diff --git a/bats/sql-shell.bats b/bats/sql-shell.bats index e993e487f7..43dc341a91 100644 --- a/bats/sql-shell.bats +++ b/bats/sql-shell.bats @@ -27,6 +27,20 @@ teardown() { [[ "$output" =~ "pk" ]] || false } +@test "sql shell writes to disk after every iteration (autocommit)" { + skiponwindows "Need to install expect and make this script work on windows." + run $BATS_TEST_DIRNAME/sql-shell.expect + echo "$output" + + # 2 tables are created. 1 from above and 1 in the expect file. + [[ "$output" =~ "+-------------+" ]] || false + [[ "$output" =~ "| Table |" ]] || false + [[ "$output" =~ "+-------------+" ]] || false + [[ "$output" =~ "| test |" ]] || false + [[ "$output" =~ "| test_expect |" ]] || false + [[ "$output" =~ "+-------------+" ]] || false +} + @test "bad sql in sql shell should error" { run dolt sql <<< "This is bad sql" [ $status -eq 1 ] diff --git a/bats/sql-shell.expect b/bats/sql-shell.expect new file mode 100755 index 0000000000..5fa8acd609 --- /dev/null +++ b/bats/sql-shell.expect @@ -0,0 +1,21 @@ +#!/usr/bin/expect + +set timeout 1 +spawn dolt sql +set id $spawn_id + +expect -i id "doltsql> " +send -i id "CREATE TABLE test_expect (pk int primary key);\r" + +expect -i id "doltsql> " + +# spawn the second process +spawn dolt sql +set id2 $spawn_id + +# Todo: Should this be a dolt ls instead ? +expect -i id2 "doltsql> " +send -i id2 "show tables;\r" + +expect -i id eof +expect -i id2 eof \ No newline at end of file diff --git a/go/cmd/dolt/commands/sql.go b/go/cmd/dolt/commands/sql.go index 8ac01c1214..528841da10 100644 --- a/go/cmd/dolt/commands/sql.go +++ b/go/cmd/dolt/commands/sql.go @@ -251,9 +251,9 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE if batchMode { batchInput := strings.NewReader(query) - roots, verr = execBatch(sqlCtx, readOnly, mrEnv, roots, batchInput, format) + verr = execBatch(sqlCtx, readOnly, mrEnv, roots, batchInput, format) } else { - roots, verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format) + verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format) if verr != nil { return HandleVErrAndExitCode(verr, usage) @@ -263,7 +263,8 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE if saveName != "" { saveMessage := apr.GetValueOrDefault(messageFlag, "") - roots[currentDB], verr = saveQuery(ctx, roots[currentDB], dEnv, query, saveName, saveMessage) + roots[currentDB], verr = saveQuery(ctx, roots[currentDB], query, saveName, saveMessage) + verr = UpdateWorkingWithVErr(mrEnv[currentDB], roots[currentDB]) } } } else if savedQueryName, exOk := apr.GetValue(executeFlag); exOk { @@ -274,7 +275,7 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE } cli.PrintErrf("Executing saved query '%s':\n%s\n", savedQueryName, sq.Query) - roots, verr = execQuery(sqlCtx, readOnly, mrEnv, roots, sq.Query, format) + verr = execQuery(sqlCtx, readOnly, mrEnv, roots, sq.Query, format) } else if apr.Contains(listSavedFlag) { hasQC, err := roots[currentDB].HasTable(ctx, doltdb.DoltQueryCatalogTableName) @@ -288,7 +289,7 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE } query := "SELECT * FROM " + doltdb.DoltQueryCatalogTableName - _, verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format) + verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format) } else { // Run in either batch mode for piped input, or shell mode for interactive runInBatchMode := true @@ -303,9 +304,9 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE } if runInBatchMode { - roots, verr = execBatch(sqlCtx, readOnly, mrEnv, roots, os.Stdin, format) + verr = execBatch(sqlCtx, readOnly, mrEnv, roots, os.Stdin, format) } else { - roots, verr = execShell(sqlCtx, readOnly, mrEnv, roots, format) + verr = execShell(sqlCtx, readOnly, mrEnv, roots, format) } } @@ -313,56 +314,36 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE return HandleVErrAndExitCode(verr, usage) } - // If the SQL session wrote a new root value, update the working set with it - for name, origRoot := range initialRoots { - root := roots[name] - if origRoot != root { - currEnv := mrEnv[name] - verr = UpdateWorkingWithVErr(currEnv, root) - } - } - return HandleVErrAndExitCode(verr, usage) } -func execShell(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, format resultFormat) (map[string]*doltdb.RootValue, errhand.VerboseError) { +func execShell(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, format resultFormat) errhand.VerboseError { dbs := CollectDBs(mrEnv, newDatabase) se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...) if err != nil { - return nil, errhand.VerboseErrorFromError(err) + return errhand.VerboseErrorFromError(err) } - err = runShell(sqlCtx, se, mrEnv) + err = runShell(sqlCtx, se, mrEnv, roots) if err != nil { - return nil, errhand.BuildDError("unable to start shell").AddCause(err).Build() + return errhand.BuildDError(err.Error()).Build() } - - newRoots, err := se.getRoots(sqlCtx) - if err != nil { - return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build() - } - - return newRoots, nil + return nil } -func execBatch(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, batchInput io.Reader, format resultFormat) (map[string]*doltdb.RootValue, errhand.VerboseError) { +func execBatch(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, batchInput io.Reader, format resultFormat) errhand.VerboseError { dbs := CollectDBs(mrEnv, newBatchedDatabase) se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...) if err != nil { - return nil, errhand.VerboseErrorFromError(err) + return errhand.VerboseErrorFromError(err) } err = runBatchMode(sqlCtx, se, batchInput) if err != nil { - return nil, errhand.BuildDError("Error processing batch").Build() + return errhand.BuildDError("Error processing batch").Build() } - newRoots, err := se.getRoots(sqlCtx) - if err != nil { - return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build() - } - - return newRoots, nil + return writeRoots(sqlCtx, se, mrEnv, roots) } type createDBFunc func(name string, dEnv *env.DoltEnv) dsqle.Database @@ -375,32 +356,26 @@ func newBatchedDatabase(name string, dEnv *env.DoltEnv) dsqle.Database { return dsqle.NewBatchedDatabase(name, dEnv.DbData()) } -func execQuery(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, query string, format resultFormat) (newRoot map[string]*doltdb.RootValue, verr errhand.VerboseError) { +func execQuery(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, query string, format resultFormat) errhand.VerboseError { dbs := CollectDBs(mrEnv, newDatabase) se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...) if err != nil { - return nil, errhand.VerboseErrorFromError(err) + return errhand.VerboseErrorFromError(err) } sqlSch, rowIter, err := processQuery(sqlCtx, query, se) if err != nil { - verr := formatQueryError("", err) - return nil, verr + return formatQueryError("", err) } if rowIter != nil { err = PrettyPrintResults(sqlCtx, se.resultFormat, sqlSch, rowIter) if err != nil { - return nil, errhand.VerboseErrorFromError(err) + return errhand.VerboseErrorFromError(err) } } - newRoots, err := se.getRoots(sqlCtx) - if err != nil { - return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build() - } - - return newRoots, nil + return writeRoots(sqlCtx, se, mrEnv, roots) } // CollectDBs takes a MultiRepoEnv and creates Database objects from each environment and returns a slice of these @@ -561,7 +536,7 @@ func validateSqlArgs(apr *argparser.ArgParseResults) error { } // Saves the query given to the catalog with the name and message given. -func saveQuery(ctx context.Context, root *doltdb.RootValue, dEnv *env.DoltEnv, query string, name string, message string) (*doltdb.RootValue, errhand.VerboseError) { +func saveQuery(ctx context.Context, root *doltdb.RootValue, query string, name string, message string) (*doltdb.RootValue, errhand.VerboseError) { _, newRoot, err := dtables.NewQueryCatalogEntryWithNameAsID(ctx, root, name, query, message) if err != nil { return nil, errhand.BuildDError("Couldn't save query").AddCause(err).Build() @@ -601,7 +576,7 @@ func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error { // runShell starts a SQL shell. Returns when the user exits the shell. The Root of the sqlEngine may // be updated by any queries which were processed. -func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error { +func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv, initialRoots map[string]*doltdb.RootValue) error { _ = iohelp.WriteLine(cli.CliOut, welcomeMsg) currentDB := ctx.Session.GetCurrentDatabase() currEnv := mrEnv[currentDB] @@ -650,6 +625,7 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error { } }) + var returnedVerr errhand.VerboseError = nil // Verr that cannot be just printed but needs to be returned. shell.Uninterpreted(func(c *ishell.Context) { query := c.Args[0] if len(strings.TrimSpace(query)) == 0 { @@ -661,7 +637,9 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error { // https://github.com/cockroachdb/cockroach/issues/15460 // For now, we store all history entries as single-line strings to avoid the issue. singleLine := strings.ReplaceAll(query, "\n", " ") - if err := shell.AddHistory(singleLine); err != nil { + + var err error + if err = shell.AddHistory(singleLine); err != nil { // TODO: handle better, like by turning off history writing for the rest of the session shell.Println(color.RedString(err.Error())) } @@ -676,6 +654,14 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error { } } + if err == nil { + returnedVerr = writeRoots(ctx, se, mrEnv, initialRoots) + + if returnedVerr != nil { + return + } + } + currPrompt := fmt.Sprintf("%s> ", ctx.GetCurrentDatabase()) shell.SetPrompt(currPrompt) shell.SetMultiPrompt(fmt.Sprintf(fmt.Sprintf("%%%ds", len(currPrompt)), "-> ")) @@ -684,7 +670,32 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error { shell.Run() _ = iohelp.WriteLine(cli.CliOut, "Bye") - return nil + return returnedVerr +} + +// writeRoots updates the working root values using the sql context, the sql engine, a multi repo env and a root_val map. +func writeRoots(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv, initialRoots map[string]*doltdb.RootValue) errhand.VerboseError { + roots, err := se.getRoots(ctx) + + if err != nil { + return errhand.BuildDError("failed to get roots").AddCause(err).Build() + } + + // If the SQL session wrote a new root value, update the working set with it + var verr errhand.VerboseError + for name, origRoot := range initialRoots { + root := roots[name] + if origRoot != root { + currEnv := mrEnv[name] + verr = UpdateWorkingWithVErr(currEnv, root) + + if verr != nil { + return verr + } + } + } + + return verr } // Returns a new auto completer with table names, column names, and SQL keywords.