diff --git a/crates/core/src/admin/table/mod.rs b/crates/core/src/admin/table/mod.rs index 25a00dc8..4df367c3 100644 --- a/crates/core/src/admin/table/mod.rs +++ b/crates/core/src/admin/table/mod.rs @@ -40,6 +40,7 @@ fn get_conn_and_migration_path( Ok(( trailbase_sqlite::Connection::new( move || { + // TODO: We should load WASM SQLite functions, since migrations may depend on them. return trailbase_extension::connect_sqlite( Some(db_path.clone()), Some(json_registry.clone()), @@ -52,7 +53,7 @@ fn get_conn_and_migration_path( )) } _ => Ok(( - state.conn().clone(), + (*state.connection_manager().main_entry().connection).clone(), (state.data_dir().migrations_path().join("main")), )), }; diff --git a/crates/core/src/app_state.rs b/crates/core/src/app_state.rs index 26a49337..17914901 100644 --- a/crates/core/src/app_state.rs +++ b/crates/core/src/app_state.rs @@ -248,6 +248,7 @@ impl AppState { return &self.state.json_schema_registry; } + #[cfg(test)] pub fn conn(&self) -> &trailbase_sqlite::Connection { return &self.state.conn; } @@ -260,7 +261,7 @@ impl AppState { return &self.state.logs_conn; } - pub(crate) fn connection_manager(&self) -> ConnectionManager { + pub fn connection_manager(&self) -> ConnectionManager { return self.state.connection_manager.clone(); } diff --git a/crates/core/src/auth/api/avatar.rs b/crates/core/src/auth/api/avatar.rs index c45a9f37..5c84af4b 100644 --- a/crates/core/src/auth/api/avatar.rs +++ b/crates/core/src/auth/api/avatar.rs @@ -139,8 +139,8 @@ pub async fn delete_avatar_handler( static ref SQL: String = format!("DELETE FROM {AVATAR_TABLE} WHERE user = ?1"); } - state - .conn() + let main_conn = state.connection_manager().main_entry().connection; + main_conn .execute(&*SQL, [rusqlite::types::Value::Blob(user.uuid.into())]) .await?; diff --git a/crates/core/src/connection.rs b/crates/core/src/connection.rs index bc2a494f..a43928ed 100644 --- a/crates/core/src/connection.rs +++ b/crates/core/src/connection.rs @@ -55,7 +55,7 @@ struct ConnectionKey { } #[derive(Clone)] -pub(crate) struct ConnectionEntry { +pub struct ConnectionEntry { pub connection: Arc, pub metadata: Arc, } @@ -157,7 +157,7 @@ impl ConnectionManager { // return self.state.main.read().connection.clone(); // } - pub(crate) fn main_entry(&self) -> ConnectionEntry { + pub fn main_entry(&self) -> ConnectionEntry { return self.state.main.read().clone(); } @@ -169,7 +169,7 @@ impl ConnectionManager { // return Ok(self.get_entry(main, attached_databases)?.connection); // } - pub(crate) fn get_entry( + pub fn get_entry( &self, main: bool, attached_databases: Option>, @@ -206,7 +206,7 @@ impl ConnectionManager { }; } - pub(crate) fn get_entry_for_qn( + pub fn get_entry_for_qn( &self, name: &trailbase_schema::QualifiedName, ) -> Result { diff --git a/crates/core/src/records/transaction.rs b/crates/core/src/records/transaction.rs index cfbebaf3..6a631290 100644 --- a/crates/core/src/records/transaction.rs +++ b/crates/core/src/records/transaction.rs @@ -1,6 +1,8 @@ use axum::extract::{Json, State}; use base64::prelude::*; use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use trailbase_schema::QualifiedName; use utoipa::ToSchema; use crate::app_state::AppState; @@ -56,19 +58,47 @@ pub async fn record_transactions_handler( user: Option, Json(request): Json, ) -> Result, RecordError> { + // NOTE: We may want to make this user-configurable. The cost also heavily depends on whether + // `request.transaction == true`. if request.operations.len() > 128 { return Err(RecordError::BadRequest("Transactions exceed limit: 128")); } + if request.operations.is_empty() { + return Ok(Json(TransactionResponse { ids: vec![] })); + } + type Op = dyn (FnOnce(&rusqlite::Connection) -> Result, RecordError>) + Send; + let mut db: (String, Option>) = Default::default(); + let mut get_api = + |state: &AppState, api_name: &str, idx: usize| -> Result { + let api = state + .lookup_record_api(api_name) + .ok_or_else(|| RecordError::ApiNotFound)?; + if !api.is_table() { + return Err(RecordError::ApiRequiresTable); + } + + // Check that all ops reference same DB. + let db_name = get_db_name(api.qualified_name()); + if idx == 0 { + db = (db_name.to_string(), Some(api.conn().clone())); + } else if db_name != db.0 { + return Err(RecordError::BadRequest("ops can only touch same db")); + } + + return Ok(api); + }; + let operations: Vec> = request .operations .into_iter() - .map(|op| -> Result, RecordError> { + .enumerate() + .map(|(idx, op)| -> Result, RecordError> { return match op { Operation::Create { api_name, value } => { - let api = get_api(&state, &api_name)?; + let api = get_api(&state, &api_name, idx)?; let mut record = extract_record(value)?; if api.insert_autofill_missing_user_id_columns() @@ -121,7 +151,7 @@ pub async fn record_transactions_handler( record_id, value, } => { - let api = get_api(&state, &api_name)?; + let api = get_api(&state, &api_name, idx)?; let record = extract_record(value)?; let record_id = api.primary_key_to_value(record_id)?; let (_index, pk_column) = api.record_pk_column(); @@ -163,7 +193,7 @@ pub async fn record_transactions_handler( api_name, record_id, } => { - let api = get_api(&state, &api_name)?; + let api = get_api(&state, &api_name, idx)?; let record_id = api.primary_key_to_value(record_id)?; let acl_check = api.build_record_level_access_check( @@ -190,9 +220,12 @@ pub async fn record_transactions_handler( }) .collect::, _>>()?; - let ids = if request.transaction.unwrap_or(true) { - state - .conn() + let conn = db + .1 + .ok_or_else(|| RecordError::Internal("missing db".into()))?; + + let ids = if request.transaction.unwrap_or(false) { + conn .call( move |conn: &mut rusqlite::Connection| -> Result, trailbase_sqlite::Error> { let tx = conn.transaction()?; @@ -211,8 +244,7 @@ pub async fn record_transactions_handler( ) .await? } else { - state - .conn() + conn .call( move |conn: &mut rusqlite::Connection| -> Result, trailbase_sqlite::Error> { let mut ids: Vec = vec![]; @@ -243,14 +275,8 @@ fn extract_record_id(value: rusqlite::types::Value) -> Result Result { - let Some(api) = state.lookup_record_api(api_name) else { - return Err(RecordError::ApiNotFound); - }; - if !api.is_table() { - return Err(RecordError::ApiRequiresTable); - } - return Ok(api); +fn get_db_name(name: &QualifiedName) -> &str { + return name.database_schema.as_deref().unwrap_or("main"); } #[inline] diff --git a/crates/core/src/server/mod.rs b/crates/core/src/server/mod.rs index 2b26a2c6..add49eb3 100644 --- a/crates/core/src/server/mod.rs +++ b/crates/core/src/server/mod.rs @@ -249,9 +249,13 @@ impl Server { // Re-apply migrations. This needs to happen before reloading the config, which is // consistent with the startup order. Otherwise, we may validate a configuration // against a stale database schema. + // + // TODO: Right now we're only re-applying main migrations. let user_migrations_path = state.data_dir().migrations_path(); match state - .conn() + .connection_manager() + .main_entry() + .connection .call(|conn: &mut rusqlite::Connection| { return crate::migrations::apply_main_migrations(conn, Some(user_migrations_path)) .map_err(|err| trailbase_sqlite::Error::Other(err.into())); diff --git a/crates/core/tests/integration_test.rs b/crates/core/tests/integration_test.rs index bb747c4d..68a17165 100644 --- a/crates/core/tests/integration_test.rs +++ b/crates/core/tests/integration_test.rs @@ -2,6 +2,7 @@ use axum::extract::{Json, State}; use axum::http::StatusCode; use axum_test::TestServer; use axum_test::multipart::MultipartForm; +use std::sync::Arc; use tower_cookies::Cookie; use trailbase_sqlite::params; @@ -12,7 +13,7 @@ use trailbase::constants::{COOKIE_AUTH_TOKEN, RECORD_API_PATH}; use trailbase::util::id_to_b64; use trailbase::{DataDir, Server, ServerOptions}; -pub(crate) async fn add_record_api_config( +async fn add_record_api_config( state: &AppState, api: RecordApiConfig, ) -> Result<(), anyhow::Error> { @@ -54,13 +55,13 @@ async fn test_record_apis() { assert!(admin_router.is_none()); assert!(tls.is_none()); - let conn = state.conn(); + let conn = state.connection_manager().main_entry().connection; let logs_conn = state.logs_conn(); - create_chat_message_app_tables(conn).await.unwrap(); + create_chat_message_app_tables(&conn).await.unwrap(); state.rebuild_connection_metadata().await.unwrap(); - let room = add_room(conn, "room0").await.unwrap(); + let room = add_room(&conn, "room0").await.unwrap(); let password = "Secret!1!!"; let client_ip = "22.11.22.11"; @@ -89,7 +90,7 @@ async fn test_record_apis() { .await .unwrap(); - add_user_to_room(conn, user_x, room).await.unwrap(); + add_user_to_room(&conn, user_x, room).await.unwrap(); #[allow(unused_mut)] let (_address, mut router) = main_router; @@ -261,8 +262,8 @@ async fn test_record_apis() { assert_eq!(status, 200); } -pub async fn create_chat_message_app_tables( - conn: &trailbase_sqlite::Connection, +async fn create_chat_message_app_tables( + conn: &Arc, ) -> Result<(), anyhow::Error> { // Create a messages, chat room and members tables. conn @@ -299,8 +300,8 @@ pub async fn create_chat_message_app_tables( return Ok(()); } -pub async fn add_room( - conn: &trailbase_sqlite::Connection, +async fn add_room( + conn: &Arc, name: &str, ) -> Result<[u8; 16], anyhow::Error> { let room: [u8; 16] = conn @@ -315,8 +316,8 @@ pub async fn add_room( return Ok(room); } -pub async fn add_user_to_room( - conn: &trailbase_sqlite::Connection, +async fn add_user_to_room( + conn: &Arc, user: [u8; 16], room: [u8; 16], ) -> Result<(), anyhow::Error> { @@ -329,7 +330,7 @@ pub async fn add_user_to_room( return Ok(()); } -pub(crate) async fn create_user_for_test( +async fn create_user_for_test( state: &AppState, email: &str, password: &str,