Introduce a connection identiy and use it to manage subscriptions more robustely.

Ultimately names don't matter. Subscriptions only works if the update_hook is installed on the very connection the record APIs use for mutations.
This commit is contained in:
Sebastian Jeltsch
2025-12-13 22:35:13 +01:00
parent 2beab0f770
commit 27ce368add
5 changed files with 133 additions and 83 deletions

View File

@@ -370,7 +370,7 @@ impl AppState {
}
}
None => {
self.state.config.set(config);
self.state.config.update(|_old| config);
}
};

View File

@@ -149,6 +149,7 @@ impl ConnectionManager {
};
}
// Gets called when the metadata was updated.
pub(crate) fn add_observer(&self, o: impl Fn() + Send + Sync + 'static) {
self.state.observers.lock().push(Box::new(o));
}
@@ -261,16 +262,21 @@ impl ConnectionManager {
return Ok(Arc::new(conn));
}
// Updates connection metadata for cached connections.
pub(crate) fn rebuild_metadata(&self) -> Result<(), ConnectionError> {
let new_metadata = Arc::new(build_metadata(
&self.state.main.read().connection.write_lock(),
&self.state.json_schema_registry,
)?);
// Main
{
let new_metadata = Arc::new(build_metadata(
&self.state.main.read().connection.write_lock(),
&self.state.json_schema_registry,
)?);
self.state.main.write().metadata = new_metadata;
self.state.main.write().metadata = new_metadata;
}
// Others:
for (key, entry) in self.state.connections.iter() {
let metadata = Arc::new(build_metadata(
let new_metadata = Arc::new(build_metadata(
&entry.connection.write_lock(),
&self.state.json_schema_registry,
)?);
@@ -279,7 +285,7 @@ impl ConnectionManager {
key,
ConnectionEntry {
connection: entry.connection.clone(),
metadata,
metadata: new_metadata,
},
true,
);

View File

@@ -10,8 +10,7 @@ use pin_project_lite::pin_project;
use reactivate::Reactive;
use rusqlite::hooks::{Action, PreUpdateCase};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, hash_map::Entry};
use std::pin::Pin;
use std::sync::{
Arc,
@@ -63,9 +62,12 @@ impl Drop for AutoCleanupEventStreamState {
if self.receiver.upgrade().is_some() {
let id = std::mem::take(&mut self.id);
let state = self.state.clone();
self.state.conn.call_and_forget(move |conn| {
state.remove_subscription(conn, id);
});
if let Some(first) = self.state.record_apis.read().values().nth(0) {
first.conn().call_and_forget(move |conn| {
state.remove_subscription(conn, id);
});
}
} else {
debug!("Subscription cleaned up already by the sender side.");
}
@@ -147,12 +149,11 @@ struct Subscriptions {
}
struct PerConnectionState {
// TODO: Should this be a weak ref?
conn: Arc<trailbase_sqlite::Connection>,
// Metadata: always updated together when config -> record APIs change.
connection_metadata: RwLock<Arc<ConnectionMetadata>>,
/// Metadata: always updated together when config -> record APIs change.
record_apis: RwLock<HashMap<String, RecordApi>>,
/// Denormalized metadata. We could also grab this from:
/// `record_apis.read().nth(0).unwrap().connection_metadata()`.
connection_metadata: RwLock<Arc<ConnectionMetadata>>,
/// Map from table name to row id to list of subscriptions.
///
@@ -383,9 +384,11 @@ impl PerConnectionState {
impl Drop for PerConnectionState {
fn drop(&mut self) {
self
.conn
.call_and_forget(|conn| conn.preupdate_hook(NO_HOOK));
if let Some(first) = self.record_apis.read().values().nth(0) {
first
.conn()
.call_and_forget(|conn| conn.preupdate_hook(NO_HOOK));
}
}
}
@@ -394,14 +397,8 @@ struct ManagerState {
/// Record API configurations.
record_apis: Reactive<Arc<Vec<(String, RecordApi)>>>,
/// Manages subscriptions for differents databases and connections respectively.
///
/// TODO: There's a disconnect. Connections can have many db names.
/// We should probably:
/// * establish a connection identity so we can establish a RecordAPI to connection mapping
/// * keep only a weak ref to connections (for them to die)
/// * And turn this into a list here.
connections: RwLock<HashMap</* db_name= */ String, Arc<PerConnectionState>>>,
/// Manages subscriptions for different connections based on `conn.id()`.
connections: RwLock<HashMap</* conn id= */ usize, Arc<PerConnectionState>>>,
}
#[derive(Clone)]
@@ -417,21 +414,17 @@ struct ContinuationState {
record_values: Vec<rusqlite::types::Value>,
}
fn filter_record_apis(db: &str, record_apis: &[(String, RecordApi)]) -> HashMap<String, RecordApi> {
fn filter_record_apis(
conn_id: usize,
record_apis: &[(String, RecordApi)],
) -> HashMap<String, RecordApi> {
return record_apis
.iter()
.flat_map(|(name, api)| {
if !api.enable_subscriptions() {
return None;
}
if api
.qualified_name()
.database_schema
.as_deref()
.unwrap_or("main")
== db
{
if api.conn().id() == conn_id {
return Some((name.to_string(), api.clone()));
}
@@ -450,17 +443,31 @@ impl SubscriptionManager {
{
let state = state.clone();
record_apis.add_observer(move |record_apis| {
// FIXME:: We need to do some more bookkeeping, e.g. remove pre-update hooks for no
// longer existing APIs or APIs which are no-longer subscribable.
// FIXME: Reload currently depends on ConnectionManager's cache to retain **all**
// connections. Currently, subscriptions would currently get cancelled when old Connections
// get evicted and new ones established. This is when RecordApis get rebuild even if
// nothing changed.
let mut lock = state.connections.write();
let connections = state.connections.read();
for (db_name, state) in connections.iter() {
let apis = filter_record_apis(db_name, record_apis);
let mut old: HashMap<usize, Arc<PerConnectionState>> = std::mem::take(&mut lock);
if let Some(first) = apis.values().nth(0) {
*state.connection_metadata.write() = first.connection_metadata().clone();
for (_name, api) in record_apis.iter() {
if !api.enable_subscriptions() {
continue;
}
let id = api.conn().id();
if let Some(existing) = old.remove(&id) {
let apis = filter_record_apis(id, record_apis);
let Some(first) = apis.values().nth(0) else {
continue;
};
// Update metadata and add back.
*existing.connection_metadata.write() = first.connection_metadata().clone();
*existing.record_apis.write() = apis;
lock.insert(id, existing);
}
*state.record_apis.write() = apis;
}
});
}
@@ -474,11 +481,8 @@ impl SubscriptionManager {
user: Option<User>,
filter: Option<trailbase_qs::ValueOrComposite>,
) -> Result<AutoCleanupEventStream, RecordError> {
let table_name = api.qualified_name();
let db = table_name.database_schema.as_deref().unwrap_or("main");
return self
.get_per_connection_state(db, &api)
.get_per_connection_state(&api)
.add_table_subscription(api, user, filter)
.await;
}
@@ -489,29 +493,26 @@ impl SubscriptionManager {
record: trailbase_sqlite::Value,
user: Option<User>,
) -> Result<AutoCleanupEventStream, RecordError> {
let table_name = api.qualified_name();
let db = table_name.database_schema.as_deref().unwrap_or("main");
return self
.get_per_connection_state(db, &api)
.get_per_connection_state(&api)
.add_record_subscription(api, record, user)
.await;
}
fn get_per_connection_state(&self, db: &str, api: &RecordApi) -> Arc<PerConnectionState> {
fn get_per_connection_state(&self, api: &RecordApi) -> Arc<PerConnectionState> {
let id: usize = api.conn().id();
let mut lock = self.state.connections.upgradable_read();
if let Some(state) = lock.get(db) {
if let Some(state) = lock.get(&id) {
return state.clone();
}
return lock.with_upgraded(|m| {
return match m.entry(db.to_string()) {
return match m.entry(id) {
Entry::Occupied(v) => v.get().clone(),
Entry::Vacant(v) => {
let state = Arc::new(PerConnectionState {
conn: api.conn().clone(),
connection_metadata: RwLock::new(api.connection_metadata().clone()),
record_apis: RwLock::new(filter_record_apis(db, &self.state.record_apis.value())),
record_apis: RwLock::new(filter_record_apis(id, &self.state.record_apis.value())),
subscriptions: Default::default(),
});
v.insert(state).clone()
@@ -898,6 +899,18 @@ mod tests {
.unwrap();
assert_eq!(1, manager.num_record_subscriptions());
// Make sure rebuilding connection metadata doesn't drop subscriptions.
state.rebuild_connection_metadata().await.unwrap();
assert_eq!(1, manager.num_record_subscriptions());
// Make sure updating config doesn't drop subscriptions.
state
.validate_and_update_config(state.get_config(), None)
.await
.unwrap();
// First event is "connection established".
assert!(
decode_db_event(stream.receiver.recv().await.unwrap())

View File

@@ -4,11 +4,11 @@ use parking_lot::RwLock;
use rusqlite::fallible_iterator::FallibleIterator;
use rusqlite::hooks::PreUpdateCase;
use rusqlite::types::Value;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::ops::{Deref, DerefMut};
use std::{
fmt::{self, Debug},
sync::Arc,
};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::oneshot;
use crate::error::Error;
@@ -77,6 +77,7 @@ impl Default for Options {
/// A handle to call functions in background thread.
#[derive(Clone)]
pub struct Connection {
id: usize,
reader: Sender<Message>,
writer: Sender<Message>,
conns: Arc<RwLock<ConnectionVec>>,
@@ -96,10 +97,9 @@ impl Connection {
};
let write_conn = new_conn()?;
let in_memory = write_conn.path().is_none_or(|s| {
// Returns empty string for in-memory databases.
return !s.is_empty();
});
let path = write_conn.path().map(|p| p.to_string());
// Returns empty string for in-memory databases.
let in_memory = path.as_ref().is_none_or(|s| !s.is_empty());
let n_read_threads: i64 = match (in_memory, opt.as_ref().map_or(0, |o| o.n_read_threads)) {
(true, _) => {
@@ -165,10 +165,11 @@ impl Connection {
debug!(
"Opened SQLite DB '{}' with {n_read_threads} reader threads",
conns.read().0[0].path().unwrap_or("<in-memory>")
path.as_deref().unwrap_or("<in-memory>")
);
return Ok(Self {
id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst),
reader: shared_read_sender,
writer: shared_write_sender,
conns,
@@ -184,6 +185,7 @@ impl Connection {
std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver));
return Self {
id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst),
reader: shared_write_sender.clone(),
writer: shared_write_sender,
conns,
@@ -199,6 +201,10 @@ impl Connection {
return Self::new(|| Ok(rusqlite::Connection::open_in_memory()?), None);
}
pub fn id(&self) -> usize {
return self.id;
}
#[inline]
pub fn write_lock(&self) -> LockGuard<'_> {
return LockGuard {
@@ -206,20 +212,20 @@ impl Connection {
};
}
#[inline]
pub fn try_write_lock_for(&self, duration: tokio::time::Duration) -> Option<LockGuard<'_>> {
return self
.conns
.try_write_for(duration)
.map(|guard| LockGuard { guard });
}
// #[inline]
// pub fn try_write_lock_for(&self, duration: tokio::time::Duration) -> Option<LockGuard<'_>> {
// return self
// .conns
// .try_write_for(duration)
// .map(|guard| LockGuard { guard });
// }
#[inline]
pub fn write_arc_lock(&self) -> ArcLockGuard {
return ArcLockGuard {
guard: self.conns.write_arc(),
};
}
// #[inline]
// pub fn write_arc_lock(&self) -> ArcLockGuard {
// return ArcLockGuard {
// guard: self.conns.write_arc(),
// };
// }
#[inline]
pub fn try_write_arc_lock_for(&self, duration: tokio::time::Duration) -> Option<ArcLockGuard> {
@@ -545,11 +551,25 @@ impl Connection {
}
impl Debug for Connection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connection").finish()
}
}
impl Hash for Connection {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}
impl PartialEq for Connection {
fn eq(&self, other: &Self) -> bool {
return self.id == other.id;
}
}
impl Eq for Connection {}
fn event_loop(id: usize, conns: Arc<RwLock<ConnectionVec>>, receiver: Receiver<Message>) {
while let Ok(message) = receiver.recv() {
match message {
@@ -661,6 +681,8 @@ pub fn list_databases(conn: &rusqlite::Connection) -> Result<Vec<Database>> {
return Ok(databases);
}
static UNIQUE_CONN_ID: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
#[path = "tests.rs"]
mod tests;

View File

@@ -4,7 +4,7 @@ use serde::Deserialize;
use std::borrow::Cow;
use crate::connection::{Connection, Database, Error, Options, extract_row_id};
use crate::{Value, ValueType, named_params, params};
use crate::{Value, ValueType};
use rusqlite::ErrorCode;
#[tokio::test]
@@ -349,6 +349,15 @@ async fn test_execute_batch_error() {
assert!(result.is_err(), "{result:?}");
}
#[tokio::test]
async fn test_identity() {
let conn0 = Connection::open_in_memory().unwrap();
let conn1 = Connection::open_in_memory().unwrap();
assert_ne!(conn0, conn1);
assert_eq!(conn0, conn0.clone());
}
#[test]
fn test_locking() {
let conn = Connection::open_in_memory().unwrap();