// Copyright 2019-2020 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package sqlserver import ( "context" "crypto/tls" "errors" "fmt" "net" "net/http" "os" "path/filepath" "strconv" "strings" "time" "github.com/dolthub/go-mysql-server/eventscheduler" "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/server/golden" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/mysql_db" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/vitess/go/mysql" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" goerrors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/cmd/dolt/commands" "github.com/dolthub/dolt/go/cmd/dolt/commands/engine" eventsapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi/v1alpha1" remotesapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/remotesapi/v1alpha1" "github.com/dolthub/dolt/go/libraries/doltcore/dconfig" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/remotesrv" "github.com/dolthub/dolt/go/libraries/doltcore/servercfg" "github.com/dolthub/dolt/go/libraries/doltcore/sqle" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/binlogreplication" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/cluster" _ "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/doltcore/sqlserver" "github.com/dolthub/dolt/go/libraries/events" "github.com/dolthub/dolt/go/libraries/utils/config" "github.com/dolthub/dolt/go/libraries/utils/filesys" "github.com/dolthub/dolt/go/libraries/utils/svcs" ) const ( LocalConnectionUser = "__dolt_local_user__" ApiSqleContextKey = "__sqle_context__" ) // sqlServerHeartbeatIntervalEnvVar is the duration between heartbeats sent to the remote server, used for testing const sqlServerHeartbeatIntervalEnvVar = "DOLT_SQL_SERVER_HEARTBEAT_INTERVAL" // ExternalDisableUsers is called by implementing applications to disable users. This is not used by Dolt itself, // but will break compatibility with implementing applications that do not yet support users. var ExternalDisableUsers bool = false var ErrCouldNotLockDatabase = goerrors.NewKind("database \"%s\" is locked by another dolt process; either clone the database to run a second server, or stop the dolt process which currently holds an exclusive write lock on the database") // Serve starts a MySQL-compatible server. Returns any errors that were encountered. func Serve( ctx context.Context, version string, serverConfig servercfg.ServerConfig, controller *svcs.Controller, dEnv *env.DoltEnv, skipRootUserInitialization bool, ) (startError error, closeError error) { // Code is easier to work through if we assume that serverController is never nil if controller == nil { controller = svcs.NewController() } ConfigureServices(serverConfig, controller, version, dEnv, skipRootUserInitialization) go controller.Start(ctx) err := controller.WaitForStart() if err != nil { return err, nil } return nil, controller.WaitForStop() } func ConfigureServices( serverConfig servercfg.ServerConfig, controller *svcs.Controller, version string, dEnv *env.DoltEnv, skipRootUserInitialization bool, ) { ValidateConfigStep := &svcs.AnonService{ InitF: func(context.Context) error { return servercfg.ValidateConfig(serverConfig) }, } controller.Register(ValidateConfigStep) lgr := logrus.StandardLogger() lgr.SetOutput(cli.CliErr) InitLogging := &svcs.AnonService{ InitF: func(context.Context) error { level, err := logrus.ParseLevel(serverConfig.LogLevel().String()) if err != nil { return err } logrus.SetLevel(level) format := strings.ToLower(fmt.Sprintf("%v", serverConfig.LogFormat())) switch format { case "json": logrus.SetFormatter(&logrus.JSONFormatter{}) default: logrus.SetFormatter(&logrus.TextFormatter{}) } sql.SystemVariables.AddSystemVariables([]sql.SystemVariable{ &sql.MysqlSystemVariable{ Name: dsess.DoltLogLevel, Scope: sql.GetMysqlScope(sql.SystemVariableScope_Global), Dynamic: true, SetVarHintApplies: false, Type: types.NewSystemEnumType(dsess.DoltLogLevel, logrus.PanicLevel.String(), logrus.FatalLevel.String(), logrus.ErrorLevel.String(), logrus.WarnLevel.String(), logrus.InfoLevel.String(), logrus.DebugLevel.String(), logrus.TraceLevel.String(), ), Default: logrus.GetLevel().String(), NotifyChanged: func(_ sql.SystemVariableScope, v sql.SystemVarValue) error { level, err := logrus.ParseLevel(v.Val.(string)) if err != nil { return fmt.Errorf("could not parse requested log level %s as a log level. dolt_log_level variable value and logging behavior will diverge.", v.Val.(string)) } logrus.SetLevel(level) return nil }, }, }) return nil }, } controller.Register(InitLogging) controller.Register(newHeartbeatService(version, dEnv)) fs := dEnv.FS InitFailsafes := &svcs.AnonService{ InitF: func(ctx context.Context) (err error) { dEnv.Config.SetFailsafes(env.DefaultFailsafeConfig) return nil }, } controller.Register(InitFailsafes) var mrEnv *env.MultiRepoEnv InitMultiEnv := &svcs.AnonService{ InitF: func(ctx context.Context) (err error) { mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv) return err }, } controller.Register(InitMultiEnv) AssertNoDatabasesInAccessModeReadOnly := &svcs.AnonService{ InitF: func(ctx context.Context) (err error) { return mrEnv.Iter(func(name string, dEnv *env.DoltEnv) (stop bool, err error) { if dEnv.IsAccessModeReadOnly(ctx) { return true, ErrCouldNotLockDatabase.New(name) } return false, nil }) }, } controller.Register(AssertNoDatabasesInAccessModeReadOnly) var localCreds *LocalCreds InitServerLocalCreds := &svcs.AnonService{ InitF: func(context.Context) (err error) { localCreds, err = persistServerLocalCreds(serverConfig.Port(), dEnv) return err }, StopF: func() error { RemoveLocalCreds(dEnv.FS) return nil }, } controller.Register(InitServerLocalCreds) var clusterController *cluster.Controller InitClusterController := &svcs.AnonService{ InitF: func(context.Context) (err error) { clusterController, err = cluster.NewController(lgr, serverConfig.ClusterConfig(), mrEnv.Config()) return err }, } controller.Register(InitClusterController) var serverConf server.Config LoadServerConfig := &svcs.AnonService{ InitF: func(context.Context) (err error) { serverConf, err = getConfigFromServerConfig(serverConfig) return err }, } controller.Register(LoadServerConfig) // Create SQL Engine with users var config *engine.SqlEngineConfig InitSqlEngineConfig := &svcs.AnonService{ InitF: func(context.Context) error { config = &engine.SqlEngineConfig{ IsReadOnly: serverConfig.ReadOnly(), PrivFilePath: serverConfig.PrivilegeFilePath(), BranchCtrlFilePath: serverConfig.BranchControlFilePath(), DoltCfgDirPath: serverConfig.CfgDir(), ServerUser: serverConfig.User(), ServerPass: serverConfig.Password(), ServerHost: serverConfig.Host(), Autocommit: serverConfig.AutoCommit(), DoltTransactionCommit: serverConfig.DoltTransactionCommit(), JwksConfig: serverConfig.JwksConfig(), SystemVariables: serverConfig.SystemVars(), ClusterController: clusterController, BinlogReplicaController: binlogreplication.DoltBinlogReplicaController, SkipRootUserInitialization: skipRootUserInitialization, } return nil }, } controller.Register(InitSqlEngineConfig) var esStatus eventscheduler.SchedulerStatus InitEventSchedulerStatus := &svcs.AnonService{ InitF: func(context.Context) (err error) { esStatus, err = getEventSchedulerStatus(serverConfig.EventSchedulerStatus()) if err != nil { return err } config.EventSchedulerStatus = esStatus return nil }, } controller.Register(InitEventSchedulerStatus) InitAutoGCController := &svcs.AnonService{ InitF: func(context.Context) error { if serverConfig.AutoGCBehavior() != nil && serverConfig.AutoGCBehavior().Enable() { config.AutoGCController = sqle.NewAutoGCController(lgr) } return nil }, } controller.Register(InitAutoGCController) var sqlEngine *engine.SqlEngine InitSqlEngine := &svcs.AnonService{ InitF: func(ctx context.Context) (err error) { if statsOn, err := mrEnv.Config().GetString(env.SqlServerGlobalsPrefix + "." + dsess.DoltStatsAutoRefreshEnabled); err != nil { // Auto-stats is off by default for every command except // sql-server. Unless the config specifies a specific // behavior, enable server stats collection. sql.SystemVariables.SetGlobal(dsess.DoltStatsAutoRefreshEnabled, 1) } else if statsOn != "0" { // do not bootstrap if auto-stats enabled } else if _, err := mrEnv.Config().GetString(env.SqlServerGlobalsPrefix + "." + dsess.DoltStatsBootstrapEnabled); err != nil { // If we've disabled stats collection and config does not // specify bootstrap behavior, enable bootstrapping. sql.SystemVariables.SetGlobal(dsess.DoltStatsBootstrapEnabled, 1) } sqlEngine, err = engine.NewSqlEngine( ctx, mrEnv, config, ) return err }, StopF: func() error { sqlEngine.Close() return nil }, } controller.Register(InitSqlEngine) // Persist any system variables that have a non-deterministic default value (i.e. @@server_uuid) // We only do this on sql-server startup initially since we want to keep the persisted server_uuid // in the configuration files for a sql-server, and not global for the whole host. PersistNondeterministicSystemVarDefaults := &svcs.AnonService{ InitF: func(ctx context.Context) error { err := dsess.PersistSystemVarDefaults(dEnv) if err != nil { logrus.Errorf("unable to persist system variable defaults: %v", err) } // Always return nil, because we don't want an invalid config value to prevent // the server from starting up. return nil }, } controller.Register(PersistNondeterministicSystemVarDefaults) InitBinlogging := &svcs.AnonService{ InitF: func(context.Context) error { primaryController := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.BinlogPrimaryController doltBinlogPrimaryController, ok := primaryController.(*binlogreplication.DoltBinlogPrimaryController) if !ok { return fmt.Errorf("unexpected type of binlog controller: %T", primaryController) } _, logBinValue, ok := sql.SystemVariables.GetGlobal("log_bin") if !ok { return fmt.Errorf("unable to load @@log_bin system variable") } logBin, ok := logBinValue.(int8) if !ok { return fmt.Errorf("unexpected type for @@log_bin system variable: %T", logBinValue) } _, logBinBranchValue, ok := sql.SystemVariables.GetGlobal("log_bin_branch") if !ok { return fmt.Errorf("unable to load @@log_bin_branch system variable") } logBinBranch, ok := logBinBranchValue.(string) if !ok { return fmt.Errorf("unexpected type for @@log_bin_branch system variable: %T", logBinBranchValue) } if logBinBranch != "" { // If an invalid branch has been configured, let the server start up so that it's // easier for customers to correct the value, but log a warning and don't enable // binlog replication. if strings.Contains(logBinBranch, "/") { logrus.Warnf("branch names containing '/' are not supported "+ "for binlog replication. Not enabling binlog replication; fix "+ "@@log_bin_branch value and restart Dolt (current value: %s)", logBinBranch) return nil } binlogreplication.BinlogBranch = logBinBranch } if logBin == 1 { logrus.Infof("Enabling binary logging for branch %s", logBinBranch) binlogProducer, err := binlogreplication.NewBinlogProducer(dEnv.FS) if err != nil { return err } logManager, err := binlogreplication.NewLogManager(fs) if err != nil { return err } binlogProducer.LogManager(logManager) doltdb.RegisterDatabaseUpdateListener(binlogProducer) doltBinlogPrimaryController.BinlogProducer(binlogProducer) // Register binlog hooks for database creation/deletion provider := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.DbProvider if doltProvider, ok := provider.(*sqle.DoltDatabaseProvider); ok { doltProvider.AddInitDatabaseHook(binlogreplication.NewBinlogInitDatabaseHook(nil, doltdb.DatabaseUpdateListeners)) doltProvider.AddDropDatabaseHook(binlogreplication.NewBinlogDropDatabaseHook(nil, doltdb.DatabaseUpdateListeners)) } } return nil }, } controller.Register(InitBinlogging) // MySQL creates a root superuser when the mysql install is first initialized. Depending on the options // specified, the root superuser is created without a password, or with a random password. This varies // slightly in some OS-specific installers. Dolt initializes the root superuser the first time a // sql-server is started and initializes its privileges database. We do this on sql-server initialization, // instead of dolt db initialization, because we only want to create the privileges database when it's // used for a server, and because we want the same root initialization logic when a sql-server is started // for a clone. More details: https://dev.mysql.com/doc/mysql-security-excerpt/8.0/en/default-privileges.html InitImplicitRootSuperUser := &svcs.AnonService{ InitF: func(ctx context.Context) error { // If privileges.db has already been initialized, indicating that this is NOT the // first time sql-server has been launched, then don't initialize the root superuser. if permissionDbExists, err := doesPrivilegesDbExist(dEnv, serverConfig.PrivilegeFilePath()); err != nil { return err } else if permissionDbExists { logrus.Debug("privileges.db already exists, not creating root superuser") return nil } // We always persist the privileges.db file, to signal that the privileges system has been initialized mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb ed := mysqlDb.Editor() defer ed.Close() // Create the root@localhost superuser, unless --skip-root-user-initialization was specified if !config.SkipRootUserInitialization { // Allow the user to override the default root host (localhost) and password (""). // This is particularly useful in a Docker container, where you need to connect // to the sql-server from outside the container and can't rely on localhost. rootHost := "localhost" doltRootHost := os.Getenv(dconfig.EnvDoltRootHost) if doltRootHost != "" { logrus.Infof("Overriding root user host with value from DOLT_ROOT_HOST: %s", doltRootHost) rootHost = doltRootHost } rootPassword := servercfg.DefaultPass doltRootPassword := os.Getenv(dconfig.EnvDoltRootPassword) if doltRootPassword != "" { logrus.Info("Overriding root user password with value from DOLT_ROOT_PASSWORD") rootPassword = doltRootPassword } logrus.Infof("Creating root@%s superuser", rootHost) mysqlDb.AddSuperUser(ed, servercfg.DefaultUser, rootHost, rootPassword) } // TODO: The in-memory filesystem doesn't work with the GMS API // for persisting the privileges database. The filesys API // is in the Dolt layer, so when the file path is passed to // GMS, it expects it to be a path on disk, and errors out. if _, isInMemFs := dEnv.FS.(*filesys.InMemFS); isInMemFs { return nil } else { sqlCtx, err := sqlEngine.NewDefaultContext(context.Background()) if err != nil { return err } return mysqlDb.Persist(sqlCtx, ed) } }, } controller.Register(InitImplicitRootSuperUser) var metListener *metricsListener InitMetricsListener := &svcs.AnonService{ InitF: func(context.Context) (err error) { labels := serverConfig.MetricsLabels() metListener, err = newMetricsListener(labels, version, clusterController) return err }, StopF: func() error { metListener.Close() return nil }, } controller.Register(InitMetricsListener) InitLockSuperUser := &svcs.AnonService{ InitF: func(context.Context) error { mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb ed := mysqlDb.Editor() mysqlDb.AddEphemeralSuperUser(ed, LocalConnectionUser, "localhost", localCreds.Secret) ed.Close() return nil }, } controller.Register(InitLockSuperUser) DisableMySQLDbIfRequired := &svcs.AnonService{ InitF: func(context.Context) error { if ExternalDisableUsers { mysqlDb := sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb mysqlDb.SetEnabled(false) } return nil }, } controller.Register(DisableMySQLDbIfRequired) type SQLMetricsService struct { state svcs.ServiceState lis net.Listener srv *http.Server } var metSrv SQLMetricsService RunMetricsServer := &svcs.AnonService{ InitF: func(context.Context) (err error) { if serverConfig.MetricsHost() != "" && serverConfig.MetricsPort() > 0 { metSrv.state.Swap(svcs.ServiceState_Init) addr := fmt.Sprintf("%s:%d", serverConfig.MetricsHost(), serverConfig.MetricsPort()) metSrv.lis, err = net.Listen("tcp", addr) if err != nil { return err } mux := http.NewServeMux() mux.Handle("/metrics", promhttp.Handler()) metSrv.srv = &http.Server{ Addr: addr, Handler: mux, } } return nil }, RunF: func(context.Context) { if metSrv.state.CompareAndSwap(svcs.ServiceState_Init, svcs.ServiceState_Run) { _ = metSrv.srv.Serve(metSrv.lis) } }, StopF: func() error { state := metSrv.state.Swap(svcs.ServiceState_Stopped) if state == svcs.ServiceState_Run { metSrv.srv.Close() } else if state == svcs.ServiceState_Init { metSrv.lis.Close() } return nil }, } controller.Register(RunMetricsServer) type RemoteSrvService struct { state svcs.ServiceState lis remotesrv.Listeners srv *remotesrv.Server } var remoteSrv RemoteSrvService RunRemoteSrv := &svcs.AnonService{ InitF: func(ctx context.Context) error { if serverConfig.RemotesapiPort() == nil { return nil } remoteSrv.state.Swap(svcs.ServiceState_Init) port := *serverConfig.RemotesapiPort() apiReadOnly := false if serverConfig.RemotesapiReadOnly() != nil { apiReadOnly = *serverConfig.RemotesapiReadOnly() } listenaddr := fmt.Sprintf(":%d", port) sqlContextInterceptor := sqle.SqlContextServerInterceptor{ Factory: sqlEngine.NewDefaultContext, } args := remotesrv.ServerArgs{ Logger: logrus.NewEntry(lgr), ReadOnly: apiReadOnly || serverConfig.ReadOnly(), HttpListenAddr: listenaddr, GrpcListenAddr: listenaddr, ConcurrencyControl: remotesapi.PushConcurrencyControl_PUSH_CONCURRENCY_CONTROL_ASSERT_WORKING_SET, Options: sqlContextInterceptor.Options(), HttpInterceptor: sqlContextInterceptor.HTTP(nil), } var err error args.FS = sqlEngine.FileSystem() args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.DoNotCreateUnknownDatabases) if err != nil { lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) return err } authenticator := newAccessController(sqle.GetInterceptorSqlContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb) args = sqle.WithUserPasswordAuth(args, authenticator) args.TLSConfig = serverConf.TLSConfig remoteSrv.srv, err = remotesrv.NewServer(args) if err != nil { lgr.Errorf("error creating remotesapi server on port %d: %v", port, err) return err } remoteSrv.lis, err = remoteSrv.srv.Listeners() if err != nil { lgr.Errorf("error starting remotesapi server listeners on port %d: %v", port, err) return err } return nil }, RunF: func(ctx context.Context) { if remoteSrv.state.CompareAndSwap(svcs.ServiceState_Init, svcs.ServiceState_Run) { remoteSrv.srv.Serve(remoteSrv.lis) } }, StopF: func() error { state := remoteSrv.state.Swap(svcs.ServiceState_Stopped) if state == svcs.ServiceState_Run { remoteSrv.srv.GracefulStop() } else if state == svcs.ServiceState_Init { remoteSrv.lis.Close() } return nil }, } controller.Register(RunRemoteSrv) var clusterRemoteSrv RemoteSrvService RunClusterRemoteSrv := &svcs.AnonService{ InitF: func(context.Context) error { if clusterController == nil { return nil } clusterRemoteSrv.state.Swap(svcs.ServiceState_Init) args, err := clusterController.RemoteSrvServerArgs(sqlEngine.NewDefaultContext, remotesrv.ServerArgs{ Logger: logrus.NewEntry(lgr), }) if err != nil { lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) return err } args.FS = sqlEngine.FileSystem() clusterRemoteSrvTLSConfig, err := LoadClusterTLSConfig(serverConfig.ClusterConfig()) if err != nil { lgr.Errorf("error starting remotesapi server for cluster config, could not load tls config: %v", err) return err } args.TLSConfig = clusterRemoteSrvTLSConfig clusterRemoteSrv.srv, err = remotesrv.NewServer(args) if err != nil { lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err) return err } clusterController.RegisterGrpcServices(sqle.GetInterceptorSqlContext, clusterRemoteSrv.srv.GrpcServer()) clusterRemoteSrv.lis, err = clusterRemoteSrv.srv.Listeners() if err != nil { lgr.Errorf("error starting remotesapi server listeners for cluster config on %s: %v", clusterController.RemoteSrvListenAddr(), err) return err } return nil }, RunF: func(context.Context) { if clusterRemoteSrv.state.CompareAndSwap(svcs.ServiceState_Init, svcs.ServiceState_Run) { clusterRemoteSrv.srv.Serve(clusterRemoteSrv.lis) } }, StopF: func() error { state := clusterRemoteSrv.state.Swap(svcs.ServiceState_Stopped) if state == svcs.ServiceState_Run { clusterRemoteSrv.srv.GracefulStop() } else if state == svcs.ServiceState_Init { clusterRemoteSrv.lis.Close() } return nil }, } controller.Register(RunClusterRemoteSrv) // We still have some startup to do from this point, and we do not run // the SQL server until we are fully booted. We also want to stop the // SQL server as the first thing we stop. However, if startup fails // during initialization, we want to shutdown the SQL server cleanly. // So we track whether the server has been shutdown by either service // which is responsible for it and we only do it here if it hasn't // already been Closed. var sqlServerClosed bool var mySQLServer *server.Server InitSQLServer := &svcs.AnonService{ InitF: func(context.Context) (err error) { v, ok := serverConfig.(servercfg.ValidatingServerConfig) if ok && v.GoldenMysqlConnectionString() != "" { mySQLServer, err = server.NewServerWithHandler( serverConf, sqlEngine.GetUnderlyingEngine(), newSessionBuilder(sqlEngine, serverConfig), metListener, func(h mysql.Handler) (mysql.Handler, error) { return golden.NewValidatingHandler(h, v.GoldenMysqlConnectionString(), logrus.StandardLogger()) }, ) } else { mySQLServer, err = server.NewServer( serverConf, sqlEngine.GetUnderlyingEngine(), newSessionBuilder(sqlEngine, serverConfig), metListener, ) } if errors.Is(err, server.UnixSocketInUseError) { lgr.Warn("unix socket set up failed: file already in use: ", serverConf.Socket) err = nil } return err }, StopF: func() (err error) { if !sqlServerClosed { sqlServerClosed = true return mySQLServer.Close() } return nil }, } controller.Register(InitSQLServer) // Automatically restart binlog replication if replication was enabled when the server was last shut down AutoStartBinlogReplica := &svcs.AnonService{ InitF: func(ctx context.Context) error { // If we're unable to restart replication, log an error, but don't prevent the server from starting up sqlCtx, err := sqlEngine.NewDefaultContext(ctx) if err != nil { logrus.Errorf("unable to restart replication, could not create session: %s", err.Error()) return nil } defer sql.SessionEnd(sqlCtx.Session) if err := binlogreplication.DoltBinlogReplicaController.AutoStart(sqlCtx); err != nil { logrus.Errorf("unable to restart replication: %s", err.Error()) } return nil }, } controller.Register(AutoStartBinlogReplica) RunClusterController := &svcs.AnonService{ InitF: func(context.Context) error { if clusterController == nil { return nil } clusterController.ManageQueryConnections( mySQLServer.SessionManager().Iter, sqlEngine.GetUnderlyingEngine().ProcessList.Kill, mySQLServer.SessionManager().KillConnection, ) return nil }, RunF: func(context.Context) { if clusterController == nil { return } clusterController.Run() }, StopF: func() error { if clusterController == nil { return nil } clusterController.GracefulStop() return nil }, } controller.Register(RunClusterController) RunSQLServer := &svcs.AnonService{ RunF: func(context.Context) { sqlserver.SetRunningServer(mySQLServer) defer sqlserver.UnsetRunningServer() mySQLServer.Start() }, StopF: func() error { sqlServerClosed = true return mySQLServer.Close() }, } controller.Register(RunSQLServer) } // heartbeatService is a service that sends a heartbeat event to the metrics server once a day type heartbeatService struct { version string eventEmitter events.Emitter interval time.Duration closer func() error } func newHeartbeatService(version string, dEnv *env.DoltEnv) *heartbeatService { metricsDisabled := dEnv.Config.GetStringOrDefault(config.MetricsDisabled, "false") disabled, err := strconv.ParseBool(metricsDisabled) if err != nil || disabled { return &heartbeatService{} // will be defunct on Run() } emitterType, ok := os.LookupEnv(events.EmitterTypeEnvVar) if !ok { emitterType = events.EmitterTypeGrpc } interval, ok := os.LookupEnv(sqlServerHeartbeatIntervalEnvVar) if !ok { interval = "24h" } duration, err := time.ParseDuration(interval) if err != nil { return &heartbeatService{} // will be defunct on Run() } emitter, closer, err := commands.NewEmitter(emitterType, dEnv) if err != nil { return &heartbeatService{} // will be defunct on Run() } events.SetGlobalCollector(events.NewCollector(version, emitter)) return &heartbeatService{ version: version, eventEmitter: emitter, interval: duration, closer: closer, } } func (h *heartbeatService) Init(ctx context.Context) error { return nil } func (h *heartbeatService) Stop() error { if h.closer != nil { return h.closer() } return nil } func (h *heartbeatService) Run(ctx context.Context) { // Faulty config settings or disabled metrics can cause us to not have a valid event emitter if h.eventEmitter == nil { return } ticker := time.NewTicker(h.interval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: t := events.NowTimestamp() logrus.Debugf("sending heartbeat event to %s:%s", events.DefaultMetricsHost, events.DefaultMetricsPort) err := h.eventEmitter.LogEvents(ctx, h.version, []*eventsapi.ClientEvent{ { Id: uuid.New().String(), StartTime: t, EndTime: t, Type: eventsapi.ClientEventType_SQL_SERVER_HEARTBEAT, }, }) if err != nil { logrus.Debugf("failed to send heartbeat event: %v", err) } } } } var _ svcs.Service = &heartbeatService{} func persistServerLocalCreds(port int, dEnv *env.DoltEnv) (*LocalCreds, error) { creds := NewLocalCreds(port) err := WriteLocalCreds(dEnv.FS, creds) if err != nil { return nil, err } return creds, err } // remotesapiAuth facilitates the implementation remotesrv.AccessControl for the remotesapi server. type remotesapiAuth struct { // ctxFactory is a function that returns a new sql.Context. This will create a new context every time it is called, // so it should be called once per API request. ctxFactory func(context.Context) (*sql.Context, error) rawDb *mysql_db.MySQLDb } func newAccessController(ctxFactory func(context.Context) (*sql.Context, error), rawDb *mysql_db.MySQLDb) remotesrv.AccessControl { return &remotesapiAuth{ctxFactory, rawDb} } // ApiAuthenticate checks the provided credentials against the database and return a SQL context if the credentials are // valid. If the credentials are invalid, then a nil context is returned. Failures to authenticate are logged. func (r *remotesapiAuth) ApiAuthenticate(ctx context.Context) (context.Context, error) { creds, err := remotesrv.ExtractBasicAuthCreds(ctx) if err != nil { return nil, err } err = commands.ValidatePasswordWithAuthResponse(r.rawDb, creds.Username, creds.Password) if err != nil { return nil, fmt.Errorf("API Authentication Failure: %v", err) } address := creds.Address if strings.Index(address, ":") > 0 { address, _, err = net.SplitHostPort(creds.Address) if err != nil { return nil, fmt.Errorf("Invalid Host string for authentication: %s", creds.Address) } } sqlCtx, err := r.ctxFactory(ctx) if err != nil { return nil, fmt.Errorf("API Runtime error: %v", err) } sqlCtx.Session.SetClient(sql.Client{User: creds.Username, Address: address, Capabilities: 0}) updatedCtx := context.WithValue(ctx, ApiSqleContextKey, sqlCtx) return updatedCtx, nil } func (r *remotesapiAuth) ApiAuthorize(ctx context.Context, superUserRequired bool) (bool, error) { sqlCtx, ok := ctx.Value(ApiSqleContextKey).(*sql.Context) if !ok { return false, fmt.Errorf("Runtime error: could not get SQL context from context") } privOp := sql.NewDynamicPrivilegedOperation(plan.DynamicPrivilege_CloneAdmin) if superUserRequired { database := sqlCtx.GetCurrentDatabase() subject := sql.PrivilegeCheckSubject{Database: database} privOp = sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Super) } authorized := r.rawDb.UserHasPrivileges(sqlCtx, privOp) if !authorized { if superUserRequired { return false, fmt.Errorf("API Authorization Failure: %s has not been granted SuperUser access", sqlCtx.Session.Client().User) } return false, fmt.Errorf("API Authorization Failure: %s has not been granted CLONE_ADMIN access", sqlCtx.Session.Client().User) } return true, nil } // doesPrivilegesDbExist looks for an existing privileges database as the specified |privilegeFilePath|. If // |privilegeFilePath| is an absolute path, it is used directly. If it is a relative path, then it is resolved // relative to the root of the specified |dEnv|. func doesPrivilegesDbExist(dEnv *env.DoltEnv, privilegeFilePath string) (exists bool, err error) { if !filepath.IsAbs(privilegeFilePath) { privilegeFilePath, err = dEnv.FS.Abs(privilegeFilePath) if err != nil { return false, err } } _, err = os.Stat(privilegeFilePath) if err != nil { if os.IsNotExist(err) { return false, nil } else { return false, err } } return true, nil } func LoadClusterTLSConfig(cfg servercfg.ClusterConfig) (*tls.Config, error) { rcfg := cfg.RemotesAPIConfig() if rcfg.TLSKey() == "" && rcfg.TLSCert() == "" { return nil, nil } c, err := tls.LoadX509KeyPair(rcfg.TLSCert(), rcfg.TLSKey()) if err != nil { return nil, err } return &tls.Config{ Certificates: []tls.Certificate{ c, }, }, nil } func portInUse(hostPort string) bool { timeout := time.Second conn, _ := net.DialTimeout("tcp", hostPort, timeout) if conn != nil { defer conn.Close() return true } return false } func newSessionBuilder(se *engine.SqlEngine, config servercfg.ServerConfig) server.SessionBuilder { userToSessionVars := make(map[string]map[string]interface{}) userVars := config.UserVars() for _, curr := range userVars { userToSessionVars[curr.Name] = curr.Vars } return func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { baseSession, err := sql.BaseSessionFromConnection(ctx, conn, addr) if err != nil { return nil, err } dsess, err := se.NewDoltSession(ctx, baseSession) if err != nil { return nil, err } varsForUser := userToSessionVars[conn.User] if len(varsForUser) > 0 { sqlCtx, err := se.NewContext(ctx, dsess) if err != nil { return nil, err } for key, val := range varsForUser { err = dsess.InitSessionVariable(sqlCtx, key, val) if err != nil { return nil, err } } } return dsess, nil } } // getConfigFromServerConfig processes ServerConfig and returns server.Config for sql-server. func getConfigFromServerConfig(serverConfig servercfg.ServerConfig) (server.Config, error) { serverConf, err := handleProtocolAndAddress(serverConfig) if err != nil { return server.Config{}, err } serverConf.DisableClientMultiStatements = serverConfig.DisableClientMultiStatements() readTimeout := time.Duration(serverConfig.ReadTimeout()) * time.Millisecond writeTimeout := time.Duration(serverConfig.WriteTimeout()) * time.Millisecond tlsConfig, err := servercfg.LoadTLSConfig(serverConfig) if err != nil { return server.Config{}, err } serverConf, err = serverConf.NewConfig() if err != nil { return server.Config{}, err } // Do not set the value of Version. Let it default to what go-mysql-server uses. This should be equivalent // to the value of mysql that we support. serverConf.ConnReadTimeout = readTimeout serverConf.ConnWriteTimeout = writeTimeout serverConf.MaxConnections = serverConfig.MaxConnections() serverConf.TLSConfig = tlsConfig serverConf.RequireSecureTransport = serverConfig.RequireSecureTransport() serverConf.MaxLoggedQueryLen = serverConfig.MaxLoggedQueryLen() serverConf.EncodeLoggedQuery = serverConfig.ShouldEncodeLoggedQuery() return serverConf, nil } // handleProtocolAndAddress returns new server.Config object with only Protocol and Address defined. func handleProtocolAndAddress(serverConfig servercfg.ServerConfig) (server.Config, error) { serverConf := server.Config{Protocol: "tcp"} portAsString := strconv.Itoa(serverConfig.Port()) hostPort := net.JoinHostPort(serverConfig.Host(), portAsString) if portInUse(hostPort) { portInUseError := fmt.Errorf("Port %s already in use.", portAsString) return server.Config{}, portInUseError } serverConf.Address = hostPort sock, useSock, err := servercfg.CheckForUnixSocket(serverConfig) if err != nil { return server.Config{}, err } if useSock { serverConf.Socket = sock } return serverConf, nil } func getEventSchedulerStatus(status string) (eventscheduler.SchedulerStatus, error) { switch strings.ToLower(status) { case "on", "1": return eventscheduler.SchedulerOn, nil case "off", "0": return eventscheduler.SchedulerOff, nil case "disabled": return eventscheduler.SchedulerDisabled, nil default: return eventscheduler.SchedulerDisabled, fmt.Errorf("Error while setting value '%s' to 'event_scheduler'.", status) } }