Merge pull request #8830 from dolthub/aaron/sqle-database-replication-fix-double-hook

[no-release-notes] go: sqle: DatabaseProvider: Fix double-creation of push-on-write commit hooks in registerNewDatabase.
This commit is contained in:
Aaron Son
2025-02-10 17:09:18 -08:00
committed by GitHub
16 changed files with 307 additions and 115 deletions
+1 -1
View File
@@ -129,7 +129,7 @@ func NewSqlEngine(
gcSafepointController := dsess.NewGCSafepointController()
b := env.GetDefaultInitBranch(mrEnv.Config())
pro, err := dsqle.NewDoltDatabaseProviderWithDatabases(b, mrEnv.FileSystem(), all, locations)
pro, err := dsqle.NewDoltDatabaseProviderWithDatabases(b, mrEnv.FileSystem(), all, locations, bThreads)
if err != nil {
return nil, err
}
+2 -1
View File
@@ -351,7 +351,8 @@ func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootVal
}
b := env.GetDefaultInitBranch(dEnv.Config)
pro, err := dsqle.NewDoltDatabaseProviderWithDatabase(b, mrEnv.FileSystem(), db, dEnv.FS)
bThreads := sql.NewBackgroundThreads()
pro, err := dsqle.NewDoltDatabaseProviderWithDatabase(b, mrEnv.FileSystem(), db, dEnv.FS, bThreads)
if err != nil {
return nil, nil, err
}
@@ -130,7 +130,7 @@ func TestPushOnWriteHook(t *testing.T) {
// setup hook
hook := NewPushOnWriteHook(destDB, tmpDir)
ddb.SetCommitHooks(ctx, []CommitHook{hook})
ddb.PrependCommitHooks(ctx, hook)
t.Run("replicate to remote", func(t *testing.T) {
srcCommit, err := ddb.Commit(context.Background(), valHash, ref.NewBranchRef(defaultBranch), meta)
@@ -293,7 +293,7 @@ func TestAsyncPushOnWrite(t *testing.T) {
// same as the call which is made after a branch delete.
counts := &countingCommitHook{make(map[string]int)}
destDB.SetCommitHooks(context.Background(), []CommitHook{counts})
destDB.PrependCommitHooks(context.Background(), counts)
bThreads := sql.NewBackgroundThreads()
hook, err := NewAsyncPushOnWriteHook(bThreads, destDB, tmpDir, &buffer.Buffer{})
+2 -7
View File
@@ -1954,13 +1954,8 @@ func (ddb *DoltDB) DatasetsByRootHash(ctx context.Context, hashof hash.Hash) (da
return ddb.db.DatasetsByRootHash(ctx, hashof)
}
func (ddb *DoltDB) SetCommitHooks(ctx context.Context, postHooks []CommitHook) *DoltDB {
ddb.db = ddb.db.SetCommitHooks(ctx, postHooks)
return ddb
}
func (ddb *DoltDB) PrependCommitHook(ctx context.Context, hook CommitHook) *DoltDB {
ddb.db = ddb.db.SetCommitHooks(ctx, append([]CommitHook{hook}, ddb.db.PostCommitHooks()...))
func (ddb *DoltDB) PrependCommitHooks(ctx context.Context, hooks ...CommitHook) *DoltDB {
ddb.db = ddb.db.SetCommitHooks(ctx, append(hooks, ddb.db.PostCommitHooks()...))
return ddb
}
@@ -315,7 +315,7 @@ func (c *Controller) applyCommitHooks(ctx context.Context, name string, bt *sql.
commitHook := newCommitHook(c.lgr, r.Name(), remote.Url, name, c.role, func(ctx context.Context) (*doltdb.DoltDB, error) {
return remote.GetRemoteDB(ctx, types.Format_Default, dialprovider)
}, denv.DoltDB(ctx), ttfdir)
denv.DoltDB(ctx).PrependCommitHook(ctx, commitHook)
denv.DoltDB(ctx).PrependCommitHooks(ctx, commitHook)
if err := commitHook.Run(bt); err != nil {
return nil, err
}
@@ -84,7 +84,7 @@ func NewInitDatabaseHook(controller *Controller, bt *sql.BackgroundThreads) sqle
return err
}
commitHook := newCommitHook(controller.lgr, r.Name(), remoteUrls[i], name, role, remoteDBs[i], denv.DoltDB(ctx), ttfdir)
denv.DoltDB(ctx).PrependCommitHook(ctx, commitHook)
denv.DoltDB(ctx).PrependCommitHooks(ctx, commitHook)
controller.registerCommitHook(commitHook)
if err := commitHook.Run(bt); err != nil {
// XXX: An error here means we are not replicating to every standby.
+65 -61
View File
@@ -92,21 +92,21 @@ func (p *DoltDatabaseProvider) WithTableFunctions(fns ...sql.TableFunction) (sql
// NewDoltDatabaseProvider returns a new provider, initialized without any databases, along with any
// errors that occurred while trying to create the database provider.
func NewDoltDatabaseProvider(defaultBranch string, fs filesys.Filesys) (*DoltDatabaseProvider, error) {
return NewDoltDatabaseProviderWithDatabases(defaultBranch, fs, nil, nil)
func NewDoltDatabaseProvider(defaultBranch string, fs filesys.Filesys, bThreads *sql.BackgroundThreads) (*DoltDatabaseProvider, error) {
return NewDoltDatabaseProviderWithDatabases(defaultBranch, fs, nil, nil, bThreads)
}
// NewDoltDatabaseProviderWithDatabase returns a new provider, initialized with one database at the
// specified location, and any error that occurred along the way.
func NewDoltDatabaseProviderWithDatabase(defaultBranch string, fs filesys.Filesys, database dsess.SqlDatabase, dbLocation filesys.Filesys) (*DoltDatabaseProvider, error) {
return NewDoltDatabaseProviderWithDatabases(defaultBranch, fs, []dsess.SqlDatabase{database}, []filesys.Filesys{dbLocation})
func NewDoltDatabaseProviderWithDatabase(defaultBranch string, fs filesys.Filesys, database dsess.SqlDatabase, dbLocation filesys.Filesys, bThreads *sql.BackgroundThreads) (*DoltDatabaseProvider, error) {
return NewDoltDatabaseProviderWithDatabases(defaultBranch, fs, []dsess.SqlDatabase{database}, []filesys.Filesys{dbLocation}, bThreads)
}
// NewDoltDatabaseProviderWithDatabases returns a new provider, initialized with the specified databases,
// at the specified locations. For every database specified, there must be a corresponding filesystem
// specified that represents where the database is located. If the number of specified databases is not the
// same as the number of specified locations, an error is returned.
func NewDoltDatabaseProviderWithDatabases(defaultBranch string, fs filesys.Filesys, databases []dsess.SqlDatabase, locations []filesys.Filesys) (*DoltDatabaseProvider, error) {
func NewDoltDatabaseProviderWithDatabases(defaultBranch string, fs filesys.Filesys, databases []dsess.SqlDatabase, locations []filesys.Filesys, bThreads *sql.BackgroundThreads) (*DoltDatabaseProvider, error) {
if len(databases) != len(locations) {
return nil, fmt.Errorf("unable to create DoltDatabaseProvider: "+
"incorrect number of databases (%d) and database locations (%d) specified", len(databases), len(locations))
@@ -154,7 +154,7 @@ func NewDoltDatabaseProviderWithDatabases(defaultBranch string, fs filesys.Files
fs: fs,
defaultBranch: defaultBranch,
dbFactoryUrl: dbFactoryUrl,
InitDatabaseHooks: []InitDatabaseHook{ConfigureReplicationDatabaseHook},
InitDatabaseHooks: []InitDatabaseHook{NewConfigureReplicationDatabaseHook(bThreads)},
isStandby: new(bool),
droppedDatabaseManager: newDroppedDatabaseManager(fs),
}, nil
@@ -612,55 +612,60 @@ func (p *DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name str
type InitDatabaseHook func(ctx *sql.Context, pro *DoltDatabaseProvider, name string, env *env.DoltEnv, db dsess.SqlDatabase) error
type DropDatabaseHook func(ctx *sql.Context, name string)
// ConfigureReplicationDatabaseHook sets up the hooks to push to a remote to replicate a newly created database.
// TODO: consider the replication heads / all heads setting
func ConfigureReplicationDatabaseHook(ctx *sql.Context, p *DoltDatabaseProvider, name string, newEnv *env.DoltEnv, _ dsess.SqlDatabase) error {
_, replicationRemoteName, _ := sql.SystemVariables.GetGlobal(dsess.ReplicateToRemote)
if replicationRemoteName == "" {
return nil
// NewConfigureReplicationDatabaseHook sets up the hooks to push to a remote to replicate a newly created database.
//
// For a new database, this hook
// 1) creates a new remote based on dsess.ReplicationRemoteURLTemplate
// 2) Installed push-on-write replication hooks based on existing sql.SystemVariables on the *DoltDB
// 3) Triggers the push-on-write hook for the default branch.
func NewConfigureReplicationDatabaseHook(bThreads *sql.BackgroundThreads) func(ctx *sql.Context, p *DoltDatabaseProvider, name string, newEnv *env.DoltEnv, _ dsess.SqlDatabase) error {
return func(ctx *sql.Context, p *DoltDatabaseProvider, name string, newEnv *env.DoltEnv, _ dsess.SqlDatabase) error {
_, replicationRemoteName, _ := sql.SystemVariables.GetGlobal(dsess.ReplicateToRemote)
if replicationRemoteName == "" {
return nil
}
remoteName, ok := replicationRemoteName.(string)
if !ok {
return nil
}
_, remoteUrlTemplate, _ := sql.SystemVariables.GetGlobal(dsess.ReplicationRemoteURLTemplate)
if remoteUrlTemplate == "" {
return nil
}
urlTemplate, ok := remoteUrlTemplate.(string)
if !ok {
return nil
}
// TODO: url sanitize name
remoteUrl := strings.Replace(urlTemplate, dsess.URLTemplateDatabasePlaceholder, name, -1)
// TODO: params for AWS, others that need them
r := env.NewRemote(remoteName, remoteUrl, nil)
err := r.Prepare(ctx, newEnv.DoltDB(ctx).Format(), p.remoteDialer)
if err != nil {
return err
}
err = newEnv.AddRemote(r)
if err != env.ErrRemoteAlreadyExists && err != nil {
return err
}
commitHooks, err := GetCommitHooks(ctx, bThreads, newEnv, cli.CliErr)
if err != nil {
return err
}
newEnv.DoltDB(ctx).PrependCommitHooks(ctx, commitHooks...)
// After setting hooks on the newly created DB, we need to do the first push manually
branchRef := ref.NewBranchRef(p.defaultBranch)
return newEnv.DoltDB(ctx).ExecuteCommitHooks(ctx, branchRef.String())
}
remoteName, ok := replicationRemoteName.(string)
if !ok {
return nil
}
_, remoteUrlTemplate, _ := sql.SystemVariables.GetGlobal(dsess.ReplicationRemoteURLTemplate)
if remoteUrlTemplate == "" {
return nil
}
urlTemplate, ok := remoteUrlTemplate.(string)
if !ok {
return nil
}
// TODO: url sanitize name
remoteUrl := strings.Replace(urlTemplate, dsess.URLTemplateDatabasePlaceholder, name, -1)
// TODO: params for AWS, others that need them
r := env.NewRemote(remoteName, remoteUrl, nil)
err := r.Prepare(ctx, newEnv.DoltDB(ctx).Format(), p.remoteDialer)
if err != nil {
return err
}
err = newEnv.AddRemote(r)
if err != env.ErrRemoteAlreadyExists && err != nil {
return err
}
// TODO: get background threads from the engine
commitHooks, err := GetCommitHooks(ctx, sql.NewBackgroundThreads(), newEnv, cli.CliErr)
if err != nil {
return err
}
newEnv.DoltDB(ctx).SetCommitHooks(ctx, commitHooks)
// After setting hooks on the newly created DB, we need to do the first push manually
branchRef := ref.NewBranchRef(p.defaultBranch)
return newEnv.DoltDB(ctx).ExecuteCommitHooks(ctx, branchRef.String())
}
// CloneDatabaseFromRemote implements DoltDatabaseProvider interface
@@ -866,7 +871,7 @@ func (p *DoltDatabaseProvider) registerNewDatabase(ctx *sql.Context, name string
}
// If we have any initialization hooks, invoke them, until any error is returned.
// By default, this will be ConfigureReplicationDatabaseHook, which will set up
// By default, this will be NewConfigureReplicationDatabaseHook, which will set up
// replication for the new database if a remote url template is set.
for _, initHook := range p.InitDatabaseHooks {
err = initHook(ctx, p, name, newEnv, db)
@@ -875,17 +880,16 @@ func (p *DoltDatabaseProvider) registerNewDatabase(ctx *sql.Context, name string
}
}
mrEnv, err := env.MultiEnvForSingleEnv(ctx, newEnv)
if err != nil {
return err
}
dbs, err := ApplyReplicationConfig(ctx, sql.NewBackgroundThreads(), mrEnv, cli.CliErr, db)
// Push replication is configured by InitDatabaseHooks, but pull-on-read
// replication is a special type of wrapper database, |ReadReplicaDatabase|.
// Transform the |db| into the replicating one if we need to.
sdb, err := applyReadReplicationConfigToDatabase(ctx, newEnv, db)
if err != nil {
return err
}
formattedName := formatDbMapKeyName(db.Name())
p.databases[formattedName] = dbs[0]
p.databases[formattedName] = sdb
p.dbLocations[formattedName] = newEnv.FS
return nil
}
@@ -0,0 +1,171 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqle
import (
"context"
"io"
"testing"
"github.com/dolthub/go-mysql-server/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/store/datas"
)
func setGlobalSqlVariable(t *testing.T, name string, val interface{}) {
_, cur, _ := sql.SystemVariables.GetGlobal(name)
t.Cleanup(func() {
sql.SystemVariables.SetGlobal(name, cur)
})
sql.SystemVariables.SetGlobal(name, val)
}
func TestDatabaseProvider(t *testing.T) {
t.Run("ReplicationConfig", func(t *testing.T) {
t.Run("CreateDatabase", func(t *testing.T) {
t.Run("NoReplication", func(t *testing.T) {
ctx := context.Background()
dEnv := dtestutils.CreateTestEnv()
tmpDir, err := dEnv.TempTableFilesDir()
require.NoError(t, err)
opts := editor.Options{Deaf: dEnv.DbEaFactory(ctx), Tempdir: tmpDir}
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(ctx), opts)
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(dEnv, context.Background(), db)
require.NoError(t, err)
sess := dsess.DSessFromSess(sqlCtx.Session)
sess.Provider().(*DoltDatabaseProvider).AddInitDatabaseHook(InstallSnoopingCommitHook)
err = ExecuteSqlOnEngine(sqlCtx, engine, "CREATE DATABASE mytest;")
require.NoError(t, err)
sqlDb, err := sess.Provider().Database(sqlCtx, "mytest")
require.NoError(t, err)
ddbs := sqlDb.(Database).DoltDatabases()
require.Len(t, ddbs, 1)
hooks := doltdb.HackDatasDatabaseFromDoltDB(ddbs[0]).(interface {
PostCommitHooks() []doltdb.CommitHook
}).PostCommitHooks()
assert.Len(t, hooks, 1)
_, ok := hooks[0].(*snoopingCommitHook)
assert.True(t, ok, "expect hook to be PushOnWriteHook, it is %T", hooks[0])
})
t.Run("PushOnWriteReplication", func(t *testing.T) {
ctx := context.Background()
dEnv := dtestutils.CreateTestEnv()
tmpDir, err := dEnv.TempTableFilesDir()
setGlobalSqlVariable(t, dsess.ReplicateToRemote, "fileremote")
setGlobalSqlVariable(t, dsess.ReplicationRemoteURLTemplate, "mem://remote_{database}")
require.NoError(t, err)
opts := editor.Options{Deaf: dEnv.DbEaFactory(ctx), Tempdir: tmpDir}
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(ctx), opts)
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(dEnv, context.Background(), db)
require.NoError(t, err)
sess := dsess.DSessFromSess(sqlCtx.Session)
sess.Provider().(*DoltDatabaseProvider).AddInitDatabaseHook(InstallSnoopingCommitHook)
err = ExecuteSqlOnEngine(sqlCtx, engine, "CREATE DATABASE mytest;")
require.NoError(t, err)
sqlDb, err := sess.Provider().Database(sqlCtx, "mytest")
require.NoError(t, err)
ddbs := sqlDb.(Database).DoltDatabases()
require.Len(t, ddbs, 1)
hooks := doltdb.HackDatasDatabaseFromDoltDB(ddbs[0]).(interface {
PostCommitHooks() []doltdb.CommitHook
}).PostCommitHooks()
require.Len(t, hooks, 2)
_, ok := hooks[0].(*snoopingCommitHook)
assert.True(t, ok, "expect hook to be snoopingCommitHook, it is %T", hooks[0])
_, ok = hooks[1].(*doltdb.PushOnWriteHook)
assert.True(t, ok, "expect hook to be PushOnWriteHook, it is %T", hooks[1])
})
t.Run("AsyncPushOnWrite", func(t *testing.T) {
ctx := context.Background()
dEnv := dtestutils.CreateTestEnv()
tmpDir, err := dEnv.TempTableFilesDir()
setGlobalSqlVariable(t, dsess.ReplicateToRemote, "fileremote")
setGlobalSqlVariable(t, dsess.ReplicationRemoteURLTemplate, "mem://remote_{database}")
setGlobalSqlVariable(t, dsess.AsyncReplication, dsess.SysVarTrue)
require.NoError(t, err)
opts := editor.Options{Deaf: dEnv.DbEaFactory(ctx), Tempdir: tmpDir}
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(ctx), opts)
require.NoError(t, err)
engine, sqlCtx, err := NewTestEngine(dEnv, context.Background(), db)
require.NoError(t, err)
sess := dsess.DSessFromSess(sqlCtx.Session)
sess.Provider().(*DoltDatabaseProvider).AddInitDatabaseHook(InstallSnoopingCommitHook)
err = ExecuteSqlOnEngine(sqlCtx, engine, "CREATE DATABASE mytest;")
require.NoError(t, err)
sqlDb, err := sess.Provider().Database(sqlCtx, "mytest")
require.NoError(t, err)
ddbs := sqlDb.(Database).DoltDatabases()
require.Len(t, ddbs, 1)
hooks := doltdb.HackDatasDatabaseFromDoltDB(ddbs[0]).(interface {
PostCommitHooks() []doltdb.CommitHook
}).PostCommitHooks()
require.Len(t, hooks, 2)
_, ok := hooks[0].(*snoopingCommitHook)
assert.True(t, ok, "expect hook to be snoopingCommitHook, it is %T", hooks[0])
_, ok = hooks[1].(*doltdb.AsyncPushOnWriteHook)
assert.True(t, ok, "expect hook to be AsyncPushOnWriteHook, it is %T", hooks[1])
})
})
})
}
type snoopingCommitHook struct {
}
func (*snoopingCommitHook) Execute(ctx context.Context, ds datas.Dataset, db datas.Database) (func(context.Context) error, error) {
return nil, nil
}
func (*snoopingCommitHook) HandleError(ctx context.Context, err error) error {
return nil
}
func (*snoopingCommitHook) SetLogger(ctx context.Context, wr io.Writer) error {
return nil
}
func (*snoopingCommitHook) ExecuteForWorkingSets() bool {
return true
}
func InstallSnoopingCommitHook(ctx *sql.Context, pro *DoltDatabaseProvider, name string, dEnv *env.DoltEnv, db dsess.SqlDatabase) error {
dEnv.DoltDB(ctx).PrependCommitHooks(ctx, &snoopingCommitHook{})
return nil
}
@@ -166,6 +166,8 @@ func restoreBackup(ctx *sql.Context, _ env.DbData, apr *argparser.ArgParseResult
dbName := strings.TrimSpace(apr.Arg(2))
force := apr.Contains(cli.ForceFlag)
sess := dsess.DSessFromSess(ctx.Session)
remoteParams := map[string]string{}
r := env.NewRemote("", backupUrl, remoteParams)
srcDb, err := r.GetRemoteDB(ctx, types.Format_Default, nil)
@@ -173,7 +175,6 @@ func restoreBackup(ctx *sql.Context, _ env.DbData, apr *argparser.ArgParseResult
return err
}
sess := dsess.DSessFromSess(ctx.Session)
existingDbData, restoringExistingDb := sess.GetDbData(ctx, dbName)
if restoringExistingDb {
if !force {
@@ -484,7 +484,7 @@ func (d *DoltHarness) NewReadOnlyEngine(provider sql.DatabaseProvider) (enginete
locations[i] = loc
}
readOnlyProvider, err := sqle.NewDoltDatabaseProviderWithDatabases("main", ddp.FileSystem(), dbs, locations)
readOnlyProvider, err := sqle.NewDoltDatabaseProviderWithDatabases("main", ddp.FileSystem(), dbs, locations, sql.NewBackgroundThreads())
if err != nil {
return nil, err
}
@@ -533,7 +533,7 @@ func (d *DoltHarness) newProvider(ctx context.Context) sql.MutableDatabaseProvid
d.multiRepoEnv = mrEnv
b := env.GetDefaultInitBranch(d.multiRepoEnv.Config())
pro, err := sqle.NewDoltDatabaseProvider(b, d.multiRepoEnv.FileSystem())
pro, err := sqle.NewDoltDatabaseProvider(b, d.multiRepoEnv.FileSystem(), sql.NewBackgroundThreads())
require.NoError(d.t, err)
return pro
@@ -99,7 +99,7 @@ func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *s
mrEnv, err := env.MultiEnvForDirectory(context.Background(), dEnv.Config.WriteableConfig(), dEnv.FS, dEnv.Version, dEnv)
require.NoError(t, err)
b := env.GetDefaultInitBranch(dEnv.Config)
pro, err := dsqle.NewDoltDatabaseProviderWithDatabase(b, mrEnv.FileSystem(), db, dEnv.FS)
pro, err := dsqle.NewDoltDatabaseProviderWithDatabase(b, mrEnv.FileSystem(), db, dEnv.FS, sql.NewBackgroundThreads())
if err != nil {
return nil, nil, nil
}
@@ -144,7 +144,8 @@ func innerInit(h *DoltHarness, dEnv *env.DoltEnv) error {
return err
}
sqlCtx := dsql.NewTestSQLCtxWithProvider(ctx, pro, statspro.NewProvider(pro.(*dsql.DoltDatabaseProvider), statsnoms.NewNomsStatsFactory(env.NewGRPCDialProviderFromDoltEnv(dEnv))), dsess.NewGCSafepointController())
config, _ := dEnv.Config.GetConfig(env.GlobalConfig)
sqlCtx := dsql.NewTestSQLCtxWithProvider(ctx, pro, config, statspro.NewProvider(pro.(*dsql.DoltDatabaseProvider), statsnoms.NewNomsStatsFactory(env.NewGRPCDialProviderFromDoltEnv(dEnv))), dsess.NewGCSafepointController())
h.sess = sqlCtx.Session.(*dsess.DoltSession)
dbs := h.engine.Analyzer.Catalog.AllDatabases(sqlCtx)
@@ -300,7 +301,7 @@ func sqlNewEngine(ctx context.Context, dEnv *env.DoltEnv) (*sqle.Engine, dsess.D
}
b := env.GetDefaultInitBranch(dEnv.Config)
pro, err := dsql.NewDoltDatabaseProviderWithDatabase(b, mrEnv.FileSystem(), db, dEnv.FS)
pro, err := dsql.NewDoltDatabaseProviderWithDatabase(b, mrEnv.FileSystem(), db, dEnv.FS, sql.NewBackgroundThreads())
if err != nil {
return nil, nil, err
}
+24 -14
View File
@@ -115,6 +115,26 @@ func newReplicaDatabase(ctx context.Context, name string, remoteName string, dEn
return rrd, nil
}
// Converts |db| into a |ReadReplicaDatabase| if read replication is
// configured through sql SystemVariables. This is called both at
// startup, for the entire set of databases, and is called when
// we create new databases through |registerNewDatabases|.
func applyReadReplicationConfigToDatabase(ctx context.Context, dEnv *env.DoltEnv, db dsess.SqlDatabase) (dsess.SqlDatabase, error) {
if _, remote, ok := sql.SystemVariables.GetGlobal(dsess.ReadReplicaRemote); ok && remote != "" {
remoteName, ok := remote.(string)
if !ok {
return nil, sql.ErrInvalidSystemVariableValue.New(remote)
}
rdb, err := newReplicaDatabase(ctx, db.Name(), remoteName, dEnv)
if err == nil {
db = rdb
} else {
logrus.Errorf("invalid replication configuration, replication disabled: %v", err)
}
}
return db, nil
}
func ApplyReplicationConfig(ctx context.Context, bThreads *sql.BackgroundThreads, mrEnv *env.MultiRepoEnv, logger io.Writer, dbs ...dsess.SqlDatabase) ([]dsess.SqlDatabase, error) {
outputDbs := make([]dsess.SqlDatabase, len(dbs))
for i, db := range dbs {
@@ -127,22 +147,12 @@ func ApplyReplicationConfig(ctx context.Context, bThreads *sql.BackgroundThreads
if err != nil {
return nil, err
}
dEnv.DoltDB(ctx).SetCommitHooks(ctx, postCommitHooks)
dEnv.DoltDB(ctx).PrependCommitHooks(ctx, postCommitHooks...)
if _, remote, ok := sql.SystemVariables.GetGlobal(dsess.ReadReplicaRemote); ok && remote != "" {
remoteName, ok := remote.(string)
if !ok {
return nil, sql.ErrInvalidSystemVariableValue.New(remote)
}
rdb, err := newReplicaDatabase(ctx, db.Name(), remoteName, dEnv)
if err == nil {
db = rdb
} else {
logrus.Errorf("invalid replication configuration, replication disabled: %v", err)
}
outputDbs[i], err = applyReadReplicationConfigToDatabase(ctx, dEnv, db)
if err != nil {
return nil, err
}
outputDbs[i] = db
}
return outputDbs, nil
}
@@ -29,7 +29,7 @@ import (
// These functions cannot be in the sqlfmt package as the reliance on the sqle package creates a circular reference.
func PrepareCreateTableStmt(ctx context.Context, sqlDb dsess.SqlDatabase) (*sql.Context, *sqle.Engine, *dsess.DoltSession) {
pro, err := NewDoltDatabaseProviderWithDatabase(env.DefaultInitBranch, nil, sqlDb, nil)
pro, err := NewDoltDatabaseProviderWithDatabase(env.DefaultInitBranch, nil, sqlDb, nil, sql.NewBackgroundThreads())
if err != nil {
return nil, nil, nil
}
+1 -1
View File
@@ -1104,7 +1104,7 @@ func TestParseCreateTableStatement(t *testing.T) {
}
func newTestEngine(ctx context.Context, dEnv *env.DoltEnv) (*gms.Engine, *sql.Context) {
pro, err := NewDoltDatabaseProviderWithDatabases("main", dEnv.FS, nil, nil)
pro, err := NewDoltDatabaseProviderWithDatabases("main", dEnv.FS, nil, nil, sql.NewBackgroundThreads())
if err != nil {
panic(err)
}
+28 -19
View File
@@ -61,10 +61,18 @@ func ExecuteSql(ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue, s
return nil, err
}
err = sqlCtx.Session.SetSessionVariable(sqlCtx, sql.AutoCommitSessionVar, false)
err = ExecuteSqlOnEngine(sqlCtx, engine, statements)
if err != nil {
return nil, err
}
return db.GetRoot(sqlCtx)
}
func ExecuteSqlOnEngine(ctx *sql.Context, engine *sqle.Engine, statements string) error {
err := ctx.Session.SetSessionVariable(ctx, sql.AutoCommitSessionVar, false)
if err != nil {
return err
}
for _, query := range strings.Split(statements, ";\n") {
if len(strings.Trim(query, " ")) == 0 {
@@ -73,50 +81,50 @@ func ExecuteSql(ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue, s
sqlStatement, err := sqlparser.Parse(query)
if err != nil {
return nil, err
return err
}
var execErr error
switch sqlStatement.(type) {
case *sqlparser.Show:
return nil, errors.New("Show statements aren't handled")
return errors.New("Show statements aren't handled")
case *sqlparser.Select, *sqlparser.OtherRead:
return nil, errors.New("Select statements aren't handled")
return errors.New("Select statements aren't handled")
case *sqlparser.Insert:
var rowIter sql.RowIter
_, rowIter, _, execErr = engine.Query(sqlCtx, query)
_, rowIter, _, execErr = engine.Query(ctx, query)
if execErr == nil {
execErr = drainIter(sqlCtx, rowIter)
execErr = drainIter(ctx, rowIter)
}
case *sqlparser.DDL, *sqlparser.AlterTable:
case *sqlparser.DDL, *sqlparser.AlterTable, *sqlparser.DBDDL:
var rowIter sql.RowIter
_, rowIter, _, execErr = engine.Query(sqlCtx, query)
_, rowIter, _, execErr = engine.Query(ctx, query)
if execErr == nil {
execErr = drainIter(sqlCtx, rowIter)
execErr = drainIter(ctx, rowIter)
}
default:
return nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query)
return fmt.Errorf("Unsupported SQL statement: '%v'.", query)
}
if execErr != nil {
return nil, execErr
return execErr
}
}
// commit leftover transaction
trx := sqlCtx.GetTransaction()
trx := ctx.GetTransaction()
if trx != nil {
err = dsess.DSessFromSess(sqlCtx.Session).CommitTransaction(sqlCtx, trx)
err = dsess.DSessFromSess(ctx.Session).CommitTransaction(ctx, trx)
if err != nil {
return nil, err
return err
}
}
return db.GetRoot(sqlCtx)
return nil
}
func NewTestSQLCtxWithProvider(ctx context.Context, pro dsess.DoltDatabaseProvider, statsPro sql.StatsProvider, gcSafepointController *dsess.GCSafepointController) *sql.Context {
s, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, config.NewMapConfig(make(map[string]string)), branch_control.CreateDefaultController(ctx), statsPro, writer.NewWriteSession, gcSafepointController)
func NewTestSQLCtxWithProvider(ctx context.Context, pro dsess.DoltDatabaseProvider, config config.ReadWriteConfig, statsPro sql.StatsProvider, gcSafepointController *dsess.GCSafepointController) *sql.Context {
s, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, config, branch_control.CreateDefaultController(ctx), statsPro, writer.NewWriteSession, gcSafepointController)
if err != nil {
panic(err)
}
@@ -131,7 +139,7 @@ func NewTestSQLCtxWithProvider(ctx context.Context, pro dsess.DoltDatabaseProvid
// NewTestEngine creates a new default engine, and a *sql.Context and initializes indexes and schema fragments.
func NewTestEngine(dEnv *env.DoltEnv, ctx context.Context, db dsess.SqlDatabase) (*sqle.Engine, *sql.Context, error) {
b := env.GetDefaultInitBranch(dEnv.Config)
pro, err := NewDoltDatabaseProviderWithDatabase(b, dEnv.FS, db, dEnv.FS)
pro, err := NewDoltDatabaseProviderWithDatabase(b, dEnv.FS, db, dEnv.FS, sql.NewBackgroundThreads())
if err != nil {
return nil, nil, err
}
@@ -139,7 +147,8 @@ func NewTestEngine(dEnv *env.DoltEnv, ctx context.Context, db dsess.SqlDatabase)
engine := sqle.NewDefault(pro)
sqlCtx := NewTestSQLCtxWithProvider(ctx, pro, nil, gcSafepointController)
config, _ := dEnv.Config.GetConfig(env.GlobalConfig)
sqlCtx := NewTestSQLCtxWithProvider(ctx, pro, config, nil, gcSafepointController)
sqlCtx.SetCurrentDatabase(db.Name())
return engine, sqlCtx, nil
}