mirror of
https://github.com/trailbaseio/trailbase.git
synced 2026-01-06 09:50:10 -06:00
Add "realtime" subscriptions for a specific record, i.e. updates and deletion.
This commit is contained in:
27
Cargo.lock
generated
27
Cargo.lock
generated
@@ -715,6 +715,26 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.69.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.10.5",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash 1.1.0",
|
||||
"shlex",
|
||||
"syn 2.0.95",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.70.1"
|
||||
@@ -3028,6 +3048,12 @@ dependencies = [
|
||||
"spin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazycell"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||
|
||||
[[package]]
|
||||
name = "lettre"
|
||||
version = "0.11.11"
|
||||
@@ -3097,6 +3123,7 @@ version = "0.30.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149"
|
||||
dependencies = [
|
||||
"bindgen 0.69.5",
|
||||
"cc",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
|
||||
@@ -48,6 +48,7 @@ rusqlite = { version = "^0.32.1", default-features = false, features = [
|
||||
"limits",
|
||||
"backup",
|
||||
"hooks",
|
||||
"preupdate_hook",
|
||||
] }
|
||||
trailbase-sqlean = { path = "vendor/sqlean", version = "^0.0.1" }
|
||||
trailbase-extension = { path = "trailbase-extension", version = "^0.1.0" }
|
||||
|
||||
@@ -5,6 +5,8 @@ import 'package:jwt_decoder/jwt_decoder.dart';
|
||||
import 'package:logging/logging.dart';
|
||||
import 'package:dio/dio.dart' as dio;
|
||||
|
||||
import 'sse.dart';
|
||||
|
||||
class User {
|
||||
final String id;
|
||||
final String email;
|
||||
@@ -272,6 +274,22 @@ class RecordApi {
|
||||
);
|
||||
}
|
||||
|
||||
Future<Stream<Map<String, dynamic>>> subscribe(RecordId id) async {
|
||||
final resp = await _client.fetch(
|
||||
'${RecordApi._recordApi}/${_name}/subscribe/${id}',
|
||||
responseType: dio.ResponseType.stream,
|
||||
);
|
||||
|
||||
final Stream<Uint8List> stream = resp.data.stream;
|
||||
return stream.asyncMap((Uint8List bytes) {
|
||||
final decoded = utf8.decode(bytes);
|
||||
if (decoded.startsWith('data: ')) {
|
||||
return jsonDecode(decoded.substring(6));
|
||||
}
|
||||
return jsonDecode(decoded);
|
||||
});
|
||||
}
|
||||
|
||||
Uri imageUri(RecordId id, String colName, {int? index}) {
|
||||
if (index != null) {
|
||||
return Uri.parse(
|
||||
@@ -283,7 +301,7 @@ class RecordApi {
|
||||
}
|
||||
|
||||
class _ThinClient {
|
||||
static final _dio = dio.Dio();
|
||||
static final _dio = dio.Dio()..interceptors.add(SeeInterceptor());
|
||||
|
||||
final String site;
|
||||
|
||||
@@ -295,6 +313,7 @@ class _ThinClient {
|
||||
Object? data,
|
||||
String? method,
|
||||
Map<String, dynamic>? queryParams,
|
||||
dio.ResponseType? responseType,
|
||||
}) async {
|
||||
if (path.startsWith('/')) {
|
||||
throw Exception('Path starts with "/". Relative path expected.');
|
||||
@@ -308,6 +327,7 @@ class _ThinClient {
|
||||
method: method,
|
||||
headers: tokenState.headers,
|
||||
validateStatus: (int? status) => true,
|
||||
responseType: responseType,
|
||||
),
|
||||
);
|
||||
|
||||
@@ -508,6 +528,7 @@ class Client {
|
||||
Object? data,
|
||||
String? method,
|
||||
Map<String, dynamic>? queryParams,
|
||||
dio.ResponseType? responseType,
|
||||
}) async {
|
||||
var tokenState = _tokenState;
|
||||
final refreshToken = _shouldRefresh(tokenState);
|
||||
@@ -515,8 +536,14 @@ class Client {
|
||||
tokenState = _tokenState = await _refreshTokensImpl(refreshToken);
|
||||
}
|
||||
|
||||
final response = await _client.fetch(path, tokenState,
|
||||
data: data, method: method, queryParams: queryParams);
|
||||
final response = await _client.fetch(
|
||||
path,
|
||||
tokenState,
|
||||
data: data,
|
||||
method: method,
|
||||
queryParams: queryParams,
|
||||
responseType: responseType,
|
||||
);
|
||||
|
||||
if (response.statusCode != 200 && (throwOnError ?? true)) {
|
||||
final errMsg = await response.data;
|
||||
|
||||
50
client/trailbase-dart/lib/src/sse.dart
Normal file
50
client/trailbase-dart/lib/src/sse.dart
Normal file
@@ -0,0 +1,50 @@
|
||||
import 'dart:async';
|
||||
import 'dart:typed_data';
|
||||
|
||||
import 'package:dio/dio.dart' as dio;
|
||||
|
||||
class SeeInterceptor extends dio.Interceptor {
|
||||
@override
|
||||
void onResponse(
|
||||
dio.Response response, dio.ResponseInterceptorHandler handler) {
|
||||
if (response.requestOptions.responseType == dio.ResponseType.stream) {
|
||||
final Stream<Uint8List> stream = response.data.stream;
|
||||
|
||||
final buffer = BytesBuilder();
|
||||
final transformedStream = stream.transform<Uint8List>(
|
||||
StreamTransformer.fromHandlers(
|
||||
handleData: (Uint8List data, EventSink<Uint8List> sink) {
|
||||
// If terminated correctly (\n\n) write to sink, otherwise buffer.
|
||||
if (endsWithNewlineNewline(data)) {
|
||||
if (buffer.isNotEmpty) {
|
||||
buffer.add(data);
|
||||
sink.add(buffer.takeBytes());
|
||||
} else {
|
||||
sink.add(data);
|
||||
}
|
||||
} else {
|
||||
buffer.add(data);
|
||||
}
|
||||
},
|
||||
),
|
||||
);
|
||||
|
||||
return handler.resolve(dio.Response(
|
||||
requestOptions: response.requestOptions,
|
||||
data: dio.ResponseBody(transformedStream, response.data.contentLength),
|
||||
statusCode: response.statusCode,
|
||||
headers: response.headers,
|
||||
));
|
||||
}
|
||||
|
||||
handler.next(response);
|
||||
}
|
||||
|
||||
bool endsWithNewlineNewline(List<int> bytes) {
|
||||
if (bytes.length < 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return bytes[bytes.length - 1] == 10 && bytes[bytes.length - 2] == 10;
|
||||
}
|
||||
}
|
||||
@@ -2,3 +2,4 @@ library;
|
||||
|
||||
export 'src/client.dart';
|
||||
export 'src/pkce.dart';
|
||||
export 'src/sse.dart';
|
||||
|
||||
@@ -57,14 +57,19 @@ Future<Process> initTrailBase() async {
|
||||
print('Trying to connect to TrailBase');
|
||||
}
|
||||
|
||||
await Future.delayed(Duration(milliseconds: 500));
|
||||
if (await process.exitCode
|
||||
.timeout(Duration(milliseconds: 500), onTimeout: () => -1) >=
|
||||
0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
process.kill(ProcessSignal.sigkill);
|
||||
final exitCode = await process.exitCode;
|
||||
|
||||
await process.stdout.forEach(print);
|
||||
await process.stderr.forEach(print);
|
||||
await process.stderr.forEach(stdout.add);
|
||||
await process.stdout.forEach(stdout.add);
|
||||
|
||||
throw Exception('Cargo run failed: ${exitCode}.');
|
||||
}
|
||||
|
||||
@@ -73,7 +78,15 @@ Future<void> main() async {
|
||||
throw Exception('Unexpected working directory');
|
||||
}
|
||||
|
||||
await initTrailBase();
|
||||
final process = await initTrailBase();
|
||||
|
||||
tearDownAll(() async {
|
||||
process.kill(ProcessSignal.sigkill);
|
||||
final _ = await process.exitCode;
|
||||
|
||||
// await process.stderr.forEach(stdout.add);
|
||||
// await process.stdout.forEach(stdout.add);
|
||||
});
|
||||
|
||||
group('client tests', () {
|
||||
test('auth', () async {
|
||||
@@ -159,5 +172,27 @@ Future<void> main() async {
|
||||
await api.delete(ids[0]);
|
||||
expect(() async => await api.read(ids[0]), throwsException);
|
||||
});
|
||||
|
||||
test('realtime', () async {
|
||||
final client = await connect();
|
||||
final api = client.records('simple_strict_table');
|
||||
|
||||
final int now = DateTime.now().millisecondsSinceEpoch ~/ 1000;
|
||||
final id = await api
|
||||
.create({'text_not_null': 'dart client realtime test 0: =?&${now}'});
|
||||
|
||||
final events = await api.subscribe(id);
|
||||
|
||||
final updatedMessage = 'dart client updated realtime test 0: ${now}';
|
||||
await api.update(id, {'text_not_null': updatedMessage});
|
||||
await api.delete(id);
|
||||
|
||||
final eventList =
|
||||
await events.timeout(Duration(seconds: 10), onTimeout: (sink) {
|
||||
print('Stream timeout');
|
||||
sink.close();
|
||||
}).toList();
|
||||
expect(eventList.length, equals(2));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ use crate::constants::SITE_URL_DEFAULT;
|
||||
use crate::data_dir::DataDir;
|
||||
use crate::email::Mailer;
|
||||
use crate::js::RuntimeHandle;
|
||||
use crate::records::subscribe::SubscriptionManager;
|
||||
use crate::records::RecordApi;
|
||||
use crate::table_metadata::TableMetadataCache;
|
||||
use crate::value_notifier::{Computed, ValueNotifier};
|
||||
@@ -33,6 +34,7 @@ struct InternalState {
|
||||
jwt: JwtHelper,
|
||||
|
||||
table_metadata: TableMetadataCache,
|
||||
subscription_manager: SubscriptionManager,
|
||||
object_store: Box<dyn ObjectStore + Send + Sync>,
|
||||
|
||||
runtime: RuntimeHandle,
|
||||
@@ -67,6 +69,22 @@ impl AppState {
|
||||
let table_metadata_clone = args.table_metadata.clone();
|
||||
let conn_clone = args.conn.clone();
|
||||
|
||||
let record_apis = Computed::new(&config, move |c| {
|
||||
return c
|
||||
.record_apis
|
||||
.iter()
|
||||
.filter_map(|config| {
|
||||
match build_record_api(conn_clone.clone(), &table_metadata_clone, config.clone()) {
|
||||
Ok(api) => Some((api.api_name().to_string(), api)),
|
||||
Err(err) => {
|
||||
error!("{err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
});
|
||||
|
||||
let runtime = args
|
||||
.js_runtime_threads
|
||||
.map_or_else(RuntimeHandle::new, RuntimeHandle::new_with_threads);
|
||||
@@ -87,26 +105,13 @@ impl AppState {
|
||||
}
|
||||
}),
|
||||
mailer: build_mailer(&config, None),
|
||||
record_apis: Computed::new(&config, move |c| {
|
||||
return c
|
||||
.record_apis
|
||||
.iter()
|
||||
.filter_map(|config| {
|
||||
match build_record_api(conn_clone.clone(), &table_metadata_clone, config.clone()) {
|
||||
Ok(api) => Some((api.api_name().to_string(), api)),
|
||||
Err(err) => {
|
||||
error!("{err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}),
|
||||
record_apis: record_apis.clone(),
|
||||
config,
|
||||
conn: args.conn.clone(),
|
||||
logs_conn: args.logs_conn,
|
||||
jwt: args.jwt,
|
||||
table_metadata: args.table_metadata,
|
||||
table_metadata: args.table_metadata.clone(),
|
||||
subscription_manager: SubscriptionManager::new(args.conn, args.table_metadata, record_apis),
|
||||
object_store: args.object_store,
|
||||
runtime,
|
||||
#[cfg(test)]
|
||||
@@ -145,6 +150,10 @@ impl AppState {
|
||||
return &self.state.table_metadata;
|
||||
}
|
||||
|
||||
pub(crate) fn subscription_manager(&self) -> &SubscriptionManager {
|
||||
return &self.state.subscription_manager;
|
||||
}
|
||||
|
||||
pub async fn refresh_table_cache(&self) -> Result<(), crate::table_metadata::TableLookupError> {
|
||||
self.table_metadata().invalidate_all().await
|
||||
}
|
||||
@@ -360,6 +369,23 @@ pub async fn test_state(options: Option<TestStateOptions>) -> anyhow::Result<App
|
||||
build_objectstore(&data_dir, None).unwrap()
|
||||
};
|
||||
|
||||
let record_apis = Computed::new(&config, move |c| {
|
||||
return c
|
||||
.record_apis
|
||||
.iter()
|
||||
.filter_map(|config| {
|
||||
let api = build_record_api(
|
||||
main_conn_clone.clone(),
|
||||
&table_metadata_clone,
|
||||
config.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
return Some((api.api_name().to_string(), api));
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
});
|
||||
|
||||
let runtime = RuntimeHandle::new();
|
||||
runtime.set_connection(conn.clone());
|
||||
|
||||
@@ -372,27 +398,13 @@ pub async fn test_state(options: Option<TestStateOptions>) -> anyhow::Result<App
|
||||
ConfiguredOAuthProviders::from_config(c.auth.clone()).unwrap()
|
||||
}),
|
||||
mailer: build_mailer(&config, options.and_then(|o| o.mailer)),
|
||||
record_apis: Computed::new(&config, move |c| {
|
||||
return c
|
||||
.record_apis
|
||||
.iter()
|
||||
.filter_map(|config| {
|
||||
let api = build_record_api(
|
||||
main_conn_clone.clone(),
|
||||
&table_metadata_clone,
|
||||
config.clone(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
return Some((api.api_name().to_string(), api));
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}),
|
||||
record_apis: record_apis.clone(),
|
||||
config,
|
||||
conn,
|
||||
conn: conn.clone(),
|
||||
logs_conn,
|
||||
jwt: jwt::test_jwt_helper(),
|
||||
table_metadata,
|
||||
table_metadata: table_metadata.clone(),
|
||||
subscription_manager: SubscriptionManager::new(conn, table_metadata, record_apis),
|
||||
object_store,
|
||||
runtime,
|
||||
cleanup: vec![Box::new(temp_dir)],
|
||||
|
||||
@@ -14,6 +14,7 @@ mod list_records;
|
||||
pub(crate) mod read_record;
|
||||
mod record_api;
|
||||
pub mod sql_to_json;
|
||||
pub(crate) mod subscribe;
|
||||
pub mod test_utils;
|
||||
mod update_record;
|
||||
mod validate;
|
||||
@@ -76,6 +77,10 @@ pub(crate) fn router() -> Router<AppState> {
|
||||
.route(
|
||||
&format!("/{RECORD_API_PATH}/{{name}}/schema"),
|
||||
get(json_schema::json_schema_handler),
|
||||
)
|
||||
.route(
|
||||
&format!("/{RECORD_API_PATH}/{{name}}/subscribe/{{record}}"),
|
||||
get(subscribe::add_subscription_sse_handler),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -339,23 +339,23 @@ impl RecordApi {
|
||||
|
||||
/// Check if the given user (if any) can access a record given the request and the operation.
|
||||
#[allow(unused)]
|
||||
pub fn check_record_level_access_sync(
|
||||
pub fn check_record_level_read_access(
|
||||
&self,
|
||||
conn: &mut rusqlite::Connection,
|
||||
conn: &rusqlite::Connection,
|
||||
p: Permission,
|
||||
record_id: Option<&Value>,
|
||||
request_params: Option<&mut LazyParams<'_>>,
|
||||
record: Vec<(&str, rusqlite::types::Value)>,
|
||||
user: Option<&User>,
|
||||
) -> Result<(), RecordError> {
|
||||
// First check table level access and if present check row-level access based on access rule.
|
||||
self.check_table_level_access(p, user)?;
|
||||
|
||||
let Some(access_query) = self.access_query(p) else {
|
||||
let Some(access_rule) = self.access_rule(p) else {
|
||||
return Ok(());
|
||||
};
|
||||
let params = self.build_named_params(p, record_id, request_params, user)?;
|
||||
|
||||
match Self::query_access(conn, access_query, params) {
|
||||
let (query, params) = build_query_and_params_for_record_read(access_rule, user, record);
|
||||
|
||||
match Self::query_access(conn, &query, params) {
|
||||
Ok(allowed) => {
|
||||
if allowed {
|
||||
return Ok(());
|
||||
@@ -389,7 +389,7 @@ impl RecordApi {
|
||||
|
||||
#[inline]
|
||||
fn query_access(
|
||||
conn: &mut rusqlite::Connection,
|
||||
conn: &rusqlite::Connection,
|
||||
access_query: &str,
|
||||
params: NamedParams,
|
||||
) -> Result<bool, trailbase_sqlite::Error> {
|
||||
@@ -453,6 +453,45 @@ impl RecordApi {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_query_and_params_for_record_read(
|
||||
access_rule: &str,
|
||||
user: Option<&User>,
|
||||
record: Vec<(&str, rusqlite::types::Value)>,
|
||||
) -> (String, NamedParams) {
|
||||
let row = record
|
||||
.iter()
|
||||
.map(|(name, _value)| {
|
||||
return format!(":__v_{name} AS {name}");
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
let mut params: Vec<_> = record
|
||||
.into_iter()
|
||||
.map(|(name, value)| {
|
||||
return (Cow::Owned(format!(":__v_{name}")), value);
|
||||
})
|
||||
.collect();
|
||||
|
||||
params.push((
|
||||
Cow::Borrowed(":__user_id"),
|
||||
user.map_or(Value::Null, |u| Value::Blob(u.uuid.into())),
|
||||
));
|
||||
|
||||
// Assumes access_rule is an expression: https://www.sqlite.org/syntax/expr.html
|
||||
let query = indoc::formatdoc!(
|
||||
r#"
|
||||
SELECT
|
||||
({access_rule})
|
||||
FROM
|
||||
(SELECT :__user_id AS id) AS _USER_,
|
||||
(SELECT {row}) AS _ROW_
|
||||
"#
|
||||
);
|
||||
|
||||
return (query, params);
|
||||
}
|
||||
|
||||
/// Build access query for record reads, deletes and query access.
|
||||
///
|
||||
/// Assumes access_rule is an expression: https://www.sqlite.org/syntax/expr.html
|
||||
|
||||
@@ -19,20 +19,20 @@ pub enum JsonError {
|
||||
ValueNotFound,
|
||||
}
|
||||
|
||||
fn value_to_json(value: rusqlite::types::Value) -> Result<serde_json::Value, JsonError> {
|
||||
pub(crate) fn value_to_json(value: rusqlite::types::Value) -> Result<serde_json::Value, JsonError> {
|
||||
use rusqlite::types::Value;
|
||||
|
||||
return Ok(match value {
|
||||
rusqlite::types::Value::Null => serde_json::Value::Null,
|
||||
rusqlite::types::Value::Real(real) => {
|
||||
Value::Null => serde_json::Value::Null,
|
||||
Value::Real(real) => {
|
||||
let Some(number) = serde_json::Number::from_f64(real) else {
|
||||
return Err(JsonError::Finite);
|
||||
};
|
||||
serde_json::Value::Number(number)
|
||||
}
|
||||
rusqlite::types::Value::Integer(integer) => {
|
||||
serde_json::Value::Number(serde_json::Number::from(integer))
|
||||
}
|
||||
rusqlite::types::Value::Blob(blob) => serde_json::Value::String(BASE64_URL_SAFE.encode(blob)),
|
||||
rusqlite::types::Value::Text(text) => serde_json::Value::String(text),
|
||||
Value::Integer(integer) => serde_json::Value::Number(serde_json::Number::from(integer)),
|
||||
Value::Blob(blob) => serde_json::Value::String(BASE64_URL_SAFE.encode(blob)),
|
||||
Value::Text(text) => serde_json::Value::String(text),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
560
trailbase-core/src/records/subscribe.rs
Normal file
560
trailbase-core/src/records/subscribe.rs
Normal file
@@ -0,0 +1,560 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::sse::{Event, KeepAlive, Sse},
|
||||
};
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use parking_lot::RwLock;
|
||||
use rusqlite::hooks::{Action, PreUpdateCase};
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicI64, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use trailbase_sqlite::{
|
||||
connection::{extract_record_values, extract_row_id},
|
||||
params,
|
||||
};
|
||||
|
||||
use crate::auth::user::User;
|
||||
use crate::records::sql_to_json::value_to_json;
|
||||
use crate::records::RecordApi;
|
||||
use crate::records::{Permission, RecordError};
|
||||
use crate::table_metadata::{TableMetadata, TableMetadataCache};
|
||||
use crate::value_notifier::Computed;
|
||||
use crate::AppState;
|
||||
|
||||
static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0);
|
||||
|
||||
// TODO:
|
||||
// * clients
|
||||
// * table-wide subscriptions
|
||||
// * optimize: avoid repeated encoding of events. Easy to do but makes testing harder since there's
|
||||
// no good way to parse sse::Event back :/. We should probably just bite the bullet and parse,
|
||||
// it's literally "data: <json>\n\n".
|
||||
|
||||
type SseEvent = Result<axum::response::sse::Event, axum::Error>;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize)]
|
||||
pub enum RecordAction {
|
||||
Delete,
|
||||
Insert,
|
||||
Update,
|
||||
}
|
||||
|
||||
impl From<Action> for RecordAction {
|
||||
fn from(value: Action) -> Self {
|
||||
return match value {
|
||||
Action::SQLITE_DELETE => RecordAction::Delete,
|
||||
Action::SQLITE_INSERT => RecordAction::Insert,
|
||||
Action::SQLITE_UPDATE => RecordAction::Update,
|
||||
_ => unreachable!("{value:?}"),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub enum DbEvent {
|
||||
Update(Option<serde_json::Value>),
|
||||
Insert(Option<serde_json::Value>),
|
||||
Delete(Option<serde_json::Value>),
|
||||
Error(String),
|
||||
}
|
||||
|
||||
// pub struct SubscriptionId {
|
||||
// table_name: String,
|
||||
// row_id: i64,
|
||||
// subscription_id: i64,
|
||||
// }
|
||||
|
||||
pub struct Subscription {
|
||||
/// Id uniquely identifying this subscription.
|
||||
subscription_id: i64,
|
||||
/// Name of the API this subscription is subscribed to. We need to lookup the Record API on the
|
||||
/// hot path to make sure we're getting the latest configuration.
|
||||
record_api_name: String,
|
||||
|
||||
/// Record id present for subscriptions to specific records.
|
||||
// record_id: Option<trailbase_sqlite::Value>,
|
||||
user: Option<User>,
|
||||
/// Channel for sending events to the SSE handler.
|
||||
channel: async_channel::Sender<DbEvent>,
|
||||
}
|
||||
|
||||
/// Internal, shareable state of the cloneable SubscriptionManager.
|
||||
struct ManagerState {
|
||||
/// SQLite connection to monitor.
|
||||
conn: trailbase_sqlite::Connection,
|
||||
/// Table metadata for mapping column indexes to column names needed for building JSON encoded
|
||||
/// records.
|
||||
table_metadata: TableMetadataCache,
|
||||
/// Record API configurations.
|
||||
record_apis: Computed<Vec<(String, RecordApi)>, crate::config::proto::Config>,
|
||||
|
||||
/// Map from table name to row id to list of subscriptions.
|
||||
subscriptions: RwLock<HashMap<String, HashMap<i64, Vec<Subscription>>>>,
|
||||
}
|
||||
|
||||
impl ManagerState {
|
||||
fn lookup_record_api(&self, name: &str) -> Option<RecordApi> {
|
||||
for (record_api_name, record_api) in self.record_apis.load().iter() {
|
||||
if record_api_name == name {
|
||||
return Some(record_api.clone());
|
||||
}
|
||||
}
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SubscriptionManager {
|
||||
state: Arc<ManagerState>,
|
||||
}
|
||||
|
||||
struct ContinuationState {
|
||||
state: Arc<ManagerState>,
|
||||
table_metadata: Arc<TableMetadata>,
|
||||
action: RecordAction,
|
||||
table_name: String,
|
||||
rowid: i64,
|
||||
record_values: Vec<rusqlite::types::Value>,
|
||||
}
|
||||
|
||||
impl SubscriptionManager {
|
||||
pub fn new(
|
||||
conn: trailbase_sqlite::Connection,
|
||||
table_metadata: TableMetadataCache,
|
||||
record_apis: Computed<Vec<(String, RecordApi)>, crate::config::proto::Config>,
|
||||
) -> Self {
|
||||
return Self {
|
||||
state: Arc::new(ManagerState {
|
||||
conn,
|
||||
table_metadata,
|
||||
record_apis,
|
||||
|
||||
subscriptions: RwLock::new(HashMap::new()),
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn num_subscriptions(&self) -> usize {
|
||||
let mut count: usize = 0;
|
||||
for table in self.state.subscriptions.read().values() {
|
||||
for record in table.values() {
|
||||
count += record.len();
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
/// Preupdate hook that runs in a continuation of the trailbase-sqlite executor.
|
||||
fn hook_continuation(conn: &rusqlite::Connection, state: ContinuationState) {
|
||||
let ContinuationState {
|
||||
state,
|
||||
table_metadata,
|
||||
table_name,
|
||||
action,
|
||||
rowid,
|
||||
record_values,
|
||||
} = state;
|
||||
|
||||
let s = &state;
|
||||
|
||||
let mut read_lock = s.subscriptions.upgradable_read();
|
||||
let Some(subs) = read_lock.get(&table_name).and_then(|m| m.get(&rowid)) else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Join values with column names.
|
||||
let record: Vec<_> = record_values
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (table_metadata.schema.columns[idx].name.as_str(), v))
|
||||
.collect();
|
||||
|
||||
// Build a JSON-encoded SQLite event (insert, update, delete).
|
||||
let event = {
|
||||
let json_value = serde_json::Value::Object(
|
||||
record
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
if let Ok(v) = value_to_json(value.clone()) {
|
||||
return Some(((*name).to_string(), v));
|
||||
};
|
||||
return None;
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
|
||||
match action {
|
||||
RecordAction::Delete => DbEvent::Delete(Some(json_value)),
|
||||
RecordAction::Insert => DbEvent::Insert(Some(json_value)),
|
||||
RecordAction::Update => DbEvent::Update(Some(json_value)),
|
||||
}
|
||||
};
|
||||
|
||||
let mut dead_subscriptions: Vec<usize> = vec![];
|
||||
for (idx, sub) in subs.iter().enumerate() {
|
||||
let Some(api) = s.lookup_record_api(&sub.record_api_name) else {
|
||||
dead_subscriptions.push(idx);
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Err(_err) = api.check_record_level_read_access(
|
||||
conn,
|
||||
Permission::Read,
|
||||
// TODO: Maybe we could inject ValueRef instead to avoid repeated cloning.
|
||||
record.clone(),
|
||||
sub.user.as_ref(),
|
||||
) {
|
||||
// This can happen if the record api configuration has changed since originally
|
||||
// subscribed. In this case we just send and error and cancel the subscription.
|
||||
let _ = sub.channel.try_send(DbEvent::Error("Access denied".into()));
|
||||
dead_subscriptions.push(idx);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: Avoid cloning the event/record over and over.
|
||||
match sub.channel.try_send(event.clone()) {
|
||||
Ok(_) => {}
|
||||
Err(async_channel::TrySendError::Full(ev)) => {
|
||||
log::warn!("Channel full, dropping event: {ev:?}");
|
||||
}
|
||||
Err(async_channel::TrySendError::Closed(_ev)) => {
|
||||
dead_subscriptions.push(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if dead_subscriptions.is_empty() && action != RecordAction::Delete {
|
||||
// No cleanup needed.
|
||||
return;
|
||||
}
|
||||
|
||||
read_lock.with_upgraded(move |subscriptions| {
|
||||
let Some(table_subscriptions) = subscriptions.get_mut(&table_name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if action == RecordAction::Delete {
|
||||
// Also drops the channel and thus automatically closes the SSE connection.
|
||||
table_subscriptions.remove(&rowid);
|
||||
|
||||
if table_subscriptions.is_empty() {
|
||||
subscriptions.remove(&table_name);
|
||||
if subscriptions.is_empty() {
|
||||
conn.preupdate_hook(NO_HOOK);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(m) = table_subscriptions.get_mut(&rowid) {
|
||||
for idx in dead_subscriptions.iter().rev() {
|
||||
m.swap_remove(*idx);
|
||||
}
|
||||
|
||||
if m.is_empty() {
|
||||
table_subscriptions.remove(&rowid);
|
||||
|
||||
if table_subscriptions.is_empty() {
|
||||
subscriptions.remove(&table_name);
|
||||
if subscriptions.is_empty() {
|
||||
conn.preupdate_hook(NO_HOOK);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn add_hook(&self) -> trailbase_sqlite::connection::Result<()> {
|
||||
let state = &self.state;
|
||||
let conn = state.conn.clone();
|
||||
let s = state.clone();
|
||||
|
||||
return state
|
||||
.conn
|
||||
.add_preupdate_hook(Some(
|
||||
move |action: Action, db: &str, table_name: &str, case: &PreUpdateCase| {
|
||||
assert_eq!(db, "main");
|
||||
|
||||
let action: RecordAction = match action {
|
||||
Action::SQLITE_UPDATE | Action::SQLITE_INSERT | Action::SQLITE_DELETE => action.into(),
|
||||
a => {
|
||||
log::error!("Unknown action: {a:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let Some(rowid) = extract_row_id(case) else {
|
||||
log::error!("Failed to extract row id");
|
||||
return;
|
||||
};
|
||||
|
||||
// If there are no subscriptions, do nothing.
|
||||
if s
|
||||
.subscriptions
|
||||
.read()
|
||||
.get(table_name)
|
||||
.and_then(|m| m.get(&rowid))
|
||||
.is_none()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(table_metadata) = s.table_metadata.get(table_name) else {
|
||||
// TODO: Should we cleanup here? Probably, since we won't recover from this issue.
|
||||
log::error!("Table not found: {table_name}");
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(record_values) = extract_record_values(case) else {
|
||||
log::error!("Failed to extract values");
|
||||
return;
|
||||
};
|
||||
|
||||
let state = ContinuationState {
|
||||
state: s.clone(),
|
||||
table_metadata,
|
||||
action,
|
||||
table_name: table_name.to_string(),
|
||||
rowid,
|
||||
record_values,
|
||||
};
|
||||
|
||||
// TODO: Optimization: in cases where there's only table-level access restrictions, we
|
||||
// could avoid the continuation and even dispatch the subscription handling to a
|
||||
// different thread entirely to take more work off the SQLite thread.
|
||||
conn.call_and_forget(move |conn| {
|
||||
Self::hook_continuation(conn, state);
|
||||
});
|
||||
},
|
||||
))
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn add_subscription(
|
||||
&self,
|
||||
api: RecordApi,
|
||||
record: Option<trailbase_sqlite::Value>,
|
||||
user: Option<User>,
|
||||
) -> Result<async_channel::Receiver<DbEvent>, RecordError> {
|
||||
let Some(record) = record else {
|
||||
return Err(RecordError::BadRequest("Missing record id"));
|
||||
};
|
||||
let (sender, receiver) = async_channel::bounded::<DbEvent>(16);
|
||||
|
||||
let table_name = api.table_name();
|
||||
let pk_column = &api.record_pk_column().name;
|
||||
|
||||
let Some(row) = self
|
||||
.state
|
||||
.conn
|
||||
.query_row(
|
||||
&format!(r#"SELECT _rowid_ FROM "{table_name}" WHERE "{pk_column}" = $1"#),
|
||||
params!(record.clone()),
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
return Err(RecordError::RecordNotFound);
|
||||
};
|
||||
let row_id: i64 = row
|
||||
.get(0)
|
||||
.map_err(|err| RecordError::Internal(err.into()))?;
|
||||
|
||||
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let empty = {
|
||||
let mut lock = self.state.subscriptions.write();
|
||||
let empty = lock.is_empty();
|
||||
let m: &mut HashMap<i64, Vec<Subscription>> = lock.entry(table_name.to_string()).or_default();
|
||||
|
||||
m.entry(row_id).or_default().push(Subscription {
|
||||
subscription_id,
|
||||
record_api_name: api.api_name().to_string(),
|
||||
// record_id: Some(record),
|
||||
user,
|
||||
channel: sender,
|
||||
});
|
||||
|
||||
empty
|
||||
};
|
||||
|
||||
if empty {
|
||||
self.add_hook().await.unwrap();
|
||||
}
|
||||
|
||||
return Ok(receiver);
|
||||
}
|
||||
|
||||
// TODO: Cleaning up subscriptions might be a thing, e.g. if SSE handlers had an onDisconnect
|
||||
// handler. Right now we're handling cleanups reactively, i.e. we only remove subscriptions when
|
||||
// sending new events and the receiving end of a handler channel became invalid. It would
|
||||
// be better to be pro-active and remove subscriptions sooner.
|
||||
//
|
||||
// async fn cleanup_subscription(&self, subscription_id: SubscriptionId) -> Result<(),
|
||||
// RecordError> { let mut lock = self.state.subscriptions.write();
|
||||
//
|
||||
// if let Some(table_subs) = lock.get_mut(&subscription_id.table_name) {
|
||||
// if let Some(subs) = table_subs.get_mut(&subscription_id.row_id) {
|
||||
// subs.retain(|s| s.id != subscription_id.subscription_id);
|
||||
//
|
||||
// if subs.is_empty() {
|
||||
// table_subs.remove(&subscription_id.row_id);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if table_subs.is_empty() {
|
||||
// lock.remove(&subscription_id.table_name);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if lock.is_empty() {
|
||||
// Self::remove_preupdate_hook(&*self.state).await?;
|
||||
// }
|
||||
//
|
||||
// return Ok(());
|
||||
// }
|
||||
}
|
||||
|
||||
pub async fn add_subscription_sse_handler(
|
||||
State(state): State<AppState>,
|
||||
Path((api_name, record)): Path<(String, String)>,
|
||||
user: Option<User>,
|
||||
) -> Result<Sse<impl Stream<Item = SseEvent>>, RecordError> {
|
||||
let Some(api) = state.lookup_record_api(&api_name) else {
|
||||
return Err(RecordError::ApiNotFound);
|
||||
};
|
||||
|
||||
let record_id = api.id_to_sql(&record)?;
|
||||
|
||||
let Ok(()) = api
|
||||
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
|
||||
.await
|
||||
else {
|
||||
return Err(RecordError::Forbidden);
|
||||
};
|
||||
|
||||
let receiver = state
|
||||
.subscription_manager()
|
||||
.add_subscription(api, Some(record_id), user)
|
||||
.await?;
|
||||
|
||||
return Ok(
|
||||
Sse::new(receiver.map(|ev| {
|
||||
// TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy
|
||||
// to pre-encode on the sender side but makes testing much harder, since there's no good way
|
||||
// to parse sse::Event back.
|
||||
return Event::default().json_data(ev);
|
||||
}))
|
||||
.keep_alive(KeepAlive::default()),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::DbEvent;
|
||||
use super::*;
|
||||
use crate::app_state::test_state;
|
||||
use crate::records::{add_record_api, AccessRules, Acls, PermissionFlag};
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_connection_test() {
|
||||
let state = test_state(None).await.unwrap();
|
||||
let conn = state.conn().clone();
|
||||
|
||||
conn
|
||||
.execute(
|
||||
"CREATE TABLE test (id INTEGER PRIMARY KEY, text TEXT) STRICT",
|
||||
(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
state.table_metadata().invalidate_all().await.unwrap();
|
||||
|
||||
// Register message table as record api with moderator read access.
|
||||
add_record_api(
|
||||
&state,
|
||||
"api_name",
|
||||
"test",
|
||||
Acls {
|
||||
world: vec![PermissionFlag::Create, PermissionFlag::Read],
|
||||
..Default::default()
|
||||
},
|
||||
AccessRules {
|
||||
// read: Some("(_ROW_._owner = _USER_.id OR EXISTS(SELECT 1 FROM room_members WHERE room =
|
||||
// _ROW_.room AND user = _USER_.id))".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let record_id_raw = 0;
|
||||
let record_id = trailbase_sqlite::Value::Integer(record_id_raw);
|
||||
let rowid: i64 = conn
|
||||
.query_row(
|
||||
"INSERT INTO test (id, text) VALUES ($1, 'foo') RETURNING _rowid_",
|
||||
[record_id],
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.get(0)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(rowid, record_id_raw);
|
||||
|
||||
let manager = state.subscription_manager();
|
||||
let api = state.lookup_record_api("api_name").unwrap();
|
||||
let receiver = manager
|
||||
.add_subscription(api, Some(trailbase_sqlite::Value::Integer(0)), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(1, manager.num_subscriptions());
|
||||
|
||||
conn
|
||||
.execute(
|
||||
"UPDATE test SET text = $1 WHERE _rowid_ = $2",
|
||||
params!("bar", rowid),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let expected = serde_json::json!({
|
||||
"id": record_id_raw,
|
||||
"text": "bar",
|
||||
});
|
||||
match receiver.recv().await.unwrap() {
|
||||
DbEvent::Update(Some(value)) => {
|
||||
assert_eq!(value, expected);
|
||||
}
|
||||
x => {
|
||||
assert!(false, "Expected update, got: {x:?}");
|
||||
}
|
||||
};
|
||||
|
||||
conn
|
||||
.execute("DELETE FROM test WHERE _rowid_ = $2", params!(rowid))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match receiver.recv().await.unwrap() {
|
||||
DbEvent::Delete(Some(value)) => {
|
||||
assert_eq!(value, expected);
|
||||
}
|
||||
x => {
|
||||
assert!(false, "Expected update, got: {x:?}");
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(0, manager.num_subscriptions());
|
||||
}
|
||||
|
||||
// TODO: Test actual SSE handler.
|
||||
}
|
||||
|
||||
const NO_HOOK: Option<fn(Action, &str, &str, &PreUpdateCase)> = None;
|
||||
@@ -73,7 +73,7 @@ pub struct TableMetadata {
|
||||
metadata: Vec<ColumnMetadata>,
|
||||
name_to_index: HashMap<String, usize>,
|
||||
|
||||
record_pk_column: Option<usize>,
|
||||
pub record_pk_column: Option<usize>,
|
||||
pub user_id_columns: Vec<usize>,
|
||||
pub file_upload_columns: Vec<usize>,
|
||||
pub file_uploads_columns: Vec<usize>,
|
||||
|
||||
@@ -68,6 +68,7 @@ struct ComputedState<T, V> {
|
||||
f: Box<dyn Sync + Send + Fn(&V) -> T>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Computed<T, V> {
|
||||
state: Arc<ComputedState<T, V>>,
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crossbeam_channel::{Receiver, Sender};
|
||||
use rusqlite::hooks::Action;
|
||||
use rusqlite::hooks::{Action, PreUpdateCase};
|
||||
use rusqlite::types::Value;
|
||||
use std::{
|
||||
fmt::{self, Debug},
|
||||
sync::Arc,
|
||||
@@ -35,11 +36,9 @@ macro_rules! named_params {
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
type CallFn = Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>;
|
||||
type HookFn = Arc<dyn Fn(&rusqlite::Connection, Action, &str, &str, i64) + Send + Sync + 'static>;
|
||||
|
||||
enum Message {
|
||||
Run(CallFn),
|
||||
ExecuteHook(HookFn, Action, String, String, i64),
|
||||
Close(oneshot::Sender<std::result::Result<(), rusqlite::Error>>),
|
||||
}
|
||||
|
||||
@@ -73,7 +72,7 @@ impl Connection {
|
||||
/// Will return `Err` if the database connection has been closed.
|
||||
pub async fn call<F, R>(&self, function: F) -> Result<R>
|
||||
where
|
||||
F: FnOnce(&mut rusqlite::Connection) -> Result<R> + 'static + Send,
|
||||
F: FnOnce(&mut rusqlite::Connection) -> Result<R> + Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
let (sender, receiver) = oneshot::channel::<Result<R>>();
|
||||
@@ -89,6 +88,12 @@ impl Connection {
|
||||
receiver.await.map_err(|_| Error::ConnectionClosed)?
|
||||
}
|
||||
|
||||
pub fn call_and_forget(&self, function: impl FnOnce(&rusqlite::Connection) + Send + 'static) {
|
||||
let _ = self
|
||||
.sender
|
||||
.send(Message::Run(Box::new(move |conn| function(conn))));
|
||||
}
|
||||
|
||||
/// Query SQL statement.
|
||||
pub async fn query(&self, sql: &str, params: impl Params + Send + 'static) -> Result<Rows> {
|
||||
let sql = sql.to_string();
|
||||
@@ -204,39 +209,12 @@ impl Connection {
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn add_hook(
|
||||
/// Convenience API for (un)setting a new pre-update hook.
|
||||
pub async fn add_preupdate_hook(
|
||||
&self,
|
||||
f: impl Fn(&rusqlite::Connection, Action, &str, &str, i64) + Send + Sync + 'static,
|
||||
hook: Option<impl (Fn(Action, &str, &str, &PreUpdateCase)) + Send + Sync + 'static>,
|
||||
) -> Result<()> {
|
||||
let sender = self.sender.clone();
|
||||
let f = Arc::new(f);
|
||||
|
||||
return self
|
||||
.call(|conn| {
|
||||
conn.update_hook(Some(
|
||||
move |action: Action, db: &str, table: &str, row: i64| {
|
||||
let _ = sender.send(Message::ExecuteHook(
|
||||
f.clone(),
|
||||
action,
|
||||
db.to_string(),
|
||||
table.to_string(),
|
||||
row,
|
||||
));
|
||||
},
|
||||
));
|
||||
|
||||
return Ok(());
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn remove_hook(&self) -> Result<()> {
|
||||
return self
|
||||
.call(|conn| {
|
||||
conn.update_hook(None::<fn(Action, &str, &str, i64)>);
|
||||
return Ok(());
|
||||
})
|
||||
.await;
|
||||
return self.call(|conn| Ok(conn.preupdate_hook(hook))).await;
|
||||
}
|
||||
|
||||
/// Close the database connection.
|
||||
@@ -286,7 +264,6 @@ fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver<Message>) {
|
||||
while let Ok(message) = receiver.recv() {
|
||||
match message {
|
||||
Message::Run(f) => f(&mut conn),
|
||||
Message::ExecuteHook(f, action, db, table, row) => f(&conn, action, &db, &table, row),
|
||||
Message::Close(ch) => {
|
||||
match conn.close() {
|
||||
Ok(v) => ch.send(Ok(v)).expect(BUG_TEXT),
|
||||
@@ -299,6 +276,50 @@ fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver<Message>) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_row_id(case: &PreUpdateCase) -> Option<i64> {
|
||||
return match case {
|
||||
PreUpdateCase::Insert(accessor) => Some(accessor.get_new_row_id()),
|
||||
PreUpdateCase::Delete(accessor) => Some(accessor.get_old_row_id()),
|
||||
PreUpdateCase::Update {
|
||||
new_value_accessor: accessor,
|
||||
..
|
||||
} => Some(accessor.get_new_row_id()),
|
||||
PreUpdateCase::Unknown => None,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn extract_record_values(case: &PreUpdateCase) -> Option<Vec<Value>> {
|
||||
return Some(match case {
|
||||
PreUpdateCase::Insert(accessor) => (0..accessor.get_column_count())
|
||||
.map(|idx| -> Value {
|
||||
accessor
|
||||
.get_new_column_value(idx)
|
||||
.map_or(rusqlite::types::Value::Null, |v| v.into())
|
||||
})
|
||||
.collect(),
|
||||
PreUpdateCase::Delete(accessor) => (0..accessor.get_column_count())
|
||||
.map(|idx| -> rusqlite::types::Value {
|
||||
accessor
|
||||
.get_old_column_value(idx)
|
||||
.map_or(rusqlite::types::Value::Null, |v| v.into())
|
||||
})
|
||||
.collect(),
|
||||
PreUpdateCase::Update {
|
||||
new_value_accessor: accessor,
|
||||
..
|
||||
} => (0..accessor.get_column_count())
|
||||
.map(|idx| -> rusqlite::types::Value {
|
||||
accessor
|
||||
.get_new_column_value(idx)
|
||||
.map_or(rusqlite::types::Value::Null, |v| v.into())
|
||||
})
|
||||
.collect(),
|
||||
PreUpdateCase::Unknown => {
|
||||
return None;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests.rs"]
|
||||
mod tests;
|
||||
|
||||
@@ -78,10 +78,9 @@ impl Params for () {
|
||||
impl Params for Vec<(String, types::Value)> {
|
||||
fn bind(self, stmt: &mut Statement<'_>) -> rusqlite::Result<()> {
|
||||
for (name, v) in self {
|
||||
let Some(idx) = stmt.parameter_index(&name)? else {
|
||||
continue;
|
||||
if let Some(idx) = stmt.parameter_index(&name)? {
|
||||
stmt.raw_bind_parameter(idx, v)?;
|
||||
};
|
||||
stmt.raw_bind_parameter(idx, v)?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use rusqlite::ffi;
|
||||
use rusqlite::hooks::PreUpdateCase;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::connection::extract_row_id;
|
||||
use crate::{named_params, params, Connection, Error, Value, ValueType};
|
||||
use rusqlite::ErrorCode;
|
||||
|
||||
@@ -322,22 +324,49 @@ async fn test_hooks() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
conn
|
||||
.add_hook(move |c, action, _db, table, row_id| match action {
|
||||
rusqlite::hooks::Action::SQLITE_INSERT => {
|
||||
let text = c
|
||||
.query_row(
|
||||
&format!(r#"SELECT text FROM "{table}" WHERE _rowid_ = $1"#),
|
||||
[row_id],
|
||||
|row| row.get::<_, String>(0),
|
||||
)
|
||||
.unwrap();
|
||||
struct State {
|
||||
action: rusqlite::hooks::Action,
|
||||
table_name: String,
|
||||
row_id: i64,
|
||||
}
|
||||
|
||||
sender.send(text).unwrap();
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
let c = conn.clone();
|
||||
|
||||
conn
|
||||
.add_preupdate_hook(Some(
|
||||
move |action: rusqlite::hooks::Action, _db: &str, table_name: &str, case: &PreUpdateCase| {
|
||||
let row_id = extract_row_id(case).unwrap();
|
||||
let state = State {
|
||||
action,
|
||||
table_name: table_name.to_string(),
|
||||
row_id,
|
||||
};
|
||||
|
||||
let sender = sender.clone();
|
||||
c.call_and_forget(move |conn| {
|
||||
match state.action {
|
||||
rusqlite::hooks::Action::SQLITE_INSERT => {
|
||||
let text = conn
|
||||
.query_row(
|
||||
&format!(
|
||||
r#"SELECT text FROM "{}" WHERE _rowid_ = $1"#,
|
||||
state.table_name
|
||||
),
|
||||
[state.row_id],
|
||||
|row| row.get::<_, String>(0),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
sender.send(text).unwrap();
|
||||
}
|
||||
_ => {
|
||||
panic!("unexpected action: {:?}", state.action);
|
||||
}
|
||||
};
|
||||
});
|
||||
},
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user