mirror of
https://github.com/dolthub/dolt.git
synced 2026-05-05 02:45:34 -05:00
Merge pull request #11 from liquidata-inc/zachmu/sql-fast-inserts
Batch inserts for SQL import. This results in a 60x speed increase and much less disk usage. Also in this change: Changed how triggering memory / CPU profiling works -- now a command line flag rather than a source change Better integration test for joins (stock market data) Better reporting during SQL import
This commit is contained in:
@@ -15,27 +15,9 @@ teardown() {
|
||||
rm -rf "$BATS_TMPDIR/dolt-repo-$$"
|
||||
}
|
||||
|
||||
@test "start a sql shell and exit using exit" {
|
||||
skiponwindows "Works on Windows command prompt but not the WSL terminal used during bats"
|
||||
run bash -c "echo exit | dolt sql"
|
||||
[ $status -eq 0 ]
|
||||
[[ "$output" =~ "# Welcome to the DoltSQL shell." ]] || false
|
||||
[[ "$output" =~ "Bye" ]] || false
|
||||
}
|
||||
|
||||
@test "start a sql shell and exit using quit" {
|
||||
skiponwindows "Works on Windows command prompt but not the WSL terminal used during bats"
|
||||
run bash -c "echo quit | dolt sql"
|
||||
[ $status -eq 0 ]
|
||||
[[ "$output" =~ "# Welcome to the DoltSQL shell." ]] || false
|
||||
[[ "$output" =~ "Bye" ]] || false
|
||||
}
|
||||
|
||||
@test "run a query in sql shell" {
|
||||
skiponwindows "Works on Windows command prompt but not the WSL terminal used during bats"
|
||||
run bash -c "echo 'select * from test;' | dolt sql"
|
||||
[ $status -eq 0 ]
|
||||
[[ "$output" =~ "# Welcome to the DoltSQL shell." ]] || false
|
||||
[[ "$output" =~ "pk" ]] || false
|
||||
[[ "$output" =~ "Bye" ]] || false
|
||||
}
|
||||
+139
-25
@@ -15,8 +15,9 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -115,8 +116,13 @@ func Sql(commandStr string, args []string, dEnv *env.DoltEnv) int {
|
||||
}
|
||||
}
|
||||
|
||||
// start an interactive shell
|
||||
root = runShell(dEnv, root)
|
||||
// Run in either batch mode for piped input, or shell mode for interactive
|
||||
fi, _ := os.Stdin.Stat()
|
||||
if (fi.Mode() & os.ModeCharDevice) == 0 {
|
||||
root = runBatchMode(dEnv, root)
|
||||
} else {
|
||||
root = runShell(dEnv, root)
|
||||
}
|
||||
|
||||
// If the SQL session wrote a new root value, update the working set with it
|
||||
if root != nil {
|
||||
@@ -126,6 +132,51 @@ func Sql(commandStr string, args []string, dEnv *env.DoltEnv) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// ScanStatements is a split function for a Scanner that returns each SQL statement in the input as a token. It doesn't
|
||||
// work for strings that contain semi-colons. Supporting that requires implementing a state machine.
|
||||
func scanStatements(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.IndexByte(data, ';'); i >= 0 {
|
||||
// We have a full ;-terminated line.
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
// If we're at EOF, we have a final, non-terminated line. Return it.
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
// runBatchMode processes queries until EOF and returns the resulting root value
|
||||
func runBatchMode(dEnv *env.DoltEnv, root *doltdb.RootValue) *doltdb.RootValue {
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
scanner.Split(scanStatements)
|
||||
|
||||
batcher := dsql.NewSqlBatcher(dEnv.DoltDB, root)
|
||||
|
||||
for scanner.Scan() {
|
||||
query := scanner.Text()
|
||||
if newRoot, err := processBatchQuery(query, dEnv, root, batcher); newRoot != nil {
|
||||
root = newRoot
|
||||
} else if err != nil {
|
||||
_, _ = fmt.Fprintf(cli.CliErr, "Error processing query '%s': %s\n", query, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
cli.Println(err.Error())
|
||||
}
|
||||
|
||||
if newRoot, _ := batcher.Commit(context.Background()); newRoot != nil {
|
||||
root = newRoot
|
||||
}
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
// runShell starts a SQL shell. Returns when the user exits the shell with the root value resulting from any queries.
|
||||
func runShell(dEnv *env.DoltEnv, root *doltdb.RootValue) *doltdb.RootValue {
|
||||
_ = iohelp.WriteLine(cli.CliOut, welcomeMsg)
|
||||
@@ -157,6 +208,7 @@ func runShell(dEnv *env.DoltEnv, root *doltdb.RootValue) *doltdb.RootValue {
|
||||
shell.EOF(func(c *ishell.Context) {
|
||||
c.Stop()
|
||||
})
|
||||
|
||||
shell.Interrupt(func(c *ishell.Context, count int, input string) {
|
||||
if count > 1 {
|
||||
c.Stop()
|
||||
@@ -181,7 +233,6 @@ func runShell(dEnv *env.DoltEnv, root *doltdb.RootValue) *doltdb.RootValue {
|
||||
// Longer term we need to switch to a new readline library, like in this bug:
|
||||
// https://github.com/cockroachdb/cockroach/issues/15460
|
||||
// For now, we store all history entries as single-line strings to avoid the issue.
|
||||
// TODO: only store history if it's a tty
|
||||
singleLine := strings.ReplaceAll(query, "\n", " ")
|
||||
if err := shell.AddHistory(singleLine); err != nil {
|
||||
// TODO: handle better, like by turning off history writing for the rest of the session
|
||||
@@ -290,7 +341,7 @@ func prepend(s string, ss []string) []string {
|
||||
func processQuery(query string, dEnv *env.DoltEnv, root *doltdb.RootValue) (*doltdb.RootValue, error) {
|
||||
sqlStatement, err := sqlparser.Parse(query)
|
||||
if err != nil {
|
||||
return nil, errFmt("Error parsing SQL: %v.", err.Error())
|
||||
return nil, fmt.Errorf("Error parsing SQL: %v.", err.Error())
|
||||
}
|
||||
|
||||
switch s := sqlStatement.(type) {
|
||||
@@ -303,7 +354,7 @@ func processQuery(query string, dEnv *env.DoltEnv, root *doltdb.RootValue) (*dol
|
||||
}
|
||||
return nil, err
|
||||
case *sqlparser.Insert:
|
||||
return sqlInsert(dEnv, root, s, query)
|
||||
return sqlInsert(dEnv, root, s)
|
||||
case *sqlparser.Update:
|
||||
return sqlUpdate(dEnv, root, s, query)
|
||||
case *sqlparser.Delete:
|
||||
@@ -311,11 +362,43 @@ func processQuery(query string, dEnv *env.DoltEnv, root *doltdb.RootValue) (*dol
|
||||
case *sqlparser.DDL:
|
||||
_, err := sqlparser.ParseStrictDDL(query)
|
||||
if err != nil {
|
||||
return nil, errFmt("Error parsing DDL: %v.", err.Error())
|
||||
return nil, fmt.Errorf("Error parsing DDL: %v.", err.Error())
|
||||
}
|
||||
return sqlDDL(dEnv, root, s, query)
|
||||
default:
|
||||
return nil, errFmt("Unsupported SQL statement: '%v'.", query)
|
||||
return nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query)
|
||||
}
|
||||
}
|
||||
|
||||
// Processes a single query in batch mode and returns the result. The RootValue may or may not be changed.
|
||||
func processBatchQuery(query string, dEnv *env.DoltEnv, root *doltdb.RootValue, batcher *dsql.SqlBatcher) (*doltdb.RootValue, error) {
|
||||
sqlStatement, err := sqlparser.Parse(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error parsing SQL: %v.", err.Error())
|
||||
}
|
||||
|
||||
switch s := sqlStatement.(type) {
|
||||
case *sqlparser.Insert:
|
||||
return sqlInsertBatch(dEnv, root, s, batcher)
|
||||
default:
|
||||
// For any other kind of statement, we need to commit whatever batch edit we've accumulated so far before executing
|
||||
// the query
|
||||
newRoot, err := batcher.Commit(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newRoot, err = processQuery(query, dEnv, newRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if newRoot != nil {
|
||||
root = newRoot
|
||||
if err := batcher.UpdateRoot(root); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,10 +479,10 @@ func prettyPrintResults(nbf *types.NomsBinFormat, sqlSch sql.Schema, rowIter sql
|
||||
|
||||
p.Start()
|
||||
if err := p.Wait(); err != nil {
|
||||
return errFmt("error processing results: %v", err)
|
||||
return fmt.Errorf("error processing results: %v", err)
|
||||
}
|
||||
if chanErr != io.EOF {
|
||||
return errFmt("error processing results: %v", chanErr)
|
||||
return fmt.Errorf("error processing results: %v", chanErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -433,17 +516,17 @@ func runPrintingPipeline(nbf *types.NomsBinFormat, p *pipeline.Pipeline, untyped
|
||||
|
||||
p.Start()
|
||||
if err := p.Wait(); err != nil {
|
||||
return errFmt("error processing results: %v", err)
|
||||
return fmt.Errorf("error processing results: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Executes a SQL insert statement and prints the result to the CLI. Returns the new root value to be written as appropriate.
|
||||
func sqlInsert(dEnv *env.DoltEnv, root *doltdb.RootValue, stmt *sqlparser.Insert, query string) (*doltdb.RootValue, error) {
|
||||
result, err := dsql.ExecuteInsert(context.Background(), dEnv.DoltDB, root, stmt, query)
|
||||
func sqlInsert(dEnv *env.DoltEnv, root *doltdb.RootValue, stmt *sqlparser.Insert) (*doltdb.RootValue, error) {
|
||||
result, err := dsql.ExecuteInsert(context.Background(), dEnv.DoltDB, root, stmt)
|
||||
if err != nil {
|
||||
return nil, errFmt("Error inserting rows: %v", err.Error())
|
||||
return nil, fmt.Errorf("Error inserting rows: %v", err.Error())
|
||||
}
|
||||
|
||||
cli.Println(fmt.Sprintf("Rows inserted: %v", result.NumRowsInserted))
|
||||
@@ -457,11 +540,46 @@ func sqlInsert(dEnv *env.DoltEnv, root *doltdb.RootValue, stmt *sqlparser.Insert
|
||||
return result.Root, nil
|
||||
}
|
||||
|
||||
type stats struct {
|
||||
numRowsInserted int
|
||||
numRowsUpdated int
|
||||
numErrorsIgnored int
|
||||
}
|
||||
|
||||
var batchEditStats stats
|
||||
var displayStrLen int
|
||||
|
||||
// Executes a SQL insert statement in batch mode and returns the new root value (which is usually unchanged) or an
|
||||
// error. No output is written to the console in batch mode.
|
||||
func sqlInsertBatch(dEnv *env.DoltEnv, root *doltdb.RootValue, stmt *sqlparser.Insert, batcher *dsql.SqlBatcher) (*doltdb.RootValue, error) {
|
||||
result, err := dsql.ExecuteBatchInsert(context.Background(), root, stmt, batcher)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error inserting rows: %v", err.Error())
|
||||
}
|
||||
mergeResultIntoStats(result, &batchEditStats)
|
||||
|
||||
displayStr := fmt.Sprintf("Rows inserted: %d, Updated: %d, Errors: %d",
|
||||
batchEditStats.numRowsInserted, batchEditStats.numRowsUpdated, batchEditStats.numErrorsIgnored)
|
||||
displayStrLen = cli.DeleteAndPrint(displayStrLen, displayStr)
|
||||
|
||||
if result.Root != nil {
|
||||
root = result.Root
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func mergeResultIntoStats(result *dsql.InsertResult, stats *stats) {
|
||||
stats.numRowsInserted += result.NumRowsInserted
|
||||
stats.numRowsUpdated += result.NumRowsUpdated
|
||||
stats.numErrorsIgnored += result.NumErrorsIgnored
|
||||
}
|
||||
|
||||
// Executes a SQL update statement and prints the result to the CLI. Returns the new root value to be written as appropriate.
|
||||
func sqlUpdate(dEnv *env.DoltEnv, root *doltdb.RootValue, update *sqlparser.Update, query string) (*doltdb.RootValue, error) {
|
||||
result, err := dsql.ExecuteUpdate(context.Background(), dEnv.DoltDB, root, update, query)
|
||||
if err != nil {
|
||||
return nil, errFmt("Error during update: %v", err.Error())
|
||||
return nil, fmt.Errorf("Error during update: %v", err.Error())
|
||||
}
|
||||
|
||||
cli.Println(fmt.Sprintf("Rows updated: %v", result.NumRowsUpdated))
|
||||
@@ -476,7 +594,7 @@ func sqlUpdate(dEnv *env.DoltEnv, root *doltdb.RootValue, update *sqlparser.Upda
|
||||
func sqlDelete(dEnv *env.DoltEnv, root *doltdb.RootValue, update *sqlparser.Delete, query string) (*doltdb.RootValue, error) {
|
||||
result, err := dsql.ExecuteDelete(context.Background(), dEnv.DoltDB, root, update, query)
|
||||
if err != nil {
|
||||
return nil, errFmt("Error during update: %v", err.Error())
|
||||
return nil, fmt.Errorf("Error during update: %v", err.Error())
|
||||
}
|
||||
|
||||
cli.Println(fmt.Sprintf("Rows deleted: %v", result.NumRowsDeleted))
|
||||
@@ -490,28 +608,24 @@ func sqlDDL(dEnv *env.DoltEnv, root *doltdb.RootValue, ddl *sqlparser.DDL, query
|
||||
case sqlparser.CreateStr:
|
||||
newRoot, _, err := dsql.ExecuteCreate(context.Background(), dEnv.DoltDB, root, ddl, query)
|
||||
if err != nil {
|
||||
return nil, errFmt("Error creating table: %v", err)
|
||||
return nil, fmt.Errorf("Error creating table: %v", err)
|
||||
}
|
||||
return newRoot, nil
|
||||
case sqlparser.AlterStr, sqlparser.RenameStr:
|
||||
newRoot, err := dsql.ExecuteAlter(context.Background(), dEnv.DoltDB, root, ddl, query)
|
||||
if err != nil {
|
||||
return nil, errFmt("Error altering table: %v", err)
|
||||
return nil, fmt.Errorf("Error altering table: %v", err)
|
||||
}
|
||||
return newRoot, nil
|
||||
case sqlparser.DropStr:
|
||||
newRoot, err := dsql.ExecuteDrop(context.Background(), dEnv.DoltDB, root, ddl, query)
|
||||
if err != nil {
|
||||
return nil, errFmt("Error dropping table: %v", err)
|
||||
return nil, fmt.Errorf("Error dropping table: %v", err)
|
||||
}
|
||||
return newRoot, nil
|
||||
case sqlparser.TruncateStr:
|
||||
return nil, errFmt("Unhandled DDL action %v in query %v", ddl.Action, query)
|
||||
return nil, fmt.Errorf("Unhandled DDL action %v in query %v", ddl.Action, query)
|
||||
default:
|
||||
return nil, errFmt("Unhandled DDL action %v in query %v", ddl.Action, query)
|
||||
return nil, fmt.Errorf("Unhandled DDL action %v in query %v", ddl.Action, query)
|
||||
}
|
||||
}
|
||||
|
||||
func errFmt(fmtMsg string, args ...interface{}) error {
|
||||
return errors.New(fmt.Sprintf(fmtMsg, args...))
|
||||
}
|
||||
|
||||
+30
-12
@@ -66,25 +66,43 @@ var doltCommand = cli.GenSubCommandHandler([]*cli.Command{
|
||||
{Name: "conflicts", Desc: "Commands for viewing and resolving merge conflicts.", Func: cnfcmds.Commands, ReqRepo: false},
|
||||
})
|
||||
|
||||
var cpuProf = false
|
||||
var memProf = false
|
||||
const profFlag = "--prof"
|
||||
const cpuProf = "cpu"
|
||||
const memProf = "mem"
|
||||
const blockingProf = "blocking"
|
||||
const traceProf = "trace"
|
||||
|
||||
func main() {
|
||||
os.Exit(runMain())
|
||||
}
|
||||
|
||||
func runMain() int {
|
||||
if cpuProf {
|
||||
fmt.Println("cpu profiling enabled.")
|
||||
defer profile.Start(profile.CPUProfile).Stop()
|
||||
}
|
||||
|
||||
if memProf {
|
||||
fmt.Println("mem profiling enabled.")
|
||||
defer profile.Start(profile.MemProfile).Stop()
|
||||
}
|
||||
|
||||
args := os.Args[1:]
|
||||
|
||||
if len(args) > 0 && args[0] == profFlag {
|
||||
if len(os.Args) <= 2 {
|
||||
panic("Expected a profile arg after " + profFlag)
|
||||
}
|
||||
prof := args[1]
|
||||
switch prof {
|
||||
case cpuProf:
|
||||
fmt.Println("cpu profiling enabled.")
|
||||
defer profile.Start(profile.CPUProfile).Stop()
|
||||
case memProf:
|
||||
fmt.Println("mem profiling enabled.")
|
||||
defer profile.Start(profile.MemProfile).Stop()
|
||||
case blockingProf:
|
||||
fmt.Println("block profiling enabled")
|
||||
defer profile.Start(profile.BlockProfile).Stop()
|
||||
case traceProf:
|
||||
fmt.Println("trace profiling enabled")
|
||||
defer profile.Start(profile.TraceProfile).Stop()
|
||||
default:
|
||||
panic("Unexpected prof flag: " + prof)
|
||||
}
|
||||
args = args[2:]
|
||||
}
|
||||
|
||||
// Currently goland doesn't support running with a different working directory when using go modules.
|
||||
// This is a hack that allows a different working directory to be set after the application starts using
|
||||
// chdir=<DIR>. The syntax is not flexible and must match exactly this.
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
// Copyright 2019 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/liquidata-inc/dolt/go/store/hash"
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
// SqlBatcher knows how to efficiently batch insert / update statements, e.g. when doing a SQL import. It does this by
|
||||
// using a single MapEditor per table that isn't persisted until Commit is called.
|
||||
type SqlBatcher struct {
|
||||
// The database we are editing
|
||||
db *doltdb.DoltDB
|
||||
// The root value we are editing
|
||||
root *doltdb.RootValue
|
||||
// The set of tables under edit
|
||||
tables map[string]*doltdb.Table
|
||||
// The schemas of tables under edit
|
||||
schemas map[string]schema.Schema
|
||||
// The row data for tables being edited
|
||||
rowData map[string]types.Map
|
||||
// The editors applying updates to the tables
|
||||
editors map[string]*types.MapEditor
|
||||
// The hashes of primary keys being inserted to the tables
|
||||
hashes map[string]map[hash.Hash]bool
|
||||
}
|
||||
|
||||
// Returns a new SqlBatcher for the given environment and root value.
|
||||
func NewSqlBatcher(db *doltdb.DoltDB, root *doltdb.RootValue) *SqlBatcher {
|
||||
batcher := &SqlBatcher{
|
||||
db: db,
|
||||
root: root,
|
||||
}
|
||||
batcher.resetState()
|
||||
return batcher
|
||||
}
|
||||
|
||||
// Updates this batcher with a new root value. If there are outstanding edits, returns an error.
|
||||
func (b *SqlBatcher) UpdateRoot(root *doltdb.RootValue) error {
|
||||
if b.isDirty() {
|
||||
return errors.New("UpdateRoot called with outstanding edits")
|
||||
}
|
||||
b.root = root
|
||||
|
||||
// resetting the state shouldn't be necessary here because of the isDirty check, but if the client chooses to ignore
|
||||
// the returned error, we'll at least have a clean state going forward
|
||||
b.resetState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// isDirty returns whether there are outstanding edits that haven't been committed.
|
||||
func (b *SqlBatcher) isDirty() bool {
|
||||
return len(b.editors) > 0
|
||||
}
|
||||
|
||||
// resetState flushes the cache of outstanding edits and other data
|
||||
func (b *SqlBatcher) resetState() {
|
||||
b.tables = make(map[string]*doltdb.Table)
|
||||
b.schemas = make(map[string]schema.Schema)
|
||||
b.rowData = make(map[string]types.Map)
|
||||
b.editors = make(map[string]*types.MapEditor)
|
||||
b.hashes = make(map[string]map[hash.Hash]bool)
|
||||
}
|
||||
|
||||
type InsertOptions struct {
|
||||
// Whether to silently replace any existing rows with the same primary key as rows inserted
|
||||
Replace bool
|
||||
}
|
||||
|
||||
type BatchInsertResult struct {
|
||||
RowInserted bool
|
||||
RowUpdated bool
|
||||
}
|
||||
|
||||
func (b *SqlBatcher) Insert(ctx context.Context, tableName string, r row.Row, opt InsertOptions) (*BatchInsertResult, error) {
|
||||
sch, err := b.GetSchema(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rowData, err := b.getRowData(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ed, err := b.getEditor(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := r.NomsMapKey(sch).Value(ctx)
|
||||
|
||||
rowExists := rowData.Get(ctx, key) != nil
|
||||
hashes := b.getHashes(ctx, tableName)
|
||||
rowAlreadyTouched := hashes[key.Hash(b.root.VRW().Format())]
|
||||
|
||||
if rowExists || rowAlreadyTouched {
|
||||
if !opt.Replace {
|
||||
return nil, fmt.Errorf("Duplicate primary key: '%v'", getPrimaryKeyString(r, sch))
|
||||
}
|
||||
}
|
||||
|
||||
ed.Set(key, r.NomsMapValue(sch))
|
||||
hashes[key.Hash(b.root.VRW().Format())] = true
|
||||
|
||||
return &BatchInsertResult{RowInserted: !rowExists, RowUpdated: rowExists || rowAlreadyTouched}, nil
|
||||
}
|
||||
|
||||
// GetTable returns the table with the name given. This method is offered because reading the table from the root value
|
||||
// is relatively expensive, and SqlBatcher caches Tables to avoid the overhead.
|
||||
func (b *SqlBatcher) GetTable(ctx context.Context, tableName string) (*doltdb.Table, error) {
|
||||
if table, ok := b.tables[tableName]; ok {
|
||||
return table, nil
|
||||
}
|
||||
|
||||
if !b.root.HasTable(ctx, tableName) {
|
||||
return nil, fmt.Errorf("Unknown table %v", tableName)
|
||||
}
|
||||
|
||||
table, _ := b.root.GetTable(ctx, tableName)
|
||||
b.tables[tableName] = table
|
||||
return table, nil
|
||||
}
|
||||
|
||||
// GetSchema returns the schema for the table name given. This method is offered because reading the schema from disk
|
||||
// is actually relatively expensive -- SqlBatcher caches the schema values per table to avoid the overhead.
|
||||
func (b *SqlBatcher) GetSchema(ctx context.Context, tableName string) (schema.Schema, error) {
|
||||
if schema, ok := b.schemas[tableName]; ok {
|
||||
return schema, nil
|
||||
}
|
||||
|
||||
table, err := b.GetTable(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sch := table.GetSchema(ctx)
|
||||
b.schemas[tableName] = sch
|
||||
return sch, nil
|
||||
}
|
||||
|
||||
func (b *SqlBatcher) getEditor(ctx context.Context, tableName string) (*types.MapEditor, error) {
|
||||
if ed, ok := b.editors[tableName]; ok {
|
||||
return ed, nil
|
||||
}
|
||||
|
||||
rowData, err := b.getRowData(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ed := rowData.Edit()
|
||||
b.editors[tableName] = ed
|
||||
return ed, nil
|
||||
}
|
||||
|
||||
func (b *SqlBatcher) getRowData(ctx context.Context, tableName string) (types.Map, error) {
|
||||
if rowData, ok := b.rowData[tableName]; ok {
|
||||
return rowData, nil
|
||||
}
|
||||
|
||||
table, err := b.GetTable(ctx, tableName)
|
||||
if err != nil {
|
||||
return types.EmptyMap, err
|
||||
}
|
||||
|
||||
rowData := table.GetRowData(ctx)
|
||||
b.rowData[tableName] = rowData
|
||||
return rowData, nil
|
||||
}
|
||||
|
||||
func (b *SqlBatcher) getHashes(ctx context.Context, tableName string) map[hash.Hash]bool {
|
||||
if hashes, ok := b.hashes[tableName]; ok {
|
||||
return hashes
|
||||
}
|
||||
|
||||
hashes := make(map[hash.Hash]bool)
|
||||
b.hashes[tableName] = hashes
|
||||
return hashes
|
||||
}
|
||||
|
||||
// Commit writes a new root value for every table under edit and returns the new root value. Tables are written in an
|
||||
// arbitrary order.
|
||||
func (b *SqlBatcher) Commit(ctx context.Context) (*doltdb.RootValue, error) {
|
||||
root := b.root
|
||||
|
||||
for tableName, ed := range b.editors {
|
||||
newMap := ed.Map(ctx)
|
||||
table := b.tables[tableName]
|
||||
table = table.UpdateRows(ctx, newMap)
|
||||
root = root.PutTable(ctx, b.db, tableName, table)
|
||||
}
|
||||
|
||||
b.root = root
|
||||
b.resetState()
|
||||
|
||||
return root, nil
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
// Copyright 2019 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"vitess.io/vitess/go/vt/sqlparser"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/dtestutils"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
. "github.com/liquidata-inc/dolt/go/libraries/doltcore/sql/sqltestutil"
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
func TestSqlBatchInserts(t *testing.T) {
|
||||
insertStatements := []string{
|
||||
`insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
|
||||
(7, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
|
||||
`insert into people values
|
||||
(8, "Milhouse", "VanHouten", false, 1, 5.1, '00000000-0000-0000-0000-000000000008', 677)`,
|
||||
`insert into people (id, first, last) values (9, "Clancey", "Wiggum")`,
|
||||
`insert into people (id, first, last) values
|
||||
(10, "Montgomery", "Burns"), (11, "Ned", "Flanders")`,
|
||||
`insert into episodes (id, name) values (5, "Bart the General"), (6, "Moaning Lisa")`,
|
||||
`insert into episodes (id, name) values (7, "The Call of the Simpsons"), (8, "The Telltale Head")`,
|
||||
`insert into episodes (id, name) values (9, "Life on the Fast Lane")`,
|
||||
`insert into appearances (character_id, episode_id) values (7,5), (7,6)`,
|
||||
`insert into appearances (character_id, episode_id) values (8,7)`,
|
||||
`insert into appearances (character_id, episode_id) values (9,8), (9,9)`,
|
||||
`insert into appearances (character_id, episode_id) values (10,5), (10,6)`,
|
||||
`insert into appearances (character_id, episode_id) values (11,9)`,
|
||||
}
|
||||
|
||||
// Shuffle the inserts so that different tables are interleaved. We're not giving a seed here, so this is
|
||||
// deterministic.
|
||||
rand.Shuffle(len(insertStatements),
|
||||
func(i, j int) {
|
||||
insertStatements[i], insertStatements[j] = insertStatements[j], insertStatements[i]
|
||||
})
|
||||
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
ctx := context.Background()
|
||||
|
||||
CreateTestDatabase(dEnv, t)
|
||||
root, _ := dEnv.WorkingRoot(ctx)
|
||||
|
||||
batcher := NewSqlBatcher(dEnv.DoltDB, root)
|
||||
for _, stmt := range insertStatements {
|
||||
statement, err := sqlparser.Parse(stmt)
|
||||
require.NoError(t, err)
|
||||
insertStmt, ok := statement.(*sqlparser.Insert)
|
||||
require.True(t, ok)
|
||||
result, err := ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.NumRowsInserted > 0)
|
||||
assert.Equal(t, 0, result.NumRowsUpdated)
|
||||
assert.Equal(t, 0, result.NumErrorsIgnored)
|
||||
}
|
||||
|
||||
// Before committing the batch, the database should be unchanged from its original state
|
||||
allPeopleRows := GetAllRows(root, PeopleTableName)
|
||||
allEpsRows := GetAllRows(root, EpisodesTableName)
|
||||
allAppearanceRows := GetAllRows(root, AppearancesTableName)
|
||||
|
||||
assert.ElementsMatch(t, AllPeopleRows, allPeopleRows)
|
||||
assert.ElementsMatch(t, AllEpsRows, allEpsRows)
|
||||
assert.ElementsMatch(t, AllAppsRows, allAppearanceRows)
|
||||
|
||||
// Now commit the batch and check for new rows
|
||||
root, err := batcher.Commit(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var expectedPeople, expectedEpisodes, expectedAppearances []row.Row
|
||||
|
||||
expectedPeople = append(expectedPeople, AllPeopleRows...)
|
||||
expectedPeople = append(expectedPeople,
|
||||
NewPeopleRowWithOptionalFields(7, "Maggie", "Simpson", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000007"), 677),
|
||||
NewPeopleRowWithOptionalFields(8, "Milhouse", "VanHouten", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000008"), 677),
|
||||
newPeopleRow(9, "Clancey", "Wiggum"),
|
||||
newPeopleRow(10, "Montgomery", "Burns"),
|
||||
newPeopleRow(11, "Ned", "Flanders"),
|
||||
)
|
||||
|
||||
expectedEpisodes = append(expectedEpisodes, AllEpsRows...)
|
||||
expectedEpisodes = append(expectedEpisodes,
|
||||
newEpsRow(5, "Bart the General"),
|
||||
newEpsRow(6, "Moaning Lisa"),
|
||||
newEpsRow(7, "The Call of the Simpsons"),
|
||||
newEpsRow(8, "The Telltale Head"),
|
||||
newEpsRow(9, "Life on the Fast Lane"),
|
||||
)
|
||||
|
||||
expectedAppearances = append(expectedAppearances, AllAppsRows...)
|
||||
expectedAppearances = append(expectedAppearances,
|
||||
newAppsRow(7, 5),
|
||||
newAppsRow(7, 6),
|
||||
newAppsRow(8, 7),
|
||||
newAppsRow(9, 8),
|
||||
newAppsRow(9, 9),
|
||||
newAppsRow(10, 5),
|
||||
newAppsRow(10, 6),
|
||||
newAppsRow(11, 9),
|
||||
)
|
||||
|
||||
allPeopleRows = GetAllRows(root, PeopleTableName)
|
||||
allEpsRows = GetAllRows(root, EpisodesTableName)
|
||||
allAppearanceRows = GetAllRows(root, AppearancesTableName)
|
||||
|
||||
assertRowSetsEqual(t, expectedPeople, allPeopleRows)
|
||||
assertRowSetsEqual(t, expectedEpisodes, allEpsRows)
|
||||
assertRowSetsEqual(t, expectedAppearances, allAppearanceRows)
|
||||
}
|
||||
|
||||
func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
|
||||
insertStatements := []string{
|
||||
`replace into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
|
||||
(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
|
||||
`insert ignore into people values
|
||||
(2, "Milhouse", "VanHouten", false, 1, 5.1, '00000000-0000-0000-0000-000000000008', 677)`,
|
||||
}
|
||||
numRowsUpdated := []int{1, 0}
|
||||
numErrorsIgnored := []int{0, 1}
|
||||
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
ctx := context.Background()
|
||||
|
||||
CreateTestDatabase(dEnv, t)
|
||||
root, _ := dEnv.WorkingRoot(ctx)
|
||||
|
||||
batcher := NewSqlBatcher(dEnv.DoltDB, root)
|
||||
for i := range insertStatements {
|
||||
stmt := insertStatements[i]
|
||||
statement, err := sqlparser.Parse(stmt)
|
||||
require.NoError(t, err)
|
||||
insertStmt, ok := statement.(*sqlparser.Insert)
|
||||
require.True(t, ok)
|
||||
result, err := ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.NumRowsInserted)
|
||||
assert.Equal(t, numRowsUpdated[i], result.NumRowsUpdated)
|
||||
assert.Equal(t, numErrorsIgnored[i], result.NumErrorsIgnored)
|
||||
}
|
||||
|
||||
// Before committing the batch, the database should be unchanged from its original state
|
||||
allPeopleRows := GetAllRows(root, PeopleTableName)
|
||||
assert.ElementsMatch(t, AllPeopleRows, allPeopleRows)
|
||||
|
||||
// Now commit the batch and check for new rows
|
||||
root, err := batcher.Commit(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var expectedPeople []row.Row
|
||||
|
||||
expectedPeople = append(expectedPeople, AllPeopleRows[1:]...) // skip homer
|
||||
expectedPeople = append(expectedPeople,
|
||||
NewPeopleRowWithOptionalFields(0, "Maggie", "Simpson", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000007"), 677),
|
||||
)
|
||||
|
||||
allPeopleRows = GetAllRows(root, PeopleTableName)
|
||||
assertRowSetsEqual(t, expectedPeople, allPeopleRows)
|
||||
}
|
||||
|
||||
func TestSqlBatchInsertErrors(t *testing.T) {
|
||||
insertStatements := []string{
|
||||
`insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
|
||||
(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
|
||||
`insert into people values
|
||||
(2, "Milhouse", "VanHouten", false, 1, 5.1, true, 677)`,
|
||||
}
|
||||
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
ctx := context.Background()
|
||||
|
||||
CreateTestDatabase(dEnv, t)
|
||||
root, _ := dEnv.WorkingRoot(ctx)
|
||||
|
||||
batcher := NewSqlBatcher(dEnv.DoltDB, root)
|
||||
for i := range insertStatements {
|
||||
stmt := insertStatements[i]
|
||||
statement, err := sqlparser.Parse(stmt)
|
||||
require.NoError(t, err)
|
||||
insertStmt, ok := statement.(*sqlparser.Insert)
|
||||
require.True(t, ok)
|
||||
_, err = ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
|
||||
require.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertRowSetsEqual(t *testing.T, expected, actual []row.Row) {
|
||||
equal, diff := rowSetsEqual(expected, actual)
|
||||
assert.True(t, equal, diff)
|
||||
}
|
||||
|
||||
// Returns whether the two slices of rows contain the same elements using set semantics (no duplicates), and an error
|
||||
// string if they aren't.
|
||||
func rowSetsEqual(expected, actual []row.Row) (bool, string) {
|
||||
if len(expected) != len(actual) {
|
||||
return false, fmt.Sprintf("Sets have different sizes: expected %d, was %d", len(expected), len(actual))
|
||||
}
|
||||
|
||||
for _, ex := range expected {
|
||||
if !containsRow(actual, ex) {
|
||||
return false, fmt.Sprintf("Missing row: %v", ex)
|
||||
}
|
||||
}
|
||||
|
||||
return true, ""
|
||||
}
|
||||
|
||||
func containsRow(rs []row.Row, r row.Row) bool {
|
||||
for _, r2 := range rs {
|
||||
equal, _ := rowsEqual(r, r2)
|
||||
if equal {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func newPeopleRow(id int, first, last string) row.Row {
|
||||
vals := row.TaggedValues{
|
||||
IdTag: types.Int(id),
|
||||
FirstTag: types.String(first),
|
||||
LastTag: types.String(last),
|
||||
}
|
||||
|
||||
return row.New(types.Format_7_18, PeopleTestSchema, vals)
|
||||
}
|
||||
|
||||
func newEpsRow(id int, name string) row.Row {
|
||||
vals := row.TaggedValues{
|
||||
EpisodeIdTag: types.Int(id),
|
||||
EpNameTag: types.String(name),
|
||||
}
|
||||
|
||||
return row.New(types.Format_7_18, EpisodesTestSchema, vals)
|
||||
}
|
||||
|
||||
func newAppsRow(charId, epId int) row.Row {
|
||||
vals := row.TaggedValues{
|
||||
AppCharacterTag: types.Int(charId),
|
||||
AppEpTag: types.Int(epId),
|
||||
}
|
||||
|
||||
return row.New(types.Format_7_18, AppearancesTestSchema, vals)
|
||||
}
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/liquidata-inc/dolt/go/store/hash"
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
@@ -39,14 +38,21 @@ type InsertResult struct {
|
||||
var ErrMissingPrimaryKeys = errors.New("One or more primary key columns missing from insert statement")
|
||||
var ConstraintFailedFmt = "Constraint failed for column '%v': %v"
|
||||
|
||||
// ExecuteSelect executes the given select query and returns the resultant rows accompanied by their output schema.
|
||||
func ExecuteInsert(ctx context.Context, db *doltdb.DoltDB, root *doltdb.RootValue, s *sqlparser.Insert, query string) (*InsertResult, error) {
|
||||
// ExecuteInsertBatch executes the given insert statement in batch mode and returns the result. The table is not changed
|
||||
// until the batch is Committed. The InsertResult returned similarly doesn't have a Root set, since the root isn't
|
||||
// modified by this function.
|
||||
func ExecuteBatchInsert(
|
||||
ctx context.Context,
|
||||
root *doltdb.RootValue,
|
||||
s *sqlparser.Insert,
|
||||
batcher *SqlBatcher,
|
||||
) (*InsertResult, error) {
|
||||
|
||||
tableName := s.Table.Name.String()
|
||||
if !root.HasTable(ctx, tableName) {
|
||||
return errInsert("Unknown table %v", tableName)
|
||||
tableSch, err := batcher.GetSchema(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
table, _ := root.GetTable(ctx, tableName)
|
||||
tableSch := table.GetSchema(ctx)
|
||||
|
||||
// Parser supports overwrite on insert with both the replace keyword (from MySQL) as well as the ignore keyword
|
||||
replace := s.Action == sqlparser.ReplaceStr
|
||||
@@ -62,13 +68,13 @@ func ExecuteInsert(ctx context.Context, db *doltdb.DoltDB, root *doltdb.RootValu
|
||||
for i, colName := range s.Columns {
|
||||
for _, c := range cols {
|
||||
if c.Name == colName.String() {
|
||||
return errInsert("Repeated column: '%v'", c.Name)
|
||||
return nil, fmt.Errorf("Repeated column: '%v'", c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
col, ok := tableSch.GetAllCols().GetByName(colName.String())
|
||||
if !ok {
|
||||
return errInsert(UnknownColumnErrFmt, colName)
|
||||
return nil, fmt.Errorf(UnknownColumnErrFmt, colName)
|
||||
}
|
||||
cols[i] = col
|
||||
}
|
||||
@@ -81,24 +87,21 @@ func ExecuteInsert(ctx context.Context, db *doltdb.DoltDB, root *doltdb.RootValu
|
||||
var err error
|
||||
rows, err = prepareInsertVals(root.VRW().Format(), cols, &queryRows, tableSch)
|
||||
if err != nil {
|
||||
return &InsertResult{}, err
|
||||
return nil, err
|
||||
}
|
||||
case *sqlparser.Select:
|
||||
return errInsert("Insert as select not supported")
|
||||
return nil, fmt.Errorf("Insert as select not supported")
|
||||
case *sqlparser.ParenSelect:
|
||||
return errInsert("Parenthesized select expressions in insert not supported")
|
||||
return nil, fmt.Errorf("Parenthesized select expressions in insert not supported")
|
||||
case *sqlparser.Union:
|
||||
return errInsert("Union not supported")
|
||||
return nil, fmt.Errorf("Union not supported")
|
||||
default:
|
||||
return errInsert("Unrecognized type for insertRows: %v", queryRows)
|
||||
return nil, fmt.Errorf("Unrecognized type for insert: %v", queryRows)
|
||||
}
|
||||
|
||||
// Perform the insert
|
||||
rowData := table.GetRowData(ctx)
|
||||
me := rowData.Edit()
|
||||
var result InsertResult
|
||||
|
||||
insertedPKHashes := make(map[hash.Hash]struct{})
|
||||
opt := InsertOptions{replace}
|
||||
for _, r := range rows {
|
||||
if !row.IsValid(r, tableSch) {
|
||||
if ignore {
|
||||
@@ -106,37 +109,58 @@ func ExecuteInsert(ctx context.Context, db *doltdb.DoltDB, root *doltdb.RootValu
|
||||
continue
|
||||
} else {
|
||||
col, constraint := row.GetInvalidConstraint(r, tableSch)
|
||||
return nil, errFmt(ConstraintFailedFmt, col.Name, constraint)
|
||||
return nil, fmt.Errorf(ConstraintFailedFmt, col.Name, constraint)
|
||||
}
|
||||
}
|
||||
|
||||
key := r.NomsMapKey(tableSch).Value(ctx)
|
||||
|
||||
rowExists := rowData.Get(ctx, key) != nil
|
||||
_, rowInserted := insertedPKHashes[key.Hash(root.VRW().Format())]
|
||||
|
||||
if rowExists || rowInserted {
|
||||
if replace {
|
||||
result.NumRowsUpdated += 1
|
||||
} else if ignore {
|
||||
insertResult, err := batcher.Insert(ctx, tableName, r, opt)
|
||||
if err != nil {
|
||||
if ignore {
|
||||
result.NumErrorsIgnored += 1
|
||||
continue
|
||||
} else {
|
||||
return errInsert("Duplicate primary key: '%v'", getPrimaryKeyString(r, tableSch))
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
me.Set(key, r.NomsMapValue(tableSch))
|
||||
|
||||
insertedPKHashes[key.Hash(root.VRW().Format())] = struct{}{}
|
||||
if insertResult.RowInserted {
|
||||
result.NumRowsInserted++
|
||||
}
|
||||
if insertResult.RowUpdated {
|
||||
result.NumRowsUpdated++
|
||||
}
|
||||
}
|
||||
newMap := me.Map(ctx)
|
||||
table = table.UpdateRows(ctx, newMap)
|
||||
|
||||
result.NumRowsInserted = int(newMap.Len() - rowData.Len())
|
||||
result.Root = root.PutTable(ctx, db, tableName, table)
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ExecuteInsert executes the given select insert statement and returns the result.
|
||||
func ExecuteInsert(
|
||||
ctx context.Context,
|
||||
db *doltdb.DoltDB,
|
||||
root *doltdb.RootValue,
|
||||
s *sqlparser.Insert,
|
||||
) (*InsertResult, error) {
|
||||
|
||||
batcher := NewSqlBatcher(db, root)
|
||||
insertResult, err := ExecuteBatchInsert(ctx, root, s, batcher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRoot, err := batcher.Commit(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &InsertResult{
|
||||
Root: newRoot,
|
||||
NumRowsInserted: insertResult.NumRowsInserted,
|
||||
NumRowsUpdated: insertResult.NumRowsUpdated,
|
||||
NumErrorsIgnored: insertResult.NumErrorsIgnored,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Returns a primary key summary of the row given
|
||||
func getPrimaryKeyString(r row.Row, tableSch schema.Schema) string {
|
||||
var sb strings.Builder
|
||||
|
||||
@@ -315,7 +315,7 @@ func TestExecuteInsert(t *testing.T) {
|
||||
sqlStatement, _ := sqlparser.Parse(tt.query)
|
||||
s := sqlStatement.(*sqlparser.Insert)
|
||||
|
||||
result, err := ExecuteInsert(ctx, dEnv.DoltDB, root, s, tt.query)
|
||||
result, err := ExecuteInsert(ctx, dEnv.DoltDB, root, s)
|
||||
|
||||
if len(tt.expectedErr) > 0 {
|
||||
require.Error(t, err)
|
||||
@@ -334,7 +334,7 @@ func TestExecuteInsert(t *testing.T) {
|
||||
|
||||
for _, expectedRow := range tt.insertedValues {
|
||||
foundRow, ok := table.GetRow(ctx, expectedRow.NomsMapKey(PeopleTestSchema).Value(ctx).(types.Tuple), PeopleTestSchema)
|
||||
assert.True(t, ok, "Row not found: %v", expectedRow)
|
||||
require.True(t, ok, "Row not found: %v", expectedRow)
|
||||
eq, diff := rowsEqual(expectedRow, foundRow)
|
||||
assert.True(t, eq, "Rows not equals, found diff %v", diff)
|
||||
}
|
||||
|
||||
@@ -15,11 +15,13 @@
|
||||
package sqltestutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/dtestutils"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/env"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
@@ -46,16 +48,16 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
episodeIdTag = iota
|
||||
epNameTag
|
||||
epAirDateTag
|
||||
epRatingTag
|
||||
EpisodeIdTag = iota
|
||||
EpNameTag
|
||||
EpAirDateTag
|
||||
EpRatingTag
|
||||
)
|
||||
|
||||
const (
|
||||
appCharacterTag = iota
|
||||
appEpTag
|
||||
appCommentsTag
|
||||
AppCharacterTag = iota
|
||||
AppEpTag
|
||||
AppCommentsTag
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -73,11 +75,11 @@ var PeopleTableName = "people"
|
||||
|
||||
var EpisodesTestSchema = createEpisodesTestSchema()
|
||||
var untypedEpisodesSch = untyped.UntypeUnkeySchema(EpisodesTestSchema)
|
||||
var episodesTableName = "episodes"
|
||||
var EpisodesTableName = "episodes"
|
||||
|
||||
var AppearancesTestSchema = createAppearancesTestSchema()
|
||||
var untypedAppearacesSch = untyped.UntypeUnkeySchema(AppearancesTestSchema)
|
||||
var appearancesTableName = "appearances"
|
||||
var AppearancesTableName = "appearances"
|
||||
|
||||
func createPeopleTestSchema() schema.Schema {
|
||||
colColl, _ := schema.NewColCollection(
|
||||
@@ -96,19 +98,19 @@ func createPeopleTestSchema() schema.Schema {
|
||||
|
||||
func createEpisodesTestSchema() schema.Schema {
|
||||
colColl, _ := schema.NewColCollection(
|
||||
schema.NewColumn("id", episodeIdTag, types.IntKind, true, schema.NotNullConstraint{}),
|
||||
schema.NewColumn("name", epNameTag, types.StringKind, false, schema.NotNullConstraint{}),
|
||||
schema.NewColumn("air_date", epAirDateTag, types.IntKind, false),
|
||||
schema.NewColumn("rating", epRatingTag, types.FloatKind, false),
|
||||
schema.NewColumn("id", EpisodeIdTag, types.IntKind, true, schema.NotNullConstraint{}),
|
||||
schema.NewColumn("name", EpNameTag, types.StringKind, false, schema.NotNullConstraint{}),
|
||||
schema.NewColumn("air_date", EpAirDateTag, types.IntKind, false),
|
||||
schema.NewColumn("rating", EpRatingTag, types.FloatKind, false),
|
||||
)
|
||||
return schema.SchemaFromCols(colColl)
|
||||
}
|
||||
|
||||
func createAppearancesTestSchema() schema.Schema {
|
||||
colColl, _ := schema.NewColCollection(
|
||||
schema.NewColumn("character_id", appCharacterTag, types.IntKind, true, schema.NotNullConstraint{}),
|
||||
schema.NewColumn("episode_id", appEpTag, types.IntKind, true, schema.NotNullConstraint{}),
|
||||
schema.NewColumn("comments", appCommentsTag, types.StringKind, false),
|
||||
schema.NewColumn("character_id", AppCharacterTag, types.IntKind, true, schema.NotNullConstraint{}),
|
||||
schema.NewColumn("episode_id", AppEpTag, types.IntKind, true, schema.NotNullConstraint{}),
|
||||
schema.NewColumn("comments", AppCommentsTag, types.StringKind, false),
|
||||
)
|
||||
return schema.SchemaFromCols(colColl)
|
||||
}
|
||||
@@ -128,10 +130,10 @@ func NewPeopleRow(id int, first, last string, isMarried bool, age int, rating fl
|
||||
|
||||
func newEpsRow(id int, name string, airdate int, rating float32) row.Row {
|
||||
vals := row.TaggedValues{
|
||||
episodeIdTag: types.Int(id),
|
||||
epNameTag: types.String(name),
|
||||
epAirDateTag: types.Int(airdate),
|
||||
epRatingTag: types.Float(rating),
|
||||
EpisodeIdTag: types.Int(id),
|
||||
EpNameTag: types.String(name),
|
||||
EpAirDateTag: types.Int(airdate),
|
||||
EpRatingTag: types.Float(rating),
|
||||
}
|
||||
|
||||
return row.New(types.Format_7_18, EpisodesTestSchema, vals)
|
||||
@@ -139,9 +141,9 @@ func newEpsRow(id int, name string, airdate int, rating float32) row.Row {
|
||||
|
||||
func newAppsRow(charId, epId int, comment string) row.Row {
|
||||
vals := row.TaggedValues{
|
||||
appCharacterTag: types.Int(charId),
|
||||
appEpTag: types.Int(epId),
|
||||
appCommentsTag: types.String(comment),
|
||||
AppCharacterTag: types.Int(charId),
|
||||
AppEpTag: types.Int(epId),
|
||||
AppCommentsTag: types.String(comment),
|
||||
}
|
||||
|
||||
return row.New(types.Format_7_18, AppearancesTestSchema, vals)
|
||||
@@ -177,7 +179,7 @@ var Ep1 = newEpsRow(1, "Simpsons Roasting On an Open Fire", 629953200, 8.0)
|
||||
var Ep2 = newEpsRow(2, "Bart the Genius", 632372400, 9.0)
|
||||
var Ep3 = newEpsRow(3, "Homer's Odyssey", 632977200, 7.0)
|
||||
var Ep4 = newEpsRow(4, "There's No Disgrace Like Home", 633582000, 8.5)
|
||||
var allEpsRows = Rs(Ep1, Ep2, Ep3, Ep4)
|
||||
var AllEpsRows = Rs(Ep1, Ep2, Ep3, Ep4)
|
||||
|
||||
// These are made up, not the actual show data
|
||||
var app1 = newAppsRow(homerId, 1, "Homer is great in this one")
|
||||
@@ -266,9 +268,25 @@ func MutateRow(r row.Row, tagsAndVals ...interface{}) row.Row {
|
||||
return mutated
|
||||
}
|
||||
|
||||
func GetAllRows(root *doltdb.RootValue, tableName string) []row.Row {
|
||||
ctx := context.Background()
|
||||
table, _ := root.GetTable(ctx, tableName)
|
||||
rowData := table.GetRowData(ctx)
|
||||
sch := table.GetSchema(ctx)
|
||||
|
||||
var rows []row.Row
|
||||
rowData.Iter(ctx, func(key, value types.Value) (stop bool) {
|
||||
r := row.FromNoms(sch, key.(types.Tuple), value.(types.Tuple))
|
||||
rows = append(rows, r)
|
||||
return false
|
||||
})
|
||||
|
||||
return rows
|
||||
}
|
||||
|
||||
// Creates a test database with the test data set in it
|
||||
func CreateTestDatabase(dEnv *env.DoltEnv, t *testing.T) {
|
||||
dtestutils.CreateTestTable(t, dEnv, PeopleTableName, PeopleTestSchema, AllPeopleRows...)
|
||||
dtestutils.CreateTestTable(t, dEnv, episodesTableName, EpisodesTestSchema, allEpsRows...)
|
||||
dtestutils.CreateTestTable(t, dEnv, appearancesTableName, AppearancesTestSchema, AllAppsRows...)
|
||||
dtestutils.CreateTestTable(t, dEnv, EpisodesTableName, EpisodesTestSchema, AllEpsRows...)
|
||||
dtestutils.CreateTestTable(t, dEnv, AppearancesTableName, AppearancesTestSchema, AllAppsRows...)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,6 +34,8 @@ import (
|
||||
// Executes all the SQL non-select statements given in the string against the root value given and returns the updated
|
||||
// root, or an error. Statements in the input string are split by `;\n`
|
||||
func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*doltdb.RootValue, error) {
|
||||
batcher := dsql.NewSqlBatcher(dEnv.DoltDB, root)
|
||||
|
||||
for _, query := range strings.Split(statements, ";\n") {
|
||||
if len(strings.Trim(query, " ")) == 0 {
|
||||
continue
|
||||
@@ -51,23 +53,19 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
|
||||
case *sqlparser.Select, *sqlparser.OtherRead:
|
||||
return nil, errors.New("Select statements aren't handled")
|
||||
case *sqlparser.Insert:
|
||||
var result *dsql.InsertResult
|
||||
result, execErr = dsql.ExecuteInsert(context.Background(), dEnv.DoltDB, root, s, query)
|
||||
root = result.Root
|
||||
case *sqlparser.Update:
|
||||
var result *dsql.UpdateResult
|
||||
result, execErr = dsql.ExecuteUpdate(context.Background(), dEnv.DoltDB, root, s, query)
|
||||
root = result.Root
|
||||
case *sqlparser.Delete:
|
||||
var result *dsql.DeleteResult
|
||||
result, execErr = dsql.ExecuteDelete(context.Background(), dEnv.DoltDB, root, s, query)
|
||||
root = result.Root
|
||||
_, execErr = dsql.ExecuteBatchInsert(context.Background(), root, s, batcher)
|
||||
case *sqlparser.DDL:
|
||||
if root, err = batcher.Commit(context.Background()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, execErr = sqlparser.ParseStrictDDL(query)
|
||||
if execErr != nil {
|
||||
return nil, fmt.Errorf("Error parsing DDL: %v.", execErr.Error())
|
||||
}
|
||||
root, execErr = sqlDDL(dEnv, root, s, query)
|
||||
if err := batcher.UpdateRoot(root); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query)
|
||||
}
|
||||
@@ -76,6 +74,13 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
|
||||
return nil, execErr
|
||||
}
|
||||
}
|
||||
|
||||
if newRoot, err := batcher.Commit(context.Background()); newRoot != nil {
|
||||
root = newRoot
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user