Minor: tidy up the VIEW extraction code.

This commit is contained in:
Sebastian Jeltsch
2025-07-22 10:21:31 +02:00
parent b3a9656964
commit 7ffd800bc9
2 changed files with 211 additions and 242 deletions
+69 -74
View File
@@ -514,23 +514,34 @@ mod tests {
use std::collections::HashSet;
use super::*;
use crate::sqlite::{Table, sqlite3_parse_into_statement};
use crate::sqlite::{SchemaError, Table, sqlite3_parse_into_statement};
fn parse_create_table(create_table_sql: &str) -> Table {
let create_table_statement = sqlite3_parse_into_statement(create_table_sql)
.unwrap()
.unwrap();
return create_table_statement.try_into().unwrap();
}
fn parse_create_view(create_view_sql: &str, tables: &[Table]) -> Result<View, SchemaError> {
let create_view_statement = sqlite3_parse_into_statement(create_view_sql)
.unwrap()
.unwrap();
return View::from(create_view_statement, tables);
}
#[test]
fn test_parse_create_view() {
let table: Table = {
let table_sql = r#"
CREATE TABLE table0 (
id BLOB PRIMARY KEY NOT NULL CHECK(is_uuid_v7(id)) DEFAULT (uuid_v7()),
col0 TEXT NOT NULL DEFAULT '',
col1 BLOB NOT NULL,
hidden INTEGER DEFAULT 42
) STRICT;
"#;
let create_table_statement = sqlite3_parse_into_statement(table_sql).unwrap().unwrap();
create_table_statement.try_into().unwrap()
};
let table = parse_create_table(
r#"
CREATE TABLE table0 (
id BLOB PRIMARY KEY NOT NULL CHECK(is_uuid_v7(id)) DEFAULT (uuid_v7()),
col0 TEXT NOT NULL DEFAULT '',
col1 BLOB NOT NULL,
hidden INTEGER DEFAULT 42
) STRICT;
"#,
);
let tables = [table.clone()];
let metadata = TableMetadata::new(table, &tables, "_user");
@@ -540,12 +551,11 @@ mod tests {
assert_eq!(1, *metadata.name_to_index.get("col0").unwrap());
{
let table_view: View = {
let view_sql = "CREATE VIEW view0 AS SELECT col0, col1 FROM table0";
let create_view_statement = sqlite3_parse_into_statement(view_sql).unwrap().unwrap();
View::from(create_view_statement, &tables).unwrap()
};
let table_view = parse_create_view(
"CREATE VIEW view0 AS SELECT col0, col1 FROM table0",
&tables,
)
.unwrap();
assert_eq!(table_view.name.name, "view0");
assert_eq!(table_view.query, "SELECT col0, col1 FROM table0");
assert_eq!(table_view.temporary, false);
@@ -567,12 +577,8 @@ mod tests {
{
let query = "SELECT id, col0, col1 FROM table0";
let table_view: View = {
let view_sql = format!("CREATE VIEW view0 AS {query}");
let create_view_statement = sqlite3_parse_into_statement(&view_sql).unwrap().unwrap();
View::from(create_view_statement, &tables).unwrap()
};
let table_view =
parse_create_view(&format!("CREATE VIEW view0 AS {query}"), &tables).unwrap();
assert_eq!(table_view.name.name, "view0");
assert_eq!(table_view.query, query);
@@ -589,21 +595,18 @@ mod tests {
#[test]
fn test_parse_create_view_with_subquery() {
let table_a: Table = {
let table_sql =
"CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT";
let stmt = sqlite3_parse_into_statement(table_sql).unwrap().unwrap();
stmt.try_into().unwrap()
};
let table_a = parse_create_table(
"CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT",
);
let tables = [table_a];
{
let view: View = {
let view_sql = "CREATE VIEW view0 AS SELECT * FROM (SELECT * FROM a);";
let create_view_statement = sqlite3_parse_into_statement(&view_sql).unwrap().unwrap();
View::from(create_view_statement, &tables).unwrap()
};
let view = parse_create_view(
"CREATE VIEW view0 AS SELECT * FROM (SELECT * FROM a);",
&tables,
)
.unwrap();
let view_columns = view.columns.as_ref().unwrap();
assert_eq!(view_columns.len(), 2);
@@ -620,12 +623,10 @@ mod tests {
}
{
let _view_result: Result<View, _> = {
let view_sql = "CREATE VIEW view0 AS SELECT id FROM (SELECT * FROM a);";
let create_view_statement = sqlite3_parse_into_statement(&view_sql).unwrap().unwrap();
View::from(create_view_statement, &tables)
};
let _view_result = parse_create_view(
"CREATE VIEW view0 AS SELECT id FROM (SELECT * FROM a);",
&tables,
);
// TODO: Support column filter on sub-queries.
// let view = _view_result.unwrap();
@@ -644,33 +645,25 @@ mod tests {
#[test]
fn test_parse_create_view_with_joins() {
let table_a: Table = {
let table_sql =
"CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT";
let stmt = sqlite3_parse_into_statement(table_sql).unwrap().unwrap();
stmt.try_into().unwrap()
};
let table_b: Table = {
let table_sql = r#"
let table_a = parse_create_table(
"CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT",
);
let table_b = parse_create_table(
r#"
CREATE TABLE b (
id INTEGER PRIMARY KEY,
fk INTEGER NOT NULL REFERENCES a(id)
) STRICT"#;
let stmt = sqlite3_parse_into_statement(table_sql).unwrap().unwrap();
stmt.try_into().unwrap()
};
) STRICT"#,
);
let tables = [table_a, table_b];
{
// LEFT JOIN
let view: View = {
let view_sql = r#"
CREATE VIEW view0 AS SELECT a.data, b.fk, a.id FROM a AS a LEFT JOIN b AS b ON a.id = b.fk;
"#;
let create_view_statement = sqlite3_parse_into_statement(&view_sql).unwrap().unwrap();
View::from(create_view_statement, &tables).unwrap()
};
let view = parse_create_view(
"CREATE VIEW view0 AS SELECT a.data, b.fk, a.id FROM a AS a LEFT JOIN b AS b ON a.id = b.fk;",
&tables,
).unwrap();
let view_columns = view.columns.as_ref().unwrap();
assert_eq!(view_columns.len(), 3);
@@ -696,13 +689,13 @@ mod tests {
name: "table_name".to_string(),
database_schema: Some("main".to_string()),
};
let table_sql = format!(
let table = parse_create_table(&format!(
"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY) STRICT",
table_name = table_name.escaped_string()
);
let create_table_statement = sqlite3_parse_into_statement(&table_sql).unwrap().unwrap();
let table: Table = create_table_statement.try_into().unwrap();
let table_metadata = TableMetadata::new(table.clone(), &[table.clone()], "_user");
));
let tables = [table.clone()];
let table_metadata = TableMetadata::new(table.clone(), &tables, "_user");
let mut table_set = HashSet::<TableMetadata>::new();
@@ -718,13 +711,15 @@ mod tests {
name: "view_name".to_string(),
database_schema: Some("main".to_string()),
};
let view_sql = format!(
"CREATE VIEW {view_name} AS SELECT id FROM {table_name}",
view_name = view_name.escaped_string(),
table_name = table_name.escaped_string()
);
let create_view_statement = sqlite3_parse_into_statement(&view_sql).unwrap().unwrap();
let table_view = View::from(create_view_statement, &[table.clone()]).unwrap();
let table_view = parse_create_view(
&format!(
"CREATE VIEW {view_name} AS SELECT id FROM {table_name}",
view_name = view_name.escaped_string(),
table_name = table_name.escaped_string()
),
&tables,
)
.unwrap();
let view_metadata = Arc::new(ViewMetadata::new(table_view, &[table.clone()]));
let mut view_set = HashSet::<Arc<ViewMetadata>>::new();
+142 -168
View File
@@ -1,4 +1,5 @@
use fallible_iterator::FallibleIterator;
use indexmap::IndexMap;
use itertools::Itertools;
use log::*;
use serde::{Deserialize, Serialize};
@@ -933,40 +934,51 @@ impl std::fmt::Display for SelectFormatter {
}
impl View {
pub fn from(value: sqlite3_parser::ast::Stmt, tables: &[Table]) -> Result<Self, SchemaError> {
return match value {
sqlite3_parser::ast::Stmt::CreateView {
temporary,
if_not_exists,
view_name,
columns,
select,
} => {
let columns = match columns.is_some() {
true => {
info!("CREATE VIEW column filtering not supported (yet)");
None
}
false => try_extract_column_mapping((*select).clone(), tables)?.map(|column_mapping| {
column_mapping
.into_iter()
.map(|mapping| mapping.column)
.collect()
}),
};
Ok(View {
name: view_name.into(),
columns,
query: SelectFormatter(*select).to_string(),
temporary,
if_not_exists,
})
}
_ => Err(SchemaError::Precondition(
format!("expected 'CREATE VIEW', got: {value:?}").into(),
)),
pub fn from(stmt: sqlite3_parser::ast::Stmt, tables: &[Table]) -> Result<Self, SchemaError> {
let sqlite3_parser::ast::Stmt::CreateView {
temporary,
if_not_exists,
view_name,
columns,
select,
} = stmt
else {
return Err(SchemaError::Precondition(
format!("expected 'CREATE VIEW', got: {stmt:?}").into(),
));
};
let column_mapping: Option<Vec<ColumnMapping>> = if columns.is_some() {
// Example, `CREATE VIEW view0(alias0, alias1) AS SELECT * FROM table0;`
//
// We probably never want to support this due to its late failure mode,
// i.e. column mismatches are discovered at query-time rather than
// view-creation. Also table schema changes may later invalidate
// existing views.
debug!("VIEW column aliases not supported for APIs");
None
} else {
// Try to parse columns very liberally. We don't want to disallow complex
// VIEWs but returning a `View` with `None` columns, means it cannot be used
// for APIs.
extract_column_mapping((*select).clone(), tables)
.map_err(|err| {
debug!(
"Failed to extract VIEW column mapping from '{:?}': {err}",
*select
);
return err;
})
.ok()
};
return Ok(View {
name: view_name.into(),
columns: column_mapping.map(|o| o.into_iter().map(|m| m.column).collect()),
query: SelectFormatter(*select).to_string(),
temporary,
if_not_exists,
});
}
}
@@ -1000,14 +1012,17 @@ struct ColumnMapping {
referred_column: Option<ReferredColumn>,
}
fn try_extract_column_mapping(
fn extract_column_mapping(
select: sqlite3_parser::ast::Select,
tables: &[Table],
) -> Result<Option<Vec<ColumnMapping>>, SchemaError> {
let body = select.body;
) -> Result<Vec<ColumnMapping>, SchemaError> {
fn precondition(m: &str) -> SchemaError {
return SchemaError::Precondition(m.into());
}
let body = select.body;
if body.compounds.is_some() {
return Ok(None);
return Err(precondition("Compound bodies not (yet) supported"));
}
let sqlite3_parser::ast::OneSelect::Select {
@@ -1020,44 +1035,65 @@ fn try_extract_column_mapping(
window_clause,
} = body.select
else {
return Ok(None);
return Err(precondition(&format!(
"Expected Select, got: {:?}",
body.select
)));
};
if distinctness.is_some() || group_by.is_some() || window_clause.is_some() {
return Ok(None);
if group_by.is_some() {
return Err(precondition("GROUP BY clause not (yet) supported"));
}
if distinctness.is_some() {
return Err(precondition("DISTINCT clause not (yet) supported"));
}
if window_clause.is_some() {
return Err(precondition("WINDOW clause not (yet) supported"));
}
// First build list of referenced tables and their aliases.
let Some(FromClause { select, joins, .. }) = from else {
return Ok(None);
let Some(FromClause {
select: nested_select,
joins,
..
}) = from
else {
return Err(precondition("missing FROM clause"));
};
let Some(select) = select else {
return Ok(None);
};
let (fqn, alias) = match *select {
SelectTable::Table(fqn, alias, _indexed) => (fqn, alias),
SelectTable::Select(select, _as) => {
if Some(&ResultColumn::Star) == columns.get(0) {
let (fqn, alias) = match nested_select.map(|s| *s) {
Some(SelectTable::Table(fqn, alias, _indexed)) => (fqn, alias),
Some(SelectTable::Select(select, _as)) => {
// Nested sub-query case.
if Some(&ResultColumn::Star) == columns.first() {
// Recurse
return try_extract_column_mapping(*select, tables);
return extract_column_mapping(*select, tables);
}
// Support more complex
debug!("The following sub-query is not (yet) supported: {select:?}");
return Ok(None);
return Err(precondition(&format!(
"The following sub-query is not (yet) supported: {select:?}"
)));
}
_ => {
debug!("The following select is not (yet) supported: {select:?}");
return Ok(None);
Some(x) => {
return Err(precondition(&format!(
"The following sub-query is not (yet) supported: {x:?}"
)));
}
None => {
return Err(precondition("missing SELECT"));
}
};
// Use IndexMap to preserve insertion order.
let mut table_names = indexmap::IndexMap::<String, QualifiedName>::from([to_entry(fqn, alias)]);
let mut table_names = IndexMap::<String, QualifiedName>::from([to_entry(fqn, alias)]);
if let Some(joins) = joins {
for join in joins {
let SelectTable::Table(fqn, alias, _indexed) = join.table else {
return Ok(None);
return Err(precondition("JOIN with TABLE expected"));
};
let entry = to_entry(fqn, alias);
@@ -1075,8 +1111,9 @@ fn try_extract_column_mapping(
match all_tables.get(table_name) {
Some(table) => {
if !table.strict {
info!("Skipping view: referenced table: {table_name:?} not strict");
return Ok(None);
return Err(precondition(&format!(
"Referenced table: {table_name:?} must be STRICT"
)));
}
for col in &table.columns {
@@ -1084,9 +1121,7 @@ fn try_extract_column_mapping(
}
}
None => {
return Err(SchemaError::Precondition(
format!("View's SELECT references missing table: {table_name:?}").into(),
));
return Err(precondition(&format!("Table missing: {table_name:?}")));
}
};
}
@@ -1113,9 +1148,7 @@ fn try_extract_column_mapping(
ResultColumn::TableStar(name) => {
let name = unquote_name(name);
let Some(table_name) = table_names.get(&name) else {
return Err(SchemaError::Precondition(
format!("Missing alias: {name}").into(),
));
return Err(precondition(&format!("Missing alias: {name}")));
};
let table = all_tables.get(table_name).expect("checked above");
@@ -1133,9 +1166,7 @@ fn try_extract_column_mapping(
Expr::Id(id) => {
let col_name = unquote_id(id.clone());
let Some((table, column)) = all_columns.get(&col_name) else {
return Err(SchemaError::Precondition(
format!("Missing columns: {id:?}").into(),
));
return Err(precondition(&format!("Missing columns: {id:?}")));
};
let name = alias
@@ -1164,16 +1195,14 @@ fn try_extract_column_mapping(
let col_name = unquote_name(name.clone());
let Some(table_name) = table_names.get(&qualifier) else {
return Err(SchemaError::Precondition(
format!("Missing table: Qualified({qualifier}, {name})").into(),
));
return Err(precondition(&format!(
"Missing table ({qualifier}, {name})"
)));
};
let table = all_tables.get(table_name).expect("checked above");
let Some(column) = table.columns.iter().find(|c| c.name == col_name) else {
return Err(SchemaError::Precondition(
format!("Missing col: {col_name}").into(),
));
return Err(precondition(&format!("Missing col: {col_name}")));
};
let name = alias
@@ -1227,18 +1256,15 @@ fn try_extract_column_mapping(
referred_column: None,
});
}
_x => {
x => {
// We cannot map arbitrary expressions.
#[cfg(debug_assertions)]
debug!("skipping expr: {_x:?}");
return Ok(None);
return Err(precondition(&format!("Unsupported expr: {x:?}")));
}
},
};
}
return Ok(Some(mapping));
return Ok(mapping);
}
fn build_foreign_key(
@@ -1533,8 +1559,11 @@ mod tests {
fn test_parse_create_index() {
let sql =
r#"CREATE UNIQUE INDEX "main"."index_name" ON 'table_name' (a ASC, b DESC) WHERE x > 0"#;
let stmt = sqlite3_parse_into_statement(sql).unwrap().unwrap();
let index: TableIndex = stmt.try_into().unwrap();
let index: TableIndex = sqlite3_parse_into_statement(sql)
.unwrap()
.unwrap()
.try_into()
.unwrap();
let sql1 = index.create_index_statement();
let stmt1 = sqlite3_parse_into_statement(&sql1).unwrap().unwrap();
@@ -1571,9 +1600,7 @@ mod tests {
else {
panic!("Not a select");
};
let _mapping = try_extract_column_mapping(*select, &tables)
.unwrap()
.unwrap();
let _mapping = extract_column_mapping(*select, &tables).unwrap();
}
{
@@ -1584,9 +1611,7 @@ mod tests {
else {
panic!("Not a select");
};
let _mapping = try_extract_column_mapping(*select, &tables)
.unwrap()
.unwrap();
let _mapping = extract_column_mapping(*select, &tables).unwrap();
}
{
@@ -1597,12 +1622,17 @@ mod tests {
else {
panic!("Not a select");
};
let _mapping = try_extract_column_mapping(*select, &tables)
.unwrap()
.unwrap();
let _mapping = extract_column_mapping(*select, &tables).unwrap();
}
}
fn parse_create_table(create_table_sql: &str) -> Table {
let create_table_statement = sqlite3_parse_into_statement(create_table_sql)
.unwrap()
.unwrap();
return create_table_statement.try_into().unwrap();
}
#[test]
fn test_view_column_extraction_join() {
let sql = "SELECT user, *, a.*, p.user AS foo FROM foo.articles AS a LEFT JOIN bar.profiles AS p ON p.user = a.author";
@@ -1612,84 +1642,28 @@ mod tests {
panic!("Not a select");
};
let tables = vec![
Table {
name: QualifiedName {
name: "profiles".to_string(),
database_schema: Some("bar".to_string()),
},
strict: true,
columns: vec![
Column {
name: "user".to_string(),
data_type: ColumnDataType::Blob,
options: vec![
ColumnOption::Unique {
is_primary: true,
conflict_clause: None,
},
ColumnOption::ForeignKey {
foreign_table: "_user".to_string(),
referred_columns: vec!["id".to_string()],
on_delete: None,
on_update: None,
},
],
},
Column {
name: "username".to_string(),
data_type: ColumnDataType::Text,
options: vec![],
},
],
foreign_keys: vec![],
unique: vec![],
checks: vec![],
virtual_table: false,
temporary: false,
},
Table {
name: QualifiedName {
name: "articles".to_string(),
database_schema: Some("foo".to_string()),
},
strict: true,
columns: vec![
Column {
name: "id".to_string(),
data_type: ColumnDataType::Blob,
options: vec![ColumnOption::Unique {
is_primary: true,
conflict_clause: None,
}],
},
Column {
name: "author".to_string(),
data_type: ColumnDataType::Blob,
options: vec![ColumnOption::ForeignKey {
foreign_table: "_user".to_string(),
referred_columns: vec!["id".to_string()],
on_delete: None,
on_update: None,
}],
},
Column {
name: "body".to_string(),
data_type: ColumnDataType::Text,
options: vec![],
},
],
foreign_keys: vec![],
unique: vec![],
checks: vec![],
virtual_table: false,
temporary: false,
},
];
let profiles_table = parse_create_table(
r#"
CREATE TABLE bar.profiles (
user BLOB PRIMARY KEY NOT NULL REFERENCES _user(id),
username TEXT NOT NULL
) STRICT;
"#,
);
let mapping = try_extract_column_mapping(*select, &tables)
.unwrap()
.unwrap();
let articles_table = parse_create_table(
r#"
CREATE TABLE foo.articles (
id BLOB PRIMARY KEY NOT NULL,
author BLOB NOT NULL REFERENCES _user(id),
body TEXT
) STRICT;
"#,
);
let tables = [profiles_table, articles_table];
let mapping = extract_column_mapping(*select, &tables).unwrap();
assert_eq!(
mapping