mirror of
https://github.com/dolthub/dolt.git
synced 2026-05-25 11:28:50 -05:00
reset auto increment counter on dolt_reset('--hard') (#8319)
This commit is contained in:
@@ -114,6 +114,11 @@ func doDoltReset(ctx *sql.Context, args []string) (int, error) {
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
err = dSess.ResetGlobals(ctx, dbName, roots.Working)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
|
||||
} else if apr.Contains(cli.SoftResetParam) {
|
||||
arg := ""
|
||||
if apr.NArg() > 1 {
|
||||
|
||||
@@ -61,41 +61,7 @@ func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb
|
||||
sequences: &sync.Map{},
|
||||
mm: mutexmap.NewMutexMap(),
|
||||
}
|
||||
|
||||
for _, root := range roots {
|
||||
root, err := root.ResolveRootValue(ctx)
|
||||
if err != nil {
|
||||
return &AutoIncrementTracker{}, err
|
||||
}
|
||||
|
||||
err = root.IterTables(ctx, func(tableName doltdb.TableName, table *doltdb.Table, sch schema.Schema) (bool, error) {
|
||||
ok := schema.HasAutoIncrement(sch)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
tableName = tableName.ToLower()
|
||||
|
||||
seq, err := table.GetAutoIncrementValue(ctx)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
// TODO: support schema name as part of the key
|
||||
tableNameStr := tableName.Name
|
||||
oldValue, loaded := ait.sequences.LoadOrStore(tableNameStr, seq)
|
||||
if loaded && seq > oldValue.(uint64) {
|
||||
ait.sequences.Store(tableNameStr, seq)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return &AutoIncrementTracker{}, err
|
||||
}
|
||||
}
|
||||
|
||||
ait.InitWithRoots(ctx, roots...)
|
||||
return &ait, nil
|
||||
}
|
||||
|
||||
@@ -109,13 +75,13 @@ func loadAutoIncValue(sequences *sync.Map, tableName string) uint64 {
|
||||
}
|
||||
|
||||
// Current returns the next value to be generated in the auto increment sequence for the table named
|
||||
func (a AutoIncrementTracker) Current(tableName string) uint64 {
|
||||
func (a *AutoIncrementTracker) Current(tableName string) uint64 {
|
||||
return loadAutoIncValue(a.sequences, tableName)
|
||||
}
|
||||
|
||||
// Next returns the next auto increment value for the table named using the provided value from an insert (which may
|
||||
// be null or 0, in which case it will be generated from the sequence).
|
||||
func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
|
||||
func (a *AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
|
||||
tbl = strings.ToLower(tbl)
|
||||
|
||||
given, err := CoerceAutoIncrementValue(insertVal)
|
||||
@@ -145,7 +111,7 @@ func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, e
|
||||
return given, nil
|
||||
}
|
||||
|
||||
func (a AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) {
|
||||
func (a *AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) {
|
||||
return CoerceAutoIncrementValue(val)
|
||||
}
|
||||
|
||||
@@ -172,7 +138,7 @@ func CoerceAutoIncrementValue(val interface{}) (uint64, error) {
|
||||
// Set sets the auto increment value for the table named, if it's greater than the one already registered for this
|
||||
// table. Otherwise, the update is silently disregarded. So far this matches the MySQL behavior, but Dolt uses the
|
||||
// maximum value for this table across all branches.
|
||||
func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
|
||||
func (a *AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
|
||||
tableName = strings.ToLower(tableName)
|
||||
|
||||
release := a.mm.Lock(tableName)
|
||||
@@ -190,7 +156,7 @@ func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *dol
|
||||
|
||||
// deepSet sets the auto increment value for the table named, if it's greater than the one on any branch head for this
|
||||
// database, ignoring the current in-memory tracker value
|
||||
func (a AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
|
||||
func (a *AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
|
||||
sess := DSessFromSess(ctx.Session)
|
||||
db, ok := sess.Provider().BaseDatabase(ctx, a.dbName)
|
||||
|
||||
@@ -371,7 +337,7 @@ func getMaxIndexValue(ctx context.Context, indexData durable.Index) (uint64, err
|
||||
}
|
||||
|
||||
// AddNewTable initializes a new table with an auto increment column to the tracker, as necessary
|
||||
func (a AutoIncrementTracker) AddNewTable(tableName string) {
|
||||
func (a *AutoIncrementTracker) AddNewTable(tableName string) {
|
||||
tableName = strings.ToLower(tableName)
|
||||
// only initialize the sequence for this table if no other branch has such a table
|
||||
a.sequences.LoadOrStore(tableName, uint64(1))
|
||||
@@ -380,7 +346,7 @@ func (a AutoIncrementTracker) AddNewTable(tableName string) {
|
||||
// DropTable drops the table with the name given.
|
||||
// To establish the new auto increment value, callers must also pass all other working sets in scope that may include
|
||||
// a table with the same name, omitting the working set that just deleted the table named.
|
||||
func (a AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error {
|
||||
func (a *AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error {
|
||||
tableName = strings.ToLower(tableName)
|
||||
|
||||
release := a.mm.Lock(tableName)
|
||||
@@ -430,3 +396,36 @@ func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName stri
|
||||
a.lockMode = lockMode
|
||||
return a.mm.Lock(tableName), nil
|
||||
}
|
||||
|
||||
func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error {
|
||||
for _, root := range roots {
|
||||
r, err := root.ResolveRootValue(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.IterTables(ctx, func(tableName doltdb.TableName, table *doltdb.Table, sch schema.Schema) (bool, error) {
|
||||
if !schema.HasAutoIncrement(sch) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
seq, err := table.GetAutoIncrementValue(ctx)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
tableNameStr := tableName.ToLower().Name
|
||||
if oldValue, loaded := a.sequences.LoadOrStore(tableNameStr, seq); loaded && seq > oldValue.(uint64) {
|
||||
a.sequences.Store(tableNameStr, seq)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1018,7 +1018,7 @@ func (d *DoltSession) SetStagingRoot(ctx *sql.Context, dbName string, newRoot do
|
||||
// via setRoot. This method is for clients that need to update more of the session state, such as the dolt_ functions.
|
||||
// Unlike setting the working root, this method always marks the database state dirty.
|
||||
func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roots) error {
|
||||
sessionState, _, err := d.LookupDbState(ctx, dbName)
|
||||
sessionState, _, err := d.lookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1031,6 +1031,25 @@ func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roo
|
||||
return d.SetWorkingSet(ctx, dbName, workingSet)
|
||||
}
|
||||
|
||||
func (d *DoltSession) ResetGlobals(ctx *sql.Context, dbName string, root doltdb.RootValue) error {
|
||||
sessionState, _, err := d.lookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tracker, err := sessionState.dbState.globalState.AutoIncrementTracker(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tracker.InitWithRoots(ctx, root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DoltSession) SetFileSystem(fs filesys.Filesys) {
|
||||
d.fs = fs
|
||||
}
|
||||
@@ -1059,8 +1078,8 @@ func (d *DoltSession) SetWorkingSet(ctx *sql.Context, dbName string, ws *doltdb.
|
||||
return err
|
||||
}
|
||||
|
||||
if writeSess := branchState.WriteSession(); writeSess != nil {
|
||||
err = writeSess.SetWorkingSet(ctx, ws)
|
||||
if branchState.writeSession != nil {
|
||||
err = branchState.writeSession.SetWorkingSet(ctx, ws)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1484,9 +1503,10 @@ func (d *DoltSession) dbSessionVarsStale(ctx *sql.Context, state *branchState) b
|
||||
return d.dbCache.CacheSessionVars(state, dtx)
|
||||
}
|
||||
|
||||
func (d DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession {
|
||||
d.globalsConf = conf
|
||||
return &d
|
||||
func (d *DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession {
|
||||
nd := *d
|
||||
nd.globalsConf = conf
|
||||
return &nd
|
||||
}
|
||||
|
||||
// PersistGlobal implements sql.PersistableSession
|
||||
|
||||
@@ -5637,6 +5637,89 @@ var DoltAutoIncrementTests = []queries.ScriptTest{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "hard reset dropped table restores auto increment",
|
||||
SetUpScript: []string{
|
||||
"create table t (a int primary key auto_increment, b int)",
|
||||
"insert into t (b) values (1), (2)",
|
||||
"call dolt_commit('-Am', 'initialize table')",
|
||||
"drop table t",
|
||||
"call dolt_reset('--hard')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "insert into t(b) values (3)",
|
||||
Expected: []sql.Row{
|
||||
{types.OkResult{RowsAffected: 1, InsertID: 3}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "select * from t order by a",
|
||||
Expected: []sql.Row{
|
||||
{1, 1},
|
||||
{2, 2},
|
||||
{3, 3},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// this behavior aligns with how we treat branches
|
||||
Name: "hard reset inserted rows continues auto increment",
|
||||
SetUpScript: []string{
|
||||
"create table t (a int primary key auto_increment, b int)",
|
||||
"insert into t (b) values (1), (2)",
|
||||
"call dolt_commit('-Am', 'initialize table')",
|
||||
"insert into t (b) values (3), (4)",
|
||||
"call dolt_reset('--hard')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "insert into t(b) values (5)",
|
||||
Expected: []sql.Row{
|
||||
{types.OkResult{RowsAffected: 1, InsertID: 5}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "select * from t order by a",
|
||||
Expected: []sql.Row{
|
||||
{1, 1},
|
||||
{2, 2},
|
||||
{5, 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "hard reset dropped table with branch restores auto increment",
|
||||
SetUpScript: []string{
|
||||
"create table t (a int primary key auto_increment, b int)",
|
||||
"insert into t (b) values (1), (2)",
|
||||
"call dolt_commit('-Am', 'initialize table')",
|
||||
"call dolt_checkout('-b', 'branch1')",
|
||||
"insert into t values (100, 100)",
|
||||
"call dolt_commit('-Am', 'other')",
|
||||
"call dolt_checkout('main')",
|
||||
"drop table t",
|
||||
"call dolt_reset('--hard')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "insert into t(b) values (101)",
|
||||
Expected: []sql.Row{
|
||||
{types.OkResult{RowsAffected: 1, InsertID: 101}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "select * from t order by a",
|
||||
Expected: []sql.Row{
|
||||
{1, 1},
|
||||
{2, 2},
|
||||
{101, 101},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var DoltCherryPickTests = []queries.ScriptTest{
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
package globalstate
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
@@ -38,9 +40,10 @@ type AutoIncrementTracker interface {
|
||||
// below the current value for this table. The table in the provided working set is assumed to already have the value
|
||||
// given, so the new global maximum is computed without regard for its value in that working set.
|
||||
Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error)
|
||||
|
||||
// AcquireTableLock acquires the auto increment lock on a table, and returns a callback function to release the lock.
|
||||
// Depending on the value of the `innodb_autoinc_lock_mode` system variable, the engine may need to acquire and hold
|
||||
// the lock for the duration of an insert statement.
|
||||
AcquireTableLock(ctx *sql.Context, tableName string) (func(), error)
|
||||
// InitWithRoots fills the AutoIncrementTracker with values pulled from each root in order.
|
||||
InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user