Merge pull request #233 from liquidata-inc/zachmu/sql-batch

Implemented batch insert semantics for the new SQL engine.
This commit is contained in:
Zach Musgrave
2019-12-05 14:06:29 -08:00
committed by GitHub
17 changed files with 752 additions and 1352 deletions
+1 -1
View File
@@ -137,7 +137,7 @@ if rows[2] != "9,8,7,6,5,4".split(","):
[ "$output" = "invalid column name c6" ]
run dolt sql -q "insert into test (pk,c1,c2,c3,c4,c5) values (0,6,6,6,6,6)"
[ "$status" -eq 1 ]
[ "$output" = "duplicate primary key given" ] || false
[[ "$output" =~ "duplicate primary key" ]] || false
}
@test "dolt sql insert same column twice" {
+1 -1
View File
@@ -85,7 +85,7 @@ teardown() {
[[ ! "$output" =~ "6" ]] || false
run dolt sql -q "insert into test (pk1,pk2,c1,c2,c3,c4,c5) values (0,1,7,7,7,7,7)"
[ "$status" -eq 1 ]
[ "$output" = "duplicate primary key given" ] || false
[[ "$output" =~ "duplicate primary key" ]] || false
run dolt sql -q "insert into test (pk1,c1,c2,c3,c4,c5) values (0,6,6,6,6,6)"
[ "$status" -eq 1 ]
[ "$output" = "column name 'pk2' is non-nullable but attempted to set default value of null" ] || false
+88 -58
View File
@@ -106,14 +106,14 @@ func Sql(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEn
return HandleVErrAndExitCode(verr, usage)
}
se, err := newSqlEngine(dEnv, root)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
origRoot := se.db.Root()
origRoot := root
// run a single command and exit
if query, ok := apr.GetValue(queryFlag); ok {
se, err := newSqlEngine(dEnv, dsqle.NewDatabase("dolt", root, dEnv))
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
if err := processQuery(ctx, query, se); err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
} else if se.db.Root() != origRoot {
@@ -125,8 +125,13 @@ func Sql(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEn
// Run in either batch mode for piped input, or shell mode for interactive
fi, err := os.Stdin.Stat()
var se *sqlEngine
// Windows has a bug where STDIN can't be statted in some cases, see https://github.com/golang/go/issues/33570
if (err != nil && osutil.IsWindows) || (fi.Mode()&os.ModeCharDevice) == 0 {
se, err = newSqlEngine(dEnv, dsqle.NewBatchedDatabase("dolt", root, dEnv))
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
err = runBatchMode(ctx, se)
if err != nil {
return 1
@@ -134,7 +139,11 @@ func Sql(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEn
} else if err != nil {
HandleVErrAndExitCode(errhand.BuildDError("Couldn't stat STDIN. This is a bug.").Build(), usage)
} else {
err := runShell(ctx, se)
se, err = newSqlEngine(dEnv, dsqle.NewDatabase("dolt", root, dEnv))
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
err = runShell(ctx, se)
if err != nil {
return HandleVErrAndExitCode(errhand.BuildDError("unable to start shell").AddCause(err).Build(), usage)
}
@@ -174,8 +183,6 @@ func runBatchMode(ctx context.Context, se *sqlEngine) error {
scanner.Buffer(buf, maxCapacity)
scanner.Split(scanStatements)
batcher := dsql.NewSqlBatcher(se.dEnv.DoltDB, se.db.Root())
var query string
for scanner.Scan() {
query += scanner.Text()
@@ -184,21 +191,24 @@ func runBatchMode(ctx context.Context, se *sqlEngine) error {
}
if !batchInsertEarlySemicolon(query) {
query += ";"
// TODO: We should fix this problem by properly implementing a state machine for scanStatements
continue
}
if err := processBatchQuery(ctx, query, se, batcher); err != nil {
if err := processBatchQuery(ctx, query, se); err != nil {
_, _ = fmt.Fprintf(cli.CliErr, "Error processing query '%s': %s\n", query, err.Error())
return err
}
query = ""
}
updateBatchInsertOutput()
if err := scanner.Err(); err != nil {
cli.Println(err.Error())
}
if newRoot, _ := batcher.Commit(ctx); newRoot != nil {
se.db.SetRoot(newRoot)
if err := se.db.Flush(ctx); err != nil {
return err
}
return nil
@@ -420,7 +430,10 @@ func prepend(s string, ss []string) []string {
// Processes a single query. The Root of the sqlEngine will be updated if necessary.
func processQuery(ctx context.Context, query string, se *sqlEngine) error {
sqlStatement, err := sqlparser.Parse(query)
if err != nil {
if err == sqlparser.ErrEmpty {
// silently skip empty statements
return nil
} else if err != nil {
return fmt.Errorf("Error parsing SQL: %v.", err.Error())
}
@@ -454,36 +467,89 @@ func processQuery(ctx context.Context, query string, se *sqlEngine) error {
}
}
type stats struct {
numRowsInserted int
numRowsUpdated int
numErrorsIgnored int
}
var batchEditStats stats
var displayStrLen int
const maxBatchSize = 50000
const updateInterval = 500
// Processes a single query in batch mode. The Root of the sqlEngine may or may not be changed.
func processBatchQuery(ctx context.Context, query string, se *sqlEngine, batcher *dsql.SqlBatcher) error {
func processBatchQuery(ctx context.Context, query string, se *sqlEngine) error {
sqlStatement, err := sqlparser.Parse(query)
if err != nil {
if err == sqlparser.ErrEmpty {
// silently skip empty statements
return nil
} else if err != nil {
return fmt.Errorf("Error parsing SQL: %v.", err.Error())
}
switch s := sqlStatement.(type) {
switch sqlStatement.(type) {
case *sqlparser.Insert:
return se.insertBatch(ctx, s, batcher)
_, rowIter, err := se.query(ctx, query)
if err != nil {
return fmt.Errorf("Error inserting rows: %v", err.Error())
}
err = mergeInsertResultIntoStats(rowIter, &batchEditStats)
if err != nil {
return fmt.Errorf("Error inserting rows: %v", err.Error())
}
if batchEditStats.numRowsInserted%maxBatchSize == 0 {
err := se.db.Flush(ctx)
if err != nil {
return err
}
}
if batchEditStats.numRowsInserted%updateInterval == 0 {
updateBatchInsertOutput()
}
return nil
default:
// For any other kind of statement, we need to commit whatever batch edit we've accumulated so far before executing
// the query
newRoot, err := batcher.Commit(ctx)
err := se.db.Flush(ctx)
if err != nil {
return err
}
se.db.SetRoot(newRoot)
err = processQuery(ctx, query, se)
if err != nil {
return err
}
if err := batcher.UpdateRoot(se.db.Root()); err != nil {
return err
}
return nil
}
}
func updateBatchInsertOutput() {
displayStr := fmt.Sprintf("Rows inserted: %d", batchEditStats.numRowsInserted)
displayStrLen = cli.DeleteAndPrint(displayStrLen, displayStr)
}
// Updates the batch insert stats with the results of an insert operation.
func mergeInsertResultIntoStats(rowIter sql.RowIter, s *stats) error {
for {
row, err := rowIter.Next()
if err == io.EOF {
return nil
} else if err != nil {
return err
} else {
updated := row[0].(int64)
s.numRowsInserted += int(updated)
}
}
}
type sqlEngine struct {
db *dsqle.Database
dEnv *env.DoltEnv
@@ -491,8 +557,7 @@ type sqlEngine struct {
}
// sqlEngine packages up the context necessary to run sql queries against sqle.
func newSqlEngine(dEnv *env.DoltEnv, root *doltdb.RootValue) (*sqlEngine, error) {
db := dsqle.NewDatabase("dolt", root, dEnv)
func newSqlEngine(dEnv *env.DoltEnv, db *dsqle.Database) (*sqlEngine, error) {
engine := sqle.NewDefault()
engine.AddDatabase(db)
@@ -663,41 +728,6 @@ func runPrintingPipeline(ctx context.Context, nbf *types.NomsBinFormat, p *pipel
return nil
}
type stats struct {
numRowsInserted int
numRowsUpdated int
numErrorsIgnored int
}
var batchEditStats stats
var displayStrLen int
// Executes a SQL insert statement in batch mode. If the root value changes, sqlEngine's root will be updated.
// No output is written to the console in batch mode.
func (se *sqlEngine) insertBatch(ctx context.Context, stmt *sqlparser.Insert, batcher *dsql.SqlBatcher) error {
result, err := dsql.ExecuteBatchInsert(ctx, se.db.Root(), stmt, batcher)
if err != nil {
return fmt.Errorf("Error inserting rows: %v", err.Error())
}
mergeResultIntoStats(result, &batchEditStats)
displayStr := fmt.Sprintf("Rows inserted: %d, Updated: %d, Errors: %d",
batchEditStats.numRowsInserted, batchEditStats.numRowsUpdated, batchEditStats.numErrorsIgnored)
displayStrLen = cli.DeleteAndPrint(displayStrLen, displayStr)
if result.Root != nil {
se.db.SetRoot(result.Root)
}
return nil
}
func mergeResultIntoStats(result *dsql.InsertResult, stats *stats) {
stats.numRowsInserted += result.NumRowsInserted
stats.numRowsUpdated += result.NumRowsUpdated
stats.numErrorsIgnored += result.NumErrorsIgnored
}
// Checks if the query is a naked delete and then deletes all rows if so. Returns true if it did so, false otherwise.
func (se *sqlEngine) checkThenDeleteAllRows(ctx context.Context, s *sqlparser.Delete) bool {
if s.Where == nil && s.Limit == nil && s.Partitions == nil && len(s.TableExprs) == 1 {
-19
View File
@@ -101,25 +101,6 @@ func TestSqlShow(t *testing.T) {
}
}
func TestBadInput(t *testing.T) {
tests := []struct {
name string
args []string
expectedRes int
}{
{"no query", []string{"-q", ""}, 1},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
dEnv := createEnvWithSeedData(t)
commandStr := "dolt sql"
result := Sql(context.TODO(), commandStr, test.args, dEnv)
assert.Equal(t, test.expectedRes, result)
})
}
}
// Tests of the create table SQL command, mostly a smoke test for errors in the command line handler. Most tests of
// create table SQL command are in the sql package.
func TestCreateTable(t *testing.T) {
+1 -1
View File
@@ -59,7 +59,7 @@ require (
replace github.com/liquidata-inc/dolt/go/gen/proto/dolt/services/eventsapi => ./gen/proto/dolt/services/eventsapi
replace github.com/src-d/go-mysql-server => github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191125230736-ebf9aee2fe67
replace github.com/src-d/go-mysql-server => github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191204003907-576e8690465b
replace vitess.io/vitess => github.com/liquidata-inc/vitess v0.0.0-20191125220844-c7181e70ee98
+2 -6
View File
@@ -203,8 +203,8 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFB
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191125230736-ebf9aee2fe67 h1:8ZDcYhWV0/yC9ATVsxJXbaFPsR4s2b2VSb83vTYobGs=
github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191125230736-ebf9aee2fe67/go.mod h1:hzni/QFitaMZ9H7hbZ3K+C4BEl75iDKqTIY+Hbvj6VQ=
github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191204003907-576e8690465b h1:LaTfDUa9JSuS4sHaDqfHgwFitInog3+wYlWOViNrfyw=
github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191204003907-576e8690465b/go.mod h1:hzni/QFitaMZ9H7hbZ3K+C4BEl75iDKqTIY+Hbvj6VQ=
github.com/liquidata-inc/ishell v0.0.0-20190514193646-693241f1f2a0 h1:phMgajKClMUiIr+hF2LGt8KRuUa2Vd2GI1sNgHgSXoU=
github.com/liquidata-inc/ishell v0.0.0-20190514193646-693241f1f2a0/go.mod h1:YC1rI9k5gx8D02ljlbxDfZe80s/iq8bGvaaQsvR+qxs=
github.com/liquidata-inc/mmap-go v1.0.3 h1:2LndAeAtup9rpvUmu4wZSYCsjCQ0Zpc+NqE+6+PnT7g=
@@ -244,7 +244,6 @@ github.com/olekukonko/tablewriter v0.0.0-20160115111002-cca8bbc07984/go.mod h1:v
github.com/oliveagle/jsonpath v0.0.0-20180606110733-2e52cf6e6852 h1:Yl0tPBa8QPjGmesFh1D0rDy+q1Twx6FyU7VWHi8wZbI=
github.com/oliveagle/jsonpath v0.0.0-20180606110733-2e52cf6e6852/go.mod h1:eqOVx5Vwu4gd2mmMZvVZsgIqNSaW3xxRThUJ0k/TPk4=
github.com/opentracing-contrib/go-grpc v0.0.0-20180928155321-4b5a12d3ff02/go.mod h1:JNdpVEzCpXBgIiv4ds+TzhN1hrtxq6ClLrTlT9OQRSc=
github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg=
github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU=
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
@@ -431,7 +430,6 @@ golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBn
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0 h1:Dh6fw+p6FyRl5x/FvNswO1ji0lIGzm3KP8Y9VkS9PTE=
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
@@ -442,7 +440,6 @@ google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsb
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.6.1 h1:QzqyMA1tlu6CgqCDUtU9V+ZKhLFT2dkJuANu5QaxI3I=
google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
@@ -458,7 +455,6 @@ google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZi
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.24.0 h1:vb/1TCsVn3DcJlQ0Gs1yB1pKI6Do2/QNwxdKqmc/b0s=
google.golang.org/grpc v1.24.0/go.mod h1:XDChyiUovWa60DnaeDeZmSW86xtLtjtZbwvSiRnRtcA=
google.golang.org/grpc v1.25.1 h1:wdKvqQk7IttEw92GoRyKG2IDrUIpgpj6H6m81yfeMW0=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
-272
View File
@@ -1,272 +0,0 @@
// Copyright 2019 Liquidata, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"context"
"errors"
"fmt"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
"github.com/liquidata-inc/dolt/go/store/hash"
"github.com/liquidata-inc/dolt/go/store/types"
)
// SqlBatcher knows how to efficiently batch insert / update statements, e.g. when doing a SQL import. It does this by
// using a single MapEditor per table that isn't persisted until Commit is called.
type SqlBatcher struct {
// The database we are editing
db *doltdb.DoltDB
// The root value we are editing
root *doltdb.RootValue
// The set of tables under edit
tables map[string]*doltdb.Table
// The schemas of tables under edit
schemas map[string]schema.Schema
// The row data for tables being edited
rowData map[string]types.Map
// The editors applying updates to the tables
editors map[string]*types.MapEditor
// The hashes of primary keys being inserted to the tables
hashes map[string]map[hash.Hash]bool
}
// Returns a new SqlBatcher for the given environment and root value.
func NewSqlBatcher(db *doltdb.DoltDB, root *doltdb.RootValue) *SqlBatcher {
batcher := &SqlBatcher{
db: db,
root: root,
}
batcher.resetState()
return batcher
}
// Updates this batcher with a new root value. If there are outstanding edits, returns an error.
func (b *SqlBatcher) UpdateRoot(root *doltdb.RootValue) error {
if b.isDirty() {
return errors.New("UpdateRoot called with outstanding edits")
}
b.root = root
// resetting the state shouldn't be necessary here because of the isDirty check, but if the client chooses to ignore
// the returned error, we'll at least have a clean state going forward
b.resetState()
return nil
}
// isDirty returns whether there are outstanding edits that haven't been committed.
func (b *SqlBatcher) isDirty() bool {
return len(b.editors) > 0
}
// resetState flushes the cache of outstanding edits and other data
func (b *SqlBatcher) resetState() {
b.tables = make(map[string]*doltdb.Table)
b.schemas = make(map[string]schema.Schema)
b.rowData = make(map[string]types.Map)
b.editors = make(map[string]*types.MapEditor)
b.hashes = make(map[string]map[hash.Hash]bool)
}
type InsertOptions struct {
// Whether to silently replace any existing rows with the same primary key as rows inserted
Replace bool
}
type BatchInsertResult struct {
RowInserted bool
RowUpdated bool
}
func (b *SqlBatcher) Insert(ctx context.Context, tableName string, r row.Row, opt InsertOptions) (*BatchInsertResult, error) {
sch, err := b.GetSchema(ctx, tableName)
if err != nil {
return nil, err
}
rowData, err := b.getRowData(ctx, tableName)
if err != nil {
return nil, err
}
ed, err := b.getEditor(ctx, tableName)
if err != nil {
return nil, err
}
key, err := r.NomsMapKey(sch).Value(ctx)
if err != nil {
return nil, err
}
keyVal, _, err := rowData.MaybeGet(ctx, key)
if err != nil {
return nil, err
}
rowExists := keyVal != nil
hashes := b.getHashes(ctx, tableName)
h, err := key.Hash(b.root.VRW().Format())
if err != nil {
return nil, err
}
rowAlreadyTouched := hashes[h]
if rowExists || rowAlreadyTouched {
if !opt.Replace {
return nil, fmt.Errorf("Duplicate primary key: '%v'", getPrimaryKeyString(r, sch))
}
}
ed.Set(key, r.NomsMapValue(sch))
h, err = key.Hash(b.root.VRW().Format())
if err != nil {
return nil, err
}
hashes[h] = true
return &BatchInsertResult{RowInserted: !rowExists, RowUpdated: rowExists || rowAlreadyTouched}, nil
}
// GetTable returns the table with the name given. This method is offered because reading the table from the root value
// is relatively expensive, and SqlBatcher caches Tables to avoid the overhead.
func (b *SqlBatcher) GetTable(ctx context.Context, tableName string) (*doltdb.Table, error) {
if table, ok := b.tables[tableName]; ok {
return table, nil
}
if has, err := b.root.HasTable(ctx, tableName); err != nil {
return nil, err
} else if !has {
return nil, fmt.Errorf("Unknown table %v", tableName)
}
table, _, err := b.root.GetTable(ctx, tableName)
if err != nil {
return nil, err
}
b.tables[tableName] = table
return table, nil
}
// GetSchema returns the schema for the table name given. This method is offered because reading the schema from disk
// is actually relatively expensive -- SqlBatcher caches the schema values per table to avoid the overhead.
func (b *SqlBatcher) GetSchema(ctx context.Context, tableName string) (schema.Schema, error) {
if schema, ok := b.schemas[tableName]; ok {
return schema, nil
}
table, err := b.GetTable(ctx, tableName)
if err != nil {
return nil, err
}
sch, err := table.GetSchema(ctx)
if err != nil {
return nil, err
}
b.schemas[tableName] = sch
return sch, nil
}
func (b *SqlBatcher) getEditor(ctx context.Context, tableName string) (*types.MapEditor, error) {
if ed, ok := b.editors[tableName]; ok {
return ed, nil
}
rowData, err := b.getRowData(ctx, tableName)
if err != nil {
return nil, err
}
ed := rowData.Edit()
b.editors[tableName] = ed
return ed, nil
}
func (b *SqlBatcher) getRowData(ctx context.Context, tableName string) (types.Map, error) {
if rowData, ok := b.rowData[tableName]; ok {
return rowData, nil
}
table, err := b.GetTable(ctx, tableName)
if err != nil {
return types.EmptyMap, err
}
rowData, err := table.GetRowData(ctx)
if err != nil {
return types.EmptyMap, err
}
b.rowData[tableName] = rowData
return rowData, nil
}
func (b *SqlBatcher) getHashes(ctx context.Context, tableName string) map[hash.Hash]bool {
if hashes, ok := b.hashes[tableName]; ok {
return hashes
}
hashes := make(map[hash.Hash]bool)
b.hashes[tableName] = hashes
return hashes
}
// Commit writes a new root value for every table under edit and returns the new root value. Tables are written in an
// arbitrary order.
func (b *SqlBatcher) Commit(ctx context.Context) (*doltdb.RootValue, error) {
root := b.root
for tableName, ed := range b.editors {
newMap, err := ed.Map(ctx)
if err != nil {
return nil, err
}
table := b.tables[tableName]
table, err = table.UpdateRows(ctx, newMap)
if err != nil {
return nil, err
}
root, err = root.PutTable(ctx, tableName, table)
if err != nil {
return nil, err
}
}
b.root = root
b.resetState()
return root, nil
}
-331
View File
@@ -1,331 +0,0 @@
// Copyright 2019 Liquidata, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"context"
"errors"
"fmt"
"strings"
"vitess.io/vitess/go/vt/sqlparser"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
"github.com/liquidata-inc/dolt/go/store/types"
)
type InsertResult struct {
Root *doltdb.RootValue
NumRowsInserted int
NumRowsUpdated int
NumErrorsIgnored int
}
var ErrMissingPrimaryKeys = errors.New("One or more primary key columns missing from insert statement")
var ConstraintFailedFmt = "Constraint failed for column '%v': %v"
// ExecuteInsertBatch executes the given insert statement in batch mode and returns the result. The table is not changed
// until the batch is Committed. The InsertResult returned similarly doesn't have a Root set, since the root isn't
// modified by this function.
func ExecuteBatchInsert(
ctx context.Context,
root *doltdb.RootValue,
s *sqlparser.Insert,
batcher *SqlBatcher,
) (*InsertResult, error) {
tableName := s.Table.Name.String()
tableSch, err := batcher.GetSchema(ctx, tableName)
if err != nil {
return nil, err
}
// Parser supports overwrite on insert with both the replace keyword (from MySQL) as well as the ignore keyword
replace := s.Action == sqlparser.ReplaceStr
ignore := s.Ignore != ""
// Get the list of columns to insert into. We support both naked inserts (no column list specified) as well as
// explicit column lists.
var cols []schema.Column
if s.Columns == nil || len(s.Columns) == 0 {
cols = tableSch.GetAllCols().GetColumns()
} else {
cols = make([]schema.Column, len(s.Columns))
for i, colName := range s.Columns {
for _, c := range cols {
if c.Name == colName.String() {
return nil, fmt.Errorf("Repeated column: '%v'", c.Name)
}
}
col, ok := tableSch.GetAllCols().GetByName(colName.String())
if !ok {
return nil, fmt.Errorf(UnknownColumnErrFmt, colName)
}
cols[i] = col
}
}
var rows []row.Row // your boat
switch queryRows := s.Rows.(type) {
case sqlparser.Values:
var err error
rows, err = prepareInsertVals(root.VRW().Format(), cols, &queryRows, tableSch)
if err != nil {
return nil, err
}
case *sqlparser.Select:
return nil, fmt.Errorf("Insert as select not supported")
case *sqlparser.ParenSelect:
return nil, fmt.Errorf("Parenthesized select expressions in insert not supported")
case *sqlparser.Union:
return nil, fmt.Errorf("Union not supported")
default:
return nil, fmt.Errorf("Unrecognized type for insert: %v", queryRows)
}
// Perform the insert
var result InsertResult
opt := InsertOptions{replace}
for _, r := range rows {
if has, err := row.IsValid(r, tableSch); err != nil {
return nil, err
} else if !has {
if ignore {
result.NumErrorsIgnored += 1
continue
} else {
col, constraint, err := row.GetInvalidConstraint(r, tableSch)
if err != nil {
return nil, fmt.Errorf(ConstraintFailedFmt, "unknown", "unknown")
}
return nil, fmt.Errorf(ConstraintFailedFmt, col.Name, constraint)
}
}
insertResult, err := batcher.Insert(ctx, tableName, r, opt)
if err != nil {
if ignore {
result.NumErrorsIgnored += 1
continue
} else {
return nil, err
}
}
if insertResult.RowInserted {
result.NumRowsInserted++
}
if insertResult.RowUpdated {
result.NumRowsUpdated++
}
}
return &result, nil
}
// ExecuteInsert executes the given select insert statement and returns the result.
func ExecuteInsert(
ctx context.Context,
db *doltdb.DoltDB,
root *doltdb.RootValue,
s *sqlparser.Insert,
) (*InsertResult, error) {
batcher := NewSqlBatcher(db, root)
insertResult, err := ExecuteBatchInsert(ctx, root, s, batcher)
if err != nil {
return nil, err
}
newRoot, err := batcher.Commit(ctx)
if err != nil {
return nil, err
}
return &InsertResult{
Root: newRoot,
NumRowsInserted: insertResult.NumRowsInserted,
NumRowsUpdated: insertResult.NumRowsUpdated,
NumErrorsIgnored: insertResult.NumErrorsIgnored,
}, nil
}
// Returns a primary key summary of the row given
func getPrimaryKeyString(r row.Row, tableSch schema.Schema) string {
var sb strings.Builder
first := true
err := tableSch.GetPKCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
if !first {
sb.WriteString(", ")
}
sb.WriteString(col.Name)
sb.WriteString(": ")
val, ok := r.GetColVal(tag)
if ok {
sb.WriteString(fmt.Sprintf("%v", val))
} else {
sb.WriteString("null")
}
first = false
return false, nil
})
// TODO: fix panics
if err != nil {
panic(err)
}
return sb.String()
}
// Returns rows to insert from the set of values given
func prepareInsertVals(nbf *types.NomsBinFormat, cols []schema.Column, values *sqlparser.Values, tableSch schema.Schema) ([]row.Row, error) {
// Lack of primary keys is its own special kind of failure that we can detect before creating any rows
allKeysFound := true
err := tableSch.GetPKCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
for _, insertCol := range cols {
if insertCol.Tag == tag {
return false, nil
}
}
allKeysFound = false
return true, nil
})
if err != nil {
return nil, err
}
if !allKeysFound {
return nil, ErrMissingPrimaryKeys
}
rows := make([]row.Row, len(*values))
for i, valTuple := range *values {
r, err := makeRow(nbf, cols, tableSch, valTuple)
if err != nil {
return nil, err
}
rows[i] = r
}
return rows, nil
}
func makeRow(nbf *types.NomsBinFormat, columns []schema.Column, tableSch schema.Schema, tuple sqlparser.ValTuple) (row.Row, error) {
if len(columns) != len(tuple) {
return errInsertRow("Wrong number of values for tuple %v", nodeToString(tuple))
}
taggedVals := make(row.TaggedValues)
for i, expr := range tuple {
column := columns[i]
switch val := expr.(type) {
case *sqlparser.SQLVal:
nomsVal, err := extractNomsValueFromSQLVal(val, column.Kind)
if err != nil {
return nil, err
}
taggedVals[column.Tag] = nomsVal
case *sqlparser.NullVal:
// nothing to do, just don't set a tagged value for this column
case sqlparser.BoolVal:
if column.Kind != types.BoolKind {
return errInsertRow("Type mismatch: boolean value but non-boolean column: %v", nodeToString(val))
}
taggedVals[column.Tag] = types.Bool(val)
case *sqlparser.UnaryExpr:
nomsVal, err := extractNomsValueFromUnaryExpr(val, column.Kind)
if err != nil {
return nil, err
}
taggedVals[column.Tag] = nomsVal
// Many of these shouldn't be possible in the grammar, but all cases included for completeness
case *sqlparser.ComparisonExpr:
return errInsertRow("Comparison expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.AndExpr:
return errInsertRow("And expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.OrExpr:
return errInsertRow("Or expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.NotExpr:
return errInsertRow("Not expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.ParenExpr:
return errInsertRow("Parenthetical expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.RangeCond:
return errInsertRow("Range expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.IsExpr:
return errInsertRow("Is expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.ExistsExpr:
return errInsertRow("Exists expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.ColName:
// unquoted strings are interpreted by the parser as column names, give a hint
return errInsertRow("Column name expressions not supported in insert values. Did you forget to quote a string? %v", nodeToString(tuple))
case sqlparser.ValTuple:
return errInsertRow("Tuple expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.Subquery:
return errInsertRow("Subquery expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.ListArg:
return errInsertRow("List expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.BinaryExpr:
return errInsertRow("Binary expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.IntervalExpr:
return errInsertRow("Interval expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.CollateExpr:
return errInsertRow("Collate expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.FuncExpr:
return errInsertRow("Function expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.CaseExpr:
return errInsertRow("Case expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.ValuesFuncExpr:
return errInsertRow("Values func expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.ConvertExpr:
return errInsertRow("Conversion expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.SubstrExpr:
return errInsertRow("Substr expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.ConvertUsingExpr:
return errInsertRow("Convert expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.MatchExpr:
return errInsertRow("Match expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.GroupConcatExpr:
return errInsertRow("Group concat expressions not supported in insert values: %v", nodeToString(tuple))
case *sqlparser.Default:
return errInsertRow("Unrecognized expression: %v", nodeToString(tuple))
default:
return errInsertRow("Unrecognized expression: %v", nodeToString(tuple))
}
}
return row.New(nbf, tableSch, taggedVals)
}
// Returns an error result with return type to match ExecuteInsert
func errInsert(errorFmt string, args ...interface{}) (*InsertResult, error) {
return nil, errors.New(fmt.Sprintf(errorFmt, args...))
}
// Returns an error result with return type to match ExecuteInsert
func errInsertRow(errorFmt string, args ...interface{}) (row.Row, error) {
return nil, errors.New(fmt.Sprintf(errorFmt, args...))
}
-384
View File
@@ -1,384 +0,0 @@
// Copyright 2019 Liquidata, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"vitess.io/vitess/go/vt/sqlparser"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/dtestutils"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
. "github.com/liquidata-inc/dolt/go/libraries/doltcore/sql/sqltestutil"
"github.com/liquidata-inc/dolt/go/store/types"
)
func mustRow(r row.Row, err error) row.Row {
if err != nil {
panic(err)
}
return r
}
func TestExecuteInsert(t *testing.T) {
tests := []struct {
name string
query string
insertedValues []row.Row
expectedResult InsertResult // root is not compared, but it used for other assertions
expectedErr string
}{
{
name: "insert one row, all columns",
query: `insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
(7, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000005', 677)`,
insertedValues: []row.Row{NewPeopleRowWithOptionalFields(7, "Maggie", "Simpson", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000005"), 677)},
expectedResult: InsertResult{NumRowsInserted: 1},
},
{
name: "insert one row, all columns, negative values",
query: `insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
(-7, "Maggie", "Simpson", false, -1, -5.1, '00000000-0000-0000-0000-000000000005', 677)`,
insertedValues: []row.Row{NewPeopleRowWithOptionalFields(-7, "Maggie", "Simpson", false, -1, -5.1, uuid.MustParse("00000000-0000-0000-0000-000000000005"), 677)},
expectedResult: InsertResult{NumRowsInserted: 1},
},
{
name: "insert one row, no column list",
query: `insert into people values
(7, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000005', 677)`,
insertedValues: []row.Row{NewPeopleRowWithOptionalFields(7, "Maggie", "Simpson", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000005"), 677)},
expectedResult: InsertResult{NumRowsInserted: 1},
},
{
name: "insert one row out of order",
query: `insert into people (rating, first, id, last, age, is_married) values
(5.1, "Maggie", 7, "Simpson", 1, false)`,
insertedValues: []row.Row{NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1)},
expectedResult: InsertResult{NumRowsInserted: 1},
},
{
name: "insert one row, null values",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", null, null, null)`,
insertedValues: []row.Row{mustRow(row.New(types.Format_7_18, PeopleTestSchema, row.TaggedValues{IdTag: types.Int(7), FirstTag: types.String("Maggie"), LastTag: types.String("Simpson")}))},
expectedResult: InsertResult{NumRowsInserted: 1},
},
{
name: "insert one row, null constraint failure",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", null, null, null, null)`,
expectedErr: "Constraint failed for column 'last': Not null",
},
{
name: "duplicate column list",
query: `insert into people (id, first, last, is_married, first, age, rating) values
(7, "Maggie", "Simpson", null, null, null, null)`,
expectedErr: "Repeated column: 'first'",
},
{
name: "insert two rows, all columns",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", false, 1, 5.1),
(8, "Milhouse", "Van Houten", false, 8, 3.5)`,
insertedValues: []row.Row{
NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1),
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
},
expectedResult: InsertResult{NumRowsInserted: 2},
},
{
name: "insert two rows, one with null constraint failure",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", false, 1, 5.1),
(8, "Milhouse", null, false, 8, 3.5)`,
expectedErr: "Constraint failed for column 'last': Not null",
},
{
name: "type mismatch int -> string",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", 100, false, 1, 5.1)`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch int -> bool",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", 10, 1, 5.1)`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch int -> uuid",
query: `insert into people (id, first, last, is_married, age, uuid) values
(7, "Maggie", "Simpson", false, 1, 100)`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch string -> int",
query: `insert into people (id, first, last, is_married, age, rating) values
("7", "Maggie", "Simpson", false, 1, 5.1)`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch string -> float",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", false, 1, "5.1")`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch string -> uint",
query: `insert into people (id, first, last, is_married, age, num_episodes) values
(7, "Maggie", "Simpson", false, 1, "100")`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch string -> uuid",
query: `insert into people (id, first, last, is_married, age, uuid) values
(7, "Maggie", "Simpson", false, 1, "a uuid but idk what im doing")`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch float -> string",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, 8.1, "Simpson", false, 1, 5.1)`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch float -> bool",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", 0.5, 1, 5.1)`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch float -> int",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", false, 1.0, 5.1)`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch bool -> int",
query: `insert into people (id, first, last, is_married, age, rating) values
(true, "Maggie", "Simpson", false, 1, 5.1)`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch bool -> float",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", false, 1, true)`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch bool -> string",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, true, "Simpson", false, 1, 5.1)`,
expectedErr: "Type mismatch",
},
{
name: "type mismatch bool -> uuid",
query: `insert into people (id, first, last, is_married, age, uuid) values
(7, "Maggie", "Simpson", false, 1, true)`,
expectedErr: "Type mismatch",
},
{
name: "insert two rows with ignore, one with null constraint failure",
query: `insert ignore into people (id, first, last, is_married, age, rating) values
(7, "Maggie", null, false, 1, 5.1),
(8, "Milhouse", "Van Houten", false, 8, 3.5)`,
insertedValues: []row.Row{
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
},
expectedResult: InsertResult{NumRowsInserted: 1, NumErrorsIgnored: 1},
},
{
name: "insert existing rows without ignore / replace",
query: `insert into people (id, first, last, is_married, age, rating) values
(0, "Homer", "Simpson", true, 45, 100)`,
expectedErr: "Duplicate primary key: 'id: 0'",
},
{
name: "insert two rows with ignore, one existing in table",
query: `insert ignore into people (id, first, last, is_married, age, rating) values
(0, "Homer", "Simpson", true, 45, 100),
(8, "Milhouse", "Van Houten", false, 8, 3.5)`,
insertedValues: []row.Row{
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
Homer, // verify that homer is unchanged by the insert
},
expectedResult: InsertResult{NumRowsInserted: 1, NumErrorsIgnored: 1},
},
{
name: "insert two rows with replace, one existing in table",
query: `replace into people (id, first, last, is_married, age, rating) values
(0, "Homer", "Simpson", true, 45, 100),
(8, "Milhouse", "Van Houten", false, 8, 3.5)`,
insertedValues: []row.Row{
NewPeopleRow(0, "Homer", "Simpson", true, 45, 100),
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
},
expectedResult: InsertResult{NumRowsInserted: 1, NumRowsUpdated: 1},
},
{
name: "insert two rows with replace ignore, one with errors",
query: `replace ignore into people (id, first, last, is_married, age, rating) values
(0, "Homer", "Simpson", true, 45, 100),
(8, "Milhouse", "Van Houten", false, 8, 3.5),
(7, "Maggie", null, false, 1, 5.1)`,
insertedValues: []row.Row{
NewPeopleRow(0, "Homer", "Simpson", true, 45, 100),
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
},
expectedResult: InsertResult{NumRowsInserted: 1, NumRowsUpdated: 1, NumErrorsIgnored: 1},
},
{
name: "insert two rows with replace, one with errors",
query: `replace into people (id, first, last, is_married, age, rating) values
(0, "Homer", "Simpson", true, 45, 100),
(8, "Milhouse", "Van Houten", false, 8, 3.5),
(7, "Maggie", null, false, 1, 5.1)`,
expectedErr: "Constraint failed for column 'last': Not null",
},
{
name: "insert five rows, all columns",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", false, 1, 5.1),
(8, "Milhouse", "Van Houten", false, 8, 3.5),
(9, "Jacqueline", "Bouvier", true, 80, 2),
(10, "Patty", "Bouvier", false, 40, 7),
(11, "Selma", "Bouvier", false, 40, 7)`,
insertedValues: []row.Row{
NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1),
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
NewPeopleRow(9, "Jacqueline", "Bouvier", true, 80, 2),
NewPeopleRow(10, "Patty", "Bouvier", false, 40, 7),
NewPeopleRow(11, "Selma", "Bouvier", false, 40, 7),
},
expectedResult: InsertResult{NumRowsInserted: 5},
},
{
name: "insert two rows, only required columns",
query: `insert into people (id, first, last) values
(7, "Maggie", "Simpson"),
(8, "Milhouse", "Van Houten")`,
insertedValues: []row.Row{
mustRow(row.New(types.Format_7_18, PeopleTestSchema, row.TaggedValues{IdTag: types.Int(7), FirstTag: types.String("Maggie"), LastTag: types.String("Simpson")})),
mustRow(row.New(types.Format_7_18, PeopleTestSchema, row.TaggedValues{IdTag: types.Int(8), FirstTag: types.String("Milhouse"), LastTag: types.String("Van Houten")})),
},
expectedResult: InsertResult{NumRowsInserted: 2},
},
{
name: "insert two rows, duplicate id",
query: `insert into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", false, 1, 5.1),
(7, "Milhouse", "Van Houten", false, 8, 3.5)`,
expectedErr: "Duplicate primary key: 'id: 7'",
},
{
name: "insert two rows, duplicate id with ignore",
query: `insert ignore into people (id, first, last, is_married, age, rating) values
(7, "Maggie", "Simpson", false, 1, 5.1),
(7, "Milhouse", "Van Houten", false, 8, 3.5)`,
insertedValues: []row.Row{
NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1),
},
expectedResult: InsertResult{NumRowsInserted: 1, NumErrorsIgnored: 1},
},
{
name: "insert no primary keys",
query: `insert into people (age) values (7), (8)`,
expectedErr: "One or more primary key columns missing from insert statement",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if len(tt.expectedErr) > 0 {
require.Equal(t, InsertResult{}, tt.expectedResult, "incorrect test setup: cannot assert both an error and expected results")
require.Nil(t, tt.insertedValues, "incorrect test setup: cannot assert both an error and inserted values")
}
dEnv := dtestutils.CreateTestEnv()
ctx := context.Background()
CreateTestDatabase(dEnv, t)
root, _ := dEnv.WorkingRoot(ctx)
sqlStatement, _ := sqlparser.Parse(tt.query)
s := sqlStatement.(*sqlparser.Insert)
result, err := ExecuteInsert(ctx, dEnv.DoltDB, root, s)
if len(tt.expectedErr) > 0 {
require.Error(t, err)
require.Contains(t, err.Error(), tt.expectedErr)
return
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.expectedResult.NumRowsInserted, result.NumRowsInserted)
assert.Equal(t, tt.expectedResult.NumErrorsIgnored, result.NumErrorsIgnored)
assert.Equal(t, tt.expectedResult.NumRowsUpdated, result.NumRowsUpdated)
table, ok, err := result.Root.GetTable(ctx, PeopleTableName)
assert.NoError(t, err)
assert.True(t, ok)
for _, expectedRow := range tt.insertedValues {
v, err := expectedRow.NomsMapKey(PeopleTestSchema).Value(ctx)
assert.NoError(t, err)
foundRow, ok, err := table.GetRow(ctx, v.(types.Tuple), PeopleTestSchema)
require.NoError(t, err)
require.True(t, ok, "Row not found: %v", expectedRow)
eq, diff := rowsEqual(expectedRow, foundRow)
assert.True(t, eq, "Rows not equals, found diff %v", diff)
}
})
}
}
func rowsEqual(expected, actual row.Row) (bool, string) {
er, ar := make(map[uint64]types.Value), make(map[uint64]types.Value)
_, err := expected.IterCols(func(t uint64, v types.Value) (bool, error) {
er[t] = v
return false, nil
})
if err != nil {
panic(err)
}
_, err = actual.IterCols(func(t uint64, v types.Value) (bool, error) {
ar[t] = v
return false, nil
})
if err != nil {
panic(err)
}
opts := cmp.Options{cmp.AllowUnexported(), dtestutils.FloatComparer}
eq := cmp.Equal(er, ar, opts)
var diff string
if !eq {
diff = cmp.Diff(er, ar, opts)
}
return eq, diff
}
+3
View File
@@ -29,6 +29,9 @@ import (
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table/untyped/resultset"
)
var ErrMissingPrimaryKeys = errors.New("One or more primary key columns missing from insert statement")
var ConstraintFailedFmt = "Constraint failed for column '%v': %v"
type UpdateResult struct {
Root *doltdb.RootValue
NumRowsUpdated int
@@ -18,6 +18,7 @@ import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -372,3 +373,32 @@ func TestExecuteUpdate(t *testing.T) {
})
}
}
func rowsEqual(expected, actual row.Row) (bool, string) {
er, ar := make(map[uint64]types.Value), make(map[uint64]types.Value)
_, err := expected.IterCols(func(t uint64, v types.Value) (bool, error) {
er[t] = v
return false, nil
})
if err != nil {
panic(err)
}
_, err = actual.IterCols(func(t uint64, v types.Value) (bool, error) {
ar[t] = v
return false, nil
})
if err != nil {
panic(err)
}
opts := cmp.Options{cmp.AllowUnexported(), dtestutils.FloatComparer}
eq := cmp.Equal(er, ar, opts)
var diff string
if !eq {
diff = cmp.Diff(er, ar, opts)
}
return eq, diff
}
+51 -10
View File
@@ -29,20 +29,42 @@ import (
var _ sql.Database = (*Database)(nil)
type batchMode bool
const (
batched batchMode = true
single batchMode = false
)
// Database implements sql.Database for a dolt DB.
type Database struct {
sql.Database
name string
root *doltdb.RootValue
dEnv *env.DoltEnv
name string
root *doltdb.RootValue
dEnv *env.DoltEnv
batchMode batchMode
tables map[string]*DoltTable
}
// NewDatabase returns a new dolt databae to use in queries.
// NewDatabase returns a new dolt database to use in queries.
func NewDatabase(name string, root *doltdb.RootValue, dEnv *env.DoltEnv) *Database {
return &Database{
name: name,
root: root,
dEnv: dEnv,
name: name,
root: root,
dEnv: dEnv,
batchMode: single,
tables: make(map[string]*DoltTable),
}
}
// NewBatchedDatabase returns a new dolt database executing in batch insert mode. Integrators must call Flush() to
// commit any outstanding edits.
func NewBatchedDatabase(name string, root *doltdb.RootValue, dEnv *env.DoltEnv) *Database {
return &Database{
name: name,
root: root,
dEnv: dEnv,
batchMode: batched,
tables: make(map[string]*DoltTable),
}
}
@@ -80,12 +102,16 @@ func (db *Database) GetTableInsensitive(ctx context.Context, tblName string) (sq
return nil, false, nil
}
if table, ok := db.tables[exactName]; ok {
return table, true, nil
}
tbl, ok, err := db.root.GetTable(ctx, exactName)
if err != nil {
return nil, false, err
} else if !ok {
panic("Name '" + exactName + "'had already been verified... This is a bug")
panic("Name '" + exactName + "' had already been verified... This is a bug")
}
sch, err := tbl.GetSchema(ctx)
@@ -94,7 +120,9 @@ func (db *Database) GetTableInsensitive(ctx context.Context, tblName string) (sq
return nil, false, err
}
return &DoltTable{name: tblName, table: tbl, sch: sch, db: db}, true, nil
table := &DoltTable{name: exactName, table: tbl, sch: sch, db: db}
db.tables[exactName] = table
return table, true, nil
}
func (db *Database) GetTableNames(ctx context.Context) ([]string, error) {
@@ -129,6 +157,8 @@ func (db *Database) DropTable(ctx *sql.Context, tableName string) error {
return err
}
delete(db.tables, tableName)
db.SetRoot(newRoot)
return nil
@@ -176,3 +206,14 @@ func (db *Database) CreateTable(ctx *sql.Context, tableName string, schema sql.S
return nil
}
// Flushes the current batch of outstanding changes and returns any errors.
func (db *Database) Flush(ctx context.Context) error {
for name, table := range db.tables {
if err := table.flushBatchedEdits(ctx); err != nil {
return err
}
delete(db.tables, name)
}
return nil
}
@@ -12,18 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
package sqle
import (
"context"
"fmt"
"io"
"math/rand"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
sqle "github.com/src-d/go-mysql-server"
"github.com/src-d/go-mysql-server/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"vitess.io/vitess/go/vt/sqlparser"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/dtestutils"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
@@ -63,17 +66,14 @@ func TestSqlBatchInserts(t *testing.T) {
CreateTestDatabase(dEnv, t)
root, _ := dEnv.WorkingRoot(ctx)
batcher := NewSqlBatcher(dEnv.DoltDB, root)
engine := sqle.NewDefault()
db := NewBatchedDatabase("dolt", root, dEnv)
engine.AddDatabase(db)
for _, stmt := range insertStatements {
statement, err := sqlparser.Parse(stmt)
_, rowIter, err := engine.Query(sql.NewEmptyContext(), stmt)
require.NoError(t, err)
insertStmt, ok := statement.(*sqlparser.Insert)
require.True(t, ok)
result, err := ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
require.NoError(t, err)
assert.True(t, result.NumRowsInserted > 0)
assert.Equal(t, 0, result.NumRowsUpdated)
assert.Equal(t, 0, result.NumErrorsIgnored)
require.NoError(t, drainIter(rowIter))
}
// Before committing the batch, the database should be unchanged from its original state
@@ -89,7 +89,7 @@ func TestSqlBatchInserts(t *testing.T) {
assert.ElementsMatch(t, AllAppsRows, allAppearanceRows)
// Now commit the batch and check for new rows
root, err = batcher.Commit(ctx)
err = db.Flush(ctx)
require.NoError(t, err)
var expectedPeople, expectedEpisodes, expectedAppearances []row.Row
@@ -124,6 +124,7 @@ func TestSqlBatchInserts(t *testing.T) {
newAppsRow(11, 9),
)
root = db.Root()
allPeopleRows, err = GetAllRows(root, PeopleTableName)
require.NoError(t, err)
allEpsRows, err = GetAllRows(root, EpisodesTableName)
@@ -136,15 +137,27 @@ func TestSqlBatchInserts(t *testing.T) {
assertRowSetsEqual(t, expectedAppearances, allAppearanceRows)
}
func drainIter(iter sql.RowIter) error {
var returnedErr error
for {
_, err := iter.Next()
if err == io.EOF {
return returnedErr
} else if err != nil {
returnedErr = err
}
}
}
func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
t.Skip("Skipped until insert ignore statements supported in go-mysql-server")
insertStatements := []string{
`replace into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
`insert ignore into people values
(2, "Milhouse", "VanHouten", false, 1, 5.1, '00000000-0000-0000-0000-000000000008', 677)`,
}
numRowsUpdated := []int{1, 0}
numErrorsIgnored := []int{0, 1}
dEnv := dtestutils.CreateTestEnv()
ctx := context.Background()
@@ -152,18 +165,14 @@ func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
CreateTestDatabase(dEnv, t)
root, _ := dEnv.WorkingRoot(ctx)
batcher := NewSqlBatcher(dEnv.DoltDB, root)
for i := range insertStatements {
stmt := insertStatements[i]
statement, err := sqlparser.Parse(stmt)
engine := sqle.NewDefault()
db := NewBatchedDatabase("dolt", root, dEnv)
engine.AddDatabase(db)
for _, stmt := range insertStatements {
_, rowIter, err := engine.Query(sql.NewEmptyContext(), stmt)
require.NoError(t, err)
insertStmt, ok := statement.(*sqlparser.Insert)
require.True(t, ok)
result, err := ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
require.NoError(t, err)
assert.Equal(t, 0, result.NumRowsInserted)
assert.Equal(t, numRowsUpdated[i], result.NumRowsUpdated)
assert.Equal(t, numErrorsIgnored[i], result.NumErrorsIgnored)
drainIter(rowIter)
}
// Before committing the batch, the database should be unchanged from its original state
@@ -172,7 +181,7 @@ func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
assert.ElementsMatch(t, AllPeopleRows, allPeopleRows)
// Now commit the batch and check for new rows
root, err = batcher.Commit(ctx)
err = db.Flush(ctx)
require.NoError(t, err)
var expectedPeople []row.Row
@@ -188,29 +197,29 @@ func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
}
func TestSqlBatchInsertErrors(t *testing.T) {
insertStatements := []string{
`insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
`insert into people values
(2, "Milhouse", "VanHouten", false, 1, 5.1, true, 677)`,
}
dEnv := dtestutils.CreateTestEnv()
ctx := context.Background()
CreateTestDatabase(dEnv, t)
root, _ := dEnv.WorkingRoot(ctx)
batcher := NewSqlBatcher(dEnv.DoltDB, root)
for i := range insertStatements {
stmt := insertStatements[i]
statement, err := sqlparser.Parse(stmt)
require.NoError(t, err)
insertStmt, ok := statement.(*sqlparser.Insert)
require.True(t, ok)
_, err = ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
require.Error(t, err)
}
engine := sqle.NewDefault()
db := NewBatchedDatabase("dolt", root, dEnv)
engine.AddDatabase(db)
_, rowIter, err := engine.Query(sql.NewEmptyContext(), `insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`)
// This won't generate an error until we commit the batch (duplicate key)
assert.NoError(t, err)
assert.NoError(t, drainIter(rowIter))
// This generates an error at insert time because of the bad type for the uuid column
_, _, err = engine.Query(sql.NewEmptyContext(), `insert into people values
(2, "Milhouse", "VanHouten", false, 1, 5.1, true, 677)`)
assert.Error(t, err)
// Error from the first statement appears here
assert.Error(t, db.Flush(ctx))
}
func assertRowSetsEqual(t *testing.T, expected, actual []row.Row) {
@@ -244,6 +253,35 @@ func containsRow(rs []row.Row, r row.Row) bool {
return false
}
func rowsEqual(expected, actual row.Row) (bool, string) {
er, ar := make(map[uint64]types.Value), make(map[uint64]types.Value)
_, err := expected.IterCols(func(t uint64, v types.Value) (bool, error) {
er[t] = v
return false, nil
})
if err != nil {
panic(err)
}
_, err = actual.IterCols(func(t uint64, v types.Value) (bool, error) {
ar[t] = v
return false, nil
})
if err != nil {
panic(err)
}
opts := cmp.Options{cmp.AllowUnexported(), dtestutils.FloatComparer}
eq := cmp.Equal(er, ar, opts)
var diff string
if !eq {
diff = cmp.Diff(er, ar, opts)
}
return eq, diff
}
func newPeopleRow(id int, first, last string) row.Row {
vals := row.TaggedValues{
IdTag: types.Int(id),
@@ -34,7 +34,9 @@ import (
// Executes all the SQL non-select statements given in the string against the root value given and returns the updated
// root, or an error. Statements in the input string are split by `;\n`
func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*doltdb.RootValue, error) {
batcher := dsql.NewSqlBatcher(dEnv.DoltDB, root)
engine := sqle.NewDefault()
db := dsqle.NewBatchedDatabase("dolt", root, dEnv)
engine.AddDatabase(db)
for _, query := range strings.Split(statements, ";\n") {
if len(strings.Trim(query, " ")) == 0 {
@@ -53,9 +55,13 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
case *sqlparser.Select, *sqlparser.OtherRead:
return nil, errors.New("Select statements aren't handled")
case *sqlparser.Insert:
_, execErr = dsql.ExecuteBatchInsert(context.Background(), root, s, batcher)
var rowIter sql.RowIter
_, rowIter, execErr = engine.Query(sql.NewEmptyContext(), query)
if execErr == nil {
execErr = drainIter(rowIter)
}
case *sqlparser.DDL:
if root, err = batcher.Commit(context.Background()); err != nil {
if err = db.Flush(context.Background()); err != nil {
return nil, err
}
_, execErr = sqlparser.ParseStrictDDL(query)
@@ -63,9 +69,7 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
return nil, fmt.Errorf("Error parsing DDL: %v.", execErr.Error())
}
root, execErr = sqlDDL(dEnv, root, s, query)
if err := batcher.UpdateRoot(root); err != nil {
return nil, err
}
db.SetRoot(root)
default:
return nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query)
}
@@ -75,13 +79,11 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
}
}
if newRoot, err := batcher.Commit(context.Background()); newRoot != nil {
root = newRoot
} else if err != nil {
if err := db.Flush(context.Background()); err == nil {
return db.Root(), nil
} else {
return nil, err
}
return root, nil
}
func sqlDDL(dEnv *env.DoltEnv, root *doltdb.RootValue, ddl *sqlparser.DDL, query string) (*doltdb.RootValue, error) {
@@ -136,3 +138,15 @@ func ExecuteSelect(root *doltdb.RootValue, query string) ([]sql.Row, error) {
return rows, nil
}
func drainIter(iter sql.RowIter) error {
var returnedErr error
for {
_, err := iter.Next()
if err == io.EOF {
return returnedErr
} else if err != nil {
returnedErr = err
}
}
}
+224
View File
@@ -0,0 +1,224 @@
// Copyright 2019 Liquidata, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqle
import (
"context"
"fmt"
"github.com/src-d/go-mysql-server/sql"
"github.com/liquidata-inc/dolt/go/cmd/dolt/errhand"
"github.com/liquidata-inc/dolt/go/store/hash"
"github.com/liquidata-inc/dolt/go/store/types"
)
var ErrDuplicatePrimaryKeyFmt = "duplicate primary key given: (%v)"
// tableEditor supports making multiple row edits (inserts, updates, deletes) to a table. It does error checking for key
// collision etc. in the Close() method, as well as during Insert / Update.
// Right now a table editor allows you to combine inserts, updates, and deletes in any order, and makes reasonable
// attempts to produce correct results when doing so. But this probably (definitely) doesn't work in every case, and
// higher-level clients should carefully flush the editor when necessary (i.e. before an update after many inserts).
type tableEditor struct {
t *DoltTable
ed *types.MapEditor
insertedKeys map[hash.Hash]types.Value
addedKeys map[hash.Hash]types.Value
removedKeys map[hash.Hash]types.Value
}
var _ sql.RowReplacer = (*tableEditor)(nil)
var _ sql.RowUpdater = (*tableEditor)(nil)
var _ sql.RowInserter = (*tableEditor)(nil)
var _ sql.RowDeleter = (*tableEditor)(nil)
func newTableEditor(t *DoltTable) *tableEditor {
return &tableEditor{
t: t,
insertedKeys: make(map[hash.Hash]types.Value),
addedKeys: make(map[hash.Hash]types.Value),
removedKeys: make(map[hash.Hash]types.Value),
}
}
func (te *tableEditor) Insert(ctx *sql.Context, sqlRow sql.Row) error {
dRow, err := SqlRowToDoltRow(te.t.table.Format(), sqlRow, te.t.sch)
if err != nil {
return err
}
key, err := dRow.NomsMapKey(te.t.sch).Value(ctx)
if err != nil {
return errhand.BuildDError("failed to get row key").AddCause(err).Build()
}
hash, err := key.Hash(dRow.Format())
if err != nil {
return err
}
// If we've already inserted this key as part of this insert operation, that's an error. Inserting a row that already
// exists in the table will be handled in Close().
if _, ok := te.addedKeys[hash]; ok {
value, err := types.EncodedValue(ctx, key)
if err != nil {
return err
}
return fmt.Errorf(ErrDuplicatePrimaryKeyFmt, value)
}
te.insertedKeys[hash] = key
te.addedKeys[hash] = key
if te.ed == nil {
te.ed, err = te.t.newMapEditor(ctx)
if err != nil {
return err
}
}
te.ed = te.ed.Set(key, dRow.NomsMapValue(te.t.sch))
return nil
}
func (te *tableEditor) Delete(ctx *sql.Context, sqlRow sql.Row) error {
dRow, err := SqlRowToDoltRow(te.t.table.Format(), sqlRow, te.t.sch)
if err != nil {
return err
}
key, err := dRow.NomsMapKey(te.t.sch).Value(ctx)
if err != nil {
return errhand.BuildDError("failed to get row key").AddCause(err).Build()
}
hash, err := key.Hash(dRow.Format())
if err != nil {
return err
}
delete(te.addedKeys, hash)
te.removedKeys[hash] = key
if te.ed == nil {
te.ed, err = te.t.newMapEditor(ctx)
if err != nil {
return err
}
}
te.ed = te.ed.Remove(key)
return nil
}
func (t *DoltTable) newMapEditor(ctx context.Context) (*types.MapEditor, error) {
typesMap, err := t.table.GetRowData(ctx)
if err != nil {
return nil, errhand.BuildDError("failed to get row data.").AddCause(err).Build()
}
return typesMap.Edit(), nil
}
func (te *tableEditor) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
dOldRow, err := SqlRowToDoltRow(te.t.table.Format(), oldRow, te.t.sch)
if err != nil {
return err
}
dNewRow, err := SqlRowToDoltRow(te.t.table.Format(), newRow, te.t.sch)
if err != nil {
return err
}
// If the PK is changed then we need to delete the old value and insert the new one
dOldKey := dOldRow.NomsMapKey(te.t.sch)
dOldKeyVal, err := dOldKey.Value(ctx)
if err != nil {
return err
}
dNewKey := dNewRow.NomsMapKey(te.t.sch)
dNewKeyVal, err := dNewKey.Value(ctx)
if err != nil {
return err
}
if !dOldKeyVal.Equals(dNewKeyVal) {
oldHash, err := dOldKeyVal.Hash(dOldRow.Format())
if err != nil {
return err
}
newHash, err := dNewKeyVal.Hash(dNewRow.Format())
if err != nil {
return err
}
// If the old value of the primary key we just updated was previously inserted, then we need to remove it now.
if _, ok := te.insertedKeys[oldHash]; ok {
delete(te.insertedKeys, oldHash)
te.ed.Remove(dOldKeyVal)
}
te.addedKeys[newHash] = dNewKeyVal
te.removedKeys[oldHash] = dOldKeyVal
}
if te.ed == nil {
te.ed, err = te.t.newMapEditor(ctx)
if err != nil {
return err
}
}
te.ed.Set(dNewKeyVal, dNewRow.NomsMapValue(te.t.sch))
return nil
}
// Close implements Closer
func (te *tableEditor) Close(ctx *sql.Context) error {
// If we're running in batched mode, don't flush the edits until explicitly told to do so by the parent table.
if te.t.db.batchMode == batched {
return nil
}
return te.flush(ctx)
}
func (te *tableEditor) flush(ctx context.Context) error {
// For all added keys, check for and report a collision
for hash, addedKey := range te.addedKeys {
if _, ok := te.removedKeys[hash]; !ok {
_, rowExists, err := te.t.table.GetRow(ctx, addedKey.(types.Tuple), te.t.sch)
if err != nil {
return errhand.BuildDError("failed to read table").AddCause(err).Build()
}
if rowExists {
value, err := types.EncodedValue(ctx, addedKey)
if err != nil {
return err
}
return fmt.Errorf(ErrDuplicatePrimaryKeyFmt, value)
}
}
}
// For all removed keys, remove the map entries that weren't added elsewhere by other updates
for hash, removedKey := range te.removedKeys {
if _, ok := te.addedKeys[hash]; !ok {
te.ed.Remove(removedKey)
}
}
if te.ed != nil {
return te.t.updateTable(ctx, te.ed)
}
return nil
}
@@ -0,0 +1,190 @@
// Copyright 2019 Liquidata, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqle
import (
"context"
"testing"
"github.com/src-d/go-mysql-server/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/dtestutils"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
. "github.com/liquidata-inc/dolt/go/libraries/doltcore/sql/sqltestutil"
)
type tableEditorTest struct {
// The name of this test. Names should be unique and descriptive.
name string
// Test setup to run
setup func(ctx *sql.Context, t *testing.T, ed *tableEditor)
// The select query to run to verify the results
selectQuery string
// The rows this query should return, nil if an error is expected
expectedRows []row.Row
// Expected error string, if any
expectedErr string
}
func TestTableEditor(t *testing.T) {
edna := NewPeopleRow(10, "Edna", "Krabapple", false, 38, 8.0)
krusty := NewPeopleRow(11, "Krusty", "Klown", false, 48, 9.5)
smithers := NewPeopleRow(12, "Waylon", "Smithers", false, 44, 7.1)
ralph := NewPeopleRow(13, "Ralph", "Wiggum", false, 9, 9.1)
martin := NewPeopleRow(14, "Martin", "Prince", false, 11, 6.3)
skinner := NewPeopleRow(15, "Seymore", "Skinner", false, 50, 8.3)
fatTony := NewPeopleRow(16, "Fat", "Tony", false, 53, 5.0)
troyMclure := NewPeopleRow(17, "Troy", "McClure", false, 58, 7.0)
var expectedErr error
// Some of these are pretty exotic use cases, but since we support all these operations it's nice to know they work
// in tandem.
testCases := []tableEditorTest{
{
name: "all inserts",
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(smithers, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(ralph, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(martin, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(skinner, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(fatTony, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(troyMclure, PeopleTestSchema)))
},
selectQuery: "select * from people where id >= 10",
expectedRows: CompressRows(PeopleTestSchema,
edna, krusty, smithers, ralph, martin, skinner, fatTony, troyMclure,
),
},
{
name: "inserts and deletes",
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
require.NoError(t, ed.Delete(ctx, r(edna, PeopleTestSchema)))
},
selectQuery: "select * from people where id >= 10",
expectedRows: CompressRows(PeopleTestSchema,
krusty,
),
},
{
name: "inserts and deletes 2",
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
require.NoError(t, ed.Delete(ctx, r(edna, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(fatTony, PeopleTestSchema)))
require.NoError(t, ed.Delete(ctx, r(Homer, PeopleTestSchema)))
},
selectQuery: "select * from people where id >= 10 or id = 0",
expectedRows: CompressRows(PeopleTestSchema,
krusty, fatTony,
),
},
{
name: "inserts and updates",
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
require.NoError(t, ed.Update(ctx, r(edna, PeopleTestSchema), r(MutateRow(edna, AgeTag, 1), PeopleTestSchema)))
},
selectQuery: "select * from people where id >= 10",
expectedRows: CompressRows(PeopleTestSchema,
MutateRow(edna, AgeTag, 1),
krusty,
),
},
{
name: "inserts updates and deletes",
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
require.NoError(t, ed.Update(ctx, r(edna, PeopleTestSchema), r(MutateRow(edna, AgeTag, 1), PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(smithers, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(ralph, PeopleTestSchema)))
require.NoError(t, ed.Update(ctx, r(smithers, PeopleTestSchema), r(MutateRow(smithers, AgeTag, 1), PeopleTestSchema)))
require.NoError(t, ed.Delete(ctx, r(smithers, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(skinner, PeopleTestSchema)))
require.NoError(t, ed.Delete(ctx, r(ralph, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(ralph, PeopleTestSchema)))
},
selectQuery: "select * from people where id >= 10",
expectedRows: CompressRows(PeopleTestSchema,
MutateRow(edna, AgeTag, 1),
krusty,
ralph,
skinner,
),
},
{
name: "inserts and updates to primary key",
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
require.NoError(t, ed.Update(ctx, r(edna, PeopleTestSchema), r(MutateRow(edna, IdTag, 30), PeopleTestSchema)))
},
selectQuery: "select * from people where id >= 10",
expectedRows: CompressRows(PeopleTestSchema,
krusty,
MutateRow(edna, IdTag, 30),
),
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
expectedErr = nil
dEnv := dtestutils.CreateTestEnv()
CreateTestDatabase(dEnv, t)
ctx := sql.NewEmptyContext()
root, _ := dEnv.WorkingRoot(context.Background())
db := NewDatabase("dolt", root, dEnv)
peopleTable, _, err := db.GetTableInsensitive(ctx, "people")
require.NoError(t, err)
dt := peopleTable.(sql.UpdatableTable)
ed := dt.Updater(ctx).(*tableEditor)
test.setup(ctx, t, ed)
if len(test.expectedErr) > 0 {
require.Error(t, expectedErr)
assert.Contains(t, expectedErr.Error(), test.expectedErr)
return
} else {
require.NoError(t, ed.Close(ctx))
}
root = db.Root()
actualRows, _, err := executeSelect(context.Background(), dEnv, CompressSchema(PeopleTestSchema), root, test.selectQuery)
require.NoError(t, err)
assert.Equal(t, test.expectedRows, actualRows)
})
}
}
func r(row row.Row, sch schema.Schema) sql.Row {
sqlRow, err := doltRowToSqlRow(row, sch)
if err != nil {
panic(err)
}
return sqlRow
}
+55 -215
View File
@@ -25,24 +25,25 @@ import (
"github.com/liquidata-inc/dolt/go/cmd/dolt/errhand"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
"github.com/liquidata-inc/dolt/go/store/hash"
"github.com/liquidata-inc/dolt/go/store/types"
)
// DoltTable implements the sql.Table interface and gives access to dolt table rows and schema.
type DoltTable struct {
name string
table *doltdb.Table
sch schema.Schema
sqlSch sql.Schema
db *Database
ed *tableEditor
}
var _ sql.Table = (*DoltTable)(nil)
var _ sql.UpdatableTable = (*DoltTable)(nil)
var _ sql.DeletableTable = (*DoltTable)(nil)
var _ sql.InsertableTable = (*DoltTable)(nil)
var _ sql.ReplaceableTable = (*DoltTable)(nil)
// DoltTable implements the sql.Table interface and gives access to dolt table rows and schema.
type DoltTable struct {
name string
table *doltdb.Table
sch schema.Schema
db *Database
}
// Implements sql.IndexableTable
func (t *DoltTable) WithIndexLookup(lookup sql.IndexLookup) sql.Table {
dil, ok := lookup.(*doltIndexLookup)
@@ -78,20 +79,21 @@ func (t *DoltTable) String() string {
// Schema returns the schema for this table.
func (t *DoltTable) Schema() sql.Schema {
// TODO: fix panics
sch, err := t.table.GetSchema(context.TODO())
if err != nil {
panic(err)
}
// TODO: fix panics
sqlSch, err := doltSchemaToSqlSchema(t.name, sch)
return t.sqlSchema()
}
func (t *DoltTable) sqlSchema() sql.Schema {
if t.sqlSch != nil {
return t.sqlSch
}
// TODO: fix panics
sqlSch, err := doltSchemaToSqlSchema(t.name, t.sch)
if err != nil {
panic(err)
}
t.sqlSch = sqlSch
return sqlSch
}
@@ -106,206 +108,44 @@ func (t *DoltTable) PartitionRows(ctx *sql.Context, _ sql.Partition) (sql.RowIte
return newRowIterator(t, ctx)
}
type tableEditor struct {
t *DoltTable
ed *types.MapEditor
addedKeys map[hash.Hash]bool
deletedKeys map[hash.Hash]bool
}
func newTableEditor(t *DoltTable) *tableEditor {
return &tableEditor{
t: t,
addedKeys: make(map[hash.Hash]bool),
deletedKeys: make(map[hash.Hash]bool),
}
}
var _ sql.RowReplacer = (*tableEditor)(nil)
var _ sql.RowUpdater = (*tableUpdater)(nil)
var _ sql.RowInserter = (*tableEditor)(nil)
var _ sql.RowDeleter = (*tableEditor)(nil)
func (te *tableEditor) Insert(ctx *sql.Context, sqlRow sql.Row) error {
dRow, err := SqlRowToDoltRow(te.t.table.Format(), sqlRow, te.t.sch)
if err != nil {
return err
}
key, err := dRow.NomsMapKey(te.t.sch).Value(ctx)
if err != nil {
return errhand.BuildDError("failed to get row key").AddCause(err).Build()
}
_, rowExists, err := te.t.table.GetRow(ctx, key.(types.Tuple), te.t.sch)
if err != nil {
return errhand.BuildDError("failed to read table").AddCause(err).Build()
}
hash, err := key.Hash(dRow.Format())
if err != nil {
return err
}
if (rowExists && !te.deletedKeys[hash]) || te.addedKeys[hash] {
return errors.New("duplicate primary key given")
}
te.addedKeys[hash] = true
if te.ed == nil {
te.ed, err = te.t.newMapEditor(ctx)
if err != nil {
return err
}
}
te.ed = te.ed.Set(key, dRow.NomsMapValue(te.t.sch))
return nil
}
func (t *DoltTable) newMapEditor(ctx context.Context) (*types.MapEditor, error) {
typesMap, err := t.table.GetRowData(ctx)
if err != nil {
return nil, errhand.BuildDError("failed to get row data.").AddCause(err).Build()
}
return typesMap.Edit(), nil
}
func (te *tableEditor) Delete(ctx *sql.Context, sqlRow sql.Row) error {
dRow, err := SqlRowToDoltRow(te.t.table.Format(), sqlRow, te.t.sch)
if err != nil {
return err
}
key, err := dRow.NomsMapKey(te.t.sch).Value(ctx)
if err != nil {
return errhand.BuildDError("failed to get row key").AddCause(err).Build()
}
hash, err := key.Hash(dRow.Format())
if err != nil {
return err
}
delete(te.addedKeys, hash)
te.deletedKeys[hash] = true
if te.ed == nil {
te.ed, err = te.t.newMapEditor(ctx)
if err != nil {
return err
}
}
te.ed = te.ed.Remove(key)
return nil
}
// tableUpdater wraps tableEditor to override the close method, necessary to enforce primary key constraints when
// updates to primary key columns are applied in an arbitrary order
type tableUpdater struct {
t *DoltTable
ed *types.MapEditor
addedKeys map[hash.Hash]types.LesserValuable
removedKeys map[hash.Hash]types.LesserValuable
}
func (tu *tableUpdater) Close(ctx *sql.Context) error {
// For all added keys, check for and report a collision
for hash, addedKey := range tu.addedKeys {
if _, ok := tu.removedKeys[hash]; !ok {
_, rowExists, err := tu.t.table.GetRow(ctx, addedKey.(types.Tuple), tu.t.sch)
if err != nil {
return errhand.BuildDError("failed to read table").AddCause(err).Build()
}
if rowExists {
return fmt.Errorf("primary key collision: (%v)", addedKey)
}
}
}
// For all removed keys, remove the map entries that weren't added elsewhere by other updates
for hash, removedKey := range tu.removedKeys {
if _, ok := tu.addedKeys[hash]; !ok {
tu.ed.Remove(removedKey)
}
}
if tu.ed != nil {
return tu.t.updateTable(ctx, tu.ed)
}
return nil
}
func (tu *tableUpdater) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
dOldRow, err := SqlRowToDoltRow(tu.t.table.Format(), oldRow, tu.t.sch)
if err != nil {
return err
}
dNewRow, err := SqlRowToDoltRow(tu.t.table.Format(), newRow, tu.t.sch)
if err != nil {
return err
}
// If the PK is changed then we have to delete the old row first
dOldKey := dOldRow.NomsMapKey(tu.t.sch)
dOldKeyVal, err := dOldKey.Value(ctx)
if err != nil {
return err
}
dNewKey := dNewRow.NomsMapKey(tu.t.sch)
dNewKeyVal, err := dNewKey.Value(ctx)
if err != nil {
return err
}
if tu.ed == nil {
tu.ed, err = tu.t.newMapEditor(ctx)
if err != nil {
return err
}
}
if !dOldKeyVal.Equals(dNewKeyVal) {
oldHash, err := dOldKeyVal.Hash(dOldRow.Format())
if err != nil {
return err
}
newHash, err := dNewKeyVal.Hash(dNewRow.Format())
if err != nil {
return err
}
tu.addedKeys[newHash] = dNewKeyVal
tu.removedKeys[oldHash] = dOldKey
}
tu.ed.Set(dNewKey, dNewRow.NomsMapValue(tu.t.sch))
return nil
}
func (te *tableEditor) Close(ctx *sql.Context) error {
if te.ed != nil {
return te.t.updateTable(ctx, te.ed)
}
return nil
}
// Inserter implements sql.InsertableTable
func (t *DoltTable) Inserter(ctx *sql.Context) sql.RowInserter {
return newTableEditor(t)
return t.getTableEditor()
}
func (t *DoltTable) Deleter(*sql.Context) sql.RowDeleter {
return newTableEditor(t)
}
func (t *DoltTable) Replacer(ctx *sql.Context) sql.RowReplacer {
return newTableEditor(t)
}
func (t *DoltTable) Updater(ctx *sql.Context) sql.RowUpdater {
return &tableUpdater{
t: t,
addedKeys: make(map[hash.Hash]types.LesserValuable),
removedKeys: make(map[hash.Hash]types.LesserValuable),
func (t *DoltTable) getTableEditor() *tableEditor {
if t.db.batchMode == batched {
if t.ed != nil {
return t.ed
}
t.ed = newTableEditor(t)
return t.ed
}
return newTableEditor(t)
}
func (t *DoltTable) flushBatchedEdits(ctx context.Context) error {
if t.ed != nil {
err := t.ed.flush(ctx)
t.ed = nil
return err
}
return nil
}
// Deleter implements sql.DeletableTable
func (t *DoltTable) Deleter(*sql.Context) sql.RowDeleter {
return t.getTableEditor()
}
// Replacer implements sql.ReplaceableTable
func (t *DoltTable) Replacer(ctx *sql.Context) sql.RowReplacer {
return t.getTableEditor()
}
// Updater implements sql.UpdatableTable
func (t *DoltTable) Updater(ctx *sql.Context) sql.RowUpdater {
return t.getTableEditor()
}
// doltTablePartitionIter, an object that knows how to return the single partition exactly once.
@@ -342,7 +182,7 @@ func (p doltTablePartition) Key() []byte {
return []byte(partitionName)
}
func (t *DoltTable) updateTable(ctx *sql.Context, mapEditor *types.MapEditor) error {
func (t *DoltTable) updateTable(ctx context.Context, mapEditor *types.MapEditor) error {
updated, err := mapEditor.Map(ctx)
if err != nil {
return errhand.BuildDError("failed to modify table").AddCause(err).Build()