spark_sdk/rpc/connections/
connection.rs

1use crate::common_types::types::certificates;
2use crate::error::SparkSdkError;
3use crate::rpc::traits::SparkRpcConnection;
4use crate::rpc::SparkRpcClient;
5use tonic::async_trait;
6use tonic::transport::Uri;
7use tonic::transport::{Certificate, Channel, ClientTlsConfig};
8
9#[derive(Debug, Clone)]
10pub(crate) struct SparkConnection {
11    /// The URL of the Spark RPC service
12    pub(crate) uri: Uri,
13
14    /// The client to use for the Spark RPC service
15    pub(crate) channel: Channel,
16}
17
18#[async_trait]
19impl SparkRpcConnection for SparkConnection {
20    async fn establish_connection(uri: Uri) -> Result<SparkRpcClient, SparkSdkError> {
21        let server_root_ca_cert = Certificate::from_pem(certificates::amazon_root_ca::CA_PEM);
22
23        // parse uri and domain name
24        let uri_str = uri.clone().to_string();
25        let domain_name = uri_str.trim_start_matches("https://").trim_end_matches('/');
26
27        // create tls config
28        let tls = ClientTlsConfig::new()
29            .domain_name(domain_name)
30            .ca_certificate(server_root_ca_cert);
31
32        // create channel with tls
33        let channel = Channel::from_shared(uri.to_string())
34            .unwrap()
35            .tls_config(tls)?
36            .connect()
37            .await?;
38
39        let connection = SparkConnection {
40            uri: uri.clone(),
41            channel,
42        };
43
44        Ok(SparkRpcClient::DefaultConnection(connection))
45    }
46}