Add client apis for WebSocket subscriptions (requires a ws-enabled server build) and tests behind the "ws" feature.

This commit is contained in:
Sebastian Jeltsch
2026-01-17 20:05:48 +01:00
parent 369785a7f1
commit 2046f16c05
5 changed files with 225 additions and 7 deletions
Generated
+38
View File
@@ -459,6 +459,22 @@ dependencies = [
"syn",
]
[[package]]
name = "async-tungstenite"
version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8acc405d38be14342132609f06f02acaf825ddccfe76c4824a69281e0458ebd4"
dependencies = [
"atomic-waker",
"futures-core",
"futures-io",
"futures-task",
"futures-util",
"log",
"pin-project-lite",
"tungstenite",
]
[[package]]
name = "atomic-waker"
version = "1.1.2"
@@ -4646,6 +4662,26 @@ dependencies = [
"web-sys",
]
[[package]]
name = "reqwest-websocket"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7705b649c3b66b85c4e9c304a6898b1ae3eecb880c474720ebf925e4a932ae02"
dependencies = [
"async-tungstenite",
"bytes",
"futures-util",
"reqwest 0.13.1",
"serde",
"serde_json",
"thiserror 2.0.18",
"tokio",
"tokio-util",
"tracing",
"tungstenite",
"web-sys",
]
[[package]]
name = "reserve-port"
version = "2.3.0"
@@ -5779,6 +5815,7 @@ checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098"
dependencies = [
"bytes",
"futures-core",
"futures-io",
"futures-sink",
"pin-project-lite",
"tokio",
@@ -6214,6 +6251,7 @@ dependencies = [
"lazy_static",
"parking_lot",
"reqwest 0.13.1",
"reqwest-websocket",
"serde",
"serde_json",
"temp-dir",
+1
View File
@@ -15,6 +15,7 @@ name = "trail"
default = []
swagger = ["dep:utoipa-swagger-ui"]
vendor-ssl = ["dep:openssl"]
ws = ["trailbase/ws"]
[dependencies]
axum = { version = "^0.8.1", features=["multipart"] }
+5
View File
@@ -11,12 +11,17 @@ exclude = [
"tests",
]
[features]
default = []
ws = ["dep:reqwest-websocket"]
[dependencies]
eventsource-stream = { version = "0.2.3", features = [] }
futures-lite = "2.6.1"
jsonwebtoken = { version = "10.2.0", default-features = false, features = ["rust_crypto"] }
parking_lot = { workspace = true }
reqwest = { version = "0.13.1", features = ["stream"] }
reqwest-websocket = { version = "0.6.0", features = ["json"], optional = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = "2.0.12"
+108 -7
View File
@@ -42,6 +42,10 @@ pub enum Error {
// NOTE: This error is leaky but comprehensively unpacking reqwest is unsustainable.
#[error("Reqwest: {0}")]
OtherReqwest(reqwest::Error),
#[cfg(feature = "ws")]
#[error("WebSocket: {0}")]
WebSocket(#[from] reqwest_websocket::Error),
}
impl From<reqwest::Error> for Error {
@@ -189,7 +193,7 @@ struct ThinClient {
}
impl ThinClient {
async fn fetch<T: Serialize>(
async fn fetch_impl<T: Serialize>(
&self,
path: &str,
headers: HeaderMap,
@@ -220,6 +224,41 @@ impl ThinClient {
return Ok(self.client.execute(request).await?);
}
#[cfg(feature = "ws")]
async fn upgrade_ws_impl<T: Serialize>(
&self,
path: &str,
headers: HeaderMap,
method: Method,
body: Option<&T>,
query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
) -> Result<reqwest_websocket::UpgradeResponse, Error> {
use reqwest_websocket::Upgrade;
assert!(path.starts_with("/"));
let mut url = self.url.clone();
url.set_path(path);
if let Some(query_params) = query_params {
let mut params = url.query_pairs_mut();
for (key, value) in query_params {
params.append_pair(key, value);
}
}
let request = {
let mut builder = self.client.request(method, url).headers(headers);
if let Some(ref body) = body {
let json = serde_json::to_string(body).map_err(Error::RecordSerialization)?;
builder = builder.body(json);
}
builder.upgrade()
};
return Ok(request.send().await?);
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
@@ -556,15 +595,54 @@ impl RecordApi {
.bytes_stream()
.eventsource()
.filter_map(|event_or| {
if let Ok(event) = event_or
&& let Ok(db_event) = serde_json::from_str::<DbEvent>(&event.data)
{
return Some(db_event);
// QUESTION: Should we instead return a `Stream<Item = Result<DbEvent, _>>` to allow for
// better error handling here.
if let Ok(event) = event_or {
return serde_json::from_str::<DbEvent>(&event.data).ok();
}
return None;
}),
);
}
#[cfg(feature = "ws")]
pub async fn subscribe_ws<'a, T: RecordId<'a>>(
&self,
id: T,
) -> Result<impl Stream<Item = DbEvent> + use<T>, Error> {
let response = self
.client
.upgrade_ws(
&format!(
"/{RECORD_API}/{name}/subscribe/{id}",
name = self.name,
id = id.serialized_id()
),
Method::GET,
None::<&()>,
Some(&[("ws".into(), "true".into())]),
)
.await?;
let websocket = response.into_websocket().await?;
return Ok(websocket.filter_map(|message| {
use reqwest_websocket::Message;
return match message {
Ok(Message::Text(msg)) => serde_json::from_str::<DbEvent>(&msg)
.map_err(|err| {
warn!("json error: {err}");
return err;
})
.ok(),
msg => {
warn!("unexpected msg: {msg:?}");
None
}
};
}));
}
}
#[derive(Clone, Debug)]
@@ -617,12 +695,35 @@ impl ClientState {
return Ok(
self
.client
.fetch(path, headers, method, body, query_params)
.fetch_impl(path, headers, method, body, query_params)
.await?
.error_for_status()?,
);
}
#[cfg(feature = "ws")]
#[inline]
async fn upgrade_ws<T: Serialize>(
&self,
path: &str,
method: Method,
body: Option<&T>,
query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
) -> Result<reqwest_websocket::UpgradeResponse, Error> {
let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
if let Some(refresh_token) = refresh_token {
let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?;
headers = new_tokens.headers.clone();
*self.tokens.write() = new_tokens;
}
return self
.client
.upgrade_ws_impl(path, headers, method, body, query_params)
.await;
}
#[inline]
fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
#[inline]
@@ -659,7 +760,7 @@ impl ClientState {
}
let response = client
.fetch(
.fetch_impl(
&format!("/{AUTH_API}/refresh"),
headers,
Method::POST,
+73
View File
@@ -49,6 +49,10 @@ fn start_server() -> Result<Option<Server>, std::io::Error> {
let args = [
"run".to_string(),
#[cfg(feature = "ws")]
{
"--features=ws".to_string()
},
"--".to_string(),
format!("--data-dir={depot_path}"),
"run".to_string(),
@@ -471,6 +475,69 @@ async fn subscription_test() {
}
}
#[cfg(feature = "ws")]
async fn subscription_ws_test() {
let client = connect().await;
let api = client.records("simple_strict_table");
let table_stream = api.subscribe_ws("*").await.unwrap();
let now = now();
let create_message = format!("rust client realtime test 0: =?&{now}");
let id = api
.create(json!({"text_not_null": create_message}))
.await
.unwrap();
let record_stream = api.subscribe_ws(&id).await.unwrap();
let updated_message = format!("rust client updated realtime test 0: =?&{now}");
api
.update(&id, json!({"text_not_null": updated_message}))
.await
.unwrap();
api.delete(&id).await.unwrap();
{
let record_events = record_stream.take(2).collect::<Vec<_>>().await;
match &record_events[0] {
DbEvent::Update(Some(serde_json::Value::Object(obj))) => {
assert_eq!(obj["text_not_null"], updated_message);
}
msg => panic!("Unexpected event: {msg:?}"),
};
match &record_events[1] {
DbEvent::Delete(Some(serde_json::Value::Object(obj))) => {
assert_eq!(obj["text_not_null"], updated_message);
}
msg => panic!("Unexpected event: {msg:?}"),
};
}
{
let table_events = table_stream.take(3).collect::<Vec<_>>().await;
match &table_events[0] {
DbEvent::Insert(Some(serde_json::Value::Object(obj))) => {
assert_eq!(obj["text_not_null"], create_message);
}
msg => panic!("Unexpected event: {msg:?}"),
};
match &table_events[1] {
DbEvent::Update(Some(serde_json::Value::Object(obj))) => {
assert_eq!(obj["text_not_null"], updated_message);
}
msg => panic!("Unexpected event: {msg:?}"),
};
match &table_events[2] {
DbEvent::Delete(Some(serde_json::Value::Object(obj))) => {
assert_eq!(obj["text_not_null"], updated_message);
}
msg => panic!("Unexpected event: {msg:?}"),
};
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
struct FileUpload {
/// The file's UUID, should be stripped.
@@ -690,6 +757,12 @@ fn integration_test() {
runtime.block_on(subscription_test());
println!("Ran subscription tests");
#[cfg(feature = "ws")]
{
runtime.block_on(subscription_ws_test());
println!("Ran subscription websocket tests");
}
runtime.block_on(file_upload_json_base64_test());
println!("Ran file upload JSON base64 tests");