Move WS subscriptions behind the same endpoint as SSE and improve error handling.

This commit is contained in:
Sebastian Jeltsch
2026-01-18 11:03:12 +01:00
parent 9a29ce2388
commit 0a5625bcf6
6 changed files with 158 additions and 94 deletions

1
Cargo.lock generated
View File

@@ -6104,6 +6104,7 @@ dependencies = [
"serde",
"serde_json",
"serde_path_to_error",
"serde_qs",
"serde_urlencoded",
"sha2",
"sqlformat",

View File

@@ -85,6 +85,7 @@ rusqlite = { version = "0.38.0", default-features = false, features = ["bundled"
rust-embed = { version = "8.4.0", default-features = false, features = ["mime-guess"] }
serde = { version = "^1.0.203", features = ["derive"] }
serde_json = { version = "^1.0.117" }
serde_qs = { version = "0.15.0", default-features = false }
serde_rusqlite = { path = "vendor/serde_rusqlite" }
sqlite-vec = { path = "vendor/sqlite-vec/bindings/rust", default-features = false }
tokio = { version = "^1.38.0", default-features = false, features = ["macros", "net", "rt-multi-thread", "fs", "signal", "time", "sync"] }

View File

@@ -78,6 +78,7 @@ rust-embed = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
serde_path_to_error = "0.1.16"
serde_qs = { workspace = true }
serde_urlencoded = "0.7.1"
sha2 = "0.10.8"
sqlformat = "0.5.0"
@@ -87,9 +88,9 @@ tokio = { workspace = true }
tokio-rustls = { version = "0.26.1", default-features = false }
tower = "0.5.0"
tower-cookies = "0.11.0"
tower_governor = { version = "0.8.0", default-features = false, features = ["axum"] }
tower-http = { version = "^0.6.0", default-features = false, features = ["cors", "trace", "fs", "limit"] }
tower-service = { version = "0.3.3", default-features = false }
tower_governor = { version = "0.8.0", default-features = false, features = ["axum"] }
tracing = { workspace = true }
tracing-opentelemetry-instrumentation-sdk = "0.32.0"
tracing-subscriber = { workspace = true }

View File

@@ -86,12 +86,6 @@ pub(crate) fn router(enable_transactions: bool) -> Router<AppState> {
get(subscribe::add_subscription_sse_handler),
);
#[cfg(feature = "ws")]
let router = router.route(
&format!("/{RECORD_API_PATH}/{{name}}/subscribe_ws/{{record}}"),
get(subscribe::add_subscription_ws_handler),
);
if enable_transactions {
return router.route(
&format!("/{TRANSACTION_API_PATH}/execute"),

View File

@@ -1,5 +1,5 @@
use async_channel::{TrySendError, WeakReceiver};
use axum::extract::{Path, RawQuery, State};
use axum::extract::{Path, RawQuery, Request, State};
use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use futures_util::{Stream, StreamExt};
@@ -14,7 +14,7 @@ use std::pin::Pin;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, LazyLock};
use std::task::{Context, Poll};
use trailbase_qs::FilterQuery;
use trailbase_qs::ValueOrComposite;
use trailbase_schema::QualifiedName;
use trailbase_schema::json::value_to_flat_json;
use trailbase_sqlite::connection::{extract_record_values, extract_row_id};
@@ -151,7 +151,7 @@ impl DbEvent {
#[inline]
fn into_event(self) -> Result<Event, axum::Error> {
return Ok(Event::Json(
serde_json::to_string(&self).map_err(|err| axum::Error::new(err))?,
serde_json::to_string(&self).map_err(axum::Error::new)?,
));
}
}
@@ -385,7 +385,7 @@ impl PerConnectionState {
self: Arc<Self>,
api: RecordApi,
user: Option<User>,
filter: Option<trailbase_qs::ValueOrComposite>,
filter: Option<ValueOrComposite>,
) -> Result<(async_channel::Sender<Event>, AutoCleanupEventStream), RecordError> {
let table_name = api.qualified_name().clone();
@@ -529,7 +529,7 @@ impl SubscriptionManager {
&self,
api: RecordApi,
user: Option<User>,
filter: Option<trailbase_qs::ValueOrComposite>,
filter: Option<ValueOrComposite>,
) -> Result<AutoCleanupEventStream, RecordError> {
let (sender, receiver) = self
.get_per_connection_state(&api)
@@ -810,13 +810,35 @@ fn hook_continuation(conn: &rusqlite::Connection, s: ContinuationState) {
}
}
#[derive(Clone, Default, Debug, PartialEq, Deserialize)]
struct SubscriptionQuery {
/// Map from filter params to filter value. It's a vector in cases like:
/// `col0[$gte]=2&col0[$lte]=10`.
filter: Option<ValueOrComposite>,
/// Whether to use WebSocket instead of default SSE.
ws: Option<bool>,
}
impl SubscriptionQuery {
fn parse(query: &str) -> Result<SubscriptionQuery, RecordError> {
// NOTE: We rely on non-strict mode to parse `filter[col0]=a&b%filter[col1]=c`.
let qs = serde_qs::Config::new(9, false);
return qs
.deserialize_bytes::<SubscriptionQuery>(query.as_bytes())
.map_err(|_err| RecordError::BadRequest("Invalid query"));
}
}
/// Read record.
#[utoipa::path(
get,
path = "/{name}/subscribe/{record}",
tag = "records",
// TODO: Document the params. Requires utoipa support in trailbase_qs or external impl.
// params(SubscritpionParams),
responses(
(status = 200, description = "SSE stream of record changes.")
(status = 200, description = "Starts streaming changes to matching records via SSE/WebSocket")
)
)]
pub async fn add_subscription_sse_handler(
@@ -824,6 +846,7 @@ pub async fn add_subscription_sse_handler(
Path((api_name, record)): Path<(String, String)>,
user: Option<User>,
RawQuery(raw_url_query): RawQuery,
request: Request,
) -> Result<Response, RecordError> {
let Some(api) = state.lookup_record_api(&api_name) else {
return Err(RecordError::ApiNotFound);
@@ -833,16 +856,20 @@ pub async fn add_subscription_sse_handler(
return Err(RecordError::Forbidden);
}
let FilterQuery { filter } = raw_url_query
let SubscriptionQuery { filter, ws } = raw_url_query
.as_ref()
.map_or_else(
|| Ok(FilterQuery::default()),
|query| FilterQuery::parse(query),
|| Ok(SubscriptionQuery::default()),
|query| SubscriptionQuery::parse(query),
)
.map_err(|_err| {
return RecordError::BadRequest("Invalid query");
})?;
if ws.unwrap_or(false) {
return subscribe_ws(state, api, record, filter, user, request).await;
}
return match record.as_str() {
"*" => {
api.check_table_level_access(Permission::Read, user.as_ref())?;
@@ -878,100 +905,134 @@ pub async fn add_subscription_sse_handler(
};
}
#[cfg(feature = "ws")]
pub async fn add_subscription_ws_handler(
State(state): State<AppState>,
ws: axum::extract::ws::WebSocketUpgrade,
Path((api_name, record)): Path<(String, String)>,
RawQuery(raw_url_query): RawQuery,
#[allow(unused)]
pub async fn subscribe_ws(
state: AppState,
api: RecordApi,
record: String,
filter: Option<ValueOrComposite>,
user: Option<User>,
request: Request,
) -> Result<Response, RecordError> {
use axum::extract::ws::WebSocket;
use futures_util::SinkExt;
let Some(api) = state.lookup_record_api(&api_name) else {
return Err(RecordError::ApiNotFound);
};
if !api.enable_subscriptions() {
return Err(RecordError::Forbidden);
}
let FilterQuery { filter } = raw_url_query
.as_ref()
.map_or_else(
|| Ok(FilterQuery::default()),
|query| FilterQuery::parse(query),
)
.map_err(|_err| {
return RecordError::BadRequest("Invalid query");
})?;
#[cfg(not(feature = "ws"))]
return Err(RecordError::BadRequest("ws unsupported"));
async fn on_upgrade(socket: WebSocket, receiver: AutoCleanupEventStream) {
let (mut sender, _) = socket.split();
#[cfg(feature = "ws")]
{
use axum::extract::FromRequestParts;
use axum::extract::ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade};
use futures_util::SinkExt;
let mut pinned_receiver = std::pin::pin!(receiver);
while let Some(ev) = pinned_receiver.next().await {
match ev.into_ws_event() {
Ok(msg) => {
if let Err(axum_err) = sender.send(msg).await {
debug!("{axum_err}");
let _ = sender.close();
return;
let (mut parts, body) = request.into_parts();
let ws = match WebSocketUpgrade::from_request_parts(&mut parts, &state).await {
Ok(ws) => ws,
Err(err) => {
return Ok(err.into_response());
}
};
async fn abort<S: SinkExt<Message> + std::marker::Unpin>(sender: &mut S, reason: &str) {
// https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
//
// 1011 indicates that a server is terminating the connection because
// it encountered an unexpected condition that prevented it from
// fulfilling the request.
let _ = sender
.send(Message::Close(Some(CloseFrame {
code: 1011,
reason: reason.into(),
})))
.await;
sender.close();
}
// WebSocket uses the HTTP `UPGRADE` mechanism to switch over to dedicated, non-HTTP `ws://`
// protocol.
async fn on_upgrade(mut socket: WebSocket, receiver: AutoCleanupEventStream) {
// NOTE: We're dropping the receiver end of the bidirectional WebSocket. We're not expecting
// any messages from the client.
let (mut sender, _) = socket.split();
let mut pinned_receiver = std::pin::pin!(receiver);
while let Some(ev) = pinned_receiver.next().await {
match ev.into_ws_event() {
Ok(msg) => {
if let Err(err) = sender.send(msg).await {
debug!("Sending WS event to client failed: {err}");
abort(&mut sender, "Failed to send event. Closing channel");
return;
}
}
}
Err(err) => {
debug_assert!(false, "into_ws_event failed: {err}");
}
};
Err(err) => {
debug_assert!(false, "into_ws_event failed: {err}");
}
};
}
}
return match record.as_str() {
"*" => {
api.check_table_level_access(Permission::Read, user.as_ref())?;
Ok(ws.on_upgrade(async move |mut socket| {
let Ok((_sender, receiver)) = state
.subscription_manager()
.get_per_connection_state(&api)
.add_table_subscription(api, user, filter)
.await
else {
abort(&mut socket, "subscription failed");
return;
};
on_upgrade(socket, receiver).await
}))
}
_ => {
let record_id = api.primary_key_to_value(record)?;
api
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
.await?;
Ok(ws.on_upgrade(async move |mut socket| {
let Ok((_sender, receiver)) = state
.subscription_manager()
.get_per_connection_state(&api)
.add_record_subscription(api, record_id, user)
.await
else {
abort(&mut socket, "subscription failed");
return;
};
on_upgrade(socket, receiver).await;
}))
}
};
}
return match record.as_str() {
"*" => {
api.check_table_level_access(Permission::Read, user.as_ref())?;
let (_sender, receiver) = state
.subscription_manager()
.get_per_connection_state(&api)
.add_table_subscription(api, user, filter)
.await?;
Ok(ws.on_upgrade(async |socket| on_upgrade(socket, receiver).await))
}
_ => {
let record_id = api.primary_key_to_value(record)?;
api
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
.await?;
let (_sender, receiver) = state
.subscription_manager()
.get_per_connection_state(&api)
.add_record_subscription(api, record_id, user)
.await?;
Ok(ws.on_upgrade(async |socket| on_upgrade(socket, receiver).await))
}
};
}
static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0);
static ESTABLISHED_EVENT: LazyLock<Event> =
LazyLock::new(|| Event::Sse(SseEvent::default().comment("subscription established")));
#[cfg(not(feature = "ws"))]
static ACCESS_DENIED_EVENT: LazyLock<Event> = LazyLock::new(|| {
#[cfg(not(feature = "ws"))]
return Event::Sse(
Event::Sse(
SseEvent::default()
.json_data(DbEvent::Error("Access denied".into()))
.expect("static"),
);
#[cfg(feature = "ws")]
return Event::Json(
serde_json::to_string(&DbEvent::Error("Access denied".into())).expect("static"),
);
)
});
#[cfg(feature = "ws")]
static ACCESS_DENIED_EVENT: LazyLock<Event> = LazyLock::new(|| {
Event::Json(serde_json::to_string(&DbEvent::Error("Access denied".into())).expect("static"))
});
static ESTABLISHED_EVENT: LazyLock<Event> =
LazyLock::new(|| Event::Sse(SseEvent::default().comment("subscription established")));
const NO_HOOK: Option<fn(Action, &str, &str, &PreUpdateCase)> = None;
@@ -1284,6 +1345,7 @@ mod tests {
Path(("api_name".to_string(), record_id_raw.to_string())),
None,
RawQuery(None),
axum::extract::Request::default(),
)
.await;
@@ -1352,6 +1414,7 @@ mod tests {
Path(("api_name".to_string(), "*".to_string())),
None,
RawQuery(None),
axum::extract::Request::default(),
)
.await;
@@ -1372,6 +1435,7 @@ mod tests {
Path(("api_name".to_string(), "*".to_string())),
User::from_auth_token(&state, &user_x_token.auth_token),
RawQuery(None),
axum::extract::Request::default(),
)
.await
.unwrap();
@@ -1399,6 +1463,7 @@ mod tests {
Path(("api_name".to_string(), record_id_raw.to_string())),
User::from_auth_token(&state, &user_x_token.auth_token),
RawQuery(None),
axum::extract::Request::default(),
)
.await
.unwrap();
@@ -1420,6 +1485,7 @@ mod tests {
Path(("api_name".to_string(), record_id_raw.to_string())),
User::from_auth_token(&state, &user_y_token.auth_token),
RawQuery(None),
axum::extract::Request::default(),
)
.await;
@@ -1635,7 +1701,8 @@ mod tests {
{
let filter =
FilterQuery::parse("filter[$and][0][id][$gt]=5&filter[$and][1][id][$lt]=100").unwrap();
SubscriptionQuery::parse("filter[$and][0][id][$gt]=5&filter[$and][1][id][$lt]=100")
.unwrap();
let stream = manager
.add_sse_table_subscription(api, None, Some(filter.filter.unwrap()))

View File

@@ -11,9 +11,9 @@ readme = "../../README.md"
[dependencies]
base64 = { version = "0.22.1", default-features = false, features = ["alloc"] }
itertools = "0.14.0"
serde = "1.0.219"
serde = { workspace = true }
serde-value = "0.7.0"
serde_qs = "0.15.0"
serde_qs = { workspace = true }
uuid = { workspace = true }
[dev-dependencies]