diff --git a/go/cmd/dolt/commands/config.go b/go/cmd/dolt/commands/config.go index 8d6b40595a..3ab6521b07 100644 --- a/go/cmd/dolt/commands/config.go +++ b/go/cmd/dolt/commands/config.go @@ -29,9 +29,8 @@ import ( ) const ( - globalParamName = "global" - localParamName = "local" - + globalParamName = "global" + localParamName = "local" addOperationStr = "add" listOperationStr = "list" getOperationStr = "get" @@ -100,7 +99,7 @@ func (cmd ConfigCmd) Exec(ctx context.Context, commandStr string, args []string, cfgTypes := apr.FlagsEqualTo([]string{globalParamName, localParamName}, true) ops := apr.FlagsEqualTo([]string{addOperationStr, listOperationStr, getOperationStr, unsetOperationStr}, true) - if cfgTypes.Size() == 2 { + if cfgTypes.Size() > 1 { cli.PrintErrln(color.RedString("Specifying both -local and -global is not valid. Exactly one may be set")) usage() } else { @@ -145,8 +144,7 @@ func getOperation(dEnv *env.DoltEnv, setCfgTypes *set.StrSet, args []string, pri cfgTypesSl := setCfgTypes.AsSlice() for _, cfgType := range cfgTypesSl { - isGlobal := cfgType == globalParamName - if _, ok := dEnv.Config.GetConfig(newCfgElement(isGlobal)); !ok { + if _, ok := dEnv.Config.GetConfig(newCfgElement(cfgType)); !ok { cli.PrintErrln(color.RedString("Unable to read config.")) return 1 } @@ -157,8 +155,7 @@ func getOperation(dEnv *env.DoltEnv, setCfgTypes *set.StrSet, args []string, pri } for _, cfgType := range cfgTypesSl { - isGlobal := cfgType == globalParamName - cfg, ok := dEnv.Config.GetConfig(newCfgElement(isGlobal)) + cfg, ok := dEnv.Config.GetConfig(newCfgElement(cfgType)) if ok { if val, err := cfg.GetString(args[0]); err == nil { printFn(args[0], &val) @@ -180,34 +177,46 @@ func addOperation(dEnv *env.DoltEnv, setCfgTypes *set.StrSet, args []string, usa return 1 } - isGlobal := setCfgTypes.Contains(globalParamName) updates := make(map[string]string) - for i := 0; i < len(args); i += 2 { updates[strings.ToLower(args[i])] = args[i+1] } - if cfg, ok := dEnv.Config.GetConfig(newCfgElement(isGlobal)); !ok { - if !isGlobal { - err := dEnv.Config.CreateLocalConfig(updates) + var cfgType string + switch setCfgTypes.Size() { + case 0: + cfgType = localParamName + case 1: + cfgType = setCfgTypes.AsSlice()[0] + default: + cli.Println("error: cannot add to multiple configs simultaneously") + return 1 + } + cfg, ok := dEnv.Config.GetConfig(newCfgElement(cfgType)) + if !ok { + switch cfgType { + case globalParamName: + panic("Should not have been able to get this far without a global config.") + case localParamName: + err := dEnv.Config.CreateLocalConfig(updates) if err != nil { cli.PrintErrln(color.RedString("Unable to create repo local config file")) return 1 } - - } else { - panic("Should not have been able to get this far without a global config.") - } - } else { - err := cfg.SetStrings(updates) - - if err != nil { - cli.PrintErrln(color.RedString("Failed to update config.")) + return 0 + default: + cli.Println("error: unknown config flag") return 1 } } + err := cfg.SetStrings(updates) + if err != nil { + cli.PrintErrln(color.RedString("Failed to update config.")) + return 1 + } + cli.Println(color.CyanString("Config successfully updated.")) return 0 } @@ -223,8 +232,18 @@ func unsetOperation(dEnv *env.DoltEnv, setCfgTypes *set.StrSet, args []string, u args[i] = strings.ToLower(a) } - isGlobal := setCfgTypes.Contains(globalParamName) - if cfg, ok := dEnv.Config.GetConfig(newCfgElement(isGlobal)); !ok { + var cfgType string + switch setCfgTypes.Size() { + case 0: + cfgType = localParamName + case 1: + cfgType = setCfgTypes.AsSlice()[0] + default: + cli.Println("error: cannot unset from multiple configs simultaneously") + return 1 + } + + if cfg, ok := dEnv.Config.GetConfig(newCfgElement(cfgType)); !ok { cli.PrintErrln(color.RedString("Unable to read config.")) return 1 } else { @@ -249,8 +268,7 @@ func listOperation(dEnv *env.DoltEnv, setCfgTypes *set.StrSet, args []string, us cfgTypesSl := setCfgTypes.AsSlice() for _, cfgType := range cfgTypesSl { - isGlobal := cfgType == globalParamName - if _, ok := dEnv.Config.GetConfig(newCfgElement(isGlobal)); !ok { + if _, ok := dEnv.Config.GetConfig(newCfgElement(cfgType)); !ok { cli.PrintErrln(color.RedString("Unable to read config.")) return 1 } @@ -261,12 +279,9 @@ func listOperation(dEnv *env.DoltEnv, setCfgTypes *set.StrSet, args []string, us } for _, cfgType := range cfgTypesSl { - isGlobal := cfgType == globalParamName - cfg, ok := dEnv.Config.GetConfig(newCfgElement(isGlobal)) - if ok { - cfg.Iter(func(name string, val string) (stop bool) { + if cfg, ok := dEnv.Config.GetConfig(newCfgElement(cfgType)); ok { + cfg.Iter(func(name, val string) bool { printFn(name, val) - return false }) } @@ -275,10 +290,13 @@ func listOperation(dEnv *env.DoltEnv, setCfgTypes *set.StrSet, args []string, us return 0 } -func newCfgElement(isGlobal bool) env.DoltConfigElement { - if isGlobal { +func newCfgElement(configFlag string) env.ConfigScope { + switch configFlag { + case localParamName: + return env.LocalConfig + case globalParamName: return env.GlobalConfig + default: + return env.LocalConfig } - - return env.LocalConfig } diff --git a/go/cmd/dolt/commands/config_test.go b/go/cmd/dolt/commands/config_test.go index 194badb672..545b252f0f 100644 --- a/go/cmd/dolt/commands/config_test.go +++ b/go/cmd/dolt/commands/config_test.go @@ -19,6 +19,8 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" + "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/utils/config" "github.com/dolthub/dolt/go/libraries/utils/set" @@ -26,7 +28,251 @@ import ( var globalCfg = set.NewStrSet([]string{globalParamName}) var localCfg = set.NewStrSet([]string{localParamName}) +var multiCfg = set.NewStrSet([]string{globalParamName, localParamName}) +func initializeConfigs(dEnv *env.DoltEnv, element env.ConfigScope) { + switch element { + case env.GlobalConfig: + globalCfg, _ := dEnv.Config.GetConfig(env.GlobalConfig) + globalCfg.SetStrings(map[string]string{"title": "senior dufus"}) + case env.LocalConfig: + dEnv.Config.CreateLocalConfig(map[string]string{"title": "senior dufus"}) + } +} +func TestConfigAdd(t *testing.T) { + tests := []struct { + Name string + CfgSet *set.StrSet + Scope env.ConfigScope + Args []string + Code int + }{ + { + Name: "local", + CfgSet: localCfg, + Scope: env.LocalConfig, + Args: []string{"title", "senior dufus"}, + }, + { + Name: "global", + CfgSet: globalCfg, + Scope: env.GlobalConfig, + Args: []string{"title", "senior dufus"}, + }, + { + Name: "default", + CfgSet: &set.StrSet{}, + Scope: env.LocalConfig, + Args: []string{"title", "senior dufus"}, + }, + { + Name: "multi error", + CfgSet: multiCfg, + Scope: env.LocalConfig, + Args: []string{"title", "senior dufus"}, + Code: 1, + }, + { + Name: "no args", + CfgSet: multiCfg, + Scope: env.LocalConfig, + Args: []string{}, + Code: 1, + }, + { + Name: "odd args", + CfgSet: multiCfg, + Scope: env.LocalConfig, + Args: []string{"title"}, + Code: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + dEnv := createTestEnv() + resCode := addOperation(dEnv, tt.CfgSet, tt.Args, func() {}) + + if tt.Code == 1 { + assert.Equal(t, tt.Code, resCode) + + } else if cfg, ok := dEnv.Config.GetConfig(tt.Scope); ok { + resVal := cfg.GetStringOrDefault("title", "") + assert.Equal(t, "senior dufus", resVal) + } else { + t.Error("comparison config not found") + } + }) + } +} + +func TestConfigGet(t *testing.T) { + tests := []struct { + Name string + CfgSet *set.StrSet + ConfigElem env.ConfigScope + Key string + Code int + }{ + { + Name: "local", + CfgSet: localCfg, + ConfigElem: env.LocalConfig, + Key: "title", + }, + { + Name: "global", + CfgSet: globalCfg, + ConfigElem: env.GlobalConfig, + Key: "title", + }, + { + Name: "default", + CfgSet: &set.StrSet{}, + ConfigElem: env.LocalConfig, + Key: "title", + }, + { + Name: "multi", + CfgSet: multiCfg, + ConfigElem: env.LocalConfig, + Key: "title", + }, + { + Name: "missing param", + CfgSet: multiCfg, + ConfigElem: env.LocalConfig, + Key: "unknown", + Code: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + dEnv := createTestEnv() + initializeConfigs(dEnv, tt.ConfigElem) + + var resVal string + resCode := getOperation(dEnv, tt.CfgSet, []string{tt.Key}, func(k string, v *string) { resVal = *v }) + + if tt.Code == 1 { + assert.Equal(t, tt.Code, resCode) + } else { + assert.Equal(t, "senior dufus", resVal) + } + }) + } +} + +func TestConfigUnset(t *testing.T) { + tests := []struct { + Name string + CfgSet *set.StrSet + ConfigElem env.ConfigScope + Key string + Code int + }{ + { + Name: "local", + CfgSet: localCfg, + ConfigElem: env.LocalConfig, + Key: "title", + }, + { + Name: "global", + CfgSet: globalCfg, + ConfigElem: env.GlobalConfig, + Key: "title", + }, + { + Name: "default", + CfgSet: &set.StrSet{}, + ConfigElem: env.LocalConfig, + Key: "title", + }, + { + Name: "multi", + CfgSet: multiCfg, + ConfigElem: env.LocalConfig, + Key: "title", + Code: 1, + }, + { + Name: "missing param", + CfgSet: multiCfg, + ConfigElem: env.LocalConfig, + Key: "unknown", + Code: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + dEnv := createTestEnv() + initializeConfigs(dEnv, tt.ConfigElem) + + resCode := unsetOperation(dEnv, tt.CfgSet, []string{tt.Key}, func() {}) + + if tt.Code == 1 { + assert.Equal(t, tt.Code, resCode) + } else if cfg, ok := dEnv.Config.GetConfig(tt.ConfigElem); ok { + _, err := cfg.GetString(tt.Key) + assert.Error(t, err) + } else { + t.Error("comparison config not found") + } + }) + } +} + +func TestConfigList(t *testing.T) { + tests := []struct { + Name string + CfgSet *set.StrSet + ConfigElem env.ConfigScope + }{ + { + Name: "local", + CfgSet: localCfg, + ConfigElem: env.LocalConfig, + }, + { + Name: "global", + CfgSet: globalCfg, + ConfigElem: env.GlobalConfig, + }, + { + Name: "default", + CfgSet: &set.StrSet{}, + ConfigElem: env.LocalConfig, + }, + { + Name: "multi", + CfgSet: multiCfg, + ConfigElem: env.LocalConfig, + }, + } + + keys := []string{"title"} + values := []string{"senior dufus"} + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + dEnv := createTestEnv() + initializeConfigs(dEnv, tt.ConfigElem) + + var resKeys []string + var resVals []string + resCode := listOperation(dEnv, tt.CfgSet, []string{}, func() {}, func(k, v string) { + resKeys = append(resKeys, k) + resVals = append(resVals, v) + }) + assert.Equal(t, 0, resCode) + assert.Equal(t, keys, resKeys) + assert.Equal(t, values, resVals) + }) + } +} func TestConfig(t *testing.T) { ctx := context.TODO() dEnv := createTestEnv() diff --git a/go/cmd/dolt/commands/filter-branch.go b/go/cmd/dolt/commands/filter-branch.go index 62e6ba503c..e2fbcc6a1f 100644 --- a/go/cmd/dolt/commands/filter-branch.go +++ b/go/cmd/dolt/commands/filter-branch.go @@ -40,6 +40,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" "github.com/dolthub/dolt/go/libraries/utils/argparser" + "github.com/dolthub/dolt/go/libraries/utils/config" "github.com/dolthub/dolt/go/libraries/utils/tracing" "github.com/dolthub/dolt/go/store/hash" ) @@ -233,7 +234,7 @@ func processFilterQuery(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commi // we set manually with the one at the working set of the HEAD being rebased. // Some functionality will not work on this kind of engine, e.g. many DOLT_ functions. func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commit) (*sql.Context, *sqlEngine, error) { - sess := dsess.DefaultSession() + sess := dsess.DefaultSession().NewDoltSession(config.NewMapConfig(make(map[string]string))) sqlCtx := sql.NewContext(ctx, sql.WithSession(sess), diff --git a/go/cmd/dolt/commands/sql.go b/go/cmd/dolt/commands/sql.go index a4645751a9..e5c73ba247 100644 --- a/go/cmd/dolt/commands/sql.go +++ b/go/cmd/dolt/commands/sql.go @@ -113,6 +113,30 @@ func init() { Type: sql.NewSystemIntType(currentBatchModeKey, -9223372036854775808, 9223372036854775807, false), Default: int64(0), }, + { + Name: dsess.DoltDefaultBranchKey, + Scope: sql.SystemVariableScope_Global, + Dynamic: true, + SetVarHintApplies: false, + Type: sql.NewSystemStringType(dsess.DoltDefaultBranchKey), + Default: "", + }, + { + Name: doltdb.ReplicateToRemoteKey, + Scope: sql.SystemVariableScope_Global, + Dynamic: true, + SetVarHintApplies: false, + Type: sql.NewSystemStringType(doltdb.ReplicateToRemoteKey), + Default: "", + }, + { + Name: doltdb.DoltReadReplicaKey, + Scope: sql.SystemVariableScope_Global, + Dynamic: true, + SetVarHintApplies: false, + Type: sql.NewSystemStringType(doltdb.DoltReadReplicaKey), + Default: "", + }, }) } @@ -403,6 +427,7 @@ func execBatch( if err != nil { return errhand.VerboseErrorFromError(err) } + se, err := newSqlEngine(ctx, dEnv, roots, readOnly, format, dbs...) if err != nil { return errhand.VerboseErrorFromError(err) @@ -518,8 +543,20 @@ func CollectDBs(ctx context.Context, mrEnv env.MultiRepoEnv) ([]dsqle.SqlDatabas dbs := make([]dsqle.SqlDatabase, 0, len(mrEnv)) var db dsqle.SqlDatabase err := mrEnv.Iter(func(name string, dEnv *env.DoltEnv) (stop bool, err error) { + postCommitHooks, err := env.GetCommitHooks(ctx, dEnv) + if err != nil { + return true, err + } + dEnv.DoltDB.SetCommitHooks(ctx, postCommitHooks) + db = newDatabase(name, dEnv) - if remoteName := dEnv.Config.GetStringOrDefault(dsqle.DoltReadReplicaKey, ""); remoteName != "" { + + if _, val, ok := sql.SystemVariables.GetGlobal(doltdb.DoltReadReplicaKey); ok && val != "" { + remoteName, ok := val.(string) + if !ok { + return true, sql.ErrInvalidSystemVariableValue.New(val) + } + db, err = dsqle.NewReadReplicaDatabase(ctx, db.(dsqle.Database), remoteName, dEnv.RepoStateReader(), dEnv.TempTableFilesDir(), doltdb.TodoWorkingSetMeta()) if err != nil { return true, err @@ -1415,7 +1452,7 @@ func mergeResultIntoStats(statement sqlparser.Statement, rowIter sql.RowIter, s type sqlEngine struct { dbs map[string]dsqle.SqlDatabase - sess *dsess.Session + sess *dsess.DoltSession contextFactory func(ctx context.Context) (*sql.Context, error) engine *sqle.Engine resultFormat resultFormat @@ -1473,7 +1510,10 @@ func newSqlEngine( } // TODO: not having user and email for this command should probably be an error or warning, it disables certain functionality - sess, err := dsess.NewSession(sql.NewEmptyContext(), sql.NewBaseSession(), pro, dEnv.Config, dbStates...) + sess, err := dsess.NewDoltSession(sql.NewEmptyContext(), sql.NewBaseSession(), pro, dEnv.Config, dbStates...) + if err != nil { + return nil, err + } // TODO: this should just be the session default like it is with MySQL err = sess.SetSessionVariable(sql.NewContext(ctx), sql.AutoCommitSessionVar, true) @@ -1490,7 +1530,7 @@ func newSqlEngine( }, nil } -func newSqlContext(sess *dsess.Session, cat sql.Catalog) func(ctx context.Context) (*sql.Context, error) { +func newSqlContext(sess *dsess.DoltSession, cat sql.Catalog) func(ctx context.Context) (*sql.Context, error) { return func(ctx context.Context) (*sql.Context, error) { sqlCtx := sql.NewContext(ctx, sql.WithSession(sess), diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index e4b56bd384..0644a807fa 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -83,6 +83,15 @@ func Serve(ctx context.Context, version string, serverConfig ServerConfig, serve permissions = auth.ReadPerm } + serverConf := server.Config{Protocol: "tcp"} + + if serverConfig.PersistenceBehavior() == loadPerisistentGlobals { + serverConf, startError = serverConf.NewConfig() + if startError != nil { + return + } + } + userAuth := auth.NewNativeSingle(serverConfig.User(), serverConfig.Password(), permissions) var mrEnv env.MultiRepoEnv @@ -128,19 +137,18 @@ func Serve(ctx context.Context, version string, serverConfig ServerConfig, serve return nil, err } + // Do not set the value of Version. Let it default to what go-mysql-server uses. This should be equivalent + // to the value of mysql that we support. + serverConf.Address = hostPort + serverConf.Auth = userAuth + serverConf.ConnReadTimeout = readTimeout + serverConf.ConnWriteTimeout = writeTimeout + serverConf.MaxConnections = serverConfig.MaxConnections() + serverConf.TLSConfig = tlsConfig + serverConf.RequireSecureTransport = serverConfig.RequireSecureTransport() + mySQLServer, startError = server.NewServer( - server.Config{ - Protocol: "tcp", - Address: hostPort, - Auth: userAuth, - ConnReadTimeout: readTimeout, - ConnWriteTimeout: writeTimeout, - MaxConnections: serverConfig.MaxConnections(), - TLSConfig: tlsConfig, - RequireSecureTransport: serverConfig.RequireSecureTransport(), - // Do not set the value of Version. Let it default to what go-mysql-server uses. This should be equivalent - // to the value of mysql that we support. - }, + serverConf, sqlEngine, newSessionBuilder(sqlEngine, dEnv.Config, pro, mrEnv, serverConfig.AutoCommit()), ) @@ -174,14 +182,14 @@ func newSessionBuilder(sqlEngine *sqle.Engine, dConf *env.DoltCliConfig, pro dsq tmpSqlCtx := sql.NewEmptyContext() client := sql.Client{Address: conn.RemoteAddr().String(), User: conn.User, Capabilities: conn.Capabilities} - mysqlSess := sql.NewSession(host, client, conn.ConnectionID) + mysqlSess := sql.NewBaseSessionWithClientServer(host, client, conn.ConnectionID) doltDbs := dsqle.DbsAsDSQLDBs(sqlEngine.Analyzer.Catalog.AllDatabases()) dbStates, err := getDbStates(ctx, doltDbs) if err != nil { return nil, err } - doltSess, err := dsess.NewSession(tmpSqlCtx, mysqlSess, pro, dConf, dbStates...) + doltSess, err := dsess.NewDoltSession(tmpSqlCtx, mysqlSess, pro, dConf, dbStates...) if err != nil { return nil, err } diff --git a/go/cmd/dolt/commands/sqlserver/server_test.go b/go/cmd/dolt/commands/sqlserver/server_test.go index 7bc57894e5..06e0b0ad84 100644 --- a/go/cmd/dolt/commands/sqlserver/server_test.go +++ b/go/cmd/dolt/commands/sqlserver/server_test.go @@ -26,10 +26,12 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/net/context" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils" "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/testcommands" "github.com/dolthub/dolt/go/libraries/doltcore/env" - dsqle "github.com/dolthub/dolt/go/libraries/doltcore/sqle" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/libraries/utils/config" ) type testPerson struct { @@ -393,11 +395,12 @@ func TestReadReplica(t *testing.T) { readReplicaDbName := multiSetup.DbNames[0] sourceDbName := multiSetup.DbNames[1] - replicaCfg, ok := multiSetup.MrEnv[readReplicaDbName].Config.GetConfig(env.LocalConfig) + localCfg, ok := multiSetup.MrEnv[readReplicaDbName].Config.GetConfig(env.LocalConfig) if !ok { t.Fatal("local config does not exist") } - replicaCfg.SetStrings(map[string]string{dsqle.DoltReadReplicaKey: "remote1"}) + config.NewPrefixConfig(localCfg, env.SqlServerGlobalsPrefix).SetStrings(map[string]string{doltdb.DoltReadReplicaKey: "remote1"}) + dsess.InitPersistedSystemVars(multiSetup.MrEnv[readReplicaDbName]) // start server as read replica sc := CreateServerController() diff --git a/go/cmd/dolt/commands/sqlserver/serverconfig.go b/go/cmd/dolt/commands/sqlserver/serverconfig.go index ecaa20944a..a59c41d84c 100644 --- a/go/cmd/dolt/commands/sqlserver/serverconfig.go +++ b/go/cmd/dolt/commands/sqlserver/serverconfig.go @@ -35,16 +35,22 @@ const ( ) const ( - defaultHost = "localhost" - defaultPort = 3306 - defaultUser = "root" - defaultPass = "" - defaultTimeout = 8 * 60 * 60 * 1000 // 8 hours, same as MySQL - defaultReadOnly = false - defaultLogLevel = LogLevel_Info - defaultAutoCommit = true - defaultMaxConnections = 100 - defaultQueryParallelism = 2 + defaultHost = "localhost" + defaultPort = 3306 + defaultUser = "root" + defaultPass = "" + defaultTimeout = 8 * 60 * 60 * 1000 // 8 hours, same as MySQL + defaultReadOnly = false + defaultLogLevel = LogLevel_Info + defaultAutoCommit = true + defaultMaxConnections = 100 + defaultQueryParallelism = 2 + defaultPersistenceBahavior = loadPerisistentGlobals +) + +const ( + ignorePeristentGlobals = "ignore" + loadPerisistentGlobals = "load" ) // String returns the string representation of the log level. @@ -101,6 +107,8 @@ type ServerConfig interface { TLSCert() string // RequireSecureTransport is true if the server should reject non-TLS connections. RequireSecureTransport() bool + // PersistenceBehavior is "load" if we include persisted system globals on server init + PersistenceBehavior() string } type commandLineServerConfig struct { @@ -118,6 +126,7 @@ type commandLineServerConfig struct { tlsKey string tlsCert string requireSecureTransport bool + persistenceBehavior string } // Host returns the domain that the server will run on. Accepts an IPv4 or IPv6 address, in addition to localhost. @@ -175,6 +184,11 @@ func (cfg *commandLineServerConfig) QueryParallelism() int { return cfg.queryParallelism } +// PersistenceBehavior returns whether to autoload persisted server configuration +func (cfg *commandLineServerConfig) PersistenceBehavior() string { + return cfg.persistenceBehavior +} + func (cfg *commandLineServerConfig) TLSKey() string { return cfg.tlsKey } @@ -254,19 +268,25 @@ func (cfg *commandLineServerConfig) withDBNamesAndPaths(dbNamesAndPaths []env.En return cfg } +func (cfg *commandLineServerConfig) withPersistenceBehavior(persistenceBehavior string) *commandLineServerConfig { + cfg.persistenceBehavior = persistenceBehavior + return cfg +} + // DefaultServerConfig creates a `*ServerConfig` that has all of the options set to their default values. func DefaultServerConfig() *commandLineServerConfig { return &commandLineServerConfig{ - host: defaultHost, - port: defaultPort, - user: defaultUser, - password: defaultPass, - timeout: defaultTimeout, - readOnly: defaultReadOnly, - logLevel: defaultLogLevel, - autoCommit: defaultAutoCommit, - maxConnections: defaultMaxConnections, - queryParallelism: defaultQueryParallelism, + host: defaultHost, + port: defaultPort, + user: defaultUser, + password: defaultPass, + timeout: defaultTimeout, + readOnly: defaultReadOnly, + logLevel: defaultLogLevel, + autoCommit: defaultAutoCommit, + maxConnections: defaultMaxConnections, + queryParallelism: defaultQueryParallelism, + persistenceBehavior: defaultPersistenceBahavior, } } diff --git a/go/cmd/dolt/commands/sqlserver/sqlserver.go b/go/cmd/dolt/commands/sqlserver/sqlserver.go index 1c4639e907..f7076f7bf0 100644 --- a/go/cmd/dolt/commands/sqlserver/sqlserver.go +++ b/go/cmd/dolt/commands/sqlserver/sqlserver.go @@ -33,18 +33,19 @@ import ( ) const ( - hostFlag = "host" - portFlag = "port" - userFlag = "user" - passwordFlag = "password" - timeoutFlag = "timeout" - readonlyFlag = "readonly" - logLevelFlag = "loglevel" - multiDBDirFlag = "multi-db-dir" - noAutoCommitFlag = "no-auto-commit" - configFileFlag = "config" - queryParallelismFlag = "query-parallelism" - maxConnectionsFlag = "max-connections" + hostFlag = "host" + portFlag = "port" + userFlag = "user" + passwordFlag = "password" + timeoutFlag = "timeout" + readonlyFlag = "readonly" + logLevelFlag = "loglevel" + multiDBDirFlag = "multi-db-dir" + noAutoCommitFlag = "no-auto-commit" + configFileFlag = "config" + queryParallelismFlag = "query-parallelism" + maxConnectionsFlag = "max-connections" + persistenceBehaviorFlag = "persistence-behavior" ) func indentLines(s string) string { @@ -141,6 +142,8 @@ func (cmd SqlServerCmd) CreateArgParser() *argparser.ArgParser { ap.SupportsFlag(noAutoCommitFlag, "", "When provided sessions will not automatically commit their changes to the working set. Anything not manually committed will be lost.") ap.SupportsInt(queryParallelismFlag, "", "num-go-routines", fmt.Sprintf("Set the number of go routines spawned to handle each query (default `%d`)", serverConfig.QueryParallelism())) ap.SupportsInt(maxConnectionsFlag, "", "max-connections", fmt.Sprintf("Set the number of connections handled by the server (default `%d`)", serverConfig.MaxConnections())) + ap.SupportsInt(persistenceBehaviorFlag, "", "persistence-behavior", fmt.Sprintf("Indicate whether to `load` or `ignore` persisted global variables (default `%s`)", serverConfig.PersistenceBehavior())) + return ap } @@ -203,12 +206,9 @@ func startServer(ctx context.Context, versionStr, commandStr string, args []stri } func GetServerConfig(dEnv *env.DoltEnv, apr *argparser.ArgParseResults, requiresRepo bool) (ServerConfig, error) { - cfgFile, ok := apr.GetValue(configFileFlag) - - if ok { + if cfgFile, ok := apr.GetValue(configFileFlag); ok { return getYAMLServerConfig(dEnv.FS, cfgFile) } - return getCommandLineServerConfig(dEnv, apr, requiresRepo) } @@ -265,6 +265,10 @@ func getCommandLineServerConfig(dEnv *env.DoltEnv, apr *argparser.ArgParseResult } serverConfig.autoCommit = !apr.Contains(noAutoCommitFlag) + if persistenceBehavior, ok := apr.GetValue(persistenceBehaviorFlag); ok { + serverConfig.withPersistenceBehavior(persistenceBehavior) + } + return serverConfig, nil } diff --git a/go/cmd/dolt/commands/sqlserver/yaml_config.go b/go/cmd/dolt/commands/sqlserver/yaml_config.go index 2dfbeb0fec..d9e4b33a93 100644 --- a/go/cmd/dolt/commands/sqlserver/yaml_config.go +++ b/go/cmd/dolt/commands/sqlserver/yaml_config.go @@ -58,6 +58,8 @@ func intPtr(n int) *int { type BehaviorYAMLConfig struct { ReadOnly *bool `yaml:"read_only"` AutoCommit *bool + // PersistenceBehavior regulates loading persisted system variable configuration. + PersistenceBehavior *string `yaml:"persistence_behavior"` } // UserYAMLConfig contains server configuration regarding the user account clients must use to connect @@ -110,9 +112,13 @@ func NewYamlConfig(configFileData []byte) (YAMLConfig, error) { func serverConfigAsYAMLConfig(cfg ServerConfig) YAMLConfig { return YAMLConfig{ - LogLevelStr: strPtr(string(cfg.LogLevel())), - BehaviorConfig: BehaviorYAMLConfig{boolPtr(cfg.ReadOnly()), boolPtr(cfg.AutoCommit())}, - UserConfig: UserYAMLConfig{strPtr(cfg.User()), strPtr(cfg.Password())}, + LogLevelStr: strPtr(string(cfg.LogLevel())), + BehaviorConfig: BehaviorYAMLConfig{ + boolPtr(cfg.ReadOnly()), + boolPtr(cfg.AutoCommit()), + strPtr(cfg.PersistenceBehavior()), + }, + UserConfig: UserYAMLConfig{strPtr(cfg.User()), strPtr(cfg.Password())}, ListenerConfig: ListenerYAMLConfig{ strPtr(cfg.Host()), intPtr(cfg.Port()), @@ -290,3 +296,10 @@ func (cfg YAMLConfig) RequireSecureTransport() bool { } return *cfg.ListenerConfig.RequireSecureTransport } + +func (cfg YAMLConfig) PersistenceBehavior() string { + if cfg.BehaviorConfig.PersistenceBehavior == nil { + return loadPerisistentGlobals + } + return *cfg.BehaviorConfig.PersistenceBehavior +} diff --git a/go/cmd/dolt/commands/sqlserver/yaml_config_test.go b/go/cmd/dolt/commands/sqlserver/yaml_config_test.go index 31179b1320..90e4bf462d 100644 --- a/go/cmd/dolt/commands/sqlserver/yaml_config_test.go +++ b/go/cmd/dolt/commands/sqlserver/yaml_config_test.go @@ -30,6 +30,7 @@ log_level: info behavior: read_only: false autocommit: true + persistence_behavior: load user: name: root diff --git a/go/cmd/dolt/dolt.go b/go/cmd/dolt/dolt.go index 0cf43ebc2c..630aef1770 100644 --- a/go/cmd/dolt/dolt.go +++ b/go/cmd/dolt/dolt.go @@ -43,6 +43,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/events" "github.com/dolthub/dolt/go/libraries/utils/filesys" "github.com/dolthub/dolt/go/store/util/tempfiles" @@ -301,6 +302,10 @@ func runMain() int { defer tempfiles.MovableTempFileProvider.Clean() if dEnv.DoltDB != nil { + err := dsess.InitPersistedSystemVars(dEnv) + if err != nil { + cli.Printf("error: failed to load persisted global variables: %s\n", err.Error()) + } dEnv.DoltDB.SetCommitHookLogger(ctx, cli.OutStream) } diff --git a/go/go.mod b/go/go.mod index 2298f5ab0d..0dd21526fa 100644 --- a/go/go.mod +++ b/go/go.mod @@ -19,7 +19,7 @@ require ( github.com/denisbrodbeck/machineid v1.0.1 github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20201005193433-3ee972b1d078 github.com/dolthub/fslock v0.0.3 - github.com/dolthub/go-mysql-server v0.11.1-0.20211028230555-ba7d8488a2e4 + github.com/dolthub/go-mysql-server v0.11.1-0.20211029162420-43a1226b27d3 github.com/dolthub/ishell v0.0.0-20210205014355-16a4ce758446 github.com/dolthub/mmap-go v1.0.4-0.20201107010347-f9f2a9588a66 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 @@ -60,7 +60,6 @@ require ( github.com/spf13/cobra v1.0.0 github.com/stretchr/testify v1.7.0 github.com/tealeg/xlsx v1.0.5 - github.com/tidwall/pretty v1.0.1 // indirect github.com/tklauser/go-sysconf v0.3.5 // indirect github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-lib v2.4.0+incompatible // indirect diff --git a/go/go.sum b/go/go.sum index 48090905e1..944efac4c1 100644 --- a/go/go.sum +++ b/go/go.sum @@ -144,8 +144,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= -github.com/dolthub/go-mysql-server v0.11.1-0.20211028230555-ba7d8488a2e4 h1:vGcWTyjRIQKfmFyKhf5kUiUHmc/Wd4CtFQ5dalSD8Mw= -github.com/dolthub/go-mysql-server v0.11.1-0.20211028230555-ba7d8488a2e4/go.mod h1:+XgR49p1y6TGpVpjbmd23Ljpqr4L/06nNpf39TcP56M= +github.com/dolthub/go-mysql-server v0.11.1-0.20211029162420-43a1226b27d3 h1:xS2HsEKXij0yE/I2plvuFpG9/jp1PnlGdRhbbLH6DqM= +github.com/dolthub/go-mysql-server v0.11.1-0.20211029162420-43a1226b27d3/go.mod h1:NXWOVk1RyZI/mR7bghGYU+Zmb58mo37420r91O7aKGk= github.com/dolthub/ishell v0.0.0-20210205014355-16a4ce758446 h1:0ol5pj+QlKUKAtqs1LiPM3ZJKs+rHPgLSsMXmhTrCAM= github.com/dolthub/ishell v0.0.0-20210205014355-16a4ce758446/go.mod h1:dhGBqcCEfK5kuFmeO5+WOx3hqc1k3M29c1oS/R7N4ms= github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8= @@ -170,7 +170,6 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239/go.mod h1:Gdwt2ce0yfBxPvZrHkprdPPTTS3N5rwmLE8T22KBXlw= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= @@ -366,7 +365,6 @@ github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09 github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= @@ -382,7 +380,6 @@ github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANyt github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= github.com/jedib0t/go-pretty v4.3.1-0.20191104025401-85fe5d6a7c4d+incompatible h1:SwOdF+2qzbZnEUsoEv1v0VkoQvoQ2pZLVDjNDzL6nto= github.com/jedib0t/go-pretty v4.3.1-0.20191104025401-85fe5d6a7c4d+incompatible/go.mod h1:XemHduiw8R651AF9Pt4FwCTKeG3oo7hrHJAoznj9nag= -github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869/go.mod h1:cJ6Cj7dQo+O6GJNiMx+Pa94qKj+TG8ONdKHgMNIyyag= github.com/jingyugao/rowserrcheck v0.0.0-20191204022205-72ab7603b68a/go.mod h1:xRskid8CManxVta/ALEhJha/pweKBaVG6fWgc0yH25s= github.com/jirfag/go-printf-func-name v0.0.0-20191110105641-45db9963cdd3/go.mod h1:HEWGJkRDzjJY2sqdDwxccsGicWEf9BQOZsq2tV+xzM0= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= @@ -426,7 +423,6 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kyoh86/exportloopref v0.1.7/go.mod h1:h1rDl2Kdj97+Kwh4gdz3ujE7XHmH51Q0lUiZ1z4NLj8= github.com/lestrrat-go/envload v0.0.0-20180220234015-a3eb8ddeffcc h1:RKf14vYWi2ttpEmkA4aQ3j4u9dStX2t4M8UM6qqNsG8= github.com/lestrrat-go/envload v0.0.0-20180220234015-a3eb8ddeffcc/go.mod h1:kopuH9ugFRkIXf3YoqHKyrJ9YfUFsckUU9S7B+XP+is= -github.com/lestrrat-go/strftime v1.0.1/go.mod h1:E1nN3pCbtMSu1yjSVeyuRFVm/U0xoR76fd03sz+Qz4g= github.com/lestrrat-go/strftime v1.0.4 h1:T1Rb9EPkAhgxKqbcMIPguPq8glqXTA1koF8n9BHElA8= github.com/lestrrat-go/strftime v1.0.4/go.mod h1:E1nN3pCbtMSu1yjSVeyuRFVm/U0xoR76fd03sz+Qz4g= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -472,7 +468,6 @@ github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrk github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= -github.com/mitchellh/hashstructure v1.0.0/go.mod h1:QjSHrPWS+BGUVBYkbTZWEnOh3G1DutKwClXU/ABz6AQ= github.com/mitchellh/hashstructure v1.1.0 h1:P6P1hdjqAAknpY/M1CGipelZgp+4y9ja9kmUZPXP+H0= github.com/mitchellh/hashstructure v1.1.0/go.mod h1:xUDAozZz0Wmdiufv0uyhnHkUTN6/6d8ulp4AwfLKrmA= github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= @@ -591,7 +586,6 @@ github.com/shirou/gopsutil v0.0.0-20190901111213-e4ec7b275ada/go.mod h1:WWnYX4lz github.com/shirou/gopsutil v3.21.2+incompatible h1:U+YvJfjCh6MslYlIAXvPtzhW3YZEtc9uncueUNpD/0A= github.com/shirou/gopsutil v3.21.2+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc= -github.com/shopspring/decimal v0.0.0-20191130220710-360f2bc03045/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= @@ -648,11 +642,9 @@ github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69 github.com/tdakkota/asciicheck v0.0.0-20200416190851-d7f85be797a2/go.mod h1:yHp0ai0Z9gUljN3o0xMhYJnH/IcvkdTBOX2fmJ93JEM= github.com/tealeg/xlsx v1.0.5 h1:+f8oFmvY8Gw1iUXzPk+kz+4GpbDZPK1FhPiQRd+ypgE= github.com/tealeg/xlsx v1.0.5/go.mod h1:btRS8dz54TDnvKNosuAqxrM1QgN1udgk9O34bDCnORM= -github.com/tebeka/strftime v0.1.4/go.mod h1:7wJm3dZlpr4l/oVK0t1HYIc4rMzQ2XJlOMIUJUJH6XQ= github.com/tetafro/godot v0.4.8/go.mod h1:/7NLHhv08H1+8DNj0MElpAACw1ajsCuf3TKNQxA5S+0= +github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= -github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8= -github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/timakin/bodyclose v0.0.0-20190930140734-f7f2e9bca95e/go.mod h1:Qimiffbc6q9tBWlVV6x0P9sat/ao1xEkREYPPj9hphk= github.com/tklauser/go-sysconf v0.3.5 h1:uu3Xl4nkLzQfXNsWn15rPc/HQCJKObbt1dKJeWp3vU4= github.com/tklauser/go-sysconf v0.3.5/go.mod h1:MkWzOF4RMCshBAMXuhXJs64Rte09mITnppBXY/rYEFI= diff --git a/go/libraries/doltcore/doltdb/commit_hooks.go b/go/libraries/doltcore/doltdb/commit_hooks.go index 9a3dd703b7..ddae0402e6 100644 --- a/go/libraries/doltcore/doltdb/commit_hooks.go +++ b/go/libraries/doltcore/doltdb/commit_hooks.go @@ -23,7 +23,10 @@ import ( "github.com/dolthub/dolt/go/store/datas" ) -const ReplicateToRemoteKey = "dolt_replicate_to_remote" +const ( + ReplicateToRemoteKey = "dolt_replicate_to_remote" + DoltReadReplicaKey = "dolt_read_replica_remote" +) type ReplicateHook struct { destDB datas.Database diff --git a/go/libraries/doltcore/env/config.go b/go/libraries/doltcore/env/config.go index 37cd784ecb..6afdfa0729 100644 --- a/go/libraries/doltcore/env/config.go +++ b/go/libraries/doltcore/env/config.go @@ -54,20 +54,25 @@ const ( var LocalConfigWhitelist = set.NewStrSet([]string{UserNameKey, UserEmailKey}) var GlobalConfigWhitelist = set.NewStrSet([]string{UserNameKey, UserEmailKey}) -// DoltConfigElement is an enum representing the elements that make up the ConfigHierarchy -type DoltConfigElement int +// ConfigScope is an enum representing the elements that make up the ConfigHierarchy +type ConfigScope int const ( // LocalConfig is the repository's local config portion of the ConfigHierarchy - LocalConfig DoltConfigElement = iota + LocalConfig ConfigScope = iota // GlobalConfig is the user's global config portion of the ConfigHierarchy GlobalConfig ) +const ( + // SqlServerGlobalsPrefix is config namespace accessible by the SQL engine (ex: sqlserver.global.key) + SqlServerGlobalsPrefix = "sqlserver.global" +) + // String gives the string name of an element that was used when it was added to the ConfigHierarchy, which is the // same name that is used to retrieve that element of the string hierarchy. -func (ce DoltConfigElement) String() string { +func (ce ConfigScope) String() string { switch ce { case LocalConfig: return localConfigName @@ -155,8 +160,13 @@ func (dcc *DoltCliConfig) createLocalConfigAt(dir string, vals map[string]string } // GetConfig retrieves a specific element of the config hierarchy. -func (dcc *DoltCliConfig) GetConfig(element DoltConfigElement) (config.ReadWriteConfig, bool) { - return dcc.ch.GetConfig(element.String()) +func (dcc *DoltCliConfig) GetConfig(element ConfigScope) (config.ReadWriteConfig, bool) { + switch element { + case LocalConfig, GlobalConfig: + return dcc.ch.GetConfig(element.String()) + default: + return nil, false + } } // GetStringOrDefault retrieves a string from the config hierarchy and returns it if available. Otherwise it returns diff --git a/go/libraries/doltcore/env/environment.go b/go/libraries/doltcore/env/environment.go index 21c825ceec..5b5d33628b 100644 --- a/go/libraries/doltcore/env/environment.go +++ b/go/libraries/doltcore/env/environment.go @@ -25,6 +25,7 @@ import ( "time" "unicode" + "github.com/dolthub/go-mysql-server/sql" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -57,11 +58,14 @@ const ( tempTablesDir = "temptf" ) -func getCommitHooks(ctx context.Context, dEnv *DoltEnv) ([]datas.CommitHook, error) { +func GetCommitHooks(ctx context.Context, dEnv *DoltEnv) ([]datas.CommitHook, error) { postCommitHooks := make([]datas.CommitHook, 0) + if _, val, ok := sql.SystemVariables.GetGlobal(doltdb.ReplicateToRemoteKey); ok && val != "" { + backupName, ok := val.(string) + if !ok { + return nil, sql.ErrInvalidSystemVariableValue.New(val) + } - backupName := dEnv.Config.GetStringOrDefault(doltdb.ReplicateToRemoteKey, "") - if backupName != "" { remotes, err := dEnv.GetRemotes() if err != nil { return nil, err @@ -206,15 +210,6 @@ func Load(ctx context.Context, hdp HomeDirProvider, fs filesys.Filesys, urlStr, } } - if dbLoadErr == nil { - postCommitHooks, dbLoadErr := getCommitHooks(ctx, dEnv) - if dbLoadErr != nil { - dEnv.DBLoadError = dbLoadErr - } else { - dEnv.DoltDB.SetCommitHooks(ctx, postCommitHooks) - } - } - return dEnv } diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 3272ef68f8..4bae3ecf85 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -43,8 +43,6 @@ var ErrInvalidTableName = errors.NewKind("Invalid table name %s. Table names mus var ErrReservedTableName = errors.NewKind("Invalid table name %s. Table names beginning with `dolt_` are reserved for internal use") var ErrSystemTableAlter = errors.NewKind("Cannot alter table %s: system tables cannot be dropped or altered") -const DoltReadReplicaKey = "dolt_read_replica_remote" - type SqlDatabase interface { sql.Database GetRoot(*sql.Context) (*doltdb.RootValue, error) @@ -354,7 +352,7 @@ func (db Database) GetTableInsensitiveWithRoot(ctx *sql.Context, root *doltdb.Ro case doltdb.CommitAncestorsTableName: dt, found = dtables.NewCommitAncestorsTable(ctx, db.ddb), true case doltdb.StatusTableName: - dt, found = dtables.NewStatusTable(ctx, db.name, db.ddb, dsess.NewSessionStateAdapter(sess, db.name, map[string]env.Remote{}, map[string]env.BranchConfig{}), db.drw), true + dt, found = dtables.NewStatusTable(ctx, db.name, db.ddb, dsess.NewSessionStateAdapter(sess.Session, db.name, map[string]env.Remote{}, map[string]env.BranchConfig{}), db.drw), true } if found { return dt, found, nil @@ -789,7 +787,7 @@ func (db Database) createTempSQLTable(ctx *sql.Context, tableName string, sch sq return db.createTempDoltTable(ctx, tableName, tempTableRootValue, doltSch, sess) } -func (db Database) createTempDoltTable(ctx *sql.Context, tableName string, root *doltdb.RootValue, doltSch schema.Schema, dsess *dsess.Session) error { +func (db Database) createTempDoltTable(ctx *sql.Context, tableName string, root *doltdb.RootValue, doltSch schema.Schema, dsess *dsess.DoltSession) error { if exists, err := root.HasTable(ctx, tableName); err != nil { return err } else if exists { diff --git a/go/libraries/doltcore/sqle/dfunctions/commit.go b/go/libraries/doltcore/sqle/dfunctions/commit.go index dda14fa2fc..92e0bf00a1 100644 --- a/go/libraries/doltcore/sqle/dfunctions/commit.go +++ b/go/libraries/doltcore/sqle/dfunctions/commit.go @@ -60,8 +60,8 @@ func (cf *CommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } } else { - name = dSess.Username - email = dSess.Email + name = dSess.Username() + email = dSess.Email() } // Get the commit message. diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go b/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go index cae03156f2..79c3ec86b9 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go @@ -76,8 +76,8 @@ func (d DoltCommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return nil, err } } else { - name = dSess.Username - email = dSess.Email + name = dSess.Username() + email = dSess.Email() } msg, msgOk := apr.GetValue(cli.CommitMessageArg) diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go b/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go index c8513c4423..7360689fdc 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go @@ -120,7 +120,7 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) // and merging into working set. Returns a new WorkingSet and whether there were merge conflicts. This currently // persists merge commits in the database, but expects the caller to update the working set. // TODO FF merging commit with constraint violations requires `constraint verify` -func mergeIntoWorkingSet(ctx *sql.Context, sess *dsess.Session, roots doltdb.Roots, ws *doltdb.WorkingSet, dbName string, spec *merge.MergeSpec) (*doltdb.WorkingSet, int, error) { +func mergeIntoWorkingSet(ctx *sql.Context, sess *dsess.DoltSession, roots doltdb.Roots, ws *doltdb.WorkingSet, dbName string, spec *merge.MergeSpec) (*doltdb.WorkingSet, int, error) { if conflicts, err := roots.Working.HasConflicts(ctx); err != nil { return ws, noConflicts, err } else if conflicts { @@ -259,7 +259,7 @@ func executeFFMerge(ctx *sql.Context, squash bool, ws *doltdb.WorkingSet, dbData func executeNoFFMerge( ctx *sql.Context, - dSess *dsess.Session, + dSess *dsess.DoltSession, spec *merge.MergeSpec, dbName string, ws *doltdb.WorkingSet, @@ -309,7 +309,7 @@ func executeNoFFMerge( return ws, dSess.SetWorkingSet(ctx, dbName, ws.ClearMerge(), nil) } -func createMergeSpec(ctx *sql.Context, sess *dsess.Session, dbName string, apr *argparser.ArgParseResults, commitSpecStr string) (*merge.MergeSpec, error) { +func createMergeSpec(ctx *sql.Context, sess *dsess.DoltSession, dbName string, apr *argparser.ArgParseResults, commitSpecStr string) (*merge.MergeSpec, error) { ddb, ok := sess.GetDoltDB(ctx, dbName) dbData, ok := sess.GetDbData(ctx, dbName) @@ -328,8 +328,8 @@ func createMergeSpec(ctx *sql.Context, sess *dsess.Session, dbName string, apr * return nil, err } } else { - name = sess.Username - email = sess.Email + name = sess.Username() + email = sess.Email() } t := ctx.QueryTime() diff --git a/go/libraries/doltcore/sqle/dfunctions/merge.go b/go/libraries/doltcore/sqle/dfunctions/merge.go index 2cd7d46e6e..0a3556b503 100644 --- a/go/libraries/doltcore/sqle/dfunctions/merge.go +++ b/go/libraries/doltcore/sqle/dfunctions/merge.go @@ -66,8 +66,8 @@ func (cf *MergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } } else { - name = sess.Username - email = sess.Email + name = sess.Username() + email = sess.Email() } dbName := sess.GetCurrentDatabase() @@ -202,7 +202,7 @@ func getBranchCommit(ctx *sql.Context, val interface{}, ddb *doltdb.DoltDB) (*do return cm, cmh, nil } -func getHead(ctx *sql.Context, sess *dsess.Session, dbName string) (*doltdb.Commit, hash.Hash, *doltdb.RootValue, error) { +func getHead(ctx *sql.Context, sess *dsess.DoltSession, dbName string) (*doltdb.Commit, hash.Hash, *doltdb.RootValue, error) { head, err := sess.GetHeadCommit(ctx, dbName) if err != nil { return nil, hash.Hash{}, nil, err diff --git a/go/libraries/doltcore/sqle/dsess/dolt_session.go b/go/libraries/doltcore/sqle/dsess/dolt_session.go new file mode 100644 index 0000000000..21b57226ab --- /dev/null +++ b/go/libraries/doltcore/sqle/dsess/dolt_session.go @@ -0,0 +1,229 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dsess + +import ( + "fmt" + "strconv" + "sync" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/cmd/dolt/cli" + "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/utils/config" +) + +type DoltSession struct { + *Session + globalsConf config.ReadWriteConfig + mu *sync.Mutex +} + +var _ sql.Session = (*DoltSession)(nil) +var _ sql.PersistableSession = (*DoltSession)(nil) + +// NewDoltSession creates a DoltSession object from a standard sql.Session and 0 or more Database objects. +func NewDoltSession(ctx *sql.Context, sqlSess *sql.BaseSession, pro RevisionDatabaseProvider, conf *env.DoltCliConfig, dbs ...InitialDbState) (*DoltSession, error) { + sess, err := NewSession(ctx, sqlSess, pro, conf, dbs...) + if err != nil { + return nil, err + } + + var globals config.ReadWriteConfig + if localConf, ok := conf.GetConfig(env.LocalConfig); !ok { + ctx.Warn(NonpersistableSessionCode, "configured mode does not support persistable sessions; SET PERSIST will not write to file") + globals = config.NewMapConfig(make(map[string]string)) + } else { + globals = config.NewPrefixConfig(localConf, env.SqlServerGlobalsPrefix) + } + + return sess.NewDoltSession(globals), nil +} + +// PersistGlobal implements sql.PersistableSession +func (s *DoltSession) PersistGlobal(sysVarName string, value interface{}) error { + sysVar, _, err := validatePersistableSysVar(sysVarName) + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + return setPersistedValue(s.globalsConf, sysVar.Name, value) +} + +// RemovePersistedGlobal implements sql.PersistableSession +func (s *DoltSession) RemovePersistedGlobal(sysVarName string) error { + sysVar, _, err := validatePersistableSysVar(sysVarName) + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + return s.globalsConf.Unset([]string{sysVar.Name}) +} + +// RemoveAllPersistedGlobals implements sql.PersistableSession +func (s *DoltSession) RemoveAllPersistedGlobals() error { + allVars := make([]string, s.globalsConf.Size()) + i := 0 + s.globalsConf.Iter(func(k, v string) bool { + allVars[i] = k + i++ + return false + }) + + s.mu.Lock() + defer s.mu.Unlock() + return s.globalsConf.Unset(allVars) +} + +// RemoveAllPersistedGlobals implements sql.PersistableSession +func (s *DoltSession) GetPersistedValue(k string) (interface{}, error) { + return getPersistedValue(s.globalsConf, k) +} + +// SystemVariablesInConfig returns a list of System Variables associated with the session +func (s *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) { + return SystemVariablesInConfig(s.globalsConf) +} + +// validatePersistedSysVar checks whether a system variable exists and is dynamic +func validatePersistableSysVar(name string) (sql.SystemVariable, interface{}, error) { + sysVar, val, ok := sql.SystemVariables.GetGlobal(name) + if !ok { + return sql.SystemVariable{}, nil, sql.ErrUnknownSystemVariable.New(name) + } + if !sysVar.Dynamic { + return sql.SystemVariable{}, nil, sql.ErrSystemVariableReadOnly.New(name) + } + return sysVar, val, nil +} + +// getPersistedValue reads and converts a config value to the associated SystemVariable type +func getPersistedValue(conf config.ReadableConfig, k string) (interface{}, error) { + v, err := conf.GetString(k) + if err != nil { + return nil, err + } + + _, value, err := validatePersistableSysVar(k) + if err != nil { + return nil, err + } + + var res interface{} + switch value.(type) { + case int, int8, int16, int32, int64: + res, err = strconv.ParseInt(v, 10, 64) + case uint, uint8, uint16, uint32, uint64: + res, err = strconv.ParseUint(v, 10, 64) + case float32, float64: + res, err = strconv.ParseFloat(v, 64) + case string: + return v, nil + default: + return nil, sql.ErrInvalidType.New(value) + } + + if err != nil { + return nil, err + } + + return res, nil +} + +// setPersistedValue casts and persists a key value pair assuming thread safety +func setPersistedValue(conf config.WritableConfig, key string, value interface{}) error { + switch v := value.(type) { + case int: + return config.SetInt(conf, key, int64(v)) + case int8: + return config.SetInt(conf, key, int64(v)) + case int16: + return config.SetInt(conf, key, int64(v)) + case int32: + return config.SetInt(conf, key, int64(v)) + case int64: + return config.SetInt(conf, key, v) + case uint: + return config.SetUint(conf, key, uint64(v)) + case uint8: + return config.SetUint(conf, key, uint64(v)) + case uint16: + return config.SetUint(conf, key, uint64(v)) + case uint32: + return config.SetUint(conf, key, uint64(v)) + case uint64: + return config.SetUint(conf, key, v) + case float32: + return config.SetFloat(conf, key, float64(v)) + case float64: + return config.SetFloat(conf, key, v) + case string: + return config.SetString(conf, key, v) + default: + return sql.ErrInvalidType.New(v) + } +} + +// SystemVariablesInConfig returns system variables from the persisted config +func SystemVariablesInConfig(conf config.ReadableConfig) ([]sql.SystemVariable, error) { + allVars := make([]sql.SystemVariable, conf.Size()) + i := 0 + var err error + var sysVar sql.SystemVariable + var def interface{} + conf.Iter(func(k, v string) bool { + def, err = getPersistedValue(conf, k) + if err != nil { + err = fmt.Errorf("key: '%s'; %w", k, err) + return true + } + // getPeristedVal already checked for errors + sysVar, _, _ = sql.SystemVariables.GetGlobal(k) + sysVar.Default = def + allVars[i] = sysVar + i++ + return false + }) + if err != nil { + return nil, err + } + return allVars, nil +} + +var initMu = sync.Mutex{} + +func InitPersistedSystemVars(dEnv *env.DoltEnv) error { + initMu.Lock() + defer initMu.Unlock() + + var globals config.ReadWriteConfig + if localConf, ok := dEnv.Config.GetConfig(env.LocalConfig); !ok { + cli.Println("warning: multi-db mode does not support persistable sessions") + globals = config.NewMapConfig(make(map[string]string)) + } else { + globals = config.NewPrefixConfig(localConf, env.SqlServerGlobalsPrefix) + } + persistedGlobalVars, err := SystemVariablesInConfig(globals) + if err != nil { + return err + } + sql.SystemVariables.AddSystemVariables(persistedGlobalVars) + return nil +} diff --git a/go/libraries/doltcore/sqle/dsess/dolt_session_test.go b/go/libraries/doltcore/sqle/dsess/dolt_session_test.go new file mode 100644 index 0000000000..91a000f6d2 --- /dev/null +++ b/go/libraries/doltcore/sqle/dsess/dolt_session_test.go @@ -0,0 +1,194 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dsess + +import ( + "testing" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/stretchr/testify/assert" + "gopkg.in/src-d/go-errors.v1" + + "github.com/dolthub/dolt/go/libraries/utils/config" +) + +func TestDoltSessionInit(t *testing.T) { + sess := DefaultSession() + conf := config.NewMapConfig(make(map[string]string)) + dsess := sess.NewDoltSession(conf) + assert.Equal(t, conf, dsess.globalsConf) +} + +func TestNewPersistedSystemVariables(t *testing.T) { + sess := DefaultSession() + conf := config.NewMapConfig(map[string]string{"max_connections": "1000"}) + dsess := sess.NewDoltSession(conf) + sysVars, err := dsess.SystemVariablesInConfig() + assert.NoError(t, err) + + maxConRes := sysVars[0] + assert.Equal(t, "max_connections", maxConRes.Name) + assert.Equal(t, int64(1000), maxConRes.Default) + +} + +func TestValidatePeristableSystemVar(t *testing.T) { + tests := []struct { + Name string + Err *errors.Kind + }{ + { + Name: "max_connections", + Err: nil, + }, + { + Name: "init_file", + Err: sql.ErrSystemVariableReadOnly, + }, + { + Name: "unknown", + Err: sql.ErrUnknownSystemVariable, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + if sysVar, _, err := validatePersistableSysVar(tt.Name); tt.Err != nil { + assert.True(t, tt.Err.Is(err)) + } else { + assert.Equal(t, tt.Name, sysVar.Name) + + } + }) + } +} + +func TestSetPersistedValue(t *testing.T) { + tests := []struct { + Name string + Value interface{} + ExpectedRes interface{} + Err *errors.Kind + }{ + { + Name: "int", + Value: 7, + }, + { + Name: "int8", + Value: int8(7), + }, + { + Name: "int16", + Value: int16(7), + }, + { + Name: "int32", + Value: int32(7), + }, + { + Name: "int64", + Value: int64(7), + }, + { + Name: "uint", + Value: uint(7), + }, + { + Name: "uint8", + Value: uint8(7), + }, + { + Name: "uint16", + Value: uint16(7), + }, + { + Name: "uint32", + Value: uint32(7), + }, + { + Name: "uint64", + Value: uint64(7), + }, + { + Name: "float32", + Value: float32(7), + ExpectedRes: "7.00000000", + }, + { + Name: "float64", + Value: float64(7), + ExpectedRes: "7.00000000", + }, + { + Name: "string", + Value: "7", + }, + { + Value: complex64(7), + Err: sql.ErrInvalidType, + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + conf := config.NewMapConfig(make(map[string]string)) + if err := setPersistedValue(conf, "key", tt.Value); tt.Err != nil { + assert.True(t, tt.Err.Is(err)) + } else if tt.ExpectedRes == nil { + assert.Equal(t, "7", conf.GetStringOrDefault("key", "")) + } else { + assert.Equal(t, tt.ExpectedRes, conf.GetStringOrDefault("key", "")) + + } + }) + } +} + +func TestGetPersistedValue(t *testing.T) { + tests := []struct { + Name string + ExpectedRes interface{} + Err *errors.Kind + }{ + { + Name: "long_query_time", + ExpectedRes: float64(7), + }, + { + Name: "tls_ciphersuites", + ExpectedRes: "7", + }, + { + Name: "max_connections", + ExpectedRes: int64(7), + }, + { + Name: "tmp_table_size", + ExpectedRes: uint64(7), + }, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + conf := config.NewMapConfig(map[string]string{tt.Name: "7"}) + if val, err := getPersistedValue(conf, tt.Name); tt.Err != nil { + assert.True(t, tt.Err.Is(err)) + } else { + assert.Equal(t, tt.ExpectedRes, val) + } + }) + } +} diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index a421e13f89..0a983abd6e 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -19,6 +19,7 @@ import ( "fmt" "os" "strings" + "sync" "github.com/dolthub/go-mysql-server/sql" @@ -46,6 +47,8 @@ const ( DoltDefaultBranchKey = "dolt_default_branch" ) +const NonpersistableSessionCode = 1105 // default + var transactionMergeStomp = false type batchMode int8 @@ -132,10 +135,9 @@ func IsWorkingKey(key string) (bool, string) { // Session is the sql.Session implementation used by dolt. It is accessible through a *sql.Context instance type Session struct { sql.Session - BatchMode batchMode - Username string - Email string - Config config.ReadableConfig + batchMode batchMode + username string + email string dbStates map[string]*DatabaseSessionState provider RevisionDatabaseProvider } @@ -181,8 +183,8 @@ var _ sql.Session = &Session{} func DefaultSession() *Session { sess := &Session{ Session: sql.NewBaseSession(), - Username: "", - Email: "", + username: "", + email: "", dbStates: make(map[string]*DatabaseSessionState), provider: emptyRevisionDatabaseProvider{}, } @@ -208,15 +210,13 @@ type InitialDbState struct { } // NewSession creates a Session object from a standard sql.Session and 0 or more Database objects. -func NewSession(ctx *sql.Context, sqlSess sql.Session, pro RevisionDatabaseProvider, conf config.ReadableConfig, dbs ...InitialDbState) (*Session, error) { +func NewSession(ctx *sql.Context, sqlSess *sql.BaseSession, pro RevisionDatabaseProvider, conf config.ReadableConfig, dbs ...InitialDbState) (*Session, error) { username := conf.GetStringOrDefault(env.UserNameKey, "") email := conf.GetStringOrDefault(env.UserEmailKey, "") - sess := &Session{ Session: sqlSess, - Username: username, - Email: email, - Config: conf, + username: username, + email: email, dbStates: make(map[string]*DatabaseSessionState), provider: pro, } @@ -236,12 +236,12 @@ func NewSession(ctx *sql.Context, sqlSess sql.Session, pro RevisionDatabaseProvi // Sessions operating in batched mode don't flush any edit buffers except when told to do so explicitly, or when a // transaction commits. Disable @@autocommit to prevent edit buffers from being flushed prematurely in this mode. func (sess *Session) EnableBatchedMode() { - sess.BatchMode = Batched + sess.batchMode = Batched } // DSessFromSess retrieves a dolt session from a standard sql.Session -func DSessFromSess(sess sql.Session) *Session { - return sess.(*Session) +func DSessFromSess(sess sql.Session) *DoltSession { + return sess.(*DoltSession) } // LookupDbState returns the session state for the database named @@ -279,6 +279,10 @@ func (sess *Session) LookupDbState(ctx *sql.Context, dbName string) (*DatabaseSe return s, ok, err } +func (sess *Session) GetDbStates() map[string]*DatabaseSessionState { + return sess.dbStates +} + // Flush flushes all changes sitting in edit sessions to the session root for the database named. This normally // happens automatically as part of statement execution, and is only necessary when the session is manually batched (as // for bulk SQL import) @@ -364,7 +368,7 @@ func (sess *Session) newWorkingSetForHead(ctx *sql.Context, wsRef ref.WorkingSet // CommitTransaction commits the in-progress transaction for the database named. Depending on session settings, this // may write only a new working set, or may additionally create a new dolt commit for the current HEAD. func (sess *Session) CommitTransaction(ctx *sql.Context, dbName string, tx sql.Transaction) error { - if sess.BatchMode == Batched { + if sess.BatchMode() == Batched { err := sess.Flush(ctx, dbName) if err != nil { return err @@ -397,8 +401,8 @@ func (sess *Session) CommitTransaction(ctx *sql.Context, dbName string, tx sql.T Date: ctx.QueryTime(), AllowEmpty: false, Force: false, - Name: sess.Username, - Email: sess.Email, + Name: sess.Username(), + Email: sess.Email(), }) if err != nil { return err @@ -1157,6 +1161,18 @@ func (sess *Session) CWBHeadRef(ctx *sql.Context, dbName string) (ref.DoltRef, e return dbState.WorkingSet.Ref().ToHeadRef() } +func (sess *Session) Username() string { + return sess.username +} + +func (sess *Session) Email() string { + return sess.email +} + +func (sess *Session) BatchMode() batchMode { + return sess.batchMode +} + // setSessionVarsForDb updates the three session vars that track the value of the session root hashes func (sess *Session) setSessionVarsForDb(ctx *sql.Context, dbName string) error { state, _, err := sess.lookupDbState(ctx, dbName) @@ -1208,6 +1224,11 @@ func (sess *Session) setSessionVarsForDb(ctx *sql.Context, dbName string) error return nil } +// NewDoltSession creates a persistable DoltSession with the given config arg +func (sess *Session) NewDoltSession(conf config.ReadWriteConfig) *DoltSession { + return &DoltSession{Session: sess, globalsConf: conf, mu: &sync.Mutex{}} +} + // defineSystemVariables defines dolt-session variables in the engine as necessary func defineSystemVariables(name string) { if _, _, ok := sql.SystemVariables.GetGlobal(name + HeadKeySuffix); !ok { @@ -1244,14 +1265,6 @@ func defineSystemVariables(name string) { Type: sql.NewSystemStringType(StagedKey(name)), Default: "", }, - { - Name: DoltDefaultBranchKey, - Scope: sql.SystemVariableScope_Global, - Dynamic: true, - SetVarHintApplies: false, - Type: sql.NewSystemStringType(DoltDefaultBranchKey), - Default: "", - }, }) } } diff --git a/go/libraries/doltcore/sqle/dsess/session_state_adapter.go b/go/libraries/doltcore/sqle/dsess/session_state_adapter.go index 4e5cc3a48d..b0c7b0ffa0 100644 --- a/go/libraries/doltcore/sqle/dsess/session_state_adapter.go +++ b/go/libraries/doltcore/sqle/dsess/session_state_adapter.go @@ -86,11 +86,11 @@ func NewSessionStateAdapter(session *Session, dbName string, remotes map[string] } func (s SessionStateAdapter) GetRoots(ctx context.Context) (doltdb.Roots, error) { - return s.session.dbStates[s.dbName].GetRoots(), nil + return s.session.GetDbStates()[s.dbName].GetRoots(), nil } func (s SessionStateAdapter) CWBHeadRef() ref.DoltRef { - workingSet := s.session.dbStates[s.dbName].WorkingSet + workingSet := s.session.GetDbStates()[s.dbName].WorkingSet headRef, err := workingSet.Ref().ToHeadRef() // TODO: fix this interface if err != nil { @@ -110,15 +110,15 @@ func (s SessionStateAdapter) CWBHeadSpec() *doltdb.CommitSpec { } func (s SessionStateAdapter) IsMergeActive(ctx context.Context) (bool, error) { - return s.session.dbStates[s.dbName].WorkingSet.MergeActive(), nil + return s.session.GetDbStates()[s.dbName].WorkingSet.MergeActive(), nil } func (s SessionStateAdapter) GetMergeCommit(ctx context.Context) (*doltdb.Commit, error) { - return s.session.dbStates[s.dbName].WorkingSet.MergeState().Commit(), nil + return s.session.GetDbStates()[s.dbName].WorkingSet.MergeState().Commit(), nil } func (s SessionStateAdapter) GetPreMergeWorking(ctx context.Context) (*doltdb.RootValue, error) { - return s.session.dbStates[s.dbName].WorkingSet.MergeState().PreMergeWorkingRoot(), nil + return s.session.GetDbStates()[s.dbName].WorkingSet.MergeState().PreMergeWorkingRoot(), nil } func (s SessionStateAdapter) GetRemotes() (map[string]env.Remote, error) { diff --git a/go/libraries/doltcore/sqle/dsess/transactions.go b/go/libraries/doltcore/sqle/dsess/transactions.go index 0eafab42ba..af7aa26a67 100644 --- a/go/libraries/doltcore/sqle/dsess/transactions.go +++ b/go/libraries/doltcore/sqle/dsess/transactions.go @@ -364,8 +364,8 @@ func (tx *DoltTransaction) ClearSavepoint(name string) *doltdb.RootValue { func (tx DoltTransaction) getWorkingSetMeta(ctx *sql.Context) *doltdb.WorkingSetMeta { sess := DSessFromSess(ctx.Session) return &doltdb.WorkingSetMeta{ - User: sess.Username, - Email: sess.Email, + User: sess.Username(), + Email: sess.Email(), Timestamp: uint64(time.Now().Unix()), Description: "sql transaction", } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go index 2e001666d5..161dc4946b 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go @@ -19,8 +19,13 @@ import ( "github.com/dolthub/go-mysql-server/enginetest" "github.com/dolthub/go-mysql-server/sql" + "github.com/stretchr/testify/require" + "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils" + "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/sqle" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/libraries/utils/config" ) func init() { @@ -447,3 +452,19 @@ func TestTestReadOnlyDatabases(t *testing.T) { func TestAddDropPks(t *testing.T) { enginetest.TestAddDropPks(t, newDoltHarness(t)) } + +func TestPersist(t *testing.T) { + harness := newDoltHarness(t) + dEnv := dtestutils.CreateTestEnv() + localConf, ok := dEnv.Config.GetConfig(env.LocalConfig) + require.True(t, ok) + globals := config.NewPrefixConfig(localConf, env.SqlServerGlobalsPrefix) + newPersistableSession := func(ctx *sql.Context) sql.PersistableSession { + session := ctx.Session.(*dsess.DoltSession).Session.NewDoltSession(globals) + err := session.RemoveAllPersistedGlobals() + require.NoError(t, err) + return session + } + + enginetest.TestPersist(t, harness, newPersistableSession) +} diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go index 65e514f87d..105198d348 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go @@ -40,7 +40,7 @@ const ( type DoltHarness struct { t *testing.T env *env.DoltEnv - session *dsess.Session + session *dsess.DoltSession databases []sqle.Database databaseGlobalStates []globalstate.GlobalState parallelism int @@ -58,7 +58,7 @@ var _ enginetest.ReadOnlyDatabaseHarness = (*DoltHarness)(nil) func newDoltHarness(t *testing.T) *DoltHarness { dEnv := dtestutils.CreateTestEnv() pro := sqle.NewDoltDatabaseProvider(dEnv.Config) - session, err := dsess.NewSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), pro, dEnv.Config) + session, err := dsess.NewDoltSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), pro, dEnv.Config) require.NoError(t, err) return &DoltHarness{ t: t, @@ -127,13 +127,13 @@ func (d *DoltHarness) NewContext() *sql.Context { func (d *DoltHarness) NewSession() *sql.Context { states := make([]dsess.InitialDbState, len(d.databases)) for i, db := range d.databases { - states[i] = getDbState(d.t, db, d.env, d.databaseGlobalStates[i]) + states[i] = getDbState(d.t, db, d.env) } dbs := dsqleDBsAsSqlDBs(d.databases) pro := d.NewDatabaseProvider(dbs...) var err error - d.session, err = dsess.NewSession( + d.session, err = dsess.NewDoltSession( enginetest.NewContext(d), enginetest.NewBaseSession(), pro.(dsess.RevisionDatabaseProvider), @@ -196,7 +196,7 @@ func (d *DoltHarness) NewDatabaseProvider(dbs ...sql.Database) sql.MutableDataba return sqle.NewDoltDatabaseProvider(d.env.Config, dbs...) } -func getDbState(t *testing.T, db sqle.Database, dEnv *env.DoltEnv, globalState globalstate.GlobalState) dsess.InitialDbState { +func getDbState(t *testing.T, db sqle.Database, dEnv *env.DoltEnv) dsess.InitialDbState { ctx := context.Background() head := dEnv.RepoStateReader().CWBHeadSpec() diff --git a/go/libraries/doltcore/sqle/logictest/dolt/doltharness.go b/go/libraries/doltcore/sqle/logictest/dolt/doltharness.go index 9d7b3030db..80931f8e50 100644 --- a/go/libraries/doltcore/sqle/logictest/dolt/doltharness.go +++ b/go/libraries/doltcore/sqle/logictest/dolt/doltharness.go @@ -50,7 +50,7 @@ const ( type DoltHarness struct { Version string engine *sqle.Engine - sess *dsess.Session + sess *dsess.DoltSession } func (h *DoltHarness) EngineStr() string { @@ -133,11 +133,8 @@ func innerInit(h *DoltHarness, dEnv *env.DoltEnv) error { return err } - h.sess = dsess.DefaultSession() - - ctx := sql.NewContext( - context.Background(), - sql.WithSession(h.sess)) + ctx := dsql.NewTestSQLCtx(context.Background()) + h.sess = ctx.Session.(*dsess.DoltSession) dbs := h.engine.Analyzer.Catalog.AllDatabases() dsqlDBs := make([]dsql.Database, len(dbs)) diff --git a/go/libraries/doltcore/sqle/mergeable_indexes_setup_test.go b/go/libraries/doltcore/sqle/mergeable_indexes_setup_test.go index 51c6942f56..38e3134357 100644 --- a/go/libraries/doltcore/sqle/mergeable_indexes_setup_test.go +++ b/go/libraries/doltcore/sqle/mergeable_indexes_setup_test.go @@ -104,7 +104,9 @@ func setupMergeableIndexes(t *testing.T, tableName, insertQuery string) (*sqle.E // Get an updated root to use for the rest of the test ctx := sql.NewEmptyContext() - roots, ok := dsess.DSessFromSess(sqlCtx.Session).GetRoots(ctx, mergeableDb.Name()) + sess, err := dsess.NewDoltSession(ctx, ctx.Session.(*sql.BaseSession), pro, dEnv.Config, getDbState(t, db, dEnv)) + require.NoError(t, err) + roots, ok := sess.GetRoots(ctx, mergeableDb.Name()) require.True(t, ok) return engine, dEnv, mergeableDb, []*indexTuple{ diff --git a/go/libraries/doltcore/sqle/sqlpersist_test.go b/go/libraries/doltcore/sqle/sqlpersist_test.go new file mode 100644 index 0000000000..7343f76ebf --- /dev/null +++ b/go/libraries/doltcore/sqle/sqlpersist_test.go @@ -0,0 +1,121 @@ +// Copyright 2020 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqle + +import ( + "context" + "testing" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils" + "github.com/dolthub/dolt/go/libraries/doltcore/row" + "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/store/types" +) + +// Structure for a test of a insert query +type PersistTest struct { + // The name of this test. Names should be unique and descriptive. + Name string + // The insert query to run + PersistQuery string + // The insert query to run + SelectQuery string + // The schema of the result of the query, nil if an error is expected + ExpectedSchema schema.Schema + // The rows this query should return, nil if an error is expected + ExpectedRows []sql.Row + // The rows this query should return, nil if an error is expected + ExpectedConfig map[string]string + // An expected error string + ExpectedErr string + // Setup logic to run before executing this test, after initial tables have been created and populated + AdditionalSetup SetupFn +} + +const maxConnTag = 0 + +var MaxConnSchema = createMaxConnSchema() + +func createMaxConnSchema() schema.Schema { + colColl := schema.NewColCollection( + schema.NewColumn("@@GLOBAL.max_connections", maxConnTag, types.IntKind, false, schema.NotNullConstraint{}), + ) + return schema.MustSchemaFromCols(colColl) +} + +func NewMaxConnRow(value int) row.Row { + vals := row.TaggedValues{ + maxConnTag: types.Int(value), + } + + r, _ := row.New(types.Format_Default, MaxConnSchema, vals) + return r +} + +func TestExecutePersist(t *testing.T) { + var persistTests = []PersistTest{ + { + Name: "SET PERSIST a system variable", + PersistQuery: "SET PERSIST max_connections = 1000;", + ExpectedConfig: map[string]string{"max_connections": "1000"}, + SelectQuery: "SELECT @@GLOBAL.max_connections", + ExpectedRows: ToSqlRows(MaxConnSchema, NewMaxConnRow(1000)), + }, + { + Name: "PERSIST ONLY a system variable", + PersistQuery: "SET PERSIST_ONLY max_connections = 1000;", + ExpectedConfig: map[string]string{"max_connections": "1000"}, + SelectQuery: "SELECT @@GLOBAL.max_connections", + ExpectedRows: ToSqlRows(MaxConnSchema, NewMaxConnRow(151)), + }, + } + for _, test := range persistTests { + t.Run(test.Name, func(t *testing.T) { + testPersistQuery(t, test) + }) + } +} + +// Tests the given query on a freshly created dataset, asserting that the result has the given schema and rows. If +// expectedErr is set, asserts instead that the execution returns an error that matches. +func testPersistQuery(t *testing.T, test PersistTest) { + dEnv := dtestutils.CreateTestEnv() + CreateEmptyTestDatabase(dEnv, t) + + if test.AdditionalSetup != nil { + test.AdditionalSetup(t, dEnv) + } + + sql.InitSystemVariables() + + var err error + root, _ := dEnv.WorkingRoot(context.Background()) + root, err = executeModify(t, context.Background(), dEnv, root, test.PersistQuery) + if len(test.ExpectedErr) > 0 { + require.Error(t, err) + return + } else { + require.NoError(t, err) + } + + actualRows, _, err := executeSelect(t, context.Background(), dEnv, root, test.SelectQuery) + require.NoError(t, err) + + assert.Equal(t, test.ExpectedRows, actualRows) +} diff --git a/go/libraries/doltcore/sqle/table_editor.go b/go/libraries/doltcore/sqle/table_editor.go index 3bcb18888e..f46649c134 100644 --- a/go/libraries/doltcore/sqle/table_editor.go +++ b/go/libraries/doltcore/sqle/table_editor.go @@ -207,7 +207,7 @@ func (te *sqlTableEditor) Close(ctx *sql.Context) error { sess := dsess.DSessFromSess(ctx.Session) // If we're running in batched mode, don't flush the edits until explicitly told to do so - if sess.BatchMode == dsess.Batched { + if sess.BatchMode() == dsess.Batched { return nil } return te.flush(ctx) diff --git a/go/libraries/doltcore/sqle/tables.go b/go/libraries/doltcore/sqle/tables.go index badb4a175c..a970da9055 100644 --- a/go/libraries/doltcore/sqle/tables.go +++ b/go/libraries/doltcore/sqle/tables.go @@ -510,7 +510,7 @@ func (t *WritableDoltTable) getTableEditor(ctx *sql.Context) (*sqlTableEditor, e sess := dsess.DSessFromSess(ctx.Session) // In batched mode, reuse the same table editor. Otherwise, hand out a new one - if sess.BatchMode == dsess.Batched { + if sess.BatchMode() == dsess.Batched { if t.ed != nil { return t.ed, nil } diff --git a/go/libraries/doltcore/sqle/testutil.go b/go/libraries/doltcore/sqle/testutil.go index e53c908986..4323386331 100644 --- a/go/libraries/doltcore/sqle/testutil.go +++ b/go/libraries/doltcore/sqle/testutil.go @@ -31,6 +31,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" + config2 "github.com/dolthub/dolt/go/libraries/utils/config" ) // ExecuteSql executes all the SQL non-select statements given in the string against the root value given and returns @@ -97,9 +98,10 @@ func ExecuteSql(t *testing.T, dEnv *env.DoltEnv, root *doltdb.RootValue, stateme // NewTestSQLCtx returns a new *sql.Context with a default DoltSession, a new IndexRegistry, and a new ViewRegistry func NewTestSQLCtx(ctx context.Context) *sql.Context { session := dsess.DefaultSession() + dsess := session.NewDoltSession(config2.NewMapConfig(make(map[string]string))) sqlCtx := sql.NewContext( ctx, - sql.WithSession(session), + sql.WithSession(dsess), ).WithCurrentDB("dolt") return sqlCtx diff --git a/go/libraries/utils/config/config.go b/go/libraries/utils/config/config.go index 023a92cead..f8c5f6d6f1 100644 --- a/go/libraries/utils/config/config.go +++ b/go/libraries/utils/config/config.go @@ -107,6 +107,10 @@ func SetStrings(c WritableConfig, updates map[string]string) error { return c.SetStrings(updates) } +func SetString(c WritableConfig, key string, val string) error { + return c.SetStrings(map[string]string{key: val}) +} + // SetInt sets a value in the WritableConfig for a given key to the string converted value of an integer func SetInt(c WritableConfig, key string, val int64) error { s := strconv.FormatInt(val, 10) diff --git a/go/libraries/utils/config/config_hierarchy.go b/go/libraries/utils/config/config_hierarchy.go index 4de77774ad..336c6cc913 100644 --- a/go/libraries/utils/config/config_hierarchy.go +++ b/go/libraries/utils/config/config_hierarchy.go @@ -23,6 +23,8 @@ const ( namespaceSep = "::" ) +var ErrUnknownConfig = errors.New("config not found") + // ConfigHierarchy is a hierarchical read-only configuration store. When a key is looked up in the ConfigHierarchy it // will go through its configs in order and will return the first value for a given key that is found. Configs are // iterated in order, so the configurations added first have the highest priority. diff --git a/go/libraries/utils/config/config_test.go b/go/libraries/utils/config/map_config_test.go similarity index 100% rename from go/libraries/utils/config/config_test.go rename to go/libraries/utils/config/map_config_test.go diff --git a/go/libraries/utils/config/prefix_config.go b/go/libraries/utils/config/prefix_config.go new file mode 100644 index 0000000000..05ba67e030 --- /dev/null +++ b/go/libraries/utils/config/prefix_config.go @@ -0,0 +1,77 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "strings" +) + +// PrefixConfig decorates read and write access to the underlying config by appending a prefix to the accessed keys +// on reads and writes. Used for the sqlserver.global persisted system variables. +type PrefixConfig struct { + c ReadWriteConfig + prefix string +} + +func NewPrefixConfig(cfg ReadWriteConfig, prefix string) PrefixConfig { + return PrefixConfig{c: cfg, prefix: prefix} +} + +func (nsc PrefixConfig) path(key string) string { + return fmt.Sprintf("%s.%s", nsc.prefix, key) +} + +func (nsc PrefixConfig) GetString(key string) (value string, err error) { + return nsc.c.GetString(nsc.path(key)) +} + +func (nsc PrefixConfig) GetStringOrDefault(key, defStr string) string { + return nsc.c.GetStringOrDefault(nsc.path(key), defStr) +} + +func (nsc PrefixConfig) SetStrings(updates map[string]string) error { + for k, v := range updates { + delete(updates, k) + updates[nsc.path(k)] = v + } + return nsc.c.SetStrings(updates) +} + +func (nsc PrefixConfig) Iter(cb func(string, string) (stop bool)) { + nsc.c.Iter(func(k, v string) (stop bool) { + if strings.HasPrefix(k, nsc.prefix+".") { + return cb(strings.TrimPrefix(k, nsc.prefix+"."), v) + } + return false + }) + return +} + +func (nsc PrefixConfig) Size() int { + count := 0 + nsc.Iter(func(k, v string) (stop bool) { + count += 1 + return false + }) + return count +} + +func (nsc PrefixConfig) Unset(params []string) error { + for i, k := range params { + params[i] = nsc.path(k) + } + return nsc.c.Unset(params) +} diff --git a/go/libraries/utils/config/prefix_config_test.go b/go/libraries/utils/config/prefix_config_test.go new file mode 100644 index 0000000000..e53956090c --- /dev/null +++ b/go/libraries/utils/config/prefix_config_test.go @@ -0,0 +1,112 @@ +// Copyright 2020 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" +) + +var ConfigVals = map[string]string{ + "scopeA.k1": "v1", + "scopeA.k2": "v2", + "scopeB.k3": "v3", + "k1": "v1", +} + +func newConfigVals() map[string]string { + newConfig := make(map[string]string) + for k, v := range ConfigVals { + newConfig[k] = v + } + return newConfig +} + +func newPrefixConfig(prefix string) PrefixConfig { + + mc := NewMapConfig(newConfigVals()) + return NewPrefixConfig(mc, prefix) + +} + +func TestPrefixConfigSet(t *testing.T) { + conf := newPrefixConfig("test") + conf.SetStrings(newConfigVals()) + v1, _ := conf.c.GetString("test.k1") + assert.Equal(t, v1, "v1") +} + +func TestPrefixConfigGet(t *testing.T) { + t.Run("test GetString", func(t *testing.T) { + conf := newPrefixConfig("scopeA") + v1, _ := conf.GetString("k1") + assert.Equal(t, "v1", v1) + }) + + t.Run("test GetString fails out of scope", func(t *testing.T) { + conf := newPrefixConfig("scopeA") + _, err := conf.GetString("k3") + assert.Equal(t, err, ErrConfigParamNotFound) + }) + + t.Run("test GetStringofDefault", func(t *testing.T) { + conf := newPrefixConfig("scopeA") + v1, _ := conf.GetString("k1") + assert.Equal(t, "v1", v1) + }) + + t.Run("test GetStringOrDefault fails out of scope", func(t *testing.T) { + conf := newPrefixConfig("scopeA") + res := conf.GetStringOrDefault("k3", "default") + assert.Equal(t, "default", res) + }) +} + +func TestPrefixConfigUnset(t *testing.T) { + t.Run("test Unset", func(t *testing.T) { + conf := newPrefixConfig("scopeA") + err := conf.Unset([]string{"k1"}) + assert.NoError(t, err) + res := conf.GetStringOrDefault("k3", "default") + assert.Equal(t, "default", res) + }) + + t.Run("test Unset doesn't affect other scope", func(t *testing.T) { + conf := newPrefixConfig("scopeA") + err := conf.Unset([]string{"k1"}) + assert.NoError(t, err) + res := conf.c.GetStringOrDefault("k1", "") + assert.Equal(t, "v1", res) + }) +} + +func TestPrefixConfigSize(t *testing.T) { + conf := newPrefixConfig("scopeA") + size := conf.Size() + assert.Equal(t, size, 2) +} + +func TestPrefixConfigIter(t *testing.T) { + conf := newPrefixConfig("scopeA") + keys := make([]string, 0, 6) + conf.Iter(func(k, v string) bool { + keys = append(keys, k) + return false + }) + sort.Strings(keys) + assert.Equal(t, []string{"k1", "k2"}, keys) +} diff --git a/integration-tests/bats/replication.bats b/integration-tests/bats/replication.bats index 8aa949fc19..e60b779c65 100644 --- a/integration-tests/bats/replication.bats +++ b/integration-tests/bats/replication.bats @@ -30,12 +30,24 @@ teardown() { [ ! -d "../bac1/.dolt" ] || false } -@test "replication: push on commit" { +@test "replication: no push on cli commit" { + cd repo1 - dolt config --local --add DOLT_REPLICATE_TO_REMOTE backup1 + dolt config --local --add sqlserver.global.DOLT_REPLICATE_TO_REMOTE backup1 dolt sql -q "create table t1 (a int primary key)" dolt commit -am "cm" + cd .. + run dolt clone file://./bac1 repo2 + [ "$status" -eq 1 ] +} + +@test "replication: push on cli engine commit" { + cd repo1 + dolt config --local --add sqlserver.global.DOLT_REPLICATE_TO_REMOTE backup1 + dolt sql -q "create table t1 (a int primary key)" + dolt sql -q "select dolt_commit('-am', 'cm')" + cd .. dolt clone file://./bac1 repo2 cd repo2 @@ -47,7 +59,7 @@ teardown() { @test "replication: no tags" { cd repo1 - dolt config --local --add DOLT_REPLICATE_TO_REMOTE backup1 + dolt config --local --add sqlserver.global.DOLT_REPLICATE_TO_REMOTE backup1 dolt tag [ ! -d "../bac1/.dolt" ] || false @@ -66,7 +78,7 @@ teardown() { [ "${#lines[@]}" -eq 1 ] [[ ! "$output" =~ "t1" ]] || false - dolt config --local --add DOLT_READ_REPLICA_REMOTE remote1 + dolt config --local --add sqlserver.global.DOLT_READ_REPLICA_REMOTE remote1 run dolt sql -q "show tables" -r csv [ "$status" -eq 0 ] [ "${#lines[@]}" -eq 2 ] @@ -75,10 +87,9 @@ teardown() { @test "replication: replicate on branch table update" { cd repo1 - dolt config --local --add DOLT_REPLICATE_TO_REMOTE backup1 + dolt config --local --add sqlserver.global.DOLT_REPLICATE_TO_REMOTE backup1 dolt sql -q "create table t1 (a int primary key)" dolt sql -q "UPDATE dolt_branches SET hash = COMMIT('--author', '{user_name} <{email_address}>','-m', 'cm') WHERE name = 'main' AND hash = @@repo1_head" - noms ds ../bac1/.dolt cd .. dolt clone file://./bac1 repo2 diff --git a/integration-tests/bats/sql-config.bats b/integration-tests/bats/sql-config.bats new file mode 100644 index 0000000000..1309f4f068 --- /dev/null +++ b/integration-tests/bats/sql-config.bats @@ -0,0 +1,112 @@ +#!/usr/bin/env bats +load $BATS_TEST_DIRNAME/helper/common.bash + +setup() { + mkdir $BATS_TMPDIR/dolt-repo-$$ + cd $BATS_TMPDIR/dolt-repo-$$ + setup_common +} + +teardown() { + teardown_common + rm -rf "$BATS_TMPDIR/config-test$$" +} + +@test "sql-config: query persisted variable with cli engine" { + echo '{"sqlserver.global.max_connections":"1000"}' > .dolt/config.json + run dolt sql -q "SELECT @@GLOBAL.max_connections" -r csv + [ "$status" -eq 0 ] + [ "${#lines[@]}" -eq 2 ] + [[ "${lines[0]}" =~ "@@GLOBAL.max_connections" ]] || false + [[ "${lines[1]}" =~ "1000" ]] || false +} + +@test "sql-config: set persist global variable with cli engine" { + dolt sql -q "SET PERSIST max_connections = 1000" + run dolt config --local --list + [ "$status" -eq 0 ] + [[ "$output" =~ "sqlserver.global.max_connections = 1000" ]] || false +} + +@test "sql-config: set persist multiple global variables with cli engine" { + dolt sql -q "SET PERSIST max_connections = 1000" + dolt sql -q "SET PERSIST auto_increment_increment = 2" + run dolt config --local --list + [ "$status" -eq 0 ] + [[ "$output" =~ "sqlserver.global.auto_increment_increment = 2" ]] || false + [[ "$output" =~ "sqlserver.global.max_connections = 1000" ]] || false +} + +@test "sql-config: persist only global variable with cli engine" { + dolt sql -q "SET PERSIST_ONLY max_connections = 1000" + run dolt config --local --list + [ "$status" -eq 0 ] + [[ "$output" =~ "sqlserver.global.max_connections = 1000" ]] || false +} + + +@test "sql-config: remove persisted variable with cli engine" { + skip "TODO parser support for RESET PERSIST" + + dolt sql -q "SET PERSIST_ONLY max_connections = 1000" + dolt sql -q "RESET PERSIST max_connections" + + run dolt config --local --list + [ "$status" -eq 0 ] + [[ ! "$output" =~ "sqlserver.global.max_connections = 1000" ]] || false +} + +@test "sql-config: remove all persisted variables with cli engine" { + skip "TODO parser support for RESET PERSIST" + + dolt sql -q "SET PERSIST_ONLY max_connections = 1000" + dolt sql -q "SET PERSIST_ONLY auto_increment_increment = 2" + dolt sql -q "RESET PERSIST" + + run dolt config --local --list + [ "$status" -eq 0 ] + [[ ! "$output" =~ "sqlserver.global.max_connections = 1000" ]] || false + [[ ! "$output" =~ "sqlserver.global.auto_increment_increment = 2" ]] || false +} + +@test "sql-config: persist dolt specific global variable" { + mkdir repo1 + cd repo1 + dolt init + dolt sql -q "SET PERSIST_ONLY repo1_head = 1000" + run dolt config --local --list + [ "$status" -eq 0 ] + [[ "$output" =~ "sqlserver.global.repo1_head = 1000" ]] || false +} + +@test "sql-config: persist invalid variable name" { + run dolt sql -q "SET PERSIST unknown = 1000" + [ "$status" -eq 1 ] + [[ ! "$output" =~ "panic" ]] || false + [[ "$output" =~ "Unknown system variable 'unknown'" ]] || false +} + +@test "sql-config: persist invalid variable type" { + run dolt sql -q "SET PERSIST max_connections = string" + [ "$status" -eq 1 ] + [[ ! "$output" =~ "panic" ]] || false + [[ "$output" =~ "Variable 'max_connections' can't be set to the value of 'string'" ]] || false +} + +@test "sql-config: invalid persisted system variable name errors on cli sql command" { + echo '{"sqlserver.global.unknown":"1000"}' > .dolt/config.json + run dolt sql -q "SELECT @@GLOBAL.unknown" -r csv + [ "$status" -eq 1 ] + [[ ! "$output" =~ "panic" ]] + [[ "$output" =~ "Unknown system variable 'unknown'" ]] || false +} + +@test "sql-config: invalid persisted system variable type errors on cli sql command" { + echo '{"sqlserver.global.max_connections":"string"}' > .dolt/config.json + run dolt sql -q "SELECT @@GLOBAL.max_connections" -r csv + [ "$status" -eq 0 ] + [[ ! "$output" =~ "panic" ]] + [[ "$output" =~ "failed to load persisted global variables: key: 'max_connections'" ]] || false + [[ "$output" =~ "invalid syntax" ]] || false + [[ "$output" =~ "151" ]] +} diff --git a/integration-tests/bats/sql-server-config.bats b/integration-tests/bats/sql-server-config.bats new file mode 100644 index 0000000000..15e36e2ea4 --- /dev/null +++ b/integration-tests/bats/sql-server-config.bats @@ -0,0 +1,113 @@ +#!/usr/bin/env bats +load $BATS_TEST_DIRNAME/helper/common.bash +load $BATS_TEST_DIRNAME/helper/query-server-common.bash + +make_repo() { + mkdir "$1" + cd "$1" + dolt init + cd .. +} + +setup() { + setup_no_dolt_init + make_repo repo1 + make_repo repo2 +} + +teardown() { + stop_sql_server + teardown_common +} + +@test "sql-server-config: persist global variable before server startup" { + cd repo1 + echo '{"sqlserver.global.max_connections":"1000"}' > .dolt/config.json + start_sql_server repo1 + + server_query repo1 1 "select @@GLOBAL.max_connections" "@@GLOBAL.max_connections\n1000" + +} + +@test "sql-server-config: invalid persisted global variable name throws warning on server startup, but does not crash" { + cd repo1 + echo '{"sqlserver.global.unknown":"1000"}' > .dolt/config.json + start_sql_server repo1 +} + +@test "sql-server-config: invalid persisted global variable value throws warning on server startup, but does not crash" { + cd repo1 + echo '{"server.max_connections":"string"}' > .dolt/config.json + start_sql_server repo1 +} + +@test "sql-server-config: persisted global variable in server" { + cd repo1 + start_sql_server repo1 + + insert_query repo1 1 "SET @@PERSIST.max_connections = 1000" + server_query repo1 1 "select @@GLOBAL.max_connections" "@@GLOBAL.max_connections\n1000" + + run dolt config --local --list + [ "$status" -eq 0 ] + [[ "$output" =~ "sqlserver.global.max_connections = 1000" ]] || false +} + +@test "sql-server-config: persist only global variable during server session" { + cd repo1 + start_sql_server repo1 + + insert_query repo1 1 "SET PERSIST max_connections = 1000" + insert_query repo1 1 "SET PERSIST_ONLY max_connections = 7777" + server_query repo1 1 "select @@GLOBAL.max_connections" "@@GLOBAL.max_connections\n1000" + + run dolt config --local --list + [ "$status" -eq 0 ] + [[ "$output" =~ "sqlserver.global.max_connections = 7777" ]] || false +} + +@test "sql-server-config: persist invalid global variable name during server session" { + cd repo1 + start_sql_server repo1 + run insert_query repo1 1 "SET @@PERSIST.unknown = 1000" + [ "$status" -eq 1 ] + [[ ! "$output" =~ "panic" ]] + [[ "$output" =~ "Unknown system variable 'unknown'" ]] || false +} + +@test "sql-server-config: persist invalid global variable value during server session" { + cd repo1 + start_sql_server repo1 + run insert_query repo1 1 "SET @@PERSIST.max_connections = 'string'" + [ "$status" -eq 1 ] + [[ ! "$output" =~ "panic" ]] + [[ "$output" =~ "Variable 'max_connections' can't be set to the value of 'string'" ]] || false +} + +@test "sql-server-config: reset persisted variable" { + skip "TODO: parser support for RESET PERSIST" + cd repo1 + start_sql_server repo1 + + insert_query repo1 1 "SET @@PERSIST.max_connections = 1000" + insert_query repo1 1 "RESET @@PERSIST.max_connections" + + run dolt config --local --list + [ "$status" -eq 0 ] + [[ ! "$output" =~ "sqlserver.global.max_connections = 1000" ]] || false +} + +@test "sql-server-config: reset all persisted variables" { + skip "TODO: parser support for RESET PERSIST" + cd repo1 + start_sql_server repo1 + + insert_query repo1 1 "SET @@PERSIST.max_connections = 1000" + insert_query repo1 1 "SET @@PERSIST.auto_increment_increment = 1000" + insert_query repo1 1 "RESET PERSIST" + + run dolt config --local --list + [ "$status" -eq 0 ] + [[ ! "$output" =~ "sqlserver.global.max_connections = 1000" ]] || false + [[ ! "$output" =~ "sqlserver.global.auto_increment_increment = 1000" ]] || false +} diff --git a/integration-tests/bats/sql-server.bats b/integration-tests/bats/sql-server.bats index 9decb7a0f9..5504fde717 100644 --- a/integration-tests/bats/sql-server.bats +++ b/integration-tests/bats/sql-server.bats @@ -1055,7 +1055,7 @@ while True: mkdir bac1 cd repo1 dolt remote add backup1 file://../bac1 - dolt config --local --add DOLT_REPLICATE_TO_REMOTE backup1 + dolt config --local --add sqlserver.global.DOLT_REPLICATE_TO_REMOTE backup1 start_sql_server repo1 multi_query repo1 1 " @@ -1097,7 +1097,7 @@ while True: dolt push -u remote1 main cd ../repo1 - dolt config --local --add DOLT_READ_REPLICA_REMOTE remote1 + dolt config --local --add sqlserver.global.DOLT_READ_REPLICA_REMOTE remote1 start_sql_server repo1 server_query repo1 1 "show tables" "Table\ntest" diff --git a/integration-tests/bats/sql-use.expect b/integration-tests/bats/sql-use.expect index bb0da9d586..fc5ab801db 100644 --- a/integration-tests/bats/sql-use.expect +++ b/integration-tests/bats/sql-use.expect @@ -1,6 +1,6 @@ #!/usr/bin/expect -set timeout 1 +set timeout 2 spawn dolt sql expect { "doltsql> " { send "use `doltsql/test`;\r"; }