diff --git a/Cargo.lock b/Cargo.lock index 0b8d5962..46c682fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -459,6 +459,22 @@ dependencies = [ "syn", ] +[[package]] +name = "async-tungstenite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc405d38be14342132609f06f02acaf825ddccfe76c4824a69281e0458ebd4" +dependencies = [ + "atomic-waker", + "futures-core", + "futures-io", + "futures-task", + "futures-util", + "log", + "pin-project-lite", + "tungstenite", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -4646,6 +4662,26 @@ dependencies = [ "web-sys", ] +[[package]] +name = "reqwest-websocket" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7705b649c3b66b85c4e9c304a6898b1ae3eecb880c474720ebf925e4a932ae02" +dependencies = [ + "async-tungstenite", + "bytes", + "futures-util", + "reqwest 0.13.1", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tokio-util", + "tracing", + "tungstenite", + "web-sys", +] + [[package]] name = "reserve-port" version = "2.3.0" @@ -5779,6 +5815,7 @@ checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", @@ -6214,6 +6251,7 @@ dependencies = [ "lazy_static", "parking_lot", "reqwest 0.13.1", + "reqwest-websocket", "serde", "serde_json", "temp-dir", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 2e406c14..3c6f49dc 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -15,6 +15,7 @@ name = "trail" default = [] swagger = ["dep:utoipa-swagger-ui"] vendor-ssl = ["dep:openssl"] +ws = ["trailbase/ws"] [dependencies] axum = { version = "^0.8.1", features=["multipart"] } diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index da456224..2eca31b9 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -11,12 +11,17 @@ exclude = [ "tests", ] +[features] +default = [] +ws = ["dep:reqwest-websocket"] + [dependencies] eventsource-stream = { version = "0.2.3", features = [] } futures-lite = "2.6.1" jsonwebtoken = { version = "10.2.0", default-features = false, features = ["rust_crypto"] } parking_lot = { workspace = true } reqwest = { version = "0.13.1", features = ["stream"] } +reqwest-websocket = { version = "0.6.0", features = ["json"], optional = true } serde = { workspace = true } serde_json = { workspace = true } thiserror = "2.0.12" diff --git a/crates/client/src/lib.rs b/crates/client/src/lib.rs index cb713fef..48c36903 100644 --- a/crates/client/src/lib.rs +++ b/crates/client/src/lib.rs @@ -42,6 +42,10 @@ pub enum Error { // NOTE: This error is leaky but comprehensively unpacking reqwest is unsustainable. #[error("Reqwest: {0}")] OtherReqwest(reqwest::Error), + + #[cfg(feature = "ws")] + #[error("WebSocket: {0}")] + WebSocket(#[from] reqwest_websocket::Error), } impl From for Error { @@ -189,7 +193,7 @@ struct ThinClient { } impl ThinClient { - async fn fetch( + async fn fetch_impl( &self, path: &str, headers: HeaderMap, @@ -220,6 +224,41 @@ impl ThinClient { return Ok(self.client.execute(request).await?); } + + #[cfg(feature = "ws")] + async fn upgrade_ws_impl( + &self, + path: &str, + headers: HeaderMap, + method: Method, + body: Option<&T>, + query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>, + ) -> Result { + use reqwest_websocket::Upgrade; + + assert!(path.starts_with("/")); + + let mut url = self.url.clone(); + url.set_path(path); + + if let Some(query_params) = query_params { + let mut params = url.query_pairs_mut(); + for (key, value) in query_params { + params.append_pair(key, value); + } + } + + let request = { + let mut builder = self.client.request(method, url).headers(headers); + if let Some(ref body) = body { + let json = serde_json::to_string(body).map_err(Error::RecordSerialization)?; + builder = builder.body(json); + } + builder.upgrade() + }; + + return Ok(request.send().await?); + } } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] @@ -556,15 +595,54 @@ impl RecordApi { .bytes_stream() .eventsource() .filter_map(|event_or| { - if let Ok(event) = event_or - && let Ok(db_event) = serde_json::from_str::(&event.data) - { - return Some(db_event); + // QUESTION: Should we instead return a `Stream>` to allow for + // better error handling here. + if let Ok(event) = event_or { + return serde_json::from_str::(&event.data).ok(); } return None; }), ); } + + #[cfg(feature = "ws")] + pub async fn subscribe_ws<'a, T: RecordId<'a>>( + &self, + id: T, + ) -> Result + use, Error> { + let response = self + .client + .upgrade_ws( + &format!( + "/{RECORD_API}/{name}/subscribe/{id}", + name = self.name, + id = id.serialized_id() + ), + Method::GET, + None::<&()>, + Some(&[("ws".into(), "true".into())]), + ) + .await?; + + let websocket = response.into_websocket().await?; + + return Ok(websocket.filter_map(|message| { + use reqwest_websocket::Message; + + return match message { + Ok(Message::Text(msg)) => serde_json::from_str::(&msg) + .map_err(|err| { + warn!("json error: {err}"); + return err; + }) + .ok(), + msg => { + warn!("unexpected msg: {msg:?}"); + None + } + }; + })); + } } #[derive(Clone, Debug)] @@ -617,12 +695,35 @@ impl ClientState { return Ok( self .client - .fetch(path, headers, method, body, query_params) + .fetch_impl(path, headers, method, body, query_params) .await? .error_for_status()?, ); } + #[cfg(feature = "ws")] + #[inline] + async fn upgrade_ws( + &self, + path: &str, + method: Method, + body: Option<&T>, + query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>, + ) -> Result { + let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp(); + if let Some(refresh_token) = refresh_token { + let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?; + + headers = new_tokens.headers.clone(); + *self.tokens.write() = new_tokens; + } + + return self + .client + .upgrade_ws_impl(path, headers, method, body, query_params) + .await; + } + #[inline] fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option) { #[inline] @@ -659,7 +760,7 @@ impl ClientState { } let response = client - .fetch( + .fetch_impl( &format!("/{AUTH_API}/refresh"), headers, Method::POST, diff --git a/crates/client/tests/integration_test.rs b/crates/client/tests/integration_test.rs index c2817355..b68f10ea 100644 --- a/crates/client/tests/integration_test.rs +++ b/crates/client/tests/integration_test.rs @@ -49,6 +49,10 @@ fn start_server() -> Result, std::io::Error> { let args = [ "run".to_string(), + #[cfg(feature = "ws")] + { + "--features=ws".to_string() + }, "--".to_string(), format!("--data-dir={depot_path}"), "run".to_string(), @@ -471,6 +475,69 @@ async fn subscription_test() { } } +#[cfg(feature = "ws")] +async fn subscription_ws_test() { + let client = connect().await; + let api = client.records("simple_strict_table"); + + let table_stream = api.subscribe_ws("*").await.unwrap(); + + let now = now(); + let create_message = format!("rust client realtime test 0: =?&{now}"); + let id = api + .create(json!({"text_not_null": create_message})) + .await + .unwrap(); + + let record_stream = api.subscribe_ws(&id).await.unwrap(); + + let updated_message = format!("rust client updated realtime test 0: =?&{now}"); + api + .update(&id, json!({"text_not_null": updated_message})) + .await + .unwrap(); + + api.delete(&id).await.unwrap(); + + { + let record_events = record_stream.take(2).collect::>().await; + match &record_events[0] { + DbEvent::Update(Some(serde_json::Value::Object(obj))) => { + assert_eq!(obj["text_not_null"], updated_message); + } + msg => panic!("Unexpected event: {msg:?}"), + }; + match &record_events[1] { + DbEvent::Delete(Some(serde_json::Value::Object(obj))) => { + assert_eq!(obj["text_not_null"], updated_message); + } + msg => panic!("Unexpected event: {msg:?}"), + }; + } + + { + let table_events = table_stream.take(3).collect::>().await; + match &table_events[0] { + DbEvent::Insert(Some(serde_json::Value::Object(obj))) => { + assert_eq!(obj["text_not_null"], create_message); + } + msg => panic!("Unexpected event: {msg:?}"), + }; + match &table_events[1] { + DbEvent::Update(Some(serde_json::Value::Object(obj))) => { + assert_eq!(obj["text_not_null"], updated_message); + } + msg => panic!("Unexpected event: {msg:?}"), + }; + match &table_events[2] { + DbEvent::Delete(Some(serde_json::Value::Object(obj))) => { + assert_eq!(obj["text_not_null"], updated_message); + } + msg => panic!("Unexpected event: {msg:?}"), + }; + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] struct FileUpload { /// The file's UUID, should be stripped. @@ -690,6 +757,12 @@ fn integration_test() { runtime.block_on(subscription_test()); println!("Ran subscription tests"); + #[cfg(feature = "ws")] + { + runtime.block_on(subscription_ws_test()); + println!("Ran subscription websocket tests"); + } + runtime.block_on(file_upload_json_base64_test()); println!("Ran file upload JSON base64 tests");