mirror of
https://github.com/dolthub/dolt.git
synced 2026-01-06 08:50:04 -06:00
Fully working batch for a batch size of 1
Signed-off-by: Zach Musgrave <zach@liquidata.co>
This commit is contained in:
@@ -2,7 +2,6 @@ package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
@@ -11,11 +10,11 @@ import (
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
var ErrKeyExists = errors.New("key already exists")
|
||||
|
||||
// SqlBatcher knows how to efficiently batch insert / update statements, e.g. when doing a SQL import. It does this by
|
||||
// using a single MapEditor per table that isn't persisted until Commit is called.
|
||||
type SqlBatcher struct {
|
||||
// The database we are editing
|
||||
db *doltdb.DoltDB
|
||||
// The root value we are editing
|
||||
root *doltdb.RootValue
|
||||
// The set of tables under edit
|
||||
@@ -31,10 +30,12 @@ type SqlBatcher struct {
|
||||
}
|
||||
|
||||
// Returns a new SqlBatcher for the given environment and root value.
|
||||
func NewSqlBatcher(root *doltdb.RootValue) *SqlBatcher {
|
||||
func NewSqlBatcher(db *doltdb.DoltDB, root *doltdb.RootValue) *SqlBatcher {
|
||||
return &SqlBatcher{
|
||||
db: db,
|
||||
root: root,
|
||||
tables: make(map[string]*doltdb.Table),
|
||||
schemas: make(map[string]schema.Schema),
|
||||
rowData: make(map[string]types.Map),
|
||||
editors: make(map[string]*types.MapEditor),
|
||||
hashes: make(map[string]map[hash.Hash]bool),
|
||||
@@ -44,9 +45,6 @@ func NewSqlBatcher(root *doltdb.RootValue) *SqlBatcher {
|
||||
type InsertOptions struct {
|
||||
// Whether to silently replace any existing rows with the same primary key as rows inserted
|
||||
Replace bool
|
||||
// Whether to ignore primary key duplication. Unlike Replace, inserts for existing keys are simply ignored, not
|
||||
// updated.
|
||||
IgnoreExisting bool
|
||||
}
|
||||
|
||||
type BatchInsertResult struct {
|
||||
@@ -77,15 +75,8 @@ func (b *SqlBatcher) Insert(ctx context.Context, tableName string, r row.Row, op
|
||||
rowAlreadyTouched := hashes[key.Hash(b.root.VRW().Format())]
|
||||
|
||||
if rowExists || rowAlreadyTouched {
|
||||
if !opt.Replace && !opt.IgnoreExisting {
|
||||
return nil, ErrKeyExists
|
||||
}
|
||||
|
||||
// If Replace and IgnoreExisting are both set, favor Replace semantics
|
||||
if opt.Replace {
|
||||
// do nothing, continue to editing
|
||||
} else if opt.IgnoreExisting {
|
||||
return &BatchInsertResult{}, nil
|
||||
if !opt.Replace {
|
||||
return nil, fmt.Errorf("Duplicate primary key: '%v'", getPrimaryKeyString(r, sch))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,6 +159,17 @@ func (b *SqlBatcher) Update(r row.Row) {
|
||||
|
||||
}
|
||||
|
||||
func (b *SqlBatcher) Commit() (*doltdb.RootValue, error) {
|
||||
return nil, nil
|
||||
// Commit writes a new root value for every table under edit and returns the new root value. Tables are written in an
|
||||
// arbitrary order.
|
||||
func (b *SqlBatcher) Commit(ctx context.Context) (*doltdb.RootValue, error) {
|
||||
root := b.root
|
||||
|
||||
for tableName, ed := range b.editors {
|
||||
newMap := ed.Map(ctx)
|
||||
table := b.tables[tableName]
|
||||
table = table.UpdateRows(ctx, newMap)
|
||||
root = root.PutTable(ctx, b.db, tableName, table)
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/liquidata-inc/dolt/go/store/hash"
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
@@ -103,7 +102,7 @@ func ExecuteBatchInsert(
|
||||
|
||||
// Perform the insert
|
||||
var result InsertResult
|
||||
opt := InsertOptions{Replace: replace, IgnoreExisting: ignore}
|
||||
opt := InsertOptions{replace}
|
||||
for _, r := range rows {
|
||||
if !row.IsValid(r, tableSch) {
|
||||
if ignore {
|
||||
@@ -145,100 +144,23 @@ func ExecuteInsert(
|
||||
query string,
|
||||
) (*InsertResult, error) {
|
||||
|
||||
tableName := s.Table.Name.String()
|
||||
if !root.HasTable(ctx, tableName) {
|
||||
return errInsert("Unknown table %v", tableName)
|
||||
}
|
||||
table, _ := root.GetTable(ctx, tableName)
|
||||
tableSch := table.GetSchema(ctx)
|
||||
|
||||
// Parser supports overwrite on insert with both the replace keyword (from MySQL) as well as the ignore keyword
|
||||
replace := s.Action == sqlparser.ReplaceStr
|
||||
ignore := s.Ignore != ""
|
||||
|
||||
// Get the list of columns to insert into. We support both naked inserts (no column list specified) as well as
|
||||
// explicit column lists.
|
||||
var cols []schema.Column
|
||||
if s.Columns == nil || len(s.Columns) == 0 {
|
||||
cols = tableSch.GetAllCols().GetColumns()
|
||||
} else {
|
||||
cols = make([]schema.Column, len(s.Columns))
|
||||
for i, colName := range s.Columns {
|
||||
for _, c := range cols {
|
||||
if c.Name == colName.String() {
|
||||
return errInsert("Repeated column: '%v'", c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
col, ok := tableSch.GetAllCols().GetByName(colName.String())
|
||||
if !ok {
|
||||
return errInsert(UnknownColumnErrFmt, colName)
|
||||
}
|
||||
cols[i] = col
|
||||
}
|
||||
batcher := NewSqlBatcher(db, root)
|
||||
insertResult, err := ExecuteBatchInsert(ctx, db, root, s, batcher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rows []row.Row // your boat
|
||||
|
||||
switch queryRows := s.Rows.(type) {
|
||||
case sqlparser.Values:
|
||||
var err error
|
||||
rows, err = prepareInsertVals(root.VRW().Format(), cols, &queryRows, tableSch)
|
||||
if err != nil {
|
||||
return &InsertResult{}, err
|
||||
}
|
||||
case *sqlparser.Select:
|
||||
return errInsert("Insert as select not supported")
|
||||
case *sqlparser.ParenSelect:
|
||||
return errInsert("Parenthesized select expressions in insert not supported")
|
||||
case *sqlparser.Union:
|
||||
return errInsert("Union not supported")
|
||||
default:
|
||||
return errInsert("Unrecognized type for insertRows: %v", queryRows)
|
||||
newRoot, err := batcher.Commit(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Perform the insert
|
||||
rowData := table.GetRowData(ctx)
|
||||
me := rowData.Edit()
|
||||
var result InsertResult
|
||||
|
||||
insertedPKHashes := make(map[hash.Hash]struct{})
|
||||
for _, r := range rows {
|
||||
if !row.IsValid(r, tableSch) {
|
||||
if ignore {
|
||||
result.NumErrorsIgnored += 1
|
||||
continue
|
||||
} else {
|
||||
col, constraint := row.GetInvalidConstraint(r, tableSch)
|
||||
return nil, errFmt(ConstraintFailedFmt, col.Name, constraint)
|
||||
}
|
||||
}
|
||||
|
||||
key := r.NomsMapKey(tableSch).Value(ctx)
|
||||
|
||||
rowExists := rowData.Get(ctx, key) != nil
|
||||
_, rowInserted := insertedPKHashes[key.Hash(root.VRW().Format())]
|
||||
|
||||
if rowExists || rowInserted {
|
||||
if replace {
|
||||
result.NumRowsUpdated += 1
|
||||
} else if ignore {
|
||||
result.NumErrorsIgnored += 1
|
||||
continue
|
||||
} else {
|
||||
return errInsert("Duplicate primary key: '%v'", getPrimaryKeyString(r, tableSch))
|
||||
}
|
||||
}
|
||||
me.Set(key, r.NomsMapValue(tableSch))
|
||||
|
||||
insertedPKHashes[key.Hash(root.VRW().Format())] = struct{}{}
|
||||
}
|
||||
newMap := me.Map(ctx)
|
||||
table = table.UpdateRows(ctx, newMap)
|
||||
|
||||
result.NumRowsInserted = int(newMap.Len() - rowData.Len())
|
||||
result.Root = root.PutTable(ctx, db, tableName, table)
|
||||
return &result, nil
|
||||
return &InsertResult{
|
||||
Root: newRoot,
|
||||
NumRowsInserted: insertResult.NumRowsInserted,
|
||||
NumRowsUpdated: insertResult.NumRowsUpdated,
|
||||
NumErrorsIgnored: insertResult.NumErrorsIgnored,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Returns a primary key summary of the row given
|
||||
|
||||
@@ -334,7 +334,7 @@ func TestExecuteInsert(t *testing.T) {
|
||||
|
||||
for _, expectedRow := range tt.insertedValues {
|
||||
foundRow, ok := table.GetRow(ctx, expectedRow.NomsMapKey(PeopleTestSchema).Value(ctx).(types.Tuple), PeopleTestSchema)
|
||||
assert.True(t, ok, "Row not found: %v", expectedRow)
|
||||
require.True(t, ok, "Row not found: %v", expectedRow)
|
||||
eq, diff := rowsEqual(expectedRow, foundRow)
|
||||
assert.True(t, eq, "Rows not equals, found diff %v", diff)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user