mirror of
https://github.com/dolthub/dolt.git
synced 2026-02-24 00:59:41 -06:00
Merge pull request #3032 from dolthub/andy/new-auto-increment
go/libraries/doltcore/sqle: Refactored Auto Increment
This commit is contained in:
@@ -68,7 +68,7 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/dolthub/go-mysql-server v0.11.1-0.20220324183628-b0a3bc9c2c2f
|
||||
github.com/dolthub/go-mysql-server v0.11.1-0.20220325203039-101310d04210
|
||||
github.com/google/flatbuffers v2.0.5+incompatible
|
||||
github.com/gosuri/uilive v0.0.4
|
||||
github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6
|
||||
|
||||
@@ -170,8 +170,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm
|
||||
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
|
||||
github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
|
||||
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
|
||||
github.com/dolthub/go-mysql-server v0.11.1-0.20220324183628-b0a3bc9c2c2f h1:F1hhtWcel9an0/Ohbdg0gw6gSlgwWBxlnrSl7Jyi/2M=
|
||||
github.com/dolthub/go-mysql-server v0.11.1-0.20220324183628-b0a3bc9c2c2f/go.mod h1:1Sq4rszjP6AW7AJaF9xfycWexkNKIwkkOuYnoS5XcPg=
|
||||
github.com/dolthub/go-mysql-server v0.11.1-0.20220325203039-101310d04210 h1:sZiIQR3mhQkPNb/9TbEg9W+nF4bcHer2ubPdl3iyKjM=
|
||||
github.com/dolthub/go-mysql-server v0.11.1-0.20220325203039-101310d04210/go.mod h1:1Sq4rszjP6AW7AJaF9xfycWexkNKIwkkOuYnoS5XcPg=
|
||||
github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 h1:oyPHJlzumKta1vnOQqUnfdz+pk3EmnHS3Nd0cCT0I2g=
|
||||
github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371/go.mod h1:dhGBqcCEfK5kuFmeO5+WOx3hqc1k3M29c1oS/R7N4ms=
|
||||
github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8=
|
||||
|
||||
@@ -87,9 +87,9 @@ type Table interface {
|
||||
SetConstraintViolations(ctx context.Context, violations types.Map) (Table, error)
|
||||
|
||||
// GetAutoIncrement returns the AUTO_INCREMENT sequence value for this table.
|
||||
GetAutoIncrement(ctx context.Context) (types.Value, error)
|
||||
GetAutoIncrement(ctx context.Context) (uint64, error)
|
||||
// SetAutoIncrement sets the AUTO_INCREMENT sequence value for this table.
|
||||
SetAutoIncrement(ctx context.Context, val types.Value) (Table, error)
|
||||
SetAutoIncrement(ctx context.Context, val uint64) (Table, error)
|
||||
}
|
||||
|
||||
type nomsTable struct {
|
||||
@@ -140,7 +140,7 @@ func NewTable(ctx context.Context, vrw types.ValueReadWriter, sch schema.Schema,
|
||||
indexesKey: indexesRef,
|
||||
}
|
||||
|
||||
if autoIncVal != nil {
|
||||
if schema.HasAutoIncrement(sch) && autoIncVal != nil {
|
||||
sd[autoIncrementKey] = autoIncVal
|
||||
}
|
||||
|
||||
@@ -461,53 +461,36 @@ func (t nomsTable) SetConstraintViolations(ctx context.Context, violationsMap ty
|
||||
}
|
||||
|
||||
// GetAutoIncrement implements Table.
|
||||
func (t nomsTable) GetAutoIncrement(ctx context.Context) (types.Value, error) {
|
||||
func (t nomsTable) GetAutoIncrement(ctx context.Context) (uint64, error) {
|
||||
val, ok, err := t.tableStruct.MaybeGet(autoIncrementKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return 0, err
|
||||
}
|
||||
if ok {
|
||||
return val, nil
|
||||
if !ok {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
sch, err := t.GetSchema(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kind := types.UnknownKind
|
||||
_ = sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
|
||||
if col.AutoIncrement {
|
||||
kind = col.Kind
|
||||
stop = true
|
||||
}
|
||||
return
|
||||
})
|
||||
switch kind {
|
||||
case types.IntKind:
|
||||
return types.Int(1), nil
|
||||
case types.UintKind:
|
||||
return types.Uint(1), nil
|
||||
case types.FloatKind:
|
||||
return types.Float(1), nil
|
||||
// older versions might have serialized auto-increment
|
||||
// value as types.Int or types.Float.
|
||||
switch t := val.(type) {
|
||||
case types.Int:
|
||||
return uint64(t), nil
|
||||
case types.Uint:
|
||||
return uint64(t), nil
|
||||
case types.Float:
|
||||
return uint64(t), nil
|
||||
default:
|
||||
return nil, ErrUnknownAutoIncrementValue
|
||||
return 0, ErrUnknownAutoIncrementValue
|
||||
}
|
||||
}
|
||||
|
||||
// SetAutoIncrement implements Table.
|
||||
func (t nomsTable) SetAutoIncrement(ctx context.Context, val types.Value) (Table, error) {
|
||||
switch val.(type) {
|
||||
case types.Int, types.Uint, types.Float:
|
||||
st, err := t.tableStruct.Set(autoIncrementKey, val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nomsTable{t.vrw, st}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("cannot set auto increment to non-numeric value")
|
||||
func (t nomsTable) SetAutoIncrement(ctx context.Context, val uint64) (Table, error) {
|
||||
st, err := t.tableStruct.Set(autoIncrementKey, types.Uint(val))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nomsTable{t.vrw, st}, nil
|
||||
}
|
||||
|
||||
func refFromNomsValue(ctx context.Context, vrw types.ValueReadWriter, val types.Value) (types.Ref, error) {
|
||||
|
||||
@@ -455,12 +455,12 @@ func (t *Table) VerifyIndexRowData(ctx context.Context, indexName string) error
|
||||
}
|
||||
|
||||
// GetAutoIncrementValue returns the current AUTO_INCREMENT value for this table.
|
||||
func (t *Table) GetAutoIncrementValue(ctx context.Context) (types.Value, error) {
|
||||
func (t *Table) GetAutoIncrementValue(ctx context.Context) (uint64, error) {
|
||||
return t.table.GetAutoIncrement(ctx)
|
||||
}
|
||||
|
||||
// SetAutoIncrementValue sets the current AUTO_INCREMENT value for this table.
|
||||
func (t *Table) SetAutoIncrementValue(ctx context.Context, val types.Value) (*Table, error) {
|
||||
func (t *Table) SetAutoIncrementValue(ctx context.Context, val uint64) (*Table, error) {
|
||||
table, err := t.table.SetAutoIncrement(ctx, val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -786,14 +786,7 @@ func mergeAutoIncrementValues(ctx context.Context, tbl, otherTbl, resultTbl *dol
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auto := false
|
||||
_ = sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
|
||||
if col.AutoIncrement {
|
||||
auto, stop = true, true
|
||||
}
|
||||
return
|
||||
})
|
||||
if !auto {
|
||||
if !schema.HasAutoIncrement(sch) {
|
||||
return resultTbl, nil
|
||||
}
|
||||
|
||||
@@ -805,14 +798,10 @@ func mergeAutoIncrementValues(ctx context.Context, tbl, otherTbl, resultTbl *dol
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
less, err := autoVal.Less(tbl.Format(), mergeAutoVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if less {
|
||||
if autoVal < mergeAutoVal {
|
||||
autoVal = mergeAutoVal
|
||||
}
|
||||
return resultTbl.SetAutoIncrementValue(nil, autoVal)
|
||||
return resultTbl.SetAutoIncrementValue(ctx, autoVal)
|
||||
}
|
||||
|
||||
func MergeCommits(ctx context.Context, commit, mergeCommit *doltdb.Commit, opts editor.Options) (*doltdb.RootValue, map[string]*MergeStats, error) {
|
||||
|
||||
@@ -149,7 +149,7 @@ func updateTableWithModifiedColumn(ctx context.Context, tbl *doltdb.Table, oldSc
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var autoVal types.Value
|
||||
var autoVal uint64
|
||||
if schema.HasAutoIncrement(newSch) && schema.HasAutoIncrement(oldSch) {
|
||||
autoVal, err = tbl.GetAutoIncrementValue(ctx)
|
||||
if err != nil {
|
||||
@@ -157,7 +157,7 @@ func updateTableWithModifiedColumn(ctx context.Context, tbl *doltdb.Table, oldSc
|
||||
}
|
||||
}
|
||||
|
||||
updatedTable, err := doltdb.NewNomsTable(ctx, vrw, newSch, rowData, indexData, autoVal)
|
||||
updatedTable, err := doltdb.NewNomsTable(ctx, vrw, newSch, rowData, indexData, types.Uint(autoVal))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -82,15 +82,11 @@ func DbsAsDSQLDBs(dbs []sql.Database) []SqlDatabase {
|
||||
|
||||
// Database implements sql.Database for a dolt DB.
|
||||
type Database struct {
|
||||
name string
|
||||
ddb *doltdb.DoltDB
|
||||
rsr env.RepoStateReader
|
||||
rsw env.RepoStateWriter
|
||||
drw env.DocsReadWriter
|
||||
|
||||
// todo: needs a major refactor to
|
||||
// correctly handle persisted sequences
|
||||
// that must be coordinated across txs
|
||||
name string
|
||||
ddb *doltdb.DoltDB
|
||||
rsr env.RepoStateReader
|
||||
rsw env.RepoStateWriter
|
||||
drw env.DocsReadWriter
|
||||
gs globalstate.GlobalState
|
||||
editOpts editor.Options
|
||||
}
|
||||
@@ -167,6 +163,7 @@ var _ sql.TableRenamer = Database{}
|
||||
var _ sql.TriggerDatabase = Database{}
|
||||
var _ sql.StoredProcedureDatabase = Database{}
|
||||
var _ sql.TransactionDatabase = Database{}
|
||||
var _ globalstate.StateProvider = Database{}
|
||||
|
||||
// NewDatabase returns a new dolt database to use in queries.
|
||||
func NewDatabase(name string, dbData env.DbData, editOpts editor.Options) Database {
|
||||
@@ -259,6 +256,10 @@ func (db Database) DbData() env.DbData {
|
||||
}
|
||||
}
|
||||
|
||||
func (db Database) GetGlobalState() globalstate.GlobalState {
|
||||
return db.gs
|
||||
}
|
||||
|
||||
// GetTableInsensitive is used when resolving tables in queries. It returns a best-effort case-insensitive match for
|
||||
// the table name given.
|
||||
func (db Database) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) {
|
||||
@@ -585,6 +586,18 @@ func (db Database) GetRoot(ctx *sql.Context) (*doltdb.RootValue, error) {
|
||||
return dbState.GetRoots().Working, nil
|
||||
}
|
||||
|
||||
func (db Database) GetWorkingSet(ctx *sql.Context) (*doltdb.WorkingSet, error) {
|
||||
sess := dsess.DSessFromSess(ctx.Session)
|
||||
dbState, ok, err := sess.LookupDbState(ctx, db.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no root value found in session")
|
||||
}
|
||||
return dbState.WorkingSet, nil
|
||||
}
|
||||
|
||||
// SetRoot should typically be called on the Session, which is where this state lives. But it's available here as a
|
||||
// convenience.
|
||||
func (db Database) SetRoot(ctx *sql.Context, newRoot *doltdb.RootValue) error {
|
||||
@@ -660,7 +673,10 @@ func (db Database) dropTableFromAiTracker(ctx *sql.Context, tableName string) er
|
||||
return err
|
||||
}
|
||||
|
||||
ait := db.gs.GetAutoIncrementTracker(ws.Ref())
|
||||
ait, err := db.gs.GetAutoIncrementTracker(ctx, ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ait.DropTable(tableName)
|
||||
|
||||
return nil
|
||||
@@ -681,10 +697,11 @@ func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.Prima
|
||||
|
||||
// Unlike the exported version CreateTable, createSqlTable doesn't enforce any table name checks.
|
||||
func (db Database) createSqlTable(ctx *sql.Context, tableName string, sch sql.PrimaryKeySchema) error {
|
||||
root, err := db.GetRoot(ctx)
|
||||
ws, err := db.GetWorkingSet(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
root := ws.WorkingRoot()
|
||||
|
||||
if exists, err := root.HasTable(ctx, tableName); err != nil {
|
||||
return err
|
||||
@@ -707,6 +724,14 @@ func (db Database) createSqlTable(ctx *sql.Context, tableName string, sch sql.Pr
|
||||
return schema.ErrUsingSpatialKey.New(tableName)
|
||||
}
|
||||
|
||||
if schema.HasAutoIncrement(doltSch) {
|
||||
ait, err := db.gs.GetAutoIncrementTracker(ctx, ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ait.AddNewTable(tableName)
|
||||
}
|
||||
|
||||
return db.createDoltTable(ctx, tableName, root, doltSch)
|
||||
}
|
||||
|
||||
|
||||
@@ -485,7 +485,6 @@ func dbRevisionForCommit(ctx context.Context, srcDb Database, revSpec string) (R
|
||||
rsw: srcDb.DbData().Rsw,
|
||||
rsr: srcDb.DbData().Rsr,
|
||||
drw: srcDb.DbData().Drw,
|
||||
gs: nil,
|
||||
editOpts: srcDb.editOpts,
|
||||
}}
|
||||
init := dsess.InitialDbState{
|
||||
|
||||
@@ -17,11 +17,11 @@ package dsess
|
||||
import (
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
)
|
||||
|
||||
type InitialDbState struct {
|
||||
@@ -48,6 +48,7 @@ type DatabaseSessionState struct {
|
||||
WorkingSet *doltdb.WorkingSet
|
||||
dbData env.DbData
|
||||
WriteSession writer.WriteSession
|
||||
globalState globalstate.GlobalState
|
||||
readOnly bool
|
||||
dirty bool
|
||||
readReplica *env.Remote
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
@@ -700,7 +701,11 @@ func (sess *Session) SwitchWorkingSet(
|
||||
// make a fresh WriteSession, discard existing WriteSession
|
||||
opts := sessionState.WriteSession.GetOptions()
|
||||
nbf := ws.WorkingRoot().VRW().Format()
|
||||
sessionState.WriteSession = writer.NewWriteSession(nbf, ws, opts)
|
||||
tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx, ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionState.WriteSession = writer.NewWriteSession(nbf, ws, tracker, opts)
|
||||
|
||||
// After switching to a new working set, we are by definition clean
|
||||
sessionState.dirty = false
|
||||
@@ -830,17 +835,26 @@ func (sess *Session) AddDB(ctx *sql.Context, dbState InitialDbState) error {
|
||||
nbf := sessionState.dbData.Ddb.Format()
|
||||
editOpts := db.(interface{ EditOptions() editor.Options }).EditOptions()
|
||||
|
||||
stateProvider, ok := db.(globalstate.StateProvider)
|
||||
if !ok {
|
||||
return fmt.Errorf("database does not contain global state store")
|
||||
}
|
||||
sessionState.globalState = stateProvider.GetGlobalState()
|
||||
|
||||
// WorkingSet is nil in the case of a read only, detached head DB
|
||||
if dbState.Err != nil {
|
||||
sessionState.Err = dbState.Err
|
||||
|
||||
} else if dbState.WorkingSet != nil {
|
||||
sessionState.WorkingSet = dbState.WorkingSet
|
||||
sessionState.WriteSession = writer.NewWriteSession(nbf, sessionState.WorkingSet, editOpts)
|
||||
err := sess.SetWorkingSet(ctx, db.Name(), dbState.WorkingSet)
|
||||
tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx, sessionState.WorkingSet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionState.WriteSession = writer.NewWriteSession(nbf, sessionState.WorkingSet, tracker, editOpts)
|
||||
if err = sess.SetWorkingSet(ctx, db.Name(), dbState.WorkingSet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
headRoot, err := dbState.HeadCommit.GetRootValue()
|
||||
|
||||
@@ -15,146 +15,118 @@
|
||||
package globalstate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
)
|
||||
|
||||
type AutoIncrementTracker interface {
|
||||
// Next returns the next auto increment value to be used by a table. If a table is not initialized in the counter
|
||||
// it will used the value stored in disk.
|
||||
Next(tableName string, insertVal interface{}, diskVal interface{}) (interface{}, error)
|
||||
// Reset resets the auto increment tracker value for a table. Typically used in truncate statements.
|
||||
Reset(tableName string, val interface{})
|
||||
// DropTable removes a table from the autoincrement tracker.
|
||||
DropTable(tableName string)
|
||||
}
|
||||
|
||||
// AutoIncrementTracker is a global map that tracks which auto increment keys have been given for each table. At runtime
|
||||
// it hands out the current key.
|
||||
func NewAutoIncrementTracker() AutoIncrementTracker {
|
||||
return &autoIncrementTracker{
|
||||
valuePerTable: make(map[string]interface{}),
|
||||
// CoerceAutoIncrementValue converts |val| into an AUTO_INCREMENT sequence value
|
||||
func CoerceAutoIncrementValue(val interface{}) (uint64, error) {
|
||||
switch typ := val.(type) {
|
||||
case float32:
|
||||
val = math.Round(float64(typ))
|
||||
case float64:
|
||||
val = math.Round(typ)
|
||||
}
|
||||
|
||||
var err error
|
||||
val, err = sql.Uint64.Convert(val)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if val == nil || val == uint64(0) {
|
||||
return 0, nil
|
||||
}
|
||||
return val.(uint64), nil
|
||||
}
|
||||
|
||||
type autoIncrementTracker struct {
|
||||
valuePerTable map[string]interface{}
|
||||
mu sync.Mutex
|
||||
func NewAutoIncrementTracker(ctx context.Context, ws *doltdb.WorkingSet) (ait AutoIncrementTracker, err error) {
|
||||
ait = AutoIncrementTracker{
|
||||
wsRef: ws.Ref(),
|
||||
sequences: make(map[string]uint64),
|
||||
mu: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// collect auto increment values
|
||||
err = ws.WorkingRoot().IterTables(ctx, func(name string, table *doltdb.Table, sch schema.Schema) (bool, error) {
|
||||
ok := schema.HasAutoIncrement(sch)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
seq, err := table.GetAutoIncrementValue(ctx)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
ait.sequences[name] = seq
|
||||
return false, nil
|
||||
})
|
||||
return ait, err
|
||||
}
|
||||
|
||||
var _ AutoIncrementTracker = (*autoIncrementTracker)(nil)
|
||||
type AutoIncrementTracker struct {
|
||||
wsRef ref.WorkingSetRef
|
||||
sequences map[string]uint64
|
||||
mu *sync.Mutex
|
||||
}
|
||||
|
||||
func (a *autoIncrementTracker) Next(tableName string, insertVal interface{}, diskVal interface{}) (interface{}, error) {
|
||||
// Current returns the current AUTO_INCREMENT value for |tableName|.
|
||||
func (a AutoIncrementTracker) Current(tableName string) uint64 {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.sequences[tableName]
|
||||
}
|
||||
|
||||
// Next returns the next AUTO_INCREMENT value for |tableName|, considering the provided |insertVal|.
|
||||
func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
diskVal = valOrZero(diskVal)
|
||||
|
||||
// Case 0: Just use the value passed in.
|
||||
potential, ok := a.valuePerTable[tableName]
|
||||
if !ok {
|
||||
// Use the disk val if the table has not been initialized yet.
|
||||
potential = diskVal
|
||||
}
|
||||
|
||||
// Case 1: Disk Val is greater. This is useful for updating the tracker when a merge occurs.
|
||||
// TODO: This is a bit of a hack. The correct solution is to plumb this tracker through the merge logic.
|
||||
diskValGreater, err := geq(valOrZero(diskVal), valOrZero(a.valuePerTable[tableName]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if diskValGreater {
|
||||
potential = diskVal
|
||||
}
|
||||
|
||||
// Case 2: Overwrite anything if an insert val is passed.
|
||||
if insertVal != nil {
|
||||
potential = insertVal
|
||||
}
|
||||
|
||||
// update the table only if val >= existing
|
||||
isGeq, err := geq(valOrZero(potential), valOrZero(a.valuePerTable[tableName]))
|
||||
given, err := CoerceAutoIncrementValue(insertVal)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if isGeq {
|
||||
val, err := convertIntTypeToUint(potential)
|
||||
if err != nil {
|
||||
return val, err
|
||||
}
|
||||
curr := a.sequences[tbl]
|
||||
|
||||
a.valuePerTable[tableName] = val + 1
|
||||
if given == 0 {
|
||||
// |given| is 0 or NULL
|
||||
a.sequences[tbl]++
|
||||
return curr, nil
|
||||
}
|
||||
|
||||
return potential, nil
|
||||
if given >= curr {
|
||||
a.sequences[tbl] = given
|
||||
a.sequences[tbl]++
|
||||
return given, nil
|
||||
}
|
||||
|
||||
// |given| < curr
|
||||
return given, nil
|
||||
}
|
||||
|
||||
func (a *autoIncrementTracker) Reset(tableName string, val interface{}) {
|
||||
// Set sets the current AUTO_INCREMENT value for |tableName|.
|
||||
func (a AutoIncrementTracker) Set(tableName string, val uint64) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
a.valuePerTable[tableName] = val
|
||||
a.sequences[tableName] = val
|
||||
}
|
||||
|
||||
func (a *autoIncrementTracker) DropTable(tableName string) {
|
||||
// AddNewTable adds |tablename| to the AutoIncrementTracker.
|
||||
func (a AutoIncrementTracker) AddNewTable(tableName string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
delete(a.valuePerTable, tableName)
|
||||
a.sequences[tableName] = uint64(1)
|
||||
}
|
||||
|
||||
// Helper method that sets nil values to 0 for clarity purposes
|
||||
func valOrZero(val interface{}) interface{} {
|
||||
if val == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
func geq(val1 interface{}, val2 interface{}) (bool, error) {
|
||||
v1, err := convertIntTypeToUint(val1)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
v2, err := convertIntTypeToUint(val2)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return v1 >= v2, nil
|
||||
}
|
||||
|
||||
func convertIntTypeToUint(val interface{}) (uint64, error) {
|
||||
switch t := val.(type) {
|
||||
case int:
|
||||
return uint64(t), nil
|
||||
case int8:
|
||||
return uint64(t), nil
|
||||
case int16:
|
||||
return uint64(t), nil
|
||||
case int32:
|
||||
return uint64(t), nil
|
||||
case int64:
|
||||
return uint64(t), nil
|
||||
case uint:
|
||||
return uint64(t), nil
|
||||
case uint8:
|
||||
return uint64(t), nil
|
||||
case uint16:
|
||||
return uint64(t), nil
|
||||
case uint32:
|
||||
return uint64(t), nil
|
||||
case uint64:
|
||||
return t, nil
|
||||
case float32:
|
||||
return uint64(t), nil
|
||||
case float64:
|
||||
return uint64(t), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("error: auto increment is not a numeric type")
|
||||
}
|
||||
// DropTable drops |tablename| from the AutoIncrementTracker.
|
||||
func (a AutoIncrementTracker) DropTable(tableName string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
delete(a.sequences, tableName)
|
||||
}
|
||||
|
||||
@@ -15,42 +15,58 @@
|
||||
package globalstate
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNextHasNoRepeats(t *testing.T) {
|
||||
var allVals sync.Map
|
||||
aiTracker := NewAutoIncrementTracker()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
for j := 0; j < 10; j++ {
|
||||
nxt, err := aiTracker.Next("test", nil, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, err := convertIntTypeToUint(nxt)
|
||||
require.NoError(t, err)
|
||||
|
||||
current, ok := allVals.Load(val)
|
||||
if !ok {
|
||||
allVals.Store(val, 1)
|
||||
} else {
|
||||
asUint, _ := convertIntTypeToUint(current)
|
||||
allVals.Store(val, asUint+1)
|
||||
}
|
||||
}
|
||||
}()
|
||||
func TestCoerceAutoIncrementValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
val interface{}
|
||||
exp uint64
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
val: nil,
|
||||
exp: uint64(0),
|
||||
},
|
||||
{
|
||||
val: int32(0),
|
||||
exp: uint64(0),
|
||||
},
|
||||
{
|
||||
val: int32(1),
|
||||
exp: uint64(1),
|
||||
},
|
||||
{
|
||||
val: uint32(1),
|
||||
exp: uint64(1),
|
||||
},
|
||||
{
|
||||
val: float32(1),
|
||||
exp: uint64(1),
|
||||
},
|
||||
{
|
||||
val: float32(1.1),
|
||||
exp: uint64(1),
|
||||
},
|
||||
{
|
||||
val: float32(1.9),
|
||||
exp: uint64(2),
|
||||
},
|
||||
}
|
||||
|
||||
// Make sure each key was called once
|
||||
allVals.Range(func(key, value interface{}) bool {
|
||||
asUint, _ := convertIntTypeToUint(value)
|
||||
|
||||
require.Equal(t, uint64(1), asUint)
|
||||
|
||||
return true
|
||||
})
|
||||
for _, test := range tests {
|
||||
name := fmt.Sprintf("Coerce %v to %v", test.val, test.exp)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
act, err := CoerceAutoIncrementValue(test.val)
|
||||
if test.err {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, test.exp, act)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,36 +15,45 @@
|
||||
package globalstate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
)
|
||||
|
||||
type GlobalState interface {
|
||||
GetAutoIncrementTracker(wsref ref.WorkingSetRef) AutoIncrementTracker
|
||||
type StateProvider interface {
|
||||
GetGlobalState() GlobalState
|
||||
}
|
||||
|
||||
func NewGlobalStateStore() GlobalState {
|
||||
return &globalStateImpl{
|
||||
return GlobalState{
|
||||
trackerMap: make(map[ref.WorkingSetRef]AutoIncrementTracker),
|
||||
mu: &sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
type globalStateImpl struct {
|
||||
type GlobalState struct {
|
||||
trackerMap map[ref.WorkingSetRef]AutoIncrementTracker
|
||||
mu sync.Mutex
|
||||
mu *sync.Mutex
|
||||
}
|
||||
|
||||
var _ GlobalState = (*globalStateImpl)(nil)
|
||||
|
||||
func (g *globalStateImpl) GetAutoIncrementTracker(wsref ref.WorkingSetRef) AutoIncrementTracker {
|
||||
func (g GlobalState) GetAutoIncrementTracker(ctx context.Context, ws *doltdb.WorkingSet) (AutoIncrementTracker, error) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
_, ok := g.trackerMap[wsref]
|
||||
if !ok {
|
||||
g.trackerMap[wsref] = NewAutoIncrementTracker()
|
||||
ait, ok := g.trackerMap[ws.Ref()]
|
||||
if ok {
|
||||
return ait, nil
|
||||
}
|
||||
|
||||
return g.trackerMap[wsref]
|
||||
var err error
|
||||
ait, err = NewAutoIncrementTracker(ctx, ws)
|
||||
if err != nil {
|
||||
return AutoIncrementTracker{}, err
|
||||
}
|
||||
g.trackerMap[ws.Ref()] = ait
|
||||
|
||||
return ait, nil
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ func (e *StaticErrorEditor) Update(*sql.Context, sql.Row, sql.Row) error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func (e *StaticErrorEditor) SetAutoIncrementValue(*sql.Context, interface{}) error {
|
||||
func (e *StaticErrorEditor) SetAutoIncrementValue(*sql.Context, uint64) error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema/alterschema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
|
||||
@@ -196,12 +197,7 @@ func (t *DoltTable) GetAutoIncrementValue(ctx *sql.Context) (interface{}, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
val, err := table.GetAutoIncrementValue(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.autoIncCol.TypeInfo.ConvertNomsValueToValue(val)
|
||||
return table.GetAutoIncrementValue(ctx)
|
||||
}
|
||||
|
||||
// Name returns the name of the table.
|
||||
@@ -415,19 +411,13 @@ func (t *WritableDoltTable) getTableEditor(ctx *sql.Context) (ed writer.TableWri
|
||||
}
|
||||
}
|
||||
|
||||
ws, err := ds.WorkingSet(ctx, t.db.name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ait := t.db.gs.GetAutoIncrementTracker(ws.Ref())
|
||||
|
||||
state, _, err := ds.LookupDbState(ctx, t.db.name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setter := ds.SetRoot
|
||||
ed, err = state.WriteSession.GetTableWriter(ctx, t.tableName, t.db.Name(), ait, setter, batched)
|
||||
ed, err = state.WriteSession.GetTableWriter(ctx, t.tableName, t.db.Name(), setter, batched)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -542,22 +532,17 @@ func (t *WritableDoltTable) PeekNextAutoIncrementValue(ctx *sql.Context) (interf
|
||||
}
|
||||
|
||||
// GetNextAutoIncrementValue implements sql.AutoIncrementTable
|
||||
func (t *WritableDoltTable) GetNextAutoIncrementValue(ctx *sql.Context, potentialVal interface{}) (interface{}, error) {
|
||||
func (t *WritableDoltTable) GetNextAutoIncrementValue(ctx *sql.Context, potentialVal interface{}) (uint64, error) {
|
||||
if !t.autoIncCol.AutoIncrement {
|
||||
return nil, sql.ErrNoAutoIncrementCol
|
||||
return 0, sql.ErrNoAutoIncrementCol
|
||||
}
|
||||
|
||||
ed, err := t.getTableEditor(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
tableVal, err := t.getTableAutoIncrementValue(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ed.NextAutoIncrementValue(potentialVal, tableVal)
|
||||
return ed.GetNextAutoIncrementValue(ctx, potentialVal)
|
||||
}
|
||||
|
||||
func (t *WritableDoltTable) getTableAutoIncrementValue(ctx *sql.Context) (interface{}, error) {
|
||||
@@ -887,6 +872,18 @@ func (t *AlterableDoltTable) AddColumn(ctx *sql.Context, column *sql.Column, ord
|
||||
return err
|
||||
}
|
||||
|
||||
if column.AutoIncrement {
|
||||
ws, err := t.db.GetWorkingSet(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ait, err := t.db.gs.GetAutoIncrementTracker(ctx, ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ait.AddNewTable(t.tableName)
|
||||
}
|
||||
|
||||
newRoot, err := root.PutTable(ctx, t.tableName, updatedTable)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1042,10 +1039,11 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c
|
||||
return nil
|
||||
}
|
||||
|
||||
root, err := t.getRoot(ctx)
|
||||
ws, err := t.db.GetWorkingSet(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
root := ws.WorkingRoot()
|
||||
|
||||
table, _, err := root.GetTable(ctx, t.tableName)
|
||||
if err != nil {
|
||||
@@ -1118,10 +1116,6 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c
|
||||
return err
|
||||
}
|
||||
|
||||
initialValue := column.Type.Zero()
|
||||
|
||||
colIdx := updatedSch.GetAllCols().IndexOf(columnName)
|
||||
|
||||
rowData, err := updatedTable.GetRowData(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1134,6 +1128,9 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c
|
||||
return err
|
||||
}
|
||||
|
||||
initialValue := column.Type.Zero()
|
||||
colIdx := updatedSch.GetAllCols().IndexOf(columnName)
|
||||
|
||||
for {
|
||||
r, err := rowIter.Next(ctx)
|
||||
if err == io.EOF {
|
||||
@@ -1146,23 +1143,28 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cmp < 0 {
|
||||
initialValue = r[colIdx]
|
||||
}
|
||||
}
|
||||
|
||||
initialValNoms, err := col.TypeInfo.ConvertValueToNomsValue(ctx, root.VRW(), initialValue)
|
||||
seq, err := globalstate.CoerceAutoIncrementValue(initialValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
seq++
|
||||
|
||||
updatedTable, err = updatedTable.SetAutoIncrementValue(ctx, seq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
initialValNoms = increment(initialValNoms)
|
||||
|
||||
updatedTable, err = updatedTable.SetAutoIncrementValue(ctx, initialValNoms)
|
||||
ait, err := t.db.gs.GetAutoIncrementTracker(ctx, ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ait.AddNewTable(t.tableName)
|
||||
ait.Set(t.tableName, seq)
|
||||
}
|
||||
|
||||
newRoot, err := root.PutTable(ctx, t.tableName, updatedTable)
|
||||
|
||||
@@ -35,8 +35,7 @@ type TableWriter interface {
|
||||
sql.RowInserter
|
||||
sql.RowDeleter
|
||||
sql.AutoIncrementSetter
|
||||
|
||||
NextAutoIncrementValue(potentialVal, tableVal interface{}) (interface{}, error)
|
||||
GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error)
|
||||
}
|
||||
|
||||
// SessionRootSetter sets the root value for the session.
|
||||
@@ -56,14 +55,14 @@ type nomsTableWriter struct {
|
||||
tableName string
|
||||
dbName string
|
||||
sch schema.Schema
|
||||
autoIncCol schema.Column
|
||||
vrw types.ValueReadWriter
|
||||
kvToSQLRow *index.KVToSqlRowConverter
|
||||
tableEditor editor.TableEditor
|
||||
sess WriteSession
|
||||
aiTracker globalstate.AutoIncrementTracker
|
||||
batched bool
|
||||
|
||||
autoInc globalstate.AutoIncrementTracker
|
||||
|
||||
setter SessionRootSetter
|
||||
}
|
||||
|
||||
@@ -158,25 +157,17 @@ func (te *nomsTableWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.R
|
||||
return err
|
||||
}
|
||||
|
||||
func (te *nomsTableWriter) NextAutoIncrementValue(potentialVal, tableVal interface{}) (interface{}, error) {
|
||||
return te.aiTracker.Next(te.tableName, potentialVal, tableVal)
|
||||
func (te *nomsTableWriter) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) {
|
||||
return te.autoInc.Next(te.tableName, insertVal)
|
||||
}
|
||||
|
||||
func (te *nomsTableWriter) GetAutoIncrementValue() (interface{}, error) {
|
||||
val := te.tableEditor.GetAutoIncrementValue()
|
||||
return te.autoIncCol.TypeInfo.ConvertNomsValueToValue(val)
|
||||
}
|
||||
|
||||
func (te *nomsTableWriter) SetAutoIncrementValue(ctx *sql.Context, val interface{}) error {
|
||||
nomsVal, err := te.autoIncCol.TypeInfo.ConvertValueToNomsValue(ctx, te.vrw, val)
|
||||
func (te *nomsTableWriter) SetAutoIncrementValue(ctx *sql.Context, val uint64) error {
|
||||
seq, err := globalstate.CoerceAutoIncrementValue(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = te.tableEditor.SetAutoIncrementValue(nomsVal); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
te.aiTracker.Reset(te.tableName, val)
|
||||
te.autoInc.Set(te.tableName, seq)
|
||||
te.tableEditor.MarkDirty()
|
||||
|
||||
return te.flush(ctx)
|
||||
}
|
||||
|
||||
@@ -19,20 +19,21 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
// WriteSession encapsulates writes made within a SQL session.
|
||||
// It's responsible for creating and managing the lifecycle of TableWriter's.
|
||||
type WriteSession interface {
|
||||
// GetTableWriter creates a TableWriter and adds it to the WriteSession.
|
||||
GetTableWriter(ctx context.Context, table, db string, ait globalstate.AutoIncrementTracker, setter SessionRootSetter, batched bool) (TableWriter, error)
|
||||
GetTableWriter(ctx context.Context, table, db string, setter SessionRootSetter, batched bool) (TableWriter, error)
|
||||
|
||||
// UpdateWorkingSet takes a callback to update this WriteSession's WorkingSet. The update method cannot change the
|
||||
// WorkingSetRef of the WriteSession. WriteSession flushes the pending writes in the session before calling the update.
|
||||
@@ -56,7 +57,8 @@ type WriteSession interface {
|
||||
type nomsWriteSession struct {
|
||||
workingSet *doltdb.WorkingSet
|
||||
tables map[string]*sessionedTableEditor
|
||||
writeMutex *sync.RWMutex // This mutex is specifically for changes that affect the TES or all STEs
|
||||
tracker globalstate.AutoIncrementTracker
|
||||
mut *sync.RWMutex // This mutex is specifically for changes that affect the TES or all STEs
|
||||
opts editor.Options
|
||||
}
|
||||
|
||||
@@ -65,11 +67,12 @@ var _ WriteSession = &nomsWriteSession{}
|
||||
// NewWriteSession creates and returns a WriteSession. Inserting a nil root is not an error, as there are
|
||||
// locations that do not have a root at the time of this call. However, a root must be set through SetRoot before any
|
||||
// table editors are returned.
|
||||
func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, opts editor.Options) WriteSession {
|
||||
func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, tracker globalstate.AutoIncrementTracker, opts editor.Options) WriteSession {
|
||||
if types.IsFormat_DOLT_1(nbf) {
|
||||
return &prollyWriteSession{
|
||||
workingSet: ws,
|
||||
tables: make(map[string]*prollyTableWriter),
|
||||
tracker: tracker,
|
||||
mut: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
@@ -77,14 +80,15 @@ func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, opts edito
|
||||
return &nomsWriteSession{
|
||||
workingSet: ws,
|
||||
tables: make(map[string]*sessionedTableEditor),
|
||||
writeMutex: &sync.RWMutex{},
|
||||
tracker: tracker,
|
||||
mut: &sync.RWMutex{},
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *nomsWriteSession) GetTableWriter(ctx context.Context, table string, database string, ait globalstate.AutoIncrementTracker, setter SessionRootSetter, batched bool) (TableWriter, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
func (s *nomsWriteSession) GetTableWriter(ctx context.Context, table, db string, setter SessionRootSetter, batched bool) (TableWriter, error) {
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
t, ok, err := s.workingSet.WorkingRoot().GetTable(ctx, table)
|
||||
if err != nil {
|
||||
@@ -106,42 +110,39 @@ func (s *nomsWriteSession) GetTableWriter(ctx context.Context, table string, dat
|
||||
}
|
||||
|
||||
conv := index.NewKVToSqlRowConverterForCols(t.Format(), sch)
|
||||
ac := autoIncrementColFromSchema(sch)
|
||||
|
||||
return &nomsTableWriter{
|
||||
tableName: table,
|
||||
dbName: database,
|
||||
dbName: db,
|
||||
sch: sch,
|
||||
autoIncCol: ac,
|
||||
vrw: vrw,
|
||||
kvToSQLRow: conv,
|
||||
tableEditor: te,
|
||||
sess: s,
|
||||
batched: batched,
|
||||
aiTracker: ait,
|
||||
autoInc: s.tracker,
|
||||
setter: setter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Flush returns an updated root with all of the changed tables.
|
||||
func (s *nomsWriteSession) Flush(ctx context.Context) (*doltdb.WorkingSet, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
return s.flush(ctx)
|
||||
}
|
||||
|
||||
// SetWorkingSet implements WriteSession.
|
||||
func (s *nomsWriteSession) SetWorkingSet(ctx context.Context, ws *doltdb.WorkingSet) error {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
return s.setWorkingSet(ctx, ws)
|
||||
}
|
||||
|
||||
// UpdateWorkingSet implements WriteSession.
|
||||
func (s *nomsWriteSession) UpdateWorkingSet(ctx context.Context, cb func(ctx context.Context, current *doltdb.WorkingSet) (*doltdb.WorkingSet, error)) error {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
current, err := s.flush(ctx)
|
||||
if err != nil {
|
||||
@@ -167,44 +168,46 @@ func (s *nomsWriteSession) SetOptions(opts editor.Options) {
|
||||
|
||||
// flush is the inner implementation for Flush that does not acquire any locks
|
||||
func (s *nomsWriteSession) flush(ctx context.Context) (*doltdb.WorkingSet, error) {
|
||||
rootMutex := &sync.Mutex{}
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(len(s.tables))
|
||||
|
||||
newRoot := s.workingSet.WorkingRoot()
|
||||
var tableErr error
|
||||
var rootErr error
|
||||
for tableName, ste := range s.tables {
|
||||
if !ste.HasEdits() {
|
||||
wg.Done()
|
||||
mu := &sync.Mutex{}
|
||||
rootUpdate := func(name string, table *doltdb.Table) (err error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if newRoot != nil {
|
||||
newRoot, err = newRoot.PutTable(ctx, name, table)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
for tblName, tblEditor := range s.tables {
|
||||
if !tblEditor.HasEdits() {
|
||||
continue
|
||||
}
|
||||
|
||||
// we can run all of the Table calls concurrently as long as we guard updating the root
|
||||
go func(tableName string, ste *sessionedTableEditor) {
|
||||
defer wg.Done()
|
||||
updatedTable, err := ste.tableEditor.Table(ctx)
|
||||
// we lock immediately after doing the operation, since both error setting and root updating are guarded
|
||||
rootMutex.Lock()
|
||||
defer rootMutex.Unlock()
|
||||
// copy variables
|
||||
name, ed := tblName, tblEditor
|
||||
|
||||
eg.Go(func() error {
|
||||
tbl, err := ed.tableEditor.Table(ctx)
|
||||
if err != nil {
|
||||
if tableErr == nil {
|
||||
tableErr = err
|
||||
return err
|
||||
}
|
||||
|
||||
if schema.HasAutoIncrement(ed.Schema()) {
|
||||
v := s.tracker.Current(name)
|
||||
tbl, err = tbl.SetAutoIncrementValue(ctx, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return
|
||||
}
|
||||
newRoot, err = newRoot.PutTable(ctx, tableName, updatedTable)
|
||||
if err != nil && rootErr == nil {
|
||||
rootErr = err
|
||||
}
|
||||
}(tableName, ste)
|
||||
|
||||
return rootUpdate(name, tbl)
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
if tableErr != nil {
|
||||
return nil, tableErr
|
||||
}
|
||||
if rootErr != nil {
|
||||
return nil, rootErr
|
||||
if err := eg.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.workingSet = s.workingSet.WithWorkingRoot(newRoot)
|
||||
|
||||
@@ -320,6 +323,10 @@ func (s *nomsWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Working
|
||||
return err
|
||||
}
|
||||
|
||||
if err = s.updateAutoIncrementSequences(ctx, root); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for tableName, localTableEditor := range s.tables {
|
||||
t, ok, err := root.GetTable(ctx, tableName)
|
||||
if err != nil {
|
||||
@@ -336,6 +343,7 @@ func (s *nomsWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Working
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newTableEditor, err := editor.NewTableEditor(ctx, t, tSch, tableName, s.opts)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -354,3 +362,17 @@ func (s *nomsWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Working
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *nomsWriteSession) updateAutoIncrementSequences(ctx context.Context, root *doltdb.RootValue) error {
|
||||
return root.IterTables(ctx, func(name string, table *doltdb.Table, sch schema.Schema) (stop bool, err error) {
|
||||
if !schema.HasAutoIncrement(sch) {
|
||||
return
|
||||
}
|
||||
v, err := table.GetAutoIncrementValue(ctx)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
s.tracker.Set(name, v)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
@@ -42,7 +42,6 @@ type prollyTableWriter struct {
|
||||
|
||||
aiCol schema.Column
|
||||
aiTracker globalstate.AutoIncrementTracker
|
||||
aiUpdate bool
|
||||
|
||||
sess WriteSession
|
||||
setter SessionRootSetter
|
||||
@@ -64,7 +63,6 @@ func (w *prollyTableWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
if err := w.primary.Insert(ctx, sqlRow); err != nil {
|
||||
return err
|
||||
}
|
||||
w.aiUpdate = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -97,28 +95,27 @@ func (w *prollyTableWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.
|
||||
if err := w.primary.Update(ctx, oldRow, newRow); err != nil {
|
||||
return err
|
||||
}
|
||||
w.aiUpdate = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextAutoIncrementValue implements TableWriter.
|
||||
func (w *prollyTableWriter) NextAutoIncrementValue(potentialVal, tableVal interface{}) (interface{}, error) {
|
||||
return w.aiTracker.Next(w.tableName, potentialVal, tableVal)
|
||||
// GetNextAutoIncrementValue implements TableWriter.
|
||||
func (w *prollyTableWriter) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) {
|
||||
return w.aiTracker.Next(w.tableName, insertVal)
|
||||
}
|
||||
|
||||
// SetAutoIncrementValue implements TableWriter.
|
||||
func (w *prollyTableWriter) SetAutoIncrementValue(ctx *sql.Context, val interface{}) error {
|
||||
nomsVal, err := w.aiCol.TypeInfo.ConvertValueToNomsValue(ctx, w.tbl.ValueReadWriter(), val)
|
||||
func (w *prollyTableWriter) SetAutoIncrementValue(ctx *sql.Context, val uint64) error {
|
||||
seq, err := globalstate.CoerceAutoIncrementValue(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.tbl, err = w.tbl.SetAutoIncrementValue(ctx, nomsVal)
|
||||
// todo(andy) set here or in flush?
|
||||
w.tbl, err = w.tbl.SetAutoIncrementValue(ctx, seq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.aiTracker.Reset(w.tableName, val)
|
||||
w.aiTracker.Set(w.tableName, seq)
|
||||
|
||||
return w.flush(ctx)
|
||||
}
|
||||
@@ -129,7 +126,6 @@ func (w *prollyTableWriter) Close(ctx *sql.Context) error {
|
||||
if w.batched {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.flush(ctx)
|
||||
}
|
||||
|
||||
@@ -187,23 +183,12 @@ func (w *prollyTableWriter) table(ctx context.Context) (t *doltdb.Table, err err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if w.aiCol.AutoIncrement && w.aiUpdate {
|
||||
seq, err := w.aiTracker.Next(w.tableName, nil, nil)
|
||||
if w.aiCol.AutoIncrement {
|
||||
seq := w.aiTracker.Current(w.tableName)
|
||||
t, err = t.SetAutoIncrementValue(ctx, seq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vrw := w.tbl.ValueReadWriter()
|
||||
|
||||
v, err := w.aiCol.TypeInfo.ConvertValueToNomsValue(ctx, vrw, seq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t, err = t.SetAutoIncrementValue(ctx, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w.aiUpdate = false
|
||||
}
|
||||
|
||||
return t, nil
|
||||
@@ -289,7 +274,9 @@ func (m prollyIndexWriter) Map(ctx context.Context) (prolly.Map, error) {
|
||||
func (m prollyIndexWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
for to := range m.keyMap {
|
||||
from := m.keyMap.MapOrdinal(to)
|
||||
index.PutField(m.keyBld, to, sqlRow[from])
|
||||
if err := index.PutField(m.keyBld, to, sqlRow[from]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
k := m.keyBld.Build(sharePool)
|
||||
|
||||
@@ -302,7 +289,9 @@ func (m prollyIndexWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
|
||||
for to := range m.valMap {
|
||||
from := m.valMap.MapOrdinal(to)
|
||||
index.PutField(m.valBld, to, sqlRow[from])
|
||||
if err = index.PutField(m.valBld, to, sqlRow[from]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v := m.valBld.Build(sharePool)
|
||||
|
||||
@@ -312,7 +301,9 @@ func (m prollyIndexWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
func (m prollyIndexWriter) Delete(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
for to := range m.keyMap {
|
||||
from := m.keyMap.MapOrdinal(to)
|
||||
index.PutField(m.keyBld, to, sqlRow[from])
|
||||
if err := index.PutField(m.keyBld, to, sqlRow[from]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
k := m.keyBld.Build(sharePool)
|
||||
|
||||
@@ -322,7 +313,9 @@ func (m prollyIndexWriter) Delete(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
func (m prollyIndexWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
|
||||
for to := range m.keyMap {
|
||||
from := m.keyMap.MapOrdinal(to)
|
||||
index.PutField(m.keyBld, to, oldRow[from])
|
||||
if err := index.PutField(m.keyBld, to, oldRow[from]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
oldKey := m.keyBld.Build(sharePool)
|
||||
|
||||
@@ -334,7 +327,9 @@ func (m prollyIndexWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.R
|
||||
|
||||
for to := range m.keyMap {
|
||||
from := m.keyMap.MapOrdinal(to)
|
||||
index.PutField(m.keyBld, to, newRow[from])
|
||||
if err := index.PutField(m.keyBld, to, newRow[from]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
newKey := m.keyBld.Build(sharePool)
|
||||
|
||||
@@ -347,7 +342,9 @@ func (m prollyIndexWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.R
|
||||
|
||||
for to := range m.valMap {
|
||||
from := m.valMap.MapOrdinal(to)
|
||||
index.PutField(m.valBld, to, newRow[from])
|
||||
if err = index.PutField(m.valBld, to, newRow[from]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
v := m.valBld.Build(sharePool)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
@@ -31,13 +32,14 @@ import (
|
||||
type prollyWriteSession struct {
|
||||
workingSet *doltdb.WorkingSet
|
||||
tables map[string]*prollyTableWriter
|
||||
tracker globalstate.AutoIncrementTracker
|
||||
mut *sync.RWMutex
|
||||
}
|
||||
|
||||
var _ WriteSession = &prollyWriteSession{}
|
||||
|
||||
// GetTableWriter implemented WriteSession.
|
||||
func (s *prollyWriteSession) GetTableWriter(ctx context.Context, table string, database string, ait globalstate.AutoIncrementTracker, setter SessionRootSetter, batched bool) (TableWriter, error) {
|
||||
func (s *prollyWriteSession) GetTableWriter(ctx context.Context, table, db string, setter SessionRootSetter, batched bool) (TableWriter, error) {
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
@@ -76,13 +78,13 @@ func (s *prollyWriteSession) GetTableWriter(ctx context.Context, table string, d
|
||||
|
||||
twr := &prollyTableWriter{
|
||||
tableName: table,
|
||||
dbName: database,
|
||||
dbName: db,
|
||||
primary: pw,
|
||||
secondary: sws,
|
||||
tbl: t,
|
||||
sch: sch,
|
||||
aiCol: autoCol,
|
||||
aiTracker: ait,
|
||||
aiTracker: s.tracker,
|
||||
sess: s,
|
||||
setter: setter,
|
||||
batched: batched,
|
||||
@@ -150,6 +152,14 @@ func (s *prollyWriteSession) flush(ctx context.Context) (*doltdb.WorkingSet, err
|
||||
return err
|
||||
}
|
||||
|
||||
if schema.HasAutoIncrement(wr.sch) {
|
||||
v := s.tracker.Current(name)
|
||||
t, err = t.SetAutoIncrementValue(ctx, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
tables[name] = t
|
||||
|
||||
@@ -43,8 +43,8 @@ type sessionedTableEditor struct {
|
||||
var _ editor.TableEditor = &sessionedTableEditor{}
|
||||
|
||||
func (ste *sessionedTableEditor) InsertKeyVal(ctx context.Context, key, val types.Tuple, tagToVal map[uint64]types.Value, errFunc editor.PKDuplicateErrFunc) error {
|
||||
ste.tableEditSession.writeMutex.RLock()
|
||||
defer ste.tableEditSession.writeMutex.RUnlock()
|
||||
ste.tableEditSession.mut.RLock()
|
||||
defer ste.tableEditSession.mut.RUnlock()
|
||||
|
||||
err := ste.validateForInsert(ctx, tagToVal)
|
||||
if err != nil {
|
||||
@@ -56,8 +56,8 @@ func (ste *sessionedTableEditor) InsertKeyVal(ctx context.Context, key, val type
|
||||
}
|
||||
|
||||
func (ste *sessionedTableEditor) DeleteByKey(ctx context.Context, key types.Tuple, tagToVal map[uint64]types.Value) error {
|
||||
ste.tableEditSession.writeMutex.RLock()
|
||||
defer ste.tableEditSession.writeMutex.RUnlock()
|
||||
ste.tableEditSession.mut.RLock()
|
||||
defer ste.tableEditSession.mut.RUnlock()
|
||||
|
||||
if !ste.tableEditSession.opts.ForeignKeyChecksDisabled && len(ste.referencingTables) > 0 {
|
||||
err := ste.onDeleteHandleRowsReferencingValues(ctx, key, tagToVal)
|
||||
@@ -72,8 +72,8 @@ func (ste *sessionedTableEditor) DeleteByKey(ctx context.Context, key types.Tupl
|
||||
|
||||
// InsertRow adds the given row to the table. If the row already exists, use UpdateRow.
|
||||
func (ste *sessionedTableEditor) InsertRow(ctx context.Context, dRow row.Row, errFunc editor.PKDuplicateErrFunc) error {
|
||||
ste.tableEditSession.writeMutex.RLock()
|
||||
defer ste.tableEditSession.writeMutex.RUnlock()
|
||||
ste.tableEditSession.mut.RLock()
|
||||
defer ste.tableEditSession.mut.RUnlock()
|
||||
|
||||
dRowTaggedVals, err := dRow.TaggedValues()
|
||||
if err != nil {
|
||||
@@ -90,8 +90,8 @@ func (ste *sessionedTableEditor) InsertRow(ctx context.Context, dRow row.Row, er
|
||||
|
||||
// DeleteRow removes the given key from the table.
|
||||
func (ste *sessionedTableEditor) DeleteRow(ctx context.Context, r row.Row) error {
|
||||
ste.tableEditSession.writeMutex.RLock()
|
||||
defer ste.tableEditSession.writeMutex.RUnlock()
|
||||
ste.tableEditSession.mut.RLock()
|
||||
defer ste.tableEditSession.mut.RUnlock()
|
||||
|
||||
if !ste.tableEditSession.opts.ForeignKeyChecksDisabled && len(ste.referencingTables) > 0 {
|
||||
err := ste.handleReferencingRowsOnDelete(ctx, r)
|
||||
@@ -107,8 +107,8 @@ func (ste *sessionedTableEditor) DeleteRow(ctx context.Context, r row.Row) error
|
||||
// UpdateRow takes the current row and new row, and updates it accordingly. Any applicable rows from tables that have a
|
||||
// foreign key referencing this table will also be updated.
|
||||
func (ste *sessionedTableEditor) UpdateRow(ctx context.Context, dOldRow row.Row, dNewRow row.Row, errFunc editor.PKDuplicateErrFunc) error {
|
||||
ste.tableEditSession.writeMutex.RLock()
|
||||
defer ste.tableEditSession.writeMutex.RUnlock()
|
||||
ste.tableEditSession.mut.RLock()
|
||||
defer ste.tableEditSession.mut.RUnlock()
|
||||
|
||||
return ste.updateRow(ctx, dOldRow, dNewRow, true, errFunc)
|
||||
}
|
||||
@@ -123,15 +123,9 @@ func (ste *sessionedTableEditor) HasEdits() bool {
|
||||
return ste.tableEditor.HasEdits()
|
||||
}
|
||||
|
||||
// GetAutoIncrementValue implements TableEditor.
|
||||
func (ste *sessionedTableEditor) GetAutoIncrementValue() types.Value {
|
||||
return ste.tableEditor.GetAutoIncrementValue()
|
||||
}
|
||||
|
||||
// SetAutoIncrementValue implements TableEditor.
|
||||
func (ste *sessionedTableEditor) SetAutoIncrementValue(v types.Value) error {
|
||||
ste.dirty = true
|
||||
return ste.tableEditor.SetAutoIncrementValue(v)
|
||||
// MarkDirty implements TableEditor.
|
||||
func (ste *sessionedTableEditor) MarkDirty() {
|
||||
ste.tableEditor.MarkDirty()
|
||||
}
|
||||
|
||||
// Table implements TableEditor.
|
||||
|
||||
@@ -250,6 +250,13 @@ func (kte *keylessTableEditor) HasEdits() bool {
|
||||
return kte.dirty
|
||||
}
|
||||
|
||||
// MarkDirty implements TableEditor.
|
||||
func (kte *keylessTableEditor) MarkDirty() {
|
||||
kte.mu.Lock()
|
||||
defer kte.mu.Unlock()
|
||||
kte.dirty = true
|
||||
}
|
||||
|
||||
// GetAutoIncrementValue implements TableEditor, AUTO_INCREMENT is not yet supported for keyless tables.
|
||||
func (kte *keylessTableEditor) GetAutoIncrementValue() types.Value {
|
||||
return types.NullValue
|
||||
|
||||
@@ -62,10 +62,10 @@ type TableEditor interface {
|
||||
InsertRow(ctx context.Context, r row.Row, errFunc PKDuplicateErrFunc) error
|
||||
UpdateRow(ctx context.Context, old, new row.Row, errFunc PKDuplicateErrFunc) error
|
||||
DeleteRow(ctx context.Context, r row.Row) error
|
||||
HasEdits() bool
|
||||
|
||||
GetAutoIncrementValue() types.Value
|
||||
SetAutoIncrementValue(v types.Value) (err error)
|
||||
HasEdits() bool
|
||||
MarkDirty()
|
||||
|
||||
SetConstraintViolation(ctx context.Context, k types.LesserValuable, v types.Valuable) error
|
||||
|
||||
Table(ctx context.Context) (*doltdb.Table, error)
|
||||
@@ -128,10 +128,6 @@ type pkTableEditor struct {
|
||||
// Whenever any write operation occurs on the table editor, this is set to true for the lifetime of the editor.
|
||||
dirty uint32
|
||||
|
||||
hasAutoInc bool
|
||||
autoIncCol schema.Column
|
||||
autoIncVal types.Value
|
||||
|
||||
// This mutex blocks on each operation, so that map reads and updates are serialized
|
||||
writeMutex *sync.Mutex
|
||||
}
|
||||
@@ -165,22 +161,6 @@ func newPkTableEditor(ctx context.Context, t *doltdb.Table, tableSch schema.Sche
|
||||
te.indexEds[i] = NewIndexEditor(ctx, index, indexData, tableSch, opts)
|
||||
}
|
||||
|
||||
err = tableSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
|
||||
if col.AutoIncrement {
|
||||
te.autoIncVal, err = t.GetAutoIncrementValue(ctx)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
te.hasAutoInc = true
|
||||
te.autoIncCol = col
|
||||
return true, err
|
||||
}
|
||||
return false, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return te, nil
|
||||
}
|
||||
|
||||
@@ -446,32 +426,7 @@ func (te *pkTableEditor) insertKeyVal(ctx context.Context, keyHash hash.Hash, ke
|
||||
return err
|
||||
}
|
||||
|
||||
if te.hasAutoInc {
|
||||
insertVal, ok := tagToVal[te.autoIncCol.Tag]
|
||||
|
||||
if ok {
|
||||
var less bool
|
||||
|
||||
// float auto increment values should be rounded before comparing to the current auto increment values
|
||||
if te.autoIncVal.Kind() == types.FloatKind {
|
||||
rounded := types.Round(insertVal)
|
||||
less, err = rounded.Less(te.nbf, te.autoIncVal)
|
||||
} else {
|
||||
less, err = insertVal.Less(te.nbf, te.autoIncVal)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !less {
|
||||
te.autoIncVal = types.Round(insertVal)
|
||||
te.autoIncVal = types.Increment(te.autoIncVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
te.setDirty(true)
|
||||
te.MarkDirty()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -546,7 +501,7 @@ func (te *pkTableEditor) DeleteByKey(ctx context.Context, key types.Tuple, tagTo
|
||||
return err
|
||||
}
|
||||
|
||||
te.setDirty(true)
|
||||
te.MarkDirty()
|
||||
return te.tea.Delete(keyHash, key)
|
||||
}
|
||||
|
||||
@@ -644,7 +599,7 @@ func (te *pkTableEditor) UpdateRow(ctx context.Context, dOldRow row.Row, dNewRow
|
||||
return err
|
||||
}
|
||||
|
||||
te.setDirty(true)
|
||||
te.MarkDirty()
|
||||
|
||||
if kvp, pkExists, err := te.tea.Get(ctx, newHash, dNewKeyVal); err != nil {
|
||||
return err
|
||||
@@ -655,37 +610,14 @@ func (te *pkTableEditor) UpdateRow(ctx context.Context, dOldRow row.Row, dNewRow
|
||||
return te.tea.Insert(newHash, dNewKeyVal, dNewRowVal)
|
||||
}
|
||||
|
||||
func (te *pkTableEditor) GetAutoIncrementValue() types.Value {
|
||||
return te.autoIncVal
|
||||
}
|
||||
|
||||
func (te *pkTableEditor) SetAutoIncrementValue(v types.Value) (err error) {
|
||||
te.writeMutex.Lock()
|
||||
defer te.writeMutex.Unlock()
|
||||
|
||||
te.setDirty(true)
|
||||
te.autoIncVal = v
|
||||
te.t, err = te.t.SetAutoIncrementValue(nil, te.autoIncVal)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Table returns a Table based on the edits given, if any. If Flush() was not called prior, it will be called here.
|
||||
func (te *pkTableEditor) Table(ctx context.Context) (*doltdb.Table, error) {
|
||||
if !te.HasEdits() {
|
||||
return te.t, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if te.hasAutoInc {
|
||||
te.t, err = te.t.SetAutoIncrementValue(nil, te.autoIncVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var tbl *doltdb.Table
|
||||
err = func() error {
|
||||
err := func() error {
|
||||
te.writeMutex.Lock()
|
||||
defer te.writeMutex.Unlock()
|
||||
|
||||
@@ -816,7 +748,7 @@ func (te *pkTableEditor) SetConstraintViolation(ctx context.Context, k types.Les
|
||||
te.cvEditor = cvMap.Edit()
|
||||
}
|
||||
te.cvEditor.Set(k, v)
|
||||
te.setDirty(true)
|
||||
te.MarkDirty()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -845,13 +777,10 @@ func (te *pkTableEditor) Close(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (te *pkTableEditor) setDirty(dirty bool) {
|
||||
var val uint32
|
||||
if dirty {
|
||||
val = 1
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&te.dirty, val)
|
||||
// MarkDirty implements TableEditor.
|
||||
func (te *pkTableEditor) MarkDirty() {
|
||||
dirty := uint32(1)
|
||||
atomic.StoreUint32(&te.dirty, dirty)
|
||||
}
|
||||
|
||||
// hasEdits returns whether the table editor has had any successful write operations. This does not track whether the
|
||||
|
||||
@@ -13,12 +13,12 @@ teardown() {
|
||||
@test "sql-show: show table status on auto-increment table" {
|
||||
dolt sql -q "CREATE TABLE test(pk int NOT NULL AUTO_INCREMENT, c1 int, PRIMARY KEY (pk))"
|
||||
|
||||
run dolt sql -q "show table status where \`Auto_increment\`=1;"
|
||||
run dolt sql -q "show table status;"
|
||||
[ "$status" -eq 0 ]
|
||||
[[ "$output" =~ "test" ]] || false
|
||||
|
||||
dolt sql -q "INSERT INTO test (c1) VALUES (0)"
|
||||
run dolt sql -q "show table status where \`Auto_increment\`=2;"
|
||||
run dolt sql -q "show table status;"
|
||||
[ "$status" -eq 0 ]
|
||||
[[ "$output" =~ "test" ]] || false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user