mirror of
https://github.com/trailbaseio/trailbase.git
synced 2025-12-21 09:29:44 -06:00
Optimization: reduce allocations, cloning and re-encoding.
This commit is contained in:
@@ -2,7 +2,7 @@ use itertools::Itertools;
|
||||
use log::*;
|
||||
use std::borrow::Cow;
|
||||
use std::sync::Arc;
|
||||
use trailbase_sqlite::{NamedParams, Params as _, Value};
|
||||
use trailbase_sqlite::{NamedParamRef, NamedParams, NamedParamsRef, Params as _, Value};
|
||||
|
||||
use crate::auth::user::User;
|
||||
use crate::config::proto::{ConflictResolutionStrategy, RecordApiConfig};
|
||||
@@ -343,7 +343,7 @@ impl RecordApi {
|
||||
&self,
|
||||
conn: &rusqlite::Connection,
|
||||
p: Permission,
|
||||
record: Vec<(&str, rusqlite::types::Value)>,
|
||||
record: &[(&str, rusqlite::types::ValueRef<'_>)],
|
||||
user: Option<&User>,
|
||||
) -> Result<(), RecordError> {
|
||||
// First check table level access and if present check row-level access based on access rule.
|
||||
@@ -355,7 +355,7 @@ impl RecordApi {
|
||||
|
||||
let (query, params) = build_query_and_params_for_record_read(access_rule, user, record);
|
||||
|
||||
match Self::query_access(conn, &query, params) {
|
||||
match Self::query_access_ref(conn, &query, ¶ms) {
|
||||
Ok(allowed) => {
|
||||
if allowed {
|
||||
return Ok(());
|
||||
@@ -404,6 +404,23 @@ impl RecordApi {
|
||||
return Err(rusqlite::Error::QueryReturnedNoRows.into());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn query_access_ref(
|
||||
conn: &rusqlite::Connection,
|
||||
access_query: &str,
|
||||
params: NamedParamsRef,
|
||||
) -> Result<bool, trailbase_sqlite::Error> {
|
||||
let mut stmt = conn.prepare_cached(access_query)?;
|
||||
params.bind(&mut stmt)?;
|
||||
|
||||
let mut rows = stmt.raw_query();
|
||||
if let Some(row) = rows.next()? {
|
||||
return Ok(row.get(0)?);
|
||||
}
|
||||
|
||||
return Err(rusqlite::Error::QueryReturnedNoRows.into());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn has_access(&self, e: Entity, p: Permission) -> bool {
|
||||
return (self.state.acl[e as usize] & (p as u8)) > 0;
|
||||
@@ -453,11 +470,11 @@ impl RecordApi {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_query_and_params_for_record_read(
|
||||
fn build_query_and_params_for_record_read<'a>(
|
||||
access_rule: &str,
|
||||
user: Option<&User>,
|
||||
record: Vec<(&str, rusqlite::types::Value)>,
|
||||
) -> (String, NamedParams) {
|
||||
record: &[(&str, rusqlite::types::ValueRef<'a>)],
|
||||
) -> (String, Vec<NamedParamRef<'a>>) {
|
||||
let row = record
|
||||
.iter()
|
||||
.map(|(name, _value)| {
|
||||
@@ -467,15 +484,22 @@ fn build_query_and_params_for_record_read(
|
||||
.join(", ");
|
||||
|
||||
let mut params: Vec<_> = record
|
||||
.into_iter()
|
||||
.iter()
|
||||
.map(|(name, value)| {
|
||||
return (Cow::Owned(format!(":__v_{name}")), value);
|
||||
return (
|
||||
Cow::Owned(format!(":__v_{name}")),
|
||||
rusqlite::types::ToSqlOutput::Borrowed(*value),
|
||||
);
|
||||
})
|
||||
.collect();
|
||||
|
||||
static NULL: rusqlite::types::ToSqlOutput<'static> =
|
||||
rusqlite::types::ToSqlOutput::Owned(Value::Null);
|
||||
params.push((
|
||||
Cow::Borrowed(":__user_id"),
|
||||
user.map_or(Value::Null, |u| Value::Blob(u.uuid.into())),
|
||||
user.map_or(NULL.clone(), |u| {
|
||||
rusqlite::types::ToSqlOutput::Owned(Value::Blob(u.uuid.into()))
|
||||
}),
|
||||
));
|
||||
|
||||
// Assumes access_rule is an expression: https://www.sqlite.org/syntax/expr.html
|
||||
|
||||
@@ -36,6 +36,25 @@ pub(crate) fn value_to_json(value: rusqlite::types::Value) -> Result<serde_json:
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn valueref_to_json(
|
||||
value: rusqlite::types::ValueRef<'_>,
|
||||
) -> Result<serde_json::Value, JsonError> {
|
||||
use rusqlite::types::ValueRef;
|
||||
|
||||
return Ok(match value {
|
||||
ValueRef::Null => serde_json::Value::Null,
|
||||
ValueRef::Real(real) => {
|
||||
let Some(number) = serde_json::Number::from_f64(real) else {
|
||||
return Err(JsonError::Finite);
|
||||
};
|
||||
serde_json::Value::Number(number)
|
||||
}
|
||||
ValueRef::Integer(integer) => serde_json::Value::Number(serde_json::Number::from(integer)),
|
||||
ValueRef::Blob(blob) => serde_json::Value::String(BASE64_URL_SAFE.encode(blob)),
|
||||
ValueRef::Text(text) => serde_json::Value::String(String::from_utf8_lossy(text).to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
/// Serialize SQL row to json.
|
||||
pub fn row_to_json(
|
||||
metadata: &(dyn TableOrViewMetadata + Send + Sync),
|
||||
|
||||
@@ -5,19 +5,16 @@ use axum::{
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use parking_lot::RwLock;
|
||||
use rusqlite::hooks::{Action, PreUpdateCase};
|
||||
use serde::Serialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicI64, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use trailbase_sqlite::{
|
||||
connection::{extract_record_values, extract_row_id},
|
||||
params,
|
||||
};
|
||||
use trailbase_sqlite::connection::{extract_record_values, extract_row_id};
|
||||
|
||||
use crate::auth::user::User;
|
||||
use crate::records::sql_to_json::value_to_json;
|
||||
use crate::records::sql_to_json::valueref_to_json;
|
||||
use crate::records::RecordApi;
|
||||
use crate::records::{Permission, RecordError};
|
||||
use crate::table_metadata::{TableMetadata, TableMetadataCache};
|
||||
@@ -28,9 +25,6 @@ static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0);
|
||||
|
||||
// TODO:
|
||||
// * clients
|
||||
// * optimize: avoid repeated encoding of events. Easy to do but makes testing harder since there's
|
||||
// no good way to parse sse::Event back :/. We should probably just bite the bullet and parse,
|
||||
// it's literally "data: <json>\n\n".
|
||||
|
||||
type SseEvent = Result<axum::response::sse::Event, axum::Error>;
|
||||
|
||||
@@ -52,7 +46,7 @@ impl From<Action> for RecordAction {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum DbEvent {
|
||||
Update(Option<serde_json::Value>),
|
||||
Insert(Option<serde_json::Value>),
|
||||
@@ -71,7 +65,7 @@ pub struct Subscription {
|
||||
// record_id: Option<trailbase_sqlite::Value>,
|
||||
user: Option<User>,
|
||||
/// Channel for sending events to the SSE handler.
|
||||
channel: async_channel::Sender<DbEvent>,
|
||||
channel: async_channel::Sender<Event>,
|
||||
}
|
||||
|
||||
/// Internal, shareable state of the cloneable SubscriptionManager.
|
||||
@@ -148,8 +142,8 @@ impl SubscriptionManager {
|
||||
s: &ManagerState,
|
||||
conn: &rusqlite::Connection,
|
||||
subs: &[Subscription],
|
||||
record: &[(&str, rusqlite::types::Value)],
|
||||
event: &DbEvent,
|
||||
record: &[(&str, rusqlite::types::ValueRef<'_>)],
|
||||
event: &Event,
|
||||
) -> Vec<usize> {
|
||||
let mut dead_subscriptions: Vec<usize> = vec![];
|
||||
for (idx, sub) in subs.iter().enumerate() {
|
||||
@@ -158,21 +152,18 @@ impl SubscriptionManager {
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Err(_err) = api.check_record_level_read_access(
|
||||
conn,
|
||||
Permission::Read,
|
||||
// TODO: Maybe we could inject ValueRef instead to avoid repeated cloning.
|
||||
record.to_owned(),
|
||||
sub.user.as_ref(),
|
||||
) {
|
||||
if let Err(_err) =
|
||||
api.check_record_level_read_access(conn, Permission::Read, record, sub.user.as_ref())
|
||||
{
|
||||
// This can happen if the record api configuration has changed since originally
|
||||
// subscribed. In this case we just send and error and cancel the subscription.
|
||||
let _ = sub.channel.try_send(DbEvent::Error("Access denied".into()));
|
||||
if let Ok(ev) = Event::default().json_data(DbEvent::Error("Access denied".into())) {
|
||||
let _ = sub.channel.try_send(ev);
|
||||
}
|
||||
dead_subscriptions.push(idx);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: Avoid cloning the event/record over and over.
|
||||
match sub.channel.try_send(event.clone()) {
|
||||
Ok(_) => {}
|
||||
Err(async_channel::TrySendError::Full(ev)) => {
|
||||
@@ -219,10 +210,10 @@ impl SubscriptionManager {
|
||||
};
|
||||
|
||||
// Join values with column names.
|
||||
let record: Vec<_> = record_values
|
||||
.into_iter()
|
||||
let record: Vec<(&str, rusqlite::types::ValueRef<'_>)> = record_values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (table_metadata.schema.columns[idx].name.as_str(), v))
|
||||
.map(|(idx, v)| (table_metadata.schema.columns[idx].name.as_str(), v.into()))
|
||||
.collect();
|
||||
|
||||
// Build a JSON-encoded SQLite event (insert, update, delete).
|
||||
@@ -231,7 +222,7 @@ impl SubscriptionManager {
|
||||
record
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
if let Ok(v) = value_to_json(value.clone()) {
|
||||
if let Ok(v) = valueref_to_json(*value) {
|
||||
return Some(((*name).to_string(), v));
|
||||
};
|
||||
return None;
|
||||
@@ -239,11 +230,17 @@ impl SubscriptionManager {
|
||||
.collect(),
|
||||
);
|
||||
|
||||
match action {
|
||||
let db_event = match action {
|
||||
RecordAction::Delete => DbEvent::Delete(Some(json_value)),
|
||||
RecordAction::Insert => DbEvent::Insert(Some(json_value)),
|
||||
RecordAction::Update => DbEvent::Update(Some(json_value)),
|
||||
}
|
||||
};
|
||||
|
||||
let Ok(event) = Event::default().json_data(db_event) else {
|
||||
return;
|
||||
};
|
||||
|
||||
event
|
||||
};
|
||||
|
||||
'record_subs: {
|
||||
@@ -318,6 +315,8 @@ impl SubscriptionManager {
|
||||
}
|
||||
|
||||
if table_subscriptions.is_empty() {
|
||||
subscriptions.remove(table_name);
|
||||
|
||||
if subscriptions.is_empty() && s.record_subscriptions.read().is_empty() {
|
||||
conn.preupdate_hook(NO_HOOK);
|
||||
}
|
||||
@@ -392,9 +391,7 @@ impl SubscriptionManager {
|
||||
api: RecordApi,
|
||||
record: trailbase_sqlite::Value,
|
||||
user: Option<User>,
|
||||
) -> Result<async_channel::Receiver<DbEvent>, RecordError> {
|
||||
let (sender, receiver) = async_channel::bounded::<DbEvent>(16);
|
||||
|
||||
) -> Result<async_channel::Receiver<Event>, RecordError> {
|
||||
let table_name = api.table_name();
|
||||
let pk_column = &api.record_pk_column().name;
|
||||
|
||||
@@ -403,7 +400,7 @@ impl SubscriptionManager {
|
||||
.conn
|
||||
.query_row(
|
||||
&format!(r#"SELECT _rowid_ FROM "{table_name}" WHERE "{pk_column}" = $1"#),
|
||||
params!(record.clone()),
|
||||
[record],
|
||||
)
|
||||
.await?
|
||||
else {
|
||||
@@ -413,6 +410,7 @@ impl SubscriptionManager {
|
||||
.get(0)
|
||||
.map_err(|err| RecordError::Internal(err.into()))?;
|
||||
|
||||
let (sender, receiver) = async_channel::bounded::<Event>(16);
|
||||
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let empty = {
|
||||
let mut lock = self.state.record_subscriptions.write();
|
||||
@@ -441,11 +439,10 @@ impl SubscriptionManager {
|
||||
&self,
|
||||
api: RecordApi,
|
||||
user: Option<User>,
|
||||
) -> Result<async_channel::Receiver<DbEvent>, RecordError> {
|
||||
let (sender, receiver) = async_channel::bounded::<DbEvent>(16);
|
||||
|
||||
) -> Result<async_channel::Receiver<Event>, RecordError> {
|
||||
let table_name = api.table_name();
|
||||
|
||||
let (sender, receiver) = async_channel::bounded::<Event>(16);
|
||||
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let empty = {
|
||||
let mut lock = self.state.table_subscriptions.write();
|
||||
@@ -508,11 +505,8 @@ pub async fn add_subscription_sse_handler(
|
||||
return Err(RecordError::ApiNotFound);
|
||||
};
|
||||
|
||||
fn encode(ev: DbEvent) -> Result<Event, axum::Error> {
|
||||
// TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy
|
||||
// to pre-encode on the sender side but makes testing much harder, since there's no good way
|
||||
// to parse sse::Event back.
|
||||
return Event::default().json_data(ev);
|
||||
fn encode(ev: Event) -> Result<Event, axum::Error> {
|
||||
return Ok(ev);
|
||||
}
|
||||
|
||||
if record == "*" {
|
||||
@@ -539,13 +533,56 @@ pub async fn add_subscription_sse_handler(
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn decode_sse_json_event(event: Event) -> serde_json::Value {
|
||||
use axum::response::IntoResponse;
|
||||
|
||||
let (sender, receiver) = async_channel::unbounded::<Event>();
|
||||
let sse = Sse::new(receiver.map(|ev| -> Result<Event, axum::Error> { Ok(ev) }));
|
||||
|
||||
sender.send(event).await.unwrap();
|
||||
sender.close();
|
||||
|
||||
let resp = sse.into_response();
|
||||
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let str = String::from_utf8_lossy(&bytes);
|
||||
let x = str
|
||||
.strip_prefix("data: ")
|
||||
.unwrap()
|
||||
.strip_suffix("\n\n")
|
||||
.unwrap();
|
||||
return serde_json::from_str(x).unwrap();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::DbEvent;
|
||||
use super::*;
|
||||
use trailbase_sqlite::params;
|
||||
|
||||
use crate::app_state::test_state;
|
||||
use crate::records::{add_record_api, AccessRules, Acls, PermissionFlag};
|
||||
|
||||
async fn decode_db_event(event: Event) -> DbEvent {
|
||||
let json = decode_sse_json_event(event).await;
|
||||
return serde_json::from_value(json).unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sse_event_encoding_test() {
|
||||
let json = serde_json::json!({
|
||||
"a": 5,
|
||||
"b": "text",
|
||||
});
|
||||
let db_event = DbEvent::Delete(Some(json));
|
||||
let event = Event::default().json_data(db_event.clone()).unwrap();
|
||||
|
||||
assert_eq!(decode_db_event(event).await, db_event);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_to_record_test() {
|
||||
let state = test_state(None).await.unwrap();
|
||||
@@ -611,7 +648,7 @@ mod tests {
|
||||
"id": record_id_raw,
|
||||
"text": "bar",
|
||||
});
|
||||
match receiver.recv().await.unwrap() {
|
||||
match decode_db_event(receiver.recv().await.unwrap()).await {
|
||||
DbEvent::Update(Some(value)) => {
|
||||
assert_eq!(value, expected);
|
||||
}
|
||||
@@ -625,7 +662,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match receiver.recv().await.unwrap() {
|
||||
match decode_db_event(receiver.recv().await.unwrap()).await {
|
||||
DbEvent::Delete(Some(value)) => {
|
||||
assert_eq!(value, expected);
|
||||
}
|
||||
@@ -691,7 +728,7 @@ mod tests {
|
||||
"id": record_id_raw,
|
||||
"text": "foo",
|
||||
});
|
||||
match receiver.recv().await.unwrap() {
|
||||
match decode_db_event(receiver.recv().await.unwrap()).await {
|
||||
DbEvent::Insert(Some(value)) => {
|
||||
assert_eq!(value, expected);
|
||||
}
|
||||
@@ -704,7 +741,7 @@ mod tests {
|
||||
"id": record_id_raw,
|
||||
"text": "bar",
|
||||
});
|
||||
match receiver.recv().await.unwrap() {
|
||||
match decode_db_event(receiver.recv().await.unwrap()).await {
|
||||
DbEvent::Update(Some(value)) => {
|
||||
assert_eq!(value, expected);
|
||||
}
|
||||
@@ -718,7 +755,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match receiver.recv().await.unwrap() {
|
||||
match decode_db_event(receiver.recv().await.unwrap()).await {
|
||||
DbEvent::Delete(Some(value)) => {
|
||||
assert_eq!(value, expected);
|
||||
}
|
||||
|
||||
@@ -214,7 +214,12 @@ impl Connection {
|
||||
&self,
|
||||
hook: Option<impl (Fn(Action, &str, &str, &PreUpdateCase)) + Send + Sync + 'static>,
|
||||
) -> Result<()> {
|
||||
return self.call(|conn| Ok(conn.preupdate_hook(hook))).await;
|
||||
return self
|
||||
.call(|conn| {
|
||||
conn.preupdate_hook(hook);
|
||||
return Ok(());
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Close the database connection.
|
||||
|
||||
@@ -22,6 +22,6 @@ pub mod schema;
|
||||
pub use connection::Connection;
|
||||
pub use error::Error;
|
||||
pub use extension::connect_sqlite;
|
||||
pub use params::{NamedParams, Params};
|
||||
pub use params::{NamedParamRef, NamedParams, NamedParamsRef, Params};
|
||||
pub use rows::{Row, Rows, ValueType};
|
||||
pub use rusqlite::types::Value;
|
||||
|
||||
@@ -3,6 +3,8 @@ use rusqlite::{types, Result, Statement};
|
||||
use std::borrow::Cow;
|
||||
|
||||
pub type NamedParams = Vec<(Cow<'static, str>, types::Value)>;
|
||||
pub type NamedParamRef<'a> = (Cow<'static, str>, types::ToSqlOutput<'a>);
|
||||
pub type NamedParamsRef<'a> = &'a [NamedParamRef<'a>];
|
||||
|
||||
// This strong typedef only exists to implement From<Option<T>>.
|
||||
#[allow(missing_debug_implementations)]
|
||||
@@ -122,6 +124,18 @@ impl Params for &[(&str, types::Value)] {
|
||||
}
|
||||
}
|
||||
|
||||
impl Params for NamedParamsRef<'_> {
|
||||
fn bind(self, stmt: &mut Statement<'_>) -> rusqlite::Result<()> {
|
||||
for (name, v) in self {
|
||||
let Some(idx) = stmt.parameter_index(name)? else {
|
||||
continue;
|
||||
};
|
||||
stmt.raw_bind_parameter(idx, v)?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Params for [(&str, types::Value); N] {
|
||||
fn bind(self, stmt: &mut Statement<'_>) -> rusqlite::Result<()> {
|
||||
for (name, v) in self {
|
||||
|
||||
Reference in New Issue
Block a user