From e2b0c0d05e3068f18ca1d1704fc52942e5499034 Mon Sep 17 00:00:00 2001 From: Sebastian Jeltsch Date: Sat, 11 Jan 2025 23:00:47 +0100 Subject: [PATCH] Add support for table/recordapi-wide "realtime" subscriptions, i.e. insertions, updates, and deletions. Remove subscriptions for missing table. --- trailbase-core/src/records/subscribe.rs | 422 +++++++++++++++++------- 1 file changed, 298 insertions(+), 124 deletions(-) diff --git a/trailbase-core/src/records/subscribe.rs b/trailbase-core/src/records/subscribe.rs index 29bb51e6..b9d2b224 100644 --- a/trailbase-core/src/records/subscribe.rs +++ b/trailbase-core/src/records/subscribe.rs @@ -28,7 +28,6 @@ static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0); // TODO: // * clients -// * table-wide subscriptions // * optimize: avoid repeated encoding of events. Easy to do but makes testing harder since there's // no good way to parse sse::Event back :/. We should probably just bite the bullet and parse, // it's literally "data: \n\n". @@ -61,12 +60,6 @@ pub enum DbEvent { Error(String), } -// pub struct SubscriptionId { -// table_name: String, -// row_id: i64, -// subscription_id: i64, -// } - pub struct Subscription { /// Id uniquely identifying this subscription. subscription_id: i64, @@ -92,7 +85,10 @@ struct ManagerState { record_apis: Computed, crate::config::proto::Config>, /// Map from table name to row id to list of subscriptions. - subscriptions: RwLock>>>, + record_subscriptions: RwLock>>>, + + /// Map from table name to table subscriptions. + table_subscriptions: RwLock>>, } impl ManagerState { @@ -113,7 +109,7 @@ pub struct SubscriptionManager { struct ContinuationState { state: Arc, - table_metadata: Arc, + table_metadata: Option>, action: RecordAction, table_name: String, rowid: i64, @@ -132,14 +128,15 @@ impl SubscriptionManager { table_metadata, record_apis, - subscriptions: RwLock::new(HashMap::new()), + record_subscriptions: RwLock::new(HashMap::new()), + table_subscriptions: RwLock::new(HashMap::new()), }), }; } - pub fn num_subscriptions(&self) -> usize { + pub fn num_record_subscriptions(&self) -> usize { let mut count: usize = 0; - for table in self.state.subscriptions.read().values() { + for table in self.state.record_subscriptions.read().values() { for record in table.values() { count += record.len(); } @@ -147,7 +144,50 @@ impl SubscriptionManager { return count; } - /// Preupdate hook that runs in a continuation of the trailbase-sqlite executor. + fn broker_subscriptions( + s: &ManagerState, + conn: &rusqlite::Connection, + subs: &[Subscription], + record: &[(&str, rusqlite::types::Value)], + event: &DbEvent, + ) -> Vec { + let mut dead_subscriptions: Vec = vec![]; + for (idx, sub) in subs.iter().enumerate() { + let Some(api) = s.lookup_record_api(&sub.record_api_name) else { + dead_subscriptions.push(idx); + continue; + }; + + if let Err(_err) = api.check_record_level_read_access( + conn, + Permission::Read, + // TODO: Maybe we could inject ValueRef instead to avoid repeated cloning. + record.to_owned(), + sub.user.as_ref(), + ) { + // This can happen if the record api configuration has changed since originally + // subscribed. In this case we just send and error and cancel the subscription. + let _ = sub.channel.try_send(DbEvent::Error("Access denied".into())); + dead_subscriptions.push(idx); + continue; + } + + // TODO: Avoid cloning the event/record over and over. + match sub.channel.try_send(event.clone()) { + Ok(_) => {} + Err(async_channel::TrySendError::Full(ev)) => { + log::warn!("Channel full, dropping event: {ev:?}"); + } + Err(async_channel::TrySendError::Closed(_ev)) => { + dead_subscriptions.push(idx); + } + } + } + + return dead_subscriptions; + } + + /// Continuation of the preupdate hook being scheduled on the executor. fn hook_continuation(conn: &rusqlite::Connection, state: ContinuationState) { let ContinuationState { state, @@ -157,11 +197,24 @@ impl SubscriptionManager { rowid, record_values, } = state; - let s = &state; + let table_name = table_name.as_str(); + + // If table_metadata is missing, the config/schema must have changed, thus removing the + // subscriptions. + let Some(table_metadata) = table_metadata else { + log::warn!("Table not found: {table_name}. Removing subscriptions"); + + let mut record_subs = s.record_subscriptions.write(); + record_subs.remove(table_name); + + let mut table_subs = s.table_subscriptions.write(); + table_subs.remove(table_name); + + if record_subs.is_empty() && table_subs.is_empty() { + conn.preupdate_hook(NO_HOOK); + } - let mut read_lock = s.subscriptions.upgradable_read(); - let Some(subs) = read_lock.get(&table_name).and_then(|m| m.get(&rowid)) else { return; }; @@ -193,80 +246,84 @@ impl SubscriptionManager { } }; - let mut dead_subscriptions: Vec = vec![]; - for (idx, sub) in subs.iter().enumerate() { - let Some(api) = s.lookup_record_api(&sub.record_api_name) else { - dead_subscriptions.push(idx); - continue; + 'record_subs: { + let mut read_lock = s.record_subscriptions.upgradable_read(); + let Some(subs) = read_lock.get(table_name).and_then(|m| m.get(&rowid)) else { + break 'record_subs; }; - if let Err(_err) = api.check_record_level_read_access( - conn, - Permission::Read, - // TODO: Maybe we could inject ValueRef instead to avoid repeated cloning. - record.clone(), - sub.user.as_ref(), - ) { - // This can happen if the record api configuration has changed since originally - // subscribed. In this case we just send and error and cancel the subscription. - let _ = sub.channel.try_send(DbEvent::Error("Access denied".into())); - dead_subscriptions.push(idx); - continue; + let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, &record, &event); + if dead_subscriptions.is_empty() && action != RecordAction::Delete { + // No cleanup needed. + break 'record_subs; } - // TODO: Avoid cloning the event/record over and over. - match sub.channel.try_send(event.clone()) { - Ok(_) => {} - Err(async_channel::TrySendError::Full(ev)) => { - log::warn!("Channel full, dropping event: {ev:?}"); - } - Err(async_channel::TrySendError::Closed(_ev)) => { - dead_subscriptions.push(idx); - } - } - } + read_lock.with_upgraded(move |subscriptions| { + let Some(table_subscriptions) = subscriptions.get_mut(table_name) else { + return; + }; - if dead_subscriptions.is_empty() && action != RecordAction::Delete { - // No cleanup needed. - return; - } - - read_lock.with_upgraded(move |subscriptions| { - let Some(table_subscriptions) = subscriptions.get_mut(&table_name) else { - return; - }; - - if action == RecordAction::Delete { - // Also drops the channel and thus automatically closes the SSE connection. - table_subscriptions.remove(&rowid); - - if table_subscriptions.is_empty() { - subscriptions.remove(&table_name); - if subscriptions.is_empty() { - conn.preupdate_hook(NO_HOOK); - } - } - - return; - } - - if let Some(m) = table_subscriptions.get_mut(&rowid) { - for idx in dead_subscriptions.iter().rev() { - m.swap_remove(*idx); - } - - if m.is_empty() { + if action == RecordAction::Delete { + // Also drops the channel and thus automatically closes the SSE connection. table_subscriptions.remove(&rowid); if table_subscriptions.is_empty() { - subscriptions.remove(&table_name); - if subscriptions.is_empty() { + subscriptions.remove(table_name); + if subscriptions.is_empty() && s.table_subscriptions.read().is_empty() { conn.preupdate_hook(NO_HOOK); } } + + return; } + + if let Some(m) = table_subscriptions.get_mut(&rowid) { + for idx in dead_subscriptions.iter().rev() { + m.swap_remove(*idx); + } + + if m.is_empty() { + table_subscriptions.remove(&rowid); + + if table_subscriptions.is_empty() { + subscriptions.remove(table_name); + if subscriptions.is_empty() && s.table_subscriptions.read().is_empty() { + conn.preupdate_hook(NO_HOOK); + } + } + } + } + }); + } + + 'table_subs: { + let mut read_lock = s.table_subscriptions.upgradable_read(); + let Some(subs) = read_lock.get(table_name) else { + break 'table_subs; + }; + + let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, &record, &event); + if dead_subscriptions.is_empty() && action != RecordAction::Delete { + // No cleanup needed. + break 'table_subs; } - }); + + read_lock.with_upgraded(move |subscriptions| { + let Some(table_subscriptions) = subscriptions.get_mut(table_name) else { + return; + }; + + for idx in dead_subscriptions.iter().rev() { + table_subscriptions.swap_remove(*idx); + } + + if table_subscriptions.is_empty() { + if subscriptions.is_empty() && s.record_subscriptions.read().is_empty() { + conn.preupdate_hook(NO_HOOK); + } + } + }); + } } async fn add_hook(&self) -> trailbase_sqlite::connection::Result<()> { @@ -294,22 +351,17 @@ impl SubscriptionManager { }; // If there are no subscriptions, do nothing. - if s - .subscriptions + let record_subs_candidate = s + .record_subscriptions .read() .get(table_name) .and_then(|m| m.get(&rowid)) - .is_none() - { + .is_some(); + let table_subs_candidate = s.table_subscriptions.read().get(table_name).is_some(); + if !record_subs_candidate && !table_subs_candidate { return; } - let Some(table_metadata) = s.table_metadata.get(table_name) else { - // TODO: Should we cleanup here? Probably, since we won't recover from this issue. - log::error!("Table not found: {table_name}"); - return; - }; - let Some(record_values) = extract_record_values(case) else { log::error!("Failed to extract values"); return; @@ -317,7 +369,7 @@ impl SubscriptionManager { let state = ContinuationState { state: s.clone(), - table_metadata, + table_metadata: s.table_metadata.get(table_name), action, table_name: table_name.to_string(), rowid, @@ -335,15 +387,12 @@ impl SubscriptionManager { .await; } - async fn add_subscription( + async fn add_record_subscription( &self, api: RecordApi, - record: Option, + record: trailbase_sqlite::Value, user: Option, ) -> Result, RecordError> { - let Some(record) = record else { - return Err(RecordError::BadRequest("Missing record id")); - }; let (sender, receiver) = async_channel::bounded::(16); let table_name = api.table_name(); @@ -366,7 +415,7 @@ impl SubscriptionManager { let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst); let empty = { - let mut lock = self.state.subscriptions.write(); + let mut lock = self.state.record_subscriptions.write(); let empty = lock.is_empty(); let m: &mut HashMap> = lock.entry(table_name.to_string()).or_default(); @@ -388,6 +437,38 @@ impl SubscriptionManager { return Ok(receiver); } + async fn add_table_subscription( + &self, + api: RecordApi, + user: Option, + ) -> Result, RecordError> { + let (sender, receiver) = async_channel::bounded::(16); + + let table_name = api.table_name(); + + let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst); + let empty = { + let mut lock = self.state.table_subscriptions.write(); + let empty = lock.is_empty() && self.state.record_subscriptions.read().is_empty(); + let m: &mut Vec = lock.entry(table_name.to_string()).or_default(); + + m.push(Subscription { + subscription_id, + record_api_name: api.api_name().to_string(), + user, + channel: sender, + }); + + empty + }; + + if empty { + self.add_hook().await.unwrap(); + } + + return Ok(receiver); + } + // TODO: Cleaning up subscriptions might be a thing, e.g. if SSE handlers had an onDisconnect // handler. Right now we're handling cleanups reactively, i.e. we only remove subscriptions when // sending new events and the receiving end of a handler channel became invalid. It would @@ -427,29 +508,35 @@ pub async fn add_subscription_sse_handler( return Err(RecordError::ApiNotFound); }; - let record_id = api.id_to_sql(&record)?; + fn encode(ev: DbEvent) -> Result { + // TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy + // to pre-encode on the sender side but makes testing much harder, since there's no good way + // to parse sse::Event back. + return Event::default().json_data(ev); + } - let Ok(()) = api - .check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref()) - .await - else { - return Err(RecordError::Forbidden); - }; + if record == "*" { + api.check_table_level_access(Permission::Read, user.as_ref())?; - let receiver = state - .subscription_manager() - .add_subscription(api, Some(record_id), user) - .await?; + let receiver = state + .subscription_manager() + .add_table_subscription(api, user) + .await?; - return Ok( - Sse::new(receiver.map(|ev| { - // TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy - // to pre-encode on the sender side but makes testing much harder, since there's no good way - // to parse sse::Event back. - return Event::default().json_data(ev); - })) - .keep_alive(KeepAlive::default()), - ); + return Ok(Sse::new(receiver.map(encode)).keep_alive(KeepAlive::default())); + } else { + let record_id = api.id_to_sql(&record)?; + api + .check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref()) + .await?; + + let receiver = state + .subscription_manager() + .add_record_subscription(api, record_id, user) + .await?; + + return Ok(Sse::new(receiver.map(encode)).keep_alive(KeepAlive::default())); + } } #[cfg(test)] @@ -460,7 +547,7 @@ mod tests { use crate::records::{add_record_api, AccessRules, Acls, PermissionFlag}; #[tokio::test] - async fn subscribe_connection_test() { + async fn subscribe_to_record_test() { let state = test_state(None).await.unwrap(); let conn = state.conn().clone(); @@ -483,11 +570,7 @@ mod tests { world: vec![PermissionFlag::Create, PermissionFlag::Read], ..Default::default() }, - AccessRules { - // read: Some("(_ROW_._owner = _USER_.id OR EXISTS(SELECT 1 FROM room_members WHERE room = - // _ROW_.room AND user = _USER_.id))".to_string()), - ..Default::default() - }, + AccessRules::default(), ) .await .unwrap(); @@ -510,11 +593,11 @@ mod tests { let manager = state.subscription_manager(); let api = state.lookup_record_api("api_name").unwrap(); let receiver = manager - .add_subscription(api, Some(trailbase_sqlite::Value::Integer(0)), None) + .add_record_subscription(api, trailbase_sqlite::Value::Integer(0), None) .await .unwrap(); - assert_eq!(1, manager.num_subscriptions()); + assert_eq!(1, manager.num_record_subscriptions()); conn .execute( @@ -538,7 +621,7 @@ mod tests { }; conn - .execute("DELETE FROM test WHERE _rowid_ = $2", params!(rowid)) + .execute("DELETE FROM test WHERE _rowid_ = $1", params!(rowid)) .await .unwrap(); @@ -551,7 +634,98 @@ mod tests { } } - assert_eq!(0, manager.num_subscriptions()); + assert_eq!(0, manager.num_record_subscriptions()); + } + + #[tokio::test] + async fn subscribe_to_table_test() { + let state = test_state(None).await.unwrap(); + let conn = state.conn().clone(); + + conn + .execute( + "CREATE TABLE test (id INTEGER PRIMARY KEY, text TEXT) STRICT", + (), + ) + .await + .unwrap(); + + state.table_metadata().invalidate_all().await.unwrap(); + + // Register message table as record api with moderator read access. + add_record_api( + &state, + "api_name", + "test", + Acls { + world: vec![PermissionFlag::Create, PermissionFlag::Read], + ..Default::default() + }, + AccessRules::default(), + ) + .await + .unwrap(); + + let manager = state.subscription_manager(); + let api = state.lookup_record_api("api_name").unwrap(); + let receiver = manager.add_table_subscription(api, None).await.unwrap(); + + let record_id_raw = 0; + conn + .query_row( + "INSERT INTO test (id, text) VALUES ($1, 'foo')", + params!(record_id_raw), + ) + .await + .unwrap(); + + conn + .execute( + "UPDATE test SET text = $1 WHERE id = $2", + params!("bar", record_id_raw), + ) + .await + .unwrap(); + + let expected = serde_json::json!({ + "id": record_id_raw, + "text": "foo", + }); + match receiver.recv().await.unwrap() { + DbEvent::Insert(Some(value)) => { + assert_eq!(value, expected); + } + x => { + assert!(false, "Expected update, got: {x:?}"); + } + }; + + let expected = serde_json::json!({ + "id": record_id_raw, + "text": "bar", + }); + match receiver.recv().await.unwrap() { + DbEvent::Update(Some(value)) => { + assert_eq!(value, expected); + } + x => { + assert!(false, "Expected update, got: {x:?}"); + } + }; + + conn + .execute("DELETE FROM test WHERE id = $1", params!(record_id_raw)) + .await + .unwrap(); + + match receiver.recv().await.unwrap() { + DbEvent::Delete(Some(value)) => { + assert_eq!(value, expected); + } + x => { + assert!(false, "Expected update, got: {x:?}"); + } + } } // TODO: Test actual SSE handler.