Back-port tb-sqlite improvements: support for ArcLock, improve connectoin setup, and remove no longer needed add_preupdate_hook.

This commit is contained in:
Sebastian Jeltsch
2025-08-25 11:35:26 +02:00
parent 528ed59f6d
commit bb91f3e11f
11 changed files with 145 additions and 117 deletions

1
Cargo.lock generated
View File

@@ -7607,6 +7607,7 @@ dependencies = [
"parking_lot",
"rand 0.9.2",
"rusqlite",
"self_cell",
"serde",
"serde_json",
"serde_rusqlite",

View File

@@ -61,6 +61,7 @@ askama = { version = "0.14.0", default-features = false, features = ["derive", "
axum = { version = "^0.8.1", features = ["multipart"] }
env_logger = { version = "^0.11.8", default-features = false, features = ["auto-color", "humantime"] }
libsqlite3-sys = { version = "0.35.0", default-features = false, features = ["bundled", "preupdate_hook"] }
parking_lot = { version = "0.12.3", default-features = false, features = ["send_guard", "arc_lock"] }
rusqlite = { version = "0.37.0", default-features = false, features = ["bundled", "column_decltype", "functions", "backup", "preupdate_hook"] }
rust-embed = { version = "8.4.0", default-features = false, features = ["mime-guess"] }
tokio = { version = "^1.38.0", default-features = false, features = ["macros", "net", "rt-multi-thread", "fs", "signal", "time", "sync"] }

View File

@@ -15,7 +15,7 @@ exclude = [
eventsource-stream = { version = "0.2.3", features = [] }
futures-lite = "2.6.1"
jsonwebtoken = { version = "9.3.0", default-features = false }
parking_lot = "0.12.3"
parking_lot = { workspace = true }
reqwest = { version = "0.12.8", features = ["stream"] }
serde = { version = "1.0.217", features = ["derive"] }
serde_json = "1.0.135"

View File

@@ -55,7 +55,7 @@ mini-moka = "0.10.3"
minijinja = { version = "2.1.2", default-features = false }
oauth2 = { version = "5.0.0-alpha.4", default-features = false, features = ["reqwest", "rustls-tls"] }
object_store = { version = "0.12.0", default-features = false, features = ["aws", "fs"] }
parking_lot = { version = "0.12.3", default-features = false }
parking_lot = { workspace = true }
pin-project-lite = "0.2.16"
prost = { version = "^0.14.1", default-features = false }
prost-reflect = { version = "^0.16.0", default-features = false, features = ["derive", "text-format"] }

View File

@@ -19,7 +19,7 @@ jsonschema = { version = "0.32.0", default-features = false }
log = "0.4.27"
maxminddb = "0.26.0"
mini-moka = "0.10.3"
parking_lot = { version = "0.12.3", default-features = false }
parking_lot = { workspace = true }
regex = "1.11.0"
rusqlite = { workspace = true }
serde = { version = "^1.0.203", features = ["derive"] }

View File

@@ -16,7 +16,7 @@ bytes = { version = "1.8.0", features = ["serde"] }
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
kanal = "0.1.1"
log = { version = "^0.4.21", default-features = false }
parking_lot = { version = "0.12.3", default-features = false }
parking_lot = { workspace = true }
rusqlite = { workspace = true }
rust-embed = { workspace = true }
rustyscript = { version = "^0.12.0", features = ["web", "fs"] }

View File

@@ -16,7 +16,7 @@ itertools = "0.14.0"
jsonschema = { version = "0.32.0", default-features = false }
lazy_static = "1.5.0"
log = { version = "^0.4.21", default-features = false }
parking_lot = { version = "0.12.3", default-features = false }
parking_lot = { workspace = true }
regex = "1.11.1"
rusqlite = { workspace = true }
schemars = "1.0.0"

View File

@@ -18,8 +18,9 @@ harness = false
crossbeam-channel = "0.5.13"
kanal = "0.1.1"
log = { version = "^0.4.21", default-features = false }
parking_lot = { version = "0.12.3", default-features = false }
parking_lot = { workspace = true }
rusqlite = { workspace = true }
self_cell = "1.2.0"
serde = { version = "^1.0.203", features = ["derive"] }
serde_json = "1.0.122"
serde_rusqlite = "0.40"

View File

@@ -2,7 +2,7 @@ use kanal::{Receiver, Sender};
use log::*;
use parking_lot::RwLock;
use rusqlite::fallible_iterator::FallibleIterator;
use rusqlite::hooks::{Action, PreUpdateCase};
use rusqlite::hooks::PreUpdateCase;
use rusqlite::types::Value;
use std::ops::{Deref, DerefMut};
use std::{
@@ -42,11 +42,13 @@ pub struct Database {
pub name: String,
}
struct LockedConnections(RwLock<Vec<rusqlite::Connection>>);
#[derive(Default)]
struct ConnectionVec(Vec<rusqlite::Connection>);
// NOTE: We must never access the same connection concurrently even as &Connection, due to
// Statement cache. We can ensure this by uniquely assigning one connection to each thread.
unsafe impl Sync for LockedConnections {}
// NOTE: We must never access the same connection concurrently even as immutable &Connection, due
// to intrinsic statement cache. We can ensure this by uniquely assigning one connection to each
// thread.
unsafe impl Sync for ConnectionVec {}
/// The result returned on method calls in this crate.
pub type Result<T> = std::result::Result<T, Error>;
@@ -77,7 +79,7 @@ impl Default for Options {
pub struct Connection {
reader: Sender<Message>,
writer: Sender<Message>,
conns: Arc<LockedConnections>,
conns: Arc<RwLock<ConnectionVec>>,
}
impl Connection {
@@ -93,62 +95,68 @@ impl Connection {
return Ok(conn);
};
let conn = new_conn()?;
let name = conn.path().and_then(|s| {
let write_conn = new_conn()?;
let in_memory = write_conn.path().map_or(true, |s| {
// Returns empty string for in-memory databases.
if s.is_empty() {
None
} else {
Some(s.to_string())
}
return !s.is_empty();
});
let n_read_threads = if name.is_some() {
let n_read_threads = match opt.as_ref().map_or(0, |o| o.n_read_threads) {
1 => {
warn!(
"Using a single dedicated reader thread won't improve performance, falling back to 0."
);
0
}
n => n,
};
if let Ok(n) = std::thread::available_parallelism() {
if n_read_threads > n.get() {
debug!(
"Using {n_read_threads} exceeding hardware parallelism: {}",
n.get()
);
}
let n_read_threads: i64 = match (in_memory, opt.as_ref().map_or(0, |o| o.n_read_threads)) {
(true, _) => {
// We cannot share an in-memory database across threads, they're all independent.
0
}
(false, 1) => {
warn!("A single reader thread won't improve performance, falling back to 0.");
0
}
(false, n) => {
if let Ok(max) = std::thread::available_parallelism() {
if n > max.get() {
debug!(
"Num read threads '{n}' exceeds hardware parallelism: {}",
max.get()
);
}
}
n as i64
}
n_read_threads
} else {
// We cannot share an in-memory database across threads, they're all independent.
0
};
let conns = {
let mut conns = vec![conn];
for _ in 0..n_read_threads {
let conns = Arc::new(RwLock::new(ConnectionVec({
let mut conns = vec![write_conn];
for _ in 0..(n_read_threads - 1).max(0) {
conns.push(new_conn()?);
}
conns
})));
Arc::new(LockedConnections(RwLock::new(conns)))
};
assert_eq!(n_read_threads.max(1) as usize, conns.read().0.len());
// Spawn writer.
let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
let conns_clone = conns.clone();
std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver));
{
let conns = conns.clone();
std::thread::Builder::new()
.name("tb-sqlite-writer".to_string())
.spawn(move || event_loop(0, conns, shared_write_receiver))
.expect("startup");
}
// Spawn readers.
let shared_read_sender = if n_read_threads > 0 {
let (shared_read_sender, shared_read_receiver) = kanal::unbounded::<Message>();
for i in 0..n_read_threads {
// NOTE: read and writer threads are sharing the first conn, given they're mutually
// exclusive.
let index = i as usize;
let shared_read_receiver = shared_read_receiver.clone();
let conns_clone = conns.clone();
std::thread::spawn(move || event_loop(i, conns_clone, shared_read_receiver));
let conns = conns.clone();
std::thread::Builder::new()
.name(format!("tb-sqlite-reader-{index}"))
.spawn(move || event_loop(index, conns, shared_read_receiver))
.expect("startup");
}
shared_read_sender
} else {
@@ -156,8 +164,8 @@ impl Connection {
};
debug!(
"Opened SQLite DB '{name}' with {n_read_threads} dedicated reader threads",
name = name.as_deref().unwrap_or("<in-memory>")
"Opened SQLite DB '{}' with {n_read_threads} reader threads",
conns.read().0[0].path().unwrap_or("<in-memory>")
);
return Ok(Self {
@@ -171,7 +179,7 @@ impl Connection {
use parking_lot::lock_api::RwLock;
let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
let conns = Arc::new(LockedConnections(RwLock::new(vec![conn])));
let conns = Arc::new(RwLock::new(ConnectionVec(vec![conn])));
let conns_clone = conns.clone();
std::thread::spawn(move || event_loop(0, conns_clone, shared_write_receiver));
@@ -194,7 +202,7 @@ impl Connection {
#[inline]
pub fn write_lock(&self) -> LockGuard<'_> {
return LockGuard {
guard: self.conns.0.write(),
guard: self.conns.write(),
};
}
@@ -202,11 +210,25 @@ impl Connection {
pub fn try_write_lock_for(&self, duration: tokio::time::Duration) -> Option<LockGuard<'_>> {
return self
.conns
.0
.try_write_for(duration)
.map(|guard| LockGuard { guard });
}
#[inline]
pub fn write_arc_lock(&self) -> ArcLockGuard {
return ArcLockGuard {
guard: self.conns.write_arc(),
};
}
#[inline]
pub fn try_write_arc_lock_for(&self, duration: tokio::time::Duration) -> Option<ArcLockGuard> {
return self
.conns
.try_write_arc_for(duration)
.map(|guard| ArcLockGuard { guard });
}
/// Call a function in background thread and get the result
/// asynchronously.
///
@@ -476,22 +498,9 @@ impl Connection {
.await;
}
/// Convenience API for (un)setting a new pre-update hook.
pub async fn add_preupdate_hook(
&self,
hook: Option<impl Fn(Action, &str, &str, &PreUpdateCase) + Send + Sync + 'static>,
) -> Result<()> {
return self
.call(move |conn| {
conn.preupdate_hook(hook);
return Ok(());
})
.await;
}
pub fn attach(&self, path: &str, name: &str) -> Result<()> {
let lock = self.conns.0.write();
for conn in &*lock {
let lock = self.conns.write();
for conn in &lock.0 {
conn.execute(&format!("ATTACH DATABASE '{path}' AS {name} "), ())?;
}
return Ok(());
@@ -530,8 +539,8 @@ impl Connection {
}
let mut errors = vec![];
let conns: Vec<_> = std::mem::take(&mut self.conns.0.write());
for conn in conns {
let conns: ConnectionVec = std::mem::take(&mut self.conns.write());
for conn in conns.0 {
if let Err((_, err)) = conn.close() {
errors.push(err);
};
@@ -552,16 +561,16 @@ impl Debug for Connection {
}
}
fn event_loop(id: usize, conns: Arc<LockedConnections>, receiver: Receiver<Message>) {
fn event_loop(id: usize, conns: Arc<RwLock<ConnectionVec>>, receiver: Receiver<Message>) {
while let Ok(message) = receiver.recv() {
match message {
Message::RunConst(f) => {
let lock = conns.0.read();
f(&lock[id])
let lock = conns.read();
f(&lock.0[id])
}
Message::RunMut(f) => {
let mut lock = conns.0.write();
f(&mut lock[0])
let mut lock = conns.write();
f(&mut lock.0[0])
}
Message::Terminate => {
return;
@@ -615,21 +624,40 @@ pub fn extract_record_values(case: &PreUpdateCase) -> Option<Vec<Value>> {
}
pub struct LockGuard<'a> {
guard: parking_lot::RwLockWriteGuard<'a, Vec<rusqlite::Connection>>,
guard: parking_lot::RwLockWriteGuard<'a, ConnectionVec>,
}
impl Deref for LockGuard<'_> {
type Target = rusqlite::Connection;
#[inline]
fn deref(&self) -> &rusqlite::Connection {
return &self.guard.deref()[0];
return &self.guard.deref().0[0];
}
}
impl DerefMut for LockGuard<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut rusqlite::Connection {
return &mut self.guard.deref_mut()[0];
return &mut self.guard.deref_mut().0[0];
}
}
pub struct ArcLockGuard {
guard: parking_lot::ArcRwLockWriteGuard<parking_lot::RawRwLock, ConnectionVec>,
}
impl Deref for ArcLockGuard {
type Target = rusqlite::Connection;
#[inline]
fn deref(&self) -> &rusqlite::Connection {
return &self.guard.deref().0[0];
}
}
impl DerefMut for ArcLockGuard {
#[inline]
fn deref_mut(&mut self) -> &mut rusqlite::Connection {
return &mut self.guard.deref_mut().0[0];
}
}

View File

@@ -125,7 +125,7 @@ pub(crate) fn columns(stmt: &Statement<'_>) -> Vec<Column> {
}
#[derive(Debug)]
pub struct Row(Vec<types::Value>, Arc<Vec<Column>>);
pub struct Row(pub Vec<types::Value>, pub Arc<Vec<Column>>);
impl Row {
pub(crate) fn from_row(row: &rusqlite::Row, cols: Arc<Vec<Column>>) -> rusqlite::Result<Self> {

View File

@@ -459,42 +459,39 @@ async fn test_hooks() {
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<String>();
let c = conn.clone();
conn
.add_preupdate_hook(Some(
move |action: rusqlite::hooks::Action, _db: &str, table_name: &str, case: &PreUpdateCase| {
let row_id = extract_row_id(case).unwrap();
let state = State {
action,
table_name: table_name.to_string(),
row_id,
conn.write_lock().preupdate_hook(Some(
move |action: rusqlite::hooks::Action, _db: &str, table_name: &str, case: &PreUpdateCase| {
let row_id = extract_row_id(case).unwrap();
let state = State {
action,
table_name: table_name.to_string(),
row_id,
};
let sender = sender.clone();
c.call_and_forget(move |conn| {
match state.action {
rusqlite::hooks::Action::SQLITE_INSERT => {
let text = conn
.query_row(
&format!(
r#"SELECT text FROM "{}" WHERE _rowid_ = $1"#,
state.table_name
),
[state.row_id],
|row| row.get::<_, String>(0),
)
.unwrap();
sender.send(text).unwrap();
}
_ => {
panic!("unexpected action: {:?}", state.action);
}
};
let sender = sender.clone();
c.call_and_forget(move |conn| {
match state.action {
rusqlite::hooks::Action::SQLITE_INSERT => {
let text = conn
.query_row(
&format!(
r#"SELECT text FROM "{}" WHERE _rowid_ = $1"#,
state.table_name
),
[state.row_id],
|row| row.get::<_, String>(0),
)
.unwrap();
sender.send(text).unwrap();
}
_ => {
panic!("unexpected action: {:?}", state.action);
}
};
});
},
))
.await
.unwrap();
});
},
));
conn
.execute("INSERT INTO test (id, text) VALUES (5, 'foo')", ())