Optimization: reduce allocations, cloning and re-encoding.

This commit is contained in:
Sebastian Jeltsch
2025-01-12 11:55:18 +01:00
parent e2b0c0d05e
commit 1828ebad5a
6 changed files with 155 additions and 56 deletions

View File

@@ -2,7 +2,7 @@ use itertools::Itertools;
use log::*;
use std::borrow::Cow;
use std::sync::Arc;
use trailbase_sqlite::{NamedParams, Params as _, Value};
use trailbase_sqlite::{NamedParamRef, NamedParams, NamedParamsRef, Params as _, Value};
use crate::auth::user::User;
use crate::config::proto::{ConflictResolutionStrategy, RecordApiConfig};
@@ -343,7 +343,7 @@ impl RecordApi {
&self,
conn: &rusqlite::Connection,
p: Permission,
record: Vec<(&str, rusqlite::types::Value)>,
record: &[(&str, rusqlite::types::ValueRef<'_>)],
user: Option<&User>,
) -> Result<(), RecordError> {
// First check table level access and if present check row-level access based on access rule.
@@ -355,7 +355,7 @@ impl RecordApi {
let (query, params) = build_query_and_params_for_record_read(access_rule, user, record);
match Self::query_access(conn, &query, params) {
match Self::query_access_ref(conn, &query, &params) {
Ok(allowed) => {
if allowed {
return Ok(());
@@ -404,6 +404,23 @@ impl RecordApi {
return Err(rusqlite::Error::QueryReturnedNoRows.into());
}
#[inline]
fn query_access_ref(
conn: &rusqlite::Connection,
access_query: &str,
params: NamedParamsRef,
) -> Result<bool, trailbase_sqlite::Error> {
let mut stmt = conn.prepare_cached(access_query)?;
params.bind(&mut stmt)?;
let mut rows = stmt.raw_query();
if let Some(row) = rows.next()? {
return Ok(row.get(0)?);
}
return Err(rusqlite::Error::QueryReturnedNoRows.into());
}
#[inline]
fn has_access(&self, e: Entity, p: Permission) -> bool {
return (self.state.acl[e as usize] & (p as u8)) > 0;
@@ -453,11 +470,11 @@ impl RecordApi {
}
}
fn build_query_and_params_for_record_read(
fn build_query_and_params_for_record_read<'a>(
access_rule: &str,
user: Option<&User>,
record: Vec<(&str, rusqlite::types::Value)>,
) -> (String, NamedParams) {
record: &[(&str, rusqlite::types::ValueRef<'a>)],
) -> (String, Vec<NamedParamRef<'a>>) {
let row = record
.iter()
.map(|(name, _value)| {
@@ -467,15 +484,22 @@ fn build_query_and_params_for_record_read(
.join(", ");
let mut params: Vec<_> = record
.into_iter()
.iter()
.map(|(name, value)| {
return (Cow::Owned(format!(":__v_{name}")), value);
return (
Cow::Owned(format!(":__v_{name}")),
rusqlite::types::ToSqlOutput::Borrowed(*value),
);
})
.collect();
static NULL: rusqlite::types::ToSqlOutput<'static> =
rusqlite::types::ToSqlOutput::Owned(Value::Null);
params.push((
Cow::Borrowed(":__user_id"),
user.map_or(Value::Null, |u| Value::Blob(u.uuid.into())),
user.map_or(NULL.clone(), |u| {
rusqlite::types::ToSqlOutput::Owned(Value::Blob(u.uuid.into()))
}),
));
// Assumes access_rule is an expression: https://www.sqlite.org/syntax/expr.html

View File

@@ -36,6 +36,25 @@ pub(crate) fn value_to_json(value: rusqlite::types::Value) -> Result<serde_json:
});
}
pub(crate) fn valueref_to_json(
value: rusqlite::types::ValueRef<'_>,
) -> Result<serde_json::Value, JsonError> {
use rusqlite::types::ValueRef;
return Ok(match value {
ValueRef::Null => serde_json::Value::Null,
ValueRef::Real(real) => {
let Some(number) = serde_json::Number::from_f64(real) else {
return Err(JsonError::Finite);
};
serde_json::Value::Number(number)
}
ValueRef::Integer(integer) => serde_json::Value::Number(serde_json::Number::from(integer)),
ValueRef::Blob(blob) => serde_json::Value::String(BASE64_URL_SAFE.encode(blob)),
ValueRef::Text(text) => serde_json::Value::String(String::from_utf8_lossy(text).to_string()),
});
}
/// Serialize SQL row to json.
pub fn row_to_json(
metadata: &(dyn TableOrViewMetadata + Send + Sync),

View File

@@ -5,19 +5,16 @@ use axum::{
use futures::stream::{Stream, StreamExt};
use parking_lot::RwLock;
use rusqlite::hooks::{Action, PreUpdateCase};
use serde::Serialize;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicI64, Ordering},
Arc,
};
use trailbase_sqlite::{
connection::{extract_record_values, extract_row_id},
params,
};
use trailbase_sqlite::connection::{extract_record_values, extract_row_id};
use crate::auth::user::User;
use crate::records::sql_to_json::value_to_json;
use crate::records::sql_to_json::valueref_to_json;
use crate::records::RecordApi;
use crate::records::{Permission, RecordError};
use crate::table_metadata::{TableMetadata, TableMetadataCache};
@@ -28,9 +25,6 @@ static SUBSCRIPTION_COUNTER: AtomicI64 = AtomicI64::new(0);
// TODO:
// * clients
// * optimize: avoid repeated encoding of events. Easy to do but makes testing harder since there's
// no good way to parse sse::Event back :/. We should probably just bite the bullet and parse,
// it's literally "data: <json>\n\n".
type SseEvent = Result<axum::response::sse::Event, axum::Error>;
@@ -52,7 +46,7 @@ impl From<Action> for RecordAction {
}
}
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum DbEvent {
Update(Option<serde_json::Value>),
Insert(Option<serde_json::Value>),
@@ -71,7 +65,7 @@ pub struct Subscription {
// record_id: Option<trailbase_sqlite::Value>,
user: Option<User>,
/// Channel for sending events to the SSE handler.
channel: async_channel::Sender<DbEvent>,
channel: async_channel::Sender<Event>,
}
/// Internal, shareable state of the cloneable SubscriptionManager.
@@ -148,8 +142,8 @@ impl SubscriptionManager {
s: &ManagerState,
conn: &rusqlite::Connection,
subs: &[Subscription],
record: &[(&str, rusqlite::types::Value)],
event: &DbEvent,
record: &[(&str, rusqlite::types::ValueRef<'_>)],
event: &Event,
) -> Vec<usize> {
let mut dead_subscriptions: Vec<usize> = vec![];
for (idx, sub) in subs.iter().enumerate() {
@@ -158,21 +152,18 @@ impl SubscriptionManager {
continue;
};
if let Err(_err) = api.check_record_level_read_access(
conn,
Permission::Read,
// TODO: Maybe we could inject ValueRef instead to avoid repeated cloning.
record.to_owned(),
sub.user.as_ref(),
) {
if let Err(_err) =
api.check_record_level_read_access(conn, Permission::Read, record, sub.user.as_ref())
{
// This can happen if the record api configuration has changed since originally
// subscribed. In this case we just send and error and cancel the subscription.
let _ = sub.channel.try_send(DbEvent::Error("Access denied".into()));
if let Ok(ev) = Event::default().json_data(DbEvent::Error("Access denied".into())) {
let _ = sub.channel.try_send(ev);
}
dead_subscriptions.push(idx);
continue;
}
// TODO: Avoid cloning the event/record over and over.
match sub.channel.try_send(event.clone()) {
Ok(_) => {}
Err(async_channel::TrySendError::Full(ev)) => {
@@ -219,10 +210,10 @@ impl SubscriptionManager {
};
// Join values with column names.
let record: Vec<_> = record_values
.into_iter()
let record: Vec<(&str, rusqlite::types::ValueRef<'_>)> = record_values
.iter()
.enumerate()
.map(|(idx, v)| (table_metadata.schema.columns[idx].name.as_str(), v))
.map(|(idx, v)| (table_metadata.schema.columns[idx].name.as_str(), v.into()))
.collect();
// Build a JSON-encoded SQLite event (insert, update, delete).
@@ -231,7 +222,7 @@ impl SubscriptionManager {
record
.iter()
.filter_map(|(name, value)| {
if let Ok(v) = value_to_json(value.clone()) {
if let Ok(v) = valueref_to_json(*value) {
return Some(((*name).to_string(), v));
};
return None;
@@ -239,11 +230,17 @@ impl SubscriptionManager {
.collect(),
);
match action {
let db_event = match action {
RecordAction::Delete => DbEvent::Delete(Some(json_value)),
RecordAction::Insert => DbEvent::Insert(Some(json_value)),
RecordAction::Update => DbEvent::Update(Some(json_value)),
}
};
let Ok(event) = Event::default().json_data(db_event) else {
return;
};
event
};
'record_subs: {
@@ -318,6 +315,8 @@ impl SubscriptionManager {
}
if table_subscriptions.is_empty() {
subscriptions.remove(table_name);
if subscriptions.is_empty() && s.record_subscriptions.read().is_empty() {
conn.preupdate_hook(NO_HOOK);
}
@@ -392,9 +391,7 @@ impl SubscriptionManager {
api: RecordApi,
record: trailbase_sqlite::Value,
user: Option<User>,
) -> Result<async_channel::Receiver<DbEvent>, RecordError> {
let (sender, receiver) = async_channel::bounded::<DbEvent>(16);
) -> Result<async_channel::Receiver<Event>, RecordError> {
let table_name = api.table_name();
let pk_column = &api.record_pk_column().name;
@@ -403,7 +400,7 @@ impl SubscriptionManager {
.conn
.query_row(
&format!(r#"SELECT _rowid_ FROM "{table_name}" WHERE "{pk_column}" = $1"#),
params!(record.clone()),
[record],
)
.await?
else {
@@ -413,6 +410,7 @@ impl SubscriptionManager {
.get(0)
.map_err(|err| RecordError::Internal(err.into()))?;
let (sender, receiver) = async_channel::bounded::<Event>(16);
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
let empty = {
let mut lock = self.state.record_subscriptions.write();
@@ -441,11 +439,10 @@ impl SubscriptionManager {
&self,
api: RecordApi,
user: Option<User>,
) -> Result<async_channel::Receiver<DbEvent>, RecordError> {
let (sender, receiver) = async_channel::bounded::<DbEvent>(16);
) -> Result<async_channel::Receiver<Event>, RecordError> {
let table_name = api.table_name();
let (sender, receiver) = async_channel::bounded::<Event>(16);
let subscription_id = SUBSCRIPTION_COUNTER.fetch_add(1, Ordering::SeqCst);
let empty = {
let mut lock = self.state.table_subscriptions.write();
@@ -508,11 +505,8 @@ pub async fn add_subscription_sse_handler(
return Err(RecordError::ApiNotFound);
};
fn encode(ev: DbEvent) -> Result<Event, axum::Error> {
// TODO: We're re-encoding the event over and over again for all subscriptions. Would be easy
// to pre-encode on the sender side but makes testing much harder, since there's no good way
// to parse sse::Event back.
return Event::default().json_data(ev);
fn encode(ev: Event) -> Result<Event, axum::Error> {
return Ok(ev);
}
if record == "*" {
@@ -539,13 +533,56 @@ pub async fn add_subscription_sse_handler(
}
}
#[cfg(test)]
async fn decode_sse_json_event(event: Event) -> serde_json::Value {
use axum::response::IntoResponse;
let (sender, receiver) = async_channel::unbounded::<Event>();
let sse = Sse::new(receiver.map(|ev| -> Result<Event, axum::Error> { Ok(ev) }));
sender.send(event).await.unwrap();
sender.close();
let resp = sse.into_response();
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let str = String::from_utf8_lossy(&bytes);
let x = str
.strip_prefix("data: ")
.unwrap()
.strip_suffix("\n\n")
.unwrap();
return serde_json::from_str(x).unwrap();
}
#[cfg(test)]
mod tests {
use super::DbEvent;
use super::*;
use trailbase_sqlite::params;
use crate::app_state::test_state;
use crate::records::{add_record_api, AccessRules, Acls, PermissionFlag};
async fn decode_db_event(event: Event) -> DbEvent {
let json = decode_sse_json_event(event).await;
return serde_json::from_value(json).unwrap();
}
#[tokio::test]
async fn sse_event_encoding_test() {
let json = serde_json::json!({
"a": 5,
"b": "text",
});
let db_event = DbEvent::Delete(Some(json));
let event = Event::default().json_data(db_event.clone()).unwrap();
assert_eq!(decode_db_event(event).await, db_event);
}
#[tokio::test]
async fn subscribe_to_record_test() {
let state = test_state(None).await.unwrap();
@@ -611,7 +648,7 @@ mod tests {
"id": record_id_raw,
"text": "bar",
});
match receiver.recv().await.unwrap() {
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Update(Some(value)) => {
assert_eq!(value, expected);
}
@@ -625,7 +662,7 @@ mod tests {
.await
.unwrap();
match receiver.recv().await.unwrap() {
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Delete(Some(value)) => {
assert_eq!(value, expected);
}
@@ -691,7 +728,7 @@ mod tests {
"id": record_id_raw,
"text": "foo",
});
match receiver.recv().await.unwrap() {
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Insert(Some(value)) => {
assert_eq!(value, expected);
}
@@ -704,7 +741,7 @@ mod tests {
"id": record_id_raw,
"text": "bar",
});
match receiver.recv().await.unwrap() {
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Update(Some(value)) => {
assert_eq!(value, expected);
}
@@ -718,7 +755,7 @@ mod tests {
.await
.unwrap();
match receiver.recv().await.unwrap() {
match decode_db_event(receiver.recv().await.unwrap()).await {
DbEvent::Delete(Some(value)) => {
assert_eq!(value, expected);
}

View File

@@ -214,7 +214,12 @@ impl Connection {
&self,
hook: Option<impl (Fn(Action, &str, &str, &PreUpdateCase)) + Send + Sync + 'static>,
) -> Result<()> {
return self.call(|conn| Ok(conn.preupdate_hook(hook))).await;
return self
.call(|conn| {
conn.preupdate_hook(hook);
return Ok(());
})
.await;
}
/// Close the database connection.

View File

@@ -22,6 +22,6 @@ pub mod schema;
pub use connection::Connection;
pub use error::Error;
pub use extension::connect_sqlite;
pub use params::{NamedParams, Params};
pub use params::{NamedParamRef, NamedParams, NamedParamsRef, Params};
pub use rows::{Row, Rows, ValueType};
pub use rusqlite::types::Value;

View File

@@ -3,6 +3,8 @@ use rusqlite::{types, Result, Statement};
use std::borrow::Cow;
pub type NamedParams = Vec<(Cow<'static, str>, types::Value)>;
pub type NamedParamRef<'a> = (Cow<'static, str>, types::ToSqlOutput<'a>);
pub type NamedParamsRef<'a> = &'a [NamedParamRef<'a>];
// This strong typedef only exists to implement From<Option<T>>.
#[allow(missing_debug_implementations)]
@@ -122,6 +124,18 @@ impl Params for &[(&str, types::Value)] {
}
}
impl Params for NamedParamsRef<'_> {
fn bind(self, stmt: &mut Statement<'_>) -> rusqlite::Result<()> {
for (name, v) in self {
let Some(idx) = stmt.parameter_index(name)? else {
continue;
};
stmt.raw_bind_parameter(idx, v)?;
}
return Ok(());
}
}
impl<const N: usize> Params for [(&str, types::Value); N] {
fn bind(self, stmt: &mut Statement<'_>) -> rusqlite::Result<()> {
for (name, v) in self {