Merge pull request #5604 from dolthub/zachmu/multidb

[no-release-notes] Refactored db / session initialization logic
This commit is contained in:
Zach Musgrave
2023-03-30 12:32:24 -07:00
committed by GitHub
26 changed files with 687 additions and 1056 deletions
+1 -1
View File
@@ -621,7 +621,7 @@ func diffUserTable(
}
if dArgs.diffParts&SchemaOnlyDiff != 0 {
err := dw.WriteTableSchemaDiff(ctx, dArgs.toRoot, td)
err := dw.WriteTableSchemaDiff(ctx, dArgs.fromRoot, dArgs.toRoot, td)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
+9 -12
View File
@@ -32,6 +32,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/json"
"github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/sqlexport"
"github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/tabular"
@@ -44,7 +45,7 @@ type diffWriter interface {
// BeginTable is called when a new table is about to be written, before any schema or row diffs are written
BeginTable(ctx context.Context, td diff.TableDelta) error
// WriteTableSchemaDiff is called to write a schema diff for the table given (if requested by args)
WriteTableSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error
WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error
// WriteTriggerDiff is called to write a trigger diff
WriteTriggerDiff(ctx context.Context, triggerName, oldDefn, newDefn string) error
// WriteViewDiff is called to write a view diff
@@ -221,17 +222,12 @@ func (t tabularDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) e
return nil
}
func (t tabularDiffWriter) WriteTableSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error {
fromSch, toSch, err := td.GetSchemas(ctx)
if err != nil {
return errhand.BuildDError("cannot retrieve schema for table %s", td.ToName).AddCause(err).Build()
}
func (t tabularDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error {
var fromCreateStmt = ""
if td.FromTable != nil {
// TODO: use UserSpaceDatabase for these, no reason for this separate database implementation
sqlDb := sqle.NewSingleTableDatabase(td.FromName, fromSch, td.FromFks, td.FromFksParentSch)
sqlDb := sqle.NewUserSpaceDatabase(fromRoot, editor.Options{})
sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb)
var err error
fromCreateStmt, err = sqle.GetCreateTableStmt(sqlCtx, engine, td.FromName)
if err != nil {
return errhand.VerboseErrorFromError(err)
@@ -240,8 +236,9 @@ func (t tabularDiffWriter) WriteTableSchemaDiff(ctx context.Context, toRoot *dol
var toCreateStmt = ""
if td.ToTable != nil {
sqlDb := sqle.NewSingleTableDatabase(td.ToName, toSch, td.ToFks, td.ToFksParentSch)
sqlDb := sqle.NewUserSpaceDatabase(toRoot, editor.Options{})
sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb)
var err error
toCreateStmt, err = sqle.GetCreateTableStmt(sqlCtx, engine, td.ToName)
if err != nil {
return errhand.VerboseErrorFromError(err)
@@ -298,7 +295,7 @@ func (s sqlDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) error
return nil
}
func (s sqlDiffWriter) WriteTableSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error {
func (s sqlDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error {
toSchemas, err := toRoot.GetAllSchemas(ctx)
if err != nil {
return errhand.BuildDError("could not read schemas from toRoot").AddCause(err).Build()
@@ -417,7 +414,7 @@ func (j *jsonDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) err
return err
}
func (j *jsonDiffWriter) WriteTableSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error {
func (j *jsonDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error {
toSchemas, err := toRoot.GetAllSchemas(ctx)
if err != nil {
return errhand.BuildDError("could not read schemas from toRoot").AddCause(err).Build()
-24
View File
@@ -306,30 +306,6 @@ func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commit)
parallelism := runtime.GOMAXPROCS(0)
azr := analyzer.NewBuilder(pro).WithParallelism(parallelism).Build()
head := dEnv.RepoStateReader().CWBHeadSpec()
headCommit, err := dEnv.DoltDB.Resolve(ctx, head, dEnv.RepoStateReader().CWBHeadRef())
if err != nil {
return nil, nil, err
}
ws, err := dEnv.WorkingSet(ctx)
if err != nil {
return nil, nil, err
}
dbState := dsess.InitialDbState{
Db: db,
HeadCommit: headCommit,
WorkingSet: ws,
DbData: dEnv.DbData(),
Remotes: dEnv.RepoState.Remotes,
}
err = sess.AddDB(sqlCtx, dbState)
if err != nil {
return nil, nil, err
}
root, err := cm.GetRootValue(ctx)
if err != nil {
return nil, nil, err
+1 -4
View File
@@ -27,7 +27,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table"
"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/json"
@@ -143,15 +142,13 @@ func (cmd CatCmd) Exec(ctx context.Context, commandStr string, args []string, dE
}
func (cmd CatCmd) prettyPrintResults(ctx context.Context, doltSch schema.Schema, idx durable.Index) error {
wr, err := getTableWriter(cmd.resultFormat, doltSch)
if err != nil {
return err
}
defer wr.Close(ctx)
sess := dsess.DefaultSession(dsess.EmptyDatabaseProvider())
sqlCtx := sql.NewContext(ctx, sql.WithSession(sess))
sqlCtx := sql.NewEmptyContext()
rowItr, err := table.NewTableIterator(ctx, doltSch, idx, 0)
if err != nil {
+44 -35
View File
@@ -36,6 +36,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/table"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/csv"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/funcitr"
@@ -301,8 +302,12 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
}
tblName := impArgs.tableName
// inferred schemas have no foreign keys
sqlDb := sqle.NewSingleTableDatabase(tblName, sch, nil, nil)
root, verr = putEmptyTableWithSchema(ctx, tblName, root, sch)
if verr != nil {
return verr
}
sqlDb := sqle.NewUserSpaceDatabase(root, editor.Options{})
sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb)
stmt, err := sqle.GetCreateTableStmt(sqlCtx, engine, tblName)
@@ -312,39 +317,6 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
cli.Println(stmt)
if !apr.Contains(dryRunFlag) {
tbl, tblExists, err := root.GetTable(ctx, tblName)
if err != nil {
return errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
}
empty, err := durable.NewEmptyIndex(ctx, root.VRW(), root.NodeStore(), sch)
if err != nil {
return errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
}
var indexSet durable.IndexSet
if tblExists {
indexSet, err = tbl.GetIndexSet(ctx)
if err != nil {
return errhand.BuildDError("error: failed to create table.").AddCause(err).Build()
}
} else {
indexSet, err = durable.NewIndexSetWithEmptyIndexes(ctx, root.VRW(), root.NodeStore(), sch)
if err != nil {
return errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
}
}
tbl, err = doltdb.NewTable(ctx, root.VRW(), root.NodeStore(), sch, empty, indexSet, nil)
if err != nil {
return errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
}
root, err = root.PutTable(ctx, tblName, tbl)
if err != nil {
return errhand.BuildDError("error: failed to add table.").AddCause(err).Build()
}
err = dEnv.UpdateWorkingRoot(ctx, root)
if err != nil {
return errhand.BuildDError("error: failed to update the working set.").AddCause(err).Build()
@@ -356,6 +328,43 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
return nil
}
func putEmptyTableWithSchema(ctx context.Context, tblName string, root *doltdb.RootValue, sch schema.Schema) (*doltdb.RootValue, errhand.VerboseError) {
tbl, tblExists, err := root.GetTable(ctx, tblName)
if err != nil {
return nil, errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
}
empty, err := durable.NewEmptyIndex(ctx, root.VRW(), root.NodeStore(), sch)
if err != nil {
return nil, errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
}
var indexSet durable.IndexSet
if tblExists {
indexSet, err = tbl.GetIndexSet(ctx)
if err != nil {
return nil, errhand.BuildDError("error: failed to create table.").AddCause(err).Build()
}
} else {
indexSet, err = durable.NewIndexSetWithEmptyIndexes(ctx, root.VRW(), root.NodeStore(), sch)
if err != nil {
return nil, errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
}
}
tbl, err = doltdb.NewTable(ctx, root.VRW(), root.NodeStore(), sch, empty, indexSet, nil)
if err != nil {
return nil, errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
}
root, err = root.PutTable(ctx, tblName, tbl)
if err != nil {
return nil, errhand.BuildDError("error: failed to add table.").AddCause(err).Build()
}
return root, nil
}
func inferSchemaFromFile(ctx context.Context, nbf *types.NomsBinFormat, impOpts *importOptions, root *doltdb.RootValue) (schema.Schema, errhand.VerboseError) {
if impOpts.fileType[0] == '.' {
impOpts.fileType = impOpts.fileType[1:]
@@ -83,7 +83,7 @@ func (database) IsReadOnly() bool {
return true
}
func (db database) InitialDBState(ctx context.Context, branch string) (dsess.InitialDbState, error) {
func (db database) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) {
// TODO: almost none of this state is actually used, but is necessary because the current session setup requires a
// repo state writer
return dsess.InitialDbState{
@@ -111,6 +111,18 @@ func (db database) EditOptions() editor.Options {
return editor.Options{}
}
func (db database) Revision() string {
return ""
}
func (db database) RevisionType() dsess.RevisionType {
return dsess.RevisionTypeNone
}
func (db database) BaseName() string {
return db.Name()
}
type noopRepoStateWriter struct{}
func (n noopRepoStateWriter) UpdateStagedRoot(ctx context.Context, newRoot *doltdb.RootValue) error {
+23 -108
View File
@@ -23,7 +23,6 @@ import (
"time"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/mysql_db"
"github.com/dolthub/go-mysql-server/sql/parse"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/types"
@@ -50,6 +49,7 @@ var ErrSystemTableAlter = errors.NewKind("Cannot alter table %s: system tables c
type SqlDatabase interface {
sql.Database
dsess.SessionDatabase
dsess.RevisionDatabase
// TODO: get rid of this, it's managed by the session, not the DB
GetRoot(*sql.Context) (*doltdb.RootValue, error)
@@ -58,34 +58,6 @@ type SqlDatabase interface {
EditOptions() editor.Options
}
// AllDbs returns all the databases in the given provider.
func AllDbs(ctx *sql.Context, pro sql.DatabaseProvider) []SqlDatabase {
dbs := pro.AllDatabases(ctx)
dsqlDBs := make([]SqlDatabase, 0, len(dbs))
for _, db := range dbs {
var sqlDb SqlDatabase
if sqlDatabase, ok := db.(SqlDatabase); ok {
sqlDb = sqlDatabase
} else if privDatabase, ok := db.(mysql_db.PrivilegedDatabase); ok {
if sqlDatabase, ok := privDatabase.Unwrap().(SqlDatabase); ok {
sqlDb = sqlDatabase
}
}
if sqlDb == nil {
continue
}
switch v := sqlDb.(type) {
case ReadReplicaDatabase, Database:
dsqlDBs = append(dsqlDBs, v)
case ReadOnlyDatabase, *UserSpaceDatabase:
default:
// esoteric analyzer errors occur if we silently drop databases, usually caused by pointer receivers
panic("cannot cast to SqlDatabase")
}
}
return dsqlDBs
}
// Database implements sql.Database for a dolt DB.
type Database struct {
name string
@@ -95,6 +67,7 @@ type Database struct {
gs globalstate.GlobalState
editOpts editor.Options
revision string
revType dsess.RevisionType
}
var _ SqlDatabase = Database{}
@@ -123,11 +96,24 @@ func (r ReadOnlyDatabase) IsReadOnly() bool {
return true
}
func (r ReadOnlyDatabase) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) {
return initialDBState(ctx, r, branch)
}
// Revision implements dsess.RevisionDatabase
func (db Database) Revision() string {
return db.revision
}
func (db Database) RevisionType() dsess.RevisionType {
return db.revType
}
func (db Database) BaseName() string {
base, _ := splitRevisionDbName(db)
return base
}
func (db Database) EditOptions() editor.Options {
return db.editOpts
}
@@ -149,76 +135,18 @@ func NewDatabase(ctx context.Context, name string, dbData env.DbData, editOpts e
}, nil
}
// GetInitialDBState returns the InitialDbState for |db|.
func GetInitialDBState(ctx context.Context, db SqlDatabase, branch string) (dsess.InitialDbState, error) {
switch db := db.(type) {
case *UserSpaceDatabase, *SingleTableInfoDatabase:
return getInitialDBStateForUserSpaceDb(ctx, db)
// initialDBState returns the InitialDbState for |db|. Other implementations of SqlDatabase outside this file should
// implement their own method for an initial db state and not rely on this method.
func initialDBState(ctx *sql.Context, db SqlDatabase, branch string) (dsess.InitialDbState, error) {
if len(db.Revision()) > 0 {
return initialStateForRevisionDb(ctx, db)
}
rsr := db.DbData().Rsr
ddb := db.DbData().Ddb
var r ref.DoltRef
if len(branch) > 0 {
r = ref.NewBranchRef(branch)
} else {
r = rsr.CWBHeadRef()
}
var retainedErr error
headCommit, err := ddb.ResolveCommitRef(ctx, r)
if err == doltdb.ErrBranchNotFound {
retainedErr = err
err = nil
}
if err != nil {
return dsess.InitialDbState{}, err
}
var ws *doltdb.WorkingSet
if retainedErr == nil {
workingSetRef, err := ref.WorkingSetRefForHead(r)
if err != nil {
return dsess.InitialDbState{}, err
}
ws, err = db.DbData().Ddb.ResolveWorkingSet(ctx, workingSetRef)
if err != nil {
return dsess.InitialDbState{}, err
}
}
remotes, err := rsr.GetRemotes()
if err != nil {
return dsess.InitialDbState{}, err
}
backups, err := rsr.GetBackups()
if err != nil {
return dsess.InitialDbState{}, err
}
branches, err := rsr.GetBranches()
if err != nil {
return dsess.InitialDbState{}, err
}
return dsess.InitialDbState{
Db: db,
HeadCommit: headCommit,
WorkingSet: ws,
DbData: db.DbData(),
Remotes: remotes,
Branches: branches,
Backups: backups,
Err: retainedErr,
}, nil
return initialDbState(ctx, db, branch)
}
func (db Database) InitialDBState(ctx context.Context, branch string) (dsess.InitialDbState, error) {
return GetInitialDBState(ctx, db, branch)
func (db Database) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) {
return initialDBState(ctx, db, branch)
}
// Name returns the name of this database, set at creation time.
@@ -1425,19 +1353,6 @@ func (db Database) SetCollation(ctx *sql.Context, collation sql.CollationID) err
return db.SetRoot(ctx, newRoot)
}
// TODO: this is a hack to make user space DBs appear to the analyzer as full DBs with state etc., but the state is
// really skeletal. We need to reexamine the DB / session initialization to make this simpler -- most of these things
// aren't needed at initialization time and for most code paths.
func getInitialDBStateForUserSpaceDb(ctx context.Context, db SqlDatabase) (dsess.InitialDbState, error) {
return dsess.InitialDbState{
Db: db,
DbData: env.DbData{
Rsw: noopRepoStateWriter{},
},
ReadOnly: true,
}, nil
}
// noopRepoStateWriter is a minimal implementation of RepoStateWriter that does nothing
type noopRepoStateWriter struct{}
+350 -210
View File
@@ -184,39 +184,20 @@ func (p DoltDatabaseProvider) FileSystemForDatabase(dbname string) (filesys.File
}
// Database implements the sql.DatabaseProvider interface
func (p DoltDatabaseProvider) Database(ctx *sql.Context, name string) (db sql.Database, err error) {
var ok bool
p.mu.RLock()
db, ok = p.databases[formatDbMapKeyName(name)]
standby := *p.isStandby
p.mu.RUnlock()
if ok {
return wrapForStandby(db, standby), nil
}
// Revision databases aren't tracked in the map, just instantiated on demand
db, _, ok, err = p.databaseForRevision(ctx, name)
func (p DoltDatabaseProvider) Database(ctx *sql.Context, name string) (sql.Database, error) {
database, b, err := p.SessionDatabase(ctx, name)
if err != nil {
return nil, err
}
// A final check: if the database doesn't exist and this is a read replica, attempt to clone it from the remote
if !ok {
db, err = p.databaseForClone(ctx, name)
if err != nil {
return nil, err
}
if db == nil {
return nil, sql.ErrDatabaseNotFound.New(name)
}
if !b {
return nil, sql.ErrDatabaseNotFound.New(name)
}
return wrapForStandby(db, standby), nil
return database, nil
}
func wrapForStandby(db sql.Database, standby bool) sql.Database {
func wrapForStandby(db SqlDatabase, standby bool) SqlDatabase {
if !standby {
return db
}
@@ -315,7 +296,7 @@ func (p DoltDatabaseProvider) AllDatabases(ctx *sql.Context) (all []sql.Database
// If the current database is not one of the primary databases, it must be a transitory revision database
if !foundDatabase && currDb != "" {
revDb, _, ok, err := p.databaseForRevision(ctx, currDb)
revDb, ok, err := p.databaseForRevision(ctx, currDb)
if err != nil {
// We can't return an error from this interface function, so just log a message
ctx.GetLogger().Warnf("unable to load %q as a database revision: %s", ctx.GetCurrentDatabase(), err.Error())
@@ -343,7 +324,7 @@ func (p DoltDatabaseProvider) allRevisionDbs(ctx *sql.Context, db SqlDatabase) (
revDbs := make([]sql.Database, len(branches))
for i, branch := range branches {
revDb, _, ok, err := p.databaseForRevision(ctx, fmt.Sprintf("%s/%s", db.Name(), branch.GetPath()))
revDb, ok, err := p.databaseForRevision(ctx, fmt.Sprintf("%s/%s", db.Name(), branch.GetPath()))
if err != nil {
return nil, err
}
@@ -464,12 +445,7 @@ func (p DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name stri
p.databases[formattedName] = db
p.dbLocations[formattedName] = newEnv.FS
dbstate, err := GetInitialDBState(ctx, db, "")
if err != nil {
return err
}
return sess.AddDB(ctx, dbstate)
return nil
}
type InitDatabaseHook func(ctx *sql.Context, pro DoltDatabaseProvider, name string, env *env.DoltEnv) error
@@ -593,7 +569,6 @@ func (p DoltDatabaseProvider) cloneDatabaseFromRemote(
Remote: remoteName,
})
sess := dsess.DSessFromSess(ctx.Session)
fkChecks, err := ctx.GetSessionVariable(ctx, "foreign_key_checks")
if err != nil {
return nil, err
@@ -612,22 +587,12 @@ func (p DoltDatabaseProvider) cloneDatabaseFromRemote(
p.databases[formatDbMapKeyName(db.Name())] = db
dbstate, err := GetInitialDBState(ctx, db, "")
if err != nil {
return nil, err
}
err = sess.AddDB(ctx, dbstate)
if err != nil {
return nil, err
}
return dEnv, nil
}
// DropDatabase implements the sql.MutableDatabaseProvider interface
func (p DoltDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error {
isRevisionDatabase, err := p.IsRevisionDatabase(ctx, name)
isRevisionDatabase, err := p.isRevisionDatabase(ctx, name)
if err != nil {
return err
}
@@ -715,6 +680,13 @@ func (p DoltDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error
// invalidateDbStateInAllSessions removes the db state for this database from every session. This is necessary when a
// database is dropped, so that other sessions don't use stale db state.
func (p DoltDatabaseProvider) invalidateDbStateInAllSessions(ctx *sql.Context, name string) error {
// Remove the db state from the current session
err := dsess.DSessFromSess(ctx.Session).RemoveDbState(ctx, name)
if err != nil {
return err
}
// If we have a running server, remove it from other sessions as well
runningServer := sqlserver.GetRunningServer()
if runningServer != nil {
sessionManager := runningServer.SessionManager()
@@ -740,9 +712,9 @@ func (p DoltDatabaseProvider) invalidateDbStateInAllSessions(ctx *sql.Context, n
return nil
}
func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string) (sql.Database, dsess.InitialDbState, bool, error) {
func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string) (SqlDatabase, bool, error) {
if !strings.Contains(revDB, dbRevisionDelimiter) {
return nil, dsess.InitialDbState{}, false, nil
return nil, false, nil
}
parts := strings.SplitN(revDB, dbRevisionDelimiter, 2)
@@ -752,95 +724,232 @@ func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string
candidate, ok := p.databases[formatDbMapKeyName(dbName)]
p.mu.RUnlock()
if !ok {
return nil, dsess.InitialDbState{}, false, nil
return nil, false, nil
}
srcDb, ok := candidate.(SqlDatabase)
if !ok {
return nil, dsess.InitialDbState{}, false, nil
return nil, false, nil
}
resolvedRevSpec, err := p.resolveAncestorSpec(ctx, revSpec, srcDb.DbData().Ddb)
dbType, resolvedRevSpec, err := revisionDbType(ctx, srcDb, revSpec)
if err != nil {
return nil, dsess.InitialDbState{}, false, err
return nil, false, err
}
caseSensitiveBranchName, isBranch, err := isBranch(ctx, srcDb, resolvedRevSpec)
if err != nil {
return nil, dsess.InitialDbState{}, false, err
}
if isBranch {
switch dbType {
case dsess.RevisionTypeBranch:
// fetch the upstream head if this is a replicated db
if replicaDb, ok := srcDb.(ReadReplicaDatabase); ok {
// TODO move this out of analysis phase, should only happen at read time, when the transaction begins (like is
// the case with a branch that already exists locally)
err := p.ensureReplicaHeadExists(ctx, resolvedRevSpec, replicaDb)
if err != nil {
return nil, dsess.InitialDbState{}, false, err
return nil, false, err
}
}
db, init, err := dbRevisionForBranch(ctx, srcDb, caseSensitiveBranchName)
db, err := revisionDbForBranch(ctx, srcDb, resolvedRevSpec)
// preserve original user case in the case of not found
if sql.ErrDatabaseNotFound.Is(err) {
return nil, dsess.InitialDbState{}, false, sql.ErrDatabaseNotFound.New(revDB)
return nil, false, sql.ErrDatabaseNotFound.New(revDB)
} else if err != nil {
return nil, dsess.InitialDbState{}, false, err
return nil, false, err
}
return db, init, true, nil
return db, true, nil
case dsess.RevisionTypeTag:
// TODO: this should be an interface, not a struct
replicaDb, ok := srcDb.(ReadReplicaDatabase)
if ok {
srcDb = replicaDb.Database
}
srcDb, ok = srcDb.(Database)
if !ok {
return nil, false, nil
}
db, err := revisionDbForTag(ctx, srcDb.(Database), resolvedRevSpec)
if err != nil {
return nil, false, err
}
return db, true, nil
case dsess.RevisionTypeCommit:
// TODO: this should be an interface, not a struct
replicaDb, ok := srcDb.(ReadReplicaDatabase)
if ok {
srcDb = replicaDb.Database
}
srcDb, ok = srcDb.(Database)
if !ok {
return nil, false, nil
}
db, err := revisionDbForCommit(ctx, srcDb.(Database), revSpec)
if err != nil {
return nil, false, err
}
return db, true, nil
case dsess.RevisionTypeNone:
// not an error, ok = false will get handled as a not found error in a layer above as appropriate
return nil, false, nil
default:
return nil, false, fmt.Errorf("unrecognized revision type for revision spec %s", revSpec)
}
}
// revisionDbType returns the type of revision spec given for the database given, and the resolved revision spec
func revisionDbType(ctx *sql.Context, srcDb SqlDatabase, revSpec string) (revType dsess.RevisionType, resolvedRevSpec string, err error) {
resolvedRevSpec, err = resolveAncestorSpec(ctx, revSpec, srcDb.DbData().Ddb)
if err != nil {
return 0, "", err
}
caseSensitiveBranchName, isBranch, err := isBranch(ctx, srcDb, resolvedRevSpec)
if err != nil {
return 0, "", err
}
if isBranch {
return dsess.RevisionTypeBranch, caseSensitiveBranchName, nil
}
isTag, err := isTag(ctx, srcDb, resolvedRevSpec)
if err != nil {
return nil, dsess.InitialDbState{}, false, err
return 0, "", err
}
if isTag {
// TODO: this should be an interface, not a struct
replicaDb, ok := srcDb.(ReadReplicaDatabase)
if ok {
srcDb = replicaDb.Database
}
srcDb, ok = srcDb.(Database)
if !ok {
return nil, dsess.InitialDbState{}, false, nil
}
db, init, err := dbRevisionForTag(ctx, srcDb.(Database), resolvedRevSpec)
if err != nil {
return nil, dsess.InitialDbState{}, false, err
}
return db, init, true, nil
return dsess.RevisionTypeTag, resolvedRevSpec, nil
}
if doltdb.IsValidCommitHash(resolvedRevSpec) {
// TODO: this should be an interface, not a struct
replicaDb, ok := srcDb.(ReadReplicaDatabase)
if ok {
srcDb = replicaDb.Database
}
srcDb, ok = srcDb.(Database)
if !ok {
return nil, dsess.InitialDbState{}, false, nil
}
db, init, err := dbRevisionForCommit(ctx, srcDb.(Database), revSpec)
if err != nil {
return nil, dsess.InitialDbState{}, false, err
}
return db, init, true, nil
return dsess.RevisionTypeCommit, resolvedRevSpec, nil
}
return nil, dsess.InitialDbState{}, false, nil
return dsess.RevisionTypeNone, "", nil
}
func initialDbState(ctx context.Context, db SqlDatabase, branch string) (dsess.InitialDbState, error) {
rsr := db.DbData().Rsr
ddb := db.DbData().Ddb
var r ref.DoltRef
if len(branch) > 0 {
r = ref.NewBranchRef(branch)
} else {
r = rsr.CWBHeadRef()
}
var retainedErr error
headCommit, err := ddb.ResolveCommitRef(ctx, r)
if err == doltdb.ErrBranchNotFound {
retainedErr = err
err = nil
}
if err != nil {
return dsess.InitialDbState{}, err
}
var ws *doltdb.WorkingSet
if retainedErr == nil {
workingSetRef, err := ref.WorkingSetRefForHead(r)
if err != nil {
return dsess.InitialDbState{}, err
}
ws, err = db.DbData().Ddb.ResolveWorkingSet(ctx, workingSetRef)
if err != nil {
return dsess.InitialDbState{}, err
}
}
remotes, err := rsr.GetRemotes()
if err != nil {
return dsess.InitialDbState{}, err
}
backups, err := rsr.GetBackups()
if err != nil {
return dsess.InitialDbState{}, err
}
branches, err := rsr.GetBranches()
if err != nil {
return dsess.InitialDbState{}, err
}
return dsess.InitialDbState{
Db: db,
HeadCommit: headCommit,
WorkingSet: ws,
DbData: db.DbData(),
Remotes: remotes,
Branches: branches,
Backups: backups,
Err: retainedErr,
}, nil
}
func initialStateForRevisionDb(ctx *sql.Context, db SqlDatabase) (dsess.InitialDbState, error) {
switch db.RevisionType() {
case dsess.RevisionTypeBranch:
init, err := initialStateForBranchDb(ctx, db)
// preserve original user case in the case of not found
if sql.ErrDatabaseNotFound.Is(err) {
return dsess.InitialDbState{}, sql.ErrDatabaseNotFound.New(db.Name())
} else if err != nil {
return dsess.InitialDbState{}, err
}
return init, nil
case dsess.RevisionTypeTag:
// TODO: this should be an interface, not a struct
replicaDb, ok := db.(ReadReplicaDatabase)
if ok {
db = replicaDb.Database
}
db, ok = db.(ReadOnlyDatabase)
if !ok {
return dsess.InitialDbState{}, fmt.Errorf("expected a ReadOnlyDatabase, got %T", db)
}
init, err := initialStateForTagDb(ctx, db.(ReadOnlyDatabase))
if err != nil {
return dsess.InitialDbState{}, err
}
return init, nil
case dsess.RevisionTypeCommit:
// TODO: this should be an interface, not a struct
replicaDb, ok := db.(ReadReplicaDatabase)
if ok {
db = replicaDb.Database
}
db, ok = db.(ReadOnlyDatabase)
if !ok {
return dsess.InitialDbState{}, fmt.Errorf("expected a ReadOnlyDatabase, got %T", db)
}
init, err := initialStateForCommit(ctx, db.(ReadOnlyDatabase))
if err != nil {
return dsess.InitialDbState{}, err
}
return init, nil
default:
return dsess.InitialDbState{}, fmt.Errorf("unrecognized revision type for revision spec %s: %v", db.Revision(), db.RevisionType())
}
}
// databaseForClone returns a newly cloned database if read replication is enabled and a remote DB exists, or an error
// otherwise
func (p DoltDatabaseProvider) databaseForClone(ctx *sql.Context, revDB string) (sql.Database, error) {
func (p DoltDatabaseProvider) databaseForClone(ctx *sql.Context, revDB string) (SqlDatabase, error) {
if !readReplicationActive(ctx) {
return nil, nil
}
@@ -860,7 +969,8 @@ func (p DoltDatabaseProvider) databaseForClone(ctx *sql.Context, revDB string) (
}
// now that the database has been cloned, retry the Database call
return p.Database(ctx, revDB)
database, err := p.Database(ctx, revDB)
return database.(SqlDatabase), err
}
// TODO: figure out the right contract: which variables must be set? What happens if they aren't all set?
@@ -881,7 +991,7 @@ func readReplicationActive(ctx *sql.Context) bool {
// resolveAncestorSpec resolves the specified revSpec to a specific commit hash if it contains an ancestor reference
// such as ~ or ^. If no ancestor reference is present, the specified revSpec is returned as is. If any unexpected
// problems are encountered, an error is returned.
func (p DoltDatabaseProvider) resolveAncestorSpec(ctx *sql.Context, revSpec string, ddb *doltdb.DoltDB) (string, error) {
func resolveAncestorSpec(ctx *sql.Context, revSpec string, ddb *doltdb.DoltDB) (string, error) {
refname, ancestorSpec, err := doltdb.SplitAncestorSpec(revSpec)
if err != nil {
return "", err
@@ -913,42 +1023,37 @@ func (p DoltDatabaseProvider) resolveAncestorSpec(ctx *sql.Context, revSpec stri
return hash.String(), nil
}
func (p DoltDatabaseProvider) RevisionDbState(ctx *sql.Context, revDB string) (dsess.InitialDbState, error) {
_, init, ok, err := p.databaseForRevision(ctx, revDB)
if err != nil {
return dsess.InitialDbState{}, err
} else if !ok {
return dsess.InitialDbState{}, sql.ErrDatabaseNotFound.New(revDB)
}
return init, nil
}
func (p DoltDatabaseProvider) stateForDatabase(ctx *sql.Context, dbName string, branch string) (dsess.InitialDbState, bool, error) {
// SessionDatabase implements dsess.SessionDatabaseProvider
func (p DoltDatabaseProvider) SessionDatabase(ctx *sql.Context, name string) (dsess.SessionDatabase, bool, error) {
var ok bool
p.mu.RLock()
db, ok := p.databases[formatDbMapKeyName(dbName)]
db, ok := p.databases[formatDbMapKeyName(name)]
standby := *p.isStandby
p.mu.RUnlock()
if ok {
return wrapForStandby(db, standby), true, nil
}
// Revision databases aren't tracked in the map, just instantiated on demand
db, ok, err := p.databaseForRevision(ctx, name)
if err != nil {
return nil, false, err
}
// A final check: if the database doesn't exist and this is a read replica, attempt to clone it from the remote
if !ok {
return dsess.InitialDbState{}, false, nil
db, err = p.databaseForClone(ctx, name)
if err != nil {
return nil, false, err
}
if db == nil {
return nil, false, nil
}
}
dbState, err := db.InitialDBState(ctx, branch)
if err != nil {
return dsess.InitialDbState{}, false, err
}
return dbState, true, nil
}
func (p DoltDatabaseProvider) DbState(ctx *sql.Context, dbName string, defaultBranch string) (dsess.InitialDbState, error) {
init, ok, err := p.stateForDatabase(ctx, dbName, defaultBranch)
if err != nil {
return dsess.InitialDbState{}, err
} else if !ok {
return dsess.InitialDbState{}, sql.ErrDatabaseNotFound.New(dbName)
}
return init, nil
return wrapForStandby(db, standby), true, nil
}
// Function implements the FunctionProvider interface
@@ -998,34 +1103,35 @@ func (p DoltDatabaseProvider) TableFunction(_ *sql.Context, name string) (sql.Ta
return nil, sql.ErrTableFunctionNotFound.New(name)
}
// GetRevisionForRevisionDatabase implements dsess.RevisionDatabaseProvider
func (p DoltDatabaseProvider) GetRevisionForRevisionDatabase(ctx *sql.Context, dbName string) (string, string, error) {
db, err := p.Database(ctx, dbName)
if err != nil {
return "", "", err
}
sqldb, ok := db.(dsess.RevisionDatabase)
// splitRevisionDbName splits the given database name into its base and revision parts and returns them. Non-revision
// DBs use their full name as the base name, and empty string as the revision.
func splitRevisionDbName(db sql.Database) (string, string) {
sqldb, ok := db.(SqlDatabase)
if !ok {
return db.Name(), "", nil
return db.Name(), ""
}
dbName := db.Name()
if sqldb.Revision() != "" {
dbName = strings.TrimSuffix(dbName, dbRevisionDelimiter+sqldb.Revision())
}
return dbName, sqldb.Revision(), nil
return dbName, sqldb.Revision()
}
// IsRevisionDatabase returns true if the specified dbName represents a database that is tied to a specific
// isRevisionDatabase returns true if the specified dbName represents a database that is tied to a specific
// branch or commit from a database (e.g. "dolt/branch1").
func (p DoltDatabaseProvider) IsRevisionDatabase(ctx *sql.Context, dbName string) (bool, error) {
_, revision, err := p.GetRevisionForRevisionDatabase(ctx, dbName)
func (p DoltDatabaseProvider) isRevisionDatabase(ctx *sql.Context, dbName string) (bool, error) {
db, ok, err := p.SessionDatabase(ctx, dbName)
if err != nil {
return false, err
}
if !ok {
return false, sql.ErrDatabaseNotFound.New(dbName)
}
return revision != "", nil
_, rev := splitRevisionDbName(db)
return rev != "", nil
}
// ensureReplicaHeadExists tries to pull the latest version of a remote branch. Will fail if the branch
@@ -1036,14 +1142,9 @@ func (p DoltDatabaseProvider) ensureReplicaHeadExists(ctx *sql.Context, branch s
// isBranch returns whether a branch with the given name is in scope for the database given
func isBranch(ctx context.Context, db SqlDatabase, branchName string) (string, bool, error) {
var ddbs []*doltdb.DoltDB
if rdb, ok := db.(ReadReplicaDatabase); ok {
ddbs = append(ddbs, rdb.ddb, rdb.srcDB)
} else if ddb, ok := db.(Database); ok {
ddbs = append(ddbs, ddb.ddb)
} else {
return "", false, fmt.Errorf("unrecognized type of database %T", db)
ddbs, err := doltDbs(db)
if err != nil {
return "", false, err
}
brName, branchExists, err := isLocalBranch(ctx, ddbs, branchName)
@@ -1065,6 +1166,21 @@ func isBranch(ctx context.Context, db SqlDatabase, branchName string) (string, b
return "", false, nil
}
func doltDbs(db SqlDatabase) ([]*doltdb.DoltDB, error) {
var ddbs []*doltdb.DoltDB
switch db := db.(type) {
case ReadReplicaDatabase:
ddbs = append(ddbs, db.ddb, db.srcDB)
case Database:
ddbs = append(ddbs, db.ddb)
case ReadOnlyDatabase:
ddbs = append(ddbs, db.ddb)
default:
return nil, fmt.Errorf("unrecognized type of database %T", db)
}
return ddbs, nil
}
func isLocalBranch(ctx context.Context, ddbs []*doltdb.DoltDB, branchName string) (string, bool, error) {
for _, ddb := range ddbs {
brName, branchExists, err := ddb.HasBranch(ctx, branchName)
@@ -1098,14 +1214,9 @@ func isRemoteBranch(ctx context.Context, ddbs []*doltdb.DoltDB, branchName strin
// isTag returns whether a tag with the given name is in scope for the database given
func isTag(ctx context.Context, db SqlDatabase, tagName string) (bool, error) {
var ddbs []*doltdb.DoltDB
if rdb, ok := db.(ReadReplicaDatabase); ok {
ddbs = append(ddbs, rdb.ddb, rdb.srcDB)
} else if ddb, ok := db.(Database); ok {
ddbs = append(ddbs, ddb.ddb)
} else {
return false, fmt.Errorf("unrecognized type of database %T", db)
ddbs, err := doltDbs(db)
if err != nil {
return false, err
}
for _, ddb := range ddbs {
@@ -1122,23 +1233,9 @@ func isTag(ctx context.Context, db SqlDatabase, tagName string) (bool, error) {
return false, nil
}
func dbRevisionForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string) (SqlDatabase, dsess.InitialDbState, error) {
// revisionDbForBranch returns a new database that is tied to the branch named by revSpec
func revisionDbForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string) (SqlDatabase, error) {
branch := ref.NewBranchRef(revSpec)
cm, err := srcDb.DbData().Ddb.ResolveCommitRef(ctx, branch)
if err != nil {
return Database{}, dsess.InitialDbState{}, err
}
wsRef, err := ref.WorkingSetRefForHead(branch)
if err != nil {
return Database{}, dsess.InitialDbState{}, err
}
ws, err := srcDb.DbData().Ddb.ResolveWorkingSet(ctx, wsRef)
if err != nil {
return Database{}, dsess.InitialDbState{}, err
}
dbName := srcDb.Name() + dbRevisionDelimiter + revSpec
static := staticRepoState{
@@ -1159,6 +1256,7 @@ func dbRevisionForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string)
gs: v.gs,
editOpts: v.editOpts,
revision: revSpec,
revType: dsess.RevisionTypeBranch,
}
case ReadReplicaDatabase:
db = ReadReplicaDatabase{
@@ -1170,6 +1268,7 @@ func dbRevisionForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string)
gs: v.gs,
editOpts: v.editOpts,
revision: revSpec,
revType: dsess.RevisionTypeBranch,
},
remote: v.remote,
srcDB: v.srcDB,
@@ -1178,23 +1277,51 @@ func dbRevisionForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string)
}
}
return db, nil
}
func initialStateForBranchDb(ctx context.Context, srcDb SqlDatabase) (dsess.InitialDbState, error) {
_, revSpec := splitRevisionDbName(srcDb)
branch := ref.NewBranchRef(revSpec)
cm, err := srcDb.DbData().Ddb.ResolveCommitRef(ctx, branch)
if err != nil {
return dsess.InitialDbState{}, err
}
wsRef, err := ref.WorkingSetRefForHead(branch)
if err != nil {
return dsess.InitialDbState{}, err
}
ws, err := srcDb.DbData().Ddb.ResolveWorkingSet(ctx, wsRef)
if err != nil {
return dsess.InitialDbState{}, err
}
static := staticRepoState{
branch: branch,
RepoStateWriter: srcDb.DbData().Rsw,
RepoStateReader: srcDb.DbData().Rsr,
}
remotes, err := static.GetRemotes()
if err != nil {
return nil, dsess.InitialDbState{}, err
return dsess.InitialDbState{}, err
}
branches, err := static.GetBranches()
if err != nil {
return nil, dsess.InitialDbState{}, err
return dsess.InitialDbState{}, err
}
backups, err := static.GetBackups()
if err != nil {
return nil, dsess.InitialDbState{}, err
return dsess.InitialDbState{}, err
}
init := dsess.InitialDbState{
Db: db,
Db: srcDb,
HeadCommit: cm,
WorkingSet: ws,
DbData: env.DbData{
@@ -1205,31 +1332,37 @@ func dbRevisionForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string)
Remotes: remotes,
Branches: branches,
Backups: backups,
//ReadReplica: //todo
}
return db, init, nil
return init, nil
}
func dbRevisionForTag(ctx context.Context, srcDb Database, revSpec string) (ReadOnlyDatabase, dsess.InitialDbState, error) {
func revisionDbForTag(ctx context.Context, srcDb Database, revSpec string) (ReadOnlyDatabase, error) {
name := srcDb.Name() + dbRevisionDelimiter + revSpec
db := ReadOnlyDatabase{Database: Database{
name: name,
ddb: srcDb.DbData().Ddb,
rsw: srcDb.DbData().Rsw,
rsr: srcDb.DbData().Rsr,
editOpts: srcDb.editOpts,
revision: revSpec,
revType: dsess.RevisionTypeTag,
}}
return db, nil
}
func initialStateForTagDb(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.InitialDbState, error) {
_, revSpec := splitRevisionDbName(srcDb)
tag := ref.NewTagRef(revSpec)
cm, err := srcDb.DbData().Ddb.ResolveCommitRef(ctx, tag)
if err != nil {
return ReadOnlyDatabase{}, dsess.InitialDbState{}, err
return dsess.InitialDbState{}, err
}
name := srcDb.Name() + dbRevisionDelimiter + revSpec
db := ReadOnlyDatabase{Database: Database{
name: name,
ddb: srcDb.DbData().Ddb,
rsw: srcDb.DbData().Rsw,
rsr: srcDb.DbData().Rsr,
editOpts: srcDb.editOpts,
revision: revSpec,
}}
init := dsess.InitialDbState{
Db: db,
Db: srcDb,
HeadCommit: cm,
ReadOnly: true,
DbData: env.DbData{
@@ -1244,32 +1377,39 @@ func dbRevisionForTag(ctx context.Context, srcDb Database, revSpec string) (Read
// - ReadReplicas
}
return db, init, nil
return init, nil
}
func dbRevisionForCommit(ctx context.Context, srcDb Database, revSpec string) (ReadOnlyDatabase, dsess.InitialDbState, error) {
func revisionDbForCommit(ctx context.Context, srcDb Database, revSpec string) (ReadOnlyDatabase, error) {
name := srcDb.Name() + dbRevisionDelimiter + revSpec
db := ReadOnlyDatabase{Database: Database{
name: name,
ddb: srcDb.DbData().Ddb,
rsw: srcDb.DbData().Rsw,
rsr: srcDb.DbData().Rsr,
editOpts: srcDb.editOpts,
revision: revSpec,
revType: dsess.RevisionTypeCommit,
}}
return db, nil
}
func initialStateForCommit(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.InitialDbState, error) {
_, revSpec := splitRevisionDbName(srcDb)
spec, err := doltdb.NewCommitSpec(revSpec)
if err != nil {
return ReadOnlyDatabase{}, dsess.InitialDbState{}, err
return dsess.InitialDbState{}, err
}
cm, err := srcDb.DbData().Ddb.Resolve(ctx, spec, srcDb.DbData().Rsr.CWBHeadRef())
if err != nil {
return ReadOnlyDatabase{}, dsess.InitialDbState{}, err
return dsess.InitialDbState{}, err
}
name := srcDb.Name() + dbRevisionDelimiter + revSpec
db := ReadOnlyDatabase{Database: Database{
name: name,
ddb: srcDb.DbData().Ddb,
rsw: srcDb.DbData().Rsw,
rsr: srcDb.DbData().Rsr,
editOpts: srcDb.editOpts,
revision: revSpec,
}}
init := dsess.InitialDbState{
Db: db,
Db: srcDb,
HeadCommit: cm,
ReadOnly: true,
DbData: env.DbData{
@@ -1284,7 +1424,7 @@ func dbRevisionForCommit(ctx context.Context, srcDb Database, revSpec string) (R
// - ReadReplicas
}
return db, init, nil
return init, nil
}
type staticRepoState struct {
@@ -178,8 +178,20 @@ func getRevisionForRevisionDatabase(ctx *sql.Context, dbName string) (string, st
return "", "", fmt.Errorf("unexpected session type: %T", ctx.Session)
}
provider := doltsess.Provider()
return provider.GetRevisionForRevisionDatabase(ctx, dbName)
db, ok, err := doltsess.Provider().SessionDatabase(ctx, dbName)
if err != nil {
return "", "", err
}
if !ok {
return "", "", sql.ErrDatabaseNotFound.New(dbName)
}
rdb, ok := db.(dsess.RevisionDatabase)
if !ok {
return dbName, "", nil
}
return rdb.BaseName(), rdb.Revision(), nil
}
// checkoutRemoteBranch checks out a remote branch creating a new local branch with the same name as the remote branch
@@ -15,8 +15,6 @@
package dsess
import (
"context"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
@@ -26,11 +24,19 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
)
// InitialDbState is the initial state of a database, as returned by SessionDatabase.InitialDBState. It is used to
// establish the in memory state of the session for every new transaction.
type InitialDbState struct {
Db sql.Database
HeadCommit *doltdb.Commit
Db sql.Database
// WorkingSet is the working set for this database. May be nil for databases tied to a detached root value, in which
// case HeadCommit must be set
WorkingSet *doltdb.WorkingSet
// The head commit for this database. May be nil for databases tied to a detached root value, in which case
// RootValue must be set.
HeadCommit *doltdb.Commit
// HeadRoot is the root value for databases without a HeadCommit. Nil for databases with a HeadCommit.
HeadRoot *doltdb.RootValue
ReadOnly bool
WorkingSet *doltdb.WorkingSet
DbData env.DbData
ReadReplica *env.Remote
Remotes map[string]env.Remote
@@ -48,7 +54,7 @@ type InitialDbState struct {
// order for the session to manage it.
type SessionDatabase interface {
sql.Database
InitialDBState(ctx context.Context, branch string) (InitialDbState, error)
InitialDBState(ctx *sql.Context, branch string) (InitialDbState, error)
}
type DatabaseSessionState struct {
@@ -15,6 +15,7 @@
package dsess
import (
"context"
"testing"
"github.com/dolthub/go-mysql-server/sql"
@@ -22,17 +23,21 @@ import (
"github.com/stretchr/testify/assert"
"gopkg.in/src-d/go-errors.v1"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/store/types"
)
func TestDoltSessionInit(t *testing.T) {
dsess := DefaultSession(EmptyDatabaseProvider())
dsess := DefaultSession(emptyDatabaseProvider())
conf := config.NewMapConfig(make(map[string]string))
assert.Equal(t, conf, dsess.globalsConf)
}
func TestNewPersistedSystemVariables(t *testing.T) {
dsess := DefaultSession(EmptyDatabaseProvider())
dsess := DefaultSession(emptyDatabaseProvider())
conf := config.NewMapConfig(map[string]string{"max_connections": "1000"})
dsess = dsess.WithGlobals(conf)
@@ -237,3 +242,55 @@ func TestGetPersistedValue(t *testing.T) {
})
}
}
func emptyDatabaseProvider() DoltDatabaseProvider {
return emptyRevisionDatabaseProvider{}
}
type emptyRevisionDatabaseProvider struct {
sql.DatabaseProvider
}
func (e emptyRevisionDatabaseProvider) DbState(ctx *sql.Context, dbName string, defaultBranch string) (InitialDbState, error) {
return InitialDbState{}, sql.ErrDatabaseNotFound.New(dbName)
}
func (e emptyRevisionDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error {
return nil
}
func (e emptyRevisionDatabaseProvider) GetRevisionForRevisionDatabase(_ *sql.Context, _ string) (string, string, error) {
return "", "", nil
}
func (e emptyRevisionDatabaseProvider) IsRevisionDatabase(_ *sql.Context, _ string) (bool, error) {
return false, nil
}
func (e emptyRevisionDatabaseProvider) GetRemoteDB(ctx context.Context, format *types.NomsBinFormat, r env.Remote, withCaching bool) (*doltdb.DoltDB, error) {
return nil, nil
}
func (e emptyRevisionDatabaseProvider) FileSystem() filesys.Filesys {
return nil
}
func (e emptyRevisionDatabaseProvider) FileSystemForDatabase(dbname string) (filesys.Filesys, error) {
return nil, nil
}
func (e emptyRevisionDatabaseProvider) CloneDatabaseFromRemote(ctx *sql.Context, dbName, branch, remoteName, remoteUrl string, remoteParams map[string]string) error {
return nil
}
func (e emptyRevisionDatabaseProvider) CreateDatabase(ctx *sql.Context, dbName string) error {
return nil
}
func (e emptyRevisionDatabaseProvider) RevisionDbState(_ *sql.Context, revDB string) (InitialDbState, error) {
return InitialDbState{}, sql.ErrDatabaseNotFound.New(revDB)
}
func (e emptyRevisionDatabaseProvider) SessionDatabase(ctx *sql.Context, dbName string) (SessionDatabase, bool, error) {
return nil, false, nil
}
+77 -74
View File
@@ -144,32 +144,19 @@ func (d *DoltSession) lookupDbState(ctx *sql.Context, dbName string) (*DatabaseS
}
// TODO: this needs to include the transaction's snapshot of the DB at tx start time
var init InitialDbState
var err error
_, val, ok := sql.SystemVariables.GetGlobal(DefaultBranchKey(dbName))
initialBranch := ""
if ok {
initialBranch = val.(string)
}
// First attempt to find a bare database (no revision spec)
init, err = d.provider.DbState(ctx, dbName, initialBranch)
if err != nil && !sql.ErrDatabaseNotFound.Is(err) {
database, ok, err := d.provider.SessionDatabase(ctx, dbName)
if err != nil {
return nil, false, err
}
// If that doesn't work, attempt to parse the database name as a revision spec
if err != nil {
init, err = d.provider.RevisionDbState(ctx, dbName)
if err != nil {
return nil, false, err
}
if !ok {
return nil, false, nil
}
// If we got this far, we have a valid initial database state, so add it to the session for future reuse
if err = d.AddDB(ctx, init); err != nil {
return nil, ok, err
// Add the initial state to the session for future reuse
if err = d.addDB(ctx, database); err != nil {
return nil, false, err
}
d.mu.Lock()
@@ -236,12 +223,12 @@ func (d *DoltSession) ValidateSession(ctx *sql.Context, dbName string) error {
return d.validateErr
}
sessionState, ok, err := d.LookupDbState(ctx, dbName)
if !ok {
return nil
}
if err != nil {
return err
}
if !ok {
return nil
}
if sessionState.WorkingSet == nil {
return nil
}
@@ -279,24 +266,16 @@ func (d *DoltSession) StartTransaction(ctx *sql.Context, tCharacteristic sql.Tra
// Since StartTransaction occurs before even any analysis, it's possible that this session has no state for the
// database with the transaction being performed, so we load it here.
if !d.HasDB(ctx, dbName) {
db, err := d.provider.Database(ctx, dbName)
db, ok, err := d.provider.SessionDatabase(ctx, dbName)
if err != nil {
return nil, err
}
sdb, ok := db.(SessionDatabase)
if !ok {
return nil, fmt.Errorf("database %s does not support sessions", dbName)
return nil, sql.ErrDatabaseNotFound.New(dbName)
}
// TODO: this needs a real branch name
init, err := sdb.InitialDBState(ctx, "")
if err != nil {
return nil, err
}
// TODO: make this take a DB, not a DBState
err = d.AddDB(ctx, init)
err = d.addDB(ctx, db)
if err != nil {
return nil, err
}
@@ -1113,13 +1092,9 @@ func (d *DoltSession) HasDB(_ *sql.Context, dbName string) bool {
return ok
}
// AddDB adds the database given to this session. This establishes a starting root value for this session, as well as
// addDB adds the database given to this session. This establishes a starting root value for this session, as well as
// other state tracking metadata.
// TODO: the session has a database provider, we shouldn't need to add databases to it explicitly, this should be
//
// internal only
func (d *DoltSession) AddDB(ctx *sql.Context, dbState InitialDbState) error {
db := dbState.Db
func (d *DoltSession) addDB(ctx *sql.Context, db SessionDatabase) error {
DefineSystemVariablesForDB(db.Name())
sessionState := NewEmptyDatabaseSessionState()
@@ -1129,6 +1104,18 @@ func (d *DoltSession) AddDB(ctx *sql.Context, dbState InitialDbState) error {
sessionState.dbName = db.Name()
sessionState.db = db
_, val, ok := sql.SystemVariables.GetGlobal(DefaultBranchKey(db.Name()))
initialBranch := ""
if ok {
initialBranch = val.(string)
}
// TODO: the branch should be already set if the DB was specified with a branch revision string
dbState, err := db.InitialDBState(ctx, initialBranch)
if err != nil {
return err
}
// TODO: get rid of all repo state reader / writer stuff. Until we do, swap out the reader with one of our own, and
// the writer with one that errors out
// TODO: this no longer gets called at session creation time, so the error handling below never occurs when a
@@ -1148,23 +1135,26 @@ func (d *DoltSession) AddDB(ctx *sql.Context, dbState InitialDbState) error {
sessionState.readOnly, sessionState.readReplica = dbState.ReadOnly, dbState.ReadReplica
// TODO: figure out how to cast this to dsqle.SqlDatabase without creating import cycles
// Or better yet, get rid of EditOptions from the database, it's a session setting
nbf := types.Format_Default
if sessionState.dbData.Ddb != nil {
nbf = sessionState.dbData.Ddb.Format()
}
editOpts := db.(interface{ EditOptions() editor.Options }).EditOptions()
stateProvider, ok := db.(globalstate.StateProvider)
if !ok {
return fmt.Errorf("database does not contain global state store")
}
sessionState.globalState = stateProvider.GetGlobalState()
// WorkingSet is nil in the case of a read only, detached head DB
if dbState.Err != nil {
sessionState.Err = dbState.Err
} else if dbState.WorkingSet != nil {
sessionState.WorkingSet = dbState.WorkingSet
// TODO: this is pretty clunky, there is a silly dependency between InitialDbState and globalstate.StateProvider
// that's hard to express with the current types
stateProvider, ok := db.(globalstate.StateProvider)
if !ok {
return fmt.Errorf("database does not contain global state store")
}
sessionState.globalState = stateProvider.GetGlobalState()
tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx)
if err != nil {
return err
@@ -1173,13 +1163,15 @@ func (d *DoltSession) AddDB(ctx *sql.Context, dbState InitialDbState) error {
if err = d.SetWorkingSet(ctx, db.Name(), dbState.WorkingSet); err != nil {
return err
}
} else {
} else if dbState.HeadCommit != nil {
// WorkingSet is nil in the case of a read only, detached head DB
headRoot, err := dbState.HeadCommit.GetRootValue(ctx)
if err != nil {
return err
}
sessionState.headRoot = headRoot
} else if dbState.HeadRoot != nil {
sessionState.headRoot = dbState.HeadRoot
}
// This has to happen after SetRoot above, since it does a stale check before its work
@@ -1257,6 +1249,8 @@ func (d *DoltSession) setSessionVarsForDb(ctx *sql.Context, dbName string) error
return err
}
// Different DBs have different requirements for what state is set, so we are maximally permissive on what's expected
// in the state object here
if state.WorkingSet != nil {
headRef, err := state.WorkingSet.Ref().ToHeadRef()
if err != nil {
@@ -1271,31 +1265,37 @@ func (d *DoltSession) setSessionVarsForDb(ctx *sql.Context, dbName string) error
roots := state.GetRoots()
h, err := roots.Working.HashOf()
if err != nil {
return err
}
err = d.Session.SetSessionVariable(ctx, WorkingKey(dbName), h.String())
if err != nil {
return err
if roots.Working != nil {
h, err := roots.Working.HashOf()
if err != nil {
return err
}
err = d.Session.SetSessionVariable(ctx, WorkingKey(dbName), h.String())
if err != nil {
return err
}
}
h, err = roots.Staged.HashOf()
if err != nil {
return err
}
err = d.Session.SetSessionVariable(ctx, StagedKey(dbName), h.String())
if err != nil {
return err
if roots.Staged != nil {
h, err := roots.Staged.HashOf()
if err != nil {
return err
}
err = d.Session.SetSessionVariable(ctx, StagedKey(dbName), h.String())
if err != nil {
return err
}
}
h, err = state.headCommit.HashOf()
if err != nil {
return err
}
err = d.Session.SetSessionVariable(ctx, HeadKey(dbName), h.String())
if err != nil {
return err
if state.headCommit != nil {
h, err := state.headCommit.HashOf()
if err != nil {
return err
}
err = d.Session.SetSessionVariable(ctx, HeadKey(dbName), h.String())
if err != nil {
return err
}
}
return nil
@@ -1382,14 +1382,17 @@ func (d *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) {
func (d *DoltSession) GetBranch() (string, error) {
ctx := sql.NewContext(context.Background(), sql.WithSession(d))
currentDb := d.Session.GetCurrentDatabase()
// no branch if there's no current db
if currentDb == "" {
return "", nil
}
dbState, _, err := d.LookupDbState(ctx, currentDb)
if err != nil {
if len(currentDb) == 0 && sql.ErrDatabaseNotFound.Is(err) {
// Some operations return an empty database (namely tests), so we return an empty branch in such cases
return "", nil
}
return "", err
}
if dbState.WorkingSet != nil {
branchRef, err := dbState.WorkingSet.Ref().ToHeadRef()
if err != nil {
@@ -18,7 +18,6 @@ import (
"context"
"github.com/dolthub/go-mysql-server/sql"
"gopkg.in/src-d/go-errors.v1"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
@@ -26,25 +25,6 @@ import (
"github.com/dolthub/dolt/go/store/types"
)
// ErrRevisionDbNotFound is thrown when a RevisionDatabaseProvider cannot find a specified revision database.
var ErrRevisionDbNotFound = errors.NewKind("revision database not found: '%s'")
// RevisionDatabaseProvider provides revision databases.
// In Dolt, commits and branches can be accessed as discrete databases
// using a Dolt-specific syntax: `my_database/my_branch`. Revision databases
// corresponding to historical commits in the repository will be read-only
// databases. Revision databases for branches will be read/write.
type RevisionDatabaseProvider interface {
// RevisionDbState provides the InitialDbState for a revision database.
RevisionDbState(ctx *sql.Context, revDB string) (InitialDbState, error)
// IsRevisionDatabase validates the specified dbName and returns true if it is a valid revision database.
IsRevisionDatabase(ctx *sql.Context, dbName string) (bool, error)
// GetRevisionForRevisionDatabase looks up the named database and returns the root database name as well as the
// revision and any errors encountered. If the specified database is not a revision database, the root database
// name will still be returned, and the revision will be an empty string.
GetRevisionForRevisionDatabase(ctx *sql.Context, dbName string) (string, string, error)
}
// RevisionDatabase allows callers to query a revision database for the commit, branch, or tag it is pinned to. For
// example, when using a database with a branch revision specification, that database is only able to use that branch
// and cannot change branch heads. Calling `Revision` on that database will return the branch name. Similarly, for
@@ -56,8 +36,24 @@ type RevisionDatabase interface {
// revision specifications (e.g. "HEAD~2") are not supported. If a database implements RevisionDatabase, but
// is not pinned to a specific revision, the empty string is returned.
Revision() string
// RevisionType returns the type of revision this database is pinned to.
RevisionType() RevisionType
// BaseName returns the name of the database without the revision specifier. E.g.if the database is named
// "myDB/master", BaseName returns "myDB".
BaseName() string
}
// RevisionType represents the type of revision a database is pinned to. For branches and tags, the revision is a
// string naming that branch or tag. For other revision specs, e.g. "HEAD~2", the revision is a commit hash.
type RevisionType int
const (
RevisionTypeNone RevisionType = iota
RevisionTypeBranch
RevisionTypeTag
RevisionTypeCommit
)
// RemoteReadReplicaDatabase is a database that pulls from a connected remote when a transaction begins.
type RemoteReadReplicaDatabase interface {
// ValidReplicaState returns whether this read replica is in a valid state to pull from the remote
@@ -68,9 +64,6 @@ type RemoteReadReplicaDatabase interface {
type DoltDatabaseProvider interface {
sql.MutableDatabaseProvider
RevisionDatabaseProvider
// env.RemoteDbProvider
// FileSystem returns the filesystem used by this provider, rooted at the data directory for all databases.
FileSystem() filesys.Filesys
// FileSystemForDatabase returns a filesystem, with the working directory set to the root directory
@@ -87,56 +80,7 @@ type DoltDatabaseProvider interface {
// (otherwise all branches are cloned), remoteName is the name for the remote created in the new database, and
// remoteUrl is a URL (e.g. "file:///dbs/db1") or an <org>/<database> path indicating a database hosted on DoltHub.
CloneDatabaseFromRemote(ctx *sql.Context, dbName, branch, remoteName, remoteUrl string, remoteParams map[string]string) error
// DbState returns the InitialDbState for the specified database and given branch. An empty branch name should use
// the default branch for the repository.
// TODO: make this use an ok bool instead of relying on sql.DatabaseNotFound errors
DbState(ctx *sql.Context, dbName string, defaultBranch string) (InitialDbState, error)
}
func EmptyDatabaseProvider() DoltDatabaseProvider {
return emptyRevisionDatabaseProvider{}
}
type emptyRevisionDatabaseProvider struct {
sql.DatabaseProvider
}
func (e emptyRevisionDatabaseProvider) DbState(ctx *sql.Context, dbName string, defaultBranch string) (InitialDbState, error) {
return InitialDbState{}, sql.ErrDatabaseNotFound.New(dbName)
}
func (e emptyRevisionDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error {
return nil
}
func (e emptyRevisionDatabaseProvider) GetRevisionForRevisionDatabase(_ *sql.Context, _ string) (string, string, error) {
return "", "", nil
}
func (e emptyRevisionDatabaseProvider) IsRevisionDatabase(_ *sql.Context, _ string) (bool, error) {
return false, nil
}
func (e emptyRevisionDatabaseProvider) GetRemoteDB(ctx context.Context, format *types.NomsBinFormat, r env.Remote, withCaching bool) (*doltdb.DoltDB, error) {
return nil, nil
}
func (e emptyRevisionDatabaseProvider) FileSystem() filesys.Filesys {
return nil
}
func (e emptyRevisionDatabaseProvider) FileSystemForDatabase(dbname string) (filesys.Filesys, error) {
return nil, nil
}
func (e emptyRevisionDatabaseProvider) CloneDatabaseFromRemote(ctx *sql.Context, dbName, branch, remoteName, remoteUrl string, remoteParams map[string]string) error {
return nil
}
func (e emptyRevisionDatabaseProvider) CreateDatabase(ctx *sql.Context, dbName string) error {
return nil
}
func (e emptyRevisionDatabaseProvider) RevisionDbState(_ *sql.Context, revDB string) (InitialDbState, error) {
return InitialDbState{}, sql.ErrDatabaseNotFound.New(revDB)
// SessionDatabase returns the SessionDatabase for the specified database, which may name a revision of a base
// database.
SessionDatabase(ctx *sql.Context, dbName string) (SessionDatabase, bool, error)
}
@@ -2324,6 +2324,7 @@ func skipPreparedTests(t *testing.T) {
func newSessionBuilder(harness *DoltHarness) server.SessionBuilder {
return func(ctx context.Context, conn *mysql.Conn, host string) (sql.Session, error) {
return harness.session, nil
newCtx := harness.NewSession()
return newCtx.Session, nil
}
}
@@ -367,16 +367,10 @@ func (d *DoltHarness) Close() {
func (d *DoltHarness) closeProvider() {
if d.provider != nil {
dbs := sqle.AllDbs(sql.NewEmptyContext(), d.provider)
dbs := d.provider.AllDatabases(sql.NewEmptyContext())
for _, db := range dbs {
d.t.Logf("closing %v", db)
require.NoError(d.t, db.DbData().Ddb.Close())
}
}
if d.session != nil {
dbs := sqle.AllDbs(sql.NewEmptyContext(), d.session.Provider())
for _, db := range dbs {
d.t.Logf("session had database %v", db)
require.NoError(d.t, db.(sqle.SqlDatabase).DbData().Ddb.Close())
}
}
}
@@ -33,7 +33,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
)
@@ -183,8 +182,7 @@ var (
)
func TestDoltIndexEqual(t *testing.T) {
ddb, ctx, root, indexMap := doltIndexSetup(t)
defer ddb.Close()
root, indexMap := doltIndexSetup(t)
tests := []doltIndexTestCase{
{
@@ -298,6 +296,7 @@ func TestDoltIndexEqual(t *testing.T) {
for _, test := range tests {
t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) {
ctx := sql.NewEmptyContext()
idx, ok := indexMap[test.indexName]
require.True(t, ok)
testDoltIndex(t, ctx, root, test.keys, test.expectedRows, idx, indexComp_Eq)
@@ -306,8 +305,7 @@ func TestDoltIndexEqual(t *testing.T) {
}
func TestDoltIndexGreaterThan(t *testing.T) {
ddb, ctx, root, indexMap := doltIndexSetup(t)
defer ddb.Close()
root, indexMap := doltIndexSetup(t)
tests := []struct {
indexName string
@@ -440,6 +438,7 @@ func TestDoltIndexGreaterThan(t *testing.T) {
for _, test := range tests {
t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) {
ctx := sql.NewEmptyContext()
index, ok := indexMap[test.indexName]
require.True(t, ok)
testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_Gt)
@@ -448,8 +447,7 @@ func TestDoltIndexGreaterThan(t *testing.T) {
}
func TestDoltIndexGreaterThanOrEqual(t *testing.T) {
ddb, ctx, root, indexMap := doltIndexSetup(t)
defer ddb.Close()
root, indexMap := doltIndexSetup(t)
tests := []struct {
indexName string
@@ -578,6 +576,7 @@ func TestDoltIndexGreaterThanOrEqual(t *testing.T) {
for _, test := range tests {
t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) {
ctx := sql.NewEmptyContext()
index, ok := indexMap[test.indexName]
require.True(t, ok)
testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_GtE)
@@ -586,8 +585,7 @@ func TestDoltIndexGreaterThanOrEqual(t *testing.T) {
}
func TestDoltIndexLessThan(t *testing.T) {
ddb, ctx, root, indexMap := doltIndexSetup(t)
defer ddb.Close()
root, indexMap := doltIndexSetup(t)
tests := []struct {
indexName string
@@ -725,6 +723,7 @@ func TestDoltIndexLessThan(t *testing.T) {
for _, test := range tests {
t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) {
ctx := sql.NewEmptyContext()
index, ok := indexMap[test.indexName]
require.True(t, ok)
testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_Lt)
@@ -733,8 +732,7 @@ func TestDoltIndexLessThan(t *testing.T) {
}
func TestDoltIndexLessThanOrEqual(t *testing.T) {
ddb, ctx, root, indexMap := doltIndexSetup(t)
defer ddb.Close()
root, indexMap := doltIndexSetup(t)
tests := []struct {
indexName string
@@ -873,6 +871,7 @@ func TestDoltIndexLessThanOrEqual(t *testing.T) {
for _, test := range tests {
t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) {
ctx := sql.NewEmptyContext()
index, ok := indexMap[test.indexName]
require.True(t, ok)
testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_LtE)
@@ -881,8 +880,7 @@ func TestDoltIndexLessThanOrEqual(t *testing.T) {
}
func TestDoltIndexBetween(t *testing.T) {
ddb, ctx, root, indexMap := doltIndexSetup(t)
defer ddb.Close()
root, indexMap := doltIndexSetup(t)
tests := []doltIndexBetweenTestCase{
{
@@ -1052,6 +1050,8 @@ func TestDoltIndexBetween(t *testing.T) {
for _, test := range tests {
t.Run(fmt.Sprintf("%s|%v%v", test.indexName, test.greaterThanOrEqual, test.lessThanOrEqual), func(t *testing.T) {
ctx := sql.NewEmptyContext()
idx, ok := indexMap[test.indexName]
require.True(t, ok)
@@ -1292,6 +1292,7 @@ func requireUnorderedRowsEqual(t *testing.T, s sql.Schema, rows1, rows2 []sql.Ro
}
func testDoltIndex(t *testing.T, ctx *sql.Context, root *doltdb.RootValue, keys []interface{}, expectedRows []sql.Row, idx index.DoltIndex, cmp indexComp) {
ctx = sql.NewEmptyContext()
exprs := idx.Expressions()
builder := sql.NewIndexBuilder(idx)
for i, key := range keys {
@@ -1335,8 +1336,8 @@ func testDoltIndex(t *testing.T, ctx *sql.Context, root *doltdb.RootValue, keys
requireUnorderedRowsEqual(t, pkSch.Schema, convertSqlRowToInt64(expectedRows), readRows)
}
func doltIndexSetup(t *testing.T) (*doltdb.DoltDB, *sql.Context, *doltdb.RootValue, map[string]index.DoltIndex) {
ctx := NewTestSQLCtx(context.Background())
func doltIndexSetup(t *testing.T) (*doltdb.RootValue, map[string]index.DoltIndex) {
ctx := context.Background()
dEnv := dtestutils.CreateTestEnv()
root, err := dEnv.WorkingRoot(ctx)
if err != nil {
@@ -1408,18 +1409,7 @@ INSERT INTO types VALUES (1, 4, '2020-05-14 12:00:03', 1.1, 'd', 1.1, 'a,c', '00
}
}
return dEnv.DoltDB, ctx, root, indexMap
}
func NewTestSQLCtx(ctx context.Context) *sql.Context {
s := dsess.DefaultSession(dsess.EmptyDatabaseProvider())
s.SetCurrentDatabase("dolt")
sqlCtx := sql.NewContext(
ctx,
sql.WithSession(s),
)
return sqlCtx
return root, indexMap
}
func mustTime(timeString string) time.Time {
@@ -24,7 +24,6 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/stretchr/testify/require"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
@@ -32,11 +31,10 @@ import (
dsqle "github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/store/types"
)
func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *env.DoltEnv, *doltdb.RootValue, dsqle.Database, []*indexTuple) {
func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *sql.Context, []*indexTuple) {
dEnv := dtestutils.CreateTestEnv()
tmpDir, err := dEnv.TempTableFilesDir()
require.NoError(t, err)
@@ -102,21 +100,12 @@ func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *e
b := env.GetDefaultInitBranch(dEnv.Config)
pro, err := dsqle.NewDoltDatabaseProviderWithDatabase(b, mrEnv.FileSystem(), db, dEnv.FS)
if err != nil {
return nil, nil, nil, dsqle.Database{}, nil
return nil, nil, nil
}
pro = pro.WithDbFactoryUrl(doltdb.InMemDoltDB)
engine = sqle.NewDefault(pro)
// Get an updated root to use for the rest of the test
ctx := sql.NewEmptyContext()
controller := branch_control.CreateDefaultController()
sess, err := dsess.NewDoltSession(ctx.Session.(*sql.BaseSession), pro, config.NewEmptyMapConfig(), controller)
require.NoError(t, err)
roots, ok := sess.GetRoots(ctx, db.Name())
require.True(t, ok)
err = sess.SetRoot(sqlCtx, db.Name(), roots.Working)
it := []*indexTuple{
idxv1ToTuple,
idxv2v1ToTuple,
@@ -126,7 +115,7 @@ func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *e
},
}
return engine, dEnv, roots.Working, db, it
return engine, sqlCtx, it
}
// indexTuple converts integers into the appropriate tuple for comparison against ranges
@@ -15,7 +15,6 @@
package index_test
import (
"context"
"fmt"
"testing"
@@ -27,7 +26,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/noms"
"github.com/dolthub/dolt/go/store/types"
@@ -42,7 +40,7 @@ func TestMergeableIndexes(t *testing.T) {
t.Skip() // this test is specific to Noms ranges
}
engine, denv, root, db, indexTuples := setupIndexes(t, "test", `INSERT INTO test VALUES
engine, sqlCtx, indexTuples := setupIndexes(t, "test", `INSERT INTO test VALUES
(-3, NULL, NULL),
(-2, NULL, NULL),
(-1, NULL, NULL),
@@ -1316,16 +1314,6 @@ func TestMergeableIndexes(t *testing.T) {
for _, test := range tests {
t.Run(test.whereStmt, func(t *testing.T) {
ctx := context.Background()
sqlCtx := NewTestSQLCtx(ctx)
session := dsess.DSessFromSess(sqlCtx.Session)
dbState := getDbState(t, db, denv)
err := session.AddDB(sqlCtx, dbState)
require.NoError(t, err)
sqlCtx.SetCurrentDatabase(db.Name())
err = session.SetRoot(sqlCtx, db.Name(), root)
require.NoError(t, err)
query := fmt.Sprintf(`SELECT pk FROM test WHERE %s ORDER BY 1`, test.whereStmt)
finalRanges, err := ReadRangesFromQuery(sqlCtx, engine, query)
@@ -1382,7 +1370,7 @@ func TestMergeableIndexesNulls(t *testing.T) {
t.Skip() // this test is specific to Noms ranges
}
engine, denv, root, db, indexTuples := setupIndexes(t, "test", `INSERT INTO test VALUES
engine, sqlCtx, indexTuples := setupIndexes(t, "test", `INSERT INTO test VALUES
(0, 10, 20),
(1, 11, 21),
(2, NULL, NULL),
@@ -1532,16 +1520,6 @@ func TestMergeableIndexesNulls(t *testing.T) {
for _, test := range tests {
t.Run(test.whereStmt, func(t *testing.T) {
ctx := context.Background()
sqlCtx := NewTestSQLCtx(ctx)
session := dsess.DSessFromSess(sqlCtx.Session)
dbState := getDbState(t, db, denv)
err := session.AddDB(sqlCtx, dbState)
require.NoError(t, err)
sqlCtx.SetCurrentDatabase(db.Name())
err = session.SetRoot(sqlCtx, db.Name(), root)
require.NoError(t, err)
query := fmt.Sprintf(`SELECT pk FROM test WHERE %s ORDER BY 1`, test.whereStmt)
finalRanges, err := ReadRangesFromQuery(sqlCtx, engine, query)
@@ -52,9 +52,9 @@ type DoltHarness struct {
}
func (h *DoltHarness) Close() {
dbs := dsql.AllDbs(sql.NewEmptyContext(), h.sess.Provider())
dbs := h.sess.Provider().AllDatabases(sql.NewEmptyContext())
for _, db := range dbs {
db.DbData().Ddb.Close()
db.(dsql.SqlDatabase).DbData().Ddb.Close()
}
}
@@ -96,8 +96,8 @@ func (rrd ReadReplicaDatabase) ValidReplicaState(ctx *sql.Context) bool {
// InitialDBState implements dsess.SessionDatabase
// This seems like a pointless override from the embedded Database implementation, but it's necessary to pass the
// correct pointer type to the session initializer.
func (rrd ReadReplicaDatabase) InitialDBState(ctx context.Context, branch string) (dsess.InitialDbState, error) {
return GetInitialDBState(ctx, rrd, branch)
func (rrd ReadReplicaDatabase) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) {
return initialDBState(ctx, rrd, branch)
}
func (rrd ReadReplicaDatabase) PullFromRemote(ctx *sql.Context) error {
@@ -27,27 +27,21 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/json"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
)
func TestSchemaTableMigrationOriginal(t *testing.T) {
ctx := NewTestSQLCtx(context.Background())
dEnv := dtestutils.CreateTestEnv()
tmpDir, err := dEnv.TempTableFilesDir()
require.NoError(t, err)
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: tmpDir}
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
dbState, err := getDbState(db, dEnv)
_, ctx, err := NewTestEngine(dEnv, context.Background(), db)
require.NoError(t, err)
err = dsess.DSessFromSess(ctx.Session).AddDB(ctx, dbState)
require.NoError(t, err)
ctx.SetCurrentDatabase(db.Name())
err = db.createSqlTable(ctx, doltdb.SchemasTableName, sql.NewPrimaryKeySchema(sql.Schema{ // original schema of dolt_schemas table
{Name: doltdb.SchemasTablesTypeCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: true},
{Name: doltdb.SchemasTablesNameCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: true},
@@ -94,21 +88,16 @@ func TestSchemaTableMigrationOriginal(t *testing.T) {
}
func TestSchemaTableMigrationV1(t *testing.T) {
ctx := NewTestSQLCtx(context.Background())
dEnv := dtestutils.CreateTestEnv()
tmpDir, err := dEnv.TempTableFilesDir()
require.NoError(t, err)
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: tmpDir}
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
dbState, err := getDbState(db, dEnv)
_, ctx, err := NewTestEngine(dEnv, context.Background(), db)
require.NoError(t, err)
err = dsess.DSessFromSess(ctx.Session).AddDB(ctx, dbState)
require.NoError(t, err)
ctx.SetCurrentDatabase(db.Name())
// original schema of dolt_schemas table with the ID column
err = db.createSqlTable(ctx, doltdb.SchemasTableName, sql.NewPrimaryKeySchema(sql.Schema{
{Name: doltdb.SchemasTablesTypeCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: false},
@@ -1,321 +0,0 @@
// Copyright 2020 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqle
import (
"context"
"fmt"
"strings"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
)
// SingleTableInfoDatabase is intended to allow a sole schema to make use of any display functionality in `go-mysql-server`.
// For example, you may have constructed a schema that you want a CREATE TABLE statement for, but the schema is not
// persisted or is temporary. This allows `go-mysql-server` to interact with that sole schema as though it were a database.
// No write operations will work with this database.
type SingleTableInfoDatabase struct {
tableName string
sch schema.Schema
foreignKeys []doltdb.ForeignKey
parentSchs map[string]schema.Schema
}
var _ doltReadOnlyTableInterface = (*SingleTableInfoDatabase)(nil)
var _ sql.IndexedTable = (*SingleTableInfoDatabase)(nil)
var _ SqlDatabase = (*SingleTableInfoDatabase)(nil)
func NewSingleTableDatabase(tableName string, sch schema.Schema, foreignKeys []doltdb.ForeignKey, parentSchs map[string]schema.Schema) *SingleTableInfoDatabase {
return &SingleTableInfoDatabase{
tableName: tableName,
sch: sch,
foreignKeys: foreignKeys,
parentSchs: parentSchs,
}
}
// Name implements sql.Table and sql.Database.
func (db *SingleTableInfoDatabase) Name() string {
return db.tableName
}
// GetTableInsensitive implements sql.Database.
func (db *SingleTableInfoDatabase) GetTableInsensitive(ctx *sql.Context, tableName string) (sql.Table, bool, error) {
if strings.ToLower(tableName) == strings.ToLower(db.tableName) {
return db, true, nil
}
return nil, false, nil
}
// GetTableNames implements sql.Database.
func (db *SingleTableInfoDatabase) GetTableNames(ctx *sql.Context) ([]string, error) {
return []string{db.tableName}, nil
}
// String implements sql.Table.
func (db *SingleTableInfoDatabase) String() string {
return db.tableName
}
// Schema implements sql.Table.
func (db *SingleTableInfoDatabase) Schema() sql.Schema {
sqlSch, err := sqlutil.FromDoltSchema(db.tableName, db.sch)
if err != nil {
}
return sqlSch.Schema
}
// Collation implements sql.Table.
func (db *SingleTableInfoDatabase) Collation() sql.CollationID {
return sql.CollationID(db.sch.GetCollation())
}
func (db *SingleTableInfoDatabase) InitialDBState(ctx context.Context, branch string) (dsess.InitialDbState, error) {
return getInitialDBStateForUserSpaceDb(ctx, db)
}
// Partitions implements sql.Table.
func (db *SingleTableInfoDatabase) Partitions(*sql.Context) (sql.PartitionIter, error) {
return nil, fmt.Errorf("cannot get paritions of a single table information database")
}
// PartitionRows implements sql.Table.
func (db *SingleTableInfoDatabase) PartitionRows(*sql.Context, sql.Partition) (sql.RowIter, error) {
return nil, fmt.Errorf("cannot get partition rows of a single table information database")
}
func (db *SingleTableInfoDatabase) PartitionRows2(ctx *sql.Context, part sql.Partition) (sql.RowIter2, error) {
return nil, fmt.Errorf("cannot get partition rows of a single table information database")
}
func (db *SingleTableInfoDatabase) LookupPartitions(context *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) {
return nil, fmt.Errorf("cannot get paritions of a single table information database")
}
// CreateIndexForForeignKey implements sql.ForeignKeyTable.
func (db *SingleTableInfoDatabase) CreateIndexForForeignKey(ctx *sql.Context, idx sql.IndexDef) error {
return fmt.Errorf("cannot create foreign keys on a single table information database")
}
// GetDeclaredForeignKeys implements sql.ForeignKeyTable.
func (db *SingleTableInfoDatabase) GetDeclaredForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
fks := make([]sql.ForeignKeyConstraint, len(db.foreignKeys))
for i, fk := range db.foreignKeys {
if !fk.IsResolved() {
fks[i] = sql.ForeignKeyConstraint{
Name: fk.Name,
Database: ctx.GetCurrentDatabase(),
Table: fk.TableName,
Columns: fk.UnresolvedFKDetails.TableColumns,
ParentDatabase: ctx.GetCurrentDatabase(),
ParentTable: fk.ReferencedTableName,
ParentColumns: fk.UnresolvedFKDetails.ReferencedTableColumns,
OnUpdate: toReferentialAction(fk.OnUpdate),
OnDelete: toReferentialAction(fk.OnDelete),
}
continue
}
if parentSch, ok := db.parentSchs[fk.ReferencedTableName]; ok {
var err error
fks[i], err = toForeignKeyConstraint(fk, ctx.GetCurrentDatabase(), db.sch, parentSch)
if err != nil {
return nil, err
}
} else {
// We can skip here since the given schema may be purposefully incomplete (such as with diffs).
continue
}
}
return fks, nil
}
// GetReferencedForeignKeys implements sql.ForeignKeyTable.
func (db *SingleTableInfoDatabase) GetReferencedForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
return nil, nil
}
// AddForeignKey implements sql.ForeignKeyTable.
func (db *SingleTableInfoDatabase) AddForeignKey(ctx *sql.Context, fk sql.ForeignKeyConstraint) error {
return fmt.Errorf("cannot create foreign keys on a single table information database")
}
// DropForeignKey implements sql.ForeignKeyTable.
func (db *SingleTableInfoDatabase) DropForeignKey(ctx *sql.Context, fkName string) error {
return fmt.Errorf("cannot create foreign keys on a single table information database")
}
// UpdateForeignKey implements sql.ForeignKeyTable.
func (db *SingleTableInfoDatabase) UpdateForeignKey(ctx *sql.Context, fkName string, fk sql.ForeignKeyConstraint) error {
return fmt.Errorf("cannot create foreign keys on a single table information database")
}
// GetForeignKeyEditor implements sql.ForeignKeyTable.
func (db *SingleTableInfoDatabase) GetForeignKeyEditor(ctx *sql.Context) sql.ForeignKeyEditor {
return nil
}
// IndexedAccess implements sql.IndexedTable.
func (db *SingleTableInfoDatabase) IndexedAccess(lookup sql.IndexLookup) sql.IndexedTable {
return db
}
// GetIndexes implements sql.IndexedTable.
func (db *SingleTableInfoDatabase) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
var sqlIndexes []sql.Index
for _, idx := range db.sch.Indexes().AllIndexes() {
cols := make([]schema.Column, idx.Count())
for i, tag := range idx.IndexedColumnTags() {
cols[i], _ = idx.GetColumn(tag)
}
sqlIndexes = append(sqlIndexes, &fmtIndex{
id: idx.Name(),
db: db.Name(),
tbl: db.tableName,
cols: cols,
unique: idx.IsUnique(),
generated: false,
comment: idx.Comment(),
})
}
return sqlIndexes, nil
}
func (db *SingleTableInfoDatabase) GetChecks(ctx *sql.Context) ([]sql.CheckDefinition, error) {
return checksInSchema(db.sch), nil
}
func (db *SingleTableInfoDatabase) IsTemporary() bool {
return false
}
func (db *SingleTableInfoDatabase) DataLength(ctx *sql.Context) (uint64, error) {
// TODO: to answer this accurately, we need the table as well as the schema
return 0, nil
}
func (db *SingleTableInfoDatabase) RowCount(ctx *sql.Context) (uint64, error) {
return 0, nil
}
func (db *SingleTableInfoDatabase) PrimaryKeySchema() sql.PrimaryKeySchema {
sqlSch, err := sqlutil.FromDoltSchema(db.tableName, db.sch)
if err != nil {
}
return sqlSch
}
func (db *SingleTableInfoDatabase) GetRoot(context *sql.Context) (*doltdb.RootValue, error) {
return nil, nil
}
func (db *SingleTableInfoDatabase) DbData() env.DbData {
panic("SingleTableInfoDatabase doesn't have DbData")
}
func (db *SingleTableInfoDatabase) Flush(context *sql.Context) error {
panic("SingleTableInfoDatabase cannot Flush")
}
// fmtIndex is used for CREATE TABLE statements only.
type fmtIndex struct {
id string
db string
tbl string
cols []schema.Column
unique bool
spatial bool
generated bool
comment string
}
// CanSupport implements sql.Index
func (idx fmtIndex) CanSupport(r ...sql.Range) bool {
return true
}
// ID implements sql.Index
func (idx fmtIndex) ID() string {
return idx.id
}
// Database implements sql.Index
func (idx fmtIndex) Database() string {
return idx.db
}
// Table implements sql.Index
func (idx fmtIndex) Table() string {
return idx.tbl
}
// Expressions implements sql.Index
func (idx fmtIndex) Expressions() []string {
strs := make([]string, len(idx.cols))
for i, col := range idx.cols {
strs[i] = idx.tbl + "." + col.Name
}
return strs
}
// IsUnique implements sql.Index
func (idx fmtIndex) IsUnique() bool {
return idx.unique
}
// IsSpatial implements sql.Index
func (idx fmtIndex) IsSpatial() bool {
return idx.spatial
}
// Comment implements sql.Index
func (idx fmtIndex) Comment() string {
return idx.comment
}
// PrefixLengths implements sql.Index
func (idx fmtIndex) PrefixLengths() []uint16 {
return nil
}
// IndexType implements sql.Index
func (idx fmtIndex) IndexType() string {
return "BTREE"
}
// IsGenerated implements sql.Index
func (idx fmtIndex) IsGenerated() bool {
return idx.generated
}
func (idx fmtIndex) IndexedAccess(_ sql.IndexLookup) (sql.IndexedTable, error) {
panic("unimplemented")
}
// ColumnExpressionTypes implements sql.Index
func (idx fmtIndex) ColumnExpressionTypes() []sql.ColumnExpressionType {
panic("unimplemented")
}
func (db *SingleTableInfoDatabase) EditOptions() editor.Options {
return editor.Options{}
}
-39
View File
@@ -113,11 +113,6 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
return db.GetRoot(ctx)
}
// NewTestSQLCtx returns a new *sql.Context with a default DoltSession, a new IndexRegistry, and a new ViewRegistry
func NewTestSQLCtx(ctx context.Context) *sql.Context {
return NewTestSQLCtxWithProvider(ctx, dsess.EmptyDatabaseProvider())
}
func NewTestSQLCtxWithProvider(ctx context.Context, pro dsess.DoltDatabaseProvider) *sql.Context {
s, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, config2.NewMapConfig(make(map[string]string)), branch_control.CreateDefaultController())
if err != nil {
@@ -141,44 +136,10 @@ func NewTestEngine(dEnv *env.DoltEnv, ctx context.Context, db SqlDatabase) (*sql
engine := sqle.NewDefault(pro)
sqlCtx := NewTestSQLCtxWithProvider(ctx, pro)
dbState, err := getDbState(db, dEnv)
if err != nil {
return nil, nil, err
}
err = dsess.DSessFromSess(sqlCtx.Session).AddDB(sqlCtx, dbState)
if err != nil {
return nil, nil, err
}
sqlCtx.SetCurrentDatabase(db.Name())
return engine, sqlCtx, nil
}
func getDbState(db sql.Database, dEnv *env.DoltEnv) (dsess.InitialDbState, error) {
ctx := context.Background()
head := dEnv.RepoStateReader().CWBHeadSpec()
headCommit, err := dEnv.DoltDB.Resolve(ctx, head, dEnv.RepoStateReader().CWBHeadRef())
if err != nil {
return dsess.InitialDbState{}, err
}
ws, err := dEnv.WorkingSet(ctx)
if err != nil {
return dsess.InitialDbState{}, err
}
return dsess.InitialDbState{
Db: db,
HeadCommit: headCommit,
WorkingSet: ws,
DbData: dEnv.DbData(),
Remotes: dEnv.RepoState.Remotes,
}, nil
}
// ExecuteSelect executes the select statement given and returns the resulting rows, or an error if one is encountered.
func ExecuteSelect(dEnv *env.DoltEnv, root *doltdb.RootValue, query string) ([]sql.Row, error) {
dbData := env.DbData{
@@ -15,8 +15,6 @@
package sqle
import (
"context"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
@@ -78,8 +76,15 @@ func (db *UserSpaceDatabase) GetTableNames(ctx *sql.Context) ([]string, error) {
return resultingTblNames, nil
}
func (db *UserSpaceDatabase) InitialDBState(ctx context.Context, branch string) (dsess.InitialDbState, error) {
return getInitialDBStateForUserSpaceDb(ctx, db)
func (db *UserSpaceDatabase) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) {
return dsess.InitialDbState{
Db: db,
ReadOnly: true,
HeadRoot: db.RootValue,
DbData: env.DbData{
Rsw: noopRepoStateWriter{},
},
}, nil
}
func (db *UserSpaceDatabase) GetRoot(*sql.Context) (*doltdb.RootValue, error) {
@@ -101,3 +106,15 @@ func (db *UserSpaceDatabase) Flush(ctx *sql.Context) error {
func (db *UserSpaceDatabase) EditOptions() editor.Options {
return db.editOpts
}
func (db *UserSpaceDatabase) Revision() string {
return ""
}
func (db *UserSpaceDatabase) RevisionType() dsess.RevisionType {
return dsess.RevisionTypeNone
}
func (db *UserSpaceDatabase) BaseName() string {
return db.Name()
}
@@ -22,11 +22,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/row"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
@@ -42,8 +40,6 @@ type tableEditorTest struct {
selectQuery string
// The rows this query should return, nil if an error is expected
expectedRows []sql.Row
// Expected error string, if any
expectedErr string
}
func TestTableEditor(t *testing.T) {
@@ -60,7 +56,6 @@ func TestTableEditor(t *testing.T) {
fatTony := sqle.NewPeopleRow(16, "Fat", "Tony", false, 53, 5.0)
troyMclure := sqle.NewPeopleRow(17, "Troy", "McClure", false, 58, 7.0)
var expectedErr error
// Some of these are pretty exotic use cases, but since we support all these operations it's nice to know they work
// in tandem.
testCases := []tableEditorTest{
@@ -159,27 +154,18 @@ func TestTableEditor(t *testing.T) {
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
expectedErr = nil
dEnv, err := sqle.CreateTestDatabase()
require.NoError(t, err)
ctx := sqle.NewTestSQLCtx(context.Background())
root, err := dEnv.WorkingRoot(context.Background())
require.NoError(t, err)
tmpDir, err := dEnv.TempTableFilesDir()
require.NoError(t, err)
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: tmpDir}
db, err := sqle.NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
db, err := sqle.NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
require.NoError(t, err)
err = dsess.DSessFromSess(ctx.Session).AddDB(ctx, getDbState(t, db, dEnv))
engine, ctx, err := sqle.NewTestEngine(dEnv, context.Background(), db)
require.NoError(t, err)
ctx.SetCurrentDatabase(db.Name())
err = db.SetRoot(ctx, root)
require.NoError(t, err)
peopleTable, _, err := db.GetTableInsensitive(ctx, "people")
require.NoError(t, err)
@@ -187,20 +173,18 @@ func TestTableEditor(t *testing.T) {
ed := dt.Updater(ctx).(writer.TableWriter)
test.setup(ctx, t, ed)
if len(test.expectedErr) > 0 {
require.Error(t, expectedErr)
assert.Contains(t, expectedErr.Error(), test.expectedErr)
return
} else {
require.NoError(t, ed.Close(ctx))
}
require.NoError(t, ed.Close(ctx))
root, err = db.GetRoot(ctx)
root, err := db.GetRoot(ctx)
require.NoError(t, err)
// TODO: not clear why this is necessary, the call to ed.Close should update the working set already
require.NoError(t, dEnv.UpdateWorkingRoot(context.Background(), root))
actualRows, err := sqle.ExecuteSelect(dEnv, root, test.selectQuery)
sch, rowIter, err := engine.Query(ctx, test.selectQuery)
require.NoError(t, err)
actualRows, err := sql.RowIterToRows(ctx, sch, rowIter)
require.NoError(t, err)
assert.Equal(t, test.expectedRows, actualRows)
@@ -215,22 +199,3 @@ func r(r row.Row, sch schema.Schema) sql.Row {
}
return sqlRow
}
func getDbState(t *testing.T, db sql.Database, dEnv *env.DoltEnv) dsess.InitialDbState {
ctx := context.Background()
head := dEnv.RepoStateReader().CWBHeadSpec()
headCommit, err := dEnv.DoltDB.Resolve(ctx, head, dEnv.RepoStateReader().CWBHeadRef())
require.NoError(t, err)
ws, err := dEnv.WorkingSet(ctx)
require.NoError(t, err)
return dsess.InitialDbState{
Db: db,
HeadCommit: headCommit,
WorkingSet: ws,
DbData: dEnv.DbData(),
Remotes: dEnv.RepoState.Remotes,
}
}
@@ -275,7 +275,7 @@ func RunQueryAttempt(t require.TestingT, conn *sql.Conn, q driver.Query) {
defer rows.Close()
}
if q.ErrorMatch != "" {
require.Error(t, err)
require.Error(t, err, "expected error running query %s", q.Query)
require.Regexp(t, q.ErrorMatch, err.Error())
return
}