diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 33ca45494b..dd7c02ec01 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -48,6 +48,7 @@ import ( _ "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/doltcore/sqlserver" + "github.com/dolthub/dolt/go/libraries/utils/svcs" ) const ( @@ -63,335 +64,470 @@ func Serve( ctx context.Context, version string, serverConfig ServerConfig, - serverController *ServerController, + controller *svcs.Controller, dEnv *env.DoltEnv, ) (startError error, closeError error) { // Code is easier to work through if we assume that serverController is never nil - if serverController == nil { - serverController = NewServerController() + if controller == nil { + controller = svcs.NewController() } - var mySQLServer *server.Server - // This guarantees unblocking on any routines with a waiting `ServerController` - defer func() { - if mySQLServer != nil { - serverController.registerCloseFunction(startError, mySQLServer.Close) - } else { - serverController.registerCloseFunction(startError, func() error { return nil }) - } - serverController.StopServer() - serverController.serverStopped(closeError) - sqlserver.UnsetRunningServer() - }() - - if startError = ValidateConfig(serverConfig); startError != nil { - return startError, nil + ValidateConfigStep := &svcs.Service{ + Init: func(context.Context) error { + return ValidateConfig(serverConfig) + }, } + controller.Register(ValidateConfigStep) lgr := logrus.StandardLogger() lgr.SetOutput(cli.CliErr) + InitLogging := &svcs.Service{ + Init: func(context.Context) error { + level, err := logrus.ParseLevel(serverConfig.LogLevel().String()) + if err != nil { + return err + } + logrus.SetLevel(level) - if serverConfig.LogLevel() != LogLevel_Info { - var level logrus.Level - level, startError = logrus.ParseLevel(serverConfig.LogLevel().String()) - if startError != nil { - cli.PrintErr(startError) - return - } - logrus.SetLevel(level) - } - logrus.SetFormatter(LogFormat{}) + sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{ + { + Name: dsess.DoltLogLevel, + Scope: sql.SystemVariableScope_Global, + Dynamic: true, + SetVarHintApplies: false, + Type: types.NewSystemEnumType(dsess.DoltLogLevel, + logrus.PanicLevel.String(), + logrus.FatalLevel.String(), + logrus.ErrorLevel.String(), + logrus.WarnLevel.String(), + logrus.InfoLevel.String(), + logrus.DebugLevel.String(), + logrus.TraceLevel.String(), + ), + Default: logrus.GetLevel().String(), + NotifyChanged: func(scope sql.SystemVariableScope, v sql.SystemVarValue) error { + level, err := logrus.ParseLevel(v.Val.(string)) + if err != nil { + return fmt.Errorf("could not parse requested log level %s as a log level. dolt_log_level variable value and logging behavior will diverge.", v.Val.(string)) + } - sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{ - { - Name: dsess.DoltLogLevel, - Scope: sql.SystemVariableScope_Global, - Dynamic: true, - SetVarHintApplies: false, - Type: types.NewSystemEnumType(dsess.DoltLogLevel, - logrus.PanicLevel.String(), - logrus.FatalLevel.String(), - logrus.ErrorLevel.String(), - logrus.WarnLevel.String(), - logrus.InfoLevel.String(), - logrus.DebugLevel.String(), - logrus.TraceLevel.String(), - ), - Default: logrus.GetLevel().String(), - NotifyChanged: func(scope sql.SystemVariableScope, v sql.SystemVarValue) error { - level, err := logrus.ParseLevel(v.Val.(string)) - if err != nil { - return fmt.Errorf("could not parse requested log level %s as a log level. dolt_log_level variable value and logging behavior will diverge.", v.Val.(string)) - } - - logrus.SetLevel(level) - return nil - }, + logrus.SetLevel(level) + return nil + }, + }, + }) + return nil }, - }) + } + controller.Register(InitLogging) + + fs := dEnv.FS + InitDataDir := &svcs.Service{ + Init: func(context.Context) error { + if len(serverConfig.DataDir()) > 0 && serverConfig.DataDir() != "." { + var err error + fs, err = dEnv.FS.WithWorkingDir(serverConfig.DataDir()) + if err != nil { + return err + } + dEnv.FS = fs + } + return nil + }, + } + controller.Register(InitDataDir) + + var serverLock *env.DBLock + InitGlobalServerLock := &svcs.Service{ + Init: func(context.Context) error { + var err error + serverLock, err = acquireGlobalSqlServerLock(serverConfig.Port(), dEnv) + return err + }, + Stop: func() error { + dEnv.FS.Delete(dEnv.LockFile(), false) + return nil + }, + } + controller.Register(InitGlobalServerLock) var mrEnv *env.MultiRepoEnv - var err error - fs := dEnv.FS - - if len(serverConfig.DataDir()) > 0 && serverConfig.DataDir() != "." { - fs, err = dEnv.FS.WithWorkingDir(serverConfig.DataDir()) - if err != nil { - return err, nil - } - dEnv.FS = fs + InitMultiEnv := &svcs.Service{ + Init: func(ctx context.Context) error { + var err error + mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv.IgnoreLockFile, dEnv) + return err + }, } + controller.Register(InitMultiEnv) - serverLock, startError := acquireGlobalSqlServerLock(serverConfig.Port(), dEnv) - if startError != nil { - return + var clusterController *cluster.Controller + InitClusterController := &svcs.Service{ + Init: func(context.Context) error { + var err error + clusterController, err = cluster.NewController(lgr, serverConfig.ClusterConfig(), mrEnv.Config()) + return err + }, } - defer dEnv.FS.Delete(dEnv.LockFile(), false) + controller.Register(InitClusterController) - mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv.IgnoreLockFile, dEnv) - if err != nil { - return err, nil - } - - clusterController, err := cluster.NewController(lgr, serverConfig.ClusterConfig(), mrEnv.Config()) - if err != nil { - return err, nil - } - - serverConf, sErr, cErr := getConfigFromServerConfig(serverConfig) - if cErr != nil { - return nil, cErr - } else if sErr != nil { - return sErr, nil + var serverConf server.Config + LoadServerConfig := &svcs.Service{ + Init: func(context.Context) error { + var err error + serverConf, err = getConfigFromServerConfig(serverConfig) + return err + }, } + controller.Register(LoadServerConfig) // Create SQL Engine with users - config := &engine.SqlEngineConfig{ - IsReadOnly: serverConfig.ReadOnly(), - PrivFilePath: serverConfig.PrivilegeFilePath(), - BranchCtrlFilePath: serverConfig.BranchControlFilePath(), - DoltCfgDirPath: serverConfig.CfgDir(), - ServerUser: serverConfig.User(), - ServerPass: serverConfig.Password(), - ServerHost: serverConfig.Host(), - Autocommit: serverConfig.AutoCommit(), - DoltTransactionCommit: serverConfig.DoltTransactionCommit(), - JwksConfig: serverConfig.JwksConfig(), - SystemVariables: serverConfig.SystemVars(), - ClusterController: clusterController, - BinlogReplicaController: binlogreplication.DoltBinlogReplicaController, - } - esStatus, err := getEventSchedulerStatus(serverConfig.EventSchedulerStatus()) - if err != nil { - return err, nil - } - config.EventSchedulerStatus = esStatus - sqlEngine, err := engine.NewSqlEngine( - ctx, - mrEnv, - config, - ) - if err != nil { - return err, nil + var config *engine.SqlEngineConfig + InitSqlEngineConfig := &svcs.Service{ + Init: func(context.Context) error { + config = &engine.SqlEngineConfig{ + IsReadOnly: serverConfig.ReadOnly(), + PrivFilePath: serverConfig.PrivilegeFilePath(), + BranchCtrlFilePath: serverConfig.BranchControlFilePath(), + DoltCfgDirPath: serverConfig.CfgDir(), + ServerUser: serverConfig.User(), + ServerPass: serverConfig.Password(), + ServerHost: serverConfig.Host(), + Autocommit: serverConfig.AutoCommit(), + DoltTransactionCommit: serverConfig.DoltTransactionCommit(), + JwksConfig: serverConfig.JwksConfig(), + SystemVariables: serverConfig.SystemVars(), + ClusterController: clusterController, + BinlogReplicaController: binlogreplication.DoltBinlogReplicaController, + } + return nil + }, } - defer sqlEngine.Close() + controller.Register(InitSqlEngineConfig) + + var esStatus eventscheduler.SchedulerStatus + InitEventSchedulerStatus := &svcs.Service{ + Init: func(context.Context) error { + var err error + esStatus, err = getEventSchedulerStatus(serverConfig.EventSchedulerStatus()) + if err != nil { + return err + } + config.EventSchedulerStatus = esStatus + return nil + }, + } + controller.Register(InitEventSchedulerStatus) + + var sqlEngine *engine.SqlEngine + InitSqlEngine := &svcs.Service{ + Init: func(ctx context.Context) error { + var err error + sqlEngine, err = engine.NewSqlEngine( + ctx, + mrEnv, + config, + ) + return err + }, + Stop: func() error { + sqlEngine.Close() + return nil + }, + } + controller.Register(InitSqlEngine) // Add superuser if specified user exists; add root superuser if no user specified and no existing privileges - userSpecified := config.ServerUser != "" + var userSpecified bool - mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb - ed := mysqlDb.Editor() - var numUsers int - ed.VisitUsers(func(*mysql_db.User) { numUsers += 1 }) - privsExist := numUsers != 0 - if userSpecified { - superuser := mysqlDb.GetUser(ed, config.ServerUser, "%", false) - if userSpecified && superuser == nil { - mysqlDb.AddSuperUser(ed, config.ServerUser, "%", config.ServerPass) - } - } else if !privsExist { - mysqlDb.AddSuperUser(ed, defaultUser, "%", defaultPass) + InitSuperUser := &svcs.Service{ + Init: func(context.Context) error { + userSpecified = config.ServerUser != "" + + mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb + ed := mysqlDb.Editor() + var numUsers int + ed.VisitUsers(func(*mysql_db.User) { numUsers += 1 }) + privsExist := numUsers != 0 + if userSpecified { + superuser := mysqlDb.GetUser(ed, config.ServerUser, "%", false) + if userSpecified && superuser == nil { + mysqlDb.AddSuperUser(ed, config.ServerUser, "%", config.ServerPass) + } + } else if !privsExist { + mysqlDb.AddSuperUser(ed, defaultUser, "%", defaultPass) + } + ed.Close() + + return nil + }, } - ed.Close() - - labels := serverConfig.MetricsLabels() + controller.Register(InitSuperUser) var listener *metricsListener - listener, startError = newMetricsListener(labels, version, clusterController) - if startError != nil { - cli.Println(startError) - return + InitMetricsListener := &svcs.Service{ + Init: func(context.Context) error { + labels := serverConfig.MetricsLabels() + var err error + listener, err = newMetricsListener(labels, version, clusterController) + return err + }, + Stop: func() error { + listener.Close() + return nil + }, } - defer listener.Close() + controller.Register(InitMetricsListener) - v, ok := serverConfig.(validatingServerConfig) - if ok && v.goldenMysqlConnectionString() != "" { - mySQLServer, startError = server.NewValidatingServer( - serverConf, - sqlEngine.GetUnderlyingEngine(), - newSessionBuilder(sqlEngine, serverConfig), - listener, - v.goldenMysqlConnectionString(), - ) - } else { - mySQLServer, startError = server.NewServer( - serverConf, - sqlEngine.GetUnderlyingEngine(), - newSessionBuilder(sqlEngine, serverConfig), - listener, - ) + var mySQLServer *server.Server + InitSQLServer := &svcs.Service{ + Init: func(context.Context) error { + var err error + v, ok := serverConfig.(validatingServerConfig) + if ok && v.goldenMysqlConnectionString() != "" { + mySQLServer, err = server.NewValidatingServer( + serverConf, + sqlEngine.GetUnderlyingEngine(), + newSessionBuilder(sqlEngine, serverConfig), + listener, + v.goldenMysqlConnectionString(), + ) + } else { + mySQLServer, err = server.NewServer( + serverConf, + sqlEngine.GetUnderlyingEngine(), + newSessionBuilder(sqlEngine, serverConfig), + listener, + ) + } + if errors.Is(err, server.UnixSocketInUseError) { + lgr.Warn("unix socket set up failed: file already in use: ", serverConf.Socket) + err = nil + } + return err + }, } + controller.Register(InitSQLServer) - if errors.Is(startError, server.UnixSocketInUseError) { - lgr.Warn("unix socket set up failed: file already in use: ", serverConf.Socket) - startError = nil - } else if startError != nil { - cli.PrintErr(startError) - return + LockMultiRepoEnv := &svcs.Service{ + Init: func(context.Context) error { + if ok, f := mrEnv.IsLocked(); ok { + return env.ErrActiveServerLock.New(f) + } + if err := mrEnv.Lock(serverLock); err != nil { + return err + } + return nil + }, + Stop: func() error { + if err := mrEnv.Unlock(); err != nil { + cli.PrintErr(err) + } + return nil + }, } + controller.Register(LockMultiRepoEnv) - sqlserver.SetRunningServer(mySQLServer, serverLock) - - ed = mysqlDb.Editor() - mysqlDb.AddSuperUser(ed, LocalConnectionUser, "localhost", serverLock.Secret) - ed.Close() - if ExternalDisableUsers { - mysqlDb.SetEnabled(false) + InitLockSuperUser := &svcs.Service{ + Init: func(context.Context) error { + mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb + ed := mysqlDb.Editor() + mysqlDb.AddSuperUser(ed, LocalConnectionUser, "localhost", serverLock.Secret) + ed.Close() + return nil + }, } + controller.Register(InitLockSuperUser) + + DisableMySQLDbIfRequired := &svcs.Service{ + Init: func(context.Context) error { + if ExternalDisableUsers { + mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb + mysqlDb.SetEnabled(false) + } + return nil + }, + } + controller.Register(DisableMySQLDbIfRequired) var metSrv *http.Server - if serverConfig.MetricsHost() != "" && serverConfig.MetricsPort() > 0 { - mux := http.NewServeMux() - mux.Handle("/metrics", promhttp.Handler()) + RunMetricsServer := &svcs.Service{ + Run: func(context.Context) { + if serverConfig.MetricsHost() != "" && serverConfig.MetricsPort() > 0 { + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.Handler()) - metSrv = &http.Server{ - Addr: fmt.Sprintf("%s:%d", serverConfig.MetricsHost(), serverConfig.MetricsPort()), - Handler: mux, - } + metSrv = &http.Server{ + Addr: fmt.Sprintf("%s:%d", serverConfig.MetricsHost(), serverConfig.MetricsPort()), + Handler: mux, + } - go func() { - _ = metSrv.ListenAndServe() - }() + _ = metSrv.ListenAndServe() + } + }, + Stop: func() error { + if metSrv != nil { + metSrv.Close() + } + return nil + }, } + controller.Register(RunMetricsServer) var remoteSrv *remotesrv.Server - if serverConfig.RemotesapiPort() != nil { - port := *serverConfig.RemotesapiPort() - listenaddr := fmt.Sprintf(":%d", port) - args, err := sqle.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{ - Logger: logrus.NewEntry(lgr), - ReadOnly: true, - HttpListenAddr: listenaddr, - GrpcListenAddr: listenaddr, - }) - if err != nil { - lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) - startError = err - return - } + var remoteSrvListeners remotesrv.Listeners + RunRemoteSrv := &svcs.Service{ + Init: func(ctx context.Context) error { + if serverConfig.RemotesapiPort() == nil { + return nil + } - ctxFactory := func() (*sql.Context, error) { return sqlEngine.NewDefaultContext(ctx) } - authenticator := newAuthenticator(ctxFactory, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb) - args = sqle.WithUserPasswordAuth(args, authenticator) + port := *serverConfig.RemotesapiPort() + listenaddr := fmt.Sprintf(":%d", port) + args, err := sqle.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{ + Logger: logrus.NewEntry(lgr), + ReadOnly: true, + HttpListenAddr: listenaddr, + GrpcListenAddr: listenaddr, + }) + if err != nil { + lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) + return err + } - args.TLSConfig = serverConf.TLSConfig - remoteSrv, err = remotesrv.NewServer(args) - if err != nil { - lgr.Errorf("error creating remotesapi server on port %d: %v", port, err) - startError = err - return - } - listeners, err := remoteSrv.Listeners() - if err != nil { - lgr.Errorf("error starting remotesapi server listeners on port %d: %v", port, err) - startError = err - return - } else { - go remoteSrv.Serve(listeners) - } + ctxFactory := func() (*sql.Context, error) { return sqlEngine.NewDefaultContext(ctx) } + authenticator := newAuthenticator(ctxFactory, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb) + args = sqle.WithUserPasswordAuth(args, authenticator) + args.TLSConfig = serverConf.TLSConfig + + remoteSrv, err = remotesrv.NewServer(args) + if err != nil { + lgr.Errorf("error creating remotesapi server on port %d: %v", port, err) + return err + } + remoteSrvListeners, err = remoteSrv.Listeners() + if err != nil { + lgr.Errorf("error starting remotesapi server listeners on port %d: %v", port, err) + return err + } + return nil + }, + Run: func(ctx context.Context) { + if remoteSrv == nil { + return + } + remoteSrv.Serve(remoteSrvListeners) + }, + Stop: func() error { + if remoteSrv == nil { + return nil + } + remoteSrv.GracefulStop() + return nil + }, } + controller.Register(RunRemoteSrv) var clusterRemoteSrv *remotesrv.Server - if clusterController != nil { - args, err := clusterController.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{ - Logger: logrus.NewEntry(lgr), - }) - if err != nil { - lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) - startError = err - return - } + var clusterRemoteSrvListeners remotesrv.Listeners + RunClusterRemoteSrv := &svcs.Service{ + Init: func(context.Context) error { + if clusterController == nil { + return nil + } - clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig()) - if err != nil { - lgr.Errorf("error starting remotesapi server for cluster config, could not load tls config: %v", err) - startError = err - return - } - args.TLSConfig = clusterRemoteSrvTLSConfig + args, err := clusterController.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{ + Logger: logrus.NewEntry(lgr), + }) + if err != nil { + lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) + return err + } - clusterRemoteSrv, err = remotesrv.NewServer(args) - if err != nil { - lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err) - startError = err - return - } - clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.GrpcServer()) + clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig()) + if err != nil { + lgr.Errorf("error starting remotesapi server for cluster config, could not load tls config: %v", err) + return err + } + args.TLSConfig = clusterRemoteSrvTLSConfig - listeners, err := clusterRemoteSrv.Listeners() - if err != nil { - lgr.Errorf("error starting remotesapi server listeners for cluster config on %s: %v", clusterController.RemoteSrvListenAddr(), err) - startError = err - return - } + clusterRemoteSrv, err = remotesrv.NewServer(args) + if err != nil { + lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err) + return err + } + clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.GrpcServer()) - go clusterRemoteSrv.Serve(listeners) - go clusterController.Run() - - clusterController.ManageQueryConnections( - mySQLServer.SessionManager().Iter, - sqlEngine.GetUnderlyingEngine().ProcessList.Kill, - mySQLServer.SessionManager().KillConnection, - ) - } - - if ok, f := mrEnv.IsLocked(); ok { - startError = env.ErrActiveServerLock.New(f) - return - } - - if err = mrEnv.Lock(serverLock); err != nil { - startError = err - return - } - - serverController.registerCloseFunction(startError, func() error { - if metSrv != nil { - metSrv.Close() - } - if remoteSrv != nil { - remoteSrv.GracefulStop() - } - if clusterRemoteSrv != nil { + clusterRemoteSrvListeners, err = clusterRemoteSrv.Listeners() + if err != nil { + lgr.Errorf("error starting remotesapi server listeners for cluster config on %s: %v", clusterController.RemoteSrvListenAddr(), err) + return err + } + return nil + }, + Run: func(context.Context) { + if clusterRemoteSrv == nil { + return + } + clusterRemoteSrv.Serve(clusterRemoteSrvListeners) + }, + Stop: func() error { + if clusterRemoteSrv == nil { + return nil + } clusterRemoteSrv.GracefulStop() - } - if clusterController != nil { + return nil + }, + } + controller.Register(RunClusterRemoteSrv) + + RunClusterController := &svcs.Service{ + Init: func(context.Context) error { + if clusterController == nil { + return nil + } + clusterController.ManageQueryConnections( + mySQLServer.SessionManager().Iter, + sqlEngine.GetUnderlyingEngine().ProcessList.Kill, + mySQLServer.SessionManager().KillConnection, + ) + return nil + }, + Run: func(context.Context) { + if clusterController == nil { + return + } + clusterController.Run() + }, + Stop: func() error { + if clusterController == nil { + return nil + } clusterController.GracefulStop() - } - - return mySQLServer.Close() - }) - - closeError = mySQLServer.Start() - if closeError != nil { - cli.PrintErr(closeError) - } - if err := mrEnv.Unlock(); err != nil { - cli.PrintErr(err) + return nil + }, } + controller.Register(RunClusterController) - return + RunSQLServer := &svcs.Service{ + Run: func(context.Context) { + sqlserver.SetRunningServer(mySQLServer, serverLock) + defer sqlserver.UnsetRunningServer() + mySQLServer.Start() + }, + Stop: func() error { + return mySQLServer.Close() + }, + } + controller.Register(RunSQLServer) + + go controller.Start(ctx) + err := controller.WaitForStart() + if err != nil { + return err, nil + } + return nil, controller.WaitForStop() } // acquireGlobalSqlServerLock attempts to acquire a global lock on the SQL server. If no error is returned, then the lock was acquired. @@ -514,10 +650,10 @@ func newSessionBuilder(se *engine.SqlEngine, config ServerConfig) server.Session } // getConfigFromServerConfig processes ServerConfig and returns server.Config for sql-server. -func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error, error) { +func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error) { serverConf, err := handleProtocolAndAddress(serverConfig) if err != nil { - return server.Config{}, err, nil + return server.Config{}, err } serverConf.DisableClientMultiStatements = serverConfig.DisableClientMultiStatements() @@ -527,7 +663,7 @@ func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error, tlsConfig, err := LoadTLSConfig(serverConfig) if err != nil { - return server.Config{}, nil, err + return server.Config{}, err } // if persist is 'load' we use currently set persisted global variable, @@ -535,12 +671,12 @@ func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error, if serverConfig.PersistenceBehavior() == loadPerisistentGlobals { serverConf, err = serverConf.NewConfig() if err != nil { - return server.Config{}, err, nil + return server.Config{}, err } } else { err = sql.SystemVariables.SetGlobal("max_connections", serverConfig.MaxConnections()) if err != nil { - return server.Config{}, err, nil + return server.Config{}, err } } @@ -554,7 +690,7 @@ func getConfigFromServerConfig(serverConfig ServerConfig) (server.Config, error, serverConf.MaxLoggedQueryLen = serverConfig.MaxLoggedQueryLen() serverConf.EncodeLoggedQuery = serverConfig.ShouldEncodeLoggedQuery() - return serverConf, nil, nil + return serverConf, nil } // handleProtocolAndAddress returns new server.Config object with only Protocol and Address defined. diff --git a/go/cmd/dolt/commands/sqlserver/server_test.go b/go/cmd/dolt/commands/sqlserver/server_test.go index 9dacd076c3..d1f73436c2 100644 --- a/go/cmd/dolt/commands/sqlserver/server_test.go +++ b/go/cmd/dolt/commands/sqlserver/server_test.go @@ -32,6 +32,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/sqle" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/utils/config" + "github.com/dolthub/dolt/go/libraries/utils/svcs" ) //TODO: server tests need to expose a higher granularity for server interactions: @@ -60,7 +61,7 @@ var ( ) func TestServerArgs(t *testing.T) { - serverController := NewServerController() + controller := svcs.NewController() dEnv, err := sqle.CreateEnvWithSeedData() require.NoError(t, err) defer func() { @@ -75,16 +76,16 @@ func TestServerArgs(t *testing.T) { "-t", "5", "-l", "info", "-r", - }, dEnv, serverController) + }, dEnv, controller) }() - err = serverController.WaitForStart() + err = controller.WaitForStart() require.NoError(t, err) conn, err := dbr.Open("mysql", "username:password@tcp(localhost:15200)/", nil) require.NoError(t, err) err = conn.Close() require.NoError(t, err) - serverController.StopServer() - err = serverController.WaitForClose() + controller.Stop() + err = controller.WaitForStop() assert.NoError(t, err) } @@ -110,22 +111,22 @@ listener: defer func() { assert.NoError(t, dEnv.DoltDB.Close()) }() - serverController := NewServerController() + controller := svcs.NewController() go func() { dEnv.FS.WriteFile("config.yaml", []byte(yamlConfig), os.ModePerm) startServer(context.Background(), "0.0.0", "dolt sql-server", []string{ "--config", "config.yaml", - }, dEnv, serverController) + }, dEnv, controller) }() - err = serverController.WaitForStart() + err = controller.WaitForStart() require.NoError(t, err) conn, err := dbr.Open("mysql", "username:password@tcp(localhost:15200)/", nil) require.NoError(t, err) err = conn.Close() require.NoError(t, err) - serverController.StopServer() - err = serverController.WaitForClose() + controller.Stop() + err = controller.WaitForStop() assert.NoError(t, err) } @@ -145,18 +146,15 @@ func TestServerBadArgs(t *testing.T) { } for _, test := range tests { + test := test t.Run(strings.Join(test, " "), func(t *testing.T) { - serverController := NewServerController() - go func(serverController *ServerController) { - startServer(context.Background(), "test", "dolt sql-server", test, env, serverController) - }(serverController) - - // In the event that a test fails, we need to prevent a test from hanging due to a running server - err := serverController.WaitForStart() - require.Error(t, err) - serverController.StopServer() - err = serverController.WaitForClose() - assert.NoError(t, err) + controller := svcs.NewController() + go func() { + startServer(context.Background(), "test", "dolt sql-server", test, env, controller) + }() + if !assert.Error(t, controller.WaitForStart()) { + controller.Stop() + } }) } } @@ -186,8 +184,8 @@ func TestServerGoodParams(t *testing.T) { for _, test := range tests { t.Run(ConfigInfo(test), func(t *testing.T) { - sc := NewServerController() - go func(config ServerConfig, sc *ServerController) { + sc := svcs.NewController() + go func(config ServerConfig, sc *svcs.Controller) { _, _ = Serve(context.Background(), "0.0.0", config, sc, env) }(test, sc) err := sc.WaitForStart() @@ -196,8 +194,8 @@ func TestServerGoodParams(t *testing.T) { require.NoError(t, err) err = conn.Close() require.NoError(t, err) - sc.StopServer() - err = sc.WaitForClose() + sc.Stop() + err = sc.WaitForStop() assert.NoError(t, err) }) } @@ -212,8 +210,8 @@ func TestServerSelect(t *testing.T) { serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15300) - sc := NewServerController() - defer sc.StopServer() + sc := svcs.NewController() + defer sc.Stop() go func() { _, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, env) }() @@ -261,7 +259,7 @@ func TestServerSelect(t *testing.T) { // If a port is already in use, throw error "Port XXXX already in use." func TestServerFailsIfPortInUse(t *testing.T) { - serverController := NewServerController() + controller := svcs.NewController() server := &http.Server{ Addr: ":15200", Handler: http.DefaultServeMux, @@ -287,10 +285,10 @@ func TestServerFailsIfPortInUse(t *testing.T) { "-t", "5", "-l", "info", "-r", - }, dEnv, serverController) + }, dEnv, controller) }() - err = serverController.WaitForStart() + err = controller.WaitForStart() require.Error(t, err) server.Close() wg.Wait() @@ -311,8 +309,8 @@ func TestServerSetDefaultBranch(t *testing.T) { serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15302) - sc := NewServerController() - defer sc.StopServer() + sc := svcs.NewController() + defer sc.Stop() go func() { _, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv) }() @@ -470,7 +468,7 @@ func TestReadReplica(t *testing.T) { dsess.InitPersistedSystemVars(multiSetup.GetEnv(readReplicaDbName)) // start server as read replica - sc := NewServerController() + sc := svcs.NewController() serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15303) // set socket to nil to force tcp @@ -482,7 +480,7 @@ func TestReadReplica(t *testing.T) { require.NoError(t, err) }() require.NoError(t, sc.WaitForStart()) - defer sc.StopServer() + defer sc.Stop() replicatedTable := "new_table" multiSetup.CreateTable(ctx, sourceDbName, replicatedTable) diff --git a/go/cmd/dolt/commands/sqlserver/servercontroller.go b/go/cmd/dolt/commands/sqlserver/servercontroller.go deleted file mode 100644 index 47cecc3572..0000000000 --- a/go/cmd/dolt/commands/sqlserver/servercontroller.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2019 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlserver - -import ( - "sync" -) - -type ServerController struct { - startCh chan struct{} - closeCh chan struct{} - closeCalled *sync.Once - closeRegistered *sync.Once - stopRegistered *sync.Once - closeFunction func() error - startError error - closeError error -} - -// NewServerController creates a `ServerController` for use with synchronizing on `Serve`. -func NewServerController() *ServerController { - return &ServerController{ - startCh: make(chan struct{}), - closeCh: make(chan struct{}), - closeCalled: &sync.Once{}, - closeRegistered: &sync.Once{}, - stopRegistered: &sync.Once{}, - } -} - -// registerCloseFunction is called within `Serve` to associate the close function with a future `StopServer` call. -// Only the first call will register and unblock, thus it is safe to be called multiple times. -func (controller *ServerController) registerCloseFunction(startError error, closeFunc func() error) { - controller.closeRegistered.Do(func() { - if startError != nil { - controller.startError = startError - } - controller.closeFunction = closeFunc - close(controller.startCh) - }) -} - -// serverStopped is called within `Serve` to signal that the server has stopped and set the exit code. -// Only the first call will register and unblock, thus it is safe to be called multiple times. -func (controller *ServerController) serverStopped(closeError error) { - controller.stopRegistered.Do(func() { - if closeError != nil { - controller.closeError = closeError - } - close(controller.closeCh) - }) -} - -// StopServer stops the server if it is running. Only the first call will trigger the stop, thus it is safe for -// multiple goroutines to call this function. -func (controller *ServerController) StopServer() { - if controller.closeFunction != nil { - controller.closeCalled.Do(func() { - if err := controller.closeFunction(); err != nil { - controller.closeError = err - } - }) - } -} - -// WaitForClose blocks the caller until the server has closed. The return is the last error encountered, if any. -func (controller *ServerController) WaitForClose() error { - select { - case <-controller.closeCh: - break - } - - return controller.closeError -} - -// WaitForStart blocks the caller until the server has started. An error is returned if one was encountered. -func (controller *ServerController) WaitForStart() error { - select { - case <-controller.startCh: - break - case <-controller.closeCh: - break - } - - return controller.startError -} diff --git a/go/cmd/dolt/commands/sqlserver/sqlclient.go b/go/cmd/dolt/commands/sqlserver/sqlclient.go index 25ceb5e722..3ac9f81a2e 100644 --- a/go/cmd/dolt/commands/sqlserver/sqlclient.go +++ b/go/cmd/dolt/commands/sqlserver/sqlclient.go @@ -41,6 +41,7 @@ import ( "github.com/dolthub/dolt/go/libraries/utils/argparser" "github.com/dolthub/dolt/go/libraries/utils/filesys" "github.com/dolthub/dolt/go/libraries/utils/iohelp" + "github.com/dolthub/dolt/go/libraries/utils/svcs" ) const ( @@ -106,7 +107,7 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri apr := cli.ParseArgsOrDie(ap, args, help) var serverConfig ServerConfig - var serverController *ServerController + var svcsController *svcs.Controller var err error cli.Println(color.YellowString("WARNING: This command is being deprecated and is not recommended for general use.\n" + @@ -148,11 +149,11 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri } cli.PrintErrf("Starting server with Config %v\n", ConfigInfo(serverConfig)) - serverController = NewServerController() + svcsController = svcs.NewController() go func() { - _, _ = Serve(ctx, cmd.VersionStr, serverConfig, serverController, dEnv) + _, _ = Serve(ctx, cmd.VersionStr, serverConfig, svcsController, dEnv) }() - err = serverController.WaitForStart() + err = svcsController.WaitForStart() if err != nil { cli.PrintErrln(err.Error()) return 1 @@ -375,8 +376,8 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri cli.PrintErrln(err.Error()) } if apr.Contains(sqlClientDualFlag) { - serverController.StopServer() - err = serverController.WaitForClose() + svcsController.Stop() + err = svcsController.WaitForStop() if err != nil { cli.PrintErrln(err.Error()) } diff --git a/go/cmd/dolt/commands/sqlserver/sqlserver.go b/go/cmd/dolt/commands/sqlserver/sqlserver.go index a1c02c3e9a..609015d468 100644 --- a/go/cmd/dolt/commands/sqlserver/sqlserver.go +++ b/go/cmd/dolt/commands/sqlserver/sqlserver.go @@ -30,6 +30,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/utils/argparser" "github.com/dolthub/dolt/go/libraries/utils/filesys" + "github.com/dolthub/dolt/go/libraries/utils/svcs" ) const ( @@ -186,12 +187,20 @@ func (cmd SqlServerCmd) RequiresRepo() bool { // Exec executes the command func (cmd SqlServerCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv, cliCtx cli.CliContext) int { - controller := NewServerController() + controller := svcs.NewController() newCtx, cancelF := context.WithCancel(context.Background()) go func() { - <-ctx.Done() - controller.StopServer() - cancelF() + // Here we only forward along the SIGINT if the server starts + // up successfully. If the service does not start up + // successfully, or if WaitForStart() blocks indefinitely, then + // startServer() should have returned an error and we do not + // need to Stop the running server or deal with our canceled + // parent context. + if controller.WaitForStart() == nil { + <-ctx.Done() + controller.Stop() + cancelF() + } }() return startServer(newCtx, cmd.VersionStr, commandStr, args, dEnv, controller) } @@ -208,7 +217,7 @@ func validateSqlServerArgs(apr *argparser.ArgParseResults) error { return nil } -func startServer(ctx context.Context, versionStr, commandStr string, args []string, dEnv *env.DoltEnv, serverController *ServerController) int { +func startServer(ctx context.Context, versionStr, commandStr string, args []string, dEnv *env.DoltEnv, controller *svcs.Controller) int { ap := SqlServerCmd{}.ArgParser() help, _ := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString(commandStr, sqlServerDocs, ap)) @@ -222,21 +231,11 @@ func startServer(ctx context.Context, versionStr, commandStr string, args []stri } serverConfig, err := GetServerConfig(dEnv.FS, apr) if err != nil { - if serverController != nil { - serverController.StopServer() - serverController.serverStopped(err) - } - cli.PrintErrln(color.RedString("Failed to start server. Bad Configuration")) cli.PrintErrln(err.Error()) return 1 } if err = SetupDoltConfig(dEnv, apr, serverConfig); err != nil { - if serverController != nil { - serverController.StopServer() - serverController.serverStopped(err) - } - cli.PrintErrln(color.RedString("Failed to start server. Bad Configuration")) cli.PrintErrln(err.Error()) return 1 @@ -244,7 +243,7 @@ func startServer(ctx context.Context, versionStr, commandStr string, args []stri cli.PrintErrf("Starting server with Config %v\n", ConfigInfo(serverConfig)) - if startError, closeError := Serve(ctx, versionStr, serverConfig, serverController, dEnv); startError != nil || closeError != nil { + if startError, closeError := Serve(ctx, versionStr, serverConfig, controller, dEnv); startError != nil || closeError != nil { if startError != nil { cli.PrintErrln(startError) } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go index 75a54598f0..4e2887811a 100755 --- a/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go @@ -32,6 +32,7 @@ import ( "github.com/dolthub/dolt/go/cmd/dolt/commands/sqlserver" "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils" "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/utils/svcs" ) // DoltBranchMultiSessionScriptTests contain tests that need to be run in a multi-session server environment @@ -539,8 +540,8 @@ func testMultiSessionScriptTests(t *testing.T, tests []queries.ScriptTest) { require.NoError(t, conn1.Close()) require.NoError(t, conn2.Close()) - sc.StopServer() - err = sc.WaitForClose() + sc.Stop() + err = sc.WaitForStop() require.NoError(t, err) }) } @@ -595,8 +596,8 @@ func testSerialSessionScriptTests(t *testing.T, tests []queries.ScriptTest) { require.NoError(t, conn1.Close()) - sc.StopServer() - err = sc.WaitForClose() + sc.Stop() + err = sc.WaitForStop() require.NoError(t, err) }) } @@ -657,7 +658,7 @@ func assertResultsEqual(t *testing.T, expected []sql.Row, rows *gosql.Rows) { } // startServer will start sql-server with given host, unix socket file path and whether to use specific port, which is defined randomly. -func startServer(t *testing.T, withPort bool, host string, unixSocketPath string) (*env.DoltEnv, *sqlserver.ServerController, sqlserver.ServerConfig) { +func startServer(t *testing.T, withPort bool, host string, unixSocketPath string) (*env.DoltEnv, *svcs.Controller, sqlserver.ServerConfig) { dEnv := dtestutils.CreateTestEnv() serverConfig := sqlserver.DefaultServerConfig() if withPort { @@ -676,8 +677,8 @@ func startServer(t *testing.T, withPort bool, host string, unixSocketPath string return dEnv, onEnv, config } -func startServerOnEnv(t *testing.T, serverConfig sqlserver.ServerConfig, dEnv *env.DoltEnv) (*sqlserver.ServerController, sqlserver.ServerConfig) { - sc := sqlserver.NewServerController() +func startServerOnEnv(t *testing.T, serverConfig sqlserver.ServerConfig, dEnv *env.DoltEnv) (*svcs.Controller, sqlserver.ServerConfig) { + sc := svcs.NewController() go func() { _, _ = sqlserver.Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv) }() @@ -745,8 +746,8 @@ func TestDoltServerRunningUnixSocket(t *testing.T) { require.NoError(t, localConn.Close()) // Stopping unix socket server - sc.StopServer() - err = sc.WaitForClose() + sc.Stop() + err = sc.WaitForStop() require.NoError(t, err) require.NoFileExists(t, defaultUnixSocketPath) @@ -773,7 +774,7 @@ func TestDoltServerRunningUnixSocket(t *testing.T) { }) // Stopping TCP socket server - tcpSc.StopServer() - err = tcpSc.WaitForClose() + tcpSc.Stop() + err = tcpSc.WaitForStop() require.NoError(t, err) } diff --git a/go/libraries/utils/svcs/controller.go b/go/libraries/utils/svcs/controller.go new file mode 100644 index 0000000000..dcab672f32 --- /dev/null +++ b/go/libraries/utils/svcs/controller.go @@ -0,0 +1,189 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package svcs + +import ( + "context" + "errors" + "sync" +) + +// A Service is a runnable unit of functionality that a Controller can +// take responsibility for. It has an |Init| function, which can error, and +// which should do all of the initialization and validation work necessary to +// bring the service up. It has a |Run| function, which will be called in a +// separate go-routine and should run and provide the functionality associated +// with the service until the |Stop| function is called. +type Service struct { + Init func(context.Context) error + Run func(context.Context) + Stop func() error +} + +// A Controller is responsible for initializing a number of registered +// services, running them all, and stopping them all when requested. Services +// are registered with |Register(*Service)|. When |Start| is called, the +// services are all initialized, in the order of their registration, and if +// every service initializes successfully, they are |Run| concurrently. When +// |Stop| is called, services are stopped in reverse-registration order. |Stop| +// does not block for the goroutines spawned by |Start| to complete, although +// typically a Service's |Stop| function should do that. |Stop| only returns an +// error if the Controller is in an illegal state where it is not valid to Stop +// it. In particular, it does not return an error seen by a Service on Stop(). +// That error is returned from Start() and from WaitForStop(). +// +// Any attempt to register a service after |Start| or |Stop| has been called +// will return an error. +// +// If an error occurs when initializing the services of a Controller, the +// Stop functions of any already initialized Services are called in +// reverse-order. The error which caused the initialization error is returned. +// +// In the case that all Services Init successfully, the error returned from +// |Start| is the first non-nil error which is returned from the |Stop| +// functions, in the order they are called. +// +// WaitForStart() can be called at any time on a Controller. It will +// block until |Start| is called. After |Start| is called, if all the services +// succesfully initialize, it will return |nil|. Otherwise it will return the +// same error |Start| returned. +// +// WaitForStop() can be called at any time on a Controller. It will block +// until |Start| is called and initialization fails, or until |Stop| is called. +// It will return the same error which |Start| returned. +type Controller struct { + mu sync.Mutex + services []*Service + initErr error + stopErr error + started bool + startCh chan struct{} + stopped bool + stopCh chan struct{} + stoppedCh chan struct{} +} + +func NewController() *Controller { + return &Controller{ + startCh: make(chan struct{}), + stopCh: make(chan struct{}), + stoppedCh: make(chan struct{}), + } +} + +func (c *Controller) WaitForStart() error { + <-c.startCh + c.mu.Lock() + err := c.initErr + c.mu.Unlock() + return err +} + +func (c *Controller) WaitForStop() error { + <-c.stoppedCh + c.mu.Lock() + var err error + if c.initErr != nil { + err = c.initErr + } else if c.stopErr != nil { + err = c.stopErr + } + c.mu.Unlock() + return err +} + +func (c *Controller) Register(svc *Service) error { + c.mu.Lock() + if c.started { + c.mu.Unlock() + return errors.New("Controller: cannot Register a service on a controller which was already started") + } + c.services = append(c.services, svc) + c.mu.Unlock() + return nil +} + +func (c *Controller) Stop() error { + c.mu.Lock() + if !c.started { + c.mu.Unlock() + return errors.New("Controller: cannot Stop a controller which was never started") + } + if c.stopped { + c.mu.Unlock() + return errors.New("Controller: cannot Stop a controller which was already stopped or which failed to initialize all its services") + } + c.stopped = true + close(c.stopCh) + c.mu.Unlock() + <-c.stoppedCh + return nil +} + +func (c *Controller) Start(ctx context.Context) error { + c.mu.Lock() + if c.started { + return errors.New("Controller: cannot start service controller twice") + } + c.started = true + svcs := make([]*Service, len(c.services)) + copy(svcs, c.services) + c.mu.Unlock() + for i, s := range svcs { + if s.Init == nil { + continue + } + err := s.Init(ctx) + if err != nil { + for j := i - 1; j >= 0; j-- { + if svcs[j].Stop != nil { + svcs[j].Stop() + } + } + c.mu.Lock() + c.stopped = true + c.initErr = err + close(c.startCh) + close(c.stoppedCh) + c.mu.Unlock() + return err + } + } + close(c.startCh) + for _, s := range svcs { + if s.Run == nil { + continue + } + go s.Run(ctx) + } + <-c.stopCh + var stopErr error + for i := len(svcs) - 1; i >= 0; i-- { + if svcs[i].Stop == nil { + continue + } + err := svcs[i].Stop() + if err != nil && stopErr == nil { + stopErr = err + } + } + c.mu.Lock() + if stopErr != nil { + c.stopErr = stopErr + } + close(c.stoppedCh) + c.mu.Unlock() + return stopErr +} diff --git a/go/libraries/utils/svcs/controller_test.go b/go/libraries/utils/svcs/controller_test.go new file mode 100644 index 0000000000..e5f147511b --- /dev/null +++ b/go/libraries/utils/svcs/controller_test.go @@ -0,0 +1,305 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package svcs + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestController(t *testing.T) { + t.Run("NewController", func(t *testing.T) { + c := NewController() + require.NotNil(t, c) + }) + t.Run("Stop", func(t *testing.T) { + t.Run("CalledBeforeStart", func(t *testing.T) { + c := NewController() + require.Error(t, c.Stop()) + }) + t.Run("ReturnsFirstError", func(t *testing.T) { + c := NewController() + ctx := context.Background() + err := errors.New("first") + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { return nil }, + Run: func(context.Context) {}, + Stop: func() error { return errors.New("second") }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { return nil }, + Run: func(context.Context) {}, + Stop: func() error { return err }, + })) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(t, c.WaitForStart()) + require.NoError(t, c.Stop()) + }() + require.ErrorIs(t, c.Start(ctx), err) + require.ErrorIs(t, c.WaitForStop(), err) + wg.Wait() + }) + }) + t.Run("EmptyServices", func(t *testing.T) { + c := NewController() + ctx := context.Background() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(t, c.WaitForStart()) + require.NoError(t, c.Stop()) + }() + require.NoError(t, c.Start(ctx)) + require.NoError(t, c.WaitForStop()) + wg.Wait() + }) + t.Run("Register", func(t *testing.T) { + t.Run("AfterStartCalled", func(t *testing.T) { + c := NewController() + ctx := context.Background() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(t, c.WaitForStart()) + require.Error(t, c.Register(&Service{ + Init: func(context.Context) error { return nil }, + Run: func(context.Context) {}, + Stop: func() error { return nil }, + })) + require.NoError(t, c.Stop()) + }() + require.NoError(t, c.Start(ctx)) + require.NoError(t, c.WaitForStop()) + wg.Wait() + }) + }) + t.Run("Start", func(t *testing.T) { + t.Run("CallsInitInOrder", func(t *testing.T) { + c := NewController() + var inited []int + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + inited = append(inited, 0) + return nil + }, + Run: func(context.Context) {}, + Stop: func() error { return nil }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + inited = append(inited, 1) + return nil + }, + Run: func(context.Context) {}, + Stop: func() error { return nil }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + inited = append(inited, 2) + return nil + }, + Run: func(context.Context) {}, + Stop: func() error { return nil }, + })) + ctx := context.Background() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(t, c.WaitForStart()) + require.NoError(t, c.Stop()) + }() + require.NoError(t, c.Start(ctx)) + require.NoError(t, c.WaitForStop()) + require.Equal(t, inited, []int{0, 1, 2}) + wg.Wait() + }) + t.Run("StopsCallingInitOnFirstError", func(t *testing.T) { + err := errors.New("first error") + c := NewController() + var inited []int + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + inited = append(inited, 0) + return nil + }, + Run: func(context.Context) {}, + Stop: func() error { return nil }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + inited = append(inited, 1) + return nil + }, + Run: func(context.Context) {}, + Stop: func() error { return nil }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + return err + }, + Run: func(context.Context) {}, + Stop: func() error { return nil }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + inited = append(inited, 2) + return nil + }, + Run: func(context.Context) {}, + Stop: func() error { return nil }, + })) + ctx := context.Background() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.ErrorIs(t, c.WaitForStart(), err) + require.NotErrorIs(t, c.Stop(), err) + }() + require.ErrorIs(t, c.Start(ctx), err) + require.ErrorIs(t, c.WaitForStop(), err) + require.Equal(t, inited, []int{0, 1}) + wg.Wait() + }) + t.Run("CallsStopWhenInitErrors", func(t *testing.T) { + err := errors.New("first error") + c := NewController() + var stopped []int + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + return nil + }, + Run: func(context.Context) {}, + Stop: func() error { + stopped = append(stopped, 0) + return nil + }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + return nil + }, + Run: func(context.Context) {}, + Stop: func() error { + stopped = append(stopped, 1) + return nil + }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + return err + }, + Run: func(context.Context) {}, + Stop: func() error { + stopped = append(stopped, 2) + return nil + }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { + return nil + }, + Run: func(context.Context) {}, + Stop: func() error { + stopped = append(stopped, 3) + return nil + }, + })) + ctx := context.Background() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.ErrorIs(t, c.WaitForStart(), err) + require.NotErrorIs(t, c.Stop(), err) + }() + require.ErrorIs(t, c.Start(ctx), err) + require.ErrorIs(t, c.WaitForStop(), err) + require.Equal(t, stopped, []int{1, 0}) + wg.Wait() + }) + t.Run("RunsServices", func(t *testing.T) { + c := NewController() + var wg sync.WaitGroup + wg.Add(2) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { return nil }, + Run: func(context.Context) { wg.Done() }, + Stop: func() error { return nil }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { return nil }, + Run: func(context.Context) { wg.Done() }, + Stop: func() error { return nil }, + })) + ctx := context.Background() + var cwg sync.WaitGroup + cwg.Add(1) + go func() { + defer cwg.Done() + require.NoError(t, c.WaitForStart()) + require.NoError(t, c.Stop()) + }() + require.NoError(t, c.Start(ctx)) + require.NoError(t, c.WaitForStop()) + wg.Wait() + cwg.Wait() + }) + t.Run("StopsAllServices", func(t *testing.T) { + c := NewController() + var wg sync.WaitGroup + err := errors.New("first error") + wg.Add(2) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { return nil }, + Run: func(context.Context) {}, + Stop: func() error { + wg.Done() + return errors.New("second error") + }, + })) + require.NoError(t, c.Register(&Service{ + Init: func(context.Context) error { return nil }, + Run: func(context.Context) {}, + Stop: func() error { + wg.Done() + return err + }, + })) + ctx := context.Background() + var cwg sync.WaitGroup + cwg.Add(1) + go func() { + defer cwg.Done() + require.NoError(t, c.WaitForStart()) + require.NoError(t, c.Stop()) + }() + require.ErrorIs(t, c.Start(ctx), err) + require.ErrorIs(t, c.WaitForStop(), err) + wg.Wait() + cwg.Wait() + }) + }) +} diff --git a/go/performance/replicationbench/replica_test.go b/go/performance/replicationbench/replica_test.go index 618e41e73a..38f755babc 100644 --- a/go/performance/replicationbench/replica_test.go +++ b/go/performance/replicationbench/replica_test.go @@ -31,6 +31,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/testcommands" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/libraries/utils/svcs" ) type query string @@ -163,13 +164,13 @@ func getProfFile(b *testing.B) *os.File { } func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, cfg srv.ServerConfig, queries []query) { - serverController := srv.NewServerController() + sc := svcs.NewController() eg, ctx := errgroup.WithContext(ctx) //b.Logf("Starting server with Config %v\n", srv.ConfigInfo(cfg)) eg.Go(func() (err error) { - startErr, closeErr := srv.Serve(ctx, "", cfg, serverController, dEnv) + startErr, closeErr := srv.Serve(ctx, "", cfg, sc, dEnv) if startErr != nil { return startErr } @@ -178,7 +179,7 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, } return nil }) - if err := serverController.WaitForStart(); err != nil { + if err := sc.WaitForStart(); err != nil { b.Fatal(err) } @@ -188,8 +189,8 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, } } - serverController.StopServer() - if err := serverController.WaitForClose(); err != nil { + sc.Stop() + if err := sc.WaitForStop(); err != nil { b.Fatal(err) } if err := eg.Wait(); err != nil { diff --git a/go/performance/serverbench/bench_test.go b/go/performance/serverbench/bench_test.go index c556eed89e..d415bd5247 100644 --- a/go/performance/serverbench/bench_test.go +++ b/go/performance/serverbench/bench_test.go @@ -31,6 +31,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/utils/filesys" + "github.com/dolthub/dolt/go/libraries/utils/svcs" "github.com/dolthub/dolt/go/store/types" ) @@ -175,13 +176,13 @@ func getProfFile(b *testing.B) *os.File { } func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, cfg srv.ServerConfig, queries []query) { - serverController := srv.NewServerController() + sc := svcs.NewController() eg, ctx := errgroup.WithContext(ctx) //b.Logf("Starting server with Config %v\n", srv.ConfigInfo(cfg)) eg.Go(func() (err error) { - startErr, closeErr := srv.Serve(ctx, "", cfg, serverController, dEnv) + startErr, closeErr := srv.Serve(ctx, "", cfg, sc, dEnv) if startErr != nil { return startErr } @@ -190,7 +191,7 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, } return nil }) - if err := serverController.WaitForStart(); err != nil { + if err := sc.WaitForStart(); err != nil { b.Fatal(err) } @@ -200,8 +201,8 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv, } } - serverController.StopServer() - if err := serverController.WaitForClose(); err != nil { + sc.Stop() + if err := sc.WaitForStop(); err != nil { b.Fatal(err) } if err := eg.Wait(); err != nil {