spark_cryptography/secret_sharing/
secret_sharing.rs1use k256::{
3 elliptic_curve::{generic_array::GenericArray, PrimeField},
4 AffinePoint, ProjectivePoint, PublicKey, Scalar,
5};
6use rand::{rngs::OsRng, RngCore};
7use std::error::Error;
8
9fn scalar_to_pubkey(secret: &k256::Scalar) -> PublicKey {
10 let point = ProjectivePoint::GENERATOR * *secret;
11 PublicKey::from_affine(AffinePoint::from(point)).expect("invalid public key")
12}
13
14#[derive(Clone)]
16pub struct Polynomial {
17 pub field_modulus: Scalar,
19
20 coefficients: Vec<Scalar>,
22
23 pub proofs: Vec<Vec<u8>>,
25}
26
27pub trait LagrangeInterpolatable {
29 fn get_index(&self) -> &Scalar;
31
32 fn get_share(&self) -> &Scalar;
34
35 fn get_field_modulus(&self) -> &Scalar;
37
38 fn get_threshold(&self) -> usize;
40}
41
42#[derive(Debug, Clone)]
44pub struct SecretShare {
45 pub field_modulus: Scalar,
47
48 pub threshold: usize,
50
51 pub index: Scalar,
53
54 pub share: Scalar,
56}
57
58impl LagrangeInterpolatable for SecretShare {
59 fn get_index(&self) -> &Scalar {
60 &self.index
61 }
62
63 fn get_field_modulus(&self) -> &Scalar {
64 &self.field_modulus
65 }
66
67 fn get_share(&self) -> &Scalar {
68 &self.share
69 }
70
71 fn get_threshold(&self) -> usize {
72 self.threshold
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct VerifiableSecretShare {
79 pub secret_share: SecretShare,
81
82 pub proofs: Vec<Vec<u8>>,
84}
85
86impl VerifiableSecretShare {
87 pub fn marshal_proto(&self) -> spark_protos::spark::SecretShare {
88 spark_protos::spark::SecretShare {
89 secret_share: self.secret_share.share.to_bytes().to_vec(),
90 proofs: self.proofs.clone(),
91 }
92 }
93}
94
95impl LagrangeInterpolatable for VerifiableSecretShare {
96 fn get_index(&self) -> &Scalar {
97 &self.secret_share.index
98 }
99
100 fn get_field_modulus(&self) -> &Scalar {
101 &self.secret_share.field_modulus
102 }
103
104 fn get_share(&self) -> &Scalar {
105 &self.secret_share.share
106 }
107
108 fn get_threshold(&self) -> usize {
109 self.secret_share.threshold
110 }
111}
112
113impl Polynomial {
114 pub fn evaluate(&self, x: &Scalar) -> Scalar {
116 let mut result = Scalar::ZERO;
117 let mut x_power = Scalar::ONE;
118
119 for coeff in self.coefficients.iter() {
120 result += *coeff * x_power;
121 x_power *= x;
122 }
123 result
124 }
125}
126
127fn field_div(
129 numerator: &k256::Scalar,
130 denominator: &k256::Scalar,
131 _field_modulus: &k256::Scalar,
132) -> Result<k256::Scalar, String> {
133 if bool::from(denominator.is_zero()) {
134 return Err("division by zero".to_string());
135 }
136
137 let inverse = denominator
138 .invert()
139 .into_option()
140 .ok_or("element not invertible".to_string())?;
141 Ok(*numerator * inverse)
142}
143
144pub fn compute_lagrange_coefficients<T: LagrangeInterpolatable>(
146 index: &Scalar,
147 points: &[T],
148) -> Result<Scalar, String> {
149 let mut numerator = Scalar::ONE;
150 let mut denominator = Scalar::ONE;
151 let field_modulus = points[0].get_field_modulus();
152
153 for point in points {
154 if point.get_index() == index {
155 continue;
156 }
157 numerator = numerator * point.get_index();
158 let value = point.get_index() - index;
159 denominator = denominator * value;
160 }
161
162 field_div(&numerator, &denominator, &field_modulus)
163}
164
165fn bytes_to_scalar(bytes: &[u8]) -> Result<Scalar, String> {
166 if bytes.len() != 32 {
168 return Err(format!(
169 "Invalid byte length for scalar. Expected 32, got {}",
170 bytes.len()
171 ));
172 }
173
174 let arr = GenericArray::clone_from_slice(bytes);
176
177 let scalar_opt = Scalar::from_repr(arr);
179
180 scalar_opt
182 .into_option()
183 .ok_or_else(|| "Failed to parse Scalar (out of range)".to_string())
184}
185
186fn generate_polynomial_for_secret_sharing(
188 field_modulus: &Scalar,
189 secret: &Scalar,
190 threshold: usize,
191) -> Result<Polynomial, Box<dyn Error>> {
192 let mut coefficients = Vec::with_capacity(threshold + 1);
193 let mut proofs = Vec::with_capacity(threshold + 1);
194
195 coefficients.push(*secret);
197
198 proofs.push(scalar_to_pubkey(secret).to_sec1_bytes().to_vec());
200
201 for _ in 1..=threshold {
203 let mut random_bytes = [0u8; 32];
204 OsRng.fill_bytes(&mut random_bytes);
205
206 let random_scalar = bytes_to_scalar(&random_bytes)?;
208 coefficients.push(random_scalar);
209 proofs.push(scalar_to_pubkey(&random_scalar).to_sec1_bytes().to_vec());
210 }
211
212 Ok(Polynomial {
213 field_modulus: *field_modulus,
214 coefficients,
215 proofs,
216 })
217}
218
219pub fn split_secret(
221 secret: &Scalar,
222 field_modulus: &Scalar,
223 threshold: usize,
224 number_of_shares: usize,
225) -> Result<Vec<SecretShare>, Box<dyn Error>> {
226 let polynomial = generate_polynomial_for_secret_sharing(field_modulus, secret, threshold - 1)?;
227
228 let mut shares = Vec::with_capacity(number_of_shares);
229 for i in 1..=number_of_shares {
230 let index = Scalar::from(i as u64);
231 let share = polynomial.evaluate(&index);
232
233 shares.push(SecretShare {
234 field_modulus: *field_modulus,
235 threshold,
236 index,
237 share,
238 });
239 }
240
241 Ok(shares)
242}
243
244fn scalar_modpow(base: &Scalar, exp: usize, _modulus: &Scalar) -> Scalar {
246 if exp == 0 {
247 return Scalar::ONE;
248 }
249
250 let mut result = Scalar::ONE;
251 let mut base = *base;
252 let mut exp = exp;
253
254 while exp > 0 {
255 if exp & 1 == 1 {
256 result *= base;
257 }
258 base *= base;
259 exp >>= 1;
260 }
261 result
262}
263
264pub fn split_secret_with_proofs(
266 secret: &[u8],
267 field_modulus: &[u8],
268 threshold: usize,
269 number_of_shares: usize,
270) -> Result<Vec<VerifiableSecretShare>, Box<dyn Error>> {
271 let secret_scalar = bytes_to_scalar(secret)?;
273 let field_modulus_scalar = bytes_to_scalar(field_modulus)?;
274
275 if threshold == 0 || threshold > number_of_shares {
277 return Err("Invalid threshold".into());
278 }
279
280 let polynomial = generate_polynomial_for_secret_sharing(
281 &field_modulus_scalar,
282 &secret_scalar,
283 threshold - 1,
284 )?;
285
286 let mut shares = Vec::with_capacity(number_of_shares);
287 for i in 1..=number_of_shares {
288 let index = Scalar::from(i as u64);
289 let share = polynomial.evaluate(&index);
290
291 shares.push(VerifiableSecretShare {
292 secret_share: SecretShare {
293 field_modulus: field_modulus_scalar,
294 threshold,
295 index,
296 share,
297 },
298 proofs: polynomial.proofs.clone(),
299 });
300 }
301
302 Ok(shares)
303}
304
305pub fn recover_secret<T: LagrangeInterpolatable>(shares: &[T]) -> Result<Scalar, String> {
307 if shares.len() < shares[0].get_threshold() {
308 return Err("not enough shares to recover secret".to_string());
309 }
310
311 let mut result = Scalar::ZERO;
312
313 for share in shares {
314 let coeff = compute_lagrange_coefficients(share.get_index(), shares)?;
315 result += share.get_share() * &coeff;
316 }
317
318 Ok(result)
319}
320
321pub fn validate_share(share: &VerifiableSecretShare) -> Result<(), String> {
323 let target_pubkey = scalar_to_pubkey(&share.secret_share.share);
324 let mut result = ProjectivePoint::IDENTITY;
325
326 if let Some(base_proof) = share.proofs.first() {
328 let base_pubkey = PublicKey::from_sec1_bytes(base_proof).map_err(|e| e.to_string())?;
329 result += ProjectivePoint::from(base_pubkey.as_affine());
330 }
331
332 for (i, proof) in share.proofs.iter().enumerate().skip(1) {
334 let pubkey = PublicKey::from_sec1_bytes(proof).map_err(|e| e.to_string())?;
335 let exp = scalar_modpow(
336 &share.secret_share.index,
337 i,
338 &share.secret_share.field_modulus,
339 );
340 result += ProjectivePoint::from(pubkey.as_affine()) * exp;
341 }
342
343 if AffinePoint::from(result) == *target_pubkey.as_affine() {
344 Ok(())
345 } else {
346 Err("Share validation failed".to_string())
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 use crate::secp256k1::CURVE_ORDER;
355 use k256::elliptic_curve::bigint::{Encoding as _, U256};
356
357 const CURVE_ORDER_HEX: &str =
359 "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141";
360
361 fn scalar_to_32_bytes(scalar: &Scalar) -> [u8; 32] {
363 scalar.to_bytes().into()
364 }
365
366 fn get_test_field_modulus_in_range() -> Vec<u8> {
368 let raw = hex::decode(CURVE_ORDER).unwrap();
369 let u = U256::from_be_slice(&raw);
370
371 let minus_one = u.saturating_sub(&U256::ONE);
373 minus_one.to_be_bytes().to_vec()
374 }
375
376 #[test]
377 fn test_secret_sharing_basic() -> Result<(), Box<dyn Error>> {
378 let mut secret_bytes = [0u8; 32];
380 secret_bytes[..5].copy_from_slice(&[1, 2, 3, 4, 5]);
381 let secret = bytes_to_scalar(&secret_bytes)?;
382
383 let field_modulus_bytes = get_test_field_modulus_in_range();
385 let field_modulus = bytes_to_scalar(&field_modulus_bytes)?;
386
387 let shares = split_secret(&secret, &field_modulus, 3, 5)?;
388 let recovered = recover_secret(&shares[0..3])?;
389
390 assert_eq!(secret, recovered);
391 Ok(())
392 }
393
394 #[test]
395 fn test_verifiable_secret_sharing() -> Result<(), Box<dyn Error>> {
396 let mut secret_arr = [0u8; 32];
398 secret_arr[..5].copy_from_slice(&[1, 2, 3, 4, 5]);
399
400 let field_modulus = get_test_field_modulus_in_range();
402
403 let shares = split_secret_with_proofs(&secret_arr, &field_modulus, 3, 5)?;
404
405 for share in &shares {
407 validate_share(share)?;
408 }
409
410 let recovered = recover_secret(&shares[0..3])?;
412 assert_eq!(bytes_to_scalar(&secret_arr)?, recovered);
413
414 Ok(())
415 }
416
417 #[test]
418 fn test_share_bytes_compatibility() -> Result<(), Box<dyn Error>> {
419 let field_modulus = get_test_field_modulus_in_range();
421 let secret_bytes = vec![0x11; 32]; let shares = split_secret_with_proofs(&secret_bytes, &field_modulus, 3, 5)?;
425
426 let original_share_bytes: Vec<Vec<u8>> = shares
428 .iter()
429 .map(|s| scalar_to_32_bytes(&s.secret_share.share).to_vec())
430 .collect();
431
432 for share in &shares {
434 validate_share(share)?;
435 }
436
437 let new_shares: Vec<VerifiableSecretShare> = shares
439 .iter()
440 .enumerate()
441 .map(|(i, original_share)| {
442 let share_scalar = bytes_to_scalar(&original_share_bytes[i]).unwrap();
443 VerifiableSecretShare {
444 secret_share: SecretShare {
445 field_modulus: original_share.secret_share.field_modulus,
446 threshold: original_share.secret_share.threshold,
447 index: Scalar::from((i + 1) as u64),
448 share: share_scalar,
449 },
450 proofs: original_share.proofs.clone(),
451 }
452 })
453 .collect();
454
455 for share in &new_shares {
457 validate_share(share)?;
458 }
459
460 for (orig, new) in shares.iter().zip(new_shares.iter()) {
462 assert_eq!(
463 scalar_to_32_bytes(&orig.secret_share.share),
464 scalar_to_32_bytes(&new.secret_share.share),
465 "Share bytes don't match after reconstruction"
466 );
467 }
468
469 let recovered_orig = recover_secret(&shares[0..3])?;
471 let recovered_new = recover_secret(&new_shares[0..3])?;
472 assert_eq!(recovered_orig, recovered_new);
473
474 Ok(())
475 }
476
477 #[test]
478 fn test_secret_sharing_out_of_range_share() {
479 let out_of_range_bytes = hex::decode(CURVE_ORDER_HEX).expect("invalid hex for n");
481
482 let result = bytes_to_scalar(&out_of_range_bytes);
484 assert!(
485 result.is_err(),
486 "Expected out-of-range error, but got an Ok(Scalar)?"
487 );
488
489 println!("Got expected error: {:?}", result.err().unwrap());
490 }
491}