mirror of
https://github.com/dolthub/dolt.git
synced 2026-04-23 05:13:00 -05:00
More clean up of session setup
This commit is contained in:
+14
-21
@@ -27,6 +27,7 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/abiosoft/readline"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
sqle "github.com/dolthub/go-mysql-server"
|
||||
"github.com/dolthub/go-mysql-server/auth"
|
||||
@@ -258,17 +259,17 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
|
||||
}
|
||||
|
||||
if multiStatementMode {
|
||||
verr := execMultiStatements(ctx, dEnv.Config, dEnv.FS, continueOnError, mrEnv, os.Stdin, format)
|
||||
verr := execMultiStatements(ctx, continueOnError, mrEnv, os.Stdin, format)
|
||||
if verr != nil {
|
||||
return HandleVErrAndExitCode(verr, usage)
|
||||
}
|
||||
} else if runInBatchMode {
|
||||
verr := execBatch(ctx, dEnv.Config, dEnv.FS, continueOnError, mrEnv, os.Stdin, format)
|
||||
verr := execBatch(ctx, continueOnError, mrEnv, os.Stdin, format)
|
||||
if verr != nil {
|
||||
return HandleVErrAndExitCode(verr, usage)
|
||||
}
|
||||
} else {
|
||||
verr := execShell(ctx, dEnv.Config, dEnv.FS, mrEnv, format)
|
||||
verr := execShell(ctx, mrEnv, format)
|
||||
if verr != nil {
|
||||
return HandleVErrAndExitCode(verr, usage)
|
||||
}
|
||||
@@ -291,7 +292,7 @@ func listSavedQueriesMode(ctx context.Context, initialRoots map[string]*doltdb.R
|
||||
}
|
||||
|
||||
query := "SELECT * FROM " + doltdb.DoltQueryCatalogTableName
|
||||
return HandleVErrAndExitCode(execQuery(ctx, dEnv.Config, dEnv.FS, mrEnv, query, format), usage)
|
||||
return HandleVErrAndExitCode(execQuery(ctx, mrEnv, query, format), usage)
|
||||
}
|
||||
|
||||
func savedQueryMode(ctx context.Context, initialRoots map[string]*doltdb.RootValue, currentDb string, savedQueryName string, usage cli.UsagePrinter, dEnv *env.DoltEnv, mrEnv *env.MultiRepoEnv, format resultFormat) int {
|
||||
@@ -302,7 +303,7 @@ func savedQueryMode(ctx context.Context, initialRoots map[string]*doltdb.RootVal
|
||||
}
|
||||
|
||||
cli.PrintErrf("Executing saved query '%s':\n%s\n", savedQueryName, sq.Query)
|
||||
return HandleVErrAndExitCode(execQuery(ctx, dEnv.Config, dEnv.FS, mrEnv, sq.Query, format), usage)
|
||||
return HandleVErrAndExitCode(execQuery(ctx, mrEnv, sq.Query, format), usage)
|
||||
}
|
||||
|
||||
func queryMode(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults, query string, continueOnError bool, mrEnv *env.MultiRepoEnv, format resultFormat, usage cli.UsagePrinter, initialRoots map[string]*doltdb.RootValue, currentDb string) int {
|
||||
@@ -311,18 +312,18 @@ func queryMode(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseRe
|
||||
|
||||
if multiStatementMode {
|
||||
batchInput := strings.NewReader(query)
|
||||
verr := execMultiStatements(ctx, dEnv.Config, dEnv.FS, continueOnError, mrEnv, batchInput, format)
|
||||
verr := execMultiStatements(ctx, continueOnError, mrEnv, batchInput, format)
|
||||
if verr != nil {
|
||||
return HandleVErrAndExitCode(verr, usage)
|
||||
}
|
||||
} else if batchMode {
|
||||
batchInput := strings.NewReader(query)
|
||||
verr := execBatch(ctx, dEnv.Config, dEnv.FS, continueOnError, mrEnv, batchInput, format)
|
||||
verr := execBatch(ctx, continueOnError, mrEnv, batchInput, format)
|
||||
if verr != nil {
|
||||
return HandleVErrAndExitCode(verr, usage)
|
||||
}
|
||||
} else {
|
||||
verr := execQuery(ctx, dEnv.Config, dEnv.FS, mrEnv, query, format)
|
||||
verr := execQuery(ctx, mrEnv, query, format)
|
||||
if verr != nil {
|
||||
return HandleVErrAndExitCode(verr, usage)
|
||||
}
|
||||
@@ -371,8 +372,6 @@ func getMultiRepoEnv(ctx context.Context, apr *argparser.ArgParseResults, dEnv *
|
||||
|
||||
func execShell(
|
||||
ctx context.Context,
|
||||
config *env.DoltCliConfig,
|
||||
fs filesys.Filesys,
|
||||
mrEnv *env.MultiRepoEnv,
|
||||
format resultFormat,
|
||||
) errhand.VerboseError {
|
||||
@@ -380,7 +379,7 @@ func execShell(
|
||||
if err != nil {
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
se, err := newSqlEngine(ctx, config, fs, format, dbs...)
|
||||
se, err := newSqlEngine(ctx, mrEnv.Config(), mrEnv.FileSystem(), format, dbs...)
|
||||
if err != nil {
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
@@ -394,8 +393,6 @@ func execShell(
|
||||
|
||||
func execBatch(
|
||||
ctx context.Context,
|
||||
config *env.DoltCliConfig,
|
||||
fs filesys.Filesys,
|
||||
continueOnErr bool,
|
||||
mrEnv *env.MultiRepoEnv,
|
||||
batchInput io.Reader,
|
||||
@@ -406,7 +403,7 @@ func execBatch(
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
|
||||
se, err := newSqlEngine(ctx, config, fs, format, dbs...)
|
||||
se, err := newSqlEngine(ctx, mrEnv.Config(), mrEnv.FileSystem(), format, dbs...)
|
||||
if err != nil {
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
@@ -439,8 +436,6 @@ func execBatch(
|
||||
|
||||
func execMultiStatements(
|
||||
ctx context.Context,
|
||||
config *env.DoltCliConfig,
|
||||
fs filesys.Filesys,
|
||||
continueOnErr bool,
|
||||
mrEnv *env.MultiRepoEnv,
|
||||
batchInput io.Reader,
|
||||
@@ -450,7 +445,7 @@ func execMultiStatements(
|
||||
if err != nil {
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
se, err := newSqlEngine(ctx, config, fs, format, dbs...)
|
||||
se, err := newSqlEngine(ctx, mrEnv.Config(), mrEnv.FileSystem(), format, dbs...)
|
||||
if err != nil {
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
@@ -478,8 +473,6 @@ func newDatabase(name string, dEnv *env.DoltEnv) dsqle.Database {
|
||||
|
||||
func execQuery(
|
||||
ctx context.Context,
|
||||
config *env.DoltCliConfig,
|
||||
fs filesys.Filesys,
|
||||
mrEnv *env.MultiRepoEnv,
|
||||
query string,
|
||||
format resultFormat,
|
||||
@@ -488,7 +481,7 @@ func execQuery(
|
||||
if err != nil {
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
se, err := newSqlEngine(ctx, config, fs, format, dbs...)
|
||||
se, err := newSqlEngine(ctx, mrEnv.Config(), mrEnv.FileSystem(), format, dbs...)
|
||||
if err != nil {
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
@@ -1439,7 +1432,7 @@ var ErrDBNotFoundKind = errors.NewKind("database '%s' not found")
|
||||
// sqlEngine packages up the context necessary to run sql queries against sqle.
|
||||
func newSqlEngine(
|
||||
ctx context.Context,
|
||||
config *env.DoltCliConfig,
|
||||
config config.ReadWriteConfig,
|
||||
fs filesys.Filesys,
|
||||
format resultFormat,
|
||||
dbs ...dsqle.SqlDatabase,
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
sqle "github.com/dolthub/go-mysql-server"
|
||||
"github.com/dolthub/go-mysql-server/auth"
|
||||
"github.com/dolthub/go-mysql-server/server"
|
||||
@@ -199,7 +200,13 @@ func newSessionBuilder(sqlEngine *sqle.Engine, dConf *env.DoltCliConfig, pro dsq
|
||||
return nil, err
|
||||
}
|
||||
|
||||
doltSess, err := dsess.NewDoltSession(tmpSqlCtx, mysqlSess, pro, dConf, dbStates...)
|
||||
localConfig, ok := dConf.GetConfig(env.LocalConfig)
|
||||
if !ok {
|
||||
logrus.Warn("No local config available, config persistence disabled")
|
||||
localConfig = config.NewEmptyMapConfig()
|
||||
}
|
||||
|
||||
doltSess, err := dsess.NewDoltSession(tmpSqlCtx, mysqlSess, pro, localConfig, dbStates...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+9
@@ -51,6 +51,15 @@ type NamedEnv struct {
|
||||
env *DoltEnv
|
||||
}
|
||||
|
||||
func (mrEnv *MultiRepoEnv) FileSystem() filesys.Filesys {
|
||||
return mrEnv.fs
|
||||
}
|
||||
|
||||
func (mrEnv *MultiRepoEnv) Config() config.ReadWriteConfig {
|
||||
return mrEnv.cfg
|
||||
}
|
||||
|
||||
// TODO: un export
|
||||
// AddEnv adds an environment to the MultiRepoEnv by name
|
||||
func (mrEnv *MultiRepoEnv) AddEnv(name string, dEnv *DoltEnv) {
|
||||
mrEnv.envs = append(mrEnv.envs, NamedEnv{
|
||||
|
||||
@@ -36,20 +36,13 @@ var _ sql.Session = (*DoltSession)(nil)
|
||||
var _ sql.PersistableSession = (*DoltSession)(nil)
|
||||
|
||||
// NewDoltSession creates a DoltSession object from a standard sql.Session and 0 or more Database objects.
|
||||
func NewDoltSession(ctx *sql.Context, sqlSess *sql.BaseSession, pro RevisionDatabaseProvider, conf *env.DoltCliConfig, dbs ...InitialDbState) (*DoltSession, error) {
|
||||
func NewDoltSession(ctx *sql.Context, sqlSess *sql.BaseSession, pro RevisionDatabaseProvider, conf config.ReadWriteConfig, dbs ...InitialDbState) (*DoltSession, error) {
|
||||
sess, err := NewSession(ctx, sqlSess, pro, conf, dbs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var globals config.ReadWriteConfig
|
||||
if localConf, ok := conf.GetConfig(env.LocalConfig); !ok {
|
||||
ctx.Warn(NonpersistableSessionCode, "configured mode does not support persistable sessions; SET PERSIST will not write to file")
|
||||
globals = config.NewMapConfig(make(map[string]string))
|
||||
} else {
|
||||
globals = config.NewPrefixConfig(localConf, env.SqlServerGlobalsPrefix)
|
||||
}
|
||||
|
||||
globals := config.NewPrefixConfig(conf, env.SqlServerGlobalsPrefix)
|
||||
return sess.NewDoltSession(globals), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/go-mysql-server/enginetest"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -58,7 +59,13 @@ var _ enginetest.ReadOnlyDatabaseHarness = (*DoltHarness)(nil)
|
||||
func newDoltHarness(t *testing.T) *DoltHarness {
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
pro := sqle.NewDoltDatabaseProvider(dEnv.Config, dEnv.FS)
|
||||
session, err := dsess.NewDoltSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), pro, dEnv.Config)
|
||||
|
||||
localConfig, ok := dEnv.Config.GetConfig(env.LocalConfig)
|
||||
if !ok {
|
||||
localConfig = config.NewEmptyMapConfig()
|
||||
}
|
||||
|
||||
session, err := dsess.NewDoltSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), pro, localConfig)
|
||||
require.NoError(t, err)
|
||||
return &DoltHarness{
|
||||
t: t,
|
||||
@@ -132,12 +139,17 @@ func (d *DoltHarness) NewSession() *sql.Context {
|
||||
dbs := dsqleDBsAsSqlDBs(d.databases)
|
||||
pro := d.NewDatabaseProvider(dbs...)
|
||||
|
||||
localConfig, ok := d.env.Config.GetConfig(env.LocalConfig)
|
||||
if !ok {
|
||||
localConfig = config.NewEmptyMapConfig()
|
||||
}
|
||||
|
||||
var err error
|
||||
d.session, err = dsess.NewDoltSession(
|
||||
enginetest.NewContext(d),
|
||||
enginetest.NewBaseSession(),
|
||||
pro.(dsess.RevisionDatabaseProvider),
|
||||
d.env.Config,
|
||||
localConfig,
|
||||
states...,
|
||||
)
|
||||
require.NoError(d.t, err)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
sqle "github.com/dolthub/go-mysql-server"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -104,7 +105,7 @@ func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *e
|
||||
|
||||
// Get an updated root to use for the rest of the test
|
||||
ctx := sql.NewEmptyContext()
|
||||
sess, err := dsess.NewDoltSession(ctx, ctx.Session.(*sql.BaseSession), pro, dEnv.Config, getDbState(t, db, dEnv))
|
||||
sess, err := dsess.NewDoltSession(ctx, ctx.Session.(*sql.BaseSession), pro, config.NewEmptyMapConfig(), getDbState(t, db, dEnv))
|
||||
require.NoError(t, err)
|
||||
roots, ok := sess.GetRoots(ctx, tiDb.Name())
|
||||
require.True(t, ok)
|
||||
|
||||
@@ -30,6 +30,10 @@ func NewMapConfig(properties map[string]string) *MapConfig {
|
||||
return &MapConfig{properties}
|
||||
}
|
||||
|
||||
func NewEmptyMapConfig() *MapConfig {
|
||||
return &MapConfig{make(map[string]string)}
|
||||
}
|
||||
|
||||
// GetString retrieves a value for a given key.
|
||||
func (mc *MapConfig) GetString(k string) (string, error) {
|
||||
if val, ok := mc.properties[k]; ok {
|
||||
|
||||
Reference in New Issue
Block a user