mirror of
https://github.com/dolthub/dolt.git
synced 2026-04-23 05:13:00 -05:00
Merge remote-tracking branch 'origin/main' into aaron/github-workflows-icu4c
This commit is contained in:
@@ -308,6 +308,29 @@ func CreateLogArgParser(isTableFunction bool) *argparser.ArgParser {
|
||||
return ap
|
||||
}
|
||||
|
||||
func CreateDiffArgParser(isTableFunction bool) *argparser.ArgParser {
|
||||
ap := argparser.NewArgParserWithVariableArgs("diff")
|
||||
ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.")
|
||||
ap.SupportsStringList(IncludeCols, "ic", "columns", "A list of columns to include in the diff.")
|
||||
if !isTableFunction { // TODO: support for table function
|
||||
ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).")
|
||||
ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).")
|
||||
ap.SupportsFlag(StatFlag, "", "Show stats of data changes")
|
||||
ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes")
|
||||
ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.")
|
||||
ap.SupportsString(WhereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
|
||||
ap.SupportsInt(LimitParam, "", "record_count", "limits to the first N diffs.")
|
||||
ap.SupportsFlag(StagedFlag, "", "Show only the staged data changes.")
|
||||
ap.SupportsFlag(CachedFlag, "c", "Synonym for --staged")
|
||||
ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit")
|
||||
ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.")
|
||||
ap.SupportsFlag(ReverseFlag, "R", "Reverses the direction of the diff.")
|
||||
ap.SupportsFlag(NameOnlyFlag, "", "Only shows table names.")
|
||||
ap.SupportsFlag(SystemFlag, "", "Show system tables in addition to user tables")
|
||||
}
|
||||
return ap
|
||||
}
|
||||
|
||||
func CreateGCArgParser() *argparser.ArgParser {
|
||||
ap := argparser.NewArgParserWithMaxArgs("gc", 0)
|
||||
ap.SupportsFlag(ShallowFlag, "s", "perform a fast, but incomplete garbage collection pass")
|
||||
|
||||
@@ -88,3 +88,19 @@ const (
|
||||
UpperCaseAllFlag = "ALL"
|
||||
UserFlag = "user"
|
||||
)
|
||||
|
||||
// Flags used by `dolt diff` command and `dolt_diff()` table function.
|
||||
const (
|
||||
SkinnyFlag = "skinny"
|
||||
IncludeCols = "include-cols"
|
||||
DataFlag = "data"
|
||||
SchemaFlag = "schema"
|
||||
NameOnlyFlag = "name-only"
|
||||
SummaryFlag = "summary"
|
||||
WhereParam = "where"
|
||||
LimitParam = "limit"
|
||||
MergeBase = "merge-base"
|
||||
DiffMode = "diff-mode"
|
||||
ReverseFlag = "reverse"
|
||||
FormatFlag = "result-format"
|
||||
)
|
||||
|
||||
@@ -60,18 +60,6 @@ const (
|
||||
TabularDiffOutput diffOutput = 1
|
||||
SQLDiffOutput diffOutput = 2
|
||||
JsonDiffOutput diffOutput = 3
|
||||
|
||||
DataFlag = "data"
|
||||
SchemaFlag = "schema"
|
||||
NameOnlyFlag = "name-only"
|
||||
StatFlag = "stat"
|
||||
SummaryFlag = "summary"
|
||||
whereParam = "where"
|
||||
limitParam = "limit"
|
||||
SkinnyFlag = "skinny"
|
||||
MergeBase = "merge-base"
|
||||
DiffMode = "diff-mode"
|
||||
ReverseFlag = "reverse"
|
||||
)
|
||||
|
||||
var diffDocs = cli.CommandDocumentationContent{
|
||||
@@ -107,12 +95,13 @@ The {{.EmphasisLeft}}--diff-mode{{.EmphasisRight}} argument controls how modifie
|
||||
}
|
||||
|
||||
type diffDisplaySettings struct {
|
||||
diffParts diffPart
|
||||
diffOutput diffOutput
|
||||
diffMode diff.Mode
|
||||
limit int
|
||||
where string
|
||||
skinny bool
|
||||
diffParts diffPart
|
||||
diffOutput diffOutput
|
||||
diffMode diff.Mode
|
||||
limit int
|
||||
where string
|
||||
skinny bool
|
||||
includeCols []string
|
||||
}
|
||||
|
||||
type diffDatasets struct {
|
||||
@@ -164,23 +153,7 @@ func (cmd DiffCmd) Docs() *cli.CommandDocumentation {
|
||||
}
|
||||
|
||||
func (cmd DiffCmd) ArgParser() *argparser.ArgParser {
|
||||
ap := argparser.NewArgParserWithVariableArgs(cmd.Name())
|
||||
ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).")
|
||||
ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).")
|
||||
ap.SupportsFlag(StatFlag, "", "Show stats of data changes")
|
||||
ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes")
|
||||
ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.")
|
||||
ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
|
||||
ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.")
|
||||
ap.SupportsFlag(cli.StagedFlag, "", "Show only the staged data changes.")
|
||||
ap.SupportsFlag(cli.CachedFlag, "c", "Synonym for --staged")
|
||||
ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.")
|
||||
ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit")
|
||||
ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.")
|
||||
ap.SupportsFlag(ReverseFlag, "R", "Reverses the direction of the diff.")
|
||||
ap.SupportsFlag(NameOnlyFlag, "", "Only shows table names.")
|
||||
ap.SupportsFlag(cli.SystemFlag, "", "Show system tables in addition to user tables")
|
||||
return ap
|
||||
return cli.CreateDiffArgParser(false)
|
||||
}
|
||||
|
||||
func (cmd DiffCmd) RequiresRepo() bool {
|
||||
@@ -228,14 +201,14 @@ func (cmd DiffCmd) Exec(ctx context.Context, commandStr string, args []string, _
|
||||
}
|
||||
|
||||
func (cmd DiffCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseError {
|
||||
if apr.Contains(StatFlag) || apr.Contains(SummaryFlag) {
|
||||
if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) {
|
||||
if apr.Contains(cli.StatFlag) || apr.Contains(cli.SummaryFlag) {
|
||||
if apr.Contains(cli.SchemaFlag) || apr.Contains(cli.DataFlag) {
|
||||
return errhand.BuildDError("invalid Arguments: --stat and --summary cannot be combined with --schema or --data").Build()
|
||||
}
|
||||
}
|
||||
|
||||
if apr.Contains(NameOnlyFlag) {
|
||||
if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) || apr.Contains(StatFlag) || apr.Contains(SummaryFlag) {
|
||||
if apr.Contains(cli.NameOnlyFlag) {
|
||||
if apr.Contains(cli.SchemaFlag) || apr.Contains(cli.DataFlag) || apr.Contains(cli.StatFlag) || apr.Contains(cli.SummaryFlag) {
|
||||
return errhand.BuildDError("invalid Arguments: --name-only cannot be combined with --schema, --data, --stat, or --summary").Build()
|
||||
}
|
||||
}
|
||||
@@ -254,25 +227,29 @@ func parseDiffDisplaySettings(apr *argparser.ArgParseResults) *diffDisplaySettin
|
||||
displaySettings := &diffDisplaySettings{}
|
||||
|
||||
displaySettings.diffParts = SchemaAndDataDiff
|
||||
if apr.Contains(DataFlag) && !apr.Contains(SchemaFlag) {
|
||||
if apr.Contains(cli.DataFlag) && !apr.Contains(cli.SchemaFlag) {
|
||||
displaySettings.diffParts = DataOnlyDiff
|
||||
} else if apr.Contains(SchemaFlag) && !apr.Contains(DataFlag) {
|
||||
} else if apr.Contains(cli.SchemaFlag) && !apr.Contains(cli.DataFlag) {
|
||||
displaySettings.diffParts = SchemaOnlyDiff
|
||||
} else if apr.Contains(StatFlag) {
|
||||
} else if apr.Contains(cli.StatFlag) {
|
||||
displaySettings.diffParts = Stat
|
||||
} else if apr.Contains(SummaryFlag) {
|
||||
} else if apr.Contains(cli.SummaryFlag) {
|
||||
displaySettings.diffParts = Summary
|
||||
} else if apr.Contains(NameOnlyFlag) {
|
||||
} else if apr.Contains(cli.NameOnlyFlag) {
|
||||
displaySettings.diffParts = NameOnlyDiff
|
||||
}
|
||||
|
||||
displaySettings.skinny = apr.Contains(SkinnyFlag)
|
||||
displaySettings.skinny = apr.Contains(cli.SkinnyFlag)
|
||||
|
||||
if cols, ok := apr.GetValueList(cli.IncludeCols); ok {
|
||||
displaySettings.includeCols = cols
|
||||
}
|
||||
|
||||
f := apr.GetValueOrDefault(FormatFlag, "tabular")
|
||||
switch strings.ToLower(f) {
|
||||
case "tabular":
|
||||
displaySettings.diffOutput = TabularDiffOutput
|
||||
switch strings.ToLower(apr.GetValueOrDefault(DiffMode, "context")) {
|
||||
switch strings.ToLower(apr.GetValueOrDefault(cli.DiffMode, "context")) {
|
||||
case "row":
|
||||
displaySettings.diffMode = diff.ModeRow
|
||||
case "line":
|
||||
@@ -288,8 +265,8 @@ func parseDiffDisplaySettings(apr *argparser.ArgParseResults) *diffDisplaySettin
|
||||
displaySettings.diffOutput = JsonDiffOutput
|
||||
}
|
||||
|
||||
displaySettings.limit, _ = apr.GetInt(limitParam)
|
||||
displaySettings.where = apr.GetValueOrDefault(whereParam, "")
|
||||
displaySettings.limit, _ = apr.GetInt(cli.LimitParam)
|
||||
displaySettings.where = apr.GetValueOrDefault(cli.WhereParam, "")
|
||||
|
||||
return displaySettings
|
||||
}
|
||||
@@ -301,12 +278,12 @@ func parseDiffArgs(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.Ar
|
||||
|
||||
staged := apr.Contains(cli.StagedFlag) || apr.Contains(cli.CachedFlag)
|
||||
|
||||
tableNames, err := dArgs.applyDiffRoots(queryist, sqlCtx, apr.Args, staged, apr.Contains(MergeBase))
|
||||
tableNames, err := dArgs.applyDiffRoots(queryist, sqlCtx, apr.Args, staged, apr.Contains(cli.MergeBase))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if apr.Contains(ReverseFlag) {
|
||||
if apr.Contains(cli.ReverseFlag) {
|
||||
dArgs.diffDatasets = &diffDatasets{
|
||||
fromRef: dArgs.toRef,
|
||||
toRef: dArgs.fromRef,
|
||||
@@ -1556,7 +1533,9 @@ func diffRows(
|
||||
if err != nil {
|
||||
return errhand.BuildDError("Error running diff query:\n%s", interpolatedQuery).AddCause(err).Build()
|
||||
}
|
||||
|
||||
for _, col := range dArgs.includeCols {
|
||||
modifiedColNames[col] = true // ensure included columns are always present
|
||||
}
|
||||
// instantiate a new schema that only contains the columns with changes
|
||||
var filteredUnionSch sql.Schema
|
||||
for _, s := range unionSch {
|
||||
|
||||
@@ -83,17 +83,17 @@ func (cmd ShowCmd) ArgParser() *argparser.ArgParser {
|
||||
ap.SupportsFlag(cli.NoPrettyFlag, "", "Show the object without making it pretty.")
|
||||
|
||||
// Flags inherited from Diff
|
||||
ap.SupportsFlag(DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).")
|
||||
ap.SupportsFlag(SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).")
|
||||
ap.SupportsFlag(StatFlag, "", "Show stats of data changes")
|
||||
ap.SupportsFlag(SummaryFlag, "", "Show summary of data and schema changes")
|
||||
ap.SupportsFlag(cli.DataFlag, "d", "Show only the data changes, do not show the schema changes (Both shown by default).")
|
||||
ap.SupportsFlag(cli.SchemaFlag, "s", "Show only the schema changes, do not show the data changes (Both shown by default).")
|
||||
ap.SupportsFlag(cli.StatFlag, "", "Show stats of data changes")
|
||||
ap.SupportsFlag(cli.SummaryFlag, "", "Show summary of data and schema changes")
|
||||
ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.")
|
||||
ap.SupportsString(whereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
|
||||
ap.SupportsInt(limitParam, "", "record_count", "limits to the first N diffs.")
|
||||
ap.SupportsString(cli.WhereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
|
||||
ap.SupportsInt(cli.LimitParam, "", "record_count", "limits to the first N diffs.")
|
||||
ap.SupportsFlag(cli.CachedFlag, "c", "Show only the staged data changes.")
|
||||
ap.SupportsFlag(SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.")
|
||||
ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit")
|
||||
ap.SupportsString(DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.")
|
||||
ap.SupportsFlag(cli.SkinnyFlag, "sk", "Shows only primary key columns and any columns with data changes.")
|
||||
ap.SupportsFlag(cli.MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit")
|
||||
ap.SupportsString(cli.DiffMode, "", "diff mode", "Determines how to display modified rows with tabular output. Valid values are row, line, in-place, context. Defaults to context.")
|
||||
return ap
|
||||
}
|
||||
|
||||
@@ -275,8 +275,8 @@ func getValueFromRefSpec(ctx context.Context, dEnv *env.DoltEnv, specRef string)
|
||||
}
|
||||
|
||||
func (cmd ShowCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseError {
|
||||
if apr.Contains(StatFlag) || apr.Contains(SummaryFlag) {
|
||||
if apr.Contains(SchemaFlag) || apr.Contains(DataFlag) {
|
||||
if apr.Contains(cli.StatFlag) || apr.Contains(cli.SummaryFlag) {
|
||||
if apr.Contains(cli.SchemaFlag) || apr.Contains(cli.DataFlag) {
|
||||
return errhand.BuildDError("invalid Arguments: --stat and --summary cannot be combined with --schema or --data").Build()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,7 +203,7 @@ func (cmd SqlServerCmd) ArgParserWithName(name string) *argparser.ArgParser {
|
||||
ap.SupportsUint(mcpPortFlag, "", "port", "If provided, runs a Dolt MCP HTTP server on this port alongside the sql-server.")
|
||||
// MCP SQL credentials (user required when MCP enabled; password optional)
|
||||
ap.SupportsString(mcpUserFlag, "", "user", "SQL user for MCP to connect as (required when --mcp-port is set).")
|
||||
ap.SupportsString(mcpPasswordFlag, "", "password", "Optional SQL password for MCP to connect with (requires --mcp-user). Defaults to env DOLT_ROOT_PASSWORD if unset.")
|
||||
ap.SupportsString(mcpPasswordFlag, "", "password", "Optional SQL password for MCP to connect with (requires --mcp-user).")
|
||||
ap.SupportsString(mcpDatabaseFlag, "", "database", "Optional SQL database name MCP should connect to (requires --mcp-port and --mcp-user).")
|
||||
return ap
|
||||
}
|
||||
|
||||
@@ -15,5 +15,5 @@
|
||||
package doltversion
|
||||
|
||||
const (
|
||||
Version = "1.58.8"
|
||||
Version = "1.59.6"
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ require (
|
||||
github.com/dolthub/fslock v0.0.3
|
||||
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718
|
||||
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
|
||||
github.com/dolthub/vitess v0.0.0-20250902185630-90811959cbd1
|
||||
github.com/dolthub/vitess v0.0.0-20250902225707-0159e964d73d
|
||||
github.com/dustin/go-humanize v1.0.1
|
||||
github.com/fatih/color v1.13.0
|
||||
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568
|
||||
@@ -61,7 +61,7 @@ require (
|
||||
github.com/dolthub/dolt-mcp v0.2.1-0.20250827202412-9d0f6e658fba
|
||||
github.com/dolthub/eventsapi_schema v0.0.0-20250725194025-a087efa1ee55
|
||||
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20250902204612-4e1a10a95d8c
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20250909231122-5fb8788af3dd
|
||||
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63
|
||||
github.com/edsrzf/mmap-go v1.2.0
|
||||
github.com/esote/minmaxheap v1.0.0
|
||||
|
||||
@@ -213,8 +213,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
|
||||
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
|
||||
github.com/dolthub/go-icu-regex v0.0.0-20250820171420-f2b78f56ce9f h1:oSA8CptGeCEdTdD9LFtv8x4juDfdaLKsx1eocyaj1bE=
|
||||
github.com/dolthub/go-icu-regex v0.0.0-20250820171420-f2b78f56ce9f/go.mod h1:kpsRG+a196Y69zsAFL0RkQICII9a571lcaxhvQnmrdY=
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20250902204612-4e1a10a95d8c h1:AHPSDdj6UnnVkc2eT3p4ZZ/7lf7MhHAwYCYkokad+PY=
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20250902204612-4e1a10a95d8c/go.mod h1:/OxBMtXwiC67PVMBV+/9sNjyfsK4SOF1moNnvdotEZM=
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20250909231122-5fb8788af3dd h1:FdOA59vPc2ilQAnQlXNfFSeVkCTHPa4m8Gl2KmOErTs=
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20250909231122-5fb8788af3dd/go.mod h1:ymoHIRZoZKO1EH9iUGcq4E6XyIpUaMgZz3ZPeWa828w=
|
||||
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
|
||||
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
|
||||
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=
|
||||
@@ -223,8 +223,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
|
||||
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
|
||||
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
|
||||
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
|
||||
github.com/dolthub/vitess v0.0.0-20250902185630-90811959cbd1 h1:sCpjbwm7rIV5o9OQDWsqGSzNXGSxpeAk3kX5gjPQlQI=
|
||||
github.com/dolthub/vitess v0.0.0-20250902185630-90811959cbd1/go.mod h1:tV3BrIVyDWVkkYy8dKt2o6hjJ89cHb5opY5FpCyhncQ=
|
||||
github.com/dolthub/vitess v0.0.0-20250902225707-0159e964d73d h1:oTWJxjzRmuHKuICUunCUwNuonubkXwOqPa5hXX3dXBo=
|
||||
github.com/dolthub/vitess v0.0.0-20250902225707-0159e964d73d/go.mod h1:tV3BrIVyDWVkkYy8dKt2o6hjJ89cHb5opY5FpCyhncQ=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/edsrzf/mmap-go v1.2.0 h1:hXLYlkbaPzt1SaQk+anYwKSRNhufIDCchSPkUD6dD84=
|
||||
|
||||
@@ -0,0 +1,510 @@
|
||||
# AGENT.md - Dolt Database Operations Guide
|
||||
|
||||
This file provides guidance for AI agents working with Dolt databases to maximize productivity and follow best practices.
|
||||
|
||||
## Quick Start
|
||||
|
||||
Dolt is "Git for Data" - a SQL database with version control capabilities. All Git commands have Dolt equivalents:
|
||||
- `git add` → `dolt add`
|
||||
- `git commit` → `dolt commit`
|
||||
- `git branch` → `dolt branch`
|
||||
- `git merge` → `dolt merge`
|
||||
- `git diff` → `dolt diff`
|
||||
|
||||
For help and documentation on commands, you can run `dolt --help` and `dolt <command> --help`.
|
||||
|
||||
## Essential Dolt CLI Commands
|
||||
|
||||
### Repository Operations
|
||||
```bash
|
||||
# Initialize new database
|
||||
dolt init
|
||||
|
||||
# Clone existing database
|
||||
dolt clone <remote-url>
|
||||
|
||||
# Show current status
|
||||
dolt status
|
||||
|
||||
# View commit history
|
||||
dolt log
|
||||
```
|
||||
|
||||
### Branch Management
|
||||
```bash
|
||||
# List branches
|
||||
dolt branch
|
||||
|
||||
# Create new branch
|
||||
dolt branch <branch-name>
|
||||
|
||||
# Switch branches
|
||||
dolt checkout <branch-name>
|
||||
|
||||
# Create and switch to new branch
|
||||
dolt checkout -b <branch-name>
|
||||
```
|
||||
|
||||
### Data Operations
|
||||
```bash
|
||||
# Stage changes
|
||||
dolt add <table-name>
|
||||
dolt add . # stage all changes
|
||||
|
||||
# Commit changes
|
||||
dolt commit -m "commit message"
|
||||
|
||||
# View differences
|
||||
dolt diff
|
||||
dolt diff <table-name>
|
||||
dolt diff <branch1> <branch2>
|
||||
|
||||
# Merge branches
|
||||
dolt merge <branch-name>
|
||||
```
|
||||
|
||||
## Starting and Connecting to Dolt SQL Server
|
||||
|
||||
### Start SQL Server
|
||||
```bash
|
||||
# Start server on default port (3306)
|
||||
dolt sql-server
|
||||
|
||||
# Start on specific port
|
||||
dolt sql-server --port=3307
|
||||
|
||||
# Start with specific host
|
||||
dolt sql-server --host=0.0.0.0 --port=3307
|
||||
|
||||
# Start in background
|
||||
dolt sql-server --port=3307 &
|
||||
```
|
||||
|
||||
### Connecting to SQL Server
|
||||
```bash
|
||||
# Connect with dolt sql command
|
||||
dolt sql
|
||||
|
||||
# Connect with mysql client
|
||||
mysql -h 127.0.0.1 -P 3306 -u root
|
||||
|
||||
# Connect with specific database
|
||||
mysql -h 127.0.0.1 -P 3306 -u root -D <database-name>
|
||||
```
|
||||
|
||||
## Dolt Testing with dolt_test System Table
|
||||
|
||||
### Unit Testing with dolt_test
|
||||
|
||||
The dolt_test system table provides a powerful way to create and run unit tests for your database. This is the preferred method for testing data integrity, business rules, and schema validation.
|
||||
|
||||
#### Creating Tests
|
||||
|
||||
Tests are created by inserting rows into the `dolt_tests` system table:
|
||||
|
||||
```sql
|
||||
-- Create a simple test
|
||||
INSERT INTO `dolt_tests` VALUES (
|
||||
'test_user_count',
|
||||
'validation',
|
||||
'SELECT COUNT(*) as user_count FROM users;',
|
||||
'row_count',
|
||||
'>',
|
||||
'0'
|
||||
);
|
||||
|
||||
-- Create a test with expected result
|
||||
INSERT INTO `dolt_tests` VALUES (
|
||||
'test_valid_emails',
|
||||
'validation',
|
||||
'SELECT COUNT(*) FROM users WHERE email NOT LIKE "%@%";',
|
||||
'row_count',
|
||||
'==',
|
||||
'0'
|
||||
);
|
||||
|
||||
-- Create a schema validation test
|
||||
INSERT INTO `dolt_tests` VALUES (
|
||||
'test_users_schema',
|
||||
'schema',
|
||||
'DESCRIBE users;',
|
||||
'row_count',
|
||||
'>=',
|
||||
'5'
|
||||
);
|
||||
```
|
||||
|
||||
#### Test Structure
|
||||
|
||||
Each test row contains:
|
||||
- test_name: Unique identifier for the test
|
||||
- test_group: Optional grouping for tests (e.g., 'validation', 'schema', 'integration')
|
||||
- test_query: SQL query to execute
|
||||
- assertion_type: Type of assertion ('expected_rows', 'expected_columns', 'expected_single_value')
|
||||
- assertion_comparator: Comparison operator ('==', '>', '<', '>=', '<=', '!=')
|
||||
- assertion_value: Expected value for comparison
|
||||
|
||||
#### Running Tests
|
||||
|
||||
```sql
|
||||
-- Run all tests
|
||||
SELECT * FROM dolt_test_run();
|
||||
|
||||
-- Run specific test
|
||||
SELECT * FROM dolt_test_run('test_user_count');
|
||||
|
||||
-- Run tests with filtering
|
||||
SELECT * FROM dolt_test_run() WHERE test_name LIKE 'test_user%' AND status != 'PASS';
|
||||
```
|
||||
|
||||
#### Test Result Interpretation
|
||||
|
||||
The dolt_test_run() function returns:
|
||||
- test_name: Name of the test
|
||||
- status: PASS, FAIL, or ERROR
|
||||
- actual_result: Actual query result
|
||||
- expected_result: Expected result
|
||||
- message: Additional details
|
||||
|
||||
#### Advanced Testing Examples
|
||||
|
||||
```sql
|
||||
-- Test data integrity
|
||||
INSERT INTO `dolt_tests` VALUES (
|
||||
'test_no_orphaned_orders',
|
||||
'integrity',
|
||||
'SELECT COUNT(*) FROM orders o LEFT JOIN users u ON o.user_id = u.id WHERE u.id IS NULL;',
|
||||
'row_count',
|
||||
'==',
|
||||
'0'
|
||||
);
|
||||
|
||||
-- Test business rules
|
||||
INSERT INTO `dolt_tests` VALUES (
|
||||
'test_positive_prices',
|
||||
'business_rules',
|
||||
'SELECT COUNT(*) FROM products WHERE price <= 0;',
|
||||
'row_count',
|
||||
'==',
|
||||
'0'
|
||||
);
|
||||
|
||||
-- Test complex relationships
|
||||
INSERT INTO `dolt_tests` VALUES (
|
||||
'test_order_totals',
|
||||
'integrity',
|
||||
'SELECT COUNT(*) FROM orders o JOIN order_items oi ON o.id = oi.order_id GROUP BY o.id HAVING SUM(oi.quantity * oi.price) != o.total;',
|
||||
'row_count',
|
||||
'==',
|
||||
'0'
|
||||
);
|
||||
```
|
||||
|
||||
### Dolt CI for DoltHub Integration
|
||||
|
||||
Dolt CI is specifically designed for running tests on DoltHub when pull requests are created. Use this only for tests you want to run automatically on DoltHub.
|
||||
|
||||
#### Prerequisites for DoltHub CI
|
||||
- Requires Dolt v1.43.14 or later
|
||||
- Must initialize CI capabilities: `dolt ci init`
|
||||
- Workflows defined in YAML files
|
||||
|
||||
#### Available CI Commands
|
||||
```bash
|
||||
# Initialize CI capabilities
|
||||
dolt ci init
|
||||
|
||||
# List available workflows
|
||||
dolt ci ls
|
||||
|
||||
# View workflow details
|
||||
dolt ci view <workflow-name>
|
||||
|
||||
# View specific job in workflow
|
||||
dolt ci view <workflow-name> <job-name>
|
||||
|
||||
# Run workflow locally (for testing before DoltHub)
|
||||
dolt ci run <workflow-name>
|
||||
```
|
||||
|
||||
#### Creating CI Workflows for DoltHub
|
||||
|
||||
Create workflow files that will run on DoltHub when pull requests are opened:
|
||||
|
||||
```yaml
|
||||
name: doltHub validation workflow
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
jobs:
|
||||
- name: validate schema
|
||||
steps:
|
||||
- name: check required tables exist
|
||||
saved_query_name: show_tables
|
||||
expected_rows: ">= 3"
|
||||
|
||||
- name: validate user data
|
||||
saved_query_name: user_count_check
|
||||
expected_columns: "== 1"
|
||||
expected_rows: "> 0"
|
||||
|
||||
- name: data integrity checks
|
||||
steps:
|
||||
- name: check email format
|
||||
saved_query_name: valid_emails
|
||||
expected_rows: "== 0" # No invalid emails
|
||||
```
|
||||
|
||||
### Best Practices for Testing
|
||||
|
||||
1. **Use dolt_test for Unit Testing**
|
||||
- Create tests for data validation
|
||||
- Test business rules and constraints
|
||||
- Validate schema changes
|
||||
- Run tests frequently during development
|
||||
|
||||
2. **Use Dolt CI for DoltHub Integration**
|
||||
- Only for tests that should run on pull requests
|
||||
- Focus on integration and deployment validation
|
||||
- Test against production-like data
|
||||
|
||||
3. **Create Comprehensive Test Suites**
|
||||
- Test data integrity constraints
|
||||
- Validate business rules
|
||||
- Check schema requirements
|
||||
- Verify data relationships
|
||||
|
||||
4. **Version Control Your Tests**
|
||||
- Commit test definitions to repository
|
||||
- Track changes to test configuration
|
||||
- Use branches for test development
|
||||
|
||||
## System Tables for Version Control
|
||||
|
||||
Dolt exposes version control operations through system tables accessible via SQL:
|
||||
|
||||
### Core System Tables
|
||||
```sql
|
||||
-- View commit history
|
||||
SELECT * FROM dolt_log;
|
||||
|
||||
-- Check current status
|
||||
SELECT * FROM dolt_status;
|
||||
|
||||
-- View branch information
|
||||
SELECT * FROM dolt_branches;
|
||||
|
||||
-- See table diffs
|
||||
SELECT * FROM dolt_diff_<table_name>;
|
||||
|
||||
-- View schema changes
|
||||
SELECT * FROM dolt_schema_diff;
|
||||
|
||||
-- Check conflicts during merge
|
||||
SELECT * FROM dolt_conflicts_<table_name>;
|
||||
|
||||
-- View commit metadata
|
||||
SELECT * FROM dolt_commits;
|
||||
```
|
||||
|
||||
### Version Control Operations via SQL
|
||||
|
||||
When working in SQL sessions, you can execute version control operations using stored procedures:
|
||||
|
||||
```sql
|
||||
-- Stage and commit changes
|
||||
CALL dolt_add('.');
|
||||
CALL dolt_commit('-m', 'commit message');
|
||||
|
||||
-- Branch operations
|
||||
CALL dolt_branch('<branch_name>');
|
||||
CALL dolt_checkout('<branch_name>');
|
||||
CALL dolt_merge('<branch_name>');
|
||||
```
|
||||
|
||||
**Note:** Use CLI commands (`dolt add`, `dolt commit`, etc.) for most operations. SQL procedures are useful when already in a SQL session.
|
||||
|
||||
### Advanced System Tables
|
||||
```sql
|
||||
-- View remotes
|
||||
SELECT * FROM dolt_remotes;
|
||||
|
||||
-- Check merge conflicts
|
||||
SELECT * FROM dolt_conflicts;
|
||||
|
||||
-- View statistics
|
||||
SELECT * FROM dolt_statistics;
|
||||
|
||||
-- See ignored tables
|
||||
SELECT * FROM dolt_ignore;
|
||||
```
|
||||
|
||||
## CLI vs SQL Approach
|
||||
|
||||
**Prefer CLI commands for:**
|
||||
- Version control operations (add, commit, branch, merge)
|
||||
- Repository management (init, clone, push, pull)
|
||||
- Conflict resolution
|
||||
- Status checking and history viewing
|
||||
|
||||
**Use SQL for:**
|
||||
- Data queries and analysis
|
||||
- Complex data transformations
|
||||
- Examining system tables (dolt_log, dolt_status, etc.)
|
||||
- When already in an active SQL session
|
||||
|
||||
## Schema Design Recommendations
|
||||
|
||||
### Use UUID Keys Instead of Auto-Increment
|
||||
|
||||
For Dolt's version control features, use UUID primary keys instead of auto-increment:
|
||||
|
||||
```sql
|
||||
-- Recommended
|
||||
CREATE TABLE users (
|
||||
id varchar(36) default(uuid()) primary key,
|
||||
name varchar(255)
|
||||
);
|
||||
|
||||
-- Avoid auto-increment with Dolt
|
||||
-- id int auto_increment primary key
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Prevents merge conflicts across branches and database clones
|
||||
- Automatic generation with default(uuid())
|
||||
- Works seamlessly in distributed environments
|
||||
|
||||
## Best Practices for Agents
|
||||
|
||||
### 1. Always Work on Feature Branches
|
||||
```bash
|
||||
# Create feature branch before making changes
|
||||
dolt checkout -b feature/agent-changes
|
||||
|
||||
# Make changes on feature branch
|
||||
dolt sql -q "INSERT INTO users VALUES (1, 'Alice');"
|
||||
|
||||
# Stage and commit
|
||||
dolt add .
|
||||
dolt commit -m "Add new user Alice"
|
||||
|
||||
# Switch back to main to merge
|
||||
dolt checkout main
|
||||
dolt merge feature/agent-changes
|
||||
```
|
||||
|
||||
### 2. Use SQL for Data Operations, CLI for Version Control
|
||||
```bash
|
||||
# Use dolt sql for data changes
|
||||
dolt sql -q "INSERT INTO users VALUES (1, 'Alice');"
|
||||
dolt sql -q "UPDATE products SET price = price * 1.1 WHERE category = 'electronics';"
|
||||
|
||||
# Check status and commit using CLI
|
||||
dolt status
|
||||
dolt add .
|
||||
dolt commit -m "Update user and product data"
|
||||
```
|
||||
|
||||
### 3. Validate Changes with System Tables
|
||||
```sql
|
||||
-- Before major operations, check current state
|
||||
SELECT * FROM dolt_status;
|
||||
SELECT * FROM dolt_branches;
|
||||
|
||||
-- After changes, verify with diffs
|
||||
SELECT * FROM dolt_diff_users;
|
||||
SELECT * FROM dolt_schema_diff;
|
||||
```
|
||||
|
||||
### 4. Use dolt_test for Data Validation
|
||||
Create tests to validate:
|
||||
- Data integrity after changes
|
||||
- Schema compatibility
|
||||
- Business rule compliance
|
||||
- Cross-table relationships
|
||||
|
||||
### 5. Handle Conflicts Gracefully
|
||||
```bash
|
||||
# Check for conflicts using CLI
|
||||
dolt conflicts cat <table_name>
|
||||
dolt conflicts resolve --ours <table_name>
|
||||
dolt conflicts resolve --theirs <table_name>
|
||||
|
||||
# Or use SQL to examine conflicts
|
||||
dolt sql -q "SELECT * FROM dolt_conflicts_<table_name>;"
|
||||
```
|
||||
|
||||
## Common Workflow Examples
|
||||
|
||||
### Data Migration Workflow
|
||||
```bash
|
||||
# Create migration branch
|
||||
dolt checkout -b migration/update-schema
|
||||
|
||||
# Apply schema changes via SQL
|
||||
dolt sql -q "ALTER TABLE users ADD COLUMN email VARCHAR(255);"
|
||||
|
||||
# Create validation tests
|
||||
dolt sql -q "INSERT INTO `dolt_tests` VALUES ('test_users_schema', 'schema', 'DESCRIBE users;', 'row_count', '>=', '6');"
|
||||
dolt sql -q "INSERT INTO `dolt_tests` VALUES ('test_email_column', 'schema', 'SELECT COUNT(*) FROM users WHERE email IS NULL;', 'row_count', '>=', '0');"
|
||||
|
||||
# Run tests to validate changes
|
||||
dolt sql -q "SELECT * FROM dolt_test_run();"
|
||||
|
||||
# Stage and commit
|
||||
dolt add .
|
||||
dolt commit -m "Add email column to users table"
|
||||
|
||||
# Merge back
|
||||
dolt checkout main
|
||||
dolt merge migration/update-schema
|
||||
```
|
||||
|
||||
### Data Analysis Workflow
|
||||
```bash
|
||||
# Create analysis branch
|
||||
dolt checkout -b analysis/user-behavior
|
||||
|
||||
# Create analysis tables via SQL
|
||||
dolt sql -q "CREATE TABLE user_metrics AS
|
||||
SELECT user_id, COUNT(*) as actions
|
||||
FROM user_actions
|
||||
GROUP BY user_id;"
|
||||
|
||||
# Create tests to validate analysis
|
||||
dolt sql -q "INSERT INTO `dolt_tests` VALUES ('test_metrics_created', 'analysis', 'SELECT COUNT(*) FROM user_metrics;', 'row_count', '>', '0');"
|
||||
dolt sql -q "INSERT INTO `dolt_tests` VALUES ('test_metrics_integrity', 'integrity', 'SELECT COUNT(*) FROM user_metrics um LEFT JOIN users u ON um.user_id = u.id WHERE u.id IS NULL;', 'row_count', '==', '0');"
|
||||
|
||||
# Run tests to validate analysis
|
||||
dolt sql -q "SELECT * FROM dolt_test_run();"
|
||||
|
||||
# Stage and commit using CLI
|
||||
dolt add user_metrics
|
||||
dolt commit -m "Add user behavior analysis"
|
||||
```
|
||||
|
||||
## Integration with External Tools
|
||||
|
||||
### Database Clients
|
||||
Most MySQL clients work with Dolt:
|
||||
- MySQL Workbench
|
||||
- phpMyAdmin
|
||||
- DataGrip
|
||||
- DBeaver
|
||||
|
||||
### Backup and Sync
|
||||
```bash
|
||||
# Push to remote
|
||||
dolt push origin main
|
||||
|
||||
# Pull changes
|
||||
dolt pull origin main
|
||||
|
||||
# Clone for backup
|
||||
dolt clone <remote-url> backup-location
|
||||
```
|
||||
|
||||
This guide enables agents to leverage Dolt's unique version control capabilities while maintaining data integrity and following collaborative development practices.
|
||||
@@ -174,6 +174,35 @@ func ExcludeIgnoredTables(ctx context.Context, roots Roots, tables []TableName)
|
||||
return filteredTables, nil
|
||||
}
|
||||
|
||||
// IdentifyIgnoredTables takes a list of table names and identifies any tables that are ignored, by evaluating the
|
||||
// table names against the patterns in the dolt_ignore table from the working set.
|
||||
func IdentifyIgnoredTables(ctx context.Context, roots Roots, tables []TableName) (ignoredTables []TableName, err error) {
|
||||
schemas := GetUniqueSchemaNamesFromTableNames(tables)
|
||||
ignorePatternMap, err := GetIgnoredTablePatterns(ctx, roots, schemas)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, tbl := range tables {
|
||||
ignorePatterns := ignorePatternMap[tbl.Schema]
|
||||
ignored, err := ignorePatterns.IsTableNameIgnored(tbl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conflict := AsDoltIgnoreInConflict(err); conflict != nil {
|
||||
// no-op
|
||||
} else if ignored == DontIgnore {
|
||||
// no-op
|
||||
} else if ignored == Ignore {
|
||||
ignoredTables = append(ignoredTables, tbl)
|
||||
} else {
|
||||
return nil, fmt.Errorf("IsTableNameIgnored returned ErrorOccurred but no error!")
|
||||
}
|
||||
}
|
||||
|
||||
return ignoredTables, nil
|
||||
}
|
||||
|
||||
// compilePattern takes a dolt_ignore pattern and generate a Regexp that matches against the same table names as the pattern.
|
||||
func compilePattern(pattern string) (*regexp.Regexp, error) {
|
||||
pattern = "^" + regexp.QuoteMeta(pattern) + "$"
|
||||
|
||||
@@ -16,6 +16,7 @@ package doltdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -27,6 +28,9 @@ import (
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
//go:embed AGENT.md
|
||||
var DefaultAgentDocValue string
|
||||
|
||||
type ctxKey int
|
||||
type ctxValue int
|
||||
|
||||
@@ -174,517 +178,6 @@ const (
|
||||
ReadmeDoc = "README.md"
|
||||
// AgentDoc is the key for accessing the agent documentation within the docs table
|
||||
AgentDoc = "AGENT.md"
|
||||
|
||||
DefaultAgentDocValue = `# AGENT.md - Dolt Database Operations Guide
|
||||
|
||||
This file provides guidance for AI agents working with Dolt databases to maximize productivity and follow best practices.
|
||||
|
||||
## Quick Start
|
||||
|
||||
Dolt is "Git for Data" - a SQL database with version control capabilities. All Git commands have Dolt equivalents:
|
||||
- ` + "`git add` → `dolt add`" + `
|
||||
- ` + "`git commit` → `dolt commit`" + `
|
||||
- ` + "`git branch` → `dolt branch`" + `
|
||||
- ` + "`git merge` → `dolt merge`" + `
|
||||
- ` + "`git diff` → `dolt diff`" + `
|
||||
|
||||
For help and documentation on commands, you can run ` + "`dolt --help`" + ` and ` + "`dolt <command> --help`" + `.
|
||||
|
||||
## Essential Dolt CLI Commands
|
||||
|
||||
### Repository Operations
|
||||
` + "```bash" + `
|
||||
# Initialize new database
|
||||
dolt init
|
||||
|
||||
# Clone existing database
|
||||
dolt clone <remote-url>
|
||||
|
||||
# Show current status
|
||||
dolt status
|
||||
|
||||
# View commit history
|
||||
dolt log
|
||||
` + "```" + `
|
||||
|
||||
### Branch Management
|
||||
` + "```bash" + `
|
||||
# List branches
|
||||
dolt branch
|
||||
|
||||
# Create new branch
|
||||
dolt branch <branch-name>
|
||||
|
||||
# Switch branches
|
||||
dolt checkout <branch-name>
|
||||
|
||||
# Create and switch to new branch
|
||||
dolt checkout -b <branch-name>
|
||||
` + "```" + `
|
||||
|
||||
### Data Operations
|
||||
` + "```bash" + `
|
||||
# Stage changes
|
||||
dolt add <table-name>
|
||||
dolt add . # stage all changes
|
||||
|
||||
# Commit changes
|
||||
dolt commit -m "commit message"
|
||||
|
||||
# View differences
|
||||
dolt diff
|
||||
dolt diff <table-name>
|
||||
dolt diff <branch1> <branch2>
|
||||
|
||||
# Merge branches
|
||||
dolt merge <branch-name>
|
||||
` + "```" + `
|
||||
|
||||
## Starting and Connecting to Dolt SQL Server
|
||||
|
||||
### Start SQL Server
|
||||
` + "```bash" + `
|
||||
# Start server on default port (3306)
|
||||
dolt sql-server
|
||||
|
||||
# Start on specific port
|
||||
dolt sql-server --port=3307
|
||||
|
||||
# Start with specific host
|
||||
dolt sql-server --host=0.0.0.0 --port=3307
|
||||
|
||||
# Start in background
|
||||
dolt sql-server --port=3307 &
|
||||
` + "```" + `
|
||||
|
||||
### Connecting to SQL Server
|
||||
` + "```bash" + `
|
||||
# Connect with dolt sql command
|
||||
dolt sql
|
||||
|
||||
# Connect with mysql client
|
||||
mysql -h 127.0.0.1 -P 3306 -u root
|
||||
|
||||
# Connect with specific database
|
||||
mysql -h 127.0.0.1 -P 3306 -u root -D <database-name>
|
||||
` + "```" + `
|
||||
|
||||
## Dolt Testing with dolt_test System Table
|
||||
|
||||
### Unit Testing with dolt_test
|
||||
|
||||
The dolt_test system table provides a powerful way to create and run unit tests for your database. This is the preferred method for testing data integrity, business rules, and schema validation.
|
||||
|
||||
#### Creating Tests
|
||||
|
||||
Tests are created by inserting rows into the ` + "`dolt_tests`" + ` system table:
|
||||
|
||||
` + "```sql" + `
|
||||
-- Create a simple test
|
||||
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
|
||||
'test_user_count',
|
||||
'validation',
|
||||
'SELECT COUNT(*) as user_count FROM users;',
|
||||
'row_count',
|
||||
'>',
|
||||
'0'
|
||||
);
|
||||
|
||||
-- Create a test with expected result
|
||||
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
|
||||
'test_valid_emails',
|
||||
'validation',
|
||||
'SELECT COUNT(*) FROM users WHERE email NOT LIKE "%@%";',
|
||||
'row_count',
|
||||
'==',
|
||||
'0'
|
||||
);
|
||||
|
||||
-- Create a schema validation test
|
||||
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
|
||||
'test_users_schema',
|
||||
'schema',
|
||||
'DESCRIBE users;',
|
||||
'row_count',
|
||||
'>=',
|
||||
'5'
|
||||
);
|
||||
` + "```" + `
|
||||
|
||||
#### Test Structure
|
||||
|
||||
Each test row contains:
|
||||
- test_name: Unique identifier for the test
|
||||
- test_group: Optional grouping for tests (e.g., 'validation', 'schema', 'integration')
|
||||
- test_query: SQL query to execute
|
||||
- assertion_type: Type of assertion ('expected_rows', 'expected_columns', 'expected_single_value')
|
||||
- assertion_comparator: Comparison operator ('==', '>', '<', '>=', '<=', '!=')
|
||||
- assertion_value: Expected value for comparison
|
||||
|
||||
#### Running Tests
|
||||
|
||||
` + "```sql" + `
|
||||
-- Run all tests
|
||||
SELECT * FROM dolt_test_run();
|
||||
|
||||
-- Run specific test
|
||||
SELECT * FROM dolt_test_run('test_user_count');
|
||||
|
||||
-- Run tests with filtering
|
||||
SELECT * FROM dolt_test_run() WHERE test_name LIKE 'test_user%' AND status != 'PASS';
|
||||
` + "```" + `
|
||||
|
||||
#### Test Result Interpretation
|
||||
|
||||
The dolt_test_run() function returns:
|
||||
- test_name: Name of the test
|
||||
- status: PASS, FAIL, or ERROR
|
||||
- actual_result: Actual query result
|
||||
- expected_result: Expected result
|
||||
- message: Additional details
|
||||
|
||||
#### Advanced Testing Examples
|
||||
|
||||
` + "```sql" + `
|
||||
-- Test data integrity
|
||||
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
|
||||
'test_no_orphaned_orders',
|
||||
'integrity',
|
||||
'SELECT COUNT(*) FROM orders o LEFT JOIN users u ON o.user_id = u.id WHERE u.id IS NULL;',
|
||||
'row_count',
|
||||
'==',
|
||||
'0'
|
||||
);
|
||||
|
||||
-- Test business rules
|
||||
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
|
||||
'test_positive_prices',
|
||||
'business_rules',
|
||||
'SELECT COUNT(*) FROM products WHERE price <= 0;',
|
||||
'row_count',
|
||||
'==',
|
||||
'0'
|
||||
);
|
||||
|
||||
-- Test complex relationships
|
||||
INSERT INTO ` + "`dolt_tests`" + ` VALUES (
|
||||
'test_order_totals',
|
||||
'integrity',
|
||||
'SELECT COUNT(*) FROM orders o JOIN order_items oi ON o.id = oi.order_id GROUP BY o.id HAVING SUM(oi.quantity * oi.price) != o.total;',
|
||||
'row_count',
|
||||
'==',
|
||||
'0'
|
||||
);
|
||||
` + "```" + `
|
||||
|
||||
### Dolt CI for DoltHub Integration
|
||||
|
||||
Dolt CI is specifically designed for running tests on DoltHub when pull requests are created. Use this only for tests you want to run automatically on DoltHub.
|
||||
|
||||
#### Prerequisites for DoltHub CI
|
||||
- Requires Dolt v1.43.14 or later
|
||||
- Must initialize CI capabilities: ` + "`dolt ci init`" + `
|
||||
- Workflows defined in YAML files
|
||||
|
||||
#### Available CI Commands
|
||||
` + "```bash" + `
|
||||
# Initialize CI capabilities
|
||||
dolt ci init
|
||||
|
||||
# List available workflows
|
||||
dolt ci ls
|
||||
|
||||
# View workflow details
|
||||
dolt ci view <workflow-name>
|
||||
|
||||
# View specific job in workflow
|
||||
dolt ci view <workflow-name> <job-name>
|
||||
|
||||
# Run workflow locally (for testing before DoltHub)
|
||||
dolt ci run <workflow-name>
|
||||
` + "```" + `
|
||||
|
||||
#### Creating CI Workflows for DoltHub
|
||||
|
||||
Create workflow files that will run on DoltHub when pull requests are opened:
|
||||
|
||||
` + "```yaml" + `
|
||||
name: doltHub validation workflow
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
jobs:
|
||||
- name: validate schema
|
||||
steps:
|
||||
- name: check required tables exist
|
||||
saved_query_name: show_tables
|
||||
expected_rows: ">= 3"
|
||||
|
||||
- name: validate user data
|
||||
saved_query_name: user_count_check
|
||||
expected_columns: "== 1"
|
||||
expected_rows: "> 0"
|
||||
|
||||
- name: data integrity checks
|
||||
steps:
|
||||
- name: check email format
|
||||
saved_query_name: valid_emails
|
||||
expected_rows: "== 0" # No invalid emails
|
||||
` + "```" + `
|
||||
|
||||
### Best Practices for Testing
|
||||
|
||||
1. **Use dolt_test for Unit Testing**
|
||||
- Create tests for data validation
|
||||
- Test business rules and constraints
|
||||
- Validate schema changes
|
||||
- Run tests frequently during development
|
||||
|
||||
2. **Use Dolt CI for DoltHub Integration**
|
||||
- Only for tests that should run on pull requests
|
||||
- Focus on integration and deployment validation
|
||||
- Test against production-like data
|
||||
|
||||
3. **Create Comprehensive Test Suites**
|
||||
- Test data integrity constraints
|
||||
- Validate business rules
|
||||
- Check schema requirements
|
||||
- Verify data relationships
|
||||
|
||||
4. **Version Control Your Tests**
|
||||
- Commit test definitions to repository
|
||||
- Track changes to test configuration
|
||||
- Use branches for test development
|
||||
|
||||
## System Tables for Version Control
|
||||
|
||||
Dolt exposes version control operations through system tables accessible via SQL:
|
||||
|
||||
### Core System Tables
|
||||
` + "```sql" + `
|
||||
-- View commit history
|
||||
SELECT * FROM dolt_log;
|
||||
|
||||
-- Check current status
|
||||
SELECT * FROM dolt_status;
|
||||
|
||||
-- View branch information
|
||||
SELECT * FROM dolt_branches;
|
||||
|
||||
-- See table diffs
|
||||
SELECT * FROM dolt_diff_<table_name>;
|
||||
|
||||
-- View schema changes
|
||||
SELECT * FROM dolt_schema_diff;
|
||||
|
||||
-- Check conflicts during merge
|
||||
SELECT * FROM dolt_conflicts_<table_name>;
|
||||
|
||||
-- View commit metadata
|
||||
SELECT * FROM dolt_commits;
|
||||
` + "```" + `
|
||||
|
||||
### Version Control Operations via SQL
|
||||
|
||||
When working in SQL sessions, you can execute version control operations using stored procedures:
|
||||
|
||||
` + "```sql" + `
|
||||
-- Stage and commit changes
|
||||
CALL dolt_add('.');
|
||||
CALL dolt_commit('-m', 'commit message');
|
||||
|
||||
-- Branch operations
|
||||
CALL dolt_branch('<branch_name>');
|
||||
CALL dolt_checkout('<branch_name>');
|
||||
CALL dolt_merge('<branch_name>');
|
||||
` + "```" + `
|
||||
|
||||
**Note:** Use CLI commands (` + "`dolt add`, `dolt commit`, etc." + `) for most operations. SQL procedures are useful when already in a SQL session.
|
||||
|
||||
### Advanced System Tables
|
||||
` + "```sql" + `
|
||||
-- View remotes
|
||||
SELECT * FROM dolt_remotes;
|
||||
|
||||
-- Check merge conflicts
|
||||
SELECT * FROM dolt_conflicts;
|
||||
|
||||
-- View statistics
|
||||
SELECT * FROM dolt_statistics;
|
||||
|
||||
-- See ignored tables
|
||||
SELECT * FROM dolt_ignore;
|
||||
` + "```" + `
|
||||
|
||||
## CLI vs SQL Approach
|
||||
|
||||
**Prefer CLI commands for:**
|
||||
- Version control operations (add, commit, branch, merge)
|
||||
- Repository management (init, clone, push, pull)
|
||||
- Conflict resolution
|
||||
- Status checking and history viewing
|
||||
|
||||
**Use SQL for:**
|
||||
- Data queries and analysis
|
||||
- Complex data transformations
|
||||
- Examining system tables (dolt_log, dolt_status, etc.)
|
||||
- When already in an active SQL session
|
||||
|
||||
## Schema Design Recommendations
|
||||
|
||||
### Use UUID Keys Instead of Auto-Increment
|
||||
|
||||
For Dolt's version control features, use UUID primary keys instead of auto-increment:
|
||||
|
||||
` + "```sql" + `
|
||||
-- Recommended
|
||||
CREATE TABLE users (
|
||||
id varchar(36) default(uuid()) primary key,
|
||||
name varchar(255)
|
||||
);
|
||||
|
||||
-- Avoid auto-increment with Dolt
|
||||
-- id int auto_increment primary key
|
||||
` + "```" + `
|
||||
|
||||
**Benefits:**
|
||||
- Prevents merge conflicts across branches and database clones
|
||||
- Automatic generation with default(uuid())
|
||||
- Works seamlessly in distributed environments
|
||||
|
||||
## Best Practices for Agents
|
||||
|
||||
### 1. Always Work on Feature Branches
|
||||
` + "```bash" + `
|
||||
# Create feature branch before making changes
|
||||
dolt checkout -b feature/agent-changes
|
||||
|
||||
# Make changes on feature branch
|
||||
dolt sql -q "INSERT INTO users VALUES (1, 'Alice');"
|
||||
|
||||
# Stage and commit
|
||||
dolt add .
|
||||
dolt commit -m "Add new user Alice"
|
||||
|
||||
# Switch back to main to merge
|
||||
dolt checkout main
|
||||
dolt merge feature/agent-changes
|
||||
` + "```" + `
|
||||
|
||||
### 2. Use SQL for Data Operations, CLI for Version Control
|
||||
` + "```bash" + `
|
||||
# Use dolt sql for data changes
|
||||
dolt sql -q "INSERT INTO users VALUES (1, 'Alice');"
|
||||
dolt sql -q "UPDATE products SET price = price * 1.1 WHERE category = 'electronics';"
|
||||
|
||||
# Check status and commit using CLI
|
||||
dolt status
|
||||
dolt add .
|
||||
dolt commit -m "Update user and product data"
|
||||
` + "```" + `
|
||||
|
||||
### 3. Validate Changes with System Tables
|
||||
` + "```sql" + `
|
||||
-- Before major operations, check current state
|
||||
SELECT * FROM dolt_status;
|
||||
SELECT * FROM dolt_branches;
|
||||
|
||||
-- After changes, verify with diffs
|
||||
SELECT * FROM dolt_diff_users;
|
||||
SELECT * FROM dolt_schema_diff;
|
||||
` + "```" + `
|
||||
|
||||
### 4. Use dolt_test for Data Validation
|
||||
Create tests to validate:
|
||||
- Data integrity after changes
|
||||
- Schema compatibility
|
||||
- Business rule compliance
|
||||
- Cross-table relationships
|
||||
|
||||
### 5. Handle Conflicts Gracefully
|
||||
` + "```bash" + `
|
||||
# Check for conflicts using CLI
|
||||
dolt conflicts cat <table_name>
|
||||
dolt conflicts resolve --ours <table_name>
|
||||
dolt conflicts resolve --theirs <table_name>
|
||||
|
||||
# Or use SQL to examine conflicts
|
||||
dolt sql -q "SELECT * FROM dolt_conflicts_<table_name>;"
|
||||
` + "```" + `
|
||||
|
||||
## Common Workflow Examples
|
||||
|
||||
### Data Migration Workflow
|
||||
` + "```bash" + `
|
||||
# Create migration branch
|
||||
dolt checkout -b migration/update-schema
|
||||
|
||||
# Apply schema changes via SQL
|
||||
dolt sql -q "ALTER TABLE users ADD COLUMN email VARCHAR(255);"
|
||||
|
||||
# Create validation tests
|
||||
dolt sql -q "INSERT INTO ` + "`dolt_tests`" + ` VALUES ('test_users_schema', 'schema', 'DESCRIBE users;', 'row_count', '>=', '6');"
|
||||
dolt sql -q "INSERT INTO ` + "`dolt_tests`" + ` VALUES ('test_email_column', 'schema', 'SELECT COUNT(*) FROM users WHERE email IS NULL;', 'row_count', '>=', '0');"
|
||||
|
||||
# Run tests to validate changes
|
||||
dolt sql -q "SELECT * FROM dolt_test_run();"
|
||||
|
||||
# Stage and commit
|
||||
dolt add .
|
||||
dolt commit -m "Add email column to users table"
|
||||
|
||||
# Merge back
|
||||
dolt checkout main
|
||||
dolt merge migration/update-schema
|
||||
` + "```" + `
|
||||
|
||||
### Data Analysis Workflow
|
||||
` + "```bash" + `
|
||||
# Create analysis branch
|
||||
dolt checkout -b analysis/user-behavior
|
||||
|
||||
# Create analysis tables via SQL
|
||||
dolt sql -q "CREATE TABLE user_metrics AS
|
||||
SELECT user_id, COUNT(*) as actions
|
||||
FROM user_actions
|
||||
GROUP BY user_id;"
|
||||
|
||||
# Create tests to validate analysis
|
||||
dolt sql -q "INSERT INTO ` + "`dolt_tests`" + ` VALUES ('test_metrics_created', 'analysis', 'SELECT COUNT(*) FROM user_metrics;', 'row_count', '>', '0');"
|
||||
dolt sql -q "INSERT INTO ` + "`dolt_tests`" + ` VALUES ('test_metrics_integrity', 'integrity', 'SELECT COUNT(*) FROM user_metrics um LEFT JOIN users u ON um.user_id = u.id WHERE u.id IS NULL;', 'row_count', '==', '0');"
|
||||
|
||||
# Run tests to validate analysis
|
||||
dolt sql -q "SELECT * FROM dolt_test_run();"
|
||||
|
||||
# Stage and commit using CLI
|
||||
dolt add user_metrics
|
||||
dolt commit -m "Add user behavior analysis"
|
||||
` + "```" + `
|
||||
|
||||
## Integration with External Tools
|
||||
|
||||
### Database Clients
|
||||
Most MySQL clients work with Dolt:
|
||||
- MySQL Workbench
|
||||
- phpMyAdmin
|
||||
- DataGrip
|
||||
- DBeaver
|
||||
|
||||
### Backup and Sync
|
||||
` + "```bash" + `
|
||||
# Push to remote
|
||||
dolt push origin main
|
||||
|
||||
# Pull changes
|
||||
dolt pull origin main
|
||||
|
||||
# Clone for backup
|
||||
dolt clone <remote-url> backup-location
|
||||
` + "```" + `
|
||||
|
||||
This guide enables agents to leverage Dolt's unique version control capabilities while maintaining data integrity and following collaborative development practices.`
|
||||
)
|
||||
|
||||
// GetDocTableName returns the name of the dolt table containing documents such as the license and readme
|
||||
|
||||
@@ -169,43 +169,28 @@ func AbortMerge(ctx *sql.Context, workingSet *doltdb.WorkingSet, roots doltdb.Ro
|
||||
return nil, fmt.Errorf("there is no merge to abort")
|
||||
}
|
||||
|
||||
tbls, err := doltdb.UnionTableNames(ctx, roots.Working, roots.Staged, roots.Head)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tbls, err = doltdb.ExcludeIgnoredTables(ctx, roots, tbls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roots, err = actions.MoveTablesFromHeadToWorking(ctx, roots, tbls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
preMergeWorkingRoot := workingSet.MergeState().PreMergeWorkingRoot()
|
||||
preMergeWorkingTables, err := preMergeWorkingRoot.GetTableNames(ctx, doltdb.DefaultSchemaName, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonIgnoredTables, err := doltdb.ExcludeIgnoredTables(ctx, roots, doltdb.ToTableNames(preMergeWorkingTables, doltdb.DefaultSchemaName))
|
||||
|
||||
// Revert the working set back to the pre-merge working root
|
||||
workingSet = workingSet.WithStagedRoot(roots.Head).WithWorkingRoot(preMergeWorkingRoot).WithStagedRoot(roots.Head)
|
||||
workingSet = workingSet.ClearMerge()
|
||||
|
||||
// Carry over any ignored tables (which could have been manually modified by a user while a merge was halted)
|
||||
ignoredTables, err := doltdb.IdentifyIgnoredTables(ctx, roots, doltdb.ToTableNames(preMergeWorkingTables, doltdb.DefaultSchemaName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
someTablesAreIgnored := len(nonIgnoredTables) != len(preMergeWorkingTables)
|
||||
|
||||
if someTablesAreIgnored {
|
||||
newWorking, err := actions.MoveTablesBetweenRoots(ctx, nonIgnoredTables, preMergeWorkingRoot, roots.Working)
|
||||
if len(ignoredTables) > 0 {
|
||||
newWorking, err := actions.MoveTablesBetweenRoots(ctx, ignoredTables, roots.Working, preMergeWorkingRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workingSet = workingSet.WithWorkingRoot(newWorking)
|
||||
} else {
|
||||
workingSet = workingSet.WithWorkingRoot(preMergeWorkingRoot)
|
||||
}
|
||||
// Unstage everything by making Staged match Head
|
||||
workingSet = workingSet.WithStagedRoot(roots.Head)
|
||||
workingSet = workingSet.ClearMerge()
|
||||
|
||||
return workingSet, nil
|
||||
}
|
||||
|
||||
@@ -688,7 +688,7 @@ func sqlTypeString(t typeinfo.TypeInfo) string {
|
||||
}
|
||||
|
||||
// Extended types are string serializable, so we'll just prepend a tag
|
||||
if extendedType, ok := typ.(sqltypes.ExtendedType); ok {
|
||||
if extendedType, ok := typ.(sql.ExtendedType); ok {
|
||||
serializedType, err := sqltypes.SerializeTypeToString(extendedType)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -476,7 +476,7 @@ func (si *schemaImpl) getKeyColumnsDescriptor(vs val.ValueStore, convertAddressC
|
||||
var handler val.TupleTypeHandler
|
||||
|
||||
_, contentHashedField := contentHashedFields[tag]
|
||||
extendedType, isExtendedType := sqlType.(gmstypes.ExtendedType)
|
||||
extendedType, isExtendedType := sqlType.(sql.ExtendedType)
|
||||
|
||||
if isExtendedType {
|
||||
encoding := EncodingFromSqlType(sqlType)
|
||||
@@ -573,7 +573,7 @@ func (si *schemaImpl) GetValueDescriptor(vs val.ValueStore) val.TupleDesc {
|
||||
collations = append(collations, sql.Collation_Unspecified)
|
||||
}
|
||||
|
||||
if extendedType, ok := sqlType.(gmstypes.ExtendedType); ok {
|
||||
if extendedType, ok := sqlType.(sql.ExtendedType); ok {
|
||||
switch encoding {
|
||||
case serial.EncodingExtendedAddr:
|
||||
handlers = append(handlers, val.NewExtendedAddressTypeHandler(vs, extendedType))
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/dolthub/vitess/go/vt/proto/query"
|
||||
|
||||
"github.com/dolthub/dolt/go/gen/fb/serial"
|
||||
@@ -26,11 +25,11 @@ import (
|
||||
|
||||
// EncodingFromSqlType returns a serial.Encoding for a sql.Type.
|
||||
func EncodingFromSqlType(typ sql.Type) serial.Encoding {
|
||||
if extendedType, ok := typ.(types.ExtendedType); ok {
|
||||
if extendedType, ok := typ.(sql.ExtendedType); ok {
|
||||
switch extendedType.MaxSerializedWidth() {
|
||||
case types.ExtendedTypeSerializedWidth_64K:
|
||||
case sql.ExtendedTypeSerializedWidth_64K:
|
||||
return serial.EncodingExtended
|
||||
case types.ExtendedTypeSerializedWidth_Unbounded:
|
||||
case sql.ExtendedTypeSerializedWidth_Unbounded:
|
||||
// Always uses adaptive encoding for extended types, regardless of the setting of UseAdaptiveEncoding below.
|
||||
return serial.EncodingExtendedAdaptive
|
||||
default:
|
||||
@@ -112,6 +111,8 @@ func EncodingFromQueryType(typ query.Type) serial.Encoding {
|
||||
return serial.EncodingStringAdaptive
|
||||
}
|
||||
return serial.EncodingStringAddr
|
||||
case query.Type_VECTOR:
|
||||
return serial.EncodingBytesAdaptive
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown encoding %v", typ))
|
||||
}
|
||||
|
||||
@@ -121,8 +121,8 @@ func (d doltTypeCompatibilityChecker) IsTypeChangeCompatible(from, to typeinfo.T
|
||||
|
||||
// The TypeCompatibility checkers don't support ExtendedTypes added by integrators, so if we see
|
||||
// one, return early and report the types are not compatible.
|
||||
_, fromExtendedType := fromSqlType.(types.ExtendedType)
|
||||
_, toExtendedType := toSqlType.(types.ExtendedType)
|
||||
_, fromExtendedType := fromSqlType.(sql.ExtendedType)
|
||||
_, toExtendedType := toSqlType.(sql.ExtendedType)
|
||||
if fromExtendedType || toExtendedType {
|
||||
return res
|
||||
}
|
||||
|
||||
@@ -367,7 +367,7 @@ func mustCreateType(sqlType sql.Type) typeinfo.TypeInfo {
|
||||
// extendedType is a no-op implementation of gmstypes.ExtendedType, used for testing type compatibility with extended types.
|
||||
type extendedType struct{}
|
||||
|
||||
var _ gmstypes.ExtendedType = extendedType{}
|
||||
var _ sql.ExtendedType = extendedType{}
|
||||
|
||||
func (e extendedType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
|
||||
panic("unimplemented")
|
||||
@@ -381,7 +381,7 @@ func (e extendedType) Convert(ctx context.Context, i interface{}) (interface{},
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (e extendedType) ConvertToType(ctx *sql.Context, typ gmstypes.ExtendedType, val any) (any, error) {
|
||||
func (e extendedType) ConvertToType(ctx *sql.Context, typ sql.ExtendedType, val any) (any, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
@@ -433,6 +433,6 @@ func (e extendedType) FormatValue(val any) (string, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (e extendedType) MaxSerializedWidth() gmstypes.ExtendedTypeSerializedWidth {
|
||||
func (e extendedType) MaxSerializedWidth() sql.ExtendedTypeSerializedWidth {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ const (
|
||||
// extendedType is a type that refers to an ExtendedType in GMS. These are only supported in the new format, and have many
|
||||
// more limitations than traditional types (for now).
|
||||
type extendedType struct {
|
||||
sqlExtendedType gmstypes.ExtendedType
|
||||
sqlExtendedType sql.ExtendedType
|
||||
}
|
||||
|
||||
var _ TypeInfo = (*extendedType)(nil)
|
||||
@@ -49,7 +49,7 @@ func CreateExtendedTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
}
|
||||
|
||||
// CreateExtendedTypeFromSqlType creates a TypeInfo from the given extended type.
|
||||
func CreateExtendedTypeFromSqlType(typ gmstypes.ExtendedType) TypeInfo {
|
||||
func CreateExtendedTypeFromSqlType(typ sql.ExtendedType) TypeInfo {
|
||||
return &extendedType{typ}
|
||||
}
|
||||
|
||||
@@ -110,7 +110,7 @@ func (ti *extendedType) NomsKind() types.NomsKind {
|
||||
|
||||
// Promote implements the TypeInfo interface.
|
||||
func (ti *extendedType) Promote() TypeInfo {
|
||||
return &extendedType{ti.sqlExtendedType.Promote().(gmstypes.ExtendedType)}
|
||||
return &extendedType{ti.sqlExtendedType.Promote().(sql.ExtendedType)}
|
||||
}
|
||||
|
||||
// String implements the TypeInfo interface.
|
||||
|
||||
@@ -46,6 +46,7 @@ const (
|
||||
UuidTypeIdentifier Identifier = "uuid"
|
||||
VarBinaryTypeIdentifier Identifier = "varbinary"
|
||||
VarStringTypeIdentifier Identifier = "varstring"
|
||||
VectorTypeIdentifier Identifier = "vector"
|
||||
YearTypeIdentifier Identifier = "year"
|
||||
GeometryTypeIdentifier Identifier = "geometry"
|
||||
PointTypeIdentifier Identifier = "point"
|
||||
@@ -136,7 +137,7 @@ type TypeInfo interface {
|
||||
|
||||
// FromSqlType takes in a sql.Type and returns the most relevant TypeInfo.
|
||||
func FromSqlType(sqlType sql.Type) (TypeInfo, error) {
|
||||
if gmsExtendedType, ok := sqlType.(gmstypes.ExtendedType); ok {
|
||||
if gmsExtendedType, ok := sqlType.(sql.ExtendedType); ok {
|
||||
return CreateExtendedTypeFromSqlType(gmsExtendedType), nil
|
||||
}
|
||||
sqlType, err := fillInCollationWithDefault(sqlType)
|
||||
@@ -273,6 +274,12 @@ func FromSqlType(sqlType sql.Type) (TypeInfo, error) {
|
||||
return nil, fmt.Errorf(`expected "SetTypeIdentifier" from SQL basetype "Set"`)
|
||||
}
|
||||
return &setType{setSQLType}, nil
|
||||
case sqltypes.Vector:
|
||||
vectorSQLType, ok := sqlType.(gmstypes.VectorType)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`expected "VectorTypeIdentifier" from SQL basetype "Vector"`)
|
||||
}
|
||||
return &vectorType{vectorSQLType}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf(`no type info can be created from SQL base type "%v"`, sqlType.String())
|
||||
}
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
// Copyright 2025 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package typeinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/dolthub/go-mysql-server/sql/values"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
const (
|
||||
vectorTypeParam_Length = "length"
|
||||
)
|
||||
|
||||
// As a type, this is modeled more after MySQL's story for binary data. There, it's treated
|
||||
// as a string that is interpreted as raw bytes, rather than as a bespoke data structure,
|
||||
// and thus this is mirrored here in its implementation. This will minimize any differences
|
||||
// that could arise.
|
||||
//
|
||||
// This type handles the BLOB types. BINARY and VARBINARY are handled by inlineBlobType.
|
||||
type vectorType struct {
|
||||
sqlVectorType gmstypes.VectorType
|
||||
}
|
||||
|
||||
var _ TypeInfo = (*vectorType)(nil)
|
||||
|
||||
// ConvertNomsValueToValue implements TypeInfo interface.
|
||||
func (ti *vectorType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
|
||||
if val, ok := v.(types.Blob); ok {
|
||||
return fromBlob(val)
|
||||
}
|
||||
if _, ok := v.(types.Null); ok || v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf(`"%v" cannot convert NomsKind "%v" to a value`, ti, v.Kind())
|
||||
}
|
||||
|
||||
// ReadFrom reads a go value from a noms types.CodecReader directly
|
||||
func (ti *vectorType) ReadFrom(_ *types.NomsBinFormat, reader types.CodecReader) (interface{}, error) {
|
||||
k := reader.PeekKind()
|
||||
switch k {
|
||||
case types.BlobKind:
|
||||
val, err := reader.ReadBlob()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fromBlob(val)
|
||||
case types.NullKind:
|
||||
_ = reader.ReadKind()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(`"%v" cannot convert NomsKind "%v" to a value`, ti.String(), k)
|
||||
}
|
||||
|
||||
// ConvertValueToNomsValue implements TypeInfo interface.
|
||||
func (ti *vectorType) ConvertValueToNomsValue(ctx context.Context, vrw types.ValueReadWriter, v interface{}) (types.Value, error) {
|
||||
if v == nil {
|
||||
return types.NullValue, nil
|
||||
}
|
||||
strVal, _, err := ti.sqlVectorType.Convert(ctx, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
val, ok := strVal.([]byte)
|
||||
if ok {
|
||||
return types.NewBlob(ctx, vrw, strings.NewReader(string(val)))
|
||||
}
|
||||
return nil, fmt.Errorf(`"%v" cannot convert value "%v" of type "%T" as it is invalid`, ti.String(), v, v)
|
||||
}
|
||||
|
||||
// Equals implements TypeInfo interface.
|
||||
func (ti *vectorType) Equals(other TypeInfo) bool {
|
||||
if other == nil {
|
||||
return false
|
||||
}
|
||||
if ti2, ok := other.(*vectorType); ok {
|
||||
return ti.sqlVectorType.Dimensions == ti2.sqlVectorType.Dimensions
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FormatValue implements TypeInfo interface.
|
||||
func (ti *vectorType) FormatValue(v types.Value) (*string, error) {
|
||||
if val, ok := v.(types.Blob); ok {
|
||||
resStr, err := fromBlob(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// This is safe (See https://go101.org/article/unsafe.html)
|
||||
return (*string)(unsafe.Pointer(&resStr)), nil
|
||||
}
|
||||
if _, ok := v.(types.Null); ok || v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf(`"%v" cannot convert NomsKind "%v" to a string`, ti, v.Kind())
|
||||
}
|
||||
|
||||
// GetTypeIdentifier implements TypeInfo interface.
|
||||
func (ti *vectorType) GetTypeIdentifier() Identifier {
|
||||
return VectorTypeIdentifier
|
||||
}
|
||||
|
||||
// GetTypeParams implements TypeInfo interface.
|
||||
func (ti *vectorType) GetTypeParams() map[string]string {
|
||||
return map[string]string{
|
||||
vectorTypeParam_Length: strconv.FormatInt(int64(ti.sqlVectorType.Dimensions), 10),
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid implements TypeInfo interface.
|
||||
func (ti *vectorType) IsValid(v types.Value) bool {
|
||||
if val, ok := v.(types.Blob); ok {
|
||||
if int(val.Len()) == ti.sqlVectorType.Dimensions*int(values.Float32Size) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if _, ok := v.(types.Null); ok || v == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NomsKind implements TypeInfo interface.
|
||||
func (ti *vectorType) NomsKind() types.NomsKind {
|
||||
return types.BlobKind
|
||||
}
|
||||
|
||||
// Promote implements TypeInfo interface.
|
||||
func (ti *vectorType) Promote() TypeInfo {
|
||||
return ti
|
||||
}
|
||||
|
||||
// String implements TypeInfo interface.
|
||||
func (ti *vectorType) String() string {
|
||||
return fmt.Sprintf(`Vector(%v)`, ti.sqlVectorType.Dimensions)
|
||||
}
|
||||
|
||||
// ToSqlType implements TypeInfo interface.
|
||||
func (ti *vectorType) ToSqlType() sql.Type {
|
||||
return ti.sqlVectorType
|
||||
}
|
||||
@@ -408,66 +408,85 @@ func ResolveSchemaConflicts(ctx *sql.Context, ddb *doltdb.DoltDB, ws *doltdb.Wor
|
||||
return ws.WithWorkingRoot(root).WithUnmergableTables(unmerged).WithMergedTables(merged), nil
|
||||
}
|
||||
|
||||
func ResolveDataConflictsForTable(ctx *sql.Context, root doltdb.RootValue, tblName doltdb.TableName, ours bool, getEditorOpts func() (editor.Options, error)) (doltdb.RootValue, bool, error) {
|
||||
tbl, ok, err := root.GetTable(ctx, tblName)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, false, doltdb.ErrTableNotFound
|
||||
}
|
||||
|
||||
if has, err := tbl.HasConflicts(ctx); err != nil {
|
||||
return nil, false, err
|
||||
} else if !has {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
sch, err := tbl.GetSchema(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
_, ourSch, theirSch, err := tbl.GetConflictSchemas(ctx, tblName)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if ours && !schema.ColCollsAreEqual(sch.GetAllCols(), ourSch.GetAllCols()) {
|
||||
return nil, false, ErrConfSchIncompatible
|
||||
} else if !ours && !schema.ColCollsAreEqual(sch.GetAllCols(), theirSch.GetAllCols()) {
|
||||
return nil, false, ErrConfSchIncompatible
|
||||
}
|
||||
|
||||
if !ours {
|
||||
if tbl.Format() == types.Format_DOLT {
|
||||
tbl, err = resolveProllyConflicts(ctx, tbl, tblName, ourSch, sch)
|
||||
} else {
|
||||
opts, err := getEditorOpts()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
tbl, err = resolveNomsConflicts(ctx, opts, tbl, tblName.Name, sch)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
newRoot, err := clearTableAndUpdateRoot(ctx, root, tbl, tblName)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
err = validateConstraintViolations(ctx, root, newRoot, tblName)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return newRoot, true, nil
|
||||
}
|
||||
|
||||
func ResolveDataConflicts(ctx *sql.Context, dSess *dsess.DoltSession, root doltdb.RootValue, dbName string, ours bool, tblNames []doltdb.TableName) error {
|
||||
getEditorOpts := func() (editor.Options, error) {
|
||||
state, _, err := dSess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return editor.Options{}, err
|
||||
}
|
||||
var opts editor.Options
|
||||
if ws := state.WriteSession(); ws != nil {
|
||||
opts = ws.GetOptions()
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
for _, tblName := range tblNames {
|
||||
tbl, ok, err := root.GetTable(ctx, tblName)
|
||||
newRoot, hasConflicts, err := ResolveDataConflictsForTable(ctx, root, tblName, ours, getEditorOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return doltdb.ErrTableNotFound
|
||||
}
|
||||
|
||||
if has, err := tbl.HasConflicts(ctx); err != nil {
|
||||
return err
|
||||
} else if !has {
|
||||
if !hasConflicts {
|
||||
continue
|
||||
}
|
||||
|
||||
sch, err := tbl.GetSchema(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, ourSch, theirSch, err := tbl.GetConflictSchemas(ctx, tblName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ours && !schema.ColCollsAreEqual(sch.GetAllCols(), ourSch.GetAllCols()) {
|
||||
return ErrConfSchIncompatible
|
||||
} else if !ours && !schema.ColCollsAreEqual(sch.GetAllCols(), theirSch.GetAllCols()) {
|
||||
return ErrConfSchIncompatible
|
||||
}
|
||||
|
||||
if !ours {
|
||||
if tbl.Format() == types.Format_DOLT {
|
||||
tbl, err = resolveProllyConflicts(ctx, tbl, tblName, ourSch, sch)
|
||||
} else {
|
||||
state, _, err := dSess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var opts editor.Options
|
||||
if ws := state.WriteSession(); ws != nil {
|
||||
opts = ws.GetOptions()
|
||||
}
|
||||
tbl, err = resolveNomsConflicts(ctx, opts, tbl, tblName.Name, sch)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
newRoot, err := clearTableAndUpdateRoot(ctx, root, tbl, tblName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateConstraintViolations(ctx, root, newRoot, tblName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
root = newRoot
|
||||
}
|
||||
return dSess.SetWorkingRoot(ctx, dbName, root)
|
||||
@@ -513,9 +532,6 @@ func DoDoltConflictsResolve(ctx *sql.Context, args []string) (int, error) {
|
||||
|
||||
if len(strTableNames) == 1 && strTableNames[0] == "." {
|
||||
all := actions.GetAllTableNames(ctx, ws.WorkingRoot())
|
||||
if err != nil {
|
||||
return 1, nil
|
||||
}
|
||||
tableNames = all
|
||||
} else {
|
||||
for _, tblName := range strTableNames {
|
||||
|
||||
@@ -16,12 +16,15 @@ package dtablefunctions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/expression"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
"gopkg.in/src-d/go-errors.v1"
|
||||
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/cli"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/merge"
|
||||
@@ -32,6 +35,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dtables"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
|
||||
dolttable "github.com/dolthub/dolt/go/libraries/doltcore/table"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
@@ -56,6 +60,8 @@ type DiffTableFunction struct {
|
||||
fromDate *types.Timestamp
|
||||
toDate *types.Timestamp
|
||||
sqlSch sql.Schema
|
||||
showSkinny bool
|
||||
includeCols map[string]struct{}
|
||||
}
|
||||
|
||||
// NewInstance creates a new instance of TableFunction interface
|
||||
@@ -100,26 +106,25 @@ func (dtf *DiffTableFunction) WithDatabase(database sql.Database) (sql.Node, err
|
||||
|
||||
// Expressions implements the sql.Expressioner interface
|
||||
func (dtf *DiffTableFunction) Expressions() []sql.Expression {
|
||||
exprs := []sql.Expression{}
|
||||
|
||||
if dtf.dotCommitExpr != nil {
|
||||
return []sql.Expression{
|
||||
dtf.dotCommitExpr, dtf.tableNameExpr,
|
||||
}
|
||||
}
|
||||
return []sql.Expression{
|
||||
dtf.fromCommitExpr, dtf.toCommitExpr, dtf.tableNameExpr,
|
||||
exprs = append(exprs, dtf.dotCommitExpr, dtf.tableNameExpr)
|
||||
} else {
|
||||
exprs = append(exprs, dtf.fromCommitExpr, dtf.toCommitExpr, dtf.tableNameExpr)
|
||||
}
|
||||
return exprs
|
||||
}
|
||||
|
||||
// WithExpressions implements the sql.Expressioner interface
|
||||
func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
|
||||
if len(expression) < 2 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(dtf.Name(), "2 to 3", len(expression))
|
||||
}
|
||||
|
||||
func (dtf *DiffTableFunction) WithExpressions(expressions ...sql.Expression) (sql.Node, error) {
|
||||
newDtf := *dtf
|
||||
// TODO: For now, we will only support literal / fully-resolved arguments to the
|
||||
// DiffTableFunction to avoid issues where the schema is needed in the analyzer
|
||||
// before the arguments could be resolved.
|
||||
for _, expr := range expression {
|
||||
var exprStrs []string
|
||||
strToExpr := map[string]sql.Expression{}
|
||||
for _, expr := range expressions {
|
||||
if !expr.Resolved() {
|
||||
return nil, ErrInvalidNonLiteralArgument.New(dtf.Name(), expr.String())
|
||||
}
|
||||
@@ -127,22 +132,52 @@ func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql
|
||||
if _, ok := expr.(sql.FunctionExpression); ok {
|
||||
return nil, ErrInvalidNonLiteralArgument.New(dtf.Name(), expr.String())
|
||||
}
|
||||
strVal := expr.String()
|
||||
if lit, ok := expr.(*expression.Literal); ok { // rm quotes from string literals
|
||||
strVal = fmt.Sprintf("%v", lit.Value())
|
||||
}
|
||||
exprStrs = append(exprStrs, strVal) // args extracted from apr later to filter out options
|
||||
strToExpr[strVal] = expr
|
||||
}
|
||||
|
||||
newDtf := *dtf
|
||||
if strings.Contains(expression[0].String(), "..") {
|
||||
if len(expression) != 2 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(fmt.Sprintf("%v with .. or ...", newDtf.Name()), 2, len(expression))
|
||||
apr, err := cli.CreateDiffArgParser(true).Parse(exprStrs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if apr.Contains(cli.SkinnyFlag) {
|
||||
newDtf.showSkinny = true
|
||||
}
|
||||
|
||||
if cols, ok := apr.GetValueList(cli.IncludeCols); ok {
|
||||
newDtf.includeCols = make(map[string]struct{})
|
||||
for _, col := range cols {
|
||||
newDtf.includeCols[col] = struct{}{}
|
||||
}
|
||||
newDtf.dotCommitExpr = expression[0]
|
||||
newDtf.tableNameExpr = expression[1]
|
||||
}
|
||||
|
||||
expressions = []sql.Expression{}
|
||||
for _, posArg := range apr.Args {
|
||||
expressions = append(expressions, strToExpr[posArg])
|
||||
}
|
||||
|
||||
if len(expressions) < 2 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(dtf.Name(), "2 to 3", len(expressions))
|
||||
}
|
||||
|
||||
if strings.Contains(expressions[0].String(), "..") {
|
||||
if len(expressions) != 2 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(fmt.Sprintf("%v with .. or ...", newDtf.Name()), 2, len(expressions))
|
||||
}
|
||||
newDtf.dotCommitExpr = expressions[0]
|
||||
newDtf.tableNameExpr = expressions[1]
|
||||
} else {
|
||||
if len(expression) != 3 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(newDtf.Name(), 3, len(expression))
|
||||
if len(expressions) != 3 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(newDtf.Name(), 3, len(expressions))
|
||||
}
|
||||
newDtf.fromCommitExpr = expression[0]
|
||||
newDtf.toCommitExpr = expression[1]
|
||||
newDtf.tableNameExpr = expression[2]
|
||||
newDtf.fromCommitExpr = expressions[0]
|
||||
newDtf.toCommitExpr = expressions[1]
|
||||
newDtf.tableNameExpr = expressions[2]
|
||||
}
|
||||
|
||||
fromCommitVal, toCommitVal, dotCommitVal, tableName, err := newDtf.evaluateArguments()
|
||||
@@ -423,6 +458,110 @@ func (dtf *DiffTableFunction) evaluateArguments() (interface{}, interface{}, int
|
||||
return fromCommitVal, toCommitVal, nil, tableName, nil
|
||||
}
|
||||
|
||||
// filterDeltaSchemaToSkinnyCols creates a filtered version of the table delta that omits columns which are identical
|
||||
// in type and value across all rows in both schemas, except for primary key columns or explicitly included using the
|
||||
// include-cols option. This also updates dtf.tableDelta with the filtered result.
|
||||
func (dtf *DiffTableFunction) filterDeltaSchemaToSkinnyCols(ctx *sql.Context, delta *diff.TableDelta) (*diff.TableDelta, error) {
|
||||
if delta.FromTable == nil || delta.ToTable == nil {
|
||||
return delta, nil
|
||||
}
|
||||
|
||||
// gather map of potential cols for removal from skinny diff
|
||||
equalDiffColsIndices := map[string][2]int{}
|
||||
toCols := delta.ToSch.GetAllCols()
|
||||
for fromIdx, fromCol := range delta.FromSch.GetAllCols().GetColumns() {
|
||||
if _, ok := dtf.includeCols[fromCol.Name]; ok {
|
||||
continue // user explicitly included this column
|
||||
}
|
||||
|
||||
col, ok := delta.ToSch.GetAllCols().GetByName(fromCol.Name)
|
||||
if !ok { // column was dropped
|
||||
continue
|
||||
}
|
||||
if fromCol.TypeInfo.Equals(col.TypeInfo) {
|
||||
toIdx := toCols.TagToIdx[toCols.NameToCol[fromCol.Name].Tag]
|
||||
equalDiffColsIndices[fromCol.Name] = [2]int{fromIdx, toIdx}
|
||||
}
|
||||
}
|
||||
|
||||
fromRowData, err := delta.FromTable.GetRowData(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toRowData, err := delta.ToTable.GetRowData(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fromIter, err := dolttable.NewTableIterator(ctx, delta.FromSch, fromRowData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fromIter.Close(ctx)
|
||||
|
||||
toIter, err := dolttable.NewTableIterator(ctx, delta.ToSch, toRowData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer toIter.Close(ctx)
|
||||
|
||||
for len(equalDiffColsIndices) > 0 {
|
||||
fromRow, fromErr := fromIter.Next(ctx)
|
||||
toRow, toErr := toIter.Next(ctx)
|
||||
|
||||
if fromErr == io.EOF && toErr == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if fromErr != nil && fromErr != io.EOF {
|
||||
return nil, fromErr
|
||||
}
|
||||
|
||||
if toErr != nil && toErr != io.EOF {
|
||||
return nil, toErr
|
||||
}
|
||||
|
||||
// xor: if only one is nil, then all cols are diffs
|
||||
if (fromRow == nil) != (toRow == nil) {
|
||||
equalDiffColsIndices = map[string][2]int{}
|
||||
break
|
||||
}
|
||||
|
||||
if fromRow == nil && toRow == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for colName, idx := range equalDiffColsIndices {
|
||||
if fromRow[idx[0]] != toRow[idx[1]] { // same row and col, values differ
|
||||
delete(equalDiffColsIndices, colName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var fromSkCols []schema.Column
|
||||
for _, col := range delta.FromSch.GetAllCols().GetColumns() {
|
||||
_, ok := equalDiffColsIndices[col.Name]
|
||||
if col.IsPartOfPK || !ok {
|
||||
fromSkCols = append(fromSkCols, col)
|
||||
}
|
||||
}
|
||||
|
||||
var toSkCols []schema.Column
|
||||
for _, col := range delta.ToSch.GetAllCols().GetColumns() {
|
||||
_, ok := equalDiffColsIndices[col.Name]
|
||||
if col.IsPartOfPK || !ok {
|
||||
toSkCols = append(toSkCols, col)
|
||||
}
|
||||
}
|
||||
|
||||
skDelta := *delta
|
||||
skDelta.FromSch = schema.MustSchemaFromCols(schema.NewColCollection(fromSkCols...))
|
||||
skDelta.ToSch = schema.MustSchemaFromCols(schema.NewColCollection(toSkCols...))
|
||||
dtf.tableDelta = skDelta
|
||||
return &skDelta, nil
|
||||
}
|
||||
|
||||
func (dtf *DiffTableFunction) generateSchema(ctx *sql.Context, fromCommitVal, toCommitVal, dotCommitVal interface{}, tableName string) error {
|
||||
if !dtf.Resolved() {
|
||||
return nil
|
||||
@@ -438,27 +577,27 @@ func (dtf *DiffTableFunction) generateSchema(ctx *sql.Context, fromCommitVal, to
|
||||
return err
|
||||
}
|
||||
|
||||
if dtf.showSkinny {
|
||||
skDelta, err := dtf.filterDeltaSchemaToSkinnyCols(ctx, &delta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delta = *skDelta
|
||||
}
|
||||
|
||||
fromTable, fromTableExists := delta.FromTable, delta.FromTable != nil
|
||||
toTable, toTableExists := delta.ToTable, delta.ToTable != nil
|
||||
|
||||
if !toTableExists && !fromTableExists {
|
||||
var format *types.NomsBinFormat
|
||||
if toTableExists {
|
||||
format = toTable.Format()
|
||||
} else if fromTableExists {
|
||||
format = fromTable.Format()
|
||||
} else {
|
||||
return sql.ErrTableNotFound.New(tableName)
|
||||
}
|
||||
|
||||
var toSchema, fromSchema schema.Schema
|
||||
var format *types.NomsBinFormat
|
||||
|
||||
if fromTableExists {
|
||||
fromSchema = delta.FromSch
|
||||
format = fromTable.Format()
|
||||
}
|
||||
|
||||
if toTableExists {
|
||||
toSchema = delta.ToSch
|
||||
format = toTable.Format()
|
||||
}
|
||||
|
||||
diffTableSch, j, err := dtables.GetDiffTableSchemaAndJoiner(format, fromSchema, toSchema)
|
||||
diffTableSch, j, err := dtables.GetDiffTableSchemaAndJoiner(format, delta.FromSch, delta.ToSch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -571,15 +710,14 @@ func (dtf *DiffTableFunction) IsReadOnly() bool {
|
||||
|
||||
// String implements the Stringer interface
|
||||
func (dtf *DiffTableFunction) String() string {
|
||||
args := []string{}
|
||||
if dtf.dotCommitExpr != nil {
|
||||
return fmt.Sprintf("DOLT_DIFF(%s, %s)",
|
||||
dtf.dotCommitExpr.String(),
|
||||
dtf.tableNameExpr.String())
|
||||
args = append(args, dtf.dotCommitExpr.String(), dtf.tableNameExpr.String())
|
||||
} else {
|
||||
args = append(args, dtf.fromCommitExpr.String(), dtf.toCommitExpr.String(), dtf.tableNameExpr.String())
|
||||
}
|
||||
return fmt.Sprintf("DOLT_DIFF(%s, %s, %s)",
|
||||
dtf.fromCommitExpr.String(),
|
||||
dtf.toCommitExpr.String(),
|
||||
dtf.tableNameExpr.String())
|
||||
|
||||
return fmt.Sprintf("DOLT_DIFF(%s)", strings.Join(args, ", "))
|
||||
}
|
||||
|
||||
// Name implements the sql.TableFunction interface
|
||||
|
||||
@@ -16,6 +16,7 @@ package dtablefunctions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
gms "github.com/dolthub/go-mysql-server"
|
||||
@@ -191,7 +192,7 @@ func (trtf *TestsRunTableFunction) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIt
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, row := range *testRows {
|
||||
for _, row := range testRows {
|
||||
result, err := trtf.queryAndAssert(row)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -254,7 +255,7 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) (*[]sql.Row, error) {
|
||||
func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) ([]sql.Row, error) {
|
||||
var queries []string
|
||||
|
||||
if arg == "*" {
|
||||
@@ -280,12 +281,21 @@ func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) (*[]sql.Row, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := sql.RowIterToRows(trtf.ctx, iter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Calling iter.Close(ctx) will cause TrackedRowIter to cancel the context, causing problems when running with
|
||||
// dolt sql-server. Since we only support `SELECT...` queries anyway, it's not necessary to Close() the iter.
|
||||
var rows []sql.Row
|
||||
for {
|
||||
row, rErr := iter.Next(trtf.ctx)
|
||||
if rErr == io.EOF {
|
||||
break
|
||||
}
|
||||
if rErr != nil {
|
||||
return nil, rErr
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
if len(rows) > 0 {
|
||||
return &rows, nil
|
||||
return rows, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("could not find tests for argument: %s", arg)
|
||||
|
||||
@@ -283,7 +283,7 @@ var ModifyAndChangeColumnScripts = []queries.ScriptTest{
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table people modify rating double default 'not a number'",
|
||||
ExpectedErrStr: "incompatible type for default value: error: 'not a number' is not a valid value for 'double'",
|
||||
ExpectedErrStr: "incompatible type for default value: Truncated incorrect double value: not a number",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -909,6 +909,12 @@ func TestVectorFunctions(t *testing.T) {
|
||||
enginetest.TestVectorFunctions(t, harness)
|
||||
}
|
||||
|
||||
func TestVectorType(t *testing.T) {
|
||||
harness := newDoltHarness(t)
|
||||
defer harness.Close()
|
||||
enginetest.TestVectorType(t, harness)
|
||||
}
|
||||
|
||||
func TestIndexPrefix(t *testing.T) {
|
||||
skipOldFormat(t)
|
||||
harness := newDoltHarness(t)
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
package enginetest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/enginetest/queries"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
@@ -818,7 +821,138 @@ var Dolt1DiffSystemTableScripts = []queries.ScriptTest{
|
||||
},
|
||||
}
|
||||
|
||||
// assertDoltDiffColumnCount returns assertions that verify a dolt_diff view
|
||||
// has the expected number of distinct data columns (excluding commit metadata).
|
||||
func assertDoltDiffColumnCount(view, selectStmt string, expected int64) []queries.ScriptTestAssertion {
|
||||
excluded := []string{
|
||||
"'to_commit'",
|
||||
"'from_commit'",
|
||||
"'to_commit_date'",
|
||||
"'from_commit_date'",
|
||||
"'diff_type'",
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT COUNT(DISTINCT REPLACE(REPLACE(column_name, 'to_', ''), 'from_', ''))
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '%s'
|
||||
AND column_name NOT IN (%s)`,
|
||||
view, strings.Join(excluded, ", "),
|
||||
)
|
||||
|
||||
return []queries.ScriptTestAssertion{
|
||||
{Query: fmt.Sprintf("DROP VIEW IF EXISTS %s;", view)},
|
||||
{Query: fmt.Sprintf("CREATE VIEW %s AS %s;", view, selectStmt)},
|
||||
{Query: query, Expected: []sql.Row{{expected}}},
|
||||
{Query: fmt.Sprintf("DROP VIEW %s;", view)},
|
||||
}
|
||||
}
|
||||
|
||||
var DiffTableFunctionScriptTests = []queries.ScriptTest{
|
||||
{
|
||||
Name: "dolt_diff: SELECT * skinny schema visibility",
|
||||
SetUpScript: []string{
|
||||
`CREATE TABLE t (
|
||||
pk BIGINT NOT NULL COMMENT 'tag:0',
|
||||
c1 BIGINT COMMENT 'tag:1',
|
||||
c2 BIGINT COMMENT 'tag:2',
|
||||
c3 BIGINT COMMENT 'tag:3',
|
||||
c4 BIGINT COMMENT 'tag:4',
|
||||
c5 BIGINT COMMENT 'tag:5',
|
||||
PRIMARY KEY (pk)
|
||||
);`,
|
||||
"call dolt_add('.')",
|
||||
"set @C0 = '';",
|
||||
"call dolt_commit_hash_out(@C0, '-m', 'Created table t');",
|
||||
"INSERT INTO t VALUES (0,1,2,3,4,5), (1,1,2,3,4,5);",
|
||||
"call dolt_add('.')",
|
||||
"set @C1 = '';",
|
||||
"call dolt_commit_hash_out(@C1, '-m', 'Added initial data');",
|
||||
|
||||
"UPDATE t SET c1=100, c3=300 WHERE pk=0;",
|
||||
"UPDATE t SET c2=200 WHERE pk=1;",
|
||||
"call dolt_add('.')",
|
||||
"set @C2 = '';",
|
||||
"call dolt_commit_hash_out(@C2, '-m', 'Updated some columns');",
|
||||
|
||||
"ALTER TABLE t ADD COLUMN c6 BIGINT;",
|
||||
"UPDATE t SET c6=600 WHERE pk=0;",
|
||||
"call dolt_add('.')",
|
||||
"set @C3 = '';",
|
||||
"call dolt_commit_hash_out(@C3, '-m', 'Added new column and updated it');",
|
||||
|
||||
"DELETE FROM t WHERE pk=1;",
|
||||
"call dolt_add('.')",
|
||||
"set @C4 = '';",
|
||||
"call dolt_commit_hash_out(@C4, '-m', 'Deleted a row');",
|
||||
},
|
||||
Assertions: func() []queries.ScriptTestAssertion {
|
||||
asserts := []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c3, d.to_c4, d.to_c5, d.from_pk, d.from_c1, d.from_c2, d.from_c3, d.from_c4, d.from_c5, d.diff_type " +
|
||||
"FROM (SELECT * FROM dolt_diff('--skinny', @C0, @C1, 't')) d " +
|
||||
"ORDER BY COALESCE(d.to_pk, d.from_pk)",
|
||||
Expected: []sql.Row{
|
||||
{int64(0), int64(1), int64(2), int64(3), int64(4), int64(5), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), "added"},
|
||||
{int64(1), int64(1), int64(2), int64(3), int64(4), int64(5), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), interface{}(nil), "added"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c3, d.from_pk, d.from_c1, d.from_c2, d.from_c3, d.diff_type " +
|
||||
"FROM (SELECT * FROM dolt_diff(@C1, @C2, 't')) d " +
|
||||
"ORDER BY COALESCE(d.to_pk, d.from_pk)",
|
||||
Expected: []sql.Row{
|
||||
{int64(0), int64(100), int64(2), int64(300), int64(0), int64(1), int64(2), int64(3), "modified"},
|
||||
{int64(1), int64(1), int64(200), int64(3), int64(1), int64(1), int64(2), int64(3), "modified"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c3, d.diff_type " +
|
||||
"FROM (SELECT * FROM dolt_diff('--skinny', @C1, @C2, 't')) d " +
|
||||
"ORDER BY d.to_pk",
|
||||
Expected: []sql.Row{
|
||||
{int64(0), int64(100), int64(2), int64(300), "modified"},
|
||||
{int64(1), int64(1), int64(200), int64(3), "modified"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "SELECT d.to_pk, d.to_c6, d.diff_type " +
|
||||
"FROM (SELECT * FROM dolt_diff('--skinny', @C2, @C3, 't')) d",
|
||||
Expected: []sql.Row{
|
||||
{int64(0), int64(600), "modified"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "SELECT d.to_pk, d.to_c1, d.to_c2, d.to_c6, d.diff_type " +
|
||||
"FROM (SELECT * FROM dolt_diff('--skinny', '--include-cols=c1,c2', @C2, @C3, 't')) d",
|
||||
Expected: []sql.Row{
|
||||
{int64(0), int64(100), int64(2), int64(600), "modified"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "SELECT d.from_pk, d.from_c1, d.from_c2, d.from_c3, d.from_c4, d.from_c5, d.from_c6, d.diff_type " +
|
||||
"FROM (SELECT * FROM dolt_diff('--skinny', @C3, @C4, 't')) d",
|
||||
Expected: []sql.Row{
|
||||
{int64(1), int64(1), int64(200), int64(3), int64(4), int64(5), nil, "removed"},
|
||||
},
|
||||
},
|
||||
}
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_all_01", "SELECT * FROM dolt_diff(@C0, @C1, 't')", 6)...)
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_01", "SELECT * FROM dolt_diff('--skinny', @C0, @C1, 't')", 6)...)
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_all_12", "SELECT * FROM dolt_diff(@C1, @C2, 't')", 6)...)
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_12", "SELECT * FROM dolt_diff('--skinny', @C1, @C2, 't')", 4)...)
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_all_23", "SELECT * FROM dolt_diff(@C2, @C3, 't')", 7)...)
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff('--skinny', @C2, @C3, 't')", 2)...)
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff(@C2, @C3, 't', '--skinny')", 2)...)
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff('--skinny', '--include-cols=c1,c2', @C2, @C3, 't')", 4)...)
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_23", "SELECT * FROM dolt_diff('--skinny', '--include-cols=c1,c2,c6', @C2, @C3, 't')", 4)...)
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_all_34", "SELECT * FROM dolt_diff(@C3, @C4, 't')", 7)...)
|
||||
asserts = append(asserts, assertDoltDiffColumnCount("v_skinny_34", "SELECT * FROM dolt_diff('--skinny', @C3, @C4, 't')", 7)...)
|
||||
|
||||
return asserts
|
||||
}(),
|
||||
},
|
||||
{
|
||||
Name: "invalid arguments",
|
||||
SetUpScript: []string{
|
||||
|
||||
@@ -129,7 +129,6 @@ var MergeScripts = []queries.ScriptTest{
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
// When there is a constraint violation for duplicate copies of a row in a keyless table, each row
|
||||
// will violate constraint in exactly the same way. Currently, the dolt_constraint_violations_<table>
|
||||
@@ -1391,6 +1390,90 @@ var MergeScripts = []queries.ScriptTest{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// Customer issue repro: when a merge halts from a conflict, if there was a table rename
|
||||
// it was preventing users from being able to abort the merge.
|
||||
Name: "Abort merge with table rename and conflict",
|
||||
SetUpScript: []string{
|
||||
"SET @@autocommit=0;",
|
||||
|
||||
// Create tables on main
|
||||
"CREATE TABLE conflict_table (pk int primary key, c1 varchar(100));",
|
||||
"CREATE TABLE table1 (pk int primary key, c1 varchar(100));",
|
||||
"CALL dolt_commit('-Am', 'creating tables on main');",
|
||||
"CALL dolt_branch('branch1');",
|
||||
|
||||
"INSERT INTO conflict_table VALUES (1, 'one');",
|
||||
"CALL dolt_commit('-Am', 'adding another table on main');",
|
||||
|
||||
// Rename a table on branch1 and create a conflict to halt the merge
|
||||
"CALL dolt_checkout('branch1');",
|
||||
"INSERT INTO conflict_table VALUES (1, 'uno');",
|
||||
"INSERT INTO table1 VALUES (1, 'one');",
|
||||
"RENAME TABLE table1 to table2;",
|
||||
"CALL dolt_commit('-Am', 'renaming table on branch1');",
|
||||
"CALL dolt_checkout('main');",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_merge('branch1');",
|
||||
Expected: []sql.Row{{"", 0, 1, "conflicts found"}},
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_merge('--abort');",
|
||||
Expected: []sql.Row{{"", 0, 0, "merge aborted"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// dolt_ignore does not work properly in Doltgres yet, so marking the dialect
|
||||
// as mysql so that it skips this test.
|
||||
Dialect: "mysql",
|
||||
Name: "Abort merge with table rename and conflict (with ignored table)",
|
||||
SetUpScript: []string{
|
||||
"SET @@autocommit=0;",
|
||||
|
||||
// Set up an ignored table
|
||||
"INSERT INTO dolt_ignore VALUES ('ignore_me', true);",
|
||||
"CREATE TABLE ignore_me (pk int primary key, c1 varchar(100));",
|
||||
"INSERT INTO ignore_me VALUES (1, 'uno');",
|
||||
|
||||
// Create tables on main
|
||||
"CREATE TABLE conflict_table (pk int primary key, c1 varchar(100));",
|
||||
"CREATE TABLE table1 (pk int primary key, c1 varchar(100));",
|
||||
"CALL dolt_commit('-Am', 'creating tables on main');",
|
||||
"CALL dolt_branch('branch1');",
|
||||
|
||||
"INSERT INTO conflict_table VALUES (1, 'one');",
|
||||
"CALL dolt_commit('-Am', 'adding another table on main');",
|
||||
|
||||
// Rename a table on branch1 and create a conflict to halt the merge
|
||||
"CALL dolt_checkout('branch1');",
|
||||
"INSERT INTO conflict_table VALUES (1, 'uno');",
|
||||
"INSERT INTO table1 VALUES (1, 'one');",
|
||||
"RENAME TABLE table1 to table2;",
|
||||
"CALL dolt_commit('-Am', 'renaming table on branch1');",
|
||||
"CALL dolt_checkout('main');",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_merge('branch1');",
|
||||
Expected: []sql.Row{{"", 0, 1, "conflicts found"}},
|
||||
},
|
||||
{
|
||||
Query: "INSERT INTO ignore_me VALUES (2, 'duex');",
|
||||
Expected: []sql.Row{{types.NewOkResult(1)}},
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_merge('--abort');",
|
||||
Expected: []sql.Row{{"", 0, 0, "merge aborted"}},
|
||||
},
|
||||
{
|
||||
Query: "SELECT * FROM ignore_me;",
|
||||
Expected: []sql.Row{{1, "uno"}, {2, "duex"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "CALL DOLT_MERGE complains when a merge overrides local changes",
|
||||
SetUpScript: []string{
|
||||
|
||||
@@ -1060,7 +1060,9 @@ func TestDoltIndexBetween(t *testing.T) {
|
||||
exprs := idx.Expressions()
|
||||
sqlIndex := sql.NewMySQLIndexBuilder(idx)
|
||||
for i := range test.greaterThanOrEqual {
|
||||
sqlIndex = sqlIndex.GreaterOrEqual(ctx, exprs[i], test.greaterThanOrEqual[i]).LessOrEqual(ctx, exprs[i], test.lessThanOrEqual[i])
|
||||
sqlIndex = sqlIndex.
|
||||
GreaterOrEqual(ctx, exprs[i], nil, test.greaterThanOrEqual[i]).
|
||||
LessOrEqual(ctx, exprs[i], nil, test.lessThanOrEqual[i])
|
||||
}
|
||||
indexLookup, err := sqlIndex.Build(ctx)
|
||||
require.NoError(t, err)
|
||||
@@ -1298,17 +1300,17 @@ func testDoltIndex(t *testing.T, ctx *sql.Context, root doltdb.RootValue, keys [
|
||||
for i, key := range keys {
|
||||
switch cmp {
|
||||
case indexComp_Eq:
|
||||
builder = builder.Equals(ctx, exprs[i], key)
|
||||
builder = builder.Equals(ctx, exprs[i], nil, key)
|
||||
case indexComp_NEq:
|
||||
builder = builder.NotEquals(ctx, exprs[i], key)
|
||||
builder = builder.NotEquals(ctx, exprs[i], nil, key)
|
||||
case indexComp_Gt:
|
||||
builder = builder.GreaterThan(ctx, exprs[i], key)
|
||||
builder = builder.GreaterThan(ctx, exprs[i], nil, key)
|
||||
case indexComp_GtE:
|
||||
builder = builder.GreaterOrEqual(ctx, exprs[i], key)
|
||||
builder = builder.GreaterOrEqual(ctx, exprs[i], nil, key)
|
||||
case indexComp_Lt:
|
||||
builder = builder.LessThan(ctx, exprs[i], key)
|
||||
builder = builder.LessThan(ctx, exprs[i], nil, key)
|
||||
case indexComp_LtE:
|
||||
builder = builder.LessOrEqual(ctx, exprs[i], key)
|
||||
builder = builder.LessOrEqual(ctx, exprs[i], nil, key)
|
||||
default:
|
||||
panic("should not be hit")
|
||||
}
|
||||
|
||||
@@ -67,6 +67,7 @@ func TestCountAgg(t *testing.T) {
|
||||
name: "reject multi parameter",
|
||||
setup: []string{
|
||||
"create table xy (x int primary key, y int, key y_idx(y))",
|
||||
"SET SESSION sql_mode = REPLACE(@@SESSION.sql_mode, 'ONLY_FULL_GROUP_BY', '');",
|
||||
},
|
||||
query: "select count(y), x from xy",
|
||||
doRowexec: false,
|
||||
@@ -84,7 +85,7 @@ func TestCountAgg(t *testing.T) {
|
||||
setup: []string{
|
||||
"create table xy (x int primary key, y int, key y_idx(y))",
|
||||
},
|
||||
query: "select count(y+1), x from xy",
|
||||
query: "select count(y+1) from xy",
|
||||
doRowexec: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -337,7 +337,7 @@ func DoltProceduresGetAll(ctx *sql.Context, db Database, procedureName string) (
|
||||
if procedureName == "" {
|
||||
lookup, err = sql.NewMySQLIndexBuilder(idx).IsNotNull(ctx, nameExpr).Build(ctx)
|
||||
} else {
|
||||
lookup, err = sql.NewMySQLIndexBuilder(idx).Equals(ctx, nameExpr, procedureName).Build(ctx)
|
||||
lookup, err = sql.NewMySQLIndexBuilder(idx).Equals(ctx, nameExpr, gmstypes.Text, procedureName).Build(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -471,7 +471,9 @@ func DoltProceduresGetDetails(ctx *sql.Context, tbl *WritableDoltTable, name str
|
||||
return sql.StoredProcedureDetails{}, false, fmt.Errorf("could not find primary key index on system table `%s`", doltdb.ProceduresTableName)
|
||||
}
|
||||
|
||||
indexLookup, err := sql.NewMySQLIndexBuilder(fragNameIndex).Equals(ctx, fragNameIndex.Expressions()[0], name).Build(ctx)
|
||||
indexLookup, err := sql.NewMySQLIndexBuilder(fragNameIndex).
|
||||
Equals(ctx, fragNameIndex.Expressions()[0], gmstypes.Text, name).
|
||||
Build(ctx)
|
||||
if err != nil {
|
||||
return sql.StoredProcedureDetails{}, false, err
|
||||
}
|
||||
|
||||
@@ -541,7 +541,7 @@ func interfaceValueAsSqlString(ctx *sql.Context, ti typeinfo.TypeInfo, value int
|
||||
return singleQuote + str + singleQuote, nil
|
||||
case typeinfo.DatetimeTypeIdentifier:
|
||||
return singleQuote + str + singleQuote, nil
|
||||
case typeinfo.InlineBlobTypeIdentifier, typeinfo.VarBinaryTypeIdentifier:
|
||||
case typeinfo.InlineBlobTypeIdentifier, typeinfo.VarBinaryTypeIdentifier, typeinfo.VectorTypeIdentifier:
|
||||
value, err := sql.UnwrapAny(ctx, value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -118,6 +118,8 @@ func GetSubtrees(msg serial.Message) ([]uint64, error) {
|
||||
switch id {
|
||||
case serial.ProllyTreeNodeFileID:
|
||||
return getProllyMapSubtrees(msg)
|
||||
case serial.VectorIndexNodeFileID:
|
||||
return getVectorIndexSubtrees(msg)
|
||||
case serial.AddressMapFileID:
|
||||
return getAddressMapSubtrees(msg)
|
||||
case serial.MergeArtifactsFileID:
|
||||
|
||||
+101
-101
@@ -16,6 +16,7 @@ package prolly
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
|
||||
@@ -138,33 +139,50 @@ func (p *proximityMapIter) Next(ctx context.Context) (k val.Tuple, v val.Tuple,
|
||||
return
|
||||
}
|
||||
|
||||
func getConvertToVectorFunction(keyDesc val.TupleDesc, ns tree.NodeStore) (tree.ConvertToVectorFunction, error) {
|
||||
switch keyDesc.Types[0].Enc {
|
||||
case val.JSONAddrEnc:
|
||||
return func(ctx context.Context, bytes []byte) ([]float32, error) {
|
||||
h, _ := keyDesc.GetJSONAddr(0, bytes)
|
||||
doc := tree.NewJSONDoc(h, ns)
|
||||
jsonWrapper, err := doc.ToIndexedJSONDocument(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.ConvertToVector(ctx, jsonWrapper)
|
||||
}, nil
|
||||
case val.BytesAdaptiveEnc:
|
||||
return func(ctx context.Context, bytes []byte) ([]float32, error) {
|
||||
vec, _, err := keyDesc.GetBytesAdaptiveValue(ctx, 0, ns, bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.ConvertToVector(ctx, vec)
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected encoding for vector index: %v", keyDesc.Types[0].Enc)
|
||||
}
|
||||
}
|
||||
|
||||
// NewProximityMap creates a new ProximityMap from a supplied root node.
|
||||
func NewProximityMap(ns tree.NodeStore, node tree.Node, keyDesc val.TupleDesc, valDesc val.TupleDesc, distanceType vector.DistanceType, logChunkSize uint8) ProximityMap {
|
||||
func NewProximityMap(ns tree.NodeStore, node tree.Node, keyDesc val.TupleDesc, valDesc val.TupleDesc, distanceType vector.DistanceType, logChunkSize uint8) (ProximityMap, error) {
|
||||
convertFunc, err := getConvertToVectorFunction(keyDesc, ns)
|
||||
if err != nil {
|
||||
return ProximityMap{}, err
|
||||
}
|
||||
tuples := tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{
|
||||
Root: node,
|
||||
NodeStore: ns,
|
||||
Order: keyDesc,
|
||||
DistanceType: distanceType,
|
||||
Convert: func(ctx context.Context, bytes []byte) []float64 {
|
||||
h, _ := keyDesc.GetJSONAddr(0, bytes)
|
||||
doc := tree.NewJSONDoc(h, ns)
|
||||
jsonWrapper, err := doc.ToIndexedJSONDocument(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
floats, err := sql.ConvertToVector(ctx, jsonWrapper)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return floats
|
||||
},
|
||||
Convert: convertFunc,
|
||||
}
|
||||
return ProximityMap{
|
||||
tuples: tuples,
|
||||
keyDesc: keyDesc,
|
||||
valDesc: valDesc,
|
||||
logChunkSize: logChunkSize,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
var proximitylevelMapKeyDesc = val.NewTupleDescriptor(
|
||||
@@ -180,6 +198,10 @@ func NewProximityMapBuilder(ctx context.Context, ns tree.NodeStore, distanceType
|
||||
return ProximityMapBuilder{}, err
|
||||
}
|
||||
mutableLevelMap := newMutableMap(emptyLevelMap)
|
||||
convertFunc, err := getConvertToVectorFunction(keyDesc, ns)
|
||||
if err != nil {
|
||||
return ProximityMapBuilder{}, err
|
||||
}
|
||||
return ProximityMapBuilder{
|
||||
ns: ns,
|
||||
vectorIndexSerializer: message.NewVectorIndexSerializer(ns.Pool(), logChunkSize, distanceType),
|
||||
@@ -189,6 +211,7 @@ func NewProximityMapBuilder(ctx context.Context, ns tree.NodeStore, distanceType
|
||||
logChunkSize: logChunkSize,
|
||||
maxLevel: 0,
|
||||
levelMap: mutableLevelMap,
|
||||
convertFunc: convertFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -211,7 +234,7 @@ func NewProximityMapBuilder(ctx context.Context, ns tree.NodeStore, distanceType
|
||||
//
|
||||
// Step 2: Create `pathMaps`, a list of maps, each corresponding to a different level of the ProximityMap
|
||||
//
|
||||
// The pathMap at depth `i` has the schema (vectorAddrs[1]...vectorAddr[i], keyBytes) -> value
|
||||
// The pathMap at depth `i` has the schema (vectorAddrs[0], ..., vectorAddr[i], keyBytes) -> value
|
||||
// and contains a row for every vector whose maximum depth is i.
|
||||
// - vectorAddrs: the path of vectors visited when walking from the root to the maximum depth where the vector appears.
|
||||
// - keyBytes: a bytestring containing the bytes of the ProximityMap key (which includes the vector)
|
||||
@@ -232,11 +255,12 @@ type ProximityMapBuilder struct {
|
||||
vectorIndexSerializer message.VectorIndexSerializer
|
||||
ns tree.NodeStore
|
||||
distanceType vector.DistanceType
|
||||
levelMap *MutableMap
|
||||
keyDesc val.TupleDesc
|
||||
valDesc val.TupleDesc
|
||||
logChunkSize uint8
|
||||
maxLevel uint8
|
||||
levelMap *MutableMap
|
||||
convertFunc tree.ConvertToVectorFunction
|
||||
}
|
||||
|
||||
// Insert adds a new key-value pair to the ProximityMap under construction.
|
||||
@@ -299,7 +323,7 @@ func (b *ProximityMapBuilder) makeRootNode(ctx context.Context, keys, values [][
|
||||
return ProximityMap{}, err
|
||||
}
|
||||
|
||||
return NewProximityMap(b.ns, rootNode, b.keyDesc, b.valDesc, b.distanceType, b.logChunkSize), nil
|
||||
return NewProximityMap(b.ns, rootNode, b.keyDesc, b.valDesc, b.distanceType, b.logChunkSize)
|
||||
}
|
||||
|
||||
// Flush finishes constructing a ProximityMap. Call this after all calls to Insert.
|
||||
@@ -371,7 +395,7 @@ func (b *ProximityMapBuilder) makePathMaps(ctx context.Context, mutableLevelMap
|
||||
|
||||
// Create every val.TupleBuilder and MutableMap that we will need
|
||||
// pathMaps[i] is the pathMap for level i (and depth maxLevel - i)
|
||||
pathMaps, keyTupleBuilder, prefixTupleBuilder, err := b.createInitialPathMaps(ctx, maxLevel)
|
||||
pathMaps, keyTupleBuilder, err := b.createInitialPathMaps(ctx, maxLevel)
|
||||
|
||||
// Next, visit each key-value pair in decreasing order of level / increasing order of depth.
|
||||
// When visiting a pair from depth `i`, we use each of the previous `i` pathMaps to compute a path of `i` index keys.
|
||||
@@ -382,55 +406,53 @@ func (b *ProximityMapBuilder) makePathMaps(ctx context.Context, mutableLevelMap
|
||||
depth := int(maxLevel - level)
|
||||
|
||||
// hashPath is a list of concatenated hashes, representing the sequence of closest vectors at each level of the tree.
|
||||
var hashPath []byte
|
||||
keyToInsert, _ := mutableLevelMap.keyDesc.GetBytes(1, levelMapKey)
|
||||
vectorHashToInsert, _ := b.keyDesc.GetJSONAddr(0, keyToInsert)
|
||||
vectorToInsert, err := getVectorFromHash(ctx, b.ns, vectorHashToInsert)
|
||||
vectorToInsert, err := b.convertFunc(ctx, keyToInsert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Compute the path that this row will have in the vector index, starting at the root.
|
||||
// A key-value pair at depth D will have a path D prior keys.
|
||||
// This path is computed in steps, by performing a lookup in each of the prior pathMaps.
|
||||
// Each iteration sets another column in |keyTupleBuilder|, then does a prefix lookup in the next pathMap
|
||||
// with all currently set columns.
|
||||
for pathDepth := 0; pathDepth < depth; pathDepth++ {
|
||||
lookupLevel := int(maxLevel) - pathDepth
|
||||
pathMap := pathMaps[lookupLevel]
|
||||
|
||||
pathMapIter, err := b.getNextPathSegmentCandidates(ctx, pathMap, prefixTupleBuilder, hashPath)
|
||||
pathMapIter, err := b.getNextPathSegmentCandidates(ctx, pathMap, keyTupleBuilder.Desc.PrefixDesc(pathDepth), keyTupleBuilder.BuildPrefixNoRecycle(b.ns.Pool(), pathDepth))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create an iterator that yields every candidate vector
|
||||
nextCandidate, stopIter := iter.Pull2(func(yield func(hash.Hash, error) bool) {
|
||||
nextCandidate, stopIter := iter.Pull2(func(yield func([]byte, error) bool) {
|
||||
for {
|
||||
pathMapKey, _, err := pathMapIter.Next(ctx)
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
yield(hash.Hash{}, err)
|
||||
yield(nil, err)
|
||||
}
|
||||
originalKey, _ := pathMap.keyDesc.GetBytes(1, pathMapKey)
|
||||
candidateVectorHash, _ := b.keyDesc.GetJSONAddr(0, originalKey)
|
||||
yield(candidateVectorHash, nil)
|
||||
originalKey, _ := pathMap.keyDesc.GetBytes(pathDepth, pathMapKey)
|
||||
yield(originalKey, nil)
|
||||
}
|
||||
})
|
||||
defer stopIter()
|
||||
|
||||
closestVectorHash, _ := b.getClosestVector(ctx, vectorToInsert, nextCandidate)
|
||||
closestVectorEncoding, _, err := b.getClosestVector(ctx, vectorToInsert, nextCandidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hashPath = append(hashPath, closestVectorHash[:]...)
|
||||
keyTupleBuilder.PutByteString(pathDepth, closestVectorEncoding)
|
||||
}
|
||||
|
||||
// Once we have the path for this key, we turn it into a tuple and add it to the next pathMap.
|
||||
keyTupleBuilder.PutByteString(0, hashPath)
|
||||
keyTupleBuilder.PutByteString(1, keyToInsert)
|
||||
keyTupleBuilder.PutByteString(depth, keyToInsert)
|
||||
|
||||
keyTuple, err := keyTupleBuilder.Build(b.ns.Pool())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyTuple := keyTupleBuilder.BuildPrefixNoRecycle(b.ns.Pool(), depth+1)
|
||||
err = pathMaps[level].Put(ctx, keyTuple, levelMapValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -441,14 +463,9 @@ func (b *ProximityMapBuilder) makePathMaps(ctx context.Context, mutableLevelMap
|
||||
childLevel := level - 1
|
||||
if level > 0 {
|
||||
for {
|
||||
hashPath = append(hashPath, vectorHashToInsert[:]...)
|
||||
keyTupleBuilder.PutByteString(0, hashPath)
|
||||
keyTupleBuilder.PutByteString(1, keyToInsert)
|
||||
|
||||
childKeyTuple, err := keyTupleBuilder.Build(b.ns.Pool())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
depth++
|
||||
keyTupleBuilder.PutByteString(depth, keyToInsert)
|
||||
childKeyTuple := keyTupleBuilder.BuildPrefixNoRecycle(b.ns.Pool(), depth+1)
|
||||
err = pathMaps[childLevel].Put(ctx, childKeyTuple, levelMapValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -472,79 +489,75 @@ func (b *ProximityMapBuilder) makePathMaps(ctx context.Context, mutableLevelMap
|
||||
}
|
||||
|
||||
// createInitialPathMaps creates a list of MutableMaps that will eventually store a single level of the to-be-built ProximityMap
|
||||
func (b *ProximityMapBuilder) createInitialPathMaps(ctx context.Context, maxLevel uint8) (pathMaps []*MutableMap, keyTupleBuilder, prefixTupleBuilder *val.TupleBuilder, err error) {
|
||||
func (b *ProximityMapBuilder) createInitialPathMaps(ctx context.Context, maxLevel uint8) (pathMaps []*MutableMap, keyTupleBuilder *val.TupleBuilder, err error) {
|
||||
pathMaps = make([]*MutableMap, maxLevel+1)
|
||||
|
||||
pathMapKeyDescTypes := []val.Type{{Enc: val.ByteStringEnc, Nullable: false}, {Enc: val.ByteStringEnc, Nullable: false}}
|
||||
|
||||
pathMapKeyDesc := val.NewTupleDescriptor(pathMapKeyDescTypes...)
|
||||
|
||||
emptyPathMap, err := NewMapFromTuples(ctx, b.ns, pathMapKeyDesc, b.valDesc)
|
||||
|
||||
keyTupleBuilder = val.NewTupleBuilder(pathMapKeyDesc, b.ns)
|
||||
prefixTupleBuilder = val.NewTupleBuilder(val.NewTupleDescriptor(pathMapKeyDescTypes[0]), b.ns)
|
||||
|
||||
for i := uint8(0); i <= maxLevel; i++ {
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
pathMaps[i] = newMutableMap(emptyPathMap)
|
||||
pathMapKeyDescTypes := make([]val.Type, maxLevel+1)
|
||||
for i := range pathMapKeyDescTypes {
|
||||
pathMapKeyDescTypes[i] = val.Type{Enc: val.ByteStringEnc, Nullable: false}
|
||||
}
|
||||
|
||||
return pathMaps, keyTupleBuilder, prefixTupleBuilder, nil
|
||||
for level := uint8(0); level <= maxLevel; level++ {
|
||||
depth := maxLevel - level
|
||||
pathMapKeyDesc := val.NewTupleDescriptor(pathMapKeyDescTypes[:depth+1]...)
|
||||
|
||||
emptyPathMap, err := NewMapFromTuples(ctx, b.ns, pathMapKeyDesc, b.valDesc)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pathMaps[level] = newMutableMap(emptyPathMap)
|
||||
}
|
||||
|
||||
keyTupleBuilder = val.NewTupleBuilder(val.NewTupleDescriptor(pathMapKeyDescTypes...), b.ns)
|
||||
|
||||
return pathMaps, keyTupleBuilder, nil
|
||||
}
|
||||
|
||||
// getNextPathSegmentCandidates takes a list of keys, representing a path into the ProximityMap from the root.
|
||||
// It returns an iter over all possible keys that could be the next path segment.
|
||||
func (b *ProximityMapBuilder) getNextPathSegmentCandidates(ctx context.Context, pathMap *MutableMap, prefixTupleBuilder *val.TupleBuilder, currentPath []byte) (MapIter, error) {
|
||||
prefixTupleBuilder.PutByteString(0, currentPath)
|
||||
prefixTuple, err := prefixTupleBuilder.Build(b.ns.Pool())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefixRange := PrefixRange(ctx, prefixTuple, prefixTupleBuilder.Desc)
|
||||
func (b *ProximityMapBuilder) getNextPathSegmentCandidates(ctx context.Context, pathMap *MutableMap, prefixTupleDesc val.TupleDesc, prefixTuple val.Tuple) (MapIter, error) {
|
||||
prefixRange := PrefixRange(ctx, prefixTuple, prefixTupleDesc)
|
||||
return pathMap.IterRange(ctx, prefixRange)
|
||||
}
|
||||
|
||||
// getClosestVector iterates over a range of candidate vectors to determine which one is the closest to the target.
|
||||
func (b *ProximityMapBuilder) getClosestVector(ctx context.Context, targetVector []float64, nextCandidate func() (candidate hash.Hash, err error, valid bool)) (hash.Hash, error) {
|
||||
func (b *ProximityMapBuilder) getClosestVector(ctx context.Context, targetVector []float32, nextCandidate func() (candidate []byte, err error, valid bool)) (closestVectorEncoding []byte, closestVector []float32, err error) {
|
||||
// First call to nextCandidate is guaranteed to be valid because there's at least one vector in the set.
|
||||
// (non-root nodes inherit the first vector from their parent)
|
||||
candidateVectorHash, err, _ := nextCandidate()
|
||||
closestVectorEncoding, err, _ = nextCandidate()
|
||||
if err != nil {
|
||||
return hash.Hash{}, err
|
||||
return nil, nil, err
|
||||
}
|
||||
closestVector, err = b.convertFunc(ctx, closestVectorEncoding)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
candidateVector, err := getVectorFromHash(ctx, b.ns, candidateVectorHash)
|
||||
closestDistance, err := b.distanceType.Eval(targetVector, closestVector)
|
||||
if err != nil {
|
||||
return hash.Hash{}, err
|
||||
}
|
||||
closestVectorHash := candidateVectorHash
|
||||
closestDistance, err := b.distanceType.Eval(targetVector, candidateVector)
|
||||
if err != nil {
|
||||
return hash.Hash{}, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
candidateVectorHash, err, valid := nextCandidate()
|
||||
candidateVectorEncoding, err, valid := nextCandidate()
|
||||
if err != nil {
|
||||
return hash.Hash{}, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if !valid {
|
||||
return closestVectorHash, nil
|
||||
return closestVectorEncoding, closestVector, nil
|
||||
}
|
||||
candidateVector, err = getVectorFromHash(ctx, b.ns, candidateVectorHash)
|
||||
candidateVector, err := b.convertFunc(ctx, candidateVectorEncoding)
|
||||
if err != nil {
|
||||
return hash.Hash{}, err
|
||||
return nil, nil, err
|
||||
}
|
||||
candidateDistance, err := b.distanceType.Eval(targetVector, candidateVector)
|
||||
if err != nil {
|
||||
return hash.Hash{}, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if candidateDistance < closestDistance {
|
||||
closestVectorHash = candidateVectorHash
|
||||
closestVector = candidateVector
|
||||
closestVectorEncoding = candidateVectorEncoding
|
||||
closestDistance = candidateDistance
|
||||
}
|
||||
}
|
||||
@@ -580,9 +593,8 @@ func (b *ProximityMapBuilder) makeProximityMapFromPathMaps(ctx context.Context,
|
||||
if err != nil {
|
||||
return ProximityMap{}, err
|
||||
}
|
||||
originalKey, _ := rootPathMap.keyDesc.GetBytes(1, key)
|
||||
path, _ := b.keyDesc.GetJSONAddr(0, originalKey)
|
||||
_, nodeCount, nodeHash, err := chunker.Next(ctx, b.ns, b.vectorIndexSerializer, path, maxLevel-1, 1, b.keyDesc)
|
||||
originalKey, _ := rootPathMap.keyDesc.GetBytes(0, key)
|
||||
_, nodeCount, nodeHash, err := chunker.Next(ctx, b.ns, b.vectorIndexSerializer, originalKey, maxLevel-1, 1, b.keyDesc)
|
||||
if err != nil {
|
||||
return ProximityMap{}, err
|
||||
}
|
||||
@@ -592,15 +604,3 @@ func (b *ProximityMapBuilder) makeProximityMapFromPathMaps(ctx context.Context,
|
||||
}
|
||||
return b.makeRootNode(ctx, topLevelKeys, topLevelValues, topLevelSubtrees, maxLevel)
|
||||
}
|
||||
|
||||
func getJsonValueFromHash(ctx context.Context, ns tree.NodeStore, h hash.Hash) (sql.JSONWrapper, error) {
|
||||
return tree.NewJSONDoc(h, ns).ToIndexedJSONDocument(ctx)
|
||||
}
|
||||
|
||||
func getVectorFromHash(ctx context.Context, ns tree.NodeStore, h hash.Hash) ([]float64, error) {
|
||||
otherValue, err := getJsonValueFromHash(ctx, ns, h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.ConvertToVector(ctx, otherValue)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,7 +18,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
|
||||
|
||||
"github.com/dolthub/dolt/go/gen/fb/serial"
|
||||
@@ -56,18 +55,9 @@ func (f ProximityFlusher) ApplyMutationsWithSerializer(
|
||||
keyDesc := mutableMap.keyDesc
|
||||
valDesc := mutableMap.valDesc
|
||||
ns := mutableMap.NodeStore()
|
||||
convert := func(ctx context.Context, bytes []byte) []float64 {
|
||||
h, _ := keyDesc.GetJSONAddr(0, bytes)
|
||||
doc := tree.NewJSONDoc(h, ns)
|
||||
jsonWrapper, err := doc.ToIndexedJSONDocument(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
floats, err := sql.ConvertToVector(ctx, jsonWrapper)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return floats
|
||||
convertFunc, err := getConvertToVectorFunction(keyDesc, ns)
|
||||
if err != nil {
|
||||
return tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{}, err
|
||||
}
|
||||
edits := make([]VectorIndexKV, 0, mutableMap.tuples.Edits.Count())
|
||||
editIter := mutableMap.tuples.Mutations()
|
||||
@@ -86,7 +76,6 @@ func (f ProximityFlusher) ApplyMutationsWithSerializer(
|
||||
mutation = editIter.NextMutation(ctx)
|
||||
}
|
||||
var newRoot tree.Node
|
||||
var err error
|
||||
root := mutableMap.tuples.Static.Root
|
||||
distanceType := mutableMap.tuples.Static.DistanceType
|
||||
if root.Count() == 0 {
|
||||
@@ -96,7 +85,11 @@ func (f ProximityFlusher) ApplyMutationsWithSerializer(
|
||||
// The root node has changed, or there may be a new level to the tree. We need to rebuild the tree.
|
||||
newRoot, _, err = f.rebuildNode(ctx, ns, root, edits, distanceType, keyDesc, valDesc, maxEditLevel)
|
||||
} else {
|
||||
newRoot, _, err = f.visitNode(ctx, serializer, ns, root, edits, convert, distanceType, keyDesc, valDesc)
|
||||
root, err = root.LoadSubtrees()
|
||||
if err != nil {
|
||||
return tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{}, err
|
||||
}
|
||||
newRoot, _, err = f.visitNode(ctx, serializer, ns, root, edits, convertFunc, distanceType, keyDesc, valDesc)
|
||||
|
||||
}
|
||||
if err != nil {
|
||||
@@ -106,7 +99,7 @@ func (f ProximityFlusher) ApplyMutationsWithSerializer(
|
||||
Root: newRoot,
|
||||
NodeStore: ns,
|
||||
DistanceType: distanceType,
|
||||
Convert: convert,
|
||||
Convert: convertFunc,
|
||||
Order: keyDesc,
|
||||
}, nil
|
||||
}
|
||||
@@ -161,7 +154,7 @@ func (f ProximityFlusher) visitNode(
|
||||
ns tree.NodeStore,
|
||||
node tree.Node,
|
||||
edits []VectorIndexKV,
|
||||
convert func(context.Context, []byte) []float64,
|
||||
convert tree.ConvertToVectorFunction,
|
||||
distanceType vector.DistanceType,
|
||||
keyDesc val.TupleDesc,
|
||||
valDesc val.TupleDesc,
|
||||
@@ -177,18 +170,29 @@ func (f ProximityFlusher) visitNode(
|
||||
childEdits := make(map[int]childEditList)
|
||||
for _, edit := range edits {
|
||||
key := edit.key
|
||||
editVector := convert(ctx, key)
|
||||
editVector, err := convert(ctx, key)
|
||||
if err != nil {
|
||||
return tree.Node{}, 0, err
|
||||
}
|
||||
level := edit.level
|
||||
// visit each child in the node to determine which is closest
|
||||
closestIdx := 0
|
||||
childKey := node.GetKey(0)
|
||||
closestDistance, err := distanceType.Eval(convert(ctx, childKey), editVector)
|
||||
childVector, err := convert(ctx, childKey)
|
||||
if err != nil {
|
||||
return tree.Node{}, 0, err
|
||||
}
|
||||
closestDistance, err := distanceType.Eval(childVector, editVector)
|
||||
if err != nil {
|
||||
return tree.Node{}, 0, err
|
||||
}
|
||||
for i := 1; i < node.Count(); i++ {
|
||||
childKey = node.GetKey(i)
|
||||
newDistance, err := distanceType.Eval(convert(ctx, childKey), editVector)
|
||||
childVector, err = convert(ctx, childKey)
|
||||
if err != nil {
|
||||
return tree.Node{}, 0, err
|
||||
}
|
||||
newDistance, err := distanceType.Eval(childVector, editVector)
|
||||
if err != nil {
|
||||
return tree.Node{}, 0, err
|
||||
}
|
||||
@@ -215,6 +219,8 @@ func (f ProximityFlusher) visitNode(
|
||||
if len(childEditList.edits) == 0 {
|
||||
// No edits affected this node, leave it as is.
|
||||
values = append(values, childValue)
|
||||
childSubtrees := node.GetSubtreeCount(i)
|
||||
nodeSubtrees = append(nodeSubtrees, uint64(childSubtrees))
|
||||
} else {
|
||||
childNodeAddress := hash.New(childValue)
|
||||
childNode, err := ns.Read(ctx, childNodeAddress)
|
||||
@@ -226,6 +232,10 @@ func (f ProximityFlusher) visitNode(
|
||||
if childEditList.mustRebuild {
|
||||
newChildNode, childSubtrees, err = f.rebuildNode(ctx, ns, childNode, childEditList.edits, distanceType, keyDesc, valDesc, uint8(childNode.Level()))
|
||||
} else {
|
||||
childNode, err = childNode.LoadSubtrees()
|
||||
if err != nil {
|
||||
return tree.Node{}, 0, err
|
||||
}
|
||||
newChildNode, childSubtrees, err = f.visitNode(ctx, serializer, ns, childNode, childEditList.edits, convert, distanceType, keyDesc, valDesc)
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ func MapInterfaceFromValue(ctx context.Context, v types.Value, sch schema.Schema
|
||||
// TODO: We should read the distance function and chunk size from the message.
|
||||
// Currently, vector.DistanceL2Squared{} and prolly.DefaultLogChunkSize are the only values that can be written,
|
||||
// but this may not be true in the future.
|
||||
return prolly.NewProximityMap(ns, root, kd, vd, vector.DistanceL2Squared{}, prolly.DefaultLogChunkSize), nil
|
||||
return prolly.NewProximityMap(ns, root, kd, vd, vector.DistanceL2Squared{}, prolly.DefaultLogChunkSize)
|
||||
default:
|
||||
return prolly.NewMap(root, ns, kd, vd), nil
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func MapFromValueWithDescriptors(v types.Value, kd, vd val.TupleDesc, ns tree.No
|
||||
// TODO: We should read the distance function and chunk size from the message.
|
||||
// Currently, vector.DistanceL2Squared{} and prolly.DefaultLogChunkSize are the only values that can be written,
|
||||
// but this may not be true in the future.
|
||||
return prolly.NewProximityMap(ns, root, kd, vd, vector.DistanceL2Squared{}, prolly.DefaultLogChunkSize), nil
|
||||
return prolly.NewProximityMap(ns, root, kd, vd, vector.DistanceL2Squared{}, prolly.DefaultLogChunkSize)
|
||||
default:
|
||||
return prolly.NewMap(root, ns, kd, vd), nil
|
||||
}
|
||||
|
||||
@@ -188,11 +188,10 @@ func TestWriteImmutableTree(t *testing.T) {
|
||||
assert.Equal(t, tt.inputSize, byteCnt)
|
||||
assert.Equal(t, expUnfilled, unfilledCnt)
|
||||
if expLevel > 0 {
|
||||
root, err = root.loadSubtrees()
|
||||
root, err = root.LoadSubtrees()
|
||||
require.NoError(t, err)
|
||||
for i := range expSubtrees {
|
||||
sc, err := root.getSubtreeCount(i)
|
||||
require.NoError(t, err)
|
||||
sc := root.GetSubtreeCount(i)
|
||||
assert.Equal(t, expSubtrees[i], sc)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,15 +290,12 @@ func insertNode[K ~[]byte, S message.Serializer, O Ordering[K]](ctx context.Cont
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nd, err = nd.loadSubtrees()
|
||||
nd, err = nd.LoadSubtrees()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := 0; i < nd.Count(); i++ {
|
||||
subtreeCount, err := nd.getSubtreeCount(i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
subtreeCount := nd.GetSubtreeCount(i)
|
||||
err = insertNode[K, S, O](ctx, tc, nil, K(nd.GetKey(i)), nd.getAddress(i), subtreeCount, level-1, order)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -160,7 +160,7 @@ func (nd Node) GetValue(i int) Item {
|
||||
return nd.values.GetItem(i, nd.msg)
|
||||
}
|
||||
|
||||
func (nd Node) loadSubtrees() (Node, error) {
|
||||
func (nd Node) LoadSubtrees() (Node, error) {
|
||||
var err error
|
||||
if nd.subtrees == nil {
|
||||
// deserializing subtree counts requires a malloc,
|
||||
@@ -174,12 +174,12 @@ func (nd Node) loadSubtrees() (Node, error) {
|
||||
return nd, err
|
||||
}
|
||||
|
||||
func (nd Node) getSubtreeCount(i int) (uint64, error) {
|
||||
func (nd Node) GetSubtreeCount(i int) uint64 {
|
||||
if nd.IsLeaf() {
|
||||
return 1, nil
|
||||
return 1
|
||||
}
|
||||
// this will panic unless subtrees were loaded.
|
||||
return (*nd.subtrees)[i], nil
|
||||
return (*nd.subtrees)[i]
|
||||
}
|
||||
|
||||
// getAddress returns the |ith| address of this node.
|
||||
|
||||
@@ -111,10 +111,10 @@ func newCursorAtOrdinal(ctx context.Context, ns NodeStore, nd Node, ord uint64)
|
||||
if nd.IsLeaf() {
|
||||
return int(distance)
|
||||
}
|
||||
nd, _ = nd.loadSubtrees()
|
||||
nd, _ = nd.LoadSubtrees()
|
||||
|
||||
for idx = 0; idx < nd.Count(); idx++ {
|
||||
cnt, _ := nd.getSubtreeCount(idx)
|
||||
cnt := nd.GetSubtreeCount(idx)
|
||||
card := int64(cnt)
|
||||
if (distance - card) < 0 {
|
||||
break
|
||||
@@ -144,16 +144,13 @@ func getOrdinalOfCursor(curr *cursor) (ord uint64, err error) {
|
||||
return 0, fmt.Errorf("found invalid parent cursor behind node start")
|
||||
}
|
||||
|
||||
curr.nd, err = curr.nd.loadSubtrees()
|
||||
curr.nd, err = curr.nd.LoadSubtrees()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for idx := curr.idx - 1; idx >= 0; idx-- {
|
||||
cnt, err := curr.nd.getSubtreeCount(idx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cnt := curr.nd.GetSubtreeCount(idx)
|
||||
ord += cnt
|
||||
}
|
||||
}
|
||||
@@ -289,15 +286,12 @@ func recursiveFetchLeafNodeSpan(ctx context.Context, ns NodeStore, nodes []Node,
|
||||
|
||||
var err error
|
||||
for _, nd := range nodes {
|
||||
if nd, err = nd.loadSubtrees(); err != nil {
|
||||
if nd, err = nd.LoadSubtrees(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
for i := 0; i < nd.Count(); i++ {
|
||||
card, err := nd.getSubtreeCount(i)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
card := nd.GetSubtreeCount(i)
|
||||
|
||||
if acc == 0 && card < start {
|
||||
start -= card
|
||||
@@ -388,11 +382,11 @@ func (cur *cursor) currentSubtreeSize() (uint64, error) {
|
||||
return 1, nil
|
||||
}
|
||||
var err error
|
||||
cur.nd, err = cur.nd.loadSubtrees()
|
||||
cur.nd, err = cur.nd.LoadSubtrees()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return cur.nd.getSubtreeCount(cur.idx)
|
||||
return cur.nd.GetSubtreeCount(cur.idx), nil
|
||||
}
|
||||
|
||||
func (cur *cursor) firstKey() Item {
|
||||
|
||||
@@ -418,7 +418,7 @@ func (td *PatchGenerator[K, O]) split(ctx context.Context) (patch Patch, diffTyp
|
||||
if err != nil {
|
||||
return Patch{}, NoDiff, err
|
||||
}
|
||||
toChild, err = toChild.loadSubtrees()
|
||||
toChild, err = toChild.LoadSubtrees()
|
||||
if err != nil {
|
||||
return Patch{}, NoDiff, err
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ func GetField(ctx context.Context, td val.TupleDesc, i int, tup val.Tuple, ns No
|
||||
v = val.NewTextStorage(ctx, h, ns)
|
||||
}
|
||||
case val.BytesAdaptiveEnc:
|
||||
v, ok, err = td.GetBytesAdaptiveValue(i, ns, tup)
|
||||
v, ok, err = td.GetBytesAdaptiveValue(ctx, i, ns, tup)
|
||||
case val.StringAdaptiveEnc:
|
||||
v, ok, err = td.GetStringAdaptiveValue(i, ns, tup)
|
||||
case val.CommitAddrEnc:
|
||||
|
||||
@@ -29,13 +29,15 @@ import (
|
||||
|
||||
type KeyValueDistanceFn[K, V ~[]byte] func(key K, value V, distance float64) error
|
||||
|
||||
type ConvertToVectorFunction func(context.Context, []byte) ([]float32, error)
|
||||
|
||||
// ProximityMap is a static Prolly Tree where the position of a key in the tree is based on proximity, as opposed to a traditional ordering.
|
||||
// O provides the ordering only within a node.
|
||||
type ProximityMap[K, V ~[]byte, O Ordering[K]] struct {
|
||||
NodeStore NodeStore
|
||||
DistanceType vector.DistanceType
|
||||
Order O
|
||||
Convert func(context.Context, []byte) []float64
|
||||
Convert ConvertToVectorFunction
|
||||
Root Node
|
||||
}
|
||||
|
||||
@@ -94,7 +96,10 @@ func (t ProximityMap[K, V, O]) WalkNodes(ctx context.Context, cb NodeCb) error {
|
||||
func (t ProximityMap[K, V, O]) Get(ctx context.Context, query K, cb KeyValueFn[K, V]) (err error) {
|
||||
nd := t.Root
|
||||
|
||||
queryVector := t.Convert(ctx, query)
|
||||
queryVector, err := t.Convert(ctx, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Find the child with the minimum distance.
|
||||
|
||||
@@ -105,7 +110,11 @@ func (t ProximityMap[K, V, O]) Get(ctx context.Context, query K, cb KeyValueFn[K
|
||||
|
||||
for i := 0; i < int(nd.count); i++ {
|
||||
k := nd.GetKey(i)
|
||||
newDistance, err := t.DistanceType.Eval(t.Convert(ctx, k), queryVector)
|
||||
vec, err := t.Convert(ctx, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newDistance, err := t.DistanceType.Eval(vec, queryVector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -201,7 +210,11 @@ func (t ProximityMap[K, V, O]) GetClosest(ctx context.Context, query interface{}
|
||||
|
||||
for i := 0; i < int(t.Root.count); i++ {
|
||||
k := t.Root.GetKey(i)
|
||||
newDistance, err := t.DistanceType.Eval(t.Convert(ctx, k), queryVector)
|
||||
vec, err := t.Convert(ctx, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newDistance, err := t.DistanceType.Eval(vec, queryVector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -222,7 +235,11 @@ func (t ProximityMap[K, V, O]) GetClosest(ctx context.Context, query interface{}
|
||||
// TODO: We don't need to recompute the distance when visiting the same key as the parent.
|
||||
for i := 0; i < int(node.count); i++ {
|
||||
k := node.GetKey(i)
|
||||
newDistance, err := t.DistanceType.Eval(t.Convert(ctx, k), queryVector)
|
||||
vec, err := t.Convert(ctx, k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newDistance, err := t.DistanceType.Eval(vec, queryVector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -265,15 +282,3 @@ func (t ProximityMap[K, V, O]) IterAll(ctx context.Context) (*OrderedTreeIter[K,
|
||||
|
||||
return &OrderedTreeIter[K, V]{curr: c, stop: stop, step: c.advance}, nil
|
||||
}
|
||||
|
||||
func getJsonValueFromHash(ctx context.Context, ns NodeStore, h hash.Hash) (interface{}, error) {
|
||||
return NewJSONDoc(h, ns).ToIndexedJSONDocument(ctx)
|
||||
}
|
||||
|
||||
func getVectorFromHash(ctx context.Context, ns NodeStore, h hash.Hash) ([]float64, error) {
|
||||
otherValue, err := getJsonValueFromHash(ctx, ns, h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.ConvertToVector(ctx, otherValue)
|
||||
}
|
||||
|
||||
@@ -81,11 +81,10 @@ func histLevelCount(t *testing.T, nodes []Node) int {
|
||||
case 0:
|
||||
cnt += n.Count()
|
||||
default:
|
||||
n, err := n.loadSubtrees()
|
||||
n, err := n.LoadSubtrees()
|
||||
require.NoError(t, err)
|
||||
for i := 0; i < n.Count(); i++ {
|
||||
subCnt, err := n.getSubtreeCount(i)
|
||||
require.NoError(t, err)
|
||||
subCnt := n.GetSubtreeCount(i)
|
||||
cnt += int(subCnt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
package prolly
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
|
||||
@@ -27,6 +28,7 @@ import (
|
||||
// vectorIndexChunker is a stateful chunker that iterates over |pathMap|, a map that contains an element
|
||||
// for every key-value pair for a given level of a ProximityMap, and provides the path of keys to reach
|
||||
// that pair from the root. It uses this iterator to build each of the ProximityMap nodes for that level.
|
||||
// A linked list of N vectorIndexChunkers is used in order to build a vector index with N levels.
|
||||
type vectorIndexChunker struct {
|
||||
pathMapIter MapIter
|
||||
pathMap *MutableMap
|
||||
@@ -34,8 +36,10 @@ type vectorIndexChunker struct {
|
||||
lastKey []byte
|
||||
lastValue []byte
|
||||
lastSubtreeCount uint64
|
||||
lastPathSegment hash.Hash
|
||||
atEnd bool
|
||||
// lastPathSegment is the last observed parent key. When Next() is called with a different value for |parentPathSegment|,
|
||||
// we know that the chunker needs to end the previous chunk and start a new one.
|
||||
lastPathSegment []byte
|
||||
atEnd bool
|
||||
}
|
||||
|
||||
func newVectorIndexChunker(ctx context.Context, pathMap *MutableMap, childChunker *vectorIndexChunker) (*vectorIndexChunker, error) {
|
||||
@@ -56,9 +60,8 @@ func newVectorIndexChunker(ctx context.Context, pathMap *MutableMap, childChunke
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path, _ := pathMap.keyDesc.GetBytes(0, firstKey)
|
||||
lastPathSegment := hash.New(path[len(path)-20:])
|
||||
originalKey, _ := pathMap.keyDesc.GetBytes(1, firstKey)
|
||||
lastPathSegment, _ := pathMap.keyDesc.GetBytes(pathMap.keyDesc.Count()-2, firstKey)
|
||||
originalKey, _ := pathMap.keyDesc.GetBytes(pathMap.keyDesc.Count()-1, firstKey)
|
||||
return &vectorIndexChunker{
|
||||
pathMap: pathMap,
|
||||
pathMapIter: pathMapIter,
|
||||
@@ -70,14 +73,15 @@ func newVectorIndexChunker(ctx context.Context, pathMap *MutableMap, childChunke
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serializer message.VectorIndexSerializer, parentPathSegment hash.Hash, level, depth int, originalKeyDesc val.TupleDesc) (tree.Node, uint64, hash.Hash, error) {
|
||||
// Next produces the next tree node for the corresponding level of the tree.
|
||||
func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serializer message.VectorIndexSerializer, parentPathSegment []byte, level, depth int, originalKeyDesc val.TupleDesc) (tree.Node, uint64, hash.Hash, error) {
|
||||
var indexMapKeys [][]byte
|
||||
var indexMapValues [][]byte
|
||||
var indexMapSubtrees []uint64
|
||||
subtreeSum := uint64(0)
|
||||
|
||||
for {
|
||||
if c.atEnd || c.lastPathSegment != parentPathSegment {
|
||||
if c.atEnd || !bytes.Equal(c.lastPathSegment, parentPathSegment) {
|
||||
msg := serializer.Serialize(indexMapKeys, indexMapValues, indexMapSubtrees, level)
|
||||
node, _, err := tree.NodeFromBytes(msg)
|
||||
if err != nil {
|
||||
@@ -86,9 +90,10 @@ func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serial
|
||||
nodeHash, err := ns.Write(ctx, node)
|
||||
return node, subtreeSum, nodeHash, err
|
||||
}
|
||||
vectorHash, _ := originalKeyDesc.GetJSONAddr(0, c.lastKey)
|
||||
if c.childChunker != nil {
|
||||
_, childCount, nodeHash, err := c.childChunker.Next(ctx, ns, serializer, vectorHash, level-1, depth+1, originalKeyDesc)
|
||||
// This chunker isn't chunking a leaf node. To insert the next key-value pair, we call Next() on the child chunker, which produces
|
||||
// a node one level down, that will be pointed to by this node.
|
||||
_, childCount, nodeHash, err := c.childChunker.Next(ctx, ns, serializer, c.lastKey, level-1, depth+1, originalKeyDesc)
|
||||
if err != nil {
|
||||
return tree.Node{}, 0, hash.Hash{}, err
|
||||
}
|
||||
@@ -107,9 +112,14 @@ func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serial
|
||||
} else if err != nil {
|
||||
return tree.Node{}, 0, hash.Hash{}, err
|
||||
} else {
|
||||
lastPath, _ := c.pathMap.keyDesc.GetBytes(0, nextKey)
|
||||
c.lastPathSegment = hash.New(lastPath[len(lastPath)-20:])
|
||||
c.lastKey, _ = c.pathMap.keyDesc.GetBytes(1, nextKey)
|
||||
// nextValue is a pathMap value tuple: it contains a primary key from the underlying table.
|
||||
// nextKey is a pathMap key tuple: it contains a field for each edge in the final index graph that connects
|
||||
// the root to |nextValue|, ending with the vector corresponding to the primary key in |nextValue|.
|
||||
// This chunker stores that vector in |c.lastKey|, so that it can write it into the index on the subsequent call.
|
||||
// It also stores the direct parent vector. When the direct parent vector changes, we use that as an indicator
|
||||
// To finish one chunk and begin the next one.
|
||||
c.lastPathSegment, _ = c.pathMap.keyDesc.GetBytes(c.pathMap.keyDesc.Count()-2, nextKey)
|
||||
c.lastKey, _ = c.pathMap.keyDesc.GetBytes(c.pathMap.keyDesc.Count()-1, nextKey)
|
||||
c.lastValue = nextValue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,7 +288,7 @@ func TestTupleBuilderAdaptiveEncodings(t *testing.T) {
|
||||
tup, err := tb.Build(testPool)
|
||||
require.NoError(t, err)
|
||||
|
||||
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(0, vs, tup)
|
||||
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(ctx, 0, vs, tup)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, shortByteArray, adaptiveEncodingBytes)
|
||||
})
|
||||
@@ -303,7 +303,7 @@ func TestTupleBuilderAdaptiveEncodings(t *testing.T) {
|
||||
tup, err := tb.Build(testPool)
|
||||
require.NoError(t, err)
|
||||
|
||||
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(0, vs, tup)
|
||||
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(ctx, 0, vs, tup)
|
||||
require.NoError(t, err)
|
||||
adaptiveEncodingByteArray := adaptiveEncodingBytes.(*ByteArray)
|
||||
outBytes, err := adaptiveEncodingByteArray.ToBytes(ctx)
|
||||
@@ -336,7 +336,7 @@ func TestTupleBuilderAdaptiveEncodings(t *testing.T) {
|
||||
|
||||
{
|
||||
// Check that first column is stored out-of-band
|
||||
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(0, vs, tup)
|
||||
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(ctx, 0, vs, tup)
|
||||
require.NoError(t, err)
|
||||
adaptiveEncodingByteArray := adaptiveEncodingBytes.(*ByteArray)
|
||||
outBytes, err := adaptiveEncodingByteArray.ToBytes(ctx)
|
||||
@@ -346,7 +346,7 @@ func TestTupleBuilderAdaptiveEncodings(t *testing.T) {
|
||||
|
||||
{
|
||||
// Check that second column is stored inline
|
||||
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(1, vs, tup)
|
||||
adaptiveEncodingBytes, _, err := td.GetBytesAdaptiveValue(ctx, 1, vs, tup)
|
||||
require.NoError(t, err)
|
||||
adaptiveEncodingByteArray := adaptiveEncodingBytes.([]byte)
|
||||
require.Equal(t, mediumByteArray, adaptiveEncodingByteArray)
|
||||
|
||||
@@ -535,11 +535,13 @@ func (td TupleDesc) GetBytesAddr(i int, tup Tuple) (hash.Hash, bool) {
|
||||
}
|
||||
|
||||
// GetBytesAdaptiveValue returns either a []byte or a BytesWrapper, but Go doesn't allow us to use a single type for that.
|
||||
func (td TupleDesc) GetBytesAdaptiveValue(i int, vs ValueStore, tup Tuple) (interface{}, bool, error) {
|
||||
// TODO: Add context parameter
|
||||
ctx := context.Background()
|
||||
func (td TupleDesc) GetBytesAdaptiveValue(ctx context.Context, i int, vs ValueStore, tup Tuple) (interface{}, bool, error) {
|
||||
td.expectEncoding(i, BytesAdaptiveEnc)
|
||||
adaptiveValue := AdaptiveValue(td.GetField(i, tup))
|
||||
return GetBytesAdaptiveValue(ctx, vs, td.GetField(i, tup))
|
||||
}
|
||||
|
||||
func GetBytesAdaptiveValue(ctx context.Context, vs ValueStore, val []byte) (interface{}, bool, error) {
|
||||
adaptiveValue := AdaptiveValue(val)
|
||||
if len(adaptiveValue) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env bats
|
||||
load $BATS_TEST_DIRNAME/helper/common.bash
|
||||
|
||||
setup() {
|
||||
if [ "$SQL_ENGINE" = "remote-engine" ]; then
|
||||
skip "This test tests remote connections directly, SQL_ENGINE is not needed."
|
||||
fi
|
||||
setup_common
|
||||
}
|
||||
|
||||
teardown() {
|
||||
stop_sql_server 1 && sleep 0.5
|
||||
teardown_common
|
||||
}
|
||||
|
||||
@test "dolt-test-run: sanity test on sql-server" {
|
||||
start_sql_server
|
||||
|
||||
dolt sql -q "insert into dolt_tests values ('test', 'test', 'select 1', 'expected_rows', '==', '1');"
|
||||
run dolt sql -q "select * from dolt_test_run()"
|
||||
[ $status -eq 0 ]
|
||||
[[ $output =~ "| test | test | select 1 | PASS | |" ]] || false
|
||||
}
|
||||
@@ -960,11 +960,12 @@ CREATE TABLE \`all_types\` (
|
||||
\`v31\` varchar(255) DEFAULT NULL,
|
||||
\`v32\` varbinary(255) DEFAULT NULL,
|
||||
\`v33\` year DEFAULT NULL,
|
||||
\`v34\` vector(1) DEFAULT NULL,
|
||||
PRIMARY KEY (\`pk\`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin;
|
||||
INSERT INTO \`all_types\` (\`pk\`,\`v1\`,\`v2\`,\`v3\`,\`v4\`,\`v5\`,\`v6\`,\`v7\`,\`v8\`,\`v9\`,\`v10\`,\`v11\`,\`v12\`,\`v13\`,\`v14\`,\`v15\`,\`v16\`,\`v17\`,\`v18\`,\`v19\`,\`v20\`,\`v21\`,\`v22\`,\`v23\`,\`v24\`,\`v25\`,\`v26\`,\`v27\`,\`v28\`,\`v29\`,\`v30\`,\`v31\`,\`v32\`,\`v33\`) VALUES
|
||||
(1,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,'null',NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
(2, 1, 1, 1 ,('abc'),'a','2022-04-05','2022-10-05 10:14:41',2.34,2.34,'s',2.34,POINT(1,2),1,'{"a":1}',LINESTRING(POINT(0,0),POINT(1,2)),('abcd'),'abcd',('ab'),1,'abc',POINT(2,1),polygon(linestring(point(1,2),point(3,4),point(5,6),point(1,2))),'one',1,'abc','10:14:41','2022-10-05 10:14:41',('a'),1,'a','abcde',1,2022);
|
||||
(1,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,'null',NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL,NULL),
|
||||
(2, 1, 1, 1 ,('abc'),'a','2022-04-05','2022-10-05 10:14:41',2.34,2.34,'s',2.34,POINT(1,2),1,'{"a":1}',LINESTRING(POINT(0,0),POINT(1,2)),('abcd'),'abcd',('ab'),1,'abc',POINT(2,1),polygon(linestring(point(1,2),point(3,4),point(5,6),point(1,2))),'one',1,'abc','10:14:41','2022-10-05 10:14:41',('a'),1,'a','abcde',1,2022,0x12345678);
|
||||
SQL
|
||||
[ "$status" -eq 0 ]
|
||||
|
||||
@@ -978,6 +979,9 @@ SQL
|
||||
run dolt sql -q "SELECT ST_AsText(v12), ST_AsText(v21), ST_AsText(v15), ST_AsText(v22) from t1;"
|
||||
[[ "$output" =~ "POINT(1 2) | POINT(2 1) | LINESTRING(0 0,1 2) | POLYGON((1 2,3 4,5 6,1 2))" ]] || false
|
||||
|
||||
run dolt sql -q "SELECT v34 from t1"
|
||||
[[ "$output" =~ "0x12345678" ]] || false
|
||||
|
||||
# need to test binary, bit and blob types
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
# dolt/integration-tests/bats/helper/sql-diff.bash
|
||||
|
||||
: "${SQL_DIFF_DEBUG:=}" # set to any value to enable debug output
|
||||
_dbg() { [ -n "$SQL_DIFF_DEBUG" ] && printf '%s\n' "$*" >&2; }
|
||||
_dbg_block() { [ -n "$SQL_DIFF_DEBUG" ] && { printf '%s\n' "$1" >&2; printf '%s\n' "$2" >&2; }; }
|
||||
|
||||
# first table header row from CLI diff (data section), as newline list
|
||||
_cli_header_cols() {
|
||||
awk '
|
||||
/^\s*\|\s*[-+<>]\s*\|/ && last_header != "" { print last_header; exit }
|
||||
/^\s*\|/ { last_header = $0 }
|
||||
' <<<"$1" \
|
||||
| tr '|' '\n' \
|
||||
| sed -e 's/^[[:space:]]*//;s/[[:space:]]*$//' \
|
||||
| grep -v -E '^(<|>|)$' \
|
||||
| grep -v '^$'
|
||||
}
|
||||
|
||||
# first table header row from SQL diff, strip to_/from_, drop metadata, as newline list
|
||||
_sql_data_header_cols() {
|
||||
echo "$1" \
|
||||
| awk '/^\|/ {print; exit}' \
|
||||
| tr '|' '\n' \
|
||||
| sed -e 's/^[[:space:] ]*//;s/[[:space:] ]*$//' \
|
||||
| grep -E '^(to_|from_)' \
|
||||
| sed -E 's/^(to_|from_)//' \
|
||||
| grep -Ev '^(commit|commit_date|diff_type)$' \
|
||||
| grep -v '^$'
|
||||
}
|
||||
|
||||
# count CLI changes by unique PK (includes +, -, <, >)
|
||||
_cli_change_count() {
|
||||
awk -F'|' '
|
||||
# start counting once we see a data row marker
|
||||
/^\s*\|\s*[-+<>]\s*\|/ { in_table=1 }
|
||||
in_table && $2 ~ /^[[:space:]]*[-+<>][[:space:]]*$/ {
|
||||
pk=$3
|
||||
gsub(/^[[:space:]]+|[[:space:]]+$/, "", pk)
|
||||
if (pk != "") seen[pk]=1
|
||||
}
|
||||
END { c=0; for (k in seen) c++; print c+0 }
|
||||
' <<<"$1"
|
||||
}
|
||||
|
||||
# count SQL data rows (lines starting with '|' minus header)
|
||||
_sql_row_count() {
|
||||
echo "$1" | awk '/^\|/ {c++} END{print (c>0?c-1:0)}'
|
||||
}
|
||||
|
||||
# compare two newline lists as sets (sorted)
|
||||
_compare_sets_or_err() {
|
||||
local name="$1" cli_cols="$2" sql_cols="$3" cli_out="$4" sql_out="$5"
|
||||
|
||||
local cli_sorted sql_sorted
|
||||
cli_sorted=$(echo "$cli_cols" | sort -u)
|
||||
sql_sorted=$(echo "$sql_cols" | sort -u)
|
||||
|
||||
_dbg_block "$name CLI columns:" "$cli_sorted"
|
||||
_dbg_block "$name SQL data columns:" "$sql_sorted"
|
||||
|
||||
if [ "$cli_sorted" != "$sql_sorted" ]; then
|
||||
echo "$name column set mismatch"
|
||||
echo "--- $name CLI columns ---"; echo "$cli_sorted"
|
||||
echo "--- $name SQL data columns ---"; echo "$sql_sorted"
|
||||
echo "--- $name CLI output ---"; echo "$cli_out"
|
||||
echo "--- $name SQL output ---"; echo "$sql_out"
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
# compare change/row counts; on mismatch, print both outputs
|
||||
_compare_counts_or_err() {
|
||||
local name="$1" cli_out="$2" sql_out="$3" cli_count="$4" sql_count="$5"
|
||||
|
||||
_dbg "$name counts: CLI=$cli_count SQL=$sql_count"
|
||||
|
||||
if [ "$cli_count" != "$sql_count" ]; then
|
||||
echo "$name change count mismatch: CLI=$cli_count, SQL=$sql_count"
|
||||
echo "--- $name CLI output ---"; echo "$cli_out"
|
||||
echo "--- $name SQL output ---"; echo "$sql_out"
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
# ---- main entrypoint ----
|
||||
|
||||
# Compare CLI diff with SQL dolt_diff
|
||||
# Usage: compare_dolt_diff [all dolt diff args...]
|
||||
compare_dolt_diff() {
|
||||
local args=("$@") # all arguments
|
||||
|
||||
# --- normal diff ---
|
||||
local cli_output sql_output cli_status sql_status
|
||||
cli_output=$(dolt diff "${args[@]}" 2>&1)
|
||||
cli_status=$?
|
||||
|
||||
# Build SQL argument list safely
|
||||
local sql_args=""
|
||||
for arg in "${args[@]}"; do
|
||||
if [ -z "$sql_args" ]; then
|
||||
sql_args="'$arg'"
|
||||
else
|
||||
sql_args+=", '$arg'"
|
||||
fi
|
||||
done
|
||||
sql_output=$(dolt sql -q "SELECT * FROM dolt_diff($sql_args)" 2>&1)
|
||||
sql_status=$?
|
||||
|
||||
# normally prints in bats using `run`, so no debug blocks here
|
||||
echo "$cli_output"
|
||||
echo "$sql_output"
|
||||
|
||||
if [ $cli_status -ne 0 ]; then
|
||||
_dbg "$cli_output"
|
||||
return 1
|
||||
fi
|
||||
if [ $sql_status -ne 0 ]; then
|
||||
_dbg "$sql_output"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Compare counts
|
||||
local cli_changes sql_rows
|
||||
cli_changes=$(_cli_change_count "$cli_output")
|
||||
sql_rows=$(_sql_row_count "$sql_output")
|
||||
_compare_counts_or_err "Diff" "$cli_output" "$sql_output" "$cli_changes" "$sql_rows" || return 1
|
||||
|
||||
# Compare columns
|
||||
local cli_cols sql_cols
|
||||
cli_cols=$(_cli_header_cols "$cli_output")
|
||||
sql_cols=$(_sql_data_header_cols "$sql_output")
|
||||
_compare_sets_or_err "Diff" "$cli_cols" "$sql_cols" "$cli_output" "$sql_output" || return 1
|
||||
|
||||
return 0
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env bats
|
||||
load $BATS_TEST_DIRNAME/helper/common.bash
|
||||
load $BATS_TEST_DIRNAME/helper/sql-diff.bash
|
||||
|
||||
setup() {
|
||||
setup_common
|
||||
@@ -890,3 +891,60 @@ EOF
|
||||
[ "$status" -eq 1 ]
|
||||
[[ "$output" =~ "invalid output format: sql. SQL format diffs only rendered for schema or data changes" ]] || false
|
||||
}
|
||||
|
||||
@test "sql-diff: skinny flag comparison between CLI and SQL table function" {
|
||||
dolt sql <<SQL
|
||||
CREATE TABLE test (
|
||||
pk BIGINT NOT NULL COMMENT 'tag:0',
|
||||
c1 BIGINT COMMENT 'tag:1',
|
||||
c2 BIGINT COMMENT 'tag:2',
|
||||
c3 BIGINT COMMENT 'tag:3',
|
||||
c4 BIGINT COMMENT 'tag:4',
|
||||
c5 BIGINT COMMENT 'tag:5',
|
||||
PRIMARY KEY (pk)
|
||||
);
|
||||
SQL
|
||||
dolt table import -u test `batshelper 1pk5col-ints.csv`
|
||||
dolt add test
|
||||
dolt commit -m "Added initial data"
|
||||
|
||||
compare_dolt_diff "HEAD~1" "HEAD" "test"
|
||||
compare_dolt_diff "-sk" "HEAD~1" "HEAD" "test"
|
||||
compare_dolt_diff "--skinny" "HEAD~1" "HEAD" "test"
|
||||
compare_dolt_diff "HEAD~1" "HEAD" "-sk" "test"
|
||||
|
||||
dolt sql -q "UPDATE test SET c1=100, c3=300 WHERE pk=0"
|
||||
dolt sql -q "UPDATE test SET c2=200 WHERE pk=1"
|
||||
dolt add test
|
||||
dolt commit -m "Updated some columns"
|
||||
|
||||
compare_dolt_diff "HEAD~1" "HEAD" "test"
|
||||
compare_dolt_diff "--skinny" "HEAD~1" "HEAD" "test"
|
||||
compare_dolt_diff "HEAD~1" "HEAD" "test" "--skinny"
|
||||
|
||||
dolt sql -q "ALTER TABLE test ADD COLUMN c6 BIGINT"
|
||||
dolt sql -q "UPDATE test SET c6=600 WHERE pk=0"
|
||||
dolt add test
|
||||
dolt commit -m "Added new column and updated it"
|
||||
|
||||
compare_dolt_diff "HEAD~1" "HEAD" "test"
|
||||
compare_dolt_diff "--skinny" "HEAD~1" "HEAD" "test"
|
||||
compare_dolt_diff "--skinny" "--include-cols=c1,c2" "HEAD~1" "HEAD" "test"
|
||||
compare_dolt_diff "--skinny" "HEAD~2" "HEAD" "test" "--include-cols" "c1" "c2"
|
||||
compare_dolt_diff "--skinny" "--include-cols=c1,c2" "HEAD~2" "HEAD" "test"
|
||||
|
||||
dolt sql -q "DELETE FROM test WHERE pk=1"
|
||||
dolt add test
|
||||
dolt commit -m "Deleted a row"
|
||||
|
||||
compare_dolt_diff "HEAD~1" "HEAD" "test"
|
||||
compare_dolt_diff "--skinny" "HEAD~1" "HEAD" "test"
|
||||
|
||||
run dolt sql -q "SELECT * FROM dolt_diff('-err', 'HEAD~1', 'HEAD', 'test')"
|
||||
[[ "$output" =~ "unknown option \`err" ]] || false
|
||||
[ "$status" -eq 1 ]
|
||||
|
||||
run dolt sql -q "SELECT * FROM dolt_diff('-sk', '--skinny', 'HEAD~1', 'HEAD', 'test')"
|
||||
[[ "$output" =~ "multiple values provided for \`skinny" ]] || false
|
||||
[ "$status" -eq 1 ]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user