mirror of
https://github.com/dolthub/dolt.git
synced 2026-05-04 19:41:26 -05:00
2178d674e8
Signed-off-by: Zach Musgrave <zach@liquidata.co>
1370 lines
40 KiB
Go
1370 lines
40 KiB
Go
// Copyright 2019-2020 Liquidata, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package commands
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/abiosoft/readline"
|
|
"github.com/fatih/color"
|
|
"github.com/flynn-archive/go-shlex"
|
|
"github.com/liquidata-inc/ishell"
|
|
sqle "github.com/src-d/go-mysql-server"
|
|
"github.com/src-d/go-mysql-server/sql"
|
|
"gopkg.in/src-d/go-errors.v1"
|
|
"vitess.io/vitess/go/vt/sqlparser"
|
|
"vitess.io/vitess/go/vt/vterrors"
|
|
|
|
"github.com/liquidata-inc/dolt/go/cmd/dolt/cli"
|
|
"github.com/liquidata-inc/dolt/go/cmd/dolt/errhand"
|
|
eventsapi "github.com/liquidata-inc/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/env"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
|
|
dsqle "github.com/liquidata-inc/dolt/go/libraries/doltcore/sqle"
|
|
_ "github.com/liquidata-inc/dolt/go/libraries/doltcore/sqle/dfunctions"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table/pipeline"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table/typed/json"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table/untyped"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table/untyped/csv"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table/untyped/fwt"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table/untyped/nullprinter"
|
|
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table/untyped/tabular"
|
|
"github.com/liquidata-inc/dolt/go/libraries/utils/argparser"
|
|
"github.com/liquidata-inc/dolt/go/libraries/utils/filesys"
|
|
"github.com/liquidata-inc/dolt/go/libraries/utils/iohelp"
|
|
"github.com/liquidata-inc/dolt/go/libraries/utils/osutil"
|
|
"github.com/liquidata-inc/dolt/go/store/types"
|
|
)
|
|
|
|
var sqlDocs = cli.CommandDocumentationContent{
|
|
ShortDesc: "Runs a SQL query",
|
|
LongDesc: "Runs a SQL query you specify. With no arguments, begins an interactive shell to run queries and view " +
|
|
"the results. With the {{.EmphasisLeft}}-q{{.EmphasisRight}} option, runs the given query and prints any " +
|
|
"results, then exits.\n" +
|
|
"\n" +
|
|
"By default, {{.EmphasisLeft}}-q{{.EmphasisRight}} executes a single statement. To execute multiple SQL " +
|
|
"statements separated by semicolons, use {{.EmphasisLeft}}-b{{.EmphasisRight}} to enable batch mode. Queries can" +
|
|
"be saved with {{.EmphasisLeft}}-s{{.EmphasisRight}}.\n" +
|
|
"\n" +
|
|
"Alternatively {{.EmpahasisLeft}}-x{{.EmphasisRight}} can be used to execute a saved query by name.\n" +
|
|
"\n" +
|
|
"Pipe SQL statements to dolt sql (no {{.EmphasisLeft}}-q{{.EmphasisRight}}) to execute a SQL import or update " +
|
|
"script.\n" +
|
|
"\n" +
|
|
"By default this command uses the dolt data repository in the current working directory as the one and only " +
|
|
"database. Running with {{.EmphasisLeft}}--multi-db-dir {{.LessThan}}directory{{.GreaterThan}}{{.EmphasisRight}} " +
|
|
"uses each of the subdirectories of the supplied directory (each subdirectory must be a valid dolt data repository) " +
|
|
"as databases. Subdirectories starting with '.' are ignored." +
|
|
"Known limitations:\n" +
|
|
"* No support for creating indexes\n" +
|
|
"* No support for foreign keys\n" +
|
|
"* No support for column constraints besides NOT NULL\n" +
|
|
"* No support for default values\n" +
|
|
"* Joins can only use indexes for two table joins. Three or more tables in a join query will use a non-indexed " +
|
|
"join, which is very slow.",
|
|
|
|
Synopsis: []string{
|
|
"[--multi-db-dir {{.LessThan}}directory{{.GreaterThan}}] [-r {{.LessThan}}result format{{.GreaterThan}}]",
|
|
"-q {{.LessThan}}query;query{{.GreaterThan}} [-r {{.LessThan}}result format{{.GreaterThan}}] -s {{.LessThan}}name{{.GreaterThan}} -m {{.LessThan}}message{{.GreaterThan}} [-b]",
|
|
"-q {{.LessThan}}query;query{{.GreaterThan}} --multi-db-dir {{.LessThan}}directory{{.GreaterThan}} [-r {{.LessThan}}result format{{.GreaterThan}}] [-b]",
|
|
"-x {{.LessThan}}name{{.GreaterThan}}",
|
|
"--list-saved",
|
|
},
|
|
}
|
|
|
|
const (
|
|
queryFlag = "query"
|
|
formatFlag = "result-format"
|
|
saveFlag = "save"
|
|
executeFlag = "execute"
|
|
listSavedFlag = "list-saved"
|
|
messageFlag = "message"
|
|
batchFlag = "batch"
|
|
multiDBDirFlag = "multi-db-dir"
|
|
welcomeMsg = `# Welcome to the DoltSQL shell.
|
|
# Statements must be terminated with ';'.
|
|
# "exit" or "quit" (or Ctrl-D) to exit.`
|
|
)
|
|
|
|
type SqlCmd struct {
|
|
VersionStr string
|
|
}
|
|
|
|
// Name is returns the name of the Dolt cli command. This is what is used on the command line to invoke the command
|
|
func (cmd SqlCmd) Name() string {
|
|
return "sql"
|
|
}
|
|
|
|
// Description returns a description of the command
|
|
func (cmd SqlCmd) Description() string {
|
|
return "Run a SQL query against tables in repository."
|
|
}
|
|
|
|
// CreateMarkdown creates a markdown file containing the helptext for the command at the given path
|
|
func (cmd SqlCmd) CreateMarkdown(fs filesys.Filesys, path, commandStr string) error {
|
|
ap := cmd.createArgParser()
|
|
return CreateMarkdown(fs, path, cli.GetCommandDocumentation(commandStr, sqlDocs, ap))
|
|
}
|
|
|
|
func (cmd SqlCmd) createArgParser() *argparser.ArgParser {
|
|
ap := argparser.NewArgParser()
|
|
ap.SupportsString(queryFlag, "q", "SQL query to run", "Runs a single query and exits")
|
|
ap.SupportsString(formatFlag, "r", "result output format", "How to format result output. Valid values are tabular, csv, json. Defaults to tabular. ")
|
|
ap.SupportsString(saveFlag, "s", "saved query name", "Used with --query, save the query to the query catalog with the name provided. Saved queries can be examined in the dolt_query_catalog system table.")
|
|
ap.SupportsString(executeFlag, "x", "saved query name", "Executes a saved query with the given name")
|
|
ap.SupportsFlag(listSavedFlag, "l", "Lists all saved queries")
|
|
ap.SupportsString(messageFlag, "m", "saved query description", "Used with --query and --save, saves the query with the descriptive message given. See also --name")
|
|
ap.SupportsFlag(batchFlag, "b", "batch mode, to run more than one query with --query, separated by ';'. Piping input to sql with no arguments also uses batch mode")
|
|
ap.SupportsString(multiDBDirFlag, "", "directory", "Defines a directory whose subdirectories should all be dolt data repositories accessible as independent databases within ")
|
|
return ap
|
|
}
|
|
|
|
// EventType returns the type of the event to log
|
|
func (cmd SqlCmd) EventType() eventsapi.ClientEventType {
|
|
return eventsapi.ClientEventType_SQL
|
|
}
|
|
|
|
// RequiresRepo indicates that this command does not have to be run from within a dolt data repository directory.
|
|
// In this case it is because this command supports the multiDBDirFlag which can pass in a directory. In the event that
|
|
// that parameter is not provided there is additional error handling within this command to make sure that this was in
|
|
// fact run from within a dolt data repository directory.
|
|
func (cmd SqlCmd) RequiresRepo() bool {
|
|
return false
|
|
}
|
|
|
|
// Exec executes the command
|
|
func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv) int {
|
|
ap := cmd.createArgParser()
|
|
help, usage := cli.HelpAndUsagePrinters(cli.GetCommandDocumentation(commandStr, sqlDocs, ap))
|
|
|
|
apr := cli.ParseArgs(ap, args, help)
|
|
err := validateSqlArgs(apr)
|
|
if err != nil {
|
|
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
|
|
}
|
|
|
|
args = apr.Args()
|
|
|
|
var verr errhand.VerboseError
|
|
format := formatTabular
|
|
if formatSr, ok := apr.GetValue(formatFlag); ok {
|
|
format, verr = getFormat(formatSr)
|
|
if verr != nil {
|
|
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(verr), usage)
|
|
}
|
|
}
|
|
|
|
dsess := dsqle.DefaultDoltSession()
|
|
|
|
var mrEnv env.MultiRepoEnv
|
|
var initialRoots map[string]*doltdb.RootValue
|
|
if multiDir, ok := apr.GetValue(multiDBDirFlag); !ok {
|
|
if !cli.CheckEnvIsValid(dEnv) {
|
|
return 2
|
|
}
|
|
|
|
mrEnv = env.DoltEnvAsMultiEnv(dEnv)
|
|
initialRoots, err = mrEnv.GetWorkingRoots(ctx)
|
|
|
|
if err != nil {
|
|
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
|
|
}
|
|
|
|
dsess.Username = *dEnv.Config.GetStringOrDefault(env.UserNameKey, "")
|
|
dsess.Email = *dEnv.Config.GetStringOrDefault(env.UserEmailKey, "")
|
|
} else {
|
|
mrEnv, err = env.LoadMultiEnvFromDir(ctx, env.GetCurrentUserHomeDir, dEnv.FS, multiDir, cmd.VersionStr)
|
|
|
|
if err != nil {
|
|
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
|
|
}
|
|
|
|
initialRoots, err = mrEnv.GetWorkingRoots(ctx)
|
|
|
|
if err != nil {
|
|
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
|
|
}
|
|
}
|
|
|
|
sqlCtx := sql.NewContext(ctx,
|
|
sql.WithSession(dsess),
|
|
sql.WithIndexRegistry(sql.NewIndexRegistry()),
|
|
sql.WithViewRegistry(sql.NewViewRegistry()))
|
|
sqlCtx.Set(sqlCtx, sql.AutoCommitSessionVar, sql.Boolean, true)
|
|
|
|
roots := make(map[string]*doltdb.RootValue)
|
|
|
|
var name string
|
|
var root *doltdb.RootValue
|
|
for name, root = range initialRoots {
|
|
roots[name] = root
|
|
}
|
|
|
|
var currentDB string
|
|
if len(initialRoots) == 1 {
|
|
sqlCtx.SetCurrentDatabase(name)
|
|
currentDB = name
|
|
}
|
|
|
|
if err != nil {
|
|
return HandleVErrAndExitCode(err.(errhand.VerboseError), usage)
|
|
}
|
|
|
|
if query, queryOK := apr.GetValue(queryFlag); queryOK {
|
|
batchMode := apr.Contains(batchFlag)
|
|
|
|
if batchMode {
|
|
batchInput := strings.NewReader(query)
|
|
roots, verr = execBatch(sqlCtx, mrEnv, roots, batchInput, format)
|
|
} else {
|
|
roots, verr = execQuery(sqlCtx, mrEnv, roots, query, format)
|
|
|
|
if verr != nil {
|
|
return HandleVErrAndExitCode(verr, usage)
|
|
}
|
|
|
|
saveName := apr.GetValueOrDefault(saveFlag, "")
|
|
|
|
if saveName != "" {
|
|
saveMessage := apr.GetValueOrDefault(messageFlag, "")
|
|
roots[currentDB], verr = saveQuery(ctx, roots[currentDB], dEnv, query, saveName, saveMessage)
|
|
}
|
|
}
|
|
} else if savedQueryName, exOk := apr.GetValue(executeFlag); exOk {
|
|
sq, err := dsqle.RetrieveFromQueryCatalog(ctx, roots[currentDB], savedQueryName)
|
|
|
|
if err != nil {
|
|
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
|
|
}
|
|
|
|
cli.PrintErrf("Executing saved query '%s':\n%s\n", savedQueryName, sq.Query)
|
|
roots, verr = execQuery(sqlCtx, mrEnv, roots, sq.Query, format)
|
|
} else if apr.Contains(listSavedFlag) {
|
|
hasQC, err := roots[currentDB].HasTable(ctx, doltdb.DoltQueryCatalogTableName)
|
|
|
|
if err != nil {
|
|
verr := errhand.BuildDError("error: Failed to read from repository.").AddCause(err).Build()
|
|
return HandleVErrAndExitCode(verr, usage)
|
|
}
|
|
|
|
if !hasQC {
|
|
return 0
|
|
}
|
|
|
|
query := "SELECT * FROM " + doltdb.DoltQueryCatalogTableName
|
|
_, verr = execQuery(sqlCtx, mrEnv, roots, query, format)
|
|
} else {
|
|
// Run in either batch mode for piped input, or shell mode for interactive
|
|
runInBatchMode := true
|
|
fi, err := os.Stdin.Stat()
|
|
|
|
if err != nil {
|
|
if !osutil.IsWindows {
|
|
return HandleVErrAndExitCode(errhand.BuildDError("Couldn't stat STDIN. This is a bug.").Build(), usage)
|
|
}
|
|
} else {
|
|
runInBatchMode = fi.Mode()&os.ModeCharDevice == 0
|
|
}
|
|
|
|
if runInBatchMode {
|
|
roots, verr = execBatch(sqlCtx, mrEnv, roots, os.Stdin, format)
|
|
} else {
|
|
roots, verr = execShell(sqlCtx, mrEnv, roots, format)
|
|
}
|
|
}
|
|
|
|
if verr != nil {
|
|
return HandleVErrAndExitCode(verr, usage)
|
|
}
|
|
|
|
// If the SQL session wrote a new root value, update the working set with it
|
|
for name, origRoot := range initialRoots {
|
|
root := roots[name]
|
|
if origRoot != root {
|
|
currEnv := mrEnv[name]
|
|
verr = UpdateWorkingWithVErr(currEnv, root)
|
|
}
|
|
}
|
|
|
|
return HandleVErrAndExitCode(verr, usage)
|
|
}
|
|
|
|
func execShell(sqlCtx *sql.Context, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, format resultFormat) (map[string]*doltdb.RootValue, errhand.VerboseError) {
|
|
dbs := CollectDBs(mrEnv, newDatabase)
|
|
se, err := newSqlEngine(sqlCtx, mrEnv, roots, format, dbs...)
|
|
if err != nil {
|
|
return nil, errhand.VerboseErrorFromError(err)
|
|
}
|
|
|
|
err = runShell(sqlCtx, se, mrEnv)
|
|
if err != nil {
|
|
return nil, errhand.BuildDError("unable to start shell").AddCause(err).Build()
|
|
}
|
|
|
|
newRoots, err := se.getRoots(sqlCtx)
|
|
if err != nil {
|
|
return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build()
|
|
}
|
|
|
|
return newRoots, nil
|
|
}
|
|
|
|
func execBatch(sqlCtx *sql.Context, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, batchInput io.Reader, format resultFormat) (map[string]*doltdb.RootValue, errhand.VerboseError) {
|
|
dbs := CollectDBs(mrEnv, newBatchedDatabase)
|
|
se, err := newSqlEngine(sqlCtx, mrEnv, roots, format, dbs...)
|
|
if err != nil {
|
|
return nil, errhand.VerboseErrorFromError(err)
|
|
}
|
|
|
|
err = runBatchMode(sqlCtx, se, batchInput)
|
|
if err != nil {
|
|
return nil, errhand.BuildDError("Error processing batch").Build()
|
|
}
|
|
|
|
newRoots, err := se.getRoots(sqlCtx)
|
|
if err != nil {
|
|
return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build()
|
|
}
|
|
|
|
return newRoots, nil
|
|
}
|
|
|
|
type createDBFunc func(name string, dEnv *env.DoltEnv) dsqle.Database
|
|
|
|
func newDatabase(name string, dEnv *env.DoltEnv) dsqle.Database {
|
|
return dsqle.NewDatabase(name, dEnv.DoltDB, dEnv.RepoState, dEnv.RepoStateWriter())
|
|
}
|
|
|
|
func newBatchedDatabase(name string, dEnv *env.DoltEnv) dsqle.Database {
|
|
return dsqle.NewBatchedDatabase(name, dEnv.DoltDB, dEnv.RepoState, dEnv.RepoStateWriter())
|
|
}
|
|
|
|
func execQuery(sqlCtx *sql.Context, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, query string, format resultFormat) (map[string]*doltdb.RootValue, errhand.VerboseError) {
|
|
dbs := CollectDBs(mrEnv, newDatabase)
|
|
se, err := newSqlEngine(sqlCtx, mrEnv, roots, format, dbs...)
|
|
if err != nil {
|
|
return nil, errhand.VerboseErrorFromError(err)
|
|
}
|
|
|
|
sqlSch, rowIter, err := processQuery(sqlCtx, query, se)
|
|
if err != nil {
|
|
verr := formatQueryError("", err)
|
|
return nil, verr
|
|
}
|
|
|
|
if rowIter != nil {
|
|
defer rowIter.Close()
|
|
err = se.prettyPrintResults(sqlCtx, sqlSch, rowIter)
|
|
if err != nil {
|
|
return nil, errhand.VerboseErrorFromError(err)
|
|
}
|
|
}
|
|
|
|
newRoots, err := se.getRoots(sqlCtx)
|
|
if err != nil {
|
|
return nil, errhand.BuildDError("failed to get roots").AddCause(err).Build()
|
|
}
|
|
|
|
return newRoots, nil
|
|
}
|
|
|
|
// CollectDBs takes a MultiRepoEnv and creates Database objects from each environment and returns a slice of these
|
|
// objects.
|
|
func CollectDBs(mrEnv env.MultiRepoEnv, createDB createDBFunc) []dsqle.Database {
|
|
dbs := make([]dsqle.Database, 0, len(mrEnv))
|
|
_ = mrEnv.Iter(func(name string, dEnv *env.DoltEnv) (stop bool, err error) {
|
|
db := createDB(name, dEnv)
|
|
dbs = append(dbs, db)
|
|
return false, nil
|
|
})
|
|
|
|
return dbs
|
|
}
|
|
|
|
func formatQueryError(message string, err error) errhand.VerboseError {
|
|
const (
|
|
maxStatementLen = 128
|
|
maxPosWhenTruncated = 64
|
|
)
|
|
|
|
if se, ok := vterrors.AsSyntaxError(err); ok {
|
|
verrBuilder := errhand.BuildDError("Error parsing SQL")
|
|
verrBuilder.AddDetails(se.Message)
|
|
|
|
statement := se.Statement
|
|
position := se.Position
|
|
|
|
prevLines := ""
|
|
for {
|
|
idxNewline := strings.IndexRune(statement, '\n')
|
|
|
|
if idxNewline == -1 {
|
|
break
|
|
} else if idxNewline < position {
|
|
position -= idxNewline + 1
|
|
prevLines += statement[:idxNewline+1]
|
|
statement = statement[idxNewline+1:]
|
|
} else {
|
|
statement = statement[:idxNewline]
|
|
break
|
|
}
|
|
}
|
|
|
|
if len(statement) > maxStatementLen {
|
|
if position > maxPosWhenTruncated {
|
|
statement = statement[position-maxPosWhenTruncated:]
|
|
position = maxPosWhenTruncated
|
|
}
|
|
|
|
if len(statement) > maxStatementLen {
|
|
statement = statement[:maxStatementLen]
|
|
}
|
|
}
|
|
|
|
verrBuilder.AddDetails(prevLines + statement)
|
|
|
|
marker := make([]rune, position+1)
|
|
for i := 0; i < position; i++ {
|
|
marker[i] = ' '
|
|
}
|
|
|
|
marker[position] = '^'
|
|
verrBuilder.AddDetails(string(marker))
|
|
|
|
return verrBuilder.Build()
|
|
} else {
|
|
if len(message) > 0 {
|
|
err = fmt.Errorf("%s: %s", message, err.Error())
|
|
}
|
|
return errhand.VerboseErrorFromError(err)
|
|
}
|
|
}
|
|
|
|
func getFormat(format string) (resultFormat, errhand.VerboseError) {
|
|
switch strings.ToLower(format) {
|
|
case "tabular":
|
|
return formatTabular, nil
|
|
case "csv":
|
|
return formatCsv, nil
|
|
case "json":
|
|
return formatJson, nil
|
|
default:
|
|
return formatTabular, errhand.BuildDError("Invalid argument for --result-format. Valid values are tabular, csv, json").Build()
|
|
}
|
|
}
|
|
|
|
func validateSqlArgs(apr *argparser.ArgParseResults) error {
|
|
_, query := apr.GetValue(queryFlag)
|
|
_, save := apr.GetValue(saveFlag)
|
|
_, msg := apr.GetValue(messageFlag)
|
|
_, batch := apr.GetValue(batchFlag)
|
|
_, list := apr.GetValue(listSavedFlag)
|
|
_, execute := apr.GetValue(executeFlag)
|
|
_, multiDB := apr.GetValue(multiDBDirFlag)
|
|
|
|
if len(apr.Args()) > 0 && !query {
|
|
return errhand.BuildDError("Invalid Argument: use --query or -q to pass inline SQL queries").Build()
|
|
}
|
|
|
|
if execute {
|
|
if list {
|
|
return errhand.BuildDError("Invalid Argument: --execute|-x is not compatible with --list-saved").Build()
|
|
} else if query {
|
|
return errhand.BuildDError("Invalid Argument: --execute|-x is not compatible with --query|-q").Build()
|
|
} else if msg {
|
|
return errhand.BuildDError("Invalid Argument: --execute|-x is not compatible with --message|-m").Build()
|
|
} else if save {
|
|
return errhand.BuildDError("Invalid Argument: --execute|-x is not compatible with --save|-s").Build()
|
|
} else if multiDB {
|
|
return errhand.BuildDError("Invalid Argument: --execute|-x is not compatible with --multi-db-dir").Build()
|
|
}
|
|
}
|
|
|
|
if list {
|
|
if execute {
|
|
return errhand.BuildDError("Invalid Argument: --list-saved is not compatible with --executed|x").Build()
|
|
} else if query {
|
|
return errhand.BuildDError("Invalid Argument: --list-saved is not compatible with --query|-q").Build()
|
|
} else if msg {
|
|
return errhand.BuildDError("Invalid Argument: --list-saved is not compatible with --message|-m").Build()
|
|
} else if save {
|
|
return errhand.BuildDError("Invalid Argument: --list-saved is not compatible with --save|-s").Build()
|
|
} else if multiDB {
|
|
return errhand.BuildDError("Invalid Argument: --execute|-x is not compatible with --multi-db-dir").Build()
|
|
}
|
|
}
|
|
|
|
if save && multiDB {
|
|
return errhand.BuildDError("Invalid Argument: --multi-db-dir queries cannot be saved").Build()
|
|
}
|
|
|
|
if batch {
|
|
if !query {
|
|
return errhand.BuildDError("Invalid Argument: --batch|-b must be used with --query|-q").Build()
|
|
}
|
|
if save || msg {
|
|
return errhand.BuildDError("Invalid Argument: --batch|-b is not compatible with --save|-s or --message|-m").Build()
|
|
}
|
|
}
|
|
|
|
if query {
|
|
if !save && msg {
|
|
return errhand.BuildDError("Invalid Argument: --message|-m is only used with --query|-q and --save|-s").Build()
|
|
}
|
|
} else {
|
|
if save {
|
|
return errhand.BuildDError("Invalid Argument: --save|-s is only used with --query|-q").Build()
|
|
}
|
|
if msg {
|
|
return errhand.BuildDError("Invalid Argument: --message|-m is only used with --query|-q and --save|-s").Build()
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Saves the query given to the catalog with the name and message given.
|
|
func saveQuery(ctx context.Context, root *doltdb.RootValue, dEnv *env.DoltEnv, query string, name string, message string) (*doltdb.RootValue, errhand.VerboseError) {
|
|
_, newRoot, err := dsqle.NewQueryCatalogEntryWithNameAsID(ctx, root, name, query, message)
|
|
if err != nil {
|
|
return nil, errhand.BuildDError("Couldn't save query").AddCause(err).Build()
|
|
}
|
|
|
|
return newRoot, nil
|
|
}
|
|
|
|
// runBatchMode processes queries until EOF. The Root of the sqlEngine may be updated.
|
|
func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error {
|
|
scanner := NewSqlStatementScanner(input)
|
|
|
|
var query string
|
|
for scanner.Scan() {
|
|
query += scanner.Text()
|
|
if len(query) == 0 || query == "\n" {
|
|
continue
|
|
}
|
|
if err := processBatchQuery(ctx, query, se); err != nil {
|
|
// TODO: this line number will not be accurate for errors that occur when flushing a batch of inserts (as opposed
|
|
// to processing the query)
|
|
verr := formatQueryError(fmt.Sprintf("error on line %d for query %s", scanner.statementStartLine, query), err)
|
|
cli.PrintErrln(verr.Verbose())
|
|
return err
|
|
}
|
|
query = ""
|
|
}
|
|
|
|
updateBatchInsertOutput()
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
cli.Println(err.Error())
|
|
}
|
|
|
|
return flushBatchedEdits(ctx, se)
|
|
}
|
|
|
|
// runShell starts a SQL shell. Returns when the user exits the shell. The Root of the sqlEngine may
|
|
// be updated by any queries which were processed.
|
|
func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
|
|
_ = iohelp.WriteLine(cli.CliOut, welcomeMsg)
|
|
currentDB := ctx.Session.GetCurrentDatabase()
|
|
currEnv := mrEnv[currentDB]
|
|
|
|
// start the doltsql shell
|
|
historyFile := filepath.Join(".sqlhistory") // history file written to working dir
|
|
initialPrompt := fmt.Sprintf("%s> ", ctx.GetCurrentDatabase())
|
|
initialMultilinePrompt := fmt.Sprintf(fmt.Sprintf("%%%ds", len(initialPrompt)), "-> ")
|
|
|
|
rlConf := readline.Config{
|
|
Prompt: initialPrompt,
|
|
Stdout: cli.CliOut,
|
|
Stderr: cli.CliOut,
|
|
HistoryFile: historyFile,
|
|
HistoryLimit: 500,
|
|
HistorySearchFold: true,
|
|
DisableAutoSaveHistory: true,
|
|
}
|
|
shellConf := ishell.UninterpretedConfig{
|
|
ReadlineConfig: &rlConf,
|
|
QuitKeywords: []string{
|
|
"quit", "exit", "quit()", "exit()",
|
|
},
|
|
LineTerminator: ";",
|
|
}
|
|
|
|
shell := ishell.NewUninterpreted(&shellConf)
|
|
shell.SetMultiPrompt(initialMultilinePrompt)
|
|
// TODO: update completer on create / drop / alter statements
|
|
completer, err := newCompleter(ctx, currEnv)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
shell.CustomCompleter(completer)
|
|
|
|
shell.EOF(func(c *ishell.Context) {
|
|
c.Stop()
|
|
})
|
|
|
|
shell.Interrupt(func(c *ishell.Context, count int, input string) {
|
|
if count > 1 {
|
|
c.Stop()
|
|
} else {
|
|
c.Println("Received SIGINT. Interrupt again to exit, or use ^D, quit, or exit")
|
|
}
|
|
})
|
|
|
|
shell.Uninterpreted(func(c *ishell.Context) {
|
|
query := c.Args[0]
|
|
if len(strings.TrimSpace(query)) == 0 {
|
|
return
|
|
}
|
|
|
|
if sqlSch, rowIter, err := processQuery(ctx, query, se); err != nil {
|
|
verr := formatQueryError("", err)
|
|
shell.Println(verr.Verbose())
|
|
} else if rowIter != nil {
|
|
defer rowIter.Close()
|
|
err = se.prettyPrintResults(ctx, sqlSch, rowIter)
|
|
if err != nil {
|
|
shell.Println(color.RedString(err.Error()))
|
|
}
|
|
}
|
|
|
|
// TODO: there's a bug in the readline library when editing multi-line history entries.
|
|
// Longer term we need to switch to a new readline library, like in this bug:
|
|
// https://github.com/cockroachdb/cockroach/issues/15460
|
|
// For now, we store all history entries as single-line strings to avoid the issue.
|
|
singleLine := strings.ReplaceAll(query, "\n", " ")
|
|
if err := shell.AddHistory(singleLine); err != nil {
|
|
// TODO: handle better, like by turning off history writing for the rest of the session
|
|
shell.Println(color.RedString(err.Error()))
|
|
}
|
|
|
|
currPrompt := fmt.Sprintf("%s> ", ctx.GetCurrentDatabase())
|
|
shell.SetPrompt(currPrompt)
|
|
shell.SetMultiPrompt(fmt.Sprintf(fmt.Sprintf("%%%ds", len(currPrompt)), "-> "))
|
|
})
|
|
|
|
shell.Run()
|
|
_ = iohelp.WriteLine(cli.CliOut, "Bye")
|
|
|
|
return nil
|
|
}
|
|
|
|
// Returns a new auto completer with table names, column names, and SQL keywords.
|
|
func newCompleter(ctx context.Context, dEnv *env.DoltEnv) (*sqlCompleter, error) {
|
|
var completionWords []string
|
|
|
|
root, err := dEnv.WorkingRoot(ctx)
|
|
if err != nil {
|
|
return &sqlCompleter{}, nil
|
|
}
|
|
|
|
tableNames, err := root.GetTableNames(ctx)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
completionWords = append(completionWords, tableNames...)
|
|
var columnNames []string
|
|
for _, tableName := range tableNames {
|
|
tbl, _, err := root.GetTable(ctx, tableName)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sch, err := tbl.GetSchema(ctx)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
|
|
completionWords = append(completionWords, col.Name)
|
|
columnNames = append(columnNames, col.Name)
|
|
return false, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
completionWords = append(completionWords, dsqle.CommonKeywords...)
|
|
|
|
return &sqlCompleter{
|
|
allWords: completionWords,
|
|
columnNames: columnNames,
|
|
}, nil
|
|
}
|
|
|
|
type sqlCompleter struct {
|
|
allWords []string
|
|
columnNames []string
|
|
}
|
|
|
|
// Do function for autocompletion, defined by the Readline library. Mostly stolen from ishell.
|
|
func (c *sqlCompleter) Do(line []rune, pos int) (newLine [][]rune, length int) {
|
|
var words []string
|
|
if w, err := shlex.Split(string(line)); err == nil {
|
|
words = w
|
|
} else {
|
|
// fall back
|
|
words = strings.Fields(string(line))
|
|
}
|
|
|
|
var cWords []string
|
|
prefix := ""
|
|
lastWord := ""
|
|
if len(words) > 0 && pos > 0 && line[pos-1] != ' ' {
|
|
lastWord = words[len(words)-1]
|
|
prefix = strings.ToLower(lastWord)
|
|
} else if len(words) > 0 {
|
|
lastWord = words[len(words)-1]
|
|
}
|
|
|
|
cWords = c.getWords(lastWord)
|
|
|
|
var suggestions [][]rune
|
|
for _, w := range cWords {
|
|
lowered := strings.ToLower(w)
|
|
if strings.HasPrefix(lowered, prefix) {
|
|
suggestions = append(suggestions, []rune(strings.TrimPrefix(lowered, prefix)))
|
|
}
|
|
}
|
|
if len(suggestions) == 1 && prefix != "" && string(suggestions[0]) == "" {
|
|
suggestions = [][]rune{[]rune(" ")}
|
|
}
|
|
|
|
return suggestions, len(prefix)
|
|
}
|
|
|
|
// Simple suggestion function. Returns column name suggestions if the last word in the input has exactly one '.' in it,
|
|
// otherwise returns all tables, columns, and reserved words.
|
|
func (c *sqlCompleter) getWords(lastWord string) (s []string) {
|
|
lastDot := strings.LastIndex(lastWord, ".")
|
|
if lastDot > 0 && strings.Count(lastWord, ".") == 1 {
|
|
alias := lastWord[:lastDot]
|
|
return prepend(alias+".", c.columnNames)
|
|
}
|
|
|
|
return c.allWords
|
|
}
|
|
|
|
func prepend(s string, ss []string) []string {
|
|
newSs := make([]string, len(ss))
|
|
for i := range ss {
|
|
newSs[i] = s + ss[i]
|
|
}
|
|
return newSs
|
|
}
|
|
|
|
// Processes a single query. The Root of the sqlEngine will be updated if necessary.
|
|
// Returns the schema and the row iterator for the results, which may be nil, and an error if one occurs.
|
|
func processQuery(ctx *sql.Context, query string, se *sqlEngine) (sql.Schema, sql.RowIter, error) {
|
|
sqlStatement, err := sqlparser.Parse(query)
|
|
if err == sqlparser.ErrEmpty {
|
|
// silently skip empty statements
|
|
return nil, nil, nil
|
|
} else if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
switch s := sqlStatement.(type) {
|
|
case *sqlparser.Select, *sqlparser.Insert, *sqlparser.Update, *sqlparser.OtherRead, *sqlparser.Show, *sqlparser.Explain, *sqlparser.Union:
|
|
return se.query(ctx, query)
|
|
case *sqlparser.Use, *sqlparser.Set:
|
|
sch, rowIter, err := se.query(ctx, query)
|
|
|
|
if rowIter != nil {
|
|
_ = rowIter.Close()
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
switch sqlStatement.(type) {
|
|
case *sqlparser.Use:
|
|
cli.Println("Database changed")
|
|
}
|
|
|
|
return sch, nil, err
|
|
case *sqlparser.Delete:
|
|
ok := se.checkThenDeleteAllRows(ctx, s)
|
|
if ok {
|
|
return nil, nil, err
|
|
}
|
|
return se.query(ctx, query)
|
|
case *sqlparser.DDL:
|
|
_, err := sqlparser.ParseStrictDDL(query)
|
|
if err != nil {
|
|
if se, ok := vterrors.AsSyntaxError(err); ok {
|
|
return nil, nil, vterrors.SyntaxError{Message: "While Parsing DDL: " + se.Message, Position: se.Position, Statement: se.Statement}
|
|
} else {
|
|
return nil, nil, fmt.Errorf("Error parsing DDL: %v.", err.Error())
|
|
}
|
|
}
|
|
return se.ddl(ctx, s, query)
|
|
default:
|
|
return nil, nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query)
|
|
}
|
|
}
|
|
|
|
type stats struct {
|
|
rowsInserted int
|
|
rowsUpdated int
|
|
rowsDeleted int
|
|
unflushedEdits int
|
|
unprintedEdits int
|
|
}
|
|
|
|
var batchEditStats = &stats{}
|
|
var displayStrLen int
|
|
|
|
const maxBatchSize = 200000
|
|
const updateInterval = 1000
|
|
|
|
func (s *stats) numUpdates() int {
|
|
return s.rowsUpdated + s.rowsDeleted + s.rowsInserted
|
|
}
|
|
|
|
func (s *stats) shouldUpdateBatchModeOutput() bool {
|
|
return s.unprintedEdits >= updateInterval
|
|
}
|
|
|
|
func (s *stats) shouldFlush() bool {
|
|
return s.unflushedEdits >= maxBatchSize
|
|
}
|
|
|
|
func flushBatchedEdits(ctx *sql.Context, se *sqlEngine) error {
|
|
err := se.iterDBs(func(_ string, db dsqle.Database) (bool, error) {
|
|
err := db.Flush(ctx)
|
|
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return false, nil
|
|
})
|
|
|
|
batchEditStats.unflushedEdits = 0
|
|
|
|
return err
|
|
}
|
|
|
|
// Processes a single query in batch mode. The Root of the sqlEngine may or may not be changed.
|
|
func processBatchQuery(ctx *sql.Context, query string, se *sqlEngine) error {
|
|
sqlStatement, err := sqlparser.Parse(query)
|
|
if err == sqlparser.ErrEmpty {
|
|
// silently skip empty statements
|
|
return nil
|
|
} else if err != nil {
|
|
return fmt.Errorf("Error parsing SQL: %v.", err.Error())
|
|
}
|
|
|
|
if canProcessAsBatchInsert(sqlStatement) {
|
|
err = processBatchInsert(ctx, se, query, sqlStatement)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
err := processNonInsertBatchQuery(ctx, se, query, sqlStatement)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if batchEditStats.shouldUpdateBatchModeOutput() {
|
|
updateBatchInsertOutput()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func processNonInsertBatchQuery(ctx *sql.Context, se *sqlEngine, query string, sqlStatement sqlparser.Statement) error {
|
|
// We need to commit whatever batch edits we've accumulated so far before executing the query
|
|
err := flushBatchedEdits(ctx, se)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sqlSch, rowIter, err := processQuery(ctx, query, se)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if rowIter != nil {
|
|
defer rowIter.Close()
|
|
err = mergeResultIntoStats(sqlStatement, rowIter, batchEditStats)
|
|
if err != nil {
|
|
return fmt.Errorf("error executing statement: %v", err.Error())
|
|
}
|
|
|
|
// Some statement types should print results, even in batch mode.
|
|
switch sqlStatement.(type) {
|
|
case *sqlparser.Select, *sqlparser.OtherRead, *sqlparser.Show, *sqlparser.Explain, *sqlparser.Union:
|
|
if displayStrLen > 0 {
|
|
// If we've been printing in batch mode, print a newline to put the regular output on its own line
|
|
cli.Print("\n")
|
|
displayStrLen = 0
|
|
}
|
|
err = se.prettyPrintResults(ctx, sqlSch, rowIter)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// And flush again afterwards, to make sure any following insert statements have the latest data
|
|
return flushBatchedEdits(ctx, se)
|
|
}
|
|
|
|
func processBatchInsert(ctx *sql.Context, se *sqlEngine, query string, sqlStatement sqlparser.Statement) error {
|
|
_, rowIter, err := se.query(ctx, query)
|
|
if err != nil {
|
|
return fmt.Errorf("Error inserting rows: %v", err.Error())
|
|
}
|
|
|
|
if rowIter != nil {
|
|
defer rowIter.Close()
|
|
err = mergeResultIntoStats(sqlStatement, rowIter, batchEditStats)
|
|
if err != nil {
|
|
return fmt.Errorf("Error inserting rows: %v", err.Error())
|
|
}
|
|
}
|
|
|
|
if batchEditStats.shouldFlush() {
|
|
return flushBatchedEdits(ctx, se)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// canProcessBatchInsert returns whether the given statement can be processed as a batch insert. Only simple inserts
|
|
// (inserting a list of values) can be processed in this way. Other kinds of insert (notably INSERT INTO SELECT AS) need
|
|
// a flushed root and can't benefit from batch optimizations.
|
|
func canProcessAsBatchInsert(sqlStatement sqlparser.Statement) bool {
|
|
switch s := sqlStatement.(type) {
|
|
case *sqlparser.Insert:
|
|
if _, ok := s.Rows.(sqlparser.Values); ok {
|
|
return true
|
|
}
|
|
return false
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func updateBatchInsertOutput() {
|
|
displayStr := fmt.Sprintf("Rows inserted: %d Rows updated: %d Rows deleted: %d\n",
|
|
batchEditStats.rowsInserted, batchEditStats.rowsUpdated, batchEditStats.rowsDeleted)
|
|
displayStrLen = cli.DeleteAndPrint(displayStrLen, displayStr)
|
|
batchEditStats.unprintedEdits = 0
|
|
}
|
|
|
|
// Updates the batch insert stats with the results of an INSERT, UPDATE, or DELETE statement.
|
|
func mergeResultIntoStats(statement sqlparser.Statement, rowIter sql.RowIter, s *stats) error {
|
|
switch statement.(type) {
|
|
case *sqlparser.Insert, *sqlparser.Delete, *sqlparser.Update:
|
|
break
|
|
default:
|
|
return nil
|
|
}
|
|
|
|
for {
|
|
row, err := rowIter.Next()
|
|
if err == io.EOF {
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
} else {
|
|
okResult := row[0].(sql.OkResult)
|
|
s.unflushedEdits += int(okResult.RowsAffected)
|
|
s.unprintedEdits += int(okResult.RowsAffected)
|
|
switch statement.(type) {
|
|
case *sqlparser.Insert:
|
|
s.rowsInserted += int(okResult.RowsAffected)
|
|
case *sqlparser.Delete:
|
|
s.rowsDeleted += int(okResult.RowsAffected)
|
|
case *sqlparser.Update:
|
|
s.rowsUpdated += int(okResult.RowsAffected)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type resultFormat byte
|
|
|
|
const (
|
|
formatTabular resultFormat = iota
|
|
formatCsv
|
|
formatJson
|
|
)
|
|
|
|
type sqlEngine struct {
|
|
dbs map[string]dsqle.Database
|
|
mrEnv env.MultiRepoEnv
|
|
engine *sqle.Engine
|
|
resultFormat resultFormat
|
|
}
|
|
|
|
var ErrDBNotFoundKind = errors.NewKind("database '%s' not found")
|
|
|
|
// sqlEngine packages up the context necessary to run sql queries against sqle.
|
|
func newSqlEngine(sqlCtx *sql.Context, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, format resultFormat, dbs ...dsqle.Database) (*sqlEngine, error) {
|
|
engine := sqle.NewDefault()
|
|
engine.AddDatabase(sql.NewInformationSchemaDatabase(engine.Catalog))
|
|
|
|
dsess := dsqle.DSessFromSess(sqlCtx.Session)
|
|
|
|
nameToDB := make(map[string]dsqle.Database)
|
|
for _, db := range dbs {
|
|
nameToDB[db.Name()] = db
|
|
root := roots[db.Name()]
|
|
engine.AddDatabase(db)
|
|
err := dsess.AddDB(sqlCtx, db)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = db.SetRoot(sqlCtx, root)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = dsqle.RegisterSchemaFragments(sqlCtx, db, root)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
sqlCtx.RegisterIndexDriver(dsqle.NewDoltIndexDriver(dbs...))
|
|
err := sqlCtx.LoadIndexes(sqlCtx, engine.Catalog.AllDatabases())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &sqlEngine{nameToDB, mrEnv, engine, format}, nil
|
|
}
|
|
|
|
func (se *sqlEngine) getDB(name string) (dsqle.Database, error) {
|
|
db, ok := se.dbs[name]
|
|
|
|
if !ok {
|
|
return dsqle.Database{}, ErrDBNotFoundKind.New(name)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func (se *sqlEngine) iterDBs(cb func(name string, db dsqle.Database) (stop bool, err error)) error {
|
|
for name, db := range se.dbs {
|
|
stop, err := cb(name, db)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if stop {
|
|
break
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (se *sqlEngine) getRoots(sqlCtx *sql.Context) (map[string]*doltdb.RootValue, error) {
|
|
newRoots := make(map[string]*doltdb.RootValue)
|
|
for name := range se.mrEnv {
|
|
db, err := se.getDB(name)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
newRoots[name], err = db.GetRoot(sqlCtx)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return newRoots, nil
|
|
}
|
|
|
|
// Execute a SQL statement and return values for printing.
|
|
func (se *sqlEngine) query(ctx *sql.Context, query string) (sql.Schema, sql.RowIter, error) {
|
|
return se.engine.Query(ctx, query)
|
|
}
|
|
|
|
// Pretty prints the output of the new SQL engine
|
|
func (se *sqlEngine) prettyPrintResults(ctx context.Context, sqlSch sql.Schema, rowIter sql.RowIter) error {
|
|
if isOkResult(sqlSch) {
|
|
return printOKResult(ctx, rowIter)
|
|
}
|
|
|
|
nbf := types.Format_Default
|
|
|
|
doltSch, err := dsqle.SqlSchemaToDoltResultSchema(sqlSch)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
untypedSch, err := untyped.UntypeUnkeySchema(doltSch)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rowChannel := make(chan row.Row)
|
|
p := pipeline.NewPartialPipeline(pipeline.InFuncForChannel(rowChannel))
|
|
|
|
// Parts of the pipeline depend on the output format, such as how we print null values and whether we pad strings.
|
|
switch se.resultFormat {
|
|
case formatCsv:
|
|
nullPrinter := nullprinter.NewNullPrinterWithNullString(untypedSch, "")
|
|
p.AddStage(pipeline.NewNamedTransform(nullprinter.NullPrintingStage, nullPrinter.ProcessRow))
|
|
|
|
case formatTabular:
|
|
nullPrinter := nullprinter.NewNullPrinter(untypedSch)
|
|
p.AddStage(pipeline.NewNamedTransform(nullprinter.NullPrintingStage, nullPrinter.ProcessRow))
|
|
autoSizeTransform := fwt.NewAutoSizingFWTTransformer(untypedSch, fwt.PrintAllWhenTooLong, 10000)
|
|
p.AddStage(pipeline.NamedTransform{Name: fwtStageName, Func: autoSizeTransform.TransformToFWT})
|
|
}
|
|
|
|
// Redirect output to the CLI
|
|
cliWr := iohelp.NopWrCloser(cli.CliOut)
|
|
|
|
var wr table.TableWriteCloser
|
|
|
|
switch se.resultFormat {
|
|
case formatTabular:
|
|
wr, err = tabular.NewTextTableWriter(cliWr, untypedSch)
|
|
case formatCsv:
|
|
wr, err = csv.NewCSVWriter(cliWr, untypedSch, csv.NewCSVInfo())
|
|
case formatJson:
|
|
wr, err = json.NewJSONWriter(cliWr, untypedSch)
|
|
default:
|
|
panic("unimplemented output format type")
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
p.RunAfter(func() { wr.Close(ctx) })
|
|
|
|
cliSink := pipeline.ProcFuncForWriter(ctx, wr)
|
|
p.SetOutput(cliSink)
|
|
|
|
p.SetBadRowCallback(func(tff *pipeline.TransformRowFailure) (quit bool) {
|
|
cli.PrintErrln(color.RedString("error: failed to transform row %s.", row.Fmt(ctx, tff.Row, untypedSch)))
|
|
return true
|
|
})
|
|
|
|
colNames, err := schema.ExtractAllColNames(untypedSch)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
r, err := untyped.NewRowFromTaggedStrings(nbf, untypedSch, colNames)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Insert the table header row at the appropriate stage
|
|
if se.resultFormat == formatTabular {
|
|
p.InjectRow(fwtStageName, r)
|
|
}
|
|
|
|
// For some output formats, we want to convert everything to strings to be processed by the pipeline. For others,
|
|
// we want to leave types alone and let the writer figure out how to format it for output.
|
|
var rowFn func(r sql.Row) (row.Row, error)
|
|
switch se.resultFormat {
|
|
case formatJson:
|
|
rowFn = func(r sql.Row) (r2 row.Row, err error) {
|
|
return dsqle.SqlRowToDoltRow(nbf, r, doltSch)
|
|
}
|
|
default:
|
|
rowFn = func(r sql.Row) (row.Row, error) {
|
|
taggedVals := make(row.TaggedValues)
|
|
for i, col := range r {
|
|
if col != nil {
|
|
taggedVals[uint64(i)] = types.String(fmt.Sprintf("%v", col))
|
|
}
|
|
}
|
|
return row.New(nbf, untypedSch, taggedVals)
|
|
}
|
|
}
|
|
|
|
var iterErr error
|
|
|
|
// Read rows off the row iter and pass them to the pipeline channel
|
|
go func() {
|
|
defer close(rowChannel)
|
|
var sqlRow sql.Row
|
|
for sqlRow, iterErr = rowIter.Next(); iterErr == nil; sqlRow, iterErr = rowIter.Next() {
|
|
var r row.Row
|
|
r, iterErr = rowFn(sqlRow)
|
|
|
|
if iterErr == nil {
|
|
rowChannel <- r
|
|
}
|
|
}
|
|
}()
|
|
|
|
p.Start()
|
|
if err := p.Wait(); err != nil {
|
|
return fmt.Errorf("error processing results: %v", err)
|
|
}
|
|
|
|
if iterErr != io.EOF {
|
|
return fmt.Errorf("error processing results: %v", iterErr)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func printOKResult(ctx context.Context, iter sql.RowIter) error {
|
|
row, err := iter.Next()
|
|
defer iter.Close()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if okResult, ok := row[0].(sql.OkResult); ok {
|
|
rowNoun := "row"
|
|
if okResult.RowsAffected != 1 {
|
|
rowNoun = "rows"
|
|
}
|
|
cli.Printf("Query OK, %d %s affected\n", okResult.RowsAffected, rowNoun)
|
|
|
|
if okResult.Info != nil {
|
|
cli.Printf("%s\n", okResult.Info)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func isOkResult(sch sql.Schema) bool {
|
|
return sch.Equals(sql.OkResultSchema)
|
|
}
|
|
|
|
// Checks if the query is a naked delete and then deletes all rows if so. Returns true if it did so, false otherwise.
|
|
func (se *sqlEngine) checkThenDeleteAllRows(ctx *sql.Context, s *sqlparser.Delete) bool {
|
|
if s.Where == nil && s.Limit == nil && s.Partitions == nil && len(s.TableExprs) == 1 {
|
|
if ate, ok := s.TableExprs[0].(*sqlparser.AliasedTableExpr); ok {
|
|
if ste, ok := ate.Expr.(sqlparser.TableName); ok {
|
|
dbName := ctx.Session.GetCurrentDatabase()
|
|
if !ste.Qualifier.IsEmpty() {
|
|
dbName = ste.Qualifier.String()
|
|
}
|
|
|
|
roots, err := se.getRoots(ctx)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
root, ok := roots[dbName]
|
|
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
tName := ste.Name.String()
|
|
table, ok, err := root.GetTable(ctx, tName)
|
|
if err == nil && ok {
|
|
|
|
// Let the SQL engine handle system table deletes to avoid duplicating business logic here
|
|
if doltdb.HasDoltPrefix(tName) {
|
|
return false
|
|
}
|
|
|
|
rowData, err := table.GetRowData(ctx)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
result := sql.OkResult{RowsAffected: rowData.Len()}
|
|
printRowIter := sql.RowsToRowIter(sql.NewRow(result))
|
|
|
|
emptyMap, err := types.NewMap(ctx, root.VRW())
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
newTable, err := table.UpdateRows(ctx, emptyMap)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
newRoot, err := root.PutTable(ctx, tName, newTable)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
_ = se.prettyPrintResults(ctx, sql.OkResultSchema, printRowIter)
|
|
|
|
db, err := se.getDB(dbName)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
err = db.SetRoot(ctx, newRoot)
|
|
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// Executes a SQL DDL statement (create, update, etc.). Updates the new root value in
|
|
// the sqlEngine if necessary.
|
|
func (se *sqlEngine) ddl(ctx *sql.Context, ddl *sqlparser.DDL, query string) (sql.Schema, sql.RowIter, error) {
|
|
switch ddl.Action {
|
|
case sqlparser.CreateStr, sqlparser.DropStr, sqlparser.AlterStr, sqlparser.RenameStr:
|
|
_, ri, err := se.query(ctx, query)
|
|
if err == nil {
|
|
ri.Close()
|
|
}
|
|
return nil, nil, err
|
|
default:
|
|
return nil, nil, fmt.Errorf("Unhandled DDL action %v in query %v", ddl.Action, query)
|
|
}
|
|
}
|