mirror of
https://github.com/dolthub/dolt.git
synced 2026-01-09 16:16:08 -06:00
cmd/dolt/commands/sqlserver: Restructure the start up sequence for sql-server.
We explicitly model Services, which can have an Init step, a Run step and a Stop step. Every registered service get initialized in the order they were registered in, then they all run concurrently until Stop is called, when they all get Stopped in reverse order. It's possible for clients to wait for init to be completed and be delivered any errors encountered on startup. They can also wait for stop, to be delivered any errors encountered on shutdown.
This commit is contained in:
@@ -48,6 +48,7 @@ import (
|
||||
_ "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/svcs"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -63,335 +64,470 @@ func Serve(
|
||||
ctx context.Context,
|
||||
version string,
|
||||
serverConfig ServerConfig,
|
||||
serverController *ServerController,
|
||||
controller *svcs.Controller,
|
||||
dEnv *env.DoltEnv,
|
||||
) (startError error, closeError error) {
|
||||
// Code is easier to work through if we assume that serverController is never nil
|
||||
if serverController == nil {
|
||||
serverController = NewServerController()
|
||||
if controller == nil {
|
||||
controller = svcs.NewController()
|
||||
}
|
||||
|
||||
var mySQLServer *server.Server
|
||||
// This guarantees unblocking on any routines with a waiting `ServerController`
|
||||
defer func() {
|
||||
if mySQLServer != nil {
|
||||
serverController.registerCloseFunction(startError, mySQLServer.Close)
|
||||
} else {
|
||||
serverController.registerCloseFunction(startError, func() error { return nil })
|
||||
}
|
||||
serverController.StopServer()
|
||||
serverController.serverStopped(closeError)
|
||||
sqlserver.UnsetRunningServer()
|
||||
}()
|
||||
|
||||
if startError = ValidateConfig(serverConfig); startError != nil {
|
||||
return startError, nil
|
||||
ValidateConfigStep := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
return ValidateConfig(serverConfig)
|
||||
},
|
||||
}
|
||||
controller.Register(ValidateConfigStep)
|
||||
|
||||
lgr := logrus.StandardLogger()
|
||||
lgr.SetOutput(cli.CliErr)
|
||||
InitLogging := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
level, err := logrus.ParseLevel(serverConfig.LogLevel().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logrus.SetLevel(level)
|
||||
|
||||
if serverConfig.LogLevel() != LogLevel_Info {
|
||||
var level logrus.Level
|
||||
level, startError = logrus.ParseLevel(serverConfig.LogLevel().String())
|
||||
if startError != nil {
|
||||
cli.PrintErr(startError)
|
||||
return
|
||||
}
|
||||
logrus.SetLevel(level)
|
||||
}
|
||||
logrus.SetFormatter(LogFormat{})
|
||||
sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{
|
||||
{
|
||||
Name: dsess.DoltLogLevel,
|
||||
Scope: sql.SystemVariableScope_Global,
|
||||
Dynamic: true,
|
||||
SetVarHintApplies: false,
|
||||
Type: types.NewSystemEnumType(dsess.DoltLogLevel,
|
||||
logrus.PanicLevel.String(),
|
||||
logrus.FatalLevel.String(),
|
||||
logrus.ErrorLevel.String(),
|
||||
logrus.WarnLevel.String(),
|
||||
logrus.InfoLevel.String(),
|
||||
logrus.DebugLevel.String(),
|
||||
logrus.TraceLevel.String(),
|
||||
),
|
||||
Default: logrus.GetLevel().String(),
|
||||
NotifyChanged: func(scope sql.SystemVariableScope, v sql.SystemVarValue) error {
|
||||
level, err := logrus.ParseLevel(v.Val.(string))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse requested log level %s as a log level. dolt_log_level variable value and logging behavior will diverge.", v.Val.(string))
|
||||
}
|
||||
|
||||
sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{
|
||||
{
|
||||
Name: dsess.DoltLogLevel,
|
||||
Scope: sql.SystemVariableScope_Global,
|
||||
Dynamic: true,
|
||||
SetVarHintApplies: false,
|
||||
Type: types.NewSystemEnumType(dsess.DoltLogLevel,
|
||||
logrus.PanicLevel.String(),
|
||||
logrus.FatalLevel.String(),
|
||||
logrus.ErrorLevel.String(),
|
||||
logrus.WarnLevel.String(),
|
||||
logrus.InfoLevel.String(),
|
||||
logrus.DebugLevel.String(),
|
||||
logrus.TraceLevel.String(),
|
||||
),
|
||||
Default: logrus.GetLevel().String(),
|
||||
NotifyChanged: func(scope sql.SystemVariableScope, v sql.SystemVarValue) error {
|
||||
level, err := logrus.ParseLevel(v.Val.(string))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse requested log level %s as a log level. dolt_log_level variable value and logging behavior will diverge.", v.Val.(string))
|
||||
}
|
||||
|
||||
logrus.SetLevel(level)
|
||||
return nil
|
||||
},
|
||||
logrus.SetLevel(level)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}
|
||||
controller.Register(InitLogging)
|
||||
|
||||
fs := dEnv.FS
|
||||
InitDataDir := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
if len(serverConfig.DataDir()) > 0 && serverConfig.DataDir() != "." {
|
||||
var err error
|
||||
fs, err = dEnv.FS.WithWorkingDir(serverConfig.DataDir())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dEnv.FS = fs
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(InitDataDir)
|
||||
|
||||
var serverLock *env.DBLock
|
||||
InitGlobalServerLock := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
var err error
|
||||
serverLock, err = acquireGlobalSqlServerLock(serverConfig.Port(), dEnv)
|
||||
return err
|
||||
},
|
||||
Stop: func() error {
|
||||
dEnv.FS.Delete(dEnv.LockFile(), false)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(InitGlobalServerLock)
|
||||
|
||||
var mrEnv *env.MultiRepoEnv
|
||||
var err error
|
||||
fs := dEnv.FS
|
||||
|
||||
if len(serverConfig.DataDir()) > 0 && serverConfig.DataDir() != "." {
|
||||
fs, err = dEnv.FS.WithWorkingDir(serverConfig.DataDir())
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
dEnv.FS = fs
|
||||
InitMultiEnv := &svcs.Service{
|
||||
Init: func(ctx context.Context) error {
|
||||
var err error
|
||||
mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv.IgnoreLockFile, dEnv)
|
||||
return err
|
||||
},
|
||||
}
|
||||
controller.Register(InitMultiEnv)
|
||||
|
||||
serverLock, startError := acquireGlobalSqlServerLock(serverConfig.Port(), dEnv)
|
||||
if startError != nil {
|
||||
return
|
||||
var clusterController *cluster.Controller
|
||||
InitClusterController := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
var err error
|
||||
clusterController, err = cluster.NewController(lgr, serverConfig.ClusterConfig(), mrEnv.Config())
|
||||
return err
|
||||
},
|
||||
}
|
||||
defer dEnv.FS.Delete(dEnv.LockFile(), false)
|
||||
controller.Register(InitClusterController)
|
||||
|
||||
mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv.IgnoreLockFile, dEnv)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
|
||||
clusterController, err := cluster.NewController(lgr, serverConfig.ClusterConfig(), mrEnv.Config())
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
|
||||
serverConf, sErr, cErr := getConfigFromServerConfig(serverConfig)
|
||||
if cErr != nil {
|
||||
return nil, cErr
|
||||
} else if sErr != nil {
|
||||
return sErr, nil
|
||||
var serverConf server.Config
|
||||
LoadServerConfig := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
var err error
|
||||
serverConf, err = getConfigFromServerConfig(serverConfig)
|
||||
return err
|
||||
},
|
||||
}
|
||||
controller.Register(LoadServerConfig)
|
||||
|
||||
// Create SQL Engine with users
|
||||
config := &engine.SqlEngineConfig{
|
||||
IsReadOnly: serverConfig.ReadOnly(),
|
||||
PrivFilePath: serverConfig.PrivilegeFilePath(),
|
||||
BranchCtrlFilePath: serverConfig.BranchControlFilePath(),
|
||||
DoltCfgDirPath: serverConfig.CfgDir(),
|
||||
ServerUser: serverConfig.User(),
|
||||
ServerPass: serverConfig.Password(),
|
||||
ServerHost: serverConfig.Host(),
|
||||
Autocommit: serverConfig.AutoCommit(),
|
||||
DoltTransactionCommit: serverConfig.DoltTransactionCommit(),
|
||||
JwksConfig: serverConfig.JwksConfig(),
|
||||
SystemVariables: serverConfig.SystemVars(),
|
||||
ClusterController: clusterController,
|
||||
BinlogReplicaController: binlogreplication.DoltBinlogReplicaController,
|
||||
}
|
||||
esStatus, err := getEventSchedulerStatus(serverConfig.EventSchedulerStatus())
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
config.EventSchedulerStatus = esStatus
|
||||
|
||||
sqlEngine, err := engine.NewSqlEngine(
|
||||
ctx,
|
||||
mrEnv,
|
||||
config,
|
||||
)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
var config *engine.SqlEngineConfig
|
||||
InitSqlEngineConfig := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
config = &engine.SqlEngineConfig{
|
||||
IsReadOnly: serverConfig.ReadOnly(),
|
||||
PrivFilePath: serverConfig.PrivilegeFilePath(),
|
||||
BranchCtrlFilePath: serverConfig.BranchControlFilePath(),
|
||||
DoltCfgDirPath: serverConfig.CfgDir(),
|
||||
ServerUser: serverConfig.User(),
|
||||
ServerPass: serverConfig.Password(),
|
||||
ServerHost: serverConfig.Host(),
|
||||
Autocommit: serverConfig.AutoCommit(),
|
||||
DoltTransactionCommit: serverConfig.DoltTransactionCommit(),
|
||||
JwksConfig: serverConfig.JwksConfig(),
|
||||
SystemVariables: serverConfig.SystemVars(),
|
||||
ClusterController: clusterController,
|
||||
BinlogReplicaController: binlogreplication.DoltBinlogReplicaController,
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
defer sqlEngine.Close()
|
||||
controller.Register(InitSqlEngineConfig)
|
||||
|
||||
var esStatus eventscheduler.SchedulerStatus
|
||||
InitEventSchedulerStatus := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
var err error
|
||||
esStatus, err = getEventSchedulerStatus(serverConfig.EventSchedulerStatus())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config.EventSchedulerStatus = esStatus
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(InitEventSchedulerStatus)
|
||||
|
||||
var sqlEngine *engine.SqlEngine
|
||||
InitSqlEngine := &svcs.Service{
|
||||
Init: func(ctx context.Context) error {
|
||||
var err error
|
||||
sqlEngine, err = engine.NewSqlEngine(
|
||||
ctx,
|
||||
mrEnv,
|
||||
config,
|
||||
)
|
||||
return err
|
||||
},
|
||||
Stop: func() error {
|
||||
sqlEngine.Close()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(InitSqlEngine)
|
||||
|
||||
// Add superuser if specified user exists; add root superuser if no user specified and no existing privileges
|
||||
userSpecified := config.ServerUser != ""
|
||||
var userSpecified bool
|
||||
|
||||
mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb
|
||||
ed := mysqlDb.Editor()
|
||||
var numUsers int
|
||||
ed.VisitUsers(func(*mysql_db.User) { numUsers += 1 })
|
||||
privsExist := numUsers != 0
|
||||
if userSpecified {
|
||||
superuser := mysqlDb.GetUser(ed, config.ServerUser, "%", false)
|
||||
if userSpecified && superuser == nil {
|
||||
mysqlDb.AddSuperUser(ed, config.ServerUser, "%", config.ServerPass)
|
||||
}
|
||||
} else if !privsExist {
|
||||
mysqlDb.AddSuperUser(ed, defaultUser, "%", defaultPass)
|
||||
InitSuperUser := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
userSpecified = config.ServerUser != ""
|
||||
|
||||
mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb
|
||||
ed := mysqlDb.Editor()
|
||||
var numUsers int
|
||||
ed.VisitUsers(func(*mysql_db.User) { numUsers += 1 })
|
||||
privsExist := numUsers != 0
|
||||
if userSpecified {
|
||||
superuser := mysqlDb.GetUser(ed, config.ServerUser, "%", false)
|
||||
if userSpecified && superuser == nil {
|
||||
mysqlDb.AddSuperUser(ed, config.ServerUser, "%", config.ServerPass)
|
||||
}
|
||||
} else if !privsExist {
|
||||
mysqlDb.AddSuperUser(ed, defaultUser, "%", defaultPass)
|
||||
}
|
||||
ed.Close()
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ed.Close()
|
||||
|
||||
labels := serverConfig.MetricsLabels()
|
||||
controller.Register(InitSuperUser)
|
||||
|
||||
var listener *metricsListener
|
||||
listener, startError = newMetricsListener(labels, version, clusterController)
|
||||
if startError != nil {
|
||||
cli.Println(startError)
|
||||
return
|
||||
InitMetricsListener := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
labels := serverConfig.MetricsLabels()
|
||||
var err error
|
||||
listener, err = newMetricsListener(labels, version, clusterController)
|
||||
return err
|
||||
},
|
||||
Stop: func() error {
|
||||
listener.Close()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
defer listener.Close()
|
||||
controller.Register(InitMetricsListener)
|
||||
|
||||
v, ok := serverConfig.(validatingServerConfig)
|
||||
if ok && v.goldenMysqlConnectionString() != "" {
|
||||
mySQLServer, startError = server.NewValidatingServer(
|
||||
serverConf,
|
||||
sqlEngine.GetUnderlyingEngine(),
|
||||
newSessionBuilder(sqlEngine, serverConfig),
|
||||
listener,
|
||||
v.goldenMysqlConnectionString(),
|
||||
)
|
||||
} else {
|
||||
mySQLServer, startError = server.NewServer(
|
||||
serverConf,
|
||||
sqlEngine.GetUnderlyingEngine(),
|
||||
newSessionBuilder(sqlEngine, serverConfig),
|
||||
listener,
|
||||
)
|
||||
var mySQLServer *server.Server
|
||||
InitSQLServer := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
var err error
|
||||
v, ok := serverConfig.(validatingServerConfig)
|
||||
if ok && v.goldenMysqlConnectionString() != "" {
|
||||
mySQLServer, err = server.NewValidatingServer(
|
||||
serverConf,
|
||||
sqlEngine.GetUnderlyingEngine(),
|
||||
newSessionBuilder(sqlEngine, serverConfig),
|
||||
listener,
|
||||
v.goldenMysqlConnectionString(),
|
||||
)
|
||||
} else {
|
||||
mySQLServer, err = server.NewServer(
|
||||
serverConf,
|
||||
sqlEngine.GetUnderlyingEngine(),
|
||||
newSessionBuilder(sqlEngine, serverConfig),
|
||||
listener,
|
||||
)
|
||||
}
|
||||
if errors.Is(err, server.UnixSocketInUseError) {
|
||||
lgr.Warn("unix socket set up failed: file already in use: ", serverConf.Socket)
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
},
|
||||
}
|
||||
controller.Register(InitSQLServer)
|
||||
|
||||
if errors.Is(startError, server.UnixSocketInUseError) {
|
||||
lgr.Warn("unix socket set up failed: file already in use: ", serverConf.Socket)
|
||||
startError = nil
|
||||
} else if startError != nil {
|
||||
cli.PrintErr(startError)
|
||||
return
|
||||
LockMultiRepoEnv := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
if ok, f := mrEnv.IsLocked(); ok {
|
||||
return env.ErrActiveServerLock.New(f)
|
||||
}
|
||||
if err := mrEnv.Lock(serverLock); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Stop: func() error {
|
||||
if err := mrEnv.Unlock(); err != nil {
|
||||
cli.PrintErr(err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(LockMultiRepoEnv)
|
||||
|
||||
sqlserver.SetRunningServer(mySQLServer, serverLock)
|
||||
|
||||
ed = mysqlDb.Editor()
|
||||
mysqlDb.AddSuperUser(ed, LocalConnectionUser, "localhost", serverLock.Secret)
|
||||
ed.Close()
|
||||
if ExternalDisableUsers {
|
||||
mysqlDb.SetEnabled(false)
|
||||
InitLockSuperUser := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb
|
||||
ed := mysqlDb.Editor()
|
||||
mysqlDb.AddSuperUser(ed, LocalConnectionUser, "localhost", serverLock.Secret)
|
||||
ed.Close()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(InitLockSuperUser)
|
||||
|
||||
DisableMySQLDbIfRequired := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
if ExternalDisableUsers {
|
||||
mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb
|
||||
mysqlDb.SetEnabled(false)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(DisableMySQLDbIfRequired)
|
||||
|
||||
var metSrv *http.Server
|
||||
if serverConfig.MetricsHost() != "" && serverConfig.MetricsPort() > 0 {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", promhttp.Handler())
|
||||
RunMetricsServer := &svcs.Service{
|
||||
Run: func(context.Context) {
|
||||
if serverConfig.MetricsHost() != "" && serverConfig.MetricsPort() > 0 {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", promhttp.Handler())
|
||||
|
||||
metSrv = &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", serverConfig.MetricsHost(), serverConfig.MetricsPort()),
|
||||
Handler: mux,
|
||||
}
|
||||
metSrv = &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", serverConfig.MetricsHost(), serverConfig.MetricsPort()),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = metSrv.ListenAndServe()
|
||||
}()
|
||||
_ = metSrv.ListenAndServe()
|
||||
}
|
||||
},
|
||||
Stop: func() error {
|
||||
if metSrv != nil {
|
||||
metSrv.Close()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(RunMetricsServer)
|
||||
|
||||
var remoteSrv *remotesrv.Server
|
||||
if serverConfig.RemotesapiPort() != nil {
|
||||
port := *serverConfig.RemotesapiPort()
|
||||
listenaddr := fmt.Sprintf(":%d", port)
|
||||
args, err := sqle.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{
|
||||
Logger: logrus.NewEntry(lgr),
|
||||
ReadOnly: true,
|
||||
HttpListenAddr: listenaddr,
|
||||
GrpcListenAddr: listenaddr,
|
||||
})
|
||||
if err != nil {
|
||||
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
|
||||
startError = err
|
||||
return
|
||||
}
|
||||
var remoteSrvListeners remotesrv.Listeners
|
||||
RunRemoteSrv := &svcs.Service{
|
||||
Init: func(ctx context.Context) error {
|
||||
if serverConfig.RemotesapiPort() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctxFactory := func() (*sql.Context, error) { return sqlEngine.NewDefaultContext(ctx) }
|
||||
authenticator := newAuthenticator(ctxFactory, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
|
||||
args = sqle.WithUserPasswordAuth(args, authenticator)
|
||||
port := *serverConfig.RemotesapiPort()
|
||||
listenaddr := fmt.Sprintf(":%d", port)
|
||||
args, err := sqle.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{
|
||||
Logger: logrus.NewEntry(lgr),
|
||||
ReadOnly: true,
|
||||
HttpListenAddr: listenaddr,
|
||||
GrpcListenAddr: listenaddr,
|
||||
})
|
||||
if err != nil {
|
||||
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
args.TLSConfig = serverConf.TLSConfig
|
||||
remoteSrv, err = remotesrv.NewServer(args)
|
||||
if err != nil {
|
||||
lgr.Errorf("error creating remotesapi server on port %d: %v", port, err)
|
||||
startError = err
|
||||
return
|
||||
}
|
||||
listeners, err := remoteSrv.Listeners()
|
||||
if err != nil {
|
||||
lgr.Errorf("error starting remotesapi server listeners on port %d: %v", port, err)
|
||||
startError = err
|
||||
return
|
||||
} else {
|
||||
go remoteSrv.Serve(listeners)
|
||||
}
|
||||
ctxFactory := func() (*sql.Context, error) { return sqlEngine.NewDefaultContext(ctx) }
|
||||
authenticator := newAuthenticator(ctxFactory, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb)
|
||||
args = sqle.WithUserPasswordAuth(args, authenticator)
|
||||
args.TLSConfig = serverConf.TLSConfig
|
||||
|
||||
remoteSrv, err = remotesrv.NewServer(args)
|
||||
if err != nil {
|
||||
lgr.Errorf("error creating remotesapi server on port %d: %v", port, err)
|
||||
return err
|
||||
}
|
||||
remoteSrvListeners, err = remoteSrv.Listeners()
|
||||
if err != nil {
|
||||
lgr.Errorf("error starting remotesapi server listeners on port %d: %v", port, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(ctx context.Context) {
|
||||
if remoteSrv == nil {
|
||||
return
|
||||
}
|
||||
remoteSrv.Serve(remoteSrvListeners)
|
||||
},
|
||||
Stop: func() error {
|
||||
if remoteSrv == nil {
|
||||
return nil
|
||||
}
|
||||
remoteSrv.GracefulStop()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(RunRemoteSrv)
|
||||
|
||||
var clusterRemoteSrv *remotesrv.Server
|
||||
if clusterController != nil {
|
||||
args, err := clusterController.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{
|
||||
Logger: logrus.NewEntry(lgr),
|
||||
})
|
||||
if err != nil {
|
||||
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
|
||||
startError = err
|
||||
return
|
||||
}
|
||||
var clusterRemoteSrvListeners remotesrv.Listeners
|
||||
RunClusterRemoteSrv := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
if clusterController == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig())
|
||||
if err != nil {
|
||||
lgr.Errorf("error starting remotesapi server for cluster config, could not load tls config: %v", err)
|
||||
startError = err
|
||||
return
|
||||
}
|
||||
args.TLSConfig = clusterRemoteSrvTLSConfig
|
||||
args, err := clusterController.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{
|
||||
Logger: logrus.NewEntry(lgr),
|
||||
})
|
||||
if err != nil {
|
||||
lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
clusterRemoteSrv, err = remotesrv.NewServer(args)
|
||||
if err != nil {
|
||||
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
|
||||
startError = err
|
||||
return
|
||||
}
|
||||
clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.GrpcServer())
|
||||
clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig())
|
||||
if err != nil {
|
||||
lgr.Errorf("error starting remotesapi server for cluster config, could not load tls config: %v", err)
|
||||
return err
|
||||
}
|
||||
args.TLSConfig = clusterRemoteSrvTLSConfig
|
||||
|
||||
listeners, err := clusterRemoteSrv.Listeners()
|
||||
if err != nil {
|
||||
lgr.Errorf("error starting remotesapi server listeners for cluster config on %s: %v", clusterController.RemoteSrvListenAddr(), err)
|
||||
startError = err
|
||||
return
|
||||
}
|
||||
clusterRemoteSrv, err = remotesrv.NewServer(args)
|
||||
if err != nil {
|
||||
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
|
||||
return err
|
||||
}
|
||||
clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.GrpcServer())
|
||||
|
||||
go clusterRemoteSrv.Serve(listeners)
|
||||
go clusterController.Run()
|
||||
|
||||
clusterController.ManageQueryConnections(
|
||||
mySQLServer.SessionManager().Iter,
|
||||
sqlEngine.GetUnderlyingEngine().ProcessList.Kill,
|
||||
mySQLServer.SessionManager().KillConnection,
|
||||
)
|
||||
}
|
||||
|
||||
if ok, f := mrEnv.IsLocked(); ok {
|
||||
startError = env.ErrActiveServerLock.New(f)
|
||||
return
|
||||
}
|
||||
|
||||
if err = mrEnv.Lock(serverLock); err != nil {
|
||||
startError = err
|
||||
return
|
||||
}
|
||||
|
||||
serverController.registerCloseFunction(startError, func() error {
|
||||
if metSrv != nil {
|
||||
metSrv.Close()
|
||||
}
|
||||
if remoteSrv != nil {
|
||||
remoteSrv.GracefulStop()
|
||||
}
|
||||
if clusterRemoteSrv != nil {
|
||||
clusterRemoteSrvListeners, err = clusterRemoteSrv.Listeners()
|
||||
if err != nil {
|
||||
lgr.Errorf("error starting remotesapi server listeners for cluster config on %s: %v", clusterController.RemoteSrvListenAddr(), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {
|
||||
if clusterRemoteSrv == nil {
|
||||
return
|
||||
}
|
||||
clusterRemoteSrv.Serve(clusterRemoteSrvListeners)
|
||||
},
|
||||
Stop: func() error {
|
||||
if clusterRemoteSrv == nil {
|
||||
return nil
|
||||
}
|
||||
clusterRemoteSrv.GracefulStop()
|
||||
}
|
||||
if clusterController != nil {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(RunClusterRemoteSrv)
|
||||
|
||||
RunClusterController := &svcs.Service{
|
||||
Init: func(context.Context) error {
|
||||
if clusterController == nil {
|
||||
return nil
|
||||
}
|
||||
clusterController.ManageQueryConnections(
|
||||
mySQLServer.SessionManager().Iter,
|
||||
sqlEngine.GetUnderlyingEngine().ProcessList.Kill,
|
||||
mySQLServer.SessionManager().KillConnection,
|
||||
)
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {
|
||||
if clusterController == nil {
|
||||
return
|
||||
}
|
||||
clusterController.Run()
|
||||
},
|
||||
Stop: func() error {
|
||||
if clusterController == nil {
|
||||
return nil
|
||||
}
|
||||
clusterController.GracefulStop()
|
||||
}
|
||||
|
||||
return mySQLServer.Close()
|
||||
})
|
||||
|
||||
closeError = mySQLServer.Start()
|
||||
if closeError != nil {
|
||||
cli.PrintErr(closeError)
|
||||
}
|
||||
if err := mrEnv.Unlock(); err != nil {
|
||||
cli.PrintErr(err)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
controller.Register(RunClusterController)
|
||||
|
||||
return
|
||||
RunSQLServer := &svcs.Service{
|
||||
Run: func(context.Context) {
|
||||
sqlserver.SetRunningServer(mySQLServer, serverLock)
|
||||
defer sqlserver.UnsetRunningServer()
|
||||
mySQLServer.Start()
|
||||
},
|
||||
Stop: func() error {
|
||||
return mySQLServer.Close()
|
||||
},
|
||||
}
|
||||
controller.Register(RunSQLServer)
|
||||
|
||||
go controller.Start(ctx)
|
||||
err := controller.WaitForStart()
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
return nil, controller.WaitForStop()
|
||||
}
|
||||
|
||||
// acquireGlobalSqlServerLock attempts to acquire a global lock on the SQL server. If no error is returned, then the lock was acquired.
|
||||
@@ -514,10 +650,10 @@ func newSessionBuilder(se *engine.SqlEngine, config ServerConfig) server.Session
|
||||
}
|
||||
|
||||
// getConfigFromServerConfig processes ServerConfig and returns server.Config for sql-server.
|
||||
func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error, error) {
|
||||
func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error) {
|
||||
serverConf, err := handleProtocolAndAddress(serverConfig)
|
||||
if err != nil {
|
||||
return server.Config{}, err, nil
|
||||
return server.Config{}, err
|
||||
}
|
||||
|
||||
serverConf.DisableClientMultiStatements = serverConfig.DisableClientMultiStatements()
|
||||
@@ -527,7 +663,7 @@ func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error,
|
||||
|
||||
tlsConfig, err := LoadTLSConfig(serverConfig)
|
||||
if err != nil {
|
||||
return server.Config{}, nil, err
|
||||
return server.Config{}, err
|
||||
}
|
||||
|
||||
// if persist is 'load' we use currently set persisted global variable,
|
||||
@@ -535,12 +671,12 @@ func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error,
|
||||
if serverConfig.PersistenceBehavior() == loadPerisistentGlobals {
|
||||
serverConf, err = serverConf.NewConfig()
|
||||
if err != nil {
|
||||
return server.Config{}, err, nil
|
||||
return server.Config{}, err
|
||||
}
|
||||
} else {
|
||||
err = sql.SystemVariables.SetGlobal("max_connections", serverConfig.MaxConnections())
|
||||
if err != nil {
|
||||
return server.Config{}, err, nil
|
||||
return server.Config{}, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -554,7 +690,7 @@ func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error,
|
||||
serverConf.MaxLoggedQueryLen = serverConfig.MaxLoggedQueryLen()
|
||||
serverConf.EncodeLoggedQuery = serverConfig.ShouldEncodeLoggedQuery()
|
||||
|
||||
return serverConf, nil, nil
|
||||
return serverConf, nil
|
||||
}
|
||||
|
||||
// handleProtocolAndAddress returns new server.Config object with only Protocol and Address defined.
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/svcs"
|
||||
)
|
||||
|
||||
//TODO: server tests need to expose a higher granularity for server interactions:
|
||||
@@ -60,7 +61,7 @@ var (
|
||||
)
|
||||
|
||||
func TestServerArgs(t *testing.T) {
|
||||
serverController := NewServerController()
|
||||
controller := svcs.NewController()
|
||||
dEnv, err := sqle.CreateEnvWithSeedData()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
@@ -75,16 +76,16 @@ func TestServerArgs(t *testing.T) {
|
||||
"-t", "5",
|
||||
"-l", "info",
|
||||
"-r",
|
||||
}, dEnv, serverController)
|
||||
}, dEnv, controller)
|
||||
}()
|
||||
err = serverController.WaitForStart()
|
||||
err = controller.WaitForStart()
|
||||
require.NoError(t, err)
|
||||
conn, err := dbr.Open("mysql", "username:password@tcp(localhost:15200)/", nil)
|
||||
require.NoError(t, err)
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
serverController.StopServer()
|
||||
err = serverController.WaitForClose()
|
||||
controller.Stop()
|
||||
err = controller.WaitForStop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -110,22 +111,22 @@ listener:
|
||||
defer func() {
|
||||
assert.NoError(t, dEnv.DoltDB.Close())
|
||||
}()
|
||||
serverController := NewServerController()
|
||||
controller := svcs.NewController()
|
||||
go func() {
|
||||
|
||||
dEnv.FS.WriteFile("config.yaml", []byte(yamlConfig), os.ModePerm)
|
||||
startServer(context.Background(), "0.0.0", "dolt sql-server", []string{
|
||||
"--config", "config.yaml",
|
||||
}, dEnv, serverController)
|
||||
}, dEnv, controller)
|
||||
}()
|
||||
err = serverController.WaitForStart()
|
||||
err = controller.WaitForStart()
|
||||
require.NoError(t, err)
|
||||
conn, err := dbr.Open("mysql", "username:password@tcp(localhost:15200)/", nil)
|
||||
require.NoError(t, err)
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
serverController.StopServer()
|
||||
err = serverController.WaitForClose()
|
||||
controller.Stop()
|
||||
err = controller.WaitForStop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -145,18 +146,15 @@ func TestServerBadArgs(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(strings.Join(test, " "), func(t *testing.T) {
|
||||
serverController := NewServerController()
|
||||
go func(serverController *ServerController) {
|
||||
startServer(context.Background(), "test", "dolt sql-server", test, env, serverController)
|
||||
}(serverController)
|
||||
|
||||
// In the event that a test fails, we need to prevent a test from hanging due to a running server
|
||||
err := serverController.WaitForStart()
|
||||
require.Error(t, err)
|
||||
serverController.StopServer()
|
||||
err = serverController.WaitForClose()
|
||||
assert.NoError(t, err)
|
||||
controller := svcs.NewController()
|
||||
go func() {
|
||||
startServer(context.Background(), "test", "dolt sql-server", test, env, controller)
|
||||
}()
|
||||
if !assert.Error(t, controller.WaitForStart()) {
|
||||
controller.Stop()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -186,8 +184,8 @@ func TestServerGoodParams(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(ConfigInfo(test), func(t *testing.T) {
|
||||
sc := NewServerController()
|
||||
go func(config ServerConfig, sc *ServerController) {
|
||||
sc := svcs.NewController()
|
||||
go func(config ServerConfig, sc *svcs.Controller) {
|
||||
_, _ = Serve(context.Background(), "0.0.0", config, sc, env)
|
||||
}(test, sc)
|
||||
err := sc.WaitForStart()
|
||||
@@ -196,8 +194,8 @@ func TestServerGoodParams(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
sc.StopServer()
|
||||
err = sc.WaitForClose()
|
||||
sc.Stop()
|
||||
err = sc.WaitForStop()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -212,8 +210,8 @@ func TestServerSelect(t *testing.T) {
|
||||
|
||||
serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15300)
|
||||
|
||||
sc := NewServerController()
|
||||
defer sc.StopServer()
|
||||
sc := svcs.NewController()
|
||||
defer sc.Stop()
|
||||
go func() {
|
||||
_, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, env)
|
||||
}()
|
||||
@@ -261,7 +259,7 @@ func TestServerSelect(t *testing.T) {
|
||||
|
||||
// If a port is already in use, throw error "Port XXXX already in use."
|
||||
func TestServerFailsIfPortInUse(t *testing.T) {
|
||||
serverController := NewServerController()
|
||||
controller := svcs.NewController()
|
||||
server := &http.Server{
|
||||
Addr: ":15200",
|
||||
Handler: http.DefaultServeMux,
|
||||
@@ -287,10 +285,10 @@ func TestServerFailsIfPortInUse(t *testing.T) {
|
||||
"-t", "5",
|
||||
"-l", "info",
|
||||
"-r",
|
||||
}, dEnv, serverController)
|
||||
}, dEnv, controller)
|
||||
}()
|
||||
|
||||
err = serverController.WaitForStart()
|
||||
err = controller.WaitForStart()
|
||||
require.Error(t, err)
|
||||
server.Close()
|
||||
wg.Wait()
|
||||
@@ -311,8 +309,8 @@ func TestServerSetDefaultBranch(t *testing.T) {
|
||||
|
||||
serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15302)
|
||||
|
||||
sc := NewServerController()
|
||||
defer sc.StopServer()
|
||||
sc := svcs.NewController()
|
||||
defer sc.Stop()
|
||||
go func() {
|
||||
_, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv)
|
||||
}()
|
||||
@@ -470,7 +468,7 @@ func TestReadReplica(t *testing.T) {
|
||||
dsess.InitPersistedSystemVars(multiSetup.GetEnv(readReplicaDbName))
|
||||
|
||||
// start server as read replica
|
||||
sc := NewServerController()
|
||||
sc := svcs.NewController()
|
||||
serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15303)
|
||||
|
||||
// set socket to nil to force tcp
|
||||
@@ -482,7 +480,7 @@ func TestReadReplica(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
require.NoError(t, sc.WaitForStart())
|
||||
defer sc.StopServer()
|
||||
defer sc.Stop()
|
||||
|
||||
replicatedTable := "new_table"
|
||||
multiSetup.CreateTable(ctx, sourceDbName, replicatedTable)
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
// Copyright 2019 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlserver
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ServerController struct {
|
||||
startCh chan struct{}
|
||||
closeCh chan struct{}
|
||||
closeCalled *sync.Once
|
||||
closeRegistered *sync.Once
|
||||
stopRegistered *sync.Once
|
||||
closeFunction func() error
|
||||
startError error
|
||||
closeError error
|
||||
}
|
||||
|
||||
// NewServerController creates a `ServerController` for use with synchronizing on `Serve`.
|
||||
func NewServerController() *ServerController {
|
||||
return &ServerController{
|
||||
startCh: make(chan struct{}),
|
||||
closeCh: make(chan struct{}),
|
||||
closeCalled: &sync.Once{},
|
||||
closeRegistered: &sync.Once{},
|
||||
stopRegistered: &sync.Once{},
|
||||
}
|
||||
}
|
||||
|
||||
// registerCloseFunction is called within `Serve` to associate the close function with a future `StopServer` call.
|
||||
// Only the first call will register and unblock, thus it is safe to be called multiple times.
|
||||
func (controller *ServerController) registerCloseFunction(startError error, closeFunc func() error) {
|
||||
controller.closeRegistered.Do(func() {
|
||||
if startError != nil {
|
||||
controller.startError = startError
|
||||
}
|
||||
controller.closeFunction = closeFunc
|
||||
close(controller.startCh)
|
||||
})
|
||||
}
|
||||
|
||||
// serverStopped is called within `Serve` to signal that the server has stopped and set the exit code.
|
||||
// Only the first call will register and unblock, thus it is safe to be called multiple times.
|
||||
func (controller *ServerController) serverStopped(closeError error) {
|
||||
controller.stopRegistered.Do(func() {
|
||||
if closeError != nil {
|
||||
controller.closeError = closeError
|
||||
}
|
||||
close(controller.closeCh)
|
||||
})
|
||||
}
|
||||
|
||||
// StopServer stops the server if it is running. Only the first call will trigger the stop, thus it is safe for
|
||||
// multiple goroutines to call this function.
|
||||
func (controller *ServerController) StopServer() {
|
||||
if controller.closeFunction != nil {
|
||||
controller.closeCalled.Do(func() {
|
||||
if err := controller.closeFunction(); err != nil {
|
||||
controller.closeError = err
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForClose blocks the caller until the server has closed. The return is the last error encountered, if any.
|
||||
func (controller *ServerController) WaitForClose() error {
|
||||
select {
|
||||
case <-controller.closeCh:
|
||||
break
|
||||
}
|
||||
|
||||
return controller.closeError
|
||||
}
|
||||
|
||||
// WaitForStart blocks the caller until the server has started. An error is returned if one was encountered.
|
||||
func (controller *ServerController) WaitForStart() error {
|
||||
select {
|
||||
case <-controller.startCh:
|
||||
break
|
||||
case <-controller.closeCh:
|
||||
break
|
||||
}
|
||||
|
||||
return controller.startError
|
||||
}
|
||||
@@ -41,6 +41,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/utils/argparser"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/iohelp"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/svcs"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -106,7 +107,7 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
|
||||
|
||||
apr := cli.ParseArgsOrDie(ap, args, help)
|
||||
var serverConfig ServerConfig
|
||||
var serverController *ServerController
|
||||
var svcsController *svcs.Controller
|
||||
var err error
|
||||
|
||||
cli.Println(color.YellowString("WARNING: This command is being deprecated and is not recommended for general use.\n" +
|
||||
@@ -148,11 +149,11 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
|
||||
}
|
||||
cli.PrintErrf("Starting server with Config %v\n", ConfigInfo(serverConfig))
|
||||
|
||||
serverController = NewServerController()
|
||||
svcsController = svcs.NewController()
|
||||
go func() {
|
||||
_, _ = Serve(ctx, cmd.VersionStr, serverConfig, serverController, dEnv)
|
||||
_, _ = Serve(ctx, cmd.VersionStr, serverConfig, svcsController, dEnv)
|
||||
}()
|
||||
err = serverController.WaitForStart()
|
||||
err = svcsController.WaitForStart()
|
||||
if err != nil {
|
||||
cli.PrintErrln(err.Error())
|
||||
return 1
|
||||
@@ -375,8 +376,8 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
|
||||
cli.PrintErrln(err.Error())
|
||||
}
|
||||
if apr.Contains(sqlClientDualFlag) {
|
||||
serverController.StopServer()
|
||||
err = serverController.WaitForClose()
|
||||
svcsController.Stop()
|
||||
err = svcsController.WaitForStop()
|
||||
if err != nil {
|
||||
cli.PrintErrln(err.Error())
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/argparser"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/svcs"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -186,12 +187,20 @@ func (cmd SqlServerCmd) RequiresRepo() bool {
|
||||
|
||||
// Exec executes the command
|
||||
func (cmd SqlServerCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv, cliCtx cli.CliContext) int {
|
||||
controller := NewServerController()
|
||||
controller := svcs.NewController()
|
||||
newCtx, cancelF := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
controller.StopServer()
|
||||
cancelF()
|
||||
// Here we only forward along the SIGINT if the server starts
|
||||
// up successfully. If the service does not start up
|
||||
// successfully, or if WaitForStart() blocks indefinitely, then
|
||||
// startServer() should have returned an error and we do not
|
||||
// need to Stop the running server or deal with our canceled
|
||||
// parent context.
|
||||
if controller.WaitForStart() == nil {
|
||||
<-ctx.Done()
|
||||
controller.Stop()
|
||||
cancelF()
|
||||
}
|
||||
}()
|
||||
return startServer(newCtx, cmd.VersionStr, commandStr, args, dEnv, controller)
|
||||
}
|
||||
@@ -208,7 +217,7 @@ func validateSqlServerArgs(apr *argparser.ArgParseResults) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func startServer(ctx context.Context, versionStr, commandStr string, args []string, dEnv *env.DoltEnv, serverController *ServerController) int {
|
||||
func startServer(ctx context.Context, versionStr, commandStr string, args []string, dEnv *env.DoltEnv, controller *svcs.Controller) int {
|
||||
ap := SqlServerCmd{}.ArgParser()
|
||||
help, _ := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString(commandStr, sqlServerDocs, ap))
|
||||
|
||||
@@ -222,21 +231,11 @@ func startServer(ctx context.Context, versionStr, commandStr string, args []stri
|
||||
}
|
||||
serverConfig, err := GetServerConfig(dEnv.FS, apr)
|
||||
if err != nil {
|
||||
if serverController != nil {
|
||||
serverController.StopServer()
|
||||
serverController.serverStopped(err)
|
||||
}
|
||||
|
||||
cli.PrintErrln(color.RedString("Failed to start server. Bad Configuration"))
|
||||
cli.PrintErrln(err.Error())
|
||||
return 1
|
||||
}
|
||||
if err = SetupDoltConfig(dEnv, apr, serverConfig); err != nil {
|
||||
if serverController != nil {
|
||||
serverController.StopServer()
|
||||
serverController.serverStopped(err)
|
||||
}
|
||||
|
||||
cli.PrintErrln(color.RedString("Failed to start server. Bad Configuration"))
|
||||
cli.PrintErrln(err.Error())
|
||||
return 1
|
||||
@@ -244,7 +243,7 @@ func startServer(ctx context.Context, versionStr, commandStr string, args []stri
|
||||
|
||||
cli.PrintErrf("Starting server with Config %v\n", ConfigInfo(serverConfig))
|
||||
|
||||
if startError, closeError := Serve(ctx, versionStr, serverConfig, serverController, dEnv); startError != nil || closeError != nil {
|
||||
if startError, closeError := Serve(ctx, versionStr, serverConfig, controller, dEnv); startError != nil || closeError != nil {
|
||||
if startError != nil {
|
||||
cli.PrintErrln(startError)
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/commands/sqlserver"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/svcs"
|
||||
)
|
||||
|
||||
// DoltBranchMultiSessionScriptTests contain tests that need to be run in a multi-session server environment
|
||||
@@ -539,8 +540,8 @@ func testMultiSessionScriptTests(t *testing.T, tests []queries.ScriptTest) {
|
||||
require.NoError(t, conn1.Close())
|
||||
require.NoError(t, conn2.Close())
|
||||
|
||||
sc.StopServer()
|
||||
err = sc.WaitForClose()
|
||||
sc.Stop()
|
||||
err = sc.WaitForStop()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -595,8 +596,8 @@ func testSerialSessionScriptTests(t *testing.T, tests []queries.ScriptTest) {
|
||||
|
||||
require.NoError(t, conn1.Close())
|
||||
|
||||
sc.StopServer()
|
||||
err = sc.WaitForClose()
|
||||
sc.Stop()
|
||||
err = sc.WaitForStop()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -657,7 +658,7 @@ func assertResultsEqual(t *testing.T, expected []sql.Row, rows *gosql.Rows) {
|
||||
}
|
||||
|
||||
// startServer will start sql-server with given host, unix socket file path and whether to use specific port, which is defined randomly.
|
||||
func startServer(t *testing.T, withPort bool, host string, unixSocketPath string) (*env.DoltEnv, *sqlserver.ServerController, sqlserver.ServerConfig) {
|
||||
func startServer(t *testing.T, withPort bool, host string, unixSocketPath string) (*env.DoltEnv, *svcs.Controller, sqlserver.ServerConfig) {
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
serverConfig := sqlserver.DefaultServerConfig()
|
||||
if withPort {
|
||||
@@ -676,8 +677,8 @@ func startServer(t *testing.T, withPort bool, host string, unixSocketPath string
|
||||
return dEnv, onEnv, config
|
||||
}
|
||||
|
||||
func startServerOnEnv(t *testing.T, serverConfig sqlserver.ServerConfig, dEnv *env.DoltEnv) (*sqlserver.ServerController, sqlserver.ServerConfig) {
|
||||
sc := sqlserver.NewServerController()
|
||||
func startServerOnEnv(t *testing.T, serverConfig sqlserver.ServerConfig, dEnv *env.DoltEnv) (*svcs.Controller, sqlserver.ServerConfig) {
|
||||
sc := svcs.NewController()
|
||||
go func() {
|
||||
_, _ = sqlserver.Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv)
|
||||
}()
|
||||
@@ -745,8 +746,8 @@ func TestDoltServerRunningUnixSocket(t *testing.T) {
|
||||
require.NoError(t, localConn.Close())
|
||||
|
||||
// Stopping unix socket server
|
||||
sc.StopServer()
|
||||
err = sc.WaitForClose()
|
||||
sc.Stop()
|
||||
err = sc.WaitForStop()
|
||||
require.NoError(t, err)
|
||||
require.NoFileExists(t, defaultUnixSocketPath)
|
||||
|
||||
@@ -773,7 +774,7 @@ func TestDoltServerRunningUnixSocket(t *testing.T) {
|
||||
})
|
||||
|
||||
// Stopping TCP socket server
|
||||
tcpSc.StopServer()
|
||||
err = tcpSc.WaitForClose()
|
||||
tcpSc.Stop()
|
||||
err = tcpSc.WaitForStop()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
189
go/libraries/utils/svcs/controller.go
Normal file
189
go/libraries/utils/svcs/controller.go
Normal file
@@ -0,0 +1,189 @@
|
||||
// Copyright 2023 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package svcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A Service is a runnable unit of functionality that a Controller can
|
||||
// take responsibility for. It has an |Init| function, which can error, and
|
||||
// which should do all of the initialization and validation work necessary to
|
||||
// bring the service up. It has a |Run| function, which will be called in a
|
||||
// separate go-routine and should run and provide the functionality associated
|
||||
// with the service until the |Stop| function is called.
|
||||
type Service struct {
|
||||
Init func(context.Context) error
|
||||
Run func(context.Context)
|
||||
Stop func() error
|
||||
}
|
||||
|
||||
// A Controller is responsible for initializing a number of registered
|
||||
// services, running them all, and stopping them all when requested. Services
|
||||
// are registered with |Register(*Service)|. When |Start| is called, the
|
||||
// services are all initialized, in the order of their registration, and if
|
||||
// every service initializes successfully, they are |Run| concurrently. When
|
||||
// |Stop| is called, services are stopped in reverse-registration order. |Stop|
|
||||
// does not block for the goroutines spawned by |Start| to complete, although
|
||||
// typically a Service's |Stop| function should do that. |Stop| only returns an
|
||||
// error if the Controller is in an illegal state where it is not valid to Stop
|
||||
// it. In particular, it does not return an error seen by a Service on Stop().
|
||||
// That error is returned from Start() and from WaitForStop().
|
||||
//
|
||||
// Any attempt to register a service after |Start| or |Stop| has been called
|
||||
// will return an error.
|
||||
//
|
||||
// If an error occurs when initializing the services of a Controller, the
|
||||
// Stop functions of any already initialized Services are called in
|
||||
// reverse-order. The error which caused the initialization error is returned.
|
||||
//
|
||||
// In the case that all Services Init successfully, the error returned from
|
||||
// |Start| is the first non-nil error which is returned from the |Stop|
|
||||
// functions, in the order they are called.
|
||||
//
|
||||
// WaitForStart() can be called at any time on a Controller. It will
|
||||
// block until |Start| is called. After |Start| is called, if all the services
|
||||
// succesfully initialize, it will return |nil|. Otherwise it will return the
|
||||
// same error |Start| returned.
|
||||
//
|
||||
// WaitForStop() can be called at any time on a Controller. It will block
|
||||
// until |Start| is called and initialization fails, or until |Stop| is called.
|
||||
// It will return the same error which |Start| returned.
|
||||
type Controller struct {
|
||||
mu sync.Mutex
|
||||
services []*Service
|
||||
initErr error
|
||||
stopErr error
|
||||
started bool
|
||||
startCh chan struct{}
|
||||
stopped bool
|
||||
stopCh chan struct{}
|
||||
stoppedCh chan struct{}
|
||||
}
|
||||
|
||||
func NewController() *Controller {
|
||||
return &Controller{
|
||||
startCh: make(chan struct{}),
|
||||
stopCh: make(chan struct{}),
|
||||
stoppedCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) WaitForStart() error {
|
||||
<-c.startCh
|
||||
c.mu.Lock()
|
||||
err := c.initErr
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Controller) WaitForStop() error {
|
||||
<-c.stoppedCh
|
||||
c.mu.Lock()
|
||||
var err error
|
||||
if c.initErr != nil {
|
||||
err = c.initErr
|
||||
} else if c.stopErr != nil {
|
||||
err = c.stopErr
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Controller) Register(svc *Service) error {
|
||||
c.mu.Lock()
|
||||
if c.started {
|
||||
c.mu.Unlock()
|
||||
return errors.New("Controller: cannot Register a service on a controller which was already started")
|
||||
}
|
||||
c.services = append(c.services, svc)
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) Stop() error {
|
||||
c.mu.Lock()
|
||||
if !c.started {
|
||||
c.mu.Unlock()
|
||||
return errors.New("Controller: cannot Stop a controller which was never started")
|
||||
}
|
||||
if c.stopped {
|
||||
c.mu.Unlock()
|
||||
return errors.New("Controller: cannot Stop a controller which was already stopped or which failed to initialize all its services")
|
||||
}
|
||||
c.stopped = true
|
||||
close(c.stopCh)
|
||||
c.mu.Unlock()
|
||||
<-c.stoppedCh
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) Start(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
if c.started {
|
||||
return errors.New("Controller: cannot start service controller twice")
|
||||
}
|
||||
c.started = true
|
||||
svcs := make([]*Service, len(c.services))
|
||||
copy(svcs, c.services)
|
||||
c.mu.Unlock()
|
||||
for i, s := range svcs {
|
||||
if s.Init == nil {
|
||||
continue
|
||||
}
|
||||
err := s.Init(ctx)
|
||||
if err != nil {
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
if svcs[j].Stop != nil {
|
||||
svcs[j].Stop()
|
||||
}
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.stopped = true
|
||||
c.initErr = err
|
||||
close(c.startCh)
|
||||
close(c.stoppedCh)
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
}
|
||||
close(c.startCh)
|
||||
for _, s := range svcs {
|
||||
if s.Run == nil {
|
||||
continue
|
||||
}
|
||||
go s.Run(ctx)
|
||||
}
|
||||
<-c.stopCh
|
||||
var stopErr error
|
||||
for i := len(svcs) - 1; i >= 0; i-- {
|
||||
if svcs[i].Stop == nil {
|
||||
continue
|
||||
}
|
||||
err := svcs[i].Stop()
|
||||
if err != nil && stopErr == nil {
|
||||
stopErr = err
|
||||
}
|
||||
}
|
||||
c.mu.Lock()
|
||||
if stopErr != nil {
|
||||
c.stopErr = stopErr
|
||||
}
|
||||
close(c.stoppedCh)
|
||||
c.mu.Unlock()
|
||||
return stopErr
|
||||
}
|
||||
305
go/libraries/utils/svcs/controller_test.go
Normal file
305
go/libraries/utils/svcs/controller_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
// Copyright 2023 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package svcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestController(t *testing.T) {
|
||||
t.Run("NewController", func(t *testing.T) {
|
||||
c := NewController()
|
||||
require.NotNil(t, c)
|
||||
})
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
t.Run("CalledBeforeStart", func(t *testing.T) {
|
||||
c := NewController()
|
||||
require.Error(t, c.Stop())
|
||||
})
|
||||
t.Run("ReturnsFirstError", func(t *testing.T) {
|
||||
c := NewController()
|
||||
ctx := context.Background()
|
||||
err := errors.New("first")
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error { return nil },
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error { return errors.New("second") },
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error { return nil },
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error { return err },
|
||||
}))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
require.NoError(t, c.WaitForStart())
|
||||
require.NoError(t, c.Stop())
|
||||
}()
|
||||
require.ErrorIs(t, c.Start(ctx), err)
|
||||
require.ErrorIs(t, c.WaitForStop(), err)
|
||||
wg.Wait()
|
||||
})
|
||||
})
|
||||
t.Run("EmptyServices", func(t *testing.T) {
|
||||
c := NewController()
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
require.NoError(t, c.WaitForStart())
|
||||
require.NoError(t, c.Stop())
|
||||
}()
|
||||
require.NoError(t, c.Start(ctx))
|
||||
require.NoError(t, c.WaitForStop())
|
||||
wg.Wait()
|
||||
})
|
||||
t.Run("Register", func(t *testing.T) {
|
||||
t.Run("AfterStartCalled", func(t *testing.T) {
|
||||
c := NewController()
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
require.NoError(t, c.WaitForStart())
|
||||
require.Error(t, c.Register(&Service{
|
||||
Init: func(context.Context) error { return nil },
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error { return nil },
|
||||
}))
|
||||
require.NoError(t, c.Stop())
|
||||
}()
|
||||
require.NoError(t, c.Start(ctx))
|
||||
require.NoError(t, c.WaitForStop())
|
||||
wg.Wait()
|
||||
})
|
||||
})
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
t.Run("CallsInitInOrder", func(t *testing.T) {
|
||||
c := NewController()
|
||||
var inited []int
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
inited = append(inited, 0)
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error { return nil },
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
inited = append(inited, 1)
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error { return nil },
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
inited = append(inited, 2)
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error { return nil },
|
||||
}))
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
require.NoError(t, c.WaitForStart())
|
||||
require.NoError(t, c.Stop())
|
||||
}()
|
||||
require.NoError(t, c.Start(ctx))
|
||||
require.NoError(t, c.WaitForStop())
|
||||
require.Equal(t, inited, []int{0, 1, 2})
|
||||
wg.Wait()
|
||||
})
|
||||
t.Run("StopsCallingInitOnFirstError", func(t *testing.T) {
|
||||
err := errors.New("first error")
|
||||
c := NewController()
|
||||
var inited []int
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
inited = append(inited, 0)
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error { return nil },
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
inited = append(inited, 1)
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error { return nil },
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
return err
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error { return nil },
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
inited = append(inited, 2)
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error { return nil },
|
||||
}))
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
require.ErrorIs(t, c.WaitForStart(), err)
|
||||
require.NotErrorIs(t, c.Stop(), err)
|
||||
}()
|
||||
require.ErrorIs(t, c.Start(ctx), err)
|
||||
require.ErrorIs(t, c.WaitForStop(), err)
|
||||
require.Equal(t, inited, []int{0, 1})
|
||||
wg.Wait()
|
||||
})
|
||||
t.Run("CallsStopWhenInitErrors", func(t *testing.T) {
|
||||
err := errors.New("first error")
|
||||
c := NewController()
|
||||
var stopped []int
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error {
|
||||
stopped = append(stopped, 0)
|
||||
return nil
|
||||
},
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error {
|
||||
stopped = append(stopped, 1)
|
||||
return nil
|
||||
},
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
return err
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error {
|
||||
stopped = append(stopped, 2)
|
||||
return nil
|
||||
},
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error {
|
||||
return nil
|
||||
},
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error {
|
||||
stopped = append(stopped, 3)
|
||||
return nil
|
||||
},
|
||||
}))
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
require.ErrorIs(t, c.WaitForStart(), err)
|
||||
require.NotErrorIs(t, c.Stop(), err)
|
||||
}()
|
||||
require.ErrorIs(t, c.Start(ctx), err)
|
||||
require.ErrorIs(t, c.WaitForStop(), err)
|
||||
require.Equal(t, stopped, []int{1, 0})
|
||||
wg.Wait()
|
||||
})
|
||||
t.Run("RunsServices", func(t *testing.T) {
|
||||
c := NewController()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error { return nil },
|
||||
Run: func(context.Context) { wg.Done() },
|
||||
Stop: func() error { return nil },
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error { return nil },
|
||||
Run: func(context.Context) { wg.Done() },
|
||||
Stop: func() error { return nil },
|
||||
}))
|
||||
ctx := context.Background()
|
||||
var cwg sync.WaitGroup
|
||||
cwg.Add(1)
|
||||
go func() {
|
||||
defer cwg.Done()
|
||||
require.NoError(t, c.WaitForStart())
|
||||
require.NoError(t, c.Stop())
|
||||
}()
|
||||
require.NoError(t, c.Start(ctx))
|
||||
require.NoError(t, c.WaitForStop())
|
||||
wg.Wait()
|
||||
cwg.Wait()
|
||||
})
|
||||
t.Run("StopsAllServices", func(t *testing.T) {
|
||||
c := NewController()
|
||||
var wg sync.WaitGroup
|
||||
err := errors.New("first error")
|
||||
wg.Add(2)
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error { return nil },
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error {
|
||||
wg.Done()
|
||||
return errors.New("second error")
|
||||
},
|
||||
}))
|
||||
require.NoError(t, c.Register(&Service{
|
||||
Init: func(context.Context) error { return nil },
|
||||
Run: func(context.Context) {},
|
||||
Stop: func() error {
|
||||
wg.Done()
|
||||
return err
|
||||
},
|
||||
}))
|
||||
ctx := context.Background()
|
||||
var cwg sync.WaitGroup
|
||||
cwg.Add(1)
|
||||
go func() {
|
||||
defer cwg.Done()
|
||||
require.NoError(t, c.WaitForStart())
|
||||
require.NoError(t, c.Stop())
|
||||
}()
|
||||
require.ErrorIs(t, c.Start(ctx), err)
|
||||
require.ErrorIs(t, c.WaitForStop(), err)
|
||||
wg.Wait()
|
||||
cwg.Wait()
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/testcommands"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/svcs"
|
||||
)
|
||||
|
||||
type query string
|
||||
@@ -163,13 +164,13 @@ func getProfFile(b *testing.B) *os.File {
|
||||
}
|
||||
|
||||
func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, cfg srv.ServerConfig, queries []query) {
|
||||
serverController := srv.NewServerController()
|
||||
sc := svcs.NewController()
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
//b.Logf("Starting server with Config %v\n", srv.ConfigInfo(cfg))
|
||||
eg.Go(func() (err error) {
|
||||
startErr, closeErr := srv.Serve(ctx, "", cfg, serverController, dEnv)
|
||||
startErr, closeErr := srv.Serve(ctx, "", cfg, sc, dEnv)
|
||||
if startErr != nil {
|
||||
return startErr
|
||||
}
|
||||
@@ -178,7 +179,7 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := serverController.WaitForStart(); err != nil {
|
||||
if err := sc.WaitForStart(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -188,8 +189,8 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv,
|
||||
}
|
||||
}
|
||||
|
||||
serverController.StopServer()
|
||||
if err := serverController.WaitForClose(); err != nil {
|
||||
sc.Stop()
|
||||
if err := sc.WaitForStop(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := eg.Wait(); err != nil {
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/svcs"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
@@ -175,13 +176,13 @@ func getProfFile(b *testing.B) *os.File {
|
||||
}
|
||||
|
||||
func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, cfg srv.ServerConfig, queries []query) {
|
||||
serverController := srv.NewServerController()
|
||||
sc := svcs.NewController()
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
//b.Logf("Starting server with Config %v\n", srv.ConfigInfo(cfg))
|
||||
eg.Go(func() (err error) {
|
||||
startErr, closeErr := srv.Serve(ctx, "", cfg, serverController, dEnv)
|
||||
startErr, closeErr := srv.Serve(ctx, "", cfg, sc, dEnv)
|
||||
if startErr != nil {
|
||||
return startErr
|
||||
}
|
||||
@@ -190,7 +191,7 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := serverController.WaitForStart(); err != nil {
|
||||
if err := sc.WaitForStart(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -200,8 +201,8 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv,
|
||||
}
|
||||
}
|
||||
|
||||
serverController.StopServer()
|
||||
if err := serverController.WaitForClose(); err != nil {
|
||||
sc.Stop()
|
||||
if err := sc.WaitForStop(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := eg.Wait(); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user