mirror of
https://github.com/trailbaseio/trailbase.git
synced 2026-05-23 02:28:34 -05:00
Add client apis for WebSocket subscriptions (requires a ws-enabled server build) and tests behind the "ws" feature.
This commit is contained in:
Generated
+38
@@ -459,6 +459,22 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-tungstenite"
|
||||
version = "0.32.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8acc405d38be14342132609f06f02acaf825ddccfe76c4824a69281e0458ebd4"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
@@ -4646,6 +4662,26 @@ dependencies = [
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reqwest-websocket"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7705b649c3b66b85c4e9c304a6898b1ae3eecb880c474720ebf925e4a932ae02"
|
||||
dependencies = [
|
||||
"async-tungstenite",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"reqwest 0.13.1",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tungstenite",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reserve-port"
|
||||
version = "2.3.0"
|
||||
@@ -5779,6 +5815,7 @@ checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
@@ -6214,6 +6251,7 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"parking_lot",
|
||||
"reqwest 0.13.1",
|
||||
"reqwest-websocket",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"temp-dir",
|
||||
|
||||
@@ -15,6 +15,7 @@ name = "trail"
|
||||
default = []
|
||||
swagger = ["dep:utoipa-swagger-ui"]
|
||||
vendor-ssl = ["dep:openssl"]
|
||||
ws = ["trailbase/ws"]
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "^0.8.1", features=["multipart"] }
|
||||
|
||||
@@ -11,12 +11,17 @@ exclude = [
|
||||
"tests",
|
||||
]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
ws = ["dep:reqwest-websocket"]
|
||||
|
||||
[dependencies]
|
||||
eventsource-stream = { version = "0.2.3", features = [] }
|
||||
futures-lite = "2.6.1"
|
||||
jsonwebtoken = { version = "10.2.0", default-features = false, features = ["rust_crypto"] }
|
||||
parking_lot = { workspace = true }
|
||||
reqwest = { version = "0.13.1", features = ["stream"] }
|
||||
reqwest-websocket = { version = "0.6.0", features = ["json"], optional = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = "2.0.12"
|
||||
|
||||
+108
-7
@@ -42,6 +42,10 @@ pub enum Error {
|
||||
// NOTE: This error is leaky but comprehensively unpacking reqwest is unsustainable.
|
||||
#[error("Reqwest: {0}")]
|
||||
OtherReqwest(reqwest::Error),
|
||||
|
||||
#[cfg(feature = "ws")]
|
||||
#[error("WebSocket: {0}")]
|
||||
WebSocket(#[from] reqwest_websocket::Error),
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for Error {
|
||||
@@ -189,7 +193,7 @@ struct ThinClient {
|
||||
}
|
||||
|
||||
impl ThinClient {
|
||||
async fn fetch<T: Serialize>(
|
||||
async fn fetch_impl<T: Serialize>(
|
||||
&self,
|
||||
path: &str,
|
||||
headers: HeaderMap,
|
||||
@@ -220,6 +224,41 @@ impl ThinClient {
|
||||
|
||||
return Ok(self.client.execute(request).await?);
|
||||
}
|
||||
|
||||
#[cfg(feature = "ws")]
|
||||
async fn upgrade_ws_impl<T: Serialize>(
|
||||
&self,
|
||||
path: &str,
|
||||
headers: HeaderMap,
|
||||
method: Method,
|
||||
body: Option<&T>,
|
||||
query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
|
||||
) -> Result<reqwest_websocket::UpgradeResponse, Error> {
|
||||
use reqwest_websocket::Upgrade;
|
||||
|
||||
assert!(path.starts_with("/"));
|
||||
|
||||
let mut url = self.url.clone();
|
||||
url.set_path(path);
|
||||
|
||||
if let Some(query_params) = query_params {
|
||||
let mut params = url.query_pairs_mut();
|
||||
for (key, value) in query_params {
|
||||
params.append_pair(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
let request = {
|
||||
let mut builder = self.client.request(method, url).headers(headers);
|
||||
if let Some(ref body) = body {
|
||||
let json = serde_json::to_string(body).map_err(Error::RecordSerialization)?;
|
||||
builder = builder.body(json);
|
||||
}
|
||||
builder.upgrade()
|
||||
};
|
||||
|
||||
return Ok(request.send().await?);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
@@ -556,15 +595,54 @@ impl RecordApi {
|
||||
.bytes_stream()
|
||||
.eventsource()
|
||||
.filter_map(|event_or| {
|
||||
if let Ok(event) = event_or
|
||||
&& let Ok(db_event) = serde_json::from_str::<DbEvent>(&event.data)
|
||||
{
|
||||
return Some(db_event);
|
||||
// QUESTION: Should we instead return a `Stream<Item = Result<DbEvent, _>>` to allow for
|
||||
// better error handling here.
|
||||
if let Ok(event) = event_or {
|
||||
return serde_json::from_str::<DbEvent>(&event.data).ok();
|
||||
}
|
||||
return None;
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "ws")]
|
||||
pub async fn subscribe_ws<'a, T: RecordId<'a>>(
|
||||
&self,
|
||||
id: T,
|
||||
) -> Result<impl Stream<Item = DbEvent> + use<T>, Error> {
|
||||
let response = self
|
||||
.client
|
||||
.upgrade_ws(
|
||||
&format!(
|
||||
"/{RECORD_API}/{name}/subscribe/{id}",
|
||||
name = self.name,
|
||||
id = id.serialized_id()
|
||||
),
|
||||
Method::GET,
|
||||
None::<&()>,
|
||||
Some(&[("ws".into(), "true".into())]),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let websocket = response.into_websocket().await?;
|
||||
|
||||
return Ok(websocket.filter_map(|message| {
|
||||
use reqwest_websocket::Message;
|
||||
|
||||
return match message {
|
||||
Ok(Message::Text(msg)) => serde_json::from_str::<DbEvent>(&msg)
|
||||
.map_err(|err| {
|
||||
warn!("json error: {err}");
|
||||
return err;
|
||||
})
|
||||
.ok(),
|
||||
msg => {
|
||||
warn!("unexpected msg: {msg:?}");
|
||||
None
|
||||
}
|
||||
};
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -617,12 +695,35 @@ impl ClientState {
|
||||
return Ok(
|
||||
self
|
||||
.client
|
||||
.fetch(path, headers, method, body, query_params)
|
||||
.fetch_impl(path, headers, method, body, query_params)
|
||||
.await?
|
||||
.error_for_status()?,
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "ws")]
|
||||
#[inline]
|
||||
async fn upgrade_ws<T: Serialize>(
|
||||
&self,
|
||||
path: &str,
|
||||
method: Method,
|
||||
body: Option<&T>,
|
||||
query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
|
||||
) -> Result<reqwest_websocket::UpgradeResponse, Error> {
|
||||
let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
|
||||
if let Some(refresh_token) = refresh_token {
|
||||
let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?;
|
||||
|
||||
headers = new_tokens.headers.clone();
|
||||
*self.tokens.write() = new_tokens;
|
||||
}
|
||||
|
||||
return self
|
||||
.client
|
||||
.upgrade_ws_impl(path, headers, method, body, query_params)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
|
||||
#[inline]
|
||||
@@ -659,7 +760,7 @@ impl ClientState {
|
||||
}
|
||||
|
||||
let response = client
|
||||
.fetch(
|
||||
.fetch_impl(
|
||||
&format!("/{AUTH_API}/refresh"),
|
||||
headers,
|
||||
Method::POST,
|
||||
|
||||
@@ -49,6 +49,10 @@ fn start_server() -> Result<Option<Server>, std::io::Error> {
|
||||
|
||||
let args = [
|
||||
"run".to_string(),
|
||||
#[cfg(feature = "ws")]
|
||||
{
|
||||
"--features=ws".to_string()
|
||||
},
|
||||
"--".to_string(),
|
||||
format!("--data-dir={depot_path}"),
|
||||
"run".to_string(),
|
||||
@@ -471,6 +475,69 @@ async fn subscription_test() {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "ws")]
|
||||
async fn subscription_ws_test() {
|
||||
let client = connect().await;
|
||||
let api = client.records("simple_strict_table");
|
||||
|
||||
let table_stream = api.subscribe_ws("*").await.unwrap();
|
||||
|
||||
let now = now();
|
||||
let create_message = format!("rust client realtime test 0: =?&{now}");
|
||||
let id = api
|
||||
.create(json!({"text_not_null": create_message}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let record_stream = api.subscribe_ws(&id).await.unwrap();
|
||||
|
||||
let updated_message = format!("rust client updated realtime test 0: =?&{now}");
|
||||
api
|
||||
.update(&id, json!({"text_not_null": updated_message}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
api.delete(&id).await.unwrap();
|
||||
|
||||
{
|
||||
let record_events = record_stream.take(2).collect::<Vec<_>>().await;
|
||||
match &record_events[0] {
|
||||
DbEvent::Update(Some(serde_json::Value::Object(obj))) => {
|
||||
assert_eq!(obj["text_not_null"], updated_message);
|
||||
}
|
||||
msg => panic!("Unexpected event: {msg:?}"),
|
||||
};
|
||||
match &record_events[1] {
|
||||
DbEvent::Delete(Some(serde_json::Value::Object(obj))) => {
|
||||
assert_eq!(obj["text_not_null"], updated_message);
|
||||
}
|
||||
msg => panic!("Unexpected event: {msg:?}"),
|
||||
};
|
||||
}
|
||||
|
||||
{
|
||||
let table_events = table_stream.take(3).collect::<Vec<_>>().await;
|
||||
match &table_events[0] {
|
||||
DbEvent::Insert(Some(serde_json::Value::Object(obj))) => {
|
||||
assert_eq!(obj["text_not_null"], create_message);
|
||||
}
|
||||
msg => panic!("Unexpected event: {msg:?}"),
|
||||
};
|
||||
match &table_events[1] {
|
||||
DbEvent::Update(Some(serde_json::Value::Object(obj))) => {
|
||||
assert_eq!(obj["text_not_null"], updated_message);
|
||||
}
|
||||
msg => panic!("Unexpected event: {msg:?}"),
|
||||
};
|
||||
match &table_events[2] {
|
||||
DbEvent::Delete(Some(serde_json::Value::Object(obj))) => {
|
||||
assert_eq!(obj["text_not_null"], updated_message);
|
||||
}
|
||||
msg => panic!("Unexpected event: {msg:?}"),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
struct FileUpload {
|
||||
/// The file's UUID, should be stripped.
|
||||
@@ -690,6 +757,12 @@ fn integration_test() {
|
||||
runtime.block_on(subscription_test());
|
||||
println!("Ran subscription tests");
|
||||
|
||||
#[cfg(feature = "ws")]
|
||||
{
|
||||
runtime.block_on(subscription_ws_test());
|
||||
println!("Ran subscription websocket tests");
|
||||
}
|
||||
|
||||
runtime.block_on(file_upload_json_base64_test());
|
||||
println!("Ran file upload JSON base64 tests");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user