go/cmd/dolt/commands/sqlserver: Make sure to close listeners for remotesapi servers if we never get to the run step.

This commit is contained in:
Aaron Son
2023-11-15 11:58:04 -08:00
parent a6c4815eee
commit ccb5cd131a
2 changed files with 44 additions and 21 deletions

View File

@@ -386,13 +386,18 @@ func Serve(
}
controller.Register(RunMetricsServer)
var remoteSrv *remotesrv.Server
var remoteSrvListeners remotesrv.Listeners
type RemoteSrvService struct {
state svcs.ServiceState
lis remotesrv.Listeners
srv *remotesrv.Server
}
var remoteSrv RemoteSrvService
RunRemoteSrv := &svcs.AnonService{
InitF: func(ctx context.Context) error {
if serverConfig.RemotesapiPort() == nil {
return nil
}
remoteSrv.state.Swap(svcs.ServiceState_Init)
port := *serverConfig.RemotesapiPort()
listenaddr := fmt.Sprintf(":%d", port)
@@ -412,12 +417,12 @@ func Serve(
args = sqle.WithUserPasswordAuth(args, authenticator)
args.TLSConfig = serverConf.TLSConfig
remoteSrv, err = remotesrv.NewServer(args)
remoteSrv.srv, err = remotesrv.NewServer(args)
if err != nil {
lgr.Errorf("error creating remotesapi server on port %d: %v", port, err)
return err
}
remoteSrvListeners, err = remoteSrv.Listeners()
remoteSrv.lis, err = remoteSrv.srv.Listeners()
if err != nil {
lgr.Errorf("error starting remotesapi server listeners on port %d: %v", port, err)
return err
@@ -425,28 +430,29 @@ func Serve(
return nil
},
RunF: func(ctx context.Context) {
if remoteSrv == nil {
return
if remoteSrv.state.CompareAndSwap(svcs.ServiceState_Init, svcs.ServiceState_Run) {
remoteSrv.srv.Serve(remoteSrv.lis)
}
remoteSrv.Serve(remoteSrvListeners)
},
StopF: func() error {
if remoteSrv == nil {
return nil
state := remoteSrv.state.Swap(svcs.ServiceState_Stopped)
if state == svcs.ServiceState_Run {
remoteSrv.srv.GracefulStop()
} else if state == svcs.ServiceState_Init {
remoteSrv.lis.Close()
}
remoteSrv.GracefulStop()
return nil
},
}
controller.Register(RunRemoteSrv)
var clusterRemoteSrv *remotesrv.Server
var clusterRemoteSrvListeners remotesrv.Listeners
var clusterRemoteSrv RemoteSrvService
RunClusterRemoteSrv := &svcs.AnonService{
InitF: func(context.Context) error {
if clusterController == nil {
return nil
}
clusterRemoteSrv.state.Swap(svcs.ServiceState_Init)
args, err := clusterController.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{
Logger: logrus.NewEntry(lgr),
@@ -463,14 +469,14 @@ func Serve(
}
args.TLSConfig = clusterRemoteSrvTLSConfig
clusterRemoteSrv, err = remotesrv.NewServer(args)
clusterRemoteSrv.srv, err = remotesrv.NewServer(args)
if err != nil {
lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err)
return err
}
clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.GrpcServer())
clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.srv.GrpcServer())
clusterRemoteSrvListeners, err = clusterRemoteSrv.Listeners()
clusterRemoteSrv.lis, err = clusterRemoteSrv.srv.Listeners()
if err != nil {
lgr.Errorf("error starting remotesapi server listeners for cluster config on %s: %v", clusterController.RemoteSrvListenAddr(), err)
return err
@@ -478,16 +484,17 @@ func Serve(
return nil
},
RunF: func(context.Context) {
if clusterRemoteSrv == nil {
return
if clusterRemoteSrv.state.CompareAndSwap(svcs.ServiceState_Init, svcs.ServiceState_Run) {
clusterRemoteSrv.srv.Serve(clusterRemoteSrv.lis)
}
clusterRemoteSrv.Serve(clusterRemoteSrvListeners)
},
StopF: func() error {
if clusterRemoteSrv == nil {
return nil
state := clusterRemoteSrv.state.Swap(svcs.ServiceState_Stopped)
if state == svcs.ServiceState_Run {
clusterRemoteSrv.srv.GracefulStop()
} else if state == svcs.ServiceState_Init {
clusterRemoteSrv.lis.Close()
}
clusterRemoteSrv.GracefulStop()
return nil
},
}

View File

@@ -137,6 +137,22 @@ type Listeners struct {
grpc net.Listener
}
func (l Listeners) Close() error {
if l.http != nil {
err := l.http.Close()
if err != nil {
if l.grpc != nil {
l.grpc.Close()
}
return err
}
}
if l.grpc != nil {
return l.grpc.Close()
}
return nil
}
func (s *Server) Listeners() (Listeners, error) {
var httpListener net.Listener
var grpcListener net.Listener