More clean up of session setup

This commit is contained in:
Zach Musgrave
2021-11-04 09:51:30 -07:00
parent e487a84ea2
commit 4a236171e7
7 changed files with 53 additions and 34 deletions
+14 -21
View File
@@ -27,6 +27,7 @@ import (
"syscall"
"github.com/abiosoft/readline"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/auth"
@@ -258,17 +259,17 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
}
if multiStatementMode {
verr := execMultiStatements(ctx, dEnv.Config, dEnv.FS, continueOnError, mrEnv, os.Stdin, format)
verr := execMultiStatements(ctx, continueOnError, mrEnv, os.Stdin, format)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
}
} else if runInBatchMode {
verr := execBatch(ctx, dEnv.Config, dEnv.FS, continueOnError, mrEnv, os.Stdin, format)
verr := execBatch(ctx, continueOnError, mrEnv, os.Stdin, format)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
}
} else {
verr := execShell(ctx, dEnv.Config, dEnv.FS, mrEnv, format)
verr := execShell(ctx, mrEnv, format)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
}
@@ -291,7 +292,7 @@ func listSavedQueriesMode(ctx context.Context, initialRoots map[string]*doltdb.R
}
query := "SELECT * FROM " + doltdb.DoltQueryCatalogTableName
return HandleVErrAndExitCode(execQuery(ctx, dEnv.Config, dEnv.FS, mrEnv, query, format), usage)
return HandleVErrAndExitCode(execQuery(ctx, mrEnv, query, format), usage)
}
func savedQueryMode(ctx context.Context, initialRoots map[string]*doltdb.RootValue, currentDb string, savedQueryName string, usage cli.UsagePrinter, dEnv *env.DoltEnv, mrEnv *env.MultiRepoEnv, format resultFormat) int {
@@ -302,7 +303,7 @@ func savedQueryMode(ctx context.Context, initialRoots map[string]*doltdb.RootVal
}
cli.PrintErrf("Executing saved query '%s':\n%s\n", savedQueryName, sq.Query)
return HandleVErrAndExitCode(execQuery(ctx, dEnv.Config, dEnv.FS, mrEnv, sq.Query, format), usage)
return HandleVErrAndExitCode(execQuery(ctx, mrEnv, sq.Query, format), usage)
}
func queryMode(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults, query string, continueOnError bool, mrEnv *env.MultiRepoEnv, format resultFormat, usage cli.UsagePrinter, initialRoots map[string]*doltdb.RootValue, currentDb string) int {
@@ -311,18 +312,18 @@ func queryMode(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseRe
if multiStatementMode {
batchInput := strings.NewReader(query)
verr := execMultiStatements(ctx, dEnv.Config, dEnv.FS, continueOnError, mrEnv, batchInput, format)
verr := execMultiStatements(ctx, continueOnError, mrEnv, batchInput, format)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
}
} else if batchMode {
batchInput := strings.NewReader(query)
verr := execBatch(ctx, dEnv.Config, dEnv.FS, continueOnError, mrEnv, batchInput, format)
verr := execBatch(ctx, continueOnError, mrEnv, batchInput, format)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
}
} else {
verr := execQuery(ctx, dEnv.Config, dEnv.FS, mrEnv, query, format)
verr := execQuery(ctx, mrEnv, query, format)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
}
@@ -371,8 +372,6 @@ func getMultiRepoEnv(ctx context.Context, apr *argparser.ArgParseResults, dEnv *
func execShell(
ctx context.Context,
config *env.DoltCliConfig,
fs filesys.Filesys,
mrEnv *env.MultiRepoEnv,
format resultFormat,
) errhand.VerboseError {
@@ -380,7 +379,7 @@ func execShell(
if err != nil {
return errhand.VerboseErrorFromError(err)
}
se, err := newSqlEngine(ctx, config, fs, format, dbs...)
se, err := newSqlEngine(ctx, mrEnv.Config(), mrEnv.FileSystem(), format, dbs...)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -394,8 +393,6 @@ func execShell(
func execBatch(
ctx context.Context,
config *env.DoltCliConfig,
fs filesys.Filesys,
continueOnErr bool,
mrEnv *env.MultiRepoEnv,
batchInput io.Reader,
@@ -406,7 +403,7 @@ func execBatch(
return errhand.VerboseErrorFromError(err)
}
se, err := newSqlEngine(ctx, config, fs, format, dbs...)
se, err := newSqlEngine(ctx, mrEnv.Config(), mrEnv.FileSystem(), format, dbs...)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -439,8 +436,6 @@ func execBatch(
func execMultiStatements(
ctx context.Context,
config *env.DoltCliConfig,
fs filesys.Filesys,
continueOnErr bool,
mrEnv *env.MultiRepoEnv,
batchInput io.Reader,
@@ -450,7 +445,7 @@ func execMultiStatements(
if err != nil {
return errhand.VerboseErrorFromError(err)
}
se, err := newSqlEngine(ctx, config, fs, format, dbs...)
se, err := newSqlEngine(ctx, mrEnv.Config(), mrEnv.FileSystem(), format, dbs...)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -478,8 +473,6 @@ func newDatabase(name string, dEnv *env.DoltEnv) dsqle.Database {
func execQuery(
ctx context.Context,
config *env.DoltCliConfig,
fs filesys.Filesys,
mrEnv *env.MultiRepoEnv,
query string,
format resultFormat,
@@ -488,7 +481,7 @@ func execQuery(
if err != nil {
return errhand.VerboseErrorFromError(err)
}
se, err := newSqlEngine(ctx, config, fs, format, dbs...)
se, err := newSqlEngine(ctx, mrEnv.Config(), mrEnv.FileSystem(), format, dbs...)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -1439,7 +1432,7 @@ var ErrDBNotFoundKind = errors.NewKind("database '%s' not found")
// sqlEngine packages up the context necessary to run sql queries against sqle.
func newSqlEngine(
ctx context.Context,
config *env.DoltCliConfig,
config config.ReadWriteConfig,
fs filesys.Filesys,
format resultFormat,
dbs ...dsqle.SqlDatabase,
+8 -1
View File
@@ -21,6 +21,7 @@ import (
"strconv"
"time"
"github.com/dolthub/dolt/go/libraries/utils/config"
sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/auth"
"github.com/dolthub/go-mysql-server/server"
@@ -199,7 +200,13 @@ func newSessionBuilder(sqlEngine *sqle.Engine, dConf *env.DoltCliConfig, pro dsq
return nil, err
}
doltSess, err := dsess.NewDoltSession(tmpSqlCtx, mysqlSess, pro, dConf, dbStates...)
localConfig, ok := dConf.GetConfig(env.LocalConfig)
if !ok {
logrus.Warn("No local config available, config persistence disabled")
localConfig = config.NewEmptyMapConfig()
}
doltSess, err := dsess.NewDoltSession(tmpSqlCtx, mysqlSess, pro, localConfig, dbStates...)
if err != nil {
return nil, err
}
+9
View File
@@ -51,6 +51,15 @@ type NamedEnv struct {
env *DoltEnv
}
func (mrEnv *MultiRepoEnv) FileSystem() filesys.Filesys {
return mrEnv.fs
}
func (mrEnv *MultiRepoEnv) Config() config.ReadWriteConfig {
return mrEnv.cfg
}
// TODO: un export
// AddEnv adds an environment to the MultiRepoEnv by name
func (mrEnv *MultiRepoEnv) AddEnv(name string, dEnv *DoltEnv) {
mrEnv.envs = append(mrEnv.envs, NamedEnv{
@@ -36,20 +36,13 @@ var _ sql.Session = (*DoltSession)(nil)
var _ sql.PersistableSession = (*DoltSession)(nil)
// NewDoltSession creates a DoltSession object from a standard sql.Session and 0 or more Database objects.
func NewDoltSession(ctx *sql.Context, sqlSess *sql.BaseSession, pro RevisionDatabaseProvider, conf *env.DoltCliConfig, dbs ...InitialDbState) (*DoltSession, error) {
func NewDoltSession(ctx *sql.Context, sqlSess *sql.BaseSession, pro RevisionDatabaseProvider, conf config.ReadWriteConfig, dbs ...InitialDbState) (*DoltSession, error) {
sess, err := NewSession(ctx, sqlSess, pro, conf, dbs...)
if err != nil {
return nil, err
}
var globals config.ReadWriteConfig
if localConf, ok := conf.GetConfig(env.LocalConfig); !ok {
ctx.Warn(NonpersistableSessionCode, "configured mode does not support persistable sessions; SET PERSIST will not write to file")
globals = config.NewMapConfig(make(map[string]string))
} else {
globals = config.NewPrefixConfig(localConf, env.SqlServerGlobalsPrefix)
}
globals := config.NewPrefixConfig(conf, env.SqlServerGlobalsPrefix)
return sess.NewDoltSession(globals), nil
}
@@ -20,6 +20,7 @@ import (
"strings"
"testing"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/go-mysql-server/enginetest"
"github.com/dolthub/go-mysql-server/sql"
"github.com/stretchr/testify/require"
@@ -58,7 +59,13 @@ var _ enginetest.ReadOnlyDatabaseHarness = (*DoltHarness)(nil)
func newDoltHarness(t *testing.T) *DoltHarness {
dEnv := dtestutils.CreateTestEnv()
pro := sqle.NewDoltDatabaseProvider(dEnv.Config, dEnv.FS)
session, err := dsess.NewDoltSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), pro, dEnv.Config)
localConfig, ok := dEnv.Config.GetConfig(env.LocalConfig)
if !ok {
localConfig = config.NewEmptyMapConfig()
}
session, err := dsess.NewDoltSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), pro, localConfig)
require.NoError(t, err)
return &DoltHarness{
t: t,
@@ -132,12 +139,17 @@ func (d *DoltHarness) NewSession() *sql.Context {
dbs := dsqleDBsAsSqlDBs(d.databases)
pro := d.NewDatabaseProvider(dbs...)
localConfig, ok := d.env.Config.GetConfig(env.LocalConfig)
if !ok {
localConfig = config.NewEmptyMapConfig()
}
var err error
d.session, err = dsess.NewDoltSession(
enginetest.NewContext(d),
enginetest.NewBaseSession(),
pro.(dsess.RevisionDatabaseProvider),
d.env.Config,
localConfig,
states...,
)
require.NoError(d.t, err)
@@ -20,6 +20,7 @@ import (
"strings"
"testing"
"github.com/dolthub/dolt/go/libraries/utils/config"
sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/stretchr/testify/require"
@@ -104,7 +105,7 @@ func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *e
// Get an updated root to use for the rest of the test
ctx := sql.NewEmptyContext()
sess, err := dsess.NewDoltSession(ctx, ctx.Session.(*sql.BaseSession), pro, dEnv.Config, getDbState(t, db, dEnv))
sess, err := dsess.NewDoltSession(ctx, ctx.Session.(*sql.BaseSession), pro, config.NewEmptyMapConfig(), getDbState(t, db, dEnv))
require.NoError(t, err)
roots, ok := sess.GetRoots(ctx, tiDb.Name())
require.True(t, ok)
+4
View File
@@ -30,6 +30,10 @@ func NewMapConfig(properties map[string]string) *MapConfig {
return &MapConfig{properties}
}
func NewEmptyMapConfig() *MapConfig {
return &MapConfig{make(map[string]string)}
}
// GetString retrieves a value for a given key.
func (mc *MapConfig) GetString(k string) (string, error) {
if val, ok := mc.properties[k]; ok {