diff --git a/go/cmd/dolt/cli/arg_parser_helpers.go b/go/cmd/dolt/cli/arg_parser_helpers.go index ec319ce3fc..73d464e129 100644 --- a/go/cmd/dolt/cli/arg_parser_helpers.go +++ b/go/cmd/dolt/cli/arg_parser_helpers.go @@ -141,3 +141,12 @@ func CreateRevertArgParser() *argparser.ArgParser { return ap } + +func CreatePullArgParser() *argparser.ArgParser { + ap := argparser.NewArgParser() + ap.SupportsFlag(SquashParam, "", "Merges changes to the working set without updating the commit history") + ap.SupportsFlag(NoFFParam, "", "Create a merge commit even when the merge resolves as a fast-forward.") + ap.SupportsFlag(ForceFlag, "f", "Ignores any foreign key warnings and proceeds with the commit.") + + return ap +} diff --git a/go/cmd/dolt/commands/fetch.go b/go/cmd/dolt/commands/fetch.go index 9d4c747393..13de126a87 100644 --- a/go/cmd/dolt/commands/fetch.go +++ b/go/cmd/dolt/commands/fetch.go @@ -121,7 +121,7 @@ func getRefSpecs(args []string, dEnv *env.DoltEnv, remotes map[string]env.Remote if len(args) != 0 { rs, verr = parseRSFromArgs(remName, args) } else { - rs, err = dEnv.GetRefSpecs(remName) + rs, err = env.GetRefSpecs(dEnv.RepoStateReader(), remName) if err != nil { verr = errhand.VerboseErrorFromError(err) } @@ -203,7 +203,7 @@ func fetchRefSpecs(ctx context.Context, mode ref.UpdateMode, dEnv *env.DoltEnv, if remoteTrackRef != nil { rsSeen = true - srcDBCommit, err := actions.FetchRemoteBranch(ctx, dEnv, rem, srcDB, dEnv.DoltDB, branchRef, remoteTrackRef, runProgFuncs, stopProgFuncs) + srcDBCommit, err := actions.FetchRemoteBranch(ctx, dEnv.TempTableFilesDir(), rem, srcDB, dEnv.DoltDB, branchRef, remoteTrackRef, runProgFuncs, stopProgFuncs) if err != nil { return errhand.VerboseErrorFromError(err) @@ -242,7 +242,7 @@ func fetchRefSpecs(ctx context.Context, mode ref.UpdateMode, dEnv *env.DoltEnv, } } - err = actions.FetchFollowTags(ctx, dEnv, srcDB, dEnv.DoltDB, runProgFuncs, stopProgFuncs) + err = actions.FetchFollowTags(ctx, dEnv.TempTableFilesDir(), srcDB, dEnv.DoltDB, runProgFuncs, stopProgFuncs) if err != nil { return errhand.VerboseErrorFromError(err) diff --git a/go/cmd/dolt/commands/merge.go b/go/cmd/dolt/commands/merge.go index 47087acd89..2d32261873 100644 --- a/go/cmd/dolt/commands/merge.go +++ b/go/cmd/dolt/commands/merge.go @@ -158,7 +158,17 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string, msg = m } - mergeSpec, ok, err := merge.ParseMergeSpec(ctx, dEnv, msg, commitSpecStr, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag), t) + roots, err := dEnv.Roots(ctx) + if err != nil { + return handleCommitErr(ctx, dEnv, err, usage) + } + + name, email, err := env.GetNameAndEmail(dEnv.Config) + if err != nil { + return handleCommitErr(ctx, dEnv, err, usage) + } + + spec, ok, err := merge.NewMergeSpec(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, roots, name, email, msg, commitSpecStr, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag), t) if err != nil { return handleCommitErr(ctx, dEnv, errhand.VerboseErrorFromError(err), usage) } @@ -167,12 +177,12 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string, return handleCommitErr(ctx, dEnv, nil, usage) } - err = mergePrinting(ctx, dEnv, mergeSpec) + err = mergePrinting(ctx, dEnv, spec) if err != nil { return handleCommitErr(ctx, dEnv, err, usage) } - tblToStats, err := merge.MergeCommitSpec(ctx, dEnv, mergeSpec) + tblToStats, err := merge.MergeCommitSpec(ctx, dEnv, spec) hasConflicts, hasConstraintViolations := printSuccessStats(tblToStats) if hasConflicts && hasConstraintViolations { cli.Println("Automatic merge failed; fix conflicts and constraint violations and then commit the result.") @@ -200,47 +210,47 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string, return handleCommitErr(ctx, dEnv, verr, usage) } -func mergePrinting(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *merge.MergeSpec) errhand.VerboseError { - if mergeSpec.H1 == mergeSpec.H2 { +func mergePrinting(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec) errhand.VerboseError { + if spec.HeadH == spec.MergeH { //TODO - why is this different for merge/pull? // cli.Println("Already up to date.") cli.Println("Everything up-to-date.") return nil } - cli.Println("Updating", mergeSpec.H1.String()+".."+mergeSpec.H2.String()) + cli.Println("Updating", spec.HeadH.String()+".."+spec.MergeH.String()) - if mergeSpec.Squash { + if spec.Squash { cli.Println("Squash commit -- not updating HEAD") } - if len(mergeSpec.TblNames) != 0 { + if len(spec.TblNames) != 0 { bldr := errhand.BuildDError("error: Your local changes to the following tables would be overwritten by merge:") - for _, tName := range mergeSpec.TblNames { + for _, tName := range spec.TblNames { bldr.AddDetails(tName) } bldr.AddDetails("Please commit your changes before you merge.") return bldr.Build() } - if ok, err := mergeSpec.Cm1.CanFastForwardTo(ctx, mergeSpec.Cm2); ok { - ancRoot, err := mergeSpec.Cm1.GetRootValue() + if ok, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC); ok { + ancRoot, err := spec.HeadC.GetRootValue() if err != nil { return errhand.VerboseErrorFromError(err) } - mergedRoot, err := mergeSpec.Cm2.GetRootValue() + mergedRoot, err := spec.MergeC.GetRootValue() if err != nil { return errhand.VerboseErrorFromError(err) } if _, err := merge.MayHaveConstraintViolations(ctx, ancRoot, mergedRoot); err != nil { return errhand.VerboseErrorFromError(err) } - if mergeSpec.Noff { - if mergeSpec.Msg == "" { + if spec.Noff { + if spec.Msg == "" { msg, err := getCommitMessageFromEditor(ctx, dEnv) if err != nil { return errhand.VerboseErrorFromError(err) } - mergeSpec.Msg = msg + spec.Msg = msg } } else { cli.Println("Fast-forward") diff --git a/go/cmd/dolt/commands/pull.go b/go/cmd/dolt/commands/pull.go index a3d1c70e4a..3785dc101b 100644 --- a/go/cmd/dolt/commands/pull.go +++ b/go/cmd/dolt/commands/pull.go @@ -54,7 +54,7 @@ func (cmd PullCmd) Description() string { // CreateMarkdown creates a markdown file containing the helptext for the command at the given path func (cmd PullCmd) CreateMarkdown(fs filesys.Filesys, path, commandStr string) error { - ap := cmd.createArgParser() + ap := cli.CreatePullArgParser() return CreateMarkdown(fs, path, cli.GetCommandDocumentation(commandStr, pullDocs, ap)) } @@ -72,7 +72,7 @@ func (cmd PullCmd) EventType() eventsapi.ClientEventType { // Exec executes the command func (cmd PullCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv) int { - ap := cmd.createArgParser() + ap := cli.CreatePullArgParser() help, usage := cli.HelpAndUsagePrinters(cli.GetCommandDocumentation(commandStr, pullDocs, ap)) apr := cli.ParseArgsOrDie(ap, args, help) @@ -87,7 +87,7 @@ func (cmd PullCmd) Exec(ctx context.Context, commandStr string, args []string, d remoteName = apr.Arg(0) } - pullSpec, err := env.ParsePullSpec(ctx, dEnv, remoteName, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag)) + pullSpec, err := env.NewPullSpec(ctx, dEnv.RepoStateReader(), remoteName, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag)) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } @@ -111,7 +111,7 @@ func pullHelper(ctx context.Context, dEnv *env.DoltEnv, pullSpec *env.PullSpec) if remoteTrackRef != nil { - srcDBCommit, err := actions.FetchRemoteBranch(ctx, dEnv, pullSpec.Remote, srcDB, dEnv.DoltDB, pullSpec.Branch, remoteTrackRef, runProgFuncs, stopProgFuncs) + srcDBCommit, err := actions.FetchRemoteBranch(ctx, dEnv.TempTableFilesDir(), pullSpec.Remote, srcDB, dEnv.DoltDB, pullSpec.Branch, remoteTrackRef, runProgFuncs, stopProgFuncs) if err != nil { return err } @@ -122,7 +122,18 @@ func pullHelper(ctx context.Context, dEnv *env.DoltEnv, pullSpec *env.PullSpec) } t := doltdb.CommitNowFunc() - mergeSpec, ok, err := merge.ParseMergeSpec(ctx, dEnv, pullSpec.Msg, remoteTrackRef.String(), pullSpec.Squash, pullSpec.Noff, pullSpec.Force, t) + + roots, err := dEnv.Roots(ctx) + if err != nil { + return err + } + + name, email, err := env.GetNameAndEmail(dEnv.Config) + if err != nil { + return err + } + + mergeSpec, ok, err := merge.NewMergeSpec(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, roots, name, email, pullSpec.Msg, remoteTrackRef.String(), pullSpec.Squash, pullSpec.Noff, pullSpec.Force, t) if err != nil { return err } @@ -142,12 +153,10 @@ func pullHelper(ctx context.Context, dEnv *env.DoltEnv, pullSpec *env.PullSpec) } } - srcDB, err = pullSpec.Remote.GetRemoteDB(ctx, dEnv.DoltDB.ValueReadWriter().Format()) - if err != nil { return err } - err = actions.FetchFollowTags(ctx, dEnv, srcDB, dEnv.DoltDB, runProgFuncs, stopProgFuncs) + err = actions.FetchFollowTags(ctx, dEnv.TempTableFilesDir(), srcDB, dEnv.DoltDB, runProgFuncs, stopProgFuncs) if err != nil { return err diff --git a/go/cmd/dolt/commands/push.go b/go/cmd/dolt/commands/push.go index 554d30e491..5a0b8a373a 100644 --- a/go/cmd/dolt/commands/push.go +++ b/go/cmd/dolt/commands/push.go @@ -96,7 +96,7 @@ func (cmd PushCmd) Exec(ctx context.Context, commandStr string, args []string, d case env.ErrNoUpstreamForBranch: currentBranch := dEnv.RepoStateReader().CWBHeadRef() remoteName := "" - if defRemote, verr := dEnv.GetDefaultRemote(); verr == nil { + if defRemote, verr := env.GetDefaultRemote(dEnv.RepoStateReader()); verr == nil { remoteName = defRemote.Name } verr = errhand.BuildDError("fatal: The current branch " + currentBranch.GetPath() + " has no upstream branch.\n" + @@ -166,7 +166,7 @@ func (ts *TextSpinner) next() string { return string([]rune{spinnerSeq[ts.seqPos]}) } -func pullerProgFunc(pullerEventCh chan datas.PullerEvent) { +func pullerProgFunc(ctx context.Context, pullerEventCh chan datas.PullerEvent) { var pos int var currentTreeLevel int var percentBuffered float64 @@ -177,6 +177,11 @@ func pullerProgFunc(pullerEventCh chan datas.PullerEvent) { uploadRate := "" for evt := range pullerEventCh { + select { + case <-ctx.Done(): + return + default: + } switch evt.EventType { case datas.NewLevelTWEvent: if evt.TWEventDetails.TreeLevel != 1 { @@ -228,19 +233,25 @@ func pullerProgFunc(pullerEventCh chan datas.PullerEvent) { } } -func progFunc(progChan chan datas.PullProgress) { +func progFunc(ctx context.Context, progChan chan datas.PullProgress) { var latest datas.PullProgress last := time.Now().UnixNano() - 1 lenPrinted := 0 done := false for !done { select { + case <-ctx.Done(): + return + default: + } + select { + case <-ctx.Done(): + return case progress, ok := <-progChan: if !ok { done = true } latest = progress - case <-time.After(250 * time.Millisecond): break } @@ -262,7 +273,7 @@ func progFunc(progChan chan datas.PullProgress) { } } -func runProgFuncs() (*sync.WaitGroup, chan datas.PullProgress, chan datas.PullerEvent) { +func runProgFuncs(ctx context.Context) (*sync.WaitGroup, chan datas.PullProgress, chan datas.PullerEvent) { pullerEventCh := make(chan datas.PullerEvent, 128) progChan := make(chan datas.PullProgress, 128) wg := &sync.WaitGroup{} @@ -270,19 +281,20 @@ func runProgFuncs() (*sync.WaitGroup, chan datas.PullProgress, chan datas.Puller wg.Add(1) go func() { defer wg.Done() - progFunc(progChan) + progFunc(ctx, progChan) }() wg.Add(1) go func() { defer wg.Done() - pullerProgFunc(pullerEventCh) + pullerProgFunc(ctx, pullerEventCh) }() return wg, progChan, pullerEventCh } -func stopProgFuncs(wg *sync.WaitGroup, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) { +func stopProgFuncs(cancel context.CancelFunc, wg *sync.WaitGroup, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) { + cancel() close(progChan) close(pullerEventCh) wg.Wait() diff --git a/go/cmd/dolt/commands/read_tables.go b/go/cmd/dolt/commands/read_tables.go index 255e2e9220..72663961db 100644 --- a/go/cmd/dolt/commands/read_tables.go +++ b/go/cmd/dolt/commands/read_tables.go @@ -171,20 +171,15 @@ func pullTableValue(ctx context.Context, dEnv *env.DoltEnv, srcDB *doltdb.DoltDB return nil, errhand.BuildDError("Unable to read from remote database.").AddCause(err).Build() } + newCtx, cancelFunc := context.WithCancel(ctx) cli.Println("Retrieving", tblName) - wg, progChan, pullerEventCh := runProgFuncs() + wg, progChan, pullerEventCh := runProgFuncs(newCtx) err = dEnv.DoltDB.PushChunksForRefHash(ctx, dEnv.TempTableFilesDir(), srcDB, tblHash, pullerEventCh) - + stopProgFuncs(cancelFunc, wg, progChan, pullerEventCh) if err != nil { return nil, errhand.BuildDError("Failed reading chunks for remote table '%s' at '%s'", tblName, commitStr).AddCause(err).Build() } - stopProgFuncs(wg, progChan, pullerEventCh) - - if err != nil { - return nil, errhand.BuildDError("Failed to pull chunks.").AddCause(err).Build() - } - destRoot, err = destRoot.SetTableHash(ctx, tblName, tblHash) if err != nil { diff --git a/go/libraries/doltcore/env/actions/branch.go b/go/libraries/doltcore/env/actions/branch.go index a4186329d0..bf8e0eecb8 100644 --- a/go/libraries/doltcore/env/actions/branch.go +++ b/go/libraries/doltcore/env/actions/branch.go @@ -160,7 +160,7 @@ func DeleteBranchOnDB(ctx context.Context, ddb *doltdb.DoltDB, dref ref.DoltRef, } isMerged, _ := master.CanFastReverseTo(ctx, cm) - if err != nil && err != doltdb.ErrUpToDate { + if err != nil && errors.Is(err, doltdb.ErrUpToDate) { return err } if !isMerged { diff --git a/go/libraries/doltcore/env/actions/remotes.go b/go/libraries/doltcore/env/actions/remotes.go index 01c0b548d4..59f628e425 100644 --- a/go/libraries/doltcore/env/actions/remotes.go +++ b/go/libraries/doltcore/env/actions/remotes.go @@ -39,8 +39,8 @@ var ErrFailedToDeleteRemote = errors.New("failed to delete remote") var ErrFailedToGetRemoteDb = errors.New("failed to get remote db") var ErrUnknownPushErr = errors.New("unknown push error") -type ProgStarter func() (*sync.WaitGroup, chan datas.PullProgress, chan datas.PullerEvent) -type ProgStopper func(wg *sync.WaitGroup, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) +type ProgStarter func(ctx context.Context) (*sync.WaitGroup, chan datas.PullProgress, chan datas.PullerEvent) +type ProgStopper func(cancel context.CancelFunc, wg *sync.WaitGroup, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) // Push will update a destination branch, in a given destination database if it can be done as a fast forward merge. // This is accomplished first by verifying that the remote tracking reference for the source database can be updated to @@ -189,9 +189,10 @@ func PushToRemoteBranch(ctx context.Context, dEnv *env.DoltEnv, mode ref.UpdateM return fmt.Errorf("%w; refspec not found: '%s'; %s", ref.ErrInvalidRefSpec, srcRef.GetPath(), err.Error()) } - wg, progChan, pullerEventCh := progStarter() + newCtx, cancelFunc := context.WithCancel(ctx) + wg, progChan, pullerEventCh := progStarter(newCtx) err = Push(ctx, dEnv, mode, destRef.(ref.BranchRef), remoteRef.(ref.RemoteRef), localDB, remoteDB, cm, progChan, pullerEventCh) - progStopper(wg, progChan, pullerEventCh) + progStopper(cancelFunc, wg, progChan, pullerEventCh) if err != nil { switch err { @@ -212,11 +213,10 @@ func pushTagToRemote(ctx context.Context, dEnv *env.DoltEnv, srcRef, destRef ref return err } - wg, progChan, pullerEventCh := progStarter() - + newCtx, cancelFunc := context.WithCancel(ctx) + wg, progChan, pullerEventCh := progStarter(newCtx) err = PushTag(ctx, dEnv, destRef.(ref.TagRef), localDB, remoteDB, tg, progChan, pullerEventCh) - - progStopper(wg, progChan, pullerEventCh) + progStopper(cancelFunc, wg, progChan, pullerEventCh) if err != nil { return err @@ -252,25 +252,25 @@ func DeleteRemoteBranch(ctx context.Context, targetRef ref.BranchRef, remoteRef } // FetchCommit takes a fetches a commit and all underlying data from a remote source database to the local destination database. -func FetchCommit(ctx context.Context, dEnv *env.DoltEnv, srcDB, destDB *doltdb.DoltDB, srcDBCommit *doltdb.Commit, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) error { +func FetchCommit(ctx context.Context, tempTablesDir string, srcDB, destDB *doltdb.DoltDB, srcDBCommit *doltdb.Commit, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) error { stRef, err := srcDBCommit.GetStRef() if err != nil { return err } - return destDB.PullChunks(ctx, dEnv.TempTableFilesDir(), srcDB, stRef, progChan, pullerEventCh) + return destDB.PullChunks(ctx, tempTablesDir, srcDB, stRef, progChan, pullerEventCh) } // FetchCommit takes a fetches a commit tag and all underlying data from a remote source database to the local destination database. -func FetchTag(ctx context.Context, dEnv *env.DoltEnv, srcDB, destDB *doltdb.DoltDB, srcDBTag *doltdb.Tag, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) error { +func FetchTag(ctx context.Context, tempTableDir string, srcDB, destDB *doltdb.DoltDB, srcDBTag *doltdb.Tag, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) error { stRef, err := srcDBTag.GetStRef() if err != nil { return err } - return destDB.PullChunks(ctx, dEnv.TempTableFilesDir(), srcDB, stRef, progChan, pullerEventCh) + return destDB.PullChunks(ctx, tempTableDir, srcDB, stRef, progChan, pullerEventCh) } // Clone pulls all data from a remote source database to a local destination database. @@ -281,7 +281,7 @@ func Clone(ctx context.Context, srcDB, destDB *doltdb.DoltDB, eventCh chan<- dat // fetchFollowTags fetches all tags from the source DB whose commits have already // been fetched into the destination DB. // todo: potentially too expensive to iterate over all srcDB tags -func FetchFollowTags(ctx context.Context, dEnv *env.DoltEnv, srcDB, destDB *doltdb.DoltDB, progStarter ProgStarter, progStopper ProgStopper) error { +func FetchFollowTags(ctx context.Context, tempTableDir string, srcDB, destDB *doltdb.DoltDB, progStarter ProgStarter, progStopper ProgStopper) error { err := IterResolvedTags(ctx, srcDB, func(tag *doltdb.Tag) (stop bool, err error) { stRef, err := tag.GetStRef() if err != nil { @@ -313,9 +313,10 @@ func FetchFollowTags(ctx context.Context, dEnv *env.DoltEnv, srcDB, destDB *dolt return false, nil } - wg, progChan, pullerEventCh := progStarter() - err = FetchTag(ctx, dEnv, srcDB, destDB, tag, progChan, pullerEventCh) - progStopper(wg, progChan, pullerEventCh) + newCtx, cancelFunc := context.WithCancel(ctx) + wg, progChan, pullerEventCh := progStarter(newCtx) + err = FetchTag(ctx, tempTableDir, srcDB, destDB, tag, progChan, pullerEventCh) + progStopper(cancelFunc, wg, progChan, pullerEventCh) if err != nil { return true, err @@ -333,7 +334,7 @@ func FetchFollowTags(ctx context.Context, dEnv *env.DoltEnv, srcDB, destDB *dolt return nil } -func FetchRemoteBranch(ctx context.Context, dEnv *env.DoltEnv, rem env.Remote, srcDB, destDB *doltdb.DoltDB, srcRef, destRef ref.DoltRef, progStarter ProgStarter, progStopper ProgStopper) (*doltdb.Commit, error) { +func FetchRemoteBranch(ctx context.Context, tempTablesDir string, rem env.Remote, srcDB, destDB *doltdb.DoltDB, srcRef, destRef ref.DoltRef, progStarter ProgStarter, progStopper ProgStopper) (*doltdb.Commit, error) { evt := events.GetEventFromContext(ctx) u, err := earl.Parse(rem.Url) @@ -351,9 +352,10 @@ func FetchRemoteBranch(ctx context.Context, dEnv *env.DoltEnv, rem env.Remote, s return nil, fmt.Errorf("unable to find '%s' on '%s'; %w", srcRef.GetPath(), rem.Name, err) } - wg, progChan, pullerEventCh := progStarter() - err = FetchCommit(ctx, dEnv, srcDB, destDB, srcDBCommit, progChan, pullerEventCh) - progStopper(wg, progChan, pullerEventCh) + newCtx, cancelFunc := context.WithCancel(ctx) + wg, progChan, pullerEventCh := progStarter(newCtx) + err = FetchCommit(ctx, tempTablesDir, srcDB, destDB, srcDBCommit, progChan, pullerEventCh) + progStopper(cancelFunc, wg, progChan, pullerEventCh) if err != nil { return nil, err diff --git a/go/libraries/doltcore/env/environment.go b/go/libraries/doltcore/env/environment.go index 5f7350f6d8..54072dc5f2 100644 --- a/go/libraries/doltcore/env/environment.go +++ b/go/libraries/doltcore/env/environment.go @@ -550,19 +550,15 @@ func (dEnv *DoltEnv) UpdateWorkingSet(ctx context.Context, ws *doltdb.WorkingSet } type repoStateReader struct { - dEnv *DoltEnv + *DoltEnv } func (r *repoStateReader) CWBHeadRef() ref.DoltRef { - return r.dEnv.RepoState.CWBHeadRef() + return r.RepoState.CWBHeadRef() } func (r *repoStateReader) CWBHeadSpec() *doltdb.CommitSpec { - return r.dEnv.RepoState.CWBHeadSpec() -} - -func (r *repoStateReader) GetRemotes() (map[string]Remote, error) { - return r.dEnv.GetRemotes() + return r.RepoState.CWBHeadSpec() } func (dEnv *DoltEnv) RepoStateReader() RepoStateReader { @@ -971,13 +967,17 @@ func (dEnv *DoltEnv) FindRef(ctx context.Context, refStr string) (ref.DoltRef, e // GetRefSpecs takes an optional remoteName and returns all refspecs associated with that remote. If "" is passed as // the remoteName then the default remote is used. -func (dEnv *DoltEnv) GetRefSpecs(remoteName string) ([]ref.RemoteRefSpec, error) { +func GetRefSpecs(rsr RepoStateReader, remoteName string) ([]ref.RemoteRefSpec, error) { var remote Remote var err error + remotes, err := rsr.GetRemotes() + if err != nil { + return nil, err + } if remoteName == "" { - remote, err = dEnv.GetDefaultRemote() - } else if r, ok := dEnv.RepoState.Remotes[remoteName]; ok { + remote, err = GetDefaultRemote(rsr) + } else if r, ok := remotes[remoteName]; ok { remote = r } else { err = ErrUnknownRemote @@ -1014,8 +1014,11 @@ var ErrCantDetermineDefault = errors.New("unable to determine the default remote // GetDefaultRemote gets the default remote for the environment. Not fully implemented yet. Needs to support multiple // repos and a configurable default. -func (dEnv *DoltEnv) GetDefaultRemote() (Remote, error) { - remotes := dEnv.RepoState.Remotes +func GetDefaultRemote(rsr RepoStateReader) (Remote, error) { + remotes, err := rsr.GetRemotes() + if err != nil { + return NoRemote, err + } if len(remotes) == 0 { return NoRemote, ErrNoRemote @@ -1025,7 +1028,7 @@ func (dEnv *DoltEnv) GetDefaultRemote() (Remote, error) { } } - if remote, ok := dEnv.RepoState.Remotes["origin"]; ok { + if remote, ok := remotes["origin"]; ok { return remote, nil } diff --git a/go/libraries/doltcore/env/remotes.go b/go/libraries/doltcore/env/remotes.go index e94b952d88..15d8ea3143 100644 --- a/go/libraries/doltcore/env/remotes.go +++ b/go/libraries/doltcore/env/remotes.go @@ -286,10 +286,10 @@ type PullSpec struct { Branch ref.DoltRef } -func ParsePullSpec(ctx context.Context, dEnv *DoltEnv, remoteName string, squash, noff, force bool) (*PullSpec, error) { - branch := dEnv.RepoStateReader().CWBHeadRef() +func NewPullSpec(ctx context.Context, rsr RepoStateReader, remoteName string, squash, noff, force bool) (*PullSpec, error) { + branch := rsr.CWBHeadRef() - refSpecs, err := dEnv.GetRefSpecs(remoteName) + refSpecs, err := GetRefSpecs(rsr, remoteName) if err != nil { return nil, err } @@ -298,7 +298,11 @@ func ParsePullSpec(ctx context.Context, dEnv *DoltEnv, remoteName string, squash return nil, ErrNoRefSpecForRemote } - remote := dEnv.RepoState.Remotes[refSpecs[0].GetRemote()] + remotes, err := rsr.GetRemotes() + if err != nil { + return nil, err + } + remote := remotes[refSpecs[0].GetRemote()] return &PullSpec{ Squash: squash, diff --git a/go/libraries/doltcore/env/repo_state.go b/go/libraries/doltcore/env/repo_state.go index dd7c3523a9..654f97c18e 100644 --- a/go/libraries/doltcore/env/repo_state.go +++ b/go/libraries/doltcore/env/repo_state.go @@ -39,6 +39,7 @@ type RepoStateWriter interface { SetCWBHeadRef(context.Context, ref.MarshalableRef) error AddRemote(name string, url string, fetchSpecs []string, params map[string]string) error RemoveRemote(ctx context.Context, name string) error + TempTableFilesDir() string } type DocsReadWriter interface { diff --git a/go/libraries/doltcore/merge/action.go b/go/libraries/doltcore/merge/action.go index 9e35c1d448..32df96f417 100644 --- a/go/libraries/doltcore/merge/action.go +++ b/go/libraries/doltcore/merge/action.go @@ -35,10 +35,10 @@ var ErrMergeFailedToUpdateRepoState = errors.New("unable to execute repo state u var ErrFailedToDetermineMergeability = errors.New("failed to determine mergeability") type MergeSpec struct { - H1 hash.Hash - H2 hash.Hash - Cm1 *doltdb.Commit - Cm2 *doltdb.Commit + HeadH hash.Hash + MergeH hash.Hash + HeadC *doltdb.Commit + MergeC *doltdb.Commit TblNames []string WorkingDiffs map[string]hash.Hash Squash bool @@ -51,13 +51,13 @@ type MergeSpec struct { Date time.Time } -func ParseMergeSpec(ctx context.Context, dEnv *env.DoltEnv, msg string, commitSpecStr string, squash bool, noff bool, force bool, date time.Time) (*MergeSpec, bool, error) { +func NewMergeSpec(ctx context.Context, rsr env.RepoStateReader, ddb *doltdb.DoltDB, roots doltdb.Roots, name, email, msg string, commitSpecStr string, squash bool, noff bool, force bool, date time.Time) (*MergeSpec, bool, error) { cs1, err := doltdb.NewCommitSpec("HEAD") if err != nil { return nil, false, err } - cm1, err := dEnv.DoltDB.Resolve(context.TODO(), cs1, dEnv.RepoStateReader().CWBHeadRef()) + cm1, err := ddb.Resolve(context.TODO(), cs1, rsr.CWBHeadRef()) if err != nil { return nil, false, err } @@ -67,7 +67,7 @@ func ParseMergeSpec(ctx context.Context, dEnv *env.DoltEnv, msg string, commitSp return nil, false, err } - cm2, err := dEnv.DoltDB.Resolve(context.TODO(), cs2, dEnv.RepoStateReader().CWBHeadRef()) + cm2, err := ddb.Resolve(context.TODO(), cs2, rsr.CWBHeadRef()) if err != nil { return nil, false, err } @@ -83,26 +83,16 @@ func ParseMergeSpec(ctx context.Context, dEnv *env.DoltEnv, msg string, commitSp } - roots, err := dEnv.Roots(ctx) - if err != nil { - return nil, false, err - } - tblNames, workingDiffs, err := MergeWouldStompChanges(ctx, roots, cm2) if err != nil { return nil, false, fmt.Errorf("%w; %s", ErrFailedToDetermineMergeability, err.Error()) } - name, email, err := env.GetNameAndEmail(dEnv.Config) - if err != nil { - return nil, false, err - } - return &MergeSpec{ - H1: h1, - H2: h2, - Cm1: cm1, - Cm2: cm2, + HeadH: h1, + MergeH: h2, + HeadC: cm1, + MergeC: cm2, TblNames: tblNames, WorkingDiffs: workingDiffs, Squash: squash, @@ -115,42 +105,42 @@ func ParseMergeSpec(ctx context.Context, dEnv *env.DoltEnv, msg string, commitSp }, true, nil } -func MergeCommitSpec(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *MergeSpec) (map[string]*MergeStats, error) { - if ok, err := mergeSpec.Cm1.CanFastForwardTo(ctx, mergeSpec.Cm2); ok { - ancRoot, err := mergeSpec.Cm1.GetRootValue() +func MergeCommitSpec(ctx context.Context, dEnv *env.DoltEnv, spec *MergeSpec) (map[string]*MergeStats, error) { + if ok, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC); ok { + ancRoot, err := spec.HeadC.GetRootValue() if err != nil { return nil, err } - mergedRoot, err := mergeSpec.Cm2.GetRootValue() + mergedRoot, err := spec.MergeC.GetRootValue() if err != nil { return nil, err } if cvPossible, err := MayHaveConstraintViolations(ctx, ancRoot, mergedRoot); err != nil { return nil, err } else if cvPossible { - return ExecuteMerge(ctx, dEnv, mergeSpec) + return ExecuteMerge(ctx, dEnv, spec) } - if mergeSpec.Noff { - return ExecNoFFMerge(ctx, dEnv, mergeSpec) + if spec.Noff { + return ExecNoFFMerge(ctx, dEnv, spec) } else { - return nil, ExecuteFFMerge(ctx, dEnv, mergeSpec) + return nil, ExecuteFFMerge(ctx, dEnv, spec) } } else if err == doltdb.ErrUpToDate || err == doltdb.ErrIsAhead { return nil, err } else { - return ExecuteMerge(ctx, dEnv, mergeSpec) + return ExecuteMerge(ctx, dEnv, spec) } } -func ExecNoFFMerge(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *MergeSpec) (map[string]*MergeStats, error) { - mergedRoot, err := mergeSpec.Cm2.GetRootValue() +func ExecNoFFMerge(ctx context.Context, dEnv *env.DoltEnv, spec *MergeSpec) (map[string]*MergeStats, error) { + mergedRoot, err := spec.MergeC.GetRootValue() if err != nil { return nil, ErrFailedToReadDatabase } tblToStats := make(map[string]*MergeStats) - err = mergedRootToWorking(ctx, false, dEnv, mergedRoot, mergeSpec.WorkingDiffs, mergeSpec.Cm2, tblToStats) + err = mergedRootToWorking(ctx, false, dEnv, mergedRoot, spec.WorkingDiffs, spec.MergeC, tblToStats) if err != nil { return tblToStats, err @@ -173,12 +163,12 @@ func ExecNoFFMerge(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *MergeSpec) } _, err = actions.CommitStaged(ctx, roots, ws.MergeActive(), mergeParentCommits, dEnv.DbData(), actions.CommitStagedProps{ - Message: mergeSpec.Msg, - Date: mergeSpec.Date, - AllowEmpty: mergeSpec.AllowEmpty, - Force: mergeSpec.Force, - Name: mergeSpec.Name, - Email: mergeSpec.Email, + Message: spec.Msg, + Date: spec.Date, + AllowEmpty: spec.AllowEmpty, + Force: spec.Force, + Name: spec.Name, + Email: spec.Email, }) if err != nil { @@ -209,16 +199,16 @@ func applyChanges(ctx context.Context, root *doltdb.RootValue, workingDiffs map[ func ExecuteFFMerge( ctx context.Context, dEnv *env.DoltEnv, - mergeSpec *MergeSpec, + spec *MergeSpec, ) error { - stagedRoot, err := mergeSpec.Cm2.GetRootValue() + stagedRoot, err := spec.MergeC.GetRootValue() if err != nil { return err } workingRoot := stagedRoot - if len(mergeSpec.WorkingDiffs) > 0 { - workingRoot, err = applyChanges(ctx, stagedRoot, mergeSpec.WorkingDiffs) + if len(spec.WorkingDiffs) > 0 { + workingRoot, err = applyChanges(ctx, stagedRoot, spec.WorkingDiffs) if err != nil { //return errhand.BuildDError("Failed to re-apply working changes.").AddCause(err).Build() @@ -231,8 +221,8 @@ func ExecuteFFMerge( return err } - if !mergeSpec.Squash { - err = dEnv.DoltDB.FastForward(ctx, dEnv.RepoStateReader().CWBHeadRef(), mergeSpec.Cm2) + if !spec.Squash { + err = dEnv.DoltDB.FastForward(ctx, dEnv.RepoStateReader().CWBHeadRef(), spec.MergeC) if err != nil { return err @@ -257,9 +247,9 @@ func ExecuteFFMerge( return nil } -func ExecuteMerge(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *MergeSpec) (map[string]*MergeStats, error) { +func ExecuteMerge(ctx context.Context, dEnv *env.DoltEnv, spec *MergeSpec) (map[string]*MergeStats, error) { opts := editor.Options{Deaf: dEnv.DbEaFactory()} - mergedRoot, tblToStats, err := MergeCommits(ctx, mergeSpec.Cm1, mergeSpec.Cm2, opts) + mergedRoot, tblToStats, err := MergeCommits(ctx, spec.HeadC, spec.MergeC, opts) if err != nil { switch err { case doltdb.ErrUpToDate: @@ -270,7 +260,7 @@ func ExecuteMerge(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *MergeSpec) return tblToStats, err } - return tblToStats, mergedRootToWorking(ctx, mergeSpec.Squash, dEnv, mergedRoot, mergeSpec.WorkingDiffs, mergeSpec.Cm2, tblToStats) + return tblToStats, mergedRootToWorking(ctx, spec.Squash, dEnv, mergedRoot, spec.WorkingDiffs, spec.MergeC, tblToStats) } // TODO: change this to be functional and not write to repo state diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go b/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go index a6b20e0c6f..8b71b8f0cf 100644 --- a/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_merge.go @@ -40,8 +40,10 @@ type DoltMergeFunc struct { const DoltConflictWarningCode int = 1105 // Since this our own custom warning we'll use 1105, the code for an unknown error -const hasConflicts = 0 -const noConflicts = 1 +const ( + hasConflicts int = 0 + noConflicts int = 1 +) func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { dbName := ctx.GetCurrentDatabase() @@ -51,18 +53,6 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) } sess := dsess.DSessFromSess(ctx.Session) - dbData, ok := sess.GetDbData(ctx, dbName) - - if !ok { - return noConflicts, fmt.Errorf("Could not load database %s", dbName) - } - - dbState, ok, err := sess.LookupDbState(ctx, dbName) - if err != nil { - return noConflicts, err - } else if !ok { - return noConflicts, fmt.Errorf("Could not load database %s", dbName) - } ap := cli.CreateMergeArgParser() args, err := getDoltArgs(ctx, row, d.Children()) @@ -85,11 +75,8 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return noConflicts, err } roots, ok := sess.GetRoots(ctx, dbName) - - // logrus.Errorf("heads are working: %s\nhead: %s", roots.Working.DebugString(ctx, true), roots.Head.DebugString(ctx, true)) - if !ok { - return noConflicts, fmt.Errorf("Could not load database %s", dbName) + return noConflicts, sql.ErrDatabaseNotFound.New(dbName) } if apr.Contains(cli.AbortParam) { @@ -110,108 +97,120 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return noConflicts, nil } - ddb, ok := sess.GetDoltDB(ctx, dbName) - if !ok { - return noConflicts, sql.ErrDatabaseNotFound.New(dbName) + branchName := apr.Arg(0) + + mergeSpec, err := createMergeSpec(ctx, sess, dbName, apr, branchName) + if err != nil { + return noConflicts, err + } + ws, conflicts, err := mergeIntoWorkingSet(ctx, sess, roots, ws, dbName, mergeSpec) + if err != nil { + return conflicts, err } - if hasConflicts, err := roots.Working.HasConflicts(ctx); err != nil { - return noConflicts, err - } else if hasConflicts { - return noConflicts, doltdb.ErrUnresolvedConflicts + err = sess.SetWorkingSet(ctx, dbName, ws, nil) + if err != nil { + return conflicts, err + } + + return conflicts, nil +} + +// mergeIntoWorkingSet encapsulates server merge logic, switching between fast-forward, no fast-forward, merge commit, +// and merging into working set. Returns a new WorkingSet and whether there were merge conflicts. This currently +// persists merge commits in the database, but expects the caller to update the working set. +func mergeIntoWorkingSet(ctx *sql.Context, sess *dsess.Session, roots doltdb.Roots, ws *doltdb.WorkingSet, dbName string, spec *merge.MergeSpec) (*doltdb.WorkingSet, int, error) { + if conflicts, err := roots.Working.HasConflicts(ctx); err != nil { + return ws, noConflicts, err + } else if conflicts { + return ws, hasConflicts, doltdb.ErrUnresolvedConflicts } if hasConstraintViolations, err := roots.Working.HasConstraintViolations(ctx); err != nil { - return noConflicts, err + return ws, hasConflicts, err } else if hasConstraintViolations { - return noConflicts, doltdb.ErrUnresolvedConstraintViolations + return ws, hasConflicts, doltdb.ErrUnresolvedConstraintViolations } if ws.MergeActive() { - return noConflicts, doltdb.ErrMergeActive + return ws, noConflicts, doltdb.ErrMergeActive } - err = checkForUncommittedChanges(roots.Working, roots.Head) + err := checkForUncommittedChanges(roots.Working, roots.Head) if err != nil { - return noConflicts, err + return ws, noConflicts, err } - branchName := apr.Arg(0) - mergeCommit, _, err := getBranchCommit(ctx, branchName, ddb) + canFF, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC) if err != nil { - return noConflicts, err - } - - headCommit, err := sess.GetHeadCommit(ctx, dbName) - if err != nil { - return noConflicts, err - } - - canFF, err := headCommit.CanFastForwardTo(ctx, mergeCommit) - if err != nil { - return noConflicts, err + return ws, noConflicts, err } if canFF { - headRoot, err := headCommit.GetRootValue() + headRoot, err := spec.HeadC.GetRootValue() if err != nil { - return noConflicts, err + return ws, noConflicts, err } - mergeRoot, err := mergeCommit.GetRootValue() + mergeRoot, err := spec.MergeC.GetRootValue() if err != nil { - return noConflicts, err + return ws, noConflicts, err } if cvPossible, err := merge.MayHaveConstraintViolations(ctx, headRoot, mergeRoot); err != nil { - return noConflicts, err + return ws, noConflicts, err } else if !cvPossible { - if apr.Contains(cli.NoFFParam) { - ws, err = executeNoFFMerge(ctx, sess, apr, dbName, ws, dbData, headCommit, mergeCommit) + dbData, ok := sess.GetDbData(ctx, dbName) + if !ok { + return ws, noConflicts, fmt.Errorf("could not load database %s", dbName) + } + if spec.Noff { + ws, err = executeNoFFMerge(ctx, sess, spec, dbName, ws, dbData) if err == doltdb.ErrUnresolvedConflicts { // if there are unresolved conflicts, write the resulting working set back to the session and return an // error message wsErr := sess.SetWorkingSet(ctx, dbName, ws, nil) if wsErr != nil { - return hasConflicts, wsErr + return ws, hasConflicts, wsErr } ctx.Warn(DoltConflictWarningCode, err.Error()) // Return 0 indicating there are conflicts - return hasConflicts, nil + return ws, hasConflicts, nil } } else { - err = executeFFMerge(ctx, sess, apr.Contains(cli.SquashParam), dbName, ws, dbData, mergeCommit) + ws, err = executeFFMerge(ctx, spec.Squash, ws, dbData, spec.MergeC) } if err != nil { - return noConflicts, err + return ws, noConflicts, err } - return noConflicts, err + return ws, noConflicts, err } } - ws, err = executeMerge(ctx, apr.Contains(cli.SquashParam), headCommit, mergeCommit, ws, dbState.EditSession.Opts) + dbState, ok, err := sess.LookupDbState(ctx, dbName) + if err != nil { + return ws, noConflicts, err + } else if !ok { + return ws, noConflicts, fmt.Errorf("could not load database %s", dbName) + } + + ws, err = executeMerge(ctx, spec.Squash, spec.HeadC, spec.MergeC, ws, dbState.EditSession.Opts) if err == doltdb.ErrUnresolvedConflicts { // if there are unresolved conflicts, write the resulting working set back to the session and return an // error message wsErr := sess.SetWorkingSet(ctx, dbName, ws, nil) if wsErr != nil { - return hasConflicts, wsErr + return ws, hasConflicts, wsErr } ctx.Warn(DoltConflictWarningCode, err.Error()) - return hasConflicts, nil + return ws, hasConflicts, nil } else if err != nil { - return noConflicts, err + return ws, noConflicts, err } - - err = sess.SetWorkingSet(ctx, dbName, ws, nil) - if err != nil { - return noConflicts, err - } - - return noConflicts, nil + return ws, noConflicts, nil } func abortMerge(ctx *sql.Context, workingSet *doltdb.WorkingSet, roots doltdb.Roots) (*doltdb.WorkingSet, error) { @@ -247,10 +246,10 @@ func executeMerge(ctx *sql.Context, squash bool, head, cm *doltdb.Commit, ws *do return mergeRootToWorking(squash, ws, mergeRoot, cm, mergeStats) } -func executeFFMerge(ctx *sql.Context, sess *dsess.Session, squash bool, dbName string, ws *doltdb.WorkingSet, dbData env.DbData, cm2 *doltdb.Commit) error { +func executeFFMerge(ctx *sql.Context, squash bool, ws *doltdb.WorkingSet, dbData env.DbData, cm2 *doltdb.Commit) (*doltdb.WorkingSet, error) { rv, err := cm2.GetRootValue() if err != nil { - return err + return ws, err } // TODO: This is all incredibly suspect, needs to be replaced with library code that is functional instead of @@ -258,71 +257,33 @@ func executeFFMerge(ctx *sql.Context, sess *dsess.Session, squash bool, dbName s if !squash { err = dbData.Ddb.FastForward(ctx, dbData.Rsr.CWBHeadRef(), cm2) if err != nil { - return err + return ws, err } } - ws = ws.WithWorkingRoot(rv).WithStagedRoot(rv) - - return sess.SetWorkingSet(ctx, dbName, ws, nil) + return ws.WithWorkingRoot(rv).WithStagedRoot(rv), nil } func executeNoFFMerge( ctx *sql.Context, dSess *dsess.Session, - apr *argparser.ArgParseResults, + spec *merge.MergeSpec, dbName string, ws *doltdb.WorkingSet, dbData env.DbData, - headCommit, mergeCommit *doltdb.Commit, + //headCommit, mergeCommit *doltdb.Commit, ) (*doltdb.WorkingSet, error) { - mergeRoot, err := mergeCommit.GetRootValue() + mergeRoot, err := spec.MergeC.GetRootValue() if err != nil { return nil, err } - ws, err = mergeRootToWorking(false, ws, mergeRoot, mergeCommit, map[string]*merge.MergeStats{}) + ws, err = mergeRootToWorking(false, ws, mergeRoot, spec.MergeC, map[string]*merge.MergeStats{}) if err != nil { // This error is recoverable, so we return a working set value along with the error return ws, err } - msg, msgOk := apr.GetValue(cli.CommitMessageArg) - if !msgOk { - hh, err := headCommit.HashOf() - if err != nil { - return nil, err - } - - cmh, err := mergeCommit.HashOf() - if err != nil { - return nil, err - } - - msg = fmt.Sprintf("SQL Generated commit merging %s into %s", hh.String(), cmh.String()) - } - - // TODO: refactor, redundant - var name, email string - if authorStr, ok := apr.GetValue(cli.AuthorParam); ok { - name, email, err = cli.ParseAuthor(authorStr) - if err != nil { - return nil, err - } - } else { - name = dSess.Username - email = dSess.Email - } - - t := ctx.QueryTime() - if commitTimeStr, ok := apr.GetValue(cli.DateParam); ok { - var err error - t, err = cli.ParseDate(commitTimeStr) - if err != nil { - return nil, err - } - } - // Save our work so far in the session, as it will be referenced by the commit call below (badly in need of a // refactoring) err = dSess.SetWorkingSet(ctx, dbName, ws, nil) @@ -341,12 +302,12 @@ func executeNoFFMerge( // TODO: this does several session state updates, and it really needs to just do one // We also need to commit any pending transaction before we do this. _, err = actions.CommitStaged(ctx, roots, ws.MergeActive(), mergeParentCommits, dbData, actions.CommitStagedProps{ - Message: msg, - Date: t, - AllowEmpty: apr.Contains(cli.AllowEmptyFlag), - Force: apr.Contains(cli.ForceFlag), - Name: name, - Email: email, + Message: spec.Msg, + Date: spec.Date, + AllowEmpty: spec.AllowEmpty, + Force: spec.Force, + Name: spec.Name, + Email: spec.Email, }) if err != nil { return nil, err @@ -355,6 +316,49 @@ func executeNoFFMerge( return ws, dSess.SetWorkingSet(ctx, dbName, ws.ClearMerge(), nil) } +func createMergeSpec(ctx *sql.Context, sess *dsess.Session, dbName string, apr *argparser.ArgParseResults, commitSpecStr string) (*merge.MergeSpec, error) { + ddb, ok := sess.GetDoltDB(ctx, dbName) + + dbData, ok := sess.GetDbData(ctx, dbName) + + msg, ok := apr.GetValue(cli.CommitMessageArg) + if !ok { + // TODO probably change, but we can't open editor so it'll have to be automated + msg = "automatic SQL merge" + } + + var err error + var name, email string + if authorStr, ok := apr.GetValue(cli.AuthorParam); ok { + name, email, err = cli.ParseAuthor(authorStr) + if err != nil { + return nil, err + } + } else { + name = sess.Username + email = sess.Email + } + + t := ctx.QueryTime() + if commitTimeStr, ok := apr.GetValue(cli.DateParam); ok { + t, err = cli.ParseDate(commitTimeStr) + if err != nil { + return nil, err + } + } + + roots, ok := sess.GetRoots(ctx, dbName) + if !ok { + return nil, sql.ErrDatabaseNotFound.New(dbName) + } + + mergeSpec, _, err := merge.NewMergeSpec(ctx, dbData.Rsr, ddb, roots, name, email, msg, commitSpecStr, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag), t) + if err != nil { + return nil, err + } + return mergeSpec, nil +} + // TODO: this copied from commands/merge.go because the latter isn't reusable. Fix that. func mergeRootToWorking( squash bool, diff --git a/go/libraries/doltcore/sqle/dfunctions/dolt_pull.go b/go/libraries/doltcore/sqle/dfunctions/dolt_pull.go new file mode 100644 index 0000000000..375cd7f7cf --- /dev/null +++ b/go/libraries/doltcore/sqle/dfunctions/dolt_pull.go @@ -0,0 +1,213 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dfunctions + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + + "github.com/dolthub/dolt/go/cmd/dolt/cli" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/env" + "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/store/datas" +) + +const DoltPullFuncName = "dolt_pull" + +type DoltPullFunc struct { + expression.NaryExpression +} + +// NewPullFunc creates a new PullFunc expression. +func NewPullFunc(ctx *sql.Context, args ...sql.Expression) (sql.Expression, error) { + return &DoltPullFunc{expression.NaryExpression{ChildExpressions: args}}, nil +} + +func (d DoltPullFunc) String() string { + childrenStrings := make([]string, len(d.Children())) + + for i, child := range d.Children() { + childrenStrings[i] = child.String() + } + + return fmt.Sprintf("DOLT_PULL(%s)", strings.Join(childrenStrings, ",")) +} + +func (d DoltPullFunc) Type() sql.Type { + return sql.Boolean +} + +func (d DoltPullFunc) WithChildren(ctx *sql.Context, children ...sql.Expression) (sql.Expression, error) { + return NewPullFunc(ctx, children...) +} + +func (d DoltPullFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + dbName := ctx.GetCurrentDatabase() + + if len(dbName) == 0 { + return noConflicts, fmt.Errorf("empty database name.") + } + + sess := dsess.DSessFromSess(ctx.Session) + dbData, ok := sess.GetDbData(ctx, dbName) + if !ok { + return noConflicts, sql.ErrDatabaseNotFound.New(dbName) + } + + ap := cli.CreatePullArgParser() + args, err := getDoltArgs(ctx, row, d.Children()) + + apr, err := ap.Parse(args) + if err != nil { + return noConflicts, err + } + + if apr.NArg() > 1 { + return noConflicts, actions.ErrInvalidPullArgs + } + + var remoteName string + if apr.NArg() == 1 { + remoteName = apr.Arg(0) + } + + pullSpec, err := env.NewPullSpec(ctx, dbData.Rsr, remoteName, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag)) + if err != nil { + return noConflicts, err + } + + srcDB, err := pullSpec.Remote.GetRemoteDBWithoutCaching(ctx, dbData.Ddb.ValueReadWriter().Format()) + if err != nil { + return noConflicts, fmt.Errorf("failed to get remote db; %w", err) + } + + ws, err := sess.WorkingSet(ctx, dbName) + if err != nil { + return noConflicts, err + } + + var conflicts interface{} + for _, refSpec := range pullSpec.RefSpecs { + remoteTrackRef := refSpec.DestRef(pullSpec.Branch) + + if remoteTrackRef != nil { + + // todo: can we pass nil for either of the channels? + srcDBCommit, err := actions.FetchRemoteBranch(ctx, dbData.Rsw.TempTableFilesDir(), pullSpec.Remote, srcDB, dbData.Ddb, pullSpec.Branch, remoteTrackRef, runProgFuncs, stopProgFuncs) + if err != nil { + return noConflicts, err + } + + // TODO: this could be replaced with a canFF check to test for error + err = dbData.Ddb.FastForward(ctx, remoteTrackRef, srcDBCommit) + if err != nil { + return noConflicts, fmt.Errorf("fetch failed; %w", err) + } + + roots, ok := sess.GetRoots(ctx, dbName) + if !ok { + return noConflicts, sql.ErrDatabaseNotFound.New(dbName) + } + + mergeSpec, err := createMergeSpec(ctx, sess, dbName, apr, remoteTrackRef.String()) + if err != nil { + return noConflicts, err + } + ws, conflicts, err = mergeIntoWorkingSet(ctx, sess, roots, ws, dbName, mergeSpec) + if err != nil && !errors.Is(doltdb.ErrUpToDate, err) { + return conflicts, err + } + + err = sess.SetWorkingSet(ctx, dbName, ws, nil) + if err != nil { + return conflicts, err + } + } + } + + err = actions.FetchFollowTags(ctx, dbData.Rsw.TempTableFilesDir(), srcDB, dbData.Ddb, runProgFuncs, stopProgFuncs) + if err != nil { + return noConflicts, err + } + + return noConflicts, nil +} + +func pullerProgFunc(ctx context.Context, pullerEventCh <-chan datas.PullerEvent) { + for { + select { + case <-ctx.Done(): + return + default: + } + select { + case <-ctx.Done(): + return + case <-pullerEventCh: + default: + } + } +} + +func progFunc(ctx context.Context, progChan <-chan datas.PullProgress) { + for { + select { + case <-ctx.Done(): + return + default: + } + select { + case <-ctx.Done(): + return + case <-progChan: + default: + } + } +} + +func runProgFuncs(ctx context.Context) (*sync.WaitGroup, chan datas.PullProgress, chan datas.PullerEvent) { + pullerEventCh := make(chan datas.PullerEvent) + progChan := make(chan datas.PullProgress) + wg := &sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + progFunc(ctx, progChan) + }() + + wg.Add(1) + go func() { + defer wg.Done() + pullerProgFunc(ctx, pullerEventCh) + }() + + return wg, progChan, pullerEventCh +} + +func stopProgFuncs(cancel context.CancelFunc, wg *sync.WaitGroup, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) { + cancel() + close(progChan) + close(pullerEventCh) + wg.Wait() +} diff --git a/go/libraries/doltcore/sqle/dfunctions/init.go b/go/libraries/doltcore/sqle/dfunctions/init.go index 18f3384781..4a23543dae 100644 --- a/go/libraries/doltcore/sqle/dfunctions/init.go +++ b/go/libraries/doltcore/sqle/dfunctions/init.go @@ -33,6 +33,7 @@ var DoltFunctions = []sql.Function{ sql.FunctionN{Name: ConstraintsVerifyFuncName, Fn: NewConstraintsVerifyFunc}, sql.FunctionN{Name: ConstraintsVerifyAllFuncName, Fn: NewConstraintsVerifyAllFunc}, sql.FunctionN{Name: RevertFuncName, Fn: NewRevertFunc}, + sql.FunctionN{Name: DoltPullFuncName, Fn: NewPullFunc}, } // These are the DoltFunctions that get exposed to Dolthub Api. diff --git a/go/libraries/doltcore/sqle/dsess/session_state_adapter.go b/go/libraries/doltcore/sqle/dsess/session_state_adapter.go old mode 100755 new mode 100644 index 7669e12bbc..7234c2ce1f --- a/go/libraries/doltcore/sqle/dsess/session_state_adapter.go +++ b/go/libraries/doltcore/sqle/dsess/session_state_adapter.go @@ -17,12 +17,15 @@ package dsess import ( "context" "fmt" + "path/filepath" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/dolt/go/libraries/doltcore/dbfactory" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/ref" + "github.com/dolthub/dolt/go/libraries/utils/filesys" ) // SessionStateAdapter is an adapter for env.RepoStateReader in SQL contexts, getting information about the repo state @@ -124,3 +127,16 @@ func (s SessionStateAdapter) AddRemote(name string, url string, fetchSpecs []str func (s SessionStateAdapter) RemoveRemote(ctx context.Context, name string) error { return fmt.Errorf("cannot delete remote in an SQL session") } + +func (s SessionStateAdapter) TempTableFilesDir() string { + //todo: save tempfile in dbState on server startup? + return mustAbs(dbfactory.DoltDir, "temptf") +} + +func mustAbs(path ...string) string { + absPath, err := filesys.LocalFS.Abs(filepath.Join(path...)) + if err != nil { + panic(err) + } + return absPath +} diff --git a/integration-tests/bats/remotes.bats b/integration-tests/bats/remotes.bats index 043d7a921f..090b9ea056 100644 --- a/integration-tests/bats/remotes.bats +++ b/integration-tests/bats/remotes.bats @@ -1110,4 +1110,4 @@ setup_ref_test() { run dolt fetch dadasdfasdfa [ "$status" -eq 1 ] [[ "$output" =~ "error: dadasdfasdfa does not appear to be a dolt database" ]] || false -} \ No newline at end of file +} diff --git a/integration-tests/bats/sql-pull.bats b/integration-tests/bats/sql-pull.bats new file mode 100644 index 0000000000..645b6acc50 --- /dev/null +++ b/integration-tests/bats/sql-pull.bats @@ -0,0 +1,229 @@ +#!/usr/bin/env bats +load $BATS_TEST_DIRNAME/helper/common.bash + +setup() { + setup_common + TMPDIRS=$(pwd)/tmpdirs + mkdir -p $TMPDIRS/{rem1,tmp1} + + # tmp1 -> rem1 -> tmp2 + cd $TMPDIRS/tmp1 + dolt init + dolt branch feature + dolt remote add origin file://../rem1 + dolt remote add test-remote file://../rem1 + dolt push origin master + + cd $TMPDIRS + dolt clone file://rem1 tmp2 + cd $TMPDIRS/tmp2 + dolt log + dolt branch feature + dolt remote add test-remote file://../rem1 + + # table and comits only present on tmp1, rem1 at start + cd $TMPDIRS/tmp1 + dolt sql -q "create table t1 (a int primary key, b int)" + dolt commit -am "First commit" + dolt sql -q "insert into t1 values (0,0)" + dolt commit -am "Second commit" + dolt push origin master + cd $TMPDIRS +} + +teardown() { + teardown_common + rm -rf $TMPDIRS + cd $BATS_TMPDIR +} + +@test "sql-pull: dolt_pull master" { + cd tmp2 + dolt sql -q "select dolt_pull('origin')" + run dolt sql -q "show tables" -r csv + [ "$status" -eq 0 ] + [ "${#lines[@]}" -eq 2 ] + [[ "$output" =~ "Table" ]] || false + [[ "$output" =~ "t1" ]] || false +} + +@test "sql-pull: dolt_pull custom remote" { + cd tmp2 + dolt sql -q "select dolt_pull('test-remote')" + run dolt sql -q "show tables" -r csv + [ "$status" -eq 0 ] + [ "${#lines[@]}" -eq 2 ] + [[ "$output" =~ "Table" ]] || false + [[ "$output" =~ "t1" ]] || false +} + +@test "sql-pull: dolt_pull default origin" { + cd tmp2 + dolt remote remove test-remote + dolt sql -q "select dolt_pull()" + run dolt sql -q "show tables" -r csv + [ "$status" -eq 0 ] + [ "${#lines[@]}" -eq 2 ] + [[ "$output" =~ "Table" ]] || false + [[ "$output" =~ "t1" ]] || false +} + +@test "sql-pull: dolt_pull default custom remote" { + cd tmp2 + dolt remote remove origin + dolt sql -q "select dolt_pull()" + run dolt sql -q "show tables" -r csv + [ "$status" -eq 0 ] + [ "${#lines[@]}" -eq 2 ] + [[ "$output" =~ "Table" ]] || false + [[ "$output" =~ "t1" ]] || false +} + +@test "sql-pull: dolt_pull up to date does not error" { + cd tmp2 + dolt sql -q "select dolt_pull('origin')" + dolt sql -q "select dolt_pull('origin')" + run dolt sql -q "show tables" -r csv + [ "$status" -eq 0 ] + [ "${#lines[@]}" -eq 2 ] + [[ "$output" =~ "Table" ]] || false + [[ "$output" =~ "t1" ]] || false +} + +@test "sql-pull: dolt_pull unknown remote fails" { + cd tmp2 + run dolt sql -q "select dolt_pull('unknown')" + [ "$status" -eq 1 ] + [[ "$output" =~ "unknown remote" ]] || false + [[ ! "$output" =~ "panic" ]] || false +} +@test "sql-pull: dolt_pull unknown feature branch fails" { + cd tmp2 + dolt checkout feature + run dolt sql -q "select dolt_pull('origin')" + [ "$status" -eq 1 ] + [[ "$output" =~ "branch not found" ]] || false + [[ ! "$output" =~ "panic" ]] || false +} + +@test "sql-pull: dolt_pull feature branch" { + cd tmp1 + dolt checkout feature + dolt merge master + dolt push origin feature + + cd ../tmp2 + dolt checkout feature + dolt sql -q "select dolt_pull('origin')" + run dolt sql -q "show tables" -r csv + [ "$status" -eq 0 ] + [ "${#lines[@]}" -eq 2 ] + [[ "$output" =~ "Table" ]] || false + [[ "$output" =~ "t1" ]] || false +} + +@test "sql-pull: dolt_pull force" { + skip "todo: support dolt pull --force (cli too)" + cd tmp2 + dolt sql -q "create table t2 (a int)" + dolt commit -am "2.0 commit" + dolt push origin master + + cd ../tmp1 + dolt sql -q "create table t2 (a int primary key)" + dolt sql -q "create table t3 (a int primary key)" + dolt commit -am "2.1 commit" + dolt push -f origin master + + cd ../tmp2 + run dolt sql -q "select dolt_pull('origin')" + [ "$status" -eq 1 ] + [[ ! "$output" =~ "panic" ]] || false + [[ "$output" =~ "fetch failed; dataset head is not ancestor of commit" ]] || false + + dolt sql -q "select dolt_pull('-f', 'origin')" + + run dolt log -n 1 + [ "$status" -eq 0 ] + [[ "$output" =~ "2.1 commit" ]] || false + + run dolt sql -q "show tables" -r csv + [ "${#lines[@]}" -eq 4 ] + [[ "$output" =~ "t3" ]] || false +} + +@test "sql-pull: dolt_pull squash" { + skip "todo: support dolt pull --squash (cli too)" + cd tmp2 + dolt sql -q "select dolt_pull('--squash', 'origin')" + run dolt sql -q "show tables" -r csv + [ "$status" -eq 0 ] + [ "${#lines[@]}" -eq 2 ] + [[ "$output" =~ "Table" ]] || false + [[ "$output" =~ "t1" ]] || false +} + +@test "sql-pull: dolt_pull --noff flag" { + cd tmp2 + dolt sql -q "select dolt_pull('--no-ff', 'origin')" + dolt status + run dolt log -n 1 + [ "$status" -eq 0 ] + # TODO change the default message name + [[ "$output" =~ "automatic SQL merge" ]] || false + + run dolt sql -q "show tables" -r csv + [ "$status" -eq 0 ] + [ "${#lines[@]}" -eq 2 ] + [[ "$output" =~ "Table" ]] || false + [[ "$output" =~ "t1" ]] || false +} + +@test "sql-pull: empty remote name does not panic" { + cd tmp2 + dolt sql -q "select dolt_pull('')" +} + +@test "sql-pull: dolt_pull dirty working set fails" { + cd tmp2 + dolt sql -q "create table t2 (a int)" + run dolt sql -q "select dolt_pull('origin')" + [ "$status" -eq 1 ] + [[ "$output" =~ "cannot merge with uncommitted changes" ]] || false +} + +@test "sql-pull: dolt_pull tag" { + cd tmp1 + dolt tag v1 + dolt push origin v1 + dolt tag + + cd ../tmp2 + dolt sql -q "select dolt_pull('origin')" + run dolt tag + [ "$status" -eq 0 ] + [[ "$output" =~ "v1" ]] || false +} + +@test "sql-pull: dolt_pull tags only for resolved commits" { + cd tmp1 + dolt tag v1 head + dolt tag v2 head^ + dolt push origin v1 + dolt push origin v2 + + dolt checkout feature + dolt sql -q "create table t2 (a int)" + dolt commit -am "feature commit" + dolt tag v3 + dolt push origin v3 + + cd ../tmp2 + dolt sql -q "select dolt_pull('origin')" + run dolt tag + [ "$status" -eq 0 ] + [[ "$output" =~ "v1" ]] || false + [[ "$output" =~ "v2" ]] || false + [[ ! "$output" =~ "v3" ]] || false +} +