mirror of
https://github.com/trailbaseio/trailbase.git
synced 2025-12-30 14:19:43 -06:00
Add support for table/recordapi-wide "realtime" subscriptions, i.e. insertions, updates, and deletions.
Remove subscriptions for missing table.
This commit is contained in:
@@ -28,7 +28,6 @@ static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0);
|
||||
|
||||
// TODO:
|
||||
// * clients
|
||||
// * table-wide subscriptions
|
||||
// * optimize: avoid repeated encoding of events. Easy to do but makes testing harder since there's
|
||||
// no good way to parse sse::Event back :/. We should probably just bite the bullet and parse,
|
||||
// it's literally "data: <json>\n\n".
|
||||
@@ -61,12 +60,6 @@ pub enum DbEvent {
|
||||
Error(String),
|
||||
}
|
||||
|
||||
// pub struct SubscriptionId {
|
||||
// table_name: String,
|
||||
// row_id: i64,
|
||||
// subscription_id: i64,
|
||||
// }
|
||||
|
||||
pub struct Subscription {
|
||||
/// Id uniquely identifying this subscription.
|
||||
subscription_id: i64,
|
||||
@@ -92,7 +85,10 @@ struct ManagerState {
|
||||
record_apis: Computed<Vec<(String, RecordApi)>, crate::config::proto::Config>,
|
||||
|
||||
/// Map from table name to row id to list of subscriptions.
|
||||
subscriptions: RwLock<HashMap<String, HashMap<i64, Vec<Subscription>>>>,
|
||||
record_subscriptions: RwLock<HashMap<String, HashMap<i64, Vec<Subscription>>>>,
|
||||
|
||||
/// Map from table name to table subscriptions.
|
||||
table_subscriptions: RwLock<HashMap<String, Vec<Subscription>>>,
|
||||
}
|
||||
|
||||
impl ManagerState {
|
||||
@@ -113,7 +109,7 @@ pub struct SubscriptionManager {
|
||||
|
||||
struct ContinuationState {
|
||||
state: Arc<ManagerState>,
|
||||
table_metadata: Arc<TableMetadata>,
|
||||
table_metadata: Option<Arc<TableMetadata>>,
|
||||
action: RecordAction,
|
||||
table_name: String,
|
||||
rowid: i64,
|
||||
@@ -132,14 +128,15 @@ impl SubscriptionManager {
|
||||
table_metadata,
|
||||
record_apis,
|
||||
|
||||
subscriptions: RwLock::new(HashMap::new()),
|
||||
record_subscriptions: RwLock::new(HashMap::new()),
|
||||
table_subscriptions: RwLock::new(HashMap::new()),
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn num_subscriptions(&self) -> usize {
|
||||
pub fn num_record_subscriptions(&self) -> usize {
|
||||
let mut count: usize = 0;
|
||||
for table in self.state.subscriptions.read().values() {
|
||||
for table in self.state.record_subscriptions.read().values() {
|
||||
for record in table.values() {
|
||||
count += record.len();
|
||||
}
|
||||
@@ -147,7 +144,50 @@ impl SubscriptionManager {
|
||||
return count;
|
||||
}
|
||||
|
||||
/// Preupdate hook that runs in a continuation of the trailbase-sqlite executor.
|
||||
fn broker_subscriptions(
|
||||
s: &ManagerState,
|
||||
conn: &rusqlite::Connection,
|
||||
subs: &[Subscription],
|
||||
record: &[(&str, rusqlite::types::Value)],
|
||||
event: &DbEvent,
|
||||
) -> Vec<usize> {
|
||||
let mut dead_subscriptions: Vec<usize> = vec![];
|
||||
for (idx, sub) in subs.iter().enumerate() {
|
||||
let Some(api) = s.lookup_record_api(&sub.record_api_name) else {
|
||||
dead_subscriptions.push(idx);
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Err(_err) = api.check_record_level_read_access(
|
||||
conn,
|
||||
Permission::Read,
|
||||
// TODO: Maybe we could inject ValueRef instead to avoid repeated cloning.
|
||||
record.to_owned(),
|
||||
sub.user.as_ref(),
|
||||
) {
|
||||
// This can happen if the record api configuration has changed since originally
|
||||
// subscribed. In this case we just send and error and cancel the subscription.
|
||||
let _ = sub.channel.try_send(DbEvent::Error("Access denied".into()));
|
||||
dead_subscriptions.push(idx);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: Avoid cloning the event/record over and over.
|
||||
match sub.channel.try_send(event.clone()) {
|
||||
Ok(_) => {}
|
||||
Err(async_channel::TrySendError::Full(ev)) => {
|
||||
log::warn!("Channel full, dropping event: {ev:?}");
|
||||
}
|
||||
Err(async_channel::TrySendError::Closed(_ev)) => {
|
||||
dead_subscriptions.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dead_subscriptions;
|
||||
}
|
||||
|
||||
/// Continuation of the preupdate hook being scheduled on the executor.
|
||||
fn hook_continuation(conn: &rusqlite::Connection, state: ContinuationState) {
|
||||
let ContinuationState {
|
||||
state,
|
||||
@@ -157,11 +197,24 @@ impl SubscriptionManager {
|
||||
rowid,
|
||||
record_values,
|
||||
} = state;
|
||||
|
||||
let s = &state;
|
||||
let table_name = table_name.as_str();
|
||||
|
||||
// If table_metadata is missing, the config/schema must have changed, thus removing the
|
||||
// subscriptions.
|
||||
let Some(table_metadata) = table_metadata else {
|
||||
log::warn!("Table not found: {table_name}. Removing subscriptions");
|
||||
|
||||
let mut record_subs = s.record_subscriptions.write();
|
||||
record_subs.remove(table_name);
|
||||
|
||||
let mut table_subs = s.table_subscriptions.write();
|
||||
table_subs.remove(table_name);
|
||||
|
||||
if record_subs.is_empty() && table_subs.is_empty() {
|
||||
conn.preupdate_hook(NO_HOOK);
|
||||
}
|
||||
|
||||
let mut read_lock = s.subscriptions.upgradable_read();
|
||||
let Some(subs) = read_lock.get(&table_name).and_then(|m| m.get(&rowid)) else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -193,80 +246,84 @@ impl SubscriptionManager {
|
||||
}
|
||||
};
|
||||
|
||||
let mut dead_subscriptions: Vec<usize> = vec![];
|
||||
for (idx, sub) in subs.iter().enumerate() {
|
||||
let Some(api) = s.lookup_record_api(&sub.record_api_name) else {
|
||||
dead_subscriptions.push(idx);
|
||||
continue;
|
||||
'record_subs: {
|
||||
let mut read_lock = s.record_subscriptions.upgradable_read();
|
||||
let Some(subs) = read_lock.get(table_name).and_then(|m| m.get(&rowid)) else {
|
||||
break 'record_subs;
|
||||
};
|
||||
|
||||
if let Err(_err) = api.check_record_level_read_access(
|
||||
conn,
|
||||
Permission::Read,
|
||||
// TODO: Maybe we could inject ValueRef instead to avoid repeated cloning.
|
||||
record.clone(),
|
||||
sub.user.as_ref(),
|
||||
) {
|
||||
// This can happen if the record api configuration has changed since originally
|
||||
// subscribed. In this case we just send and error and cancel the subscription.
|
||||
let _ = sub.channel.try_send(DbEvent::Error("Access denied".into()));
|
||||
dead_subscriptions.push(idx);
|
||||
continue;
|
||||
let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, &record, &event);
|
||||
if dead_subscriptions.is_empty() && action != RecordAction::Delete {
|
||||
// No cleanup needed.
|
||||
break 'record_subs;
|
||||
}
|
||||
|
||||
// TODO: Avoid cloning the event/record over and over.
|
||||
match sub.channel.try_send(event.clone()) {
|
||||
Ok(_) => {}
|
||||
Err(async_channel::TrySendError::Full(ev)) => {
|
||||
log::warn!("Channel full, dropping event: {ev:?}");
|
||||
}
|
||||
Err(async_channel::TrySendError::Closed(_ev)) => {
|
||||
dead_subscriptions.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
read_lock.with_upgraded(move |subscriptions| {
|
||||
let Some(table_subscriptions) = subscriptions.get_mut(table_name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if dead_subscriptions.is_empty() && action != RecordAction::Delete {
|
||||
// No cleanup needed.
|
||||
return;
|
||||
}
|
||||
|
||||
read_lock.with_upgraded(move |subscriptions| {
|
||||
let Some(table_subscriptions) = subscriptions.get_mut(&table_name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if action == RecordAction::Delete {
|
||||
// Also drops the channel and thus automatically closes the SSE connection.
|
||||
table_subscriptions.remove(&rowid);
|
||||
|
||||
if table_subscriptions.is_empty() {
|
||||
subscriptions.remove(&table_name);
|
||||
if subscriptions.is_empty() {
|
||||
conn.preupdate_hook(NO_HOOK);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(m) = table_subscriptions.get_mut(&rowid) {
|
||||
for idx in dead_subscriptions.iter().rev() {
|
||||
m.swap_remove(*idx);
|
||||
}
|
||||
|
||||
if m.is_empty() {
|
||||
if action == RecordAction::Delete {
|
||||
// Also drops the channel and thus automatically closes the SSE connection.
|
||||
table_subscriptions.remove(&rowid);
|
||||
|
||||
if table_subscriptions.is_empty() {
|
||||
subscriptions.remove(&table_name);
|
||||
if subscriptions.is_empty() {
|
||||
subscriptions.remove(table_name);
|
||||
if subscriptions.is_empty() && s.table_subscriptions.read().is_empty() {
|
||||
conn.preupdate_hook(NO_HOOK);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(m) = table_subscriptions.get_mut(&rowid) {
|
||||
for idx in dead_subscriptions.iter().rev() {
|
||||
m.swap_remove(*idx);
|
||||
}
|
||||
|
||||
if m.is_empty() {
|
||||
table_subscriptions.remove(&rowid);
|
||||
|
||||
if table_subscriptions.is_empty() {
|
||||
subscriptions.remove(table_name);
|
||||
if subscriptions.is_empty() && s.table_subscriptions.read().is_empty() {
|
||||
conn.preupdate_hook(NO_HOOK);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
'table_subs: {
|
||||
let mut read_lock = s.table_subscriptions.upgradable_read();
|
||||
let Some(subs) = read_lock.get(table_name) else {
|
||||
break 'table_subs;
|
||||
};
|
||||
|
||||
let dead_subscriptions = Self::broker_subscriptions(s, conn, subs, &record, &event);
|
||||
if dead_subscriptions.is_empty() && action != RecordAction::Delete {
|
||||
// No cleanup needed.
|
||||
break 'table_subs;
|
||||
}
|
||||
});
|
||||
|
||||
read_lock.with_upgraded(move |subscriptions| {
|
||||
let Some(table_subscriptions) = subscriptions.get_mut(table_name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
for idx in dead_subscriptions.iter().rev() {
|
||||
table_subscriptions.swap_remove(*idx);
|
||||
}
|
||||
|
||||
if table_subscriptions.is_empty() {
|
||||
if subscriptions.is_empty() && s.record_subscriptions.read().is_empty() {
|
||||
conn.preupdate_hook(NO_HOOK);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn add_hook(&self) -> trailbase_sqlite::connection::Result<()> {
|
||||
@@ -294,22 +351,17 @@ impl SubscriptionManager {
|
||||
};
|
||||
|
||||
// If there are no subscriptions, do nothing.
|
||||
if s
|
||||
.subscriptions
|
||||
let record_subs_candidate = s
|
||||
.record_subscriptions
|
||||
.read()
|
||||
.get(table_name)
|
||||
.and_then(|m| m.get(&rowid))
|
||||
.is_none()
|
||||
{
|
||||
.is_some();
|
||||
let table_subs_candidate = s.table_subscriptions.read().get(table_name).is_some();
|
||||
if !record_subs_candidate && !table_subs_candidate {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(table_metadata) = s.table_metadata.get(table_name) else {
|
||||
// TODO: Should we cleanup here? Probably, since we won't recover from this issue.
|
||||
log::error!("Table not found: {table_name}");
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(record_values) = extract_record_values(case) else {
|
||||
log::error!("Failed to extract values");
|
||||
return;
|
||||
@@ -317,7 +369,7 @@ impl SubscriptionManager {
|
||||
|
||||
let state = ContinuationState {
|
||||
state: s.clone(),
|
||||
table_metadata,
|
||||
table_metadata: s.table_metadata.get(table_name),
|
||||
action,
|
||||
table_name: table_name.to_string(),
|
||||
rowid,
|
||||
@@ -335,15 +387,12 @@ impl SubscriptionManager {
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn add_subscription(
|
||||
async fn add_record_subscription(
|
||||
&self,
|
||||
api: RecordApi,
|
||||
record: Option<trailbase_sqlite::Value>,
|
||||
record: trailbase_sqlite::Value,
|
||||
user: Option<User>,
|
||||
) -> Result<async_channel::Receiver<DbEvent>, RecordError> {
|
||||
let Some(record) = record else {
|
||||
return Err(RecordError::BadRequest("Missing record id"));
|
||||
};
|
||||
let (sender, receiver) = async_channel::bounded::<DbEvent>(16);
|
||||
|
||||
let table_name = api.table_name();
|
||||
@@ -366,7 +415,7 @@ impl SubscriptionManager {
|
||||
|
||||
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let empty = {
|
||||
let mut lock = self.state.subscriptions.write();
|
||||
let mut lock = self.state.record_subscriptions.write();
|
||||
let empty = lock.is_empty();
|
||||
let m: &mut HashMap<i64, Vec<Subscription>> = lock.entry(table_name.to_string()).or_default();
|
||||
|
||||
@@ -388,6 +437,38 @@ impl SubscriptionManager {
|
||||
return Ok(receiver);
|
||||
}
|
||||
|
||||
async fn add_table_subscription(
|
||||
&self,
|
||||
api: RecordApi,
|
||||
user: Option<User>,
|
||||
) -> Result<async_channel::Receiver<DbEvent>, RecordError> {
|
||||
let (sender, receiver) = async_channel::bounded::<DbEvent>(16);
|
||||
|
||||
let table_name = api.table_name();
|
||||
|
||||
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let empty = {
|
||||
let mut lock = self.state.table_subscriptions.write();
|
||||
let empty = lock.is_empty() && self.state.record_subscriptions.read().is_empty();
|
||||
let m: &mut Vec<Subscription> = lock.entry(table_name.to_string()).or_default();
|
||||
|
||||
m.push(Subscription {
|
||||
subscription_id,
|
||||
record_api_name: api.api_name().to_string(),
|
||||
user,
|
||||
channel: sender,
|
||||
});
|
||||
|
||||
empty
|
||||
};
|
||||
|
||||
if empty {
|
||||
self.add_hook().await.unwrap();
|
||||
}
|
||||
|
||||
return Ok(receiver);
|
||||
}
|
||||
|
||||
// TODO: Cleaning up subscriptions might be a thing, e.g. if SSE handlers had an onDisconnect
|
||||
// handler. Right now we're handling cleanups reactively, i.e. we only remove subscriptions when
|
||||
// sending new events and the receiving end of a handler channel became invalid. It would
|
||||
@@ -427,29 +508,35 @@ pub async fn add_subscription_sse_handler(
|
||||
return Err(RecordError::ApiNotFound);
|
||||
};
|
||||
|
||||
let record_id = api.id_to_sql(&record)?;
|
||||
fn encode(ev: DbEvent) -> Result<Event, axum::Error> {
|
||||
// TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy
|
||||
// to pre-encode on the sender side but makes testing much harder, since there's no good way
|
||||
// to parse sse::Event back.
|
||||
return Event::default().json_data(ev);
|
||||
}
|
||||
|
||||
let Ok(()) = api
|
||||
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
|
||||
.await
|
||||
else {
|
||||
return Err(RecordError::Forbidden);
|
||||
};
|
||||
if record == "*" {
|
||||
api.check_table_level_access(Permission::Read, user.as_ref())?;
|
||||
|
||||
let receiver = state
|
||||
.subscription_manager()
|
||||
.add_subscription(api, Some(record_id), user)
|
||||
.await?;
|
||||
let receiver = state
|
||||
.subscription_manager()
|
||||
.add_table_subscription(api, user)
|
||||
.await?;
|
||||
|
||||
return Ok(
|
||||
Sse::new(receiver.map(|ev| {
|
||||
// TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy
|
||||
// to pre-encode on the sender side but makes testing much harder, since there's no good way
|
||||
// to parse sse::Event back.
|
||||
return Event::default().json_data(ev);
|
||||
}))
|
||||
.keep_alive(KeepAlive::default()),
|
||||
);
|
||||
return Ok(Sse::new(receiver.map(encode)).keep_alive(KeepAlive::default()));
|
||||
} else {
|
||||
let record_id = api.id_to_sql(&record)?;
|
||||
api
|
||||
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
|
||||
.await?;
|
||||
|
||||
let receiver = state
|
||||
.subscription_manager()
|
||||
.add_record_subscription(api, record_id, user)
|
||||
.await?;
|
||||
|
||||
return Ok(Sse::new(receiver.map(encode)).keep_alive(KeepAlive::default()));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -460,7 +547,7 @@ mod tests {
|
||||
use crate::records::{add_record_api, AccessRules, Acls, PermissionFlag};
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_connection_test() {
|
||||
async fn subscribe_to_record_test() {
|
||||
let state = test_state(None).await.unwrap();
|
||||
let conn = state.conn().clone();
|
||||
|
||||
@@ -483,11 +570,7 @@ mod tests {
|
||||
world: vec![PermissionFlag::Create, PermissionFlag::Read],
|
||||
..Default::default()
|
||||
},
|
||||
AccessRules {
|
||||
// read: Some("(_ROW_._owner = _USER_.id OR EXISTS(SELECT 1 FROM room_members WHERE room =
|
||||
// _ROW_.room AND user = _USER_.id))".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
AccessRules::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -510,11 +593,11 @@ mod tests {
|
||||
let manager = state.subscription_manager();
|
||||
let api = state.lookup_record_api("api_name").unwrap();
|
||||
let receiver = manager
|
||||
.add_subscription(api, Some(trailbase_sqlite::Value::Integer(0)), None)
|
||||
.add_record_subscription(api, trailbase_sqlite::Value::Integer(0), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(1, manager.num_subscriptions());
|
||||
assert_eq!(1, manager.num_record_subscriptions());
|
||||
|
||||
conn
|
||||
.execute(
|
||||
@@ -538,7 +621,7 @@ mod tests {
|
||||
};
|
||||
|
||||
conn
|
||||
.execute("DELETE FROM test WHERE _rowid_ = $2", params!(rowid))
|
||||
.execute("DELETE FROM test WHERE _rowid_ = $1", params!(rowid))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -551,7 +634,98 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(0, manager.num_subscriptions());
|
||||
assert_eq!(0, manager.num_record_subscriptions());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_to_table_test() {
|
||||
let state = test_state(None).await.unwrap();
|
||||
let conn = state.conn().clone();
|
||||
|
||||
conn
|
||||
.execute(
|
||||
"CREATE TABLE test (id INTEGER PRIMARY KEY, text TEXT) STRICT",
|
||||
(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
state.table_metadata().invalidate_all().await.unwrap();
|
||||
|
||||
// Register message table as record api with moderator read access.
|
||||
add_record_api(
|
||||
&state,
|
||||
"api_name",
|
||||
"test",
|
||||
Acls {
|
||||
world: vec![PermissionFlag::Create, PermissionFlag::Read],
|
||||
..Default::default()
|
||||
},
|
||||
AccessRules::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let manager = state.subscription_manager();
|
||||
let api = state.lookup_record_api("api_name").unwrap();
|
||||
let receiver = manager.add_table_subscription(api, None).await.unwrap();
|
||||
|
||||
let record_id_raw = 0;
|
||||
conn
|
||||
.query_row(
|
||||
"INSERT INTO test (id, text) VALUES ($1, 'foo')",
|
||||
params!(record_id_raw),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
conn
|
||||
.execute(
|
||||
"UPDATE test SET text = $1 WHERE id = $2",
|
||||
params!("bar", record_id_raw),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let expected = serde_json::json!({
|
||||
"id": record_id_raw,
|
||||
"text": "foo",
|
||||
});
|
||||
match receiver.recv().await.unwrap() {
|
||||
DbEvent::Insert(Some(value)) => {
|
||||
assert_eq!(value, expected);
|
||||
}
|
||||
x => {
|
||||
assert!(false, "Expected update, got: {x:?}");
|
||||
}
|
||||
};
|
||||
|
||||
let expected = serde_json::json!({
|
||||
"id": record_id_raw,
|
||||
"text": "bar",
|
||||
});
|
||||
match receiver.recv().await.unwrap() {
|
||||
DbEvent::Update(Some(value)) => {
|
||||
assert_eq!(value, expected);
|
||||
}
|
||||
x => {
|
||||
assert!(false, "Expected update, got: {x:?}");
|
||||
}
|
||||
};
|
||||
|
||||
conn
|
||||
.execute("DELETE FROM test WHERE id = $1", params!(record_id_raw))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match receiver.recv().await.unwrap() {
|
||||
DbEvent::Delete(Some(value)) => {
|
||||
assert_eq!(value, expected);
|
||||
}
|
||||
x => {
|
||||
assert!(false, "Expected update, got: {x:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Test actual SSE handler.
|
||||
|
||||
Reference in New Issue
Block a user