spark_sdk/wallet/leaf_manager/
mod.rs

1//! Leaf manager for the Spark wallet.
2
3// std
4use std::sync::Arc;
5
6//crates
7use crate::{
8    constants::spark::{BITCOIN_TOKEN_PUBLIC_KEY, DUST_AMOUNT},
9    error::SparkSdkError,
10    SparkNetwork,
11};
12use hashbrown::HashMap;
13use serde::{Deserialize, Serialize};
14use spark_protos::spark::TreeNode;
15use uuid::Uuid;
16
17use super::internal_handlers::traits::leaves::LeafSelectionResponse;
18
19pub(crate) type LeafMap = Arc<parking_lot::RwLock<HashMap<String, LeafNode>>>;
20
21pub(crate) struct LeafManager {
22    /// The map of leaf nodes
23    leaves: LeafMap,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct LeafNode {
28    /// The id of the leaf node. This is used to derive the child index of the leaf node as well.
29    pub(crate) id: String,
30
31    /// The tree id of the leaf node
32    pub(crate) tree_id: String,
33
34    /// The value of the leaf node
35    pub(crate) value: u64,
36
37    /// The parent node id of the leaf node
38    pub(crate) parent_node_id: Option<String>,
39
40    /// The transaction of the node
41    pub(crate) node_transaction: Vec<u8>,
42
43    /// The transaction of the refund
44    pub(crate) refund_transaction: Vec<u8>,
45
46    /// The vout of the leaf node
47    pub(crate) vout: u32,
48
49    /// The verifying public key of the leaf node
50    pub(crate) verifying_public_key: Vec<u8>,
51
52    /// The signing public key
53    pub(crate) signing_public_key: Vec<u8>,
54
55    /// The token public key
56    pub(crate) token_public_key: Vec<u8>,
57
58    /// Revocation public key (for tokens only)
59    pub(crate) revocation_public_key: Vec<u8>,
60
61    /// Token transaction hash
62    pub(crate) token_transaction_hash: Vec<u8>,
63
64    /// The status of the leaf node
65    status: LeafNodeStatusInternal,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct PublicLeafNode(LeafNode);
70
71// A trait for converting to public-facing types
72pub trait ToPublic {
73    type Public;
74    fn to_public(self) -> Self::Public;
75}
76
77impl ToPublic for LeafNode {
78    type Public = PublicLeafNode;
79
80    fn to_public(self) -> PublicLeafNode {
81        PublicLeafNode(self)
82    }
83}
84
85// Implement Deref/DerefMut if you want field access to work seamlessly
86impl std::ops::Deref for PublicLeafNode {
87    type Target = LeafNode;
88
89    fn deref(&self) -> &Self::Target {
90        &self.0
91    }
92}
93
94// Then implement methods specifically for PublicLeafNode where status returns LeafNodeStatus
95impl PublicLeafNode {
96    pub fn status(&self) -> LeafNodeStatus {
97        self.0.status.clone().into()
98    }
99}
100
101impl From<TreeNode> for PublicLeafNode {
102    fn from(tree_node: TreeNode) -> Self {
103        PublicLeafNode(LeafNode {
104            id: tree_node.id,
105            tree_id: tree_node.tree_id,
106            value: tree_node.value,
107            parent_node_id: tree_node.parent_node_id,
108            status: LeafNodeStatusInternal::Available,
109            node_transaction: tree_node.node_tx,
110            refund_transaction: tree_node.refund_tx,
111            vout: tree_node.vout,
112            verifying_public_key: tree_node.verifying_public_key,
113            signing_public_key: tree_node.owner_identity_public_key,
114            token_public_key: vec![],
115            revocation_public_key: vec![],
116            token_transaction_hash: vec![],
117        })
118    }
119}
120
121impl LeafNode {
122    // Provide a public helper that checks whether the node is 'Available'
123    pub(crate) fn is_available(&self) -> bool {
124        // LeafNodeStatusInternal is private, but we can still match on it here
125        matches!(self.status, LeafNodeStatusInternal::Available)
126    }
127
128    pub(crate) fn marshal_to_tree_node(
129        &self,
130        identity_public_key: impl Into<Vec<u8>>,
131        network: &SparkNetwork,
132    ) -> TreeNode {
133        TreeNode {
134            id: self.id.clone(),
135            tree_id: self.tree_id.clone(),
136            value: self.value,
137            parent_node_id: self.parent_node_id.clone(),
138            node_tx: self.node_transaction.clone(),
139            refund_tx: self.refund_transaction.clone(),
140            vout: self.vout,
141            verifying_public_key: self.verifying_public_key.clone(),
142            owner_identity_public_key: identity_public_key.into(),
143            signing_keyshare: None,
144            status: "".to_string(),
145            network: marshal_spark_network(network),
146        }
147    }
148}
149
150impl LeafNode {
151    pub(crate) fn new(
152        id: String,
153        tree_id: String,
154        value: u64,
155        parent_node_id: Option<String>,
156        vout: u32,
157        verifying_public_key: Vec<u8>,
158        signing_public_key: Vec<u8>,
159        node_transaction: Vec<u8>,
160        refund_transaction: Vec<u8>,
161        token_public_key: Option<Vec<u8>>,
162        revocation_public_key: Vec<u8>,
163        token_transaction_hash: Vec<u8>,
164    ) -> Self {
165        Self {
166            id,
167            tree_id,
168            value,
169            parent_node_id,
170            vout,
171            verifying_public_key,
172            signing_public_key,
173            node_transaction,
174            refund_transaction,
175            token_public_key: token_public_key.unwrap_or(vec![]),
176            revocation_public_key,
177            token_transaction_hash,
178            status: LeafNodeStatusInternal::Available,
179        }
180    }
181}
182
183/// Leaf status
184#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
185pub enum LeafNodeStatus {
186    /// Not used
187    Available,
188
189    /// Being transferred - not usable for any other operation. The string is the transfer id.
190    InTransfer,
191
192    /// Being split - not usable for any other operation. The string is the split id.
193    InSplit,
194
195    /// In swap - not usable for any other operation. The string is the swap id.
196    InSwap,
197
198    /// Aggregatable parent
199    AggregatableParent,
200}
201
202type RequestId = String;
203
204/// Internal leaf node status
205#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
206enum LeafNodeStatusInternal {
207    /// Available
208    Available,
209
210    /// Being transferred - not usable for any other operation. The string is the transfer id.
211    InTransfer(RequestId),
212
213    /// Being split - not usable for any other operation. The string is the split id.
214    InSplit(RequestId),
215
216    /// In swap - not usable for any other operation. The string is the swap id.
217    InSwap(RequestId),
218
219    /// Aggregatable parent
220    AggregatableParent,
221}
222
223impl From<LeafNodeStatusInternal> for LeafNodeStatus {
224    fn from(status: LeafNodeStatusInternal) -> Self {
225        match status {
226            LeafNodeStatusInternal::Available => LeafNodeStatus::Available,
227            LeafNodeStatusInternal::InTransfer(_) => LeafNodeStatus::InTransfer,
228            LeafNodeStatusInternal::InSplit(_) => LeafNodeStatus::InSplit,
229            LeafNodeStatusInternal::InSwap(_) => LeafNodeStatus::InSwap,
230            LeafNodeStatusInternal::AggregatableParent => LeafNodeStatus::AggregatableParent,
231        }
232    }
233}
234
235impl From<LeafNodeStatus> for LeafNodeStatusInternal {
236    fn from(status: LeafNodeStatus) -> Self {
237        let request_id = uuid::Uuid::now_v7().to_string();
238        match status {
239            LeafNodeStatus::Available => LeafNodeStatusInternal::Available,
240            LeafNodeStatus::InTransfer => LeafNodeStatusInternal::InTransfer(request_id),
241            LeafNodeStatus::InSplit => LeafNodeStatusInternal::InSplit(request_id),
242            LeafNodeStatus::InSwap => LeafNodeStatusInternal::InSwap(request_id),
243            LeafNodeStatus::AggregatableParent => LeafNodeStatusInternal::AggregatableParent,
244        }
245    }
246}
247
248impl LeafNodeStatusInternal {
249    fn get_request_id(&self) -> Option<String> {
250        match self {
251            LeafNodeStatusInternal::Available => None,
252            LeafNodeStatusInternal::InTransfer(request_id) => Some(request_id.clone()),
253            LeafNodeStatusInternal::InSplit(request_id) => Some(request_id.clone()),
254            LeafNodeStatusInternal::InSwap(request_id) => Some(request_id.clone()),
255            LeafNodeStatusInternal::AggregatableParent => None,
256        }
257    }
258}
259
260impl LeafManager {
261    pub(crate) fn new() -> Self {
262        Self {
263            leaves: Arc::new(parking_lot::RwLock::new(HashMap::new())),
264        }
265    }
266
267    pub(crate) fn refresh_leaves(&self, leaves: Vec<PublicLeafNode>) -> Result<(), SparkSdkError> {
268        let mut guard = self.leaves.write();
269        for leaf in leaves {
270            guard.insert(leaf.id.clone(), leaf.0);
271        }
272        drop(guard);
273        Ok(())
274    }
275
276    pub(crate) fn get_leaf_count(
277        &self,
278        filter_cb: Option<Box<dyn Fn(&PublicLeafNode) -> bool>>,
279    ) -> Result<u32, SparkSdkError> {
280        // If there's no filter, just return how many items are in the database
281        if filter_cb.is_none() {
282            return Ok(self.leaves.read().len() as u32);
283        }
284
285        let filter_cb = filter_cb.unwrap();
286        let mut count = 0;
287
288        // Iterate over each entry in the leaves map
289        for node in self.leaves.read().values() {
290            // If the filter callback returns true, increment the count
291            if filter_cb(&node.clone().to_public()) {
292                count += 1;
293            }
294        }
295
296        Ok(count)
297    }
298
299    pub(crate) fn query_single_node(
300        &self,
301        cb: Option<Box<dyn Fn(&LeafNode) -> bool>>,
302        new_status: Option<LeafNodeStatus>,
303    ) -> Result<PublicLeafNode, SparkSdkError> {
304        match new_status {
305            Some(status) => {
306                let mut guard = self.leaves.write();
307                for node in guard.values_mut() {
308                    if cb.as_ref().map_or(false, |f| f(node)) {
309                        node.status = status.clone().into();
310                        return Ok(node.clone().to_public());
311                    }
312                }
313            }
314            None => {
315                let guard = self.leaves.read();
316                for node in guard.values() {
317                    if cb.as_ref().map_or(false, |f| f(node)) {
318                        return Ok(node.clone().to_public());
319                    }
320                }
321            }
322        }
323
324        Err(SparkSdkError::InvalidInput("No node found".to_string()))
325    }
326
327    pub(crate) fn insert_leaves_in_batch(
328        &self,
329        new_leaves: Vec<LeafNode>,
330    ) -> Result<(), SparkSdkError> {
331        let mut guard = self.leaves.write();
332        for leaf in new_leaves {
333            // when inserting in batch, the key shouldn't be already existing in the map.
334            if guard.contains_key(&leaf.id) {
335                return Err(SparkSdkError::InvalidInput(format!(
336                    "Leaf {} already exists",
337                    leaf.id
338                )));
339            }
340            guard.insert(leaf.id.clone(), leaf);
341        }
342        Ok(())
343    }
344
345    pub(crate) fn lock_leaf_ids_for_operation(
346        &self,
347        leaf_ids: &[String],
348        new_status: LeafNodeStatus,
349    ) -> Result<(Vec<PublicLeafNode>, String), SparkSdkError> {
350        let internal_status = LeafNodeStatusInternal::from(new_status);
351        let unlock_id = internal_status.get_request_id().unwrap();
352
353        let mut leaves = Vec::new();
354        let mut guard = self.leaves.write();
355        for leaf_id in leaf_ids {
356            let get_leaf = guard.get(leaf_id);
357            if get_leaf.is_none() {
358                drop(guard);
359                return Err(SparkSdkError::InvalidInput(format!(
360                    "Leaf {} not found",
361                    leaf_id
362                )));
363            }
364            let leaf = get_leaf.unwrap();
365            if leaf.status != LeafNodeStatusInternal::Available {
366                drop(guard);
367                return Err(SparkSdkError::InvalidInput(format!(
368                    "Leaf {} is not available",
369                    leaf_id
370                )));
371            }
372            leaves.push(leaf.clone().to_public());
373        }
374
375        for leaf_id in leaf_ids {
376            let leaf = guard.get_mut(leaf_id).unwrap();
377            leaf.status = internal_status.clone();
378        }
379
380        Ok((leaves, unlock_id))
381    }
382
383    pub(crate) fn get_btc_value(
384        &self,
385        filter_cb: Option<Box<dyn Fn(&PublicLeafNode) -> bool>>,
386    ) -> Result<u64, SparkSdkError> {
387        // set the default filter for mapping btc leaves
388        let mut default_cb: Box<dyn Fn(&PublicLeafNode) -> bool> = Box::new(|_| true);
389
390        // if a custom filter is provided, combine it with the default filter
391        if let Some(additional_filter) = filter_cb {
392            let combined_filter =
393                move |node: &PublicLeafNode| default_cb(node) && additional_filter(node);
394            default_cb = Box::new(combined_filter);
395        }
396
397        let value_sum: u64 = self
398            .leaves
399            .read()
400            .values()
401            .filter_map(|node| {
402                let pub_node = node.clone().to_public();
403                if default_cb(&pub_node) {
404                    Some(pub_node.value)
405                } else {
406                    None
407                }
408            })
409            .sum();
410
411        Ok(value_sum)
412    }
413
414    pub(crate) fn get_node(&self, leaf_id: &String) -> Result<PublicLeafNode, SparkSdkError> {
415        let guard = self.leaves.read();
416        let node = guard
417            .get(leaf_id)
418            .ok_or_else(|| SparkSdkError::InvalidInput(format!("Leaf {} not found", leaf_id)))?
419            .clone();
420
421        Ok(node.to_public())
422    }
423
424    pub(crate) fn get_nodes(
425        &self,
426        leaf_ids: &[String],
427    ) -> Result<Vec<PublicLeafNode>, SparkSdkError> {
428        let nodes = leaf_ids
429            .iter()
430            .map(|id| self.get_node(id))
431            .collect::<Result<Vec<PublicLeafNode>, SparkSdkError>>()?;
432        Ok(nodes)
433    }
434
435    pub(crate) fn update_leaves_with_status(
436        &self,
437        leaf_ids: &[String],
438        new_status: LeafNodeStatus,
439    ) -> Result<(), SparkSdkError> {
440        let mut leaves = self.leaves.write();
441
442        // if all the leaves aren't available, return an errpr
443        for leaf_id in leaf_ids {
444            let leaf = leaves.get(leaf_id).ok_or_else(|| {
445                SparkSdkError::InvalidInput(format!("Leaf {} not found", leaf_id))
446            })?;
447            if leaf.status != LeafNodeStatusInternal::Available {
448                drop(leaves);
449                return Err(SparkSdkError::InvalidInput(format!(
450                    "Leaf {} is not available",
451                    leaf_id
452                )));
453            }
454        }
455
456        for leaf_id in leaf_ids {
457            let leaf = leaves.get_mut(leaf_id).unwrap();
458            leaf.status = new_status.clone().into();
459        }
460
461        drop(leaves);
462
463        Ok(())
464    }
465
466    pub(crate) fn select_leaves(
467        &self,
468        target_value: u64,
469        token_pubkey: Option<&Vec<u8>>,
470        new_status: LeafNodeStatus,
471    ) -> Result<LeafSelectionResponse, SparkSdkError> {
472        // get a write lock
473        let mut guard = self.leaves.write();
474
475        // extract the token pubkey
476        let empty_vec: Vec<u8> = vec![];
477        let token_pubkey = token_pubkey.unwrap_or(&empty_vec);
478
479        // define the filter condition: leaf node, available, and token pubkey matches
480        let filter_condition = |node: &LeafNode| {
481            node.token_public_key == *token_pubkey
482                && node.status == LeafNodeStatusInternal::Available
483        };
484
485        // map all the leaves
486        let leaves: HashMap<String, LeafNode> = guard
487            .iter()
488            .filter(|(_, node)| filter_condition(node))
489            .map(|(leaf_id, node)| (leaf_id.clone(), node.clone()))
490            .collect();
491
492        // check if the leaves add up to the target value
493        let total_value: u64 = leaves.values().map(|leaf| leaf.value).sum();
494        if total_value < target_value {
495            drop(guard);
496            return Err(SparkSdkError::InvalidInput(format!(
497                "Total value of leaves is less than target value: Total value is: {}. Target value is: {}",
498                total_value, target_value
499            )));
500        }
501
502        // get the candidates
503        let mut chosen_leaves: Vec<(String, PublicLeafNode)> = leaves
504            .iter()
505            .filter_map(|(leaf_id, node)| {
506                if node.value >= target_value && node.status == LeafNodeStatusInternal::Available {
507                    Some((leaf_id.clone(), node.clone().to_public()))
508                } else {
509                    None
510                }
511            })
512            .collect();
513
514        // sort the candidates by value
515        chosen_leaves.sort_by(|(_, a), (_, b)| a.value.cmp(&b.value));
516
517        // Update status for all candidates atomically
518        let unlocking_id = Uuid::now_v7().to_string();
519        let mut leaves = vec![];
520        for (leaf_id, _) in &chosen_leaves {
521            let node = guard.get_mut(leaf_id).unwrap();
522            node.status = LeafNodeStatusInternal::InTransfer(unlocking_id.clone());
523
524            leaves.push(node.clone().to_public());
525        }
526
527        Ok(LeafSelectionResponse {
528            leaves,
529            total_value,
530            unlocking_id: Some(unlocking_id),
531        })
532    }
533
534    pub(crate) fn update_node_status_for_split(
535        &self,
536        node_id: &str,
537        identity_public_key: &[u8],
538        network: &SparkNetwork,
539    ) -> Result<(Vec<u8>, Vec<u8>, TreeNode), SparkSdkError> {
540        let mut leaves = self.leaves.write();
541
542        // get and verify node exists
543        let node = leaves
544            .get_mut(node_id)
545            .ok_or_else(|| SparkSdkError::InvalidInput(format!("node {} not found", node_id)))?;
546
547        // check if node is available
548        if node.status != LeafNodeStatusInternal::Available {
549            return Err(SparkSdkError::InvalidInput(format!(
550                "node {} is not available",
551                node_id
552            )));
553        }
554
555        // Update status
556        node.status = LeafNodeStatus::InSplit.into();
557
558        // Store values we need to return
559        let return_values = (
560            node.refund_transaction.clone(),
561            node.signing_public_key.clone(),
562            node.marshal_to_tree_node(identity_public_key, network),
563        );
564
565        Ok(return_values)
566    }
567
568    pub(crate) async fn delete_leaves_after_transfer(
569        &self,
570        transfer_request_id: &str,
571        leaf_ids: &[String],
572    ) -> Result<(), SparkSdkError> {
573        let mut leaves = self.leaves.write();
574
575        // make sure that all leaves are in transfer lock with the same request id
576        for leaf_id in leaf_ids {
577            let leaf = leaves.get(leaf_id).ok_or_else(|| {
578                SparkSdkError::InvalidInput(format!("Leaf {} not found", leaf_id))
579            })?;
580            if leaf.status != LeafNodeStatusInternal::InTransfer(transfer_request_id.to_string()) {
581                return Err(SparkSdkError::InvalidInput(format!(
582                    "Leaf {} is not in transfer lock with request id {}",
583                    leaf_id, transfer_request_id
584                )));
585            }
586        }
587
588        for leaf_id in leaf_ids {
589            leaves.remove(leaf_id);
590        }
591
592        Ok(())
593    }
594
595    pub(crate) fn select_leaf_to_split(
596        &self,
597        target_value: u64,
598    ) -> Result<PublicLeafNode, SparkSdkError> {
599        let guard = self.leaves.write();
600
601        // Find first available leaf with value greater than target
602        let leaf = guard
603            .values()
604            .find(|leaf| {
605                leaf.status == LeafNodeStatusInternal::Available && leaf.value >= target_value
606            })
607            .ok_or_else(|| SparkSdkError::InvalidInput("No suitable leaf found".to_string()))?;
608
609        Ok(leaf.clone().to_public())
610    }
611}
612
613fn marshal_spark_network(network: &SparkNetwork) -> i32 {
614    match network {
615        SparkNetwork::Mainnet => spark_protos::spark::Network::Mainnet as i32,
616        SparkNetwork::Regtest => spark_protos::spark::Network::Regtest as i32,
617    }
618}