Merge pull request #613 from liquidata-inc/zachmu/sql-batch-errors

Real statement tokenization for SQL statements in batch processing mode, instead of line-based scans with hacks for embedded semicolons. Print out the line number of failed queries during batch processing failure.
This commit is contained in:
Zach Musgrave
2020-04-23 10:23:53 -07:00
committed by GitHub
5 changed files with 389 additions and 70 deletions

View File

@@ -52,7 +52,27 @@ insert into test values poop;
SQL
[ "$status" -ne 0 ]
[[ "$output" =~ "Error processing batch" ]] || false
skip "No line number and query on error"
[[ "$output" =~ " 3 " ]] || false
[[ "$output" =~ "insert into test values poop;" ]] || false
[[ "$output" =~ "error on line 3 for query" ]] || false
[[ "$output" =~ "insert into test values poop" ]] || false
run dolt sql <<SQL
insert into test values (0,0,0,0,0,0);
insert into test values (1,0,
0,0,0,0);
insert into
test values (2,0,0,0,0,0)
;
insert into
test values
poop;
insert into test values (3,0,0,0,0,0);
SQL
[ "$status" -ne 0 ]
[[ "$output" =~ "Error processing batch" ]] || false
[[ "$output" =~ "error on line 10 for query" ]] || false
[[ "$output" =~ "poop" ]] || false
}

View File

@@ -383,12 +383,12 @@ CREATE TABLE test (
PRIMARY KEY (pk));
SQL
[ $status -ne 0 ]
[[ "${lines[0]}" =~ "Cannot create column pk, the tag 1234 was already used in table aaa" ]] || false
[[ "${lines[1]}" =~ "Cannot create column c1, the tag 5678 was already used in table bbb" ]] || false
[[ "$output" =~ "Cannot create column pk, the tag 1234 was already used in table aaa" ]] || false
[[ "$output" =~ "Cannot create column c1, the tag 5678 was already used in table bbb" ]] || false
run dolt sql -q "ALTER TABLE aaa ADD COLUMN c1 INT COMMENT 'tag:5678';"
[ $status -ne 0 ]
[[ "${lines[0]}" =~ "Cannot create column c1, the tag 5678 was already used in table bbb" ]] || false
[[ "$output" =~ "Cannot create column c1, the tag 5678 was already used in table bbb" ]] || false
}
@test "Deterministic tag generation produces consistent results" {

View File

@@ -15,8 +15,6 @@
package commands
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
@@ -371,7 +369,7 @@ func execQuery(sqlCtx *sql.Context, mrEnv env.MultiRepoEnv, roots map[string]*do
sqlSch, rowIter, err := processQuery(sqlCtx, query, se)
if err != nil {
verr := formatQueryError(query, err)
verr := formatQueryError("", err)
return nil, verr
}
@@ -404,7 +402,7 @@ func CollectDBs(mrEnv env.MultiRepoEnv, createDB createDBFunc) []dsqle.Database
return dbs
}
func formatQueryError(query string, err error) errhand.VerboseError {
func formatQueryError(message string, err error) errhand.VerboseError {
const (
maxStatementLen = 128
maxPosWhenTruncated = 64
@@ -456,6 +454,9 @@ func formatQueryError(query string, err error) errhand.VerboseError {
return verrBuilder.Build()
} else {
if len(message) > 0 {
err = fmt.Errorf("%s: %s", message, err.Error())
}
return errhand.VerboseErrorFromError(err)
}
}
@@ -553,31 +554,9 @@ func saveQuery(ctx context.Context, root *doltdb.RootValue, dEnv *env.DoltEnv, q
return newRoot, nil
}
// ScanStatements is a split function for a Scanner that returns each SQL statement in the input as a token. It doesn't
// work for strings that contain semi-colons. Supporting that requires implementing a state machine.
func scanStatements(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, ';'); i >= 0 {
// We have a full ;-terminated line.
return i + 1, data[0:i], nil
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
// runBatchMode processes queries until EOF. The Root of the sqlEngine may be updated.
func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error {
scanner := bufio.NewScanner(input)
const maxCapacity = 512 * 1024
buf := make([]byte, maxCapacity)
scanner.Buffer(buf, maxCapacity)
scanner.Split(scanStatements)
scanner := NewSqlStatementScanner(input)
var query string
for scanner.Scan() {
@@ -585,13 +564,10 @@ func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error {
if len(query) == 0 || query == "\n" {
continue
}
if !batchInsertEarlySemicolon(query) {
query += ";"
// TODO: We should fix this problem by properly implementing a state machine for scanStatements
continue
}
if err := processBatchQuery(ctx, query, se); err != nil {
verr := formatQueryError(query, err)
// TODO: this line number will not be accurate for errors that occur when flushing a batch of inserts (as opposed
// to processing the query)
verr := formatQueryError(fmt.Sprintf("error on line %d for query %s", scanner.statementStartLine, query), err)
cli.PrintErrln(verr.Verbose())
return err
}
@@ -607,36 +583,6 @@ func runBatchMode(ctx *sql.Context, se *sqlEngine, input io.Reader) error {
return flushBatchedEdits(ctx, se)
}
// batchInsertEarlySemicolon loops through a string to check if Scan stopped too early on a semicolon
func batchInsertEarlySemicolon(query string) bool {
quotes := []uint8{'\'', '"'}
midQuote := false
queryLength := len(query)
for i := 0; i < queryLength; i++ {
for _, quote := range quotes {
if query[i] == quote {
i++
midQuote = true
inEscapeMode := false
for ; i < queryLength; i++ {
if inEscapeMode {
inEscapeMode = false
} else {
if query[i] == quote {
midQuote = false
break
} else if query[i] == '\\' {
inEscapeMode = true
}
}
}
break
}
}
}
return !midQuote
}
// runShell starts a SQL shell. Returns when the user exits the shell. The Root of the sqlEngine may
// be updated by any queries which were processed.
func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
@@ -695,7 +641,7 @@ func runShell(ctx *sql.Context, se *sqlEngine, mrEnv env.MultiRepoEnv) error {
}
if sqlSch, rowIter, err := processQuery(ctx, query, se); err != nil {
verr := formatQueryError(query, err)
verr := formatQueryError("", err)
shell.Println(verr.Verbose())
} else if rowIter != nil {
defer rowIter.Close()

View File

@@ -0,0 +1,151 @@
// Copyright 2020 Liquidata, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package commands
import (
"bufio"
"io"
"unicode"
)
type statementScanner struct {
*bufio.Scanner
statementStartLine int // the line number of the first line of the last parsed statement
startLineNum int // the line number we began parsing the most recent token at
lineNum int // the current line number being parsed
}
func NewSqlStatementScanner(input io.Reader) *statementScanner {
scanner := bufio.NewScanner(input)
const maxCapacity = 512 * 1024
buf := make([]byte, maxCapacity)
scanner.Buffer(buf, maxCapacity)
s := &statementScanner{
Scanner: scanner,
lineNum: 1,
}
scanner.Split(s.scanStatements)
return s
}
const (
sQuote byte = '\''
dQuote = '"'
backslash = '\\'
backtick = '`'
)
// ScanStatements is a split function for a Scanner that returns each SQL statement in the input as a token.
func (s *statementScanner) scanStatements(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
var (
quoteChar byte // the opening quote character of the current quote being parsed, or 0 if the current parse location isn't inside a quoted string
lastChar byte // the last character parsed
ignoreNextChar bool // whether to ignore the next character
numConsecutiveBackslashes int // the number of consecutive backslashes encountered
seenNonWhitespaceChar bool // whether we have encountered a non-whitespace character since we returned the last token
)
s.startLineNum = s.lineNum
for i := range data {
if !ignoreNextChar {
// this doesn't handle unicode characters correctly and will break on some things, but it's only used for line
// number reporting.
if !seenNonWhitespaceChar && !unicode.IsSpace(rune(data[i])) {
seenNonWhitespaceChar = true
s.statementStartLine = s.lineNum
}
switch data[i] {
case '\n':
s.lineNum++
case ';':
if quoteChar == 0 {
s.startLineNum = s.lineNum
_, _, _ = s.resetState()
return i + 1, data[0:i], nil
}
case backslash:
numConsecutiveBackslashes++
case sQuote, dQuote, backtick:
prevNumConsecutiveBackslashes := numConsecutiveBackslashes
numConsecutiveBackslashes = 0
// escaped quote character
if lastChar == backslash && prevNumConsecutiveBackslashes%2 == 1 {
break
}
// currently in a quoted string
if quoteChar != 0 {
// end quote or two consecutive quote characters (a form of escaping quote chars)
if quoteChar == data[i] {
var nextChar byte = 0
if i+1 < len(data) {
nextChar = data[i+1]
}
if nextChar == quoteChar {
// escaped quote. skip the next character
ignoreNextChar = true
break
} else if atEOF || i+1 < len(data) {
// end quote
quoteChar = 0
break
} else {
// need more data to make a decision
return s.resetState()
}
}
// embedded quote ('"' or "'")
break
}
// open quote
quoteChar = data[i]
default:
numConsecutiveBackslashes = 0
}
} else {
ignoreNextChar = false
}
lastChar = data[i]
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), data, nil
}
// Request more data.
return s.resetState()
}
// resetState resets the internal state of the scanner and returns the "more data" response for a split function
func (s *statementScanner) resetState() (advance int, token []byte, err error) {
// rewind the line number to where we started parsing this token
s.lineNum = s.startLineNum
return 0, nil, nil
}

View File

@@ -0,0 +1,202 @@
// Copyright 2020 Liquidata, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package commands
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestScanStatements(t *testing.T) {
type testcase struct {
input string
statements []string
lineNums []int
}
// Some of these include malformed input (e.g. strings that aren't properly terminated)
testcases := []testcase{
{
input: `insert into foo values (";;';'");`,
statements: []string{
`insert into foo values (";;';'")`,
},
},
{
input: `select ''';;'; select ";\;"`,
statements: []string{
`select ''';;'`,
`select ";\;"`,
},
},
{
input: `select ''';;'; select ";\;`,
statements: []string{
`select ''';;'`,
`select ";\;`,
},
},
{
input: `select ''';;'; select ";\;
;`,
statements: []string{
`select ''';;'`,
`select ";\;
;`,
},
},
{
input: `select '\\'''; select '";";'; select 1`,
statements: []string{
`select '\\'''`,
`select '";";'`,
`select 1`,
},
},
{
input: `select '\\''; select '";";'; select 1`,
statements: []string{
`select '\\''; select '";"`,
`'; select 1`,
},
},
{
input: `insert into foo values(''); select 1`,
statements: []string{
`insert into foo values('')`,
`select 1`,
},
},
{
input: `insert into foo values('''); select 1`,
statements: []string{
`insert into foo values('''); select 1`,
},
},
{
input: `insert into foo values(''''); select 1`,
statements: []string{
`insert into foo values('''')`,
`select 1`,
},
},
{
input: `insert into foo values(""); select 1`,
statements: []string{
`insert into foo values("")`,
`select 1`,
},
},
{
input: `insert into foo values("""); select 1`,
statements: []string{
`insert into foo values("""); select 1`,
},
},
{
input: `insert into foo values(""""); select 1`,
statements: []string{
`insert into foo values("""")`,
`select 1`,
},
},
{
input: `select '\''; select "hell\"o"`,
statements: []string{
`select '\''`,
`select "hell\"o"`,
},
},
{
input: `select * from foo; select baz from foo;
select
a from b; select 1`,
statements: []string{
"select * from foo",
"select baz from foo",
"select\na from b",
"select 1",
},
lineNums: []int{
1, 1, 2, 3,
},
},
{
input: "create table dumb (`hell\\`o;` int primary key);",
statements: []string{
"create table dumb (`hell\\`o;` int primary key)",
},
},
{
input: "create table dumb (`hell``o;` int primary key); select \n" +
"baz from foo;\n" +
"\n" +
"select\n" +
"a from b; select 1\n\n",
statements: []string{
"create table dumb (`hell``o;` int primary key)",
"select \nbaz from foo",
"select\na from b",
"select 1",
},
lineNums: []int{
1, 1, 4, 5,
},
},
{
input: `insert into foo values ('a', "b;", 'c;;""
'); update foo set baz = bar,
qux = '"hello"""' where xyzzy = ";;';'";
create table foo (a int not null default ';',
primary key (a));`,
statements: []string{
`insert into foo values ('a', "b;", 'c;;""
')`,
`update foo set baz = bar,
qux = '"hello"""' where xyzzy = ";;';'"`,
`create table foo (a int not null default ';',
primary key (a))`,
},
lineNums: []int{
1, 2, 6,
},
},
}
for _, tt := range testcases {
t.Run(tt.input, func(t *testing.T) {
reader := strings.NewReader(tt.input)
scanner := NewSqlStatementScanner(reader)
var i int
for scanner.Scan() {
require.True(t, i < len(tt.statements))
assert.Equal(t, tt.statements[i], strings.TrimSpace(scanner.Text()))
if tt.lineNums != nil {
assert.Equal(t, tt.lineNums[i], scanner.statementStartLine)
} else {
assert.Equal(t, 1, scanner.statementStartLine)
}
i++
}
require.NoError(t, scanner.Err())
})
}
}