Merge pull request #4071 from dolthub/zachmu/ai

global auto increment tracking
This commit is contained in:
Zach Musgrave
2022-08-17 16:34:18 -07:00
committed by GitHub
36 changed files with 964 additions and 366 deletions
+13 -7
View File
@@ -44,9 +44,12 @@ func CollectDBs(ctx context.Context, mrEnv *env.MultiRepoEnv, useBulkEditor bool
}
dEnv.DoltDB.SetCommitHooks(ctx, postCommitHooks)
db = newDatabase(name, dEnv, useBulkEditor)
db, err = newDatabase(ctx, name, dEnv, useBulkEditor)
if err != nil {
return false, err
}
if _, remote, ok := sql.SystemVariables.GetGlobal(dsess.ReadReplicaRemoteKey); ok && remote != "" {
if _, remote, ok := sql.SystemVariables.GetGlobal(dsess.ReadReplicaRemote); ok && remote != "" {
remoteName, ok := remote.(string)
if !ok {
return true, sql.ErrInvalidSystemVariableValue.New(remote)
@@ -89,7 +92,7 @@ func GetCommitHooks(ctx context.Context, dEnv *env.DoltEnv) ([]doltdb.CommitHook
return postCommitHooks, nil
}
func newDatabase(name string, dEnv *env.DoltEnv, useBulkEditor bool) sqle.Database {
func newDatabase(ctx context.Context, name string, dEnv *env.DoltEnv, useBulkEditor bool) (sqle.Database, error) {
deaf := dEnv.DbEaFactory()
if useBulkEditor {
deaf = dEnv.BulkDbEaFactory()
@@ -98,7 +101,7 @@ func newDatabase(name string, dEnv *env.DoltEnv, useBulkEditor bool) sqle.Databa
Deaf: deaf,
Tempdir: dEnv.TempTableFilesDir(),
}
return sqle.NewDatabase(name, dEnv.DbData(), opts)
return sqle.NewDatabase(ctx, name, dEnv.DbData(), opts)
}
// newReplicaDatabase creates a new dsqle.ReadReplicaDatabase. If the doltdb.SkipReplicationErrorsKey global variable is set,
@@ -110,7 +113,10 @@ func newReplicaDatabase(ctx context.Context, name string, remoteName string, dEn
Tempdir: dEnv.TempTableFilesDir(),
}
db := sqle.NewDatabase(name, dEnv.DbData(), opts)
db, err := sqle.NewDatabase(ctx, name, dEnv.DbData(), opts)
if err != nil {
return sqle.ReadReplicaDatabase{}, err
}
rrd, err := sqle.NewReadReplicaDatabase(ctx, db, remoteName, dEnv)
if err != nil {
@@ -125,9 +131,9 @@ func newReplicaDatabase(ctx context.Context, name string, remoteName string, dEn
}
func getPushOnWriteHook(ctx context.Context, dEnv *env.DoltEnv) (*doltdb.PushOnWriteHook, error) {
_, val, ok := sql.SystemVariables.GetGlobal(dsess.ReplicateToRemoteKey)
_, val, ok := sql.SystemVariables.GetGlobal(dsess.ReplicateToRemote)
if !ok {
return nil, sql.ErrUnknownSystemVariable.New(dsess.ReplicateToRemoteKey)
return nil, sql.ErrUnknownSystemVariable.New(dsess.ReplicateToRemote)
} else if val == "" {
return nil, nil
}
+4 -1
View File
@@ -226,7 +226,10 @@ func processFilterQuery(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commi
// Some functionality will not work on this kind of engine, e.g. many DOLT_ functions.
func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commit) (*sql.Context, *engine.SqlEngine, error) {
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := dsqle.NewDatabase(dbName, dEnv.DbData(), opts)
db, err := dsqle.NewDatabase(ctx, dbName, dEnv.DbData(), opts)
if err != nil {
return nil, nil, err
}
mrEnv, err := env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), dEnv.FS, dEnv.Version, dEnv.IgnoreLockFile, dEnv)
if err != nil {
+2 -2
View File
@@ -568,8 +568,8 @@ func TestDelete(t *testing.T) {
func TestCommitHooksNoErrors(t *testing.T) {
dEnv := dtestutils.CreateEnvWithSeedData(t)
sqle.AddDoltSystemVariables()
sql.SystemVariables.SetGlobal(dsess.SkipReplicationErrorsKey, true)
sql.SystemVariables.SetGlobal(dsess.ReplicateToRemoteKey, "unknown")
sql.SystemVariables.SetGlobal(dsess.SkipReplicationErrors, true)
sql.SystemVariables.SetGlobal(dsess.ReplicateToRemote, "unknown")
hooks, err := engine.GetCommitHooks(context.Background(), dEnv)
assert.NoError(t, err)
if len(hooks) < 1 {
@@ -407,7 +407,7 @@ func TestReadReplica(t *testing.T) {
if !ok {
t.Fatal("local config does not exist")
}
config.NewPrefixConfig(localCfg, env.SqlServerGlobalsPrefix).SetStrings(map[string]string{dsess.ReadReplicaRemoteKey: "remote1", dsess.ReplicateHeadsKey: "main,feature"})
config.NewPrefixConfig(localCfg, env.SqlServerGlobalsPrefix).SetStrings(map[string]string{dsess.ReadReplicaRemote: "remote1", dsess.ReplicateHeads: "main,feature"})
dsess.InitPersistedSystemVars(multiSetup.MrEnv.GetEnv(readReplicaDbName))
// start server as read replica
@@ -182,7 +182,9 @@ func (q Query) Exec(t *testing.T, dEnv *env.DoltEnv) error {
root, err := dEnv.WorkingRoot(context.Background())
require.NoError(t, err)
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
sqlDb := dsqle.NewDatabase("dolt", dEnv.DbData(), opts)
sqlDb, err := dsqle.NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
engine, sqlCtx, err := dsqle.NewTestEngine(t, dEnv, context.Background(), sqlDb, root)
require.NoError(t, err)
@@ -578,7 +578,9 @@ func TestDropPks(t *testing.T) {
ctx := context.Background()
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := NewDatabase("dolt", dEnv.DbData(), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
root, _ := dEnv.WorkingRoot(ctx)
engine, sqlCtx, err := NewTestEngine(t, dEnv, ctx, db, root)
require.NoError(t, err)
@@ -118,7 +118,9 @@ func parseTime(timestampLayout bool, value string) time.Time {
func executeSelect(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root *doltdb.RootValue, query string) ([]interface{}, error) {
var err error
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := sqle.NewDatabase("dolt", dEnv.DbData(), opts)
db, err := sqle.NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
engine, sqlCtx, err := sqle.NewTestEngine(t, dEnv, ctx, db, root)
if err != nil {
return nil, err
@@ -147,7 +149,9 @@ func executeSelect(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root *d
func executeModify(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root *doltdb.RootValue, query string) (*doltdb.RootValue, error) {
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := sqle.NewDatabase("dolt", dEnv.DbData(), opts)
db, err := sqle.NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
engine, sqlCtx, err := sqle.NewTestEngine(t, dEnv, ctx, db, root)
if err != nil {
return nil, err
+6 -2
View File
@@ -42,7 +42,9 @@ type SetupFn func(t *testing.T, dEnv *env.DoltEnv)
func executeSelect(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root *doltdb.RootValue, query string) ([]sql.Row, sql.Schema, error) {
var err error
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := NewDatabase("dolt", dEnv.DbData(), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(t, dEnv, ctx, db, root)
if err != nil {
return nil, nil, err
@@ -69,7 +71,9 @@ func executeSelect(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root *d
// Runs the query given and returns the error (if any).
func executeModify(t *testing.T, ctx context.Context, dEnv *env.DoltEnv, root *doltdb.RootValue, query string) (*doltdb.RootValue, error) {
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := NewDatabase("dolt", dEnv.DbData(), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(t, dEnv, ctx, db, root)
if err != nil {
+70 -11
View File
@@ -164,15 +164,20 @@ var _ sql.TransactionDatabase = Database{}
var _ globalstate.StateProvider = Database{}
// NewDatabase returns a new dolt database to use in queries.
func NewDatabase(name string, dbData env.DbData, editOpts editor.Options) Database {
func NewDatabase(ctx context.Context, name string, dbData env.DbData, editOpts editor.Options) (Database, error) {
globalState, err := globalstate.NewGlobalStateStoreForDb(ctx, dbData.Ddb)
if err != nil {
return Database{}, err
}
return Database{
name: name,
ddb: dbData.Ddb,
rsr: dbData.Rsr,
rsw: dbData.Rsw,
gs: globalstate.NewGlobalStateStore(),
gs: globalState,
editOpts: editOpts,
}
}, nil
}
// GetInitialDBState returns the InitialDbState for |db|.
@@ -724,12 +729,13 @@ func (db Database) DropTable(ctx *sql.Context, tableName string) error {
return nil
}
root, err := db.GetRoot(ctx)
ws, err := db.GetWorkingSet(ctx)
if err != nil {
return err
}
tableExists, err := root.HasTable(ctx, tableName)
root := ws.WorkingRoot()
tbl, tableExists, err := root.GetTable(ctx, tableName)
if err != nil {
return err
}
@@ -743,20 +749,73 @@ func (db Database) DropTable(ctx *sql.Context, tableName string) error {
return err
}
ws, err := ds.WorkingSet(ctx, db.Name())
sch, err := tbl.GetSchema(ctx)
if err != nil {
return err
}
ait, err := db.gs.GetAutoIncrementTracker(ctx, ws)
if err != nil {
return err
if schema.HasAutoIncrement(sch) {
ddb, _ := ds.GetDoltDB(ctx, db.name)
err = db.removeTableFromAutoIncrementTracker(ctx, tableName, ddb, ws.Ref())
if err != nil {
return err
}
}
ait.DropTable(tableName)
return db.SetRoot(ctx, newRoot)
}
// removeTableFromAutoIncrementTracker updates the global auto increment tracking as necessary to deal with the table
// given being dropped or truncated. The auto increment value for this table after this operation will either be reset
// back to 1 if this table only exists in the working set given, or to the highest value in all other working sets
// otherwise. This operation is expensive if the
func (db Database) removeTableFromAutoIncrementTracker(
ctx *sql.Context,
tableName string,
ddb *doltdb.DoltDB,
ws ref.WorkingSetRef,
) error {
branches, err := ddb.GetBranches(ctx)
if err != nil {
return err
}
var wses []*doltdb.WorkingSet
for _, b := range branches {
wsRef, err := ref.WorkingSetRefForHead(b)
if err != nil {
return err
}
if wsRef == ws {
// skip this branch, we've deleted it here
continue
}
ws, err := ddb.ResolveWorkingSet(ctx, wsRef)
if err == doltdb.ErrWorkingSetNotFound {
// skip, continue working on other branches
continue
} else if err != nil {
return err
}
wses = append(wses, ws)
}
ait, err := db.gs.GetAutoIncrementTracker(ctx)
if err != nil {
return err
}
err = ait.DropTable(ctx, tableName, wses...)
if err != nil {
return err
}
return nil
}
// CreateTable creates a table with the name and schema given.
func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.PrimaryKeySchema) error {
if strings.ToLower(tableName) == doltdb.DocTableName {
@@ -805,7 +864,7 @@ func (db Database) createSqlTable(ctx *sql.Context, tableName string, sch sql.Pr
}
if schema.HasAutoIncrement(doltSch) {
ait, err := db.gs.GetAutoIncrementTracker(ctx, ws)
ait, err := db.gs.GetAutoIncrementTracker(ctx)
if err != nil {
return err
}
@@ -261,7 +261,11 @@ func (p DoltDatabaseProvider) CreateDatabase(ctx *sql.Context, name string) erro
ForeignKeyChecksDisabled: fkChecks.(int8) == 0,
}
db := NewDatabase(name, newEnv.DbData(), opts)
db, err := NewDatabase(ctx, name, newEnv.DbData(), opts)
if err != nil {
return err
}
formattedName := formatDbMapKeyName(db.Name())
p.databases[formattedName] = db
p.dbLocations[formattedName] = newEnv.FS
@@ -344,7 +348,11 @@ func (p DoltDatabaseProvider) cloneDatabaseFromRemote(ctx *sql.Context, dbName,
ForeignKeyChecksDisabled: fkChecks.(int8) == 0,
}
db := NewDatabase(dbName, dEnv.DbData(), opts)
db, err := NewDatabase(ctx, dbName, dEnv.DbData(), opts)
if err != nil {
return err
}
p.databases[formatDbMapKeyName(db.Name())] = db
dbstate, err := GetInitialDBState(ctx, db)
@@ -129,19 +129,19 @@ func DoDoltBackup(ctx *sql.Context, args []string) (int, error) {
return statusErr, err
}
credsFile, _ := sess.GetSessionVariable(ctx, dsess.AwsCredsFileKey)
credsFile, _ := sess.GetSessionVariable(ctx, dsess.AwsCredsFile)
credsFileStr, isStr := credsFile.(string)
if isStr && len(credsFileStr) > 0 {
params[dbfactory.AWSCredsFileParam] = credsFileStr
}
credsProfile, err := sess.GetSessionVariable(ctx, dsess.AwsCredsProfileKey)
credsProfile, err := sess.GetSessionVariable(ctx, dsess.AwsCredsProfile)
profStr, isStr := credsProfile.(string)
if isStr && len(profStr) > 0 {
params[dbfactory.AWSCredsProfile] = profStr
}
credsRegion, err := sess.GetSessionVariable(ctx, dsess.AwsCredsRegionKey)
credsRegion, err := sess.GetSessionVariable(ctx, dsess.AwsCredsRegion)
regionStr, isStr := credsRegion.(string)
if isStr && len(regionStr) > 0 {
params[dbfactory.AWSRegionParam] = regionStr
+2 -14
View File
@@ -44,18 +44,6 @@ const (
Batched
)
const (
ReplicateToRemoteKey = "dolt_replicate_to_remote"
ReadReplicaRemoteKey = "dolt_read_replica_remote"
SkipReplicationErrorsKey = "dolt_skip_replication_errors"
ReplicateHeadsKey = "dolt_replicate_heads"
ReplicateAllHeadsKey = "dolt_replicate_all_heads"
AsyncReplicationKey = "dolt_async_replication"
AwsCredsFileKey = "aws_credentials_file"
AwsCredsProfileKey = "aws_credentials_profile"
AwsCredsRegionKey = "aws_credentials_region"
)
var ErrWorkingSetChanges = goerrors.NewKind("Cannot switch working set, session state is dirty. " +
"Rollback or commit changes before changing working sets.")
var ErrSessionNotPeristable = errors.New("session is not persistable")
@@ -782,7 +770,7 @@ func (d *DoltSession) SwitchWorkingSet(
// make a fresh WriteSession, discard existing WriteSession
opts := sessionState.WriteSession.GetOptions()
nbf := ws.WorkingRoot().VRW().Format()
tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx, ws)
tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx)
if err != nil {
return err
}
@@ -942,7 +930,7 @@ func (d *DoltSession) AddDB(ctx *sql.Context, dbState InitialDbState) error {
} else if dbState.WorkingSet != nil {
sessionState.WorkingSet = dbState.WorkingSet
tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx, sessionState.WorkingSet)
tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx)
if err != nil {
return err
}
+28 -77
View File
@@ -15,11 +15,13 @@
package dsess
import (
"fmt"
"strings"
"github.com/dolthub/go-mysql-server/sql"
)
// Per-DB system variables
const (
HeadKeySuffix = "_head"
HeadRefKeySuffix = "_head_ref"
@@ -28,83 +30,24 @@ const (
DefaultBranchKeySuffix = "_default_branch"
)
// General system variables
const (
DoltCommitOnTransactionCommit = "dolt_transaction_commit"
TransactionsDisabledSysVar = "dolt_transactions_disabled"
ForceTransactionCommit = "dolt_force_transaction_commit"
CurrentBatchModeKey = "batch_mode"
AllowCommitConflicts = "dolt_allow_commit_conflicts"
ReplicateToRemote = "dolt_replicate_to_remote"
ReadReplicaRemote = "dolt_read_replica_remote"
SkipReplicationErrors = "dolt_skip_replication_errors"
ReplicateHeads = "dolt_replicate_heads"
ReplicateAllHeads = "dolt_replicate_all_heads"
AsyncReplication = "dolt_async_replication"
AwsCredsFile = "aws_credentials_file"
AwsCredsProfile = "aws_credentials_profile"
AwsCredsRegion = "aws_credentials_region"
)
func init() {
sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{
{ // If true, causes a Dolt commit to occur when you commit a transaction.
Name: DoltCommitOnTransactionCommit,
Scope: sql.SystemVariableScope_Both,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(DoltCommitOnTransactionCommit),
Default: int8(0),
},
{
Name: TransactionsDisabledSysVar,
Scope: sql.SystemVariableScope_Session,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(TransactionsDisabledSysVar),
Default: int8(0),
},
{ // If true, disables the conflict and constraint violation check when you commit a transaction.
Name: ForceTransactionCommit,
Scope: sql.SystemVariableScope_Both,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(ForceTransactionCommit),
Default: int8(0),
},
{
Name: CurrentBatchModeKey,
Scope: sql.SystemVariableScope_Session,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemIntType(CurrentBatchModeKey, -9223372036854775808, 9223372036854775807, false),
Default: int64(0),
},
{ // If true, disables the conflict violation check when you commit a transaction.
Name: AllowCommitConflicts,
Scope: sql.SystemVariableScope_Session,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(AllowCommitConflicts),
Default: int8(0),
},
{
Name: AwsCredsFileKey,
Scope: sql.SystemVariableScope_Session,
Dynamic: false,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(AwsCredsFileKey),
Default: nil,
},
{
Name: AwsCredsProfileKey,
Scope: sql.SystemVariableScope_Session,
Dynamic: false,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(AwsCredsProfileKey),
Default: nil,
},
{
Name: AwsCredsRegionKey,
Scope: sql.SystemVariableScope_Session,
Dynamic: false,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(AwsCredsRegionKey),
Default: nil,
},
})
}
// DefineSystemVariablesForDB defines per database dolt-session variables in the engine as necessary
func DefineSystemVariablesForDB(name string) {
if _, _, ok := sql.SystemVariables.GetGlobal(name + HeadKeySuffix); !ok {
@@ -199,16 +142,24 @@ func IsWorkingKey(key string) (bool, string) {
return false, ""
}
func IsDefaultBranchKey(key string) (bool, string) {
if strings.HasSuffix(key, DefaultBranchKeySuffix) {
return true, key[:len(key)-len(DefaultBranchKeySuffix)]
}
return false, ""
}
func IsReadOnlyVersionKey(key string) bool {
return strings.HasSuffix(key, HeadKeySuffix) ||
strings.HasSuffix(key, StagedKeySuffix) ||
strings.HasSuffix(key, WorkingKeySuffix)
}
// GetBooleanSystemVar returns a boolean value for the system variable named, returning an error if the variable
// doesn't exist in the session or has a non-boolean type.
func GetBooleanSystemVar(ctx *sql.Context, varName string) (bool, error) {
val, err := ctx.GetSessionVariable(ctx, varName)
if err != nil {
return false, err
}
i8, isInt8 := val.(int8)
if !isInt8 {
return false, fmt.Errorf("unexpected type for variable %s: %T", varName, val)
}
return i8 == 1, nil
}
@@ -89,58 +89,80 @@ func TestSingleScript(t *testing.T) {
var scripts = []queries.ScriptTest{
{
Name: "Temp playground for collation testing",
Name: "truncate table",
SetUpScript: []string{
"CREATE TABLE test1 (pk BIGINT PRIMARY KEY, v1 VARCHAR(255) COLLATE utf16_unicode_ci, INDEX(v1));",
"CREATE TABLE test2 (pk BIGINT PRIMARY KEY, v1 VARCHAR(255) COLLATE utf8mb4_0900_bin, INDEX(v1));",
"INSERT INTO test1 VALUES (1, 'abc'), (2, 'ABC'), (3, 'aBc'), (4, 'AbC');",
"INSERT INTO test2 VALUES (1, 'abc'), (2, 'ABC'), (3, 'aBc'), (4, 'AbC');",
"create table t (a int primary key auto_increment, b int)",
"call dolt_commit('-am', 'empty table')",
"call dolt_branch('branch1')",
"call dolt_branch('branch2')",
"insert into t (b) values (1), (2)",
"call dolt_commit('-am', 'two values on main')",
"call dolt_checkout('branch1')",
"insert into t (b) values (3), (4)",
"call dolt_commit('-am', 'two values on branch1')",
"call dolt_checkout('branch2')",
"insert into t (b) values (5), (6)",
"call dolt_checkout('branch1')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT v1, pk FROM test1 ORDER BY pk;",
Query: "truncate table t",
Expected: []sql.Row{{sql.NewOkResult(2)}},
},
{
Query: "call dolt_checkout('main')",
SkipResultsCheck: true,
},
{
// highest value in any branch is 6
Query: "insert into t (b) values (7), (8)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 7}}},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{"abc", 1}, {"ABC", 2}, {"aBc", 3}, {"AbC", 4},
{1, 1},
{2, 2},
{7, 7},
{8, 8},
},
},
{
Query: "SELECT v1, pk FROM test1 ORDER BY v1, pk;",
Query: "truncate table t",
Expected: []sql.Row{{sql.NewOkResult(4)}},
},
{
Query: "call dolt_checkout('branch2')",
SkipResultsCheck: true,
},
{
// highest value in any branch is still 6 (truncated table above)
Query: "insert into t (b) values (7), (8)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 7}}},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{"abc", 1}, {"ABC", 2}, {"aBc", 3}, {"AbC", 4},
{5, 5},
{6, 6},
{7, 7},
{8, 8},
},
},
{
Query: "SELECT v1, pk FROM test1 WHERE v1 > 'AbC' ORDER BY v1, pk;",
Expected: []sql.Row{},
Query: "truncate table t",
Expected: []sql.Row{{sql.NewOkResult(4)}},
},
{
Query: "SELECT v1, pk FROM test1 WHERE v1 >= 'AbC' ORDER BY v1, pk;",
Expected: []sql.Row{
{"abc", 1}, {"ABC", 2}, {"aBc", 3}, {"AbC", 4},
},
// no value on any branch
Query: "insert into t (b) values (1), (2)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 1}}},
},
{
Query: "SELECT v1, pk FROM test2 ORDER BY pk;",
Query: "select * from t order by a",
Expected: []sql.Row{
{"abc", 1}, {"ABC", 2}, {"aBc", 3}, {"AbC", 4},
},
},
{
Query: "SELECT v1, pk FROM test2 ORDER BY v1, pk;",
Expected: []sql.Row{
{"ABC", 2}, {"AbC", 4}, {"aBc", 3}, {"abc", 1},
},
},
{
Query: "SELECT v1, pk FROM test2 WHERE v1 > 'AbC' ORDER BY v1, pk;",
Expected: []sql.Row{
{"aBc", 3}, {"abc", 1},
},
},
{
Query: "SELECT v1, pk FROM test2 WHERE v1 >= 'AbC' ORDER BY v1, pk;",
Expected: []sql.Row{
{"AbC", 4}, {"aBc", 3}, {"abc", 1},
{1, 1},
{2, 2},
},
},
},
@@ -746,6 +768,34 @@ func TestDoltMerge(t *testing.T) {
}
}
func TestDoltAutoIncrement(t *testing.T) {
for _, script := range DoltAutoIncrementTests {
// doing commits on different branches is antagonistic to engine reuse, use a new engine on each script
enginetest.TestScript(t, newDoltHarness(t), script)
}
for _, script := range BrokenAutoIncrementTests {
t.Run(script.Name, func(t *testing.T) {
t.Skip()
enginetest.TestScript(t, newDoltHarness(t), script)
})
}
}
func TestDoltAutoIncrementPrepared(t *testing.T) {
for _, script := range DoltAutoIncrementTests {
// doing commits on different branches is antagonistic to engine reuse, use a new engine on each script
enginetest.TestScriptPrepared(t, newDoltHarness(t), script)
}
for _, script := range BrokenAutoIncrementTests {
t.Run(script.Name, func(t *testing.T) {
t.Skip()
enginetest.TestScriptPrepared(t, newDoltHarness(t), script)
})
}
}
func TestDoltConflictsTableNameTable(t *testing.T) {
for _, script := range DoltConflictTableNameTableTests {
enginetest.TestScript(t, newDoltHarness(t), script)
@@ -34,7 +34,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/store/types"
@@ -46,20 +45,19 @@ const (
)
type DoltHarness struct {
t *testing.T
multiRepoEnv *env.MultiRepoEnv
createdEnvs map[string]*env.DoltEnv
session *dsess.DoltSession
databases []sqle.Database
databaseGlobalStates []globalstate.GlobalState
hashes []string
parallelism int
skippedQueries []string
setupData []setup.SetupScript
resetData []setup.SetupScript
initDbs map[string]struct{}
autoInc bool
engine *gms.Engine
t *testing.T
multiRepoEnv *env.MultiRepoEnv
createdEnvs map[string]*env.DoltEnv
session *dsess.DoltSession
databases []sqle.Database
hashes []string
parallelism int
skippedQueries []string
setupData []setup.SetupScript
resetData []setup.SetupScript
initDbs map[string]struct{}
autoInc bool
engine *gms.Engine
}
var _ enginetest.Harness = (*DoltHarness)(nil)
@@ -114,19 +112,51 @@ func (d *DoltHarness) Setup(setupData ...[]setup.SetupScript) {
// resetScripts returns a set of queries that will reset the given database
// names. If [autoInc], the queries for resetting autoincrement tables are
// included.
func resetScripts(dbs []string, autoInc bool) []setup.SetupScript {
var resetCmds setup.SetupScript
func (d *DoltHarness) resetScripts() []setup.SetupScript {
ctx := enginetest.NewContext(d)
_, res := enginetest.MustQuery(ctx, d.engine, "select schema_name from information_schema.schemata where schema_name not in ('information_schema');")
var dbs []string
for i := range res {
dbs = append(dbs, res[i][0].(string))
}
var resetCmds []setup.SetupScript
for i := range dbs {
db := dbs[i]
resetCmds = append(resetCmds, fmt.Sprintf("use %s", db))
resetCmds = append(resetCmds, "call dclean()")
resetCmds = append(resetCmds, "call dreset('--hard', 'head')")
if autoInc {
resetCmds = append(resetCmds, setup.AutoincrementData[0]...)
resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("use %s", db)})
// Any auto increment tables must be dropped and recreated to get a fresh state for the global auto increment
// sequence trackers
_, aiTables := enginetest.MustQuery(ctx, d.engine,
fmt.Sprintf("select distinct table_name from information_schema.columns where extra = 'auto_increment' and table_schema = '%s';", db))
for _, tableNameRow := range aiTables {
tableName := tableNameRow[0].(string)
// special handling for auto_increment_tbl, which is expected to start with particular values
if strings.ToLower(tableName) == "auto_increment_tbl" {
resetCmds = append(resetCmds, setup.AutoincrementData...)
continue
}
resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("drop table %s", tableName)})
ctx := enginetest.NewContext(d).WithCurrentDB(db)
_, showCreateResult := enginetest.MustQuery(ctx, d.engine, fmt.Sprintf("show create table %s;", tableName))
var createTableStatement strings.Builder
for _, row := range showCreateResult {
createTableStatement.WriteString(row[1].(string))
}
resetCmds = append(resetCmds, setup.SetupScript{createTableStatement.String()})
}
resetCmds = append(resetCmds, setup.SetupScript{"call dclean()"})
resetCmds = append(resetCmds, setup.SetupScript{"call dreset('--hard', 'head')"})
}
resetCmds = append(resetCmds, "use mydb")
return []setup.SetupScript{resetCmds}
resetCmds = append(resetCmds, setup.SetupScript{"use mydb"})
return resetCmds
}
// commitScripts returns a set of queries that will commit the workingsets
@@ -161,11 +191,7 @@ func (d *DoltHarness) NewEngine(t *testing.T) (*gms.Engine, error) {
d.engine = e
var res []sql.Row
// todo(max): need better way to reset autoincrement regardless of test type
ctx := enginetest.NewContext(d)
_, res = enginetest.MustQuery(ctx, e, "select count(*) from information_schema.tables where table_name = 'auto_increment_tbl';")
d.autoInc = res[0][0].(int64) > 0
_, res = enginetest.MustQuery(ctx, e, "select schema_name from information_schema.schemata where schema_name not in ('information_schema');")
var dbs []string
for i := range res {
@@ -186,13 +212,9 @@ func (d *DoltHarness) NewEngine(t *testing.T) (*gms.Engine, error) {
//todo(max): easier if tests specify their databases ahead of time
ctx := enginetest.NewContext(d)
_, res := enginetest.MustQuery(ctx, d.engine, "select schema_name from information_schema.schemata where schema_name not in ('information_schema');")
var dbs []string
for i := range res {
dbs = append(dbs, res[i][0].(string))
}
e, err := enginetest.RunEngineScripts(ctx, d.engine, d.resetScripts(), d.SupportsNativeIndexCreation())
return enginetest.RunEngineScripts(ctx, d.engine, resetScripts(dbs, d.autoInc), d.SupportsNativeIndexCreation())
return e, err
}
// WithParallelism returns a copy of the harness with parallelism set to the given number of threads. A value of 0 or
@@ -290,7 +312,6 @@ func (d *DoltHarness) NewDatabase(name string) sql.Database {
func (d *DoltHarness) NewDatabases(names ...string) []sql.Database {
d.databases = nil
d.databaseGlobalStates = nil
for _, name := range names {
dEnv := dtestutils.CreateTestEnvWithName(name)
@@ -298,11 +319,10 @@ func (d *DoltHarness) NewDatabases(names ...string) []sql.Database {
store.SetValidateContentAddresses(true)
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := sqle.NewDatabase(name, dEnv.DbData(), opts)
d.databases = append(d.databases, db)
db, err := sqle.NewDatabase(context.Background(), name, dEnv.DbData(), opts)
require.NoError(d.t, err)
globalState := globalstate.NewGlobalStateStore()
d.databaseGlobalStates = append(d.databaseGlobalStates, globalState)
d.databases = append(d.databases, db)
d.multiRepoEnv.AddOrReplaceEnv(name, dEnv)
d.createdEnvs[db.Name()] = dEnv
@@ -1736,7 +1736,7 @@ var MergeScripts = []queries.ScriptTest{
"INSERT INTO t (pk,c0) VALUES (3,3), (4,4);",
"CALL dolt_commit('-a', '-m', 'cm2');",
"CALL dolt_checkout('main');",
"INSERT INTO t (c0) VALUES (2);",
"INSERT INTO t (c0) VALUES (5);",
"CALL dolt_commit('-a', '-m', 'cm3');",
},
Assertions: []queries.ScriptTestAssertion{
@@ -1745,19 +1745,19 @@ var MergeScripts = []queries.ScriptTest{
Expected: []sql.Row{{0, 0}},
},
{
Query: "INSERT INTO t VALUES (NULL,5),(6,6),(NULL,7);",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 3, InsertID: 5}}},
Query: "INSERT INTO t VALUES (NULL,6),(7,7),(NULL,8);",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 3, InsertID: 6}}},
},
{
Query: "SELECT * FROM t ORDER BY pk;",
Expected: []sql.Row{
{1, 1},
{2, 2},
{3, 3},
{4, 4},
{5, 5},
{6, 6},
{7, 7},
{8, 8},
},
},
},
@@ -1807,7 +1807,7 @@ var MergeScripts = []queries.ScriptTest{
"INSERT INTO t VALUES (4,4), (5,5);",
"CALL dolt_commit('-am', 'cm2');",
"CALL dolt_checkout('main');",
"INSERT INTO t (c0) VALUES (2);",
"INSERT INTO t (c0) VALUES (6);",
"CALL dolt_commit('-am', 'cm3');",
},
Assertions: []queries.ScriptTestAssertion{
@@ -1816,18 +1816,18 @@ var MergeScripts = []queries.ScriptTest{
Expected: []sql.Row{{0, 0}},
},
{
Query: "INSERT INTO t VALUES (3,3),(NULL,6);",
Query: "INSERT INTO t VALUES (3,3),(NULL,7);",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 3}}},
},
{
Query: "SELECT * FROM t ORDER BY pk;",
Expected: []sql.Row{
{1, 1},
{2, 2},
{3, 3},
{4, 4},
{5, 5},
{6, 6},
{7, 7},
},
},
},
@@ -5243,3 +5243,228 @@ var DoltRemoteTestScripts = []queries.ScriptTest{
},
},
}
// DoltAutoIncrementTests is tests of dolt's global auto increment logic
var DoltAutoIncrementTests = []queries.ScriptTest{
{
Name: "insert on different branches",
SetUpScript: []string{
"create table t (a int primary key auto_increment, b int)",
"call dolt_commit('-am', 'empty table')",
"call dolt_branch('branch1')",
"call dolt_branch('branch2')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "insert into t (b) values (1), (2)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 1}}},
},
{
Query: "call dolt_commit('-am', 'two values on main')",
SkipResultsCheck: true,
},
{
Query: "call dolt_checkout('branch1')",
SkipResultsCheck: true,
},
{
Query: "insert into t (b) values (3), (4)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 3}}},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{3, 3},
{4, 4},
},
},
{
Query: "call dolt_commit('-am', 'two values on branch1')",
SkipResultsCheck: true,
},
{
Query: "call dolt_checkout('branch2')",
SkipResultsCheck: true,
},
{
Query: "insert into t (b) values (5), (6)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 5}}},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{5, 5},
{6, 6},
},
},
},
},
{
Name: "drop table",
SetUpScript: []string{
"create table t (a int primary key auto_increment, b int)",
"call dolt_commit('-am', 'empty table')",
"call dolt_branch('branch1')",
"call dolt_branch('branch2')",
"insert into t (b) values (1), (2)",
"call dolt_commit('-am', 'two values on main')",
"call dolt_checkout('branch1')",
"insert into t (b) values (3), (4)",
"call dolt_commit('-am', 'two values on branch1')",
"call dolt_checkout('branch2')",
"insert into t (b) values (5), (6)",
"call dolt_checkout('branch1')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "drop table t",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "call dolt_checkout('main')",
SkipResultsCheck: true,
},
{
// highest value in any branch is 6
Query: "insert into t (b) values (7), (8)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 7}}},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{1, 1},
{2, 2},
{7, 7},
{8, 8},
},
},
{
Query: "drop table t",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "call dolt_checkout('branch2')",
SkipResultsCheck: true,
},
{
// highest value in any branch is still 6 (dropped table above)
Query: "insert into t (b) values (7), (8)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 7}}},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{5, 5},
{6, 6},
{7, 7},
{8, 8},
},
},
{
Query: "drop table t",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "create table t (a int primary key auto_increment, b int)",
SkipResultsCheck: true,
},
{
// no value on any branch
Query: "insert into t (b) values (1), (2)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 1}}},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{1, 1},
{2, 2},
},
},
},
},
}
var BrokenAutoIncrementTests = []queries.ScriptTest{
{
// truncate table doesn't reset the persisted auto increment counter of tables on other branches, which leads to
// the value not resetting to 1 after a truncate if the table exists on other branches, even if truncated on every
// branch
Name: "truncate table",
SetUpScript: []string{
"create table t (a int primary key auto_increment, b int)",
"call dolt_commit('-am', 'empty table')",
"call dolt_branch('branch1')",
"call dolt_branch('branch2')",
"insert into t (b) values (1), (2)",
"call dolt_commit('-am', 'two values on main')",
"call dolt_checkout('branch1')",
"insert into t (b) values (3), (4)",
"call dolt_commit('-am', 'two values on branch1')",
"call dolt_checkout('branch2')",
"insert into t (b) values (5), (6)",
"call dolt_checkout('branch1')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "truncate table t",
Expected: []sql.Row{{sql.NewOkResult(2)}},
},
{
Query: "call dolt_checkout('main')",
SkipResultsCheck: true,
},
{
// highest value in any branch is 6
Query: "insert into t (b) values (7), (8)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 7}}},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{1, 1},
{2, 2},
{7, 7},
{8, 8},
},
},
{
Query: "truncate table t",
Expected: []sql.Row{{sql.NewOkResult(4)}},
},
{
Query: "call dolt_checkout('branch2')",
SkipResultsCheck: true,
},
{
// highest value in any branch is still 6 (truncated table above)
Query: "insert into t (b) values (7), (8)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 7}}},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{5, 5},
{6, 6},
{7, 7},
{8, 8},
},
},
{
Query: "truncate table t",
Expected: []sql.Row{{sql.NewOkResult(4)}},
},
{
// no value on any branch
Query: "insert into t (b) values (1), (2)",
Expected: []sql.Row{{sql.OkResult{RowsAffected: 2, InsertID: 1}}},
},
{
Query: "select * from t order by a",
Expected: []sql.Row{
{1, 1},
{2, 2},
},
},
},
},
}
@@ -17,76 +17,72 @@ package globalstate
import (
"context"
"math"
"strings"
"sync"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
)
// CoerceAutoIncrementValue converts |val| into an AUTO_INCREMENT sequence value
func CoerceAutoIncrementValue(val interface{}) (uint64, error) {
switch typ := val.(type) {
case float32:
val = math.Round(float64(typ))
case float64:
val = math.Round(typ)
}
var err error
val, err = sql.Uint64.Convert(val)
if err != nil {
return 0, err
}
if val == nil || val == uint64(0) {
return 0, nil
}
return val.(uint64), nil
}
// NewAutoIncrementTracker returns a new autoincrement tracker for the working set given
func NewAutoIncrementTracker(ctx context.Context, ws *doltdb.WorkingSet) (AutoIncrementTracker, error) {
ait := AutoIncrementTracker{
wsRef: ws.Ref(),
sequences: make(map[string]uint64),
mu: &sync.Mutex{},
}
// collect auto increment values
err := ws.WorkingRoot().IterTables(ctx, func(name string, table *doltdb.Table, sch schema.Schema) (bool, error) {
ok := schema.HasAutoIncrement(sch)
if !ok {
return false, nil
}
seq, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
}
ait.sequences[name] = seq
return false, nil
})
return ait, err
}
type AutoIncrementTracker struct {
wsRef ref.WorkingSetRef
sequences map[string]uint64
mu *sync.Mutex
}
// NewAutoIncrementTracker returns a new autoincrement tracker for the working sets given. All working sets must be
// considered because the auto increment value for a table is tracked globally, across all branches.
func NewAutoIncrementTracker(ctx context.Context, wses ...*doltdb.WorkingSet) (AutoIncrementTracker, error) {
ait := AutoIncrementTracker{
sequences: make(map[string]uint64),
mu: &sync.Mutex{},
}
for _, ws := range wses {
err := ws.WorkingRoot().IterTables(ctx, func(tableName string, table *doltdb.Table, sch schema.Schema) (bool, error) {
ok := schema.HasAutoIncrement(sch)
if !ok {
return false, nil
}
tableName = strings.ToLower(tableName)
seq, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
}
if seq > ait.sequences[tableName] {
ait.sequences[tableName] = seq
}
return false, nil
})
if err != nil {
return AutoIncrementTracker{}, err
}
}
return ait, nil
}
// Current returns the next value to be generated in the auto increment sequence for the table named
func (a AutoIncrementTracker) Current(tableName string) uint64 {
a.mu.Lock()
defer a.mu.Unlock()
return a.sequences[tableName]
return a.sequences[strings.ToLower(tableName)]
}
// Next returns the next auto increment value for the table named using the provided value from an insert (which may
// be null or 0, in which case it will be generated from the sequence).
func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
a.mu.Lock()
defer a.mu.Unlock()
tbl = strings.ToLower(tbl)
given, err := CoerceAutoIncrementValue(insertVal)
if err != nil {
return 0, err
@@ -110,20 +106,90 @@ func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, e
return given, nil
}
// CoerceAutoIncrementValue converts |val| into an AUTO_INCREMENT sequence value
func CoerceAutoIncrementValue(val interface{}) (uint64, error) {
switch typ := val.(type) {
case float32:
val = math.Round(float64(typ))
case float64:
val = math.Round(typ)
}
var err error
val, err = sql.Uint64.Convert(val)
if err != nil {
return 0, err
}
if val == nil || val == uint64(0) {
return 0, nil
}
return val.(uint64), nil
}
// Set sets the auto increment value for the table named, if it's greater than the one already registered for this
// table. Otherwise, the update is silently disregarded. So far this matches the MySQL behavior, but Dolt uses the
// maximum value for this table across all branches.
func (a AutoIncrementTracker) Set(tableName string, val uint64) {
a.mu.Lock()
defer a.mu.Unlock()
a.sequences[tableName] = val
tableName = strings.ToLower(tableName)
existing := a.sequences[tableName]
if val > existing {
a.sequences[strings.ToLower(tableName)] = val
}
}
// AddNewTable initializes a new table with an auto increment column to the tracker, as necessary
func (a AutoIncrementTracker) AddNewTable(tableName string) {
a.mu.Lock()
defer a.mu.Unlock()
a.sequences[tableName] = uint64(1)
tableName = strings.ToLower(tableName)
// only initialize the sequence for this table if no other branch has such a table
if _, ok := a.sequences[tableName]; !ok {
a.sequences[tableName] = uint64(1)
}
}
func (a AutoIncrementTracker) DropTable(tableName string) {
// DropTable drops the table with the name given.
// To establish the new auto increment value, callers must also pass all other working sets in scope that may include
// a table with the same name, omitting the working set that just deleted the table named.
func (a AutoIncrementTracker) DropTable(ctx context.Context, tableName string, wses ...*doltdb.WorkingSet) error {
a.mu.Lock()
defer a.mu.Unlock()
tableName = strings.ToLower(tableName)
delete(a.sequences, tableName)
// Get the new highest value from all tables in the working sets given
for _, ws := range wses {
table, _, exists, err := ws.WorkingRoot().GetTableInsensitive(ctx, tableName)
if err != nil {
return err
}
if !exists {
continue
}
sch, err := table.GetSchema(ctx)
if err != nil {
return err
}
if schema.HasAutoIncrement(sch) {
seq, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return err
}
if seq > a.sequences[tableName] {
a.sequences[tableName] = seq
}
}
}
return nil
}
@@ -18,8 +18,9 @@ import (
"context"
"sync"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
)
@@ -27,33 +28,46 @@ type StateProvider interface {
GetGlobalState() GlobalState
}
func NewGlobalStateStore() GlobalState {
return GlobalState{
trackerMap: make(map[ref.WorkingSetRef]AutoIncrementTracker),
mu: &sync.Mutex{},
func NewGlobalStateStoreForDb(ctx context.Context, db *doltdb.DoltDB) (GlobalState, error) {
branches, err := db.GetBranches(ctx)
if err != nil {
return GlobalState{}, err
}
var wses []*doltdb.WorkingSet
for _, b := range branches {
wsRef, err := ref.WorkingSetRefForHead(b)
if err != nil {
return GlobalState{}, err
}
ws, err := db.ResolveWorkingSet(ctx, wsRef)
if err == doltdb.ErrWorkingSetNotFound {
// skip, continue working on other branches
continue
} else if err != nil {
return GlobalState{}, err
}
wses = append(wses, ws)
}
tracker, err := NewAutoIncrementTracker(ctx, wses...)
if err != nil {
return GlobalState{}, err
}
return GlobalState{
aiTracker: tracker,
mu: &sync.Mutex{},
}, nil
}
type GlobalState struct {
trackerMap map[ref.WorkingSetRef]AutoIncrementTracker
mu *sync.Mutex
aiTracker AutoIncrementTracker
mu *sync.Mutex
}
func (g GlobalState) GetAutoIncrementTracker(ctx context.Context, ws *doltdb.WorkingSet) (AutoIncrementTracker, error) {
g.mu.Lock()
defer g.mu.Unlock()
ait, ok := g.trackerMap[ws.Ref()]
if ok {
return ait, nil
}
var err error
ait, err = NewAutoIncrementTracker(ctx, ws)
if err != nil {
return AutoIncrementTracker{}, err
}
g.trackerMap[ws.Ref()] = ait
return ait, nil
func (g GlobalState) GetAutoIncrementTracker(ctx *sql.Context) (AutoIncrementTracker, error) {
return g.aiTracker, nil
}
@@ -40,7 +40,9 @@ func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *e
root, err := dEnv.WorkingRoot(context.Background())
require.NoError(t, err)
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := dsqle.NewDatabase("dolt", dEnv.DbData(), opts)
db, err := dsqle.NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
engine, sqlCtx, err := dsqle.NewTestEngine(t, dEnv, context.Background(), db, root)
require.NoError(t, err)
@@ -307,7 +307,11 @@ func schemaToSchemaString(sch sql.Schema) (string, error) {
func sqlNewEngine(dEnv *env.DoltEnv) (*sqle.Engine, error) {
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := dsql.NewDatabase("dolt", dEnv.DbData(), opts)
db, err := dsql.NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
if err != nil {
return nil, err
}
mrEnv, err := env.MultiEnvForDirectory(context.Background(), dEnv.Config.WriteableConfig(), dEnv.FS, dEnv.Version, dEnv.IgnoreLockFile, dEnv)
if err != nil {
return nil, err
@@ -98,14 +98,14 @@ func (rrd ReadReplicaDatabase) StartTransaction(ctx *sql.Context, tCharacteristi
}
func (rrd ReadReplicaDatabase) PullFromRemote(ctx *sql.Context) error {
_, headsArg, ok := sql.SystemVariables.GetGlobal(dsess.ReplicateHeadsKey)
_, headsArg, ok := sql.SystemVariables.GetGlobal(dsess.ReplicateHeads)
if !ok {
return sql.ErrUnknownSystemVariable.New(dsess.ReplicateHeadsKey)
return sql.ErrUnknownSystemVariable.New(dsess.ReplicateHeads)
}
_, allHeads, ok := sql.SystemVariables.GetGlobal(dsess.ReplicateAllHeadsKey)
_, allHeads, ok := sql.SystemVariables.GetGlobal(dsess.ReplicateAllHeads)
if !ok {
return sql.ErrUnknownSystemVariable.New(dsess.ReplicateAllHeadsKey)
return sql.ErrUnknownSystemVariable.New(dsess.ReplicateAllHeads)
}
dSess := dsess.DSessFromSess(ctx.Session)
@@ -119,7 +119,7 @@ func (rrd ReadReplicaDatabase) PullFromRemote(ctx *sql.Context) error {
case headsArg != "":
heads, ok := headsArg.(string)
if !ok {
return sql.ErrInvalidSystemVariableValue.New(dsess.ReplicateHeadsKey)
return sql.ErrInvalidSystemVariableValue.New(dsess.ReplicateHeads)
}
branches := parseBranches(heads)
err := rrd.srcDB.Rebase(ctx)
+9 -6
View File
@@ -30,9 +30,9 @@ import (
)
func getPushOnWriteHook(ctx context.Context, bThreads *sql.BackgroundThreads, dEnv *env.DoltEnv, logger io.Writer) (doltdb.CommitHook, error) {
_, val, ok := sql.SystemVariables.GetGlobal(dsess.ReplicateToRemoteKey)
_, val, ok := sql.SystemVariables.GetGlobal(dsess.ReplicateToRemote)
if !ok {
return nil, sql.ErrUnknownSystemVariable.New(dsess.ReplicateToRemoteKey)
return nil, sql.ErrUnknownSystemVariable.New(dsess.ReplicateToRemote)
} else if val == "" {
return nil, nil
}
@@ -57,8 +57,8 @@ func getPushOnWriteHook(ctx context.Context, bThreads *sql.BackgroundThreads, dE
return nil, err
}
_, val, ok = sql.SystemVariables.GetGlobal(dsess.AsyncReplicationKey)
if _, val, ok = sql.SystemVariables.GetGlobal(dsess.AsyncReplicationKey); ok && val == SysVarTrue {
_, val, ok = sql.SystemVariables.GetGlobal(dsess.AsyncReplication)
if _, val, ok = sql.SystemVariables.GetGlobal(dsess.AsyncReplication); ok && val == SysVarTrue {
return doltdb.NewAsyncPushOnWriteHook(bThreads, ddb, dEnv.TempTableFilesDir(), logger)
}
@@ -95,7 +95,10 @@ func newReplicaDatabase(ctx context.Context, name string, remoteName string, dEn
Deaf: dEnv.DbEaFactory(),
}
db := NewDatabase(name, dEnv.DbData(), opts)
db, err := NewDatabase(ctx, name, dEnv.DbData(), opts)
if err != nil {
return ReadReplicaDatabase{}, err
}
rrd, err := NewReadReplicaDatabase(ctx, db, remoteName, dEnv)
if err != nil {
@@ -123,7 +126,7 @@ func ApplyReplicationConfig(ctx context.Context, bThreads *sql.BackgroundThreads
}
dEnv.DoltDB.SetCommitHooks(ctx, postCommitHooks)
if _, remote, ok := sql.SystemVariables.GetGlobal(dsess.ReadReplicaRemoteKey); ok && remote != "" {
if _, remote, ok := sql.SystemVariables.GetGlobal(dsess.ReadReplicaRemote); ok && remote != "" {
remoteName, ok := remote.(string)
if !ok {
return nil, sql.ErrInvalidSystemVariableValue.New(remote)
@@ -31,8 +31,8 @@ import (
func TestCommitHooksNoErrors(t *testing.T) {
dEnv := dtestutils.CreateEnvWithSeedData(t)
AddDoltSystemVariables()
sql.SystemVariables.SetGlobal(dsess.SkipReplicationErrorsKey, true)
sql.SystemVariables.SetGlobal(dsess.ReplicateToRemoteKey, "unknown")
sql.SystemVariables.SetGlobal(dsess.SkipReplicationErrors, true)
sql.SystemVariables.SetGlobal(dsess.ReplicateToRemote, "unknown")
bThreads := sql.NewBackgroundThreads()
hooks, err := GetCommitHooks(context.Background(), bThreads, dEnv, &buffer.Buffer{})
assert.NoError(t, err)
@@ -35,9 +35,11 @@ func TestSchemaTableRecreationOlder(t *testing.T) {
ctx := NewTestSQLCtx(context.Background())
dEnv := dtestutils.CreateTestEnv()
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := NewDatabase("dolt", dEnv.DbData(), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
dbState := getDbState(t, db, dEnv)
err := dsess.DSessFromSess(ctx.Session).AddDB(ctx, dbState)
err = dsess.DSessFromSess(ctx.Session).AddDB(ctx, dbState)
require.NoError(t, err)
ctx.SetCurrentDatabase(db.Name())
@@ -111,9 +113,11 @@ func TestSchemaTableRecreation(t *testing.T) {
ctx := NewTestSQLCtx(context.Background())
dEnv := dtestutils.CreateTestEnv()
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := NewDatabase("dolt", dEnv.DbData(), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
dbState := getDbState(t, db, dEnv)
err := dsess.DSessFromSess(ctx.Session).AddDB(ctx, dbState)
err = dsess.DSessFromSess(ctx.Session).AddDB(ctx, dbState)
require.NoError(t, err)
ctx.SetCurrentDatabase(db.Name())
+9 -3
View File
@@ -65,7 +65,9 @@ func TestSqlBatchInserts(t *testing.T) {
root, _ := dEnv.WorkingRoot(ctx)
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := NewDatabase("dolt", dEnv.DbData(), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(t, dEnv, ctx, db, root)
require.NoError(t, err)
dsess.DSessFromSess(sqlCtx.Session).EnableBatchedMode()
@@ -155,7 +157,9 @@ func TestSqlBatchInsertIgnoreReplace(t *testing.T) {
root, _ := dEnv.WorkingRoot(ctx)
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := NewDatabase("dolt", dEnv.DbData(), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(t, dEnv, ctx, db, root)
require.NoError(t, err)
dsess.DSessFromSess(sqlCtx.Session).EnableBatchedMode()
@@ -195,7 +199,9 @@ func TestSqlBatchInsertErrors(t *testing.T) {
root, _ := dEnv.WorkingRoot(ctx)
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := NewDatabase("dolt", dEnv.DbData(), opts)
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(t, dEnv, ctx, db, root)
require.NoError(t, err)
dsess.DSessFromSess(sqlCtx.Session).EnableBatchedMode()
+77 -13
View File
@@ -32,58 +32,122 @@ func init() {
func AddDoltSystemVariables() {
sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{
{
Name: dsess.ReplicateToRemoteKey,
Name: dsess.ReplicateToRemote,
Scope: sql.SystemVariableScope_Global,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(dsess.ReplicateToRemoteKey),
Type: sql.NewSystemStringType(dsess.ReplicateToRemote),
Default: "",
},
{
Name: dsess.ReadReplicaRemoteKey,
Name: dsess.ReadReplicaRemote,
Scope: sql.SystemVariableScope_Global,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(dsess.ReadReplicaRemoteKey),
Type: sql.NewSystemStringType(dsess.ReadReplicaRemote),
Default: "",
},
{
Name: dsess.SkipReplicationErrorsKey,
Name: dsess.SkipReplicationErrors,
Scope: sql.SystemVariableScope_Global,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(dsess.SkipReplicationErrorsKey),
Type: sql.NewSystemBoolType(dsess.SkipReplicationErrors),
Default: int8(0),
},
{
Name: dsess.ReplicateHeadsKey,
Name: dsess.ReplicateHeads,
Scope: sql.SystemVariableScope_Both,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(dsess.ReplicateHeadsKey),
Type: sql.NewSystemStringType(dsess.ReplicateHeads),
Default: "",
},
{
Name: dsess.ReplicateAllHeadsKey,
Name: dsess.ReplicateAllHeads,
Scope: sql.SystemVariableScope_Both,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(dsess.ReplicateAllHeadsKey),
Type: sql.NewSystemBoolType(dsess.ReplicateAllHeads),
Default: int8(0),
},
{
Name: dsess.AsyncReplicationKey,
Name: dsess.AsyncReplication,
Scope: sql.SystemVariableScope_Both,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(dsess.AsyncReplicationKey),
Type: sql.NewSystemBoolType(dsess.AsyncReplication),
Default: int8(0),
},
{ // If true, causes a Dolt commit to occur when you commit a transaction.
Name: dsess.DoltCommitOnTransactionCommit,
Scope: sql.SystemVariableScope_Both,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(dsess.DoltCommitOnTransactionCommit),
Default: int8(0),
},
{
Name: dsess.TransactionsDisabledSysVar,
Scope: sql.SystemVariableScope_Session,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(dsess.TransactionsDisabledSysVar),
Default: int8(0),
},
{ // If true, disables the conflict and constraint violation check when you commit a transaction.
Name: dsess.ForceTransactionCommit,
Scope: sql.SystemVariableScope_Both,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(dsess.ForceTransactionCommit),
Default: int8(0),
},
{
Name: dsess.CurrentBatchModeKey,
Scope: sql.SystemVariableScope_Session,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemIntType(dsess.CurrentBatchModeKey, -9223372036854775808, 9223372036854775807, false),
Default: int64(0),
},
{ // If true, disables the conflict violation check when you commit a transaction.
Name: dsess.AllowCommitConflicts,
Scope: sql.SystemVariableScope_Session,
Dynamic: true,
SetVarHintApplies: false,
Type: sql.NewSystemBoolType(dsess.AllowCommitConflicts),
Default: int8(0),
},
{
Name: dsess.AwsCredsFile,
Scope: sql.SystemVariableScope_Session,
Dynamic: false,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(dsess.AwsCredsFile),
Default: nil,
},
{
Name: dsess.AwsCredsProfile,
Scope: sql.SystemVariableScope_Session,
Dynamic: false,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(dsess.AwsCredsProfile),
Default: nil,
},
{
Name: dsess.AwsCredsRegion,
Scope: sql.SystemVariableScope_Session,
Dynamic: false,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(dsess.AwsCredsRegion),
Default: nil,
},
})
}
func SkipReplicationWarnings() bool {
_, skip, ok := sql.SystemVariables.GetGlobal(dsess.SkipReplicationErrorsKey)
_, skip, ok := sql.SystemVariables.GetGlobal(dsess.SkipReplicationErrors)
if !ok {
panic("dolt system variables not loaded")
}
+37 -10
View File
@@ -676,7 +676,8 @@ func (t *WritableDoltTable) Truncate(ctx *sql.Context) (int, error) {
}
numOfRows := int(rowData.Count())
newTable, err := truncate(ctx, table, sch)
sess := dsess.DSessFromSess(ctx.Session)
newTable, err := t.truncate(ctx, table, sch, sess)
if err != nil {
return 0, err
}
@@ -700,7 +701,12 @@ func (t *WritableDoltTable) Truncate(ctx *sql.Context) (int, error) {
// truncate returns an empty copy of the table given by setting the rows and indexes to empty. The schema can be
// updated at the same time.
func truncate(ctx *sql.Context, table *doltdb.Table, sch schema.Schema) (*doltdb.Table, error) {
func (t *WritableDoltTable) truncate(
ctx *sql.Context,
table *doltdb.Table,
sch schema.Schema,
sess *dsess.DoltSession,
) (*doltdb.Table, error) {
empty, err := durable.NewEmptyIndex(ctx, table.ValueReadWriter(), table.NodeStore(), sch)
if err != nil {
return nil, err
@@ -718,6 +724,19 @@ func truncate(ctx *sql.Context, table *doltdb.Table, sch schema.Schema) (*doltdb
}
}
ws, err := sess.WorkingSet(ctx, t.db.name)
if err != nil {
return nil, err
}
if schema.HasAutoIncrement(sch) {
ddb, _ := sess.GetDoltDB(ctx, t.db.name)
err = t.db.removeTableFromAutoIncrementTracker(ctx, t.Name(), ddb, ws.Ref())
if err != nil {
return nil, err
}
}
// truncate table resets auto-increment value
return doltdb.NewTable(ctx, table.ValueReadWriter(), table.NodeStore(), sch, empty, idxSet, nil)
}
@@ -1160,11 +1179,7 @@ func (t *AlterableDoltTable) AddColumn(ctx *sql.Context, column *sql.Column, ord
}
if column.AutoIncrement {
ws, err := t.db.GetWorkingSet(ctx)
if err != nil {
return err
}
ait, err := t.db.gs.GetAutoIncrementTracker(ctx, ws)
ait, err := t.db.gs.GetAutoIncrementTracker(ctx)
if err != nil {
return err
}
@@ -1345,7 +1360,8 @@ func (t *AlterableDoltTable) RewriteInserter(
})
}
dt, err = truncate(ctx, dt, newSch)
// TODO: test for this when the table is auto increment and exists on another branch
dt, err = t.truncate(ctx, dt, newSch, sess)
if err != nil {
return nil, err
}
@@ -1371,7 +1387,7 @@ func (t *AlterableDoltTable) RewriteInserter(
// TODO: figure out locking. Other DBs automatically lock a table during this kind of operation, we should probably
// do the same. We're messing with global auto-increment values here and it's not safe.
ait, err := t.db.gs.GetAutoIncrementTracker(ctx, newWs)
ait, err := t.db.gs.GetAutoIncrementTracker(ctx)
if err != nil {
return nil, err
}
@@ -1688,7 +1704,7 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c
return err
}
ait, err := t.db.gs.GetAutoIncrementTracker(ctx, ws)
ait, err := t.db.gs.GetAutoIncrementTracker(ctx)
if err != nil {
return err
}
@@ -1698,6 +1714,17 @@ func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, c
ait.Set(t.tableName, seq)
}
// If we're removing an auto inc property, we just need to update global auto increment tracking
if existingCol.AutoIncrement && !col.AutoIncrement {
// TODO: this isn't transactional, and it should be
sess := dsess.DSessFromSess(ctx.Session)
ddb, _ := sess.GetDoltDB(ctx, t.db.name)
err = t.db.removeTableFromAutoIncrementTracker(ctx, t.Name(), ddb, ws.Ref())
if err != nil {
return err
}
}
newRoot, err := root.PutTable(ctx, t.tableName, updatedTable)
if err != nil {
return err
+2 -4
View File
@@ -104,8 +104,7 @@ func NewTempTable(
newWs := ws.WithWorkingRoot(newRoot)
gs := globalstate.NewGlobalStateStore()
ait, err := gs.GetAutoIncrementTracker(ctx, newWs)
ait, err := globalstate.NewAutoIncrementTracker(ctx, newWs)
if err != nil {
return nil, err
}
@@ -152,8 +151,7 @@ func setTempTableRoot(t *TempTable) func(ctx *sql.Context, dbName string, newRoo
ws := dbState.WorkingSet
newWs := ws.WithWorkingRoot(newRoot)
gs := globalstate.NewGlobalStateStore()
ait, err := gs.GetAutoIncrementTracker(ctx, newWs)
ait, err := globalstate.NewAutoIncrementTracker(ctx, newWs)
if err != nil {
return err
}
+5 -2
View File
@@ -43,7 +43,8 @@ import (
// the updated root, or an error. Statements in the input string are split by `;\n`
func ExecuteSql(t *testing.T, dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*doltdb.RootValue, error) {
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := NewDatabase("dolt", dEnv.DbData(), opts)
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
engine, ctx, err := NewTestEngine(t, dEnv, context.Background(), db, root)
dsess.DSessFromSess(ctx.Session).EnableBatchedMode()
@@ -182,7 +183,9 @@ func ExecuteSelect(t *testing.T, dEnv *env.DoltEnv, ddb *doltdb.DoltDB, root *do
}
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := NewDatabase("dolt", dbData, opts)
db, err := NewDatabase(context.Background(), "dolt", dbData, opts)
require.NoError(t, err)
engine, ctx, err := NewTestEngine(t, dEnv, context.Background(), db, root)
if err != nil {
return nil, err
@@ -163,8 +163,10 @@ func TestTableEditor(t *testing.T) {
ctx := sqle.NewTestSQLCtx(context.Background())
root, _ := dEnv.WorkingRoot(context.Background())
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: dEnv.TempTableFilesDir()}
db := sqle.NewDatabase("dolt", dEnv.DbData(), opts)
err := dsess.DSessFromSess(ctx.Session).AddDB(ctx, getDbState(t, db, dEnv))
db, err := sqle.NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
err = dsess.DSessFromSess(ctx.Session).AddDB(ctx, getDbState(t, db, dEnv))
require.NoError(t, err)
ctx.SetCurrentDatabase(db.Name())
@@ -59,7 +59,7 @@ type WriteSessionFlusher interface {
type nomsWriteSession struct {
workingSet *doltdb.WorkingSet
tables map[string]*sessionedTableEditor
tracker globalstate.AutoIncrementTracker
aiTracker globalstate.AutoIncrementTracker
mut *sync.RWMutex // This mutex is specifically for changes that affect the TES or all STEs
opts editor.Options
}
@@ -69,12 +69,12 @@ var _ WriteSession = &nomsWriteSession{}
// NewWriteSession creates and returns a WriteSession. Inserting a nil root is not an error, as there are
// locations that do not have a root at the time of this call. However, a root must be set through SetRoot before any
// table editors are returned.
func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, tracker globalstate.AutoIncrementTracker, opts editor.Options) WriteSession {
func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, aiTracker globalstate.AutoIncrementTracker, opts editor.Options) WriteSession {
if types.IsFormat_DOLT(nbf) {
return &prollyWriteSession{
workingSet: ws,
tables: make(map[string]*prollyTableWriter),
tracker: tracker,
aiTracker: aiTracker,
mut: &sync.RWMutex{},
}
}
@@ -82,7 +82,7 @@ func NewWriteSession(nbf *types.NomsBinFormat, ws *doltdb.WorkingSet, tracker gl
return &nomsWriteSession{
workingSet: ws,
tables: make(map[string]*sessionedTableEditor),
tracker: tracker,
aiTracker: aiTracker,
mut: &sync.RWMutex{},
opts: opts,
}
@@ -127,7 +127,7 @@ func (s *nomsWriteSession) GetTableWriter(ctx context.Context, table, db string,
tableEditor: te,
flusher: s,
batched: batched,
autoInc: s.tracker,
autoInc: s.aiTracker,
setter: setter,
}, nil
}
@@ -186,7 +186,7 @@ func (s *nomsWriteSession) flush(ctx context.Context) (*doltdb.WorkingSet, error
// Update the auto increment value for the table if a tracker was provided
// TODO: the table probably needs an autoincrement tracker no matter what
if schema.HasAutoIncrement(ed.Schema()) {
v := s.tracker.Current(name)
v := s.aiTracker.Current(name)
tbl, err = tbl.SetAutoIncrementValue(ctx, v)
if err != nil {
return err
@@ -264,10 +264,6 @@ func (s *nomsWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Working
s.workingSet = ws
root := ws.WorkingRoot()
if err := updateAutoIncrementSequences(ctx, root, s.tracker); err != nil {
return err
}
for tableName, localTableEditor := range s.tables {
t, ok, err := root.GetTable(ctx, tableName)
if err != nil {
@@ -32,7 +32,7 @@ import (
type prollyWriteSession struct {
workingSet *doltdb.WorkingSet
tables map[string]*prollyTableWriter
tracker globalstate.AutoIncrementTracker
aiTracker globalstate.AutoIncrementTracker
mut *sync.RWMutex
}
@@ -96,7 +96,7 @@ func (s *prollyWriteSession) GetTableWriter(ctx context.Context, table, db strin
sch: sch,
sqlSch: pkSch.Schema,
aiCol: autoCol,
aiTracker: s.tracker,
aiTracker: s.aiTracker,
flusher: s,
setter: setter,
batched: batched,
@@ -146,7 +146,7 @@ func (s *prollyWriteSession) flush(ctx context.Context) (*doltdb.WorkingSet, err
}
if schema.HasAutoIncrement(wr.sch) {
t, err = t.SetAutoIncrementValue(ctx2, s.tracker.Current(name))
t, err = t.SetAutoIncrementValue(ctx2, s.aiTracker.Current(name))
if err != nil {
return err
}
@@ -179,10 +179,6 @@ func (s *prollyWriteSession) flush(ctx context.Context) (*doltdb.WorkingSet, err
// setRoot is the inner implementation for SetRoot that does not acquire any locks
func (s *prollyWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.WorkingSet) error {
root := ws.WorkingRoot()
if err := updateAutoIncrementSequences(ctx, root, s.tracker); err != nil {
return err
}
for tableName, tableWriter := range s.tables {
t, ok, err := root.GetTable(ctx, tableName)
if err != nil {
@@ -205,17 +201,3 @@ func (s *prollyWriteSession) setWorkingSet(ctx context.Context, ws *doltdb.Worki
s.workingSet = ws
return nil
}
func updateAutoIncrementSequences(ctx context.Context, root *doltdb.RootValue, t globalstate.AutoIncrementTracker) error {
return root.IterTables(ctx, func(name string, table *doltdb.Table, sch schema.Schema) (stop bool, err error) {
if !schema.HasAutoIncrement(sch) {
return
}
v, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
}
t.Set(name, v)
return
})
}
@@ -98,7 +98,7 @@ func getAsyncEnvAndConfig(ctx context.Context, b *testing.B) (dEnv *env.DoltEnv,
if !ok {
b.Fatal("local config does not exist")
}
localCfg.SetStrings(map[string]string{fmt.Sprintf("%s.%s", env.SqlServerGlobalsPrefix, dsess.ReplicateToRemoteKey): "remote1", fmt.Sprintf("%s.%s", env.SqlServerGlobalsPrefix, dsess.AsyncReplicationKey): "1"})
localCfg.SetStrings(map[string]string{fmt.Sprintf("%s.%s", env.SqlServerGlobalsPrefix, dsess.ReplicateToRemote): "remote1", fmt.Sprintf("%s.%s", env.SqlServerGlobalsPrefix, dsess.AsyncReplication): "1"})
yaml := []byte(fmt.Sprintf(`
log_level: warning
@@ -120,7 +120,7 @@ func getEnvAndConfig(ctx context.Context, b *testing.B) (dEnv *env.DoltEnv, cfg
if !ok {
b.Fatal("local config does not exist")
}
localCfg.SetStrings(map[string]string{dsess.ReplicateToRemoteKey: "remote1"})
localCfg.SetStrings(map[string]string{dsess.ReplicateToRemote: "remote1"})
yaml := []byte(fmt.Sprintf(`
log_level: warning
@@ -694,3 +694,79 @@ SQL
[ $status -eq 0 ]
[[ "$output" =~ "NOT NULL AUTO_INCREMENT" ]] || false
}
@test "auto_increment: globally distinct auto increment values" {
dolt sql <<SQL
call dolt_commit('-am', 'empty table');
call dolt_branch('branch1');
call dolt_branch('branch2');
insert into test (c0) values (1), (2);
call dolt_commit('-am', 'main values');
call dolt_checkout('branch1');
insert into test (c0) values (3), (4);
call dolt_commit('-am', 'branch1 values');
call dolt_checkout('branch2');
insert into test (c0) values (5), (6);
call dolt_commit('-am', 'branch2 values');
SQL
run dolt sql -q 'select * from test' -r csv
[ $status -eq 0 ]
[[ "$output" =~ "1,1" ]] || false
[[ "$output" =~ "2,2" ]] || false
dolt checkout branch1
run dolt sql -q 'select * from test' -r csv
[ $status -eq 0 ]
[[ "$output" =~ "3,3" ]] || false
[[ "$output" =~ "4,4" ]] || false
dolt checkout branch2
run dolt sql -q 'select * from test' -r csv
[ $status -eq 0 ]
[[ "$output" =~ "5,5" ]] || false
[[ "$output" =~ "6,6" ]] || false
# Should have the same result across multiple invocations of sql as well
dolt checkout main
dolt sql <<SQL
create table t1 (ai bigint primary key auto_increment, c0 int);
call dolt_commit('-am', 'empty table');
call dolt_branch('branch3');
call dolt_branch('branch4');
insert into t1 (c0) values (1), (2);
call dolt_commit('-am', 'main values');
SQL
dolt sql <<SQL
call dolt_checkout('branch3');
insert into t1 (c0) values (3), (4);
call dolt_commit('-am', 'branch3 values');
SQL
dolt sql <<SQL
call dolt_checkout('branch4');
insert into t1 (c0) values (5), (6);
call dolt_commit('-am', 'branch4 values');
SQL
run dolt sql -q 'select * from t1' -r csv
[ $status -eq 0 ]
[[ "$output" =~ "1,1" ]] || false
[[ "$output" =~ "2,2" ]] || false
dolt checkout branch3
run dolt sql -q 'select * from t1' -r csv
[ $status -eq 0 ]
[[ "$output" =~ "3,3" ]] || false
[[ "$output" =~ "4,4" ]] || false
dolt checkout branch4
run dolt sql -q 'select * from t1' -r csv
[ $status -eq 0 ]
[[ "$output" =~ "5,5" ]] || false
[[ "$output" =~ "6,6" ]] || false
}
+29
View File
@@ -1117,6 +1117,35 @@ END""")
server_query repo1 1 "SELECT * FROM t1" "pk,val\n1,1\n2,2"
}
@test "sql-server: auto increment is globally distinct across branches and connections" {
skiponwindows "Missing dependencies"
cd repo1
start_sql_server repo1
server_query repo1 1 "CREATE TABLE t1(pk bigint primary key auto_increment, val int)" ""
insert_query repo1 1 "INSERT INTO t1 (val) VALUES (1)"
server_query repo1 1 "SELECT * FROM t1" "pk,val\n1,1"
insert_query repo1 1 "INSERT INTO t1 (val) VALUES (2)"
server_query repo1 1 "SELECT * FROM t1" "pk,val\n1,1\n2,2"
run server_query repo1 1 "call dolt_commit('-am', 'table with two values')"
run server_query repo1 1 "call dolt_branch('new_branch')"
insert_query repo1/new_branch 1 "INSERT INTO t1 (val) VALUES (3)"
server_query repo1/new_branch 1 "SELECT * FROM t1" "pk,val\n1,1\n2,2\n3,3"
insert_query repo1 1 "INSERT INTO t1 (val) VALUES (4)"
server_query repo1 1 "SELECT * FROM t1" "pk,val\n1,1\n2,2\n4,4"
# drop the table on main, should keep counting from 4
server_query repo1 1 "drop table t1;"
server_query repo1 1 "CREATE TABLE t1(pk bigint primary key auto_increment, val int)" ""
insert_query repo1 1 "INSERT INTO t1 (val) VALUES (4)"
server_query repo1 1 "SELECT * FROM t1" "pk,val\n4,4"
}
@test "sql-server: sql-push --set-remote within session" {
skiponwindows "Missing dependencies"