Fix multi-db transactions (and remove state.conn()).

This is also a breaking change, since this drive-by changes `request.transaction`'s default from true to false to default to the cheaper option.
This commit is contained in:
Sebastian Jeltsch
2025-12-12 11:24:37 +01:00
parent 2d687d7a54
commit 7bfed3d7cc
7 changed files with 71 additions and 38 deletions

View File

@@ -40,6 +40,7 @@ fn get_conn_and_migration_path(
Ok((
trailbase_sqlite::Connection::new(
move || {
// TODO: We should load WASM SQLite functions, since migrations may depend on them.
return trailbase_extension::connect_sqlite(
Some(db_path.clone()),
Some(json_registry.clone()),
@@ -52,7 +53,7 @@ fn get_conn_and_migration_path(
))
}
_ => Ok((
state.conn().clone(),
(*state.connection_manager().main_entry().connection).clone(),
(state.data_dir().migrations_path().join("main")),
)),
};

View File

@@ -248,6 +248,7 @@ impl AppState {
return &self.state.json_schema_registry;
}
#[cfg(test)]
pub fn conn(&self) -> &trailbase_sqlite::Connection {
return &self.state.conn;
}
@@ -260,7 +261,7 @@ impl AppState {
return &self.state.logs_conn;
}
pub(crate) fn connection_manager(&self) -> ConnectionManager {
pub fn connection_manager(&self) -> ConnectionManager {
return self.state.connection_manager.clone();
}

View File

@@ -139,8 +139,8 @@ pub async fn delete_avatar_handler(
static ref SQL: String = format!("DELETE FROM {AVATAR_TABLE} WHERE user = ?1");
}
state
.conn()
let main_conn = state.connection_manager().main_entry().connection;
main_conn
.execute(&*SQL, [rusqlite::types::Value::Blob(user.uuid.into())])
.await?;

View File

@@ -55,7 +55,7 @@ struct ConnectionKey {
}
#[derive(Clone)]
pub(crate) struct ConnectionEntry {
pub struct ConnectionEntry {
pub connection: Arc<Connection>,
pub metadata: Arc<ConnectionMetadata>,
}
@@ -157,7 +157,7 @@ impl ConnectionManager {
// return self.state.main.read().connection.clone();
// }
pub(crate) fn main_entry(&self) -> ConnectionEntry {
pub fn main_entry(&self) -> ConnectionEntry {
return self.state.main.read().clone();
}
@@ -169,7 +169,7 @@ impl ConnectionManager {
// return Ok(self.get_entry(main, attached_databases)?.connection);
// }
pub(crate) fn get_entry(
pub fn get_entry(
&self,
main: bool,
attached_databases: Option<BTreeSet<String>>,
@@ -206,7 +206,7 @@ impl ConnectionManager {
};
}
pub(crate) fn get_entry_for_qn(
pub fn get_entry_for_qn(
&self,
name: &trailbase_schema::QualifiedName,
) -> Result<ConnectionEntry, ConnectionError> {

View File

@@ -1,6 +1,8 @@
use axum::extract::{Json, State};
use base64::prelude::*;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use trailbase_schema::QualifiedName;
use utoipa::ToSchema;
use crate::app_state::AppState;
@@ -56,19 +58,47 @@ pub async fn record_transactions_handler(
user: Option<User>,
Json(request): Json<TransactionRequest>,
) -> Result<Json<TransactionResponse>, RecordError> {
// NOTE: We may want to make this user-configurable. The cost also heavily depends on whether
// `request.transaction == true`.
if request.operations.len() > 128 {
return Err(RecordError::BadRequest("Transactions exceed limit: 128"));
}
if request.operations.is_empty() {
return Ok(Json(TransactionResponse { ids: vec![] }));
}
type Op = dyn (FnOnce(&rusqlite::Connection) -> Result<Option<String>, RecordError>) + Send;
let mut db: (String, Option<Arc<trailbase_sqlite::Connection>>) = Default::default();
let mut get_api =
|state: &AppState, api_name: &str, idx: usize| -> Result<RecordApi, RecordError> {
let api = state
.lookup_record_api(api_name)
.ok_or_else(|| RecordError::ApiNotFound)?;
if !api.is_table() {
return Err(RecordError::ApiRequiresTable);
}
// Check that all ops reference same DB.
let db_name = get_db_name(api.qualified_name());
if idx == 0 {
db = (db_name.to_string(), Some(api.conn().clone()));
} else if db_name != db.0 {
return Err(RecordError::BadRequest("ops can only touch same db"));
}
return Ok(api);
};
let operations: Vec<Box<Op>> = request
.operations
.into_iter()
.map(|op| -> Result<Box<Op>, RecordError> {
.enumerate()
.map(|(idx, op)| -> Result<Box<Op>, RecordError> {
return match op {
Operation::Create { api_name, value } => {
let api = get_api(&state, &api_name)?;
let api = get_api(&state, &api_name, idx)?;
let mut record = extract_record(value)?;
if api.insert_autofill_missing_user_id_columns()
@@ -121,7 +151,7 @@ pub async fn record_transactions_handler(
record_id,
value,
} => {
let api = get_api(&state, &api_name)?;
let api = get_api(&state, &api_name, idx)?;
let record = extract_record(value)?;
let record_id = api.primary_key_to_value(record_id)?;
let (_index, pk_column) = api.record_pk_column();
@@ -163,7 +193,7 @@ pub async fn record_transactions_handler(
api_name,
record_id,
} => {
let api = get_api(&state, &api_name)?;
let api = get_api(&state, &api_name, idx)?;
let record_id = api.primary_key_to_value(record_id)?;
let acl_check = api.build_record_level_access_check(
@@ -190,9 +220,12 @@ pub async fn record_transactions_handler(
})
.collect::<Result<Vec<_>, _>>()?;
let ids = if request.transaction.unwrap_or(true) {
state
.conn()
let conn = db
.1
.ok_or_else(|| RecordError::Internal("missing db".into()))?;
let ids = if request.transaction.unwrap_or(false) {
conn
.call(
move |conn: &mut rusqlite::Connection| -> Result<Vec<String>, trailbase_sqlite::Error> {
let tx = conn.transaction()?;
@@ -211,8 +244,7 @@ pub async fn record_transactions_handler(
)
.await?
} else {
state
.conn()
conn
.call(
move |conn: &mut rusqlite::Connection| -> Result<Vec<String>, trailbase_sqlite::Error> {
let mut ids: Vec<String> = vec![];
@@ -243,14 +275,8 @@ fn extract_record_id(value: rusqlite::types::Value) -> Result<String, trailbase_
}
#[inline]
fn get_api(state: &AppState, api_name: &str) -> Result<RecordApi, RecordError> {
let Some(api) = state.lookup_record_api(api_name) else {
return Err(RecordError::ApiNotFound);
};
if !api.is_table() {
return Err(RecordError::ApiRequiresTable);
}
return Ok(api);
fn get_db_name(name: &QualifiedName) -> &str {
return name.database_schema.as_deref().unwrap_or("main");
}
#[inline]

View File

@@ -249,9 +249,13 @@ impl Server {
// Re-apply migrations. This needs to happen before reloading the config, which is
// consistent with the startup order. Otherwise, we may validate a configuration
// against a stale database schema.
//
// TODO: Right now we're only re-applying main migrations.
let user_migrations_path = state.data_dir().migrations_path();
match state
.conn()
.connection_manager()
.main_entry()
.connection
.call(|conn: &mut rusqlite::Connection| {
return crate::migrations::apply_main_migrations(conn, Some(user_migrations_path))
.map_err(|err| trailbase_sqlite::Error::Other(err.into()));

View File

@@ -2,6 +2,7 @@ use axum::extract::{Json, State};
use axum::http::StatusCode;
use axum_test::TestServer;
use axum_test::multipart::MultipartForm;
use std::sync::Arc;
use tower_cookies::Cookie;
use trailbase_sqlite::params;
@@ -12,7 +13,7 @@ use trailbase::constants::{COOKIE_AUTH_TOKEN, RECORD_API_PATH};
use trailbase::util::id_to_b64;
use trailbase::{DataDir, Server, ServerOptions};
pub(crate) async fn add_record_api_config(
async fn add_record_api_config(
state: &AppState,
api: RecordApiConfig,
) -> Result<(), anyhow::Error> {
@@ -54,13 +55,13 @@ async fn test_record_apis() {
assert!(admin_router.is_none());
assert!(tls.is_none());
let conn = state.conn();
let conn = state.connection_manager().main_entry().connection;
let logs_conn = state.logs_conn();
create_chat_message_app_tables(conn).await.unwrap();
create_chat_message_app_tables(&conn).await.unwrap();
state.rebuild_connection_metadata().await.unwrap();
let room = add_room(conn, "room0").await.unwrap();
let room = add_room(&conn, "room0").await.unwrap();
let password = "Secret!1!!";
let client_ip = "22.11.22.11";
@@ -89,7 +90,7 @@ async fn test_record_apis() {
.await
.unwrap();
add_user_to_room(conn, user_x, room).await.unwrap();
add_user_to_room(&conn, user_x, room).await.unwrap();
#[allow(unused_mut)]
let (_address, mut router) = main_router;
@@ -261,8 +262,8 @@ async fn test_record_apis() {
assert_eq!(status, 200);
}
pub async fn create_chat_message_app_tables(
conn: &trailbase_sqlite::Connection,
async fn create_chat_message_app_tables(
conn: &Arc<trailbase_sqlite::Connection>,
) -> Result<(), anyhow::Error> {
// Create a messages, chat room and members tables.
conn
@@ -299,8 +300,8 @@ pub async fn create_chat_message_app_tables(
return Ok(());
}
pub async fn add_room(
conn: &trailbase_sqlite::Connection,
async fn add_room(
conn: &Arc<trailbase_sqlite::Connection>,
name: &str,
) -> Result<[u8; 16], anyhow::Error> {
let room: [u8; 16] = conn
@@ -315,8 +316,8 @@ pub async fn add_room(
return Ok(room);
}
pub async fn add_user_to_room(
conn: &trailbase_sqlite::Connection,
async fn add_user_to_room(
conn: &Arc<trailbase_sqlite::Connection>,
user: [u8; 16],
room: [u8; 16],
) -> Result<(), anyhow::Error> {
@@ -329,7 +330,7 @@ pub async fn add_user_to_room(
return Ok(());
}
pub(crate) async fn create_user_for_test(
async fn create_user_for_test(
state: &AppState,
email: &str,
password: &str,