|
|
|
|
@@ -1,5 +1,5 @@
|
|
|
|
|
use async_channel::{TrySendError, WeakReceiver};
|
|
|
|
|
use axum::extract::{Path, RawQuery, State};
|
|
|
|
|
use axum::extract::{Path, RawQuery, Request, State};
|
|
|
|
|
use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
|
|
|
|
|
use axum::response::{IntoResponse, Response};
|
|
|
|
|
use futures_util::{Stream, StreamExt};
|
|
|
|
|
@@ -14,7 +14,7 @@ use std::pin::Pin;
|
|
|
|
|
use std::sync::atomic::{AtomicI64, Ordering};
|
|
|
|
|
use std::sync::{Arc, LazyLock};
|
|
|
|
|
use std::task::{Context, Poll};
|
|
|
|
|
use trailbase_qs::FilterQuery;
|
|
|
|
|
use trailbase_qs::ValueOrComposite;
|
|
|
|
|
use trailbase_schema::QualifiedName;
|
|
|
|
|
use trailbase_schema::json::value_to_flat_json;
|
|
|
|
|
use trailbase_sqlite::connection::{extract_record_values, extract_row_id};
|
|
|
|
|
@@ -151,7 +151,7 @@ impl DbEvent {
|
|
|
|
|
#[inline]
|
|
|
|
|
fn into_event(self) -> Result<Event, axum::Error> {
|
|
|
|
|
return Ok(Event::Json(
|
|
|
|
|
serde_json::to_string(&self).map_err(|err| axum::Error::new(err))?,
|
|
|
|
|
serde_json::to_string(&self).map_err(axum::Error::new)?,
|
|
|
|
|
));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -385,7 +385,7 @@ impl PerConnectionState {
|
|
|
|
|
self: Arc<Self>,
|
|
|
|
|
api: RecordApi,
|
|
|
|
|
user: Option<User>,
|
|
|
|
|
filter: Option<trailbase_qs::ValueOrComposite>,
|
|
|
|
|
filter: Option<ValueOrComposite>,
|
|
|
|
|
) -> Result<(async_channel::Sender<Event>, AutoCleanupEventStream), RecordError> {
|
|
|
|
|
let table_name = api.qualified_name().clone();
|
|
|
|
|
|
|
|
|
|
@@ -529,7 +529,7 @@ impl SubscriptionManager {
|
|
|
|
|
&self,
|
|
|
|
|
api: RecordApi,
|
|
|
|
|
user: Option<User>,
|
|
|
|
|
filter: Option<trailbase_qs::ValueOrComposite>,
|
|
|
|
|
filter: Option<ValueOrComposite>,
|
|
|
|
|
) -> Result<AutoCleanupEventStream, RecordError> {
|
|
|
|
|
let (sender, receiver) = self
|
|
|
|
|
.get_per_connection_state(&api)
|
|
|
|
|
@@ -810,13 +810,35 @@ fn hook_continuation(conn: &rusqlite::Connection, s: ContinuationState) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Default, Debug, PartialEq, Deserialize)]
|
|
|
|
|
struct SubscriptionQuery {
|
|
|
|
|
/// Map from filter params to filter value. It's a vector in cases like:
|
|
|
|
|
/// `col0[$gte]=2&col0[$lte]=10`.
|
|
|
|
|
filter: Option<ValueOrComposite>,
|
|
|
|
|
|
|
|
|
|
/// Whether to use WebSocket instead of default SSE.
|
|
|
|
|
ws: Option<bool>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl SubscriptionQuery {
|
|
|
|
|
fn parse(query: &str) -> Result<SubscriptionQuery, RecordError> {
|
|
|
|
|
// NOTE: We rely on non-strict mode to parse `filter[col0]=a&b%filter[col1]=c`.
|
|
|
|
|
let qs = serde_qs::Config::new(9, false);
|
|
|
|
|
return qs
|
|
|
|
|
.deserialize_bytes::<SubscriptionQuery>(query.as_bytes())
|
|
|
|
|
.map_err(|_err| RecordError::BadRequest("Invalid query"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Read record.
|
|
|
|
|
#[utoipa::path(
|
|
|
|
|
get,
|
|
|
|
|
path = "/{name}/subscribe/{record}",
|
|
|
|
|
tag = "records",
|
|
|
|
|
// TODO: Document the params. Requires utoipa support in trailbase_qs or external impl.
|
|
|
|
|
// params(SubscritpionParams),
|
|
|
|
|
responses(
|
|
|
|
|
(status = 200, description = "SSE stream of record changes.")
|
|
|
|
|
(status = 200, description = "Starts streaming changes to matching records via SSE/WebSocket")
|
|
|
|
|
)
|
|
|
|
|
)]
|
|
|
|
|
pub async fn add_subscription_sse_handler(
|
|
|
|
|
@@ -824,6 +846,7 @@ pub async fn add_subscription_sse_handler(
|
|
|
|
|
Path((api_name, record)): Path<(String, String)>,
|
|
|
|
|
user: Option<User>,
|
|
|
|
|
RawQuery(raw_url_query): RawQuery,
|
|
|
|
|
request: Request,
|
|
|
|
|
) -> Result<Response, RecordError> {
|
|
|
|
|
let Some(api) = state.lookup_record_api(&api_name) else {
|
|
|
|
|
return Err(RecordError::ApiNotFound);
|
|
|
|
|
@@ -833,16 +856,20 @@ pub async fn add_subscription_sse_handler(
|
|
|
|
|
return Err(RecordError::Forbidden);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let FilterQuery { filter } = raw_url_query
|
|
|
|
|
let SubscriptionQuery { filter, ws } = raw_url_query
|
|
|
|
|
.as_ref()
|
|
|
|
|
.map_or_else(
|
|
|
|
|
|| Ok(FilterQuery::default()),
|
|
|
|
|
|query| FilterQuery::parse(query),
|
|
|
|
|
|| Ok(SubscriptionQuery::default()),
|
|
|
|
|
|query| SubscriptionQuery::parse(query),
|
|
|
|
|
)
|
|
|
|
|
.map_err(|_err| {
|
|
|
|
|
return RecordError::BadRequest("Invalid query");
|
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
if ws.unwrap_or(false) {
|
|
|
|
|
return subscribe_ws(state, api, record, filter, user, request).await;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return match record.as_str() {
|
|
|
|
|
"*" => {
|
|
|
|
|
api.check_table_level_access(Permission::Read, user.as_ref())?;
|
|
|
|
|
@@ -878,100 +905,134 @@ pub async fn add_subscription_sse_handler(
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(feature = "ws")]
|
|
|
|
|
pub async fn add_subscription_ws_handler(
|
|
|
|
|
State(state): State<AppState>,
|
|
|
|
|
ws: axum::extract::ws::WebSocketUpgrade,
|
|
|
|
|
Path((api_name, record)): Path<(String, String)>,
|
|
|
|
|
RawQuery(raw_url_query): RawQuery,
|
|
|
|
|
#[allow(unused)]
|
|
|
|
|
pub async fn subscribe_ws(
|
|
|
|
|
state: AppState,
|
|
|
|
|
api: RecordApi,
|
|
|
|
|
record: String,
|
|
|
|
|
filter: Option<ValueOrComposite>,
|
|
|
|
|
user: Option<User>,
|
|
|
|
|
request: Request,
|
|
|
|
|
) -> Result<Response, RecordError> {
|
|
|
|
|
use axum::extract::ws::WebSocket;
|
|
|
|
|
use futures_util::SinkExt;
|
|
|
|
|
|
|
|
|
|
let Some(api) = state.lookup_record_api(&api_name) else {
|
|
|
|
|
return Err(RecordError::ApiNotFound);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if !api.enable_subscriptions() {
|
|
|
|
|
return Err(RecordError::Forbidden);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let FilterQuery { filter } = raw_url_query
|
|
|
|
|
.as_ref()
|
|
|
|
|
.map_or_else(
|
|
|
|
|
|| Ok(FilterQuery::default()),
|
|
|
|
|
|query| FilterQuery::parse(query),
|
|
|
|
|
)
|
|
|
|
|
.map_err(|_err| {
|
|
|
|
|
return RecordError::BadRequest("Invalid query");
|
|
|
|
|
})?;
|
|
|
|
|
#[cfg(not(feature = "ws"))]
|
|
|
|
|
return Err(RecordError::BadRequest("ws unsupported"));
|
|
|
|
|
|
|
|
|
|
async fn on_upgrade(socket: WebSocket, receiver: AutoCleanupEventStream) {
|
|
|
|
|
let (mut sender, _) = socket.split();
|
|
|
|
|
#[cfg(feature = "ws")]
|
|
|
|
|
{
|
|
|
|
|
use axum::extract::FromRequestParts;
|
|
|
|
|
use axum::extract::ws::{CloseFrame, Message, WebSocket, WebSocketUpgrade};
|
|
|
|
|
use futures_util::SinkExt;
|
|
|
|
|
|
|
|
|
|
let mut pinned_receiver = std::pin::pin!(receiver);
|
|
|
|
|
while let Some(ev) = pinned_receiver.next().await {
|
|
|
|
|
match ev.into_ws_event() {
|
|
|
|
|
Ok(msg) => {
|
|
|
|
|
if let Err(axum_err) = sender.send(msg).await {
|
|
|
|
|
debug!("{axum_err}");
|
|
|
|
|
let _ = sender.close();
|
|
|
|
|
return;
|
|
|
|
|
let (mut parts, body) = request.into_parts();
|
|
|
|
|
let ws = match WebSocketUpgrade::from_request_parts(&mut parts, &state).await {
|
|
|
|
|
Ok(ws) => ws,
|
|
|
|
|
Err(err) => {
|
|
|
|
|
return Ok(err.into_response());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
async fn abort<S: SinkExt<Message> + std::marker::Unpin>(sender: &mut S, reason: &str) {
|
|
|
|
|
// https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1
|
|
|
|
|
//
|
|
|
|
|
// 1011 indicates that a server is terminating the connection because
|
|
|
|
|
// it encountered an unexpected condition that prevented it from
|
|
|
|
|
// fulfilling the request.
|
|
|
|
|
let _ = sender
|
|
|
|
|
.send(Message::Close(Some(CloseFrame {
|
|
|
|
|
code: 1011,
|
|
|
|
|
reason: reason.into(),
|
|
|
|
|
})))
|
|
|
|
|
.await;
|
|
|
|
|
|
|
|
|
|
sender.close();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// WebSocket uses the HTTP `UPGRADE` mechanism to switch over to dedicated, non-HTTP `ws://`
|
|
|
|
|
// protocol.
|
|
|
|
|
async fn on_upgrade(mut socket: WebSocket, receiver: AutoCleanupEventStream) {
|
|
|
|
|
// NOTE: We're dropping the receiver end of the bidirectional WebSocket. We're not expecting
|
|
|
|
|
// any messages from the client.
|
|
|
|
|
let (mut sender, _) = socket.split();
|
|
|
|
|
|
|
|
|
|
let mut pinned_receiver = std::pin::pin!(receiver);
|
|
|
|
|
while let Some(ev) = pinned_receiver.next().await {
|
|
|
|
|
match ev.into_ws_event() {
|
|
|
|
|
Ok(msg) => {
|
|
|
|
|
if let Err(err) = sender.send(msg).await {
|
|
|
|
|
debug!("Sending WS event to client failed: {err}");
|
|
|
|
|
abort(&mut sender, "Failed to send event. Closing channel");
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Err(err) => {
|
|
|
|
|
debug_assert!(false, "into_ws_event failed: {err}");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
Err(err) => {
|
|
|
|
|
debug_assert!(false, "into_ws_event failed: {err}");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return match record.as_str() {
|
|
|
|
|
"*" => {
|
|
|
|
|
api.check_table_level_access(Permission::Read, user.as_ref())?;
|
|
|
|
|
|
|
|
|
|
Ok(ws.on_upgrade(async move |mut socket| {
|
|
|
|
|
let Ok((_sender, receiver)) = state
|
|
|
|
|
.subscription_manager()
|
|
|
|
|
.get_per_connection_state(&api)
|
|
|
|
|
.add_table_subscription(api, user, filter)
|
|
|
|
|
.await
|
|
|
|
|
else {
|
|
|
|
|
abort(&mut socket, "subscription failed");
|
|
|
|
|
return;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
on_upgrade(socket, receiver).await
|
|
|
|
|
}))
|
|
|
|
|
}
|
|
|
|
|
_ => {
|
|
|
|
|
let record_id = api.primary_key_to_value(record)?;
|
|
|
|
|
api
|
|
|
|
|
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
|
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
Ok(ws.on_upgrade(async move |mut socket| {
|
|
|
|
|
let Ok((_sender, receiver)) = state
|
|
|
|
|
.subscription_manager()
|
|
|
|
|
.get_per_connection_state(&api)
|
|
|
|
|
.add_record_subscription(api, record_id, user)
|
|
|
|
|
.await
|
|
|
|
|
else {
|
|
|
|
|
abort(&mut socket, "subscription failed");
|
|
|
|
|
return;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
on_upgrade(socket, receiver).await;
|
|
|
|
|
}))
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return match record.as_str() {
|
|
|
|
|
"*" => {
|
|
|
|
|
api.check_table_level_access(Permission::Read, user.as_ref())?;
|
|
|
|
|
|
|
|
|
|
let (_sender, receiver) = state
|
|
|
|
|
.subscription_manager()
|
|
|
|
|
.get_per_connection_state(&api)
|
|
|
|
|
.add_table_subscription(api, user, filter)
|
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
Ok(ws.on_upgrade(async |socket| on_upgrade(socket, receiver).await))
|
|
|
|
|
}
|
|
|
|
|
_ => {
|
|
|
|
|
let record_id = api.primary_key_to_value(record)?;
|
|
|
|
|
api
|
|
|
|
|
.check_record_level_access(Permission::Read, Some(&record_id), None, user.as_ref())
|
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
let (_sender, receiver) = state
|
|
|
|
|
.subscription_manager()
|
|
|
|
|
.get_per_connection_state(&api)
|
|
|
|
|
.add_record_subscription(api, record_id, user)
|
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
Ok(ws.on_upgrade(async |socket| on_upgrade(socket, receiver).await))
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0);
|
|
|
|
|
|
|
|
|
|
static ESTABLISHED_EVENT: LazyLock<Event> =
|
|
|
|
|
LazyLock::new(|| Event::Sse(SseEvent::default().comment("subscription established")));
|
|
|
|
|
#[cfg(not(feature = "ws"))]
|
|
|
|
|
static ACCESS_DENIED_EVENT: LazyLock<Event> = LazyLock::new(|| {
|
|
|
|
|
#[cfg(not(feature = "ws"))]
|
|
|
|
|
return Event::Sse(
|
|
|
|
|
Event::Sse(
|
|
|
|
|
SseEvent::default()
|
|
|
|
|
.json_data(DbEvent::Error("Access denied".into()))
|
|
|
|
|
.expect("static"),
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
#[cfg(feature = "ws")]
|
|
|
|
|
return Event::Json(
|
|
|
|
|
serde_json::to_string(&DbEvent::Error("Access denied".into())).expect("static"),
|
|
|
|
|
);
|
|
|
|
|
)
|
|
|
|
|
});
|
|
|
|
|
#[cfg(feature = "ws")]
|
|
|
|
|
static ACCESS_DENIED_EVENT: LazyLock<Event> = LazyLock::new(|| {
|
|
|
|
|
Event::Json(serde_json::to_string(&DbEvent::Error("Access denied".into())).expect("static"))
|
|
|
|
|
});
|
|
|
|
|
static ESTABLISHED_EVENT: LazyLock<Event> =
|
|
|
|
|
LazyLock::new(|| Event::Sse(SseEvent::default().comment("subscription established")));
|
|
|
|
|
|
|
|
|
|
const NO_HOOK: Option<fn(Action, &str, &str, &PreUpdateCase)> = None;
|
|
|
|
|
|
|
|
|
|
@@ -1284,6 +1345,7 @@ mod tests {
|
|
|
|
|
Path(("api_name".to_string(), record_id_raw.to_string())),
|
|
|
|
|
None,
|
|
|
|
|
RawQuery(None),
|
|
|
|
|
axum::extract::Request::default(),
|
|
|
|
|
)
|
|
|
|
|
.await;
|
|
|
|
|
|
|
|
|
|
@@ -1352,6 +1414,7 @@ mod tests {
|
|
|
|
|
Path(("api_name".to_string(), "*".to_string())),
|
|
|
|
|
None,
|
|
|
|
|
RawQuery(None),
|
|
|
|
|
axum::extract::Request::default(),
|
|
|
|
|
)
|
|
|
|
|
.await;
|
|
|
|
|
|
|
|
|
|
@@ -1372,6 +1435,7 @@ mod tests {
|
|
|
|
|
Path(("api_name".to_string(), "*".to_string())),
|
|
|
|
|
User::from_auth_token(&state, &user_x_token.auth_token),
|
|
|
|
|
RawQuery(None),
|
|
|
|
|
axum::extract::Request::default(),
|
|
|
|
|
)
|
|
|
|
|
.await
|
|
|
|
|
.unwrap();
|
|
|
|
|
@@ -1399,6 +1463,7 @@ mod tests {
|
|
|
|
|
Path(("api_name".to_string(), record_id_raw.to_string())),
|
|
|
|
|
User::from_auth_token(&state, &user_x_token.auth_token),
|
|
|
|
|
RawQuery(None),
|
|
|
|
|
axum::extract::Request::default(),
|
|
|
|
|
)
|
|
|
|
|
.await
|
|
|
|
|
.unwrap();
|
|
|
|
|
@@ -1420,6 +1485,7 @@ mod tests {
|
|
|
|
|
Path(("api_name".to_string(), record_id_raw.to_string())),
|
|
|
|
|
User::from_auth_token(&state, &user_y_token.auth_token),
|
|
|
|
|
RawQuery(None),
|
|
|
|
|
axum::extract::Request::default(),
|
|
|
|
|
)
|
|
|
|
|
.await;
|
|
|
|
|
|
|
|
|
|
@@ -1635,7 +1701,8 @@ mod tests {
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
let filter =
|
|
|
|
|
FilterQuery::parse("filter[$and][0][id][$gt]=5&filter[$and][1][id][$lt]=100").unwrap();
|
|
|
|
|
SubscriptionQuery::parse("filter[$and][0][id][$gt]=5&filter[$and][1][id][$lt]=100")
|
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
|
|
let stream = manager
|
|
|
|
|
.add_sse_table_subscription(api, None, Some(filter.filter.unwrap()))
|
|
|
|
|
|