mirror of
https://github.com/dolthub/dolt.git
synced 2026-04-26 03:30:09 -05:00
Merge pull request #5604 from dolthub/zachmu/multidb
[no-release-notes] Refactored db / session initialization logic
This commit is contained in:
@@ -621,7 +621,7 @@ func diffUserTable(
|
||||
}
|
||||
|
||||
if dArgs.diffParts&SchemaOnlyDiff != 0 {
|
||||
err := dw.WriteTableSchemaDiff(ctx, dArgs.toRoot, td)
|
||||
err := dw.WriteTableSchemaDiff(ctx, dArgs.fromRoot, dArgs.toRoot, td)
|
||||
if err != nil {
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/json"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/sqlexport"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/tabular"
|
||||
@@ -44,7 +45,7 @@ type diffWriter interface {
|
||||
// BeginTable is called when a new table is about to be written, before any schema or row diffs are written
|
||||
BeginTable(ctx context.Context, td diff.TableDelta) error
|
||||
// WriteTableSchemaDiff is called to write a schema diff for the table given (if requested by args)
|
||||
WriteTableSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error
|
||||
WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error
|
||||
// WriteTriggerDiff is called to write a trigger diff
|
||||
WriteTriggerDiff(ctx context.Context, triggerName, oldDefn, newDefn string) error
|
||||
// WriteViewDiff is called to write a view diff
|
||||
@@ -221,17 +222,12 @@ func (t tabularDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t tabularDiffWriter) WriteTableSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error {
|
||||
fromSch, toSch, err := td.GetSchemas(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("cannot retrieve schema for table %s", td.ToName).AddCause(err).Build()
|
||||
}
|
||||
|
||||
func (t tabularDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error {
|
||||
var fromCreateStmt = ""
|
||||
if td.FromTable != nil {
|
||||
// TODO: use UserSpaceDatabase for these, no reason for this separate database implementation
|
||||
sqlDb := sqle.NewSingleTableDatabase(td.FromName, fromSch, td.FromFks, td.FromFksParentSch)
|
||||
sqlDb := sqle.NewUserSpaceDatabase(fromRoot, editor.Options{})
|
||||
sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb)
|
||||
var err error
|
||||
fromCreateStmt, err = sqle.GetCreateTableStmt(sqlCtx, engine, td.FromName)
|
||||
if err != nil {
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
@@ -240,8 +236,9 @@ func (t tabularDiffWriter) WriteTableSchemaDiff(ctx context.Context, toRoot *dol
|
||||
|
||||
var toCreateStmt = ""
|
||||
if td.ToTable != nil {
|
||||
sqlDb := sqle.NewSingleTableDatabase(td.ToName, toSch, td.ToFks, td.ToFksParentSch)
|
||||
sqlDb := sqle.NewUserSpaceDatabase(toRoot, editor.Options{})
|
||||
sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb)
|
||||
var err error
|
||||
toCreateStmt, err = sqle.GetCreateTableStmt(sqlCtx, engine, td.ToName)
|
||||
if err != nil {
|
||||
return errhand.VerboseErrorFromError(err)
|
||||
@@ -298,7 +295,7 @@ func (s sqlDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s sqlDiffWriter) WriteTableSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error {
|
||||
func (s sqlDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error {
|
||||
toSchemas, err := toRoot.GetAllSchemas(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("could not read schemas from toRoot").AddCause(err).Build()
|
||||
@@ -417,7 +414,7 @@ func (j *jsonDiffWriter) BeginTable(ctx context.Context, td diff.TableDelta) err
|
||||
return err
|
||||
}
|
||||
|
||||
func (j *jsonDiffWriter) WriteTableSchemaDiff(ctx context.Context, toRoot *doltdb.RootValue, td diff.TableDelta) error {
|
||||
func (j *jsonDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromRoot *doltdb.RootValue, toRoot *doltdb.RootValue, td diff.TableDelta) error {
|
||||
toSchemas, err := toRoot.GetAllSchemas(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("could not read schemas from toRoot").AddCause(err).Build()
|
||||
|
||||
@@ -306,30 +306,6 @@ func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commit)
|
||||
parallelism := runtime.GOMAXPROCS(0)
|
||||
azr := analyzer.NewBuilder(pro).WithParallelism(parallelism).Build()
|
||||
|
||||
head := dEnv.RepoStateReader().CWBHeadSpec()
|
||||
headCommit, err := dEnv.DoltDB.Resolve(ctx, head, dEnv.RepoStateReader().CWBHeadRef())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ws, err := dEnv.WorkingSet(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dbState := dsess.InitialDbState{
|
||||
Db: db,
|
||||
HeadCommit: headCommit,
|
||||
WorkingSet: ws,
|
||||
DbData: dEnv.DbData(),
|
||||
Remotes: dEnv.RepoState.Remotes,
|
||||
}
|
||||
|
||||
err = sess.AddDB(sqlCtx, dbState)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
root, err := cm.GetRootValue(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -27,7 +27,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/json"
|
||||
@@ -143,15 +142,13 @@ func (cmd CatCmd) Exec(ctx context.Context, commandStr string, args []string, dE
|
||||
}
|
||||
|
||||
func (cmd CatCmd) prettyPrintResults(ctx context.Context, doltSch schema.Schema, idx durable.Index) error {
|
||||
|
||||
wr, err := getTableWriter(cmd.resultFormat, doltSch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer wr.Close(ctx)
|
||||
|
||||
sess := dsess.DefaultSession(dsess.EmptyDatabaseProvider())
|
||||
sqlCtx := sql.NewContext(ctx, sql.WithSession(sess))
|
||||
sqlCtx := sql.NewEmptyContext()
|
||||
|
||||
rowItr, err := table.NewTableIterator(ctx, doltSch, idx, 0)
|
||||
if err != nil {
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/csv"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/argparser"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/funcitr"
|
||||
@@ -301,8 +302,12 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
|
||||
}
|
||||
|
||||
tblName := impArgs.tableName
|
||||
// inferred schemas have no foreign keys
|
||||
sqlDb := sqle.NewSingleTableDatabase(tblName, sch, nil, nil)
|
||||
root, verr = putEmptyTableWithSchema(ctx, tblName, root, sch)
|
||||
if verr != nil {
|
||||
return verr
|
||||
}
|
||||
|
||||
sqlDb := sqle.NewUserSpaceDatabase(root, editor.Options{})
|
||||
sqlCtx, engine, _ := sqle.PrepareCreateTableStmt(ctx, sqlDb)
|
||||
|
||||
stmt, err := sqle.GetCreateTableStmt(sqlCtx, engine, tblName)
|
||||
@@ -312,39 +317,6 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
|
||||
cli.Println(stmt)
|
||||
|
||||
if !apr.Contains(dryRunFlag) {
|
||||
tbl, tblExists, err := root.GetTable(ctx, tblName)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
empty, err := durable.NewEmptyIndex(ctx, root.VRW(), root.NodeStore(), sch)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
var indexSet durable.IndexSet
|
||||
if tblExists {
|
||||
indexSet, err = tbl.GetIndexSet(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: failed to create table.").AddCause(err).Build()
|
||||
}
|
||||
} else {
|
||||
indexSet, err = durable.NewIndexSetWithEmptyIndexes(ctx, root.VRW(), root.NodeStore(), sch)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
|
||||
}
|
||||
}
|
||||
|
||||
tbl, err = doltdb.NewTable(ctx, root.VRW(), root.NodeStore(), sch, empty, indexSet, nil)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
root, err = root.PutTable(ctx, tblName, tbl)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: failed to add table.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
err = dEnv.UpdateWorkingRoot(ctx, root)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("error: failed to update the working set.").AddCause(err).Build()
|
||||
@@ -356,6 +328,43 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
|
||||
return nil
|
||||
}
|
||||
|
||||
func putEmptyTableWithSchema(ctx context.Context, tblName string, root *doltdb.RootValue, sch schema.Schema) (*doltdb.RootValue, errhand.VerboseError) {
|
||||
tbl, tblExists, err := root.GetTable(ctx, tblName)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
empty, err := durable.NewEmptyIndex(ctx, root.VRW(), root.NodeStore(), sch)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
var indexSet durable.IndexSet
|
||||
if tblExists {
|
||||
indexSet, err = tbl.GetIndexSet(ctx)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("error: failed to create table.").AddCause(err).Build()
|
||||
}
|
||||
} else {
|
||||
indexSet, err = durable.NewIndexSetWithEmptyIndexes(ctx, root.VRW(), root.NodeStore(), sch)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
|
||||
}
|
||||
}
|
||||
|
||||
tbl, err = doltdb.NewTable(ctx, root.VRW(), root.NodeStore(), sch, empty, indexSet, nil)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("error: failed to get table.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
root, err = root.PutTable(ctx, tblName, tbl)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("error: failed to add table.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func inferSchemaFromFile(ctx context.Context, nbf *types.NomsBinFormat, impOpts *importOptions, root *doltdb.RootValue) (schema.Schema, errhand.VerboseError) {
|
||||
if impOpts.fileType[0] == '.' {
|
||||
impOpts.fileType = impOpts.fileType[1:]
|
||||
|
||||
@@ -83,7 +83,7 @@ func (database) IsReadOnly() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (db database) InitialDBState(ctx context.Context, branch string) (dsess.InitialDbState, error) {
|
||||
func (db database) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) {
|
||||
// TODO: almost none of this state is actually used, but is necessary because the current session setup requires a
|
||||
// repo state writer
|
||||
return dsess.InitialDbState{
|
||||
@@ -111,6 +111,18 @@ func (db database) EditOptions() editor.Options {
|
||||
return editor.Options{}
|
||||
}
|
||||
|
||||
func (db database) Revision() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (db database) RevisionType() dsess.RevisionType {
|
||||
return dsess.RevisionTypeNone
|
||||
}
|
||||
|
||||
func (db database) BaseName() string {
|
||||
return db.Name()
|
||||
}
|
||||
|
||||
type noopRepoStateWriter struct{}
|
||||
|
||||
func (n noopRepoStateWriter) UpdateStagedRoot(ctx context.Context, newRoot *doltdb.RootValue) error {
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/mysql_db"
|
||||
"github.com/dolthub/go-mysql-server/sql/parse"
|
||||
"github.com/dolthub/go-mysql-server/sql/plan"
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
@@ -50,6 +49,7 @@ var ErrSystemTableAlter = errors.NewKind("Cannot alter table %s: system tables c
|
||||
type SqlDatabase interface {
|
||||
sql.Database
|
||||
dsess.SessionDatabase
|
||||
dsess.RevisionDatabase
|
||||
|
||||
// TODO: get rid of this, it's managed by the session, not the DB
|
||||
GetRoot(*sql.Context) (*doltdb.RootValue, error)
|
||||
@@ -58,34 +58,6 @@ type SqlDatabase interface {
|
||||
EditOptions() editor.Options
|
||||
}
|
||||
|
||||
// AllDbs returns all the databases in the given provider.
|
||||
func AllDbs(ctx *sql.Context, pro sql.DatabaseProvider) []SqlDatabase {
|
||||
dbs := pro.AllDatabases(ctx)
|
||||
dsqlDBs := make([]SqlDatabase, 0, len(dbs))
|
||||
for _, db := range dbs {
|
||||
var sqlDb SqlDatabase
|
||||
if sqlDatabase, ok := db.(SqlDatabase); ok {
|
||||
sqlDb = sqlDatabase
|
||||
} else if privDatabase, ok := db.(mysql_db.PrivilegedDatabase); ok {
|
||||
if sqlDatabase, ok := privDatabase.Unwrap().(SqlDatabase); ok {
|
||||
sqlDb = sqlDatabase
|
||||
}
|
||||
}
|
||||
if sqlDb == nil {
|
||||
continue
|
||||
}
|
||||
switch v := sqlDb.(type) {
|
||||
case ReadReplicaDatabase, Database:
|
||||
dsqlDBs = append(dsqlDBs, v)
|
||||
case ReadOnlyDatabase, *UserSpaceDatabase:
|
||||
default:
|
||||
// esoteric analyzer errors occur if we silently drop databases, usually caused by pointer receivers
|
||||
panic("cannot cast to SqlDatabase")
|
||||
}
|
||||
}
|
||||
return dsqlDBs
|
||||
}
|
||||
|
||||
// Database implements sql.Database for a dolt DB.
|
||||
type Database struct {
|
||||
name string
|
||||
@@ -95,6 +67,7 @@ type Database struct {
|
||||
gs globalstate.GlobalState
|
||||
editOpts editor.Options
|
||||
revision string
|
||||
revType dsess.RevisionType
|
||||
}
|
||||
|
||||
var _ SqlDatabase = Database{}
|
||||
@@ -123,11 +96,24 @@ func (r ReadOnlyDatabase) IsReadOnly() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (r ReadOnlyDatabase) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) {
|
||||
return initialDBState(ctx, r, branch)
|
||||
}
|
||||
|
||||
// Revision implements dsess.RevisionDatabase
|
||||
func (db Database) Revision() string {
|
||||
return db.revision
|
||||
}
|
||||
|
||||
func (db Database) RevisionType() dsess.RevisionType {
|
||||
return db.revType
|
||||
}
|
||||
|
||||
func (db Database) BaseName() string {
|
||||
base, _ := splitRevisionDbName(db)
|
||||
return base
|
||||
}
|
||||
|
||||
func (db Database) EditOptions() editor.Options {
|
||||
return db.editOpts
|
||||
}
|
||||
@@ -149,76 +135,18 @@ func NewDatabase(ctx context.Context, name string, dbData env.DbData, editOpts e
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetInitialDBState returns the InitialDbState for |db|.
|
||||
func GetInitialDBState(ctx context.Context, db SqlDatabase, branch string) (dsess.InitialDbState, error) {
|
||||
switch db := db.(type) {
|
||||
case *UserSpaceDatabase, *SingleTableInfoDatabase:
|
||||
return getInitialDBStateForUserSpaceDb(ctx, db)
|
||||
// initialDBState returns the InitialDbState for |db|. Other implementations of SqlDatabase outside this file should
|
||||
// implement their own method for an initial db state and not rely on this method.
|
||||
func initialDBState(ctx *sql.Context, db SqlDatabase, branch string) (dsess.InitialDbState, error) {
|
||||
if len(db.Revision()) > 0 {
|
||||
return initialStateForRevisionDb(ctx, db)
|
||||
}
|
||||
|
||||
rsr := db.DbData().Rsr
|
||||
ddb := db.DbData().Ddb
|
||||
|
||||
var r ref.DoltRef
|
||||
if len(branch) > 0 {
|
||||
r = ref.NewBranchRef(branch)
|
||||
} else {
|
||||
r = rsr.CWBHeadRef()
|
||||
}
|
||||
|
||||
var retainedErr error
|
||||
|
||||
headCommit, err := ddb.ResolveCommitRef(ctx, r)
|
||||
if err == doltdb.ErrBranchNotFound {
|
||||
retainedErr = err
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
var ws *doltdb.WorkingSet
|
||||
if retainedErr == nil {
|
||||
workingSetRef, err := ref.WorkingSetRefForHead(r)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
ws, err = db.DbData().Ddb.ResolveWorkingSet(ctx, workingSetRef)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
}
|
||||
|
||||
remotes, err := rsr.GetRemotes()
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
backups, err := rsr.GetBackups()
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
branches, err := rsr.GetBranches()
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
return dsess.InitialDbState{
|
||||
Db: db,
|
||||
HeadCommit: headCommit,
|
||||
WorkingSet: ws,
|
||||
DbData: db.DbData(),
|
||||
Remotes: remotes,
|
||||
Branches: branches,
|
||||
Backups: backups,
|
||||
Err: retainedErr,
|
||||
}, nil
|
||||
return initialDbState(ctx, db, branch)
|
||||
}
|
||||
|
||||
func (db Database) InitialDBState(ctx context.Context, branch string) (dsess.InitialDbState, error) {
|
||||
return GetInitialDBState(ctx, db, branch)
|
||||
func (db Database) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) {
|
||||
return initialDBState(ctx, db, branch)
|
||||
}
|
||||
|
||||
// Name returns the name of this database, set at creation time.
|
||||
@@ -1425,19 +1353,6 @@ func (db Database) SetCollation(ctx *sql.Context, collation sql.CollationID) err
|
||||
return db.SetRoot(ctx, newRoot)
|
||||
}
|
||||
|
||||
// TODO: this is a hack to make user space DBs appear to the analyzer as full DBs with state etc., but the state is
|
||||
// really skeletal. We need to reexamine the DB / session initialization to make this simpler -- most of these things
|
||||
// aren't needed at initialization time and for most code paths.
|
||||
func getInitialDBStateForUserSpaceDb(ctx context.Context, db SqlDatabase) (dsess.InitialDbState, error) {
|
||||
return dsess.InitialDbState{
|
||||
Db: db,
|
||||
DbData: env.DbData{
|
||||
Rsw: noopRepoStateWriter{},
|
||||
},
|
||||
ReadOnly: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// noopRepoStateWriter is a minimal implementation of RepoStateWriter that does nothing
|
||||
type noopRepoStateWriter struct{}
|
||||
|
||||
|
||||
@@ -184,39 +184,20 @@ func (p DoltDatabaseProvider) FileSystemForDatabase(dbname string) (filesys.File
|
||||
}
|
||||
|
||||
// Database implements the sql.DatabaseProvider interface
|
||||
func (p DoltDatabaseProvider) Database(ctx *sql.Context, name string) (db sql.Database, err error) {
|
||||
var ok bool
|
||||
p.mu.RLock()
|
||||
db, ok = p.databases[formatDbMapKeyName(name)]
|
||||
standby := *p.isStandby
|
||||
p.mu.RUnlock()
|
||||
if ok {
|
||||
return wrapForStandby(db, standby), nil
|
||||
}
|
||||
|
||||
// Revision databases aren't tracked in the map, just instantiated on demand
|
||||
db, _, ok, err = p.databaseForRevision(ctx, name)
|
||||
func (p DoltDatabaseProvider) Database(ctx *sql.Context, name string) (sql.Database, error) {
|
||||
database, b, err := p.SessionDatabase(ctx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// A final check: if the database doesn't exist and this is a read replica, attempt to clone it from the remote
|
||||
if !ok {
|
||||
db, err = p.databaseForClone(ctx, name)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
return nil, sql.ErrDatabaseNotFound.New(name)
|
||||
}
|
||||
if !b {
|
||||
return nil, sql.ErrDatabaseNotFound.New(name)
|
||||
}
|
||||
|
||||
return wrapForStandby(db, standby), nil
|
||||
return database, nil
|
||||
}
|
||||
|
||||
func wrapForStandby(db sql.Database, standby bool) sql.Database {
|
||||
func wrapForStandby(db SqlDatabase, standby bool) SqlDatabase {
|
||||
if !standby {
|
||||
return db
|
||||
}
|
||||
@@ -315,7 +296,7 @@ func (p DoltDatabaseProvider) AllDatabases(ctx *sql.Context) (all []sql.Database
|
||||
|
||||
// If the current database is not one of the primary databases, it must be a transitory revision database
|
||||
if !foundDatabase && currDb != "" {
|
||||
revDb, _, ok, err := p.databaseForRevision(ctx, currDb)
|
||||
revDb, ok, err := p.databaseForRevision(ctx, currDb)
|
||||
if err != nil {
|
||||
// We can't return an error from this interface function, so just log a message
|
||||
ctx.GetLogger().Warnf("unable to load %q as a database revision: %s", ctx.GetCurrentDatabase(), err.Error())
|
||||
@@ -343,7 +324,7 @@ func (p DoltDatabaseProvider) allRevisionDbs(ctx *sql.Context, db SqlDatabase) (
|
||||
|
||||
revDbs := make([]sql.Database, len(branches))
|
||||
for i, branch := range branches {
|
||||
revDb, _, ok, err := p.databaseForRevision(ctx, fmt.Sprintf("%s/%s", db.Name(), branch.GetPath()))
|
||||
revDb, ok, err := p.databaseForRevision(ctx, fmt.Sprintf("%s/%s", db.Name(), branch.GetPath()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -464,12 +445,7 @@ func (p DoltDatabaseProvider) CreateCollatedDatabase(ctx *sql.Context, name stri
|
||||
p.databases[formattedName] = db
|
||||
p.dbLocations[formattedName] = newEnv.FS
|
||||
|
||||
dbstate, err := GetInitialDBState(ctx, db, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sess.AddDB(ctx, dbstate)
|
||||
return nil
|
||||
}
|
||||
|
||||
type InitDatabaseHook func(ctx *sql.Context, pro DoltDatabaseProvider, name string, env *env.DoltEnv) error
|
||||
@@ -593,7 +569,6 @@ func (p DoltDatabaseProvider) cloneDatabaseFromRemote(
|
||||
Remote: remoteName,
|
||||
})
|
||||
|
||||
sess := dsess.DSessFromSess(ctx.Session)
|
||||
fkChecks, err := ctx.GetSessionVariable(ctx, "foreign_key_checks")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -612,22 +587,12 @@ func (p DoltDatabaseProvider) cloneDatabaseFromRemote(
|
||||
|
||||
p.databases[formatDbMapKeyName(db.Name())] = db
|
||||
|
||||
dbstate, err := GetInitialDBState(ctx, db, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = sess.AddDB(ctx, dbstate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dEnv, nil
|
||||
}
|
||||
|
||||
// DropDatabase implements the sql.MutableDatabaseProvider interface
|
||||
func (p DoltDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error {
|
||||
isRevisionDatabase, err := p.IsRevisionDatabase(ctx, name)
|
||||
isRevisionDatabase, err := p.isRevisionDatabase(ctx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -715,6 +680,13 @@ func (p DoltDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error
|
||||
// invalidateDbStateInAllSessions removes the db state for this database from every session. This is necessary when a
|
||||
// database is dropped, so that other sessions don't use stale db state.
|
||||
func (p DoltDatabaseProvider) invalidateDbStateInAllSessions(ctx *sql.Context, name string) error {
|
||||
// Remove the db state from the current session
|
||||
err := dsess.DSessFromSess(ctx.Session).RemoveDbState(ctx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we have a running server, remove it from other sessions as well
|
||||
runningServer := sqlserver.GetRunningServer()
|
||||
if runningServer != nil {
|
||||
sessionManager := runningServer.SessionManager()
|
||||
@@ -740,9 +712,9 @@ func (p DoltDatabaseProvider) invalidateDbStateInAllSessions(ctx *sql.Context, n
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string) (sql.Database, dsess.InitialDbState, bool, error) {
|
||||
func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string) (SqlDatabase, bool, error) {
|
||||
if !strings.Contains(revDB, dbRevisionDelimiter) {
|
||||
return nil, dsess.InitialDbState{}, false, nil
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
parts := strings.SplitN(revDB, dbRevisionDelimiter, 2)
|
||||
@@ -752,95 +724,232 @@ func (p DoltDatabaseProvider) databaseForRevision(ctx *sql.Context, revDB string
|
||||
candidate, ok := p.databases[formatDbMapKeyName(dbName)]
|
||||
p.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil, dsess.InitialDbState{}, false, nil
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
srcDb, ok := candidate.(SqlDatabase)
|
||||
if !ok {
|
||||
return nil, dsess.InitialDbState{}, false, nil
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
resolvedRevSpec, err := p.resolveAncestorSpec(ctx, revSpec, srcDb.DbData().Ddb)
|
||||
dbType, resolvedRevSpec, err := revisionDbType(ctx, srcDb, revSpec)
|
||||
if err != nil {
|
||||
return nil, dsess.InitialDbState{}, false, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
caseSensitiveBranchName, isBranch, err := isBranch(ctx, srcDb, resolvedRevSpec)
|
||||
if err != nil {
|
||||
return nil, dsess.InitialDbState{}, false, err
|
||||
}
|
||||
|
||||
if isBranch {
|
||||
switch dbType {
|
||||
case dsess.RevisionTypeBranch:
|
||||
// fetch the upstream head if this is a replicated db
|
||||
if replicaDb, ok := srcDb.(ReadReplicaDatabase); ok {
|
||||
// TODO move this out of analysis phase, should only happen at read time, when the transaction begins (like is
|
||||
// the case with a branch that already exists locally)
|
||||
err := p.ensureReplicaHeadExists(ctx, resolvedRevSpec, replicaDb)
|
||||
if err != nil {
|
||||
return nil, dsess.InitialDbState{}, false, err
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
db, init, err := dbRevisionForBranch(ctx, srcDb, caseSensitiveBranchName)
|
||||
db, err := revisionDbForBranch(ctx, srcDb, resolvedRevSpec)
|
||||
// preserve original user case in the case of not found
|
||||
if sql.ErrDatabaseNotFound.Is(err) {
|
||||
return nil, dsess.InitialDbState{}, false, sql.ErrDatabaseNotFound.New(revDB)
|
||||
return nil, false, sql.ErrDatabaseNotFound.New(revDB)
|
||||
} else if err != nil {
|
||||
return nil, dsess.InitialDbState{}, false, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return db, init, true, nil
|
||||
return db, true, nil
|
||||
case dsess.RevisionTypeTag:
|
||||
// TODO: this should be an interface, not a struct
|
||||
replicaDb, ok := srcDb.(ReadReplicaDatabase)
|
||||
|
||||
if ok {
|
||||
srcDb = replicaDb.Database
|
||||
}
|
||||
|
||||
srcDb, ok = srcDb.(Database)
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
db, err := revisionDbForTag(ctx, srcDb.(Database), resolvedRevSpec)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return db, true, nil
|
||||
case dsess.RevisionTypeCommit:
|
||||
// TODO: this should be an interface, not a struct
|
||||
replicaDb, ok := srcDb.(ReadReplicaDatabase)
|
||||
if ok {
|
||||
srcDb = replicaDb.Database
|
||||
}
|
||||
|
||||
srcDb, ok = srcDb.(Database)
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
db, err := revisionDbForCommit(ctx, srcDb.(Database), revSpec)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return db, true, nil
|
||||
case dsess.RevisionTypeNone:
|
||||
// not an error, ok = false will get handled as a not found error in a layer above as appropriate
|
||||
return nil, false, nil
|
||||
default:
|
||||
return nil, false, fmt.Errorf("unrecognized revision type for revision spec %s", revSpec)
|
||||
}
|
||||
}
|
||||
|
||||
// revisionDbType returns the type of revision spec given for the database given, and the resolved revision spec
|
||||
func revisionDbType(ctx *sql.Context, srcDb SqlDatabase, revSpec string) (revType dsess.RevisionType, resolvedRevSpec string, err error) {
|
||||
resolvedRevSpec, err = resolveAncestorSpec(ctx, revSpec, srcDb.DbData().Ddb)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
caseSensitiveBranchName, isBranch, err := isBranch(ctx, srcDb, resolvedRevSpec)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
if isBranch {
|
||||
return dsess.RevisionTypeBranch, caseSensitiveBranchName, nil
|
||||
}
|
||||
|
||||
isTag, err := isTag(ctx, srcDb, resolvedRevSpec)
|
||||
if err != nil {
|
||||
return nil, dsess.InitialDbState{}, false, err
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
if isTag {
|
||||
// TODO: this should be an interface, not a struct
|
||||
replicaDb, ok := srcDb.(ReadReplicaDatabase)
|
||||
|
||||
if ok {
|
||||
srcDb = replicaDb.Database
|
||||
}
|
||||
|
||||
srcDb, ok = srcDb.(Database)
|
||||
if !ok {
|
||||
return nil, dsess.InitialDbState{}, false, nil
|
||||
}
|
||||
|
||||
db, init, err := dbRevisionForTag(ctx, srcDb.(Database), resolvedRevSpec)
|
||||
if err != nil {
|
||||
return nil, dsess.InitialDbState{}, false, err
|
||||
}
|
||||
return db, init, true, nil
|
||||
return dsess.RevisionTypeTag, resolvedRevSpec, nil
|
||||
}
|
||||
|
||||
if doltdb.IsValidCommitHash(resolvedRevSpec) {
|
||||
// TODO: this should be an interface, not a struct
|
||||
replicaDb, ok := srcDb.(ReadReplicaDatabase)
|
||||
if ok {
|
||||
srcDb = replicaDb.Database
|
||||
}
|
||||
|
||||
srcDb, ok = srcDb.(Database)
|
||||
if !ok {
|
||||
return nil, dsess.InitialDbState{}, false, nil
|
||||
}
|
||||
db, init, err := dbRevisionForCommit(ctx, srcDb.(Database), revSpec)
|
||||
if err != nil {
|
||||
return nil, dsess.InitialDbState{}, false, err
|
||||
}
|
||||
return db, init, true, nil
|
||||
return dsess.RevisionTypeCommit, resolvedRevSpec, nil
|
||||
}
|
||||
|
||||
return nil, dsess.InitialDbState{}, false, nil
|
||||
return dsess.RevisionTypeNone, "", nil
|
||||
}
|
||||
|
||||
func initialDbState(ctx context.Context, db SqlDatabase, branch string) (dsess.InitialDbState, error) {
|
||||
rsr := db.DbData().Rsr
|
||||
ddb := db.DbData().Ddb
|
||||
|
||||
var r ref.DoltRef
|
||||
if len(branch) > 0 {
|
||||
r = ref.NewBranchRef(branch)
|
||||
} else {
|
||||
r = rsr.CWBHeadRef()
|
||||
}
|
||||
|
||||
var retainedErr error
|
||||
|
||||
headCommit, err := ddb.ResolveCommitRef(ctx, r)
|
||||
if err == doltdb.ErrBranchNotFound {
|
||||
retainedErr = err
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
var ws *doltdb.WorkingSet
|
||||
if retainedErr == nil {
|
||||
workingSetRef, err := ref.WorkingSetRefForHead(r)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
ws, err = db.DbData().Ddb.ResolveWorkingSet(ctx, workingSetRef)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
}
|
||||
|
||||
remotes, err := rsr.GetRemotes()
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
backups, err := rsr.GetBackups()
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
branches, err := rsr.GetBranches()
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
return dsess.InitialDbState{
|
||||
Db: db,
|
||||
HeadCommit: headCommit,
|
||||
WorkingSet: ws,
|
||||
DbData: db.DbData(),
|
||||
Remotes: remotes,
|
||||
Branches: branches,
|
||||
Backups: backups,
|
||||
Err: retainedErr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func initialStateForRevisionDb(ctx *sql.Context, db SqlDatabase) (dsess.InitialDbState, error) {
|
||||
switch db.RevisionType() {
|
||||
case dsess.RevisionTypeBranch:
|
||||
init, err := initialStateForBranchDb(ctx, db)
|
||||
// preserve original user case in the case of not found
|
||||
if sql.ErrDatabaseNotFound.Is(err) {
|
||||
return dsess.InitialDbState{}, sql.ErrDatabaseNotFound.New(db.Name())
|
||||
} else if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
return init, nil
|
||||
case dsess.RevisionTypeTag:
|
||||
// TODO: this should be an interface, not a struct
|
||||
replicaDb, ok := db.(ReadReplicaDatabase)
|
||||
|
||||
if ok {
|
||||
db = replicaDb.Database
|
||||
}
|
||||
|
||||
db, ok = db.(ReadOnlyDatabase)
|
||||
if !ok {
|
||||
return dsess.InitialDbState{}, fmt.Errorf("expected a ReadOnlyDatabase, got %T", db)
|
||||
}
|
||||
|
||||
init, err := initialStateForTagDb(ctx, db.(ReadOnlyDatabase))
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
return init, nil
|
||||
case dsess.RevisionTypeCommit:
|
||||
// TODO: this should be an interface, not a struct
|
||||
replicaDb, ok := db.(ReadReplicaDatabase)
|
||||
if ok {
|
||||
db = replicaDb.Database
|
||||
}
|
||||
|
||||
db, ok = db.(ReadOnlyDatabase)
|
||||
if !ok {
|
||||
return dsess.InitialDbState{}, fmt.Errorf("expected a ReadOnlyDatabase, got %T", db)
|
||||
}
|
||||
|
||||
init, err := initialStateForCommit(ctx, db.(ReadOnlyDatabase))
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
return init, nil
|
||||
default:
|
||||
return dsess.InitialDbState{}, fmt.Errorf("unrecognized revision type for revision spec %s: %v", db.Revision(), db.RevisionType())
|
||||
}
|
||||
}
|
||||
|
||||
// databaseForClone returns a newly cloned database if read replication is enabled and a remote DB exists, or an error
|
||||
// otherwise
|
||||
func (p DoltDatabaseProvider) databaseForClone(ctx *sql.Context, revDB string) (sql.Database, error) {
|
||||
func (p DoltDatabaseProvider) databaseForClone(ctx *sql.Context, revDB string) (SqlDatabase, error) {
|
||||
if !readReplicationActive(ctx) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -860,7 +969,8 @@ func (p DoltDatabaseProvider) databaseForClone(ctx *sql.Context, revDB string) (
|
||||
}
|
||||
|
||||
// now that the database has been cloned, retry the Database call
|
||||
return p.Database(ctx, revDB)
|
||||
database, err := p.Database(ctx, revDB)
|
||||
return database.(SqlDatabase), err
|
||||
}
|
||||
|
||||
// TODO: figure out the right contract: which variables must be set? What happens if they aren't all set?
|
||||
@@ -881,7 +991,7 @@ func readReplicationActive(ctx *sql.Context) bool {
|
||||
// resolveAncestorSpec resolves the specified revSpec to a specific commit hash if it contains an ancestor reference
|
||||
// such as ~ or ^. If no ancestor reference is present, the specified revSpec is returned as is. If any unexpected
|
||||
// problems are encountered, an error is returned.
|
||||
func (p DoltDatabaseProvider) resolveAncestorSpec(ctx *sql.Context, revSpec string, ddb *doltdb.DoltDB) (string, error) {
|
||||
func resolveAncestorSpec(ctx *sql.Context, revSpec string, ddb *doltdb.DoltDB) (string, error) {
|
||||
refname, ancestorSpec, err := doltdb.SplitAncestorSpec(revSpec)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -913,42 +1023,37 @@ func (p DoltDatabaseProvider) resolveAncestorSpec(ctx *sql.Context, revSpec stri
|
||||
return hash.String(), nil
|
||||
}
|
||||
|
||||
func (p DoltDatabaseProvider) RevisionDbState(ctx *sql.Context, revDB string) (dsess.InitialDbState, error) {
|
||||
_, init, ok, err := p.databaseForRevision(ctx, revDB)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
} else if !ok {
|
||||
return dsess.InitialDbState{}, sql.ErrDatabaseNotFound.New(revDB)
|
||||
}
|
||||
|
||||
return init, nil
|
||||
}
|
||||
|
||||
func (p DoltDatabaseProvider) stateForDatabase(ctx *sql.Context, dbName string, branch string) (dsess.InitialDbState, bool, error) {
|
||||
// SessionDatabase implements dsess.SessionDatabaseProvider
|
||||
func (p DoltDatabaseProvider) SessionDatabase(ctx *sql.Context, name string) (dsess.SessionDatabase, bool, error) {
|
||||
var ok bool
|
||||
p.mu.RLock()
|
||||
db, ok := p.databases[formatDbMapKeyName(dbName)]
|
||||
db, ok := p.databases[formatDbMapKeyName(name)]
|
||||
standby := *p.isStandby
|
||||
p.mu.RUnlock()
|
||||
if ok {
|
||||
return wrapForStandby(db, standby), true, nil
|
||||
}
|
||||
|
||||
// Revision databases aren't tracked in the map, just instantiated on demand
|
||||
db, ok, err := p.databaseForRevision(ctx, name)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// A final check: if the database doesn't exist and this is a read replica, attempt to clone it from the remote
|
||||
if !ok {
|
||||
return dsess.InitialDbState{}, false, nil
|
||||
db, err = p.databaseForClone(ctx, name)
|
||||
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
dbState, err := db.InitialDBState(ctx, branch)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, false, err
|
||||
}
|
||||
|
||||
return dbState, true, nil
|
||||
}
|
||||
|
||||
func (p DoltDatabaseProvider) DbState(ctx *sql.Context, dbName string, defaultBranch string) (dsess.InitialDbState, error) {
|
||||
init, ok, err := p.stateForDatabase(ctx, dbName, defaultBranch)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
} else if !ok {
|
||||
return dsess.InitialDbState{}, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
return init, nil
|
||||
return wrapForStandby(db, standby), true, nil
|
||||
}
|
||||
|
||||
// Function implements the FunctionProvider interface
|
||||
@@ -998,34 +1103,35 @@ func (p DoltDatabaseProvider) TableFunction(_ *sql.Context, name string) (sql.Ta
|
||||
return nil, sql.ErrTableFunctionNotFound.New(name)
|
||||
}
|
||||
|
||||
// GetRevisionForRevisionDatabase implements dsess.RevisionDatabaseProvider
|
||||
func (p DoltDatabaseProvider) GetRevisionForRevisionDatabase(ctx *sql.Context, dbName string) (string, string, error) {
|
||||
db, err := p.Database(ctx, dbName)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
sqldb, ok := db.(dsess.RevisionDatabase)
|
||||
// splitRevisionDbName splits the given database name into its base and revision parts and returns them. Non-revision
|
||||
// DBs use their full name as the base name, and empty string as the revision.
|
||||
func splitRevisionDbName(db sql.Database) (string, string) {
|
||||
sqldb, ok := db.(SqlDatabase)
|
||||
if !ok {
|
||||
return db.Name(), "", nil
|
||||
return db.Name(), ""
|
||||
}
|
||||
|
||||
dbName := db.Name()
|
||||
if sqldb.Revision() != "" {
|
||||
dbName = strings.TrimSuffix(dbName, dbRevisionDelimiter+sqldb.Revision())
|
||||
}
|
||||
|
||||
return dbName, sqldb.Revision(), nil
|
||||
return dbName, sqldb.Revision()
|
||||
}
|
||||
|
||||
// IsRevisionDatabase returns true if the specified dbName represents a database that is tied to a specific
|
||||
// isRevisionDatabase returns true if the specified dbName represents a database that is tied to a specific
|
||||
// branch or commit from a database (e.g. "dolt/branch1").
|
||||
func (p DoltDatabaseProvider) IsRevisionDatabase(ctx *sql.Context, dbName string) (bool, error) {
|
||||
_, revision, err := p.GetRevisionForRevisionDatabase(ctx, dbName)
|
||||
func (p DoltDatabaseProvider) isRevisionDatabase(ctx *sql.Context, dbName string) (bool, error) {
|
||||
db, ok, err := p.SessionDatabase(ctx, dbName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !ok {
|
||||
return false, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
return revision != "", nil
|
||||
_, rev := splitRevisionDbName(db)
|
||||
return rev != "", nil
|
||||
}
|
||||
|
||||
// ensureReplicaHeadExists tries to pull the latest version of a remote branch. Will fail if the branch
|
||||
@@ -1036,14 +1142,9 @@ func (p DoltDatabaseProvider) ensureReplicaHeadExists(ctx *sql.Context, branch s
|
||||
|
||||
// isBranch returns whether a branch with the given name is in scope for the database given
|
||||
func isBranch(ctx context.Context, db SqlDatabase, branchName string) (string, bool, error) {
|
||||
var ddbs []*doltdb.DoltDB
|
||||
|
||||
if rdb, ok := db.(ReadReplicaDatabase); ok {
|
||||
ddbs = append(ddbs, rdb.ddb, rdb.srcDB)
|
||||
} else if ddb, ok := db.(Database); ok {
|
||||
ddbs = append(ddbs, ddb.ddb)
|
||||
} else {
|
||||
return "", false, fmt.Errorf("unrecognized type of database %T", db)
|
||||
ddbs, err := doltDbs(db)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
brName, branchExists, err := isLocalBranch(ctx, ddbs, branchName)
|
||||
@@ -1065,6 +1166,21 @@ func isBranch(ctx context.Context, db SqlDatabase, branchName string) (string, b
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func doltDbs(db SqlDatabase) ([]*doltdb.DoltDB, error) {
|
||||
var ddbs []*doltdb.DoltDB
|
||||
switch db := db.(type) {
|
||||
case ReadReplicaDatabase:
|
||||
ddbs = append(ddbs, db.ddb, db.srcDB)
|
||||
case Database:
|
||||
ddbs = append(ddbs, db.ddb)
|
||||
case ReadOnlyDatabase:
|
||||
ddbs = append(ddbs, db.ddb)
|
||||
default:
|
||||
return nil, fmt.Errorf("unrecognized type of database %T", db)
|
||||
}
|
||||
return ddbs, nil
|
||||
}
|
||||
|
||||
func isLocalBranch(ctx context.Context, ddbs []*doltdb.DoltDB, branchName string) (string, bool, error) {
|
||||
for _, ddb := range ddbs {
|
||||
brName, branchExists, err := ddb.HasBranch(ctx, branchName)
|
||||
@@ -1098,14 +1214,9 @@ func isRemoteBranch(ctx context.Context, ddbs []*doltdb.DoltDB, branchName strin
|
||||
|
||||
// isTag returns whether a tag with the given name is in scope for the database given
|
||||
func isTag(ctx context.Context, db SqlDatabase, tagName string) (bool, error) {
|
||||
var ddbs []*doltdb.DoltDB
|
||||
|
||||
if rdb, ok := db.(ReadReplicaDatabase); ok {
|
||||
ddbs = append(ddbs, rdb.ddb, rdb.srcDB)
|
||||
} else if ddb, ok := db.(Database); ok {
|
||||
ddbs = append(ddbs, ddb.ddb)
|
||||
} else {
|
||||
return false, fmt.Errorf("unrecognized type of database %T", db)
|
||||
ddbs, err := doltDbs(db)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, ddb := range ddbs {
|
||||
@@ -1122,23 +1233,9 @@ func isTag(ctx context.Context, db SqlDatabase, tagName string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func dbRevisionForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string) (SqlDatabase, dsess.InitialDbState, error) {
|
||||
// revisionDbForBranch returns a new database that is tied to the branch named by revSpec
|
||||
func revisionDbForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string) (SqlDatabase, error) {
|
||||
branch := ref.NewBranchRef(revSpec)
|
||||
cm, err := srcDb.DbData().Ddb.ResolveCommitRef(ctx, branch)
|
||||
if err != nil {
|
||||
return Database{}, dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
wsRef, err := ref.WorkingSetRefForHead(branch)
|
||||
if err != nil {
|
||||
return Database{}, dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
ws, err := srcDb.DbData().Ddb.ResolveWorkingSet(ctx, wsRef)
|
||||
if err != nil {
|
||||
return Database{}, dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
dbName := srcDb.Name() + dbRevisionDelimiter + revSpec
|
||||
|
||||
static := staticRepoState{
|
||||
@@ -1159,6 +1256,7 @@ func dbRevisionForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string)
|
||||
gs: v.gs,
|
||||
editOpts: v.editOpts,
|
||||
revision: revSpec,
|
||||
revType: dsess.RevisionTypeBranch,
|
||||
}
|
||||
case ReadReplicaDatabase:
|
||||
db = ReadReplicaDatabase{
|
||||
@@ -1170,6 +1268,7 @@ func dbRevisionForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string)
|
||||
gs: v.gs,
|
||||
editOpts: v.editOpts,
|
||||
revision: revSpec,
|
||||
revType: dsess.RevisionTypeBranch,
|
||||
},
|
||||
remote: v.remote,
|
||||
srcDB: v.srcDB,
|
||||
@@ -1178,23 +1277,51 @@ func dbRevisionForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string)
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initialStateForBranchDb(ctx context.Context, srcDb SqlDatabase) (dsess.InitialDbState, error) {
|
||||
_, revSpec := splitRevisionDbName(srcDb)
|
||||
|
||||
branch := ref.NewBranchRef(revSpec)
|
||||
cm, err := srcDb.DbData().Ddb.ResolveCommitRef(ctx, branch)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
wsRef, err := ref.WorkingSetRefForHead(branch)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
ws, err := srcDb.DbData().Ddb.ResolveWorkingSet(ctx, wsRef)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
static := staticRepoState{
|
||||
branch: branch,
|
||||
RepoStateWriter: srcDb.DbData().Rsw,
|
||||
RepoStateReader: srcDb.DbData().Rsr,
|
||||
}
|
||||
|
||||
remotes, err := static.GetRemotes()
|
||||
if err != nil {
|
||||
return nil, dsess.InitialDbState{}, err
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
branches, err := static.GetBranches()
|
||||
if err != nil {
|
||||
return nil, dsess.InitialDbState{}, err
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
backups, err := static.GetBackups()
|
||||
if err != nil {
|
||||
return nil, dsess.InitialDbState{}, err
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
init := dsess.InitialDbState{
|
||||
Db: db,
|
||||
Db: srcDb,
|
||||
HeadCommit: cm,
|
||||
WorkingSet: ws,
|
||||
DbData: env.DbData{
|
||||
@@ -1205,31 +1332,37 @@ func dbRevisionForBranch(ctx context.Context, srcDb SqlDatabase, revSpec string)
|
||||
Remotes: remotes,
|
||||
Branches: branches,
|
||||
Backups: backups,
|
||||
//ReadReplica: //todo
|
||||
}
|
||||
|
||||
return db, init, nil
|
||||
return init, nil
|
||||
}
|
||||
|
||||
func dbRevisionForTag(ctx context.Context, srcDb Database, revSpec string) (ReadOnlyDatabase, dsess.InitialDbState, error) {
|
||||
func revisionDbForTag(ctx context.Context, srcDb Database, revSpec string) (ReadOnlyDatabase, error) {
|
||||
name := srcDb.Name() + dbRevisionDelimiter + revSpec
|
||||
db := ReadOnlyDatabase{Database: Database{
|
||||
name: name,
|
||||
ddb: srcDb.DbData().Ddb,
|
||||
rsw: srcDb.DbData().Rsw,
|
||||
rsr: srcDb.DbData().Rsr,
|
||||
editOpts: srcDb.editOpts,
|
||||
revision: revSpec,
|
||||
revType: dsess.RevisionTypeTag,
|
||||
}}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initialStateForTagDb(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.InitialDbState, error) {
|
||||
_, revSpec := splitRevisionDbName(srcDb)
|
||||
tag := ref.NewTagRef(revSpec)
|
||||
|
||||
cm, err := srcDb.DbData().Ddb.ResolveCommitRef(ctx, tag)
|
||||
if err != nil {
|
||||
return ReadOnlyDatabase{}, dsess.InitialDbState{}, err
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
name := srcDb.Name() + dbRevisionDelimiter + revSpec
|
||||
db := ReadOnlyDatabase{Database: Database{
|
||||
name: name,
|
||||
ddb: srcDb.DbData().Ddb,
|
||||
rsw: srcDb.DbData().Rsw,
|
||||
rsr: srcDb.DbData().Rsr,
|
||||
editOpts: srcDb.editOpts,
|
||||
revision: revSpec,
|
||||
}}
|
||||
init := dsess.InitialDbState{
|
||||
Db: db,
|
||||
Db: srcDb,
|
||||
HeadCommit: cm,
|
||||
ReadOnly: true,
|
||||
DbData: env.DbData{
|
||||
@@ -1244,32 +1377,39 @@ func dbRevisionForTag(ctx context.Context, srcDb Database, revSpec string) (Read
|
||||
// - ReadReplicas
|
||||
}
|
||||
|
||||
return db, init, nil
|
||||
return init, nil
|
||||
}
|
||||
|
||||
func dbRevisionForCommit(ctx context.Context, srcDb Database, revSpec string) (ReadOnlyDatabase, dsess.InitialDbState, error) {
|
||||
func revisionDbForCommit(ctx context.Context, srcDb Database, revSpec string) (ReadOnlyDatabase, error) {
|
||||
name := srcDb.Name() + dbRevisionDelimiter + revSpec
|
||||
db := ReadOnlyDatabase{Database: Database{
|
||||
name: name,
|
||||
ddb: srcDb.DbData().Ddb,
|
||||
rsw: srcDb.DbData().Rsw,
|
||||
rsr: srcDb.DbData().Rsr,
|
||||
editOpts: srcDb.editOpts,
|
||||
revision: revSpec,
|
||||
revType: dsess.RevisionTypeCommit,
|
||||
}}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initialStateForCommit(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.InitialDbState, error) {
|
||||
_, revSpec := splitRevisionDbName(srcDb)
|
||||
|
||||
spec, err := doltdb.NewCommitSpec(revSpec)
|
||||
if err != nil {
|
||||
return ReadOnlyDatabase{}, dsess.InitialDbState{}, err
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
cm, err := srcDb.DbData().Ddb.Resolve(ctx, spec, srcDb.DbData().Rsr.CWBHeadRef())
|
||||
if err != nil {
|
||||
return ReadOnlyDatabase{}, dsess.InitialDbState{}, err
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
name := srcDb.Name() + dbRevisionDelimiter + revSpec
|
||||
db := ReadOnlyDatabase{Database: Database{
|
||||
name: name,
|
||||
ddb: srcDb.DbData().Ddb,
|
||||
rsw: srcDb.DbData().Rsw,
|
||||
rsr: srcDb.DbData().Rsr,
|
||||
editOpts: srcDb.editOpts,
|
||||
revision: revSpec,
|
||||
}}
|
||||
|
||||
init := dsess.InitialDbState{
|
||||
Db: db,
|
||||
Db: srcDb,
|
||||
HeadCommit: cm,
|
||||
ReadOnly: true,
|
||||
DbData: env.DbData{
|
||||
@@ -1284,7 +1424,7 @@ func dbRevisionForCommit(ctx context.Context, srcDb Database, revSpec string) (R
|
||||
// - ReadReplicas
|
||||
}
|
||||
|
||||
return db, init, nil
|
||||
return init, nil
|
||||
}
|
||||
|
||||
type staticRepoState struct {
|
||||
|
||||
@@ -178,8 +178,20 @@ func getRevisionForRevisionDatabase(ctx *sql.Context, dbName string) (string, st
|
||||
return "", "", fmt.Errorf("unexpected session type: %T", ctx.Session)
|
||||
}
|
||||
|
||||
provider := doltsess.Provider()
|
||||
return provider.GetRevisionForRevisionDatabase(ctx, dbName)
|
||||
db, ok, err := doltsess.Provider().SessionDatabase(ctx, dbName)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", "", sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
rdb, ok := db.(dsess.RevisionDatabase)
|
||||
if !ok {
|
||||
return dbName, "", nil
|
||||
}
|
||||
|
||||
return rdb.BaseName(), rdb.Revision(), nil
|
||||
}
|
||||
|
||||
// checkoutRemoteBranch checks out a remote branch creating a new local branch with the same name as the remote branch
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
package dsess
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
@@ -26,11 +24,19 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
)
|
||||
|
||||
// InitialDbState is the initial state of a database, as returned by SessionDatabase.InitialDBState. It is used to
|
||||
// establish the in memory state of the session for every new transaction.
|
||||
type InitialDbState struct {
|
||||
Db sql.Database
|
||||
HeadCommit *doltdb.Commit
|
||||
Db sql.Database
|
||||
// WorkingSet is the working set for this database. May be nil for databases tied to a detached root value, in which
|
||||
// case HeadCommit must be set
|
||||
WorkingSet *doltdb.WorkingSet
|
||||
// The head commit for this database. May be nil for databases tied to a detached root value, in which case
|
||||
// RootValue must be set.
|
||||
HeadCommit *doltdb.Commit
|
||||
// HeadRoot is the root value for databases without a HeadCommit. Nil for databases with a HeadCommit.
|
||||
HeadRoot *doltdb.RootValue
|
||||
ReadOnly bool
|
||||
WorkingSet *doltdb.WorkingSet
|
||||
DbData env.DbData
|
||||
ReadReplica *env.Remote
|
||||
Remotes map[string]env.Remote
|
||||
@@ -48,7 +54,7 @@ type InitialDbState struct {
|
||||
// order for the session to manage it.
|
||||
type SessionDatabase interface {
|
||||
sql.Database
|
||||
InitialDBState(ctx context.Context, branch string) (InitialDbState, error)
|
||||
InitialDBState(ctx *sql.Context, branch string) (InitialDbState, error)
|
||||
}
|
||||
|
||||
type DatabaseSessionState struct {
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
package dsess
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
@@ -22,17 +23,21 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/src-d/go-errors.v1"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
func TestDoltSessionInit(t *testing.T) {
|
||||
dsess := DefaultSession(EmptyDatabaseProvider())
|
||||
dsess := DefaultSession(emptyDatabaseProvider())
|
||||
conf := config.NewMapConfig(make(map[string]string))
|
||||
assert.Equal(t, conf, dsess.globalsConf)
|
||||
}
|
||||
|
||||
func TestNewPersistedSystemVariables(t *testing.T) {
|
||||
dsess := DefaultSession(EmptyDatabaseProvider())
|
||||
dsess := DefaultSession(emptyDatabaseProvider())
|
||||
conf := config.NewMapConfig(map[string]string{"max_connections": "1000"})
|
||||
dsess = dsess.WithGlobals(conf)
|
||||
|
||||
@@ -237,3 +242,55 @@ func TestGetPersistedValue(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func emptyDatabaseProvider() DoltDatabaseProvider {
|
||||
return emptyRevisionDatabaseProvider{}
|
||||
}
|
||||
|
||||
type emptyRevisionDatabaseProvider struct {
|
||||
sql.DatabaseProvider
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) DbState(ctx *sql.Context, dbName string, defaultBranch string) (InitialDbState, error) {
|
||||
return InitialDbState{}, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) GetRevisionForRevisionDatabase(_ *sql.Context, _ string) (string, string, error) {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) IsRevisionDatabase(_ *sql.Context, _ string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) GetRemoteDB(ctx context.Context, format *types.NomsBinFormat, r env.Remote, withCaching bool) (*doltdb.DoltDB, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) FileSystem() filesys.Filesys {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) FileSystemForDatabase(dbname string) (filesys.Filesys, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) CloneDatabaseFromRemote(ctx *sql.Context, dbName, branch, remoteName, remoteUrl string, remoteParams map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) CreateDatabase(ctx *sql.Context, dbName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) RevisionDbState(_ *sql.Context, revDB string) (InitialDbState, error) {
|
||||
return InitialDbState{}, sql.ErrDatabaseNotFound.New(revDB)
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) SessionDatabase(ctx *sql.Context, dbName string) (SessionDatabase, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
@@ -144,32 +144,19 @@ func (d *DoltSession) lookupDbState(ctx *sql.Context, dbName string) (*DatabaseS
|
||||
}
|
||||
|
||||
// TODO: this needs to include the transaction's snapshot of the DB at tx start time
|
||||
var init InitialDbState
|
||||
var err error
|
||||
|
||||
_, val, ok := sql.SystemVariables.GetGlobal(DefaultBranchKey(dbName))
|
||||
initialBranch := ""
|
||||
if ok {
|
||||
initialBranch = val.(string)
|
||||
}
|
||||
|
||||
// First attempt to find a bare database (no revision spec)
|
||||
init, err = d.provider.DbState(ctx, dbName, initialBranch)
|
||||
if err != nil && !sql.ErrDatabaseNotFound.Is(err) {
|
||||
database, ok, err := d.provider.SessionDatabase(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// If that doesn't work, attempt to parse the database name as a revision spec
|
||||
if err != nil {
|
||||
init, err = d.provider.RevisionDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// If we got this far, we have a valid initial database state, so add it to the session for future reuse
|
||||
if err = d.AddDB(ctx, init); err != nil {
|
||||
return nil, ok, err
|
||||
// Add the initial state to the session for future reuse
|
||||
if err = d.addDB(ctx, database); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
@@ -236,12 +223,12 @@ func (d *DoltSession) ValidateSession(ctx *sql.Context, dbName string) error {
|
||||
return d.validateErr
|
||||
}
|
||||
sessionState, ok, err := d.LookupDbState(ctx, dbName)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if sessionState.WorkingSet == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -279,24 +266,16 @@ func (d *DoltSession) StartTransaction(ctx *sql.Context, tCharacteristic sql.Tra
|
||||
// Since StartTransaction occurs before even any analysis, it's possible that this session has no state for the
|
||||
// database with the transaction being performed, so we load it here.
|
||||
if !d.HasDB(ctx, dbName) {
|
||||
db, err := d.provider.Database(ctx, dbName)
|
||||
db, ok, err := d.provider.SessionDatabase(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sdb, ok := db.(SessionDatabase)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("database %s does not support sessions", dbName)
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
// TODO: this needs a real branch name
|
||||
init, err := sdb.InitialDBState(ctx, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: make this take a DB, not a DBState
|
||||
err = d.AddDB(ctx, init)
|
||||
err = d.addDB(ctx, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1113,13 +1092,9 @@ func (d *DoltSession) HasDB(_ *sql.Context, dbName string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// AddDB adds the database given to this session. This establishes a starting root value for this session, as well as
|
||||
// addDB adds the database given to this session. This establishes a starting root value for this session, as well as
|
||||
// other state tracking metadata.
|
||||
// TODO: the session has a database provider, we shouldn't need to add databases to it explicitly, this should be
|
||||
//
|
||||
// internal only
|
||||
func (d *DoltSession) AddDB(ctx *sql.Context, dbState InitialDbState) error {
|
||||
db := dbState.Db
|
||||
func (d *DoltSession) addDB(ctx *sql.Context, db SessionDatabase) error {
|
||||
DefineSystemVariablesForDB(db.Name())
|
||||
|
||||
sessionState := NewEmptyDatabaseSessionState()
|
||||
@@ -1129,6 +1104,18 @@ func (d *DoltSession) AddDB(ctx *sql.Context, dbState InitialDbState) error {
|
||||
sessionState.dbName = db.Name()
|
||||
sessionState.db = db
|
||||
|
||||
_, val, ok := sql.SystemVariables.GetGlobal(DefaultBranchKey(db.Name()))
|
||||
initialBranch := ""
|
||||
if ok {
|
||||
initialBranch = val.(string)
|
||||
}
|
||||
|
||||
// TODO: the branch should be already set if the DB was specified with a branch revision string
|
||||
dbState, err := db.InitialDBState(ctx, initialBranch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: get rid of all repo state reader / writer stuff. Until we do, swap out the reader with one of our own, and
|
||||
// the writer with one that errors out
|
||||
// TODO: this no longer gets called at session creation time, so the error handling below never occurs when a
|
||||
@@ -1148,23 +1135,26 @@ func (d *DoltSession) AddDB(ctx *sql.Context, dbState InitialDbState) error {
|
||||
sessionState.readOnly, sessionState.readReplica = dbState.ReadOnly, dbState.ReadReplica
|
||||
|
||||
// TODO: figure out how to cast this to dsqle.SqlDatabase without creating import cycles
|
||||
// Or better yet, get rid of EditOptions from the database, it's a session setting
|
||||
nbf := types.Format_Default
|
||||
if sessionState.dbData.Ddb != nil {
|
||||
nbf = sessionState.dbData.Ddb.Format()
|
||||
}
|
||||
editOpts := db.(interface{ EditOptions() editor.Options }).EditOptions()
|
||||
|
||||
stateProvider, ok := db.(globalstate.StateProvider)
|
||||
if !ok {
|
||||
return fmt.Errorf("database does not contain global state store")
|
||||
}
|
||||
sessionState.globalState = stateProvider.GetGlobalState()
|
||||
|
||||
// WorkingSet is nil in the case of a read only, detached head DB
|
||||
if dbState.Err != nil {
|
||||
sessionState.Err = dbState.Err
|
||||
} else if dbState.WorkingSet != nil {
|
||||
sessionState.WorkingSet = dbState.WorkingSet
|
||||
|
||||
// TODO: this is pretty clunky, there is a silly dependency between InitialDbState and globalstate.StateProvider
|
||||
// that's hard to express with the current types
|
||||
stateProvider, ok := db.(globalstate.StateProvider)
|
||||
if !ok {
|
||||
return fmt.Errorf("database does not contain global state store")
|
||||
}
|
||||
sessionState.globalState = stateProvider.GetGlobalState()
|
||||
|
||||
tracker, err := sessionState.globalState.GetAutoIncrementTracker(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1173,13 +1163,15 @@ func (d *DoltSession) AddDB(ctx *sql.Context, dbState InitialDbState) error {
|
||||
if err = d.SetWorkingSet(ctx, db.Name(), dbState.WorkingSet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
} else if dbState.HeadCommit != nil {
|
||||
// WorkingSet is nil in the case of a read only, detached head DB
|
||||
headRoot, err := dbState.HeadCommit.GetRootValue(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionState.headRoot = headRoot
|
||||
} else if dbState.HeadRoot != nil {
|
||||
sessionState.headRoot = dbState.HeadRoot
|
||||
}
|
||||
|
||||
// This has to happen after SetRoot above, since it does a stale check before its work
|
||||
@@ -1257,6 +1249,8 @@ func (d *DoltSession) setSessionVarsForDb(ctx *sql.Context, dbName string) error
|
||||
return err
|
||||
}
|
||||
|
||||
// Different DBs have different requirements for what state is set, so we are maximally permissive on what's expected
|
||||
// in the state object here
|
||||
if state.WorkingSet != nil {
|
||||
headRef, err := state.WorkingSet.Ref().ToHeadRef()
|
||||
if err != nil {
|
||||
@@ -1271,31 +1265,37 @@ func (d *DoltSession) setSessionVarsForDb(ctx *sql.Context, dbName string) error
|
||||
|
||||
roots := state.GetRoots()
|
||||
|
||||
h, err := roots.Working.HashOf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.Session.SetSessionVariable(ctx, WorkingKey(dbName), h.String())
|
||||
if err != nil {
|
||||
return err
|
||||
if roots.Working != nil {
|
||||
h, err := roots.Working.HashOf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.Session.SetSessionVariable(ctx, WorkingKey(dbName), h.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
h, err = roots.Staged.HashOf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.Session.SetSessionVariable(ctx, StagedKey(dbName), h.String())
|
||||
if err != nil {
|
||||
return err
|
||||
if roots.Staged != nil {
|
||||
h, err := roots.Staged.HashOf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.Session.SetSessionVariable(ctx, StagedKey(dbName), h.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
h, err = state.headCommit.HashOf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.Session.SetSessionVariable(ctx, HeadKey(dbName), h.String())
|
||||
if err != nil {
|
||||
return err
|
||||
if state.headCommit != nil {
|
||||
h, err := state.headCommit.HashOf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = d.Session.SetSessionVariable(ctx, HeadKey(dbName), h.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1382,14 +1382,17 @@ func (d *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) {
|
||||
func (d *DoltSession) GetBranch() (string, error) {
|
||||
ctx := sql.NewContext(context.Background(), sql.WithSession(d))
|
||||
currentDb := d.Session.GetCurrentDatabase()
|
||||
|
||||
// no branch if there's no current db
|
||||
if currentDb == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
dbState, _, err := d.LookupDbState(ctx, currentDb)
|
||||
if err != nil {
|
||||
if len(currentDb) == 0 && sql.ErrDatabaseNotFound.Is(err) {
|
||||
// Some operations return an empty database (namely tests), so we return an empty branch in such cases
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
if dbState.WorkingSet != nil {
|
||||
branchRef, err := dbState.WorkingSet.Ref().ToHeadRef()
|
||||
if err != nil {
|
||||
|
||||
+19
-75
@@ -18,7 +18,6 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"gopkg.in/src-d/go-errors.v1"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
@@ -26,25 +25,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
// ErrRevisionDbNotFound is thrown when a RevisionDatabaseProvider cannot find a specified revision database.
|
||||
var ErrRevisionDbNotFound = errors.NewKind("revision database not found: '%s'")
|
||||
|
||||
// RevisionDatabaseProvider provides revision databases.
|
||||
// In Dolt, commits and branches can be accessed as discrete databases
|
||||
// using a Dolt-specific syntax: `my_database/my_branch`. Revision databases
|
||||
// corresponding to historical commits in the repository will be read-only
|
||||
// databases. Revision databases for branches will be read/write.
|
||||
type RevisionDatabaseProvider interface {
|
||||
// RevisionDbState provides the InitialDbState for a revision database.
|
||||
RevisionDbState(ctx *sql.Context, revDB string) (InitialDbState, error)
|
||||
// IsRevisionDatabase validates the specified dbName and returns true if it is a valid revision database.
|
||||
IsRevisionDatabase(ctx *sql.Context, dbName string) (bool, error)
|
||||
// GetRevisionForRevisionDatabase looks up the named database and returns the root database name as well as the
|
||||
// revision and any errors encountered. If the specified database is not a revision database, the root database
|
||||
// name will still be returned, and the revision will be an empty string.
|
||||
GetRevisionForRevisionDatabase(ctx *sql.Context, dbName string) (string, string, error)
|
||||
}
|
||||
|
||||
// RevisionDatabase allows callers to query a revision database for the commit, branch, or tag it is pinned to. For
|
||||
// example, when using a database with a branch revision specification, that database is only able to use that branch
|
||||
// and cannot change branch heads. Calling `Revision` on that database will return the branch name. Similarly, for
|
||||
@@ -56,8 +36,24 @@ type RevisionDatabase interface {
|
||||
// revision specifications (e.g. "HEAD~2") are not supported. If a database implements RevisionDatabase, but
|
||||
// is not pinned to a specific revision, the empty string is returned.
|
||||
Revision() string
|
||||
// RevisionType returns the type of revision this database is pinned to.
|
||||
RevisionType() RevisionType
|
||||
// BaseName returns the name of the database without the revision specifier. E.g.if the database is named
|
||||
// "myDB/master", BaseName returns "myDB".
|
||||
BaseName() string
|
||||
}
|
||||
|
||||
// RevisionType represents the type of revision a database is pinned to. For branches and tags, the revision is a
|
||||
// string naming that branch or tag. For other revision specs, e.g. "HEAD~2", the revision is a commit hash.
|
||||
type RevisionType int
|
||||
|
||||
const (
|
||||
RevisionTypeNone RevisionType = iota
|
||||
RevisionTypeBranch
|
||||
RevisionTypeTag
|
||||
RevisionTypeCommit
|
||||
)
|
||||
|
||||
// RemoteReadReplicaDatabase is a database that pulls from a connected remote when a transaction begins.
|
||||
type RemoteReadReplicaDatabase interface {
|
||||
// ValidReplicaState returns whether this read replica is in a valid state to pull from the remote
|
||||
@@ -68,9 +64,6 @@ type RemoteReadReplicaDatabase interface {
|
||||
|
||||
type DoltDatabaseProvider interface {
|
||||
sql.MutableDatabaseProvider
|
||||
RevisionDatabaseProvider
|
||||
// env.RemoteDbProvider
|
||||
|
||||
// FileSystem returns the filesystem used by this provider, rooted at the data directory for all databases.
|
||||
FileSystem() filesys.Filesys
|
||||
// FileSystemForDatabase returns a filesystem, with the working directory set to the root directory
|
||||
@@ -87,56 +80,7 @@ type DoltDatabaseProvider interface {
|
||||
// (otherwise all branches are cloned), remoteName is the name for the remote created in the new database, and
|
||||
// remoteUrl is a URL (e.g. "file:///dbs/db1") or an <org>/<database> path indicating a database hosted on DoltHub.
|
||||
CloneDatabaseFromRemote(ctx *sql.Context, dbName, branch, remoteName, remoteUrl string, remoteParams map[string]string) error
|
||||
// DbState returns the InitialDbState for the specified database and given branch. An empty branch name should use
|
||||
// the default branch for the repository.
|
||||
// TODO: make this use an ok bool instead of relying on sql.DatabaseNotFound errors
|
||||
DbState(ctx *sql.Context, dbName string, defaultBranch string) (InitialDbState, error)
|
||||
}
|
||||
|
||||
func EmptyDatabaseProvider() DoltDatabaseProvider {
|
||||
return emptyRevisionDatabaseProvider{}
|
||||
}
|
||||
|
||||
type emptyRevisionDatabaseProvider struct {
|
||||
sql.DatabaseProvider
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) DbState(ctx *sql.Context, dbName string, defaultBranch string) (InitialDbState, error) {
|
||||
return InitialDbState{}, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) GetRevisionForRevisionDatabase(_ *sql.Context, _ string) (string, string, error) {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) IsRevisionDatabase(_ *sql.Context, _ string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) GetRemoteDB(ctx context.Context, format *types.NomsBinFormat, r env.Remote, withCaching bool) (*doltdb.DoltDB, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) FileSystem() filesys.Filesys {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) FileSystemForDatabase(dbname string) (filesys.Filesys, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) CloneDatabaseFromRemote(ctx *sql.Context, dbName, branch, remoteName, remoteUrl string, remoteParams map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) CreateDatabase(ctx *sql.Context, dbName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e emptyRevisionDatabaseProvider) RevisionDbState(_ *sql.Context, revDB string) (InitialDbState, error) {
|
||||
return InitialDbState{}, sql.ErrDatabaseNotFound.New(revDB)
|
||||
// SessionDatabase returns the SessionDatabase for the specified database, which may name a revision of a base
|
||||
// database.
|
||||
SessionDatabase(ctx *sql.Context, dbName string) (SessionDatabase, bool, error)
|
||||
}
|
||||
@@ -2324,6 +2324,7 @@ func skipPreparedTests(t *testing.T) {
|
||||
|
||||
func newSessionBuilder(harness *DoltHarness) server.SessionBuilder {
|
||||
return func(ctx context.Context, conn *mysql.Conn, host string) (sql.Session, error) {
|
||||
return harness.session, nil
|
||||
newCtx := harness.NewSession()
|
||||
return newCtx.Session, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,16 +367,10 @@ func (d *DoltHarness) Close() {
|
||||
|
||||
func (d *DoltHarness) closeProvider() {
|
||||
if d.provider != nil {
|
||||
dbs := sqle.AllDbs(sql.NewEmptyContext(), d.provider)
|
||||
dbs := d.provider.AllDatabases(sql.NewEmptyContext())
|
||||
for _, db := range dbs {
|
||||
d.t.Logf("closing %v", db)
|
||||
require.NoError(d.t, db.DbData().Ddb.Close())
|
||||
}
|
||||
}
|
||||
if d.session != nil {
|
||||
dbs := sqle.AllDbs(sql.NewEmptyContext(), d.session.Provider())
|
||||
for _, db := range dbs {
|
||||
d.t.Logf("session had database %v", db)
|
||||
require.NoError(d.t, db.(sqle.SqlDatabase).DbData().Ddb.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,6 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
|
||||
)
|
||||
@@ -183,8 +182,7 @@ var (
|
||||
)
|
||||
|
||||
func TestDoltIndexEqual(t *testing.T) {
|
||||
ddb, ctx, root, indexMap := doltIndexSetup(t)
|
||||
defer ddb.Close()
|
||||
root, indexMap := doltIndexSetup(t)
|
||||
|
||||
tests := []doltIndexTestCase{
|
||||
{
|
||||
@@ -298,6 +296,7 @@ func TestDoltIndexEqual(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) {
|
||||
ctx := sql.NewEmptyContext()
|
||||
idx, ok := indexMap[test.indexName]
|
||||
require.True(t, ok)
|
||||
testDoltIndex(t, ctx, root, test.keys, test.expectedRows, idx, indexComp_Eq)
|
||||
@@ -306,8 +305,7 @@ func TestDoltIndexEqual(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDoltIndexGreaterThan(t *testing.T) {
|
||||
ddb, ctx, root, indexMap := doltIndexSetup(t)
|
||||
defer ddb.Close()
|
||||
root, indexMap := doltIndexSetup(t)
|
||||
|
||||
tests := []struct {
|
||||
indexName string
|
||||
@@ -440,6 +438,7 @@ func TestDoltIndexGreaterThan(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) {
|
||||
ctx := sql.NewEmptyContext()
|
||||
index, ok := indexMap[test.indexName]
|
||||
require.True(t, ok)
|
||||
testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_Gt)
|
||||
@@ -448,8 +447,7 @@ func TestDoltIndexGreaterThan(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDoltIndexGreaterThanOrEqual(t *testing.T) {
|
||||
ddb, ctx, root, indexMap := doltIndexSetup(t)
|
||||
defer ddb.Close()
|
||||
root, indexMap := doltIndexSetup(t)
|
||||
|
||||
tests := []struct {
|
||||
indexName string
|
||||
@@ -578,6 +576,7 @@ func TestDoltIndexGreaterThanOrEqual(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) {
|
||||
ctx := sql.NewEmptyContext()
|
||||
index, ok := indexMap[test.indexName]
|
||||
require.True(t, ok)
|
||||
testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_GtE)
|
||||
@@ -586,8 +585,7 @@ func TestDoltIndexGreaterThanOrEqual(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDoltIndexLessThan(t *testing.T) {
|
||||
ddb, ctx, root, indexMap := doltIndexSetup(t)
|
||||
defer ddb.Close()
|
||||
root, indexMap := doltIndexSetup(t)
|
||||
|
||||
tests := []struct {
|
||||
indexName string
|
||||
@@ -725,6 +723,7 @@ func TestDoltIndexLessThan(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) {
|
||||
ctx := sql.NewEmptyContext()
|
||||
index, ok := indexMap[test.indexName]
|
||||
require.True(t, ok)
|
||||
testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_Lt)
|
||||
@@ -733,8 +732,7 @@ func TestDoltIndexLessThan(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDoltIndexLessThanOrEqual(t *testing.T) {
|
||||
ddb, ctx, root, indexMap := doltIndexSetup(t)
|
||||
defer ddb.Close()
|
||||
root, indexMap := doltIndexSetup(t)
|
||||
|
||||
tests := []struct {
|
||||
indexName string
|
||||
@@ -873,6 +871,7 @@ func TestDoltIndexLessThanOrEqual(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s|%v", test.indexName, test.keys), func(t *testing.T) {
|
||||
ctx := sql.NewEmptyContext()
|
||||
index, ok := indexMap[test.indexName]
|
||||
require.True(t, ok)
|
||||
testDoltIndex(t, ctx, root, test.keys, test.expectedRows, index, indexComp_LtE)
|
||||
@@ -881,8 +880,7 @@ func TestDoltIndexLessThanOrEqual(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDoltIndexBetween(t *testing.T) {
|
||||
ddb, ctx, root, indexMap := doltIndexSetup(t)
|
||||
defer ddb.Close()
|
||||
root, indexMap := doltIndexSetup(t)
|
||||
|
||||
tests := []doltIndexBetweenTestCase{
|
||||
{
|
||||
@@ -1052,6 +1050,8 @@ func TestDoltIndexBetween(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s|%v%v", test.indexName, test.greaterThanOrEqual, test.lessThanOrEqual), func(t *testing.T) {
|
||||
ctx := sql.NewEmptyContext()
|
||||
|
||||
idx, ok := indexMap[test.indexName]
|
||||
require.True(t, ok)
|
||||
|
||||
@@ -1292,6 +1292,7 @@ func requireUnorderedRowsEqual(t *testing.T, s sql.Schema, rows1, rows2 []sql.Ro
|
||||
}
|
||||
|
||||
func testDoltIndex(t *testing.T, ctx *sql.Context, root *doltdb.RootValue, keys []interface{}, expectedRows []sql.Row, idx index.DoltIndex, cmp indexComp) {
|
||||
ctx = sql.NewEmptyContext()
|
||||
exprs := idx.Expressions()
|
||||
builder := sql.NewIndexBuilder(idx)
|
||||
for i, key := range keys {
|
||||
@@ -1335,8 +1336,8 @@ func testDoltIndex(t *testing.T, ctx *sql.Context, root *doltdb.RootValue, keys
|
||||
requireUnorderedRowsEqual(t, pkSch.Schema, convertSqlRowToInt64(expectedRows), readRows)
|
||||
}
|
||||
|
||||
func doltIndexSetup(t *testing.T) (*doltdb.DoltDB, *sql.Context, *doltdb.RootValue, map[string]index.DoltIndex) {
|
||||
ctx := NewTestSQLCtx(context.Background())
|
||||
func doltIndexSetup(t *testing.T) (*doltdb.RootValue, map[string]index.DoltIndex) {
|
||||
ctx := context.Background()
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
root, err := dEnv.WorkingRoot(ctx)
|
||||
if err != nil {
|
||||
@@ -1408,18 +1409,7 @@ INSERT INTO types VALUES (1, 4, '2020-05-14 12:00:03', 1.1, 'd', 1.1, 'a,c', '00
|
||||
}
|
||||
}
|
||||
|
||||
return dEnv.DoltDB, ctx, root, indexMap
|
||||
}
|
||||
|
||||
func NewTestSQLCtx(ctx context.Context) *sql.Context {
|
||||
s := dsess.DefaultSession(dsess.EmptyDatabaseProvider())
|
||||
s.SetCurrentDatabase("dolt")
|
||||
sqlCtx := sql.NewContext(
|
||||
ctx,
|
||||
sql.WithSession(s),
|
||||
)
|
||||
|
||||
return sqlCtx
|
||||
return root, indexMap
|
||||
}
|
||||
|
||||
func mustTime(timeString string) time.Time {
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
@@ -32,11 +31,10 @@ import (
|
||||
dsqle "github.com/dolthub/dolt/go/libraries/doltcore/sqle"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *env.DoltEnv, *doltdb.RootValue, dsqle.Database, []*indexTuple) {
|
||||
func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *sql.Context, []*indexTuple) {
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
require.NoError(t, err)
|
||||
@@ -102,21 +100,12 @@ func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *e
|
||||
b := env.GetDefaultInitBranch(dEnv.Config)
|
||||
pro, err := dsqle.NewDoltDatabaseProviderWithDatabase(b, mrEnv.FileSystem(), db, dEnv.FS)
|
||||
if err != nil {
|
||||
return nil, nil, nil, dsqle.Database{}, nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
pro = pro.WithDbFactoryUrl(doltdb.InMemDoltDB)
|
||||
engine = sqle.NewDefault(pro)
|
||||
|
||||
// Get an updated root to use for the rest of the test
|
||||
ctx := sql.NewEmptyContext()
|
||||
controller := branch_control.CreateDefaultController()
|
||||
sess, err := dsess.NewDoltSession(ctx.Session.(*sql.BaseSession), pro, config.NewEmptyMapConfig(), controller)
|
||||
require.NoError(t, err)
|
||||
roots, ok := sess.GetRoots(ctx, db.Name())
|
||||
require.True(t, ok)
|
||||
err = sess.SetRoot(sqlCtx, db.Name(), roots.Working)
|
||||
|
||||
it := []*indexTuple{
|
||||
idxv1ToTuple,
|
||||
idxv2v1ToTuple,
|
||||
@@ -126,7 +115,7 @@ func setupIndexes(t *testing.T, tableName, insertQuery string) (*sqle.Engine, *e
|
||||
},
|
||||
}
|
||||
|
||||
return engine, dEnv, roots.Working, db, it
|
||||
return engine, sqlCtx, it
|
||||
}
|
||||
|
||||
// indexTuple converts integers into the appropriate tuple for comparison against ranges
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
package index_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
@@ -27,7 +26,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/noms"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
@@ -42,7 +40,7 @@ func TestMergeableIndexes(t *testing.T) {
|
||||
t.Skip() // this test is specific to Noms ranges
|
||||
}
|
||||
|
||||
engine, denv, root, db, indexTuples := setupIndexes(t, "test", `INSERT INTO test VALUES
|
||||
engine, sqlCtx, indexTuples := setupIndexes(t, "test", `INSERT INTO test VALUES
|
||||
(-3, NULL, NULL),
|
||||
(-2, NULL, NULL),
|
||||
(-1, NULL, NULL),
|
||||
@@ -1316,16 +1314,6 @@ func TestMergeableIndexes(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.whereStmt, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
sqlCtx := NewTestSQLCtx(ctx)
|
||||
session := dsess.DSessFromSess(sqlCtx.Session)
|
||||
dbState := getDbState(t, db, denv)
|
||||
err := session.AddDB(sqlCtx, dbState)
|
||||
require.NoError(t, err)
|
||||
sqlCtx.SetCurrentDatabase(db.Name())
|
||||
err = session.SetRoot(sqlCtx, db.Name(), root)
|
||||
require.NoError(t, err)
|
||||
|
||||
query := fmt.Sprintf(`SELECT pk FROM test WHERE %s ORDER BY 1`, test.whereStmt)
|
||||
|
||||
finalRanges, err := ReadRangesFromQuery(sqlCtx, engine, query)
|
||||
@@ -1382,7 +1370,7 @@ func TestMergeableIndexesNulls(t *testing.T) {
|
||||
t.Skip() // this test is specific to Noms ranges
|
||||
}
|
||||
|
||||
engine, denv, root, db, indexTuples := setupIndexes(t, "test", `INSERT INTO test VALUES
|
||||
engine, sqlCtx, indexTuples := setupIndexes(t, "test", `INSERT INTO test VALUES
|
||||
(0, 10, 20),
|
||||
(1, 11, 21),
|
||||
(2, NULL, NULL),
|
||||
@@ -1532,16 +1520,6 @@ func TestMergeableIndexesNulls(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.whereStmt, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
sqlCtx := NewTestSQLCtx(ctx)
|
||||
session := dsess.DSessFromSess(sqlCtx.Session)
|
||||
dbState := getDbState(t, db, denv)
|
||||
err := session.AddDB(sqlCtx, dbState)
|
||||
require.NoError(t, err)
|
||||
sqlCtx.SetCurrentDatabase(db.Name())
|
||||
err = session.SetRoot(sqlCtx, db.Name(), root)
|
||||
require.NoError(t, err)
|
||||
|
||||
query := fmt.Sprintf(`SELECT pk FROM test WHERE %s ORDER BY 1`, test.whereStmt)
|
||||
|
||||
finalRanges, err := ReadRangesFromQuery(sqlCtx, engine, query)
|
||||
|
||||
@@ -52,9 +52,9 @@ type DoltHarness struct {
|
||||
}
|
||||
|
||||
func (h *DoltHarness) Close() {
|
||||
dbs := dsql.AllDbs(sql.NewEmptyContext(), h.sess.Provider())
|
||||
dbs := h.sess.Provider().AllDatabases(sql.NewEmptyContext())
|
||||
for _, db := range dbs {
|
||||
db.DbData().Ddb.Close()
|
||||
db.(dsql.SqlDatabase).DbData().Ddb.Close()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -96,8 +96,8 @@ func (rrd ReadReplicaDatabase) ValidReplicaState(ctx *sql.Context) bool {
|
||||
// InitialDBState implements dsess.SessionDatabase
|
||||
// This seems like a pointless override from the embedded Database implementation, but it's necessary to pass the
|
||||
// correct pointer type to the session initializer.
|
||||
func (rrd ReadReplicaDatabase) InitialDBState(ctx context.Context, branch string) (dsess.InitialDbState, error) {
|
||||
return GetInitialDBState(ctx, rrd, branch)
|
||||
func (rrd ReadReplicaDatabase) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) {
|
||||
return initialDBState(ctx, rrd, branch)
|
||||
}
|
||||
|
||||
func (rrd ReadReplicaDatabase) PullFromRemote(ctx *sql.Context) error {
|
||||
|
||||
@@ -27,27 +27,21 @@ import (
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/json"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
)
|
||||
|
||||
func TestSchemaTableMigrationOriginal(t *testing.T) {
|
||||
ctx := NewTestSQLCtx(context.Background())
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
require.NoError(t, err)
|
||||
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: tmpDir}
|
||||
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
|
||||
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
dbState, err := getDbState(db, dEnv)
|
||||
_, ctx, err := NewTestEngine(dEnv, context.Background(), db)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = dsess.DSessFromSess(ctx.Session).AddDB(ctx, dbState)
|
||||
require.NoError(t, err)
|
||||
ctx.SetCurrentDatabase(db.Name())
|
||||
|
||||
err = db.createSqlTable(ctx, doltdb.SchemasTableName, sql.NewPrimaryKeySchema(sql.Schema{ // original schema of dolt_schemas table
|
||||
{Name: doltdb.SchemasTablesTypeCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: true},
|
||||
{Name: doltdb.SchemasTablesNameCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: true},
|
||||
@@ -94,21 +88,16 @@ func TestSchemaTableMigrationOriginal(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSchemaTableMigrationV1(t *testing.T) {
|
||||
ctx := NewTestSQLCtx(context.Background())
|
||||
dEnv := dtestutils.CreateTestEnv()
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
require.NoError(t, err)
|
||||
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: tmpDir}
|
||||
db, err := NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
|
||||
db, err := NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
dbState, err := getDbState(db, dEnv)
|
||||
_, ctx, err := NewTestEngine(dEnv, context.Background(), db)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = dsess.DSessFromSess(ctx.Session).AddDB(ctx, dbState)
|
||||
require.NoError(t, err)
|
||||
ctx.SetCurrentDatabase(db.Name())
|
||||
|
||||
// original schema of dolt_schemas table with the ID column
|
||||
err = db.createSqlTable(ctx, doltdb.SchemasTableName, sql.NewPrimaryKeySchema(sql.Schema{
|
||||
{Name: doltdb.SchemasTablesTypeCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: false},
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
// Copyright 2020 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
)
|
||||
|
||||
// SingleTableInfoDatabase is intended to allow a sole schema to make use of any display functionality in `go-mysql-server`.
|
||||
// For example, you may have constructed a schema that you want a CREATE TABLE statement for, but the schema is not
|
||||
// persisted or is temporary. This allows `go-mysql-server` to interact with that sole schema as though it were a database.
|
||||
// No write operations will work with this database.
|
||||
type SingleTableInfoDatabase struct {
|
||||
tableName string
|
||||
sch schema.Schema
|
||||
foreignKeys []doltdb.ForeignKey
|
||||
parentSchs map[string]schema.Schema
|
||||
}
|
||||
|
||||
var _ doltReadOnlyTableInterface = (*SingleTableInfoDatabase)(nil)
|
||||
var _ sql.IndexedTable = (*SingleTableInfoDatabase)(nil)
|
||||
var _ SqlDatabase = (*SingleTableInfoDatabase)(nil)
|
||||
|
||||
func NewSingleTableDatabase(tableName string, sch schema.Schema, foreignKeys []doltdb.ForeignKey, parentSchs map[string]schema.Schema) *SingleTableInfoDatabase {
|
||||
return &SingleTableInfoDatabase{
|
||||
tableName: tableName,
|
||||
sch: sch,
|
||||
foreignKeys: foreignKeys,
|
||||
parentSchs: parentSchs,
|
||||
}
|
||||
}
|
||||
|
||||
// Name implements sql.Table and sql.Database.
|
||||
func (db *SingleTableInfoDatabase) Name() string {
|
||||
return db.tableName
|
||||
}
|
||||
|
||||
// GetTableInsensitive implements sql.Database.
|
||||
func (db *SingleTableInfoDatabase) GetTableInsensitive(ctx *sql.Context, tableName string) (sql.Table, bool, error) {
|
||||
if strings.ToLower(tableName) == strings.ToLower(db.tableName) {
|
||||
return db, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// GetTableNames implements sql.Database.
|
||||
func (db *SingleTableInfoDatabase) GetTableNames(ctx *sql.Context) ([]string, error) {
|
||||
return []string{db.tableName}, nil
|
||||
}
|
||||
|
||||
// String implements sql.Table.
|
||||
func (db *SingleTableInfoDatabase) String() string {
|
||||
return db.tableName
|
||||
}
|
||||
|
||||
// Schema implements sql.Table.
|
||||
func (db *SingleTableInfoDatabase) Schema() sql.Schema {
|
||||
sqlSch, err := sqlutil.FromDoltSchema(db.tableName, db.sch)
|
||||
if err != nil {
|
||||
}
|
||||
return sqlSch.Schema
|
||||
}
|
||||
|
||||
// Collation implements sql.Table.
|
||||
func (db *SingleTableInfoDatabase) Collation() sql.CollationID {
|
||||
return sql.CollationID(db.sch.GetCollation())
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) InitialDBState(ctx context.Context, branch string) (dsess.InitialDbState, error) {
|
||||
return getInitialDBStateForUserSpaceDb(ctx, db)
|
||||
}
|
||||
|
||||
// Partitions implements sql.Table.
|
||||
func (db *SingleTableInfoDatabase) Partitions(*sql.Context) (sql.PartitionIter, error) {
|
||||
return nil, fmt.Errorf("cannot get paritions of a single table information database")
|
||||
}
|
||||
|
||||
// PartitionRows implements sql.Table.
|
||||
func (db *SingleTableInfoDatabase) PartitionRows(*sql.Context, sql.Partition) (sql.RowIter, error) {
|
||||
return nil, fmt.Errorf("cannot get partition rows of a single table information database")
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) PartitionRows2(ctx *sql.Context, part sql.Partition) (sql.RowIter2, error) {
|
||||
return nil, fmt.Errorf("cannot get partition rows of a single table information database")
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) LookupPartitions(context *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) {
|
||||
return nil, fmt.Errorf("cannot get paritions of a single table information database")
|
||||
}
|
||||
|
||||
// CreateIndexForForeignKey implements sql.ForeignKeyTable.
|
||||
func (db *SingleTableInfoDatabase) CreateIndexForForeignKey(ctx *sql.Context, idx sql.IndexDef) error {
|
||||
return fmt.Errorf("cannot create foreign keys on a single table information database")
|
||||
}
|
||||
|
||||
// GetDeclaredForeignKeys implements sql.ForeignKeyTable.
|
||||
func (db *SingleTableInfoDatabase) GetDeclaredForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
|
||||
fks := make([]sql.ForeignKeyConstraint, len(db.foreignKeys))
|
||||
for i, fk := range db.foreignKeys {
|
||||
if !fk.IsResolved() {
|
||||
fks[i] = sql.ForeignKeyConstraint{
|
||||
Name: fk.Name,
|
||||
Database: ctx.GetCurrentDatabase(),
|
||||
Table: fk.TableName,
|
||||
Columns: fk.UnresolvedFKDetails.TableColumns,
|
||||
ParentDatabase: ctx.GetCurrentDatabase(),
|
||||
ParentTable: fk.ReferencedTableName,
|
||||
ParentColumns: fk.UnresolvedFKDetails.ReferencedTableColumns,
|
||||
OnUpdate: toReferentialAction(fk.OnUpdate),
|
||||
OnDelete: toReferentialAction(fk.OnDelete),
|
||||
}
|
||||
continue
|
||||
}
|
||||
if parentSch, ok := db.parentSchs[fk.ReferencedTableName]; ok {
|
||||
var err error
|
||||
fks[i], err = toForeignKeyConstraint(fk, ctx.GetCurrentDatabase(), db.sch, parentSch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// We can skip here since the given schema may be purposefully incomplete (such as with diffs).
|
||||
continue
|
||||
}
|
||||
}
|
||||
return fks, nil
|
||||
}
|
||||
|
||||
// GetReferencedForeignKeys implements sql.ForeignKeyTable.
|
||||
func (db *SingleTableInfoDatabase) GetReferencedForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// AddForeignKey implements sql.ForeignKeyTable.
|
||||
func (db *SingleTableInfoDatabase) AddForeignKey(ctx *sql.Context, fk sql.ForeignKeyConstraint) error {
|
||||
return fmt.Errorf("cannot create foreign keys on a single table information database")
|
||||
}
|
||||
|
||||
// DropForeignKey implements sql.ForeignKeyTable.
|
||||
func (db *SingleTableInfoDatabase) DropForeignKey(ctx *sql.Context, fkName string) error {
|
||||
return fmt.Errorf("cannot create foreign keys on a single table information database")
|
||||
}
|
||||
|
||||
// UpdateForeignKey implements sql.ForeignKeyTable.
|
||||
func (db *SingleTableInfoDatabase) UpdateForeignKey(ctx *sql.Context, fkName string, fk sql.ForeignKeyConstraint) error {
|
||||
return fmt.Errorf("cannot create foreign keys on a single table information database")
|
||||
}
|
||||
|
||||
// GetForeignKeyEditor implements sql.ForeignKeyTable.
|
||||
func (db *SingleTableInfoDatabase) GetForeignKeyEditor(ctx *sql.Context) sql.ForeignKeyEditor {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IndexedAccess implements sql.IndexedTable.
|
||||
func (db *SingleTableInfoDatabase) IndexedAccess(lookup sql.IndexLookup) sql.IndexedTable {
|
||||
return db
|
||||
}
|
||||
|
||||
// GetIndexes implements sql.IndexedTable.
|
||||
func (db *SingleTableInfoDatabase) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
|
||||
var sqlIndexes []sql.Index
|
||||
for _, idx := range db.sch.Indexes().AllIndexes() {
|
||||
cols := make([]schema.Column, idx.Count())
|
||||
for i, tag := range idx.IndexedColumnTags() {
|
||||
cols[i], _ = idx.GetColumn(tag)
|
||||
}
|
||||
sqlIndexes = append(sqlIndexes, &fmtIndex{
|
||||
id: idx.Name(),
|
||||
db: db.Name(),
|
||||
tbl: db.tableName,
|
||||
cols: cols,
|
||||
unique: idx.IsUnique(),
|
||||
generated: false,
|
||||
comment: idx.Comment(),
|
||||
})
|
||||
}
|
||||
return sqlIndexes, nil
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) GetChecks(ctx *sql.Context) ([]sql.CheckDefinition, error) {
|
||||
return checksInSchema(db.sch), nil
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) IsTemporary() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) DataLength(ctx *sql.Context) (uint64, error) {
|
||||
// TODO: to answer this accurately, we need the table as well as the schema
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) RowCount(ctx *sql.Context) (uint64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) PrimaryKeySchema() sql.PrimaryKeySchema {
|
||||
sqlSch, err := sqlutil.FromDoltSchema(db.tableName, db.sch)
|
||||
if err != nil {
|
||||
}
|
||||
return sqlSch
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) GetRoot(context *sql.Context) (*doltdb.RootValue, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) DbData() env.DbData {
|
||||
panic("SingleTableInfoDatabase doesn't have DbData")
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) Flush(context *sql.Context) error {
|
||||
panic("SingleTableInfoDatabase cannot Flush")
|
||||
}
|
||||
|
||||
// fmtIndex is used for CREATE TABLE statements only.
|
||||
type fmtIndex struct {
|
||||
id string
|
||||
db string
|
||||
tbl string
|
||||
|
||||
cols []schema.Column
|
||||
unique bool
|
||||
spatial bool
|
||||
generated bool
|
||||
comment string
|
||||
}
|
||||
|
||||
// CanSupport implements sql.Index
|
||||
func (idx fmtIndex) CanSupport(r ...sql.Range) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ID implements sql.Index
|
||||
func (idx fmtIndex) ID() string {
|
||||
return idx.id
|
||||
}
|
||||
|
||||
// Database implements sql.Index
|
||||
func (idx fmtIndex) Database() string {
|
||||
return idx.db
|
||||
}
|
||||
|
||||
// Table implements sql.Index
|
||||
func (idx fmtIndex) Table() string {
|
||||
return idx.tbl
|
||||
}
|
||||
|
||||
// Expressions implements sql.Index
|
||||
func (idx fmtIndex) Expressions() []string {
|
||||
strs := make([]string, len(idx.cols))
|
||||
for i, col := range idx.cols {
|
||||
strs[i] = idx.tbl + "." + col.Name
|
||||
}
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsUnique implements sql.Index
|
||||
func (idx fmtIndex) IsUnique() bool {
|
||||
return idx.unique
|
||||
}
|
||||
|
||||
// IsSpatial implements sql.Index
|
||||
func (idx fmtIndex) IsSpatial() bool {
|
||||
return idx.spatial
|
||||
}
|
||||
|
||||
// Comment implements sql.Index
|
||||
func (idx fmtIndex) Comment() string {
|
||||
return idx.comment
|
||||
}
|
||||
|
||||
// PrefixLengths implements sql.Index
|
||||
func (idx fmtIndex) PrefixLengths() []uint16 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IndexType implements sql.Index
|
||||
func (idx fmtIndex) IndexType() string {
|
||||
return "BTREE"
|
||||
}
|
||||
|
||||
// IsGenerated implements sql.Index
|
||||
func (idx fmtIndex) IsGenerated() bool {
|
||||
return idx.generated
|
||||
}
|
||||
|
||||
func (idx fmtIndex) IndexedAccess(_ sql.IndexLookup) (sql.IndexedTable, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// ColumnExpressionTypes implements sql.Index
|
||||
func (idx fmtIndex) ColumnExpressionTypes() []sql.ColumnExpressionType {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (db *SingleTableInfoDatabase) EditOptions() editor.Options {
|
||||
return editor.Options{}
|
||||
}
|
||||
@@ -113,11 +113,6 @@ func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*
|
||||
return db.GetRoot(ctx)
|
||||
}
|
||||
|
||||
// NewTestSQLCtx returns a new *sql.Context with a default DoltSession, a new IndexRegistry, and a new ViewRegistry
|
||||
func NewTestSQLCtx(ctx context.Context) *sql.Context {
|
||||
return NewTestSQLCtxWithProvider(ctx, dsess.EmptyDatabaseProvider())
|
||||
}
|
||||
|
||||
func NewTestSQLCtxWithProvider(ctx context.Context, pro dsess.DoltDatabaseProvider) *sql.Context {
|
||||
s, err := dsess.NewDoltSession(sql.NewBaseSession(), pro, config2.NewMapConfig(make(map[string]string)), branch_control.CreateDefaultController())
|
||||
if err != nil {
|
||||
@@ -141,44 +136,10 @@ func NewTestEngine(dEnv *env.DoltEnv, ctx context.Context, db SqlDatabase) (*sql
|
||||
|
||||
engine := sqle.NewDefault(pro)
|
||||
sqlCtx := NewTestSQLCtxWithProvider(ctx, pro)
|
||||
|
||||
dbState, err := getDbState(db, dEnv)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = dsess.DSessFromSess(sqlCtx.Session).AddDB(sqlCtx, dbState)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sqlCtx.SetCurrentDatabase(db.Name())
|
||||
return engine, sqlCtx, nil
|
||||
}
|
||||
|
||||
func getDbState(db sql.Database, dEnv *env.DoltEnv) (dsess.InitialDbState, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
head := dEnv.RepoStateReader().CWBHeadSpec()
|
||||
headCommit, err := dEnv.DoltDB.Resolve(ctx, head, dEnv.RepoStateReader().CWBHeadRef())
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
ws, err := dEnv.WorkingSet(ctx)
|
||||
if err != nil {
|
||||
return dsess.InitialDbState{}, err
|
||||
}
|
||||
|
||||
return dsess.InitialDbState{
|
||||
Db: db,
|
||||
HeadCommit: headCommit,
|
||||
WorkingSet: ws,
|
||||
DbData: dEnv.DbData(),
|
||||
Remotes: dEnv.RepoState.Remotes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExecuteSelect executes the select statement given and returns the resulting rows, or an error if one is encountered.
|
||||
func ExecuteSelect(dEnv *env.DoltEnv, root *doltdb.RootValue, query string) ([]sql.Row, error) {
|
||||
dbData := env.DbData{
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
package sqle
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
@@ -78,8 +76,15 @@ func (db *UserSpaceDatabase) GetTableNames(ctx *sql.Context) ([]string, error) {
|
||||
return resultingTblNames, nil
|
||||
}
|
||||
|
||||
func (db *UserSpaceDatabase) InitialDBState(ctx context.Context, branch string) (dsess.InitialDbState, error) {
|
||||
return getInitialDBStateForUserSpaceDb(ctx, db)
|
||||
func (db *UserSpaceDatabase) InitialDBState(ctx *sql.Context, branch string) (dsess.InitialDbState, error) {
|
||||
return dsess.InitialDbState{
|
||||
Db: db,
|
||||
ReadOnly: true,
|
||||
HeadRoot: db.RootValue,
|
||||
DbData: env.DbData{
|
||||
Rsw: noopRepoStateWriter{},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *UserSpaceDatabase) GetRoot(*sql.Context) (*doltdb.RootValue, error) {
|
||||
@@ -101,3 +106,15 @@ func (db *UserSpaceDatabase) Flush(ctx *sql.Context) error {
|
||||
func (db *UserSpaceDatabase) EditOptions() editor.Options {
|
||||
return db.editOpts
|
||||
}
|
||||
|
||||
func (db *UserSpaceDatabase) Revision() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (db *UserSpaceDatabase) RevisionType() dsess.RevisionType {
|
||||
return dsess.RevisionTypeNone
|
||||
}
|
||||
|
||||
func (db *UserSpaceDatabase) BaseName() string {
|
||||
return db.Name()
|
||||
}
|
||||
|
||||
@@ -22,11 +22,9 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/row"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
@@ -42,8 +40,6 @@ type tableEditorTest struct {
|
||||
selectQuery string
|
||||
// The rows this query should return, nil if an error is expected
|
||||
expectedRows []sql.Row
|
||||
// Expected error string, if any
|
||||
expectedErr string
|
||||
}
|
||||
|
||||
func TestTableEditor(t *testing.T) {
|
||||
@@ -60,7 +56,6 @@ func TestTableEditor(t *testing.T) {
|
||||
fatTony := sqle.NewPeopleRow(16, "Fat", "Tony", false, 53, 5.0)
|
||||
troyMclure := sqle.NewPeopleRow(17, "Troy", "McClure", false, 58, 7.0)
|
||||
|
||||
var expectedErr error
|
||||
// Some of these are pretty exotic use cases, but since we support all these operations it's nice to know they work
|
||||
// in tandem.
|
||||
testCases := []tableEditorTest{
|
||||
@@ -159,27 +154,18 @@ func TestTableEditor(t *testing.T) {
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
expectedErr = nil
|
||||
|
||||
dEnv, err := sqle.CreateTestDatabase()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := sqle.NewTestSQLCtx(context.Background())
|
||||
root, err := dEnv.WorkingRoot(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
tmpDir, err := dEnv.TempTableFilesDir()
|
||||
require.NoError(t, err)
|
||||
opts := editor.Options{Deaf: dEnv.DbEaFactory(), Tempdir: tmpDir}
|
||||
db, err := sqle.NewDatabase(ctx, "dolt", dEnv.DbData(), opts)
|
||||
db, err := sqle.NewDatabase(context.Background(), "dolt", dEnv.DbData(), opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = dsess.DSessFromSess(ctx.Session).AddDB(ctx, getDbState(t, db, dEnv))
|
||||
engine, ctx, err := sqle.NewTestEngine(dEnv, context.Background(), db)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx.SetCurrentDatabase(db.Name())
|
||||
err = db.SetRoot(ctx, root)
|
||||
require.NoError(t, err)
|
||||
peopleTable, _, err := db.GetTableInsensitive(ctx, "people")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -187,20 +173,18 @@ func TestTableEditor(t *testing.T) {
|
||||
ed := dt.Updater(ctx).(writer.TableWriter)
|
||||
|
||||
test.setup(ctx, t, ed)
|
||||
if len(test.expectedErr) > 0 {
|
||||
require.Error(t, expectedErr)
|
||||
assert.Contains(t, expectedErr.Error(), test.expectedErr)
|
||||
return
|
||||
} else {
|
||||
require.NoError(t, ed.Close(ctx))
|
||||
}
|
||||
require.NoError(t, ed.Close(ctx))
|
||||
|
||||
root, err = db.GetRoot(ctx)
|
||||
root, err := db.GetRoot(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TODO: not clear why this is necessary, the call to ed.Close should update the working set already
|
||||
require.NoError(t, dEnv.UpdateWorkingRoot(context.Background(), root))
|
||||
|
||||
actualRows, err := sqle.ExecuteSelect(dEnv, root, test.selectQuery)
|
||||
sch, rowIter, err := engine.Query(ctx, test.selectQuery)
|
||||
require.NoError(t, err)
|
||||
|
||||
actualRows, err := sql.RowIterToRows(ctx, sch, rowIter)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedRows, actualRows)
|
||||
@@ -215,22 +199,3 @@ func r(r row.Row, sch schema.Schema) sql.Row {
|
||||
}
|
||||
return sqlRow
|
||||
}
|
||||
|
||||
func getDbState(t *testing.T, db sql.Database, dEnv *env.DoltEnv) dsess.InitialDbState {
|
||||
ctx := context.Background()
|
||||
|
||||
head := dEnv.RepoStateReader().CWBHeadSpec()
|
||||
headCommit, err := dEnv.DoltDB.Resolve(ctx, head, dEnv.RepoStateReader().CWBHeadRef())
|
||||
require.NoError(t, err)
|
||||
|
||||
ws, err := dEnv.WorkingSet(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
return dsess.InitialDbState{
|
||||
Db: db,
|
||||
HeadCommit: headCommit,
|
||||
WorkingSet: ws,
|
||||
DbData: dEnv.DbData(),
|
||||
Remotes: dEnv.RepoState.Remotes,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,7 +275,7 @@ func RunQueryAttempt(t require.TestingT, conn *sql.Conn, q driver.Query) {
|
||||
defer rows.Close()
|
||||
}
|
||||
if q.ErrorMatch != "" {
|
||||
require.Error(t, err)
|
||||
require.Error(t, err, "expected error running query %s", q.Query)
|
||||
require.Regexp(t, q.ErrorMatch, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user