Merge remote-tracking branch 'origin/main' into nicktobey/lazy-load

This commit is contained in:
Nick Tobey
2025-02-03 12:45:10 -08:00
37 changed files with 2388 additions and 1260 deletions

3
.gitignore vendored
View File

@@ -16,3 +16,6 @@ SysbenchDockerfile.dockerignore
sysbench-runner-tests-entrypoint.sh
config.json
integration-tests/bats/batsee_results
*~
.dir-locals.el

View File

@@ -47,6 +47,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/statspro"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
)
// SqlEngine packages up the context necessary to run sql queries against dsqle.
@@ -55,6 +56,7 @@ type SqlEngine struct {
contextFactory contextFactory
dsessFactory sessionFactory
engine *gms.Engine
fs filesys.Filesys
}
type sessionFactory func(mysqlSess *sql.BaseSession, pro sql.DatabaseProvider) (*dsess.DoltSession, error)
@@ -124,6 +126,8 @@ func NewSqlEngine(
locations = append(locations, nil)
}
gcSafepointController := dsess.NewGCSafepointController()
b := env.GetDefaultInitBranch(mrEnv.Config())
pro, err := dsqle.NewDoltDatabaseProviderWithDatabases(b, mrEnv.FileSystem(), all, locations)
if err != nil {
@@ -189,11 +193,12 @@ func NewSqlEngine(
engine.Analyzer.Catalog.StatsProvider = statsPro
engine.Analyzer.ExecBuilder = rowexec.NewOverrideBuilder(kvexec.Builder{})
sessFactory := doltSessionFactory(pro, statsPro, mrEnv.Config(), bcController, config.Autocommit)
sessFactory := doltSessionFactory(pro, statsPro, mrEnv.Config(), bcController, gcSafepointController, config.Autocommit)
sqlEngine.provider = pro
sqlEngine.contextFactory = sqlContextFactory()
sqlEngine.dsessFactory = sessFactory
sqlEngine.engine = engine
sqlEngine.fs = pro.FileSystem()
// configuring stats depends on sessionBuilder
// sessionBuilder needs ref to statsProv
@@ -314,6 +319,10 @@ func (se *SqlEngine) GetUnderlyingEngine() *gms.Engine {
return se.engine
}
func (se *SqlEngine) FileSystem() filesys.Filesys {
return se.fs
}
func (se *SqlEngine) Close() error {
if se.engine != nil {
if se.engine.Analyzer.Catalog.BinlogReplicaController != nil {
@@ -413,9 +422,9 @@ func sqlContextFactory() contextFactory {
}
// doltSessionFactory returns a sessionFactory that creates a new DoltSession
func doltSessionFactory(pro *dsqle.DoltDatabaseProvider, statsPro sql.StatsProvider, config config.ReadWriteConfig, bc *branch_control.Controller, autocommit bool) sessionFactory {
func doltSessionFactory(pro *dsqle.DoltDatabaseProvider, statsPro sql.StatsProvider, config config.ReadWriteConfig, bc *branch_control.Controller, gcSafepointController *dsess.GCSafepointController, autocommit bool) sessionFactory {
return func(mysqlSess *sql.BaseSession, provider sql.DatabaseProvider) (*dsess.DoltSession, error) {
doltSession, err := dsess.NewDoltSession(mysqlSess, pro, config, bc, statsPro, writer.NewWriteSession)
doltSession, err := dsess.NewDoltSession(mysqlSess, pro, config, bc, statsPro, writer.NewWriteSession, gcSafepointController)
if err != nil {
return nil, err
}

View File

@@ -559,21 +559,27 @@ func ConfigureServices(
}
listenaddr := fmt.Sprintf(":%d", port)
sqlContextInterceptor := sqle.SqlContextServerInterceptor{
Factory: sqlEngine.NewDefaultContext,
}
args := remotesrv.ServerArgs{
Logger: logrus.NewEntry(lgr),
ReadOnly: apiReadOnly || serverConfig.ReadOnly(),
HttpListenAddr: listenaddr,
GrpcListenAddr: listenaddr,
ConcurrencyControl: remotesapi.PushConcurrencyControl_PUSH_CONCURRENCY_CONTROL_ASSERT_WORKING_SET,
Options: sqlContextInterceptor.Options(),
HttpInterceptor: sqlContextInterceptor.HTTP(nil),
}
var err error
args.FS, args.DBCache, err = sqle.RemoteSrvFSAndDBCache(sqlEngine.NewDefaultContext, sqle.DoNotCreateUnknownDatabases)
args.FS = sqlEngine.FileSystem()
args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.DoNotCreateUnknownDatabases)
if err != nil {
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
return err
}
authenticator := newAccessController(sqlEngine.NewDefaultContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
authenticator := newAccessController(sqle.GetInterceptorSqlContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
args = sqle.WithUserPasswordAuth(args, authenticator)
args.TLSConfig = serverConf.TLSConfig
@@ -621,6 +627,7 @@ func ConfigureServices(
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
return err
}
args.FS = sqlEngine.FileSystem()
clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig())
if err != nil {
@@ -634,7 +641,7 @@ func ConfigureServices(
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
return err
}
clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.srv.GrpcServer())
clusterController.RegisterGrpcServices(sqle.GetInterceptorSqlContext, clusterRemoteSrv.srv.GrpcServer())
clusterRemoteSrv.lis, err = clusterRemoteSrv.srv.Listeners()
if err != nil {
@@ -711,7 +718,13 @@ func ConfigureServices(
AutoStartBinlogReplica := &svcs.AnonService{
InitF: func(ctx context.Context) error {
// If we're unable to restart replication, log an error, but don't prevent the server from starting up
if err := binlogreplication.DoltBinlogReplicaController.AutoStart(ctx); err != nil {
sqlCtx, err := sqlEngine.NewDefaultContext(ctx)
if err != nil {
logrus.Errorf("unable to restart replication, could not create session: %s", err.Error())
return nil
}
defer sql.SessionEnd(sqlCtx.Session)
if err := binlogreplication.DoltBinlogReplicaController.AutoStart(sqlCtx); err != nil {
logrus.Errorf("unable to restart replication: %s", err.Error())
}
return nil

View File

@@ -56,7 +56,7 @@ require (
github.com/cespare/xxhash/v2 v2.2.0
github.com/creasty/defaults v1.6.0
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
github.com/dolthub/go-mysql-server v0.19.1-0.20250128182847-3f5bb8c52cd8
github.com/dolthub/go-mysql-server v0.19.1-0.20250131110511-67aa2a430366
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63
github.com/dolthub/swiss v0.1.0
github.com/esote/minmaxheap v1.0.0

View File

@@ -179,8 +179,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90 h1:Sni8jrP0sy/w9ZYXoff4g/ixe+7bFCZlfCqXKJSU+zM=
github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA=
github.com/dolthub/go-mysql-server v0.19.1-0.20250128182847-3f5bb8c52cd8 h1:eEGYHOC5Ft+56yPaH26gsdbonrZ2EiTwQLy8Oj3TAFE=
github.com/dolthub/go-mysql-server v0.19.1-0.20250128182847-3f5bb8c52cd8/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc=
github.com/dolthub/go-mysql-server v0.19.1-0.20250131110511-67aa2a430366 h1:pJ+upgX6hrhyqgpkmk9Ye9lIPSualMHZcUMs8kWknV4=
github.com/dolthub/go-mysql-server v0.19.1-0.20250131110511-67aa2a430366/go.mod h1:jYEJ8tNkA7K3k39X8iMqaX3MSMmViRgh222JSLHDgVc=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=

View File

@@ -45,4 +45,8 @@ const (
EnvDbNameReplace = "DOLT_DBNAME_REPLACE"
EnvDoltRootHost = "DOLT_ROOT_HOST"
EnvDoltRootPassword = "DOLT_ROOT_PASSWORD"
// If set, must be "kill_connections" or "session_aware"
// Will go away after session_aware is made default-and-only.
EnvGCSafepointControllerChoice = "DOLT_GC_SAFEPOINT_CONTROLLER_CHOICE"
)

View File

@@ -706,6 +706,45 @@ func (ddb *DoltDB) writeRootValue(ctx context.Context, rv RootValue) (RootValue,
return rv, ref, nil
}
// Persists all relevant root values of the WorkingSet to the database and returns all hashes reachable
// from the working set. This is used in GC, for example, where all dependencies of the in-memory working
// set value need to be accounted for.
func (ddb *DoltDB) WorkingSetHashes(ctx context.Context, ws *WorkingSet) ([]hash.Hash, error) {
spec, err := ws.writeValues(ctx, ddb, nil)
if err != nil {
return nil, err
}
ret := make([]hash.Hash, 0)
ret = append(ret, spec.StagedRoot.TargetHash())
ret = append(ret, spec.WorkingRoot.TargetHash())
if spec.MergeState != nil {
fromCommit, err := spec.MergeState.FromCommit(ctx, ddb.vrw)
if err != nil {
return nil, err
}
h, err := fromCommit.NomsValue().Hash(ddb.db.Format())
if err != nil {
return nil, err
}
ret = append(ret, h)
h, err = spec.MergeState.PreMergeWorkingAddr(ctx, ddb.vrw)
ret = append(ret, h)
}
if spec.RebaseState != nil {
ret = append(ret, spec.RebaseState.PreRebaseWorkingAddr())
commit, err := spec.RebaseState.OntoCommit(ctx, ddb.vrw)
if err != nil {
return nil, err
}
h, err := commit.NomsValue().Hash(ddb.db.Format())
if err != nil {
return nil, err
}
ret = append(ret, h)
}
return ret, nil
}
// ReadRootValue reads the RootValue associated with the hash given and returns it. Returns an error if the value cannot
// be read, or if the hash given doesn't represent a dolt RootValue.
func (ddb *DoltDB) ReadRootValue(ctx context.Context, h hash.Hash) (RootValue, error) {

View File

@@ -590,7 +590,6 @@ func (ws *WorkingSet) writeValues(ctx context.Context, db *DoltDB, meta *datas.W
return nil, fmt.Errorf("StagedRoot and workingRoot must be set. This is a bug.")
}
var r RootValue
r, workingRoot, err := db.writeRootValue(ctx, ws.workingRoot)
if err != nil {
return nil, err

View File

@@ -105,8 +105,6 @@ func persistReplicaRunningState(ctx *sql.Context, state replicaRunningState) err
// loadReplicationConfiguration loads the replication configuration for default channel ("") from
// the "mysql" database, |mysqlDb|.
func loadReplicationConfiguration(ctx *sql.Context, mysqlDb *mysql_db.MySQLDb) (*mysql_db.ReplicaSourceInfo, error) {
sql.SessionCommandBegin(ctx.Session)
defer sql.SessionCommandEnd(ctx.Session)
rd := mysqlDb.Reader()
defer rd.Close()

View File

@@ -146,8 +146,9 @@ func (a *binlogReplicaApplier) connectAndStartReplicationEventStream(ctx *sql.Co
var conn *mysql.Conn
var err error
for connectionAttempts := uint64(0); ; connectionAttempts++ {
sql.SessionCommandBegin(ctx.Session)
replicaSourceInfo, err := loadReplicationConfiguration(ctx, a.engine.Analyzer.Catalog.MySQLDb)
sql.SessionCommandEnd(ctx.Session)
if replicaSourceInfo == nil {
err = ErrServerNotConfiguredAsReplica
DoltBinlogReplicaController.setIoError(ERFatalReplicaError, err.Error())

View File

@@ -15,7 +15,6 @@
package binlogreplication
import (
"context"
"fmt"
"strings"
"sync"
@@ -158,8 +157,6 @@ func (d *doltBinlogReplicaController) StartReplica(ctx *sql.Context) error {
// created and locked to disable log ins, and if it does exist, but is missing super privs or is not
// locked, it will be given superuser privs and locked.
func (d *doltBinlogReplicaController) configureReplicationUser(ctx *sql.Context) {
sql.SessionCommandBegin(ctx.Session)
defer sql.SessionCommandEnd(ctx.Session)
mySQLDb := d.engine.Analyzer.Catalog.MySQLDb
ed := mySQLDb.Editor()
defer ed.Close()
@@ -417,8 +414,10 @@ func (d *doltBinlogReplicaController) setSqlError(errno uint, message string) {
// replication is not configured, hasn't been started, or has been stopped before the server was
// shutdown, then this method will not start replication. This method should only be called during
// the server startup process and should not be invoked after that.
func (d *doltBinlogReplicaController) AutoStart(_ context.Context) error {
runningState, err := loadReplicationRunningState(d.ctx)
func (d *doltBinlogReplicaController) AutoStart(ctx *sql.Context) error {
sql.SessionCommandBegin(ctx.Session)
defer sql.SessionCommandEnd(ctx.Session)
runningState, err := loadReplicationRunningState(ctx)
if err != nil {
logrus.Errorf("Unable to load replication running state: %s", err.Error())
return err
@@ -430,7 +429,7 @@ func (d *doltBinlogReplicaController) AutoStart(_ context.Context) error {
}
logrus.Info("auto-starting binlog replication from source...")
return d.StartReplica(d.ctx)
return d.StartReplica(ctx)
}
// Release all resources, such as replication threads, associated with the replication.

View File

@@ -27,27 +27,27 @@ import (
// TestBinlogReplicationForAllTypes tests that operations (inserts, updates, and deletes) on all SQL
// data types can be successfully replicated.
func TestBinlogReplicationForAllTypes(t *testing.T) {
defer teardown(t)
startSqlServersWithDoltSystemVars(t, doltReplicaSystemVars)
startReplicationAndCreateTestDb(t, mySqlPort)
h := newHarness(t)
h.startSqlServersWithDoltSystemVars(doltReplicaSystemVars)
h.startReplicationAndCreateTestDb(h.mySqlPort)
// Set the session's timezone to UTC, to avoid TIMESTAMP test values changing
// when they are converted to UTC for storage.
primaryDatabase.MustExec("SET @@time_zone = '+0:00';")
h.primaryDatabase.MustExec("SET @@time_zone = '+0:00';")
// Create the test table
tableName := "alltypes"
createTableStatement := generateCreateTableStatement(tableName)
primaryDatabase.MustExec(createTableStatement)
h.primaryDatabase.MustExec(createTableStatement)
// Make inserts on the primary small, large, and null values
primaryDatabase.MustExec(generateInsertValuesStatement(tableName, 0))
primaryDatabase.MustExec(generateInsertValuesStatement(tableName, 1))
primaryDatabase.MustExec(generateInsertNullValuesStatement(tableName))
h.primaryDatabase.MustExec(generateInsertValuesStatement(tableName, 0))
h.primaryDatabase.MustExec(generateInsertValuesStatement(tableName, 1))
h.primaryDatabase.MustExec(generateInsertNullValuesStatement(tableName))
// Verify inserts on replica
waitForReplicaToCatchUp(t)
rows, err := replicaDatabase.Queryx("select * from db01.alltypes order by pk asc;")
h.waitForReplicaToCatchUp()
rows, err := h.replicaDatabase.Queryx("select * from db01.alltypes order by pk asc;")
require.NoError(t, err)
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "1", row["pk"])
@@ -62,14 +62,14 @@ func TestBinlogReplicationForAllTypes(t *testing.T) {
require.NoError(t, rows.Close())
// Make updates on the primary
primaryDatabase.MustExec(generateUpdateToNullValuesStatement(tableName, 1))
primaryDatabase.MustExec(generateUpdateValuesStatement(tableName, 2, 0))
primaryDatabase.MustExec(generateUpdateValuesStatement(tableName, 3, 1))
h.primaryDatabase.MustExec(generateUpdateToNullValuesStatement(tableName, 1))
h.primaryDatabase.MustExec(generateUpdateValuesStatement(tableName, 2, 0))
h.primaryDatabase.MustExec(generateUpdateValuesStatement(tableName, 3, 1))
// Verify updates on the replica
waitForReplicaToCatchUp(t)
replicaDatabase.MustExec("use db01;")
rows, err = replicaDatabase.Queryx("select * from db01.alltypes order by pk asc;")
h.waitForReplicaToCatchUp()
h.replicaDatabase.MustExec("use db01;")
rows, err = h.replicaDatabase.Queryx("select * from db01.alltypes order by pk asc;")
require.NoError(t, err)
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "1", row["pk"])
@@ -84,13 +84,13 @@ func TestBinlogReplicationForAllTypes(t *testing.T) {
require.NoError(t, rows.Close())
// Make deletes on the primary
primaryDatabase.MustExec("delete from alltypes where pk=1;")
primaryDatabase.MustExec("delete from alltypes where pk=2;")
primaryDatabase.MustExec("delete from alltypes where pk=3;")
h.primaryDatabase.MustExec("delete from alltypes where pk=1;")
h.primaryDatabase.MustExec("delete from alltypes where pk=2;")
h.primaryDatabase.MustExec("delete from alltypes where pk=3;")
// Verify deletes on the replica
waitForReplicaToCatchUp(t)
rows, err = replicaDatabase.Queryx("select * from db01.alltypes order by pk asc;")
h.waitForReplicaToCatchUp()
rows, err = h.replicaDatabase.Queryx("select * from db01.alltypes order by pk asc;")
require.NoError(t, err)
require.False(t, rows.Next())
require.NoError(t, rows.Close())

View File

@@ -24,37 +24,37 @@ import (
// TestBinlogReplicationFilters_ignoreTablesOnly tests that the ignoreTables replication
// filtering option is correctly applied and honored.
func TestBinlogReplicationFilters_ignoreTablesOnly(t *testing.T) {
defer teardown(t)
startSqlServersWithDoltSystemVars(t, doltReplicaSystemVars)
startReplicationAndCreateTestDb(t, mySqlPort)
h := newHarness(t)
h.startSqlServersWithDoltSystemVars(doltReplicaSystemVars)
h.startReplicationAndCreateTestDb(h.mySqlPort)
// Ignore replication events for db01.t2. Also tests that the first filter setting is overwritten by
// the second and that db and that db and table names are case-insensitive.
replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_IGNORE_TABLE=(db01.t1);")
replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_IGNORE_TABLE=(DB01.T2);")
h.replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_IGNORE_TABLE=(db01.t1);")
h.replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_IGNORE_TABLE=(DB01.T2);")
// Assert that status shows replication filters
status := showReplicaStatus(t)
status := h.showReplicaStatus()
require.Equal(t, "db01.t2", status["Replicate_Ignore_Table"])
require.Equal(t, "", status["Replicate_Do_Table"])
// Make changes on the primary
primaryDatabase.MustExec("CREATE TABLE db01.t1 (pk INT PRIMARY KEY);")
primaryDatabase.MustExec("CREATE TABLE db01.t2 (pk INT PRIMARY KEY);")
h.primaryDatabase.MustExec("CREATE TABLE db01.t1 (pk INT PRIMARY KEY);")
h.primaryDatabase.MustExec("CREATE TABLE db01.t2 (pk INT PRIMARY KEY);")
for i := 1; i < 12; i++ {
primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t1 VALUES (%d);", i))
primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t2 VALUES (%d);", i))
h.primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t1 VALUES (%d);", i))
h.primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t2 VALUES (%d);", i))
}
primaryDatabase.MustExec("UPDATE db01.t1 set pk = pk-1;")
primaryDatabase.MustExec("UPDATE db01.t2 set pk = pk-1;")
primaryDatabase.MustExec("DELETE FROM db01.t1 WHERE pk = 10;")
primaryDatabase.MustExec("DELETE FROM db01.t2 WHERE pk = 10;")
h.primaryDatabase.MustExec("UPDATE db01.t1 set pk = pk-1;")
h.primaryDatabase.MustExec("UPDATE db01.t2 set pk = pk-1;")
h.primaryDatabase.MustExec("DELETE FROM db01.t1 WHERE pk = 10;")
h.primaryDatabase.MustExec("DELETE FROM db01.t2 WHERE pk = 10;")
// Pause to let the replica catch up
waitForReplicaToCatchUp(t)
h.waitForReplicaToCatchUp()
// Verify that all changes from t1 were applied on the replica
rows, err := replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t1;")
rows, err := h.replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t1;")
require.NoError(t, err)
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "10", row["count"])
@@ -63,7 +63,7 @@ func TestBinlogReplicationFilters_ignoreTablesOnly(t *testing.T) {
require.NoError(t, rows.Close())
// Verify that no changes from t2 were applied on the replica
rows, err = replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t2;")
rows, err = h.replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t2;")
require.NoError(t, err)
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "0", row["count"])
@@ -75,37 +75,37 @@ func TestBinlogReplicationFilters_ignoreTablesOnly(t *testing.T) {
// TestBinlogReplicationFilters_doTablesOnly tests that the doTables replication
// filtering option is correctly applied and honored.
func TestBinlogReplicationFilters_doTablesOnly(t *testing.T) {
defer teardown(t)
startSqlServersWithDoltSystemVars(t, doltReplicaSystemVars)
startReplicationAndCreateTestDb(t, mySqlPort)
h := newHarness(t)
h.startSqlServersWithDoltSystemVars(doltReplicaSystemVars)
h.startReplicationAndCreateTestDb(h.mySqlPort)
// Do replication events for db01.t1. Also tests that the first filter setting is overwritten by
// the second and that db and that db and table names are case-insensitive.
replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_DO_TABLE=(db01.t2);")
replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_DO_TABLE=(DB01.T1);")
h.replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_DO_TABLE=(db01.t2);")
h.replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_DO_TABLE=(DB01.T1);")
// Assert that status shows replication filters
status := showReplicaStatus(t)
status := h.showReplicaStatus()
require.Equal(t, "db01.t1", status["Replicate_Do_Table"])
require.Equal(t, "", status["Replicate_Ignore_Table"])
// Make changes on the primary
primaryDatabase.MustExec("CREATE TABLE db01.t1 (pk INT PRIMARY KEY);")
primaryDatabase.MustExec("CREATE TABLE db01.t2 (pk INT PRIMARY KEY);")
h.primaryDatabase.MustExec("CREATE TABLE db01.t1 (pk INT PRIMARY KEY);")
h.primaryDatabase.MustExec("CREATE TABLE db01.t2 (pk INT PRIMARY KEY);")
for i := 1; i < 12; i++ {
primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t1 VALUES (%d);", i))
primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t2 VALUES (%d);", i))
h.primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t1 VALUES (%d);", i))
h.primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t2 VALUES (%d);", i))
}
primaryDatabase.MustExec("UPDATE db01.t1 set pk = pk-1;")
primaryDatabase.MustExec("UPDATE db01.t2 set pk = pk-1;")
primaryDatabase.MustExec("DELETE FROM db01.t1 WHERE pk = 10;")
primaryDatabase.MustExec("DELETE FROM db01.t2 WHERE pk = 10;")
h.primaryDatabase.MustExec("UPDATE db01.t1 set pk = pk-1;")
h.primaryDatabase.MustExec("UPDATE db01.t2 set pk = pk-1;")
h.primaryDatabase.MustExec("DELETE FROM db01.t1 WHERE pk = 10;")
h.primaryDatabase.MustExec("DELETE FROM db01.t2 WHERE pk = 10;")
// Pause to let the replica catch up
waitForReplicaToCatchUp(t)
h.waitForReplicaToCatchUp()
// Verify that all changes from t1 were applied on the replica
rows, err := replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t1;")
rows, err := h.replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t1;")
require.NoError(t, err)
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "10", row["count"])
@@ -114,7 +114,7 @@ func TestBinlogReplicationFilters_doTablesOnly(t *testing.T) {
require.NoError(t, rows.Close())
// Verify that no changes from t2 were applied on the replica
rows, err = replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t2;")
rows, err = h.replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t2;")
require.NoError(t, err)
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "0", row["count"])
@@ -126,38 +126,38 @@ func TestBinlogReplicationFilters_doTablesOnly(t *testing.T) {
// TestBinlogReplicationFilters_doTablesAndIgnoreTables tests that the doTables and ignoreTables
// replication filtering options are correctly applied and honored when used together.
func TestBinlogReplicationFilters_doTablesAndIgnoreTables(t *testing.T) {
defer teardown(t)
startSqlServersWithDoltSystemVars(t, doltReplicaSystemVars)
startReplicationAndCreateTestDb(t, mySqlPort)
h := newHarness(t)
h.startSqlServersWithDoltSystemVars(doltReplicaSystemVars)
h.startReplicationAndCreateTestDb(h.mySqlPort)
// Do replication events for db01.t1, and db01.t2
replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_DO_TABLE=(db01.t1, db01.t2);")
h.replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_DO_TABLE=(db01.t1, db01.t2);")
// Ignore replication events for db01.t2
replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_IGNORE_TABLE=(db01.t2);")
h.replicaDatabase.MustExec("CHANGE REPLICATION FILTER REPLICATE_IGNORE_TABLE=(db01.t2);")
// Assert that replica status shows replication filters
status := showReplicaStatus(t)
status := h.showReplicaStatus()
require.True(t, status["Replicate_Do_Table"] == "db01.t1,db01.t2" ||
status["Replicate_Do_Table"] == "db01.t2,db01.t1")
require.Equal(t, "db01.t2", status["Replicate_Ignore_Table"])
// Make changes on the primary
primaryDatabase.MustExec("CREATE TABLE db01.t1 (pk INT PRIMARY KEY);")
primaryDatabase.MustExec("CREATE TABLE db01.t2 (pk INT PRIMARY KEY);")
h.primaryDatabase.MustExec("CREATE TABLE db01.t1 (pk INT PRIMARY KEY);")
h.primaryDatabase.MustExec("CREATE TABLE db01.t2 (pk INT PRIMARY KEY);")
for i := 1; i < 12; i++ {
primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t1 VALUES (%d);", i))
primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t2 VALUES (%d);", i))
h.primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t1 VALUES (%d);", i))
h.primaryDatabase.MustExec(fmt.Sprintf("INSERT INTO db01.t2 VALUES (%d);", i))
}
primaryDatabase.MustExec("UPDATE db01.t1 set pk = pk-1;")
primaryDatabase.MustExec("UPDATE db01.t2 set pk = pk-1;")
primaryDatabase.MustExec("DELETE FROM db01.t1 WHERE pk = 10;")
primaryDatabase.MustExec("DELETE FROM db01.t2 WHERE pk = 10;")
h.primaryDatabase.MustExec("UPDATE db01.t1 set pk = pk-1;")
h.primaryDatabase.MustExec("UPDATE db01.t2 set pk = pk-1;")
h.primaryDatabase.MustExec("DELETE FROM db01.t1 WHERE pk = 10;")
h.primaryDatabase.MustExec("DELETE FROM db01.t2 WHERE pk = 10;")
// Pause to let the replica catch up
waitForReplicaToCatchUp(t)
h.waitForReplicaToCatchUp()
// Verify that all changes from t1 were applied on the replica
rows, err := replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t1;")
rows, err := h.replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t1;")
require.NoError(t, err)
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "10", row["count"])
@@ -166,7 +166,7 @@ func TestBinlogReplicationFilters_doTablesAndIgnoreTables(t *testing.T) {
require.NoError(t, rows.Close())
// Verify that no changes from t2 were applied on the replica
rows, err = replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t2;")
rows, err = h.replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t2;")
require.NoError(t, err)
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "0", row["count"])
@@ -177,15 +177,15 @@ func TestBinlogReplicationFilters_doTablesAndIgnoreTables(t *testing.T) {
// TestBinlogReplicationFilters_errorCases test returned errors for various error cases.
func TestBinlogReplicationFilters_errorCases(t *testing.T) {
defer teardown(t)
startSqlServers(t)
h := newHarness(t)
h.startSqlServers()
// All tables must be qualified with a database
_, err := replicaDatabase.Queryx("CHANGE REPLICATION FILTER REPLICATE_DO_TABLE=(t1);")
_, err := h.replicaDatabase.Queryx("CHANGE REPLICATION FILTER REPLICATE_DO_TABLE=(t1);")
require.Error(t, err)
require.ErrorContains(t, err, "no database specified for table")
_, err = replicaDatabase.Queryx("CHANGE REPLICATION FILTER REPLICATE_IGNORE_TABLE=(t1);")
_, err = h.replicaDatabase.Queryx("CHANGE REPLICATION FILTER REPLICATE_IGNORE_TABLE=(t1);")
require.Error(t, err)
require.ErrorContains(t, err, "no database specified for table")
}

View File

@@ -23,30 +23,30 @@ import (
// TestBinlogReplicationMultiDb tests that binlog events spanning multiple databases are correctly
// applied by a replica.
func TestBinlogReplicationMultiDb(t *testing.T) {
defer teardown(t)
startSqlServersWithDoltSystemVars(t, doltReplicaSystemVars)
startReplicationAndCreateTestDb(t, mySqlPort)
h := newHarness(t)
h.startSqlServersWithDoltSystemVars(doltReplicaSystemVars)
h.startReplicationAndCreateTestDb(h.mySqlPort)
// Make changes on the primary to db01 and db02
primaryDatabase.MustExec("create database db02;")
primaryDatabase.MustExec("use db01;")
primaryDatabase.MustExec("create table t01 (pk int primary key, c1 int default (0))")
primaryDatabase.MustExec("use db02;")
primaryDatabase.MustExec("create table t02 (pk int primary key, c1 int default (0))")
primaryDatabase.MustExec("use db01;")
primaryDatabase.MustExec("insert into t01 (pk) values (1), (3), (5), (8), (9);")
primaryDatabase.MustExec("use db02;")
primaryDatabase.MustExec("insert into t02 (pk) values (2), (4), (6), (7), (10);")
primaryDatabase.MustExec("use db01;")
primaryDatabase.MustExec("delete from t01 where pk=9;")
primaryDatabase.MustExec("delete from db02.t02 where pk=10;")
primaryDatabase.MustExec("use db02;")
primaryDatabase.MustExec("update db01.t01 set pk=7 where pk=8;")
primaryDatabase.MustExec("update t02 set pk=8 where pk=7;")
h.primaryDatabase.MustExec("create database db02;")
h.primaryDatabase.MustExec("use db01;")
h.primaryDatabase.MustExec("create table t01 (pk int primary key, c1 int default (0))")
h.primaryDatabase.MustExec("use db02;")
h.primaryDatabase.MustExec("create table t02 (pk int primary key, c1 int default (0))")
h.primaryDatabase.MustExec("use db01;")
h.primaryDatabase.MustExec("insert into t01 (pk) values (1), (3), (5), (8), (9);")
h.primaryDatabase.MustExec("use db02;")
h.primaryDatabase.MustExec("insert into t02 (pk) values (2), (4), (6), (7), (10);")
h.primaryDatabase.MustExec("use db01;")
h.primaryDatabase.MustExec("delete from t01 where pk=9;")
h.primaryDatabase.MustExec("delete from db02.t02 where pk=10;")
h.primaryDatabase.MustExec("use db02;")
h.primaryDatabase.MustExec("update db01.t01 set pk=7 where pk=8;")
h.primaryDatabase.MustExec("update t02 set pk=8 where pk=7;")
// Verify the changes in db01 on the replica
waitForReplicaToCatchUp(t)
rows, err := replicaDatabase.Queryx("select * from db01.t01 order by pk asc;")
h.waitForReplicaToCatchUp()
rows, err := h.replicaDatabase.Queryx("select * from db01.t01 order by pk asc;")
require.NoError(t, err)
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "1", row["pk"])
@@ -61,8 +61,8 @@ func TestBinlogReplicationMultiDb(t *testing.T) {
require.NoError(t, rows.Close())
// Verify db01.dolt_diff
replicaDatabase.MustExec("use db01;")
rows, err = replicaDatabase.Queryx("select * from db01.dolt_diff;")
h.replicaDatabase.MustExec("use db01;")
rows, err = h.replicaDatabase.Queryx("select * from db01.dolt_diff;")
require.NoError(t, err)
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t01", row["table_name"])
@@ -85,8 +85,8 @@ func TestBinlogReplicationMultiDb(t *testing.T) {
require.NoError(t, rows.Close())
// Verify the changes in db02 on the replica
replicaDatabase.MustExec("use db02;")
rows, err = replicaDatabase.Queryx("select * from db02.t02 order by pk asc;")
h.replicaDatabase.MustExec("use db02;")
rows, err = h.replicaDatabase.Queryx("select * from db02.t02 order by pk asc;")
require.NoError(t, err)
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "2", row["pk"])
@@ -100,7 +100,7 @@ func TestBinlogReplicationMultiDb(t *testing.T) {
require.NoError(t, rows.Close())
// Verify db02.dolt_diff
rows, err = replicaDatabase.Queryx("select * from db02.dolt_diff;")
rows, err = h.replicaDatabase.Queryx("select * from db02.dolt_diff;")
require.NoError(t, err)
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t02", row["table_name"])
@@ -125,28 +125,28 @@ func TestBinlogReplicationMultiDb(t *testing.T) {
// TestBinlogReplicationMultiDbTransactions tests that binlog events for transactions that span
// multiple DBs are applied correctly to a replica.
func TestBinlogReplicationMultiDbTransactions(t *testing.T) {
defer teardown(t)
startSqlServersWithDoltSystemVars(t, doltReplicaSystemVars)
startReplicationAndCreateTestDb(t, mySqlPort)
h := newHarness(t)
h.startSqlServersWithDoltSystemVars(doltReplicaSystemVars)
h.startReplicationAndCreateTestDb(h.mySqlPort)
// Make changes on the primary to db01 and db02
primaryDatabase.MustExec("create database db02;")
primaryDatabase.MustExec("create table db01.t01 (pk int primary key, c1 int default (0))")
primaryDatabase.MustExec("create table db02.t02 (pk int primary key, c1 int default (0))")
primaryDatabase.MustExec("set @autocommit = 0;")
h.primaryDatabase.MustExec("create database db02;")
h.primaryDatabase.MustExec("create table db01.t01 (pk int primary key, c1 int default (0))")
h.primaryDatabase.MustExec("create table db02.t02 (pk int primary key, c1 int default (0))")
h.primaryDatabase.MustExec("set @autocommit = 0;")
primaryDatabase.MustExec("start transaction;")
primaryDatabase.MustExec("insert into db01.t01 (pk) values (1), (3), (5), (8), (9);")
primaryDatabase.MustExec("insert into db02.t02 (pk) values (2), (4), (6), (7), (10);")
primaryDatabase.MustExec("delete from db01.t01 where pk=9;")
primaryDatabase.MustExec("delete from db02.t02 where pk=10;")
primaryDatabase.MustExec("update db01.t01 set pk=7 where pk=8;")
primaryDatabase.MustExec("update db02.t02 set pk=8 where pk=7;")
primaryDatabase.MustExec("commit;")
h.primaryDatabase.MustExec("start transaction;")
h.primaryDatabase.MustExec("insert into db01.t01 (pk) values (1), (3), (5), (8), (9);")
h.primaryDatabase.MustExec("insert into db02.t02 (pk) values (2), (4), (6), (7), (10);")
h.primaryDatabase.MustExec("delete from db01.t01 where pk=9;")
h.primaryDatabase.MustExec("delete from db02.t02 where pk=10;")
h.primaryDatabase.MustExec("update db01.t01 set pk=7 where pk=8;")
h.primaryDatabase.MustExec("update db02.t02 set pk=8 where pk=7;")
h.primaryDatabase.MustExec("commit;")
// Verify the changes in db01 on the replica
waitForReplicaToCatchUp(t)
rows, err := replicaDatabase.Queryx("select * from db01.t01 order by pk asc;")
h.waitForReplicaToCatchUp()
rows, err := h.replicaDatabase.Queryx("select * from db01.t01 order by pk asc;")
require.NoError(t, err)
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "1", row["pk"])
@@ -160,8 +160,8 @@ func TestBinlogReplicationMultiDbTransactions(t *testing.T) {
require.NoError(t, rows.Close())
// Verify db01.dolt_diff
replicaDatabase.MustExec("use db01;")
rows, err = replicaDatabase.Queryx("select * from db01.dolt_diff;")
h.replicaDatabase.MustExec("use db01;")
rows, err = h.replicaDatabase.Queryx("select * from db01.dolt_diff;")
require.NoError(t, err)
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t01", row["table_name"])
@@ -175,9 +175,9 @@ func TestBinlogReplicationMultiDbTransactions(t *testing.T) {
require.NoError(t, rows.Close())
// Verify the changes in db02 on the replica
waitForReplicaToCatchUp(t)
replicaDatabase.MustExec("use db02;")
rows, err = replicaDatabase.Queryx("select * from db02.t02 order by pk asc;")
h.waitForReplicaToCatchUp()
h.replicaDatabase.MustExec("use db02;")
rows, err = h.replicaDatabase.Queryx("select * from db02.t02 order by pk asc;")
require.NoError(t, err)
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "2", row["pk"])
@@ -191,7 +191,7 @@ func TestBinlogReplicationMultiDbTransactions(t *testing.T) {
require.NoError(t, rows.Close())
// Verify db02.dolt_diff
rows, err = replicaDatabase.Queryx("select * from db02.dolt_diff;")
rows, err = h.replicaDatabase.Queryx("select * from db02.dolt_diff;")
require.NoError(t, err)
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t02", row["table_name"])

View File

@@ -28,38 +28,34 @@ import (
"github.com/stretchr/testify/require"
)
var toxiClient *toxiproxyclient.Client
var mysqlProxy *toxiproxyclient.Proxy
var proxyPort int
// TestBinlogReplicationAutoReconnect tests that the replica's connection to the primary is correctly
// reestablished if it drops.
func TestBinlogReplicationAutoReconnect(t *testing.T) {
defer teardown(t)
startSqlServersWithDoltSystemVars(t, doltReplicaSystemVars)
configureToxiProxy(t)
configureFastConnectionRetry(t)
startReplicationAndCreateTestDb(t, proxyPort)
h := newHarness(t)
h.startSqlServersWithDoltSystemVars(doltReplicaSystemVars)
h.configureToxiProxy()
h.configureFastConnectionRetry()
h.startReplicationAndCreateTestDb(h.proxyPort)
// Get the replica started up and ensure it's in sync with the primary before turning on the limit_data toxic
testInitialReplicaStatus(t)
primaryDatabase.MustExec("create table reconnect_test(pk int primary key, c1 varchar(255));")
waitForReplicaToCatchUp(t)
turnOnLimitDataToxic(t)
h.testInitialReplicaStatus()
h.primaryDatabase.MustExec("create table reconnect_test(pk int primary key, c1 varchar(255));")
h.waitForReplicaToCatchUp()
h.turnOnLimitDataToxic()
for i := 0; i < 1000; i++ {
value := "foobarbazbashfoobarbazbashfoobarbazbashfoobarbazbashfoobarbazbash"
primaryDatabase.MustExec(fmt.Sprintf("insert into reconnect_test values (%v, %q)", i, value))
h.primaryDatabase.MustExec(fmt.Sprintf("insert into reconnect_test values (%v, %q)", i, value))
}
// Remove the limit_data toxic so that a connection can be reestablished
err := mysqlProxy.RemoveToxic("limit_data")
err := h.mysqlProxy.RemoveToxic("limit_data")
require.NoError(t, err)
t.Logf("Toxiproxy proxy limit_data toxic removed")
// Assert that all records get written to the table
waitForReplicaToCatchUp(t)
h.waitForReplicaToCatchUp()
rows, err := replicaDatabase.Queryx("select min(pk) as min, max(pk) as max, count(pk) as count from db01.reconnect_test;")
rows, err := h.replicaDatabase.Queryx("select min(pk) as min, max(pk) as max, count(pk) as count from db01.reconnect_test;")
require.NoError(t, err)
row := convertMapScanResultToStrings(readNextRow(t, rows))
@@ -69,7 +65,7 @@ func TestBinlogReplicationAutoReconnect(t *testing.T) {
require.NoError(t, rows.Close())
// Assert that show replica status show reconnection IO error
status := showReplicaStatus(t)
status := h.showReplicaStatus()
require.Equal(t, "1158", status["Last_IO_Errno"])
require.True(t, strings.Contains(status["Last_IO_Error"].(string), "EOF"))
requireRecentTimeString(t, status["Last_IO_Error_Timestamp"])
@@ -77,54 +73,54 @@ func TestBinlogReplicationAutoReconnect(t *testing.T) {
// configureFastConnectionRetry configures the replica to retry a failed connection after 5s, instead of the default 60s
// connection retry interval. This is used for testing connection retry logic without waiting the full default period.
func configureFastConnectionRetry(_ *testing.T) {
replicaDatabase.MustExec(
func (h *harness) configureFastConnectionRetry() {
h.replicaDatabase.MustExec(
"change replication source to SOURCE_CONNECT_RETRY=5;")
}
// testInitialReplicaStatus tests the data returned by SHOW REPLICA STATUS and errors
// out if any values are not what is expected for a replica that has just connected
// to a MySQL primary.
func testInitialReplicaStatus(t *testing.T) {
status := showReplicaStatus(t)
func (h *harness) testInitialReplicaStatus() {
status := h.showReplicaStatus()
// Positioning settings
require.Equal(t, "1", status["Auto_Position"])
require.Equal(h.t, "1", status["Auto_Position"])
// Connection settings
require.Equal(t, "5", status["Connect_Retry"])
require.Equal(t, "86400", status["Source_Retry_Count"])
require.Equal(t, "localhost", status["Source_Host"])
require.NotEmpty(t, status["Source_Port"])
require.NotEmpty(t, status["Source_User"])
require.Equal(h.t, "5", status["Connect_Retry"])
require.Equal(h.t, "86400", status["Source_Retry_Count"])
require.Equal(h.t, "localhost", status["Source_Host"])
require.NotEmpty(h.t, status["Source_Port"])
require.NotEmpty(h.t, status["Source_User"])
// Error status
require.Equal(t, "0", status["Last_Errno"])
require.Equal(t, "", status["Last_Error"])
require.Equal(t, "0", status["Last_IO_Errno"])
require.Equal(t, "", status["Last_IO_Error"])
require.Equal(t, "", status["Last_IO_Error_Timestamp"])
require.Equal(t, "0", status["Last_SQL_Errno"])
require.Equal(t, "", status["Last_SQL_Error"])
require.Equal(t, "", status["Last_SQL_Error_Timestamp"])
require.Equal(h.t, "0", status["Last_Errno"])
require.Equal(h.t, "", status["Last_Error"])
require.Equal(h.t, "0", status["Last_IO_Errno"])
require.Equal(h.t, "", status["Last_IO_Error"])
require.Equal(h.t, "", status["Last_IO_Error_Timestamp"])
require.Equal(h.t, "0", status["Last_SQL_Errno"])
require.Equal(h.t, "", status["Last_SQL_Error"])
require.Equal(h.t, "", status["Last_SQL_Error_Timestamp"])
// Empty filter configuration
require.Equal(t, "", status["Replicate_Do_Table"])
require.Equal(t, "", status["Replicate_Ignore_Table"])
require.Equal(h.t, "", status["Replicate_Do_Table"])
require.Equal(h.t, "", status["Replicate_Ignore_Table"])
// Thread status
require.True(t,
require.True(h.t,
status["Replica_IO_Running"] == "Yes" ||
status["Replica_IO_Running"] == "Connecting")
require.Equal(t, "Yes", status["Replica_SQL_Running"])
require.Equal(h.t, "Yes", status["Replica_SQL_Running"])
// Unsupported fields
require.Equal(t, "INVALID", status["Source_Log_File"])
require.Equal(t, "Ignored", status["Source_SSL_Allowed"])
require.Equal(t, "None", status["Until_Condition"])
require.Equal(t, "0", status["SQL_Delay"])
require.Equal(t, "0", status["SQL_Remaining_Delay"])
require.Equal(t, "0", status["Seconds_Behind_Source"])
require.Equal(h.t, "INVALID", status["Source_Log_File"])
require.Equal(h.t, "Ignored", status["Source_SSL_Allowed"])
require.Equal(h.t, "None", status["Until_Condition"])
require.Equal(h.t, "0", status["SQL_Delay"])
require.Equal(h.t, "0", status["SQL_Remaining_Delay"])
require.Equal(h.t, "0", status["Seconds_Behind_Source"])
}
// requireRecentTimeString asserts that the specified |datetime| is a non-nil timestamp string
@@ -141,14 +137,14 @@ func requireRecentTimeString(t *testing.T, datetime interface{}) {
// showReplicaStatus returns a map with the results of SHOW REPLICA STATUS, keyed by the
// name of each column.
func showReplicaStatus(t *testing.T) map[string]interface{} {
rows, err := replicaDatabase.Queryx("show replica status;")
require.NoError(t, err)
func (h *harness) showReplicaStatus() map[string]interface{} {
rows, err := h.replicaDatabase.Queryx("show replica status;")
require.NoError(h.t, err)
defer rows.Close()
return convertMapScanResultToStrings(readNextRow(t, rows))
return convertMapScanResultToStrings(readNextRow(h.t, rows))
}
func configureToxiProxy(t *testing.T) {
func (h *harness) configureToxiProxy() {
toxiproxyPort := findFreePort()
metrics := toxiproxy.NewMetricsContainer(prometheus.NewRegistry())
@@ -157,31 +153,31 @@ func configureToxiProxy(t *testing.T) {
toxiproxyServer.Listen("localhost", strconv.Itoa(toxiproxyPort))
}()
time.Sleep(500 * time.Millisecond)
t.Logf("Toxiproxy control plane running on port %d", toxiproxyPort)
h.t.Logf("Toxiproxy control plane running on port %d", toxiproxyPort)
toxiClient = toxiproxyclient.NewClient(fmt.Sprintf("localhost:%d", toxiproxyPort))
h.toxiClient = toxiproxyclient.NewClient(fmt.Sprintf("localhost:%d", toxiproxyPort))
proxyPort = findFreePort()
h.proxyPort = findFreePort()
var err error
mysqlProxy, err = toxiClient.CreateProxy("mysql",
fmt.Sprintf("localhost:%d", proxyPort), // downstream
fmt.Sprintf("localhost:%d", mySqlPort)) // upstream
h.mysqlProxy, err = h.toxiClient.CreateProxy("mysql",
fmt.Sprintf("localhost:%d", h.proxyPort), // downstream
fmt.Sprintf("localhost:%d", h.mySqlPort)) // upstream
if err != nil {
panic(fmt.Sprintf("unable to create toxiproxy: %v", err.Error()))
}
t.Logf("Toxiproxy proxy started on port %d", proxyPort)
h.t.Logf("Toxiproxy proxy started on port %d", h.proxyPort)
}
// turnOnLimitDataToxic adds a limit_data toxic to the active Toxiproxy, which prevents more than 1KB of data
// from being sent from the primary through the proxy to the replica. Callers MUST call configureToxiProxy
// before calling this function.
func turnOnLimitDataToxic(t *testing.T) {
require.NotNil(t, mysqlProxy)
_, err := mysqlProxy.AddToxic("limit_data", "limit_data", "downstream", 1.0, toxiproxyclient.Attributes{
func (h *harness) turnOnLimitDataToxic() {
require.NotNil(h.t, h.mysqlProxy)
_, err := h.mysqlProxy.AddToxic("limit_data", "limit_data", "downstream", 1.0, toxiproxyclient.Attributes{
"bytes": 1_000,
})
require.NoError(t, err)
t.Logf("Toxiproxy proxy with limit_data toxic (1KB) started on port %d", proxyPort)
require.NoError(h.t, err)
h.t.Logf("Toxiproxy proxy with limit_data toxic (1KB) started on port %d", h.proxyPort)
}
// convertMapScanResultToStrings converts each value in the specified map |m| into a string.

View File

@@ -25,11 +25,11 @@ import (
// TestBinlogReplicationServerRestart tests that a replica can be configured and started, then the
// server process can be restarted and replica can be restarted without problems.
func TestBinlogReplicationServerRestart(t *testing.T) {
defer teardown(t)
startSqlServersWithDoltSystemVars(t, doltReplicaSystemVars)
startReplicationAndCreateTestDb(t, mySqlPort)
h := newHarness(t)
h.startSqlServersWithDoltSystemVars(doltReplicaSystemVars)
h.startReplicationAndCreateTestDb(h.mySqlPort)
primaryDatabase.MustExec("create table t (pk int auto_increment primary key)")
h.primaryDatabase.MustExec("create table t (pk int auto_increment primary key)")
// Launch a goroutine that inserts data for 5 seconds
var wg sync.WaitGroup
@@ -38,22 +38,22 @@ func TestBinlogReplicationServerRestart(t *testing.T) {
defer wg.Done()
limit := 5 * time.Second
for startTime := time.Now(); time.Now().Sub(startTime) <= limit; {
primaryDatabase.MustExec("insert into t values (DEFAULT);")
h.primaryDatabase.MustExec("insert into t values (DEFAULT);")
time.Sleep(100 * time.Millisecond)
}
}()
// Let the replica process a few transactions, then stop the server and pause a second
waitForReplicaToReachGtid(t, 3)
stopDoltSqlServer(t)
h.waitForReplicaToReachGtid(3)
h.stopDoltSqlServer()
time.Sleep(1000 * time.Millisecond)
var err error
doltPort, doltProcess, err = startDoltSqlServer(t, testDir, nil)
h.doltPort, h.doltProcess, err = h.startDoltSqlServer(nil)
require.NoError(t, err)
// Check replication status on the replica and assert configuration persisted
status := showReplicaStatus(t)
status := h.showReplicaStatus()
// The default Connect_Retry interval is 60s; but some tests configure a faster connection retry interval
require.True(t, status["Connect_Retry"] == "5" || status["Connect_Retry"] == "60")
require.Equal(t, "86400", status["Source_Retry_Count"])
@@ -64,16 +64,16 @@ func TestBinlogReplicationServerRestart(t *testing.T) {
// Restart replication on replica
// TODO: For now, we have to set server_id each time we start the service.
// Turn this into a persistent sys var
replicaDatabase.MustExec("set @@global.server_id=123;")
replicaDatabase.MustExec("START REPLICA")
h.replicaDatabase.MustExec("set @@global.server_id=123;")
h.replicaDatabase.MustExec("START REPLICA")
// Assert that all changes have replicated from the primary
wg.Wait()
waitForReplicaToCatchUp(t)
h.waitForReplicaToCatchUp()
countMaxQuery := "SELECT COUNT(pk) AS count, MAX(pk) as max FROM db01.t;"
primaryRows, err := primaryDatabase.Queryx(countMaxQuery)
primaryRows, err := h.primaryDatabase.Queryx(countMaxQuery)
require.NoError(t, err)
replicaRows, err := replicaDatabase.Queryx(countMaxQuery)
replicaRows, err := h.replicaDatabase.Queryx(countMaxQuery)
require.NoError(t, err)
primaryRow := convertMapScanResultToStrings(readNextRow(t, primaryRows))
replicaRow := convertMapScanResultToStrings(readNextRow(t, replicaRows))

View File

@@ -0,0 +1,72 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd
// +build darwin dragonfly freebsd linux netbsd openbsd
package binlogreplication
import (
"os"
"os/exec"
"os/signal"
"syscall"
"time"
)
func ApplyCmdAttributes(cmd *exec.Cmd) {
// Nothing...
}
func StopProcess(proc *os.Process) error {
err := proc.Signal(syscall.SIGTERM)
if err != nil {
return err
}
_, err = proc.Wait()
return err
}
// These tests spawn child process for go compiling, dolt sql-server
// and for mysqld. We would like to clean up these child processes
// when the program exits. In general, we use *testing.T.Cleanup to
// terminate any running processes associated with the test.
//
// On a shell, when a user runs 'go test .', and then they deliver
// an interrupt, '^C', the shell delivers a SIGINT to the process
// group of the foreground process. In our case, `dolt`, `go`, and
// the default signal handler for the golang runtime (this test
// program) will all terminate the program on delivery of a SIGINT.
// `mysqld`, however, does not terminate on receiving SIGINT. Thus,
// we install a handler here, and we translate the Interrupt into
// a SIGTERM against the process group. That will get `mysqld` to
// shutdown as well.
func InstallSignalHandlers() {
interrupts := make(chan os.Signal, 1)
signal.Notify(interrupts, os.Interrupt)
go func() {
<-interrupts
// |mysqld| will exit on SIGTERM
syscall.Kill(-os.Getpid(), syscall.SIGTERM)
time.Sleep(1 * time.Second)
// Canceling this context will cause os.Process.Kill
// to get called on any still-running processes.
commandCtxCancel()
time.Sleep(1 * time.Second)
// Redeliver SIGINT to ourselves with the default
// signal handler restored.
signal.Reset(os.Interrupt)
syscall.Kill(-os.Getpid(), syscall.SIGINT)
}()
}

View File

@@ -0,0 +1,50 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build windows
// +build windows
package binlogreplication
import (
"os"
"os/exec"
"syscall"
"golang.org/x/sys/windows"
)
func ApplyCmdAttributes(cmd *exec.Cmd) {
// Creating a new process group for the process will allow GracefulStop to send the break signal to that process
// without also killing the parent process
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
}
func StopProcess(proc *os.Process) error {
err := windows.GenerateConsoleCtrlEvent(windows.CTRL_BREAK_EVENT, uint32(proc.Pid))
if err != nil {
return err
}
_, err = proc.Wait()
return err
}
// I don't know if there is any magic necessary here, but regardless,
// we don't run these tests on windows, so there are never child
// mysqld processes to worry about.
func InstallSignalHandlers() {
}

View File

@@ -688,9 +688,14 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql.
listenaddr := c.RemoteSrvListenAddr()
args.HttpListenAddr = listenaddr
args.GrpcListenAddr = listenaddr
args.Options = c.ServerOptions()
ctxInterceptor := sqle.SqlContextServerInterceptor{
Factory: ctxFactory,
}
args.Options = append(args.Options, ctxInterceptor.Options()...)
args.Options = append(args.Options, c.ServerOptions()...)
args.HttpInterceptor = ctxInterceptor.HTTP(args.HttpInterceptor)
var err error
args.FS, args.DBCache, err = sqle.RemoteSrvFSAndDBCache(ctxFactory, sqle.CreateUnknownDatabases)
args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.CreateUnknownDatabases)
if err != nil {
return remotesrv.ServerArgs{}, err
}
@@ -699,7 +704,7 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql.
keyID := creds.PubKeyToKID(c.pub)
keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID)
args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, c.pub)
args.HttpInterceptor = JWKSHandlerInterceptor(args.HttpInterceptor, keyIDStr, c.pub)
return args, nil
}

View File

@@ -46,16 +46,21 @@ func (h JWKSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Write(b)
}
func JWKSHandlerInterceptor(keyID string, pub ed25519.PublicKey) func(http.Handler) http.Handler {
func JWKSHandlerInterceptor(existing func(http.Handler) http.Handler, keyID string, pub ed25519.PublicKey) func(http.Handler) http.Handler {
jh := JWKSHandler{KeyID: keyID, PublicKey: pub}
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
this := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.EscapedPath() == "/.well-known/jwks.json" {
jh.ServeHTTP(w, r)
return
}
h.ServeHTTP(w, r)
})
if existing != nil {
return existing(this)
} else {
return this
}
}
}

View File

@@ -171,6 +171,10 @@ func (p *DoltDatabaseProvider) WithFunctions(fns []sql.Function) *DoltDatabasePr
return &cp
}
func (p *DoltDatabaseProvider) RegisterProcedure(procedure sql.ExternalStoredProcedureDetails) {
p.externalProcedures.Register(procedure)
}
// WithDbFactoryUrl returns a copy of this provider with the DbFactoryUrl set as provided.
// The URL is used when creating new databases.
// See doltdb.InMemDoltDB, doltdb.LocalDirDoltDB
@@ -684,7 +688,8 @@ func (p *DoltDatabaseProvider) CloneDatabaseFromRemote(
if exists {
deleteErr := p.fs.Delete(dbName, true)
if deleteErr != nil {
err = fmt.Errorf("%s: unable to clean up failed clone in directory '%s'", err.Error(), dbName)
err = fmt.Errorf("%s: unable to clean up failed clone in directory '%s': %s",
err.Error(), dbName, deleteErr.Error())
}
}
return err

View File

@@ -37,10 +37,19 @@ const (
cmdSuccess = 0
)
var useSessionAwareSafepointController bool
func init() {
if os.Getenv(dconfig.EnvDisableGcProcedure) != "" {
DoltGCFeatureFlag = false
}
if choice := os.Getenv(dconfig.EnvGCSafepointControllerChoice); choice != "" {
if choice == "session_aware" {
useSessionAwareSafepointController = true
} else if choice != "kill_connections" {
panic("Invalid value for " + dconfig.EnvGCSafepointControllerChoice + ". must be session_aware or kill_connections")
}
}
}
var DoltGCFeatureFlag = true
@@ -59,27 +68,127 @@ func doltGC(ctx *sql.Context, args ...string) (sql.RowIter, error) {
var ErrServerPerformedGC = errors.New("this connection was established when this server performed an online garbage collection. this connection can no longer be used. please reconnect.")
type safepointController struct {
begin func(context.Context, func(hash.Hash) bool) error
preFinalize func(context.Context) error
postFinalize func(context.Context) error
cancel func()
// The original behavior safepoint controller, which kills all connections right as the GC process is being finalized.
// The only connection which is left up is the connection on which dolt_gc is called, but that connection is
// invalidated in such a way that all future queries on it return an error.
type killConnectionsSafepointController struct {
callCtx *sql.Context
origEpoch int
}
func (sc safepointController) BeginGC(ctx context.Context, keeper func(hash.Hash) bool) error {
return sc.begin(ctx, keeper)
func (sc killConnectionsSafepointController) BeginGC(ctx context.Context, keeper func(hash.Hash) bool) error {
return nil
}
func (sc safepointController) EstablishPreFinalizeSafepoint(ctx context.Context) error {
return sc.preFinalize(ctx)
func (sc killConnectionsSafepointController) EstablishPreFinalizeSafepoint(ctx context.Context) error {
return nil
}
func (sc safepointController) EstablishPostFinalizeSafepoint(ctx context.Context) error {
return sc.postFinalize(ctx)
func checkEpochSame(origEpoch int) error {
// Here we need to sanity check role and epoch.
if origEpoch != -1 {
if _, role, ok := sql.SystemVariables.GetGlobal(dsess.DoltClusterRoleVariable); ok {
if role.(string) != "primary" {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but now our role is %s", role.(string))
}
_, epoch, ok := sql.SystemVariables.GetGlobal(dsess.DoltClusterRoleEpochVariable)
if !ok {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but we can no longer read the cluster role epoch.")
}
if origEpoch != epoch.(int) {
return fmt.Errorf("dolt_gc failed: when we began we were primary in the cluster at epoch %d, but now we are at epoch %d. for gc to safely finalize, our role and epoch must not change throughout the gc.", origEpoch, epoch.(int))
}
} else {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but we can no longer read the cluster role.")
}
}
return nil
}
func (sc safepointController) CancelSafepoint() {
sc.cancel()
func (sc killConnectionsSafepointController) EstablishPostFinalizeSafepoint(ctx context.Context) error {
err := checkEpochSame(sc.origEpoch)
if err != nil {
return err
}
killed := make(map[uint32]struct{})
processes := sc.callCtx.ProcessList.Processes()
for _, p := range processes {
if p.Connection != sc.callCtx.Session.ID() {
// Kill any inflight query.
sc.callCtx.ProcessList.Kill(p.Connection)
// Tear down the connection itself.
sc.callCtx.KillConnection(p.Connection)
killed[p.Connection] = struct{}{}
}
}
// Look in processes until the connections are actually gone.
params := backoff.NewExponentialBackOff()
params.InitialInterval = 1 * time.Millisecond
params.MaxInterval = 25 * time.Millisecond
params.MaxElapsedTime = 3 * time.Second
err = backoff.Retry(func() error {
processes := sc.callCtx.ProcessList.Processes()
allgood := true
for _, p := range processes {
if _, ok := killed[p.Connection]; ok {
allgood = false
sc.callCtx.ProcessList.Kill(p.Connection)
}
}
if !allgood {
return errors.New("unable to establish safepoint.")
}
return nil
}, params)
if err != nil {
return err
}
sc.callCtx.Session.SetTransaction(nil)
dsess.DSessFromSess(sc.callCtx.Session).SetValidateErr(ErrServerPerformedGC)
return nil
}
func (sc killConnectionsSafepointController) CancelSafepoint() {
}
type sessionAwareSafepointController struct {
controller *dsess.GCSafepointController
callCtx *sql.Context
origEpoch int
waiter *dsess.GCSafepointWaiter
keeper func(hash.Hash) bool
}
func (sc *sessionAwareSafepointController) visit(ctx context.Context, sess *dsess.DoltSession) error {
return sess.VisitGCRoots(ctx, sc.callCtx.GetCurrentDatabase(), sc.keeper)
}
func (sc *sessionAwareSafepointController) BeginGC(ctx context.Context, keeper func(hash.Hash) bool) error {
sc.keeper = keeper
thisSess := dsess.DSessFromSess(sc.callCtx.Session)
err := sc.visit(ctx, thisSess)
if err != nil {
return err
}
sc.waiter = sc.controller.Waiter(ctx, thisSess, sc.visit)
return nil
}
func (sc *sessionAwareSafepointController) EstablishPreFinalizeSafepoint(ctx context.Context) error {
return sc.waiter.Wait(ctx)
}
func (sc *sessionAwareSafepointController) EstablishPostFinalizeSafepoint(ctx context.Context) error {
return checkEpochSame(sc.origEpoch)
}
func (sc *sessionAwareSafepointController) CancelSafepoint() {
canceledCtx, cancel := context.WithCancel(context.Background())
cancel()
sc.waiter.Wait(canceledCtx)
}
func doDoltGC(ctx *sql.Context, args []string) (int, error) {
@@ -122,7 +231,6 @@ func doDoltGC(ctx *sql.Context, args []string) (int, error) {
// We assert that we are the primary here before we begin, and
// we assert again that we are the primary at the same epoch as
// we establish the safepoint.
origepoch := -1
if _, role, ok := sql.SystemVariables.GetGlobal(dsess.DoltClusterRoleVariable); ok {
// TODO: magic constant...
@@ -141,71 +249,20 @@ func doDoltGC(ctx *sql.Context, args []string) (int, error) {
mode = types.GCModeFull
}
// TODO: Implement safepointController so that begin can capture inflight sessions
// and preFinalize can ensure they're all in a good place before returning.
sc := safepointController{
begin: func(context.Context, func(hash.Hash) bool) error { return nil },
preFinalize: func(context.Context) error { return nil },
postFinalize: func(context.Context) error {
if origepoch != -1 {
// Here we need to sanity check role and epoch.
if _, role, ok := sql.SystemVariables.GetGlobal(dsess.DoltClusterRoleVariable); ok {
if role.(string) != "primary" {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but now our role is %s", role.(string))
}
_, epoch, ok := sql.SystemVariables.GetGlobal(dsess.DoltClusterRoleEpochVariable)
if !ok {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but we can no longer read the cluster role epoch.")
}
if origepoch != epoch.(int) {
return fmt.Errorf("dolt_gc failed: when we began we were primary in the cluster at epoch %d, but now we are at epoch %d. for gc to safely finalize, our role and epoch must not change throughout the gc.", origepoch, epoch.(int))
}
} else {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but we can no longer read the cluster role.")
}
}
killed := make(map[uint32]struct{})
processes := ctx.ProcessList.Processes()
for _, p := range processes {
if p.Connection != ctx.Session.ID() {
// Kill any inflight query.
ctx.ProcessList.Kill(p.Connection)
// Tear down the connection itself.
ctx.KillConnection(p.Connection)
killed[p.Connection] = struct{}{}
}
}
// Look in processes until the connections are actually gone.
params := backoff.NewExponentialBackOff()
params.InitialInterval = 1 * time.Millisecond
params.MaxInterval = 25 * time.Millisecond
params.MaxElapsedTime = 3 * time.Second
err := backoff.Retry(func() error {
processes := ctx.ProcessList.Processes()
allgood := true
for _, p := range processes {
if _, ok := killed[p.Connection]; ok {
allgood = false
ctx.ProcessList.Kill(p.Connection)
}
}
if !allgood {
return errors.New("unable to establish safepoint.")
}
return nil
}, params)
if err != nil {
return err
}
ctx.Session.SetTransaction(nil)
dsess.DSessFromSess(ctx.Session).SetValidateErr(ErrServerPerformedGC)
return nil
},
cancel: func() {},
var sc types.GCSafepointController
if useSessionAwareSafepointController {
gcSafepointController := dSess.GCSafepointController()
sc = &sessionAwareSafepointController{
origEpoch: origepoch,
callCtx: ctx,
controller: gcSafepointController,
}
} else {
sc = killConnectionsSafepointController{
origEpoch: origepoch,
callCtx: ctx,
}
}
err = ddb.GC(ctx, mode, sc)
if err != nil {
return cmdFailure, err

View File

@@ -36,7 +36,6 @@ var DoltProcedures = []sql.ExternalStoredProcedureDetails{
{Name: "dolt_purge_dropped_databases", Schema: int64Schema("status"), Function: doltPurgeDroppedDatabases, AdminOnly: true},
{Name: "dolt_rebase", Schema: doltRebaseProcedureSchema, Function: doltRebase},
// dolt_gc is enabled behind a feature flag for now, see dolt_gc.go
{Name: "dolt_gc", Schema: int64Schema("status"), Function: doltGC, ReadOnly: true, AdminOnly: true},
{Name: "dolt_merge", Schema: doltMergeSchema, Function: doltMerge},

View File

@@ -0,0 +1,306 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dsess
import (
"context"
"errors"
"sync"
"sync/atomic"
)
type GCSafepointController struct {
mu sync.Mutex
// All known sessions. The first command registers the session
// here and SessionEnd causes it to be removed.
sessions map[*DoltSession]*GCSafepointSessionState
}
type GCSafepointSessionState struct {
// True when a command is outstanding on the session,
// false otherwise.
OutstandingCommand bool
// Registered when we create a GCSafepointWaiter if
// there is an outstanding command on the session
// at the time. This will be called when the session's
// SessionCommandEnd function is called.
CommandEndCallback func()
// When this channel is non-nil, it means that an
// outstanding visit session call is ongoing for this
// session. The CommandBegin callback will block until
// that call has completed.
QuiesceCallbackDone atomic.Value // chan struct{}
}
// Make is so that HasOutstandingVisitCall will return true and
// BlockForOutstandingVisitCall will block until
// EndOutstandingVisitCall is called.
func (state *GCSafepointSessionState) BeginOutstandingVisitCall() {
state.QuiesceCallbackDone.Store(make(chan struct{}))
}
// Bracket the end of an outstanding visit call. Unblocks any
// callers to |BlockForOutstandingVisitCall|. Must be paired
// one-for-one with calls to |BeginOutstandingVisitCall|.
func (state *GCSafepointSessionState) EndOutstandingVisitCall() {
close(state.QuiesceCallbackDone.Load().(chan struct{}))
}
// Peek whether |BlockForOutstandingVisitCall| would block.
func (state *GCSafepointSessionState) HasOutstandingVisitCall() bool {
ch := state.QuiesceCallbackDone.Load().(chan struct{})
select {
case <-ch:
return false
default:
return true
}
}
func (state *GCSafepointSessionState) BlockForOutstandingVisitCall() {
ch := state.QuiesceCallbackDone.Load().(chan struct{})
<-ch
}
var closedCh = make(chan struct{})
func init() {
close(closedCh)
}
func NewGCSafepointSessionState() *GCSafepointSessionState {
state := &GCSafepointSessionState{}
state.QuiesceCallbackDone.Store(closedCh)
return state
}
type GCSafepointWaiter struct {
controller *GCSafepointController
wg sync.WaitGroup
mu sync.Mutex
err error
}
func NewGCSafepointController() *GCSafepointController {
return &GCSafepointController{
sessions: make(map[*DoltSession]*GCSafepointSessionState),
}
}
// The GCSafepointController is keeping track of *DoltSession instances that have ever had work done.
// By pairing up CommandBegin and CommandEnd callbacks, it can identify quiesced sessions--ones that
// are not currently running work. Calling |Waiter| asks the controller to concurrently call
// |visitQuiescedSession| on each known session as soon as it is safe and possible. The returned
// |Waiter| is used to |Wait| for all of those to be completed. A call is not made for |thisSession|,
// since, if that session corresponds to an ongoing SQL procedure call, for example, that session
// will never quiesce. Instead, the caller should ensure that |visitQuiescedSession| is called on
// its own session.
//
// After creating a Waiter, it is an error to create a new Waiter before the |Wait| method of the
// original watier has returned. This error is not guaranteed to always be detected.
func (c *GCSafepointController) Waiter(ctx context.Context, thisSession *DoltSession, visitQuiescedSession func(context.Context, *DoltSession) error) *GCSafepointWaiter {
c.mu.Lock()
defer c.mu.Unlock()
ret := &GCSafepointWaiter{controller: c}
for sess, state := range c.sessions {
// If an existing session already has a |CommandEndCallback| registered,
// then more than one |Waiter| would be outstanding on this
// SafepointController. This is an error and is not supported.
if state.CommandEndCallback != nil {
panic("Attempt to create more than one GCSafepointWaiter.")
}
if sess == thisSession {
continue
}
// When this session's |visit| call is done, it will count down this
// waitgroup. The |Wait| method, in turn, will block on this waitgroup
// completing to know that all callback are done.
ret.wg.Add(1)
// The work we do when we visit the session, including bookkeeping.
work := func() {
// We don't set this until the callback is actually called.
// If we did set this outside of the callback, Wait's
// cleanup logic would need to change to ensure that the
// session is in a usable state when the callback gets
// canceled before ever being called.
state.BeginOutstandingVisitCall()
go func() {
err := visitQuiescedSession(ctx, sess)
ret.accumulateErrors(err)
ret.wg.Done()
state.EndOutstandingVisitCall()
}()
}
if state.OutstandingCommand {
// If a command is currently running on the session, register
// our work to run as soon as the command is done.
state.CommandEndCallback = work
} else {
// When no command is running on the session, we can immediately
// visit it.
work()
}
}
return ret
}
func (w *GCSafepointWaiter) accumulateErrors(err error) {
if err != nil {
w.mu.Lock()
w.err = errors.Join(w.err, err)
w.mu.Unlock()
}
}
// |Wait| will block on the Waiter's waitgroup. A successful
// return from this method signals that all sessions that were known
// about when the waiter was created have been visited by the
// |visitQuiescedSession| callback that was given to |Waiter|.
//
// This function will return early, and with an error, if the
// supplied |ctx|'s |Done| channel delivers. In that case,
// all sessions will not necessarily have been visited, but
// any visit callbacks which were started will still have
// completed.
//
// In addition to returning an error if the passed in |ctx|
// is |Done| before the wait is finished, this function also
// returns accumulated errors as seen from each
// |visitQuiescedSession| callback. No attempt is made to
// cancel callbacks or to return early in the case that errors
// are seen from the callback functions.
func (w *GCSafepointWaiter) Wait(ctx context.Context) error {
done := make(chan struct{})
go func() {
w.wg.Wait()
close(done)
}()
select {
case <-done:
return w.err
case <-ctx.Done():
w.controller.mu.Lock()
for _, state := range w.controller.sessions {
if state.CommandEndCallback != nil {
// Do not visit the session, but do
// count down the WaitGroup so that
// the goroutine above still completes.
w.wg.Done()
state.CommandEndCallback = nil
}
}
w.controller.mu.Unlock()
// Once a session visit callback has started, we
// cannot cancel it. So we wait for all the inflight
// callbacks to be completed here, before returning.
<-done
return errors.Join(context.Cause(ctx), w.err)
}
}
// Beginning a command on a session has three effects:
//
// 1. It registers the Session in the set of all
// known sessions, |c.sessions|, if this is our
// first time seeing it.
//
// 2. It blocks for any existing call to |CommandEndCallback|
// on this session to complete. If a call to |CommendEndCallback|
// is outstanding, our |QuiesceCallbackDone| a read from our
// |QuiesceCallbackDone| channel will block.
//
// 3. It sets |OutstandingCommand| for the Session to true. Only
// one command can be outstanding at a time, and whether a command
// is outstanding controls how |Waiter| treats the Session when it
// is setting up all Sessions to visit their GC roots.
func (c *GCSafepointController) SessionCommandBegin(s *DoltSession) error {
c.mu.Lock()
defer c.mu.Unlock()
var state *GCSafepointSessionState
if state = c.sessions[s]; state == nil {
// Step #1: keep track of all seen sessions.
state = NewGCSafepointSessionState()
c.sessions[s] = state
}
if state.OutstandingCommand {
panic("SessionBeginCommand called on a session that already had an outstanding command.")
}
// Step #2: Receiving from QuiesceCallbackDone blocks, then
// the callback for this Session is still outstanding. We
// don't want to block on this work finishing while holding
// the controller-wide lock, so we release it while we block.
if state.HasOutstandingVisitCall() {
c.mu.Unlock()
state.BlockForOutstandingVisitCall()
c.mu.Lock()
if state.OutstandingCommand {
// Concurrent calls to SessionCommandBegin. Bad times...
panic("SessionBeginCommand called on a session that already had an outstanding command.")
}
}
// Step #3. Record that a command is running so that Waiter
// will populate CommandEndCallback instead of running the
// visit logic immediately.
state.OutstandingCommand = true
return nil
}
// SessionCommandEnd marks the end of a session command. It has for
// effects that the session no longer has an OutstandingCommand and,
// if CommandEndCallback was non-nil, the callback itself has been
// called and the CommandEndCallback field has been reset to |nil|.
func (c *GCSafepointController) SessionCommandEnd(s *DoltSession) {
c.mu.Lock()
defer c.mu.Unlock()
state := c.sessions[s]
if state == nil {
panic("SessionCommandEnd called on a session that was not registered")
}
if state.OutstandingCommand != true {
panic("SessionCommandEnd called on a session that did not have an outstanding command.")
}
if state.CommandEndCallback != nil {
state.CommandEndCallback()
state.CommandEndCallback = nil
}
state.OutstandingCommand = false
}
// SessionEnd will remove the session from our tracked session state,
// if we already knew about it. It is an error to call this on a
// session which currently has an outstanding command.
//
// Because we only register sessions when the BeginCommand, it is
// possible to get a SessionEnd callback for a session that was
// never registered.
//
// This callback does not block for any outstanding |visitQuiescedSession|
// callback to be completed before allowing the session to unregister
// itself. It is an error for the application to call |SessionBeginCommand|
// on a session after it is has called |SessionEnd| on it, but that error
// is not necessarily detected.
func (c *GCSafepointController) SessionEnd(s *DoltSession) {
c.mu.Lock()
defer c.mu.Unlock()
state := c.sessions[s]
if state != nil {
if state.OutstandingCommand == true {
panic("SessionEnd called on a session that had an outstanding command.")
}
delete(c.sessions, s)
}
}

View File

@@ -0,0 +1,292 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dsess
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestGCSafepointController(t *testing.T) {
t.Parallel()
t.Run("SessionEnd", func(t *testing.T) {
t.Parallel()
t.Run("UnknownSession", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
controller.SessionEnd(&DoltSession{})
})
t.Run("KnownSession", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
controller.SessionCommandBegin(sess)
controller.SessionCommandEnd(sess)
controller.SessionEnd(sess)
})
t.Run("RunningSession", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
controller.SessionCommandBegin(sess)
require.Panics(t, func() {
controller.SessionEnd(sess)
})
})
})
t.Run("CommandBegin", func(t *testing.T) {
t.Parallel()
t.Run("RunningSession", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
controller.SessionCommandBegin(sess)
require.Panics(t, func() {
controller.SessionCommandBegin(sess)
})
})
t.Run("AfterCommandEnd", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
controller.SessionCommandBegin(sess)
controller.SessionCommandEnd(sess)
controller.SessionCommandBegin(sess)
})
})
t.Run("CommandEnd", func(t *testing.T) {
t.Parallel()
t.Run("NotKnown", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
require.Panics(t, func() {
controller.SessionCommandEnd(sess)
})
})
t.Run("NotRunning", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
controller.SessionCommandBegin(sess)
controller.SessionCommandEnd(sess)
require.Panics(t, func() {
controller.SessionCommandEnd(sess)
})
})
})
t.Run("Waiter", func(t *testing.T) {
t.Parallel()
t.Run("Empty", func(t *testing.T) {
t.Parallel()
var nilCh chan struct{}
block := func(context.Context, *DoltSession) error {
<-nilCh
return nil
}
controller := NewGCSafepointController()
waiter := controller.Waiter(context.Background(), nil, block)
waiter.Wait(context.Background())
})
t.Run("OnlyThisSession", func(t *testing.T) {
t.Parallel()
var nilCh chan struct{}
block := func(context.Context, *DoltSession) error {
<-nilCh
return nil
}
sess := &DoltSession{}
controller := NewGCSafepointController()
controller.SessionCommandBegin(sess)
waiter := controller.Waiter(context.Background(), sess, block)
waiter.Wait(context.Background())
controller.SessionCommandEnd(sess)
controller.SessionEnd(sess)
})
t.Run("OneQuiescedOneNot", func(t *testing.T) {
t.Parallel()
// A test case where one session is known
// but not within a command and another one
// is within a command at the time the
// waiter is created.
quiesced := &DoltSession{}
running := &DoltSession{}
controller := NewGCSafepointController()
controller.SessionCommandBegin(quiesced)
controller.SessionCommandBegin(running)
controller.SessionCommandEnd(quiesced)
sawQuiesced, sawRunning, waitDone := make(chan struct{}), make(chan struct{}), make(chan struct{})
wait := func(_ context.Context, s *DoltSession) error {
if s == quiesced {
close(sawQuiesced)
} else if s == running {
close(sawRunning)
} else {
panic("saw unexpected session")
}
return nil
}
waiter := controller.Waiter(context.Background(), nil, wait)
go func() {
waiter.Wait(context.Background())
close(waitDone)
}()
<-sawQuiesced
select {
case <-sawRunning:
require.FailNow(t, "unexpected saw running session on callback before it was quiesced")
case <-time.After(50 * time.Millisecond):
}
controller.SessionCommandEnd(running)
<-sawRunning
<-waitDone
controller.SessionCommandBegin(quiesced)
controller.SessionCommandBegin(running)
controller.SessionCommandEnd(quiesced)
controller.SessionCommandEnd(running)
})
t.Run("OneQuiescedOneNotCanceledContext", func(t *testing.T) {
t.Parallel()
// When the Wait context is canceled, we do not block on
// the running sessions and they never get visited.
quiesced := &DoltSession{}
running := &DoltSession{}
controller := NewGCSafepointController()
controller.SessionCommandBegin(quiesced)
controller.SessionCommandBegin(running)
controller.SessionCommandEnd(quiesced)
sawQuiesced, sawRunning, waitDone := make(chan struct{}), make(chan struct{}), make(chan struct{})
wait := func(_ context.Context, s *DoltSession) error {
if s == quiesced {
close(sawQuiesced)
} else if s == running {
close(sawRunning)
} else {
panic("saw unexpected session")
}
return nil
}
waiter := controller.Waiter(context.Background(), nil, wait)
var waitErr error
go func() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
waitErr = waiter.Wait(ctx)
close(waitDone)
}()
<-sawQuiesced
<-waitDone
require.Error(t, waitErr)
select {
case <-sawRunning:
require.FailNow(t, "unexpected saw running session on callback before it was quiesced")
case <-time.After(50 * time.Millisecond):
}
controller.SessionCommandEnd(running)
select {
case <-sawRunning:
require.FailNow(t, "unexpected saw running session on callback before it was quiesced")
case <-time.After(50 * time.Millisecond):
}
controller.SessionCommandBegin(quiesced)
controller.SessionCommandBegin(running)
controller.SessionCommandEnd(quiesced)
controller.SessionCommandEnd(running)
})
t.Run("BeginBlocksUntilVisitFinished", func(t *testing.T) {
t.Parallel()
quiesced := &DoltSession{}
running := &DoltSession{}
controller := NewGCSafepointController()
controller.SessionCommandBegin(quiesced)
controller.SessionCommandEnd(quiesced)
controller.SessionCommandBegin(running)
finishQuiesced, finishRunning := make(chan struct{}), make(chan struct{})
sawQuiesced, sawRunning := make(chan struct{}), make(chan struct{})
wait := func(_ context.Context, s *DoltSession) error {
if s == quiesced {
close(sawQuiesced)
<-finishQuiesced
} else if s == running {
close(sawRunning)
<-finishRunning
} else {
panic("saw unexpected session")
}
return nil
}
waiter := controller.Waiter(context.Background(), nil, wait)
waitDone := make(chan struct{})
go func() {
waiter.Wait(context.Background())
close(waitDone)
}()
beginDone := make(chan struct{})
go func() {
controller.SessionCommandBegin(quiesced)
close(beginDone)
}()
<-sawQuiesced
select {
case <-beginDone:
require.FailNow(t, "unexpected beginDone")
case <-time.After(50 * time.Millisecond):
}
newSession := &DoltSession{}
controller.SessionCommandBegin(newSession)
controller.SessionCommandEnd(newSession)
controller.SessionEnd(newSession)
close(finishQuiesced)
<-beginDone
beginDone = make(chan struct{})
go func() {
controller.SessionCommandEnd(running)
<-sawRunning
controller.SessionCommandBegin(running)
close(beginDone)
}()
select {
case <-beginDone:
require.FailNow(t, "unexpected beginDone")
case <-time.After(50 * time.Millisecond):
}
close(finishRunning)
<-beginDone
<-waitDone
controller.SessionCommandEnd(quiesced)
controller.SessionCommandEnd(running)
controller.SessionCommandBegin(quiesced)
controller.SessionCommandBegin(running)
controller.SessionCommandEnd(quiesced)
controller.SessionCommandEnd(running)
controller.SessionEnd(quiesced)
controller.SessionEnd(running)
err := controller.Waiter(context.Background(), nil, func(context.Context, *DoltSession) error {
panic("unexpected registered session")
}).Wait(context.Background())
require.NoError(t, err)
})
})
}

View File

@@ -50,19 +50,20 @@ var ErrSessionNotPersistable = errors.New("session is not persistable")
// DoltSession is the sql.Session implementation used by dolt. It is accessible through a *sql.Context instance
type DoltSession struct {
sql.Session
DoltgresSessObj any // This is used by Doltgres to persist objects in the session. This is not used by Dolt.
username string
email string
dbStates map[string]*DatabaseSessionState
dbCache *DatabaseCache
provider DoltDatabaseProvider
tempTables map[string][]sql.Table
globalsConf config.ReadWriteConfig
branchController *branch_control.Controller
statsProv sql.StatsProvider
mu *sync.Mutex
fs filesys.Filesys
writeSessProv WriteSessFunc
DoltgresSessObj any // This is used by Doltgres to persist objects in the session. This is not used by Dolt.
username string
email string
dbStates map[string]*DatabaseSessionState
dbCache *DatabaseCache
provider DoltDatabaseProvider
tempTables map[string][]sql.Table
globalsConf config.ReadWriteConfig
branchController *branch_control.Controller
statsProv sql.StatsProvider
mu *sync.Mutex
fs filesys.Filesys
writeSessProv WriteSessFunc
gcSafepointController *GCSafepointController
// If non-nil, this will be returned from ValidateSession.
// Used by sqle/cluster to put a session into a terminal err state.
@@ -100,25 +101,27 @@ func NewDoltSession(
branchController *branch_control.Controller,
statsProvider sql.StatsProvider,
writeSessProv WriteSessFunc,
gcSafepointController *GCSafepointController,
) (*DoltSession, error) {
username := conf.GetStringOrDefault(config.UserNameKey, "")
email := conf.GetStringOrDefault(config.UserEmailKey, "")
globals := config.NewPrefixConfig(conf, env.SqlServerGlobalsPrefix)
sess := &DoltSession{
Session: sqlSess,
username: username,
email: email,
dbStates: make(map[string]*DatabaseSessionState),
dbCache: newDatabaseCache(),
provider: pro,
tempTables: make(map[string][]sql.Table),
globalsConf: globals,
branchController: branchController,
statsProv: statsProvider,
mu: &sync.Mutex{},
fs: pro.FileSystem(),
writeSessProv: writeSessProv,
Session: sqlSess,
username: username,
email: email,
dbStates: make(map[string]*DatabaseSessionState),
dbCache: newDatabaseCache(),
provider: pro,
tempTables: make(map[string][]sql.Table),
globalsConf: globals,
branchController: branchController,
statsProv: statsProvider,
mu: &sync.Mutex{},
fs: pro.FileSystem(),
writeSessProv: writeSessProv,
gcSafepointController: gcSafepointController,
}
return sess, nil
@@ -790,6 +793,94 @@ func (d *DoltSession) Rollback(ctx *sql.Context, tx sql.Transaction) error {
return nil
}
// As part of GC, ongoing *DoltSessions are asked to make their roots available to the GC process.
// A *DoltSession has the following roots:
// 1) All of the branchStates for the database.
// 2) If there is an active transaction, the initial root for that transaction and any roots for any savepoints of that transaction.
// 3) Working set roots in any writeSession.
func (d *DoltSession) VisitGCRoots(ctx context.Context, dbName string, keep func(hash.Hash) bool) error {
dbName = strings.ToLower(dbName)
dbName, _ = SplitRevisionDbName(dbName)
d.mu.Lock()
dbState, dbStateFound := d.dbStates[dbName]
d.mu.Unlock()
if dbStateFound {
for _, head := range dbState.heads {
if head.headRoot != nil {
h, err := head.headRoot.HashOf()
if err != nil {
return err
}
if keep(h) {
panic("gc safepoint establishment found inconsistent state; process could not guarantee it would be able to keep a chunk if we continue")
}
} else if head.headCommit != nil {
h, err := head.headCommit.HashOf()
if err != nil {
return err
}
if keep(h) {
panic("gc safepoint establishment found inconsistent state; process could not guarantee it would be able to keep a chunk if we continue")
}
} else if head.workingSet != nil {
hashes, err := head.dbData.Ddb.WorkingSetHashes(ctx, head.workingSet)
if err != nil {
return err
}
for _, h := range hashes {
if keep(h) {
panic("gc safepoint establishment found inconsistent state; process could not guarantee it would be able to keep a chunk if we continue")
}
}
}
if head.writeSession != nil {
ws := head.writeSession.GetWorkingSet()
hashes, err := head.dbData.Ddb.WorkingSetHashes(ctx, ws)
if err != nil {
return err
}
for _, h := range hashes {
if keep(h) {
panic("gc safepoint establishment found inconsistent state; process could not guarantee it would be able to keep a chunk if we continue")
}
}
}
}
}
tx := d.GetTransaction()
if tx == nil {
return nil
}
dtx, ok := tx.(*DoltTransaction)
if !ok {
// weird...
return nil
}
h, has := dtx.GetInitialRoot(dbName)
if has && keep(h) {
panic("gc safepoint establishment found inconsistent state; process could not guarantee it could would be able to keep a chunk if we continue")
}
for _, savepoint := range dtx.savepoints {
rv, ok := savepoint.roots[dbName]
if ok {
h, err := rv.HashOf()
if err != nil {
return err
}
if keep(h) {
panic("gc safepoint establishment found inconsistent state; process could not guarantee it could would be able to keep a chunk if we continue")
}
}
}
return nil
}
// CreateSavepoint creates a new savepoint for this transaction with the name given. A previously created savepoint
// with the same name will be overwritten.
func (d *DoltSession) CreateSavepoint(ctx *sql.Context, tx sql.Transaction, savepointName string) error {
@@ -1628,6 +1719,33 @@ func (d *DoltSession) GetController() *branch_control.Controller {
return d.branchController
}
// Implement sql.LifecycleAwareSession, allowing for GC safepoints to be aware of
// outstanding SQL operations.
func (d *DoltSession) CommandBegin() error {
if d.gcSafepointController != nil {
return d.gcSafepointController.SessionCommandBegin(d)
}
return nil
}
func (d *DoltSession) CommandEnd() {
if d.gcSafepointController != nil {
d.gcSafepointController.SessionCommandEnd(d)
}
}
func (d *DoltSession) SessionEnd() {
if d.gcSafepointController != nil {
d.gcSafepointController.SessionEnd(d)
}
}
// dolt_gc accesses the safepoint controller for the current
// sql engine through here.
func (d *DoltSession) GCSafepointController() *GCSafepointController {
return d.gcSafepointController
}
// validatePersistedSysVar checks whether a system variable exists and is dynamic
func validatePersistableSysVar(name string) (sql.SystemVariable, interface{}, error) {
sysVar, val, ok := sql.SystemVariables.GetGlobal(name)

View File

@@ -93,7 +93,7 @@ type dbRoot struct {
type savepoint struct {
name string
// TODO: we need a root value per DB here
// from db name to the root value for that database.
roots map[string]doltdb.RootValue
}

View File

@@ -44,22 +44,23 @@ import (
)
type DoltHarness struct {
t *testing.T
provider dsess.DoltDatabaseProvider
statsPro sql.StatsProvider
multiRepoEnv *env.MultiRepoEnv
session *dsess.DoltSession
branchControl *branch_control.Controller
parallelism int
skippedQueries []string
setupData []setup.SetupScript
resetData []setup.SetupScript
engine *gms.Engine
setupDbs map[string]struct{}
skipSetupCommit bool
configureStats bool
useLocalFilesystem bool
setupTestProcedures bool
t *testing.T
provider dsess.DoltDatabaseProvider
statsPro sql.StatsProvider
multiRepoEnv *env.MultiRepoEnv
session *dsess.DoltSession
branchControl *branch_control.Controller
gcSafepointController *dsess.GCSafepointController
parallelism int
skippedQueries []string
setupData []setup.SetupScript
resetData []setup.SetupScript
engine *gms.Engine
setupDbs map[string]struct{}
skipSetupCommit bool
configureStats bool
useLocalFilesystem bool
setupTestProcedures bool
}
func (d *DoltHarness) UseLocalFileSystem() {
@@ -243,11 +244,13 @@ func (d *DoltHarness) NewEngine(t *testing.T) (enginetest.QueryEngine, error) {
require.True(t, ok)
d.provider = doltProvider
d.gcSafepointController = dsess.NewGCSafepointController()
statsProv := statspro.NewProvider(d.provider.(*sqle.DoltDatabaseProvider), statsnoms.NewNomsStatsFactory(d.multiRepoEnv.RemoteDialProvider()))
d.statsPro = statsProv
var err error
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession)
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession, d.gcSafepointController)
require.NoError(t, err)
e, err := enginetest.NewEngine(t, d, d.provider, d.setupData, d.statsPro)
@@ -274,7 +277,7 @@ func (d *DoltHarness) NewEngine(t *testing.T) (enginetest.QueryEngine, error) {
}
// Get a fresh session after running setup scripts, since some setup scripts can change the session state
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession)
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession, nil)
require.NoError(t, err)
}
@@ -315,7 +318,7 @@ func (d *DoltHarness) NewEngine(t *testing.T) (enginetest.QueryEngine, error) {
e, err := enginetest.RunSetupScripts(sqlCtx, d.engine, d.resetScripts(), d.SupportsNativeIndexCreation())
// Get a fresh session after running setup scripts, since some setup scripts can change the session state
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession)
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession, nil)
require.NoError(t, err)
return e, err
@@ -397,7 +400,7 @@ func (d *DoltHarness) newSessionWithClient(client sql.Client) *dsess.DoltSession
localConfig := d.multiRepoEnv.Config()
pro := d.session.Provider()
dSession, err := dsess.NewDoltSession(sql.NewBaseSessionWithClientServer("address", client, 1), pro.(dsess.DoltDatabaseProvider), localConfig, d.branchControl, d.statsPro, writer.NewWriteSession)
dSession, err := dsess.NewDoltSession(sql.NewBaseSessionWithClientServer("address", client, 1), pro.(dsess.DoltDatabaseProvider), localConfig, d.branchControl, d.statsPro, writer.NewWriteSession, nil)
dSession.SetCurrentDatabase("mydb")
require.NoError(d.t, err)
return dSession
@@ -430,7 +433,7 @@ func (d *DoltHarness) NewDatabases(names ...string) []sql.Database {
d.statsPro = statspro.NewProvider(doltProvider, statsnoms.NewNomsStatsFactory(d.multiRepoEnv.RemoteDialProvider()))
var err error
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), doltProvider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession)
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), doltProvider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession, nil)
require.NoError(d.t, err)
// TODO: the engine tests should do this for us
@@ -487,7 +490,7 @@ func (d *DoltHarness) NewReadOnlyEngine(provider sql.DatabaseProvider) (enginete
}
// reset the session as well since we have swapped out the database provider, which invalidates caching assumptions
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), readOnlyProvider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession)
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), readOnlyProvider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession, d.gcSafepointController)
require.NoError(d.t, err)
return enginetest.NewEngineWithProvider(nil, d, readOnlyProvider), nil

View File

@@ -144,7 +144,7 @@ func innerInit(h *DoltHarness, dEnv *env.DoltEnv) error {
return err
}
sqlCtx := dsql.NewTestSQLCtxWithProvider(ctx, pro, statspro.NewProvider(pro.(*dsql.DoltDatabaseProvider), statsnoms.NewNomsStatsFactory(env.NewGRPCDialProviderFromDoltEnv(dEnv))))
sqlCtx := dsql.NewTestSQLCtxWithProvider(ctx, pro, statspro.NewProvider(pro.(*dsql.DoltDatabaseProvider), statsnoms.NewNomsStatsFactory(env.NewGRPCDialProviderFromDoltEnv(dEnv))), dsess.NewGCSafepointController())
h.sess = sqlCtx.Session.(*dsess.DoltSession)
dbs := h.engine.Analyzer.Catalog.AllDatabases(sqlCtx)

View File

@@ -16,13 +16,15 @@ package sqle
import (
"context"
"errors"
"net/http"
"github.com/dolthub/go-mysql-server/sql"
"google.golang.org/grpc"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/remotesrv"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/store/datas"
)
@@ -81,17 +83,12 @@ type CreateUnknownDatabasesSetting bool
const CreateUnknownDatabases CreateUnknownDatabasesSetting = true
const DoNotCreateUnknownDatabases CreateUnknownDatabasesSetting = false
// Considers |args| and returns a new |remotesrv.ServerArgs| instance which
// will serve databases accessible through |ctxFactory|.
func RemoteSrvFSAndDBCache(ctxFactory func(context.Context) (*sql.Context, error), createSetting CreateUnknownDatabasesSetting) (filesys.Filesys, remotesrv.DBCache, error) {
sqlCtx, err := ctxFactory(context.Background())
if err != nil {
return nil, nil, err
}
sess := dsess.DSessFromSess(sqlCtx.Session)
fs := sess.Provider().FileSystem()
// Returns a remotesrv.DBCache instance which will use the *sql.Context
// returned from |ctxFactory| to access a database in the session
// DatabaseProvider.
func RemoteSrvDBCache(ctxFactory func(context.Context) (*sql.Context, error), createSetting CreateUnknownDatabasesSetting) (remotesrv.DBCache, error) {
dbcache := remotesrvStore{ctxFactory, bool(createSetting)}
return fs, dbcache, nil
return dbcache, nil
}
func WithUserPasswordAuth(args remotesrv.ServerArgs, authnz remotesrv.AccessControl) remotesrv.ServerArgs {
@@ -102,3 +99,88 @@ func WithUserPasswordAuth(args remotesrv.ServerArgs, authnz remotesrv.AccessCont
args.Options = append(args.Options, si.Options()...)
return args
}
type SqlContextServerInterceptor struct {
Factory func(context.Context) (*sql.Context, error)
}
type serverStreamWrapper struct {
grpc.ServerStream
ctx context.Context
}
func (s serverStreamWrapper) Context() context.Context {
return s.ctx
}
type sqlContextInterceptorKey struct{}
func GetInterceptorSqlContext(ctx context.Context) (*sql.Context, error) {
if v := ctx.Value(sqlContextInterceptorKey{}); v != nil {
return v.(*sql.Context), nil
}
return nil, errors.New("misconfiguration; a sql.Context should always be available from the interceptor chain.")
}
func (si SqlContextServerInterceptor) Stream() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
sqlCtx, err := si.Factory(ss.Context())
if err != nil {
return err
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
newCtx := context.WithValue(ss.Context(), sqlContextInterceptorKey{}, sqlCtx)
newSs := serverStreamWrapper{
ServerStream: ss,
ctx: newCtx,
}
return handler(srv, newSs)
}
}
func (si SqlContextServerInterceptor) Unary() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
sqlCtx, err := si.Factory(ctx)
if err != nil {
return nil, err
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
newCtx := context.WithValue(ctx, sqlContextInterceptorKey{}, sqlCtx)
return handler(newCtx, req)
}
}
func (si SqlContextServerInterceptor) HTTP(existing func(http.Handler) http.Handler) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
this := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
sqlCtx, err := si.Factory(ctx)
if err != nil {
http.Error(w, "could not initialize sql.Context", http.StatusInternalServerError)
return
}
defer sql.SessionEnd(sqlCtx.Session)
sql.SessionCommandBegin(sqlCtx.Session)
defer sql.SessionCommandEnd(sqlCtx.Session)
newCtx := context.WithValue(ctx, sqlContextInterceptorKey{}, sqlCtx)
newReq := r.WithContext(newCtx)
h.ServeHTTP(w, newReq)
})
if existing != nil {
return existing(this)
} else {
return this
}
}
}
func (si SqlContextServerInterceptor) Options() []grpc.ServerOption {
return []grpc.ServerOption{
grpc.ChainUnaryInterceptor(si.Unary()),
grpc.ChainStreamInterceptor(si.Stream()),
}
}

View File

@@ -1108,13 +1108,14 @@ func newTestEngine(ctx context.Context, dEnv *env.DoltEnv) (*gms.Engine, *sql.Co
if err != nil {
panic(err)
}
mrEnv, err := env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), dEnv.FS, dEnv.Version, dEnv)
if err != nil {
panic(err)
}
doltSession, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, dEnv.Config.WriteableConfig(), nil, nil, writer.NewWriteSession)
gcSafepointController := dsess.NewGCSafepointController()
doltSession, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, dEnv.Config.WriteableConfig(), nil, nil, writer.NewWriteSession, gcSafepointController)
if err != nil {
panic(err)
}

View File

@@ -115,8 +115,8 @@ func ExecuteSql(ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue, s
return db.GetRoot(sqlCtx)
}
func NewTestSQLCtxWithProvider(ctx context.Context, pro dsess.DoltDatabaseProvider, statsPro sql.StatsProvider) *sql.Context {
s, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, config.NewMapConfig(make(map[string]string)), branch_control.CreateDefaultController(ctx), statsPro, writer.NewWriteSession)
func NewTestSQLCtxWithProvider(ctx context.Context, pro dsess.DoltDatabaseProvider, statsPro sql.StatsProvider, gcSafepointController *dsess.GCSafepointController) *sql.Context {
s, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, config.NewMapConfig(make(map[string]string)), branch_control.CreateDefaultController(ctx), statsPro, writer.NewWriteSession, gcSafepointController)
if err != nil {
panic(err)
}
@@ -135,10 +135,11 @@ func NewTestEngine(dEnv *env.DoltEnv, ctx context.Context, db dsess.SqlDatabase)
if err != nil {
return nil, nil, err
}
gcSafepointController := dsess.NewGCSafepointController()
engine := sqle.NewDefault(pro)
sqlCtx := NewTestSQLCtxWithProvider(ctx, pro, nil)
sqlCtx := NewTestSQLCtxWithProvider(ctx, pro, nil, gcSafepointController)
sqlCtx.SetCurrentDatabase(db.Name())
return engine, sqlCtx, nil
}

View File

@@ -769,10 +769,10 @@ func (sm SerialMessage) WalkAddrs(nbf *NomsBinFormat, cb func(addr hash.Hash) er
return err
}
}
case serial.TableSchemaFileID, serial.ForeignKeyCollectionFileID:
case serial.TableSchemaFileID, serial.ForeignKeyCollectionFileID, serial.TupleFileID:
// no further references from these file types
return nil
case serial.ProllyTreeNodeFileID, serial.AddressMapFileID, serial.MergeArtifactsFileID, serial.BlobFileID, serial.CommitClosureFileID:
case serial.ProllyTreeNodeFileID, serial.AddressMapFileID, serial.MergeArtifactsFileID, serial.BlobFileID, serial.CommitClosureFileID, serial.VectorIndexNodeFileID:
return message.WalkAddresses(context.TODO(), serial.Message(sm), func(ctx context.Context, addr hash.Hash) error {
return cb(addr)
})

View File

@@ -430,3 +430,14 @@ SQL
[[ "$output" =~ "pk1" ]] || false
[[ "${#lines[@]}" = "1" ]] || false
}
@test "vector-index: can GC" {
dolt sql <<SQL
CREATE VECTOR INDEX idx_v1 ON onepk(v1);
INSERT INTO onepk VALUES (1, '[99, 51]'), (2, '[11, 55]'), (3, '[88, 52]'), (4, '[22, 54]'), (5, '[77, 53]');
SQL
dolt gc
dolt sql <<SQL
INSERT INTO onepk VALUES (6, '[99, 51]'), (7, '[11, 55]'), (8, '[88, 52]'), (9, '[22, 54]'), (10, '[77, 53]');
SQL
}

View File

@@ -31,49 +31,64 @@ import (
)
func TestConcurrentGC(t *testing.T) {
t.Run("NoCommits", func(t *testing.T) {
t.Run("Normal", func(t *testing.T) {
var gct = gcTest{
numThreads: 8,
duration: 10 * time.Second,
}
gct.run(t)
})
t.Run("Full", func(t *testing.T) {
var gct = gcTest{
numThreads: 8,
duration: 10 * time.Second,
full: true,
}
gct.run(t)
})
})
t.Run("WithCommits", func(t *testing.T) {
t.Run("Normal", func(t *testing.T) {
var gct = gcTest{
numThreads: 8,
duration: 10 * time.Second,
commit: true,
}
gct.run(t)
})
t.Run("Full", func(t *testing.T) {
var gct = gcTest{
numThreads: 8,
duration: 10 * time.Second,
commit: true,
full: true,
}
gct.run(t)
})
})
type dimension struct {
names []string
factors func(gcTest) []gcTest
}
commits := dimension{
names: []string{"NoCommits", "WithCommits"},
factors: func(base gcTest) []gcTest{
no, yes := base, base
no.commit = false
yes.commit = true
return []gcTest{no, yes}
},
}
full := dimension{
names: []string{"NotFull", "Full"},
factors: func(base gcTest) []gcTest{
no, yes := base, base
no.full = false
yes.full = true
return []gcTest{no, yes}
},
}
safepoint := dimension {
names: []string{"KillConnections", "SessionAware"},
factors: func(base gcTest) []gcTest{
no, yes := base, base
no.sessionAware = false
yes.sessionAware = true
return []gcTest{no, yes}
},
}
var doDimensions func(t *testing.T, base gcTest, dims []dimension)
doDimensions = func (t *testing.T, base gcTest, dims []dimension) {
if len(dims) == 0 {
base.run(t)
return
}
dim, dims := dims[0], dims[1:]
dimf := dim.factors(base)
for i := range dim.names {
t.Run(dim.names[i], func(t *testing.T) {
doDimensions(t, dimf[i], dims)
})
}
}
dimensions := []dimension{commits, full, safepoint}
doDimensions(t, gcTest{
numThreads: 8,
duration: 10 * time.Second,
}, dimensions)
}
type gcTest struct {
numThreads int
duration time.Duration
commit bool
full bool
numThreads int
duration time.Duration
commit bool
full bool
sessionAware bool
}
func (gct gcTest) createDB(t *testing.T, ctx context.Context, db *sql.DB) {
@@ -96,13 +111,19 @@ func (gct gcTest) createDB(t *testing.T, ctx context.Context, db *sql.DB) {
func (gct gcTest) doUpdate(t *testing.T, ctx context.Context, db *sql.DB, i int) error {
conn, err := db.Conn(ctx)
if err != nil {
if gct.sessionAware {
if !assert.NoError(t, err) {
return nil
}
} else if err != nil {
t.Logf("err in Conn: %v", err)
return nil
}
defer conn.Close()
_, err = conn.ExecContext(ctx, "update vals set val = val+1 where id = ?", i)
if err != nil {
if gct.sessionAware {
assert.NoError(t, err)
} else if err != nil {
if !assert.NotContains(t, err.Error(), "dangling ref") {
return err
}
@@ -116,7 +137,9 @@ func (gct gcTest) doUpdate(t *testing.T, ctx context.Context, db *sql.DB, i int)
}
if gct.commit {
_, err = conn.ExecContext(ctx, fmt.Sprintf("call dolt_commit('-am', 'increment vals id = %d')", i))
if err != nil {
if gct.sessionAware {
assert.NoError(t, err)
} else if err != nil {
if !assert.NotContains(t, err.Error(), "dangling ref") {
return err
}
@@ -134,16 +157,24 @@ func (gct gcTest) doUpdate(t *testing.T, ctx context.Context, db *sql.DB, i int)
func (gct gcTest) doGC(t *testing.T, ctx context.Context, db *sql.DB) error {
conn, err := db.Conn(ctx)
if err != nil {
if gct.sessionAware {
if !assert.NoError(t, err) {
return nil
}
} else if err != nil {
t.Logf("err in Conn for dolt_gc: %v", err)
return nil
}
defer func() {
// After calling dolt_gc, the connection is bad. Remove it from the connection pool.
conn.Raw(func(_ any) error {
return sqldriver.ErrBadConn
})
}()
if !gct.sessionAware {
defer func() {
// After calling dolt_gc, the connection is bad. Remove it from the connection pool.
conn.Raw(func(_ any) error {
return sqldriver.ErrBadConn
})
}()
} else {
defer conn.Close()
}
b := time.Now()
if !gct.full {
_, err = conn.ExecContext(ctx, "call dolt_gc()")
@@ -204,7 +235,14 @@ func (gct gcTest) run(t *testing.T) {
repo, err := rs.MakeRepo("concurrent_gc_test")
require.NoError(t, err)
server := MakeServer(t, repo, &driver.Server{})
srvSettings := &driver.Server{}
if gct.sessionAware {
srvSettings.Envs = append(srvSettings.Envs, "DOLT_GC_SAFEPOINT_CONTROLLER_CHOICE=session_aware")
} else {
srvSettings.Envs = append(srvSettings.Envs, "DOLT_GC_SAFEPOINT_CONTROLLER_CHOICE=kill_connections")
}
server := MakeServer(t, repo, srvSettings)
server.DBName = "concurrent_gc_test"
db, err := server.DB(driver.Connection{User: "root"})