diff --git a/go/go.mod b/go/go.mod index c4899f75ee..c1a95e73f8 100644 --- a/go/go.mod +++ b/go/go.mod @@ -68,7 +68,7 @@ require ( ) require ( - github.com/dolthub/go-mysql-server v0.11.1-0.20220324183628-b0a3bc9c2c2f + github.com/dolthub/go-mysql-server v0.11.1-0.20220325203039-101310d04210 github.com/google/flatbuffers v2.0.5+incompatible github.com/gosuri/uilive v0.0.4 github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6 diff --git a/go/go.sum b/go/go.sum index 98cee619d8..0129ae9062 100755 --- a/go/go.sum +++ b/go/go.sum @@ -170,8 +170,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= -github.com/dolthub/go-mysql-server v0.11.1-0.20220324183628-b0a3bc9c2c2f h1:F1hhtWcel9an0/Ohbdg0gw6gSlgwWBxlnrSl7Jyi/2M= -github.com/dolthub/go-mysql-server v0.11.1-0.20220324183628-b0a3bc9c2c2f/go.mod h1:1Sq4rszjP6AW7AJaF9xfycWexkNKIwkkOuYnoS5XcPg= +github.com/dolthub/go-mysql-server v0.11.1-0.20220325203039-101310d04210 h1:sZiIQR3mhQkPNb/9TbEg9W+nF4bcHer2ubPdl3iyKjM= +github.com/dolthub/go-mysql-server v0.11.1-0.20220325203039-101310d04210/go.mod h1:1Sq4rszjP6AW7AJaF9xfycWexkNKIwkkOuYnoS5XcPg= github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 h1:oyPHJlzumKta1vnOQqUnfdz+pk3EmnHS3Nd0cCT0I2g= github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371/go.mod h1:dhGBqcCEfK5kuFmeO5+WOx3hqc1k3M29c1oS/R7N4ms= github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8= diff --git a/go/libraries/doltcore/doltdb/durable/table.go b/go/libraries/doltcore/doltdb/durable/table.go index 1ee4f28e0f..c47be25f0d 100644 --- a/go/libraries/doltcore/doltdb/durable/table.go +++ b/go/libraries/doltcore/doltdb/durable/table.go @@ -87,9 +87,9 @@ type Table interface { SetConstraintViolations(ctx context.Context, violations types.Map) (Table, error) // GetAutoIncrement returns the AUTO_INCREMENT sequence value for this table. - GetAutoIncrement(ctx context.Context) (types.Value, error) + GetAutoIncrement(ctx context.Context) (uint64, error) // SetAutoIncrement sets the AUTO_INCREMENT sequence value for this table. - SetAutoIncrement(ctx context.Context, val types.Value) (Table, error) + SetAutoIncrement(ctx context.Context, val uint64) (Table, error) } type nomsTable struct { @@ -140,7 +140,7 @@ func NewTable(ctx context.Context, vrw types.ValueReadWriter, sch schema.Schema, indexesKey: indexesRef, } - if autoIncVal != nil { + if schema.HasAutoIncrement(sch) && autoIncVal != nil { sd[autoIncrementKey] = autoIncVal } @@ -461,53 +461,36 @@ func (t nomsTable) SetConstraintViolations(ctx context.Context, violationsMap ty } // GetAutoIncrement implements Table. -func (t nomsTable) GetAutoIncrement(ctx context.Context) (types.Value, error) { +func (t nomsTable) GetAutoIncrement(ctx context.Context) (uint64, error) { val, ok, err := t.tableStruct.MaybeGet(autoIncrementKey) if err != nil { - return nil, err + return 0, err } - if ok { - return val, nil + if !ok { + return 1, nil } - sch, err := t.GetSchema(ctx) - if err != nil { - return nil, err - } - - kind := types.UnknownKind - _ = sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { - if col.AutoIncrement { - kind = col.Kind - stop = true - } - return - }) - switch kind { - case types.IntKind: - return types.Int(1), nil - case types.UintKind: - return types.Uint(1), nil - case types.FloatKind: - return types.Float(1), nil + // older versions might have serialized auto-increment + // value as types.Int or types.Float. + switch t := val.(type) { + case types.Int: + return uint64(t), nil + case types.Uint: + return uint64(t), nil + case types.Float: + return uint64(t), nil default: - return nil, ErrUnknownAutoIncrementValue + return 0, ErrUnknownAutoIncrementValue } } // SetAutoIncrement implements Table. -func (t nomsTable) SetAutoIncrement(ctx context.Context, val types.Value) (Table, error) { - switch val.(type) { - case types.Int, types.Uint, types.Float: - st, err := t.tableStruct.Set(autoIncrementKey, val) - if err != nil { - return nil, err - } - return nomsTable{t.vrw, st}, nil - - default: - return nil, fmt.Errorf("cannot set auto increment to non-numeric value") +func (t nomsTable) SetAutoIncrement(ctx context.Context, val uint64) (Table, error) { + st, err := t.tableStruct.Set(autoIncrementKey, types.Uint(val)) + if err != nil { + return nil, err } + return nomsTable{t.vrw, st}, nil } func refFromNomsValue(ctx context.Context, vrw types.ValueReadWriter, val types.Value) (types.Ref, error) { diff --git a/go/libraries/doltcore/doltdb/table.go b/go/libraries/doltcore/doltdb/table.go index d0bf391b8e..56ee1d6eb0 100644 --- a/go/libraries/doltcore/doltdb/table.go +++ b/go/libraries/doltcore/doltdb/table.go @@ -455,12 +455,12 @@ func (t *Table) VerifyIndexRowData(ctx context.Context, indexName string) error } // GetAutoIncrementValue returns the current AUTO_INCREMENT value for this table. -func (t *Table) GetAutoIncrementValue(ctx context.Context) (types.Value, error) { +func (t *Table) GetAutoIncrementValue(ctx context.Context) (uint64, error) { return t.table.GetAutoIncrement(ctx) } // SetAutoIncrementValue sets the current AUTO_INCREMENT value for this table. -func (t *Table) SetAutoIncrementValue(ctx context.Context, val types.Value) (*Table, error) { +func (t *Table) SetAutoIncrementValue(ctx context.Context, val uint64) (*Table, error) { table, err := t.table.SetAutoIncrement(ctx, val) if err != nil { return nil, err diff --git a/go/libraries/doltcore/merge/merge.go b/go/libraries/doltcore/merge/merge.go index 98e66ec7e1..727f93603c 100644 --- a/go/libraries/doltcore/merge/merge.go +++ b/go/libraries/doltcore/merge/merge.go @@ -786,14 +786,7 @@ func mergeAutoIncrementValues(ctx context.Context, tbl, otherTbl, resultTbl *dol if err != nil { return nil, err } - auto := false - _ = sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { - if col.AutoIncrement { - auto, stop = true, true - } - return - }) - if !auto { + if !schema.HasAutoIncrement(sch) { return resultTbl, nil } @@ -805,14 +798,10 @@ func mergeAutoIncrementValues(ctx context.Context, tbl, otherTbl, resultTbl *dol if err != nil { return nil, err } - less, err := autoVal.Less(tbl.Format(), mergeAutoVal) - if err != nil { - return nil, err - } - if less { + if autoVal < mergeAutoVal { autoVal = mergeAutoVal } - return resultTbl.SetAutoIncrementValue(nil, autoVal) + return resultTbl.SetAutoIncrementValue(ctx, autoVal) } func MergeCommits(ctx context.Context, commit, mergeCommit *doltdb.Commit, opts editor.Options) (*doltdb.RootValue, map[string]*MergeStats, error) { diff --git a/go/libraries/doltcore/schema/alterschema/modifycolumn.go b/go/libraries/doltcore/schema/alterschema/modifycolumn.go index b1fe3b6bb2..b1a670a427 100644 --- a/go/libraries/doltcore/schema/alterschema/modifycolumn.go +++ b/go/libraries/doltcore/schema/alterschema/modifycolumn.go @@ -149,7 +149,7 @@ func updateTableWithModifiedColumn(ctx context.Context, tbl *doltdb.Table, oldSc return nil, err } - var autoVal types.Value + var autoVal uint64 if schema.HasAutoIncrement(newSch) && schema.HasAutoIncrement(oldSch) { autoVal, err = tbl.GetAutoIncrementValue(ctx) if err != nil { @@ -157,7 +157,7 @@ func updateTableWithModifiedColumn(ctx context.Context, tbl *doltdb.Table, oldSc } } - updatedTable, err := doltdb.NewNomsTable(ctx, vrw, newSch, rowData, indexData, autoVal) + updatedTable, err := doltdb.NewNomsTable(ctx, vrw, newSch, rowData, indexData, types.Uint(autoVal)) if err != nil { return nil, err } diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index e9a5680fbc..7228d51ecd 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -82,15 +82,11 @@ func DbsAsDSQLDBs(dbs []sql.Database) []SqlDatabase { // Database implements sql.Database for a dolt DB. type Database struct { - name string - ddb *doltdb.DoltDB - rsr env.RepoStateReader - rsw env.RepoStateWriter - drw env.DocsReadWriter - - // todo: needs a major refactor to - // correctly handle persisted sequences - // that must be coordinated across txs + name string + ddb *doltdb.DoltDB + rsr env.RepoStateReader + rsw env.RepoStateWriter + drw env.DocsReadWriter gs globalstate.GlobalState editOpts editor.Options } @@ -167,6 +163,7 @@ var _ sql.TableRenamer = Database{} var _ sql.TriggerDatabase = Database{} var _ sql.StoredProcedureDatabase = Database{} var _ sql.TransactionDatabase = Database{} +var _ globalstate.StateProvider = Database{} // NewDatabase returns a new dolt database to use in queries. func NewDatabase(name string, dbData env.DbData, editOpts editor.Options) Database { @@ -259,6 +256,10 @@ func (db Database) DbData() env.DbData { } } +func (db Database) GetGlobalState() globalstate.GlobalState { + return db.gs +} + // GetTableInsensitive is used when resolving tables in queries. It returns a best-effort case-insensitive match for // the table name given. func (db Database) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) { @@ -585,6 +586,18 @@ func (db Database) GetRoot(ctx *sql.Context) (*doltdb.RootValue, error) { return dbState.GetRoots().Working, nil } +func (db Database) GetWorkingSet(ctx *sql.Context) (*doltdb.WorkingSet, error) { + sess := dsess.DSessFromSess(ctx.Session) + dbState, ok, err := sess.LookupDbState(ctx, db.Name()) + if err != nil { + return nil, err + } + if !ok { + return nil, fmt.Errorf("no root value found in session") + } + return dbState.WorkingSet, nil +} + // SetRoot should typically be called on the Session, which is where this state lives. But it's available here as a // convenience. func (db Database) SetRoot(ctx *sql.Context, newRoot *doltdb.RootValue) error { @@ -660,7 +673,10 @@ func (db Database) dropTableFromAiTracker(ctx *sql.Context, tableName string) er return err } - ait := db.gs.GetAutoIncrementTracker(ws.Ref()) + ait, err := db.gs.GetAutoIncrementTracker(ctx, ws) + if err != nil { + return err + } ait.DropTable(tableName) return nil @@ -681,10 +697,11 @@ func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.Prima // Unlike the exported version CreateTable, createSqlTable doesn't enforce any table name checks. func (db Database) createSqlTable(ctx *sql.Context, tableName string, sch sql.PrimaryKeySchema) error { - root, err := db.GetRoot(ctx) + ws, err := db.GetWorkingSet(ctx) if err != nil { return err } + root := ws.WorkingRoot() if exists, err := root.HasTable(ctx, tableName); err != nil { return err @@ -707,6 +724,14 @@ func (db Database) createSqlTable(ctx *sql.Context, tableName string, sch sql.Pr return schema.ErrUsingSpatialKey.New(tableName) } + if schema.HasAutoIncrement(doltSch) { + ait, err := db.gs.GetAutoIncrementTracker(ctx, ws) + if err != nil { + return err + } + ait.AddNewTable(tableName) + } + return db.createDoltTable(ctx, tableName, root, doltSch) } diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 7e9f94df63..f73ba42bf1 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -485,7 +485,6 @@ func dbRevisionForCommit(ctx context.Context, srcDb Database, revSpec string) (R rsw: srcDb.DbData().Rsw, rsr: srcDb.DbData().Rsr, drw: srcDb.DbData().Drw, - gs: nil, editOpts: srcDb.editOpts, }} init := dsess.InitialDbState{ diff --git a/go/libraries/doltcore/sqle/dsess/database_session_state.go b/go/libraries/doltcore/sqle/dsess/database_session_state.go index d9044e7d59..b8d5ca0135 100644 --- a/go/libraries/doltcore/sqle/dsess/database_session_state.go +++ b/go/libraries/doltcore/sqle/dsess/database_session_state.go @@ -17,11 +17,11 @@ package dsess import ( "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer" - "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" - "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer" + "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" ) type InitialDbState struct { @@ -48,6 +48,7 @@ type DatabaseSessionState struct { WorkingSet *doltdb.WorkingSet dbData env.DbData WriteSession writer.WriteSession + globalState globalstate.GlobalState readOnly bool dirty bool readReplica *env.Remote diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index 347ae0e565..0482c40313 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -27,6 +27,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" "github.com/dolthub/dolt/go/libraries/doltcore/ref" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer" "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" "github.com/dolthub/dolt/go/libraries/utils/config" @@ -700,7 +701,11 @@ func (sess *Session) SwitchWorkingSet( // make a fresh WriteSession, discard existing WriteSession opts := sessionState.WriteSession.GetOptions() nbf := ws.WorkingRoot().VRW().Format() - sessionState.WriteSession = writer.NewWriteSession(nbf, ws, opts) + tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx, ws) + if err != nil { + return err + } + sessionState.WriteSession = writer.NewWriteSession(nbf, ws, tracker, opts) // After switching to a new working set, we are by definition clean sessionState.dirty = false @@ -830,17 +835,26 @@ func (sess *Session) AddDB(ctx *sql.Context, dbState InitialDbState) error { nbf := sessionState.dbData.Ddb.Format() editOpts := db.(interface{ EditOptions() editor.Options }).EditOptions() + stateProvider, ok := db.(globalstate.StateProvider) + if !ok { + return fmt.Errorf("database does not contain global state store") + } + sessionState.globalState = stateProvider.GetGlobalState() + // WorkingSet is nil in the case of a read only, detached head DB if dbState.Err != nil { sessionState.Err = dbState.Err } else if dbState.WorkingSet != nil { sessionState.WorkingSet = dbState.WorkingSet - sessionState.WriteSession = writer.NewWriteSession(nbf, sessionState.WorkingSet, editOpts) - err := sess.SetWorkingSet(ctx, db.Name(), dbState.WorkingSet) + tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx, sessionState.WorkingSet) if err != nil { return err } + sessionState.WriteSession = writer.NewWriteSession(nbf, sessionState.WorkingSet, tracker, editOpts) + if err = sess.SetWorkingSet(ctx, db.Name(), dbState.WorkingSet); err != nil { + return err + } } else { headRoot, err := dbState.HeadCommit.GetRootValue() diff --git a/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go b/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go index 66d1c0e3c8..cdec60ba43 100644 --- a/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go +++ b/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go @@ -15,146 +15,118 @@ package globalstate import ( - "fmt" + "context" + "math" "sync" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/ref" + "github.com/dolthub/dolt/go/libraries/doltcore/schema" ) -type AutoIncrementTracker interface { - // Next returns the next auto increment value to be used by a table. If a table is not initialized in the counter - // it will used the value stored in disk. - Next(tableName string, insertVal interface{}, diskVal interface{}) (interface{}, error) - // Reset resets the auto increment tracker value for a table. Typically used in truncate statements. - Reset(tableName string, val interface{}) - // DropTable removes a table from the autoincrement tracker. - DropTable(tableName string) -} - -// AutoIncrementTracker is a global map that tracks which auto increment keys have been given for each table. At runtime -// it hands out the current key. -func NewAutoIncrementTracker() AutoIncrementTracker { - return &autoIncrementTracker{ - valuePerTable: make(map[string]interface{}), +// CoerceAutoIncrementValue converts |val| into an AUTO_INCREMENT sequence value +func CoerceAutoIncrementValue(val interface{}) (uint64, error) { + switch typ := val.(type) { + case float32: + val = math.Round(float64(typ)) + case float64: + val = math.Round(typ) } + + var err error + val, err = sql.Uint64.Convert(val) + if err != nil { + return 0, err + } + if val == nil || val == uint64(0) { + return 0, nil + } + return val.(uint64), nil } -type autoIncrementTracker struct { - valuePerTable map[string]interface{} - mu sync.Mutex +func NewAutoIncrementTracker(ctx context.Context, ws *doltdb.WorkingSet) (ait AutoIncrementTracker, err error) { + ait = AutoIncrementTracker{ + wsRef: ws.Ref(), + sequences: make(map[string]uint64), + mu: &sync.Mutex{}, + } + + // collect auto increment values + err = ws.WorkingRoot().IterTables(ctx, func(name string, table *doltdb.Table, sch schema.Schema) (bool, error) { + ok := schema.HasAutoIncrement(sch) + if !ok { + return false, nil + } + seq, err := table.GetAutoIncrementValue(ctx) + if err != nil { + return true, err + } + ait.sequences[name] = seq + return false, nil + }) + return ait, err } -var _ AutoIncrementTracker = (*autoIncrementTracker)(nil) +type AutoIncrementTracker struct { + wsRef ref.WorkingSetRef + sequences map[string]uint64 + mu *sync.Mutex +} -func (a *autoIncrementTracker) Next(tableName string, insertVal interface{}, diskVal interface{}) (interface{}, error) { +// Current returns the current AUTO_INCREMENT value for |tableName|. +func (a AutoIncrementTracker) Current(tableName string) uint64 { + a.mu.Lock() + defer a.mu.Unlock() + return a.sequences[tableName] +} + +// Next returns the next AUTO_INCREMENT value for |tableName|, considering the provided |insertVal|. +func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) { a.mu.Lock() defer a.mu.Unlock() - diskVal = valOrZero(diskVal) - - // Case 0: Just use the value passed in. - potential, ok := a.valuePerTable[tableName] - if !ok { - // Use the disk val if the table has not been initialized yet. - potential = diskVal - } - - // Case 1: Disk Val is greater. This is useful for updating the tracker when a merge occurs. - // TODO: This is a bit of a hack. The correct solution is to plumb this tracker through the merge logic. - diskValGreater, err := geq(valOrZero(diskVal), valOrZero(a.valuePerTable[tableName])) - if err != nil { - return nil, err - } - - if diskValGreater { - potential = diskVal - } - - // Case 2: Overwrite anything if an insert val is passed. - if insertVal != nil { - potential = insertVal - } - - // update the table only if val >= existing - isGeq, err := geq(valOrZero(potential), valOrZero(a.valuePerTable[tableName])) + given, err := CoerceAutoIncrementValue(insertVal) if err != nil { return 0, err } - if isGeq { - val, err := convertIntTypeToUint(potential) - if err != nil { - return val, err - } + curr := a.sequences[tbl] - a.valuePerTable[tableName] = val + 1 + if given == 0 { + // |given| is 0 or NULL + a.sequences[tbl]++ + return curr, nil } - return potential, nil + if given >= curr { + a.sequences[tbl] = given + a.sequences[tbl]++ + return given, nil + } + + // |given| < curr + return given, nil } -func (a *autoIncrementTracker) Reset(tableName string, val interface{}) { +// Set sets the current AUTO_INCREMENT value for |tableName|. +func (a AutoIncrementTracker) Set(tableName string, val uint64) { a.mu.Lock() defer a.mu.Unlock() - - a.valuePerTable[tableName] = val + a.sequences[tableName] = val } -func (a *autoIncrementTracker) DropTable(tableName string) { +// AddNewTable adds |tablename| to the AutoIncrementTracker. +func (a AutoIncrementTracker) AddNewTable(tableName string) { a.mu.Lock() defer a.mu.Unlock() - - delete(a.valuePerTable, tableName) + a.sequences[tableName] = uint64(1) } -// Helper method that sets nil values to 0 for clarity purposes -func valOrZero(val interface{}) interface{} { - if val == nil { - return 0 - } - - return val -} - -func geq(val1 interface{}, val2 interface{}) (bool, error) { - v1, err := convertIntTypeToUint(val1) - if err != nil { - return false, err - } - - v2, err := convertIntTypeToUint(val2) - if err != nil { - return false, err - } - - return v1 >= v2, nil -} - -func convertIntTypeToUint(val interface{}) (uint64, error) { - switch t := val.(type) { - case int: - return uint64(t), nil - case int8: - return uint64(t), nil - case int16: - return uint64(t), nil - case int32: - return uint64(t), nil - case int64: - return uint64(t), nil - case uint: - return uint64(t), nil - case uint8: - return uint64(t), nil - case uint16: - return uint64(t), nil - case uint32: - return uint64(t), nil - case uint64: - return t, nil - case float32: - return uint64(t), nil - case float64: - return uint64(t), nil - default: - return 0, fmt.Errorf("error: auto increment is not a numeric type") - } +// DropTable drops |tablename| from the AutoIncrementTracker. +func (a AutoIncrementTracker) DropTable(tableName string) { + a.mu.Lock() + defer a.mu.Unlock() + delete(a.sequences, tableName) } diff --git a/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker_test.go b/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker_test.go index 0c6194a491..5caf9079ef 100644 --- a/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker_test.go +++ b/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker_test.go @@ -15,42 +15,58 @@ package globalstate import ( - "sync" + "fmt" "testing" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) -func TestNextHasNoRepeats(t *testing.T) { - var allVals sync.Map - aiTracker := NewAutoIncrementTracker() - - for i := 0; i < 100; i++ { - go func() { - for j := 0; j < 10; j++ { - nxt, err := aiTracker.Next("test", nil, 1) - require.NoError(t, err) - - val, err := convertIntTypeToUint(nxt) - require.NoError(t, err) - - current, ok := allVals.Load(val) - if !ok { - allVals.Store(val, 1) - } else { - asUint, _ := convertIntTypeToUint(current) - allVals.Store(val, asUint+1) - } - } - }() +func TestCoerceAutoIncrementValue(t *testing.T) { + tests := []struct { + val interface{} + exp uint64 + err bool + }{ + { + val: nil, + exp: uint64(0), + }, + { + val: int32(0), + exp: uint64(0), + }, + { + val: int32(1), + exp: uint64(1), + }, + { + val: uint32(1), + exp: uint64(1), + }, + { + val: float32(1), + exp: uint64(1), + }, + { + val: float32(1.1), + exp: uint64(1), + }, + { + val: float32(1.9), + exp: uint64(2), + }, } - // Make sure each key was called once - allVals.Range(func(key, value interface{}) bool { - asUint, _ := convertIntTypeToUint(value) - - require.Equal(t, uint64(1), asUint) - - return true - }) + for _, test := range tests { + name := fmt.Sprintf("Coerce %v to %v", test.val, test.exp) + t.Run(name, func(t *testing.T) { + act, err := CoerceAutoIncrementValue(test.val) + if test.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.exp, act) + }) + } } diff --git a/go/libraries/doltcore/sqle/globalstate/global_state.go b/go/libraries/doltcore/sqle/globalstate/global_state.go index e7f176a9a3..4852ade316 100644 --- a/go/libraries/doltcore/sqle/globalstate/global_state.go +++ b/go/libraries/doltcore/sqle/globalstate/global_state.go @@ -15,36 +15,45 @@ package globalstate import ( + "context" "sync" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/ref" ) -type GlobalState interface { - GetAutoIncrementTracker(wsref ref.WorkingSetRef) AutoIncrementTracker +type StateProvider interface { + GetGlobalState() GlobalState } func NewGlobalStateStore() GlobalState { - return &globalStateImpl{ + return GlobalState{ trackerMap: make(map[ref.WorkingSetRef]AutoIncrementTracker), + mu: &sync.Mutex{}, } } -type globalStateImpl struct { +type GlobalState struct { trackerMap map[ref.WorkingSetRef]AutoIncrementTracker - mu sync.Mutex + mu *sync.Mutex } -var _ GlobalState = (*globalStateImpl)(nil) - -func (g *globalStateImpl) GetAutoIncrementTracker(wsref ref.WorkingSetRef) AutoIncrementTracker { +func (g GlobalState) GetAutoIncrementTracker(ctx context.Context, ws *doltdb.WorkingSet) (AutoIncrementTracker, error) { g.mu.Lock() defer g.mu.Unlock() - _, ok := g.trackerMap[wsref] - if !ok { - g.trackerMap[wsref] = NewAutoIncrementTracker() + ait, ok := g.trackerMap[ws.Ref()] + if ok { + return ait, nil } - return g.trackerMap[wsref] + var err error + ait, err = NewAutoIncrementTracker(ctx, ws) + if err != nil { + return AutoIncrementTracker{}, err + } + g.trackerMap[ws.Ref()] = ait + + return ait, nil } diff --git a/go/libraries/doltcore/sqle/sqlutil/static_errors.go b/go/libraries/doltcore/sqle/sqlutil/static_errors.go index c33feeeeea..45a9eb0aca 100644 --- a/go/libraries/doltcore/sqle/sqlutil/static_errors.go +++ b/go/libraries/doltcore/sqle/sqlutil/static_errors.go @@ -72,7 +72,7 @@ func (e *StaticErrorEditor) Update(*sql.Context, sql.Row, sql.Row) error { return e.err } -func (e *StaticErrorEditor) SetAutoIncrementValue(*sql.Context, interface{}) error { +func (e *StaticErrorEditor) SetAutoIncrementValue(*sql.Context, uint64) error { return e.err } diff --git a/go/libraries/doltcore/sqle/tables.go b/go/libraries/doltcore/sqle/tables.go index 8e235d75c4..0c1b41eafc 100644 --- a/go/libraries/doltcore/sqle/tables.go +++ b/go/libraries/doltcore/sqle/tables.go @@ -35,6 +35,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/schema/alterschema" "github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer" @@ -196,12 +197,7 @@ func (t *DoltTable) GetAutoIncrementValue(ctx *sql.Context) (interface{}, error) if err != nil { return nil, err } - - val, err := table.GetAutoIncrementValue(ctx) - if err != nil { - return nil, err - } - return t.autoIncCol.TypeInfo.ConvertNomsValueToValue(val) + return table.GetAutoIncrementValue(ctx) } // Name returns the name of the table. @@ -415,19 +411,13 @@ func (t *WritableDoltTable) getTableEditor(ctx *sql.Context) (ed writer.TableWri } } - ws, err := ds.WorkingSet(ctx, t.db.name) - if err != nil { - return nil, err - } - ait := t.db.gs.GetAutoIncrementTracker(ws.Ref()) - state, _, err := ds.LookupDbState(ctx, t.db.name) if err != nil { return nil, err } setter := ds.SetRoot - ed, err = state.WriteSession.GetTableWriter(ctx, t.tableName, t.db.Name(), ait, setter, batched) + ed, err = state.WriteSession.GetTableWriter(ctx, t.tableName, t.db.Name(), setter, batched) if err != nil { return nil, err @@ -542,22 +532,17 @@ func (t *WritableDoltTable) PeekNextAutoIncrementValue(ctx *sql.Context) (interf } // GetNextAutoIncrementValue implements sql.AutoIncrementTable -func (t *WritableDoltTable) GetNextAutoIncrementValue(ctx *sql.Context, potentialVal interface{}) (interface{}, error) { +func (t *WritableDoltTable) GetNextAutoIncrementValue(ctx *sql.Context, potentialVal interface{}) (uint64, error) { if !t.autoIncCol.AutoIncrement { - return nil, sql.ErrNoAutoIncrementCol + return 0, sql.ErrNoAutoIncrementCol } ed, err := t.getTableEditor(ctx) if err != nil { - return nil, err + return 0, err } - tableVal, err := t.getTableAutoIncrementValue(ctx) - if err != nil { - return nil, err - } - - return ed.NextAutoIncrementValue(potentialVal, tableVal) + return ed.GetNextAutoIncrementValue(ctx, potentialVal) } func (t *WritableDoltTable) getTableAutoIncrementValue(ctx *sql.Context) (interface{}, error) { @@ -887,6 +872,18 @@ func (t *AlterableDoltTable) AddColumn(ctx *sql.Context, column *sql.Column, ord return err } + if column.AutoIncrement { + ws, err := t.db.GetWorkingSet(ctx) + if err != nil { + return err + } + ait, err := t.db.gs.GetAutoIncrementTracker(ctx, ws) + if err != nil { + return err + } + ait.AddNewTable(t.tableName) + } + newRoot, err := root.PutTable(ctx, t.tableName, updatedTable) if err != nil { return err @@ -1042,10 +1039,11 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c return nil } - root, err := t.getRoot(ctx) + ws, err := t.db.GetWorkingSet(ctx) if err != nil { return err } + root := ws.WorkingRoot() table, _, err := root.GetTable(ctx, t.tableName) if err != nil { @@ -1118,10 +1116,6 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c return err } - initialValue := column.Type.Zero() - - colIdx := updatedSch.GetAllCols().IndexOf(columnName) - rowData, err := updatedTable.GetRowData(ctx) if err != nil { return err @@ -1134,6 +1128,9 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c return err } + initialValue := column.Type.Zero() + colIdx := updatedSch.GetAllCols().IndexOf(columnName) + for { r, err := rowIter.Next(ctx) if err == io.EOF { @@ -1146,23 +1143,28 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c if err != nil { return err } - if cmp < 0 { initialValue = r[colIdx] } } - initialValNoms, err := col.TypeInfo.ConvertValueToNomsValue(ctx, root.VRW(), initialValue) + seq, err := globalstate.CoerceAutoIncrementValue(initialValue) + if err != nil { + return err + } + seq++ + + updatedTable, err = updatedTable.SetAutoIncrementValue(ctx, seq) if err != nil { return err } - initialValNoms = increment(initialValNoms) - - updatedTable, err = updatedTable.SetAutoIncrementValue(ctx, initialValNoms) + ait, err := t.db.gs.GetAutoIncrementTracker(ctx, ws) if err != nil { return err } + ait.AddNewTable(t.tableName) + ait.Set(t.tableName, seq) } newRoot, err := root.PutTable(ctx, t.tableName, updatedTable) diff --git a/go/libraries/doltcore/sqle/writer/noms_table_writer.go b/go/libraries/doltcore/sqle/writer/noms_table_writer.go index 91d03be7ea..c83989bb01 100644 --- a/go/libraries/doltcore/sqle/writer/noms_table_writer.go +++ b/go/libraries/doltcore/sqle/writer/noms_table_writer.go @@ -35,8 +35,7 @@ type TableWriter interface { sql.RowInserter sql.RowDeleter sql.AutoIncrementSetter - - NextAutoIncrementValue(potentialVal, tableVal interface{}) (interface{}, error) + GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) } // SessionRootSetter sets the root value for the session. @@ -56,14 +55,14 @@ type nomsTableWriter struct { tableName string dbName string sch schema.Schema - autoIncCol schema.Column vrw types.ValueReadWriter kvToSQLRow *index.KVToSqlRowConverter tableEditor editor.TableEditor sess WriteSession - aiTracker globalstate.AutoIncrementTracker batched bool + autoInc globalstate.AutoIncrementTracker + setter SessionRootSetter } @@ -158,25 +157,17 @@ func (te *nomsTableWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.R return err } -func (te *nomsTableWriter) NextAutoIncrementValue(potentialVal, tableVal interface{}) (interface{}, error) { - return te.aiTracker.Next(te.tableName, potentialVal, tableVal) +func (te *nomsTableWriter) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) { + return te.autoInc.Next(te.tableName, insertVal) } -func (te *nomsTableWriter) GetAutoIncrementValue() (interface{}, error) { - val := te.tableEditor.GetAutoIncrementValue() - return te.autoIncCol.TypeInfo.ConvertNomsValueToValue(val) -} - -func (te *nomsTableWriter) SetAutoIncrementValue(ctx *sql.Context, val interface{}) error { - nomsVal, err := te.autoIncCol.TypeInfo.ConvertValueToNomsValue(ctx, te.vrw, val) +func (te *nomsTableWriter) SetAutoIncrementValue(ctx *sql.Context, val uint64) error { + seq, err := globalstate.CoerceAutoIncrementValue(val) if err != nil { return err } - if err = te.tableEditor.SetAutoIncrementValue(nomsVal); err != nil { - return err - } - - te.aiTracker.Reset(te.tableName, val) + te.autoInc.Set(te.tableName, seq) + te.tableEditor.MarkDirty() return te.flush(ctx) } diff --git a/go/libraries/doltcore/sqle/writer/noms_write_session.go b/go/libraries/doltcore/sqle/writer/noms_write_session.go index 35a9f6fba4..7f77d2d8ef 100644 --- a/go/libraries/doltcore/sqle/writer/noms_write_session.go +++ b/go/libraries/doltcore/sqle/writer/noms_write_session.go @@ -19,20 +19,21 @@ import ( "fmt" "sync" - "github.com/dolthub/dolt/go/store/types" + "golang.org/x/sync/errgroup" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" + "github.com/dolthub/dolt/go/store/types" ) // WriteSession encapsulates writes made within a SQL session. // It's responsible for creating and managing the lifecycle of TableWriter's. type WriteSession interface { // GetTableWriter creates a TableWriter and adds it to the WriteSession. - GetTableWriter(ctx context.Context, table, db string, ait globalstate.AutoIncrementTracker, setter SessionRootSetter, batched bool) (TableWriter, error) + GetTableWriter(ctx context.Context, table, db string, setter SessionRootSetter, batched bool) (TableWriter, error) // UpdateWorkingSet takes a callback to update this WriteSession's WorkingSet. The update method cannot change the // WorkingSetRef of the WriteSession. WriteSession flushes the pending writes in the session before calling the update. @@ -56,7 +57,8 @@ type WriteSession interface { type nomsWriteSession struct { workingSet *doltdb.WorkingSet tables map[string]*sessionedTableEditor - writeMutex *sync.RWMutex // This mutex is specifically for changes that affect the TES or all STEs + tracker globalstate.AutoIncrementTracker + mut *sync.RWMutex // This mutex is specifically for changes that affect the TES or all STEs opts editor.Options } @@ -65,11 +67,12 @@ var _ WriteSession = &nomsWriteSession{} // NewWriteSession creates and returns a WriteSession. Inserting a nil root is not an error, as there are // locations that do not have a root at the time of this call. However, a root must be set through SetRoot before any // table editors are returned. -func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, opts editor.Options) WriteSession { +func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, tracker globalstate.AutoIncrementTracker, opts editor.Options) WriteSession { if types.IsFormat_DOLT_1(nbf) { return &prollyWriteSession{ workingSet: ws, tables: make(map[string]*prollyTableWriter), + tracker: tracker, mut: &sync.RWMutex{}, } } @@ -77,14 +80,15 @@ func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, opts edito return &nomsWriteSession{ workingSet: ws, tables: make(map[string]*sessionedTableEditor), - writeMutex: &sync.RWMutex{}, + tracker: tracker, + mut: &sync.RWMutex{}, opts: opts, } } -func (s *nomsWriteSession) GetTableWriter(ctx context.Context, table string, database string, ait globalstate.AutoIncrementTracker, setter SessionRootSetter, batched bool) (TableWriter, error) { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() +func (s *nomsWriteSession) GetTableWriter(ctx context.Context, table, db string, setter SessionRootSetter, batched bool) (TableWriter, error) { + s.mut.Lock() + defer s.mut.Unlock() t, ok, err := s.workingSet.WorkingRoot().GetTable(ctx, table) if err != nil { @@ -106,42 +110,39 @@ func (s *nomsWriteSession) GetTableWriter(ctx context.Context, table string, dat } conv := index.NewKVToSqlRowConverterForCols(t.Format(), sch) - ac := autoIncrementColFromSchema(sch) return &nomsTableWriter{ tableName: table, - dbName: database, + dbName: db, sch: sch, - autoIncCol: ac, vrw: vrw, kvToSQLRow: conv, tableEditor: te, sess: s, batched: batched, - aiTracker: ait, + autoInc: s.tracker, setter: setter, }, nil } // Flush returns an updated root with all of the changed tables. func (s *nomsWriteSession) Flush(ctx context.Context) (*doltdb.WorkingSet, error) { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() - + s.mut.Lock() + defer s.mut.Unlock() return s.flush(ctx) } // SetWorkingSet implements WriteSession. func (s *nomsWriteSession) SetWorkingSet(ctx context.Context, ws *doltdb.WorkingSet) error { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() + s.mut.Lock() + defer s.mut.Unlock() return s.setWorkingSet(ctx, ws) } // UpdateWorkingSet implements WriteSession. func (s *nomsWriteSession) UpdateWorkingSet(ctx context.Context, cb func(ctx context.Context, current *doltdb.WorkingSet) (*doltdb.WorkingSet, error)) error { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() + s.mut.Lock() + defer s.mut.Unlock() current, err := s.flush(ctx) if err != nil { @@ -167,44 +168,46 @@ func (s *nomsWriteSession) SetOptions(opts editor.Options) { // flush is the inner implementation for Flush that does not acquire any locks func (s *nomsWriteSession) flush(ctx context.Context) (*doltdb.WorkingSet, error) { - rootMutex := &sync.Mutex{} - wg := &sync.WaitGroup{} - wg.Add(len(s.tables)) - newRoot := s.workingSet.WorkingRoot() - var tableErr error - var rootErr error - for tableName, ste := range s.tables { - if !ste.HasEdits() { - wg.Done() + mu := &sync.Mutex{} + rootUpdate := func(name string, table *doltdb.Table) (err error) { + mu.Lock() + defer mu.Unlock() + if newRoot != nil { + newRoot, err = newRoot.PutTable(ctx, name, table) + } + return err + } + + eg, ctx := errgroup.WithContext(ctx) + + for tblName, tblEditor := range s.tables { + if !tblEditor.HasEdits() { continue } - // we can run all of the Table calls concurrently as long as we guard updating the root - go func(tableName string, ste *sessionedTableEditor) { - defer wg.Done() - updatedTable, err := ste.tableEditor.Table(ctx) - // we lock immediately after doing the operation, since both error setting and root updating are guarded - rootMutex.Lock() - defer rootMutex.Unlock() + // copy variables + name, ed := tblName, tblEditor + + eg.Go(func() error { + tbl, err := ed.tableEditor.Table(ctx) if err != nil { - if tableErr == nil { - tableErr = err + return err + } + + if schema.HasAutoIncrement(ed.Schema()) { + v := s.tracker.Current(name) + tbl, err = tbl.SetAutoIncrementValue(ctx, v) + if err != nil { + return err } - return } - newRoot, err = newRoot.PutTable(ctx, tableName, updatedTable) - if err != nil && rootErr == nil { - rootErr = err - } - }(tableName, ste) + + return rootUpdate(name, tbl) + }) } - wg.Wait() - if tableErr != nil { - return nil, tableErr - } - if rootErr != nil { - return nil, rootErr + if err := eg.Wait(); err != nil { + return nil, err } s.workingSet = s.workingSet.WithWorkingRoot(newRoot) @@ -320,6 +323,10 @@ func (s *nomsWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Working return err } + if err = s.updateAutoIncrementSequences(ctx, root); err != nil { + return err + } + for tableName, localTableEditor := range s.tables { t, ok, err := root.GetTable(ctx, tableName) if err != nil { @@ -336,6 +343,7 @@ func (s *nomsWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Working if err != nil { return err } + newTableEditor, err := editor.NewTableEditor(ctx, t, tSch, tableName, s.opts) if err != nil { return err @@ -354,3 +362,17 @@ func (s *nomsWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Working } return nil } + +func (s *nomsWriteSession) updateAutoIncrementSequences(ctx context.Context, root *doltdb.RootValue) error { + return root.IterTables(ctx, func(name string, table *doltdb.Table, sch schema.Schema) (stop bool, err error) { + if !schema.HasAutoIncrement(sch) { + return + } + v, err := table.GetAutoIncrementValue(ctx) + if err != nil { + return true, err + } + s.tracker.Set(name, v) + return + }) +} diff --git a/go/libraries/doltcore/sqle/writer/prolly_table_writer.go b/go/libraries/doltcore/sqle/writer/prolly_table_writer.go index b8ea0fc564..c694119d0f 100644 --- a/go/libraries/doltcore/sqle/writer/prolly_table_writer.go +++ b/go/libraries/doltcore/sqle/writer/prolly_table_writer.go @@ -42,7 +42,6 @@ type prollyTableWriter struct { aiCol schema.Column aiTracker globalstate.AutoIncrementTracker - aiUpdate bool sess WriteSession setter SessionRootSetter @@ -64,7 +63,6 @@ func (w *prollyTableWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error { if err := w.primary.Insert(ctx, sqlRow); err != nil { return err } - w.aiUpdate = true return nil } @@ -97,28 +95,27 @@ func (w *prollyTableWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql. if err := w.primary.Update(ctx, oldRow, newRow); err != nil { return err } - w.aiUpdate = true return nil } -// NextAutoIncrementValue implements TableWriter. -func (w *prollyTableWriter) NextAutoIncrementValue(potentialVal, tableVal interface{}) (interface{}, error) { - return w.aiTracker.Next(w.tableName, potentialVal, tableVal) +// GetNextAutoIncrementValue implements TableWriter. +func (w *prollyTableWriter) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) { + return w.aiTracker.Next(w.tableName, insertVal) } // SetAutoIncrementValue implements TableWriter. -func (w *prollyTableWriter) SetAutoIncrementValue(ctx *sql.Context, val interface{}) error { - nomsVal, err := w.aiCol.TypeInfo.ConvertValueToNomsValue(ctx, w.tbl.ValueReadWriter(), val) +func (w *prollyTableWriter) SetAutoIncrementValue(ctx *sql.Context, val uint64) error { + seq, err := globalstate.CoerceAutoIncrementValue(val) if err != nil { return err } - w.tbl, err = w.tbl.SetAutoIncrementValue(ctx, nomsVal) + // todo(andy) set here or in flush? + w.tbl, err = w.tbl.SetAutoIncrementValue(ctx, seq) if err != nil { return err } - - w.aiTracker.Reset(w.tableName, val) + w.aiTracker.Set(w.tableName, seq) return w.flush(ctx) } @@ -129,7 +126,6 @@ func (w *prollyTableWriter) Close(ctx *sql.Context) error { if w.batched { return nil } - return w.flush(ctx) } @@ -187,23 +183,12 @@ func (w *prollyTableWriter) table(ctx context.Context) (t *doltdb.Table, err err return nil, err } - if w.aiCol.AutoIncrement && w.aiUpdate { - seq, err := w.aiTracker.Next(w.tableName, nil, nil) + if w.aiCol.AutoIncrement { + seq := w.aiTracker.Current(w.tableName) + t, err = t.SetAutoIncrementValue(ctx, seq) if err != nil { return nil, err } - vrw := w.tbl.ValueReadWriter() - - v, err := w.aiCol.TypeInfo.ConvertValueToNomsValue(ctx, vrw, seq) - if err != nil { - return nil, err - } - - t, err = t.SetAutoIncrementValue(ctx, v) - if err != nil { - return nil, err - } - w.aiUpdate = false } return t, nil @@ -289,7 +274,9 @@ func (m prollyIndexWriter) Map(ctx context.Context) (prolly.Map, error) { func (m prollyIndexWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error { for to := range m.keyMap { from := m.keyMap.MapOrdinal(to) - index.PutField(m.keyBld, to, sqlRow[from]) + if err := index.PutField(m.keyBld, to, sqlRow[from]); err != nil { + return err + } } k := m.keyBld.Build(sharePool) @@ -302,7 +289,9 @@ func (m prollyIndexWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error { for to := range m.valMap { from := m.valMap.MapOrdinal(to) - index.PutField(m.valBld, to, sqlRow[from]) + if err = index.PutField(m.valBld, to, sqlRow[from]); err != nil { + return err + } } v := m.valBld.Build(sharePool) @@ -312,7 +301,9 @@ func (m prollyIndexWriter) Insert(ctx *sql.Context, sqlRow sql.Row) error { func (m prollyIndexWriter) Delete(ctx *sql.Context, sqlRow sql.Row) error { for to := range m.keyMap { from := m.keyMap.MapOrdinal(to) - index.PutField(m.keyBld, to, sqlRow[from]) + if err := index.PutField(m.keyBld, to, sqlRow[from]); err != nil { + return err + } } k := m.keyBld.Build(sharePool) @@ -322,7 +313,9 @@ func (m prollyIndexWriter) Delete(ctx *sql.Context, sqlRow sql.Row) error { func (m prollyIndexWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error { for to := range m.keyMap { from := m.keyMap.MapOrdinal(to) - index.PutField(m.keyBld, to, oldRow[from]) + if err := index.PutField(m.keyBld, to, oldRow[from]); err != nil { + return err + } } oldKey := m.keyBld.Build(sharePool) @@ -334,7 +327,9 @@ func (m prollyIndexWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.R for to := range m.keyMap { from := m.keyMap.MapOrdinal(to) - index.PutField(m.keyBld, to, newRow[from]) + if err := index.PutField(m.keyBld, to, newRow[from]); err != nil { + return err + } } newKey := m.keyBld.Build(sharePool) @@ -347,7 +342,9 @@ func (m prollyIndexWriter) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.R for to := range m.valMap { from := m.valMap.MapOrdinal(to) - index.PutField(m.valBld, to, newRow[from]) + if err = index.PutField(m.valBld, to, newRow[from]); err != nil { + return err + } } v := m.valBld.Build(sharePool) diff --git a/go/libraries/doltcore/sqle/writer/prolly_write_session.go b/go/libraries/doltcore/sqle/writer/prolly_write_session.go index 4189b4fd71..32e3267077 100644 --- a/go/libraries/doltcore/sqle/writer/prolly_write_session.go +++ b/go/libraries/doltcore/sqle/writer/prolly_write_session.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" @@ -31,13 +32,14 @@ import ( type prollyWriteSession struct { workingSet *doltdb.WorkingSet tables map[string]*prollyTableWriter + tracker globalstate.AutoIncrementTracker mut *sync.RWMutex } var _ WriteSession = &prollyWriteSession{} // GetTableWriter implemented WriteSession. -func (s *prollyWriteSession) GetTableWriter(ctx context.Context, table string, database string, ait globalstate.AutoIncrementTracker, setter SessionRootSetter, batched bool) (TableWriter, error) { +func (s *prollyWriteSession) GetTableWriter(ctx context.Context, table, db string, setter SessionRootSetter, batched bool) (TableWriter, error) { s.mut.Lock() defer s.mut.Unlock() @@ -76,13 +78,13 @@ func (s *prollyWriteSession) GetTableWriter(ctx context.Context, table string, d twr := &prollyTableWriter{ tableName: table, - dbName: database, + dbName: db, primary: pw, secondary: sws, tbl: t, sch: sch, aiCol: autoCol, - aiTracker: ait, + aiTracker: s.tracker, sess: s, setter: setter, batched: batched, @@ -150,6 +152,14 @@ func (s *prollyWriteSession) flush(ctx context.Context) (*doltdb.WorkingSet, err return err } + if schema.HasAutoIncrement(wr.sch) { + v := s.tracker.Current(name) + t, err = t.SetAutoIncrementValue(ctx, v) + if err != nil { + return err + } + } + mu.Lock() defer mu.Unlock() tables[name] = t diff --git a/go/libraries/doltcore/sqle/writer/sessioned_table_editor.go b/go/libraries/doltcore/sqle/writer/sessioned_table_editor.go index 61ffcffe63..47a4893776 100644 --- a/go/libraries/doltcore/sqle/writer/sessioned_table_editor.go +++ b/go/libraries/doltcore/sqle/writer/sessioned_table_editor.go @@ -43,8 +43,8 @@ type sessionedTableEditor struct { var _ editor.TableEditor = &sessionedTableEditor{} func (ste *sessionedTableEditor) InsertKeyVal(ctx context.Context, key, val types.Tuple, tagToVal map[uint64]types.Value, errFunc editor.PKDuplicateErrFunc) error { - ste.tableEditSession.writeMutex.RLock() - defer ste.tableEditSession.writeMutex.RUnlock() + ste.tableEditSession.mut.RLock() + defer ste.tableEditSession.mut.RUnlock() err := ste.validateForInsert(ctx, tagToVal) if err != nil { @@ -56,8 +56,8 @@ func (ste *sessionedTableEditor) InsertKeyVal(ctx context.Context, key, val type } func (ste *sessionedTableEditor) DeleteByKey(ctx context.Context, key types.Tuple, tagToVal map[uint64]types.Value) error { - ste.tableEditSession.writeMutex.RLock() - defer ste.tableEditSession.writeMutex.RUnlock() + ste.tableEditSession.mut.RLock() + defer ste.tableEditSession.mut.RUnlock() if !ste.tableEditSession.opts.ForeignKeyChecksDisabled && len(ste.referencingTables) > 0 { err := ste.onDeleteHandleRowsReferencingValues(ctx, key, tagToVal) @@ -72,8 +72,8 @@ func (ste *sessionedTableEditor) DeleteByKey(ctx context.Context, key types.Tupl // InsertRow adds the given row to the table. If the row already exists, use UpdateRow. func (ste *sessionedTableEditor) InsertRow(ctx context.Context, dRow row.Row, errFunc editor.PKDuplicateErrFunc) error { - ste.tableEditSession.writeMutex.RLock() - defer ste.tableEditSession.writeMutex.RUnlock() + ste.tableEditSession.mut.RLock() + defer ste.tableEditSession.mut.RUnlock() dRowTaggedVals, err := dRow.TaggedValues() if err != nil { @@ -90,8 +90,8 @@ func (ste *sessionedTableEditor) InsertRow(ctx context.Context, dRow row.Row, er // DeleteRow removes the given key from the table. func (ste *sessionedTableEditor) DeleteRow(ctx context.Context, r row.Row) error { - ste.tableEditSession.writeMutex.RLock() - defer ste.tableEditSession.writeMutex.RUnlock() + ste.tableEditSession.mut.RLock() + defer ste.tableEditSession.mut.RUnlock() if !ste.tableEditSession.opts.ForeignKeyChecksDisabled && len(ste.referencingTables) > 0 { err := ste.handleReferencingRowsOnDelete(ctx, r) @@ -107,8 +107,8 @@ func (ste *sessionedTableEditor) DeleteRow(ctx context.Context, r row.Row) error // UpdateRow takes the current row and new row, and updates it accordingly. Any applicable rows from tables that have a // foreign key referencing this table will also be updated. func (ste *sessionedTableEditor) UpdateRow(ctx context.Context, dOldRow row.Row, dNewRow row.Row, errFunc editor.PKDuplicateErrFunc) error { - ste.tableEditSession.writeMutex.RLock() - defer ste.tableEditSession.writeMutex.RUnlock() + ste.tableEditSession.mut.RLock() + defer ste.tableEditSession.mut.RUnlock() return ste.updateRow(ctx, dOldRow, dNewRow, true, errFunc) } @@ -123,15 +123,9 @@ func (ste *sessionedTableEditor) HasEdits() bool { return ste.tableEditor.HasEdits() } -// GetAutoIncrementValue implements TableEditor. -func (ste *sessionedTableEditor) GetAutoIncrementValue() types.Value { - return ste.tableEditor.GetAutoIncrementValue() -} - -// SetAutoIncrementValue implements TableEditor. -func (ste *sessionedTableEditor) SetAutoIncrementValue(v types.Value) error { - ste.dirty = true - return ste.tableEditor.SetAutoIncrementValue(v) +// MarkDirty implements TableEditor. +func (ste *sessionedTableEditor) MarkDirty() { + ste.tableEditor.MarkDirty() } // Table implements TableEditor. diff --git a/go/libraries/doltcore/table/editor/keyless_table_editor.go b/go/libraries/doltcore/table/editor/keyless_table_editor.go index 6596b38be9..ac92ebe242 100644 --- a/go/libraries/doltcore/table/editor/keyless_table_editor.go +++ b/go/libraries/doltcore/table/editor/keyless_table_editor.go @@ -250,6 +250,13 @@ func (kte *keylessTableEditor) HasEdits() bool { return kte.dirty } +// MarkDirty implements TableEditor. +func (kte *keylessTableEditor) MarkDirty() { + kte.mu.Lock() + defer kte.mu.Unlock() + kte.dirty = true +} + // GetAutoIncrementValue implements TableEditor, AUTO_INCREMENT is not yet supported for keyless tables. func (kte *keylessTableEditor) GetAutoIncrementValue() types.Value { return types.NullValue diff --git a/go/libraries/doltcore/table/editor/pk_table_editor.go b/go/libraries/doltcore/table/editor/pk_table_editor.go index a783d723f5..a385d422f7 100644 --- a/go/libraries/doltcore/table/editor/pk_table_editor.go +++ b/go/libraries/doltcore/table/editor/pk_table_editor.go @@ -62,10 +62,10 @@ type TableEditor interface { InsertRow(ctx context.Context, r row.Row, errFunc PKDuplicateErrFunc) error UpdateRow(ctx context.Context, old, new row.Row, errFunc PKDuplicateErrFunc) error DeleteRow(ctx context.Context, r row.Row) error - HasEdits() bool - GetAutoIncrementValue() types.Value - SetAutoIncrementValue(v types.Value) (err error) + HasEdits() bool + MarkDirty() + SetConstraintViolation(ctx context.Context, k types.LesserValuable, v types.Valuable) error Table(ctx context.Context) (*doltdb.Table, error) @@ -128,10 +128,6 @@ type pkTableEditor struct { // Whenever any write operation occurs on the table editor, this is set to true for the lifetime of the editor. dirty uint32 - hasAutoInc bool - autoIncCol schema.Column - autoIncVal types.Value - // This mutex blocks on each operation, so that map reads and updates are serialized writeMutex *sync.Mutex } @@ -165,22 +161,6 @@ func newPkTableEditor(ctx context.Context, t *doltdb.Table, tableSch schema.Sche te.indexEds[i] = NewIndexEditor(ctx, index, indexData, tableSch, opts) } - err = tableSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { - if col.AutoIncrement { - te.autoIncVal, err = t.GetAutoIncrementValue(ctx) - if err != nil { - return true, err - } - te.hasAutoInc = true - te.autoIncCol = col - return true, err - } - return false, nil - }) - if err != nil { - return nil, err - } - return te, nil } @@ -446,32 +426,7 @@ func (te *pkTableEditor) insertKeyVal(ctx context.Context, keyHash hash.Hash, ke return err } - if te.hasAutoInc { - insertVal, ok := tagToVal[te.autoIncCol.Tag] - - if ok { - var less bool - - // float auto increment values should be rounded before comparing to the current auto increment values - if te.autoIncVal.Kind() == types.FloatKind { - rounded := types.Round(insertVal) - less, err = rounded.Less(te.nbf, te.autoIncVal) - } else { - less, err = insertVal.Less(te.nbf, te.autoIncVal) - } - - if err != nil { - return err - } - - if !less { - te.autoIncVal = types.Round(insertVal) - te.autoIncVal = types.Increment(te.autoIncVal) - } - } - } - - te.setDirty(true) + te.MarkDirty() return nil } @@ -546,7 +501,7 @@ func (te *pkTableEditor) DeleteByKey(ctx context.Context, key types.Tuple, tagTo return err } - te.setDirty(true) + te.MarkDirty() return te.tea.Delete(keyHash, key) } @@ -644,7 +599,7 @@ func (te *pkTableEditor) UpdateRow(ctx context.Context, dOldRow row.Row, dNewRow return err } - te.setDirty(true) + te.MarkDirty() if kvp, pkExists, err := te.tea.Get(ctx, newHash, dNewKeyVal); err != nil { return err @@ -655,37 +610,14 @@ func (te *pkTableEditor) UpdateRow(ctx context.Context, dOldRow row.Row, dNewRow return te.tea.Insert(newHash, dNewKeyVal, dNewRowVal) } -func (te *pkTableEditor) GetAutoIncrementValue() types.Value { - return te.autoIncVal -} - -func (te *pkTableEditor) SetAutoIncrementValue(v types.Value) (err error) { - te.writeMutex.Lock() - defer te.writeMutex.Unlock() - - te.setDirty(true) - te.autoIncVal = v - te.t, err = te.t.SetAutoIncrementValue(nil, te.autoIncVal) - - return err -} - // Table returns a Table based on the edits given, if any. If Flush() was not called prior, it will be called here. func (te *pkTableEditor) Table(ctx context.Context) (*doltdb.Table, error) { if !te.HasEdits() { return te.t, nil } - var err error - if te.hasAutoInc { - te.t, err = te.t.SetAutoIncrementValue(nil, te.autoIncVal) - if err != nil { - return nil, err - } - } - var tbl *doltdb.Table - err = func() error { + err := func() error { te.writeMutex.Lock() defer te.writeMutex.Unlock() @@ -816,7 +748,7 @@ func (te *pkTableEditor) SetConstraintViolation(ctx context.Context, k types.Les te.cvEditor = cvMap.Edit() } te.cvEditor.Set(k, v) - te.setDirty(true) + te.MarkDirty() return nil } @@ -845,13 +777,10 @@ func (te *pkTableEditor) Close(ctx context.Context) error { return nil } -func (te *pkTableEditor) setDirty(dirty bool) { - var val uint32 - if dirty { - val = 1 - } - - atomic.StoreUint32(&te.dirty, val) +// MarkDirty implements TableEditor. +func (te *pkTableEditor) MarkDirty() { + dirty := uint32(1) + atomic.StoreUint32(&te.dirty, dirty) } // hasEdits returns whether the table editor has had any successful write operations. This does not track whether the diff --git a/integration-tests/bats/sql-show.bats b/integration-tests/bats/sql-show.bats index 63b2a16162..e9b27a8106 100644 --- a/integration-tests/bats/sql-show.bats +++ b/integration-tests/bats/sql-show.bats @@ -13,12 +13,12 @@ teardown() { @test "sql-show: show table status on auto-increment table" { dolt sql -q "CREATE TABLE test(pk int NOT NULL AUTO_INCREMENT, c1 int, PRIMARY KEY (pk))" - run dolt sql -q "show table status where \`Auto_increment\`=1;" + run dolt sql -q "show table status;" [ "$status" -eq 0 ] [[ "$output" =~ "test" ]] || false dolt sql -q "INSERT INTO test (c1) VALUES (0)" - run dolt sql -q "show table status where \`Auto_increment\`=2;" + run dolt sql -q "show table status;" [ "$status" -eq 0 ] [[ "$output" =~ "test" ]] || false }