mirror of
https://github.com/trailbaseio/trailbase.git
synced 2026-05-19 07:49:57 -05:00
Massively overhaul change event execution model.
* minimize work in SQLite's preupdate hook * push brokering onto a separate thread * kick filter and acl checking further downstream into the SSE handlers. * add layered sequence numbers to detect server-side event losses and allow clients to detect client-side event losses.
This commit is contained in:
@@ -20,6 +20,7 @@ pub enum ValueOrComposite {
|
||||
Composite(Combiner, Vec<ValueOrComposite>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum Filter {
|
||||
Passthrough,
|
||||
Record(ValueOrComposite),
|
||||
@@ -224,7 +225,7 @@ fn parse_geometries(record: &[u8], filter: &str) -> Option<(geos::Geometry, geos
|
||||
|
||||
pub(crate) fn apply_filter_recursively_to_record(
|
||||
filter: &ValueOrComposite,
|
||||
record: &indexmap::IndexMap<&str, rusqlite::types::Value>,
|
||||
record: &indexmap::IndexMap<String, rusqlite::types::Value>,
|
||||
) -> bool {
|
||||
return match filter {
|
||||
ValueOrComposite::Value(col_op_value) => {
|
||||
@@ -314,7 +315,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_basic_value_filter() {
|
||||
let record: IndexMap<&str, Value> = IndexMap::from([("a", Value::Text("a value".to_string()))]);
|
||||
let record: IndexMap<String, Value> = IndexMap::from([(
|
||||
"a".to_string(),
|
||||
Value::Text("a value".to_string().to_string()),
|
||||
)]);
|
||||
|
||||
assert!(apply_filter_recursively_to_record(
|
||||
&ValueOrComposite::Value(ColumnOpValue {
|
||||
@@ -355,8 +359,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_basic_composite_filter() {
|
||||
let record: IndexMap<&str, Value> =
|
||||
IndexMap::from([("a", Value::Integer(5)), ("b", Value::Integer(-5))]);
|
||||
let record: IndexMap<String, Value> = IndexMap::from([
|
||||
("a".to_string(), Value::Integer(5)),
|
||||
("b".to_string(), Value::Integer(-5)),
|
||||
]);
|
||||
|
||||
assert!(apply_filter_recursively_to_record(
|
||||
&ValueOrComposite::Composite(
|
||||
|
||||
@@ -564,43 +564,60 @@ impl RecordApi {
|
||||
|
||||
/// Check if the given user (if any) can access a record given the request and the operation.
|
||||
#[inline]
|
||||
pub(crate) fn check_record_level_read_access_for_subscriptions(
|
||||
pub(crate) async fn check_record_level_read_access_for_subscriptions(
|
||||
&self,
|
||||
conn: &rusqlite::Connection,
|
||||
params: SubscriptionAclParams<'_>,
|
||||
conn: &trailbase_sqlite::Connection,
|
||||
record: &Arc<indexmap::IndexMap<String, rusqlite::types::Value>>,
|
||||
user: Option<&User>,
|
||||
) -> Result<(), RecordError> {
|
||||
// First check table level access and if present check row-level access based on access rule.
|
||||
self.check_table_level_access(Permission::Read, params.user)?;
|
||||
self.check_table_level_access(Permission::Read, user)?;
|
||||
|
||||
let Some(ref access_query) = self.state.subscription_read_access_query else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let mut stmt = conn
|
||||
.prepare_cached(access_query)
|
||||
.map_err(|_err| RecordError::Forbidden)?;
|
||||
conn
|
||||
.call_reader({
|
||||
let access_query = access_query.clone();
|
||||
let user = user.cloned();
|
||||
let record = record.clone();
|
||||
|
||||
// NOTE: the `bind` impl does the heavy lifting.
|
||||
params
|
||||
.bind(&mut stmt)
|
||||
.map_err(|_err| RecordError::Forbidden)?;
|
||||
move |conn| {
|
||||
let params = SubscriptionAclParams {
|
||||
params: &record,
|
||||
user: user.as_ref(),
|
||||
};
|
||||
|
||||
match stmt.raw_query().next() {
|
||||
Ok(Some(row)) => {
|
||||
if row.get(0).unwrap_or(false) {
|
||||
return Ok(());
|
||||
let mut stmt = conn.prepare_cached(&access_query)?;
|
||||
|
||||
// NOTE: the `bind` impl does the heavy lifting.
|
||||
params.bind(&mut stmt)?;
|
||||
|
||||
match stmt.raw_query().next() {
|
||||
Ok(Some(row)) => {
|
||||
if row.get(0).unwrap_or(false) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(err) => {
|
||||
warn!("RLA query failed: {err}");
|
||||
|
||||
#[cfg(test)]
|
||||
panic!("RLA query failed: {err}");
|
||||
}
|
||||
};
|
||||
|
||||
return Err(trailbase_sqlite::Error::Other(
|
||||
RecordError::Forbidden.into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(err) => {
|
||||
warn!("RLA query failed: {err}");
|
||||
})
|
||||
.await
|
||||
.map_err(|_| RecordError::Forbidden)?;
|
||||
|
||||
#[cfg(test)]
|
||||
panic!("RLA query failed: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
return Err(RecordError::Forbidden);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -704,9 +721,9 @@ impl RecordApi {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct SubscriptionAclParams<'a> {
|
||||
pub params: &'a indexmap::IndexMap<&'a str, rusqlite::types::Value>,
|
||||
pub user: Option<&'a User>,
|
||||
struct SubscriptionAclParams<'a> {
|
||||
params: &'a indexmap::IndexMap<String, rusqlite::types::Value>,
|
||||
user: Option<&'a User>,
|
||||
}
|
||||
|
||||
impl<'a> trailbase_sqlite::Params for SubscriptionAclParams<'a> {
|
||||
|
||||
@@ -1,41 +1,46 @@
|
||||
use axum::response::sse::Event as SseEvent;
|
||||
use serde::Serialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::records::RecordError;
|
||||
|
||||
type JsonObject = serde_json::value::Map<String, serde_json::Value>;
|
||||
|
||||
#[repr(i64)]
|
||||
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq)]
|
||||
pub enum EventErrorStatus {
|
||||
/// Unknown or unspecified error.
|
||||
Unknown = 0,
|
||||
/// Access forbidden.
|
||||
Forbidden = 1,
|
||||
/// Server-side event-loss, e.g. a buffer ran out of capacity. This does not account for
|
||||
/// additional losses that may happen between the TrailBase server and the client. This
|
||||
/// needs to be determined client-side based on event `seq` numbers.
|
||||
Loss = 2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
pub struct EventError {
|
||||
pub status: EventErrorStatus,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum JsonEventPayload {
|
||||
Update { value: JsonObject },
|
||||
Insert { value: JsonObject },
|
||||
Delete { value: JsonObject },
|
||||
Error { error: String },
|
||||
Error { value: EventError },
|
||||
Ping,
|
||||
}
|
||||
|
||||
// fn serialize_raw_json<S>(json: &Option<String>, s: S) -> Result<S::Ok, S::Error>
|
||||
// where
|
||||
// S: serde::ser::Serializer,
|
||||
// {
|
||||
// // This should be pretty efficient: it just checks that the string is valid;
|
||||
// // it doesn't parse it into a new data structure.
|
||||
// if let Some(json) = json {
|
||||
// let v: &serde_json::value::RawValue = serde_json::from_str(json).expect("invalid json");
|
||||
// return v.serialize(s);
|
||||
// }
|
||||
//
|
||||
// return s.serialize_none();
|
||||
// }
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub enum EventPayload {
|
||||
Update(Option<Box<serde_json::value::RawValue>>),
|
||||
Insert(Option<Box<serde_json::value::RawValue>>),
|
||||
Delete(Option<Box<serde_json::value::RawValue>>),
|
||||
Error(String),
|
||||
Error(Option<Box<serde_json::value::RawValue>>),
|
||||
Ping,
|
||||
}
|
||||
|
||||
@@ -49,7 +54,7 @@ impl PartialEq for EventPayload {
|
||||
(Self::Update(lhs), Self::Update(rhs)) => get(lhs) == get(rhs),
|
||||
(Self::Insert(lhs), Self::Insert(rhs)) => get(lhs) == get(rhs),
|
||||
(Self::Delete(lhs), Self::Delete(rhs)) => get(lhs) == get(rhs),
|
||||
(Self::Error(lhs), Self::Error(rhs)) => lhs == rhs,
|
||||
(Self::Error(lhs), Self::Error(rhs)) => get(lhs) == get(rhs),
|
||||
(Self::Ping, Self::Ping) => true,
|
||||
_ => false,
|
||||
};
|
||||
@@ -74,7 +79,11 @@ impl EventPayload {
|
||||
.map(|v| v.to_owned())
|
||||
.ok(),
|
||||
),
|
||||
JsonEventPayload::Error { error } => EventPayload::Error(error.clone()),
|
||||
JsonEventPayload::Error { value } => EventPayload::Error(
|
||||
serde_json::value::to_raw_value(&value)
|
||||
.map(|v| v.to_owned())
|
||||
.ok(),
|
||||
),
|
||||
JsonEventPayload::Ping => EventPayload::Ping,
|
||||
};
|
||||
}
|
||||
@@ -118,20 +127,24 @@ pub struct ChangeEvent {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn deserialize_event(ev: Arc<EventPayload>) -> Result<JsonEventPayload, serde_json::Error> {
|
||||
return match *ev {
|
||||
EventPayload::Update(ref v) => Ok(JsonEventPayload::Update {
|
||||
value: serde_json::from_str(v.as_ref().map_or("", |v| v.get()))?,
|
||||
}),
|
||||
EventPayload::Insert(ref v) => Ok(JsonEventPayload::Insert {
|
||||
value: serde_json::from_str(v.as_ref().map_or("", |v| v.get()))?,
|
||||
}),
|
||||
EventPayload::Delete(ref v) => Ok(JsonEventPayload::Delete {
|
||||
value: serde_json::from_str(v.as_ref().map_or("", |v| v.get()))?,
|
||||
}),
|
||||
EventPayload::Error(ref err) => Ok(JsonEventPayload::Error { error: err.clone() }),
|
||||
EventPayload::Ping => Ok(JsonEventPayload::Ping),
|
||||
};
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
pub enum TestJsonEventPayload {
|
||||
Update(JsonObject),
|
||||
Insert(JsonObject),
|
||||
Delete(JsonObject),
|
||||
Error {
|
||||
status: EventErrorStatus,
|
||||
message: Option<String>,
|
||||
},
|
||||
Ping,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
pub struct TestChangeEvent {
|
||||
#[serde(flatten)]
|
||||
pub event: TestJsonEventPayload,
|
||||
pub seq: Option<i64>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -139,66 +152,50 @@ mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// #[test]
|
||||
// fn serialization_test() {
|
||||
// let payload0 = EventPayload::from(&JsonEventPayload::Insert {
|
||||
// value: JsonObject::from_iter([("key0".to_string(), json!("value0"))]),
|
||||
// });
|
||||
//
|
||||
// let ev0 = ChangeEvent {
|
||||
// payload: Arc::new(payload0.clone()),
|
||||
// seq: None,
|
||||
// };
|
||||
//
|
||||
// let expected0 = r#"{"type":"insert","value":{"key0":"value0"}}"#;
|
||||
// let ev0str = serde_json::to_string(&ev0).unwrap();
|
||||
// assert_eq!(expected0, ev0str);
|
||||
// assert_eq!(
|
||||
// expected0,
|
||||
// serde_json::to_string(&ChangeEvent {
|
||||
// payload: Arc::new(payload0.clone()),
|
||||
// seq: None,
|
||||
// })
|
||||
// .unwrap()
|
||||
// );
|
||||
//
|
||||
// let ev0deserialized: EventPayload = serde_json::from_str(&expected0).unwrap();
|
||||
// assert_eq!(payload0, ev0deserialized);
|
||||
//
|
||||
// let payload1 = EventPayload::from(&JsonEventPayload::Error {
|
||||
// error: "boom".to_string(),
|
||||
// });
|
||||
// let ev1 = ChangeEvent {
|
||||
// payload: Arc::new(payload1),
|
||||
// seq: Some(11),
|
||||
// };
|
||||
//
|
||||
// let expected1 = r#"{"type":"error","error":"boom","seq":11}"#;
|
||||
// let ev1str = serde_json::to_string(&ev1).unwrap();
|
||||
// assert_eq!(expected1, ev1str);
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn serialization_foo_test() {
|
||||
let event = ChangeEvent {
|
||||
event: Arc::new(EventPayload::Delete(Some(
|
||||
serde_json::value::to_raw_value(&json!({
|
||||
"foo": 4,
|
||||
}))
|
||||
.unwrap(),
|
||||
))),
|
||||
seq: Some(4),
|
||||
};
|
||||
{
|
||||
let event = ChangeEvent {
|
||||
event: Arc::new(EventPayload::from(&JsonEventPayload::Delete {
|
||||
value: JsonObject::from_iter([("foo".to_string(), json!(4))]),
|
||||
})),
|
||||
seq: Some(4),
|
||||
};
|
||||
|
||||
let value = serde_json::to_value(&event).unwrap();
|
||||
assert_eq!(
|
||||
serde_json::json!({
|
||||
"Delete": {
|
||||
"foo": 4,
|
||||
let value = serde_json::to_value(&event).unwrap();
|
||||
assert_eq!(
|
||||
serde_json::json!({
|
||||
"Delete": {
|
||||
"foo": 4,
|
||||
},
|
||||
"seq": 4,
|
||||
}),
|
||||
value
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let event = ChangeEvent {
|
||||
event: Arc::new(EventPayload::from(&JsonEventPayload::Error {
|
||||
value: EventError {
|
||||
status: EventErrorStatus::Loss,
|
||||
message: Some("test".to_string()),
|
||||
},
|
||||
"seq": 4,
|
||||
}),
|
||||
value
|
||||
);
|
||||
})),
|
||||
seq: Some(4),
|
||||
};
|
||||
|
||||
let value = serde_json::to_value(&event).unwrap();
|
||||
assert_eq!(
|
||||
serde_json::json!({
|
||||
"Error": {
|
||||
"status": EventErrorStatus::Loss,
|
||||
"message": "test",
|
||||
},
|
||||
"seq": 4,
|
||||
}),
|
||||
value
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
use axum::extract::{Path, RawQuery, Request, State};
|
||||
use axum::response::sse::{KeepAlive, Sse};
|
||||
use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use futures_util::StreamExt;
|
||||
use futures_util::stream;
|
||||
use serde::Deserialize;
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use trailbase_qs::ValueOrComposite;
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::app_state::AppState;
|
||||
use crate::auth::User;
|
||||
use crate::records::RecordApi;
|
||||
use crate::records::filter::{Filter, apply_filter_recursively_to_record};
|
||||
use crate::records::subscribe::event::{
|
||||
EventError, EventErrorStatus, EventPayload, JsonEventPayload,
|
||||
};
|
||||
use crate::records::subscribe::state::{EventCandidate, Subscription};
|
||||
use crate::records::{Permission, RecordError};
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Deserialize)]
|
||||
@@ -84,6 +91,48 @@ pub async fn add_subscription_sse_and_ws_handler(
|
||||
};
|
||||
}
|
||||
|
||||
struct ValidateEventArgs {
|
||||
state: AppState,
|
||||
// FIXME: We could probably do with a subset of information from the `Subscription` and keep it
|
||||
// internal.
|
||||
subscription: Arc<Subscription>,
|
||||
expected_candidate_seq: AtomicI64,
|
||||
}
|
||||
|
||||
async fn validate_event(
|
||||
args: Arc<ValidateEventArgs>,
|
||||
ev: EventCandidate,
|
||||
) -> Result<Option<Arc<EventPayload>>, RecordError> {
|
||||
if ev.seq != args.expected_candidate_seq.fetch_add(1, Ordering::SeqCst) {
|
||||
args.expected_candidate_seq.store(ev.seq, Ordering::SeqCst);
|
||||
return Ok(Some(EVENT_LOSS_EVENT.clone()));
|
||||
}
|
||||
|
||||
let Some(ref record) = ev.record else {
|
||||
// Established events.
|
||||
return Ok(Some(ev.payload));
|
||||
};
|
||||
|
||||
let sub: &Subscription = &args.subscription;
|
||||
if let Filter::Record(ref filter) = sub.filter
|
||||
&& !apply_filter_recursively_to_record(filter, record)
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// We don't memoize and eagerly look up the APIs to make sure we get an up-to-date
|
||||
// version.
|
||||
let Some(api) = args.state.lookup_record_api(&sub.record_api_name) else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
api
|
||||
.check_record_level_read_access_for_subscriptions(api.conn(), record, sub.user.as_ref())
|
||||
.await?;
|
||||
|
||||
return Ok(Some(ev.payload));
|
||||
}
|
||||
|
||||
pub async fn subscribe_sse(
|
||||
state: AppState,
|
||||
api: RecordApi,
|
||||
@@ -91,21 +140,35 @@ pub async fn subscribe_sse(
|
||||
filter: Option<ValueOrComposite>,
|
||||
user: Option<User>,
|
||||
) -> Result<Response, RecordError> {
|
||||
let seq = Arc::new(AtomicI64::default());
|
||||
|
||||
return match record.as_str() {
|
||||
"*" => {
|
||||
api.check_table_level_access(Permission::Read, user.as_ref())?;
|
||||
|
||||
let receiver = state
|
||||
let (receiver, subscription) = state
|
||||
.subscription_manager()
|
||||
.add_sse_table_subscription(api, user, filter)
|
||||
.await?;
|
||||
|
||||
let seq = AtomicI64::default();
|
||||
let args = Arc::new(ValidateEventArgs {
|
||||
state,
|
||||
subscription,
|
||||
expected_candidate_seq: AtomicI64::default(),
|
||||
});
|
||||
|
||||
Ok(
|
||||
Sse::new(
|
||||
receiver.map(move |ev| ev.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst)))),
|
||||
)
|
||||
Sse::new(receiver.filter_map(move |ev: EventCandidate| {
|
||||
let seq = seq.clone();
|
||||
let args = args.clone();
|
||||
|
||||
return async move {
|
||||
validate_event(args.clone(), ev)
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
.map(|ev| ev.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst))))
|
||||
};
|
||||
}))
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response(),
|
||||
)
|
||||
@@ -116,16 +179,48 @@ pub async fn subscribe_sse(
|
||||
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
|
||||
.await?;
|
||||
|
||||
let receiver = state
|
||||
let (receiver, subscription) = state
|
||||
.subscription_manager()
|
||||
.add_sse_record_subscription(api, record_id, user)
|
||||
.await?;
|
||||
|
||||
let seq = AtomicI64::default();
|
||||
let args = Arc::new(ValidateEventArgs {
|
||||
state,
|
||||
subscription,
|
||||
expected_candidate_seq: AtomicI64::default(),
|
||||
});
|
||||
|
||||
Ok(
|
||||
Sse::new(
|
||||
receiver.map(move |ev| ev.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst)))),
|
||||
receiver
|
||||
.then(move |ev: EventCandidate| {
|
||||
let seq = seq.clone();
|
||||
let args = args.clone();
|
||||
|
||||
return async move {
|
||||
match validate_event(args.clone(), ev).await {
|
||||
Ok(None) => stream::empty().boxed(),
|
||||
Ok(Some(ev)) => stream::once(std::future::ready(
|
||||
ev.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst))),
|
||||
))
|
||||
.boxed(),
|
||||
Err(_) => {
|
||||
// Death sentence for record subscriptions to not have access
|
||||
stream::iter(vec![
|
||||
// First send an error event to the user.
|
||||
ACCESS_DENIED_EVENT
|
||||
.clone()
|
||||
.into_sse_event(Some(seq.fetch_add(1, Ordering::SeqCst))),
|
||||
// Then terminate the stream via the `take_while` below.
|
||||
Err(RecordError::Forbidden),
|
||||
])
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
};
|
||||
})
|
||||
.flatten()
|
||||
.take_while(|event: &Result<SseEvent, RecordError>| std::future::ready(event.is_ok())),
|
||||
)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response(),
|
||||
@@ -157,6 +252,7 @@ pub async fn subscribe_ws(
|
||||
|
||||
use crate::records::subscribe::event::EventPayload;
|
||||
use crate::records::subscribe::state::AutoCleanupEventStream;
|
||||
use crate::records::subscribe::state::EventCandidate;
|
||||
|
||||
let (mut parts, _body) = request.into_parts();
|
||||
let ws = match WebSocketUpgrade::from_request_parts(&mut parts, &state).await {
|
||||
@@ -197,14 +293,42 @@ pub async fn subscribe_ws(
|
||||
}
|
||||
|
||||
async fn broker<S: SinkExt<Message> + std::marker::Unpin>(
|
||||
state: AppState,
|
||||
subscription: Arc<Subscription>,
|
||||
// Receive events from SQLite
|
||||
receiver: AutoCleanupEventStream,
|
||||
// Send messages via WebSocket.
|
||||
sender: &mut S,
|
||||
is_record_subscription: bool,
|
||||
) {
|
||||
let args = Arc::new(ValidateEventArgs {
|
||||
state,
|
||||
subscription,
|
||||
expected_candidate_seq: AtomicI64::default(),
|
||||
});
|
||||
|
||||
let mut pinned_receiver = std::pin::pin!(receiver);
|
||||
while let Some(ev) = pinned_receiver.next().await {
|
||||
match ev.into_ws_event() {
|
||||
let payload = match validate_event(args.clone(), ev).await {
|
||||
Ok(Some(payload)) => payload,
|
||||
Ok(None) => {
|
||||
continue;
|
||||
}
|
||||
Err(_) => {
|
||||
if is_record_subscription {
|
||||
// Death sentence for record subscriptions to not have access
|
||||
let _ = ACCESS_DENIED_EVENT
|
||||
.clone()
|
||||
.into_ws_event()
|
||||
.map(|ev| sender.send(ev));
|
||||
return;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match payload.into_ws_event() {
|
||||
Ok(msg) => {
|
||||
if let Err(_value) = sender.send(msg).await {
|
||||
log::debug!("Sending WS event to client failed");
|
||||
@@ -271,6 +395,8 @@ pub async fn subscribe_ws(
|
||||
return match record.as_str() {
|
||||
"*" => {
|
||||
Ok(ws.on_upgrade(async move |socket: WebSocket| {
|
||||
use crate::records::subscribe::state::EventCandidate;
|
||||
|
||||
let Some(mut ws_sender) = init(&state, socket, &mut user).await else {
|
||||
return;
|
||||
};
|
||||
@@ -283,10 +409,10 @@ pub async fn subscribe_ws(
|
||||
return;
|
||||
}
|
||||
|
||||
let (sender, receiver) = async_channel::bounded::<Arc<EventPayload>>(16);
|
||||
let state = state.subscription_manager().get_per_connection_state(&api);
|
||||
let (sender, receiver) = async_channel::bounded::<EventCandidate>(64);
|
||||
let conn_state = state.subscription_manager().get_per_connection_state(&api);
|
||||
|
||||
let Ok(id) = state
|
||||
let Ok(subscription) = conn_state
|
||||
.clone()
|
||||
.add_table_subscription(api, user, filter, sender)
|
||||
.await
|
||||
@@ -295,15 +421,17 @@ pub async fn subscribe_ws(
|
||||
return;
|
||||
};
|
||||
|
||||
let receiver = AutoCleanupEventStream::new(receiver, state, id);
|
||||
let receiver = AutoCleanupEventStream::new(receiver, conn_state, subscription.id.clone());
|
||||
|
||||
broker(receiver, &mut ws_sender).await
|
||||
broker(state, subscription, receiver, &mut ws_sender, false).await
|
||||
}))
|
||||
}
|
||||
_ => {
|
||||
let record_id = api.primary_key_to_value(record)?;
|
||||
|
||||
Ok(ws.on_upgrade(async move |socket: WebSocket| {
|
||||
use crate::records::subscribe::state::EventCandidate;
|
||||
|
||||
let Some(mut ws_sender) = init(&state, socket, &mut user).await else {
|
||||
return;
|
||||
};
|
||||
@@ -319,10 +447,10 @@ pub async fn subscribe_ws(
|
||||
return;
|
||||
}
|
||||
|
||||
let (sender, receiver) = async_channel::bounded::<Arc<EventPayload>>(16);
|
||||
let state = state.subscription_manager().get_per_connection_state(&api);
|
||||
let (sender, receiver) = async_channel::bounded::<EventCandidate>(64);
|
||||
let conn_state = state.subscription_manager().get_per_connection_state(&api);
|
||||
|
||||
let Ok(id) = state
|
||||
let Ok(subscription) = conn_state
|
||||
.clone()
|
||||
.add_record_subscription(api, record_id, user, sender)
|
||||
.await
|
||||
@@ -331,10 +459,37 @@ pub async fn subscribe_ws(
|
||||
return;
|
||||
};
|
||||
|
||||
let receiver = AutoCleanupEventStream::new(receiver, state, id);
|
||||
let receiver = AutoCleanupEventStream::new(receiver, conn_state, subscription.id.clone());
|
||||
|
||||
broker(receiver, &mut ws_sender).await;
|
||||
broker(state, subscription, receiver, &mut ws_sender, true).await;
|
||||
}))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static ACCESS_DENIED_EVENT: LazyLock<Arc<EventPayload>> = LazyLock::new(|| {
|
||||
Arc::new(EventPayload::from(&JsonEventPayload::Error {
|
||||
value: EventError {
|
||||
status: EventErrorStatus::Forbidden,
|
||||
message: Some("Access denied".into()),
|
||||
},
|
||||
}))
|
||||
});
|
||||
static EVENT_LOSS_EVENT: LazyLock<Arc<EventPayload>> = LazyLock::new(|| {
|
||||
Arc::new(EventPayload::from(&JsonEventPayload::Error {
|
||||
value: EventError {
|
||||
status: EventErrorStatus::Loss,
|
||||
message: None,
|
||||
},
|
||||
}))
|
||||
});
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn static_sse_event_test() {
|
||||
let _x: Arc<EventPayload> = (*ACCESS_DENIED_EVENT).clone();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,13 +23,14 @@ pub struct PreupdateHookEvent {
|
||||
pub record: Vec<Value>,
|
||||
}
|
||||
|
||||
pub fn install_hook(conn: &Connection) -> kanal::Receiver<PreupdateHookEvent> {
|
||||
pub fn install_hook(conn: &Connection) -> kanal::Receiver<(usize, PreupdateHookEvent)> {
|
||||
let (sender, receiver) = kanal::bounded(CAPACITY);
|
||||
|
||||
conn
|
||||
.write_lock()
|
||||
.preupdate_hook({
|
||||
let conn = conn.clone();
|
||||
let mut cnt = 0;
|
||||
|
||||
Some(
|
||||
move |action: Action, db: &str, table_name: &str, case: &PreUpdateCase| {
|
||||
@@ -69,7 +70,9 @@ pub fn install_hook(conn: &Connection) -> kanal::Receiver<PreupdateHookEvent> {
|
||||
record,
|
||||
};
|
||||
|
||||
match sender.try_send(event) {
|
||||
cnt += 1;
|
||||
|
||||
match sender.try_send((cnt, event)) {
|
||||
Ok(true) => {}
|
||||
Ok(false) => {
|
||||
warn!("Channel full. Failed to forward preupdate event.")
|
||||
@@ -129,11 +132,13 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ev0 = receiver.next().unwrap();
|
||||
let (cnt, ev0) = receiver.next().unwrap();
|
||||
assert_eq!(1, cnt);
|
||||
assert_eq!("\"test\"", ev0.table_name.escaped_string());
|
||||
assert_eq!(Value::Integer(3), ev0.record[0]);
|
||||
|
||||
let ev1 = receiver.next().unwrap();
|
||||
let (cnt, ev1) = receiver.next().unwrap();
|
||||
assert_eq!(2, cnt);
|
||||
assert_eq!(Value::Integer(4), ev1.record[0]);
|
||||
|
||||
uninstall_hook(&conn);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use parking_lot::RwLock;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use reactivate::Reactive;
|
||||
use std::collections::{HashMap, hash_map::Entry};
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use trailbase_qs::ValueOrComposite;
|
||||
|
||||
@@ -9,7 +10,10 @@ use crate::auth::User;
|
||||
use crate::records::RecordApi;
|
||||
use crate::records::RecordError;
|
||||
use crate::records::subscribe::event::{EventPayload, JsonEventPayload};
|
||||
use crate::records::subscribe::state::{AutoCleanupEventStream, PerConnectionState};
|
||||
use crate::records::subscribe::state::{
|
||||
AutoCleanupEventStream, EventCandidate, PerConnectionState, PerConnectionStateInternal,
|
||||
Subscription,
|
||||
};
|
||||
|
||||
/// Internal, shareable state of the cloneable SubscriptionManager.
|
||||
struct ManagerState {
|
||||
@@ -47,17 +51,21 @@ impl SubscriptionManager {
|
||||
|
||||
let id = api.conn().id();
|
||||
|
||||
// TODO: Clean subscriptions from existing entries for tables that not longer have a
|
||||
// corresponding API.
|
||||
// TODO: Skip/cleanup subscriptions from existing entries for tables that not longer have
|
||||
// a corresponding API.
|
||||
if let Some(existing) = old.remove(&id) {
|
||||
let apis = filter_record_apis(id, record_apis);
|
||||
let Some(first) = apis.values().nth(0) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Update metadata and add back.
|
||||
*existing.connection_metadata.write() = first.connection_metadata().clone();
|
||||
*existing.record_apis.write() = apis;
|
||||
{
|
||||
let mut state = existing.state.lock();
|
||||
// Update metadata and add back.
|
||||
state.connection_metadata = first.connection_metadata().clone();
|
||||
state.record_apis = apis;
|
||||
}
|
||||
|
||||
lock.insert(id, existing);
|
||||
}
|
||||
}
|
||||
@@ -72,23 +80,32 @@ impl SubscriptionManager {
|
||||
api: RecordApi,
|
||||
user: Option<User>,
|
||||
filter: Option<ValueOrComposite>,
|
||||
) -> Result<AutoCleanupEventStream, RecordError> {
|
||||
let (sender, receiver) = async_channel::bounded::<Arc<EventPayload>>(16);
|
||||
) -> Result<(AutoCleanupEventStream, Arc<Subscription>), RecordError> {
|
||||
let (sender, receiver) = async_channel::bounded::<EventCandidate>(64);
|
||||
let state = self.get_per_connection_state(&api);
|
||||
|
||||
let id = state
|
||||
let subscription = state
|
||||
.clone()
|
||||
.add_table_subscription(api, user, filter, sender.clone())
|
||||
.await?;
|
||||
|
||||
// Send an immediate comment to flush SSE headers and establish the connection
|
||||
if sender.send(ESTABLISHED_EVENT.clone()).await.is_err() {
|
||||
if sender
|
||||
.send(EventCandidate {
|
||||
record: None,
|
||||
payload: ESTABLISHED_EVENT.clone(),
|
||||
seq: subscription.candidate_seq.fetch_add(1, Ordering::SeqCst),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err(RecordError::BadRequest("channel already closed"));
|
||||
}
|
||||
|
||||
let receiver = AutoCleanupEventStream::new(receiver, state, id);
|
||||
|
||||
return Ok(receiver);
|
||||
return Ok((
|
||||
AutoCleanupEventStream::new(receiver, state, subscription.id.clone()),
|
||||
subscription,
|
||||
));
|
||||
}
|
||||
|
||||
pub async fn add_sse_record_subscription(
|
||||
@@ -96,23 +113,32 @@ impl SubscriptionManager {
|
||||
api: RecordApi,
|
||||
record: trailbase_sqlite::Value,
|
||||
user: Option<User>,
|
||||
) -> Result<AutoCleanupEventStream, RecordError> {
|
||||
let (sender, receiver) = async_channel::bounded::<Arc<EventPayload>>(16);
|
||||
) -> Result<(AutoCleanupEventStream, Arc<Subscription>), RecordError> {
|
||||
let (sender, receiver) = async_channel::bounded::<EventCandidate>(64);
|
||||
let state = self.get_per_connection_state(&api);
|
||||
|
||||
let id = state
|
||||
let subscription = state
|
||||
.clone()
|
||||
.add_record_subscription(api, record, user, sender.clone())
|
||||
.await?;
|
||||
|
||||
// Send an immediate comment to flush SSE headers and establish the connection
|
||||
if sender.send(ESTABLISHED_EVENT.clone()).await.is_err() {
|
||||
if sender
|
||||
.send(EventCandidate {
|
||||
record: None,
|
||||
payload: ESTABLISHED_EVENT.clone(),
|
||||
seq: subscription.candidate_seq.fetch_add(1, Ordering::SeqCst),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err(RecordError::BadRequest("channel already closed"));
|
||||
}
|
||||
|
||||
let receiver = AutoCleanupEventStream::new(receiver, state, id);
|
||||
|
||||
return Ok(receiver);
|
||||
return Ok((
|
||||
AutoCleanupEventStream::new(receiver, state, subscription.id.clone()),
|
||||
subscription,
|
||||
));
|
||||
}
|
||||
|
||||
pub fn get_per_connection_state(&self, api: &RecordApi) -> Arc<PerConnectionState> {
|
||||
@@ -127,9 +153,12 @@ impl SubscriptionManager {
|
||||
Entry::Occupied(v) => v.get().clone(),
|
||||
Entry::Vacant(v) => {
|
||||
let state = Arc::new(PerConnectionState {
|
||||
connection_metadata: RwLock::new(api.connection_metadata().clone()),
|
||||
record_apis: RwLock::new(filter_record_apis(id, &self.state.record_apis.value())),
|
||||
subscriptions: Default::default(),
|
||||
state: Mutex::new(PerConnectionStateInternal {
|
||||
connection_metadata: api.connection_metadata().clone(),
|
||||
record_apis: filter_record_apis(id, &self.state.record_apis.value()),
|
||||
conn: api.conn().clone(),
|
||||
subscriptions: Default::default(),
|
||||
}),
|
||||
});
|
||||
v.insert(state).clone()
|
||||
}
|
||||
@@ -141,8 +170,8 @@ impl SubscriptionManager {
|
||||
pub fn num_record_subscriptions(&self) -> usize {
|
||||
let mut count: usize = 0;
|
||||
for state in self.state.connections.read().values() {
|
||||
for (_table_name, subs) in state.subscriptions.read().iter() {
|
||||
for record in subs.read().record.values() {
|
||||
for (_table_name, subs) in state.state.lock().subscriptions.iter() {
|
||||
for record in subs.record.values() {
|
||||
count += record.len();
|
||||
}
|
||||
}
|
||||
@@ -154,8 +183,8 @@ impl SubscriptionManager {
|
||||
pub fn num_table_subscriptions(&self) -> usize {
|
||||
let mut count: usize = 0;
|
||||
for state in self.state.connections.read().values() {
|
||||
for (_table_name, subs) in state.subscriptions.read().iter() {
|
||||
count += subs.read().table.len();
|
||||
for (_table_name, subs) in state.state.lock().subscriptions.iter() {
|
||||
count += subs.table.len();
|
||||
}
|
||||
}
|
||||
return count;
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use async_channel::{TrySendError, WeakReceiver};
|
||||
use async_channel::WeakReceiver;
|
||||
use futures_util::Stream;
|
||||
use log::*;
|
||||
use parking_lot::RwLock;
|
||||
use parking_lot::Mutex;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::task::{Context, Poll};
|
||||
use trailbase_qs::ValueOrComposite;
|
||||
use trailbase_schema::QualifiedName;
|
||||
@@ -15,20 +15,17 @@ use trailbase_schema::json::value_to_flat_json;
|
||||
use crate::auth::User;
|
||||
use crate::records::RecordApi;
|
||||
use crate::records::RecordError;
|
||||
use crate::records::filter::{
|
||||
Filter, apply_filter_recursively_to_record, qs_filter_to_record_filter,
|
||||
};
|
||||
use crate::records::record_api::SubscriptionAclParams;
|
||||
use crate::records::filter::{Filter, qs_filter_to_record_filter};
|
||||
use crate::records::subscribe::event::{EventPayload, JsonEventPayload};
|
||||
use crate::records::subscribe::hook::{
|
||||
PreupdateHookEvent, RecordAction, install_hook, uninstall_hook, uninstall_hook_rusqlite,
|
||||
PreupdateHookEvent, RecordAction, install_hook, uninstall_hook,
|
||||
};
|
||||
use crate::schema_metadata::ConnectionMetadata;
|
||||
|
||||
/// Composite id uniquely identifying a subscription.
|
||||
///
|
||||
/// If row_id is Some, this is considered to reference a subscription to a specific record.
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct SubscriptionId {
|
||||
pub table_name: QualifiedName,
|
||||
pub sub_id: i64,
|
||||
@@ -38,25 +35,24 @@ pub struct SubscriptionId {
|
||||
/// RAII type for automatically cleaning up subscriptions when the receiving side gets dropped,
|
||||
/// e.g. client disconnects.
|
||||
struct AutoCleanupEventStreamState {
|
||||
receiver: WeakReceiver<Arc<EventPayload>>,
|
||||
state: Arc<PerConnectionState>,
|
||||
receiver: WeakReceiver<EventCandidate>,
|
||||
state: Weak<PerConnectionState>,
|
||||
id: SubscriptionId,
|
||||
}
|
||||
|
||||
impl Drop for AutoCleanupEventStreamState {
|
||||
fn drop(&mut self) {
|
||||
// Subscriptions can be cleaned up either by the sender, i.e. when trying to broker events and
|
||||
// tables or records get deleted, or by the client-receiver, e.g. by disconnecting. In the
|
||||
// latter case, we need to clean up the subscription.
|
||||
// tables or records get deleted, or by the client-receiver, e.g. by disconnecting.
|
||||
// When dropped by the client-side, we need to clean up the subscription.
|
||||
if self.receiver.upgrade().is_some() {
|
||||
let id = std::mem::take(&mut self.id);
|
||||
let state = self.state.clone();
|
||||
|
||||
if let Some(first) = self.state.record_apis.read().values().nth(0) {
|
||||
first.conn().call_and_forget(move |conn| {
|
||||
state.remove_subscription(conn, id);
|
||||
});
|
||||
}
|
||||
let Some(state) = self.state.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
state.state.lock().remove_subscription2(id);
|
||||
} else {
|
||||
debug!("Subscription cleaned up already by the sender side.");
|
||||
}
|
||||
@@ -70,19 +66,19 @@ pin_project! {
|
||||
state: AutoCleanupEventStreamState,
|
||||
|
||||
#[pin]
|
||||
pub receiver: async_channel::Receiver<Arc<EventPayload>>,
|
||||
pub receiver: async_channel::Receiver<EventCandidate>,
|
||||
}
|
||||
}
|
||||
impl AutoCleanupEventStream {
|
||||
pub fn new(
|
||||
receiver: async_channel::Receiver<Arc<EventPayload>>,
|
||||
receiver: async_channel::Receiver<EventCandidate>,
|
||||
state: Arc<PerConnectionState>,
|
||||
id: SubscriptionId,
|
||||
) -> Self {
|
||||
return Self {
|
||||
state: AutoCleanupEventStreamState {
|
||||
receiver: receiver.downgrade(),
|
||||
state,
|
||||
state: Arc::downgrade(&state),
|
||||
id,
|
||||
},
|
||||
receiver,
|
||||
@@ -91,7 +87,7 @@ impl AutoCleanupEventStream {
|
||||
}
|
||||
|
||||
impl Stream for AutoCleanupEventStream {
|
||||
type Item = Arc<EventPayload>;
|
||||
type Item = EventCandidate;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let mut this = self.project();
|
||||
@@ -104,27 +100,38 @@ impl Stream for AutoCleanupEventStream {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Subscription {
|
||||
/// Id uniquely identifying this subscription.
|
||||
subscription_id: i64,
|
||||
pub id: SubscriptionId,
|
||||
/// Name of the API this subscription is subscribed to. We need to lookup the Record API on the
|
||||
/// hot path to make sure we're getting the latest configuration.
|
||||
record_api_name: String,
|
||||
pub record_api_name: String,
|
||||
/// User associated with subscriber.
|
||||
user: Option<User>,
|
||||
pub user: Option<User>,
|
||||
/// Record filter.
|
||||
pub filter: Filter,
|
||||
/// Channel for sending events to the SSE handler.
|
||||
sender: async_channel::Sender<Arc<EventPayload>>,
|
||||
/// Filter
|
||||
filter: Filter,
|
||||
pub sender: async_channel::Sender<EventCandidate>,
|
||||
|
||||
pub candidate_seq: AtomicI64,
|
||||
}
|
||||
|
||||
// Represents a change event that needs further filtering, e.g. ACLs.
|
||||
#[derive(Debug)]
|
||||
pub struct EventCandidate {
|
||||
pub record: Option<Arc<indexmap::IndexMap<String, rusqlite::types::Value>>>,
|
||||
pub payload: Arc<EventPayload>,
|
||||
pub seq: i64,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Subscriptions {
|
||||
/// A list of table subscriptions for this table.
|
||||
pub table: Vec<Subscription>,
|
||||
pub table: Vec<Arc<Subscription>>,
|
||||
|
||||
/// A map of record subscriptions for this.
|
||||
pub record: HashMap<i64, Vec<Subscription>>,
|
||||
pub record: HashMap<i64, Vec<Arc<Subscription>>>,
|
||||
}
|
||||
|
||||
impl Subscriptions {
|
||||
@@ -133,96 +140,104 @@ impl Subscriptions {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PerConnectionState {
|
||||
pub struct PerConnectionStateInternal {
|
||||
/// Metadata: always updated together when config -> record APIs change.
|
||||
pub record_apis: RwLock<HashMap<String, RecordApi>>,
|
||||
pub record_apis: HashMap<String, RecordApi>,
|
||||
|
||||
/// Denormalized metadata. We could also grab this from:
|
||||
/// `record_apis.read().nth(0).unwrap().connection_metadata()`.
|
||||
pub connection_metadata: RwLock<Arc<ConnectionMetadata>>,
|
||||
pub connection_metadata: Arc<ConnectionMetadata>,
|
||||
|
||||
/// Should be the same as for all `record_apis` above.
|
||||
pub conn: Arc<trailbase_sqlite::Connection>,
|
||||
|
||||
/// Map from table name to row id to list of subscriptions.
|
||||
///
|
||||
/// NOTE: Use layered locking to allow cleaning up per-table subscriptions w/o having to
|
||||
/// exclusively lock the entire map.
|
||||
pub subscriptions: RwLock<HashMap</* table_name= */ QualifiedName, RwLock<Subscriptions>>>,
|
||||
pub subscriptions: HashMap</* table_name= */ QualifiedName, Subscriptions>,
|
||||
}
|
||||
|
||||
impl PerConnectionStateInternal {
|
||||
pub fn remove_subscription2(&mut self, id: SubscriptionId) {
|
||||
let Some(subscriptions) = self.subscriptions.get_mut(&id.table_name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(row_id) = id.row_id {
|
||||
if let Some(record_subscriptions) = subscriptions.record.get_mut(&row_id) {
|
||||
record_subscriptions.retain(|sub| {
|
||||
return sub.id.sub_id != id.sub_id;
|
||||
});
|
||||
|
||||
if record_subscriptions.is_empty() {
|
||||
subscriptions.record.remove(&row_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
subscriptions.table.retain(|sub| {
|
||||
return sub.id.sub_id != id.sub_id;
|
||||
});
|
||||
}
|
||||
|
||||
if subscriptions.is_empty() {
|
||||
self.subscriptions.remove(&id.table_name);
|
||||
if self.subscriptions.is_empty() {
|
||||
uninstall_hook(&self.conn);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PerConnectionState {
|
||||
pub state: Mutex<PerConnectionStateInternal>,
|
||||
}
|
||||
|
||||
impl PerConnectionState {
|
||||
fn lookup_record_api(&self, name: &str) -> Option<RecordApi> {
|
||||
return self.record_apis.read().get(name).cloned();
|
||||
}
|
||||
|
||||
// Gets called by the Stream destructor, e.g. when a client disconnects.
|
||||
fn remove_subscription(&self, conn: &rusqlite::Connection, id: SubscriptionId) {
|
||||
let mut read_lock = self.subscriptions.upgradable_read();
|
||||
|
||||
let remove_subscription_entry_for_table = {
|
||||
let Some(mut subscriptions) = read_lock.get(&id.table_name).map(|l| l.write()) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(row_id) = id.row_id {
|
||||
if let Some(record_subscriptions) = subscriptions.record.get_mut(&row_id) {
|
||||
record_subscriptions.retain(|sub| {
|
||||
return sub.subscription_id != id.sub_id;
|
||||
});
|
||||
|
||||
if record_subscriptions.is_empty() {
|
||||
subscriptions.record.remove(&row_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
subscriptions.table.retain(|sub| {
|
||||
return sub.subscription_id != id.sub_id;
|
||||
});
|
||||
}
|
||||
|
||||
subscriptions.is_empty()
|
||||
};
|
||||
|
||||
if remove_subscription_entry_for_table {
|
||||
let table_name = &id.table_name;
|
||||
// NOTE: Only write lock across all tables when necessary.
|
||||
read_lock.with_upgraded(|lock| {
|
||||
// Check again to avoid races:
|
||||
if lock.get(table_name).is_some_and(|e| e.read().is_empty()) {
|
||||
lock.remove(table_name);
|
||||
|
||||
if lock.is_empty() {
|
||||
uninstall_hook_rusqlite(conn);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn add_hook(self: &Arc<Self>, api: RecordApi) {
|
||||
let conn = (**api.conn()).clone();
|
||||
let conn = api.conn().clone();
|
||||
let state = self.clone();
|
||||
|
||||
let receiver = install_hook(&conn).to_async();
|
||||
let receiver = install_hook(&conn);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
if receiver.sender_count() == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let event = match receiver.recv().await {
|
||||
Ok(event) => event,
|
||||
Err(kanal::ReceiveError::Closed) | Err(kanal::ReceiveError::SendClosed) => {
|
||||
// Spawn broker task.
|
||||
if let Err(err) = std::thread::Builder::new()
|
||||
.name("subscriptions".to_string())
|
||||
.spawn(move || {
|
||||
let mut expected = 1;
|
||||
loop {
|
||||
if receiver.sender_count() == 0 {
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let state = state.clone();
|
||||
conn.call_reader_and_forget(move |conn| {
|
||||
broker_event(conn, state, event);
|
||||
});
|
||||
}
|
||||
let event = match receiver.recv() {
|
||||
Ok((cnt, event)) => {
|
||||
if cnt != expected {
|
||||
// QUESTION: There's several ways we could deal with failure. We
|
||||
// probably shouldn't create back pressure on the preupdate_hook and gunk up the
|
||||
// SQLite access. We could try to deliver event loss messages to all receivers but
|
||||
// that may just make the problem worse. We're probably at limit already
|
||||
// if we don't manage to catch up. Should we just disconnect all subscriptions?
|
||||
state.state.lock().subscriptions.clear();
|
||||
break;
|
||||
}
|
||||
expected += 1;
|
||||
|
||||
debug!("Channel closed: terminating subscription broker task.");
|
||||
});
|
||||
event
|
||||
}
|
||||
Err(kanal::ReceiveError::Closed) | Err(kanal::ReceiveError::SendClosed) => {
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
broker(&conn, &state, event);
|
||||
}
|
||||
|
||||
debug!("Channel closed: terminating subscription broker task.");
|
||||
})
|
||||
{
|
||||
log::error!("Failed to start subscription broker: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_record_subscription(
|
||||
@@ -230,8 +245,8 @@ impl PerConnectionState {
|
||||
api: RecordApi,
|
||||
record: trailbase_sqlite::Value,
|
||||
user: Option<User>,
|
||||
sender: async_channel::Sender<Arc<EventPayload>>,
|
||||
) -> Result<SubscriptionId, RecordError> {
|
||||
sender: async_channel::Sender<EventCandidate>,
|
||||
) -> Result<Arc<Subscription>, RecordError> {
|
||||
let table_name = api.table_name();
|
||||
let pk_column = &api.record_pk_column().column.name;
|
||||
|
||||
@@ -248,25 +263,32 @@ impl PerConnectionState {
|
||||
};
|
||||
|
||||
let qualified_name = api.qualified_name();
|
||||
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let install_hook: bool = {
|
||||
let mut lock = self.subscriptions.write();
|
||||
let empty = lock.is_empty();
|
||||
let sender = sender.clone();
|
||||
let subscription_entry = Arc::new(Subscription {
|
||||
id: SubscriptionId {
|
||||
table_name: qualified_name.clone(),
|
||||
row_id: Some(row_id),
|
||||
sub_id: SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst),
|
||||
},
|
||||
record_api_name: api.api_name().to_string(),
|
||||
user,
|
||||
sender,
|
||||
filter: Filter::Passthrough,
|
||||
candidate_seq: AtomicI64::default(),
|
||||
});
|
||||
|
||||
let subscriptions = lock.entry(qualified_name.clone()).or_default();
|
||||
let install_hook: bool = {
|
||||
let mut lock = self.state.lock();
|
||||
let empty = lock.subscriptions.is_empty();
|
||||
|
||||
let subscriptions = lock
|
||||
.subscriptions
|
||||
.entry(qualified_name.clone())
|
||||
.or_default();
|
||||
subscriptions
|
||||
.write()
|
||||
.record
|
||||
.entry(row_id)
|
||||
.or_default()
|
||||
.push(Subscription {
|
||||
subscription_id,
|
||||
record_api_name: api.api_name().to_string(),
|
||||
user,
|
||||
sender,
|
||||
filter: Filter::Passthrough,
|
||||
});
|
||||
.push(subscription_entry.clone());
|
||||
|
||||
empty
|
||||
};
|
||||
@@ -275,11 +297,7 @@ impl PerConnectionState {
|
||||
self.add_hook(api.clone());
|
||||
}
|
||||
|
||||
return Ok(SubscriptionId {
|
||||
table_name: qualified_name.clone(),
|
||||
row_id: Some(row_id),
|
||||
sub_id: subscription_id,
|
||||
});
|
||||
return Ok(subscription_entry);
|
||||
}
|
||||
|
||||
pub async fn add_table_subscription(
|
||||
@@ -287,8 +305,8 @@ impl PerConnectionState {
|
||||
api: RecordApi,
|
||||
user: Option<User>,
|
||||
filter: Option<ValueOrComposite>,
|
||||
sender: async_channel::Sender<Arc<EventPayload>>,
|
||||
) -> Result<SubscriptionId, RecordError> {
|
||||
sender: async_channel::Sender<EventCandidate>,
|
||||
) -> Result<Arc<Subscription>, RecordError> {
|
||||
let filter = if let Some(filter) = filter {
|
||||
Filter::Record(qs_filter_to_record_filter(api.columns(), filter)?)
|
||||
} else {
|
||||
@@ -296,20 +314,29 @@ impl PerConnectionState {
|
||||
};
|
||||
|
||||
let qualified_name = api.qualified_name();
|
||||
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let install_hook: bool = {
|
||||
let mut lock = self.subscriptions.write();
|
||||
let empty = lock.is_empty();
|
||||
let sender = sender.clone();
|
||||
let subscription_entry = Arc::new(Subscription {
|
||||
id: SubscriptionId {
|
||||
table_name: qualified_name.clone(),
|
||||
row_id: None,
|
||||
sub_id: SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst),
|
||||
},
|
||||
record_api_name: api.api_name().to_string(),
|
||||
user,
|
||||
sender,
|
||||
filter,
|
||||
candidate_seq: AtomicI64::default(),
|
||||
});
|
||||
|
||||
let subscriptions = lock.entry(qualified_name.clone()).or_default();
|
||||
subscriptions.write().table.push(Subscription {
|
||||
subscription_id,
|
||||
record_api_name: api.api_name().to_string(),
|
||||
user,
|
||||
sender,
|
||||
filter,
|
||||
});
|
||||
let install_hook: bool = {
|
||||
let mut lock = self.state.lock();
|
||||
let empty = lock.subscriptions.is_empty();
|
||||
|
||||
let subscriptions = lock
|
||||
.subscriptions
|
||||
.entry(qualified_name.clone())
|
||||
.or_default();
|
||||
|
||||
subscriptions.table.push(subscription_entry.clone());
|
||||
|
||||
empty
|
||||
};
|
||||
@@ -318,94 +345,51 @@ impl PerConnectionState {
|
||||
self.add_hook(api.clone());
|
||||
}
|
||||
|
||||
return Ok(SubscriptionId {
|
||||
table_name: qualified_name.clone(),
|
||||
row_id: None,
|
||||
sub_id: subscription_id,
|
||||
});
|
||||
return Ok(subscription_entry);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PerConnectionState {
|
||||
fn drop(&mut self) {
|
||||
if let Some(first) = self.record_apis.read().values().nth(0) {
|
||||
uninstall_hook(first.conn());
|
||||
}
|
||||
uninstall_hook(&self.state.lock().conn);
|
||||
}
|
||||
}
|
||||
|
||||
fn broker_subscriptions(
|
||||
s: &PerConnectionState,
|
||||
conn: &rusqlite::Connection,
|
||||
subs: &[Subscription],
|
||||
record_subscriptions: bool,
|
||||
record: &indexmap::IndexMap<&str, rusqlite::types::Value>,
|
||||
subs: &[Arc<Subscription>],
|
||||
record: &Arc<indexmap::IndexMap<String, rusqlite::types::Value>>,
|
||||
event: &Arc<EventPayload>,
|
||||
) -> Vec<usize> {
|
||||
let mut dead_subscriptions: Vec<usize> = vec![];
|
||||
|
||||
for (idx, sub) in subs.iter().enumerate() {
|
||||
// Skip events for records that are being filtered out anyway.
|
||||
if let Filter::Record(ref filter) = sub.filter
|
||||
&& !apply_filter_recursively_to_record(filter, record)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// We don't memoize and eagerly look up the APIs to make sure we get an up-to-date version.
|
||||
let Some(api) = s.lookup_record_api(&sub.record_api_name) else {
|
||||
dead_subscriptions.push(idx);
|
||||
sub.sender.close();
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Err(_err) = api.check_record_level_read_access_for_subscriptions(
|
||||
conn,
|
||||
SubscriptionAclParams {
|
||||
params: record,
|
||||
user: sub.user.as_ref(),
|
||||
},
|
||||
) {
|
||||
// NOTE: that access failures for table subscriptions for specific records are simply ignored,
|
||||
// i.e. those events will just not be send. Other records in the table may pass the
|
||||
// check. For record subscriptions, however, missing access is a death sentence.
|
||||
if record_subscriptions {
|
||||
// This can happen if the record api configuration has changed since originally
|
||||
// subscribed. In this case we just send and error and cancel the subscription.
|
||||
match sub.sender.try_send(ACCESS_DENIED_EVENT.clone()) {
|
||||
Ok(_) | Err(TrySendError::Full(_)) => {
|
||||
sub.sender.close();
|
||||
return subs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, sub)| {
|
||||
// Cloning the event. It's important that we use a try_send here to not block other
|
||||
// subscriptions if a subscriber is slow and their channel fills up.
|
||||
if let Err(err) = sub.sender.try_send(EventCandidate {
|
||||
record: Some(record.clone()),
|
||||
payload: event.clone(),
|
||||
seq: sub.candidate_seq.fetch_add(1, Ordering::SeqCst),
|
||||
}) {
|
||||
match err {
|
||||
async_channel::TrySendError::Full(ev) => {
|
||||
debug!("Channel full, dropping event: {ev:?}");
|
||||
}
|
||||
async_channel::TrySendError::Closed(_ev) => {
|
||||
return Some(idx);
|
||||
}
|
||||
Err(TrySendError::Closed(_)) => {}
|
||||
};
|
||||
|
||||
dead_subscriptions.push(idx);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Cloning the event. It's important that we use a try_send here to not block other
|
||||
// subscriptions if a subscriber is slow and their channel fills up.
|
||||
if let Err(err) = sub.sender.try_send(event.clone()) {
|
||||
match err {
|
||||
async_channel::TrySendError::Full(ev) => {
|
||||
debug!("Channel full, dropping event: {ev:?}");
|
||||
}
|
||||
async_channel::TrySendError::Closed(_ev) => {
|
||||
dead_subscriptions.push(idx);
|
||||
sub.sender.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return dead_subscriptions;
|
||||
return None;
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
/// Broker event to various subscriptions.
|
||||
fn broker_event(
|
||||
conn: &rusqlite::Connection,
|
||||
state: Arc<PerConnectionState>,
|
||||
fn broker(
|
||||
conn: &trailbase_sqlite::Connection,
|
||||
state: &Arc<PerConnectionState>,
|
||||
event: PreupdateHookEvent,
|
||||
) {
|
||||
let PreupdateHookEvent {
|
||||
@@ -415,153 +399,75 @@ fn broker_event(
|
||||
record,
|
||||
} = event;
|
||||
|
||||
let mut per_table_subscriptions = state.subscriptions.upgradable_read();
|
||||
let mut state = state.state.lock();
|
||||
|
||||
// If table_metadata is missing, the config/schema must have changed, thus removing the
|
||||
// subscriptions.
|
||||
let connection_metadata_lock = state.connection_metadata.read();
|
||||
let Some(table_metadata) = connection_metadata_lock.get_table(&table_name) else {
|
||||
let connection_metadata = state.connection_metadata.clone();
|
||||
let Some(table_metadata) = connection_metadata.get_table(&table_name) else {
|
||||
warn!("Table {table_name:?} not found. Removing subscriptions");
|
||||
|
||||
per_table_subscriptions.with_upgraded(|lock| {
|
||||
lock.remove(&table_name);
|
||||
if lock.is_empty() {
|
||||
uninstall_hook_rusqlite(conn);
|
||||
}
|
||||
});
|
||||
|
||||
state.subscriptions.remove(&table_name);
|
||||
if state.subscriptions.is_empty() {
|
||||
uninstall_hook(conn);
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
let remove_subscription_entry_for_table = {
|
||||
// Check if there are any matching subscriptions and otherwise go back to listening.
|
||||
let Some(mut subscriptions) = per_table_subscriptions
|
||||
.get(&table_name)
|
||||
.map(|r| r.upgradable_read())
|
||||
else {
|
||||
return;
|
||||
};
|
||||
if subscriptions.table.is_empty() && !subscriptions.record.contains_key(&row_id) {
|
||||
return;
|
||||
}
|
||||
// Check if there are any matching subscriptions and otherwise go back to listening.
|
||||
let Some(subscriptions) = state.subscriptions.get_mut(&table_name) else {
|
||||
return;
|
||||
};
|
||||
if subscriptions.table.is_empty() && !subscriptions.record.contains_key(&row_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Join values with column names. We use a map rather than a Vec<(String, Value)> for filter
|
||||
// access.
|
||||
let record: indexmap::IndexMap<&str, rusqlite::types::Value> = record
|
||||
// Join values with column names. We use a map rather than a Vec<(String, Value)> for filter
|
||||
// access.
|
||||
let record: Arc<indexmap::IndexMap<String, rusqlite::types::Value>> = Arc::new(
|
||||
record
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (table_metadata.schema.columns[idx].name.as_str(), v))
|
||||
.map(|(idx, v)| (table_metadata.schema.columns[idx].name.clone(), v))
|
||||
.collect(),
|
||||
);
|
||||
|
||||
// Build a JSON-encoded SQLite event (insert, update, delete).
|
||||
let event: Arc<EventPayload> = {
|
||||
let json_obj = record
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
return value_to_flat_json(value)
|
||||
.ok()
|
||||
.map(|v| (name.to_string(), v));
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build a JSON-encoded SQLite event (insert, update, delete).
|
||||
let event: Arc<EventPayload> = {
|
||||
let json_obj = record
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
return value_to_flat_json(value)
|
||||
.ok()
|
||||
.map(|v| (name.to_string(), v));
|
||||
})
|
||||
.collect();
|
||||
|
||||
let payload = EventPayload::from(&match action {
|
||||
RecordAction::Delete => JsonEventPayload::Delete { value: json_obj },
|
||||
RecordAction::Insert => JsonEventPayload::Insert { value: json_obj },
|
||||
RecordAction::Update => JsonEventPayload::Update { value: json_obj },
|
||||
});
|
||||
|
||||
Arc::new(payload)
|
||||
};
|
||||
|
||||
// First broker record subscriptions.
|
||||
let (dead_record_subscriptions, dead_table_subscriptions) = {
|
||||
// let Some(subscriptions) = per_table_subscriptions.get(&table_name).map(|r| r.read()) else {
|
||||
// return;
|
||||
// };
|
||||
|
||||
let dead_record_subscriptions = subscriptions
|
||||
.record
|
||||
.get(&row_id)
|
||||
.map(|record_subscriptions| {
|
||||
broker_subscriptions(&state, conn, record_subscriptions, true, &record, &event)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Then broker table subscriptions.
|
||||
let dead_table_subscriptions =
|
||||
broker_subscriptions(&state, conn, &subscriptions.table, false, &record, &event);
|
||||
|
||||
(dead_record_subscriptions, dead_table_subscriptions)
|
||||
};
|
||||
|
||||
if dead_record_subscriptions.is_empty()
|
||||
&& dead_table_subscriptions.is_empty()
|
||||
&& action != RecordAction::Delete
|
||||
{
|
||||
// No cleanup needed.
|
||||
return;
|
||||
}
|
||||
|
||||
subscriptions.with_upgraded(|subscriptions| {
|
||||
// Record subscription cleanup.
|
||||
match action {
|
||||
RecordAction::Delete => {
|
||||
// This is unique for record subscriptions: if the record is deleted, cancel all
|
||||
// subscriptions.
|
||||
subscriptions.record.remove(&row_id);
|
||||
}
|
||||
RecordAction::Update | RecordAction::Insert => {
|
||||
if let Some(m) = subscriptions.record.get_mut(&row_id) {
|
||||
for idx in dead_record_subscriptions.iter().rev() {
|
||||
m.swap_remove(*idx);
|
||||
}
|
||||
|
||||
if m.is_empty() {
|
||||
subscriptions.record.remove(&row_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Table subscription cleanup.
|
||||
for idx in dead_table_subscriptions.iter().rev() {
|
||||
subscriptions.table.swap_remove(*idx);
|
||||
}
|
||||
|
||||
/* remove_subscription_entry_for_table = */
|
||||
subscriptions.is_empty()
|
||||
})
|
||||
Arc::new(EventPayload::from(&match action {
|
||||
RecordAction::Delete => JsonEventPayload::Delete { value: json_obj },
|
||||
RecordAction::Insert => JsonEventPayload::Insert { value: json_obj },
|
||||
RecordAction::Update => JsonEventPayload::Update { value: json_obj },
|
||||
}))
|
||||
};
|
||||
|
||||
if remove_subscription_entry_for_table {
|
||||
// NOTE: Only write lock across all tables when necessary.
|
||||
per_table_subscriptions.with_upgraded(|lock| {
|
||||
// Check again to avoid races:
|
||||
if lock.get(&table_name).is_some_and(|e| e.read().is_empty()) {
|
||||
lock.remove(&table_name);
|
||||
// First broker record subscriptions.
|
||||
if let Some(record_subscriptions) = subscriptions.record.get_mut(&row_id) {
|
||||
let dead = broker_subscriptions(record_subscriptions, &record, &event);
|
||||
|
||||
if lock.is_empty() {
|
||||
uninstall_hook_rusqlite(conn);
|
||||
}
|
||||
}
|
||||
});
|
||||
for idx in dead.iter().rev() {
|
||||
record_subscriptions.remove(*idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Then broker table subscriptions.
|
||||
let dead = broker_subscriptions(&subscriptions.table, &record, &event);
|
||||
for idx in dead.iter().rev() {
|
||||
subscriptions.table.remove(*idx);
|
||||
}
|
||||
|
||||
if subscriptions.is_empty() {
|
||||
uninstall_hook(conn);
|
||||
}
|
||||
}
|
||||
|
||||
static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0);
|
||||
|
||||
static ACCESS_DENIED_EVENT: LazyLock<Arc<EventPayload>> = LazyLock::new(|| {
|
||||
Arc::new(EventPayload::from(&JsonEventPayload::Error {
|
||||
error: "Access denied".into(),
|
||||
}))
|
||||
});
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn static_sse_event_test() {
|
||||
let _x: Arc<EventPayload> = (*ACCESS_DENIED_EVENT).clone();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use async_channel::TryRecvError;
|
||||
use axum::extract::{Path, RawQuery, State};
|
||||
use futures_util::StreamExt;
|
||||
use http_body_util::BodyExt;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
use trailbase_sqlite::params;
|
||||
|
||||
use crate::User;
|
||||
@@ -9,10 +11,12 @@ use crate::admin::user::*;
|
||||
use crate::app_state::{AppState, test_state};
|
||||
use crate::auth::util::login_with_password;
|
||||
use crate::config::proto::RecordApiConfig;
|
||||
use crate::records::subscribe::event::{JsonEventPayload, deserialize_event};
|
||||
use crate::records::subscribe::handler::{SubscriptionQuery, add_subscription_sse_and_ws_handler};
|
||||
use crate::records::subscribe::event::{EventErrorStatus, TestChangeEvent, TestJsonEventPayload};
|
||||
use crate::records::subscribe::handler::{
|
||||
SubscriptionQuery, add_subscription_sse_and_ws_handler, subscribe_sse,
|
||||
};
|
||||
use crate::records::test_utils::add_record_api_config;
|
||||
use crate::records::{PermissionFlag, RecordError};
|
||||
use crate::records::{PermissionFlag, RecordApi, RecordError};
|
||||
use crate::util::uuid_to_b64;
|
||||
|
||||
async fn setup_world_readable() -> AppState {
|
||||
@@ -67,10 +71,7 @@ async fn subscribe_to_record_test() {
|
||||
|
||||
let manager = state.subscription_manager();
|
||||
let api = state.lookup_record_api("api_name").unwrap();
|
||||
let stream = manager
|
||||
.add_sse_record_subscription(api, trailbase_sqlite::Value::Integer(0), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut stream = subscribe_to_records(state.clone(), api, "0", None, /* filter= */ None).await;
|
||||
|
||||
assert_eq!(1, manager.num_record_subscriptions());
|
||||
|
||||
@@ -87,8 +88,8 @@ async fn subscribe_to_record_test() {
|
||||
|
||||
// First event is "connection established".
|
||||
assert!(matches!(
|
||||
deserialize_event(stream.receiver.recv().await.unwrap()).unwrap(),
|
||||
JsonEventPayload::Ping
|
||||
stream.next().await.unwrap().event,
|
||||
TestJsonEventPayload::Ping
|
||||
));
|
||||
|
||||
// This should do nothing since nobody is subscribed to id = 5.
|
||||
@@ -112,9 +113,9 @@ async fn subscribe_to_record_test() {
|
||||
"id": record_id_raw,
|
||||
"text": "bar",
|
||||
});
|
||||
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
|
||||
JsonEventPayload::Update { value: obj } => {
|
||||
assert_eq!(Value::Object(obj), expected);
|
||||
match stream.next().await.unwrap().event {
|
||||
TestJsonEventPayload::Update(obj) => {
|
||||
assert_eq!(Value::Object(obj.clone()), expected);
|
||||
}
|
||||
x => {
|
||||
panic!("Expected update, got: {x:?}");
|
||||
@@ -126,9 +127,9 @@ async fn subscribe_to_record_test() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
|
||||
JsonEventPayload::Delete { value: obj } => {
|
||||
assert_eq!(Value::Object(obj), expected);
|
||||
match stream.next().await.unwrap().event {
|
||||
TestJsonEventPayload::Delete(obj) => {
|
||||
assert_eq!(Value::Object(obj.clone()), expected);
|
||||
}
|
||||
x => {
|
||||
panic!("Expected delete, got: {x:?}");
|
||||
@@ -142,6 +143,67 @@ async fn subscribe_to_record_test() {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn subscribe_to_records(
|
||||
state: AppState,
|
||||
api: RecordApi,
|
||||
record: &str,
|
||||
user: Option<User>,
|
||||
filter: Option<&str>,
|
||||
// ) -> kanal::AsyncReceiver<TestChangeEvent> {
|
||||
) -> std::pin::Pin<Box<dyn futures_util::Stream<Item = TestChangeEvent>>> {
|
||||
let filter = filter.map(|f| SubscriptionQuery::parse(f).unwrap().filter.unwrap());
|
||||
let response = subscribe_sse(state, api, record.to_string(), filter, user)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(200, response.status());
|
||||
|
||||
let cnt = Arc::new(AtomicI64::default());
|
||||
let stream = response.into_data_stream();
|
||||
|
||||
return stream
|
||||
.take_while(|bytes| std::future::ready(bytes.is_ok()))
|
||||
.filter_map(move |bytes| {
|
||||
let cnt = cnt.clone();
|
||||
|
||||
return async move {
|
||||
let Ok(bytes) = bytes.as_ref() else {
|
||||
return None;
|
||||
};
|
||||
let payload = String::from_utf8_lossy(&bytes).to_string();
|
||||
|
||||
cnt.fetch_add(1, Ordering::SeqCst);
|
||||
if cnt.load(Ordering::SeqCst) == 1 {
|
||||
// Make sure we have an explicit ping as a first message to establish connection.
|
||||
assert!(payload.contains("ping"));
|
||||
|
||||
return Some(TestChangeEvent {
|
||||
event: TestJsonEventPayload::Ping,
|
||||
seq: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Ignore heartbeats.
|
||||
if let Some((_a, b)) = payload.split_once("data: ") {
|
||||
return Some(serde_json::from_str(b).unwrap());
|
||||
}
|
||||
return None;
|
||||
};
|
||||
})
|
||||
.boxed();
|
||||
}
|
||||
|
||||
async fn take_test_events(
|
||||
stream: std::pin::Pin<Box<dyn futures_util::Stream<Item = TestChangeEvent>>>,
|
||||
n: usize,
|
||||
) -> Vec<TestChangeEvent> {
|
||||
use tokio::time::*;
|
||||
|
||||
return timeout(Duration::from_secs(4), stream.take(n).collect())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_to_table_test() {
|
||||
let state = setup_world_readable().await;
|
||||
@@ -151,16 +213,13 @@ async fn subscribe_to_table_test() {
|
||||
let api = state.lookup_record_api("api_name").unwrap();
|
||||
|
||||
{
|
||||
let stream = manager
|
||||
.add_sse_table_subscription(api, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut stream = subscribe_to_records(state.clone(), api, "*", None, /* filter= */ None).await;
|
||||
|
||||
assert_eq!(1, manager.num_table_subscriptions());
|
||||
// First event is "connection established".
|
||||
assert!(matches!(
|
||||
deserialize_event(stream.receiver.recv().await.unwrap()).unwrap(),
|
||||
JsonEventPayload::Ping
|
||||
stream.next().await.unwrap().event,
|
||||
TestJsonEventPayload::Ping
|
||||
));
|
||||
|
||||
let record_id_raw = 0;
|
||||
@@ -180,13 +239,13 @@ async fn subscribe_to_table_test() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
|
||||
JsonEventPayload::Insert { value: obj } => {
|
||||
match stream.next().await.unwrap().event {
|
||||
TestJsonEventPayload::Insert(obj) => {
|
||||
let expected = serde_json::json!({
|
||||
"id": record_id_raw,
|
||||
"text": "foo",
|
||||
});
|
||||
assert_eq!(Value::Object(obj), expected);
|
||||
assert_eq!(Value::Object(obj.clone()), expected);
|
||||
}
|
||||
x => {
|
||||
panic!("Expected insert, got: {x:?}");
|
||||
@@ -197,9 +256,9 @@ async fn subscribe_to_table_test() {
|
||||
"id": record_id_raw,
|
||||
"text": "bar",
|
||||
});
|
||||
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
|
||||
JsonEventPayload::Update { value: obj } => {
|
||||
assert_eq!(Value::Object(obj), expected);
|
||||
match stream.next().await.unwrap().event {
|
||||
TestJsonEventPayload::Update(obj) => {
|
||||
assert_eq!(Value::Object(obj.clone()), expected);
|
||||
}
|
||||
x => {
|
||||
panic!("Expected update, got: {x:?}");
|
||||
@@ -211,9 +270,9 @@ async fn subscribe_to_table_test() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
|
||||
JsonEventPayload::Delete { value: obj } => {
|
||||
assert_eq!(Value::Object(obj), expected);
|
||||
match stream.next().await.unwrap().event {
|
||||
TestJsonEventPayload::Delete(obj) => {
|
||||
assert_eq!(Value::Object(obj.clone()), expected);
|
||||
}
|
||||
x => {
|
||||
panic!("Expected delete, got: {x:?}");
|
||||
@@ -430,29 +489,29 @@ async fn test_acl_selective_table_subs() {
|
||||
|
||||
// Assert events for table subscriptions are selective on ACLs.
|
||||
{
|
||||
let user_x_subscription = manager
|
||||
.add_sse_table_subscription(
|
||||
api.clone(),
|
||||
User::from_auth_token(&state, &user_x_token.auth_token),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut user_x_subscription = subscribe_to_records(
|
||||
state.clone(),
|
||||
api.clone(),
|
||||
"*",
|
||||
User::from_auth_token(&state, &user_x_token.auth_token),
|
||||
/* filter= */ None,
|
||||
)
|
||||
.await;
|
||||
|
||||
// First event is "connection established".
|
||||
assert!(matches!(
|
||||
deserialize_event(user_x_subscription.receiver.recv().await.unwrap()).unwrap(),
|
||||
JsonEventPayload::Ping
|
||||
user_x_subscription.next().await.unwrap().event,
|
||||
TestJsonEventPayload::Ping
|
||||
));
|
||||
|
||||
let user_y_subscription = manager
|
||||
.add_sse_table_subscription(
|
||||
api.clone(),
|
||||
User::from_auth_token(&state, &user_y_token.auth_token),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut user_y_subscription = subscribe_to_records(
|
||||
state.clone(),
|
||||
api.clone(),
|
||||
"*",
|
||||
User::from_auth_token(&state, &user_y_token.auth_token),
|
||||
/* filter= */ None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(2, manager.num_table_subscriptions());
|
||||
|
||||
@@ -468,14 +527,14 @@ async fn test_acl_selective_table_subs() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match deserialize_event(user_x_subscription.receiver.recv().await.unwrap()).unwrap() {
|
||||
JsonEventPayload::Insert { value: obj } => {
|
||||
match user_x_subscription.next().await.unwrap().event {
|
||||
TestJsonEventPayload::Insert(obj) => {
|
||||
let expected = serde_json::json!({
|
||||
"id": record_id_raw,
|
||||
"user": uuid_to_b64(&user_x),
|
||||
"text": "foo",
|
||||
});
|
||||
assert_eq!(Value::Object(obj), expected);
|
||||
assert_eq!(Value::Object(obj.clone()), expected);
|
||||
}
|
||||
x => {
|
||||
panic!("Expected insert, got: {x:?}");
|
||||
@@ -483,18 +542,26 @@ async fn test_acl_selective_table_subs() {
|
||||
};
|
||||
|
||||
// User y should *not* have received the insert event.
|
||||
assert!(
|
||||
tokio::time::timeout(
|
||||
tokio::time::Duration::from_millis(300),
|
||||
user_y_subscription.receiver.clone().count()
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
assert_eq!(
|
||||
user_y_subscription.receiver.try_recv().err().unwrap(),
|
||||
TryRecvError::Empty
|
||||
);
|
||||
// assert!(
|
||||
// tokio::time::timeout(
|
||||
// tokio::time::Duration::from_millis(300),
|
||||
// user_y_subscription.receiver.clone().count()
|
||||
// )
|
||||
// .await
|
||||
// .is_err()
|
||||
// );
|
||||
// assert_eq!(
|
||||
// user_y_subscription.next().await.unwrap(),
|
||||
// TryRecvError::Empty
|
||||
// );
|
||||
assert!(matches!(
|
||||
user_y_subscription.next().await.unwrap().event,
|
||||
TestJsonEventPayload::Ping
|
||||
));
|
||||
|
||||
use tokio::time::*;
|
||||
let got = timeout(Duration::from_millis(100), user_y_subscription.next()).await;
|
||||
assert!(got.is_err(), "Got: {got:?}");
|
||||
}
|
||||
|
||||
// Implicitly await for scheduled cleanups to go through.
|
||||
@@ -538,16 +605,20 @@ async fn subscription_acl_change_owner() {
|
||||
|
||||
let manager = state.subscription_manager();
|
||||
let api = state.lookup_record_api("api_name").unwrap();
|
||||
let stream = manager
|
||||
.add_sse_record_subscription(api, trailbase_sqlite::Value::Integer(record_id), user_x)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut stream = subscribe_to_records(
|
||||
state.clone(),
|
||||
api,
|
||||
&record_id.to_string(),
|
||||
user_x,
|
||||
/* filter= */ None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(1, manager.num_record_subscriptions());
|
||||
// First event is "connection established".
|
||||
assert!(matches!(
|
||||
deserialize_event(stream.receiver.recv().await.unwrap()).unwrap(),
|
||||
JsonEventPayload::Ping
|
||||
stream.next().await.unwrap().event,
|
||||
TestJsonEventPayload::Ping,
|
||||
));
|
||||
|
||||
conn
|
||||
@@ -567,8 +638,8 @@ async fn subscription_acl_change_owner() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
|
||||
JsonEventPayload::Update { value: obj } => {
|
||||
match stream.next().await.unwrap().event {
|
||||
TestJsonEventPayload::Update(obj) => {
|
||||
let expected = serde_json::json!({
|
||||
"id": record_id,
|
||||
"user": uuid_to_b64(&user_x_id),
|
||||
@@ -581,20 +652,19 @@ async fn subscription_acl_change_owner() {
|
||||
}
|
||||
}
|
||||
|
||||
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
|
||||
JsonEventPayload::Error { .. } => {}
|
||||
match stream.next().await.unwrap().event {
|
||||
TestJsonEventPayload::Error { status, .. } => {
|
||||
assert_eq!(EventErrorStatus::Forbidden, status);
|
||||
}
|
||||
x => {
|
||||
panic!("Expected error, got: {x:?}");
|
||||
}
|
||||
}
|
||||
|
||||
conn
|
||||
.read_query_row_f("SELECT 1", (), |row| row.get::<_, i64>(0))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(stream);
|
||||
|
||||
// Make sure the subscription was cleaned up after the access error.
|
||||
assert!(stream.receiver.is_closed());
|
||||
// assert!(stream.is_closed());
|
||||
assert_eq!(0, manager.num_record_subscriptions());
|
||||
}
|
||||
|
||||
@@ -607,22 +677,17 @@ async fn subscription_filter_test() {
|
||||
let api = state.lookup_record_api("api_name").unwrap();
|
||||
|
||||
{
|
||||
let filter =
|
||||
SubscriptionQuery::parse("filter[$and][0][id][$gt]=5&filter[$and][1][id][$lt]=100").unwrap();
|
||||
|
||||
let stream = manager
|
||||
.add_sse_table_subscription(api, None, Some(filter.filter.unwrap()))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = subscribe_to_records(
|
||||
state.clone(),
|
||||
api.clone(),
|
||||
"*",
|
||||
/* user= */ None,
|
||||
Some("filter[$and][0][id][$gt]=5&filter[$and][1][id][$lt]=100"),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(1, manager.num_table_subscriptions());
|
||||
// First event is "connection established".
|
||||
assert!(matches!(
|
||||
deserialize_event(stream.receiver.recv().await.unwrap()).unwrap(),
|
||||
JsonEventPayload::Ping
|
||||
));
|
||||
|
||||
// This one should be filter out.
|
||||
conn
|
||||
.execute("INSERT INTO test (id, text) VALUES ($1, 'foo')", params!(1))
|
||||
.await
|
||||
@@ -637,13 +702,17 @@ async fn subscription_filter_test() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match deserialize_event(stream.receiver.recv().await.unwrap()).unwrap() {
|
||||
JsonEventPayload::Insert { value: obj } => {
|
||||
let events = take_test_events(stream, 2).await;
|
||||
|
||||
assert!(matches!(events[0].event, TestJsonEventPayload::Ping));
|
||||
|
||||
match &events[1].event {
|
||||
TestJsonEventPayload::Insert(obj) => {
|
||||
let expected = serde_json::json!({
|
||||
"id": 25,
|
||||
"text": "foo",
|
||||
});
|
||||
assert_eq!(Value::Object(obj), expected);
|
||||
assert_eq!(Value::Object(obj.clone()), expected);
|
||||
}
|
||||
x => {
|
||||
panic!("Expected insert, got: {x:?}");
|
||||
|
||||
Reference in New Issue
Block a user