mirror of
https://github.com/dolthub/dolt.git
synced 2026-01-08 00:39:48 -06:00
Merge pull request #613 from liquidata-inc/zachmu/sql-batch-errors
Real statement tokenization for SQL statements in batch processing mode, instead of line-based scans with hacks for embedded semicolons. Print out the line number of failed queries during batch processing failure.
This commit is contained in:
@@ -52,7 +52,27 @@ insert into test values poop;
|
||||
SQL
|
||||
[ "$status" -ne 0 ]
|
||||
[[ "$output" =~ "Error processing batch" ]] || false
|
||||
skip "No line number and query on error"
|
||||
[[ "$output" =~ " 3 " ]] || false
|
||||
[[ "$output" =~ "insert into test values poop;" ]] || false
|
||||
[[ "$output" =~ "error on line 3 for query" ]] || false
|
||||
[[ "$output" =~ "insert into test values poop" ]] || false
|
||||
|
||||
run dolt sql <<SQL
|
||||
insert into test values (0,0,0,0,0,0);
|
||||
|
||||
insert into test values (1,0,
|
||||
0,0,0,0);
|
||||
|
||||
insert into
|
||||
test values (2,0,0,0,0,0)
|
||||
;
|
||||
|
||||
insert into
|
||||
test values
|
||||
poop;
|
||||
|
||||
insert into test values (3,0,0,0,0,0);
|
||||
SQL
|
||||
[ "$status" -ne 0 ]
|
||||
[[ "$output" =~ "Error processing batch" ]] || false
|
||||
[[ "$output" =~ "error on line 10 for query" ]] || false
|
||||
[[ "$output" =~ "poop" ]] || false
|
||||
}
|
||||
|
||||
@@ -383,12 +383,12 @@ CREATE TABLE test (
|
||||
PRIMARY KEY (pk));
|
||||
SQL
|
||||
[ $status -ne 0 ]
|
||||
[[ "${lines[0]}" =~ "Cannot create column pk, the tag 1234 was already used in table aaa" ]] || false
|
||||
[[ "${lines[1]}" =~ "Cannot create column c1, the tag 5678 was already used in table bbb" ]] || false
|
||||
[[ "$output" =~ "Cannot create column pk, the tag 1234 was already used in table aaa" ]] || false
|
||||
[[ "$output" =~ "Cannot create column c1, the tag 5678 was already used in table bbb" ]] || false
|
||||
|
||||
run dolt sql -q "ALTER TABLE aaa ADD COLUMN c1 INT COMMENT 'tag:5678';"
|
||||
[ $status -ne 0 ]
|
||||
[[ "${lines[0]}" =~ "Cannot create column c1, the tag 5678 was already used in table bbb" ]] || false
|
||||
[[ "$output" =~ "Cannot create column c1, the tag 5678 was already used in table bbb" ]] || false
|
||||
}
|
||||
|
||||
@test "Deterministic tag generation produces consistent results" {
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -371,7 +369,7 @@ func execQuery(sqlCtx *sql.Context, mrEnv env.MultiRepoEnv, roots map[string]*do
|
||||
|
||||
sqlSch, rowIter, err := processQuery(sqlCtx, query, se)
|
||||
if err != nil {
|
||||
verr := formatQueryError(query, err)
|
||||
verr := formatQueryError("", err)
|
||||
return nil, verr
|
||||
}
|
||||
|
||||
@@ -404,7 +402,7 @@ func CollectDBs(mrEnv env.MultiRepoEnv, createDB createDBFunc) []dsqle.Database
|
||||
return dbs
|
||||
}
|
||||
|
||||
func formatQueryError(query string, err error) errhand.VerboseError {
|
||||
func formatQueryError(message string, err error) errhand.VerboseError {
|
||||
const (
|
||||
maxStatementLen = 128
|
||||
maxPosWhenTruncated = 64
|
||||
@@ -456,6 +454,9 @@ func formatQueryError(query string, err error) errhand.VerboseError {
|
||||
|
||||
return verrBuilder.Build()
|
||||
} else {
|
||||
if len(message) > 0 {
|
||||
err = fmt.Errorf("%s: %s", message, err.Error())
|
||||
}
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
}
|
||||
@@ -553,31 +554,9 @@ func saveQuery(ctx context.Context, root *doltdb.RootValue, dEnv *env.DoltEnv, q
|
||||
return newRoot, nil
|
||||
}
|
||||
|
||||
// ScanStatements is a split function for a Scanner that returns each SQL statement in the input as a token. It doesn't
|
||||
// work for strings that contain semi-colons. Supporting that requires implementing a state machine.
|
||||
func scanStatements(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.IndexByte(data, ';'); i >= 0 {
|
||||
// We have a full ;-terminated line.
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
// If we're at EOF, we have a final, non-terminated line. Return it.
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
// runBatchMode processes queries until EOF. The Root of the sqlEngine may be updated.
|
||||
func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error {
|
||||
scanner := bufio.NewScanner(input)
|
||||
const maxCapacity = 512 * 1024
|
||||
buf := make([]byte, maxCapacity)
|
||||
scanner.Buffer(buf, maxCapacity)
|
||||
scanner.Split(scanStatements)
|
||||
scanner := NewSqlStatementScanner(input)
|
||||
|
||||
var query string
|
||||
for scanner.Scan() {
|
||||
@@ -585,13 +564,10 @@ func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error {
|
||||
if len(query) == 0 || query == "\n" {
|
||||
continue
|
||||
}
|
||||
if !batchInsertEarlySemicolon(query) {
|
||||
query += ";"
|
||||
// TODO: We should fix this problem by properly implementing a state machine for scanStatements
|
||||
continue
|
||||
}
|
||||
if err := processBatchQuery(ctx, query, se); err != nil {
|
||||
verr := formatQueryError(query, err)
|
||||
// TODO: this line number will not be accurate for errors that occur when flushing a batch of inserts (as opposed
|
||||
// to processing the query)
|
||||
verr := formatQueryError(fmt.Sprintf("error on line %d for query %s", scanner.statementStartLine, query), err)
|
||||
cli.PrintErrln(verr.Verbose())
|
||||
return err
|
||||
}
|
||||
@@ -607,36 +583,6 @@ func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error {
|
||||
return flushBatchedEdits(ctx, se)
|
||||
}
|
||||
|
||||
// batchInsertEarlySemicolon loops through a string to check if Scan stopped too early on a semicolon
|
||||
func batchInsertEarlySemicolon(query string) bool {
|
||||
quotes := []uint8{'\'', '"'}
|
||||
midQuote := false
|
||||
queryLength := len(query)
|
||||
for i := 0; i < queryLength; i++ {
|
||||
for _, quote := range quotes {
|
||||
if query[i] == quote {
|
||||
i++
|
||||
midQuote = true
|
||||
inEscapeMode := false
|
||||
for ; i < queryLength; i++ {
|
||||
if inEscapeMode {
|
||||
inEscapeMode = false
|
||||
} else {
|
||||
if query[i] == quote {
|
||||
midQuote = false
|
||||
break
|
||||
} else if query[i] == '\\' {
|
||||
inEscapeMode = true
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return !midQuote
|
||||
}
|
||||
|
||||
// runShell starts a SQL shell. Returns when the user exits the shell. The Root of the sqlEngine may
|
||||
// be updated by any queries which were processed.
|
||||
func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
|
||||
@@ -695,7 +641,7 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
|
||||
}
|
||||
|
||||
if sqlSch, rowIter, err := processQuery(ctx, query, se); err != nil {
|
||||
verr := formatQueryError(query, err)
|
||||
verr := formatQueryError("", err)
|
||||
shell.Println(verr.Verbose())
|
||||
} else if rowIter != nil {
|
||||
defer rowIter.Close()
|
||||
|
||||
151
go/cmd/dolt/commands/sql_statement_scanner.go
Executable file
151
go/cmd/dolt/commands/sql_statement_scanner.go
Executable file
@@ -0,0 +1,151 @@
|
||||
// Copyright 2020 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package commands
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type statementScanner struct {
|
||||
*bufio.Scanner
|
||||
statementStartLine int // the line number of the first line of the last parsed statement
|
||||
startLineNum int // the line number we began parsing the most recent token at
|
||||
lineNum int // the current line number being parsed
|
||||
}
|
||||
|
||||
func NewSqlStatementScanner(input io.Reader) *statementScanner {
|
||||
scanner := bufio.NewScanner(input)
|
||||
const maxCapacity = 512 * 1024
|
||||
buf := make([]byte, maxCapacity)
|
||||
scanner.Buffer(buf, maxCapacity)
|
||||
|
||||
s := &statementScanner{
|
||||
Scanner: scanner,
|
||||
lineNum: 1,
|
||||
}
|
||||
scanner.Split(s.scanStatements)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
const (
|
||||
sQuote byte = '\''
|
||||
dQuote = '"'
|
||||
backslash = '\\'
|
||||
backtick = '`'
|
||||
)
|
||||
|
||||
// ScanStatements is a split function for a Scanner that returns each SQL statement in the input as a token.
|
||||
func (s *statementScanner) scanStatements(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
quoteChar byte // the opening quote character of the current quote being parsed, or 0 if the current parse location isn't inside a quoted string
|
||||
lastChar byte // the last character parsed
|
||||
ignoreNextChar bool // whether to ignore the next character
|
||||
numConsecutiveBackslashes int // the number of consecutive backslashes encountered
|
||||
seenNonWhitespaceChar bool // whether we have encountered a non-whitespace character since we returned the last token
|
||||
)
|
||||
|
||||
s.startLineNum = s.lineNum
|
||||
|
||||
for i := range data {
|
||||
if !ignoreNextChar {
|
||||
// this doesn't handle unicode characters correctly and will break on some things, but it's only used for line
|
||||
// number reporting.
|
||||
if !seenNonWhitespaceChar && !unicode.IsSpace(rune(data[i])) {
|
||||
seenNonWhitespaceChar = true
|
||||
s.statementStartLine = s.lineNum
|
||||
}
|
||||
|
||||
switch data[i] {
|
||||
case '\n':
|
||||
s.lineNum++
|
||||
case ';':
|
||||
if quoteChar == 0 {
|
||||
s.startLineNum = s.lineNum
|
||||
_, _, _ = s.resetState()
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
case backslash:
|
||||
numConsecutiveBackslashes++
|
||||
case sQuote, dQuote, backtick:
|
||||
prevNumConsecutiveBackslashes := numConsecutiveBackslashes
|
||||
numConsecutiveBackslashes = 0
|
||||
|
||||
// escaped quote character
|
||||
if lastChar == backslash && prevNumConsecutiveBackslashes%2 == 1 {
|
||||
break
|
||||
}
|
||||
|
||||
// currently in a quoted string
|
||||
if quoteChar != 0 {
|
||||
|
||||
// end quote or two consecutive quote characters (a form of escaping quote chars)
|
||||
if quoteChar == data[i] {
|
||||
var nextChar byte = 0
|
||||
if i+1 < len(data) {
|
||||
nextChar = data[i+1]
|
||||
}
|
||||
|
||||
if nextChar == quoteChar {
|
||||
// escaped quote. skip the next character
|
||||
ignoreNextChar = true
|
||||
break
|
||||
} else if atEOF || i+1 < len(data) {
|
||||
// end quote
|
||||
quoteChar = 0
|
||||
break
|
||||
} else {
|
||||
// need more data to make a decision
|
||||
return s.resetState()
|
||||
}
|
||||
}
|
||||
|
||||
// embedded quote ('"' or "'")
|
||||
break
|
||||
}
|
||||
|
||||
// open quote
|
||||
quoteChar = data[i]
|
||||
default:
|
||||
numConsecutiveBackslashes = 0
|
||||
}
|
||||
} else {
|
||||
ignoreNextChar = false
|
||||
}
|
||||
|
||||
lastChar = data[i]
|
||||
}
|
||||
|
||||
// If we're at EOF, we have a final, non-terminated line. Return it.
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
|
||||
// Request more data.
|
||||
return s.resetState()
|
||||
}
|
||||
|
||||
// resetState resets the internal state of the scanner and returns the "more data" response for a split function
|
||||
func (s *statementScanner) resetState() (advance int, token []byte, err error) {
|
||||
// rewind the line number to where we started parsing this token
|
||||
s.lineNum = s.startLineNum
|
||||
return 0, nil, nil
|
||||
}
|
||||
202
go/cmd/dolt/commands/sql_statement_scanner_test.go
Executable file
202
go/cmd/dolt/commands/sql_statement_scanner_test.go
Executable file
@@ -0,0 +1,202 @@
|
||||
// Copyright 2020 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package commands
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestScanStatements(t *testing.T) {
|
||||
type testcase struct {
|
||||
input string
|
||||
statements []string
|
||||
lineNums []int
|
||||
}
|
||||
|
||||
// Some of these include malformed input (e.g. strings that aren't properly terminated)
|
||||
testcases := []testcase{
|
||||
{
|
||||
input: `insert into foo values (";;';'");`,
|
||||
statements: []string{
|
||||
`insert into foo values (";;';'")`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `select ''';;'; select ";\;"`,
|
||||
statements: []string{
|
||||
`select ''';;'`,
|
||||
`select ";\;"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `select ''';;'; select ";\;`,
|
||||
statements: []string{
|
||||
`select ''';;'`,
|
||||
`select ";\;`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `select ''';;'; select ";\;
|
||||
;`,
|
||||
statements: []string{
|
||||
`select ''';;'`,
|
||||
`select ";\;
|
||||
;`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `select '\\'''; select '";";'; select 1`,
|
||||
statements: []string{
|
||||
`select '\\'''`,
|
||||
`select '";";'`,
|
||||
`select 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `select '\\''; select '";";'; select 1`,
|
||||
statements: []string{
|
||||
`select '\\''; select '";"`,
|
||||
`'; select 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `insert into foo values(''); select 1`,
|
||||
statements: []string{
|
||||
`insert into foo values('')`,
|
||||
`select 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `insert into foo values('''); select 1`,
|
||||
statements: []string{
|
||||
`insert into foo values('''); select 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `insert into foo values(''''); select 1`,
|
||||
statements: []string{
|
||||
`insert into foo values('''')`,
|
||||
`select 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `insert into foo values(""); select 1`,
|
||||
statements: []string{
|
||||
`insert into foo values("")`,
|
||||
`select 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `insert into foo values("""); select 1`,
|
||||
statements: []string{
|
||||
`insert into foo values("""); select 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `insert into foo values(""""); select 1`,
|
||||
statements: []string{
|
||||
`insert into foo values("""")`,
|
||||
`select 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `select '\''; select "hell\"o"`,
|
||||
statements: []string{
|
||||
`select '\''`,
|
||||
`select "hell\"o"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `select * from foo; select baz from foo;
|
||||
select
|
||||
a from b; select 1`,
|
||||
statements: []string{
|
||||
"select * from foo",
|
||||
"select baz from foo",
|
||||
"select\na from b",
|
||||
"select 1",
|
||||
},
|
||||
lineNums: []int{
|
||||
1, 1, 2, 3,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "create table dumb (`hell\\`o;` int primary key);",
|
||||
statements: []string{
|
||||
"create table dumb (`hell\\`o;` int primary key)",
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "create table dumb (`hell``o;` int primary key); select \n" +
|
||||
"baz from foo;\n" +
|
||||
"\n" +
|
||||
"select\n" +
|
||||
"a from b; select 1\n\n",
|
||||
statements: []string{
|
||||
"create table dumb (`hell``o;` int primary key)",
|
||||
"select \nbaz from foo",
|
||||
"select\na from b",
|
||||
"select 1",
|
||||
},
|
||||
lineNums: []int{
|
||||
1, 1, 4, 5,
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `insert into foo values ('a', "b;", 'c;;""
|
||||
'); update foo set baz = bar,
|
||||
qux = '"hello"""' where xyzzy = ";;';'";
|
||||
|
||||
|
||||
create table foo (a int not null default ';',
|
||||
primary key (a));`,
|
||||
statements: []string{
|
||||
`insert into foo values ('a', "b;", 'c;;""
|
||||
')`,
|
||||
`update foo set baz = bar,
|
||||
qux = '"hello"""' where xyzzy = ";;';'"`,
|
||||
`create table foo (a int not null default ';',
|
||||
primary key (a))`,
|
||||
},
|
||||
lineNums: []int{
|
||||
1, 2, 6,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testcases {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
reader := strings.NewReader(tt.input)
|
||||
scanner := NewSqlStatementScanner(reader)
|
||||
var i int
|
||||
for scanner.Scan() {
|
||||
require.True(t, i < len(tt.statements))
|
||||
assert.Equal(t, tt.statements[i], strings.TrimSpace(scanner.Text()))
|
||||
if tt.lineNums != nil {
|
||||
assert.Equal(t, tt.lineNums[i], scanner.statementStartLine)
|
||||
} else {
|
||||
assert.Equal(t, 1, scanner.statementStartLine)
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
require.NoError(t, scanner.Err())
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user