mirror of
https://github.com/dolthub/dolt.git
synced 2026-04-26 11:39:05 -05:00
Merge pull request #233 from liquidata-inc/zachmu/sql-batch
Implemented batch insert semantics for the new SQL engine.
This commit is contained in:
@@ -137,7 +137,7 @@ if rows[2] != "9,8,7,6,5,4".split(","):
|
||||
[ "$output" = "invalid column name c6" ]
|
||||
run dolt sql -q "insert into test (pk,c1,c2,c3,c4,c5) values (0,6,6,6,6,6)"
|
||||
[ "$status" -eq 1 ]
|
||||
[ "$output" = "duplicate primary key given" ] || false
|
||||
[[ "$output" =~ "duplicate primary key" ]] || false
|
||||
}
|
||||
|
||||
@test "dolt sql insert same column twice" {
|
||||
|
||||
@@ -85,7 +85,7 @@ teardown() {
|
||||
[[ ! "$output" =~ "6" ]] || false
|
||||
run dolt sql -q "insert into test (pk1,pk2,c1,c2,c3,c4,c5) values (0,1,7,7,7,7,7)"
|
||||
[ "$status" -eq 1 ]
|
||||
[ "$output" = "duplicate primary key given" ] || false
|
||||
[[ "$output" =~ "duplicate primary key" ]] || false
|
||||
run dolt sql -q "insert into test (pk1,c1,c2,c3,c4,c5) values (0,6,6,6,6,6)"
|
||||
[ "$status" -eq 1 ]
|
||||
[ "$output" = "column name 'pk2' is non-nullable but attempted to set default value of null" ] || false
|
||||
|
||||
+88
-58
@@ -106,14 +106,14 @@ func Sql(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEn
|
||||
return HandleVErrAndExitCode(verr, usage)
|
||||
}
|
||||
|
||||
se, err := newSqlEngine(dEnv, root)
|
||||
if err != nil {
|
||||
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
|
||||
}
|
||||
origRoot := se.db.Root()
|
||||
origRoot := root
|
||||
|
||||
// run a single command and exit
|
||||
if query, ok := apr.GetValue(queryFlag); ok {
|
||||
se, err := newSqlEngine(dEnv, dsqle.NewDatabase("dolt", root, dEnv))
|
||||
if err != nil {
|
||||
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
|
||||
}
|
||||
if err := processQuery(ctx, query, se); err != nil {
|
||||
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
|
||||
} else if se.db.Root() != origRoot {
|
||||
@@ -125,8 +125,13 @@ func Sql(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEn
|
||||
|
||||
// Run in either batch mode for piped input, or shell mode for interactive
|
||||
fi, err := os.Stdin.Stat()
|
||||
var se *sqlEngine
|
||||
// Windows has a bug where STDIN can't be statted in some cases, see https://github.com/golang/go/issues/33570
|
||||
if (err != nil && osutil.IsWindows) || (fi.Mode()&os.ModeCharDevice) == 0 {
|
||||
se, err = newSqlEngine(dEnv, dsqle.NewBatchedDatabase("dolt", root, dEnv))
|
||||
if err != nil {
|
||||
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
|
||||
}
|
||||
err = runBatchMode(ctx, se)
|
||||
if err != nil {
|
||||
return 1
|
||||
@@ -134,7 +139,11 @@ func Sql(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEn
|
||||
} else if err != nil {
|
||||
HandleVErrAndExitCode(errhand.BuildDError("Couldn't stat STDIN. This is a bug.").Build(), usage)
|
||||
} else {
|
||||
err := runShell(ctx, se)
|
||||
se, err = newSqlEngine(dEnv, dsqle.NewDatabase("dolt", root, dEnv))
|
||||
if err != nil {
|
||||
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
|
||||
}
|
||||
err = runShell(ctx, se)
|
||||
if err != nil {
|
||||
return HandleVErrAndExitCode(errhand.BuildDError("unable to start shell").AddCause(err).Build(), usage)
|
||||
}
|
||||
@@ -174,8 +183,6 @@ func runBatchMode(ctx context.Context, se *sqlEngine) error {
|
||||
scanner.Buffer(buf, maxCapacity)
|
||||
scanner.Split(scanStatements)
|
||||
|
||||
batcher := dsql.NewSqlBatcher(se.dEnv.DoltDB, se.db.Root())
|
||||
|
||||
var query string
|
||||
for scanner.Scan() {
|
||||
query += scanner.Text()
|
||||
@@ -184,21 +191,24 @@ func runBatchMode(ctx context.Context, se *sqlEngine) error {
|
||||
}
|
||||
if !batchInsertEarlySemicolon(query) {
|
||||
query += ";"
|
||||
// TODO: We should fix this problem by properly implementing a state machine for scanStatements
|
||||
continue
|
||||
}
|
||||
if err := processBatchQuery(ctx, query, se, batcher); err != nil {
|
||||
if err := processBatchQuery(ctx, query, se); err != nil {
|
||||
_, _ = fmt.Fprintf(cli.CliErr, "Error processing query '%s': %s\n", query, err.Error())
|
||||
return err
|
||||
}
|
||||
query = ""
|
||||
}
|
||||
|
||||
updateBatchInsertOutput()
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
cli.Println(err.Error())
|
||||
}
|
||||
|
||||
if newRoot, _ := batcher.Commit(ctx); newRoot != nil {
|
||||
se.db.SetRoot(newRoot)
|
||||
if err := se.db.Flush(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -420,7 +430,10 @@ func prepend(s string, ss []string) []string {
|
||||
// Processes a single query. The Root of the sqlEngine will be updated if necessary.
|
||||
func processQuery(ctx context.Context, query string, se *sqlEngine) error {
|
||||
sqlStatement, err := sqlparser.Parse(query)
|
||||
if err != nil {
|
||||
if err == sqlparser.ErrEmpty {
|
||||
// silently skip empty statements
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("Error parsing SQL: %v.", err.Error())
|
||||
}
|
||||
|
||||
@@ -454,36 +467,89 @@ func processQuery(ctx context.Context, query string, se *sqlEngine) error {
|
||||
}
|
||||
}
|
||||
|
||||
type stats struct {
|
||||
numRowsInserted int
|
||||
numRowsUpdated int
|
||||
numErrorsIgnored int
|
||||
}
|
||||
|
||||
var batchEditStats stats
|
||||
var displayStrLen int
|
||||
|
||||
const maxBatchSize = 50000
|
||||
const updateInterval = 500
|
||||
|
||||
// Processes a single query in batch mode. The Root of the sqlEngine may or may not be changed.
|
||||
func processBatchQuery(ctx context.Context, query string, se *sqlEngine, batcher *dsql.SqlBatcher) error {
|
||||
func processBatchQuery(ctx context.Context, query string, se *sqlEngine) error {
|
||||
sqlStatement, err := sqlparser.Parse(query)
|
||||
if err != nil {
|
||||
if err == sqlparser.ErrEmpty {
|
||||
// silently skip empty statements
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("Error parsing SQL: %v.", err.Error())
|
||||
}
|
||||
|
||||
switch s := sqlStatement.(type) {
|
||||
switch sqlStatement.(type) {
|
||||
case *sqlparser.Insert:
|
||||
return se.insertBatch(ctx, s, batcher)
|
||||
_, rowIter, err := se.query(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error inserting rows: %v", err.Error())
|
||||
}
|
||||
|
||||
err = mergeInsertResultIntoStats(rowIter, &batchEditStats)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error inserting rows: %v", err.Error())
|
||||
}
|
||||
|
||||
if batchEditStats.numRowsInserted%maxBatchSize == 0 {
|
||||
err := se.db.Flush(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if batchEditStats.numRowsInserted%updateInterval == 0 {
|
||||
updateBatchInsertOutput()
|
||||
}
|
||||
|
||||
return nil
|
||||
default:
|
||||
// For any other kind of statement, we need to commit whatever batch edit we've accumulated so far before executing
|
||||
// the query
|
||||
newRoot, err := batcher.Commit(ctx)
|
||||
err := se.db.Flush(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
se.db.SetRoot(newRoot)
|
||||
|
||||
err = processQuery(ctx, query, se)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := batcher.UpdateRoot(se.db.Root()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func updateBatchInsertOutput() {
|
||||
displayStr := fmt.Sprintf("Rows inserted: %d", batchEditStats.numRowsInserted)
|
||||
displayStrLen = cli.DeleteAndPrint(displayStrLen, displayStr)
|
||||
}
|
||||
|
||||
// Updates the batch insert stats with the results of an insert operation.
|
||||
func mergeInsertResultIntoStats(rowIter sql.RowIter, s *stats) error {
|
||||
for {
|
||||
row, err := rowIter.Next()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
updated := row[0].(int64)
|
||||
s.numRowsInserted += int(updated)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type sqlEngine struct {
|
||||
db *dsqle.Database
|
||||
dEnv *env.DoltEnv
|
||||
@@ -491,8 +557,7 @@ type sqlEngine struct {
|
||||
}
|
||||
|
||||
// sqlEngine packages up the context necessary to run sql queries against sqle.
|
||||
func newSqlEngine(dEnv *env.DoltEnv, root *doltdb.RootValue) (*sqlEngine, error) {
|
||||
db := dsqle.NewDatabase("dolt", root, dEnv)
|
||||
func newSqlEngine(dEnv *env.DoltEnv, db *dsqle.Database) (*sqlEngine, error) {
|
||||
engine := sqle.NewDefault()
|
||||
engine.AddDatabase(db)
|
||||
|
||||
@@ -663,41 +728,6 @@ func runPrintingPipeline(ctx context.Context, nbf *types.NomsBinFormat, p *pipel
|
||||
return nil
|
||||
}
|
||||
|
||||
type stats struct {
|
||||
numRowsInserted int
|
||||
numRowsUpdated int
|
||||
numErrorsIgnored int
|
||||
}
|
||||
|
||||
var batchEditStats stats
|
||||
var displayStrLen int
|
||||
|
||||
// Executes a SQL insert statement in batch mode. If the root value changes, sqlEngine's root will be updated.
|
||||
// No output is written to the console in batch mode.
|
||||
func (se *sqlEngine) insertBatch(ctx context.Context, stmt *sqlparser.Insert, batcher *dsql.SqlBatcher) error {
|
||||
result, err := dsql.ExecuteBatchInsert(ctx, se.db.Root(), stmt, batcher)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error inserting rows: %v", err.Error())
|
||||
}
|
||||
mergeResultIntoStats(result, &batchEditStats)
|
||||
|
||||
displayStr := fmt.Sprintf("Rows inserted: %d, Updated: %d, Errors: %d",
|
||||
batchEditStats.numRowsInserted, batchEditStats.numRowsUpdated, batchEditStats.numErrorsIgnored)
|
||||
displayStrLen = cli.DeleteAndPrint(displayStrLen, displayStr)
|
||||
|
||||
if result.Root != nil {
|
||||
se.db.SetRoot(result.Root)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeResultIntoStats(result *dsql.InsertResult, stats *stats) {
|
||||
stats.numRowsInserted += result.NumRowsInserted
|
||||
stats.numRowsUpdated += result.NumRowsUpdated
|
||||
stats.numErrorsIgnored += result.NumErrorsIgnored
|
||||
}
|
||||
|
||||
// Checks if the query is a naked delete and then deletes all rows if so. Returns true if it did so, false otherwise.
|
||||
func (se *sqlEngine) checkThenDeleteAllRows(ctx context.Context, s *sqlparser.Delete) bool {
|
||||
if s.Where == nil && s.Limit == nil && s.Partitions == nil && len(s.TableExprs) == 1 {
|
||||
|
||||
@@ -101,25 +101,6 @@ func TestSqlShow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedRes int
|
||||
}{
|
||||
{"no query", []string{"-q", ""}, 1},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
dEnv := createEnvWithSeedData(t)
|
||||
|
||||
commandStr := "dolt sql"
|
||||
result := Sql(context.TODO(), commandStr, test.args, dEnv)
|
||||
assert.Equal(t, test.expectedRes, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Tests of the create table SQL command, mostly a smoke test for errors in the command line handler. Most tests of
|
||||
// create table SQL command are in the sql package.
|
||||
func TestCreateTable(t *testing.T) {
|
||||
|
||||
@@ -59,7 +59,7 @@ require (
|
||||
|
||||
replace github.com/liquidata-inc/dolt/go/gen/proto/dolt/services/eventsapi => ./gen/proto/dolt/services/eventsapi
|
||||
|
||||
replace github.com/src-d/go-mysql-server => github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191125230736-ebf9aee2fe67
|
||||
replace github.com/src-d/go-mysql-server => github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191204003907-576e8690465b
|
||||
|
||||
replace vitess.io/vitess => github.com/liquidata-inc/vitess v0.0.0-20191125220844-c7181e70ee98
|
||||
|
||||
|
||||
@@ -203,8 +203,8 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFB
|
||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
|
||||
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191125230736-ebf9aee2fe67 h1:8ZDcYhWV0/yC9ATVsxJXbaFPsR4s2b2VSb83vTYobGs=
|
||||
github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191125230736-ebf9aee2fe67/go.mod h1:hzni/QFitaMZ9H7hbZ3K+C4BEl75iDKqTIY+Hbvj6VQ=
|
||||
github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191204003907-576e8690465b h1:LaTfDUa9JSuS4sHaDqfHgwFitInog3+wYlWOViNrfyw=
|
||||
github.com/liquidata-inc/go-mysql-server v0.4.1-0.20191204003907-576e8690465b/go.mod h1:hzni/QFitaMZ9H7hbZ3K+C4BEl75iDKqTIY+Hbvj6VQ=
|
||||
github.com/liquidata-inc/ishell v0.0.0-20190514193646-693241f1f2a0 h1:phMgajKClMUiIr+hF2LGt8KRuUa2Vd2GI1sNgHgSXoU=
|
||||
github.com/liquidata-inc/ishell v0.0.0-20190514193646-693241f1f2a0/go.mod h1:YC1rI9k5gx8D02ljlbxDfZe80s/iq8bGvaaQsvR+qxs=
|
||||
github.com/liquidata-inc/mmap-go v1.0.3 h1:2LndAeAtup9rpvUmu4wZSYCsjCQ0Zpc+NqE+6+PnT7g=
|
||||
@@ -244,7 +244,6 @@ github.com/olekukonko/tablewriter v0.0.0-20160115111002-cca8bbc07984/go.mod h1:v
|
||||
github.com/oliveagle/jsonpath v0.0.0-20180606110733-2e52cf6e6852 h1:Yl0tPBa8QPjGmesFh1D0rDy+q1Twx6FyU7VWHi8wZbI=
|
||||
github.com/oliveagle/jsonpath v0.0.0-20180606110733-2e52cf6e6852/go.mod h1:eqOVx5Vwu4gd2mmMZvVZsgIqNSaW3xxRThUJ0k/TPk4=
|
||||
github.com/opentracing-contrib/go-grpc v0.0.0-20180928155321-4b5a12d3ff02/go.mod h1:JNdpVEzCpXBgIiv4ds+TzhN1hrtxq6ClLrTlT9OQRSc=
|
||||
github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg=
|
||||
github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
|
||||
github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU=
|
||||
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
|
||||
@@ -431,7 +430,6 @@ golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBn
|
||||
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0 h1:Dh6fw+p6FyRl5x/FvNswO1ji0lIGzm3KP8Y9VkS9PTE=
|
||||
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
|
||||
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
|
||||
@@ -442,7 +440,6 @@ google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsb
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.6.1 h1:QzqyMA1tlu6CgqCDUtU9V+ZKhLFT2dkJuANu5QaxI3I=
|
||||
google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
|
||||
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
@@ -458,7 +455,6 @@ google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZi
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
|
||||
google.golang.org/grpc v1.24.0 h1:vb/1TCsVn3DcJlQ0Gs1yB1pKI6Do2/QNwxdKqmc/b0s=
|
||||
google.golang.org/grpc v1.24.0/go.mod h1:XDChyiUovWa60DnaeDeZmSW86xtLtjtZbwvSiRnRtcA=
|
||||
google.golang.org/grpc v1.25.1 h1:wdKvqQk7IttEw92GoRyKG2IDrUIpgpj6H6m81yfeMW0=
|
||||
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
// Copyright 2019 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/liquidata-inc/dolt/go/store/hash"
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
// SqlBatcher knows how to efficiently batch insert / update statements, e.g. when doing a SQL import. It does this by
|
||||
// using a single MapEditor per table that isn't persisted until Commit is called.
|
||||
type SqlBatcher struct {
|
||||
// The database we are editing
|
||||
db *doltdb.DoltDB
|
||||
// The root value we are editing
|
||||
root *doltdb.RootValue
|
||||
// The set of tables under edit
|
||||
tables map[string]*doltdb.Table
|
||||
// The schemas of tables under edit
|
||||
schemas map[string]schema.Schema
|
||||
// The row data for tables being edited
|
||||
rowData map[string]types.Map
|
||||
// The editors applying updates to the tables
|
||||
editors map[string]*types.MapEditor
|
||||
// The hashes of primary keys being inserted to the tables
|
||||
hashes map[string]map[hash.Hash]bool
|
||||
}
|
||||
|
||||
// Returns a new SqlBatcher for the given environment and root value.
|
||||
func NewSqlBatcher(db *doltdb.DoltDB, root *doltdb.RootValue) *SqlBatcher {
|
||||
batcher := &SqlBatcher{
|
||||
db: db,
|
||||
root: root,
|
||||
}
|
||||
batcher.resetState()
|
||||
return batcher
|
||||
}
|
||||
|
||||
// Updates this batcher with a new root value. If there are outstanding edits, returns an error.
|
||||
func (b *SqlBatcher) UpdateRoot(root *doltdb.RootValue) error {
|
||||
if b.isDirty() {
|
||||
return errors.New("UpdateRoot called with outstanding edits")
|
||||
}
|
||||
b.root = root
|
||||
|
||||
// resetting the state shouldn't be necessary here because of the isDirty check, but if the client chooses to ignore
|
||||
// the returned error, we'll at least have a clean state going forward
|
||||
b.resetState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// isDirty returns whether there are outstanding edits that haven't been committed.
|
||||
func (b *SqlBatcher) isDirty() bool {
|
||||
return len(b.editors) > 0
|
||||
}
|
||||
|
||||
// resetState flushes the cache of outstanding edits and other data
|
||||
func (b *SqlBatcher) resetState() {
|
||||
b.tables = make(map[string]*doltdb.Table)
|
||||
b.schemas = make(map[string]schema.Schema)
|
||||
b.rowData = make(map[string]types.Map)
|
||||
b.editors = make(map[string]*types.MapEditor)
|
||||
b.hashes = make(map[string]map[hash.Hash]bool)
|
||||
}
|
||||
|
||||
type InsertOptions struct {
|
||||
// Whether to silently replace any existing rows with the same primary key as rows inserted
|
||||
Replace bool
|
||||
}
|
||||
|
||||
type BatchInsertResult struct {
|
||||
RowInserted bool
|
||||
RowUpdated bool
|
||||
}
|
||||
|
||||
func (b *SqlBatcher) Insert(ctx context.Context, tableName string, r row.Row, opt InsertOptions) (*BatchInsertResult, error) {
|
||||
sch, err := b.GetSchema(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rowData, err := b.getRowData(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ed, err := b.getEditor(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := r.NomsMapKey(sch).Value(ctx)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyVal, _, err := rowData.MaybeGet(ctx, key)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rowExists := keyVal != nil
|
||||
hashes := b.getHashes(ctx, tableName)
|
||||
h, err := key.Hash(b.root.VRW().Format())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rowAlreadyTouched := hashes[h]
|
||||
|
||||
if rowExists || rowAlreadyTouched {
|
||||
if !opt.Replace {
|
||||
return nil, fmt.Errorf("Duplicate primary key: '%v'", getPrimaryKeyString(r, sch))
|
||||
}
|
||||
}
|
||||
|
||||
ed.Set(key, r.NomsMapValue(sch))
|
||||
h, err = key.Hash(b.root.VRW().Format())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hashes[h] = true
|
||||
|
||||
return &BatchInsertResult{RowInserted: !rowExists, RowUpdated: rowExists || rowAlreadyTouched}, nil
|
||||
}
|
||||
|
||||
// GetTable returns the table with the name given. This method is offered because reading the table from the root value
|
||||
// is relatively expensive, and SqlBatcher caches Tables to avoid the overhead.
|
||||
func (b *SqlBatcher) GetTable(ctx context.Context, tableName string) (*doltdb.Table, error) {
|
||||
if table, ok := b.tables[tableName]; ok {
|
||||
return table, nil
|
||||
}
|
||||
|
||||
if has, err := b.root.HasTable(ctx, tableName); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, fmt.Errorf("Unknown table %v", tableName)
|
||||
}
|
||||
|
||||
table, _, err := b.root.GetTable(ctx, tableName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.tables[tableName] = table
|
||||
return table, nil
|
||||
}
|
||||
|
||||
// GetSchema returns the schema for the table name given. This method is offered because reading the schema from disk
|
||||
// is actually relatively expensive -- SqlBatcher caches the schema values per table to avoid the overhead.
|
||||
func (b *SqlBatcher) GetSchema(ctx context.Context, tableName string) (schema.Schema, error) {
|
||||
if schema, ok := b.schemas[tableName]; ok {
|
||||
return schema, nil
|
||||
}
|
||||
|
||||
table, err := b.GetTable(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sch, err := table.GetSchema(ctx)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.schemas[tableName] = sch
|
||||
return sch, nil
|
||||
}
|
||||
|
||||
func (b *SqlBatcher) getEditor(ctx context.Context, tableName string) (*types.MapEditor, error) {
|
||||
if ed, ok := b.editors[tableName]; ok {
|
||||
return ed, nil
|
||||
}
|
||||
|
||||
rowData, err := b.getRowData(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ed := rowData.Edit()
|
||||
b.editors[tableName] = ed
|
||||
return ed, nil
|
||||
}
|
||||
|
||||
func (b *SqlBatcher) getRowData(ctx context.Context, tableName string) (types.Map, error) {
|
||||
if rowData, ok := b.rowData[tableName]; ok {
|
||||
return rowData, nil
|
||||
}
|
||||
|
||||
table, err := b.GetTable(ctx, tableName)
|
||||
if err != nil {
|
||||
return types.EmptyMap, err
|
||||
}
|
||||
|
||||
rowData, err := table.GetRowData(ctx)
|
||||
|
||||
if err != nil {
|
||||
return types.EmptyMap, err
|
||||
}
|
||||
|
||||
b.rowData[tableName] = rowData
|
||||
return rowData, nil
|
||||
}
|
||||
|
||||
func (b *SqlBatcher) getHashes(ctx context.Context, tableName string) map[hash.Hash]bool {
|
||||
if hashes, ok := b.hashes[tableName]; ok {
|
||||
return hashes
|
||||
}
|
||||
|
||||
hashes := make(map[hash.Hash]bool)
|
||||
b.hashes[tableName] = hashes
|
||||
return hashes
|
||||
}
|
||||
|
||||
// Commit writes a new root value for every table under edit and returns the new root value. Tables are written in an
|
||||
// arbitrary order.
|
||||
func (b *SqlBatcher) Commit(ctx context.Context) (*doltdb.RootValue, error) {
|
||||
root := b.root
|
||||
|
||||
for tableName, ed := range b.editors {
|
||||
newMap, err := ed.Map(ctx)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
table := b.tables[tableName]
|
||||
table, err = table.UpdateRows(ctx, newMap)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
root, err = root.PutTable(ctx, tableName, table)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
b.root = root
|
||||
b.resetState()
|
||||
|
||||
return root, nil
|
||||
}
|
||||
@@ -1,331 +0,0 @@
|
||||
// Copyright 2019 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"vitess.io/vitess/go/vt/sqlparser"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
type InsertResult struct {
|
||||
Root *doltdb.RootValue
|
||||
NumRowsInserted int
|
||||
NumRowsUpdated int
|
||||
NumErrorsIgnored int
|
||||
}
|
||||
|
||||
var ErrMissingPrimaryKeys = errors.New("One or more primary key columns missing from insert statement")
|
||||
var ConstraintFailedFmt = "Constraint failed for column '%v': %v"
|
||||
|
||||
// ExecuteInsertBatch executes the given insert statement in batch mode and returns the result. The table is not changed
|
||||
// until the batch is Committed. The InsertResult returned similarly doesn't have a Root set, since the root isn't
|
||||
// modified by this function.
|
||||
func ExecuteBatchInsert(
|
||||
ctx context.Context,
|
||||
root *doltdb.RootValue,
|
||||
s *sqlparser.Insert,
|
||||
batcher *SqlBatcher,
|
||||
) (*InsertResult, error) {
|
||||
|
||||
tableName := s.Table.Name.String()
|
||||
tableSch, err := batcher.GetSchema(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parser supports overwrite on insert with both the replace keyword (from MySQL) as well as the ignore keyword
|
||||
replace := s.Action == sqlparser.ReplaceStr
|
||||
ignore := s.Ignore != ""
|
||||
|
||||
// Get the list of columns to insert into. We support both naked inserts (no column list specified) as well as
|
||||
// explicit column lists.
|
||||
var cols []schema.Column
|
||||
if s.Columns == nil || len(s.Columns) == 0 {
|
||||
cols = tableSch.GetAllCols().GetColumns()
|
||||
} else {
|
||||
cols = make([]schema.Column, len(s.Columns))
|
||||
for i, colName := range s.Columns {
|
||||
for _, c := range cols {
|
||||
if c.Name == colName.String() {
|
||||
return nil, fmt.Errorf("Repeated column: '%v'", c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
col, ok := tableSch.GetAllCols().GetByName(colName.String())
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(UnknownColumnErrFmt, colName)
|
||||
}
|
||||
cols[i] = col
|
||||
}
|
||||
}
|
||||
|
||||
var rows []row.Row // your boat
|
||||
|
||||
switch queryRows := s.Rows.(type) {
|
||||
case sqlparser.Values:
|
||||
var err error
|
||||
rows, err = prepareInsertVals(root.VRW().Format(), cols, &queryRows, tableSch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case *sqlparser.Select:
|
||||
return nil, fmt.Errorf("Insert as select not supported")
|
||||
case *sqlparser.ParenSelect:
|
||||
return nil, fmt.Errorf("Parenthesized select expressions in insert not supported")
|
||||
case *sqlparser.Union:
|
||||
return nil, fmt.Errorf("Union not supported")
|
||||
default:
|
||||
return nil, fmt.Errorf("Unrecognized type for insert: %v", queryRows)
|
||||
}
|
||||
|
||||
// Perform the insert
|
||||
var result InsertResult
|
||||
opt := InsertOptions{replace}
|
||||
for _, r := range rows {
|
||||
if has, err := row.IsValid(r, tableSch); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
if ignore {
|
||||
result.NumErrorsIgnored += 1
|
||||
continue
|
||||
} else {
|
||||
col, constraint, err := row.GetInvalidConstraint(r, tableSch)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(ConstraintFailedFmt, "unknown", "unknown")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(ConstraintFailedFmt, col.Name, constraint)
|
||||
}
|
||||
}
|
||||
|
||||
insertResult, err := batcher.Insert(ctx, tableName, r, opt)
|
||||
if err != nil {
|
||||
if ignore {
|
||||
result.NumErrorsIgnored += 1
|
||||
continue
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if insertResult.RowInserted {
|
||||
result.NumRowsInserted++
|
||||
}
|
||||
if insertResult.RowUpdated {
|
||||
result.NumRowsUpdated++
|
||||
}
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ExecuteInsert executes the given select insert statement and returns the result.
|
||||
func ExecuteInsert(
|
||||
ctx context.Context,
|
||||
db *doltdb.DoltDB,
|
||||
root *doltdb.RootValue,
|
||||
s *sqlparser.Insert,
|
||||
) (*InsertResult, error) {
|
||||
|
||||
batcher := NewSqlBatcher(db, root)
|
||||
insertResult, err := ExecuteBatchInsert(ctx, root, s, batcher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRoot, err := batcher.Commit(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &InsertResult{
|
||||
Root: newRoot,
|
||||
NumRowsInserted: insertResult.NumRowsInserted,
|
||||
NumRowsUpdated: insertResult.NumRowsUpdated,
|
||||
NumErrorsIgnored: insertResult.NumErrorsIgnored,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Returns a primary key summary of the row given
|
||||
func getPrimaryKeyString(r row.Row, tableSch schema.Schema) string {
|
||||
var sb strings.Builder
|
||||
first := true
|
||||
err := tableSch.GetPKCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
|
||||
if !first {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(col.Name)
|
||||
sb.WriteString(": ")
|
||||
val, ok := r.GetColVal(tag)
|
||||
if ok {
|
||||
sb.WriteString(fmt.Sprintf("%v", val))
|
||||
} else {
|
||||
sb.WriteString("null")
|
||||
}
|
||||
|
||||
first = false
|
||||
return false, nil
|
||||
})
|
||||
|
||||
// TODO: fix panics
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// Returns rows to insert from the set of values given
|
||||
func prepareInsertVals(nbf *types.NomsBinFormat, cols []schema.Column, values *sqlparser.Values, tableSch schema.Schema) ([]row.Row, error) {
|
||||
|
||||
// Lack of primary keys is its own special kind of failure that we can detect before creating any rows
|
||||
allKeysFound := true
|
||||
err := tableSch.GetPKCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
|
||||
for _, insertCol := range cols {
|
||||
if insertCol.Tag == tag {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
allKeysFound = false
|
||||
return true, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !allKeysFound {
|
||||
return nil, ErrMissingPrimaryKeys
|
||||
}
|
||||
|
||||
rows := make([]row.Row, len(*values))
|
||||
|
||||
for i, valTuple := range *values {
|
||||
r, err := makeRow(nbf, cols, tableSch, valTuple)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows[i] = r
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func makeRow(nbf *types.NomsBinFormat, columns []schema.Column, tableSch schema.Schema, tuple sqlparser.ValTuple) (row.Row, error) {
|
||||
if len(columns) != len(tuple) {
|
||||
return errInsertRow("Wrong number of values for tuple %v", nodeToString(tuple))
|
||||
}
|
||||
|
||||
taggedVals := make(row.TaggedValues)
|
||||
for i, expr := range tuple {
|
||||
column := columns[i]
|
||||
switch val := expr.(type) {
|
||||
case *sqlparser.SQLVal:
|
||||
nomsVal, err := extractNomsValueFromSQLVal(val, column.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
taggedVals[column.Tag] = nomsVal
|
||||
case *sqlparser.NullVal:
|
||||
// nothing to do, just don't set a tagged value for this column
|
||||
case sqlparser.BoolVal:
|
||||
if column.Kind != types.BoolKind {
|
||||
return errInsertRow("Type mismatch: boolean value but non-boolean column: %v", nodeToString(val))
|
||||
}
|
||||
taggedVals[column.Tag] = types.Bool(val)
|
||||
case *sqlparser.UnaryExpr:
|
||||
nomsVal, err := extractNomsValueFromUnaryExpr(val, column.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
taggedVals[column.Tag] = nomsVal
|
||||
|
||||
// Many of these shouldn't be possible in the grammar, but all cases included for completeness
|
||||
case *sqlparser.ComparisonExpr:
|
||||
return errInsertRow("Comparison expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.AndExpr:
|
||||
return errInsertRow("And expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.OrExpr:
|
||||
return errInsertRow("Or expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.NotExpr:
|
||||
return errInsertRow("Not expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.ParenExpr:
|
||||
return errInsertRow("Parenthetical expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.RangeCond:
|
||||
return errInsertRow("Range expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.IsExpr:
|
||||
return errInsertRow("Is expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.ExistsExpr:
|
||||
return errInsertRow("Exists expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.ColName:
|
||||
// unquoted strings are interpreted by the parser as column names, give a hint
|
||||
return errInsertRow("Column name expressions not supported in insert values. Did you forget to quote a string? %v", nodeToString(tuple))
|
||||
case sqlparser.ValTuple:
|
||||
return errInsertRow("Tuple expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.Subquery:
|
||||
return errInsertRow("Subquery expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.ListArg:
|
||||
return errInsertRow("List expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.BinaryExpr:
|
||||
return errInsertRow("Binary expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.IntervalExpr:
|
||||
return errInsertRow("Interval expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.CollateExpr:
|
||||
return errInsertRow("Collate expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.FuncExpr:
|
||||
return errInsertRow("Function expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.CaseExpr:
|
||||
return errInsertRow("Case expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.ValuesFuncExpr:
|
||||
return errInsertRow("Values func expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.ConvertExpr:
|
||||
return errInsertRow("Conversion expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.SubstrExpr:
|
||||
return errInsertRow("Substr expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.ConvertUsingExpr:
|
||||
return errInsertRow("Convert expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.MatchExpr:
|
||||
return errInsertRow("Match expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.GroupConcatExpr:
|
||||
return errInsertRow("Group concat expressions not supported in insert values: %v", nodeToString(tuple))
|
||||
case *sqlparser.Default:
|
||||
return errInsertRow("Unrecognized expression: %v", nodeToString(tuple))
|
||||
default:
|
||||
return errInsertRow("Unrecognized expression: %v", nodeToString(tuple))
|
||||
}
|
||||
}
|
||||
|
||||
return row.New(nbf, tableSch, taggedVals)
|
||||
}
|
||||
|
||||
// Returns an error result with return type to match ExecuteInsert
|
||||
func errInsert(errorFmt string, args ...interface{}) (*InsertResult, error) {
|
||||
return nil, errors.New(fmt.Sprintf(errorFmt, args...))
|
||||
}
|
||||
|
||||
// Returns an error result with return type to match ExecuteInsert
|
||||
func errInsertRow(errorFmt string, args ...interface{}) (row.Row, error) {
|
||||
return nil, errors.New(fmt.Sprintf(errorFmt, args...))
|
||||
}
|
||||
@@ -1,384 +0,0 @@
|
||||
// Copyright 2019 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"vitess.io/vitess/go/vt/sqlparser"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/dtestutils"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
. "github.com/liquidata-inc/dolt/go/libraries/doltcore/sql/sqltestutil"
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
func mustRow(r row.Row, err error) row.Row {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func TestExecuteInsert(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
insertedValues []row.Row
|
||||
expectedResult InsertResult // root is not compared, but it used for other assertions
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "insert one row, all columns",
|
||||
query: `insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
|
||||
(7, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000005', 677)`,
|
||||
insertedValues: []row.Row{NewPeopleRowWithOptionalFields(7, "Maggie", "Simpson", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000005"), 677)},
|
||||
expectedResult: InsertResult{NumRowsInserted: 1},
|
||||
},
|
||||
{
|
||||
name: "insert one row, all columns, negative values",
|
||||
query: `insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
|
||||
(-7, "Maggie", "Simpson", false, -1, -5.1, '00000000-0000-0000-0000-000000000005', 677)`,
|
||||
insertedValues: []row.Row{NewPeopleRowWithOptionalFields(-7, "Maggie", "Simpson", false, -1, -5.1, uuid.MustParse("00000000-0000-0000-0000-000000000005"), 677)},
|
||||
expectedResult: InsertResult{NumRowsInserted: 1},
|
||||
},
|
||||
{
|
||||
name: "insert one row, no column list",
|
||||
query: `insert into people values
|
||||
(7, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000005', 677)`,
|
||||
insertedValues: []row.Row{NewPeopleRowWithOptionalFields(7, "Maggie", "Simpson", false, 1, 5.1, uuid.MustParse("00000000-0000-0000-0000-000000000005"), 677)},
|
||||
expectedResult: InsertResult{NumRowsInserted: 1},
|
||||
},
|
||||
{
|
||||
name: "insert one row out of order",
|
||||
query: `insert into people (rating, first, id, last, age, is_married) values
|
||||
(5.1, "Maggie", 7, "Simpson", 1, false)`,
|
||||
insertedValues: []row.Row{NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1)},
|
||||
expectedResult: InsertResult{NumRowsInserted: 1},
|
||||
},
|
||||
{
|
||||
name: "insert one row, null values",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", null, null, null)`,
|
||||
insertedValues: []row.Row{mustRow(row.New(types.Format_7_18, PeopleTestSchema, row.TaggedValues{IdTag: types.Int(7), FirstTag: types.String("Maggie"), LastTag: types.String("Simpson")}))},
|
||||
expectedResult: InsertResult{NumRowsInserted: 1},
|
||||
},
|
||||
{
|
||||
name: "insert one row, null constraint failure",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", null, null, null, null)`,
|
||||
expectedErr: "Constraint failed for column 'last': Not null",
|
||||
},
|
||||
{
|
||||
name: "duplicate column list",
|
||||
query: `insert into people (id, first, last, is_married, first, age, rating) values
|
||||
(7, "Maggie", "Simpson", null, null, null, null)`,
|
||||
expectedErr: "Repeated column: 'first'",
|
||||
},
|
||||
{
|
||||
name: "insert two rows, all columns",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", false, 1, 5.1),
|
||||
(8, "Milhouse", "Van Houten", false, 8, 3.5)`,
|
||||
insertedValues: []row.Row{
|
||||
NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1),
|
||||
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
|
||||
},
|
||||
expectedResult: InsertResult{NumRowsInserted: 2},
|
||||
},
|
||||
{
|
||||
name: "insert two rows, one with null constraint failure",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", false, 1, 5.1),
|
||||
(8, "Milhouse", null, false, 8, 3.5)`,
|
||||
expectedErr: "Constraint failed for column 'last': Not null",
|
||||
},
|
||||
{
|
||||
name: "type mismatch int -> string",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", 100, false, 1, 5.1)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch int -> bool",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", 10, 1, 5.1)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch int -> uuid",
|
||||
query: `insert into people (id, first, last, is_married, age, uuid) values
|
||||
(7, "Maggie", "Simpson", false, 1, 100)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch string -> int",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
("7", "Maggie", "Simpson", false, 1, 5.1)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch string -> float",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", false, 1, "5.1")`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch string -> uint",
|
||||
query: `insert into people (id, first, last, is_married, age, num_episodes) values
|
||||
(7, "Maggie", "Simpson", false, 1, "100")`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch string -> uuid",
|
||||
query: `insert into people (id, first, last, is_married, age, uuid) values
|
||||
(7, "Maggie", "Simpson", false, 1, "a uuid but idk what im doing")`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch float -> string",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, 8.1, "Simpson", false, 1, 5.1)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch float -> bool",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", 0.5, 1, 5.1)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch float -> int",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", false, 1.0, 5.1)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch bool -> int",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(true, "Maggie", "Simpson", false, 1, 5.1)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch bool -> float",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", false, 1, true)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch bool -> string",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, true, "Simpson", false, 1, 5.1)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "type mismatch bool -> uuid",
|
||||
query: `insert into people (id, first, last, is_married, age, uuid) values
|
||||
(7, "Maggie", "Simpson", false, 1, true)`,
|
||||
expectedErr: "Type mismatch",
|
||||
},
|
||||
{
|
||||
name: "insert two rows with ignore, one with null constraint failure",
|
||||
query: `insert ignore into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", null, false, 1, 5.1),
|
||||
(8, "Milhouse", "Van Houten", false, 8, 3.5)`,
|
||||
insertedValues: []row.Row{
|
||||
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
|
||||
},
|
||||
expectedResult: InsertResult{NumRowsInserted: 1, NumErrorsIgnored: 1},
|
||||
},
|
||||
{
|
||||
name: "insert existing rows without ignore / replace",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(0, "Homer", "Simpson", true, 45, 100)`,
|
||||
expectedErr: "Duplicate primary key: 'id: 0'",
|
||||
},
|
||||
{
|
||||
name: "insert two rows with ignore, one existing in table",
|
||||
query: `insert ignore into people (id, first, last, is_married, age, rating) values
|
||||
(0, "Homer", "Simpson", true, 45, 100),
|
||||
(8, "Milhouse", "Van Houten", false, 8, 3.5)`,
|
||||
insertedValues: []row.Row{
|
||||
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
|
||||
Homer, // verify that homer is unchanged by the insert
|
||||
},
|
||||
expectedResult: InsertResult{NumRowsInserted: 1, NumErrorsIgnored: 1},
|
||||
},
|
||||
{
|
||||
name: "insert two rows with replace, one existing in table",
|
||||
query: `replace into people (id, first, last, is_married, age, rating) values
|
||||
(0, "Homer", "Simpson", true, 45, 100),
|
||||
(8, "Milhouse", "Van Houten", false, 8, 3.5)`,
|
||||
insertedValues: []row.Row{
|
||||
NewPeopleRow(0, "Homer", "Simpson", true, 45, 100),
|
||||
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
|
||||
},
|
||||
expectedResult: InsertResult{NumRowsInserted: 1, NumRowsUpdated: 1},
|
||||
},
|
||||
{
|
||||
name: "insert two rows with replace ignore, one with errors",
|
||||
query: `replace ignore into people (id, first, last, is_married, age, rating) values
|
||||
(0, "Homer", "Simpson", true, 45, 100),
|
||||
(8, "Milhouse", "Van Houten", false, 8, 3.5),
|
||||
(7, "Maggie", null, false, 1, 5.1)`,
|
||||
insertedValues: []row.Row{
|
||||
NewPeopleRow(0, "Homer", "Simpson", true, 45, 100),
|
||||
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
|
||||
},
|
||||
expectedResult: InsertResult{NumRowsInserted: 1, NumRowsUpdated: 1, NumErrorsIgnored: 1},
|
||||
},
|
||||
{
|
||||
name: "insert two rows with replace, one with errors",
|
||||
query: `replace into people (id, first, last, is_married, age, rating) values
|
||||
(0, "Homer", "Simpson", true, 45, 100),
|
||||
(8, "Milhouse", "Van Houten", false, 8, 3.5),
|
||||
(7, "Maggie", null, false, 1, 5.1)`,
|
||||
expectedErr: "Constraint failed for column 'last': Not null",
|
||||
},
|
||||
{
|
||||
name: "insert five rows, all columns",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", false, 1, 5.1),
|
||||
(8, "Milhouse", "Van Houten", false, 8, 3.5),
|
||||
(9, "Jacqueline", "Bouvier", true, 80, 2),
|
||||
(10, "Patty", "Bouvier", false, 40, 7),
|
||||
(11, "Selma", "Bouvier", false, 40, 7)`,
|
||||
insertedValues: []row.Row{
|
||||
NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1),
|
||||
NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5),
|
||||
NewPeopleRow(9, "Jacqueline", "Bouvier", true, 80, 2),
|
||||
NewPeopleRow(10, "Patty", "Bouvier", false, 40, 7),
|
||||
NewPeopleRow(11, "Selma", "Bouvier", false, 40, 7),
|
||||
},
|
||||
expectedResult: InsertResult{NumRowsInserted: 5},
|
||||
},
|
||||
{
|
||||
name: "insert two rows, only required columns",
|
||||
query: `insert into people (id, first, last) values
|
||||
(7, "Maggie", "Simpson"),
|
||||
(8, "Milhouse", "Van Houten")`,
|
||||
insertedValues: []row.Row{
|
||||
mustRow(row.New(types.Format_7_18, PeopleTestSchema, row.TaggedValues{IdTag: types.Int(7), FirstTag: types.String("Maggie"), LastTag: types.String("Simpson")})),
|
||||
mustRow(row.New(types.Format_7_18, PeopleTestSchema, row.TaggedValues{IdTag: types.Int(8), FirstTag: types.String("Milhouse"), LastTag: types.String("Van Houten")})),
|
||||
},
|
||||
expectedResult: InsertResult{NumRowsInserted: 2},
|
||||
},
|
||||
|
||||
{
|
||||
name: "insert two rows, duplicate id",
|
||||
query: `insert into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", false, 1, 5.1),
|
||||
(7, "Milhouse", "Van Houten", false, 8, 3.5)`,
|
||||
expectedErr: "Duplicate primary key: 'id: 7'",
|
||||
},
|
||||
{
|
||||
name: "insert two rows, duplicate id with ignore",
|
||||
query: `insert ignore into people (id, first, last, is_married, age, rating) values
|
||||
(7, "Maggie", "Simpson", false, 1, 5.1),
|
||||
(7, "Milhouse", "Van Houten", false, 8, 3.5)`,
|
||||
insertedValues: []row.Row{
|
||||
NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1),
|
||||
},
|
||||
expectedResult: InsertResult{NumRowsInserted: 1, NumErrorsIgnored: 1},
|
||||
},
|
||||
{
|
||||
name: "insert no primary keys",
|
||||
query: `insert into people (age) values (7), (8)`,
|
||||
expectedErr: "One or more primary key columns missing from insert statement",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
if len(tt.expectedErr) > 0 {
|
||||
require.Equal(t, InsertResult{}, tt.expectedResult, "incorrect test setup: cannot assert both an error and expected results")
|
||||
require.Nil(t, tt.insertedValues, "incorrect test setup: cannot assert both an error and inserted values")
|
||||
}
|
||||
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
ctx := context.Background()
|
||||
|
||||
CreateTestDatabase(dEnv, t)
|
||||
root, _ := dEnv.WorkingRoot(ctx)
|
||||
|
||||
sqlStatement, _ := sqlparser.Parse(tt.query)
|
||||
s := sqlStatement.(*sqlparser.Insert)
|
||||
|
||||
result, err := ExecuteInsert(ctx, dEnv.DoltDB, root, s)
|
||||
|
||||
if len(tt.expectedErr) > 0 {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.expectedErr)
|
||||
return
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectedResult.NumRowsInserted, result.NumRowsInserted)
|
||||
assert.Equal(t, tt.expectedResult.NumErrorsIgnored, result.NumErrorsIgnored)
|
||||
assert.Equal(t, tt.expectedResult.NumRowsUpdated, result.NumRowsUpdated)
|
||||
|
||||
table, ok, err := result.Root.GetTable(ctx, PeopleTableName)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
for _, expectedRow := range tt.insertedValues {
|
||||
v, err := expectedRow.NomsMapKey(PeopleTestSchema).Value(ctx)
|
||||
assert.NoError(t, err)
|
||||
foundRow, ok, err := table.GetRow(ctx, v.(types.Tuple), PeopleTestSchema)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok, "Row not found: %v", expectedRow)
|
||||
eq, diff := rowsEqual(expectedRow, foundRow)
|
||||
assert.True(t, eq, "Rows not equals, found diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func rowsEqual(expected, actual row.Row) (bool, string) {
|
||||
er, ar := make(map[uint64]types.Value), make(map[uint64]types.Value)
|
||||
_, err := expected.IterCols(func(t uint64, v types.Value) (bool, error) {
|
||||
er[t] = v
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = actual.IterCols(func(t uint64, v types.Value) (bool, error) {
|
||||
ar[t] = v
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
opts := cmp.Options{cmp.AllowUnexported(), dtestutils.FloatComparer}
|
||||
eq := cmp.Equal(er, ar, opts)
|
||||
var diff string
|
||||
if !eq {
|
||||
diff = cmp.Diff(er, ar, opts)
|
||||
}
|
||||
return eq, diff
|
||||
}
|
||||
@@ -29,6 +29,9 @@ import (
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table/untyped/resultset"
|
||||
)
|
||||
|
||||
var ErrMissingPrimaryKeys = errors.New("One or more primary key columns missing from insert statement")
|
||||
var ConstraintFailedFmt = "Constraint failed for column '%v': %v"
|
||||
|
||||
type UpdateResult struct {
|
||||
Root *doltdb.RootValue
|
||||
NumRowsUpdated int
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -372,3 +373,32 @@ func TestExecuteUpdate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func rowsEqual(expected, actual row.Row) (bool, string) {
|
||||
er, ar := make(map[uint64]types.Value), make(map[uint64]types.Value)
|
||||
_, err := expected.IterCols(func(t uint64, v types.Value) (bool, error) {
|
||||
er[t] = v
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = actual.IterCols(func(t uint64, v types.Value) (bool, error) {
|
||||
ar[t] = v
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
opts := cmp.Options{cmp.AllowUnexported(), dtestutils.FloatComparer}
|
||||
eq := cmp.Equal(er, ar, opts)
|
||||
var diff string
|
||||
if !eq {
|
||||
diff = cmp.Diff(er, ar, opts)
|
||||
}
|
||||
return eq, diff
|
||||
}
|
||||
|
||||
@@ -29,20 +29,42 @@ import (
|
||||
|
||||
var _ sql.Database = (*Database)(nil)
|
||||
|
||||
type batchMode bool
|
||||
|
||||
const (
|
||||
batched batchMode = true
|
||||
single batchMode = false
|
||||
)
|
||||
|
||||
// Database implements sql.Database for a dolt DB.
|
||||
type Database struct {
|
||||
sql.Database
|
||||
name string
|
||||
root *doltdb.RootValue
|
||||
dEnv *env.DoltEnv
|
||||
name string
|
||||
root *doltdb.RootValue
|
||||
dEnv *env.DoltEnv
|
||||
batchMode batchMode
|
||||
tables map[string]*DoltTable
|
||||
}
|
||||
|
||||
// NewDatabase returns a new dolt databae to use in queries.
|
||||
// NewDatabase returns a new dolt database to use in queries.
|
||||
func NewDatabase(name string, root *doltdb.RootValue, dEnv *env.DoltEnv) *Database {
|
||||
return &Database{
|
||||
name: name,
|
||||
root: root,
|
||||
dEnv: dEnv,
|
||||
name: name,
|
||||
root: root,
|
||||
dEnv: dEnv,
|
||||
batchMode: single,
|
||||
tables: make(map[string]*DoltTable),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBatchedDatabase returns a new dolt database executing in batch insert mode. Integrators must call Flush() to
|
||||
// commit any outstanding edits.
|
||||
func NewBatchedDatabase(name string, root *doltdb.RootValue, dEnv *env.DoltEnv) *Database {
|
||||
return &Database{
|
||||
name: name,
|
||||
root: root,
|
||||
dEnv: dEnv,
|
||||
batchMode: batched,
|
||||
tables: make(map[string]*DoltTable),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,12 +102,16 @@ func (db *Database) GetTableInsensitive(ctx context.Context, tblName string) (sq
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
if table, ok := db.tables[exactName]; ok {
|
||||
return table, true, nil
|
||||
}
|
||||
|
||||
tbl, ok, err := db.root.GetTable(ctx, exactName)
|
||||
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
} else if !ok {
|
||||
panic("Name '" + exactName + "'had already been verified... This is a bug")
|
||||
panic("Name '" + exactName + "' had already been verified... This is a bug")
|
||||
}
|
||||
|
||||
sch, err := tbl.GetSchema(ctx)
|
||||
@@ -94,7 +120,9 @@ func (db *Database) GetTableInsensitive(ctx context.Context, tblName string) (sq
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return &DoltTable{name: tblName, table: tbl, sch: sch, db: db}, true, nil
|
||||
table := &DoltTable{name: exactName, table: tbl, sch: sch, db: db}
|
||||
db.tables[exactName] = table
|
||||
return table, true, nil
|
||||
}
|
||||
|
||||
func (db *Database) GetTableNames(ctx context.Context) ([]string, error) {
|
||||
@@ -129,6 +157,8 @@ func (db *Database) DropTable(ctx *sql.Context, tableName string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(db.tables, tableName)
|
||||
|
||||
db.SetRoot(newRoot)
|
||||
|
||||
return nil
|
||||
@@ -176,3 +206,14 @@ func (db *Database) CreateTable(ctx *sql.Context, tableName string, schema sql.S
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flushes the current batch of outstanding changes and returns any errors.
|
||||
func (db *Database) Flush(ctx context.Context) error {
|
||||
for name, table := range db.tables {
|
||||
if err := table.flushBatchedEdits(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(db.tables, name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+81
-43
@@ -12,18 +12,21 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sql
|
||||
package sqle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/uuid"
|
||||
sqle "github.com/src-d/go-mysql-server"
|
||||
"github.com/src-d/go-mysql-server/sql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"vitess.io/vitess/go/vt/sqlparser"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/dtestutils"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
@@ -63,17 +66,14 @@ func TestSqlBatchInserts(t *testing.T) {
|
||||
CreateTestDatabase(dEnv, t)
|
||||
root, _ := dEnv.WorkingRoot(ctx)
|
||||
|
||||
batcher := NewSqlBatcher(dEnv.DoltDB, root)
|
||||
engine := sqle.NewDefault()
|
||||
db := NewBatchedDatabase("dolt", root, dEnv)
|
||||
engine.AddDatabase(db)
|
||||
|
||||
for _, stmt := range insertStatements {
|
||||
statement, err := sqlparser.Parse(stmt)
|
||||
_, rowIter, err := engine.Query(sql.NewEmptyContext(), stmt)
|
||||
require.NoError(t, err)
|
||||
insertStmt, ok := statement.(*sqlparser.Insert)
|
||||
require.True(t, ok)
|
||||
result, err := ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.NumRowsInserted > 0)
|
||||
assert.Equal(t, 0, result.NumRowsUpdated)
|
||||
assert.Equal(t, 0, result.NumErrorsIgnored)
|
||||
require.NoError(t, drainIter(rowIter))
|
||||
}
|
||||
|
||||
// Before committing the batch, the database should be unchanged from its original state
|
||||
@@ -89,7 +89,7 @@ func TestSqlBatchInserts(t *testing.T) {
|
||||
assert.ElementsMatch(t, AllAppsRows, allAppearanceRows)
|
||||
|
||||
// Now commit the batch and check for new rows
|
||||
root, err = batcher.Commit(ctx)
|
||||
err = db.Flush(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var expectedPeople, expectedEpisodes, expectedAppearances []row.Row
|
||||
@@ -124,6 +124,7 @@ func TestSqlBatchInserts(t *testing.T) {
|
||||
newAppsRow(11, 9),
|
||||
)
|
||||
|
||||
root = db.Root()
|
||||
allPeopleRows, err = GetAllRows(root, PeopleTableName)
|
||||
require.NoError(t, err)
|
||||
allEpsRows, err = GetAllRows(root, EpisodesTableName)
|
||||
@@ -136,15 +137,27 @@ func TestSqlBatchInserts(t *testing.T) {
|
||||
assertRowSetsEqual(t, expectedAppearances, allAppearanceRows)
|
||||
}
|
||||
|
||||
func drainIter(iter sql.RowIter) error {
|
||||
var returnedErr error
|
||||
for {
|
||||
_, err := iter.Next()
|
||||
if err == io.EOF {
|
||||
return returnedErr
|
||||
} else if err != nil {
|
||||
returnedErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
|
||||
t.Skip("Skipped until insert ignore statements supported in go-mysql-server")
|
||||
|
||||
insertStatements := []string{
|
||||
`replace into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
|
||||
(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
|
||||
`insert ignore into people values
|
||||
(2, "Milhouse", "VanHouten", false, 1, 5.1, '00000000-0000-0000-0000-000000000008', 677)`,
|
||||
}
|
||||
numRowsUpdated := []int{1, 0}
|
||||
numErrorsIgnored := []int{0, 1}
|
||||
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
ctx := context.Background()
|
||||
@@ -152,18 +165,14 @@ func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
|
||||
CreateTestDatabase(dEnv, t)
|
||||
root, _ := dEnv.WorkingRoot(ctx)
|
||||
|
||||
batcher := NewSqlBatcher(dEnv.DoltDB, root)
|
||||
for i := range insertStatements {
|
||||
stmt := insertStatements[i]
|
||||
statement, err := sqlparser.Parse(stmt)
|
||||
engine := sqle.NewDefault()
|
||||
db := NewBatchedDatabase("dolt", root, dEnv)
|
||||
engine.AddDatabase(db)
|
||||
|
||||
for _, stmt := range insertStatements {
|
||||
_, rowIter, err := engine.Query(sql.NewEmptyContext(), stmt)
|
||||
require.NoError(t, err)
|
||||
insertStmt, ok := statement.(*sqlparser.Insert)
|
||||
require.True(t, ok)
|
||||
result, err := ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.NumRowsInserted)
|
||||
assert.Equal(t, numRowsUpdated[i], result.NumRowsUpdated)
|
||||
assert.Equal(t, numErrorsIgnored[i], result.NumErrorsIgnored)
|
||||
drainIter(rowIter)
|
||||
}
|
||||
|
||||
// Before committing the batch, the database should be unchanged from its original state
|
||||
@@ -172,7 +181,7 @@ func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
|
||||
assert.ElementsMatch(t, AllPeopleRows, allPeopleRows)
|
||||
|
||||
// Now commit the batch and check for new rows
|
||||
root, err = batcher.Commit(ctx)
|
||||
err = db.Flush(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
var expectedPeople []row.Row
|
||||
@@ -188,29 +197,29 @@ func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlBatchInsertErrors(t *testing.T) {
|
||||
insertStatements := []string{
|
||||
`insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
|
||||
(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`,
|
||||
`insert into people values
|
||||
(2, "Milhouse", "VanHouten", false, 1, 5.1, true, 677)`,
|
||||
}
|
||||
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
ctx := context.Background()
|
||||
|
||||
CreateTestDatabase(dEnv, t)
|
||||
root, _ := dEnv.WorkingRoot(ctx)
|
||||
|
||||
batcher := NewSqlBatcher(dEnv.DoltDB, root)
|
||||
for i := range insertStatements {
|
||||
stmt := insertStatements[i]
|
||||
statement, err := sqlparser.Parse(stmt)
|
||||
require.NoError(t, err)
|
||||
insertStmt, ok := statement.(*sqlparser.Insert)
|
||||
require.True(t, ok)
|
||||
_, err = ExecuteBatchInsert(context.Background(), root, insertStmt, batcher)
|
||||
require.Error(t, err)
|
||||
}
|
||||
engine := sqle.NewDefault()
|
||||
db := NewBatchedDatabase("dolt", root, dEnv)
|
||||
engine.AddDatabase(db)
|
||||
|
||||
_, rowIter, err := engine.Query(sql.NewEmptyContext(), `insert into people (id, first, last, is_married, age, rating, uuid, num_episodes) values
|
||||
(0, "Maggie", "Simpson", false, 1, 5.1, '00000000-0000-0000-0000-000000000007', 677)`)
|
||||
// This won't generate an error until we commit the batch (duplicate key)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, drainIter(rowIter))
|
||||
|
||||
// This generates an error at insert time because of the bad type for the uuid column
|
||||
_, _, err = engine.Query(sql.NewEmptyContext(), `insert into people values
|
||||
(2, "Milhouse", "VanHouten", false, 1, 5.1, true, 677)`)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Error from the first statement appears here
|
||||
assert.Error(t, db.Flush(ctx))
|
||||
}
|
||||
|
||||
func assertRowSetsEqual(t *testing.T, expected, actual []row.Row) {
|
||||
@@ -244,6 +253,35 @@ func containsRow(rs []row.Row, r row.Row) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func rowsEqual(expected, actual row.Row) (bool, string) {
|
||||
er, ar := make(map[uint64]types.Value), make(map[uint64]types.Value)
|
||||
_, err := expected.IterCols(func(t uint64, v types.Value) (bool, error) {
|
||||
er[t] = v
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = actual.IterCols(func(t uint64, v types.Value) (bool, error) {
|
||||
ar[t] = v
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
opts := cmp.Options{cmp.AllowUnexported(), dtestutils.FloatComparer}
|
||||
eq := cmp.Equal(er, ar, opts)
|
||||
var diff string
|
||||
if !eq {
|
||||
diff = cmp.Diff(er, ar, opts)
|
||||
}
|
||||
return eq, diff
|
||||
}
|
||||
|
||||
func newPeopleRow(id int, first, last string) row.Row {
|
||||
vals := row.TaggedValues{
|
||||
IdTag: types.Int(id),
|
||||
@@ -34,7 +34,9 @@ import (
|
||||
// Executes all the SQL non-select statements given in the string against the root value given and returns the updated
|
||||
// root, or an error. Statements in the input string are split by `;\n`
|
||||
func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*doltdb.RootValue, error) {
|
||||
batcher := dsql.NewSqlBatcher(dEnv.DoltDB, root)
|
||||
engine := sqle.NewDefault()
|
||||
db := dsqle.NewBatchedDatabase("dolt", root, dEnv)
|
||||
engine.AddDatabase(db)
|
||||
|
||||
for _, query := range strings.Split(statements, ";\n") {
|
||||
if len(strings.Trim(query, " ")) == 0 {
|
||||
@@ -53,9 +55,13 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
|
||||
case *sqlparser.Select, *sqlparser.OtherRead:
|
||||
return nil, errors.New("Select statements aren't handled")
|
||||
case *sqlparser.Insert:
|
||||
_, execErr = dsql.ExecuteBatchInsert(context.Background(), root, s, batcher)
|
||||
var rowIter sql.RowIter
|
||||
_, rowIter, execErr = engine.Query(sql.NewEmptyContext(), query)
|
||||
if execErr == nil {
|
||||
execErr = drainIter(rowIter)
|
||||
}
|
||||
case *sqlparser.DDL:
|
||||
if root, err = batcher.Commit(context.Background()); err != nil {
|
||||
if err = db.Flush(context.Background()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, execErr = sqlparser.ParseStrictDDL(query)
|
||||
@@ -63,9 +69,7 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
|
||||
return nil, fmt.Errorf("Error parsing DDL: %v.", execErr.Error())
|
||||
}
|
||||
root, execErr = sqlDDL(dEnv, root, s, query)
|
||||
if err := batcher.UpdateRoot(root); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetRoot(root)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query)
|
||||
}
|
||||
@@ -75,13 +79,11 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
|
||||
}
|
||||
}
|
||||
|
||||
if newRoot, err := batcher.Commit(context.Background()); newRoot != nil {
|
||||
root = newRoot
|
||||
} else if err != nil {
|
||||
if err := db.Flush(context.Background()); err == nil {
|
||||
return db.Root(), nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func sqlDDL(dEnv *env.DoltEnv, root *doltdb.RootValue, ddl *sqlparser.DDL, query string) (*doltdb.RootValue, error) {
|
||||
@@ -136,3 +138,15 @@ func ExecuteSelect(root *doltdb.RootValue, query string) ([]sql.Row, error) {
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func drainIter(iter sql.RowIter) error {
|
||||
var returnedErr error
|
||||
for {
|
||||
_, err := iter.Next()
|
||||
if err == io.EOF {
|
||||
return returnedErr
|
||||
} else if err != nil {
|
||||
returnedErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
// Copyright 2019 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/src-d/go-mysql-server/sql"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/cmd/dolt/errhand"
|
||||
"github.com/liquidata-inc/dolt/go/store/hash"
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
var ErrDuplicatePrimaryKeyFmt = "duplicate primary key given: (%v)"
|
||||
|
||||
// tableEditor supports making multiple row edits (inserts, updates, deletes) to a table. It does error checking for key
|
||||
// collision etc. in the Close() method, as well as during Insert / Update.
|
||||
// Right now a table editor allows you to combine inserts, updates, and deletes in any order, and makes reasonable
|
||||
// attempts to produce correct results when doing so. But this probably (definitely) doesn't work in every case, and
|
||||
// higher-level clients should carefully flush the editor when necessary (i.e. before an update after many inserts).
|
||||
type tableEditor struct {
|
||||
t *DoltTable
|
||||
ed *types.MapEditor
|
||||
insertedKeys map[hash.Hash]types.Value
|
||||
addedKeys map[hash.Hash]types.Value
|
||||
removedKeys map[hash.Hash]types.Value
|
||||
}
|
||||
|
||||
var _ sql.RowReplacer = (*tableEditor)(nil)
|
||||
var _ sql.RowUpdater = (*tableEditor)(nil)
|
||||
var _ sql.RowInserter = (*tableEditor)(nil)
|
||||
var _ sql.RowDeleter = (*tableEditor)(nil)
|
||||
|
||||
func newTableEditor(t *DoltTable) *tableEditor {
|
||||
return &tableEditor{
|
||||
t: t,
|
||||
insertedKeys: make(map[hash.Hash]types.Value),
|
||||
addedKeys: make(map[hash.Hash]types.Value),
|
||||
removedKeys: make(map[hash.Hash]types.Value),
|
||||
}
|
||||
}
|
||||
|
||||
func (te *tableEditor) Insert(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
dRow, err := SqlRowToDoltRow(te.t.table.Format(), sqlRow, te.t.sch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, err := dRow.NomsMapKey(te.t.sch).Value(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to get row key").AddCause(err).Build()
|
||||
}
|
||||
|
||||
hash, err := key.Hash(dRow.Format())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we've already inserted this key as part of this insert operation, that's an error. Inserting a row that already
|
||||
// exists in the table will be handled in Close().
|
||||
if _, ok := te.addedKeys[hash]; ok {
|
||||
value, err := types.EncodedValue(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf(ErrDuplicatePrimaryKeyFmt, value)
|
||||
}
|
||||
te.insertedKeys[hash] = key
|
||||
te.addedKeys[hash] = key
|
||||
|
||||
if te.ed == nil {
|
||||
te.ed, err = te.t.newMapEditor(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
te.ed = te.ed.Set(key, dRow.NomsMapValue(te.t.sch))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (te *tableEditor) Delete(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
dRow, err := SqlRowToDoltRow(te.t.table.Format(), sqlRow, te.t.sch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, err := dRow.NomsMapKey(te.t.sch).Value(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to get row key").AddCause(err).Build()
|
||||
}
|
||||
hash, err := key.Hash(dRow.Format())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(te.addedKeys, hash)
|
||||
te.removedKeys[hash] = key
|
||||
|
||||
if te.ed == nil {
|
||||
te.ed, err = te.t.newMapEditor(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
te.ed = te.ed.Remove(key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DoltTable) newMapEditor(ctx context.Context) (*types.MapEditor, error) {
|
||||
typesMap, err := t.table.GetRowData(ctx)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("failed to get row data.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
return typesMap.Edit(), nil
|
||||
}
|
||||
|
||||
func (te *tableEditor) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
|
||||
dOldRow, err := SqlRowToDoltRow(te.t.table.Format(), oldRow, te.t.sch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dNewRow, err := SqlRowToDoltRow(te.t.table.Format(), newRow, te.t.sch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the PK is changed then we need to delete the old value and insert the new one
|
||||
dOldKey := dOldRow.NomsMapKey(te.t.sch)
|
||||
dOldKeyVal, err := dOldKey.Value(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dNewKey := dNewRow.NomsMapKey(te.t.sch)
|
||||
dNewKeyVal, err := dNewKey.Value(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !dOldKeyVal.Equals(dNewKeyVal) {
|
||||
oldHash, err := dOldKeyVal.Hash(dOldRow.Format())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newHash, err := dNewKeyVal.Hash(dNewRow.Format())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the old value of the primary key we just updated was previously inserted, then we need to remove it now.
|
||||
if _, ok := te.insertedKeys[oldHash]; ok {
|
||||
delete(te.insertedKeys, oldHash)
|
||||
te.ed.Remove(dOldKeyVal)
|
||||
}
|
||||
|
||||
te.addedKeys[newHash] = dNewKeyVal
|
||||
te.removedKeys[oldHash] = dOldKeyVal
|
||||
}
|
||||
|
||||
if te.ed == nil {
|
||||
te.ed, err = te.t.newMapEditor(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
te.ed.Set(dNewKeyVal, dNewRow.NomsMapValue(te.t.sch))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements Closer
|
||||
func (te *tableEditor) Close(ctx *sql.Context) error {
|
||||
// If we're running in batched mode, don't flush the edits until explicitly told to do so by the parent table.
|
||||
if te.t.db.batchMode == batched {
|
||||
return nil
|
||||
}
|
||||
return te.flush(ctx)
|
||||
}
|
||||
|
||||
func (te *tableEditor) flush(ctx context.Context) error {
|
||||
// For all added keys, check for and report a collision
|
||||
for hash, addedKey := range te.addedKeys {
|
||||
if _, ok := te.removedKeys[hash]; !ok {
|
||||
_, rowExists, err := te.t.table.GetRow(ctx, addedKey.(types.Tuple), te.t.sch)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to read table").AddCause(err).Build()
|
||||
}
|
||||
if rowExists {
|
||||
value, err := types.EncodedValue(ctx, addedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf(ErrDuplicatePrimaryKeyFmt, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
// For all removed keys, remove the map entries that weren't added elsewhere by other updates
|
||||
for hash, removedKey := range te.removedKeys {
|
||||
if _, ok := te.addedKeys[hash]; !ok {
|
||||
te.ed.Remove(removedKey)
|
||||
}
|
||||
}
|
||||
|
||||
if te.ed != nil {
|
||||
return te.t.updateTable(ctx, te.ed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
// Copyright 2019 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/src-d/go-mysql-server/sql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/dtestutils"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
|
||||
. "github.com/liquidata-inc/dolt/go/libraries/doltcore/sql/sqltestutil"
|
||||
)
|
||||
|
||||
type tableEditorTest struct {
|
||||
// The name of this test. Names should be unique and descriptive.
|
||||
name string
|
||||
// Test setup to run
|
||||
setup func(ctx *sql.Context, t *testing.T, ed *tableEditor)
|
||||
// The select query to run to verify the results
|
||||
selectQuery string
|
||||
// The rows this query should return, nil if an error is expected
|
||||
expectedRows []row.Row
|
||||
// Expected error string, if any
|
||||
expectedErr string
|
||||
}
|
||||
|
||||
func TestTableEditor(t *testing.T) {
|
||||
edna := NewPeopleRow(10, "Edna", "Krabapple", false, 38, 8.0)
|
||||
krusty := NewPeopleRow(11, "Krusty", "Klown", false, 48, 9.5)
|
||||
smithers := NewPeopleRow(12, "Waylon", "Smithers", false, 44, 7.1)
|
||||
ralph := NewPeopleRow(13, "Ralph", "Wiggum", false, 9, 9.1)
|
||||
martin := NewPeopleRow(14, "Martin", "Prince", false, 11, 6.3)
|
||||
skinner := NewPeopleRow(15, "Seymore", "Skinner", false, 50, 8.3)
|
||||
fatTony := NewPeopleRow(16, "Fat", "Tony", false, 53, 5.0)
|
||||
troyMclure := NewPeopleRow(17, "Troy", "McClure", false, 58, 7.0)
|
||||
|
||||
var expectedErr error
|
||||
// Some of these are pretty exotic use cases, but since we support all these operations it's nice to know they work
|
||||
// in tandem.
|
||||
testCases := []tableEditorTest{
|
||||
{
|
||||
name: "all inserts",
|
||||
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
|
||||
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(smithers, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(ralph, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(martin, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(skinner, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(fatTony, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(troyMclure, PeopleTestSchema)))
|
||||
},
|
||||
selectQuery: "select * from people where id >= 10",
|
||||
expectedRows: CompressRows(PeopleTestSchema,
|
||||
edna, krusty, smithers, ralph, martin, skinner, fatTony, troyMclure,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "inserts and deletes",
|
||||
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
|
||||
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Delete(ctx, r(edna, PeopleTestSchema)))
|
||||
},
|
||||
selectQuery: "select * from people where id >= 10",
|
||||
expectedRows: CompressRows(PeopleTestSchema,
|
||||
krusty,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "inserts and deletes 2",
|
||||
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
|
||||
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Delete(ctx, r(edna, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(fatTony, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Delete(ctx, r(Homer, PeopleTestSchema)))
|
||||
},
|
||||
selectQuery: "select * from people where id >= 10 or id = 0",
|
||||
expectedRows: CompressRows(PeopleTestSchema,
|
||||
krusty, fatTony,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "inserts and updates",
|
||||
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
|
||||
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Update(ctx, r(edna, PeopleTestSchema), r(MutateRow(edna, AgeTag, 1), PeopleTestSchema)))
|
||||
},
|
||||
selectQuery: "select * from people where id >= 10",
|
||||
expectedRows: CompressRows(PeopleTestSchema,
|
||||
MutateRow(edna, AgeTag, 1),
|
||||
krusty,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "inserts updates and deletes",
|
||||
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
|
||||
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Update(ctx, r(edna, PeopleTestSchema), r(MutateRow(edna, AgeTag, 1), PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(smithers, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(ralph, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Update(ctx, r(smithers, PeopleTestSchema), r(MutateRow(smithers, AgeTag, 1), PeopleTestSchema)))
|
||||
require.NoError(t, ed.Delete(ctx, r(smithers, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(skinner, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Delete(ctx, r(ralph, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(ralph, PeopleTestSchema)))
|
||||
},
|
||||
selectQuery: "select * from people where id >= 10",
|
||||
expectedRows: CompressRows(PeopleTestSchema,
|
||||
MutateRow(edna, AgeTag, 1),
|
||||
krusty,
|
||||
ralph,
|
||||
skinner,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "inserts and updates to primary key",
|
||||
setup: func(ctx *sql.Context, t *testing.T, ed *tableEditor) {
|
||||
require.NoError(t, ed.Insert(ctx, r(edna, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema)))
|
||||
require.NoError(t, ed.Update(ctx, r(edna, PeopleTestSchema), r(MutateRow(edna, IdTag, 30), PeopleTestSchema)))
|
||||
},
|
||||
selectQuery: "select * from people where id >= 10",
|
||||
expectedRows: CompressRows(PeopleTestSchema,
|
||||
krusty,
|
||||
MutateRow(edna, IdTag, 30),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
expectedErr = nil
|
||||
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
CreateTestDatabase(dEnv, t)
|
||||
|
||||
ctx := sql.NewEmptyContext()
|
||||
root, _ := dEnv.WorkingRoot(context.Background())
|
||||
db := NewDatabase("dolt", root, dEnv)
|
||||
peopleTable, _, err := db.GetTableInsensitive(ctx, "people")
|
||||
require.NoError(t, err)
|
||||
|
||||
dt := peopleTable.(sql.UpdatableTable)
|
||||
ed := dt.Updater(ctx).(*tableEditor)
|
||||
|
||||
test.setup(ctx, t, ed)
|
||||
if len(test.expectedErr) > 0 {
|
||||
require.Error(t, expectedErr)
|
||||
assert.Contains(t, expectedErr.Error(), test.expectedErr)
|
||||
return
|
||||
} else {
|
||||
require.NoError(t, ed.Close(ctx))
|
||||
}
|
||||
|
||||
root = db.Root()
|
||||
actualRows, _, err := executeSelect(context.Background(), dEnv, CompressSchema(PeopleTestSchema), root, test.selectQuery)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedRows, actualRows)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func r(row row.Row, sch schema.Schema) sql.Row {
|
||||
sqlRow, err := doltRowToSqlRow(row, sch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return sqlRow
|
||||
}
|
||||
@@ -25,24 +25,25 @@ import (
|
||||
"github.com/liquidata-inc/dolt/go/cmd/dolt/errhand"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/liquidata-inc/dolt/go/store/hash"
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
// DoltTable implements the sql.Table interface and gives access to dolt table rows and schema.
|
||||
type DoltTable struct {
|
||||
name string
|
||||
table *doltdb.Table
|
||||
sch schema.Schema
|
||||
sqlSch sql.Schema
|
||||
db *Database
|
||||
ed *tableEditor
|
||||
}
|
||||
|
||||
var _ sql.Table = (*DoltTable)(nil)
|
||||
var _ sql.UpdatableTable = (*DoltTable)(nil)
|
||||
var _ sql.DeletableTable = (*DoltTable)(nil)
|
||||
var _ sql.InsertableTable = (*DoltTable)(nil)
|
||||
var _ sql.ReplaceableTable = (*DoltTable)(nil)
|
||||
|
||||
// DoltTable implements the sql.Table interface and gives access to dolt table rows and schema.
|
||||
type DoltTable struct {
|
||||
name string
|
||||
table *doltdb.Table
|
||||
sch schema.Schema
|
||||
db *Database
|
||||
}
|
||||
|
||||
// Implements sql.IndexableTable
|
||||
func (t *DoltTable) WithIndexLookup(lookup sql.IndexLookup) sql.Table {
|
||||
dil, ok := lookup.(*doltIndexLookup)
|
||||
@@ -78,20 +79,21 @@ func (t *DoltTable) String() string {
|
||||
|
||||
// Schema returns the schema for this table.
|
||||
func (t *DoltTable) Schema() sql.Schema {
|
||||
// TODO: fix panics
|
||||
sch, err := t.table.GetSchema(context.TODO())
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// TODO: fix panics
|
||||
sqlSch, err := doltSchemaToSqlSchema(t.name, sch)
|
||||
return t.sqlSchema()
|
||||
}
|
||||
|
||||
func (t *DoltTable) sqlSchema() sql.Schema {
|
||||
if t.sqlSch != nil {
|
||||
return t.sqlSch
|
||||
}
|
||||
|
||||
// TODO: fix panics
|
||||
sqlSch, err := doltSchemaToSqlSchema(t.name, t.sch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
t.sqlSch = sqlSch
|
||||
return sqlSch
|
||||
}
|
||||
|
||||
@@ -106,206 +108,44 @@ func (t *DoltTable) PartitionRows(ctx *sql.Context, _ sql.Partition) (sql.RowIte
|
||||
return newRowIterator(t, ctx)
|
||||
}
|
||||
|
||||
type tableEditor struct {
|
||||
t *DoltTable
|
||||
ed *types.MapEditor
|
||||
addedKeys map[hash.Hash]bool
|
||||
deletedKeys map[hash.Hash]bool
|
||||
}
|
||||
|
||||
func newTableEditor(t *DoltTable) *tableEditor {
|
||||
return &tableEditor{
|
||||
t: t,
|
||||
addedKeys: make(map[hash.Hash]bool),
|
||||
deletedKeys: make(map[hash.Hash]bool),
|
||||
}
|
||||
}
|
||||
|
||||
var _ sql.RowReplacer = (*tableEditor)(nil)
|
||||
var _ sql.RowUpdater = (*tableUpdater)(nil)
|
||||
var _ sql.RowInserter = (*tableEditor)(nil)
|
||||
var _ sql.RowDeleter = (*tableEditor)(nil)
|
||||
|
||||
func (te *tableEditor) Insert(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
dRow, err := SqlRowToDoltRow(te.t.table.Format(), sqlRow, te.t.sch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, err := dRow.NomsMapKey(te.t.sch).Value(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to get row key").AddCause(err).Build()
|
||||
}
|
||||
_, rowExists, err := te.t.table.GetRow(ctx, key.(types.Tuple), te.t.sch)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to read table").AddCause(err).Build()
|
||||
}
|
||||
|
||||
hash, err := key.Hash(dRow.Format())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if (rowExists && !te.deletedKeys[hash]) || te.addedKeys[hash] {
|
||||
return errors.New("duplicate primary key given")
|
||||
}
|
||||
te.addedKeys[hash] = true
|
||||
|
||||
if te.ed == nil {
|
||||
te.ed, err = te.t.newMapEditor(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
te.ed = te.ed.Set(key, dRow.NomsMapValue(te.t.sch))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DoltTable) newMapEditor(ctx context.Context) (*types.MapEditor, error) {
|
||||
typesMap, err := t.table.GetRowData(ctx)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("failed to get row data.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
return typesMap.Edit(), nil
|
||||
}
|
||||
|
||||
func (te *tableEditor) Delete(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
dRow, err := SqlRowToDoltRow(te.t.table.Format(), sqlRow, te.t.sch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, err := dRow.NomsMapKey(te.t.sch).Value(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to get row key").AddCause(err).Build()
|
||||
}
|
||||
hash, err := key.Hash(dRow.Format())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(te.addedKeys, hash)
|
||||
te.deletedKeys[hash] = true
|
||||
|
||||
if te.ed == nil {
|
||||
te.ed, err = te.t.newMapEditor(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
te.ed = te.ed.Remove(key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// tableUpdater wraps tableEditor to override the close method, necessary to enforce primary key constraints when
|
||||
// updates to primary key columns are applied in an arbitrary order
|
||||
type tableUpdater struct {
|
||||
t *DoltTable
|
||||
ed *types.MapEditor
|
||||
addedKeys map[hash.Hash]types.LesserValuable
|
||||
removedKeys map[hash.Hash]types.LesserValuable
|
||||
}
|
||||
|
||||
func (tu *tableUpdater) Close(ctx *sql.Context) error {
|
||||
// For all added keys, check for and report a collision
|
||||
for hash, addedKey := range tu.addedKeys {
|
||||
if _, ok := tu.removedKeys[hash]; !ok {
|
||||
_, rowExists, err := tu.t.table.GetRow(ctx, addedKey.(types.Tuple), tu.t.sch)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to read table").AddCause(err).Build()
|
||||
}
|
||||
if rowExists {
|
||||
return fmt.Errorf("primary key collision: (%v)", addedKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
// For all removed keys, remove the map entries that weren't added elsewhere by other updates
|
||||
for hash, removedKey := range tu.removedKeys {
|
||||
if _, ok := tu.addedKeys[hash]; !ok {
|
||||
tu.ed.Remove(removedKey)
|
||||
}
|
||||
}
|
||||
|
||||
if tu.ed != nil {
|
||||
return tu.t.updateTable(ctx, tu.ed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tu *tableUpdater) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
|
||||
dOldRow, err := SqlRowToDoltRow(tu.t.table.Format(), oldRow, tu.t.sch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dNewRow, err := SqlRowToDoltRow(tu.t.table.Format(), newRow, tu.t.sch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the PK is changed then we have to delete the old row first
|
||||
dOldKey := dOldRow.NomsMapKey(tu.t.sch)
|
||||
dOldKeyVal, err := dOldKey.Value(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dNewKey := dNewRow.NomsMapKey(tu.t.sch)
|
||||
dNewKeyVal, err := dNewKey.Value(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tu.ed == nil {
|
||||
tu.ed, err = tu.t.newMapEditor(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !dOldKeyVal.Equals(dNewKeyVal) {
|
||||
oldHash, err := dOldKeyVal.Hash(dOldRow.Format())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newHash, err := dNewKeyVal.Hash(dNewRow.Format())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tu.addedKeys[newHash] = dNewKeyVal
|
||||
tu.removedKeys[oldHash] = dOldKey
|
||||
}
|
||||
|
||||
tu.ed.Set(dNewKey, dNewRow.NomsMapValue(tu.t.sch))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (te *tableEditor) Close(ctx *sql.Context) error {
|
||||
if te.ed != nil {
|
||||
return te.t.updateTable(ctx, te.ed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Inserter implements sql.InsertableTable
|
||||
func (t *DoltTable) Inserter(ctx *sql.Context) sql.RowInserter {
|
||||
return newTableEditor(t)
|
||||
return t.getTableEditor()
|
||||
}
|
||||
|
||||
func (t *DoltTable) Deleter(*sql.Context) sql.RowDeleter {
|
||||
return newTableEditor(t)
|
||||
}
|
||||
|
||||
func (t *DoltTable) Replacer(ctx *sql.Context) sql.RowReplacer {
|
||||
return newTableEditor(t)
|
||||
}
|
||||
|
||||
func (t *DoltTable) Updater(ctx *sql.Context) sql.RowUpdater {
|
||||
return &tableUpdater{
|
||||
t: t,
|
||||
addedKeys: make(map[hash.Hash]types.LesserValuable),
|
||||
removedKeys: make(map[hash.Hash]types.LesserValuable),
|
||||
func (t *DoltTable) getTableEditor() *tableEditor {
|
||||
if t.db.batchMode == batched {
|
||||
if t.ed != nil {
|
||||
return t.ed
|
||||
}
|
||||
t.ed = newTableEditor(t)
|
||||
return t.ed
|
||||
}
|
||||
return newTableEditor(t)
|
||||
}
|
||||
|
||||
func (t *DoltTable) flushBatchedEdits(ctx context.Context) error {
|
||||
if t.ed != nil {
|
||||
err := t.ed.flush(ctx)
|
||||
t.ed = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deleter implements sql.DeletableTable
|
||||
func (t *DoltTable) Deleter(*sql.Context) sql.RowDeleter {
|
||||
return t.getTableEditor()
|
||||
}
|
||||
|
||||
// Replacer implements sql.ReplaceableTable
|
||||
func (t *DoltTable) Replacer(ctx *sql.Context) sql.RowReplacer {
|
||||
return t.getTableEditor()
|
||||
}
|
||||
|
||||
// Updater implements sql.UpdatableTable
|
||||
func (t *DoltTable) Updater(ctx *sql.Context) sql.RowUpdater {
|
||||
return t.getTableEditor()
|
||||
}
|
||||
|
||||
// doltTablePartitionIter, an object that knows how to return the single partition exactly once.
|
||||
@@ -342,7 +182,7 @@ func (p doltTablePartition) Key() []byte {
|
||||
return []byte(partitionName)
|
||||
}
|
||||
|
||||
func (t *DoltTable) updateTable(ctx *sql.Context, mapEditor *types.MapEditor) error {
|
||||
func (t *DoltTable) updateTable(ctx context.Context, mapEditor *types.MapEditor) error {
|
||||
updated, err := mapEditor.Map(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to modify table").AddCause(err).Build()
|
||||
|
||||
Reference in New Issue
Block a user