From 5fa0978734aad58c1b85486ef07f0ca5baf7d135 Mon Sep 17 00:00:00 2001 From: Andy Arthur Date: Sun, 4 Jul 2021 09:03:28 -0700 Subject: [PATCH] added sess.LookupDbState() --- go/libraries/doltcore/doltdb/commit_spec.go | 1 - go/libraries/doltcore/sqle/database.go | 43 ++- .../doltcore/sqle/database_provider.go | 1 - .../doltcore/sqle/dfunctions/active_branch.go | 4 +- .../doltcore/sqle/dfunctions/commit.go | 5 +- .../doltcore/sqle/dfunctions/dolt_add.go | 2 +- .../doltcore/sqle/dfunctions/dolt_checkout.go | 4 +- .../doltcore/sqle/dfunctions/dolt_commit.go | 13 +- .../doltcore/sqle/dfunctions/dolt_merge.go | 13 +- .../sqle/dfunctions/dolt_merge_base.go | 4 +- .../doltcore/sqle/dfunctions/dolt_reset.go | 9 +- .../doltcore/sqle/dfunctions/hashof.go | 2 +- .../doltcore/sqle/dfunctions/merge.go | 6 +- .../doltcore/sqle/dfunctions/squash.go | 6 +- go/libraries/doltcore/sqle/dsess/session.go | 257 +++++++++++------- .../sqle/dsess/session_state_adapter.go | 14 +- .../sqle/mergeable_indexes_setup_test.go | 2 +- go/libraries/doltcore/sqle/table_editor.go | 5 +- 18 files changed, 246 insertions(+), 145 deletions(-) diff --git a/go/libraries/doltcore/doltdb/commit_spec.go b/go/libraries/doltcore/doltdb/commit_spec.go index 129831d0c7..d1c51e4b48 100644 --- a/go/libraries/doltcore/doltdb/commit_spec.go +++ b/go/libraries/doltcore/doltdb/commit_spec.go @@ -49,7 +49,6 @@ func IsValidCommitHash(s string) bool { return hashRegex.MatchString(s) } - type commitSpecType string const ( diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 8daf4da06a..b61cf4d570 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -493,8 +493,11 @@ var hashType = sql.MustCreateString(query.Type_TEXT, 32, sql.Collation_ascii_bin // GetRoot returns the root value for this database session func (db Database) GetRoot(ctx *sql.Context) (*doltdb.RootValue, error) { sess := dsess.DSessFromSess(ctx.Session) - dbState, dbRootOk := sess.DbStates[db.name] - if !dbRootOk { + dbState, ok, err := sess.LookupDbState(ctx, db.Name()) + if err != nil { + return nil, err + } + if !ok { return nil, fmt.Errorf("no root value found in session") } @@ -746,7 +749,11 @@ func (db Database) RenameTable(ctx *sql.Context, oldName, newName string) error // Flush flushes the current batch of outstanding changes and returns any errors. func (db Database) Flush(ctx *sql.Context) error { sess := dsess.DSessFromSess(ctx.Session) - editSession := sess.DbStates[db.name].EditSession + dbState, _, err := sess.LookupDbState(ctx, db.Name()) + if err != nil { + return err + } + editSession := dbState.EditSession newRoot, err := editSession.Flush(ctx) if err != nil { @@ -760,7 +767,7 @@ func (db Database) Flush(ctx *sql.Context) error { // Flush any changes made to temporary tables // TODO: Shouldn't always be updating both roots. Needs to update either both roots or neither of them, atomically - tempTableEditSession := sess.DbStates[db.name].TempTableEditSession + tempTableEditSession := dbState.TempTableEditSession if tempTableEditSession != nil { newTempTableRoot, err := tempTableEditSession.Flush(ctx) if err != nil { @@ -964,7 +971,12 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin // If rows exist, then grab the highest id and add 1 to get the new id indexToUse := int64(1) - te, err := db.TableEditSession(ctx, tbl.IsTemporary()).GetTableEditor(ctx, doltdb.SchemasTableName, tbl.sch) + ts, err := db.TableEditSession(ctx, tbl.IsTemporary()) + if err != nil { + return err + } + + te, err := ts.GetTableEditor(ctx, doltdb.SchemasTableName, tbl.sch) if err != nil { return err } @@ -1028,20 +1040,29 @@ func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name str } // TableEditSession returns the TableEditSession for this database from the given context. -func (db Database) TableEditSession(ctx *sql.Context, isTemporary bool) *editor.TableEditSession { - if isTemporary { - return dsess.DSessFromSess(ctx.Session).DbStates[db.name].TempTableEditSession +func (db Database) TableEditSession(ctx *sql.Context, isTemporary bool) (*editor.TableEditSession, error) { + sess := dsess.DSessFromSess(ctx.Session) + dbState, _, err := sess.LookupDbState(ctx, db.Name()) + if err != nil { + return nil, err } - return dsess.DSessFromSess(ctx.Session).DbStates[db.name].EditSession + + if isTemporary { + return dbState.TempTableEditSession, nil + } + return dbState.EditSession, nil } // GetAllTemporaryTables returns all temporary tables func (db Database) GetAllTemporaryTables(ctx *sql.Context) ([]sql.Table, error) { sess := dsess.DSessFromSess(ctx.Session) + dbState, _, err := sess.LookupDbState(ctx, db.Name()) + if err != nil { + return nil, err + } tables := make([]sql.Table, 0) - - root := sess.DbStates[db.name].TempTableRoot + root := dbState.TempTableRoot if root != nil { tNames, err := root.GetTableNames(ctx) if err != nil { diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 9b8b9bdc18..d5d1af451d 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -290,4 +290,3 @@ type staticRepoState struct { func (s staticRepoState) CWBHeadRef() ref.DoltRef { return s.branch } - diff --git a/go/libraries/doltcore/sqle/dfunctions/active_branch.go b/go/libraries/doltcore/sqle/dfunctions/active_branch.go index c0bc631069..6a4fbda46a 100644 --- a/go/libraries/doltcore/sqle/dfunctions/active_branch.go +++ b/go/libraries/doltcore/sqle/dfunctions/active_branch.go @@ -38,13 +38,13 @@ func (ab *ActiveBranchFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, er dbName := ctx.GetCurrentDatabase() dSess := dsess.DSessFromSess(ctx.Session) - ddb, ok := dSess.GetDoltDB(dbName) + ddb, ok := dSess.GetDoltDB(ctx, dbName) if !ok { return nil, sql.ErrDatabaseNotFound.New(dbName) } - currentBranchRef, err := dSess.CWBHeadRef(dbName) + currentBranchRef, err := dSess.CWBHeadRef(ctx, dbName) if err != nil { return nil, err } diff --git a/go/libraries/doltcore/sqle/dfunctions/commit.go b/go/libraries/doltcore/sqle/dfunctions/commit.go index 868d6657f7..084bef554a 100644 --- a/go/libraries/doltcore/sqle/dfunctions/commit.go +++ b/go/libraries/doltcore/sqle/dfunctions/commit.go @@ -75,10 +75,11 @@ func (cf *CommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - root, ok := dSess.GetRoot(dbName) + roots, ok := dSess.GetRoots(ctx, dbName) if !ok { return nil, fmt.Errorf("unknown database '%s'", dbName) } + root := roots.Working // Update the superschema to with any new information from the table map. tblNames, err := root.GetTableNames(ctx) @@ -91,7 +92,7 @@ func (cf *CommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - ddb, ok := dSess.GetDoltDB(dbName) + ddb, ok := dSess.GetDoltDB(ctx, dbName) if !ok { return nil, sql.ErrDatabaseNotFound.New(dbName) diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_add.go b/go/libraries/doltcore/sqle/dfunctions/dolt_add.go index 9dc823bb38..9f2286021a 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_add.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_add.go @@ -53,7 +53,7 @@ func (d DoltAddFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { allFlag := apr.Contains(cli.AllFlag) dSess := dsess.DSessFromSess(ctx.Session) - roots, ok := dSess.GetRoots(dbName) + roots, ok := dSess.GetRoots(ctx, dbName) if apr.NArg() == 0 && !allFlag { return 1, fmt.Errorf("Nothing specified, nothing added. Maybe you wanted to say 'dolt add .'?") } else if allFlag || apr.NArg() == 1 && apr.Arg(0) == "." { diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_checkout.go b/go/libraries/doltcore/sqle/dfunctions/dolt_checkout.go index b148303d61..3cd29fbed5 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_checkout.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_checkout.go @@ -63,12 +63,12 @@ func (d DoltCheckoutFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, erro // Checking out new branch. dSess := dsess.DSessFromSess(ctx.Session) - dbData, ok := dSess.GetDbData(dbName) + dbData, ok := dSess.GetDbData(ctx, dbName) if !ok { return 1, fmt.Errorf("Could not load database %s", dbName) } - roots, ok := dSess.GetRoots(dbName) + roots, ok := dSess.GetRoots(ctx, dbName) if !ok { return 1, fmt.Errorf("Could not load database %s", dbName) } diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go b/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go index 1f0c94d7c8..8fcb17549c 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_commit.go @@ -60,8 +60,7 @@ func (d DoltCommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } dSess := dsess.DSessFromSess(ctx.Session) - - roots, ok := dSess.GetRoots(dbName) + roots, ok := dSess.GetRoots(nil, dbName) if !ok { return nil, fmt.Errorf("Could not load database %s", dbName) } @@ -178,7 +177,11 @@ func CommitToDolt( // repo state writer, so we're never persisting the new working set to disk like in a command line context. // TODO: fix this mess - ws := dSess.WorkingSet(ctx, dbName) + ws, err := dSess.WorkingSet(ctx, dbName) + if err != nil { + return nil, err + } + // StartTransaction sets the working set for the session, and we want the one we previous had, not the one on disk // Updating the working set like this also updates the head commit and root info for the session tx, err := dSess.StartTransaction(ctx, dbName) @@ -186,6 +189,10 @@ func CommitToDolt( return nil, err } + ws, err = dSess.WorkingSet(ctx, dbName) + if err != nil { + return nil, err + } err = dSess.SetWorkingSet(ctx, dbName, ws.ClearMerge(), nil) if err != nil { return nil, err diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go b/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go index 3838887bfe..35692b46dd 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go @@ -46,7 +46,7 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } sess := dsess.DSessFromSess(ctx.Session) - dbData, ok := sess.GetDbData(dbName) + dbData, ok := sess.GetDbData(nil, dbName) if !ok { return 1, fmt.Errorf("Could not load database %s", dbName) @@ -68,8 +68,11 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return 1, fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together.\n", cli.SquashParam, cli.NoFFParam) } - ws := sess.WorkingSet(ctx, dbName) - roots, ok := sess.GetRoots(dbName) + ws, err := sess.WorkingSet(ctx, dbName) + if err != nil { + return nil, err + } + roots, ok := sess.GetRoots(nil, dbName) // logrus.Errorf("heads are working: %s\nhead: %s", roots.Working.DebugString(ctx, true), roots.Head.DebugString(ctx, true)) @@ -95,7 +98,7 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return "Merge aborted", nil } - ddb, ok := sess.GetDoltDB(dbName) + ddb, ok := sess.GetDoltDB(nil, dbName) if !ok { return nil, sql.ErrDatabaseNotFound.New(dbName) } @@ -317,7 +320,7 @@ func executeNoFFMerge( } // The roots need refreshing after the above - roots, _ := dSess.GetRoots(dbName) + roots, _ := dSess.GetRoots(ctx, dbName) // TODO: this does several session state updates, and it really needs to just do one // We also need to commit any pending transaction before we do this. diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_merge_base.go b/go/libraries/doltcore/sqle/dfunctions/dolt_merge_base.go index 766b82aadd..f074fa0021 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_merge_base.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_merge_base.go @@ -84,11 +84,11 @@ func resolveRefSpecs(ctx *sql.Context, leftSpec, rightSpec string) (left, right sess := dsess.DSessFromSess(ctx.Session) dbName := ctx.GetCurrentDatabase() - dbData, ok := sess.GetDbData(dbName) + dbData, ok := sess.GetDbData(ctx, dbName) if !ok { return nil, nil, sql.ErrDatabaseNotFound.New(dbName) } - doltDB, ok := sess.GetDoltDB(dbName) + doltDB, ok := sess.GetDoltDB(ctx, dbName) if !ok { return nil, nil, sql.ErrDatabaseNotFound.New(dbName) } diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_reset.go b/go/libraries/doltcore/sqle/dfunctions/dolt_reset.go index 45756a9496..b977e523b9 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_reset.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_reset.go @@ -40,7 +40,7 @@ func (d DoltResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } dSess := dsess.DSessFromSess(ctx.Session) - dbData, ok := dSess.GetDbData(dbName) + dbData, ok := dSess.GetDbData(ctx, dbName) if !ok { return 1, fmt.Errorf("Could not load database %s", dbName) @@ -64,7 +64,7 @@ func (d DoltResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } // Get all the needed roots. - roots, ok := dSess.GetRoots(dbName) + roots, ok := dSess.GetRoots(nil, dbName) if !ok { return 1, fmt.Errorf("Could not load database %s", dbName) } @@ -91,7 +91,10 @@ func (d DoltResetFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } } - ws := dSess.WorkingSet(ctx, dbName) + ws, err := dSess.WorkingSet(ctx, dbName) + if err != nil { + return nil, err + } err = dSess.SetWorkingSet(ctx, dbName, ws.WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged), nil) if err != nil { return 1, err diff --git a/go/libraries/doltcore/sqle/dfunctions/hashof.go b/go/libraries/doltcore/sqle/dfunctions/hashof.go index e926bcb4bb..b2c9bb71b7 100644 --- a/go/libraries/doltcore/sqle/dfunctions/hashof.go +++ b/go/libraries/doltcore/sqle/dfunctions/hashof.go @@ -63,7 +63,7 @@ func (t *HashOf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } dbName := ctx.GetCurrentDatabase() - ddb, ok := dsess.DSessFromSess(ctx.Session).GetDoltDB(dbName) + ddb, ok := dsess.DSessFromSess(ctx.Session).GetDoltDB(ctx, dbName) if !ok { return nil, sql.ErrDatabaseNotFound.New(dbName) } diff --git a/go/libraries/doltcore/sqle/dfunctions/merge.go b/go/libraries/doltcore/sqle/dfunctions/merge.go index 3e796d0adc..05bbb1b543 100644 --- a/go/libraries/doltcore/sqle/dfunctions/merge.go +++ b/go/libraries/doltcore/sqle/dfunctions/merge.go @@ -71,12 +71,12 @@ func (cf *MergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } dbName := sess.GetCurrentDatabase() - ddb, ok := sess.GetDoltDB(dbName) + ddb, ok := sess.GetDoltDB(nil, dbName) if !ok { return nil, sql.ErrDatabaseNotFound.New(dbName) } - root, ok := sess.GetRoot(dbName) + roots, ok := sess.GetRoots(ctx, dbName) if !ok { return nil, sql.ErrDatabaseNotFound.New(dbName) } @@ -86,7 +86,7 @@ func (cf *MergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - err = checkForUncommittedChanges(root, headRoot) + err = checkForUncommittedChanges(roots.Working, headRoot) if err != nil { return nil, err } diff --git a/go/libraries/doltcore/sqle/dfunctions/squash.go b/go/libraries/doltcore/sqle/dfunctions/squash.go index 552469e324..6f2bfa916f 100644 --- a/go/libraries/doltcore/sqle/dfunctions/squash.go +++ b/go/libraries/doltcore/sqle/dfunctions/squash.go @@ -48,12 +48,12 @@ func (s SquashFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } dbName := sess.GetCurrentDatabase() - ddb, ok := sess.GetDoltDB(dbName) + ddb, ok := sess.GetDoltDB(ctx, dbName) if !ok { return nil, sql.ErrDatabaseNotFound.New(dbName) } - root, ok := sess.GetRoot(dbName) + roots, ok := sess.GetRoots(ctx, dbName) if !ok { return nil, sql.ErrDatabaseNotFound.New(dbName) } @@ -63,7 +63,7 @@ func (s SquashFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - err = checkForUncommittedChanges(root, headRoot) + err = checkForUncommittedChanges(roots.Working, headRoot) if err != nil { return nil, err } diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index fed2677bf3..f0f0bcc738 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -128,7 +128,7 @@ type Session struct { Username string Email string // TODO: make this private again - DbStates map[string]*DatabaseSessionState + dbStates map[string]*DatabaseSessionState } type DatabaseSessionState struct { @@ -161,7 +161,7 @@ func DefaultSession() *Session { Session: sql.NewBaseSession(), Username: "", Email: "", - DbStates: make(map[string]*DatabaseSessionState), + dbStates: make(map[string]*DatabaseSessionState), } return sess } @@ -180,7 +180,7 @@ func NewSession(ctx *sql.Context, sqlSess sql.Session, username, email string, d Session: sqlSess, Username: username, Email: email, - DbStates: make(map[string]*DatabaseSessionState), + dbStates: make(map[string]*DatabaseSessionState), } for _, db := range dbs { @@ -206,12 +206,21 @@ func DSessFromSess(sess sql.Session) *Session { return sess.(*Session) } +func (sess *Session) LookupDbState(ctx *sql.Context, dbName string) (*DatabaseSessionState, bool, error) { + dbState, ok := sess.dbStates[dbName] + return dbState, ok, nil +} + // Flush flushes all changes sitting in edit sessions to the session root for the database named. This normally // happens automatically as part of statement execution, and is only necessary when the session is manually batched (as // for bulk SQL import) func (sess *Session) Flush(ctx *sql.Context, dbName string) error { - editSession := sess.DbStates[dbName].EditSession - newRoot, err := editSession.Flush(ctx) + dbState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + + newRoot, err := dbState.EditSession.Flush(ctx) if err != nil { return err } @@ -232,7 +241,10 @@ func (sess *Session) StartTransaction(ctx *sql.Context, dbName string) (sql.Tran return DisabledTransaction{}, nil } - sessionState := sess.DbStates[dbName] + sessionState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return nil, err + } wsRef := sessionState.WorkingSet.Ref() ws, err := sessionState.dbData.Ddb.ResolveWorkingSet(ctx, wsRef) @@ -257,7 +269,7 @@ func (sess *Session) StartTransaction(ctx *sql.Context, dbName string) (sql.Tran } func (sess *Session) newWorkingSetForHead(ctx *sql.Context, wsRef ref.WorkingSetRef, dbName string) (*doltdb.WorkingSet, error) { - dbData, _ := sess.GetDbData(dbName) + dbData, _ := sess.GetDbData(nil, dbName) headSpec, _ := doltdb.NewCommitSpec("HEAD") headRef, err := wsRef.ToHeadRef() @@ -287,11 +299,16 @@ func (sess *Session) CommitTransaction(ctx *sql.Context, dbName string, tx sql.T } } + dbState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + if TransactionsDisabled(ctx) { return nil } - if !sess.DbStates[dbName].dirty { + if !dbState.dirty { return nil } @@ -301,7 +318,11 @@ func (sess *Session) CommitTransaction(ctx *sql.Context, dbName string, tx sql.T return nil } - dbstate, ok := sess.DbStates[dbName] + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + // It's possible that this returns false if the user has created an in-Memory database. Moreover, // the analyzer will check for us whether a db exists or not. // TODO: fix this @@ -320,7 +341,7 @@ func (sess *Session) CommitTransaction(ctx *sql.Context, dbName string, tx sql.T // TODO: actual logging // logrus.Errorf("working root to commit is %s", dbstate.workingSet.WorkingRoot().DebugString(ctx, true)) - mergedWorkingSet, err := dtx.Commit(ctx, dbstate.WorkingSet) + mergedWorkingSet, err := dtx.Commit(ctx, dbState.WorkingSet) if err != nil { return err } @@ -337,7 +358,7 @@ func (sess *Session) CommitTransaction(ctx *sql.Context, dbName string, tx sql.T return err } - dbstate.dirty = false + dbState.dirty = false return nil } @@ -347,7 +368,10 @@ func (sess *Session) CommitToDolt( dbName string, props actions.CommitStagedProps, ) (*doltdb.Commit, error) { - sessionState := sess.DbStates[dbName] + sessionState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return nil, err + } dbData := sessionState.dbData // TODO: this does several session state updates, and it really needs to just do one @@ -368,7 +392,10 @@ func (sess *Session) CommitToDolt( // repo state writer, so we're never persisting the new working set to disk like in a command line context. // TODO: fix this mess - ws := sess.WorkingSet(ctx, dbName) + ws, err := sess.WorkingSet(ctx, dbName) + if err != nil { + return nil, err + } // StartTransaction sets the working set for the session, and we want the one we previous had, not the one on disk // Updating the working set like this also updates the head commit and root info for the session tx, err := sess.StartTransaction(ctx, dbName) @@ -406,7 +433,10 @@ func (sess *Session) CreateDoltCommit(ctx *sql.Context, dbName string) error { return err } - sessionState := sess.DbStates[dbName] + sessionState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } roots := sessionState.GetRoots() roots, err = actions.StageAllTablesNoDocs(ctx, roots) @@ -435,7 +465,12 @@ func (sess *Session) RollbackTransaction(ctx *sql.Context, dbName string, tx sql return nil } - if !sess.DbStates[dbName].dirty { + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + + if !dbState.dirty { return nil } @@ -444,12 +479,12 @@ func (sess *Session) RollbackTransaction(ctx *sql.Context, dbName string, tx sql return fmt.Errorf("expected a DoltTransaction") } - err := sess.SetRoot(ctx, dbName, dtx.startState.WorkingRoot()) + err = sess.SetRoot(ctx, dbName, dtx.startState.WorkingRoot()) if err != nil { return err } - sess.DbStates[dbName].dirty = false + dbState.dirty = false return nil } @@ -465,7 +500,12 @@ func (sess *Session) CreateSavepoint(ctx *sql.Context, savepointName, dbName str return fmt.Errorf("expected a DoltTransaction") } - dtx.CreateSavepoint(savepointName, sess.DbStates[dbName].GetRoots().Working) + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + + dtx.CreateSavepoint(savepointName, dbState.GetRoots().Working) return nil } @@ -515,40 +555,16 @@ func (sess *Session) ReleaseSavepoint(ctx *sql.Context, savepointName, dbName st } // GetDoltDB returns the *DoltDB for a given database by name -func (sess *Session) GetDoltDB(dbName string) (*doltdb.DoltDB, bool) { - dbstate, ok := sess.DbStates[dbName] +func (sess *Session) GetDoltDB(ctx *sql.Context, dbName string) (*doltdb.DoltDB, bool) { + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return nil, false + } if !ok { return nil, false } - return dbstate.dbData.Ddb, true -} - -func (sess *Session) GetDoltDBRepoStateWriter(dbName string) (env.RepoStateWriter, bool) { - d, ok := sess.DbStates[dbName] - if !ok { - return nil, false - } - - return d.dbData.Rsw, true -} - -func (sess *Session) GetDoltDBRepoStateReader(dbName string) (env.RepoStateReader, bool) { - d, ok := sess.DbStates[dbName] - if !ok { - return nil, false - } - - return d.dbData.Rsr, true -} - -func (sess *Session) GetDoltDBDocsReadWriter(dbName string) (env.DocsReadWriter, bool) { - d, ok := sess.DbStates[dbName] - if !ok { - return nil, false - } - - return d.dbData.Drw, true + return dbState.dbData.Ddb, true } func (sess *Session) GetDoltDbAutoIncrementTracker(dbName string) (globalstate.AutoIncrementTracker, bool) { @@ -569,33 +585,29 @@ func (sess *Session) GetDoltDbAutoIncrementTracker(dbName string) (globalstate.A return tracker, true } -func (sess *Session) GetDbData(dbName string) (env.DbData, bool) { - sessionState, ok := sess.DbStates[dbName] +func (sess *Session) GetDbData(ctx *sql.Context, dbName string) (env.DbData, bool) { + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return env.DbData{}, false + } if !ok { return env.DbData{}, false } - return sessionState.dbData, true -} - -// GetRoot returns the current working *RootValue for a given database associated with the session -func (sess *Session) GetRoot(dbName string) (*doltdb.RootValue, bool) { - dbstate, ok := sess.DbStates[dbName] - if !ok { - return nil, false - } - - return dbstate.GetRoots().Working, true + return dbState.dbData, true } // GetRoots returns the current roots for a given database associated with the session -func (sess *Session) GetRoots(dbName string) (doltdb.Roots, bool) { - dbstate, ok := sess.DbStates[dbName] +func (sess *Session) GetRoots(ctx *sql.Context, dbName string) (doltdb.Roots, bool) { + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return doltdb.Roots{}, false + } if !ok { return doltdb.Roots{}, false } - return dbstate.GetRoots(), true + return dbState.GetRoots(), true } // SetRoot sets a new root value for the session for the database named. This is the primary mechanism by which data @@ -604,7 +616,11 @@ func (sess *Session) GetRoots(dbName string) (doltdb.Roots, bool) { // Data changes contained in the |newRoot| aren't persisted until this session is committed. // TODO: rename to SetWorkingRoot func (sess *Session) SetRoot(ctx *sql.Context, dbName string, newRoot *doltdb.RootValue) error { - sessionState := sess.DbStates[dbName] + sessionState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + if rootsEqual(sessionState.GetRoots().Working, newRoot) { return nil } @@ -616,7 +632,10 @@ func (sess *Session) SetRoot(ctx *sql.Context, dbName string, newRoot *doltdb.Ro func (sess *Session) setRoot(ctx *sql.Context, dbName string, newRoot *doltdb.RootValue) error { // logrus.Tracef("setting root value %s", newRoot.DebugString(ctx, true)) - sessionState := sess.DbStates[dbName] + sessionState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } h, err := newRoot.HashOf() if err != nil { @@ -645,7 +664,12 @@ func (sess *Session) setRoot(ctx *sql.Context, dbName string, newRoot *doltdb.Ro // Unlike setting the only the working root, this method always marks the database state dirty. func (sess *Session) SetRoots(ctx *sql.Context, dbName string, roots doltdb.Roots) error { // TODO: handle HEAD here? - workingSet := sess.DbStates[dbName].WorkingSet.WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged) + sessionState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + + workingSet := sessionState.WorkingSet.WithWorkingRoot(roots.Working).WithStagedRoot(roots.Staged) return sess.SetWorkingSet(ctx, dbName, workingSet, nil) } @@ -662,7 +686,10 @@ func (sess *Session) SetWorkingSet( panic("attempted to set a nil working set for the session") } - sessionState := sess.DbStates[dbName] + sessionState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } sessionState.WorkingSet = ws if headRoot == nil && !sessionState.detachedHead { @@ -693,7 +720,7 @@ func (sess *Session) SetWorkingSet( sessionState.headRoot = headRoot } - err := sess.setSessionVarsForDb(ctx, dbName) + err = sess.setSessionVarsForDb(ctx, dbName) if err != nil { return err } @@ -716,7 +743,10 @@ func (sess *Session) SwitchWorkingSet( ctx *sql.Context, dbName string, wsRef ref.WorkingSetRef) error { - sessionState := sess.DbStates[dbName] + sessionState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } if sessionState.dirty { return fmt.Errorf("Cannot switch working set, session state is dirty. " + @@ -778,33 +808,46 @@ func (sess *Session) SwitchWorkingSet( return nil } -func (sess *Session) WorkingSet(ctx *sql.Context, dbName string) *doltdb.WorkingSet { - sessionState := sess.DbStates[dbName] - return sessionState.WorkingSet +func (sess *Session) WorkingSet(ctx *sql.Context, dbName string) (*doltdb.WorkingSet, error) { + sessionState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return nil, err + } + return sessionState.WorkingSet, nil } func (sess *Session) GetTempTableRootValue(ctx *sql.Context, dbName string) (*doltdb.RootValue, bool) { - dbstate, ok := sess.DbStates[dbName] + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return nil, false + } if !ok { return nil, false } - if dbstate.TempTableRoot == nil { + if dbState.TempTableRoot == nil { return nil, false } - return dbstate.TempTableRoot, true + return dbState.TempTableRoot, true } func (sess *Session) SetTempTableRoot(ctx *sql.Context, dbName string, newRoot *doltdb.RootValue) error { - sess.DbStates[dbName].TempTableRoot = newRoot - return sess.DbStates[dbName].TempTableEditSession.SetRoot(ctx, newRoot) + dbState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + dbState.TempTableRoot = newRoot + return dbState.TempTableEditSession.SetRoot(ctx, newRoot) } // GetHeadCommit returns the parent commit of the current session. func (sess *Session) GetHeadCommit(ctx *sql.Context, dbName string) (*doltdb.Commit, error) { - dbState, dbFound := sess.DbStates[dbName] - if !dbFound { + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return nil, err + } + if !ok { return nil, sql.ErrDatabaseNotFound.New(dbName) } @@ -822,7 +865,12 @@ func (sess *Session) SetSessionVariable(ctx *sql.Context, key string, value inte return err } - sess.DbStates[dbName].detachedHead = true + dbState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + + dbState.detachedHead = true return nil } @@ -849,11 +897,11 @@ func (sess *Session) setForeignKeyChecksSessionVar(ctx *sql.Context, key string, intVal = convertedVal.(int64) } if intVal == 0 { - for _, dbState := range sess.DbStates { + for _, dbState := range sess.dbStates { dbState.EditSession.Props.ForeignKeyChecksDisabled = true } } else if intVal == 1 { - for _, dbState := range sess.DbStates { + for _, dbState := range sess.dbStates { dbState.EditSession.Props.ForeignKeyChecksDisabled = false } } else { @@ -869,12 +917,15 @@ func (sess *Session) setWorkingSessionVar(ctx *sql.Context, value interface{}, d return doltdb.ErrInvalidHash } - dbstate, dbFound := sess.DbStates[dbName] - if !dbFound { + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + if !ok { return sql.ErrDatabaseNotFound.New(dbName) } - root, err := dbstate.dbData.Ddb.ReadRootValue(ctx, hash.Parse(valStr)) + root, err := dbState.dbData.Ddb.ReadRootValue(ctx, hash.Parse(valStr)) if errors.Is(doltdb.ErrNoRootValAtHash, err) { return nil } else if err != nil { @@ -885,8 +936,11 @@ func (sess *Session) setWorkingSessionVar(ctx *sql.Context, value interface{}, d } func (sess *Session) setHeadSessionVar(ctx *sql.Context, value interface{}, dbName string) error { - dbstate, dbFound := sess.DbStates[dbName] - if !dbFound { + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + if !ok { return sql.ErrDatabaseNotFound.New(dbName) } @@ -901,19 +955,19 @@ func (sess *Session) setHeadSessionVar(ctx *sql.Context, value interface{}, dbNa return err } - cm, err := dbstate.dbData.Ddb.Resolve(ctx, cs, nil) + cm, err := dbState.dbData.Ddb.Resolve(ctx, cs, nil) if err != nil { return err } - dbstate.headCommit = cm + dbState.headCommit = cm root, err := cm.GetRootValue() if err != nil { return err } - dbstate.headRoot = root + dbState.headRoot = root err = sess.Session.SetSessionVariable(ctx, HeadKey(dbName), value) if err != nil { @@ -937,7 +991,7 @@ func (sess *Session) AddDB(ctx *sql.Context, dbState InitialDbState) error { defineSystemVariables(db.Name()) sessionState := &DatabaseSessionState{} - sess.DbStates[db.Name()] = sessionState + sess.dbStates[db.Name()] = sessionState // TODO: get rid of all repo state reader / writer stuff. Until we do, swap out the reader with one of our own, and // the writer with one that errors out @@ -991,19 +1045,30 @@ func (sess *Session) CreateTemporaryTablesRoot(ctx *sql.Context, dbName string, return err } - sess.DbStates[dbName].TempTableEditSession = editor.CreateTableEditSession(newRoot, editor.TableEditSessionProps{}) + dbState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } + dbState.TempTableEditSession = editor.CreateTableEditSession(newRoot, editor.TableEditSessionProps{}) return sess.SetTempTableRoot(ctx, dbName, newRoot) } // CWBHeadRef returns the branch ref for this session HEAD for the database named -func (sess *Session) CWBHeadRef(dbName string) (ref.DoltRef, error) { - return sess.DbStates[dbName].WorkingSet.Ref().ToHeadRef() +func (sess *Session) CWBHeadRef(ctx *sql.Context, dbName string) (ref.DoltRef, error) { + dbState, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return nil, err + } + return dbState.WorkingSet.Ref().ToHeadRef() } // setSessionVarsForDb updates the three session vars that track the value of the session root hashes func (sess *Session) setSessionVarsForDb(ctx *sql.Context, dbName string) error { - state := sess.DbStates[dbName] + state, _, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return err + } roots := state.GetRoots() h, err := roots.Working.HashOf() diff --git a/go/libraries/doltcore/sqle/dsess/session_state_adapter.go b/go/libraries/doltcore/sqle/dsess/session_state_adapter.go index 551b324afb..300fcbaa92 100755 --- a/go/libraries/doltcore/sqle/dsess/session_state_adapter.go +++ b/go/libraries/doltcore/sqle/dsess/session_state_adapter.go @@ -33,13 +33,13 @@ type SessionStateAdapter struct { } func (s SessionStateAdapter) UpdateStagedRoot(ctx context.Context, newRoot *doltdb.RootValue) error { - roots, _ := s.session.GetRoots(s.dbName) + roots, _ := s.session.GetRoots(nil, s.dbName) roots.Staged = newRoot return s.session.SetRoots(ctx.(*sql.Context), s.dbName, roots) } func (s SessionStateAdapter) UpdateWorkingRoot(ctx context.Context, newRoot *doltdb.RootValue) error { - roots, _ := s.session.GetRoots(s.dbName) + roots, _ := s.session.GetRoots(nil, s.dbName) roots.Working = newRoot return s.session.SetRoots(ctx.(*sql.Context), s.dbName, roots) } @@ -69,11 +69,11 @@ func NewSessionStateAdapter(session *Session, dbName string) SessionStateAdapter } func (s SessionStateAdapter) GetRoots(ctx context.Context) (doltdb.Roots, error) { - return s.session.DbStates[s.dbName].GetRoots(), nil + return s.session.dbStates[s.dbName].GetRoots(), nil } func (s SessionStateAdapter) CWBHeadRef() ref.DoltRef { - workingSet := s.session.DbStates[s.dbName].WorkingSet + workingSet := s.session.dbStates[s.dbName].WorkingSet headRef, err := workingSet.Ref().ToHeadRef() // TODO: fix this interface if err != nil { @@ -93,13 +93,13 @@ func (s SessionStateAdapter) CWBHeadSpec() *doltdb.CommitSpec { } func (s SessionStateAdapter) IsMergeActive(ctx context.Context) (bool, error) { - return s.session.DbStates[s.dbName].WorkingSet.MergeActive(), nil + return s.session.dbStates[s.dbName].WorkingSet.MergeActive(), nil } func (s SessionStateAdapter) GetMergeCommit(ctx context.Context) (*doltdb.Commit, error) { - return s.session.DbStates[s.dbName].WorkingSet.MergeState().Commit(), nil + return s.session.dbStates[s.dbName].WorkingSet.MergeState().Commit(), nil } func (s SessionStateAdapter) GetPreMergeWorking(ctx context.Context) (*doltdb.RootValue, error) { - return s.session.DbStates[s.dbName].WorkingSet.MergeState().PreMergeWorkingRoot(), nil + return s.session.dbStates[s.dbName].WorkingSet.MergeState().PreMergeWorkingRoot(), nil } diff --git a/go/libraries/doltcore/sqle/mergeable_indexes_setup_test.go b/go/libraries/doltcore/sqle/mergeable_indexes_setup_test.go index b580c16543..c761f7fc79 100644 --- a/go/libraries/doltcore/sqle/mergeable_indexes_setup_test.go +++ b/go/libraries/doltcore/sqle/mergeable_indexes_setup_test.go @@ -100,7 +100,7 @@ func setupMergeableIndexes(t *testing.T, tableName, insertQuery string) (*sqle.E engine.AddDatabase(mergeableDb) // Get an updated root to use for the rest of the test - root, _ = dsess.DSessFromSess(sqlCtx.Session).GetRoot(mergeableDb.Name()) + root, _ = dsess.DSessFromSess(sqlCtx.Session).GetRoot(nil, mergeableDb.Name()) return engine, dEnv, mergeableDb, []*indexTuple{ idxv1ToTuple, diff --git a/go/libraries/doltcore/sqle/table_editor.go b/go/libraries/doltcore/sqle/table_editor.go index 188e92d65c..02882632aa 100644 --- a/go/libraries/doltcore/sqle/table_editor.go +++ b/go/libraries/doltcore/sqle/table_editor.go @@ -57,7 +57,10 @@ var _ sql.RowInserter = (*sqlTableEditor)(nil) var _ sql.RowDeleter = (*sqlTableEditor)(nil) func newSqlTableEditor(ctx *sql.Context, t *WritableDoltTable) (*sqlTableEditor, error) { - sess := t.db.TableEditSession(ctx, t.IsTemporary()) + sess, err := t.db.TableEditSession(ctx, t.IsTemporary()) + if err != nil { + return nil, err + } tableEditor, err := sess.GetTableEditor(ctx, t.tableName, t.sch) if err != nil {