Merge pull request #9237 from dolthub/nathan/sql-shell-warnings

Sql shell warnings
This commit is contained in:
Nathan Gabrielson
2025-05-22 13:47:47 -07:00
committed by GitHub
13 changed files with 323 additions and 63 deletions

View File

@@ -90,6 +90,12 @@ type Queryist interface {
QueryWithBindings(ctx *sql.Context, query string, parsed sqlparser.Statement, bindings map[string]sqlparser.Expr, qFlags *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error)
}
// ShellServerQueryist is used to gather warnings in the sql-shell context when a server is running.
// We call an extra "show warnings" query, but want to avoid this in other cases, (i.e. dolt sql -q)
type ShellServerQueryist interface {
EnableGatherWarnings()
}
// This type is to store the content of a documented command, elsewhere we can transform this struct into
// other structs that are used to generate documentation at the command line and in markdown files.
type CommandDocumentationContent struct {

View File

@@ -130,7 +130,7 @@ func (cmd BlameCmd) Exec(ctx context.Context, commandStr string, args []string,
return 1
}
err = engine.PrettyPrintResults(sqlCtx, engine.FormatTabular, schema, ri, false)
err = engine.PrettyPrintResults(sqlCtx, engine.FormatTabular, schema, ri, false, false)
if err != nil {
iohelp.WriteLine(cli.CliOut, err.Error())
return 1

View File

@@ -181,7 +181,7 @@ func printViolationsForTable(ctx *sql.Context, dbName, tblName string, tbl *dolt
limitItr := &sqlLimitIter{itr: sqlItr, limit: 50}
err = engine.PrettyPrintResults(ctx, engine.FormatTabular, sqlSch, limitItr, false)
err = engine.PrettyPrintResults(ctx, engine.FormatTabular, sqlSch, limitItr, false, false)
if err != nil {
return errhand.BuildDError("Error outputting rows").AddCause(err).Build()
}

View File

@@ -23,6 +23,7 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/fatih/color"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/row"
@@ -55,16 +56,17 @@ const (
)
// PrettyPrintResults prints the result of a query in the format provided
func PrettyPrintResults(ctx *sql.Context, resultFormat PrintResultFormat, sqlSch sql.Schema, rowIter sql.RowIter, pageResults bool) (rerr error) {
return prettyPrintResultsWithSummary(ctx, resultFormat, sqlSch, rowIter, PrintNoSummary, pageResults)
func PrettyPrintResults(ctx *sql.Context, resultFormat PrintResultFormat, sqlSch sql.Schema, rowIter sql.RowIter, pageResults bool, showWarnings bool) (rerr error) {
return prettyPrintResultsWithSummary(ctx, resultFormat, sqlSch, rowIter, PrintNoSummary, pageResults, showWarnings)
}
// PrettyPrintResultsExtended prints the result of a query in the format provided, including row count and timing info
func PrettyPrintResultsExtended(ctx *sql.Context, resultFormat PrintResultFormat, sqlSch sql.Schema, rowIter sql.RowIter, pageResults bool) (rerr error) {
return prettyPrintResultsWithSummary(ctx, resultFormat, sqlSch, rowIter, PrintRowCountAndTiming, pageResults)
func PrettyPrintResultsExtended(ctx *sql.Context, resultFormat PrintResultFormat, sqlSch sql.Schema, rowIter sql.RowIter, pageResults bool, showWarnings bool) (rerr error) {
return prettyPrintResultsWithSummary(ctx, resultFormat, sqlSch, rowIter, PrintRowCountAndTiming, pageResults, showWarnings)
}
func prettyPrintResultsWithSummary(ctx *sql.Context, resultFormat PrintResultFormat, sqlSch sql.Schema, rowIter sql.RowIter, summary PrintSummaryBehavior, pageResults bool) (rerr error) {
func prettyPrintResultsWithSummary(ctx *sql.Context, resultFormat PrintResultFormat, sqlSch sql.Schema, rowIter sql.RowIter, summary PrintSummaryBehavior, pageResults bool, showWarnings bool) (rerr error) {
defer func() {
closeErr := rowIter.Close(ctx)
if rerr == nil && closeErr != nil {
@@ -139,7 +141,15 @@ func prettyPrintResultsWithSummary(ctx *sql.Context, resultFormat PrintResultFor
}
if summary == PrintRowCountAndTiming {
err = printResultSetSummary(numRows, start)
warnings := ""
if showWarnings {
warnings = "\n"
for _, warn := range ctx.Session.Warnings() {
warnings += color.YellowString(fmt.Sprintf("\nWarning (Code %d): %s", warn.Code, warn.Message))
}
}
err = printResultSetSummary(numRows, ctx.WarningCount(), warnings, start)
if err != nil {
return err
}
@@ -154,9 +164,19 @@ func prettyPrintResultsWithSummary(ctx *sql.Context, resultFormat PrintResultFor
}
}
func printResultSetSummary(numRows int, start time.Time) error {
func printResultSetSummary(numRows int, numWarnings uint16, warningsList string, start time.Time) error {
warning := ""
if numWarnings > 0 {
plural := ""
if numWarnings > 1 {
plural = "s"
}
warning = fmt.Sprintf(", %d warning%s", numWarnings, plural)
}
if numRows == 0 {
printEmptySetResult(start)
printEmptySetResult(start, warning)
return nil
}
@@ -166,7 +186,7 @@ func printResultSetSummary(numRows int, start time.Time) error {
}
secondsSinceStart := secondsSince(start, time.Now())
err := iohelp.WriteLine(cli.CliOut, fmt.Sprintf("%d %s in set (%.2f sec)", numRows, noun, secondsSinceStart))
err := iohelp.WriteLine(cli.CliOut, fmt.Sprintf("%d %s in set%s (%.2f sec) %s", numRows, noun, warning, secondsSinceStart, warningsList))
if err != nil {
return err
}
@@ -216,9 +236,9 @@ type nullWriter struct{}
func (n nullWriter) WriteSqlRow(ctx *sql.Context, r sql.Row) error { return nil }
func (n nullWriter) Close(ctx context.Context) error { return nil }
func printEmptySetResult(start time.Time) {
func printEmptySetResult(start time.Time, warning string) {
seconds := secondsSince(start, time.Now())
cli.Printf("Empty set (%.2f sec)\n", seconds)
cli.Printf("Empty set%s (%.2f sec)\n", warning, seconds)
}
func printOKResult(ctx *sql.Context, iter sql.RowIter, start time.Time) error {

View File

@@ -140,7 +140,7 @@ func (cmd TagsCmd) Exec(ctx context.Context, commandStr string, args []string, d
}
sqlCtx := sql.NewContext(ctx)
err = engine.PrettyPrintResults(sqlCtx, outputFmt, headerSchema, sql.RowsToRowIter(rows...), false)
err = engine.PrettyPrintResults(sqlCtx, outputFmt, headerSchema, sql.RowsToRowIter(rows...), false, false)
return commands.HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}

View File

@@ -447,7 +447,7 @@ func execSingleQuery(
}
if rowIter != nil {
err = engine.PrettyPrintResults(sqlCtx, format, sqlSch, rowIter, false)
err = engine.PrettyPrintResults(sqlCtx, format, sqlSch, rowIter, false, false)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -663,7 +663,7 @@ func execBatchMode(ctx *sql.Context, qryist cli.Queryist, input io.Reader, conti
fileReadProg.printNewLineIfNeeded()
}
}
err = engine.PrettyPrintResults(ctx, format, sqlSch, rowIter, false)
err = engine.PrettyPrintResults(ctx, format, sqlSch, rowIter, false, false)
if err != nil {
err = buildBatchSqlErr(scanner.state.statementStartLine, query, err)
if !continueOnErr {
@@ -753,6 +753,12 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
initialCtx := sqlCtx.Context
//We want to gather the warnings if a server is running, as the connection queryist does not automatically cache them
if c, ok := qryist.(cli.ShellServerQueryist); ok {
c.EnableGatherWarnings()
}
toggleWarnings := true
pagerEnabled := false
// Used for the \edit command.
lastSqlCmd := ""
@@ -796,6 +802,9 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
}
if cmdType == DoltCliCommand {
_, okOn := subCmd.(WarningOn)
_, okOff := subCmd.(WarningOff)
if _, ok := subCmd.(SlashPager); ok {
p, err := handlePagerCommand(query)
if err != nil {
@@ -803,6 +812,18 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
} else {
pagerEnabled = p
}
} else if okOn || okOff {
w, err := handleWarningCommand(query)
if err != nil {
shell.Println(color.RedString(err.Error()))
} else {
toggleWarnings = w
if toggleWarnings {
cli.Println("Show warnings enabled")
} else {
cli.Println("Show warnings disabled")
}
}
} else {
err := handleSlashCommand(sqlCtx, subCmd, query, cliCtx)
if err != nil {
@@ -823,9 +844,9 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu
} else if rowIter != nil {
switch closureFormat {
case engine.FormatTabular, engine.FormatVertical:
err = engine.PrettyPrintResultsExtended(sqlCtx, closureFormat, sqlSch, rowIter, pagerEnabled)
err = engine.PrettyPrintResultsExtended(sqlCtx, closureFormat, sqlSch, rowIter, pagerEnabled, toggleWarnings)
default:
err = engine.PrettyPrintResults(sqlCtx, closureFormat, sqlSch, rowIter, pagerEnabled)
err = engine.PrettyPrintResults(sqlCtx, closureFormat, sqlSch, rowIter, pagerEnabled, toggleWarnings)
}
if err != nil {
@@ -951,8 +972,12 @@ func formattedPrompts(db, branch string, dirty bool) (string, string) {
// along the way by printing red error messages to the CLI. If there was an issue getting the db name, the ok return
// value will be false and the strings will be empty.
func getDBBranchFromSession(sqlCtx *sql.Context, qryist cli.Queryist) (db string, branch string, ok bool) {
sqlCtx.Session.LockWarnings()
defer sqlCtx.Session.UnlockWarnings()
_, _, _, err := qryist.Query(sqlCtx, "set lock_warnings = 1")
if err != nil {
cli.Println(color.RedString(err.Error()))
return "", "", false
}
defer qryist.Query(sqlCtx, "set lock_warnings = 0")
_, resp, _, err := qryist.Query(sqlCtx, "select database() as db, active_branch() as branch")
if err != nil {
@@ -993,8 +1018,11 @@ func getDBBranchFromSession(sqlCtx *sql.Context, qryist cli.Queryist) (db string
// isDirty returns true if the workspace is dirty, false otherwise. This function _assumes_ you are on a database
// with a branch. If you are not, you will get an error.
func isDirty(sqlCtx *sql.Context, qryist cli.Queryist) (bool, error) {
sqlCtx.Session.LockWarnings()
defer sqlCtx.Session.UnlockWarnings()
_, _, _, err := qryist.Query(sqlCtx, "set lock_warnings = 1")
if err != nil {
return false, err
}
defer qryist.Query(sqlCtx, "set lock_warnings = 0")
_, resp, _, err := qryist.Query(sqlCtx, "select count(table_name) > 0 as dirty from dolt_status")

View File

@@ -41,6 +41,8 @@ var slashCmds = []cli.Command{
SlashHelp{},
SlashEdit{},
SlashPager{},
WarningOn{},
WarningOff{},
}
// parseSlashCmd parses a command line string into a slice of strings, splitting on spaces, but allowing spaces within
@@ -208,8 +210,12 @@ func (s SlashEdit) Exec(ctx context.Context, commandStr string, args []string, d
}
func (s SlashEdit) Docs() *cli.CommandDocumentation {
//TODO implement me
return &cli.CommandDocumentation{}
return &cli.CommandDocumentation{
ShortDesc: "Use $EDITOR to edit the last command.",
LongDesc: "Start a text editor to edit your last command. Command will be executed after you finish editing.",
Synopsis: []string{},
ArgParser: s.ArgParser(),
}
}
func (s SlashEdit) ArgParser() *argparser.ArgParser {
@@ -220,8 +226,12 @@ func (s SlashEdit) ArgParser() *argparser.ArgParser {
type SlashPager struct{}
func (s SlashPager) Docs() *cli.CommandDocumentation {
//TODO
return &cli.CommandDocumentation{}
return &cli.CommandDocumentation{
ShortDesc: "Enable or Disable the result pager",
LongDesc: "Returns results in pager form. Use pager [on|off].",
Synopsis: []string{},
ArgParser: s.ArgParser(),
}
}
func (s SlashPager) ArgParser() *argparser.ArgParser {
@@ -266,3 +276,64 @@ func handlePagerCommand(fullCmd string) (bool, error) {
return false, fmt.Errorf("Usage: \\pager [on|off]")
}
type WarningCmd struct{}
func (s WarningCmd) Docs() *cli.CommandDocumentation {
return &cli.CommandDocumentation{
ShortDesc: "Toggle display of generated warnings after sql command.",
LongDesc: "Displays a detailed list of the warnings generated after each sql command. Use \\W and \\w to enable and disable the setting, respectively.",
Synopsis: []string{},
ArgParser: s.ArgParser(),
}
}
func (s WarningCmd) ArgParser() *argparser.ArgParser {
return &argparser.ArgParser{}
}
// Exec should never be called on warning command; It only changes which information is displayed.
// handleWarningCommand should be used instead
func (s WarningCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv, cliCtx cli.CliContext) int {
panic("runtime error. Exec should never be called on warning display commands.")
}
type WarningOn struct {
WarningCmd
}
var _ cli.Command = WarningOn{}
func (s WarningOn) Name() string { return "W" }
func (s WarningOn) Description() string {
return "Show generated warnings after sql command"
}
type WarningOff struct {
WarningCmd
}
var _ cli.Command = WarningOff{}
func (s WarningOff) Name() string { return "w" }
func (s WarningOff) Description() string {
return "Hide generated warnings after sql command"
}
func handleWarningCommand(fullCmd string) (bool, error) {
tokens := strings.Split(fullCmd, " ")
if len(tokens) == 0 || (tokens[0] != "\\w" && tokens[0] != "\\W") {
return false, fmt.Errorf("runtime error: Expected \\w or \\W command.")
} else if len(tokens) > 1 {
return false, fmt.Errorf("Usage: \\w \\w to toggle warnings")
}
if tokens[0] == "\\W" {
return true, nil
} else {
return false, nil
}
}

View File

@@ -19,6 +19,8 @@ import (
sql2 "database/sql"
"fmt"
"io"
"regexp"
"strings"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
@@ -63,7 +65,8 @@ func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, c
conn := &dbr.Connection{DB: sql2.OpenDB(mysqlConnector), EventReceiver: nil, Dialect: dialect.MySQL}
queryist := ConnectionQueryist{connection: conn}
gatherWarnings := false
queryist := ConnectionQueryist{connection: conn, gatherWarnings: &gatherWarnings}
var lateBind cli.LateBindQueryist = func(ctx context.Context) (cli.Queryist, *sql.Context, func(), error) {
sqlCtx := sql.NewContext(ctx)
@@ -78,20 +81,55 @@ func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, c
// ConnectionQueryist executes queries by connecting to a running mySql server.
type ConnectionQueryist struct {
connection *dbr.Connection
connection *dbr.Connection
gatherWarnings *bool
}
var _ cli.Queryist = ConnectionQueryist{}
var _ cli.Queryist = &ConnectionQueryist{}
func (c ConnectionQueryist) EnableGatherWarnings() {
*c.gatherWarnings = true
}
func (c ConnectionQueryist) Query(ctx *sql.Context, query string) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
rows, err := c.connection.QueryContext(ctx, query)
if err != nil {
return nil, nil, nil, err
}
rowIter, err := NewMysqlRowWrapper(rows)
if err != nil {
return nil, nil, nil, err
}
if c.gatherWarnings != nil && *c.gatherWarnings == true {
ctx.ClearWarnings()
re := regexp.MustCompile(`\s+`)
noSpace := strings.TrimSpace(re.ReplaceAllString(query, " "))
isShowWarnings := strings.EqualFold(noSpace, "show warnings")
if !isShowWarnings {
warnRows, err := c.connection.QueryContext(ctx, "show warnings")
if err != nil {
return nil, nil, nil, err
}
for warnRows.Next() {
var code int
var msg string
var level string
err = warnRows.Scan(&level, &code, &msg)
if err != nil {
return nil, nil, nil, err
}
ctx.Warn(code, "%s", msg)
}
}
}
return rowIter.Schema(), rowIter, nil, nil
}
@@ -100,23 +138,23 @@ func (c ConnectionQueryist) QueryWithBindings(ctx *sql.Context, query string, _
}
type MysqlRowWrapper struct {
rows *sql2.Rows
schema sql.Schema
finished bool
vRow []*string
iRow []interface{}
rows []sql.Row
schema sql.Schema
numRows int
curRow int
}
var _ sql.RowIter = (*MysqlRowWrapper)(nil)
func NewMysqlRowWrapper(rows *sql2.Rows) (*MysqlRowWrapper, error) {
colNames, err := rows.Columns()
func NewMysqlRowWrapper(sqlRows *sql2.Rows) (*MysqlRowWrapper, error) {
colNames, err := sqlRows.Columns()
if err != nil {
return nil, err
}
schema := make(sql.Schema, len(colNames))
vRow := make([]*string, len(colNames))
iRow := make([]interface{}, len(colNames))
rows := make([]sql.Row, 0)
for i, colName := range colNames {
schema[i] = &sql.Column{
Name: colName,
@@ -125,12 +163,32 @@ func NewMysqlRowWrapper(rows *sql2.Rows) (*MysqlRowWrapper, error) {
}
iRow[i] = &vRow[i]
}
for sqlRows.Next() {
err := sqlRows.Scan(iRow...)
if err != nil {
return nil, err
}
sqlRow := make(sql.Row, len(vRow))
for i, val := range vRow {
if val != nil {
sqlRow[i] = *val
}
}
rows = append(rows, sqlRow)
}
closeErr := sqlRows.Close()
if closeErr != nil {
return nil, err
}
return &MysqlRowWrapper{
rows: rows,
schema: schema,
finished: !rows.Next(),
vRow: vRow,
iRow: iRow,
rows: rows,
schema: schema,
numRows: len(rows),
curRow: 0,
}, nil
}
@@ -139,27 +197,19 @@ func (s *MysqlRowWrapper) Schema() sql.Schema {
}
func (s *MysqlRowWrapper) Next(*sql.Context) (sql.Row, error) {
if s.finished {
if s.NoMoreRows() {
return nil, io.EOF
}
err := s.rows.Scan(s.iRow...)
if err != nil {
return nil, err
}
sqlRow := make(sql.Row, len(s.vRow))
for i, val := range s.vRow {
if val != nil {
sqlRow[i] = *val
}
}
s.finished = !s.rows.Next()
return sqlRow, nil
s.curRow++
return s.rows[s.curRow-1], nil
}
func (s *MysqlRowWrapper) HasMoreRows() bool {
return !s.finished
func (s *MysqlRowWrapper) NoMoreRows() bool {
return s.curRow >= s.numRows
}
func (s *MysqlRowWrapper) Close(*sql.Context) error {
return s.rows.Close()
s.curRow = s.numRows
return nil
}

View File

@@ -516,6 +516,13 @@ func AddDoltSystemVariables() {
Type: types.NewSystemBoolType("gpgsign"),
Default: int8(0),
},
&sql.MysqlSystemVariable{
Name: "sql_warnings",
Dynamic: true,
Scope: sql.GetMysqlScope(sql.SystemVariableScope_Both),
Type: types.NewSystemBoolType("sql_warnings"),
Default: int8(1),
},
})
sql.SystemVariables.AddSystemVariables(DoltSystemVariables)
}

View File

@@ -21,21 +21,43 @@ teardown() {
teardown_common
}
# bats test_tags=no_lambda
@test "sql-shell: warnings are not suppressed" {
skiponwindows "Need to install expect and make this script work on windows."
if [ "$SQL_ENGINE" = "remote-engine" ]; then
skip "session ctx in shell is no the same as session in server"
fi
run $BATS_TEST_DIRNAME/sql-shell-warnings.expect
echo "$output"
[[ "$output" =~ "Warning" ]] || false
[[ "$output" =~ "1365" ]] || false
[[ "$output" =~ "Division by 0" ]] || false
}
# bats test_tags=no_lambda
@test "sql-shell: can toggle warning details" {
skiponwindows "Need to install expect and make this script work on windows."
run $BATS_TEST_DIRNAME/sql-warning-summary.expect
[ "$status" -eq 0 ]
! [[ "$output" =~ "Warning (Code 1365): Division by 0\nWarning (Code 1365): Division by 0" ]] || false
}
# bats test_tags=no_lambda
@test "sql-shell: can toggle warning summary" {
skiponwindows "Need to install expect and make this script work on windows."
skip " set sql_warnings currently doesn't work --- needs more communication between server & shell"
run $BATS_TEST_DIRNAME/sql-warning-detail.expect
[ "$status" -eq 0 ]
! [[ "$output" =~ "1 row in set, 3 warnings" ]] || false
}
# bats test_tags=no_lambda
@test "sql-shell: show warnings hides warning summary, and removes whitespace" {
skiponwindows "Need to install expect and make this script work on windows."
run $BATS_TEST_DIRNAME/sql-show-warnings.expect
[ "$status" -eq 0 ]
}
@test "sql-shell: use user without privileges, and no superuser created" {
rm -rf .doltcfg
@@ -1007,4 +1029,4 @@ expect eof
[ $status -eq 0 ]
[[ "$output" =~ "github.com/dolthub/dolt/go" ]] || false
[[ "$output" =~ "github.com/dolthub/go-mysql-server" ]] || false
}
}

View File

@@ -0,0 +1,16 @@
#!/usr/bin/expect
set timeout 5
set env(NO_COLOR) 1
source "$env(BATS_CWD)/helper/common_expect_functions.tcl"
spawn dolt sql
expect_with_defaults {dolt-repo-.*} { send " shoW warNinGs ; \r"; }
expect_with_defaults_2 {Empty set} {dolt-repo-.*} { send "select 1/0;\r"; }
expect_with_defaults {dolt-repo-.*} { send " shoW warNinGs ; \r"; }
expect_with_defaults_2 {Division by 0 } {dolt-repo-.*} { send "exit;\r"; }

View File

@@ -0,0 +1,21 @@
#!/usr/bin/expect
set timeout 5
set env(NO_COLOR) 1
source "$env(BATS_CWD)/helper/common_expect_functions.tcl"
spawn dolt sql
expect_with_defaults {dolt-repo-.*} { send "select 1/0;\r"; }
expect_with_defaults_2 {1 row in set, 1 warning} {dolt-repo-.*} { send "select 1/0, 1/0;\r"; }
expect_with_defaults_2 {1 row in set, 2 warnings} {dolt-repo-.*} { send "set sql_warnings = 0;\r"; }
expect_with_defaults {dolt-repo-.*} { send "select 1/0, 1/0, 1/0;\r"; }
expect_with_defaults {dolt-repo-.*} { send "exit;\r"; }
expect eof
exit

View File

@@ -0,0 +1,19 @@
#!/usr/bin/expect
set timeout 5
set env(NO_COLOR) 1
source "$env(BATS_CWD)/helper/common_expect_functions.tcl"
spawn dolt sql
expect_with_defaults {dolt-repo-.*>} { send "select 1/0;\r"; }
expect_with_defaults_2 {Warning \(Code 1365\): Division by 0} {dolt-repo-.*>} { send "\\w\r"; }
expect_with_defaults_2 {Show warnings disabled} {dolt-repo-.*>} { send "select 1/0,1/0;\r"; }
expect_with_defaults {dolt-repo-.*>} { send "quit;\r"; }
expect eof
exit