Optimize rust client: less allocations and locking.

This commit is contained in:
Sebastian Jeltsch
2025-01-17 23:02:03 +01:00
parent b3d4cbd99f
commit 9fdc206587
4 changed files with 127 additions and 133 deletions

2
Cargo.lock generated
View File

@@ -6370,7 +6370,7 @@ dependencies = [
[[package]]
name = "trailbase-client"
version = "0.0.2"
version = "0.0.3"
dependencies = [
"jsonwebtoken",
"log",

View File

@@ -1,6 +1,6 @@
[package]
name = "trailbase-client"
version = "0.0.2"
version = "0.0.3"
edition = "2021"
license = "OSL-3.0"
description = "Client for accessing TrailBase's record APIs"

View File

@@ -72,37 +72,39 @@ impl RecordId<'_> for i64 {
}
struct ThinClient {
client: reqwest::Client,
site: String,
url: url::Url,
}
impl ThinClient {
async fn fetch(
async fn fetch<T: Serialize>(
&self,
path: String,
path: &str,
headers: HeaderMap,
method: Method,
body: Option<serde_json::Value>,
query_params: Option<Vec<(String, String)>>,
body: Option<&T>,
query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
) -> Result<reqwest::Response, Error> {
if path.starts_with("/") {
return Err(Error::Precondition("path must not start with '/'"));
}
assert!(path.starts_with("/"));
let mut url = url::Url::parse(&format!("{}/{path}", self.site))?;
let mut url = self.url.clone();
url.set_path(path);
if let Some(query_params) = query_params {
let mut params = url.query_pairs_mut();
for (key, value) in query_params {
url.query_pairs_mut().append_pair(&key, &value);
params.append_pair(key, value);
}
}
let mut builder = self.client.request(method, url).headers(headers);
let request = {
let mut builder = self.client.request(method, url).headers(headers);
if let Some(ref body) = body {
builder = builder.body(serde_json::to_string(body)?);
}
builder.build()?
};
if let Some(body) = body {
builder = builder.body(serde_json::to_string(&body)?);
}
return Ok(self.client.execute(builder.build()?).await?);
return Ok(self.client.execute(request).await?);
}
}
@@ -140,19 +142,19 @@ impl RecordApi {
order: Option<&[&str]>,
filters: Option<&[&str]>,
) -> Result<Vec<T>, Error> {
let mut params: Vec<(String, String)> = vec![];
let mut params: Vec<(Cow<'static, str>, Cow<'static, str>)> = vec![];
if let Some(pagination) = pagination {
if let Some(cursor) = pagination.cursor {
params.push(("cursor".to_string(), cursor));
params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
}
if let Some(limit) = pagination.limit {
params.push(("limit".to_string(), limit.to_string()));
params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
}
}
if let Some(order) = order {
params.push(("order".to_string(), order.join(",")));
params.push((Cow::Borrowed("order"), Cow::Owned(order.join(","))));
}
if let Some(filters) = filters {
@@ -161,17 +163,20 @@ impl RecordApi {
panic!("Filter '{filter}' does not match: 'name[op]=value'");
};
params.push((name_op.to_string(), value.to_string()));
params.push((
Cow::Owned(name_op.to_string()),
Cow::Owned(value.to_string()),
));
}
}
let response = self
.client
.fetch(
format!("{RECORD_API}/{}", self.name),
&format!("/{RECORD_API}/{}", self.name),
Method::GET,
None,
Some(params),
None::<&()>,
Some(&params),
)
.await?;
@@ -182,13 +187,13 @@ impl RecordApi {
let response = self
.client
.fetch(
format!(
"{RECORD_API}/{name}/{id}",
&format!(
"/{RECORD_API}/{name}/{id}",
name = self.name,
id = id.serialized_id()
),
Method::GET,
None,
None::<&()>,
None,
)
.await?;
@@ -200,9 +205,9 @@ impl RecordApi {
let response = self
.client
.fetch(
format!("{RECORD_API}/{name}", name = self.name),
&format!("/{RECORD_API}/{name}", name = self.name),
Method::POST,
Some(serde_json::to_value(record)?),
Some(&record),
None,
)
.await?;
@@ -223,13 +228,13 @@ impl RecordApi {
self
.client
.fetch(
format!(
"{RECORD_API}/{name}/{id}",
&format!(
"/{RECORD_API}/{name}/{id}",
name = self.name,
id = id.serialized_id()
),
Method::PATCH,
Some(serde_json::to_value(record)?),
Some(&record),
None,
)
.await?;
@@ -241,13 +246,13 @@ impl RecordApi {
self
.client
.fetch(
format!(
"{RECORD_API}/{name}/{id}",
&format!(
"/{RECORD_API}/{name}/{id}",
name = self.name,
id = id.serialized_id()
),
Method::DELETE,
None,
None::<&()>,
None,
)
.await?;
@@ -281,68 +286,50 @@ impl TokenState {
struct ClientState {
client: ThinClient,
site: String,
token_state: RwLock<TokenState>,
tokens: RwLock<TokenState>,
}
impl ClientState {
#[inline]
async fn fetch(
async fn fetch<T: Serialize>(
&self,
url: String,
path: &str,
method: Method,
body: Option<serde_json::Value>,
query_params: Option<Vec<(String, String)>>,
body: Option<&T>,
query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
) -> Result<reqwest::Response, Error> {
let (needs_refetch, mut headers) = {
let token_state = self.token_state.read();
(
Self::should_refresh(&token_state),
token_state.headers.clone(),
)
};
let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
if let Some(refresh_token) = refresh_token {
let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?;
if needs_refetch {
let refresh_token = {
let token_state = self.token_state.read();
let Some(ref refresh_token) = token_state
.state
.as_ref()
.and_then(|s| s.0.refresh_token.clone())
else {
return Err(Error::Precondition("Missing refresh token"));
};
refresh_token.clone()
};
let new_token_state =
ClientState::refresh_tokens(&self.client, headers, refresh_token).await?;
let new_headers = new_token_state.headers.clone();
*self.token_state.write() = new_token_state;
headers = new_headers;
headers = new_tokens.headers.clone();
*self.tokens.write() = new_tokens;
}
return self
.client
.fetch(url, headers, method, body, query_params)
.fetch(path, headers, method, body, query_params)
.await;
}
#[inline]
fn should_refresh(token_state: &TokenState) -> bool {
let now = now();
if let Some(ref state) = token_state.state {
return state.1.exp - 60 < now as i64;
fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
#[inline]
fn should_refresh(jwt: &JwtTokenClaims) -> bool {
return jwt.exp - 60 < now() as i64;
}
return false;
let tokens = self.tokens.read();
let headers = tokens.headers.clone();
return match tokens.state {
Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
_ => (headers, None),
};
}
fn extract_refresh_token_and_headers(
token_state: &TokenState,
) -> Result<(String, HeaderMap), Error> {
let Some(ref state) = token_state.state else {
fn extract_headers_refresh_token(&self) -> Result<(HeaderMap, String), Error> {
let tokens = self.tokens.read();
let Some(ref state) = tokens.state else {
return Err(Error::Precondition("Not logged int?"));
};
@@ -350,7 +337,7 @@ impl ClientState {
return Err(Error::Precondition("Missing refresh token"));
};
return Ok((refresh_token.clone(), token_state.headers.clone()));
return Ok((tokens.headers.clone(), refresh_token.clone()));
}
async fn refresh_tokens(
@@ -358,14 +345,19 @@ impl ClientState {
headers: HeaderMap,
refresh_token: String,
) -> Result<TokenState, Error> {
#[derive(Serialize)]
struct RefreshRequest<'a> {
refresh_token: &'a str,
}
let response = client
.fetch(
format!("{AUTH_API}/refresh"),
&format!("/{AUTH_API}/refresh"),
headers,
Method::POST,
Some(serde_json::json!({
"refresh_token": refresh_token,
})),
Some(&RefreshRequest {
refresh_token: &refresh_token,
}),
None,
)
.await?;
@@ -391,17 +383,17 @@ pub struct Client {
}
impl Client {
pub fn new(site: &str, tokens: Option<Tokens>) -> Client {
return Client {
pub fn new(site: &str, tokens: Option<Tokens>) -> Result<Client, Error> {
return Ok(Client {
state: Arc::new(ClientState {
client: ThinClient {
client: reqwest::Client::new(),
site: site.to_string(),
url: url::Url::parse(site)?,
},
site: site.to_string(),
token_state: RwLock::new(TokenState::build(tokens.as_ref())),
tokens: RwLock::new(TokenState::build(tokens.as_ref())),
}),
};
});
}
pub fn site(&self) -> String {
@@ -409,17 +401,11 @@ impl Client {
}
pub fn tokens(&self) -> Option<Tokens> {
return self
.state
.token_state
.read()
.state
.as_ref()
.map(|x| x.0.clone());
return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
}
pub fn user(&self) -> Option<User> {
if let Some(state) = &self.state.token_state.read().state {
if let Some(state) = &self.state.tokens.read().state {
return Some(User {
sub: state.1.sub.clone(),
email: state.1.email.clone(),
@@ -436,25 +422,27 @@ impl Client {
}
pub async fn refresh(&self) -> Result<(), Error> {
let (refresh_token, headers) =
ClientState::extract_refresh_token_and_headers(&self.state.token_state.read())?;
let new_token_state =
let (headers, refresh_token) = self.state.extract_headers_refresh_token()?;
let new_tokens =
ClientState::refresh_tokens(&self.state.client, headers, refresh_token).await?;
*self.state.token_state.write() = new_token_state;
*self.state.tokens.write() = new_tokens;
return Ok(());
}
pub async fn login(&self, email: &str, password: &str) -> Result<Tokens, Error> {
#[derive(Serialize)]
struct Credentials<'a> {
email: &'a str,
password: &'a str,
}
let response = self
.state
.fetch(
format!("{AUTH_API}/login"),
&format!("/{AUTH_API}/login"),
Method::POST,
Some(serde_json::json!({
"email": email,
"password": password,
})),
Some(&Credentials { email, password }),
None,
)
.await?;
@@ -465,39 +453,45 @@ impl Client {
}
pub async fn logout(&self) -> Result<(), Error> {
let refresh_token: Option<String> = self
.state
.token_state
.read()
.state
.as_ref()
.and_then(|s| s.0.refresh_token.clone());
if let Some(refresh_token) = refresh_token {
self
.state
.fetch(
format!("{AUTH_API}/logout"),
Method::POST,
Some(serde_json::json!({"refresh_token": refresh_token})),
None,
)
.await?;
} else {
self
.state
.fetch(format!("{AUTH_API}/logout"), Method::GET, None, None)
.await?;
#[derive(Serialize)]
struct LogoutRequest {
refresh_token: String,
}
let response_or = match self.state.extract_headers_refresh_token() {
Ok((_headers, refresh_token)) => {
self
.state
.fetch(
&format!("/{AUTH_API}/logout"),
Method::POST,
Some(&LogoutRequest { refresh_token }),
None,
)
.await
}
_ => {
self
.state
.fetch(
&format!("/{AUTH_API}/logout"),
Method::GET,
None::<&()>,
None,
)
.await
}
};
self.update_tokens(None);
return Ok(());
return response_or.map(|_| ());
}
fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
let state = TokenState::build(tokens);
*self.state.token_state.write() = state.clone();
*self.state.tokens.write() = state.clone();
// _authChange?.call(this, state.state?.$1);
if let Some(ref s) = state.state {
@@ -558,7 +552,7 @@ mod tests {
#[tokio::test]
async fn is_send_test() {
let client = Client::new("http://127.0.0.1:4000", None);
let client = Client::new("http://127.0.0.1:4000", None).unwrap();
let api = client.records("simple_strict_table");

View File

@@ -79,7 +79,7 @@ struct SimpleStrict {
}
async fn connect() -> Client {
let client = Client::new(&format!("http://127.0.0.1:{PORT}"), None);
let client = Client::new(&format!("http://127.0.0.1:{PORT}"), None).unwrap();
let _ = client.login("admin@localhost", "secret").await.unwrap();
return client;
}