spark_cryptography/secret_sharing/
secret_sharing.rs

1// use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
2use k256::{
3    elliptic_curve::{generic_array::GenericArray, PrimeField},
4    AffinePoint, ProjectivePoint, PublicKey, Scalar,
5};
6use rand::{rngs::OsRng, RngCore};
7use std::error::Error;
8
9fn scalar_to_pubkey(secret: &k256::Scalar) -> PublicKey {
10    let point = ProjectivePoint::GENERATOR * *secret;
11    PublicKey::from_affine(AffinePoint::from(point)).expect("invalid public key")
12}
13
14/// Polynomial used for secret sharing
15#[derive(Clone)]
16pub struct Polynomial {
17    /// Field modulus of the polynomial
18    pub field_modulus: Scalar,
19
20    /// Coefficients of the polynomial
21    coefficients: Vec<Scalar>,
22
23    /// Proofs of the polynomial
24    pub proofs: Vec<Vec<u8>>,
25}
26
27/// Trait for Lagrange interpolation
28pub trait LagrangeInterpolatable {
29    /// Returns the index of the share
30    fn get_index(&self) -> &Scalar;
31
32    /// Returns the share value
33    fn get_share(&self) -> &Scalar;
34
35    /// Returns the field modulus
36    fn get_field_modulus(&self) -> &Scalar;
37
38    /// Returns the threshold
39    fn get_threshold(&self) -> usize;
40}
41
42/// Basic secret share structure
43#[derive(Debug, Clone)]
44pub struct SecretShare {
45    /// Field modulus of the share
46    pub field_modulus: Scalar,
47
48    /// Threshold of the secret
49    pub threshold: usize,
50
51    /// Index of the share
52    pub index: Scalar,
53
54    /// Share value
55    pub share: Scalar,
56}
57
58impl LagrangeInterpolatable for SecretShare {
59    fn get_index(&self) -> &Scalar {
60        &self.index
61    }
62
63    fn get_field_modulus(&self) -> &Scalar {
64        &self.field_modulus
65    }
66
67    fn get_share(&self) -> &Scalar {
68        &self.share
69    }
70
71    fn get_threshold(&self) -> usize {
72        self.threshold
73    }
74}
75
76/// Verifiable secret share with proofs
77#[derive(Debug, Clone)]
78pub struct VerifiableSecretShare {
79    /// Base secret share
80    pub secret_share: SecretShare,
81
82    /// Proofs for verification
83    pub proofs: Vec<Vec<u8>>,
84}
85
86impl VerifiableSecretShare {
87    pub fn marshal_proto(&self) -> spark_protos::spark::SecretShare {
88        spark_protos::spark::SecretShare {
89            secret_share: self.secret_share.share.to_bytes().to_vec(),
90            proofs: self.proofs.clone(),
91        }
92    }
93}
94
95impl LagrangeInterpolatable for VerifiableSecretShare {
96    fn get_index(&self) -> &Scalar {
97        &self.secret_share.index
98    }
99
100    fn get_field_modulus(&self) -> &Scalar {
101        &self.secret_share.field_modulus
102    }
103
104    fn get_share(&self) -> &Scalar {
105        &self.secret_share.share
106    }
107
108    fn get_threshold(&self) -> usize {
109        self.secret_share.threshold
110    }
111}
112
113impl Polynomial {
114    /// Evaluates the polynomial at a given point
115    pub fn evaluate(&self, x: &Scalar) -> Scalar {
116        let mut result = Scalar::ZERO;
117        let mut x_power = Scalar::ONE;
118
119        for coeff in self.coefficients.iter() {
120            result += *coeff * x_power;
121            x_power *= x;
122        }
123        result
124    }
125}
126
127/// Performs field division in the given modulus
128fn field_div(
129    numerator: &k256::Scalar,
130    denominator: &k256::Scalar,
131    _field_modulus: &k256::Scalar,
132) -> Result<k256::Scalar, String> {
133    if bool::from(denominator.is_zero()) {
134        return Err("division by zero".to_string());
135    }
136
137    let inverse = denominator
138        .invert()
139        .into_option()
140        .ok_or("element not invertible".to_string())?;
141    Ok(*numerator * inverse)
142}
143
144/// Computes Lagrange coefficients for interpolation
145pub fn compute_lagrange_coefficients<T: LagrangeInterpolatable>(
146    index: &Scalar,
147    points: &[T],
148) -> Result<Scalar, String> {
149    let mut numerator = Scalar::ONE;
150    let mut denominator = Scalar::ONE;
151    let field_modulus = points[0].get_field_modulus();
152
153    for point in points {
154        if point.get_index() == index {
155            continue;
156        }
157        numerator = numerator * point.get_index();
158        let value = point.get_index() - index;
159        denominator = denominator * value;
160    }
161
162    field_div(&numerator, &denominator, &field_modulus)
163}
164
165fn bytes_to_scalar(bytes: &[u8]) -> Result<Scalar, String> {
166    // Disallow anything larger than 32 bytes.
167    if bytes.len() != 32 {
168        return Err(format!(
169            "Invalid byte length for scalar. Expected 32, got {}",
170            bytes.len()
171        ));
172    }
173
174    // Convert bytes directly to a GenericArray
175    let arr = GenericArray::clone_from_slice(bytes);
176
177    // Attempt to create Scalar from bytes representation
178    let scalar_opt = Scalar::from_repr(arr);
179
180    // Convert CtOption to Result
181    scalar_opt
182        .into_option()
183        .ok_or_else(|| "Failed to parse Scalar (out of range)".to_string())
184}
185
186/// Generates a polynomial for secret sharing
187fn generate_polynomial_for_secret_sharing(
188    field_modulus: &Scalar,
189    secret: &Scalar,
190    threshold: usize,
191) -> Result<Polynomial, Box<dyn Error>> {
192    let mut coefficients = Vec::with_capacity(threshold + 1);
193    let mut proofs = Vec::with_capacity(threshold + 1);
194
195    // Set the constant term (secret)
196    coefficients.push(*secret);
197
198    // Generate proof for secret
199    proofs.push(scalar_to_pubkey(secret).to_sec1_bytes().to_vec());
200
201    // Generate random coefficients for higher terms
202    for _ in 1..=threshold {
203        let mut random_bytes = [0u8; 32];
204        OsRng.fill_bytes(&mut random_bytes);
205
206        // Convert to scalar and ensure it's within field modulus
207        let random_scalar = bytes_to_scalar(&random_bytes)?;
208        coefficients.push(random_scalar);
209        proofs.push(scalar_to_pubkey(&random_scalar).to_sec1_bytes().to_vec());
210    }
211
212    Ok(Polynomial {
213        field_modulus: *field_modulus,
214        coefficients,
215        proofs,
216    })
217}
218
219/// Splits a secret into shares
220pub fn split_secret(
221    secret: &Scalar,
222    field_modulus: &Scalar,
223    threshold: usize,
224    number_of_shares: usize,
225) -> Result<Vec<SecretShare>, Box<dyn Error>> {
226    let polynomial = generate_polynomial_for_secret_sharing(field_modulus, secret, threshold - 1)?;
227
228    let mut shares = Vec::with_capacity(number_of_shares);
229    for i in 1..=number_of_shares {
230        let index = Scalar::from(i as u64);
231        let share = polynomial.evaluate(&index);
232
233        shares.push(SecretShare {
234            field_modulus: *field_modulus,
235            threshold,
236            index,
237            share,
238        });
239    }
240
241    Ok(shares)
242}
243
244/// Helper function to perform modular exponentiation for k256::Scalar
245fn scalar_modpow(base: &Scalar, exp: usize, _modulus: &Scalar) -> Scalar {
246    if exp == 0 {
247        return Scalar::ONE;
248    }
249
250    let mut result = Scalar::ONE;
251    let mut base = *base;
252    let mut exp = exp;
253
254    while exp > 0 {
255        if exp & 1 == 1 {
256            result *= base;
257        }
258        base *= base;
259        exp >>= 1;
260    }
261    result
262}
263
264/// Splits a secret into verifiable shares
265pub fn split_secret_with_proofs(
266    secret: &[u8],
267    field_modulus: &[u8],
268    threshold: usize,
269    number_of_shares: usize,
270) -> Result<Vec<VerifiableSecretShare>, Box<dyn Error>> {
271    // Ensure secret is valid scalar
272    let secret_scalar = bytes_to_scalar(secret)?;
273    let field_modulus_scalar = bytes_to_scalar(field_modulus)?;
274
275    // Validate inputs
276    if threshold == 0 || threshold > number_of_shares {
277        return Err("Invalid threshold".into());
278    }
279
280    let polynomial = generate_polynomial_for_secret_sharing(
281        &field_modulus_scalar,
282        &secret_scalar,
283        threshold - 1,
284    )?;
285
286    let mut shares = Vec::with_capacity(number_of_shares);
287    for i in 1..=number_of_shares {
288        let index = Scalar::from(i as u64);
289        let share = polynomial.evaluate(&index);
290
291        shares.push(VerifiableSecretShare {
292            secret_share: SecretShare {
293                field_modulus: field_modulus_scalar,
294                threshold,
295                index,
296                share,
297            },
298            proofs: polynomial.proofs.clone(),
299        });
300    }
301
302    Ok(shares)
303}
304
305/// Recovers a secret from a set of shares
306pub fn recover_secret<T: LagrangeInterpolatable>(shares: &[T]) -> Result<Scalar, String> {
307    if shares.len() < shares[0].get_threshold() {
308        return Err("not enough shares to recover secret".to_string());
309    }
310
311    let mut result = Scalar::ZERO;
312
313    for share in shares {
314        let coeff = compute_lagrange_coefficients(share.get_index(), shares)?;
315        result += share.get_share() * &coeff;
316    }
317
318    Ok(result)
319}
320
321/// Validates a verifiable share
322pub fn validate_share(share: &VerifiableSecretShare) -> Result<(), String> {
323    let target_pubkey = scalar_to_pubkey(&share.secret_share.share);
324    let mut result = ProjectivePoint::IDENTITY;
325
326    // Add the base proof
327    if let Some(base_proof) = share.proofs.first() {
328        let base_pubkey = PublicKey::from_sec1_bytes(base_proof).map_err(|e| e.to_string())?;
329        result += ProjectivePoint::from(base_pubkey.as_affine());
330    }
331
332    // Add the higher-degree terms
333    for (i, proof) in share.proofs.iter().enumerate().skip(1) {
334        let pubkey = PublicKey::from_sec1_bytes(proof).map_err(|e| e.to_string())?;
335        let exp = scalar_modpow(
336            &share.secret_share.index,
337            i,
338            &share.secret_share.field_modulus,
339        );
340        result += ProjectivePoint::from(pubkey.as_affine()) * exp;
341    }
342
343    if AffinePoint::from(result) == *target_pubkey.as_affine() {
344        Ok(())
345    } else {
346        Err("Share validation failed".to_string())
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    use crate::secp256k1::CURVE_ORDER;
355    use k256::elliptic_curve::bigint::{Encoding as _, U256};
356
357    // The secp256k1 group order in hex (this is `n`).
358    const CURVE_ORDER_HEX: &str =
359        "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141";
360
361    /// Convert a scalar to fixed 32-byte array
362    fn scalar_to_32_bytes(scalar: &Scalar) -> [u8; 32] {
363        scalar.to_bytes().into()
364    }
365
366    // Helper for an in-range "field modulus" in byte form:
367    fn get_test_field_modulus_in_range() -> Vec<u8> {
368        let raw = hex::decode(CURVE_ORDER).unwrap();
369        let u = U256::from_be_slice(&raw);
370
371        // subtract 1 so that it's definitely in range
372        let minus_one = u.saturating_sub(&U256::ONE);
373        minus_one.to_be_bytes().to_vec()
374    }
375
376    #[test]
377    fn test_secret_sharing_basic() -> Result<(), Box<dyn Error>> {
378        // Replace 5-byte secret with a 32-byte array containing those 5 bytes at the front
379        let mut secret_bytes = [0u8; 32];
380        secret_bytes[..5].copy_from_slice(&[1, 2, 3, 4, 5]);
381        let secret = bytes_to_scalar(&secret_bytes)?;
382
383        // Use an in-range field modulus rather than the full curve order
384        let field_modulus_bytes = get_test_field_modulus_in_range();
385        let field_modulus = bytes_to_scalar(&field_modulus_bytes)?;
386
387        let shares = split_secret(&secret, &field_modulus, 3, 5)?;
388        let recovered = recover_secret(&shares[0..3])?;
389
390        assert_eq!(secret, recovered);
391        Ok(())
392    }
393
394    #[test]
395    fn test_verifiable_secret_sharing() -> Result<(), Box<dyn Error>> {
396        // Again, pad the 5-byte secret to 32 bytes
397        let mut secret_arr = [0u8; 32];
398        secret_arr[..5].copy_from_slice(&[1, 2, 3, 4, 5]);
399
400        // Use a valid field_modulus that's < group order
401        let field_modulus = get_test_field_modulus_in_range();
402
403        let shares = split_secret_with_proofs(&secret_arr, &field_modulus, 3, 5)?;
404
405        // Validate all shares
406        for share in &shares {
407            validate_share(share)?;
408        }
409
410        // Recover secret
411        let recovered = recover_secret(&shares[0..3])?;
412        assert_eq!(bytes_to_scalar(&secret_arr)?, recovered);
413
414        Ok(())
415    }
416
417    #[test]
418    fn test_share_bytes_compatibility() -> Result<(), Box<dyn Error>> {
419        // Use a field modulus that is definitely in range
420        let field_modulus = get_test_field_modulus_in_range();
421        let secret_bytes = vec![0x11; 32]; // 32 bytes of 0x11
422
423        // First split and validate
424        let shares = split_secret_with_proofs(&secret_bytes, &field_modulus, 3, 5)?;
425
426        // Store the original share bytes
427        let original_share_bytes: Vec<Vec<u8>> = shares
428            .iter()
429            .map(|s| scalar_to_32_bytes(&s.secret_share.share).to_vec())
430            .collect();
431
432        // Validate all original shares
433        for share in &shares {
434            validate_share(share)?;
435        }
436
437        // Create new shares with the stored bytes
438        let new_shares: Vec<VerifiableSecretShare> = shares
439            .iter()
440            .enumerate()
441            .map(|(i, original_share)| {
442                let share_scalar = bytes_to_scalar(&original_share_bytes[i]).unwrap();
443                VerifiableSecretShare {
444                    secret_share: SecretShare {
445                        field_modulus: original_share.secret_share.field_modulus,
446                        threshold: original_share.secret_share.threshold,
447                        index: Scalar::from((i + 1) as u64),
448                        share: share_scalar,
449                    },
450                    proofs: original_share.proofs.clone(),
451                }
452            })
453            .collect();
454
455        // Validate reconstructed shares
456        for share in &new_shares {
457            validate_share(share)?;
458        }
459
460        // Check byte-for-byte equality
461        for (orig, new) in shares.iter().zip(new_shares.iter()) {
462            assert_eq!(
463                scalar_to_32_bytes(&orig.secret_share.share),
464                scalar_to_32_bytes(&new.secret_share.share),
465                "Share bytes don't match after reconstruction"
466            );
467        }
468
469        // Recover secret from both sets of shares
470        let recovered_orig = recover_secret(&shares[0..3])?;
471        let recovered_new = recover_secret(&new_shares[0..3])?;
472        assert_eq!(recovered_orig, recovered_new);
473
474        Ok(())
475    }
476
477    #[test]
478    fn test_secret_sharing_out_of_range_share() {
479        // This share is EXACTLY n, i.e. out of range for k256::Scalar
480        let out_of_range_bytes = hex::decode(CURVE_ORDER_HEX).expect("invalid hex for n");
481
482        // Demonstrate that bytes_to_scalar() rejects it
483        let result = bytes_to_scalar(&out_of_range_bytes);
484        assert!(
485            result.is_err(),
486            "Expected out-of-range error, but got an Ok(Scalar)?"
487        );
488
489        println!("Got expected error: {:?}", result.err().unwrap());
490    }
491}