spark_sdk/wallet/internal_handlers/implementations/
create_tree.rs

1use crate::common_types::types::Address;
2use crate::common_types::types::Encodable;
3use crate::common_types::types::Secp256k1;
4use crate::common_types::types::Transaction;
5use crate::common_types::types::TxIn;
6use crate::common_types::types::TxOut;
7use crate::common_types::types::Txid;
8use crate::error::SparkSdkError;
9use crate::signer::traits::secp256k1::KeygenMethod;
10use crate::signer::traits::SparkSigner;
11use crate::wallet::client::SparkSdk;
12use crate::wallet::internal_handlers::traits::create_tree::BuildCreationNodesFromTreeSdkResponse;
13use crate::wallet::internal_handlers::traits::create_tree::CreateAddressRequestNodeFromTreeNodesSdkResponse;
14use crate::wallet::internal_handlers::traits::create_tree::CreateDepositAddressBinaryTreeSdkResponse;
15use crate::wallet::internal_handlers::traits::create_tree::CreateTreeInternalHandlers;
16use crate::wallet::internal_handlers::traits::create_tree::DepositAddressTree;
17use crate::wallet::internal_handlers::traits::create_tree::FinalizeTreeCreationSdkResponse;
18use crate::wallet::internal_handlers::traits::create_tree::GenerateDepositAddressForTreeSdkResponse;
19use crate::wallet::internal_handlers::utils::frost_commitment_to_proto_commitment;
20use crate::wallet::leaf_manager::LeafNode;
21use parking_lot::RwLock;
22use spark_protos::spark::prepare_tree_address_request::Source as SourceProto;
23use spark_protos::spark::AddressNode as AddressNodeProto;
24use spark_protos::spark::AddressRequestNode as AddressRequestNodeProto;
25use spark_protos::spark::CreateTreeRequest;
26use spark_protos::spark::FinalizeNodeSignaturesRequest;
27use spark_protos::spark::NodeOutput as NodeOutputProto;
28use spark_protos::spark::PrepareTreeAddressRequest;
29use spark_protos::spark::TreeNode as TreeNodeProto;
30use spark_protos::spark::Utxo as UtxoProto;
31use std::collections::VecDeque;
32use std::str::FromStr;
33use std::sync::Arc;
34use tonic::async_trait;
35
36#[async_trait]
37impl<S: SparkSigner + Send + Sync + Clone + 'static> CreateTreeInternalHandlers<S> for SparkSdk<S> {
38    /// Creates a binary tree of deposit addresses.
39    ///
40    /// The tree is created by recursively splitting the target signing private key into two halves.
41    /// Each node in the tree represents a deposit address, and the children of each node are the
42    /// next level of the tree.
43    ///
44    /// # Arguments
45    ///
46    /// * `split_level` - The level of the tree to create.
47    /// * `target_pubkey` - The public key to split into the tree.
48    fn create_deposit_address_binary_tree(
49        &self,
50        split_level: u32,
51        target_pubkey: &Vec<u8>,
52    ) -> Result<CreateDepositAddressBinaryTreeSdkResponse, SparkSdkError> {
53        if split_level == 0 {
54            return Ok(CreateDepositAddressBinaryTreeSdkResponse { tree: vec![] });
55        }
56
57        // generate left pubkey
58        let left_pubkey = self.signer.new_secp256k1_keypair(KeygenMethod::Random)?;
59
60        // left node
61        let mut left_node = DepositAddressTree {
62            address: None,
63            verification_key: None,
64            signing_public_key: left_pubkey.clone(),
65            children: vec![],
66        };
67
68        // create left children recursively
69        let left_children =
70            self.create_deposit_address_binary_tree(split_level - 1, &left_pubkey)?;
71        left_node.children = left_children.tree;
72
73        // calculate right pubkey
74        let right_pubkey =
75            self.signer
76                .subtract_secret_keys_given_pubkeys(target_pubkey, &left_pubkey, true)?;
77
78        // right node
79        let mut right_node = DepositAddressTree {
80            address: None,
81            verification_key: None,
82            signing_public_key: right_pubkey.clone(),
83            children: vec![],
84        };
85
86        // create right children recursively
87        let right_children =
88            self.create_deposit_address_binary_tree(split_level - 1, &right_pubkey)?;
89        right_node.children = right_children.tree;
90
91        Ok(CreateDepositAddressBinaryTreeSdkResponse {
92            tree: vec![
93                Arc::new(RwLock::new(left_node)),
94                Arc::new(RwLock::new(right_node)),
95            ],
96        })
97    }
98
99    fn create_address_request_node_from_tree_nodes(
100        &self,
101        tree_nodes: &Vec<Arc<RwLock<DepositAddressTree>>>,
102    ) -> Result<CreateAddressRequestNodeFromTreeNodesSdkResponse, SparkSdkError> {
103        let mut results = Vec::<AddressRequestNodeProto>::new();
104
105        for node in tree_nodes {
106            let node = node.read();
107
108            let address_request_node =
109                self.create_address_request_node_from_tree_nodes(&node.children)?;
110            let address_request_node = AddressRequestNodeProto {
111                user_public_key: node.signing_public_key.clone(),
112                children: address_request_node.address_request_nodes,
113            };
114            results.push(address_request_node);
115        }
116
117        Ok(CreateAddressRequestNodeFromTreeNodesSdkResponse {
118            address_request_nodes: results,
119        })
120    }
121
122    fn apply_address_nodes_to_tree(
123        &self,
124        tree: &mut Vec<Arc<RwLock<DepositAddressTree>>>,
125        address_nodes: Vec<AddressNodeProto>,
126    ) -> Result<(), SparkSdkError> {
127        for (i, node) in tree.iter_mut().enumerate() {
128            let mut node = node.write();
129            let node_address_data = address_nodes[i].address.clone().unwrap();
130            node.address = Some(node_address_data.address);
131            node.verification_key = Some(node_address_data.verifying_key);
132
133            if !node.children.is_empty() {
134                self.apply_address_nodes_to_tree(
135                    &mut node.children,
136                    address_nodes[i].children.clone(),
137                )?;
138            }
139        }
140
141        Ok(())
142    }
143
144    async fn generate_deposit_address_for_tree(
145        &self,
146        parent_tx: Option<Transaction>,
147        parent_node: Option<Arc<RwLock<TreeNodeProto>>>,
148        vout: u32,
149        parent_public_key: Vec<u8>,
150        split_level: u32,
151    ) -> Result<GenerateDepositAddressForTreeSdkResponse, SparkSdkError> {
152        let mut spark_client = self.config.spark_config.get_spark_connection(None).await?;
153        let network_proto = self.config.spark_config.network.marshal_proto();
154
155        // 1. Create the binary tree given the user request
156        let time_start = std::time::Instant::now();
157        let mut deposit_address_tree =
158            self.create_deposit_address_binary_tree(split_level, &parent_public_key)?;
159        let duration = time_start.elapsed();
160        println!(
161            "[create_deposit_address_binary_tree] duration: {:?}",
162            duration
163        );
164
165        // If split_level = 0, len = 1. if split_level = 1, len = 1 + 2 = 3. If split_level = 2, len = 1 + 2 + 4 = 7.
166        // This is because the tree is a binary tree, and the number of nodes is 2^split_level.
167        // assert!(tree.len() == 2u32.pow(split_level) as usize);
168
169        // 2. Create the address request nodes (in proto format) from the tree nodes
170        let time_start = std::time::Instant::now();
171        let address_nodes =
172            self.create_address_request_node_from_tree_nodes(&mut deposit_address_tree.tree)?;
173        let duration = time_start.elapsed();
174        println!(
175            "[create_address_request_node_from_tree_nodes] duration: {:?}",
176            duration
177        );
178
179        // 3. Send PrepareTreeAddressRequest to Spark. This is the first step of tree creation.
180        let request_source = match parent_node {
181            Some(parent_node) => {
182                let node_output = NodeOutputProto {
183                    node_id: parent_node.read().id.clone(),
184                    vout,
185                };
186                SourceProto::ParentNodeOutput(node_output)
187            }
188            None => {
189                let mut raw_tx = Vec::new();
190                parent_tx
191                    .as_ref()
192                    .unwrap()
193                    .consensus_encode(&mut raw_tx)
194                    .map_err(|e| SparkSdkError::InvalidArgument(e.to_string()))?;
195                let utxo = UtxoProto {
196                    vout,
197                    raw_tx,
198                    network: network_proto,
199                };
200                SourceProto::OnChainUtxo(utxo)
201            }
202        };
203        let source = Some(request_source);
204        let request_node = AddressRequestNodeProto {
205            user_public_key: parent_public_key.clone(),
206            children: address_nodes.address_request_nodes,
207        };
208        let node = Some(request_node);
209
210        let mut request = tonic::Request::new(PrepareTreeAddressRequest {
211            user_identity_public_key: self.get_identity_public_key().to_vec(),
212            node,
213            source,
214        });
215        self.add_authorization_header_to_request(&mut request, None);
216
217        let time_start = std::time::Instant::now();
218
219        let local_time = chrono::Local::now();
220        println!(
221            "Sending prepare_tree_address request at: {}",
222            local_time.format("%Y-%m-%d %H:%M:%S")
223        );
224        let spark_tree_response_ = spark_client.prepare_tree_address(request).await?;
225        let spark_tree_response = spark_tree_response_.into_inner();
226        let duration = time_start.elapsed();
227        println!("[prepare_tree_address] duration: {:?}", duration);
228
229        // 4. Create the root node
230        let response_address_node = spark_tree_response.node.unwrap();
231        let root = DepositAddressTree {
232            address: None,
233            verification_key: None,
234            signing_public_key: parent_public_key.clone(),
235            children: deposit_address_tree.tree.clone(),
236        };
237        let root = Arc::new(RwLock::new(root));
238
239        // 5. Apply the address nodes to the tree
240        let mut root_in_vec = vec![root];
241        let time_start = std::time::Instant::now();
242        println!("Applying address nodes to tree at time: {:?}", time_start);
243        self.apply_address_nodes_to_tree(&mut root_in_vec, vec![response_address_node])?;
244        let duration = time_start.elapsed();
245        println!("[apply_address_nodes_to_tree] duration: {:?}", duration);
246
247        // 6. Extract the root back from the vector and return it
248        let root = root_in_vec.remove(0);
249
250        Ok(GenerateDepositAddressForTreeSdkResponse { tree: root })
251    }
252
253    fn build_creation_nodes_from_tree(
254        &self,
255        parent_txid: Txid,
256        txout_: &TxOut,
257        vout: u32,
258        root: Arc<RwLock<DepositAddressTree>>,
259    ) -> Result<BuildCreationNodesFromTreeSdkResponse, SparkSdkError> {
260        struct TreeNode {
261            parent_txid: Txid,
262            txout: TxOut,
263            vout: u32,
264            node: Arc<RwLock<DepositAddressTree>>,
265        }
266
267        let mut creation_node = spark_protos::spark::CreationNode::default();
268        let mut queue = VecDeque::<(TreeNode, *mut spark_protos::spark::CreationNode)>::new();
269
270        queue.push_back((
271            TreeNode {
272                parent_txid,
273                txout: txout_.clone(),
274                vout,
275                node: root,
276            },
277            &mut creation_node,
278        ));
279
280        let network = self.config.spark_config.network.to_bitcoin_network();
281        let secp = Secp256k1::new();
282
283        while let Some((current, creation_ptr)) = queue.pop_front() {
284            let creation_ref = unsafe { &mut *creation_ptr };
285            let local_node = current.node.read();
286
287            // let node_signing_key = local_node.verification_key.clone();
288            // let user_signing_key = SecretKey::from_slice(&node_signing_key).unwrap();
289            let user_verifying_key_ = local_node.signing_public_key.clone();
290            let user_verifying_key =
291                bitcoin::secp256k1::PublicKey::from_slice(&user_verifying_key_).unwrap();
292
293            if !local_node.children.is_empty() {
294                let tx = {
295                    let tx_input = bitcoin::TxIn {
296                        previous_output: bitcoin::OutPoint {
297                            txid: current.parent_txid,
298                            vout: current.vout,
299                        },
300                        script_sig: bitcoin::ScriptBuf::default(),
301                        sequence: bitcoin::Sequence::ZERO,
302                        witness: bitcoin::Witness::default(),
303                    };
304
305                    let mut output = Vec::new();
306                    for child in local_node.children.iter() {
307                        let child_address = child.read().address.clone().unwrap();
308                        let child_address = Address::from_str(&child_address).unwrap();
309                        let child_address = child_address.require_network(network).unwrap();
310                        output.push(TxOut {
311                            value: bitcoin::Amount::from_sat(
312                                current.txout.value.to_sat() / local_node.children.len() as u64,
313                            ),
314                            script_pubkey: child_address.script_pubkey(),
315                        });
316                    }
317
318                    bitcoin::Transaction {
319                        version: bitcoin::transaction::Version::TWO,
320                        lock_time: bitcoin::absolute::LockTime::ZERO,
321                        input: vec![tx_input],
322                        output,
323                    }
324                };
325
326                let mut tx_buf = vec![];
327                tx.consensus_encode(&mut tx_buf)
328                    .map_err(|e| SparkSdkError::InvalidArgument(e.to_string()))?;
329
330                let commitment = self.signer.new_frost_signing_noncepair()?;
331
332                let signing_job = spark_protos::spark::SigningJob {
333                    signing_public_key: user_verifying_key.serialize().to_vec(),
334                    raw_tx: tx_buf,
335                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
336                        &commitment,
337                    )?),
338                };
339
340                creation_ref.node_tx_signing_job = Some(signing_job);
341                creation_ref.children = vec![Default::default(); local_node.children.len()];
342
343                let txid = tx.compute_txid();
344                for (i, child) in local_node.children.iter().enumerate() {
345                    queue.push_back((
346                        TreeNode {
347                            parent_txid: txid,
348                            txout: tx.output[i].clone(),
349                            vout: i as u32,
350                            node: child.clone(),
351                        },
352                        &mut creation_ref.children[i] as *mut _,
353                    ));
354                }
355            } else {
356                let aggregated_address = local_node.address.clone().unwrap();
357                let aggregated_address = Address::from_str(&aggregated_address).unwrap();
358                let aggregated_address = aggregated_address.require_network(network).unwrap();
359
360                let node_tx = {
361                    let input = bitcoin::TxIn {
362                        previous_output: bitcoin::OutPoint {
363                            txid: current.parent_txid,
364                            vout: current.vout,
365                        },
366                        script_sig: bitcoin::ScriptBuf::default(),
367                        // sequence: bitcoin::Sequence::ZERO,
368                        sequence: bitcoin::Sequence::ZERO,
369                        witness: bitcoin::Witness::default(),
370                    };
371
372                    bitcoin::Transaction {
373                        version: bitcoin::transaction::Version::TWO,
374                        lock_time: bitcoin::absolute::LockTime::ZERO,
375                        input: vec![input],
376                        output: vec![TxOut {
377                            value: current.txout.value,
378                            script_pubkey: aggregated_address.script_pubkey(),
379                        }],
380                    }
381                };
382
383                let mut node_tx_buf = vec![];
384                node_tx
385                    .consensus_encode(&mut node_tx_buf)
386                    .map_err(|e| SparkSdkError::InvalidArgument(e.to_string()))?;
387
388                let refund_tx = {
389                    let user_self_xonly = user_verifying_key.x_only_public_key().0;
390                    let user_self_address = Address::p2tr(&secp, user_self_xonly, None, network);
391
392                    bitcoin::Transaction {
393                        version: bitcoin::transaction::Version::TWO,
394                        lock_time: bitcoin::absolute::LockTime::ZERO,
395                        input: vec![TxIn {
396                            previous_output: bitcoin::OutPoint {
397                                txid: node_tx.compute_txid(),
398                                vout: 0,
399                            },
400                            script_sig: bitcoin::ScriptBuf::default(),
401                            // TODO: this must be the default sequence. For tree creation here, we can set it to MAX, since this is an SSP feature and is unlikely to be used here. Yet, this will affect unilateral exits.
402                            sequence: bitcoin::Sequence::MAX,
403                            witness: bitcoin::Witness::default(),
404                        }],
405                        output: vec![TxOut {
406                            value: current.txout.value,
407                            script_pubkey: user_self_address.script_pubkey(),
408                        }],
409                    }
410                };
411
412                let mut refund_tx_buf = vec![];
413                refund_tx
414                    .consensus_encode(&mut refund_tx_buf)
415                    .map_err(|e| SparkSdkError::InvalidArgument(e.to_string()))?;
416
417                let node_commitment = self.signer.new_frost_signing_noncepair()?;
418                let refund_commitment = self.signer.new_frost_signing_noncepair()?;
419
420                creation_ref.node_tx_signing_job = Some(spark_protos::spark::SigningJob {
421                    signing_public_key: user_verifying_key.serialize().to_vec(),
422                    raw_tx: node_tx_buf,
423                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
424                        &node_commitment,
425                    )?),
426                });
427
428                creation_ref.refund_tx_signing_job = Some(spark_protos::spark::SigningJob {
429                    signing_public_key: user_verifying_key.serialize().to_vec(),
430                    raw_tx: refund_tx_buf,
431                    signing_nonce_commitment: Some(frost_commitment_to_proto_commitment(
432                        &refund_commitment,
433                    )?),
434                });
435            }
436        }
437
438        Ok(BuildCreationNodesFromTreeSdkResponse {
439            creation_nodes: creation_node,
440        })
441    }
442
443    async fn finalize_tree_creation(
444        &self,
445        parent_tx: Option<Transaction>,
446        parent_node: Option<Arc<RwLock<TreeNodeProto>>>,
447        vout: u32,
448        root: Arc<RwLock<DepositAddressTree>>,
449    ) -> Result<FinalizeTreeCreationSdkResponse, SparkSdkError> {
450        let mut request = CreateTreeRequest {
451            user_identity_public_key: self.get_identity_public_key().to_vec(),
452            ..Default::default()
453        };
454
455        let final_parent_tx = if let Some(ptx) = parent_tx {
456            let mut raw_tx = Vec::new();
457            ptx.consensus_encode(&mut raw_tx)
458                .map_err(|e| SparkSdkError::InvalidArgument(e.to_string()))?;
459
460            request.source = Some(
461                spark_protos::spark::create_tree_request::Source::OnChainUtxo(
462                    spark_protos::spark::Utxo {
463                        // txid,
464                        vout,
465                        raw_tx,
466                        network: self.config.spark_config.network.marshal_proto(),
467                    },
468                ),
469            );
470            ptx
471        } else if let Some(parent_node) = parent_node {
472            let mut tx_buf = parent_node.read().node_tx.clone();
473            let ptx: Transaction = bitcoin::consensus::deserialize(&mut tx_buf).map_err(|_| {
474                SparkSdkError::InvalidArgument("Failed to parse parent node_tx".into())
475            })?;
476            let node_id = parent_node.read().id.clone();
477            request.source = Some(
478                spark_protos::spark::create_tree_request::Source::ParentNodeOutput(
479                    spark_protos::spark::NodeOutput { node_id, vout },
480                ),
481            );
482            ptx
483        } else {
484            return Err(SparkSdkError::InvalidArgument(
485                "No parent_tx or parent_node provided to create_tree".into(),
486            ));
487        };
488
489        let parent_txid = final_parent_tx.compute_txid();
490        let time_start = std::time::Instant::now();
491        let creation_node_response = self.build_creation_nodes_from_tree(
492            parent_txid,
493            &final_parent_tx.output[vout as usize],
494            vout,
495            root.clone(),
496        )?;
497        let duration = time_start.elapsed();
498        println!("[build_creation_nodes_from_tree] duration: {:?}", duration);
499
500        request.node = Some(creation_node_response.creation_nodes.clone());
501        let mut tonic_request = tonic::Request::new(request);
502        self.add_authorization_header_to_request(&mut tonic_request, None);
503
504        let mut spark_client = self.config.spark_config.get_spark_connection(None).await?;
505        let time_start = std::time::Instant::now();
506        let resp = spark_client.create_tree(tonic_request).await?;
507        let duration = time_start.elapsed();
508        println!("[create_tree] duration: {:?}", duration);
509
510        let create_tree_result = resp.into_inner().node.ok_or_else(|| {
511            SparkSdkError::InvalidArgument("Coordinator returned no creation node".into())
512        })?;
513
514        let time_start = std::time::Instant::now();
515        let (node_signatures, signing_public_keys) = self.signer.sign_created_tree_in_bfs_order(
516            final_parent_tx,
517            vout,
518            root,
519            creation_node_response.creation_nodes,
520            create_tree_result,
521        )?;
522        let duration = time_start.elapsed();
523        println!("[sign_created_tree_in_bfs_order] duration: {:?}", duration);
524
525        let mut spark_client = self.config.spark_config.get_spark_connection(None).await?;
526        let mut spark_signatures_request = tonic::Request::new(FinalizeNodeSignaturesRequest {
527            node_signatures: node_signatures.clone(),
528            ..Default::default()
529        });
530        self.add_authorization_header_to_request(&mut spark_signatures_request, None);
531
532        let time_start = std::time::Instant::now();
533        let spark_signatures_response_ = spark_client
534            .finalize_node_signatures(spark_signatures_request)
535            .await?;
536        let spark_signatures_response = spark_signatures_response_.into_inner();
537        let duration = time_start.elapsed();
538        println!("[finalize_node_signatures] duration: {:?}", duration);
539
540        // Slice the array to get the second half (including the center)
541        let starting_index = spark_signatures_response.nodes.len() / 2;
542
543        // Get all node ids
544        let node_ids = spark_signatures_response
545            .nodes
546            .iter()
547            .map(|node| node.id.clone())
548            .collect::<Vec<_>>();
549
550        let leaf_nodes = spark_signatures_response
551            .nodes
552            .iter()
553            .skip(spark_signatures_response.nodes.len() / 2)
554            .collect::<Vec<_>>();
555
556        let mut leaf_nodes_to_insert = vec![];
557        for (i, node) in leaf_nodes.iter().enumerate() {
558            println!("node.tree_id: {:?}", node.tree_id);
559
560            let node_id = node.id.clone();
561            let i = i + starting_index;
562            if node_id != node_signatures[i].clone().node_id {
563                return Err(SparkSdkError::InvalidArgument("Node ID mismatch".into()));
564            }
565
566            let verifying_public_key = node.verifying_public_key.clone();
567            let node_tx_transaction = node.node_tx.clone();
568            let refund_tx_transaction = node.refund_tx.clone();
569            let value = node.value;
570            let node_tx_signature = node_signatures[i].node_tx_signature.clone();
571            let refund_tx_signature = node_signatures[i].refund_tx_signature.clone();
572
573            let _verifying_key = node.verifying_public_key.clone();
574
575            // TODO: this is an error-prone approach. For roots, this approach is correct.
576            // For splits, this approach is incorrect because now it should be strictly the right half, excluding the center.
577
578            // TODO: For the parents, add aggregated.
579            leaf_nodes_to_insert.push(LeafNode::new(
580                node_id,
581                "".to_string(),
582                value,
583                node.parent_node_id.clone(),
584                node.vout,
585                verifying_public_key,
586                signing_public_keys[i].clone(),
587                // node_tx_signature,
588                // refund_tx_signature,
589                node_tx_transaction,
590                refund_tx_transaction,
591                None,   // default will be the Bitcoin token public key
592                vec![], // No revocation public key for tree creation
593                vec![], // No token transaction hash for tree creation
594            ));
595        }
596
597        self.leaf_manager
598            .insert_leaves_in_batch(leaf_nodes_to_insert)?;
599
600        Ok(FinalizeTreeCreationSdkResponse {
601            finalize_tree_response: spark_signatures_response,
602            signing_public_keys,
603        })
604    }
605}