Support per user session variables in the server config

This commit is contained in:
Brian Hendriks
2022-07-11 15:26:38 -07:00
parent efdf426a66
commit fdb5104efe
14 changed files with 203 additions and 23 deletions
+1 -1
View File
@@ -362,7 +362,7 @@ func getMultiRepoEnv(ctx context.Context, apr *argparser.ArgParseResults, dEnv *
multiDir, multiDbMode := apr.GetValue(multiDBDirFlag)
if multiDbMode {
var err error
mrEnv, err = env.LoadMultiEnvFromDir(ctx, env.GetCurrentUserHomeDir, dEnv.Config.WriteableConfig(), dEnv.FS, multiDir, cmd.VersionStr)
mrEnv, err = env.LoadMultiEnvFromDir(ctx, env.GetCurrentUserHomeDir, dEnv.Config.WriteableConfig(), dEnv.FS, multiDir, cmd.VersionStr, dEnv.IgnoreLocks)
if err != nil {
return nil, errhand.VerboseErrorFromError(err)
}
+31 -5
View File
@@ -95,7 +95,7 @@ func Serve(
}
// TODO: this should be the global config, probably?
mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version)
mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv.IgnoreLocks)
if err != nil {
return err, nil
}
@@ -117,7 +117,7 @@ func Serve(
}
// TODO: this should be the global config, probably?
mrEnv, err = env.LoadMultiEnv(ctx, env.GetCurrentUserHomeDir, dEnv.Config.WriteableConfig(), fs, version, dbNamesAndPaths...)
mrEnv, err = env.LoadMultiEnv(ctx, env.GetCurrentUserHomeDir, dEnv.Config.WriteableConfig(), fs, version, dEnv.IgnoreLocks, dbNamesAndPaths...)
if err != nil {
return err, nil
@@ -158,7 +158,7 @@ func Serve(
mySQLServer, startError = server.NewServer(
serverConf,
sqlEngine.GetUnderlyingEngine(),
newSessionBuilder(sqlEngine),
newSessionBuilder(sqlEngine, serverConfig),
listener,
)
@@ -222,7 +222,13 @@ func portInUse(hostPort string) bool {
return false
}
func newSessionBuilder(se *engine.SqlEngine) server.SessionBuilder {
func newSessionBuilder(se *engine.SqlEngine, config ServerConfig) server.SessionBuilder {
userToSessionVars := make(map[string]map[string]string)
userVars := config.UserVars()
for _, curr := range userVars {
userToSessionVars[curr.Name] = curr.Vars
}
return func(ctx context.Context, conn *mysql.Conn, host string) (sql.Session, error) {
mysqlSess, err := server.DefaultSessionBuilder(ctx, conn, host)
if err != nil {
@@ -233,7 +239,27 @@ func newSessionBuilder(se *engine.SqlEngine) server.SessionBuilder {
return nil, fmt.Errorf("unknown GMS base session type")
}
return se.NewDoltSession(ctx, mysqlBaseSess)
dsess, err := se.NewDoltSession(ctx, mysqlBaseSess)
if err != nil {
return nil, err
}
varsForUser := userToSessionVars[conn.User]
if len(varsForUser) > 0 {
sqlCtx, err := se.NewContext(ctx)
if err != nil {
return nil, err
}
for key, val := range varsForUser {
err = dsess.InitSessionVariable(sqlCtx, key, val)
if err != nil {
return nil, err
}
}
}
return dsess, nil
}
}
@@ -125,6 +125,8 @@ type ServerConfig interface {
// PrivilegeFilePath returns the path to the file which contains all needed privilege information in the form of a
// JSON string.
PrivilegeFilePath() string
// UserVars is an array containing user specific session variables
UserVars() []UserSessionVars
}
type commandLineServerConfig struct {
@@ -241,6 +243,10 @@ func (cfg *commandLineServerConfig) PrivilegeFilePath() string {
return cfg.privilegeFilePath
}
func (cfg *commandLineServerConfig) UserVars() []UserSessionVars {
return nil
}
// DatabaseNamesAndPaths returns an array of env.EnvNameAndPathObjects corresponding to the databases to be loaded in
// a multiple db configuration. If nil is returned the server will look for a database in the current directory and
// give it a name automatically.
@@ -107,6 +107,11 @@ type MetricsYAMLConfig struct {
Port *int `yaml:"port"`
}
type UserSessionVars struct {
Name string `yaml:"name"`
Vars map[string]string `yaml:"vars"`
}
// YAMLConfig is a ServerConfig implementation which is read from a yaml file
type YAMLConfig struct {
LogLevelStr *string `yaml:"log_level"`
@@ -118,6 +123,7 @@ type YAMLConfig struct {
DataDirStr *string `yaml:"data_dir"`
MetricsConfig MetricsYAMLConfig `yaml:"metrics"`
PrivilegeFile *string `yaml:"privilege_file"`
Vars []UserSessionVars `yaml:"user_session_vars"`
}
var _ ServerConfig = YAMLConfig{}
@@ -324,6 +330,14 @@ func (cfg YAMLConfig) PrivilegeFilePath() string {
return ""
}
func (cfg YAMLConfig) UserVars() []UserSessionVars {
if cfg.Vars != nil {
return cfg.Vars
}
return nil
}
// QueryParallelism returns the parallelism that should be used by the go-mysql-server analyzer
func (cfg YAMLConfig) QueryParallelism() int {
if cfg.PerformanceConfig.QueryParallelism == nil {
@@ -59,6 +59,18 @@ metrics:
label1: value1
label2: 2
label3: true
user_session_vars:
- name: user0
vars:
var1: val0_1
var2: val0_2
var3: val0_3
- name: user1
vars:
var1: val1_1
var2: val1_2
var4: val1_4
`
expected := serverConfigAsYAMLConfig(DefaultServerConfig())
@@ -82,6 +94,24 @@ metrics:
},
}
expected.DataDirStr = strPtr("some nonsense")
expected.Vars = []UserSessionVars{
{
Name: "user0",
Vars: map[string]string{
"var1": "val0_1",
"var2": "val0_2",
"var3": "val0_3",
},
},
{
Name: "user1",
Vars: map[string]string{
"var1": "val1_1",
"var2": "val1_2",
"var4": "val1_4",
},
},
}
config, err := NewYamlConfig([]byte(testStr))
require.NoError(t, err)
+8
View File
@@ -123,6 +123,8 @@ const stdInFlag = "--stdin"
const stdOutFlag = "--stdout"
const stdErrFlag = "--stderr"
const stdOutAndErrFlag = "--out-and-err"
const ignoreLocksFlag = "--ignore-locks"
const cpuProf = "cpu"
const memProf = "mem"
const blockingProf = "blocking"
@@ -138,6 +140,7 @@ func runMain() int {
args := os.Args[1:]
csMetrics := false
ignoreLocks := false
if len(args) > 0 {
var doneDebugFlags bool
for !doneDebugFlags && len(args) > 0 {
@@ -274,6 +277,10 @@ func runMain() int {
csMetrics = true
args = args[1:]
case ignoreLocksFlag:
ignoreLocks = true
args = args[1:]
case featureVersionFlag:
if featureVersion, err := strconv.Atoi(args[1]); err == nil {
doltdb.DoltFeatureVersion = doltdb.FeatureVersion(featureVersion)
@@ -297,6 +304,7 @@ func runMain() int {
ctx := context.Background()
dEnv := env.Load(ctx, env.GetCurrentUserHomeDir, filesys.LocalFS, doltdb.LocalDirDoltDB, Version)
dEnv.IgnoreLocks = ignoreLocks
root, err := env.GetCurrentUserHomeDir()
if err != nil {
+1 -1
View File
@@ -59,7 +59,7 @@ require (
)
require (
github.com/dolthub/go-mysql-server v0.12.1-0.20220708213239-a9724caf9408
github.com/dolthub/go-mysql-server v0.12.1-0.20220711205846-1d24942a3ec6
github.com/google/flatbuffers v2.0.6+incompatible
github.com/gosuri/uilive v0.0.4
github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6
+3 -2
View File
@@ -175,8 +175,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-mysql-server v0.12.1-0.20220708213239-a9724caf9408 h1:d9S6ejy+EF2fbAOXl7K9DZCm79mfZyZY/dgtoZu21DI=
github.com/dolthub/go-mysql-server v0.12.1-0.20220708213239-a9724caf9408/go.mod h1:fhyVDvV0K59cdk9N7TQsPjr2Hp/Qseej8+R9tVqPDCg=
github.com/dolthub/go-mysql-server v0.12.1-0.20220711205846-1d24942a3ec6 h1:ut47h+4QLL+WatIf/N0MAJGtrFJ1PF9fDsy++/J3blw=
github.com/dolthub/go-mysql-server v0.12.1-0.20220711205846-1d24942a3ec6/go.mod h1:fhyVDvV0K59cdk9N7TQsPjr2Hp/Qseej8+R9tVqPDCg=
github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 h1:oyPHJlzumKta1vnOQqUnfdz+pk3EmnHS3Nd0cCT0I2g=
github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371/go.mod h1:dhGBqcCEfK5kuFmeO5+WOx3hqc1k3M29c1oS/R7N4ms=
github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8=
@@ -238,6 +238,7 @@ github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8=
+14
View File
@@ -98,6 +98,8 @@ type DoltEnv struct {
FS filesys.Filesys
urlStr string
hdp HomeDirProvider
IgnoreLocks bool
}
// Load loads the DoltEnv for the current directory of the cli
@@ -1288,12 +1290,20 @@ func (dEnv *DoltEnv) LockFile() string {
// IsLocked returns true if this database's lockfile exists
func (dEnv *DoltEnv) IsLocked() bool {
if dEnv.IgnoreLocks {
return false
}
ok, _ := dEnv.FS.Exists(dEnv.LockFile())
return ok
}
// Lock writes this database's lockfile or errors if it already exists
func (dEnv *DoltEnv) Lock() error {
if dEnv.IgnoreLocks {
return nil
}
if dEnv.IsLocked() {
return ErrActiveServerLock.New(dEnv.LockFile())
}
@@ -1302,5 +1312,9 @@ func (dEnv *DoltEnv) Lock() error {
// Unlock deletes this database's lockfile
func (dEnv *DoltEnv) Unlock() error {
if dEnv.IgnoreLocks {
return nil
}
return dEnv.FS.DeleteFile(dEnv.LockFile())
}
+29 -11
View File
@@ -45,9 +45,10 @@ type EnvNameAndPath struct {
// MultiRepoEnv is a type used to store multiple environments which can be retrieved by name
type MultiRepoEnv struct {
envs []NamedEnv
fs filesys.Filesys
cfg config.ReadWriteConfig
envs []NamedEnv
fs filesys.Filesys
cfg config.ReadWriteConfig
ignoreLocks bool
}
type NamedEnv struct {
@@ -141,6 +142,10 @@ func (mrEnv *MultiRepoEnv) GetWorkingRoots(ctx context.Context) (map[string]*dol
// IsLocked returns true if any env is locked
func (mrEnv *MultiRepoEnv) IsLocked() (bool, string) {
if mrEnv.ignoreLocks {
return false, ""
}
for _, e := range mrEnv.envs {
if e.env.IsLocked() {
return true, e.env.LockFile()
@@ -152,6 +157,10 @@ func (mrEnv *MultiRepoEnv) IsLocked() (bool, string) {
// Lock locks all child envs. If an error is returned, all
// child envs will be returned with their initial lock state.
func (mrEnv *MultiRepoEnv) Lock() error {
if mrEnv.ignoreLocks {
return nil
}
if ok, f := mrEnv.IsLocked(); ok {
return ErrActiveServerLock.New(f)
}
@@ -169,6 +178,10 @@ func (mrEnv *MultiRepoEnv) Lock() error {
// Unlock unlocks all child envs.
func (mrEnv *MultiRepoEnv) Unlock() error {
if mrEnv.ignoreLocks {
return nil
}
var err, retErr error
for _, e := range mrEnv.envs {
err = e.env.Unlock()
@@ -221,7 +234,7 @@ func getRepoRootDir(path, pathSeparator string) string {
func DoltEnvAsMultiEnv(ctx context.Context, dEnv *DoltEnv) (*MultiRepoEnv, error) {
if !dEnv.Valid() {
cfg := dEnv.Config.WriteableConfig()
return MultiEnvForDirectory(ctx, cfg, dEnv.FS, dEnv.Version)
return MultiEnvForDirectory(ctx, cfg, dEnv.FS, dEnv.Version, dEnv.IgnoreLocks)
}
dbName := "dolt"
@@ -295,11 +308,13 @@ func MultiEnvForDirectory(
config config.ReadWriteConfig,
fs filesys.Filesys,
version string,
ignoreLocks bool,
) (*MultiRepoEnv, error) {
mrEnv := &MultiRepoEnv{
envs: make([]NamedEnv, 0),
fs: fs,
cfg: config,
envs: make([]NamedEnv, 0),
fs: fs,
cfg: config,
ignoreLocks: ignoreLocks,
}
// If there are other directories in the directory, try to load them as additional databases
@@ -334,6 +349,7 @@ func LoadMultiEnv(
cfg config.ReadWriteConfig,
fs filesys.Filesys,
version string,
ignoreLocks bool,
envNamesAndPaths ...EnvNameAndPath,
) (*MultiRepoEnv, error) {
nameToPath := make(map[string]string)
@@ -352,9 +368,10 @@ func LoadMultiEnv(
}
mrEnv := &MultiRepoEnv{
envs: make([]NamedEnv, 0),
fs: fs,
cfg: cfg,
envs: make([]NamedEnv, 0),
fs: fs,
cfg: cfg,
ignoreLocks: ignoreLocks,
}
for name, path := range nameToPath {
@@ -419,6 +436,7 @@ func LoadMultiEnvFromDir(
cfg config.ReadWriteConfig,
fs filesys.Filesys,
path, version string,
ignoreLocks bool,
) (*MultiRepoEnv, error) {
envNamesAndPaths, err := DBNamesAndPathsFromDir(fs, path)
@@ -431,7 +449,7 @@ func LoadMultiEnvFromDir(
return nil, errhand.VerboseErrorFromError(err)
}
return LoadMultiEnv(ctx, hdp, cfg, multiDbDirFs, version, envNamesAndPaths...)
return LoadMultiEnv(ctx, hdp, cfg, multiDbDirFs, version, ignoreLocks, envNamesAndPaths...)
}
func dirToDBName(dirName string) string {
+2 -2
View File
@@ -183,7 +183,7 @@ func TestLoadMultiEnv(t *testing.T) {
envNamesAndPaths[i] = EnvNameAndPath{name, filepath.Join(rootPath, name)}
}
mrEnv, err := LoadMultiEnv(context.Background(), hdp, config.NewEmptyMapConfig(), filesys.LocalFS, "test", envNamesAndPaths...)
mrEnv, err := LoadMultiEnv(context.Background(), hdp, config.NewEmptyMapConfig(), filesys.LocalFS, "test", false, envNamesAndPaths...)
require.NoError(t, err)
for _, name := range names {
@@ -205,7 +205,7 @@ func TestLoadMultiEnvFromDir(t *testing.T) {
}
rootPath, hdp, envs := initMultiEnv(t, "TestLoadMultiEnvFromDir", names)
mrEnv, err := LoadMultiEnvFromDir(context.Background(), hdp, config.NewEmptyMapConfig(), filesys.LocalFS, rootPath, "test")
mrEnv, err := LoadMultiEnvFromDir(context.Background(), hdp, config.NewEmptyMapConfig(), filesys.LocalFS, rootPath, "test", false)
require.NoError(t, err)
assert.Len(t, mrEnv.envs, len(names))
@@ -57,6 +57,8 @@ const (
ReplicateHeadsKey = "dolt_replicate_heads"
ReplicateAllHeadsKey = "dolt_replicate_all_heads"
AsyncReplicationKey = "dolt_async_replication"
AwsCredsFileKey = "aws_credentials_file"
AwsCredsProfileKey = "aws_credentials_profile"
// Transactions merges will stomp if either if the below keys are set
TransactionMergeStompKey = "dolt_transaction_merge_stomp"
TransactionMergeStompEnvKey = "DOLT_TRANSACTION_MERGE_STOMP"
@@ -137,6 +137,22 @@ func defineSystemVariables(name string) {
Type: sql.NewSystemStringType(DefaultBranchKey(name)),
Default: "",
},
{
Name: AwsCredsFileKey,
Scope: sql.SystemVariableScope_Session,
Dynamic: false,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(AwsCredsFileKey),
Default: nil,
},
{
Name: AwsCredsProfileKey,
Scope: sql.SystemVariableScope_Session,
Dynamic: false,
SetVarHintApplies: false,
Type: sql.NewSystemStringType(AwsCredsProfileKey),
Default: nil,
},
})
}
}
+46 -1
View File
@@ -21,6 +21,51 @@ teardown() {
teardown_common
}
@test "sql-server: user session variables from config" {
cd repo1
echo "
privilege_file: privs.json
user_session_vars:
- name: user0
vars:
aws_credentials_file: /Users/user0/.aws/config
aws_credentials_profile: default
- name: user1
vars:
aws_credentials_file: /Users/user1/.aws/config
aws_credentials_profile: lddev" > server.yaml
dolt sql --privilege-file=privs.json -q "CREATE USER dolt@'127.0.0.1'"
dolt sql --privilege-file=privs.json -q "CREATE USER user0@'127.0.0.1' IDENTIFIED BY 'pass0'"
dolt sql --privilege-file=privs.json -q "CREATE USER user1@'127.0.0.1' IDENTIFIED BY 'pass1'"
dolt sql --privilege-file=privs.json -q "CREATE USER user2@'127.0.0.1' IDENTIFIED BY 'pass2'"
start_sql_server_with_config "" server.yaml
run mysql --host=127.0.0.1 --port=$PORT --user=user0 --password=pass0<<SQL
SELECT @@aws_credentials_file, @@aws_credentials_profile;
SQL
[[ "$output" =~ /Users/user0/.aws/config.*default ]] || false
run mysql --host=127.0.0.1 --port=$PORT --user=user1 --password=pass1<<SQL
SELECT @@aws_credentials_file, @@aws_credentials_profile;
SQL
[[ "$output" =~ /Users/user1/.aws/config.*lddev ]] || false
run mysql --host=127.0.0.1 --port=$PORT --user=user2 --password=pass2<<SQL
SELECT @@aws_credentials_file, @@aws_credentials_profile;
SQL
[[ "$output" =~ NULL.*NULL ]] || false
run mysql --host=127.0.0.1 --port=$PORT --user=user2 --password=pass2<<SQL
SET @@aws_credentials_file="/Users/should_fail";
SQL
[[ "$output" =~ "Variable 'aws_credentials_file' is a read only variable" ]] || false
}
@test "sql-server: port in use" {
cd repo1
@@ -1345,4 +1390,4 @@ databases:
let PORT="$$ % (65536-1024) + 1024"
run dolt sql-server -P $PORT
[ "$status" -eq 1 ]
}
}