mirror of
https://github.com/trailbaseio/trailbase.git
synced 2026-01-06 01:40:12 -06:00
Add a hooks API to trailbase_sqlite::Connection.
This commit is contained in:
@@ -38,4 +38,5 @@ rusqlite = { version = "^0.32.1", default-features = false, features = [
|
||||
"functions",
|
||||
"limits",
|
||||
"backup",
|
||||
"hooks",
|
||||
] }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crossbeam_channel::{Receiver, Sender};
|
||||
pub use rusqlite::types::{ToSqlOutput, Value};
|
||||
use rusqlite::hooks::Action;
|
||||
use std::{
|
||||
fmt::{self, Debug},
|
||||
sync::Arc,
|
||||
@@ -35,9 +35,11 @@ macro_rules! named_params {
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
type CallFn = Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>;
|
||||
type HookFn = Arc<dyn Fn(&rusqlite::Connection, Action, &str, &str, i64) + Send + Sync + 'static>;
|
||||
|
||||
enum Message {
|
||||
Run(CallFn),
|
||||
ExecuteHook(HookFn, Action, String, String, i64),
|
||||
Close(oneshot::Sender<std::result::Result<(), rusqlite::Error>>),
|
||||
}
|
||||
|
||||
@@ -49,7 +51,9 @@ pub struct Connection {
|
||||
|
||||
impl Connection {
|
||||
pub fn from_conn(conn: rusqlite::Connection) -> Result<Self> {
|
||||
return Ok(start(conn));
|
||||
let (sender, receiver) = crossbeam_channel::unbounded::<Message>();
|
||||
std::thread::spawn(move || event_loop(conn, receiver));
|
||||
return Ok(Self { sender });
|
||||
}
|
||||
|
||||
/// Open a new connection to an in-memory SQLite database.
|
||||
@@ -58,8 +62,7 @@ impl Connection {
|
||||
///
|
||||
/// Will return `Err` if the underlying SQLite open call fails.
|
||||
pub fn open_in_memory() -> Result<Self> {
|
||||
let conn = rusqlite::Connection::open_in_memory()?;
|
||||
return Ok(start(conn));
|
||||
return Self::from_conn(rusqlite::Connection::open_in_memory()?);
|
||||
}
|
||||
|
||||
/// Call a function in background thread and get the result
|
||||
@@ -201,6 +204,41 @@ impl Connection {
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn add_hook(
|
||||
&self,
|
||||
f: impl Fn(&rusqlite::Connection, Action, &str, &str, i64) + Send + Sync + 'static,
|
||||
) -> Result<()> {
|
||||
let sender = self.sender.clone();
|
||||
let f = Arc::new(f);
|
||||
|
||||
return self
|
||||
.call(|conn| {
|
||||
conn.update_hook(Some(
|
||||
move |action: Action, db: &str, table: &str, row: i64| {
|
||||
let _ = sender.send(Message::ExecuteHook(
|
||||
f.clone(),
|
||||
action,
|
||||
db.to_string(),
|
||||
table.to_string(),
|
||||
row,
|
||||
));
|
||||
},
|
||||
));
|
||||
|
||||
return Ok(());
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn remove_hook(&self) -> Result<()> {
|
||||
return self
|
||||
.call(|conn| {
|
||||
conn.update_hook(None::<fn(Action, &str, &str, i64)>);
|
||||
return Ok(());
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Close the database connection.
|
||||
///
|
||||
/// This is functionally equivalent to the `Drop` implementation for
|
||||
@@ -226,15 +264,13 @@ impl Connection {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let result = receiver.await;
|
||||
|
||||
if result.is_err() {
|
||||
let Ok(result) = receiver.await else {
|
||||
// If we get a RecvError at this point, it also means the channel closed in the meantime
|
||||
// we can assume the connection is closed
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
result.unwrap().map_err(|e| Error::Close(self, e))
|
||||
return result.map_err(|e| Error::Close(self, e));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,24 +280,17 @@ impl Debug for Connection {
|
||||
}
|
||||
}
|
||||
|
||||
fn start(conn: rusqlite::Connection) -> Connection
|
||||
// F: FnOnce() -> rusqlite::Result<rusqlite::Connection> + Send + 'static,
|
||||
{
|
||||
let (sender, receiver) = crossbeam_channel::unbounded::<Message>();
|
||||
|
||||
std::thread::spawn(move || event_loop(conn, receiver));
|
||||
|
||||
return Connection { sender };
|
||||
}
|
||||
|
||||
fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver<Message>) {
|
||||
const BUG_TEXT: &str = "bug in trailbase-sqlite, please report";
|
||||
|
||||
while let Ok(message) = receiver.recv() {
|
||||
match message {
|
||||
Message::Run(f) => f(&mut conn),
|
||||
Message::Close(s) => {
|
||||
Message::ExecuteHook(f, action, db, table, row) => f(&conn, action, &db, &table, row),
|
||||
Message::Close(ch) => {
|
||||
match conn.close() {
|
||||
Ok(v) => s.send(Ok(v)).expect(BUG_TEXT),
|
||||
Err((_conn, e)) => s.send(Err(e)).expect(BUG_TEXT),
|
||||
Ok(v) => ch.send(Ok(v)).expect(BUG_TEXT),
|
||||
Err((_conn, e)) => ch.send(Err(e)).expect(BUG_TEXT),
|
||||
};
|
||||
|
||||
return;
|
||||
@@ -270,8 +299,6 @@ fn event_loop(mut conn: rusqlite::Connection, receiver: Receiver<Message>) {
|
||||
}
|
||||
}
|
||||
|
||||
const BUG_TEXT: &str = "bug in trailbase-sqlite, please report";
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests.rs"]
|
||||
mod tests;
|
||||
|
||||
@@ -19,8 +19,9 @@ pub mod params;
|
||||
mod rows;
|
||||
pub mod schema;
|
||||
|
||||
pub use connection::{Connection, Value};
|
||||
pub use connection::Connection;
|
||||
pub use error::Error;
|
||||
pub use extension::connect_sqlite;
|
||||
pub use params::Params;
|
||||
pub use rows::{Row, Rows, ValueType};
|
||||
pub use rusqlite::types::Value;
|
||||
|
||||
@@ -313,6 +313,43 @@ async fn test_params() {
|
||||
assert_eq!(rows.0.get(0).unwrap().get::<i64>(0), Ok(4));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hooks() {
|
||||
let conn = Connection::open_in_memory().unwrap();
|
||||
|
||||
conn
|
||||
.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, text TEXT)", ())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<String>();
|
||||
conn
|
||||
.add_hook(move |c, action, _db, table, row_id| match action {
|
||||
rusqlite::hooks::Action::SQLITE_INSERT => {
|
||||
let text = c
|
||||
.query_row(
|
||||
&format!(r#"SELECT text FROM "{table}" WHERE _rowid_ = $1"#),
|
||||
[row_id],
|
||||
|row| row.get::<_, String>(0),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
sender.send(text).unwrap();
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
conn
|
||||
.execute("INSERT INTO test (id, text) VALUES (5, 'foo')", ())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let text = receiver.recv().await.unwrap();
|
||||
assert_eq!(text, "foo");
|
||||
}
|
||||
|
||||
// The rest is boilerplate, not really that important
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum MyError {
|
||||
|
||||
Reference in New Issue
Block a user