Add support for table/recordapi-wide "realtime" subscriptions, i.e. insertions, updates, and deletions.

Remove subscriptions for missing table.
This commit is contained in:
Sebastian Jeltsch
2025-01-11 23:00:47 +01:00
parent 30f295e6fd
commit e2b0c0d05e

View File

@@ -28,7 +28,6 @@ static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0);
// TODO:
// * clients
// * table-wide subscriptions
// * optimize: avoid repeated encoding of events. Easy to do but makes testing harder since there's
// no good way to parse sse::Event back :/. We should probably just bite the bullet and parse,
// it's literally "data: <json>\n\n".
@@ -61,12 +60,6 @@ pub enum DbEvent {
Error(String),
}
// pub struct SubscriptionId {
// table_name: String,
// row_id: i64,
// subscription_id: i64,
// }
pub struct Subscription {
/// Id uniquely identifying this subscription.
subscription_id: i64,
@@ -92,7 +85,10 @@ struct ManagerState {
record_apis: Computed<Vec<(String, RecordApi)>, crate::config::proto::Config>,
/// Map from table name to row id to list of subscriptions.
subscriptions: RwLock<HashMap<String, HashMap<i64, Vec<Subscription>>>>,
record_subscriptions: RwLock<HashMap<String, HashMap<i64, Vec<Subscription>>>>,
/// Map from table name to table subscriptions.
table_subscriptions: RwLock<HashMap<String, Vec<Subscription>>>,
}
impl ManagerState {
@@ -113,7 +109,7 @@ pub struct SubscriptionManager {
struct ContinuationState {
state: Arc<ManagerState>,
table_metadata: Arc<TableMetadata>,
table_metadata: Option<Arc<TableMetadata>>,
action: RecordAction,
table_name: String,
rowid: i64,
@@ -132,14 +128,15 @@ impl SubscriptionManager {
table_metadata,
record_apis,
subscriptions: RwLock::new(HashMap::new()),
record_subscriptions: RwLock::new(HashMap::new()),
table_subscriptions: RwLock::new(HashMap::new()),
}),
};
}
pub fn num_subscriptions(&self) -> usize {
pub fn num_record_subscriptions(&self) -> usize {
let mut count: usize = 0;
for table in self.state.subscriptions.read().values() {
for table in self.state.record_subscriptions.read().values() {
for record in table.values() {
count += record.len();
}
@@ -147,7 +144,50 @@ impl SubscriptionManager {
return count;
}
/// Preupdate hook that runs in a continuation of the trailbase-sqlite executor.
fn broker_subscriptions(
s: &ManagerState,
conn: &rusqlite::Connection,
subs: &[Subscription],
record: &[(&str, rusqlite::types::Value)],
event: &DbEvent,
) -> Vec<usize> {
let mut dead_subscriptions: Vec<usize> = vec![];
for (idx, sub) in subs.iter().enumerate() {
let Some(api) = s.lookup_record_api(&sub.record_api_name) else {
dead_subscriptions.push(idx);
continue;
};
if let Err(_err) = api.check_record_level_read_access(
conn,
Permission::Read,
// TODO: Maybe we could inject ValueRef instead to avoid repeated cloning.
record.to_owned(),
sub.user.as_ref(),
) {
// This can happen if the record api configuration has changed since originally
// subscribed. In this case we just send and error and cancel the subscription.
let _ = sub.channel.try_send(DbEvent::Error("Access denied".into()));
dead_subscriptions.push(idx);
continue;
}
// TODO: Avoid cloning the event/record over and over.
match sub.channel.try_send(event.clone()) {
Ok(_) => {}
Err(async_channel::TrySendError::Full(ev)) => {
log::warn!("Channel full, dropping event: {ev:?}");
}
Err(async_channel::TrySendError::Closed(_ev)) => {
dead_subscriptions.push(idx);
}
}
}
return dead_subscriptions;
}
/// Continuation of the preupdate hook being scheduled on the executor.
fn hook_continuation(conn: &rusqlite::Connection, state: ContinuationState) {
let ContinuationState {
state,
@@ -157,11 +197,24 @@ impl SubscriptionManager {
rowid,
record_values,
} = state;
let s = &state;
let table_name = table_name.as_str();
// If table_metadata is missing, the config/schema must have changed, thus removing the
// subscriptions.
let Some(table_metadata) = table_metadata else {
log::warn!("Table not found: {table_name}. Removing subscriptions");
let mut record_subs = s.record_subscriptions.write();
record_subs.remove(table_name);
let mut table_subs = s.table_subscriptions.write();
table_subs.remove(table_name);
if record_subs.is_empty() && table_subs.is_empty() {
conn.preupdate_hook(NO_HOOK);
}
let mut read_lock = s.subscriptions.upgradable_read();
let Some(subs) = read_lock.get(&table_name).and_then(|m| m.get(&rowid)) else {
return;
};
@@ -193,80 +246,84 @@ impl SubscriptionManager {
}
};
let mut dead_subscriptions: Vec<usize> = vec![];
for (idx, sub) in subs.iter().enumerate() {
let Some(api) = s.lookup_record_api(&sub.record_api_name) else {
dead_subscriptions.push(idx);
continue;
'record_subs: {
let mut read_lock = s.record_subscriptions.upgradable_read();
let Some(subs) = read_lock.get(table_name).and_then(|m| m.get(&rowid)) else {
break 'record_subs;
};
if let Err(_err) = api.check_record_level_read_access(
conn,
Permission::Read,
// TODO: Maybe we could inject ValueRef instead to avoid repeated cloning.
record.clone(),
sub.user.as_ref(),
) {
// This can happen if the record api configuration has changed since originally
// subscribed. In this case we just send and error and cancel the subscription.
let _ = sub.channel.try_send(DbEvent::Error("Access denied".into()));
dead_subscriptions.push(idx);
continue;
let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, &record, &event);
if dead_subscriptions.is_empty() && action != RecordAction::Delete {
// No cleanup needed.
break 'record_subs;
}
// TODO: Avoid cloning the event/record over and over.
match sub.channel.try_send(event.clone()) {
Ok(_) => {}
Err(async_channel::TrySendError::Full(ev)) => {
log::warn!("Channel full, dropping event: {ev:?}");
}
Err(async_channel::TrySendError::Closed(_ev)) => {
dead_subscriptions.push(idx);
}
}
}
read_lock.with_upgraded(move |subscriptions| {
let Some(table_subscriptions) = subscriptions.get_mut(table_name) else {
return;
};
if dead_subscriptions.is_empty() && action != RecordAction::Delete {
// No cleanup needed.
return;
}
read_lock.with_upgraded(move |subscriptions| {
let Some(table_subscriptions) = subscriptions.get_mut(&table_name) else {
return;
};
if action == RecordAction::Delete {
// Also drops the channel and thus automatically closes the SSE connection.
table_subscriptions.remove(&rowid);
if table_subscriptions.is_empty() {
subscriptions.remove(&table_name);
if subscriptions.is_empty() {
conn.preupdate_hook(NO_HOOK);
}
}
return;
}
if let Some(m) = table_subscriptions.get_mut(&rowid) {
for idx in dead_subscriptions.iter().rev() {
m.swap_remove(*idx);
}
if m.is_empty() {
if action == RecordAction::Delete {
// Also drops the channel and thus automatically closes the SSE connection.
table_subscriptions.remove(&rowid);
if table_subscriptions.is_empty() {
subscriptions.remove(&table_name);
if subscriptions.is_empty() {
subscriptions.remove(table_name);
if subscriptions.is_empty() && s.table_subscriptions.read().is_empty() {
conn.preupdate_hook(NO_HOOK);
}
}
return;
}
if let Some(m) = table_subscriptions.get_mut(&rowid) {
for idx in dead_subscriptions.iter().rev() {
m.swap_remove(*idx);
}
if m.is_empty() {
table_subscriptions.remove(&rowid);
if table_subscriptions.is_empty() {
subscriptions.remove(table_name);
if subscriptions.is_empty() && s.table_subscriptions.read().is_empty() {
conn.preupdate_hook(NO_HOOK);
}
}
}
}
});
}
'table_subs: {
let mut read_lock = s.table_subscriptions.upgradable_read();
let Some(subs) = read_lock.get(table_name) else {
break 'table_subs;
};
let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, &record, &event);
if dead_subscriptions.is_empty() && action != RecordAction::Delete {
// No cleanup needed.
break 'table_subs;
}
});
read_lock.with_upgraded(move |subscriptions| {
let Some(table_subscriptions) = subscriptions.get_mut(table_name) else {
return;
};
for idx in dead_subscriptions.iter().rev() {
table_subscriptions.swap_remove(*idx);
}
if table_subscriptions.is_empty() {
if subscriptions.is_empty() && s.record_subscriptions.read().is_empty() {
conn.preupdate_hook(NO_HOOK);
}
}
});
}
}
async fn add_hook(&self) -> trailbase_sqlite::connection::Result<()> {
@@ -294,22 +351,17 @@ impl SubscriptionManager {
};
// If there are no subscriptions, do nothing.
if s
.subscriptions
let record_subs_candidate = s
.record_subscriptions
.read()
.get(table_name)
.and_then(|m| m.get(&rowid))
.is_none()
{
.is_some();
let table_subs_candidate = s.table_subscriptions.read().get(table_name).is_some();
if !record_subs_candidate && !table_subs_candidate {
return;
}
let Some(table_metadata) = s.table_metadata.get(table_name) else {
// TODO: Should we cleanup here? Probably, since we won't recover from this issue.
log::error!("Table not found: {table_name}");
return;
};
let Some(record_values) = extract_record_values(case) else {
log::error!("Failed to extract values");
return;
@@ -317,7 +369,7 @@ impl SubscriptionManager {
let state = ContinuationState {
state: s.clone(),
table_metadata,
table_metadata: s.table_metadata.get(table_name),
action,
table_name: table_name.to_string(),
rowid,
@@ -335,15 +387,12 @@ impl SubscriptionManager {
.await;
}
async fn add_subscription(
async fn add_record_subscription(
&self,
api: RecordApi,
record: Option<trailbase_sqlite::Value>,
record: trailbase_sqlite::Value,
user: Option<User>,
) -> Result<async_channel::Receiver<DbEvent>, RecordError> {
let Some(record) = record else {
return Err(RecordError::BadRequest("Missing record id"));
};
let (sender, receiver) = async_channel::bounded::<DbEvent>(16);
let table_name = api.table_name();
@@ -366,7 +415,7 @@ impl SubscriptionManager {
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
let empty = {
let mut lock = self.state.subscriptions.write();
let mut lock = self.state.record_subscriptions.write();
let empty = lock.is_empty();
let m: &mut HashMap<i64, Vec<Subscription>> = lock.entry(table_name.to_string()).or_default();
@@ -388,6 +437,38 @@ impl SubscriptionManager {
return Ok(receiver);
}
async fn add_table_subscription(
&self,
api: RecordApi,
user: Option<User>,
) -> Result<async_channel::Receiver<DbEvent>, RecordError> {
let (sender, receiver) = async_channel::bounded::<DbEvent>(16);
let table_name = api.table_name();
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
let empty = {
let mut lock = self.state.table_subscriptions.write();
let empty = lock.is_empty() && self.state.record_subscriptions.read().is_empty();
let m: &mut Vec<Subscription> = lock.entry(table_name.to_string()).or_default();
m.push(Subscription {
subscription_id,
record_api_name: api.api_name().to_string(),
user,
channel: sender,
});
empty
};
if empty {
self.add_hook().await.unwrap();
}
return Ok(receiver);
}
// TODO: Cleaning up subscriptions might be a thing, e.g. if SSE handlers had an onDisconnect
// handler. Right now we're handling cleanups reactively, i.e. we only remove subscriptions when
// sending new events and the receiving end of a handler channel became invalid. It would
@@ -427,29 +508,35 @@ pub async fn add_subscription_sse_handler(
return Err(RecordError::ApiNotFound);
};
let record_id = api.id_to_sql(&record)?;
fn encode(ev: DbEvent) -> Result<Event, axum::Error> {
// TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy
// to pre-encode on the sender side but makes testing much harder, since there's no good way
// to parse sse::Event back.
return Event::default().json_data(ev);
}
let Ok(()) = api
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
.await
else {
return Err(RecordError::Forbidden);
};
if record == "*" {
api.check_table_level_access(Permission::Read, user.as_ref())?;
let receiver = state
.subscription_manager()
.add_subscription(api, Some(record_id), user)
.await?;
let receiver = state
.subscription_manager()
.add_table_subscription(api, user)
.await?;
return Ok(
Sse::new(receiver.map(|ev| {
// TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy
// to pre-encode on the sender side but makes testing much harder, since there's no good way
// to parse sse::Event back.
return Event::default().json_data(ev);
}))
.keep_alive(KeepAlive::default()),
);
return Ok(Sse::new(receiver.map(encode)).keep_alive(KeepAlive::default()));
} else {
let record_id = api.id_to_sql(&record)?;
api
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
.await?;
let receiver = state
.subscription_manager()
.add_record_subscription(api, record_id, user)
.await?;
return Ok(Sse::new(receiver.map(encode)).keep_alive(KeepAlive::default()));
}
}
#[cfg(test)]
@@ -460,7 +547,7 @@ mod tests {
use crate::records::{add_record_api, AccessRules, Acls, PermissionFlag};
#[tokio::test]
async fn subscribe_connection_test() {
async fn subscribe_to_record_test() {
let state = test_state(None).await.unwrap();
let conn = state.conn().clone();
@@ -483,11 +570,7 @@ mod tests {
world: vec![PermissionFlag::Create, PermissionFlag::Read],
..Default::default()
},
AccessRules {
// read: Some("(_ROW_._owner = _USER_.id OR EXISTS(SELECT 1 FROM room_members WHERE room =
// _ROW_.room AND user = _USER_.id))".to_string()),
..Default::default()
},
AccessRules::default(),
)
.await
.unwrap();
@@ -510,11 +593,11 @@ mod tests {
let manager = state.subscription_manager();
let api = state.lookup_record_api("api_name").unwrap();
let receiver = manager
.add_subscription(api, Some(trailbase_sqlite::Value::Integer(0)), None)
.add_record_subscription(api, trailbase_sqlite::Value::Integer(0), None)
.await
.unwrap();
assert_eq!(1, manager.num_subscriptions());
assert_eq!(1, manager.num_record_subscriptions());
conn
.execute(
@@ -538,7 +621,7 @@ mod tests {
};
conn
.execute("DELETE FROM test WHERE _rowid_ = $2", params!(rowid))
.execute("DELETE FROM test WHERE _rowid_ = $1", params!(rowid))
.await
.unwrap();
@@ -551,7 +634,98 @@ mod tests {
}
}
assert_eq!(0, manager.num_subscriptions());
assert_eq!(0, manager.num_record_subscriptions());
}
#[tokio::test]
async fn subscribe_to_table_test() {
let state = test_state(None).await.unwrap();
let conn = state.conn().clone();
conn
.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, text TEXT) STRICT",
(),
)
.await
.unwrap();
state.table_metadata().invalidate_all().await.unwrap();
// Register message table as record api with moderator read access.
add_record_api(
&state,
"api_name",
"test",
Acls {
world: vec![PermissionFlag::Create, PermissionFlag::Read],
..Default::default()
},
AccessRules::default(),
)
.await
.unwrap();
let manager = state.subscription_manager();
let api = state.lookup_record_api("api_name").unwrap();
let receiver = manager.add_table_subscription(api, None).await.unwrap();
let record_id_raw = 0;
conn
.query_row(
"INSERT INTO test (id, text) VALUES ($1, 'foo')",
params!(record_id_raw),
)
.await
.unwrap();
conn
.execute(
"UPDATE test SET text = $1 WHERE id = $2",
params!("bar", record_id_raw),
)
.await
.unwrap();
let expected = serde_json::json!({
"id": record_id_raw,
"text": "foo",
});
match receiver.recv().await.unwrap() {
DbEvent::Insert(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
};
let expected = serde_json::json!({
"id": record_id_raw,
"text": "bar",
});
match receiver.recv().await.unwrap() {
DbEvent::Update(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
};
conn
.execute("DELETE FROM test WHERE id = $1", params!(record_id_raw))
.await
.unwrap();
match receiver.recv().await.unwrap() {
DbEvent::Delete(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
}
}
// TODO: Test actual SSE handler.