Massively overhaul change event execution model.

* minimize work in SQLite's preupdate hook
* push brokering onto a separate thread
* kick filter and acl checking further downstream into the SSE handlers.
* add layered sequence numbers to detect server-side event losses and allow clients to detect client-side event losses.
This commit is contained in:
Sebastian Jeltsch
2026-04-02 16:46:30 +02:00
parent 175800287f
commit e8eaacfc7c
8 changed files with 792 additions and 608 deletions
+10 -4
View File
@@ -20,6 +20,7 @@ pub enum ValueOrComposite {
Composite(Combiner, Vec<ValueOrComposite>),
}
#[derive(Clone, Debug, PartialEq)]
pub enum Filter {
Passthrough,
Record(ValueOrComposite),
@@ -224,7 +225,7 @@ fn parse_geometries(record: &[u8], filter: &str) -> Option<(geos::Geometry, geos
pub(crate) fn apply_filter_recursively_to_record(
filter: &ValueOrComposite,
record: &indexmap::IndexMap<&str, rusqlite::types::Value>,
record: &indexmap::IndexMap<String, rusqlite::types::Value>,
) -> bool {
return match filter {
ValueOrComposite::Value(col_op_value) => {
@@ -314,7 +315,10 @@ mod tests {
#[test]
fn test_basic_value_filter() {
let record: IndexMap<&str, Value> = IndexMap::from([("a", Value::Text("a value".to_string()))]);
let record: IndexMap<String, Value> = IndexMap::from([(
"a".to_string(),
Value::Text("a value".to_string().to_string()),
)]);
assert!(apply_filter_recursively_to_record(
&ValueOrComposite::Value(ColumnOpValue {
@@ -355,8 +359,10 @@ mod tests {
#[test]
fn test_basic_composite_filter() {
let record: IndexMap<&str, Value> =
IndexMap::from([("a", Value::Integer(5)), ("b", Value::Integer(-5))]);
let record: IndexMap<String, Value> = IndexMap::from([
("a".to_string(), Value::Integer(5)),
("b".to_string(), Value::Integer(-5)),
]);
assert!(apply_filter_recursively_to_record(
&ValueOrComposite::Composite(
+45 -28
View File
@@ -564,43 +564,60 @@ impl RecordApi {
/// Check if the given user (if any) can access a record given the request and the operation.
#[inline]
pub(crate) fn check_record_level_read_access_for_subscriptions(
pub(crate) async fn check_record_level_read_access_for_subscriptions(
&self,
conn: &rusqlite::Connection,
params: SubscriptionAclParams<'_>,
conn: &trailbase_sqlite::Connection,
record: &Arc<indexmap::IndexMap<String, rusqlite::types::Value>>,
user: Option<&User>,
) -> Result<(), RecordError> {
// First check table level access and if present check row-level access based on access rule.
self.check_table_level_access(Permission::Read, params.user)?;
self.check_table_level_access(Permission::Read, user)?;
let Some(ref access_query) = self.state.subscription_read_access_query else {
return Ok(());
};
let mut stmt = conn
.prepare_cached(access_query)
.map_err(|_err| RecordError::Forbidden)?;
conn
.call_reader({
let access_query = access_query.clone();
let user = user.cloned();
let record = record.clone();
// NOTE: the `bind` impl does the heavy lifting.
params
.bind(&mut stmt)
.map_err(|_err| RecordError::Forbidden)?;
move |conn| {
let params = SubscriptionAclParams {
params: &record,
user: user.as_ref(),
};
match stmt.raw_query().next() {
Ok(Some(row)) => {
if row.get(0).unwrap_or(false) {
return Ok(());
let mut stmt = conn.prepare_cached(&access_query)?;
// NOTE: the `bind` impl does the heavy lifting.
params.bind(&mut stmt)?;
match stmt.raw_query().next() {
Ok(Some(row)) => {
if row.get(0).unwrap_or(false) {
return Ok(());
}
}
Ok(None) => {}
Err(err) => {
warn!("RLA query failed: {err}");
#[cfg(test)]
panic!("RLA query failed: {err}");
}
};
return Err(trailbase_sqlite::Error::Other(
RecordError::Forbidden.into(),
));
}
}
Ok(None) => {}
Err(err) => {
warn!("RLA query failed: {err}");
})
.await
.map_err(|_| RecordError::Forbidden)?;
#[cfg(test)]
panic!("RLA query failed: {err}");
}
}
return Err(RecordError::Forbidden);
return Ok(());
}
#[inline]
@@ -704,9 +721,9 @@ impl RecordApi {
}
}
pub(crate) struct SubscriptionAclParams<'a> {
pub params: &'a indexmap::IndexMap<&'a str, rusqlite::types::Value>,
pub user: Option<&'a User>,
struct SubscriptionAclParams<'a> {
params: &'a indexmap::IndexMap<String, rusqlite::types::Value>,
user: Option<&'a User>,
}
impl<'a> trailbase_sqlite::Params for SubscriptionAclParams<'a> {
+87 -90
View File
@@ -1,41 +1,46 @@
use axum::response::sse::Event as SseEvent;
use serde::Serialize;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::records::RecordError;
type JsonObject = serde_json::value::Map<String, serde_json::Value>;
#[repr(i64)]
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq)]
pub enum EventErrorStatus {
/// Unknown or unspecified error.
Unknown = 0,
/// Access forbidden.
Forbidden = 1,
/// Server-side event-loss, e.g. a buffer ran out of capacity. This does not account for
/// additional losses that may happen between the TrailBase server and the client. This
/// needs to be determined client-side based on event `seq` numbers.
Loss = 2,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct EventError {
pub status: EventErrorStatus,
pub message: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum JsonEventPayload {
Update { value: JsonObject },
Insert { value: JsonObject },
Delete { value: JsonObject },
Error { error: String },
Error { value: EventError },
Ping,
}
// fn serialize_raw_json<S>(json: &Option<String>, s: S) -> Result<S::Ok, S::Error>
// where
// S: serde::ser::Serializer,
// {
// // This should be pretty efficient: it just checks that the string is valid;
// // it doesn't parse it into a new data structure.
// if let Some(json) = json {
// let v: &serde_json::value::RawValue = serde_json::from_str(json).expect("invalid json");
// return v.serialize(s);
// }
//
// return s.serialize_none();
// }
#[allow(unused)]
#[derive(Debug, Clone, Serialize)]
pub enum EventPayload {
Update(Option<Box<serde_json::value::RawValue>>),
Insert(Option<Box<serde_json::value::RawValue>>),
Delete(Option<Box<serde_json::value::RawValue>>),
Error(String),
Error(Option<Box<serde_json::value::RawValue>>),
Ping,
}
@@ -49,7 +54,7 @@ impl PartialEq for EventPayload {
(Self::Update(lhs), Self::Update(rhs)) => get(lhs) == get(rhs),
(Self::Insert(lhs), Self::Insert(rhs)) => get(lhs) == get(rhs),
(Self::Delete(lhs), Self::Delete(rhs)) => get(lhs) == get(rhs),
(Self::Error(lhs), Self::Error(rhs)) => lhs == rhs,
(Self::Error(lhs), Self::Error(rhs)) => get(lhs) == get(rhs),
(Self::Ping, Self::Ping) => true,
_ => false,
};
@@ -74,7 +79,11 @@ impl EventPayload {
.map(|v| v.to_owned())
.ok(),
),
JsonEventPayload::Error { error } => EventPayload::Error(error.clone()),
JsonEventPayload::Error { value } => EventPayload::Error(
serde_json::value::to_raw_value(&value)
.map(|v| v.to_owned())
.ok(),
),
JsonEventPayload::Ping => EventPayload::Ping,
};
}
@@ -118,20 +127,24 @@ pub struct ChangeEvent {
}
#[cfg(test)]
pub fn deserialize_event(ev: Arc<EventPayload>) -> Result<JsonEventPayload, serde_json::Error> {
return match *ev {
EventPayload::Update(ref v) => Ok(JsonEventPayload::Update {
value: serde_json::from_str(v.as_ref().map_or("", |v| v.get()))?,
}),
EventPayload::Insert(ref v) => Ok(JsonEventPayload::Insert {
value: serde_json::from_str(v.as_ref().map_or("", |v| v.get()))?,
}),
EventPayload::Delete(ref v) => Ok(JsonEventPayload::Delete {
value: serde_json::from_str(v.as_ref().map_or("", |v| v.get()))?,
}),
EventPayload::Error(ref err) => Ok(JsonEventPayload::Error { error: err.clone() }),
EventPayload::Ping => Ok(JsonEventPayload::Ping),
};
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub enum TestJsonEventPayload {
Update(JsonObject),
Insert(JsonObject),
Delete(JsonObject),
Error {
status: EventErrorStatus,
message: Option<String>,
},
Ping,
}
#[cfg(test)]
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct TestChangeEvent {
#[serde(flatten)]
pub event: TestJsonEventPayload,
pub seq: Option<i64>,
}
#[cfg(test)]
@@ -139,66 +152,50 @@ mod tests {
use super::*;
use serde_json::json;
// #[test]
// fn serialization_test() {
// let payload0 = EventPayload::from(&JsonEventPayload::Insert {
// value: JsonObject::from_iter([("key0".to_string(), json!("value0"))]),
// });
//
// let ev0 = ChangeEvent {
// payload: Arc::new(payload0.clone()),
// seq: None,
// };
//
// let expected0 = r#"{"type":"insert","value":{"key0":"value0"}}"#;
// let ev0str = serde_json::to_string(&ev0).unwrap();
// assert_eq!(expected0, ev0str);
// assert_eq!(
// expected0,
// serde_json::to_string(&ChangeEvent {
// payload: Arc::new(payload0.clone()),
// seq: None,
// })
// .unwrap()
// );
//
// let ev0deserialized: EventPayload = serde_json::from_str(&expected0).unwrap();
// assert_eq!(payload0, ev0deserialized);
//
// let payload1 = EventPayload::from(&JsonEventPayload::Error {
// error: "boom".to_string(),
// });
// let ev1 = ChangeEvent {
// payload: Arc::new(payload1),
// seq: Some(11),
// };
//
// let expected1 = r#"{"type":"error","error":"boom","seq":11}"#;
// let ev1str = serde_json::to_string(&ev1).unwrap();
// assert_eq!(expected1, ev1str);
// }
#[test]
fn serialization_foo_test() {
let event = ChangeEvent {
event: Arc::new(EventPayload::Delete(Some(
serde_json::value::to_raw_value(&json!({
"foo": 4,
}))
.unwrap(),
))),
seq: Some(4),
};
{
let event = ChangeEvent {
event: Arc::new(EventPayload::from(&JsonEventPayload::Delete {
value: JsonObject::from_iter([("foo".to_string(), json!(4))]),
})),
seq: Some(4),
};
let value = serde_json::to_value(&event).unwrap();
assert_eq!(
serde_json::json!({
"Delete": {
"foo": 4,
let value = serde_json::to_value(&event).unwrap();
assert_eq!(
serde_json::json!({
"Delete": {
"foo": 4,
},
"seq": 4,
}),
value
);
}
{
let event = ChangeEvent {
event: Arc::new(EventPayload::from(&JsonEventPayload::Error {
value: EventError {
status: EventErrorStatus::Loss,
message: Some("test".to_string()),
},
"seq": 4,
}),
value
);
})),
seq: Some(4),
};
let value = serde_json::to_value(&event).unwrap();
assert_eq!(
serde_json::json!({
"Error": {
"status": EventErrorStatus::Loss,
"message": "test",
},
"seq": 4,
}),
value
);
}
}
}
+175 -20
View File
@@ -1,15 +1,22 @@
use axum::extract::{Path, RawQuery, Request, State};
use axum::response::sse::{KeepAlive, Sse};
use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use futures_util::StreamExt;
use futures_util::stream;
use serde::Deserialize;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, LazyLock};
use trailbase_qs::ValueOrComposite;
use ts_rs::TS;
use crate::app_state::AppState;
use crate::auth::User;
use crate::records::RecordApi;
use crate::records::filter::{Filter, apply_filter_recursively_to_record};
use crate::records::subscribe::event::{
EventError, EventErrorStatus, EventPayload, JsonEventPayload,
};
use crate::records::subscribe::state::{EventCandidate, Subscription};
use crate::records::{Permission, RecordError};
#[derive(Clone, Default, Debug, PartialEq, Deserialize)]
@@ -84,6 +91,48 @@ pub async fn add_subscription_sse_and_ws_handler(
};
}
struct ValidateEventArgs {
state: AppState,
// FIXME: We could probably do with a subset of information from the `Subscription` and keep it
// internal.
subscription: Arc<Subscription>,
expected_candidate_seq: AtomicI64,
}
async fn validate_event(
args: Arc<ValidateEventArgs>,
ev: EventCandidate,
) -> Result<Option<Arc<EventPayload>>, RecordError> {
if ev.seq != args.expected_candidate_seq.fetch_add(1, Ordering::SeqCst) {
args.expected_candidate_seq.store(ev.seq, Ordering::SeqCst);
return Ok(Some(EVENT_LOSS_EVENT.clone()));
}
let Some(ref record) = ev.record else {
// Established events.
return Ok(Some(ev.payload));
};
let sub: &Subscription = &args.subscription;
if let Filter::Record(ref filter) = sub.filter
&& !apply_filter_recursively_to_record(filter, record)
{
return Ok(None);
}
// We don't memoize and eagerly look up the APIs to make sure we get an up-to-date
// version.
let Some(api) = args.state.lookup_record_api(&sub.record_api_name) else {
return Ok(None);
};
api
.check_record_level_read_access_for_subscriptions(api.conn(), record, sub.user.as_ref())
.await?;
return Ok(Some(ev.payload));
}
pub async fn subscribe_sse(
state: AppState,
api: RecordApi,
@@ -91,21 +140,35 @@ pub async fn subscribe_sse(
filter: Option<ValueOrComposite>,
user: Option<User>,
) -> Result<Response, RecordError> {
let seq = Arc::new(AtomicI64::default());
return match record.as_str() {
"*" => {
api.check_table_level_access(Permission::Read, user.as_ref())?;
let receiver = state
let (receiver, subscription) = state
.subscription_manager()
.add_sse_table_subscription(api, user, filter)
.await?;
let seq = AtomicI64::default();
let args = Arc::new(ValidateEventArgs {
state,
subscription,
expected_candidate_seq: AtomicI64::default(),
});
Ok(
Sse::new(
receiver.map(move |ev| ev.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst)))),
)
Sse::new(receiver.filter_map(move |ev: EventCandidate| {
let seq = seq.clone();
let args = args.clone();
return async move {
validate_event(args.clone(), ev)
.await
.unwrap_or_default()
.map(|ev| ev.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst))))
};
}))
.keep_alive(KeepAlive::default())
.into_response(),
)
@@ -116,16 +179,48 @@ pub async fn subscribe_sse(
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
.await?;
let receiver = state
let (receiver, subscription) = state
.subscription_manager()
.add_sse_record_subscription(api, record_id, user)
.await?;
let seq = AtomicI64::default();
let args = Arc::new(ValidateEventArgs {
state,
subscription,
expected_candidate_seq: AtomicI64::default(),
});
Ok(
Sse::new(
receiver.map(move |ev| ev.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst)))),
receiver
.then(move |ev: EventCandidate| {
let seq = seq.clone();
let args = args.clone();
return async move {
match validate_event(args.clone(), ev).await {
Ok(None) => stream::empty().boxed(),
Ok(Some(ev)) => stream::once(std::future::ready(
ev.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst))),
))
.boxed(),
Err(_) => {
// Death sentence for record subscriptions to not have access
stream::iter(vec![
// First send an error event to the user.
ACCESS_DENIED_EVENT
.clone()
.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst))),
// Then terminate the stream via the `take_while` below.
Err(RecordError::Forbidden),
])
.boxed()
}
}
};
})
.flatten()
.take_while(|event: &Result<SseEvent, RecordError>| std::future::ready(event.is_ok())),
)
.keep_alive(KeepAlive::default())
.into_response(),
@@ -157,6 +252,7 @@ pub async fn subscribe_ws(
use crate::records::subscribe::event::EventPayload;
use crate::records::subscribe::state::AutoCleanupEventStream;
use crate::records::subscribe::state::EventCandidate;
let (mut parts, _body) = request.into_parts();
let ws = match WebSocketUpgrade::from_request_parts(&mut parts, &state).await {
@@ -197,14 +293,42 @@ pub async fn subscribe_ws(
}
async fn broker<S: SinkExt<Message> + std::marker::Unpin>(
state: AppState,
subscription: Arc<Subscription>,
// Receive events from SQLite
receiver: AutoCleanupEventStream,
// Send messages via WebSocket.
sender: &mut S,
is_record_subscription: bool,
) {
let args = Arc::new(ValidateEventArgs {
state,
subscription,
expected_candidate_seq: AtomicI64::default(),
});
let mut pinned_receiver = std::pin::pin!(receiver);
while let Some(ev) = pinned_receiver.next().await {
match ev.into_ws_event() {
let payload = match validate_event(args.clone(), ev).await {
Ok(Some(payload)) => payload,
Ok(None) => {
continue;
}
Err(_) => {
if is_record_subscription {
// Death sentence for record subscriptions to not have access
let _ = ACCESS_DENIED_EVENT
.clone()
.into_ws_event()
.map(|ev| sender.send(ev));
return;
} else {
continue;
}
}
};
match payload.into_ws_event() {
Ok(msg) => {
if let Err(_value) = sender.send(msg).await {
log::debug!("Sending WS event to client failed");
@@ -271,6 +395,8 @@ pub async fn subscribe_ws(
return match record.as_str() {
"*" => {
Ok(ws.on_upgrade(async move |socket: WebSocket| {
use crate::records::subscribe::state::EventCandidate;
let Some(mut ws_sender) = init(&state, socket, &mut user).await else {
return;
};
@@ -283,10 +409,10 @@ pub async fn subscribe_ws(
return;
}
let (sender, receiver) = async_channel::bounded::<Arc<EventPayload>>(16);
let state = state.subscription_manager().get_per_connection_state(&api);
let (sender, receiver) = async_channel::bounded::<EventCandidate>(64);
let conn_state = state.subscription_manager().get_per_connection_state(&api);
let Ok(id) = state
let Ok(subscription) = conn_state
.clone()
.add_table_subscription(api, user, filter, sender)
.await
@@ -295,15 +421,17 @@ pub async fn subscribe_ws(
return;
};
let receiver = AutoCleanupEventStream::new(receiver, state, id);
let receiver = AutoCleanupEventStream::new(receiver, conn_state, subscription.id.clone());
broker(receiver, &mut ws_sender).await
broker(state, subscription, receiver, &mut ws_sender, false).await
}))
}
_ => {
let record_id = api.primary_key_to_value(record)?;
Ok(ws.on_upgrade(async move |socket: WebSocket| {
use crate::records::subscribe::state::EventCandidate;
let Some(mut ws_sender) = init(&state, socket, &mut user).await else {
return;
};
@@ -319,10 +447,10 @@ pub async fn subscribe_ws(
return;
}
let (sender, receiver) = async_channel::bounded::<Arc<EventPayload>>(16);
let state = state.subscription_manager().get_per_connection_state(&api);
let (sender, receiver) = async_channel::bounded::<EventCandidate>(64);
let conn_state = state.subscription_manager().get_per_connection_state(&api);
let Ok(id) = state
let Ok(subscription) = conn_state
.clone()
.add_record_subscription(api, record_id, user, sender)
.await
@@ -331,10 +459,37 @@ pub async fn subscribe_ws(
return;
};
let receiver = AutoCleanupEventStream::new(receiver, state, id);
let receiver = AutoCleanupEventStream::new(receiver, conn_state, subscription.id.clone());
broker(receiver, &mut ws_sender).await;
broker(state, subscription, receiver, &mut ws_sender, true).await;
}))
}
};
}
static ACCESS_DENIED_EVENT: LazyLock<Arc<EventPayload>> = LazyLock::new(|| {
Arc::new(EventPayload::from(&JsonEventPayload::Error {
value: EventError {
status: EventErrorStatus::Forbidden,
message: Some("Access denied".into()),
},
}))
});
static EVENT_LOSS_EVENT: LazyLock<Arc<EventPayload>> = LazyLock::new(|| {
Arc::new(EventPayload::from(&JsonEventPayload::Error {
value: EventError {
status: EventErrorStatus::Loss,
message: None,
},
}))
});
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn static_sse_event_test() {
let _x: Arc<EventPayload> = (*ACCESS_DENIED_EVENT).clone();
}
}
+9 -4
View File
@@ -23,13 +23,14 @@ pub struct PreupdateHookEvent {
pub record: Vec<Value>,
}
pub fn install_hook(conn: &Connection) -> kanal::Receiver<PreupdateHookEvent> {
pub fn install_hook(conn: &Connection) -> kanal::Receiver<(usize, PreupdateHookEvent)> {
let (sender, receiver) = kanal::bounded(CAPACITY);
conn
.write_lock()
.preupdate_hook({
let conn = conn.clone();
let mut cnt = 0;
Some(
move |action: Action, db: &str, table_name: &str, case: &PreUpdateCase| {
@@ -69,7 +70,9 @@ pub fn install_hook(conn: &Connection) -> kanal::Receiver<PreupdateHookEvent> {
record,
};
match sender.try_send(event) {
cnt += 1;
match sender.try_send((cnt, event)) {
Ok(true) => {}
Ok(false) => {
warn!("Channel full. Failed to forward preupdate event.")
@@ -129,11 +132,13 @@ mod tests {
.await
.unwrap();
let ev0 = receiver.next().unwrap();
let (cnt, ev0) = receiver.next().unwrap();
assert_eq!(1, cnt);
assert_eq!("\"test\"", ev0.table_name.escaped_string());
assert_eq!(Value::Integer(3), ev0.record[0]);
let ev1 = receiver.next().unwrap();
let (cnt, ev1) = receiver.next().unwrap();
assert_eq!(2, cnt);
assert_eq!(Value::Integer(4), ev1.record[0]);
uninstall_hook(&conn);
+57 -28
View File
@@ -1,6 +1,7 @@
use parking_lot::RwLock;
use parking_lot::{Mutex, RwLock};
use reactivate::Reactive;
use std::collections::{HashMap, hash_map::Entry};
use std::sync::atomic::Ordering;
use std::sync::{Arc, LazyLock};
use trailbase_qs::ValueOrComposite;
@@ -9,7 +10,10 @@ use crate::auth::User;
use crate::records::RecordApi;
use crate::records::RecordError;
use crate::records::subscribe::event::{EventPayload, JsonEventPayload};
use crate::records::subscribe::state::{AutoCleanupEventStream, PerConnectionState};
use crate::records::subscribe::state::{
AutoCleanupEventStream, EventCandidate, PerConnectionState, PerConnectionStateInternal,
Subscription,
};
/// Internal, shareable state of the cloneable SubscriptionManager.
struct ManagerState {
@@ -47,17 +51,21 @@ impl SubscriptionManager {
let id = api.conn().id();
// TODO: Clean subscriptions from existing entries for tables that not longer have a
// corresponding API.
// TODO: Skip/cleanup subscriptions from existing entries for tables that not longer have
// a corresponding API.
if let Some(existing) = old.remove(&id) {
let apis = filter_record_apis(id, record_apis);
let Some(first) = apis.values().nth(0) else {
continue;
};
// Update metadata and add back.
*existing.connection_metadata.write() = first.connection_metadata().clone();
*existing.record_apis.write() = apis;
{
let mut state = existing.state.lock();
// Update metadata and add back.
state.connection_metadata = first.connection_metadata().clone();
state.record_apis = apis;
}
lock.insert(id, existing);
}
}
@@ -72,23 +80,32 @@ impl SubscriptionManager {
api: RecordApi,
user: Option<User>,
filter: Option<ValueOrComposite>,
) -> Result<AutoCleanupEventStream, RecordError> {
let (sender, receiver) = async_channel::bounded::<Arc<EventPayload>>(16);
) -> Result<(AutoCleanupEventStream, Arc<Subscription>), RecordError> {
let (sender, receiver) = async_channel::bounded::<EventCandidate>(64);
let state = self.get_per_connection_state(&api);
let id = state
let subscription = state
.clone()
.add_table_subscription(api, user, filter, sender.clone())
.await?;
// Send an immediate comment to flush SSE headers and establish the connection
if sender.send(ESTABLISHED_EVENT.clone()).await.is_err() {
if sender
.send(EventCandidate {
record: None,
payload: ESTABLISHED_EVENT.clone(),
seq: subscription.candidate_seq.fetch_add(1, Ordering::SeqCst),
})
.await
.is_err()
{
return Err(RecordError::BadRequest("channel already closed"));
}
let receiver = AutoCleanupEventStream::new(receiver, state, id);
return Ok(receiver);
return Ok((
AutoCleanupEventStream::new(receiver, state, subscription.id.clone()),
subscription,
));
}
pub async fn add_sse_record_subscription(
@@ -96,23 +113,32 @@ impl SubscriptionManager {
api: RecordApi,
record: trailbase_sqlite::Value,
user: Option<User>,
) -> Result<AutoCleanupEventStream, RecordError> {
let (sender, receiver) = async_channel::bounded::<Arc<EventPayload>>(16);
) -> Result<(AutoCleanupEventStream, Arc<Subscription>), RecordError> {
let (sender, receiver) = async_channel::bounded::<EventCandidate>(64);
let state = self.get_per_connection_state(&api);
let id = state
let subscription = state
.clone()
.add_record_subscription(api, record, user, sender.clone())
.await?;
// Send an immediate comment to flush SSE headers and establish the connection
if sender.send(ESTABLISHED_EVENT.clone()).await.is_err() {
if sender
.send(EventCandidate {
record: None,
payload: ESTABLISHED_EVENT.clone(),
seq: subscription.candidate_seq.fetch_add(1, Ordering::SeqCst),
})
.await
.is_err()
{
return Err(RecordError::BadRequest("channel already closed"));
}
let receiver = AutoCleanupEventStream::new(receiver, state, id);
return Ok(receiver);
return Ok((
AutoCleanupEventStream::new(receiver, state, subscription.id.clone()),
subscription,
));
}
pub fn get_per_connection_state(&self, api: &RecordApi) -> Arc<PerConnectionState> {
@@ -127,9 +153,12 @@ impl SubscriptionManager {
Entry::Occupied(v) => v.get().clone(),
Entry::Vacant(v) => {
let state = Arc::new(PerConnectionState {
connection_metadata: RwLock::new(api.connection_metadata().clone()),
record_apis: RwLock::new(filter_record_apis(id, &self.state.record_apis.value())),
subscriptions: Default::default(),
state: Mutex::new(PerConnectionStateInternal {
connection_metadata: api.connection_metadata().clone(),
record_apis: filter_record_apis(id, &self.state.record_apis.value()),
conn: api.conn().clone(),
subscriptions: Default::default(),
}),
});
v.insert(state).clone()
}
@@ -141,8 +170,8 @@ impl SubscriptionManager {
pub fn num_record_subscriptions(&self) -> usize {
let mut count: usize = 0;
for state in self.state.connections.read().values() {
for (_table_name, subs) in state.subscriptions.read().iter() {
for record in subs.read().record.values() {
for (_table_name, subs) in state.state.lock().subscriptions.iter() {
for record in subs.record.values() {
count += record.len();
}
}
@@ -154,8 +183,8 @@ impl SubscriptionManager {
pub fn num_table_subscriptions(&self) -> usize {
let mut count: usize = 0;
for state in self.state.connections.read().values() {
for (_table_name, subs) in state.subscriptions.read().iter() {
count += subs.read().table.len();
for (_table_name, subs) in state.state.lock().subscriptions.iter() {
count += subs.table.len();
}
}
return count;
+245 -339
View File
@@ -1,12 +1,12 @@
use async_channel::{TrySendError, WeakReceiver};
use async_channel::WeakReceiver;
use futures_util::Stream;
use log::*;
use parking_lot::RwLock;
use parking_lot::Mutex;
use pin_project_lite::pin_project;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, LazyLock};
use std::sync::{Arc, Weak};
use std::task::{Context, Poll};
use trailbase_qs::ValueOrComposite;
use trailbase_schema::QualifiedName;
@@ -15,20 +15,17 @@ use trailbase_schema::json::value_to_flat_json;
use crate::auth::User;
use crate::records::RecordApi;
use crate::records::RecordError;
use crate::records::filter::{
Filter, apply_filter_recursively_to_record, qs_filter_to_record_filter,
};
use crate::records::record_api::SubscriptionAclParams;
use crate::records::filter::{Filter, qs_filter_to_record_filter};
use crate::records::subscribe::event::{EventPayload, JsonEventPayload};
use crate::records::subscribe::hook::{
PreupdateHookEvent, RecordAction, install_hook, uninstall_hook, uninstall_hook_rusqlite,
PreupdateHookEvent, RecordAction, install_hook, uninstall_hook,
};
use crate::schema_metadata::ConnectionMetadata;
/// Composite id uniquely identifying a subscription.
///
/// If row_id is Some, this is considered to reference a subscription to a specific record.
#[derive(Default, Debug)]
#[derive(Clone, Default, Debug)]
pub struct SubscriptionId {
pub table_name: QualifiedName,
pub sub_id: i64,
@@ -38,25 +35,24 @@ pub struct SubscriptionId {
/// RAII type for automatically cleaning up subscriptions when the receiving side gets dropped,
/// e.g. client disconnects.
struct AutoCleanupEventStreamState {
receiver: WeakReceiver<Arc<EventPayload>>,
state: Arc<PerConnectionState>,
receiver: WeakReceiver<EventCandidate>,
state: Weak<PerConnectionState>,
id: SubscriptionId,
}
impl Drop for AutoCleanupEventStreamState {
fn drop(&mut self) {
// Subscriptions can be cleaned up either by the sender, i.e. when trying to broker events and
// tables or records get deleted, or by the client-receiver, e.g. by disconnecting. In the
// latter case, we need to clean up the subscription.
// tables or records get deleted, or by the client-receiver, e.g. by disconnecting.
// When dropped by the client-side, we need to clean up the subscription.
if self.receiver.upgrade().is_some() {
let id = std::mem::take(&mut self.id);
let state = self.state.clone();
if let Some(first) = self.state.record_apis.read().values().nth(0) {
first.conn().call_and_forget(move |conn| {
state.remove_subscription(conn, id);
});
}
let Some(state) = self.state.upgrade() else {
return;
};
state.state.lock().remove_subscription2(id);
} else {
debug!("Subscription cleaned up already by the sender side.");
}
@@ -70,19 +66,19 @@ pin_project! {
state: AutoCleanupEventStreamState,
#[pin]
pub receiver: async_channel::Receiver<Arc<EventPayload>>,
pub receiver: async_channel::Receiver<EventCandidate>,
}
}
impl AutoCleanupEventStream {
pub fn new(
receiver: async_channel::Receiver<Arc<EventPayload>>,
receiver: async_channel::Receiver<EventCandidate>,
state: Arc<PerConnectionState>,
id: SubscriptionId,
) -> Self {
return Self {
state: AutoCleanupEventStreamState {
receiver: receiver.downgrade(),
state,
state: Arc::downgrade(&state),
id,
},
receiver,
@@ -91,7 +87,7 @@ impl AutoCleanupEventStream {
}
impl Stream for AutoCleanupEventStream {
type Item = Arc<EventPayload>;
type Item = EventCandidate;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
@@ -104,27 +100,38 @@ impl Stream for AutoCleanupEventStream {
}
}
#[derive(Debug)]
pub struct Subscription {
/// Id uniquely identifying this subscription.
subscription_id: i64,
pub id: SubscriptionId,
/// Name of the API this subscription is subscribed to. We need to lookup the Record API on the
/// hot path to make sure we're getting the latest configuration.
record_api_name: String,
pub record_api_name: String,
/// User associated with subscriber.
user: Option<User>,
pub user: Option<User>,
/// Record filter.
pub filter: Filter,
/// Channel for sending events to the SSE handler.
sender: async_channel::Sender<Arc<EventPayload>>,
/// Filter
filter: Filter,
pub sender: async_channel::Sender<EventCandidate>,
pub candidate_seq: AtomicI64,
}
// Represents a change event that needs further filtering, e.g. ACLs.
#[derive(Debug)]
pub struct EventCandidate {
pub record: Option<Arc<indexmap::IndexMap<String, rusqlite::types::Value>>>,
pub payload: Arc<EventPayload>,
pub seq: i64,
}
#[derive(Default)]
pub struct Subscriptions {
/// A list of table subscriptions for this table.
pub table: Vec<Subscription>,
pub table: Vec<Arc<Subscription>>,
/// A map of record subscriptions for this.
pub record: HashMap<i64, Vec<Subscription>>,
pub record: HashMap<i64, Vec<Arc<Subscription>>>,
}
impl Subscriptions {
@@ -133,96 +140,104 @@ impl Subscriptions {
}
}
pub struct PerConnectionState {
pub struct PerConnectionStateInternal {
/// Metadata: always updated together when config -> record APIs change.
pub record_apis: RwLock<HashMap<String, RecordApi>>,
pub record_apis: HashMap<String, RecordApi>,
/// Denormalized metadata. We could also grab this from:
/// `record_apis.read().nth(0).unwrap().connection_metadata()`.
pub connection_metadata: RwLock<Arc<ConnectionMetadata>>,
pub connection_metadata: Arc<ConnectionMetadata>,
/// Should be the same as for all `record_apis` above.
pub conn: Arc<trailbase_sqlite::Connection>,
/// Map from table name to row id to list of subscriptions.
///
/// NOTE: Use layered locking to allow cleaning up per-table subscriptions w/o having to
/// exclusively lock the entire map.
pub subscriptions: RwLock<HashMap</* table_name= */ QualifiedName, RwLock<Subscriptions>>>,
pub subscriptions: HashMap</* table_name= */ QualifiedName, Subscriptions>,
}
impl PerConnectionStateInternal {
pub fn remove_subscription2(&mut self, id: SubscriptionId) {
let Some(subscriptions) = self.subscriptions.get_mut(&id.table_name) else {
return;
};
if let Some(row_id) = id.row_id {
if let Some(record_subscriptions) = subscriptions.record.get_mut(&row_id) {
record_subscriptions.retain(|sub| {
return sub.id.sub_id != id.sub_id;
});
if record_subscriptions.is_empty() {
subscriptions.record.remove(&row_id);
}
}
} else {
subscriptions.table.retain(|sub| {
return sub.id.sub_id != id.sub_id;
});
}
if subscriptions.is_empty() {
self.subscriptions.remove(&id.table_name);
if self.subscriptions.is_empty() {
uninstall_hook(&self.conn);
}
}
}
}
pub struct PerConnectionState {
pub state: Mutex<PerConnectionStateInternal>,
}
impl PerConnectionState {
fn lookup_record_api(&self, name: &str) -> Option<RecordApi> {
return self.record_apis.read().get(name).cloned();
}
// Gets called by the Stream destructor, e.g. when a client disconnects.
fn remove_subscription(&self, conn: &rusqlite::Connection, id: SubscriptionId) {
let mut read_lock = self.subscriptions.upgradable_read();
let remove_subscription_entry_for_table = {
let Some(mut subscriptions) = read_lock.get(&id.table_name).map(|l| l.write()) else {
return;
};
if let Some(row_id) = id.row_id {
if let Some(record_subscriptions) = subscriptions.record.get_mut(&row_id) {
record_subscriptions.retain(|sub| {
return sub.subscription_id != id.sub_id;
});
if record_subscriptions.is_empty() {
subscriptions.record.remove(&row_id);
}
}
} else {
subscriptions.table.retain(|sub| {
return sub.subscription_id != id.sub_id;
});
}
subscriptions.is_empty()
};
if remove_subscription_entry_for_table {
let table_name = &id.table_name;
// NOTE: Only write lock across all tables when necessary.
read_lock.with_upgraded(|lock| {
// Check again to avoid races:
if lock.get(table_name).is_some_and(|e| e.read().is_empty()) {
lock.remove(table_name);
if lock.is_empty() {
uninstall_hook_rusqlite(conn);
}
}
});
}
}
fn add_hook(self: &Arc<Self>, api: RecordApi) {
let conn = (**api.conn()).clone();
let conn = api.conn().clone();
let state = self.clone();
let receiver = install_hook(&conn).to_async();
let receiver = install_hook(&conn);
tokio::spawn(async move {
loop {
if receiver.sender_count() == 0 {
break;
}
let event = match receiver.recv().await {
Ok(event) => event,
Err(kanal::ReceiveError::Closed) | Err(kanal::ReceiveError::SendClosed) => {
// Spawn broker task.
if let Err(err) = std::thread::Builder::new()
.name("subscriptions".to_string())
.spawn(move || {
let mut expected = 1;
loop {
if receiver.sender_count() == 0 {
break;
}
};
let state = state.clone();
conn.call_reader_and_forget(move |conn| {
broker_event(conn, state, event);
});
}
let event = match receiver.recv() {
Ok((cnt, event)) => {
if cnt != expected {
// QUESTION: There's several ways we could deal with failure. We
// probably shouldn't create back pressure on the preupdate_hook and gunk up the
// SQLite access. We could try to deliver event loss messages to all receivers but
// that may just make the problem worse. We're probably at limit already
// if we don't manage to catch up. Should we just disconnect all subscriptions?
state.state.lock().subscriptions.clear();
break;
}
expected += 1;
debug!("Channel closed: terminating subscription broker task.");
});
event
}
Err(kanal::ReceiveError::Closed) | Err(kanal::ReceiveError::SendClosed) => {
break;
}
};
broker(&conn, &state, event);
}
debug!("Channel closed: terminating subscription broker task.");
})
{
log::error!("Failed to start subscription broker: {err}");
}
}
pub async fn add_record_subscription(
@@ -230,8 +245,8 @@ impl PerConnectionState {
api: RecordApi,
record: trailbase_sqlite::Value,
user: Option<User>,
sender: async_channel::Sender<Arc<EventPayload>>,
) -> Result<SubscriptionId, RecordError> {
sender: async_channel::Sender<EventCandidate>,
) -> Result<Arc<Subscription>, RecordError> {
let table_name = api.table_name();
let pk_column = &api.record_pk_column().column.name;
@@ -248,25 +263,32 @@ impl PerConnectionState {
};
let qualified_name = api.qualified_name();
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
let install_hook: bool = {
let mut lock = self.subscriptions.write();
let empty = lock.is_empty();
let sender = sender.clone();
let subscription_entry = Arc::new(Subscription {
id: SubscriptionId {
table_name: qualified_name.clone(),
row_id: Some(row_id),
sub_id: SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst),
},
record_api_name: api.api_name().to_string(),
user,
sender,
filter: Filter::Passthrough,
candidate_seq: AtomicI64::default(),
});
let subscriptions = lock.entry(qualified_name.clone()).or_default();
let install_hook: bool = {
let mut lock = self.state.lock();
let empty = lock.subscriptions.is_empty();
let subscriptions = lock
.subscriptions
.entry(qualified_name.clone())
.or_default();
subscriptions
.write()
.record
.entry(row_id)
.or_default()
.push(Subscription {
subscription_id,
record_api_name: api.api_name().to_string(),
user,
sender,
filter: Filter::Passthrough,
});
.push(subscription_entry.clone());
empty
};
@@ -275,11 +297,7 @@ impl PerConnectionState {
self.add_hook(api.clone());
}
return Ok(SubscriptionId {
table_name: qualified_name.clone(),
row_id: Some(row_id),
sub_id: subscription_id,
});
return Ok(subscription_entry);
}
pub async fn add_table_subscription(
@@ -287,8 +305,8 @@ impl PerConnectionState {
api: RecordApi,
user: Option<User>,
filter: Option<ValueOrComposite>,
sender: async_channel::Sender<Arc<EventPayload>>,
) -> Result<SubscriptionId, RecordError> {
sender: async_channel::Sender<EventCandidate>,
) -> Result<Arc<Subscription>, RecordError> {
let filter = if let Some(filter) = filter {
Filter::Record(qs_filter_to_record_filter(api.columns(), filter)?)
} else {
@@ -296,20 +314,29 @@ impl PerConnectionState {
};
let qualified_name = api.qualified_name();
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
let install_hook: bool = {
let mut lock = self.subscriptions.write();
let empty = lock.is_empty();
let sender = sender.clone();
let subscription_entry = Arc::new(Subscription {
id: SubscriptionId {
table_name: qualified_name.clone(),
row_id: None,
sub_id: SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst),
},
record_api_name: api.api_name().to_string(),
user,
sender,
filter,
candidate_seq: AtomicI64::default(),
});
let subscriptions = lock.entry(qualified_name.clone()).or_default();
subscriptions.write().table.push(Subscription {
subscription_id,
record_api_name: api.api_name().to_string(),
user,
sender,
filter,
});
let install_hook: bool = {
let mut lock = self.state.lock();
let empty = lock.subscriptions.is_empty();
let subscriptions = lock
.subscriptions
.entry(qualified_name.clone())
.or_default();
subscriptions.table.push(subscription_entry.clone());
empty
};
@@ -318,94 +345,51 @@ impl PerConnectionState {
self.add_hook(api.clone());
}
return Ok(SubscriptionId {
table_name: qualified_name.clone(),
row_id: None,
sub_id: subscription_id,
});
return Ok(subscription_entry);
}
}
impl Drop for PerConnectionState {
fn drop(&mut self) {
if let Some(first) = self.record_apis.read().values().nth(0) {
uninstall_hook(first.conn());
}
uninstall_hook(&self.state.lock().conn);
}
}
fn broker_subscriptions(
s: &PerConnectionState,
conn: &rusqlite::Connection,
subs: &[Subscription],
record_subscriptions: bool,
record: &indexmap::IndexMap<&str, rusqlite::types::Value>,
subs: &[Arc<Subscription>],
record: &Arc<indexmap::IndexMap<String, rusqlite::types::Value>>,
event: &Arc<EventPayload>,
) -> Vec<usize> {
let mut dead_subscriptions: Vec<usize> = vec![];
for (idx, sub) in subs.iter().enumerate() {
// Skip events for records that are being filtered out anyway.
if let Filter::Record(ref filter) = sub.filter
&& !apply_filter_recursively_to_record(filter, record)
{
continue;
}
// We don't memoize and eagerly look up the APIs to make sure we get an up-to-date version.
let Some(api) = s.lookup_record_api(&sub.record_api_name) else {
dead_subscriptions.push(idx);
sub.sender.close();
continue;
};
if let Err(_err) = api.check_record_level_read_access_for_subscriptions(
conn,
SubscriptionAclParams {
params: record,
user: sub.user.as_ref(),
},
) {
// NOTE: that access failures for table subscriptions for specific records are simply ignored,
// i.e. those events will just not be send. Other records in the table may pass the
// check. For record subscriptions, however, missing access is a death sentence.
if record_subscriptions {
// This can happen if the record api configuration has changed since originally
// subscribed. In this case we just send and error and cancel the subscription.
match sub.sender.try_send(ACCESS_DENIED_EVENT.clone()) {
Ok(_) | Err(TrySendError::Full(_)) => {
sub.sender.close();
return subs
.iter()
.enumerate()
.filter_map(|(idx, sub)| {
// Cloning the event. It's important that we use a try_send here to not block other
// subscriptions if a subscriber is slow and their channel fills up.
if let Err(err) = sub.sender.try_send(EventCandidate {
record: Some(record.clone()),
payload: event.clone(),
seq: sub.candidate_seq.fetch_add(1, Ordering::SeqCst),
}) {
match err {
async_channel::TrySendError::Full(ev) => {
debug!("Channel full, dropping event: {ev:?}");
}
async_channel::TrySendError::Closed(_ev) => {
return Some(idx);
}
Err(TrySendError::Closed(_)) => {}
};
dead_subscriptions.push(idx);
}
continue;
}
// Cloning the event. It's important that we use a try_send here to not block other
// subscriptions if a subscriber is slow and their channel fills up.
if let Err(err) = sub.sender.try_send(event.clone()) {
match err {
async_channel::TrySendError::Full(ev) => {
debug!("Channel full, dropping event: {ev:?}");
}
async_channel::TrySendError::Closed(_ev) => {
dead_subscriptions.push(idx);
sub.sender.close();
}
}
}
}
};
return dead_subscriptions;
return None;
})
.collect();
}
/// Broker event to various subscriptions.
fn broker_event(
conn: &rusqlite::Connection,
state: Arc<PerConnectionState>,
fn broker(
conn: &trailbase_sqlite::Connection,
state: &Arc<PerConnectionState>,
event: PreupdateHookEvent,
) {
let PreupdateHookEvent {
@@ -415,153 +399,75 @@ fn broker_event(
record,
} = event;
let mut per_table_subscriptions = state.subscriptions.upgradable_read();
let mut state = state.state.lock();
// If table_metadata is missing, the config/schema must have changed, thus removing the
// subscriptions.
let connection_metadata_lock = state.connection_metadata.read();
let Some(table_metadata) = connection_metadata_lock.get_table(&table_name) else {
let connection_metadata = state.connection_metadata.clone();
let Some(table_metadata) = connection_metadata.get_table(&table_name) else {
warn!("Table {table_name:?} not found. Removing subscriptions");
per_table_subscriptions.with_upgraded(|lock| {
lock.remove(&table_name);
if lock.is_empty() {
uninstall_hook_rusqlite(conn);
}
});
state.subscriptions.remove(&table_name);
if state.subscriptions.is_empty() {
uninstall_hook(conn);
}
return;
};
let remove_subscription_entry_for_table = {
// Check if there are any matching subscriptions and otherwise go back to listening.
let Some(mut subscriptions) = per_table_subscriptions
.get(&table_name)
.map(|r| r.upgradable_read())
else {
return;
};
if subscriptions.table.is_empty() && !subscriptions.record.contains_key(&row_id) {
return;
}
// Check if there are any matching subscriptions and otherwise go back to listening.
let Some(subscriptions) = state.subscriptions.get_mut(&table_name) else {
return;
};
if subscriptions.table.is_empty() && !subscriptions.record.contains_key(&row_id) {
return;
}
// Join values with column names. We use a map rather than a Vec<(String, Value)> for filter
// access.
let record: indexmap::IndexMap<&str, rusqlite::types::Value> = record
// Join values with column names. We use a map rather than a Vec<(String, Value)> for filter
// access.
let record: Arc<indexmap::IndexMap<String, rusqlite::types::Value>> = Arc::new(
record
.into_iter()
.enumerate()
.map(|(idx, v)| (table_metadata.schema.columns[idx].name.as_str(), v))
.map(|(idx, v)| (table_metadata.schema.columns[idx].name.clone(), v))
.collect(),
);
// Build a JSON-encoded SQLite event (insert, update, delete).
let event: Arc<EventPayload> = {
let json_obj = record
.iter()
.filter_map(|(name, value)| {
return value_to_flat_json(value)
.ok()
.map(|v| (name.to_string(), v));
})
.collect();
// Build a JSON-encoded SQLite event (insert, update, delete).
let event: Arc<EventPayload> = {
let json_obj = record
.iter()
.filter_map(|(name, value)| {
return value_to_flat_json(value)
.ok()
.map(|v| (name.to_string(), v));
})
.collect();
let payload = EventPayload::from(&match action {
RecordAction::Delete => JsonEventPayload::Delete { value: json_obj },
RecordAction::Insert => JsonEventPayload::Insert { value: json_obj },
RecordAction::Update => JsonEventPayload::Update { value: json_obj },
});
Arc::new(payload)
};
// First broker record subscriptions.
let (dead_record_subscriptions, dead_table_subscriptions) = {
// let Some(subscriptions) = per_table_subscriptions.get(&table_name).map(|r| r.read()) else {
// return;
// };
let dead_record_subscriptions = subscriptions
.record
.get(&row_id)
.map(|record_subscriptions| {
broker_subscriptions(&state, conn, record_subscriptions, true, &record, &event)
})
.unwrap_or_default();
// Then broker table subscriptions.
let dead_table_subscriptions =
broker_subscriptions(&state, conn, &subscriptions.table, false, &record, &event);
(dead_record_subscriptions, dead_table_subscriptions)
};
if dead_record_subscriptions.is_empty()
&& dead_table_subscriptions.is_empty()
&& action != RecordAction::Delete
{
// No cleanup needed.
return;
}
subscriptions.with_upgraded(|subscriptions| {
// Record subscription cleanup.
match action {
RecordAction::Delete => {
// This is unique for record subscriptions: if the record is deleted, cancel all
// subscriptions.
subscriptions.record.remove(&row_id);
}
RecordAction::Update | RecordAction::Insert => {
if let Some(m) = subscriptions.record.get_mut(&row_id) {
for idx in dead_record_subscriptions.iter().rev() {
m.swap_remove(*idx);
}
if m.is_empty() {
subscriptions.record.remove(&row_id);
}
}
}
}
// Table subscription cleanup.
for idx in dead_table_subscriptions.iter().rev() {
subscriptions.table.swap_remove(*idx);
}
/* remove_subscription_entry_for_table = */
subscriptions.is_empty()
})
Arc::new(EventPayload::from(&match action {
RecordAction::Delete => JsonEventPayload::Delete { value: json_obj },
RecordAction::Insert => JsonEventPayload::Insert { value: json_obj },
RecordAction::Update => JsonEventPayload::Update { value: json_obj },
}))
};
if remove_subscription_entry_for_table {
// NOTE: Only write lock across all tables when necessary.
per_table_subscriptions.with_upgraded(|lock| {
// Check again to avoid races:
if lock.get(&table_name).is_some_and(|e| e.read().is_empty()) {
lock.remove(&table_name);
// First broker record subscriptions.
if let Some(record_subscriptions) = subscriptions.record.get_mut(&row_id) {
let dead = broker_subscriptions(record_subscriptions, &record, &event);
if lock.is_empty() {
uninstall_hook_rusqlite(conn);
}
}
});
for idx in dead.iter().rev() {
record_subscriptions.remove(*idx);
}
}
// Then broker table subscriptions.
let dead = broker_subscriptions(&subscriptions.table, &record, &event);
for idx in dead.iter().rev() {
subscriptions.table.remove(*idx);
}
if subscriptions.is_empty() {
uninstall_hook(conn);
}
}
static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0);
static ACCESS_DENIED_EVENT: LazyLock<Arc<EventPayload>> = LazyLock::new(|| {
Arc::new(EventPayload::from(&JsonEventPayload::Error {
error: "Access denied".into(),
}))
});
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn static_sse_event_test() {
let _x: Arc<EventPayload> = (*ACCESS_DENIED_EVENT).clone();
}
}
+164 -95
View File
@@ -1,7 +1,9 @@
use async_channel::TryRecvError;
use axum::extract::{Path, RawQuery, State};
use futures_util::StreamExt;
use http_body_util::BodyExt;
use serde_json::Value;
use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
use trailbase_sqlite::params;
use crate::User;
@@ -9,10 +11,12 @@ use crate::admin::user::*;
use crate::app_state::{AppState, test_state};
use crate::auth::util::login_with_password;
use crate::config::proto::RecordApiConfig;
use crate::records::subscribe::event::{JsonEventPayload, deserialize_event};
use crate::records::subscribe::handler::{SubscriptionQuery, add_subscription_sse_and_ws_handler};
use crate::records::subscribe::event::{EventErrorStatus, TestChangeEvent, TestJsonEventPayload};
use crate::records::subscribe::handler::{
SubscriptionQuery, add_subscription_sse_and_ws_handler, subscribe_sse,
};
use crate::records::test_utils::add_record_api_config;
use crate::records::{PermissionFlag, RecordError};
use crate::records::{PermissionFlag, RecordApi, RecordError};
use crate::util::uuid_to_b64;
async fn setup_world_readable() -> AppState {
@@ -67,10 +71,7 @@ async fn subscribe_to_record_test() {
let manager = state.subscription_manager();
let api = state.lookup_record_api("api_name").unwrap();
let stream = manager
.add_sse_record_subscription(api, trailbase_sqlite::Value::Integer(0), None)
.await
.unwrap();
let mut stream = subscribe_to_records(state.clone(), api, "0", None, /* filter= */ None).await;
assert_eq!(1, manager.num_record_subscriptions());
@@ -87,8 +88,8 @@ async fn subscribe_to_record_test() {
// First event is "connection established".
assert!(matches!(
deserialize_event(stream.receiver.recv().await.unwrap()).unwrap(),
JsonEventPayload::Ping
stream.next().await.unwrap().event,
TestJsonEventPayload::Ping
));
// This should do nothing since nobody is subscribed to id = 5.
@@ -112,9 +113,9 @@ async fn subscribe_to_record_test() {
"id": record_id_raw,
"text": "bar",
});
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
JsonEventPayload::Update { value: obj } => {
assert_eq!(Value::Object(obj), expected);
match stream.next().await.unwrap().event {
TestJsonEventPayload::Update(obj) => {
assert_eq!(Value::Object(obj.clone()), expected);
}
x => {
panic!("Expected update, got: {x:?}");
@@ -126,9 +127,9 @@ async fn subscribe_to_record_test() {
.await
.unwrap();
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
JsonEventPayload::Delete { value: obj } => {
assert_eq!(Value::Object(obj), expected);
match stream.next().await.unwrap().event {
TestJsonEventPayload::Delete(obj) => {
assert_eq!(Value::Object(obj.clone()), expected);
}
x => {
panic!("Expected delete, got: {x:?}");
@@ -142,6 +143,67 @@ async fn subscribe_to_record_test() {
.unwrap();
}
async fn subscribe_to_records(
state: AppState,
api: RecordApi,
record: &str,
user: Option<User>,
filter: Option<&str>,
// ) -> kanal::AsyncReceiver<TestChangeEvent> {
) -> std::pin::Pin<Box<dyn futures_util::Stream<Item = TestChangeEvent>>> {
let filter = filter.map(|f| SubscriptionQuery::parse(f).unwrap().filter.unwrap());
let response = subscribe_sse(state, api, record.to_string(), filter, user)
.await
.unwrap();
assert_eq!(200, response.status());
let cnt = Arc::new(AtomicI64::default());
let stream = response.into_data_stream();
return stream
.take_while(|bytes| std::future::ready(bytes.is_ok()))
.filter_map(move |bytes| {
let cnt = cnt.clone();
return async move {
let Ok(bytes) = bytes.as_ref() else {
return None;
};
let payload = String::from_utf8_lossy(&bytes).to_string();
cnt.fetch_add(1, Ordering::SeqCst);
if cnt.load(Ordering::SeqCst) == 1 {
// Make sure we have an explicit ping as a first message to establish connection.
assert!(payload.contains("ping"));
return Some(TestChangeEvent {
event: TestJsonEventPayload::Ping,
seq: None,
});
}
// Ignore heartbeats.
if let Some((_a, b)) = payload.split_once("data: ") {
return Some(serde_json::from_str(b).unwrap());
}
return None;
};
})
.boxed();
}
async fn take_test_events(
stream: std::pin::Pin<Box<dyn futures_util::Stream<Item = TestChangeEvent>>>,
n: usize,
) -> Vec<TestChangeEvent> {
use tokio::time::*;
return timeout(Duration::from_secs(4), stream.take(n).collect())
.await
.unwrap();
}
#[tokio::test]
async fn subscribe_to_table_test() {
let state = setup_world_readable().await;
@@ -151,16 +213,13 @@ async fn subscribe_to_table_test() {
let api = state.lookup_record_api("api_name").unwrap();
{
let stream = manager
.add_sse_table_subscription(api, None, None)
.await
.unwrap();
let mut stream = subscribe_to_records(state.clone(), api, "*", None, /* filter= */ None).await;
assert_eq!(1, manager.num_table_subscriptions());
// First event is "connection established".
assert!(matches!(
deserialize_event(stream.receiver.recv().await.unwrap()).unwrap(),
JsonEventPayload::Ping
stream.next().await.unwrap().event,
TestJsonEventPayload::Ping
));
let record_id_raw = 0;
@@ -180,13 +239,13 @@ async fn subscribe_to_table_test() {
.await
.unwrap();
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
JsonEventPayload::Insert { value: obj } => {
match stream.next().await.unwrap().event {
TestJsonEventPayload::Insert(obj) => {
let expected = serde_json::json!({
"id": record_id_raw,
"text": "foo",
});
assert_eq!(Value::Object(obj), expected);
assert_eq!(Value::Object(obj.clone()), expected);
}
x => {
panic!("Expected insert, got: {x:?}");
@@ -197,9 +256,9 @@ async fn subscribe_to_table_test() {
"id": record_id_raw,
"text": "bar",
});
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
JsonEventPayload::Update { value: obj } => {
assert_eq!(Value::Object(obj), expected);
match stream.next().await.unwrap().event {
TestJsonEventPayload::Update(obj) => {
assert_eq!(Value::Object(obj.clone()), expected);
}
x => {
panic!("Expected update, got: {x:?}");
@@ -211,9 +270,9 @@ async fn subscribe_to_table_test() {
.await
.unwrap();
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
JsonEventPayload::Delete { value: obj } => {
assert_eq!(Value::Object(obj), expected);
match stream.next().await.unwrap().event {
TestJsonEventPayload::Delete(obj) => {
assert_eq!(Value::Object(obj.clone()), expected);
}
x => {
panic!("Expected delete, got: {x:?}");
@@ -430,29 +489,29 @@ async fn test_acl_selective_table_subs() {
// Assert events for table subscriptions are selective on ACLs.
{
let user_x_subscription = manager
.add_sse_table_subscription(
api.clone(),
User::from_auth_token(&state, &user_x_token.auth_token),
None,
)
.await
.unwrap();
let mut user_x_subscription = subscribe_to_records(
state.clone(),
api.clone(),
"*",
User::from_auth_token(&state, &user_x_token.auth_token),
/* filter= */ None,
)
.await;
// First event is "connection established".
assert!(matches!(
deserialize_event(user_x_subscription.receiver.recv().await.unwrap()).unwrap(),
JsonEventPayload::Ping
user_x_subscription.next().await.unwrap().event,
TestJsonEventPayload::Ping
));
let user_y_subscription = manager
.add_sse_table_subscription(
api.clone(),
User::from_auth_token(&state, &user_y_token.auth_token),
None,
)
.await
.unwrap();
let mut user_y_subscription = subscribe_to_records(
state.clone(),
api.clone(),
"*",
User::from_auth_token(&state, &user_y_token.auth_token),
/* filter= */ None,
)
.await;
assert_eq!(2, manager.num_table_subscriptions());
@@ -468,14 +527,14 @@ async fn test_acl_selective_table_subs() {
.await
.unwrap();
match deserialize_event(user_x_subscription.receiver.recv().await.unwrap()).unwrap() {
JsonEventPayload::Insert { value: obj } => {
match user_x_subscription.next().await.unwrap().event {
TestJsonEventPayload::Insert(obj) => {
let expected = serde_json::json!({
"id": record_id_raw,
"user": uuid_to_b64(&user_x),
"text": "foo",
});
assert_eq!(Value::Object(obj), expected);
assert_eq!(Value::Object(obj.clone()), expected);
}
x => {
panic!("Expected insert, got: {x:?}");
@@ -483,18 +542,26 @@ async fn test_acl_selective_table_subs() {
};
// User y should *not* have received the insert event.
assert!(
tokio::time::timeout(
tokio::time::Duration::from_millis(300),
user_y_subscription.receiver.clone().count()
)
.await
.is_err()
);
assert_eq!(
user_y_subscription.receiver.try_recv().err().unwrap(),
TryRecvError::Empty
);
// assert!(
// tokio::time::timeout(
// tokio::time::Duration::from_millis(300),
// user_y_subscription.receiver.clone().count()
// )
// .await
// .is_err()
// );
// assert_eq!(
// user_y_subscription.next().await.unwrap(),
// TryRecvError::Empty
// );
assert!(matches!(
user_y_subscription.next().await.unwrap().event,
TestJsonEventPayload::Ping
));
use tokio::time::*;
let got = timeout(Duration::from_millis(100), user_y_subscription.next()).await;
assert!(got.is_err(), "Got: {got:?}");
}
// Implicitly await for scheduled cleanups to go through.
@@ -538,16 +605,20 @@ async fn subscription_acl_change_owner() {
let manager = state.subscription_manager();
let api = state.lookup_record_api("api_name").unwrap();
let stream = manager
.add_sse_record_subscription(api, trailbase_sqlite::Value::Integer(record_id), user_x)
.await
.unwrap();
let mut stream = subscribe_to_records(
state.clone(),
api,
&record_id.to_string(),
user_x,
/* filter= */ None,
)
.await;
assert_eq!(1, manager.num_record_subscriptions());
// First event is "connection established".
assert!(matches!(
deserialize_event(stream.receiver.recv().await.unwrap()).unwrap(),
JsonEventPayload::Ping
stream.next().await.unwrap().event,
TestJsonEventPayload::Ping,
));
conn
@@ -567,8 +638,8 @@ async fn subscription_acl_change_owner() {
.await
.unwrap();
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
JsonEventPayload::Update { value: obj } => {
match stream.next().await.unwrap().event {
TestJsonEventPayload::Update(obj) => {
let expected = serde_json::json!({
"id": record_id,
"user": uuid_to_b64(&user_x_id),
@@ -581,20 +652,19 @@ async fn subscription_acl_change_owner() {
}
}
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
JsonEventPayload::Error { .. } => {}
match stream.next().await.unwrap().event {
TestJsonEventPayload::Error { status, .. } => {
assert_eq!(EventErrorStatus::Forbidden, status);
}
x => {
panic!("Expected error, got: {x:?}");
}
}
conn
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
.await
.unwrap();
drop(stream);
// Make sure the subscription was cleaned up after the access error.
assert!(stream.receiver.is_closed());
// assert!(stream.is_closed());
assert_eq!(0, manager.num_record_subscriptions());
}
@@ -607,22 +677,17 @@ async fn subscription_filter_test() {
let api = state.lookup_record_api("api_name").unwrap();
{
let filter =
SubscriptionQuery::parse("filter[$and][0][id][$gt]=5&filter[$and][1][id][$lt]=100").unwrap();
let stream = manager
.add_sse_table_subscription(api, None, Some(filter.filter.unwrap()))
.await
.unwrap();
let stream = subscribe_to_records(
state.clone(),
api.clone(),
"*",
/* user= */ None,
Some("filter[$and][0][id][$gt]=5&filter[$and][1][id][$lt]=100"),
)
.await;
assert_eq!(1, manager.num_table_subscriptions());
// First event is "connection established".
assert!(matches!(
deserialize_event(stream.receiver.recv().await.unwrap()).unwrap(),
JsonEventPayload::Ping
));
// This one should be filter out.
conn
.execute("INSERT INTO test (id, text) VALUES ($1, 'foo')", params!(1))
.await
@@ -637,13 +702,17 @@ async fn subscription_filter_test() {
.await
.unwrap();
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
JsonEventPayload::Insert { value: obj } => {
let events = take_test_events(stream, 2).await;
assert!(matches!(events[0].event, TestJsonEventPayload::Ping));
match &events[1].event {
TestJsonEventPayload::Insert(obj) => {
let expected = serde_json::json!({
"id": 25,
"text": "foo",
});
assert_eq!(Value::Object(obj), expected);
assert_eq!(Value::Object(obj.clone()), expected);
}
x => {
panic!("Expected insert, got: {x:?}");