go: sqle: dsess: Make DoltSession Lifecycle aware. Move towards a GCSafepointController which can work with it.

This commit is contained in:
Aaron Son
2025-01-22 10:59:40 -08:00
parent 366e466a7e
commit 915e392b10
9 changed files with 631 additions and 123 deletions
+3
View File
@@ -16,3 +16,6 @@ SysbenchDockerfile.dockerignore
sysbench-runner-tests-entrypoint.sh
config.json
integration-tests/bats/batsee_results
*~
.dir-locals.el
+6 -4
View File
@@ -125,13 +125,15 @@ func NewSqlEngine(
locations = append(locations, nil)
}
gcSafepointController := dsess.NewGCSafepointController()
b := env.GetDefaultInitBranch(mrEnv.Config())
pro, err := dsqle.NewDoltDatabaseProviderWithDatabases(b, mrEnv.FileSystem(), all, locations)
if err != nil {
return nil, err
}
pro = pro.WithRemoteDialer(mrEnv.RemoteDialProvider())
pro.RegisterProcedure(dprocedures.NewDoltGCProcedure())
pro.RegisterProcedure(dprocedures.NewDoltGCProcedure(gcSafepointController))
config.ClusterController.RegisterStoredProcedures(pro)
if config.ClusterController != nil {
@@ -191,7 +193,7 @@ func NewSqlEngine(
engine.Analyzer.Catalog.StatsProvider = statsPro
engine.Analyzer.ExecBuilder = rowexec.NewOverrideBuilder(kvexec.Builder{})
sessFactory := doltSessionFactory(pro, statsPro, mrEnv.Config(), bcController, config.Autocommit)
sessFactory := doltSessionFactory(pro, statsPro, mrEnv.Config(), bcController, gcSafepointController, config.Autocommit)
sqlEngine.provider = pro
sqlEngine.contextFactory = sqlContextFactory()
sqlEngine.dsessFactory = sessFactory
@@ -415,9 +417,9 @@ func sqlContextFactory() contextFactory {
}
// doltSessionFactory returns a sessionFactory that creates a new DoltSession
func doltSessionFactory(pro *dsqle.DoltDatabaseProvider, statsPro sql.StatsProvider, config config.ReadWriteConfig, bc *branch_control.Controller, autocommit bool) sessionFactory {
func doltSessionFactory(pro *dsqle.DoltDatabaseProvider, statsPro sql.StatsProvider, config config.ReadWriteConfig, bc *branch_control.Controller, gcSafepointController *dsess.GCSafepointController, autocommit bool) sessionFactory {
return func(mysqlSess *sql.BaseSession, provider sql.DatabaseProvider) (*dsess.DoltSession, error) {
doltSession, err := dsess.NewDoltSession(mysqlSess, pro, config, bc, statsPro, writer.NewWriteSession)
doltSession, err := dsess.NewDoltSession(mysqlSess, pro, config, bc, statsPro, writer.NewWriteSession, gcSafepointController)
if err != nil {
return nil, err
}
@@ -45,15 +45,17 @@ func init() {
var DoltGCFeatureFlag = true
func NewDoltGCProcedure() sql.ExternalStoredProcedureDetails {
impl := &DoltGCProcedure{}
func NewDoltGCProcedure(gcSafepointController *dsess.GCSafepointController) sql.ExternalStoredProcedureDetails {
impl := &DoltGCProcedure{
gcSafepointController: gcSafepointController,
}
return sql.ExternalStoredProcedureDetails{
Name: "dolt_gc",
Schema: int64Schema("status"),
Function: impl.Run,
ReadOnly: true,
Name: "dolt_gc",
Schema: int64Schema("status"),
Function: impl.Run,
ReadOnly: true,
AdminOnly: true,
}
}
}
// doltGC is the stored procedure to run online garbage collection on a database.
@@ -70,30 +72,88 @@ func (p *DoltGCProcedure) Run(ctx *sql.Context, args ...string) (sql.RowIter, er
var ErrServerPerformedGC = errors.New("this connection was established when this server performed an online garbage collection. this connection can no longer be used. please reconnect.")
type safepointController struct {
begin func(context.Context, func(hash.Hash) bool) error
preFinalize func(context.Context) error
postFinalize func(context.Context) error
cancel func()
// The original behavior safepoint controller, which kills all connections right as the GC process is being finalized.
// The only connection which is left up is the connection on which dolt_gc is called, but that connection is
// invalidated in such a way that all future queries on it return an error.
type killConnectionsSafepointController struct {
callCtx *sql.Context
origEpoch int
}
func (sc safepointController) BeginGC(ctx context.Context, keeper func(hash.Hash) bool) error {
return sc.begin(ctx, keeper)
func (sc killConnectionsSafepointController) BeginGC(ctx context.Context, keeper func(hash.Hash) bool) error {
return nil
}
func (sc safepointController) EstablishPreFinalizeSafepoint(ctx context.Context) error {
return sc.preFinalize(ctx)
func (sc killConnectionsSafepointController) EstablishPreFinalizeSafepoint(ctx context.Context) error {
return nil
}
func (sc safepointController) EstablishPostFinalizeSafepoint(ctx context.Context) error {
return sc.postFinalize(ctx)
func (sc killConnectionsSafepointController) EstablishPostFinalizeSafepoint(ctx context.Context) error {
// Here we need to sanity check role and epoch.
if sc.origEpoch != -1 {
if _, role, ok := sql.SystemVariables.GetGlobal(dsess.DoltClusterRoleVariable); ok {
if role.(string) != "primary" {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but now our role is %s", role.(string))
}
_, epoch, ok := sql.SystemVariables.GetGlobal(dsess.DoltClusterRoleEpochVariable)
if !ok {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but we can no longer read the cluster role epoch.")
}
if sc.origEpoch != epoch.(int) {
return fmt.Errorf("dolt_gc failed: when we began we were primary in the cluster at epoch %d, but now we are at epoch %d. for gc to safely finalize, our role and epoch must not change throughout the gc.", sc.origEpoch, epoch.(int))
}
} else {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but we can no longer read the cluster role.")
}
}
killed := make(map[uint32]struct{})
processes := sc.callCtx.ProcessList.Processes()
for _, p := range processes {
if p.Connection != sc.callCtx.Session.ID() {
// Kill any inflight query.
sc.callCtx.ProcessList.Kill(p.Connection)
// Tear down the connection itself.
sc.callCtx.KillConnection(p.Connection)
killed[p.Connection] = struct{}{}
}
}
// Look in processes until the connections are actually gone.
params := backoff.NewExponentialBackOff()
params.InitialInterval = 1 * time.Millisecond
params.MaxInterval = 25 * time.Millisecond
params.MaxElapsedTime = 3 * time.Second
err := backoff.Retry(func() error {
processes := sc.callCtx.ProcessList.Processes()
allgood := true
for _, p := range processes {
if _, ok := killed[p.Connection]; ok {
allgood = false
sc.callCtx.ProcessList.Kill(p.Connection)
}
}
if !allgood {
return errors.New("unable to establish safepoint.")
}
return nil
}, params)
if err != nil {
return err
}
sc.callCtx.Session.SetTransaction(nil)
dsess.DSessFromSess(sc.callCtx.Session).SetValidateErr(ErrServerPerformedGC)
return nil
}
func (sc safepointController) CancelSafepoint() {
sc.cancel()
func (sc killConnectionsSafepointController) CancelSafepoint() {
}
type DoltGCProcedure struct {
// Used by the implementation to visit existing sessions, find them
// at a quiesced state and ensure that their in-memory state makes
// it to the GC process.
gcSafepointController *dsess.GCSafepointController
}
func (*DoltGCProcedure) doGC(ctx *sql.Context, args []string) (int, error) {
@@ -136,7 +196,6 @@ func (*DoltGCProcedure) doGC(ctx *sql.Context, args []string) (int, error) {
// We assert that we are the primary here before we begin, and
// we assert again that we are the primary at the same epoch as
// we establish the safepoint.
origepoch := -1
if _, role, ok := sql.SystemVariables.GetGlobal(dsess.DoltClusterRoleVariable); ok {
// TODO: magic constant...
@@ -155,71 +214,10 @@ func (*DoltGCProcedure) doGC(ctx *sql.Context, args []string) (int, error) {
mode = types.GCModeFull
}
// TODO: Implement safepointController so that begin can capture inflight sessions
// and preFinalize can ensure they're all in a good place before returning.
sc := safepointController{
begin: func(context.Context, func(hash.Hash) bool) error { return nil },
preFinalize: func(context.Context) error { return nil },
postFinalize: func(context.Context) error {
if origepoch != -1 {
// Here we need to sanity check role and epoch.
if _, role, ok := sql.SystemVariables.GetGlobal(dsess.DoltClusterRoleVariable); ok {
if role.(string) != "primary" {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but now our role is %s", role.(string))
}
_, epoch, ok := sql.SystemVariables.GetGlobal(dsess.DoltClusterRoleEpochVariable)
if !ok {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but we can no longer read the cluster role epoch.")
}
if origepoch != epoch.(int) {
return fmt.Errorf("dolt_gc failed: when we began we were primary in the cluster at epoch %d, but now we are at epoch %d. for gc to safely finalize, our role and epoch must not change throughout the gc.", origepoch, epoch.(int))
}
} else {
return fmt.Errorf("dolt_gc failed: when we began we were a primary in a cluster, but we can no longer read the cluster role.")
}
}
killed := make(map[uint32]struct{})
processes := ctx.ProcessList.Processes()
for _, p := range processes {
if p.Connection != ctx.Session.ID() {
// Kill any inflight query.
ctx.ProcessList.Kill(p.Connection)
// Tear down the connection itself.
ctx.KillConnection(p.Connection)
killed[p.Connection] = struct{}{}
}
}
// Look in processes until the connections are actually gone.
params := backoff.NewExponentialBackOff()
params.InitialInterval = 1 * time.Millisecond
params.MaxInterval = 25 * time.Millisecond
params.MaxElapsedTime = 3 * time.Second
err := backoff.Retry(func() error {
processes := ctx.ProcessList.Processes()
allgood := true
for _, p := range processes {
if _, ok := killed[p.Connection]; ok {
allgood = false
ctx.ProcessList.Kill(p.Connection)
}
}
if !allgood {
return errors.New("unable to establish safepoint.")
}
return nil
}, params)
if err != nil {
return err
}
ctx.Session.SetTransaction(nil)
dsess.DSessFromSess(ctx.Session).SetValidateErr(ErrServerPerformedGC)
return nil
},
cancel: func() {},
sc := killConnectionsSafepointController{
origEpoch: origepoch,
callCtx: ctx,
}
err = ddb.GC(ctx, mode, sc)
if err != nil {
return cmdFailure, err
@@ -0,0 +1,191 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dsess
import (
"context"
"sync"
"sync/atomic"
)
type GCSafepointController struct {
mu sync.Mutex
// All known sessions. The first command registers the session
// here, and SessionEnd causes it to be removed.
sessions map[*DoltSession]*GCSafepointSessionState
}
type GCSafepointSessionState struct {
// True when a command is outstanding on the session,
// false otherwise.
OutstandingCommand bool
// Registered when we create a GCSafepointWaiter, this
// will be called when the session's SessionCommandEnd
// function is hit.
CommandEndCallback func()
// When this channel is non-nil, it means that an
// outstanding visit session call is ongoing for this
// session. The CommandBegin callback will block until
// that call has completed.
QuiesceCallbackDone atomic.Value // chan struct{}
}
type GCSafepointWaiter struct {
controller *GCSafepointController
wg sync.WaitGroup
}
func NewGCSafepointController() *GCSafepointController {
return &GCSafepointController{
sessions: make(map[*DoltSession]*GCSafepointSessionState),
}
}
// The GCSafepointController is keeping track of *DoltSession instances that have ever had work done.
// By pairing up CommandBegin and CommandEnd callbacks, it can identify quiesced sessions--ones that
// are not currently running work. Calling |Waiter| asks the controller to concurrently call
// |visitQuiescedSession| on each known session as soon as it is safe and possible. The returned
// |Waiter| is used to |Wait| for all of those to be completed. A call is not made for |thisSession|,
// since, if that session corresponds to an ongoing SQL procedure call, for example, that session
// will never quiesce.
//
// After creating a Waiter, it is an error to create a new Waiter before the |Wait| method of the
// original watier has returned. This error is not guaranteed to always be detected.
func (c *GCSafepointController) Waiter(thisSession *DoltSession, visitQuiescedSession func(*DoltSession)) *GCSafepointWaiter {
c.mu.Lock()
defer c.mu.Unlock()
ret := &GCSafepointWaiter{controller: c}
for sess, state := range c.sessions {
if state.CommandEndCallback != nil {
panic("Attempt to create more than one GCSafepointWaiter.")
}
if sess == thisSession {
continue
}
if state.OutstandingCommand {
ret.wg.Add(1)
state.CommandEndCallback = func() {
state.QuiesceCallbackDone.Store(make(chan struct{}))
go func() {
visitQuiescedSession(sess)
ret.wg.Done()
toClose := state.QuiesceCallbackDone.Load().(chan struct{})
close(toClose)
}()
}
} else {
ret.wg.Add(1)
state.QuiesceCallbackDone.Store(make(chan struct{}))
go func() {
visitQuiescedSession(sess)
ret.wg.Done()
toClose := state.QuiesceCallbackDone.Load().(chan struct{})
close(toClose)
}()
}
}
return ret
}
func (w *GCSafepointWaiter) Wait(ctx context.Context) error {
done := make(chan struct{})
go func() {
w.wg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
w.controller.mu.Lock()
for _, state := range w.controller.sessions {
if state.CommandEndCallback != nil {
// Do not visit the session, but do
// count down the WaitGroup so that
// the goroutine above still completes.
w.wg.Done()
state.CommandEndCallback = nil
}
}
w.controller.mu.Unlock()
// Once a session visit callback has started, we
// cannot cancel it. So we wait for all the inflight
// callbacks to be completed here, before returning.
<-done
return context.Cause(ctx)
}
}
var closedCh = make(chan struct{})
func init() {
close(closedCh)
}
func (c *GCSafepointController) SessionCommandBegin(s *DoltSession) error {
c.mu.Lock()
defer c.mu.Unlock()
var state *GCSafepointSessionState
if state = c.sessions[s]; state == nil {
state = &GCSafepointSessionState{}
state.QuiesceCallbackDone.Store(closedCh)
c.sessions[s] = state
}
if state.OutstandingCommand {
panic("SesisonBeginCommand called on a session that already had an outstanding command.")
}
toWait := state.QuiesceCallbackDone.Load().(chan struct{})
select {
case <-toWait:
default:
c.mu.Unlock()
<-toWait
c.mu.Lock()
}
state.OutstandingCommand = true
return nil
}
func (c *GCSafepointController) SessionCommandEnd(s *DoltSession) {
c.mu.Lock()
defer c.mu.Unlock()
state := c.sessions[s]
if state == nil {
panic("SessionCommandEnd called on a session that was not registered")
}
if state.OutstandingCommand != true {
panic("SessionCommandEnd called on a session that did not have an outstanding command.")
}
if state.CommandEndCallback != nil {
state.CommandEndCallback()
state.CommandEndCallback = nil
}
state.OutstandingCommand = false
}
// Because we only register sessions when the BeginCommand, it is technically
// possible to get a SessionEnd callback for a session that was never registered.
// However, if there is a corresponding session, it is certainly an error for
// us to get this callback and have OutstandingCommand == true.
func (c *GCSafepointController) SessionEnd(s *DoltSession) {
c.mu.Lock()
defer c.mu.Unlock()
state := c.sessions[s]
if state != nil && state.OutstandingCommand == true {
panic("SessionEnd called on a session that had an outstanding command.")
}
delete(c.sessions, s)
}
@@ -0,0 +1,287 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dsess
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestGCSafepointController(t *testing.T) {
t.Parallel()
t.Run("SessionEnd", func(t *testing.T) {
t.Parallel()
t.Run("UnknownSession", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
controller.SessionEnd(&DoltSession{})
})
t.Run("KnownSession", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
controller.SessionCommandBegin(sess)
controller.SessionCommandEnd(sess)
controller.SessionEnd(sess)
})
t.Run("RunningSession", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
controller.SessionCommandBegin(sess)
require.Panics(t, func() {
controller.SessionEnd(sess)
})
})
})
t.Run("CommandBegin", func(t *testing.T) {
t.Parallel()
t.Run("RunningSession", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
controller.SessionCommandBegin(sess)
require.Panics(t, func() {
controller.SessionCommandBegin(sess)
})
})
t.Run("AfterCommandEnd", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
controller.SessionCommandBegin(sess)
controller.SessionCommandEnd(sess)
controller.SessionCommandBegin(sess)
})
})
t.Run("CommandEnd", func(t *testing.T) {
t.Parallel()
t.Run("NotKnown", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
require.Panics(t, func() {
controller.SessionCommandEnd(sess)
})
})
t.Run("NotRunning", func(t *testing.T) {
t.Parallel()
controller := NewGCSafepointController()
sess := &DoltSession{}
controller.SessionCommandBegin(sess)
controller.SessionCommandEnd(sess)
require.Panics(t, func() {
controller.SessionCommandEnd(sess)
})
})
})
t.Run("Waiter", func(t *testing.T) {
t.Parallel()
t.Run("Empty", func(t *testing.T) {
t.Parallel()
var nilCh chan struct{}
block := func(*DoltSession) {
<-nilCh
}
controller := NewGCSafepointController()
waiter := controller.Waiter(nil, block)
waiter.Wait(context.Background())
})
t.Run("OnlyThisSession", func(t *testing.T) {
t.Parallel()
var nilCh chan struct{}
block := func(*DoltSession) {
<-nilCh
}
sess := &DoltSession{}
controller := NewGCSafepointController()
controller.SessionCommandBegin(sess)
waiter := controller.Waiter(sess, block)
waiter.Wait(context.Background())
controller.SessionCommandEnd(sess)
controller.SessionEnd(sess)
})
t.Run("OneQuiescedOneNot", func(t *testing.T) {
t.Parallel()
// A test case where one session is known
// but not within a command and another one
// is within a command at the time the
// waiter is created.
quiesced := &DoltSession{}
running := &DoltSession{}
controller := NewGCSafepointController()
controller.SessionCommandBegin(quiesced)
controller.SessionCommandBegin(running)
controller.SessionCommandEnd(quiesced)
sawQuiesced, sawRunning, waitDone := make(chan struct{}), make(chan struct{}), make(chan struct{})
wait := func(s *DoltSession) {
if s == quiesced {
close(sawQuiesced)
} else if s == running {
close(sawRunning)
} else {
panic("saw unexpected session")
}
}
waiter := controller.Waiter(nil, wait)
go func() {
waiter.Wait(context.Background())
close(waitDone)
}()
<-sawQuiesced
select {
case <-sawRunning:
require.FailNow(t, "unexpected saw running session on callback before it was quiesced")
case <-time.After(50 * time.Millisecond):
}
controller.SessionCommandEnd(running)
<-sawRunning
<-waitDone
controller.SessionCommandBegin(quiesced)
controller.SessionCommandBegin(running)
controller.SessionCommandEnd(quiesced)
controller.SessionCommandEnd(running)
})
t.Run("OneQuiescedOneNotCanceledContext", func(t *testing.T) {
t.Parallel()
// When the Wait context is canceled, we do not block on
// the running sessions and they never get visited.
quiesced := &DoltSession{}
running := &DoltSession{}
controller := NewGCSafepointController()
controller.SessionCommandBegin(quiesced)
controller.SessionCommandBegin(running)
controller.SessionCommandEnd(quiesced)
sawQuiesced, sawRunning, waitDone := make(chan struct{}), make(chan struct{}), make(chan struct{})
wait := func(s *DoltSession) {
if s == quiesced {
close(sawQuiesced)
} else if s == running {
close(sawRunning)
} else {
panic("saw unexpected session")
}
}
waiter := controller.Waiter(nil, wait)
var waitErr error
go func() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
waitErr = waiter.Wait(ctx)
close(waitDone)
}()
<-sawQuiesced
<-waitDone
require.Error(t, waitErr)
select {
case <-sawRunning:
require.FailNow(t, "unexpected saw running session on callback before it was quiesced")
case <-time.After(50 * time.Millisecond):
}
controller.SessionCommandEnd(running)
select {
case <-sawRunning:
require.FailNow(t, "unexpected saw running session on callback before it was quiesced")
case <-time.After(50 * time.Millisecond):
}
controller.SessionCommandBegin(quiesced)
controller.SessionCommandBegin(running)
controller.SessionCommandEnd(quiesced)
controller.SessionCommandEnd(running)
})
t.Run("BeginBlocksUntilVisitFinished", func(t *testing.T) {
t.Parallel()
quiesced := &DoltSession{}
running := &DoltSession{}
controller := NewGCSafepointController()
controller.SessionCommandBegin(quiesced)
controller.SessionCommandEnd(quiesced)
controller.SessionCommandBegin(running)
finishQuiesced, finishRunning := make(chan struct{}), make(chan struct{})
sawQuiesced, sawRunning := make(chan struct{}), make(chan struct{})
wait := func(s *DoltSession) {
if s == quiesced {
close(sawQuiesced)
<-finishQuiesced
} else if s == running {
close(sawRunning)
<-finishRunning
} else {
panic("saw unexpected session")
}
}
waiter := controller.Waiter(nil, wait)
waitDone := make(chan struct{})
go func() {
waiter.Wait(context.Background())
close(waitDone)
}()
beginDone := make(chan struct{})
go func() {
controller.SessionCommandBegin(quiesced)
close(beginDone)
}()
<-sawQuiesced
select {
case <-beginDone:
require.FailNow(t, "unexpected beginDone")
case <-time.After(50 * time.Millisecond):
}
newSession := &DoltSession{}
controller.SessionCommandBegin(newSession)
controller.SessionCommandEnd(newSession)
controller.SessionEnd(newSession)
close(finishQuiesced)
<-beginDone
beginDone = make(chan struct{})
go func() {
controller.SessionCommandEnd(running)
<-sawRunning
controller.SessionCommandBegin(running)
close(beginDone)
}()
select {
case <-beginDone:
require.FailNow(t, "unexpected beginDone")
case <-time.After(50 * time.Millisecond):
}
close(finishRunning)
<-beginDone
<-waitDone
controller.SessionCommandEnd(quiesced)
controller.SessionCommandEnd(running)
controller.SessionCommandBegin(quiesced)
controller.SessionCommandBegin(running)
controller.SessionCommandEnd(quiesced)
controller.SessionCommandEnd(running)
controller.SessionEnd(quiesced)
controller.SessionEnd(running)
err := controller.Waiter(nil, func(*DoltSession) {
panic("unexpected registered session")
}).Wait(context.Background())
require.NoError(t, err)
})
})
}
+50 -26
View File
@@ -50,19 +50,20 @@ var ErrSessionNotPersistable = errors.New("session is not persistable")
// DoltSession is the sql.Session implementation used by dolt. It is accessible through a *sql.Context instance
type DoltSession struct {
sql.Session
DoltgresSessObj any // This is used by Doltgres to persist objects in the session. This is not used by Dolt.
username string
email string
dbStates map[string]*DatabaseSessionState
dbCache *DatabaseCache
provider DoltDatabaseProvider
tempTables map[string][]sql.Table
globalsConf config.ReadWriteConfig
branchController *branch_control.Controller
statsProv sql.StatsProvider
mu *sync.Mutex
fs filesys.Filesys
writeSessProv WriteSessFunc
DoltgresSessObj any // This is used by Doltgres to persist objects in the session. This is not used by Dolt.
username string
email string
dbStates map[string]*DatabaseSessionState
dbCache *DatabaseCache
provider DoltDatabaseProvider
tempTables map[string][]sql.Table
globalsConf config.ReadWriteConfig
branchController *branch_control.Controller
statsProv sql.StatsProvider
mu *sync.Mutex
fs filesys.Filesys
writeSessProv WriteSessFunc
gcSafepointController *GCSafepointController
// If non-nil, this will be returned from ValidateSession.
// Used by sqle/cluster to put a session into a terminal err state.
@@ -100,25 +101,27 @@ func NewDoltSession(
branchController *branch_control.Controller,
statsProvider sql.StatsProvider,
writeSessProv WriteSessFunc,
gcSafepointController *GCSafepointController,
) (*DoltSession, error) {
username := conf.GetStringOrDefault(config.UserNameKey, "")
email := conf.GetStringOrDefault(config.UserEmailKey, "")
globals := config.NewPrefixConfig(conf, env.SqlServerGlobalsPrefix)
sess := &DoltSession{
Session: sqlSess,
username: username,
email: email,
dbStates: make(map[string]*DatabaseSessionState),
dbCache: newDatabaseCache(),
provider: pro,
tempTables: make(map[string][]sql.Table),
globalsConf: globals,
branchController: branchController,
statsProv: statsProvider,
mu: &sync.Mutex{},
fs: pro.FileSystem(),
writeSessProv: writeSessProv,
Session: sqlSess,
username: username,
email: email,
dbStates: make(map[string]*DatabaseSessionState),
dbCache: newDatabaseCache(),
provider: pro,
tempTables: make(map[string][]sql.Table),
globalsConf: globals,
branchController: branchController,
statsProv: statsProvider,
mu: &sync.Mutex{},
fs: pro.FileSystem(),
writeSessProv: writeSessProv,
gcSafepointController: gcSafepointController,
}
return sess, nil
@@ -1628,6 +1631,27 @@ func (d *DoltSession) GetController() *branch_control.Controller {
return d.branchController
}
// Implement sql.LifecycleAwareSession, allowing for GC safepoints to be aware of
// outstanding SQL operations.
func (d *DoltSession) CommandBegin() error {
if d.gcSafepointController != nil {
return d.gcSafepointController.SessionCommandBegin(d)
}
return nil
}
func (d *DoltSession) CommandEnd() {
if d.gcSafepointController != nil {
d.gcSafepointController.SessionCommandEnd(d)
}
}
func (d *DoltSession) SessionEnd() {
if d.gcSafepointController != nil {
d.gcSafepointController.SessionEnd(d)
}
}
// validatePersistedSysVar checks whether a system variable exists and is dynamic
func validatePersistableSysVar(name string) (sql.SystemVariable, interface{}, error) {
sysVar, val, ok := sql.SystemVariables.GetGlobal(name)
@@ -34,6 +34,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dprocedures"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/kvexec"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/statsnoms"
@@ -246,7 +247,7 @@ func (d *DoltHarness) NewEngine(t *testing.T) (enginetest.QueryEngine, error) {
d.statsPro = statsProv
var err error
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession)
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession, nil)
require.NoError(t, err)
e, err := enginetest.NewEngine(t, d, d.provider, d.setupData, d.statsPro)
@@ -273,7 +274,7 @@ func (d *DoltHarness) NewEngine(t *testing.T) (enginetest.QueryEngine, error) {
}
// Get a fresh session after running setup scripts, since some setup scripts can change the session state
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession)
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession, nil)
require.NoError(t, err)
}
@@ -314,7 +315,7 @@ func (d *DoltHarness) NewEngine(t *testing.T) (enginetest.QueryEngine, error) {
e, err := enginetest.RunSetupScripts(ctx, d.engine, d.resetScripts(), d.SupportsNativeIndexCreation())
// Get a fresh session after running setup scripts, since some setup scripts can change the session state
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession)
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), d.provider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession, nil)
require.NoError(t, err)
return e, err
@@ -396,7 +397,7 @@ func (d *DoltHarness) newSessionWithClient(client sql.Client) *dsess.DoltSession
localConfig := d.multiRepoEnv.Config()
pro := d.session.Provider()
dSession, err := dsess.NewDoltSession(sql.NewBaseSessionWithClientServer("address", client, 1), pro.(dsess.DoltDatabaseProvider), localConfig, d.branchControl, d.statsPro, writer.NewWriteSession)
dSession, err := dsess.NewDoltSession(sql.NewBaseSessionWithClientServer("address", client, 1), pro.(dsess.DoltDatabaseProvider), localConfig, d.branchControl, d.statsPro, writer.NewWriteSession, nil)
dSession.SetCurrentDatabase("mydb")
require.NoError(d.t, err)
return dSession
@@ -428,7 +429,7 @@ func (d *DoltHarness) NewDatabases(names ...string) []sql.Database {
d.statsPro = statspro.NewProvider(doltProvider, statsnoms.NewNomsStatsFactory(d.multiRepoEnv.RemoteDialProvider()))
var err error
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), doltProvider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession)
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), doltProvider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession, nil)
require.NoError(d.t, err)
// TODO: the engine tests should do this for us
@@ -486,7 +487,7 @@ func (d *DoltHarness) NewReadOnlyEngine(provider sql.DatabaseProvider) (enginete
}
// reset the session as well since we have swapped out the database provider, which invalidates caching assumptions
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), readOnlyProvider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession)
d.session, err = dsess.NewDoltSession(enginetest.NewBaseSession(), readOnlyProvider, d.multiRepoEnv.Config(), d.branchControl, d.statsPro, writer.NewWriteSession, nil)
require.NoError(d.t, err)
return enginetest.NewEngineWithProvider(nil, d, readOnlyProvider), nil
@@ -531,6 +532,8 @@ func (d *DoltHarness) newProvider() sql.MutableDatabaseProvider {
b := env.GetDefaultInitBranch(d.multiRepoEnv.Config())
pro, err := sqle.NewDoltDatabaseProvider(b, d.multiRepoEnv.FileSystem())
require.NoError(d.t, err)
gcSafepointController := dsess.NewGCSafepointController()
pro.Register(dprocedures.NewDoltGCProcedure(gcSafepointController))
return pro
}
+1 -1
View File
@@ -1113,7 +1113,7 @@ func newTestEngine(ctx context.Context, dEnv *env.DoltEnv) (*gms.Engine, *sql.Co
panic(err)
}
doltSession, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, dEnv.Config.WriteableConfig(), nil, nil, writer.NewWriteSession)
doltSession, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, dEnv.Config.WriteableConfig(), nil, nil, writer.NewWriteSession, nil)
if err != nil {
panic(err)
}
+1 -1
View File
@@ -116,7 +116,7 @@ func ExecuteSql(dEnv *env.DoltEnv, root doltdb.RootValue, statements string) (do
}
func NewTestSQLCtxWithProvider(ctx context.Context, pro dsess.DoltDatabaseProvider, statsPro sql.StatsProvider) *sql.Context {
s, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, config.NewMapConfig(make(map[string]string)), branch_control.CreateDefaultController(ctx), statsPro, writer.NewWriteSession)
s, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, config.NewMapConfig(make(map[string]string)), branch_control.CreateDefaultController(ctx), statsPro, writer.NewWriteSession, nil)
if err != nil {
panic(err)
}