diff --git a/Cargo.lock b/Cargo.lock index c187cc1d..153d1173 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -715,6 +715,26 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools 0.10.5", + "lazy_static", + "lazycell", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.95", +] + [[package]] name = "bindgen" version = "0.70.1" @@ -3028,6 +3048,12 @@ dependencies = [ "spin", ] +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "lettre" version = "0.11.11" @@ -3097,6 +3123,7 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" dependencies = [ + "bindgen 0.69.5", "cc", "pkg-config", "vcpkg", diff --git a/Cargo.toml b/Cargo.toml index 46b949f4..417c36d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ rusqlite = { version = "^0.32.1", default-features = false, features = [ "limits", "backup", "hooks", + "preupdate_hook", ] } trailbase-sqlean = { path = "vendor/sqlean", version = "^0.0.1" } trailbase-extension = { path = "trailbase-extension", version = "^0.1.0" } diff --git a/client/trailbase-dart/lib/src/client.dart b/client/trailbase-dart/lib/src/client.dart index d7bf1f21..7733eae7 100644 --- a/client/trailbase-dart/lib/src/client.dart +++ b/client/trailbase-dart/lib/src/client.dart @@ -5,6 +5,8 @@ import 'package:jwt_decoder/jwt_decoder.dart'; import 'package:logging/logging.dart'; import 'package:dio/dio.dart' as dio; +import 'sse.dart'; + class User { final String id; final String email; @@ -272,6 +274,22 @@ class RecordApi { ); } + Future>> subscribe(RecordId id) async { + final resp = await _client.fetch( + '${RecordApi._recordApi}/${_name}/subscribe/${id}', + responseType: dio.ResponseType.stream, + ); + + final Stream stream = resp.data.stream; + return stream.asyncMap((Uint8List bytes) { + final decoded = utf8.decode(bytes); + if (decoded.startsWith('data: ')) { + return jsonDecode(decoded.substring(6)); + } + return jsonDecode(decoded); + }); + } + Uri imageUri(RecordId id, String colName, {int? index}) { if (index != null) { return Uri.parse( @@ -283,7 +301,7 @@ class RecordApi { } class _ThinClient { - static final _dio = dio.Dio(); + static final _dio = dio.Dio()..interceptors.add(SeeInterceptor()); final String site; @@ -295,6 +313,7 @@ class _ThinClient { Object? data, String? method, Map? queryParams, + dio.ResponseType? responseType, }) async { if (path.startsWith('/')) { throw Exception('Path starts with "/". Relative path expected.'); @@ -308,6 +327,7 @@ class _ThinClient { method: method, headers: tokenState.headers, validateStatus: (int? status) => true, + responseType: responseType, ), ); @@ -508,6 +528,7 @@ class Client { Object? data, String? method, Map? queryParams, + dio.ResponseType? responseType, }) async { var tokenState = _tokenState; final refreshToken = _shouldRefresh(tokenState); @@ -515,8 +536,14 @@ class Client { tokenState = _tokenState = await _refreshTokensImpl(refreshToken); } - final response = await _client.fetch(path, tokenState, - data: data, method: method, queryParams: queryParams); + final response = await _client.fetch( + path, + tokenState, + data: data, + method: method, + queryParams: queryParams, + responseType: responseType, + ); if (response.statusCode != 200 && (throwOnError ?? true)) { final errMsg = await response.data; diff --git a/client/trailbase-dart/lib/src/sse.dart b/client/trailbase-dart/lib/src/sse.dart new file mode 100644 index 00000000..cce5d0a6 --- /dev/null +++ b/client/trailbase-dart/lib/src/sse.dart @@ -0,0 +1,50 @@ +import 'dart:async'; +import 'dart:typed_data'; + +import 'package:dio/dio.dart' as dio; + +class SeeInterceptor extends dio.Interceptor { + @override + void onResponse( + dio.Response response, dio.ResponseInterceptorHandler handler) { + if (response.requestOptions.responseType == dio.ResponseType.stream) { + final Stream stream = response.data.stream; + + final buffer = BytesBuilder(); + final transformedStream = stream.transform( + StreamTransformer.fromHandlers( + handleData: (Uint8List data, EventSink sink) { + // If terminated correctly (\n\n) write to sink, otherwise buffer. + if (endsWithNewlineNewline(data)) { + if (buffer.isNotEmpty) { + buffer.add(data); + sink.add(buffer.takeBytes()); + } else { + sink.add(data); + } + } else { + buffer.add(data); + } + }, + ), + ); + + return handler.resolve(dio.Response( + requestOptions: response.requestOptions, + data: dio.ResponseBody(transformedStream, response.data.contentLength), + statusCode: response.statusCode, + headers: response.headers, + )); + } + + handler.next(response); + } + + bool endsWithNewlineNewline(List bytes) { + if (bytes.length < 2) { + return false; + } + + return bytes[bytes.length - 1] == 10 && bytes[bytes.length - 2] == 10; + } +} diff --git a/client/trailbase-dart/lib/trailbase.dart b/client/trailbase-dart/lib/trailbase.dart index 06a947f8..1d840900 100644 --- a/client/trailbase-dart/lib/trailbase.dart +++ b/client/trailbase-dart/lib/trailbase.dart @@ -2,3 +2,4 @@ library; export 'src/client.dart'; export 'src/pkce.dart'; +export 'src/sse.dart'; diff --git a/client/trailbase-dart/test/trailbase_test.dart b/client/trailbase-dart/test/trailbase_test.dart index a5638830..7b9dc54d 100644 --- a/client/trailbase-dart/test/trailbase_test.dart +++ b/client/trailbase-dart/test/trailbase_test.dart @@ -57,14 +57,19 @@ Future initTrailBase() async { print('Trying to connect to TrailBase'); } - await Future.delayed(Duration(milliseconds: 500)); + if (await process.exitCode + .timeout(Duration(milliseconds: 500), onTimeout: () => -1) >= + 0) { + break; + } } process.kill(ProcessSignal.sigkill); final exitCode = await process.exitCode; - await process.stdout.forEach(print); - await process.stderr.forEach(print); + await process.stderr.forEach(stdout.add); + await process.stdout.forEach(stdout.add); + throw Exception('Cargo run failed: ${exitCode}.'); } @@ -73,7 +78,15 @@ Future main() async { throw Exception('Unexpected working directory'); } - await initTrailBase(); + final process = await initTrailBase(); + + tearDownAll(() async { + process.kill(ProcessSignal.sigkill); + final _ = await process.exitCode; + + // await process.stderr.forEach(stdout.add); + // await process.stdout.forEach(stdout.add); + }); group('client tests', () { test('auth', () async { @@ -159,5 +172,27 @@ Future main() async { await api.delete(ids[0]); expect(() async => await api.read(ids[0]), throwsException); }); + + test('realtime', () async { + final client = await connect(); + final api = client.records('simple_strict_table'); + + final int now = DateTime.now().millisecondsSinceEpoch ~/ 1000; + final id = await api + .create({'text_not_null': 'dart client realtime test 0: =?&${now}'}); + + final events = await api.subscribe(id); + + final updatedMessage = 'dart client updated realtime test 0: ${now}'; + await api.update(id, {'text_not_null': updatedMessage}); + await api.delete(id); + + final eventList = + await events.timeout(Duration(seconds: 10), onTimeout: (sink) { + print('Stream timeout'); + sink.close(); + }).toList(); + expect(eventList.length, equals(2)); + }); }); } diff --git a/trailbase-core/src/app_state.rs b/trailbase-core/src/app_state.rs index f18e7c4f..e97a63bd 100644 --- a/trailbase-core/src/app_state.rs +++ b/trailbase-core/src/app_state.rs @@ -11,6 +11,7 @@ use crate::constants::SITE_URL_DEFAULT; use crate::data_dir::DataDir; use crate::email::Mailer; use crate::js::RuntimeHandle; +use crate::records::subscribe::SubscriptionManager; use crate::records::RecordApi; use crate::table_metadata::TableMetadataCache; use crate::value_notifier::{Computed, ValueNotifier}; @@ -33,6 +34,7 @@ struct InternalState { jwt: JwtHelper, table_metadata: TableMetadataCache, + subscription_manager: SubscriptionManager, object_store: Box, runtime: RuntimeHandle, @@ -67,6 +69,22 @@ impl AppState { let table_metadata_clone = args.table_metadata.clone(); let conn_clone = args.conn.clone(); + let record_apis = Computed::new(&config, move |c| { + return c + .record_apis + .iter() + .filter_map(|config| { + match build_record_api(conn_clone.clone(), &table_metadata_clone, config.clone()) { + Ok(api) => Some((api.api_name().to_string(), api)), + Err(err) => { + error!("{err}"); + None + } + } + }) + .collect::>(); + }); + let runtime = args .js_runtime_threads .map_or_else(RuntimeHandle::new, RuntimeHandle::new_with_threads); @@ -87,26 +105,13 @@ impl AppState { } }), mailer: build_mailer(&config, None), - record_apis: Computed::new(&config, move |c| { - return c - .record_apis - .iter() - .filter_map(|config| { - match build_record_api(conn_clone.clone(), &table_metadata_clone, config.clone()) { - Ok(api) => Some((api.api_name().to_string(), api)), - Err(err) => { - error!("{err}"); - None - } - } - }) - .collect::>(); - }), + record_apis: record_apis.clone(), config, conn: args.conn.clone(), logs_conn: args.logs_conn, jwt: args.jwt, - table_metadata: args.table_metadata, + table_metadata: args.table_metadata.clone(), + subscription_manager: SubscriptionManager::new(args.conn, args.table_metadata, record_apis), object_store: args.object_store, runtime, #[cfg(test)] @@ -145,6 +150,10 @@ impl AppState { return &self.state.table_metadata; } + pub(crate) fn subscription_manager(&self) -> &SubscriptionManager { + return &self.state.subscription_manager; + } + pub async fn refresh_table_cache(&self) -> Result<(), crate::table_metadata::TableLookupError> { self.table_metadata().invalidate_all().await } @@ -360,6 +369,23 @@ pub async fn test_state(options: Option) -> anyhow::Result>(); + }); + let runtime = RuntimeHandle::new(); runtime.set_connection(conn.clone()); @@ -372,27 +398,13 @@ pub async fn test_state(options: Option) -> anyhow::Result>(); - }), + record_apis: record_apis.clone(), config, - conn, + conn: conn.clone(), logs_conn, jwt: jwt::test_jwt_helper(), - table_metadata, + table_metadata: table_metadata.clone(), + subscription_manager: SubscriptionManager::new(conn, table_metadata, record_apis), object_store, runtime, cleanup: vec![Box::new(temp_dir)], diff --git a/trailbase-core/src/records/mod.rs b/trailbase-core/src/records/mod.rs index 3e92aeec..51bad382 100644 --- a/trailbase-core/src/records/mod.rs +++ b/trailbase-core/src/records/mod.rs @@ -14,6 +14,7 @@ mod list_records; pub(crate) mod read_record; mod record_api; pub mod sql_to_json; +pub(crate) mod subscribe; pub mod test_utils; mod update_record; mod validate; @@ -76,6 +77,10 @@ pub(crate) fn router() -> Router { .route( &format!("/{RECORD_API_PATH}/{{name}}/schema"), get(json_schema::json_schema_handler), + ) + .route( + &format!("/{RECORD_API_PATH}/{{name}}/subscribe/{{record}}"), + get(subscribe::add_subscription_sse_handler), ); } diff --git a/trailbase-core/src/records/record_api.rs b/trailbase-core/src/records/record_api.rs index d20a66bd..3f17e97a 100644 --- a/trailbase-core/src/records/record_api.rs +++ b/trailbase-core/src/records/record_api.rs @@ -339,23 +339,23 @@ impl RecordApi { /// Check if the given user (if any) can access a record given the request and the operation. #[allow(unused)] - pub fn check_record_level_access_sync( + pub fn check_record_level_read_access( &self, - conn: &mut rusqlite::Connection, + conn: &rusqlite::Connection, p: Permission, - record_id: Option<&Value>, - request_params: Option<&mut LazyParams<'_>>, + record: Vec<(&str, rusqlite::types::Value)>, user: Option<&User>, ) -> Result<(), RecordError> { // First check table level access and if present check row-level access based on access rule. self.check_table_level_access(p, user)?; - let Some(access_query) = self.access_query(p) else { + let Some(access_rule) = self.access_rule(p) else { return Ok(()); }; - let params = self.build_named_params(p, record_id, request_params, user)?; - match Self::query_access(conn, access_query, params) { + let (query, params) = build_query_and_params_for_record_read(access_rule, user, record); + + match Self::query_access(conn, &query, params) { Ok(allowed) => { if allowed { return Ok(()); @@ -389,7 +389,7 @@ impl RecordApi { #[inline] fn query_access( - conn: &mut rusqlite::Connection, + conn: &rusqlite::Connection, access_query: &str, params: NamedParams, ) -> Result { @@ -453,6 +453,45 @@ impl RecordApi { } } +fn build_query_and_params_for_record_read( + access_rule: &str, + user: Option<&User>, + record: Vec<(&str, rusqlite::types::Value)>, +) -> (String, NamedParams) { + let row = record + .iter() + .map(|(name, _value)| { + return format!(":__v_{name} AS {name}"); + }) + .collect::>() + .join(", "); + + let mut params: Vec<_> = record + .into_iter() + .map(|(name, value)| { + return (Cow::Owned(format!(":__v_{name}")), value); + }) + .collect(); + + params.push(( + Cow::Borrowed(":__user_id"), + user.map_or(Value::Null, |u| Value::Blob(u.uuid.into())), + )); + + // Assumes access_rule is an expression: https://www.sqlite.org/syntax/expr.html + let query = indoc::formatdoc!( + r#" + SELECT + ({access_rule}) + FROM + (SELECT :__user_id AS id) AS _USER_, + (SELECT {row}) AS _ROW_ + "# + ); + + return (query, params); +} + /// Build access query for record reads, deletes and query access. /// /// Assumes access_rule is an expression: https://www.sqlite.org/syntax/expr.html diff --git a/trailbase-core/src/records/sql_to_json.rs b/trailbase-core/src/records/sql_to_json.rs index 1384bf68..d3054dad 100644 --- a/trailbase-core/src/records/sql_to_json.rs +++ b/trailbase-core/src/records/sql_to_json.rs @@ -19,20 +19,20 @@ pub enum JsonError { ValueNotFound, } -fn value_to_json(value: rusqlite::types::Value) -> Result { +pub(crate) fn value_to_json(value: rusqlite::types::Value) -> Result { + use rusqlite::types::Value; + return Ok(match value { - rusqlite::types::Value::Null => serde_json::Value::Null, - rusqlite::types::Value::Real(real) => { + Value::Null => serde_json::Value::Null, + Value::Real(real) => { let Some(number) = serde_json::Number::from_f64(real) else { return Err(JsonError::Finite); }; serde_json::Value::Number(number) } - rusqlite::types::Value::Integer(integer) => { - serde_json::Value::Number(serde_json::Number::from(integer)) - } - rusqlite::types::Value::Blob(blob) => serde_json::Value::String(BASE64_URL_SAFE.encode(blob)), - rusqlite::types::Value::Text(text) => serde_json::Value::String(text), + Value::Integer(integer) => serde_json::Value::Number(serde_json::Number::from(integer)), + Value::Blob(blob) => serde_json::Value::String(BASE64_URL_SAFE.encode(blob)), + Value::Text(text) => serde_json::Value::String(text), }); } diff --git a/trailbase-core/src/records/subscribe.rs b/trailbase-core/src/records/subscribe.rs new file mode 100644 index 00000000..29bb51e6 --- /dev/null +++ b/trailbase-core/src/records/subscribe.rs @@ -0,0 +1,560 @@ +use axum::{ + extract::{Path, State}, + response::sse::{Event, KeepAlive, Sse}, +}; +use futures::stream::{Stream, StreamExt}; +use parking_lot::RwLock; +use rusqlite::hooks::{Action, PreUpdateCase}; +use serde::Serialize; +use std::collections::HashMap; +use std::sync::{ + atomic::{AtomicI64, Ordering}, + Arc, +}; +use trailbase_sqlite::{ + connection::{extract_record_values, extract_row_id}, + params, +}; + +use crate::auth::user::User; +use crate::records::sql_to_json::value_to_json; +use crate::records::RecordApi; +use crate::records::{Permission, RecordError}; +use crate::table_metadata::{TableMetadata, TableMetadataCache}; +use crate::value_notifier::Computed; +use crate::AppState; + +static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0); + +// TODO: +// * clients +// * table-wide subscriptions +// * optimize: avoid repeated encoding of events. Easy to do but makes testing harder since there's +// no good way to parse sse::Event back :/. We should probably just bite the bullet and parse, +// it's literally "data: \n\n". + +type SseEvent = Result; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize)] +pub enum RecordAction { + Delete, + Insert, + Update, +} + +impl From for RecordAction { + fn from(value: Action) -> Self { + return match value { + Action::SQLITE_DELETE => RecordAction::Delete, + Action::SQLITE_INSERT => RecordAction::Insert, + Action::SQLITE_UPDATE => RecordAction::Update, + _ => unreachable!("{value:?}"), + }; + } +} + +#[derive(Debug, Clone, Serialize)] +pub enum DbEvent { + Update(Option), + Insert(Option), + Delete(Option), + Error(String), +} + +// pub struct SubscriptionId { +// table_name: String, +// row_id: i64, +// subscription_id: i64, +// } + +pub struct Subscription { + /// Id uniquely identifying this subscription. + subscription_id: i64, + /// Name of the API this subscription is subscribed to. We need to lookup the Record API on the + /// hot path to make sure we're getting the latest configuration. + record_api_name: String, + + /// Record id present for subscriptions to specific records. + // record_id: Option, + user: Option, + /// Channel for sending events to the SSE handler. + channel: async_channel::Sender, +} + +/// Internal, shareable state of the cloneable SubscriptionManager. +struct ManagerState { + /// SQLite connection to monitor. + conn: trailbase_sqlite::Connection, + /// Table metadata for mapping column indexes to column names needed for building JSON encoded + /// records. + table_metadata: TableMetadataCache, + /// Record API configurations. + record_apis: Computed, crate::config::proto::Config>, + + /// Map from table name to row id to list of subscriptions. + subscriptions: RwLock>>>, +} + +impl ManagerState { + fn lookup_record_api(&self, name: &str) -> Option { + for (record_api_name, record_api) in self.record_apis.load().iter() { + if record_api_name == name { + return Some(record_api.clone()); + } + } + return None; + } +} + +#[derive(Clone)] +pub struct SubscriptionManager { + state: Arc, +} + +struct ContinuationState { + state: Arc, + table_metadata: Arc, + action: RecordAction, + table_name: String, + rowid: i64, + record_values: Vec, +} + +impl SubscriptionManager { + pub fn new( + conn: trailbase_sqlite::Connection, + table_metadata: TableMetadataCache, + record_apis: Computed, crate::config::proto::Config>, + ) -> Self { + return Self { + state: Arc::new(ManagerState { + conn, + table_metadata, + record_apis, + + subscriptions: RwLock::new(HashMap::new()), + }), + }; + } + + pub fn num_subscriptions(&self) -> usize { + let mut count: usize = 0; + for table in self.state.subscriptions.read().values() { + for record in table.values() { + count += record.len(); + } + } + return count; + } + + /// Preupdate hook that runs in a continuation of the trailbase-sqlite executor. + fn hook_continuation(conn: &rusqlite::Connection, state: ContinuationState) { + let ContinuationState { + state, + table_metadata, + table_name, + action, + rowid, + record_values, + } = state; + + let s = &state; + + let mut read_lock = s.subscriptions.upgradable_read(); + let Some(subs) = read_lock.get(&table_name).and_then(|m| m.get(&rowid)) else { + return; + }; + + // Join values with column names. + let record: Vec<_> = record_values + .into_iter() + .enumerate() + .map(|(idx, v)| (table_metadata.schema.columns[idx].name.as_str(), v)) + .collect(); + + // Build a JSON-encoded SQLite event (insert, update, delete). + let event = { + let json_value = serde_json::Value::Object( + record + .iter() + .filter_map(|(name, value)| { + if let Ok(v) = value_to_json(value.clone()) { + return Some(((*name).to_string(), v)); + }; + return None; + }) + .collect(), + ); + + match action { + RecordAction::Delete => DbEvent::Delete(Some(json_value)), + RecordAction::Insert => DbEvent::Insert(Some(json_value)), + RecordAction::Update => DbEvent::Update(Some(json_value)), + } + }; + + let mut dead_subscriptions: Vec = vec![]; + for (idx, sub) in subs.iter().enumerate() { + let Some(api) = s.lookup_record_api(&sub.record_api_name) else { + dead_subscriptions.push(idx); + continue; + }; + + if let Err(_err) = api.check_record_level_read_access( + conn, + Permission::Read, + // TODO: Maybe we could inject ValueRef instead to avoid repeated cloning. + record.clone(), + sub.user.as_ref(), + ) { + // This can happen if the record api configuration has changed since originally + // subscribed. In this case we just send and error and cancel the subscription. + let _ = sub.channel.try_send(DbEvent::Error("Access denied".into())); + dead_subscriptions.push(idx); + continue; + } + + // TODO: Avoid cloning the event/record over and over. + match sub.channel.try_send(event.clone()) { + Ok(_) => {} + Err(async_channel::TrySendError::Full(ev)) => { + log::warn!("Channel full, dropping event: {ev:?}"); + } + Err(async_channel::TrySendError::Closed(_ev)) => { + dead_subscriptions.push(idx); + } + } + } + + if dead_subscriptions.is_empty() && action != RecordAction::Delete { + // No cleanup needed. + return; + } + + read_lock.with_upgraded(move |subscriptions| { + let Some(table_subscriptions) = subscriptions.get_mut(&table_name) else { + return; + }; + + if action == RecordAction::Delete { + // Also drops the channel and thus automatically closes the SSE connection. + table_subscriptions.remove(&rowid); + + if table_subscriptions.is_empty() { + subscriptions.remove(&table_name); + if subscriptions.is_empty() { + conn.preupdate_hook(NO_HOOK); + } + } + + return; + } + + if let Some(m) = table_subscriptions.get_mut(&rowid) { + for idx in dead_subscriptions.iter().rev() { + m.swap_remove(*idx); + } + + if m.is_empty() { + table_subscriptions.remove(&rowid); + + if table_subscriptions.is_empty() { + subscriptions.remove(&table_name); + if subscriptions.is_empty() { + conn.preupdate_hook(NO_HOOK); + } + } + } + } + }); + } + + async fn add_hook(&self) -> trailbase_sqlite::connection::Result<()> { + let state = &self.state; + let conn = state.conn.clone(); + let s = state.clone(); + + return state + .conn + .add_preupdate_hook(Some( + move |action: Action, db: &str, table_name: &str, case: &PreUpdateCase| { + assert_eq!(db, "main"); + + let action: RecordAction = match action { + Action::SQLITE_UPDATE | Action::SQLITE_INSERT | Action::SQLITE_DELETE => action.into(), + a => { + log::error!("Unknown action: {a:?}"); + return; + } + }; + + let Some(rowid) = extract_row_id(case) else { + log::error!("Failed to extract row id"); + return; + }; + + // If there are no subscriptions, do nothing. + if s + .subscriptions + .read() + .get(table_name) + .and_then(|m| m.get(&rowid)) + .is_none() + { + return; + } + + let Some(table_metadata) = s.table_metadata.get(table_name) else { + // TODO: Should we cleanup here? Probably, since we won't recover from this issue. + log::error!("Table not found: {table_name}"); + return; + }; + + let Some(record_values) = extract_record_values(case) else { + log::error!("Failed to extract values"); + return; + }; + + let state = ContinuationState { + state: s.clone(), + table_metadata, + action, + table_name: table_name.to_string(), + rowid, + record_values, + }; + + // TODO: Optimization: in cases where there's only table-level access restrictions, we + // could avoid the continuation and even dispatch the subscription handling to a + // different thread entirely to take more work off the SQLite thread. + conn.call_and_forget(move |conn| { + Self::hook_continuation(conn, state); + }); + }, + )) + .await; + } + + async fn add_subscription( + &self, + api: RecordApi, + record: Option, + user: Option, + ) -> Result, RecordError> { + let Some(record) = record else { + return Err(RecordError::BadRequest("Missing record id")); + }; + let (sender, receiver) = async_channel::bounded::(16); + + let table_name = api.table_name(); + let pk_column = &api.record_pk_column().name; + + let Some(row) = self + .state + .conn + .query_row( + &format!(r#"SELECT _rowid_ FROM "{table_name}" WHERE "{pk_column}" = $1"#), + params!(record.clone()), + ) + .await? + else { + return Err(RecordError::RecordNotFound); + }; + let row_id: i64 = row + .get(0) + .map_err(|err| RecordError::Internal(err.into()))?; + + let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst); + let empty = { + let mut lock = self.state.subscriptions.write(); + let empty = lock.is_empty(); + let m: &mut HashMap> = lock.entry(table_name.to_string()).or_default(); + + m.entry(row_id).or_default().push(Subscription { + subscription_id, + record_api_name: api.api_name().to_string(), + // record_id: Some(record), + user, + channel: sender, + }); + + empty + }; + + if empty { + self.add_hook().await.unwrap(); + } + + return Ok(receiver); + } + + // TODO: Cleaning up subscriptions might be a thing, e.g. if SSE handlers had an onDisconnect + // handler. Right now we're handling cleanups reactively, i.e. we only remove subscriptions when + // sending new events and the receiving end of a handler channel became invalid. It would + // be better to be pro-active and remove subscriptions sooner. + // + // async fn cleanup_subscription(&self, subscription_id: SubscriptionId) -> Result<(), + // RecordError> { let mut lock = self.state.subscriptions.write(); + // + // if let Some(table_subs) = lock.get_mut(&subscription_id.table_name) { + // if let Some(subs) = table_subs.get_mut(&subscription_id.row_id) { + // subs.retain(|s| s.id != subscription_id.subscription_id); + // + // if subs.is_empty() { + // table_subs.remove(&subscription_id.row_id); + // } + // } + // + // if table_subs.is_empty() { + // lock.remove(&subscription_id.table_name); + // } + // } + // + // if lock.is_empty() { + // Self::remove_preupdate_hook(&*self.state).await?; + // } + // + // return Ok(()); + // } +} + +pub async fn add_subscription_sse_handler( + State(state): State, + Path((api_name, record)): Path<(String, String)>, + user: Option, +) -> Result>, RecordError> { + let Some(api) = state.lookup_record_api(&api_name) else { + return Err(RecordError::ApiNotFound); + }; + + let record_id = api.id_to_sql(&record)?; + + let Ok(()) = api + .check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref()) + .await + else { + return Err(RecordError::Forbidden); + }; + + let receiver = state + .subscription_manager() + .add_subscription(api, Some(record_id), user) + .await?; + + return Ok( + Sse::new(receiver.map(|ev| { + // TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy + // to pre-encode on the sender side but makes testing much harder, since there's no good way + // to parse sse::Event back. + return Event::default().json_data(ev); + })) + .keep_alive(KeepAlive::default()), + ); +} + +#[cfg(test)] +mod tests { + use super::DbEvent; + use super::*; + use crate::app_state::test_state; + use crate::records::{add_record_api, AccessRules, Acls, PermissionFlag}; + + #[tokio::test] + async fn subscribe_connection_test() { + let state = test_state(None).await.unwrap(); + let conn = state.conn().clone(); + + conn + .execute( + "CREATE TABLE test (id INTEGER PRIMARY KEY, text TEXT) STRICT", + (), + ) + .await + .unwrap(); + + state.table_metadata().invalidate_all().await.unwrap(); + + // Register message table as record api with moderator read access. + add_record_api( + &state, + "api_name", + "test", + Acls { + world: vec![PermissionFlag::Create, PermissionFlag::Read], + ..Default::default() + }, + AccessRules { + // read: Some("(_ROW_._owner = _USER_.id OR EXISTS(SELECT 1 FROM room_members WHERE room = + // _ROW_.room AND user = _USER_.id))".to_string()), + ..Default::default() + }, + ) + .await + .unwrap(); + + let record_id_raw = 0; + let record_id = trailbase_sqlite::Value::Integer(record_id_raw); + let rowid: i64 = conn + .query_row( + "INSERT INTO test (id, text) VALUES ($1, 'foo') RETURNING _rowid_", + [record_id], + ) + .await + .unwrap() + .unwrap() + .get(0) + .unwrap(); + + assert_eq!(rowid, record_id_raw); + + let manager = state.subscription_manager(); + let api = state.lookup_record_api("api_name").unwrap(); + let receiver = manager + .add_subscription(api, Some(trailbase_sqlite::Value::Integer(0)), None) + .await + .unwrap(); + + assert_eq!(1, manager.num_subscriptions()); + + conn + .execute( + "UPDATE test SET text = $1 WHERE _rowid_ = $2", + params!("bar", rowid), + ) + .await + .unwrap(); + + let expected = serde_json::json!({ + "id": record_id_raw, + "text": "bar", + }); + match receiver.recv().await.unwrap() { + DbEvent::Update(Some(value)) => { + assert_eq!(value, expected); + } + x => { + assert!(false, "Expected update, got: {x:?}"); + } + }; + + conn + .execute("DELETE FROM test WHERE _rowid_ = $2", params!(rowid)) + .await + .unwrap(); + + match receiver.recv().await.unwrap() { + DbEvent::Delete(Some(value)) => { + assert_eq!(value, expected); + } + x => { + assert!(false, "Expected update, got: {x:?}"); + } + } + + assert_eq!(0, manager.num_subscriptions()); + } + + // TODO: Test actual SSE handler. +} + +const NO_HOOK: Option = None; diff --git a/trailbase-core/src/table_metadata.rs b/trailbase-core/src/table_metadata.rs index 3e291797..5009722e 100644 --- a/trailbase-core/src/table_metadata.rs +++ b/trailbase-core/src/table_metadata.rs @@ -73,7 +73,7 @@ pub struct TableMetadata { metadata: Vec, name_to_index: HashMap, - record_pk_column: Option, + pub record_pk_column: Option, pub user_id_columns: Vec, pub file_upload_columns: Vec, pub file_uploads_columns: Vec, diff --git a/trailbase-core/src/value_notifier.rs b/trailbase-core/src/value_notifier.rs index 8ae2a34f..3fa6762a 100644 --- a/trailbase-core/src/value_notifier.rs +++ b/trailbase-core/src/value_notifier.rs @@ -68,6 +68,7 @@ struct ComputedState { f: Box T>, } +#[derive(Clone)] pub struct Computed { state: Arc>, } diff --git a/trailbase-sqlite/src/connection.rs b/trailbase-sqlite/src/connection.rs index 4c025679..db9f43e5 100644 --- a/trailbase-sqlite/src/connection.rs +++ b/trailbase-sqlite/src/connection.rs @@ -1,5 +1,6 @@ use crossbeam_channel::{Receiver, Sender}; -use rusqlite::hooks::Action; +use rusqlite::hooks::{Action, PreUpdateCase}; +use rusqlite::types::Value; use std::{ fmt::{self, Debug}, sync::Arc, @@ -35,11 +36,9 @@ macro_rules! named_params { pub type Result = std::result::Result; type CallFn = Box; -type HookFn = Arc; enum Message { Run(CallFn), - ExecuteHook(HookFn, Action, String, String, i64), Close(oneshot::Sender>), } @@ -73,7 +72,7 @@ impl Connection { /// Will return `Err` if the database connection has been closed. pub async fn call(&self, function: F) -> Result where - F: FnOnce(&mut rusqlite::Connection) -> Result + 'static + Send, + F: FnOnce(&mut rusqlite::Connection) -> Result + Send + 'static, R: Send + 'static, { let (sender, receiver) = oneshot::channel::>(); @@ -89,6 +88,12 @@ impl Connection { receiver.await.map_err(|_| Error::ConnectionClosed)? } + pub fn call_and_forget(&self, function: impl FnOnce(&rusqlite::Connection) + Send + 'static) { + let _ = self + .sender + .send(Message::Run(Box::new(move |conn| function(conn)))); + } + /// Query SQL statement. pub async fn query(&self, sql: &str, params: impl Params + Send + 'static) -> Result { let sql = sql.to_string(); @@ -204,39 +209,12 @@ impl Connection { .await; } - pub async fn add_hook( + /// Convenience API for (un)setting a new pre-update hook. + pub async fn add_preupdate_hook( &self, - f: impl Fn(&rusqlite::Connection, Action, &str, &str, i64) + Send + Sync + 'static, + hook: Option, ) -> Result<()> { - let sender = self.sender.clone(); - let f = Arc::new(f); - - return self - .call(|conn| { - conn.update_hook(Some( - move |action: Action, db: &str, table: &str, row: i64| { - let _ = sender.send(Message::ExecuteHook( - f.clone(), - action, - db.to_string(), - table.to_string(), - row, - )); - }, - )); - - return Ok(()); - }) - .await; - } - - pub async fn remove_hook(&self) -> Result<()> { - return self - .call(|conn| { - conn.update_hook(None::); - return Ok(()); - }) - .await; + return self.call(|conn| Ok(conn.preupdate_hook(hook))).await; } /// Close the database connection. @@ -286,7 +264,6 @@ fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver) { while let Ok(message) = receiver.recv() { match message { Message::Run(f) => f(&mut conn), - Message::ExecuteHook(f, action, db, table, row) => f(&conn, action, &db, &table, row), Message::Close(ch) => { match conn.close() { Ok(v) => ch.send(Ok(v)).expect(BUG_TEXT), @@ -299,6 +276,50 @@ fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver) { } } +pub fn extract_row_id(case: &PreUpdateCase) -> Option { + return match case { + PreUpdateCase::Insert(accessor) => Some(accessor.get_new_row_id()), + PreUpdateCase::Delete(accessor) => Some(accessor.get_old_row_id()), + PreUpdateCase::Update { + new_value_accessor: accessor, + .. + } => Some(accessor.get_new_row_id()), + PreUpdateCase::Unknown => None, + }; +} + +pub fn extract_record_values(case: &PreUpdateCase) -> Option> { + return Some(match case { + PreUpdateCase::Insert(accessor) => (0..accessor.get_column_count()) + .map(|idx| -> Value { + accessor + .get_new_column_value(idx) + .map_or(rusqlite::types::Value::Null, |v| v.into()) + }) + .collect(), + PreUpdateCase::Delete(accessor) => (0..accessor.get_column_count()) + .map(|idx| -> rusqlite::types::Value { + accessor + .get_old_column_value(idx) + .map_or(rusqlite::types::Value::Null, |v| v.into()) + }) + .collect(), + PreUpdateCase::Update { + new_value_accessor: accessor, + .. + } => (0..accessor.get_column_count()) + .map(|idx| -> rusqlite::types::Value { + accessor + .get_new_column_value(idx) + .map_or(rusqlite::types::Value::Null, |v| v.into()) + }) + .collect(), + PreUpdateCase::Unknown => { + return None; + } + }); +} + #[cfg(test)] #[path = "tests.rs"] mod tests; diff --git a/trailbase-sqlite/src/params.rs b/trailbase-sqlite/src/params.rs index 1111724a..9228dd62 100644 --- a/trailbase-sqlite/src/params.rs +++ b/trailbase-sqlite/src/params.rs @@ -78,10 +78,9 @@ impl Params for () { impl Params for Vec<(String, types::Value)> { fn bind(self, stmt: &mut Statement<'_>) -> rusqlite::Result<()> { for (name, v) in self { - let Some(idx) = stmt.parameter_index(&name)? else { - continue; + if let Some(idx) = stmt.parameter_index(&name)? { + stmt.raw_bind_parameter(idx, v)?; }; - stmt.raw_bind_parameter(idx, v)?; } return Ok(()); } diff --git a/trailbase-sqlite/src/tests.rs b/trailbase-sqlite/src/tests.rs index 5a91ace7..6738bc4f 100644 --- a/trailbase-sqlite/src/tests.rs +++ b/trailbase-sqlite/src/tests.rs @@ -1,6 +1,8 @@ use rusqlite::ffi; +use rusqlite::hooks::PreUpdateCase; use serde::Deserialize; +use crate::connection::extract_row_id; use crate::{named_params, params, Connection, Error, Value, ValueType}; use rusqlite::ErrorCode; @@ -322,22 +324,49 @@ async fn test_hooks() { .await .unwrap(); - let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::(); - conn - .add_hook(move |c, action, _db, table, row_id| match action { - rusqlite::hooks::Action::SQLITE_INSERT => { - let text = c - .query_row( - &format!(r#"SELECT text FROM "{table}" WHERE _rowid_ = $1"#), - [row_id], - |row| row.get::<_, String>(0), - ) - .unwrap(); + struct State { + action: rusqlite::hooks::Action, + table_name: String, + row_id: i64, + } - sender.send(text).unwrap(); - } - _ => {} - }) + let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::(); + let c = conn.clone(); + + conn + .add_preupdate_hook(Some( + move |action: rusqlite::hooks::Action, _db: &str, table_name: &str, case: &PreUpdateCase| { + let row_id = extract_row_id(case).unwrap(); + let state = State { + action, + table_name: table_name.to_string(), + row_id, + }; + + let sender = sender.clone(); + c.call_and_forget(move |conn| { + match state.action { + rusqlite::hooks::Action::SQLITE_INSERT => { + let text = conn + .query_row( + &format!( + r#"SELECT text FROM "{}" WHERE _rowid_ = $1"#, + state.table_name + ), + [state.row_id], + |row| row.get::<_, String>(0), + ) + .unwrap(); + + sender.send(text).unwrap(); + } + _ => { + panic!("unexpected action: {:?}", state.action); + } + }; + }); + }, + )) .await .unwrap();