cmd/dolt/commands/sqlserver: Restructure the start up sequence for sql-server.

We explicitly model Services, which can have an Init step, a Run step and a
Stop step. Every registered service get initialized in the order they were
registered in, then they all run concurrently until Stop is called, when they
all get Stopped in reverse order. It's possible for clients to wait for init to
be completed and be delivered any errors encountered on startup. They can also
wait for stop, to be delivered any errors encountered on shutdown.
This commit is contained in:
Aaron Son
2023-11-14 16:30:55 -08:00
parent 0cc42b8440
commit 23dc3ed014
10 changed files with 987 additions and 454 deletions

View File

@@ -48,6 +48,7 @@ import (
_ "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
)
const (
@@ -63,335 +64,470 @@ func Serve(
ctx context.Context,
version string,
serverConfig ServerConfig,
serverController *ServerController,
controller *svcs.Controller,
dEnv *env.DoltEnv,
) (startError error, closeError error) {
// Code is easier to work through if we assume that serverController is never nil
if serverController == nil {
serverController = NewServerController()
if controller == nil {
controller = svcs.NewController()
}
var mySQLServer *server.Server
// This guarantees unblocking on any routines with a waiting `ServerController`
defer func() {
if mySQLServer != nil {
serverController.registerCloseFunction(startError, mySQLServer.Close)
} else {
serverController.registerCloseFunction(startError, func() error { return nil })
}
serverController.StopServer()
serverController.serverStopped(closeError)
sqlserver.UnsetRunningServer()
}()
if startError = ValidateConfig(serverConfig); startError != nil {
return startError, nil
ValidateConfigStep := &svcs.Service{
Init: func(context.Context) error {
return ValidateConfig(serverConfig)
},
}
controller.Register(ValidateConfigStep)
lgr := logrus.StandardLogger()
lgr.SetOutput(cli.CliErr)
InitLogging := &svcs.Service{
Init: func(context.Context) error {
level, err := logrus.ParseLevel(serverConfig.LogLevel().String())
if err != nil {
return err
}
logrus.SetLevel(level)
if serverConfig.LogLevel() != LogLevel_Info {
var level logrus.Level
level, startError = logrus.ParseLevel(serverConfig.LogLevel().String())
if startError != nil {
cli.PrintErr(startError)
return
}
logrus.SetLevel(level)
}
logrus.SetFormatter(LogFormat{})
sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{
{
Name: dsess.DoltLogLevel,
Scope: sql.SystemVariableScope_Global,
Dynamic: true,
SetVarHintApplies: false,
Type: types.NewSystemEnumType(dsess.DoltLogLevel,
logrus.PanicLevel.String(),
logrus.FatalLevel.String(),
logrus.ErrorLevel.String(),
logrus.WarnLevel.String(),
logrus.InfoLevel.String(),
logrus.DebugLevel.String(),
logrus.TraceLevel.String(),
),
Default: logrus.GetLevel().String(),
NotifyChanged: func(scope sql.SystemVariableScope, v sql.SystemVarValue) error {
level, err := logrus.ParseLevel(v.Val.(string))
if err != nil {
return fmt.Errorf("could not parse requested log level %s as a log level. dolt_log_level variable value and logging behavior will diverge.", v.Val.(string))
}
sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{
{
Name: dsess.DoltLogLevel,
Scope: sql.SystemVariableScope_Global,
Dynamic: true,
SetVarHintApplies: false,
Type: types.NewSystemEnumType(dsess.DoltLogLevel,
logrus.PanicLevel.String(),
logrus.FatalLevel.String(),
logrus.ErrorLevel.String(),
logrus.WarnLevel.String(),
logrus.InfoLevel.String(),
logrus.DebugLevel.String(),
logrus.TraceLevel.String(),
),
Default: logrus.GetLevel().String(),
NotifyChanged: func(scope sql.SystemVariableScope, v sql.SystemVarValue) error {
level, err := logrus.ParseLevel(v.Val.(string))
if err != nil {
return fmt.Errorf("could not parse requested log level %s as a log level. dolt_log_level variable value and logging behavior will diverge.", v.Val.(string))
}
logrus.SetLevel(level)
return nil
},
logrus.SetLevel(level)
return nil
},
},
})
return nil
},
})
}
controller.Register(InitLogging)
fs := dEnv.FS
InitDataDir := &svcs.Service{
Init: func(context.Context) error {
if len(serverConfig.DataDir()) > 0 && serverConfig.DataDir() != "." {
var err error
fs, err = dEnv.FS.WithWorkingDir(serverConfig.DataDir())
if err != nil {
return err
}
dEnv.FS = fs
}
return nil
},
}
controller.Register(InitDataDir)
var serverLock *env.DBLock
InitGlobalServerLock := &svcs.Service{
Init: func(context.Context) error {
var err error
serverLock, err = acquireGlobalSqlServerLock(serverConfig.Port(), dEnv)
return err
},
Stop: func() error {
dEnv.FS.Delete(dEnv.LockFile(), false)
return nil
},
}
controller.Register(InitGlobalServerLock)
var mrEnv *env.MultiRepoEnv
var err error
fs := dEnv.FS
if len(serverConfig.DataDir()) > 0 && serverConfig.DataDir() != "." {
fs, err = dEnv.FS.WithWorkingDir(serverConfig.DataDir())
if err != nil {
return err, nil
}
dEnv.FS = fs
InitMultiEnv := &svcs.Service{
Init: func(ctx context.Context) error {
var err error
mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv.IgnoreLockFile, dEnv)
return err
},
}
controller.Register(InitMultiEnv)
serverLock, startError := acquireGlobalSqlServerLock(serverConfig.Port(), dEnv)
if startError != nil {
return
var clusterController *cluster.Controller
InitClusterController := &svcs.Service{
Init: func(context.Context) error {
var err error
clusterController, err = cluster.NewController(lgr, serverConfig.ClusterConfig(), mrEnv.Config())
return err
},
}
defer dEnv.FS.Delete(dEnv.LockFile(), false)
controller.Register(InitClusterController)
mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv.IgnoreLockFile, dEnv)
if err != nil {
return err, nil
}
clusterController, err := cluster.NewController(lgr, serverConfig.ClusterConfig(), mrEnv.Config())
if err != nil {
return err, nil
}
serverConf, sErr, cErr := getConfigFromServerConfig(serverConfig)
if cErr != nil {
return nil, cErr
} else if sErr != nil {
return sErr, nil
var serverConf server.Config
LoadServerConfig := &svcs.Service{
Init: func(context.Context) error {
var err error
serverConf, err = getConfigFromServerConfig(serverConfig)
return err
},
}
controller.Register(LoadServerConfig)
// Create SQL Engine with users
config := &engine.SqlEngineConfig{
IsReadOnly: serverConfig.ReadOnly(),
PrivFilePath: serverConfig.PrivilegeFilePath(),
BranchCtrlFilePath: serverConfig.BranchControlFilePath(),
DoltCfgDirPath: serverConfig.CfgDir(),
ServerUser: serverConfig.User(),
ServerPass: serverConfig.Password(),
ServerHost: serverConfig.Host(),
Autocommit: serverConfig.AutoCommit(),
DoltTransactionCommit: serverConfig.DoltTransactionCommit(),
JwksConfig: serverConfig.JwksConfig(),
SystemVariables: serverConfig.SystemVars(),
ClusterController: clusterController,
BinlogReplicaController: binlogreplication.DoltBinlogReplicaController,
}
esStatus, err := getEventSchedulerStatus(serverConfig.EventSchedulerStatus())
if err != nil {
return err, nil
}
config.EventSchedulerStatus = esStatus
sqlEngine, err := engine.NewSqlEngine(
ctx,
mrEnv,
config,
)
if err != nil {
return err, nil
var config *engine.SqlEngineConfig
InitSqlEngineConfig := &svcs.Service{
Init: func(context.Context) error {
config = &engine.SqlEngineConfig{
IsReadOnly: serverConfig.ReadOnly(),
PrivFilePath: serverConfig.PrivilegeFilePath(),
BranchCtrlFilePath: serverConfig.BranchControlFilePath(),
DoltCfgDirPath: serverConfig.CfgDir(),
ServerUser: serverConfig.User(),
ServerPass: serverConfig.Password(),
ServerHost: serverConfig.Host(),
Autocommit: serverConfig.AutoCommit(),
DoltTransactionCommit: serverConfig.DoltTransactionCommit(),
JwksConfig: serverConfig.JwksConfig(),
SystemVariables: serverConfig.SystemVars(),
ClusterController: clusterController,
BinlogReplicaController: binlogreplication.DoltBinlogReplicaController,
}
return nil
},
}
defer sqlEngine.Close()
controller.Register(InitSqlEngineConfig)
var esStatus eventscheduler.SchedulerStatus
InitEventSchedulerStatus := &svcs.Service{
Init: func(context.Context) error {
var err error
esStatus, err = getEventSchedulerStatus(serverConfig.EventSchedulerStatus())
if err != nil {
return err
}
config.EventSchedulerStatus = esStatus
return nil
},
}
controller.Register(InitEventSchedulerStatus)
var sqlEngine *engine.SqlEngine
InitSqlEngine := &svcs.Service{
Init: func(ctx context.Context) error {
var err error
sqlEngine, err = engine.NewSqlEngine(
ctx,
mrEnv,
config,
)
return err
},
Stop: func() error {
sqlEngine.Close()
return nil
},
}
controller.Register(InitSqlEngine)
// Add superuser if specified user exists; add root superuser if no user specified and no existing privileges
userSpecified := config.ServerUser != ""
var userSpecified bool
mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb
ed := mysqlDb.Editor()
var numUsers int
ed.VisitUsers(func(*mysql_db.User) { numUsers += 1 })
privsExist := numUsers != 0
if userSpecified {
superuser := mysqlDb.GetUser(ed, config.ServerUser, "%", false)
if userSpecified && superuser == nil {
mysqlDb.AddSuperUser(ed, config.ServerUser, "%", config.ServerPass)
}
} else if !privsExist {
mysqlDb.AddSuperUser(ed, defaultUser, "%", defaultPass)
InitSuperUser := &svcs.Service{
Init: func(context.Context) error {
userSpecified = config.ServerUser != ""
mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb
ed := mysqlDb.Editor()
var numUsers int
ed.VisitUsers(func(*mysql_db.User) { numUsers += 1 })
privsExist := numUsers != 0
if userSpecified {
superuser := mysqlDb.GetUser(ed, config.ServerUser, "%", false)
if userSpecified && superuser == nil {
mysqlDb.AddSuperUser(ed, config.ServerUser, "%", config.ServerPass)
}
} else if !privsExist {
mysqlDb.AddSuperUser(ed, defaultUser, "%", defaultPass)
}
ed.Close()
return nil
},
}
ed.Close()
labels := serverConfig.MetricsLabels()
controller.Register(InitSuperUser)
var listener *metricsListener
listener, startError = newMetricsListener(labels, version, clusterController)
if startError != nil {
cli.Println(startError)
return
InitMetricsListener := &svcs.Service{
Init: func(context.Context) error {
labels := serverConfig.MetricsLabels()
var err error
listener, err = newMetricsListener(labels, version, clusterController)
return err
},
Stop: func() error {
listener.Close()
return nil
},
}
defer listener.Close()
controller.Register(InitMetricsListener)
v, ok := serverConfig.(validatingServerConfig)
if ok && v.goldenMysqlConnectionString() != "" {
mySQLServer, startError = server.NewValidatingServer(
serverConf,
sqlEngine.GetUnderlyingEngine(),
newSessionBuilder(sqlEngine, serverConfig),
listener,
v.goldenMysqlConnectionString(),
)
} else {
mySQLServer, startError = server.NewServer(
serverConf,
sqlEngine.GetUnderlyingEngine(),
newSessionBuilder(sqlEngine, serverConfig),
listener,
)
var mySQLServer *server.Server
InitSQLServer := &svcs.Service{
Init: func(context.Context) error {
var err error
v, ok := serverConfig.(validatingServerConfig)
if ok && v.goldenMysqlConnectionString() != "" {
mySQLServer, err = server.NewValidatingServer(
serverConf,
sqlEngine.GetUnderlyingEngine(),
newSessionBuilder(sqlEngine, serverConfig),
listener,
v.goldenMysqlConnectionString(),
)
} else {
mySQLServer, err = server.NewServer(
serverConf,
sqlEngine.GetUnderlyingEngine(),
newSessionBuilder(sqlEngine, serverConfig),
listener,
)
}
if errors.Is(err, server.UnixSocketInUseError) {
lgr.Warn("unix socket set up failed: file already in use: ", serverConf.Socket)
err = nil
}
return err
},
}
controller.Register(InitSQLServer)
if errors.Is(startError, server.UnixSocketInUseError) {
lgr.Warn("unix socket set up failed: file already in use: ", serverConf.Socket)
startError = nil
} else if startError != nil {
cli.PrintErr(startError)
return
LockMultiRepoEnv := &svcs.Service{
Init: func(context.Context) error {
if ok, f := mrEnv.IsLocked(); ok {
return env.ErrActiveServerLock.New(f)
}
if err := mrEnv.Lock(serverLock); err != nil {
return err
}
return nil
},
Stop: func() error {
if err := mrEnv.Unlock(); err != nil {
cli.PrintErr(err)
}
return nil
},
}
controller.Register(LockMultiRepoEnv)
sqlserver.SetRunningServer(mySQLServer, serverLock)
ed = mysqlDb.Editor()
mysqlDb.AddSuperUser(ed, LocalConnectionUser, "localhost", serverLock.Secret)
ed.Close()
if ExternalDisableUsers {
mysqlDb.SetEnabled(false)
InitLockSuperUser := &svcs.Service{
Init: func(context.Context) error {
mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb
ed := mysqlDb.Editor()
mysqlDb.AddSuperUser(ed, LocalConnectionUser, "localhost", serverLock.Secret)
ed.Close()
return nil
},
}
controller.Register(InitLockSuperUser)
DisableMySQLDbIfRequired := &svcs.Service{
Init: func(context.Context) error {
if ExternalDisableUsers {
mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb
mysqlDb.SetEnabled(false)
}
return nil
},
}
controller.Register(DisableMySQLDbIfRequired)
var metSrv *http.Server
if serverConfig.MetricsHost() != "" && serverConfig.MetricsPort() > 0 {
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.Handler())
RunMetricsServer := &svcs.Service{
Run: func(context.Context) {
if serverConfig.MetricsHost() != "" && serverConfig.MetricsPort() > 0 {
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.Handler())
metSrv = &http.Server{
Addr: fmt.Sprintf("%s:%d", serverConfig.MetricsHost(), serverConfig.MetricsPort()),
Handler: mux,
}
metSrv = &http.Server{
Addr: fmt.Sprintf("%s:%d", serverConfig.MetricsHost(), serverConfig.MetricsPort()),
Handler: mux,
}
go func() {
_ = metSrv.ListenAndServe()
}()
_ = metSrv.ListenAndServe()
}
},
Stop: func() error {
if metSrv != nil {
metSrv.Close()
}
return nil
},
}
controller.Register(RunMetricsServer)
var remoteSrv *remotesrv.Server
if serverConfig.RemotesapiPort() != nil {
port := *serverConfig.RemotesapiPort()
listenaddr := fmt.Sprintf(":%d", port)
args, err := sqle.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{
Logger: logrus.NewEntry(lgr),
ReadOnly: true,
HttpListenAddr: listenaddr,
GrpcListenAddr: listenaddr,
})
if err != nil {
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
startError = err
return
}
var remoteSrvListeners remotesrv.Listeners
RunRemoteSrv := &svcs.Service{
Init: func(ctx context.Context) error {
if serverConfig.RemotesapiPort() == nil {
return nil
}
ctxFactory := func() (*sql.Context, error) { return sqlEngine.NewDefaultContext(ctx) }
authenticator := newAuthenticator(ctxFactory, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
args = sqle.WithUserPasswordAuth(args, authenticator)
port := *serverConfig.RemotesapiPort()
listenaddr := fmt.Sprintf(":%d", port)
args, err := sqle.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{
Logger: logrus.NewEntry(lgr),
ReadOnly: true,
HttpListenAddr: listenaddr,
GrpcListenAddr: listenaddr,
})
if err != nil {
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
return err
}
args.TLSConfig = serverConf.TLSConfig
remoteSrv, err = remotesrv.NewServer(args)
if err != nil {
lgr.Errorf("error creating remotesapi server on port %d: %v", port, err)
startError = err
return
}
listeners, err := remoteSrv.Listeners()
if err != nil {
lgr.Errorf("error starting remotesapi server listeners on port %d: %v", port, err)
startError = err
return
} else {
go remoteSrv.Serve(listeners)
}
ctxFactory := func() (*sql.Context, error) { return sqlEngine.NewDefaultContext(ctx) }
authenticator := newAuthenticator(ctxFactory, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
args = sqle.WithUserPasswordAuth(args, authenticator)
args.TLSConfig = serverConf.TLSConfig
remoteSrv, err = remotesrv.NewServer(args)
if err != nil {
lgr.Errorf("error creating remotesapi server on port %d: %v", port, err)
return err
}
remoteSrvListeners, err = remoteSrv.Listeners()
if err != nil {
lgr.Errorf("error starting remotesapi server listeners on port %d: %v", port, err)
return err
}
return nil
},
Run: func(ctx context.Context) {
if remoteSrv == nil {
return
}
remoteSrv.Serve(remoteSrvListeners)
},
Stop: func() error {
if remoteSrv == nil {
return nil
}
remoteSrv.GracefulStop()
return nil
},
}
controller.Register(RunRemoteSrv)
var clusterRemoteSrv *remotesrv.Server
if clusterController != nil {
args, err := clusterController.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{
Logger: logrus.NewEntry(lgr),
})
if err != nil {
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
startError = err
return
}
var clusterRemoteSrvListeners remotesrv.Listeners
RunClusterRemoteSrv := &svcs.Service{
Init: func(context.Context) error {
if clusterController == nil {
return nil
}
clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig())
if err != nil {
lgr.Errorf("error starting remotesapi server for cluster config, could not load tls config: %v", err)
startError = err
return
}
args.TLSConfig = clusterRemoteSrvTLSConfig
args, err := clusterController.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{
Logger: logrus.NewEntry(lgr),
})
if err != nil {
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
return err
}
clusterRemoteSrv, err = remotesrv.NewServer(args)
if err != nil {
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
startError = err
return
}
clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.GrpcServer())
clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig())
if err != nil {
lgr.Errorf("error starting remotesapi server for cluster config, could not load tls config: %v", err)
return err
}
args.TLSConfig = clusterRemoteSrvTLSConfig
listeners, err := clusterRemoteSrv.Listeners()
if err != nil {
lgr.Errorf("error starting remotesapi server listeners for cluster config on %s: %v", clusterController.RemoteSrvListenAddr(), err)
startError = err
return
}
clusterRemoteSrv, err = remotesrv.NewServer(args)
if err != nil {
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
return err
}
clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.GrpcServer())
go clusterRemoteSrv.Serve(listeners)
go clusterController.Run()
clusterController.ManageQueryConnections(
mySQLServer.SessionManager().Iter,
sqlEngine.GetUnderlyingEngine().ProcessList.Kill,
mySQLServer.SessionManager().KillConnection,
)
}
if ok, f := mrEnv.IsLocked(); ok {
startError = env.ErrActiveServerLock.New(f)
return
}
if err = mrEnv.Lock(serverLock); err != nil {
startError = err
return
}
serverController.registerCloseFunction(startError, func() error {
if metSrv != nil {
metSrv.Close()
}
if remoteSrv != nil {
remoteSrv.GracefulStop()
}
if clusterRemoteSrv != nil {
clusterRemoteSrvListeners, err = clusterRemoteSrv.Listeners()
if err != nil {
lgr.Errorf("error starting remotesapi server listeners for cluster config on %s: %v", clusterController.RemoteSrvListenAddr(), err)
return err
}
return nil
},
Run: func(context.Context) {
if clusterRemoteSrv == nil {
return
}
clusterRemoteSrv.Serve(clusterRemoteSrvListeners)
},
Stop: func() error {
if clusterRemoteSrv == nil {
return nil
}
clusterRemoteSrv.GracefulStop()
}
if clusterController != nil {
return nil
},
}
controller.Register(RunClusterRemoteSrv)
RunClusterController := &svcs.Service{
Init: func(context.Context) error {
if clusterController == nil {
return nil
}
clusterController.ManageQueryConnections(
mySQLServer.SessionManager().Iter,
sqlEngine.GetUnderlyingEngine().ProcessList.Kill,
mySQLServer.SessionManager().KillConnection,
)
return nil
},
Run: func(context.Context) {
if clusterController == nil {
return
}
clusterController.Run()
},
Stop: func() error {
if clusterController == nil {
return nil
}
clusterController.GracefulStop()
}
return mySQLServer.Close()
})
closeError = mySQLServer.Start()
if closeError != nil {
cli.PrintErr(closeError)
}
if err := mrEnv.Unlock(); err != nil {
cli.PrintErr(err)
return nil
},
}
controller.Register(RunClusterController)
return
RunSQLServer := &svcs.Service{
Run: func(context.Context) {
sqlserver.SetRunningServer(mySQLServer, serverLock)
defer sqlserver.UnsetRunningServer()
mySQLServer.Start()
},
Stop: func() error {
return mySQLServer.Close()
},
}
controller.Register(RunSQLServer)
go controller.Start(ctx)
err := controller.WaitForStart()
if err != nil {
return err, nil
}
return nil, controller.WaitForStop()
}
// acquireGlobalSqlServerLock attempts to acquire a global lock on the SQL server. If no error is returned, then the lock was acquired.
@@ -514,10 +650,10 @@ func newSessionBuilder(se *engine.SqlEngine, config ServerConfig) server.Session
}
// getConfigFromServerConfig processes ServerConfig and returns server.Config for sql-server.
func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error, error) {
func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error) {
serverConf, err := handleProtocolAndAddress(serverConfig)
if err != nil {
return server.Config{}, err, nil
return server.Config{}, err
}
serverConf.DisableClientMultiStatements = serverConfig.DisableClientMultiStatements()
@@ -527,7 +663,7 @@ func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error,
tlsConfig, err := LoadTLSConfig(serverConfig)
if err != nil {
return server.Config{}, nil, err
return server.Config{}, err
}
// if persist is 'load' we use currently set persisted global variable,
@@ -535,12 +671,12 @@ func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error,
if serverConfig.PersistenceBehavior() == loadPerisistentGlobals {
serverConf, err = serverConf.NewConfig()
if err != nil {
return server.Config{}, err, nil
return server.Config{}, err
}
} else {
err = sql.SystemVariables.SetGlobal("max_connections", serverConfig.MaxConnections())
if err != nil {
return server.Config{}, err, nil
return server.Config{}, err
}
}
@@ -554,7 +690,7 @@ func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error,
serverConf.MaxLoggedQueryLen = serverConfig.MaxLoggedQueryLen()
serverConf.EncodeLoggedQuery = serverConfig.ShouldEncodeLoggedQuery()
return serverConf, nil, nil
return serverConf, nil
}
// handleProtocolAndAddress returns new server.Config object with only Protocol and Address defined.

View File

@@ -32,6 +32,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
)
//TODO: server tests need to expose a higher granularity for server interactions:
@@ -60,7 +61,7 @@ var (
)
func TestServerArgs(t *testing.T) {
serverController := NewServerController()
controller := svcs.NewController()
dEnv, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
defer func() {
@@ -75,16 +76,16 @@ func TestServerArgs(t *testing.T) {
"-t", "5",
"-l", "info",
"-r",
}, dEnv, serverController)
}, dEnv, controller)
}()
err = serverController.WaitForStart()
err = controller.WaitForStart()
require.NoError(t, err)
conn, err := dbr.Open("mysql", "username:password@tcp(localhost:15200)/", nil)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
serverController.StopServer()
err = serverController.WaitForClose()
controller.Stop()
err = controller.WaitForStop()
assert.NoError(t, err)
}
@@ -110,22 +111,22 @@ listener:
defer func() {
assert.NoError(t, dEnv.DoltDB.Close())
}()
serverController := NewServerController()
controller := svcs.NewController()
go func() {
dEnv.FS.WriteFile("config.yaml", []byte(yamlConfig), os.ModePerm)
startServer(context.Background(), "0.0.0", "dolt sql-server", []string{
"--config", "config.yaml",
}, dEnv, serverController)
}, dEnv, controller)
}()
err = serverController.WaitForStart()
err = controller.WaitForStart()
require.NoError(t, err)
conn, err := dbr.Open("mysql", "username:password@tcp(localhost:15200)/", nil)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
serverController.StopServer()
err = serverController.WaitForClose()
controller.Stop()
err = controller.WaitForStop()
assert.NoError(t, err)
}
@@ -145,18 +146,15 @@ func TestServerBadArgs(t *testing.T) {
}
for _, test := range tests {
test := test
t.Run(strings.Join(test, " "), func(t *testing.T) {
serverController := NewServerController()
go func(serverController *ServerController) {
startServer(context.Background(), "test", "dolt sql-server", test, env, serverController)
}(serverController)
// In the event that a test fails, we need to prevent a test from hanging due to a running server
err := serverController.WaitForStart()
require.Error(t, err)
serverController.StopServer()
err = serverController.WaitForClose()
assert.NoError(t, err)
controller := svcs.NewController()
go func() {
startServer(context.Background(), "test", "dolt sql-server", test, env, controller)
}()
if !assert.Error(t, controller.WaitForStart()) {
controller.Stop()
}
})
}
}
@@ -186,8 +184,8 @@ func TestServerGoodParams(t *testing.T) {
for _, test := range tests {
t.Run(ConfigInfo(test), func(t *testing.T) {
sc := NewServerController()
go func(config ServerConfig, sc *ServerController) {
sc := svcs.NewController()
go func(config ServerConfig, sc *svcs.Controller) {
_, _ = Serve(context.Background(), "0.0.0", config, sc, env)
}(test, sc)
err := sc.WaitForStart()
@@ -196,8 +194,8 @@ func TestServerGoodParams(t *testing.T) {
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
sc.StopServer()
err = sc.WaitForClose()
sc.Stop()
err = sc.WaitForStop()
assert.NoError(t, err)
})
}
@@ -212,8 +210,8 @@ func TestServerSelect(t *testing.T) {
serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15300)
sc := NewServerController()
defer sc.StopServer()
sc := svcs.NewController()
defer sc.Stop()
go func() {
_, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, env)
}()
@@ -261,7 +259,7 @@ func TestServerSelect(t *testing.T) {
// If a port is already in use, throw error "Port XXXX already in use."
func TestServerFailsIfPortInUse(t *testing.T) {
serverController := NewServerController()
controller := svcs.NewController()
server := &http.Server{
Addr: ":15200",
Handler: http.DefaultServeMux,
@@ -287,10 +285,10 @@ func TestServerFailsIfPortInUse(t *testing.T) {
"-t", "5",
"-l", "info",
"-r",
}, dEnv, serverController)
}, dEnv, controller)
}()
err = serverController.WaitForStart()
err = controller.WaitForStart()
require.Error(t, err)
server.Close()
wg.Wait()
@@ -311,8 +309,8 @@ func TestServerSetDefaultBranch(t *testing.T) {
serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15302)
sc := NewServerController()
defer sc.StopServer()
sc := svcs.NewController()
defer sc.Stop()
go func() {
_, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv)
}()
@@ -470,7 +468,7 @@ func TestReadReplica(t *testing.T) {
dsess.InitPersistedSystemVars(multiSetup.GetEnv(readReplicaDbName))
// start server as read replica
sc := NewServerController()
sc := svcs.NewController()
serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15303)
// set socket to nil to force tcp
@@ -482,7 +480,7 @@ func TestReadReplica(t *testing.T) {
require.NoError(t, err)
}()
require.NoError(t, sc.WaitForStart())
defer sc.StopServer()
defer sc.Stop()
replicatedTable := "new_table"
multiSetup.CreateTable(ctx, sourceDbName, replicatedTable)

View File

@@ -1,98 +0,0 @@
// Copyright 2019 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlserver
import (
"sync"
)
type ServerController struct {
startCh chan struct{}
closeCh chan struct{}
closeCalled *sync.Once
closeRegistered *sync.Once
stopRegistered *sync.Once
closeFunction func() error
startError error
closeError error
}
// NewServerController creates a `ServerController` for use with synchronizing on `Serve`.
func NewServerController() *ServerController {
return &ServerController{
startCh: make(chan struct{}),
closeCh: make(chan struct{}),
closeCalled: &sync.Once{},
closeRegistered: &sync.Once{},
stopRegistered: &sync.Once{},
}
}
// registerCloseFunction is called within `Serve` to associate the close function with a future `StopServer` call.
// Only the first call will register and unblock, thus it is safe to be called multiple times.
func (controller *ServerController) registerCloseFunction(startError error, closeFunc func() error) {
controller.closeRegistered.Do(func() {
if startError != nil {
controller.startError = startError
}
controller.closeFunction = closeFunc
close(controller.startCh)
})
}
// serverStopped is called within `Serve` to signal that the server has stopped and set the exit code.
// Only the first call will register and unblock, thus it is safe to be called multiple times.
func (controller *ServerController) serverStopped(closeError error) {
controller.stopRegistered.Do(func() {
if closeError != nil {
controller.closeError = closeError
}
close(controller.closeCh)
})
}
// StopServer stops the server if it is running. Only the first call will trigger the stop, thus it is safe for
// multiple goroutines to call this function.
func (controller *ServerController) StopServer() {
if controller.closeFunction != nil {
controller.closeCalled.Do(func() {
if err := controller.closeFunction(); err != nil {
controller.closeError = err
}
})
}
}
// WaitForClose blocks the caller until the server has closed. The return is the last error encountered, if any.
func (controller *ServerController) WaitForClose() error {
select {
case <-controller.closeCh:
break
}
return controller.closeError
}
// WaitForStart blocks the caller until the server has started. An error is returned if one was encountered.
func (controller *ServerController) WaitForStart() error {
select {
case <-controller.startCh:
break
case <-controller.closeCh:
break
}
return controller.startError
}

View File

@@ -41,6 +41,7 @@ import (
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/libraries/utils/iohelp"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
)
const (
@@ -106,7 +107,7 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
apr := cli.ParseArgsOrDie(ap, args, help)
var serverConfig ServerConfig
var serverController *ServerController
var svcsController *svcs.Controller
var err error
cli.Println(color.YellowString("WARNING: This command is being deprecated and is not recommended for general use.\n" +
@@ -148,11 +149,11 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
}
cli.PrintErrf("Starting server with Config %v\n", ConfigInfo(serverConfig))
serverController = NewServerController()
svcsController = svcs.NewController()
go func() {
_, _ = Serve(ctx, cmd.VersionStr, serverConfig, serverController, dEnv)
_, _ = Serve(ctx, cmd.VersionStr, serverConfig, svcsController, dEnv)
}()
err = serverController.WaitForStart()
err = svcsController.WaitForStart()
if err != nil {
cli.PrintErrln(err.Error())
return 1
@@ -375,8 +376,8 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
cli.PrintErrln(err.Error())
}
if apr.Contains(sqlClientDualFlag) {
serverController.StopServer()
err = serverController.WaitForClose()
svcsController.Stop()
err = svcsController.WaitForStop()
if err != nil {
cli.PrintErrln(err.Error())
}

View File

@@ -30,6 +30,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
)
const (
@@ -186,12 +187,20 @@ func (cmd SqlServerCmd) RequiresRepo() bool {
// Exec executes the command
func (cmd SqlServerCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv, cliCtx cli.CliContext) int {
controller := NewServerController()
controller := svcs.NewController()
newCtx, cancelF := context.WithCancel(context.Background())
go func() {
<-ctx.Done()
controller.StopServer()
cancelF()
// Here we only forward along the SIGINT if the server starts
// up successfully. If the service does not start up
// successfully, or if WaitForStart() blocks indefinitely, then
// startServer() should have returned an error and we do not
// need to Stop the running server or deal with our canceled
// parent context.
if controller.WaitForStart() == nil {
<-ctx.Done()
controller.Stop()
cancelF()
}
}()
return startServer(newCtx, cmd.VersionStr, commandStr, args, dEnv, controller)
}
@@ -208,7 +217,7 @@ func validateSqlServerArgs(apr *argparser.ArgParseResults) error {
return nil
}
func startServer(ctx context.Context, versionStr, commandStr string, args []string, dEnv *env.DoltEnv, serverController *ServerController) int {
func startServer(ctx context.Context, versionStr, commandStr string, args []string, dEnv *env.DoltEnv, controller *svcs.Controller) int {
ap := SqlServerCmd{}.ArgParser()
help, _ := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString(commandStr, sqlServerDocs, ap))
@@ -222,21 +231,11 @@ func startServer(ctx context.Context, versionStr, commandStr string, args []stri
}
serverConfig, err := GetServerConfig(dEnv.FS, apr)
if err != nil {
if serverController != nil {
serverController.StopServer()
serverController.serverStopped(err)
}
cli.PrintErrln(color.RedString("Failed to start server. Bad Configuration"))
cli.PrintErrln(err.Error())
return 1
}
if err = SetupDoltConfig(dEnv, apr, serverConfig); err != nil {
if serverController != nil {
serverController.StopServer()
serverController.serverStopped(err)
}
cli.PrintErrln(color.RedString("Failed to start server. Bad Configuration"))
cli.PrintErrln(err.Error())
return 1
@@ -244,7 +243,7 @@ func startServer(ctx context.Context, versionStr, commandStr string, args []stri
cli.PrintErrf("Starting server with Config %v\n", ConfigInfo(serverConfig))
if startError, closeError := Serve(ctx, versionStr, serverConfig, serverController, dEnv); startError != nil || closeError != nil {
if startError, closeError := Serve(ctx, versionStr, serverConfig, controller, dEnv); startError != nil || closeError != nil {
if startError != nil {
cli.PrintErrln(startError)
}

View File

@@ -32,6 +32,7 @@ import (
"github.com/dolthub/dolt/go/cmd/dolt/commands/sqlserver"
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
)
// DoltBranchMultiSessionScriptTests contain tests that need to be run in a multi-session server environment
@@ -539,8 +540,8 @@ func testMultiSessionScriptTests(t *testing.T, tests []queries.ScriptTest) {
require.NoError(t, conn1.Close())
require.NoError(t, conn2.Close())
sc.StopServer()
err = sc.WaitForClose()
sc.Stop()
err = sc.WaitForStop()
require.NoError(t, err)
})
}
@@ -595,8 +596,8 @@ func testSerialSessionScriptTests(t *testing.T, tests []queries.ScriptTest) {
require.NoError(t, conn1.Close())
sc.StopServer()
err = sc.WaitForClose()
sc.Stop()
err = sc.WaitForStop()
require.NoError(t, err)
})
}
@@ -657,7 +658,7 @@ func assertResultsEqual(t *testing.T, expected []sql.Row, rows *gosql.Rows) {
}
// startServer will start sql-server with given host, unix socket file path and whether to use specific port, which is defined randomly.
func startServer(t *testing.T, withPort bool, host string, unixSocketPath string) (*env.DoltEnv, *sqlserver.ServerController, sqlserver.ServerConfig) {
func startServer(t *testing.T, withPort bool, host string, unixSocketPath string) (*env.DoltEnv, *svcs.Controller, sqlserver.ServerConfig) {
dEnv := dtestutils.CreateTestEnv()
serverConfig := sqlserver.DefaultServerConfig()
if withPort {
@@ -676,8 +677,8 @@ func startServer(t *testing.T, withPort bool, host string, unixSocketPath string
return dEnv, onEnv, config
}
func startServerOnEnv(t *testing.T, serverConfig sqlserver.ServerConfig, dEnv *env.DoltEnv) (*sqlserver.ServerController, sqlserver.ServerConfig) {
sc := sqlserver.NewServerController()
func startServerOnEnv(t *testing.T, serverConfig sqlserver.ServerConfig, dEnv *env.DoltEnv) (*svcs.Controller, sqlserver.ServerConfig) {
sc := svcs.NewController()
go func() {
_, _ = sqlserver.Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv)
}()
@@ -745,8 +746,8 @@ func TestDoltServerRunningUnixSocket(t *testing.T) {
require.NoError(t, localConn.Close())
// Stopping unix socket server
sc.StopServer()
err = sc.WaitForClose()
sc.Stop()
err = sc.WaitForStop()
require.NoError(t, err)
require.NoFileExists(t, defaultUnixSocketPath)
@@ -773,7 +774,7 @@ func TestDoltServerRunningUnixSocket(t *testing.T) {
})
// Stopping TCP socket server
tcpSc.StopServer()
err = tcpSc.WaitForClose()
tcpSc.Stop()
err = tcpSc.WaitForStop()
require.NoError(t, err)
}

View File

@@ -0,0 +1,189 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package svcs
import (
"context"
"errors"
"sync"
)
// A Service is a runnable unit of functionality that a Controller can
// take responsibility for. It has an |Init| function, which can error, and
// which should do all of the initialization and validation work necessary to
// bring the service up. It has a |Run| function, which will be called in a
// separate go-routine and should run and provide the functionality associated
// with the service until the |Stop| function is called.
type Service struct {
Init func(context.Context) error
Run func(context.Context)
Stop func() error
}
// A Controller is responsible for initializing a number of registered
// services, running them all, and stopping them all when requested. Services
// are registered with |Register(*Service)|. When |Start| is called, the
// services are all initialized, in the order of their registration, and if
// every service initializes successfully, they are |Run| concurrently. When
// |Stop| is called, services are stopped in reverse-registration order. |Stop|
// does not block for the goroutines spawned by |Start| to complete, although
// typically a Service's |Stop| function should do that. |Stop| only returns an
// error if the Controller is in an illegal state where it is not valid to Stop
// it. In particular, it does not return an error seen by a Service on Stop().
// That error is returned from Start() and from WaitForStop().
//
// Any attempt to register a service after |Start| or |Stop| has been called
// will return an error.
//
// If an error occurs when initializing the services of a Controller, the
// Stop functions of any already initialized Services are called in
// reverse-order. The error which caused the initialization error is returned.
//
// In the case that all Services Init successfully, the error returned from
// |Start| is the first non-nil error which is returned from the |Stop|
// functions, in the order they are called.
//
// WaitForStart() can be called at any time on a Controller. It will
// block until |Start| is called. After |Start| is called, if all the services
// succesfully initialize, it will return |nil|. Otherwise it will return the
// same error |Start| returned.
//
// WaitForStop() can be called at any time on a Controller. It will block
// until |Start| is called and initialization fails, or until |Stop| is called.
// It will return the same error which |Start| returned.
type Controller struct {
mu sync.Mutex
services []*Service
initErr error
stopErr error
started bool
startCh chan struct{}
stopped bool
stopCh chan struct{}
stoppedCh chan struct{}
}
func NewController() *Controller {
return &Controller{
startCh: make(chan struct{}),
stopCh: make(chan struct{}),
stoppedCh: make(chan struct{}),
}
}
func (c *Controller) WaitForStart() error {
<-c.startCh
c.mu.Lock()
err := c.initErr
c.mu.Unlock()
return err
}
func (c *Controller) WaitForStop() error {
<-c.stoppedCh
c.mu.Lock()
var err error
if c.initErr != nil {
err = c.initErr
} else if c.stopErr != nil {
err = c.stopErr
}
c.mu.Unlock()
return err
}
func (c *Controller) Register(svc *Service) error {
c.mu.Lock()
if c.started {
c.mu.Unlock()
return errors.New("Controller: cannot Register a service on a controller which was already started")
}
c.services = append(c.services, svc)
c.mu.Unlock()
return nil
}
func (c *Controller) Stop() error {
c.mu.Lock()
if !c.started {
c.mu.Unlock()
return errors.New("Controller: cannot Stop a controller which was never started")
}
if c.stopped {
c.mu.Unlock()
return errors.New("Controller: cannot Stop a controller which was already stopped or which failed to initialize all its services")
}
c.stopped = true
close(c.stopCh)
c.mu.Unlock()
<-c.stoppedCh
return nil
}
func (c *Controller) Start(ctx context.Context) error {
c.mu.Lock()
if c.started {
return errors.New("Controller: cannot start service controller twice")
}
c.started = true
svcs := make([]*Service, len(c.services))
copy(svcs, c.services)
c.mu.Unlock()
for i, s := range svcs {
if s.Init == nil {
continue
}
err := s.Init(ctx)
if err != nil {
for j := i - 1; j >= 0; j-- {
if svcs[j].Stop != nil {
svcs[j].Stop()
}
}
c.mu.Lock()
c.stopped = true
c.initErr = err
close(c.startCh)
close(c.stoppedCh)
c.mu.Unlock()
return err
}
}
close(c.startCh)
for _, s := range svcs {
if s.Run == nil {
continue
}
go s.Run(ctx)
}
<-c.stopCh
var stopErr error
for i := len(svcs) - 1; i >= 0; i-- {
if svcs[i].Stop == nil {
continue
}
err := svcs[i].Stop()
if err != nil && stopErr == nil {
stopErr = err
}
}
c.mu.Lock()
if stopErr != nil {
c.stopErr = stopErr
}
close(c.stoppedCh)
c.mu.Unlock()
return stopErr
}

View File

@@ -0,0 +1,305 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package svcs
import (
"context"
"errors"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
func TestController(t *testing.T) {
t.Run("NewController", func(t *testing.T) {
c := NewController()
require.NotNil(t, c)
})
t.Run("Stop", func(t *testing.T) {
t.Run("CalledBeforeStart", func(t *testing.T) {
c := NewController()
require.Error(t, c.Stop())
})
t.Run("ReturnsFirstError", func(t *testing.T) {
c := NewController()
ctx := context.Background()
err := errors.New("first")
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error { return nil },
Run: func(context.Context) {},
Stop: func() error { return errors.New("second") },
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error { return nil },
Run: func(context.Context) {},
Stop: func() error { return err },
}))
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
require.NoError(t, c.WaitForStart())
require.NoError(t, c.Stop())
}()
require.ErrorIs(t, c.Start(ctx), err)
require.ErrorIs(t, c.WaitForStop(), err)
wg.Wait()
})
})
t.Run("EmptyServices", func(t *testing.T) {
c := NewController()
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
require.NoError(t, c.WaitForStart())
require.NoError(t, c.Stop())
}()
require.NoError(t, c.Start(ctx))
require.NoError(t, c.WaitForStop())
wg.Wait()
})
t.Run("Register", func(t *testing.T) {
t.Run("AfterStartCalled", func(t *testing.T) {
c := NewController()
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
require.NoError(t, c.WaitForStart())
require.Error(t, c.Register(&Service{
Init: func(context.Context) error { return nil },
Run: func(context.Context) {},
Stop: func() error { return nil },
}))
require.NoError(t, c.Stop())
}()
require.NoError(t, c.Start(ctx))
require.NoError(t, c.WaitForStop())
wg.Wait()
})
})
t.Run("Start", func(t *testing.T) {
t.Run("CallsInitInOrder", func(t *testing.T) {
c := NewController()
var inited []int
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
inited = append(inited, 0)
return nil
},
Run: func(context.Context) {},
Stop: func() error { return nil },
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
inited = append(inited, 1)
return nil
},
Run: func(context.Context) {},
Stop: func() error { return nil },
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
inited = append(inited, 2)
return nil
},
Run: func(context.Context) {},
Stop: func() error { return nil },
}))
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
require.NoError(t, c.WaitForStart())
require.NoError(t, c.Stop())
}()
require.NoError(t, c.Start(ctx))
require.NoError(t, c.WaitForStop())
require.Equal(t, inited, []int{0, 1, 2})
wg.Wait()
})
t.Run("StopsCallingInitOnFirstError", func(t *testing.T) {
err := errors.New("first error")
c := NewController()
var inited []int
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
inited = append(inited, 0)
return nil
},
Run: func(context.Context) {},
Stop: func() error { return nil },
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
inited = append(inited, 1)
return nil
},
Run: func(context.Context) {},
Stop: func() error { return nil },
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
return err
},
Run: func(context.Context) {},
Stop: func() error { return nil },
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
inited = append(inited, 2)
return nil
},
Run: func(context.Context) {},
Stop: func() error { return nil },
}))
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
require.ErrorIs(t, c.WaitForStart(), err)
require.NotErrorIs(t, c.Stop(), err)
}()
require.ErrorIs(t, c.Start(ctx), err)
require.ErrorIs(t, c.WaitForStop(), err)
require.Equal(t, inited, []int{0, 1})
wg.Wait()
})
t.Run("CallsStopWhenInitErrors", func(t *testing.T) {
err := errors.New("first error")
c := NewController()
var stopped []int
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
return nil
},
Run: func(context.Context) {},
Stop: func() error {
stopped = append(stopped, 0)
return nil
},
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
return nil
},
Run: func(context.Context) {},
Stop: func() error {
stopped = append(stopped, 1)
return nil
},
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
return err
},
Run: func(context.Context) {},
Stop: func() error {
stopped = append(stopped, 2)
return nil
},
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error {
return nil
},
Run: func(context.Context) {},
Stop: func() error {
stopped = append(stopped, 3)
return nil
},
}))
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
require.ErrorIs(t, c.WaitForStart(), err)
require.NotErrorIs(t, c.Stop(), err)
}()
require.ErrorIs(t, c.Start(ctx), err)
require.ErrorIs(t, c.WaitForStop(), err)
require.Equal(t, stopped, []int{1, 0})
wg.Wait()
})
t.Run("RunsServices", func(t *testing.T) {
c := NewController()
var wg sync.WaitGroup
wg.Add(2)
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error { return nil },
Run: func(context.Context) { wg.Done() },
Stop: func() error { return nil },
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error { return nil },
Run: func(context.Context) { wg.Done() },
Stop: func() error { return nil },
}))
ctx := context.Background()
var cwg sync.WaitGroup
cwg.Add(1)
go func() {
defer cwg.Done()
require.NoError(t, c.WaitForStart())
require.NoError(t, c.Stop())
}()
require.NoError(t, c.Start(ctx))
require.NoError(t, c.WaitForStop())
wg.Wait()
cwg.Wait()
})
t.Run("StopsAllServices", func(t *testing.T) {
c := NewController()
var wg sync.WaitGroup
err := errors.New("first error")
wg.Add(2)
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error { return nil },
Run: func(context.Context) {},
Stop: func() error {
wg.Done()
return errors.New("second error")
},
}))
require.NoError(t, c.Register(&Service{
Init: func(context.Context) error { return nil },
Run: func(context.Context) {},
Stop: func() error {
wg.Done()
return err
},
}))
ctx := context.Background()
var cwg sync.WaitGroup
cwg.Add(1)
go func() {
defer cwg.Done()
require.NoError(t, c.WaitForStart())
require.NoError(t, c.Stop())
}()
require.ErrorIs(t, c.Start(ctx), err)
require.ErrorIs(t, c.WaitForStop(), err)
wg.Wait()
cwg.Wait()
})
})
}

View File

@@ -31,6 +31,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/testcommands"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
)
type query string
@@ -163,13 +164,13 @@ func getProfFile(b *testing.B) *os.File {
}
func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, cfg srv.ServerConfig, queries []query) {
serverController := srv.NewServerController()
sc := svcs.NewController()
eg, ctx := errgroup.WithContext(ctx)
//b.Logf("Starting server with Config %v\n", srv.ConfigInfo(cfg))
eg.Go(func() (err error) {
startErr, closeErr := srv.Serve(ctx, "", cfg, serverController, dEnv)
startErr, closeErr := srv.Serve(ctx, "", cfg, sc, dEnv)
if startErr != nil {
return startErr
}
@@ -178,7 +179,7 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv,
}
return nil
})
if err := serverController.WaitForStart(); err != nil {
if err := sc.WaitForStart(); err != nil {
b.Fatal(err)
}
@@ -188,8 +189,8 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv,
}
}
serverController.StopServer()
if err := serverController.WaitForClose(); err != nil {
sc.Stop()
if err := sc.WaitForStop(); err != nil {
b.Fatal(err)
}
if err := eg.Wait(); err != nil {

View File

@@ -31,6 +31,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
"github.com/dolthub/dolt/go/store/types"
)
@@ -175,13 +176,13 @@ func getProfFile(b *testing.B) *os.File {
}
func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, cfg srv.ServerConfig, queries []query) {
serverController := srv.NewServerController()
sc := svcs.NewController()
eg, ctx := errgroup.WithContext(ctx)
//b.Logf("Starting server with Config %v\n", srv.ConfigInfo(cfg))
eg.Go(func() (err error) {
startErr, closeErr := srv.Serve(ctx, "", cfg, serverController, dEnv)
startErr, closeErr := srv.Serve(ctx, "", cfg, sc, dEnv)
if startErr != nil {
return startErr
}
@@ -190,7 +191,7 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv,
}
return nil
})
if err := serverController.WaitForStart(); err != nil {
if err := sc.WaitForStart(); err != nil {
b.Fatal(err)
}
@@ -200,8 +201,8 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv,
}
}
serverController.StopServer()
if err := serverController.WaitForClose(); err != nil {
sc.Stop()
if err := sc.WaitForStop(); err != nil {
b.Fatal(err)
}
if err := eg.Wait(); err != nil {