Write to root on every loop of sql shell. (#1215)

This pr fixes a problem where autocommit was not turned on for every single iteration of the shell loop.
This commit is contained in:
Vinai Rachakonda
2021-01-19 16:07:38 -05:00
committed by GitHub
parent f64c0286a3
commit cf037558dd
3 changed files with 97 additions and 51 deletions

View File

@@ -27,6 +27,20 @@ teardown() {
[[ "$output" =~ "pk" ]] || false
}
@test "sql shell writes to disk after every iteration (autocommit)" {
skiponwindows "Need to install expect and make this script work on windows."
run $BATS_TEST_DIRNAME/sql-shell.expect
echo "$output"
# 2 tables are created. 1 from above and 1 in the expect file.
[[ "$output" =~ "+-------------+" ]] || false
[[ "$output" =~ "| Table |" ]] || false
[[ "$output" =~ "+-------------+" ]] || false
[[ "$output" =~ "| test |" ]] || false
[[ "$output" =~ "| test_expect |" ]] || false
[[ "$output" =~ "+-------------+" ]] || false
}
@test "bad sql in sql shell should error" {
run dolt sql <<< "This is bad sql"
[ $status -eq 1 ]

21
bats/sql-shell.expect Executable file
View File

@@ -0,0 +1,21 @@
#!/usr/bin/expect
set timeout 1
spawn dolt sql
set id $spawn_id
expect -i id "doltsql> "
send -i id "CREATE TABLE test_expect (pk int primary key);\r"
expect -i id "doltsql> "
# spawn the second process
spawn dolt sql
set id2 $spawn_id
# Todo: Should this be a dolt ls instead ?
expect -i id2 "doltsql> "
send -i id2 "show tables;\r"
expect -i id eof
expect -i id2 eof

View File

@@ -251,9 +251,9 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
if batchMode {
batchInput := strings.NewReader(query)
roots, verr = execBatch(sqlCtx, readOnly, mrEnv, roots, batchInput, format)
verr = execBatch(sqlCtx, readOnly, mrEnv, roots, batchInput, format)
} else {
roots, verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format)
verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
@@ -263,7 +263,8 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
if saveName != "" {
saveMessage := apr.GetValueOrDefault(messageFlag, "")
roots[currentDB], verr = saveQuery(ctx, roots[currentDB], dEnv, query, saveName, saveMessage)
roots[currentDB], verr = saveQuery(ctx, roots[currentDB], query, saveName, saveMessage)
verr = UpdateWorkingWithVErr(mrEnv[currentDB], roots[currentDB])
}
}
} else if savedQueryName, exOk := apr.GetValue(executeFlag); exOk {
@@ -274,7 +275,7 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
}
cli.PrintErrf("Executing saved query '%s':\n%s\n", savedQueryName, sq.Query)
roots, verr = execQuery(sqlCtx, readOnly, mrEnv, roots, sq.Query, format)
verr = execQuery(sqlCtx, readOnly, mrEnv, roots, sq.Query, format)
} else if apr.Contains(listSavedFlag) {
hasQC, err := roots[currentDB].HasTable(ctx, doltdb.DoltQueryCatalogTableName)
@@ -288,7 +289,7 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
}
query := "SELECT * FROM " + doltdb.DoltQueryCatalogTableName
_, verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format)
verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format)
} else {
// Run in either batch mode for piped input, or shell mode for interactive
runInBatchMode := true
@@ -303,9 +304,9 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
}
if runInBatchMode {
roots, verr = execBatch(sqlCtx, readOnly, mrEnv, roots, os.Stdin, format)
verr = execBatch(sqlCtx, readOnly, mrEnv, roots, os.Stdin, format)
} else {
roots, verr = execShell(sqlCtx, readOnly, mrEnv, roots, format)
verr = execShell(sqlCtx, readOnly, mrEnv, roots, format)
}
}
@@ -313,56 +314,36 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
return HandleVErrAndExitCode(verr, usage)
}
// If the SQL session wrote a new root value, update the working set with it
for name, origRoot := range initialRoots {
root := roots[name]
if origRoot != root {
currEnv := mrEnv[name]
verr = UpdateWorkingWithVErr(currEnv, root)
}
}
return HandleVErrAndExitCode(verr, usage)
}
func execShell(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, format resultFormat) (map[string]*doltdb.RootValue, errhand.VerboseError) {
func execShell(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, format resultFormat) errhand.VerboseError {
dbs := CollectDBs(mrEnv, newDatabase)
se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...)
if err != nil {
return nil, errhand.VerboseErrorFromError(err)
return errhand.VerboseErrorFromError(err)
}
err = runShell(sqlCtx, se, mrEnv)
err = runShell(sqlCtx, se, mrEnv, roots)
if err != nil {
return nil, errhand.BuildDError("unable to start shell").AddCause(err).Build()
return errhand.BuildDError(err.Error()).Build()
}
newRoots, err := se.getRoots(sqlCtx)
if err != nil {
return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build()
}
return newRoots, nil
return nil
}
func execBatch(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, batchInput io.Reader, format resultFormat) (map[string]*doltdb.RootValue, errhand.VerboseError) {
func execBatch(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, batchInput io.Reader, format resultFormat) errhand.VerboseError {
dbs := CollectDBs(mrEnv, newBatchedDatabase)
se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...)
if err != nil {
return nil, errhand.VerboseErrorFromError(err)
return errhand.VerboseErrorFromError(err)
}
err = runBatchMode(sqlCtx, se, batchInput)
if err != nil {
return nil, errhand.BuildDError("Error processing batch").Build()
return errhand.BuildDError("Error processing batch").Build()
}
newRoots, err := se.getRoots(sqlCtx)
if err != nil {
return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build()
}
return newRoots, nil
return writeRoots(sqlCtx, se, mrEnv, roots)
}
type createDBFunc func(name string, dEnv *env.DoltEnv) dsqle.Database
@@ -375,32 +356,26 @@ func newBatchedDatabase(name string, dEnv *env.DoltEnv) dsqle.Database {
return dsqle.NewBatchedDatabase(name, dEnv.DbData())
}
func execQuery(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, query string, format resultFormat) (newRoot map[string]*doltdb.RootValue, verr errhand.VerboseError) {
func execQuery(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, query string, format resultFormat) errhand.VerboseError {
dbs := CollectDBs(mrEnv, newDatabase)
se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...)
if err != nil {
return nil, errhand.VerboseErrorFromError(err)
return errhand.VerboseErrorFromError(err)
}
sqlSch, rowIter, err := processQuery(sqlCtx, query, se)
if err != nil {
verr := formatQueryError("", err)
return nil, verr
return formatQueryError("", err)
}
if rowIter != nil {
err = PrettyPrintResults(sqlCtx, se.resultFormat, sqlSch, rowIter)
if err != nil {
return nil, errhand.VerboseErrorFromError(err)
return errhand.VerboseErrorFromError(err)
}
}
newRoots, err := se.getRoots(sqlCtx)
if err != nil {
return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build()
}
return newRoots, nil
return writeRoots(sqlCtx, se, mrEnv, roots)
}
// CollectDBs takes a MultiRepoEnv and creates Database objects from each environment and returns a slice of these
@@ -561,7 +536,7 @@ func validateSqlArgs(apr *argparser.ArgParseResults) error {
}
// Saves the query given to the catalog with the name and message given.
func saveQuery(ctx context.Context, root *doltdb.RootValue, dEnv *env.DoltEnv, query string, name string, message string) (*doltdb.RootValue, errhand.VerboseError) {
func saveQuery(ctx context.Context, root *doltdb.RootValue, query string, name string, message string) (*doltdb.RootValue, errhand.VerboseError) {
_, newRoot, err := dtables.NewQueryCatalogEntryWithNameAsID(ctx, root, name, query, message)
if err != nil {
return nil, errhand.BuildDError("Couldn't save query").AddCause(err).Build()
@@ -601,7 +576,7 @@ func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error {
// runShell starts a SQL shell. Returns when the user exits the shell. The Root of the sqlEngine may
// be updated by any queries which were processed.
func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv, initialRoots map[string]*doltdb.RootValue) error {
_ = iohelp.WriteLine(cli.CliOut, welcomeMsg)
currentDB := ctx.Session.GetCurrentDatabase()
currEnv := mrEnv[currentDB]
@@ -650,6 +625,7 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
}
})
var returnedVerr errhand.VerboseError = nil // Verr that cannot be just printed but needs to be returned.
shell.Uninterpreted(func(c *ishell.Context) {
query := c.Args[0]
if len(strings.TrimSpace(query)) == 0 {
@@ -661,7 +637,9 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
// https://github.com/cockroachdb/cockroach/issues/15460
// For now, we store all history entries as single-line strings to avoid the issue.
singleLine := strings.ReplaceAll(query, "\n", " ")
if err := shell.AddHistory(singleLine); err != nil {
var err error
if err = shell.AddHistory(singleLine); err != nil {
// TODO: handle better, like by turning off history writing for the rest of the session
shell.Println(color.RedString(err.Error()))
}
@@ -676,6 +654,14 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
}
}
if err == nil {
returnedVerr = writeRoots(ctx, se, mrEnv, initialRoots)
if returnedVerr != nil {
return
}
}
currPrompt := fmt.Sprintf("%s> ", ctx.GetCurrentDatabase())
shell.SetPrompt(currPrompt)
shell.SetMultiPrompt(fmt.Sprintf(fmt.Sprintf("%%%ds", len(currPrompt)), "-> "))
@@ -684,7 +670,32 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
shell.Run()
_ = iohelp.WriteLine(cli.CliOut, "Bye")
return nil
return returnedVerr
}
// writeRoots updates the working root values using the sql context, the sql engine, a multi repo env and a root_val map.
func writeRoots(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv, initialRoots map[string]*doltdb.RootValue) errhand.VerboseError {
roots, err := se.getRoots(ctx)
if err != nil {
return errhand.BuildDError("failed to get roots").AddCause(err).Build()
}
// If the SQL session wrote a new root value, update the working set with it
var verr errhand.VerboseError
for name, origRoot := range initialRoots {
root := roots[name]
if origRoot != root {
currEnv := mrEnv[name]
verr = UpdateWorkingWithVErr(currEnv, root)
if verr != nil {
return verr
}
}
}
return verr
}
// Returns a new auto completer with table names, column names, and SQL keywords.