mirror of
https://github.com/dolthub/dolt.git
synced 2026-02-10 18:49:02 -06:00
added sess.LookupDbState()
This commit is contained in:
@@ -49,7 +49,6 @@ func IsValidCommitHash(s string) bool {
|
||||
return hashRegex.MatchString(s)
|
||||
}
|
||||
|
||||
|
||||
type commitSpecType string
|
||||
|
||||
const (
|
||||
|
||||
@@ -493,8 +493,11 @@ var hashType = sql.MustCreateString(query.Type_TEXT, 32, sql.Collation_ascii_bin
|
||||
// GetRoot returns the root value for this database session
|
||||
func (db Database) GetRoot(ctx *sql.Context) (*doltdb.RootValue, error) {
|
||||
sess := dsess.DSessFromSess(ctx.Session)
|
||||
dbState, dbRootOk := sess.DbStates[db.name]
|
||||
if !dbRootOk {
|
||||
dbState, ok, err := sess.LookupDbState(ctx, db.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no root value found in session")
|
||||
}
|
||||
|
||||
@@ -746,7 +749,11 @@ func (db Database) RenameTable(ctx *sql.Context, oldName, newName string) error
|
||||
// Flush flushes the current batch of outstanding changes and returns any errors.
|
||||
func (db Database) Flush(ctx *sql.Context) error {
|
||||
sess := dsess.DSessFromSess(ctx.Session)
|
||||
editSession := sess.DbStates[db.name].EditSession
|
||||
dbState, _, err := sess.LookupDbState(ctx, db.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
editSession := dbState.EditSession
|
||||
|
||||
newRoot, err := editSession.Flush(ctx)
|
||||
if err != nil {
|
||||
@@ -760,7 +767,7 @@ func (db Database) Flush(ctx *sql.Context) error {
|
||||
|
||||
// Flush any changes made to temporary tables
|
||||
// TODO: Shouldn't always be updating both roots. Needs to update either both roots or neither of them, atomically
|
||||
tempTableEditSession := sess.DbStates[db.name].TempTableEditSession
|
||||
tempTableEditSession := dbState.TempTableEditSession
|
||||
if tempTableEditSession != nil {
|
||||
newTempTableRoot, err := tempTableEditSession.Flush(ctx)
|
||||
if err != nil {
|
||||
@@ -964,7 +971,12 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
|
||||
|
||||
// If rows exist, then grab the highest id and add 1 to get the new id
|
||||
indexToUse := int64(1)
|
||||
te, err := db.TableEditSession(ctx, tbl.IsTemporary()).GetTableEditor(ctx, doltdb.SchemasTableName, tbl.sch)
|
||||
ts, err := db.TableEditSession(ctx, tbl.IsTemporary())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
te, err := ts.GetTableEditor(ctx, doltdb.SchemasTableName, tbl.sch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1028,20 +1040,29 @@ func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name str
|
||||
}
|
||||
|
||||
// TableEditSession returns the TableEditSession for this database from the given context.
|
||||
func (db Database) TableEditSession(ctx *sql.Context, isTemporary bool) *editor.TableEditSession {
|
||||
if isTemporary {
|
||||
return dsess.DSessFromSess(ctx.Session).DbStates[db.name].TempTableEditSession
|
||||
func (db Database) TableEditSession(ctx *sql.Context, isTemporary bool) (*editor.TableEditSession, error) {
|
||||
sess := dsess.DSessFromSess(ctx.Session)
|
||||
dbState, _, err := sess.LookupDbState(ctx, db.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dsess.DSessFromSess(ctx.Session).DbStates[db.name].EditSession
|
||||
|
||||
if isTemporary {
|
||||
return dbState.TempTableEditSession, nil
|
||||
}
|
||||
return dbState.EditSession, nil
|
||||
}
|
||||
|
||||
// GetAllTemporaryTables returns all temporary tables
|
||||
func (db Database) GetAllTemporaryTables(ctx *sql.Context) ([]sql.Table, error) {
|
||||
sess := dsess.DSessFromSess(ctx.Session)
|
||||
dbState, _, err := sess.LookupDbState(ctx, db.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tables := make([]sql.Table, 0)
|
||||
|
||||
root := sess.DbStates[db.name].TempTableRoot
|
||||
root := dbState.TempTableRoot
|
||||
if root != nil {
|
||||
tNames, err := root.GetTableNames(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -290,4 +290,3 @@ type staticRepoState struct {
|
||||
func (s staticRepoState) CWBHeadRef() ref.DoltRef {
|
||||
return s.branch
|
||||
}
|
||||
|
||||
|
||||
@@ -38,13 +38,13 @@ func (ab *ActiveBranchFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, er
|
||||
dbName := ctx.GetCurrentDatabase()
|
||||
dSess := dsess.DSessFromSess(ctx.Session)
|
||||
|
||||
ddb, ok := dSess.GetDoltDB(dbName)
|
||||
ddb, ok := dSess.GetDoltDB(ctx, dbName)
|
||||
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
currentBranchRef, err := dSess.CWBHeadRef(dbName)
|
||||
currentBranchRef, err := dSess.CWBHeadRef(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -75,10 +75,11 @@ func (cf *CommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
root, ok := dSess.GetRoot(dbName)
|
||||
roots, ok := dSess.GetRoots(ctx, dbName)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown database '%s'", dbName)
|
||||
}
|
||||
root := roots.Working
|
||||
|
||||
// Update the superschema to with any new information from the table map.
|
||||
tblNames, err := root.GetTableNames(ctx)
|
||||
@@ -91,7 +92,7 @@ func (cf *CommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ddb, ok := dSess.GetDoltDB(dbName)
|
||||
ddb, ok := dSess.GetDoltDB(ctx, dbName)
|
||||
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
|
||||
@@ -53,7 +53,7 @@ func (d DoltAddFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
allFlag := apr.Contains(cli.AllFlag)
|
||||
|
||||
dSess := dsess.DSessFromSess(ctx.Session)
|
||||
roots, ok := dSess.GetRoots(dbName)
|
||||
roots, ok := dSess.GetRoots(ctx, dbName)
|
||||
if apr.NArg() == 0 && !allFlag {
|
||||
return 1, fmt.Errorf("Nothing specified, nothing added. Maybe you wanted to say 'dolt add .'?")
|
||||
} else if allFlag || apr.NArg() == 1 && apr.Arg(0) == "." {
|
||||
|
||||
@@ -63,12 +63,12 @@ func (d DoltCheckoutFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, erro
|
||||
|
||||
// Checking out new branch.
|
||||
dSess := dsess.DSessFromSess(ctx.Session)
|
||||
dbData, ok := dSess.GetDbData(dbName)
|
||||
dbData, ok := dSess.GetDbData(ctx, dbName)
|
||||
if !ok {
|
||||
return 1, fmt.Errorf("Could not load database %s", dbName)
|
||||
}
|
||||
|
||||
roots, ok := dSess.GetRoots(dbName)
|
||||
roots, ok := dSess.GetRoots(ctx, dbName)
|
||||
if !ok {
|
||||
return 1, fmt.Errorf("Could not load database %s", dbName)
|
||||
}
|
||||
|
||||
@@ -60,8 +60,7 @@ func (d DoltCommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
|
||||
}
|
||||
|
||||
dSess := dsess.DSessFromSess(ctx.Session)
|
||||
|
||||
roots, ok := dSess.GetRoots(dbName)
|
||||
roots, ok := dSess.GetRoots(nil, dbName)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Could not load database %s", dbName)
|
||||
}
|
||||
@@ -178,7 +177,11 @@ func CommitToDolt(
|
||||
// repo state writer, so we're never persisting the new working set to disk like in a command line context.
|
||||
// TODO: fix this mess
|
||||
|
||||
ws := dSess.WorkingSet(ctx, dbName)
|
||||
ws, err := dSess.WorkingSet(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// StartTransaction sets the working set for the session, and we want the one we previous had, not the one on disk
|
||||
// Updating the working set like this also updates the head commit and root info for the session
|
||||
tx, err := dSess.StartTransaction(ctx, dbName)
|
||||
@@ -186,6 +189,10 @@ func CommitToDolt(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ws, err = dSess.WorkingSet(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = dSess.SetWorkingSet(ctx, dbName, ws.ClearMerge(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -46,7 +46,7 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
|
||||
}
|
||||
|
||||
sess := dsess.DSessFromSess(ctx.Session)
|
||||
dbData, ok := sess.GetDbData(dbName)
|
||||
dbData, ok := sess.GetDbData(nil, dbName)
|
||||
|
||||
if !ok {
|
||||
return 1, fmt.Errorf("Could not load database %s", dbName)
|
||||
@@ -68,8 +68,11 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
|
||||
return 1, fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together.\n", cli.SquashParam, cli.NoFFParam)
|
||||
}
|
||||
|
||||
ws := sess.WorkingSet(ctx, dbName)
|
||||
roots, ok := sess.GetRoots(dbName)
|
||||
ws, err := sess.WorkingSet(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roots, ok := sess.GetRoots(nil, dbName)
|
||||
|
||||
// logrus.Errorf("heads are working: %s\nhead: %s", roots.Working.DebugString(ctx, true), roots.Head.DebugString(ctx, true))
|
||||
|
||||
@@ -95,7 +98,7 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
|
||||
return "Merge aborted", nil
|
||||
}
|
||||
|
||||
ddb, ok := sess.GetDoltDB(dbName)
|
||||
ddb, ok := sess.GetDoltDB(nil, dbName)
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
@@ -317,7 +320,7 @@ func executeNoFFMerge(
|
||||
}
|
||||
|
||||
// The roots need refreshing after the above
|
||||
roots, _ := dSess.GetRoots(dbName)
|
||||
roots, _ := dSess.GetRoots(ctx, dbName)
|
||||
|
||||
// TODO: this does several session state updates, and it really needs to just do one
|
||||
// We also need to commit any pending transaction before we do this.
|
||||
|
||||
@@ -84,11 +84,11 @@ func resolveRefSpecs(ctx *sql.Context, leftSpec, rightSpec string) (left, right
|
||||
sess := dsess.DSessFromSess(ctx.Session)
|
||||
dbName := ctx.GetCurrentDatabase()
|
||||
|
||||
dbData, ok := sess.GetDbData(dbName)
|
||||
dbData, ok := sess.GetDbData(ctx, dbName)
|
||||
if !ok {
|
||||
return nil, nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
doltDB, ok := sess.GetDoltDB(dbName)
|
||||
doltDB, ok := sess.GetDoltDB(ctx, dbName)
|
||||
if !ok {
|
||||
return nil, nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func (d DoltResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
|
||||
}
|
||||
|
||||
dSess := dsess.DSessFromSess(ctx.Session)
|
||||
dbData, ok := dSess.GetDbData(dbName)
|
||||
dbData, ok := dSess.GetDbData(ctx, dbName)
|
||||
|
||||
if !ok {
|
||||
return 1, fmt.Errorf("Could not load database %s", dbName)
|
||||
@@ -64,7 +64,7 @@ func (d DoltResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
|
||||
}
|
||||
|
||||
// Get all the needed roots.
|
||||
roots, ok := dSess.GetRoots(dbName)
|
||||
roots, ok := dSess.GetRoots(nil, dbName)
|
||||
if !ok {
|
||||
return 1, fmt.Errorf("Could not load database %s", dbName)
|
||||
}
|
||||
@@ -91,7 +91,10 @@ func (d DoltResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
|
||||
}
|
||||
}
|
||||
|
||||
ws := dSess.WorkingSet(ctx, dbName)
|
||||
ws, err := dSess.WorkingSet(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = dSess.SetWorkingSet(ctx, dbName, ws.WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged), nil)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
|
||||
@@ -63,7 +63,7 @@ func (t *HashOf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
}
|
||||
|
||||
dbName := ctx.GetCurrentDatabase()
|
||||
ddb, ok := dsess.DSessFromSess(ctx.Session).GetDoltDB(dbName)
|
||||
ddb, ok := dsess.DSessFromSess(ctx.Session).GetDoltDB(ctx, dbName)
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
@@ -71,12 +71,12 @@ func (cf *MergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
}
|
||||
|
||||
dbName := sess.GetCurrentDatabase()
|
||||
ddb, ok := sess.GetDoltDB(dbName)
|
||||
ddb, ok := sess.GetDoltDB(nil, dbName)
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
root, ok := sess.GetRoot(dbName)
|
||||
roots, ok := sess.GetRoots(ctx, dbName)
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
@@ -86,7 +86,7 @@ func (cf *MergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = checkForUncommittedChanges(root, headRoot)
|
||||
err = checkForUncommittedChanges(roots.Working, headRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -48,12 +48,12 @@ func (s SquashFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
}
|
||||
|
||||
dbName := sess.GetCurrentDatabase()
|
||||
ddb, ok := sess.GetDoltDB(dbName)
|
||||
ddb, ok := sess.GetDoltDB(ctx, dbName)
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
root, ok := sess.GetRoot(dbName)
|
||||
roots, ok := sess.GetRoots(ctx, dbName)
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
@@ -63,7 +63,7 @@ func (s SquashFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = checkForUncommittedChanges(root, headRoot)
|
||||
err = checkForUncommittedChanges(roots.Working, headRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -128,7 +128,7 @@ type Session struct {
|
||||
Username string
|
||||
Email string
|
||||
// TODO: make this private again
|
||||
DbStates map[string]*DatabaseSessionState
|
||||
dbStates map[string]*DatabaseSessionState
|
||||
}
|
||||
|
||||
type DatabaseSessionState struct {
|
||||
@@ -161,7 +161,7 @@ func DefaultSession() *Session {
|
||||
Session: sql.NewBaseSession(),
|
||||
Username: "",
|
||||
Email: "",
|
||||
DbStates: make(map[string]*DatabaseSessionState),
|
||||
dbStates: make(map[string]*DatabaseSessionState),
|
||||
}
|
||||
return sess
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func NewSession(ctx *sql.Context, sqlSess sql.Session, username, email string, d
|
||||
Session: sqlSess,
|
||||
Username: username,
|
||||
Email: email,
|
||||
DbStates: make(map[string]*DatabaseSessionState),
|
||||
dbStates: make(map[string]*DatabaseSessionState),
|
||||
}
|
||||
|
||||
for _, db := range dbs {
|
||||
@@ -206,12 +206,21 @@ func DSessFromSess(sess sql.Session) *Session {
|
||||
return sess.(*Session)
|
||||
}
|
||||
|
||||
func (sess *Session) LookupDbState(ctx *sql.Context, dbName string) (*DatabaseSessionState, bool, error) {
|
||||
dbState, ok := sess.dbStates[dbName]
|
||||
return dbState, ok, nil
|
||||
}
|
||||
|
||||
// Flush flushes all changes sitting in edit sessions to the session root for the database named. This normally
|
||||
// happens automatically as part of statement execution, and is only necessary when the session is manually batched (as
|
||||
// for bulk SQL import)
|
||||
func (sess *Session) Flush(ctx *sql.Context, dbName string) error {
|
||||
editSession := sess.DbStates[dbName].EditSession
|
||||
newRoot, err := editSession.Flush(ctx)
|
||||
dbState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newRoot, err := dbState.EditSession.Flush(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -232,7 +241,10 @@ func (sess *Session) StartTransaction(ctx *sql.Context, dbName string) (sql.Tran
|
||||
return DisabledTransaction{}, nil
|
||||
}
|
||||
|
||||
sessionState := sess.DbStates[dbName]
|
||||
sessionState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wsRef := sessionState.WorkingSet.Ref()
|
||||
ws, err := sessionState.dbData.Ddb.ResolveWorkingSet(ctx, wsRef)
|
||||
@@ -257,7 +269,7 @@ func (sess *Session) StartTransaction(ctx *sql.Context, dbName string) (sql.Tran
|
||||
}
|
||||
|
||||
func (sess *Session) newWorkingSetForHead(ctx *sql.Context, wsRef ref.WorkingSetRef, dbName string) (*doltdb.WorkingSet, error) {
|
||||
dbData, _ := sess.GetDbData(dbName)
|
||||
dbData, _ := sess.GetDbData(nil, dbName)
|
||||
|
||||
headSpec, _ := doltdb.NewCommitSpec("HEAD")
|
||||
headRef, err := wsRef.ToHeadRef()
|
||||
@@ -287,11 +299,16 @@ func (sess *Session) CommitTransaction(ctx *sql.Context, dbName string, tx sql.T
|
||||
}
|
||||
}
|
||||
|
||||
dbState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if TransactionsDisabled(ctx) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !sess.DbStates[dbName].dirty {
|
||||
if !dbState.dirty {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -301,7 +318,11 @@ func (sess *Session) CommitTransaction(ctx *sql.Context, dbName string, tx sql.T
|
||||
return nil
|
||||
}
|
||||
|
||||
dbstate, ok := sess.DbStates[dbName]
|
||||
dbState, ok, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// It's possible that this returns false if the user has created an in-Memory database. Moreover,
|
||||
// the analyzer will check for us whether a db exists or not.
|
||||
// TODO: fix this
|
||||
@@ -320,7 +341,7 @@ func (sess *Session) CommitTransaction(ctx *sql.Context, dbName string, tx sql.T
|
||||
// TODO: actual logging
|
||||
// logrus.Errorf("working root to commit is %s", dbstate.workingSet.WorkingRoot().DebugString(ctx, true))
|
||||
|
||||
mergedWorkingSet, err := dtx.Commit(ctx, dbstate.WorkingSet)
|
||||
mergedWorkingSet, err := dtx.Commit(ctx, dbState.WorkingSet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -337,7 +358,7 @@ func (sess *Session) CommitTransaction(ctx *sql.Context, dbName string, tx sql.T
|
||||
return err
|
||||
}
|
||||
|
||||
dbstate.dirty = false
|
||||
dbState.dirty = false
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -347,7 +368,10 @@ func (sess *Session) CommitToDolt(
|
||||
dbName string,
|
||||
props actions.CommitStagedProps,
|
||||
) (*doltdb.Commit, error) {
|
||||
sessionState := sess.DbStates[dbName]
|
||||
sessionState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbData := sessionState.dbData
|
||||
|
||||
// TODO: this does several session state updates, and it really needs to just do one
|
||||
@@ -368,7 +392,10 @@ func (sess *Session) CommitToDolt(
|
||||
// repo state writer, so we're never persisting the new working set to disk like in a command line context.
|
||||
// TODO: fix this mess
|
||||
|
||||
ws := sess.WorkingSet(ctx, dbName)
|
||||
ws, err := sess.WorkingSet(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// StartTransaction sets the working set for the session, and we want the one we previous had, not the one on disk
|
||||
// Updating the working set like this also updates the head commit and root info for the session
|
||||
tx, err := sess.StartTransaction(ctx, dbName)
|
||||
@@ -406,7 +433,10 @@ func (sess *Session) CreateDoltCommit(ctx *sql.Context, dbName string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionState := sess.DbStates[dbName]
|
||||
sessionState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roots := sessionState.GetRoots()
|
||||
|
||||
roots, err = actions.StageAllTablesNoDocs(ctx, roots)
|
||||
@@ -435,7 +465,12 @@ func (sess *Session) RollbackTransaction(ctx *sql.Context, dbName string, tx sql
|
||||
return nil
|
||||
}
|
||||
|
||||
if !sess.DbStates[dbName].dirty {
|
||||
dbState, ok, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !dbState.dirty {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -444,12 +479,12 @@ func (sess *Session) RollbackTransaction(ctx *sql.Context, dbName string, tx sql
|
||||
return fmt.Errorf("expected a DoltTransaction")
|
||||
}
|
||||
|
||||
err := sess.SetRoot(ctx, dbName, dtx.startState.WorkingRoot())
|
||||
err = sess.SetRoot(ctx, dbName, dtx.startState.WorkingRoot())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sess.DbStates[dbName].dirty = false
|
||||
dbState.dirty = false
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -465,7 +500,12 @@ func (sess *Session) CreateSavepoint(ctx *sql.Context, savepointName, dbName str
|
||||
return fmt.Errorf("expected a DoltTransaction")
|
||||
}
|
||||
|
||||
dtx.CreateSavepoint(savepointName, sess.DbStates[dbName].GetRoots().Working)
|
||||
dbState, ok, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dtx.CreateSavepoint(savepointName, dbState.GetRoots().Working)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -515,40 +555,16 @@ func (sess *Session) ReleaseSavepoint(ctx *sql.Context, savepointName, dbName st
|
||||
}
|
||||
|
||||
// GetDoltDB returns the *DoltDB for a given database by name
|
||||
func (sess *Session) GetDoltDB(dbName string) (*doltdb.DoltDB, bool) {
|
||||
dbstate, ok := sess.DbStates[dbName]
|
||||
func (sess *Session) GetDoltDB(ctx *sql.Context, dbName string) (*doltdb.DoltDB, bool) {
|
||||
dbState, ok, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return dbstate.dbData.Ddb, true
|
||||
}
|
||||
|
||||
func (sess *Session) GetDoltDBRepoStateWriter(dbName string) (env.RepoStateWriter, bool) {
|
||||
d, ok := sess.DbStates[dbName]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return d.dbData.Rsw, true
|
||||
}
|
||||
|
||||
func (sess *Session) GetDoltDBRepoStateReader(dbName string) (env.RepoStateReader, bool) {
|
||||
d, ok := sess.DbStates[dbName]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return d.dbData.Rsr, true
|
||||
}
|
||||
|
||||
func (sess *Session) GetDoltDBDocsReadWriter(dbName string) (env.DocsReadWriter, bool) {
|
||||
d, ok := sess.DbStates[dbName]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return d.dbData.Drw, true
|
||||
return dbState.dbData.Ddb, true
|
||||
}
|
||||
|
||||
func (sess *Session) GetDoltDbAutoIncrementTracker(dbName string) (globalstate.AutoIncrementTracker, bool) {
|
||||
@@ -569,33 +585,29 @@ func (sess *Session) GetDoltDbAutoIncrementTracker(dbName string) (globalstate.A
|
||||
return tracker, true
|
||||
}
|
||||
|
||||
func (sess *Session) GetDbData(dbName string) (env.DbData, bool) {
|
||||
sessionState, ok := sess.DbStates[dbName]
|
||||
func (sess *Session) GetDbData(ctx *sql.Context, dbName string) (env.DbData, bool) {
|
||||
dbState, ok, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return env.DbData{}, false
|
||||
}
|
||||
if !ok {
|
||||
return env.DbData{}, false
|
||||
}
|
||||
|
||||
return sessionState.dbData, true
|
||||
}
|
||||
|
||||
// GetRoot returns the current working *RootValue for a given database associated with the session
|
||||
func (sess *Session) GetRoot(dbName string) (*doltdb.RootValue, bool) {
|
||||
dbstate, ok := sess.DbStates[dbName]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return dbstate.GetRoots().Working, true
|
||||
return dbState.dbData, true
|
||||
}
|
||||
|
||||
// GetRoots returns the current roots for a given database associated with the session
|
||||
func (sess *Session) GetRoots(dbName string) (doltdb.Roots, bool) {
|
||||
dbstate, ok := sess.DbStates[dbName]
|
||||
func (sess *Session) GetRoots(ctx *sql.Context, dbName string) (doltdb.Roots, bool) {
|
||||
dbState, ok, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return doltdb.Roots{}, false
|
||||
}
|
||||
if !ok {
|
||||
return doltdb.Roots{}, false
|
||||
}
|
||||
|
||||
return dbstate.GetRoots(), true
|
||||
return dbState.GetRoots(), true
|
||||
}
|
||||
|
||||
// SetRoot sets a new root value for the session for the database named. This is the primary mechanism by which data
|
||||
@@ -604,7 +616,11 @@ func (sess *Session) GetRoots(dbName string) (doltdb.Roots, bool) {
|
||||
// Data changes contained in the |newRoot| aren't persisted until this session is committed.
|
||||
// TODO: rename to SetWorkingRoot
|
||||
func (sess *Session) SetRoot(ctx *sql.Context, dbName string, newRoot *doltdb.RootValue) error {
|
||||
sessionState := sess.DbStates[dbName]
|
||||
sessionState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rootsEqual(sessionState.GetRoots().Working, newRoot) {
|
||||
return nil
|
||||
}
|
||||
@@ -616,7 +632,10 @@ func (sess *Session) SetRoot(ctx *sql.Context, dbName string, newRoot *doltdb.Ro
|
||||
func (sess *Session) setRoot(ctx *sql.Context, dbName string, newRoot *doltdb.RootValue) error {
|
||||
// logrus.Tracef("setting root value %s", newRoot.DebugString(ctx, true))
|
||||
|
||||
sessionState := sess.DbStates[dbName]
|
||||
sessionState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h, err := newRoot.HashOf()
|
||||
if err != nil {
|
||||
@@ -645,7 +664,12 @@ func (sess *Session) setRoot(ctx *sql.Context, dbName string, newRoot *doltdb.Ro
|
||||
// Unlike setting the only the working root, this method always marks the database state dirty.
|
||||
func (sess *Session) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roots) error {
|
||||
// TODO: handle HEAD here?
|
||||
workingSet := sess.DbStates[dbName].WorkingSet.WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged)
|
||||
sessionState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
workingSet := sessionState.WorkingSet.WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged)
|
||||
return sess.SetWorkingSet(ctx, dbName, workingSet, nil)
|
||||
}
|
||||
|
||||
@@ -662,7 +686,10 @@ func (sess *Session) SetWorkingSet(
|
||||
panic("attempted to set a nil working set for the session")
|
||||
}
|
||||
|
||||
sessionState := sess.DbStates[dbName]
|
||||
sessionState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionState.WorkingSet = ws
|
||||
|
||||
if headRoot == nil && !sessionState.detachedHead {
|
||||
@@ -693,7 +720,7 @@ func (sess *Session) SetWorkingSet(
|
||||
sessionState.headRoot = headRoot
|
||||
}
|
||||
|
||||
err := sess.setSessionVarsForDb(ctx, dbName)
|
||||
err = sess.setSessionVarsForDb(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -716,7 +743,10 @@ func (sess *Session) SwitchWorkingSet(
|
||||
ctx *sql.Context,
|
||||
dbName string,
|
||||
wsRef ref.WorkingSetRef) error {
|
||||
sessionState := sess.DbStates[dbName]
|
||||
sessionState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sessionState.dirty {
|
||||
return fmt.Errorf("Cannot switch working set, session state is dirty. " +
|
||||
@@ -778,33 +808,46 @@ func (sess *Session) SwitchWorkingSet(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sess *Session) WorkingSet(ctx *sql.Context, dbName string) *doltdb.WorkingSet {
|
||||
sessionState := sess.DbStates[dbName]
|
||||
return sessionState.WorkingSet
|
||||
func (sess *Session) WorkingSet(ctx *sql.Context, dbName string) (*doltdb.WorkingSet, error) {
|
||||
sessionState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessionState.WorkingSet, nil
|
||||
}
|
||||
|
||||
func (sess *Session) GetTempTableRootValue(ctx *sql.Context, dbName string) (*doltdb.RootValue, bool) {
|
||||
dbstate, ok := sess.DbStates[dbName]
|
||||
dbState, ok, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if dbstate.TempTableRoot == nil {
|
||||
if dbState.TempTableRoot == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return dbstate.TempTableRoot, true
|
||||
return dbState.TempTableRoot, true
|
||||
}
|
||||
|
||||
func (sess *Session) SetTempTableRoot(ctx *sql.Context, dbName string, newRoot *doltdb.RootValue) error {
|
||||
sess.DbStates[dbName].TempTableRoot = newRoot
|
||||
return sess.DbStates[dbName].TempTableEditSession.SetRoot(ctx, newRoot)
|
||||
dbState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dbState.TempTableRoot = newRoot
|
||||
return dbState.TempTableEditSession.SetRoot(ctx, newRoot)
|
||||
}
|
||||
|
||||
// GetHeadCommit returns the parent commit of the current session.
|
||||
func (sess *Session) GetHeadCommit(ctx *sql.Context, dbName string) (*doltdb.Commit, error) {
|
||||
dbState, dbFound := sess.DbStates[dbName]
|
||||
if !dbFound {
|
||||
dbState, ok, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
@@ -822,7 +865,12 @@ func (sess *Session) SetSessionVariable(ctx *sql.Context, key string, value inte
|
||||
return err
|
||||
}
|
||||
|
||||
sess.DbStates[dbName].detachedHead = true
|
||||
dbState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbState.detachedHead = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -849,11 +897,11 @@ func (sess *Session) setForeignKeyChecksSessionVar(ctx *sql.Context, key string,
|
||||
intVal = convertedVal.(int64)
|
||||
}
|
||||
if intVal == 0 {
|
||||
for _, dbState := range sess.DbStates {
|
||||
for _, dbState := range sess.dbStates {
|
||||
dbState.EditSession.Props.ForeignKeyChecksDisabled = true
|
||||
}
|
||||
} else if intVal == 1 {
|
||||
for _, dbState := range sess.DbStates {
|
||||
for _, dbState := range sess.dbStates {
|
||||
dbState.EditSession.Props.ForeignKeyChecksDisabled = false
|
||||
}
|
||||
} else {
|
||||
@@ -869,12 +917,15 @@ func (sess *Session) setWorkingSessionVar(ctx *sql.Context, value interface{}, d
|
||||
return doltdb.ErrInvalidHash
|
||||
}
|
||||
|
||||
dbstate, dbFound := sess.DbStates[dbName]
|
||||
if !dbFound {
|
||||
dbState, ok, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
root, err := dbstate.dbData.Ddb.ReadRootValue(ctx, hash.Parse(valStr))
|
||||
root, err := dbState.dbData.Ddb.ReadRootValue(ctx, hash.Parse(valStr))
|
||||
if errors.Is(doltdb.ErrNoRootValAtHash, err) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
@@ -885,8 +936,11 @@ func (sess *Session) setWorkingSessionVar(ctx *sql.Context, value interface{}, d
|
||||
}
|
||||
|
||||
func (sess *Session) setHeadSessionVar(ctx *sql.Context, value interface{}, dbName string) error {
|
||||
dbstate, dbFound := sess.DbStates[dbName]
|
||||
if !dbFound {
|
||||
dbState, ok, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
@@ -901,19 +955,19 @@ func (sess *Session) setHeadSessionVar(ctx *sql.Context, value interface{}, dbNa
|
||||
return err
|
||||
}
|
||||
|
||||
cm, err := dbstate.dbData.Ddb.Resolve(ctx, cs, nil)
|
||||
cm, err := dbState.dbData.Ddb.Resolve(ctx, cs, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbstate.headCommit = cm
|
||||
dbState.headCommit = cm
|
||||
|
||||
root, err := cm.GetRootValue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbstate.headRoot = root
|
||||
dbState.headRoot = root
|
||||
|
||||
err = sess.Session.SetSessionVariable(ctx, HeadKey(dbName), value)
|
||||
if err != nil {
|
||||
@@ -937,7 +991,7 @@ func (sess *Session) AddDB(ctx *sql.Context, dbState InitialDbState) error {
|
||||
defineSystemVariables(db.Name())
|
||||
|
||||
sessionState := &DatabaseSessionState{}
|
||||
sess.DbStates[db.Name()] = sessionState
|
||||
sess.dbStates[db.Name()] = sessionState
|
||||
|
||||
// TODO: get rid of all repo state reader / writer stuff. Until we do, swap out the reader with one of our own, and
|
||||
// the writer with one that errors out
|
||||
@@ -991,19 +1045,30 @@ func (sess *Session) CreateTemporaryTablesRoot(ctx *sql.Context, dbName string,
|
||||
return err
|
||||
}
|
||||
|
||||
sess.DbStates[dbName].TempTableEditSession = editor.CreateTableEditSession(newRoot, editor.TableEditSessionProps{})
|
||||
dbState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dbState.TempTableEditSession = editor.CreateTableEditSession(newRoot, editor.TableEditSessionProps{})
|
||||
|
||||
return sess.SetTempTableRoot(ctx, dbName, newRoot)
|
||||
}
|
||||
|
||||
// CWBHeadRef returns the branch ref for this session HEAD for the database named
|
||||
func (sess *Session) CWBHeadRef(dbName string) (ref.DoltRef, error) {
|
||||
return sess.DbStates[dbName].WorkingSet.Ref().ToHeadRef()
|
||||
func (sess *Session) CWBHeadRef(ctx *sql.Context, dbName string) (ref.DoltRef, error) {
|
||||
dbState, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbState.WorkingSet.Ref().ToHeadRef()
|
||||
}
|
||||
|
||||
// setSessionVarsForDb updates the three session vars that track the value of the session root hashes
|
||||
func (sess *Session) setSessionVarsForDb(ctx *sql.Context, dbName string) error {
|
||||
state := sess.DbStates[dbName]
|
||||
state, _, err := sess.LookupDbState(ctx, dbName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roots := state.GetRoots()
|
||||
|
||||
h, err := roots.Working.HashOf()
|
||||
|
||||
@@ -33,13 +33,13 @@ type SessionStateAdapter struct {
|
||||
}
|
||||
|
||||
func (s SessionStateAdapter) UpdateStagedRoot(ctx context.Context, newRoot *doltdb.RootValue) error {
|
||||
roots, _ := s.session.GetRoots(s.dbName)
|
||||
roots, _ := s.session.GetRoots(nil, s.dbName)
|
||||
roots.Staged = newRoot
|
||||
return s.session.SetRoots(ctx.(*sql.Context), s.dbName, roots)
|
||||
}
|
||||
|
||||
func (s SessionStateAdapter) UpdateWorkingRoot(ctx context.Context, newRoot *doltdb.RootValue) error {
|
||||
roots, _ := s.session.GetRoots(s.dbName)
|
||||
roots, _ := s.session.GetRoots(nil, s.dbName)
|
||||
roots.Working = newRoot
|
||||
return s.session.SetRoots(ctx.(*sql.Context), s.dbName, roots)
|
||||
}
|
||||
@@ -69,11 +69,11 @@ func NewSessionStateAdapter(session *Session, dbName string) SessionStateAdapter
|
||||
}
|
||||
|
||||
func (s SessionStateAdapter) GetRoots(ctx context.Context) (doltdb.Roots, error) {
|
||||
return s.session.DbStates[s.dbName].GetRoots(), nil
|
||||
return s.session.dbStates[s.dbName].GetRoots(), nil
|
||||
}
|
||||
|
||||
func (s SessionStateAdapter) CWBHeadRef() ref.DoltRef {
|
||||
workingSet := s.session.DbStates[s.dbName].WorkingSet
|
||||
workingSet := s.session.dbStates[s.dbName].WorkingSet
|
||||
headRef, err := workingSet.Ref().ToHeadRef()
|
||||
// TODO: fix this interface
|
||||
if err != nil {
|
||||
@@ -93,13 +93,13 @@ func (s SessionStateAdapter) CWBHeadSpec() *doltdb.CommitSpec {
|
||||
}
|
||||
|
||||
func (s SessionStateAdapter) IsMergeActive(ctx context.Context) (bool, error) {
|
||||
return s.session.DbStates[s.dbName].WorkingSet.MergeActive(), nil
|
||||
return s.session.dbStates[s.dbName].WorkingSet.MergeActive(), nil
|
||||
}
|
||||
|
||||
func (s SessionStateAdapter) GetMergeCommit(ctx context.Context) (*doltdb.Commit, error) {
|
||||
return s.session.DbStates[s.dbName].WorkingSet.MergeState().Commit(), nil
|
||||
return s.session.dbStates[s.dbName].WorkingSet.MergeState().Commit(), nil
|
||||
}
|
||||
|
||||
func (s SessionStateAdapter) GetPreMergeWorking(ctx context.Context) (*doltdb.RootValue, error) {
|
||||
return s.session.DbStates[s.dbName].WorkingSet.MergeState().PreMergeWorkingRoot(), nil
|
||||
return s.session.dbStates[s.dbName].WorkingSet.MergeState().PreMergeWorkingRoot(), nil
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func setupMergeableIndexes(t *testing.T, tableName, insertQuery string) (*sqle.E
|
||||
engine.AddDatabase(mergeableDb)
|
||||
|
||||
// Get an updated root to use for the rest of the test
|
||||
root, _ = dsess.DSessFromSess(sqlCtx.Session).GetRoot(mergeableDb.Name())
|
||||
root, _ = dsess.DSessFromSess(sqlCtx.Session).GetRoot(nil, mergeableDb.Name())
|
||||
|
||||
return engine, dEnv, mergeableDb, []*indexTuple{
|
||||
idxv1ToTuple,
|
||||
|
||||
@@ -57,7 +57,10 @@ var _ sql.RowInserter = (*sqlTableEditor)(nil)
|
||||
var _ sql.RowDeleter = (*sqlTableEditor)(nil)
|
||||
|
||||
func newSqlTableEditor(ctx *sql.Context, t *WritableDoltTable) (*sqlTableEditor, error) {
|
||||
sess := t.db.TableEditSession(ctx, t.IsTemporary())
|
||||
sess, err := t.db.TableEditSession(ctx, t.IsTemporary())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tableEditor, err := sess.GetTableEditor(ctx, t.tableName, t.sch)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user