mirror of
https://github.com/trailbaseio/trailbase.git
synced 2026-05-23 02:28:34 -05:00
Support named params and alternate placeholder syntax ?\d.
This commit is contained in:
Generated
+1
@@ -7046,6 +7046,7 @@ dependencies = [
|
||||
"pgrow2serde",
|
||||
"postgres",
|
||||
"rand 0.10.1",
|
||||
"regex",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"serde_rusqlite",
|
||||
|
||||
@@ -21,18 +21,18 @@ harness = false
|
||||
|
||||
[features]
|
||||
default = ["pg"]
|
||||
pg = ["dep:postgres", "dep:bytes", "dep:pgrow2serde"]
|
||||
pg = ["dep:postgres", "dep:bytes", "dep:pgrow2serde", "dep:regex"]
|
||||
|
||||
[dependencies]
|
||||
bytes = { version = "1.11.1", optional = true }
|
||||
flume = { workspace = true }
|
||||
log = { version = "^0.4.21", default-features = false }
|
||||
parking_lot = { workspace = true }
|
||||
pgrow2serde = { git = "https://github.com/deknowny/pgrow2serde", optional = true }
|
||||
postgres = { version = "0.19.12", optional = true, features = ["with-serde_json-1"] }
|
||||
regex = { version = "1.12.3", optional = true}
|
||||
rusqlite = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
# serde-pgrow = { version = "0.3.6", optional = true }
|
||||
pgrow2serde = { git = "https://github.com/deknowny/pgrow2serde", optional = true }
|
||||
serde_rusqlite = { workspace = true }
|
||||
smallvec = { version = "1.15.1", features = ["const_generics"] }
|
||||
thiserror = "2.0.12"
|
||||
|
||||
@@ -668,7 +668,7 @@ mod tests {
|
||||
|
||||
let row = conn
|
||||
.read_query_row(
|
||||
"SELECT COUNT(*) FROM test_table_poc_generic WHERE data = $1",
|
||||
"SELECT COUNT(*) FROM test_table_poc_generic WHERE data = ?1",
|
||||
("a".to_string(),),
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -5,8 +5,7 @@ use tokio::sync::oneshot;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::params::Params;
|
||||
use crate::pg::util::bind;
|
||||
use crate::value::Value;
|
||||
use crate::pg::util::PgStatement;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct Options {
|
||||
@@ -166,8 +165,8 @@ impl Executor {
|
||||
{
|
||||
return self
|
||||
.call(move |conn: &mut postgres::Client| {
|
||||
let params: Vec<Value> = bind(sql.as_ref(), params)?;
|
||||
return f(conn.query_raw(sql.as_ref(), ¶ms)?);
|
||||
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
|
||||
return f(conn.query_raw(&sql, ¶ms)?);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
@@ -211,11 +210,12 @@ fn event_loop(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::named_params;
|
||||
|
||||
use postgres::{Client, NoTls, fallible_iterator::FallibleIterator};
|
||||
|
||||
#[tokio::test]
|
||||
async fn pg_poc_test() {
|
||||
let exec = Executor::new(
|
||||
fn build_executor() -> Result<Executor, Error> {
|
||||
return Executor::new(
|
||||
|| {
|
||||
return Client::configure()
|
||||
.host("localhost")
|
||||
@@ -227,9 +227,12 @@ mod tests {
|
||||
Options {
|
||||
num_threads: Some(2),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pg_poc_test() {
|
||||
let exec = build_executor().unwrap();
|
||||
assert_eq!(2, exec.threads());
|
||||
|
||||
exec
|
||||
@@ -274,4 +277,38 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pg_poc_named_parameter_test() {
|
||||
let exec = build_executor().unwrap();
|
||||
assert_eq!(2, exec.threads());
|
||||
|
||||
exec
|
||||
.call(|client| {
|
||||
return client.execute(
|
||||
"
|
||||
CREATE TABLE IF NOT EXISTS test_table_poc_named_params(
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
",
|
||||
&[],
|
||||
);
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
exec
|
||||
.query_rows_f(
|
||||
"
|
||||
INSERT INTO test_table_poc_named_params (data) VALUES (:param);
|
||||
",
|
||||
named_params! {":param": "value"},
|
||||
|_rows| -> Result<(), Error> {
|
||||
return Ok(());
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,18 +4,16 @@ use postgres::fallible_iterator::FallibleIterator;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::params::Params;
|
||||
use crate::pg::util::bind;
|
||||
use crate::pg::util::{columns, from_row, from_rows};
|
||||
use crate::pg::util::{PgStatement, columns, from_row, from_rows};
|
||||
use crate::rows::{Row, Rows};
|
||||
use crate::traits::SyncConnection as SyncConnectionTrait;
|
||||
use crate::traits::SyncTransaction as SyncTransactionTrait;
|
||||
use crate::value::Value;
|
||||
|
||||
impl SyncConnectionTrait for postgres::Client {
|
||||
// Queries the first row and returns it if present, otherwise `None`.
|
||||
fn query_row(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<Option<Row>, Error> {
|
||||
let params: Vec<Value> = bind(sql.as_ref(), params)?;
|
||||
let mut row_iter = self.query_raw(sql.as_ref(), ¶ms)?;
|
||||
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
|
||||
let mut row_iter = self.query_raw(&sql, ¶ms)?;
|
||||
|
||||
if let Some(row) = row_iter.next()? {
|
||||
return Ok(Some(from_row(&row, Arc::new(columns(&row)))?));
|
||||
@@ -25,16 +23,14 @@ impl SyncConnectionTrait for postgres::Client {
|
||||
}
|
||||
|
||||
fn query_rows(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<Rows, Error> {
|
||||
let params: Vec<Value> = bind(sql.as_ref(), params)?;
|
||||
let row_iter = self.query_raw(sql.as_ref(), ¶ms)?;
|
||||
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
|
||||
let row_iter = self.query_raw(&sql, ¶ms)?;
|
||||
return from_rows(row_iter);
|
||||
}
|
||||
|
||||
fn execute(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<usize, Error> {
|
||||
let params: Vec<Value> = bind(sql.as_ref(), params)?;
|
||||
// let param_refs: Vec<&(dyn postgres::types::ToSql + Sync)> = params.iter().collect();
|
||||
// return Ok(self.execute(sql.as_ref(), ¶m_refs)? as usize);
|
||||
let row_iter = self.query_raw(sql.as_ref(), params)?;
|
||||
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
|
||||
let row_iter = self.query_raw(&sql, params)?;
|
||||
return Ok(row_iter.rows_affected().unwrap_or_default() as usize);
|
||||
}
|
||||
|
||||
@@ -46,8 +42,8 @@ impl SyncConnectionTrait for postgres::Client {
|
||||
impl<'a> SyncConnectionTrait for postgres::Transaction<'a> {
|
||||
// Queries the first row and returns it if present, otherwise `None`.
|
||||
fn query_row(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<Option<Row>, Error> {
|
||||
let params: Vec<Value> = bind(sql.as_ref(), params)?;
|
||||
let mut row_iter = self.query_raw(sql.as_ref(), ¶ms)?;
|
||||
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
|
||||
let mut row_iter = self.query_raw(&sql, ¶ms)?;
|
||||
|
||||
if let Some(row) = row_iter.next()? {
|
||||
return Ok(Some(from_row(&row, Arc::new(columns(&row)))?));
|
||||
@@ -57,14 +53,14 @@ impl<'a> SyncConnectionTrait for postgres::Transaction<'a> {
|
||||
}
|
||||
|
||||
fn query_rows(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<Rows, Error> {
|
||||
let params: Vec<Value> = bind(sql.as_ref(), params)?;
|
||||
let row_iter = self.query_raw(sql.as_ref(), ¶ms)?;
|
||||
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
|
||||
let row_iter = self.query_raw(&sql, ¶ms)?;
|
||||
return from_rows(row_iter);
|
||||
}
|
||||
|
||||
fn execute(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<usize, Error> {
|
||||
let params: Vec<Value> = bind(sql.as_ref(), params)?;
|
||||
let row_iter = self.query_raw(sql.as_ref(), params)?;
|
||||
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
|
||||
let row_iter = self.query_raw(&sql, params)?;
|
||||
return Ok(row_iter.rows_affected().unwrap_or_default() as usize);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use postgres::fallible_iterator::FallibleIterator;
|
||||
use std::sync::Arc;
|
||||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::params::Params;
|
||||
@@ -9,11 +11,64 @@ use crate::to_sql::ToSqlProxy;
|
||||
use crate::value::Value;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PgStatement<'a> {
|
||||
pub(crate) struct PgStatement<'a> {
|
||||
#[allow(unused)]
|
||||
sql: &'a str,
|
||||
// TODO: Can we use ToSqlProxy here?
|
||||
params: &'a mut Vec<(usize, Value)>,
|
||||
|
||||
// TODO: Could we use ToSqlProxy here to reduce copies?
|
||||
params: Vec<(usize, Value)>,
|
||||
placeholders: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
impl<'a> PgStatement<'a> {
|
||||
pub fn new(sql: &'a str) -> Result<Self, Error> {
|
||||
static NAMED_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"(?<named>:[[:alpha:]][[:alnum:]]*)").expect("startup"));
|
||||
|
||||
let mut placeholders: HashMap<String, usize> = Default::default();
|
||||
for (idx, cap) in NAMED_RE.captures_iter(sql).enumerate() {
|
||||
let named_params = &cap["named"];
|
||||
placeholders.insert(named_params.to_string(), idx + 1);
|
||||
}
|
||||
|
||||
return Ok(Self {
|
||||
sql,
|
||||
params: vec![],
|
||||
placeholders,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn bind(mut self, params: impl Params) -> Result<(String, Vec<Value>), Error> {
|
||||
params.bind(&mut self)?;
|
||||
|
||||
let Self {
|
||||
sql,
|
||||
placeholders,
|
||||
mut params,
|
||||
} = self;
|
||||
|
||||
// TODO: Do we need further validation, e.g. that indexes are consecutive, that they're
|
||||
// matching the SQL...?
|
||||
let bound_params = {
|
||||
params.sort_by(|a, b| {
|
||||
return a.0.cmp(&b.0);
|
||||
});
|
||||
params.into_iter().map(|p| p.1).collect()
|
||||
};
|
||||
|
||||
// Also support "?1" placeholders like sqlite (PG only supports $1).
|
||||
static RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"[?](?<index>\d+)").expect("startup"));
|
||||
|
||||
let mut sql = RE.replace_all(sql, "$$$index").to_string();
|
||||
|
||||
// TODO: We should probably do this along the initial parse when we find the placeholders.
|
||||
for (name, idx) in placeholders {
|
||||
sql = sql.replace(&name, &format!("${idx}"));
|
||||
}
|
||||
|
||||
return Ok((sql, bound_params));
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Statement for PgStatement<'a> {
|
||||
@@ -22,29 +77,16 @@ impl<'a> Statement for PgStatement<'a> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn parameter_index(&self, _name: &str) -> Result<Option<usize>, Error> {
|
||||
return Err(Error::Other("not implemented: parse `self.sql`".into()));
|
||||
/// Will return Err if `name` is invalid. Will return Ok(None) if the name
|
||||
/// is valid but not a bound parameter of this statement.
|
||||
fn parameter_index(&self, name: &str) -> Result<Option<usize>, Error> {
|
||||
if &name[0..1] != ":" || name[1..].chars().any(|c| !c.is_ascii_alphanumeric()) {
|
||||
return Err(Error::Other(format!("invalid param name: {name}").into()));
|
||||
}
|
||||
return Ok(self.placeholders.get(name).cloned());
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn bind(sql: &str, params: impl Params) -> Result<Vec<Value>, Error> {
|
||||
let mut bound: Vec<(usize, Value)> = vec![];
|
||||
let mut stmt = PgStatement {
|
||||
sql,
|
||||
params: &mut bound,
|
||||
};
|
||||
params.bind(&mut stmt)?;
|
||||
|
||||
bound.sort_by(|a, b| {
|
||||
return a.0.cmp(&b.0);
|
||||
});
|
||||
|
||||
// TODO: Do we need further validation, e.g. that indexes are consecutive?
|
||||
|
||||
return Ok(bound.into_iter().map(|p| p.1).collect());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn map_first<T>(
|
||||
mut rows: postgres::RowIter<'_>,
|
||||
@@ -110,3 +152,32 @@ pub(crate) fn columns(row: &postgres::Row) -> Vec<Column> {
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::named_params;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn pg_statement_test() {
|
||||
let (sql, params) = PgStatement::new("INSERT INTO 'table' (col) VALUES (?1), (?1)")
|
||||
.unwrap()
|
||||
.bind(("foo",))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!("INSERT INTO 'table' (col) VALUES ($1), ($1)", sql);
|
||||
assert_eq!(Value::Text("foo".to_string()), *params.first().unwrap());
|
||||
|
||||
let (sql, params) = PgStatement::new("INSERT INTO 'table' (col) VALUES (:p0), (:p1)")
|
||||
.unwrap()
|
||||
.bind(named_params! {":p0": "p0", ":p1": "p1"})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!("INSERT INTO 'table' (col) VALUES ($1), ($2)", sql);
|
||||
assert_eq!(
|
||||
vec![Value::Text("p0".to_string()), Value::Text("p1".to_string())],
|
||||
params,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user