mirror of
https://github.com/trailbaseio/trailbase.git
synced 2026-01-08 10:50:15 -06:00
Provide a base64_url_safe modeled based on SQLite's own base64() extension.
This commit is contained in:
153
crates/extension/src/b64.rs
Normal file
153
crates/extension/src/b64.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
use base64::prelude::*;
|
||||
use rusqlite::Error;
|
||||
use rusqlite::Result;
|
||||
use rusqlite::functions::Context;
|
||||
use rusqlite::types::{Value, ValueRef};
|
||||
|
||||
/// A base64 conversion utility similar to the SQLite's `base64()` extension.
|
||||
///
|
||||
/// It introspects on the inputs, and will convert based on that:
|
||||
/// - BLOB → TEXT (encode)
|
||||
/// - TEXT → BLOB (decode)
|
||||
/// - NULL → NULL
|
||||
/// - Other types → error
|
||||
///
|
||||
/// Note, however, that this implementation is more strict with respect to b64
|
||||
/// text inputs, e.g. it checks the padding.
|
||||
pub(super) fn base64(context: &Context) -> Result<Value> {
|
||||
#[cfg(debug_assertions)]
|
||||
if context.len() != 1 {
|
||||
return Err(Error::InvalidParameterCount(context.len(), 1));
|
||||
}
|
||||
|
||||
return match context.get_raw(0) {
|
||||
ValueRef::Null => Ok(Value::Null),
|
||||
ValueRef::Blob(blob) => Ok(Value::Text(BASE64_STANDARD.encode(blob))),
|
||||
ValueRef::Text(text_bytes) => {
|
||||
let text_str =
|
||||
std::str::from_utf8(text_bytes).map_err(|err| Error::UserFunctionError(err.into()))?;
|
||||
|
||||
Ok(Value::Blob(
|
||||
BASE64_STANDARD
|
||||
.decode(text_str)
|
||||
.map_err(|err| Error::UserFunctionError(err.into()))?,
|
||||
))
|
||||
}
|
||||
v => Err(Error::InvalidFunctionParameterType(0, v.data_type())),
|
||||
};
|
||||
}
|
||||
|
||||
/// A URL-safe base64 conversion utility similar to the SQLite's `base64()` extension.
|
||||
///
|
||||
/// It introspects on the inputs, and will convert based on that:
|
||||
/// - BLOB → TEXT (encode)
|
||||
/// - TEXT → BLOB (decode)
|
||||
/// - NULL → NULL
|
||||
/// - Other types → error
|
||||
///
|
||||
/// Note, however, that this implementation is more strict with respect to b64
|
||||
/// text inputs, e.g. it checks the padding.
|
||||
pub(super) fn base64_url_safe(context: &Context) -> Result<Value> {
|
||||
#[cfg(debug_assertions)]
|
||||
if context.len() != 1 {
|
||||
return Err(Error::InvalidParameterCount(context.len(), 1));
|
||||
}
|
||||
|
||||
return match context.get_raw(0) {
|
||||
ValueRef::Null => Ok(Value::Null),
|
||||
ValueRef::Blob(blob) => Ok(Value::Text(BASE64_URL_SAFE.encode(blob))),
|
||||
ValueRef::Text(text_bytes) => {
|
||||
let text_str =
|
||||
std::str::from_utf8(text_bytes).map_err(|err| Error::UserFunctionError(err.into()))?;
|
||||
|
||||
Ok(Value::Blob(
|
||||
BASE64_URL_SAFE
|
||||
.decode(text_str)
|
||||
.map_err(|err| Error::UserFunctionError(err.into()))?,
|
||||
))
|
||||
}
|
||||
v => Err(Error::InvalidFunctionParameterType(0, v.data_type())),
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use base64::prelude::*;
|
||||
use rusqlite::Error;
|
||||
|
||||
#[test]
|
||||
fn test_base64_wrong_number_of_arguments() {
|
||||
let conn = crate::connect_sqlite(None).unwrap();
|
||||
let val = conn.query_row(
|
||||
"SELECT base64_url_safe('a', 'b')",
|
||||
[],
|
||||
|row| -> Result<[u8; 16], Error> { Ok(row.get(0)?) },
|
||||
);
|
||||
assert!(val.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_url_safe_roundtrip() {
|
||||
let conn = crate::connect_sqlite(None).unwrap();
|
||||
|
||||
let value = b"832!@#$%^&*()>./";
|
||||
for query in [
|
||||
format!("SELECT base64(base64(?1))"),
|
||||
format!("SELECT base64_url_safe(base64_url_safe(?1))"),
|
||||
] {
|
||||
// BLOB → TEXT → BLOB round-trip test
|
||||
let val = conn
|
||||
.query_row(&query, [value], |row| -> Result<Vec<u8>, Error> {
|
||||
Ok(row.get(0)?)
|
||||
})
|
||||
.unwrap();
|
||||
assert_eq!(val, value);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_url_safe_null_handling() {
|
||||
let conn = crate::connect_sqlite(None).unwrap();
|
||||
for query in [
|
||||
format!("SELECT base64(NULL)"),
|
||||
format!("SELECT base64_url_safe(NULL)"),
|
||||
] {
|
||||
let val: Option<Vec<u8>> = conn.query_row(&query, [], |row| row.get(0)).unwrap();
|
||||
assert!(val.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_url_safe_empty_string() {
|
||||
let conn = crate::connect_sqlite(None).unwrap();
|
||||
for query in [
|
||||
format!("SELECT base64('')"),
|
||||
format!("SELECT base64_url_safe('')"),
|
||||
] {
|
||||
let val: Vec<u8> = conn.query_row(&query, [], |row| row.get(0)).unwrap();
|
||||
assert!(val.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_url_safe_trimmed_input() {
|
||||
let conn = crate::connect_sqlite(None).unwrap();
|
||||
let encoded = BASE64_URL_SAFE.encode(&[1, 2, 3, 4]);
|
||||
|
||||
for query in [
|
||||
format!("SELECT base64(' {encoded} ')"),
|
||||
format!("SELECT base64_url_safe(' {encoded} ')"),
|
||||
] {
|
||||
let v = conn.query_row(&query, [], |row| row.get::<_, Vec<u8>>(0));
|
||||
assert!(v.is_err());
|
||||
}
|
||||
|
||||
for query in [
|
||||
format!("SELECT base64(trim(' {encoded} '))"),
|
||||
format!("SELECT base64_url_safe(trim(' {encoded} '))"),
|
||||
] {
|
||||
let val: Vec<u8> = conn.query_row(&query, [], |row| row.get(0)).unwrap();
|
||||
assert_eq!(val, vec![1, 2, 3, 4]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ pub mod geoip;
|
||||
pub mod jsonschema;
|
||||
pub mod password;
|
||||
|
||||
mod b64;
|
||||
mod regex;
|
||||
mod uuid;
|
||||
mod validators;
|
||||
@@ -73,8 +74,8 @@ pub fn connect_sqlite(path: Option<PathBuf>) -> Result<rusqlite::Connection, Err
|
||||
pub fn sqlite3_extension_init(
|
||||
db: rusqlite::Connection,
|
||||
) -> Result<rusqlite::Connection, rusqlite::Error> {
|
||||
// WARN: Be careful with declaring INNOCUOUS. This allows these "app-defined functions" to run
|
||||
// even when "trusted_schema=OFF", which means as part of: VIEWs, TRIGGERs, CHECK, DEFAULT,
|
||||
// WARN: Be careful with declaring INNOCUOUS. It allows "user-defined functions" to run
|
||||
// when "trusted_schema=OFF", which means as part of: VIEWs, TRIGGERs, CHECK, DEFAULT,
|
||||
// GENERATED cols, ... as opposed to just top-level SELECTs.
|
||||
|
||||
db.create_scalar_function(
|
||||
@@ -100,7 +101,9 @@ pub fn sqlite3_extension_init(
|
||||
db.create_scalar_function(
|
||||
"uuid_text",
|
||||
1,
|
||||
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_INNOCUOUS,
|
||||
FunctionFlags::SQLITE_UTF8
|
||||
| FunctionFlags::SQLITE_DETERMINISTIC
|
||||
| FunctionFlags::SQLITE_INNOCUOUS,
|
||||
uuid::uuid_text,
|
||||
)?;
|
||||
|
||||
@@ -117,7 +120,9 @@ pub fn sqlite3_extension_init(
|
||||
db.create_scalar_function(
|
||||
"hash_password",
|
||||
1,
|
||||
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_INNOCUOUS,
|
||||
FunctionFlags::SQLITE_UTF8
|
||||
| FunctionFlags::SQLITE_DETERMINISTIC
|
||||
| FunctionFlags::SQLITE_INNOCUOUS,
|
||||
password::hash_password_sqlite,
|
||||
)?;
|
||||
|
||||
@@ -202,6 +207,23 @@ pub fn sqlite3_extension_init(
|
||||
geoip::geoip_city_json,
|
||||
)?;
|
||||
|
||||
db.create_scalar_function(
|
||||
"base64",
|
||||
1,
|
||||
FunctionFlags::SQLITE_UTF8
|
||||
| FunctionFlags::SQLITE_DETERMINISTIC
|
||||
| FunctionFlags::SQLITE_INNOCUOUS,
|
||||
b64::base64,
|
||||
)?;
|
||||
db.create_scalar_function(
|
||||
"base64_url_safe",
|
||||
1,
|
||||
FunctionFlags::SQLITE_UTF8
|
||||
| FunctionFlags::SQLITE_DETERMINISTIC
|
||||
| FunctionFlags::SQLITE_INNOCUOUS,
|
||||
b64::base64_url_safe,
|
||||
)?;
|
||||
|
||||
return Ok(db);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user