diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index f4413fb547..3b3075f259 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -386,13 +386,18 @@ func Serve( } controller.Register(RunMetricsServer) - var remoteSrv *remotesrv.Server - var remoteSrvListeners remotesrv.Listeners + type RemoteSrvService struct { + state svcs.ServiceState + lis remotesrv.Listeners + srv *remotesrv.Server + } + var remoteSrv RemoteSrvService RunRemoteSrv := &svcs.AnonService{ InitF: func(ctx context.Context) error { if serverConfig.RemotesapiPort() == nil { return nil } + remoteSrv.state.Swap(svcs.ServiceState_Init) port := *serverConfig.RemotesapiPort() listenaddr := fmt.Sprintf(":%d", port) @@ -412,12 +417,12 @@ func Serve( args = sqle.WithUserPasswordAuth(args, authenticator) args.TLSConfig = serverConf.TLSConfig - remoteSrv, err = remotesrv.NewServer(args) + remoteSrv.srv, err = remotesrv.NewServer(args) if err != nil { lgr.Errorf("error creating remotesapi server on port %d: %v", port, err) return err } - remoteSrvListeners, err = remoteSrv.Listeners() + remoteSrv.lis, err = remoteSrv.srv.Listeners() if err != nil { lgr.Errorf("error starting remotesapi server listeners on port %d: %v", port, err) return err @@ -425,28 +430,29 @@ func Serve( return nil }, RunF: func(ctx context.Context) { - if remoteSrv == nil { - return + if remoteSrv.state.CompareAndSwap(svcs.ServiceState_Init, svcs.ServiceState_Run) { + remoteSrv.srv.Serve(remoteSrv.lis) } - remoteSrv.Serve(remoteSrvListeners) }, StopF: func() error { - if remoteSrv == nil { - return nil + state := remoteSrv.state.Swap(svcs.ServiceState_Stopped) + if state == svcs.ServiceState_Run { + remoteSrv.srv.GracefulStop() + } else if state == svcs.ServiceState_Init { + remoteSrv.lis.Close() } - remoteSrv.GracefulStop() return nil }, } controller.Register(RunRemoteSrv) - var clusterRemoteSrv *remotesrv.Server - var clusterRemoteSrvListeners remotesrv.Listeners + var clusterRemoteSrv RemoteSrvService RunClusterRemoteSrv := &svcs.AnonService{ InitF: func(context.Context) error { if clusterController == nil { return nil } + clusterRemoteSrv.state.Swap(svcs.ServiceState_Init) args, err := clusterController.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{ Logger: logrus.NewEntry(lgr), @@ -463,14 +469,14 @@ func Serve( } args.TLSConfig = clusterRemoteSrvTLSConfig - clusterRemoteSrv, err = remotesrv.NewServer(args) + clusterRemoteSrv.srv, err = remotesrv.NewServer(args) if err != nil { lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err) return err } - clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.GrpcServer()) + clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.srv.GrpcServer()) - clusterRemoteSrvListeners, err = clusterRemoteSrv.Listeners() + clusterRemoteSrv.lis, err = clusterRemoteSrv.srv.Listeners() if err != nil { lgr.Errorf("error starting remotesapi server listeners for cluster config on %s: %v", clusterController.RemoteSrvListenAddr(), err) return err @@ -478,16 +484,17 @@ func Serve( return nil }, RunF: func(context.Context) { - if clusterRemoteSrv == nil { - return + if clusterRemoteSrv.state.CompareAndSwap(svcs.ServiceState_Init, svcs.ServiceState_Run) { + clusterRemoteSrv.srv.Serve(clusterRemoteSrv.lis) } - clusterRemoteSrv.Serve(clusterRemoteSrvListeners) }, StopF: func() error { - if clusterRemoteSrv == nil { - return nil + state := clusterRemoteSrv.state.Swap(svcs.ServiceState_Stopped) + if state == svcs.ServiceState_Run { + clusterRemoteSrv.srv.GracefulStop() + } else if state == svcs.ServiceState_Init { + clusterRemoteSrv.lis.Close() } - clusterRemoteSrv.GracefulStop() return nil }, } diff --git a/go/libraries/doltcore/remotesrv/server.go b/go/libraries/doltcore/remotesrv/server.go index a9fa8428ff..b3d5b37eca 100644 --- a/go/libraries/doltcore/remotesrv/server.go +++ b/go/libraries/doltcore/remotesrv/server.go @@ -137,6 +137,22 @@ type Listeners struct { grpc net.Listener } +func (l Listeners) Close() error { + if l.http != nil { + err := l.http.Close() + if err != nil { + if l.grpc != nil { + l.grpc.Close() + } + return err + } + } + if l.grpc != nil { + return l.grpc.Close() + } + return nil +} + func (s *Server) Listeners() (Listeners, error) { var httpListener net.Listener var grpcListener net.Listener