diff --git a/go/cmd/dolt/commands/clean.go b/go/cmd/dolt/commands/clean.go index 779f9b50c8..235ab60da2 100644 --- a/go/cmd/dolt/commands/clean.go +++ b/go/cmd/dolt/commands/clean.go @@ -84,7 +84,7 @@ func (cmd CleanCmd) Exec(ctx context.Context, commandStr string, args []string, return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - roots, err = actions.CleanUntracked(ctx, roots, apr.Args, apr.Contains(DryrunCleanParam)) + roots, err = actions.CleanUntracked(ctx, roots, apr.Args, apr.Contains(DryrunCleanParam), false) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } diff --git a/go/libraries/doltcore/doltdb/doltdb.go b/go/libraries/doltcore/doltdb/doltdb.go index 17a809d4a1..7eb4502ced 100644 --- a/go/libraries/doltcore/doltdb/doltdb.go +++ b/go/libraries/doltcore/doltdb/doltdb.go @@ -374,6 +374,35 @@ func (ddb *DoltDB) ResolveCommitRef(ctx context.Context, ref ref.DoltRef) (*Comm return NewCommit(ctx, ddb.vrw, ddb.ns, commitVal) } +// ResolveBranchRoots returns the Roots for the branch given +func (ddb *DoltDB) ResolveBranchRoots(ctx context.Context, branch ref.BranchRef) (Roots, error) { + commitRef, err := ddb.ResolveCommitRef(ctx, branch) + if err != nil { + return Roots{}, err + } + + headRoot, err := commitRef.GetRootValue(ctx) + if err != nil { + return Roots{}, err + } + + wsRef, err := ref.WorkingSetRefForHead(branch) + if err != nil { + return Roots{}, err + } + + ws, err := ddb.ResolveWorkingSet(ctx, wsRef) + if err != nil { + return Roots{}, err + } + + return Roots{ + Head: headRoot, + Working: ws.WorkingRoot(), + Staged: ws.StagedRoot(), + }, nil +} + // ResolveTag takes a TagRef and returns the corresponding Tag object. func (ddb *DoltDB) ResolveTag(ctx context.Context, tagRef ref.TagRef) (*Tag, error) { ds, err := ddb.db.GetDataset(ctx, tagRef.String()) diff --git a/go/libraries/doltcore/env/actions/branch.go b/go/libraries/doltcore/env/actions/branch.go index 0cef5d2128..9e6198fc83 100644 --- a/go/libraries/doltcore/env/actions/branch.go +++ b/go/libraries/doltcore/env/actions/branch.go @@ -327,9 +327,9 @@ func createBranch(ctx context.Context, dbData env.DbData, newBranch, startingPoi return CreateBranchOnDB(ctx, dbData.Ddb, newBranch, startingPoint, force, dbData.Rsr.CWBHeadRef()) } -// RootsForBranch returns the roots needed for a branch checkout. |roots.Head| should be the pre-checkout head. The +// rootsForBranch returns the roots needed for a branch checkout. |roots.Head| should be the pre-checkout head. The // returned roots struct has |Head| set to |branchRoot|. -func RootsForBranch(ctx context.Context, roots doltdb.Roots, branchRoot *doltdb.RootValue, force bool) (doltdb.Roots, error) { +func rootsForBranch(ctx context.Context, roots doltdb.Roots, branchRoot *doltdb.RootValue, force bool) (doltdb.Roots, error) { conflicts := set.NewStrSet([]string{}) if roots.Head == nil { roots.Working = branchRoot @@ -352,12 +352,12 @@ func RootsForBranch(ctx context.Context, roots doltdb.Roots, branchRoot *doltdb. return doltdb.Roots{}, CheckoutWouldOverwrite{conflicts.AsSlice()} } - roots.Working, err = overwriteRoot(ctx, branchRoot, wrkTblHashes) + roots.Working, err = writeTableHashes(ctx, branchRoot, wrkTblHashes) if err != nil { return doltdb.Roots{}, err } - roots.Staged, err = overwriteRoot(ctx, branchRoot, stgTblHashes) + roots.Staged, err = writeTableHashes(ctx, branchRoot, stgTblHashes) if err != nil { return doltdb.Roots{}, err } @@ -383,44 +383,90 @@ func CheckoutBranch(ctx context.Context, dEnv *env.DoltEnv, brName string, force return doltdb.ErrAlreadyOnBranch } - branchRoot, err := BranchRoot(ctx, db, brName) + branchHead, err := branchHeadRoot(ctx, db, brName) if err != nil { return err } + workingSetExists := true initialWs, err := dEnv.WorkingSet(ctx) + if err == doltdb.ErrWorkingSetNotFound { + // ignore, but don't reset the working set + workingSetExists = false + } else if err != nil { + return err + } - if err != nil { - // working set does not exist, ignore error and skip the compatibility check below - } else if !force { - err = checkWorkingSetCompatibility(ctx, dEnv, branchRef, initialWs) + if !force { + if checkoutWouldStompWorkingSetChanges(ctx, dEnv, branchRef) { + return ErrWorkingSetsOnBothBranches + } + } + + initialRoots, err := dEnv.Roots(ctx) + + // roots will be empty/nil if the working set is not set (working set is not set if the current branch was deleted) + if errors.Is(err, doltdb.ErrBranchNotFound) || errors.Is(err, doltdb.ErrWorkingSetNotFound) { + workingSetExists = false + } else if err != nil { + return err + } + + hasChanges := false + if workingSetExists { + hasChanges, _, _, err = rootHasUncommittedChanges(initialRoots) if err != nil { return err } } - shouldResetWorkingSet := true - initialRoots, err := dEnv.Roots(ctx) - - // roots will be empty/nil if the working set is not set (working set is not set if the current branch was deleted) - if errors.Is(err, doltdb.ErrBranchNotFound) || errors.Is(err, doltdb.ErrWorkingSetNotFound) { - initialRoots, _ = dEnv.RecoveryRoots(ctx) - shouldResetWorkingSet = false - } else if err != nil { - return err + // Only if the current working set has uncommitted changes do we carry them forward to the branch being checked out. + // If this is the case, then the destination branch must *not* have any uncommitted changes, as checked by + // checkoutWouldStompWorkingSetChanges + if hasChanges { + err = transferWorkingChanges(ctx, dEnv, initialRoots, branchHead, branchRef, force) + if err != nil { + return err + } + } else { + err = dEnv.RepoStateWriter().SetCWBHeadRef(ctx, ref.MarshalableRef{Ref: branchRef}) + if err != nil { + return err + } } - newRoots, err := RootsForBranch(ctx, initialRoots, branchRoot, force) + if workingSetExists && hasChanges { + err = cleanOldWorkingSet(ctx, dEnv, initialRoots, initialHeadRef, initialWs) + if err != nil { + return err + } + } + + return nil +} + +func transferWorkingChanges( + ctx context.Context, + dEnv *env.DoltEnv, + initialRoots doltdb.Roots, + branchHead *doltdb.RootValue, + branchRef ref.BranchRef, + force bool, +) error { + newRoots, err := rootsForBranch(ctx, initialRoots, branchHead, force) if err != nil { return err } + // important to not update the checked out branch until after we have done the error checking above, otherwise we + // potentially leave the client in a bad state err = dEnv.RepoStateWriter().SetCWBHeadRef(ctx, ref.MarshalableRef{Ref: branchRef}) if err != nil { return err } ws, err := dEnv.WorkingSet(ctx) + // For backwards compatibility we support the branch not having a working set, but generally speaking it already // should have one if err == doltdb.ErrWorkingSetNotFound { @@ -438,20 +484,71 @@ func CheckoutBranch(ctx context.Context, dEnv *env.DoltEnv, brName string, force return err } - if shouldResetWorkingSet { - // reset the source branch's working set to the branch head, leaving the source branch unchanged - err = ResetHard(ctx, dEnv, "", initialRoots, initialHeadRef, initialWs) - if err != nil { - return err - } - } - return nil } -// BranchRoot returns the root value at the branch with the name given -// TODO: this belongs in DoltDB, maybe -func BranchRoot(ctx context.Context, db *doltdb.DoltDB, brName string) (*doltdb.RootValue, error) { +// cleanOldWorkingSet resets the source branch's working set to the branch head, leaving the source branch unchanged +func cleanOldWorkingSet( + ctx context.Context, + dEnv *env.DoltEnv, + initialRoots doltdb.Roots, + initialHeadRef ref.DoltRef, + initialWs *doltdb.WorkingSet, +) error { + // reset the source branch's working set to the branch head, leaving the source branch unchanged + err := ResetHard(ctx, dEnv, "", initialRoots, initialHeadRef, initialWs) + if err != nil { + return err + } + + // Annoyingly, after the ResetHard above we need to get all the roots again, because the working set has changed + cm, err := dEnv.DoltDB.ResolveCommitRef(ctx, initialHeadRef) + if err != nil { + return err + } + + headRoot, err := cm.ResolveRootValue(ctx) + if err != nil { + return err + } + + workingSet, err := dEnv.DoltDB.ResolveWorkingSet(ctx, initialWs.Ref()) + if err != nil { + return err + } + + resetRoots := doltdb.Roots{ + Head: headRoot, + Working: workingSet.WorkingRoot(), + Staged: workingSet.StagedRoot(), + } + + // we also have to do a clean, because we the ResetHard won't touch any new tables (tables only in the working set) + newRoots, err := CleanUntracked(ctx, resetRoots, []string{}, false, true) + if err != nil { + return err + } + + h, err := workingSet.HashOf() + if err != nil { + return err + } + + err = dEnv.DoltDB.UpdateWorkingSet( + ctx, + initialWs.Ref(), + initialWs.WithWorkingRoot(newRoots.Working).WithStagedRoot(newRoots.Staged).ClearMerge(), + h, + dEnv.NewWorkingSetMeta("reset hard"), + ) + if err != nil { + return err + } + return nil +} + +// branchHeadRoot returns the root value at the branch head with the name given +func branchHeadRoot(ctx context.Context, db *doltdb.DoltDB, brName string) (*doltdb.RootValue, error) { cs, err := doltdb.NewCommitSpec(brName) if err != nil { return nil, doltdb.RootValueUnreadable{RootType: doltdb.HeadRoot, Cause: err} @@ -541,9 +638,9 @@ func moveModifiedTables(ctx context.Context, oldRoot, newRoot, changedRoot *dolt return resultMap, nil } -// overwriteRoot writes new table hash values for the root given and returns it. +// writeTableHashes writes new table hash values for the root given and returns it. // This is an inexpensive and convenient way of replacing all the tables at once. -func overwriteRoot(ctx context.Context, head *doltdb.RootValue, tblHashes map[string]hash.Hash) (*doltdb.RootValue, error) { +func writeTableHashes(ctx context.Context, head *doltdb.RootValue, tblHashes map[string]hash.Hash) (*doltdb.RootValue, error) { names, err := head.GetTableNames(ctx) if err != nil { return nil, err @@ -575,53 +672,55 @@ func overwriteRoot(ctx context.Context, head *doltdb.RootValue, tblHashes map[st return head, nil } -// checkWorkingSetCompatibility checks that the current working set is "compatible" with the dest working set. +// checkoutWouldStompWorkingSetChanges checks that the current working set is "compatible" with the dest working set. // This means that if both working sets are present (ie there are changes on both source and dest branches), // we check if the changes are identical before allowing a clobbering checkout. // Working set errors are ignored by this function, because they are properly handled elsewhere. -func checkWorkingSetCompatibility(ctx context.Context, dEnv *env.DoltEnv, branchRef ref.BranchRef, currentWs *doltdb.WorkingSet) error { - db := dEnv.DoltDB - destWsRef, err := ref.WorkingSetRefForHead(branchRef) +func checkoutWouldStompWorkingSetChanges(ctx context.Context, dEnv *env.DoltEnv, branchRef ref.BranchRef) bool { + sourceRoots, err := dEnv.Roots(ctx) if err != nil { - // dest working set does not exist, skip check - return nil - } - destWs, err := db.ResolveWorkingSet(ctx, destWsRef) - if err != nil { - // dest working set does not resolve, skip check - return nil + return false } - sourceHasChanges, sourceHash, err := detectWorkingSetChanges(currentWs) + destRoots, err := dEnv.DoltDB.ResolveBranchRoots(ctx, branchRef) if err != nil { - // error detecting source changes, skip check - return nil + return false } - destHasChanges, destHash, err := detectWorkingSetChanges(destWs) - if err != nil { - // error detecting dest changes, skip check - return nil - } - areHashesEqual := sourceHash.Equal(destHash) - if sourceHasChanges && destHasChanges && !areHashesEqual { - return ErrWorkingSetsOnBothBranches + sourceHasChanges, sourceWorkingHash, sourceStagedHash, err := rootHasUncommittedChanges(sourceRoots) + if err != nil { + return false } - return nil + + destHasChanges, destWorkingHash, destStagedHash, err := rootHasUncommittedChanges(destRoots) + if err != nil { + return false + } + + // This is a stomping checkout operation if both the source and dest have uncommitted changes, and they're not the + // same uncommitted changes + return sourceHasChanges && destHasChanges && (sourceWorkingHash != destWorkingHash || sourceStagedHash != destStagedHash) } -// detectWorkingSetChanges returns a boolean indicating whether the working set has changes, and a hash of the changes -func detectWorkingSetChanges(ws *doltdb.WorkingSet) (hasChanges bool, wrHash hash.Hash, err error) { - wrHash, err = ws.WorkingRoot().HashOf() +// rootHasUncommittedChanges returns whether the roots given have uncommitted changes, and the hashes of the working and staged roots +func rootHasUncommittedChanges(roots doltdb.Roots) (hasChanges bool, workingHash hash.Hash, stagedHash hash.Hash, err error) { + headHash, err := roots.Head.HashOf() if err != nil { - return false, hash.Hash{}, err + return false, hash.Hash{}, hash.Hash{}, err } - srHash, err := ws.StagedRoot().HashOf() + + workingHash, err = roots.Working.HashOf() if err != nil { - return false, hash.Hash{}, err + return false, hash.Hash{}, hash.Hash{}, err } - hasChanges = !wrHash.Equal(srHash) - return hasChanges, wrHash, nil + + stagedHash, err = roots.Staged.HashOf() + if err != nil { + return false, hash.Hash{}, hash.Hash{}, err + } + + hasChanges = workingHash != stagedHash || stagedHash != headHash + return hasChanges, workingHash, stagedHash, nil } func IsBranch(ctx context.Context, ddb *doltdb.DoltDB, str string) (bool, error) { diff --git a/go/libraries/doltcore/env/actions/reset.go b/go/libraries/doltcore/env/actions/reset.go index 02a3c99ee2..0a601ff365 100644 --- a/go/libraries/doltcore/env/actions/reset.go +++ b/go/libraries/doltcore/env/actions/reset.go @@ -282,7 +282,7 @@ func IsValidRef(ctx context.Context, cSpecStr string, ddb *doltdb.DoltDB, rsr en // CleanUntracked deletes untracked tables from the working root. // Evaluates untracked tables as: all working tables - all staged tables. -func CleanUntracked(ctx context.Context, roots doltdb.Roots, tables []string, dryrun bool) (doltdb.Roots, error) { +func CleanUntracked(ctx context.Context, roots doltdb.Roots, tables []string, dryrun bool, force bool) (doltdb.Roots, error) { untrackedTables := make(map[string]struct{}) var err error @@ -318,7 +318,7 @@ func CleanUntracked(ctx context.Context, roots doltdb.Roots, tables []string, dr toDelete = append(toDelete, t) } - newRoot, err = newRoot.RemoveTables(ctx, false, false, toDelete...) + newRoot, err = newRoot.RemoveTables(ctx, force, force, toDelete...) if err != nil { return doltdb.Roots{}, fmt.Errorf("failed to remove tables; %w", err) } diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_clean.go b/go/libraries/doltcore/sqle/dprocedures/dolt_clean.go index ccc379345b..8f24d1726b 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_clean.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_clean.go @@ -57,7 +57,7 @@ func doDoltClean(ctx *sql.Context, args []string) (int, error) { return 1, fmt.Errorf("Could not load database %s", dbName) } - roots, err = actions.CleanUntracked(ctx, roots, apr.Args, apr.ContainsAll(cli.DryRunFlag)) + roots, err = actions.CleanUntracked(ctx, roots, apr.Args, apr.ContainsAll(cli.DryRunFlag), false) if err != nil { return 1, fmt.Errorf("failed to clean; %w", err) } diff --git a/integration-tests/bats/checkout.bats b/integration-tests/bats/checkout.bats index 352f3dbeb1..af408aff94 100755 --- a/integration-tests/bats/checkout.bats +++ b/integration-tests/bats/checkout.bats @@ -66,42 +66,73 @@ SQL dolt sql <