spark_cryptography/secret_sharing/
shamir_new.rs1use bitcoin::secp256k1::SecretKey;
2use elliptic_curve::sec1::ToEncodedPoint;
3use elliptic_curve::{generic_array::GenericArray, PrimeField};
4use k256::{elliptic_curve::Field, ProjectivePoint, Scalar};
5use spark_protos::spark::SecretShare;
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum VSSError {
10 #[error("Invalid threshold: {0}")]
11 InvalidThreshold(String),
12 #[error("Insufficient shares: {0}")]
13 InsufficientShares(String),
14 #[error("Math error: {0}")]
15 MathError(String),
16 #[error("Verification failed")]
17 VerificationFailed,
18}
19
20pub struct VSS {
21 threshold: usize,
22 total_shares: usize,
23}
24
25pub struct Share {
26 x: Scalar,
27 y: Scalar,
28 commitment: Vec<ProjectivePoint>,
29}
30
31impl Share {
32 pub fn to_bytes(&self) -> Vec<u8> {
33 self.y.to_bytes().to_vec()
34 }
35
36 pub fn marshal_proto(&self) -> SecretShare {
37 let proofs: Vec<Vec<u8>> = self
39 .commitment
40 .iter()
41 .map(|point| point.to_affine().to_encoded_point(true).as_bytes().to_vec())
42 .collect();
43
44 SecretShare {
45 secret_share: self.to_bytes(),
46 proofs,
47 }
48 }
49}
50
51impl VSS {
52 pub fn new(threshold: usize, total_shares: usize) -> Result<Self, VSSError> {
53 if threshold > total_shares {
54 return Err(VSSError::InvalidThreshold(
55 "Threshold cannot exceed total shares".to_string(),
56 ));
57 }
58 if threshold < 1 || total_shares < 1 {
59 return Err(VSSError::InvalidThreshold(
60 "Threshold and shares must be positive".to_string(),
61 ));
62 }
63
64 Ok(VSS {
65 threshold,
66 total_shares,
67 })
68 }
69
70 pub fn split_from_secret_key(&self, sk: &SecretKey) -> Result<Vec<Share>, VSSError> {
71 let bytes = sk.secret_bytes();
72 let generic_array = GenericArray::from_slice(&bytes);
73 let scalar = Scalar::from_repr(*generic_array).unwrap();
74 self.split_from_scalar(&scalar)
75 }
76
77 pub fn split_from_scalar(&self, secret: &Scalar) -> Result<Vec<Share>, VSSError> {
78 let mut rng = rand::thread_rng();
79 let g = ProjectivePoint::GENERATOR;
80
81 let mut coefficients = vec![*secret];
83 for _ in 1..self.threshold {
84 coefficients.push(Scalar::random(&mut rng));
85 }
86
87 let mut commitments = Vec::with_capacity(self.threshold);
89 for coef in &coefficients {
90 commitments.push(g * coef);
91 }
92
93 let mut shares = Vec::with_capacity(self.total_shares);
95 for x in 1..=self.total_shares {
96 let x_val = Scalar::from(x as u32);
97 let mut y = Scalar::ZERO;
98
99 for coef in coefficients.iter().rev() {
101 y = y * x_val + coef;
102 }
103
104 shares.push(Share {
105 x: x_val,
106 y,
107 commitment: commitments.clone(),
108 });
109 }
110
111 Ok(shares)
112 }
113
114 pub fn verify_share(&self, share: &Share) -> Result<bool, VSSError> {
115 let g = ProjectivePoint::GENERATOR;
116 let left = g * share.y;
117
118 let mut right = ProjectivePoint::IDENTITY;
119 for (i, commit) in share.commitment.iter().enumerate() {
120 let exp = share.x.pow_vartime([i as u64, 0, 0, 0]);
121 right = right + (*commit * exp);
122 }
123
124 Ok(left == right)
125 }
126
127 pub fn reconstruct(&self, shares: &[Share]) -> Result<Scalar, VSSError> {
128 if shares.len() < self.threshold {
129 return Err(VSSError::InsufficientShares(format!(
130 "Need at least {} shares, got {}",
131 self.threshold,
132 shares.len()
133 )));
134 }
135
136 let mut secret = Scalar::ZERO;
138 for i in 0..self.threshold {
139 let mut numerator = Scalar::ONE;
140 let mut denominator = Scalar::ONE;
141
142 for j in 0..self.threshold {
143 if i != j {
144 numerator *= shares[j].x;
145 denominator *= shares[j].x - shares[i].x;
146 }
147 }
148
149 let li = shares[i].y * numerator * denominator.invert().unwrap();
150 secret += li;
151 }
152
153 Ok(secret)
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use k256::SecretKey;
161
162 #[test]
163 fn test_basic_split_reconstruct() {
164 let vss = VSS::new(3, 5).unwrap();
165 let secret = Scalar::random(&mut rand::thread_rng());
166 let shares = vss.split_from_scalar(&secret).unwrap();
167
168 assert_eq!(shares.len(), 5);
169 let reconstructed = vss.reconstruct(&shares[0..3]).unwrap();
170 assert_eq!(reconstructed, secret);
171 }
172
173 #[test]
174 fn test_verification() {
175 let vss = VSS::new(2, 3).unwrap();
176 let secret = Scalar::random(&mut rand::thread_rng());
177 let shares = vss.split_from_scalar(&secret).unwrap();
178
179 for share in &shares {
180 assert!(vss.verify_share(share).unwrap());
181 }
182 }
183
184 #[test]
185 fn test_insufficient_shares() {
186 let vss = VSS::new(3, 5).unwrap();
187 let secret = Scalar::random(&mut rand::thread_rng());
188 let shares = vss.split_from_scalar(&secret).unwrap();
189
190 assert!(matches!(
191 vss.reconstruct(&shares[0..2]),
192 Err(VSSError::InsufficientShares(_))
193 ));
194 }
195
196 #[test]
197 fn test_invalid_threshold() {
198 assert!(matches!(VSS::new(5, 3), Err(VSSError::InvalidThreshold(_))));
199
200 assert!(matches!(VSS::new(0, 5), Err(VSSError::InvalidThreshold(_))));
201 }
202
203 #[test]
204 fn test_with_secret_key() {
205 let vss = VSS::new(2, 3).unwrap();
206 let sk = SecretKey::random(&mut rand::thread_rng());
207 let secret = sk.to_nonzero_scalar();
208 let shares = vss.split_from_scalar(&secret).unwrap();
209
210 let reconstructed = vss.reconstruct(&shares[0..2]).unwrap();
211 assert_eq!(reconstructed, *secret);
212 }
213}