spark_cryptography/
key_arithmetic.rs

1use bitcoin::secp256k1::{Error as Secp256k1Error, PublicKey, Secp256k1, SecretKey};
2
3/// Add two public keys
4/// Input keys must be 33-byte compressed secp256k1 public keys
5pub fn add_public_keys(key1: &[u8], key2: &[u8]) -> Result<Vec<u8>, Secp256k1Error> {
6    if key1.len() != 33 || key2.len() != 33 {
7        return Err(Secp256k1Error::InvalidPublicKey);
8    }
9
10    let pub1 = PublicKey::from_slice(key1)?;
11    let pub2 = PublicKey::from_slice(key2)?;
12
13    // Combine the public keys
14    let combined = pub1.combine(&pub2)?;
15
16    // Return compressed format (33 bytes)
17    Ok(combined.serialize().to_vec())
18}
19
20pub fn apply_public_key_tweak(pubkey: &[u8], tweak: &[u8]) -> Result<Vec<u8>, Secp256k1Error> {
21    if pubkey.len() != 33 || tweak.len() != 32 {
22        return Err(Secp256k1Error::InvalidPublicKey);
23    }
24
25    let secp = Secp256k1::new();
26    let pub_key = PublicKey::from_slice(pubkey)?;
27
28    // Apply the tweak
29    let tweak_key = SecretKey::from_slice(tweak)?;
30    let tweaked_key = pub_key.add_exp_tweak(&secp, &tweak_key.into())?;
31
32    // Return compressed format
33    Ok(tweaked_key.serialize().to_vec())
34}
35
36/// Subtract public keys
37/// Input keys must be 33-byte compressed secp256k1 public keys
38pub fn subtract_public_keys(key1: &[u8], key2: &[u8]) -> Result<Vec<u8>, Secp256k1Error> {
39    if key1.len() != 33 || key2.len() != 33 {
40        return Err(Secp256k1Error::InvalidPublicKey);
41    }
42
43    let secp = Secp256k1::new();
44    let pub1 = PublicKey::from_slice(key1)?;
45    let pub2 = PublicKey::from_slice(key2)?;
46
47    // Negate the second key and add
48    let negated = pub2.negate(&secp);
49    let result = pub1.combine(&negated)?;
50
51    Ok(result.serialize().to_vec())
52}
53
54/// Add private keys
55/// Input keys must be 32 bytes
56pub fn add_private_keys(key1: &[u8], key2: &[u8]) -> Result<Vec<u8>, Secp256k1Error> {
57    if key1.len() != 32 || key2.len() != 32 {
58        return Err(Secp256k1Error::InvalidSecretKey);
59    }
60
61    let sec1 = SecretKey::from_slice(key1)?;
62    let sec2 = SecretKey::from_slice(key2)?;
63
64    // Add the keys modulo curve order
65    let combined = sec1.add_tweak(&sec2.into())?;
66
67    Ok(combined.secret_bytes().to_vec())
68}
69
70/// Subtract private keys
71/// Input keys must be 32 bytes
72pub fn subtract_secret_keys<T: AsRef<[u8]>, U: AsRef<[u8]>>(
73    key1: T,
74    key2: U,
75) -> Result<Vec<u8>, Secp256k1Error> {
76    if key1.as_ref().len() != 32 || key2.as_ref().len() != 32 {
77        return Err(Secp256k1Error::InvalidSecretKey);
78    }
79
80    let sec1 = SecretKey::from_slice(key1.as_ref())?;
81    let sec2 = SecretKey::from_slice(key2.as_ref())?;
82
83    // Negate the second key and add
84    let negated = sec2.negate();
85    let result = sec1.add_tweak(&negated.into())?;
86
87    Ok(result.secret_bytes().to_vec())
88}
89
90/// Sum of private keys
91/// Returns the sum of the given private keys modulo the curve order
92pub fn sum_private_keys(keys: &[&[u8]]) -> Result<Vec<u8>, Secp256k1Error> {
93    if keys.is_empty() {
94        return Err(Secp256k1Error::InvalidSecretKey);
95    }
96
97    let mut result = SecretKey::from_slice(keys[0])?;
98
99    // Add all subsequent keys
100    for key in keys.iter().skip(1) {
101        if key.len() != 32 {
102            return Err(Secp256k1Error::InvalidSecretKey);
103        }
104        let next_key = SecretKey::from_slice(key)?;
105        result = result.add_tweak(&next_key.into())?;
106    }
107
108    Ok(result.secret_bytes().to_vec())
109}
110
111/// Last key with target
112/// Tweaks the given keys so their sum equals the target
113pub fn last_key_with_target(keys: &[&[u8]], target: &[u8]) -> Result<Vec<u8>, Secp256k1Error> {
114    if target.len() != 32 {
115        return Err(Secp256k1Error::InvalidSecretKey);
116    }
117
118    let target_key = SecretKey::from_slice(target)?;
119    let current_sum = sum_private_keys(keys)?;
120    let current_sum_key = SecretKey::from_slice(&current_sum)?;
121
122    // Calculate target - sum(keys)
123    let result = subtract_secret_keys(&target_key.secret_bytes(), &current_sum_key.secret_bytes())?;
124
125    Ok(result)
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use bitcoin::secp256k1::{rand::rngs::OsRng, Secp256k1};
132
133    /// Helper function to generate n random private keys
134    fn generate_test_keys(n: usize) -> Vec<Vec<u8>> {
135        let secp = Secp256k1::new();
136        let mut keys = Vec::with_capacity(n);
137        for _ in 0..n {
138            let (secret_key, _) = secp.generate_keypair(&mut OsRng);
139            keys.push(secret_key.secret_bytes().to_vec());
140        }
141        keys
142    }
143
144    #[test]
145    fn test_key_additions_multiple() {
146        let secp = Secp256k1::new();
147
148        // Test with multiple pairs of keys
149        for _ in 0..10 {
150            let (priv_a, pub_a) = secp.generate_keypair(&mut OsRng);
151            let (priv_b, pub_b) = secp.generate_keypair(&mut OsRng);
152
153            let priv_sum =
154                add_private_keys(&priv_a.secret_bytes(), &priv_b.secret_bytes()).unwrap();
155
156            let pub_sum = add_public_keys(&pub_a.serialize(), &pub_b.serialize()).unwrap();
157
158            let sum_key = SecretKey::from_slice(&priv_sum).unwrap();
159            let sum_pub = PublicKey::from_secret_key(&secp, &sum_key);
160
161            assert_eq!(sum_pub.serialize().to_vec(), pub_sum);
162        }
163    }
164
165    #[test]
166    fn test_subtract_keys() {
167        let secp = Secp256k1::new();
168
169        // Test that (A + B) - B = A for both public and private keys
170        let (priv_a, pub_a) = secp.generate_keypair(&mut OsRng);
171        let (priv_b, pub_b) = secp.generate_keypair(&mut OsRng);
172
173        // Test private key subtraction
174        let sum = add_private_keys(&priv_a.secret_bytes(), &priv_b.secret_bytes()).unwrap();
175
176        let diff = subtract_secret_keys(&sum, &priv_b.secret_bytes()).unwrap();
177
178        assert_eq!(diff, priv_a.secret_bytes());
179
180        // Test public key subtraction
181        let pub_sum = add_public_keys(&pub_a.serialize(), &pub_b.serialize()).unwrap();
182
183        let pub_diff = subtract_public_keys(&pub_sum, &pub_b.serialize()).unwrap();
184
185        assert_eq!(pub_diff, pub_a.serialize());
186    }
187
188    #[test]
189    #[should_panic(expected = "InvalidSecretKey")]
190    fn test_invalid_private_key() {
191        // Test with invalid private key (all zeros)
192        let zeros = vec![0u8; 32];
193        let (valid_key, _) = Secp256k1::new().generate_keypair(&mut OsRng);
194
195        add_private_keys(&zeros, &valid_key.secret_bytes()).unwrap();
196    }
197
198    #[test]
199    #[should_panic(expected = "InvalidPublicKey")]
200    fn test_invalid_public_key() {
201        // Test with invalid public key (all zeros)
202        let zeros = vec![0u8; 33];
203        let (_, valid_pub) = Secp256k1::new().generate_keypair(&mut OsRng);
204
205        add_public_keys(&zeros, &valid_pub.serialize()).unwrap();
206    }
207
208    #[test]
209    fn test_sum_of_private_keys_large_set() {
210        // Test with a larger set of keys (100 keys)
211        let keys = generate_test_keys(100);
212        let key_slices: Vec<&[u8]> = keys.iter().map(|k| k.as_slice()).collect();
213
214        let sum = sum_private_keys(&key_slices).unwrap();
215
216        // Verify by adding one by one
217        let mut manual_sum = keys[0].clone();
218        for key in keys.iter().skip(1) {
219            manual_sum = add_private_keys(&manual_sum, key).unwrap();
220        }
221
222        assert_eq!(sum, manual_sum);
223    }
224
225    #[test]
226    fn test_last_key_with_target_properties() {
227        let secp = Secp256k1::new();
228
229        // Generate random target and keys
230        let (target_key, _) = secp.generate_keypair(&mut OsRng);
231        let keys = generate_test_keys(5);
232        let key_slices: Vec<&[u8]> = keys.iter().map(|k| k.as_slice()).collect();
233
234        // Calculate tweak
235        let tweak = last_key_with_target(&key_slices, &target_key.secret_bytes()).unwrap();
236
237        // Properties to verify:
238        // 1. Tweak should be a valid private key
239        assert!(SecretKey::from_slice(&tweak).is_ok());
240
241        // 2. Sum of keys + tweak should equal target
242        let mut keys_with_tweak = keys.clone();
243        keys_with_tweak.push(tweak);
244        let key_slices_with_tweak: Vec<&[u8]> =
245            keys_with_tweak.iter().map(|k| k.as_slice()).collect();
246
247        let sum = sum_private_keys(&key_slices_with_tweak).unwrap();
248        assert_eq!(sum, target_key.secret_bytes());
249    }
250
251    #[test]
252    fn test_apply_tweak_associative_property() {
253        let secp = Secp256k1::new();
254
255        // Test that (P + t1) + t2 = P + (t1 + t2)
256        let (_, base_pub) = secp.generate_keypair(&mut OsRng);
257        let (tweak1, _) = secp.generate_keypair(&mut OsRng);
258        let (tweak2, _) = secp.generate_keypair(&mut OsRng);
259
260        // Method 1: (P + t1) + t2
261        let intermediate =
262            apply_public_key_tweak(&base_pub.serialize(), &tweak1.secret_bytes()).unwrap();
263
264        let result1 = apply_public_key_tweak(&intermediate, &tweak2.secret_bytes()).unwrap();
265
266        // Method 2: P + (t1 + t2)
267        let combined_tweak =
268            add_private_keys(&tweak1.secret_bytes(), &tweak2.secret_bytes()).unwrap();
269
270        let result2 = apply_public_key_tweak(&base_pub.serialize(), &combined_tweak).unwrap();
271
272        assert_eq!(result1, result2);
273    }
274
275    #[test]
276    fn test_key_additions() {
277        let secp = Secp256k1::new();
278
279        // Generate first key pair
280        let (priv_a, pub_a) = secp.generate_keypair(&mut OsRng);
281        let priv_a_bytes = priv_a.secret_bytes();
282        let pub_a_bytes = pub_a.serialize();
283
284        // Generate second key pair
285        let (priv_b, pub_b) = secp.generate_keypair(&mut OsRng);
286        let priv_b_bytes = priv_b.secret_bytes();
287        let pub_b_bytes = pub_b.serialize();
288
289        // Test that public key of private key addition equals the public key addition
290        let priv_sum =
291            add_private_keys(&priv_a_bytes, &priv_b_bytes).expect("Failed to add private keys");
292        let pub_sum =
293            add_public_keys(&pub_a_bytes, &pub_b_bytes).expect("Failed to add public keys");
294
295        let target_key =
296            SecretKey::from_slice(&priv_sum).expect("Failed to create secret key from sum");
297        let target_pub = PublicKey::from_secret_key(&secp, &target_key);
298
299        assert_eq!(
300            target_pub.serialize().to_vec(),
301            pub_sum,
302            "Public key of private key addition does not equal the public key addition"
303        );
304    }
305
306    #[test]
307    fn test_sum_of_private_keys() {
308        let secp = Secp256k1::new();
309
310        // Generate 10 random keys
311        let mut keys: Vec<Vec<u8>> = Vec::new();
312        for _ in 0..10 {
313            let (secret_key, _) = secp.generate_keypair(&mut OsRng);
314            keys.push(secret_key.secret_bytes().to_vec());
315        }
316
317        // Calculate sum using sum_private_keys
318        let key_slices: Vec<&[u8]> = keys.iter().map(|k| k.as_slice()).collect();
319        let sum = sum_private_keys(&key_slices).expect("Failed to sum private keys");
320
321        // Calculate sum manually using add_private_keys
322        let mut sum2 = keys[0].clone();
323        for i in 1..keys.len() {
324            sum2 = add_private_keys(&sum2, &keys[i]).expect("Failed to add private keys");
325        }
326
327        assert_eq!(sum, sum2, "Sum of private keys does not match");
328    }
329
330    #[test]
331    fn test_private_key_tweak_with_target() {
332        let secp = Secp256k1::new();
333
334        // Generate target key
335        let (target_key, _) = secp.generate_keypair(&mut OsRng);
336        let target_bytes = target_key.secret_bytes();
337
338        // Generate 10 random keys
339        let mut keys: Vec<Vec<u8>> = Vec::new();
340        for _ in 0..10 {
341            let (secret_key, _) = secp.generate_keypair(&mut OsRng);
342            keys.push(secret_key.secret_bytes().to_vec());
343        }
344
345        // Get key slices for the function call
346        let key_slices: Vec<&[u8]> = keys.iter().map(|k| k.as_slice()).collect();
347
348        // Calculate tweak
349        let tweak =
350            last_key_with_target(&key_slices, &target_bytes).expect("Failed to calculate tweak");
351
352        // Add tweak to keys
353        keys.push(tweak);
354
355        // Calculate sum with tweak
356        let key_slices_with_tweak: Vec<&[u8]> = keys.iter().map(|k| k.as_slice()).collect();
357        let sum = sum_private_keys(&key_slices_with_tweak)
358            .expect("Failed to sum private keys with tweak");
359
360        assert_eq!(
361            sum,
362            target_bytes.to_vec(),
363            "Private key tweak with target does not match"
364        );
365    }
366
367    #[test]
368    fn test_apply_additive_tweak_to_public_key() {
369        let secp = Secp256k1::new();
370
371        // Generate initial keypair
372        let (priv_key, pub_key) = secp.generate_keypair(&mut OsRng);
373        let priv_key_bytes = priv_key.secret_bytes();
374        let pub_key_bytes = pub_key.serialize();
375
376        // Generate tweak
377        let (tweak_key, _) = secp.generate_keypair(&mut OsRng);
378        let tweak_bytes = tweak_key.secret_bytes();
379
380        // Calculate new private key by adding
381        let new_priv =
382            add_private_keys(&priv_key_bytes, &tweak_bytes).expect("Failed to add private keys");
383        let target_key =
384            SecretKey::from_slice(&new_priv).expect("Failed to create secret key from sum");
385        let target_pub = PublicKey::from_secret_key(&secp, &target_key);
386
387        // Apply tweak to public key
388        let new_pub_key = apply_public_key_tweak(&pub_key_bytes, &tweak_bytes)
389            .expect("Failed to apply tweak to public key");
390
391        assert_eq!(
392            new_pub_key,
393            target_pub.serialize(),
394            "Apply additive tweak to public key does not match"
395        );
396    }
397}