diff --git a/bats/sql-batch.bats b/bats/sql-batch.bats index 6b8d1eb058..4559e1c250 100644 --- a/bats/sql-batch.bats +++ b/bats/sql-batch.bats @@ -52,7 +52,27 @@ insert into test values poop; SQL [ "$status" -ne 0 ] [[ "$output" =~ "Error processing batch" ]] || false - skip "No line number and query on error" - [[ "$output" =~ " 3 " ]] || false - [[ "$output" =~ "insert into test values poop;" ]] || false + [[ "$output" =~ "error on line 3 for query" ]] || false + [[ "$output" =~ "insert into test values poop" ]] || false + + run dolt sql < 0 { + err = fmt.Errorf("%s: %s", message, err.Error()) + } return errhand.VerboseErrorFromError(err) } } @@ -553,31 +554,9 @@ func saveQuery(ctx context.Context, root *doltdb.RootValue, dEnv *env.DoltEnv, q return newRoot, nil } -// ScanStatements is a split function for a Scanner that returns each SQL statement in the input as a token. It doesn't -// work for strings that contain semi-colons. Supporting that requires implementing a state machine. -func scanStatements(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, ';'); i >= 0 { - // We have a full ;-terminated line. - return i + 1, data[0:i], nil - } - // If we're at EOF, we have a final, non-terminated line. Return it. - if atEOF { - return len(data), data, nil - } - // Request more data. - return 0, nil, nil -} - // runBatchMode processes queries until EOF. The Root of the sqlEngine may be updated. func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error { - scanner := bufio.NewScanner(input) - const maxCapacity = 512 * 1024 - buf := make([]byte, maxCapacity) - scanner.Buffer(buf, maxCapacity) - scanner.Split(scanStatements) + scanner := NewSqlStatementScanner(input) var query string for scanner.Scan() { @@ -585,13 +564,10 @@ func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error { if len(query) == 0 || query == "\n" { continue } - if !batchInsertEarlySemicolon(query) { - query += ";" - // TODO: We should fix this problem by properly implementing a state machine for scanStatements - continue - } if err := processBatchQuery(ctx, query, se); err != nil { - verr := formatQueryError(query, err) + // TODO: this line number will not be accurate for errors that occur when flushing a batch of inserts (as opposed + // to processing the query) + verr := formatQueryError(fmt.Sprintf("error on line %d for query %s", scanner.statementStartLine, query), err) cli.PrintErrln(verr.Verbose()) return err } @@ -607,36 +583,6 @@ func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error { return flushBatchedEdits(ctx, se) } -// batchInsertEarlySemicolon loops through a string to check if Scan stopped too early on a semicolon -func batchInsertEarlySemicolon(query string) bool { - quotes := []uint8{'\'', '"'} - midQuote := false - queryLength := len(query) - for i := 0; i < queryLength; i++ { - for _, quote := range quotes { - if query[i] == quote { - i++ - midQuote = true - inEscapeMode := false - for ; i < queryLength; i++ { - if inEscapeMode { - inEscapeMode = false - } else { - if query[i] == quote { - midQuote = false - break - } else if query[i] == '\\' { - inEscapeMode = true - } - } - } - break - } - } - } - return !midQuote -} - // runShell starts a SQL shell. Returns when the user exits the shell. The Root of the sqlEngine may // be updated by any queries which were processed. func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error { @@ -695,7 +641,7 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error { } if sqlSch, rowIter, err := processQuery(ctx, query, se); err != nil { - verr := formatQueryError(query, err) + verr := formatQueryError("", err) shell.Println(verr.Verbose()) } else if rowIter != nil { defer rowIter.Close() diff --git a/go/cmd/dolt/commands/sql_statement_scanner.go b/go/cmd/dolt/commands/sql_statement_scanner.go new file mode 100755 index 0000000000..a5820b8fe1 --- /dev/null +++ b/go/cmd/dolt/commands/sql_statement_scanner.go @@ -0,0 +1,151 @@ +// Copyright 2020 Liquidata, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package commands + +import ( + "bufio" + "io" + "unicode" +) + +type statementScanner struct { + *bufio.Scanner + statementStartLine int // the line number of the first line of the last parsed statement + startLineNum int // the line number we began parsing the most recent token at + lineNum int // the current line number being parsed +} + +func NewSqlStatementScanner(input io.Reader) *statementScanner { + scanner := bufio.NewScanner(input) + const maxCapacity = 512 * 1024 + buf := make([]byte, maxCapacity) + scanner.Buffer(buf, maxCapacity) + + s := &statementScanner{ + Scanner: scanner, + lineNum: 1, + } + scanner.Split(s.scanStatements) + + return s +} + +const ( + sQuote byte = '\'' + dQuote = '"' + backslash = '\\' + backtick = '`' +) + +// ScanStatements is a split function for a Scanner that returns each SQL statement in the input as a token. +func (s *statementScanner) scanStatements(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + + var ( + quoteChar byte // the opening quote character of the current quote being parsed, or 0 if the current parse location isn't inside a quoted string + lastChar byte // the last character parsed + ignoreNextChar bool // whether to ignore the next character + numConsecutiveBackslashes int // the number of consecutive backslashes encountered + seenNonWhitespaceChar bool // whether we have encountered a non-whitespace character since we returned the last token + ) + + s.startLineNum = s.lineNum + + for i := range data { + if !ignoreNextChar { + // this doesn't handle unicode characters correctly and will break on some things, but it's only used for line + // number reporting. + if !seenNonWhitespaceChar && !unicode.IsSpace(rune(data[i])) { + seenNonWhitespaceChar = true + s.statementStartLine = s.lineNum + } + + switch data[i] { + case '\n': + s.lineNum++ + case ';': + if quoteChar == 0 { + s.startLineNum = s.lineNum + _, _, _ = s.resetState() + return i + 1, data[0:i], nil + } + case backslash: + numConsecutiveBackslashes++ + case sQuote, dQuote, backtick: + prevNumConsecutiveBackslashes := numConsecutiveBackslashes + numConsecutiveBackslashes = 0 + + // escaped quote character + if lastChar == backslash && prevNumConsecutiveBackslashes%2 == 1 { + break + } + + // currently in a quoted string + if quoteChar != 0 { + + // end quote or two consecutive quote characters (a form of escaping quote chars) + if quoteChar == data[i] { + var nextChar byte = 0 + if i+1 < len(data) { + nextChar = data[i+1] + } + + if nextChar == quoteChar { + // escaped quote. skip the next character + ignoreNextChar = true + break + } else if atEOF || i+1 < len(data) { + // end quote + quoteChar = 0 + break + } else { + // need more data to make a decision + return s.resetState() + } + } + + // embedded quote ('"' or "'") + break + } + + // open quote + quoteChar = data[i] + default: + numConsecutiveBackslashes = 0 + } + } else { + ignoreNextChar = false + } + + lastChar = data[i] + } + + // If we're at EOF, we have a final, non-terminated line. Return it. + if atEOF { + return len(data), data, nil + } + + // Request more data. + return s.resetState() +} + +// resetState resets the internal state of the scanner and returns the "more data" response for a split function +func (s *statementScanner) resetState() (advance int, token []byte, err error) { + // rewind the line number to where we started parsing this token + s.lineNum = s.startLineNum + return 0, nil, nil +} diff --git a/go/cmd/dolt/commands/sql_statement_scanner_test.go b/go/cmd/dolt/commands/sql_statement_scanner_test.go new file mode 100755 index 0000000000..a76d609b1a --- /dev/null +++ b/go/cmd/dolt/commands/sql_statement_scanner_test.go @@ -0,0 +1,202 @@ +// Copyright 2020 Liquidata, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package commands + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScanStatements(t *testing.T) { + type testcase struct { + input string + statements []string + lineNums []int + } + + // Some of these include malformed input (e.g. strings that aren't properly terminated) + testcases := []testcase{ + { + input: `insert into foo values (";;';'");`, + statements: []string{ + `insert into foo values (";;';'")`, + }, + }, + { + input: `select ''';;'; select ";\;"`, + statements: []string{ + `select ''';;'`, + `select ";\;"`, + }, + }, + { + input: `select ''';;'; select ";\;`, + statements: []string{ + `select ''';;'`, + `select ";\;`, + }, + }, + { + input: `select ''';;'; select ";\; +;`, + statements: []string{ + `select ''';;'`, + `select ";\; +;`, + }, + }, + { + input: `select '\\'''; select '";";'; select 1`, + statements: []string{ + `select '\\'''`, + `select '";";'`, + `select 1`, + }, + }, + { + input: `select '\\''; select '";";'; select 1`, + statements: []string{ + `select '\\''; select '";"`, + `'; select 1`, + }, + }, + { + input: `insert into foo values(''); select 1`, + statements: []string{ + `insert into foo values('')`, + `select 1`, + }, + }, + { + input: `insert into foo values('''); select 1`, + statements: []string{ + `insert into foo values('''); select 1`, + }, + }, + { + input: `insert into foo values(''''); select 1`, + statements: []string{ + `insert into foo values('''')`, + `select 1`, + }, + }, + { + input: `insert into foo values(""); select 1`, + statements: []string{ + `insert into foo values("")`, + `select 1`, + }, + }, + { + input: `insert into foo values("""); select 1`, + statements: []string{ + `insert into foo values("""); select 1`, + }, + }, + { + input: `insert into foo values(""""); select 1`, + statements: []string{ + `insert into foo values("""")`, + `select 1`, + }, + }, + { + input: `select '\''; select "hell\"o"`, + statements: []string{ + `select '\''`, + `select "hell\"o"`, + }, + }, + { + input: `select * from foo; select baz from foo; +select +a from b; select 1`, + statements: []string{ + "select * from foo", + "select baz from foo", + "select\na from b", + "select 1", + }, + lineNums: []int{ + 1, 1, 2, 3, + }, + }, + { + input: "create table dumb (`hell\\`o;` int primary key);", + statements: []string{ + "create table dumb (`hell\\`o;` int primary key)", + }, + }, + { + input: "create table dumb (`hell``o;` int primary key); select \n" + + "baz from foo;\n" + + "\n" + + "select\n" + + "a from b; select 1\n\n", + statements: []string{ + "create table dumb (`hell``o;` int primary key)", + "select \nbaz from foo", + "select\na from b", + "select 1", + }, + lineNums: []int{ + 1, 1, 4, 5, + }, + }, + { + input: `insert into foo values ('a', "b;", 'c;;"" +'); update foo set baz = bar, +qux = '"hello"""' where xyzzy = ";;';'"; + + +create table foo (a int not null default ';', +primary key (a));`, + statements: []string{ + `insert into foo values ('a', "b;", 'c;;"" +')`, + `update foo set baz = bar, +qux = '"hello"""' where xyzzy = ";;';'"`, + `create table foo (a int not null default ';', +primary key (a))`, + }, + lineNums: []int{ + 1, 2, 6, + }, + }, + } + + for _, tt := range testcases { + t.Run(tt.input, func(t *testing.T) { + reader := strings.NewReader(tt.input) + scanner := NewSqlStatementScanner(reader) + var i int + for scanner.Scan() { + require.True(t, i < len(tt.statements)) + assert.Equal(t, tt.statements[i], strings.TrimSpace(scanner.Text())) + if tt.lineNums != nil { + assert.Equal(t, tt.lineNums[i], scanner.statementStartLine) + } else { + assert.Equal(t, 1, scanner.statementStartLine) + } + i++ + } + + require.NoError(t, scanner.Err()) + }) + } +}