Merge pull request #11 from liquidata-inc/zachmu/sql-fast-inserts

Batch inserts for SQL import. This results in a 60x speed increase and much less disk usage.

Also in this change:

Changed how triggering memory / CPU profiling works -- now a command line flag rather than a source change
Better integration test for joins (stock market data)
Better reporting during SQL import
This commit is contained in:
Zach Musgrave
2019-08-05 11:11:54 -07:00
committed by GitHub
10 changed files with 6058 additions and 153 deletions
-18
View File
@@ -15,27 +15,9 @@ teardown() {
rm -rf "$BATS_TMPDIR/dolt-repo-$$"
}
@test "start a sql shell and exit using exit" {
skiponwindows "Works on Windows command prompt but not the WSL terminal used during bats"
run bash -c "echo exit | dolt sql"
[ $status -eq 0 ]
[[ "$output" =~ "# Welcome to the DoltSQL shell." ]] || false
[[ "$output" =~ "Bye" ]] || false
}
@test "start a sql shell and exit using quit" {
skiponwindows "Works on Windows command prompt but not the WSL terminal used during bats"
run bash -c "echo quit | dolt sql"
[ $status -eq 0 ]
[[ "$output" =~ "# Welcome to the DoltSQL shell." ]] || false
[[ "$output" =~ "Bye" ]] || false
}
@test "run a query in sql shell" {
skiponwindows "Works on Windows command prompt but not the WSL terminal used during bats"
run bash -c "echo 'select * from test;' | dolt sql"
[ $status -eq 0 ]
[[ "$output" =~ "# Welcome to the DoltSQL shell." ]] || false
[[ "$output" =~ "pk" ]] || false
[[ "$output" =~ "Bye" ]] || false
}
+139 -25
View File
@@ -15,8 +15,9 @@
package commands
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
@@ -115,8 +116,13 @@ func Sql(commandStr string, args []string, dEnv *env.DoltEnv) int {
}
}
// start an interactive shell
root = runShell(dEnv, root)
// Run in either batch mode for piped input, or shell mode for interactive
fi, _ := os.Stdin.Stat()
if (fi.Mode() & os.ModeCharDevice) == 0 {
root = runBatchMode(dEnv, root)
} else {
root = runShell(dEnv, root)
}
// If the SQL session wrote a new root value, update the working set with it
if root != nil {
@@ -126,6 +132,51 @@ func Sql(commandStr string, args []string, dEnv *env.DoltEnv) int {
return 0
}
// ScanStatements is a split function for a Scanner that returns each SQL statement in the input as a token. It doesn't
// work for strings that contain semi-colons. Supporting that requires implementing a state machine.
func scanStatements(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, ';'); i >= 0 {
// We have a full ;-terminated line.
return i + 1, data[0:i], nil
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
// runBatchMode processes queries until EOF and returns the resulting root value
func runBatchMode(dEnv *env.DoltEnv, root *doltdb.RootValue) *doltdb.RootValue {
scanner := bufio.NewScanner(os.Stdin)
scanner.Split(scanStatements)
batcher := dsql.NewSqlBatcher(dEnv.DoltDB, root)
for scanner.Scan() {
query := scanner.Text()
if newRoot, err := processBatchQuery(query, dEnv, root, batcher); newRoot != nil {
root = newRoot
} else if err != nil {
_, _ = fmt.Fprintf(cli.CliErr, "Error processing query '%s': %s\n", query, err.Error())
}
}
if err := scanner.Err(); err != nil {
cli.Println(err.Error())
}
if newRoot, _ := batcher.Commit(context.Background()); newRoot != nil {
root = newRoot
}
return root
}
// runShell starts a SQL shell. Returns when the user exits the shell with the root value resulting from any queries.
func runShell(dEnv *env.DoltEnv, root *doltdb.RootValue) *doltdb.RootValue {
_ = iohelp.WriteLine(cli.CliOut, welcomeMsg)
@@ -157,6 +208,7 @@ func runShell(dEnv *env.DoltEnv, root *doltdb.RootValue) *doltdb.RootValue {
shell.EOF(func(c *ishell.Context) {
c.Stop()
})
shell.Interrupt(func(c *ishell.Context, count int, input string) {
if count > 1 {
c.Stop()
@@ -181,7 +233,6 @@ func runShell(dEnv *env.DoltEnv, root *doltdb.RootValue) *doltdb.RootValue {
// Longer term we need to switch to a new readline library, like in this bug:
// https://github.com/cockroachdb/cockroach/issues/15460
// For now, we store all history entries as single-line strings to avoid the issue.
// TODO: only store history if it's a tty
singleLine := strings.ReplaceAll(query, "\n", " ")
if err := shell.AddHistory(singleLine); err != nil {
// TODO: handle better, like by turning off history writing for the rest of the session
@@ -290,7 +341,7 @@ func prepend(s string, ss []string) []string {
func processQuery(query string, dEnv *env.DoltEnv, root *doltdb.RootValue) (*doltdb.RootValue, error) {
sqlStatement, err := sqlparser.Parse(query)
if err != nil {
return nil, errFmt("Error parsing SQL: %v.", err.Error())
return nil, fmt.Errorf("Error parsing SQL: %v.", err.Error())
}
switch s := sqlStatement.(type) {
@@ -303,7 +354,7 @@ func processQuery(query string, dEnv *env.DoltEnv, root *doltdb.RootValue) (*dol
}
return nil, err
case *sqlparser.Insert:
return sqlInsert(dEnv, root, s, query)
return sqlInsert(dEnv, root, s)
case *sqlparser.Update:
return sqlUpdate(dEnv, root, s, query)
case *sqlparser.Delete:
@@ -311,11 +362,43 @@ func processQuery(query string, dEnv *env.DoltEnv, root *doltdb.RootValue) (*dol
case *sqlparser.DDL:
_, err := sqlparser.ParseStrictDDL(query)
if err != nil {
return nil, errFmt("Error parsing DDL: %v.", err.Error())
return nil, fmt.Errorf("Error parsing DDL: %v.", err.Error())
}
return sqlDDL(dEnv, root, s, query)
default:
return nil, errFmt("Unsupported SQL statement: '%v'.", query)
return nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query)
}
}
// Processes a single query in batch mode and returns the result. The RootValue may or may not be changed.
func processBatchQuery(query string, dEnv *env.DoltEnv, root *doltdb.RootValue, batcher *dsql.SqlBatcher) (*doltdb.RootValue, error) {
sqlStatement, err := sqlparser.Parse(query)
if err != nil {
return nil, fmt.Errorf("Error parsing SQL: %v.", err.Error())
}
switch s := sqlStatement.(type) {
case *sqlparser.Insert:
return sqlInsertBatch(dEnv, root, s, batcher)
default:
// For any other kind of statement, we need to commit whatever batch edit we've accumulated so far before executing
// the query
newRoot, err := batcher.Commit(context.Background())
if err != nil {
return nil, err
}
newRoot, err = processQuery(query, dEnv, newRoot)
if err != nil {
return nil, err
}
if newRoot != nil {
root = newRoot
if err := batcher.UpdateRoot(root); err != nil {
return nil, err
}
}
return root, nil
}
}
@@ -396,10 +479,10 @@ func prettyPrintResults(nbf *types.NomsBinFormat, sqlSch sql.Schema, rowIter sql
p.Start()
if err := p.Wait(); err != nil {
return errFmt("error processing results: %v", err)
return fmt.Errorf("error processing results: %v", err)
}
if chanErr != io.EOF {
return errFmt("error processing results: %v", chanErr)
return fmt.Errorf("error processing results: %v", chanErr)
}
return nil
@@ -433,17 +516,17 @@ func runPrintingPipeline(nbf *types.NomsBinFormat, p *pipeline.Pipeline, untyped
p.Start()
if err := p.Wait(); err != nil {
return errFmt("error processing results: %v", err)
return fmt.Errorf("error processing results: %v", err)
}
return nil
}
// Executes a SQL insert statement and prints the result to the CLI. Returns the new root value to be written as appropriate.
func sqlInsert(dEnv *env.DoltEnv, root *doltdb.RootValue, stmt *sqlparser.Insert, query string) (*doltdb.RootValue, error) {
result, err := dsql.ExecuteInsert(context.Background(), dEnv.DoltDB, root, stmt, query)
func sqlInsert(dEnv *env.DoltEnv, root *doltdb.RootValue, stmt *sqlparser.Insert) (*doltdb.RootValue, error) {
result, err := dsql.ExecuteInsert(context.Background(), dEnv.DoltDB, root, stmt)
if err != nil {
return nil, errFmt("Error inserting rows: %v", err.Error())
return nil, fmt.Errorf("Error inserting rows: %v", err.Error())
}
cli.Println(fmt.Sprintf("Rows inserted: %v", result.NumRowsInserted))
@@ -457,11 +540,46 @@ func sqlInsert(dEnv *env.DoltEnv, root *doltdb.RootValue, stmt *sqlparser.Insert
return result.Root, nil
}
type stats struct {
numRowsInserted int
numRowsUpdated int
numErrorsIgnored int
}
var batchEditStats stats
var displayStrLen int
// Executes a SQL insert statement in batch mode and returns the new root value (which is usually unchanged) or an
// error. No output is written to the console in batch mode.
func sqlInsertBatch(dEnv *env.DoltEnv, root *doltdb.RootValue, stmt *sqlparser.Insert, batcher *dsql.SqlBatcher) (*doltdb.RootValue, error) {
result, err := dsql.ExecuteBatchInsert(context.Background(), root, stmt, batcher)
if err != nil {
return nil, fmt.Errorf("Error inserting rows: %v", err.Error())
}
mergeResultIntoStats(result, &batchEditStats)
displayStr := fmt.Sprintf("Rows inserted: %d, Updated: %d, Errors: %d",
batchEditStats.numRowsInserted, batchEditStats.numRowsUpdated, batchEditStats.numErrorsIgnored)
displayStrLen = cli.DeleteAndPrint(displayStrLen, displayStr)
if result.Root != nil {
root = result.Root
}
return root, nil
}
func mergeResultIntoStats(result *dsql.InsertResult, stats *stats) {
stats.numRowsInserted += result.NumRowsInserted
stats.numRowsUpdated += result.NumRowsUpdated
stats.numErrorsIgnored += result.NumErrorsIgnored
}
// Executes a SQL update statement and prints the result to the CLI. Returns the new root value to be written as appropriate.
func sqlUpdate(dEnv *env.DoltEnv, root *doltdb.RootValue, update *sqlparser.Update, query string) (*doltdb.RootValue, error) {
result, err := dsql.ExecuteUpdate(context.Background(), dEnv.DoltDB, root, update, query)
if err != nil {
return nil, errFmt("Error during update: %v", err.Error())
return nil, fmt.Errorf("Error during update: %v", err.Error())
}
cli.Println(fmt.Sprintf("Rows updated: %v", result.NumRowsUpdated))
@@ -476,7 +594,7 @@ func sqlUpdate(dEnv *env.DoltEnv, root *doltdb.RootValue, update *sqlparser.Upda
func sqlDelete(dEnv *env.DoltEnv, root *doltdb.RootValue, update *sqlparser.Delete, query string) (*doltdb.RootValue, error) {
result, err := dsql.ExecuteDelete(context.Background(), dEnv.DoltDB, root, update, query)
if err != nil {
return nil, errFmt("Error during update: %v", err.Error())
return nil, fmt.Errorf("Error during update: %v", err.Error())
}
cli.Println(fmt.Sprintf("Rows deleted: %v", result.NumRowsDeleted))
@@ -490,28 +608,24 @@ func sqlDDL(dEnv *env.DoltEnv, root *doltdb.RootValue, ddl *sqlparser.DDL, query
case sqlparser.CreateStr:
newRoot, _, err := dsql.ExecuteCreate(context.Background(), dEnv.DoltDB, root, ddl, query)
if err != nil {
return nil, errFmt("Error creating table: %v", err)
return nil, fmt.Errorf("Error creating table: %v", err)
}
return newRoot, nil
case sqlparser.AlterStr, sqlparser.RenameStr:
newRoot, err := dsql.ExecuteAlter(context.Background(), dEnv.DoltDB, root, ddl, query)
if err != nil {
return nil, errFmt("Error altering table: %v", err)
return nil, fmt.Errorf("Error altering table: %v", err)
}
return newRoot, nil
case sqlparser.DropStr:
newRoot, err := dsql.ExecuteDrop(context.Background(), dEnv.DoltDB, root, ddl, query)
if err != nil {
return nil, errFmt("Error dropping table: %v", err)
return nil, fmt.Errorf("Error dropping table: %v", err)
}
return newRoot, nil
case sqlparser.TruncateStr:
return nil, errFmt("Unhandled DDL action %v in query %v", ddl.Action, query)
return nil, fmt.Errorf("Unhandled DDL action %v in query %v", ddl.Action, query)
default:
return nil, errFmt("Unhandled DDL action %v in query %v", ddl.Action, query)
return nil, fmt.Errorf("Unhandled DDL action %v in query %v", ddl.Action, query)
}
}
func errFmt(fmtMsg string, args ...interface{}) error {
return errors.New(fmt.Sprintf(fmtMsg, args...))
}
+30 -12
View File
@@ -66,25 +66,43 @@ var doltCommand = cli.GenSubCommandHandler([]*cli.Command{
{Name: "conflicts", Desc: "Commands for viewing and resolving merge conflicts.", Func: cnfcmds.Commands, ReqRepo: false},
})
var cpuProf = false
var memProf = false
const profFlag = "--prof"
const cpuProf = "cpu"
const memProf = "mem"
const blockingProf = "blocking"
const traceProf = "trace"
func main() {
os.Exit(runMain())
}
func runMain() int {
if cpuProf {
fmt.Println("cpu profiling enabled.")
defer profile.Start(profile.CPUProfile).Stop()
}
if memProf {
fmt.Println("mem profiling enabled.")
defer profile.Start(profile.MemProfile).Stop()
}
args := os.Args[1:]
if len(args) > 0 && args[0] == profFlag {
if len(os.Args) <= 2 {
panic("Expected a profile arg after " + profFlag)
}
prof := args[1]
switch prof {
case cpuProf:
fmt.Println("cpu profiling enabled.")
defer profile.Start(profile.CPUProfile).Stop()
case memProf:
fmt.Println("mem profiling enabled.")
defer profile.Start(profile.MemProfile).Stop()
case blockingProf:
fmt.Println("block profiling enabled")
defer profile.Start(profile.BlockProfile).Stop()
case traceProf:
fmt.Println("trace profiling enabled")
defer profile.Start(profile.TraceProfile).Stop()
default:
panic("Unexpected prof flag: " + prof)
}
args = args[2:]
}
// Currently goland doesn't support running with a different working directory when using go modules.
// This is a hack that allows a different working directory to be set after the application starts using
// chdir=<DIR>. The syntax is not flexible and must match exactly this.
+218
View File
@@ -0,0 +1,218 @@
// Copyright 2019 Liquidata, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"context"
"errors"
"fmt"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
"github.com/liquidata-inc/dolt/go/store/hash"
"github.com/liquidata-inc/dolt/go/store/types"
)
// SqlBatcher knows how to efficiently batch insert / update statements, e.g. when doing a SQL import. It does this by
// using a single MapEditor per table that isn't persisted until Commit is called.
type SqlBatcher struct {
// The database we are editing
db *doltdb.DoltDB
// The root value we are editing
root *doltdb.RootValue
// The set of tables under edit
tables map[string]*doltdb.Table
// The schemas of tables under edit
schemas map[string]schema.Schema
// The row data for tables being edited
rowData map[string]types.Map
// The editors applying updates to the tables
editors map[string]*types.MapEditor
// The hashes of primary keys being inserted to the tables
hashes map[string]map[hash.Hash]bool
}
// Returns a new SqlBatcher for the given environment and root value.
func NewSqlBatcher(db *doltdb.DoltDB, root *doltdb.RootValue) *SqlBatcher {
batcher := &SqlBatcher{
db: db,
root: root,
}
batcher.resetState()
return batcher
}
// Updates this batcher with a new root value. If there are outstanding edits, returns an error.
func (b *SqlBatcher) UpdateRoot(root *doltdb.RootValue) error {
if b.isDirty() {
return errors.New("UpdateRoot called with outstanding edits")
}
b.root = root
// resetting the state shouldn't be necessary here because of the isDirty check, but if the client chooses to ignore
// the returned error, we'll at least have a clean state going forward
b.resetState()
return nil
}
// isDirty returns whether there are outstanding edits that haven't been committed.
func (b *SqlBatcher) isDirty() bool {
return len(b.editors) > 0
}
// resetState flushes the cache of outstanding edits and other data
func (b *SqlBatcher) resetState() {
b.tables = make(map[string]*doltdb.Table)
b.schemas = make(map[string]schema.Schema)
b.rowData = make(map[string]types.Map)
b.editors = make(map[string]*types.MapEditor)
b.hashes = make(map[string]map[hash.Hash]bool)
}
type InsertOptions struct {
// Whether to silently replace any existing rows with the same primary key as rows inserted
Replace bool
}
type BatchInsertResult struct {
RowInserted bool
RowUpdated bool
}
func (b *SqlBatcher) Insert(ctx context.Context, tableName string, r row.Row, opt InsertOptions) (*BatchInsertResult, error) {
sch, err := b.GetSchema(ctx, tableName)
if err != nil {
return nil, err
}
rowData, err := b.getRowData(ctx, tableName)
if err != nil {
return nil, err
}
ed, err := b.getEditor(ctx, tableName)
if err != nil {
return nil, err
}
key := r.NomsMapKey(sch).Value(ctx)
rowExists := rowData.Get(ctx, key) != nil
hashes := b.getHashes(ctx, tableName)
rowAlreadyTouched := hashes[key.Hash(b.root.VRW().Format())]
if rowExists || rowAlreadyTouched {
if !opt.Replace {
return nil, fmt.Errorf("Duplicate primary key: '%v'", getPrimaryKeyString(r, sch))
}
}
ed.Set(key, r.NomsMapValue(sch))
hashes[key.Hash(b.root.VRW().Format())] = true
return &BatchInsertResult{RowInserted: !rowExists, RowUpdated: rowExists || rowAlreadyTouched}, nil
}
// GetTable returns the table with the name given. This method is offered because reading the table from the root value
// is relatively expensive, and SqlBatcher caches Tables to avoid the overhead.
func (b *SqlBatcher) GetTable(ctx context.Context, tableName string) (*doltdb.Table, error) {
if table, ok := b.tables[tableName]; ok {
return table, nil
}
if !b.root.HasTable(ctx, tableName) {
return nil, fmt.Errorf("Unknown table %v", tableName)
}
table, _ := b.root.GetTable(ctx, tableName)
b.tables[tableName] = table
return table, nil
}
// GetSchema returns the schema for the table name given. This method is offered because reading the schema from disk
// is actually relatively expensive -- SqlBatcher caches the schema values per table to avoid the overhead.
func (b *SqlBatcher) GetSchema(ctx context.Context, tableName string) (schema.Schema, error) {
if schema, ok := b.schemas[tableName]; ok {
return schema, nil
}
table, err := b.GetTable(ctx, tableName)
if err != nil {
return nil, err
}
sch := table.GetSchema(ctx)
b.schemas[tableName] = sch
return sch, nil
}
func (b *SqlBatcher) getEditor(ctx context.Context, tableName string) (*types.MapEditor, error) {
if ed, ok := b.editors[tableName]; ok {
return ed, nil
}
rowData, err := b.getRowData(ctx, tableName)
if err != nil {
return nil, err
}
ed := rowData.Edit()
b.editors[tableName] = ed
return ed, nil
}
func (b *SqlBatcher) getRowData(ctx context.Context, tableName string) (types.Map, error) {
if rowData, ok := b.rowData[tableName]; ok {
return rowData, nil
}
table, err := b.GetTable(ctx, tableName)
if err != nil {
return types.EmptyMap, err
}
rowData := table.GetRowData(ctx)
b.rowData[tableName] = rowData
return rowData, nil
}
func (b *SqlBatcher) getHashes(ctx context.Context, tableName string) map[hash.Hash]bool {
if hashes, ok := b.hashes[tableName]; ok {
return hashes
}
hashes := make(map[hash.Hash]bool)
b.hashes[tableName] = hashes
return hashes
}
// Commit writes a new root value for every table under edit and returns the new root value. Tables are written in an
// arbitrary order.
func (b *SqlBatcher) Commit(ctx context.Context) (*doltdb.RootValue, error) {
root := b.root
for tableName, ed := range b.editors {
newMap := ed.Map(ctx)
table := b.tables[tableName]
table = table.UpdateRows(ctx, newMap)
root = root.PutTable(ctx, b.db, tableName, table)
}
b.root = root
b.resetState()
return root, nil
}
@@ -0,0 +1,265 @@
// Copyright 2019 Liquidata, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"context"
"fmt"
"math/rand"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"vitess.io/vitess/go/vt/sqlparser"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/dtestutils"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
. "github.com/liquidata-inc/dolt/go/libraries/doltcore/sql/sqltestutil"
"github.com/liquidata-inc/dolt/go/store/types"
)
func TestSqlBatchInserts(t *testing.T) {
insertStatements := []string{
`insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
(7, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
`insert into people values
(8, "Milhouse", "VanHouten", false, 1, 5.1, '00000000-0000-0000-0000-000000000008', 677)`,
`insert into people (id, first, last) values (9, "Clancey", "Wiggum")`,
`insert into people (id, first, last) values
(10, "Montgomery", "Burns"), (11, "Ned", "Flanders")`,
`insert into episodes (id, name) values (5, "Bart the General"), (6, "Moaning Lisa")`,
`insert into episodes (id, name) values (7, "The Call of the Simpsons"), (8, "The Telltale Head")`,
`insert into episodes (id, name) values (9, "Life on the Fast Lane")`,
`insert into appearances (character_id, episode_id) values (7,5), (7,6)`,
`insert into appearances (character_id, episode_id) values (8,7)`,
`insert into appearances (character_id, episode_id) values (9,8), (9,9)`,
`insert into appearances (character_id, episode_id) values (10,5), (10,6)`,
`insert into appearances (character_id, episode_id) values (11,9)`,
}
// Shuffle the inserts so that different tables are interleaved. We're not giving a seed here, so this is
// deterministic.
rand.Shuffle(len(insertStatements),
func(i, j int) {
insertStatements[i], insertStatements[j] = insertStatements[j], insertStatements[i]
})
dEnv := dtestutils.CreateTestEnv()
ctx := context.Background()
CreateTestDatabase(dEnv, t)
root, _ := dEnv.WorkingRoot(ctx)
batcher := NewSqlBatcher(dEnv.DoltDB, root)
for _, stmt := range insertStatements {
statement, err := sqlparser.Parse(stmt)
require.NoError(t, err)
insertStmt, ok := statement.(*sqlparser.Insert)
require.True(t, ok)
result, err := ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
require.NoError(t, err)
assert.True(t, result.NumRowsInserted > 0)
assert.Equal(t, 0, result.NumRowsUpdated)
assert.Equal(t, 0, result.NumErrorsIgnored)
}
// Before committing the batch, the database should be unchanged from its original state
allPeopleRows := GetAllRows(root, PeopleTableName)
allEpsRows := GetAllRows(root, EpisodesTableName)
allAppearanceRows := GetAllRows(root, AppearancesTableName)
assert.ElementsMatch(t, AllPeopleRows, allPeopleRows)
assert.ElementsMatch(t, AllEpsRows, allEpsRows)
assert.ElementsMatch(t, AllAppsRows, allAppearanceRows)
// Now commit the batch and check for new rows
root, err := batcher.Commit(ctx)
require.NoError(t, err)
var expectedPeople, expectedEpisodes, expectedAppearances []row.Row
expectedPeople = append(expectedPeople, AllPeopleRows...)
expectedPeople = append(expectedPeople,
NewPeopleRowWithOptionalFields(7, "Maggie", "Simpson", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000007"), 677),
NewPeopleRowWithOptionalFields(8, "Milhouse", "VanHouten", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000008"), 677),
newPeopleRow(9, "Clancey", "Wiggum"),
newPeopleRow(10, "Montgomery", "Burns"),
newPeopleRow(11, "Ned", "Flanders"),
)
expectedEpisodes = append(expectedEpisodes, AllEpsRows...)
expectedEpisodes = append(expectedEpisodes,
newEpsRow(5, "Bart the General"),
newEpsRow(6, "Moaning Lisa"),
newEpsRow(7, "The Call of the Simpsons"),
newEpsRow(8, "The Telltale Head"),
newEpsRow(9, "Life on the Fast Lane"),
)
expectedAppearances = append(expectedAppearances, AllAppsRows...)
expectedAppearances = append(expectedAppearances,
newAppsRow(7, 5),
newAppsRow(7, 6),
newAppsRow(8, 7),
newAppsRow(9, 8),
newAppsRow(9, 9),
newAppsRow(10, 5),
newAppsRow(10, 6),
newAppsRow(11, 9),
)
allPeopleRows = GetAllRows(root, PeopleTableName)
allEpsRows = GetAllRows(root, EpisodesTableName)
allAppearanceRows = GetAllRows(root, AppearancesTableName)
assertRowSetsEqual(t, expectedPeople, allPeopleRows)
assertRowSetsEqual(t, expectedEpisodes, allEpsRows)
assertRowSetsEqual(t, expectedAppearances, allAppearanceRows)
}
func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
insertStatements := []string{
`replace into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
`insert ignore into people values
(2, "Milhouse", "VanHouten", false, 1, 5.1, '00000000-0000-0000-0000-000000000008', 677)`,
}
numRowsUpdated := []int{1, 0}
numErrorsIgnored := []int{0, 1}
dEnv := dtestutils.CreateTestEnv()
ctx := context.Background()
CreateTestDatabase(dEnv, t)
root, _ := dEnv.WorkingRoot(ctx)
batcher := NewSqlBatcher(dEnv.DoltDB, root)
for i := range insertStatements {
stmt := insertStatements[i]
statement, err := sqlparser.Parse(stmt)
require.NoError(t, err)
insertStmt, ok := statement.(*sqlparser.Insert)
require.True(t, ok)
result, err := ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
require.NoError(t, err)
assert.Equal(t, 0, result.NumRowsInserted)
assert.Equal(t, numRowsUpdated[i], result.NumRowsUpdated)
assert.Equal(t, numErrorsIgnored[i], result.NumErrorsIgnored)
}
// Before committing the batch, the database should be unchanged from its original state
allPeopleRows := GetAllRows(root, PeopleTableName)
assert.ElementsMatch(t, AllPeopleRows, allPeopleRows)
// Now commit the batch and check for new rows
root, err := batcher.Commit(ctx)
require.NoError(t, err)
var expectedPeople []row.Row
expectedPeople = append(expectedPeople, AllPeopleRows[1:]...) // skip homer
expectedPeople = append(expectedPeople,
NewPeopleRowWithOptionalFields(0, "Maggie", "Simpson", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000007"), 677),
)
allPeopleRows = GetAllRows(root, PeopleTableName)
assertRowSetsEqual(t, expectedPeople, allPeopleRows)
}
func TestSqlBatchInsertErrors(t *testing.T) {
insertStatements := []string{
`insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
`insert into people values
(2, "Milhouse", "VanHouten", false, 1, 5.1, true, 677)`,
}
dEnv := dtestutils.CreateTestEnv()
ctx := context.Background()
CreateTestDatabase(dEnv, t)
root, _ := dEnv.WorkingRoot(ctx)
batcher := NewSqlBatcher(dEnv.DoltDB, root)
for i := range insertStatements {
stmt := insertStatements[i]
statement, err := sqlparser.Parse(stmt)
require.NoError(t, err)
insertStmt, ok := statement.(*sqlparser.Insert)
require.True(t, ok)
_, err = ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
require.Error(t, err)
}
}
func assertRowSetsEqual(t *testing.T, expected, actual []row.Row) {
equal, diff := rowSetsEqual(expected, actual)
assert.True(t, equal, diff)
}
// Returns whether the two slices of rows contain the same elements using set semantics (no duplicates), and an error
// string if they aren't.
func rowSetsEqual(expected, actual []row.Row) (bool, string) {
if len(expected) != len(actual) {
return false, fmt.Sprintf("Sets have different sizes: expected %d, was %d", len(expected), len(actual))
}
for _, ex := range expected {
if !containsRow(actual, ex) {
return false, fmt.Sprintf("Missing row: %v", ex)
}
}
return true, ""
}
func containsRow(rs []row.Row, r row.Row) bool {
for _, r2 := range rs {
equal, _ := rowsEqual(r, r2)
if equal {
return true
}
}
return false
}
func newPeopleRow(id int, first, last string) row.Row {
vals := row.TaggedValues{
IdTag: types.Int(id),
FirstTag: types.String(first),
LastTag: types.String(last),
}
return row.New(types.Format_7_18, PeopleTestSchema, vals)
}
func newEpsRow(id int, name string) row.Row {
vals := row.TaggedValues{
EpisodeIdTag: types.Int(id),
EpNameTag: types.String(name),
}
return row.New(types.Format_7_18, EpisodesTestSchema, vals)
}
func newAppsRow(charId, epId int) row.Row {
vals := row.TaggedValues{
AppCharacterTag: types.Int(charId),
AppEpTag: types.Int(epId),
}
return row.New(types.Format_7_18, AppearancesTestSchema, vals)
}
+59 -35
View File
@@ -25,7 +25,6 @@ import (
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
"github.com/liquidata-inc/dolt/go/store/hash"
"github.com/liquidata-inc/dolt/go/store/types"
)
@@ -39,14 +38,21 @@ type InsertResult struct {
var ErrMissingPrimaryKeys = errors.New("One or more primary key columns missing from insert statement")
var ConstraintFailedFmt = "Constraint failed for column '%v': %v"
// ExecuteSelect executes the given select query and returns the resultant rows accompanied by their output schema.
func ExecuteInsert(ctx context.Context, db *doltdb.DoltDB, root *doltdb.RootValue, s *sqlparser.Insert, query string) (*InsertResult, error) {
// ExecuteInsertBatch executes the given insert statement in batch mode and returns the result. The table is not changed
// until the batch is Committed. The InsertResult returned similarly doesn't have a Root set, since the root isn't
// modified by this function.
func ExecuteBatchInsert(
ctx context.Context,
root *doltdb.RootValue,
s *sqlparser.Insert,
batcher *SqlBatcher,
) (*InsertResult, error) {
tableName := s.Table.Name.String()
if !root.HasTable(ctx, tableName) {
return errInsert("Unknown table %v", tableName)
tableSch, err := batcher.GetSchema(ctx, tableName)
if err != nil {
return nil, err
}
table, _ := root.GetTable(ctx, tableName)
tableSch := table.GetSchema(ctx)
// Parser supports overwrite on insert with both the replace keyword (from MySQL) as well as the ignore keyword
replace := s.Action == sqlparser.ReplaceStr
@@ -62,13 +68,13 @@ func ExecuteInsert(ctx context.Context, db *doltdb.DoltDB, root *doltdb.RootValu
for i, colName := range s.Columns {
for _, c := range cols {
if c.Name == colName.String() {
return errInsert("Repeated column: '%v'", c.Name)
return nil, fmt.Errorf("Repeated column: '%v'", c.Name)
}
}
col, ok := tableSch.GetAllCols().GetByName(colName.String())
if !ok {
return errInsert(UnknownColumnErrFmt, colName)
return nil, fmt.Errorf(UnknownColumnErrFmt, colName)
}
cols[i] = col
}
@@ -81,24 +87,21 @@ func ExecuteInsert(ctx context.Context, db *doltdb.DoltDB, root *doltdb.RootValu
var err error
rows, err = prepareInsertVals(root.VRW().Format(), cols, &queryRows, tableSch)
if err != nil {
return &InsertResult{}, err
return nil, err
}
case *sqlparser.Select:
return errInsert("Insert as select not supported")
return nil, fmt.Errorf("Insert as select not supported")
case *sqlparser.ParenSelect:
return errInsert("Parenthesized select expressions in insert not supported")
return nil, fmt.Errorf("Parenthesized select expressions in insert not supported")
case *sqlparser.Union:
return errInsert("Union not supported")
return nil, fmt.Errorf("Union not supported")
default:
return errInsert("Unrecognized type for insertRows: %v", queryRows)
return nil, fmt.Errorf("Unrecognized type for insert: %v", queryRows)
}
// Perform the insert
rowData := table.GetRowData(ctx)
me := rowData.Edit()
var result InsertResult
insertedPKHashes := make(map[hash.Hash]struct{})
opt := InsertOptions{replace}
for _, r := range rows {
if !row.IsValid(r, tableSch) {
if ignore {
@@ -106,37 +109,58 @@ func ExecuteInsert(ctx context.Context, db *doltdb.DoltDB, root *doltdb.RootValu
continue
} else {
col, constraint := row.GetInvalidConstraint(r, tableSch)
return nil, errFmt(ConstraintFailedFmt, col.Name, constraint)
return nil, fmt.Errorf(ConstraintFailedFmt, col.Name, constraint)
}
}
key := r.NomsMapKey(tableSch).Value(ctx)
rowExists := rowData.Get(ctx, key) != nil
_, rowInserted := insertedPKHashes[key.Hash(root.VRW().Format())]
if rowExists || rowInserted {
if replace {
result.NumRowsUpdated += 1
} else if ignore {
insertResult, err := batcher.Insert(ctx, tableName, r, opt)
if err != nil {
if ignore {
result.NumErrorsIgnored += 1
continue
} else {
return errInsert("Duplicate primary key: '%v'", getPrimaryKeyString(r, tableSch))
return nil, err
}
}
me.Set(key, r.NomsMapValue(tableSch))
insertedPKHashes[key.Hash(root.VRW().Format())] = struct{}{}
if insertResult.RowInserted {
result.NumRowsInserted++
}
if insertResult.RowUpdated {
result.NumRowsUpdated++
}
}
newMap := me.Map(ctx)
table = table.UpdateRows(ctx, newMap)
result.NumRowsInserted = int(newMap.Len() - rowData.Len())
result.Root = root.PutTable(ctx, db, tableName, table)
return &result, nil
}
// ExecuteInsert executes the given select insert statement and returns the result.
func ExecuteInsert(
ctx context.Context,
db *doltdb.DoltDB,
root *doltdb.RootValue,
s *sqlparser.Insert,
) (*InsertResult, error) {
batcher := NewSqlBatcher(db, root)
insertResult, err := ExecuteBatchInsert(ctx, root, s, batcher)
if err != nil {
return nil, err
}
newRoot, err := batcher.Commit(ctx)
if err != nil {
return nil, err
}
return &InsertResult{
Root: newRoot,
NumRowsInserted: insertResult.NumRowsInserted,
NumRowsUpdated: insertResult.NumRowsUpdated,
NumErrorsIgnored: insertResult.NumErrorsIgnored,
}, nil
}
// Returns a primary key summary of the row given
func getPrimaryKeyString(r row.Row, tableSch schema.Schema) string {
var sb strings.Builder
+2 -2
View File
@@ -315,7 +315,7 @@ func TestExecuteInsert(t *testing.T) {
sqlStatement, _ := sqlparser.Parse(tt.query)
s := sqlStatement.(*sqlparser.Insert)
result, err := ExecuteInsert(ctx, dEnv.DoltDB, root, s, tt.query)
result, err := ExecuteInsert(ctx, dEnv.DoltDB, root, s)
if len(tt.expectedErr) > 0 {
require.Error(t, err)
@@ -334,7 +334,7 @@ func TestExecuteInsert(t *testing.T) {
for _, expectedRow := range tt.insertedValues {
foundRow, ok := table.GetRow(ctx, expectedRow.NomsMapKey(PeopleTestSchema).Value(ctx).(types.Tuple), PeopleTestSchema)
assert.True(t, ok, "Row not found: %v", expectedRow)
require.True(t, ok, "Row not found: %v", expectedRow)
eq, diff := rowsEqual(expectedRow, foundRow)
assert.True(t, eq, "Rows not equals, found diff %v", diff)
}
+44 -26
View File
@@ -15,11 +15,13 @@
package sqltestutil
import (
"context"
"reflect"
"testing"
"github.com/google/uuid"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/dtestutils"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/env"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
@@ -46,16 +48,16 @@ const (
)
const (
episodeIdTag = iota
epNameTag
epAirDateTag
epRatingTag
EpisodeIdTag = iota
EpNameTag
EpAirDateTag
EpRatingTag
)
const (
appCharacterTag = iota
appEpTag
appCommentsTag
AppCharacterTag = iota
AppEpTag
AppCommentsTag
)
const (
@@ -73,11 +75,11 @@ var PeopleTableName = "people"
var EpisodesTestSchema = createEpisodesTestSchema()
var untypedEpisodesSch = untyped.UntypeUnkeySchema(EpisodesTestSchema)
var episodesTableName = "episodes"
var EpisodesTableName = "episodes"
var AppearancesTestSchema = createAppearancesTestSchema()
var untypedAppearacesSch = untyped.UntypeUnkeySchema(AppearancesTestSchema)
var appearancesTableName = "appearances"
var AppearancesTableName = "appearances"
func createPeopleTestSchema() schema.Schema {
colColl, _ := schema.NewColCollection(
@@ -96,19 +98,19 @@ func createPeopleTestSchema() schema.Schema {
func createEpisodesTestSchema() schema.Schema {
colColl, _ := schema.NewColCollection(
schema.NewColumn("id", episodeIdTag, types.IntKind, true, schema.NotNullConstraint{}),
schema.NewColumn("name", epNameTag, types.StringKind, false, schema.NotNullConstraint{}),
schema.NewColumn("air_date", epAirDateTag, types.IntKind, false),
schema.NewColumn("rating", epRatingTag, types.FloatKind, false),
schema.NewColumn("id", EpisodeIdTag, types.IntKind, true, schema.NotNullConstraint{}),
schema.NewColumn("name", EpNameTag, types.StringKind, false, schema.NotNullConstraint{}),
schema.NewColumn("air_date", EpAirDateTag, types.IntKind, false),
schema.NewColumn("rating", EpRatingTag, types.FloatKind, false),
)
return schema.SchemaFromCols(colColl)
}
func createAppearancesTestSchema() schema.Schema {
colColl, _ := schema.NewColCollection(
schema.NewColumn("character_id", appCharacterTag, types.IntKind, true, schema.NotNullConstraint{}),
schema.NewColumn("episode_id", appEpTag, types.IntKind, true, schema.NotNullConstraint{}),
schema.NewColumn("comments", appCommentsTag, types.StringKind, false),
schema.NewColumn("character_id", AppCharacterTag, types.IntKind, true, schema.NotNullConstraint{}),
schema.NewColumn("episode_id", AppEpTag, types.IntKind, true, schema.NotNullConstraint{}),
schema.NewColumn("comments", AppCommentsTag, types.StringKind, false),
)
return schema.SchemaFromCols(colColl)
}
@@ -128,10 +130,10 @@ func NewPeopleRow(id int, first, last string, isMarried bool, age int, rating fl
func newEpsRow(id int, name string, airdate int, rating float32) row.Row {
vals := row.TaggedValues{
episodeIdTag: types.Int(id),
epNameTag: types.String(name),
epAirDateTag: types.Int(airdate),
epRatingTag: types.Float(rating),
EpisodeIdTag: types.Int(id),
EpNameTag: types.String(name),
EpAirDateTag: types.Int(airdate),
EpRatingTag: types.Float(rating),
}
return row.New(types.Format_7_18, EpisodesTestSchema, vals)
@@ -139,9 +141,9 @@ func newEpsRow(id int, name string, airdate int, rating float32) row.Row {
func newAppsRow(charId, epId int, comment string) row.Row {
vals := row.TaggedValues{
appCharacterTag: types.Int(charId),
appEpTag: types.Int(epId),
appCommentsTag: types.String(comment),
AppCharacterTag: types.Int(charId),
AppEpTag: types.Int(epId),
AppCommentsTag: types.String(comment),
}
return row.New(types.Format_7_18, AppearancesTestSchema, vals)
@@ -177,7 +179,7 @@ var Ep1 = newEpsRow(1, "Simpsons Roasting On an Open Fire", 629953200, 8.0)
var Ep2 = newEpsRow(2, "Bart the Genius", 632372400, 9.0)
var Ep3 = newEpsRow(3, "Homer's Odyssey", 632977200, 7.0)
var Ep4 = newEpsRow(4, "There's No Disgrace Like Home", 633582000, 8.5)
var allEpsRows = Rs(Ep1, Ep2, Ep3, Ep4)
var AllEpsRows = Rs(Ep1, Ep2, Ep3, Ep4)
// These are made up, not the actual show data
var app1 = newAppsRow(homerId, 1, "Homer is great in this one")
@@ -266,9 +268,25 @@ func MutateRow(r row.Row, tagsAndVals ...interface{}) row.Row {
return mutated
}
func GetAllRows(root *doltdb.RootValue, tableName string) []row.Row {
ctx := context.Background()
table, _ := root.GetTable(ctx, tableName)
rowData := table.GetRowData(ctx)
sch := table.GetSchema(ctx)
var rows []row.Row
rowData.Iter(ctx, func(key, value types.Value) (stop bool) {
r := row.FromNoms(sch, key.(types.Tuple), value.(types.Tuple))
rows = append(rows, r)
return false
})
return rows
}
// Creates a test database with the test data set in it
func CreateTestDatabase(dEnv *env.DoltEnv, t *testing.T) {
dtestutils.CreateTestTable(t, dEnv, PeopleTableName, PeopleTestSchema, AllPeopleRows...)
dtestutils.CreateTestTable(t, dEnv, episodesTableName, EpisodesTestSchema, allEpsRows...)
dtestutils.CreateTestTable(t, dEnv, appearancesTableName, AppearancesTestSchema, AllAppsRows...)
dtestutils.CreateTestTable(t, dEnv, EpisodesTableName, EpisodesTestSchema, AllEpsRows...)
dtestutils.CreateTestTable(t, dEnv, AppearancesTableName, AppearancesTestSchema, AllAppsRows...)
}
File diff suppressed because it is too large Load Diff
@@ -34,6 +34,8 @@ import (
// Executes all the SQL non-select statements given in the string against the root value given and returns the updated
// root, or an error. Statements in the input string are split by `;\n`
func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*doltdb.RootValue, error) {
batcher := dsql.NewSqlBatcher(dEnv.DoltDB, root)
for _, query := range strings.Split(statements, ";\n") {
if len(strings.Trim(query, " ")) == 0 {
continue
@@ -51,23 +53,19 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
case *sqlparser.Select, *sqlparser.OtherRead:
return nil, errors.New("Select statements aren't handled")
case *sqlparser.Insert:
var result *dsql.InsertResult
result, execErr = dsql.ExecuteInsert(context.Background(), dEnv.DoltDB, root, s, query)
root = result.Root
case *sqlparser.Update:
var result *dsql.UpdateResult
result, execErr = dsql.ExecuteUpdate(context.Background(), dEnv.DoltDB, root, s, query)
root = result.Root
case *sqlparser.Delete:
var result *dsql.DeleteResult
result, execErr = dsql.ExecuteDelete(context.Background(), dEnv.DoltDB, root, s, query)
root = result.Root
_, execErr = dsql.ExecuteBatchInsert(context.Background(), root, s, batcher)
case *sqlparser.DDL:
if root, err = batcher.Commit(context.Background()); err != nil {
return nil, err
}
_, execErr = sqlparser.ParseStrictDDL(query)
if execErr != nil {
return nil, fmt.Errorf("Error parsing DDL: %v.", execErr.Error())
}
root, execErr = sqlDDL(dEnv, root, s, query)
if err := batcher.UpdateRoot(root); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query)
}
@@ -76,6 +74,13 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
return nil, execErr
}
}
if newRoot, err := batcher.Commit(context.Background()); newRoot != nil {
root = newRoot
} else if err != nil {
return nil, err
}
return root, nil
}