diff --git a/crates/extension/src/b64.rs b/crates/extension/src/b64.rs new file mode 100644 index 00000000..c074fd31 --- /dev/null +++ b/crates/extension/src/b64.rs @@ -0,0 +1,153 @@ +use base64::prelude::*; +use rusqlite::Error; +use rusqlite::Result; +use rusqlite::functions::Context; +use rusqlite::types::{Value, ValueRef}; + +/// A base64 conversion utility similar to the SQLite's `base64()` extension. +/// +/// It introspects on the inputs, and will convert based on that: +/// - BLOB → TEXT (encode) +/// - TEXT → BLOB (decode) +/// - NULL → NULL +/// - Other types → error +/// +/// Note, however, that this implementation is more strict with respect to b64 +/// text inputs, e.g. it checks the padding. +pub(super) fn base64(context: &Context) -> Result { + #[cfg(debug_assertions)] + if context.len() != 1 { + return Err(Error::InvalidParameterCount(context.len(), 1)); + } + + return match context.get_raw(0) { + ValueRef::Null => Ok(Value::Null), + ValueRef::Blob(blob) => Ok(Value::Text(BASE64_STANDARD.encode(blob))), + ValueRef::Text(text_bytes) => { + let text_str = + std::str::from_utf8(text_bytes).map_err(|err| Error::UserFunctionError(err.into()))?; + + Ok(Value::Blob( + BASE64_STANDARD + .decode(text_str) + .map_err(|err| Error::UserFunctionError(err.into()))?, + )) + } + v => Err(Error::InvalidFunctionParameterType(0, v.data_type())), + }; +} + +/// A URL-safe base64 conversion utility similar to the SQLite's `base64()` extension. +/// +/// It introspects on the inputs, and will convert based on that: +/// - BLOB → TEXT (encode) +/// - TEXT → BLOB (decode) +/// - NULL → NULL +/// - Other types → error +/// +/// Note, however, that this implementation is more strict with respect to b64 +/// text inputs, e.g. it checks the padding. +pub(super) fn base64_url_safe(context: &Context) -> Result { + #[cfg(debug_assertions)] + if context.len() != 1 { + return Err(Error::InvalidParameterCount(context.len(), 1)); + } + + return match context.get_raw(0) { + ValueRef::Null => Ok(Value::Null), + ValueRef::Blob(blob) => Ok(Value::Text(BASE64_URL_SAFE.encode(blob))), + ValueRef::Text(text_bytes) => { + let text_str = + std::str::from_utf8(text_bytes).map_err(|err| Error::UserFunctionError(err.into()))?; + + Ok(Value::Blob( + BASE64_URL_SAFE + .decode(text_str) + .map_err(|err| Error::UserFunctionError(err.into()))?, + )) + } + v => Err(Error::InvalidFunctionParameterType(0, v.data_type())), + }; +} + +#[cfg(test)] +mod tests { + use base64::prelude::*; + use rusqlite::Error; + + #[test] + fn test_base64_wrong_number_of_arguments() { + let conn = crate::connect_sqlite(None).unwrap(); + let val = conn.query_row( + "SELECT base64_url_safe('a', 'b')", + [], + |row| -> Result<[u8; 16], Error> { Ok(row.get(0)?) }, + ); + assert!(val.is_err()); + } + + #[test] + fn test_base64_url_safe_roundtrip() { + let conn = crate::connect_sqlite(None).unwrap(); + + let value = b"832!@#$%^&*()>./"; + for query in [ + format!("SELECT base64(base64(?1))"), + format!("SELECT base64_url_safe(base64_url_safe(?1))"), + ] { + // BLOB → TEXT → BLOB round-trip test + let val = conn + .query_row(&query, [value], |row| -> Result, Error> { + Ok(row.get(0)?) + }) + .unwrap(); + assert_eq!(val, value); + } + } + + #[test] + fn test_base64_url_safe_null_handling() { + let conn = crate::connect_sqlite(None).unwrap(); + for query in [ + format!("SELECT base64(NULL)"), + format!("SELECT base64_url_safe(NULL)"), + ] { + let val: Option> = conn.query_row(&query, [], |row| row.get(0)).unwrap(); + assert!(val.is_none()); + } + } + + #[test] + fn test_base64_url_safe_empty_string() { + let conn = crate::connect_sqlite(None).unwrap(); + for query in [ + format!("SELECT base64('')"), + format!("SELECT base64_url_safe('')"), + ] { + let val: Vec = conn.query_row(&query, [], |row| row.get(0)).unwrap(); + assert!(val.is_empty()); + } + } + + #[test] + fn test_base64_url_safe_trimmed_input() { + let conn = crate::connect_sqlite(None).unwrap(); + let encoded = BASE64_URL_SAFE.encode(&[1, 2, 3, 4]); + + for query in [ + format!("SELECT base64(' {encoded} ')"), + format!("SELECT base64_url_safe(' {encoded} ')"), + ] { + let v = conn.query_row(&query, [], |row| row.get::<_, Vec>(0)); + assert!(v.is_err()); + } + + for query in [ + format!("SELECT base64(trim(' {encoded} '))"), + format!("SELECT base64_url_safe(trim(' {encoded} '))"), + ] { + let val: Vec = conn.query_row(&query, [], |row| row.get(0)).unwrap(); + assert_eq!(val, vec![1, 2, 3, 4]); + } + } +} diff --git a/crates/extension/src/lib.rs b/crates/extension/src/lib.rs index e352dfbc..ffdcea1b 100644 --- a/crates/extension/src/lib.rs +++ b/crates/extension/src/lib.rs @@ -8,6 +8,7 @@ pub mod geoip; pub mod jsonschema; pub mod password; +mod b64; mod regex; mod uuid; mod validators; @@ -73,8 +74,8 @@ pub fn connect_sqlite(path: Option) -> Result Result { - // WARN: Be careful with declaring INNOCUOUS. This allows these "app-defined functions" to run - // even when "trusted_schema=OFF", which means as part of: VIEWs, TRIGGERs, CHECK, DEFAULT, + // WARN: Be careful with declaring INNOCUOUS. It allows "user-defined functions" to run + // when "trusted_schema=OFF", which means as part of: VIEWs, TRIGGERs, CHECK, DEFAULT, // GENERATED cols, ... as opposed to just top-level SELECTs. db.create_scalar_function( @@ -100,7 +101,9 @@ pub fn sqlite3_extension_init( db.create_scalar_function( "uuid_text", 1, - FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_INNOCUOUS, + FunctionFlags::SQLITE_UTF8 + | FunctionFlags::SQLITE_DETERMINISTIC + | FunctionFlags::SQLITE_INNOCUOUS, uuid::uuid_text, )?; @@ -117,7 +120,9 @@ pub fn sqlite3_extension_init( db.create_scalar_function( "hash_password", 1, - FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_INNOCUOUS, + FunctionFlags::SQLITE_UTF8 + | FunctionFlags::SQLITE_DETERMINISTIC + | FunctionFlags::SQLITE_INNOCUOUS, password::hash_password_sqlite, )?; @@ -202,6 +207,23 @@ pub fn sqlite3_extension_init( geoip::geoip_city_json, )?; + db.create_scalar_function( + "base64", + 1, + FunctionFlags::SQLITE_UTF8 + | FunctionFlags::SQLITE_DETERMINISTIC + | FunctionFlags::SQLITE_INNOCUOUS, + b64::base64, + )?; + db.create_scalar_function( + "base64_url_safe", + 1, + FunctionFlags::SQLITE_UTF8 + | FunctionFlags::SQLITE_DETERMINISTIC + | FunctionFlags::SQLITE_INNOCUOUS, + b64::base64_url_safe, + )?; + return Ok(db); }