Provide a base64_url_safe modeled based on SQLite's own base64() extension.

This commit is contained in:
Bilux
2025-10-14 22:06:26 +01:00
committed by Sebastian Jeltsch
parent dc1978ff2e
commit c0f896f621
2 changed files with 179 additions and 4 deletions

153
crates/extension/src/b64.rs Normal file
View File

@@ -0,0 +1,153 @@
use base64::prelude::*;
use rusqlite::Error;
use rusqlite::Result;
use rusqlite::functions::Context;
use rusqlite::types::{Value, ValueRef};
/// A base64 conversion utility similar to the SQLite's `base64()` extension.
///
/// It introspects on the inputs, and will convert based on that:
/// - BLOB → TEXT (encode)
/// - TEXT → BLOB (decode)
/// - NULL → NULL
/// - Other types → error
///
/// Note, however, that this implementation is more strict with respect to b64
/// text inputs, e.g. it checks the padding.
pub(super) fn base64(context: &Context) -> Result<Value> {
#[cfg(debug_assertions)]
if context.len() != 1 {
return Err(Error::InvalidParameterCount(context.len(), 1));
}
return match context.get_raw(0) {
ValueRef::Null => Ok(Value::Null),
ValueRef::Blob(blob) => Ok(Value::Text(BASE64_STANDARD.encode(blob))),
ValueRef::Text(text_bytes) => {
let text_str =
std::str::from_utf8(text_bytes).map_err(|err| Error::UserFunctionError(err.into()))?;
Ok(Value::Blob(
BASE64_STANDARD
.decode(text_str)
.map_err(|err| Error::UserFunctionError(err.into()))?,
))
}
v => Err(Error::InvalidFunctionParameterType(0, v.data_type())),
};
}
/// A URL-safe base64 conversion utility similar to the SQLite's `base64()` extension.
///
/// It introspects on the inputs, and will convert based on that:
/// - BLOB → TEXT (encode)
/// - TEXT → BLOB (decode)
/// - NULL → NULL
/// - Other types → error
///
/// Note, however, that this implementation is more strict with respect to b64
/// text inputs, e.g. it checks the padding.
pub(super) fn base64_url_safe(context: &Context) -> Result<Value> {
#[cfg(debug_assertions)]
if context.len() != 1 {
return Err(Error::InvalidParameterCount(context.len(), 1));
}
return match context.get_raw(0) {
ValueRef::Null => Ok(Value::Null),
ValueRef::Blob(blob) => Ok(Value::Text(BASE64_URL_SAFE.encode(blob))),
ValueRef::Text(text_bytes) => {
let text_str =
std::str::from_utf8(text_bytes).map_err(|err| Error::UserFunctionError(err.into()))?;
Ok(Value::Blob(
BASE64_URL_SAFE
.decode(text_str)
.map_err(|err| Error::UserFunctionError(err.into()))?,
))
}
v => Err(Error::InvalidFunctionParameterType(0, v.data_type())),
};
}
#[cfg(test)]
mod tests {
use base64::prelude::*;
use rusqlite::Error;
#[test]
fn test_base64_wrong_number_of_arguments() {
let conn = crate::connect_sqlite(None).unwrap();
let val = conn.query_row(
"SELECT base64_url_safe('a', 'b')",
[],
|row| -> Result<[u8; 16], Error> { Ok(row.get(0)?) },
);
assert!(val.is_err());
}
#[test]
fn test_base64_url_safe_roundtrip() {
let conn = crate::connect_sqlite(None).unwrap();
let value = b"832!@#$%^&*()>./";
for query in [
format!("SELECT base64(base64(?1))"),
format!("SELECT base64_url_safe(base64_url_safe(?1))"),
] {
// BLOB → TEXT → BLOB round-trip test
let val = conn
.query_row(&query, [value], |row| -> Result<Vec<u8>, Error> {
Ok(row.get(0)?)
})
.unwrap();
assert_eq!(val, value);
}
}
#[test]
fn test_base64_url_safe_null_handling() {
let conn = crate::connect_sqlite(None).unwrap();
for query in [
format!("SELECT base64(NULL)"),
format!("SELECT base64_url_safe(NULL)"),
] {
let val: Option<Vec<u8>> = conn.query_row(&query, [], |row| row.get(0)).unwrap();
assert!(val.is_none());
}
}
#[test]
fn test_base64_url_safe_empty_string() {
let conn = crate::connect_sqlite(None).unwrap();
for query in [
format!("SELECT base64('')"),
format!("SELECT base64_url_safe('')"),
] {
let val: Vec<u8> = conn.query_row(&query, [], |row| row.get(0)).unwrap();
assert!(val.is_empty());
}
}
#[test]
fn test_base64_url_safe_trimmed_input() {
let conn = crate::connect_sqlite(None).unwrap();
let encoded = BASE64_URL_SAFE.encode(&[1, 2, 3, 4]);
for query in [
format!("SELECT base64(' {encoded} ')"),
format!("SELECT base64_url_safe(' {encoded} ')"),
] {
let v = conn.query_row(&query, [], |row| row.get::<_, Vec<u8>>(0));
assert!(v.is_err());
}
for query in [
format!("SELECT base64(trim(' {encoded} '))"),
format!("SELECT base64_url_safe(trim(' {encoded} '))"),
] {
let val: Vec<u8> = conn.query_row(&query, [], |row| row.get(0)).unwrap();
assert_eq!(val, vec![1, 2, 3, 4]);
}
}
}

View File

@@ -8,6 +8,7 @@ pub mod geoip;
pub mod jsonschema;
pub mod password;
mod b64;
mod regex;
mod uuid;
mod validators;
@@ -73,8 +74,8 @@ pub fn connect_sqlite(path: Option<PathBuf>) -> Result<rusqlite::Connection, Err
pub fn sqlite3_extension_init(
db: rusqlite::Connection,
) -> Result<rusqlite::Connection, rusqlite::Error> {
// WARN: Be careful with declaring INNOCUOUS. This allows these "app-defined functions" to run
// even when "trusted_schema=OFF", which means as part of: VIEWs, TRIGGERs, CHECK, DEFAULT,
// WARN: Be careful with declaring INNOCUOUS. It allows "user-defined functions" to run
// when "trusted_schema=OFF", which means as part of: VIEWs, TRIGGERs, CHECK, DEFAULT,
// GENERATED cols, ... as opposed to just top-level SELECTs.
db.create_scalar_function(
@@ -100,7 +101,9 @@ pub fn sqlite3_extension_init(
db.create_scalar_function(
"uuid_text",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_INNOCUOUS,
FunctionFlags::SQLITE_UTF8
| FunctionFlags::SQLITE_DETERMINISTIC
| FunctionFlags::SQLITE_INNOCUOUS,
uuid::uuid_text,
)?;
@@ -117,7 +120,9 @@ pub fn sqlite3_extension_init(
db.create_scalar_function(
"hash_password",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_INNOCUOUS,
FunctionFlags::SQLITE_UTF8
| FunctionFlags::SQLITE_DETERMINISTIC
| FunctionFlags::SQLITE_INNOCUOUS,
password::hash_password_sqlite,
)?;
@@ -202,6 +207,23 @@ pub fn sqlite3_extension_init(
geoip::geoip_city_json,
)?;
db.create_scalar_function(
"base64",
1,
FunctionFlags::SQLITE_UTF8
| FunctionFlags::SQLITE_DETERMINISTIC
| FunctionFlags::SQLITE_INNOCUOUS,
b64::base64,
)?;
db.create_scalar_function(
"base64_url_safe",
1,
FunctionFlags::SQLITE_UTF8
| FunctionFlags::SQLITE_DETERMINISTIC
| FunctionFlags::SQLITE_INNOCUOUS,
b64::base64_url_safe,
)?;
return Ok(db);
}