diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index a5ff67b8c5..e4b56bd384 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -33,6 +33,7 @@ import ( "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/cmd/dolt/commands" "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/doltcore/ref" dsqle "github.com/dolthub/dolt/go/libraries/doltcore/sqle" _ "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" @@ -222,12 +223,36 @@ func getDbStates(ctx context.Context, dbs []dsqle.SqlDatabase) ([]dsess.InitialD } func GetInitialDBStateWithDefaultBranch(ctx context.Context, db dsqle.SqlDatabase, branch string) (dsess.InitialDbState, error) { - ret, err := dsqle.GetInitialDBStateOnBranch(ctx, db, branch) + init, err := dsqle.GetInitialDBState(ctx, db) if err != nil { - err = fmt.Errorf("@@GLOBAL.dolt_default_branch (%s) is not a valid branch", branch) return dsess.InitialDbState{}, err } - return ret, nil + + ddb := init.DbData.Ddb + r := ref.NewBranchRef(branch) + + head, err := ddb.ResolveCommitRef(ctx, r) + if err != nil { + init.Err = fmt.Errorf("@@GLOBAL.dolt_default_branch (%s) is not a valid branch", branch) + } else { + init.Err = nil + } + init.HeadCommit = head + + if init.Err == nil { + workingSetRef, err := ref.WorkingSetRefForHead(r) + if err != nil { + return dsess.InitialDbState{}, err + } + + ws, err := init.DbData.Ddb.ResolveWorkingSet(ctx, workingSetRef) + if err != nil { + return dsess.InitialDbState{}, err + } + init.WorkingSet = ws + } + + return init, nil } func dsqleDBsAsSqlDBs(dbs []dsqle.SqlDatabase) []sql.Database { diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 098a42fa37..3272ef68f8 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -28,7 +28,6 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/env/actions/commitwalk" - "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/row" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/schema/alterschema" @@ -221,49 +220,6 @@ func GetInitialDBState(ctx context.Context, db SqlDatabase) (dsess.InitialDbStat }, nil } -// GetInitialDBStateOnBranch returns the InitialDbState for |db|, but on the -// given branch, instead of on the default branch from the RepoStateReader. -func GetInitialDBStateOnBranch(ctx context.Context, db SqlDatabase, branch string) (dsess.InitialDbState, error) { - rsr := db.DbData().Rsr - ddb := db.DbData().Ddb - - r := ref.NewBranchRef(branch) - - headCommit, err := ddb.ResolveCommitRef(ctx, r) - if err != nil { - return dsess.InitialDbState{}, err - } - - wsRef, err := ref.WorkingSetRefForHead(r) - if err != nil { - return dsess.InitialDbState{}, err - } - - ws, err := ddb.ResolveWorkingSet(ctx, wsRef) - if err != nil { - return dsess.InitialDbState{}, err - } - - remotes, err := rsr.GetRemotes() - if err != nil { - return dsess.InitialDbState{}, err - } - - branches, err := rsr.GetBranches() - if err != nil { - return dsess.InitialDbState{}, err - } - - return dsess.InitialDbState{ - Db: db, - HeadCommit: headCommit, - WorkingSet: ws, - DbData: db.DbData(), - Remotes: remotes, - Branches: branches, - }, nil -} - // Name returns the name of this database, set at creation time. func (db Database) Name() string { return db.name diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index 67bbb8d192..d817d430dd 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -153,6 +153,11 @@ type DatabaseSessionState struct { readReplica *env.Remote TempTableRoot *doltdb.RootValue TempTableEditSession *editor.TableEditSession + + // Same as InitialDbState.Err, this signifies that this + // DatabaseSessionState is invalid. LookupDbState returning a + // DatabaseSessionState with Err != nil will return that err. + Err error } func (d DatabaseSessionState) GetRoots() doltdb.Roots { @@ -217,12 +222,10 @@ func NewSession(ctx *sql.Context, sqlSess sql.Session, pro RevisionDatabaseProvi } for _, db := range dbs { - if db.Err == nil { - err := sess.AddDB(ctx, db) + err := sess.AddDB(ctx, db) - if err != nil { - return nil, err - } + if err != nil { + return nil, err } } @@ -243,7 +246,7 @@ func DSessFromSess(sess sql.Session) *Session { // LookupDbState returns the session state for the database named // TODO(zachmu) get rid of bool return param, use a not found error or similar -func (sess *Session) LookupDbState(ctx *sql.Context, dbName string) (*DatabaseSessionState, bool, error) { +func (sess *Session) lookupDbState(ctx *sql.Context, dbName string) (*DatabaseSessionState, bool, error) { dbState, ok := sess.dbStates[dbName] if ok { return dbState, ok, nil @@ -251,7 +254,7 @@ func (sess *Session) LookupDbState(ctx *sql.Context, dbName string) (*DatabaseSe init, err := sess.provider.RevisionDbState(ctx, dbName) if err != nil { - return nil, ok, err + return nil, false, err } // TODO: this could potentially add a |sess.dbStates| entry @@ -263,9 +266,17 @@ func (sess *Session) LookupDbState(ctx *sql.Context, dbName string) (*DatabaseSe } dbState, ok = sess.dbStates[dbName] if !ok { - return nil, ok, sql.ErrDatabaseNotFound.New(dbName) + return nil, false, sql.ErrDatabaseNotFound.New(dbName) } - return dbState, ok, nil + return dbState, true, nil +} + +func (sess *Session) LookupDbState(ctx *sql.Context, dbName string) (*DatabaseSessionState, bool, error) { + s, ok, err := sess.lookupDbState(ctx, dbName) + if ok && s.Err != nil { + return nil, false, s.Err + } + return s, ok, err } // Flush flushes all changes sitting in edit sessions to the session root for the database named. This normally @@ -1060,16 +1071,13 @@ func (sess *Session) SetSessionVarDirectly(ctx *sql.Context, key string, value i // HasDB returns true if |sess| is tracking state for this database. func (sess *Session) HasDB(ctx *sql.Context, dbName string) bool { - _, ok, err := sess.LookupDbState(ctx, dbName) + _, ok, err := sess.lookupDbState(ctx, dbName) return ok && err == nil } // AddDB adds the database given to this session. This establishes a starting root value for this session, as well as // other state tracking metadata. func (sess *Session) AddDB(ctx *sql.Context, dbState InitialDbState) error { - if dbState.Err != nil { - return dbState.Err - } db := dbState.Db defineSystemVariables(db.Name()) @@ -1089,7 +1097,9 @@ func (sess *Session) AddDB(ctx *sql.Context, dbState InitialDbState) error { sessionState.EditSession = editor.CreateTableEditSession(nil, editOpts) // WorkingSet is nil in the case of a read only, detached head DB - if dbState.WorkingSet != nil { + if dbState.Err != nil { + sessionState.Err = dbState.Err + } else if dbState.WorkingSet != nil { sessionState.WorkingSet = dbState.WorkingSet workingRoot := dbState.WorkingSet.WorkingRoot() // logrus.Tracef("working root intialized to %s", workingRoot.DebugString(ctx, false)) @@ -1114,7 +1124,10 @@ func (sess *Session) AddDB(ctx *sql.Context, dbState InitialDbState) error { // After setting the initial root we have no state to commit sessionState.dirty = false - return sess.setSessionVarsForDb(ctx, db.Name()) + if sessionState.Err == nil { + return sess.setSessionVarsForDb(ctx, db.Name()) + } + return nil } // CreateTemporaryTablesRoot creates an empty root value and a table edit session for the purposes of storing @@ -1146,7 +1159,7 @@ func (sess *Session) CWBHeadRef(ctx *sql.Context, dbName string) (ref.DoltRef, e // setSessionVarsForDb updates the three session vars that track the value of the session root hashes func (sess *Session) setSessionVarsForDb(ctx *sql.Context, dbName string) error { - state, _, err := sess.LookupDbState(ctx, dbName) + state, _, err := sess.lookupDbState(ctx, dbName) if err != nil { return err } diff --git a/integration-tests/bats/deleted-branches.bats b/integration-tests/bats/deleted-branches.bats index 876bc218cd..018ba9e57c 100644 --- a/integration-tests/bats/deleted-branches.bats +++ b/integration-tests/bats/deleted-branches.bats @@ -3,6 +3,8 @@ load $BATS_TEST_DIRNAME/helper/common.bash load $BATS_TEST_DIRNAME/helper/query-server-common.bash setup() { + skiponwindows "Has dependencies that are missing on the Jenkins Windows installation." + setup_common } @@ -11,16 +13,26 @@ teardown() { stop_sql_server } -@test "deleted-branches: can checkout existing branch after checked out branch is deleted" { +make_it() { + dolt sql -q 'create table test (id int primary key);' + dolt add . + dolt commit -m 'initial commit' + dolt branch -c main to_keep +} + +@test "deleted-branches: can checkout existing branch after checked out branch is deleted" { + make_it + dolt sql -q 'delete from dolt_branches where name = "main"' + + dolt branch -av + dolt checkout to_keep } @test "deleted-branches: can SQL connect with dolt_default_branch set to existing branch when checked out branch is deleted" { - skiponwindows "Has dependencies that are missing on the Jenkins Windows installation." - - dolt branch -c main to_keep + make_it start_sql_server "dolt_repo_$$" @@ -28,22 +40,80 @@ teardown() { server_query "dolt_repo_$$" 1 'delete from dolt_branches where name = "main"' "" - server_query "dolt_repo_$$" 1 "SELECT 2+2 FROM dual" "2+2\n4" + server_query "dolt_repo_$$" 1 "SELECT * FROM test" "id\n" } @test "deleted-branches: can SQL connect with existing branch revision specifier when checked out branch is deleted" { - skiponwindows "Has dependencies that are missing on the Jenkins Windows installation." - - dolt branch -c main to_keep + make_it start_sql_server "dolt_repo_$$" server_query "dolt_repo_$$" 1 'delete from dolt_branches where name = "main"' "" # Against the default branch it fails - run server_query "dolt_repo_$$" 1 "SELECT 2+2 FROM dual" "2+2\n4" + run server_query "dolt_repo_$$" 1 "SELECT * FROM test" "id\n" [ "$status" -eq 1 ] || fail "expected query against the default branch, which was deleted, to fail" # Against to_keep it succeeds - server_query "dolt_repo_$$/to_keep" 1 "SELECT 2+2 FROM dual" "2+2\n4" + server_query "dolt_repo_$$/to_keep" 1 "SELECT * FROM test" "id\n" +} + +@test "deleted-branches: can SQL connect with existing branch revision specifier when dolt_default_branch is invalid" { + make_it + + start_sql_server "dolt_repo_$$" + + server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_default_branch = 'this_branch_does_not_exist'" + + # Against the default branch it fails + run server_query "dolt_repo_$$" 1 "SELECT * FROM test" "" + [ "$status" -eq 1 ] # || (echo "expected query against the default branch, which does not exist, to fail"; exit 1) + + # Against main, which exists it succeeds + server_query "dolt_repo_$$/main" 1 "SELECT * FROM test" "id\n" +} + +@test "deleted-branches: can DOLT_CHECKOUT on SQL connection with existing branch revision specifier when dolt_default_branch is invalid" { + make_it + + start_sql_server "dolt_repo_$$" + + server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_default_branch = 'this_branch_does_not_exist'" + + multi_query "dolt_repo_$$/main" 1 " +SELECT * FROM test; +SELECT DOLT_CHECKOUT('to_keep'); +SELECT * FROM test;" +} + +@test "deleted-branches: can DOLT_CHECKOUT on SQL connection with existing branch revision specifier set to existing branch when checked out branch is deleted" { + make_it + + dolt branch -c to_keep to_checkout + + start_sql_server "dolt_repo_$$" + + server_query "dolt_repo_$$" 1 'delete from dolt_branches where name = "main"' "" + + multi_query "dolt_repo_$$/to_keep" 1 " +SELECT * FROM test; +SELECT DOLT_CHECKOUT('to_checkout'); +SELECT * FROM test;" +} + +@test "deleted-branches: can DOLT_CHECKOUT on SQL connecttion with dolt_default_branch set to existing branch when checked out branch is deleted" { + make_it + + dolt branch -c to_keep to_checkout + + start_sql_server "dolt_repo_$$" + + server_query "dolt_repo_$$" 1 "SET @@GLOBAL.dolt_default_branch = 'to_keep'" + + server_query "dolt_repo_$$" 1 'delete from dolt_branches where name = "main"' "" + + multi_query "dolt_repo_$$" 1 " +SELECT * FROM test; +SELECT DOLT_CHECKOUT('to_checkout'); +SELECT * FROM test;" }