From 27ce368addd8b44f559b910b6278e8a62cc9ba6e Mon Sep 17 00:00:00 2001 From: Sebastian Jeltsch Date: Sat, 13 Dec 2025 22:35:13 +0100 Subject: [PATCH] Introduce a connection identiy and use it to manage subscriptions more robustely. Ultimately names don't matter. Subscriptions only works if the update_hook is installed on the very connection the record APIs use for mutations. --- crates/core/src/app_state.rs | 2 +- crates/core/src/connection.rs | 20 +++-- crates/core/src/records/subscribe.rs | 115 +++++++++++++++------------ crates/sqlite/src/connection.rs | 68 ++++++++++------ crates/sqlite/src/tests.rs | 11 ++- 5 files changed, 133 insertions(+), 83 deletions(-) diff --git a/crates/core/src/app_state.rs b/crates/core/src/app_state.rs index 17914901..3e6874ff 100644 --- a/crates/core/src/app_state.rs +++ b/crates/core/src/app_state.rs @@ -370,7 +370,7 @@ impl AppState { } } None => { - self.state.config.set(config); + self.state.config.update(|_old| config); } }; diff --git a/crates/core/src/connection.rs b/crates/core/src/connection.rs index a43928ed..63ba30f3 100644 --- a/crates/core/src/connection.rs +++ b/crates/core/src/connection.rs @@ -149,6 +149,7 @@ impl ConnectionManager { }; } + // Gets called when the metadata was updated. pub(crate) fn add_observer(&self, o: impl Fn() + Send + Sync + 'static) { self.state.observers.lock().push(Box::new(o)); } @@ -261,16 +262,21 @@ impl ConnectionManager { return Ok(Arc::new(conn)); } + // Updates connection metadata for cached connections. pub(crate) fn rebuild_metadata(&self) -> Result<(), ConnectionError> { - let new_metadata = Arc::new(build_metadata( - &self.state.main.read().connection.write_lock(), - &self.state.json_schema_registry, - )?); + // Main + { + let new_metadata = Arc::new(build_metadata( + &self.state.main.read().connection.write_lock(), + &self.state.json_schema_registry, + )?); - self.state.main.write().metadata = new_metadata; + self.state.main.write().metadata = new_metadata; + } + // Others: for (key, entry) in self.state.connections.iter() { - let metadata = Arc::new(build_metadata( + let new_metadata = Arc::new(build_metadata( &entry.connection.write_lock(), &self.state.json_schema_registry, )?); @@ -279,7 +285,7 @@ impl ConnectionManager { key, ConnectionEntry { connection: entry.connection.clone(), - metadata, + metadata: new_metadata, }, true, ); diff --git a/crates/core/src/records/subscribe.rs b/crates/core/src/records/subscribe.rs index 614abada..7b9beb48 100644 --- a/crates/core/src/records/subscribe.rs +++ b/crates/core/src/records/subscribe.rs @@ -10,8 +10,7 @@ use pin_project_lite::pin_project; use reactivate::Reactive; use rusqlite::hooks::{Action, PreUpdateCase}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::collections::hash_map::Entry; +use std::collections::{HashMap, hash_map::Entry}; use std::pin::Pin; use std::sync::{ Arc, @@ -63,9 +62,12 @@ impl Drop for AutoCleanupEventStreamState { if self.receiver.upgrade().is_some() { let id = std::mem::take(&mut self.id); let state = self.state.clone(); - self.state.conn.call_and_forget(move |conn| { - state.remove_subscription(conn, id); - }); + + if let Some(first) = self.state.record_apis.read().values().nth(0) { + first.conn().call_and_forget(move |conn| { + state.remove_subscription(conn, id); + }); + } } else { debug!("Subscription cleaned up already by the sender side."); } @@ -147,12 +149,11 @@ struct Subscriptions { } struct PerConnectionState { - // TODO: Should this be a weak ref? - conn: Arc, - - // Metadata: always updated together when config -> record APIs change. - connection_metadata: RwLock>, + /// Metadata: always updated together when config -> record APIs change. record_apis: RwLock>, + /// Denormalized metadata. We could also grab this from: + /// `record_apis.read().nth(0).unwrap().connection_metadata()`. + connection_metadata: RwLock>, /// Map from table name to row id to list of subscriptions. /// @@ -383,9 +384,11 @@ impl PerConnectionState { impl Drop for PerConnectionState { fn drop(&mut self) { - self - .conn - .call_and_forget(|conn| conn.preupdate_hook(NO_HOOK)); + if let Some(first) = self.record_apis.read().values().nth(0) { + first + .conn() + .call_and_forget(|conn| conn.preupdate_hook(NO_HOOK)); + } } } @@ -394,14 +397,8 @@ struct ManagerState { /// Record API configurations. record_apis: Reactive>>, - /// Manages subscriptions for differents databases and connections respectively. - /// - /// TODO: There's a disconnect. Connections can have many db names. - /// We should probably: - /// * establish a connection identity so we can establish a RecordAPI to connection mapping - /// * keep only a weak ref to connections (for them to die) - /// * And turn this into a list here. - connections: RwLock>>, + /// Manages subscriptions for different connections based on `conn.id()`. + connections: RwLock>>, } #[derive(Clone)] @@ -417,21 +414,17 @@ struct ContinuationState { record_values: Vec, } -fn filter_record_apis(db: &str, record_apis: &[(String, RecordApi)]) -> HashMap { +fn filter_record_apis( + conn_id: usize, + record_apis: &[(String, RecordApi)], +) -> HashMap { return record_apis .iter() .flat_map(|(name, api)| { if !api.enable_subscriptions() { return None; } - - if api - .qualified_name() - .database_schema - .as_deref() - .unwrap_or("main") - == db - { + if api.conn().id() == conn_id { return Some((name.to_string(), api.clone())); } @@ -450,17 +443,31 @@ impl SubscriptionManager { { let state = state.clone(); record_apis.add_observer(move |record_apis| { - // FIXME:: We need to do some more bookkeeping, e.g. remove pre-update hooks for no - // longer existing APIs or APIs which are no-longer subscribable. + // FIXME: Reload currently depends on ConnectionManager's cache to retain **all** + // connections. Currently, subscriptions would currently get cancelled when old Connections + // get evicted and new ones established. This is when RecordApis get rebuild even if + // nothing changed. + let mut lock = state.connections.write(); - let connections = state.connections.read(); - for (db_name, state) in connections.iter() { - let apis = filter_record_apis(db_name, record_apis); + let mut old: HashMap> = std::mem::take(&mut lock); - if let Some(first) = apis.values().nth(0) { - *state.connection_metadata.write() = first.connection_metadata().clone(); + for (_name, api) in record_apis.iter() { + if !api.enable_subscriptions() { + continue; + } + + let id = api.conn().id(); + if let Some(existing) = old.remove(&id) { + let apis = filter_record_apis(id, record_apis); + let Some(first) = apis.values().nth(0) else { + continue; + }; + + // Update metadata and add back. + *existing.connection_metadata.write() = first.connection_metadata().clone(); + *existing.record_apis.write() = apis; + lock.insert(id, existing); } - *state.record_apis.write() = apis; } }); } @@ -474,11 +481,8 @@ impl SubscriptionManager { user: Option, filter: Option, ) -> Result { - let table_name = api.qualified_name(); - let db = table_name.database_schema.as_deref().unwrap_or("main"); - return self - .get_per_connection_state(db, &api) + .get_per_connection_state(&api) .add_table_subscription(api, user, filter) .await; } @@ -489,29 +493,26 @@ impl SubscriptionManager { record: trailbase_sqlite::Value, user: Option, ) -> Result { - let table_name = api.qualified_name(); - let db = table_name.database_schema.as_deref().unwrap_or("main"); - return self - .get_per_connection_state(db, &api) + .get_per_connection_state(&api) .add_record_subscription(api, record, user) .await; } - fn get_per_connection_state(&self, db: &str, api: &RecordApi) -> Arc { + fn get_per_connection_state(&self, api: &RecordApi) -> Arc { + let id: usize = api.conn().id(); let mut lock = self.state.connections.upgradable_read(); - if let Some(state) = lock.get(db) { + if let Some(state) = lock.get(&id) { return state.clone(); } return lock.with_upgraded(|m| { - return match m.entry(db.to_string()) { + return match m.entry(id) { Entry::Occupied(v) => v.get().clone(), Entry::Vacant(v) => { let state = Arc::new(PerConnectionState { - conn: api.conn().clone(), connection_metadata: RwLock::new(api.connection_metadata().clone()), - record_apis: RwLock::new(filter_record_apis(db, &self.state.record_apis.value())), + record_apis: RwLock::new(filter_record_apis(id, &self.state.record_apis.value())), subscriptions: Default::default(), }); v.insert(state).clone() @@ -898,6 +899,18 @@ mod tests { .unwrap(); assert_eq!(1, manager.num_record_subscriptions()); + + // Make sure rebuilding connection metadata doesn't drop subscriptions. + state.rebuild_connection_metadata().await.unwrap(); + + assert_eq!(1, manager.num_record_subscriptions()); + + // Make sure updating config doesn't drop subscriptions. + state + .validate_and_update_config(state.get_config(), None) + .await + .unwrap(); + // First event is "connection established". assert!( decode_db_event(stream.receiver.recv().await.unwrap()) diff --git a/crates/sqlite/src/connection.rs b/crates/sqlite/src/connection.rs index 900dd130..a9315e2d 100644 --- a/crates/sqlite/src/connection.rs +++ b/crates/sqlite/src/connection.rs @@ -4,11 +4,11 @@ use parking_lot::RwLock; use rusqlite::fallible_iterator::FallibleIterator; use rusqlite::hooks::PreUpdateCase; use rusqlite::types::Value; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; use std::ops::{Deref, DerefMut}; -use std::{ - fmt::{self, Debug}, - sync::Arc, -}; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::sync::oneshot; use crate::error::Error; @@ -77,6 +77,7 @@ impl Default for Options { /// A handle to call functions in background thread. #[derive(Clone)] pub struct Connection { + id: usize, reader: Sender, writer: Sender, conns: Arc>, @@ -96,10 +97,9 @@ impl Connection { }; let write_conn = new_conn()?; - let in_memory = write_conn.path().is_none_or(|s| { - // Returns empty string for in-memory databases. - return !s.is_empty(); - }); + let path = write_conn.path().map(|p| p.to_string()); + // Returns empty string for in-memory databases. + let in_memory = path.as_ref().is_none_or(|s| !s.is_empty()); let n_read_threads: i64 = match (in_memory, opt.as_ref().map_or(0, |o| o.n_read_threads)) { (true, _) => { @@ -165,10 +165,11 @@ impl Connection { debug!( "Opened SQLite DB '{}' with {n_read_threads} reader threads", - conns.read().0[0].path().unwrap_or("") + path.as_deref().unwrap_or("") ); return Ok(Self { + id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst), reader: shared_read_sender, writer: shared_write_sender, conns, @@ -184,6 +185,7 @@ impl Connection { std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver)); return Self { + id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst), reader: shared_write_sender.clone(), writer: shared_write_sender, conns, @@ -199,6 +201,10 @@ impl Connection { return Self::new(|| Ok(rusqlite::Connection::open_in_memory()?), None); } + pub fn id(&self) -> usize { + return self.id; + } + #[inline] pub fn write_lock(&self) -> LockGuard<'_> { return LockGuard { @@ -206,20 +212,20 @@ impl Connection { }; } - #[inline] - pub fn try_write_lock_for(&self, duration: tokio::time::Duration) -> Option> { - return self - .conns - .try_write_for(duration) - .map(|guard| LockGuard { guard }); - } + // #[inline] + // pub fn try_write_lock_for(&self, duration: tokio::time::Duration) -> Option> { + // return self + // .conns + // .try_write_for(duration) + // .map(|guard| LockGuard { guard }); + // } - #[inline] - pub fn write_arc_lock(&self) -> ArcLockGuard { - return ArcLockGuard { - guard: self.conns.write_arc(), - }; - } + // #[inline] + // pub fn write_arc_lock(&self) -> ArcLockGuard { + // return ArcLockGuard { + // guard: self.conns.write_arc(), + // }; + // } #[inline] pub fn try_write_arc_lock_for(&self, duration: tokio::time::Duration) -> Option { @@ -545,11 +551,25 @@ impl Connection { } impl Debug for Connection { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Connection").finish() } } +impl Hash for Connection { + fn hash(&self, state: &mut H) { + self.id.hash(state); + } +} + +impl PartialEq for Connection { + fn eq(&self, other: &Self) -> bool { + return self.id == other.id; + } +} + +impl Eq for Connection {} + fn event_loop(id: usize, conns: Arc>, receiver: Receiver) { while let Ok(message) = receiver.recv() { match message { @@ -661,6 +681,8 @@ pub fn list_databases(conn: &rusqlite::Connection) -> Result> { return Ok(databases); } +static UNIQUE_CONN_ID: AtomicUsize = AtomicUsize::new(0); + #[cfg(test)] #[path = "tests.rs"] mod tests; diff --git a/crates/sqlite/src/tests.rs b/crates/sqlite/src/tests.rs index 6e87477d..510ae185 100644 --- a/crates/sqlite/src/tests.rs +++ b/crates/sqlite/src/tests.rs @@ -4,7 +4,7 @@ use serde::Deserialize; use std::borrow::Cow; use crate::connection::{Connection, Database, Error, Options, extract_row_id}; -use crate::{Value, ValueType, named_params, params}; +use crate::{Value, ValueType}; use rusqlite::ErrorCode; #[tokio::test] @@ -349,6 +349,15 @@ async fn test_execute_batch_error() { assert!(result.is_err(), "{result:?}"); } +#[tokio::test] +async fn test_identity() { + let conn0 = Connection::open_in_memory().unwrap(); + let conn1 = Connection::open_in_memory().unwrap(); + + assert_ne!(conn0, conn1); + assert_eq!(conn0, conn0.clone()); +} + #[test] fn test_locking() { let conn = Connection::open_in_memory().unwrap();