Merge pull request #3032 from dolthub/andy/new-auto-increment

go/libraries/doltcore/sqle: Refactored Auto Increment
This commit is contained in:
AndyA
2022-03-25 14:29:32 -07:00
committed by GitHub
23 changed files with 434 additions and 474 deletions

View File

@@ -68,7 +68,7 @@ require (
)
require (
github.com/dolthub/go-mysql-server v0.11.1-0.20220324183628-b0a3bc9c2c2f
github.com/dolthub/go-mysql-server v0.11.1-0.20220325203039-101310d04210
github.com/google/flatbuffers v2.0.5+incompatible
github.com/gosuri/uilive v0.0.4
github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6

View File

@@ -170,8 +170,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-mysql-server v0.11.1-0.20220324183628-b0a3bc9c2c2f h1:F1hhtWcel9an0/Ohbdg0gw6gSlgwWBxlnrSl7Jyi/2M=
github.com/dolthub/go-mysql-server v0.11.1-0.20220324183628-b0a3bc9c2c2f/go.mod h1:1Sq4rszjP6AW7AJaF9xfycWexkNKIwkkOuYnoS5XcPg=
github.com/dolthub/go-mysql-server v0.11.1-0.20220325203039-101310d04210 h1:sZiIQR3mhQkPNb/9TbEg9W+nF4bcHer2ubPdl3iyKjM=
github.com/dolthub/go-mysql-server v0.11.1-0.20220325203039-101310d04210/go.mod h1:1Sq4rszjP6AW7AJaF9xfycWexkNKIwkkOuYnoS5XcPg=
github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 h1:oyPHJlzumKta1vnOQqUnfdz+pk3EmnHS3Nd0cCT0I2g=
github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371/go.mod h1:dhGBqcCEfK5kuFmeO5+WOx3hqc1k3M29c1oS/R7N4ms=
github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8=

View File

@@ -87,9 +87,9 @@ type Table interface {
SetConstraintViolations(ctx context.Context, violations types.Map) (Table, error)
// GetAutoIncrement returns the AUTO_INCREMENT sequence value for this table.
GetAutoIncrement(ctx context.Context) (types.Value, error)
GetAutoIncrement(ctx context.Context) (uint64, error)
// SetAutoIncrement sets the AUTO_INCREMENT sequence value for this table.
SetAutoIncrement(ctx context.Context, val types.Value) (Table, error)
SetAutoIncrement(ctx context.Context, val uint64) (Table, error)
}
type nomsTable struct {
@@ -140,7 +140,7 @@ func NewTable(ctx context.Context, vrw types.ValueReadWriter, sch schema.Schema,
indexesKey: indexesRef,
}
if autoIncVal != nil {
if schema.HasAutoIncrement(sch) && autoIncVal != nil {
sd[autoIncrementKey] = autoIncVal
}
@@ -461,53 +461,36 @@ func (t nomsTable) SetConstraintViolations(ctx context.Context, violationsMap ty
}
// GetAutoIncrement implements Table.
func (t nomsTable) GetAutoIncrement(ctx context.Context) (types.Value, error) {
func (t nomsTable) GetAutoIncrement(ctx context.Context) (uint64, error) {
val, ok, err := t.tableStruct.MaybeGet(autoIncrementKey)
if err != nil {
return nil, err
return 0, err
}
if ok {
return val, nil
if !ok {
return 1, nil
}
sch, err := t.GetSchema(ctx)
if err != nil {
return nil, err
}
kind := types.UnknownKind
_ = sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
if col.AutoIncrement {
kind = col.Kind
stop = true
}
return
})
switch kind {
case types.IntKind:
return types.Int(1), nil
case types.UintKind:
return types.Uint(1), nil
case types.FloatKind:
return types.Float(1), nil
// older versions might have serialized auto-increment
// value as types.Int or types.Float.
switch t := val.(type) {
case types.Int:
return uint64(t), nil
case types.Uint:
return uint64(t), nil
case types.Float:
return uint64(t), nil
default:
return nil, ErrUnknownAutoIncrementValue
return 0, ErrUnknownAutoIncrementValue
}
}
// SetAutoIncrement implements Table.
func (t nomsTable) SetAutoIncrement(ctx context.Context, val types.Value) (Table, error) {
switch val.(type) {
case types.Int, types.Uint, types.Float:
st, err := t.tableStruct.Set(autoIncrementKey, val)
if err != nil {
return nil, err
}
return nomsTable{t.vrw, st}, nil
default:
return nil, fmt.Errorf("cannot set auto increment to non-numeric value")
func (t nomsTable) SetAutoIncrement(ctx context.Context, val uint64) (Table, error) {
st, err := t.tableStruct.Set(autoIncrementKey, types.Uint(val))
if err != nil {
return nil, err
}
return nomsTable{t.vrw, st}, nil
}
func refFromNomsValue(ctx context.Context, vrw types.ValueReadWriter, val types.Value) (types.Ref, error) {

View File

@@ -455,12 +455,12 @@ func (t *Table) VerifyIndexRowData(ctx context.Context, indexName string) error
}
// GetAutoIncrementValue returns the current AUTO_INCREMENT value for this table.
func (t *Table) GetAutoIncrementValue(ctx context.Context) (types.Value, error) {
func (t *Table) GetAutoIncrementValue(ctx context.Context) (uint64, error) {
return t.table.GetAutoIncrement(ctx)
}
// SetAutoIncrementValue sets the current AUTO_INCREMENT value for this table.
func (t *Table) SetAutoIncrementValue(ctx context.Context, val types.Value) (*Table, error) {
func (t *Table) SetAutoIncrementValue(ctx context.Context, val uint64) (*Table, error) {
table, err := t.table.SetAutoIncrement(ctx, val)
if err != nil {
return nil, err

View File

@@ -786,14 +786,7 @@ func mergeAutoIncrementValues(ctx context.Context, tbl, otherTbl, resultTbl *dol
if err != nil {
return nil, err
}
auto := false
_ = sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
if col.AutoIncrement {
auto, stop = true, true
}
return
})
if !auto {
if !schema.HasAutoIncrement(sch) {
return resultTbl, nil
}
@@ -805,14 +798,10 @@ func mergeAutoIncrementValues(ctx context.Context, tbl, otherTbl, resultTbl *dol
if err != nil {
return nil, err
}
less, err := autoVal.Less(tbl.Format(), mergeAutoVal)
if err != nil {
return nil, err
}
if less {
if autoVal < mergeAutoVal {
autoVal = mergeAutoVal
}
return resultTbl.SetAutoIncrementValue(nil, autoVal)
return resultTbl.SetAutoIncrementValue(ctx, autoVal)
}
func MergeCommits(ctx context.Context, commit, mergeCommit *doltdb.Commit, opts editor.Options) (*doltdb.RootValue, map[string]*MergeStats, error) {

View File

@@ -149,7 +149,7 @@ func updateTableWithModifiedColumn(ctx context.Context, tbl *doltdb.Table, oldSc
return nil, err
}
var autoVal types.Value
var autoVal uint64
if schema.HasAutoIncrement(newSch) && schema.HasAutoIncrement(oldSch) {
autoVal, err = tbl.GetAutoIncrementValue(ctx)
if err != nil {
@@ -157,7 +157,7 @@ func updateTableWithModifiedColumn(ctx context.Context, tbl *doltdb.Table, oldSc
}
}
updatedTable, err := doltdb.NewNomsTable(ctx, vrw, newSch, rowData, indexData, autoVal)
updatedTable, err := doltdb.NewNomsTable(ctx, vrw, newSch, rowData, indexData, types.Uint(autoVal))
if err != nil {
return nil, err
}

View File

@@ -82,15 +82,11 @@ func DbsAsDSQLDBs(dbs []sql.Database) []SqlDatabase {
// Database implements sql.Database for a dolt DB.
type Database struct {
name string
ddb *doltdb.DoltDB
rsr env.RepoStateReader
rsw env.RepoStateWriter
drw env.DocsReadWriter
// todo: needs a major refactor to
// correctly handle persisted sequences
// that must be coordinated across txs
name string
ddb *doltdb.DoltDB
rsr env.RepoStateReader
rsw env.RepoStateWriter
drw env.DocsReadWriter
gs globalstate.GlobalState
editOpts editor.Options
}
@@ -167,6 +163,7 @@ var _ sql.TableRenamer = Database{}
var _ sql.TriggerDatabase = Database{}
var _ sql.StoredProcedureDatabase = Database{}
var _ sql.TransactionDatabase = Database{}
var _ globalstate.StateProvider = Database{}
// NewDatabase returns a new dolt database to use in queries.
func NewDatabase(name string, dbData env.DbData, editOpts editor.Options) Database {
@@ -259,6 +256,10 @@ func (db Database) DbData() env.DbData {
}
}
func (db Database) GetGlobalState() globalstate.GlobalState {
return db.gs
}
// GetTableInsensitive is used when resolving tables in queries. It returns a best-effort case-insensitive match for
// the table name given.
func (db Database) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) {
@@ -585,6 +586,18 @@ func (db Database) GetRoot(ctx *sql.Context) (*doltdb.RootValue, error) {
return dbState.GetRoots().Working, nil
}
func (db Database) GetWorkingSet(ctx *sql.Context) (*doltdb.WorkingSet, error) {
sess := dsess.DSessFromSess(ctx.Session)
dbState, ok, err := sess.LookupDbState(ctx, db.Name())
if err != nil {
return nil, err
}
if !ok {
return nil, fmt.Errorf("no root value found in session")
}
return dbState.WorkingSet, nil
}
// SetRoot should typically be called on the Session, which is where this state lives. But it's available here as a
// convenience.
func (db Database) SetRoot(ctx *sql.Context, newRoot *doltdb.RootValue) error {
@@ -660,7 +673,10 @@ func (db Database) dropTableFromAiTracker(ctx *sql.Context, tableName string) er
return err
}
ait := db.gs.GetAutoIncrementTracker(ws.Ref())
ait, err := db.gs.GetAutoIncrementTracker(ctx, ws)
if err != nil {
return err
}
ait.DropTable(tableName)
return nil
@@ -681,10 +697,11 @@ func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.Prima
// Unlike the exported version CreateTable, createSqlTable doesn't enforce any table name checks.
func (db Database) createSqlTable(ctx *sql.Context, tableName string, sch sql.PrimaryKeySchema) error {
root, err := db.GetRoot(ctx)
ws, err := db.GetWorkingSet(ctx)
if err != nil {
return err
}
root := ws.WorkingRoot()
if exists, err := root.HasTable(ctx, tableName); err != nil {
return err
@@ -707,6 +724,14 @@ func (db Database) createSqlTable(ctx *sql.Context, tableName string, sch sql.Pr
return schema.ErrUsingSpatialKey.New(tableName)
}
if schema.HasAutoIncrement(doltSch) {
ait, err := db.gs.GetAutoIncrementTracker(ctx, ws)
if err != nil {
return err
}
ait.AddNewTable(tableName)
}
return db.createDoltTable(ctx, tableName, root, doltSch)
}

View File

@@ -485,7 +485,6 @@ func dbRevisionForCommit(ctx context.Context, srcDb Database, revSpec string) (R
rsw: srcDb.DbData().Rsw,
rsr: srcDb.DbData().Rsr,
drw: srcDb.DbData().Drw,
gs: nil,
editOpts: srcDb.editOpts,
}}
init := dsess.InitialDbState{

View File

@@ -17,11 +17,11 @@ package dsess
import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
)
type InitialDbState struct {
@@ -48,6 +48,7 @@ type DatabaseSessionState struct {
WorkingSet *doltdb.WorkingSet
dbData env.DbData
WriteSession writer.WriteSession
globalState globalstate.GlobalState
readOnly bool
dirty bool
readReplica *env.Remote

View File

@@ -27,6 +27,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/config"
@@ -700,7 +701,11 @@ func (sess *Session) SwitchWorkingSet(
// make a fresh WriteSession, discard existing WriteSession
opts := sessionState.WriteSession.GetOptions()
nbf := ws.WorkingRoot().VRW().Format()
sessionState.WriteSession = writer.NewWriteSession(nbf, ws, opts)
tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx, ws)
if err != nil {
return err
}
sessionState.WriteSession = writer.NewWriteSession(nbf, ws, tracker, opts)
// After switching to a new working set, we are by definition clean
sessionState.dirty = false
@@ -830,17 +835,26 @@ func (sess *Session) AddDB(ctx *sql.Context, dbState InitialDbState) error {
nbf := sessionState.dbData.Ddb.Format()
editOpts := db.(interface{ EditOptions() editor.Options }).EditOptions()
stateProvider, ok := db.(globalstate.StateProvider)
if !ok {
return fmt.Errorf("database does not contain global state store")
}
sessionState.globalState = stateProvider.GetGlobalState()
// WorkingSet is nil in the case of a read only, detached head DB
if dbState.Err != nil {
sessionState.Err = dbState.Err
} else if dbState.WorkingSet != nil {
sessionState.WorkingSet = dbState.WorkingSet
sessionState.WriteSession = writer.NewWriteSession(nbf, sessionState.WorkingSet, editOpts)
err := sess.SetWorkingSet(ctx, db.Name(), dbState.WorkingSet)
tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx, sessionState.WorkingSet)
if err != nil {
return err
}
sessionState.WriteSession = writer.NewWriteSession(nbf, sessionState.WorkingSet, tracker, editOpts)
if err = sess.SetWorkingSet(ctx, db.Name(), dbState.WorkingSet); err != nil {
return err
}
} else {
headRoot, err := dbState.HeadCommit.GetRootValue()

View File

@@ -15,146 +15,118 @@
package globalstate
import (
"fmt"
"context"
"math"
"sync"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
)
type AutoIncrementTracker interface {
// Next returns the next auto increment value to be used by a table. If a table is not initialized in the counter
// it will used the value stored in disk.
Next(tableName string, insertVal interface{}, diskVal interface{}) (interface{}, error)
// Reset resets the auto increment tracker value for a table. Typically used in truncate statements.
Reset(tableName string, val interface{})
// DropTable removes a table from the autoincrement tracker.
DropTable(tableName string)
}
// AutoIncrementTracker is a global map that tracks which auto increment keys have been given for each table. At runtime
// it hands out the current key.
func NewAutoIncrementTracker() AutoIncrementTracker {
return &autoIncrementTracker{
valuePerTable: make(map[string]interface{}),
// CoerceAutoIncrementValue converts |val| into an AUTO_INCREMENT sequence value
func CoerceAutoIncrementValue(val interface{}) (uint64, error) {
switch typ := val.(type) {
case float32:
val = math.Round(float64(typ))
case float64:
val = math.Round(typ)
}
var err error
val, err = sql.Uint64.Convert(val)
if err != nil {
return 0, err
}
if val == nil || val == uint64(0) {
return 0, nil
}
return val.(uint64), nil
}
type autoIncrementTracker struct {
valuePerTable map[string]interface{}
mu sync.Mutex
func NewAutoIncrementTracker(ctx context.Context, ws *doltdb.WorkingSet) (ait AutoIncrementTracker, err error) {
ait = AutoIncrementTracker{
wsRef: ws.Ref(),
sequences: make(map[string]uint64),
mu: &sync.Mutex{},
}
// collect auto increment values
err = ws.WorkingRoot().IterTables(ctx, func(name string, table *doltdb.Table, sch schema.Schema) (bool, error) {
ok := schema.HasAutoIncrement(sch)
if !ok {
return false, nil
}
seq, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
}
ait.sequences[name] = seq
return false, nil
})
return ait, err
}
var _ AutoIncrementTracker = (*autoIncrementTracker)(nil)
type AutoIncrementTracker struct {
wsRef ref.WorkingSetRef
sequences map[string]uint64
mu *sync.Mutex
}
func (a *autoIncrementTracker) Next(tableName string, insertVal interface{}, diskVal interface{}) (interface{}, error) {
// Current returns the current AUTO_INCREMENT value for |tableName|.
func (a AutoIncrementTracker) Current(tableName string) uint64 {
a.mu.Lock()
defer a.mu.Unlock()
return a.sequences[tableName]
}
// Next returns the next AUTO_INCREMENT value for |tableName|, considering the provided |insertVal|.
func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
a.mu.Lock()
defer a.mu.Unlock()
diskVal = valOrZero(diskVal)
// Case 0: Just use the value passed in.
potential, ok := a.valuePerTable[tableName]
if !ok {
// Use the disk val if the table has not been initialized yet.
potential = diskVal
}
// Case 1: Disk Val is greater. This is useful for updating the tracker when a merge occurs.
// TODO: This is a bit of a hack. The correct solution is to plumb this tracker through the merge logic.
diskValGreater, err := geq(valOrZero(diskVal), valOrZero(a.valuePerTable[tableName]))
if err != nil {
return nil, err
}
if diskValGreater {
potential = diskVal
}
// Case 2: Overwrite anything if an insert val is passed.
if insertVal != nil {
potential = insertVal
}
// update the table only if val >= existing
isGeq, err := geq(valOrZero(potential), valOrZero(a.valuePerTable[tableName]))
given, err := CoerceAutoIncrementValue(insertVal)
if err != nil {
return 0, err
}
if isGeq {
val, err := convertIntTypeToUint(potential)
if err != nil {
return val, err
}
curr := a.sequences[tbl]
a.valuePerTable[tableName] = val + 1
if given == 0 {
// |given| is 0 or NULL
a.sequences[tbl]++
return curr, nil
}
return potential, nil
if given >= curr {
a.sequences[tbl] = given
a.sequences[tbl]++
return given, nil
}
// |given| < curr
return given, nil
}
func (a *autoIncrementTracker) Reset(tableName string, val interface{}) {
// Set sets the current AUTO_INCREMENT value for |tableName|.
func (a AutoIncrementTracker) Set(tableName string, val uint64) {
a.mu.Lock()
defer a.mu.Unlock()
a.valuePerTable[tableName] = val
a.sequences[tableName] = val
}
func (a *autoIncrementTracker) DropTable(tableName string) {
// AddNewTable adds |tablename| to the AutoIncrementTracker.
func (a AutoIncrementTracker) AddNewTable(tableName string) {
a.mu.Lock()
defer a.mu.Unlock()
delete(a.valuePerTable, tableName)
a.sequences[tableName] = uint64(1)
}
// Helper method that sets nil values to 0 for clarity purposes
func valOrZero(val interface{}) interface{} {
if val == nil {
return 0
}
return val
}
func geq(val1 interface{}, val2 interface{}) (bool, error) {
v1, err := convertIntTypeToUint(val1)
if err != nil {
return false, err
}
v2, err := convertIntTypeToUint(val2)
if err != nil {
return false, err
}
return v1 >= v2, nil
}
func convertIntTypeToUint(val interface{}) (uint64, error) {
switch t := val.(type) {
case int:
return uint64(t), nil
case int8:
return uint64(t), nil
case int16:
return uint64(t), nil
case int32:
return uint64(t), nil
case int64:
return uint64(t), nil
case uint:
return uint64(t), nil
case uint8:
return uint64(t), nil
case uint16:
return uint64(t), nil
case uint32:
return uint64(t), nil
case uint64:
return t, nil
case float32:
return uint64(t), nil
case float64:
return uint64(t), nil
default:
return 0, fmt.Errorf("error: auto increment is not a numeric type")
}
// DropTable drops |tablename| from the AutoIncrementTracker.
func (a AutoIncrementTracker) DropTable(tableName string) {
a.mu.Lock()
defer a.mu.Unlock()
delete(a.sequences, tableName)
}

View File

@@ -15,42 +15,58 @@
package globalstate
import (
"sync"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
)
func TestNextHasNoRepeats(t *testing.T) {
var allVals sync.Map
aiTracker := NewAutoIncrementTracker()
for i := 0; i < 100; i++ {
go func() {
for j := 0; j < 10; j++ {
nxt, err := aiTracker.Next("test", nil, 1)
require.NoError(t, err)
val, err := convertIntTypeToUint(nxt)
require.NoError(t, err)
current, ok := allVals.Load(val)
if !ok {
allVals.Store(val, 1)
} else {
asUint, _ := convertIntTypeToUint(current)
allVals.Store(val, asUint+1)
}
}
}()
func TestCoerceAutoIncrementValue(t *testing.T) {
tests := []struct {
val interface{}
exp uint64
err bool
}{
{
val: nil,
exp: uint64(0),
},
{
val: int32(0),
exp: uint64(0),
},
{
val: int32(1),
exp: uint64(1),
},
{
val: uint32(1),
exp: uint64(1),
},
{
val: float32(1),
exp: uint64(1),
},
{
val: float32(1.1),
exp: uint64(1),
},
{
val: float32(1.9),
exp: uint64(2),
},
}
// Make sure each key was called once
allVals.Range(func(key, value interface{}) bool {
asUint, _ := convertIntTypeToUint(value)
require.Equal(t, uint64(1), asUint)
return true
})
for _, test := range tests {
name := fmt.Sprintf("Coerce %v to %v", test.val, test.exp)
t.Run(name, func(t *testing.T) {
act, err := CoerceAutoIncrementValue(test.val)
if test.err {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, test.exp, act)
})
}
}

View File

@@ -15,36 +15,45 @@
package globalstate
import (
"context"
"sync"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
)
type GlobalState interface {
GetAutoIncrementTracker(wsref ref.WorkingSetRef) AutoIncrementTracker
type StateProvider interface {
GetGlobalState() GlobalState
}
func NewGlobalStateStore() GlobalState {
return &globalStateImpl{
return GlobalState{
trackerMap: make(map[ref.WorkingSetRef]AutoIncrementTracker),
mu: &sync.Mutex{},
}
}
type globalStateImpl struct {
type GlobalState struct {
trackerMap map[ref.WorkingSetRef]AutoIncrementTracker
mu sync.Mutex
mu *sync.Mutex
}
var _ GlobalState = (*globalStateImpl)(nil)
func (g *globalStateImpl) GetAutoIncrementTracker(wsref ref.WorkingSetRef) AutoIncrementTracker {
func (g GlobalState) GetAutoIncrementTracker(ctx context.Context, ws *doltdb.WorkingSet) (AutoIncrementTracker, error) {
g.mu.Lock()
defer g.mu.Unlock()
_, ok := g.trackerMap[wsref]
if !ok {
g.trackerMap[wsref] = NewAutoIncrementTracker()
ait, ok := g.trackerMap[ws.Ref()]
if ok {
return ait, nil
}
return g.trackerMap[wsref]
var err error
ait, err = NewAutoIncrementTracker(ctx, ws)
if err != nil {
return AutoIncrementTracker{}, err
}
g.trackerMap[ws.Ref()] = ait
return ait, nil
}

View File

@@ -72,7 +72,7 @@ func (e *StaticErrorEditor) Update(*sql.Context, sql.Row, sql.Row) error {
return e.err
}
func (e *StaticErrorEditor) SetAutoIncrementValue(*sql.Context, interface{}) error {
func (e *StaticErrorEditor) SetAutoIncrementValue(*sql.Context, uint64) error {
return e.err
}

View File

@@ -35,6 +35,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/schema/alterschema"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
@@ -196,12 +197,7 @@ func (t *DoltTable) GetAutoIncrementValue(ctx *sql.Context) (interface{}, error)
if err != nil {
return nil, err
}
val, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return nil, err
}
return t.autoIncCol.TypeInfo.ConvertNomsValueToValue(val)
return table.GetAutoIncrementValue(ctx)
}
// Name returns the name of the table.
@@ -415,19 +411,13 @@ func (t *WritableDoltTable) getTableEditor(ctx *sql.Context) (ed writer.TableWri
}
}
ws, err := ds.WorkingSet(ctx, t.db.name)
if err != nil {
return nil, err
}
ait := t.db.gs.GetAutoIncrementTracker(ws.Ref())
state, _, err := ds.LookupDbState(ctx, t.db.name)
if err != nil {
return nil, err
}
setter := ds.SetRoot
ed, err = state.WriteSession.GetTableWriter(ctx, t.tableName, t.db.Name(), ait, setter, batched)
ed, err = state.WriteSession.GetTableWriter(ctx, t.tableName, t.db.Name(), setter, batched)
if err != nil {
return nil, err
@@ -542,22 +532,17 @@ func (t *WritableDoltTable) PeekNextAutoIncrementValue(ctx *sql.Context) (interf
}
// GetNextAutoIncrementValue implements sql.AutoIncrementTable
func (t *WritableDoltTable) GetNextAutoIncrementValue(ctx *sql.Context, potentialVal interface{}) (interface{}, error) {
func (t *WritableDoltTable) GetNextAutoIncrementValue(ctx *sql.Context, potentialVal interface{}) (uint64, error) {
if !t.autoIncCol.AutoIncrement {
return nil, sql.ErrNoAutoIncrementCol
return 0, sql.ErrNoAutoIncrementCol
}
ed, err := t.getTableEditor(ctx)
if err != nil {
return nil, err
return 0, err
}
tableVal, err := t.getTableAutoIncrementValue(ctx)
if err != nil {
return nil, err
}
return ed.NextAutoIncrementValue(potentialVal, tableVal)
return ed.GetNextAutoIncrementValue(ctx, potentialVal)
}
func (t *WritableDoltTable) getTableAutoIncrementValue(ctx *sql.Context) (interface{}, error) {
@@ -887,6 +872,18 @@ func (t *AlterableDoltTable) AddColumn(ctx *sql.Context, column *sql.Column, ord
return err
}
if column.AutoIncrement {
ws, err := t.db.GetWorkingSet(ctx)
if err != nil {
return err
}
ait, err := t.db.gs.GetAutoIncrementTracker(ctx, ws)
if err != nil {
return err
}
ait.AddNewTable(t.tableName)
}
newRoot, err := root.PutTable(ctx, t.tableName, updatedTable)
if err != nil {
return err
@@ -1042,10 +1039,11 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c
return nil
}
root, err := t.getRoot(ctx)
ws, err := t.db.GetWorkingSet(ctx)
if err != nil {
return err
}
root := ws.WorkingRoot()
table, _, err := root.GetTable(ctx, t.tableName)
if err != nil {
@@ -1118,10 +1116,6 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c
return err
}
initialValue := column.Type.Zero()
colIdx := updatedSch.GetAllCols().IndexOf(columnName)
rowData, err := updatedTable.GetRowData(ctx)
if err != nil {
return err
@@ -1134,6 +1128,9 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c
return err
}
initialValue := column.Type.Zero()
colIdx := updatedSch.GetAllCols().IndexOf(columnName)
for {
r, err := rowIter.Next(ctx)
if err == io.EOF {
@@ -1146,23 +1143,28 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c
if err != nil {
return err
}
if cmp < 0 {
initialValue = r[colIdx]
}
}
initialValNoms, err := col.TypeInfo.ConvertValueToNomsValue(ctx, root.VRW(), initialValue)
seq, err := globalstate.CoerceAutoIncrementValue(initialValue)
if err != nil {
return err
}
seq++
updatedTable, err = updatedTable.SetAutoIncrementValue(ctx, seq)
if err != nil {
return err
}
initialValNoms = increment(initialValNoms)
updatedTable, err = updatedTable.SetAutoIncrementValue(ctx, initialValNoms)
ait, err := t.db.gs.GetAutoIncrementTracker(ctx, ws)
if err != nil {
return err
}
ait.AddNewTable(t.tableName)
ait.Set(t.tableName, seq)
}
newRoot, err := root.PutTable(ctx, t.tableName, updatedTable)

View File

@@ -35,8 +35,7 @@ type TableWriter interface {
sql.RowInserter
sql.RowDeleter
sql.AutoIncrementSetter
NextAutoIncrementValue(potentialVal, tableVal interface{}) (interface{}, error)
GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error)
}
// SessionRootSetter sets the root value for the session.
@@ -56,14 +55,14 @@ type nomsTableWriter struct {
tableName string
dbName string
sch schema.Schema
autoIncCol schema.Column
vrw types.ValueReadWriter
kvToSQLRow *index.KVToSqlRowConverter
tableEditor editor.TableEditor
sess WriteSession
aiTracker globalstate.AutoIncrementTracker
batched bool
autoInc globalstate.AutoIncrementTracker
setter SessionRootSetter
}
@@ -158,25 +157,17 @@ func (te *nomsTableWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.R
return err
}
func (te *nomsTableWriter) NextAutoIncrementValue(potentialVal, tableVal interface{}) (interface{}, error) {
return te.aiTracker.Next(te.tableName, potentialVal, tableVal)
func (te *nomsTableWriter) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) {
return te.autoInc.Next(te.tableName, insertVal)
}
func (te *nomsTableWriter) GetAutoIncrementValue() (interface{}, error) {
val := te.tableEditor.GetAutoIncrementValue()
return te.autoIncCol.TypeInfo.ConvertNomsValueToValue(val)
}
func (te *nomsTableWriter) SetAutoIncrementValue(ctx *sql.Context, val interface{}) error {
nomsVal, err := te.autoIncCol.TypeInfo.ConvertValueToNomsValue(ctx, te.vrw, val)
func (te *nomsTableWriter) SetAutoIncrementValue(ctx *sql.Context, val uint64) error {
seq, err := globalstate.CoerceAutoIncrementValue(val)
if err != nil {
return err
}
if err = te.tableEditor.SetAutoIncrementValue(nomsVal); err != nil {
return err
}
te.aiTracker.Reset(te.tableName, val)
te.autoInc.Set(te.tableName, seq)
te.tableEditor.MarkDirty()
return te.flush(ctx)
}

View File

@@ -19,20 +19,21 @@ import (
"fmt"
"sync"
"github.com/dolthub/dolt/go/store/types"
"golang.org/x/sync/errgroup"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/store/types"
)
// WriteSession encapsulates writes made within a SQL session.
// It's responsible for creating and managing the lifecycle of TableWriter's.
type WriteSession interface {
// GetTableWriter creates a TableWriter and adds it to the WriteSession.
GetTableWriter(ctx context.Context, table, db string, ait globalstate.AutoIncrementTracker, setter SessionRootSetter, batched bool) (TableWriter, error)
GetTableWriter(ctx context.Context, table, db string, setter SessionRootSetter, batched bool) (TableWriter, error)
// UpdateWorkingSet takes a callback to update this WriteSession's WorkingSet. The update method cannot change the
// WorkingSetRef of the WriteSession. WriteSession flushes the pending writes in the session before calling the update.
@@ -56,7 +57,8 @@ type WriteSession interface {
type nomsWriteSession struct {
workingSet *doltdb.WorkingSet
tables map[string]*sessionedTableEditor
writeMutex *sync.RWMutex // This mutex is specifically for changes that affect the TES or all STEs
tracker globalstate.AutoIncrementTracker
mut *sync.RWMutex // This mutex is specifically for changes that affect the TES or all STEs
opts editor.Options
}
@@ -65,11 +67,12 @@ var _ WriteSession = &nomsWriteSession{}
// NewWriteSession creates and returns a WriteSession. Inserting a nil root is not an error, as there are
// locations that do not have a root at the time of this call. However, a root must be set through SetRoot before any
// table editors are returned.
func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, opts editor.Options) WriteSession {
func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, tracker globalstate.AutoIncrementTracker, opts editor.Options) WriteSession {
if types.IsFormat_DOLT_1(nbf) {
return &prollyWriteSession{
workingSet: ws,
tables: make(map[string]*prollyTableWriter),
tracker: tracker,
mut: &sync.RWMutex{},
}
}
@@ -77,14 +80,15 @@ func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, opts edito
return &nomsWriteSession{
workingSet: ws,
tables: make(map[string]*sessionedTableEditor),
writeMutex: &sync.RWMutex{},
tracker: tracker,
mut: &sync.RWMutex{},
opts: opts,
}
}
func (s *nomsWriteSession) GetTableWriter(ctx context.Context, table string, database string, ait globalstate.AutoIncrementTracker, setter SessionRootSetter, batched bool) (TableWriter, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
func (s *nomsWriteSession) GetTableWriter(ctx context.Context, table, db string, setter SessionRootSetter, batched bool) (TableWriter, error) {
s.mut.Lock()
defer s.mut.Unlock()
t, ok, err := s.workingSet.WorkingRoot().GetTable(ctx, table)
if err != nil {
@@ -106,42 +110,39 @@ func (s *nomsWriteSession) GetTableWriter(ctx context.Context, table string, dat
}
conv := index.NewKVToSqlRowConverterForCols(t.Format(), sch)
ac := autoIncrementColFromSchema(sch)
return &nomsTableWriter{
tableName: table,
dbName: database,
dbName: db,
sch: sch,
autoIncCol: ac,
vrw: vrw,
kvToSQLRow: conv,
tableEditor: te,
sess: s,
batched: batched,
aiTracker: ait,
autoInc: s.tracker,
setter: setter,
}, nil
}
// Flush returns an updated root with all of the changed tables.
func (s *nomsWriteSession) Flush(ctx context.Context) (*doltdb.WorkingSet, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
s.mut.Lock()
defer s.mut.Unlock()
return s.flush(ctx)
}
// SetWorkingSet implements WriteSession.
func (s *nomsWriteSession) SetWorkingSet(ctx context.Context, ws *doltdb.WorkingSet) error {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
s.mut.Lock()
defer s.mut.Unlock()
return s.setWorkingSet(ctx, ws)
}
// UpdateWorkingSet implements WriteSession.
func (s *nomsWriteSession) UpdateWorkingSet(ctx context.Context, cb func(ctx context.Context, current *doltdb.WorkingSet) (*doltdb.WorkingSet, error)) error {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
s.mut.Lock()
defer s.mut.Unlock()
current, err := s.flush(ctx)
if err != nil {
@@ -167,44 +168,46 @@ func (s *nomsWriteSession) SetOptions(opts editor.Options) {
// flush is the inner implementation for Flush that does not acquire any locks
func (s *nomsWriteSession) flush(ctx context.Context) (*doltdb.WorkingSet, error) {
rootMutex := &sync.Mutex{}
wg := &sync.WaitGroup{}
wg.Add(len(s.tables))
newRoot := s.workingSet.WorkingRoot()
var tableErr error
var rootErr error
for tableName, ste := range s.tables {
if !ste.HasEdits() {
wg.Done()
mu := &sync.Mutex{}
rootUpdate := func(name string, table *doltdb.Table) (err error) {
mu.Lock()
defer mu.Unlock()
if newRoot != nil {
newRoot, err = newRoot.PutTable(ctx, name, table)
}
return err
}
eg, ctx := errgroup.WithContext(ctx)
for tblName, tblEditor := range s.tables {
if !tblEditor.HasEdits() {
continue
}
// we can run all of the Table calls concurrently as long as we guard updating the root
go func(tableName string, ste *sessionedTableEditor) {
defer wg.Done()
updatedTable, err := ste.tableEditor.Table(ctx)
// we lock immediately after doing the operation, since both error setting and root updating are guarded
rootMutex.Lock()
defer rootMutex.Unlock()
// copy variables
name, ed := tblName, tblEditor
eg.Go(func() error {
tbl, err := ed.tableEditor.Table(ctx)
if err != nil {
if tableErr == nil {
tableErr = err
return err
}
if schema.HasAutoIncrement(ed.Schema()) {
v := s.tracker.Current(name)
tbl, err = tbl.SetAutoIncrementValue(ctx, v)
if err != nil {
return err
}
return
}
newRoot, err = newRoot.PutTable(ctx, tableName, updatedTable)
if err != nil && rootErr == nil {
rootErr = err
}
}(tableName, ste)
return rootUpdate(name, tbl)
})
}
wg.Wait()
if tableErr != nil {
return nil, tableErr
}
if rootErr != nil {
return nil, rootErr
if err := eg.Wait(); err != nil {
return nil, err
}
s.workingSet = s.workingSet.WithWorkingRoot(newRoot)
@@ -320,6 +323,10 @@ func (s *nomsWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Working
return err
}
if err = s.updateAutoIncrementSequences(ctx, root); err != nil {
return err
}
for tableName, localTableEditor := range s.tables {
t, ok, err := root.GetTable(ctx, tableName)
if err != nil {
@@ -336,6 +343,7 @@ func (s *nomsWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Working
if err != nil {
return err
}
newTableEditor, err := editor.NewTableEditor(ctx, t, tSch, tableName, s.opts)
if err != nil {
return err
@@ -354,3 +362,17 @@ func (s *nomsWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Working
}
return nil
}
func (s *nomsWriteSession) updateAutoIncrementSequences(ctx context.Context, root *doltdb.RootValue) error {
return root.IterTables(ctx, func(name string, table *doltdb.Table, sch schema.Schema) (stop bool, err error) {
if !schema.HasAutoIncrement(sch) {
return
}
v, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
}
s.tracker.Set(name, v)
return
})
}

View File

@@ -42,7 +42,6 @@ type prollyTableWriter struct {
aiCol schema.Column
aiTracker globalstate.AutoIncrementTracker
aiUpdate bool
sess WriteSession
setter SessionRootSetter
@@ -64,7 +63,6 @@ func (w *prollyTableWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error {
if err := w.primary.Insert(ctx, sqlRow); err != nil {
return err
}
w.aiUpdate = true
return nil
}
@@ -97,28 +95,27 @@ func (w *prollyTableWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.
if err := w.primary.Update(ctx, oldRow, newRow); err != nil {
return err
}
w.aiUpdate = true
return nil
}
// NextAutoIncrementValue implements TableWriter.
func (w *prollyTableWriter) NextAutoIncrementValue(potentialVal, tableVal interface{}) (interface{}, error) {
return w.aiTracker.Next(w.tableName, potentialVal, tableVal)
// GetNextAutoIncrementValue implements TableWriter.
func (w *prollyTableWriter) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) {
return w.aiTracker.Next(w.tableName, insertVal)
}
// SetAutoIncrementValue implements TableWriter.
func (w *prollyTableWriter) SetAutoIncrementValue(ctx *sql.Context, val interface{}) error {
nomsVal, err := w.aiCol.TypeInfo.ConvertValueToNomsValue(ctx, w.tbl.ValueReadWriter(), val)
func (w *prollyTableWriter) SetAutoIncrementValue(ctx *sql.Context, val uint64) error {
seq, err := globalstate.CoerceAutoIncrementValue(val)
if err != nil {
return err
}
w.tbl, err = w.tbl.SetAutoIncrementValue(ctx, nomsVal)
// todo(andy) set here or in flush?
w.tbl, err = w.tbl.SetAutoIncrementValue(ctx, seq)
if err != nil {
return err
}
w.aiTracker.Reset(w.tableName, val)
w.aiTracker.Set(w.tableName, seq)
return w.flush(ctx)
}
@@ -129,7 +126,6 @@ func (w *prollyTableWriter) Close(ctx *sql.Context) error {
if w.batched {
return nil
}
return w.flush(ctx)
}
@@ -187,23 +183,12 @@ func (w *prollyTableWriter) table(ctx context.Context) (t *doltdb.Table, err err
return nil, err
}
if w.aiCol.AutoIncrement && w.aiUpdate {
seq, err := w.aiTracker.Next(w.tableName, nil, nil)
if w.aiCol.AutoIncrement {
seq := w.aiTracker.Current(w.tableName)
t, err = t.SetAutoIncrementValue(ctx, seq)
if err != nil {
return nil, err
}
vrw := w.tbl.ValueReadWriter()
v, err := w.aiCol.TypeInfo.ConvertValueToNomsValue(ctx, vrw, seq)
if err != nil {
return nil, err
}
t, err = t.SetAutoIncrementValue(ctx, v)
if err != nil {
return nil, err
}
w.aiUpdate = false
}
return t, nil
@@ -289,7 +274,9 @@ func (m prollyIndexWriter) Map(ctx context.Context) (prolly.Map, error) {
func (m prollyIndexWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error {
for to := range m.keyMap {
from := m.keyMap.MapOrdinal(to)
index.PutField(m.keyBld, to, sqlRow[from])
if err := index.PutField(m.keyBld, to, sqlRow[from]); err != nil {
return err
}
}
k := m.keyBld.Build(sharePool)
@@ -302,7 +289,9 @@ func (m prollyIndexWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error {
for to := range m.valMap {
from := m.valMap.MapOrdinal(to)
index.PutField(m.valBld, to, sqlRow[from])
if err = index.PutField(m.valBld, to, sqlRow[from]); err != nil {
return err
}
}
v := m.valBld.Build(sharePool)
@@ -312,7 +301,9 @@ func (m prollyIndexWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error {
func (m prollyIndexWriter) Delete(ctx *sql.Context, sqlRow sql.Row) error {
for to := range m.keyMap {
from := m.keyMap.MapOrdinal(to)
index.PutField(m.keyBld, to, sqlRow[from])
if err := index.PutField(m.keyBld, to, sqlRow[from]); err != nil {
return err
}
}
k := m.keyBld.Build(sharePool)
@@ -322,7 +313,9 @@ func (m prollyIndexWriter) Delete(ctx *sql.Context, sqlRow sql.Row) error {
func (m prollyIndexWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
for to := range m.keyMap {
from := m.keyMap.MapOrdinal(to)
index.PutField(m.keyBld, to, oldRow[from])
if err := index.PutField(m.keyBld, to, oldRow[from]); err != nil {
return err
}
}
oldKey := m.keyBld.Build(sharePool)
@@ -334,7 +327,9 @@ func (m prollyIndexWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.R
for to := range m.keyMap {
from := m.keyMap.MapOrdinal(to)
index.PutField(m.keyBld, to, newRow[from])
if err := index.PutField(m.keyBld, to, newRow[from]); err != nil {
return err
}
}
newKey := m.keyBld.Build(sharePool)
@@ -347,7 +342,9 @@ func (m prollyIndexWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.R
for to := range m.valMap {
from := m.valMap.MapOrdinal(to)
index.PutField(m.valBld, to, newRow[from])
if err = index.PutField(m.valBld, to, newRow[from]); err != nil {
return err
}
}
v := m.valBld.Build(sharePool)

View File

@@ -21,6 +21,7 @@ import (
"golang.org/x/sync/errgroup"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
@@ -31,13 +32,14 @@ import (
type prollyWriteSession struct {
workingSet *doltdb.WorkingSet
tables map[string]*prollyTableWriter
tracker globalstate.AutoIncrementTracker
mut *sync.RWMutex
}
var _ WriteSession = &prollyWriteSession{}
// GetTableWriter implemented WriteSession.
func (s *prollyWriteSession) GetTableWriter(ctx context.Context, table string, database string, ait globalstate.AutoIncrementTracker, setter SessionRootSetter, batched bool) (TableWriter, error) {
func (s *prollyWriteSession) GetTableWriter(ctx context.Context, table, db string, setter SessionRootSetter, batched bool) (TableWriter, error) {
s.mut.Lock()
defer s.mut.Unlock()
@@ -76,13 +78,13 @@ func (s *prollyWriteSession) GetTableWriter(ctx context.Context, table string, d
twr := &prollyTableWriter{
tableName: table,
dbName: database,
dbName: db,
primary: pw,
secondary: sws,
tbl: t,
sch: sch,
aiCol: autoCol,
aiTracker: ait,
aiTracker: s.tracker,
sess: s,
setter: setter,
batched: batched,
@@ -150,6 +152,14 @@ func (s *prollyWriteSession) flush(ctx context.Context) (*doltdb.WorkingSet, err
return err
}
if schema.HasAutoIncrement(wr.sch) {
v := s.tracker.Current(name)
t, err = t.SetAutoIncrementValue(ctx, v)
if err != nil {
return err
}
}
mu.Lock()
defer mu.Unlock()
tables[name] = t

View File

@@ -43,8 +43,8 @@ type sessionedTableEditor struct {
var _ editor.TableEditor = &sessionedTableEditor{}
func (ste *sessionedTableEditor) InsertKeyVal(ctx context.Context, key, val types.Tuple, tagToVal map[uint64]types.Value, errFunc editor.PKDuplicateErrFunc) error {
ste.tableEditSession.writeMutex.RLock()
defer ste.tableEditSession.writeMutex.RUnlock()
ste.tableEditSession.mut.RLock()
defer ste.tableEditSession.mut.RUnlock()
err := ste.validateForInsert(ctx, tagToVal)
if err != nil {
@@ -56,8 +56,8 @@ func (ste *sessionedTableEditor) InsertKeyVal(ctx context.Context, key, val type
}
func (ste *sessionedTableEditor) DeleteByKey(ctx context.Context, key types.Tuple, tagToVal map[uint64]types.Value) error {
ste.tableEditSession.writeMutex.RLock()
defer ste.tableEditSession.writeMutex.RUnlock()
ste.tableEditSession.mut.RLock()
defer ste.tableEditSession.mut.RUnlock()
if !ste.tableEditSession.opts.ForeignKeyChecksDisabled && len(ste.referencingTables) > 0 {
err := ste.onDeleteHandleRowsReferencingValues(ctx, key, tagToVal)
@@ -72,8 +72,8 @@ func (ste *sessionedTableEditor) DeleteByKey(ctx context.Context, key types.Tupl
// InsertRow adds the given row to the table. If the row already exists, use UpdateRow.
func (ste *sessionedTableEditor) InsertRow(ctx context.Context, dRow row.Row, errFunc editor.PKDuplicateErrFunc) error {
ste.tableEditSession.writeMutex.RLock()
defer ste.tableEditSession.writeMutex.RUnlock()
ste.tableEditSession.mut.RLock()
defer ste.tableEditSession.mut.RUnlock()
dRowTaggedVals, err := dRow.TaggedValues()
if err != nil {
@@ -90,8 +90,8 @@ func (ste *sessionedTableEditor) InsertRow(ctx context.Context, dRow row.Row, er
// DeleteRow removes the given key from the table.
func (ste *sessionedTableEditor) DeleteRow(ctx context.Context, r row.Row) error {
ste.tableEditSession.writeMutex.RLock()
defer ste.tableEditSession.writeMutex.RUnlock()
ste.tableEditSession.mut.RLock()
defer ste.tableEditSession.mut.RUnlock()
if !ste.tableEditSession.opts.ForeignKeyChecksDisabled && len(ste.referencingTables) > 0 {
err := ste.handleReferencingRowsOnDelete(ctx, r)
@@ -107,8 +107,8 @@ func (ste *sessionedTableEditor) DeleteRow(ctx context.Context, r row.Row) error
// UpdateRow takes the current row and new row, and updates it accordingly. Any applicable rows from tables that have a
// foreign key referencing this table will also be updated.
func (ste *sessionedTableEditor) UpdateRow(ctx context.Context, dOldRow row.Row, dNewRow row.Row, errFunc editor.PKDuplicateErrFunc) error {
ste.tableEditSession.writeMutex.RLock()
defer ste.tableEditSession.writeMutex.RUnlock()
ste.tableEditSession.mut.RLock()
defer ste.tableEditSession.mut.RUnlock()
return ste.updateRow(ctx, dOldRow, dNewRow, true, errFunc)
}
@@ -123,15 +123,9 @@ func (ste *sessionedTableEditor) HasEdits() bool {
return ste.tableEditor.HasEdits()
}
// GetAutoIncrementValue implements TableEditor.
func (ste *sessionedTableEditor) GetAutoIncrementValue() types.Value {
return ste.tableEditor.GetAutoIncrementValue()
}
// SetAutoIncrementValue implements TableEditor.
func (ste *sessionedTableEditor) SetAutoIncrementValue(v types.Value) error {
ste.dirty = true
return ste.tableEditor.SetAutoIncrementValue(v)
// MarkDirty implements TableEditor.
func (ste *sessionedTableEditor) MarkDirty() {
ste.tableEditor.MarkDirty()
}
// Table implements TableEditor.

View File

@@ -250,6 +250,13 @@ func (kte *keylessTableEditor) HasEdits() bool {
return kte.dirty
}
// MarkDirty implements TableEditor.
func (kte *keylessTableEditor) MarkDirty() {
kte.mu.Lock()
defer kte.mu.Unlock()
kte.dirty = true
}
// GetAutoIncrementValue implements TableEditor, AUTO_INCREMENT is not yet supported for keyless tables.
func (kte *keylessTableEditor) GetAutoIncrementValue() types.Value {
return types.NullValue

View File

@@ -62,10 +62,10 @@ type TableEditor interface {
InsertRow(ctx context.Context, r row.Row, errFunc PKDuplicateErrFunc) error
UpdateRow(ctx context.Context, old, new row.Row, errFunc PKDuplicateErrFunc) error
DeleteRow(ctx context.Context, r row.Row) error
HasEdits() bool
GetAutoIncrementValue() types.Value
SetAutoIncrementValue(v types.Value) (err error)
HasEdits() bool
MarkDirty()
SetConstraintViolation(ctx context.Context, k types.LesserValuable, v types.Valuable) error
Table(ctx context.Context) (*doltdb.Table, error)
@@ -128,10 +128,6 @@ type pkTableEditor struct {
// Whenever any write operation occurs on the table editor, this is set to true for the lifetime of the editor.
dirty uint32
hasAutoInc bool
autoIncCol schema.Column
autoIncVal types.Value
// This mutex blocks on each operation, so that map reads and updates are serialized
writeMutex *sync.Mutex
}
@@ -165,22 +161,6 @@ func newPkTableEditor(ctx context.Context, t *doltdb.Table, tableSch schema.Sche
te.indexEds[i] = NewIndexEditor(ctx, index, indexData, tableSch, opts)
}
err = tableSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
if col.AutoIncrement {
te.autoIncVal, err = t.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
}
te.hasAutoInc = true
te.autoIncCol = col
return true, err
}
return false, nil
})
if err != nil {
return nil, err
}
return te, nil
}
@@ -446,32 +426,7 @@ func (te *pkTableEditor) insertKeyVal(ctx context.Context, keyHash hash.Hash, ke
return err
}
if te.hasAutoInc {
insertVal, ok := tagToVal[te.autoIncCol.Tag]
if ok {
var less bool
// float auto increment values should be rounded before comparing to the current auto increment values
if te.autoIncVal.Kind() == types.FloatKind {
rounded := types.Round(insertVal)
less, err = rounded.Less(te.nbf, te.autoIncVal)
} else {
less, err = insertVal.Less(te.nbf, te.autoIncVal)
}
if err != nil {
return err
}
if !less {
te.autoIncVal = types.Round(insertVal)
te.autoIncVal = types.Increment(te.autoIncVal)
}
}
}
te.setDirty(true)
te.MarkDirty()
return nil
}
@@ -546,7 +501,7 @@ func (te *pkTableEditor) DeleteByKey(ctx context.Context, key types.Tuple, tagTo
return err
}
te.setDirty(true)
te.MarkDirty()
return te.tea.Delete(keyHash, key)
}
@@ -644,7 +599,7 @@ func (te *pkTableEditor) UpdateRow(ctx context.Context, dOldRow row.Row, dNewRow
return err
}
te.setDirty(true)
te.MarkDirty()
if kvp, pkExists, err := te.tea.Get(ctx, newHash, dNewKeyVal); err != nil {
return err
@@ -655,37 +610,14 @@ func (te *pkTableEditor) UpdateRow(ctx context.Context, dOldRow row.Row, dNewRow
return te.tea.Insert(newHash, dNewKeyVal, dNewRowVal)
}
func (te *pkTableEditor) GetAutoIncrementValue() types.Value {
return te.autoIncVal
}
func (te *pkTableEditor) SetAutoIncrementValue(v types.Value) (err error) {
te.writeMutex.Lock()
defer te.writeMutex.Unlock()
te.setDirty(true)
te.autoIncVal = v
te.t, err = te.t.SetAutoIncrementValue(nil, te.autoIncVal)
return err
}
// Table returns a Table based on the edits given, if any. If Flush() was not called prior, it will be called here.
func (te *pkTableEditor) Table(ctx context.Context) (*doltdb.Table, error) {
if !te.HasEdits() {
return te.t, nil
}
var err error
if te.hasAutoInc {
te.t, err = te.t.SetAutoIncrementValue(nil, te.autoIncVal)
if err != nil {
return nil, err
}
}
var tbl *doltdb.Table
err = func() error {
err := func() error {
te.writeMutex.Lock()
defer te.writeMutex.Unlock()
@@ -816,7 +748,7 @@ func (te *pkTableEditor) SetConstraintViolation(ctx context.Context, k types.Les
te.cvEditor = cvMap.Edit()
}
te.cvEditor.Set(k, v)
te.setDirty(true)
te.MarkDirty()
return nil
}
@@ -845,13 +777,10 @@ func (te *pkTableEditor) Close(ctx context.Context) error {
return nil
}
func (te *pkTableEditor) setDirty(dirty bool) {
var val uint32
if dirty {
val = 1
}
atomic.StoreUint32(&te.dirty, val)
// MarkDirty implements TableEditor.
func (te *pkTableEditor) MarkDirty() {
dirty := uint32(1)
atomic.StoreUint32(&te.dirty, dirty)
}
// hasEdits returns whether the table editor has had any successful write operations. This does not track whether the

View File

@@ -13,12 +13,12 @@ teardown() {
@test "sql-show: show table status on auto-increment table" {
dolt sql -q "CREATE TABLE test(pk int NOT NULL AUTO_INCREMENT, c1 int, PRIMARY KEY (pk))"
run dolt sql -q "show table status where \`Auto_increment\`=1;"
run dolt sql -q "show table status;"
[ "$status" -eq 0 ]
[[ "$output" =~ "test" ]] || false
dolt sql -q "INSERT INTO test (c1) VALUES (0)"
run dolt sql -q "show table status where \`Auto_increment\`=2;"
run dolt sql -q "show table status;"
[ "$status" -eq 0 ]
[[ "$output" =~ "test" ]] || false
}