diff --git a/Cargo.lock b/Cargo.lock index c34927f8..f3c2f33e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6370,7 +6370,7 @@ dependencies = [ [[package]] name = "trailbase-client" -version = "0.0.2" +version = "0.0.3" dependencies = [ "jsonwebtoken", "log", diff --git a/client/trailbase-rs/Cargo.toml b/client/trailbase-rs/Cargo.toml index 15575074..6e90e6a0 100644 --- a/client/trailbase-rs/Cargo.toml +++ b/client/trailbase-rs/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "trailbase-client" -version = "0.0.2" +version = "0.0.3" edition = "2021" license = "OSL-3.0" description = "Client for accessing TrailBase's record APIs" diff --git a/client/trailbase-rs/src/lib.rs b/client/trailbase-rs/src/lib.rs index a221ac6f..3b28e786 100644 --- a/client/trailbase-rs/src/lib.rs +++ b/client/trailbase-rs/src/lib.rs @@ -72,37 +72,39 @@ impl RecordId<'_> for i64 { } struct ThinClient { client: reqwest::Client, - site: String, + url: url::Url, } impl ThinClient { - async fn fetch( + async fn fetch( &self, - path: String, + path: &str, headers: HeaderMap, method: Method, - body: Option, - query_params: Option>, + body: Option<&T>, + query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>, ) -> Result { - if path.starts_with("/") { - return Err(Error::Precondition("path must not start with '/'")); - } + assert!(path.starts_with("/")); - let mut url = url::Url::parse(&format!("{}/{path}", self.site))?; + let mut url = self.url.clone(); + url.set_path(path); if let Some(query_params) = query_params { + let mut params = url.query_pairs_mut(); for (key, value) in query_params { - url.query_pairs_mut().append_pair(&key, &value); + params.append_pair(key, value); } } - let mut builder = self.client.request(method, url).headers(headers); + let request = { + let mut builder = self.client.request(method, url).headers(headers); + if let Some(ref body) = body { + builder = builder.body(serde_json::to_string(body)?); + } + builder.build()? + }; - if let Some(body) = body { - builder = builder.body(serde_json::to_string(&body)?); - } - - return Ok(self.client.execute(builder.build()?).await?); + return Ok(self.client.execute(request).await?); } } @@ -140,19 +142,19 @@ impl RecordApi { order: Option<&[&str]>, filters: Option<&[&str]>, ) -> Result, Error> { - let mut params: Vec<(String, String)> = vec![]; + let mut params: Vec<(Cow<'static, str>, Cow<'static, str>)> = vec![]; if let Some(pagination) = pagination { if let Some(cursor) = pagination.cursor { - params.push(("cursor".to_string(), cursor)); + params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor))); } if let Some(limit) = pagination.limit { - params.push(("limit".to_string(), limit.to_string())); + params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string()))); } } if let Some(order) = order { - params.push(("order".to_string(), order.join(","))); + params.push((Cow::Borrowed("order"), Cow::Owned(order.join(",")))); } if let Some(filters) = filters { @@ -161,17 +163,20 @@ impl RecordApi { panic!("Filter '{filter}' does not match: 'name[op]=value'"); }; - params.push((name_op.to_string(), value.to_string())); + params.push(( + Cow::Owned(name_op.to_string()), + Cow::Owned(value.to_string()), + )); } } let response = self .client .fetch( - format!("{RECORD_API}/{}", self.name), + &format!("/{RECORD_API}/{}", self.name), Method::GET, - None, - Some(params), + None::<&()>, + Some(¶ms), ) .await?; @@ -182,13 +187,13 @@ impl RecordApi { let response = self .client .fetch( - format!( - "{RECORD_API}/{name}/{id}", + &format!( + "/{RECORD_API}/{name}/{id}", name = self.name, id = id.serialized_id() ), Method::GET, - None, + None::<&()>, None, ) .await?; @@ -200,9 +205,9 @@ impl RecordApi { let response = self .client .fetch( - format!("{RECORD_API}/{name}", name = self.name), + &format!("/{RECORD_API}/{name}", name = self.name), Method::POST, - Some(serde_json::to_value(record)?), + Some(&record), None, ) .await?; @@ -223,13 +228,13 @@ impl RecordApi { self .client .fetch( - format!( - "{RECORD_API}/{name}/{id}", + &format!( + "/{RECORD_API}/{name}/{id}", name = self.name, id = id.serialized_id() ), Method::PATCH, - Some(serde_json::to_value(record)?), + Some(&record), None, ) .await?; @@ -241,13 +246,13 @@ impl RecordApi { self .client .fetch( - format!( - "{RECORD_API}/{name}/{id}", + &format!( + "/{RECORD_API}/{name}/{id}", name = self.name, id = id.serialized_id() ), Method::DELETE, - None, + None::<&()>, None, ) .await?; @@ -281,68 +286,50 @@ impl TokenState { struct ClientState { client: ThinClient, site: String, - token_state: RwLock, + tokens: RwLock, } impl ClientState { #[inline] - async fn fetch( + async fn fetch( &self, - url: String, + path: &str, method: Method, - body: Option, - query_params: Option>, + body: Option<&T>, + query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>, ) -> Result { - let (needs_refetch, mut headers) = { - let token_state = self.token_state.read(); - ( - Self::should_refresh(&token_state), - token_state.headers.clone(), - ) - }; + let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp(); + if let Some(refresh_token) = refresh_token { + let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?; - if needs_refetch { - let refresh_token = { - let token_state = self.token_state.read(); - let Some(ref refresh_token) = token_state - .state - .as_ref() - .and_then(|s| s.0.refresh_token.clone()) - else { - return Err(Error::Precondition("Missing refresh token")); - }; - refresh_token.clone() - }; - - let new_token_state = - ClientState::refresh_tokens(&self.client, headers, refresh_token).await?; - - let new_headers = new_token_state.headers.clone(); - - *self.token_state.write() = new_token_state; - - headers = new_headers; + headers = new_tokens.headers.clone(); + *self.tokens.write() = new_tokens; } return self .client - .fetch(url, headers, method, body, query_params) + .fetch(path, headers, method, body, query_params) .await; } #[inline] - fn should_refresh(token_state: &TokenState) -> bool { - let now = now(); - if let Some(ref state) = token_state.state { - return state.1.exp - 60 < now as i64; + fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option) { + #[inline] + fn should_refresh(jwt: &JwtTokenClaims) -> bool { + return jwt.exp - 60 < now() as i64; } - return false; + + let tokens = self.tokens.read(); + let headers = tokens.headers.clone(); + return match tokens.state { + Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()), + _ => (headers, None), + }; } - fn extract_refresh_token_and_headers( - token_state: &TokenState, - ) -> Result<(String, HeaderMap), Error> { - let Some(ref state) = token_state.state else { + fn extract_headers_refresh_token(&self) -> Result<(HeaderMap, String), Error> { + let tokens = self.tokens.read(); + let Some(ref state) = tokens.state else { return Err(Error::Precondition("Not logged int?")); }; @@ -350,7 +337,7 @@ impl ClientState { return Err(Error::Precondition("Missing refresh token")); }; - return Ok((refresh_token.clone(), token_state.headers.clone())); + return Ok((tokens.headers.clone(), refresh_token.clone())); } async fn refresh_tokens( @@ -358,14 +345,19 @@ impl ClientState { headers: HeaderMap, refresh_token: String, ) -> Result { + #[derive(Serialize)] + struct RefreshRequest<'a> { + refresh_token: &'a str, + } + let response = client .fetch( - format!("{AUTH_API}/refresh"), + &format!("/{AUTH_API}/refresh"), headers, Method::POST, - Some(serde_json::json!({ - "refresh_token": refresh_token, - })), + Some(&RefreshRequest { + refresh_token: &refresh_token, + }), None, ) .await?; @@ -391,17 +383,17 @@ pub struct Client { } impl Client { - pub fn new(site: &str, tokens: Option) -> Client { - return Client { + pub fn new(site: &str, tokens: Option) -> Result { + return Ok(Client { state: Arc::new(ClientState { client: ThinClient { client: reqwest::Client::new(), - site: site.to_string(), + url: url::Url::parse(site)?, }, site: site.to_string(), - token_state: RwLock::new(TokenState::build(tokens.as_ref())), + tokens: RwLock::new(TokenState::build(tokens.as_ref())), }), - }; + }); } pub fn site(&self) -> String { @@ -409,17 +401,11 @@ impl Client { } pub fn tokens(&self) -> Option { - return self - .state - .token_state - .read() - .state - .as_ref() - .map(|x| x.0.clone()); + return self.state.tokens.read().state.as_ref().map(|x| x.0.clone()); } pub fn user(&self) -> Option { - if let Some(state) = &self.state.token_state.read().state { + if let Some(state) = &self.state.tokens.read().state { return Some(User { sub: state.1.sub.clone(), email: state.1.email.clone(), @@ -436,25 +422,27 @@ impl Client { } pub async fn refresh(&self) -> Result<(), Error> { - let (refresh_token, headers) = - ClientState::extract_refresh_token_and_headers(&self.state.token_state.read())?; - let new_token_state = + let (headers, refresh_token) = self.state.extract_headers_refresh_token()?; + let new_tokens = ClientState::refresh_tokens(&self.state.client, headers, refresh_token).await?; - *self.state.token_state.write() = new_token_state; + *self.state.tokens.write() = new_tokens; return Ok(()); } pub async fn login(&self, email: &str, password: &str) -> Result { + #[derive(Serialize)] + struct Credentials<'a> { + email: &'a str, + password: &'a str, + } + let response = self .state .fetch( - format!("{AUTH_API}/login"), + &format!("/{AUTH_API}/login"), Method::POST, - Some(serde_json::json!({ - "email": email, - "password": password, - })), + Some(&Credentials { email, password }), None, ) .await?; @@ -465,39 +453,45 @@ impl Client { } pub async fn logout(&self) -> Result<(), Error> { - let refresh_token: Option = self - .state - .token_state - .read() - .state - .as_ref() - .and_then(|s| s.0.refresh_token.clone()); - if let Some(refresh_token) = refresh_token { - self - .state - .fetch( - format!("{AUTH_API}/logout"), - Method::POST, - Some(serde_json::json!({"refresh_token": refresh_token})), - None, - ) - .await?; - } else { - self - .state - .fetch(format!("{AUTH_API}/logout"), Method::GET, None, None) - .await?; + #[derive(Serialize)] + struct LogoutRequest { + refresh_token: String, } + let response_or = match self.state.extract_headers_refresh_token() { + Ok((_headers, refresh_token)) => { + self + .state + .fetch( + &format!("/{AUTH_API}/logout"), + Method::POST, + Some(&LogoutRequest { refresh_token }), + None, + ) + .await + } + _ => { + self + .state + .fetch( + &format!("/{AUTH_API}/logout"), + Method::GET, + None::<&()>, + None, + ) + .await + } + }; + self.update_tokens(None); - return Ok(()); + return response_or.map(|_| ()); } fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState { let state = TokenState::build(tokens); - *self.state.token_state.write() = state.clone(); + *self.state.tokens.write() = state.clone(); // _authChange?.call(this, state.state?.$1); if let Some(ref s) = state.state { @@ -558,7 +552,7 @@ mod tests { #[tokio::test] async fn is_send_test() { - let client = Client::new("http://127.0.0.1:4000", None); + let client = Client::new("http://127.0.0.1:4000", None).unwrap(); let api = client.records("simple_strict_table"); diff --git a/client/trailbase-rs/tests/integration_test.rs b/client/trailbase-rs/tests/integration_test.rs index 4899f60b..22b9fe14 100644 --- a/client/trailbase-rs/tests/integration_test.rs +++ b/client/trailbase-rs/tests/integration_test.rs @@ -79,7 +79,7 @@ struct SimpleStrict { } async fn connect() -> Client { - let client = Client::new(&format!("http://127.0.0.1:{PORT}"), None); + let client = Client::new(&format!("http://127.0.0.1:{PORT}"), None).unwrap(); let _ = client.login("admin@localhost", "secret").await.unwrap(); return client; }