mirror of
https://github.com/dolthub/dolt.git
synced 2025-12-30 16:12:39 -06:00
Write to root on every loop of sql shell. (#1215)
This pr fixes a problem where autocommit was not turned on for every single iteration of the shell loop.
This commit is contained in:
@@ -27,6 +27,20 @@ teardown() {
|
||||
[[ "$output" =~ "pk" ]] || false
|
||||
}
|
||||
|
||||
@test "sql shell writes to disk after every iteration (autocommit)" {
|
||||
skiponwindows "Need to install expect and make this script work on windows."
|
||||
run $BATS_TEST_DIRNAME/sql-shell.expect
|
||||
echo "$output"
|
||||
|
||||
# 2 tables are created. 1 from above and 1 in the expect file.
|
||||
[[ "$output" =~ "+-------------+" ]] || false
|
||||
[[ "$output" =~ "| Table |" ]] || false
|
||||
[[ "$output" =~ "+-------------+" ]] || false
|
||||
[[ "$output" =~ "| test |" ]] || false
|
||||
[[ "$output" =~ "| test_expect |" ]] || false
|
||||
[[ "$output" =~ "+-------------+" ]] || false
|
||||
}
|
||||
|
||||
@test "bad sql in sql shell should error" {
|
||||
run dolt sql <<< "This is bad sql"
|
||||
[ $status -eq 1 ]
|
||||
|
||||
21
bats/sql-shell.expect
Executable file
21
bats/sql-shell.expect
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/expect
|
||||
|
||||
set timeout 1
|
||||
spawn dolt sql
|
||||
set id $spawn_id
|
||||
|
||||
expect -i id "doltsql> "
|
||||
send -i id "CREATE TABLE test_expect (pk int primary key);\r"
|
||||
|
||||
expect -i id "doltsql> "
|
||||
|
||||
# spawn the second process
|
||||
spawn dolt sql
|
||||
set id2 $spawn_id
|
||||
|
||||
# Todo: Should this be a dolt ls instead ?
|
||||
expect -i id2 "doltsql> "
|
||||
send -i id2 "show tables;\r"
|
||||
|
||||
expect -i id eof
|
||||
expect -i id2 eof
|
||||
@@ -251,9 +251,9 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
|
||||
|
||||
if batchMode {
|
||||
batchInput := strings.NewReader(query)
|
||||
roots, verr = execBatch(sqlCtx, readOnly, mrEnv, roots, batchInput, format)
|
||||
verr = execBatch(sqlCtx, readOnly, mrEnv, roots, batchInput, format)
|
||||
} else {
|
||||
roots, verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format)
|
||||
verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format)
|
||||
|
||||
if verr != nil {
|
||||
return HandleVErrAndExitCode(verr, usage)
|
||||
@@ -263,7 +263,8 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
|
||||
|
||||
if saveName != "" {
|
||||
saveMessage := apr.GetValueOrDefault(messageFlag, "")
|
||||
roots[currentDB], verr = saveQuery(ctx, roots[currentDB], dEnv, query, saveName, saveMessage)
|
||||
roots[currentDB], verr = saveQuery(ctx, roots[currentDB], query, saveName, saveMessage)
|
||||
verr = UpdateWorkingWithVErr(mrEnv[currentDB], roots[currentDB])
|
||||
}
|
||||
}
|
||||
} else if savedQueryName, exOk := apr.GetValue(executeFlag); exOk {
|
||||
@@ -274,7 +275,7 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
|
||||
}
|
||||
|
||||
cli.PrintErrf("Executing saved query '%s':\n%s\n", savedQueryName, sq.Query)
|
||||
roots, verr = execQuery(sqlCtx, readOnly, mrEnv, roots, sq.Query, format)
|
||||
verr = execQuery(sqlCtx, readOnly, mrEnv, roots, sq.Query, format)
|
||||
} else if apr.Contains(listSavedFlag) {
|
||||
hasQC, err := roots[currentDB].HasTable(ctx, doltdb.DoltQueryCatalogTableName)
|
||||
|
||||
@@ -288,7 +289,7 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
|
||||
}
|
||||
|
||||
query := "SELECT * FROM " + doltdb.DoltQueryCatalogTableName
|
||||
_, verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format)
|
||||
verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format)
|
||||
} else {
|
||||
// Run in either batch mode for piped input, or shell mode for interactive
|
||||
runInBatchMode := true
|
||||
@@ -303,9 +304,9 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
|
||||
}
|
||||
|
||||
if runInBatchMode {
|
||||
roots, verr = execBatch(sqlCtx, readOnly, mrEnv, roots, os.Stdin, format)
|
||||
verr = execBatch(sqlCtx, readOnly, mrEnv, roots, os.Stdin, format)
|
||||
} else {
|
||||
roots, verr = execShell(sqlCtx, readOnly, mrEnv, roots, format)
|
||||
verr = execShell(sqlCtx, readOnly, mrEnv, roots, format)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -313,56 +314,36 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
|
||||
return HandleVErrAndExitCode(verr, usage)
|
||||
}
|
||||
|
||||
// If the SQL session wrote a new root value, update the working set with it
|
||||
for name, origRoot := range initialRoots {
|
||||
root := roots[name]
|
||||
if origRoot != root {
|
||||
currEnv := mrEnv[name]
|
||||
verr = UpdateWorkingWithVErr(currEnv, root)
|
||||
}
|
||||
}
|
||||
|
||||
return HandleVErrAndExitCode(verr, usage)
|
||||
}
|
||||
|
||||
func execShell(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, format resultFormat) (map[string]*doltdb.RootValue, errhand.VerboseError) {
|
||||
func execShell(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, format resultFormat) errhand.VerboseError {
|
||||
dbs := CollectDBs(mrEnv, newDatabase)
|
||||
se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...)
|
||||
if err != nil {
|
||||
return nil, errhand.VerboseErrorFromError(err)
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
|
||||
err = runShell(sqlCtx, se, mrEnv)
|
||||
err = runShell(sqlCtx, se, mrEnv, roots)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("unable to start shell").AddCause(err).Build()
|
||||
return errhand.BuildDError(err.Error()).Build()
|
||||
}
|
||||
|
||||
newRoots, err := se.getRoots(sqlCtx)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build()
|
||||
}
|
||||
|
||||
return newRoots, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func execBatch(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, batchInput io.Reader, format resultFormat) (map[string]*doltdb.RootValue, errhand.VerboseError) {
|
||||
func execBatch(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, batchInput io.Reader, format resultFormat) errhand.VerboseError {
|
||||
dbs := CollectDBs(mrEnv, newBatchedDatabase)
|
||||
se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...)
|
||||
if err != nil {
|
||||
return nil, errhand.VerboseErrorFromError(err)
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
|
||||
err = runBatchMode(sqlCtx, se, batchInput)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("Error processing batch").Build()
|
||||
return errhand.BuildDError("Error processing batch").Build()
|
||||
}
|
||||
|
||||
newRoots, err := se.getRoots(sqlCtx)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build()
|
||||
}
|
||||
|
||||
return newRoots, nil
|
||||
return writeRoots(sqlCtx, se, mrEnv, roots)
|
||||
}
|
||||
|
||||
type createDBFunc func(name string, dEnv *env.DoltEnv) dsqle.Database
|
||||
@@ -375,32 +356,26 @@ func newBatchedDatabase(name string, dEnv *env.DoltEnv) dsqle.Database {
|
||||
return dsqle.NewBatchedDatabase(name, dEnv.DbData())
|
||||
}
|
||||
|
||||
func execQuery(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, query string, format resultFormat) (newRoot map[string]*doltdb.RootValue, verr errhand.VerboseError) {
|
||||
func execQuery(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, query string, format resultFormat) errhand.VerboseError {
|
||||
dbs := CollectDBs(mrEnv, newDatabase)
|
||||
se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...)
|
||||
if err != nil {
|
||||
return nil, errhand.VerboseErrorFromError(err)
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
|
||||
sqlSch, rowIter, err := processQuery(sqlCtx, query, se)
|
||||
if err != nil {
|
||||
verr := formatQueryError("", err)
|
||||
return nil, verr
|
||||
return formatQueryError("", err)
|
||||
}
|
||||
|
||||
if rowIter != nil {
|
||||
err = PrettyPrintResults(sqlCtx, se.resultFormat, sqlSch, rowIter)
|
||||
if err != nil {
|
||||
return nil, errhand.VerboseErrorFromError(err)
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
}
|
||||
|
||||
newRoots, err := se.getRoots(sqlCtx)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build()
|
||||
}
|
||||
|
||||
return newRoots, nil
|
||||
return writeRoots(sqlCtx, se, mrEnv, roots)
|
||||
}
|
||||
|
||||
// CollectDBs takes a MultiRepoEnv and creates Database objects from each environment and returns a slice of these
|
||||
@@ -561,7 +536,7 @@ func validateSqlArgs(apr *argparser.ArgParseResults) error {
|
||||
}
|
||||
|
||||
// Saves the query given to the catalog with the name and message given.
|
||||
func saveQuery(ctx context.Context, root *doltdb.RootValue, dEnv *env.DoltEnv, query string, name string, message string) (*doltdb.RootValue, errhand.VerboseError) {
|
||||
func saveQuery(ctx context.Context, root *doltdb.RootValue, query string, name string, message string) (*doltdb.RootValue, errhand.VerboseError) {
|
||||
_, newRoot, err := dtables.NewQueryCatalogEntryWithNameAsID(ctx, root, name, query, message)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("Couldn't save query").AddCause(err).Build()
|
||||
@@ -601,7 +576,7 @@ func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error {
|
||||
|
||||
// runShell starts a SQL shell. Returns when the user exits the shell. The Root of the sqlEngine may
|
||||
// be updated by any queries which were processed.
|
||||
func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
|
||||
func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv, initialRoots map[string]*doltdb.RootValue) error {
|
||||
_ = iohelp.WriteLine(cli.CliOut, welcomeMsg)
|
||||
currentDB := ctx.Session.GetCurrentDatabase()
|
||||
currEnv := mrEnv[currentDB]
|
||||
@@ -650,6 +625,7 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
|
||||
}
|
||||
})
|
||||
|
||||
var returnedVerr errhand.VerboseError = nil // Verr that cannot be just printed but needs to be returned.
|
||||
shell.Uninterpreted(func(c *ishell.Context) {
|
||||
query := c.Args[0]
|
||||
if len(strings.TrimSpace(query)) == 0 {
|
||||
@@ -661,7 +637,9 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
|
||||
// https://github.com/cockroachdb/cockroach/issues/15460
|
||||
// For now, we store all history entries as single-line strings to avoid the issue.
|
||||
singleLine := strings.ReplaceAll(query, "\n", " ")
|
||||
if err := shell.AddHistory(singleLine); err != nil {
|
||||
|
||||
var err error
|
||||
if err = shell.AddHistory(singleLine); err != nil {
|
||||
// TODO: handle better, like by turning off history writing for the rest of the session
|
||||
shell.Println(color.RedString(err.Error()))
|
||||
}
|
||||
@@ -676,6 +654,14 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
returnedVerr = writeRoots(ctx, se, mrEnv, initialRoots)
|
||||
|
||||
if returnedVerr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
currPrompt := fmt.Sprintf("%s> ", ctx.GetCurrentDatabase())
|
||||
shell.SetPrompt(currPrompt)
|
||||
shell.SetMultiPrompt(fmt.Sprintf(fmt.Sprintf("%%%ds", len(currPrompt)), "-> "))
|
||||
@@ -684,7 +670,32 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
|
||||
shell.Run()
|
||||
_ = iohelp.WriteLine(cli.CliOut, "Bye")
|
||||
|
||||
return nil
|
||||
return returnedVerr
|
||||
}
|
||||
|
||||
// writeRoots updates the working root values using the sql context, the sql engine, a multi repo env and a root_val map.
|
||||
func writeRoots(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv, initialRoots map[string]*doltdb.RootValue) errhand.VerboseError {
|
||||
roots, err := se.getRoots(ctx)
|
||||
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to get roots").AddCause(err).Build()
|
||||
}
|
||||
|
||||
// If the SQL session wrote a new root value, update the working set with it
|
||||
var verr errhand.VerboseError
|
||||
for name, origRoot := range initialRoots {
|
||||
root := roots[name]
|
||||
if origRoot != root {
|
||||
currEnv := mrEnv[name]
|
||||
verr = UpdateWorkingWithVErr(currEnv, root)
|
||||
|
||||
if verr != nil {
|
||||
return verr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return verr
|
||||
}
|
||||
|
||||
// Returns a new auto completer with table names, column names, and SQL keywords.
|
||||
|
||||
Reference in New Issue
Block a user