Add SQLite transaction support to JS runtime.

This commit is contained in:
Sebastian Jeltsch
2025-05-03 13:58:05 +02:00
parent 4e645ee487
commit aa1a767320
6 changed files with 348 additions and 35 deletions

8
Cargo.lock generated
View File

@@ -5398,6 +5398,12 @@ dependencies = [
"libc",
]
[[package]]
name = "self_cell"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f7d95a54511e0c7be3f51e8867aa8cf35148d7b9445d44de2f943e2b206e749"
[[package]]
name = "semver"
version = "0.9.0"
@@ -6876,8 +6882,10 @@ dependencies = [
"kanal",
"log",
"parking_lot",
"rusqlite",
"rust-embed",
"rustyscript",
"self_cell",
"serde",
"serde_json",
"tokio",

View File

@@ -7,8 +7,10 @@ import {
htmlHandler,
jsonHandler,
stringHandler,
transaction,
HttpError,
StatusCodes,
Transaction,
} from "../trailbase.js";
import type {
JsonRequestType,
@@ -24,7 +26,7 @@ addRoute(
const table = uri.query.get("table");
if (table) {
const rows = await query(`SELECT COUNT(*) FROM ${table}`, []);
const rows = await query(`SELECT COUNT(*) FROM "${table}"`, []);
return `entries: ${rows[0][0]}`;
}
@@ -38,7 +40,7 @@ addRoute(
stringHandler(async (req: StringRequestType) => {
const table = req.params["table"];
if (table) {
const rows = await query(`SELECT COUNT(*) FROM ${table}`, []);
const rows = await query(`SELECT COUNT(*) FROM "${table}"`, []);
return `entries: ${rows[0][0]}`;
}
@@ -46,6 +48,24 @@ addRoute(
}),
);
addRoute(
"GET",
"/tx/{table}",
stringHandler(async (req: StringRequestType) => {
const table = req.params["table"];
if (table) {
const count = transaction((tx: Transaction) => {
const rows = tx.query(`SELECT COUNT(*) FROM "${table}"`, []);
return rows[0][0] as number;
});
return `entries: ${count}`;
}
return `test: ${req.uri}`;
}),
);
addRoute(
"GET",
"/html",

View File

@@ -17,8 +17,10 @@ futures-util = { version = "0.3", default-features = false, features = ["alloc"]
kanal = "0.1.1"
log = { version = "^0.4.21", default-features = false }
parking_lot = { version = "0.12.3", default-features = false }
rusqlite = { workspace = true }
rust-embed = { workspace = true }
rustyscript = { version = "^0.11.0", features = ["web", "fs"] }
self_cell = "1.2.0"
serde = { version = "^1.0.203", features = ["derive"] }
serde_json = "^1.0.117"
tokio = { workspace = true }

View File

@@ -28,6 +28,8 @@ export {
parsePath,
query,
stringHandler,
transaction,
Transaction,
} from "./trailbase";
export type {

View File

@@ -703,18 +703,57 @@ export function addPeriodicCallback(
/// Queries the SQLite database.
export async function query(
queryStr: string,
sql: string,
params: unknown[],
): Promise<unknown[][]> {
return await rustyscript.async_functions.query(queryStr, params);
return await rustyscript.async_functions.query(sql, params);
}
/// Executes given query against SQLite database.
export async function execute(
queryStr: string,
params: unknown[],
): Promise<number> {
return await rustyscript.async_functions.execute(queryStr, params);
export async function execute(sql: string, params: unknown[]): Promise<number> {
return await rustyscript.async_functions.execute(sql, params);
}
export class Transaction {
finalized: boolean;
constructor() {
this.finalized = false;
}
public query(queryStr: string, params: unknown[]): unknown[][] {
return rustyscript.functions.transaction_query(queryStr, params);
}
public execute(queryStr: string, params: unknown[]): number {
return rustyscript.functions.transaction_execute(queryStr, params);
}
public commit(): void {
this.finalized = true;
rustyscript.functions.transaction_commit();
}
public rollback(): void {
this.finalized = true;
rustyscript.functions.transaction_rollback();
}
}
export function transaction<T>(f: (tx: Transaction) => T): T {
rustyscript.functions.transaction_begin();
const tx = new Transaction();
try {
const r = f(tx);
if (!tx.finalized) {
rustyscript.functions.transaction_rollback();
}
return r;
} catch (e) {
rustyscript.functions.transaction_rollback();
throw e;
}
}
export type ParsedPath = {

View File

@@ -1,13 +1,18 @@
use futures_util::future::LocalBoxFuture;
use log::*;
use parking_lot::Mutex;
use rusqlite::Transaction;
use rustyscript::{deno_core::PollEventLoopOptions, init_platform, js_value::Promise};
use self_cell::{MutBorrow, self_cell};
use serde::Serialize;
use std::collections::HashSet;
use std::path::Path;
use std::rc::Rc;
use std::sync::OnceLock;
use tokio::sync::oneshot;
use tokio::time::Duration;
use tracing_subscriber::prelude::*;
use trailbase_sqlite::Params;
use trailbase_sqlite::connection::LockGuard;
use trailbase_sqlite::rows::{JsonError, row_to_json_array};
use crate::JsRuntimeAssets;
@@ -154,6 +159,8 @@ impl RuntimeState {
let handle = if n_threads > 0 {
Some(std::thread::spawn(move || {
use tracing_subscriber::prelude::*;
// swc_ecma_codegen is very spammy (or at least used to be):
// https://github.com/swc-project/swc/pull/9604
tracing_subscriber::Registry::default()
@@ -172,7 +179,7 @@ impl RuntimeState {
let shared_receiver = shared_receiver.clone();
return std::thread::spawn(move || {
let tokio_runtime = std::rc::Rc::new(
let tokio_runtime = Rc::new(
tokio::runtime::Builder::new_current_thread()
.enable_time()
.enable_io()
@@ -213,7 +220,7 @@ impl RuntimeState {
fn init_runtime(
index: usize,
tokio_runtime: std::rc::Rc<tokio::runtime::Runtime>,
tokio_runtime: Rc<tokio::runtime::Runtime>,
) -> Result<Runtime, AnyError> {
let mut runtime = rustyscript::Runtime::with_tokio_runtime(
rustyscript::RuntimeOptions {
@@ -416,10 +423,36 @@ impl RuntimeHandle {
}
}
self_cell!(
struct OwnedLock {
owner: trailbase_sqlite::Connection,
#[covariant]
dependent: LockGuard,
}
);
struct OwnedLockWrapper(OwnedLock);
self_cell!(
struct OwnedTransaction {
owner: MutBorrow<OwnedLockWrapper>,
#[covariant]
dependent: Transaction,
}
);
pub fn register_database_functions(handle: &RuntimeHandle, conn: trailbase_sqlite::Connection) {
fn error_mapper(err: impl std::error::Error) -> Error {
return Error::Runtime(err.to_string());
}
fn register(runtime: &mut Runtime, conn: trailbase_sqlite::Connection) -> Result<(), Error> {
let conn_clone = conn.clone();
runtime.register_async_function("query", move |args: Vec<serde_json::Value>| {
assert_eq!(args.len(), 2);
let conn = conn_clone.clone();
Box::pin(async move {
let query: String = get_arg(&args, 0)?;
@@ -428,7 +461,7 @@ pub fn register_database_functions(handle: &RuntimeHandle, conn: trailbase_sqlit
let rows = conn
.write_query_rows(query, params)
.await
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
.map_err(error_mapper)?;
let values = rows
.iter()
@@ -436,7 +469,7 @@ pub fn register_database_functions(handle: &RuntimeHandle, conn: trailbase_sqlit
return Ok(serde_json::Value::Array(row_to_json_array(row)?));
})
.collect::<Result<Vec<_>, _>>()
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
.map_err(error_mapper)?;
return Ok(serde_json::Value::Array(values));
})
@@ -444,20 +477,123 @@ pub fn register_database_functions(handle: &RuntimeHandle, conn: trailbase_sqlit
let conn_clone = conn.clone();
runtime.register_async_function("execute", move |args: Vec<serde_json::Value>| {
assert_eq!(args.len(), 2);
let conn = conn_clone.clone();
Box::pin(async move {
let query: String = get_arg(&args, 0)?;
let params = json_values_to_params(get_arg(&args, 1)?)?;
let rows_affected = conn
.execute(query, params)
.await
.map_err(|err| rustyscript::Error::Runtime(err.to_string()))?;
let rows_affected = conn.execute(query, params).await.map_err(error_mapper)?;
return Ok(serde_json::Value::Number(rows_affected.into()));
})
})?;
let current_transaction: Rc<Mutex<Option<OwnedTransaction>>> = Rc::new(Mutex::new(None));
let current_transaction_clone = current_transaction.clone();
runtime.register_function("transaction_begin", move |args: &[serde_json::Value]| {
assert_eq!(args.len(), 0);
assert!(current_transaction_clone.lock().is_none());
let lock = OwnedLock::new(conn.clone(), |owner| owner.write_lock());
let tx = OwnedTransaction::try_new(MutBorrow::new(OwnedLockWrapper(lock)), |lock| {
lock
.borrow_mut()
.0
.with_dependent_mut(|_owner, depdendent| depdendent.transaction())
})
.map_err(error_mapper)?;
*current_transaction_clone.lock() = Some(tx);
return Ok(serde_json::Value::Null);
})?;
let current_transaction_clone = current_transaction.clone();
runtime.register_function("transaction_query", move |args: &[serde_json::Value]| {
assert_eq!(args.len(), 2);
let query: String = get_arg(args, 0)?;
let params = json_values_to_params(get_arg(args, 1)?)?;
let tx = current_transaction_clone.lock();
if let Some(tx) = &*tx {
let mut stmt = tx
.borrow_dependent()
.prepare(&query)
.map_err(error_mapper)?;
params.bind(&mut stmt).map_err(error_mapper)?;
let rows =
trailbase_sqlite::rows::Rows::from_rows(stmt.raw_query()).map_err(error_mapper)?;
let values = rows
.iter()
.map(|row| -> Result<serde_json::Value, JsonError> {
return Ok(serde_json::Value::Array(row_to_json_array(row)?));
})
.collect::<Result<Vec<_>, _>>()
.map_err(error_mapper)?;
return Ok(serde_json::Value::Array(values));
}
return Ok(serde_json::Value::Null);
})?;
let current_transaction_clone = current_transaction.clone();
runtime.register_function(
"transaction_execute",
move |args: &[serde_json::Value]| {
assert_eq!(args.len(), 2);
let query: String = get_arg(args, 0)?;
let params = json_values_to_params(get_arg(args, 1)?)?;
let tx = current_transaction_clone.lock();
if let Some(tx) = &*tx {
let mut stmt = tx
.borrow_dependent()
.prepare(&query)
.map_err(error_mapper)?;
params.bind(&mut stmt).map_err(error_mapper)?;
let rows_affected = stmt.raw_execute().map_err(error_mapper)?;
return Ok(serde_json::Value::Number(rows_affected.into()));
}
return Ok(serde_json::Value::Null);
},
)?;
let current_transaction_clone = current_transaction.clone();
runtime.register_function("transaction_commit", move |args: &[serde_json::Value]| {
assert_eq!(args.len(), 0);
let tx = current_transaction_clone.lock().take();
if let Some(tx) = tx {
// NOTE: this is the same as `tx.commit()` just w/o consuming.
tx.borrow_dependent()
.execute_batch("COMMIT")
.map_err(error_mapper)?;
}
return Ok(serde_json::Value::Null);
})?;
let current_transaction_clone = current_transaction.clone();
runtime.register_function(
"transaction_rollback",
move |args: &[serde_json::Value]| {
assert_eq!(args.len(), 0);
let tx = current_transaction_clone.lock().take();
if let Some(tx) = tx {
// NOTE: this is the same as `tx.rollback()` just w/o consuming.
tx.borrow_dependent()
.execute_batch("ROLLBACK")
.map_err(error_mapper)?;
}
return Ok(serde_json::Value::Null);
},
)?;
return Ok(());
}
@@ -566,10 +702,16 @@ pub async fn write_js_runtime_files(data_dir: impl AsRef<Path>) {
#[cfg(test)]
mod tests {
use super::*;
use rustyscript::Module;
use tracing_subscriber::prelude::*;
#[tokio::test]
async fn test_serial_tests() {
tracing_subscriber::Registry::default()
.with(tracing_subscriber::filter::LevelFilter::WARN)
.set_default();
// Run on a single thread to make sure that any potential blocking is maximally bad.
let handle = RuntimeHandle::singleton_or_init_with_threads(1);
@@ -580,6 +722,7 @@ mod tests {
test_runtime_javascript_blocking(&handle).await;
test_javascript_query(&handle).await;
test_javascript_execute(&handle).await;
test_javascript_transaction(&handle).await;
}
async fn test_runtime_apply(handle: &RuntimeHandle) {
@@ -602,9 +745,6 @@ mod tests {
}
async fn test_runtime_javascript(handle: &RuntimeHandle) {
tracing_subscriber::Registry::default()
.with(tracing_subscriber::filter::LevelFilter::WARN)
.set_default();
let module = Module::new(
"module.js",
r#"
@@ -631,10 +771,6 @@ mod tests {
}
async fn test_runtime_javascript_blocking(handle: &RuntimeHandle) {
tracing_subscriber::Registry::default()
.with(tracing_subscriber::filter::LevelFilter::WARN)
.set_default();
let (ext_sender, ext_receiver) = kanal::bounded_async::<i64>(10);
{
// Register custom functions.
@@ -723,20 +859,19 @@ mod tests {
async fn test_javascript_query(handle: &RuntimeHandle) {
let conn = trailbase_sqlite::Connection::open_in_memory().unwrap();
conn
.execute("CREATE TABLE test (v0 TEXT, v1 INTEGER);", ())
.execute("CREATE TABLE 'table' (v0 TEXT, v1 INTEGER);", ())
.await
.unwrap();
conn
.execute("INSERT INTO test (v0, v1) VALUES ('0', 0), ('1', 1);", ())
.execute(
"INSERT INTO 'table' (v0, v1) VALUES ('0', 0), ('1', 1);",
(),
)
.await
.unwrap();
register_database_functions(&handle, conn);
tracing_subscriber::Registry::default()
.with(tracing_subscriber::filter::LevelFilter::WARN)
.set_default();
let module = Module::new(
"module.ts",
r#"
@@ -755,7 +890,7 @@ mod tests {
>(
Some(module),
"test_query",
vec![serde_json::json!("SELECT * FROM test")],
vec![serde_json::json!("SELECT * FROM 'table'")],
sender,
))
.await
@@ -792,9 +927,6 @@ mod tests {
register_database_functions(&handle, conn.clone());
tracing_subscriber::Registry::default()
.with(tracing_subscriber::filter::LevelFilter::WARN)
.set_default();
let module = Module::new(
"module.ts",
r#"
@@ -827,4 +959,114 @@ mod tests {
.unwrap();
assert_eq!(0, count);
}
async fn test_javascript_transaction(handle: &RuntimeHandle) {
let conn = trailbase_sqlite::Connection::open_in_memory().unwrap();
conn
.execute_batch(
r#"
CREATE TABLE 'table' (
v0 TEXT NOT NULL,
v1 INTEGER NOT NULL
);
INSERT INTO 'table' (v0, v1) VALUES ('foo', 5), ('bar', 3);
"#,
)
.await
.unwrap();
register_database_functions(&handle, conn.clone());
{
// Check that the rolled back transaction would delete 2 rows but deletes none.
let module = Module::new(
"module.ts",
r#"
import { transaction, Transaction } from "trailbase:main";
export function test_transaction_rollback() : number {
return transaction((tx: Transaction) => {
const n = tx.execute("DELETE FROM 'table' WHERE TRUE", []);
tx.rollback();
return n;
});
}
"#,
);
let (sender, receiver) = oneshot::channel();
handle
.send_to_any_isolate(build_call_sync_js_function_message::<i64>(
Some(module),
"test_transaction_rollback",
Vec::<serde_json::Value>::new(),
sender,
))
.await
.unwrap();
let rows_affected = receiver.await.unwrap().unwrap();
assert_eq!(2, rows_affected);
let count: i64 = conn
.query_row_f("SELECT COUNT(*) FROM 'table'", (), |row| row.get(0))
.await
.unwrap()
.unwrap();
assert_eq!(2, count);
}
{
// Check that the committed transaction takes effect
let module = Module::new(
"module.ts",
r#"
import { transaction, Transaction } from "trailbase:main";
export function test_transaction_commit() : [number, number] {
return transaction((tx: Transaction) => {
const count = tx.query("SELECT COUNT(*) FROM 'table'", [])[0][0];
const inserted = tx.execute("INSERT INTO 'table' (v0, v1) VALUES (?1, ?2)", ["baz", "7"]);
tx.commit();
return [count, inserted];
});
}
"#,
);
// Check that the rolled back transaction would delete 2 rows but deletes none.
let (sender, receiver) = oneshot::channel();
handle
.send_to_any_isolate(build_call_sync_js_function_message::<Vec<i64>>(
Some(module),
"test_transaction_commit",
Vec::<serde_json::Value>::new(),
sender,
))
.await
.unwrap();
let result = receiver.await.unwrap().unwrap();
assert_eq!(2, result.len());
assert_eq!(2, result[0]);
assert_eq!(1, result[1]);
let count: i64 = conn
.query_row_f("SELECT COUNT(*) FROM 'table'", (), |row| row.get(0))
.await
.unwrap()
.unwrap();
assert_eq!(3, count);
let v0: String = conn
.query_row_f(
"SELECT v0 FROM 'table' WHERE v1 = ?1",
trailbase_sqlite::params!(7),
|row| row.get(0),
)
.await
.unwrap()
.unwrap();
assert_eq!("baz", v0);
}
}
}