mirror of
https://github.com/dolthub/dolt.git
synced 2026-05-13 03:10:03 -05:00
Merge remote-tracking branch 'origin/main' into elian/6329b
This commit is contained in:
@@ -6,6 +6,7 @@ ARG DOLT_VERSION
|
||||
RUN apt update -y && \
|
||||
apt install -y \
|
||||
curl \
|
||||
git \
|
||||
tini \
|
||||
ca-certificates && \
|
||||
apt clean && \
|
||||
|
||||
@@ -4,7 +4,7 @@ FROM debian:bookworm-slim AS base
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
curl tini ca-certificates && \
|
||||
curl git tini ca-certificates && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ func CreateCommitArgParser(supportsBranchFlag bool) *argparser.ArgParser {
|
||||
ap.SupportsFlag(UpperCaseAllFlag, "A", "Adds all tables and databases (including new tables) in the working set to the staged set.")
|
||||
ap.SupportsFlag(AmendFlag, "", "Amend previous commit")
|
||||
ap.SupportsOptionalString(SignFlag, "S", "key-id", "Sign the commit using GPG. If no key-id is provided the key-id is taken from 'user.signingkey' the in the configuration")
|
||||
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification")
|
||||
if supportsBranchFlag {
|
||||
ap.SupportsString(BranchParam, "", "branch", "Commit to the specified branch instead of the current branch.")
|
||||
}
|
||||
@@ -96,6 +97,7 @@ func CreateMergeArgParser() *argparser.ArgParser {
|
||||
ap.SupportsFlag(NoCommitFlag, "", "Perform the merge and stop just before creating a merge commit. Note this will not prevent a fast-forward merge; use the --no-ff arg together with the --no-commit arg to prevent both fast-forwards and merge commits.")
|
||||
ap.SupportsFlag(NoEditFlag, "", "Use an auto-generated commit message when creating a merge commit. The default for interactive CLI sessions is to open an editor.")
|
||||
ap.SupportsString(AuthorParam, "", "author", "Specify an explicit author using the standard A U Thor {{.LessThan}}author@example.com{{.GreaterThan}} format.")
|
||||
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before merge")
|
||||
|
||||
return ap
|
||||
}
|
||||
@@ -116,6 +118,7 @@ func CreateRebaseArgParser() *argparser.ArgParser {
|
||||
ap.SupportsFlag(AbortParam, "", "Abort an interactive rebase and return the working set to the pre-rebase state")
|
||||
ap.SupportsFlag(ContinueFlag, "", "Continue an interactive rebase after adjusting the rebase plan")
|
||||
ap.SupportsFlag(InteractiveFlag, "i", "Start an interactive rebase")
|
||||
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before rebase")
|
||||
return ap
|
||||
}
|
||||
|
||||
@@ -174,6 +177,7 @@ func CreateRemoteArgParser() *argparser.ArgParser {
|
||||
func CreateCleanArgParser() *argparser.ArgParser {
|
||||
ap := argparser.NewArgParserWithVariableArgs("clean")
|
||||
ap.SupportsFlag(DryRunFlag, "", "Tests removing untracked tables without modifying the working set.")
|
||||
ap.SupportsFlag(ExcludeIgnoreRulesFlag, "x", "Do not respect dolt_ignore; remove untracked tables that match dolt_ignore. dolt_nonlocal_tables is always respected.")
|
||||
return ap
|
||||
}
|
||||
|
||||
@@ -192,6 +196,7 @@ func CreateCherryPickArgParser() *argparser.ArgParser {
|
||||
ap.SupportsFlag(AllowEmptyFlag, "", "Allow empty commits to be cherry-picked. "+
|
||||
"Note that use of this option only keeps commits that were initially empty. "+
|
||||
"Commits which become empty, due to a previous commit, will cause cherry-pick to fail.")
|
||||
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before cherry-pick")
|
||||
ap.TooManyArgsErrorFunc = func(receivedArgs []string) error {
|
||||
return errors.New("cherry-picking multiple commits is not supported yet.")
|
||||
}
|
||||
@@ -229,6 +234,7 @@ func CreatePullArgParser() *argparser.ArgParser {
|
||||
ap.SupportsString(UserFlag, "", "user", "User name to use when authenticating with the remote. Gets password from the environment variable {{.EmphasisLeft}}DOLT_REMOTE_PASSWORD{{.EmphasisRight}}.")
|
||||
ap.SupportsFlag(PruneFlag, "p", "After fetching, remove any remote-tracking references that don't exist on the remote.")
|
||||
ap.SupportsFlag(SilentFlag, "", "Suppress progress information.")
|
||||
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before merge")
|
||||
return ap
|
||||
}
|
||||
|
||||
|
||||
+73
-71
@@ -17,77 +17,79 @@ package cli
|
||||
// Constants for command line flags names. These tend to be used in multiple places, so defining
|
||||
// them low in the package dependency tree makes sense.
|
||||
const (
|
||||
AbortParam = "abort"
|
||||
AllFlag = "all"
|
||||
AllowEmptyFlag = "allow-empty"
|
||||
AmendFlag = "amend"
|
||||
AuthorParam = "author"
|
||||
ArchiveLevelParam = "archive-level"
|
||||
BranchParam = "branch"
|
||||
CachedFlag = "cached"
|
||||
CheckoutCreateBranch = "b"
|
||||
CreateResetBranch = "B"
|
||||
CommitFlag = "commit"
|
||||
ContinueFlag = "continue"
|
||||
CopyFlag = "copy"
|
||||
DateParam = "date"
|
||||
DecorateFlag = "decorate"
|
||||
DeleteFlag = "delete"
|
||||
DeleteForceFlag = "D"
|
||||
DepthFlag = "depth"
|
||||
DryRunFlag = "dry-run"
|
||||
EmptyParam = "empty"
|
||||
ForceFlag = "force"
|
||||
FullFlag = "full"
|
||||
GraphFlag = "graph"
|
||||
HardResetParam = "hard"
|
||||
HostFlag = "host"
|
||||
IncludeUntrackedFlag = "include-untracked"
|
||||
InteractiveFlag = "interactive"
|
||||
JobFlag = "job"
|
||||
ListFlag = "list"
|
||||
MergesFlag = "merges"
|
||||
MessageArg = "message"
|
||||
MinParentsFlag = "min-parents"
|
||||
MoveFlag = "move"
|
||||
NoCommitFlag = "no-commit"
|
||||
NoEditFlag = "no-edit"
|
||||
NoFFParam = "no-ff"
|
||||
FFOnlyParam = "ff-only"
|
||||
NoPrettyFlag = "no-pretty"
|
||||
NoTLSFlag = "no-tls"
|
||||
NoJsonMergeFlag = "dont-merge-json"
|
||||
NotFlag = "not"
|
||||
NumberFlag = "number"
|
||||
OneLineFlag = "oneline"
|
||||
OursFlag = "ours"
|
||||
OutputOnlyFlag = "output-only"
|
||||
ParentsFlag = "parents"
|
||||
PatchFlag = "patch"
|
||||
PasswordFlag = "password"
|
||||
PortFlag = "port"
|
||||
PruneFlag = "prune"
|
||||
QuietFlag = "quiet"
|
||||
RemoteParam = "remote"
|
||||
SetUpstreamFlag = "set-upstream"
|
||||
SetUpstreamToFlag = "set-upstream-to"
|
||||
ShallowFlag = "shallow"
|
||||
ShowIgnoredFlag = "ignored"
|
||||
ShowSignatureFlag = "show-signature"
|
||||
SignFlag = "gpg-sign"
|
||||
SilentFlag = "silent"
|
||||
SingleBranchFlag = "single-branch"
|
||||
SkipEmptyFlag = "skip-empty"
|
||||
SoftResetParam = "soft"
|
||||
SquashParam = "squash"
|
||||
StagedFlag = "staged"
|
||||
StatFlag = "stat"
|
||||
SystemFlag = "system"
|
||||
TablesFlag = "tables"
|
||||
TheirsFlag = "theirs"
|
||||
TrackFlag = "track"
|
||||
UpperCaseAllFlag = "ALL"
|
||||
UserFlag = "user"
|
||||
AbortParam = "abort"
|
||||
AllFlag = "all"
|
||||
AllowEmptyFlag = "allow-empty"
|
||||
AmendFlag = "amend"
|
||||
AuthorParam = "author"
|
||||
ArchiveLevelParam = "archive-level"
|
||||
BranchParam = "branch"
|
||||
CachedFlag = "cached"
|
||||
CheckoutCreateBranch = "b"
|
||||
CreateResetBranch = "B"
|
||||
CommitFlag = "commit"
|
||||
ContinueFlag = "continue"
|
||||
CopyFlag = "copy"
|
||||
DateParam = "date"
|
||||
DecorateFlag = "decorate"
|
||||
DeleteFlag = "delete"
|
||||
DeleteForceFlag = "D"
|
||||
DepthFlag = "depth"
|
||||
DryRunFlag = "dry-run"
|
||||
EmptyParam = "empty"
|
||||
ExcludeIgnoreRulesFlag = "x"
|
||||
ForceFlag = "force"
|
||||
FullFlag = "full"
|
||||
GraphFlag = "graph"
|
||||
HardResetParam = "hard"
|
||||
HostFlag = "host"
|
||||
IncludeUntrackedFlag = "include-untracked"
|
||||
InteractiveFlag = "interactive"
|
||||
JobFlag = "job"
|
||||
ListFlag = "list"
|
||||
MergesFlag = "merges"
|
||||
MessageArg = "message"
|
||||
MinParentsFlag = "min-parents"
|
||||
MoveFlag = "move"
|
||||
NoCommitFlag = "no-commit"
|
||||
NoEditFlag = "no-edit"
|
||||
NoFFParam = "no-ff"
|
||||
FFOnlyParam = "ff-only"
|
||||
NoPrettyFlag = "no-pretty"
|
||||
NoTLSFlag = "no-tls"
|
||||
NoJsonMergeFlag = "dont-merge-json"
|
||||
NotFlag = "not"
|
||||
NumberFlag = "number"
|
||||
OneLineFlag = "oneline"
|
||||
OursFlag = "ours"
|
||||
OutputOnlyFlag = "output-only"
|
||||
ParentsFlag = "parents"
|
||||
PatchFlag = "patch"
|
||||
PasswordFlag = "password"
|
||||
PortFlag = "port"
|
||||
PruneFlag = "prune"
|
||||
QuietFlag = "quiet"
|
||||
RemoteParam = "remote"
|
||||
SetUpstreamFlag = "set-upstream"
|
||||
SetUpstreamToFlag = "set-upstream-to"
|
||||
ShallowFlag = "shallow"
|
||||
ShowIgnoredFlag = "ignored"
|
||||
ShowSignatureFlag = "show-signature"
|
||||
SignFlag = "gpg-sign"
|
||||
SilentFlag = "silent"
|
||||
SingleBranchFlag = "single-branch"
|
||||
SkipEmptyFlag = "skip-empty"
|
||||
SkipVerificationFlag = "skip-verification"
|
||||
SoftResetParam = "soft"
|
||||
SquashParam = "squash"
|
||||
StagedFlag = "staged"
|
||||
StatFlag = "stat"
|
||||
SystemFlag = "system"
|
||||
TablesFlag = "tables"
|
||||
TheirsFlag = "theirs"
|
||||
TrackFlag = "track"
|
||||
UpperCaseAllFlag = "ALL"
|
||||
UserFlag = "user"
|
||||
)
|
||||
|
||||
// Flags used by `dolt diff` command and `dolt_diff()` table function.
|
||||
|
||||
@@ -32,16 +32,17 @@ const (
|
||||
|
||||
var cleanDocContent = cli.CommandDocumentationContent{
|
||||
ShortDesc: "Deletes untracked working tables",
|
||||
LongDesc: "{{.EmphasisLeft}}dolt clean [--dry-run]{{.EmphasisRight}}\n\n" +
|
||||
LongDesc: "{{.EmphasisLeft}}dolt clean [--dry-run] [-x]{{.EmphasisRight}}\n\n" +
|
||||
"The default (parameterless) form clears the values for all untracked working {{.LessThan}}tables{{.GreaterThan}} ." +
|
||||
"This command permanently deletes unstaged or uncommitted tables.\n\n" +
|
||||
"This command permanently deletes unstaged or uncommitted tables. By default, tables matching dolt_ignore or dolt_nonlocal_tables are not removed.\n\n" +
|
||||
"The {{.EmphasisLeft}}--dry-run{{.EmphasisRight}} flag can be used to test whether the clean can succeed without " +
|
||||
"deleting any tables from the current working set.\n\n" +
|
||||
"{{.EmphasisLeft}}dolt clean [--dry-run] {{.LessThan}}tables{{.GreaterThan}}...{{.EmphasisRight}}\n\n" +
|
||||
"The {{.EmphasisLeft}}-x{{.EmphasisRight}} flag causes dolt_ignore to be ignored so that untracked tables matching dolt_ignore are removed; dolt_nonlocal_tables is always respected (similar to git clean -x).\n\n" +
|
||||
"{{.EmphasisLeft}}dolt clean [--dry-run] [-x] {{.LessThan}}tables{{.GreaterThan}}...{{.EmphasisRight}}\n\n" +
|
||||
"If {{.LessThan}}tables{{.GreaterThan}} is specified, only those table names are considered for deleting.\n\n",
|
||||
Synopsis: []string{
|
||||
"[--dry-run]",
|
||||
"[--dry-run] {{.LessThan}}tables{{.GreaterThan}}...",
|
||||
"[--dry-run] [-x]",
|
||||
"[--dry-run] [-x] {{.LessThan}}tables{{.GreaterThan}}...",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -87,6 +88,13 @@ func (cmd CleanCmd) Exec(ctx context.Context, commandStr string, args []string,
|
||||
buffer.WriteString("\"--dry-run\"")
|
||||
firstParamDone = true
|
||||
}
|
||||
if apr.Contains(cli.ExcludeIgnoreRulesFlag) {
|
||||
if firstParamDone {
|
||||
buffer.WriteString(", ")
|
||||
}
|
||||
buffer.WriteString("\"-x\"")
|
||||
firstParamDone = true
|
||||
}
|
||||
if apr.NArg() > 0 {
|
||||
// loop over apr.Args() and add them to the buffer
|
||||
for i := 0; i < apr.NArg(); i++ {
|
||||
|
||||
@@ -266,6 +266,10 @@ func constructParametrizedDoltCommitQuery(msg string, apr *argparser.ArgParseRes
|
||||
writeToBuffer("--skip-empty")
|
||||
}
|
||||
|
||||
if apr.Contains(cli.SkipVerificationFlag) {
|
||||
writeToBuffer("--skip-verification")
|
||||
}
|
||||
|
||||
cfgSign := cliCtx.Config().GetStringOrDefault("sqlserver.global.gpgsign", "")
|
||||
if apr.Contains(cli.SignFlag) || strings.ToLower(cfgSign) == "true" {
|
||||
writeToBuffer("--gpg-sign")
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/dolthub/vitess/go/sqltypes"
|
||||
ast "github.com/dolthub/vitess/go/vt/sqlparser"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/gocraft/dbr/v2"
|
||||
"github.com/gocraft/dbr/v2/dialect"
|
||||
|
||||
@@ -549,20 +548,6 @@ func getTableNamesAtRef(queryist cli.Queryist, sqlCtx *sql.Context, ref string)
|
||||
return tableNames, nil
|
||||
}
|
||||
|
||||
func isTableNotFoundError(err error) bool {
|
||||
if sql.ErrTableNotFound.Is(err) {
|
||||
return true
|
||||
}
|
||||
mse, ok := err.(*mysql.MySQLError)
|
||||
if ok {
|
||||
if strings.HasPrefix(mse.Message, "table not found:") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// applyDiffRoots applies the appropriate |from| and |to| root values to the receiver and returns the table names
|
||||
// (if any) given to the command.
|
||||
func (dArgs *diffArgs) applyDiffRoots(queryist cli.Queryist, sqlCtx *sql.Context, args []string, isCached, useMergeBase bool) ([]string, error) {
|
||||
|
||||
@@ -623,16 +623,6 @@ func dumpTable(ctx *sql.Context, dEnv *env.DoltEnv, engine *sqle.Engine, root do
|
||||
}
|
||||
|
||||
func getTableWriter(ctx context.Context, dEnv *env.DoltEnv, tblOpts *tableOptions, outSch schema.Schema, filePath string) (table.SqlRowWriter, errhand.VerboseError) {
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("error: ").AddCause(err).Build()
|
||||
}
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("error: ").AddCause(err).Build()
|
||||
}
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
|
||||
writer, err := dEnv.FS.OpenForWriteAppend(filePath, os.ModePerm)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("Error opening writer for %s.", tblOpts.DestName()).AddCause(err).Build()
|
||||
@@ -643,7 +633,7 @@ func getTableWriter(ctx context.Context, dEnv *env.DoltEnv, tblOpts *tableOption
|
||||
return nil, errhand.BuildDError("Could not create table writer for %s", tblOpts.tableName).AddCause(err).Build()
|
||||
}
|
||||
|
||||
wr, err := tblOpts.dest.NewCreatingWriter(ctx, tblOpts, root, outSch, opts, writer)
|
||||
wr, err := tblOpts.dest.NewCreatingWriter(ctx, tblOpts, root, outSch, editor.Options{}, writer)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("Could not create table writer for %s", tblOpts.tableName).AddCause(err).Build()
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func NewSqlEngine(
|
||||
})
|
||||
}
|
||||
|
||||
dbs, locations, err := CollectDBs(ctx, mrEnv, config.Bulk)
|
||||
dbs, locations, err := CollectDBs(ctx, mrEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -26,13 +26,13 @@ import (
|
||||
|
||||
// CollectDBs takes a MultiRepoEnv and creates Database objects from each environment and returns a slice of these
|
||||
// objects.
|
||||
func CollectDBs(ctx context.Context, mrEnv *env.MultiRepoEnv, useBulkEditor bool) ([]dsess.SqlDatabase, []filesys.Filesys, error) {
|
||||
func CollectDBs(ctx context.Context, mrEnv *env.MultiRepoEnv) ([]dsess.SqlDatabase, []filesys.Filesys, error) {
|
||||
var dbs []dsess.SqlDatabase
|
||||
var locations []filesys.Filesys
|
||||
var db dsess.SqlDatabase
|
||||
|
||||
err := mrEnv.Iter(func(name string, dEnv *env.DoltEnv) (stop bool, err error) {
|
||||
db, err = newDatabase(ctx, name, dEnv, useBulkEditor)
|
||||
db, err = newDatabase(ctx, name, dEnv)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -50,25 +50,7 @@ func CollectDBs(ctx context.Context, mrEnv *env.MultiRepoEnv, useBulkEditor bool
|
||||
return dbs, locations, nil
|
||||
}
|
||||
|
||||
func newDatabase(ctx context.Context, name string, dEnv *env.DoltEnv, useBulkEditor bool) (sqle.Database, error) {
|
||||
var deaf editor.DbEaFactory
|
||||
var err error
|
||||
if useBulkEditor {
|
||||
deaf, err = dEnv.BulkDbEaFactory(ctx)
|
||||
} else {
|
||||
deaf, err = dEnv.DbEaFactory(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return sqle.Database{}, err
|
||||
}
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
if err != nil {
|
||||
return sqle.Database{}, err
|
||||
}
|
||||
opts := editor.Options{
|
||||
Deaf: deaf,
|
||||
Tempdir: tmpDir,
|
||||
}
|
||||
func newDatabase(ctx context.Context, name string, dEnv *env.DoltEnv) (sqle.Database, error) {
|
||||
dbdata := dEnv.DbData(ctx)
|
||||
// Databases registered with the SQL engine are always
|
||||
// configured for FatalBehaviorCrash. These are local
|
||||
@@ -81,5 +63,5 @@ func newDatabase(ctx context.Context, name string, dEnv *env.DoltEnv, useBulkEdi
|
||||
// See also sqle/database_provider.go, where we do this when
|
||||
// creating new databases as well.
|
||||
dbdata.Ddb.SetCrashOnFatalError()
|
||||
return sqle.NewDatabase(ctx, name, dbdata, opts)
|
||||
return sqle.NewDatabase(ctx, name, dbdata, editor.Options{})
|
||||
}
|
||||
|
||||
@@ -335,16 +335,7 @@ func processFilterQuery(ctx context.Context, dEnv *env.DoltEnv, root doltdb.Root
|
||||
// we set manually with the one at the working set of the HEAD being rebased.
|
||||
// Some functionality will not work on this kind of engine, e.g. many DOLT_ functions.
|
||||
func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue) (*sql.Context, *engine.SqlEngine, error) {
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
db, err := dsqle.NewDatabase(ctx, filterDbName, dEnv.DbData(ctx), opts)
|
||||
db, err := dsqle.NewDatabase(ctx, filterDbName, dEnv.DbData(ctx), editor.Options{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -87,15 +87,7 @@ func (cmd RebuildCmd) Exec(ctx context.Context, commandStr string, args []string
|
||||
if !ok {
|
||||
return HandleErr(errhand.BuildDError("The table `%s` does not exist.", tableName).Build(), nil)
|
||||
}
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
if err != nil {
|
||||
return HandleErr(errhand.BuildDError("error: ").AddCause(err).Build(), nil)
|
||||
}
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
if err != nil {
|
||||
return HandleErr(errhand.BuildDError("error: ").AddCause(err).Build(), nil)
|
||||
}
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
|
||||
sch, err := table.GetSchema(ctx)
|
||||
if err != nil {
|
||||
return HandleErr(errhand.BuildDError("could not get table schema").AddCause(err).Build(), nil)
|
||||
@@ -104,7 +96,7 @@ func (cmd RebuildCmd) Exec(ctx context.Context, commandStr string, args []string
|
||||
if idxSch == nil {
|
||||
return HandleErr(errhand.BuildDError("the index `%s` does not exist on table `%s`", indexName, tableName).Build(), nil)
|
||||
}
|
||||
indexRowData, err := creation.BuildSecondaryIndex(sql.NewContext(ctx), table, idxSch, tableName, opts)
|
||||
indexRowData, err := creation.BuildSecondaryIndex(sql.NewContext(ctx), table, idxSch, tableName, editor.Options{})
|
||||
if err != nil {
|
||||
return HandleErr(errhand.BuildDError("Unable to rebuild index `%s` on table `%s`.", indexName, tableName).AddCause(err).Build(), nil)
|
||||
}
|
||||
|
||||
@@ -318,6 +318,10 @@ func constructInterpolatedDoltMergeQuery(apr *argparser.ArgParseResults, cliCtx
|
||||
params = append(params, msg)
|
||||
}
|
||||
|
||||
if apr.Contains(cli.SkipVerificationFlag) {
|
||||
writeToBuffer("--skip-verification", false)
|
||||
}
|
||||
|
||||
if !apr.Contains(cli.AbortParam) && !apr.Contains(cli.SquashParam) {
|
||||
writeToBuffer("?", true)
|
||||
params = append(params, apr.Arg(0))
|
||||
|
||||
@@ -134,16 +134,7 @@ func exportSchemas(ctx context.Context, apr *argparser.ArgParseResults, root dol
|
||||
}
|
||||
|
||||
for _, tn := range tablesToExport {
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: ").AddCause(err).Build()
|
||||
}
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: ").AddCause(err).Build()
|
||||
}
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
verr := exportTblSchema(ctx, tn, root, wr, opts)
|
||||
verr := exportTblSchema(ctx, tn, root, wr, editor.Options{})
|
||||
if verr != nil {
|
||||
return verr
|
||||
}
|
||||
|
||||
@@ -133,16 +133,7 @@ func printSchemas(ctx context.Context, apr *argparser.ArgParseResults, dEnv *env
|
||||
}
|
||||
}
|
||||
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: ").AddCause(err).Build()
|
||||
}
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: ").AddCause(err).Build()
|
||||
}
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
sqlCtx, engine, _ := dsqle.PrepareCreateTableStmt(ctx, dsqle.NewUserSpaceDatabase(root, opts))
|
||||
sqlCtx, engine, _ := dsqle.PrepareCreateTableStmt(ctx, dsqle.NewUserSpaceDatabase(root, editor.Options{}))
|
||||
|
||||
var notFound []string
|
||||
for _, tblName := range tables {
|
||||
|
||||
@@ -156,6 +156,9 @@ func mcpRun(cfg *Config, lgr *logrus.Logger, state *svcs.ServiceState, cancelPtr
|
||||
logger,
|
||||
dbConf,
|
||||
*cfg.MCP.Port,
|
||||
nil, // jwkClaimsMap
|
||||
"", // jwkUrl
|
||||
nil, // tlsConfig
|
||||
toolsets.WithToolSet(&toolsets.PrimitiveToolSetV1{}),
|
||||
)
|
||||
if err != nil {
|
||||
|
||||
@@ -258,11 +258,7 @@ func getTableWriter(ctx context.Context, root doltdb.RootValue, dEnv *env.DoltEn
|
||||
return nil, errhand.BuildDError("Error opening writer for %s.", exOpts.DestName()).AddCause(err).Build()
|
||||
}
|
||||
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("Error opening writer for %s.", exOpts.DestName()).AddCause(err).Build()
|
||||
}
|
||||
wr, err := exOpts.dest.NewCreatingWriter(ctx, exOpts, root, rdSchema, editor.Options{Deaf: deaf}, writer)
|
||||
wr, err := exOpts.dest.NewCreatingWriter(ctx, exOpts, root, rdSchema, editor.Options{}, writer)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("Error opening writer for %s.", exOpts.DestName()).AddCause(err).Build()
|
||||
}
|
||||
|
||||
@@ -15,5 +15,5 @@
|
||||
package doltversion
|
||||
|
||||
const (
|
||||
Version = "1.81.8"
|
||||
Version = "1.81.10"
|
||||
)
|
||||
|
||||
@@ -579,7 +579,19 @@ func (rcv *RebaseState) MutateRebasingStarted(n bool) bool {
|
||||
return rcv._tab.MutateBoolSlot(16, n)
|
||||
}
|
||||
|
||||
const RebaseStateNumFields = 7
|
||||
func (rcv *RebaseState) SkipVerification() bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(18))
|
||||
if o != 0 {
|
||||
return rcv._tab.GetBool(o + rcv._tab.Pos)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rcv *RebaseState) MutateSkipVerification(n bool) bool {
|
||||
return rcv._tab.MutateBoolSlot(18, n)
|
||||
}
|
||||
|
||||
const RebaseStateNumFields = 8
|
||||
|
||||
func RebaseStateStart(builder *flatbuffers.Builder) {
|
||||
builder.StartObject(RebaseStateNumFields)
|
||||
@@ -614,6 +626,9 @@ func RebaseStateAddLastAttemptedStep(builder *flatbuffers.Builder, lastAttempted
|
||||
func RebaseStateAddRebasingStarted(builder *flatbuffers.Builder, rebasingStarted bool) {
|
||||
builder.PrependBoolSlot(6, rebasingStarted, false)
|
||||
}
|
||||
func RebaseStateAddSkipVerification(builder *flatbuffers.Builder, skipVerification bool) {
|
||||
builder.PrependBoolSlot(7, skipVerification, false)
|
||||
}
|
||||
func RebaseStateEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
return builder.EndObject()
|
||||
}
|
||||
|
||||
@@ -58,10 +58,10 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0
|
||||
github.com/creasty/defaults v1.6.0
|
||||
github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12
|
||||
github.com/dolthub/dolt-mcp v0.2.2
|
||||
github.com/dolthub/dolt-mcp v0.3.4
|
||||
github.com/dolthub/eventsapi_schema v0.0.0-20260205214132-a7a3c84c84a1
|
||||
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20260211220532-85072e590dc7
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20260212215527-0cb492ad7051
|
||||
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63
|
||||
github.com/edsrzf/mmap-go v1.2.0
|
||||
github.com/esote/minmaxheap v1.0.0
|
||||
|
||||
@@ -186,8 +186,8 @@ github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waN
|
||||
github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
|
||||
github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:IdqX7J8vi/Kn3T3Ee0VzqnLqwFmgA2hr8WZETPcQjfM=
|
||||
github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo=
|
||||
github.com/dolthub/dolt-mcp v0.2.2 h1:bpROmam74n95uU4EA3BpOIVlTDT0pzeFMBwe/YRq2mI=
|
||||
github.com/dolthub/dolt-mcp v0.2.2/go.mod h1:S++DJ4QWTAXq+0TNzFa7Oq3IhoT456DJHwAINFAHgDQ=
|
||||
github.com/dolthub/dolt-mcp v0.3.4 h1:AyG5cw+fNWXDHXujtQnqUPZrpWtPg6FN6yYtjv1pP44=
|
||||
github.com/dolthub/dolt-mcp v0.3.4/go.mod h1:bCZ7KHvDYs+M0e+ySgmGiNvLhcwsN7bbf5YCyillLrk=
|
||||
github.com/dolthub/eventsapi_schema v0.0.0-20260205214132-a7a3c84c84a1 h1:QePoMpa5qlquwUqRVyF9KAHsJAlYbE2+eZkMPAxeBXc=
|
||||
github.com/dolthub/eventsapi_schema v0.0.0-20260205214132-a7a3c84c84a1/go.mod h1:evuptFmr/0/j0X/g+3cveHEEOM5tqyRA15FNgirtOY0=
|
||||
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww=
|
||||
@@ -196,8 +196,8 @@ github.com/dolthub/fslock v0.0.0-20251215194149-ef20baba2318 h1:n+vdH5G5Db+1qnDC
|
||||
github.com/dolthub/fslock v0.0.0-20251215194149-ef20baba2318/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
|
||||
github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790 h1:zxMsH7RLiG+dlZ/y0LgJHTV26XoiSJcuWq+em6t6VVc=
|
||||
github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE=
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20260211220532-85072e590dc7 h1:9xC+/i949mi2wwsu6BKgvnDnuRcYy4KysrIb2x7DaSo=
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20260211220532-85072e590dc7/go.mod h1:LEWdXw6LKjdonOv2X808RpUc8wZVtQx4ZEPvmDWkvY4=
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20260212215527-0cb492ad7051 h1:7vNnl/Z2HhFFUTdXNOySd8KFODBztPlmCITrRIKDgTw=
|
||||
github.com/dolthub/go-mysql-server v0.20.1-0.20260212215527-0cb492ad7051/go.mod h1:LEWdXw6LKjdonOv2X808RpUc8wZVtQx4ZEPvmDWkvY4=
|
||||
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
|
||||
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
|
||||
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=
|
||||
|
||||
@@ -52,6 +52,9 @@ type CherryPickOptions struct {
|
||||
// and Dolt cherry-pick implementations, the default action is to fail when an empty commit is specified. In Git
|
||||
// and Dolt rebase implementations, the default action is to keep commits that start off as empty.
|
||||
EmptyCommitHandling doltdb.EmptyCommitHandling
|
||||
|
||||
// SkipVerification controls whether test validation should be skipped before creating commits.
|
||||
SkipVerification bool
|
||||
}
|
||||
|
||||
// NewCherryPickOptions creates a new CherryPickOptions instance, filled out with default values for cherry-pick.
|
||||
@@ -61,6 +64,7 @@ func NewCherryPickOptions() CherryPickOptions {
|
||||
CommitMessage: "",
|
||||
CommitBecomesEmptyHandling: doltdb.ErrorOnEmptyCommit,
|
||||
EmptyCommitHandling: doltdb.ErrorOnEmptyCommit,
|
||||
SkipVerification: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,9 +163,10 @@ func CreateCommitStagedPropsFromCherryPickOptions(ctx *sql.Context, options Cher
|
||||
}
|
||||
|
||||
commitProps := actions.CommitStagedProps{
|
||||
Date: originalMeta.Time(),
|
||||
Name: originalMeta.Name,
|
||||
Email: originalMeta.Email,
|
||||
Date: originalMeta.Time(),
|
||||
Name: originalMeta.Name,
|
||||
Email: originalMeta.Email,
|
||||
SkipVerification: options.SkipVerification,
|
||||
}
|
||||
|
||||
if options.CommitMessage != "" {
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
// Copyright 2019 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package conflict
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema/encoding"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
type ConflictSchema struct {
|
||||
Base schema.Schema
|
||||
Schema schema.Schema
|
||||
MergeSchema schema.Schema
|
||||
}
|
||||
|
||||
func NewConflictSchema(base, sch, mergeSch schema.Schema) ConflictSchema {
|
||||
return ConflictSchema{
|
||||
Base: base,
|
||||
Schema: sch,
|
||||
MergeSchema: mergeSch,
|
||||
}
|
||||
}
|
||||
|
||||
func ValueFromConflictSchema(ctx context.Context, vrw types.ValueReadWriter, cs ConflictSchema) (types.Value, error) {
|
||||
b, err := serializeSchema(ctx, vrw, cs.Base)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, err := serializeSchema(ctx, vrw, cs.Schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := serializeSchema(ctx, vrw, cs.MergeSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return types.NewTuple(vrw.Format(), b, s, m)
|
||||
}
|
||||
|
||||
func ConflictSchemaFromValue(ctx context.Context, vrw types.ValueReadWriter, v types.Value) (cs ConflictSchema, err error) {
|
||||
tup, ok := v.(types.Tuple)
|
||||
if !ok {
|
||||
err = errors.New("conflict schema value must be types.Struct")
|
||||
return ConflictSchema{}, err
|
||||
}
|
||||
|
||||
b, err := tup.Get(0)
|
||||
if err != nil {
|
||||
return ConflictSchema{}, err
|
||||
}
|
||||
cs.Base, err = deserializeSchema(ctx, vrw, b)
|
||||
if err != nil {
|
||||
return ConflictSchema{}, err
|
||||
}
|
||||
|
||||
s, err := tup.Get(1)
|
||||
if err != nil {
|
||||
return ConflictSchema{}, err
|
||||
}
|
||||
cs.Schema, err = deserializeSchema(ctx, vrw, s)
|
||||
if err != nil {
|
||||
return ConflictSchema{}, err
|
||||
}
|
||||
|
||||
m, err := tup.Get(2)
|
||||
if err != nil {
|
||||
return ConflictSchema{}, err
|
||||
}
|
||||
cs.MergeSchema, err = deserializeSchema(ctx, vrw, m)
|
||||
if err != nil {
|
||||
return ConflictSchema{}, err
|
||||
}
|
||||
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func serializeSchema(ctx context.Context, vrw types.ValueReadWriter, sch schema.Schema) (types.Ref, error) {
|
||||
st, err := encoding.MarshalSchema(ctx, vrw, sch)
|
||||
if err != nil {
|
||||
return types.Ref{}, err
|
||||
}
|
||||
|
||||
return vrw.WriteValue(ctx, st)
|
||||
}
|
||||
|
||||
func deserializeSchema(ctx context.Context, vrw types.ValueReadWriter, v types.Value) (schema.Schema, error) {
|
||||
r, ok := v.(types.Ref)
|
||||
if !ok {
|
||||
return nil, errors.New("conflict schemas field value is unexpected type")
|
||||
}
|
||||
|
||||
return encoding.UnmarshalSchemaAtAddr(ctx, vrw, r.TargetHash())
|
||||
}
|
||||
|
||||
type Conflict struct {
|
||||
Base types.Value
|
||||
Value types.Value
|
||||
MergeValue types.Value
|
||||
}
|
||||
|
||||
func NewConflict(base, value, mergeValue types.Value) Conflict {
|
||||
if base == nil {
|
||||
base = types.NullValue
|
||||
}
|
||||
if value == nil {
|
||||
value = types.NullValue
|
||||
}
|
||||
if mergeValue == nil {
|
||||
mergeValue = types.NullValue
|
||||
}
|
||||
return Conflict{base, value, mergeValue}
|
||||
}
|
||||
|
||||
func ConflictFromTuple(tpl types.Tuple) (Conflict, error) {
|
||||
base, err := tpl.Get(0)
|
||||
|
||||
if err != nil {
|
||||
return Conflict{}, err
|
||||
}
|
||||
|
||||
val, err := tpl.Get(1)
|
||||
|
||||
if err != nil {
|
||||
return Conflict{}, err
|
||||
}
|
||||
|
||||
mv, err := tpl.Get(2)
|
||||
|
||||
if err != nil {
|
||||
return Conflict{}, err
|
||||
}
|
||||
return Conflict{base, val, mv}, nil
|
||||
}
|
||||
|
||||
func (c Conflict) ToNomsList(vrw types.ValueReadWriter) (types.Tuple, error) {
|
||||
return types.NewTuple(vrw.Format(), c.Base, c.Value, c.MergeValue)
|
||||
}
|
||||
@@ -1,285 +0,0 @@
|
||||
// Copyright 2019 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package diff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/row"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/async"
|
||||
"github.com/dolthub/dolt/go/store/diff"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
// todo: make package private
|
||||
type AsyncDiffer struct {
|
||||
diffChan chan diff.Difference
|
||||
bufferSize int
|
||||
|
||||
eg *errgroup.Group
|
||||
egCtx context.Context
|
||||
egCancel func()
|
||||
|
||||
diffStats map[types.DiffChangeType]uint64
|
||||
}
|
||||
|
||||
var _ RowDiffer = &AsyncDiffer{}
|
||||
|
||||
// todo: make package private once dolthub is migrated
|
||||
func NewAsyncDiffer(bufferedDiffs int) *AsyncDiffer {
|
||||
return &AsyncDiffer{
|
||||
diffChan: make(chan diff.Difference, bufferedDiffs),
|
||||
bufferSize: bufferedDiffs,
|
||||
egCtx: context.Background(),
|
||||
egCancel: func() {},
|
||||
diffStats: make(map[types.DiffChangeType]uint64),
|
||||
}
|
||||
}
|
||||
|
||||
func tableDontDescendLists(v1, v2 types.Value) bool {
|
||||
kind := v1.Kind()
|
||||
return !types.IsPrimitiveKind(kind) && kind != types.TupleKind && kind == v2.Kind() && kind != types.RefKind
|
||||
}
|
||||
|
||||
func (ad *AsyncDiffer) Start(ctx context.Context, from, to types.Map) {
|
||||
ad.start(ctx, func(ctx context.Context) error {
|
||||
return diff.Diff(ctx, from, to, ad.diffChan, true, tableDontDescendLists)
|
||||
})
|
||||
}
|
||||
|
||||
func (ad *AsyncDiffer) StartWithRange(ctx context.Context, from, to types.Map, start types.Value, inRange types.ValueInRange) {
|
||||
ad.start(ctx, func(ctx context.Context) error {
|
||||
return diff.DiffMapRange(ctx, from, to, start, inRange, ad.diffChan, true, tableDontDescendLists)
|
||||
})
|
||||
}
|
||||
|
||||
func (ad *AsyncDiffer) start(ctx context.Context, diffFunc func(ctx context.Context) error) {
|
||||
ad.eg, ad.egCtx = errgroup.WithContext(ctx)
|
||||
ad.egCancel = async.GoWithCancel(ad.egCtx, ad.eg, func(ctx context.Context) (err error) {
|
||||
defer close(ad.diffChan)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in diff.Diff: %v", r)
|
||||
}
|
||||
}()
|
||||
return diffFunc(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (ad *AsyncDiffer) Close() error {
|
||||
ad.egCancel()
|
||||
return ad.eg.Wait()
|
||||
}
|
||||
|
||||
func (ad *AsyncDiffer) getDiffs(numDiffs int, timeoutChan <-chan time.Time, pred diffPredicate) ([]*diff.Difference, bool, error) {
|
||||
diffs := make([]*diff.Difference, 0, numDiffs)
|
||||
for {
|
||||
select {
|
||||
case d, more := <-ad.diffChan:
|
||||
if more {
|
||||
if pred(&d) {
|
||||
ad.diffStats[d.ChangeType]++
|
||||
diffs = append(diffs, &d)
|
||||
}
|
||||
if numDiffs != 0 && numDiffs == len(diffs) {
|
||||
return diffs, true, nil
|
||||
}
|
||||
} else {
|
||||
return diffs, false, ad.eg.Wait()
|
||||
}
|
||||
case <-timeoutChan:
|
||||
return diffs, true, nil
|
||||
case <-ad.egCtx.Done():
|
||||
return nil, false, ad.eg.Wait()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var forever <-chan time.Time = make(chan time.Time)
|
||||
|
||||
type diffPredicate func(*diff.Difference) bool
|
||||
|
||||
var alwaysTruePredicate diffPredicate = func(*diff.Difference) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func hasChangeTypePredicate(changeType types.DiffChangeType) diffPredicate {
|
||||
return func(d *diff.Difference) bool {
|
||||
return d.ChangeType == changeType
|
||||
}
|
||||
}
|
||||
|
||||
func (ad *AsyncDiffer) GetDiffs(numDiffs int, timeout time.Duration) ([]*diff.Difference, bool, error) {
|
||||
if timeout < 0 {
|
||||
return ad.GetDiffsWithoutTimeout(numDiffs)
|
||||
}
|
||||
return ad.getDiffs(numDiffs, time.After(timeout), alwaysTruePredicate)
|
||||
}
|
||||
|
||||
func (ad *AsyncDiffer) GetDiffsWithFilter(numDiffs int, timeout time.Duration, filterByChangeType types.DiffChangeType) ([]*diff.Difference, bool, error) {
|
||||
if timeout < 0 {
|
||||
return ad.GetDiffsWithoutTimeoutWithFilter(numDiffs, filterByChangeType)
|
||||
}
|
||||
return ad.getDiffs(numDiffs, time.After(timeout), hasChangeTypePredicate(filterByChangeType))
|
||||
}
|
||||
|
||||
func (ad *AsyncDiffer) GetDiffsWithoutTimeoutWithFilter(numDiffs int, filterByChangeType types.DiffChangeType) ([]*diff.Difference, bool, error) {
|
||||
return ad.getDiffs(numDiffs, forever, hasChangeTypePredicate(filterByChangeType))
|
||||
}
|
||||
|
||||
func (ad *AsyncDiffer) GetDiffsWithoutTimeout(numDiffs int) ([]*diff.Difference, bool, error) {
|
||||
return ad.getDiffs(numDiffs, forever, alwaysTruePredicate)
|
||||
}
|
||||
|
||||
type keylessDiffer struct {
|
||||
*AsyncDiffer
|
||||
|
||||
df diff.Difference
|
||||
copiesLeft uint64
|
||||
}
|
||||
|
||||
var _ RowDiffer = &keylessDiffer{}
|
||||
|
||||
func (kd *keylessDiffer) getDiffs(numDiffs int, timeoutChan <-chan time.Time, pred diffPredicate) ([]*diff.Difference, bool, error) {
|
||||
diffs := make([]*diff.Difference, numDiffs)
|
||||
idx := 0
|
||||
|
||||
for {
|
||||
// first populate |diffs| with copies of |kd.df|
|
||||
|
||||
cpy := kd.df // save a copy of kd.df to reference
|
||||
for (idx < numDiffs) && (kd.copiesLeft > 0) {
|
||||
diffs[idx] = &cpy
|
||||
idx++
|
||||
kd.copiesLeft--
|
||||
}
|
||||
if idx == numDiffs {
|
||||
return diffs, true, nil
|
||||
}
|
||||
|
||||
// then find the next Difference the satisfies |pred|
|
||||
match := false
|
||||
for !match {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
return diffs, true, nil
|
||||
|
||||
case <-kd.egCtx.Done():
|
||||
return nil, false, kd.eg.Wait()
|
||||
|
||||
case d, more := <-kd.diffChan:
|
||||
if !more {
|
||||
return diffs[:idx], more, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
kd.df, kd.copiesLeft, err = convertDiff(d)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
match = pred(&kd.df)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (kd *keylessDiffer) GetDiffs(numDiffs int, timeout time.Duration) ([]*diff.Difference, bool, error) {
|
||||
if timeout < 0 {
|
||||
return kd.getDiffs(numDiffs, forever, alwaysTruePredicate)
|
||||
}
|
||||
return kd.getDiffs(numDiffs, time.After(timeout), alwaysTruePredicate)
|
||||
}
|
||||
|
||||
func (kd *keylessDiffer) GetDiffsWithFilter(numDiffs int, timeout time.Duration, filterByChangeType types.DiffChangeType) ([]*diff.Difference, bool, error) {
|
||||
if timeout < 0 {
|
||||
return kd.getDiffs(numDiffs, forever, hasChangeTypePredicate(filterByChangeType))
|
||||
}
|
||||
return kd.getDiffs(numDiffs, time.After(timeout), hasChangeTypePredicate(filterByChangeType))
|
||||
}
|
||||
|
||||
// convertDiff reports the cardinality of a change,
|
||||
// and converts updates to adds or deletes
|
||||
func convertDiff(df diff.Difference) (diff.Difference, uint64, error) {
|
||||
var oldCard uint64
|
||||
if df.OldValue != nil {
|
||||
v, err := df.OldValue.(types.Tuple).Get(row.KeylessCardinalityValIdx)
|
||||
if err != nil {
|
||||
return df, 0, err
|
||||
}
|
||||
oldCard = uint64(v.(types.Uint))
|
||||
}
|
||||
|
||||
var newCard uint64
|
||||
if df.NewValue != nil {
|
||||
v, err := df.NewValue.(types.Tuple).Get(row.KeylessCardinalityValIdx)
|
||||
if err != nil {
|
||||
return df, 0, err
|
||||
}
|
||||
newCard = uint64(v.(types.Uint))
|
||||
}
|
||||
|
||||
switch df.ChangeType {
|
||||
case types.DiffChangeRemoved:
|
||||
return df, oldCard, nil
|
||||
|
||||
case types.DiffChangeAdded:
|
||||
return df, newCard, nil
|
||||
|
||||
case types.DiffChangeModified:
|
||||
delta := int64(newCard) - int64(oldCard)
|
||||
if delta > 0 {
|
||||
df.ChangeType = types.DiffChangeAdded
|
||||
df.OldValue = nil
|
||||
return df, uint64(delta), nil
|
||||
} else if delta < 0 {
|
||||
df.ChangeType = types.DiffChangeRemoved
|
||||
df.NewValue = nil
|
||||
return df, uint64(-delta), nil
|
||||
} else {
|
||||
panic(fmt.Sprintf("diff with delta = 0 for key: %s", df.KeyValue.HumanReadableString()))
|
||||
}
|
||||
default:
|
||||
return df, 0, fmt.Errorf("unexpected DiffChange type %d", df.ChangeType)
|
||||
}
|
||||
}
|
||||
|
||||
type EmptyRowDiffer struct {
|
||||
}
|
||||
|
||||
var _ RowDiffer = &EmptyRowDiffer{}
|
||||
|
||||
func (e EmptyRowDiffer) Start(ctx context.Context, from, to types.Map) {
|
||||
}
|
||||
|
||||
func (e EmptyRowDiffer) StartWithRange(ctx context.Context, from, to types.Map, start types.Value, inRange types.ValueInRange) {
|
||||
|
||||
}
|
||||
|
||||
func (e EmptyRowDiffer) GetDiffs(numDiffs int, timeout time.Duration) ([]*diff.Difference, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (e EmptyRowDiffer) GetDiffsWithFilter(numDiffs int, timeout time.Duration, filterByChangeType types.DiffChangeType) ([]*diff.Difference, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (e EmptyRowDiffer) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,345 +0,0 @@
|
||||
// Copyright 2021 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package diff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/store/chunks"
|
||||
"github.com/dolthub/dolt/go/store/constants"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
func TestAsyncDiffer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
storage := &chunks.MemoryStorage{}
|
||||
vrw := types.NewValueStore(storage.NewView())
|
||||
|
||||
vals := []types.Value{
|
||||
types.Uint(0), types.String("a"),
|
||||
types.Uint(1), types.String("b"),
|
||||
types.Uint(3), types.String("d"),
|
||||
types.Uint(4), types.String("e"),
|
||||
types.Uint(6), types.String("g"),
|
||||
types.Uint(7), types.String("h"),
|
||||
types.Uint(9), types.String("j"),
|
||||
types.Uint(10), types.String("k"),
|
||||
types.Uint(12), types.String("m"),
|
||||
types.Uint(13), types.String("n"),
|
||||
types.Uint(15), types.String("p"),
|
||||
types.Uint(16), types.String("q"),
|
||||
types.Uint(18), types.String("s"),
|
||||
types.Uint(19), types.String("t"),
|
||||
types.Uint(21), types.String("v"),
|
||||
types.Uint(22), types.String("w"),
|
||||
types.Uint(24), types.String("y"),
|
||||
types.Uint(25), types.String("z"),
|
||||
}
|
||||
|
||||
m1, err := types.NewMap(ctx, vrw, vals...)
|
||||
require.NoError(t, err)
|
||||
|
||||
vals = []types.Value{
|
||||
types.Uint(0), types.String("a"), // unchanged
|
||||
//types.Uint(1), types.String("b"), // deleted
|
||||
types.Uint(2), types.String("c"), // added
|
||||
types.Uint(3), types.String("d"), // unchanged
|
||||
//types.Uint(4), types.String("e"), // deleted
|
||||
types.Uint(5), types.String("f"), // added
|
||||
types.Uint(6), types.String("g"), // unchanged
|
||||
//types.Uint(7), types.String("h"), // deleted
|
||||
types.Uint(8), types.String("i"), // added
|
||||
types.Uint(9), types.String("j"), // unchanged
|
||||
//types.Uint(10), types.String("k"), // deleted
|
||||
types.Uint(11), types.String("l"), // added
|
||||
types.Uint(12), types.String("m2"), // changed
|
||||
//types.Uint(13), types.String("n"), // deleted
|
||||
types.Uint(14), types.String("o"), // added
|
||||
types.Uint(15), types.String("p2"), // changed
|
||||
//types.Uint(16), types.String("q"), // deleted
|
||||
types.Uint(17), types.String("r"), // added
|
||||
types.Uint(18), types.String("s2"), // changed
|
||||
//types.Uint(19), types.String("t"), // deleted
|
||||
types.Uint(20), types.String("u"), // added
|
||||
types.Uint(21), types.String("v2"), // changed
|
||||
//types.Uint(22), types.String("w"), // deleted
|
||||
types.Uint(23), types.String("x"), // added
|
||||
types.Uint(24), types.String("y2"), // changed
|
||||
//types.Uint(25), types.String("z"), // deleted
|
||||
}
|
||||
m2, err := types.NewMap(ctx, vrw, vals...)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
createdStarted func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer
|
||||
expectedStats map[types.DiffChangeType]uint64
|
||||
}{
|
||||
{
|
||||
name: "iter all",
|
||||
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
|
||||
ad := NewAsyncDiffer(4)
|
||||
ad.Start(ctx, m1, m2)
|
||||
return ad
|
||||
},
|
||||
expectedStats: map[types.DiffChangeType]uint64{
|
||||
types.DiffChangeModified: 5,
|
||||
types.DiffChangeAdded: 8,
|
||||
types.DiffChangeRemoved: 9,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "iter range starting with nil",
|
||||
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
|
||||
ad := NewAsyncDiffer(4)
|
||||
ad.StartWithRange(ctx, m1, m2, nil, func(ctx context.Context, value types.Value) (bool, bool, error) {
|
||||
return true, false, nil
|
||||
})
|
||||
return ad
|
||||
},
|
||||
expectedStats: map[types.DiffChangeType]uint64{
|
||||
types.DiffChangeModified: 5,
|
||||
types.DiffChangeAdded: 8,
|
||||
types.DiffChangeRemoved: 9,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "iter range staring with Null Value",
|
||||
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
|
||||
ad := NewAsyncDiffer(4)
|
||||
ad.StartWithRange(ctx, m1, m2, types.NullValue, func(ctx context.Context, value types.Value) (bool, bool, error) {
|
||||
return true, false, nil
|
||||
})
|
||||
return ad
|
||||
},
|
||||
expectedStats: map[types.DiffChangeType]uint64{
|
||||
types.DiffChangeModified: 5,
|
||||
types.DiffChangeAdded: 8,
|
||||
types.DiffChangeRemoved: 9,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "iter range less than 17",
|
||||
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
|
||||
ad := NewAsyncDiffer(4)
|
||||
end := types.Uint(27)
|
||||
ad.StartWithRange(ctx, m1, m2, types.NullValue, func(ctx context.Context, value types.Value) (bool, bool, error) {
|
||||
valid, err := value.Less(ctx, vrw.Format(), end)
|
||||
return valid, false, err
|
||||
})
|
||||
return ad
|
||||
},
|
||||
expectedStats: map[types.DiffChangeType]uint64{
|
||||
types.DiffChangeModified: 5,
|
||||
types.DiffChangeAdded: 8,
|
||||
types.DiffChangeRemoved: 9,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "iter range less than 15",
|
||||
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
|
||||
ad := NewAsyncDiffer(4)
|
||||
end := types.Uint(15)
|
||||
ad.StartWithRange(ctx, m1, m2, types.NullValue, func(ctx context.Context, value types.Value) (bool, bool, error) {
|
||||
valid, err := value.Less(ctx, vrw.Format(), end)
|
||||
return valid, false, err
|
||||
})
|
||||
return ad
|
||||
},
|
||||
expectedStats: map[types.DiffChangeType]uint64{
|
||||
types.DiffChangeModified: 1,
|
||||
types.DiffChangeAdded: 5,
|
||||
types.DiffChangeRemoved: 5,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "iter range 10 < 15",
|
||||
createdStarted: func(ctx context.Context, m1, m2 types.Map) *AsyncDiffer {
|
||||
ad := NewAsyncDiffer(4)
|
||||
start := types.Uint(10)
|
||||
end := types.Uint(15)
|
||||
ad.StartWithRange(ctx, m1, m2, start, func(ctx context.Context, value types.Value) (bool, bool, error) {
|
||||
valid, err := value.Less(ctx, vrw.Format(), end)
|
||||
return valid, false, err
|
||||
})
|
||||
return ad
|
||||
},
|
||||
expectedStats: map[types.DiffChangeType]uint64{
|
||||
types.DiffChangeModified: 1,
|
||||
types.DiffChangeAdded: 2,
|
||||
types.DiffChangeRemoved: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ad := test.createdStarted(ctx, m1, m2)
|
||||
err := readAll(ad)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expectedStats, ad.diffStats)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("can close without reading all", func(t *testing.T) {
|
||||
ad := NewAsyncDiffer(1)
|
||||
ad.Start(ctx, m1, m2)
|
||||
res, more, err := ad.GetDiffs(1, -1)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, more)
|
||||
assert.Len(t, res, 1)
|
||||
err = ad.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("can filter based on change type", func(t *testing.T) {
|
||||
ad := NewAsyncDiffer(20)
|
||||
ad.Start(ctx, m1, m2)
|
||||
res, more, err := ad.GetDiffs(10, -1)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, more)
|
||||
assert.Len(t, res, 10)
|
||||
err = ad.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ad = NewAsyncDiffer(20)
|
||||
ad.Start(ctx, m1, m2)
|
||||
res, more, err = ad.GetDiffsWithFilter(10, 20*time.Second, types.DiffChangeModified)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, more)
|
||||
assert.Len(t, res, 5)
|
||||
err = ad.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ad = NewAsyncDiffer(20)
|
||||
ad.Start(ctx, m1, m2)
|
||||
res, more, err = ad.GetDiffsWithFilter(6, -1, types.DiffChangeAdded)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, more)
|
||||
assert.Len(t, res, 6)
|
||||
err = ad.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
k1Row1Vals := []types.Value{c1Tag, types.Uint(3), c2Tag, types.String("d")}
|
||||
k1Vals, err := getKeylessRow(ctx, k1Row1Vals)
|
||||
assert.NoError(t, err)
|
||||
k1, err := types.NewMap(ctx, vrw, k1Vals...)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Delete one row, add two rows
|
||||
k2Row1Vals := []types.Value{c1Tag, types.Uint(4), c2Tag, types.String("d")}
|
||||
k2Vals1, err := getKeylessRow(ctx, k2Row1Vals)
|
||||
assert.NoError(t, err)
|
||||
k2Row2Vals := []types.Value{c1Tag, types.Uint(1), c2Tag, types.String("e")}
|
||||
k2Vals2, err := getKeylessRow(ctx, k2Row2Vals)
|
||||
assert.NoError(t, err)
|
||||
k2Vals := append(k2Vals1, k2Vals2...)
|
||||
k2, err := types.NewMap(ctx, vrw, k2Vals...)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("can diff and filter keyless tables", func(t *testing.T) {
|
||||
kd := &keylessDiffer{AsyncDiffer: NewAsyncDiffer(20)}
|
||||
kd.Start(ctx, k1, k2)
|
||||
res, more, err := kd.GetDiffs(10, 20*time.Second)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, more)
|
||||
assert.Len(t, res, 3)
|
||||
err = kd.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
kd = &keylessDiffer{AsyncDiffer: NewAsyncDiffer(20)}
|
||||
kd.Start(ctx, k1, k2)
|
||||
res, more, err = kd.GetDiffsWithFilter(10, 20*time.Second, types.DiffChangeModified)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, more)
|
||||
assert.Len(t, res, 0)
|
||||
err = kd.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
kd = &keylessDiffer{AsyncDiffer: NewAsyncDiffer(20)}
|
||||
kd.Start(ctx, k1, k2)
|
||||
res, more, err = kd.GetDiffsWithFilter(6, -1, types.DiffChangeAdded)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, more)
|
||||
assert.Len(t, res, 2)
|
||||
err = kd.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func readAll(ad *AsyncDiffer) error {
|
||||
for {
|
||||
_, more, err := ad.GetDiffs(10, -1)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !more {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var c1Tag = types.Uint(1)
|
||||
var c2Tag = types.Uint(2)
|
||||
var cardTag = types.Uint(schema.KeylessRowCardinalityTag)
|
||||
var rowIdTag = types.Uint(schema.KeylessRowIdTag)
|
||||
|
||||
func getKeylessRow(ctx context.Context, vals []types.Value) ([]types.Value, error) {
|
||||
nbf, err := types.GetFormatForVersionString(constants.FormatDefaultString)
|
||||
if err != nil {
|
||||
return []types.Value{}, err
|
||||
}
|
||||
|
||||
id1, err := types.UUIDHashedFromValues(nbf, vals...)
|
||||
if err != nil {
|
||||
return []types.Value{}, err
|
||||
}
|
||||
|
||||
prefix := []types.Value{
|
||||
cardTag,
|
||||
types.Uint(1),
|
||||
}
|
||||
vals = append(prefix, vals...)
|
||||
|
||||
return []types.Value{
|
||||
mustTuple(rowIdTag, id1),
|
||||
mustTuple(vals...),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func mustTuple(vals ...types.Value) types.Tuple {
|
||||
tup, err := types.NewTuple(types.Format_Default, vals...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return tup
|
||||
}
|
||||
@@ -19,13 +19,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/errhand"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/row"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/store/diff"
|
||||
"github.com/dolthub/dolt/go/store/prolly"
|
||||
"github.com/dolthub/dolt/go/store/prolly/tree"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
@@ -39,7 +36,6 @@ type DiffStatProgress struct {
|
||||
}
|
||||
|
||||
type prollyReporter func(ctx context.Context, vMapping val.OrdinalMapping, fromD, toD *val.TupleDesc, change tree.Diff, ch chan<- DiffStatProgress) error
|
||||
type nomsReporter func(ctx context.Context, change *diff.Difference, fromSch, toSch schema.Schema, ch chan<- DiffStatProgress) error
|
||||
|
||||
// Stat reports a stat of diff changes between two values
|
||||
// todo: make package private once dolthub is migrated
|
||||
@@ -168,38 +164,6 @@ func diffProllyTrees(ctx context.Context, ch chan DiffStatProgress, keyless bool
|
||||
return nil
|
||||
}
|
||||
|
||||
func statWithReporter(ctx context.Context, ch chan DiffStatProgress, from, to types.Map, rpr nomsReporter, fromSch, toSch schema.Schema) (err error) {
|
||||
ad := NewAsyncDiffer(1024)
|
||||
ad.Start(ctx, from, to)
|
||||
defer func() {
|
||||
if cerr := ad.Close(); cerr != nil && err == nil {
|
||||
err = cerr
|
||||
}
|
||||
}()
|
||||
|
||||
var more bool
|
||||
var diffs []*diff.Difference
|
||||
for {
|
||||
diffs, more, err = ad.GetDiffs(100, time.Millisecond)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, df := range diffs {
|
||||
err = rpr(ctx, df, fromSch, toSch, ch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !more {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func reportPkChanges(ctx context.Context, vMapping val.OrdinalMapping, fromD, toD *val.TupleDesc, change tree.Diff, ch chan<- DiffStatProgress) error {
|
||||
var stat DiffStatProgress
|
||||
switch change.Type {
|
||||
@@ -280,66 +244,3 @@ func prollyCountCellDiff(ctx context.Context, mapping val.OrdinalMapping, fromD,
|
||||
changed += newCols
|
||||
return changed
|
||||
}
|
||||
|
||||
func reportNomsPkChanges(ctx context.Context, change *diff.Difference, fromSch, toSch schema.Schema, ch chan<- DiffStatProgress) error {
|
||||
var stat DiffStatProgress
|
||||
switch change.ChangeType {
|
||||
case types.DiffChangeAdded:
|
||||
stat = DiffStatProgress{Adds: 1}
|
||||
case types.DiffChangeRemoved:
|
||||
stat = DiffStatProgress{Removes: 1}
|
||||
case types.DiffChangeModified:
|
||||
oldTuple := change.OldValue.(types.Tuple)
|
||||
newTuple := change.NewValue.(types.Tuple)
|
||||
cellChanges, err := row.CountCellDiffs(oldTuple, newTuple, fromSch, toSch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stat = DiffStatProgress{Changes: 1, CellChanges: cellChanges}
|
||||
default:
|
||||
return errors.New("unknown change type")
|
||||
}
|
||||
select {
|
||||
case ch <- stat:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func reportNomsKeylessChanges(ctx context.Context, change *diff.Difference, fromSch, toSch schema.Schema, ch chan<- DiffStatProgress) error {
|
||||
var oldCard uint64
|
||||
if change.OldValue != nil {
|
||||
v, err := change.OldValue.(types.Tuple).Get(row.KeylessCardinalityValIdx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oldCard = uint64(v.(types.Uint))
|
||||
}
|
||||
|
||||
var newCard uint64
|
||||
if change.NewValue != nil {
|
||||
v, err := change.NewValue.(types.Tuple).Get(row.KeylessCardinalityValIdx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newCard = uint64(v.(types.Uint))
|
||||
}
|
||||
|
||||
var stat DiffStatProgress
|
||||
delta := int64(newCard) - int64(oldCard)
|
||||
if delta > 0 {
|
||||
stat = DiffStatProgress{Adds: uint64(delta)}
|
||||
} else if delta < 0 {
|
||||
stat = DiffStatProgress{Removes: uint64(-delta)}
|
||||
} else {
|
||||
return fmt.Errorf("diff with delta = 0 for key: %s", change.KeyValue.HumanReadableString())
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- stat:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,31 +118,6 @@ func mapQuerySchemaToTargetSchema(query, target sql.Schema) (mapping []int, err
|
||||
return
|
||||
}
|
||||
|
||||
func mapToAndFromColumns(query sql.Schema) (mapping []int, err error) {
|
||||
last := query[len(query)-1]
|
||||
if last.Name != "diff_type" {
|
||||
return nil, errors.New("expected last diff column to be 'diff_type'")
|
||||
}
|
||||
query = query[:len(query)-1]
|
||||
|
||||
mapping = make([]int, len(query))
|
||||
for i, col := range query {
|
||||
if strings.HasPrefix(col.Name, fromPrefix) {
|
||||
// map "from_..." column to "to_..." column
|
||||
base := col.Name[len(fromPrefix):]
|
||||
mapping[i] = query.IndexOfColName(toPrefix + base)
|
||||
} else if strings.HasPrefix(col.Name, toPrefix) {
|
||||
// map "to_..." column to "from_..." column
|
||||
base := col.Name[len(toPrefix):]
|
||||
mapping[i] = query.IndexOfColName(fromPrefix + base)
|
||||
} else {
|
||||
return nil, errors.New("expected column prefix of 'to_' or 'from_' (" + col.Name + ")")
|
||||
}
|
||||
}
|
||||
// |mapping| will contain -1 for unmapped columns
|
||||
return
|
||||
}
|
||||
|
||||
func (ds DiffSplitter) SplitDiffResultRow(ctx *sql.Context, row sql.Row) (from, to RowDiff, err error) {
|
||||
from = RowDiff{ColDiffs: make([]ChangeType, len(ds.targetSch))}
|
||||
to = RowDiff{ColDiffs: make([]ChangeType, len(ds.targetSch))}
|
||||
|
||||
@@ -81,15 +81,6 @@ func NewCommit(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore
|
||||
return &Commit{vrw, ns, parents, commit}, nil
|
||||
}
|
||||
|
||||
// NewCommitFromValue generates a new Commit object that wraps a supplied types.Value.
|
||||
func NewCommitFromValue(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, value types.Value) (*Commit, error) {
|
||||
commit, err := datas.CommitFromValue(vrw.Format(), value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewCommit(ctx, vrw, ns, commit)
|
||||
}
|
||||
|
||||
// HashOf returns the hash of the commit
|
||||
func (c *Commit) HashOf() (hash.Hash, error) {
|
||||
return c.dCommit.Addr(), nil
|
||||
|
||||
@@ -27,7 +27,6 @@ var ErrInvTableName = errors.New("not a valid table name")
|
||||
var ErrInvHash = errors.New("not a valid hash")
|
||||
var ErrInvalidAncestorSpec = errors.New("invalid ancestor spec")
|
||||
var ErrInvalidBranchOrHash = errors.New("string is not a valid branch or hash")
|
||||
var ErrInvalidHash = errors.New("string is not a valid hash")
|
||||
|
||||
var ErrFoundHashNotACommit = errors.New("the value retrieved for this hash is not a commit")
|
||||
var ErrHashNotFound = errors.New("could not find a value for this hash")
|
||||
@@ -39,7 +38,6 @@ var ErrWorkspaceNotFound = errors.New("workspace not found")
|
||||
var ErrTableNotFound = errors.New("table not found")
|
||||
var ErrTableExists = errors.New("table already exists")
|
||||
var ErrAlreadyOnBranch = errors.New("Already on branch")
|
||||
var ErrAlreadyOnWorkspace = errors.New("Already on workspace")
|
||||
|
||||
var ErrNomsIO = errors.New("error reading from or writing to noms")
|
||||
|
||||
@@ -109,37 +107,6 @@ func (rt RootType) String() string {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
type RootTypeSet map[RootType]struct{}
|
||||
|
||||
func NewRootTypeSet(rts ...RootType) RootTypeSet {
|
||||
mp := make(map[RootType]struct{})
|
||||
|
||||
for _, rt := range rts {
|
||||
mp[rt] = struct{}{}
|
||||
}
|
||||
|
||||
return RootTypeSet(mp)
|
||||
}
|
||||
|
||||
func (rts RootTypeSet) Contains(rt RootType) bool {
|
||||
_, ok := rts[rt]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (rts RootTypeSet) First(rtList []RootType) RootType {
|
||||
for _, rt := range rtList {
|
||||
if _, ok := rts[rt]; ok {
|
||||
return rt
|
||||
}
|
||||
}
|
||||
|
||||
return InvalidRoot
|
||||
}
|
||||
|
||||
func (rts RootTypeSet) IsEmpty() bool {
|
||||
return len(rts) == 0
|
||||
}
|
||||
|
||||
type RootValueUnreadable struct {
|
||||
RootType RootType
|
||||
Cause error
|
||||
|
||||
@@ -51,7 +51,8 @@ func getNonlocalTablesRef(_ context.Context, valDesc *val.TupleDesc, valTuple va
|
||||
return result
|
||||
}
|
||||
|
||||
func GetGlobalTablePatterns(ctx context.Context, root RootValue, schema string, cb func(string)) error {
|
||||
// GetNonlocalTablePatterns invokes |cb| once for each table name pattern in dolt_nonlocal_tables on |root| and |schema|.
|
||||
func GetNonlocalTablePatterns(ctx context.Context, root RootValue, schema string, cb func(string)) error {
|
||||
table_name := TableName{Name: NonlocalTableName, Schema: schema}
|
||||
table, found, err := root.GetTable(ctx, table_name)
|
||||
if err != nil {
|
||||
|
||||
@@ -472,6 +472,10 @@ func encodeTableNameForSerialization(name TableName) string {
|
||||
// decodeTableNameFromSerialization decodes a table name from a serialized string. See notes on serialization in
|
||||
// |encodeTableNameForSerialization|
|
||||
func decodeTableNameFromSerialization(encodedName string) (TableName, bool) {
|
||||
if len(encodedName) == 0 {
|
||||
return TableName{}, false
|
||||
}
|
||||
|
||||
if encodedName[0] != 0 {
|
||||
return TableName{Name: encodedName}, true
|
||||
} else if len(encodedName) >= 4 { // 2 null bytes plus at least one char for schema and table name
|
||||
|
||||
@@ -43,6 +43,35 @@ func MatchTablePattern(pattern string, table string) (bool, error) {
|
||||
return re.MatchString(table), nil
|
||||
}
|
||||
|
||||
// CompiledTablePatterns holds compiled table name patterns for reuse when matching many names without recompiling.
|
||||
type CompiledTablePatterns []*regexp.Regexp
|
||||
|
||||
// CompileTablePatterns compiles each of |patterns| once and returns them for use with TableMatchesAny. Returns (nil, nil) when |patterns| is empty.
|
||||
func CompileTablePatterns(patterns []string) (CompiledTablePatterns, error) {
|
||||
if len(patterns) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
compiled := make(CompiledTablePatterns, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
re, err := compilePattern(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
compiled = append(compiled, re)
|
||||
}
|
||||
return compiled, nil
|
||||
}
|
||||
|
||||
// TableMatchesAny reports whether |table| matches any of the patterns in |c|.
|
||||
func (c CompiledTablePatterns) TableMatchesAny(table string) bool {
|
||||
for _, re := range c {
|
||||
if re.MatchString(table) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetMatchingTables returns all tables that match a pattern
|
||||
func GetMatchingTables(ctx *sql.Context, root RootValue, schemaName string, pattern string) (results []string, err error) {
|
||||
// If the pattern doesn't contain any special characters, look up that name.
|
||||
|
||||
@@ -75,6 +75,8 @@ type RebaseState struct {
|
||||
// rebasingStarted is true once the rebase plan has been started to execute. Once rebasingStarted is true, the
|
||||
// value in lastAttemptedStep has been initialized and is valid to read.
|
||||
rebasingStarted bool
|
||||
// skipVerification indicates whether test validation should be skipped during rebase operations.
|
||||
skipVerification bool
|
||||
}
|
||||
|
||||
// Branch returns the name of the branch being actively rebased. This is the branch that will be updated to point
|
||||
@@ -120,6 +122,10 @@ func (rs RebaseState) WithRebasingStarted(rebasingStarted bool) *RebaseState {
|
||||
return &rs
|
||||
}
|
||||
|
||||
func (rs RebaseState) SkipVerification() bool {
|
||||
return rs.skipVerification
|
||||
}
|
||||
|
||||
type MergeState struct {
|
||||
// the source commit
|
||||
commit *Commit
|
||||
@@ -322,13 +328,14 @@ func (ws WorkingSet) StartMerge(commit *Commit, commitSpecStr string) *WorkingSe
|
||||
// the branch that is being rebased, and |previousRoot| is root value of the branch being rebased. The HEAD and STAGED
|
||||
// root values of the branch being rebased must match |previousRoot|; WORKING may be a different root value, but ONLY
|
||||
// if it contains only ignored tables.
|
||||
func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling) (*WorkingSet, error) {
|
||||
func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling, skipVerification bool) (*WorkingSet, error) {
|
||||
ws.rebaseState = &RebaseState{
|
||||
ontoCommit: ontoCommit,
|
||||
preRebaseWorking: previousRoot,
|
||||
branch: branch,
|
||||
commitBecomesEmptyHandling: commitBecomesEmptyHandling,
|
||||
emptyCommitHandling: emptyCommitHandling,
|
||||
skipVerification: skipVerification,
|
||||
}
|
||||
|
||||
ontoRoot, err := ontoCommit.GetRootValue(ctx)
|
||||
@@ -549,6 +556,7 @@ func newWorkingSet(ctx context.Context, name string, vrw types.ValueReadWriter,
|
||||
emptyCommitHandling: EmptyCommitHandling(dsws.RebaseState.EmptyCommitHandling(ctx)),
|
||||
lastAttemptedStep: dsws.RebaseState.LastAttemptedStep(ctx),
|
||||
rebasingStarted: dsws.RebaseState.RebasingStarted(ctx),
|
||||
skipVerification: dsws.RebaseState.SkipVerification(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -646,7 +654,7 @@ func (ws *WorkingSet) writeValues(ctx context.Context, db *DoltDB, meta *datas.W
|
||||
|
||||
rebaseState = datas.NewRebaseState(preRebaseWorking.TargetHash(), dCommit.Addr(), ws.rebaseState.branch,
|
||||
uint8(ws.rebaseState.commitBecomesEmptyHandling), uint8(ws.rebaseState.emptyCommitHandling),
|
||||
ws.rebaseState.lastAttemptedStep, ws.rebaseState.rebasingStarted)
|
||||
ws.rebaseState.lastAttemptedStep, ws.rebaseState.rebasingStarted, ws.rebaseState.skipVerification)
|
||||
}
|
||||
|
||||
return &datas.WorkingSetSpec{
|
||||
|
||||
+1
-1
@@ -216,7 +216,7 @@ func CleanOldWorkingSet(
|
||||
}
|
||||
|
||||
// we also have to do a clean, because we the ResetHard won't touch any new tables (tables only in the working set)
|
||||
newRoots, err := CleanUntracked(ctx, resetRoots, []string{}, false, true)
|
||||
newRoots, err := CleanUntracked(ctx, resetRoots, []string{}, false, true, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
+108
-8
@@ -15,8 +15,12 @@
|
||||
package actions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gms "github.com/dolthub/go-mysql-server"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
|
||||
@@ -25,14 +29,42 @@ import (
|
||||
)
|
||||
|
||||
type CommitStagedProps struct {
|
||||
Message string
|
||||
Date time.Time
|
||||
AllowEmpty bool
|
||||
SkipEmpty bool
|
||||
Amend bool
|
||||
Force bool
|
||||
Name string
|
||||
Email string
|
||||
Message string
|
||||
Date time.Time
|
||||
AllowEmpty bool
|
||||
SkipEmpty bool
|
||||
Amend bool
|
||||
Force bool
|
||||
Name string
|
||||
Email string
|
||||
SkipVerification bool
|
||||
}
|
||||
|
||||
const (
|
||||
// System variable name, defined here to avoid circular imports
|
||||
DoltCommitVerificationGroups = "dolt_commit_verification_groups"
|
||||
)
|
||||
|
||||
// GetCommitRunTestGroups returns the test groups to run for commit operations
|
||||
// Returns empty slice if no tests should be run, ["*"] if all tests should be run,
|
||||
// or specific group names if only those groups should be run
|
||||
func GetCommitRunTestGroups() []string {
|
||||
_, val, ok := sql.SystemVariables.GetGlobal(DoltCommitVerificationGroups)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if stringVal, ok := val.(string); ok && stringVal != "" {
|
||||
if stringVal == "*" {
|
||||
return []string{"*"}
|
||||
}
|
||||
// Split by comma and trim whitespace
|
||||
groups := strings.Split(stringVal, ",")
|
||||
for i, group := range groups {
|
||||
groups[i] = strings.TrimSpace(group)
|
||||
}
|
||||
return groups
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCommitStaged returns a new pending commit with the roots and commit properties given.
|
||||
@@ -114,6 +146,16 @@ func GetCommitStaged(
|
||||
}
|
||||
}
|
||||
|
||||
if !props.SkipVerification {
|
||||
testGroups := GetCommitRunTestGroups()
|
||||
if len(testGroups) > 0 {
|
||||
err := runCommitVerification(ctx, testGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
meta, err := datas.NewCommitMetaWithUserTS(props.Name, props.Email, props.Message, props.Date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -121,3 +163,61 @@ func GetCommitStaged(
|
||||
|
||||
return db.NewPendingCommit(ctx, roots, mergeParents, props.Amend, meta)
|
||||
}
|
||||
|
||||
func runCommitVerification(ctx *sql.Context, testGroups []string) error {
|
||||
type sessionInterface interface {
|
||||
sql.Session
|
||||
GenericProvider() sql.MutableDatabaseProvider
|
||||
}
|
||||
|
||||
session, ok := ctx.Session.(sessionInterface)
|
||||
if !ok {
|
||||
return fmt.Errorf("session does not provide database provider interface")
|
||||
}
|
||||
|
||||
provider := session.GenericProvider()
|
||||
engine := gms.NewDefault(provider)
|
||||
|
||||
return runTestsUsingDtablefunctions(ctx, engine, testGroups)
|
||||
}
|
||||
|
||||
// runTestsUsingDtablefunctions runs tests using the dtablefunctions package against the staged root
|
||||
func runTestsUsingDtablefunctions(ctx *sql.Context, engine *gms.Engine, testGroups []string) error {
|
||||
if len(testGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var allFailures []string
|
||||
|
||||
for _, group := range testGroups {
|
||||
query := fmt.Sprintf("SELECT * FROM dolt_test_run('%s')", group)
|
||||
_, iter, _, err := engine.Query(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run dolt_test_run for group %s: %w", group, err)
|
||||
}
|
||||
|
||||
for {
|
||||
row, rErr := iter.Next(ctx)
|
||||
if rErr == io.EOF {
|
||||
break
|
||||
}
|
||||
if rErr != nil {
|
||||
return fmt.Errorf("error reading test results: %w", rErr)
|
||||
}
|
||||
|
||||
// Extract status (column 3)
|
||||
status := fmt.Sprintf("%v", row[3])
|
||||
if status != "PASS" {
|
||||
testName := fmt.Sprintf("%v", row[0])
|
||||
message := fmt.Sprintf("%v", row[4])
|
||||
allFailures = append(allFailures, fmt.Sprintf("%s (%s)", testName, message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(allFailures) > 0 {
|
||||
return fmt.Errorf("commit verification failed: %s", strings.Join(allFailures, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+42
-28
@@ -270,60 +270,74 @@ func getUnionedTables(ctx context.Context, tables []doltdb.TableName, stagedRoot
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// CleanUntracked deletes untracked tables from the working root.
|
||||
// Evaluates untracked tables as: all working tables - all staged tables.
|
||||
func CleanUntracked(ctx *sql.Context, roots doltdb.Roots, tables []string, dryrun bool, force bool) (doltdb.Roots, error) {
|
||||
// CleanUntracked deletes from the working root the tables that are untracked (in working but not in staged/head). If
|
||||
// |tables| is non-empty it uses only those names as candidates; otherwise it uses all working tables. Tables matching
|
||||
// dolt_nonlocal_tables are always excluded. When |respectIgnoreRules| is true, tables matching dolt_ignore are also excluded. Does nothing when |dryrun| is true.
|
||||
func CleanUntracked(ctx *sql.Context, roots doltdb.Roots, tables []string, dryrun bool, force bool, respectIgnoreRules bool) (doltdb.Roots, error) {
|
||||
untrackedTables := make(map[doltdb.TableName]struct{})
|
||||
for _, name := range tables {
|
||||
resolvedName, tblExists, err := resolve.TableName(ctx, roots.Working, name)
|
||||
if err != nil {
|
||||
return doltdb.Roots{}, err
|
||||
}
|
||||
if !tblExists {
|
||||
return doltdb.Roots{}, fmt.Errorf("%w: '%s'", doltdb.ErrTableNotFound, name)
|
||||
}
|
||||
untrackedTables[resolvedName] = struct{}{}
|
||||
}
|
||||
|
||||
var err error
|
||||
if len(tables) == 0 {
|
||||
allTableNames, err := roots.Working.GetAllTableNames(ctx, true)
|
||||
if err != nil {
|
||||
return doltdb.Roots{}, nil
|
||||
return doltdb.Roots{}, err
|
||||
}
|
||||
for _, tableName := range allTableNames {
|
||||
untrackedTables[tableName] = struct{}{}
|
||||
}
|
||||
} else {
|
||||
for i := range tables {
|
||||
name := tables[i]
|
||||
resolvedName, tblExists, err := resolve.TableName(ctx, roots.Working, name)
|
||||
var candidates []doltdb.TableName
|
||||
if respectIgnoreRules {
|
||||
candidates, err = doltdb.ExcludeIgnoredTables(ctx, roots, allTableNames)
|
||||
if err != nil {
|
||||
return doltdb.Roots{}, err
|
||||
}
|
||||
if !tblExists {
|
||||
return doltdb.Roots{}, fmt.Errorf("%w: '%s'", doltdb.ErrTableNotFound, name)
|
||||
} else {
|
||||
candidates = allTableNames
|
||||
}
|
||||
var nonlocalPatterns []string
|
||||
err = doltdb.GetNonlocalTablePatterns(ctx, roots.Working, doltdb.DefaultSchemaName, func(p string) {
|
||||
nonlocalPatterns = append(nonlocalPatterns, p)
|
||||
})
|
||||
if err != nil {
|
||||
return doltdb.Roots{}, err
|
||||
}
|
||||
compiled, err := doltdb.CompileTablePatterns(nonlocalPatterns)
|
||||
if err != nil {
|
||||
return doltdb.Roots{}, err
|
||||
}
|
||||
for _, tableName := range candidates {
|
||||
if compiled.TableMatchesAny(tableName.Name) {
|
||||
continue
|
||||
}
|
||||
untrackedTables[resolvedName] = struct{}{}
|
||||
untrackedTables[tableName] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// untracked tables = working tables - staged tables
|
||||
headTblNames := GetAllTableNames(ctx, roots.Staged)
|
||||
if err != nil {
|
||||
return doltdb.Roots{}, err
|
||||
}
|
||||
|
||||
for _, name := range headTblNames {
|
||||
delete(untrackedTables, name)
|
||||
}
|
||||
|
||||
newRoot := roots.Working
|
||||
var toDelete []doltdb.TableName
|
||||
toDelete := make([]doltdb.TableName, 0, len(untrackedTables))
|
||||
for t := range untrackedTables {
|
||||
toDelete = append(toDelete, t)
|
||||
}
|
||||
|
||||
newRoot, err = newRoot.RemoveTables(ctx, force, force, toDelete...)
|
||||
if err != nil {
|
||||
return doltdb.Roots{}, fmt.Errorf("failed to remove tables; %w", err)
|
||||
}
|
||||
|
||||
if dryrun {
|
||||
return roots, nil
|
||||
}
|
||||
roots.Working = newRoot
|
||||
|
||||
newRoot, err := roots.Working.RemoveTables(ctx, force, force, toDelete...)
|
||||
if err != nil {
|
||||
return doltdb.Roots{}, fmt.Errorf("failed to remove tables; %w", err)
|
||||
}
|
||||
roots.Working = newRoot
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,447 +0,0 @@
|
||||
// Copyright 2025 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/shopspring/decimal"
|
||||
"golang.org/x/exp/constraints"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/val"
|
||||
)
|
||||
|
||||
const (
|
||||
AssertionExpectedRows = "expected_rows"
|
||||
AssertionExpectedColumns = "expected_columns"
|
||||
AssertionExpectedSingleValue = "expected_single_value"
|
||||
)
|
||||
|
||||
// AssertData parses an assertion, comparison, and value, then returns the status of the test.
|
||||
// Valid comparison are: "==", "!=", "<", ">", "<=", and ">=".
|
||||
// testPassed indicates whether the test was successful or not.
|
||||
// message is a string used to indicate test failures, and will not halt the overall process.
|
||||
// message will be empty if the test passed.
|
||||
// err indicates runtime failures and will stop dolt_test_run from proceeding.
|
||||
func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) {
|
||||
switch assertion {
|
||||
case AssertionExpectedRows:
|
||||
message, err = expectRows(sqlCtx, comparison, value, queryResult)
|
||||
case AssertionExpectedColumns:
|
||||
message, err = expectColumns(sqlCtx, comparison, value, queryResult)
|
||||
case AssertionExpectedSingleValue:
|
||||
message, err = expectSingleValue(sqlCtx, comparison, value, queryResult)
|
||||
default:
|
||||
return false, fmt.Sprintf("%s is not a valid assertion type", assertion), nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
} else if message != "" {
|
||||
return false, message, nil
|
||||
}
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
row, err := queryResult.Next(sqlCtx)
|
||||
if err == io.EOF {
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(row) != 1 {
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil
|
||||
}
|
||||
_, err = queryResult.Next(sqlCtx)
|
||||
if err == nil { //If multiple rows were given, we should error out
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple rows"), nil
|
||||
} else if err != io.EOF { // "True" error, so we should quit out
|
||||
return "", err
|
||||
}
|
||||
|
||||
if value == nil { // If we're expecting a null value, we don't need to type switch
|
||||
return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil
|
||||
}
|
||||
|
||||
// Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception
|
||||
// of "0" and "1", which are valid integers and are covered below.
|
||||
if *value != "0" && *value != "1" {
|
||||
if expectedBool, err := strconv.ParseBool(*value); err == nil {
|
||||
actualBool, boolErr := getInterfaceAsBool(row[0])
|
||||
if boolErr != nil {
|
||||
return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil
|
||||
}
|
||||
return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil
|
||||
}
|
||||
}
|
||||
|
||||
switch actualValue := row[0].(type) {
|
||||
case int8:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int8(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int16:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int16(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int32:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int64:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil
|
||||
case int:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint8:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint8(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint16:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint16(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint32:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint64:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedUint, actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case float64:
|
||||
expectedFloat, err := strconv.ParseFloat(*value, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil
|
||||
case float32:
|
||||
expectedFloat, err := strconv.ParseFloat(*value, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil
|
||||
case decimal.Decimal:
|
||||
expectedDecimal, err := decimal.NewFromString(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil
|
||||
}
|
||||
return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil
|
||||
case time.Time:
|
||||
expectedTime, format, err := parseTestsDate(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%s does not appear to be a valid date", *value), nil
|
||||
}
|
||||
return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil
|
||||
case *val.TextStorage, string:
|
||||
actualString, err := GetStringColAsString(sqlCtx, actualValue)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil
|
||||
default:
|
||||
return fmt.Sprintf("Type %T is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil
|
||||
}
|
||||
}
|
||||
|
||||
func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
if value == nil {
|
||||
return "null is not a valid assertion for expected_rows", nil
|
||||
}
|
||||
expectedRows, err := strconv.Atoi(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
|
||||
}
|
||||
|
||||
var numRows int
|
||||
for {
|
||||
_, err := queryResult.Next(sqlCtx)
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
numRows++
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil
|
||||
}
|
||||
|
||||
func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
if value == nil {
|
||||
return "null is not a valid assertion for expected_rows", nil
|
||||
}
|
||||
expectedColumns, err := strconv.Atoi(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
|
||||
}
|
||||
|
||||
var numColumns int
|
||||
row, err := queryResult.Next(sqlCtx)
|
||||
if err != nil && err != io.EOF {
|
||||
return "", err
|
||||
}
|
||||
numColumns = len(row)
|
||||
return compareTestAssertion(comparison, expectedColumns, numColumns, AssertionExpectedColumns), nil
|
||||
}
|
||||
|
||||
// compareTestAssertion is a generic function used for comparing string, ints, floats.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareTestAssertion[T constraints.Ordered](comparison string, expectedValue, actualValue T, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if actualValue != expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "!=":
|
||||
if actualValue == expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "<":
|
||||
if actualValue >= expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "<=":
|
||||
if actualValue > expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case ">":
|
||||
if actualValue <= expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case ">=":
|
||||
if actualValue < expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseTestsDate is an internal function that parses the queried string according to allowed time formats for dolt_tests.
|
||||
// It returns the parsed time, the format that succeeded, and an error if applicable.
|
||||
func parseTestsDate(value string) (parsedTime time.Time, format string, err error) {
|
||||
// List of valid formats
|
||||
formats := []string{
|
||||
time.DateOnly,
|
||||
time.DateTime,
|
||||
time.TimeOnly,
|
||||
time.RFC3339,
|
||||
time.RFC1123Z,
|
||||
}
|
||||
|
||||
for _, format := range formats {
|
||||
if parsedTime, parseErr := time.Parse(format, value); parseErr == nil {
|
||||
return parsedTime, format, nil
|
||||
} else {
|
||||
err = parseErr
|
||||
}
|
||||
}
|
||||
return time.Time{}, "", err
|
||||
}
|
||||
|
||||
// compareDates is a function used for comparing time values.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareDates(comparison string, expectedValue, realValue time.Time, format string, assertionType string) string {
|
||||
expectedStr := expectedValue.Format(format)
|
||||
realStr := realValue.Format(format)
|
||||
switch comparison {
|
||||
case "==":
|
||||
if !expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "<":
|
||||
if realValue.Equal(expectedValue) || realValue.After(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "<=":
|
||||
if realValue.After(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case ">":
|
||||
if realValue.Before(expectedValue) || realValue.Equal(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case ">=":
|
||||
if realValue.Before(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// compareDecimals is a function used for comparing decimals.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if !expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "<":
|
||||
if realValue.GreaterThanOrEqual(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "<=":
|
||||
if realValue.GreaterThan(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case ">":
|
||||
if realValue.LessThanOrEqual(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case ">=":
|
||||
if realValue.LessThan(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getTinyIntColAsBool returns the value interface{} as a bool
|
||||
// This is necessary because the query engine may return a tinyint column as a bool, int, or other types.
|
||||
// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles.
|
||||
func getInterfaceAsBool(col interface{}) (bool, error) {
|
||||
switch v := col.(type) {
|
||||
case bool:
|
||||
return v, nil
|
||||
case int:
|
||||
return v == 1, nil
|
||||
case int8:
|
||||
return v == 1, nil
|
||||
case int16:
|
||||
return v == 1, nil
|
||||
case int32:
|
||||
return v == 1, nil
|
||||
case int64:
|
||||
return v == 1, nil
|
||||
case uint:
|
||||
return v == 1, nil
|
||||
case uint8:
|
||||
return v == 1, nil
|
||||
case uint16:
|
||||
return v == 1, nil
|
||||
case uint32:
|
||||
return v == 1, nil
|
||||
case uint64:
|
||||
return v == 1, nil
|
||||
case string:
|
||||
return v == "1", nil
|
||||
default:
|
||||
return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v)
|
||||
}
|
||||
}
|
||||
|
||||
// compareBooleans is a function used for comparing boolean values.
|
||||
// It takes in a comparison string from one of: "==", "!="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if expectedValue != realValue {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue == realValue {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// compareNullValue is a function used for comparing a null value.
|
||||
// It takes in a comparison string from one of: "==", "!="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareNullValue(comparison string, actualValue interface{}, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if actualValue != nil {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue)
|
||||
}
|
||||
case "!=":
|
||||
if actualValue == nil {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetStringColAsString is a function that returns a text column as a string.
|
||||
// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations,
|
||||
// so we use a special parser to get the correct string values
|
||||
func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
|
||||
if ts, ok := tableValue.(*val.TextStorage); ok {
|
||||
str, err := ts.Unwrap(sqlCtx)
|
||||
return &str, err
|
||||
} else if str, ok := tableValue.(string); ok {
|
||||
return &str, nil
|
||||
} else if tableValue == nil {
|
||||
return nil, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
|
||||
}
|
||||
}
|
||||
-35
@@ -33,7 +33,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/grpcendpoint"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
@@ -1350,40 +1349,6 @@ func (dEnv *DoltEnv) TempTableFilesDir() (string, error) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func (dEnv *DoltEnv) DbEaFactory(ctx context.Context) (editor.DbEaFactory, error) {
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := dEnv.DoltDB(ctx)
|
||||
if db == nil {
|
||||
if dEnv.DBLoadError != nil {
|
||||
return nil, dEnv.DBLoadError
|
||||
}
|
||||
return nil, errors.New("DoltDB failed to initialize but no error was recorded")
|
||||
}
|
||||
|
||||
return editor.NewDbEaFactory(tmpDir, db.ValueReadWriter()), nil
|
||||
}
|
||||
|
||||
func (dEnv *DoltEnv) BulkDbEaFactory(ctx context.Context) (editor.DbEaFactory, error) {
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := dEnv.DoltDB(ctx)
|
||||
if db == nil {
|
||||
if dEnv.DBLoadError != nil {
|
||||
return nil, dEnv.DBLoadError
|
||||
}
|
||||
return nil, errors.New("DoltDB failed to initialize but no error was recorded")
|
||||
}
|
||||
|
||||
return editor.NewBulkImportTEAFactory(db.ValueReadWriter(), tmpDir), nil
|
||||
}
|
||||
|
||||
func (dEnv *DoltEnv) IsAccessModeReadOnly(ctx context.Context) (bool, error) {
|
||||
db := dEnv.DoltDB(ctx)
|
||||
if db == nil {
|
||||
|
||||
Vendored
-248
@@ -1,248 +0,0 @@
|
||||
// Copyright 2021 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package env
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/dolt/go/store/chunks"
|
||||
"github.com/dolthub/dolt/go/store/datas"
|
||||
"github.com/dolthub/dolt/go/store/hash"
|
||||
)
|
||||
|
||||
func NewMemoryDbData(ctx context.Context, cfg config.ReadableConfig) (DbData[context.Context], error) {
|
||||
branchName := GetDefaultInitBranch(cfg)
|
||||
|
||||
ddb, err := NewMemoryDoltDB(ctx, branchName)
|
||||
if err != nil {
|
||||
return DbData[context.Context]{}, err
|
||||
}
|
||||
|
||||
rs, err := NewMemoryRepoState(ctx, ddb, branchName)
|
||||
if err != nil {
|
||||
return DbData[context.Context]{}, err
|
||||
}
|
||||
|
||||
return DbData[context.Context]{
|
||||
Ddb: ddb,
|
||||
Rsw: rs,
|
||||
Rsr: rs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewMemoryDoltDB(ctx context.Context, initBranch string) (*doltdb.DoltDB, error) {
|
||||
ts := &chunks.TestStorage{}
|
||||
cs := ts.NewViewWithDefaultFormat()
|
||||
ddb, err := doltdb.DoltDBFromCS(cs, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := "memory"
|
||||
branchRef := ref.NewBranchRef(initBranch)
|
||||
err = ddb.WriteEmptyRepoWithCommitTimeAndDefaultBranch(ctx, m, m, datas.CommitterDate(), branchRef)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ddb, nil
|
||||
}
|
||||
|
||||
func NewMemoryRepoState(ctx context.Context, ddb *doltdb.DoltDB, initBranch string) (MemoryRepoState, error) {
|
||||
head := ref.NewBranchRef(initBranch)
|
||||
rs := MemoryRepoState{
|
||||
DoltDB: ddb,
|
||||
Head: head,
|
||||
}
|
||||
|
||||
commit, err := ddb.ResolveCommitRef(ctx, head)
|
||||
if err != nil {
|
||||
return MemoryRepoState{}, err
|
||||
}
|
||||
|
||||
root, err := commit.GetRootValue(ctx)
|
||||
if err != nil {
|
||||
return MemoryRepoState{}, err
|
||||
}
|
||||
|
||||
err = rs.UpdateWorkingRoot(ctx, root)
|
||||
if err != nil {
|
||||
return MemoryRepoState{}, err
|
||||
}
|
||||
|
||||
err = rs.UpdateStagedRoot(ctx, root)
|
||||
if err != nil {
|
||||
return MemoryRepoState{}, err
|
||||
}
|
||||
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
type MemoryRepoState struct {
|
||||
DoltDB *doltdb.DoltDB
|
||||
Head ref.DoltRef
|
||||
}
|
||||
|
||||
var _ RepoStateReader[context.Context] = MemoryRepoState{}
|
||||
var _ RepoStateWriter = MemoryRepoState{}
|
||||
|
||||
func (m MemoryRepoState) CWBHeadRef(context.Context) (ref.DoltRef, error) {
|
||||
return m.Head, nil
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) CWBHeadSpec(ctx context.Context) (*doltdb.CommitSpec, error) {
|
||||
headRef, err := m.CWBHeadRef(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
spec, err := doltdb.NewCommitSpec(headRef.GetPath())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) UpdateStagedRoot(ctx context.Context, newRoot doltdb.RootValue) error {
|
||||
var h hash.Hash
|
||||
var wsRef ref.WorkingSetRef
|
||||
|
||||
ws, err := m.WorkingSet(ctx)
|
||||
if err == doltdb.ErrWorkingSetNotFound {
|
||||
// first time updating root
|
||||
headRef, err := m.CWBHeadRef(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
wsRef, err = ref.WorkingSetRefForHead(headRef)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ws = doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(newRoot).WithStagedRoot(newRoot)
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
h, err = ws.HashOf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wsRef = ws.Ref()
|
||||
}
|
||||
|
||||
return m.DoltDB.UpdateWorkingSet(ctx, wsRef, ws.WithStagedRoot(newRoot), h, m.workingSetMeta(), nil)
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) UpdateWorkingRoot(ctx context.Context, newRoot doltdb.RootValue) error {
|
||||
var h hash.Hash
|
||||
var wsRef ref.WorkingSetRef
|
||||
|
||||
ws, err := m.WorkingSet(ctx)
|
||||
if err == doltdb.ErrWorkingSetNotFound {
|
||||
// first time updating root
|
||||
headRef, err := m.CWBHeadRef(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
wsRef, err = ref.WorkingSetRefForHead(headRef)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ws = doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(newRoot).WithStagedRoot(newRoot)
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
h, err = ws.HashOf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wsRef = ws.Ref()
|
||||
}
|
||||
|
||||
return m.DoltDB.UpdateWorkingSet(ctx, wsRef, ws.WithWorkingRoot(newRoot), h, m.workingSetMeta(), nil)
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) WorkingSet(ctx context.Context) (*doltdb.WorkingSet, error) {
|
||||
headRef, err := m.CWBHeadRef(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workingSetRef, err := ref.WorkingSetRefForHead(headRef)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
workingSet, err := m.DoltDB.ResolveWorkingSet(ctx, workingSetRef)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return workingSet, nil
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) workingSetMeta() *datas.WorkingSetMeta {
|
||||
return &datas.WorkingSetMeta{
|
||||
Timestamp: uint64(time.Now().Unix()),
|
||||
Description: "updated from dolt environment",
|
||||
}
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) SetCWBHeadRef(_ context.Context, r ref.MarshalableRef) (err error) {
|
||||
m.Head = r.Ref
|
||||
return
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) GetRemotes() (*concurrentmap.Map[string, Remote], error) {
|
||||
return concurrentmap.New[string, Remote](), nil
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) AddRemote(r Remote) error {
|
||||
return fmt.Errorf("cannot insert a remote in a memory database")
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) GetBranches() (*concurrentmap.Map[string, BranchConfig], error) {
|
||||
return concurrentmap.New[string, BranchConfig](), nil
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) UpdateBranch(name string, new BranchConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) RemoveRemote(ctx context.Context, name string) error {
|
||||
return fmt.Errorf("cannot delete a remote from a memory database")
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) TempTableFilesDir() (string, error) {
|
||||
return os.TempDir(), nil
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) GetBackups() (*concurrentmap.Map[string, Remote], error) {
|
||||
panic("cannot get backups on in memory database")
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) AddBackup(r Remote) error {
|
||||
panic("cannot add backup to in memory database")
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) RemoveBackup(ctx context.Context, name string) error {
|
||||
panic("cannot remove backup from in memory database")
|
||||
}
|
||||
@@ -127,7 +127,6 @@ func testDataMergeHelper(t *testing.T, tests []dataMergeTest, flipSides bool) {
|
||||
|
||||
var mo merge.MergeOpts
|
||||
var eo editor.Options
|
||||
eo = eo.WithDeaf(editor.NewInMemDeaf(a.VRW()))
|
||||
// attempt merge before skipping to assert no panics
|
||||
result, err := merge.MergeRoots(sql.NewContext(ctx), doltdb.SimpleTableResolver{}, l, r, a, rootish{r}, rootish{a}, eo, mo)
|
||||
|
||||
@@ -147,7 +146,6 @@ func testDataMergeHelper(t *testing.T, tests []dataMergeTest, flipSides bool) {
|
||||
func setupDataMergeTest(ctx context.Context, t *testing.T, schema namedSchema, test dataTest) (anc, left, right, merged doltdb.RootValue) {
|
||||
denv := dtestutils.CreateTestEnv()
|
||||
var eo editor.Options
|
||||
eo = eo.WithDeaf(editor.NewInMemDeaf(denv.DoltDB(ctx).ValueReadWriter()))
|
||||
|
||||
ancestorTable := tbl(schema, test.ancestor...)
|
||||
anc = makeRootWithTable(t, denv.DoltDB(ctx), eo, *ancestorTable)
|
||||
|
||||
@@ -458,18 +458,6 @@ type keylessEntry struct {
|
||||
c2 int
|
||||
}
|
||||
|
||||
func (e keylessEntries) toTupleSet() tupleSet {
|
||||
tups := make([]types.Tuple, len(e))
|
||||
for i, t := range e {
|
||||
tups[i] = t.ToNomsTuple()
|
||||
}
|
||||
return mustTupleSet(tups...)
|
||||
}
|
||||
|
||||
func (e keylessEntry) ToNomsTuple() types.Tuple {
|
||||
return dtu.MustTuple(cardTag, types.Uint(e.card), c1Tag, types.Int(e.c1), c2Tag, types.Int(e.c2))
|
||||
}
|
||||
|
||||
func (e keylessEntry) HashAndValue() ([]byte, val.Tuple, error) {
|
||||
valBld.PutUint64(0, uint64(e.card))
|
||||
valBld.PutInt64(1, int64(e.c1))
|
||||
@@ -497,14 +485,6 @@ func (e conflictEntries) toConflictSet(t *testing.T) conflictSet {
|
||||
return s
|
||||
}
|
||||
|
||||
func (e conflictEntries) toTupleSet() tupleSet {
|
||||
tups := make([]types.Tuple, len(e))
|
||||
for i, t := range e {
|
||||
tups[i] = t.ToNomsTuple()
|
||||
}
|
||||
return mustTupleSet(tups...)
|
||||
}
|
||||
|
||||
func (e conflictEntry) Key(t *testing.T) (h [16]byte) {
|
||||
if e.base != nil {
|
||||
h2, _, err := e.base.HashAndValue()
|
||||
@@ -528,34 +508,6 @@ func (e conflictEntry) Key(t *testing.T) (h [16]byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func (e conflictEntry) ToNomsTuple() types.Tuple {
|
||||
var b, o, t types.Value = types.NullValue, types.NullValue, types.NullValue
|
||||
if e.base != nil {
|
||||
b = e.base.ToNomsTuple()
|
||||
}
|
||||
if e.ours != nil {
|
||||
o = e.ours.ToNomsTuple()
|
||||
}
|
||||
if e.theirs != nil {
|
||||
t = e.theirs.ToNomsTuple()
|
||||
}
|
||||
return dtu.MustTuple(b, o, t)
|
||||
}
|
||||
|
||||
type tupleSet map[hash.Hash]types.Tuple
|
||||
|
||||
func mustTupleSet(tt ...types.Tuple) (s tupleSet) {
|
||||
s = make(tupleSet, len(tt))
|
||||
for _, tup := range tt {
|
||||
h, err := tup.Hash(types.Format_Default)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s[h] = tup
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type hash128Set map[[16]byte]val.Tuple
|
||||
|
||||
func mustHash128Set(entries ...keylessEntry) (s hash128Set) {
|
||||
|
||||
@@ -413,14 +413,6 @@ type ArtifactStatus struct {
|
||||
ConstraintViolationsTables []string
|
||||
}
|
||||
|
||||
func (as ArtifactStatus) HasConflicts() bool {
|
||||
return len(as.DataConflictTables) > 0 || len(as.SchemaConflictsTables) > 0
|
||||
}
|
||||
|
||||
func (as ArtifactStatus) HasConstraintViolations() bool {
|
||||
return len(as.ConstraintViolationsTables) > 0
|
||||
}
|
||||
|
||||
// MergeWouldStompChanges returns list of table names that are stomped and the diffs map between head and working set.
|
||||
func MergeWouldStompChanges(ctx context.Context, roots doltdb.Roots, mergeCommit *doltdb.Commit) ([]doltdb.TableName, map[doltdb.TableName]hash.Hash, error) {
|
||||
mergeRoot, err := mergeCommit.GetRootValue(ctx)
|
||||
|
||||
@@ -32,7 +32,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor/creation"
|
||||
filesys2 "github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/valutil"
|
||||
"github.com/dolthub/dolt/go/store/datas"
|
||||
"github.com/dolthub/dolt/go/store/pool"
|
||||
"github.com/dolthub/dolt/go/store/prolly"
|
||||
@@ -76,10 +75,6 @@ func (v rowV) value() val.Tuple {
|
||||
return tup
|
||||
}
|
||||
|
||||
func (v rowV) nomsValue() types.Value {
|
||||
return valsToTestTupleWithoutPks([]types.Value{types.Int(v.col1), types.Int(v.col2)})
|
||||
}
|
||||
|
||||
const (
|
||||
NoopAction ActionType = iota
|
||||
InsertAction
|
||||
@@ -505,68 +500,6 @@ func rebuildAllProllyIndexes(ctx *sql.Context, tbl *doltdb.Table) (*doltdb.Table
|
||||
return tbl.SetIndexSet(ctx, indexes)
|
||||
}
|
||||
|
||||
func calcExpectedStats(t *testing.T) *MergeStats {
|
||||
s := &MergeStats{Operation: TableModified}
|
||||
for _, testCase := range testRows {
|
||||
if (testCase.leftAction == InsertAction) != (testCase.rightAction == InsertAction) {
|
||||
if testCase.leftAction == UpdateAction || testCase.rightAction == UpdateAction ||
|
||||
testCase.leftAction == DeleteAction || testCase.rightAction == DeleteAction {
|
||||
// Either the row exists in the ancestor commit and we are
|
||||
// deleting or updating it, or the row doesn't exist and we are
|
||||
// inserting it.
|
||||
t.Fatalf("it's impossible for an insert to be paired with an update or delete")
|
||||
}
|
||||
}
|
||||
|
||||
if testCase.leftAction == NoopAction {
|
||||
switch testCase.rightAction {
|
||||
case NoopAction:
|
||||
case DeleteAction:
|
||||
s.Deletes++
|
||||
case InsertAction:
|
||||
s.Adds++
|
||||
case UpdateAction:
|
||||
s.Modifications++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if testCase.rightAction == NoopAction {
|
||||
switch testCase.leftAction {
|
||||
case NoopAction:
|
||||
case DeleteAction:
|
||||
s.Deletes++
|
||||
case InsertAction:
|
||||
s.Adds++
|
||||
case UpdateAction:
|
||||
s.Modifications++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if testCase.conflict {
|
||||
// (UpdateAction, DeleteAction),
|
||||
// (DeleteAction, UpdateAction),
|
||||
// (UpdateAction, UpdateAction) with conflict,
|
||||
// (InsertAction, InsertAction) with conflict
|
||||
s.DataConflicts++
|
||||
continue
|
||||
}
|
||||
|
||||
if testCase.leftAction == InsertAction && testCase.rightAction == InsertAction {
|
||||
// Equivalent inserts
|
||||
continue
|
||||
}
|
||||
|
||||
if !valutil.NilSafeEqCheck(unwrapNoms(testCase.leftValue), unwrapNoms(testCase.rightValue)) {
|
||||
s.Modifications++
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func mustMakeEmptyRepo(t *testing.T) *doltdb.DoltDB {
|
||||
ddb, _ := doltdb.LoadDoltDB(context.Background(), types.Format_Default, doltdb.InMemDoltDB, filesys2.LocalFS)
|
||||
err := ddb.WriteEmptyRepo(context.Background(), env.DefaultInitBranch, name, email)
|
||||
@@ -645,40 +578,6 @@ func key(i int) val.Tuple {
|
||||
return tup
|
||||
}
|
||||
|
||||
func nomsKey(i int) types.Value {
|
||||
return mustTuple(types.NewTuple(types.Format_Default, types.Uint(idTag), types.Int(i)))
|
||||
}
|
||||
|
||||
func unwrap(v *rowV) val.Tuple {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return v.value()
|
||||
}
|
||||
|
||||
func unwrapNoms(v *rowV) types.Value {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return v.nomsValue()
|
||||
}
|
||||
|
||||
func mustTuple(tpl types.Tuple, err error) types.Tuple {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return tpl
|
||||
}
|
||||
|
||||
func mustString(str string, err error) string {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
func MustDebugFormatProlly(t *testing.T, m prolly.Map) string {
|
||||
s, err := prolly.DebugFormat(context.Background(), m)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -26,15 +26,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/store/val"
|
||||
)
|
||||
|
||||
type nomsRowMergeTest struct {
|
||||
name string
|
||||
row, mergeRow, ancRow types.Value
|
||||
sch schema.Schema
|
||||
expectedResult types.Value
|
||||
expectCellMerge bool
|
||||
expectConflict bool
|
||||
}
|
||||
|
||||
type rowMergeTest struct {
|
||||
name string
|
||||
row, mergeRow, ancRow val.Tuple
|
||||
@@ -69,39 +60,6 @@ func build(ints ...int) []*int {
|
||||
return out
|
||||
}
|
||||
|
||||
var convergentEditCases = []testCase{
|
||||
{
|
||||
"add same row",
|
||||
build(1, 2),
|
||||
build(1, 2),
|
||||
nil,
|
||||
2, 2, 2,
|
||||
build(1, 2),
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"both delete row",
|
||||
nil,
|
||||
nil,
|
||||
build(1, 2),
|
||||
2, 2, 2,
|
||||
nil,
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"modify row to equal value",
|
||||
build(2, 2),
|
||||
build(2, 2),
|
||||
build(1, 1),
|
||||
2, 2, 2,
|
||||
build(2, 2),
|
||||
false,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
var testCases = []testCase{
|
||||
{
|
||||
"insert different rows",
|
||||
@@ -221,35 +179,6 @@ func TestRowMerge(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func valsToTestTupleWithoutPks(vals []types.Value) types.Value {
|
||||
return valsToTestTuple(vals, false)
|
||||
}
|
||||
|
||||
func valsToTestTupleWithPks(vals []types.Value) types.Value {
|
||||
return valsToTestTuple(vals, true)
|
||||
}
|
||||
|
||||
func valsToTestTuple(vals []types.Value, includePrimaryKeys bool) types.Value {
|
||||
if vals == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tplVals := make([]types.Value, 0, 2*len(vals))
|
||||
for i, val := range vals {
|
||||
if !types.IsNull(val) {
|
||||
tag := i
|
||||
// Assume one primary key tag, add 1 to all other tags
|
||||
if includePrimaryKeys {
|
||||
tag++
|
||||
}
|
||||
tplVals = append(tplVals, types.Uint(tag))
|
||||
tplVals = append(tplVals, val)
|
||||
}
|
||||
}
|
||||
|
||||
return mustTuple(types.NewTuple(types.Format_Default, tplVals...))
|
||||
}
|
||||
|
||||
func createRowMergeStruct(t testCase) rowMergeTest {
|
||||
mergedSch := calcMergedSchema(t)
|
||||
leftSch := calcSchema(t.rowCnt)
|
||||
@@ -269,16 +198,6 @@ func createRowMergeStruct(t testCase) rowMergeTest {
|
||||
t.expectConflict}
|
||||
}
|
||||
|
||||
func createNomsRowMergeStruct(t testCase) nomsRowMergeTest {
|
||||
sch := calcMergedSchema(t)
|
||||
|
||||
tpl := valsToTestTupleWithPks(toVals(t.row))
|
||||
mergeTpl := valsToTestTupleWithPks(toVals(t.mergeRow))
|
||||
ancTpl := valsToTestTupleWithPks(toVals(t.ancRow))
|
||||
expectedTpl := valsToTestTupleWithPks(toVals(t.expectedResult))
|
||||
return nomsRowMergeTest{t.name, tpl, mergeTpl, ancTpl, sch, expectedTpl, t.expectCellMerge, t.expectConflict}
|
||||
}
|
||||
|
||||
func calcMergedSchema(t testCase) schema.Schema {
|
||||
longest := t.rowCnt
|
||||
if t.mRowCnt > longest {
|
||||
@@ -323,20 +242,3 @@ func buildTup(sch schema.Schema, r []*int) val.Tuple {
|
||||
}
|
||||
return tup
|
||||
}
|
||||
|
||||
func toVals(ints []*int) []types.Value {
|
||||
if ints == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
v := make([]types.Value, len(ints))
|
||||
for i, d := range ints {
|
||||
if d == nil {
|
||||
v[i] = types.NullValue
|
||||
continue
|
||||
}
|
||||
|
||||
v[i] = types.Int(*d)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -1650,7 +1650,6 @@ func testSchemaMergeHelper(t *testing.T, tests []schemaMergeTest, flipSides bool
|
||||
|
||||
var mo merge.MergeOpts
|
||||
var eo editor.Options
|
||||
eo = eo.WithDeaf(editor.NewInMemDeaf(a.VRW()))
|
||||
// attempt merge before skipping to assert no panics
|
||||
result, err := merge.MergeRoots(sql.NewContext(ctx), doltdb.SimpleTableResolver{}, l, r, a, rootish{r}, rootish{a}, eo, mo)
|
||||
maybeSkip(t, test, flipSides)
|
||||
@@ -1693,7 +1692,6 @@ func testSchemaMergeHelper(t *testing.T, tests []schemaMergeTest, flipSides bool
|
||||
func setupSchemaMergeTest(ctx context.Context, t *testing.T, test schemaMergeTest) (anc, left, right, merged doltdb.RootValue) {
|
||||
denv := dtestutils.CreateTestEnv()
|
||||
var eo editor.Options
|
||||
eo = eo.WithDeaf(editor.NewInMemDeaf(denv.DoltDB(ctx).ValueReadWriter()))
|
||||
anc = makeRootWithTable(t, denv.DoltDB(ctx), eo, test.ancestor)
|
||||
assert.NotNil(t, anc)
|
||||
if test.left != nil {
|
||||
|
||||
@@ -20,12 +20,10 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
json2 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/json"
|
||||
"github.com/dolthub/dolt/go/store/hash"
|
||||
"github.com/dolthub/dolt/go/store/prolly"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
@@ -543,16 +541,3 @@ func foreignKeyCVJson(foreignKey doltdb.ForeignKey, sch, refSch schema.Schema) (
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func jsonDataToNomsValue(ctx context.Context, vrw types.ValueReadWriter, data []byte) (types.JSON, error) {
|
||||
var doc interface{}
|
||||
if err := json.Unmarshal(data, &doc); err != nil {
|
||||
return types.JSON{}, err
|
||||
}
|
||||
sqlDoc := gmstypes.JSONDocument{Val: doc}
|
||||
nomsJson, err := json2.NomsJSONFromJSONValue(ctx, vrw, sqlDoc)
|
||||
if err != nil {
|
||||
return types.JSON{}, err
|
||||
}
|
||||
return types.JSON(nomsJson), nil
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestBasics(t *testing.T) {
|
||||
{NewDataLocation("file.csv", ""), CsvFile.ReadableStr() + ":file.csv", true},
|
||||
{NewDataLocation("file.psv", ""), PsvFile.ReadableStr() + ":file.psv", true},
|
||||
{NewDataLocation("file.json", ""), JsonFile.ReadableStr() + ":file.json", true},
|
||||
//{NewDataLocation("file.nbf", ""), NbfFile, "file.nbf", true},
|
||||
// {NewDataLocation("file.nbf", ""), NbfFile, "file.nbf", true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
@@ -133,7 +133,7 @@ func TestExists(t *testing.T) {
|
||||
NewDataLocation("file.csv", ""),
|
||||
NewDataLocation("file.psv", ""),
|
||||
NewDataLocation("file.json", ""),
|
||||
//NewDataLocation("file.nbf", ""),
|
||||
// NewDataLocation("file.nbf", ""),
|
||||
}
|
||||
|
||||
ddb, root, fs := createRootAndFS()
|
||||
@@ -192,7 +192,7 @@ func TestCreateRdWr(t *testing.T) {
|
||||
{NewDataLocation("file.csv", ""), reflect.TypeOf((*csv.CSVReader)(nil)).Elem(), reflect.TypeOf((*csv.CSVWriter)(nil)).Elem()},
|
||||
{NewDataLocation("file.psv", ""), reflect.TypeOf((*csv.CSVReader)(nil)).Elem(), reflect.TypeOf((*csv.CSVWriter)(nil)).Elem()},
|
||||
{NewDataLocation("file.json", ""), reflect.TypeOf((*json.JSONReader)(nil)).Elem(), reflect.TypeOf((*json.RowWriter)(nil)).Elem()},
|
||||
//{NewDataLocation("file.nbf", ""), reflect.TypeOf((*nbf.NBFReader)(nil)).Elem(), reflect.TypeOf((*nbf.NBFWriter)(nil)).Elem()},
|
||||
// {NewDataLocation("file.nbf", ""), reflect.TypeOf((*nbf.NBFReader)(nil)).Elem(), reflect.TypeOf((*nbf.NBFWriter)(nil)).Elem()},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -220,16 +220,6 @@ func TestCreateRdWr(t *testing.T) {
|
||||
|
||||
loc := test.dl
|
||||
|
||||
tmpDir, tdErr := dEnv.TempTableFilesDir()
|
||||
if tdErr != nil {
|
||||
t.Fatal("Unexpected error accessing .dolt directory.", tdErr)
|
||||
}
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
if err != nil {
|
||||
t.Fatal("Unexpected error accessing .dolt directory.", err)
|
||||
}
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
|
||||
filePath, fpErr := dEnv.FS.Abs(strings.Split(loc.String(), ":")[1])
|
||||
if fpErr != nil {
|
||||
t.Fatal("Unexpected error getting filepath", fpErr)
|
||||
@@ -240,7 +230,7 @@ func TestCreateRdWr(t *testing.T) {
|
||||
t.Fatal("Unexpected error opening file for writer.", wrErr)
|
||||
}
|
||||
|
||||
wr, wErr := loc.NewCreatingWriter(context.Background(), mvOpts, root, fakeSchema, opts, writer)
|
||||
wr, wErr := loc.NewCreatingWriter(context.Background(), mvOpts, root, fakeSchema, editor.Options{}, writer)
|
||||
if wErr != nil {
|
||||
t.Fatal("Unexpected error creating writer.", wErr)
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/noms"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
@@ -40,6 +39,9 @@ const (
|
||||
tableWriterStatUpdateRate = 64 * 1024
|
||||
)
|
||||
|
||||
// StatsCb is a callback for reporting stats about the rows that have been processed so far
|
||||
type StatsCb func(types.AppliedEditStats)
|
||||
|
||||
// SqlEngineTableWriter is a utility for importing a set of rows through the sql engine.
|
||||
type SqlEngineTableWriter struct {
|
||||
se *sqle.Engine
|
||||
@@ -51,7 +53,7 @@ type SqlEngineTableWriter struct {
|
||||
force bool
|
||||
disableFks bool
|
||||
|
||||
statsCB noms.StatsCB
|
||||
statsCB StatsCb
|
||||
stats types.AppliedEditStats
|
||||
statOps int32
|
||||
|
||||
@@ -60,7 +62,7 @@ type SqlEngineTableWriter struct {
|
||||
rowOperationSchema sql.PrimaryKeySchema
|
||||
}
|
||||
|
||||
func NewSqlEngineTableWriter(ctx *sql.Context, engine *sqle.Engine, createTableSchema, rowOperationSchema schema.Schema, options *MoverOptions, statsCB noms.StatsCB) (*SqlEngineTableWriter, error) {
|
||||
func NewSqlEngineTableWriter(ctx *sql.Context, engine *sqle.Engine, createTableSchema, rowOperationSchema schema.Schema, options *MoverOptions, statsCB StatsCb) (*SqlEngineTableWriter, error) {
|
||||
if engine.IsReadOnly() {
|
||||
// SqlEngineTableWriter does not respect read only mode
|
||||
return nil, analyzererrors.ErrReadOnlyDatabase.New(ctx.GetCurrentDatabase())
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
package row
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
@@ -24,8 +22,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
var ErrRowNotValid = errors.New("invalid row for current schema")
|
||||
|
||||
type Row interface {
|
||||
// Iterates over all the columns in the row. Columns that have no value set will not be visited.
|
||||
IterCols(cb func(tag uint64, val types.Value) (stop bool, err error)) (bool, error)
|
||||
@@ -82,21 +78,6 @@ func FromNoms(sch schema.Schema, nomsKey, nomsVal types.Tuple) (Row, error) {
|
||||
return pkRowFromNoms(sch, nomsKey, nomsVal)
|
||||
}
|
||||
|
||||
// ToNoms returns the storage-layer tuples corresponding to |r|.
|
||||
func ToNoms(ctx context.Context, sch schema.Schema, r Row) (key, val types.Tuple, err error) {
|
||||
k, err := r.NomsMapKey(sch).Value(ctx)
|
||||
if err != nil {
|
||||
return key, val, err
|
||||
}
|
||||
|
||||
v, err := r.NomsMapValue(sch).Value(ctx)
|
||||
if err != nil {
|
||||
return key, val, err
|
||||
}
|
||||
|
||||
return k.(types.Tuple), v.(types.Tuple), nil
|
||||
}
|
||||
|
||||
func GetFieldByName(colName string, r Row, sch schema.Schema) (types.Value, bool) {
|
||||
col, ok := sch.GetAllCols().GetByName(colName)
|
||||
|
||||
@@ -123,71 +104,6 @@ func GetFieldByNameWithDefault(colName string, defVal types.Value, r Row, sch sc
|
||||
}
|
||||
}
|
||||
|
||||
// ReduceToIndexKeysFromTagMap creates a full key and a partial key from the given map of tags (first tuple being the
|
||||
// full key). Please refer to the note in the index editor for more information regarding partial keys.
|
||||
func ReduceToIndexKeysFromTagMap(nbf *types.NomsBinFormat, idx schema.Index, tagToVal map[uint64]types.Value, tf *types.TupleFactory) (types.Tuple, types.Tuple, error) {
|
||||
vals := make([]types.Value, 0, len(idx.AllTags())*2)
|
||||
for _, tag := range idx.AllTags() {
|
||||
val, ok := tagToVal[tag]
|
||||
if !ok {
|
||||
val = types.NullValue
|
||||
}
|
||||
vals = append(vals, types.Uint(tag), val)
|
||||
}
|
||||
|
||||
if tf == nil {
|
||||
fullKey, err := types.NewTuple(nbf, vals...)
|
||||
if err != nil {
|
||||
return types.Tuple{}, types.Tuple{}, err
|
||||
}
|
||||
|
||||
partialKey, err := types.NewTuple(nbf, vals[:idx.Count()*2]...)
|
||||
if err != nil {
|
||||
return types.Tuple{}, types.Tuple{}, err
|
||||
}
|
||||
|
||||
return fullKey, partialKey, nil
|
||||
} else {
|
||||
fullKey, err := tf.Create(vals...)
|
||||
if err != nil {
|
||||
return types.Tuple{}, types.Tuple{}, err
|
||||
}
|
||||
|
||||
partialKey, err := tf.Create(vals[:idx.Count()*2]...)
|
||||
if err != nil {
|
||||
return types.Tuple{}, types.Tuple{}, err
|
||||
}
|
||||
|
||||
return fullKey, partialKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReduceToIndexPartialKey creates an index record from a primary storage record.
|
||||
func ReduceToIndexPartialKey(tags []uint64, idx schema.Index, r Row) (types.Tuple, error) {
|
||||
var vals []types.Value
|
||||
if idx.Name() != "" {
|
||||
tags = idx.IndexedColumnTags()
|
||||
}
|
||||
for _, tag := range tags {
|
||||
val, ok := r.GetColVal(tag)
|
||||
if !ok {
|
||||
val = types.NullValue
|
||||
}
|
||||
vals = append(vals, types.Uint(tag), val)
|
||||
}
|
||||
|
||||
return types.NewTuple(r.Format(), vals...)
|
||||
}
|
||||
|
||||
func IsEmpty(r Row) (b bool) {
|
||||
b = true
|
||||
_, _ = r.IterCols(func(_ uint64, _ types.Value) (stop bool, err error) {
|
||||
b = false
|
||||
return true, nil
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// IsValid returns whether the row given matches the types and satisfies all the constraints of the schema given.
|
||||
func IsValid(r Row, sch schema.Schema) (bool, error) {
|
||||
column, constraint, err := findInvalidCol(r, sch)
|
||||
@@ -265,77 +181,3 @@ func AreEqual(row1, row2 Row, sch schema.Schema) bool {
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func TaggedValsEqualForSch(tv, other TaggedValues, sch schema.Schema) bool {
|
||||
if tv == nil && other == nil {
|
||||
return true
|
||||
} else if tv == nil || other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, tag := range sch.GetAllCols().Tags {
|
||||
val1, _ := tv[tag]
|
||||
val2, _ := other[tag]
|
||||
|
||||
if !valutil.NilSafeEqCheck(val1, val2) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func KeyAndTaggedValuesForRow(r Row, sch schema.Schema) (types.Tuple, TaggedValues, error) {
|
||||
switch typed := r.(type) {
|
||||
case nomsRow:
|
||||
pkCols := sch.GetPKCols()
|
||||
keyVals := make([]types.Value, 0, pkCols.Size()*2)
|
||||
tv := make(TaggedValues)
|
||||
err := pkCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
|
||||
val, ok := typed.key[tag]
|
||||
if !ok || types.IsNull(val) {
|
||||
return false, errors.New("invalid key contains null values")
|
||||
}
|
||||
|
||||
tv[tag] = val
|
||||
keyVals = append(keyVals, types.Uint(tag))
|
||||
keyVals = append(keyVals, val)
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return types.Tuple{}, nil, err
|
||||
}
|
||||
|
||||
nonPkCols := sch.GetNonPKCols()
|
||||
_, err = typed.value.Iter(func(tag uint64, val types.Value) (stop bool, err error) {
|
||||
if _, ok := nonPkCols.TagToIdx[tag]; ok {
|
||||
tv[tag] = val
|
||||
}
|
||||
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return types.Tuple{}, nil, err
|
||||
}
|
||||
|
||||
t, err := types.NewTuple(r.Format(), keyVals...)
|
||||
if err != nil {
|
||||
return types.Tuple{}, nil, err
|
||||
}
|
||||
|
||||
return t, tv, nil
|
||||
|
||||
case keylessRow:
|
||||
tv, err := typed.TaggedValues()
|
||||
if err != nil {
|
||||
return types.Tuple{}, nil, err
|
||||
}
|
||||
|
||||
return typed.key, tv, nil
|
||||
|
||||
default:
|
||||
panic("unknown row type")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,60 +191,6 @@ func TaggedValuesFromTupleValueSlice(vals types.TupleValueSlice) (TaggedValues,
|
||||
return taggedTuple, nil
|
||||
}
|
||||
|
||||
func TaggedValuesFromTupleKeyAndValue(key, value types.Tuple) (TaggedValues, error) {
|
||||
tv := make(TaggedValues)
|
||||
err := AddToTaggedVals(tv, key)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = AddToTaggedVals(tv, value)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tv, nil
|
||||
}
|
||||
|
||||
func AddToTaggedVals(tv TaggedValues, t types.Tuple) error {
|
||||
return IterDoltTuple(t, func(tag uint64, val types.Value) error {
|
||||
tv[tag] = val
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func IterDoltTuple(t types.Tuple, cb func(tag uint64, val types.Value) error) error {
|
||||
itr, err := t.Iterator()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for itr.HasMore() {
|
||||
_, tag, err := itr.NextUint64()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, currVal, err := itr.Next()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = cb(tag, currVal)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tt TaggedValues) String() string {
|
||||
str := "{"
|
||||
for k, v := range tt {
|
||||
@@ -260,36 +206,3 @@ func (tt TaggedValues) String() string {
|
||||
str += "\n}"
|
||||
return str
|
||||
}
|
||||
|
||||
// CountCellDiffs returns the number of fields that are different between two
|
||||
// tuples and does not panic if tuples are different lengths.
|
||||
func CountCellDiffs(from, to types.Tuple, fromSch, toSch schema.Schema) (uint64, error) {
|
||||
fromColLen := len(fromSch.GetAllCols().GetColumns())
|
||||
toColLen := len(toSch.GetAllCols().GetColumns())
|
||||
changed := 0
|
||||
f, err := ParseTaggedValues(from)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
t, err := ParseTaggedValues(to)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for i, v := range f {
|
||||
ov, ok := t[i]
|
||||
// !ok means t[i] has NULL value, and it is not cell modify if it was from drop column or add column
|
||||
if (!ok && fromColLen == toColLen) || (ok && !v.Equals(ov)) {
|
||||
changed++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range t {
|
||||
if f[i] == nil {
|
||||
changed++
|
||||
}
|
||||
}
|
||||
|
||||
return uint64(changed), nil
|
||||
}
|
||||
|
||||
@@ -18,12 +18,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/row"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
)
|
||||
|
||||
var IdentityConverter = &RowConverter{nil, true, nil}
|
||||
@@ -96,56 +93,6 @@ func panicOnDuplicateMappings(mapping *FieldMapping) {
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertWithWarnings takes an input row, maps its columns to their destination columns, performing any type
|
||||
// conversions needed to create a row of the expected destination schema, and uses the optional WarnFunction
|
||||
// callback to let callers handle logging a warning when a field cannot be cleanly converted.
|
||||
func (rc *RowConverter) ConvertWithWarnings(inRow row.Row, warnFn WarnFunction) (row.Row, error) {
|
||||
return rc.convert(inRow, warnFn)
|
||||
}
|
||||
|
||||
// convert takes a row and maps its columns to their destination columns, automatically performing any type conversion
|
||||
// needed, and using the optional WarnFunction to let callers log warnings on any type conversion errors.
|
||||
func (rc *RowConverter) convert(inRow row.Row, warnFn WarnFunction) (row.Row, error) {
|
||||
if rc.IdentityConverter {
|
||||
return inRow, nil
|
||||
}
|
||||
|
||||
outTaggedVals := make(row.TaggedValues, len(rc.SrcToDest))
|
||||
_, err := inRow.IterCols(func(tag uint64, val types.Value) (stop bool, err error) {
|
||||
convFunc, ok := rc.ConvFuncs[tag]
|
||||
|
||||
if ok {
|
||||
outTag := rc.SrcToDest[tag]
|
||||
outVal, err := convFunc(val)
|
||||
|
||||
if sql.ErrInvalidValue.Is(err) && warnFn != nil {
|
||||
col, _ := rc.SrcSch.GetAllCols().GetByTag(tag)
|
||||
warnFn(DatatypeCoercionFailureWarningCode, DatatypeCoercionFailureWarning, col.Name)
|
||||
outVal = types.NullValue
|
||||
err = nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if types.IsNull(outVal) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
outTaggedVals[outTag] = outVal
|
||||
}
|
||||
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return row.New(inRow.Format(), rc.DestSch, outTaggedVals)
|
||||
}
|
||||
|
||||
func IsNecessary(srcSch, destSch schema.Schema, destToSrc map[uint64]uint64) (bool, error) {
|
||||
srcCols := srcSch.GetAllCols()
|
||||
destCols := destSch.GetAllCols()
|
||||
|
||||
@@ -42,18 +42,6 @@ const (
|
||||
NotNullConstraintType = "not_null"
|
||||
)
|
||||
|
||||
// ColConstraintFromTypeAndParams takes in a string representing the type of the constraint and a map of parameters
|
||||
// that can be used to determine the behavior of the constraint. An example might be a constraint which validated
|
||||
// a value is in a given range. For this the constraint type might by "in_range_constraint", and the parameters might
|
||||
// be {"min": -10, "max": 10}
|
||||
func ColConstraintFromTypeAndParams(colCnstType string, params map[string]string) ColConstraint {
|
||||
switch colCnstType {
|
||||
case NotNullConstraintType:
|
||||
return NotNullConstraint{}
|
||||
}
|
||||
panic("Unknown column constraint type: " + colCnstType)
|
||||
}
|
||||
|
||||
// NotNullConstraint validates that a value is not null. It does not restrict 0 length strings, or 0 valued ints, or
|
||||
// anything other than non nil values
|
||||
type NotNullConstraint struct{}
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/dolthub/vitess/go/sqltypes"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
@@ -45,38 +44,10 @@ type blobStringType struct {
|
||||
var _ TypeInfo = (*blobStringType)(nil)
|
||||
|
||||
var (
|
||||
TinyTextType TypeInfo = &blobStringType{sqlStringType: gmstypes.TinyText}
|
||||
TextType TypeInfo = &blobStringType{sqlStringType: gmstypes.Text}
|
||||
MediumTextType TypeInfo = &blobStringType{sqlStringType: gmstypes.MediumText}
|
||||
LongTextType TypeInfo = &blobStringType{sqlStringType: gmstypes.LongText}
|
||||
TextType TypeInfo = &blobStringType{sqlStringType: gmstypes.Text}
|
||||
LongTextType TypeInfo = &blobStringType{sqlStringType: gmstypes.LongText}
|
||||
)
|
||||
|
||||
func CreateBlobStringTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
collationStr, ok := params[blobStringTypeParam_Collate]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`create blobstring type info is missing param "%v"`, blobStringTypeParam_Collate)
|
||||
}
|
||||
collation, err := sql.ParseCollation("", collationStr, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxLengthStr, ok := params[blobStringTypeParam_Length]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`create blobstring type info is missing param "%v"`, blobStringTypeParam_Length)
|
||||
}
|
||||
length, err := strconv.ParseInt(maxLengthStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlType, err := gmstypes.CreateString(sqltypes.Text, length, collation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &blobStringType{sqlType}, nil
|
||||
}
|
||||
|
||||
// ConvertNomsValueToValue implements TypeInfo interface.
|
||||
func (ti *blobStringType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
|
||||
if val, ok := v.(types.Blob); ok {
|
||||
|
||||
@@ -116,29 +116,6 @@ func generateSetType(t *testing.T, numOfElements int) *setType {
|
||||
return &setType{gmstypes.MustCreateSetType(vals, sql.Collation_Default)}
|
||||
}
|
||||
|
||||
func generateInlineBlobTypes(t *testing.T, numOfTypes uint16) []TypeInfo {
|
||||
var res []TypeInfo
|
||||
loop(t, 1, 500, numOfTypes, func(i int64) {
|
||||
pad := false
|
||||
if i%2 == 0 {
|
||||
pad = true
|
||||
}
|
||||
res = append(res, generateInlineBlobType(t, i, pad))
|
||||
})
|
||||
return res
|
||||
}
|
||||
|
||||
func generateInlineBlobType(t *testing.T, length int64, pad bool) *inlineBlobType {
|
||||
require.True(t, length > 0)
|
||||
if pad {
|
||||
t, err := gmstypes.CreateBinary(sqltypes.Binary, length)
|
||||
if err == nil {
|
||||
return &inlineBlobType{t}
|
||||
}
|
||||
}
|
||||
return &inlineBlobType{gmstypes.MustCreateBinary(sqltypes.VarBinary, length)}
|
||||
}
|
||||
|
||||
func generateVarStringTypes(t *testing.T, numOfTypes uint16) []TypeInfo {
|
||||
var res []TypeInfo
|
||||
loop(t, 1, 500, numOfTypes, func(i int64) {
|
||||
|
||||
@@ -17,7 +17,6 @@ package typeinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
@@ -50,39 +49,6 @@ func CreateDatetimeTypeFromSqlType(typ sql.DatetimeType) *datetimeType {
|
||||
return &datetimeType{typ}
|
||||
}
|
||||
|
||||
func CreateDatetimeTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
if sqlType, ok := params[datetimeTypeParam_SQL]; ok {
|
||||
precision := 6
|
||||
if precisionParam, ok := params[datetimeTypeParam_Precision]; ok {
|
||||
var err error
|
||||
precision, err = strconv.Atoi(precisionParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
switch sqlType {
|
||||
case datetimeTypeParam_SQL_Date:
|
||||
return DateType, nil
|
||||
case datetimeTypeParam_SQL_Datetime:
|
||||
gmsType, err := gmstypes.CreateDatetimeType(sqltypes.Datetime, precision)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CreateDatetimeTypeFromSqlType(gmsType), nil
|
||||
case datetimeTypeParam_SQL_Timestamp:
|
||||
gmsType, err := gmstypes.CreateDatetimeType(sqltypes.Timestamp, precision)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CreateDatetimeTypeFromSqlType(gmsType), nil
|
||||
default:
|
||||
return nil, fmt.Errorf(`create datetime type info has invalid param "%v"`, sqlType)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf(`create datetime type info is missing param "%v"`, datetimeTypeParam_SQL)
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertNomsValueToValue implements TypeInfo interface.
|
||||
func (ti *datetimeType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
|
||||
if val, ok := v.(types.Timestamp); ok {
|
||||
@@ -121,7 +87,7 @@ func (ti *datetimeType) ReadFrom(_ *types.NomsBinFormat, reader types.CodecReade
|
||||
|
||||
// ConvertValueToNomsValue implements TypeInfo interface.
|
||||
func (ti *datetimeType) ConvertValueToNomsValue(ctx context.Context, vrw types.ValueReadWriter, v interface{}) (types.Value, error) {
|
||||
//TODO: handle the zero value as a special case that is valid for all ranges
|
||||
// TODO: handle the zero value as a special case that is valid for all ranges
|
||||
if v == nil {
|
||||
return types.NullValue, nil
|
||||
}
|
||||
|
||||
@@ -16,12 +16,10 @@ package typeinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
@@ -39,34 +37,6 @@ type enumType struct {
|
||||
|
||||
var _ TypeInfo = (*enumType)(nil)
|
||||
|
||||
func CreateEnumTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var collation sql.CollationID
|
||||
var err error
|
||||
if collationStr, ok := params[enumTypeParam_Collation]; ok {
|
||||
collation, err = sql.ParseCollation("", collationStr, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf(`create enum type info is missing param "%v"`, enumTypeParam_Collation)
|
||||
}
|
||||
var values []string
|
||||
if valuesStr, ok := params[enumTypeParam_Values]; ok {
|
||||
dec := gob.NewDecoder(strings.NewReader(valuesStr))
|
||||
err = dec.Decode(&values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf(`create enum type info is missing param "%v"`, enumTypeParam_Values)
|
||||
}
|
||||
sqlEnumType, err := gmstypes.CreateEnumType(values, collation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CreateEnumTypeFromSqlEnumType(sqlEnumType), nil
|
||||
}
|
||||
|
||||
func CreateEnumTypeFromSqlEnumType(sqlEnumType sql.EnumType) TypeInfo {
|
||||
return &enumType{sqlEnumType}
|
||||
}
|
||||
|
||||
@@ -19,15 +19,10 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
const (
|
||||
extendedTypeParams_string_encoded = "string_encoded"
|
||||
)
|
||||
|
||||
// extendedType is a type that refers to an ExtendedType in GMS. These are only supported in the new format, and have many
|
||||
// more limitations than traditional types (for now).
|
||||
type extendedType struct {
|
||||
@@ -36,18 +31,6 @@ type extendedType struct {
|
||||
|
||||
var _ TypeInfo = (*extendedType)(nil)
|
||||
|
||||
// CreateExtendedTypeFromParams creates a TypeInfo from the given parameter map.
|
||||
func CreateExtendedTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
if encodedString, ok := params[extendedTypeParams_string_encoded]; ok {
|
||||
t, err := gmstypes.DeserializeTypeFromString(encodedString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &extendedType{t}, nil
|
||||
}
|
||||
return nil, fmt.Errorf(`create extended type info is missing "%v" param`, extendedTypeParams_string_encoded)
|
||||
}
|
||||
|
||||
// CreateExtendedTypeFromSqlType creates a TypeInfo from the given extended type.
|
||||
func CreateExtendedTypeFromSqlType(typ sql.ExtendedType) TypeInfo {
|
||||
return &extendedType{typ}
|
||||
|
||||
@@ -30,12 +30,6 @@ import (
|
||||
|
||||
type FloatWidth int8
|
||||
|
||||
const (
|
||||
floatTypeParam_Width = "width"
|
||||
floatTypeParam_Width_32 = "32"
|
||||
floatTypeParam_Width_64 = "64"
|
||||
)
|
||||
|
||||
type floatType struct {
|
||||
sqlFloatType sql.NumberType
|
||||
}
|
||||
@@ -46,20 +40,6 @@ var (
|
||||
Float64Type = &floatType{gmstypes.Float64}
|
||||
)
|
||||
|
||||
func CreateFloatTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
if width, ok := params[floatTypeParam_Width]; ok {
|
||||
switch width {
|
||||
case floatTypeParam_Width_32:
|
||||
return Float32Type, nil
|
||||
case floatTypeParam_Width_64:
|
||||
return Float64Type, nil
|
||||
default:
|
||||
return nil, fmt.Errorf(`create float type info has "%v" param with value "%v"`, floatTypeParam_Width, width)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf(`create float type info is missing "%v" param`, floatTypeParam_Width)
|
||||
}
|
||||
|
||||
// ConvertNomsValueToValue implements TypeInfo interface.
|
||||
func (ti *floatType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
|
||||
if val, ok := v.(types.Float); ok {
|
||||
|
||||
@@ -17,7 +17,6 @@ package typeinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
@@ -289,28 +288,6 @@ func geometryTypeConverter(ctx context.Context, src *geometryType, destTi TypeIn
|
||||
}
|
||||
}
|
||||
|
||||
func CreateGeometryTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var (
|
||||
err error
|
||||
sridVal uint64
|
||||
def bool
|
||||
)
|
||||
if s, ok := params["SRID"]; ok {
|
||||
sridVal, err = strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if d, ok := params["DefinedSRID"]; ok {
|
||||
def, err = strconv.ParseBool(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return CreateGeometryTypeFromSqlGeometryType(gmstypes.GeometryType{SRID: uint32(sridVal), DefinedSRID: def}), nil
|
||||
}
|
||||
|
||||
func CreateGeometryTypeFromSqlGeometryType(sqlGeometryType gmstypes.GeometryType) TypeInfo {
|
||||
return &geometryType{sqlGeometryType: sqlGeometryType}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ package typeinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
@@ -197,25 +196,3 @@ func geomcollTypeConverter(ctx context.Context, src *geomcollType, destTi TypeIn
|
||||
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
|
||||
}
|
||||
}
|
||||
|
||||
func CreateGeomCollTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var (
|
||||
err error
|
||||
sridVal uint64
|
||||
def bool
|
||||
)
|
||||
if s, ok := params["SRID"]; ok {
|
||||
sridVal, err = strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if d, ok := params["DefinedSRID"]; ok {
|
||||
def, err = strconv.ParseBool(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &geomcollType{sqlGeomCollType: gmstypes.GeomCollType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
|
||||
}
|
||||
|
||||
@@ -22,19 +22,11 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/dolthub/vitess/go/sqltypes"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
const (
|
||||
inlineBlobTypeParam_Length = "length"
|
||||
inlineBlobTypeParam_SQL = "sql"
|
||||
inlineBlobTypeParam_SQL_Binary = "bin"
|
||||
inlineBlobTypeParam_SQL_VarBinary = "varbin"
|
||||
)
|
||||
|
||||
// inlineBlobType handles BINARY and VARBINARY. BLOB types are handled by varBinaryType.
|
||||
type inlineBlobType struct {
|
||||
sqlBinaryType sql.StringType
|
||||
@@ -42,40 +34,6 @@ type inlineBlobType struct {
|
||||
|
||||
var _ TypeInfo = (*inlineBlobType)(nil)
|
||||
|
||||
var (
|
||||
VarbinaryDefaultType = &inlineBlobType{gmstypes.MustCreateBinary(sqltypes.VarBinary, 16383)}
|
||||
)
|
||||
|
||||
func CreateInlineBlobTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var length int64
|
||||
var err error
|
||||
if lengthStr, ok := params[inlineBlobTypeParam_Length]; ok {
|
||||
length, err = strconv.ParseInt(lengthStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf(`create inlineblob type info is missing param "%v"`, inlineBlobTypeParam_Length)
|
||||
}
|
||||
if sqlStr, ok := params[inlineBlobTypeParam_SQL]; ok {
|
||||
var sqlType sql.StringType
|
||||
switch sqlStr {
|
||||
case inlineBlobTypeParam_SQL_Binary:
|
||||
sqlType, err = gmstypes.CreateBinary(sqltypes.Binary, length)
|
||||
case inlineBlobTypeParam_SQL_VarBinary:
|
||||
sqlType, err = gmstypes.CreateBinary(sqltypes.VarBinary, length)
|
||||
default:
|
||||
return nil, fmt.Errorf(`create inlineblob type info has "%v" param with value "%v"`, inlineBlobTypeParam_SQL, sqlStr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &inlineBlobType{sqlType}, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf(`create inlineblob type info is missing param "%v"`, inlineBlobTypeParam_SQL)
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertNomsValueToValue implements TypeInfo interface.
|
||||
func (ti *inlineBlobType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
|
||||
if val, ok := v.(types.InlineBlob); ok {
|
||||
|
||||
@@ -26,15 +26,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
const (
|
||||
intTypeParams_Width = "width"
|
||||
intTypeParams_Width_8 = "8"
|
||||
intTypeParams_Width_16 = "16"
|
||||
intTypeParams_Width_24 = "24"
|
||||
intTypeParams_Width_32 = "32"
|
||||
intTypeParams_Width_64 = "64"
|
||||
)
|
||||
|
||||
type intType struct {
|
||||
sqlIntType sql.NumberType
|
||||
}
|
||||
@@ -48,26 +39,6 @@ var (
|
||||
Int64Type = &intType{gmstypes.Int64}
|
||||
)
|
||||
|
||||
func CreateIntTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
if width, ok := params[intTypeParams_Width]; ok {
|
||||
switch width {
|
||||
case intTypeParams_Width_8:
|
||||
return Int8Type, nil
|
||||
case intTypeParams_Width_16:
|
||||
return Int16Type, nil
|
||||
case intTypeParams_Width_24:
|
||||
return Int24Type, nil
|
||||
case intTypeParams_Width_32:
|
||||
return Int32Type, nil
|
||||
case intTypeParams_Width_64:
|
||||
return Int64Type, nil
|
||||
default:
|
||||
return nil, fmt.Errorf(`create int type info has "%v" param with value "%v"`, intTypeParams_Width, width)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf(`create int type info is missing "%v" param`, intTypeParams_Width)
|
||||
}
|
||||
|
||||
// ConvertNomsValueToValue implements TypeInfo interface.
|
||||
func (ti *intType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
|
||||
if val, ok := v.(types.Int); ok {
|
||||
|
||||
@@ -17,7 +17,6 @@ package typeinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
@@ -197,24 +196,3 @@ func linestringTypeConverter(ctx context.Context, src *linestringType, destTi Ty
|
||||
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
|
||||
}
|
||||
}
|
||||
|
||||
func CreateLineStringTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var (
|
||||
err error
|
||||
sridVal uint64
|
||||
def bool
|
||||
)
|
||||
if s, ok := params["SRID"]; ok {
|
||||
sridVal, err = strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if d, ok := params["DefinedSRID"]; ok {
|
||||
def, err = strconv.ParseBool(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &linestringType{sqlLineStringType: gmstypes.LineStringType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ package typeinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
@@ -197,25 +196,3 @@ func multilinestringTypeConverter(ctx context.Context, src *multilinestringType,
|
||||
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
|
||||
}
|
||||
}
|
||||
|
||||
func CreateMultiLineStringTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var (
|
||||
err error
|
||||
sridVal uint64
|
||||
def bool
|
||||
)
|
||||
if s, ok := params["SRID"]; ok {
|
||||
sridVal, err = strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if d, ok := params["DefinedSRID"]; ok {
|
||||
def, err = strconv.ParseBool(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &multilinestringType{sqlMultiLineStringType: gmstypes.MultiLineStringType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ package typeinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
@@ -197,24 +196,3 @@ func multipointTypeConverter(ctx context.Context, src *multipointType, destTi Ty
|
||||
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
|
||||
}
|
||||
}
|
||||
|
||||
func CreateMultiPointTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var (
|
||||
err error
|
||||
sridVal uint64
|
||||
def bool
|
||||
)
|
||||
if s, ok := params["SRID"]; ok {
|
||||
sridVal, err = strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if d, ok := params["DefinedSRID"]; ok {
|
||||
def, err = strconv.ParseBool(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &multipointType{sqlMultiPointType: gmstypes.MultiPointType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ package typeinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
@@ -197,25 +196,3 @@ func multipolygonTypeConverter(ctx context.Context, src *multipolygonType, destT
|
||||
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
|
||||
}
|
||||
}
|
||||
|
||||
func CreateMultiPolygonTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var (
|
||||
err error
|
||||
sridVal uint64
|
||||
def bool
|
||||
)
|
||||
if s, ok := params["SRID"]; ok {
|
||||
sridVal, err = strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if d, ok := params["DefinedSRID"]; ok {
|
||||
def, err = strconv.ParseBool(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &multipolygonType{sqlMultiPolygonType: gmstypes.MultiPolygonType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ package typeinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
@@ -198,28 +197,6 @@ func pointTypeConverter(ctx context.Context, src *pointType, destTi TypeInfo) (t
|
||||
}
|
||||
}
|
||||
|
||||
func CreatePointTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var (
|
||||
err error
|
||||
sridVal uint64
|
||||
def bool
|
||||
)
|
||||
if s, ok := params["SRID"]; ok {
|
||||
sridVal, err = strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if d, ok := params["DefinedSRID"]; ok {
|
||||
def, err = strconv.ParseBool(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return CreatePointTypeFromSqlPointType(gmstypes.PointType{SRID: uint32(sridVal), DefinedSRID: def}), nil
|
||||
}
|
||||
|
||||
func CreatePointTypeFromSqlPointType(sqlPointType gmstypes.PointType) TypeInfo {
|
||||
return &pointType{sqlPointType: sqlPointType}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ package typeinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
@@ -197,25 +196,3 @@ func polygonTypeConverter(ctx context.Context, src *polygonType, destTi TypeInfo
|
||||
return nil, false, UnhandledTypeConversion.New(src.String(), destTi.String())
|
||||
}
|
||||
}
|
||||
|
||||
func CreatePolygonTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var (
|
||||
err error
|
||||
sridVal uint64
|
||||
def bool
|
||||
)
|
||||
if s, ok := params["SRID"]; ok {
|
||||
sridVal, err = strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if d, ok := params["DefinedSRID"]; ok {
|
||||
def, err = strconv.ParseBool(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &polygonType{sqlPolygonType: gmstypes.PolygonType{SRID: uint32(sridVal), DefinedSRID: def}}, nil
|
||||
}
|
||||
|
||||
@@ -16,21 +16,14 @@ package typeinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
const (
|
||||
setTypeParam_Collation = "collate"
|
||||
setTypeParam_Values = "vals"
|
||||
)
|
||||
|
||||
// This is a dolt implementation of the MySQL type Set, thus most of the functionality
|
||||
// within is directly reliant on the go-mysql-server implementation.
|
||||
type setType struct {
|
||||
@@ -39,33 +32,6 @@ type setType struct {
|
||||
|
||||
var _ TypeInfo = (*setType)(nil)
|
||||
|
||||
func CreateSetTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
collationStr, ok := params[setTypeParam_Collation]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`create set type info is missing param "%v"`, setTypeParam_Collation)
|
||||
}
|
||||
collation, err := sql.ParseCollation("", collationStr, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
valuesStr, ok := params[setTypeParam_Values]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`create set type info is missing param "%v"`, setTypeParam_Values)
|
||||
}
|
||||
var values []string
|
||||
dec := gob.NewDecoder(strings.NewReader(valuesStr))
|
||||
if err = dec.Decode(&values); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlSetType, err := gmstypes.CreateSetType(values, collation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CreateSetTypeFromSqlSetType(sqlSetType), nil
|
||||
}
|
||||
|
||||
func CreateSetTypeFromSqlSetType(sqlSetType sql.SetType) TypeInfo {
|
||||
return &setType{sqlSetType}
|
||||
}
|
||||
|
||||
@@ -26,15 +26,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
const (
|
||||
uintTypeParam_Width = "width"
|
||||
uintTypeParam_Width_8 = "8"
|
||||
uintTypeParam_Width_16 = "16"
|
||||
uintTypeParam_Width_24 = "24"
|
||||
uintTypeParam_Width_32 = "32"
|
||||
uintTypeParam_Width_64 = "64"
|
||||
)
|
||||
|
||||
type uintType struct {
|
||||
sqlUintType sql.NumberType
|
||||
}
|
||||
@@ -48,26 +39,6 @@ var (
|
||||
Uint64Type = &uintType{gmstypes.Uint64}
|
||||
)
|
||||
|
||||
func CreateUintTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
if width, ok := params[uintTypeParam_Width]; ok {
|
||||
switch width {
|
||||
case uintTypeParam_Width_8:
|
||||
return Uint8Type, nil
|
||||
case uintTypeParam_Width_16:
|
||||
return Uint16Type, nil
|
||||
case uintTypeParam_Width_24:
|
||||
return Uint24Type, nil
|
||||
case uintTypeParam_Width_32:
|
||||
return Uint32Type, nil
|
||||
case uintTypeParam_Width_64:
|
||||
return Uint64Type, nil
|
||||
default:
|
||||
return nil, fmt.Errorf(`create uint type info has "%v" param with value "%v"`, uintTypeParam_Width, width)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf(`create uint type info is missing "%v" param`, uintTypeParam_Width)
|
||||
}
|
||||
|
||||
// ConvertNomsValueToValue implements TypeInfo interface.
|
||||
func (ti *uintType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
|
||||
if val, ok := v.(types.Uint); ok {
|
||||
|
||||
@@ -16,7 +16,6 @@ package typeinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
@@ -24,16 +23,10 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/dolthub/vitess/go/sqltypes"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
const (
|
||||
varBinaryTypeParam_Length = "length"
|
||||
)
|
||||
|
||||
// As a type, this is modeled more after MySQL's story for binary data. There, it's treated
|
||||
// as a string that is interpreted as raw bytes, rather than as a bespoke data structure,
|
||||
// and thus this is mirrored here in its implementation. This will minimize any differences
|
||||
@@ -46,31 +39,6 @@ type varBinaryType struct {
|
||||
|
||||
var _ TypeInfo = (*varBinaryType)(nil)
|
||||
|
||||
var (
|
||||
TinyBlobType TypeInfo = &varBinaryType{sqlBinaryType: gmstypes.TinyBlob}
|
||||
BlobType TypeInfo = &varBinaryType{sqlBinaryType: gmstypes.Blob}
|
||||
MediumBlobType TypeInfo = &varBinaryType{sqlBinaryType: gmstypes.MediumBlob}
|
||||
LongBlobType TypeInfo = &varBinaryType{sqlBinaryType: gmstypes.LongBlob}
|
||||
)
|
||||
|
||||
func CreateVarBinaryTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var length int64
|
||||
var err error
|
||||
if lengthStr, ok := params[varBinaryTypeParam_Length]; ok {
|
||||
length, err = strconv.ParseInt(lengthStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf(`create varbinary type info is missing param "%v"`, varBinaryTypeParam_Length)
|
||||
}
|
||||
sqlType, err := gmstypes.CreateBinary(sqltypes.Blob, length)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &varBinaryType{sqlType}, nil
|
||||
}
|
||||
|
||||
// ConvertNomsValueToValue implements TypeInfo interface.
|
||||
func (ti *varBinaryType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
|
||||
if val, ok := v.(types.Blob); ok {
|
||||
@@ -199,29 +167,6 @@ func fromBlob(b types.Blob) ([]byte, error) {
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// hasPrefix finds out if a Blob has a prefixed integer. Initially blobs for varBinary prepended an integer indicating
|
||||
// the length, which was unnecessary (as the underlying sequence tracks the total size). It's been removed, but this
|
||||
// may be used to see if a Blob is one of those older Blobs. A false positive is possible, but EXTREMELY unlikely.
|
||||
func hasPrefix(b types.Blob, ctx context.Context) (bool, error) {
|
||||
blobLength := b.Len()
|
||||
if blobLength < 8 {
|
||||
return false, nil
|
||||
}
|
||||
countBytes := make([]byte, 8)
|
||||
n, err := b.ReadAt(ctx, countBytes, 0)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if n != 8 {
|
||||
return false, fmt.Errorf("wanted 8 bytes from blob for count, got %d", n)
|
||||
}
|
||||
prefixedLength := binary.LittleEndian.Uint64(countBytes)
|
||||
if prefixedLength == blobLength-8 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// varBinaryTypeConverter is an internal function for GetTypeConverter that handles the specific type as the source TypeInfo.
|
||||
func varBinaryTypeConverter(ctx context.Context, src *varBinaryType, destTi TypeInfo) (tc TypeConverter, needsConversion bool, err error) {
|
||||
switch dest := destTi.(type) {
|
||||
|
||||
@@ -28,15 +28,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
const (
|
||||
varStringTypeParam_Collate = "collate"
|
||||
varStringTypeParam_Length = "length"
|
||||
varStringTypeParam_SQL = "sql"
|
||||
varStringTypeParam_SQL_Char = "char"
|
||||
varStringTypeParam_SQL_VarChar = "varchar"
|
||||
varStringTypeParam_SQL_Text = "text"
|
||||
)
|
||||
|
||||
// varStringType handles CHAR and VARCHAR. The TEXT types are handled by blobStringType. For any repositories that were
|
||||
// created before the introduction of blobStringType, they will use varStringType for TEXT types. As varStringType makes
|
||||
// use of the String Value type, it does not actually support all viable lengths of a TEXT string, meaning all such
|
||||
@@ -60,48 +51,6 @@ func CreateVarStringTypeFromSqlType(stringType sql.StringType) TypeInfo {
|
||||
return &varStringType{stringType}
|
||||
}
|
||||
|
||||
func CreateVarStringTypeFromParams(params map[string]string) (TypeInfo, error) {
|
||||
var length int64
|
||||
var collation sql.CollationID
|
||||
var err error
|
||||
if collationStr, ok := params[varStringTypeParam_Collate]; ok {
|
||||
collation, err = sql.ParseCollation("", collationStr, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf(`create varstring type info is missing param "%v"`, varStringTypeParam_Collate)
|
||||
}
|
||||
if maxLengthStr, ok := params[varStringTypeParam_Length]; ok {
|
||||
length, err = strconv.ParseInt(maxLengthStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
} else {
|
||||
return nil, fmt.Errorf(`create varstring type info is missing param "%v"`, varStringTypeParam_Length)
|
||||
}
|
||||
if sqlStr, ok := params[varStringTypeParam_SQL]; ok {
|
||||
var sqlType sql.StringType
|
||||
switch sqlStr {
|
||||
case varStringTypeParam_SQL_Char:
|
||||
sqlType, err = gmstypes.CreateString(sqltypes.Char, length, collation)
|
||||
case varStringTypeParam_SQL_VarChar:
|
||||
sqlType, err = gmstypes.CreateString(sqltypes.VarChar, length, collation)
|
||||
case varStringTypeParam_SQL_Text:
|
||||
sqlType, err = gmstypes.CreateString(sqltypes.Text, length, collation)
|
||||
default:
|
||||
return nil, fmt.Errorf(`create varstring type info has "%v" param with value "%v"`, varStringTypeParam_SQL, sqlStr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &varStringType{sqlType}, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf(`create varstring type info is missing param "%v"`, varStringTypeParam_Length)
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertNomsValueToValue implements TypeInfo interface.
|
||||
func (ti *varStringType) ConvertNomsValueToValue(v types.Value) (interface{}, error) {
|
||||
if val, ok := v.(types.String); ok {
|
||||
|
||||
@@ -434,12 +434,8 @@ func TestDropPks(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
defer dEnv.DoltDB(ctx).Close()
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
require.NoError(t, err)
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
require.NoError(t, err)
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), opts)
|
||||
|
||||
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), editor.Options{})
|
||||
require.NoError(t, err)
|
||||
|
||||
root, _ := dEnv.WorkingRoot(ctx)
|
||||
|
||||
@@ -841,9 +841,7 @@ func getTableWriter(ctx *sql.Context, engine *gms.Engine, tableName, databaseNam
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
options := sqlDatabase.EditOptions()
|
||||
options.ForeignKeyChecksDisabled = foreignKeyChecksDisabled
|
||||
writeSession := writer.NewWriteSession(binFormat, ws, tracker, options)
|
||||
writeSession := writer.NewWriteSession(binFormat, ws, tracker, sqlDatabase.EditOptions())
|
||||
|
||||
ds := dsess.DSessFromSess(ctx.Session)
|
||||
setter := ds.SetWorkingRoot
|
||||
|
||||
@@ -39,12 +39,7 @@ type SetupFn func(t *testing.T, dEnv *env.DoltEnv)
|
||||
// Runs the query given and returns the result. The schema result of the query's execution is currently ignored, and
|
||||
// the targetSchema given is used to prepare all rows.
|
||||
func executeSelect(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue, query string) ([]sql.Row, sql.Schema, error) {
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
require.NoError(t, err)
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
require.NoError(t, err)
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), opts)
|
||||
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), editor.Options{})
|
||||
require.NoError(t, err)
|
||||
|
||||
engine, sqlCtx, err := NewTestEngine(dEnv, ctx, db)
|
||||
@@ -72,12 +67,7 @@ func executeSelect(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root do
|
||||
|
||||
// Runs the query given and returns the error (if any).
|
||||
func executeModify(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue, query string) (doltdb.RootValue, error) {
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
require.NoError(t, err)
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
require.NoError(t, err)
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), opts)
|
||||
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), editor.Options{})
|
||||
require.NoError(t, err)
|
||||
|
||||
engine, sqlCtx, err := NewTestEngine(dEnv, ctx, db)
|
||||
|
||||
@@ -1955,7 +1955,10 @@ func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.Prima
|
||||
return err
|
||||
}
|
||||
|
||||
if doltdb.IsSystemTable(doltdb.TableName{Name: tableName, Schema: db.schemaName}) && !doltdb.IsFullTextTable(tableName) && !doltdb.HasDoltCIPrefix(tableName) {
|
||||
if doltdb.IsSystemTable(doltdb.TableName{Name: tableName, Schema: db.schemaName}) &&
|
||||
!doltdb.IsFullTextTable(tableName) &&
|
||||
!doltdb.HasDoltCIPrefix(tableName) &&
|
||||
tableName != doltdb.TestsTableName { // NM4 - determine why this is required now.
|
||||
return ErrReservedTableName.New(tableName)
|
||||
}
|
||||
|
||||
|
||||
@@ -990,23 +990,7 @@ func (p *DoltDatabaseProvider) registerNewDatabase(ctx *sql.Context, name string
|
||||
// Ensure any provider-supplied DB load params are applied before any lazy DB load occurs.
|
||||
p.applyDBLoadParamsToEnv(newEnv)
|
||||
|
||||
fkChecks, err := ctx.GetSessionVariable(ctx, "foreign_key_checks")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deaf, err := newEnv.DbEaFactory(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts := editor.Options{
|
||||
Deaf: deaf,
|
||||
// TODO: this doesn't seem right, why is this getting set in the constructor to the DB
|
||||
ForeignKeyChecksDisabled: fkChecks.(int8) == 0,
|
||||
}
|
||||
|
||||
db, err := NewDatabase(ctx, name, newEnv.DbData(ctx), opts)
|
||||
db, err := NewDatabase(ctx, name, newEnv.DbData(ctx), editor.Options{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -44,12 +44,8 @@ func TestDatabaseProvider(t *testing.T) {
|
||||
setup := func(t *testing.T) (*sqle.Engine, *sql.Context, *DoltDatabaseProvider) {
|
||||
ctx := context.Background()
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
require.NoError(t, err)
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
require.NoError(t, err)
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(ctx), opts)
|
||||
|
||||
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(ctx), editor.Options{})
|
||||
require.NoError(t, err)
|
||||
|
||||
engine, sqlCtx, err := NewTestEngine(dEnv, context.Background(), db)
|
||||
|
||||
@@ -45,15 +45,9 @@ func TestIsKeyFuncs(t *testing.T) {
|
||||
func TestNeedsToReloadEvents(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
require.NoError(t, err)
|
||||
deaf, err := dEnv.DbEaFactory(ctx)
|
||||
require.NoError(t, err)
|
||||
opts := editor.Options{Deaf: deaf, Tempdir: tmpDir}
|
||||
|
||||
timestamp := time.Now().Truncate(time.Minute).UTC()
|
||||
|
||||
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), opts)
|
||||
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(ctx), editor.Options{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, sqlCtx, err := NewTestEngine(dEnv, ctx, db)
|
||||
|
||||
@@ -103,6 +103,8 @@ func doDoltCherryPick(ctx *sql.Context, args []string) (string, int, int, int, e
|
||||
cherryPickOptions.EmptyCommitHandling = doltdb.KeepEmptyCommit
|
||||
}
|
||||
|
||||
cherryPickOptions.SkipVerification = apr.Contains(cli.SkipVerificationFlag)
|
||||
|
||||
commit, mergeResult, err := cherry_pick.CherryPick(ctx, cherryStr, cherryPickOptions)
|
||||
if err != nil {
|
||||
return "", 0, 0, 0, err
|
||||
|
||||
@@ -57,7 +57,8 @@ func doDoltClean(ctx *sql.Context, args []string) (int, error) {
|
||||
return 1, fmt.Errorf("Could not load database %s", dbName)
|
||||
}
|
||||
|
||||
roots, err = actions.CleanUntracked(ctx, roots, apr.Args, apr.ContainsAll(cli.DryRunFlag), false)
|
||||
respectIgnoreRules := !apr.Contains(cli.ExcludeIgnoreRulesFlag)
|
||||
roots, err = actions.CleanUntracked(ctx, roots, apr.Args, apr.ContainsAll(cli.DryRunFlag), false, respectIgnoreRules)
|
||||
if err != nil {
|
||||
return 1, fmt.Errorf("failed to clean; %w", err)
|
||||
}
|
||||
|
||||
@@ -113,6 +113,13 @@ func getDirectoryAndUrlString(apr *argparser.ArgParseResults) (string, string, e
|
||||
} else if dir == "/" {
|
||||
return "", "", errhand.BuildDError("Could not infer repo name. Please explicitly define a directory for this url").Build()
|
||||
}
|
||||
// Match `dolt clone` behavior: strip a trailing `.git` from inferred names.
|
||||
if strings.HasSuffix(dir, ".git") {
|
||||
dir = strings.TrimSuffix(dir, ".git")
|
||||
if dir == "" {
|
||||
return "", "", errhand.BuildDError("Could not infer repo name. Please explicitly define a directory for this url").Build()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dir, urlStr, nil
|
||||
|
||||
@@ -163,14 +163,15 @@ func doDoltCommit(ctx *sql.Context, args []string) (string, bool, error) {
|
||||
}
|
||||
|
||||
csp := actions.CommitStagedProps{
|
||||
Message: msg,
|
||||
Date: t,
|
||||
AllowEmpty: apr.Contains(cli.AllowEmptyFlag),
|
||||
SkipEmpty: apr.Contains(cli.SkipEmptyFlag),
|
||||
Amend: amend,
|
||||
Force: apr.Contains(cli.ForceFlag),
|
||||
Name: name,
|
||||
Email: email,
|
||||
Message: msg,
|
||||
Date: t,
|
||||
AllowEmpty: apr.Contains(cli.AllowEmptyFlag),
|
||||
SkipEmpty: apr.Contains(cli.SkipEmptyFlag),
|
||||
Amend: amend,
|
||||
Force: apr.Contains(cli.ForceFlag),
|
||||
Name: name,
|
||||
Email: email,
|
||||
SkipVerification: apr.Contains(cli.SkipVerificationFlag),
|
||||
}
|
||||
|
||||
shouldSign, err := dsess.GetBooleanSystemVar(ctx, "gpgsign")
|
||||
|
||||
@@ -180,7 +180,7 @@ func doDoltMerge(ctx *sql.Context, args []string) (string, int, int, string, err
|
||||
msg = userMsg
|
||||
}
|
||||
|
||||
ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg)
|
||||
ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg, apr.Contains(cli.SkipVerificationFlag))
|
||||
if err != nil {
|
||||
return commit, conflicts, fastForward, "", err
|
||||
}
|
||||
@@ -205,6 +205,7 @@ func performMerge(
|
||||
spec *merge.MergeSpec,
|
||||
noCommit bool,
|
||||
msg string,
|
||||
skipVerification bool,
|
||||
) (*doltdb.WorkingSet, string, int, int, string, error) {
|
||||
// todo: allow merges even when an existing merge is uncommitted
|
||||
if ws.MergeActive() {
|
||||
@@ -234,7 +235,7 @@ func performMerge(
|
||||
if canFF {
|
||||
if spec.FFMode == merge.NoFastForward {
|
||||
var commit *doltdb.Commit
|
||||
ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit)
|
||||
ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit, skipVerification)
|
||||
if err == doltdb.ErrUnresolvedConflictsOrViolations {
|
||||
// if there are unresolved conflicts, write the resulting working set back to the session and return an
|
||||
// error message
|
||||
@@ -306,7 +307,10 @@ func performMerge(
|
||||
author := fmt.Sprintf("%s <%s>", spec.Name, spec.Email)
|
||||
args := []string{"-m", msg, "--author", author}
|
||||
if spec.Force {
|
||||
args = append(args, "--force")
|
||||
args = append(args, "--"+cli.ForceFlag)
|
||||
}
|
||||
if skipVerification {
|
||||
args = append(args, "--"+cli.SkipVerificationFlag)
|
||||
}
|
||||
commit, _, err = doDoltCommit(ctx, args)
|
||||
if err != nil {
|
||||
@@ -405,6 +409,7 @@ func executeNoFFMerge(
|
||||
dbName string,
|
||||
ws *doltdb.WorkingSet,
|
||||
noCommit bool,
|
||||
skipVerification bool,
|
||||
) (*doltdb.WorkingSet, *doltdb.Commit, error) {
|
||||
mergeRoot, err := spec.MergeC.GetRootValue(ctx)
|
||||
if err != nil {
|
||||
@@ -444,11 +449,12 @@ func executeNoFFMerge(
|
||||
}
|
||||
|
||||
pendingCommit, err := dSess.NewPendingCommit(ctx, dbName, roots, actions.CommitStagedProps{
|
||||
Message: msg,
|
||||
Date: spec.Date,
|
||||
Force: spec.Force,
|
||||
Name: spec.Name,
|
||||
Email: spec.Email,
|
||||
Message: msg,
|
||||
Date: spec.Date,
|
||||
Force: spec.Force,
|
||||
Name: spec.Name,
|
||||
Email: spec.Email,
|
||||
SkipVerification: skipVerification,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -237,7 +237,7 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, string, error) {
|
||||
return noConflictsOrViolations, threeWayMerge, "", ErrUncommittedChanges.New()
|
||||
}
|
||||
|
||||
ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg)
|
||||
ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg, apr.Contains(cli.SkipVerificationFlag))
|
||||
if err != nil && !errors.Is(doltdb.ErrUpToDate, err) {
|
||||
return conflicts, fastForward, "", err
|
||||
}
|
||||
|
||||
@@ -216,7 +216,9 @@ func doDoltRebase(ctx *sql.Context, args []string) (int, string, error) {
|
||||
} else if apr.NArg() > 1 {
|
||||
return 1, "", fmt.Errorf("too many args")
|
||||
}
|
||||
err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling)
|
||||
|
||||
skipVerification := apr.Contains(cli.SkipVerificationFlag)
|
||||
err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
|
||||
if err != nil {
|
||||
return 1, "", err
|
||||
}
|
||||
@@ -263,7 +265,7 @@ func processCommitBecomesEmptyParams(apr *argparser.ArgParseResults) (doltdb.Emp
|
||||
// startRebase starts a new interactive rebase operation. |upstreamPoint| specifies the commit where the new rebased
|
||||
// commits will be based off of, |commitBecomesEmptyHandling| specifies how to handle commits that are not empty, but
|
||||
// do not produce any changes when applied, and |emptyCommitHandling| specifies how to handle empty commits.
|
||||
func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) error {
|
||||
func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling, skipVerification bool) error {
|
||||
if upstreamPoint == "" {
|
||||
return fmt.Errorf("no upstream branch specified")
|
||||
}
|
||||
@@ -351,7 +353,7 @@ func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandl
|
||||
}
|
||||
|
||||
newWorkingSet, err := workingSet.StartRebase(ctx, upstreamCommit, rebaseBranch, branchRoots.Working,
|
||||
commitBecomesEmptyHandling, emptyCommitHandling)
|
||||
commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -716,7 +718,8 @@ func continueRebase(ctx *sql.Context) rebaseResult {
|
||||
|
||||
result := processRebasePlanStep(ctx, &step,
|
||||
workingSet.RebaseState().CommitBecomesEmptyHandling(),
|
||||
workingSet.RebaseState().EmptyCommitHandling())
|
||||
workingSet.RebaseState().EmptyCommitHandling(),
|
||||
workingSet.RebaseState().SkipVerification())
|
||||
if result.err != nil || result.status != 0 || result.halt {
|
||||
return result
|
||||
}
|
||||
@@ -803,7 +806,7 @@ func commitManuallyStagedChangesForStep(ctx *sql.Context, step rebase.RebasePlan
|
||||
}
|
||||
|
||||
options, err := createCherryPickOptionsForRebaseStep(ctx, &step, workingSet.RebaseState().CommitBecomesEmptyHandling(),
|
||||
workingSet.RebaseState().EmptyCommitHandling())
|
||||
workingSet.RebaseState().EmptyCommitHandling(), workingSet.RebaseState().SkipVerification())
|
||||
|
||||
doltDB, ok := doltSession.GetDoltDB(ctx, ctx.GetCurrentDatabase())
|
||||
if !ok {
|
||||
@@ -861,6 +864,7 @@ func processRebasePlanStep(
|
||||
planStep *rebase.RebasePlanStep,
|
||||
commitBecomesEmptyHandling doltdb.EmptyCommitHandling,
|
||||
emptyCommitHandling doltdb.EmptyCommitHandling,
|
||||
skipVerification bool,
|
||||
) rebaseResult {
|
||||
// Make sure we have a transaction opened for the session
|
||||
// NOTE: After our first call to cherry-pick, the tx is committed, so a new tx needs to be started
|
||||
@@ -878,7 +882,7 @@ func processRebasePlanStep(
|
||||
return newRebaseSuccess("")
|
||||
}
|
||||
|
||||
options, err := createCherryPickOptionsForRebaseStep(ctx, planStep, commitBecomesEmptyHandling, emptyCommitHandling)
|
||||
options, err := createCherryPickOptionsForRebaseStep(ctx, planStep, commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
|
||||
if err != nil {
|
||||
return newRebaseError(err)
|
||||
}
|
||||
@@ -886,12 +890,19 @@ func processRebasePlanStep(
|
||||
return handleRebaseCherryPick(ctx, planStep, *options)
|
||||
}
|
||||
|
||||
func createCherryPickOptionsForRebaseStep(ctx *sql.Context, planStep *rebase.RebasePlanStep, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) (*cherry_pick.CherryPickOptions, error) {
|
||||
func createCherryPickOptionsForRebaseStep(
|
||||
ctx *sql.Context,
|
||||
planStep *rebase.RebasePlanStep,
|
||||
commitBecomesEmptyHandling doltdb.EmptyCommitHandling,
|
||||
emptyCommitHandling doltdb.EmptyCommitHandling,
|
||||
skipVerification bool,
|
||||
) (*cherry_pick.CherryPickOptions, error) {
|
||||
// Override the default empty commit handling options for cherry-pick, since
|
||||
// rebase has slightly different defaults
|
||||
options := cherry_pick.NewCherryPickOptions()
|
||||
options.CommitBecomesEmptyHandling = commitBecomesEmptyHandling
|
||||
options.EmptyCommitHandling = emptyCommitHandling
|
||||
options.SkipVerification = skipVerification
|
||||
|
||||
switch planStep.Action {
|
||||
case rebase.RebaseActionDrop, rebase.RebaseActionPick, rebase.RebaseActionEdit:
|
||||
|
||||
@@ -392,14 +392,6 @@ func parseStashIndex(apr *argparser.ArgParseResults) (int, error) {
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
func bulkDbEaFactory(dbData env.DbData[*sql.Context]) editor.DbEaFactory {
|
||||
tmpDir, err := dbData.Rsw.TempTableFilesDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return editor.NewBulkImportTEAFactory(dbData.Ddb.ValueReadWriter(), tmpDir)
|
||||
}
|
||||
|
||||
func updateWorkingRoot(ctx *sql.Context, dbData env.DbData[*sql.Context], newRoot doltdb.RootValue) error {
|
||||
var h hash.Hash
|
||||
var wsRef ref.WorkingSetRef
|
||||
@@ -510,16 +502,12 @@ func handleMerge(ctx *sql.Context, dbName string, dbData env.DbData[*sql.Context
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
tmpDir, err := dbData.Rsw.TempTableFilesDir()
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
tableResolver, err := dsess.GetTableResolver(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
opts := editor.Options{Deaf: bulkDbEaFactory(dbData), Tempdir: tmpDir}
|
||||
|
||||
opts := editor.Options{}
|
||||
result, err := merge.MergeRoots(ctx, tableResolver, curWorkingRoot, stashRoot, parentRoot, stashRoot, parentCommit, opts, merge.MergeOpts{IsCherryPick: false})
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
sqltypes "github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/cli"
|
||||
@@ -454,32 +453,6 @@ func (d *DoltSession) clear() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DoltSession) newWorkingSetForHead(ctx *sql.Context, wsRef ref.WorkingSetRef, dbName string) (*doltdb.WorkingSet, error) {
|
||||
dbData, _ := d.GetDbData(nil, dbName)
|
||||
|
||||
headSpec, _ := doltdb.NewCommitSpec("HEAD")
|
||||
headRef, err := wsRef.ToHeadRef()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
optCmt, err := dbData.Ddb.Resolve(ctx, headSpec, headRef)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headCommit, ok := optCmt.ToCommit()
|
||||
if !ok {
|
||||
return nil, doltdb.ErrGhostCommitEncountered
|
||||
}
|
||||
|
||||
headRoot, err := headCommit.GetRootValue(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return doltdb.EmptyWorkingSet(wsRef).WithWorkingRoot(headRoot).WithStagedRoot(headRoot), nil
|
||||
}
|
||||
|
||||
// CommitTransaction commits the in-progress transaction. Depending on session settings, this may write only a new
|
||||
// working set, or may additionally create a new dolt commit for the current HEAD. If more than one branch head has
|
||||
// changes, the transaction is rejected.
|
||||
@@ -1328,40 +1301,6 @@ func (d *DoltSession) setHeadRefSessionVar(ctx *sql.Context, db, value string) e
|
||||
func (d *DoltSession) setForeignKeyChecksSessionVar(ctx *sql.Context, key string, value interface{}) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
convertedVal, _, err := sqltypes.Int64.Convert(ctx, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
intVal := int64(0)
|
||||
if convertedVal != nil {
|
||||
intVal = convertedVal.(int64)
|
||||
}
|
||||
|
||||
if intVal == 0 {
|
||||
for _, dbState := range d.dbStates {
|
||||
for _, branchState := range dbState.heads {
|
||||
if ws := branchState.WriteSession(); ws != nil {
|
||||
opts := ws.GetOptions()
|
||||
opts.ForeignKeyChecksDisabled = true
|
||||
ws.SetOptions(opts)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if intVal == 1 {
|
||||
for _, dbState := range d.dbStates {
|
||||
for _, branchState := range dbState.heads {
|
||||
if ws := branchState.WriteSession(); ws != nil {
|
||||
opts := ws.GetOptions()
|
||||
opts.ForeignKeyChecksDisabled = false
|
||||
ws.SetOptions(opts)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return sql.ErrInvalidSystemVariableValue.New("foreign_key_checks", intVal)
|
||||
}
|
||||
|
||||
return d.Session.SetSessionVariable(ctx, key, value)
|
||||
}
|
||||
|
||||
|
||||
@@ -363,15 +363,6 @@ func interfaceToString(r interface{}) (string, error) {
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func resolveRoot(ctx *sql.Context, sess *dsess.DoltSession, dbName, hashStr string) (*refDetails, error) {
|
||||
root, commitTime, _, err := sess.ResolveRootForRef(ctx, dbName, hashStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &refDetails{root: root, hashStr: hashStr, commitTime: commitTime}, nil
|
||||
}
|
||||
|
||||
func resolveCommit(ctx *sql.Context, ddb *doltdb.DoltDB, headRef ref.DoltRef, cSpecStr string) (*doltdb.Commit, error) {
|
||||
cs, err := doltdb.NewCommitSpec(cSpecStr)
|
||||
if err != nil {
|
||||
|
||||
@@ -17,7 +17,9 @@ package dtablefunctions
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gms "github.com/dolthub/go-mysql-server"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
@@ -26,10 +28,13 @@ import (
|
||||
"github.com/dolthub/vitess/go/vt/sqlparser"
|
||||
"github.com/gocraft/dbr/v2"
|
||||
"github.com/gocraft/dbr/v2/dialect"
|
||||
"github.com/shopspring/decimal"
|
||||
"golang.org/x/exp/constraints"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
|
||||
"github.com/dolthub/dolt/go/store/val"
|
||||
)
|
||||
|
||||
const testsRunDefaultRowCount = 10
|
||||
@@ -39,12 +44,13 @@ var _ sql.CatalogTableFunction = (*TestsRunTableFunction)(nil)
|
||||
var _ sql.ExecSourceRel = (*TestsRunTableFunction)(nil)
|
||||
var _ sql.AuthorizationCheckerNode = (*TestsRunTableFunction)(nil)
|
||||
|
||||
type testResult struct {
|
||||
testName string
|
||||
groupName string
|
||||
query string
|
||||
status string
|
||||
message string
|
||||
// TestResult represents the result of running a single test
|
||||
type TestResult struct {
|
||||
TestName string
|
||||
GroupName string
|
||||
Query string
|
||||
Status string
|
||||
Message string
|
||||
}
|
||||
|
||||
type TestsRunTableFunction struct {
|
||||
@@ -199,7 +205,7 @@ func (trtf *TestsRunTableFunction) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIt
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resultRow := sql.NewRow(result.testName, result.groupName, result.query, result.status, result.message)
|
||||
resultRow := sql.NewRow(result.TestName, result.GroupName, result.Query, result.Status, result.Message)
|
||||
resultRows = append(resultRows, resultRow)
|
||||
}
|
||||
}
|
||||
@@ -220,7 +226,7 @@ func (trtf *TestsRunTableFunction) RowCount(_ *sql.Context) (uint64, bool, error
|
||||
return testsRunDefaultRowCount, false, nil
|
||||
}
|
||||
|
||||
func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResult, err error) {
|
||||
func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result TestResult, err error) {
|
||||
testName, groupName, query, assertion, comparison, value, err := parseDoltTestsRow(trtf.ctx, row)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -237,9 +243,9 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
|
||||
if err != nil {
|
||||
message = fmt.Sprintf("Query error: %s", err.Error())
|
||||
} else {
|
||||
testPassed, message, err = actions.AssertData(trtf.ctx, *assertion, *comparison, value, queryResult)
|
||||
testPassed, message, err = AssertData(trtf.ctx, *assertion, *comparison, value, queryResult)
|
||||
if err != nil {
|
||||
return testResult{}, err
|
||||
return TestResult{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -253,11 +259,49 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
|
||||
if groupName != nil {
|
||||
groupString = *groupName
|
||||
}
|
||||
result = testResult{*testName, groupString, *query, status, message}
|
||||
result = TestResult{*testName, groupString, *query, status, message}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (trtf *TestsRunTableFunction) queryAndAssertWithFunc(row sql.Row, assertDataFunc AssertDataFunc) (result TestResult, err error) {
|
||||
testName, groupName, query, assertion, comparison, value, err := parseDoltTestsRow(trtf.ctx, row)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
message, err := validateQuery(trtf.ctx, trtf.catalog, *query)
|
||||
if err != nil && message == "" {
|
||||
message = fmt.Sprintf("query error: %s", err.Error())
|
||||
}
|
||||
|
||||
var testPassed bool
|
||||
if message == "" {
|
||||
_, queryResult, _, err := trtf.engine.Query(trtf.ctx, *query)
|
||||
if err != nil {
|
||||
message = fmt.Sprintf("Query error: %s", err.Error())
|
||||
} else {
|
||||
testPassed, message, err = assertDataFunc(trtf.ctx, *assertion, *comparison, value, queryResult)
|
||||
if err != nil {
|
||||
return TestResult{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status := "PASS"
|
||||
if !testPassed {
|
||||
status = "FAIL"
|
||||
}
|
||||
|
||||
var groupString string
|
||||
if groupName != nil {
|
||||
groupString = *groupName
|
||||
}
|
||||
result = TestResult{*testName, groupString, *query, status, message}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) ([]sql.Row, error) {
|
||||
// Original behavior when root is nil - use SQL queries against current session
|
||||
var queries []string
|
||||
|
||||
if arg == "*" {
|
||||
@@ -320,28 +364,31 @@ func IsWriteQuery(query string, ctx *sql.Context, catalog sql.Catalog) (bool, er
|
||||
}
|
||||
|
||||
func parseDoltTestsRow(ctx *sql.Context, row sql.Row) (testName, groupName, query, assertion, comparison, value *string, err error) {
|
||||
if testName, err = actions.GetStringColAsString(ctx, row[0]); err != nil {
|
||||
if testName, err = getStringColAsString(ctx, row[0]); err != nil {
|
||||
return
|
||||
}
|
||||
if groupName, err = actions.GetStringColAsString(ctx, row[1]); err != nil {
|
||||
if groupName, err = getStringColAsString(ctx, row[1]); err != nil {
|
||||
return
|
||||
}
|
||||
if query, err = actions.GetStringColAsString(ctx, row[2]); err != nil {
|
||||
if query, err = getStringColAsString(ctx, row[2]); err != nil {
|
||||
return
|
||||
}
|
||||
if assertion, err = actions.GetStringColAsString(ctx, row[3]); err != nil {
|
||||
if assertion, err = getStringColAsString(ctx, row[3]); err != nil {
|
||||
return
|
||||
}
|
||||
if comparison, err = actions.GetStringColAsString(ctx, row[4]); err != nil {
|
||||
if comparison, err = getStringColAsString(ctx, row[4]); err != nil {
|
||||
return
|
||||
}
|
||||
if value, err = actions.GetStringColAsString(ctx, row[5]); err != nil {
|
||||
if value, err = getStringColAsString(ctx, row[5]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return testName, groupName, query, assertion, comparison, value, nil
|
||||
}
|
||||
|
||||
// AssertDataFunc defines the function signature for asserting test data
|
||||
type AssertDataFunc func(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error)
|
||||
|
||||
func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string, error) {
|
||||
// We first check if the query contains multiple sql statements
|
||||
if statements, err := sqlparser.SplitStatementToPieces(query); err != nil {
|
||||
@@ -361,3 +408,455 @@ func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string,
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Simple inline assertion constants to avoid circular imports
|
||||
const (
|
||||
AssertionExpectedRows = "expected_rows"
|
||||
AssertionExpectedColumns = "expected_columns"
|
||||
AssertionExpectedSingleValue = "expected_single_value"
|
||||
)
|
||||
|
||||
// getStringColAsString safely converts a sql value to string
|
||||
func getStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
|
||||
if tableValue == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if ts, ok := tableValue.(*val.TextStorage); ok {
|
||||
str, err := ts.Unwrap(sqlCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &str, nil
|
||||
} else if str, ok := tableValue.(string); ok {
|
||||
return &str, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
|
||||
}
|
||||
}
|
||||
|
||||
// readTableDataFromDoltTable reads test data directly from a dolt table
|
||||
func (trtf *TestsRunTableFunction) readTableDataFromDoltTable(table *doltdb.Table, arg string) ([]sql.Row, error) {
|
||||
// This is a complex implementation that requires reading table data directly from dolt storage
|
||||
// For now, return an error that clearly indicates this needs to be implemented
|
||||
// The table scan would involve:
|
||||
// 1. Getting the table schema
|
||||
// 2. Creating a table iterator
|
||||
// 3. Reading and filtering rows based on the arg (test_name or test_group)
|
||||
// 4. Converting dolt storage format to SQL rows
|
||||
//
|
||||
// This is a significant implementation that requires understanding dolt's storage internals
|
||||
return nil, fmt.Errorf("direct table reading from dolt storage not yet implemented for table scan of dolt_tests - this requires implementing table iteration and row conversion from dolt's internal storage format")
|
||||
}
|
||||
|
||||
// AssertData parses an assertion, comparison, and value, then returns the status of the test.
|
||||
// Valid comparison are: "==", "!=", "<", ">", "<=", and ">=".
|
||||
// testPassed indicates whether the test was successful or not.
|
||||
// message is a string used to indicate test failures, and will not halt the overall process.
|
||||
// message will be empty if the test passed.
|
||||
// err indicates runtime failures and will stop dolt_test_run from proceeding.
|
||||
func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) {
|
||||
switch assertion {
|
||||
case AssertionExpectedRows:
|
||||
message, err = expectRows(sqlCtx, comparison, value, queryResult)
|
||||
case AssertionExpectedColumns:
|
||||
message, err = expectColumns(sqlCtx, comparison, value, queryResult)
|
||||
case AssertionExpectedSingleValue:
|
||||
message, err = expectSingleValue(sqlCtx, comparison, value, queryResult)
|
||||
default:
|
||||
return false, fmt.Sprintf("%s is not a valid assertion type", assertion), nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
} else if message != "" {
|
||||
return false, message, nil
|
||||
}
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
row, err := queryResult.Next(sqlCtx)
|
||||
if err == io.EOF {
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(row) != 1 {
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil
|
||||
}
|
||||
_, err = queryResult.Next(sqlCtx)
|
||||
if err == nil { //If multiple rows were given, we should error out
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple rows"), nil
|
||||
} else if err != io.EOF { // "True" error, so we should quit out
|
||||
return "", err
|
||||
}
|
||||
|
||||
if value == nil { // If we're expecting a null value, we don't need to type switch
|
||||
return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil
|
||||
}
|
||||
|
||||
// Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception
|
||||
// of "0" and "1", which are valid integers and are covered below.
|
||||
if *value != "0" && *value != "1" {
|
||||
if expectedBool, err := strconv.ParseBool(*value); err == nil {
|
||||
actualBool, boolErr := getInterfaceAsBool(row[0])
|
||||
if boolErr != nil {
|
||||
return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil
|
||||
}
|
||||
return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil
|
||||
}
|
||||
}
|
||||
|
||||
switch actualValue := row[0].(type) {
|
||||
case int8:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int8(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int16:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int16(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int32:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int64:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil
|
||||
case int:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint8:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint8(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint16:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint16(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint32:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint64:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedUint, actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case float64:
|
||||
expectedFloat, err := strconv.ParseFloat(*value, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil
|
||||
case float32:
|
||||
expectedFloat, err := strconv.ParseFloat(*value, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil
|
||||
case decimal.Decimal:
|
||||
expectedDecimal, err := decimal.NewFromString(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil
|
||||
}
|
||||
return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil
|
||||
case time.Time:
|
||||
expectedTime, format, err := parseTestsDate(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%s does not appear to be a valid date", *value), nil
|
||||
}
|
||||
return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil
|
||||
case *val.TextStorage, string:
|
||||
actualString, err := GetStringColAsString(sqlCtx, actualValue)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil
|
||||
default:
|
||||
return fmt.Sprintf("Type %T is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil
|
||||
}
|
||||
}
|
||||
|
||||
func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
if value == nil {
|
||||
return "null is not a valid assertion for expected_rows", nil
|
||||
}
|
||||
expectedRows, err := strconv.Atoi(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
|
||||
}
|
||||
|
||||
var numRows int
|
||||
for {
|
||||
_, err := queryResult.Next(sqlCtx)
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
numRows++
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil
|
||||
}
|
||||
|
||||
func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
if value == nil {
|
||||
return "null is not a valid assertion for expected_rows", nil
|
||||
}
|
||||
expectedColumns, err := strconv.Atoi(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
|
||||
}
|
||||
|
||||
var numColumns int
|
||||
row, err := queryResult.Next(sqlCtx)
|
||||
if err != nil && err != io.EOF {
|
||||
return "", err
|
||||
}
|
||||
numColumns = len(row)
|
||||
return compareTestAssertion(comparison, expectedColumns, numColumns, AssertionExpectedColumns), nil
|
||||
}
|
||||
|
||||
// compareTestAssertion is a generic function used for comparing string, ints, floats.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareTestAssertion[T constraints.Ordered](comparison string, expectedValue, actualValue T, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if actualValue != expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "!=":
|
||||
if actualValue == expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "<":
|
||||
if actualValue >= expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "<=":
|
||||
if actualValue > expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case ">":
|
||||
if actualValue <= expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case ">=":
|
||||
if actualValue < expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseTestsDate is an internal function that parses the queried string according to allowed time formats for dolt_tests.
|
||||
// It returns the parsed time, the format that succeeded, and an error if applicable.
|
||||
func parseTestsDate(value string) (parsedTime time.Time, format string, err error) {
|
||||
// List of valid formats
|
||||
formats := []string{
|
||||
time.DateOnly,
|
||||
time.DateTime,
|
||||
time.TimeOnly,
|
||||
time.RFC3339,
|
||||
time.RFC1123Z,
|
||||
}
|
||||
|
||||
for _, format := range formats {
|
||||
if parsedTime, parseErr := time.Parse(format, value); parseErr == nil {
|
||||
return parsedTime, format, nil
|
||||
} else {
|
||||
err = parseErr
|
||||
}
|
||||
}
|
||||
return time.Time{}, "", err
|
||||
}
|
||||
|
||||
// compareDates is a function used for comparing time values.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareDates(comparison string, expectedValue, realValue time.Time, format string, assertionType string) string {
|
||||
expectedStr := expectedValue.Format(format)
|
||||
realStr := realValue.Format(format)
|
||||
switch comparison {
|
||||
case "==":
|
||||
if !expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "<":
|
||||
if realValue.Equal(expectedValue) || realValue.After(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "<=":
|
||||
if realValue.After(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case ">":
|
||||
if realValue.Before(expectedValue) || realValue.Equal(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case ">=":
|
||||
if realValue.Before(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// compareDecimals is a function used for comparing decimals.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if !expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "<":
|
||||
if realValue.GreaterThanOrEqual(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "<=":
|
||||
if realValue.GreaterThan(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case ">":
|
||||
if realValue.LessThanOrEqual(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case ">=":
|
||||
if realValue.LessThan(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getTinyIntColAsBool returns the value interface{} as a bool
|
||||
// This is necessary because the query engine may return a tinyint column as a bool, int, or other types.
|
||||
// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles.
|
||||
func getInterfaceAsBool(col interface{}) (bool, error) {
|
||||
switch v := col.(type) {
|
||||
case bool:
|
||||
return v, nil
|
||||
case int:
|
||||
return v == 1, nil
|
||||
case int8:
|
||||
return v == 1, nil
|
||||
case int16:
|
||||
return v == 1, nil
|
||||
case int32:
|
||||
return v == 1, nil
|
||||
case int64:
|
||||
return v == 1, nil
|
||||
case uint:
|
||||
return v == 1, nil
|
||||
case uint8:
|
||||
return v == 1, nil
|
||||
case uint16:
|
||||
return v == 1, nil
|
||||
case uint32:
|
||||
return v == 1, nil
|
||||
case uint64:
|
||||
return v == 1, nil
|
||||
case string:
|
||||
return v == "1", nil
|
||||
default:
|
||||
return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v)
|
||||
}
|
||||
}
|
||||
|
||||
// compareBooleans is a function used for comparing boolean values.
|
||||
// It takes in a comparison string from one of: "==", "!="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if expectedValue != realValue {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue == realValue {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// compareNullValue is a function used for comparing a null value.
|
||||
// It takes in a comparison string from one of: "==", "!="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareNullValue(comparison string, actualValue interface{}, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if actualValue != nil {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue)
|
||||
}
|
||||
case "!=":
|
||||
if actualValue == nil {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetStringColAsString is a function that returns a text column as a string.
|
||||
// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations,
|
||||
// so we use a special parser to get the correct string values
|
||||
func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
|
||||
if ts, ok := tableValue.(*val.TextStorage); ok {
|
||||
str, err := ts.Unwrap(sqlCtx)
|
||||
return &str, err
|
||||
} else if str, ok := tableValue.(string); ok {
|
||||
return &str, nil
|
||||
} else if tableValue == nil {
|
||||
return nil, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,185 +0,0 @@
|
||||
// Copyright 2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dtables
|
||||
|
||||
import (
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
"github.com/dolthub/vitess/go/sqltypes"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
|
||||
)
|
||||
|
||||
const (
|
||||
AccessBinlogTableName = AccessTableName + "_binlog"
|
||||
NamespaceBinlogTableName = NamespaceTableName + "_binlog"
|
||||
)
|
||||
|
||||
// accessBinlogSchema is the schema for the "dolt_branch_control_binlog" table.
|
||||
var accessBinlogSchema = sql.Schema{
|
||||
&sql.Column{
|
||||
Name: "index",
|
||||
Type: types.Int64,
|
||||
Source: AccessBinlogTableName,
|
||||
PrimaryKey: true,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "operation",
|
||||
Type: types.MustCreateEnumType([]string{"insert", "delete"}, sql.Collation_utf8mb4_0900_bin),
|
||||
Source: AccessBinlogTableName,
|
||||
PrimaryKey: false,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "branch",
|
||||
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: AccessBinlogTableName,
|
||||
PrimaryKey: false,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "user",
|
||||
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_bin),
|
||||
Source: AccessBinlogTableName,
|
||||
PrimaryKey: false,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "host",
|
||||
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: AccessBinlogTableName,
|
||||
PrimaryKey: false,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "permissions",
|
||||
Type: types.MustCreateSetType(PermissionsStrings, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: AccessBinlogTableName,
|
||||
PrimaryKey: false,
|
||||
},
|
||||
}
|
||||
|
||||
// namespaceBinlogSchema is the schema for the "dolt_branch_namespace_control_binlog" table.
|
||||
var namespaceBinlogSchema = sql.Schema{
|
||||
&sql.Column{
|
||||
Name: "index",
|
||||
Type: types.Int64,
|
||||
Source: NamespaceBinlogTableName,
|
||||
PrimaryKey: true,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "operation",
|
||||
Type: types.MustCreateEnumType([]string{"insert", "delete"}, sql.Collation_utf8mb4_0900_bin),
|
||||
Source: NamespaceBinlogTableName,
|
||||
PrimaryKey: false,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "branch",
|
||||
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: NamespaceBinlogTableName,
|
||||
PrimaryKey: false,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "user",
|
||||
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_bin),
|
||||
Source: NamespaceBinlogTableName,
|
||||
PrimaryKey: false,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "host",
|
||||
Type: types.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: NamespaceBinlogTableName,
|
||||
PrimaryKey: false,
|
||||
},
|
||||
}
|
||||
|
||||
// BinlogTable provides a queryable view over the Binlog.
|
||||
type BinlogTable struct {
|
||||
Log *branch_control.Binlog
|
||||
IsAccess bool
|
||||
}
|
||||
|
||||
var _ sql.Table = BinlogTable{}
|
||||
|
||||
// Name implements the interface sql.Table.
|
||||
func (b BinlogTable) Name() string {
|
||||
if b.IsAccess {
|
||||
return AccessBinlogTableName
|
||||
} else {
|
||||
return NamespaceBinlogTableName
|
||||
}
|
||||
}
|
||||
|
||||
// String implements the interface sql.Table.
|
||||
func (b BinlogTable) String() string {
|
||||
if b.IsAccess {
|
||||
return AccessBinlogTableName
|
||||
} else {
|
||||
return NamespaceBinlogTableName
|
||||
}
|
||||
}
|
||||
|
||||
// Schema implements the interface sql.Table.
|
||||
func (b BinlogTable) Schema() sql.Schema {
|
||||
if b.IsAccess {
|
||||
return accessBinlogSchema
|
||||
} else {
|
||||
return namespaceBinlogSchema
|
||||
}
|
||||
}
|
||||
|
||||
// Collation implements the interface sql.Table.
|
||||
func (b BinlogTable) Collation() sql.CollationID {
|
||||
return sql.Collation_Default
|
||||
}
|
||||
|
||||
// Partitions implements the interface sql.Table.
|
||||
func (b BinlogTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
|
||||
return index.SinglePartitionIterFromNomsMap(nil), nil
|
||||
}
|
||||
|
||||
// PartitionRows implements the interface sql.Table.
|
||||
func (b BinlogTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
|
||||
b.Log.RWMutex.RLock()
|
||||
defer b.Log.RWMutex.RUnlock()
|
||||
|
||||
binlogRows := b.Log.Rows()
|
||||
rows := make([]sql.Row, len(binlogRows))
|
||||
for i := 0; i < len(binlogRows); i++ {
|
||||
logRow := binlogRows[i]
|
||||
operation := uint16(1)
|
||||
if !logRow.IsInsert {
|
||||
operation = 2
|
||||
}
|
||||
|
||||
if b.IsAccess {
|
||||
rows[i] = sql.Row{
|
||||
int64(i),
|
||||
operation,
|
||||
logRow.Branch,
|
||||
logRow.User,
|
||||
logRow.Host,
|
||||
logRow.Permissions,
|
||||
}
|
||||
} else {
|
||||
rows[i] = sql.Row{
|
||||
int64(i),
|
||||
operation,
|
||||
logRow.Branch,
|
||||
logRow.User,
|
||||
logRow.Host,
|
||||
}
|
||||
}
|
||||
}
|
||||
return sql.RowsToRowIter(rows...), nil
|
||||
}
|
||||
@@ -24,9 +24,7 @@ import (
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/row"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/rowconv"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/expreval"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
|
||||
@@ -591,38 +589,6 @@ func (dt *DiffTable) PreciseMatch() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// tableData returns the map of primary key to values for the specified table (or an empty map if the tbl is null)
|
||||
// and the schema of the table (or EmptySchema if tbl is null).
|
||||
func tableData(ctx *sql.Context, tbl *doltdb.Table, ddb *doltdb.DoltDB) (durable.Index, schema.Schema, error) {
|
||||
var data durable.Index
|
||||
var err error
|
||||
|
||||
if tbl == nil {
|
||||
data, err = durable.NewEmptyPrimaryIndex(ctx, ddb.ValueReadWriter(), ddb.NodeStore(), schema.EmptySchema)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
} else {
|
||||
data, err = tbl.GetRowData(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var sch schema.Schema
|
||||
if tbl == nil {
|
||||
sch = schema.EmptySchema
|
||||
} else {
|
||||
sch, err = tbl.GetSchema(ctx)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return data, sch, nil
|
||||
}
|
||||
|
||||
type TblInfoAtCommit struct {
|
||||
date *types.Timestamp
|
||||
tbl *doltdb.Table
|
||||
@@ -891,20 +857,6 @@ func (dps *DiffPartitions) Close(*sql.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// rowConvForSchema creates a RowConverter for transforming rows with the given schema a target schema.
|
||||
func (dp DiffPartition) rowConvForSchema(ctx context.Context, vrw types.ValueReadWriter, targetSch, srcSch schema.Schema) (*rowconv.RowConverter, error) {
|
||||
if schema.SchemasAreEqual(srcSch, schema.EmptySchema) {
|
||||
return rowconv.IdentityConverter, nil
|
||||
}
|
||||
|
||||
fm, err := rowconv.TagMappingByTagAndName(srcSch, targetSch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return rowconv.NewRowConverter(ctx, vrw, fm)
|
||||
}
|
||||
|
||||
// GetDiffTableSchemaAndJoiner returns the schema for the diff table given a
|
||||
// target schema for a row |sch|. In the old storage format, it also returns the
|
||||
// associated joiner.
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
// Copyright 2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dtables
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/expression"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/utils/set"
|
||||
)
|
||||
|
||||
type Predicate func(sql.Expression) bool
|
||||
|
||||
// ColumnPredicate returns a predicate function for expressions on the column names given
|
||||
func ColumnPredicate(colNameSet *set.StrSet) Predicate {
|
||||
return func(filter sql.Expression) bool {
|
||||
isCommitFilter := true
|
||||
sql.Inspect(filter, func(e sql.Expression) (cont bool) {
|
||||
if e == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
switch val := e.(type) {
|
||||
case *expression.GetField:
|
||||
if !colNameSet.Contains(strings.ToLower(val.Name())) {
|
||||
isCommitFilter = false
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return isCommitFilter
|
||||
}
|
||||
}
|
||||
|
||||
// FilterFilters returns the subset of the expressions given that match the given predicate
|
||||
func FilterFilters(filters []sql.Expression, predicate func(filter sql.Expression) bool) []sql.Expression {
|
||||
matching := make([]sql.Expression, 0, len(filters))
|
||||
for _, f := range filters {
|
||||
if predicate(f) {
|
||||
matching = append(matching, f)
|
||||
}
|
||||
}
|
||||
return matching
|
||||
}
|
||||
@@ -156,49 +156,3 @@ func (itr *StashItr) Next(*sql.Context) (sql.Row, error) {
|
||||
func (itr *StashItr) Close(*sql.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ sql.RowReplacer = stashWriter{nil}
|
||||
var _ sql.RowUpdater = stashWriter{nil}
|
||||
var _ sql.RowInserter = stashWriter{nil}
|
||||
var _ sql.RowDeleter = stashWriter{nil}
|
||||
|
||||
type stashWriter struct {
|
||||
rt *StashesTable
|
||||
}
|
||||
|
||||
// Insert inserts the row given, returning an error if it cannot. Insert will be called once for each row to process
|
||||
// for the insert operation, which may involve many rows. After all rows in an operation have been processed, Close
|
||||
// is called.
|
||||
func (bWr stashWriter) Insert(_ *sql.Context, _ sql.Row) error {
|
||||
return fmt.Errorf("the dolt_stashes table is read-only; use the dolt_stash stored procedure to edit stashes")
|
||||
}
|
||||
|
||||
// Update the given row. Provides both the old and new rows.
|
||||
func (bWr stashWriter) Update(_ *sql.Context, _ sql.Row, _ sql.Row) error {
|
||||
return fmt.Errorf("the dolt_stash table is read-only; use the dolt_stash stored procedure to edit stashes")
|
||||
}
|
||||
|
||||
// Delete deletes the given row. Returns ErrDeleteRowNotFound if the row was not found. Delete will be called once for
|
||||
// each row to process for the delete operation, which may involve many rows. After all rows have been processed,
|
||||
// Close is called.
|
||||
func (bWr stashWriter) Delete(_ *sql.Context, _ sql.Row) error {
|
||||
return fmt.Errorf("the dolt_stash table is read-only; use the dolt_stash stored procedure to edit stashes")
|
||||
}
|
||||
|
||||
// StatementBegin implements the interface sql.TableEditor. Currently a no-op.
|
||||
func (bWr stashWriter) StatementBegin(*sql.Context) {}
|
||||
|
||||
// DiscardChanges implements the interface sql.TableEditor. Currently a no-op.
|
||||
func (bWr stashWriter) DiscardChanges(_ *sql.Context, _ error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StatementComplete implements the interface sql.TableEditor. Currently a no-op.
|
||||
func (bWr stashWriter) StatementComplete(*sql.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close finalizes the delete operation, persisting the result.
|
||||
func (bWr stashWriter) Close(*sql.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/expression"
|
||||
"github.com/dolthub/go-mysql-server/sql/plan"
|
||||
"github.com/dolthub/go-mysql-server/sql/transform"
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
|
||||
@@ -462,64 +461,6 @@ func (itr *doltDiffCommitHistoryRowItr) Close(*sql.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// isTableDataEmpty return true if the table does not contain any data
|
||||
func isTableDataEmpty(ctx *sql.Context, table *doltdb.Table) (bool, error) {
|
||||
rowData, err := table.GetRowData(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return rowData.Empty()
|
||||
}
|
||||
|
||||
// commitFilterForDiffTableFilterExprs returns CommitFilter used for CommitItr.
|
||||
func commitFilterForDiffTableFilterExprs(filters []sql.Expression) (doltdb.CommitFilter[*sql.Context], error) {
|
||||
filters = transformFilters(filters...)
|
||||
|
||||
return func(ctx *sql.Context, h hash.Hash, optCmt *doltdb.OptionalCommit) (filterOut bool, err error) {
|
||||
cm, ok := optCmt.ToCommit()
|
||||
if !ok {
|
||||
return false, doltdb.ErrGhostCommitEncountered
|
||||
}
|
||||
|
||||
meta, err := cm.GetCommitMeta(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, filter := range filters {
|
||||
res, err := filter.Eval(ctx, sql.Row{h.String(), meta.Name, meta.Time()})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
b, ok := res.(bool)
|
||||
if ok && !b {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, err
|
||||
}, nil
|
||||
}
|
||||
|
||||
// transformFilters return filter expressions with index specified for rows used in CommitFilter.
|
||||
func transformFilters(filters ...sql.Expression) []sql.Expression {
|
||||
for i := range filters {
|
||||
filters[i], _, _ = transform.Expr(filters[i], func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
|
||||
gf, ok := e.(*expression.GetField)
|
||||
if !ok {
|
||||
return e, transform.SameTree, nil
|
||||
}
|
||||
switch gf.Name() {
|
||||
case commitHashCol:
|
||||
return gf.WithIndex(0), transform.NewTree, nil
|
||||
default:
|
||||
return gf, transform.SameTree, nil
|
||||
}
|
||||
})
|
||||
}
|
||||
return filters
|
||||
}
|
||||
|
||||
func getCommitsFromCommitHashEquality(ctx *sql.Context, ddb *doltdb.DoltDB, filters []sql.Expression) ([]*doltdb.Commit, bool) {
|
||||
var commits []*doltdb.Commit
|
||||
var isCommitHashEquality bool
|
||||
|
||||
@@ -1210,6 +1210,11 @@ func TestDoltDdlScripts(t *testing.T) {
|
||||
RunDoltDdlScripts(t, harness)
|
||||
}
|
||||
|
||||
func TestDoltCommitVerificationScripts(t *testing.T) {
|
||||
harness := newDoltEnginetestHarness(t)
|
||||
RunDoltCommitVerificationScripts(t, harness)
|
||||
}
|
||||
|
||||
func TestBrokenDdlScripts(t *testing.T) {
|
||||
for _, script := range BrokenDDLScripts {
|
||||
t.Skip(script.Name)
|
||||
|
||||
@@ -512,7 +512,8 @@ func RunStoredProceduresTest(t *testing.T, h DoltEnginetestHarness) {
|
||||
}
|
||||
|
||||
func RunDoltStoredProceduresTest(t *testing.T, h DoltEnginetestHarness) {
|
||||
for _, script := range DoltProcedureTests {
|
||||
scripts := append(DoltProcedureTests, DoltCleanProcedureScripts...)
|
||||
for _, script := range scripts {
|
||||
func() {
|
||||
h := h.NewHarness(t)
|
||||
h.UseLocalFileSystem()
|
||||
@@ -523,7 +524,8 @@ func RunDoltStoredProceduresTest(t *testing.T, h DoltEnginetestHarness) {
|
||||
}
|
||||
|
||||
func RunDoltStoredProceduresPreparedTest(t *testing.T, h DoltEnginetestHarness) {
|
||||
for _, script := range DoltProcedureTests {
|
||||
scripts := append(DoltProcedureTests, DoltCleanProcedureScripts...)
|
||||
for _, script := range scripts {
|
||||
func() {
|
||||
h := h.NewHarness(t)
|
||||
h.UseLocalFileSystem()
|
||||
@@ -2145,3 +2147,12 @@ func RunTransactionTestsWithEngineSetup(t *testing.T, setupEngine func(*gms.Engi
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunDoltCommitVerificationScripts(t *testing.T, harness DoltEnginetestHarness) {
|
||||
for _, script := range DoltCommitVerificationScripts {
|
||||
harness := harness.NewHarness(t)
|
||||
|
||||
enginetest.TestScript(t, harness, script)
|
||||
harness.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +190,6 @@ func (d *DoltHarness) resetScripts() []setup.SetupScript {
|
||||
for i := range dbs {
|
||||
db := dbs[i]
|
||||
resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("use %s", db)})
|
||||
|
||||
// Any auto increment tables must be dropped and recreated to get a fresh state for the global auto increment
|
||||
// sequence trackers
|
||||
_, aiTables := enginetest.MustQuery(ctx, d.engine,
|
||||
@@ -218,6 +217,7 @@ func (d *DoltHarness) resetScripts() []setup.SetupScript {
|
||||
resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("drop database if exists %s", db)})
|
||||
}
|
||||
}
|
||||
|
||||
resetCmds = append(resetCmds, setup.SetupScript{"use mydb"})
|
||||
return resetCmds
|
||||
}
|
||||
@@ -229,7 +229,7 @@ func commitScripts(dbs []string) []setup.SetupScript {
|
||||
db := dbs[i]
|
||||
commitCmds = append(commitCmds, fmt.Sprintf("use %s", db))
|
||||
commitCmds = append(commitCmds, "call dolt_add('.')")
|
||||
commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s', '--date', '1970-01-01T12:00:00')", db))
|
||||
commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s', '--date', '1970-01-01T12:00:00', '--skip-verification')", db))
|
||||
}
|
||||
commitCmds = append(commitCmds, "use mydb")
|
||||
return []setup.SetupScript{commitCmds}
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
// Copyright 2026 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package enginetest
|
||||
|
||||
import (
|
||||
"github.com/dolthub/go-mysql-server/enginetest/queries"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
)
|
||||
|
||||
// DoltCleanProcedureScripts are script tests for the dolt_clean procedure.
|
||||
var DoltCleanProcedureScripts = []queries.ScriptTest{
|
||||
{
|
||||
Name: "dolt_clean does not drop tables matching dolt_ignore",
|
||||
SetUpScript: []string{
|
||||
"CREATE TABLE ignored_foo (id int primary key);",
|
||||
"INSERT INTO ignored_foo VALUES (1);",
|
||||
"INSERT INTO dolt_ignore VALUES ('ignored_*', true);",
|
||||
"CALL dolt_add('dolt_ignore');",
|
||||
"CALL dolt_commit('-m', 'add dolt_ignore');",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "SELECT * FROM ignored_foo;",
|
||||
Expected: []sql.Row{{1}},
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_clean();",
|
||||
Expected: []sql.Row{{0}},
|
||||
},
|
||||
{
|
||||
Query: "SELECT * FROM ignored_foo;",
|
||||
Expected: []sql.Row{{1}},
|
||||
},
|
||||
{
|
||||
Query: "SHOW TABLES;",
|
||||
Expected: []sql.Row{{"ignored_foo"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "dolt_clean -x drops tables matching dolt_ignore",
|
||||
SetUpScript: []string{
|
||||
"CREATE TABLE ignored_bar (id int primary key);",
|
||||
"INSERT INTO ignored_bar VALUES (1);",
|
||||
"INSERT INTO dolt_ignore VALUES ('ignored_*', true);",
|
||||
"CALL dolt_add('dolt_ignore');",
|
||||
"CALL dolt_commit('-m', 'add dolt_ignore');",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "SELECT * FROM ignored_bar;",
|
||||
Expected: []sql.Row{{1}},
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_clean('-x');",
|
||||
Expected: []sql.Row{{0}},
|
||||
},
|
||||
{
|
||||
Query: "SELECT * FROM ignored_bar;",
|
||||
ExpectedErrStr: "table not found: ignored_bar",
|
||||
},
|
||||
{
|
||||
Query: "SHOW TABLES;",
|
||||
Expected: []sql.Row{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,538 @@
|
||||
// Copyright 2025 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package enginetest
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/enginetest"
|
||||
"github.com/dolthub/go-mysql-server/enginetest/queries"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/hash"
|
||||
)
|
||||
|
||||
// commitHashValidator validates commit hash format (32 character hex)
|
||||
type commitHashValidator struct{}
|
||||
|
||||
var _ enginetest.CustomValueValidator = &commitHashValidator{}
|
||||
|
||||
func (chv *commitHashValidator) Validate(val interface{}) (bool, error) {
|
||||
h, ok := val.(string)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
_, ok = hash.MaybeParse(h)
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// successfulRebaseMessageValidator validates successful rebase message format
|
||||
type successfulRebaseMessageValidator struct{}
|
||||
|
||||
var _ enginetest.CustomValueValidator = &successfulRebaseMessageValidator{}
|
||||
var successfulRebaseRegex = regexp.MustCompile(`^Successfully rebased.*`)
|
||||
|
||||
func (srmv *successfulRebaseMessageValidator) Validate(val interface{}) (bool, error) {
|
||||
message, ok := val.(string)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
return successfulRebaseRegex.MatchString(message), nil
|
||||
}
|
||||
|
||||
var commitHash = &commitHashValidator{}
|
||||
var successfulRebaseMessage = &successfulRebaseMessageValidator{}
|
||||
|
||||
var DoltCommitVerificationScripts = []queries.ScriptTest{
|
||||
{
|
||||
Name: "test verification system variables exist and have correct defaults",
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
|
||||
Expected: []sql.Row{
|
||||
{"dolt_commit_verification_groups", ""},
|
||||
},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "test verification system variables can be set",
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
Expected: []sql.Row{{types.OkResult{}}},
|
||||
},
|
||||
{
|
||||
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
|
||||
Expected: []sql.Row{
|
||||
{"dolt_commit_verification_groups", "*"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = 'unit,integration'",
|
||||
Expected: []sql.Row{{types.OkResult{}}},
|
||||
},
|
||||
{
|
||||
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
|
||||
Expected: []sql.Row{
|
||||
{"dolt_commit_verification_groups", "unit,integration"},
|
||||
},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "commit verification enabled - all tests pass",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
|
||||
"('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Alice\"', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit with passing tests')",
|
||||
ExpectedColumns: sql.Schema{
|
||||
{Name: "hash", Type: types.LongText, Nullable: false},
|
||||
},
|
||||
Expected: []sql.Row{{commitHash}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "commit verification enabled - tests fail, commit aborted",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
|
||||
"('test_will_fail', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit that should fail verification')",
|
||||
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 2)",
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_commit('--skip-verification','-m', 'skip verification')",
|
||||
Expected: []sql.Row{{commitHash}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "commit with test verification - specific test groups",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = 'unit'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
|
||||
"('test_will_fail', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit with unit tests only')",
|
||||
Expected: []sql.Row{{commitHash}},
|
||||
},
|
||||
{
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = 'integration'",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_commit('--allow-empty', '--amend', '-m', 'fail please')",
|
||||
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 2)",
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_commit('--allow-empty', '--amend', '--skip-verification', '-m', 'skip the tests')",
|
||||
Expected: []sql.Row{{commitHash}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cherry-pick with test verification enabled - tests pass",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_user_count_update', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'add test')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_user_count_update'",
|
||||
"CALL dolt_add('.')",
|
||||
"call dolt_commit_hash_out(@commit_1_hash,'--skip-verification', '-m', 'Add Bob and update test')",
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'chuck@exampl.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"call dolt_commit_hash_out(@commit_2_hash,'--skip-verification', '-m', 'Add Charlie')",
|
||||
"CALL dolt_checkout('main')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_cherry_pick(@commit_1_hash)",
|
||||
Expected: []sql.Row{{commitHash, int64(0), int64(0), int64(0)}},
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_cherry_pick(@commit_2_hash)",
|
||||
ExpectedErrStr: "commit verification failed: test_user_count_update (Assertion failed: expected_single_value equal to 2, got 3)",
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cherry-pick with test verification enabled - tests fail, aborted",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Initial commit')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"call dolt_commit_hash_out(@commit_hash,'--skip-verification', '-m', 'Add Bob but dont update test')",
|
||||
"CALL dolt_checkout('main')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_cherry_pick(@commit_hash)",
|
||||
ExpectedErrStr: "commit verification failed: test_users_count (Assertion failed: expected_single_value equal to 1, got 2)",
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_cherry_pick('--skip-verification', @commit_hash)",
|
||||
Expected: []sql.Row{{commitHash, int64(0), int64(0), int64(0)}},
|
||||
},
|
||||
{
|
||||
Query: "select * from dolt_test_run('*')",
|
||||
Expected: []sql.Row{
|
||||
{"test_users_count", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 1, got 2"},
|
||||
},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "rebase with test verification enabled - tests pass",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Initial commit')",
|
||||
"DELETE FROM users where id = 1",
|
||||
"INSERT INTO users VALUES (1, 'Zed', 'zed@example.com')",
|
||||
"CALL dolt_commit('-am', 'drop Alice, add Zed')", // tests still pass here.
|
||||
"CALL dolt_checkout('-b', 'feature', 'HEAD~1')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Add Bob and update test')",
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
|
||||
"UPDATE dolt_tests SET assertion_value = '3' WHERE test_name = 'test_users_count'",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Add Charlie, update test')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_rebase('main')",
|
||||
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Name: "rebase with test verification enabled - tests fail, aborted",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Initial commit')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Add Bob but dont update test')",
|
||||
"CALL dolt_checkout('main')",
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie')", // this will trip the existing test.
|
||||
"CALL dolt_checkout('feature')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_rebase('main')",
|
||||
ExpectedErrStr: "commit verification failed: test_users_count (Assertion failed: expected_single_value equal to 2, got 3)",
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_rebase('--abort')",
|
||||
Expected: []sql.Row{{0, "Interactive rebase aborted"}},
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_rebase('--skip-verification', 'main')",
|
||||
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
|
||||
},
|
||||
{
|
||||
Query: "select * from dolt_test_run('*')",
|
||||
Expected: []sql.Row{
|
||||
{"test_users_count", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 2, got 3"},
|
||||
},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "interactive rebase with --skip-verification flag should persist across continue operations",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob but dont update test')", // This will cause test to fail
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie')",
|
||||
"CALL dolt_checkout('main')",
|
||||
"INSERT INTO users VALUES (4, 'David', 'david@example.com')", // Add a commit to main to create divergence
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add David on main')",
|
||||
"CALL dolt_checkout('feature')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_rebase('--interactive', '--skip-verification', 'main')",
|
||||
Expected: []sql.Row{{0, "interactive rebase started on branch dolt_rebase_feature; adjust the rebase plan in the dolt_rebase table, then continue rebasing by calling dolt_rebase('--continue')"}},
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_rebase('--continue')", // This should NOT require --skip-verification flag but should still skip tests
|
||||
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "test verification with no dolt_tests errors",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit without dolt_tests table')",
|
||||
ExpectedErrStr: "failed to run dolt_test_run for group *: could not find tests for argument: *",
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "test verification with mixed test groups - only specified groups run",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = 'unit'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_unit', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
|
||||
"('test_users_integration', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit with unit tests only - should pass')",
|
||||
Expected: []sql.Row{{commitHash}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "test verification error message includes test details",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_specific_failure', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit with specific test failure')",
|
||||
ExpectedErrStr: "commit verification failed: test_specific_failure (Assertion failed: expected_single_value equal to 999, got 2)",
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "merge with test verification enabled - tests pass",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Alice\"', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Initial commit')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_bob_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Bob\"', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
|
||||
"CALL dolt_checkout('main')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_merge('feature')",
|
||||
Expected: []sql.Row{{commitHash, int64(1), int64(0), "merge successful"}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "merge with test verification enabled - tests fail, merge aborted",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit with failing test')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
|
||||
"CALL dolt_checkout('main')",
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie to force non-FF merge')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_merge('feature')",
|
||||
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 3)",
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "merge with --skip-verification flag bypasses verification",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit with failing test')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
|
||||
"CALL dolt_checkout('main')",
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie to force non-FF merge')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_merge('--skip-verification', 'feature')",
|
||||
Expected: []sql.Row{{commitHash, int64(0), int64(0), "merge successful"}},
|
||||
},
|
||||
{
|
||||
Query: "select * from dolt_test_run('*')",
|
||||
Expected: []sql.Row{
|
||||
{"test_will_fail", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 999, got 3"},
|
||||
},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user