From 0a5625bcf638af37f20141d8d28fdd6675aca6ed Mon Sep 17 00:00:00 2001 From: Sebastian Jeltsch Date: Sun, 18 Jan 2026 11:03:12 +0100 Subject: [PATCH] Move WS subscriptions behind the same endpoint as SSE and improve error handling. --- Cargo.lock | 1 + Cargo.toml | 1 + crates/core/Cargo.toml | 3 +- crates/core/src/records/mod.rs | 6 - crates/core/src/records/subscribe.rs | 237 +++++++++++++++++---------- crates/qs/Cargo.toml | 4 +- 6 files changed, 158 insertions(+), 94 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7eec7086..970c7147 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6104,6 +6104,7 @@ dependencies = [ "serde", "serde_json", "serde_path_to_error", + "serde_qs", "serde_urlencoded", "sha2", "sqlformat", diff --git a/Cargo.toml b/Cargo.toml index 123e7e9a..8be35578 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,6 +85,7 @@ rusqlite = { version = "0.38.0", default-features = false, features = ["bundled" rust-embed = { version = "8.4.0", default-features = false, features = ["mime-guess"] } serde = { version = "^1.0.203", features = ["derive"] } serde_json = { version = "^1.0.117" } +serde_qs = { version = "0.15.0", default-features = false } serde_rusqlite = { path = "vendor/serde_rusqlite" } sqlite-vec = { path = "vendor/sqlite-vec/bindings/rust", default-features = false } tokio = { version = "^1.38.0", default-features = false, features = ["macros", "net", "rt-multi-thread", "fs", "signal", "time", "sync"] } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index bf7d53a4..f6f98d2c 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -78,6 +78,7 @@ rust-embed = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } serde_path_to_error = "0.1.16" +serde_qs = { workspace = true } serde_urlencoded = "0.7.1" sha2 = "0.10.8" sqlformat = "0.5.0" @@ -87,9 +88,9 @@ tokio = { workspace = true } tokio-rustls = { version = "0.26.1", default-features = false } tower = "0.5.0" tower-cookies = "0.11.0" -tower_governor = { version = "0.8.0", default-features = false, features = ["axum"] } tower-http = { version = "^0.6.0", default-features = false, features = ["cors", "trace", "fs", "limit"] } tower-service = { version = "0.3.3", default-features = false } +tower_governor = { version = "0.8.0", default-features = false, features = ["axum"] } tracing = { workspace = true } tracing-opentelemetry-instrumentation-sdk = "0.32.0" tracing-subscriber = { workspace = true } diff --git a/crates/core/src/records/mod.rs b/crates/core/src/records/mod.rs index 9b272d78..098f018a 100644 --- a/crates/core/src/records/mod.rs +++ b/crates/core/src/records/mod.rs @@ -86,12 +86,6 @@ pub(crate) fn router(enable_transactions: bool) -> Router { get(subscribe::add_subscription_sse_handler), ); - #[cfg(feature = "ws")] - let router = router.route( - &format!("/{RECORD_API_PATH}/{{name}}/subscribe_ws/{{record}}"), - get(subscribe::add_subscription_ws_handler), - ); - if enable_transactions { return router.route( &format!("/{TRANSACTION_API_PATH}/execute"), diff --git a/crates/core/src/records/subscribe.rs b/crates/core/src/records/subscribe.rs index 2172793b..81ea81ea 100644 --- a/crates/core/src/records/subscribe.rs +++ b/crates/core/src/records/subscribe.rs @@ -1,5 +1,5 @@ use async_channel::{TrySendError, WeakReceiver}; -use axum::extract::{Path, RawQuery, State}; +use axum::extract::{Path, RawQuery, Request, State}; use axum::response::sse::{Event as SseEvent, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use futures_util::{Stream, StreamExt}; @@ -14,7 +14,7 @@ use std::pin::Pin; use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::{Arc, LazyLock}; use std::task::{Context, Poll}; -use trailbase_qs::FilterQuery; +use trailbase_qs::ValueOrComposite; use trailbase_schema::QualifiedName; use trailbase_schema::json::value_to_flat_json; use trailbase_sqlite::connection::{extract_record_values, extract_row_id}; @@ -151,7 +151,7 @@ impl DbEvent { #[inline] fn into_event(self) -> Result { return Ok(Event::Json( - serde_json::to_string(&self).map_err(|err| axum::Error::new(err))?, + serde_json::to_string(&self).map_err(axum::Error::new)?, )); } } @@ -385,7 +385,7 @@ impl PerConnectionState { self: Arc, api: RecordApi, user: Option, - filter: Option, + filter: Option, ) -> Result<(async_channel::Sender, AutoCleanupEventStream), RecordError> { let table_name = api.qualified_name().clone(); @@ -529,7 +529,7 @@ impl SubscriptionManager { &self, api: RecordApi, user: Option, - filter: Option, + filter: Option, ) -> Result { let (sender, receiver) = self .get_per_connection_state(&api) @@ -810,13 +810,35 @@ fn hook_continuation(conn: &rusqlite::Connection, s: ContinuationState) { } } +#[derive(Clone, Default, Debug, PartialEq, Deserialize)] +struct SubscriptionQuery { + /// Map from filter params to filter value. It's a vector in cases like: + /// `col0[$gte]=2&col0[$lte]=10`. + filter: Option, + + /// Whether to use WebSocket instead of default SSE. + ws: Option, +} + +impl SubscriptionQuery { + fn parse(query: &str) -> Result { + // NOTE: We rely on non-strict mode to parse `filter[col0]=a&b%filter[col1]=c`. + let qs = serde_qs::Config::new(9, false); + return qs + .deserialize_bytes::(query.as_bytes()) + .map_err(|_err| RecordError::BadRequest("Invalid query")); + } +} + /// Read record. #[utoipa::path( get, path = "/{name}/subscribe/{record}", tag = "records", + // TODO: Document the params. Requires utoipa support in trailbase_qs or external impl. + // params(SubscritpionParams), responses( - (status = 200, description = "SSE stream of record changes.") + (status = 200, description = "Starts streaming changes to matching records via SSE/WebSocket") ) )] pub async fn add_subscription_sse_handler( @@ -824,6 +846,7 @@ pub async fn add_subscription_sse_handler( Path((api_name, record)): Path<(String, String)>, user: Option, RawQuery(raw_url_query): RawQuery, + request: Request, ) -> Result { let Some(api) = state.lookup_record_api(&api_name) else { return Err(RecordError::ApiNotFound); @@ -833,16 +856,20 @@ pub async fn add_subscription_sse_handler( return Err(RecordError::Forbidden); } - let FilterQuery { filter } = raw_url_query + let SubscriptionQuery { filter, ws } = raw_url_query .as_ref() .map_or_else( - || Ok(FilterQuery::default()), - |query| FilterQuery::parse(query), + || Ok(SubscriptionQuery::default()), + |query| SubscriptionQuery::parse(query), ) .map_err(|_err| { return RecordError::BadRequest("Invalid query"); })?; + if ws.unwrap_or(false) { + return subscribe_ws(state, api, record, filter, user, request).await; + } + return match record.as_str() { "*" => { api.check_table_level_access(Permission::Read, user.as_ref())?; @@ -878,100 +905,134 @@ pub async fn add_subscription_sse_handler( }; } -#[cfg(feature = "ws")] -pub async fn add_subscription_ws_handler( - State(state): State, - ws: axum::extract::ws::WebSocketUpgrade, - Path((api_name, record)): Path<(String, String)>, - RawQuery(raw_url_query): RawQuery, +#[allow(unused)] +pub async fn subscribe_ws( + state: AppState, + api: RecordApi, + record: String, + filter: Option, user: Option, + request: Request, ) -> Result { - use axum::extract::ws::WebSocket; - use futures_util::SinkExt; - - let Some(api) = state.lookup_record_api(&api_name) else { - return Err(RecordError::ApiNotFound); - }; - if !api.enable_subscriptions() { return Err(RecordError::Forbidden); } - let FilterQuery { filter } = raw_url_query - .as_ref() - .map_or_else( - || Ok(FilterQuery::default()), - |query| FilterQuery::parse(query), - ) - .map_err(|_err| { - return RecordError::BadRequest("Invalid query"); - })?; + #[cfg(not(feature = "ws"))] + return Err(RecordError::BadRequest("ws unsupported")); - async fn on_upgrade(socket: WebSocket, receiver: AutoCleanupEventStream) { - let (mut sender, _) = socket.split(); + #[cfg(feature = "ws")] + { + use axum::extract::FromRequestParts; + use axum::extract::ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade}; + use futures_util::SinkExt; - let mut pinned_receiver = std::pin::pin!(receiver); - while let Some(ev) = pinned_receiver.next().await { - match ev.into_ws_event() { - Ok(msg) => { - if let Err(axum_err) = sender.send(msg).await { - debug!("{axum_err}"); - let _ = sender.close(); - return; + let (mut parts, body) = request.into_parts(); + let ws = match WebSocketUpgrade::from_request_parts(&mut parts, &state).await { + Ok(ws) => ws, + Err(err) => { + return Ok(err.into_response()); + } + }; + + async fn abort + std::marker::Unpin>(sender: &mut S, reason: &str) { + // https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1 + // + // 1011 indicates that a server is terminating the connection because + // it encountered an unexpected condition that prevented it from + // fulfilling the request. + let _ = sender + .send(Message::Close(Some(CloseFrame { + code: 1011, + reason: reason.into(), + }))) + .await; + + sender.close(); + } + + // WebSocket uses the HTTP `UPGRADE` mechanism to switch over to dedicated, non-HTTP `ws://` + // protocol. + async fn on_upgrade(mut socket: WebSocket, receiver: AutoCleanupEventStream) { + // NOTE: We're dropping the receiver end of the bidirectional WebSocket. We're not expecting + // any messages from the client. + let (mut sender, _) = socket.split(); + + let mut pinned_receiver = std::pin::pin!(receiver); + while let Some(ev) = pinned_receiver.next().await { + match ev.into_ws_event() { + Ok(msg) => { + if let Err(err) = sender.send(msg).await { + debug!("Sending WS event to client failed: {err}"); + abort(&mut sender, "Failed to send event. Closing channel"); + return; + } } - } - Err(err) => { - debug_assert!(false, "into_ws_event failed: {err}"); - } - }; + Err(err) => { + debug_assert!(false, "into_ws_event failed: {err}"); + } + }; + } } + + return match record.as_str() { + "*" => { + api.check_table_level_access(Permission::Read, user.as_ref())?; + + Ok(ws.on_upgrade(async move |mut socket| { + let Ok((_sender, receiver)) = state + .subscription_manager() + .get_per_connection_state(&api) + .add_table_subscription(api, user, filter) + .await + else { + abort(&mut socket, "subscription failed"); + return; + }; + + on_upgrade(socket, receiver).await + })) + } + _ => { + let record_id = api.primary_key_to_value(record)?; + api + .check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref()) + .await?; + + Ok(ws.on_upgrade(async move |mut socket| { + let Ok((_sender, receiver)) = state + .subscription_manager() + .get_per_connection_state(&api) + .add_record_subscription(api, record_id, user) + .await + else { + abort(&mut socket, "subscription failed"); + return; + }; + + on_upgrade(socket, receiver).await; + })) + } + }; } - - return match record.as_str() { - "*" => { - api.check_table_level_access(Permission::Read, user.as_ref())?; - - let (_sender, receiver) = state - .subscription_manager() - .get_per_connection_state(&api) - .add_table_subscription(api, user, filter) - .await?; - - Ok(ws.on_upgrade(async |socket| on_upgrade(socket, receiver).await)) - } - _ => { - let record_id = api.primary_key_to_value(record)?; - api - .check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref()) - .await?; - - let (_sender, receiver) = state - .subscription_manager() - .get_per_connection_state(&api) - .add_record_subscription(api, record_id, user) - .await?; - - Ok(ws.on_upgrade(async |socket| on_upgrade(socket, receiver).await)) - } - }; } static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0); + +static ESTABLISHED_EVENT: LazyLock = + LazyLock::new(|| Event::Sse(SseEvent::default().comment("subscription established"))); +#[cfg(not(feature = "ws"))] static ACCESS_DENIED_EVENT: LazyLock = LazyLock::new(|| { - #[cfg(not(feature = "ws"))] - return Event::Sse( + Event::Sse( SseEvent::default() .json_data(DbEvent::Error("Access denied".into())) .expect("static"), - ); - - #[cfg(feature = "ws")] - return Event::Json( - serde_json::to_string(&DbEvent::Error("Access denied".into())).expect("static"), - ); + ) +}); +#[cfg(feature = "ws")] +static ACCESS_DENIED_EVENT: LazyLock = LazyLock::new(|| { + Event::Json(serde_json::to_string(&DbEvent::Error("Access denied".into())).expect("static")) }); -static ESTABLISHED_EVENT: LazyLock = - LazyLock::new(|| Event::Sse(SseEvent::default().comment("subscription established"))); const NO_HOOK: Option = None; @@ -1284,6 +1345,7 @@ mod tests { Path(("api_name".to_string(), record_id_raw.to_string())), None, RawQuery(None), + axum::extract::Request::default(), ) .await; @@ -1352,6 +1414,7 @@ mod tests { Path(("api_name".to_string(), "*".to_string())), None, RawQuery(None), + axum::extract::Request::default(), ) .await; @@ -1372,6 +1435,7 @@ mod tests { Path(("api_name".to_string(), "*".to_string())), User::from_auth_token(&state, &user_x_token.auth_token), RawQuery(None), + axum::extract::Request::default(), ) .await .unwrap(); @@ -1399,6 +1463,7 @@ mod tests { Path(("api_name".to_string(), record_id_raw.to_string())), User::from_auth_token(&state, &user_x_token.auth_token), RawQuery(None), + axum::extract::Request::default(), ) .await .unwrap(); @@ -1420,6 +1485,7 @@ mod tests { Path(("api_name".to_string(), record_id_raw.to_string())), User::from_auth_token(&state, &user_y_token.auth_token), RawQuery(None), + axum::extract::Request::default(), ) .await; @@ -1635,7 +1701,8 @@ mod tests { { let filter = - FilterQuery::parse("filter[$and][0][id][$gt]=5&filter[$and][1][id][$lt]=100").unwrap(); + SubscriptionQuery::parse("filter[$and][0][id][$gt]=5&filter[$and][1][id][$lt]=100") + .unwrap(); let stream = manager .add_sse_table_subscription(api, None, Some(filter.filter.unwrap())) diff --git a/crates/qs/Cargo.toml b/crates/qs/Cargo.toml index 8d422687..dc531ee9 100644 --- a/crates/qs/Cargo.toml +++ b/crates/qs/Cargo.toml @@ -11,9 +11,9 @@ readme = "../../README.md" [dependencies] base64 = { version = "0.22.1", default-features = false, features = ["alloc"] } itertools = "0.14.0" -serde = "1.0.219" +serde = { workspace = true } serde-value = "0.7.0" -serde_qs = "0.15.0" +serde_qs = { workspace = true } uuid = { workspace = true } [dev-dependencies]