Merge remote-tracking branch 'origin/main' into aaron/github-workflows-icu4c

This commit is contained in:
Aaron Son
2025-09-10 09:20:44 -07:00
51 changed files with 2343 additions and 1389 deletions
+23
View File
@@ -308,6 +308,29 @@ func CreateLogArgParser(isTableFunction bool) *argparser.ArgParser {
return ap
}
func CreateDiffArgParser(isTableFunction bool) *argparser.ArgParser {
ap := argparser.NewArgParserWithVariableArgs("diff")
ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.")
ap.SupportsStringList(IncludeCols, "ic", "columns", "A list of columns to include in the diff.")
if !isTableFunction { // TODO: support for table function
ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).")
ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).")
ap.SupportsFlag(StatFlag, "", "Show stats of data changes")
ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes")
ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.")
ap.SupportsString(WhereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
ap.SupportsInt(LimitParam, "", "record_count", "limits to the first N diffs.")
ap.SupportsFlag(StagedFlag, "", "Show only the staged data changes.")
ap.SupportsFlag(CachedFlag, "c", "Synonym for --staged")
ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit")
ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.")
ap.SupportsFlag(ReverseFlag, "R", "Reverses the direction of the diff.")
ap.SupportsFlag(NameOnlyFlag, "", "Only shows table names.")
ap.SupportsFlag(SystemFlag, "", "Show system tables in addition to user tables")
}
return ap
}
func CreateGCArgParser() *argparser.ArgParser {
ap := argparser.NewArgParserWithMaxArgs("gc", 0)
ap.SupportsFlag(ShallowFlag, "s", "perform a fast, but incomplete garbage collection pass")
+16
View File
@@ -88,3 +88,19 @@ const (
UpperCaseAllFlag = "ALL"
UserFlag = "user"
)
// Flags used by `dolt diff` command and `dolt_diff()` table function.
const (
SkinnyFlag = "skinny"
IncludeCols = "include-cols"
DataFlag = "data"
SchemaFlag = "schema"
NameOnlyFlag = "name-only"
SummaryFlag = "summary"
WhereParam = "where"
LimitParam = "limit"
MergeBase = "merge-base"
DiffMode = "diff-mode"
ReverseFlag = "reverse"
FormatFlag = "result-format"
)
+30 -51
View File
@@ -60,18 +60,6 @@ const (
TabularDiffOutput diffOutput = 1
SQLDiffOutput diffOutput = 2
JsonDiffOutput diffOutput = 3
DataFlag = "data"
SchemaFlag = "schema"
NameOnlyFlag = "name-only"
StatFlag = "stat"
SummaryFlag = "summary"
whereParam = "where"
limitParam = "limit"
SkinnyFlag = "skinny"
MergeBase = "merge-base"
DiffMode = "diff-mode"
ReverseFlag = "reverse"
)
var diffDocs = cli.CommandDocumentationContent{
@@ -107,12 +95,13 @@ The {{.EmphasisLeft}}--diff-mode{{.EmphasisRight}} argument controls how modifie
}
type diffDisplaySettings struct {
diffParts diffPart
diffOutput diffOutput
diffMode diff.Mode
limit int
where string
skinny bool
diffParts diffPart
diffOutput diffOutput
diffMode diff.Mode
limit int
where string
skinny bool
includeCols []string
}
type diffDatasets struct {
@@ -164,23 +153,7 @@ func (cmd DiffCmd) Docs() *cli.CommandDocumentation {
}
func (cmd DiffCmd) ArgParser() *argparser.ArgParser {
ap := argparser.NewArgParserWithVariableArgs(cmd.Name())
ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).")
ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).")
ap.SupportsFlag(StatFlag, "", "Show stats of data changes")
ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes")
ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.")
ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.")
ap.SupportsFlag(cli.StagedFlag, "", "Show only the staged data changes.")
ap.SupportsFlag(cli.CachedFlag, "c", "Synonym for --staged")
ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.")
ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit")
ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.")
ap.SupportsFlag(ReverseFlag, "R", "Reverses the direction of the diff.")
ap.SupportsFlag(NameOnlyFlag, "", "Only shows table names.")
ap.SupportsFlag(cli.SystemFlag, "", "Show system tables in addition to user tables")
return ap
return cli.CreateDiffArgParser(false)
}
func (cmd DiffCmd) RequiresRepo() bool {
@@ -228,14 +201,14 @@ func (cmd DiffCmd) Exec(ctx context.Context, commandStr string, args []string, _
}
func (cmd DiffCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseError {
if apr.Contains(StatFlag) || apr.Contains(SummaryFlag) {
if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) {
if apr.Contains(cli.StatFlag) || apr.Contains(cli.SummaryFlag) {
if apr.Contains(cli.SchemaFlag) || apr.Contains(cli.DataFlag) {
return errhand.BuildDError("invalid Arguments: --stat and --summary cannot be combined with --schema or --data").Build()
}
}
if apr.Contains(NameOnlyFlag) {
if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) || apr.Contains(StatFlag) || apr.Contains(SummaryFlag) {
if apr.Contains(cli.NameOnlyFlag) {
if apr.Contains(cli.SchemaFlag) || apr.Contains(cli.DataFlag) || apr.Contains(cli.StatFlag) || apr.Contains(cli.SummaryFlag) {
return errhand.BuildDError("invalid Arguments: --name-only cannot be combined with --schema, --data, --stat, or --summary").Build()
}
}
@@ -254,25 +227,29 @@ func parseDiffDisplaySettings(apr *argparser.ArgParseResults) *diffDisplaySettin
displaySettings := &diffDisplaySettings{}
displaySettings.diffParts = SchemaAndDataDiff
if apr.Contains(DataFlag) && !apr.Contains(SchemaFlag) {
if apr.Contains(cli.DataFlag) && !apr.Contains(cli.SchemaFlag) {
displaySettings.diffParts = DataOnlyDiff
} else if apr.Contains(SchemaFlag) && !apr.Contains(DataFlag) {
} else if apr.Contains(cli.SchemaFlag) && !apr.Contains(cli.DataFlag) {
displaySettings.diffParts = SchemaOnlyDiff
} else if apr.Contains(StatFlag) {
} else if apr.Contains(cli.StatFlag) {
displaySettings.diffParts = Stat
} else if apr.Contains(SummaryFlag) {
} else if apr.Contains(cli.SummaryFlag) {
displaySettings.diffParts = Summary
} else if apr.Contains(NameOnlyFlag) {
} else if apr.Contains(cli.NameOnlyFlag) {
displaySettings.diffParts = NameOnlyDiff
}
displaySettings.skinny = apr.Contains(SkinnyFlag)
displaySettings.skinny = apr.Contains(cli.SkinnyFlag)
if cols, ok := apr.GetValueList(cli.IncludeCols); ok {
displaySettings.includeCols = cols
}
f := apr.GetValueOrDefault(FormatFlag, "tabular")
switch strings.ToLower(f) {
case "tabular":
displaySettings.diffOutput = TabularDiffOutput
switch strings.ToLower(apr.GetValueOrDefault(DiffMode, "context")) {
switch strings.ToLower(apr.GetValueOrDefault(cli.DiffMode, "context")) {
case "row":
displaySettings.diffMode = diff.ModeRow
case "line":
@@ -288,8 +265,8 @@ func parseDiffDisplaySettings(apr *argparser.ArgParseResults) *diffDisplaySettin
displaySettings.diffOutput = JsonDiffOutput
}
displaySettings.limit, _ = apr.GetInt(limitParam)
displaySettings.where = apr.GetValueOrDefault(whereParam, "")
displaySettings.limit, _ = apr.GetInt(cli.LimitParam)
displaySettings.where = apr.GetValueOrDefault(cli.WhereParam, "")
return displaySettings
}
@@ -301,12 +278,12 @@ func parseDiffArgs(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.Ar
staged := apr.Contains(cli.StagedFlag) || apr.Contains(cli.CachedFlag)
tableNames, err := dArgs.applyDiffRoots(queryist, sqlCtx, apr.Args, staged, apr.Contains(MergeBase))
tableNames, err := dArgs.applyDiffRoots(queryist, sqlCtx, apr.Args, staged, apr.Contains(cli.MergeBase))
if err != nil {
return nil, err
}
if apr.Contains(ReverseFlag) {
if apr.Contains(cli.ReverseFlag) {
dArgs.diffDatasets = &diffDatasets{
fromRef: dArgs.toRef,
toRef: dArgs.fromRef,
@@ -1556,7 +1533,9 @@ func diffRows(
if err != nil {
return errhand.BuildDError("Error running diff query:\n%s", interpolatedQuery).AddCause(err).Build()
}
for _, col := range dArgs.includeCols {
modifiedColNames[col] = true // ensure included columns are always present
}
// instantiate a new schema that only contains the columns with changes
var filteredUnionSch sql.Schema
for _, s := range unionSch {
+11 -11
View File
@@ -83,17 +83,17 @@ func (cmd ShowCmd) ArgParser() *argparser.ArgParser {
ap.SupportsFlag(cli.NoPrettyFlag, "", "Show the object without making it pretty.")
// Flags inherited from Diff
ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).")
ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).")
ap.SupportsFlag(StatFlag, "", "Show stats of data changes")
ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes")
ap.SupportsFlag(cli.DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).")
ap.SupportsFlag(cli.SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).")
ap.SupportsFlag(cli.StatFlag, "", "Show stats of data changes")
ap.SupportsFlag(cli.SummaryFlag, "", "Show summary of data and schema changes")
ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.")
ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.")
ap.SupportsString(cli.WhereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
ap.SupportsInt(cli.LimitParam, "", "record_count", "limits to the first N diffs.")
ap.SupportsFlag(cli.CachedFlag, "c", "Show only the staged data changes.")
ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.")
ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit")
ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.")
ap.SupportsFlag(cli.SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.")
ap.SupportsFlag(cli.MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit")
ap.SupportsString(cli.DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.")
return ap
}
@@ -275,8 +275,8 @@ func getValueFromRefSpec(ctx context.Context, dEnv *env.DoltEnv, specRef string)
}
func (cmd ShowCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseError {
if apr.Contains(StatFlag) || apr.Contains(SummaryFlag) {
if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) {
if apr.Contains(cli.StatFlag) || apr.Contains(cli.SummaryFlag) {
if apr.Contains(cli.SchemaFlag) || apr.Contains(cli.DataFlag) {
return errhand.BuildDError("invalid Arguments: --stat and --summary cannot be combined with --schema or --data").Build()
}
}
+1 -1
View File
@@ -203,7 +203,7 @@ func (cmd SqlServerCmd) ArgParserWithName(name string) *argparser.ArgParser {
ap.SupportsUint(mcpPortFlag, "", "port", "If provided, runs a Dolt MCP HTTP server on this port alongside the sql-server.")
// MCP SQL credentials (user required when MCP enabled; password optional)
ap.SupportsString(mcpUserFlag, "", "user", "SQL user for MCP to connect as (required when --mcp-port is set).")
ap.SupportsString(mcpPasswordFlag, "", "password", "Optional SQL password for MCP to connect with (requires --mcp-user). Defaults to env DOLT_ROOT_PASSWORD if unset.")
ap.SupportsString(mcpPasswordFlag, "", "password", "Optional SQL password for MCP to connect with (requires --mcp-user).")
ap.SupportsString(mcpDatabaseFlag, "", "database", "Optional SQL database name MCP should connect to (requires --mcp-port and --mcp-user).")
return ap
}
+1 -1
View File
@@ -15,5 +15,5 @@
package doltversion
const (
Version = "1.58.8"
Version = "1.59.6"
)
+2 -2
View File
@@ -13,7 +13,7 @@ require (
github.com/dolthub/fslock v0.0.3
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20250902185630-90811959cbd1
github.com/dolthub/vitess v0.0.0-20250902225707-0159e964d73d
github.com/dustin/go-humanize v1.0.1
github.com/fatih/color v1.13.0
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568
@@ -61,7 +61,7 @@ require (
github.com/dolthub/dolt-mcp v0.2.1-0.20250827202412-9d0f6e658fba
github.com/dolthub/eventsapi_schema v0.0.0-20250725194025-a087efa1ee55
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
github.com/dolthub/go-mysql-server v0.20.1-0.20250902204612-4e1a10a95d8c
github.com/dolthub/go-mysql-server v0.20.1-0.20250909231122-5fb8788af3dd
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63
github.com/edsrzf/mmap-go v1.2.0
github.com/esote/minmaxheap v1.0.0
+4 -4
View File
@@ -213,8 +213,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20250820171420-f2b78f56ce9f h1:oSA8CptGeCEdTdD9LFtv8x4juDfdaLKsx1eocyaj1bE=
github.com/dolthub/go-icu-regex v0.0.0-20250820171420-f2b78f56ce9f/go.mod h1:kpsRG+a196Y69zsAFL0RkQICII9a571lcaxhvQnmrdY=
github.com/dolthub/go-mysql-server v0.20.1-0.20250902204612-4e1a10a95d8c h1:AHPSDdj6UnnVkc2eT3p4ZZ/7lf7MhHAwYCYkokad+PY=
github.com/dolthub/go-mysql-server v0.20.1-0.20250902204612-4e1a10a95d8c/go.mod h1:/OxBMtXwiC67PVMBV+/9sNjyfsK4SOF1moNnvdotEZM=
github.com/dolthub/go-mysql-server v0.20.1-0.20250909231122-5fb8788af3dd h1:FdOA59vPc2ilQAnQlXNfFSeVkCTHPa4m8Gl2KmOErTs=
github.com/dolthub/go-mysql-server v0.20.1-0.20250909231122-5fb8788af3dd/go.mod h1:ymoHIRZoZKO1EH9iUGcq4E6XyIpUaMgZz3ZPeWa828w=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=
@@ -223,8 +223,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20250902185630-90811959cbd1 h1:sCpjbwm7rIV5o9OQDWsqGSzNXGSxpeAk3kX5gjPQlQI=
github.com/dolthub/vitess v0.0.0-20250902185630-90811959cbd1/go.mod h1:tV3BrIVyDWVkkYy8dKt2o6hjJ89cHb5opY5FpCyhncQ=
github.com/dolthub/vitess v0.0.0-20250902225707-0159e964d73d h1:oTWJxjzRmuHKuICUunCUwNuonubkXwOqPa5hXX3dXBo=
github.com/dolthub/vitess v0.0.0-20250902225707-0159e964d73d/go.mod h1:tV3BrIVyDWVkkYy8dKt2o6hjJ89cHb5opY5FpCyhncQ=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/edsrzf/mmap-go v1.2.0 h1:hXLYlkbaPzt1SaQk+anYwKSRNhufIDCchSPkUD6dD84=
+510
View File
@@ -0,0 +1,510 @@
# AGENT.md - Dolt Database Operations Guide
This file provides guidance for AI agents working with Dolt databases to maximize productivity and follow best practices.
## Quick Start
Dolt is "Git for Data" - a SQL database with version control capabilities. All Git commands have Dolt equivalents:
- `git add``dolt add`
- `git commit``dolt commit`
- `git branch``dolt branch`
- `git merge``dolt merge`
- `git diff``dolt diff`
For help and documentation on commands, you can run `dolt --help` and `dolt <command> --help`.
## Essential Dolt CLI Commands
### Repository Operations
```bash
# Initialize new database
dolt init
# Clone existing database
dolt clone <remote-url>
# Show current status
dolt status
# View commit history
dolt log
```
### Branch Management
```bash
# List branches
dolt branch
# Create new branch
dolt branch <branch-name>
# Switch branches
dolt checkout <branch-name>
# Create and switch to new branch
dolt checkout -b <branch-name>
```
### Data Operations
```bash
# Stage changes
dolt add <table-name>
dolt add . # stage all changes
# Commit changes
dolt commit -m "commit message"
# View differences
dolt diff
dolt diff <table-name>
dolt diff <branch1> <branch2>
# Merge branches
dolt merge <branch-name>
```
## Starting and Connecting to Dolt SQL Server
### Start SQL Server
```bash
# Start server on default port (3306)
dolt sql-server
# Start on specific port
dolt sql-server --port=3307
# Start with specific host
dolt sql-server --host=0.0.0.0 --port=3307
# Start in background
dolt sql-server --port=3307 &
```
### Connecting to SQL Server
```bash
# Connect with dolt sql command
dolt sql
# Connect with mysql client
mysql -h 127.0.0.1 -P 3306 -u root
# Connect with specific database
mysql -h 127.0.0.1 -P 3306 -u root -D <database-name>
```
## Dolt Testing with dolt_test System Table
### Unit Testing with dolt_test
The dolt_test system table provides a powerful way to create and run unit tests for your database. This is the preferred method for testing data integrity, business rules, and schema validation.
#### Creating Tests
Tests are created by inserting rows into the `dolt_tests` system table:
```sql
-- Create a simple test
INSERT INTO `dolt_tests` VALUES (
'test_user_count',
'validation',
'SELECT COUNT(*) as user_count FROM users;',
'row_count',
'>',
'0'
);
-- Create a test with expected result
INSERT INTO `dolt_tests` VALUES (
'test_valid_emails',
'validation',
'SELECT COUNT(*) FROM users WHERE email NOT LIKE "%@%";',
'row_count',
'==',
'0'
);
-- Create a schema validation test
INSERT INTO `dolt_tests` VALUES (
'test_users_schema',
'schema',
'DESCRIBE users;',
'row_count',
'>=',
'5'
);
```
#### Test Structure
Each test row contains:
- test_name: Unique identifier for the test
- test_group: Optional grouping for tests (e.g., 'validation', 'schema', 'integration')
- test_query: SQL query to execute
- assertion_type: Type of assertion ('expected_rows', 'expected_columns', 'expected_single_value')
- assertion_comparator: Comparison operator ('==', '>', '<', '>=', '<=', '!=')
- assertion_value: Expected value for comparison
#### Running Tests
```sql
-- Run all tests
SELECT * FROM dolt_test_run();
-- Run specific test
SELECT * FROM dolt_test_run('test_user_count');
-- Run tests with filtering
SELECT * FROM dolt_test_run() WHERE test_name LIKE 'test_user%' AND status != 'PASS';
```
#### Test Result Interpretation
The dolt_test_run() function returns:
- test_name: Name of the test
- status: PASS, FAIL, or ERROR
- actual_result: Actual query result
- expected_result: Expected result
- message: Additional details
#### Advanced Testing Examples
```sql
-- Test data integrity
INSERT INTO `dolt_tests` VALUES (
'test_no_orphaned_orders',
'integrity',
'SELECT COUNT(*) FROM orders o LEFT JOIN users u ON o.user_id = u.id WHERE u.id IS NULL;',
'row_count',
'==',
'0'
);
-- Test business rules
INSERT INTO `dolt_tests` VALUES (
'test_positive_prices',
'business_rules',
'SELECT COUNT(*) FROM products WHERE price <= 0;',
'row_count',
'==',
'0'
);
-- Test complex relationships
INSERT INTO `dolt_tests` VALUES (
'test_order_totals',
'integrity',
'SELECT COUNT(*) FROM orders o JOIN order_items oi ON o.id = oi.order_id GROUP BY o.id HAVING SUM(oi.quantity * oi.price) != o.total;',
'row_count',
'==',
'0'
);
```
### Dolt CI for DoltHub Integration
Dolt CI is specifically designed for running tests on DoltHub when pull requests are created. Use this only for tests you want to run automatically on DoltHub.
#### Prerequisites for DoltHub CI
- Requires Dolt v1.43.14 or later
- Must initialize CI capabilities: `dolt ci init`
- Workflows defined in YAML files
#### Available CI Commands
```bash
# Initialize CI capabilities
dolt ci init
# List available workflows
dolt ci ls
# View workflow details
dolt ci view <workflow-name>
# View specific job in workflow
dolt ci view <workflow-name> <job-name>
# Run workflow locally (for testing before DoltHub)
dolt ci run <workflow-name>
```
#### Creating CI Workflows for DoltHub
Create workflow files that will run on DoltHub when pull requests are opened:
```yaml
name: doltHub validation workflow
on:
push:
branches:
- master
- main
jobs:
- name: validate schema
steps:
- name: check required tables exist
saved_query_name: show_tables
expected_rows: ">= 3"
- name: validate user data
saved_query_name: user_count_check
expected_columns: "== 1"
expected_rows: "> 0"
- name: data integrity checks
steps:
- name: check email format
saved_query_name: valid_emails
expected_rows: "== 0" # No invalid emails
```
### Best Practices for Testing
1. **Use dolt_test for Unit Testing**
- Create tests for data validation
- Test business rules and constraints
- Validate schema changes
- Run tests frequently during development
2. **Use Dolt CI for DoltHub Integration**
- Only for tests that should run on pull requests
- Focus on integration and deployment validation
- Test against production-like data
3. **Create Comprehensive Test Suites**
- Test data integrity constraints
- Validate business rules
- Check schema requirements
- Verify data relationships
4. **Version Control Your Tests**
- Commit test definitions to repository
- Track changes to test configuration
- Use branches for test development
## System Tables for Version Control
Dolt exposes version control operations through system tables accessible via SQL:
### Core System Tables
```sql
-- View commit history
SELECT * FROM dolt_log;
-- Check current status
SELECT * FROM dolt_status;
-- View branch information
SELECT * FROM dolt_branches;
-- See table diffs
SELECT * FROM dolt_diff_<table_name>;
-- View schema changes
SELECT * FROM dolt_schema_diff;
-- Check conflicts during merge
SELECT * FROM dolt_conflicts_<table_name>;
-- View commit metadata
SELECT * FROM dolt_commits;
```
### Version Control Operations via SQL
When working in SQL sessions, you can execute version control operations using stored procedures:
```sql
-- Stage and commit changes
CALL dolt_add('.');
CALL dolt_commit('-m', 'commit message');
-- Branch operations
CALL dolt_branch('<branch_name>');
CALL dolt_checkout('<branch_name>');
CALL dolt_merge('<branch_name>');
```
**Note:** Use CLI commands (`dolt add`, `dolt commit`, etc.) for most operations. SQL procedures are useful when already in a SQL session.
### Advanced System Tables
```sql
-- View remotes
SELECT * FROM dolt_remotes;
-- Check merge conflicts
SELECT * FROM dolt_conflicts;
-- View statistics
SELECT * FROM dolt_statistics;
-- See ignored tables
SELECT * FROM dolt_ignore;
```
## CLI vs SQL Approach
**Prefer CLI commands for:**
- Version control operations (add, commit, branch, merge)
- Repository management (init, clone, push, pull)
- Conflict resolution
- Status checking and history viewing
**Use SQL for:**
- Data queries and analysis
- Complex data transformations
- Examining system tables (dolt_log, dolt_status, etc.)
- When already in an active SQL session
## Schema Design Recommendations
### Use UUID Keys Instead of Auto-Increment
For Dolt's version control features, use UUID primary keys instead of auto-increment:
```sql
-- Recommended
CREATE TABLE users (
id varchar(36) default(uuid()) primary key,
name varchar(255)
);
-- Avoid auto-increment with Dolt
-- id int auto_increment primary key
```
**Benefits:**
- Prevents merge conflicts across branches and database clones
- Automatic generation with default(uuid())
- Works seamlessly in distributed environments
## Best Practices for Agents
### 1. Always Work on Feature Branches
```bash
# Create feature branch before making changes
dolt checkout -b feature/agent-changes
# Make changes on feature branch
dolt sql -q "INSERT INTO users VALUES (1, 'Alice');"
# Stage and commit
dolt add .
dolt commit -m "Add new user Alice"
# Switch back to main to merge
dolt checkout main
dolt merge feature/agent-changes
```
### 2. Use SQL for Data Operations, CLI for Version Control
```bash
# Use dolt sql for data changes
dolt sql -q "INSERT INTO users VALUES (1, 'Alice');"
dolt sql -q "UPDATE products SET price = price * 1.1 WHERE category = 'electronics';"
# Check status and commit using CLI
dolt status
dolt add .
dolt commit -m "Update user and product data"
```
### 3. Validate Changes with System Tables
```sql
-- Before major operations, check current state
SELECT * FROM dolt_status;
SELECT * FROM dolt_branches;
-- After changes, verify with diffs
SELECT * FROM dolt_diff_users;
SELECT * FROM dolt_schema_diff;
```
### 4. Use dolt_test for Data Validation
Create tests to validate:
- Data integrity after changes
- Schema compatibility
- Business rule compliance
- Cross-table relationships
### 5. Handle Conflicts Gracefully
```bash
# Check for conflicts using CLI
dolt conflicts cat <table_name>
dolt conflicts resolve --ours <table_name>
dolt conflicts resolve --theirs <table_name>
# Or use SQL to examine conflicts
dolt sql -q "SELECT * FROM dolt_conflicts_<table_name>;"
```
## Common Workflow Examples
### Data Migration Workflow
```bash
# Create migration branch
dolt checkout -b migration/update-schema
# Apply schema changes via SQL
dolt sql -q "ALTER TABLE users ADD COLUMN email VARCHAR(255);"
# Create validation tests
dolt sql -q "INSERT INTO `dolt_tests` VALUES ('test_users_schema', 'schema', 'DESCRIBE users;', 'row_count', '>=', '6');"
dolt sql -q "INSERT INTO `dolt_tests` VALUES ('test_email_column', 'schema', 'SELECT COUNT(*) FROM users WHERE email IS NULL;', 'row_count', '>=', '0');"
# Run tests to validate changes
dolt sql -q "SELECT * FROM dolt_test_run();"
# Stage and commit
dolt add .
dolt commit -m "Add email column to users table"
# Merge back
dolt checkout main
dolt merge migration/update-schema
```
### Data Analysis Workflow
```bash
# Create analysis branch
dolt checkout -b analysis/user-behavior
# Create analysis tables via SQL
dolt sql -q "CREATE TABLE user_metrics AS
SELECT user_id, COUNT(*) as actions
FROM user_actions
GROUP BY user_id;"
# Create tests to validate analysis
dolt sql -q "INSERT INTO `dolt_tests` VALUES ('test_metrics_created', 'analysis', 'SELECT COUNT(*) FROM user_metrics;', 'row_count', '>', '0');"
dolt sql -q "INSERT INTO `dolt_tests` VALUES ('test_metrics_integrity', 'integrity', 'SELECT COUNT(*) FROM user_metrics um LEFT JOIN users u ON um.user_id = u.id WHERE u.id IS NULL;', 'row_count', '==', '0');"
# Run tests to validate analysis
dolt sql -q "SELECT * FROM dolt_test_run();"
# Stage and commit using CLI
dolt add user_metrics
dolt commit -m "Add user behavior analysis"
```
## Integration with External Tools
### Database Clients
Most MySQL clients work with Dolt:
- MySQL Workbench
- phpMyAdmin
- DataGrip
- DBeaver
### Backup and Sync
```bash
# Push to remote
dolt push origin main
# Pull changes
dolt pull origin main
# Clone for backup
dolt clone <remote-url> backup-location
```
This guide enables agents to leverage Dolt's unique version control capabilities while maintaining data integrity and following collaborative development practices.
+29
View File
@@ -174,6 +174,35 @@ func ExcludeIgnoredTables(ctx context.Context, roots Roots, tables []TableName)
return filteredTables, nil
}
// IdentifyIgnoredTables takes a list of table names and identifies any tables that are ignored, by evaluating the
// table names against the patterns in the dolt_ignore table from the working set.
func IdentifyIgnoredTables(ctx context.Context, roots Roots, tables []TableName) (ignoredTables []TableName, err error) {
schemas := GetUniqueSchemaNamesFromTableNames(tables)
ignorePatternMap, err := GetIgnoredTablePatterns(ctx, roots, schemas)
if err != nil {
return nil, err
}
for _, tbl := range tables {
ignorePatterns := ignorePatternMap[tbl.Schema]
ignored, err := ignorePatterns.IsTableNameIgnored(tbl)
if err != nil {
return nil, err
}
if conflict := AsDoltIgnoreInConflict(err); conflict != nil {
// no-op
} else if ignored == DontIgnore {
// no-op
} else if ignored == Ignore {
ignoredTables = append(ignoredTables, tbl)
} else {
return nil, fmt.Errorf("IsTableNameIgnored returned ErrorOccurred but no error!")
}
}
return ignoredTables, nil
}
// compilePattern takes a dolt_ignore pattern and generate a Regexp that matches against the same table names as the pattern.
func compilePattern(pattern string) (*regexp.Regexp, error) {
pattern = "^" + regexp.QuoteMeta(pattern) + "$"
+4 -511
View File
@@ -16,6 +16,7 @@ package doltdb
import (
"context"
_ "embed"
"errors"
"sort"
"strings"
@@ -27,6 +28,9 @@ import (
"github.com/dolthub/dolt/go/store/types"
)
//go:embed AGENT.md
var DefaultAgentDocValue string
type ctxKey int
type ctxValue int
@@ -174,517 +178,6 @@ const (
ReadmeDoc = "README.md"
// AgentDoc is the key for accessing the agent documentation within the docs table
AgentDoc = "AGENT.md"
DefaultAgentDocValue = `# AGENT.md - Dolt Database Operations Guide
This file provides guidance for AI agents working with Dolt databases to maximize productivity and follow best practices.
## Quick Start
Dolt is "Git for Data" - a SQL database with version control capabilities. All Git commands have Dolt equivalents:
- ` + "`git add` → `dolt add`" + `
- ` + "`git commit` → `dolt commit`" + `
- ` + "`git branch` → `dolt branch`" + `
- ` + "`git merge` → `dolt merge`" + `
- ` + "`git diff` → `dolt diff`" + `
For help and documentation on commands, you can run ` + "`dolt --help`" + ` and ` + "`dolt <command> --help`" + `.
## Essential Dolt CLI Commands
### Repository Operations
` + "```bash" + `
# Initialize new database
dolt init
# Clone existing database
dolt clone <remote-url>
# Show current status
dolt status
# View commit history
dolt log
` + "```" + `
### Branch Management
` + "```bash" + `
# List branches
dolt branch
# Create new branch
dolt branch <branch-name>
# Switch branches
dolt checkout <branch-name>
# Create and switch to new branch
dolt checkout -b <branch-name>
` + "```" + `
### Data Operations
` + "```bash" + `
# Stage changes
dolt add <table-name>
dolt add . # stage all changes
# Commit changes
dolt commit -m "commit message"
# View differences
dolt diff
dolt diff <table-name>
dolt diff <branch1> <branch2>
# Merge branches
dolt merge <branch-name>
` + "```" + `
## Starting and Connecting to Dolt SQL Server
### Start SQL Server
` + "```bash" + `
# Start server on default port (3306)
dolt sql-server
# Start on specific port
dolt sql-server --port=3307
# Start with specific host
dolt sql-server --host=0.0.0.0 --port=3307
# Start in background
dolt sql-server --port=3307 &
` + "```" + `
### Connecting to SQL Server
` + "```bash" + `
# Connect with dolt sql command
dolt sql
# Connect with mysql client
mysql -h 127.0.0.1 -P 3306 -u root
# Connect with specific database
mysql -h 127.0.0.1 -P 3306 -u root -D <database-name>
` + "```" + `
## Dolt Testing with dolt_test System Table
### Unit Testing with dolt_test
The dolt_test system table provides a powerful way to create and run unit tests for your database. This is the preferred method for testing data integrity, business rules, and schema validation.
#### Creating Tests
Tests are created by inserting rows into the ` + "`dolt_tests`" + ` system table:
` + "```sql" + `
-- Create a simple test
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
'test_user_count',
'validation',
'SELECT COUNT(*) as user_count FROM users;',
'row_count',
'>',
'0'
);
-- Create a test with expected result
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
'test_valid_emails',
'validation',
'SELECT COUNT(*) FROM users WHERE email NOT LIKE "%@%";',
'row_count',
'==',
'0'
);
-- Create a schema validation test
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
'test_users_schema',
'schema',
'DESCRIBE users;',
'row_count',
'>=',
'5'
);
` + "```" + `
#### Test Structure
Each test row contains:
- test_name: Unique identifier for the test
- test_group: Optional grouping for tests (e.g., 'validation', 'schema', 'integration')
- test_query: SQL query to execute
- assertion_type: Type of assertion ('expected_rows', 'expected_columns', 'expected_single_value')
- assertion_comparator: Comparison operator ('==', '>', '<', '>=', '<=', '!=')
- assertion_value: Expected value for comparison
#### Running Tests
` + "```sql" + `
-- Run all tests
SELECT * FROM dolt_test_run();
-- Run specific test
SELECT * FROM dolt_test_run('test_user_count');
-- Run tests with filtering
SELECT * FROM dolt_test_run() WHERE test_name LIKE 'test_user%' AND status != 'PASS';
` + "```" + `
#### Test Result Interpretation
The dolt_test_run() function returns:
- test_name: Name of the test
- status: PASS, FAIL, or ERROR
- actual_result: Actual query result
- expected_result: Expected result
- message: Additional details
#### Advanced Testing Examples
` + "```sql" + `
-- Test data integrity
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
'test_no_orphaned_orders',
'integrity',
'SELECT COUNT(*) FROM orders o LEFT JOIN users u ON o.user_id = u.id WHERE u.id IS NULL;',
'row_count',
'==',
'0'
);
-- Test business rules
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
'test_positive_prices',
'business_rules',
'SELECT COUNT(*) FROM products WHERE price <= 0;',
'row_count',
'==',
'0'
);
-- Test complex relationships
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
'test_order_totals',
'integrity',
'SELECT COUNT(*) FROM orders o JOIN order_items oi ON o.id = oi.order_id GROUP BY o.id HAVING SUM(oi.quantity * oi.price) != o.total;',
'row_count',
'==',
'0'
);
` + "```" + `
### Dolt CI for DoltHub Integration
Dolt CI is specifically designed for running tests on DoltHub when pull requests are created. Use this only for tests you want to run automatically on DoltHub.
#### Prerequisites for DoltHub CI
- Requires Dolt v1.43.14 or later
- Must initialize CI capabilities: ` + "`dolt ci init`" + `
- Workflows defined in YAML files
#### Available CI Commands
` + "```bash" + `
# Initialize CI capabilities
dolt ci init
# List available workflows
dolt ci ls
# View workflow details
dolt ci view <workflow-name>
# View specific job in workflow
dolt ci view <workflow-name> <job-name>
# Run workflow locally (for testing before DoltHub)
dolt ci run <workflow-name>
` + "```" + `
#### Creating CI Workflows for DoltHub
Create workflow files that will run on DoltHub when pull requests are opened:
` + "```yaml" + `
name: doltHub validation workflow
on:
push:
branches:
- master
- main
jobs:
- name: validate schema
steps:
- name: check required tables exist
saved_query_name: show_tables
expected_rows: ">= 3"
- name: validate user data
saved_query_name: user_count_check
expected_columns: "== 1"
expected_rows: "> 0"
- name: data integrity checks
steps:
- name: check email format
saved_query_name: valid_emails
expected_rows: "== 0" # No invalid emails
` + "```" + `
### Best Practices for Testing
1. **Use dolt_test for Unit Testing**
- Create tests for data validation
- Test business rules and constraints
- Validate schema changes
- Run tests frequently during development
2. **Use Dolt CI for DoltHub Integration**
- Only for tests that should run on pull requests
- Focus on integration and deployment validation
- Test against production-like data
3. **Create Comprehensive Test Suites**
- Test data integrity constraints
- Validate business rules
- Check schema requirements
- Verify data relationships
4. **Version Control Your Tests**
- Commit test definitions to repository
- Track changes to test configuration
- Use branches for test development
## System Tables for Version Control
Dolt exposes version control operations through system tables accessible via SQL:
### Core System Tables
` + "```sql" + `
-- View commit history
SELECT * FROM dolt_log;
-- Check current status
SELECT * FROM dolt_status;
-- View branch information
SELECT * FROM dolt_branches;
-- See table diffs
SELECT * FROM dolt_diff_<table_name>;
-- View schema changes
SELECT * FROM dolt_schema_diff;
-- Check conflicts during merge
SELECT * FROM dolt_conflicts_<table_name>;
-- View commit metadata
SELECT * FROM dolt_commits;
` + "```" + `
### Version Control Operations via SQL
When working in SQL sessions, you can execute version control operations using stored procedures:
` + "```sql" + `
-- Stage and commit changes
CALL dolt_add('.');
CALL dolt_commit('-m', 'commit message');
-- Branch operations
CALL dolt_branch('<branch_name>');
CALL dolt_checkout('<branch_name>');
CALL dolt_merge('<branch_name>');
` + "```" + `
**Note:** Use CLI commands (` + "`dolt add`, `dolt commit`, etc." + `) for most operations. SQL procedures are useful when already in a SQL session.
### Advanced System Tables
` + "```sql" + `
-- View remotes
SELECT * FROM dolt_remotes;
-- Check merge conflicts
SELECT * FROM dolt_conflicts;
-- View statistics
SELECT * FROM dolt_statistics;
-- See ignored tables
SELECT * FROM dolt_ignore;
` + "```" + `
## CLI vs SQL Approach
**Prefer CLI commands for:**
- Version control operations (add, commit, branch, merge)
- Repository management (init, clone, push, pull)
- Conflict resolution
- Status checking and history viewing
**Use SQL for:**
- Data queries and analysis
- Complex data transformations
- Examining system tables (dolt_log, dolt_status, etc.)
- When already in an active SQL session
## Schema Design Recommendations
### Use UUID Keys Instead of Auto-Increment
For Dolt's version control features, use UUID primary keys instead of auto-increment:
` + "```sql" + `
-- Recommended
CREATE TABLE users (
id varchar(36) default(uuid()) primary key,
name varchar(255)
);
-- Avoid auto-increment with Dolt
-- id int auto_increment primary key
` + "```" + `
**Benefits:**
- Prevents merge conflicts across branches and database clones
- Automatic generation with default(uuid())
- Works seamlessly in distributed environments
## Best Practices for Agents
### 1. Always Work on Feature Branches
` + "```bash" + `
# Create feature branch before making changes
dolt checkout -b feature/agent-changes
# Make changes on feature branch
dolt sql -q "INSERT INTO users VALUES (1, 'Alice');"
# Stage and commit
dolt add .
dolt commit -m "Add new user Alice"
# Switch back to main to merge
dolt checkout main
dolt merge feature/agent-changes
` + "```" + `
### 2. Use SQL for Data Operations, CLI for Version Control
` + "```bash" + `
# Use dolt sql for data changes
dolt sql -q "INSERT INTO users VALUES (1, 'Alice');"
dolt sql -q "UPDATE products SET price = price * 1.1 WHERE category = 'electronics';"
# Check status and commit using CLI
dolt status
dolt add .
dolt commit -m "Update user and product data"
` + "```" + `
### 3. Validate Changes with System Tables
` + "```sql" + `
-- Before major operations, check current state
SELECT * FROM dolt_status;
SELECT * FROM dolt_branches;
-- After changes, verify with diffs
SELECT * FROM dolt_diff_users;
SELECT * FROM dolt_schema_diff;
` + "```" + `
### 4. Use dolt_test for Data Validation
Create tests to validate:
- Data integrity after changes
- Schema compatibility
- Business rule compliance
- Cross-table relationships
### 5. Handle Conflicts Gracefully
` + "```bash" + `
# Check for conflicts using CLI
dolt conflicts cat <table_name>
dolt conflicts resolve --ours <table_name>
dolt conflicts resolve --theirs <table_name>
# Or use SQL to examine conflicts
dolt sql -q "SELECT * FROM dolt_conflicts_<table_name>;"
` + "```" + `
## Common Workflow Examples
### Data Migration Workflow
` + "```bash" + `
# Create migration branch
dolt checkout -b migration/update-schema
# Apply schema changes via SQL
dolt sql -q "ALTER TABLE users ADD COLUMN email VARCHAR(255);"
# Create validation tests
dolt sql -q "INSERT INTO ` + "`dolt_tests`" + ` VALUES ('test_users_schema', 'schema', 'DESCRIBE users;', 'row_count', '>=', '6');"
dolt sql -q "INSERT INTO ` + "`dolt_tests`" + ` VALUES ('test_email_column', 'schema', 'SELECT COUNT(*) FROM users WHERE email IS NULL;', 'row_count', '>=', '0');"
# Run tests to validate changes
dolt sql -q "SELECT * FROM dolt_test_run();"
# Stage and commit
dolt add .
dolt commit -m "Add email column to users table"
# Merge back
dolt checkout main
dolt merge migration/update-schema
` + "```" + `
### Data Analysis Workflow
` + "```bash" + `
# Create analysis branch
dolt checkout -b analysis/user-behavior
# Create analysis tables via SQL
dolt sql -q "CREATE TABLE user_metrics AS
SELECT user_id, COUNT(*) as actions
FROM user_actions
GROUP BY user_id;"
# Create tests to validate analysis
dolt sql -q "INSERT INTO ` + "`dolt_tests`" + ` VALUES ('test_metrics_created', 'analysis', 'SELECT COUNT(*) FROM user_metrics;', 'row_count', '>', '0');"
dolt sql -q "INSERT INTO ` + "`dolt_tests`" + ` VALUES ('test_metrics_integrity', 'integrity', 'SELECT COUNT(*) FROM user_metrics um LEFT JOIN users u ON um.user_id = u.id WHERE u.id IS NULL;', 'row_count', '==', '0');"
# Run tests to validate analysis
dolt sql -q "SELECT * FROM dolt_test_run();"
# Stage and commit using CLI
dolt add user_metrics
dolt commit -m "Add user behavior analysis"
` + "```" + `
## Integration with External Tools
### Database Clients
Most MySQL clients work with Dolt:
- MySQL Workbench
- phpMyAdmin
- DataGrip
- DBeaver
### Backup and Sync
` + "```bash" + `
# Push to remote
dolt push origin main
# Pull changes
dolt pull origin main
# Clone for backup
dolt clone <remote-url> backup-location
` + "```" + `
This guide enables agents to leverage Dolt's unique version control capabilities while maintaining data integrity and following collaborative development practices.`
)
// GetDocTableName returns the name of the dolt table containing documents such as the license and readme
+9 -24
View File
@@ -169,43 +169,28 @@ func AbortMerge(ctx *sql.Context, workingSet *doltdb.WorkingSet, roots doltdb.Ro
return nil, fmt.Errorf("there is no merge to abort")
}
tbls, err := doltdb.UnionTableNames(ctx, roots.Working, roots.Staged, roots.Head)
if err != nil {
return nil, err
}
tbls, err = doltdb.ExcludeIgnoredTables(ctx, roots, tbls)
if err != nil {
return nil, err
}
roots, err = actions.MoveTablesFromHeadToWorking(ctx, roots, tbls)
if err != nil {
return nil, err
}
preMergeWorkingRoot := workingSet.MergeState().PreMergeWorkingRoot()
preMergeWorkingTables, err := preMergeWorkingRoot.GetTableNames(ctx, doltdb.DefaultSchemaName, true)
if err != nil {
return nil, err
}
nonIgnoredTables, err := doltdb.ExcludeIgnoredTables(ctx, roots, doltdb.ToTableNames(preMergeWorkingTables, doltdb.DefaultSchemaName))
// Revert the working set back to the pre-merge working root
workingSet = workingSet.WithStagedRoot(roots.Head).WithWorkingRoot(preMergeWorkingRoot).WithStagedRoot(roots.Head)
workingSet = workingSet.ClearMerge()
// Carry over any ignored tables (which could have been manually modified by a user while a merge was halted)
ignoredTables, err := doltdb.IdentifyIgnoredTables(ctx, roots, doltdb.ToTableNames(preMergeWorkingTables, doltdb.DefaultSchemaName))
if err != nil {
return nil, err
}
someTablesAreIgnored := len(nonIgnoredTables) != len(preMergeWorkingTables)
if someTablesAreIgnored {
newWorking, err := actions.MoveTablesBetweenRoots(ctx, nonIgnoredTables, preMergeWorkingRoot, roots.Working)
if len(ignoredTables) > 0 {
newWorking, err := actions.MoveTablesBetweenRoots(ctx, ignoredTables, roots.Working, preMergeWorkingRoot)
if err != nil {
return nil, err
}
workingSet = workingSet.WithWorkingRoot(newWorking)
} else {
workingSet = workingSet.WithWorkingRoot(preMergeWorkingRoot)
}
// Unstage everything by making Staged match Head
workingSet = workingSet.WithStagedRoot(roots.Head)
workingSet = workingSet.ClearMerge()
return workingSet, nil
}
@@ -688,7 +688,7 @@ func sqlTypeString(t typeinfo.TypeInfo) string {
}
// Extended types are string serializable, so we'll just prepend a tag
if extendedType, ok := typ.(sqltypes.ExtendedType); ok {
if extendedType, ok := typ.(sql.ExtendedType); ok {
serializedType, err := sqltypes.SerializeTypeToString(extendedType)
if err != nil {
panic(err)
+2 -2
View File
@@ -476,7 +476,7 @@ func (si *schemaImpl) getKeyColumnsDescriptor(vs val.ValueStore, convertAddressC
var handler val.TupleTypeHandler
_, contentHashedField := contentHashedFields[tag]
extendedType, isExtendedType := sqlType.(gmstypes.ExtendedType)
extendedType, isExtendedType := sqlType.(sql.ExtendedType)
if isExtendedType {
encoding := EncodingFromSqlType(sqlType)
@@ -573,7 +573,7 @@ func (si *schemaImpl) GetValueDescriptor(vs val.ValueStore) val.TupleDesc {
collations = append(collations, sql.Collation_Unspecified)
}
if extendedType, ok := sqlType.(gmstypes.ExtendedType); ok {
if extendedType, ok := sqlType.(sql.ExtendedType); ok {
switch encoding {
case serial.EncodingExtendedAddr:
handlers = append(handlers, val.NewExtendedAddressTypeHandler(vs, extendedType))
@@ -18,7 +18,6 @@ import (
"fmt"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/vt/proto/query"
"github.com/dolthub/dolt/go/gen/fb/serial"
@@ -26,11 +25,11 @@ import (
// EncodingFromSqlType returns a serial.Encoding for a sql.Type.
func EncodingFromSqlType(typ sql.Type) serial.Encoding {
if extendedType, ok := typ.(types.ExtendedType); ok {
if extendedType, ok := typ.(sql.ExtendedType); ok {
switch extendedType.MaxSerializedWidth() {
case types.ExtendedTypeSerializedWidth_64K:
case sql.ExtendedTypeSerializedWidth_64K:
return serial.EncodingExtended
case types.ExtendedTypeSerializedWidth_Unbounded:
case sql.ExtendedTypeSerializedWidth_Unbounded:
// Always uses adaptive encoding for extended types, regardless of the setting of UseAdaptiveEncoding below.
return serial.EncodingExtendedAdaptive
default:
@@ -112,6 +111,8 @@ func EncodingFromQueryType(typ query.Type) serial.Encoding {
return serial.EncodingStringAdaptive
}
return serial.EncodingStringAddr
case query.Type_VECTOR:
return serial.EncodingBytesAdaptive
default:
panic(fmt.Sprintf("unknown encoding %v", typ))
}
@@ -121,8 +121,8 @@ func (d doltTypeCompatibilityChecker) IsTypeChangeCompatible(from, to typeinfo.T
// The TypeCompatibility checkers don't support ExtendedTypes added by integrators, so if we see
// one, return early and report the types are not compatible.
_, fromExtendedType := fromSqlType.(types.ExtendedType)
_, toExtendedType := toSqlType.(types.ExtendedType)
_, fromExtendedType := fromSqlType.(sql.ExtendedType)
_, toExtendedType := toSqlType.(sql.ExtendedType)
if fromExtendedType || toExtendedType {
return res
}
@@ -367,7 +367,7 @@ func mustCreateType(sqlType sql.Type) typeinfo.TypeInfo {
// extendedType is a no-op implementation of gmstypes.ExtendedType, used for testing type compatibility with extended types.
type extendedType struct{}
var _ gmstypes.ExtendedType = extendedType{}
var _ sql.ExtendedType = extendedType{}
func (e extendedType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
panic("unimplemented")
@@ -381,7 +381,7 @@ func (e extendedType) Convert(ctx context.Context, i interface{}) (interface{},
panic("unimplemented")
}
func (e extendedType) ConvertToType(ctx *sql.Context, typ gmstypes.ExtendedType, val any) (any, error) {
func (e extendedType) ConvertToType(ctx *sql.Context, typ sql.ExtendedType, val any) (any, error) {
panic("unimplemented")
}
@@ -433,6 +433,6 @@ func (e extendedType) FormatValue(val any) (string, error) {
panic("unimplemented")
}
func (e extendedType) MaxSerializedWidth() gmstypes.ExtendedTypeSerializedWidth {
func (e extendedType) MaxSerializedWidth() sql.ExtendedTypeSerializedWidth {
panic("unimplemented")
}
@@ -31,7 +31,7 @@ const (
// extendedType is a type that refers to an ExtendedType in GMS. These are only supported in the new format, and have many
// more limitations than traditional types (for now).
type extendedType struct {
sqlExtendedType gmstypes.ExtendedType
sqlExtendedType sql.ExtendedType
}
var _ TypeInfo = (*extendedType)(nil)
@@ -49,7 +49,7 @@ func CreateExtendedTypeFromParams(params map[string]string) (TypeInfo, error) {
}
// CreateExtendedTypeFromSqlType creates a TypeInfo from the given extended type.
func CreateExtendedTypeFromSqlType(typ gmstypes.ExtendedType) TypeInfo {
func CreateExtendedTypeFromSqlType(typ sql.ExtendedType) TypeInfo {
return &extendedType{typ}
}
@@ -110,7 +110,7 @@ func (ti *extendedType) NomsKind() types.NomsKind {
// Promote implements the TypeInfo interface.
func (ti *extendedType) Promote() TypeInfo {
return &extendedType{ti.sqlExtendedType.Promote().(gmstypes.ExtendedType)}
return &extendedType{ti.sqlExtendedType.Promote().(sql.ExtendedType)}
}
// String implements the TypeInfo interface.
@@ -46,6 +46,7 @@ const (
UuidTypeIdentifier Identifier = "uuid"
VarBinaryTypeIdentifier Identifier = "varbinary"
VarStringTypeIdentifier Identifier = "varstring"
VectorTypeIdentifier Identifier = "vector"
YearTypeIdentifier Identifier = "year"
GeometryTypeIdentifier Identifier = "geometry"
PointTypeIdentifier Identifier = "point"
@@ -136,7 +137,7 @@ type TypeInfo interface {
// FromSqlType takes in a sql.Type and returns the most relevant TypeInfo.
func FromSqlType(sqlType sql.Type) (TypeInfo, error) {
if gmsExtendedType, ok := sqlType.(gmstypes.ExtendedType); ok {
if gmsExtendedType, ok := sqlType.(sql.ExtendedType); ok {
return CreateExtendedTypeFromSqlType(gmsExtendedType), nil
}
sqlType, err := fillInCollationWithDefault(sqlType)
@@ -273,6 +274,12 @@ func FromSqlType(sqlType sql.Type) (TypeInfo, error) {
return nil, fmt.Errorf(`expected "SetTypeIdentifier" from SQL basetype "Set"`)
}
return &setType{setSQLType}, nil
case sqltypes.Vector:
vectorSQLType, ok := sqlType.(gmstypes.VectorType)
if !ok {
return nil, fmt.Errorf(`expected "VectorTypeIdentifier" from SQL basetype "Vector"`)
}
return &vectorType{vectorSQLType}, nil
default:
return nil, fmt.Errorf(`no type info can be created from SQL base type "%v"`, sqlType.String())
}
@@ -0,0 +1,162 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package typeinfo
import (
"context"
"fmt"
"strconv"
"strings"
"unsafe"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/go-mysql-server/sql/values"
"github.com/dolthub/dolt/go/store/types"
)
const (
vectorTypeParam_Length = "length"
)
// As a type, this is modeled more after MySQL's story for binary data. There, it's treated
// as a string that is interpreted as raw bytes, rather than as a bespoke data structure,
// and thus this is mirrored here in its implementation. This will minimize any differences
// that could arise.
//
// This type handles the BLOB types. BINARY and VARBINARY are handled by inlineBlobType.
type vectorType struct {
sqlVectorType gmstypes.VectorType
}
var _ TypeInfo = (*vectorType)(nil)
// ConvertNomsValueToValue implements TypeInfo interface.
func (ti *vectorType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
if val, ok := v.(types.Blob); ok {
return fromBlob(val)
}
if _, ok := v.(types.Null); ok || v == nil {
return nil, nil
}
return nil, fmt.Errorf(`"%v" cannot convert NomsKind "%v" to a value`, ti, v.Kind())
}
// ReadFrom reads a go value from a noms types.CodecReader directly
func (ti *vectorType) ReadFrom(_ *types.NomsBinFormat, reader types.CodecReader) (interface{}, error) {
k := reader.PeekKind()
switch k {
case types.BlobKind:
val, err := reader.ReadBlob()
if err != nil {
return nil, err
}
return fromBlob(val)
case types.NullKind:
_ = reader.ReadKind()
return nil, nil
}
return nil, fmt.Errorf(`"%v" cannot convert NomsKind "%v" to a value`, ti.String(), k)
}
// ConvertValueToNomsValue implements TypeInfo interface.
func (ti *vectorType) ConvertValueToNomsValue(ctx context.Context, vrw types.ValueReadWriter, v interface{}) (types.Value, error) {
if v == nil {
return types.NullValue, nil
}
strVal, _, err := ti.sqlVectorType.Convert(ctx, v)
if err != nil {
return nil, err
}
val, ok := strVal.([]byte)
if ok {
return types.NewBlob(ctx, vrw, strings.NewReader(string(val)))
}
return nil, fmt.Errorf(`"%v" cannot convert value "%v" of type "%T" as it is invalid`, ti.String(), v, v)
}
// Equals implements TypeInfo interface.
func (ti *vectorType) Equals(other TypeInfo) bool {
if other == nil {
return false
}
if ti2, ok := other.(*vectorType); ok {
return ti.sqlVectorType.Dimensions == ti2.sqlVectorType.Dimensions
}
return false
}
// FormatValue implements TypeInfo interface.
func (ti *vectorType) FormatValue(v types.Value) (*string, error) {
if val, ok := v.(types.Blob); ok {
resStr, err := fromBlob(val)
if err != nil {
return nil, err
}
// This is safe (See https://go101.org/article/unsafe.html)
return (*string)(unsafe.Pointer(&resStr)), nil
}
if _, ok := v.(types.Null); ok || v == nil {
return nil, nil
}
return nil, fmt.Errorf(`"%v" cannot convert NomsKind "%v" to a string`, ti, v.Kind())
}
// GetTypeIdentifier implements TypeInfo interface.
func (ti *vectorType) GetTypeIdentifier() Identifier {
return VectorTypeIdentifier
}
// GetTypeParams implements TypeInfo interface.
func (ti *vectorType) GetTypeParams() map[string]string {
return map[string]string{
vectorTypeParam_Length: strconv.FormatInt(int64(ti.sqlVectorType.Dimensions), 10),
}
}
// IsValid implements TypeInfo interface.
func (ti *vectorType) IsValid(v types.Value) bool {
if val, ok := v.(types.Blob); ok {
if int(val.Len()) == ti.sqlVectorType.Dimensions*int(values.Float32Size) {
return true
}
}
if _, ok := v.(types.Null); ok || v == nil {
return true
}
return false
}
// NomsKind implements TypeInfo interface.
func (ti *vectorType) NomsKind() types.NomsKind {
return types.BlobKind
}
// Promote implements TypeInfo interface.
func (ti *vectorType) Promote() TypeInfo {
return ti
}
// String implements TypeInfo interface.
func (ti *vectorType) String() string {
return fmt.Sprintf(`Vector(%v)`, ti.sqlVectorType.Dimensions)
}
// ToSqlType implements TypeInfo interface.
func (ti *vectorType) ToSqlType() sql.Type {
return ti.sqlVectorType
}
@@ -408,66 +408,85 @@ func ResolveSchemaConflicts(ctx *sql.Context, ddb *doltdb.DoltDB, ws *doltdb.Wor
return ws.WithWorkingRoot(root).WithUnmergableTables(unmerged).WithMergedTables(merged), nil
}
func ResolveDataConflictsForTable(ctx *sql.Context, root doltdb.RootValue, tblName doltdb.TableName, ours bool, getEditorOpts func() (editor.Options, error)) (doltdb.RootValue, bool, error) {
tbl, ok, err := root.GetTable(ctx, tblName)
if err != nil {
return nil, false, err
}
if !ok {
return nil, false, doltdb.ErrTableNotFound
}
if has, err := tbl.HasConflicts(ctx); err != nil {
return nil, false, err
} else if !has {
return nil, false, nil
}
sch, err := tbl.GetSchema(ctx)
if err != nil {
return nil, false, err
}
_, ourSch, theirSch, err := tbl.GetConflictSchemas(ctx, tblName)
if err != nil {
return nil, false, err
}
if ours && !schema.ColCollsAreEqual(sch.GetAllCols(), ourSch.GetAllCols()) {
return nil, false, ErrConfSchIncompatible
} else if !ours && !schema.ColCollsAreEqual(sch.GetAllCols(), theirSch.GetAllCols()) {
return nil, false, ErrConfSchIncompatible
}
if !ours {
if tbl.Format() == types.Format_DOLT {
tbl, err = resolveProllyConflicts(ctx, tbl, tblName, ourSch, sch)
} else {
opts, err := getEditorOpts()
if err != nil {
return nil, false, err
}
tbl, err = resolveNomsConflicts(ctx, opts, tbl, tblName.Name, sch)
}
if err != nil {
return nil, false, err
}
}
newRoot, err := clearTableAndUpdateRoot(ctx, root, tbl, tblName)
if err != nil {
return nil, false, err
}
err = validateConstraintViolations(ctx, root, newRoot, tblName)
if err != nil {
return nil, false, err
}
return newRoot, true, nil
}
func ResolveDataConflicts(ctx *sql.Context, dSess *dsess.DoltSession, root doltdb.RootValue, dbName string, ours bool, tblNames []doltdb.TableName) error {
getEditorOpts := func() (editor.Options, error) {
state, _, err := dSess.LookupDbState(ctx, dbName)
if err != nil {
return editor.Options{}, err
}
var opts editor.Options
if ws := state.WriteSession(); ws != nil {
opts = ws.GetOptions()
}
return opts, nil
}
for _, tblName := range tblNames {
tbl, ok, err := root.GetTable(ctx, tblName)
newRoot, hasConflicts, err := ResolveDataConflictsForTable(ctx, root, tblName, ours, getEditorOpts)
if err != nil {
return err
}
if !ok {
return doltdb.ErrTableNotFound
}
if has, err := tbl.HasConflicts(ctx); err != nil {
return err
} else if !has {
if !hasConflicts {
continue
}
sch, err := tbl.GetSchema(ctx)
if err != nil {
return err
}
_, ourSch, theirSch, err := tbl.GetConflictSchemas(ctx, tblName)
if err != nil {
return err
}
if ours && !schema.ColCollsAreEqual(sch.GetAllCols(), ourSch.GetAllCols()) {
return ErrConfSchIncompatible
} else if !ours && !schema.ColCollsAreEqual(sch.GetAllCols(), theirSch.GetAllCols()) {
return ErrConfSchIncompatible
}
if !ours {
if tbl.Format() == types.Format_DOLT {
tbl, err = resolveProllyConflicts(ctx, tbl, tblName, ourSch, sch)
} else {
state, _, err := dSess.LookupDbState(ctx, dbName)
if err != nil {
return err
}
var opts editor.Options
if ws := state.WriteSession(); ws != nil {
opts = ws.GetOptions()
}
tbl, err = resolveNomsConflicts(ctx, opts, tbl, tblName.Name, sch)
}
if err != nil {
return err
}
}
newRoot, err := clearTableAndUpdateRoot(ctx, root, tbl, tblName)
if err != nil {
return err
}
err = validateConstraintViolations(ctx, root, newRoot, tblName)
if err != nil {
return err
}
root = newRoot
}
return dSess.SetWorkingRoot(ctx, dbName, root)
@@ -513,9 +532,6 @@ func DoDoltConflictsResolve(ctx *sql.Context, args []string) (int, error) {
if len(strTableNames) == 1 && strTableNames[0] == "." {
all := actions.GetAllTableNames(ctx, ws.WorkingRoot())
if err != nil {
return 1, nil
}
tableNames = all
} else {
for _, tblName := range strTableNames {
@@ -16,12 +16,15 @@ package dtablefunctions
import (
"fmt"
"io"
"strings"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"gopkg.in/src-d/go-errors.v1"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/merge"
@@ -32,6 +35,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dtables"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
dolttable "github.com/dolthub/dolt/go/libraries/doltcore/table"
"github.com/dolthub/dolt/go/store/types"
)
@@ -56,6 +60,8 @@ type DiffTableFunction struct {
fromDate *types.Timestamp
toDate *types.Timestamp
sqlSch sql.Schema
showSkinny bool
includeCols map[string]struct{}
}
// NewInstance creates a new instance of TableFunction interface
@@ -100,26 +106,25 @@ func (dtf *DiffTableFunction) WithDatabase(database sql.Database) (sql.Node, err
// Expressions implements the sql.Expressioner interface
func (dtf *DiffTableFunction) Expressions() []sql.Expression {
exprs := []sql.Expression{}
if dtf.dotCommitExpr != nil {
return []sql.Expression{
dtf.dotCommitExpr, dtf.tableNameExpr,
}
}
return []sql.Expression{
dtf.fromCommitExpr, dtf.toCommitExpr, dtf.tableNameExpr,
exprs = append(exprs, dtf.dotCommitExpr, dtf.tableNameExpr)
} else {
exprs = append(exprs, dtf.fromCommitExpr, dtf.toCommitExpr, dtf.tableNameExpr)
}
return exprs
}
// WithExpressions implements the sql.Expressioner interface
func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
if len(expression) < 2 {
return nil, sql.ErrInvalidArgumentNumber.New(dtf.Name(), "2 to 3", len(expression))
}
func (dtf *DiffTableFunction) WithExpressions(expressions ...sql.Expression) (sql.Node, error) {
newDtf := *dtf
// TODO: For now, we will only support literal / fully-resolved arguments to the
// DiffTableFunction to avoid issues where the schema is needed in the analyzer
// before the arguments could be resolved.
for _, expr := range expression {
var exprStrs []string
strToExpr := map[string]sql.Expression{}
for _, expr := range expressions {
if !expr.Resolved() {
return nil, ErrInvalidNonLiteralArgument.New(dtf.Name(), expr.String())
}
@@ -127,22 +132,52 @@ func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql
if _, ok := expr.(sql.FunctionExpression); ok {
return nil, ErrInvalidNonLiteralArgument.New(dtf.Name(), expr.String())
}
strVal := expr.String()
if lit, ok := expr.(*expression.Literal); ok { // rm quotes from string literals
strVal = fmt.Sprintf("%v", lit.Value())
}
exprStrs = append(exprStrs, strVal) // args extracted from apr later to filter out options
strToExpr[strVal] = expr
}
newDtf := *dtf
if strings.Contains(expression[0].String(), "..") {
if len(expression) != 2 {
return nil, sql.ErrInvalidArgumentNumber.New(fmt.Sprintf("%v with .. or ...", newDtf.Name()), 2, len(expression))
apr, err := cli.CreateDiffArgParser(true).Parse(exprStrs)
if err != nil {
return nil, err
}
if apr.Contains(cli.SkinnyFlag) {
newDtf.showSkinny = true
}
if cols, ok := apr.GetValueList(cli.IncludeCols); ok {
newDtf.includeCols = make(map[string]struct{})
for _, col := range cols {
newDtf.includeCols[col] = struct{}{}
}
newDtf.dotCommitExpr = expression[0]
newDtf.tableNameExpr = expression[1]
}
expressions = []sql.Expression{}
for _, posArg := range apr.Args {
expressions = append(expressions, strToExpr[posArg])
}
if len(expressions) < 2 {
return nil, sql.ErrInvalidArgumentNumber.New(dtf.Name(), "2 to 3", len(expressions))
}
if strings.Contains(expressions[0].String(), "..") {
if len(expressions) != 2 {
return nil, sql.ErrInvalidArgumentNumber.New(fmt.Sprintf("%v with .. or ...", newDtf.Name()), 2, len(expressions))
}
newDtf.dotCommitExpr = expressions[0]
newDtf.tableNameExpr = expressions[1]
} else {
if len(expression) != 3 {
return nil, sql.ErrInvalidArgumentNumber.New(newDtf.Name(), 3, len(expression))
if len(expressions) != 3 {
return nil, sql.ErrInvalidArgumentNumber.New(newDtf.Name(), 3, len(expressions))
}
newDtf.fromCommitExpr = expression[0]
newDtf.toCommitExpr = expression[1]
newDtf.tableNameExpr = expression[2]
newDtf.fromCommitExpr = expressions[0]
newDtf.toCommitExpr = expressions[1]
newDtf.tableNameExpr = expressions[2]
}
fromCommitVal, toCommitVal, dotCommitVal, tableName, err := newDtf.evaluateArguments()
@@ -423,6 +458,110 @@ func (dtf *DiffTableFunction) evaluateArguments() (interface{}, interface{}, int
return fromCommitVal, toCommitVal, nil, tableName, nil
}
// filterDeltaSchemaToSkinnyCols creates a filtered version of the table delta that omits columns which are identical
// in type and value across all rows in both schemas, except for primary key columns or explicitly included using the
// include-cols option. This also updates dtf.tableDelta with the filtered result.
func (dtf *DiffTableFunction) filterDeltaSchemaToSkinnyCols(ctx *sql.Context, delta *diff.TableDelta) (*diff.TableDelta, error) {
if delta.FromTable == nil || delta.ToTable == nil {
return delta, nil
}
// gather map of potential cols for removal from skinny diff
equalDiffColsIndices := map[string][2]int{}
toCols := delta.ToSch.GetAllCols()
for fromIdx, fromCol := range delta.FromSch.GetAllCols().GetColumns() {
if _, ok := dtf.includeCols[fromCol.Name]; ok {
continue // user explicitly included this column
}
col, ok := delta.ToSch.GetAllCols().GetByName(fromCol.Name)
if !ok { // column was dropped
continue
}
if fromCol.TypeInfo.Equals(col.TypeInfo) {
toIdx := toCols.TagToIdx[toCols.NameToCol[fromCol.Name].Tag]
equalDiffColsIndices[fromCol.Name] = [2]int{fromIdx, toIdx}
}
}
fromRowData, err := delta.FromTable.GetRowData(ctx)
if err != nil {
return nil, err
}
toRowData, err := delta.ToTable.GetRowData(ctx)
if err != nil {
return nil, err
}
fromIter, err := dolttable.NewTableIterator(ctx, delta.FromSch, fromRowData)
if err != nil {
return nil, err
}
defer fromIter.Close(ctx)
toIter, err := dolttable.NewTableIterator(ctx, delta.ToSch, toRowData)
if err != nil {
return nil, err
}
defer toIter.Close(ctx)
for len(equalDiffColsIndices) > 0 {
fromRow, fromErr := fromIter.Next(ctx)
toRow, toErr := toIter.Next(ctx)
if fromErr == io.EOF && toErr == io.EOF {
break
}
if fromErr != nil && fromErr != io.EOF {
return nil, fromErr
}
if toErr != nil && toErr != io.EOF {
return nil, toErr
}
// xor: if only one is nil, then all cols are diffs
if (fromRow == nil) != (toRow == nil) {
equalDiffColsIndices = map[string][2]int{}
break
}
if fromRow == nil && toRow == nil {
continue
}
for colName, idx := range equalDiffColsIndices {
if fromRow[idx[0]] != toRow[idx[1]] { // same row and col, values differ
delete(equalDiffColsIndices, colName)
}
}
}
var fromSkCols []schema.Column
for _, col := range delta.FromSch.GetAllCols().GetColumns() {
_, ok := equalDiffColsIndices[col.Name]
if col.IsPartOfPK || !ok {
fromSkCols = append(fromSkCols, col)
}
}
var toSkCols []schema.Column
for _, col := range delta.ToSch.GetAllCols().GetColumns() {
_, ok := equalDiffColsIndices[col.Name]
if col.IsPartOfPK || !ok {
toSkCols = append(toSkCols, col)
}
}
skDelta := *delta
skDelta.FromSch = schema.MustSchemaFromCols(schema.NewColCollection(fromSkCols...))
skDelta.ToSch = schema.MustSchemaFromCols(schema.NewColCollection(toSkCols...))
dtf.tableDelta = skDelta
return &skDelta, nil
}
func (dtf *DiffTableFunction) generateSchema(ctx *sql.Context, fromCommitVal, toCommitVal, dotCommitVal interface{}, tableName string) error {
if !dtf.Resolved() {
return nil
@@ -438,27 +577,27 @@ func (dtf *DiffTableFunction) generateSchema(ctx *sql.Context, fromCommitVal, to
return err
}
if dtf.showSkinny {
skDelta, err := dtf.filterDeltaSchemaToSkinnyCols(ctx, &delta)
if err != nil {
return err
}
delta = *skDelta
}
fromTable, fromTableExists := delta.FromTable, delta.FromTable != nil
toTable, toTableExists := delta.ToTable, delta.ToTable != nil
if !toTableExists && !fromTableExists {
var format *types.NomsBinFormat
if toTableExists {
format = toTable.Format()
} else if fromTableExists {
format = fromTable.Format()
} else {
return sql.ErrTableNotFound.New(tableName)
}
var toSchema, fromSchema schema.Schema
var format *types.NomsBinFormat
if fromTableExists {
fromSchema = delta.FromSch
format = fromTable.Format()
}
if toTableExists {
toSchema = delta.ToSch
format = toTable.Format()
}
diffTableSch, j, err := dtables.GetDiffTableSchemaAndJoiner(format, fromSchema, toSchema)
diffTableSch, j, err := dtables.GetDiffTableSchemaAndJoiner(format, delta.FromSch, delta.ToSch)
if err != nil {
return err
}
@@ -571,15 +710,14 @@ func (dtf *DiffTableFunction) IsReadOnly() bool {
// String implements the Stringer interface
func (dtf *DiffTableFunction) String() string {
args := []string{}
if dtf.dotCommitExpr != nil {
return fmt.Sprintf("DOLT_DIFF(%s, %s)",
dtf.dotCommitExpr.String(),
dtf.tableNameExpr.String())
args = append(args, dtf.dotCommitExpr.String(), dtf.tableNameExpr.String())
} else {
args = append(args, dtf.fromCommitExpr.String(), dtf.toCommitExpr.String(), dtf.tableNameExpr.String())
}
return fmt.Sprintf("DOLT_DIFF(%s, %s, %s)",
dtf.fromCommitExpr.String(),
dtf.toCommitExpr.String(),
dtf.tableNameExpr.String())
return fmt.Sprintf("DOLT_DIFF(%s)", strings.Join(args, ", "))
}
// Name implements the sql.TableFunction interface
@@ -16,6 +16,7 @@ package dtablefunctions
import (
"fmt"
"io"
"strings"
gms "github.com/dolthub/go-mysql-server"
@@ -191,7 +192,7 @@ func (trtf *TestsRunTableFunction) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIt
return nil, err
}
for _, row := range *testRows {
for _, row := range testRows {
result, err := trtf.queryAndAssert(row)
if err != nil {
return nil, err
@@ -254,7 +255,7 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
return result, nil
}
func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) (*[]sql.Row, error) {
func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) ([]sql.Row, error) {
var queries []string
if arg == "*" {
@@ -280,12 +281,21 @@ func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) (*[]sql.Row, err
if err != nil {
return nil, err
}
rows, err := sql.RowIterToRows(trtf.ctx, iter)
if err != nil {
return nil, err
// Calling iter.Close(ctx) will cause TrackedRowIter to cancel the context, causing problems when running with
// dolt sql-server. Since we only support `SELECT...` queries anyway, it's not necessary to Close() the iter.
var rows []sql.Row
for {
row, rErr := iter.Next(trtf.ctx)
if rErr == io.EOF {
break
}
if rErr != nil {
return nil, rErr
}
rows = append(rows, row)
}
if len(rows) > 0 {
return &rows, nil
return rows, nil
}
}
return nil, fmt.Errorf("could not find tests for argument: %s", arg)
@@ -283,7 +283,7 @@ var ModifyAndChangeColumnScripts = []queries.ScriptTest{
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table people modify rating double default 'not a number'",
ExpectedErrStr: "incompatible type for default value: error: 'not a number' is not a valid value for 'double'",
ExpectedErrStr: "incompatible type for default value: Truncated incorrect double value: not a number",
},
},
},
@@ -909,6 +909,12 @@ func TestVectorFunctions(t *testing.T) {
enginetest.TestVectorFunctions(t, harness)
}
func TestVectorType(t *testing.T) {
harness := newDoltHarness(t)
defer harness.Close()
enginetest.TestVectorType(t, harness)
}
func TestIndexPrefix(t *testing.T) {
skipOldFormat(t)
harness := newDoltHarness(t)
@@ -15,6 +15,9 @@
package enginetest
import (
"fmt"
"strings"
"github.com/dolthub/go-mysql-server/enginetest/queries"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
@@ -818,7 +821,138 @@ var Dolt1DiffSystemTableScripts = []queries.ScriptTest{
},
}
// assertDoltDiffColumnCount returns assertions that verify a dolt_diff view
// has the expected number of distinct data columns (excluding commit metadata).
func assertDoltDiffColumnCount(view, selectStmt string, expected int64) []queries.ScriptTestAssertion {
excluded := []string{
"'to_commit'",
"'from_commit'",
"'to_commit_date'",
"'from_commit_date'",
"'diff_type'",
}
query := fmt.Sprintf(`
SELECT COUNT(DISTINCT REPLACE(REPLACE(column_name, 'to_', ''), 'from_', ''))
FROM information_schema.columns
WHERE table_schema = DATABASE()
AND table_name = '%s'
AND column_name NOT IN (%s)`,
view, strings.Join(excluded, ", "),
)
return []queries.ScriptTestAssertion{
{Query: fmt.Sprintf("DROP VIEW IF EXISTS %s;", view)},
{Query: fmt.Sprintf("CREATE VIEW %s AS %s;", view, selectStmt)},
{Query: query, Expected: []sql.Row{{expected}}},
{Query: fmt.Sprintf("DROP VIEW %s;", view)},
}
}
var DiffTableFunctionScriptTests = []queries.ScriptTest{
{
Name: "dolt_diff: SELECT * skinny schema visibility",
SetUpScript: []string{
`CREATE TABLE t (
pk BIGINT NOT NULL COMMENT 'tag:0',
c1 BIGINT COMMENT 'tag:1',
c2 BIGINT COMMENT 'tag:2',
c3 BIGINT COMMENT 'tag:3',
c4 BIGINT COMMENT 'tag:4',
c5 BIGINT COMMENT 'tag:5',
PRIMARY KEY (pk)
);`,
"call dolt_add('.')",
"set @C0 = '';",
"call dolt_commit_hash_out(@C0, '-m', 'Created table t');",
"INSERT INTO t VALUES (0,1,2,3,4,5), (1,1,2,3,4,5);",
"call dolt_add('.')",
"set @C1 = '';",
"call dolt_commit_hash_out(@C1, '-m', 'Added initial data');",
"UPDATE t SET c1=100, c3=300 WHERE pk=0;",
"UPDATE t SET c2=200 WHERE pk=1;",
"call dolt_add('.')",
"set @C2 = '';",
"call dolt_commit_hash_out(@C2, '-m', 'Updated some columns');",
"ALTER TABLE t ADD COLUMN c6 BIGINT;",
"UPDATE t SET c6=600 WHERE pk=0;",
"call dolt_add('.')",
"set @C3 = '';",
"call dolt_commit_hash_out(@C3, '-m', 'Added new column and updated it');",
"DELETE FROM t WHERE pk=1;",
"call dolt_add('.')",
"set @C4 = '';",
"call dolt_commit_hash_out(@C4, '-m', 'Deleted a row');",
},
Assertions: func() []queries.ScriptTestAssertion {
asserts := []queries.ScriptTestAssertion{
{
Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c3, d.to_c4, d.to_c5, d.from_pk, d.from_c1, d.from_c2, d.from_c3, d.from_c4, d.from_c5, d.diff_type " +
"FROM (SELECT * FROM dolt_diff('--skinny', @C0, @C1, 't')) d " +
"ORDER BY COALESCE(d.to_pk, d.from_pk)",
Expected: []sql.Row{
{int64(0), int64(1), int64(2), int64(3), int64(4), int64(5), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), "added"},
{int64(1), int64(1), int64(2), int64(3), int64(4), int64(5), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), "added"},
},
},
{
Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c3, d.from_pk, d.from_c1, d.from_c2, d.from_c3, d.diff_type " +
"FROM (SELECT * FROM dolt_diff(@C1, @C2, 't')) d " +
"ORDER BY COALESCE(d.to_pk, d.from_pk)",
Expected: []sql.Row{
{int64(0), int64(100), int64(2), int64(300), int64(0), int64(1), int64(2), int64(3), "modified"},
{int64(1), int64(1), int64(200), int64(3), int64(1), int64(1), int64(2), int64(3), "modified"},
},
},
{
Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c3, d.diff_type " +
"FROM (SELECT * FROM dolt_diff('--skinny', @C1, @C2, 't')) d " +
"ORDER BY d.to_pk",
Expected: []sql.Row{
{int64(0), int64(100), int64(2), int64(300), "modified"},
{int64(1), int64(1), int64(200), int64(3), "modified"},
},
},
{
Query: "SELECT d.to_pk, d.to_c6, d.diff_type " +
"FROM (SELECT * FROM dolt_diff('--skinny', @C2, @C3, 't')) d",
Expected: []sql.Row{
{int64(0), int64(600), "modified"},
},
},
{
Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c6, d.diff_type " +
"FROM (SELECT * FROM dolt_diff('--skinny', '--include-cols=c1,c2', @C2, @C3, 't')) d",
Expected: []sql.Row{
{int64(0), int64(100), int64(2), int64(600), "modified"},
},
},
{
Query: "SELECT d.from_pk, d.from_c1, d.from_c2, d.from_c3, d.from_c4, d.from_c5, d.from_c6, d.diff_type " +
"FROM (SELECT * FROM dolt_diff('--skinny', @C3, @C4, 't')) d",
Expected: []sql.Row{
{int64(1), int64(1), int64(200), int64(3), int64(4), int64(5), nil, "removed"},
},
},
}
asserts = append(asserts, assertDoltDiffColumnCount("v_all_01", "SELECT * FROM dolt_diff(@C0, @C1, 't')", 6)...)
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_01", "SELECT * FROM dolt_diff('--skinny', @C0, @C1, 't')", 6)...)
asserts = append(asserts, assertDoltDiffColumnCount("v_all_12", "SELECT * FROM dolt_diff(@C1, @C2, 't')", 6)...)
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_12", "SELECT * FROM dolt_diff('--skinny', @C1, @C2, 't')", 4)...)
asserts = append(asserts, assertDoltDiffColumnCount("v_all_23", "SELECT * FROM dolt_diff(@C2, @C3, 't')", 7)...)
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff('--skinny', @C2, @C3, 't')", 2)...)
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff(@C2, @C3, 't', '--skinny')", 2)...)
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff('--skinny', '--include-cols=c1,c2', @C2, @C3, 't')", 4)...)
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff('--skinny', '--include-cols=c1,c2,c6', @C2, @C3, 't')", 4)...)
asserts = append(asserts, assertDoltDiffColumnCount("v_all_34", "SELECT * FROM dolt_diff(@C3, @C4, 't')", 7)...)
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_34", "SELECT * FROM dolt_diff('--skinny', @C3, @C4, 't')", 7)...)
return asserts
}(),
},
{
Name: "invalid arguments",
SetUpScript: []string{
@@ -129,7 +129,6 @@ var MergeScripts = []queries.ScriptTest{
},
},
},
{
// When there is a constraint violation for duplicate copies of a row in a keyless table, each row
// will violate constraint in exactly the same way. Currently, the dolt_constraint_violations_<table>
@@ -1391,6 +1390,90 @@ var MergeScripts = []queries.ScriptTest{
},
},
},
{
// Customer issue repro: when a merge halts from a conflict, if there was a table rename
// it was preventing users from being able to abort the merge.
Name: "Abort merge with table rename and conflict",
SetUpScript: []string{
"SET @@autocommit=0;",
// Create tables on main
"CREATE TABLE conflict_table (pk int primary key, c1 varchar(100));",
"CREATE TABLE table1 (pk int primary key, c1 varchar(100));",
"CALL dolt_commit('-Am', 'creating tables on main');",
"CALL dolt_branch('branch1');",
"INSERT INTO conflict_table VALUES (1, 'one');",
"CALL dolt_commit('-Am', 'adding another table on main');",
// Rename a table on branch1 and create a conflict to halt the merge
"CALL dolt_checkout('branch1');",
"INSERT INTO conflict_table VALUES (1, 'uno');",
"INSERT INTO table1 VALUES (1, 'one');",
"RENAME TABLE table1 to table2;",
"CALL dolt_commit('-Am', 'renaming table on branch1');",
"CALL dolt_checkout('main');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('branch1');",
Expected: []sql.Row{{"", 0, 1, "conflicts found"}},
},
{
Query: "CALL dolt_merge('--abort');",
Expected: []sql.Row{{"", 0, 0, "merge aborted"}},
},
},
},
{
// dolt_ignore does not work properly in Doltgres yet, so marking the dialect
// as mysql so that it skips this test.
Dialect: "mysql",
Name: "Abort merge with table rename and conflict (with ignored table)",
SetUpScript: []string{
"SET @@autocommit=0;",
// Set up an ignored table
"INSERT INTO dolt_ignore VALUES ('ignore_me', true);",
"CREATE TABLE ignore_me (pk int primary key, c1 varchar(100));",
"INSERT INTO ignore_me VALUES (1, 'uno');",
// Create tables on main
"CREATE TABLE conflict_table (pk int primary key, c1 varchar(100));",
"CREATE TABLE table1 (pk int primary key, c1 varchar(100));",
"CALL dolt_commit('-Am', 'creating tables on main');",
"CALL dolt_branch('branch1');",
"INSERT INTO conflict_table VALUES (1, 'one');",
"CALL dolt_commit('-Am', 'adding another table on main');",
// Rename a table on branch1 and create a conflict to halt the merge
"CALL dolt_checkout('branch1');",
"INSERT INTO conflict_table VALUES (1, 'uno');",
"INSERT INTO table1 VALUES (1, 'one');",
"RENAME TABLE table1 to table2;",
"CALL dolt_commit('-Am', 'renaming table on branch1');",
"CALL dolt_checkout('main');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('branch1');",
Expected: []sql.Row{{"", 0, 1, "conflicts found"}},
},
{
Query: "INSERT INTO ignore_me VALUES (2, 'duex');",
Expected: []sql.Row{{types.NewOkResult(1)}},
},
{
Query: "CALL dolt_merge('--abort');",
Expected: []sql.Row{{"", 0, 0, "merge aborted"}},
},
{
Query: "SELECT * FROM ignore_me;",
Expected: []sql.Row{{1, "uno"}, {2, "duex"}},
},
},
},
{
Name: "CALL DOLT_MERGE complains when a merge overrides local changes",
SetUpScript: []string{
@@ -1060,7 +1060,9 @@ func TestDoltIndexBetween(t *testing.T) {
exprs := idx.Expressions()
sqlIndex := sql.NewMySQLIndexBuilder(idx)
for i := range test.greaterThanOrEqual {
sqlIndex = sqlIndex.GreaterOrEqual(ctx, exprs[i], test.greaterThanOrEqual[i]).LessOrEqual(ctx, exprs[i], test.lessThanOrEqual[i])
sqlIndex = sqlIndex.
GreaterOrEqual(ctx, exprs[i], nil, test.greaterThanOrEqual[i]).
LessOrEqual(ctx, exprs[i], nil, test.lessThanOrEqual[i])
}
indexLookup, err := sqlIndex.Build(ctx)
require.NoError(t, err)
@@ -1298,17 +1300,17 @@ func testDoltIndex(t *testing.T, ctx *sql.Context, root doltdb.RootValue, keys [
for i, key := range keys {
switch cmp {
case indexComp_Eq:
builder = builder.Equals(ctx, exprs[i], key)
builder = builder.Equals(ctx, exprs[i], nil, key)
case indexComp_NEq:
builder = builder.NotEquals(ctx, exprs[i], key)
builder = builder.NotEquals(ctx, exprs[i], nil, key)
case indexComp_Gt:
builder = builder.GreaterThan(ctx, exprs[i], key)
builder = builder.GreaterThan(ctx, exprs[i], nil, key)
case indexComp_GtE:
builder = builder.GreaterOrEqual(ctx, exprs[i], key)
builder = builder.GreaterOrEqual(ctx, exprs[i], nil, key)
case indexComp_Lt:
builder = builder.LessThan(ctx, exprs[i], key)
builder = builder.LessThan(ctx, exprs[i], nil, key)
case indexComp_LtE:
builder = builder.LessOrEqual(ctx, exprs[i], key)
builder = builder.LessOrEqual(ctx, exprs[i], nil, key)
default:
panic("should not be hit")
}
@@ -67,6 +67,7 @@ func TestCountAgg(t *testing.T) {
name: "reject multi parameter",
setup: []string{
"create table xy (x int primary key, y int, key y_idx(y))",
"SET SESSION sql_mode = REPLACE(@@SESSION.sql_mode, 'ONLY_FULL_GROUP_BY', '');",
},
query: "select count(y), x from xy",
doRowexec: false,
@@ -84,7 +85,7 @@ func TestCountAgg(t *testing.T) {
setup: []string{
"create table xy (x int primary key, y int, key y_idx(y))",
},
query: "select count(y+1), x from xy",
query: "select count(y+1) from xy",
doRowexec: false,
},
}
@@ -337,7 +337,7 @@ func DoltProceduresGetAll(ctx *sql.Context, db Database, procedureName string) (
if procedureName == "" {
lookup, err = sql.NewMySQLIndexBuilder(idx).IsNotNull(ctx, nameExpr).Build(ctx)
} else {
lookup, err = sql.NewMySQLIndexBuilder(idx).Equals(ctx, nameExpr, procedureName).Build(ctx)
lookup, err = sql.NewMySQLIndexBuilder(idx).Equals(ctx, nameExpr, gmstypes.Text, procedureName).Build(ctx)
}
if err != nil {
return nil, err
@@ -471,7 +471,9 @@ func DoltProceduresGetDetails(ctx *sql.Context, tbl *WritableDoltTable, name str
return sql.StoredProcedureDetails{}, false, fmt.Errorf("could not find primary key index on system table `%s`", doltdb.ProceduresTableName)
}
indexLookup, err := sql.NewMySQLIndexBuilder(fragNameIndex).Equals(ctx, fragNameIndex.Expressions()[0], name).Build(ctx)
indexLookup, err := sql.NewMySQLIndexBuilder(fragNameIndex).
Equals(ctx, fragNameIndex.Expressions()[0], gmstypes.Text, name).
Build(ctx)
if err != nil {
return sql.StoredProcedureDetails{}, false, err
}
+1 -1
View File
@@ -541,7 +541,7 @@ func interfaceValueAsSqlString(ctx *sql.Context, ti typeinfo.TypeInfo, value int
return singleQuote + str + singleQuote, nil
case typeinfo.DatetimeTypeIdentifier:
return singleQuote + str + singleQuote, nil
case typeinfo.InlineBlobTypeIdentifier, typeinfo.VarBinaryTypeIdentifier:
case typeinfo.InlineBlobTypeIdentifier, typeinfo.VarBinaryTypeIdentifier, typeinfo.VectorTypeIdentifier:
value, err := sql.UnwrapAny(ctx, value)
if err != nil {
return "", err
+2
View File
@@ -118,6 +118,8 @@ func GetSubtrees(msg serial.Message) ([]uint64, error) {
switch id {
case serial.ProllyTreeNodeFileID:
return getProllyMapSubtrees(msg)
case serial.VectorIndexNodeFileID:
return getVectorIndexSubtrees(msg)
case serial.AddressMapFileID:
return getAddressMapSubtrees(msg)
case serial.MergeArtifactsFileID:
+101 -101
View File
@@ -16,6 +16,7 @@ package prolly
import (
"context"
"fmt"
"io"
"iter"
@@ -138,33 +139,50 @@ func (p *proximityMapIter) Next(ctx context.Context) (k val.Tuple, v val.Tuple,
return
}
func getConvertToVectorFunction(keyDesc val.TupleDesc, ns tree.NodeStore) (tree.ConvertToVectorFunction, error) {
switch keyDesc.Types[0].Enc {
case val.JSONAddrEnc:
return func(ctx context.Context, bytes []byte) ([]float32, error) {
h, _ := keyDesc.GetJSONAddr(0, bytes)
doc := tree.NewJSONDoc(h, ns)
jsonWrapper, err := doc.ToIndexedJSONDocument(ctx)
if err != nil {
return nil, err
}
return sql.ConvertToVector(ctx, jsonWrapper)
}, nil
case val.BytesAdaptiveEnc:
return func(ctx context.Context, bytes []byte) ([]float32, error) {
vec, _, err := keyDesc.GetBytesAdaptiveValue(ctx, 0, ns, bytes)
if err != nil {
return nil, err
}
return sql.ConvertToVector(ctx, vec)
}, nil
default:
return nil, fmt.Errorf("unexpected encoding for vector index: %v", keyDesc.Types[0].Enc)
}
}
// NewProximityMap creates a new ProximityMap from a supplied root node.
func NewProximityMap(ns tree.NodeStore, node tree.Node, keyDesc val.TupleDesc, valDesc val.TupleDesc, distanceType vector.DistanceType, logChunkSize uint8) ProximityMap {
func NewProximityMap(ns tree.NodeStore, node tree.Node, keyDesc val.TupleDesc, valDesc val.TupleDesc, distanceType vector.DistanceType, logChunkSize uint8) (ProximityMap, error) {
convertFunc, err := getConvertToVectorFunction(keyDesc, ns)
if err != nil {
return ProximityMap{}, err
}
tuples := tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{
Root: node,
NodeStore: ns,
Order: keyDesc,
DistanceType: distanceType,
Convert: func(ctx context.Context, bytes []byte) []float64 {
h, _ := keyDesc.GetJSONAddr(0, bytes)
doc := tree.NewJSONDoc(h, ns)
jsonWrapper, err := doc.ToIndexedJSONDocument(ctx)
if err != nil {
panic(err)
}
floats, err := sql.ConvertToVector(ctx, jsonWrapper)
if err != nil {
panic(err)
}
return floats
},
Convert: convertFunc,
}
return ProximityMap{
tuples: tuples,
keyDesc: keyDesc,
valDesc: valDesc,
logChunkSize: logChunkSize,
}
}, nil
}
var proximitylevelMapKeyDesc = val.NewTupleDescriptor(
@@ -180,6 +198,10 @@ func NewProximityMapBuilder(ctx context.Context, ns tree.NodeStore, distanceType
return ProximityMapBuilder{}, err
}
mutableLevelMap := newMutableMap(emptyLevelMap)
convertFunc, err := getConvertToVectorFunction(keyDesc, ns)
if err != nil {
return ProximityMapBuilder{}, err
}
return ProximityMapBuilder{
ns: ns,
vectorIndexSerializer: message.NewVectorIndexSerializer(ns.Pool(), logChunkSize, distanceType),
@@ -189,6 +211,7 @@ func NewProximityMapBuilder(ctx context.Context, ns tree.NodeStore, distanceType
logChunkSize: logChunkSize,
maxLevel: 0,
levelMap: mutableLevelMap,
convertFunc: convertFunc,
}, nil
}
@@ -211,7 +234,7 @@ func NewProximityMapBuilder(ctx context.Context, ns tree.NodeStore, distanceType
//
// Step 2: Create `pathMaps`, a list of maps, each corresponding to a different level of the ProximityMap
//
// The pathMap at depth `i` has the schema (vectorAddrs[1]...vectorAddr[i], keyBytes) -> value
// The pathMap at depth `i` has the schema (vectorAddrs[0], ..., vectorAddr[i], keyBytes) -> value
// and contains a row for every vector whose maximum depth is i.
// - vectorAddrs: the path of vectors visited when walking from the root to the maximum depth where the vector appears.
// - keyBytes: a bytestring containing the bytes of the ProximityMap key (which includes the vector)
@@ -232,11 +255,12 @@ type ProximityMapBuilder struct {
vectorIndexSerializer message.VectorIndexSerializer
ns tree.NodeStore
distanceType vector.DistanceType
levelMap *MutableMap
keyDesc val.TupleDesc
valDesc val.TupleDesc
logChunkSize uint8
maxLevel uint8
levelMap *MutableMap
convertFunc tree.ConvertToVectorFunction
}
// Insert adds a new key-value pair to the ProximityMap under construction.
@@ -299,7 +323,7 @@ func (b *ProximityMapBuilder) makeRootNode(ctx context.Context, keys, values [][
return ProximityMap{}, err
}
return NewProximityMap(b.ns, rootNode, b.keyDesc, b.valDesc, b.distanceType, b.logChunkSize), nil
return NewProximityMap(b.ns, rootNode, b.keyDesc, b.valDesc, b.distanceType, b.logChunkSize)
}
// Flush finishes constructing a ProximityMap. Call this after all calls to Insert.
@@ -371,7 +395,7 @@ func (b *ProximityMapBuilder) makePathMaps(ctx context.Context, mutableLevelMap
// Create every val.TupleBuilder and MutableMap that we will need
// pathMaps[i] is the pathMap for level i (and depth maxLevel - i)
pathMaps, keyTupleBuilder, prefixTupleBuilder, err := b.createInitialPathMaps(ctx, maxLevel)
pathMaps, keyTupleBuilder, err := b.createInitialPathMaps(ctx, maxLevel)
// Next, visit each key-value pair in decreasing order of level / increasing order of depth.
// When visiting a pair from depth `i`, we use each of the previous `i` pathMaps to compute a path of `i` index keys.
@@ -382,55 +406,53 @@ func (b *ProximityMapBuilder) makePathMaps(ctx context.Context, mutableLevelMap
depth := int(maxLevel - level)
// hashPath is a list of concatenated hashes, representing the sequence of closest vectors at each level of the tree.
var hashPath []byte
keyToInsert, _ := mutableLevelMap.keyDesc.GetBytes(1, levelMapKey)
vectorHashToInsert, _ := b.keyDesc.GetJSONAddr(0, keyToInsert)
vectorToInsert, err := getVectorFromHash(ctx, b.ns, vectorHashToInsert)
vectorToInsert, err := b.convertFunc(ctx, keyToInsert)
if err != nil {
return nil, err
}
// Compute the path that this row will have in the vector index, starting at the root.
// A key-value pair at depth D will have a path D prior keys.
// This path is computed in steps, by performing a lookup in each of the prior pathMaps.
// Each iteration sets another column in |keyTupleBuilder|, then does a prefix lookup in the next pathMap
// with all currently set columns.
for pathDepth := 0; pathDepth < depth; pathDepth++ {
lookupLevel := int(maxLevel) - pathDepth
pathMap := pathMaps[lookupLevel]
pathMapIter, err := b.getNextPathSegmentCandidates(ctx, pathMap, prefixTupleBuilder, hashPath)
pathMapIter, err := b.getNextPathSegmentCandidates(ctx, pathMap, keyTupleBuilder.Desc.PrefixDesc(pathDepth), keyTupleBuilder.BuildPrefixNoRecycle(b.ns.Pool(), pathDepth))
if err != nil {
return nil, err
}
// Create an iterator that yields every candidate vector
nextCandidate, stopIter := iter.Pull2(func(yield func(hash.Hash, error) bool) {
nextCandidate, stopIter := iter.Pull2(func(yield func([]byte, error) bool) {
for {
pathMapKey, _, err := pathMapIter.Next(ctx)
if err == io.EOF {
return
}
if err != nil {
yield(hash.Hash{}, err)
yield(nil, err)
}
originalKey, _ := pathMap.keyDesc.GetBytes(1, pathMapKey)
candidateVectorHash, _ := b.keyDesc.GetJSONAddr(0, originalKey)
yield(candidateVectorHash, nil)
originalKey, _ := pathMap.keyDesc.GetBytes(pathDepth, pathMapKey)
yield(originalKey, nil)
}
})
defer stopIter()
closestVectorHash, _ := b.getClosestVector(ctx, vectorToInsert, nextCandidate)
closestVectorEncoding, _, err := b.getClosestVector(ctx, vectorToInsert, nextCandidate)
if err != nil {
return nil, err
}
hashPath = append(hashPath, closestVectorHash[:]...)
keyTupleBuilder.PutByteString(pathDepth, closestVectorEncoding)
}
// Once we have the path for this key, we turn it into a tuple and add it to the next pathMap.
keyTupleBuilder.PutByteString(0, hashPath)
keyTupleBuilder.PutByteString(1, keyToInsert)
keyTupleBuilder.PutByteString(depth, keyToInsert)
keyTuple, err := keyTupleBuilder.Build(b.ns.Pool())
if err != nil {
return nil, err
}
keyTuple := keyTupleBuilder.BuildPrefixNoRecycle(b.ns.Pool(), depth+1)
err = pathMaps[level].Put(ctx, keyTuple, levelMapValue)
if err != nil {
return nil, err
@@ -441,14 +463,9 @@ func (b *ProximityMapBuilder) makePathMaps(ctx context.Context, mutableLevelMap
childLevel := level - 1
if level > 0 {
for {
hashPath = append(hashPath, vectorHashToInsert[:]...)
keyTupleBuilder.PutByteString(0, hashPath)
keyTupleBuilder.PutByteString(1, keyToInsert)
childKeyTuple, err := keyTupleBuilder.Build(b.ns.Pool())
if err != nil {
return nil, err
}
depth++
keyTupleBuilder.PutByteString(depth, keyToInsert)
childKeyTuple := keyTupleBuilder.BuildPrefixNoRecycle(b.ns.Pool(), depth+1)
err = pathMaps[childLevel].Put(ctx, childKeyTuple, levelMapValue)
if err != nil {
return nil, err
@@ -472,79 +489,75 @@ func (b *ProximityMapBuilder) makePathMaps(ctx context.Context, mutableLevelMap
}
// createInitialPathMaps creates a list of MutableMaps that will eventually store a single level of the to-be-built ProximityMap
func (b *ProximityMapBuilder) createInitialPathMaps(ctx context.Context, maxLevel uint8) (pathMaps []*MutableMap, keyTupleBuilder, prefixTupleBuilder *val.TupleBuilder, err error) {
func (b *ProximityMapBuilder) createInitialPathMaps(ctx context.Context, maxLevel uint8) (pathMaps []*MutableMap, keyTupleBuilder *val.TupleBuilder, err error) {
pathMaps = make([]*MutableMap, maxLevel+1)
pathMapKeyDescTypes := []val.Type{{Enc: val.ByteStringEnc, Nullable: false}, {Enc: val.ByteStringEnc, Nullable: false}}
pathMapKeyDesc := val.NewTupleDescriptor(pathMapKeyDescTypes...)
emptyPathMap, err := NewMapFromTuples(ctx, b.ns, pathMapKeyDesc, b.valDesc)
keyTupleBuilder = val.NewTupleBuilder(pathMapKeyDesc, b.ns)
prefixTupleBuilder = val.NewTupleBuilder(val.NewTupleDescriptor(pathMapKeyDescTypes[0]), b.ns)
for i := uint8(0); i <= maxLevel; i++ {
if err != nil {
return nil, nil, nil, err
}
pathMaps[i] = newMutableMap(emptyPathMap)
pathMapKeyDescTypes := make([]val.Type, maxLevel+1)
for i := range pathMapKeyDescTypes {
pathMapKeyDescTypes[i] = val.Type{Enc: val.ByteStringEnc, Nullable: false}
}
return pathMaps, keyTupleBuilder, prefixTupleBuilder, nil
for level := uint8(0); level <= maxLevel; level++ {
depth := maxLevel - level
pathMapKeyDesc := val.NewTupleDescriptor(pathMapKeyDescTypes[:depth+1]...)
emptyPathMap, err := NewMapFromTuples(ctx, b.ns, pathMapKeyDesc, b.valDesc)
if err != nil {
return nil, nil, err
}
pathMaps[level] = newMutableMap(emptyPathMap)
}
keyTupleBuilder = val.NewTupleBuilder(val.NewTupleDescriptor(pathMapKeyDescTypes...), b.ns)
return pathMaps, keyTupleBuilder, nil
}
// getNextPathSegmentCandidates takes a list of keys, representing a path into the ProximityMap from the root.
// It returns an iter over all possible keys that could be the next path segment.
func (b *ProximityMapBuilder) getNextPathSegmentCandidates(ctx context.Context, pathMap *MutableMap, prefixTupleBuilder *val.TupleBuilder, currentPath []byte) (MapIter, error) {
prefixTupleBuilder.PutByteString(0, currentPath)
prefixTuple, err := prefixTupleBuilder.Build(b.ns.Pool())
if err != nil {
return nil, err
}
prefixRange := PrefixRange(ctx, prefixTuple, prefixTupleBuilder.Desc)
func (b *ProximityMapBuilder) getNextPathSegmentCandidates(ctx context.Context, pathMap *MutableMap, prefixTupleDesc val.TupleDesc, prefixTuple val.Tuple) (MapIter, error) {
prefixRange := PrefixRange(ctx, prefixTuple, prefixTupleDesc)
return pathMap.IterRange(ctx, prefixRange)
}
// getClosestVector iterates over a range of candidate vectors to determine which one is the closest to the target.
func (b *ProximityMapBuilder) getClosestVector(ctx context.Context, targetVector []float64, nextCandidate func() (candidate hash.Hash, err error, valid bool)) (hash.Hash, error) {
func (b *ProximityMapBuilder) getClosestVector(ctx context.Context, targetVector []float32, nextCandidate func() (candidate []byte, err error, valid bool)) (closestVectorEncoding []byte, closestVector []float32, err error) {
// First call to nextCandidate is guaranteed to be valid because there's at least one vector in the set.
// (non-root nodes inherit the first vector from their parent)
candidateVectorHash, err, _ := nextCandidate()
closestVectorEncoding, err, _ = nextCandidate()
if err != nil {
return hash.Hash{}, err
return nil, nil, err
}
closestVector, err = b.convertFunc(ctx, closestVectorEncoding)
if err != nil {
return nil, nil, err
}
candidateVector, err := getVectorFromHash(ctx, b.ns, candidateVectorHash)
closestDistance, err := b.distanceType.Eval(targetVector, closestVector)
if err != nil {
return hash.Hash{}, err
}
closestVectorHash := candidateVectorHash
closestDistance, err := b.distanceType.Eval(targetVector, candidateVector)
if err != nil {
return hash.Hash{}, err
return nil, nil, err
}
for {
candidateVectorHash, err, valid := nextCandidate()
candidateVectorEncoding, err, valid := nextCandidate()
if err != nil {
return hash.Hash{}, err
return nil, nil, err
}
if !valid {
return closestVectorHash, nil
return closestVectorEncoding, closestVector, nil
}
candidateVector, err = getVectorFromHash(ctx, b.ns, candidateVectorHash)
candidateVector, err := b.convertFunc(ctx, candidateVectorEncoding)
if err != nil {
return hash.Hash{}, err
return nil, nil, err
}
candidateDistance, err := b.distanceType.Eval(targetVector, candidateVector)
if err != nil {
return hash.Hash{}, err
return nil, nil, err
}
if candidateDistance < closestDistance {
closestVectorHash = candidateVectorHash
closestVector = candidateVector
closestVectorEncoding = candidateVectorEncoding
closestDistance = candidateDistance
}
}
@@ -580,9 +593,8 @@ func (b *ProximityMapBuilder) makeProximityMapFromPathMaps(ctx context.Context,
if err != nil {
return ProximityMap{}, err
}
originalKey, _ := rootPathMap.keyDesc.GetBytes(1, key)
path, _ := b.keyDesc.GetJSONAddr(0, originalKey)
_, nodeCount, nodeHash, err := chunker.Next(ctx, b.ns, b.vectorIndexSerializer, path, maxLevel-1, 1, b.keyDesc)
originalKey, _ := rootPathMap.keyDesc.GetBytes(0, key)
_, nodeCount, nodeHash, err := chunker.Next(ctx, b.ns, b.vectorIndexSerializer, originalKey, maxLevel-1, 1, b.keyDesc)
if err != nil {
return ProximityMap{}, err
}
@@ -592,15 +604,3 @@ func (b *ProximityMapBuilder) makeProximityMapFromPathMaps(ctx context.Context,
}
return b.makeRootNode(ctx, topLevelKeys, topLevelValues, topLevelSubtrees, maxLevel)
}
func getJsonValueFromHash(ctx context.Context, ns tree.NodeStore, h hash.Hash) (sql.JSONWrapper, error) {
return tree.NewJSONDoc(h, ns).ToIndexedJSONDocument(ctx)
}
func getVectorFromHash(ctx context.Context, ns tree.NodeStore, h hash.Hash) ([]float64, error) {
otherValue, err := getJsonValueFromHash(ctx, ns, h)
if err != nil {
return nil, err
}
return sql.ConvertToVector(ctx, otherValue)
}
File diff suppressed because it is too large Load Diff
+30 -20
View File
@@ -18,7 +18,6 @@ import (
"context"
"fmt"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
"github.com/dolthub/dolt/go/gen/fb/serial"
@@ -56,18 +55,9 @@ func (f ProximityFlusher) ApplyMutationsWithSerializer(
keyDesc := mutableMap.keyDesc
valDesc := mutableMap.valDesc
ns := mutableMap.NodeStore()
convert := func(ctx context.Context, bytes []byte) []float64 {
h, _ := keyDesc.GetJSONAddr(0, bytes)
doc := tree.NewJSONDoc(h, ns)
jsonWrapper, err := doc.ToIndexedJSONDocument(ctx)
if err != nil {
panic(err)
}
floats, err := sql.ConvertToVector(ctx, jsonWrapper)
if err != nil {
panic(err)
}
return floats
convertFunc, err := getConvertToVectorFunction(keyDesc, ns)
if err != nil {
return tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{}, err
}
edits := make([]VectorIndexKV, 0, mutableMap.tuples.Edits.Count())
editIter := mutableMap.tuples.Mutations()
@@ -86,7 +76,6 @@ func (f ProximityFlusher) ApplyMutationsWithSerializer(
mutation = editIter.NextMutation(ctx)
}
var newRoot tree.Node
var err error
root := mutableMap.tuples.Static.Root
distanceType := mutableMap.tuples.Static.DistanceType
if root.Count() == 0 {
@@ -96,7 +85,11 @@ func (f ProximityFlusher) ApplyMutationsWithSerializer(
// The root node has changed, or there may be a new level to the tree. We need to rebuild the tree.
newRoot, _, err = f.rebuildNode(ctx, ns, root, edits, distanceType, keyDesc, valDesc, maxEditLevel)
} else {
newRoot, _, err = f.visitNode(ctx, serializer, ns, root, edits, convert, distanceType, keyDesc, valDesc)
root, err = root.LoadSubtrees()
if err != nil {
return tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{}, err
}
newRoot, _, err = f.visitNode(ctx, serializer, ns, root, edits, convertFunc, distanceType, keyDesc, valDesc)
}
if err != nil {
@@ -106,7 +99,7 @@ func (f ProximityFlusher) ApplyMutationsWithSerializer(
Root: newRoot,
NodeStore: ns,
DistanceType: distanceType,
Convert: convert,
Convert: convertFunc,
Order: keyDesc,
}, nil
}
@@ -161,7 +154,7 @@ func (f ProximityFlusher) visitNode(
ns tree.NodeStore,
node tree.Node,
edits []VectorIndexKV,
convert func(context.Context, []byte) []float64,
convert tree.ConvertToVectorFunction,
distanceType vector.DistanceType,
keyDesc val.TupleDesc,
valDesc val.TupleDesc,
@@ -177,18 +170,29 @@ func (f ProximityFlusher) visitNode(
childEdits := make(map[int]childEditList)
for _, edit := range edits {
key := edit.key
editVector := convert(ctx, key)
editVector, err := convert(ctx, key)
if err != nil {
return tree.Node{}, 0, err
}
level := edit.level
// visit each child in the node to determine which is closest
closestIdx := 0
childKey := node.GetKey(0)
closestDistance, err := distanceType.Eval(convert(ctx, childKey), editVector)
childVector, err := convert(ctx, childKey)
if err != nil {
return tree.Node{}, 0, err
}
closestDistance, err := distanceType.Eval(childVector, editVector)
if err != nil {
return tree.Node{}, 0, err
}
for i := 1; i < node.Count(); i++ {
childKey = node.GetKey(i)
newDistance, err := distanceType.Eval(convert(ctx, childKey), editVector)
childVector, err = convert(ctx, childKey)
if err != nil {
return tree.Node{}, 0, err
}
newDistance, err := distanceType.Eval(childVector, editVector)
if err != nil {
return tree.Node{}, 0, err
}
@@ -215,6 +219,8 @@ func (f ProximityFlusher) visitNode(
if len(childEditList.edits) == 0 {
// No edits affected this node, leave it as is.
values = append(values, childValue)
childSubtrees := node.GetSubtreeCount(i)
nodeSubtrees = append(nodeSubtrees, uint64(childSubtrees))
} else {
childNodeAddress := hash.New(childValue)
childNode, err := ns.Read(ctx, childNodeAddress)
@@ -226,6 +232,10 @@ func (f ProximityFlusher) visitNode(
if childEditList.mustRebuild {
newChildNode, childSubtrees, err = f.rebuildNode(ctx, ns, childNode, childEditList.edits, distanceType, keyDesc, valDesc, uint8(childNode.Level()))
} else {
childNode, err = childNode.LoadSubtrees()
if err != nil {
return tree.Node{}, 0, err
}
newChildNode, childSubtrees, err = f.visitNode(ctx, serializer, ns, childNode, childEditList.edits, convert, distanceType, keyDesc, valDesc)
}
+2 -2
View File
@@ -67,7 +67,7 @@ func MapInterfaceFromValue(ctx context.Context, v types.Value, sch schema.Schema
// TODO: We should read the distance function and chunk size from the message.
// Currently, vector.DistanceL2Squared{} and prolly.DefaultLogChunkSize are the only values that can be written,
// but this may not be true in the future.
return prolly.NewProximityMap(ns, root, kd, vd, vector.DistanceL2Squared{}, prolly.DefaultLogChunkSize), nil
return prolly.NewProximityMap(ns, root, kd, vd, vector.DistanceL2Squared{}, prolly.DefaultLogChunkSize)
default:
return prolly.NewMap(root, ns, kd, vd), nil
}
@@ -83,7 +83,7 @@ func MapFromValueWithDescriptors(v types.Value, kd, vd val.TupleDesc, ns tree.No
// TODO: We should read the distance function and chunk size from the message.
// Currently, vector.DistanceL2Squared{} and prolly.DefaultLogChunkSize are the only values that can be written,
// but this may not be true in the future.
return prolly.NewProximityMap(ns, root, kd, vd, vector.DistanceL2Squared{}, prolly.DefaultLogChunkSize), nil
return prolly.NewProximityMap(ns, root, kd, vd, vector.DistanceL2Squared{}, prolly.DefaultLogChunkSize)
default:
return prolly.NewMap(root, ns, kd, vd), nil
}
+2 -3
View File
@@ -188,11 +188,10 @@ func TestWriteImmutableTree(t *testing.T) {
assert.Equal(t, tt.inputSize, byteCnt)
assert.Equal(t, expUnfilled, unfilledCnt)
if expLevel > 0 {
root, err = root.loadSubtrees()
root, err = root.LoadSubtrees()
require.NoError(t, err)
for i := range expSubtrees {
sc, err := root.getSubtreeCount(i)
require.NoError(t, err)
sc := root.GetSubtreeCount(i)
assert.Equal(t, expSubtrees[i], sc)
}
}
+2 -5
View File
@@ -290,15 +290,12 @@ func insertNode[K ~[]byte, S message.Serializer, O Ordering[K]](ctx context.Cont
}
}
} else {
nd, err = nd.loadSubtrees()
nd, err = nd.LoadSubtrees()
if err != nil {
return err
}
for i := 0; i < nd.Count(); i++ {
subtreeCount, err := nd.getSubtreeCount(i)
if err != nil {
return err
}
subtreeCount := nd.GetSubtreeCount(i)
err = insertNode[K, S, O](ctx, tc, nil, K(nd.GetKey(i)), nd.getAddress(i), subtreeCount, level-1, order)
if err != nil {
return err
+4 -4
View File
@@ -160,7 +160,7 @@ func (nd Node) GetValue(i int) Item {
return nd.values.GetItem(i, nd.msg)
}
func (nd Node) loadSubtrees() (Node, error) {
func (nd Node) LoadSubtrees() (Node, error) {
var err error
if nd.subtrees == nil {
// deserializing subtree counts requires a malloc,
@@ -174,12 +174,12 @@ func (nd Node) loadSubtrees() (Node, error) {
return nd, err
}
func (nd Node) getSubtreeCount(i int) (uint64, error) {
func (nd Node) GetSubtreeCount(i int) uint64 {
if nd.IsLeaf() {
return 1, nil
return 1
}
// this will panic unless subtrees were loaded.
return (*nd.subtrees)[i], nil
return (*nd.subtrees)[i]
}
// getAddress returns the |ith| address of this node.
+8 -14
View File
@@ -111,10 +111,10 @@ func newCursorAtOrdinal(ctx context.Context, ns NodeStore, nd Node, ord uint64)
if nd.IsLeaf() {
return int(distance)
}
nd, _ = nd.loadSubtrees()
nd, _ = nd.LoadSubtrees()
for idx = 0; idx < nd.Count(); idx++ {
cnt, _ := nd.getSubtreeCount(idx)
cnt := nd.GetSubtreeCount(idx)
card := int64(cnt)
if (distance - card) < 0 {
break
@@ -144,16 +144,13 @@ func getOrdinalOfCursor(curr *cursor) (ord uint64, err error) {
return 0, fmt.Errorf("found invalid parent cursor behind node start")
}
curr.nd, err = curr.nd.loadSubtrees()
curr.nd, err = curr.nd.LoadSubtrees()
if err != nil {
return 0, err
}
for idx := curr.idx - 1; idx >= 0; idx-- {
cnt, err := curr.nd.getSubtreeCount(idx)
if err != nil {
return 0, err
}
cnt := curr.nd.GetSubtreeCount(idx)
ord += cnt
}
}
@@ -289,15 +286,12 @@ func recursiveFetchLeafNodeSpan(ctx context.Context, ns NodeStore, nodes []Node,
var err error
for _, nd := range nodes {
if nd, err = nd.loadSubtrees(); err != nil {
if nd, err = nd.LoadSubtrees(); err != nil {
return nil, 0, err
}
for i := 0; i < nd.Count(); i++ {
card, err := nd.getSubtreeCount(i)
if err != nil {
return nil, 0, err
}
card := nd.GetSubtreeCount(i)
if acc == 0 && card < start {
start -= card
@@ -388,11 +382,11 @@ func (cur *cursor) currentSubtreeSize() (uint64, error) {
return 1, nil
}
var err error
cur.nd, err = cur.nd.loadSubtrees()
cur.nd, err = cur.nd.LoadSubtrees()
if err != nil {
return 0, err
}
return cur.nd.getSubtreeCount(cur.idx)
return cur.nd.GetSubtreeCount(cur.idx), nil
}
func (cur *cursor) firstKey() Item {
+1 -1
View File
@@ -418,7 +418,7 @@ func (td *PatchGenerator[K, O]) split(ctx context.Context) (patch Patch, diffTyp
if err != nil {
return Patch{}, NoDiff, err
}
toChild, err = toChild.loadSubtrees()
toChild, err = toChild.LoadSubtrees()
if err != nil {
return Patch{}, NoDiff, err
}
+1 -1
View File
@@ -137,7 +137,7 @@ func GetField(ctx context.Context, td val.TupleDesc, i int, tup val.Tuple, ns No
v = val.NewTextStorage(ctx, h, ns)
}
case val.BytesAdaptiveEnc:
v, ok, err = td.GetBytesAdaptiveValue(i, ns, tup)
v, ok, err = td.GetBytesAdaptiveValue(ctx, i, ns, tup)
case val.StringAdaptiveEnc:
v, ok, err = td.GetStringAdaptiveValue(i, ns, tup)
case val.CommitAddrEnc:
+22 -17
View File
@@ -29,13 +29,15 @@ import (
type KeyValueDistanceFn[K, V ~[]byte] func(key K, value V, distance float64) error
type ConvertToVectorFunction func(context.Context, []byte) ([]float32, error)
// ProximityMap is a static Prolly Tree where the position of a key in the tree is based on proximity, as opposed to a traditional ordering.
// O provides the ordering only within a node.
type ProximityMap[K, V ~[]byte, O Ordering[K]] struct {
NodeStore NodeStore
DistanceType vector.DistanceType
Order O
Convert func(context.Context, []byte) []float64
Convert ConvertToVectorFunction
Root Node
}
@@ -94,7 +96,10 @@ func (t ProximityMap[K, V, O]) WalkNodes(ctx context.Context, cb NodeCb) error {
func (t ProximityMap[K, V, O]) Get(ctx context.Context, query K, cb KeyValueFn[K, V]) (err error) {
nd := t.Root
queryVector := t.Convert(ctx, query)
queryVector, err := t.Convert(ctx, query)
if err != nil {
return err
}
// Find the child with the minimum distance.
@@ -105,7 +110,11 @@ func (t ProximityMap[K, V, O]) Get(ctx context.Context, query K, cb KeyValueFn[K
for i := 0; i < int(nd.count); i++ {
k := nd.GetKey(i)
newDistance, err := t.DistanceType.Eval(t.Convert(ctx, k), queryVector)
vec, err := t.Convert(ctx, k)
if err != nil {
return err
}
newDistance, err := t.DistanceType.Eval(vec, queryVector)
if err != nil {
return err
}
@@ -201,7 +210,11 @@ func (t ProximityMap[K, V, O]) GetClosest(ctx context.Context, query interface{}
for i := 0; i < int(t.Root.count); i++ {
k := t.Root.GetKey(i)
newDistance, err := t.DistanceType.Eval(t.Convert(ctx, k), queryVector)
vec, err := t.Convert(ctx, k)
if err != nil {
return err
}
newDistance, err := t.DistanceType.Eval(vec, queryVector)
if err != nil {
return err
}
@@ -222,7 +235,11 @@ func (t ProximityMap[K, V, O]) GetClosest(ctx context.Context, query interface{}
// TODO: We don't need to recompute the distance when visiting the same key as the parent.
for i := 0; i < int(node.count); i++ {
k := node.GetKey(i)
newDistance, err := t.DistanceType.Eval(t.Convert(ctx, k), queryVector)
vec, err := t.Convert(ctx, k)
if err != nil {
return err
}
newDistance, err := t.DistanceType.Eval(vec, queryVector)
if err != nil {
return err
}
@@ -265,15 +282,3 @@ func (t ProximityMap[K, V, O]) IterAll(ctx context.Context) (*OrderedTreeIter[K,
return &OrderedTreeIter[K, V]{curr: c, stop: stop, step: c.advance}, nil
}
func getJsonValueFromHash(ctx context.Context, ns NodeStore, h hash.Hash) (interface{}, error) {
return NewJSONDoc(h, ns).ToIndexedJSONDocument(ctx)
}
func getVectorFromHash(ctx context.Context, ns NodeStore, h hash.Hash) ([]float64, error) {
otherValue, err := getJsonValueFromHash(ctx, ns, h)
if err != nil {
return nil, err
}
return sql.ConvertToVector(ctx, otherValue)
}
+2 -3
View File
@@ -81,11 +81,10 @@ func histLevelCount(t *testing.T, nodes []Node) int {
case 0:
cnt += n.Count()
default:
n, err := n.loadSubtrees()
n, err := n.LoadSubtrees()
require.NoError(t, err)
for i := 0; i < n.Count(); i++ {
subCnt, err := n.getSubtreeCount(i)
require.NoError(t, err)
subCnt := n.GetSubtreeCount(i)
cnt += int(subCnt)
}
}
+22 -12
View File
@@ -15,6 +15,7 @@
package prolly
import (
"bytes"
"context"
"io"
@@ -27,6 +28,7 @@ import (
// vectorIndexChunker is a stateful chunker that iterates over |pathMap|, a map that contains an element
// for every key-value pair for a given level of a ProximityMap, and provides the path of keys to reach
// that pair from the root. It uses this iterator to build each of the ProximityMap nodes for that level.
// A linked list of N vectorIndexChunkers is used in order to build a vector index with N levels.
type vectorIndexChunker struct {
pathMapIter MapIter
pathMap *MutableMap
@@ -34,8 +36,10 @@ type vectorIndexChunker struct {
lastKey []byte
lastValue []byte
lastSubtreeCount uint64
lastPathSegment hash.Hash
atEnd bool
// lastPathSegment is the last observed parent key. When Next() is called with a different value for |parentPathSegment|,
// we know that the chunker needs to end the previous chunk and start a new one.
lastPathSegment []byte
atEnd bool
}
func newVectorIndexChunker(ctx context.Context, pathMap *MutableMap, childChunker *vectorIndexChunker) (*vectorIndexChunker, error) {
@@ -56,9 +60,8 @@ func newVectorIndexChunker(ctx context.Context, pathMap *MutableMap, childChunke
if err != nil {
return nil, err
}
path, _ := pathMap.keyDesc.GetBytes(0, firstKey)
lastPathSegment := hash.New(path[len(path)-20:])
originalKey, _ := pathMap.keyDesc.GetBytes(1, firstKey)
lastPathSegment, _ := pathMap.keyDesc.GetBytes(pathMap.keyDesc.Count()-2, firstKey)
originalKey, _ := pathMap.keyDesc.GetBytes(pathMap.keyDesc.Count()-1, firstKey)
return &vectorIndexChunker{
pathMap: pathMap,
pathMapIter: pathMapIter,
@@ -70,14 +73,15 @@ func newVectorIndexChunker(ctx context.Context, pathMap *MutableMap, childChunke
}, nil
}
func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serializer message.VectorIndexSerializer, parentPathSegment hash.Hash, level, depth int, originalKeyDesc val.TupleDesc) (tree.Node, uint64, hash.Hash, error) {
// Next produces the next tree node for the corresponding level of the tree.
func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serializer message.VectorIndexSerializer, parentPathSegment []byte, level, depth int, originalKeyDesc val.TupleDesc) (tree.Node, uint64, hash.Hash, error) {
var indexMapKeys [][]byte
var indexMapValues [][]byte
var indexMapSubtrees []uint64
subtreeSum := uint64(0)
for {
if c.atEnd || c.lastPathSegment != parentPathSegment {
if c.atEnd || !bytes.Equal(c.lastPathSegment, parentPathSegment) {
msg := serializer.Serialize(indexMapKeys, indexMapValues, indexMapSubtrees, level)
node, _, err := tree.NodeFromBytes(msg)
if err != nil {
@@ -86,9 +90,10 @@ func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serial
nodeHash, err := ns.Write(ctx, node)
return node, subtreeSum, nodeHash, err
}
vectorHash, _ := originalKeyDesc.GetJSONAddr(0, c.lastKey)
if c.childChunker != nil {
_, childCount, nodeHash, err := c.childChunker.Next(ctx, ns, serializer, vectorHash, level-1, depth+1, originalKeyDesc)
// This chunker isn't chunking a leaf node. To insert the next key-value pair, we call Next() on the child chunker, which produces
// a node one level down, that will be pointed to by this node.
_, childCount, nodeHash, err := c.childChunker.Next(ctx, ns, serializer, c.lastKey, level-1, depth+1, originalKeyDesc)
if err != nil {
return tree.Node{}, 0, hash.Hash{}, err
}
@@ -107,9 +112,14 @@ func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serial
} else if err != nil {
return tree.Node{}, 0, hash.Hash{}, err
} else {
lastPath, _ := c.pathMap.keyDesc.GetBytes(0, nextKey)
c.lastPathSegment = hash.New(lastPath[len(lastPath)-20:])
c.lastKey, _ = c.pathMap.keyDesc.GetBytes(1, nextKey)
// nextValue is a pathMap value tuple: it contains a primary key from the underlying table.
// nextKey is a pathMap key tuple: it contains a field for each edge in the final index graph that connects
// the root to |nextValue|, ending with the vector corresponding to the primary key in |nextValue|.
// This chunker stores that vector in |c.lastKey|, so that it can write it into the index on the subsequent call.
// It also stores the direct parent vector. When the direct parent vector changes, we use that as an indicator
// To finish one chunk and begin the next one.
c.lastPathSegment, _ = c.pathMap.keyDesc.GetBytes(c.pathMap.keyDesc.Count()-2, nextKey)
c.lastKey, _ = c.pathMap.keyDesc.GetBytes(c.pathMap.keyDesc.Count()-1, nextKey)
c.lastValue = nextValue
}
}
+4 -4
View File
@@ -288,7 +288,7 @@ func TestTupleBuilderAdaptiveEncodings(t *testing.T) {
tup, err := tb.Build(testPool)
require.NoError(t, err)
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(0, vs, tup)
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(ctx, 0, vs, tup)
require.NoError(t, err)
require.Equal(t, shortByteArray, adaptiveEncodingBytes)
})
@@ -303,7 +303,7 @@ func TestTupleBuilderAdaptiveEncodings(t *testing.T) {
tup, err := tb.Build(testPool)
require.NoError(t, err)
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(0, vs, tup)
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(ctx, 0, vs, tup)
require.NoError(t, err)
adaptiveEncodingByteArray := adaptiveEncodingBytes.(*ByteArray)
outBytes, err := adaptiveEncodingByteArray.ToBytes(ctx)
@@ -336,7 +336,7 @@ func TestTupleBuilderAdaptiveEncodings(t *testing.T) {
{
// Check that first column is stored out-of-band
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(0, vs, tup)
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(ctx, 0, vs, tup)
require.NoError(t, err)
adaptiveEncodingByteArray := adaptiveEncodingBytes.(*ByteArray)
outBytes, err := adaptiveEncodingByteArray.ToBytes(ctx)
@@ -346,7 +346,7 @@ func TestTupleBuilderAdaptiveEncodings(t *testing.T) {
{
// Check that second column is stored inline
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(1, vs, tup)
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(ctx, 1, vs, tup)
require.NoError(t, err)
adaptiveEncodingByteArray := adaptiveEncodingBytes.([]byte)
require.Equal(t, mediumByteArray, adaptiveEncodingByteArray)
+6 -4
View File
@@ -535,11 +535,13 @@ func (td TupleDesc) GetBytesAddr(i int, tup Tuple) (hash.Hash, bool) {
}
// GetBytesAdaptiveValue returns either a []byte or a BytesWrapper, but Go doesn't allow us to use a single type for that.
func (td TupleDesc) GetBytesAdaptiveValue(i int, vs ValueStore, tup Tuple) (interface{}, bool, error) {
// TODO: Add context parameter
ctx := context.Background()
func (td TupleDesc) GetBytesAdaptiveValue(ctx context.Context, i int, vs ValueStore, tup Tuple) (interface{}, bool, error) {
td.expectEncoding(i, BytesAdaptiveEnc)
adaptiveValue := AdaptiveValue(td.GetField(i, tup))
return GetBytesAdaptiveValue(ctx, vs, td.GetField(i, tup))
}
func GetBytesAdaptiveValue(ctx context.Context, vs ValueStore, val []byte) (interface{}, bool, error) {
adaptiveValue := AdaptiveValue(val)
if len(adaptiveValue) == 0 {
return nil, false, nil
}
+23
View File
@@ -0,0 +1,23 @@
#!/usr/bin/env bats
load $BATS_TEST_DIRNAME/helper/common.bash
setup() {
if [ "$SQL_ENGINE" = "remote-engine" ]; then
skip "This test tests remote connections directly, SQL_ENGINE is not needed."
fi
setup_common
}
teardown() {
stop_sql_server 1 && sleep 0.5
teardown_common
}
@test "dolt-test-run: sanity test on sql-server" {
start_sql_server
dolt sql -q "insert into dolt_tests values ('test', 'test', 'select 1', 'expected_rows', '==', '1');"
run dolt sql -q "select * from dolt_test_run()"
[ $status -eq 0 ]
[[ $output =~ "| test | test | select 1 | PASS | |" ]] || false
}
+6 -2
View File
@@ -960,11 +960,12 @@ CREATE TABLE \`all_types\` (
\`v31\` varchar(255) DEFAULT NULL,
\`v32\` varbinary(255) DEFAULT NULL,
\`v33\` year DEFAULT NULL,
\`v34\` vector(1) DEFAULT NULL,
PRIMARY KEY (\`pk\`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin;
INSERT INTO \`all_types\` (\`pk\`,\`v1\`,\`v2\`,\`v3\`,\`v4\`,\`v5\`,\`v6\`,\`v7\`,\`v8\`,\`v9\`,\`v10\`,\`v11\`,\`v12\`,\`v13\`,\`v14\`,\`v15\`,\`v16\`,\`v17\`,\`v18\`,\`v19\`,\`v20\`,\`v21\`,\`v22\`,\`v23\`,\`v24\`,\`v25\`,\`v26\`,\`v27\`,\`v28\`,\`v29\`,\`v30\`,\`v31\`,\`v32\`,\`v33\`) VALUES
(1,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,'null',NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL),
(2, 1, 1, 1 ,('abc'),'a','2022-04-05','2022-10-05 10:14:41',2.34,2.34,'s',2.34,POINT(1,2),1,'{"a":1}',LINESTRING(POINT(0,0),POINT(1,2)),('abcd'),'abcd',('ab'),1,'abc',POINT(2,1),polygon(linestring(point(1,2),point(3,4),point(5,6),point(1,2))),'one',1,'abc','10:14:41','2022-10-05 10:14:41',('a'),1,'a','abcde',1,2022);
(1,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,'null',NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL),
(2, 1, 1, 1 ,('abc'),'a','2022-04-05','2022-10-05 10:14:41',2.34,2.34,'s',2.34,POINT(1,2),1,'{"a":1}',LINESTRING(POINT(0,0),POINT(1,2)),('abcd'),'abcd',('ab'),1,'abc',POINT(2,1),polygon(linestring(point(1,2),point(3,4),point(5,6),point(1,2))),'one',1,'abc','10:14:41','2022-10-05 10:14:41',('a'),1,'a','abcde',1,2022,0x12345678);
SQL
[ "$status" -eq 0 ]
@@ -978,6 +979,9 @@ SQL
run dolt sql -q "SELECT ST_AsText(v12), ST_AsText(v21), ST_AsText(v15), ST_AsText(v22) from t1;"
[[ "$output" =~ "POINT(1 2) | POINT(2 1) | LINESTRING(0 0,1 2) | POLYGON((1 2,3 4,5 6,1 2))" ]] || false
run dolt sql -q "SELECT v34 from t1"
[[ "$output" =~ "0x12345678" ]] || false
# need to test binary, bit and blob types
}
+137
View File
@@ -0,0 +1,137 @@
# dolt/integration-tests/bats/helper/sql-diff.bash
: "${SQL_DIFF_DEBUG:=}" # set to any value to enable debug output
_dbg() { [ -n "$SQL_DIFF_DEBUG" ] && printf '%s\n' "$*" >&2; }
_dbg_block() { [ -n "$SQL_DIFF_DEBUG" ] && { printf '%s\n' "$1" >&2; printf '%s\n' "$2" >&2; }; }
# first table header row from CLI diff (data section), as newline list
_cli_header_cols() {
awk '
/^\s*\|\s*[-+<>]\s*\|/ && last_header != "" { print last_header; exit }
/^\s*\|/ { last_header = $0 }
' <<<"$1" \
| tr '|' '\n' \
| sed -e 's/^[[:space:]]*//;s/[[:space:]]*$//' \
| grep -v -E '^(<|>|)$' \
| grep -v '^$'
}
# first table header row from SQL diff, strip to_/from_, drop metadata, as newline list
_sql_data_header_cols() {
echo "$1" \
| awk '/^\|/ {print; exit}' \
| tr '|' '\n' \
| sed -e 's/^[[:space:] ]*//;s/[[:space:] ]*$//' \
| grep -E '^(to_|from_)' \
| sed -E 's/^(to_|from_)//' \
| grep -Ev '^(commit|commit_date|diff_type)$' \
| grep -v '^$'
}
# count CLI changes by unique PK (includes +, -, <, >)
_cli_change_count() {
awk -F'|' '
# start counting once we see a data row marker
/^\s*\|\s*[-+<>]\s*\|/ { in_table=1 }
in_table && $2 ~ /^[[:space:]]*[-+<>][[:space:]]*$/ {
pk=$3
gsub(/^[[:space:]]+|[[:space:]]+$/, "", pk)
if (pk != "") seen[pk]=1
}
END { c=0; for (k in seen) c++; print c+0 }
' <<<"$1"
}
# count SQL data rows (lines starting with '|' minus header)
_sql_row_count() {
echo "$1" | awk '/^\|/ {c++} END{print (c>0?c-1:0)}'
}
# compare two newline lists as sets (sorted)
_compare_sets_or_err() {
local name="$1" cli_cols="$2" sql_cols="$3" cli_out="$4" sql_out="$5"
local cli_sorted sql_sorted
cli_sorted=$(echo "$cli_cols" | sort -u)
sql_sorted=$(echo "$sql_cols" | sort -u)
_dbg_block "$name CLI columns:" "$cli_sorted"
_dbg_block "$name SQL data columns:" "$sql_sorted"
if [ "$cli_sorted" != "$sql_sorted" ]; then
echo "$name column set mismatch"
echo "--- $name CLI columns ---"; echo "$cli_sorted"
echo "--- $name SQL data columns ---"; echo "$sql_sorted"
echo "--- $name CLI output ---"; echo "$cli_out"
echo "--- $name SQL output ---"; echo "$sql_out"
return 1
fi
return 0
}
# compare change/row counts; on mismatch, print both outputs
_compare_counts_or_err() {
local name="$1" cli_out="$2" sql_out="$3" cli_count="$4" sql_count="$5"
_dbg "$name counts: CLI=$cli_count SQL=$sql_count"
if [ "$cli_count" != "$sql_count" ]; then
echo "$name change count mismatch: CLI=$cli_count, SQL=$sql_count"
echo "--- $name CLI output ---"; echo "$cli_out"
echo "--- $name SQL output ---"; echo "$sql_out"
return 1
fi
return 0
}
# ---- main entrypoint ----
# Compare CLI diff with SQL dolt_diff
# Usage: compare_dolt_diff [all dolt diff args...]
compare_dolt_diff() {
local args=("$@") # all arguments
# --- normal diff ---
local cli_output sql_output cli_status sql_status
cli_output=$(dolt diff "${args[@]}" 2>&1)
cli_status=$?
# Build SQL argument list safely
local sql_args=""
for arg in "${args[@]}"; do
if [ -z "$sql_args" ]; then
sql_args="'$arg'"
else
sql_args+=", '$arg'"
fi
done
sql_output=$(dolt sql -q "SELECT * FROM dolt_diff($sql_args)" 2>&1)
sql_status=$?
# normally prints in bats using `run`, so no debug blocks here
echo "$cli_output"
echo "$sql_output"
if [ $cli_status -ne 0 ]; then
_dbg "$cli_output"
return 1
fi
if [ $sql_status -ne 0 ]; then
_dbg "$sql_output"
return 1
fi
# Compare counts
local cli_changes sql_rows
cli_changes=$(_cli_change_count "$cli_output")
sql_rows=$(_sql_row_count "$sql_output")
_compare_counts_or_err "Diff" "$cli_output" "$sql_output" "$cli_changes" "$sql_rows" || return 1
# Compare columns
local cli_cols sql_cols
cli_cols=$(_cli_header_cols "$cli_output")
sql_cols=$(_sql_data_header_cols "$sql_output")
_compare_sets_or_err "Diff" "$cli_cols" "$sql_cols" "$cli_output" "$sql_output" || return 1
return 0
}
+58
View File
@@ -1,5 +1,6 @@
#!/usr/bin/env bats
load $BATS_TEST_DIRNAME/helper/common.bash
load $BATS_TEST_DIRNAME/helper/sql-diff.bash
setup() {
setup_common
@@ -890,3 +891,60 @@ EOF
[ "$status" -eq 1 ]
[[ "$output" =~ "invalid output format: sql. SQL format diffs only rendered for schema or data changes" ]] || false
}
@test "sql-diff: skinny flag comparison between CLI and SQL table function" {
dolt sql <<SQL
CREATE TABLE test (
pk BIGINT NOT NULL COMMENT 'tag:0',
c1 BIGINT COMMENT 'tag:1',
c2 BIGINT COMMENT 'tag:2',
c3 BIGINT COMMENT 'tag:3',
c4 BIGINT COMMENT 'tag:4',
c5 BIGINT COMMENT 'tag:5',
PRIMARY KEY (pk)
);
SQL
dolt table import -u test `batshelper 1pk5col-ints.csv`
dolt add test
dolt commit -m "Added initial data"
compare_dolt_diff "HEAD~1" "HEAD" "test"
compare_dolt_diff "-sk" "HEAD~1" "HEAD" "test"
compare_dolt_diff "--skinny" "HEAD~1" "HEAD" "test"
compare_dolt_diff "HEAD~1" "HEAD" "-sk" "test"
dolt sql -q "UPDATE test SET c1=100, c3=300 WHERE pk=0"
dolt sql -q "UPDATE test SET c2=200 WHERE pk=1"
dolt add test
dolt commit -m "Updated some columns"
compare_dolt_diff "HEAD~1" "HEAD" "test"
compare_dolt_diff "--skinny" "HEAD~1" "HEAD" "test"
compare_dolt_diff "HEAD~1" "HEAD" "test" "--skinny"
dolt sql -q "ALTER TABLE test ADD COLUMN c6 BIGINT"
dolt sql -q "UPDATE test SET c6=600 WHERE pk=0"
dolt add test
dolt commit -m "Added new column and updated it"
compare_dolt_diff "HEAD~1" "HEAD" "test"
compare_dolt_diff "--skinny" "HEAD~1" "HEAD" "test"
compare_dolt_diff "--skinny" "--include-cols=c1,c2" "HEAD~1" "HEAD" "test"
compare_dolt_diff "--skinny" "HEAD~2" "HEAD" "test" "--include-cols" "c1" "c2"
compare_dolt_diff "--skinny" "--include-cols=c1,c2" "HEAD~2" "HEAD" "test"
dolt sql -q "DELETE FROM test WHERE pk=1"
dolt add test
dolt commit -m "Deleted a row"
compare_dolt_diff "HEAD~1" "HEAD" "test"
compare_dolt_diff "--skinny" "HEAD~1" "HEAD" "test"
run dolt sql -q "SELECT * FROM dolt_diff('-err', 'HEAD~1', 'HEAD', 'test')"
[[ "$output" =~ "unknown option \`err" ]] || false
[ "$status" -eq 1 ]
run dolt sql -q "SELECT * FROM dolt_diff('-sk', '--skinny', 'HEAD~1', 'HEAD', 'test')"
[[ "$output" =~ "multiple values provided for \`skinny" ]] || false
[ "$status" -eq 1 ]
}