Merge remote-tracking branch 'origin/main' into elian/6329b

This commit is contained in:
elianddb
2026-02-13 16:57:51 -08:00
182 changed files with 2269 additions and 9818 deletions
+1
View File
@@ -6,6 +6,7 @@ ARG DOLT_VERSION
RUN apt update -y && \
apt install -y \
curl \
git \
tini \
ca-certificates && \
apt clean && \
+1 -1
View File
@@ -4,7 +4,7 @@ FROM debian:bookworm-slim AS base
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update -y && \
apt-get install -y --no-install-recommends \
curl tini ca-certificates && \
curl git tini ca-certificates && \
rm -rf /var/lib/apt/lists/*
+6
View File
@@ -61,6 +61,7 @@ func CreateCommitArgParser(supportsBranchFlag bool) *argparser.ArgParser {
ap.SupportsFlag(UpperCaseAllFlag, "A", "Adds all tables and databases (including new tables) in the working set to the staged set.")
ap.SupportsFlag(AmendFlag, "", "Amend previous commit")
ap.SupportsOptionalString(SignFlag, "S", "key-id", "Sign the commit using GPG. If no key-id is provided the key-id is taken from 'user.signingkey' the in the configuration")
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification")
if supportsBranchFlag {
ap.SupportsString(BranchParam, "", "branch", "Commit to the specified branch instead of the current branch.")
}
@@ -96,6 +97,7 @@ func CreateMergeArgParser() *argparser.ArgParser {
ap.SupportsFlag(NoCommitFlag, "", "Perform the merge and stop just before creating a merge commit. Note this will not prevent a fast-forward merge; use the --no-ff arg together with the --no-commit arg to prevent both fast-forwards and merge commits.")
ap.SupportsFlag(NoEditFlag, "", "Use an auto-generated commit message when creating a merge commit. The default for interactive CLI sessions is to open an editor.")
ap.SupportsString(AuthorParam, "", "author", "Specify an explicit author using the standard A U Thor {{.LessThan}}author@example.com{{.GreaterThan}} format.")
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before merge")
return ap
}
@@ -116,6 +118,7 @@ func CreateRebaseArgParser() *argparser.ArgParser {
ap.SupportsFlag(AbortParam, "", "Abort an interactive rebase and return the working set to the pre-rebase state")
ap.SupportsFlag(ContinueFlag, "", "Continue an interactive rebase after adjusting the rebase plan")
ap.SupportsFlag(InteractiveFlag, "i", "Start an interactive rebase")
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before rebase")
return ap
}
@@ -174,6 +177,7 @@ func CreateRemoteArgParser() *argparser.ArgParser {
func CreateCleanArgParser() *argparser.ArgParser {
ap := argparser.NewArgParserWithVariableArgs("clean")
ap.SupportsFlag(DryRunFlag, "", "Tests removing untracked tables without modifying the working set.")
ap.SupportsFlag(ExcludeIgnoreRulesFlag, "x", "Do not respect dolt_ignore; remove untracked tables that match dolt_ignore. dolt_nonlocal_tables is always respected.")
return ap
}
@@ -192,6 +196,7 @@ func CreateCherryPickArgParser() *argparser.ArgParser {
ap.SupportsFlag(AllowEmptyFlag, "", "Allow empty commits to be cherry-picked. "+
"Note that use of this option only keeps commits that were initially empty. "+
"Commits which become empty, due to a previous commit, will cause cherry-pick to fail.")
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before cherry-pick")
ap.TooManyArgsErrorFunc = func(receivedArgs []string) error {
return errors.New("cherry-picking multiple commits is not supported yet.")
}
@@ -229,6 +234,7 @@ func CreatePullArgParser() *argparser.ArgParser {
ap.SupportsString(UserFlag, "", "user", "User name to use when authenticating with the remote. Gets password from the environment variable {{.EmphasisLeft}}DOLT_REMOTE_PASSWORD{{.EmphasisRight}}.")
ap.SupportsFlag(PruneFlag, "p", "After fetching, remove any remote-tracking references that don't exist on the remote.")
ap.SupportsFlag(SilentFlag, "", "Suppress progress information.")
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before merge")
return ap
}
+73 -71
View File
@@ -17,77 +17,79 @@ package cli
// Constants for command line flags names. These tend to be used in multiple places, so defining
// them low in the package dependency tree makes sense.
const (
AbortParam = "abort"
AllFlag = "all"
AllowEmptyFlag = "allow-empty"
AmendFlag = "amend"
AuthorParam = "author"
ArchiveLevelParam = "archive-level"
BranchParam = "branch"
CachedFlag = "cached"
CheckoutCreateBranch = "b"
CreateResetBranch = "B"
CommitFlag = "commit"
ContinueFlag = "continue"
CopyFlag = "copy"
DateParam = "date"
DecorateFlag = "decorate"
DeleteFlag = "delete"
DeleteForceFlag = "D"
DepthFlag = "depth"
DryRunFlag = "dry-run"
EmptyParam = "empty"
ForceFlag = "force"
FullFlag = "full"
GraphFlag = "graph"
HardResetParam = "hard"
HostFlag = "host"
IncludeUntrackedFlag = "include-untracked"
InteractiveFlag = "interactive"
JobFlag = "job"
ListFlag = "list"
MergesFlag = "merges"
MessageArg = "message"
MinParentsFlag = "min-parents"
MoveFlag = "move"
NoCommitFlag = "no-commit"
NoEditFlag = "no-edit"
NoFFParam = "no-ff"
FFOnlyParam = "ff-only"
NoPrettyFlag = "no-pretty"
NoTLSFlag = "no-tls"
NoJsonMergeFlag = "dont-merge-json"
NotFlag = "not"
NumberFlag = "number"
OneLineFlag = "oneline"
OursFlag = "ours"
OutputOnlyFlag = "output-only"
ParentsFlag = "parents"
PatchFlag = "patch"
PasswordFlag = "password"
PortFlag = "port"
PruneFlag = "prune"
QuietFlag = "quiet"
RemoteParam = "remote"
SetUpstreamFlag = "set-upstream"
SetUpstreamToFlag = "set-upstream-to"
ShallowFlag = "shallow"
ShowIgnoredFlag = "ignored"
ShowSignatureFlag = "show-signature"
SignFlag = "gpg-sign"
SilentFlag = "silent"
SingleBranchFlag = "single-branch"
SkipEmptyFlag = "skip-empty"
SoftResetParam = "soft"
SquashParam = "squash"
StagedFlag = "staged"
StatFlag = "stat"
SystemFlag = "system"
TablesFlag = "tables"
TheirsFlag = "theirs"
TrackFlag = "track"
UpperCaseAllFlag = "ALL"
UserFlag = "user"
AbortParam = "abort"
AllFlag = "all"
AllowEmptyFlag = "allow-empty"
AmendFlag = "amend"
AuthorParam = "author"
ArchiveLevelParam = "archive-level"
BranchParam = "branch"
CachedFlag = "cached"
CheckoutCreateBranch = "b"
CreateResetBranch = "B"
CommitFlag = "commit"
ContinueFlag = "continue"
CopyFlag = "copy"
DateParam = "date"
DecorateFlag = "decorate"
DeleteFlag = "delete"
DeleteForceFlag = "D"
DepthFlag = "depth"
DryRunFlag = "dry-run"
EmptyParam = "empty"
ExcludeIgnoreRulesFlag = "x"
ForceFlag = "force"
FullFlag = "full"
GraphFlag = "graph"
HardResetParam = "hard"
HostFlag = "host"
IncludeUntrackedFlag = "include-untracked"
InteractiveFlag = "interactive"
JobFlag = "job"
ListFlag = "list"
MergesFlag = "merges"
MessageArg = "message"
MinParentsFlag = "min-parents"
MoveFlag = "move"
NoCommitFlag = "no-commit"
NoEditFlag = "no-edit"
NoFFParam = "no-ff"
FFOnlyParam = "ff-only"
NoPrettyFlag = "no-pretty"
NoTLSFlag = "no-tls"
NoJsonMergeFlag = "dont-merge-json"
NotFlag = "not"
NumberFlag = "number"
OneLineFlag = "oneline"
OursFlag = "ours"
OutputOnlyFlag = "output-only"
ParentsFlag = "parents"
PatchFlag = "patch"
PasswordFlag = "password"
PortFlag = "port"
PruneFlag = "prune"
QuietFlag = "quiet"
RemoteParam = "remote"
SetUpstreamFlag = "set-upstream"
SetUpstreamToFlag = "set-upstream-to"
ShallowFlag = "shallow"
ShowIgnoredFlag = "ignored"
ShowSignatureFlag = "show-signature"
SignFlag = "gpg-sign"
SilentFlag = "silent"
SingleBranchFlag = "single-branch"
SkipEmptyFlag = "skip-empty"
SkipVerificationFlag = "skip-verification"
SoftResetParam = "soft"
SquashParam = "squash"
StagedFlag = "staged"
StatFlag = "stat"
SystemFlag = "system"
TablesFlag = "tables"
TheirsFlag = "theirs"
TrackFlag = "track"
UpperCaseAllFlag = "ALL"
UserFlag = "user"
)
// Flags used by `dolt diff` command and `dolt_diff()` table function.
+13 -5
View File
@@ -32,16 +32,17 @@ const (
var cleanDocContent = cli.CommandDocumentationContent{
ShortDesc: "Deletes untracked working tables",
LongDesc: "{{.EmphasisLeft}}dolt clean [--dry-run]{{.EmphasisRight}}\n\n" +
LongDesc: "{{.EmphasisLeft}}dolt clean [--dry-run] [-x]{{.EmphasisRight}}\n\n" +
"The default (parameterless) form clears the values for all untracked working {{.LessThan}}tables{{.GreaterThan}} ." +
"This command permanently deletes unstaged or uncommitted tables.\n\n" +
"This command permanently deletes unstaged or uncommitted tables. By default, tables matching dolt_ignore or dolt_nonlocal_tables are not removed.\n\n" +
"The {{.EmphasisLeft}}--dry-run{{.EmphasisRight}} flag can be used to test whether the clean can succeed without " +
"deleting any tables from the current working set.\n\n" +
"{{.EmphasisLeft}}dolt clean [--dry-run] {{.LessThan}}tables{{.GreaterThan}}...{{.EmphasisRight}}\n\n" +
"The {{.EmphasisLeft}}-x{{.EmphasisRight}} flag causes dolt_ignore to be ignored so that untracked tables matching dolt_ignore are removed; dolt_nonlocal_tables is always respected (similar to git clean -x).\n\n" +
"{{.EmphasisLeft}}dolt clean [--dry-run] [-x] {{.LessThan}}tables{{.GreaterThan}}...{{.EmphasisRight}}\n\n" +
"If {{.LessThan}}tables{{.GreaterThan}} is specified, only those table names are considered for deleting.\n\n",
Synopsis: []string{
"[--dry-run]",
"[--dry-run] {{.LessThan}}tables{{.GreaterThan}}...",
"[--dry-run] [-x]",
"[--dry-run] [-x] {{.LessThan}}tables{{.GreaterThan}}...",
},
}
@@ -87,6 +88,13 @@ func (cmd CleanCmd) Exec(ctx context.Context, commandStr string, args []string,
buffer.WriteString("\"--dry-run\"")
firstParamDone = true
}
if apr.Contains(cli.ExcludeIgnoreRulesFlag) {
if firstParamDone {
buffer.WriteString(", ")
}
buffer.WriteString("\"-x\"")
firstParamDone = true
}
if apr.NArg() > 0 {
// loop over apr.Args() and add them to the buffer
for i := 0; i < apr.NArg(); i++ {
+4
View File
@@ -266,6 +266,10 @@ func constructParametrizedDoltCommitQuery(msg string, apr *argparser.ArgParseRes
writeToBuffer("--skip-empty")
}
if apr.Contains(cli.SkipVerificationFlag) {
writeToBuffer("--skip-verification")
}
cfgSign := cliCtx.Config().GetStringOrDefault("sqlserver.global.gpgsign", "")
if apr.Contains(cli.SignFlag) || strings.ToLower(cfgSign) == "true" {
writeToBuffer("--gpg-sign")
-15
View File
@@ -26,7 +26,6 @@ import (
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/sqltypes"
ast "github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/go-sql-driver/mysql"
"github.com/gocraft/dbr/v2"
"github.com/gocraft/dbr/v2/dialect"
@@ -549,20 +548,6 @@ func getTableNamesAtRef(queryist cli.Queryist, sqlCtx *sql.Context, ref string)
return tableNames, nil
}
func isTableNotFoundError(err error) bool {
if sql.ErrTableNotFound.Is(err) {
return true
}
mse, ok := err.(*mysql.MySQLError)
if ok {
if strings.HasPrefix(mse.Message, "table not found:") {
return true
}
}
return false
}
// applyDiffRoots applies the appropriate |from| and |to| root values to the receiver and returns the table names
// (if any) given to the command.
func (dArgs *diffArgs) applyDiffRoots(queryist cli.Queryist, sqlCtx *sql.Context, args []string, isCached, useMergeBase bool) ([]string, error) {
+1 -11
View File
@@ -623,16 +623,6 @@ func dumpTable(ctx *sql.Context, dEnv *env.DoltEnv, engine *sqle.Engine, root do
}
func getTableWriter(ctx context.Context, dEnv *env.DoltEnv, tblOpts *tableOptions, outSch schema.Schema, filePath string) (table.SqlRowWriter, errhand.VerboseError) {
tmpDir, err := dEnv.TempTableFilesDir()
if err != nil {
return nil, errhand.BuildDError("error: ").AddCause(err).Build()
}
deaf, err := dEnv.DbEaFactory(ctx)
if err != nil {
return nil, errhand.BuildDError("error: ").AddCause(err).Build()
}
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
writer, err := dEnv.FS.OpenForWriteAppend(filePath, os.ModePerm)
if err != nil {
return nil, errhand.BuildDError("Error opening writer for %s.", tblOpts.DestName()).AddCause(err).Build()
@@ -643,7 +633,7 @@ func getTableWriter(ctx context.Context, dEnv *env.DoltEnv, tblOpts *tableOption
return nil, errhand.BuildDError("Could not create table writer for %s", tblOpts.tableName).AddCause(err).Build()
}
wr, err := tblOpts.dest.NewCreatingWriter(ctx, tblOpts, root, outSch, opts, writer)
wr, err := tblOpts.dest.NewCreatingWriter(ctx, tblOpts, root, outSch, editor.Options{}, writer)
if err != nil {
return nil, errhand.BuildDError("Could not create table writer for %s", tblOpts.tableName).AddCause(err).Build()
}
+1 -1
View File
@@ -143,7 +143,7 @@ func NewSqlEngine(
})
}
dbs, locations, err := CollectDBs(ctx, mrEnv, config.Bulk)
dbs, locations, err := CollectDBs(ctx, mrEnv)
if err != nil {
return nil, err
}
+4 -22
View File
@@ -26,13 +26,13 @@ import (
// CollectDBs takes a MultiRepoEnv and creates Database objects from each environment and returns a slice of these
// objects.
func CollectDBs(ctx context.Context, mrEnv *env.MultiRepoEnv, useBulkEditor bool) ([]dsess.SqlDatabase, []filesys.Filesys, error) {
func CollectDBs(ctx context.Context, mrEnv *env.MultiRepoEnv) ([]dsess.SqlDatabase, []filesys.Filesys, error) {
var dbs []dsess.SqlDatabase
var locations []filesys.Filesys
var db dsess.SqlDatabase
err := mrEnv.Iter(func(name string, dEnv *env.DoltEnv) (stop bool, err error) {
db, err = newDatabase(ctx, name, dEnv, useBulkEditor)
db, err = newDatabase(ctx, name, dEnv)
if err != nil {
return false, err
}
@@ -50,25 +50,7 @@ func CollectDBs(ctx context.Context, mrEnv *env.MultiRepoEnv, useBulkEditor bool
return dbs, locations, nil
}
func newDatabase(ctx context.Context, name string, dEnv *env.DoltEnv, useBulkEditor bool) (sqle.Database, error) {
var deaf editor.DbEaFactory
var err error
if useBulkEditor {
deaf, err = dEnv.BulkDbEaFactory(ctx)
} else {
deaf, err = dEnv.DbEaFactory(ctx)
}
if err != nil {
return sqle.Database{}, err
}
tmpDir, err := dEnv.TempTableFilesDir()
if err != nil {
return sqle.Database{}, err
}
opts := editor.Options{
Deaf: deaf,
Tempdir: tmpDir,
}
func newDatabase(ctx context.Context, name string, dEnv *env.DoltEnv) (sqle.Database, error) {
dbdata := dEnv.DbData(ctx)
// Databases registered with the SQL engine are always
// configured for FatalBehaviorCrash. These are local
@@ -81,5 +63,5 @@ func newDatabase(ctx context.Context, name string, dEnv *env.DoltEnv, useBulkEdi
// See also sqle/database_provider.go, where we do this when
// creating new databases as well.
dbdata.Ddb.SetCrashOnFatalError()
return sqle.NewDatabase(ctx, name, dbdata, opts)
return sqle.NewDatabase(ctx, name, dbdata, editor.Options{})
}
+1 -10
View File
@@ -335,16 +335,7 @@ func processFilterQuery(ctx context.Context, dEnv *env.DoltEnv, root doltdb.Root
// we set manually with the one at the working set of the HEAD being rebased.
// Some functionality will not work on this kind of engine, e.g. many DOLT_ functions.
func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue) (*sql.Context, *engine.SqlEngine, error) {
tmpDir, err := dEnv.TempTableFilesDir()
if err != nil {
return nil, nil, err
}
deaf, err := dEnv.DbEaFactory(ctx)
if err != nil {
return nil, nil, err
}
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
db, err := dsqle.NewDatabase(ctx, filterDbName, dEnv.DbData(ctx), opts)
db, err := dsqle.NewDatabase(ctx, filterDbName, dEnv.DbData(ctx), editor.Options{})
if err != nil {
return nil, nil, err
}
+2 -10
View File
@@ -87,15 +87,7 @@ func (cmd RebuildCmd) Exec(ctx context.Context, commandStr string, args []string
if !ok {
return HandleErr(errhand.BuildDError("The table `%s` does not exist.", tableName).Build(), nil)
}
tmpDir, err := dEnv.TempTableFilesDir()
if err != nil {
return HandleErr(errhand.BuildDError("error: ").AddCause(err).Build(), nil)
}
deaf, err := dEnv.DbEaFactory(ctx)
if err != nil {
return HandleErr(errhand.BuildDError("error: ").AddCause(err).Build(), nil)
}
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
sch, err := table.GetSchema(ctx)
if err != nil {
return HandleErr(errhand.BuildDError("could not get table schema").AddCause(err).Build(), nil)
@@ -104,7 +96,7 @@ func (cmd RebuildCmd) Exec(ctx context.Context, commandStr string, args []string
if idxSch == nil {
return HandleErr(errhand.BuildDError("the index `%s` does not exist on table `%s`", indexName, tableName).Build(), nil)
}
indexRowData, err := creation.BuildSecondaryIndex(sql.NewContext(ctx), table, idxSch, tableName, opts)
indexRowData, err := creation.BuildSecondaryIndex(sql.NewContext(ctx), table, idxSch, tableName, editor.Options{})
if err != nil {
return HandleErr(errhand.BuildDError("Unable to rebuild index `%s` on table `%s`.", indexName, tableName).AddCause(err).Build(), nil)
}
+4
View File
@@ -318,6 +318,10 @@ func constructInterpolatedDoltMergeQuery(apr *argparser.ArgParseResults, cliCtx
params = append(params, msg)
}
if apr.Contains(cli.SkipVerificationFlag) {
writeToBuffer("--skip-verification", false)
}
if !apr.Contains(cli.AbortParam) && !apr.Contains(cli.SquashParam) {
writeToBuffer("?", true)
params = append(params, apr.Arg(0))
+1 -10
View File
@@ -134,16 +134,7 @@ func exportSchemas(ctx context.Context, apr *argparser.ArgParseResults, root dol
}
for _, tn := range tablesToExport {
tmpDir, err := dEnv.TempTableFilesDir()
if err != nil {
return errhand.BuildDError("error: ").AddCause(err).Build()
}
deaf, err := dEnv.DbEaFactory(ctx)
if err != nil {
return errhand.BuildDError("error: ").AddCause(err).Build()
}
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
verr := exportTblSchema(ctx, tn, root, wr, opts)
verr := exportTblSchema(ctx, tn, root, wr, editor.Options{})
if verr != nil {
return verr
}
+1 -10
View File
@@ -133,16 +133,7 @@ func printSchemas(ctx context.Context, apr *argparser.ArgParseResults, dEnv *env
}
}
tmpDir, err := dEnv.TempTableFilesDir()
if err != nil {
return errhand.BuildDError("error: ").AddCause(err).Build()
}
deaf, err := dEnv.DbEaFactory(ctx)
if err != nil {
return errhand.BuildDError("error: ").AddCause(err).Build()
}
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
sqlCtx, engine, _ := dsqle.PrepareCreateTableStmt(ctx, dsqle.NewUserSpaceDatabase(root, opts))
sqlCtx, engine, _ := dsqle.PrepareCreateTableStmt(ctx, dsqle.NewUserSpaceDatabase(root, editor.Options{}))
var notFound []string
for _, tblName := range tables {
+3
View File
@@ -156,6 +156,9 @@ func mcpRun(cfg *Config, lgr *logrus.Logger, state *svcs.ServiceState, cancelPtr
logger,
dbConf,
*cfg.MCP.Port,
nil, // jwkClaimsMap
"", // jwkUrl
nil, // tlsConfig
toolsets.WithToolSet(&toolsets.PrimitiveToolSetV1{}),
)
if err != nil {
+1 -5
View File
@@ -258,11 +258,7 @@ func getTableWriter(ctx context.Context, root doltdb.RootValue, dEnv *env.DoltEn
return nil, errhand.BuildDError("Error opening writer for %s.", exOpts.DestName()).AddCause(err).Build()
}
deaf, err := dEnv.DbEaFactory(ctx)
if err != nil {
return nil, errhand.BuildDError("Error opening writer for %s.", exOpts.DestName()).AddCause(err).Build()
}
wr, err := exOpts.dest.NewCreatingWriter(ctx, exOpts, root, rdSchema, editor.Options{Deaf: deaf}, writer)
wr, err := exOpts.dest.NewCreatingWriter(ctx, exOpts, root, rdSchema, editor.Options{}, writer)
if err != nil {
return nil, errhand.BuildDError("Error opening writer for %s.", exOpts.DestName()).AddCause(err).Build()
}
+1 -1
View File
@@ -15,5 +15,5 @@
package doltversion
const (
Version = "1.81.8"
Version = "1.81.10"
)
+16 -1
View File
@@ -579,7 +579,19 @@ func (rcv *RebaseState) MutateRebasingStarted(n bool) bool {
return rcv._tab.MutateBoolSlot(16, n)
}
const RebaseStateNumFields = 7
func (rcv *RebaseState) SkipVerification() bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(18))
if o != 0 {
return rcv._tab.GetBool(o + rcv._tab.Pos)
}
return false
}
func (rcv *RebaseState) MutateSkipVerification(n bool) bool {
return rcv._tab.MutateBoolSlot(18, n)
}
const RebaseStateNumFields = 8
func RebaseStateStart(builder *flatbuffers.Builder) {
builder.StartObject(RebaseStateNumFields)
@@ -614,6 +626,9 @@ func RebaseStateAddLastAttemptedStep(builder *flatbuffers.Builder, lastAttempted
func RebaseStateAddRebasingStarted(builder *flatbuffers.Builder, rebasingStarted bool) {
builder.PrependBoolSlot(6, rebasingStarted, false)
}
func RebaseStateAddSkipVerification(builder *flatbuffers.Builder, skipVerification bool) {
builder.PrependBoolSlot(7, skipVerification, false)
}
func RebaseStateEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
return builder.EndObject()
}
+2 -2
View File
@@ -58,10 +58,10 @@ require (
github.com/cespare/xxhash/v2 v2.3.0
github.com/creasty/defaults v1.6.0
github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12
github.com/dolthub/dolt-mcp v0.2.2
github.com/dolthub/dolt-mcp v0.3.4
github.com/dolthub/eventsapi_schema v0.0.0-20260205214132-a7a3c84c84a1
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
github.com/dolthub/go-mysql-server v0.20.1-0.20260211220532-85072e590dc7
github.com/dolthub/go-mysql-server v0.20.1-0.20260212215527-0cb492ad7051
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63
github.com/edsrzf/mmap-go v1.2.0
github.com/esote/minmaxheap v1.0.0
+4 -4
View File
@@ -186,8 +186,8 @@ github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waN
github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:IdqX7J8vi/Kn3T3Ee0VzqnLqwFmgA2hr8WZETPcQjfM=
github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo=
github.com/dolthub/dolt-mcp v0.2.2 h1:bpROmam74n95uU4EA3BpOIVlTDT0pzeFMBwe/YRq2mI=
github.com/dolthub/dolt-mcp v0.2.2/go.mod h1:S++DJ4QWTAXq+0TNzFa7Oq3IhoT456DJHwAINFAHgDQ=
github.com/dolthub/dolt-mcp v0.3.4 h1:AyG5cw+fNWXDHXujtQnqUPZrpWtPg6FN6yYtjv1pP44=
github.com/dolthub/dolt-mcp v0.3.4/go.mod h1:bCZ7KHvDYs+M0e+ySgmGiNvLhcwsN7bbf5YCyillLrk=
github.com/dolthub/eventsapi_schema v0.0.0-20260205214132-a7a3c84c84a1 h1:QePoMpa5qlquwUqRVyF9KAHsJAlYbE2+eZkMPAxeBXc=
github.com/dolthub/eventsapi_schema v0.0.0-20260205214132-a7a3c84c84a1/go.mod h1:evuptFmr/0/j0X/g+3cveHEEOM5tqyRA15FNgirtOY0=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww=
@@ -196,8 +196,8 @@ github.com/dolthub/fslock v0.0.0-20251215194149-ef20baba2318 h1:n+vdH5G5Db+1qnDC
github.com/dolthub/fslock v0.0.0-20251215194149-ef20baba2318/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790 h1:zxMsH7RLiG+dlZ/y0LgJHTV26XoiSJcuWq+em6t6VVc=
github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE=
github.com/dolthub/go-mysql-server v0.20.1-0.20260211220532-85072e590dc7 h1:9xC+/i949mi2wwsu6BKgvnDnuRcYy4KysrIb2x7DaSo=
github.com/dolthub/go-mysql-server v0.20.1-0.20260211220532-85072e590dc7/go.mod h1:LEWdXw6LKjdonOv2X808RpUc8wZVtQx4ZEPvmDWkvY4=
github.com/dolthub/go-mysql-server v0.20.1-0.20260212215527-0cb492ad7051 h1:7vNnl/Z2HhFFUTdXNOySd8KFODBztPlmCITrRIKDgTw=
github.com/dolthub/go-mysql-server v0.20.1-0.20260212215527-0cb492ad7051/go.mod h1:LEWdXw6LKjdonOv2X808RpUc8wZVtQx4ZEPvmDWkvY4=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=
@@ -52,6 +52,9 @@ type CherryPickOptions struct {
// and Dolt cherry-pick implementations, the default action is to fail when an empty commit is specified. In Git
// and Dolt rebase implementations, the default action is to keep commits that start off as empty.
EmptyCommitHandling doltdb.EmptyCommitHandling
// SkipVerification controls whether test validation should be skipped before creating commits.
SkipVerification bool
}
// NewCherryPickOptions creates a new CherryPickOptions instance, filled out with default values for cherry-pick.
@@ -61,6 +64,7 @@ func NewCherryPickOptions() CherryPickOptions {
CommitMessage: "",
CommitBecomesEmptyHandling: doltdb.ErrorOnEmptyCommit,
EmptyCommitHandling: doltdb.ErrorOnEmptyCommit,
SkipVerification: false,
}
}
@@ -159,9 +163,10 @@ func CreateCommitStagedPropsFromCherryPickOptions(ctx *sql.Context, options Cher
}
commitProps := actions.CommitStagedProps{
Date: originalMeta.Time(),
Name: originalMeta.Name,
Email: originalMeta.Email,
Date: originalMeta.Time(),
Name: originalMeta.Name,
Email: originalMeta.Email,
SkipVerification: options.SkipVerification,
}
if options.CommitMessage != "" {
-156
View File
@@ -1,156 +0,0 @@
// Copyright 2019 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package conflict
import (
"context"
"errors"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/encoding"
"github.com/dolthub/dolt/go/store/types"
)
type ConflictSchema struct {
Base schema.Schema
Schema schema.Schema
MergeSchema schema.Schema
}
func NewConflictSchema(base, sch, mergeSch schema.Schema) ConflictSchema {
return ConflictSchema{
Base: base,
Schema: sch,
MergeSchema: mergeSch,
}
}
func ValueFromConflictSchema(ctx context.Context, vrw types.ValueReadWriter, cs ConflictSchema) (types.Value, error) {
b, err := serializeSchema(ctx, vrw, cs.Base)
if err != nil {
return nil, err
}
s, err := serializeSchema(ctx, vrw, cs.Schema)
if err != nil {
return nil, err
}
m, err := serializeSchema(ctx, vrw, cs.MergeSchema)
if err != nil {
return nil, err
}
return types.NewTuple(vrw.Format(), b, s, m)
}
func ConflictSchemaFromValue(ctx context.Context, vrw types.ValueReadWriter, v types.Value) (cs ConflictSchema, err error) {
tup, ok := v.(types.Tuple)
if !ok {
err = errors.New("conflict schema value must be types.Struct")
return ConflictSchema{}, err
}
b, err := tup.Get(0)
if err != nil {
return ConflictSchema{}, err
}
cs.Base, err = deserializeSchema(ctx, vrw, b)
if err != nil {
return ConflictSchema{}, err
}
s, err := tup.Get(1)
if err != nil {
return ConflictSchema{}, err
}
cs.Schema, err = deserializeSchema(ctx, vrw, s)
if err != nil {
return ConflictSchema{}, err
}
m, err := tup.Get(2)
if err != nil {
return ConflictSchema{}, err
}
cs.MergeSchema, err = deserializeSchema(ctx, vrw, m)
if err != nil {
return ConflictSchema{}, err
}
return cs, nil
}
func serializeSchema(ctx context.Context, vrw types.ValueReadWriter, sch schema.Schema) (types.Ref, error) {
st, err := encoding.MarshalSchema(ctx, vrw, sch)
if err != nil {
return types.Ref{}, err
}
return vrw.WriteValue(ctx, st)
}
func deserializeSchema(ctx context.Context, vrw types.ValueReadWriter, v types.Value) (schema.Schema, error) {
r, ok := v.(types.Ref)
if !ok {
return nil, errors.New("conflict schemas field value is unexpected type")
}
return encoding.UnmarshalSchemaAtAddr(ctx, vrw, r.TargetHash())
}
type Conflict struct {
Base types.Value
Value types.Value
MergeValue types.Value
}
func NewConflict(base, value, mergeValue types.Value) Conflict {
if base == nil {
base = types.NullValue
}
if value == nil {
value = types.NullValue
}
if mergeValue == nil {
mergeValue = types.NullValue
}
return Conflict{base, value, mergeValue}
}
func ConflictFromTuple(tpl types.Tuple) (Conflict, error) {
base, err := tpl.Get(0)
if err != nil {
return Conflict{}, err
}
val, err := tpl.Get(1)
if err != nil {
return Conflict{}, err
}
mv, err := tpl.Get(2)
if err != nil {
return Conflict{}, err
}
return Conflict{base, val, mv}, nil
}
func (c Conflict) ToNomsList(vrw types.ValueReadWriter) (types.Tuple, error) {
return types.NewTuple(vrw.Format(), c.Base, c.Value, c.MergeValue)
}
-285
View File
@@ -1,285 +0,0 @@
// Copyright 2019 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package diff
import (
"context"
"fmt"
"time"
"golang.org/x/sync/errgroup"
"github.com/dolthub/dolt/go/libraries/doltcore/row"
"github.com/dolthub/dolt/go/libraries/utils/async"
"github.com/dolthub/dolt/go/store/diff"
"github.com/dolthub/dolt/go/store/types"
)
// todo: make package private
type AsyncDiffer struct {
diffChan chan diff.Difference
bufferSize int
eg *errgroup.Group
egCtx context.Context
egCancel func()
diffStats map[types.DiffChangeType]uint64
}
var _ RowDiffer = &AsyncDiffer{}
// todo: make package private once dolthub is migrated
func NewAsyncDiffer(bufferedDiffs int) *AsyncDiffer {
return &AsyncDiffer{
diffChan: make(chan diff.Difference, bufferedDiffs),
bufferSize: bufferedDiffs,
egCtx: context.Background(),
egCancel: func() {},
diffStats: make(map[types.DiffChangeType]uint64),
}
}
func tableDontDescendLists(v1, v2 types.Value) bool {
kind := v1.Kind()
return !types.IsPrimitiveKind(kind) && kind != types.TupleKind && kind == v2.Kind() && kind != types.RefKind
}
func (ad *AsyncDiffer) Start(ctx context.Context, from, to types.Map) {
ad.start(ctx, func(ctx context.Context) error {
return diff.Diff(ctx, from, to, ad.diffChan, true, tableDontDescendLists)
})
}
func (ad *AsyncDiffer) StartWithRange(ctx context.Context, from, to types.Map, start types.Value, inRange types.ValueInRange) {
ad.start(ctx, func(ctx context.Context) error {
return diff.DiffMapRange(ctx, from, to, start, inRange, ad.diffChan, true, tableDontDescendLists)
})
}
func (ad *AsyncDiffer) start(ctx context.Context, diffFunc func(ctx context.Context) error) {
ad.eg, ad.egCtx = errgroup.WithContext(ctx)
ad.egCancel = async.GoWithCancel(ad.egCtx, ad.eg, func(ctx context.Context) (err error) {
defer close(ad.diffChan)
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in diff.Diff: %v", r)
}
}()
return diffFunc(ctx)
})
}
func (ad *AsyncDiffer) Close() error {
ad.egCancel()
return ad.eg.Wait()
}
func (ad *AsyncDiffer) getDiffs(numDiffs int, timeoutChan <-chan time.Time, pred diffPredicate) ([]*diff.Difference, bool, error) {
diffs := make([]*diff.Difference, 0, numDiffs)
for {
select {
case d, more := <-ad.diffChan:
if more {
if pred(&d) {
ad.diffStats[d.ChangeType]++
diffs = append(diffs, &d)
}
if numDiffs != 0 && numDiffs == len(diffs) {
return diffs, true, nil
}
} else {
return diffs, false, ad.eg.Wait()
}
case <-timeoutChan:
return diffs, true, nil
case <-ad.egCtx.Done():
return nil, false, ad.eg.Wait()
}
}
}
var forever <-chan time.Time = make(chan time.Time)
type diffPredicate func(*diff.Difference) bool
var alwaysTruePredicate diffPredicate = func(*diff.Difference) bool {
return true
}
func hasChangeTypePredicate(changeType types.DiffChangeType) diffPredicate {
return func(d *diff.Difference) bool {
return d.ChangeType == changeType
}
}
func (ad *AsyncDiffer) GetDiffs(numDiffs int, timeout time.Duration) ([]*diff.Difference, bool, error) {
if timeout < 0 {
return ad.GetDiffsWithoutTimeout(numDiffs)
}
return ad.getDiffs(numDiffs, time.After(timeout), alwaysTruePredicate)
}
func (ad *AsyncDiffer) GetDiffsWithFilter(numDiffs int, timeout time.Duration, filterByChangeType types.DiffChangeType) ([]*diff.Difference, bool, error) {
if timeout < 0 {
return ad.GetDiffsWithoutTimeoutWithFilter(numDiffs, filterByChangeType)
}
return ad.getDiffs(numDiffs, time.After(timeout), hasChangeTypePredicate(filterByChangeType))
}
func (ad *AsyncDiffer) GetDiffsWithoutTimeoutWithFilter(numDiffs int, filterByChangeType types.DiffChangeType) ([]*diff.Difference, bool, error) {
return ad.getDiffs(numDiffs, forever, hasChangeTypePredicate(filterByChangeType))
}
func (ad *AsyncDiffer) GetDiffsWithoutTimeout(numDiffs int) ([]*diff.Difference, bool, error) {
return ad.getDiffs(numDiffs, forever, alwaysTruePredicate)
}
type keylessDiffer struct {
*AsyncDiffer
df diff.Difference
copiesLeft uint64
}
var _ RowDiffer = &keylessDiffer{}
func (kd *keylessDiffer) getDiffs(numDiffs int, timeoutChan <-chan time.Time, pred diffPredicate) ([]*diff.Difference, bool, error) {
diffs := make([]*diff.Difference, numDiffs)
idx := 0
for {
// first populate |diffs| with copies of |kd.df|
cpy := kd.df // save a copy of kd.df to reference
for (idx < numDiffs) && (kd.copiesLeft > 0) {
diffs[idx] = &cpy
idx++
kd.copiesLeft--
}
if idx == numDiffs {
return diffs, true, nil
}
// then find the next Difference the satisfies |pred|
match := false
for !match {
select {
case <-timeoutChan:
return diffs, true, nil
case <-kd.egCtx.Done():
return nil, false, kd.eg.Wait()
case d, more := <-kd.diffChan:
if !more {
return diffs[:idx], more, nil
}
var err error
kd.df, kd.copiesLeft, err = convertDiff(d)
if err != nil {
return nil, false, err
}
match = pred(&kd.df)
}
}
}
}
func (kd *keylessDiffer) GetDiffs(numDiffs int, timeout time.Duration) ([]*diff.Difference, bool, error) {
if timeout < 0 {
return kd.getDiffs(numDiffs, forever, alwaysTruePredicate)
}
return kd.getDiffs(numDiffs, time.After(timeout), alwaysTruePredicate)
}
func (kd *keylessDiffer) GetDiffsWithFilter(numDiffs int, timeout time.Duration, filterByChangeType types.DiffChangeType) ([]*diff.Difference, bool, error) {
if timeout < 0 {
return kd.getDiffs(numDiffs, forever, hasChangeTypePredicate(filterByChangeType))
}
return kd.getDiffs(numDiffs, time.After(timeout), hasChangeTypePredicate(filterByChangeType))
}
// convertDiff reports the cardinality of a change,
// and converts updates to adds or deletes
func convertDiff(df diff.Difference) (diff.Difference, uint64, error) {
var oldCard uint64
if df.OldValue != nil {
v, err := df.OldValue.(types.Tuple).Get(row.KeylessCardinalityValIdx)
if err != nil {
return df, 0, err
}
oldCard = uint64(v.(types.Uint))
}
var newCard uint64
if df.NewValue != nil {
v, err := df.NewValue.(types.Tuple).Get(row.KeylessCardinalityValIdx)
if err != nil {
return df, 0, err
}
newCard = uint64(v.(types.Uint))
}
switch df.ChangeType {
case types.DiffChangeRemoved:
return df, oldCard, nil
case types.DiffChangeAdded:
return df, newCard, nil
case types.DiffChangeModified:
delta := int64(newCard) - int64(oldCard)
if delta > 0 {
df.ChangeType = types.DiffChangeAdded
df.OldValue = nil
return df, uint64(delta), nil
} else if delta < 0 {
df.ChangeType = types.DiffChangeRemoved
df.NewValue = nil
return df, uint64(-delta), nil
} else {
panic(fmt.Sprintf("diff with delta = 0 for key: %s", df.KeyValue.HumanReadableString()))
}
default:
return df, 0, fmt.Errorf("unexpected DiffChange type %d", df.ChangeType)
}
}
type EmptyRowDiffer struct {
}
var _ RowDiffer = &EmptyRowDiffer{}
func (e EmptyRowDiffer) Start(ctx context.Context, from, to types.Map) {
}
func (e EmptyRowDiffer) StartWithRange(ctx context.Context, from, to types.Map, start types.Value, inRange types.ValueInRange) {
}
func (e EmptyRowDiffer) GetDiffs(numDiffs int, timeout time.Duration) ([]*diff.Difference, bool, error) {
return nil, false, nil
}
func (e EmptyRowDiffer) GetDiffsWithFilter(numDiffs int, timeout time.Duration, filterByChangeType types.DiffChangeType) ([]*diff.Difference, bool, error) {
return nil, false, nil
}
func (e EmptyRowDiffer) Close() error {
return nil
}
@@ -1,345 +0,0 @@
// Copyright 2021 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package diff
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/store/chunks"
"github.com/dolthub/dolt/go/store/constants"
"github.com/dolthub/dolt/go/store/types"
)
func TestAsyncDiffer(t *testing.T) {
ctx := context.Background()
storage := &chunks.MemoryStorage{}
vrw := types.NewValueStore(storage.NewView())
vals := []types.Value{
types.Uint(0), types.String("a"),
types.Uint(1), types.String("b"),
types.Uint(3), types.String("d"),
types.Uint(4), types.String("e"),
types.Uint(6), types.String("g"),
types.Uint(7), types.String("h"),
types.Uint(9), types.String("j"),
types.Uint(10), types.String("k"),
types.Uint(12), types.String("m"),
types.Uint(13), types.String("n"),
types.Uint(15), types.String("p"),
types.Uint(16), types.String("q"),
types.Uint(18), types.String("s"),
types.Uint(19), types.String("t"),
types.Uint(21), types.String("v"),
types.Uint(22), types.String("w"),
types.Uint(24), types.String("y"),
types.Uint(25), types.String("z"),
}
m1, err := types.NewMap(ctx, vrw, vals...)
require.NoError(t, err)
vals = []types.Value{
types.Uint(0), types.String("a"), // unchanged
//types.Uint(1), types.String("b"), // deleted
types.Uint(2), types.String("c"), // added
types.Uint(3), types.String("d"), // unchanged
//types.Uint(4), types.String("e"), // deleted
types.Uint(5), types.String("f"), // added
types.Uint(6), types.String("g"), // unchanged
//types.Uint(7), types.String("h"), // deleted
types.Uint(8), types.String("i"), // added
types.Uint(9), types.String("j"), // unchanged
//types.Uint(10), types.String("k"), // deleted
types.Uint(11), types.String("l"), // added
types.Uint(12), types.String("m2"), // changed
//types.Uint(13), types.String("n"), // deleted
types.Uint(14), types.String("o"), // added
types.Uint(15), types.String("p2"), // changed
//types.Uint(16), types.String("q"), // deleted
types.Uint(17), types.String("r"), // added
types.Uint(18), types.String("s2"), // changed
//types.Uint(19), types.String("t"), // deleted
types.Uint(20), types.String("u"), // added
types.Uint(21), types.String("v2"), // changed
//types.Uint(22), types.String("w"), // deleted
types.Uint(23), types.String("x"), // added
types.Uint(24), types.String("y2"), // changed
//types.Uint(25), types.String("z"), // deleted
}
m2, err := types.NewMap(ctx, vrw, vals...)
require.NoError(t, err)
tests := []struct {
name string
createdStarted func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer
expectedStats map[types.DiffChangeType]uint64
}{
{
name: "iter all",
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
ad := NewAsyncDiffer(4)
ad.Start(ctx, m1, m2)
return ad
},
expectedStats: map[types.DiffChangeType]uint64{
types.DiffChangeModified: 5,
types.DiffChangeAdded: 8,
types.DiffChangeRemoved: 9,
},
},
{
name: "iter range starting with nil",
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
ad := NewAsyncDiffer(4)
ad.StartWithRange(ctx, m1, m2, nil, func(ctx context.Context, value types.Value) (bool, bool, error) {
return true, false, nil
})
return ad
},
expectedStats: map[types.DiffChangeType]uint64{
types.DiffChangeModified: 5,
types.DiffChangeAdded: 8,
types.DiffChangeRemoved: 9,
},
},
{
name: "iter range staring with Null Value",
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
ad := NewAsyncDiffer(4)
ad.StartWithRange(ctx, m1, m2, types.NullValue, func(ctx context.Context, value types.Value) (bool, bool, error) {
return true, false, nil
})
return ad
},
expectedStats: map[types.DiffChangeType]uint64{
types.DiffChangeModified: 5,
types.DiffChangeAdded: 8,
types.DiffChangeRemoved: 9,
},
},
{
name: "iter range less than 17",
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
ad := NewAsyncDiffer(4)
end := types.Uint(27)
ad.StartWithRange(ctx, m1, m2, types.NullValue, func(ctx context.Context, value types.Value) (bool, bool, error) {
valid, err := value.Less(ctx, vrw.Format(), end)
return valid, false, err
})
return ad
},
expectedStats: map[types.DiffChangeType]uint64{
types.DiffChangeModified: 5,
types.DiffChangeAdded: 8,
types.DiffChangeRemoved: 9,
},
},
{
name: "iter range less than 15",
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
ad := NewAsyncDiffer(4)
end := types.Uint(15)
ad.StartWithRange(ctx, m1, m2, types.NullValue, func(ctx context.Context, value types.Value) (bool, bool, error) {
valid, err := value.Less(ctx, vrw.Format(), end)
return valid, false, err
})
return ad
},
expectedStats: map[types.DiffChangeType]uint64{
types.DiffChangeModified: 1,
types.DiffChangeAdded: 5,
types.DiffChangeRemoved: 5,
},
},
{
name: "iter range 10 < 15",
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
ad := NewAsyncDiffer(4)
start := types.Uint(10)
end := types.Uint(15)
ad.StartWithRange(ctx, m1, m2, start, func(ctx context.Context, value types.Value) (bool, bool, error) {
valid, err := value.Less(ctx, vrw.Format(), end)
return valid, false, err
})
return ad
},
expectedStats: map[types.DiffChangeType]uint64{
types.DiffChangeModified: 1,
types.DiffChangeAdded: 2,
types.DiffChangeRemoved: 2,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := context.Background()
ad := test.createdStarted(ctx, m1, m2)
err := readAll(ad)
require.NoError(t, err)
require.Equal(t, test.expectedStats, ad.diffStats)
})
}
t.Run("can close without reading all", func(t *testing.T) {
ad := NewAsyncDiffer(1)
ad.Start(ctx, m1, m2)
res, more, err := ad.GetDiffs(1, -1)
require.NoError(t, err)
assert.True(t, more)
assert.Len(t, res, 1)
err = ad.Close()
assert.NoError(t, err)
})
t.Run("can filter based on change type", func(t *testing.T) {
ad := NewAsyncDiffer(20)
ad.Start(ctx, m1, m2)
res, more, err := ad.GetDiffs(10, -1)
require.NoError(t, err)
assert.True(t, more)
assert.Len(t, res, 10)
err = ad.Close()
assert.NoError(t, err)
ad = NewAsyncDiffer(20)
ad.Start(ctx, m1, m2)
res, more, err = ad.GetDiffsWithFilter(10, 20*time.Second, types.DiffChangeModified)
require.NoError(t, err)
assert.False(t, more)
assert.Len(t, res, 5)
err = ad.Close()
assert.NoError(t, err)
ad = NewAsyncDiffer(20)
ad.Start(ctx, m1, m2)
res, more, err = ad.GetDiffsWithFilter(6, -1, types.DiffChangeAdded)
require.NoError(t, err)
assert.True(t, more)
assert.Len(t, res, 6)
err = ad.Close()
assert.NoError(t, err)
})
k1Row1Vals := []types.Value{c1Tag, types.Uint(3), c2Tag, types.String("d")}
k1Vals, err := getKeylessRow(ctx, k1Row1Vals)
assert.NoError(t, err)
k1, err := types.NewMap(ctx, vrw, k1Vals...)
assert.NoError(t, err)
// Delete one row, add two rows
k2Row1Vals := []types.Value{c1Tag, types.Uint(4), c2Tag, types.String("d")}
k2Vals1, err := getKeylessRow(ctx, k2Row1Vals)
assert.NoError(t, err)
k2Row2Vals := []types.Value{c1Tag, types.Uint(1), c2Tag, types.String("e")}
k2Vals2, err := getKeylessRow(ctx, k2Row2Vals)
assert.NoError(t, err)
k2Vals := append(k2Vals1, k2Vals2...)
k2, err := types.NewMap(ctx, vrw, k2Vals...)
require.NoError(t, err)
t.Run("can diff and filter keyless tables", func(t *testing.T) {
kd := &keylessDiffer{AsyncDiffer: NewAsyncDiffer(20)}
kd.Start(ctx, k1, k2)
res, more, err := kd.GetDiffs(10, 20*time.Second)
require.NoError(t, err)
assert.False(t, more)
assert.Len(t, res, 3)
err = kd.Close()
assert.NoError(t, err)
kd = &keylessDiffer{AsyncDiffer: NewAsyncDiffer(20)}
kd.Start(ctx, k1, k2)
res, more, err = kd.GetDiffsWithFilter(10, 20*time.Second, types.DiffChangeModified)
require.NoError(t, err)
assert.False(t, more)
assert.Len(t, res, 0)
err = kd.Close()
assert.NoError(t, err)
kd = &keylessDiffer{AsyncDiffer: NewAsyncDiffer(20)}
kd.Start(ctx, k1, k2)
res, more, err = kd.GetDiffsWithFilter(6, -1, types.DiffChangeAdded)
require.NoError(t, err)
assert.False(t, more)
assert.Len(t, res, 2)
err = kd.Close()
assert.NoError(t, err)
})
}
func readAll(ad *AsyncDiffer) error {
for {
_, more, err := ad.GetDiffs(10, -1)
if err != nil {
return err
}
if !more {
break
}
}
return nil
}
var c1Tag = types.Uint(1)
var c2Tag = types.Uint(2)
var cardTag = types.Uint(schema.KeylessRowCardinalityTag)
var rowIdTag = types.Uint(schema.KeylessRowIdTag)
func getKeylessRow(ctx context.Context, vals []types.Value) ([]types.Value, error) {
nbf, err := types.GetFormatForVersionString(constants.FormatDefaultString)
if err != nil {
return []types.Value{}, err
}
id1, err := types.UUIDHashedFromValues(nbf, vals...)
if err != nil {
return []types.Value{}, err
}
prefix := []types.Value{
cardTag,
types.Uint(1),
}
vals = append(prefix, vals...)
return []types.Value{
mustTuple(rowIdTag, id1),
mustTuple(vals...),
}, nil
}
func mustTuple(vals ...types.Value) types.Tuple {
tup, err := types.NewTuple(types.Format_Default, vals...)
if err != nil {
panic(err)
}
return tup
}
-99
View File
@@ -19,13 +19,10 @@ import (
"errors"
"fmt"
"io"
"time"
"github.com/dolthub/dolt/go/cmd/dolt/errhand"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/row"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/store/diff"
"github.com/dolthub/dolt/go/store/prolly"
"github.com/dolthub/dolt/go/store/prolly/tree"
"github.com/dolthub/dolt/go/store/types"
@@ -39,7 +36,6 @@ type DiffStatProgress struct {
}
type prollyReporter func(ctx context.Context, vMapping val.OrdinalMapping, fromD, toD *val.TupleDesc, change tree.Diff, ch chan<- DiffStatProgress) error
type nomsReporter func(ctx context.Context, change *diff.Difference, fromSch, toSch schema.Schema, ch chan<- DiffStatProgress) error
// Stat reports a stat of diff changes between two values
// todo: make package private once dolthub is migrated
@@ -168,38 +164,6 @@ func diffProllyTrees(ctx context.Context, ch chan DiffStatProgress, keyless bool
return nil
}
func statWithReporter(ctx context.Context, ch chan DiffStatProgress, from, to types.Map, rpr nomsReporter, fromSch, toSch schema.Schema) (err error) {
ad := NewAsyncDiffer(1024)
ad.Start(ctx, from, to)
defer func() {
if cerr := ad.Close(); cerr != nil && err == nil {
err = cerr
}
}()
var more bool
var diffs []*diff.Difference
for {
diffs, more, err = ad.GetDiffs(100, time.Millisecond)
if err != nil {
return err
}
for _, df := range diffs {
err = rpr(ctx, df, fromSch, toSch, ch)
if err != nil {
return err
}
}
if !more {
break
}
}
return nil
}
func reportPkChanges(ctx context.Context, vMapping val.OrdinalMapping, fromD, toD *val.TupleDesc, change tree.Diff, ch chan<- DiffStatProgress) error {
var stat DiffStatProgress
switch change.Type {
@@ -280,66 +244,3 @@ func prollyCountCellDiff(ctx context.Context, mapping val.OrdinalMapping, fromD,
changed += newCols
return changed
}
func reportNomsPkChanges(ctx context.Context, change *diff.Difference, fromSch, toSch schema.Schema, ch chan<- DiffStatProgress) error {
var stat DiffStatProgress
switch change.ChangeType {
case types.DiffChangeAdded:
stat = DiffStatProgress{Adds: 1}
case types.DiffChangeRemoved:
stat = DiffStatProgress{Removes: 1}
case types.DiffChangeModified:
oldTuple := change.OldValue.(types.Tuple)
newTuple := change.NewValue.(types.Tuple)
cellChanges, err := row.CountCellDiffs(oldTuple, newTuple, fromSch, toSch)
if err != nil {
return err
}
stat = DiffStatProgress{Changes: 1, CellChanges: cellChanges}
default:
return errors.New("unknown change type")
}
select {
case ch <- stat:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func reportNomsKeylessChanges(ctx context.Context, change *diff.Difference, fromSch, toSch schema.Schema, ch chan<- DiffStatProgress) error {
var oldCard uint64
if change.OldValue != nil {
v, err := change.OldValue.(types.Tuple).Get(row.KeylessCardinalityValIdx)
if err != nil {
return err
}
oldCard = uint64(v.(types.Uint))
}
var newCard uint64
if change.NewValue != nil {
v, err := change.NewValue.(types.Tuple).Get(row.KeylessCardinalityValIdx)
if err != nil {
return err
}
newCard = uint64(v.(types.Uint))
}
var stat DiffStatProgress
delta := int64(newCard) - int64(oldCard)
if delta > 0 {
stat = DiffStatProgress{Adds: uint64(delta)}
} else if delta < 0 {
stat = DiffStatProgress{Removes: uint64(-delta)}
} else {
return fmt.Errorf("diff with delta = 0 for key: %s", change.KeyValue.HumanReadableString())
}
select {
case ch <- stat:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
@@ -118,31 +118,6 @@ func mapQuerySchemaToTargetSchema(query, target sql.Schema) (mapping []int, err
return
}
func mapToAndFromColumns(query sql.Schema) (mapping []int, err error) {
last := query[len(query)-1]
if last.Name != "diff_type" {
return nil, errors.New("expected last diff column to be 'diff_type'")
}
query = query[:len(query)-1]
mapping = make([]int, len(query))
for i, col := range query {
if strings.HasPrefix(col.Name, fromPrefix) {
// map "from_..." column to "to_..." column
base := col.Name[len(fromPrefix):]
mapping[i] = query.IndexOfColName(toPrefix + base)
} else if strings.HasPrefix(col.Name, toPrefix) {
// map "to_..." column to "from_..." column
base := col.Name[len(toPrefix):]
mapping[i] = query.IndexOfColName(fromPrefix + base)
} else {
return nil, errors.New("expected column prefix of 'to_' or 'from_' (" + col.Name + ")")
}
}
// |mapping| will contain -1 for unmapped columns
return
}
func (ds DiffSplitter) SplitDiffResultRow(ctx *sql.Context, row sql.Row) (from, to RowDiff, err error) {
from = RowDiff{ColDiffs: make([]ChangeType, len(ds.targetSch))}
to = RowDiff{ColDiffs: make([]ChangeType, len(ds.targetSch))}
-9
View File
@@ -81,15 +81,6 @@ func NewCommit(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore
return &Commit{vrw, ns, parents, commit}, nil
}
// NewCommitFromValue generates a new Commit object that wraps a supplied types.Value.
func NewCommitFromValue(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, value types.Value) (*Commit, error) {
commit, err := datas.CommitFromValue(vrw.Format(), value)
if err != nil {
return nil, err
}
return NewCommit(ctx, vrw, ns, commit)
}
// HashOf returns the hash of the commit
func (c *Commit) HashOf() (hash.Hash, error) {
return c.dCommit.Addr(), nil
-33
View File
@@ -27,7 +27,6 @@ var ErrInvTableName = errors.New("not a valid table name")
var ErrInvHash = errors.New("not a valid hash")
var ErrInvalidAncestorSpec = errors.New("invalid ancestor spec")
var ErrInvalidBranchOrHash = errors.New("string is not a valid branch or hash")
var ErrInvalidHash = errors.New("string is not a valid hash")
var ErrFoundHashNotACommit = errors.New("the value retrieved for this hash is not a commit")
var ErrHashNotFound = errors.New("could not find a value for this hash")
@@ -39,7 +38,6 @@ var ErrWorkspaceNotFound = errors.New("workspace not found")
var ErrTableNotFound = errors.New("table not found")
var ErrTableExists = errors.New("table already exists")
var ErrAlreadyOnBranch = errors.New("Already on branch")
var ErrAlreadyOnWorkspace = errors.New("Already on workspace")
var ErrNomsIO = errors.New("error reading from or writing to noms")
@@ -109,37 +107,6 @@ func (rt RootType) String() string {
return "unknown"
}
type RootTypeSet map[RootType]struct{}
func NewRootTypeSet(rts ...RootType) RootTypeSet {
mp := make(map[RootType]struct{})
for _, rt := range rts {
mp[rt] = struct{}{}
}
return RootTypeSet(mp)
}
func (rts RootTypeSet) Contains(rt RootType) bool {
_, ok := rts[rt]
return ok
}
func (rts RootTypeSet) First(rtList []RootType) RootType {
for _, rt := range rtList {
if _, ok := rts[rt]; ok {
return rt
}
}
return InvalidRoot
}
func (rts RootTypeSet) IsEmpty() bool {
return len(rts) == 0
}
type RootValueUnreadable struct {
RootType RootType
Cause error
@@ -51,7 +51,8 @@ func getNonlocalTablesRef(_ context.Context, valDesc *val.TupleDesc, valTuple va
return result
}
func GetGlobalTablePatterns(ctx context.Context, root RootValue, schema string, cb func(string)) error {
// GetNonlocalTablePatterns invokes |cb| once for each table name pattern in dolt_nonlocal_tables on |root| and |schema|.
func GetNonlocalTablePatterns(ctx context.Context, root RootValue, schema string, cb func(string)) error {
table_name := TableName{Name: NonlocalTableName, Schema: schema}
table, found, err := root.GetTable(ctx, table_name)
if err != nil {
@@ -472,6 +472,10 @@ func encodeTableNameForSerialization(name TableName) string {
// decodeTableNameFromSerialization decodes a table name from a serialized string. See notes on serialization in
// |encodeTableNameForSerialization|
func decodeTableNameFromSerialization(encodedName string) (TableName, bool) {
if len(encodedName) == 0 {
return TableName{}, false
}
if encodedName[0] != 0 {
return TableName{Name: encodedName}, true
} else if len(encodedName) >= 4 { // 2 null bytes plus at least one char for schema and table name
@@ -43,6 +43,35 @@ func MatchTablePattern(pattern string, table string) (bool, error) {
return re.MatchString(table), nil
}
// CompiledTablePatterns holds compiled table name patterns for reuse when matching many names without recompiling.
type CompiledTablePatterns []*regexp.Regexp
// CompileTablePatterns compiles each of |patterns| once and returns them for use with TableMatchesAny. Returns (nil, nil) when |patterns| is empty.
func CompileTablePatterns(patterns []string) (CompiledTablePatterns, error) {
if len(patterns) == 0 {
return nil, nil
}
compiled := make(CompiledTablePatterns, 0, len(patterns))
for _, p := range patterns {
re, err := compilePattern(p)
if err != nil {
return nil, err
}
compiled = append(compiled, re)
}
return compiled, nil
}
// TableMatchesAny reports whether |table| matches any of the patterns in |c|.
func (c CompiledTablePatterns) TableMatchesAny(table string) bool {
for _, re := range c {
if re.MatchString(table) {
return true
}
}
return false
}
// GetMatchingTables returns all tables that match a pattern
func GetMatchingTables(ctx *sql.Context, root RootValue, schemaName string, pattern string) (results []string, err error) {
// If the pattern doesn't contain any special characters, look up that name.
+10 -2
View File
@@ -75,6 +75,8 @@ type RebaseState struct {
// rebasingStarted is true once the rebase plan has been started to execute. Once rebasingStarted is true, the
// value in lastAttemptedStep has been initialized and is valid to read.
rebasingStarted bool
// skipVerification indicates whether test validation should be skipped during rebase operations.
skipVerification bool
}
// Branch returns the name of the branch being actively rebased. This is the branch that will be updated to point
@@ -120,6 +122,10 @@ func (rs RebaseState) WithRebasingStarted(rebasingStarted bool) *RebaseState {
return &rs
}
func (rs RebaseState) SkipVerification() bool {
return rs.skipVerification
}
type MergeState struct {
// the source commit
commit *Commit
@@ -322,13 +328,14 @@ func (ws WorkingSet) StartMerge(commit *Commit, commitSpecStr string) *WorkingSe
// the branch that is being rebased, and |previousRoot| is root value of the branch being rebased. The HEAD and STAGED
// root values of the branch being rebased must match |previousRoot|; WORKING may be a different root value, but ONLY
// if it contains only ignored tables.
func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling) (*WorkingSet, error) {
func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling, skipVerification bool) (*WorkingSet, error) {
ws.rebaseState = &RebaseState{
ontoCommit: ontoCommit,
preRebaseWorking: previousRoot,
branch: branch,
commitBecomesEmptyHandling: commitBecomesEmptyHandling,
emptyCommitHandling: emptyCommitHandling,
skipVerification: skipVerification,
}
ontoRoot, err := ontoCommit.GetRootValue(ctx)
@@ -549,6 +556,7 @@ func newWorkingSet(ctx context.Context, name string, vrw types.ValueReadWriter,
emptyCommitHandling: EmptyCommitHandling(dsws.RebaseState.EmptyCommitHandling(ctx)),
lastAttemptedStep: dsws.RebaseState.LastAttemptedStep(ctx),
rebasingStarted: dsws.RebaseState.RebasingStarted(ctx),
skipVerification: dsws.RebaseState.SkipVerification(ctx),
}
}
@@ -646,7 +654,7 @@ func (ws *WorkingSet) writeValues(ctx context.Context, db *DoltDB, meta *datas.W
rebaseState = datas.NewRebaseState(preRebaseWorking.TargetHash(), dCommit.Addr(), ws.rebaseState.branch,
uint8(ws.rebaseState.commitBecomesEmptyHandling), uint8(ws.rebaseState.emptyCommitHandling),
ws.rebaseState.lastAttemptedStep, ws.rebaseState.rebasingStarted)
ws.rebaseState.lastAttemptedStep, ws.rebaseState.rebasingStarted, ws.rebaseState.skipVerification)
}
return &datas.WorkingSetSpec{
+1 -1
View File
@@ -216,7 +216,7 @@ func CleanOldWorkingSet(
}
// we also have to do a clean, because we the ResetHard won't touch any new tables (tables only in the working set)
newRoots, err := CleanUntracked(ctx, resetRoots, []string{}, false, true)
newRoots, err := CleanUntracked(ctx, resetRoots, []string{}, false, true, false)
if err != nil {
return err
}
+108 -8
View File
@@ -15,8 +15,12 @@
package actions
import (
"fmt"
"io"
"strings"
"time"
gms "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
@@ -25,14 +29,42 @@ import (
)
type CommitStagedProps struct {
Message string
Date time.Time
AllowEmpty bool
SkipEmpty bool
Amend bool
Force bool
Name string
Email string
Message string
Date time.Time
AllowEmpty bool
SkipEmpty bool
Amend bool
Force bool
Name string
Email string
SkipVerification bool
}
const (
// System variable name, defined here to avoid circular imports
DoltCommitVerificationGroups = "dolt_commit_verification_groups"
)
// GetCommitRunTestGroups returns the test groups to run for commit operations
// Returns empty slice if no tests should be run, ["*"] if all tests should be run,
// or specific group names if only those groups should be run
func GetCommitRunTestGroups() []string {
_, val, ok := sql.SystemVariables.GetGlobal(DoltCommitVerificationGroups)
if !ok {
return nil
}
if stringVal, ok := val.(string); ok && stringVal != "" {
if stringVal == "*" {
return []string{"*"}
}
// Split by comma and trim whitespace
groups := strings.Split(stringVal, ",")
for i, group := range groups {
groups[i] = strings.TrimSpace(group)
}
return groups
}
return nil
}
// GetCommitStaged returns a new pending commit with the roots and commit properties given.
@@ -114,6 +146,16 @@ func GetCommitStaged(
}
}
if !props.SkipVerification {
testGroups := GetCommitRunTestGroups()
if len(testGroups) > 0 {
err := runCommitVerification(ctx, testGroups)
if err != nil {
return nil, err
}
}
}
meta, err := datas.NewCommitMetaWithUserTS(props.Name, props.Email, props.Message, props.Date)
if err != nil {
return nil, err
@@ -121,3 +163,61 @@ func GetCommitStaged(
return db.NewPendingCommit(ctx, roots, mergeParents, props.Amend, meta)
}
func runCommitVerification(ctx *sql.Context, testGroups []string) error {
type sessionInterface interface {
sql.Session
GenericProvider() sql.MutableDatabaseProvider
}
session, ok := ctx.Session.(sessionInterface)
if !ok {
return fmt.Errorf("session does not provide database provider interface")
}
provider := session.GenericProvider()
engine := gms.NewDefault(provider)
return runTestsUsingDtablefunctions(ctx, engine, testGroups)
}
// runTestsUsingDtablefunctions runs tests using the dtablefunctions package against the staged root
func runTestsUsingDtablefunctions(ctx *sql.Context, engine *gms.Engine, testGroups []string) error {
if len(testGroups) == 0 {
return nil
}
var allFailures []string
for _, group := range testGroups {
query := fmt.Sprintf("SELECT * FROM dolt_test_run('%s')", group)
_, iter, _, err := engine.Query(ctx, query)
if err != nil {
return fmt.Errorf("failed to run dolt_test_run for group %s: %w", group, err)
}
for {
row, rErr := iter.Next(ctx)
if rErr == io.EOF {
break
}
if rErr != nil {
return fmt.Errorf("error reading test results: %w", rErr)
}
// Extract status (column 3)
status := fmt.Sprintf("%v", row[3])
if status != "PASS" {
testName := fmt.Sprintf("%v", row[0])
message := fmt.Sprintf("%v", row[4])
allFailures = append(allFailures, fmt.Sprintf("%s (%s)", testName, message))
}
}
}
if len(allFailures) > 0 {
return fmt.Errorf("commit verification failed: %s", strings.Join(allFailures, ", "))
}
return nil
}
+42 -28
View File
@@ -270,60 +270,74 @@ func getUnionedTables(ctx context.Context, tables []doltdb.TableName, stagedRoot
return tables, nil
}
// CleanUntracked deletes untracked tables from the working root.
// Evaluates untracked tables as: all working tables - all staged tables.
func CleanUntracked(ctx *sql.Context, roots doltdb.Roots, tables []string, dryrun bool, force bool) (doltdb.Roots, error) {
// CleanUntracked deletes from the working root the tables that are untracked (in working but not in staged/head). If
// |tables| is non-empty it uses only those names as candidates; otherwise it uses all working tables. Tables matching
// dolt_nonlocal_tables are always excluded. When |respectIgnoreRules| is true, tables matching dolt_ignore are also excluded. Does nothing when |dryrun| is true.
func CleanUntracked(ctx *sql.Context, roots doltdb.Roots, tables []string, dryrun bool, force bool, respectIgnoreRules bool) (doltdb.Roots, error) {
untrackedTables := make(map[doltdb.TableName]struct{})
for _, name := range tables {
resolvedName, tblExists, err := resolve.TableName(ctx, roots.Working, name)
if err != nil {
return doltdb.Roots{}, err
}
if !tblExists {
return doltdb.Roots{}, fmt.Errorf("%w: '%s'", doltdb.ErrTableNotFound, name)
}
untrackedTables[resolvedName] = struct{}{}
}
var err error
if len(tables) == 0 {
allTableNames, err := roots.Working.GetAllTableNames(ctx, true)
if err != nil {
return doltdb.Roots{}, nil
return doltdb.Roots{}, err
}
for _, tableName := range allTableNames {
untrackedTables[tableName] = struct{}{}
}
} else {
for i := range tables {
name := tables[i]
resolvedName, tblExists, err := resolve.TableName(ctx, roots.Working, name)
var candidates []doltdb.TableName
if respectIgnoreRules {
candidates, err = doltdb.ExcludeIgnoredTables(ctx, roots, allTableNames)
if err != nil {
return doltdb.Roots{}, err
}
if !tblExists {
return doltdb.Roots{}, fmt.Errorf("%w: '%s'", doltdb.ErrTableNotFound, name)
} else {
candidates = allTableNames
}
var nonlocalPatterns []string
err = doltdb.GetNonlocalTablePatterns(ctx, roots.Working, doltdb.DefaultSchemaName, func(p string) {
nonlocalPatterns = append(nonlocalPatterns, p)
})
if err != nil {
return doltdb.Roots{}, err
}
compiled, err := doltdb.CompileTablePatterns(nonlocalPatterns)
if err != nil {
return doltdb.Roots{}, err
}
for _, tableName := range candidates {
if compiled.TableMatchesAny(tableName.Name) {
continue
}
untrackedTables[resolvedName] = struct{}{}
untrackedTables[tableName] = struct{}{}
}
}
// untracked tables = working tables - staged tables
headTblNames := GetAllTableNames(ctx, roots.Staged)
if err != nil {
return doltdb.Roots{}, err
}
for _, name := range headTblNames {
delete(untrackedTables, name)
}
newRoot := roots.Working
var toDelete []doltdb.TableName
toDelete := make([]doltdb.TableName, 0, len(untrackedTables))
for t := range untrackedTables {
toDelete = append(toDelete, t)
}
newRoot, err = newRoot.RemoveTables(ctx, force, force, toDelete...)
if err != nil {
return doltdb.Roots{}, fmt.Errorf("failed to remove tables; %w", err)
}
if dryrun {
return roots, nil
}
roots.Working = newRoot
newRoot, err := roots.Working.RemoveTables(ctx, force, force, toDelete...)
if err != nil {
return doltdb.Roots{}, fmt.Errorf("failed to remove tables; %w", err)
}
roots.Working = newRoot
return roots, nil
}
-447
View File
@@ -1,447 +0,0 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package actions
import (
"fmt"
"io"
"strconv"
"time"
"github.com/dolthub/go-mysql-server/sql"
"github.com/shopspring/decimal"
"golang.org/x/exp/constraints"
"github.com/dolthub/dolt/go/store/val"
)
const (
AssertionExpectedRows = "expected_rows"
AssertionExpectedColumns = "expected_columns"
AssertionExpectedSingleValue = "expected_single_value"
)
// AssertData parses an assertion, comparison, and value, then returns the status of the test.
// Valid comparison are: "==", "!=", "<", ">", "<=", and ">=".
// testPassed indicates whether the test was successful or not.
// message is a string used to indicate test failures, and will not halt the overall process.
// message will be empty if the test passed.
// err indicates runtime failures and will stop dolt_test_run from proceeding.
func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) {
switch assertion {
case AssertionExpectedRows:
message, err = expectRows(sqlCtx, comparison, value, queryResult)
case AssertionExpectedColumns:
message, err = expectColumns(sqlCtx, comparison, value, queryResult)
case AssertionExpectedSingleValue:
message, err = expectSingleValue(sqlCtx, comparison, value, queryResult)
default:
return false, fmt.Sprintf("%s is not a valid assertion type", assertion), nil
}
if err != nil {
return false, "", err
} else if message != "" {
return false, message, nil
}
return true, "", nil
}
func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
row, err := queryResult.Next(sqlCtx)
if err == io.EOF {
return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil
} else if err != nil {
return "", err
}
if len(row) != 1 {
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil
}
_, err = queryResult.Next(sqlCtx)
if err == nil { //If multiple rows were given, we should error out
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple rows"), nil
} else if err != io.EOF { // "True" error, so we should quit out
return "", err
}
if value == nil { // If we're expecting a null value, we don't need to type switch
return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil
}
// Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception
// of "0" and "1", which are valid integers and are covered below.
if *value != "0" && *value != "1" {
if expectedBool, err := strconv.ParseBool(*value); err == nil {
actualBool, boolErr := getInterfaceAsBool(row[0])
if boolErr != nil {
return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil
}
return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil
}
}
switch actualValue := row[0].(type) {
case int8:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int8(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int16:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int16(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int32:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int64:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil
case int:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case uint8:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint8(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint16:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint16(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint32:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint64:
expectedUint, err := strconv.ParseUint(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedUint, actualValue, AssertionExpectedSingleValue), nil
case uint:
expectedUint, err := strconv.ParseUint(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case float64:
expectedFloat, err := strconv.ParseFloat(*value, 64)
if err != nil {
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil
case float32:
expectedFloat, err := strconv.ParseFloat(*value, 32)
if err != nil {
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
}
return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil
case decimal.Decimal:
expectedDecimal, err := decimal.NewFromString(*value)
if err != nil {
return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil
}
return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil
case time.Time:
expectedTime, format, err := parseTestsDate(*value)
if err != nil {
return fmt.Sprintf("%s does not appear to be a valid date", *value), nil
}
return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil
case *val.TextStorage, string:
actualString, err := GetStringColAsString(sqlCtx, actualValue)
if err != nil {
return "", err
}
return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil
default:
return fmt.Sprintf("Type %T is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil
}
}
func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
if value == nil {
return "null is not a valid assertion for expected_rows", nil
}
expectedRows, err := strconv.Atoi(*value)
if err != nil {
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
}
var numRows int
for {
_, err := queryResult.Next(sqlCtx)
if err == io.EOF {
break
} else if err != nil {
return "", err
}
numRows++
}
return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil
}
func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
if value == nil {
return "null is not a valid assertion for expected_rows", nil
}
expectedColumns, err := strconv.Atoi(*value)
if err != nil {
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
}
var numColumns int
row, err := queryResult.Next(sqlCtx)
if err != nil && err != io.EOF {
return "", err
}
numColumns = len(row)
return compareTestAssertion(comparison, expectedColumns, numColumns, AssertionExpectedColumns), nil
}
// compareTestAssertion is a generic function used for comparing string, ints, floats.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareTestAssertion[T constraints.Ordered](comparison string, expectedValue, actualValue T, assertionType string) string {
switch comparison {
case "==":
if actualValue != expectedValue {
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case "!=":
if actualValue == expectedValue {
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case "<":
if actualValue >= expectedValue {
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, actualValue)
}
case "<=":
if actualValue > expectedValue {
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case ">":
if actualValue <= expectedValue {
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, actualValue)
}
case ">=":
if actualValue < expectedValue {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, actualValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// parseTestsDate is an internal function that parses the queried string according to allowed time formats for dolt_tests.
// It returns the parsed time, the format that succeeded, and an error if applicable.
func parseTestsDate(value string) (parsedTime time.Time, format string, err error) {
// List of valid formats
formats := []string{
time.DateOnly,
time.DateTime,
time.TimeOnly,
time.RFC3339,
time.RFC1123Z,
}
for _, format := range formats {
if parsedTime, parseErr := time.Parse(format, value); parseErr == nil {
return parsedTime, format, nil
} else {
err = parseErr
}
}
return time.Time{}, "", err
}
// compareDates is a function used for comparing time values.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareDates(comparison string, expectedValue, realValue time.Time, format string, assertionType string) string {
expectedStr := expectedValue.Format(format)
realStr := realValue.Format(format)
switch comparison {
case "==":
if !expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s equal to %s, got %s", assertionType, expectedStr, realStr)
}
case "!=":
if expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s not equal to %s, got %s", assertionType, expectedStr, realStr)
}
case "<":
if realValue.Equal(expectedValue) || realValue.After(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than %s, got %s", assertionType, expectedStr, realStr)
}
case "<=":
if realValue.After(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than or equal to %s, got %s", assertionType, expectedStr, realStr)
}
case ">":
if realValue.Before(expectedValue) || realValue.Equal(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than %s, got %s", assertionType, expectedStr, realStr)
}
case ">=":
if realValue.Before(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// compareDecimals is a function used for comparing decimals.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal, assertionType string) string {
switch comparison {
case "==":
if !expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, realValue)
}
case "!=":
if expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, realValue)
}
case "<":
if realValue.GreaterThanOrEqual(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, realValue)
}
case "<=":
if realValue.GreaterThan(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, realValue)
}
case ">":
if realValue.LessThanOrEqual(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, realValue)
}
case ">=":
if realValue.LessThan(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// getTinyIntColAsBool returns the value interface{} as a bool
// This is necessary because the query engine may return a tinyint column as a bool, int, or other types.
// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles.
func getInterfaceAsBool(col interface{}) (bool, error) {
switch v := col.(type) {
case bool:
return v, nil
case int:
return v == 1, nil
case int8:
return v == 1, nil
case int16:
return v == 1, nil
case int32:
return v == 1, nil
case int64:
return v == 1, nil
case uint:
return v == 1, nil
case uint8:
return v == 1, nil
case uint16:
return v == 1, nil
case uint32:
return v == 1, nil
case uint64:
return v == 1, nil
case string:
return v == "1", nil
default:
return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v)
}
}
// compareBooleans is a function used for comparing boolean values.
// It takes in a comparison string from one of: "==", "!="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string {
switch comparison {
case "==":
if expectedValue != realValue {
return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue)
}
case "!=":
if expectedValue == realValue {
return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison)
}
return ""
}
// compareNullValue is a function used for comparing a null value.
// It takes in a comparison string from one of: "==", "!="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareNullValue(comparison string, actualValue interface{}, assertionType string) string {
switch comparison {
case "==":
if actualValue != nil {
return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue)
}
case "!=":
if actualValue == nil {
return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType)
}
default:
return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison)
}
return ""
}
// GetStringColAsString is a function that returns a text column as a string.
// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations,
// so we use a special parser to get the correct string values
func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
if ts, ok := tableValue.(*val.TextStorage); ok {
str, err := ts.Unwrap(sqlCtx)
return &str, err
} else if str, ok := tableValue.(string); ok {
return &str, nil
} else if tableValue == nil {
return nil, nil
} else {
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
}
}
-35
View File
@@ -33,7 +33,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/grpcendpoint"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
@@ -1350,40 +1349,6 @@ func (dEnv *DoltEnv) TempTableFilesDir() (string, error) {
return absPath, nil
}
func (dEnv *DoltEnv) DbEaFactory(ctx context.Context) (editor.DbEaFactory, error) {
tmpDir, err := dEnv.TempTableFilesDir()
if err != nil {
return nil, err
}
db := dEnv.DoltDB(ctx)
if db == nil {
if dEnv.DBLoadError != nil {
return nil, dEnv.DBLoadError
}
return nil, errors.New("DoltDB failed to initialize but no error was recorded")
}
return editor.NewDbEaFactory(tmpDir, db.ValueReadWriter()), nil
}
func (dEnv *DoltEnv) BulkDbEaFactory(ctx context.Context) (editor.DbEaFactory, error) {
tmpDir, err := dEnv.TempTableFilesDir()
if err != nil {
return nil, err
}
db := dEnv.DoltDB(ctx)
if db == nil {
if dEnv.DBLoadError != nil {
return nil, dEnv.DBLoadError
}
return nil, errors.New("DoltDB failed to initialize but no error was recorded")
}
return editor.NewBulkImportTEAFactory(db.ValueReadWriter(), tmpDir), nil
}
func (dEnv *DoltEnv) IsAccessModeReadOnly(ctx context.Context) (bool, error) {
db := dEnv.DoltDB(ctx)
if db == nil {
-248
View File
@@ -1,248 +0,0 @@
// Copyright 2021 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package env
import (
"context"
"fmt"
"os"
"time"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/store/chunks"
"github.com/dolthub/dolt/go/store/datas"
"github.com/dolthub/dolt/go/store/hash"
)
func NewMemoryDbData(ctx context.Context, cfg config.ReadableConfig) (DbData[context.Context], error) {
branchName := GetDefaultInitBranch(cfg)
ddb, err := NewMemoryDoltDB(ctx, branchName)
if err != nil {
return DbData[context.Context]{}, err
}
rs, err := NewMemoryRepoState(ctx, ddb, branchName)
if err != nil {
return DbData[context.Context]{}, err
}
return DbData[context.Context]{
Ddb: ddb,
Rsw: rs,
Rsr: rs,
}, nil
}
func NewMemoryDoltDB(ctx context.Context, initBranch string) (*doltdb.DoltDB, error) {
ts := &chunks.TestStorage{}
cs := ts.NewViewWithDefaultFormat()
ddb, err := doltdb.DoltDBFromCS(cs, "")
if err != nil {
return nil, err
}
m := "memory"
branchRef := ref.NewBranchRef(initBranch)
err = ddb.WriteEmptyRepoWithCommitTimeAndDefaultBranch(ctx, m, m, datas.CommitterDate(), branchRef)
if err != nil {
return nil, err
}
return ddb, nil
}
func NewMemoryRepoState(ctx context.Context, ddb *doltdb.DoltDB, initBranch string) (MemoryRepoState, error) {
head := ref.NewBranchRef(initBranch)
rs := MemoryRepoState{
DoltDB: ddb,
Head: head,
}
commit, err := ddb.ResolveCommitRef(ctx, head)
if err != nil {
return MemoryRepoState{}, err
}
root, err := commit.GetRootValue(ctx)
if err != nil {
return MemoryRepoState{}, err
}
err = rs.UpdateWorkingRoot(ctx, root)
if err != nil {
return MemoryRepoState{}, err
}
err = rs.UpdateStagedRoot(ctx, root)
if err != nil {
return MemoryRepoState{}, err
}
return rs, nil
}
type MemoryRepoState struct {
DoltDB *doltdb.DoltDB
Head ref.DoltRef
}
var _ RepoStateReader[context.Context] = MemoryRepoState{}
var _ RepoStateWriter = MemoryRepoState{}
func (m MemoryRepoState) CWBHeadRef(context.Context) (ref.DoltRef, error) {
return m.Head, nil
}
func (m MemoryRepoState) CWBHeadSpec(ctx context.Context) (*doltdb.CommitSpec, error) {
headRef, err := m.CWBHeadRef(ctx)
if err != nil {
return nil, err
}
spec, err := doltdb.NewCommitSpec(headRef.GetPath())
if err != nil {
return nil, err
}
return spec, nil
}
func (m MemoryRepoState) UpdateStagedRoot(ctx context.Context, newRoot doltdb.RootValue) error {
var h hash.Hash
var wsRef ref.WorkingSetRef
ws, err := m.WorkingSet(ctx)
if err == doltdb.ErrWorkingSetNotFound {
// first time updating root
headRef, err := m.CWBHeadRef(ctx)
if err != nil {
return err
}
wsRef, err = ref.WorkingSetRefForHead(headRef)
if err != nil {
return err
}
ws = doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(newRoot).WithStagedRoot(newRoot)
} else if err != nil {
return err
} else {
h, err = ws.HashOf()
if err != nil {
return err
}
wsRef = ws.Ref()
}
return m.DoltDB.UpdateWorkingSet(ctx, wsRef, ws.WithStagedRoot(newRoot), h, m.workingSetMeta(), nil)
}
func (m MemoryRepoState) UpdateWorkingRoot(ctx context.Context, newRoot doltdb.RootValue) error {
var h hash.Hash
var wsRef ref.WorkingSetRef
ws, err := m.WorkingSet(ctx)
if err == doltdb.ErrWorkingSetNotFound {
// first time updating root
headRef, err := m.CWBHeadRef(ctx)
if err != nil {
return err
}
wsRef, err = ref.WorkingSetRefForHead(headRef)
if err != nil {
return err
}
ws = doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(newRoot).WithStagedRoot(newRoot)
} else if err != nil {
return err
} else {
h, err = ws.HashOf()
if err != nil {
return err
}
wsRef = ws.Ref()
}
return m.DoltDB.UpdateWorkingSet(ctx, wsRef, ws.WithWorkingRoot(newRoot), h, m.workingSetMeta(), nil)
}
func (m MemoryRepoState) WorkingSet(ctx context.Context) (*doltdb.WorkingSet, error) {
headRef, err := m.CWBHeadRef(ctx)
if err != nil {
return nil, err
}
workingSetRef, err := ref.WorkingSetRefForHead(headRef)
if err != nil {
return nil, err
}
workingSet, err := m.DoltDB.ResolveWorkingSet(ctx, workingSetRef)
if err != nil {
return nil, err
}
return workingSet, nil
}
func (m MemoryRepoState) workingSetMeta() *datas.WorkingSetMeta {
return &datas.WorkingSetMeta{
Timestamp: uint64(time.Now().Unix()),
Description: "updated from dolt environment",
}
}
func (m MemoryRepoState) SetCWBHeadRef(_ context.Context, r ref.MarshalableRef) (err error) {
m.Head = r.Ref
return
}
func (m MemoryRepoState) GetRemotes() (*concurrentmap.Map[string, Remote], error) {
return concurrentmap.New[string, Remote](), nil
}
func (m MemoryRepoState) AddRemote(r Remote) error {
return fmt.Errorf("cannot insert a remote in a memory database")
}
func (m MemoryRepoState) GetBranches() (*concurrentmap.Map[string, BranchConfig], error) {
return concurrentmap.New[string, BranchConfig](), nil
}
func (m MemoryRepoState) UpdateBranch(name string, new BranchConfig) error {
return nil
}
func (m MemoryRepoState) RemoveRemote(ctx context.Context, name string) error {
return fmt.Errorf("cannot delete a remote from a memory database")
}
func (m MemoryRepoState) TempTableFilesDir() (string, error) {
return os.TempDir(), nil
}
func (m MemoryRepoState) GetBackups() (*concurrentmap.Map[string, Remote], error) {
panic("cannot get backups on in memory database")
}
func (m MemoryRepoState) AddBackup(r Remote) error {
panic("cannot add backup to in memory database")
}
func (m MemoryRepoState) RemoveBackup(ctx context.Context, name string) error {
panic("cannot remove backup from in memory database")
}
@@ -127,7 +127,6 @@ func testDataMergeHelper(t *testing.T, tests []dataMergeTest, flipSides bool) {
var mo merge.MergeOpts
var eo editor.Options
eo = eo.WithDeaf(editor.NewInMemDeaf(a.VRW()))
// attempt merge before skipping to assert no panics
result, err := merge.MergeRoots(sql.NewContext(ctx), doltdb.SimpleTableResolver{}, l, r, a, rootish{r}, rootish{a}, eo, mo)
@@ -147,7 +146,6 @@ func testDataMergeHelper(t *testing.T, tests []dataMergeTest, flipSides bool) {
func setupDataMergeTest(ctx context.Context, t *testing.T, schema namedSchema, test dataTest) (anc, left, right, merged doltdb.RootValue) {
denv := dtestutils.CreateTestEnv()
var eo editor.Options
eo = eo.WithDeaf(editor.NewInMemDeaf(denv.DoltDB(ctx).ValueReadWriter()))
ancestorTable := tbl(schema, test.ancestor...)
anc = makeRootWithTable(t, denv.DoltDB(ctx), eo, *ancestorTable)
@@ -458,18 +458,6 @@ type keylessEntry struct {
c2 int
}
func (e keylessEntries) toTupleSet() tupleSet {
tups := make([]types.Tuple, len(e))
for i, t := range e {
tups[i] = t.ToNomsTuple()
}
return mustTupleSet(tups...)
}
func (e keylessEntry) ToNomsTuple() types.Tuple {
return dtu.MustTuple(cardTag, types.Uint(e.card), c1Tag, types.Int(e.c1), c2Tag, types.Int(e.c2))
}
func (e keylessEntry) HashAndValue() ([]byte, val.Tuple, error) {
valBld.PutUint64(0, uint64(e.card))
valBld.PutInt64(1, int64(e.c1))
@@ -497,14 +485,6 @@ func (e conflictEntries) toConflictSet(t *testing.T) conflictSet {
return s
}
func (e conflictEntries) toTupleSet() tupleSet {
tups := make([]types.Tuple, len(e))
for i, t := range e {
tups[i] = t.ToNomsTuple()
}
return mustTupleSet(tups...)
}
func (e conflictEntry) Key(t *testing.T) (h [16]byte) {
if e.base != nil {
h2, _, err := e.base.HashAndValue()
@@ -528,34 +508,6 @@ func (e conflictEntry) Key(t *testing.T) (h [16]byte) {
return
}
func (e conflictEntry) ToNomsTuple() types.Tuple {
var b, o, t types.Value = types.NullValue, types.NullValue, types.NullValue
if e.base != nil {
b = e.base.ToNomsTuple()
}
if e.ours != nil {
o = e.ours.ToNomsTuple()
}
if e.theirs != nil {
t = e.theirs.ToNomsTuple()
}
return dtu.MustTuple(b, o, t)
}
type tupleSet map[hash.Hash]types.Tuple
func mustTupleSet(tt ...types.Tuple) (s tupleSet) {
s = make(tupleSet, len(tt))
for _, tup := range tt {
h, err := tup.Hash(types.Format_Default)
if err != nil {
panic(err)
}
s[h] = tup
}
return
}
type hash128Set map[[16]byte]val.Tuple
func mustHash128Set(entries ...keylessEntry) (s hash128Set) {
-8
View File
@@ -413,14 +413,6 @@ type ArtifactStatus struct {
ConstraintViolationsTables []string
}
func (as ArtifactStatus) HasConflicts() bool {
return len(as.DataConflictTables) > 0 || len(as.SchemaConflictsTables) > 0
}
func (as ArtifactStatus) HasConstraintViolations() bool {
return len(as.ConstraintViolationsTables) > 0
}
// MergeWouldStompChanges returns list of table names that are stomped and the diffs map between head and working set.
func MergeWouldStompChanges(ctx context.Context, roots doltdb.Roots, mergeCommit *doltdb.Commit) ([]doltdb.TableName, map[doltdb.TableName]hash.Hash, error) {
mergeRoot, err := mergeCommit.GetRootValue(ctx)
-101
View File
@@ -32,7 +32,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor/creation"
filesys2 "github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/libraries/utils/valutil"
"github.com/dolthub/dolt/go/store/datas"
"github.com/dolthub/dolt/go/store/pool"
"github.com/dolthub/dolt/go/store/prolly"
@@ -76,10 +75,6 @@ func (v rowV) value() val.Tuple {
return tup
}
func (v rowV) nomsValue() types.Value {
return valsToTestTupleWithoutPks([]types.Value{types.Int(v.col1), types.Int(v.col2)})
}
const (
NoopAction ActionType = iota
InsertAction
@@ -505,68 +500,6 @@ func rebuildAllProllyIndexes(ctx *sql.Context, tbl *doltdb.Table) (*doltdb.Table
return tbl.SetIndexSet(ctx, indexes)
}
func calcExpectedStats(t *testing.T) *MergeStats {
s := &MergeStats{Operation: TableModified}
for _, testCase := range testRows {
if (testCase.leftAction == InsertAction) != (testCase.rightAction == InsertAction) {
if testCase.leftAction == UpdateAction || testCase.rightAction == UpdateAction ||
testCase.leftAction == DeleteAction || testCase.rightAction == DeleteAction {
// Either the row exists in the ancestor commit and we are
// deleting or updating it, or the row doesn't exist and we are
// inserting it.
t.Fatalf("it's impossible for an insert to be paired with an update or delete")
}
}
if testCase.leftAction == NoopAction {
switch testCase.rightAction {
case NoopAction:
case DeleteAction:
s.Deletes++
case InsertAction:
s.Adds++
case UpdateAction:
s.Modifications++
}
continue
}
if testCase.rightAction == NoopAction {
switch testCase.leftAction {
case NoopAction:
case DeleteAction:
s.Deletes++
case InsertAction:
s.Adds++
case UpdateAction:
s.Modifications++
}
continue
}
if testCase.conflict {
// (UpdateAction, DeleteAction),
// (DeleteAction, UpdateAction),
// (UpdateAction, UpdateAction) with conflict,
// (InsertAction, InsertAction) with conflict
s.DataConflicts++
continue
}
if testCase.leftAction == InsertAction && testCase.rightAction == InsertAction {
// Equivalent inserts
continue
}
if !valutil.NilSafeEqCheck(unwrapNoms(testCase.leftValue), unwrapNoms(testCase.rightValue)) {
s.Modifications++
continue
}
}
return s
}
func mustMakeEmptyRepo(t *testing.T) *doltdb.DoltDB {
ddb, _ := doltdb.LoadDoltDB(context.Background(), types.Format_Default, doltdb.InMemDoltDB, filesys2.LocalFS)
err := ddb.WriteEmptyRepo(context.Background(), env.DefaultInitBranch, name, email)
@@ -645,40 +578,6 @@ func key(i int) val.Tuple {
return tup
}
func nomsKey(i int) types.Value {
return mustTuple(types.NewTuple(types.Format_Default, types.Uint(idTag), types.Int(i)))
}
func unwrap(v *rowV) val.Tuple {
if v == nil {
return nil
}
return v.value()
}
func unwrapNoms(v *rowV) types.Value {
if v == nil {
return nil
}
return v.nomsValue()
}
func mustTuple(tpl types.Tuple, err error) types.Tuple {
if err != nil {
panic(err)
}
return tpl
}
func mustString(str string, err error) string {
if err != nil {
panic(err)
}
return str
}
func MustDebugFormatProlly(t *testing.T, m prolly.Map) string {
s, err := prolly.DebugFormat(context.Background(), m)
require.NoError(t, err)
@@ -26,15 +26,6 @@ import (
"github.com/dolthub/dolt/go/store/val"
)
type nomsRowMergeTest struct {
name string
row, mergeRow, ancRow types.Value
sch schema.Schema
expectedResult types.Value
expectCellMerge bool
expectConflict bool
}
type rowMergeTest struct {
name string
row, mergeRow, ancRow val.Tuple
@@ -69,39 +60,6 @@ func build(ints ...int) []*int {
return out
}
var convergentEditCases = []testCase{
{
"add same row",
build(1, 2),
build(1, 2),
nil,
2, 2, 2,
build(1, 2),
false,
false,
},
{
"both delete row",
nil,
nil,
build(1, 2),
2, 2, 2,
nil,
false,
false,
},
{
"modify row to equal value",
build(2, 2),
build(2, 2),
build(1, 1),
2, 2, 2,
build(2, 2),
false,
false,
},
}
var testCases = []testCase{
{
"insert different rows",
@@ -221,35 +179,6 @@ func TestRowMerge(t *testing.T) {
}
}
func valsToTestTupleWithoutPks(vals []types.Value) types.Value {
return valsToTestTuple(vals, false)
}
func valsToTestTupleWithPks(vals []types.Value) types.Value {
return valsToTestTuple(vals, true)
}
func valsToTestTuple(vals []types.Value, includePrimaryKeys bool) types.Value {
if vals == nil {
return nil
}
tplVals := make([]types.Value, 0, 2*len(vals))
for i, val := range vals {
if !types.IsNull(val) {
tag := i
// Assume one primary key tag, add 1 to all other tags
if includePrimaryKeys {
tag++
}
tplVals = append(tplVals, types.Uint(tag))
tplVals = append(tplVals, val)
}
}
return mustTuple(types.NewTuple(types.Format_Default, tplVals...))
}
func createRowMergeStruct(t testCase) rowMergeTest {
mergedSch := calcMergedSchema(t)
leftSch := calcSchema(t.rowCnt)
@@ -269,16 +198,6 @@ func createRowMergeStruct(t testCase) rowMergeTest {
t.expectConflict}
}
func createNomsRowMergeStruct(t testCase) nomsRowMergeTest {
sch := calcMergedSchema(t)
tpl := valsToTestTupleWithPks(toVals(t.row))
mergeTpl := valsToTestTupleWithPks(toVals(t.mergeRow))
ancTpl := valsToTestTupleWithPks(toVals(t.ancRow))
expectedTpl := valsToTestTupleWithPks(toVals(t.expectedResult))
return nomsRowMergeTest{t.name, tpl, mergeTpl, ancTpl, sch, expectedTpl, t.expectCellMerge, t.expectConflict}
}
func calcMergedSchema(t testCase) schema.Schema {
longest := t.rowCnt
if t.mRowCnt > longest {
@@ -323,20 +242,3 @@ func buildTup(sch schema.Schema, r []*int) val.Tuple {
}
return tup
}
func toVals(ints []*int) []types.Value {
if ints == nil {
return nil
}
v := make([]types.Value, len(ints))
for i, d := range ints {
if d == nil {
v[i] = types.NullValue
continue
}
v[i] = types.Int(*d)
}
return v
}
@@ -1650,7 +1650,6 @@ func testSchemaMergeHelper(t *testing.T, tests []schemaMergeTest, flipSides bool
var mo merge.MergeOpts
var eo editor.Options
eo = eo.WithDeaf(editor.NewInMemDeaf(a.VRW()))
// attempt merge before skipping to assert no panics
result, err := merge.MergeRoots(sql.NewContext(ctx), doltdb.SimpleTableResolver{}, l, r, a, rootish{r}, rootish{a}, eo, mo)
maybeSkip(t, test, flipSides)
@@ -1693,7 +1692,6 @@ func testSchemaMergeHelper(t *testing.T, tests []schemaMergeTest, flipSides bool
func setupSchemaMergeTest(ctx context.Context, t *testing.T, test schemaMergeTest) (anc, left, right, merged doltdb.RootValue) {
denv := dtestutils.CreateTestEnv()
var eo editor.Options
eo = eo.WithDeaf(editor.NewInMemDeaf(denv.DoltDB(ctx).ValueReadWriter()))
anc = makeRootWithTable(t, denv.DoltDB(ctx), eo, test.ancestor)
assert.NotNil(t, anc)
if test.left != nil {
@@ -20,12 +20,10 @@ import (
"fmt"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
json2 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/json"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/prolly"
"github.com/dolthub/dolt/go/store/types"
@@ -543,16 +541,3 @@ func foreignKeyCVJson(foreignKey doltdb.ForeignKey, sch, refSch schema.Schema) (
return d, nil
}
func jsonDataToNomsValue(ctx context.Context, vrw types.ValueReadWriter, data []byte) (types.JSON, error) {
var doc interface{}
if err := json.Unmarshal(data, &doc); err != nil {
return types.JSON{}, err
}
sqlDoc := gmstypes.JSONDocument{Val: doc}
nomsJson, err := json2.NomsJSONFromJSONValue(ctx, vrw, sqlDoc)
if err != nil {
return types.JSON{}, err
}
return types.JSON(nomsJson), nil
}
+4 -14
View File
@@ -86,7 +86,7 @@ func TestBasics(t *testing.T) {
{NewDataLocation("file.csv", ""), CsvFile.ReadableStr() + ":file.csv", true},
{NewDataLocation("file.psv", ""), PsvFile.ReadableStr() + ":file.psv", true},
{NewDataLocation("file.json", ""), JsonFile.ReadableStr() + ":file.json", true},
//{NewDataLocation("file.nbf", ""), NbfFile, "file.nbf", true},
// {NewDataLocation("file.nbf", ""), NbfFile, "file.nbf", true},
}
for _, test := range tests {
@@ -133,7 +133,7 @@ func TestExists(t *testing.T) {
NewDataLocation("file.csv", ""),
NewDataLocation("file.psv", ""),
NewDataLocation("file.json", ""),
//NewDataLocation("file.nbf", ""),
// NewDataLocation("file.nbf", ""),
}
ddb, root, fs := createRootAndFS()
@@ -192,7 +192,7 @@ func TestCreateRdWr(t *testing.T) {
{NewDataLocation("file.csv", ""), reflect.TypeOf((*csv.CSVReader)(nil)).Elem(), reflect.TypeOf((*csv.CSVWriter)(nil)).Elem()},
{NewDataLocation("file.psv", ""), reflect.TypeOf((*csv.CSVReader)(nil)).Elem(), reflect.TypeOf((*csv.CSVWriter)(nil)).Elem()},
{NewDataLocation("file.json", ""), reflect.TypeOf((*json.JSONReader)(nil)).Elem(), reflect.TypeOf((*json.RowWriter)(nil)).Elem()},
//{NewDataLocation("file.nbf", ""), reflect.TypeOf((*nbf.NBFReader)(nil)).Elem(), reflect.TypeOf((*nbf.NBFWriter)(nil)).Elem()},
// {NewDataLocation("file.nbf", ""), reflect.TypeOf((*nbf.NBFReader)(nil)).Elem(), reflect.TypeOf((*nbf.NBFWriter)(nil)).Elem()},
}
ctx := context.Background()
@@ -220,16 +220,6 @@ func TestCreateRdWr(t *testing.T) {
loc := test.dl
tmpDir, tdErr := dEnv.TempTableFilesDir()
if tdErr != nil {
t.Fatal("Unexpected error accessing .dolt directory.", tdErr)
}
deaf, err := dEnv.DbEaFactory(ctx)
if err != nil {
t.Fatal("Unexpected error accessing .dolt directory.", err)
}
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
filePath, fpErr := dEnv.FS.Abs(strings.Split(loc.String(), ":")[1])
if fpErr != nil {
t.Fatal("Unexpected error getting filepath", fpErr)
@@ -240,7 +230,7 @@ func TestCreateRdWr(t *testing.T) {
t.Fatal("Unexpected error opening file for writer.", wrErr)
}
wr, wErr := loc.NewCreatingWriter(context.Background(), mvOpts, root, fakeSchema, opts, writer)
wr, wErr := loc.NewCreatingWriter(context.Background(), mvOpts, root, fakeSchema, editor.Options{}, writer)
if wErr != nil {
t.Fatal("Unexpected error creating writer.", wErr)
}
@@ -31,7 +31,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/noms"
"github.com/dolthub/dolt/go/store/types"
)
@@ -40,6 +39,9 @@ const (
tableWriterStatUpdateRate = 64 * 1024
)
// StatsCb is a callback for reporting stats about the rows that have been processed so far
type StatsCb func(types.AppliedEditStats)
// SqlEngineTableWriter is a utility for importing a set of rows through the sql engine.
type SqlEngineTableWriter struct {
se *sqle.Engine
@@ -51,7 +53,7 @@ type SqlEngineTableWriter struct {
force bool
disableFks bool
statsCB noms.StatsCB
statsCB StatsCb
stats types.AppliedEditStats
statOps int32
@@ -60,7 +62,7 @@ type SqlEngineTableWriter struct {
rowOperationSchema sql.PrimaryKeySchema
}
func NewSqlEngineTableWriter(ctx *sql.Context, engine *sqle.Engine, createTableSchema, rowOperationSchema schema.Schema, options *MoverOptions, statsCB noms.StatsCB) (*SqlEngineTableWriter, error) {
func NewSqlEngineTableWriter(ctx *sql.Context, engine *sqle.Engine, createTableSchema, rowOperationSchema schema.Schema, options *MoverOptions, statsCB StatsCb) (*SqlEngineTableWriter, error) {
if engine.IsReadOnly() {
// SqlEngineTableWriter does not respect read only mode
return nil, analyzererrors.ErrReadOnlyDatabase.New(ctx.GetCurrentDatabase())
-158
View File
@@ -15,8 +15,6 @@
package row
import (
"context"
"errors"
"fmt"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
@@ -24,8 +22,6 @@ import (
"github.com/dolthub/dolt/go/store/types"
)
var ErrRowNotValid = errors.New("invalid row for current schema")
type Row interface {
// Iterates over all the columns in the row. Columns that have no value set will not be visited.
IterCols(cb func(tag uint64, val types.Value) (stop bool, err error)) (bool, error)
@@ -82,21 +78,6 @@ func FromNoms(sch schema.Schema, nomsKey, nomsVal types.Tuple) (Row, error) {
return pkRowFromNoms(sch, nomsKey, nomsVal)
}
// ToNoms returns the storage-layer tuples corresponding to |r|.
func ToNoms(ctx context.Context, sch schema.Schema, r Row) (key, val types.Tuple, err error) {
k, err := r.NomsMapKey(sch).Value(ctx)
if err != nil {
return key, val, err
}
v, err := r.NomsMapValue(sch).Value(ctx)
if err != nil {
return key, val, err
}
return k.(types.Tuple), v.(types.Tuple), nil
}
func GetFieldByName(colName string, r Row, sch schema.Schema) (types.Value, bool) {
col, ok := sch.GetAllCols().GetByName(colName)
@@ -123,71 +104,6 @@ func GetFieldByNameWithDefault(colName string, defVal types.Value, r Row, sch sc
}
}
// ReduceToIndexKeysFromTagMap creates a full key and a partial key from the given map of tags (first tuple being the
// full key). Please refer to the note in the index editor for more information regarding partial keys.
func ReduceToIndexKeysFromTagMap(nbf *types.NomsBinFormat, idx schema.Index, tagToVal map[uint64]types.Value, tf *types.TupleFactory) (types.Tuple, types.Tuple, error) {
vals := make([]types.Value, 0, len(idx.AllTags())*2)
for _, tag := range idx.AllTags() {
val, ok := tagToVal[tag]
if !ok {
val = types.NullValue
}
vals = append(vals, types.Uint(tag), val)
}
if tf == nil {
fullKey, err := types.NewTuple(nbf, vals...)
if err != nil {
return types.Tuple{}, types.Tuple{}, err
}
partialKey, err := types.NewTuple(nbf, vals[:idx.Count()*2]...)
if err != nil {
return types.Tuple{}, types.Tuple{}, err
}
return fullKey, partialKey, nil
} else {
fullKey, err := tf.Create(vals...)
if err != nil {
return types.Tuple{}, types.Tuple{}, err
}
partialKey, err := tf.Create(vals[:idx.Count()*2]...)
if err != nil {
return types.Tuple{}, types.Tuple{}, err
}
return fullKey, partialKey, nil
}
}
// ReduceToIndexPartialKey creates an index record from a primary storage record.
func ReduceToIndexPartialKey(tags []uint64, idx schema.Index, r Row) (types.Tuple, error) {
var vals []types.Value
if idx.Name() != "" {
tags = idx.IndexedColumnTags()
}
for _, tag := range tags {
val, ok := r.GetColVal(tag)
if !ok {
val = types.NullValue
}
vals = append(vals, types.Uint(tag), val)
}
return types.NewTuple(r.Format(), vals...)
}
func IsEmpty(r Row) (b bool) {
b = true
_, _ = r.IterCols(func(_ uint64, _ types.Value) (stop bool, err error) {
b = false
return true, nil
})
return b
}
// IsValid returns whether the row given matches the types and satisfies all the constraints of the schema given.
func IsValid(r Row, sch schema.Schema) (bool, error) {
column, constraint, err := findInvalidCol(r, sch)
@@ -265,77 +181,3 @@ func AreEqual(row1, row2 Row, sch schema.Schema) bool {
return true
}
func TaggedValsEqualForSch(tv, other TaggedValues, sch schema.Schema) bool {
if tv == nil && other == nil {
return true
} else if tv == nil || other == nil {
return false
}
for _, tag := range sch.GetAllCols().Tags {
val1, _ := tv[tag]
val2, _ := other[tag]
if !valutil.NilSafeEqCheck(val1, val2) {
return false
}
}
return true
}
func KeyAndTaggedValuesForRow(r Row, sch schema.Schema) (types.Tuple, TaggedValues, error) {
switch typed := r.(type) {
case nomsRow:
pkCols := sch.GetPKCols()
keyVals := make([]types.Value, 0, pkCols.Size()*2)
tv := make(TaggedValues)
err := pkCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
val, ok := typed.key[tag]
if !ok || types.IsNull(val) {
return false, errors.New("invalid key contains null values")
}
tv[tag] = val
keyVals = append(keyVals, types.Uint(tag))
keyVals = append(keyVals, val)
return false, nil
})
if err != nil {
return types.Tuple{}, nil, err
}
nonPkCols := sch.GetNonPKCols()
_, err = typed.value.Iter(func(tag uint64, val types.Value) (stop bool, err error) {
if _, ok := nonPkCols.TagToIdx[tag]; ok {
tv[tag] = val
}
return false, nil
})
if err != nil {
return types.Tuple{}, nil, err
}
t, err := types.NewTuple(r.Format(), keyVals...)
if err != nil {
return types.Tuple{}, nil, err
}
return t, tv, nil
case keylessRow:
tv, err := typed.TaggedValues()
if err != nil {
return types.Tuple{}, nil, err
}
return typed.key, tv, nil
default:
panic("unknown row type")
}
}
@@ -191,60 +191,6 @@ func TaggedValuesFromTupleValueSlice(vals types.TupleValueSlice) (TaggedValues,
return taggedTuple, nil
}
func TaggedValuesFromTupleKeyAndValue(key, value types.Tuple) (TaggedValues, error) {
tv := make(TaggedValues)
err := AddToTaggedVals(tv, key)
if err != nil {
return nil, err
}
err = AddToTaggedVals(tv, value)
if err != nil {
return nil, err
}
return tv, nil
}
func AddToTaggedVals(tv TaggedValues, t types.Tuple) error {
return IterDoltTuple(t, func(tag uint64, val types.Value) error {
tv[tag] = val
return nil
})
}
func IterDoltTuple(t types.Tuple, cb func(tag uint64, val types.Value) error) error {
itr, err := t.Iterator()
if err != nil {
return err
}
for itr.HasMore() {
_, tag, err := itr.NextUint64()
if err != nil {
return err
}
_, currVal, err := itr.Next()
if err != nil {
return err
}
err = cb(tag, currVal)
if err != nil {
return err
}
}
return nil
}
func (tt TaggedValues) String() string {
str := "{"
for k, v := range tt {
@@ -260,36 +206,3 @@ func (tt TaggedValues) String() string {
str += "\n}"
return str
}
// CountCellDiffs returns the number of fields that are different between two
// tuples and does not panic if tuples are different lengths.
func CountCellDiffs(from, to types.Tuple, fromSch, toSch schema.Schema) (uint64, error) {
fromColLen := len(fromSch.GetAllCols().GetColumns())
toColLen := len(toSch.GetAllCols().GetColumns())
changed := 0
f, err := ParseTaggedValues(from)
if err != nil {
return 0, err
}
t, err := ParseTaggedValues(to)
if err != nil {
return 0, err
}
for i, v := range f {
ov, ok := t[i]
// !ok means t[i] has NULL value, and it is not cell modify if it was from drop column or add column
if (!ok && fromColLen == toColLen) || (ok && !v.Equals(ov)) {
changed++
}
}
for i := range t {
if f[i] == nil {
changed++
}
}
return uint64(changed), nil
}
@@ -18,12 +18,9 @@ import (
"context"
"fmt"
"github.com/dolthub/dolt/go/libraries/doltcore/row"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
"github.com/dolthub/dolt/go/store/types"
"github.com/dolthub/go-mysql-server/sql"
)
var IdentityConverter = &RowConverter{nil, true, nil}
@@ -96,56 +93,6 @@ func panicOnDuplicateMappings(mapping *FieldMapping) {
}
}
// ConvertWithWarnings takes an input row, maps its columns to their destination columns, performing any type
// conversions needed to create a row of the expected destination schema, and uses the optional WarnFunction
// callback to let callers handle logging a warning when a field cannot be cleanly converted.
func (rc *RowConverter) ConvertWithWarnings(inRow row.Row, warnFn WarnFunction) (row.Row, error) {
return rc.convert(inRow, warnFn)
}
// convert takes a row and maps its columns to their destination columns, automatically performing any type conversion
// needed, and using the optional WarnFunction to let callers log warnings on any type conversion errors.
func (rc *RowConverter) convert(inRow row.Row, warnFn WarnFunction) (row.Row, error) {
if rc.IdentityConverter {
return inRow, nil
}
outTaggedVals := make(row.TaggedValues, len(rc.SrcToDest))
_, err := inRow.IterCols(func(tag uint64, val types.Value) (stop bool, err error) {
convFunc, ok := rc.ConvFuncs[tag]
if ok {
outTag := rc.SrcToDest[tag]
outVal, err := convFunc(val)
if sql.ErrInvalidValue.Is(err) && warnFn != nil {
col, _ := rc.SrcSch.GetAllCols().GetByTag(tag)
warnFn(DatatypeCoercionFailureWarningCode, DatatypeCoercionFailureWarning, col.Name)
outVal = types.NullValue
err = nil
}
if err != nil {
return false, err
}
if types.IsNull(outVal) {
return false, nil
}
outTaggedVals[outTag] = outVal
}
return false, nil
})
if err != nil {
return nil, err
}
return row.New(inRow.Format(), rc.DestSch, outTaggedVals)
}
func IsNecessary(srcSch, destSch schema.Schema, destToSrc map[uint64]uint64) (bool, error) {
srcCols := srcSch.GetAllCols()
destCols := destSch.GetAllCols()
@@ -42,18 +42,6 @@ const (
NotNullConstraintType = "not_null"
)
// ColConstraintFromTypeAndParams takes in a string representing the type of the constraint and a map of parameters
// that can be used to determine the behavior of the constraint. An example might be a constraint which validated
// a value is in a given range. For this the constraint type might by "in_range_constraint", and the parameters might
// be {"min": -10, "max": 10}
func ColConstraintFromTypeAndParams(colCnstType string, params map[string]string) ColConstraint {
switch colCnstType {
case NotNullConstraintType:
return NotNullConstraint{}
}
panic("Unknown column constraint type: " + colCnstType)
}
// NotNullConstraint validates that a value is not null. It does not restrict 0 length strings, or 0 valued ints, or
// anything other than non nil values
type NotNullConstraint struct{}
@@ -24,7 +24,6 @@ import (
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/dolt/go/store/types"
)
@@ -45,38 +44,10 @@ type blobStringType struct {
var _ TypeInfo = (*blobStringType)(nil)
var (
TinyTextType TypeInfo = &blobStringType{sqlStringType: gmstypes.TinyText}
TextType TypeInfo = &blobStringType{sqlStringType: gmstypes.Text}
MediumTextType TypeInfo = &blobStringType{sqlStringType: gmstypes.MediumText}
LongTextType TypeInfo = &blobStringType{sqlStringType: gmstypes.LongText}
TextType TypeInfo = &blobStringType{sqlStringType: gmstypes.Text}
LongTextType TypeInfo = &blobStringType{sqlStringType: gmstypes.LongText}
)
func CreateBlobStringTypeFromParams(params map[string]string) (TypeInfo, error) {
collationStr, ok := params[blobStringTypeParam_Collate]
if !ok {
return nil, fmt.Errorf(`create blobstring type info is missing param "%v"`, blobStringTypeParam_Collate)
}
collation, err := sql.ParseCollation("", collationStr, false)
if err != nil {
return nil, err
}
maxLengthStr, ok := params[blobStringTypeParam_Length]
if !ok {
return nil, fmt.Errorf(`create blobstring type info is missing param "%v"`, blobStringTypeParam_Length)
}
length, err := strconv.ParseInt(maxLengthStr, 10, 64)
if err != nil {
return nil, err
}
sqlType, err := gmstypes.CreateString(sqltypes.Text, length, collation)
if err != nil {
return nil, err
}
return &blobStringType{sqlType}, nil
}
// ConvertNomsValueToValue implements TypeInfo interface.
func (ti *blobStringType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
if val, ok := v.(types.Blob); ok {
@@ -116,29 +116,6 @@ func generateSetType(t *testing.T, numOfElements int) *setType {
return &setType{gmstypes.MustCreateSetType(vals, sql.Collation_Default)}
}
func generateInlineBlobTypes(t *testing.T, numOfTypes uint16) []TypeInfo {
var res []TypeInfo
loop(t, 1, 500, numOfTypes, func(i int64) {
pad := false
if i%2 == 0 {
pad = true
}
res = append(res, generateInlineBlobType(t, i, pad))
})
return res
}
func generateInlineBlobType(t *testing.T, length int64, pad bool) *inlineBlobType {
require.True(t, length > 0)
if pad {
t, err := gmstypes.CreateBinary(sqltypes.Binary, length)
if err == nil {
return &inlineBlobType{t}
}
}
return &inlineBlobType{gmstypes.MustCreateBinary(sqltypes.VarBinary, length)}
}
func generateVarStringTypes(t *testing.T, numOfTypes uint16) []TypeInfo {
var res []TypeInfo
loop(t, 1, 500, numOfTypes, func(i int64) {
@@ -17,7 +17,6 @@ package typeinfo
import (
"context"
"fmt"
"strconv"
"time"
"github.com/dolthub/go-mysql-server/sql"
@@ -50,39 +49,6 @@ func CreateDatetimeTypeFromSqlType(typ sql.DatetimeType) *datetimeType {
return &datetimeType{typ}
}
func CreateDatetimeTypeFromParams(params map[string]string) (TypeInfo, error) {
if sqlType, ok := params[datetimeTypeParam_SQL]; ok {
precision := 6
if precisionParam, ok := params[datetimeTypeParam_Precision]; ok {
var err error
precision, err = strconv.Atoi(precisionParam)
if err != nil {
return nil, err
}
}
switch sqlType {
case datetimeTypeParam_SQL_Date:
return DateType, nil
case datetimeTypeParam_SQL_Datetime:
gmsType, err := gmstypes.CreateDatetimeType(sqltypes.Datetime, precision)
if err != nil {
return nil, err
}
return CreateDatetimeTypeFromSqlType(gmsType), nil
case datetimeTypeParam_SQL_Timestamp:
gmsType, err := gmstypes.CreateDatetimeType(sqltypes.Timestamp, precision)
if err != nil {
return nil, err
}
return CreateDatetimeTypeFromSqlType(gmsType), nil
default:
return nil, fmt.Errorf(`create datetime type info has invalid param "%v"`, sqlType)
}
} else {
return nil, fmt.Errorf(`create datetime type info is missing param "%v"`, datetimeTypeParam_SQL)
}
}
// ConvertNomsValueToValue implements TypeInfo interface.
func (ti *datetimeType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
if val, ok := v.(types.Timestamp); ok {
@@ -121,7 +87,7 @@ func (ti *datetimeType) ReadFrom(_ *types.NomsBinFormat, reader types.CodecReade
// ConvertValueToNomsValue implements TypeInfo interface.
func (ti *datetimeType) ConvertValueToNomsValue(ctx context.Context, vrw types.ValueReadWriter, v interface{}) (types.Value, error) {
//TODO: handle the zero value as a special case that is valid for all ranges
// TODO: handle the zero value as a special case that is valid for all ranges
if v == nil {
return types.NullValue, nil
}
@@ -16,12 +16,10 @@ package typeinfo
import (
"context"
"encoding/gob"
"fmt"
"strings"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/dolt/go/store/types"
)
@@ -39,34 +37,6 @@ type enumType struct {
var _ TypeInfo = (*enumType)(nil)
func CreateEnumTypeFromParams(params map[string]string) (TypeInfo, error) {
var collation sql.CollationID
var err error
if collationStr, ok := params[enumTypeParam_Collation]; ok {
collation, err = sql.ParseCollation("", collationStr, false)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf(`create enum type info is missing param "%v"`, enumTypeParam_Collation)
}
var values []string
if valuesStr, ok := params[enumTypeParam_Values]; ok {
dec := gob.NewDecoder(strings.NewReader(valuesStr))
err = dec.Decode(&values)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf(`create enum type info is missing param "%v"`, enumTypeParam_Values)
}
sqlEnumType, err := gmstypes.CreateEnumType(values, collation)
if err != nil {
return nil, err
}
return CreateEnumTypeFromSqlEnumType(sqlEnumType), nil
}
func CreateEnumTypeFromSqlEnumType(sqlEnumType sql.EnumType) TypeInfo {
return &enumType{sqlEnumType}
}
@@ -19,15 +19,10 @@ import (
"fmt"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/dolt/go/store/types"
)
const (
extendedTypeParams_string_encoded = "string_encoded"
)
// extendedType is a type that refers to an ExtendedType in GMS. These are only supported in the new format, and have many
// more limitations than traditional types (for now).
type extendedType struct {
@@ -36,18 +31,6 @@ type extendedType struct {
var _ TypeInfo = (*extendedType)(nil)
// CreateExtendedTypeFromParams creates a TypeInfo from the given parameter map.
func CreateExtendedTypeFromParams(params map[string]string) (TypeInfo, error) {
if encodedString, ok := params[extendedTypeParams_string_encoded]; ok {
t, err := gmstypes.DeserializeTypeFromString(encodedString)
if err != nil {
return nil, err
}
return &extendedType{t}, nil
}
return nil, fmt.Errorf(`create extended type info is missing "%v" param`, extendedTypeParams_string_encoded)
}
// CreateExtendedTypeFromSqlType creates a TypeInfo from the given extended type.
func CreateExtendedTypeFromSqlType(typ sql.ExtendedType) TypeInfo {
return &extendedType{typ}
@@ -30,12 +30,6 @@ import (
type FloatWidth int8
const (
floatTypeParam_Width = "width"
floatTypeParam_Width_32 = "32"
floatTypeParam_Width_64 = "64"
)
type floatType struct {
sqlFloatType sql.NumberType
}
@@ -46,20 +40,6 @@ var (
Float64Type = &floatType{gmstypes.Float64}
)
func CreateFloatTypeFromParams(params map[string]string) (TypeInfo, error) {
if width, ok := params[floatTypeParam_Width]; ok {
switch width {
case floatTypeParam_Width_32:
return Float32Type, nil
case floatTypeParam_Width_64:
return Float64Type, nil
default:
return nil, fmt.Errorf(`create float type info has "%v" param with value "%v"`, floatTypeParam_Width, width)
}
}
return nil, fmt.Errorf(`create float type info is missing "%v" param`, floatTypeParam_Width)
}
// ConvertNomsValueToValue implements TypeInfo interface.
func (ti *floatType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
if val, ok := v.(types.Float); ok {
@@ -17,7 +17,6 @@ package typeinfo
import (
"context"
"fmt"
"strconv"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
@@ -289,28 +288,6 @@ func geometryTypeConverter(ctx context.Context, src *geometryType, destTi TypeIn
}
}
func CreateGeometryTypeFromParams(params map[string]string) (TypeInfo, error) {
var (
err error
sridVal uint64
def bool
)
if s, ok := params["SRID"]; ok {
sridVal, err = strconv.ParseUint(s, 10, 32)
if err != nil {
return nil, err
}
}
if d, ok := params["DefinedSRID"]; ok {
def, err = strconv.ParseBool(d)
if err != nil {
return nil, err
}
}
return CreateGeometryTypeFromSqlGeometryType(gmstypes.GeometryType{SRID: uint32(sridVal), DefinedSRID: def}), nil
}
func CreateGeometryTypeFromSqlGeometryType(sqlGeometryType gmstypes.GeometryType) TypeInfo {
return &geometryType{sqlGeometryType: sqlGeometryType}
}
@@ -17,7 +17,6 @@ package typeinfo
import (
"context"
"fmt"
"strconv"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
@@ -197,25 +196,3 @@ func geomcollTypeConverter(ctx context.Context, src *geomcollType, destTi TypeIn
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
}
}
func CreateGeomCollTypeFromParams(params map[string]string) (TypeInfo, error) {
var (
err error
sridVal uint64
def bool
)
if s, ok := params["SRID"]; ok {
sridVal, err = strconv.ParseUint(s, 10, 32)
if err != nil {
return nil, err
}
}
if d, ok := params["DefinedSRID"]; ok {
def, err = strconv.ParseBool(d)
if err != nil {
return nil, err
}
}
return &geomcollType{sqlGeomCollType: gmstypes.GeomCollType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
}
@@ -22,19 +22,11 @@ import (
"unsafe"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/dolt/go/store/types"
)
const (
inlineBlobTypeParam_Length = "length"
inlineBlobTypeParam_SQL = "sql"
inlineBlobTypeParam_SQL_Binary = "bin"
inlineBlobTypeParam_SQL_VarBinary = "varbin"
)
// inlineBlobType handles BINARY and VARBINARY. BLOB types are handled by varBinaryType.
type inlineBlobType struct {
sqlBinaryType sql.StringType
@@ -42,40 +34,6 @@ type inlineBlobType struct {
var _ TypeInfo = (*inlineBlobType)(nil)
var (
VarbinaryDefaultType = &inlineBlobType{gmstypes.MustCreateBinary(sqltypes.VarBinary, 16383)}
)
func CreateInlineBlobTypeFromParams(params map[string]string) (TypeInfo, error) {
var length int64
var err error
if lengthStr, ok := params[inlineBlobTypeParam_Length]; ok {
length, err = strconv.ParseInt(lengthStr, 10, 64)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf(`create inlineblob type info is missing param "%v"`, inlineBlobTypeParam_Length)
}
if sqlStr, ok := params[inlineBlobTypeParam_SQL]; ok {
var sqlType sql.StringType
switch sqlStr {
case inlineBlobTypeParam_SQL_Binary:
sqlType, err = gmstypes.CreateBinary(sqltypes.Binary, length)
case inlineBlobTypeParam_SQL_VarBinary:
sqlType, err = gmstypes.CreateBinary(sqltypes.VarBinary, length)
default:
return nil, fmt.Errorf(`create inlineblob type info has "%v" param with value "%v"`, inlineBlobTypeParam_SQL, sqlStr)
}
if err != nil {
return nil, err
}
return &inlineBlobType{sqlType}, nil
} else {
return nil, fmt.Errorf(`create inlineblob type info is missing param "%v"`, inlineBlobTypeParam_SQL)
}
}
// ConvertNomsValueToValue implements TypeInfo interface.
func (ti *inlineBlobType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
if val, ok := v.(types.InlineBlob); ok {
@@ -26,15 +26,6 @@ import (
"github.com/dolthub/dolt/go/store/types"
)
const (
intTypeParams_Width = "width"
intTypeParams_Width_8 = "8"
intTypeParams_Width_16 = "16"
intTypeParams_Width_24 = "24"
intTypeParams_Width_32 = "32"
intTypeParams_Width_64 = "64"
)
type intType struct {
sqlIntType sql.NumberType
}
@@ -48,26 +39,6 @@ var (
Int64Type = &intType{gmstypes.Int64}
)
func CreateIntTypeFromParams(params map[string]string) (TypeInfo, error) {
if width, ok := params[intTypeParams_Width]; ok {
switch width {
case intTypeParams_Width_8:
return Int8Type, nil
case intTypeParams_Width_16:
return Int16Type, nil
case intTypeParams_Width_24:
return Int24Type, nil
case intTypeParams_Width_32:
return Int32Type, nil
case intTypeParams_Width_64:
return Int64Type, nil
default:
return nil, fmt.Errorf(`create int type info has "%v" param with value "%v"`, intTypeParams_Width, width)
}
}
return nil, fmt.Errorf(`create int type info is missing "%v" param`, intTypeParams_Width)
}
// ConvertNomsValueToValue implements TypeInfo interface.
func (ti *intType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
if val, ok := v.(types.Int); ok {
@@ -17,7 +17,6 @@ package typeinfo
import (
"context"
"fmt"
"strconv"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
@@ -197,24 +196,3 @@ func linestringTypeConverter(ctx context.Context, src *linestringType, destTi Ty
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
}
}
func CreateLineStringTypeFromParams(params map[string]string) (TypeInfo, error) {
var (
err error
sridVal uint64
def bool
)
if s, ok := params["SRID"]; ok {
sridVal, err = strconv.ParseUint(s, 10, 32)
if err != nil {
return nil, err
}
}
if d, ok := params["DefinedSRID"]; ok {
def, err = strconv.ParseBool(d)
if err != nil {
return nil, err
}
}
return &linestringType{sqlLineStringType: gmstypes.LineStringType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
}
@@ -17,7 +17,6 @@ package typeinfo
import (
"context"
"fmt"
"strconv"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
@@ -197,25 +196,3 @@ func multilinestringTypeConverter(ctx context.Context, src *multilinestringType,
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
}
}
func CreateMultiLineStringTypeFromParams(params map[string]string) (TypeInfo, error) {
var (
err error
sridVal uint64
def bool
)
if s, ok := params["SRID"]; ok {
sridVal, err = strconv.ParseUint(s, 10, 32)
if err != nil {
return nil, err
}
}
if d, ok := params["DefinedSRID"]; ok {
def, err = strconv.ParseBool(d)
if err != nil {
return nil, err
}
}
return &multilinestringType{sqlMultiLineStringType: gmstypes.MultiLineStringType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
}
@@ -17,7 +17,6 @@ package typeinfo
import (
"context"
"fmt"
"strconv"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
@@ -197,24 +196,3 @@ func multipointTypeConverter(ctx context.Context, src *multipointType, destTi Ty
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
}
}
func CreateMultiPointTypeFromParams(params map[string]string) (TypeInfo, error) {
var (
err error
sridVal uint64
def bool
)
if s, ok := params["SRID"]; ok {
sridVal, err = strconv.ParseUint(s, 10, 32)
if err != nil {
return nil, err
}
}
if d, ok := params["DefinedSRID"]; ok {
def, err = strconv.ParseBool(d)
if err != nil {
return nil, err
}
}
return &multipointType{sqlMultiPointType: gmstypes.MultiPointType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
}
@@ -17,7 +17,6 @@ package typeinfo
import (
"context"
"fmt"
"strconv"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
@@ -197,25 +196,3 @@ func multipolygonTypeConverter(ctx context.Context, src *multipolygonType, destT
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
}
}
func CreateMultiPolygonTypeFromParams(params map[string]string) (TypeInfo, error) {
var (
err error
sridVal uint64
def bool
)
if s, ok := params["SRID"]; ok {
sridVal, err = strconv.ParseUint(s, 10, 32)
if err != nil {
return nil, err
}
}
if d, ok := params["DefinedSRID"]; ok {
def, err = strconv.ParseBool(d)
if err != nil {
return nil, err
}
}
return &multipolygonType{sqlMultiPolygonType: gmstypes.MultiPolygonType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
}
@@ -17,7 +17,6 @@ package typeinfo
import (
"context"
"fmt"
"strconv"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
@@ -198,28 +197,6 @@ func pointTypeConverter(ctx context.Context, src *pointType, destTi TypeInfo) (t
}
}
func CreatePointTypeFromParams(params map[string]string) (TypeInfo, error) {
var (
err error
sridVal uint64
def bool
)
if s, ok := params["SRID"]; ok {
sridVal, err = strconv.ParseUint(s, 10, 32)
if err != nil {
return nil, err
}
}
if d, ok := params["DefinedSRID"]; ok {
def, err = strconv.ParseBool(d)
if err != nil {
return nil, err
}
}
return CreatePointTypeFromSqlPointType(gmstypes.PointType{SRID: uint32(sridVal), DefinedSRID: def}), nil
}
func CreatePointTypeFromSqlPointType(sqlPointType gmstypes.PointType) TypeInfo {
return &pointType{sqlPointType: sqlPointType}
}
@@ -17,7 +17,6 @@ package typeinfo
import (
"context"
"fmt"
"strconv"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
@@ -197,25 +196,3 @@ func polygonTypeConverter(ctx context.Context, src *polygonType, destTi TypeInfo
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
}
}
func CreatePolygonTypeFromParams(params map[string]string) (TypeInfo, error) {
var (
err error
sridVal uint64
def bool
)
if s, ok := params["SRID"]; ok {
sridVal, err = strconv.ParseUint(s, 10, 32)
if err != nil {
return nil, err
}
}
if d, ok := params["DefinedSRID"]; ok {
def, err = strconv.ParseBool(d)
if err != nil {
return nil, err
}
}
return &polygonType{sqlPolygonType: gmstypes.PolygonType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
}
@@ -16,21 +16,14 @@ package typeinfo
import (
"context"
"encoding/gob"
"fmt"
"strings"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/dolt/go/store/types"
)
const (
setTypeParam_Collation = "collate"
setTypeParam_Values = "vals"
)
// This is a dolt implementation of the MySQL type Set, thus most of the functionality
// within is directly reliant on the go-mysql-server implementation.
type setType struct {
@@ -39,33 +32,6 @@ type setType struct {
var _ TypeInfo = (*setType)(nil)
func CreateSetTypeFromParams(params map[string]string) (TypeInfo, error) {
collationStr, ok := params[setTypeParam_Collation]
if !ok {
return nil, fmt.Errorf(`create set type info is missing param "%v"`, setTypeParam_Collation)
}
collation, err := sql.ParseCollation("", collationStr, false)
if err != nil {
return nil, err
}
valuesStr, ok := params[setTypeParam_Values]
if !ok {
return nil, fmt.Errorf(`create set type info is missing param "%v"`, setTypeParam_Values)
}
var values []string
dec := gob.NewDecoder(strings.NewReader(valuesStr))
if err = dec.Decode(&values); err != nil {
return nil, err
}
sqlSetType, err := gmstypes.CreateSetType(values, collation)
if err != nil {
return nil, err
}
return CreateSetTypeFromSqlSetType(sqlSetType), nil
}
func CreateSetTypeFromSqlSetType(sqlSetType sql.SetType) TypeInfo {
return &setType{sqlSetType}
}
@@ -26,15 +26,6 @@ import (
"github.com/dolthub/dolt/go/store/types"
)
const (
uintTypeParam_Width = "width"
uintTypeParam_Width_8 = "8"
uintTypeParam_Width_16 = "16"
uintTypeParam_Width_24 = "24"
uintTypeParam_Width_32 = "32"
uintTypeParam_Width_64 = "64"
)
type uintType struct {
sqlUintType sql.NumberType
}
@@ -48,26 +39,6 @@ var (
Uint64Type = &uintType{gmstypes.Uint64}
)
func CreateUintTypeFromParams(params map[string]string) (TypeInfo, error) {
if width, ok := params[uintTypeParam_Width]; ok {
switch width {
case uintTypeParam_Width_8:
return Uint8Type, nil
case uintTypeParam_Width_16:
return Uint16Type, nil
case uintTypeParam_Width_24:
return Uint24Type, nil
case uintTypeParam_Width_32:
return Uint32Type, nil
case uintTypeParam_Width_64:
return Uint64Type, nil
default:
return nil, fmt.Errorf(`create uint type info has "%v" param with value "%v"`, uintTypeParam_Width, width)
}
}
return nil, fmt.Errorf(`create uint type info is missing "%v" param`, uintTypeParam_Width)
}
// ConvertNomsValueToValue implements TypeInfo interface.
func (ti *uintType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
if val, ok := v.(types.Uint); ok {
@@ -16,7 +16,6 @@ package typeinfo
import (
"context"
"encoding/binary"
"fmt"
"io"
"strconv"
@@ -24,16 +23,10 @@ import (
"unsafe"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/dolt/go/store/types"
)
const (
varBinaryTypeParam_Length = "length"
)
// As a type, this is modeled more after MySQL's story for binary data. There, it's treated
// as a string that is interpreted as raw bytes, rather than as a bespoke data structure,
// and thus this is mirrored here in its implementation. This will minimize any differences
@@ -46,31 +39,6 @@ type varBinaryType struct {
var _ TypeInfo = (*varBinaryType)(nil)
var (
TinyBlobType TypeInfo = &varBinaryType{sqlBinaryType: gmstypes.TinyBlob}
BlobType TypeInfo = &varBinaryType{sqlBinaryType: gmstypes.Blob}
MediumBlobType TypeInfo = &varBinaryType{sqlBinaryType: gmstypes.MediumBlob}
LongBlobType TypeInfo = &varBinaryType{sqlBinaryType: gmstypes.LongBlob}
)
func CreateVarBinaryTypeFromParams(params map[string]string) (TypeInfo, error) {
var length int64
var err error
if lengthStr, ok := params[varBinaryTypeParam_Length]; ok {
length, err = strconv.ParseInt(lengthStr, 10, 64)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf(`create varbinary type info is missing param "%v"`, varBinaryTypeParam_Length)
}
sqlType, err := gmstypes.CreateBinary(sqltypes.Blob, length)
if err != nil {
return nil, err
}
return &varBinaryType{sqlType}, nil
}
// ConvertNomsValueToValue implements TypeInfo interface.
func (ti *varBinaryType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
if val, ok := v.(types.Blob); ok {
@@ -199,29 +167,6 @@ func fromBlob(b types.Blob) ([]byte, error) {
return str, nil
}
// hasPrefix finds out if a Blob has a prefixed integer. Initially blobs for varBinary prepended an integer indicating
// the length, which was unnecessary (as the underlying sequence tracks the total size). It's been removed, but this
// may be used to see if a Blob is one of those older Blobs. A false positive is possible, but EXTREMELY unlikely.
func hasPrefix(b types.Blob, ctx context.Context) (bool, error) {
blobLength := b.Len()
if blobLength < 8 {
return false, nil
}
countBytes := make([]byte, 8)
n, err := b.ReadAt(ctx, countBytes, 0)
if err != nil {
return false, err
}
if n != 8 {
return false, fmt.Errorf("wanted 8 bytes from blob for count, got %d", n)
}
prefixedLength := binary.LittleEndian.Uint64(countBytes)
if prefixedLength == blobLength-8 {
return true, nil
}
return false, nil
}
// varBinaryTypeConverter is an internal function for GetTypeConverter that handles the specific type as the source TypeInfo.
func varBinaryTypeConverter(ctx context.Context, src *varBinaryType, destTi TypeInfo) (tc TypeConverter, needsConversion bool, err error) {
switch dest := destTi.(type) {
@@ -28,15 +28,6 @@ import (
"github.com/dolthub/dolt/go/store/types"
)
const (
varStringTypeParam_Collate = "collate"
varStringTypeParam_Length = "length"
varStringTypeParam_SQL = "sql"
varStringTypeParam_SQL_Char = "char"
varStringTypeParam_SQL_VarChar = "varchar"
varStringTypeParam_SQL_Text = "text"
)
// varStringType handles CHAR and VARCHAR. The TEXT types are handled by blobStringType. For any repositories that were
// created before the introduction of blobStringType, they will use varStringType for TEXT types. As varStringType makes
// use of the String Value type, it does not actually support all viable lengths of a TEXT string, meaning all such
@@ -60,48 +51,6 @@ func CreateVarStringTypeFromSqlType(stringType sql.StringType) TypeInfo {
return &varStringType{stringType}
}
func CreateVarStringTypeFromParams(params map[string]string) (TypeInfo, error) {
var length int64
var collation sql.CollationID
var err error
if collationStr, ok := params[varStringTypeParam_Collate]; ok {
collation, err = sql.ParseCollation("", collationStr, false)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf(`create varstring type info is missing param "%v"`, varStringTypeParam_Collate)
}
if maxLengthStr, ok := params[varStringTypeParam_Length]; ok {
length, err = strconv.ParseInt(maxLengthStr, 10, 64)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf(`create varstring type info is missing param "%v"`, varStringTypeParam_Length)
}
if sqlStr, ok := params[varStringTypeParam_SQL]; ok {
var sqlType sql.StringType
switch sqlStr {
case varStringTypeParam_SQL_Char:
sqlType, err = gmstypes.CreateString(sqltypes.Char, length, collation)
case varStringTypeParam_SQL_VarChar:
sqlType, err = gmstypes.CreateString(sqltypes.VarChar, length, collation)
case varStringTypeParam_SQL_Text:
sqlType, err = gmstypes.CreateString(sqltypes.Text, length, collation)
default:
return nil, fmt.Errorf(`create varstring type info has "%v" param with value "%v"`, varStringTypeParam_SQL, sqlStr)
}
if err != nil {
return nil, err
}
return &varStringType{sqlType}, nil
} else {
return nil, fmt.Errorf(`create varstring type info is missing param "%v"`, varStringTypeParam_Length)
}
}
// ConvertNomsValueToValue implements TypeInfo interface.
func (ti *varStringType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
if val, ok := v.(types.String); ok {
@@ -434,12 +434,8 @@ func TestDropPks(t *testing.T) {
ctx := context.Background()
dEnv := dtestutils.CreateTestEnv()
defer dEnv.DoltDB(ctx).Close()
tmpDir, err := dEnv.TempTableFilesDir()
require.NoError(t, err)
deaf, err := dEnv.DbEaFactory(ctx)
require.NoError(t, err)
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), editor.Options{})
require.NoError(t, err)
root, _ := dEnv.WorkingRoot(ctx)
@@ -841,9 +841,7 @@ func getTableWriter(ctx *sql.Context, engine *gms.Engine, tableName, databaseNam
return nil, nil, err
}
options := sqlDatabase.EditOptions()
options.ForeignKeyChecksDisabled = foreignKeyChecksDisabled
writeSession := writer.NewWriteSession(binFormat, ws, tracker, options)
writeSession := writer.NewWriteSession(binFormat, ws, tracker, sqlDatabase.EditOptions())
ds := dsess.DSessFromSess(ctx.Session)
setter := ds.SetWorkingRoot
+2 -12
View File
@@ -39,12 +39,7 @@ type SetupFn func(t *testing.T, dEnv *env.DoltEnv)
// Runs the query given and returns the result. The schema result of the query's execution is currently ignored, and
// the targetSchema given is used to prepare all rows.
func executeSelect(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue, query string) ([]sql.Row, sql.Schema, error) {
tmpDir, err := dEnv.TempTableFilesDir()
require.NoError(t, err)
deaf, err := dEnv.DbEaFactory(ctx)
require.NoError(t, err)
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), editor.Options{})
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(dEnv, ctx, db)
@@ -72,12 +67,7 @@ func executeSelect(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root do
// Runs the query given and returns the error (if any).
func executeModify(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue, query string) (doltdb.RootValue, error) {
tmpDir, err := dEnv.TempTableFilesDir()
require.NoError(t, err)
deaf, err := dEnv.DbEaFactory(ctx)
require.NoError(t, err)
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), editor.Options{})
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(dEnv, ctx, db)
+4 -1
View File
@@ -1955,7 +1955,10 @@ func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.Prima
return err
}
if doltdb.IsSystemTable(doltdb.TableName{Name: tableName, Schema: db.schemaName}) && !doltdb.IsFullTextTable(tableName) && !doltdb.HasDoltCIPrefix(tableName) {
if doltdb.IsSystemTable(doltdb.TableName{Name: tableName, Schema: db.schemaName}) &&
!doltdb.IsFullTextTable(tableName) &&
!doltdb.HasDoltCIPrefix(tableName) &&
tableName != doltdb.TestsTableName { // NM4 - determine why this is required now.
return ErrReservedTableName.New(tableName)
}
@@ -990,23 +990,7 @@ func (p *DoltDatabaseProvider) registerNewDatabase(ctx *sql.Context, name string
// Ensure any provider-supplied DB load params are applied before any lazy DB load occurs.
p.applyDBLoadParamsToEnv(newEnv)
fkChecks, err := ctx.GetSessionVariable(ctx, "foreign_key_checks")
if err != nil {
return err
}
deaf, err := newEnv.DbEaFactory(ctx)
if err != nil {
return err
}
opts := editor.Options{
Deaf: deaf,
// TODO: this doesn't seem right, why is this getting set in the constructor to the DB
ForeignKeyChecksDisabled: fkChecks.(int8) == 0,
}
db, err := NewDatabase(ctx, name, newEnv.DbData(ctx), opts)
db, err := NewDatabase(ctx, name, newEnv.DbData(ctx), editor.Options{})
if err != nil {
return err
}
@@ -44,12 +44,8 @@ func TestDatabaseProvider(t *testing.T) {
setup := func(t *testing.T) (*sqle.Engine, *sql.Context, *DoltDatabaseProvider) {
ctx := context.Background()
dEnv := dtestutils.CreateTestEnv()
tmpDir, err := dEnv.TempTableFilesDir()
require.NoError(t, err)
deaf, err := dEnv.DbEaFactory(ctx)
require.NoError(t, err)
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(ctx), opts)
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(ctx), editor.Options{})
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(dEnv, context.Background(), db)
+1 -7
View File
@@ -45,15 +45,9 @@ func TestIsKeyFuncs(t *testing.T) {
func TestNeedsToReloadEvents(t *testing.T) {
ctx := context.Background()
dEnv := dtestutils.CreateTestEnv()
tmpDir, err := dEnv.TempTableFilesDir()
require.NoError(t, err)
deaf, err := dEnv.DbEaFactory(ctx)
require.NoError(t, err)
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
timestamp := time.Now().Truncate(time.Minute).UTC()
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), editor.Options{})
require.NoError(t, err)
_, sqlCtx, err := NewTestEngine(dEnv, ctx, db)
@@ -103,6 +103,8 @@ func doDoltCherryPick(ctx *sql.Context, args []string) (string, int, int, int, e
cherryPickOptions.EmptyCommitHandling = doltdb.KeepEmptyCommit
}
cherryPickOptions.SkipVerification = apr.Contains(cli.SkipVerificationFlag)
commit, mergeResult, err := cherry_pick.CherryPick(ctx, cherryStr, cherryPickOptions)
if err != nil {
return "", 0, 0, 0, err
@@ -57,7 +57,8 @@ func doDoltClean(ctx *sql.Context, args []string) (int, error) {
return 1, fmt.Errorf("Could not load database %s", dbName)
}
roots, err = actions.CleanUntracked(ctx, roots, apr.Args, apr.ContainsAll(cli.DryRunFlag), false)
respectIgnoreRules := !apr.Contains(cli.ExcludeIgnoreRulesFlag)
roots, err = actions.CleanUntracked(ctx, roots, apr.Args, apr.ContainsAll(cli.DryRunFlag), false, respectIgnoreRules)
if err != nil {
return 1, fmt.Errorf("failed to clean; %w", err)
}
@@ -113,6 +113,13 @@ func getDirectoryAndUrlString(apr *argparser.ArgParseResults) (string, string, e
} else if dir == "/" {
return "", "", errhand.BuildDError("Could not infer repo name. Please explicitly define a directory for this url").Build()
}
// Match `dolt clone` behavior: strip a trailing `.git` from inferred names.
if strings.HasSuffix(dir, ".git") {
dir = strings.TrimSuffix(dir, ".git")
if dir == "" {
return "", "", errhand.BuildDError("Could not infer repo name. Please explicitly define a directory for this url").Build()
}
}
}
return dir, urlStr, nil
@@ -163,14 +163,15 @@ func doDoltCommit(ctx *sql.Context, args []string) (string, bool, error) {
}
csp := actions.CommitStagedProps{
Message: msg,
Date: t,
AllowEmpty: apr.Contains(cli.AllowEmptyFlag),
SkipEmpty: apr.Contains(cli.SkipEmptyFlag),
Amend: amend,
Force: apr.Contains(cli.ForceFlag),
Name: name,
Email: email,
Message: msg,
Date: t,
AllowEmpty: apr.Contains(cli.AllowEmptyFlag),
SkipEmpty: apr.Contains(cli.SkipEmptyFlag),
Amend: amend,
Force: apr.Contains(cli.ForceFlag),
Name: name,
Email: email,
SkipVerification: apr.Contains(cli.SkipVerificationFlag),
}
shouldSign, err := dsess.GetBooleanSystemVar(ctx, "gpgsign")
@@ -180,7 +180,7 @@ func doDoltMerge(ctx *sql.Context, args []string) (string, int, int, string, err
msg = userMsg
}
ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg)
ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg, apr.Contains(cli.SkipVerificationFlag))
if err != nil {
return commit, conflicts, fastForward, "", err
}
@@ -205,6 +205,7 @@ func performMerge(
spec *merge.MergeSpec,
noCommit bool,
msg string,
skipVerification bool,
) (*doltdb.WorkingSet, string, int, int, string, error) {
// todo: allow merges even when an existing merge is uncommitted
if ws.MergeActive() {
@@ -234,7 +235,7 @@ func performMerge(
if canFF {
if spec.FFMode == merge.NoFastForward {
var commit *doltdb.Commit
ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit)
ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit, skipVerification)
if err == doltdb.ErrUnresolvedConflictsOrViolations {
// if there are unresolved conflicts, write the resulting working set back to the session and return an
// error message
@@ -306,7 +307,10 @@ func performMerge(
author := fmt.Sprintf("%s <%s>", spec.Name, spec.Email)
args := []string{"-m", msg, "--author", author}
if spec.Force {
args = append(args, "--force")
args = append(args, "--"+cli.ForceFlag)
}
if skipVerification {
args = append(args, "--"+cli.SkipVerificationFlag)
}
commit, _, err = doDoltCommit(ctx, args)
if err != nil {
@@ -405,6 +409,7 @@ func executeNoFFMerge(
dbName string,
ws *doltdb.WorkingSet,
noCommit bool,
skipVerification bool,
) (*doltdb.WorkingSet, *doltdb.Commit, error) {
mergeRoot, err := spec.MergeC.GetRootValue(ctx)
if err != nil {
@@ -444,11 +449,12 @@ func executeNoFFMerge(
}
pendingCommit, err := dSess.NewPendingCommit(ctx, dbName, roots, actions.CommitStagedProps{
Message: msg,
Date: spec.Date,
Force: spec.Force,
Name: spec.Name,
Email: spec.Email,
Message: msg,
Date: spec.Date,
Force: spec.Force,
Name: spec.Name,
Email: spec.Email,
SkipVerification: skipVerification,
})
if err != nil {
return nil, nil, err
@@ -237,7 +237,7 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, string, error) {
return noConflictsOrViolations, threeWayMerge, "", ErrUncommittedChanges.New()
}
ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg)
ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg, apr.Contains(cli.SkipVerificationFlag))
if err != nil && !errors.Is(doltdb.ErrUpToDate, err) {
return conflicts, fastForward, "", err
}
@@ -216,7 +216,9 @@ func doDoltRebase(ctx *sql.Context, args []string) (int, string, error) {
} else if apr.NArg() > 1 {
return 1, "", fmt.Errorf("too many args")
}
err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling)
skipVerification := apr.Contains(cli.SkipVerificationFlag)
err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
if err != nil {
return 1, "", err
}
@@ -263,7 +265,7 @@ func processCommitBecomesEmptyParams(apr *argparser.ArgParseResults) (doltdb.Emp
// startRebase starts a new interactive rebase operation. |upstreamPoint| specifies the commit where the new rebased
// commits will be based off of, |commitBecomesEmptyHandling| specifies how to handle commits that are not empty, but
// do not produce any changes when applied, and |emptyCommitHandling| specifies how to handle empty commits.
func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) error {
func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling, skipVerification bool) error {
if upstreamPoint == "" {
return fmt.Errorf("no upstream branch specified")
}
@@ -351,7 +353,7 @@ func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandl
}
newWorkingSet, err := workingSet.StartRebase(ctx, upstreamCommit, rebaseBranch, branchRoots.Working,
commitBecomesEmptyHandling, emptyCommitHandling)
commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
if err != nil {
return err
}
@@ -716,7 +718,8 @@ func continueRebase(ctx *sql.Context) rebaseResult {
result := processRebasePlanStep(ctx, &step,
workingSet.RebaseState().CommitBecomesEmptyHandling(),
workingSet.RebaseState().EmptyCommitHandling())
workingSet.RebaseState().EmptyCommitHandling(),
workingSet.RebaseState().SkipVerification())
if result.err != nil || result.status != 0 || result.halt {
return result
}
@@ -803,7 +806,7 @@ func commitManuallyStagedChangesForStep(ctx *sql.Context, step rebase.RebasePlan
}
options, err := createCherryPickOptionsForRebaseStep(ctx, &step, workingSet.RebaseState().CommitBecomesEmptyHandling(),
workingSet.RebaseState().EmptyCommitHandling())
workingSet.RebaseState().EmptyCommitHandling(), workingSet.RebaseState().SkipVerification())
doltDB, ok := doltSession.GetDoltDB(ctx, ctx.GetCurrentDatabase())
if !ok {
@@ -861,6 +864,7 @@ func processRebasePlanStep(
planStep *rebase.RebasePlanStep,
commitBecomesEmptyHandling doltdb.EmptyCommitHandling,
emptyCommitHandling doltdb.EmptyCommitHandling,
skipVerification bool,
) rebaseResult {
// Make sure we have a transaction opened for the session
// NOTE: After our first call to cherry-pick, the tx is committed, so a new tx needs to be started
@@ -878,7 +882,7 @@ func processRebasePlanStep(
return newRebaseSuccess("")
}
options, err := createCherryPickOptionsForRebaseStep(ctx, planStep, commitBecomesEmptyHandling, emptyCommitHandling)
options, err := createCherryPickOptionsForRebaseStep(ctx, planStep, commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
if err != nil {
return newRebaseError(err)
}
@@ -886,12 +890,19 @@ func processRebasePlanStep(
return handleRebaseCherryPick(ctx, planStep, *options)
}
func createCherryPickOptionsForRebaseStep(ctx *sql.Context, planStep *rebase.RebasePlanStep, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) (*cherry_pick.CherryPickOptions, error) {
func createCherryPickOptionsForRebaseStep(
ctx *sql.Context,
planStep *rebase.RebasePlanStep,
commitBecomesEmptyHandling doltdb.EmptyCommitHandling,
emptyCommitHandling doltdb.EmptyCommitHandling,
skipVerification bool,
) (*cherry_pick.CherryPickOptions, error) {
// Override the default empty commit handling options for cherry-pick, since
// rebase has slightly different defaults
options := cherry_pick.NewCherryPickOptions()
options.CommitBecomesEmptyHandling = commitBecomesEmptyHandling
options.EmptyCommitHandling = emptyCommitHandling
options.SkipVerification = skipVerification
switch planStep.Action {
case rebase.RebaseActionDrop, rebase.RebaseActionPick, rebase.RebaseActionEdit:
@@ -392,14 +392,6 @@ func parseStashIndex(apr *argparser.ArgParseResults) (int, error) {
return idx, nil
}
func bulkDbEaFactory(dbData env.DbData[*sql.Context]) editor.DbEaFactory {
tmpDir, err := dbData.Rsw.TempTableFilesDir()
if err != nil {
return nil
}
return editor.NewBulkImportTEAFactory(dbData.Ddb.ValueReadWriter(), tmpDir)
}
func updateWorkingRoot(ctx *sql.Context, dbData env.DbData[*sql.Context], newRoot doltdb.RootValue) error {
var h hash.Hash
var wsRef ref.WorkingSetRef
@@ -510,16 +502,12 @@ func handleMerge(ctx *sql.Context, dbName string, dbData env.DbData[*sql.Context
return nil, nil, nil, err
}
tmpDir, err := dbData.Rsw.TempTableFilesDir()
if err != nil {
return nil, nil, nil, err
}
tableResolver, err := dsess.GetTableResolver(ctx, dbName)
if err != nil {
return nil, nil, nil, err
}
opts := editor.Options{Deaf: bulkDbEaFactory(dbData), Tempdir: tmpDir}
opts := editor.Options{}
result, err := merge.MergeRoots(ctx, tableResolver, curWorkingRoot, stashRoot, parentRoot, stashRoot, parentCommit, opts, merge.MergeOpts{IsCherryPick: false})
if err != nil {
return nil, nil, nil, err
@@ -24,7 +24,6 @@ import (
"time"
"github.com/dolthub/go-mysql-server/sql"
sqltypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/shopspring/decimal"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
@@ -454,32 +453,6 @@ func (d *DoltSession) clear() {
}
}
func (d *DoltSession) newWorkingSetForHead(ctx *sql.Context, wsRef ref.WorkingSetRef, dbName string) (*doltdb.WorkingSet, error) {
dbData, _ := d.GetDbData(nil, dbName)
headSpec, _ := doltdb.NewCommitSpec("HEAD")
headRef, err := wsRef.ToHeadRef()
if err != nil {
return nil, err
}
optCmt, err := dbData.Ddb.Resolve(ctx, headSpec, headRef)
if err != nil {
return nil, err
}
headCommit, ok := optCmt.ToCommit()
if !ok {
return nil, doltdb.ErrGhostCommitEncountered
}
headRoot, err := headCommit.GetRootValue(ctx)
if err != nil {
return nil, err
}
return doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(headRoot).WithStagedRoot(headRoot), nil
}
// CommitTransaction commits the in-progress transaction. Depending on session settings, this may write only a new
// working set, or may additionally create a new dolt commit for the current HEAD. If more than one branch head has
// changes, the transaction is rejected.
@@ -1328,40 +1301,6 @@ func (d *DoltSession) setHeadRefSessionVar(ctx *sql.Context, db, value string) e
func (d *DoltSession) setForeignKeyChecksSessionVar(ctx *sql.Context, key string, value interface{}) error {
d.mu.Lock()
defer d.mu.Unlock()
convertedVal, _, err := sqltypes.Int64.Convert(ctx, value)
if err != nil {
return err
}
intVal := int64(0)
if convertedVal != nil {
intVal = convertedVal.(int64)
}
if intVal == 0 {
for _, dbState := range d.dbStates {
for _, branchState := range dbState.heads {
if ws := branchState.WriteSession(); ws != nil {
opts := ws.GetOptions()
opts.ForeignKeyChecksDisabled = true
ws.SetOptions(opts)
}
}
}
} else if intVal == 1 {
for _, dbState := range d.dbStates {
for _, branchState := range dbState.heads {
if ws := branchState.WriteSession(); ws != nil {
opts := ws.GetOptions()
opts.ForeignKeyChecksDisabled = false
ws.SetOptions(opts)
}
}
}
} else {
return sql.ErrInvalidSystemVariableValue.New("foreign_key_checks", intVal)
}
return d.Session.SetSessionVariable(ctx, key, value)
}
@@ -363,15 +363,6 @@ func interfaceToString(r interface{}) (string, error) {
return str, nil
}
func resolveRoot(ctx *sql.Context, sess *dsess.DoltSession, dbName, hashStr string) (*refDetails, error) {
root, commitTime, _, err := sess.ResolveRootForRef(ctx, dbName, hashStr)
if err != nil {
return nil, err
}
return &refDetails{root: root, hashStr: hashStr, commitTime: commitTime}, nil
}
func resolveCommit(ctx *sql.Context, ddb *doltdb.DoltDB, headRef ref.DoltRef, cSpecStr string) (*doltdb.Commit, error) {
cs, err := doltdb.NewCommitSpec(cSpecStr)
if err != nil {
@@ -17,7 +17,9 @@ package dtablefunctions
import (
"fmt"
"io"
"strconv"
"strings"
"time"
gms "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
@@ -26,10 +28,13 @@ import (
"github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/gocraft/dbr/v2"
"github.com/gocraft/dbr/v2/dialect"
"github.com/shopspring/decimal"
"golang.org/x/exp/constraints"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
"github.com/dolthub/dolt/go/store/val"
)
const testsRunDefaultRowCount = 10
@@ -39,12 +44,13 @@ var _ sql.CatalogTableFunction = (*TestsRunTableFunction)(nil)
var _ sql.ExecSourceRel = (*TestsRunTableFunction)(nil)
var _ sql.AuthorizationCheckerNode = (*TestsRunTableFunction)(nil)
type testResult struct {
testName string
groupName string
query string
status string
message string
// TestResult represents the result of running a single test
type TestResult struct {
TestName string
GroupName string
Query string
Status string
Message string
}
type TestsRunTableFunction struct {
@@ -199,7 +205,7 @@ func (trtf *TestsRunTableFunction) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIt
return nil, err
}
resultRow := sql.NewRow(result.testName, result.groupName, result.query, result.status, result.message)
resultRow := sql.NewRow(result.TestName, result.GroupName, result.Query, result.Status, result.Message)
resultRows = append(resultRows, resultRow)
}
}
@@ -220,7 +226,7 @@ func (trtf *TestsRunTableFunction) RowCount(_ *sql.Context) (uint64, bool, error
return testsRunDefaultRowCount, false, nil
}
func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResult, err error) {
func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result TestResult, err error) {
testName, groupName, query, assertion, comparison, value, err := parseDoltTestsRow(trtf.ctx, row)
if err != nil {
return
@@ -237,9 +243,9 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
if err != nil {
message = fmt.Sprintf("Query error: %s", err.Error())
} else {
testPassed, message, err = actions.AssertData(trtf.ctx, *assertion, *comparison, value, queryResult)
testPassed, message, err = AssertData(trtf.ctx, *assertion, *comparison, value, queryResult)
if err != nil {
return testResult{}, err
return TestResult{}, err
}
}
}
@@ -253,11 +259,49 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
if groupName != nil {
groupString = *groupName
}
result = testResult{*testName, groupString, *query, status, message}
result = TestResult{*testName, groupString, *query, status, message}
return result, nil
}
func (trtf *TestsRunTableFunction) queryAndAssertWithFunc(row sql.Row, assertDataFunc AssertDataFunc) (result TestResult, err error) {
testName, groupName, query, assertion, comparison, value, err := parseDoltTestsRow(trtf.ctx, row)
if err != nil {
return
}
message, err := validateQuery(trtf.ctx, trtf.catalog, *query)
if err != nil && message == "" {
message = fmt.Sprintf("query error: %s", err.Error())
}
var testPassed bool
if message == "" {
_, queryResult, _, err := trtf.engine.Query(trtf.ctx, *query)
if err != nil {
message = fmt.Sprintf("Query error: %s", err.Error())
} else {
testPassed, message, err = assertDataFunc(trtf.ctx, *assertion, *comparison, value, queryResult)
if err != nil {
return TestResult{}, err
}
}
}
status := "PASS"
if !testPassed {
status = "FAIL"
}
var groupString string
if groupName != nil {
groupString = *groupName
}
result = TestResult{*testName, groupString, *query, status, message}
return result, nil
}
func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) ([]sql.Row, error) {
// Original behavior when root is nil - use SQL queries against current session
var queries []string
if arg == "*" {
@@ -320,28 +364,31 @@ func IsWriteQuery(query string, ctx *sql.Context, catalog sql.Catalog) (bool, er
}
func parseDoltTestsRow(ctx *sql.Context, row sql.Row) (testName, groupName, query, assertion, comparison, value *string, err error) {
if testName, err = actions.GetStringColAsString(ctx, row[0]); err != nil {
if testName, err = getStringColAsString(ctx, row[0]); err != nil {
return
}
if groupName, err = actions.GetStringColAsString(ctx, row[1]); err != nil {
if groupName, err = getStringColAsString(ctx, row[1]); err != nil {
return
}
if query, err = actions.GetStringColAsString(ctx, row[2]); err != nil {
if query, err = getStringColAsString(ctx, row[2]); err != nil {
return
}
if assertion, err = actions.GetStringColAsString(ctx, row[3]); err != nil {
if assertion, err = getStringColAsString(ctx, row[3]); err != nil {
return
}
if comparison, err = actions.GetStringColAsString(ctx, row[4]); err != nil {
if comparison, err = getStringColAsString(ctx, row[4]); err != nil {
return
}
if value, err = actions.GetStringColAsString(ctx, row[5]); err != nil {
if value, err = getStringColAsString(ctx, row[5]); err != nil {
return
}
return testName, groupName, query, assertion, comparison, value, nil
}
// AssertDataFunc defines the function signature for asserting test data
type AssertDataFunc func(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error)
func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string, error) {
// We first check if the query contains multiple sql statements
if statements, err := sqlparser.SplitStatementToPieces(query); err != nil {
@@ -361,3 +408,455 @@ func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string,
}
return "", nil
}
// Simple inline assertion constants to avoid circular imports
const (
AssertionExpectedRows = "expected_rows"
AssertionExpectedColumns = "expected_columns"
AssertionExpectedSingleValue = "expected_single_value"
)
// getStringColAsString safely converts a sql value to string
func getStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
if tableValue == nil {
return nil, nil
}
if ts, ok := tableValue.(*val.TextStorage); ok {
str, err := ts.Unwrap(sqlCtx)
if err != nil {
return nil, err
}
return &str, nil
} else if str, ok := tableValue.(string); ok {
return &str, nil
} else {
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
}
}
// readTableDataFromDoltTable reads test data directly from a dolt table
func (trtf *TestsRunTableFunction) readTableDataFromDoltTable(table *doltdb.Table, arg string) ([]sql.Row, error) {
// This is a complex implementation that requires reading table data directly from dolt storage
// For now, return an error that clearly indicates this needs to be implemented
// The table scan would involve:
// 1. Getting the table schema
// 2. Creating a table iterator
// 3. Reading and filtering rows based on the arg (test_name or test_group)
// 4. Converting dolt storage format to SQL rows
//
// This is a significant implementation that requires understanding dolt's storage internals
return nil, fmt.Errorf("direct table reading from dolt storage not yet implemented for table scan of dolt_tests - this requires implementing table iteration and row conversion from dolt's internal storage format")
}
// AssertData parses an assertion, comparison, and value, then returns the status of the test.
// Valid comparison are: "==", "!=", "<", ">", "<=", and ">=".
// testPassed indicates whether the test was successful or not.
// message is a string used to indicate test failures, and will not halt the overall process.
// message will be empty if the test passed.
// err indicates runtime failures and will stop dolt_test_run from proceeding.
func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) {
switch assertion {
case AssertionExpectedRows:
message, err = expectRows(sqlCtx, comparison, value, queryResult)
case AssertionExpectedColumns:
message, err = expectColumns(sqlCtx, comparison, value, queryResult)
case AssertionExpectedSingleValue:
message, err = expectSingleValue(sqlCtx, comparison, value, queryResult)
default:
return false, fmt.Sprintf("%s is not a valid assertion type", assertion), nil
}
if err != nil {
return false, "", err
} else if message != "" {
return false, message, nil
}
return true, "", nil
}
func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
row, err := queryResult.Next(sqlCtx)
if err == io.EOF {
return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil
} else if err != nil {
return "", err
}
if len(row) != 1 {
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil
}
_, err = queryResult.Next(sqlCtx)
if err == nil { //If multiple rows were given, we should error out
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple rows"), nil
} else if err != io.EOF { // "True" error, so we should quit out
return "", err
}
if value == nil { // If we're expecting a null value, we don't need to type switch
return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil
}
// Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception
// of "0" and "1", which are valid integers and are covered below.
if *value != "0" && *value != "1" {
if expectedBool, err := strconv.ParseBool(*value); err == nil {
actualBool, boolErr := getInterfaceAsBool(row[0])
if boolErr != nil {
return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil
}
return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil
}
}
switch actualValue := row[0].(type) {
case int8:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int8(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int16:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int16(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int32:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int64:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil
case int:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case uint8:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint8(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint16:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint16(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint32:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint64:
expectedUint, err := strconv.ParseUint(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedUint, actualValue, AssertionExpectedSingleValue), nil
case uint:
expectedUint, err := strconv.ParseUint(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case float64:
expectedFloat, err := strconv.ParseFloat(*value, 64)
if err != nil {
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil
case float32:
expectedFloat, err := strconv.ParseFloat(*value, 32)
if err != nil {
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
}
return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil
case decimal.Decimal:
expectedDecimal, err := decimal.NewFromString(*value)
if err != nil {
return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil
}
return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil
case time.Time:
expectedTime, format, err := parseTestsDate(*value)
if err != nil {
return fmt.Sprintf("%s does not appear to be a valid date", *value), nil
}
return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil
case *val.TextStorage, string:
actualString, err := GetStringColAsString(sqlCtx, actualValue)
if err != nil {
return "", err
}
return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil
default:
return fmt.Sprintf("Type %T is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil
}
}
func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
if value == nil {
return "null is not a valid assertion for expected_rows", nil
}
expectedRows, err := strconv.Atoi(*value)
if err != nil {
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
}
var numRows int
for {
_, err := queryResult.Next(sqlCtx)
if err == io.EOF {
break
} else if err != nil {
return "", err
}
numRows++
}
return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil
}
func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
if value == nil {
return "null is not a valid assertion for expected_rows", nil
}
expectedColumns, err := strconv.Atoi(*value)
if err != nil {
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
}
var numColumns int
row, err := queryResult.Next(sqlCtx)
if err != nil && err != io.EOF {
return "", err
}
numColumns = len(row)
return compareTestAssertion(comparison, expectedColumns, numColumns, AssertionExpectedColumns), nil
}
// compareTestAssertion is a generic function used for comparing string, ints, floats.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareTestAssertion[T constraints.Ordered](comparison string, expectedValue, actualValue T, assertionType string) string {
switch comparison {
case "==":
if actualValue != expectedValue {
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case "!=":
if actualValue == expectedValue {
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case "<":
if actualValue >= expectedValue {
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, actualValue)
}
case "<=":
if actualValue > expectedValue {
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case ">":
if actualValue <= expectedValue {
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, actualValue)
}
case ">=":
if actualValue < expectedValue {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, actualValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// parseTestsDate is an internal function that parses the queried string according to allowed time formats for dolt_tests.
// It returns the parsed time, the format that succeeded, and an error if applicable.
func parseTestsDate(value string) (parsedTime time.Time, format string, err error) {
// List of valid formats
formats := []string{
time.DateOnly,
time.DateTime,
time.TimeOnly,
time.RFC3339,
time.RFC1123Z,
}
for _, format := range formats {
if parsedTime, parseErr := time.Parse(format, value); parseErr == nil {
return parsedTime, format, nil
} else {
err = parseErr
}
}
return time.Time{}, "", err
}
// compareDates is a function used for comparing time values.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareDates(comparison string, expectedValue, realValue time.Time, format string, assertionType string) string {
expectedStr := expectedValue.Format(format)
realStr := realValue.Format(format)
switch comparison {
case "==":
if !expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s equal to %s, got %s", assertionType, expectedStr, realStr)
}
case "!=":
if expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s not equal to %s, got %s", assertionType, expectedStr, realStr)
}
case "<":
if realValue.Equal(expectedValue) || realValue.After(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than %s, got %s", assertionType, expectedStr, realStr)
}
case "<=":
if realValue.After(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than or equal to %s, got %s", assertionType, expectedStr, realStr)
}
case ">":
if realValue.Before(expectedValue) || realValue.Equal(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than %s, got %s", assertionType, expectedStr, realStr)
}
case ">=":
if realValue.Before(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// compareDecimals is a function used for comparing decimals.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal, assertionType string) string {
switch comparison {
case "==":
if !expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, realValue)
}
case "!=":
if expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, realValue)
}
case "<":
if realValue.GreaterThanOrEqual(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, realValue)
}
case "<=":
if realValue.GreaterThan(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, realValue)
}
case ">":
if realValue.LessThanOrEqual(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, realValue)
}
case ">=":
if realValue.LessThan(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// getTinyIntColAsBool returns the value interface{} as a bool
// This is necessary because the query engine may return a tinyint column as a bool, int, or other types.
// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles.
func getInterfaceAsBool(col interface{}) (bool, error) {
switch v := col.(type) {
case bool:
return v, nil
case int:
return v == 1, nil
case int8:
return v == 1, nil
case int16:
return v == 1, nil
case int32:
return v == 1, nil
case int64:
return v == 1, nil
case uint:
return v == 1, nil
case uint8:
return v == 1, nil
case uint16:
return v == 1, nil
case uint32:
return v == 1, nil
case uint64:
return v == 1, nil
case string:
return v == "1", nil
default:
return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v)
}
}
// compareBooleans is a function used for comparing boolean values.
// It takes in a comparison string from one of: "==", "!="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string {
switch comparison {
case "==":
if expectedValue != realValue {
return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue)
}
case "!=":
if expectedValue == realValue {
return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison)
}
return ""
}
// compareNullValue is a function used for comparing a null value.
// It takes in a comparison string from one of: "==", "!="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareNullValue(comparison string, actualValue interface{}, assertionType string) string {
switch comparison {
case "==":
if actualValue != nil {
return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue)
}
case "!=":
if actualValue == nil {
return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType)
}
default:
return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison)
}
return ""
}
// GetStringColAsString is a function that returns a text column as a string.
// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations,
// so we use a special parser to get the correct string values
func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
if ts, ok := tableValue.(*val.TextStorage); ok {
str, err := ts.Unwrap(sqlCtx)
return &str, err
} else if str, ok := tableValue.(string); ok {
return &str, nil
} else if tableValue == nil {
return nil, nil
} else {
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
}
}
@@ -1,185 +0,0 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dtables
import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
)
const (
AccessBinlogTableName = AccessTableName + "_binlog"
NamespaceBinlogTableName = NamespaceTableName + "_binlog"
)
// accessBinlogSchema is the schema for the "dolt_branch_control_binlog" table.
var accessBinlogSchema = sql.Schema{
&sql.Column{
Name: "index",
Type: types.Int64,
Source: AccessBinlogTableName,
PrimaryKey: true,
},
&sql.Column{
Name: "operation",
Type: types.MustCreateEnumType([]string{"insert", "delete"}, sql.Collation_utf8mb4_0900_bin),
Source: AccessBinlogTableName,
PrimaryKey: false,
},
&sql.Column{
Name: "branch",
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
Source: AccessBinlogTableName,
PrimaryKey: false,
},
&sql.Column{
Name: "user",
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_bin),
Source: AccessBinlogTableName,
PrimaryKey: false,
},
&sql.Column{
Name: "host",
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
Source: AccessBinlogTableName,
PrimaryKey: false,
},
&sql.Column{
Name: "permissions",
Type: types.MustCreateSetType(PermissionsStrings, sql.Collation_utf8mb4_0900_ai_ci),
Source: AccessBinlogTableName,
PrimaryKey: false,
},
}
// namespaceBinlogSchema is the schema for the "dolt_branch_namespace_control_binlog" table.
var namespaceBinlogSchema = sql.Schema{
&sql.Column{
Name: "index",
Type: types.Int64,
Source: NamespaceBinlogTableName,
PrimaryKey: true,
},
&sql.Column{
Name: "operation",
Type: types.MustCreateEnumType([]string{"insert", "delete"}, sql.Collation_utf8mb4_0900_bin),
Source: NamespaceBinlogTableName,
PrimaryKey: false,
},
&sql.Column{
Name: "branch",
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
Source: NamespaceBinlogTableName,
PrimaryKey: false,
},
&sql.Column{
Name: "user",
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_bin),
Source: NamespaceBinlogTableName,
PrimaryKey: false,
},
&sql.Column{
Name: "host",
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
Source: NamespaceBinlogTableName,
PrimaryKey: false,
},
}
// BinlogTable provides a queryable view over the Binlog.
type BinlogTable struct {
Log *branch_control.Binlog
IsAccess bool
}
var _ sql.Table = BinlogTable{}
// Name implements the interface sql.Table.
func (b BinlogTable) Name() string {
if b.IsAccess {
return AccessBinlogTableName
} else {
return NamespaceBinlogTableName
}
}
// String implements the interface sql.Table.
func (b BinlogTable) String() string {
if b.IsAccess {
return AccessBinlogTableName
} else {
return NamespaceBinlogTableName
}
}
// Schema implements the interface sql.Table.
func (b BinlogTable) Schema() sql.Schema {
if b.IsAccess {
return accessBinlogSchema
} else {
return namespaceBinlogSchema
}
}
// Collation implements the interface sql.Table.
func (b BinlogTable) Collation() sql.CollationID {
return sql.Collation_Default
}
// Partitions implements the interface sql.Table.
func (b BinlogTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
return index.SinglePartitionIterFromNomsMap(nil), nil
}
// PartitionRows implements the interface sql.Table.
func (b BinlogTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
b.Log.RWMutex.RLock()
defer b.Log.RWMutex.RUnlock()
binlogRows := b.Log.Rows()
rows := make([]sql.Row, len(binlogRows))
for i := 0; i < len(binlogRows); i++ {
logRow := binlogRows[i]
operation := uint16(1)
if !logRow.IsInsert {
operation = 2
}
if b.IsAccess {
rows[i] = sql.Row{
int64(i),
operation,
logRow.Branch,
logRow.User,
logRow.Host,
logRow.Permissions,
}
} else {
rows[i] = sql.Row{
int64(i),
operation,
logRow.Branch,
logRow.User,
logRow.Host,
}
}
}
return sql.RowsToRowIter(rows...), nil
}
@@ -24,9 +24,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/row"
"github.com/dolthub/dolt/go/libraries/doltcore/rowconv"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/expreval"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
@@ -591,38 +589,6 @@ func (dt *DiffTable) PreciseMatch() bool {
return false
}
// tableData returns the map of primary key to values for the specified table (or an empty map if the tbl is null)
// and the schema of the table (or EmptySchema if tbl is null).
func tableData(ctx *sql.Context, tbl *doltdb.Table, ddb *doltdb.DoltDB) (durable.Index, schema.Schema, error) {
var data durable.Index
var err error
if tbl == nil {
data, err = durable.NewEmptyPrimaryIndex(ctx, ddb.ValueReadWriter(), ddb.NodeStore(), schema.EmptySchema)
if err != nil {
return nil, nil, err
}
} else {
data, err = tbl.GetRowData(ctx)
if err != nil {
return nil, nil, err
}
}
var sch schema.Schema
if tbl == nil {
sch = schema.EmptySchema
} else {
sch, err = tbl.GetSchema(ctx)
if err != nil {
return nil, nil, err
}
}
return data, sch, nil
}
type TblInfoAtCommit struct {
date *types.Timestamp
tbl *doltdb.Table
@@ -891,20 +857,6 @@ func (dps *DiffPartitions) Close(*sql.Context) error {
return nil
}
// rowConvForSchema creates a RowConverter for transforming rows with the given schema a target schema.
func (dp DiffPartition) rowConvForSchema(ctx context.Context, vrw types.ValueReadWriter, targetSch, srcSch schema.Schema) (*rowconv.RowConverter, error) {
if schema.SchemasAreEqual(srcSch, schema.EmptySchema) {
return rowconv.IdentityConverter, nil
}
fm, err := rowconv.TagMappingByTagAndName(srcSch, targetSch)
if err != nil {
return nil, err
}
return rowconv.NewRowConverter(ctx, vrw, fm)
}
// GetDiffTableSchemaAndJoiner returns the schema for the diff table given a
// target schema for a row |sch|. In the old storage format, it also returns the
// associated joiner.
@@ -1,61 +0,0 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dtables
import (
"strings"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/dolt/go/libraries/utils/set"
)
type Predicate func(sql.Expression) bool
// ColumnPredicate returns a predicate function for expressions on the column names given
func ColumnPredicate(colNameSet *set.StrSet) Predicate {
return func(filter sql.Expression) bool {
isCommitFilter := true
sql.Inspect(filter, func(e sql.Expression) (cont bool) {
if e == nil {
return true
}
switch val := e.(type) {
case *expression.GetField:
if !colNameSet.Contains(strings.ToLower(val.Name())) {
isCommitFilter = false
return false
}
}
return true
})
return isCommitFilter
}
}
// FilterFilters returns the subset of the expressions given that match the given predicate
func FilterFilters(filters []sql.Expression, predicate func(filter sql.Expression) bool) []sql.Expression {
matching := make([]sql.Expression, 0, len(filters))
for _, f := range filters {
if predicate(f) {
matching = append(matching, f)
}
}
return matching
}
@@ -156,49 +156,3 @@ func (itr *StashItr) Next(*sql.Context) (sql.Row, error) {
func (itr *StashItr) Close(*sql.Context) error {
return nil
}
var _ sql.RowReplacer = stashWriter{nil}
var _ sql.RowUpdater = stashWriter{nil}
var _ sql.RowInserter = stashWriter{nil}
var _ sql.RowDeleter = stashWriter{nil}
type stashWriter struct {
rt *StashesTable
}
// Insert inserts the row given, returning an error if it cannot. Insert will be called once for each row to process
// for the insert operation, which may involve many rows. After all rows in an operation have been processed, Close
// is called.
func (bWr stashWriter) Insert(_ *sql.Context, _ sql.Row) error {
return fmt.Errorf("the dolt_stashes table is read-only; use the dolt_stash stored procedure to edit stashes")
}
// Update the given row. Provides both the old and new rows.
func (bWr stashWriter) Update(_ *sql.Context, _ sql.Row, _ sql.Row) error {
return fmt.Errorf("the dolt_stash table is read-only; use the dolt_stash stored procedure to edit stashes")
}
// Delete deletes the given row. Returns ErrDeleteRowNotFound if the row was not found. Delete will be called once for
// each row to process for the delete operation, which may involve many rows. After all rows have been processed,
// Close is called.
func (bWr stashWriter) Delete(_ *sql.Context, _ sql.Row) error {
return fmt.Errorf("the dolt_stash table is read-only; use the dolt_stash stored procedure to edit stashes")
}
// StatementBegin implements the interface sql.TableEditor. Currently a no-op.
func (bWr stashWriter) StatementBegin(*sql.Context) {}
// DiscardChanges implements the interface sql.TableEditor. Currently a no-op.
func (bWr stashWriter) DiscardChanges(_ *sql.Context, _ error) error {
return nil
}
// StatementComplete implements the interface sql.TableEditor. Currently a no-op.
func (bWr stashWriter) StatementComplete(*sql.Context) error {
return nil
}
// Close finalizes the delete operation, persisting the result.
func (bWr stashWriter) Close(*sql.Context) error {
return nil
}
@@ -23,7 +23,6 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
@@ -462,64 +461,6 @@ func (itr *doltDiffCommitHistoryRowItr) Close(*sql.Context) error {
return nil
}
// isTableDataEmpty return true if the table does not contain any data
func isTableDataEmpty(ctx *sql.Context, table *doltdb.Table) (bool, error) {
rowData, err := table.GetRowData(ctx)
if err != nil {
return false, err
}
return rowData.Empty()
}
// commitFilterForDiffTableFilterExprs returns CommitFilter used for CommitItr.
func commitFilterForDiffTableFilterExprs(filters []sql.Expression) (doltdb.CommitFilter[*sql.Context], error) {
filters = transformFilters(filters...)
return func(ctx *sql.Context, h hash.Hash, optCmt *doltdb.OptionalCommit) (filterOut bool, err error) {
cm, ok := optCmt.ToCommit()
if !ok {
return false, doltdb.ErrGhostCommitEncountered
}
meta, err := cm.GetCommitMeta(ctx)
if err != nil {
return false, err
}
for _, filter := range filters {
res, err := filter.Eval(ctx, sql.Row{h.String(), meta.Name, meta.Time()})
if err != nil {
return false, err
}
b, ok := res.(bool)
if ok && !b {
return true, nil
}
}
return false, err
}, nil
}
// transformFilters return filter expressions with index specified for rows used in CommitFilter.
func transformFilters(filters ...sql.Expression) []sql.Expression {
for i := range filters {
filters[i], _, _ = transform.Expr(filters[i], func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
gf, ok := e.(*expression.GetField)
if !ok {
return e, transform.SameTree, nil
}
switch gf.Name() {
case commitHashCol:
return gf.WithIndex(0), transform.NewTree, nil
default:
return gf, transform.SameTree, nil
}
})
}
return filters
}
func getCommitsFromCommitHashEquality(ctx *sql.Context, ddb *doltdb.DoltDB, filters []sql.Expression) ([]*doltdb.Commit, bool) {
var commits []*doltdb.Commit
var isCommitHashEquality bool
@@ -1210,6 +1210,11 @@ func TestDoltDdlScripts(t *testing.T) {
RunDoltDdlScripts(t, harness)
}
func TestDoltCommitVerificationScripts(t *testing.T) {
harness := newDoltEnginetestHarness(t)
RunDoltCommitVerificationScripts(t, harness)
}
func TestBrokenDdlScripts(t *testing.T) {
for _, script := range BrokenDDLScripts {
t.Skip(script.Name)
@@ -512,7 +512,8 @@ func RunStoredProceduresTest(t *testing.T, h DoltEnginetestHarness) {
}
func RunDoltStoredProceduresTest(t *testing.T, h DoltEnginetestHarness) {
for _, script := range DoltProcedureTests {
scripts := append(DoltProcedureTests, DoltCleanProcedureScripts...)
for _, script := range scripts {
func() {
h := h.NewHarness(t)
h.UseLocalFileSystem()
@@ -523,7 +524,8 @@ func RunDoltStoredProceduresTest(t *testing.T, h DoltEnginetestHarness) {
}
func RunDoltStoredProceduresPreparedTest(t *testing.T, h DoltEnginetestHarness) {
for _, script := range DoltProcedureTests {
scripts := append(DoltProcedureTests, DoltCleanProcedureScripts...)
for _, script := range scripts {
func() {
h := h.NewHarness(t)
h.UseLocalFileSystem()
@@ -2145,3 +2147,12 @@ func RunTransactionTestsWithEngineSetup(t *testing.T, setupEngine func(*gms.Engi
})
}
}
func RunDoltCommitVerificationScripts(t *testing.T, harness DoltEnginetestHarness) {
for _, script := range DoltCommitVerificationScripts {
harness := harness.NewHarness(t)
enginetest.TestScript(t, harness, script)
harness.Close()
}
}
@@ -190,7 +190,6 @@ func (d *DoltHarness) resetScripts() []setup.SetupScript {
for i := range dbs {
db := dbs[i]
resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("use %s", db)})
// Any auto increment tables must be dropped and recreated to get a fresh state for the global auto increment
// sequence trackers
_, aiTables := enginetest.MustQuery(ctx, d.engine,
@@ -218,6 +217,7 @@ func (d *DoltHarness) resetScripts() []setup.SetupScript {
resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("drop database if exists %s", db)})
}
}
resetCmds = append(resetCmds, setup.SetupScript{"use mydb"})
return resetCmds
}
@@ -229,7 +229,7 @@ func commitScripts(dbs []string) []setup.SetupScript {
db := dbs[i]
commitCmds = append(commitCmds, fmt.Sprintf("use %s", db))
commitCmds = append(commitCmds, "call dolt_add('.')")
commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s', '--date', '1970-01-01T12:00:00')", db))
commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s', '--date', '1970-01-01T12:00:00', '--skip-verification')", db))
}
commitCmds = append(commitCmds, "use mydb")
return []setup.SetupScript{commitCmds}
@@ -0,0 +1,80 @@
// Copyright 2026 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package enginetest
import (
"github.com/dolthub/go-mysql-server/enginetest/queries"
"github.com/dolthub/go-mysql-server/sql"
)
// DoltCleanProcedureScripts are script tests for the dolt_clean procedure.
var DoltCleanProcedureScripts = []queries.ScriptTest{
{
Name: "dolt_clean does not drop tables matching dolt_ignore",
SetUpScript: []string{
"CREATE TABLE ignored_foo (id int primary key);",
"INSERT INTO ignored_foo VALUES (1);",
"INSERT INTO dolt_ignore VALUES ('ignored_*', true);",
"CALL dolt_add('dolt_ignore');",
"CALL dolt_commit('-m', 'add dolt_ignore');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT * FROM ignored_foo;",
Expected: []sql.Row{{1}},
},
{
Query: "CALL dolt_clean();",
Expected: []sql.Row{{0}},
},
{
Query: "SELECT * FROM ignored_foo;",
Expected: []sql.Row{{1}},
},
{
Query: "SHOW TABLES;",
Expected: []sql.Row{{"ignored_foo"}},
},
},
},
{
Name: "dolt_clean -x drops tables matching dolt_ignore",
SetUpScript: []string{
"CREATE TABLE ignored_bar (id int primary key);",
"INSERT INTO ignored_bar VALUES (1);",
"INSERT INTO dolt_ignore VALUES ('ignored_*', true);",
"CALL dolt_add('dolt_ignore');",
"CALL dolt_commit('-m', 'add dolt_ignore');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT * FROM ignored_bar;",
Expected: []sql.Row{{1}},
},
{
Query: "CALL dolt_clean('-x');",
Expected: []sql.Row{{0}},
},
{
Query: "SELECT * FROM ignored_bar;",
ExpectedErrStr: "table not found: ignored_bar",
},
{
Query: "SHOW TABLES;",
Expected: []sql.Row{},
},
},
},
}
@@ -0,0 +1,538 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package enginetest
import (
"regexp"
"github.com/dolthub/go-mysql-server/enginetest"
"github.com/dolthub/go-mysql-server/enginetest/queries"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/dolt/go/store/hash"
)
// commitHashValidator validates commit hash format (32 character hex)
type commitHashValidator struct{}
var _ enginetest.CustomValueValidator = &commitHashValidator{}
func (chv *commitHashValidator) Validate(val interface{}) (bool, error) {
h, ok := val.(string)
if !ok {
return false, nil
}
_, ok = hash.MaybeParse(h)
return ok, nil
}
// successfulRebaseMessageValidator validates successful rebase message format
type successfulRebaseMessageValidator struct{}
var _ enginetest.CustomValueValidator = &successfulRebaseMessageValidator{}
var successfulRebaseRegex = regexp.MustCompile(`^Successfully rebased.*`)
func (srmv *successfulRebaseMessageValidator) Validate(val interface{}) (bool, error) {
message, ok := val.(string)
if !ok {
return false, nil
}
return successfulRebaseRegex.MatchString(message), nil
}
var commitHash = &commitHashValidator{}
var successfulRebaseMessage = &successfulRebaseMessageValidator{}
var DoltCommitVerificationScripts = []queries.ScriptTest{
{
Name: "test verification system variables exist and have correct defaults",
Assertions: []queries.ScriptTestAssertion{
{
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
Expected: []sql.Row{
{"dolt_commit_verification_groups", ""},
},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "test verification system variables can be set",
Assertions: []queries.ScriptTestAssertion{
{
Query: "SET GLOBAL dolt_commit_verification_groups = '*'",
Expected: []sql.Row{{types.OkResult{}}},
},
{
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
Expected: []sql.Row{
{"dolt_commit_verification_groups", "*"},
},
},
{
Query: "SET GLOBAL dolt_commit_verification_groups = 'unit,integration'",
Expected: []sql.Row{{types.OkResult{}}},
},
{
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
Expected: []sql.Row{
{"dolt_commit_verification_groups", "unit,integration"},
},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "commit verification enabled - all tests pass",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
"('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Alice\"', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit with passing tests')",
ExpectedColumns: sql.Schema{
{Name: "hash", Type: types.LongText, Nullable: false},
},
Expected: []sql.Row{{commitHash}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "commit verification enabled - tests fail, commit aborted",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
"('test_will_fail', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit that should fail verification')",
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 2)",
},
{
Query: "CALL dolt_commit('--skip-verification','-m', 'skip verification')",
Expected: []sql.Row{{commitHash}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "commit with test verification - specific test groups",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = 'unit'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
"('test_will_fail', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit with unit tests only')",
Expected: []sql.Row{{commitHash}},
},
{
Query: "SET GLOBAL dolt_commit_verification_groups = 'integration'",
SkipResultsCheck: true,
},
{
Query: "CALL dolt_commit('--allow-empty', '--amend', '-m', 'fail please')",
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 2)",
},
{
Query: "CALL dolt_commit('--allow-empty', '--amend', '--skip-verification', '-m', 'skip the tests')",
Expected: []sql.Row{{commitHash}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "cherry-pick with test verification enabled - tests pass",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_user_count_update', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'add test')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_user_count_update'",
"CALL dolt_add('.')",
"call dolt_commit_hash_out(@commit_1_hash,'--skip-verification', '-m', 'Add Bob and update test')",
"INSERT INTO users VALUES (3, 'Charlie', 'chuck@exampl.com')",
"CALL dolt_add('.')",
"call dolt_commit_hash_out(@commit_2_hash,'--skip-verification', '-m', 'Add Charlie')",
"CALL dolt_checkout('main')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_cherry_pick(@commit_1_hash)",
Expected: []sql.Row{{commitHash, int64(0), int64(0), int64(0)}},
},
{
Query: "CALL dolt_cherry_pick(@commit_2_hash)",
ExpectedErrStr: "commit verification failed: test_user_count_update (Assertion failed: expected_single_value equal to 2, got 3)",
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "cherry-pick with test verification enabled - tests fail, aborted",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Initial commit')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"CALL dolt_add('.')",
"call dolt_commit_hash_out(@commit_hash,'--skip-verification', '-m', 'Add Bob but dont update test')",
"CALL dolt_checkout('main')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_cherry_pick(@commit_hash)",
ExpectedErrStr: "commit verification failed: test_users_count (Assertion failed: expected_single_value equal to 1, got 2)",
},
{
Query: "CALL dolt_cherry_pick('--skip-verification', @commit_hash)",
Expected: []sql.Row{{commitHash, int64(0), int64(0), int64(0)}},
},
{
Query: "select * from dolt_test_run('*')",
Expected: []sql.Row{
{"test_users_count", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 1, got 2"},
},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "rebase with test verification enabled - tests pass",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Initial commit')",
"DELETE FROM users where id = 1",
"INSERT INTO users VALUES (1, 'Zed', 'zed@example.com')",
"CALL dolt_commit('-am', 'drop Alice, add Zed')", // tests still pass here.
"CALL dolt_checkout('-b', 'feature', 'HEAD~1')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Add Bob and update test')",
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
"UPDATE dolt_tests SET assertion_value = '3' WHERE test_name = 'test_users_count'",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Add Charlie, update test')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_rebase('main')",
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "rebase with test verification enabled - tests fail, aborted",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Initial commit')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Add Bob but dont update test')",
"CALL dolt_checkout('main')",
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie')", // this will trip the existing test.
"CALL dolt_checkout('feature')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_rebase('main')",
ExpectedErrStr: "commit verification failed: test_users_count (Assertion failed: expected_single_value equal to 2, got 3)",
},
{
Query: "CALL dolt_rebase('--abort')",
Expected: []sql.Row{{0, "Interactive rebase aborted"}},
},
{
Query: "CALL dolt_rebase('--skip-verification', 'main')",
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
},
{
Query: "select * from dolt_test_run('*')",
Expected: []sql.Row{
{"test_users_count", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 2, got 3"},
},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "interactive rebase with --skip-verification flag should persist across continue operations",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob but dont update test')", // This will cause test to fail
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie')",
"CALL dolt_checkout('main')",
"INSERT INTO users VALUES (4, 'David', 'david@example.com')", // Add a commit to main to create divergence
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add David on main')",
"CALL dolt_checkout('feature')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_rebase('--interactive', '--skip-verification', 'main')",
Expected: []sql.Row{{0, "interactive rebase started on branch dolt_rebase_feature; adjust the rebase plan in the dolt_rebase table, then continue rebasing by calling dolt_rebase('--continue')"}},
},
{
Query: "CALL dolt_rebase('--continue')", // This should NOT require --skip-verification flag but should still skip tests
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "test verification with no dolt_tests errors",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit without dolt_tests table')",
ExpectedErrStr: "failed to run dolt_test_run for group *: could not find tests for argument: *",
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "test verification with mixed test groups - only specified groups run",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = 'unit'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_unit', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
"('test_users_integration', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit with unit tests only - should pass')",
Expected: []sql.Row{{commitHash}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "test verification error message includes test details",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_specific_failure', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit with specific test failure')",
ExpectedErrStr: "commit verification failed: test_specific_failure (Assertion failed: expected_single_value equal to 999, got 2)",
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "merge with test verification enabled - tests pass",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Alice\"', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Initial commit')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_bob_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Bob\"', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
"CALL dolt_checkout('main')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('feature')",
Expected: []sql.Row{{commitHash, int64(1), int64(0), "merge successful"}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "merge with test verification enabled - tests fail, merge aborted",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit with failing test')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
"CALL dolt_checkout('main')",
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie to force non-FF merge')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('feature')",
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 3)",
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "merge with --skip-verification flag bypasses verification",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit with failing test')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
"CALL dolt_checkout('main')",
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie to force non-FF merge')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('--skip-verification', 'feature')",
Expected: []sql.Row{{commitHash, int64(0), int64(0), "merge successful"}},
},
{
Query: "select * from dolt_test_run('*')",
Expected: []sql.Row{
{"test_will_fail", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 999, got 3"},
},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
}

Some files were not shown because too many files have changed in this diff Show More