spark_sdk/wallet/
graphql.rs1use std::collections::HashMap;
4
5use regex::Regex;
6use reqwest::Client;
7use serde_json::{from_slice, to_vec, Value};
8use url::Url;
9use zstd::{decode_all, encode_all};
10
11use crate::constants::spark::LIGHTSPARK_SSP_ENDPOINT;
12
13#[derive(Debug)]
18pub struct RequestError {
19 pub message: String,
20 pub status_code: u16,
21}
22
23impl std::fmt::Display for RequestError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 write!(
26 f,
27 "lightspark request failed: {}: {}",
28 self.status_code, self.message
29 )
30 }
31}
32
33impl std::error::Error for RequestError {}
34
35#[derive(Debug)]
38pub struct GraphQLInternalError {
39 pub message: String,
40}
41
42impl std::fmt::Display for GraphQLInternalError {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "lightspark request failed: {}", self.message)
45 }
46}
47
48impl std::error::Error for GraphQLInternalError {}
49
50#[derive(Debug)]
53pub struct GraphQLError {
54 pub message: String,
55 pub error_type: String,
56}
57
58impl std::fmt::Display for GraphQLError {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "{}: {}", self.error_type, self.message)
61 }
62}
63
64impl std::error::Error for GraphQLError {}
65
66#[derive(Clone)]
68pub struct GraphqlClient {
69 base_url: Option<String>,
70 identity_public_key: String,
71 http_client: Client,
72}
73
74impl GraphqlClient {
75 pub fn with_base_url(
77 identity_public_key: String,
78 base_url: Option<String>,
79 ) -> Result<Self, Box<dyn std::error::Error>> {
80 if let Some(ref url) = base_url {
81 validate_base_url(url)?;
82 }
83 Ok(GraphqlClient {
84 base_url,
85 identity_public_key,
86 http_client: Client::new(),
87 })
88 }
89
90 pub async fn execute_graphql(
92 &self,
93 query: &str,
94 variables: HashMap<String, Value>,
95 ) -> Result<HashMap<String, Value>, Box<dyn std::error::Error>> {
96 let re = Regex::new(r"(?i)\s*(?:query|mutation)\s+(?P<OperationName>\w+)")?;
97 let captures = re.captures(query).ok_or("invalid query payload")?;
98 let operation_name = captures
99 .name("OperationName")
100 .ok_or("invalid query payload")?
101 .as_str();
102
103 let payload = serde_json::json!({
105 "operationName": operation_name,
106 "query": query,
107 "variables": variables,
108 });
109 let encoded_payload = to_vec(&payload)?;
110
111 let (body, compressed) = if encoded_payload.len() > 1024 {
113 let compressed = encode_all(&encoded_payload[..], 0)?;
114 (compressed, true)
115 } else {
116 (encoded_payload, false)
117 };
118
119 let server_url = self.base_url.as_deref().unwrap_or(LIGHTSPARK_SSP_ENDPOINT);
121 validate_base_url(server_url)?;
122
123 let mut request = self
125 .http_client
126 .post(server_url)
127 .header("Spark-Identity-Public-Key", &self.identity_public_key)
128 .header("Content-Type", "application/json")
129 .header("Accept-Encoding", "zstd")
130 .header("X-GraphQL-Operation", operation_name)
131 .header("User-Agent", self.get_user_agent())
132 .header("X-Polarity-SDK", self.get_user_agent());
133
134 if compressed {
135 request = request.header("Content-Encoding", "zstd");
136 }
137
138 let response = request.body(body).send().await?;
139
140 let status = response.status();
142 if !status.is_success() {
143 return Err(Box::new(RequestError {
144 message: status.to_string(),
145 status_code: status.as_u16(),
146 }));
147 }
148
149 let is_zstd_encoded = response
151 .headers()
152 .get("Content-Encoding")
153 .map_or(false, |v| v == "zstd");
154
155 let mut data = response.bytes().await?.to_vec();
157
158 if is_zstd_encoded {
160 data = decode_all(&data[..])?;
161 }
162
163 let result: HashMap<String, Value> = from_slice(&data)?;
165 if let Some(errors) = result.get("errors") {
166 let err = errors
167 .as_array()
168 .and_then(|arr| arr.first())
169 .and_then(|v| v.as_object())
170 .ok_or("invalid error format")?;
171 let error_message = err
172 .get("message")
173 .and_then(|v| v.as_str())
174 .ok_or("missing error message")?
175 .to_string();
176
177 if let Some(extensions) = err.get("extensions").and_then(|v| v.as_object()) {
178 if let Some(error_name) = extensions.get("error_name").and_then(|v| v.as_str()) {
179 return Err(Box::new(GraphQLError {
180 message: error_message,
181 error_type: error_name.to_string(),
182 }));
183 }
184 }
185 return Err(Box::new(GraphQLInternalError {
186 message: error_message,
187 }));
188 }
189
190 result
191 .get("data")
192 .and_then(|v| {
193 v.as_object()
194 .map(|o| o.into_iter().map(|(k, v)| (k.clone(), v.clone())).collect())
195 })
196 .ok_or_else(|| "missing data field".into())
197 }
198
199 fn get_user_agent(&self) -> &str {
200 "spark"
201 }
202}
203
204fn validate_base_url(base_url: &str) -> Result<(), Box<dyn std::error::Error>> {
205 let parsed_url = Url::parse(base_url)?;
206 let hostname = parsed_url.host_str().unwrap_or("");
207 let host_parts: Vec<&str> = hostname.split('.').collect();
208 let tld = host_parts.last().unwrap_or(&"");
209
210 let is_whitelisted_localhost =
211 hostname == "localhost" || *tld == "local" || *tld == "internal" || hostname == "127.0.0.1";
212
213 if parsed_url.scheme() != "https" && !is_whitelisted_localhost {
214 return Err("invalid base url: must be https:// if not targeting localhost".into());
215 }
216 Ok(())
217}