Increase the space of applicable VIEWs for Record APIs by treating top-level GROUP BY <single col> expressions as key-defining. #99

This commit is contained in:
Sebastian Jeltsch
2025-07-23 17:10:06 +02:00
parent 94abee6206
commit 6ec0fc3941
21 changed files with 789 additions and 387 deletions
+5 -4
View File
@@ -6,6 +6,7 @@ import 'package:test/test.dart';
import 'package:dio/dio.dart';
const port = 4006;
const address = '127.0.0.1:${port}';
class SimpleStrict {
final String id;
@@ -150,7 +151,7 @@ class Comment {
}
Future<Client> connect() async {
final client = Client('http://127.0.0.1:${port}');
final client = Client('http://${address}');
await client.login('admin@localhost', 'secret');
return client;
}
@@ -171,7 +172,7 @@ Future<Process> initTrailBase() async {
'--',
'--data-dir=${depotPath}',
'run',
'--address=127.0.0.1:${port}',
'--address=${address}',
// We want at least some parallelism to experience isolate-local state.
'--js-runtime-threads=2',
]);
@@ -179,8 +180,8 @@ Future<Process> initTrailBase() async {
final dio = Dio();
for (int i = 0; i < 100; ++i) {
try {
final response = await dio.fetch(
RequestOptions(path: 'http://127.0.0.1:${port}/api/healthcheck'));
final response = await dio
.fetch(RequestOptions(path: 'http://${address}/api/healthcheck'));
if (response.statusCode == 200) {
return process;
}
+10 -7
View File
@@ -1,8 +1,11 @@
# Auto-generated config.Vault textproto
secrets: [{
key: "TRAIL_AUTH_OAUTH_PROVIDERS_OIDC0_CLIENT_SECRET"
value: "invalid_oidc_client_secret"
}, {
key: "TRAIL_AUTH_OAUTH_PROVIDERS_DISCORD_CLIENT_SECRET"
value: "invalid_discord_client_secret"
}]
secrets: [
{
key: "TRAIL_AUTH_OAUTH_PROVIDERS_OIDC0_CLIENT_SECRET"
value: "invalid_oidc_client_secret"
},
{
key: "TRAIL_AUTH_OAUTH_PROVIDERS_DISCORD_CLIENT_SECRET"
value: "invalid_discord_client_secret"
}
]
@@ -1,5 +1,7 @@
# Auto-generated config.Vault textproto
secrets: [{
key: "TRAIL_AUTH_OAUTH_PROVIDERS_DISCORD_CLIENT_SECRET"
value: "invalid_discord_client_secret"
}]
secrets: [
{
key: "TRAIL_AUTH_OAUTH_PROVIDERS_DISCORD_CLIENT_SECRET"
value: "invalid_discord_client_secret"
}
]
@@ -80,7 +80,7 @@ function buildSchema(schemas: ListSchemasResponse): SQLNamespace {
const viewName = view.name.name;
schema[viewName] = {
self: { label: viewName, type: "keyword" },
children: view.columns?.map((c) => c.name) ?? [],
children: view.column_mapping?.columns.map((c) => c.column.name) ?? [],
} satisfies SQLNamespace;
}
@@ -20,6 +20,7 @@ import {
isNotNull,
hiddenTable,
tableType,
getColumns,
ForeignKey,
} from "@/lib/schema";
@@ -66,7 +67,7 @@ function findTargetPortName(
continue;
}
for (const column of tableOrView.columns ?? []) {
for (const column of getColumns(tableOrView) ?? []) {
const unique = getUnique(column.options);
if (unique?.is_primary ?? false) {
return `${foreignKey.foreign_table}-${column.name}`;
@@ -88,7 +89,7 @@ function buildErNode(
};
const name = prettyFormatQualifiedName(tableOrView.name);
const columns = tableOrView.columns ?? [];
const columns = getColumns(tableOrView) ?? [];
const view = tableType(tableOrView) === "view";
const ports: PortMetadata[] = columns.map((column) => {
@@ -62,7 +62,7 @@ import {
} from "@/components/FormFields";
import { createConfigQuery, setConfig } from "@/lib/config";
import { parseSqlExpression } from "@/lib/parse";
import { tableType, getForeignKey } from "@/lib/schema";
import { tableType, getForeignKey, getColumns } from "@/lib/schema";
import { buildDefaultRow } from "@/lib/convert";
import { client } from "@/lib/fetch";
@@ -330,7 +330,7 @@ function getForeignKeyColumns(schema: Table | View): [string, ForeignKey][] {
return true;
}
return (schema.columns ?? [])
return (getColumns(schema) ?? [])
.map(
(c) =>
[c.name, getForeignKey(c.options)] as [string, ForeignKey | undefined],
+53 -15
View File
@@ -196,16 +196,15 @@ export function isJSONColumn(column: Column): boolean {
return false;
}
function columnsSatisfyRecordApiRequirements(
columns: Column[],
all: Table[],
): boolean {
for (const column of columns) {
if (isPrimaryKeyColumn(column)) {
if (column.data_type === "Integer") {
return true;
}
function isSuitableRecordPkColumn(column: Column, all: Table[]): boolean {
if (!isPrimaryKeyColumn(column)) {
return false;
}
switch (column.data_type) {
case "Integer":
return true;
case "Blob": {
if (isUUIDColumn(column)) {
return true;
}
@@ -214,14 +213,14 @@ function columnsSatisfyRecordApiRequirements(
if (foreign_key) {
const foreign_col_name = foreign_key.referred_columns[0];
if (!foreign_col_name) {
continue;
return false;
}
const foreign_table = all.find(
(t) => t.name.name === foreign_key.foreign_table,
);
if (!foreign_table) {
continue;
return false;
}
const foreign_col = foreign_table.columns.find(
@@ -231,6 +230,7 @@ function columnsSatisfyRecordApiRequirements(
return true;
}
}
break;
}
}
@@ -242,7 +242,11 @@ export function tableSatisfiesRecordApiRequirements(
all: Table[],
): boolean {
if (table.strict) {
return columnsSatisfyRecordApiRequirements(table.columns, all);
for (const column of table.columns) {
if (isSuitableRecordPkColumn(column, all)) {
return true;
}
}
}
return false;
}
@@ -251,15 +255,49 @@ export function viewSatisfiesRecordApiRequirements(
view: View,
all: Table[],
): boolean {
const columns = view.columns;
if (columns) {
return columnsSatisfyRecordApiRequirements(columns, all);
const mapping = view.column_mapping;
if (!mapping) {
return false;
}
const groupBy = mapping.group_by;
if (groupBy != null) {
if (isSuitableRecordPkColumn(mapping.columns[groupBy].column, all)) {
return true;
}
}
const RIGHT = 0x10;
const CROSS = 0x02;
const NATURAL = 0x04;
const MASK = RIGHT | CROSS | NATURAL;
for (const joinType of mapping.joins) {
if (joinType & MASK) {
return false;
}
}
for (const column of mapping.columns.map((c) => c.column)) {
if (isSuitableRecordPkColumn(column, all)) {
return true;
}
}
return false;
}
export type TableType = "table" | "virtualTable" | "view";
export function getColumns(tableOrView: Table | View): undefined | Column[] {
switch (tableType(tableOrView)) {
case "table":
case "virtualTable":
return (tableOrView as Table).columns;
case "view":
return (tableOrView as View).column_mapping?.columns.map((c) => c.column);
}
}
export function tableType(table: Table | View): TableType {
if ("virtual_table" in table) {
if (table.virtual_table) {
@@ -0,0 +1,12 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { ViewColumn } from "./ViewColumn";
export type ColumnMapping = { columns: Array<ViewColumn>,
/**
* Group by that can be used as a key for record APIs.
*/
group_by: number | null,
/**
* A list of joins.
*/
joins: Array<number>, };
+5 -3
View File
@@ -1,5 +1,5 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { Column } from "./Column";
import type { ColumnMapping } from "./ColumnMapping";
import type { QualifiedName } from "./QualifiedName";
export type View = { name: QualifiedName,
@@ -10,6 +10,8 @@ export type View = { name: QualifiedName,
* functions, ..., which makes them inherently not type safe and therefore their columns not
* well defined.
*
* NOTE: Should all this inference be in ViewMetadata?
* QUESTION: We've been wondering if the inference should live more in ViewMetadata, however
* right now the `View` is heavily used in the UI to e.g. render tables and infer record API
* suitability. It's ok that this is more than just an AST.
*/
columns: Array<Column> | null, query: string, temporary: boolean, };
column_mapping: ColumnMapping | null, query: string, temporary: boolean, };
@@ -0,0 +1,4 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { Column } from "./Column";
export type ViewColumn = { column: Column, parent_name: string | null, };
@@ -1 +1,2 @@
export const PORT: number = 4005;
export const ADDRESS: string = `127.0.0.1:${PORT}`;
@@ -1,8 +1,6 @@
import { expect, test } from "vitest";
import { OAuth2Server } from "oauth2-mock-server";
import { PORT } from "../constants";
const address: string = `http://127.0.0.1:${PORT}`;
import { ADDRESS } from "../constants";
type OpenIdConfig = {
issuer: string;
@@ -38,7 +36,7 @@ test("OIDC", async () => {
userInfoResponse.statusCode = 200;
});
const login = await fetch(`${address}/api/auth/v1/oauth/oidc0/login`, {
const login = await fetch(`http://${ADDRESS}/api/auth/v1/oauth/oidc0/login`, {
redirect: "manual",
});
@@ -4,9 +4,7 @@ import { initClient, urlSafeBase64Encode } from "../../src/index";
import type { Client, Event } from "../../src/index";
import { status } from "http-status";
import { v7 as uuidv7, parse as uuidParse } from "uuid";
const port: number = 4005;
const address: string = `http://127.0.0.1:${port}`;
import { ADDRESS } from "../constants";
const sleep = (ms: number) => new Promise((r) => setTimeout(r, ms));
@@ -33,7 +31,7 @@ type SimpleSubsetView = {
};
async function connect(): Promise<Client> {
const client = initClient(address);
const client = initClient(`http://${ADDRESS}`);
await client.login("admin@localhost", "secret");
return client;
}
@@ -1,8 +1,6 @@
import { expect, test } from "vitest";
import { status } from "http-status";
const port: number = 4005;
const address: string = `http://127.0.0.1:${port}`;
import { ADDRESS } from "../constants";
test("JS runtime", async () => {
const expected = {
@@ -14,16 +12,18 @@ test("JS runtime", async () => {
},
};
const jsonUrl = `${address}/json`;
const jsonUrl = `http://${ADDRESS}/json`;
const json = await (await fetch(jsonUrl)).json();
expect(json).toMatchObject(expected);
const response = await fetch(`${address}/fetch?url=${encodeURI(jsonUrl)}`);
const response = await fetch(
`http://${ADDRESS}/fetch?url=${encodeURI(jsonUrl)}`,
);
expect(await response.json()).toMatchObject(expected);
const errResp = await fetch(`${address}/error`);
const errResp = await fetch(`http://${ADDRESS}/error`);
expect(errResp.status).equals(status.IM_A_TEAPOT);
// Test that the periodic callback was called.
expect((await fetch(`${address}/await`)).status).equals(status.OK);
expect((await fetch(`http://${ADDRESS}/await`)).status).equals(status.OK);
});
@@ -5,7 +5,7 @@ import { cwd } from "node:process";
import { join } from "node:path";
import { execa, type Subprocess } from "execa";
import { PORT } from "./constants";
import { ADDRESS } from "./constants";
const sleep = (ms: number) => new Promise((r) => setTimeout(r, ms));
@@ -28,7 +28,7 @@ async function initTrailBase(): Promise<{ subprocess: Subprocess }> {
cwd: root,
stdout: process.stdout,
stderr: process.stdout,
})`cargo run -- --data-dir client/testfixture --public-url http://127.0.0.1:${PORT} run -a 127.0.0.1:${PORT} --js-runtime-threads 1`;
})`cargo run -- --data-dir client/testfixture --public-url http://${ADDRESS} run -a ${ADDRESS} --js-runtime-threads 1`;
for (let i = 0; i < 100; ++i) {
if ((subprocess.exitCode ?? 0) > 0) {
@@ -36,7 +36,7 @@ async function initTrailBase(): Promise<{ subprocess: Subprocess }> {
}
try {
const response = await fetch(`http://127.0.0.1:${PORT}/api/healthcheck`);
const response = await fetch(`http://${ADDRESS}/api/healthcheck`);
if (response.ok) {
return { subprocess };
}
+3 -3
View File
@@ -264,9 +264,9 @@ pub(crate) fn auth_ui_router() -> Router<AppState> {
.nest_service(
"/_/auth/",
AssetService::<trailbase_assets::AuthAssets>::with_parameters(
// We want as little magic as possible. The only /_/auth/subpath that isn't SSR, is profile, so
// we when hitting /profile or /profile, we want actually want to serve the static
// profile/index.html.
// We want as little magic as possible. The only /_/auth/subpath that isn't SSR, is
// profile, so we when hitting /profile or /profile, we want actually want to serve
// the static profile/index.html.
Some(Box::new(|path| {
if path == "profile" {
Some(format!("{path}/index.html"))
+1 -1
View File
@@ -608,7 +608,7 @@ mod tests {
id INTEGER PRIMARY KEY,
'index' TEXT NOT NULL DEFAULT '',
nullable INTEGER
);
) STRICT;
INSERT INTO 'table' (id, 'index', nullable) VALUES (1, '1', 1), (2, '2', NULL), (3, '3', NULL);
"#,
)
+1 -1
View File
@@ -104,7 +104,7 @@ impl RecordApiSchema {
};
let record_pk_column = (pk_index, pk_column.clone());
let Some(ref columns) = view_metadata.schema.columns else {
let Some(columns) = view_metadata.columns() else {
return Err("RecordApi requires schema".to_string());
};
let Some(json_metadata) = view_metadata.json_metadata() else {
+4 -3
View File
@@ -3,7 +3,9 @@ use log::*;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::metadata::{JsonColumnMetadata, JsonSchemaError, TableMetadata, extract_json_metadata};
use crate::metadata::{
JsonColumnMetadata, JsonSchemaError, TableMetadata, extract_json_metadata, is_pk_column,
};
use crate::sqlite::{Column, ColumnDataType, ColumnOption};
/// Influeces the generated JSON schema. In `Insert` mode columns with default values will be
@@ -120,8 +122,7 @@ pub fn build_json_schema_expanded(
};
let Some(pk_column) = (match referred_columns.len() {
0 => crate::metadata::find_pk_column_index(&table.schema.columns)
.map(|idx| &table.schema.columns[idx]),
0 => table.schema.columns.iter().find(|c| is_pk_column(c)),
1 => table
.schema
.columns
+339 -136
View File
@@ -2,13 +2,16 @@ use jsonschema::Validator;
use lazy_static::lazy_static;
use log::*;
use regex::Regex;
use sqlite3_parser::ast::JoinType;
use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use thiserror::Error;
use crate::sqlite::{Column, ColumnDataType, ColumnOption, QualifiedName, Table, View};
use crate::sqlite::{
Column, ColumnDataType, ColumnMapping, ColumnOption, QualifiedName, Table, View,
};
// TODO: Can we merge this with crate::sqlite::SchemaError?
#[derive(Debug, Clone, Error)]
@@ -76,18 +79,12 @@ impl JsonMetadata {
return Self::from_columns(&table.columns);
}
fn from_view(view: &View) -> Option<Self> {
return view.columns.as_ref().map(|cols| Self::from_columns(cols));
}
fn from_columns(columns: &[Column]) -> Self {
let columns: Vec<_> = columns.iter().map(build_json_metadata).collect();
let file_column_indexes = find_file_column_indexes(&columns);
return Self {
file_column_indexes: find_file_column_indexes(&columns),
columns,
file_column_indexes,
};
}
}
@@ -124,7 +121,7 @@ impl TableMetadata {
.map(|(index, col)| (col.name.clone(), index)),
);
let record_pk_column = find_record_pk_column_index(&table.columns, tables);
let record_pk_column = find_record_pk_column_index_for_table(&table, tables);
let user_id_columns = find_user_id_foreign_key_columns(&table.columns, user_table_name);
let json_metadata = JsonMetadata::from_table(&table);
@@ -193,6 +190,10 @@ impl Borrow<QualifiedName> for Arc<TableMetadata> {
pub struct ViewMetadata {
pub schema: View,
// QUESTION: Why do we have copy of the columns here? Right now it's duplicate from `.schema`.
// This probably only exists because we have a trait impl that returns Option<&[Column]>.
columns: Option<Vec<Column>>,
name_to_index: HashMap<String, usize>,
record_pk_column: Option<usize>,
json_metadata: Option<JsonMetadata>,
@@ -204,28 +205,36 @@ impl ViewMetadata {
/// NOTE: The list of all tables is needed only to extract interger/UUIDv7 pk columns for foreign
/// key relationships.
pub fn new(view: View, tables: &[Table]) -> Self {
let name_to_index = if let Some(ref columns) = view.columns {
HashMap::<String, usize>::from_iter(
columns
return match view.column_mapping {
Some(ref column_mapping) => {
let columns: Vec<Column> = column_mapping
.columns
.iter()
.enumerate()
.map(|(index, col)| (col.name.clone(), index)),
)
} else {
HashMap::<String, usize>::new()
};
.map(|m| m.column.clone())
.collect();
let record_pk_column = view
.columns
.as_ref()
.and_then(|c| find_record_pk_column_index(c, tables));
let json_metadata = JsonMetadata::from_view(&view);
let name_to_index = HashMap::<String, usize>::from_iter(
columns
.iter()
.enumerate()
.map(|(index, col)| (col.name.clone(), index)),
);
return ViewMetadata {
schema: view,
name_to_index,
record_pk_column,
json_metadata,
ViewMetadata {
name_to_index,
json_metadata: Some(JsonMetadata::from_columns(&columns)),
columns: Some(columns),
record_pk_column: find_record_pk_column_index_for_view(column_mapping, tables),
schema: view,
}
}
None => ViewMetadata {
name_to_index: HashMap::<String, usize>::default(),
columns: None,
record_pk_column: None,
json_metadata: None,
schema: view,
},
};
}
@@ -234,6 +243,11 @@ impl ViewMetadata {
&self.schema.name
}
#[inline]
pub fn columns(&self) -> Option<&[Column]> {
return self.columns.as_deref();
}
#[inline]
pub fn column_index_by_name(&self, key: &str) -> Option<usize> {
self.name_to_index.get(key).copied()
@@ -242,8 +256,8 @@ impl ViewMetadata {
#[inline]
pub fn column_by_name(&self, key: &str) -> Option<(usize, &Column)> {
let index = self.column_index_by_name(key)?;
let cols = self.schema.columns.as_ref()?;
return Some((index, &cols[index]));
let mapping = self.schema.column_mapping.as_ref()?;
return Some((index, &mapping.columns[index].column));
}
}
@@ -313,7 +327,7 @@ impl TableOrViewMetadata for ViewMetadata {
}
fn columns(&self) -> Option<&[Column]> {
return self.schema.columns.as_deref();
return self.columns.as_deref();
}
fn json_metadata(&self) -> Option<&JsonMetadata> {
@@ -321,7 +335,7 @@ impl TableOrViewMetadata for ViewMetadata {
}
fn record_pk_column(&self) -> Option<(usize, &Column)> {
let Some(columns) = &self.schema.columns else {
let Some(columns) = &self.columns else {
return None;
};
let index = self.record_pk_column?;
@@ -427,85 +441,122 @@ pub fn find_user_id_foreign_key_columns(columns: &[Column], user_table_name: &st
return indexes;
}
pub(crate) fn find_pk_column_index(columns: &[Column]) -> Option<usize> {
return columns.iter().position(|col| {
for opt in &col.options {
if let ColumnOption::Unique { is_primary, .. } = opt {
return *is_primary;
}
pub(crate) fn is_pk_column(column: &Column) -> bool {
for opt in &column.options {
if let ColumnOption::Unique { is_primary, .. } = opt {
return *is_primary;
}
}
return false;
}
fn is_suitable_record_pk_column(column: &Column, tables: &[Table]) -> bool {
if !is_pk_column(column) {
return false;
});
}
return match column.data_type {
ColumnDataType::Integer => {
// TODO: We should detect the "integer pk" desc case and at least warn:
// https://www.sqlite.org/lang_createtable.html#rowid.
true
}
ColumnDataType::Blob => {
lazy_static! {
static ref UUID_CHECK_RE: Regex =
Regex::new(r"^is_uuid(|_v7|_v4)\s*\(").expect("infallible");
}
for opts in &column.options {
match opts {
// Check the column itself is a UUID column.
ColumnOption::Check(expr) if UUID_CHECK_RE.is_match(expr) => return true,
// Or that a referenced column is a UUID column.
ColumnOption::ForeignKey {
foreign_table,
referred_columns,
..
} => {
let referred_column = {
if referred_columns.len() != 1 {
return false;
}
&referred_columns[0]
};
// NOTE: Foreign keys cannot cross database boundaries, we can therefore compare by
// unqualified name.
let Some(referred_table) = tables.iter().find(|t| t.name.name == *foreign_table) else {
warn!("Failed to get foreign key schema for {foreign_table}");
return false;
};
let Some(foreign_column) = referred_table
.columns
.iter()
.find(|c| c.name == *referred_column)
else {
return false;
};
for opt in &foreign_column.options {
match opt {
ColumnOption::Check(expr) if UUID_CHECK_RE.is_match(expr) => return true,
_ => {}
}
}
}
_ => {}
}
}
false
}
_ => false,
};
}
/// Finds suitable Integer or UUIDv7/UUIDv4 primary key columns, if present.
///
/// Cursors require certain properties like a stable, time-sortable primary key.
fn find_record_pk_column_index(columns: &[Column], tables: &[Table]) -> Option<usize> {
let index = find_pk_column_index(columns)?;
let column = &columns[index];
if column.data_type == ColumnDataType::Integer {
// TODO: We should detect the "integer pk" desc case and at least warn:
// https://www.sqlite.org/lang_createtable.html#rowid.
return Some(index);
}
for opts in &column.options {
lazy_static! {
static ref UUID_CHECK_RE: Regex = Regex::new(r"^is_uuid(|_v7|_v4)\s*\(").expect("infallible");
}
match &opts {
// Check if the referenced column is a uuidv7 column.
ColumnOption::ForeignKey {
foreign_table,
referred_columns,
..
} => {
// NOTE: Foreign keys cannot cross database boundaries, we can therefore compare by
// unqualified name.
let Some(referred_table) = tables.iter().find(|t| t.name.name == *foreign_table) else {
error!("Failed to get foreign key schema for {foreign_table}");
continue;
};
if referred_columns.len() != 1 {
return None;
}
let referred_column = &referred_columns[0];
let col = referred_table
.columns
.iter()
.find(|c| c.name == *referred_column)?;
let mut is_pk = false;
for opt in &col.options {
match opt {
ColumnOption::Check(expr) if UUID_CHECK_RE.is_match(expr) => {
return Some(index);
}
ColumnOption::Unique { is_primary, .. } if *is_primary => {
is_pk = true;
}
_ => {}
}
}
if is_pk && col.data_type == ColumnDataType::Integer {
return Some(index);
}
return None;
}
ColumnOption::Check(expr) if UUID_CHECK_RE.is_match(expr) => {
fn find_record_pk_column_index_for_table(table: &Table, tables: &[Table]) -> Option<usize> {
if table.strict {
for (index, column) in table.columns.iter().enumerate() {
if is_suitable_record_pk_column(column, tables) {
return Some(index);
}
_ => {}
}
}
return None;
}
fn find_record_pk_column_index_for_view(
column_mapping: &ColumnMapping,
tables: &[Table],
) -> Option<usize> {
if let Some(group_by_index) = column_mapping.group_by {
let column = &column_mapping.columns[group_by_index];
if is_suitable_record_pk_column(&column.column, tables) {
return Some(group_by_index);
}
return None;
}
// NOTE: We could be smarter here. It's quite tricky to say with a set of arbitrary joins, which
// (integer, UUID) columns end up being unique afterwards. Rely on explicit GROUP BY instead.
let mask = JoinType::RIGHT | JoinType::CROSS | JoinType::NATURAL;
for join_type in &column_mapping.joins {
if join_type & mask.bits() != 0 {
warn!("Only LEFT and INNER JOINS supported yet, got: {join_type:?}");
return None;
}
}
for (index, mapped_column) in column_mapping.columns.iter().enumerate() {
if is_suitable_record_pk_column(&mapped_column.column, tables) {
return Some(index);
}
}
return None;
}
@@ -515,16 +566,55 @@ mod tests {
use super::*;
use crate::parse::parse_into_statement;
use crate::sqlite::Table;
use crate::sqlite::{SchemaError, Table};
fn parse_create_table(create_table_sql: &str) -> Table {
let create_table_statement = parse_into_statement(create_table_sql).unwrap().unwrap();
return create_table_statement.try_into().unwrap();
}
fn parse_create_view(create_view_sql: &str, tables: &[Table]) -> View {
fn parse_create_view(create_view_sql: &str, tables: &[Table]) -> Result<View, SchemaError> {
let create_view_statement = parse_into_statement(create_view_sql).unwrap().unwrap();
return View::from(create_view_statement, tables).unwrap();
return View::from(create_view_statement, tables);
}
#[test]
fn test_find_record_pk_column_index_for_table() {
let table = parse_create_table("CREATE TABLE t (id INTEGER PRIMARY KEY) STRICT");
let tables = [table.clone()];
assert_eq!(
Some(0),
find_record_pk_column_index_for_table(&table, &tables)
);
let table = parse_create_table(
r#"
CREATE TABLE t (
value TEXT,
id BLOB PRIMARY KEY NOT NULL CHECK(is_uuid(id))
) STRICT;
"#,
);
let tables = [table.clone()];
assert_eq!(
Some(1),
find_record_pk_column_index_for_table(&table, &tables)
);
let non_strict_table = parse_create_table(
r#"
CREATE TABLE t (
value TEXT,
id BLOB PRIMARY KEY NOT NULL CHECK(is_uuid(id))
);
"#,
);
let tables = [non_strict_table.clone()];
assert_eq!(
None,
find_record_pk_column_index_for_table(&non_strict_table, &tables)
);
}
#[test]
@@ -551,19 +641,20 @@ mod tests {
let table_view = parse_create_view(
"CREATE VIEW view0 AS SELECT col0, col1 FROM table0",
&tables,
);
)
.unwrap();
assert_eq!(table_view.name.name, "view0");
assert_eq!(table_view.query, "SELECT col0, col1 FROM table0");
assert_eq!(table_view.temporary, false);
let view_columns = table_view.columns.as_ref().unwrap();
let view_columns = &table_view.column_mapping.as_ref().unwrap().columns;
assert_eq!(view_columns.len(), 2);
assert_eq!(view_columns[0].name, "col0");
assert_eq!(view_columns[0].data_type, ColumnDataType::Text);
assert_eq!(view_columns[0].column.name, "col0");
assert_eq!(view_columns[0].column.data_type, ColumnDataType::Text);
assert_eq!(view_columns[1].name, "col1");
assert_eq!(view_columns[1].data_type, ColumnDataType::Blob);
assert_eq!(view_columns[1].column.name, "col1");
assert_eq!(view_columns[1].column.data_type, ColumnDataType::Blob);
let view_metadata = ViewMetadata::new(table_view, &tables);
@@ -573,7 +664,8 @@ mod tests {
{
let query = "SELECT id, col0, col1 FROM table0";
let table_view = parse_create_view(&format!("CREATE VIEW view0 AS {query}"), &tables);
let table_view =
parse_create_view(&format!("CREATE VIEW view0 AS {query}"), &tables).unwrap();
assert_eq!(table_view.name.name, "view0");
assert_eq!(table_view.query, query);
@@ -600,15 +692,16 @@ mod tests {
let view = parse_create_view(
"CREATE VIEW view0 AS SELECT * FROM (SELECT * FROM a);",
&tables,
);
let view_columns = view.columns.as_ref().unwrap();
)
.unwrap();
let view_columns = &view.column_mapping.as_ref().unwrap().columns;
assert_eq!(view_columns.len(), 2);
assert_eq!(view_columns[0].name, "id");
assert_eq!(view_columns[0].data_type, ColumnDataType::Integer);
assert_eq!(view_columns[0].column.name, "id");
assert_eq!(view_columns[0].column.data_type, ColumnDataType::Integer);
assert_eq!(view_columns[1].name, "data");
assert_eq!(view_columns[1].data_type, ColumnDataType::Text);
assert_eq!(view_columns[1].column.name, "data");
assert_eq!(view_columns[1].column.data_type, ColumnDataType::Text);
let metadata = ViewMetadata::new(view, &tables);
let (pk_index, pk_col) = metadata.record_pk_column().unwrap();
@@ -620,11 +713,12 @@ mod tests {
let view = parse_create_view(
"CREATE VIEW view0 AS SELECT id FROM (SELECT * FROM a);",
&tables,
);
let view_columns = view.columns.as_ref().unwrap();
)
.unwrap();
let view_columns = &view.column_mapping.as_ref().unwrap().columns;
assert_eq!(view_columns.len(), 1);
assert_eq!(view_columns[0].name, "id");
assert_eq!(view_columns[0].data_type, ColumnDataType::Integer);
assert_eq!(view_columns[0].column.name, "id");
assert_eq!(view_columns[0].column.data_type, ColumnDataType::Integer);
let metadata = ViewMetadata::new(view, &tables);
let (pk_index, pk_col) = metadata.record_pk_column().unwrap();
@@ -636,11 +730,12 @@ mod tests {
let view = parse_create_view(
"CREATE VIEW view0 AS SELECT x.id FROM (SELECT * FROM a) AS x;",
&tables,
);
let view_columns = view.columns.as_ref().unwrap();
)
.unwrap();
let view_columns = &view.column_mapping.as_ref().unwrap().columns;
assert_eq!(view_columns.len(), 1);
assert_eq!(view_columns[0].name, "id");
assert_eq!(view_columns[0].data_type, ColumnDataType::Integer);
assert_eq!(view_columns[0].column.name, "id");
assert_eq!(view_columns[0].column.data_type, ColumnDataType::Integer);
let metadata = ViewMetadata::new(view, &tables);
let (pk_index, pk_col) = metadata.record_pk_column().unwrap();
@@ -653,8 +748,8 @@ mod tests {
let view = parse_create_view(
"CREATE VIEW view0 AS SELECT x.id, y.id FROM (SELECT * FROM a) AS x, (SELECT * FROM a) AS y;",
&tables,
);
assert_eq!(view.columns, None);
).unwrap();
assert!(view.column_mapping.is_none());
}
}
@@ -678,24 +773,131 @@ mod tests {
let view = parse_create_view(
"CREATE VIEW view0 AS SELECT a.data, b.fk, a.id FROM a AS a LEFT JOIN b AS b ON a.id = b.fk;",
&tables,
);
let view_columns = view.columns.as_ref().unwrap();
).unwrap();
let view_columns = &view.column_mapping.as_ref().unwrap().columns;
assert_eq!(view_columns.len(), 3);
assert_eq!(view_columns[2].name, "id");
assert_eq!(view_columns[2].data_type, ColumnDataType::Integer);
assert_eq!(view_columns[2].column.name, "id");
assert_eq!(view_columns[2].column.data_type, ColumnDataType::Integer);
assert_eq!(view_columns[0].name, "data");
assert_eq!(view_columns[0].data_type, ColumnDataType::Text);
assert_eq!(view_columns[0].column.name, "data");
assert_eq!(view_columns[0].column.data_type, ColumnDataType::Text);
assert_eq!(view_columns[1].name, "fk");
assert_eq!(view_columns[1].data_type, ColumnDataType::Integer);
assert_eq!(view_columns[1].column.name, "fk");
assert_eq!(view_columns[1].column.data_type, ColumnDataType::Integer);
let metadata = ViewMetadata::new(view, &tables);
let (pk_index, pk_col) = metadata.record_pk_column().unwrap();
assert_eq!(pk_index, 2);
assert_eq!(pk_col.name, "id");
}
{
// JOINs
for (join_type, expected) in [
("LEFT", Some(1)),
("INNER", Some(1)),
("RIGHT", None),
("CROSS", None),
] {
let view = parse_create_view(
&format!(
"CREATE VIEW view0 AS SELECT a.data, a.id FROM a AS a {join_type} JOIN b AS b ON a.id = b.fk;"
),
&tables,
)
.unwrap();
let metadata = ViewMetadata::new(view, &tables);
assert_eq!(
expected,
metadata.record_pk_column().map(|c| c.0),
"{join_type}"
);
}
}
}
#[test]
fn test_parse_create_view_with_group_by() {
let table_a = parse_create_table(
"CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT",
);
let table_b = parse_create_table(
r#"
CREATE TABLE b (
id INTEGER PRIMARY KEY,
fk INTEGER NOT NULL REFERENCES a(id)
) STRICT"#,
);
let tables = [table_a, table_b];
{
// JOIN on a SELECT is not suitable for APIs. They're cross-producty nature spoils PKs.
{
for (i, sql) in [
"CREATE VIEW v AS SELECT data, a.id AS z FROM a RIGHT JOIN b ON a.id = b.id GROUP BY z;",
"CREATE VIEW v AS SELECT data, x.id AS z FROM a AS x RIGHT JOIN b ON x.id = b.id GROUP BY z;",
"CREATE VIEW v AS SELECT data, x.id AS z FROM a x RIGHT JOIN b ON x.id = b.id GROUP BY z;",
].iter().enumerate() {
let view = parse_create_view(sql, &tables).unwrap();
assert!(view.column_mapping.is_some(), "{i}: {sql}");
let metadata = ViewMetadata::new(view, &tables);
assert_eq!(Some(1), metadata.record_pk_column().map(|c| c.0));
}
}
{
let view = parse_create_view(
"CREATE VIEW v AS SELECT a.data, a.id FROM a RIGHT JOIN b ON a.id = b.id GROUP BY a.id;",
&tables,
)
.unwrap();
let metadata = ViewMetadata::new(view, &tables);
assert_eq!(Some(1), metadata.record_pk_column().map(|c| c.0));
}
}
}
#[test]
fn test_parse_create_view_from_issue_99() {
let authors_table = parse_create_table(
"
CREATE TABLE authors (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
age INTEGER DEFAULT NULL
) STRICT;
",
);
let posts_table = parse_create_table(
"
CREATE TABLE posts (
id INTEGER PRIMARY KEY,
author INTEGER DEFAULT NULL REFERENCES persons(id),
title TEXT NOT NULL
) STRICT;
",
);
let tables = [authors_table, posts_table];
let view = parse_create_view(
"
CREATE VIEW authors_view_posts AS
SELECT authors.* FROM authors authors
INNER JOIN posts posts ON posts.author = authors.id
GROUP BY authors.id;
",
&tables,
)
.unwrap();
let metadata = ViewMetadata::new(view, &tables);
assert_eq!(Some(0), metadata.record_pk_column().map(|c| c.0));
}
#[test]
@@ -726,15 +928,16 @@ mod tests {
name: "view_name".to_string(),
database_schema: Some("main".to_string()),
};
let table_view = parse_create_view(
let view = parse_create_view(
&format!(
"CREATE VIEW {view_name} AS SELECT id FROM {table_name}",
view_name = view_name.escaped_string(),
table_name = table_name.escaped_string()
),
&tables,
);
let view_metadata = Arc::new(ViewMetadata::new(table_view, &[table.clone()]));
)
.unwrap();
let view_metadata = Arc::new(ViewMetadata::new(view, &[table.clone()]));
let mut view_set = HashSet::<Arc<ViewMetadata>>::new();
+325 -187
View File
@@ -851,7 +851,7 @@ impl std::fmt::Display for SelectFormatter {
}
}
#[derive(Clone, Debug, Serialize, Deserialize, TS, PartialEq)]
#[derive(Clone, Debug, Serialize, Deserialize, TS)]
pub struct View {
pub name: QualifiedName,
@@ -861,8 +861,10 @@ pub struct View {
/// functions, ..., which makes them inherently not type safe and therefore their columns not
/// well defined.
///
/// NOTE: Should all this inference be in ViewMetadata?
pub columns: Option<Vec<Column>>,
/// QUESTION: We've been wondering if the inference should live more in ViewMetadata, however
/// right now the `View` is heavily used in the UI to e.g. render tables and infer record API
/// suitability. It's ok that this is more than just an AST.
pub(crate) column_mapping: Option<ColumnMapping>,
pub query: String,
@@ -887,7 +889,7 @@ impl View {
));
};
let column_mapping: Option<Vec<ColumnMapping>> = if columns.is_some() {
let column_mapping: Option<ColumnMapping> = if columns.is_some() {
// Example, `CREATE VIEW view0(alias0, alias1) AS SELECT * FROM table0;`
//
// We probably never want to support this due to its late failure mode,
@@ -913,7 +915,7 @@ impl View {
return Ok(View {
name: view_name.into(),
columns: column_mapping.map(|o| o.into_iter().map(|m| m.column).collect()),
column_mapping,
query: SelectFormatter(*select).to_string(),
temporary,
if_not_exists,
@@ -921,141 +923,140 @@ impl View {
}
}
#[derive(Clone, Debug)]
#[allow(unused)]
struct ReferredColumn {
table_name: QualifiedName,
column_name: String,
#[derive(Clone, Debug, Deserialize, Serialize, TS)]
pub(crate) struct ViewColumn {
// e.g. "foo" for CREATE VIEW v AS SELECT foo.bar AS baz FROM ...
// #[allow(unused)]
// pub(crate) qualifier: Option<String>,
//
// // e.g. "baz" for CREATE VIEW v AS SELECT foo.bar AS baz FROM ...
// pub(crate) alias: Option<String>,
// The inferred column schema, either via a cast from a computed column or the underlying table
// column if inferable.
// NOTE: It would be cleaner to separate (Table)`Column` from `ViewColumn`, just pulling the in
// the contents here. However, the UI currently depends on `Column`.
pub(crate) column: Column,
// Would be "foo" for CREATE VIEW v AS SELECT foo.bar FROM foo;
pub(crate) parent_name: Option<String>,
}
#[derive(Clone, Debug)]
struct ColumnMapping {
column: Column,
#[derive(Clone, Debug, Serialize, Deserialize, TS)]
pub(crate) struct ColumnMapping {
pub(crate) columns: Vec<ViewColumn>,
#[allow(unused)]
referred_column: Option<ReferredColumn>,
/// Group by that can be used as a key for record APIs.
pub(crate) group_by: Option<usize>,
/// A list of joins.
pub(crate) joins: Vec<u8>,
}
fn extract_column_mapping(
select: sqlite3_parser::ast::Select,
tables: &[Table],
) -> Result<Vec<ColumnMapping>, SchemaError> {
) -> Result<ColumnMapping, SchemaError> {
let result_columns = extract_result_columns(&select)?;
let referenced_table_by_alias = extract_referenced_tables_by_alias(select, tables)?;
let group_by_key_candidate = extract_group_by_key_candidate(&select)?;
let (joins, referenced_table_by_alias) =
extract_joins_and_referenced_tables_by_alias(select, tables)?;
// SQLite checks comprehensively and will return an `ambiguous column name: <col>` error at
// query time (as opposed to VIEW-creation-time).
let find_column_by_unqualified_name = |col_name: &str| -> Result<(&Table, &Column), SchemaError> {
// Search tables in index order.
let mut found: Option<(&Table, &Column)> = None;
for (_alias, table) in &referenced_table_by_alias {
for col in &table.columns {
if col.name == col_name {
if found.is_some() {
return Err(precondition(&format!("Ambibuous column: {col_name}")));
let find_column_by_unqualified_name =
|col_name: &str| -> Result<(&ReferredTable, usize), SchemaError> {
// Search tables in index order.
let mut found: Option<(&ReferredTable, usize)> = None;
for referred_table in &referenced_table_by_alias {
for (i, col) in referred_table.table.columns.iter().enumerate() {
if col.name == col_name {
if found.is_some() {
return Err(precondition(&format!("Ambiguous column: {col_name}")));
}
found = Some((referred_table, i));
}
found = Some((table, col));
}
}
}
return found.ok_or(precondition(&format!("Column '{col_name}' not found")));
};
return found.ok_or(precondition(&format!("Column '{col_name}' not found")));
};
let find_table_by_alias = |a: &str| -> Result<&Table, SchemaError> {
for (alias, table) in &referenced_table_by_alias {
if alias.as_deref() == Some(a) {
return Ok(table);
let find_table_by_alias = |a: &str| -> Result<&ReferredTable, SchemaError> {
for referred_table in &referenced_table_by_alias {
if referred_table.alias.as_deref() == Some(a) {
return Ok(referred_table);
}
if referred_table.table.name.name == a {
return Ok(referred_table);
}
}
return Err(precondition(&format!("No table found for '{a}'")));
};
let mut mapping: Vec<ColumnMapping> = vec![];
let mut mapping: Vec<ViewColumn> = vec![];
for col in result_columns {
use sqlite3_parser::ast::Expr;
match col {
ResultColumn::Star => {
for (_alias, table) in &referenced_table_by_alias {
for c in &table.columns {
mapping.push(ColumnMapping {
for referred_table in &referenced_table_by_alias {
for c in &referred_table.table.columns {
mapping.push(ViewColumn {
column: c.clone(),
referred_column: Some(ReferredColumn {
table_name: table.name.clone(),
column_name: c.name.clone(),
}),
parent_name: get_parent_name(referred_table),
});
}
}
}
ResultColumn::TableStar(name) => {
let name = unquote_name(name);
let table = find_table_by_alias(&name)?;
let referred_table = find_table_by_alias(&name)?;
for c in &table.columns {
mapping.push(ColumnMapping {
for c in &referred_table.table.columns {
mapping.push(ViewColumn {
column: c.clone(),
referred_column: Some(ReferredColumn {
table_name: table.name.clone(),
column_name: c.name.clone(),
}),
parent_name: get_parent_name(referred_table),
});
}
}
ResultColumn::Expr(expr, alias) => match expr {
Expr::Id(id) => {
let col_name = unquote_id(id.clone());
let (table, column) = find_column_by_unqualified_name(&col_name)?;
let (referred_table, column_index) = find_column_by_unqualified_name(&col_name)?;
let column = &referred_table.table.columns[column_index];
let name = alias
.and_then(|alias| {
if let sqlite3_parser::ast::As::As(name) = alias {
return Some(unquote_name(name));
}
None
})
.unwrap_or_else(|| column.name.clone());
mapping.push(ColumnMapping {
mapping.push(ViewColumn {
column: Column {
name,
name: to_alias(alias).unwrap_or_else(|| column.name.clone()),
data_type: column.data_type,
options: column.options.clone(),
},
referred_column: Some(ReferredColumn {
table_name: table.name.clone(),
column_name: column.name.clone(),
}),
parent_name: get_parent_name(referred_table),
});
}
Expr::Qualified(qualifier, name) => {
let table = find_table_by_alias(&unquote_name(qualifier))?;
let qualifier = unquote_name(qualifier);
let referred_table = find_table_by_alias(&qualifier)?;
let col_name = unquote_name(name);
let Some(column) = table.columns.iter().find(|c| c.name == col_name) else {
let Some(column_index) = referred_table
.table
.columns
.iter()
.position(|c| c.name == col_name)
else {
return Err(precondition(&format!("Missing col: {col_name}")));
};
let name = alias
.and_then(|alias| {
if let sqlite3_parser::ast::As::As(name) = alias {
return Some(unquote_name(name));
}
None
})
.unwrap_or_else(|| column.name.clone());
let column = &referred_table.table.columns[column_index];
mapping.push(ColumnMapping {
mapping.push(ViewColumn {
column: Column {
name,
name: to_alias(alias).unwrap_or_else(|| column.name.clone()),
data_type: column.data_type,
options: column.options.clone(),
},
referred_column: Some(ReferredColumn {
table_name: table.name.clone(),
column_name: column.name.clone(),
}),
parent_name: get_parent_name(referred_table),
});
}
Expr::Cast { expr: _, type_name } => {
@@ -1070,22 +1071,17 @@ fn extract_column_mapping(
));
};
let Some(name) = alias.and_then(|alias| {
if let sqlite3_parser::ast::As::As(name) = alias {
return Some(unquote_name(name));
}
None
}) else {
let Some(name) = to_alias(alias) else {
return Err(SchemaError::Precondition("Missing alias in cast".into()));
};
mapping.push(ColumnMapping {
mapping.push(ViewColumn {
column: Column {
name,
data_type,
options: vec![ColumnOption::Null],
options: vec![],
},
referred_column: None,
parent_name: None,
});
}
x => {
@@ -1096,27 +1092,61 @@ fn extract_column_mapping(
};
}
return Ok(mapping);
}
return match group_by_key_candidate {
None => Ok(ColumnMapping {
columns: mapping,
group_by: None,
joins,
}),
Some(group_by) => {
// NOTE: GROUP BY can technically reference any column, but only columns also exposed by the
// VIEW are useful to us as keys. In other words, there's no point of us to search for this
// column through all referenced tables.
let group_by = match group_by.qualifier {
Some(ref qualifier) => {
// If the "GROUP BY" uses a qualifier, it must reference a table or subselect, e.g.:
// CREATE VIEW v AS SELECT a.id FROM a RIGHT JOIN b ON a.id = b.id GROUP BY a.id;
mapping.iter().position(|v: &ViewColumn| {
v.parent_name.as_ref() == Some(qualifier) && v.column.name == group_by.name
})
}
None => mapping
.iter()
.position(|v: &ViewColumn| v.column.name == group_by.name),
};
#[inline]
fn precondition(m: &str) -> SchemaError {
return SchemaError::Precondition(m.into());
}
fn extract_result_columns(
select: &sqlite3_parser::ast::Select,
) -> Result<Vec<ResultColumn>, SchemaError> {
let sqlite3_parser::ast::OneSelect::Select { columns, .. } = &select.body.select else {
return Err(precondition("VALUES not supported"));
Ok(ColumnMapping {
columns: mapping,
group_by: Some(group_by.ok_or_else(|| precondition("GROUP BY column not exposed"))?),
joins,
})
}
};
return Ok(columns.clone());
}
fn extract_referenced_tables_by_alias(
fn get_parent_name(referred_table: &ReferredTable) -> Option<String> {
return Some(
referred_table
.alias
.as_ref()
.unwrap_or(&referred_table.table.name.name)
.to_owned(),
);
}
#[derive(Clone, Debug)]
pub(crate) struct ReferredTable {
/// Optional top-most alias (nested aliases, e.g. in a sub-query, are not accessible).
pub(crate) alias: Option<String>,
/// The referenced table.
pub(crate) table: Table,
}
fn extract_joins_and_referenced_tables_by_alias(
select: sqlite3_parser::ast::Select,
tables: &[Table],
) -> Result<Vec<(Option<String>, &Table)>, SchemaError> {
) -> Result<(Vec<u8>, Vec<ReferredTable>), SchemaError> {
let body = select.body;
if body.compounds.is_some() {
return Err(precondition("Compound queries not supported"));
@@ -1126,7 +1156,7 @@ fn extract_referenced_tables_by_alias(
columns: _,
distinctness,
from,
group_by,
group_by: _,
having: _,
where_clause: _,
window_clause,
@@ -1138,10 +1168,6 @@ fn extract_referenced_tables_by_alias(
)));
};
if group_by.is_some() {
return Err(precondition("GROUP BY clause not (yet) supported"));
}
if distinctness.is_some() {
return Err(precondition("DISTINCT clause not (yet) supported"));
}
@@ -1168,7 +1194,7 @@ fn extract_referenced_tables_by_alias(
// Make sure all referenced tables are strict.
if !table.strict {
return Err(precondition(&format!(
"Referenced table: {:?} must be STRICT",
"Table {:?} must be STRICT to derive type",
table.name
)));
}
@@ -1176,42 +1202,76 @@ fn extract_referenced_tables_by_alias(
return Ok(table);
};
// Map from "alias" to table. Use IndexMap to preserve insertion order.
let referenced_table_by_alias: Vec<(Option<String>, &Table)> = match nested_select.map(|s| *s) {
let mut all_joins: Vec<u8> = joins
.as_ref()
.map(|joins| {
joins
.iter()
.map(|join| extract_join_type(join.operator).bits())
.collect()
})
.unwrap_or_default();
// List of referenced tables in insertion order (left-to-right).
let referenced_table_by_alias: Vec<ReferredTable> = match nested_select.map(|s| *s) {
Some(SelectTable::Table(fqn, alias, _indexed)) => {
let mut referenced_tables = vec![(to_alias(alias), find_table(&fqn.into())?)];
// Table itself
let mut referenced_tables = vec![ReferredTable {
alias: to_alias(alias),
table: find_table(&fqn.into())?.clone(),
}];
// // Plus possible joins.
for join in joins.unwrap_or_default() {
match join.operator {
JoinOperator::TypedJoin(Some(t)) if t.contains(JoinType::LEFT) => {}
_ => {
// Right now, we're picking the VIEW's primary key left to right. Other joins would
// require more sophistication. Many joins will spoil PKs, e.g. by computing a
// cross-product yielding a non-unique column.
return Err(precondition(&format!(
"Only LEFT JOINS supported yet, got: {:?}",
join.operator
)));
}
}
// We don't currently allow joining sub-queries, etc.
let SelectTable::Table(fqn, alias, _indexed) = join.table else {
return Err(precondition("JOIN with TABLE expected"));
match join.table {
SelectTable::Table(fqn, alias, _indexed) => {
referenced_tables.push(ReferredTable {
alias: to_alias(alias),
table: find_table(&fqn.into())?.clone(),
});
}
SelectTable::Select(subselect, alias) => {
let alias = to_alias(alias);
let (joins_in_subselect, referenced_tables_in_subselect) =
extract_joins_and_referenced_tables_by_alias(*subselect, tables)?;
all_joins.extend(joins_in_subselect);
referenced_tables.extend(referenced_tables_in_subselect.into_iter().map(
|ReferredTable { table, .. }| -> ReferredTable {
return ReferredTable {
alias: alias.clone(),
table,
};
},
));
}
_ => {
return Err(precondition("JOIN with TABLE expected"));
}
};
referenced_tables.push((to_alias(alias), find_table(&fqn.into())?));
}
referenced_tables
}
Some(SelectTable::Select(select, alias)) => {
Some(SelectTable::Select(nested_select, alias)) => {
// Simply recurse tu unnest the select.
let alias = to_alias(alias);
return Ok(
extract_referenced_tables_by_alias(*select, tables)?
let (joins_in_nested_select, referenced_tables_in_nested_select) =
extract_joins_and_referenced_tables_by_alias(*nested_select, tables)?;
return Ok((
joins_in_nested_select,
referenced_tables_in_nested_select
.into_iter()
.map(|(_, table)| (alias.clone(), table))
.map(|referred_table| ReferredTable {
// NOTE: Reset the alias.
alias: alias.clone(),
table: referred_table.table,
})
.collect(),
);
));
}
Some(x) => {
return Err(precondition(&format!(
@@ -1223,7 +1283,68 @@ fn extract_referenced_tables_by_alias(
}
};
return Ok(referenced_table_by_alias);
return Ok((all_joins, referenced_table_by_alias));
}
#[inline]
fn precondition(m: &str) -> SchemaError {
return SchemaError::Precondition(m.into());
}
fn extract_join_type(op: JoinOperator) -> JoinType {
return match op {
JoinOperator::TypedJoin(Some(t)) => t,
JoinOperator::Comma | JoinOperator::TypedJoin(None) => JoinType::INNER,
};
}
fn extract_result_columns(
select: &sqlite3_parser::ast::Select,
) -> Result<Vec<ResultColumn>, SchemaError> {
let sqlite3_parser::ast::OneSelect::Select { columns, .. } = &select.body.select else {
return Err(precondition("VALUES not supported"));
};
return Ok(columns.clone());
}
struct GroupBy {
qualifier: Option<String>,
name: String,
}
fn extract_group_by_key_candidate(
select: &sqlite3_parser::ast::Select,
) -> Result<Option<GroupBy>, SchemaError> {
let sqlite3_parser::ast::OneSelect::Select { group_by, .. } = &select.body.select else {
return Err(precondition("VALUES not supported"));
};
let Some(group_by) = group_by else {
return Ok(None);
};
return match group_by.len() {
1 => match group_by[0].clone() {
Expr::Id(id) => Ok(Some(GroupBy {
qualifier: None,
name: unquote_id(id),
})),
Expr::Name(name) => Ok(Some(GroupBy {
qualifier: None,
name: unquote_name(name),
})),
Expr::Qualified(qualifier, name) => Ok(Some(GroupBy {
qualifier: Some(unquote_name(qualifier)),
name: unquote_name(name),
})),
expr => Err(precondition(&format!(
"For RecordAPIs GROUP BY expressions must reference an exposed VIEW column, got {expr:?}"
))),
},
n => Err(precondition(&format!(
"For RecordAPIs GROUP BY expressions must reference a single VIEW column, got {n}"
))),
};
}
fn build_foreign_key(
@@ -1541,6 +1662,14 @@ mod tests {
assert_eq!(index, index1, "Parsed: {sql1}");
}
fn parse_into_select(sql: &str) -> sqlite3_parser::ast::Select {
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
else {
panic!("Not a select");
};
return *select;
}
#[test]
fn test_view_column_extraction() {
let tables = vec![Table {
@@ -1563,56 +1692,45 @@ mod tests {
{
// No alias
let sql = "SELECT column FROM table_name";
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
else {
panic!("Not a select");
};
let _mapping = extract_column_mapping(*select, &tables).unwrap();
let select = parse_into_select("SELECT column FROM table_name");
let _mapping = extract_column_mapping(select, &tables).unwrap();
}
{
// With alias
let sql = "SELECT alias.column FROM table_name AS alias";
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
else {
panic!("Not a select");
};
let _mapping = extract_column_mapping(*select, &tables).unwrap();
let select = parse_into_select("SELECT alias.column FROM table_name AS alias");
let _mapping = extract_column_mapping(select, &tables).unwrap();
}
{
// With "elided" alias
let sql = "SELECT alias.column FROM table_name alias";
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
else {
panic!("Not a select");
};
let _mapping = extract_column_mapping(*select, &tables).unwrap();
let select = parse_into_select("SELECT alias.column FROM table_name alias");
let _mapping = extract_column_mapping(select, &tables).unwrap();
}
{
// JOIN on a SELECT.
let sql = "SELECT x.column, y.column FROM table_name AS x LEFT JOIN (SELECT * FROM table_name) AS y ON x.column = y.column";
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
else {
panic!("Not a select");
};
let err = extract_column_mapping(*select, &tables)
.err()
.unwrap()
.to_string();
assert!(err.contains("JOIN with TABLE expected"), "{err}");
let select = parse_into_select(
"SELECT x.column, y.column AS foo FROM table_name AS x LEFT JOIN (SELECT * FROM table_name) AS y ON x.column = y.column",
);
let column_mapping = extract_column_mapping(select, &tables).unwrap();
let columns = &column_mapping.columns;
assert_eq!(columns.len(), 2, "{columns:?}");
let first = &columns[0];
assert_eq!(first.column.data_type, ColumnDataType::Text);
assert_eq!(first.column.name, "column");
let second = &columns[1];
assert_eq!(second.column.data_type, ColumnDataType::Text);
assert_eq!(second.column.name, "foo");
}
{
// Compound SELECT.
let sql = "SELECT column FROM table_name UNION SELECT column FROM table_name";
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
else {
panic!("Not a select");
};
let err = extract_column_mapping(*select, &tables)
let select =
parse_into_select("SELECT column FROM table_name UNION SELECT column FROM table_name");
let err = extract_column_mapping(select, &tables)
.err()
.unwrap()
.to_string();
@@ -1620,6 +1738,38 @@ mod tests {
}
}
fn parse_create_view_select(sql: &str) -> sqlite3_parser::ast::Select {
let sqlite3_parser::ast::Stmt::CreateView { select, .. } =
parse_into_statement(sql).unwrap().unwrap()
else {
panic!("Not a CREATE VIEW: {sql}");
};
return *select;
}
#[test]
fn test_creare_view_colum_mapping() {
let table_a = parse_create_table(
"CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT",
);
let tables = [table_a];
let select =
parse_create_view_select("CREATE VIEW view0 AS SELECT x.id FROM a AS x GROUP BY x.id");
assert_eq!(
Some(0),
extract_column_mapping(select, &tables).unwrap().group_by
);
let select =
parse_create_view_select("CREATE VIEW view0 AS SELECT x.id FROM a AS x GROUP BY id");
assert_eq!(
Some(0),
extract_column_mapping(select, &tables).unwrap().group_by
)
}
fn parse_create_table(create_table_sql: &str) -> Table {
let create_table_statement = parse_into_statement(create_table_sql).unwrap().unwrap();
return create_table_statement.try_into().unwrap();
@@ -1627,12 +1777,6 @@ mod tests {
#[test]
fn test_view_column_extraction_join() {
let sql = "SELECT user, *, a.*, p.user AS foo FROM foo.articles AS a LEFT JOIN bar.profiles AS p ON p.user = a.author";
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
else {
panic!("Not a select");
};
let profiles_table = parse_create_table(
r#"
CREATE TABLE bar.profiles (
@@ -1654,20 +1798,14 @@ mod tests {
let tables = [profiles_table, articles_table];
let mapping = extract_column_mapping(*select, &tables).unwrap();
assert_eq!(
mapping
.iter()
.map(|m| m.referred_column.as_ref().unwrap().column_name.as_str())
.collect::<Vec<_>>(),
[
"user", "id", "author", "body", "user", "username", "id", "author", "body", "user"
]
let select = parse_into_select(
"SELECT user, *, a.*, p.user AS foo FROM foo.articles AS a LEFT JOIN bar.profiles AS p ON p.user = a.author",
);
let mapping = extract_column_mapping(select, &tables).unwrap();
assert_eq!(
mapping
.columns
.iter()
.map(|m| m.column.name.as_str())
.collect::<Vec<_>>(),