Add a hooks API to trailbase_sqlite::Connection.

This commit is contained in:
Sebastian Jeltsch
2024-12-16 17:07:17 +01:00
parent 746f0c1108
commit 18d1d87710
4 changed files with 91 additions and 25 deletions

View File

@@ -38,4 +38,5 @@ rusqlite = { version = "^0.32.1", default-features = false, features = [
"functions",
"limits",
"backup",
"hooks",
] }

View File

@@ -1,5 +1,5 @@
use crossbeam_channel::{Receiver, Sender};
pub use rusqlite::types::{ToSqlOutput, Value};
use rusqlite::hooks::Action;
use std::{
fmt::{self, Debug},
sync::Arc,
@@ -35,9 +35,11 @@ macro_rules! named_params {
pub type Result<T> = std::result::Result<T, Error>;
type CallFn = Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>;
type HookFn = Arc<dyn Fn(&rusqlite::Connection, Action, &str, &str, i64) + Send + Sync + 'static>;
enum Message {
Run(CallFn),
ExecuteHook(HookFn, Action, String, String, i64),
Close(oneshot::Sender<std::result::Result<(), rusqlite::Error>>),
}
@@ -49,7 +51,9 @@ pub struct Connection {
impl Connection {
pub fn from_conn(conn: rusqlite::Connection) -> Result<Self> {
return Ok(start(conn));
let (sender, receiver) = crossbeam_channel::unbounded::<Message>();
std::thread::spawn(move || event_loop(conn, receiver));
return Ok(Self { sender });
}
/// Open a new connection to an in-memory SQLite database.
@@ -58,8 +62,7 @@ impl Connection {
///
/// Will return `Err` if the underlying SQLite open call fails.
pub fn open_in_memory() -> Result<Self> {
let conn = rusqlite::Connection::open_in_memory()?;
return Ok(start(conn));
return Self::from_conn(rusqlite::Connection::open_in_memory()?);
}
/// Call a function in background thread and get the result
@@ -201,6 +204,41 @@ impl Connection {
.await;
}
pub async fn add_hook(
&self,
f: impl Fn(&rusqlite::Connection, Action, &str, &str, i64) + Send + Sync + 'static,
) -> Result<()> {
let sender = self.sender.clone();
let f = Arc::new(f);
return self
.call(|conn| {
conn.update_hook(Some(
move |action: Action, db: &str, table: &str, row: i64| {
let _ = sender.send(Message::ExecuteHook(
f.clone(),
action,
db.to_string(),
table.to_string(),
row,
));
},
));
return Ok(());
})
.await;
}
pub async fn remove_hook(&self) -> Result<()> {
return self
.call(|conn| {
conn.update_hook(None::<fn(Action, &str, &str, i64)>);
return Ok(());
})
.await;
}
/// Close the database connection.
///
/// This is functionally equivalent to the `Drop` implementation for
@@ -226,15 +264,13 @@ impl Connection {
return Ok(());
}
let result = receiver.await;
if result.is_err() {
let Ok(result) = receiver.await else {
// If we get a RecvError at this point, it also means the channel closed in the meantime
// we can assume the connection is closed
return Ok(());
}
};
result.unwrap().map_err(|e| Error::Close(self, e))
return result.map_err(|e| Error::Close(self, e));
}
}
@@ -244,24 +280,17 @@ impl Debug for Connection {
}
}
fn start(conn: rusqlite::Connection) -> Connection
// F: FnOnce() -> rusqlite::Result<rusqlite::Connection> + Send + 'static,
{
let (sender, receiver) = crossbeam_channel::unbounded::<Message>();
std::thread::spawn(move || event_loop(conn, receiver));
return Connection { sender };
}
fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver<Message>) {
const BUG_TEXT: &str = "bug in trailbase-sqlite, please report";
while let Ok(message) = receiver.recv() {
match message {
Message::Run(f) => f(&mut conn),
Message::Close(s) => {
Message::ExecuteHook(f, action, db, table, row) => f(&conn, action, &db, &table, row),
Message::Close(ch) => {
match conn.close() {
Ok(v) => s.send(Ok(v)).expect(BUG_TEXT),
Err((_conn, e)) => s.send(Err(e)).expect(BUG_TEXT),
Ok(v) => ch.send(Ok(v)).expect(BUG_TEXT),
Err((_conn, e)) => ch.send(Err(e)).expect(BUG_TEXT),
};
return;
@@ -270,8 +299,6 @@ fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver<Message>) {
}
}
const BUG_TEXT: &str = "bug in trailbase-sqlite, please report";
#[cfg(test)]
#[path = "tests.rs"]
mod tests;

View File

@@ -19,8 +19,9 @@ pub mod params;
mod rows;
pub mod schema;
pub use connection::{Connection, Value};
pub use connection::Connection;
pub use error::Error;
pub use extension::connect_sqlite;
pub use params::Params;
pub use rows::{Row, Rows, ValueType};
pub use rusqlite::types::Value;

View File

@@ -313,6 +313,43 @@ async fn test_params() {
assert_eq!(rows.0.get(0).unwrap().get::<i64>(0), Ok(4));
}
#[tokio::test]
async fn test_hooks() {
let conn = Connection::open_in_memory().unwrap();
conn
.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, text TEXT)", ())
.await
.unwrap();
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<String>();
conn
.add_hook(move |c, action, _db, table, row_id| match action {
rusqlite::hooks::Action::SQLITE_INSERT => {
let text = c
.query_row(
&format!(r#"SELECT text FROM "{table}" WHERE _rowid_ = $1"#),
[row_id],
|row| row.get::<_, String>(0),
)
.unwrap();
sender.send(text).unwrap();
}
_ => {}
})
.await
.unwrap();
conn
.execute("INSERT INTO test (id, text) VALUES (5, 'foo')", ())
.await
.unwrap();
let text = receiver.recv().await.unwrap();
assert_eq!(text, "foo");
}
// The rest is boilerplate, not really that important
#[derive(Debug, thiserror::Error)]
enum MyError {