Support named params and alternate placeholder syntax ?\d.

This commit is contained in:
Sebastian Jeltsch
2026-04-29 11:35:43 +02:00
parent 5e52dab7a8
commit 1c5b01969c
6 changed files with 159 additions and 54 deletions
Generated
+1
View File
@@ -7046,6 +7046,7 @@ dependencies = [
"pgrow2serde",
"postgres",
"rand 0.10.1",
"regex",
"rusqlite",
"serde",
"serde_rusqlite",
+3 -3
View File
@@ -21,18 +21,18 @@ harness = false
[features]
default = ["pg"]
pg = ["dep:postgres", "dep:bytes", "dep:pgrow2serde"]
pg = ["dep:postgres", "dep:bytes", "dep:pgrow2serde", "dep:regex"]
[dependencies]
bytes = { version = "1.11.1", optional = true }
flume = { workspace = true }
log = { version = "^0.4.21", default-features = false }
parking_lot = { workspace = true }
pgrow2serde = { git = "https://github.com/deknowny/pgrow2serde", optional = true }
postgres = { version = "0.19.12", optional = true, features = ["with-serde_json-1"] }
regex = { version = "1.12.3", optional = true}
rusqlite = { workspace = true }
serde = { workspace = true }
# serde-pgrow = { version = "0.3.6", optional = true }
pgrow2serde = { git = "https://github.com/deknowny/pgrow2serde", optional = true }
serde_rusqlite = { workspace = true }
smallvec = { version = "1.15.1", features = ["const_generics"] }
thiserror = "2.0.12"
+1 -1
View File
@@ -668,7 +668,7 @@ mod tests {
let row = conn
.read_query_row(
"SELECT COUNT(*) FROM test_table_poc_generic WHERE data = $1",
"SELECT COUNT(*) FROM test_table_poc_generic WHERE data = ?1",
("a".to_string(),),
)
.await
+46 -9
View File
@@ -5,8 +5,7 @@ use tokio::sync::oneshot;
use crate::error::Error;
use crate::params::Params;
use crate::pg::util::bind;
use crate::value::Value;
use crate::pg::util::PgStatement;
#[derive(Clone, Default)]
pub struct Options {
@@ -166,8 +165,8 @@ impl Executor {
{
return self
.call(move |conn: &mut postgres::Client| {
let params: Vec<Value> = bind(sql.as_ref(), params)?;
return f(conn.query_raw(sql.as_ref(), &params)?);
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
return f(conn.query_raw(&sql, &params)?);
})
.await;
}
@@ -211,11 +210,12 @@ fn event_loop(
#[cfg(test)]
mod tests {
use super::*;
use crate::named_params;
use postgres::{Client, NoTls, fallible_iterator::FallibleIterator};
#[tokio::test]
async fn pg_poc_test() {
let exec = Executor::new(
fn build_executor() -> Result<Executor, Error> {
return Executor::new(
|| {
return Client::configure()
.host("localhost")
@@ -227,9 +227,12 @@ mod tests {
Options {
num_threads: Some(2),
},
)
.unwrap();
);
}
#[tokio::test]
async fn pg_poc_test() {
let exec = build_executor().unwrap();
assert_eq!(2, exec.threads());
exec
@@ -274,4 +277,38 @@ mod tests {
.await
.unwrap();
}
#[tokio::test]
async fn pg_poc_named_parameter_test() {
let exec = build_executor().unwrap();
assert_eq!(2, exec.threads());
exec
.call(|client| {
return client.execute(
"
CREATE TABLE IF NOT EXISTS test_table_poc_named_params(
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
",
&[],
);
})
.await
.unwrap();
exec
.query_rows_f(
"
INSERT INTO test_table_poc_named_params (data) VALUES (:param);
",
named_params! {":param": "value"},
|_rows| -> Result<(), Error> {
return Ok(());
},
)
.await
.unwrap();
}
}
+13 -17
View File
@@ -4,18 +4,16 @@ use postgres::fallible_iterator::FallibleIterator;
use crate::error::Error;
use crate::params::Params;
use crate::pg::util::bind;
use crate::pg::util::{columns, from_row, from_rows};
use crate::pg::util::{PgStatement, columns, from_row, from_rows};
use crate::rows::{Row, Rows};
use crate::traits::SyncConnection as SyncConnectionTrait;
use crate::traits::SyncTransaction as SyncTransactionTrait;
use crate::value::Value;
impl SyncConnectionTrait for postgres::Client {
// Queries the first row and returns it if present, otherwise `None`.
fn query_row(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<Option<Row>, Error> {
let params: Vec<Value> = bind(sql.as_ref(), params)?;
let mut row_iter = self.query_raw(sql.as_ref(), &params)?;
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
let mut row_iter = self.query_raw(&sql, &params)?;
if let Some(row) = row_iter.next()? {
return Ok(Some(from_row(&row, Arc::new(columns(&row)))?));
@@ -25,16 +23,14 @@ impl SyncConnectionTrait for postgres::Client {
}
fn query_rows(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<Rows, Error> {
let params: Vec<Value> = bind(sql.as_ref(), params)?;
let row_iter = self.query_raw(sql.as_ref(), &params)?;
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
let row_iter = self.query_raw(&sql, &params)?;
return from_rows(row_iter);
}
fn execute(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<usize, Error> {
let params: Vec<Value> = bind(sql.as_ref(), params)?;
// let param_refs: Vec<&(dyn postgres::types::ToSql + Sync)> = params.iter().collect();
// return Ok(self.execute(sql.as_ref(), &param_refs)? as usize);
let row_iter = self.query_raw(sql.as_ref(), params)?;
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
let row_iter = self.query_raw(&sql, params)?;
return Ok(row_iter.rows_affected().unwrap_or_default() as usize);
}
@@ -46,8 +42,8 @@ impl SyncConnectionTrait for postgres::Client {
impl<'a> SyncConnectionTrait for postgres::Transaction<'a> {
// Queries the first row and returns it if present, otherwise `None`.
fn query_row(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<Option<Row>, Error> {
let params: Vec<Value> = bind(sql.as_ref(), params)?;
let mut row_iter = self.query_raw(sql.as_ref(), &params)?;
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
let mut row_iter = self.query_raw(&sql, &params)?;
if let Some(row) = row_iter.next()? {
return Ok(Some(from_row(&row, Arc::new(columns(&row)))?));
@@ -57,14 +53,14 @@ impl<'a> SyncConnectionTrait for postgres::Transaction<'a> {
}
fn query_rows(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<Rows, Error> {
let params: Vec<Value> = bind(sql.as_ref(), params)?;
let row_iter = self.query_raw(sql.as_ref(), &params)?;
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
let row_iter = self.query_raw(&sql, &params)?;
return from_rows(row_iter);
}
fn execute(&mut self, sql: impl AsRef<str>, params: impl Params) -> Result<usize, Error> {
let params: Vec<Value> = bind(sql.as_ref(), params)?;
let row_iter = self.query_raw(sql.as_ref(), params)?;
let (sql, params) = PgStatement::new(sql.as_ref())?.bind(params)?;
let row_iter = self.query_raw(&sql, params)?;
return Ok(row_iter.rows_affected().unwrap_or_default() as usize);
}
+95 -24
View File
@@ -1,5 +1,7 @@
use postgres::fallible_iterator::FallibleIterator;
use std::sync::Arc;
use regex::Regex;
use std::collections::HashMap;
use std::sync::{Arc, LazyLock};
use crate::error::Error;
use crate::params::Params;
@@ -9,11 +11,64 @@ use crate::to_sql::ToSqlProxy;
use crate::value::Value;
#[derive(Debug)]
pub struct PgStatement<'a> {
pub(crate) struct PgStatement<'a> {
#[allow(unused)]
sql: &'a str,
// TODO: Can we use ToSqlProxy here?
params: &'a mut Vec<(usize, Value)>,
// TODO: Could we use ToSqlProxy here to reduce copies?
params: Vec<(usize, Value)>,
placeholders: HashMap<String, usize>,
}
impl<'a> PgStatement<'a> {
pub fn new(sql: &'a str) -> Result<Self, Error> {
static NAMED_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?<named>:[[:alpha:]][[:alnum:]]*)").expect("startup"));
let mut placeholders: HashMap<String, usize> = Default::default();
for (idx, cap) in NAMED_RE.captures_iter(sql).enumerate() {
let named_params = &cap["named"];
placeholders.insert(named_params.to_string(), idx + 1);
}
return Ok(Self {
sql,
params: vec![],
placeholders,
});
}
pub fn bind(mut self, params: impl Params) -> Result<(String, Vec<Value>), Error> {
params.bind(&mut self)?;
let Self {
sql,
placeholders,
mut params,
} = self;
// TODO: Do we need further validation, e.g. that indexes are consecutive, that they're
// matching the SQL...?
let bound_params = {
params.sort_by(|a, b| {
return a.0.cmp(&b.0);
});
params.into_iter().map(|p| p.1).collect()
};
// Also support "?1" placeholders like sqlite (PG only supports $1).
static RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"[?](?<index>\d+)").expect("startup"));
let mut sql = RE.replace_all(sql, "$$$index").to_string();
// TODO: We should probably do this along the initial parse when we find the placeholders.
for (name, idx) in placeholders {
sql = sql.replace(&name, &format!("${idx}"));
}
return Ok((sql, bound_params));
}
}
impl<'a> Statement for PgStatement<'a> {
@@ -22,29 +77,16 @@ impl<'a> Statement for PgStatement<'a> {
return Ok(());
}
fn parameter_index(&self, _name: &str) -> Result<Option<usize>, Error> {
return Err(Error::Other("not implemented: parse `self.sql`".into()));
/// Will return Err if `name` is invalid. Will return Ok(None) if the name
/// is valid but not a bound parameter of this statement.
fn parameter_index(&self, name: &str) -> Result<Option<usize>, Error> {
if &name[0..1] != ":" || name[1..].chars().any(|c| !c.is_ascii_alphanumeric()) {
return Err(Error::Other(format!("invalid param name: {name}").into()));
}
return Ok(self.placeholders.get(name).cloned());
}
}
#[inline]
pub(crate) fn bind(sql: &str, params: impl Params) -> Result<Vec<Value>, Error> {
let mut bound: Vec<(usize, Value)> = vec![];
let mut stmt = PgStatement {
sql,
params: &mut bound,
};
params.bind(&mut stmt)?;
bound.sort_by(|a, b| {
return a.0.cmp(&b.0);
});
// TODO: Do we need further validation, e.g. that indexes are consecutive?
return Ok(bound.into_iter().map(|p| p.1).collect());
}
#[inline]
pub(crate) fn map_first<T>(
mut rows: postgres::RowIter<'_>,
@@ -110,3 +152,32 @@ pub(crate) fn columns(row: &postgres::Row) -> Vec<Column> {
})
.collect();
}
#[cfg(test)]
mod tests {
use crate::named_params;
use super::*;
#[test]
fn pg_statement_test() {
let (sql, params) = PgStatement::new("INSERT INTO 'table' (col) VALUES (?1), (?1)")
.unwrap()
.bind(("foo",))
.unwrap();
assert_eq!("INSERT INTO 'table' (col) VALUES ($1), ($1)", sql);
assert_eq!(Value::Text("foo".to_string()), *params.first().unwrap());
let (sql, params) = PgStatement::new("INSERT INTO 'table' (col) VALUES (:p0), (:p1)")
.unwrap()
.bind(named_params! {":p0": "p0", ":p1": "p1"})
.unwrap();
assert_eq!("INSERT INTO 'table' (col) VALUES ($1), ($2)", sql);
assert_eq!(
vec![Value::Text("p0".to_string()), Value::Text("p1".to_string())],
params,
);
}
}