diff --git a/Cargo.lock b/Cargo.lock index 4396bd29..187fe919 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8061,6 +8061,7 @@ name = "trailbase-qs" version = "0.1.0" dependencies = [ "base64", + "rusqlite", "serde", "serde-value", "serde_qs", diff --git a/crates/core/src/listing.rs b/crates/core/src/listing.rs index ee2e7647..ab6c34c0 100644 --- a/crates/core/src/listing.rs +++ b/crates/core/src/listing.rs @@ -41,41 +41,34 @@ pub(crate) fn build_filter_where_clause( }); }; - let validator = |column_name: &str| -> Result<(), WhereClauseError> { + let convert = |column_name: &str, + value: trailbase_qs::Value| + -> Result { if column_name.starts_with("_") { return Err(WhereClauseError::UnrecognizedParam(format!( "Invalid parameter: {column_name}" ))); } - // IMPORTANT: We only include parameters with known columns to avoid building an invalid - // query early and prevent injections. - if !columns.iter().any(|c| c.name == column_name) { + let Some(column) = columns.iter().find(|c| c.name == column_name) else { return Err(WhereClauseError::UnrecognizedParam(format!( "Unrecognized parameter: {column_name}" ))); }; - return Ok(()); + // TODO: Improve hacky error handling. + return crate::records::filter::qs_value_to_sql_with_constraints(column, value) + .map_err(|err| WhereClauseError::UnrecognizedParam(err.to_string())); }; - let (sql, params) = filter_params.into_sql(Some(table_name), &validator)?; - - use trailbase_sqlite::Value; - type Param = (Cow<'static, str>, Value); - let sql_params: Vec = params - .into_iter() - .map(|(name, value)| { - return ( - Cow::Owned(name), - crate::records::filter::qs_value_to_sql(value), - ); - }) - .collect(); + let (sql, params) = filter_params.into_sql(Some(table_name), &convert)?; return Ok(WhereClause { clause: sql, - params: sql_params, + params: params + .into_iter() + .map(|(name, v)| (Cow::Owned(name), v)) + .collect(), }); } diff --git a/crates/core/src/records/filter.rs b/crates/core/src/records/filter.rs index 401d241a..2d3723b8 100644 --- a/crates/core/src/records/filter.rs +++ b/crates/core/src/records/filter.rs @@ -22,7 +22,7 @@ pub enum Filter { Record(ValueOrComposite), } -pub(crate) fn qs_value_to_sql(value: trailbase_qs::Value) -> rusqlite::types::Value { +fn any_qs_value_to_sql(value: trailbase_qs::Value) -> rusqlite::types::Value { use base64::prelude::*; use rusqlite::types::Value; use trailbase_qs::Value as QsValue; @@ -50,7 +50,7 @@ pub(crate) fn qs_value_to_sql_with_constraints( return match column.data_type { ColumnDataType::Null => Err(RecordError::BadRequest("Invalid query")), - ColumnDataType::Any => Ok(qs_value_to_sql(value)), + ColumnDataType::Any => Ok(any_qs_value_to_sql(value)), ColumnDataType::Blob => match value { QsValue::String(s) => Ok(Value::Blob( BASE64_URL_SAFE diff --git a/crates/qs/Cargo.toml b/crates/qs/Cargo.toml index 25cdc97b..692bda82 100644 --- a/crates/qs/Cargo.toml +++ b/crates/qs/Cargo.toml @@ -10,6 +10,7 @@ readme = "../README.md" [dependencies] base64 = { version = "0.22.1", default-features = false, features = ["alloc"] } +rusqlite = { workspace = true } serde = "1.0.219" serde-value = "0.7.0" serde_qs = "0.15.0" diff --git a/crates/qs/src/column_rel_value.rs b/crates/qs/src/column_rel_value.rs index 2883e4d6..9b5dc897 100644 --- a/crates/qs/src/column_rel_value.rs +++ b/crates/qs/src/column_rel_value.rs @@ -1,4 +1,5 @@ use base64::prelude::*; +use rusqlite::types::Value as SqlValue; use serde::de::{Deserializer, Error}; use crate::value::Value; @@ -57,48 +58,38 @@ pub struct ColumnOpValue { } impl ColumnOpValue { - pub fn into_sql( + pub fn into_sql( self, column_prefix: Option<&str>, + convert: &dyn Fn(&str, Value) -> Result, index: &mut usize, - ) -> (String, Option<(String, Value)>) { + ) -> Result<(String, Option<(String, SqlValue)>), E> { + let v = self.value; + let c = self.column; + return match self.op { CompareOp::Is => { - assert!(matches!(self.value, Value::String(_)), "{:?}", self.value); + assert!(matches!(v, Value::String(_)), "{v:?}"); - match column_prefix { - Some(p) => ( - format!(r#"{p}."{c}" IS {v}"#, c = self.column, v = self.value), - None, - ), - None => ( - format!(r#""{c}" IS {v}"#, c = self.column, v = self.value), - None, - ), - } + Ok(match column_prefix { + Some(p) => (format!(r#"{p}."{c}" IS {v}"#), None), + None => (format!(r#""{c}" IS {v}"#), None), + }) } _ => { let param = param_name(*index); *index += 1; - match column_prefix { + Ok(match column_prefix { Some(p) => ( - format!( - r#"{p}."{c}" {o} {param}"#, - c = self.column, - o = self.op.as_sql() - ), - Some((param, self.value)), + format!(r#"{p}."{c}" {o} {param}"#, o = self.op.as_sql()), + Some((param, convert(&c, v)?)), ), None => ( - format!( - r#""{c}" {o} {param}"#, - c = self.column, - o = self.op.as_sql() - ), - Some((param, self.value)), + format!(r#""{c}" {o} {param}"#, o = self.op.as_sql()), + Some((param, convert(&c, v)?)), ), - } + }) } }; } diff --git a/crates/qs/src/filter.rs b/crates/qs/src/filter.rs index 55b26c5a..cadc4729 100644 --- a/crates/qs/src/filter.rs +++ b/crates/qs/src/filter.rs @@ -7,6 +7,7 @@ /// filters[column][eq]=value /// filters[and][0][column0][eq]=value0&filters[and][1][column1][eq]=value1 /// filters[and][0][or][0][column0]=value0&[and][0][or][1][column1]=value1 +use rusqlite::types::Value as SqlValue; use std::collections::BTreeMap; use crate::column_rel_value::{ColumnOpValue, serde_value_to_single_column_rel_value}; @@ -28,33 +29,31 @@ impl ValueOrComposite { pub fn into_sql( self, column_prefix: Option<&str>, - validator: &dyn Fn(&str) -> Result<(), E>, - ) -> Result<(String, Vec<(String, Value)>), E> { + convert: &dyn Fn(&str, Value) -> Result, + ) -> Result<(String, Vec<(String, SqlValue)>), E> { let mut index: usize = 0; - return self.into_sql_impl(column_prefix, validator, &mut index); + return self.into_sql_impl(column_prefix, convert, &mut index); } fn into_sql_impl( self, column_prefix: Option<&str>, - validator: &dyn Fn(&str) -> Result<(), E>, + convert: &dyn Fn(&str, Value) -> Result, index: &mut usize, - ) -> Result<(String, Vec<(String, Value)>), E> { + ) -> Result<(String, Vec<(String, SqlValue)>), E> { match self { Self::Value(v) => { - validator(&v.column)?; - - return Ok(match v.into_sql(column_prefix, index) { + return Ok(match v.into_sql(column_prefix, convert, index)? { (sql, Some(param)) => (sql, vec![param]), (sql, None) => (sql, vec![]), }); } Self::Composite(combiner, vec) => { let mut fragments = Vec::::with_capacity(vec.len()); - let mut params = Vec::<(String, Value)>::with_capacity(vec.len()); + let mut params = Vec::<(String, SqlValue)>::with_capacity(vec.len()); for value_or_composite in vec { - let (f, p) = value_or_composite.into_sql_impl::(column_prefix, validator, index)?; + let (f, p) = value_or_composite.into_sql_impl::(column_prefix, convert, index)?; fragments.push(f); params.extend(p); } @@ -280,16 +279,20 @@ mod tests { value: Value::String("val0".to_string()), }); - let validator = |_: &str| -> Result<(), String> { - return Ok(()); + let convert = |_: &str, value: Value| -> Result { + return Ok(match value { + Value::String(s) => SqlValue::Text(s), + Value::Integer(i) => SqlValue::Integer(i), + Value::Double(d) => SqlValue::Real(d), + }); }; let sql0 = v0 .clone() - .into_sql(/* column_prefix= */ None, &validator) + .into_sql(/* column_prefix= */ None, &convert) .unwrap(); assert_eq!(sql0.0, r#""col0" = :__p0"#); let sql0 = v0 - .into_sql(/* column_prefix= */ Some("p"), &validator) + .into_sql(/* column_prefix= */ Some("p"), &convert) .unwrap(); assert_eq!(sql0.0, r#"p."col0" = :__p0"#); @@ -298,7 +301,7 @@ mod tests { op: CompareOp::Is, value: Value::String("NULL".to_string()), }); - let sql1 = v1.into_sql(None, &validator).unwrap(); + let sql1 = v1.into_sql(None, &convert).unwrap(); assert_eq!(sql1.0, r#""col0" IS NULL"#, "{sql1:?}",); } } diff --git a/crates/qs/src/query.rs b/crates/qs/src/query.rs index 02fe7fa2..4b9fd583 100644 --- a/crates/qs/src/query.rs +++ b/crates/qs/src/query.rs @@ -203,6 +203,7 @@ impl FilterQuery { mod tests { use super::*; + use rusqlite::types::Value as SqlValue; use serde_qs::Config; use crate::column_rel_value::{ColumnOpValue, CompareOp}; @@ -395,8 +396,12 @@ mod tests { ) ); - let filter = |_: &str| -> Result<(), String> { - return Ok(()); + let filter = |_: &str, value: Value| -> Result { + return Ok(match value { + Value::String(s) => SqlValue::Text(s), + Value::Integer(i) => SqlValue::Integer(i), + Value::Double(d) => SqlValue::Real(d), + }); }; let (sql, params) = q1.filter.clone().unwrap().into_sql(None, &filter).unwrap(); assert_eq!( @@ -406,9 +411,9 @@ mod tests { assert_eq!( params, vec![ - (":__p0".to_string(), Value::String("val2".to_string())), - (":__p1".to_string(), Value::String("val0".to_string())), - (":__p2".to_string(), Value::Integer(1)), + (":__p0".to_string(), SqlValue::Text("val2".to_string())), + (":__p1".to_string(), SqlValue::Text("val0".to_string())), + (":__p2".to_string(), SqlValue::Integer(1)), ] ); let (sql, _) = q1.filter.unwrap().into_sql(Some("p"), &filter).unwrap();