reset auto increment counter on dolt_reset('--hard') (#8319)

This commit is contained in:
James Cor
2024-09-04 01:14:23 -07:00
committed by GitHub
parent 913d6b5549
commit 8dca4a5f74
5 changed files with 159 additions and 49 deletions
@@ -114,6 +114,11 @@ func doDoltReset(ctx *sql.Context, args []string) (int, error) {
if err != nil {
return 1, err
}
err = dSess.ResetGlobals(ctx, dbName, roots.Working)
if err != nil {
return 1, err
}
} else if apr.Contains(cli.SoftResetParam) {
arg := ""
if apr.NArg() > 1 {
@@ -61,41 +61,7 @@ func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb
sequences: &sync.Map{},
mm: mutexmap.NewMutexMap(),
}
for _, root := range roots {
root, err := root.ResolveRootValue(ctx)
if err != nil {
return &AutoIncrementTracker{}, err
}
err = root.IterTables(ctx, func(tableName doltdb.TableName, table *doltdb.Table, sch schema.Schema) (bool, error) {
ok := schema.HasAutoIncrement(sch)
if !ok {
return false, nil
}
tableName = tableName.ToLower()
seq, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
}
// TODO: support schema name as part of the key
tableNameStr := tableName.Name
oldValue, loaded := ait.sequences.LoadOrStore(tableNameStr, seq)
if loaded && seq > oldValue.(uint64) {
ait.sequences.Store(tableNameStr, seq)
}
return false, nil
})
if err != nil {
return &AutoIncrementTracker{}, err
}
}
ait.InitWithRoots(ctx, roots...)
return &ait, nil
}
@@ -109,13 +75,13 @@ func loadAutoIncValue(sequences *sync.Map, tableName string) uint64 {
}
// Current returns the next value to be generated in the auto increment sequence for the table named
func (a AutoIncrementTracker) Current(tableName string) uint64 {
func (a *AutoIncrementTracker) Current(tableName string) uint64 {
return loadAutoIncValue(a.sequences, tableName)
}
// Next returns the next auto increment value for the table named using the provided value from an insert (which may
// be null or 0, in which case it will be generated from the sequence).
func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
func (a *AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
tbl = strings.ToLower(tbl)
given, err := CoerceAutoIncrementValue(insertVal)
@@ -145,7 +111,7 @@ func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, e
return given, nil
}
func (a AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) {
func (a *AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) {
return CoerceAutoIncrementValue(val)
}
@@ -172,7 +138,7 @@ func CoerceAutoIncrementValue(val interface{}) (uint64, error) {
// Set sets the auto increment value for the table named, if it's greater than the one already registered for this
// table. Otherwise, the update is silently disregarded. So far this matches the MySQL behavior, but Dolt uses the
// maximum value for this table across all branches.
func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
func (a *AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
tableName = strings.ToLower(tableName)
release := a.mm.Lock(tableName)
@@ -190,7 +156,7 @@ func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *dol
// deepSet sets the auto increment value for the table named, if it's greater than the one on any branch head for this
// database, ignoring the current in-memory tracker value
func (a AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
func (a *AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
sess := DSessFromSess(ctx.Session)
db, ok := sess.Provider().BaseDatabase(ctx, a.dbName)
@@ -371,7 +337,7 @@ func getMaxIndexValue(ctx context.Context, indexData durable.Index) (uint64, err
}
// AddNewTable initializes a new table with an auto increment column to the tracker, as necessary
func (a AutoIncrementTracker) AddNewTable(tableName string) {
func (a *AutoIncrementTracker) AddNewTable(tableName string) {
tableName = strings.ToLower(tableName)
// only initialize the sequence for this table if no other branch has such a table
a.sequences.LoadOrStore(tableName, uint64(1))
@@ -380,7 +346,7 @@ func (a AutoIncrementTracker) AddNewTable(tableName string) {
// DropTable drops the table with the name given.
// To establish the new auto increment value, callers must also pass all other working sets in scope that may include
// a table with the same name, omitting the working set that just deleted the table named.
func (a AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error {
func (a *AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error {
tableName = strings.ToLower(tableName)
release := a.mm.Lock(tableName)
@@ -430,3 +396,36 @@ func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName stri
a.lockMode = lockMode
return a.mm.Lock(tableName), nil
}
func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error {
for _, root := range roots {
r, err := root.ResolveRootValue(ctx)
if err != nil {
return err
}
err = r.IterTables(ctx, func(tableName doltdb.TableName, table *doltdb.Table, sch schema.Schema) (bool, error) {
if !schema.HasAutoIncrement(sch) {
return false, nil
}
seq, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
}
tableNameStr := tableName.ToLower().Name
if oldValue, loaded := a.sequences.LoadOrStore(tableNameStr, seq); loaded && seq > oldValue.(uint64) {
a.sequences.Store(tableNameStr, seq)
}
return false, nil
})
if err != nil {
return err
}
}
return nil
}
+26 -6
View File
@@ -1018,7 +1018,7 @@ func (d *DoltSession) SetStagingRoot(ctx *sql.Context, dbName string, newRoot do
// via setRoot. This method is for clients that need to update more of the session state, such as the dolt_ functions.
// Unlike setting the working root, this method always marks the database state dirty.
func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roots) error {
sessionState, _, err := d.LookupDbState(ctx, dbName)
sessionState, _, err := d.lookupDbState(ctx, dbName)
if err != nil {
return err
}
@@ -1031,6 +1031,25 @@ func (d *DoltSession) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roo
return d.SetWorkingSet(ctx, dbName, workingSet)
}
func (d *DoltSession) ResetGlobals(ctx *sql.Context, dbName string, root doltdb.RootValue) error {
sessionState, _, err := d.lookupDbState(ctx, dbName)
if err != nil {
return err
}
tracker, err := sessionState.dbState.globalState.AutoIncrementTracker(ctx)
if err != nil {
return err
}
err = tracker.InitWithRoots(ctx, root)
if err != nil {
return err
}
return nil
}
func (d *DoltSession) SetFileSystem(fs filesys.Filesys) {
d.fs = fs
}
@@ -1059,8 +1078,8 @@ func (d *DoltSession) SetWorkingSet(ctx *sql.Context, dbName string, ws *doltdb.
return err
}
if writeSess := branchState.WriteSession(); writeSess != nil {
err = writeSess.SetWorkingSet(ctx, ws)
if branchState.writeSession != nil {
err = branchState.writeSession.SetWorkingSet(ctx, ws)
if err != nil {
return err
}
@@ -1484,9 +1503,10 @@ func (d *DoltSession) dbSessionVarsStale(ctx *sql.Context, state *branchState) b
return d.dbCache.CacheSessionVars(state, dtx)
}
func (d DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession {
d.globalsConf = conf
return &d
func (d *DoltSession) WithGlobals(conf config.ReadWriteConfig) *DoltSession {
nd := *d
nd.globalsConf = conf
return &nd
}
// PersistGlobal implements sql.PersistableSession
@@ -5637,6 +5637,89 @@ var DoltAutoIncrementTests = []queries.ScriptTest{
},
},
},
{
Name: "hard reset dropped table restores auto increment",
SetUpScript: []string{
"create table t (a int primary key auto_increment, b int)",
"insert into t (b) values (1), (2)",
"call dolt_commit('-Am', 'initialize table')",
"drop table t",
"call dolt_reset('--hard')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "insert into t(b) values (3)",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1, InsertID: 3}},
},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{1, 1},
{2, 2},
{3, 3},
},
},
},
},
{
// this behavior aligns with how we treat branches
Name: "hard reset inserted rows continues auto increment",
SetUpScript: []string{
"create table t (a int primary key auto_increment, b int)",
"insert into t (b) values (1), (2)",
"call dolt_commit('-Am', 'initialize table')",
"insert into t (b) values (3), (4)",
"call dolt_reset('--hard')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "insert into t(b) values (5)",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1, InsertID: 5}},
},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{1, 1},
{2, 2},
{5, 5},
},
},
},
},
{
Name: "hard reset dropped table with branch restores auto increment",
SetUpScript: []string{
"create table t (a int primary key auto_increment, b int)",
"insert into t (b) values (1), (2)",
"call dolt_commit('-Am', 'initialize table')",
"call dolt_checkout('-b', 'branch1')",
"insert into t values (100, 100)",
"call dolt_commit('-Am', 'other')",
"call dolt_checkout('main')",
"drop table t",
"call dolt_reset('--hard')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "insert into t(b) values (101)",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1, InsertID: 101}},
},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{1, 1},
{2, 2},
{101, 101},
},
},
},
},
}
var DoltCherryPickTests = []queries.ScriptTest{
@@ -15,6 +15,8 @@
package globalstate
import (
"context"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
@@ -38,9 +40,10 @@ type AutoIncrementTracker interface {
// below the current value for this table. The table in the provided working set is assumed to already have the value
// given, so the new global maximum is computed without regard for its value in that working set.
Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error)
// AcquireTableLock acquires the auto increment lock on a table, and returns a callback function to release the lock.
// Depending on the value of the `innodb_autoinc_lock_mode` system variable, the engine may need to acquire and hold
// the lock for the duration of an insert statement.
AcquireTableLock(ctx *sql.Context, tableName string) (func(), error)
// InitWithRoots fills the AutoIncrementTracker with values pulled from each root in order.
InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error
}