spark_sdk/wallet/
graphql.rs

1// Copyright ©, 2023-present, Lightspark Group, Inc. - All Rights Reserved
2
3use std::collections::HashMap;
4
5use regex::Regex;
6use reqwest::Client;
7use serde_json::{from_slice, to_vec, Value};
8use url::Url;
9use zstd::{decode_all, encode_all};
10
11use crate::constants::spark::LIGHTSPARK_SSP_ENDPOINT;
12
13// Custom error types
14
15/// Indicates a request to the Lightspark API failed (e.g., network or server error).
16/// Retry if status code is 500-599.
17#[derive(Debug)]
18pub struct RequestError {
19    pub message: String,
20    pub status_code: u16,
21}
22
23impl std::fmt::Display for RequestError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        write!(
26            f,
27            "lightspark request failed: {}: {}",
28            self.status_code, self.message
29        )
30    }
31}
32
33impl std::error::Error for RequestError {}
34
35/// Indicates a failure in the Lightspark API (e.g., a bug).
36/// Retrying might help if transient.
37#[derive(Debug)]
38pub struct GraphQLInternalError {
39    pub message: String,
40}
41
42impl std::fmt::Display for GraphQLInternalError {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "lightspark request failed: {}", self.message)
45    }
46}
47
48impl std::error::Error for GraphQLInternalError {}
49
50/// Indicates a successful GraphQL request but with a user error.
51/// Do not retry; fix the input instead.
52#[derive(Debug)]
53pub struct GraphQLError {
54    pub message: String,
55    pub error_type: String,
56}
57
58impl std::fmt::Display for GraphQLError {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        write!(f, "{}: {}", self.error_type, self.message)
61    }
62}
63
64impl std::error::Error for GraphQLError {}
65
66/// Requester struct for making GraphQL requests to Lightspark API.
67#[derive(Clone)]
68pub struct GraphqlClient {
69    base_url: Option<String>,
70    identity_public_key: String,
71    http_client: Client,
72}
73
74impl GraphqlClient {
75    /// Creates a new Requester with a custom base URL.
76    pub fn with_base_url(
77        identity_public_key: String,
78        base_url: Option<String>,
79    ) -> Result<Self, Box<dyn std::error::Error>> {
80        if let Some(ref url) = base_url {
81            validate_base_url(url)?;
82        }
83        Ok(GraphqlClient {
84            base_url,
85            identity_public_key,
86            http_client: Client::new(),
87        })
88    }
89
90    /// Executes a GraphQL query or mutation with the given context (timeout handled via Client).
91    pub async fn execute_graphql(
92        &self,
93        query: &str,
94        variables: HashMap<String, Value>,
95    ) -> Result<HashMap<String, Value>, Box<dyn std::error::Error>> {
96        let re = Regex::new(r"(?i)\s*(?:query|mutation)\s+(?P<OperationName>\w+)")?;
97        let captures = re.captures(query).ok_or("invalid query payload")?;
98        let operation_name = captures
99            .name("OperationName")
100            .ok_or("invalid query payload")?
101            .as_str();
102
103        // Prepare payload
104        let payload = serde_json::json!({
105            "operationName": operation_name,
106            "query": query,
107            "variables": variables,
108        });
109        let encoded_payload = to_vec(&payload)?;
110
111        // Compress if payload > 1024 bytes
112        let (body, compressed) = if encoded_payload.len() > 1024 {
113            let compressed = encode_all(&encoded_payload[..], 0)?;
114            (compressed, true)
115        } else {
116            (encoded_payload, false)
117        };
118
119        // Determine server URL
120        let server_url = self.base_url.as_deref().unwrap_or(LIGHTSPARK_SSP_ENDPOINT);
121        validate_base_url(server_url)?;
122
123        // Build HTTP request
124        let mut request = self
125            .http_client
126            .post(server_url)
127            .header("Spark-Identity-Public-Key", &self.identity_public_key)
128            .header("Content-Type", "application/json")
129            .header("Accept-Encoding", "zstd")
130            .header("X-GraphQL-Operation", operation_name)
131            .header("User-Agent", self.get_user_agent())
132            .header("X-Polarity-SDK", self.get_user_agent());
133
134        if compressed {
135            request = request.header("Content-Encoding", "zstd");
136        }
137
138        let response = request.body(body).send().await?;
139
140        // Handle response
141        let status = response.status();
142        if !status.is_success() {
143            return Err(Box::new(RequestError {
144                message: status.to_string(),
145                status_code: status.as_u16(),
146            }));
147        }
148
149        // Check headers before consuming response
150        let is_zstd_encoded = response
151            .headers()
152            .get("Content-Encoding")
153            .map_or(false, |v| v == "zstd");
154
155        // Now consume response to get the body
156        let mut data = response.bytes().await?.to_vec();
157
158        // Decode if zstd encoded
159        if is_zstd_encoded {
160            data = decode_all(&data[..])?;
161        }
162
163        // Parse JSON response
164        let result: HashMap<String, Value> = from_slice(&data)?;
165        if let Some(errors) = result.get("errors") {
166            let err = errors
167                .as_array()
168                .and_then(|arr| arr.first())
169                .and_then(|v| v.as_object())
170                .ok_or("invalid error format")?;
171            let error_message = err
172                .get("message")
173                .and_then(|v| v.as_str())
174                .ok_or("missing error message")?
175                .to_string();
176
177            if let Some(extensions) = err.get("extensions").and_then(|v| v.as_object()) {
178                if let Some(error_name) = extensions.get("error_name").and_then(|v| v.as_str()) {
179                    return Err(Box::new(GraphQLError {
180                        message: error_message,
181                        error_type: error_name.to_string(),
182                    }));
183                }
184            }
185            return Err(Box::new(GraphQLInternalError {
186                message: error_message,
187            }));
188        }
189
190        result
191            .get("data")
192            .and_then(|v| {
193                v.as_object()
194                    .map(|o| o.into_iter().map(|(k, v)| (k.clone(), v.clone())).collect())
195            })
196            .ok_or_else(|| "missing data field".into())
197    }
198
199    fn get_user_agent(&self) -> &str {
200        "spark"
201    }
202}
203
204fn validate_base_url(base_url: &str) -> Result<(), Box<dyn std::error::Error>> {
205    let parsed_url = Url::parse(base_url)?;
206    let hostname = parsed_url.host_str().unwrap_or("");
207    let host_parts: Vec<&str> = hostname.split('.').collect();
208    let tld = host_parts.last().unwrap_or(&"");
209
210    let is_whitelisted_localhost =
211        hostname == "localhost" || *tld == "local" || *tld == "internal" || hostname == "127.0.0.1";
212
213    if parsed_url.scheme() != "https" && !is_whitelisted_localhost {
214        return Err("invalid base url: must be https:// if not targeting localhost".into());
215    }
216    Ok(())
217}