From 7ffd800bc967cc66401b8e4e7778a8e0e5bd4584 Mon Sep 17 00:00:00 2001 From: Sebastian Jeltsch Date: Tue, 22 Jul 2025 10:21:31 +0200 Subject: [PATCH] Minor: tidy up the VIEW extraction code. --- trailbase-schema/src/metadata.rs | 143 +++++++------- trailbase-schema/src/sqlite.rs | 310 ++++++++++++++----------------- 2 files changed, 211 insertions(+), 242 deletions(-) diff --git a/trailbase-schema/src/metadata.rs b/trailbase-schema/src/metadata.rs index cf0aa9dd..8ddbd77e 100644 --- a/trailbase-schema/src/metadata.rs +++ b/trailbase-schema/src/metadata.rs @@ -514,23 +514,34 @@ mod tests { use std::collections::HashSet; use super::*; - use crate::sqlite::{Table, sqlite3_parse_into_statement}; + use crate::sqlite::{SchemaError, Table, sqlite3_parse_into_statement}; + + fn parse_create_table(create_table_sql: &str) -> Table { + let create_table_statement = sqlite3_parse_into_statement(create_table_sql) + .unwrap() + .unwrap(); + return create_table_statement.try_into().unwrap(); + } + + fn parse_create_view(create_view_sql: &str, tables: &[Table]) -> Result { + let create_view_statement = sqlite3_parse_into_statement(create_view_sql) + .unwrap() + .unwrap(); + return View::from(create_view_statement, tables); + } #[test] fn test_parse_create_view() { - let table: Table = { - let table_sql = r#" - CREATE TABLE table0 ( - id BLOB PRIMARY KEY NOT NULL CHECK(is_uuid_v7(id)) DEFAULT (uuid_v7()), - col0 TEXT NOT NULL DEFAULT '', - col1 BLOB NOT NULL, - hidden INTEGER DEFAULT 42 - ) STRICT; - "#; - - let create_table_statement = sqlite3_parse_into_statement(table_sql).unwrap().unwrap(); - create_table_statement.try_into().unwrap() - }; + let table = parse_create_table( + r#" + CREATE TABLE table0 ( + id BLOB PRIMARY KEY NOT NULL CHECK(is_uuid_v7(id)) DEFAULT (uuid_v7()), + col0 TEXT NOT NULL DEFAULT '', + col1 BLOB NOT NULL, + hidden INTEGER DEFAULT 42 + ) STRICT; + "#, + ); let tables = [table.clone()]; let metadata = TableMetadata::new(table, &tables, "_user"); @@ -540,12 +551,11 @@ mod tests { assert_eq!(1, *metadata.name_to_index.get("col0").unwrap()); { - let table_view: View = { - let view_sql = "CREATE VIEW view0 AS SELECT col0, col1 FROM table0"; - let create_view_statement = sqlite3_parse_into_statement(view_sql).unwrap().unwrap(); - - View::from(create_view_statement, &tables).unwrap() - }; + let table_view = parse_create_view( + "CREATE VIEW view0 AS SELECT col0, col1 FROM table0", + &tables, + ) + .unwrap(); assert_eq!(table_view.name.name, "view0"); assert_eq!(table_view.query, "SELECT col0, col1 FROM table0"); assert_eq!(table_view.temporary, false); @@ -567,12 +577,8 @@ mod tests { { let query = "SELECT id, col0, col1 FROM table0"; - let table_view: View = { - let view_sql = format!("CREATE VIEW view0 AS {query}"); - let create_view_statement = sqlite3_parse_into_statement(&view_sql).unwrap().unwrap(); - - View::from(create_view_statement, &tables).unwrap() - }; + let table_view = + parse_create_view(&format!("CREATE VIEW view0 AS {query}"), &tables).unwrap(); assert_eq!(table_view.name.name, "view0"); assert_eq!(table_view.query, query); @@ -589,21 +595,18 @@ mod tests { #[test] fn test_parse_create_view_with_subquery() { - let table_a: Table = { - let table_sql = - "CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT"; - let stmt = sqlite3_parse_into_statement(table_sql).unwrap().unwrap(); - stmt.try_into().unwrap() - }; + let table_a = parse_create_table( + "CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT", + ); let tables = [table_a]; { - let view: View = { - let view_sql = "CREATE VIEW view0 AS SELECT * FROM (SELECT * FROM a);"; - let create_view_statement = sqlite3_parse_into_statement(&view_sql).unwrap().unwrap(); - View::from(create_view_statement, &tables).unwrap() - }; + let view = parse_create_view( + "CREATE VIEW view0 AS SELECT * FROM (SELECT * FROM a);", + &tables, + ) + .unwrap(); let view_columns = view.columns.as_ref().unwrap(); assert_eq!(view_columns.len(), 2); @@ -620,12 +623,10 @@ mod tests { } { - let _view_result: Result = { - let view_sql = "CREATE VIEW view0 AS SELECT id FROM (SELECT * FROM a);"; - let create_view_statement = sqlite3_parse_into_statement(&view_sql).unwrap().unwrap(); - - View::from(create_view_statement, &tables) - }; + let _view_result = parse_create_view( + "CREATE VIEW view0 AS SELECT id FROM (SELECT * FROM a);", + &tables, + ); // TODO: Support column filter on sub-queries. // let view = _view_result.unwrap(); @@ -644,33 +645,25 @@ mod tests { #[test] fn test_parse_create_view_with_joins() { - let table_a: Table = { - let table_sql = - "CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT"; - let stmt = sqlite3_parse_into_statement(table_sql).unwrap().unwrap(); - stmt.try_into().unwrap() - }; - let table_b: Table = { - let table_sql = r#" + let table_a = parse_create_table( + "CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT", + ); + let table_b = parse_create_table( + r#" CREATE TABLE b ( id INTEGER PRIMARY KEY, fk INTEGER NOT NULL REFERENCES a(id) - ) STRICT"#; - let stmt = sqlite3_parse_into_statement(table_sql).unwrap().unwrap(); - stmt.try_into().unwrap() - }; + ) STRICT"#, + ); let tables = [table_a, table_b]; { // LEFT JOIN - let view: View = { - let view_sql = r#" - CREATE VIEW view0 AS SELECT a.data, b.fk, a.id FROM a AS a LEFT JOIN b AS b ON a.id = b.fk; - "#; - let create_view_statement = sqlite3_parse_into_statement(&view_sql).unwrap().unwrap(); - View::from(create_view_statement, &tables).unwrap() - }; + let view = parse_create_view( + "CREATE VIEW view0 AS SELECT a.data, b.fk, a.id FROM a AS a LEFT JOIN b AS b ON a.id = b.fk;", + &tables, + ).unwrap(); let view_columns = view.columns.as_ref().unwrap(); assert_eq!(view_columns.len(), 3); @@ -696,13 +689,13 @@ mod tests { name: "table_name".to_string(), database_schema: Some("main".to_string()), }; - let table_sql = format!( + let table = parse_create_table(&format!( "CREATE TABLE {table_name} (id INTEGER PRIMARY KEY) STRICT", table_name = table_name.escaped_string() - ); - let create_table_statement = sqlite3_parse_into_statement(&table_sql).unwrap().unwrap(); - let table: Table = create_table_statement.try_into().unwrap(); - let table_metadata = TableMetadata::new(table.clone(), &[table.clone()], "_user"); + )); + let tables = [table.clone()]; + + let table_metadata = TableMetadata::new(table.clone(), &tables, "_user"); let mut table_set = HashSet::::new(); @@ -718,13 +711,15 @@ mod tests { name: "view_name".to_string(), database_schema: Some("main".to_string()), }; - let view_sql = format!( - "CREATE VIEW {view_name} AS SELECT id FROM {table_name}", - view_name = view_name.escaped_string(), - table_name = table_name.escaped_string() - ); - let create_view_statement = sqlite3_parse_into_statement(&view_sql).unwrap().unwrap(); - let table_view = View::from(create_view_statement, &[table.clone()]).unwrap(); + let table_view = parse_create_view( + &format!( + "CREATE VIEW {view_name} AS SELECT id FROM {table_name}", + view_name = view_name.escaped_string(), + table_name = table_name.escaped_string() + ), + &tables, + ) + .unwrap(); let view_metadata = Arc::new(ViewMetadata::new(table_view, &[table.clone()])); let mut view_set = HashSet::>::new(); diff --git a/trailbase-schema/src/sqlite.rs b/trailbase-schema/src/sqlite.rs index 78e71eda..20bed343 100644 --- a/trailbase-schema/src/sqlite.rs +++ b/trailbase-schema/src/sqlite.rs @@ -1,4 +1,5 @@ use fallible_iterator::FallibleIterator; +use indexmap::IndexMap; use itertools::Itertools; use log::*; use serde::{Deserialize, Serialize}; @@ -933,40 +934,51 @@ impl std::fmt::Display for SelectFormatter { } impl View { - pub fn from(value: sqlite3_parser::ast::Stmt, tables: &[Table]) -> Result { - return match value { - sqlite3_parser::ast::Stmt::CreateView { - temporary, - if_not_exists, - view_name, - columns, - select, - } => { - let columns = match columns.is_some() { - true => { - info!("CREATE VIEW column filtering not supported (yet)"); - None - } - false => try_extract_column_mapping((*select).clone(), tables)?.map(|column_mapping| { - column_mapping - .into_iter() - .map(|mapping| mapping.column) - .collect() - }), - }; - - Ok(View { - name: view_name.into(), - columns, - query: SelectFormatter(*select).to_string(), - temporary, - if_not_exists, - }) - } - _ => Err(SchemaError::Precondition( - format!("expected 'CREATE VIEW', got: {value:?}").into(), - )), + pub fn from(stmt: sqlite3_parser::ast::Stmt, tables: &[Table]) -> Result { + let sqlite3_parser::ast::Stmt::CreateView { + temporary, + if_not_exists, + view_name, + columns, + select, + } = stmt + else { + return Err(SchemaError::Precondition( + format!("expected 'CREATE VIEW', got: {stmt:?}").into(), + )); }; + + let column_mapping: Option> = if columns.is_some() { + // Example, `CREATE VIEW view0(alias0, alias1) AS SELECT * FROM table0;` + // + // We probably never want to support this due to its late failure mode, + // i.e. column mismatches are discovered at query-time rather than + // view-creation. Also table schema changes may later invalidate + // existing views. + debug!("VIEW column aliases not supported for APIs"); + None + } else { + // Try to parse columns very liberally. We don't want to disallow complex + // VIEWs but returning a `View` with `None` columns, means it cannot be used + // for APIs. + extract_column_mapping((*select).clone(), tables) + .map_err(|err| { + debug!( + "Failed to extract VIEW column mapping from '{:?}': {err}", + *select + ); + return err; + }) + .ok() + }; + + return Ok(View { + name: view_name.into(), + columns: column_mapping.map(|o| o.into_iter().map(|m| m.column).collect()), + query: SelectFormatter(*select).to_string(), + temporary, + if_not_exists, + }); } } @@ -1000,14 +1012,17 @@ struct ColumnMapping { referred_column: Option, } -fn try_extract_column_mapping( +fn extract_column_mapping( select: sqlite3_parser::ast::Select, tables: &[Table], -) -> Result>, SchemaError> { - let body = select.body; +) -> Result, SchemaError> { + fn precondition(m: &str) -> SchemaError { + return SchemaError::Precondition(m.into()); + } + let body = select.body; if body.compounds.is_some() { - return Ok(None); + return Err(precondition("Compound bodies not (yet) supported")); } let sqlite3_parser::ast::OneSelect::Select { @@ -1020,44 +1035,65 @@ fn try_extract_column_mapping( window_clause, } = body.select else { - return Ok(None); + return Err(precondition(&format!( + "Expected Select, got: {:?}", + body.select + ))); }; - if distinctness.is_some() || group_by.is_some() || window_clause.is_some() { - return Ok(None); + if group_by.is_some() { + return Err(precondition("GROUP BY clause not (yet) supported")); + } + + if distinctness.is_some() { + return Err(precondition("DISTINCT clause not (yet) supported")); + } + + if window_clause.is_some() { + return Err(precondition("WINDOW clause not (yet) supported")); } // First build list of referenced tables and their aliases. - let Some(FromClause { select, joins, .. }) = from else { - return Ok(None); + let Some(FromClause { + select: nested_select, + joins, + .. + }) = from + else { + return Err(precondition("missing FROM clause")); }; - let Some(select) = select else { - return Ok(None); - }; - let (fqn, alias) = match *select { - SelectTable::Table(fqn, alias, _indexed) => (fqn, alias), - SelectTable::Select(select, _as) => { - if Some(&ResultColumn::Star) == columns.get(0) { + + let (fqn, alias) = match nested_select.map(|s| *s) { + Some(SelectTable::Table(fqn, alias, _indexed)) => (fqn, alias), + Some(SelectTable::Select(select, _as)) => { + // Nested sub-query case. + if Some(&ResultColumn::Star) == columns.first() { // Recurse - return try_extract_column_mapping(*select, tables); + return extract_column_mapping(*select, tables); } + // Support more complex - debug!("The following sub-query is not (yet) supported: {select:?}"); - return Ok(None); + return Err(precondition(&format!( + "The following sub-query is not (yet) supported: {select:?}" + ))); } - _ => { - debug!("The following select is not (yet) supported: {select:?}"); - return Ok(None); + Some(x) => { + return Err(precondition(&format!( + "The following sub-query is not (yet) supported: {x:?}" + ))); + } + None => { + return Err(precondition("missing SELECT")); } }; // Use IndexMap to preserve insertion order. - let mut table_names = indexmap::IndexMap::::from([to_entry(fqn, alias)]); + let mut table_names = IndexMap::::from([to_entry(fqn, alias)]); if let Some(joins) = joins { for join in joins { let SelectTable::Table(fqn, alias, _indexed) = join.table else { - return Ok(None); + return Err(precondition("JOIN with TABLE expected")); }; let entry = to_entry(fqn, alias); @@ -1075,8 +1111,9 @@ fn try_extract_column_mapping( match all_tables.get(table_name) { Some(table) => { if !table.strict { - info!("Skipping view: referenced table: {table_name:?} not strict"); - return Ok(None); + return Err(precondition(&format!( + "Referenced table: {table_name:?} must be STRICT" + ))); } for col in &table.columns { @@ -1084,9 +1121,7 @@ fn try_extract_column_mapping( } } None => { - return Err(SchemaError::Precondition( - format!("View's SELECT references missing table: {table_name:?}").into(), - )); + return Err(precondition(&format!("Table missing: {table_name:?}"))); } }; } @@ -1113,9 +1148,7 @@ fn try_extract_column_mapping( ResultColumn::TableStar(name) => { let name = unquote_name(name); let Some(table_name) = table_names.get(&name) else { - return Err(SchemaError::Precondition( - format!("Missing alias: {name}").into(), - )); + return Err(precondition(&format!("Missing alias: {name}"))); }; let table = all_tables.get(table_name).expect("checked above"); @@ -1133,9 +1166,7 @@ fn try_extract_column_mapping( Expr::Id(id) => { let col_name = unquote_id(id.clone()); let Some((table, column)) = all_columns.get(&col_name) else { - return Err(SchemaError::Precondition( - format!("Missing columns: {id:?}").into(), - )); + return Err(precondition(&format!("Missing columns: {id:?}"))); }; let name = alias @@ -1164,16 +1195,14 @@ fn try_extract_column_mapping( let col_name = unquote_name(name.clone()); let Some(table_name) = table_names.get(&qualifier) else { - return Err(SchemaError::Precondition( - format!("Missing table: Qualified({qualifier}, {name})").into(), - )); + return Err(precondition(&format!( + "Missing table ({qualifier}, {name})" + ))); }; let table = all_tables.get(table_name).expect("checked above"); let Some(column) = table.columns.iter().find(|c| c.name == col_name) else { - return Err(SchemaError::Precondition( - format!("Missing col: {col_name}").into(), - )); + return Err(precondition(&format!("Missing col: {col_name}"))); }; let name = alias @@ -1227,18 +1256,15 @@ fn try_extract_column_mapping( referred_column: None, }); } - _x => { + x => { // We cannot map arbitrary expressions. - #[cfg(debug_assertions)] - debug!("skipping expr: {_x:?}"); - - return Ok(None); + return Err(precondition(&format!("Unsupported expr: {x:?}"))); } }, }; } - return Ok(Some(mapping)); + return Ok(mapping); } fn build_foreign_key( @@ -1533,8 +1559,11 @@ mod tests { fn test_parse_create_index() { let sql = r#"CREATE UNIQUE INDEX "main"."index_name" ON 'table_name' (a ASC, b DESC) WHERE x > 0"#; - let stmt = sqlite3_parse_into_statement(sql).unwrap().unwrap(); - let index: TableIndex = stmt.try_into().unwrap(); + let index: TableIndex = sqlite3_parse_into_statement(sql) + .unwrap() + .unwrap() + .try_into() + .unwrap(); let sql1 = index.create_index_statement(); let stmt1 = sqlite3_parse_into_statement(&sql1).unwrap().unwrap(); @@ -1571,9 +1600,7 @@ mod tests { else { panic!("Not a select"); }; - let _mapping = try_extract_column_mapping(*select, &tables) - .unwrap() - .unwrap(); + let _mapping = extract_column_mapping(*select, &tables).unwrap(); } { @@ -1584,9 +1611,7 @@ mod tests { else { panic!("Not a select"); }; - let _mapping = try_extract_column_mapping(*select, &tables) - .unwrap() - .unwrap(); + let _mapping = extract_column_mapping(*select, &tables).unwrap(); } { @@ -1597,12 +1622,17 @@ mod tests { else { panic!("Not a select"); }; - let _mapping = try_extract_column_mapping(*select, &tables) - .unwrap() - .unwrap(); + let _mapping = extract_column_mapping(*select, &tables).unwrap(); } } + fn parse_create_table(create_table_sql: &str) -> Table { + let create_table_statement = sqlite3_parse_into_statement(create_table_sql) + .unwrap() + .unwrap(); + return create_table_statement.try_into().unwrap(); + } + #[test] fn test_view_column_extraction_join() { let sql = "SELECT user, *, a.*, p.user AS foo FROM foo.articles AS a LEFT JOIN bar.profiles AS p ON p.user = a.author"; @@ -1612,84 +1642,28 @@ mod tests { panic!("Not a select"); }; - let tables = vec![ - Table { - name: QualifiedName { - name: "profiles".to_string(), - database_schema: Some("bar".to_string()), - }, - strict: true, - columns: vec![ - Column { - name: "user".to_string(), - data_type: ColumnDataType::Blob, - options: vec![ - ColumnOption::Unique { - is_primary: true, - conflict_clause: None, - }, - ColumnOption::ForeignKey { - foreign_table: "_user".to_string(), - referred_columns: vec!["id".to_string()], - on_delete: None, - on_update: None, - }, - ], - }, - Column { - name: "username".to_string(), - data_type: ColumnDataType::Text, - options: vec![], - }, - ], - foreign_keys: vec![], - unique: vec![], - checks: vec![], - virtual_table: false, - temporary: false, - }, - Table { - name: QualifiedName { - name: "articles".to_string(), - database_schema: Some("foo".to_string()), - }, - strict: true, - columns: vec![ - Column { - name: "id".to_string(), - data_type: ColumnDataType::Blob, - options: vec![ColumnOption::Unique { - is_primary: true, - conflict_clause: None, - }], - }, - Column { - name: "author".to_string(), - data_type: ColumnDataType::Blob, - options: vec![ColumnOption::ForeignKey { - foreign_table: "_user".to_string(), - referred_columns: vec!["id".to_string()], - on_delete: None, - on_update: None, - }], - }, - Column { - name: "body".to_string(), - data_type: ColumnDataType::Text, - options: vec![], - }, - ], - foreign_keys: vec![], - unique: vec![], - checks: vec![], - virtual_table: false, - temporary: false, - }, - ]; + let profiles_table = parse_create_table( + r#" + CREATE TABLE bar.profiles ( + user BLOB PRIMARY KEY NOT NULL REFERENCES _user(id), + username TEXT NOT NULL + ) STRICT; + "#, + ); - let mapping = try_extract_column_mapping(*select, &tables) - .unwrap() - .unwrap(); + let articles_table = parse_create_table( + r#" + CREATE TABLE foo.articles ( + id BLOB PRIMARY KEY NOT NULL, + author BLOB NOT NULL REFERENCES _user(id), + body TEXT + ) STRICT; + "#, + ); + + let tables = [profiles_table, articles_table]; + + let mapping = extract_column_mapping(*select, &tables).unwrap(); assert_eq!( mapping