go: cmd/dolt: sqlserver: Create a Config struct to encapsulate arguments that control sqlserver.ConfigureServices behavior.

Allow configuring the ProtocolListenerFactory through ConfigureServices.
This commit is contained in:
Aaron Son
2025-03-05 13:51:33 -08:00
parent 29ed892419
commit 60cbf4726f
6 changed files with 114 additions and 70 deletions
+65 -62
View File
@@ -76,40 +76,42 @@ var ExternalDisableUsers bool = false
var ErrCouldNotLockDatabase = goerrors.NewKind("database \"%s\" is locked by another dolt process; either clone the database to run a second server, or stop the dolt process which currently holds an exclusive write lock on the database")
type Config struct {
ServerConfig servercfg.ServerConfig
DoltEnv *env.DoltEnv
SkipRootUserInit bool
Version string
Controller *svcs.Controller
ProtocolListenerFactory server.ProtocolListenerFunc
}
// Serve starts a MySQL-compatible server. Returns any errors that were encountered.
func Serve(
ctx context.Context,
version string,
serverConfig servercfg.ServerConfig,
controller *svcs.Controller,
dEnv *env.DoltEnv,
skipRootUserInitialization bool,
cfg *Config,
) (startError error, closeError error) {
// Code is easier to work through if we assume that serverController is never nil
if controller == nil {
controller = svcs.NewController()
if cfg.Controller == nil {
cfg.Controller = svcs.NewController()
}
ConfigureServices(serverConfig, controller, version, dEnv, skipRootUserInitialization)
ConfigureServices(cfg)
go controller.Start(ctx)
err := controller.WaitForStart()
go cfg.Controller.Start(ctx)
err := cfg.Controller.WaitForStart()
if err != nil {
return err, nil
}
return nil, controller.WaitForStop()
return nil, cfg.Controller.WaitForStop()
}
func ConfigureServices(
serverConfig servercfg.ServerConfig,
controller *svcs.Controller,
version string,
dEnv *env.DoltEnv,
skipRootUserInitialization bool,
cfg *Config,
) {
controller := cfg.Controller
ValidateConfigStep := &svcs.AnonService{
InitF: func(context.Context) error {
return servercfg.ValidateConfig(serverConfig)
return servercfg.ValidateConfig(cfg.ServerConfig)
},
}
controller.Register(ValidateConfigStep)
@@ -118,18 +120,18 @@ func ConfigureServices(
lgr.SetOutput(cli.CliErr)
InitLogging := &svcs.AnonService{
InitF: func(context.Context) error {
level, err := logrus.ParseLevel(serverConfig.LogLevel().String())
level, err := logrus.ParseLevel(cfg.ServerConfig.LogLevel().String())
if err != nil {
return err
}
logrus.SetLevel(level)
switch strings.ToLower(string(serverConfig.LogFormat())) {
switch strings.ToLower(string(cfg.ServerConfig.LogFormat())) {
case string(servercfg.LogFormat_JSON):
logrus.SetFormatter(&logrus.JSONFormatter{})
case string(servercfg.LogFormat_Text):
logrus.SetFormatter(&logrus.TextFormatter{})
default:
return fmt.Errorf("unknown log format: %s", serverConfig.LogFormat())
return fmt.Errorf("unknown log format: %s", cfg.ServerConfig.LogFormat())
}
sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{
@@ -164,12 +166,12 @@ func ConfigureServices(
}
controller.Register(InitLogging)
controller.Register(newHeartbeatService(version, dEnv))
controller.Register(newHeartbeatService(cfg.Version, cfg.DoltEnv))
fs := dEnv.FS
fs := cfg.DoltEnv.FS
InitFailsafes := &svcs.AnonService{
InitF: func(ctx context.Context) (err error) {
dEnv.Config.SetFailsafes(env.DefaultFailsafeConfig)
cfg.DoltEnv.Config.SetFailsafes(env.DefaultFailsafeConfig)
return nil
},
}
@@ -178,7 +180,7 @@ func ConfigureServices(
var mrEnv *env.MultiRepoEnv
InitMultiEnv := &svcs.AnonService{
InitF: func(ctx context.Context) (err error) {
mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv)
mrEnv, err = env.MultiEnvForDirectory(ctx, cfg.DoltEnv.Config.WriteableConfig(), fs, cfg.DoltEnv.Version, cfg.DoltEnv)
return err
},
}
@@ -199,11 +201,11 @@ func ConfigureServices(
var localCreds *LocalCreds
InitServerLocalCreds := &svcs.AnonService{
InitF: func(context.Context) (err error) {
localCreds, err = persistServerLocalCreds(serverConfig.Port(), dEnv)
localCreds, err = persistServerLocalCreds(cfg.ServerConfig.Port(), cfg.DoltEnv)
return err
},
StopF: func() error {
RemoveLocalCreds(dEnv.FS)
RemoveLocalCreds(cfg.DoltEnv.FS)
return nil
},
}
@@ -212,7 +214,7 @@ func ConfigureServices(
var clusterController *cluster.Controller
InitClusterController := &svcs.AnonService{
InitF: func(context.Context) (err error) {
clusterController, err = cluster.NewController(lgr, serverConfig.ClusterConfig(), mrEnv.Config())
clusterController, err = cluster.NewController(lgr, cfg.ServerConfig.ClusterConfig(), mrEnv.Config())
return err
},
}
@@ -221,7 +223,7 @@ func ConfigureServices(
var serverConf server.Config
LoadServerConfig := &svcs.AnonService{
InitF: func(context.Context) (err error) {
serverConf, err = getConfigFromServerConfig(serverConfig)
serverConf, err = getConfigFromServerConfig(cfg.ServerConfig, cfg.ProtocolListenerFactory)
return err
},
}
@@ -232,20 +234,20 @@ func ConfigureServices(
InitSqlEngineConfig := &svcs.AnonService{
InitF: func(context.Context) error {
config = &engine.SqlEngineConfig{
IsReadOnly: serverConfig.ReadOnly(),
PrivFilePath: serverConfig.PrivilegeFilePath(),
BranchCtrlFilePath: serverConfig.BranchControlFilePath(),
DoltCfgDirPath: serverConfig.CfgDir(),
ServerUser: serverConfig.User(),
ServerPass: serverConfig.Password(),
ServerHost: serverConfig.Host(),
Autocommit: serverConfig.AutoCommit(),
DoltTransactionCommit: serverConfig.DoltTransactionCommit(),
JwksConfig: serverConfig.JwksConfig(),
SystemVariables: serverConfig.SystemVars(),
IsReadOnly: cfg.ServerConfig.ReadOnly(),
PrivFilePath: cfg.ServerConfig.PrivilegeFilePath(),
BranchCtrlFilePath: cfg.ServerConfig.BranchControlFilePath(),
DoltCfgDirPath: cfg.ServerConfig.CfgDir(),
ServerUser: cfg.ServerConfig.User(),
ServerPass: cfg.ServerConfig.Password(),
ServerHost: cfg.ServerConfig.Host(),
Autocommit: cfg.ServerConfig.AutoCommit(),
DoltTransactionCommit: cfg.ServerConfig.DoltTransactionCommit(),
JwksConfig: cfg.ServerConfig.JwksConfig(),
SystemVariables: cfg.ServerConfig.SystemVars(),
ClusterController: clusterController,
BinlogReplicaController: binlogreplication.DoltBinlogReplicaController,
SkipRootUserInitialization: skipRootUserInitialization,
SkipRootUserInitialization: cfg.SkipRootUserInit,
}
return nil
},
@@ -255,7 +257,7 @@ func ConfigureServices(
var esStatus eventscheduler.SchedulerStatus
InitEventSchedulerStatus := &svcs.AnonService{
InitF: func(context.Context) (err error) {
esStatus, err = getEventSchedulerStatus(serverConfig.EventSchedulerStatus())
esStatus, err = getEventSchedulerStatus(cfg.ServerConfig.EventSchedulerStatus())
if err != nil {
return err
}
@@ -267,8 +269,8 @@ func ConfigureServices(
InitAutoGCController := &svcs.AnonService{
InitF: func(context.Context) error {
if serverConfig.AutoGCBehavior() != nil &&
serverConfig.AutoGCBehavior().Enable() {
if cfg.ServerConfig.AutoGCBehavior() != nil &&
cfg.ServerConfig.AutoGCBehavior().Enable() {
config.AutoGCController = sqle.NewAutoGCController(lgr)
}
return nil
@@ -352,7 +354,7 @@ func ConfigureServices(
// in the configuration files for a sql-server, and not global for the whole host.
PersistNondeterministicSystemVarDefaults := &svcs.AnonService{
InitF: func(ctx context.Context) error {
err := dsess.PersistSystemVarDefaults(dEnv)
err := dsess.PersistSystemVarDefaults(cfg.DoltEnv)
if err != nil {
logrus.Errorf("unable to persist system variable defaults: %v", err)
}
@@ -404,7 +406,7 @@ func ConfigureServices(
if logBin == 1 {
logrus.Infof("Enabling binary logging for branch %s", logBinBranch)
binlogProducer, err := binlogreplication.NewBinlogProducer(dEnv.FS)
binlogProducer, err := binlogreplication.NewBinlogProducer(cfg.DoltEnv.FS)
if err != nil {
return err
}
@@ -441,7 +443,7 @@ func ConfigureServices(
InitF: func(ctx context.Context) error {
// If privileges.db has already been initialized, indicating that this is NOT the
// first time sql-server has been launched, then don't initialize the root superuser.
if permissionDbExists, err := doesPrivilegesDbExist(dEnv, serverConfig.PrivilegeFilePath()); err != nil {
if permissionDbExists, err := doesPrivilegesDbExist(cfg.DoltEnv, cfg.ServerConfig.PrivilegeFilePath()); err != nil {
return err
} else if permissionDbExists {
logrus.Debug("privileges.db already exists, not creating root superuser")
@@ -480,7 +482,7 @@ func ConfigureServices(
// for persisting the privileges database. The filesys API
// is in the Dolt layer, so when the file path is passed to
// GMS, it expects it to be a path on disk, and errors out.
if _, isInMemFs := dEnv.FS.(*filesys.InMemFS); isInMemFs {
if _, isInMemFs := cfg.DoltEnv.FS.(*filesys.InMemFS); isInMemFs {
return nil
} else {
sqlCtx, err := sqlEngine.NewDefaultContext(context.Background())
@@ -496,8 +498,8 @@ func ConfigureServices(
var metListener *metricsListener
InitMetricsListener := &svcs.AnonService{
InitF: func(context.Context) (err error) {
labels := serverConfig.MetricsLabels()
metListener, err = newMetricsListener(labels, version, clusterController)
labels := cfg.ServerConfig.MetricsLabels()
metListener, err = newMetricsListener(labels, cfg.Version, clusterController)
return err
},
StopF: func() error {
@@ -539,10 +541,10 @@ func ConfigureServices(
RunMetricsServer := &svcs.AnonService{
InitF: func(context.Context) (err error) {
if serverConfig.MetricsHost() != "" && serverConfig.MetricsPort() > 0 {
if cfg.ServerConfig.MetricsHost() != "" && cfg.ServerConfig.MetricsPort() > 0 {
metSrv.state.Swap(svcs.ServiceState_Init)
addr := fmt.Sprintf("%s:%d", serverConfig.MetricsHost(), serverConfig.MetricsPort())
addr := fmt.Sprintf("%s:%d", cfg.ServerConfig.MetricsHost(), cfg.ServerConfig.MetricsPort())
metSrv.lis, err = net.Listen("tcp", addr)
if err != nil {
return err
@@ -583,16 +585,16 @@ func ConfigureServices(
var remoteSrv RemoteSrvService
RunRemoteSrv := &svcs.AnonService{
InitF: func(ctx context.Context) error {
if serverConfig.RemotesapiPort() == nil {
if cfg.ServerConfig.RemotesapiPort() == nil {
return nil
}
remoteSrv.state.Swap(svcs.ServiceState_Init)
port := *serverConfig.RemotesapiPort()
port := *cfg.ServerConfig.RemotesapiPort()
apiReadOnly := false
if serverConfig.RemotesapiReadOnly() != nil {
apiReadOnly = *serverConfig.RemotesapiReadOnly()
if cfg.ServerConfig.RemotesapiReadOnly() != nil {
apiReadOnly = *cfg.ServerConfig.RemotesapiReadOnly()
}
listenaddr := fmt.Sprintf(":%d", port)
@@ -601,7 +603,7 @@ func ConfigureServices(
}
args := remotesrv.ServerArgs{
Logger: logrus.NewEntry(lgr),
ReadOnly: apiReadOnly || serverConfig.ReadOnly(),
ReadOnly: apiReadOnly || cfg.ServerConfig.ReadOnly(),
HttpListenAddr: listenaddr,
GrpcListenAddr: listenaddr,
ConcurrencyControl: remotesapi.PushConcurrencyControl_PUSH_CONCURRENCY_CONTROL_ASSERT_WORKING_SET,
@@ -666,7 +668,7 @@ func ConfigureServices(
}
args.FS = sqlEngine.FileSystem()
clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig())
clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(cfg.ServerConfig.ClusterConfig())
if err != nil {
lgr.Errorf("error starting remotesapi server for cluster config, could not load tls config: %v", err)
return err
@@ -675,7 +677,7 @@ func ConfigureServices(
clusterRemoteSrv.srv, err = remotesrv.NewServer(args)
if err != nil {
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
lgr.Errorf("error creating remotesapi server on port %d: %v", *cfg.ServerConfig.RemotesapiPort(), err)
return err
}
clusterController.RegisterGrpcServices(sqle.GetInterceptorSqlContext, clusterRemoteSrv.srv.GrpcServer())
@@ -715,13 +717,13 @@ func ConfigureServices(
var sqlServerClosed bool
InitSQLServer := &svcs.AnonService{
InitF: func(context.Context) (err error) {
v, ok := serverConfig.(servercfg.ValidatingServerConfig)
v, ok := cfg.ServerConfig.(servercfg.ValidatingServerConfig)
if ok && v.GoldenMysqlConnectionString() != "" {
mySQLServer, err = server.NewServerWithHandler(
serverConf,
sqlEngine.GetUnderlyingEngine(),
sql.NewContext,
newSessionBuilder(sqlEngine, serverConfig),
newSessionBuilder(sqlEngine, cfg.ServerConfig),
metListener,
func(h mysql.Handler) (mysql.Handler, error) {
return golden.NewValidatingHandler(h, v.GoldenMysqlConnectionString(), logrus.StandardLogger())
@@ -732,7 +734,7 @@ func ConfigureServices(
serverConf,
sqlEngine.GetUnderlyingEngine(),
sql.NewContext,
newSessionBuilder(sqlEngine, serverConfig),
newSessionBuilder(sqlEngine, cfg.ServerConfig),
metListener,
)
}
@@ -1065,7 +1067,7 @@ func newSessionBuilder(se *engine.SqlEngine, config servercfg.ServerConfig) serv
}
// getConfigFromServerConfig processes ServerConfig and returns server.Config for sql-server.
func getConfigFromServerConfig(serverConfig servercfg.ServerConfig) (server.Config, error) {
func getConfigFromServerConfig(serverConfig servercfg.ServerConfig, plf server.ProtocolListenerFunc) (server.Config, error) {
serverConf, err := handleProtocolAndAddress(serverConfig)
if err != nil {
return server.Config{}, err
@@ -1095,6 +1097,7 @@ func getConfigFromServerConfig(serverConfig servercfg.ServerConfig) (server.Conf
serverConf.RequireSecureTransport = serverConfig.RequireSecureTransport()
serverConf.MaxLoggedQueryLen = serverConfig.MaxLoggedQueryLen()
serverConf.EncodeLoggedQuery = serverConfig.ShouldEncodeLoggedQuery()
serverConf.ProtocolListenerFactory = plf
return serverConf, nil
}
+24 -4
View File
@@ -212,7 +212,12 @@ func TestServerGoodParams(t *testing.T) {
t.Run(servercfg.ConfigInfo(test), func(t *testing.T) {
sc := svcs.NewController()
go func(config servercfg.ServerConfig, sc *svcs.Controller) {
_, _ = Serve(context.Background(), "0.0.0", config, sc, env, false)
_, _ = Serve(context.Background(), &Config{
Version: "0.0.0",
ServerConfig: config,
Controller: sc,
DoltEnv: env,
})
}(test, sc)
err := sc.WaitForStart()
require.NoError(t, err)
@@ -240,7 +245,12 @@ func TestServerSelect(t *testing.T) {
sc := svcs.NewController()
defer sc.Stop()
go func() {
_, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, env, false)
_, _ = Serve(context.Background(), &Config{
Version: "0.0.0",
ServerConfig: serverConfig,
Controller: sc,
DoltEnv: env,
})
}()
err = sc.WaitForStart()
require.NoError(t, err)
@@ -339,7 +349,12 @@ func TestServerSetDefaultBranch(t *testing.T) {
sc := svcs.NewController()
defer sc.Stop()
go func() {
_, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv, false)
_, _ = Serve(context.Background(), &Config{
Version: "0.0.0",
ServerConfig: serverConfig,
Controller: sc,
DoltEnv: dEnv,
})
}()
err = sc.WaitForStart()
require.NoError(t, err)
@@ -503,7 +518,12 @@ func TestReadReplica(t *testing.T) {
os.Chdir(multiSetup.DbPaths[readReplicaDbName])
go func() {
err, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, multiSetup.GetEnv(readReplicaDbName), false)
err, _ = Serve(context.Background(), &Config{
Version: "0.0.0",
ServerConfig: serverConfig,
Controller: sc,
DoltEnv: multiSetup.GetEnv(readReplicaDbName),
})
require.NoError(t, err)
}()
require.NoError(t, sc.WaitForStart())
+7 -1
View File
@@ -267,7 +267,13 @@ func StartServer(ctx context.Context, versionStr, commandStr string, args []stri
cli.Printf("Starting server with Config %v\n", servercfg.ConfigInfo(serverConfig))
skipRootUserInitialization := apr.Contains(skipRootUserInitialization)
startError, closeError := Serve(ctx, versionStr, serverConfig, controller, dEnv, skipRootUserInitialization)
startError, closeError := Serve(ctx, &Config{
Version: versionStr,
ServerConfig: serverConfig,
Controller: controller,
DoltEnv: dEnv,
SkipRootUserInit: skipRootUserInitialization,
})
if startError != nil {
return startError
}
@@ -551,7 +551,12 @@ func makeDestinationSlice(t *testing.T, columnTypes []*gosql.ColumnType) []inter
func startServerOnEnv(t *testing.T, serverConfig servercfg.ServerConfig, dEnv *env.DoltEnv) (*svcs.Controller, servercfg.ServerConfig) {
sc := svcs.NewController()
go func() {
_, _ = sqlserver.Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv, false)
_, _ = sqlserver.Serve(context.Background(), &sqlserver.Config{
Version: "0.0.0",
ServerConfig: serverConfig,
Controller: sc,
DoltEnv: dEnv,
})
}()
err := sc.WaitForStart()
require.NoError(t, err)
@@ -171,7 +171,12 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv,
//b.Logf("Starting server with Config %v\n", srv.ConfigInfo(cfg))
eg.Go(func() (err error) {
startErr, closeErr := srv.Serve(ctx, "", cfg, sc, dEnv, false)
startErr, closeErr := srv.Serve(ctx, &srv.Config{
Version: "",
ServerConfig: cfg,
Controller: sc,
DoltEnv: dEnv,
})
if startErr != nil {
return startErr
}
+6 -1
View File
@@ -183,7 +183,12 @@ func executeServerQueries(ctx context.Context, b *testing.B, dEnv *env.DoltEnv,
//b.Logf("Starting server with Config %v\n", srv.ConfigInfo(cfg))
eg.Go(func() (err error) {
startErr, closeErr := srv.Serve(ctx, "", cfg, sc, dEnv, false)
startErr, closeErr := srv.Serve(ctx, &srv.Config{
Version: "",
ServerConfig: cfg,
Controller: sc,
DoltEnv: dEnv,
})
if startErr != nil {
return startErr
}