mirror of
https://github.com/trailbaseio/trailbase.git
synced 2026-05-08 18:09:34 -05:00
Increase the space of applicable VIEWs for Record APIs by treating top-level GROUP BY <single col> expressions as key-defining. #99
This commit is contained in:
@@ -6,6 +6,7 @@ import 'package:test/test.dart';
|
||||
import 'package:dio/dio.dart';
|
||||
|
||||
const port = 4006;
|
||||
const address = '127.0.0.1:${port}';
|
||||
|
||||
class SimpleStrict {
|
||||
final String id;
|
||||
@@ -150,7 +151,7 @@ class Comment {
|
||||
}
|
||||
|
||||
Future<Client> connect() async {
|
||||
final client = Client('http://127.0.0.1:${port}');
|
||||
final client = Client('http://${address}');
|
||||
await client.login('admin@localhost', 'secret');
|
||||
return client;
|
||||
}
|
||||
@@ -171,7 +172,7 @@ Future<Process> initTrailBase() async {
|
||||
'--',
|
||||
'--data-dir=${depotPath}',
|
||||
'run',
|
||||
'--address=127.0.0.1:${port}',
|
||||
'--address=${address}',
|
||||
// We want at least some parallelism to experience isolate-local state.
|
||||
'--js-runtime-threads=2',
|
||||
]);
|
||||
@@ -179,8 +180,8 @@ Future<Process> initTrailBase() async {
|
||||
final dio = Dio();
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
try {
|
||||
final response = await dio.fetch(
|
||||
RequestOptions(path: 'http://127.0.0.1:${port}/api/healthcheck'));
|
||||
final response = await dio
|
||||
.fetch(RequestOptions(path: 'http://${address}/api/healthcheck'));
|
||||
if (response.statusCode == 200) {
|
||||
return process;
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
# Auto-generated config.Vault textproto
|
||||
secrets: [{
|
||||
key: "TRAIL_AUTH_OAUTH_PROVIDERS_OIDC0_CLIENT_SECRET"
|
||||
value: "invalid_oidc_client_secret"
|
||||
}, {
|
||||
key: "TRAIL_AUTH_OAUTH_PROVIDERS_DISCORD_CLIENT_SECRET"
|
||||
value: "invalid_discord_client_secret"
|
||||
}]
|
||||
secrets: [
|
||||
{
|
||||
key: "TRAIL_AUTH_OAUTH_PROVIDERS_OIDC0_CLIENT_SECRET"
|
||||
value: "invalid_oidc_client_secret"
|
||||
},
|
||||
{
|
||||
key: "TRAIL_AUTH_OAUTH_PROVIDERS_DISCORD_CLIENT_SECRET"
|
||||
value: "invalid_discord_client_secret"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Auto-generated config.Vault textproto
|
||||
secrets: [{
|
||||
key: "TRAIL_AUTH_OAUTH_PROVIDERS_DISCORD_CLIENT_SECRET"
|
||||
value: "invalid_discord_client_secret"
|
||||
}]
|
||||
secrets: [
|
||||
{
|
||||
key: "TRAIL_AUTH_OAUTH_PROVIDERS_DISCORD_CLIENT_SECRET"
|
||||
value: "invalid_discord_client_secret"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -80,7 +80,7 @@ function buildSchema(schemas: ListSchemasResponse): SQLNamespace {
|
||||
const viewName = view.name.name;
|
||||
schema[viewName] = {
|
||||
self: { label: viewName, type: "keyword" },
|
||||
children: view.columns?.map((c) => c.name) ?? [],
|
||||
children: view.column_mapping?.columns.map((c) => c.column.name) ?? [],
|
||||
} satisfies SQLNamespace;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
isNotNull,
|
||||
hiddenTable,
|
||||
tableType,
|
||||
getColumns,
|
||||
ForeignKey,
|
||||
} from "@/lib/schema";
|
||||
|
||||
@@ -66,7 +67,7 @@ function findTargetPortName(
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const column of tableOrView.columns ?? []) {
|
||||
for (const column of getColumns(tableOrView) ?? []) {
|
||||
const unique = getUnique(column.options);
|
||||
if (unique?.is_primary ?? false) {
|
||||
return `${foreignKey.foreign_table}-${column.name}`;
|
||||
@@ -88,7 +89,7 @@ function buildErNode(
|
||||
};
|
||||
|
||||
const name = prettyFormatQualifiedName(tableOrView.name);
|
||||
const columns = tableOrView.columns ?? [];
|
||||
const columns = getColumns(tableOrView) ?? [];
|
||||
|
||||
const view = tableType(tableOrView) === "view";
|
||||
const ports: PortMetadata[] = columns.map((column) => {
|
||||
|
||||
@@ -62,7 +62,7 @@ import {
|
||||
} from "@/components/FormFields";
|
||||
import { createConfigQuery, setConfig } from "@/lib/config";
|
||||
import { parseSqlExpression } from "@/lib/parse";
|
||||
import { tableType, getForeignKey } from "@/lib/schema";
|
||||
import { tableType, getForeignKey, getColumns } from "@/lib/schema";
|
||||
import { buildDefaultRow } from "@/lib/convert";
|
||||
import { client } from "@/lib/fetch";
|
||||
|
||||
@@ -330,7 +330,7 @@ function getForeignKeyColumns(schema: Table | View): [string, ForeignKey][] {
|
||||
return true;
|
||||
}
|
||||
|
||||
return (schema.columns ?? [])
|
||||
return (getColumns(schema) ?? [])
|
||||
.map(
|
||||
(c) =>
|
||||
[c.name, getForeignKey(c.options)] as [string, ForeignKey | undefined],
|
||||
|
||||
@@ -196,16 +196,15 @@ export function isJSONColumn(column: Column): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
function columnsSatisfyRecordApiRequirements(
|
||||
columns: Column[],
|
||||
all: Table[],
|
||||
): boolean {
|
||||
for (const column of columns) {
|
||||
if (isPrimaryKeyColumn(column)) {
|
||||
if (column.data_type === "Integer") {
|
||||
return true;
|
||||
}
|
||||
function isSuitableRecordPkColumn(column: Column, all: Table[]): boolean {
|
||||
if (!isPrimaryKeyColumn(column)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (column.data_type) {
|
||||
case "Integer":
|
||||
return true;
|
||||
case "Blob": {
|
||||
if (isUUIDColumn(column)) {
|
||||
return true;
|
||||
}
|
||||
@@ -214,14 +213,14 @@ function columnsSatisfyRecordApiRequirements(
|
||||
if (foreign_key) {
|
||||
const foreign_col_name = foreign_key.referred_columns[0];
|
||||
if (!foreign_col_name) {
|
||||
continue;
|
||||
return false;
|
||||
}
|
||||
|
||||
const foreign_table = all.find(
|
||||
(t) => t.name.name === foreign_key.foreign_table,
|
||||
);
|
||||
if (!foreign_table) {
|
||||
continue;
|
||||
return false;
|
||||
}
|
||||
|
||||
const foreign_col = foreign_table.columns.find(
|
||||
@@ -231,6 +230,7 @@ function columnsSatisfyRecordApiRequirements(
|
||||
return true;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,7 +242,11 @@ export function tableSatisfiesRecordApiRequirements(
|
||||
all: Table[],
|
||||
): boolean {
|
||||
if (table.strict) {
|
||||
return columnsSatisfyRecordApiRequirements(table.columns, all);
|
||||
for (const column of table.columns) {
|
||||
if (isSuitableRecordPkColumn(column, all)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -251,15 +255,49 @@ export function viewSatisfiesRecordApiRequirements(
|
||||
view: View,
|
||||
all: Table[],
|
||||
): boolean {
|
||||
const columns = view.columns;
|
||||
if (columns) {
|
||||
return columnsSatisfyRecordApiRequirements(columns, all);
|
||||
const mapping = view.column_mapping;
|
||||
if (!mapping) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const groupBy = mapping.group_by;
|
||||
if (groupBy != null) {
|
||||
if (isSuitableRecordPkColumn(mapping.columns[groupBy].column, all)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
const RIGHT = 0x10;
|
||||
const CROSS = 0x02;
|
||||
const NATURAL = 0x04;
|
||||
const MASK = RIGHT | CROSS | NATURAL;
|
||||
for (const joinType of mapping.joins) {
|
||||
if (joinType & MASK) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (const column of mapping.columns.map((c) => c.column)) {
|
||||
if (isSuitableRecordPkColumn(column, all)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
export type TableType = "table" | "virtualTable" | "view";
|
||||
|
||||
export function getColumns(tableOrView: Table | View): undefined | Column[] {
|
||||
switch (tableType(tableOrView)) {
|
||||
case "table":
|
||||
case "virtualTable":
|
||||
return (tableOrView as Table).columns;
|
||||
case "view":
|
||||
return (tableOrView as View).column_mapping?.columns.map((c) => c.column);
|
||||
}
|
||||
}
|
||||
|
||||
export function tableType(table: Table | View): TableType {
|
||||
if ("virtual_table" in table) {
|
||||
if (table.virtual_table) {
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
|
||||
import type { ViewColumn } from "./ViewColumn";
|
||||
|
||||
export type ColumnMapping = { columns: Array<ViewColumn>,
|
||||
/**
|
||||
* Group by that can be used as a key for record APIs.
|
||||
*/
|
||||
group_by: number | null,
|
||||
/**
|
||||
* A list of joins.
|
||||
*/
|
||||
joins: Array<number>, };
|
||||
@@ -1,5 +1,5 @@
|
||||
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
|
||||
import type { Column } from "./Column";
|
||||
import type { ColumnMapping } from "./ColumnMapping";
|
||||
import type { QualifiedName } from "./QualifiedName";
|
||||
|
||||
export type View = { name: QualifiedName,
|
||||
@@ -10,6 +10,8 @@ export type View = { name: QualifiedName,
|
||||
* functions, ..., which makes them inherently not type safe and therefore their columns not
|
||||
* well defined.
|
||||
*
|
||||
* NOTE: Should all this inference be in ViewMetadata?
|
||||
* QUESTION: We've been wondering if the inference should live more in ViewMetadata, however
|
||||
* right now the `View` is heavily used in the UI to e.g. render tables and infer record API
|
||||
* suitability. It's ok that this is more than just an AST.
|
||||
*/
|
||||
columns: Array<Column> | null, query: string, temporary: boolean, };
|
||||
column_mapping: ColumnMapping | null, query: string, temporary: boolean, };
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
|
||||
import type { Column } from "./Column";
|
||||
|
||||
export type ViewColumn = { column: Column, parent_name: string | null, };
|
||||
@@ -1 +1,2 @@
|
||||
export const PORT: number = 4005;
|
||||
export const ADDRESS: string = `127.0.0.1:${PORT}`;
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { expect, test } from "vitest";
|
||||
import { OAuth2Server } from "oauth2-mock-server";
|
||||
import { PORT } from "../constants";
|
||||
|
||||
const address: string = `http://127.0.0.1:${PORT}`;
|
||||
import { ADDRESS } from "../constants";
|
||||
|
||||
type OpenIdConfig = {
|
||||
issuer: string;
|
||||
@@ -38,7 +36,7 @@ test("OIDC", async () => {
|
||||
userInfoResponse.statusCode = 200;
|
||||
});
|
||||
|
||||
const login = await fetch(`${address}/api/auth/v1/oauth/oidc0/login`, {
|
||||
const login = await fetch(`http://${ADDRESS}/api/auth/v1/oauth/oidc0/login`, {
|
||||
redirect: "manual",
|
||||
});
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@ import { initClient, urlSafeBase64Encode } from "../../src/index";
|
||||
import type { Client, Event } from "../../src/index";
|
||||
import { status } from "http-status";
|
||||
import { v7 as uuidv7, parse as uuidParse } from "uuid";
|
||||
|
||||
const port: number = 4005;
|
||||
const address: string = `http://127.0.0.1:${port}`;
|
||||
import { ADDRESS } from "../constants";
|
||||
|
||||
const sleep = (ms: number) => new Promise((r) => setTimeout(r, ms));
|
||||
|
||||
@@ -33,7 +31,7 @@ type SimpleSubsetView = {
|
||||
};
|
||||
|
||||
async function connect(): Promise<Client> {
|
||||
const client = initClient(address);
|
||||
const client = initClient(`http://${ADDRESS}`);
|
||||
await client.login("admin@localhost", "secret");
|
||||
return client;
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import { expect, test } from "vitest";
|
||||
import { status } from "http-status";
|
||||
|
||||
const port: number = 4005;
|
||||
const address: string = `http://127.0.0.1:${port}`;
|
||||
import { ADDRESS } from "../constants";
|
||||
|
||||
test("JS runtime", async () => {
|
||||
const expected = {
|
||||
@@ -14,16 +12,18 @@ test("JS runtime", async () => {
|
||||
},
|
||||
};
|
||||
|
||||
const jsonUrl = `${address}/json`;
|
||||
const jsonUrl = `http://${ADDRESS}/json`;
|
||||
const json = await (await fetch(jsonUrl)).json();
|
||||
expect(json).toMatchObject(expected);
|
||||
|
||||
const response = await fetch(`${address}/fetch?url=${encodeURI(jsonUrl)}`);
|
||||
const response = await fetch(
|
||||
`http://${ADDRESS}/fetch?url=${encodeURI(jsonUrl)}`,
|
||||
);
|
||||
expect(await response.json()).toMatchObject(expected);
|
||||
|
||||
const errResp = await fetch(`${address}/error`);
|
||||
const errResp = await fetch(`http://${ADDRESS}/error`);
|
||||
expect(errResp.status).equals(status.IM_A_TEAPOT);
|
||||
|
||||
// Test that the periodic callback was called.
|
||||
expect((await fetch(`${address}/await`)).status).equals(status.OK);
|
||||
expect((await fetch(`http://${ADDRESS}/await`)).status).equals(status.OK);
|
||||
});
|
||||
|
||||
@@ -5,7 +5,7 @@ import { cwd } from "node:process";
|
||||
import { join } from "node:path";
|
||||
import { execa, type Subprocess } from "execa";
|
||||
|
||||
import { PORT } from "./constants";
|
||||
import { ADDRESS } from "./constants";
|
||||
|
||||
const sleep = (ms: number) => new Promise((r) => setTimeout(r, ms));
|
||||
|
||||
@@ -28,7 +28,7 @@ async function initTrailBase(): Promise<{ subprocess: Subprocess }> {
|
||||
cwd: root,
|
||||
stdout: process.stdout,
|
||||
stderr: process.stdout,
|
||||
})`cargo run -- --data-dir client/testfixture --public-url http://127.0.0.1:${PORT} run -a 127.0.0.1:${PORT} --js-runtime-threads 1`;
|
||||
})`cargo run -- --data-dir client/testfixture --public-url http://${ADDRESS} run -a ${ADDRESS} --js-runtime-threads 1`;
|
||||
|
||||
for (let i = 0; i < 100; ++i) {
|
||||
if ((subprocess.exitCode ?? 0) > 0) {
|
||||
@@ -36,7 +36,7 @@ async function initTrailBase(): Promise<{ subprocess: Subprocess }> {
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`http://127.0.0.1:${PORT}/api/healthcheck`);
|
||||
const response = await fetch(`http://${ADDRESS}/api/healthcheck`);
|
||||
if (response.ok) {
|
||||
return { subprocess };
|
||||
}
|
||||
|
||||
@@ -264,9 +264,9 @@ pub(crate) fn auth_ui_router() -> Router<AppState> {
|
||||
.nest_service(
|
||||
"/_/auth/",
|
||||
AssetService::<trailbase_assets::AuthAssets>::with_parameters(
|
||||
// We want as little magic as possible. The only /_/auth/subpath that isn't SSR, is profile, so
|
||||
// we when hitting /profile or /profile, we want actually want to serve the static
|
||||
// profile/index.html.
|
||||
// We want as little magic as possible. The only /_/auth/subpath that isn't SSR, is
|
||||
// profile, so we when hitting /profile or /profile, we want actually want to serve
|
||||
// the static profile/index.html.
|
||||
Some(Box::new(|path| {
|
||||
if path == "profile" {
|
||||
Some(format!("{path}/index.html"))
|
||||
|
||||
@@ -608,7 +608,7 @@ mod tests {
|
||||
id INTEGER PRIMARY KEY,
|
||||
'index' TEXT NOT NULL DEFAULT '',
|
||||
nullable INTEGER
|
||||
);
|
||||
) STRICT;
|
||||
INSERT INTO 'table' (id, 'index', nullable) VALUES (1, '1', 1), (2, '2', NULL), (3, '3', NULL);
|
||||
"#,
|
||||
)
|
||||
|
||||
@@ -104,7 +104,7 @@ impl RecordApiSchema {
|
||||
};
|
||||
let record_pk_column = (pk_index, pk_column.clone());
|
||||
|
||||
let Some(ref columns) = view_metadata.schema.columns else {
|
||||
let Some(columns) = view_metadata.columns() else {
|
||||
return Err("RecordApi requires schema".to_string());
|
||||
};
|
||||
let Some(json_metadata) = view_metadata.json_metadata() else {
|
||||
|
||||
@@ -3,7 +3,9 @@ use log::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::metadata::{JsonColumnMetadata, JsonSchemaError, TableMetadata, extract_json_metadata};
|
||||
use crate::metadata::{
|
||||
JsonColumnMetadata, JsonSchemaError, TableMetadata, extract_json_metadata, is_pk_column,
|
||||
};
|
||||
use crate::sqlite::{Column, ColumnDataType, ColumnOption};
|
||||
|
||||
/// Influeces the generated JSON schema. In `Insert` mode columns with default values will be
|
||||
@@ -120,8 +122,7 @@ pub fn build_json_schema_expanded(
|
||||
};
|
||||
|
||||
let Some(pk_column) = (match referred_columns.len() {
|
||||
0 => crate::metadata::find_pk_column_index(&table.schema.columns)
|
||||
.map(|idx| &table.schema.columns[idx]),
|
||||
0 => table.schema.columns.iter().find(|c| is_pk_column(c)),
|
||||
1 => table
|
||||
.schema
|
||||
.columns
|
||||
|
||||
+339
-136
@@ -2,13 +2,16 @@ use jsonschema::Validator;
|
||||
use lazy_static::lazy_static;
|
||||
use log::*;
|
||||
use regex::Regex;
|
||||
use sqlite3_parser::ast::JoinType;
|
||||
use std::borrow::Borrow;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::sqlite::{Column, ColumnDataType, ColumnOption, QualifiedName, Table, View};
|
||||
use crate::sqlite::{
|
||||
Column, ColumnDataType, ColumnMapping, ColumnOption, QualifiedName, Table, View,
|
||||
};
|
||||
|
||||
// TODO: Can we merge this with crate::sqlite::SchemaError?
|
||||
#[derive(Debug, Clone, Error)]
|
||||
@@ -76,18 +79,12 @@ impl JsonMetadata {
|
||||
return Self::from_columns(&table.columns);
|
||||
}
|
||||
|
||||
fn from_view(view: &View) -> Option<Self> {
|
||||
return view.columns.as_ref().map(|cols| Self::from_columns(cols));
|
||||
}
|
||||
|
||||
fn from_columns(columns: &[Column]) -> Self {
|
||||
let columns: Vec<_> = columns.iter().map(build_json_metadata).collect();
|
||||
|
||||
let file_column_indexes = find_file_column_indexes(&columns);
|
||||
|
||||
return Self {
|
||||
file_column_indexes: find_file_column_indexes(&columns),
|
||||
columns,
|
||||
file_column_indexes,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -124,7 +121,7 @@ impl TableMetadata {
|
||||
.map(|(index, col)| (col.name.clone(), index)),
|
||||
);
|
||||
|
||||
let record_pk_column = find_record_pk_column_index(&table.columns, tables);
|
||||
let record_pk_column = find_record_pk_column_index_for_table(&table, tables);
|
||||
let user_id_columns = find_user_id_foreign_key_columns(&table.columns, user_table_name);
|
||||
let json_metadata = JsonMetadata::from_table(&table);
|
||||
|
||||
@@ -193,6 +190,10 @@ impl Borrow<QualifiedName> for Arc<TableMetadata> {
|
||||
pub struct ViewMetadata {
|
||||
pub schema: View,
|
||||
|
||||
// QUESTION: Why do we have copy of the columns here? Right now it's duplicate from `.schema`.
|
||||
// This probably only exists because we have a trait impl that returns Option<&[Column]>.
|
||||
columns: Option<Vec<Column>>,
|
||||
|
||||
name_to_index: HashMap<String, usize>,
|
||||
record_pk_column: Option<usize>,
|
||||
json_metadata: Option<JsonMetadata>,
|
||||
@@ -204,28 +205,36 @@ impl ViewMetadata {
|
||||
/// NOTE: The list of all tables is needed only to extract interger/UUIDv7 pk columns for foreign
|
||||
/// key relationships.
|
||||
pub fn new(view: View, tables: &[Table]) -> Self {
|
||||
let name_to_index = if let Some(ref columns) = view.columns {
|
||||
HashMap::<String, usize>::from_iter(
|
||||
columns
|
||||
return match view.column_mapping {
|
||||
Some(ref column_mapping) => {
|
||||
let columns: Vec<Column> = column_mapping
|
||||
.columns
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, col)| (col.name.clone(), index)),
|
||||
)
|
||||
} else {
|
||||
HashMap::<String, usize>::new()
|
||||
};
|
||||
.map(|m| m.column.clone())
|
||||
.collect();
|
||||
|
||||
let record_pk_column = view
|
||||
.columns
|
||||
.as_ref()
|
||||
.and_then(|c| find_record_pk_column_index(c, tables));
|
||||
let json_metadata = JsonMetadata::from_view(&view);
|
||||
let name_to_index = HashMap::<String, usize>::from_iter(
|
||||
columns
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, col)| (col.name.clone(), index)),
|
||||
);
|
||||
|
||||
return ViewMetadata {
|
||||
schema: view,
|
||||
name_to_index,
|
||||
record_pk_column,
|
||||
json_metadata,
|
||||
ViewMetadata {
|
||||
name_to_index,
|
||||
json_metadata: Some(JsonMetadata::from_columns(&columns)),
|
||||
columns: Some(columns),
|
||||
record_pk_column: find_record_pk_column_index_for_view(column_mapping, tables),
|
||||
schema: view,
|
||||
}
|
||||
}
|
||||
None => ViewMetadata {
|
||||
name_to_index: HashMap::<String, usize>::default(),
|
||||
columns: None,
|
||||
record_pk_column: None,
|
||||
json_metadata: None,
|
||||
schema: view,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -234,6 +243,11 @@ impl ViewMetadata {
|
||||
&self.schema.name
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn columns(&self) -> Option<&[Column]> {
|
||||
return self.columns.as_deref();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn column_index_by_name(&self, key: &str) -> Option<usize> {
|
||||
self.name_to_index.get(key).copied()
|
||||
@@ -242,8 +256,8 @@ impl ViewMetadata {
|
||||
#[inline]
|
||||
pub fn column_by_name(&self, key: &str) -> Option<(usize, &Column)> {
|
||||
let index = self.column_index_by_name(key)?;
|
||||
let cols = self.schema.columns.as_ref()?;
|
||||
return Some((index, &cols[index]));
|
||||
let mapping = self.schema.column_mapping.as_ref()?;
|
||||
return Some((index, &mapping.columns[index].column));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -313,7 +327,7 @@ impl TableOrViewMetadata for ViewMetadata {
|
||||
}
|
||||
|
||||
fn columns(&self) -> Option<&[Column]> {
|
||||
return self.schema.columns.as_deref();
|
||||
return self.columns.as_deref();
|
||||
}
|
||||
|
||||
fn json_metadata(&self) -> Option<&JsonMetadata> {
|
||||
@@ -321,7 +335,7 @@ impl TableOrViewMetadata for ViewMetadata {
|
||||
}
|
||||
|
||||
fn record_pk_column(&self) -> Option<(usize, &Column)> {
|
||||
let Some(columns) = &self.schema.columns else {
|
||||
let Some(columns) = &self.columns else {
|
||||
return None;
|
||||
};
|
||||
let index = self.record_pk_column?;
|
||||
@@ -427,85 +441,122 @@ pub fn find_user_id_foreign_key_columns(columns: &[Column], user_table_name: &st
|
||||
return indexes;
|
||||
}
|
||||
|
||||
pub(crate) fn find_pk_column_index(columns: &[Column]) -> Option<usize> {
|
||||
return columns.iter().position(|col| {
|
||||
for opt in &col.options {
|
||||
if let ColumnOption::Unique { is_primary, .. } = opt {
|
||||
return *is_primary;
|
||||
}
|
||||
pub(crate) fn is_pk_column(column: &Column) -> bool {
|
||||
for opt in &column.options {
|
||||
if let ColumnOption::Unique { is_primary, .. } = opt {
|
||||
return *is_primary;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
fn is_suitable_record_pk_column(column: &Column, tables: &[Table]) -> bool {
|
||||
if !is_pk_column(column) {
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
return match column.data_type {
|
||||
ColumnDataType::Integer => {
|
||||
// TODO: We should detect the "integer pk" desc case and at least warn:
|
||||
// https://www.sqlite.org/lang_createtable.html#rowid.
|
||||
true
|
||||
}
|
||||
ColumnDataType::Blob => {
|
||||
lazy_static! {
|
||||
static ref UUID_CHECK_RE: Regex =
|
||||
Regex::new(r"^is_uuid(|_v7|_v4)\s*\(").expect("infallible");
|
||||
}
|
||||
|
||||
for opts in &column.options {
|
||||
match opts {
|
||||
// Check the column itself is a UUID column.
|
||||
ColumnOption::Check(expr) if UUID_CHECK_RE.is_match(expr) => return true,
|
||||
// Or that a referenced column is a UUID column.
|
||||
ColumnOption::ForeignKey {
|
||||
foreign_table,
|
||||
referred_columns,
|
||||
..
|
||||
} => {
|
||||
let referred_column = {
|
||||
if referred_columns.len() != 1 {
|
||||
return false;
|
||||
}
|
||||
&referred_columns[0]
|
||||
};
|
||||
|
||||
// NOTE: Foreign keys cannot cross database boundaries, we can therefore compare by
|
||||
// unqualified name.
|
||||
let Some(referred_table) = tables.iter().find(|t| t.name.name == *foreign_table) else {
|
||||
warn!("Failed to get foreign key schema for {foreign_table}");
|
||||
return false;
|
||||
};
|
||||
|
||||
let Some(foreign_column) = referred_table
|
||||
.columns
|
||||
.iter()
|
||||
.find(|c| c.name == *referred_column)
|
||||
else {
|
||||
return false;
|
||||
};
|
||||
|
||||
for opt in &foreign_column.options {
|
||||
match opt {
|
||||
ColumnOption::Check(expr) if UUID_CHECK_RE.is_match(expr) => return true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
}
|
||||
|
||||
/// Finds suitable Integer or UUIDv7/UUIDv4 primary key columns, if present.
|
||||
///
|
||||
/// Cursors require certain properties like a stable, time-sortable primary key.
|
||||
fn find_record_pk_column_index(columns: &[Column], tables: &[Table]) -> Option<usize> {
|
||||
let index = find_pk_column_index(columns)?;
|
||||
let column = &columns[index];
|
||||
|
||||
if column.data_type == ColumnDataType::Integer {
|
||||
// TODO: We should detect the "integer pk" desc case and at least warn:
|
||||
// https://www.sqlite.org/lang_createtable.html#rowid.
|
||||
return Some(index);
|
||||
}
|
||||
|
||||
for opts in &column.options {
|
||||
lazy_static! {
|
||||
static ref UUID_CHECK_RE: Regex = Regex::new(r"^is_uuid(|_v7|_v4)\s*\(").expect("infallible");
|
||||
}
|
||||
|
||||
match &opts {
|
||||
// Check if the referenced column is a uuidv7 column.
|
||||
ColumnOption::ForeignKey {
|
||||
foreign_table,
|
||||
referred_columns,
|
||||
..
|
||||
} => {
|
||||
// NOTE: Foreign keys cannot cross database boundaries, we can therefore compare by
|
||||
// unqualified name.
|
||||
let Some(referred_table) = tables.iter().find(|t| t.name.name == *foreign_table) else {
|
||||
error!("Failed to get foreign key schema for {foreign_table}");
|
||||
continue;
|
||||
};
|
||||
|
||||
if referred_columns.len() != 1 {
|
||||
return None;
|
||||
}
|
||||
let referred_column = &referred_columns[0];
|
||||
|
||||
let col = referred_table
|
||||
.columns
|
||||
.iter()
|
||||
.find(|c| c.name == *referred_column)?;
|
||||
|
||||
let mut is_pk = false;
|
||||
for opt in &col.options {
|
||||
match opt {
|
||||
ColumnOption::Check(expr) if UUID_CHECK_RE.is_match(expr) => {
|
||||
return Some(index);
|
||||
}
|
||||
ColumnOption::Unique { is_primary, .. } if *is_primary => {
|
||||
is_pk = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if is_pk && col.data_type == ColumnDataType::Integer {
|
||||
return Some(index);
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
ColumnOption::Check(expr) if UUID_CHECK_RE.is_match(expr) => {
|
||||
fn find_record_pk_column_index_for_table(table: &Table, tables: &[Table]) -> Option<usize> {
|
||||
if table.strict {
|
||||
for (index, column) in table.columns.iter().enumerate() {
|
||||
if is_suitable_record_pk_column(column, tables) {
|
||||
return Some(index);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
fn find_record_pk_column_index_for_view(
|
||||
column_mapping: &ColumnMapping,
|
||||
tables: &[Table],
|
||||
) -> Option<usize> {
|
||||
if let Some(group_by_index) = column_mapping.group_by {
|
||||
let column = &column_mapping.columns[group_by_index];
|
||||
if is_suitable_record_pk_column(&column.column, tables) {
|
||||
return Some(group_by_index);
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
// NOTE: We could be smarter here. It's quite tricky to say with a set of arbitrary joins, which
|
||||
// (integer, UUID) columns end up being unique afterwards. Rely on explicit GROUP BY instead.
|
||||
let mask = JoinType::RIGHT | JoinType::CROSS | JoinType::NATURAL;
|
||||
for join_type in &column_mapping.joins {
|
||||
if join_type & mask.bits() != 0 {
|
||||
warn!("Only LEFT and INNER JOINS supported yet, got: {join_type:?}");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
for (index, mapped_column) in column_mapping.columns.iter().enumerate() {
|
||||
if is_suitable_record_pk_column(&mapped_column.column, tables) {
|
||||
return Some(index);
|
||||
}
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
@@ -515,16 +566,55 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::parse::parse_into_statement;
|
||||
use crate::sqlite::Table;
|
||||
use crate::sqlite::{SchemaError, Table};
|
||||
|
||||
fn parse_create_table(create_table_sql: &str) -> Table {
|
||||
let create_table_statement = parse_into_statement(create_table_sql).unwrap().unwrap();
|
||||
return create_table_statement.try_into().unwrap();
|
||||
}
|
||||
|
||||
fn parse_create_view(create_view_sql: &str, tables: &[Table]) -> View {
|
||||
fn parse_create_view(create_view_sql: &str, tables: &[Table]) -> Result<View, SchemaError> {
|
||||
let create_view_statement = parse_into_statement(create_view_sql).unwrap().unwrap();
|
||||
return View::from(create_view_statement, tables).unwrap();
|
||||
return View::from(create_view_statement, tables);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_record_pk_column_index_for_table() {
|
||||
let table = parse_create_table("CREATE TABLE t (id INTEGER PRIMARY KEY) STRICT");
|
||||
let tables = [table.clone()];
|
||||
assert_eq!(
|
||||
Some(0),
|
||||
find_record_pk_column_index_for_table(&table, &tables)
|
||||
);
|
||||
|
||||
let table = parse_create_table(
|
||||
r#"
|
||||
CREATE TABLE t (
|
||||
value TEXT,
|
||||
id BLOB PRIMARY KEY NOT NULL CHECK(is_uuid(id))
|
||||
) STRICT;
|
||||
"#,
|
||||
);
|
||||
|
||||
let tables = [table.clone()];
|
||||
assert_eq!(
|
||||
Some(1),
|
||||
find_record_pk_column_index_for_table(&table, &tables)
|
||||
);
|
||||
|
||||
let non_strict_table = parse_create_table(
|
||||
r#"
|
||||
CREATE TABLE t (
|
||||
value TEXT,
|
||||
id BLOB PRIMARY KEY NOT NULL CHECK(is_uuid(id))
|
||||
);
|
||||
"#,
|
||||
);
|
||||
let tables = [non_strict_table.clone()];
|
||||
assert_eq!(
|
||||
None,
|
||||
find_record_pk_column_index_for_table(&non_strict_table, &tables)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -551,19 +641,20 @@ mod tests {
|
||||
let table_view = parse_create_view(
|
||||
"CREATE VIEW view0 AS SELECT col0, col1 FROM table0",
|
||||
&tables,
|
||||
);
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(table_view.name.name, "view0");
|
||||
assert_eq!(table_view.query, "SELECT col0, col1 FROM table0");
|
||||
assert_eq!(table_view.temporary, false);
|
||||
|
||||
let view_columns = table_view.columns.as_ref().unwrap();
|
||||
let view_columns = &table_view.column_mapping.as_ref().unwrap().columns;
|
||||
|
||||
assert_eq!(view_columns.len(), 2);
|
||||
assert_eq!(view_columns[0].name, "col0");
|
||||
assert_eq!(view_columns[0].data_type, ColumnDataType::Text);
|
||||
assert_eq!(view_columns[0].column.name, "col0");
|
||||
assert_eq!(view_columns[0].column.data_type, ColumnDataType::Text);
|
||||
|
||||
assert_eq!(view_columns[1].name, "col1");
|
||||
assert_eq!(view_columns[1].data_type, ColumnDataType::Blob);
|
||||
assert_eq!(view_columns[1].column.name, "col1");
|
||||
assert_eq!(view_columns[1].column.data_type, ColumnDataType::Blob);
|
||||
|
||||
let view_metadata = ViewMetadata::new(table_view, &tables);
|
||||
|
||||
@@ -573,7 +664,8 @@ mod tests {
|
||||
|
||||
{
|
||||
let query = "SELECT id, col0, col1 FROM table0";
|
||||
let table_view = parse_create_view(&format!("CREATE VIEW view0 AS {query}"), &tables);
|
||||
let table_view =
|
||||
parse_create_view(&format!("CREATE VIEW view0 AS {query}"), &tables).unwrap();
|
||||
|
||||
assert_eq!(table_view.name.name, "view0");
|
||||
assert_eq!(table_view.query, query);
|
||||
@@ -600,15 +692,16 @@ mod tests {
|
||||
let view = parse_create_view(
|
||||
"CREATE VIEW view0 AS SELECT * FROM (SELECT * FROM a);",
|
||||
&tables,
|
||||
);
|
||||
let view_columns = view.columns.as_ref().unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
let view_columns = &view.column_mapping.as_ref().unwrap().columns;
|
||||
|
||||
assert_eq!(view_columns.len(), 2);
|
||||
assert_eq!(view_columns[0].name, "id");
|
||||
assert_eq!(view_columns[0].data_type, ColumnDataType::Integer);
|
||||
assert_eq!(view_columns[0].column.name, "id");
|
||||
assert_eq!(view_columns[0].column.data_type, ColumnDataType::Integer);
|
||||
|
||||
assert_eq!(view_columns[1].name, "data");
|
||||
assert_eq!(view_columns[1].data_type, ColumnDataType::Text);
|
||||
assert_eq!(view_columns[1].column.name, "data");
|
||||
assert_eq!(view_columns[1].column.data_type, ColumnDataType::Text);
|
||||
|
||||
let metadata = ViewMetadata::new(view, &tables);
|
||||
let (pk_index, pk_col) = metadata.record_pk_column().unwrap();
|
||||
@@ -620,11 +713,12 @@ mod tests {
|
||||
let view = parse_create_view(
|
||||
"CREATE VIEW view0 AS SELECT id FROM (SELECT * FROM a);",
|
||||
&tables,
|
||||
);
|
||||
let view_columns = view.columns.as_ref().unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
let view_columns = &view.column_mapping.as_ref().unwrap().columns;
|
||||
assert_eq!(view_columns.len(), 1);
|
||||
assert_eq!(view_columns[0].name, "id");
|
||||
assert_eq!(view_columns[0].data_type, ColumnDataType::Integer);
|
||||
assert_eq!(view_columns[0].column.name, "id");
|
||||
assert_eq!(view_columns[0].column.data_type, ColumnDataType::Integer);
|
||||
|
||||
let metadata = ViewMetadata::new(view, &tables);
|
||||
let (pk_index, pk_col) = metadata.record_pk_column().unwrap();
|
||||
@@ -636,11 +730,12 @@ mod tests {
|
||||
let view = parse_create_view(
|
||||
"CREATE VIEW view0 AS SELECT x.id FROM (SELECT * FROM a) AS x;",
|
||||
&tables,
|
||||
);
|
||||
let view_columns = view.columns.as_ref().unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
let view_columns = &view.column_mapping.as_ref().unwrap().columns;
|
||||
assert_eq!(view_columns.len(), 1);
|
||||
assert_eq!(view_columns[0].name, "id");
|
||||
assert_eq!(view_columns[0].data_type, ColumnDataType::Integer);
|
||||
assert_eq!(view_columns[0].column.name, "id");
|
||||
assert_eq!(view_columns[0].column.data_type, ColumnDataType::Integer);
|
||||
|
||||
let metadata = ViewMetadata::new(view, &tables);
|
||||
let (pk_index, pk_col) = metadata.record_pk_column().unwrap();
|
||||
@@ -653,8 +748,8 @@ mod tests {
|
||||
let view = parse_create_view(
|
||||
"CREATE VIEW view0 AS SELECT x.id, y.id FROM (SELECT * FROM a) AS x, (SELECT * FROM a) AS y;",
|
||||
&tables,
|
||||
);
|
||||
assert_eq!(view.columns, None);
|
||||
).unwrap();
|
||||
assert!(view.column_mapping.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -678,24 +773,131 @@ mod tests {
|
||||
let view = parse_create_view(
|
||||
"CREATE VIEW view0 AS SELECT a.data, b.fk, a.id FROM a AS a LEFT JOIN b AS b ON a.id = b.fk;",
|
||||
&tables,
|
||||
);
|
||||
let view_columns = view.columns.as_ref().unwrap();
|
||||
).unwrap();
|
||||
let view_columns = &view.column_mapping.as_ref().unwrap().columns;
|
||||
|
||||
assert_eq!(view_columns.len(), 3);
|
||||
assert_eq!(view_columns[2].name, "id");
|
||||
assert_eq!(view_columns[2].data_type, ColumnDataType::Integer);
|
||||
assert_eq!(view_columns[2].column.name, "id");
|
||||
assert_eq!(view_columns[2].column.data_type, ColumnDataType::Integer);
|
||||
|
||||
assert_eq!(view_columns[0].name, "data");
|
||||
assert_eq!(view_columns[0].data_type, ColumnDataType::Text);
|
||||
assert_eq!(view_columns[0].column.name, "data");
|
||||
assert_eq!(view_columns[0].column.data_type, ColumnDataType::Text);
|
||||
|
||||
assert_eq!(view_columns[1].name, "fk");
|
||||
assert_eq!(view_columns[1].data_type, ColumnDataType::Integer);
|
||||
assert_eq!(view_columns[1].column.name, "fk");
|
||||
assert_eq!(view_columns[1].column.data_type, ColumnDataType::Integer);
|
||||
|
||||
let metadata = ViewMetadata::new(view, &tables);
|
||||
let (pk_index, pk_col) = metadata.record_pk_column().unwrap();
|
||||
assert_eq!(pk_index, 2);
|
||||
assert_eq!(pk_col.name, "id");
|
||||
}
|
||||
|
||||
{
|
||||
// JOINs
|
||||
for (join_type, expected) in [
|
||||
("LEFT", Some(1)),
|
||||
("INNER", Some(1)),
|
||||
("RIGHT", None),
|
||||
("CROSS", None),
|
||||
] {
|
||||
let view = parse_create_view(
|
||||
&format!(
|
||||
"CREATE VIEW view0 AS SELECT a.data, a.id FROM a AS a {join_type} JOIN b AS b ON a.id = b.fk;"
|
||||
),
|
||||
&tables,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let metadata = ViewMetadata::new(view, &tables);
|
||||
assert_eq!(
|
||||
expected,
|
||||
metadata.record_pk_column().map(|c| c.0),
|
||||
"{join_type}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_create_view_with_group_by() {
|
||||
let table_a = parse_create_table(
|
||||
"CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT",
|
||||
);
|
||||
let table_b = parse_create_table(
|
||||
r#"
|
||||
CREATE TABLE b (
|
||||
id INTEGER PRIMARY KEY,
|
||||
fk INTEGER NOT NULL REFERENCES a(id)
|
||||
) STRICT"#,
|
||||
);
|
||||
|
||||
let tables = [table_a, table_b];
|
||||
|
||||
{
|
||||
// JOIN on a SELECT is not suitable for APIs. They're cross-producty nature spoils PKs.
|
||||
{
|
||||
for (i, sql) in [
|
||||
"CREATE VIEW v AS SELECT data, a.id AS z FROM a RIGHT JOIN b ON a.id = b.id GROUP BY z;",
|
||||
"CREATE VIEW v AS SELECT data, x.id AS z FROM a AS x RIGHT JOIN b ON x.id = b.id GROUP BY z;",
|
||||
"CREATE VIEW v AS SELECT data, x.id AS z FROM a x RIGHT JOIN b ON x.id = b.id GROUP BY z;",
|
||||
].iter().enumerate() {
|
||||
let view = parse_create_view(sql, &tables).unwrap();
|
||||
assert!(view.column_mapping.is_some(), "{i}: {sql}");
|
||||
|
||||
let metadata = ViewMetadata::new(view, &tables);
|
||||
assert_eq!(Some(1), metadata.record_pk_column().map(|c| c.0));
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let view = parse_create_view(
|
||||
"CREATE VIEW v AS SELECT a.data, a.id FROM a RIGHT JOIN b ON a.id = b.id GROUP BY a.id;",
|
||||
&tables,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let metadata = ViewMetadata::new(view, &tables);
|
||||
assert_eq!(Some(1), metadata.record_pk_column().map(|c| c.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_create_view_from_issue_99() {
|
||||
let authors_table = parse_create_table(
|
||||
"
|
||||
CREATE TABLE authors (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
age INTEGER DEFAULT NULL
|
||||
) STRICT;
|
||||
",
|
||||
);
|
||||
let posts_table = parse_create_table(
|
||||
"
|
||||
CREATE TABLE posts (
|
||||
id INTEGER PRIMARY KEY,
|
||||
author INTEGER DEFAULT NULL REFERENCES persons(id),
|
||||
title TEXT NOT NULL
|
||||
) STRICT;
|
||||
",
|
||||
);
|
||||
|
||||
let tables = [authors_table, posts_table];
|
||||
|
||||
let view = parse_create_view(
|
||||
"
|
||||
CREATE VIEW authors_view_posts AS
|
||||
SELECT authors.* FROM authors authors
|
||||
INNER JOIN posts posts ON posts.author = authors.id
|
||||
GROUP BY authors.id;
|
||||
",
|
||||
&tables,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let metadata = ViewMetadata::new(view, &tables);
|
||||
assert_eq!(Some(0), metadata.record_pk_column().map(|c| c.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -726,15 +928,16 @@ mod tests {
|
||||
name: "view_name".to_string(),
|
||||
database_schema: Some("main".to_string()),
|
||||
};
|
||||
let table_view = parse_create_view(
|
||||
let view = parse_create_view(
|
||||
&format!(
|
||||
"CREATE VIEW {view_name} AS SELECT id FROM {table_name}",
|
||||
view_name = view_name.escaped_string(),
|
||||
table_name = table_name.escaped_string()
|
||||
),
|
||||
&tables,
|
||||
);
|
||||
let view_metadata = Arc::new(ViewMetadata::new(table_view, &[table.clone()]));
|
||||
)
|
||||
.unwrap();
|
||||
let view_metadata = Arc::new(ViewMetadata::new(view, &[table.clone()]));
|
||||
|
||||
let mut view_set = HashSet::<Arc<ViewMetadata>>::new();
|
||||
|
||||
|
||||
+325
-187
@@ -851,7 +851,7 @@ impl std::fmt::Display for SelectFormatter {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, TS, PartialEq)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, TS)]
|
||||
pub struct View {
|
||||
pub name: QualifiedName,
|
||||
|
||||
@@ -861,8 +861,10 @@ pub struct View {
|
||||
/// functions, ..., which makes them inherently not type safe and therefore their columns not
|
||||
/// well defined.
|
||||
///
|
||||
/// NOTE: Should all this inference be in ViewMetadata?
|
||||
pub columns: Option<Vec<Column>>,
|
||||
/// QUESTION: We've been wondering if the inference should live more in ViewMetadata, however
|
||||
/// right now the `View` is heavily used in the UI to e.g. render tables and infer record API
|
||||
/// suitability. It's ok that this is more than just an AST.
|
||||
pub(crate) column_mapping: Option<ColumnMapping>,
|
||||
|
||||
pub query: String,
|
||||
|
||||
@@ -887,7 +889,7 @@ impl View {
|
||||
));
|
||||
};
|
||||
|
||||
let column_mapping: Option<Vec<ColumnMapping>> = if columns.is_some() {
|
||||
let column_mapping: Option<ColumnMapping> = if columns.is_some() {
|
||||
// Example, `CREATE VIEW view0(alias0, alias1) AS SELECT * FROM table0;`
|
||||
//
|
||||
// We probably never want to support this due to its late failure mode,
|
||||
@@ -913,7 +915,7 @@ impl View {
|
||||
|
||||
return Ok(View {
|
||||
name: view_name.into(),
|
||||
columns: column_mapping.map(|o| o.into_iter().map(|m| m.column).collect()),
|
||||
column_mapping,
|
||||
query: SelectFormatter(*select).to_string(),
|
||||
temporary,
|
||||
if_not_exists,
|
||||
@@ -921,141 +923,140 @@ impl View {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[allow(unused)]
|
||||
struct ReferredColumn {
|
||||
table_name: QualifiedName,
|
||||
column_name: String,
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, TS)]
|
||||
pub(crate) struct ViewColumn {
|
||||
// e.g. "foo" for CREATE VIEW v AS SELECT foo.bar AS baz FROM ...
|
||||
// #[allow(unused)]
|
||||
// pub(crate) qualifier: Option<String>,
|
||||
//
|
||||
// // e.g. "baz" for CREATE VIEW v AS SELECT foo.bar AS baz FROM ...
|
||||
// pub(crate) alias: Option<String>,
|
||||
|
||||
// The inferred column schema, either via a cast from a computed column or the underlying table
|
||||
// column if inferable.
|
||||
// NOTE: It would be cleaner to separate (Table)`Column` from `ViewColumn`, just pulling the in
|
||||
// the contents here. However, the UI currently depends on `Column`.
|
||||
pub(crate) column: Column,
|
||||
|
||||
// Would be "foo" for CREATE VIEW v AS SELECT foo.bar FROM foo;
|
||||
pub(crate) parent_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ColumnMapping {
|
||||
column: Column,
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, TS)]
|
||||
pub(crate) struct ColumnMapping {
|
||||
pub(crate) columns: Vec<ViewColumn>,
|
||||
|
||||
#[allow(unused)]
|
||||
referred_column: Option<ReferredColumn>,
|
||||
/// Group by that can be used as a key for record APIs.
|
||||
pub(crate) group_by: Option<usize>,
|
||||
|
||||
/// A list of joins.
|
||||
pub(crate) joins: Vec<u8>,
|
||||
}
|
||||
|
||||
fn extract_column_mapping(
|
||||
select: sqlite3_parser::ast::Select,
|
||||
tables: &[Table],
|
||||
) -> Result<Vec<ColumnMapping>, SchemaError> {
|
||||
) -> Result<ColumnMapping, SchemaError> {
|
||||
let result_columns = extract_result_columns(&select)?;
|
||||
let referenced_table_by_alias = extract_referenced_tables_by_alias(select, tables)?;
|
||||
let group_by_key_candidate = extract_group_by_key_candidate(&select)?;
|
||||
|
||||
let (joins, referenced_table_by_alias) =
|
||||
extract_joins_and_referenced_tables_by_alias(select, tables)?;
|
||||
|
||||
// SQLite checks comprehensively and will return an `ambiguous column name: <col>` error at
|
||||
// query time (as opposed to VIEW-creation-time).
|
||||
let find_column_by_unqualified_name = |col_name: &str| -> Result<(&Table, &Column), SchemaError> {
|
||||
// Search tables in index order.
|
||||
let mut found: Option<(&Table, &Column)> = None;
|
||||
for (_alias, table) in &referenced_table_by_alias {
|
||||
for col in &table.columns {
|
||||
if col.name == col_name {
|
||||
if found.is_some() {
|
||||
return Err(precondition(&format!("Ambibuous column: {col_name}")));
|
||||
let find_column_by_unqualified_name =
|
||||
|col_name: &str| -> Result<(&ReferredTable, usize), SchemaError> {
|
||||
// Search tables in index order.
|
||||
let mut found: Option<(&ReferredTable, usize)> = None;
|
||||
for referred_table in &referenced_table_by_alias {
|
||||
for (i, col) in referred_table.table.columns.iter().enumerate() {
|
||||
if col.name == col_name {
|
||||
if found.is_some() {
|
||||
return Err(precondition(&format!("Ambiguous column: {col_name}")));
|
||||
}
|
||||
found = Some((referred_table, i));
|
||||
}
|
||||
found = Some((table, col));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return found.ok_or(precondition(&format!("Column '{col_name}' not found")));
|
||||
};
|
||||
return found.ok_or(precondition(&format!("Column '{col_name}' not found")));
|
||||
};
|
||||
|
||||
let find_table_by_alias = |a: &str| -> Result<&Table, SchemaError> {
|
||||
for (alias, table) in &referenced_table_by_alias {
|
||||
if alias.as_deref() == Some(a) {
|
||||
return Ok(table);
|
||||
let find_table_by_alias = |a: &str| -> Result<&ReferredTable, SchemaError> {
|
||||
for referred_table in &referenced_table_by_alias {
|
||||
if referred_table.alias.as_deref() == Some(a) {
|
||||
return Ok(referred_table);
|
||||
}
|
||||
if referred_table.table.name.name == a {
|
||||
return Ok(referred_table);
|
||||
}
|
||||
}
|
||||
return Err(precondition(&format!("No table found for '{a}'")));
|
||||
};
|
||||
|
||||
let mut mapping: Vec<ColumnMapping> = vec![];
|
||||
let mut mapping: Vec<ViewColumn> = vec![];
|
||||
for col in result_columns {
|
||||
use sqlite3_parser::ast::Expr;
|
||||
|
||||
match col {
|
||||
ResultColumn::Star => {
|
||||
for (_alias, table) in &referenced_table_by_alias {
|
||||
for c in &table.columns {
|
||||
mapping.push(ColumnMapping {
|
||||
for referred_table in &referenced_table_by_alias {
|
||||
for c in &referred_table.table.columns {
|
||||
mapping.push(ViewColumn {
|
||||
column: c.clone(),
|
||||
referred_column: Some(ReferredColumn {
|
||||
table_name: table.name.clone(),
|
||||
column_name: c.name.clone(),
|
||||
}),
|
||||
parent_name: get_parent_name(referred_table),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
ResultColumn::TableStar(name) => {
|
||||
let name = unquote_name(name);
|
||||
let table = find_table_by_alias(&name)?;
|
||||
let referred_table = find_table_by_alias(&name)?;
|
||||
|
||||
for c in &table.columns {
|
||||
mapping.push(ColumnMapping {
|
||||
for c in &referred_table.table.columns {
|
||||
mapping.push(ViewColumn {
|
||||
column: c.clone(),
|
||||
referred_column: Some(ReferredColumn {
|
||||
table_name: table.name.clone(),
|
||||
column_name: c.name.clone(),
|
||||
}),
|
||||
parent_name: get_parent_name(referred_table),
|
||||
});
|
||||
}
|
||||
}
|
||||
ResultColumn::Expr(expr, alias) => match expr {
|
||||
Expr::Id(id) => {
|
||||
let col_name = unquote_id(id.clone());
|
||||
let (table, column) = find_column_by_unqualified_name(&col_name)?;
|
||||
let (referred_table, column_index) = find_column_by_unqualified_name(&col_name)?;
|
||||
let column = &referred_table.table.columns[column_index];
|
||||
|
||||
let name = alias
|
||||
.and_then(|alias| {
|
||||
if let sqlite3_parser::ast::As::As(name) = alias {
|
||||
return Some(unquote_name(name));
|
||||
}
|
||||
None
|
||||
})
|
||||
.unwrap_or_else(|| column.name.clone());
|
||||
|
||||
mapping.push(ColumnMapping {
|
||||
mapping.push(ViewColumn {
|
||||
column: Column {
|
||||
name,
|
||||
name: to_alias(alias).unwrap_or_else(|| column.name.clone()),
|
||||
data_type: column.data_type,
|
||||
options: column.options.clone(),
|
||||
},
|
||||
referred_column: Some(ReferredColumn {
|
||||
table_name: table.name.clone(),
|
||||
column_name: column.name.clone(),
|
||||
}),
|
||||
parent_name: get_parent_name(referred_table),
|
||||
});
|
||||
}
|
||||
Expr::Qualified(qualifier, name) => {
|
||||
let table = find_table_by_alias(&unquote_name(qualifier))?;
|
||||
let qualifier = unquote_name(qualifier);
|
||||
let referred_table = find_table_by_alias(&qualifier)?;
|
||||
|
||||
let col_name = unquote_name(name);
|
||||
let Some(column) = table.columns.iter().find(|c| c.name == col_name) else {
|
||||
let Some(column_index) = referred_table
|
||||
.table
|
||||
.columns
|
||||
.iter()
|
||||
.position(|c| c.name == col_name)
|
||||
else {
|
||||
return Err(precondition(&format!("Missing col: {col_name}")));
|
||||
};
|
||||
|
||||
let name = alias
|
||||
.and_then(|alias| {
|
||||
if let sqlite3_parser::ast::As::As(name) = alias {
|
||||
return Some(unquote_name(name));
|
||||
}
|
||||
None
|
||||
})
|
||||
.unwrap_or_else(|| column.name.clone());
|
||||
let column = &referred_table.table.columns[column_index];
|
||||
|
||||
mapping.push(ColumnMapping {
|
||||
mapping.push(ViewColumn {
|
||||
column: Column {
|
||||
name,
|
||||
name: to_alias(alias).unwrap_or_else(|| column.name.clone()),
|
||||
data_type: column.data_type,
|
||||
options: column.options.clone(),
|
||||
},
|
||||
referred_column: Some(ReferredColumn {
|
||||
table_name: table.name.clone(),
|
||||
column_name: column.name.clone(),
|
||||
}),
|
||||
parent_name: get_parent_name(referred_table),
|
||||
});
|
||||
}
|
||||
Expr::Cast { expr: _, type_name } => {
|
||||
@@ -1070,22 +1071,17 @@ fn extract_column_mapping(
|
||||
));
|
||||
};
|
||||
|
||||
let Some(name) = alias.and_then(|alias| {
|
||||
if let sqlite3_parser::ast::As::As(name) = alias {
|
||||
return Some(unquote_name(name));
|
||||
}
|
||||
None
|
||||
}) else {
|
||||
let Some(name) = to_alias(alias) else {
|
||||
return Err(SchemaError::Precondition("Missing alias in cast".into()));
|
||||
};
|
||||
|
||||
mapping.push(ColumnMapping {
|
||||
mapping.push(ViewColumn {
|
||||
column: Column {
|
||||
name,
|
||||
data_type,
|
||||
options: vec![ColumnOption::Null],
|
||||
options: vec![],
|
||||
},
|
||||
referred_column: None,
|
||||
parent_name: None,
|
||||
});
|
||||
}
|
||||
x => {
|
||||
@@ -1096,27 +1092,61 @@ fn extract_column_mapping(
|
||||
};
|
||||
}
|
||||
|
||||
return Ok(mapping);
|
||||
}
|
||||
return match group_by_key_candidate {
|
||||
None => Ok(ColumnMapping {
|
||||
columns: mapping,
|
||||
group_by: None,
|
||||
joins,
|
||||
}),
|
||||
Some(group_by) => {
|
||||
// NOTE: GROUP BY can technically reference any column, but only columns also exposed by the
|
||||
// VIEW are useful to us as keys. In other words, there's no point of us to search for this
|
||||
// column through all referenced tables.
|
||||
let group_by = match group_by.qualifier {
|
||||
Some(ref qualifier) => {
|
||||
// If the "GROUP BY" uses a qualifier, it must reference a table or subselect, e.g.:
|
||||
// CREATE VIEW v AS SELECT a.id FROM a RIGHT JOIN b ON a.id = b.id GROUP BY a.id;
|
||||
mapping.iter().position(|v: &ViewColumn| {
|
||||
v.parent_name.as_ref() == Some(qualifier) && v.column.name == group_by.name
|
||||
})
|
||||
}
|
||||
None => mapping
|
||||
.iter()
|
||||
.position(|v: &ViewColumn| v.column.name == group_by.name),
|
||||
};
|
||||
|
||||
#[inline]
|
||||
fn precondition(m: &str) -> SchemaError {
|
||||
return SchemaError::Precondition(m.into());
|
||||
}
|
||||
|
||||
fn extract_result_columns(
|
||||
select: &sqlite3_parser::ast::Select,
|
||||
) -> Result<Vec<ResultColumn>, SchemaError> {
|
||||
let sqlite3_parser::ast::OneSelect::Select { columns, .. } = &select.body.select else {
|
||||
return Err(precondition("VALUES not supported"));
|
||||
Ok(ColumnMapping {
|
||||
columns: mapping,
|
||||
group_by: Some(group_by.ok_or_else(|| precondition("GROUP BY column not exposed"))?),
|
||||
joins,
|
||||
})
|
||||
}
|
||||
};
|
||||
return Ok(columns.clone());
|
||||
}
|
||||
|
||||
fn extract_referenced_tables_by_alias(
|
||||
fn get_parent_name(referred_table: &ReferredTable) -> Option<String> {
|
||||
return Some(
|
||||
referred_table
|
||||
.alias
|
||||
.as_ref()
|
||||
.unwrap_or(&referred_table.table.name.name)
|
||||
.to_owned(),
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ReferredTable {
|
||||
/// Optional top-most alias (nested aliases, e.g. in a sub-query, are not accessible).
|
||||
pub(crate) alias: Option<String>,
|
||||
|
||||
/// The referenced table.
|
||||
pub(crate) table: Table,
|
||||
}
|
||||
|
||||
fn extract_joins_and_referenced_tables_by_alias(
|
||||
select: sqlite3_parser::ast::Select,
|
||||
tables: &[Table],
|
||||
) -> Result<Vec<(Option<String>, &Table)>, SchemaError> {
|
||||
) -> Result<(Vec<u8>, Vec<ReferredTable>), SchemaError> {
|
||||
let body = select.body;
|
||||
if body.compounds.is_some() {
|
||||
return Err(precondition("Compound queries not supported"));
|
||||
@@ -1126,7 +1156,7 @@ fn extract_referenced_tables_by_alias(
|
||||
columns: _,
|
||||
distinctness,
|
||||
from,
|
||||
group_by,
|
||||
group_by: _,
|
||||
having: _,
|
||||
where_clause: _,
|
||||
window_clause,
|
||||
@@ -1138,10 +1168,6 @@ fn extract_referenced_tables_by_alias(
|
||||
)));
|
||||
};
|
||||
|
||||
if group_by.is_some() {
|
||||
return Err(precondition("GROUP BY clause not (yet) supported"));
|
||||
}
|
||||
|
||||
if distinctness.is_some() {
|
||||
return Err(precondition("DISTINCT clause not (yet) supported"));
|
||||
}
|
||||
@@ -1168,7 +1194,7 @@ fn extract_referenced_tables_by_alias(
|
||||
// Make sure all referenced tables are strict.
|
||||
if !table.strict {
|
||||
return Err(precondition(&format!(
|
||||
"Referenced table: {:?} must be STRICT",
|
||||
"Table {:?} must be STRICT to derive type",
|
||||
table.name
|
||||
)));
|
||||
}
|
||||
@@ -1176,42 +1202,76 @@ fn extract_referenced_tables_by_alias(
|
||||
return Ok(table);
|
||||
};
|
||||
|
||||
// Map from "alias" to table. Use IndexMap to preserve insertion order.
|
||||
let referenced_table_by_alias: Vec<(Option<String>, &Table)> = match nested_select.map(|s| *s) {
|
||||
let mut all_joins: Vec<u8> = joins
|
||||
.as_ref()
|
||||
.map(|joins| {
|
||||
joins
|
||||
.iter()
|
||||
.map(|join| extract_join_type(join.operator).bits())
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// List of referenced tables in insertion order (left-to-right).
|
||||
let referenced_table_by_alias: Vec<ReferredTable> = match nested_select.map(|s| *s) {
|
||||
Some(SelectTable::Table(fqn, alias, _indexed)) => {
|
||||
let mut referenced_tables = vec![(to_alias(alias), find_table(&fqn.into())?)];
|
||||
// Table itself
|
||||
let mut referenced_tables = vec![ReferredTable {
|
||||
alias: to_alias(alias),
|
||||
table: find_table(&fqn.into())?.clone(),
|
||||
}];
|
||||
|
||||
// // Plus possible joins.
|
||||
for join in joins.unwrap_or_default() {
|
||||
match join.operator {
|
||||
JoinOperator::TypedJoin(Some(t)) if t.contains(JoinType::LEFT) => {}
|
||||
_ => {
|
||||
// Right now, we're picking the VIEW's primary key left to right. Other joins would
|
||||
// require more sophistication. Many joins will spoil PKs, e.g. by computing a
|
||||
// cross-product yielding a non-unique column.
|
||||
return Err(precondition(&format!(
|
||||
"Only LEFT JOINS supported yet, got: {:?}",
|
||||
join.operator
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// We don't currently allow joining sub-queries, etc.
|
||||
let SelectTable::Table(fqn, alias, _indexed) = join.table else {
|
||||
return Err(precondition("JOIN with TABLE expected"));
|
||||
match join.table {
|
||||
SelectTable::Table(fqn, alias, _indexed) => {
|
||||
referenced_tables.push(ReferredTable {
|
||||
alias: to_alias(alias),
|
||||
table: find_table(&fqn.into())?.clone(),
|
||||
});
|
||||
}
|
||||
SelectTable::Select(subselect, alias) => {
|
||||
let alias = to_alias(alias);
|
||||
|
||||
let (joins_in_subselect, referenced_tables_in_subselect) =
|
||||
extract_joins_and_referenced_tables_by_alias(*subselect, tables)?;
|
||||
|
||||
all_joins.extend(joins_in_subselect);
|
||||
referenced_tables.extend(referenced_tables_in_subselect.into_iter().map(
|
||||
|ReferredTable { table, .. }| -> ReferredTable {
|
||||
return ReferredTable {
|
||||
alias: alias.clone(),
|
||||
table,
|
||||
};
|
||||
},
|
||||
));
|
||||
}
|
||||
_ => {
|
||||
return Err(precondition("JOIN with TABLE expected"));
|
||||
}
|
||||
};
|
||||
referenced_tables.push((to_alias(alias), find_table(&fqn.into())?));
|
||||
}
|
||||
|
||||
referenced_tables
|
||||
}
|
||||
Some(SelectTable::Select(select, alias)) => {
|
||||
Some(SelectTable::Select(nested_select, alias)) => {
|
||||
// Simply recurse tu unnest the select.
|
||||
let alias = to_alias(alias);
|
||||
return Ok(
|
||||
extract_referenced_tables_by_alias(*select, tables)?
|
||||
let (joins_in_nested_select, referenced_tables_in_nested_select) =
|
||||
extract_joins_and_referenced_tables_by_alias(*nested_select, tables)?;
|
||||
|
||||
return Ok((
|
||||
joins_in_nested_select,
|
||||
referenced_tables_in_nested_select
|
||||
.into_iter()
|
||||
.map(|(_, table)| (alias.clone(), table))
|
||||
.map(|referred_table| ReferredTable {
|
||||
// NOTE: Reset the alias.
|
||||
alias: alias.clone(),
|
||||
table: referred_table.table,
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
));
|
||||
}
|
||||
Some(x) => {
|
||||
return Err(precondition(&format!(
|
||||
@@ -1223,7 +1283,68 @@ fn extract_referenced_tables_by_alias(
|
||||
}
|
||||
};
|
||||
|
||||
return Ok(referenced_table_by_alias);
|
||||
return Ok((all_joins, referenced_table_by_alias));
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn precondition(m: &str) -> SchemaError {
|
||||
return SchemaError::Precondition(m.into());
|
||||
}
|
||||
|
||||
fn extract_join_type(op: JoinOperator) -> JoinType {
|
||||
return match op {
|
||||
JoinOperator::TypedJoin(Some(t)) => t,
|
||||
JoinOperator::Comma | JoinOperator::TypedJoin(None) => JoinType::INNER,
|
||||
};
|
||||
}
|
||||
|
||||
fn extract_result_columns(
|
||||
select: &sqlite3_parser::ast::Select,
|
||||
) -> Result<Vec<ResultColumn>, SchemaError> {
|
||||
let sqlite3_parser::ast::OneSelect::Select { columns, .. } = &select.body.select else {
|
||||
return Err(precondition("VALUES not supported"));
|
||||
};
|
||||
return Ok(columns.clone());
|
||||
}
|
||||
|
||||
struct GroupBy {
|
||||
qualifier: Option<String>,
|
||||
name: String,
|
||||
}
|
||||
|
||||
fn extract_group_by_key_candidate(
|
||||
select: &sqlite3_parser::ast::Select,
|
||||
) -> Result<Option<GroupBy>, SchemaError> {
|
||||
let sqlite3_parser::ast::OneSelect::Select { group_by, .. } = &select.body.select else {
|
||||
return Err(precondition("VALUES not supported"));
|
||||
};
|
||||
|
||||
let Some(group_by) = group_by else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
return match group_by.len() {
|
||||
1 => match group_by[0].clone() {
|
||||
Expr::Id(id) => Ok(Some(GroupBy {
|
||||
qualifier: None,
|
||||
name: unquote_id(id),
|
||||
})),
|
||||
Expr::Name(name) => Ok(Some(GroupBy {
|
||||
qualifier: None,
|
||||
name: unquote_name(name),
|
||||
})),
|
||||
Expr::Qualified(qualifier, name) => Ok(Some(GroupBy {
|
||||
qualifier: Some(unquote_name(qualifier)),
|
||||
name: unquote_name(name),
|
||||
})),
|
||||
expr => Err(precondition(&format!(
|
||||
"For RecordAPIs GROUP BY expressions must reference an exposed VIEW column, got {expr:?}"
|
||||
))),
|
||||
},
|
||||
n => Err(precondition(&format!(
|
||||
"For RecordAPIs GROUP BY expressions must reference a single VIEW column, got {n}"
|
||||
))),
|
||||
};
|
||||
}
|
||||
|
||||
fn build_foreign_key(
|
||||
@@ -1541,6 +1662,14 @@ mod tests {
|
||||
assert_eq!(index, index1, "Parsed: {sql1}");
|
||||
}
|
||||
|
||||
fn parse_into_select(sql: &str) -> sqlite3_parser::ast::Select {
|
||||
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
|
||||
else {
|
||||
panic!("Not a select");
|
||||
};
|
||||
return *select;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_view_column_extraction() {
|
||||
let tables = vec![Table {
|
||||
@@ -1563,56 +1692,45 @@ mod tests {
|
||||
|
||||
{
|
||||
// No alias
|
||||
let sql = "SELECT column FROM table_name";
|
||||
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
|
||||
else {
|
||||
panic!("Not a select");
|
||||
};
|
||||
let _mapping = extract_column_mapping(*select, &tables).unwrap();
|
||||
let select = parse_into_select("SELECT column FROM table_name");
|
||||
let _mapping = extract_column_mapping(select, &tables).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
// With alias
|
||||
let sql = "SELECT alias.column FROM table_name AS alias";
|
||||
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
|
||||
else {
|
||||
panic!("Not a select");
|
||||
};
|
||||
let _mapping = extract_column_mapping(*select, &tables).unwrap();
|
||||
let select = parse_into_select("SELECT alias.column FROM table_name AS alias");
|
||||
let _mapping = extract_column_mapping(select, &tables).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
// With "elided" alias
|
||||
let sql = "SELECT alias.column FROM table_name alias";
|
||||
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
|
||||
else {
|
||||
panic!("Not a select");
|
||||
};
|
||||
let _mapping = extract_column_mapping(*select, &tables).unwrap();
|
||||
let select = parse_into_select("SELECT alias.column FROM table_name alias");
|
||||
let _mapping = extract_column_mapping(select, &tables).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
// JOIN on a SELECT.
|
||||
let sql = "SELECT x.column, y.column FROM table_name AS x LEFT JOIN (SELECT * FROM table_name) AS y ON x.column = y.column";
|
||||
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
|
||||
else {
|
||||
panic!("Not a select");
|
||||
};
|
||||
let err = extract_column_mapping(*select, &tables)
|
||||
.err()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
assert!(err.contains("JOIN with TABLE expected"), "{err}");
|
||||
let select = parse_into_select(
|
||||
"SELECT x.column, y.column AS foo FROM table_name AS x LEFT JOIN (SELECT * FROM table_name) AS y ON x.column = y.column",
|
||||
);
|
||||
let column_mapping = extract_column_mapping(select, &tables).unwrap();
|
||||
let columns = &column_mapping.columns;
|
||||
assert_eq!(columns.len(), 2, "{columns:?}");
|
||||
|
||||
let first = &columns[0];
|
||||
assert_eq!(first.column.data_type, ColumnDataType::Text);
|
||||
assert_eq!(first.column.name, "column");
|
||||
|
||||
let second = &columns[1];
|
||||
assert_eq!(second.column.data_type, ColumnDataType::Text);
|
||||
assert_eq!(second.column.name, "foo");
|
||||
}
|
||||
|
||||
{
|
||||
// Compound SELECT.
|
||||
let sql = "SELECT column FROM table_name UNION SELECT column FROM table_name";
|
||||
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
|
||||
else {
|
||||
panic!("Not a select");
|
||||
};
|
||||
let err = extract_column_mapping(*select, &tables)
|
||||
let select =
|
||||
parse_into_select("SELECT column FROM table_name UNION SELECT column FROM table_name");
|
||||
let err = extract_column_mapping(select, &tables)
|
||||
.err()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
@@ -1620,6 +1738,38 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_create_view_select(sql: &str) -> sqlite3_parser::ast::Select {
|
||||
let sqlite3_parser::ast::Stmt::CreateView { select, .. } =
|
||||
parse_into_statement(sql).unwrap().unwrap()
|
||||
else {
|
||||
panic!("Not a CREATE VIEW: {sql}");
|
||||
};
|
||||
return *select;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_creare_view_colum_mapping() {
|
||||
let table_a = parse_create_table(
|
||||
"CREATE TABLE a (id INTEGER PRIMARY KEY, data TEXT NOT NULL DEFAULT '') STRICT",
|
||||
);
|
||||
|
||||
let tables = [table_a];
|
||||
|
||||
let select =
|
||||
parse_create_view_select("CREATE VIEW view0 AS SELECT x.id FROM a AS x GROUP BY x.id");
|
||||
assert_eq!(
|
||||
Some(0),
|
||||
extract_column_mapping(select, &tables).unwrap().group_by
|
||||
);
|
||||
|
||||
let select =
|
||||
parse_create_view_select("CREATE VIEW view0 AS SELECT x.id FROM a AS x GROUP BY id");
|
||||
assert_eq!(
|
||||
Some(0),
|
||||
extract_column_mapping(select, &tables).unwrap().group_by
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_create_table(create_table_sql: &str) -> Table {
|
||||
let create_table_statement = parse_into_statement(create_table_sql).unwrap().unwrap();
|
||||
return create_table_statement.try_into().unwrap();
|
||||
@@ -1627,12 +1777,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_view_column_extraction_join() {
|
||||
let sql = "SELECT user, *, a.*, p.user AS foo FROM foo.articles AS a LEFT JOIN bar.profiles AS p ON p.user = a.author";
|
||||
let sqlite3_parser::ast::Stmt::Select(select) = parse_into_statement(sql).unwrap().unwrap()
|
||||
else {
|
||||
panic!("Not a select");
|
||||
};
|
||||
|
||||
let profiles_table = parse_create_table(
|
||||
r#"
|
||||
CREATE TABLE bar.profiles (
|
||||
@@ -1654,20 +1798,14 @@ mod tests {
|
||||
|
||||
let tables = [profiles_table, articles_table];
|
||||
|
||||
let mapping = extract_column_mapping(*select, &tables).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
mapping
|
||||
.iter()
|
||||
.map(|m| m.referred_column.as_ref().unwrap().column_name.as_str())
|
||||
.collect::<Vec<_>>(),
|
||||
[
|
||||
"user", "id", "author", "body", "user", "username", "id", "author", "body", "user"
|
||||
]
|
||||
let select = parse_into_select(
|
||||
"SELECT user, *, a.*, p.user AS foo FROM foo.articles AS a LEFT JOIN bar.profiles AS p ON p.user = a.author",
|
||||
);
|
||||
let mapping = extract_column_mapping(select, &tables).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
mapping
|
||||
.columns
|
||||
.iter()
|
||||
.map(|m| m.column.name.as_str())
|
||||
.collect::<Vec<_>>(),
|
||||
|
||||
Reference in New Issue
Block a user