mirror of
https://github.com/trailbaseio/trailbase.git
synced 2025-12-30 14:19:43 -06:00
Introduce a connection identiy and use it to manage subscriptions more robustely.
Ultimately names don't matter. Subscriptions only works if the update_hook is installed on the very connection the record APIs use for mutations.
This commit is contained in:
@@ -370,7 +370,7 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
None => {
|
||||
self.state.config.set(config);
|
||||
self.state.config.update(|_old| config);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -149,6 +149,7 @@ impl ConnectionManager {
|
||||
};
|
||||
}
|
||||
|
||||
// Gets called when the metadata was updated.
|
||||
pub(crate) fn add_observer(&self, o: impl Fn() + Send + Sync + 'static) {
|
||||
self.state.observers.lock().push(Box::new(o));
|
||||
}
|
||||
@@ -261,16 +262,21 @@ impl ConnectionManager {
|
||||
return Ok(Arc::new(conn));
|
||||
}
|
||||
|
||||
// Updates connection metadata for cached connections.
|
||||
pub(crate) fn rebuild_metadata(&self) -> Result<(), ConnectionError> {
|
||||
let new_metadata = Arc::new(build_metadata(
|
||||
&self.state.main.read().connection.write_lock(),
|
||||
&self.state.json_schema_registry,
|
||||
)?);
|
||||
// Main
|
||||
{
|
||||
let new_metadata = Arc::new(build_metadata(
|
||||
&self.state.main.read().connection.write_lock(),
|
||||
&self.state.json_schema_registry,
|
||||
)?);
|
||||
|
||||
self.state.main.write().metadata = new_metadata;
|
||||
self.state.main.write().metadata = new_metadata;
|
||||
}
|
||||
|
||||
// Others:
|
||||
for (key, entry) in self.state.connections.iter() {
|
||||
let metadata = Arc::new(build_metadata(
|
||||
let new_metadata = Arc::new(build_metadata(
|
||||
&entry.connection.write_lock(),
|
||||
&self.state.json_schema_registry,
|
||||
)?);
|
||||
@@ -279,7 +285,7 @@ impl ConnectionManager {
|
||||
key,
|
||||
ConnectionEntry {
|
||||
connection: entry.connection.clone(),
|
||||
metadata,
|
||||
metadata: new_metadata,
|
||||
},
|
||||
true,
|
||||
);
|
||||
|
||||
@@ -10,8 +10,7 @@ use pin_project_lite::pin_project;
|
||||
use reactivate::Reactive;
|
||||
use rusqlite::hooks::{Action, PreUpdateCase};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::{HashMap, hash_map::Entry};
|
||||
use std::pin::Pin;
|
||||
use std::sync::{
|
||||
Arc,
|
||||
@@ -63,9 +62,12 @@ impl Drop for AutoCleanupEventStreamState {
|
||||
if self.receiver.upgrade().is_some() {
|
||||
let id = std::mem::take(&mut self.id);
|
||||
let state = self.state.clone();
|
||||
self.state.conn.call_and_forget(move |conn| {
|
||||
state.remove_subscription(conn, id);
|
||||
});
|
||||
|
||||
if let Some(first) = self.state.record_apis.read().values().nth(0) {
|
||||
first.conn().call_and_forget(move |conn| {
|
||||
state.remove_subscription(conn, id);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
debug!("Subscription cleaned up already by the sender side.");
|
||||
}
|
||||
@@ -147,12 +149,11 @@ struct Subscriptions {
|
||||
}
|
||||
|
||||
struct PerConnectionState {
|
||||
// TODO: Should this be a weak ref?
|
||||
conn: Arc<trailbase_sqlite::Connection>,
|
||||
|
||||
// Metadata: always updated together when config -> record APIs change.
|
||||
connection_metadata: RwLock<Arc<ConnectionMetadata>>,
|
||||
/// Metadata: always updated together when config -> record APIs change.
|
||||
record_apis: RwLock<HashMap<String, RecordApi>>,
|
||||
/// Denormalized metadata. We could also grab this from:
|
||||
/// `record_apis.read().nth(0).unwrap().connection_metadata()`.
|
||||
connection_metadata: RwLock<Arc<ConnectionMetadata>>,
|
||||
|
||||
/// Map from table name to row id to list of subscriptions.
|
||||
///
|
||||
@@ -383,9 +384,11 @@ impl PerConnectionState {
|
||||
|
||||
impl Drop for PerConnectionState {
|
||||
fn drop(&mut self) {
|
||||
self
|
||||
.conn
|
||||
.call_and_forget(|conn| conn.preupdate_hook(NO_HOOK));
|
||||
if let Some(first) = self.record_apis.read().values().nth(0) {
|
||||
first
|
||||
.conn()
|
||||
.call_and_forget(|conn| conn.preupdate_hook(NO_HOOK));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,14 +397,8 @@ struct ManagerState {
|
||||
/// Record API configurations.
|
||||
record_apis: Reactive<Arc<Vec<(String, RecordApi)>>>,
|
||||
|
||||
/// Manages subscriptions for differents databases and connections respectively.
|
||||
///
|
||||
/// TODO: There's a disconnect. Connections can have many db names.
|
||||
/// We should probably:
|
||||
/// * establish a connection identity so we can establish a RecordAPI to connection mapping
|
||||
/// * keep only a weak ref to connections (for them to die)
|
||||
/// * And turn this into a list here.
|
||||
connections: RwLock<HashMap</* db_name= */ String, Arc<PerConnectionState>>>,
|
||||
/// Manages subscriptions for different connections based on `conn.id()`.
|
||||
connections: RwLock<HashMap</* conn id= */ usize, Arc<PerConnectionState>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -417,21 +414,17 @@ struct ContinuationState {
|
||||
record_values: Vec<rusqlite::types::Value>,
|
||||
}
|
||||
|
||||
fn filter_record_apis(db: &str, record_apis: &[(String, RecordApi)]) -> HashMap<String, RecordApi> {
|
||||
fn filter_record_apis(
|
||||
conn_id: usize,
|
||||
record_apis: &[(String, RecordApi)],
|
||||
) -> HashMap<String, RecordApi> {
|
||||
return record_apis
|
||||
.iter()
|
||||
.flat_map(|(name, api)| {
|
||||
if !api.enable_subscriptions() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if api
|
||||
.qualified_name()
|
||||
.database_schema
|
||||
.as_deref()
|
||||
.unwrap_or("main")
|
||||
== db
|
||||
{
|
||||
if api.conn().id() == conn_id {
|
||||
return Some((name.to_string(), api.clone()));
|
||||
}
|
||||
|
||||
@@ -450,17 +443,31 @@ impl SubscriptionManager {
|
||||
{
|
||||
let state = state.clone();
|
||||
record_apis.add_observer(move |record_apis| {
|
||||
// FIXME:: We need to do some more bookkeeping, e.g. remove pre-update hooks for no
|
||||
// longer existing APIs or APIs which are no-longer subscribable.
|
||||
// FIXME: Reload currently depends on ConnectionManager's cache to retain **all**
|
||||
// connections. Currently, subscriptions would currently get cancelled when old Connections
|
||||
// get evicted and new ones established. This is when RecordApis get rebuild even if
|
||||
// nothing changed.
|
||||
let mut lock = state.connections.write();
|
||||
|
||||
let connections = state.connections.read();
|
||||
for (db_name, state) in connections.iter() {
|
||||
let apis = filter_record_apis(db_name, record_apis);
|
||||
let mut old: HashMap<usize, Arc<PerConnectionState>> = std::mem::take(&mut lock);
|
||||
|
||||
if let Some(first) = apis.values().nth(0) {
|
||||
*state.connection_metadata.write() = first.connection_metadata().clone();
|
||||
for (_name, api) in record_apis.iter() {
|
||||
if !api.enable_subscriptions() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let id = api.conn().id();
|
||||
if let Some(existing) = old.remove(&id) {
|
||||
let apis = filter_record_apis(id, record_apis);
|
||||
let Some(first) = apis.values().nth(0) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Update metadata and add back.
|
||||
*existing.connection_metadata.write() = first.connection_metadata().clone();
|
||||
*existing.record_apis.write() = apis;
|
||||
lock.insert(id, existing);
|
||||
}
|
||||
*state.record_apis.write() = apis;
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -474,11 +481,8 @@ impl SubscriptionManager {
|
||||
user: Option<User>,
|
||||
filter: Option<trailbase_qs::ValueOrComposite>,
|
||||
) -> Result<AutoCleanupEventStream, RecordError> {
|
||||
let table_name = api.qualified_name();
|
||||
let db = table_name.database_schema.as_deref().unwrap_or("main");
|
||||
|
||||
return self
|
||||
.get_per_connection_state(db, &api)
|
||||
.get_per_connection_state(&api)
|
||||
.add_table_subscription(api, user, filter)
|
||||
.await;
|
||||
}
|
||||
@@ -489,29 +493,26 @@ impl SubscriptionManager {
|
||||
record: trailbase_sqlite::Value,
|
||||
user: Option<User>,
|
||||
) -> Result<AutoCleanupEventStream, RecordError> {
|
||||
let table_name = api.qualified_name();
|
||||
let db = table_name.database_schema.as_deref().unwrap_or("main");
|
||||
|
||||
return self
|
||||
.get_per_connection_state(db, &api)
|
||||
.get_per_connection_state(&api)
|
||||
.add_record_subscription(api, record, user)
|
||||
.await;
|
||||
}
|
||||
|
||||
fn get_per_connection_state(&self, db: &str, api: &RecordApi) -> Arc<PerConnectionState> {
|
||||
fn get_per_connection_state(&self, api: &RecordApi) -> Arc<PerConnectionState> {
|
||||
let id: usize = api.conn().id();
|
||||
let mut lock = self.state.connections.upgradable_read();
|
||||
if let Some(state) = lock.get(db) {
|
||||
if let Some(state) = lock.get(&id) {
|
||||
return state.clone();
|
||||
}
|
||||
|
||||
return lock.with_upgraded(|m| {
|
||||
return match m.entry(db.to_string()) {
|
||||
return match m.entry(id) {
|
||||
Entry::Occupied(v) => v.get().clone(),
|
||||
Entry::Vacant(v) => {
|
||||
let state = Arc::new(PerConnectionState {
|
||||
conn: api.conn().clone(),
|
||||
connection_metadata: RwLock::new(api.connection_metadata().clone()),
|
||||
record_apis: RwLock::new(filter_record_apis(db, &self.state.record_apis.value())),
|
||||
record_apis: RwLock::new(filter_record_apis(id, &self.state.record_apis.value())),
|
||||
subscriptions: Default::default(),
|
||||
});
|
||||
v.insert(state).clone()
|
||||
@@ -898,6 +899,18 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(1, manager.num_record_subscriptions());
|
||||
|
||||
// Make sure rebuilding connection metadata doesn't drop subscriptions.
|
||||
state.rebuild_connection_metadata().await.unwrap();
|
||||
|
||||
assert_eq!(1, manager.num_record_subscriptions());
|
||||
|
||||
// Make sure updating config doesn't drop subscriptions.
|
||||
state
|
||||
.validate_and_update_config(state.get_config(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// First event is "connection established".
|
||||
assert!(
|
||||
decode_db_event(stream.receiver.recv().await.unwrap())
|
||||
|
||||
@@ -4,11 +4,11 @@ use parking_lot::RwLock;
|
||||
use rusqlite::fallible_iterator::FallibleIterator;
|
||||
use rusqlite::hooks::PreUpdateCase;
|
||||
use rusqlite::types::Value;
|
||||
use std::fmt::Debug;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::{
|
||||
fmt::{self, Debug},
|
||||
sync::Arc,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::error::Error;
|
||||
@@ -77,6 +77,7 @@ impl Default for Options {
|
||||
/// A handle to call functions in background thread.
|
||||
#[derive(Clone)]
|
||||
pub struct Connection {
|
||||
id: usize,
|
||||
reader: Sender<Message>,
|
||||
writer: Sender<Message>,
|
||||
conns: Arc<RwLock<ConnectionVec>>,
|
||||
@@ -96,10 +97,9 @@ impl Connection {
|
||||
};
|
||||
|
||||
let write_conn = new_conn()?;
|
||||
let in_memory = write_conn.path().is_none_or(|s| {
|
||||
// Returns empty string for in-memory databases.
|
||||
return !s.is_empty();
|
||||
});
|
||||
let path = write_conn.path().map(|p| p.to_string());
|
||||
// Returns empty string for in-memory databases.
|
||||
let in_memory = path.as_ref().is_none_or(|s| !s.is_empty());
|
||||
|
||||
let n_read_threads: i64 = match (in_memory, opt.as_ref().map_or(0, |o| o.n_read_threads)) {
|
||||
(true, _) => {
|
||||
@@ -165,10 +165,11 @@ impl Connection {
|
||||
|
||||
debug!(
|
||||
"Opened SQLite DB '{}' with {n_read_threads} reader threads",
|
||||
conns.read().0[0].path().unwrap_or("<in-memory>")
|
||||
path.as_deref().unwrap_or("<in-memory>")
|
||||
);
|
||||
|
||||
return Ok(Self {
|
||||
id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst),
|
||||
reader: shared_read_sender,
|
||||
writer: shared_write_sender,
|
||||
conns,
|
||||
@@ -184,6 +185,7 @@ impl Connection {
|
||||
std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver));
|
||||
|
||||
return Self {
|
||||
id: UNIQUE_CONN_ID.fetch_add(1, Ordering::SeqCst),
|
||||
reader: shared_write_sender.clone(),
|
||||
writer: shared_write_sender,
|
||||
conns,
|
||||
@@ -199,6 +201,10 @@ impl Connection {
|
||||
return Self::new(|| Ok(rusqlite::Connection::open_in_memory()?), None);
|
||||
}
|
||||
|
||||
pub fn id(&self) -> usize {
|
||||
return self.id;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn write_lock(&self) -> LockGuard<'_> {
|
||||
return LockGuard {
|
||||
@@ -206,20 +212,20 @@ impl Connection {
|
||||
};
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn try_write_lock_for(&self, duration: tokio::time::Duration) -> Option<LockGuard<'_>> {
|
||||
return self
|
||||
.conns
|
||||
.try_write_for(duration)
|
||||
.map(|guard| LockGuard { guard });
|
||||
}
|
||||
// #[inline]
|
||||
// pub fn try_write_lock_for(&self, duration: tokio::time::Duration) -> Option<LockGuard<'_>> {
|
||||
// return self
|
||||
// .conns
|
||||
// .try_write_for(duration)
|
||||
// .map(|guard| LockGuard { guard });
|
||||
// }
|
||||
|
||||
#[inline]
|
||||
pub fn write_arc_lock(&self) -> ArcLockGuard {
|
||||
return ArcLockGuard {
|
||||
guard: self.conns.write_arc(),
|
||||
};
|
||||
}
|
||||
// #[inline]
|
||||
// pub fn write_arc_lock(&self) -> ArcLockGuard {
|
||||
// return ArcLockGuard {
|
||||
// guard: self.conns.write_arc(),
|
||||
// };
|
||||
// }
|
||||
|
||||
#[inline]
|
||||
pub fn try_write_arc_lock_for(&self, duration: tokio::time::Duration) -> Option<ArcLockGuard> {
|
||||
@@ -545,11 +551,25 @@ impl Connection {
|
||||
}
|
||||
|
||||
impl Debug for Connection {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Connection").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for Connection {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.id.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Connection {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
return self.id == other.id;
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Connection {}
|
||||
|
||||
fn event_loop(id: usize, conns: Arc<RwLock<ConnectionVec>>, receiver: Receiver<Message>) {
|
||||
while let Ok(message) = receiver.recv() {
|
||||
match message {
|
||||
@@ -661,6 +681,8 @@ pub fn list_databases(conn: &rusqlite::Connection) -> Result<Vec<Database>> {
|
||||
return Ok(databases);
|
||||
}
|
||||
|
||||
static UNIQUE_CONN_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests.rs"]
|
||||
mod tests;
|
||||
|
||||
@@ -4,7 +4,7 @@ use serde::Deserialize;
|
||||
use std::borrow::Cow;
|
||||
|
||||
use crate::connection::{Connection, Database, Error, Options, extract_row_id};
|
||||
use crate::{Value, ValueType, named_params, params};
|
||||
use crate::{Value, ValueType};
|
||||
use rusqlite::ErrorCode;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -349,6 +349,15 @@ async fn test_execute_batch_error() {
|
||||
assert!(result.is_err(), "{result:?}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_identity() {
|
||||
let conn0 = Connection::open_in_memory().unwrap();
|
||||
let conn1 = Connection::open_in_memory().unwrap();
|
||||
|
||||
assert_ne!(conn0, conn1);
|
||||
assert_eq!(conn0, conn0.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locking() {
|
||||
let conn = Connection::open_in_memory().unwrap();
|
||||
|
||||
Reference in New Issue
Block a user