diff --git a/go/cmd/dolt/cli/arg_parser_helpers.go b/go/cmd/dolt/cli/arg_parser_helpers.go index 55245652bf..45183408bb 100644 --- a/go/cmd/dolt/cli/arg_parser_helpers.go +++ b/go/cmd/dolt/cli/arg_parser_helpers.go @@ -129,7 +129,9 @@ var mergeAbortDetails = `Abort the current conflict resolution process, and try If there were uncommitted working set changes present when the merge started, {{.EmphasisLeft}}dolt merge --abort{{.EmphasisRight}} will be unable to reconstruct these changes. It is therefore recommended to always commit or stash your changes before running dolt merge. ` -// Creates the argparser shared dolt commit cli and DOLT_COMMIT. +var branchForceFlagDesc = "Reset {{.LessThan}}branchname{{.GreaterThan}} to {{.LessThan}}startpoint{{.GreaterThan}}, even if {{.LessThan}}branchname{{.GreaterThan}} exists already. Without {{.EmphasisLeft}}-f{{.EmphasisRight}}, {{.EmphasisLeft}}dolt branch{{.EmphasisRight}} refuses to change an existing branch. In combination with {{.EmphasisLeft}}-d{{.EmphasisRight}} (or {{.EmphasisLeft}}--delete{{.EmphasisRight}}), allow deleting the branch irrespective of its merged status. In combination with -m (or {{.EmphasisLeft}}--move{{.EmphasisRight}}), allow renaming the branch even if the new branch name already exists, the same applies for {{.EmphasisLeft}}-c{{.EmphasisRight}} (or {{.EmphasisLeft}}--copy{{.EmphasisRight}})." + +// CreateCommitArgParser creates the argparser shared dolt commit cli and DOLT_COMMIT. func CreateCommitArgParser() *argparser.ArgParser { ap := argparser.NewArgParser() ap.SupportsString(MessageArg, "m", "msg", "Use the given {{.LessThan}}msg{{.GreaterThan}} as the commit message.") @@ -174,7 +176,7 @@ func CreatePushArgParser() *argparser.ArgParser { func CreateAddArgParser() *argparser.ArgParser { ap := argparser.NewArgParser() ap.ArgListHelp = append(ap.ArgListHelp, [2]string{"table", "Working table(s) to add to the list tables staged to be committed. The abbreviation '.' can be used to add all tables."}) - ap.SupportsFlag("all", "A", "Stages any and all changes (adds, deletes, and modifications).") + ap.SupportsFlag(AllFlag, "A", "Stages any and all changes (adds, deletes, and modifications).") return ap } @@ -228,7 +230,6 @@ func CreateCherryPickArgParser() *argparser.ArgParser { func CreateFetchArgParser() *argparser.ArgParser { ap := argparser.NewArgParser() - ap.SupportsFlag(ForceFlag, "f", "Update refs to remote branches with the current state of the remote, overwriting any conflicting history.") return ap } @@ -256,11 +257,12 @@ func CreatePullArgParser() *argparser.ArgParser { func CreateBranchArgParser() *argparser.ArgParser { ap := argparser.NewArgParser() - ap.SupportsFlag(ForceFlag, "f", "Ignores any foreign key warnings and proceeds with the commit.") + ap.SupportsFlag(ForceFlag, "f", branchForceFlagDesc) ap.SupportsFlag(CopyFlag, "c", "Create a copy of a branch.") ap.SupportsFlag(MoveFlag, "m", "Move/rename a branch") ap.SupportsFlag(DeleteFlag, "d", "Delete a branch. The branch must be fully merged in its upstream branch.") ap.SupportsFlag(DeleteForceFlag, "", "Shortcut for {{.EmphasisLeft}}--delete --force{{.EmphasisRight}}.") + ap.SupportsString(TrackFlag, "t", "", "When creating a new branch, set up 'upstream' configuration.") return ap } diff --git a/go/cmd/dolt/commands/backup.go b/go/cmd/dolt/commands/backup.go index 215d074956..18c30069b3 100644 --- a/go/cmd/dolt/commands/backup.go +++ b/go/cmd/dolt/commands/backup.go @@ -202,7 +202,7 @@ func printBackups(dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.Ver } for _, r := range backups { - if apr.Contains(verboseFlag) { + if apr.Contains(cli.VerboseFlag) { paramStr := make([]byte, 0) if len(r.Params) > 0 { paramStr, _ = json.Marshal(r.Params) diff --git a/go/cmd/dolt/commands/branch.go b/go/cmd/dolt/commands/branch.go index d2267fc02d..164c8fc7a3 100644 --- a/go/cmd/dolt/commands/branch.go +++ b/go/cmd/dolt/commands/branch.go @@ -34,8 +34,6 @@ import ( "github.com/dolthub/dolt/go/libraries/utils/set" ) -var branchForceFlagDesc = "Reset {{.LessThan}}branchname{{.GreaterThan}} to {{.LessThan}}startpoint{{.GreaterThan}}, even if {{.LessThan}}branchname{{.GreaterThan}} exists already. Without {{.EmphasisLeft}}-f{{.EmphasisRight}}, {{.EmphasisLeft}}dolt branch{{.EmphasisRight}} refuses to change an existing branch. In combination with {{.EmphasisLeft}}-d{{.EmphasisRight}} (or {{.EmphasisLeft}}--delete{{.EmphasisRight}}), allow deleting the branch irrespective of its merged status. In combination with -m (or {{.EmphasisLeft}}--move{{.EmphasisRight}}), allow renaming the branch even if the new branch name already exists, the same applies for {{.EmphasisLeft}}-c{{.EmphasisRight}} (or {{.EmphasisLeft}}--copy{{.EmphasisRight}})." - var branchDocs = cli.CommandDocumentationContent{ ShortDesc: `List, create, or delete branches`, LongDesc: `If {{.EmphasisLeft}}--list{{.EmphasisRight}} is given, or if there are no non-option arguments, existing branches are listed. The current branch will be highlighted with an asterisk. With no options, only local branches are listed. With {{.EmphasisLeft}}-r{{.EmphasisRight}}, only remote branches are listed. With {{.EmphasisLeft}}-a{{.EmphasisRight}} both local and remote branches are listed. {{.EmphasisLeft}}-v{{.EmphasisRight}} causes the hash of the commit that the branches are at to be printed as well. @@ -60,15 +58,7 @@ With a {{.EmphasisLeft}}-d{{.EmphasisRight}}, {{.LessThan}}branchname{{.GreaterT const ( listFlag = "list" - forceFlag = "force" - copyFlag = "copy" - moveFlag = "move" - deleteFlag = "delete" - deleteForceFlag = "D" - verboseFlag = cli.VerboseFlag - allFlag = "all" datasetsFlag = "datasets" - remoteFlag = "remote" showCurrentFlag = "show-current" ) @@ -92,20 +82,14 @@ func (cmd BranchCmd) Docs() *cli.CommandDocumentation { } func (cmd BranchCmd) ArgParser() *argparser.ArgParser { - ap := argparser.NewArgParser() + ap := cli.CreateBranchArgParser() ap.ArgListHelp = append(ap.ArgListHelp, [2]string{"start-point", "A commit that a new branch should point at."}) ap.SupportsFlag(listFlag, "", "List branches") - ap.SupportsFlag(forceFlag, "f", branchForceFlagDesc) - ap.SupportsFlag(copyFlag, "c", "Create a copy of a branch.") - ap.SupportsFlag(moveFlag, "m", "Move/rename a branch") - ap.SupportsFlag(deleteFlag, "d", "Delete a branch. The branch must be fully merged in its upstream branch.") - ap.SupportsFlag(deleteForceFlag, "", "Shortcut for {{.EmphasisLeft}}--delete --force{{.EmphasisRight}}.") - ap.SupportsFlag(verboseFlag, "v", "When in list mode, show the hash and commit subject line for each head") - ap.SupportsFlag(allFlag, "a", "When in list mode, shows remote tracked branches") + ap.SupportsFlag(cli.VerboseFlag, "v", "When in list mode, show the hash and commit subject line for each head") + ap.SupportsFlag(cli.AllFlag, "a", "When in list mode, shows remote tracked branches") ap.SupportsFlag(datasetsFlag, "", "List all datasets in the database") - ap.SupportsFlag(remoteFlag, "r", "When in list mode, show only remote tracked branches. When with -d, delete a remote tracking branch.") + ap.SupportsFlag(cli.RemoteParam, "r", "When in list mode, show only remote tracked branches. When with -d, delete a remote tracking branch.") ap.SupportsFlag(showCurrentFlag, "", "Print the name of the current branch") - ap.SupportsString(cli.TrackFlag, "t", "", "When creating a new branch, set up 'upstream' configuration.") return ap } @@ -121,13 +105,13 @@ func (cmd BranchCmd) Exec(ctx context.Context, commandStr string, args []string, apr := cli.ParseArgsOrDie(ap, args, help) switch { - case apr.Contains(moveFlag): + case apr.Contains(cli.MoveFlag): return moveBranch(ctx, dEnv, apr, usage) - case apr.Contains(copyFlag): + case apr.Contains(cli.CopyFlag): return copyBranch(ctx, dEnv, apr, usage) - case apr.Contains(deleteFlag): - return deleteBranches(ctx, dEnv, apr, usage, apr.Contains(forceFlag)) - case apr.Contains(deleteForceFlag): + case apr.Contains(cli.DeleteFlag): + return deleteBranches(ctx, dEnv, apr, usage, apr.Contains(cli.ForceFlag)) + case apr.Contains(cli.DeleteForceFlag): return deleteBranches(ctx, dEnv, apr, usage, true) case apr.Contains(listFlag): return printBranches(ctx, dEnv, apr, usage) @@ -145,9 +129,9 @@ func (cmd BranchCmd) Exec(ctx context.Context, commandStr string, args []string, func printBranches(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults, _ cli.UsagePrinter) int { branchSet := set.NewStrSet(apr.Args) - verbose := apr.Contains(verboseFlag) + verbose := apr.Contains(cli.VerboseFlag) printRemote := apr.Contains(cli.RemoteParam) - printAll := apr.Contains(allFlag) + printAll := apr.Contains(cli.AllFlag) branches, err := dEnv.DoltDB.GetHeadRefs(ctx) @@ -186,7 +170,6 @@ func printBranches(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPar } else if branch.GetType() == ref.RemoteRefType { branchName = " " + color.RedString("remotes/"+branch.GetPath()) branchLen += len("remotes/") - } if verbose { @@ -260,7 +243,7 @@ func moveBranch(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseR return 1 } - force := apr.Contains(forceFlag) + force := apr.Contains(cli.ForceFlag) src := apr.Arg(0) dest := apr.Arg(1) err := actions.RenameBranch(ctx, dEnv.DbData(), src, apr.Arg(1), dEnv, force) @@ -290,7 +273,7 @@ func copyBranch(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseR return 1 } - force := apr.Contains(forceFlag) + force := apr.Contains(cli.ForceFlag) src := apr.Arg(0) dest := apr.Arg(1) err := actions.CopyBranch(ctx, dEnv, src, dest, force) @@ -323,7 +306,7 @@ func deleteBranches(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPa err := actions.DeleteBranch(ctx, dEnv.DbData(), brName, actions.DeleteOptions{ Force: force, - Remote: apr.Contains(remoteFlag), + Remote: apr.Contains(cli.RemoteParam), }, dEnv) if err != nil { @@ -373,9 +356,8 @@ func createBranch(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars } if apr.NArg() == 2 { - newBranch = apr.Arg(0) - startPt = apr.Arg(1) - remote, remoteBranch = ParseRemoteBranchName(startPt) + // branchName and startPt are already set + remote, remoteBranch = actions.ParseRemoteBranchName(startPt) _, remoteOk := remotes[remote] if !remoteOk { return HandleVErrAndExitCode(errhand.BuildDError("'%s' is not a valid remote ref and a branch '%s' cannot be created from it", startPt, newBranch).Build(), usage) @@ -384,12 +366,12 @@ func createBranch(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars // if track option is defined with no value, // the track value can either be starting point name OR branch name startPt = trackVal - remote, remoteBranch = ParseRemoteBranchName(startPt) + remote, remoteBranch = actions.ParseRemoteBranchName(startPt) _, remoteOk := remotes[remote] if !remoteOk { newBranch = trackVal startPt = apr.Arg(0) - remote, remoteBranch = ParseRemoteBranchName(startPt) + remote, remoteBranch = actions.ParseRemoteBranchName(startPt) _, remoteOk = remotes[remote] if !remoteOk { return HandleVErrAndExitCode(errhand.BuildDError("'%s' is not a valid remote ref and a branch '%s' cannot be created from it", startPt, newBranch).Build(), usage) @@ -398,7 +380,7 @@ func createBranch(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars } } - err := actions.CreateBranchWithStartPt(ctx, dEnv.DbData(), newBranch, startPt, apr.Contains(forceFlag)) + err := actions.CreateBranchWithStartPt(ctx, dEnv.DbData(), newBranch, startPt, apr.Contains(cli.ForceFlag)) if err != nil { return HandleVErrAndExitCode(errhand.BuildDError(err.Error()).Build(), usage) } diff --git a/go/cmd/dolt/commands/checkout.go b/go/cmd/dolt/commands/checkout.go index 34a7b8b18e..05156436dc 100644 --- a/go/cmd/dolt/commands/checkout.go +++ b/go/cmd/dolt/commands/checkout.go @@ -17,7 +17,6 @@ package commands import ( "context" "fmt" - "strings" "github.com/dolthub/dolt/go/libraries/doltcore/ref" @@ -154,7 +153,7 @@ func checkoutNewBranch(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.Ar } else if trackVal == "inherit" { return errhand.VerboseErrorFromError(fmt.Errorf("--track='inherit' is not supported yet")) } - remoteName, remoteBranchName = ParseRemoteBranchName(startPt) + remoteName, remoteBranchName = actions.ParseRemoteBranchName(startPt) remotes, err := dEnv.RepoStateReader().GetRemotes() if err != nil { return errhand.BuildDError(err.Error()).Build() @@ -191,7 +190,7 @@ func checkoutNewBranch(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.Ar if err != nil { return nil } - remoteName, remoteBranchName = ParseRemoteBranchName(startPt) + remoteName, remoteBranchName = actions.ParseRemoteBranchName(startPt) _, remoteOk := remotes[remoteName] if !remoteOk { return nil @@ -326,15 +325,7 @@ func SetRemoteUpstreamForBranchRef(dEnv *env.DoltEnv, remote, remoteBranch strin return errhand.BuildDError(fmt.Errorf("%w: '%s'", err, remote).Error()).Build() } - src := refSpec.SrcRef(branchRef) - dest := refSpec.DestRef(src) - - err = dEnv.RepoStateWriter().UpdateBranch(branchRef.GetPath(), env.BranchConfig{ - Merge: ref.MarshalableRef{ - Ref: dest, - }, - Remote: remote, - }) + err = env.SetRemoteUpstreamForRefSpec(dEnv.RepoStateWriter(), refSpec, remote, branchRef) if err != nil { return errhand.BuildDError(err.Error()).Build() } @@ -348,12 +339,3 @@ func unreadableRootToVErr(err error) errhand.VerboseError { bdr := errhand.BuildDError("error: unable to read the %s", rt.String()) return bdr.AddCause(doltdb.GetUnreachableRootCause(err)).Build() } - -func ParseRemoteBranchName(startPt string) (string, string) { - startPt = strings.TrimPrefix(startPt, "remotes/") - names := strings.Split(startPt, "/") - if len(names) < 2 { - return "", "" - } - return names[0], strings.Join(names[1:], "/") -} diff --git a/go/cmd/dolt/commands/fetch.go b/go/cmd/dolt/commands/fetch.go index 74b2e1ca9a..9533833fe2 100644 --- a/go/cmd/dolt/commands/fetch.go +++ b/go/cmd/dolt/commands/fetch.go @@ -77,22 +77,14 @@ func (cmd FetchCmd) Exec(ctx context.Context, commandStr string, args []string, if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - updateMode := ref.UpdateMode{Force: apr.Contains(cli.ForceFlag)} srcDB, err := r.GetRemoteDBWithoutCaching(ctx, dEnv.DbData().Ddb.ValueReadWriter().Format(), dEnv) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - err = actions.FetchRefSpecs(ctx, dEnv.DbData(), srcDB, refSpecs, r, updateMode, buildProgStarter(downloadLanguage), stopProgFuncs) - switch err { - case doltdb.ErrUpToDate: - return HandleVErrAndExitCode(nil, usage) - case actions.ErrCantFF: - verr := errhand.BuildDError("error: fetch failed, can't fast forward remote tracking ref").AddCause(err).Build() - return HandleVErrAndExitCode(verr, usage) - } - if err != nil { + err = actions.FetchRefSpecs(ctx, dEnv.DbData(), srcDB, refSpecs, r, ref.UpdateMode{Force: true}, buildProgStarter(downloadLanguage), stopProgFuncs) + if err != nil && err != doltdb.ErrUpToDate { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } return HandleVErrAndExitCode(nil, usage) diff --git a/go/cmd/dolt/commands/filter-branch.go b/go/cmd/dolt/commands/filter-branch.go index 7cf08644f8..8e46a472f7 100644 --- a/go/cmd/dolt/commands/filter-branch.go +++ b/go/cmd/dolt/commands/filter-branch.go @@ -84,9 +84,9 @@ func (cmd FilterBranchCmd) Docs() *cli.CommandDocumentation { func (cmd FilterBranchCmd) ArgParser() *argparser.ArgParser { ap := argparser.NewArgParser() - ap.SupportsFlag(verboseFlag, "v", "logs more information") + ap.SupportsFlag(cli.VerboseFlag, "v", "logs more information") ap.SupportsFlag(branchesFlag, "b", "filter all branches") - ap.SupportsFlag(allFlag, "a", "filter all branches and tags") + ap.SupportsFlag(cli.AllFlag, "a", "filter all branches and tags") return ap } @@ -112,7 +112,7 @@ func (cmd FilterBranchCmd) Exec(ctx context.Context, commandStr string, args []s } query := apr.Arg(0) - verbose := apr.Contains(verboseFlag) + verbose := apr.Contains(cli.VerboseFlag) notFound := make(missingTbls) replay := func(ctx context.Context, commit, _, _ *doltdb.Commit) (*doltdb.RootValue, error) { @@ -160,7 +160,7 @@ func (cmd FilterBranchCmd) Exec(ctx context.Context, commandStr string, args []s switch { case apr.Contains(branchesFlag): err = rebase.AllBranches(ctx, dEnv, replay, nerf) - case apr.Contains(allFlag): + case apr.Contains(cli.AllFlag): err = rebase.AllBranchesAndTags(ctx, dEnv, replay, nerf) default: err = rebase.CurrentBranch(ctx, dEnv, replay, nerf) diff --git a/go/cmd/dolt/commands/ls.go b/go/cmd/dolt/commands/ls.go index fae751d76e..b1ea833756 100644 --- a/go/cmd/dolt/commands/ls.go +++ b/go/cmd/dolt/commands/ls.go @@ -66,9 +66,9 @@ func (cmd LsCmd) Docs() *cli.CommandDocumentation { func (cmd LsCmd) ArgParser() *argparser.ArgParser { ap := argparser.NewArgParser() - ap.SupportsFlag(verboseFlag, "v", "show the hash of the table") + ap.SupportsFlag(cli.VerboseFlag, "v", "show the hash of the table") ap.SupportsFlag(systemFlag, "s", "show system tables") - ap.SupportsFlag(allFlag, "a", "show system tables") + ap.SupportsFlag(cli.AllFlag, "a", "show system tables") return ap } @@ -83,8 +83,8 @@ func (cmd LsCmd) Exec(ctx context.Context, commandStr string, args []string, dEn help, usage := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString(commandStr, lsDocs, ap)) apr := cli.ParseArgsOrDie(ap, args, help) - if apr.Contains(systemFlag) && apr.Contains(allFlag) { - verr := errhand.BuildDError("--%s and --%s are mutually exclusive", systemFlag, allFlag).SetPrintUsage().Build() + if apr.Contains(systemFlag) && apr.Contains(cli.AllFlag) { + verr := errhand.BuildDError("--%s and --%s are mutually exclusive", systemFlag, cli.AllFlag).SetPrintUsage().Build() HandleVErrAndExitCode(verr, usage) } @@ -99,13 +99,13 @@ func (cmd LsCmd) Exec(ctx context.Context, commandStr string, args []string, dEn } if verr == nil { - if !apr.Contains(systemFlag) || apr.Contains(allFlag) { - verr = printUserTables(ctx, root, label, apr.Contains(verboseFlag)) + if !apr.Contains(systemFlag) || apr.Contains(cli.AllFlag) { + verr = printUserTables(ctx, root, label, apr.Contains(cli.VerboseFlag)) cli.Println() } - if verr == nil && (apr.Contains(systemFlag) || apr.Contains(allFlag)) { - verr = printSystemTables(ctx, root, dEnv.DoltDB, apr.Contains(verboseFlag)) + if verr == nil && (apr.Contains(systemFlag) || apr.Contains(cli.AllFlag)) { + verr = printSystemTables(ctx, root, dEnv.DoltDB, apr.Contains(cli.VerboseFlag)) cli.Println() } } diff --git a/go/cmd/dolt/commands/remote.go b/go/cmd/dolt/commands/remote.go index 14032bd08c..e9c6ecaff1 100644 --- a/go/cmd/dolt/commands/remote.go +++ b/go/cmd/dolt/commands/remote.go @@ -85,15 +85,11 @@ func (cmd RemoteCmd) Docs() *cli.CommandDocumentation { } func (cmd RemoteCmd) ArgParser() *argparser.ArgParser { - ap := argparser.NewArgParser() + ap := cli.CreateRemoteArgParser() ap.ArgListHelp = append(ap.ArgListHelp, [2]string{"region", "cloud provider region associated with this remote."}) ap.ArgListHelp = append(ap.ArgListHelp, [2]string{"creds-type", "credential type. Valid options are role, env, and file. See the help section for additional details."}) ap.ArgListHelp = append(ap.ArgListHelp, [2]string{"profile", "AWS profile to use."}) - ap.SupportsFlag(verboseFlag, "v", "When printing the list of remotes adds additional details.") - ap.SupportsString(dbfactory.AWSRegionParam, "", "region", "") - ap.SupportsValidatedString(dbfactory.AWSCredsTypeParam, "", "creds-type", "", argparser.ValidatorFromStrList(dbfactory.AWSCredsTypeParam, dbfactory.AWSCredTypes)) - ap.SupportsString(dbfactory.AWSCredsFileParam, "", "file", "AWS credentials file") - ap.SupportsString(dbfactory.AWSCredsProfile, "", "profile", "AWS profile to use") + ap.SupportsFlag(cli.VerboseFlag, "v", "When printing the list of remotes adds additional details.") ap.SupportsString(dbfactory.OSSCredsFileParam, "", "file", "OSS credentials file") ap.SupportsString(dbfactory.OSSCredsProfile, "", "profile", "OSS profile to use") return ap @@ -216,7 +212,7 @@ func printRemotes(dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.Ver } for _, r := range remotes { - if apr.Contains(verboseFlag) { + if apr.Contains(cli.VerboseFlag) { paramStr := make([]byte, 0) if len(r.Params) > 0 { paramStr, _ = json.Marshal(r.Params) diff --git a/go/cmd/dolt/commands/sql.go b/go/cmd/dolt/commands/sql.go index 841792d257..ea92c7112d 100644 --- a/go/cmd/dolt/commands/sql.go +++ b/go/cmd/dolt/commands/sql.go @@ -23,6 +23,7 @@ import ( "path/filepath" "strings" "syscall" + "time" "github.com/abiosoft/readline" "github.com/dolthub/go-mysql-server/sql" @@ -819,6 +820,8 @@ func runMultiStatementMode(ctx *sql.Context, se *engine.SqlEngine, input io.Read } } + // store start time for query + ctx.SetQueryTime(time.Now()) sqlSch, rowIter, err := processParsedQuery(ctx, query, se, sqlStatement) if err != nil { handleError(scanner.statementStartLine, query, err) diff --git a/go/cmd/dolt/commands/tblcmds/import.go b/go/cmd/dolt/commands/tblcmds/import.go index 556dbeee7b..db456b9f8d 100644 --- a/go/cmd/dolt/commands/tblcmds/import.go +++ b/go/cmd/dolt/commands/tblcmds/import.go @@ -68,6 +68,20 @@ const ( disableFkChecks = "disable-fk-checks" ) +var jsonInputFileHelp = "The expected JSON input file format is:" + ` + + { "rows": + [ + { + "column_name":"value" + ... + }, ... + ] + } + +where column_name is the name of a column of the table being imported and value is the data for that column in the table. +` + var importDocs = cli.CommandDocumentationContent{ ShortDesc: `Imports data into a dolt table`, LongDesc: `If {{.EmphasisLeft}}--create-table | -c{{.EmphasisRight}} is given the operation will create {{.LessThan}}table{{.GreaterThan}} and import the contents of file into it. If a table already exists at this location then the operation will fail, unless the {{.EmphasisLeft}}--force | -f{{.EmphasisRight}} flag is provided. The force flag forces the existing table to be overwritten. @@ -86,6 +100,8 @@ A mapping file can be used to map fields between the file being imported and the ` + schcmds.MappingFileHelp + ` +` + jsonInputFileHelp + + ` In create, update, and replace scenarios the file's extension is used to infer the type of the file. If a file does not have the expected extension then the {{.EmphasisLeft}}--file-type{{.EmphasisRight}} parameter should be used to explicitly define the format of the file in one of the supported formats (csv, psv, json, xlsx). For files separated by a delimiter other than a ',' (type csv) or a '|' (type psv), the --delim parameter can be used to specify a delimiter`, Synopsis: []string{ diff --git a/go/cmd/dolt/dolt.go b/go/cmd/dolt/dolt.go index d91bb50714..c87b7f0f7d 100644 --- a/go/cmd/dolt/dolt.go +++ b/go/cmd/dolt/dolt.go @@ -56,7 +56,7 @@ import ( ) const ( - Version = "0.52.19" + Version = "0.52.20" ) var dumpDocsCommand = &commands.DumpDocsCmd{} @@ -186,7 +186,7 @@ func runMain() int { cli.Println(cyanStar, " /trace: A trace of execution of the current program. You can specify the duration in the seconds GET parameter. After you get the trace file, use the go tool trace command to investigate the trace.") cli.Println() - err := http.ListenAndServe("localhost:6060", nil) + err := http.ListenAndServe("0.0.0.0:6060", nil) if err != nil { cli.Println(color.YellowString("pprof server exited with error: %v", err)) diff --git a/go/go.mod b/go/go.mod index 95323ec5c7..7096c84d80 100644 --- a/go/go.mod +++ b/go/go.mod @@ -15,7 +15,7 @@ require ( github.com/dolthub/fslock v0.0.3 github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20230201234433-864c7d109df8 + github.com/dolthub/vitess v0.0.0-20230210003150-3065f526d869 github.com/dustin/go-humanize v1.0.0 github.com/fatih/color v1.13.0 github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 @@ -58,7 +58,7 @@ require ( github.com/cenkalti/backoff/v4 v4.1.3 github.com/cespare/xxhash v1.1.0 github.com/creasty/defaults v1.6.0 - github.com/dolthub/go-mysql-server v0.14.1-0.20230203182436-2dac5eaba602 + github.com/dolthub/go-mysql-server v0.14.1-0.20230210003917-ba6b4d6584b0 github.com/google/flatbuffers v2.0.6+incompatible github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6 github.com/mitchellh/go-ps v1.0.0 diff --git a/go/go.sum b/go/go.sum index 0dde6695d4..c3580f6a34 100644 --- a/go/go.sum +++ b/go/go.sum @@ -161,16 +161,16 @@ github.com/dolthub/flatbuffers v1.13.0-dh.1 h1:OWJdaPep22N52O/0xsUevxJ6Qfw1M2txC github.com/dolthub/flatbuffers v1.13.0-dh.1/go.mod h1:CorYGaDmXjHz1Z7i50PYXG1Ricn31GcA2wNOTFIQAKE= github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= -github.com/dolthub/go-mysql-server v0.14.1-0.20230203182436-2dac5eaba602 h1:2tKO9mNuquQNmpcqFm4YhE6vyQR5/2bepkxUZnW1X9w= -github.com/dolthub/go-mysql-server v0.14.1-0.20230203182436-2dac5eaba602/go.mod h1:aVtgxAf6Bfs0hCj+KzIH7Y1aAxg7/7FlslouCh94VVQ= +github.com/dolthub/go-mysql-server v0.14.1-0.20230210003917-ba6b4d6584b0 h1:Z8QzgtCJfgApWyrXTIyooASjoRrbBdOW24Im2JSE0Ro= +github.com/dolthub/go-mysql-server v0.14.1-0.20230210003917-ba6b4d6584b0/go.mod h1:3PGGtLcVPnJumgozqqAKZPae88QmvkOd1KGS+Z2/RXU= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto= github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8= github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474/go.mod h1:kMz7uXOXq4qRriCEyZ/LUeTqraLJCjf0WVZcUi6TxUY= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20230201234433-864c7d109df8 h1:h1DBe5+9JIArCVsBV14fA+RHDXWY8ynUheDL5ZVPOTg= -github.com/dolthub/vitess v0.0.0-20230201234433-864c7d109df8/go.mod h1:oVFIBdqMFEkt4Xz2fzFJBNtzKhDEjwdCF0dzde39iKs= +github.com/dolthub/vitess v0.0.0-20230210003150-3065f526d869 h1:RiSFAJqwBJmFbISgxWEdpljUak1uFtNCKG0zGT8xzA4= +github.com/dolthub/vitess v0.0.0-20230210003150-3065f526d869/go.mod h1:oVFIBdqMFEkt4Xz2fzFJBNtzKhDEjwdCF0dzde39iKs= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= diff --git a/go/libraries/doltcore/doltdb/doltdb.go b/go/libraries/doltcore/doltdb/doltdb.go index 197d2d2abb..a71f21729b 100644 --- a/go/libraries/doltcore/doltdb/doltdb.go +++ b/go/libraries/doltcore/doltdb/doltdb.go @@ -931,6 +931,7 @@ func (ddb *DoltDB) GetRefsOfType(ctx context.Context, refTypeFilter map[ref.RefT } // NewBranchAtCommit creates a new branch with HEAD at the commit given. Branch names must pass IsValidUserBranchName. +// Silently overwrites any existing branch with the same name given, if one exists. func (ddb *DoltDB) NewBranchAtCommit(ctx context.Context, branchRef ref.DoltRef, commit *Commit) error { if !IsValidBranchRef(branchRef) { panic(fmt.Sprintf("invalid branch name %s, use IsValidUserBranchName check", branchRef.String())) diff --git a/go/libraries/doltcore/doltdb/system_table.go b/go/libraries/doltcore/doltdb/system_table.go index 833db9ec75..b88a0a8fb0 100644 --- a/go/libraries/doltcore/doltdb/system_table.go +++ b/go/libraries/doltcore/doltdb/system_table.go @@ -155,6 +155,7 @@ var persistedSystemTables = []string{ var generatedSystemTables = []string{ BranchesTableName, + RemoteBranchesTableName, LogTableName, TableOfTablesInConflictName, TableOfTablesWithViolationsName, @@ -268,6 +269,9 @@ const ( // BranchesTableName is the branches system table name BranchesTableName = "dolt_branches" + // RemoteBranchesTableName is the all-branches system table name + RemoteBranchesTableName = "dolt_remote_branches" + // RemotesTableName is the remotes system table name RemotesTableName = "dolt_remotes" diff --git a/go/libraries/doltcore/env/actions/remotes.go b/go/libraries/doltcore/env/actions/remotes.go index d4b05fbdb7..7d0b0e5352 100644 --- a/go/libraries/doltcore/env/actions/remotes.go +++ b/go/libraries/doltcore/env/actions/remotes.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "strings" "sync" "github.com/dolthub/dolt/go/cmd/dolt/cli" @@ -494,3 +495,32 @@ func HandleInitRemoteStorageClientErr(name, url string, err error) error { var detail = fmt.Sprintf("the remote: %s '%s' could not be accessed", name, url) return fmt.Errorf("%w; %s; %s", ErrFailedToGetRemoteDb, detail, err.Error()) } + +// ParseRemoteBranchName takes remote branch ref name, parses it and returns remote branch name. +// For example, it parses the input string 'origin/john/mybranch' and returns remote name 'origin' and branch name 'john/mybranch'. +func ParseRemoteBranchName(startPt string) (string, string) { + startPt = strings.TrimPrefix(startPt, "remotes/") + names := strings.SplitN(startPt, "/", 2) + if len(names) < 2 { + return "", "" + } + return names[0], names[1] +} + +// GetRemoteBranchRef returns a remote ref with matching name for a branch for each remote. +func GetRemoteBranchRef(ctx context.Context, ddb *doltdb.DoltDB, name string) ([]ref.RemoteRef, error) { + remoteRefFilter := map[ref.RefType]struct{}{ref.RemoteRefType: {}} + refs, err := ddb.GetRefsOfType(ctx, remoteRefFilter) + if err != nil { + return nil, err + } + + var remoteRef []ref.RemoteRef + for _, rf := range refs { + if remRef, ok := rf.(ref.RemoteRef); ok && remRef.GetBranch() == name { + remoteRef = append(remoteRef, remRef) + } + } + + return remoteRef, nil +} diff --git a/go/libraries/doltcore/env/actions/table.go b/go/libraries/doltcore/env/actions/table.go index dd6bdae449..f17b70be77 100644 --- a/go/libraries/doltcore/env/actions/table.go +++ b/go/libraries/doltcore/env/actions/table.go @@ -17,8 +17,6 @@ package actions import ( "context" - "github.com/dolthub/dolt/go/libraries/doltcore/ref" - "github.com/dolthub/dolt/go/libraries/doltcore/diff" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/utils/set" @@ -120,21 +118,3 @@ func RemoveDocsTable(tbls []string) []string { } return result } - -// GetRemoteBranchRef returns a remote ref with matching name for a branch for each remotes. -func GetRemoteBranchRef(ctx context.Context, ddb *doltdb.DoltDB, name string) ([]ref.RemoteRef, error) { - remoteRefFilter := map[ref.RefType]struct{}{ref.RemoteRefType: {}} - refs, err := ddb.GetRefsOfType(ctx, remoteRefFilter) - if err != nil { - return nil, err - } - - var remoteRef []ref.RemoteRef - for _, rf := range refs { - if remRef, ok := rf.(ref.RemoteRef); ok && remRef.GetBranch() == name { - remoteRef = append(remoteRef, remRef) - } - } - - return remoteRef, nil -} diff --git a/go/libraries/doltcore/env/remotes.go b/go/libraries/doltcore/env/remotes.go index 7782970ec7..6008d2446e 100644 --- a/go/libraries/doltcore/env/remotes.go +++ b/go/libraries/doltcore/env/remotes.go @@ -554,3 +554,17 @@ func GetDefaultBranch(dEnv *DoltEnv, branches []ref.DoltRef) string { return branches[0].GetPath() } + +// SetRemoteUpstreamForRefSpec set upstream for given RefSpec, remote name and branch ref. It uses given RepoStateWriter +// to persist upstream tracking branch information. +func SetRemoteUpstreamForRefSpec(rsw RepoStateWriter, refSpec ref.RefSpec, remote string, branchRef ref.DoltRef) error { + src := refSpec.SrcRef(branchRef) + dest := refSpec.DestRef(src) + + return rsw.UpdateBranch(branchRef.GetPath(), BranchConfig{ + Merge: ref.MarshalableRef{ + Ref: dest, + }, + Remote: remote, + }) +} diff --git a/go/libraries/doltcore/merge/violations_fk_prolly.go b/go/libraries/doltcore/merge/violations_fk_prolly.go index a06dc72582..d1b4091e6f 100644 --- a/go/libraries/doltcore/merge/violations_fk_prolly.go +++ b/go/libraries/doltcore/merge/violations_fk_prolly.go @@ -146,39 +146,54 @@ func prollyChildSecDiffFkConstraintViolations( ctx context.Context, foreignKey doltdb.ForeignKey, postParent, postChild *constraintViolationsLoadedTable, - preChildScndryIdx prolly.Map, + preChildSecIdx prolly.Map, receiver FKViolationReceiver) error { postChildRowData := durable.ProllyMapFromIndex(postChild.RowData) - postChildScndryIdx := durable.ProllyMapFromIndex(postChild.IndexData) - parentScndryIdx := durable.ProllyMapFromIndex(postParent.IndexData) + postChildSecIdx := durable.ProllyMapFromIndex(postChild.IndexData) + parentSecIdx := durable.ProllyMapFromIndex(postParent.IndexData) - idxDesc, _ := parentScndryIdx.Descriptors() - partialDesc := idxDesc.PrefixDesc(len(foreignKey.TableColumns)) - partialKB := val.NewTupleBuilder(partialDesc) + parentSecIdxDesc, _ := parentSecIdx.Descriptors() + partialDesc := parentSecIdxDesc.PrefixDesc(len(foreignKey.TableColumns)) + childPriKD, _ := postChildRowData.Descriptors() + childPriKB := val.NewTupleBuilder(childPriKD) - primaryKD, _ := postChildRowData.Descriptors() - kb := val.NewTupleBuilder(primaryKD) - - err := prolly.DiffMaps(ctx, preChildScndryIdx, postChildScndryIdx, func(ctx context.Context, diff tree.Diff) error { + var parentSecIdxCur *tree.Cursor + err := prolly.DiffMaps(ctx, preChildSecIdx, postChildSecIdx, func(ctx context.Context, diff tree.Diff) error { switch diff.Type { case tree.AddedDiff, tree.ModifiedDiff: - k, v := val.Tuple(diff.Key), val.Tuple(diff.To) - partialKey, hasNulls := makePartialKey( - partialKB, - foreignKey.TableColumns, - postChild.Index, - postChild.IndexSchema, - k, - v, - preChildScndryIdx.Pool()) - if hasNulls { - return nil + k := val.Tuple(diff.Key) + // TODO: possible to skip this if there are not null constraints over entire index + for i := 0; i < k.Count(); i++ { + if k.FieldIsNull(i) { + return nil + } } - err := createCVIfNoPartialKeyMatchesSec(ctx, k, v, partialKey, partialDesc, primaryKD, kb, parentScndryIdx, postChildRowData, postChildRowData.Pool(), receiver) + if parentSecIdxCur == nil { + newCur, err := tree.NewCursorAtKey(ctx, parentSecIdx.NodeStore(), parentSecIdx.Node(), k, partialDesc) + if err != nil { + return err + } + if !newCur.Valid() { + return createCVForSecIdx(ctx, k, childPriKD, childPriKB, postChildRowData, postChildRowData.Pool(), receiver) + } + parentSecIdxCur = newCur + } + + err := tree.Seek(ctx, parentSecIdxCur, k, partialDesc) if err != nil { return err } + if !parentSecIdxCur.Valid() { + return createCVForSecIdx(ctx, k, childPriKD, childPriKB, postChildRowData, postChildRowData.Pool(), receiver) + } + + // possible that k is less than the smallest key in parentSecIdxCur, so still should compare + key := val.Tuple(parentSecIdxCur.CurrentKey()) + if partialDesc.Compare(k, key) != 0 { + return createCVForSecIdx(ctx, k, childPriKD, childPriKB, postChildRowData, postChildRowData.Pool(), receiver) + } + return nil case tree.RemovedDiff: default: panic("unhandled diff type") @@ -213,27 +228,14 @@ func createCVIfNoPartialKeyMatchesPri( return receiver.ProllyFKViolationFound(ctx, k, v) } -func createCVIfNoPartialKeyMatchesSec( +func createCVForSecIdx( ctx context.Context, - k, v, partialKey val.Tuple, - partialKeyDesc val.TupleDesc, + k val.Tuple, primaryKD val.TupleDesc, primaryKb *val.TupleBuilder, - idx prolly.Map, pri prolly.Map, pool pool.BuffPool, receiver FKViolationReceiver) error { - itr, err := creation.NewPrefixItr(ctx, partialKey, partialKeyDesc, idx) - if err != nil { - return err - } - _, _, err = itr.Next(ctx) - if err != nil && err != io.EOF { - return err - } - if err == nil { - return nil - } // convert secondary idx entry to primary row key // the pks of the table are the last keys of the index @@ -245,7 +247,7 @@ func createCVIfNoPartialKeyMatchesSec( primaryIdxKey := primaryKb.Build(pool) var value val.Tuple - err = pri.Get(ctx, primaryIdxKey, func(k, v val.Tuple) error { + err := pri.Get(ctx, primaryIdxKey, func(k, v val.Tuple) error { value = v return nil }) @@ -329,7 +331,8 @@ func createCVsForPartialKeyMatches( } func makePartialKey(kb *val.TupleBuilder, tags []uint64, idxSch schema.Index, tblSch schema.Schema, k, v val.Tuple, pool pool.BuffPool) (val.Tuple, bool) { - if idxSch.Name() != "" { + // Possible that the parent index (idxSch) is longer than the partial key (tags). + if idxSch.Name() != "" && len(idxSch.IndexedColumnTags()) <= len(tags) { tags = idxSch.IndexedColumnTags() } for i, tag := range tags { diff --git a/go/libraries/doltcore/ref/remote_ref.go b/go/libraries/doltcore/ref/remote_ref.go index ccd6d8f227..c6d74d2e66 100644 --- a/go/libraries/doltcore/ref/remote_ref.go +++ b/go/libraries/doltcore/ref/remote_ref.go @@ -57,7 +57,7 @@ func NewRemoteRef(remote, branch string) RemoteRef { return RemoteRef{remote, branch} } -// NewRemoteRefFromPathString creates a DoltRef from a string in the format origin/main, or remotes/origin/main, or +// NewRemoteRefFromPathStr creates a DoltRef from a string in the format origin/main, or remotes/origin/main, or // refs/remotes/origin/main func NewRemoteRefFromPathStr(remoteAndPath string) (DoltRef, error) { if IsRef(remoteAndPath) { diff --git a/go/libraries/doltcore/schema/collation_comparator.go b/go/libraries/doltcore/schema/collation_comparator.go index 30e202c4f4..fc229cab24 100644 --- a/go/libraries/doltcore/schema/collation_comparator.go +++ b/go/libraries/doltcore/schema/collation_comparator.go @@ -115,6 +115,33 @@ func collationCompare(typ val.Type, collation sql.CollationID, left, right []byt } func compareCollatedStrings(collation sql.CollationID, left, right []byte) int { + i := 0 + for i < len(left) && i < len(right) { + if left[i] != right[i] { + break + } + i++ + } + if i >= len(left) || i >= len(right) { + if len(left) < len(right) { + return -1 + } else if len(left) > len(right) { + return 1 + } else { + return 0 + } + } + + li := i + for ; li >= 0 && !utf8.RuneStart(left[li]); li-- { + } + left = left[li:] + + ri := i + for ; ri >= 0 && !utf8.RuneStart(right[ri]); ri-- { + } + right = right[ri:] + getRuneWeight := collation.Sorter() for len(left) > 0 && len(right) > 0 { // Binary strings aren't handled through this function, so it is safe to use the utf8 functions @@ -130,12 +157,14 @@ func compareCollatedStrings(collation sql.CollationID, left, right []byte) int { return 0 } } - leftWeight := getRuneWeight(leftRune) - rightWeight := getRuneWeight(rightRune) - if leftWeight < rightWeight { - return -1 - } else if leftWeight > rightWeight { - return 1 + if leftRune != rightRune { + leftWeight := getRuneWeight(leftRune) + rightWeight := getRuneWeight(rightRune) + if leftWeight < rightWeight { + return -1 + } else if leftWeight > rightWeight { + return 1 + } } left = left[leftRead:] right = right[rightRead:] diff --git a/go/libraries/doltcore/schema/collation_comparator_test.go b/go/libraries/doltcore/schema/collation_comparator_test.go new file mode 100644 index 0000000000..543baccb79 --- /dev/null +++ b/go/libraries/doltcore/schema/collation_comparator_test.go @@ -0,0 +1,65 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package schema + +import ( + "fmt" + "testing" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/stretchr/testify/require" +) + +func TestCompareCollatedStrings(t *testing.T) { + tests := []struct { + name string + left []byte + right []byte + exp int + }{ + { + left: []byte("Hello, 人"), + right: []byte("Hello, 亻"), + exp: -1, + }, + { + left: []byte("woÒ"), + right: []byte("woÓ"), + exp: 0, + }, + { + left: []byte("\u07FB"), + right: []byte("\u07FC"), + exp: -1, + }, + { + left: []byte("˧"), + right: []byte("˦"), + exp: 1, + }, + { + left: []byte("ƵƶzƸ"), + right: []byte("ƵƶzƷ"), + exp: 1, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s vs %s", tt.left, tt.right), func(t *testing.T) { + cmp := compareCollatedStrings(sql.Collation_utf8mb4_0900_ai_ci, tt.left, tt.right) + require.Equal(t, tt.exp, cmp) + }) + } +} diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 7fccb94d94..a3491b5a26 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -432,6 +432,8 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds dt, found = dtables.NewTableOfTablesConstraintViolations(ctx, root), true case doltdb.BranchesTableName: dt, found = dtables.NewBranchesTable(ctx, db.ddb), true + case doltdb.RemoteBranchesTableName: + dt, found = dtables.NewRemoteBranchesTable(ctx, db.ddb), true case doltdb.RemotesTableName: dt, found = dtables.NewRemotesTable(ctx, db.ddb), true case doltdb.CommitsTableName: diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 09b160c6d0..aa495e2bd0 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -152,9 +152,8 @@ func (p DoltDatabaseProvider) FileSystem() filesys.Filesys { return p.fs } -// If this DatabaseProvider is set to standby |true|, it returns every dolt -// database as a read only database. Set back to |false| to get read-write -// behavior from dolt databases again. +// SetIsStandby sets whether this provider is set to standby |true|. Standbys return every dolt database as a read only +// database. Set back to |false| to get read-write behavior from dolt databases again. func (p DoltDatabaseProvider) SetIsStandby(standby bool) { p.mu.Lock() defer p.mu.Unlock() @@ -187,13 +186,16 @@ func (p DoltDatabaseProvider) Database(ctx *sql.Context, name string) (db sql.Da return wrapForStandby(db, standby), nil } + // Revision databases aren't tracked in the map, just instantiated on demand db, _, ok, err = p.databaseForRevision(ctx, name) if err != nil { return nil, err } + // A final check: if the database doesn't exist and this is a read replica, attempt to clone it from the remote if !ok { db, err = p.databaseForClone(ctx, name) + if err != nil { return nil, err } @@ -203,7 +205,6 @@ func (p DoltDatabaseProvider) Database(ctx *sql.Context, name string) (db sql.Da } } - // Don't track revision databases, just instantiate them on demand return wrapForStandby(db, standby), nil } @@ -263,7 +264,7 @@ func (p DoltDatabaseProvider) attemptCloneReplica(ctx *sql.Context, dbName strin func (p DoltDatabaseProvider) HasDatabase(ctx *sql.Context, name string) bool { _, err := p.Database(ctx, name) if err != nil && !sql.ErrDatabaseNotFound.Is(err) { - ctx.GetLogger().Errorf(err.Error()) + ctx.GetLogger().Warnf("Error getting database %s: %s", name, err.Error()) } return err == nil } @@ -725,7 +726,7 @@ func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string return nil, dsess.InitialDbState{}, false, err } - caseSensitiveBranchName, isBranch, err := isBranch(ctx, srcDb, resolvedRevSpec, p.remoteDialer) + caseSensitiveBranchName, isBranch, err := isBranch(ctx, srcDb, resolvedRevSpec) if err != nil { return nil, dsess.InitialDbState{}, false, err } @@ -733,8 +734,9 @@ func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string if isBranch { // fetch the upstream head if this is a replicated db if replicaDb, ok := srcDb.(ReadReplicaDatabase); ok { - // TODO move this out of analysis phase, should only happen at read time - err := switchAndFetchReplicaHead(ctx, resolvedRevSpec, replicaDb) + // TODO move this out of analysis phase, should only happen at read time, when the transaction begins (like is + // the case with a branch that already exists locally) + err := p.ensureReplicaHeadExists(ctx, resolvedRevSpec, replicaDb) if err != nil { return nil, dsess.InitialDbState{}, false, err } @@ -751,7 +753,7 @@ func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string return db, init, true, nil } - isTag, err := isTag(ctx, srcDb, resolvedRevSpec, p.remoteDialer) + isTag, err := isTag(ctx, srcDb, resolvedRevSpec) if err != nil { return nil, dsess.InitialDbState{}, false, err } @@ -759,6 +761,7 @@ func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string if isTag { // TODO: this should be an interface, not a struct replicaDb, ok := srcDb.(ReadReplicaDatabase) + if ok { srcDb = replicaDb.Database } @@ -981,69 +984,14 @@ func (p DoltDatabaseProvider) IsRevisionDatabase(ctx *sql.Context, dbName string return revision != "", nil } -// switchAndFetchReplicaHead tries to pull the latest version of a branch. Will fail if the branch -// does not exist on the ReadReplicaDatabase's remote. If the target branch is not a replication -// head, the new branch will not be continuously fetched. -func switchAndFetchReplicaHead(ctx *sql.Context, branch string, db ReadReplicaDatabase) error { - branchRef := ref.NewBranchRef(branch) - - var branchExists bool - branches, err := db.ddb.GetBranches(ctx) - if err != nil { - return err - } - - for _, br := range branches { - if br.String() == branch { - branchExists = true - break - } - } - - // check whether branch is on remote before creating local tracking branch - cm, err := actions.FetchRemoteBranch(ctx, db.tmpDir, db.remote, db.srcDB, db.DbData().Ddb, branchRef, actions.NoopRunProgFuncs, actions.NoopStopProgFuncs) - if err != nil { - return err - } - - cmHash, err := cm.HashOf() - if err != nil { - return err - } - - // create refs/heads/branch dataset - if !branchExists { - err = db.ddb.NewBranchAtCommit(ctx, branchRef, cm) - if err != nil { - return err - } - } - - dSess := dsess.DSessFromSess(ctx.Session) - currentBranchRef, err := dSess.CWBHeadRef(ctx, db.name) - if err != nil { - return err - } - - // create workingSets/heads/branch and update the working set - err = db.RebaseSourceDb(ctx) - if err != nil { - return err - } - - err = pullBranches(ctx, db, []doltdb.RefWithHash{{ - Ref: branchRef, - Hash: cmHash, - }}, nil, currentBranchRef, pullBehavior_fastForward) - if err != nil { - return err - } - - return nil +// ensureReplicaHeadExists tries to pull the latest version of a remote branch. Will fail if the branch +// does not exist on the ReadReplicaDatabase's remote. +func (p DoltDatabaseProvider) ensureReplicaHeadExists(ctx *sql.Context, branch string, db ReadReplicaDatabase) error { + return db.CreateLocalBranchFromRemote(ctx, ref.NewBranchRef(branch)) } // isBranch returns whether a branch with the given name is in scope for the database given -func isBranch(ctx context.Context, db SqlDatabase, branchName string, dialer dbfactory.GRPCDialProvider) (string, bool, error) { +func isBranch(ctx context.Context, db SqlDatabase, branchName string) (string, bool, error) { var ddbs []*doltdb.DoltDB if rdb, ok := db.(ReadReplicaDatabase); ok { @@ -1062,7 +1010,7 @@ func isBranch(ctx context.Context, db SqlDatabase, branchName string, dialer dbf return brName, true, nil } - brName, branchExists, err = isRemoteBranch(ctx, db, ddbs, branchName) + brName, branchExists, err = isRemoteBranch(ctx, ddbs, branchName) if err != nil { return "", false, err } @@ -1088,21 +1036,15 @@ func isLocalBranch(ctx context.Context, ddbs []*doltdb.DoltDB, branchName string return "", false, nil } -// isRemoteBranch is called when the branch in connection string is not available as a local branch, so it searches -// for a remote tracking branch. If there is only one match, it creates a new local branch from the remote tracking -// branch and sets its upstream to it. -func isRemoteBranch(ctx context.Context, srcDB SqlDatabase, ddbs []*doltdb.DoltDB, branchName string) (string, bool, error) { +// isRemoteBranch returns whether the given branch name is a remote branch on any of the databases provided. +func isRemoteBranch(ctx context.Context, ddbs []*doltdb.DoltDB, branchName string) (string, bool, error) { for _, ddb := range ddbs { - bn, branchExists, remoteRef, err := ddb.HasRemoteTrackingBranch(ctx, branchName) + bn, branchExists, _, err := ddb.HasRemoteTrackingBranch(ctx, branchName) if err != nil { return "", false, err } if branchExists { - err = createLocalBranchFromRemoteTrackingBranch(ctx, srcDB.DbData(), ddb, branchName, remoteRef) - if err != nil { - return "", false, err - } return bn, true, nil } } @@ -1110,36 +1052,8 @@ func isRemoteBranch(ctx context.Context, srcDB SqlDatabase, ddbs []*doltdb.DoltD return "", false, nil } -// createLocalBranchFromRemoteTrackingBranch creates a new local branch from given remote tracking branch -// and sets its upstream to it. -func createLocalBranchFromRemoteTrackingBranch(ctx context.Context, dbData env.DbData, ddb *doltdb.DoltDB, branchName string, remoteRef ref.RemoteRef) error { - startPt := remoteRef.GetPath() - err := actions.CreateBranchOnDB(ctx, ddb, branchName, startPt, false, remoteRef) - if err != nil { - return err - } - - // at this point the branch is created on db - branchRef := ref.NewBranchRef(branchName) - remote := remoteRef.GetRemote() - refSpec, err := ref.ParseRefSpecForRemote(remote, remoteRef.GetBranch()) - if err != nil { - return fmt.Errorf("%w: '%s'", err, remote) - } - - src := refSpec.SrcRef(branchRef) - dest := refSpec.DestRef(src) - - return dbData.Rsw.UpdateBranch(branchRef.GetPath(), env.BranchConfig{ - Merge: ref.MarshalableRef{ - Ref: dest, - }, - Remote: remote, - }) -} - // isTag returns whether a tag with the given name is in scope for the database given -func isTag(ctx context.Context, db SqlDatabase, tagName string, dialer dbfactory.GRPCDialProvider) (bool, error) { +func isTag(ctx context.Context, db SqlDatabase, tagName string) (bool, error) { var ddbs []*doltdb.DoltDB if rdb, ok := db.(ReadReplicaDatabase); ok { diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_branch.go b/go/libraries/doltcore/sqle/dprocedures/dolt_branch.go index 566a56c6f9..397fbcfda4 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_branch.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_branch.go @@ -257,26 +257,77 @@ func loadConfig(ctx *sql.Context) *env.DoltCliConfig { } func createNewBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResults) error { - var branchName string + if apr.NArg() == 0 || apr.NArg() > 2 { + return InvalidArgErr + } + + var branchName = apr.Arg(0) var startPt = "HEAD" - if apr.NArg() == 1 { - branchName = apr.Arg(0) - } else if apr.NArg() == 2 { - branchName = apr.Arg(0) + if len(branchName) == 0 { + return EmptyBranchNameErr + } + if apr.NArg() == 2 { startPt = apr.Arg(1) if len(startPt) == 0 { return InvalidArgErr } } - if len(branchName) == 0 { - return EmptyBranchNameErr + var remoteName, remoteBranch string + var refSpec ref.RefSpec + var err error + trackVal, setTrackUpstream := apr.GetValue(cli.TrackFlag) + if setTrackUpstream { + if trackVal == "inherit" { + return fmt.Errorf("--track='inherit' is not supported yet") + } else if trackVal == "direct" && apr.NArg() != 2 { + return InvalidArgErr + } + + if apr.NArg() == 2 { + // branchName and startPt are already set + remoteName, remoteBranch = actions.ParseRemoteBranchName(startPt) + refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranch) + if err != nil { + return err + } + } else { + // if track option is defined with no value, + // the track value can either be starting point name OR branch name + startPt = trackVal + remoteName, remoteBranch = actions.ParseRemoteBranchName(startPt) + refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranch) + if err != nil { + branchName = trackVal + startPt = apr.Arg(0) + remoteName, remoteBranch = actions.ParseRemoteBranchName(startPt) + refSpec, err = ref.ParseRefSpecForRemote(remoteName, remoteBranch) + if err != nil { + return err + } + } + } } - if err := branch_control.CanCreateBranch(ctx, branchName); err != nil { + err = branch_control.CanCreateBranch(ctx, branchName) + if err != nil { return err } - return actions.CreateBranchWithStartPt(ctx, dbData, branchName, startPt, apr.Contains(cli.ForceFlag)) + + err = actions.CreateBranchWithStartPt(ctx, dbData, branchName, startPt, apr.Contains(cli.ForceFlag)) + if err != nil { + return err + } + + if setTrackUpstream { + // at this point new branch is created + err = env.SetRemoteUpstreamForRefSpec(dbData.Rsw, refSpec, remoteName, ref.NewBranchRef(branchName)) + if err != nil { + return err + } + } + + return nil } func copyBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResults) error { diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go b/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go index 52bc674401..ba7851c1cf 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go @@ -208,15 +208,7 @@ func checkoutRemoteBranch(ctx *sql.Context, dbName string, dbData env.DbData, br return errhand.BuildDError(fmt.Errorf("%w: '%s'", err, remoteRef.GetRemote()).Error()).Build() } - src := refSpec.SrcRef(dbData.Rsr.CWBHeadRef()) - dest := refSpec.DestRef(src) - - return dbData.Rsw.UpdateBranch(src.GetPath(), env.BranchConfig{ - Merge: ref.MarshalableRef{ - Ref: dest, - }, - Remote: remoteRef.GetRemote(), - }) + return env.SetRemoteUpstreamForRefSpec(dbData.Rsw, refSpec, remoteRef.GetRemote(), dbData.Rsr.CWBHeadRef()) } else { return fmt.Errorf("'%s' matched multiple (%v) remote tracking branches", branchName, len(remoteRefs)) } diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_fetch.go b/go/libraries/doltcore/sqle/dprocedures/dolt_fetch.go index a47a9176fa..24ff6a4e6f 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_fetch.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_fetch.go @@ -62,14 +62,12 @@ func doDoltFetch(ctx *sql.Context, args []string) (int, error) { return cmdFailure, err } - updateMode := ref.UpdateMode{Force: apr.Contains(cli.ForceFlag)} - srcDB, err := sess.Provider().GetRemoteDB(ctx, dbData.Ddb.ValueReadWriter().Format(), remote, false) if err != nil { return 1, err } - err = actions.FetchRefSpecs(ctx, dbData, srcDB, refSpecs, remote, updateMode, runProgFuncs, stopProgFuncs) + err = actions.FetchRefSpecs(ctx, dbData, srcDB, refSpecs, remote, ref.UpdateMode{Force: true}, runProgFuncs, stopProgFuncs) if err != nil { return cmdFailure, fmt.Errorf("fetch failed: %w", err) } diff --git a/go/libraries/doltcore/sqle/dtables/branches_table.go b/go/libraries/doltcore/sqle/dtables/branches_table.go index e847620d55..8149bbd41e 100644 --- a/go/libraries/doltcore/sqle/dtables/branches_table.go +++ b/go/libraries/doltcore/sqle/dtables/branches_table.go @@ -22,6 +22,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" ) @@ -31,37 +32,54 @@ var _ sql.DeletableTable = (*BranchesTable)(nil) var _ sql.InsertableTable = (*BranchesTable)(nil) var _ sql.ReplaceableTable = (*BranchesTable)(nil) -// BranchesTable is a sql.Table implementation that implements a system table which shows the dolt branches +// BranchesTable is the system table that accesses branches type BranchesTable struct { - ddb *doltdb.DoltDB + ddb *doltdb.DoltDB + remote bool } // NewBranchesTable creates a BranchesTable func NewBranchesTable(_ *sql.Context, ddb *doltdb.DoltDB) sql.Table { - return &BranchesTable{ddb} + return &BranchesTable{ddb, false} +} + +// NewRemoteBranchesTable creates a BranchesTable with only remote refs +func NewRemoteBranchesTable(_ *sql.Context, ddb *doltdb.DoltDB) sql.Table { + return &BranchesTable{ddb, true} } // Name is a sql.Table interface function which returns the name of the table which is defined by the constant // BranchesTableName func (bt *BranchesTable) Name() string { + if bt.remote { + return doltdb.RemoteBranchesTableName + } return doltdb.BranchesTableName } // String is a sql.Table interface function which returns the name of the table which is defined by the constant // BranchesTableName func (bt *BranchesTable) String() string { + if bt.remote { + return doltdb.RemoteBranchesTableName + } return doltdb.BranchesTableName } // Schema is a sql.Table interface function that gets the sql.Schema of the branches system table func (bt *BranchesTable) Schema() sql.Schema { + tableName := doltdb.BranchesTableName + if bt.remote { + tableName = doltdb.RemoteBranchesTableName + } + return []*sql.Column{ - {Name: "name", Type: types.Text, Source: doltdb.BranchesTableName, PrimaryKey: true, Nullable: false}, - {Name: "hash", Type: types.Text, Source: doltdb.BranchesTableName, PrimaryKey: false, Nullable: false}, - {Name: "latest_committer", Type: types.Text, Source: doltdb.BranchesTableName, PrimaryKey: false, Nullable: true}, - {Name: "latest_committer_email", Type: types.Text, Source: doltdb.BranchesTableName, PrimaryKey: false, Nullable: true}, - {Name: "latest_commit_date", Type: types.Datetime, Source: doltdb.BranchesTableName, PrimaryKey: false, Nullable: true}, - {Name: "latest_commit_message", Type: types.Text, Source: doltdb.BranchesTableName, PrimaryKey: false, Nullable: true}, + {Name: "name", Type: types.Text, Source: tableName, PrimaryKey: true, Nullable: false}, + {Name: "hash", Type: types.Text, Source: tableName, PrimaryKey: false, Nullable: false}, + {Name: "latest_committer", Type: types.Text, Source: tableName, PrimaryKey: false, Nullable: true}, + {Name: "latest_committer_email", Type: types.Text, Source: tableName, PrimaryKey: false, Nullable: true}, + {Name: "latest_commit_date", Type: types.Datetime, Source: tableName, PrimaryKey: false, Nullable: true}, + {Name: "latest_commit_message", Type: types.Text, Source: tableName, PrimaryKey: false, Nullable: true}, } } @@ -77,7 +95,7 @@ func (bt *BranchesTable) Partitions(*sql.Context) (sql.PartitionIter, error) { // PartitionRows is a sql.Table interface function that gets a row iterator for a partition func (bt *BranchesTable) PartitionRows(sqlCtx *sql.Context, part sql.Partition) (sql.RowIter, error) { - return NewBranchItr(sqlCtx, bt.ddb) + return NewBranchItr(sqlCtx, bt.ddb, bt.remote) } // BranchItr is a sql.RowItr implementation which iterates over each commit as if it's a row in the table. @@ -88,23 +106,37 @@ type BranchItr struct { } // NewBranchItr creates a BranchItr from the current environment. -func NewBranchItr(sqlCtx *sql.Context, ddb *doltdb.DoltDB) (*BranchItr, error) { - branches, err := ddb.GetBranches(sqlCtx) +func NewBranchItr(ctx *sql.Context, ddb *doltdb.DoltDB, remote bool) (*BranchItr, error) { + var branchRefs []ref.DoltRef + var err error - if err != nil { - return nil, err + if remote { + branchRefs, err = ddb.GetRefsOfType(ctx, map[ref.RefType]struct{}{ref.RemoteRefType: {}}) + if err != nil { + return nil, err + } + } else { + branchRefs, err = ddb.GetBranches(ctx) + if err != nil { + return nil, err + } } - branchNames := make([]string, len(branches)) - commits := make([]*doltdb.Commit, len(branches)) - for i, branch := range branches { - commit, err := ddb.ResolveCommitRef(sqlCtx, branch) + branchNames := make([]string, len(branchRefs)) + commits := make([]*doltdb.Commit, len(branchRefs)) + for i, branch := range branchRefs { + commit, err := ddb.ResolveCommitRef(ctx, branch) if err != nil { return nil, err } - branchNames[i] = branch.GetPath() + if branch.GetType() == ref.RemoteRefType { + branchNames[i] = "remotes/" + branch.GetPath() + } else { + branchNames[i] = branch.GetPath() + } + commits[i] = commit } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go index 28857111dd..c68d57e548 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go @@ -1307,6 +1307,46 @@ var Dolt1MergeScripts = []queries.ScriptTest{ }, }, }, + { + Name: "parent index is longer than child index", + SetUpScript: []string{ + "create table parent (i int primary key, x int, y int, z int, index (y, x, z));", + "create table child (y int, x int, primary key(y, x), foreign key (y, x) references parent(y, x));", + "insert into parent values (100,1,1,1), (200,2,1,2), (300,1,null,1);", + "CALL DOLT_ADD('.')", + "CALL DOLT_COMMIT('-am', 'setup');", + "CALL DOLT_BRANCH('other');", + + "DELETE from parent WHERE x = 2;", + "CALL DOLT_COMMIT('-am', 'main');", + + "CALL DOLT_CHECKOUT('other');", + "INSERT INTO child VALUES (1, 2);", + "CALL DOLT_COMMIT('-am', 'other');", + + "CALL DOLT_CHECKOUT('main');", + "set DOLT_FORCE_TRANSACTION_COMMIT = on;", + "CALL DOLT_MERGE('other');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT * from dolt_constraint_violations", + Expected: []sql.Row{ + {"child", uint64(1)}, + }, + }, + { + Query: "SELECT * from dolt_constraint_violations_parent", + Expected: []sql.Row{}, + }, + { + Query: "SELECT y, x from dolt_constraint_violations_child", + Expected: []sql.Row{ + {1, 2}, + }, + }, + }, + }, } var KeylessMergeCVsAndConflictsScripts = []queries.ScriptTest{ @@ -3212,6 +3252,38 @@ var DoltVerifyConstraintsTestScripts = []queries.ScriptTest{ }, }, }, + { + Name: "verify-constraints: Stored Procedure ignores null", + SetUpScript: []string{ + "create table parent (id bigint primary key, v1 bigint, v2 bigint, index (v1, v2))", + "create table child (id bigint primary key, v1 bigint, v2 bigint, foreign key (v1, v2) references parent(v1, v2))", + "insert into parent values (1, 1, 1), (2, 2, 2)", + "insert into child values (1, 1, 1), (2, 90, NULL)", + "set dolt_force_transaction_commit = 1;", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL DOLT_VERIFY_CONSTRAINTS('child')", + Expected: []sql.Row{{0}}, + }, + { + Query: "set foreign_key_checks = 0;", + SkipResultsCheck: true, + }, + { + Query: "insert into child values (3, 30, 30);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "set foreign_key_checks = 1;", + SkipResultsCheck: true, + }, + { + Query: "CALL DOLT_VERIFY_CONSTRAINTS('child')", + Expected: []sql.Row{{1}}, + }, + }, + }, } var errTmplNoAutomaticMerge = "table %s can't be automatically merged.\nTo merge this table, make the schema on the source and target branch equal." diff --git a/go/libraries/doltcore/sqle/read_replica_database.go b/go/libraries/doltcore/sqle/read_replica_database.go index 19c4688528..78eab60ca7 100644 --- a/go/libraries/doltcore/sqle/read_replica_database.go +++ b/go/libraries/doltcore/sqle/read_replica_database.go @@ -25,6 +25,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/store/datas" @@ -183,7 +184,7 @@ func (rrd ReadReplicaDatabase) PullFromRemote(ctx *sql.Context) error { } remoteRefs = prunedRefs - err = pullBranches(ctx, rrd, remoteRefs, localRefs, currentBranchRef, behavior) + err = pullBranchesAndUpdateWorkingSet(ctx, rrd, remoteRefs, localRefs, currentBranchRef, behavior) if err != nil && !dsess.IgnoreReplicationErrors() { return err @@ -193,7 +194,7 @@ func (rrd ReadReplicaDatabase) PullFromRemote(ctx *sql.Context) error { } case allHeads == int8(1): - err = pullBranches(ctx, rrd, remoteRefs, localRefs, currentBranchRef, behavior) + err = pullBranchesAndUpdateWorkingSet(ctx, rrd, remoteRefs, localRefs, currentBranchRef, behavior) if err != nil && !dsess.IgnoreReplicationErrors() { return err } else if err != nil { @@ -215,8 +216,55 @@ func (rrd ReadReplicaDatabase) PullFromRemote(ctx *sql.Context) error { return nil } -func (rrd ReadReplicaDatabase) RebaseSourceDb(ctx *sql.Context) error { - return rrd.srcDB.Rebase(ctx) +// CreateLocalBranchFromRemote pulls the given branch from the remote database and creates a local tracking branch for +// it. This is only used for initializing a new local branch being pulled from a remote during connection +// initialization, and doesn't do the full work of remote synchronization that happens on transaction start. +func (rrd ReadReplicaDatabase) CreateLocalBranchFromRemote(ctx *sql.Context, branchRef ref.BranchRef) error { + _, err := rrd.limiter.Run(ctx, "pullNewBranch", func() (any, error) { + // because several clients can queue up waiting to create the same local branch, double check to see if this + // work was already done and bail early if so + _, branchExists, err := rrd.ddb.HasBranch(ctx, branchRef.GetPath()) + if err != nil { + return nil, err + } + + if branchExists { + return nil, nil + } + + cm, err := actions.FetchRemoteBranch(ctx, rrd.tmpDir, rrd.remote, rrd.srcDB, rrd.ddb, branchRef, actions.NoopRunProgFuncs, actions.NoopStopProgFuncs) + if err != nil { + return nil, err + } + + cmHash, err := cm.HashOf() + if err != nil { + return nil, err + } + + // create refs/heads/branch dataset + err = rrd.ddb.NewBranchAtCommit(ctx, branchRef, cm) + if err != nil { + return nil, err + } + + err = rrd.srcDB.Rebase(ctx) + if err != nil { + return nil, err + } + + _, err = pullBranches(ctx, rrd, []doltdb.RefWithHash{{ + Ref: branchRef, + Hash: cmHash, + }}, nil, pullBehavior_fastForward) + if err != nil { + return nil, err + } + + return nil, err + }) + + return err } type pullBehavior bool @@ -224,9 +272,10 @@ type pullBehavior bool const pullBehavior_fastForward pullBehavior = false const pullBehavior_forcePull pullBehavior = true -// pullBranches pulls the remote branches named. If a corresponding local branch exists, it will be fast-forwarded. If -// it doesn't exist, it will be created. -func pullBranches( +// pullBranchesAndUpdateWorkingSet pulls the remote branches named. If a corresponding local branch exists, it will be +// fast-forwarded. If it doesn't exist, it will be created. Afterward, the working set of the current branch is +// updated if the current branch ref was updated by the pull. +func pullBranchesAndUpdateWorkingSet( ctx *sql.Context, rrd ReadReplicaDatabase, remoteRefs []doltdb.RefWithHash, @@ -234,68 +283,8 @@ func pullBranches( currentBranchRef ref.DoltRef, behavior pullBehavior, ) error { - localRefsByPath := make(map[string]doltdb.RefWithHash) - remoteRefsByPath := make(map[string]doltdb.RefWithHash) - remoteHashes := make([]hash.Hash, len(remoteRefs)) - for i, b := range remoteRefs { - remoteRefsByPath[b.Ref.GetPath()] = b - remoteHashes[i] = b.Hash - } - - for _, b := range localRefs { - localRefsByPath[b.Ref.GetPath()] = b - } - - // XXX: Our view of which remote branches to pull and what to set the - // local branches to was computed outside of the limiter, concurrently - // with other possible attempts to pull from the remote. Now we are - // applying changes based on that view. This seems capable of rolling - // back changes which were applied from another thread. - - _, err := rrd.limiter.Run(ctx, "-all", func() (any, error) { - err := rrd.ddb.PullChunks(ctx, rrd.tmpDir, rrd.srcDB, remoteHashes, nil) - - for _, remoteRef := range remoteRefs { - localRef, localRefExists := localRefsByPath[remoteRef.Ref.GetPath()] - switch { - case err != nil: - case localRefExists: - // TODO: this should work for workspaces too but doesn't, only branches - if localRef.Ref.GetType() == ref.BranchRefType { - if localRef.Hash != remoteRef.Hash { - if behavior == pullBehavior_forcePull { - err = rrd.ddb.SetHead(ctx, remoteRef.Ref, remoteRef.Hash) - if err != nil { - return nil, err - } - } else { - err = rrd.ddb.FastForwardToHash(ctx, remoteRef.Ref, remoteRef.Hash) - if err != nil { - return nil, err - } - } - } - } - default: - switch remoteRef.Ref.GetType() { - case ref.BranchRefType: - err = rrd.ddb.SetHead(ctx, remoteRef.Ref, remoteRef.Hash) - if err != nil { - return nil, err - } - case ref.TagRefType: - err = rrd.ddb.SetHead(ctx, remoteRef.Ref, remoteRef.Hash) - if err != nil { - return nil, err - } - default: - ctx.GetLogger().Warnf("skipping replication for unhandled remote ref %s", remoteRef.Ref.String()) - } - } - } - return nil, nil - }) + remoteRefsByPath, err := pullBranches(ctx, rrd, remoteRefs, localRefs, behavior) if err != nil { return err } @@ -361,6 +350,142 @@ func pullBranches( return nil } +// pullBranches pulls the remote branches named and returns the map of their hashes keyed by branch path. +func pullBranches( + ctx *sql.Context, + rrd ReadReplicaDatabase, + remoteRefs []doltdb.RefWithHash, + localRefs []doltdb.RefWithHash, + behavior pullBehavior, +) (map[string]doltdb.RefWithHash, error) { + localRefsByPath := make(map[string]doltdb.RefWithHash) + remoteRefsByPath := make(map[string]doltdb.RefWithHash) + remoteHashes := make([]hash.Hash, len(remoteRefs)) + + for i, b := range remoteRefs { + remoteRefsByPath[b.Ref.GetPath()] = b + remoteHashes[i] = b.Hash + } + + for _, b := range localRefs { + localRefsByPath[b.Ref.GetPath()] = b + } + + // XXX: Our view of which remote branches to pull and what to set the + // local branches to was computed outside of the limiter, concurrently + // with other possible attempts to pull from the remote. Now we are + // applying changes based on that view. This seems capable of rolling + // back changes which were applied from another thread. + + _, err := rrd.limiter.Run(ctx, "-all", func() (any, error) { + pullErr := rrd.ddb.PullChunks(ctx, rrd.tmpDir, rrd.srcDB, remoteHashes, nil) + + REFS: // every successful pass through the loop below must end with CONTINUE REFS to get out of the retry loop + for _, remoteRef := range remoteRefs { + trackingRef := ref.NewRemoteRef(rrd.remote.Name, remoteRef.Ref.GetPath()) + localRef, localRefExists := localRefsByPath[remoteRef.Ref.GetPath()] + + // loop on optimistic lock failures + OPTIMISTIC_RETRY: + for { + if pullErr != nil || localRefExists { + pullErr = nil + + // TODO: this should work for workspaces too but doesn't, only branches + if localRef.Ref.GetType() == ref.BranchRefType { + err := rrd.pullLocalBranch(ctx, localRef, remoteRef, trackingRef, behavior) + if errors.Is(err, datas.ErrOptimisticLockFailed) { + continue OPTIMISTIC_RETRY + } else if err != nil { + return nil, err + } + } + + continue REFS + } else { + switch remoteRef.Ref.GetType() { + case ref.BranchRefType: + err := rrd.createNewBranchFromRemote(ctx, remoteRef, trackingRef) + if errors.Is(err, datas.ErrOptimisticLockFailed) { + continue OPTIMISTIC_RETRY + } else if err != nil { + return nil, err + } + + // TODO: Establish upstream tracking for this new branch + continue REFS + case ref.TagRefType: + err := rrd.ddb.SetHead(ctx, remoteRef.Ref, remoteRef.Hash) + if errors.Is(err, datas.ErrOptimisticLockFailed) { + continue OPTIMISTIC_RETRY + } else if err != nil { + return nil, err + } + + continue REFS + default: + ctx.GetLogger().Warnf("skipping replication for unhandled remote ref %s", remoteRef.Ref.String()) + continue REFS + } + } + } + } + return nil, nil + }) + if err != nil { + return nil, err + } + + return remoteRefsByPath, nil +} + +func (rrd ReadReplicaDatabase) createNewBranchFromRemote(ctx *sql.Context, remoteRef doltdb.RefWithHash, trackingRef ref.RemoteRef) error { + ctx.GetLogger().Tracef("creating local branch %s", remoteRef.Ref.GetPath()) + + // If a local branch isn't present for the remote branch, create a new branch for it. We need to use + // NewBranchAtCommit so that the branch has its associated working set created at the same time. Creating + // branch refs without associate working sets causes errors in other places. + spec, err := doltdb.NewCommitSpec(remoteRef.Hash.String()) + if err != nil { + return err + } + + cm, err := rrd.ddb.Resolve(ctx, spec, nil) + if err != nil { + return err + } + + err = rrd.ddb.NewBranchAtCommit(ctx, remoteRef.Ref, cm) + err = rrd.ddb.SetHead(ctx, trackingRef, remoteRef.Hash) + if err != nil { + return err + } + + return rrd.ddb.SetHead(ctx, trackingRef, remoteRef.Hash) +} + +func (rrd ReadReplicaDatabase) pullLocalBranch(ctx *sql.Context, localRef doltdb.RefWithHash, remoteRef doltdb.RefWithHash, trackingRef ref.RemoteRef, behavior pullBehavior) error { + if localRef.Hash != remoteRef.Hash { + if behavior == pullBehavior_forcePull { + err := rrd.ddb.SetHead(ctx, remoteRef.Ref, remoteRef.Hash) + if err != nil { + return err + } + } else { + err := rrd.ddb.FastForwardToHash(ctx, remoteRef.Ref, remoteRef.Hash) + if err != nil { + return err + } + } + + err := rrd.ddb.SetHead(ctx, trackingRef, remoteRef.Hash) + if err != nil { + return err + } + } + return nil +} + func getReplicationRefs(ctx *sql.Context, rrd ReadReplicaDatabase) ( remoteRefs []doltdb.RefWithHash, localRefs []doltdb.RefWithHash, diff --git a/go/libraries/doltcore/sqle/tables.go b/go/libraries/doltcore/sqle/tables.go index 8e34ecc278..7da2b3437a 100644 --- a/go/libraries/doltcore/sqle/tables.go +++ b/go/libraries/doltcore/sqle/tables.go @@ -1937,6 +1937,103 @@ func (t *AlterableDoltTable) RenameIndex(ctx *sql.Context, fromIndexName string, return t.updateFromRoot(ctx, newRoot) } +// createForeignKey creates a doltdb.ForeignKey from a sql.ForeignKeyConstraint +func (t *AlterableDoltTable) createForeignKey( + ctx *sql.Context, + root *doltdb.RootValue, + tbl *doltdb.Table, + sqlFk sql.ForeignKeyConstraint, + onUpdateRefAction, onDeleteRefAction doltdb.ForeignKeyReferentialAction) (doltdb.ForeignKey, error) { + if !sqlFk.IsResolved { + return doltdb.ForeignKey{ + Name: sqlFk.Name, + TableName: sqlFk.Table, + TableIndex: "", + TableColumns: nil, + ReferencedTableName: sqlFk.ParentTable, + ReferencedTableIndex: "", + ReferencedTableColumns: nil, + OnUpdate: onUpdateRefAction, + OnDelete: onDeleteRefAction, + UnresolvedFKDetails: doltdb.UnresolvedFKDetails{ + TableColumns: sqlFk.Columns, + ReferencedTableColumns: sqlFk.ParentColumns, + }, + }, nil + } + colTags := make([]uint64, len(sqlFk.Columns)) + for i, col := range sqlFk.Columns { + tableCol, ok := t.sch.GetAllCols().GetByNameCaseInsensitive(col) + if !ok { + return doltdb.ForeignKey{}, fmt.Errorf("table `%s` does not have column `%s`", sqlFk.Table, col) + } + colTags[i] = tableCol.Tag + } + + var refTbl *doltdb.Table + var refSch schema.Schema + if sqlFk.IsSelfReferential() { + refTbl = tbl + refSch = t.sch + } else { + var ok bool + var err error + refTbl, _, ok, err = root.GetTableInsensitive(ctx, sqlFk.ParentTable) + if err != nil { + return doltdb.ForeignKey{}, err + } + if !ok { + return doltdb.ForeignKey{}, fmt.Errorf("referenced table `%s` does not exist", sqlFk.ParentTable) + } + refSch, err = refTbl.GetSchema(ctx) + if err != nil { + return doltdb.ForeignKey{}, err + } + } + + refColTags := make([]uint64, len(sqlFk.ParentColumns)) + for i, name := range sqlFk.ParentColumns { + refCol, ok := refSch.GetAllCols().GetByNameCaseInsensitive(name) + if !ok { + return doltdb.ForeignKey{}, fmt.Errorf("table `%s` does not have column `%s`", sqlFk.ParentTable, name) + } + refColTags[i] = refCol.Tag + } + + var tableIndexName, refTableIndexName string + tableIndex, ok, err := findIndexWithPrefix(t.sch, sqlFk.Columns) + if err != nil { + return doltdb.ForeignKey{}, err + } + // Use secondary index if found; otherwise it will use empty string, indicating primary key + if ok { + tableIndexName = tableIndex.Name() + } + refTableIndex, ok, err := findIndexWithPrefix(refSch, sqlFk.ParentColumns) + if err != nil { + return doltdb.ForeignKey{}, err + } + // Use secondary index if found; otherwise it will use empty string, indicating primary key + if ok { + refTableIndexName = refTableIndex.Name() + } + return doltdb.ForeignKey{ + Name: sqlFk.Name, + TableName: sqlFk.Table, + TableIndex: tableIndexName, + TableColumns: colTags, + ReferencedTableName: sqlFk.ParentTable, + ReferencedTableIndex: refTableIndexName, + ReferencedTableColumns: refColTags, + OnUpdate: onUpdateRefAction, + OnDelete: onDeleteRefAction, + UnresolvedFKDetails: doltdb.UnresolvedFKDetails{ + TableColumns: sqlFk.Columns, + ReferencedTableColumns: sqlFk.ParentColumns, + }, + }, nil +} + // AddForeignKey implements sql.ForeignKeyTable func (t *AlterableDoltTable) AddForeignKey(ctx *sql.Context, sqlFk sql.ForeignKeyConstraint) error { if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { @@ -1970,95 +2067,9 @@ func (t *AlterableDoltTable) AddForeignKey(ctx *sql.Context, sqlFk sql.ForeignKe return err } - var doltFk doltdb.ForeignKey - - if sqlFk.IsResolved { - colTags := make([]uint64, len(sqlFk.Columns)) - for i, col := range sqlFk.Columns { - tableCol, ok := t.sch.GetAllCols().GetByNameCaseInsensitive(col) - if !ok { - return fmt.Errorf("table `%s` does not have column `%s`", sqlFk.Table, col) - } - colTags[i] = tableCol.Tag - } - - var refTbl *doltdb.Table - var ok bool - var refSch schema.Schema - if sqlFk.IsSelfReferential() { - refTbl = tbl - refSch = t.sch - } else { - refTbl, _, ok, err = root.GetTableInsensitive(ctx, sqlFk.ParentTable) - if err != nil { - return err - } - if !ok { - return fmt.Errorf("referenced table `%s` does not exist", sqlFk.ParentTable) - } - refSch, err = refTbl.GetSchema(ctx) - if err != nil { - return err - } - } - - refColTags := make([]uint64, len(sqlFk.ParentColumns)) - for i, name := range sqlFk.ParentColumns { - refCol, ok := refSch.GetAllCols().GetByNameCaseInsensitive(name) - if !ok { - return fmt.Errorf("table `%s` does not have column `%s`", sqlFk.ParentTable, name) - } - refColTags[i] = refCol.Tag - } - - var tableIndexName, refTableIndexName string - tableIndex, ok, err := findIndexWithPrefix(t.sch, sqlFk.Columns) - if err != nil { - return err - } - // Use secondary index if found; otherwise it will use empty string, indicating primary key - if ok { - tableIndexName = tableIndex.Name() - } - refTableIndex, ok, err := findIndexWithPrefix(refSch, sqlFk.ParentColumns) - if err != nil { - return err - } - // Use secondary index if found; otherwise it will use empty string, indicating primary key - if ok { - refTableIndexName = refTableIndex.Name() - } - doltFk = doltdb.ForeignKey{ - Name: sqlFk.Name, - TableName: sqlFk.Table, - TableIndex: tableIndexName, - TableColumns: colTags, - ReferencedTableName: sqlFk.ParentTable, - ReferencedTableIndex: refTableIndexName, - ReferencedTableColumns: refColTags, - OnUpdate: onUpdateRefAction, - OnDelete: onDeleteRefAction, - UnresolvedFKDetails: doltdb.UnresolvedFKDetails{ - TableColumns: sqlFk.Columns, - ReferencedTableColumns: sqlFk.ParentColumns, - }, - } - } else { - doltFk = doltdb.ForeignKey{ - Name: sqlFk.Name, - TableName: sqlFk.Table, - TableIndex: "", - TableColumns: nil, - ReferencedTableName: sqlFk.ParentTable, - ReferencedTableIndex: "", - ReferencedTableColumns: nil, - OnUpdate: onUpdateRefAction, - OnDelete: onDeleteRefAction, - UnresolvedFKDetails: doltdb.UnresolvedFKDetails{ - TableColumns: sqlFk.Columns, - ReferencedTableColumns: sqlFk.ParentColumns, - }, - } + doltFk, err := t.createForeignKey(ctx, root, tbl, sqlFk, onUpdateRefAction, onDeleteRefAction) + if err != nil { + return err } fkc, err := root.GetForeignKeyCollection(ctx) @@ -2132,12 +2143,7 @@ func (t *AlterableDoltTable) UpdateForeignKey(ctx *sql.Context, fkName string, s doltFk.UnresolvedFKDetails.TableColumns = sqlFk.Columns doltFk.UnresolvedFKDetails.ReferencedTableColumns = sqlFk.ParentColumns - if doltFk.IsResolved() && !sqlFk.IsResolved { // Need to unresolve the foreign key - doltFk.TableIndex = "" - doltFk.TableColumns = nil - doltFk.ReferencedTableIndex = "" - doltFk.ReferencedTableColumns = nil - } else if !doltFk.IsResolved() && sqlFk.IsResolved { // Need to assign tags and indexes since it's resolved + if !doltFk.IsResolved() || !sqlFk.IsResolved { tbl, _, ok, err := root.GetTableInsensitive(ctx, t.tableName) if err != nil { return err @@ -2145,129 +2151,10 @@ func (t *AlterableDoltTable) UpdateForeignKey(ctx *sql.Context, fkName string, s if !ok { return sql.ErrTableNotFound.New(t.tableName) } - - colTags := make([]uint64, len(sqlFk.Columns)) - for i, col := range sqlFk.Columns { - tableCol, ok := t.sch.GetAllCols().GetByNameCaseInsensitive(col) - if !ok { - return fmt.Errorf("table `%s` does not have column `%s`", sqlFk.Table, col) - } - colTags[i] = tableCol.Tag - } - - var refTbl *doltdb.Table - var refSch schema.Schema - if sqlFk.IsSelfReferential() { - refTbl = tbl - refSch = t.sch - } else { - refTbl, _, ok, err = root.GetTableInsensitive(ctx, sqlFk.ParentTable) - if err != nil { - return err - } - if !ok { - return fmt.Errorf("referenced table `%s` does not exist", sqlFk.ParentTable) - } - refSch, err = refTbl.GetSchema(ctx) - if err != nil { - return err - } - } - - refColTags := make([]uint64, len(sqlFk.ParentColumns)) - for i, name := range sqlFk.ParentColumns { - refCol, ok := refSch.GetAllCols().GetByNameCaseInsensitive(name) - if !ok { - return fmt.Errorf("table `%s` does not have column `%s`", sqlFk.ParentTable, name) - } - refColTags[i] = refCol.Tag - } - - tableIndex, ok, err := findIndexWithPrefix(t.sch, sqlFk.Columns) + doltFk, err = t.createForeignKey(ctx, root, tbl, sqlFk, doltFk.OnUpdate, doltFk.OnDelete) if err != nil { return err } - if !ok { - // The engine matched on a primary key, and Dolt does not yet support using the primary key within the - // schema.Index interface (which is used internally to represent indexes across the codebase). In the - // meantime, we must generate a duplicate key over the primary key. - //TODO: use the primary key as-is - idxReturn, err := creation.CreateIndex( - ctx, - tbl, - "", - sqlFk.Columns, - nil, - false, - false, - "", - editor.Options{ - ForeignKeyChecksDisabled: true, - Deaf: t.opts.Deaf, - Tempdir: t.opts.Tempdir, - }) - if err != nil { - return err - } - tableIndex = idxReturn.NewIndex - tbl = idxReturn.NewTable - root, err = root.PutTable(ctx, t.tableName, idxReturn.NewTable) - if sqlFk.IsSelfReferential() { - refTbl = idxReturn.NewTable - } - } - - refTableIndex, ok, err := findIndexWithPrefix(refSch, sqlFk.ParentColumns) - if err != nil { - return err - } - if !ok { - // The engine matched on a primary key, and Dolt does not yet support using the primary key within the - // schema.Index interface (which is used internally to represent indexes across the codebase). In the - // meantime, we must generate a duplicate key over the primary key. - //TODO: use the primary key as-is - var refPkTags []uint64 - for _, i := range refSch.GetPkOrdinals() { - refPkTags = append(refPkTags, refSch.GetAllCols().GetByIndex(i).Tag) - } - - var colNames []string - for _, t := range refColTags { - c, _ := refSch.GetAllCols().GetByTag(t) - colNames = append(colNames, c.Name) - } - - // Our duplicate index is only unique if it's the entire primary key (which is by definition unique) - unique := len(refPkTags) == len(refColTags) - idxReturn, err := creation.CreateIndex( - ctx, - refTbl, - "", - colNames, - nil, - unique, - false, - "", - editor.Options{ - ForeignKeyChecksDisabled: true, - Deaf: t.opts.Deaf, - Tempdir: t.opts.Tempdir, - }) - if err != nil { - return err - } - refTbl = idxReturn.NewTable - refTableIndex = idxReturn.NewIndex - root, err = root.PutTable(ctx, sqlFk.ParentTable, idxReturn.NewTable) - if err != nil { - return err - } - } - - doltFk.TableIndex = tableIndex.Name() - doltFk.TableColumns = colTags - doltFk.ReferencedTableIndex = refTableIndex.Name() - doltFk.ReferencedTableColumns = refColTags } err = fkc.AddKeys(doltFk) diff --git a/go/libraries/doltcore/table/typed/json/reader.go b/go/libraries/doltcore/table/typed/json/reader.go index 40acebae53..47f5273967 100644 --- a/go/libraries/doltcore/table/typed/json/reader.go +++ b/go/libraries/doltcore/table/typed/json/reader.go @@ -111,7 +111,12 @@ func (r *JSONReader) ReadSqlRow(ctx context.Context) (sql.Row, error) { return nil, io.EOF } - return r.convToSqlRow(metaRow.Value.(map[string]interface{})) + mapVal, ok := metaRow.Value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("unexpected JSON format received, expected format: { \"rows\": [ json_row_objects... ] } ") + } + + return r.convToSqlRow(mapVal) } func (r *JSONReader) convToSqlRow(rowMap map[string]interface{}) (sql.Row, error) { diff --git a/go/store/datas/commit.go b/go/store/datas/commit.go index 5c34ec391d..87fa55dd37 100644 --- a/go/store/datas/commit.go +++ b/go/store/datas/commit.go @@ -445,6 +445,9 @@ func GetCommitParents(ctx context.Context, vr types.ValueReader, cv types.Value) return nil, errors.New("GetCommitParents: provided value is not a commit.") } addrs, err := types.SerialCommitParentAddrs(vr.Format(), sm) + if err != nil { + return nil, err + } vals, err := vr.ReadManyValues(ctx, addrs) if err != nil { return nil, err @@ -485,6 +488,9 @@ func GetCommitParents(ctx context.Context, vr types.ValueReader, cv types.Value) refs = append(refs, v.(types.Ref)) return nil }) + if err != nil { + return nil, err + } } else { ps, ok, err = c.MaybeGet(parentsField) if err != nil { @@ -496,6 +502,9 @@ func GetCommitParents(ctx context.Context, vr types.ValueReader, cv types.Value) refs = append(refs, v.(types.Ref)) return nil }) + if err != nil { + return nil, err + } } } hashes := make([]hash.Hash, len(refs)) diff --git a/go/store/nbs/aws_table_persister_test.go b/go/store/nbs/aws_table_persister_test.go index 7cdc10b916..4151157455 100644 --- a/go/store/nbs/aws_table_persister_test.go +++ b/go/store/nbs/aws_table_persister_test.go @@ -45,7 +45,7 @@ func TestAWSTablePersisterPersist(t *testing.T) { mt := newMemTable(testMemTableSize) for _, c := range testChunks { - assert.True(t, mt.addChunk(computeAddr(c), c)) + assert.Equal(t, mt.addChunk(computeAddr(c), c), chunkAdded) } t.Run("PersistToS3", func(t *testing.T) { @@ -89,8 +89,8 @@ func TestAWSTablePersisterPersist(t *testing.T) { existingTable := newMemTable(testMemTableSize) for _, c := range testChunks { - assert.True(mt.addChunk(computeAddr(c), c)) - assert.True(existingTable.addChunk(computeAddr(c), c)) + assert.Equal(mt.addChunk(computeAddr(c), c), chunkAdded) + assert.Equal(existingTable.addChunk(computeAddr(c), c), chunkAdded) } s3svc, ddb := makeFakeS3(t), makeFakeDTS(makeFakeDDB(t), nil) diff --git a/go/store/nbs/bs_persister.go b/go/store/nbs/bs_persister.go index af7c97c7b3..cfb6e9136d 100644 --- a/go/store/nbs/bs_persister.go +++ b/go/store/nbs/bs_persister.go @@ -233,6 +233,11 @@ type bsTableReaderAt struct { bs blobstore.Blobstore } +func (bsTRA *bsTableReaderAt) Reader(ctx context.Context) (io.ReadCloser, error) { + rc, _, err := bsTRA.bs.Get(ctx, bsTRA.key, blobstore.AllRange) + return rc, err +} + // ReadAtWithStats is the bsTableReaderAt implementation of the tableReaderAt interface func (bsTRA *bsTableReaderAt) ReadAtWithStats(ctx context.Context, p []byte, off int64, stats *Stats) (int, error) { br := blobstore.NewBlobRange(off, int64(len(p))) diff --git a/go/store/nbs/dynamo_table_reader.go b/go/store/nbs/dynamo_table_reader.go index ce44315b2b..1c15bae08a 100644 --- a/go/store/nbs/dynamo_table_reader.go +++ b/go/store/nbs/dynamo_table_reader.go @@ -22,6 +22,7 @@ package nbs import ( + "bytes" "context" "fmt" "io" @@ -53,6 +54,14 @@ func (t tableNotInDynamoErr) Error() string { return fmt.Sprintf("NBS table %s not present in DynamoDB table %s", t.nbs, t.dynamo) } +func (dtra *dynamoTableReaderAt) Reader(ctx context.Context) (io.ReadCloser, error) { + data, err := dtra.ddb.ReadTable(ctx, dtra.h, &Stats{}) + if err != nil { + return nil, err + } + return io.NopCloser(bytes.NewReader(data)), nil +} + func (dtra *dynamoTableReaderAt) ReadAtWithStats(ctx context.Context, p []byte, off int64, stats *Stats) (n int, err error) { data, err := dtra.ddb.ReadTable(ctx, dtra.h, stats) diff --git a/go/store/nbs/empty_chunk_source.go b/go/store/nbs/empty_chunk_source.go index 13b31aab8c..ea2f082dc0 100644 --- a/go/store/nbs/empty_chunk_source.go +++ b/go/store/nbs/empty_chunk_source.go @@ -70,8 +70,8 @@ func (ecs emptyChunkSource) index() (tableIndex, error) { return onHeapTableIndex{}, nil } -func (ecs emptyChunkSource) reader(context.Context) (io.Reader, uint64, error) { - return &bytes.Buffer{}, 0, nil +func (ecs emptyChunkSource) reader(context.Context) (io.ReadCloser, uint64, error) { + return io.NopCloser(&bytes.Buffer{}), 0, nil } func (ecs emptyChunkSource) getRecordRanges(lookups []getRecord) (map[hash.Hash]Range, error) { diff --git a/go/store/nbs/file_table_persister_test.go b/go/store/nbs/file_table_persister_test.go index cb5cff1081..a9a489387d 100644 --- a/go/store/nbs/file_table_persister_test.go +++ b/go/store/nbs/file_table_persister_test.go @@ -139,7 +139,7 @@ func TestFSTablePersisterPersist(t *testing.T) { func persistTableData(p tablePersister, chunx ...[]byte) (src chunkSource, err error) { mt := newMemTable(testMemTableSize) for _, c := range chunx { - if !mt.addChunk(computeAddr(c), c) { + if mt.addChunk(computeAddr(c), c) == chunkNotAdded { return nil, fmt.Errorf("memTable too full to add %s", computeAddr(c)) } } @@ -152,8 +152,8 @@ func TestFSTablePersisterPersistNoData(t *testing.T) { existingTable := newMemTable(testMemTableSize) for _, c := range testChunks { - assert.True(mt.addChunk(computeAddr(c), c)) - assert.True(existingTable.addChunk(computeAddr(c), c)) + assert.Equal(mt.addChunk(computeAddr(c), c), chunkAdded) + assert.Equal(existingTable.addChunk(computeAddr(c), c), chunkAdded) } dir := makeTempDir(t) diff --git a/go/store/nbs/file_table_reader.go b/go/store/nbs/file_table_reader.go index d1789e1a74..2aa858e5d2 100644 --- a/go/store/nbs/file_table_reader.go +++ b/go/store/nbs/file_table_reader.go @@ -55,7 +55,7 @@ func tableFileExists(ctx context.Context, dir string, h addr) (bool, error) { func newFileTableReader(ctx context.Context, dir string, h addr, chunkCount uint32, q MemoryQuotaProvider, fc *fdCache) (cs chunkSource, err error) { path := filepath.Join(dir, h.String()) - index, err := func() (ti onHeapTableIndex, err error) { + index, sz, err := func() (ti onHeapTableIndex, sz int64, err error) { // Be careful with how |f| is used below. |RefFile| returns a cached // os.File pointer so the code needs to use f in a concurrency-safe @@ -82,7 +82,8 @@ func newFileTableReader(ctx context.Context, dir string, h addr, chunkCount uint } idxSz := int64(indexSize(chunkCount) + footerSize) - indexOffset := fi.Size() - idxSz + sz = fi.Size() + indexOffset := sz - idxSz r := io.NewSectionReader(f, indexOffset, idxSz) var b []byte @@ -122,7 +123,7 @@ func newFileTableReader(ctx context.Context, dir string, h addr, chunkCount uint return nil, errors.New("unexpected chunk count") } - tr, err := newTableReader(index, &cacheReaderAt{path, fc}, fileBlockSize) + tr, err := newTableReader(index, &cacheReaderAt{path, fc, sz}, fileBlockSize) if err != nil { index.Close() return nil, err @@ -153,6 +154,11 @@ func (mmtr *fileTableReader) clone() (chunkSource, error) { type cacheReaderAt struct { path string fc *fdCache + sz int64 +} + +func (cra *cacheReaderAt) Reader(ctx context.Context) (io.ReadCloser, error) { + return io.NopCloser(io.LimitReader(&readerAdapter{cra, 0, ctx}, cra.sz)), nil } func (cra *cacheReaderAt) ReadAtWithStats(ctx context.Context, p []byte, off int64, stats *Stats) (n int, err error) { diff --git a/go/store/nbs/journal_chunk_source.go b/go/store/nbs/journal_chunk_source.go index bc05965700..eb599139dd 100644 --- a/go/store/nbs/journal_chunk_source.go +++ b/go/store/nbs/journal_chunk_source.go @@ -193,9 +193,9 @@ func (s journalChunkSource) hash() addr { } // reader implements chunkSource. -func (s journalChunkSource) reader(context.Context) (io.Reader, uint64, error) { +func (s journalChunkSource) reader(context.Context) (io.ReadCloser, uint64, error) { rdr, sz, err := s.journal.Snapshot() - return rdr, uint64(sz), err + return io.NopCloser(rdr), uint64(sz), err } func (s journalChunkSource) getRecordRanges(requests []getRecord) (map[hash.Hash]Range, error) { diff --git a/go/store/nbs/mem_table.go b/go/store/nbs/mem_table.go index aea022b7ac..4476979ec8 100644 --- a/go/store/nbs/mem_table.go +++ b/go/store/nbs/mem_table.go @@ -33,6 +33,14 @@ import ( "github.com/dolthub/dolt/go/store/hash" ) +type addChunkResult int + +const ( + chunkExists addChunkResult = iota + chunkAdded + chunkNotAdded +) + func WriteChunks(chunks []chunks.Chunk) (string, []byte, error) { var size uint64 for _, chunk := range chunks { @@ -46,7 +54,8 @@ func WriteChunks(chunks []chunks.Chunk) (string, []byte, error) { func writeChunksToMT(mt *memTable, chunks []chunks.Chunk) (string, []byte, error) { for _, chunk := range chunks { - if !mt.addChunk(addr(chunk.Hash()), chunk.Data()) { + res := mt.addChunk(addr(chunk.Hash()), chunk.Data()) + if res == chunkNotAdded { return "", nil, errors.New("didn't create this memory table with enough space to add all the chunks") } } @@ -78,17 +87,19 @@ func newMemTable(memTableSize uint64) *memTable { return &memTable{chunks: map[addr][]byte{}, maxData: memTableSize} } -func (mt *memTable) addChunk(h addr, data []byte) bool { +func (mt *memTable) addChunk(h addr, data []byte) addChunkResult { if len(data) == 0 { panic("NBS blocks cannot be zero length") } if _, ok := mt.chunks[h]; ok { - return true + return chunkExists } + dataLen := uint64(len(data)) if mt.totalData+dataLen > mt.maxData { - return false + return chunkNotAdded } + mt.totalData += dataLen mt.chunks[h] = data mt.order = append(mt.order, hasRecord{ @@ -97,7 +108,7 @@ func (mt *memTable) addChunk(h addr, data []byte) bool { len(mt.order), false, }) - return true + return chunkAdded } func (mt *memTable) addChildRefs(addrs hash.HashSet) { diff --git a/go/store/nbs/mem_table_test.go b/go/store/nbs/mem_table_test.go index 26f233bbff..acdfff9e8e 100644 --- a/go/store/nbs/mem_table_test.go +++ b/go/store/nbs/mem_table_test.go @@ -24,6 +24,7 @@ package nbs import ( "bytes" "context" + "io" "os" "testing" @@ -89,7 +90,7 @@ func TestMemTableAddHasGetChunk(t *testing.T) { } for _, c := range chunks { - assert.True(mt.addChunk(computeAddr(c), c)) + assert.Equal(mt.addChunk(computeAddr(c), c), chunkAdded) } assertChunksInReader(chunks, mt, assert) @@ -114,9 +115,9 @@ func TestMemTableAddOverflowChunk(t *testing.T) { { bigAddr := computeAddr(big) mt := newMemTable(memTableSize) - assert.True(mt.addChunk(bigAddr, big)) + assert.Equal(mt.addChunk(bigAddr, big), chunkAdded) assert.True(mt.has(bigAddr)) - assert.False(mt.addChunk(computeAddr(little), little)) + assert.Equal(mt.addChunk(computeAddr(little), little), chunkNotAdded) assert.False(mt.has(computeAddr(little))) } @@ -124,12 +125,12 @@ func TestMemTableAddOverflowChunk(t *testing.T) { big := big[:memTableSize-1] bigAddr := computeAddr(big) mt := newMemTable(memTableSize) - assert.True(mt.addChunk(bigAddr, big)) + assert.Equal(mt.addChunk(bigAddr, big), chunkAdded) assert.True(mt.has(bigAddr)) - assert.True(mt.addChunk(computeAddr(little), little)) + assert.Equal(mt.addChunk(computeAddr(little), little), chunkAdded) assert.True(mt.has(computeAddr(little))) other := []byte("o") - assert.False(mt.addChunk(computeAddr(other), other)) + assert.Equal(mt.addChunk(computeAddr(other), other), chunkNotAdded) assert.False(mt.has(computeAddr(other))) } } @@ -146,7 +147,7 @@ func TestMemTableWrite(t *testing.T) { } for _, c := range chunks { - assert.True(mt.addChunk(computeAddr(c), c)) + assert.Equal(mt.addChunk(computeAddr(c), c), chunkAdded) } td1, _, err := buildTable(chunks[1:2]) @@ -179,15 +180,20 @@ func TestMemTableWrite(t *testing.T) { } type tableReaderAtAdapter struct { - *bytes.Reader + br *bytes.Reader } func tableReaderAtFromBytes(b []byte) tableReaderAt { return tableReaderAtAdapter{bytes.NewReader(b)} } +func (adapter tableReaderAtAdapter) Reader(ctx context.Context) (io.ReadCloser, error) { + r := *adapter.br + return io.NopCloser(&r), nil +} + func (adapter tableReaderAtAdapter) ReadAtWithStats(ctx context.Context, p []byte, off int64, stats *Stats) (n int, err error) { - return adapter.ReadAt(p, off) + return adapter.br.ReadAt(p, off) } func TestMemTableSnappyWriteOutOfLine(t *testing.T) { @@ -201,7 +207,7 @@ func TestMemTableSnappyWriteOutOfLine(t *testing.T) { } for _, c := range chunks { - assert.True(mt.addChunk(computeAddr(c), c)) + assert.Equal(mt.addChunk(computeAddr(c), c), chunkAdded) } mt.snapper = &outOfLineSnappy{[]bool{false, true, false}} // chunks[1] should trigger a panic diff --git a/go/store/nbs/s3_table_reader.go b/go/store/nbs/s3_table_reader.go index dda82631df..68079834c2 100644 --- a/go/store/nbs/s3_table_reader.go +++ b/go/store/nbs/s3_table_reader.go @@ -60,6 +60,10 @@ type s3svc interface { PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) } +func (s3tra *s3TableReaderAt) Reader(ctx context.Context) (io.ReadCloser, error) { + return s3tra.s3.Reader(ctx, s3tra.h) +} + func (s3tra *s3TableReaderAt) ReadAtWithStats(ctx context.Context, p []byte, off int64, stats *Stats) (n int, err error) { return s3tra.s3.ReadAt(ctx, s3tra.h, p, off, stats) } @@ -79,6 +83,10 @@ func (s3or *s3ObjectReader) key(k string) string { return k } +func (s3or *s3ObjectReader) Reader(ctx context.Context, name addr) (io.ReadCloser, error) { + return s3or.reader(ctx, name) +} + func (s3or *s3ObjectReader) ReadAt(ctx context.Context, name addr, p []byte, off int64, stats *Stats) (n int, err error) { t1 := time.Now() @@ -143,6 +151,18 @@ func (s3or *s3ObjectReader) ReadFromEnd(ctx context.Context, name addr, p []byte return s3or.readRange(ctx, name, p, fmt.Sprintf("%s=-%d", s3RangePrefix, len(p))) } +func (s3or *s3ObjectReader) reader(ctx context.Context, name addr) (io.ReadCloser, error) { + input := &s3.GetObjectInput{ + Bucket: aws.String(s3or.bucket), + Key: aws.String(s3or.key(name.String())), + } + result, err := s3or.s3.GetObjectWithContext(ctx, input) + if err != nil { + return nil, err + } + return result.Body, nil +} + func (s3or *s3ObjectReader) readRange(ctx context.Context, name addr, p []byte, rangeHeader string) (n int, sz uint64, err error) { read := func() (int, uint64, error) { if s3or.readRl != nil { diff --git a/go/store/nbs/store.go b/go/store/nbs/store.go index f1d0a97765..f0a4632f04 100644 --- a/go/store/nbs/store.go +++ b/go/store/nbs/store.go @@ -623,8 +623,8 @@ func (nbs *NomsBlockStore) addChunk(ctx context.Context, ch chunks.Chunk, addrs } a := addr(ch.Hash()) - ok := nbs.mt.addChunk(a, ch.Data()) - if !ok { + addChunkRes := nbs.mt.addChunk(a, ch.Data()) + if addChunkRes == chunkNotAdded { ts, err := nbs.tables.append(ctx, nbs.mt, checker, nbs.stats) if err != nil { if errors.Is(err, ErrDanglingRef) { @@ -634,12 +634,12 @@ func (nbs *NomsBlockStore) addChunk(ctx context.Context, ch chunks.Chunk, addrs } nbs.tables = ts nbs.mt = newMemTable(nbs.mtSize) - ok = nbs.mt.addChunk(a, ch.Data()) + addChunkRes = nbs.mt.addChunk(a, ch.Data()) } - if ok { + if addChunkRes == chunkAdded { nbs.mt.addChildRefs(addrs) } - return ok, nil + return addChunkRes == chunkAdded || addChunkRes == chunkExists, nil } // refCheck checks that no dangling references are being committed. @@ -1292,7 +1292,7 @@ func newTableFile(cs chunkSource, info tableSpec) tableFile { if err != nil { return nil, 0, err } - return io.NopCloser(r), s, nil + return r, s, nil }, } } diff --git a/go/store/nbs/table.go b/go/store/nbs/table.go index 925ab4cbb2..dec3c60bd9 100644 --- a/go/store/nbs/table.go +++ b/go/store/nbs/table.go @@ -260,7 +260,7 @@ type chunkSource interface { hash() addr // opens a Reader to the first byte of the chunkData segment of this table. - reader(context.Context) (io.Reader, uint64, error) + reader(context.Context) (io.ReadCloser, uint64, error) // getRecordRanges sets getRecord.found to true, and returns a Range for each present getRecord query. getRecordRanges(requests []getRecord) (map[hash.Hash]Range, error) diff --git a/go/store/nbs/table_index.go b/go/store/nbs/table_index.go index 80ba26a9f2..0b29c1b3ae 100644 --- a/go/store/nbs/table_index.go +++ b/go/store/nbs/table_index.go @@ -333,8 +333,6 @@ func (ti onHeapTableIndex) lookupOrdinal(h *addr) (uint32, error) { // findPrefix returns the first position in |tr.prefixes| whose value == |prefix|. // Returns |tr.chunkCount| if absent func (ti onHeapTableIndex) findPrefix(prefix uint64) (idx uint32) { - query := make([]byte, addrPrefixSize) - binary.BigEndian.PutUint64(query, prefix) // NOTE: The golang impl of sort.Search is basically inlined here. This method can be called in // an extremely tight loop and inlining the code was a significant perf improvement. idx, j := 0, ti.count @@ -342,7 +340,8 @@ func (ti onHeapTableIndex) findPrefix(prefix uint64) (idx uint32) { h := idx + (j-idx)/2 // avoid overflow when computing h // i ≤ h < j o := int64(prefixTupleSize * h) - if bytes.Compare(ti.prefixTuples[o:o+addrPrefixSize], query) < 0 { + tmp := binary.BigEndian.Uint64(ti.prefixTuples[o : o+addrPrefixSize]) + if tmp < prefix { idx = h + 1 // preserves f(i-1) == false } else { j = h // preserves f(j) == true diff --git a/go/store/nbs/table_reader.go b/go/store/nbs/table_reader.go index 9f05270676..3306103d10 100644 --- a/go/store/nbs/table_reader.go +++ b/go/store/nbs/table_reader.go @@ -130,6 +130,7 @@ func (ir indexResult) Length() uint32 { type tableReaderAt interface { ReadAtWithStats(ctx context.Context, p []byte, off int64, stats *Stats) (n int, err error) + Reader(ctx context.Context) (io.ReadCloser, error) } // tableReader implements get & has queries against a single nbs table. goroutine safe. @@ -631,10 +632,14 @@ func (tr tableReader) extract(ctx context.Context, chunks chan<- extractRecord) return nil } -func (tr tableReader) reader(ctx context.Context) (io.Reader, uint64, error) { +func (tr tableReader) reader(ctx context.Context) (io.ReadCloser, uint64, error) { i, _ := tr.index() sz := i.tableFileSize() - return io.LimitReader(&readerAdapter{tr.r, 0, ctx}, int64(sz)), sz, nil + r, err := tr.r.Reader(ctx) + if err != nil { + return nil, 0, err + } + return r, sz, nil } func (tr tableReader) getRecordRanges(requests []getRecord) (map[hash.Hash]Range, error) { diff --git a/go/store/prolly/tree/testutils.go b/go/store/prolly/tree/testutils.go index ab756790fa..a9e08f3c08 100644 --- a/go/store/prolly/tree/testutils.go +++ b/go/store/prolly/tree/testutils.go @@ -213,16 +213,16 @@ func randomField(tb *val.TupleBuilder, idx int, typ val.Type, ns NodeStore) { v := uint16(testRand.Intn(math.MaxUint16)) tb.PutUint16(idx, v) case val.Int32Enc: - v := int32(testRand.Intn(math.MaxInt32) * neg) + v := testRand.Int31() * int32(neg) tb.PutInt32(idx, v) case val.Uint32Enc: - v := uint32(testRand.Intn(math.MaxUint32)) + v := testRand.Uint32() tb.PutUint32(idx, v) case val.Int64Enc: - v := int64(testRand.Intn(math.MaxInt64) * neg) + v := testRand.Int63() * int64(neg) tb.PutInt64(idx, v) case val.Uint64Enc: - v := uint64(testRand.Uint64()) + v := testRand.Uint64() tb.PutUint64(idx, v) case val.Float32Enc: tb.PutFloat32(idx, testRand.Float32()) diff --git a/go/store/skip/list.go b/go/store/skip/list.go index 5b43a0e17a..e472f91ba3 100644 --- a/go/store/skip/list.go +++ b/go/store/skip/list.go @@ -21,7 +21,7 @@ import ( const ( maxHeight = 9 - maxCount = math.MaxUint32 - 1 + maxCount = math.MaxInt32 - 1 sentinelId = nodeId(0) initSize = 8 ) diff --git a/integration-tests/bats/constraint-violations.bats b/integration-tests/bats/constraint-violations.bats index c5c52fb7df..a59b03b138 100644 --- a/integration-tests/bats/constraint-violations.bats +++ b/integration-tests/bats/constraint-violations.bats @@ -2861,3 +2861,15 @@ SQL [[ $output =~ "Automatic merge failed; 1 table(s) are unmerged." ]] } +@test "constraint-violations: altering FKs over PKs does not create bad index" { + dolt sql < /dev/null; do sleep .1; done fi; - fi fi SERVER_PID= } diff --git a/integration-tests/bats/import-replace-tables.bats b/integration-tests/bats/import-replace-tables.bats index 0ec70b0f97..b0547f4d64 100644 --- a/integration-tests/bats/import-replace-tables.bats +++ b/integration-tests/bats/import-replace-tables.bats @@ -196,6 +196,22 @@ SQL [[ "$output" =~ "An error occurred while moving data" ]] || false } +@test "import-replace-tables: import table with unexpected JSON format" { + dolt sql < unexpected.json + run dolt table import -r employees unexpected.json + [ "$status" -eq 1 ] + [[ "$output" =~ "An error occurred while moving data" ]] || false + [[ "$output" =~ "unexpected JSON format received, expected format: { \"rows\": [ json_row_objects... ] }" ]] || false +} + @test "import-replace-tables: replace table using xlsx file" { dolt sql <rem1->repo2 cd repo2 dolt sql -q "create table t2 (a int)" dolt add . dolt commit -am "forced commit" dolt push --force origin main - cd ../repo1 + run dolt sql -q "call dolt_fetch('origin', 'main')" - [ "$status" -eq 1 ] - [[ "$output" =~ "fetch failed: can't fast forward merge" ]] || false - - dolt sql -q "call dolt_fetch('--force', 'origin', 'main')" - - dolt diff main origin/main - run dolt diff main origin/main [ "$status" -eq 0 ] - [[ "$output" =~ "deleted table" ]] || false - - run dolt sql -q "show tables as of hashof('origin/main')" -r csv - [ "${#lines[@]}" -eq 2 ] - [[ "$output" =~ "Table" ]] || false - [[ "$output" =~ "t2" ]] || false -} - -@test "sql-fetch: CALL dolt_fetch --force" { - # reverse information flow for force fetch repo1->rem1->repo2 - cd repo2 - dolt sql -q "create table t2 (a int)" - dolt add . - dolt commit -am "forced commit" - dolt push --force origin main - - cd ../repo1 - run dolt sql -q "CALL dolt_fetch('origin', 'main')" - [ "$status" -eq 1 ] - [[ "$output" =~ "fetch failed: can't fast forward merge" ]] || false - - dolt sql -q "CALL dolt_fetch('--force', 'origin', 'main')" run dolt diff main origin/main [ "$status" -eq 0 ] diff --git a/integration-tests/bats/sql-server.bats b/integration-tests/bats/sql-server.bats index 5337bf14ed..30b6798526 100644 --- a/integration-tests/bats/sql-server.bats +++ b/integration-tests/bats/sql-server.bats @@ -1873,3 +1873,13 @@ s.close() [ $status -eq 0 ] [[ "$output" =~ "$EXPECTED" ]] || false } + +@test "sql-server: binary literal is printed as hex string for utf8 charset result set" { + cd repo1 + start_sql_server + dolt sql-client -P $PORT -u dolt --use-db repo1 -q "SET character_set_results = utf8; CREATE TABLE mapping(branch_id binary(16) PRIMARY KEY, user_id binary(16) NOT NULL, company_id binary(16) NOT NULL);" + + run dolt sql-client -P $PORT -u dolt --use-db repo1 -q "EXPLAIN SELECT m.* FROM mapping m WHERE user_id = uuid_to_bin('1c4c4e33-8ad7-4421-8450-9d5182816ac3');" + [ $status -eq 0 ] + [[ "$output" =~ "0x1C4C4E338AD7442184509D5182816AC3" ]] || false +} diff --git a/integration-tests/bats/sql.bats b/integration-tests/bats/sql.bats index 53d1342cb9..9b5ba865e1 100755 --- a/integration-tests/bats/sql.bats +++ b/integration-tests/bats/sql.bats @@ -2779,3 +2779,17 @@ SQL run dolt sql -q 'INSERT INTO dts (`created_at`) VALUES ("0001-01-01 00:00:00");' [ "$status" -eq 0 ] } + +@test "sql: multi statement query returns accurate timing" { + dolt sql -q "CREATE TABLE t(a int);" + dolt sql -q "INSERT INTO t VALUES (1);" + dolt sql -q "CREATE TABLE t1(b int);" + run dolt sql < { + this.connection.query(sql, args, (err, rows) => { + if (err) return reject(err); + return resolve(rows); + }); + }); + } + close() { + this.connection.end((err) => { + if (err) { + console.error(err); + } else { + console.log("db connection closed"); + } + }); + } +} diff --git a/integration-tests/mysql-client-tests/node/helpers.js b/integration-tests/mysql-client-tests/node/helpers.js new file mode 100644 index 0000000000..8fcf0151b4 --- /dev/null +++ b/integration-tests/mysql-client-tests/node/helpers.js @@ -0,0 +1,18 @@ +const args = process.argv.slice(2); +const user = args[0]; +const port = args[1]; +const dbName = args[2]; + +export function getArgs() { + return { user, port, dbName }; +} + +export function getConfig() { + const { user, port, dbName } = getArgs(); + return { + host: "127.0.0.1", + port: port, + user: user, + database: dbName, + }; +} diff --git a/integration-tests/mysql-client-tests/node/index.js b/integration-tests/mysql-client-tests/node/index.js index f0eca0a11e..413b352e7b 100644 --- a/integration-tests/mysql-client-tests/node/index.js +++ b/integration-tests/mysql-client-tests/node/index.js @@ -1,145 +1,115 @@ -const mysql = require('mysql'); +import { Database } from "./database.js"; +import { getConfig } from "./helpers.js"; -const args = process.argv.slice(2); - -const user = args[0]; -const port = args[1]; -const db = args[2]; - -const config = { - host: '127.0.0.1', - user: user, - port: port, - database: db -}; - -class Database { - constructor( config ) { - this.connection = mysql.createConnection( config ); - this.connection.connect(); - } - - query( sql, args ) { - return new Promise( ( resolve, reject ) => { - this.connection.query( sql, args, ( err, rows ) => { - if ( err ) - return reject( err ); - return resolve( rows ); - } ); - } ); - } - close() { - this.connection.end(err => { - if (err) { - console.error(err) - } else { - console.log("db connection closed") - } - }) - } -} +const tests = [ + { + q: "create table test (pk int, `value` int, primary key(pk))", + res: { + fieldCount: 0, + affectedRows: 0, + insertId: 0, + serverStatus: 2, + warningCount: 0, + message: "", + protocol41: true, + changedRows: 0, + }, + }, + { + q: "describe test", + res: [ + { + Field: "pk", + Type: "int", + Null: "NO", + Key: "PRI", + Default: "NULL", + Extra: "", + }, + { + Field: "value", + Type: "int", + Null: "YES", + Key: "", + Default: "NULL", + Extra: "", + }, + ], + }, + { q: "select * from test", res: [] }, + { + q: "insert into test (pk, `value`) values (0,0)", + res: { + fieldCount: 0, + affectedRows: 1, + insertId: 0, + serverStatus: 2, + warningCount: 0, + message: "", + protocol41: true, + changedRows: 0, + }, + }, + { q: "select * from test", res: [{ pk: 0, value: 0 }] }, + { q: "call dolt_add('-A');", res: [{ status: 0 }] }, + { q: "call dolt_commit('-m', 'my commit')", res: [] }, + { q: "select COUNT(*) FROM dolt_log", res: [{ "COUNT(*)": 2 }] }, + { q: "call dolt_checkout('-b', 'mybranch')", res: [{ status: 0 }] }, + { + q: "insert into test (pk, `value`) values (1,1)", + res: { + fieldCount: 0, + affectedRows: 1, + insertId: 0, + serverStatus: 2, + warningCount: 0, + message: "", + protocol41: true, + changedRows: 0, + }, + }, + { q: "call dolt_commit('-a', '-m', 'my commit2')", res: [] }, + { q: "call dolt_checkout('main')", res: [{ status: 0 }] }, + { + q: "call dolt_merge('mybranch')", + res: [{ fast_forward: 1, conflicts: 0 }], + }, + { q: "select COUNT(*) FROM dolt_log", res: [{ "COUNT(*)": 3 }] }, +]; async function main() { - const queries = [ - "create table test (pk int, `value` int, primary key(pk))", - "describe test", - "select * from test", - "insert into test (pk, `value`) values (0,0)", - "select * from test", - "call dolt_add('-A');", - "call dolt_commit('-m', 'my commit')", - "select COUNT(*) FROM dolt_log", - "call dolt_checkout('-b', 'mybranch')", - "insert into test (pk, `value`) values (1,1)", - "call dolt_commit('-a', '-m', 'my commit2')", - "call dolt_checkout('main')", - "call dolt_merge('mybranch')", - "select COUNT(*) FROM dolt_log", - ]; + const database = new Database(getConfig()); - const results = [ - { - fieldCount: 0, - affectedRows: 0, - insertId: 0, - serverStatus: 2, - warningCount: 0, - message: '', - protocol41: true, - changedRows: 0 - }, - [ { Field: 'pk', - Type: 'int', - Null: 'NO', - Key: 'PRI', - Default: 'NULL', - Extra: '' }, - { Field: 'value', - Type: 'int', - Null: 'YES', - Key: '', - Default: 'NULL', - Extra: '' } - ], - [], - { - fieldCount: 0, - affectedRows: 1, - insertId: 0, - serverStatus: 2, - warningCount: 0, - message: '', - protocol41: true, - changedRows: 0 - }, - [ { pk: 0, value: 0 } ], - [ { status: 0 } ], - [], - [ { "COUNT(*)": 2 } ], - [ { status: 0 } ], - { - fieldCount: 0, - affectedRows: 1, - insertId: 0, - serverStatus: 2, - warningCount: 0, - message: '', - protocol41: true, - changedRows: 0 - }, - [], - [ { status: 0 } ], - [ { fast_forward: 1, conflicts: 0 } ], - [ { "COUNT(*)": 3 } ], - ]; + await Promise.all( + tests.map((test) => { + const expected = test.res; + return database + .query(test.q) + .then((rows) => { + const resultStr = JSON.stringify(rows); + const result = JSON.parse(resultStr); + if ( + resultStr !== JSON.stringify(expected) && + test.q.includes("dolt_commit") && + !(rows.length === 1 && rows[0].hash.length > 0) + ) { + console.log("Query:", test.q); + console.log("Results:", result); + console.log("Expected:", expected); + throw new Error("Query failed"); + } else { + console.log("Query succeeded:", test.q); + } + }) + .catch((err) => { + console.error(err); + process.exit(1); + }); + }) + ); - const database = new Database(config); - - await Promise.all(queries.map((query, idx) => { - const expected = results[idx]; - return database.query(query).then(rows => { - const resultStr = JSON.stringify(rows); - const result = JSON.parse(resultStr); - if (resultStr !== JSON.stringify(expected) && !(query.includes("dolt_commit"))) { - console.log("Query:", query); - console.log("Results:", result); - console.log("Expected:", expected); - throw new Error("Query failed") - } else { - console.log("Query succeeded:", query) - } - }).catch(err => { - console.error(err) - process.exit(1); - }); - })); - - database.close() - process.exit(0) + database.close(); + process.exit(0); } main(); - - - - diff --git a/integration-tests/mysql-client-tests/node/knex.js b/integration-tests/mysql-client-tests/node/knex.js index fdc03335ba..b1b960972e 100644 --- a/integration-tests/mysql-client-tests/node/knex.js +++ b/integration-tests/mysql-client-tests/node/knex.js @@ -1,84 +1,79 @@ -const knex = require("knex"); -const wtfnode = require("wtfnode") -Socket = require('net').Socket; - -const args = process.argv.slice(2); -const user = args[0]; -const port = args[1]; -const dbName = args[2]; +import knex from "knex"; +import wtfnode from "wtfnode"; +import { Socket } from "net"; +import { getConfig } from "./helpers.js"; const db = knex({ - client: "mysql2", - connection: { - host: "127.0.0.1", - port: port, - user: user, - database: dbName, - }, + client: "mysql2", + connection: getConfig(), }); async function createTable() { - let val = await db.schema.createTable('test2', (table) => { - table.integer('id').primary() - table.integer('foo') - }); - return val + const val = await db.schema.createTable("test2", (table) => { + table.integer("id").primary(); + table.integer("foo"); + }); + return val; } async function upsert(table, data) { - let val = await db(table).insert(data).onConflict().merge(); - return val + const val = await db(table).insert(data).onConflict().merge(); + return val; } async function select() { - let val = await db.select('id', 'foo').from('test2'); - return val + const val = await db.select("id", "foo").from("test2"); + return val; } async function main() { - await createTable(); - await Promise.all([ - upsert("test2", { id: 1, foo: 1 }), - upsert("test2", { id: 2, foo: 2 }), - ]) + await createTable(); + await Promise.all([ + upsert("test2", { id: 1, foo: 1 }), + upsert("test2", { id: 2, foo: 2 }), + ]); - let expectedResult = JSON.stringify([ { id: 1, foo: 1 }, { id: 2, foo: 2 } ]) - let result = await select(); - if (JSON.stringify(result) !== expectedResult) { - console.log("Results:", result); - console.log("Expected:", expectedResult); - process.exit(1) - throw new Error("Query failed") + const expectedResult = JSON.stringify([ + { id: 1, foo: 1 }, + { id: 2, foo: 2 }, + ]); + const result = await select(); + if (JSON.stringify(result) !== expectedResult) { + console.log("Results:", result); + console.log("Expected:", expectedResult); + process.exit(1); + } + + await db.destroy(); + + // cc: https://github.com/dolthub/dolt/issues/3752 + setTimeout(async () => { + const sockets = await getOpenSockets(); + + if (sockets.length > 0) { + wtfnode.dump(); + process.exit(1); } - - await db.destroy(); - - // cc: https://github.com/dolthub/dolt/issues/3752 - setTimeout(async () => { - let sockets = await getOpenSockets(); - - if (sockets.length > 0) { - wtfnode.dump(); - process.exit(1); - throw new Error("Database not properly destroyed. Hanging server connections"); - } - - }, 3000); + }, 3000); } // cc: https://github.com/myndzi/wtfnode/blob/master/index.js#L457 async function getOpenSockets() { - let sockets = [] - process._getActiveHandles().forEach(function (h) { - // handles can be null now? early exit to guard against this - if (!h) { return; } + const sockets = []; + process._getActiveHandles().forEach(function (h) { + // handles can be null now? early exit to guard against this + if (!h) { + return; + } - if (h instanceof Socket) { - if ((h.fd == null) && (h.localAddress) && !(h.destroyed)) { sockets.push(h); } - } - }); + if (h instanceof Socket) { + if (h.fd == null && h.localAddress && !h.destroyed) { + sockets.push(h); + } + } + }); - return sockets + return sockets; } main(); diff --git a/integration-tests/mysql-client-tests/node/package-lock.json b/integration-tests/mysql-client-tests/node/package-lock.json index 9ebdad34d7..daaa50c802 100644 --- a/integration-tests/mysql-client-tests/node/package-lock.json +++ b/integration-tests/mysql-client-tests/node/package-lock.json @@ -9,9 +9,10 @@ "version": "1.0.0", "license": "ISC", "dependencies": { - "knex": "^1.0.7", + "knex": "^2.4.0", "mysql": "^2.18.1", - "mysql2": "^2.3.3" + "mysql2": "^2.3.3", + "wtfnode": "^0.9.1" } }, "node_modules/bignumber.js": { @@ -23,9 +24,9 @@ } }, "node_modules/colorette": { - "version": "2.0.16", - "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.16.tgz", - "integrity": "sha512-hUewv7oMjCp+wkBv5Rm0v87eJhq4woh5rSR+42YSQJKecCqgIqNkZ6lAlQms/BwHPJA5NKMRlpxPRv0n8HQW6g==" + "version": "2.0.19", + "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.19.tgz", + "integrity": "sha512-3tlv/dIP7FWvj3BsbHrGLJ6l/oKh1O3TcgBqMn+yyCagOxc23fyzDS6HypQbgxWbkpDnf52p1LuR4eWDQ/K9WQ==" }, "node_modules/commander": { "version": "9.2.0", @@ -163,11 +164,11 @@ "integrity": "sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE=" }, "node_modules/knex": { - "version": "1.0.7", - "resolved": "https://registry.npmjs.org/knex/-/knex-1.0.7.tgz", - "integrity": "sha512-89jxuRATt4qJMb9ZyyaKBy0pQ4d5h7eOFRqiNFnUvsgU+9WZ2eIaZKrAPG1+F3mgu5UloPUnkVE5Yo2sKZUs6Q==", + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/knex/-/knex-2.4.0.tgz", + "integrity": "sha512-i0GWwqYp1Hs2yvc2rlDO6nzzkLhwdyOZKRdsMTB8ZxOs2IXQyL5rBjSbS1krowCh6V65T4X9CJaKtuIfkaPGSA==", "dependencies": { - "colorette": "2.0.16", + "colorette": "2.0.19", "commander": "^9.1.0", "debug": "4.3.4", "escalade": "^3.1.1", @@ -189,9 +190,6 @@ "node": ">=12" }, "peerDependenciesMeta": { - "@vscode/sqlite3": { - "optional": true - }, "better-sqlite3": { "optional": true }, @@ -207,6 +205,9 @@ "pg-native": { "optional": true }, + "sqlite3": { + "optional": true + }, "tedious": { "optional": true } @@ -435,6 +436,14 @@ "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", "integrity": "sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8=" }, + "node_modules/wtfnode": { + "version": "0.9.1", + "resolved": "https://registry.npmjs.org/wtfnode/-/wtfnode-0.9.1.tgz", + "integrity": "sha512-Ip6C2KeQPl/F3aP1EfOnPoQk14Udd9lffpoqWDNH3Xt78svxPbv53ngtmtfI0q2Te3oTq79XKTnRNXVIn/GsPA==", + "bin": { + "wtfnode": "proxy.js" + } + }, "node_modules/yallist": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", @@ -448,9 +457,9 @@ "integrity": "sha512-t/OYhhJ2SD+YGBQcjY8GzzDHEk9f3nerxjtfa6tlMXfe7frs/WozhvCNoGvpM0P3bNf3Gq5ZRMlGr5f3r4/N8A==" }, "colorette": { - "version": "2.0.16", - "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.16.tgz", - "integrity": "sha512-hUewv7oMjCp+wkBv5Rm0v87eJhq4woh5rSR+42YSQJKecCqgIqNkZ6lAlQms/BwHPJA5NKMRlpxPRv0n8HQW6g==" + "version": "2.0.19", + "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.19.tgz", + "integrity": "sha512-3tlv/dIP7FWvj3BsbHrGLJ6l/oKh1O3TcgBqMn+yyCagOxc23fyzDS6HypQbgxWbkpDnf52p1LuR4eWDQ/K9WQ==" }, "commander": { "version": "9.2.0", @@ -553,11 +562,11 @@ "integrity": "sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE=" }, "knex": { - "version": "1.0.7", - "resolved": "https://registry.npmjs.org/knex/-/knex-1.0.7.tgz", - "integrity": "sha512-89jxuRATt4qJMb9ZyyaKBy0pQ4d5h7eOFRqiNFnUvsgU+9WZ2eIaZKrAPG1+F3mgu5UloPUnkVE5Yo2sKZUs6Q==", + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/knex/-/knex-2.4.0.tgz", + "integrity": "sha512-i0GWwqYp1Hs2yvc2rlDO6nzzkLhwdyOZKRdsMTB8ZxOs2IXQyL5rBjSbS1krowCh6V65T4X9CJaKtuIfkaPGSA==", "requires": { - "colorette": "2.0.16", + "colorette": "2.0.19", "commander": "^9.1.0", "debug": "4.3.4", "escalade": "^3.1.1", @@ -758,6 +767,11 @@ "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", "integrity": "sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8=" }, + "wtfnode": { + "version": "0.9.1", + "resolved": "https://registry.npmjs.org/wtfnode/-/wtfnode-0.9.1.tgz", + "integrity": "sha512-Ip6C2KeQPl/F3aP1EfOnPoQk14Udd9lffpoqWDNH3Xt78svxPbv53ngtmtfI0q2Te3oTq79XKTnRNXVIn/GsPA==" + }, "yallist": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", diff --git a/integration-tests/mysql-client-tests/node/package.json b/integration-tests/mysql-client-tests/node/package.json index 0885aefdd9..f1edcc04d1 100644 --- a/integration-tests/mysql-client-tests/node/package.json +++ b/integration-tests/mysql-client-tests/node/package.json @@ -3,13 +3,14 @@ "version": "1.0.0", "description": "A simple node command line utility to show how to connect a node application to a Dolt database using the MySQL connector.", "main": "index.js", + "type": "module", "scripts": { "test": "echo \"Error: no test specified\" && exit 1" }, "author": "", "license": "ISC", "dependencies": { - "knex": "^1.0.7", + "knex": "^2.4.0", "mysql": "^2.18.1", "mysql2": "^2.3.3", "wtfnode": "^0.9.1"