Add "realtime" subscriptions for a specific record, i.e. updates and deletion.

This commit is contained in:
Sebastian Jeltsch
2024-12-10 23:52:39 +01:00
parent 5706ff85fe
commit 30f295e6fd
16 changed files with 919 additions and 112 deletions

27
Cargo.lock generated
View File

@@ -715,6 +715,26 @@ dependencies = [
"serde",
]
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"lazy_static",
"lazycell",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.95",
]
[[package]]
name = "bindgen"
version = "0.70.1"
@@ -3028,6 +3048,12 @@ dependencies = [
"spin",
]
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "lettre"
version = "0.11.11"
@@ -3097,6 +3123,7 @@ version = "0.30.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149"
dependencies = [
"bindgen 0.69.5",
"cc",
"pkg-config",
"vcpkg",

View File

@@ -48,6 +48,7 @@ rusqlite = { version = "^0.32.1", default-features = false, features = [
"limits",
"backup",
"hooks",
"preupdate_hook",
] }
trailbase-sqlean = { path = "vendor/sqlean", version = "^0.0.1" }
trailbase-extension = { path = "trailbase-extension", version = "^0.1.0" }

View File

@@ -5,6 +5,8 @@ import 'package:jwt_decoder/jwt_decoder.dart';
import 'package:logging/logging.dart';
import 'package:dio/dio.dart' as dio;
import 'sse.dart';
class User {
final String id;
final String email;
@@ -272,6 +274,22 @@ class RecordApi {
);
}
Future<Stream<Map<String, dynamic>>> subscribe(RecordId id) async {
final resp = await _client.fetch(
'${RecordApi._recordApi}/${_name}/subscribe/${id}',
responseType: dio.ResponseType.stream,
);
final Stream<Uint8List> stream = resp.data.stream;
return stream.asyncMap((Uint8List bytes) {
final decoded = utf8.decode(bytes);
if (decoded.startsWith('data: ')) {
return jsonDecode(decoded.substring(6));
}
return jsonDecode(decoded);
});
}
Uri imageUri(RecordId id, String colName, {int? index}) {
if (index != null) {
return Uri.parse(
@@ -283,7 +301,7 @@ class RecordApi {
}
class _ThinClient {
static final _dio = dio.Dio();
static final _dio = dio.Dio()..interceptors.add(SeeInterceptor());
final String site;
@@ -295,6 +313,7 @@ class _ThinClient {
Object? data,
String? method,
Map<String, dynamic>? queryParams,
dio.ResponseType? responseType,
}) async {
if (path.startsWith('/')) {
throw Exception('Path starts with "/". Relative path expected.');
@@ -308,6 +327,7 @@ class _ThinClient {
method: method,
headers: tokenState.headers,
validateStatus: (int? status) => true,
responseType: responseType,
),
);
@@ -508,6 +528,7 @@ class Client {
Object? data,
String? method,
Map<String, dynamic>? queryParams,
dio.ResponseType? responseType,
}) async {
var tokenState = _tokenState;
final refreshToken = _shouldRefresh(tokenState);
@@ -515,8 +536,14 @@ class Client {
tokenState = _tokenState = await _refreshTokensImpl(refreshToken);
}
final response = await _client.fetch(path, tokenState,
data: data, method: method, queryParams: queryParams);
final response = await _client.fetch(
path,
tokenState,
data: data,
method: method,
queryParams: queryParams,
responseType: responseType,
);
if (response.statusCode != 200 && (throwOnError ?? true)) {
final errMsg = await response.data;

View File

@@ -0,0 +1,50 @@
import 'dart:async';
import 'dart:typed_data';
import 'package:dio/dio.dart' as dio;
class SeeInterceptor extends dio.Interceptor {
@override
void onResponse(
dio.Response response, dio.ResponseInterceptorHandler handler) {
if (response.requestOptions.responseType == dio.ResponseType.stream) {
final Stream<Uint8List> stream = response.data.stream;
final buffer = BytesBuilder();
final transformedStream = stream.transform<Uint8List>(
StreamTransformer.fromHandlers(
handleData: (Uint8List data, EventSink<Uint8List> sink) {
// If terminated correctly (\n\n) write to sink, otherwise buffer.
if (endsWithNewlineNewline(data)) {
if (buffer.isNotEmpty) {
buffer.add(data);
sink.add(buffer.takeBytes());
} else {
sink.add(data);
}
} else {
buffer.add(data);
}
},
),
);
return handler.resolve(dio.Response(
requestOptions: response.requestOptions,
data: dio.ResponseBody(transformedStream, response.data.contentLength),
statusCode: response.statusCode,
headers: response.headers,
));
}
handler.next(response);
}
bool endsWithNewlineNewline(List<int> bytes) {
if (bytes.length < 2) {
return false;
}
return bytes[bytes.length - 1] == 10 && bytes[bytes.length - 2] == 10;
}
}

View File

@@ -2,3 +2,4 @@ library;
export 'src/client.dart';
export 'src/pkce.dart';
export 'src/sse.dart';

View File

@@ -57,14 +57,19 @@ Future<Process> initTrailBase() async {
print('Trying to connect to TrailBase');
}
await Future.delayed(Duration(milliseconds: 500));
if (await process.exitCode
.timeout(Duration(milliseconds: 500), onTimeout: () => -1) >=
0) {
break;
}
}
process.kill(ProcessSignal.sigkill);
final exitCode = await process.exitCode;
await process.stdout.forEach(print);
await process.stderr.forEach(print);
await process.stderr.forEach(stdout.add);
await process.stdout.forEach(stdout.add);
throw Exception('Cargo run failed: ${exitCode}.');
}
@@ -73,7 +78,15 @@ Future<void> main() async {
throw Exception('Unexpected working directory');
}
await initTrailBase();
final process = await initTrailBase();
tearDownAll(() async {
process.kill(ProcessSignal.sigkill);
final _ = await process.exitCode;
// await process.stderr.forEach(stdout.add);
// await process.stdout.forEach(stdout.add);
});
group('client tests', () {
test('auth', () async {
@@ -159,5 +172,27 @@ Future<void> main() async {
await api.delete(ids[0]);
expect(() async => await api.read(ids[0]), throwsException);
});
test('realtime', () async {
final client = await connect();
final api = client.records('simple_strict_table');
final int now = DateTime.now().millisecondsSinceEpoch ~/ 1000;
final id = await api
.create({'text_not_null': 'dart client realtime test 0: =?&${now}'});
final events = await api.subscribe(id);
final updatedMessage = 'dart client updated realtime test 0: ${now}';
await api.update(id, {'text_not_null': updatedMessage});
await api.delete(id);
final eventList =
await events.timeout(Duration(seconds: 10), onTimeout: (sink) {
print('Stream timeout');
sink.close();
}).toList();
expect(eventList.length, equals(2));
});
});
}

View File

@@ -11,6 +11,7 @@ use crate::constants::SITE_URL_DEFAULT;
use crate::data_dir::DataDir;
use crate::email::Mailer;
use crate::js::RuntimeHandle;
use crate::records::subscribe::SubscriptionManager;
use crate::records::RecordApi;
use crate::table_metadata::TableMetadataCache;
use crate::value_notifier::{Computed, ValueNotifier};
@@ -33,6 +34,7 @@ struct InternalState {
jwt: JwtHelper,
table_metadata: TableMetadataCache,
subscription_manager: SubscriptionManager,
object_store: Box<dyn ObjectStore + Send + Sync>,
runtime: RuntimeHandle,
@@ -67,6 +69,22 @@ impl AppState {
let table_metadata_clone = args.table_metadata.clone();
let conn_clone = args.conn.clone();
let record_apis = Computed::new(&config, move |c| {
return c
.record_apis
.iter()
.filter_map(|config| {
match build_record_api(conn_clone.clone(), &table_metadata_clone, config.clone()) {
Ok(api) => Some((api.api_name().to_string(), api)),
Err(err) => {
error!("{err}");
None
}
}
})
.collect::<Vec<_>>();
});
let runtime = args
.js_runtime_threads
.map_or_else(RuntimeHandle::new, RuntimeHandle::new_with_threads);
@@ -87,26 +105,13 @@ impl AppState {
}
}),
mailer: build_mailer(&config, None),
record_apis: Computed::new(&config, move |c| {
return c
.record_apis
.iter()
.filter_map(|config| {
match build_record_api(conn_clone.clone(), &table_metadata_clone, config.clone()) {
Ok(api) => Some((api.api_name().to_string(), api)),
Err(err) => {
error!("{err}");
None
}
}
})
.collect::<Vec<_>>();
}),
record_apis: record_apis.clone(),
config,
conn: args.conn.clone(),
logs_conn: args.logs_conn,
jwt: args.jwt,
table_metadata: args.table_metadata,
table_metadata: args.table_metadata.clone(),
subscription_manager: SubscriptionManager::new(args.conn, args.table_metadata, record_apis),
object_store: args.object_store,
runtime,
#[cfg(test)]
@@ -145,6 +150,10 @@ impl AppState {
return &self.state.table_metadata;
}
pub(crate) fn subscription_manager(&self) -> &SubscriptionManager {
return &self.state.subscription_manager;
}
pub async fn refresh_table_cache(&self) -> Result<(), crate::table_metadata::TableLookupError> {
self.table_metadata().invalidate_all().await
}
@@ -360,6 +369,23 @@ pub async fn test_state(options: Option<TestStateOptions>) -> anyhow::Result<App
build_objectstore(&data_dir, None).unwrap()
};
let record_apis = Computed::new(&config, move |c| {
return c
.record_apis
.iter()
.filter_map(|config| {
let api = build_record_api(
main_conn_clone.clone(),
&table_metadata_clone,
config.clone(),
)
.unwrap();
return Some((api.api_name().to_string(), api));
})
.collect::<Vec<_>>();
});
let runtime = RuntimeHandle::new();
runtime.set_connection(conn.clone());
@@ -372,27 +398,13 @@ pub async fn test_state(options: Option<TestStateOptions>) -> anyhow::Result<App
ConfiguredOAuthProviders::from_config(c.auth.clone()).unwrap()
}),
mailer: build_mailer(&config, options.and_then(|o| o.mailer)),
record_apis: Computed::new(&config, move |c| {
return c
.record_apis
.iter()
.filter_map(|config| {
let api = build_record_api(
main_conn_clone.clone(),
&table_metadata_clone,
config.clone(),
)
.unwrap();
return Some((api.api_name().to_string(), api));
})
.collect::<Vec<_>>();
}),
record_apis: record_apis.clone(),
config,
conn,
conn: conn.clone(),
logs_conn,
jwt: jwt::test_jwt_helper(),
table_metadata,
table_metadata: table_metadata.clone(),
subscription_manager: SubscriptionManager::new(conn, table_metadata, record_apis),
object_store,
runtime,
cleanup: vec![Box::new(temp_dir)],

View File

@@ -14,6 +14,7 @@ mod list_records;
pub(crate) mod read_record;
mod record_api;
pub mod sql_to_json;
pub(crate) mod subscribe;
pub mod test_utils;
mod update_record;
mod validate;
@@ -76,6 +77,10 @@ pub(crate) fn router() -> Router<AppState> {
.route(
&format!("/{RECORD_API_PATH}/{{name}}/schema"),
get(json_schema::json_schema_handler),
)
.route(
&format!("/{RECORD_API_PATH}/{{name}}/subscribe/{{record}}"),
get(subscribe::add_subscription_sse_handler),
);
}

View File

@@ -339,23 +339,23 @@ impl RecordApi {
/// Check if the given user (if any) can access a record given the request and the operation.
#[allow(unused)]
pub fn check_record_level_access_sync(
pub fn check_record_level_read_access(
&self,
conn: &mut rusqlite::Connection,
conn: &rusqlite::Connection,
p: Permission,
record_id: Option<&Value>,
request_params: Option<&mut LazyParams<'_>>,
record: Vec<(&str, rusqlite::types::Value)>,
user: Option<&User>,
) -> Result<(), RecordError> {
// First check table level access and if present check row-level access based on access rule.
self.check_table_level_access(p, user)?;
let Some(access_query) = self.access_query(p) else {
let Some(access_rule) = self.access_rule(p) else {
return Ok(());
};
let params = self.build_named_params(p, record_id, request_params, user)?;
match Self::query_access(conn, access_query, params) {
let (query, params) = build_query_and_params_for_record_read(access_rule, user, record);
match Self::query_access(conn, &query, params) {
Ok(allowed) => {
if allowed {
return Ok(());
@@ -389,7 +389,7 @@ impl RecordApi {
#[inline]
fn query_access(
conn: &mut rusqlite::Connection,
conn: &rusqlite::Connection,
access_query: &str,
params: NamedParams,
) -> Result<bool, trailbase_sqlite::Error> {
@@ -453,6 +453,45 @@ impl RecordApi {
}
}
fn build_query_and_params_for_record_read(
access_rule: &str,
user: Option<&User>,
record: Vec<(&str, rusqlite::types::Value)>,
) -> (String, NamedParams) {
let row = record
.iter()
.map(|(name, _value)| {
return format!(":__v_{name} AS {name}");
})
.collect::<Vec<_>>()
.join(", ");
let mut params: Vec<_> = record
.into_iter()
.map(|(name, value)| {
return (Cow::Owned(format!(":__v_{name}")), value);
})
.collect();
params.push((
Cow::Borrowed(":__user_id"),
user.map_or(Value::Null, |u| Value::Blob(u.uuid.into())),
));
// Assumes access_rule is an expression: https://www.sqlite.org/syntax/expr.html
let query = indoc::formatdoc!(
r#"
SELECT
({access_rule})
FROM
(SELECT :__user_id AS id) AS _USER_,
(SELECT {row}) AS _ROW_
"#
);
return (query, params);
}
/// Build access query for record reads, deletes and query access.
///
/// Assumes access_rule is an expression: https://www.sqlite.org/syntax/expr.html

View File

@@ -19,20 +19,20 @@ pub enum JsonError {
ValueNotFound,
}
fn value_to_json(value: rusqlite::types::Value) -> Result<serde_json::Value, JsonError> {
pub(crate) fn value_to_json(value: rusqlite::types::Value) -> Result<serde_json::Value, JsonError> {
use rusqlite::types::Value;
return Ok(match value {
rusqlite::types::Value::Null => serde_json::Value::Null,
rusqlite::types::Value::Real(real) => {
Value::Null => serde_json::Value::Null,
Value::Real(real) => {
let Some(number) = serde_json::Number::from_f64(real) else {
return Err(JsonError::Finite);
};
serde_json::Value::Number(number)
}
rusqlite::types::Value::Integer(integer) => {
serde_json::Value::Number(serde_json::Number::from(integer))
}
rusqlite::types::Value::Blob(blob) => serde_json::Value::String(BASE64_URL_SAFE.encode(blob)),
rusqlite::types::Value::Text(text) => serde_json::Value::String(text),
Value::Integer(integer) => serde_json::Value::Number(serde_json::Number::from(integer)),
Value::Blob(blob) => serde_json::Value::String(BASE64_URL_SAFE.encode(blob)),
Value::Text(text) => serde_json::Value::String(text),
});
}

View File

@@ -0,0 +1,560 @@
use axum::{
extract::{Path, State},
response::sse::{Event, KeepAlive, Sse},
};
use futures::stream::{Stream, StreamExt};
use parking_lot::RwLock;
use rusqlite::hooks::{Action, PreUpdateCase};
use serde::Serialize;
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicI64, Ordering},
Arc,
};
use trailbase_sqlite::{
connection::{extract_record_values, extract_row_id},
params,
};
use crate::auth::user::User;
use crate::records::sql_to_json::value_to_json;
use crate::records::RecordApi;
use crate::records::{Permission, RecordError};
use crate::table_metadata::{TableMetadata, TableMetadataCache};
use crate::value_notifier::Computed;
use crate::AppState;
static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0);
// TODO:
// * clients
// * table-wide subscriptions
// * optimize: avoid repeated encoding of events. Easy to do but makes testing harder since there's
// no good way to parse sse::Event back :/. We should probably just bite the bullet and parse,
// it's literally "data: <json>\n\n".
type SseEvent = Result<axum::response::sse::Event, axum::Error>;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize)]
pub enum RecordAction {
Delete,
Insert,
Update,
}
impl From<Action> for RecordAction {
fn from(value: Action) -> Self {
return match value {
Action::SQLITE_DELETE => RecordAction::Delete,
Action::SQLITE_INSERT => RecordAction::Insert,
Action::SQLITE_UPDATE => RecordAction::Update,
_ => unreachable!("{value:?}"),
};
}
}
#[derive(Debug, Clone, Serialize)]
pub enum DbEvent {
Update(Option<serde_json::Value>),
Insert(Option<serde_json::Value>),
Delete(Option<serde_json::Value>),
Error(String),
}
// pub struct SubscriptionId {
// table_name: String,
// row_id: i64,
// subscription_id: i64,
// }
pub struct Subscription {
/// Id uniquely identifying this subscription.
subscription_id: i64,
/// Name of the API this subscription is subscribed to. We need to lookup the Record API on the
/// hot path to make sure we're getting the latest configuration.
record_api_name: String,
/// Record id present for subscriptions to specific records.
// record_id: Option<trailbase_sqlite::Value>,
user: Option<User>,
/// Channel for sending events to the SSE handler.
channel: async_channel::Sender<DbEvent>,
}
/// Internal, shareable state of the cloneable SubscriptionManager.
struct ManagerState {
/// SQLite connection to monitor.
conn: trailbase_sqlite::Connection,
/// Table metadata for mapping column indexes to column names needed for building JSON encoded
/// records.
table_metadata: TableMetadataCache,
/// Record API configurations.
record_apis: Computed<Vec<(String, RecordApi)>, crate::config::proto::Config>,
/// Map from table name to row id to list of subscriptions.
subscriptions: RwLock<HashMap<String, HashMap<i64, Vec<Subscription>>>>,
}
impl ManagerState {
fn lookup_record_api(&self, name: &str) -> Option<RecordApi> {
for (record_api_name, record_api) in self.record_apis.load().iter() {
if record_api_name == name {
return Some(record_api.clone());
}
}
return None;
}
}
#[derive(Clone)]
pub struct SubscriptionManager {
state: Arc<ManagerState>,
}
struct ContinuationState {
state: Arc<ManagerState>,
table_metadata: Arc<TableMetadata>,
action: RecordAction,
table_name: String,
rowid: i64,
record_values: Vec<rusqlite::types::Value>,
}
impl SubscriptionManager {
pub fn new(
conn: trailbase_sqlite::Connection,
table_metadata: TableMetadataCache,
record_apis: Computed<Vec<(String, RecordApi)>, crate::config::proto::Config>,
) -> Self {
return Self {
state: Arc::new(ManagerState {
conn,
table_metadata,
record_apis,
subscriptions: RwLock::new(HashMap::new()),
}),
};
}
pub fn num_subscriptions(&self) -> usize {
let mut count: usize = 0;
for table in self.state.subscriptions.read().values() {
for record in table.values() {
count += record.len();
}
}
return count;
}
/// Preupdate hook that runs in a continuation of the trailbase-sqlite executor.
fn hook_continuation(conn: &rusqlite::Connection, state: ContinuationState) {
let ContinuationState {
state,
table_metadata,
table_name,
action,
rowid,
record_values,
} = state;
let s = &state;
let mut read_lock = s.subscriptions.upgradable_read();
let Some(subs) = read_lock.get(&table_name).and_then(|m| m.get(&rowid)) else {
return;
};
// Join values with column names.
let record: Vec<_> = record_values
.into_iter()
.enumerate()
.map(|(idx, v)| (table_metadata.schema.columns[idx].name.as_str(), v))
.collect();
// Build a JSON-encoded SQLite event (insert, update, delete).
let event = {
let json_value = serde_json::Value::Object(
record
.iter()
.filter_map(|(name, value)| {
if let Ok(v) = value_to_json(value.clone()) {
return Some(((*name).to_string(), v));
};
return None;
})
.collect(),
);
match action {
RecordAction::Delete => DbEvent::Delete(Some(json_value)),
RecordAction::Insert => DbEvent::Insert(Some(json_value)),
RecordAction::Update => DbEvent::Update(Some(json_value)),
}
};
let mut dead_subscriptions: Vec<usize> = vec![];
for (idx, sub) in subs.iter().enumerate() {
let Some(api) = s.lookup_record_api(&sub.record_api_name) else {
dead_subscriptions.push(idx);
continue;
};
if let Err(_err) = api.check_record_level_read_access(
conn,
Permission::Read,
// TODO: Maybe we could inject ValueRef instead to avoid repeated cloning.
record.clone(),
sub.user.as_ref(),
) {
// This can happen if the record api configuration has changed since originally
// subscribed. In this case we just send and error and cancel the subscription.
let _ = sub.channel.try_send(DbEvent::Error("Access denied".into()));
dead_subscriptions.push(idx);
continue;
}
// TODO: Avoid cloning the event/record over and over.
match sub.channel.try_send(event.clone()) {
Ok(_) => {}
Err(async_channel::TrySendError::Full(ev)) => {
log::warn!("Channel full, dropping event: {ev:?}");
}
Err(async_channel::TrySendError::Closed(_ev)) => {
dead_subscriptions.push(idx);
}
}
}
if dead_subscriptions.is_empty() && action != RecordAction::Delete {
// No cleanup needed.
return;
}
read_lock.with_upgraded(move |subscriptions| {
let Some(table_subscriptions) = subscriptions.get_mut(&table_name) else {
return;
};
if action == RecordAction::Delete {
// Also drops the channel and thus automatically closes the SSE connection.
table_subscriptions.remove(&rowid);
if table_subscriptions.is_empty() {
subscriptions.remove(&table_name);
if subscriptions.is_empty() {
conn.preupdate_hook(NO_HOOK);
}
}
return;
}
if let Some(m) = table_subscriptions.get_mut(&rowid) {
for idx in dead_subscriptions.iter().rev() {
m.swap_remove(*idx);
}
if m.is_empty() {
table_subscriptions.remove(&rowid);
if table_subscriptions.is_empty() {
subscriptions.remove(&table_name);
if subscriptions.is_empty() {
conn.preupdate_hook(NO_HOOK);
}
}
}
}
});
}
async fn add_hook(&self) -> trailbase_sqlite::connection::Result<()> {
let state = &self.state;
let conn = state.conn.clone();
let s = state.clone();
return state
.conn
.add_preupdate_hook(Some(
move |action: Action, db: &str, table_name: &str, case: &PreUpdateCase| {
assert_eq!(db, "main");
let action: RecordAction = match action {
Action::SQLITE_UPDATE | Action::SQLITE_INSERT | Action::SQLITE_DELETE => action.into(),
a => {
log::error!("Unknown action: {a:?}");
return;
}
};
let Some(rowid) = extract_row_id(case) else {
log::error!("Failed to extract row id");
return;
};
// If there are no subscriptions, do nothing.
if s
.subscriptions
.read()
.get(table_name)
.and_then(|m| m.get(&rowid))
.is_none()
{
return;
}
let Some(table_metadata) = s.table_metadata.get(table_name) else {
// TODO: Should we cleanup here? Probably, since we won't recover from this issue.
log::error!("Table not found: {table_name}");
return;
};
let Some(record_values) = extract_record_values(case) else {
log::error!("Failed to extract values");
return;
};
let state = ContinuationState {
state: s.clone(),
table_metadata,
action,
table_name: table_name.to_string(),
rowid,
record_values,
};
// TODO: Optimization: in cases where there's only table-level access restrictions, we
// could avoid the continuation and even dispatch the subscription handling to a
// different thread entirely to take more work off the SQLite thread.
conn.call_and_forget(move |conn| {
Self::hook_continuation(conn, state);
});
},
))
.await;
}
async fn add_subscription(
&self,
api: RecordApi,
record: Option<trailbase_sqlite::Value>,
user: Option<User>,
) -> Result<async_channel::Receiver<DbEvent>, RecordError> {
let Some(record) = record else {
return Err(RecordError::BadRequest("Missing record id"));
};
let (sender, receiver) = async_channel::bounded::<DbEvent>(16);
let table_name = api.table_name();
let pk_column = &api.record_pk_column().name;
let Some(row) = self
.state
.conn
.query_row(
&format!(r#"SELECT _rowid_ FROM "{table_name}" WHERE "{pk_column}" = $1"#),
params!(record.clone()),
)
.await?
else {
return Err(RecordError::RecordNotFound);
};
let row_id: i64 = row
.get(0)
.map_err(|err| RecordError::Internal(err.into()))?;
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
let empty = {
let mut lock = self.state.subscriptions.write();
let empty = lock.is_empty();
let m: &mut HashMap<i64, Vec<Subscription>> = lock.entry(table_name.to_string()).or_default();
m.entry(row_id).or_default().push(Subscription {
subscription_id,
record_api_name: api.api_name().to_string(),
// record_id: Some(record),
user,
channel: sender,
});
empty
};
if empty {
self.add_hook().await.unwrap();
}
return Ok(receiver);
}
// TODO: Cleaning up subscriptions might be a thing, e.g. if SSE handlers had an onDisconnect
// handler. Right now we're handling cleanups reactively, i.e. we only remove subscriptions when
// sending new events and the receiving end of a handler channel became invalid. It would
// be better to be pro-active and remove subscriptions sooner.
//
// async fn cleanup_subscription(&self, subscription_id: SubscriptionId) -> Result<(),
// RecordError> { let mut lock = self.state.subscriptions.write();
//
// if let Some(table_subs) = lock.get_mut(&subscription_id.table_name) {
// if let Some(subs) = table_subs.get_mut(&subscription_id.row_id) {
// subs.retain(|s| s.id != subscription_id.subscription_id);
//
// if subs.is_empty() {
// table_subs.remove(&subscription_id.row_id);
// }
// }
//
// if table_subs.is_empty() {
// lock.remove(&subscription_id.table_name);
// }
// }
//
// if lock.is_empty() {
// Self::remove_preupdate_hook(&*self.state).await?;
// }
//
// return Ok(());
// }
}
pub async fn add_subscription_sse_handler(
State(state): State<AppState>,
Path((api_name, record)): Path<(String, String)>,
user: Option<User>,
) -> Result<Sse<impl Stream<Item = SseEvent>>, RecordError> {
let Some(api) = state.lookup_record_api(&api_name) else {
return Err(RecordError::ApiNotFound);
};
let record_id = api.id_to_sql(&record)?;
let Ok(()) = api
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
.await
else {
return Err(RecordError::Forbidden);
};
let receiver = state
.subscription_manager()
.add_subscription(api, Some(record_id), user)
.await?;
return Ok(
Sse::new(receiver.map(|ev| {
// TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy
// to pre-encode on the sender side but makes testing much harder, since there's no good way
// to parse sse::Event back.
return Event::default().json_data(ev);
}))
.keep_alive(KeepAlive::default()),
);
}
#[cfg(test)]
mod tests {
use super::DbEvent;
use super::*;
use crate::app_state::test_state;
use crate::records::{add_record_api, AccessRules, Acls, PermissionFlag};
#[tokio::test]
async fn subscribe_connection_test() {
let state = test_state(None).await.unwrap();
let conn = state.conn().clone();
conn
.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, text TEXT) STRICT",
(),
)
.await
.unwrap();
state.table_metadata().invalidate_all().await.unwrap();
// Register message table as record api with moderator read access.
add_record_api(
&state,
"api_name",
"test",
Acls {
world: vec![PermissionFlag::Create, PermissionFlag::Read],
..Default::default()
},
AccessRules {
// read: Some("(_ROW_._owner = _USER_.id OR EXISTS(SELECT 1 FROM room_members WHERE room =
// _ROW_.room AND user = _USER_.id))".to_string()),
..Default::default()
},
)
.await
.unwrap();
let record_id_raw = 0;
let record_id = trailbase_sqlite::Value::Integer(record_id_raw);
let rowid: i64 = conn
.query_row(
"INSERT INTO test (id, text) VALUES ($1, 'foo') RETURNING _rowid_",
[record_id],
)
.await
.unwrap()
.unwrap()
.get(0)
.unwrap();
assert_eq!(rowid, record_id_raw);
let manager = state.subscription_manager();
let api = state.lookup_record_api("api_name").unwrap();
let receiver = manager
.add_subscription(api, Some(trailbase_sqlite::Value::Integer(0)), None)
.await
.unwrap();
assert_eq!(1, manager.num_subscriptions());
conn
.execute(
"UPDATE test SET text = $1 WHERE _rowid_ = $2",
params!("bar", rowid),
)
.await
.unwrap();
let expected = serde_json::json!({
"id": record_id_raw,
"text": "bar",
});
match receiver.recv().await.unwrap() {
DbEvent::Update(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
};
conn
.execute("DELETE FROM test WHERE _rowid_ = $2", params!(rowid))
.await
.unwrap();
match receiver.recv().await.unwrap() {
DbEvent::Delete(Some(value)) => {
assert_eq!(value, expected);
}
x => {
assert!(false, "Expected update, got: {x:?}");
}
}
assert_eq!(0, manager.num_subscriptions());
}
// TODO: Test actual SSE handler.
}
const NO_HOOK: Option<fn(Action, &str, &str, &PreUpdateCase)> = None;

View File

@@ -73,7 +73,7 @@ pub struct TableMetadata {
metadata: Vec<ColumnMetadata>,
name_to_index: HashMap<String, usize>,
record_pk_column: Option<usize>,
pub record_pk_column: Option<usize>,
pub user_id_columns: Vec<usize>,
pub file_upload_columns: Vec<usize>,
pub file_uploads_columns: Vec<usize>,

View File

@@ -68,6 +68,7 @@ struct ComputedState<T, V> {
f: Box<dyn Sync + Send + Fn(&V) -> T>,
}
#[derive(Clone)]
pub struct Computed<T, V> {
state: Arc<ComputedState<T, V>>,
}

View File

@@ -1,5 +1,6 @@
use crossbeam_channel::{Receiver, Sender};
use rusqlite::hooks::Action;
use rusqlite::hooks::{Action, PreUpdateCase};
use rusqlite::types::Value;
use std::{
fmt::{self, Debug},
sync::Arc,
@@ -35,11 +36,9 @@ macro_rules! named_params {
pub type Result<T> = std::result::Result<T, Error>;
type CallFn = Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>;
type HookFn = Arc<dyn Fn(&rusqlite::Connection, Action, &str, &str, i64) + Send + Sync + 'static>;
enum Message {
Run(CallFn),
ExecuteHook(HookFn, Action, String, String, i64),
Close(oneshot::Sender<std::result::Result<(), rusqlite::Error>>),
}
@@ -73,7 +72,7 @@ impl Connection {
/// Will return `Err` if the database connection has been closed.
pub async fn call<F, R>(&self, function: F) -> Result<R>
where
F: FnOnce(&mut rusqlite::Connection) -> Result<R> + 'static + Send,
F: FnOnce(&mut rusqlite::Connection) -> Result<R> + Send + 'static,
R: Send + 'static,
{
let (sender, receiver) = oneshot::channel::<Result<R>>();
@@ -89,6 +88,12 @@ impl Connection {
receiver.await.map_err(|_| Error::ConnectionClosed)?
}
pub fn call_and_forget(&self, function: impl FnOnce(&rusqlite::Connection) + Send + 'static) {
let _ = self
.sender
.send(Message::Run(Box::new(move |conn| function(conn))));
}
/// Query SQL statement.
pub async fn query(&self, sql: &str, params: impl Params + Send + 'static) -> Result<Rows> {
let sql = sql.to_string();
@@ -204,39 +209,12 @@ impl Connection {
.await;
}
pub async fn add_hook(
/// Convenience API for (un)setting a new pre-update hook.
pub async fn add_preupdate_hook(
&self,
f: impl Fn(&rusqlite::Connection, Action, &str, &str, i64) + Send + Sync + 'static,
hook: Option<impl (Fn(Action, &str, &str, &PreUpdateCase)) + Send + Sync + 'static>,
) -> Result<()> {
let sender = self.sender.clone();
let f = Arc::new(f);
return self
.call(|conn| {
conn.update_hook(Some(
move |action: Action, db: &str, table: &str, row: i64| {
let _ = sender.send(Message::ExecuteHook(
f.clone(),
action,
db.to_string(),
table.to_string(),
row,
));
},
));
return Ok(());
})
.await;
}
pub async fn remove_hook(&self) -> Result<()> {
return self
.call(|conn| {
conn.update_hook(None::<fn(Action, &str, &str, i64)>);
return Ok(());
})
.await;
return self.call(|conn| Ok(conn.preupdate_hook(hook))).await;
}
/// Close the database connection.
@@ -286,7 +264,6 @@ fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver<Message>) {
while let Ok(message) = receiver.recv() {
match message {
Message::Run(f) => f(&mut conn),
Message::ExecuteHook(f, action, db, table, row) => f(&conn, action, &db, &table, row),
Message::Close(ch) => {
match conn.close() {
Ok(v) => ch.send(Ok(v)).expect(BUG_TEXT),
@@ -299,6 +276,50 @@ fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver<Message>) {
}
}
pub fn extract_row_id(case: &PreUpdateCase) -> Option<i64> {
return match case {
PreUpdateCase::Insert(accessor) => Some(accessor.get_new_row_id()),
PreUpdateCase::Delete(accessor) => Some(accessor.get_old_row_id()),
PreUpdateCase::Update {
new_value_accessor: accessor,
..
} => Some(accessor.get_new_row_id()),
PreUpdateCase::Unknown => None,
};
}
pub fn extract_record_values(case: &PreUpdateCase) -> Option<Vec<Value>> {
return Some(match case {
PreUpdateCase::Insert(accessor) => (0..accessor.get_column_count())
.map(|idx| -> Value {
accessor
.get_new_column_value(idx)
.map_or(rusqlite::types::Value::Null, |v| v.into())
})
.collect(),
PreUpdateCase::Delete(accessor) => (0..accessor.get_column_count())
.map(|idx| -> rusqlite::types::Value {
accessor
.get_old_column_value(idx)
.map_or(rusqlite::types::Value::Null, |v| v.into())
})
.collect(),
PreUpdateCase::Update {
new_value_accessor: accessor,
..
} => (0..accessor.get_column_count())
.map(|idx| -> rusqlite::types::Value {
accessor
.get_new_column_value(idx)
.map_or(rusqlite::types::Value::Null, |v| v.into())
})
.collect(),
PreUpdateCase::Unknown => {
return None;
}
});
}
#[cfg(test)]
#[path = "tests.rs"]
mod tests;

View File

@@ -78,10 +78,9 @@ impl Params for () {
impl Params for Vec<(String, types::Value)> {
fn bind(self, stmt: &mut Statement<'_>) -> rusqlite::Result<()> {
for (name, v) in self {
let Some(idx) = stmt.parameter_index(&name)? else {
continue;
if let Some(idx) = stmt.parameter_index(&name)? {
stmt.raw_bind_parameter(idx, v)?;
};
stmt.raw_bind_parameter(idx, v)?;
}
return Ok(());
}

View File

@@ -1,6 +1,8 @@
use rusqlite::ffi;
use rusqlite::hooks::PreUpdateCase;
use serde::Deserialize;
use crate::connection::extract_row_id;
use crate::{named_params, params, Connection, Error, Value, ValueType};
use rusqlite::ErrorCode;
@@ -322,22 +324,49 @@ async fn test_hooks() {
.await
.unwrap();
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<String>();
conn
.add_hook(move |c, action, _db, table, row_id| match action {
rusqlite::hooks::Action::SQLITE_INSERT => {
let text = c
.query_row(
&format!(r#"SELECT text FROM "{table}" WHERE _rowid_ = $1"#),
[row_id],
|row| row.get::<_, String>(0),
)
.unwrap();
struct State {
action: rusqlite::hooks::Action,
table_name: String,
row_id: i64,
}
sender.send(text).unwrap();
}
_ => {}
})
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<String>();
let c = conn.clone();
conn
.add_preupdate_hook(Some(
move |action: rusqlite::hooks::Action, _db: &str, table_name: &str, case: &PreUpdateCase| {
let row_id = extract_row_id(case).unwrap();
let state = State {
action,
table_name: table_name.to_string(),
row_id,
};
let sender = sender.clone();
c.call_and_forget(move |conn| {
match state.action {
rusqlite::hooks::Action::SQLITE_INSERT => {
let text = conn
.query_row(
&format!(
r#"SELECT text FROM "{}" WHERE _rowid_ = $1"#,
state.table_name
),
[state.row_id],
|row| row.get::<_, String>(0),
)
.unwrap();
sender.send(text).unwrap();
}
_ => {
panic!("unexpected action: {:?}", state.action);
}
};
});
},
))
.await
.unwrap();