From d554276ab606db1fbbed1a22140e2e6509679d46 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Tue, 27 Jul 2021 17:25:33 -0700 Subject: [PATCH] Enabled named db revisions for SQL command (previously only in server) --- go/cmd/dolt/commands/sql.go | 156 ++++++++++++------ go/cmd/dolt/commands/sqlserver/server.go | 2 +- .../doltcore/sqle/database_provider.go | 4 +- 3 files changed, 108 insertions(+), 54 deletions(-) diff --git a/go/cmd/dolt/commands/sql.go b/go/cmd/dolt/commands/sql.go index 9313694b53..7ab3d55081 100644 --- a/go/cmd/dolt/commands/sql.go +++ b/go/cmd/dolt/commands/sql.go @@ -190,11 +190,6 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE } } - sess := dsess.DefaultSession() - // TODO: not having user and email for this command should probably be an error or warning, it disables certain functionality - sess.Username = *dEnv.Config.GetStringOrDefault(env.UserNameKey, "") - sess.Email = *dEnv.Config.GetStringOrDefault(env.UserEmailKey, "") - var mrEnv env.MultiRepoEnv var initialRoots map[string]*doltdb.RootValue var readOnly = false @@ -257,16 +252,6 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE } } - sqlCtx := sql.NewContext(ctx, - sql.WithSession(sess), - sql.WithIndexRegistry(sql.NewIndexRegistry()), - sql.WithViewRegistry(sql.NewViewRegistry()), - sql.WithTracer(tracing.Tracer(ctx))) - err = sqlCtx.SetSessionVariable(sqlCtx, sql.AutoCommitSessionVar, true) - if err != nil { - return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) - } - roots := make(map[string]*doltdb.RootValue) var name string @@ -277,7 +262,6 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE var currentDB string if len(initialRoots) == 1 { - sqlCtx.SetCurrentDatabase(name) currentDB = name } @@ -288,12 +272,12 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE if multiStatementMode { batchInput := strings.NewReader(query) - verr = execMultiStatements(sqlCtx, readOnly, continueOnError, mrEnv, roots, batchInput, format) + verr = execMultiStatements(ctx, dEnv, continueOnError, mrEnv, roots, readOnly, batchInput, format) } else if batchMode { batchInput := strings.NewReader(query) - verr = execBatch(sqlCtx, readOnly, continueOnError, mrEnv, roots, batchInput, format) + verr = execBatch(ctx, dEnv, continueOnError, mrEnv, roots, readOnly, batchInput, format) } else { - verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format) + verr = execQuery(ctx, dEnv, mrEnv, roots, readOnly, query, format) if verr != nil { return HandleVErrAndExitCode(verr, usage) @@ -315,7 +299,7 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE } cli.PrintErrf("Executing saved query '%s':\n%s\n", savedQueryName, sq.Query) - verr = execQuery(sqlCtx, readOnly, mrEnv, roots, sq.Query, format) + verr = execQuery(ctx, dEnv, mrEnv, roots, readOnly, sq.Query, format) } else if apr.Contains(listSavedFlag) { hasQC, err := roots[currentDB].HasTable(ctx, doltdb.DoltQueryCatalogTableName) @@ -329,7 +313,7 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE } query := "SELECT * FROM " + doltdb.DoltQueryCatalogTableName - verr = execQuery(sqlCtx, readOnly, mrEnv, roots, query, format) + verr = execQuery(ctx, dEnv, mrEnv, roots, readOnly, query, format) } else { // Run in either batch mode for piped input, or shell mode for interactive runInBatchMode := true @@ -345,11 +329,11 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE } if multiStatementMode { - verr = execMultiStatements(sqlCtx, readOnly, continueOnError, mrEnv, roots, os.Stdin, format) + verr = execMultiStatements(ctx, dEnv, continueOnError, mrEnv, roots, readOnly, os.Stdin, format) } else if runInBatchMode { - verr = execBatch(sqlCtx, readOnly, continueOnError, mrEnv, roots, os.Stdin, format) + verr = execBatch(ctx, dEnv, continueOnError, mrEnv, roots, readOnly, os.Stdin, format) } else { - verr = execShell(sqlCtx, readOnly, mrEnv, roots, format) + verr = execShell(ctx, dEnv, mrEnv, roots, readOnly, format) } } @@ -375,9 +359,16 @@ func parseCommitSpec(dEnv *env.DoltEnv, apr *argparser.ArgParseResults) (*doltdb return cs, nil } -func execShell(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, format resultFormat) errhand.VerboseError { +func execShell( + ctx context.Context, + dEnv *env.DoltEnv, + mrEnv env.MultiRepoEnv, + roots map[string]*doltdb.RootValue, + readOnly bool, + format resultFormat, +) errhand.VerboseError { dbs := CollectDBs(mrEnv) - se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...) + se, sqlCtx, err := newSqlEngine(ctx, dEnv, mrEnv, roots, readOnly, format, dbs...) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -389,9 +380,18 @@ func execShell(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots return nil } -func execBatch(sqlCtx *sql.Context, readOnly bool, continueOnErr bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, batchInput io.Reader, format resultFormat) errhand.VerboseError { +func execBatch( + ctx context.Context, + dEnv *env.DoltEnv, + continueOnErr bool, + mrEnv env.MultiRepoEnv, + roots map[string]*doltdb.RootValue, + readOnly bool, + batchInput io.Reader, + format resultFormat, +) errhand.VerboseError { dbs := CollectDBs(mrEnv) - se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...) + se, sqlCtx, err := newSqlEngine(ctx, dEnv, mrEnv, roots, readOnly, format, dbs...) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -418,16 +418,17 @@ func execBatch(sqlCtx *sql.Context, readOnly bool, continueOnErr bool, mrEnv env } func execMultiStatements( - sqlCtx *sql.Context, - readOnly bool, + ctx context.Context, + dEnv *env.DoltEnv, continueOnErr bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, + readOnly bool, batchInput io.Reader, format resultFormat, ) errhand.VerboseError { dbs := CollectDBs(mrEnv) - se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...) + se, sqlCtx, err := newSqlEngine(ctx, dEnv, mrEnv, roots, readOnly, format, dbs...) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -446,15 +447,16 @@ func newDatabase(name string, dEnv *env.DoltEnv) dsqle.Database { } func execQuery( - sqlCtx *sql.Context, - readOnly bool, + ctx context.Context, + dEnv *env.DoltEnv, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, + readOnly bool, query string, format resultFormat, ) errhand.VerboseError { dbs := CollectDBs(mrEnv) - se, err := newSqlEngine(sqlCtx, readOnly, mrEnv, roots, format, dbs...) + se, sqlCtx, err := newSqlEngine(ctx, dEnv, mrEnv, roots, readOnly, format, dbs...) if err != nil { return errhand.VerboseErrorFromError(err) } @@ -1353,8 +1355,15 @@ type sqlEngine struct { var ErrDBNotFoundKind = errors.NewKind("database '%s' not found") // sqlEngine packages up the context necessary to run sql queries against sqle. -func newSqlEngine(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, roots map[string]*doltdb.RootValue, format resultFormat, dbs ...dsqle.Database) (*sqlEngine, error) { - c := sql.NewCatalog() +func newSqlEngine( + ctx context.Context, + dEnv *env.DoltEnv, + mrEnv env.MultiRepoEnv, + roots map[string]*doltdb.RootValue, // See TODO below + readOnly bool, + format resultFormat, + dbs ...dsqle.Database, +) (*sqlEngine, *sql.Context, error) { var au auth.Auth if readOnly { @@ -1363,14 +1372,17 @@ func newSqlEngine(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, ro au = new(auth.None) } - err := c.Register(dfunctions.DoltFunctions...) + parallelism := runtime.GOMAXPROCS(0) + pro := dsqle.NewDoltDatabaseProvider(dbs...) + cat := sql.NewCatalogWithDbProvider(pro) + err := cat.Register(dfunctions.DoltFunctions...) if err != nil { - return nil, err + return nil, nil, err } - parallelism := runtime.GOMAXPROCS(0) - engine := sqle.New(c, analyzer.NewBuilder(c).WithParallelism(parallelism).Build(), &sqle.Config{Auth: au}) + cat.AddDatabase(information_schema.NewInformationSchemaDatabase(cat)) + engine := sqle.New(cat, analyzer.NewBuilder(cat).WithParallelism(parallelism).Build(), &sqle.Config{Auth: au}) engine.AddDatabase(information_schema.NewInformationSchemaDatabase(engine.Catalog)) if dbg, ok := os.LookupEnv("DOLT_SQL_DEBUG_LOG"); ok && strings.ToLower(dbg) == "true" { @@ -1380,35 +1392,77 @@ func newSqlEngine(sqlCtx *sql.Context, readOnly bool, mrEnv env.MultiRepoEnv, ro } } - sess := dsess.DSessFromSess(sqlCtx.Session) - nameToDB := make(map[string]dsqle.Database) + var dbStates []dsess.InitialDbState for _, db := range dbs { nameToDB[db.Name()] = db - root := roots[db.Name()] engine.AddDatabase(db) - // TODO: this doesn't consider the root above, which may not be the HEAD of the branch + // TODO: this doesn't consider the roots provided as a param, which may not be the HEAD of the branch // To fix this, we need to pass a commit here as a separate param, and install a read-only database on it // since it isn't a current HEAD. - dbState, err := getDbState(sqlCtx, db, mrEnv) + dbState, err := getDbState(ctx, db, mrEnv) if err != nil { - return nil, err + return nil, nil, err } - err = sess.AddDB(sqlCtx, dbState) + dbStates = append(dbStates, dbState) + } + + // TODO: not having user and email for this command should probably be an error or warning, it disables certain functionality + username := *dEnv.Config.GetStringOrDefault(env.UserNameKey, "") + email := *dEnv.Config.GetStringOrDefault(env.UserEmailKey, "") + sess, err := dsess.NewSession(sql.NewEmptyContext(), sql.NewBaseSession(), pro, username, email, dbStates...) + + sqlCtx := sql.NewContext(ctx, + sql.WithSession(sess), + sql.WithIndexRegistry(sql.NewIndexRegistry()), + sql.WithViewRegistry(sql.NewViewRegistry()), + sql.WithTracer(tracing.Tracer(ctx))) + + for _, db := range dbsAsDSQLDBs(cat.AllDatabases()) { + root, err := db.GetRoot(sqlCtx) if err != nil { - return nil, err + return nil, nil, err } err = dsqle.RegisterSchemaFragments(sqlCtx, db, root) - if err != nil { - return nil, err + return nil, nil, err } } - return &sqlEngine{nameToDB, mrEnv, engine, format}, nil + err = sqlCtx.SetSessionVariable(sqlCtx, sql.AutoCommitSessionVar, true) + if err != nil { + return nil, nil, err + } + + initialRoots, err := mrEnv.GetWorkingRoots(ctx) + if err != nil { + return nil, nil, err + } + + if len(initialRoots) == 1 { + for name := range initialRoots { + sqlCtx.SetCurrentDatabase(name) + } + } + + return &sqlEngine{nameToDB, mrEnv, engine, format}, sqlCtx, nil +} + +func dbsAsDSQLDBs(dbs []sql.Database) []dsqle.Database { + dsqlDBs := make([]dsqle.Database, 0, len(dbs)) + + for _, db := range dbs { + dsqlDB, ok := db.(dsqle.Database) + + if ok { + dsqlDBs = append(dsqlDBs, dsqlDB) + } + } + + return dsqlDBs } func getDbState(ctx context.Context, db dsqle.Database, mrEnv env.MultiRepoEnv) (dsess.InitialDbState, error) { diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index db3103d451..dfef14bd04 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -204,7 +204,7 @@ func newSessionBuilder(sqlEngine *sqle.Engine, username string, email string, pr dbs := dbsAsDSQLDBs(sqlEngine.Catalog.AllDatabases()) for _, db := range dbs { root, err := db.GetRoot(sqlCtx) - if err != err { + if err != nil { cli.PrintErrln(err) return nil, nil, nil, err } diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 5d96a26003..9c978068f8 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -34,7 +34,7 @@ const ( enableDbRevisionsEnvKey = "DOLT_ENABLE_DB_REVISIONS" ) -var dbRevisionsEnabled = false +var dbRevisionsEnabled = true func init() { val, ok := os.LookupEnv(enableDbRevisionsEnvKey) @@ -46,7 +46,7 @@ func init() { } func DbRevisionsEnabled() bool { - return dbRevisionsEnabled + return true } type DoltDatabaseProvider struct {