spark_cryptography/secret_sharing/
shamir_new.rs

1use bitcoin::secp256k1::SecretKey;
2use elliptic_curve::sec1::ToEncodedPoint;
3use elliptic_curve::{generic_array::GenericArray, PrimeField};
4use k256::{elliptic_curve::Field, ProjectivePoint, Scalar};
5use spark_protos::spark::SecretShare;
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum VSSError {
10    #[error("Invalid threshold: {0}")]
11    InvalidThreshold(String),
12    #[error("Insufficient shares: {0}")]
13    InsufficientShares(String),
14    #[error("Math error: {0}")]
15    MathError(String),
16    #[error("Verification failed")]
17    VerificationFailed,
18}
19
20pub struct VSS {
21    threshold: usize,
22    total_shares: usize,
23}
24
25pub struct Share {
26    x: Scalar,
27    y: Scalar,
28    commitment: Vec<ProjectivePoint>,
29}
30
31impl Share {
32    pub fn to_bytes(&self) -> Vec<u8> {
33        self.y.to_bytes().to_vec()
34    }
35
36    pub fn marshal_proto(&self) -> SecretShare {
37        // Convert commitments to bytes (similar to proofs in Go)
38        let proofs: Vec<Vec<u8>> = self
39            .commitment
40            .iter()
41            .map(|point| point.to_affine().to_encoded_point(true).as_bytes().to_vec())
42            .collect();
43
44        SecretShare {
45            secret_share: self.to_bytes(),
46            proofs,
47        }
48    }
49}
50
51impl VSS {
52    pub fn new(threshold: usize, total_shares: usize) -> Result<Self, VSSError> {
53        if threshold > total_shares {
54            return Err(VSSError::InvalidThreshold(
55                "Threshold cannot exceed total shares".to_string(),
56            ));
57        }
58        if threshold < 1 || total_shares < 1 {
59            return Err(VSSError::InvalidThreshold(
60                "Threshold and shares must be positive".to_string(),
61            ));
62        }
63
64        Ok(VSS {
65            threshold,
66            total_shares,
67        })
68    }
69
70    pub fn split_from_secret_key(&self, sk: &SecretKey) -> Result<Vec<Share>, VSSError> {
71        let bytes = sk.secret_bytes();
72        let generic_array = GenericArray::from_slice(&bytes);
73        let scalar = Scalar::from_repr(*generic_array).unwrap();
74        self.split_from_scalar(&scalar)
75    }
76
77    pub fn split_from_scalar(&self, secret: &Scalar) -> Result<Vec<Share>, VSSError> {
78        let mut rng = rand::thread_rng();
79        let g = ProjectivePoint::GENERATOR;
80
81        // Generate random polynomial coefficients
82        let mut coefficients = vec![*secret];
83        for _ in 1..self.threshold {
84            coefficients.push(Scalar::random(&mut rng));
85        }
86
87        // Generate commitments for verifiability (Feldman's scheme)
88        let mut commitments = Vec::with_capacity(self.threshold);
89        for coef in &coefficients {
90            commitments.push(g * coef);
91        }
92
93        // Generate shares
94        let mut shares = Vec::with_capacity(self.total_shares);
95        for x in 1..=self.total_shares {
96            let x_val = Scalar::from(x as u32);
97            let mut y = Scalar::ZERO;
98
99            // Evaluate polynomial at x using Horner's method
100            for coef in coefficients.iter().rev() {
101                y = y * x_val + coef;
102            }
103
104            shares.push(Share {
105                x: x_val,
106                y,
107                commitment: commitments.clone(),
108            });
109        }
110
111        Ok(shares)
112    }
113
114    pub fn verify_share(&self, share: &Share) -> Result<bool, VSSError> {
115        let g = ProjectivePoint::GENERATOR;
116        let left = g * share.y;
117
118        let mut right = ProjectivePoint::IDENTITY;
119        for (i, commit) in share.commitment.iter().enumerate() {
120            let exp = share.x.pow_vartime([i as u64, 0, 0, 0]);
121            right = right + (*commit * exp);
122        }
123
124        Ok(left == right)
125    }
126
127    pub fn reconstruct(&self, shares: &[Share]) -> Result<Scalar, VSSError> {
128        if shares.len() < self.threshold {
129            return Err(VSSError::InsufficientShares(format!(
130                "Need at least {} shares, got {}",
131                self.threshold,
132                shares.len()
133            )));
134        }
135
136        // Lagrange interpolation in scalar field
137        let mut secret = Scalar::ZERO;
138        for i in 0..self.threshold {
139            let mut numerator = Scalar::ONE;
140            let mut denominator = Scalar::ONE;
141
142            for j in 0..self.threshold {
143                if i != j {
144                    numerator *= shares[j].x;
145                    denominator *= shares[j].x - shares[i].x;
146                }
147            }
148
149            let li = shares[i].y * numerator * denominator.invert().unwrap();
150            secret += li;
151        }
152
153        Ok(secret)
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use k256::SecretKey;
161
162    #[test]
163    fn test_basic_split_reconstruct() {
164        let vss = VSS::new(3, 5).unwrap();
165        let secret = Scalar::random(&mut rand::thread_rng());
166        let shares = vss.split_from_scalar(&secret).unwrap();
167
168        assert_eq!(shares.len(), 5);
169        let reconstructed = vss.reconstruct(&shares[0..3]).unwrap();
170        assert_eq!(reconstructed, secret);
171    }
172
173    #[test]
174    fn test_verification() {
175        let vss = VSS::new(2, 3).unwrap();
176        let secret = Scalar::random(&mut rand::thread_rng());
177        let shares = vss.split_from_scalar(&secret).unwrap();
178
179        for share in &shares {
180            assert!(vss.verify_share(share).unwrap());
181        }
182    }
183
184    #[test]
185    fn test_insufficient_shares() {
186        let vss = VSS::new(3, 5).unwrap();
187        let secret = Scalar::random(&mut rand::thread_rng());
188        let shares = vss.split_from_scalar(&secret).unwrap();
189
190        assert!(matches!(
191            vss.reconstruct(&shares[0..2]),
192            Err(VSSError::InsufficientShares(_))
193        ));
194    }
195
196    #[test]
197    fn test_invalid_threshold() {
198        assert!(matches!(VSS::new(5, 3), Err(VSSError::InvalidThreshold(_))));
199
200        assert!(matches!(VSS::new(0, 5), Err(VSSError::InvalidThreshold(_))));
201    }
202
203    #[test]
204    fn test_with_secret_key() {
205        let vss = VSS::new(2, 3).unwrap();
206        let sk = SecretKey::random(&mut rand::thread_rng());
207        let secret = sk.to_nonzero_scalar();
208        let shares = vss.split_from_scalar(&secret).unwrap();
209
210        let reconstructed = vss.reconstruct(&shares[0..2]).unwrap();
211        assert_eq!(reconstructed, *secret);
212    }
213}