sql server dolt_pull function first pass (#2102)

* sql_pull compiles

* Fix sql-merge bats

* Fix merge bats tests, add pull

* Cleanup comments

* Some of zach's comments

* Fresh set of sql-pull unittests

* bats name typo

* Windows doesn't support file:/// with 3 slashes

* Add one more comment line

* fix one of the check errors

* Add short circuits to cli progress goroutines

* [ga-format-pr] Run go/utils/repofmt/format_repo.sh and go/Godeps/update.sh

Co-authored-by: max-hoffman <max-hoffman@users.noreply.github.com>
This commit is contained in:
Maximilian Hoffman
2021-09-10 14:27:37 -07:00
committed by GitHub
parent 4e8bb133d5
commit 72cd1109fa
18 changed files with 745 additions and 247 deletions

View File

@@ -141,3 +141,12 @@ func CreateRevertArgParser() *argparser.ArgParser {
return ap
}
func CreatePullArgParser() *argparser.ArgParser {
ap := argparser.NewArgParser()
ap.SupportsFlag(SquashParam, "", "Merges changes to the working set without updating the commit history")
ap.SupportsFlag(NoFFParam, "", "Create a merge commit even when the merge resolves as a fast-forward.")
ap.SupportsFlag(ForceFlag, "f", "Ignores any foreign key warnings and proceeds with the commit.")
return ap
}

View File

@@ -121,7 +121,7 @@ func getRefSpecs(args []string, dEnv *env.DoltEnv, remotes map[string]env.Remote
if len(args) != 0 {
rs, verr = parseRSFromArgs(remName, args)
} else {
rs, err = dEnv.GetRefSpecs(remName)
rs, err = env.GetRefSpecs(dEnv.RepoStateReader(), remName)
if err != nil {
verr = errhand.VerboseErrorFromError(err)
}
@@ -203,7 +203,7 @@ func fetchRefSpecs(ctx context.Context, mode ref.UpdateMode, dEnv *env.DoltEnv,
if remoteTrackRef != nil {
rsSeen = true
srcDBCommit, err := actions.FetchRemoteBranch(ctx, dEnv, rem, srcDB, dEnv.DoltDB, branchRef, remoteTrackRef, runProgFuncs, stopProgFuncs)
srcDBCommit, err := actions.FetchRemoteBranch(ctx, dEnv.TempTableFilesDir(), rem, srcDB, dEnv.DoltDB, branchRef, remoteTrackRef, runProgFuncs, stopProgFuncs)
if err != nil {
return errhand.VerboseErrorFromError(err)
@@ -242,7 +242,7 @@ func fetchRefSpecs(ctx context.Context, mode ref.UpdateMode, dEnv *env.DoltEnv,
}
}
err = actions.FetchFollowTags(ctx, dEnv, srcDB, dEnv.DoltDB, runProgFuncs, stopProgFuncs)
err = actions.FetchFollowTags(ctx, dEnv.TempTableFilesDir(), srcDB, dEnv.DoltDB, runProgFuncs, stopProgFuncs)
if err != nil {
return errhand.VerboseErrorFromError(err)

View File

@@ -158,7 +158,17 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string,
msg = m
}
mergeSpec, ok, err := merge.ParseMergeSpec(ctx, dEnv, msg, commitSpecStr, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag), t)
roots, err := dEnv.Roots(ctx)
if err != nil {
return handleCommitErr(ctx, dEnv, err, usage)
}
name, email, err := env.GetNameAndEmail(dEnv.Config)
if err != nil {
return handleCommitErr(ctx, dEnv, err, usage)
}
spec, ok, err := merge.NewMergeSpec(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, roots, name, email, msg, commitSpecStr, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag), t)
if err != nil {
return handleCommitErr(ctx, dEnv, errhand.VerboseErrorFromError(err), usage)
}
@@ -167,12 +177,12 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string,
return handleCommitErr(ctx, dEnv, nil, usage)
}
err = mergePrinting(ctx, dEnv, mergeSpec)
err = mergePrinting(ctx, dEnv, spec)
if err != nil {
return handleCommitErr(ctx, dEnv, err, usage)
}
tblToStats, err := merge.MergeCommitSpec(ctx, dEnv, mergeSpec)
tblToStats, err := merge.MergeCommitSpec(ctx, dEnv, spec)
hasConflicts, hasConstraintViolations := printSuccessStats(tblToStats)
if hasConflicts && hasConstraintViolations {
cli.Println("Automatic merge failed; fix conflicts and constraint violations and then commit the result.")
@@ -200,47 +210,47 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string,
return handleCommitErr(ctx, dEnv, verr, usage)
}
func mergePrinting(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *merge.MergeSpec) errhand.VerboseError {
if mergeSpec.H1 == mergeSpec.H2 {
func mergePrinting(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec) errhand.VerboseError {
if spec.HeadH == spec.MergeH {
//TODO - why is this different for merge/pull?
// cli.Println("Already up to date.")
cli.Println("Everything up-to-date.")
return nil
}
cli.Println("Updating", mergeSpec.H1.String()+".."+mergeSpec.H2.String())
cli.Println("Updating", spec.HeadH.String()+".."+spec.MergeH.String())
if mergeSpec.Squash {
if spec.Squash {
cli.Println("Squash commit -- not updating HEAD")
}
if len(mergeSpec.TblNames) != 0 {
if len(spec.TblNames) != 0 {
bldr := errhand.BuildDError("error: Your local changes to the following tables would be overwritten by merge:")
for _, tName := range mergeSpec.TblNames {
for _, tName := range spec.TblNames {
bldr.AddDetails(tName)
}
bldr.AddDetails("Please commit your changes before you merge.")
return bldr.Build()
}
if ok, err := mergeSpec.Cm1.CanFastForwardTo(ctx, mergeSpec.Cm2); ok {
ancRoot, err := mergeSpec.Cm1.GetRootValue()
if ok, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC); ok {
ancRoot, err := spec.HeadC.GetRootValue()
if err != nil {
return errhand.VerboseErrorFromError(err)
}
mergedRoot, err := mergeSpec.Cm2.GetRootValue()
mergedRoot, err := spec.MergeC.GetRootValue()
if err != nil {
return errhand.VerboseErrorFromError(err)
}
if _, err := merge.MayHaveConstraintViolations(ctx, ancRoot, mergedRoot); err != nil {
return errhand.VerboseErrorFromError(err)
}
if mergeSpec.Noff {
if mergeSpec.Msg == "" {
if spec.Noff {
if spec.Msg == "" {
msg, err := getCommitMessageFromEditor(ctx, dEnv)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
mergeSpec.Msg = msg
spec.Msg = msg
}
} else {
cli.Println("Fast-forward")

View File

@@ -54,7 +54,7 @@ func (cmd PullCmd) Description() string {
// CreateMarkdown creates a markdown file containing the helptext for the command at the given path
func (cmd PullCmd) CreateMarkdown(fs filesys.Filesys, path, commandStr string) error {
ap := cmd.createArgParser()
ap := cli.CreatePullArgParser()
return CreateMarkdown(fs, path, cli.GetCommandDocumentation(commandStr, pullDocs, ap))
}
@@ -72,7 +72,7 @@ func (cmd PullCmd) EventType() eventsapi.ClientEventType {
// Exec executes the command
func (cmd PullCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv) int {
ap := cmd.createArgParser()
ap := cli.CreatePullArgParser()
help, usage := cli.HelpAndUsagePrinters(cli.GetCommandDocumentation(commandStr, pullDocs, ap))
apr := cli.ParseArgsOrDie(ap, args, help)
@@ -87,7 +87,7 @@ func (cmd PullCmd) Exec(ctx context.Context, commandStr string, args []string, d
remoteName = apr.Arg(0)
}
pullSpec, err := env.ParsePullSpec(ctx, dEnv, remoteName, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag))
pullSpec, err := env.NewPullSpec(ctx, dEnv.RepoStateReader(), remoteName, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag))
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
@@ -111,7 +111,7 @@ func pullHelper(ctx context.Context, dEnv *env.DoltEnv, pullSpec *env.PullSpec)
if remoteTrackRef != nil {
srcDBCommit, err := actions.FetchRemoteBranch(ctx, dEnv, pullSpec.Remote, srcDB, dEnv.DoltDB, pullSpec.Branch, remoteTrackRef, runProgFuncs, stopProgFuncs)
srcDBCommit, err := actions.FetchRemoteBranch(ctx, dEnv.TempTableFilesDir(), pullSpec.Remote, srcDB, dEnv.DoltDB, pullSpec.Branch, remoteTrackRef, runProgFuncs, stopProgFuncs)
if err != nil {
return err
}
@@ -122,7 +122,18 @@ func pullHelper(ctx context.Context, dEnv *env.DoltEnv, pullSpec *env.PullSpec)
}
t := doltdb.CommitNowFunc()
mergeSpec, ok, err := merge.ParseMergeSpec(ctx, dEnv, pullSpec.Msg, remoteTrackRef.String(), pullSpec.Squash, pullSpec.Noff, pullSpec.Force, t)
roots, err := dEnv.Roots(ctx)
if err != nil {
return err
}
name, email, err := env.GetNameAndEmail(dEnv.Config)
if err != nil {
return err
}
mergeSpec, ok, err := merge.NewMergeSpec(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, roots, name, email, pullSpec.Msg, remoteTrackRef.String(), pullSpec.Squash, pullSpec.Noff, pullSpec.Force, t)
if err != nil {
return err
}
@@ -142,12 +153,10 @@ func pullHelper(ctx context.Context, dEnv *env.DoltEnv, pullSpec *env.PullSpec)
}
}
srcDB, err = pullSpec.Remote.GetRemoteDB(ctx, dEnv.DoltDB.ValueReadWriter().Format())
if err != nil {
return err
}
err = actions.FetchFollowTags(ctx, dEnv, srcDB, dEnv.DoltDB, runProgFuncs, stopProgFuncs)
err = actions.FetchFollowTags(ctx, dEnv.TempTableFilesDir(), srcDB, dEnv.DoltDB, runProgFuncs, stopProgFuncs)
if err != nil {
return err

View File

@@ -96,7 +96,7 @@ func (cmd PushCmd) Exec(ctx context.Context, commandStr string, args []string, d
case env.ErrNoUpstreamForBranch:
currentBranch := dEnv.RepoStateReader().CWBHeadRef()
remoteName := "<remote>"
if defRemote, verr := dEnv.GetDefaultRemote(); verr == nil {
if defRemote, verr := env.GetDefaultRemote(dEnv.RepoStateReader()); verr == nil {
remoteName = defRemote.Name
}
verr = errhand.BuildDError("fatal: The current branch " + currentBranch.GetPath() + " has no upstream branch.\n" +
@@ -166,7 +166,7 @@ func (ts *TextSpinner) next() string {
return string([]rune{spinnerSeq[ts.seqPos]})
}
func pullerProgFunc(pullerEventCh chan datas.PullerEvent) {
func pullerProgFunc(ctx context.Context, pullerEventCh chan datas.PullerEvent) {
var pos int
var currentTreeLevel int
var percentBuffered float64
@@ -177,6 +177,11 @@ func pullerProgFunc(pullerEventCh chan datas.PullerEvent) {
uploadRate := ""
for evt := range pullerEventCh {
select {
case <-ctx.Done():
return
default:
}
switch evt.EventType {
case datas.NewLevelTWEvent:
if evt.TWEventDetails.TreeLevel != 1 {
@@ -228,19 +233,25 @@ func pullerProgFunc(pullerEventCh chan datas.PullerEvent) {
}
}
func progFunc(progChan chan datas.PullProgress) {
func progFunc(ctx context.Context, progChan chan datas.PullProgress) {
var latest datas.PullProgress
last := time.Now().UnixNano() - 1
lenPrinted := 0
done := false
for !done {
select {
case <-ctx.Done():
return
default:
}
select {
case <-ctx.Done():
return
case progress, ok := <-progChan:
if !ok {
done = true
}
latest = progress
case <-time.After(250 * time.Millisecond):
break
}
@@ -262,7 +273,7 @@ func progFunc(progChan chan datas.PullProgress) {
}
}
func runProgFuncs() (*sync.WaitGroup, chan datas.PullProgress, chan datas.PullerEvent) {
func runProgFuncs(ctx context.Context) (*sync.WaitGroup, chan datas.PullProgress, chan datas.PullerEvent) {
pullerEventCh := make(chan datas.PullerEvent, 128)
progChan := make(chan datas.PullProgress, 128)
wg := &sync.WaitGroup{}
@@ -270,19 +281,20 @@ func runProgFuncs() (*sync.WaitGroup, chan datas.PullProgress, chan datas.Puller
wg.Add(1)
go func() {
defer wg.Done()
progFunc(progChan)
progFunc(ctx, progChan)
}()
wg.Add(1)
go func() {
defer wg.Done()
pullerProgFunc(pullerEventCh)
pullerProgFunc(ctx, pullerEventCh)
}()
return wg, progChan, pullerEventCh
}
func stopProgFuncs(wg *sync.WaitGroup, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) {
func stopProgFuncs(cancel context.CancelFunc, wg *sync.WaitGroup, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) {
cancel()
close(progChan)
close(pullerEventCh)
wg.Wait()

View File

@@ -171,20 +171,15 @@ func pullTableValue(ctx context.Context, dEnv *env.DoltEnv, srcDB *doltdb.DoltDB
return nil, errhand.BuildDError("Unable to read from remote database.").AddCause(err).Build()
}
newCtx, cancelFunc := context.WithCancel(ctx)
cli.Println("Retrieving", tblName)
wg, progChan, pullerEventCh := runProgFuncs()
wg, progChan, pullerEventCh := runProgFuncs(newCtx)
err = dEnv.DoltDB.PushChunksForRefHash(ctx, dEnv.TempTableFilesDir(), srcDB, tblHash, pullerEventCh)
stopProgFuncs(cancelFunc, wg, progChan, pullerEventCh)
if err != nil {
return nil, errhand.BuildDError("Failed reading chunks for remote table '%s' at '%s'", tblName, commitStr).AddCause(err).Build()
}
stopProgFuncs(wg, progChan, pullerEventCh)
if err != nil {
return nil, errhand.BuildDError("Failed to pull chunks.").AddCause(err).Build()
}
destRoot, err = destRoot.SetTableHash(ctx, tblName, tblHash)
if err != nil {

View File

@@ -160,7 +160,7 @@ func DeleteBranchOnDB(ctx context.Context, ddb *doltdb.DoltDB, dref ref.DoltRef,
}
isMerged, _ := master.CanFastReverseTo(ctx, cm)
if err != nil && err != doltdb.ErrUpToDate {
if err != nil && errors.Is(err, doltdb.ErrUpToDate) {
return err
}
if !isMerged {

View File

@@ -39,8 +39,8 @@ var ErrFailedToDeleteRemote = errors.New("failed to delete remote")
var ErrFailedToGetRemoteDb = errors.New("failed to get remote db")
var ErrUnknownPushErr = errors.New("unknown push error")
type ProgStarter func() (*sync.WaitGroup, chan datas.PullProgress, chan datas.PullerEvent)
type ProgStopper func(wg *sync.WaitGroup, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent)
type ProgStarter func(ctx context.Context) (*sync.WaitGroup, chan datas.PullProgress, chan datas.PullerEvent)
type ProgStopper func(cancel context.CancelFunc, wg *sync.WaitGroup, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent)
// Push will update a destination branch, in a given destination database if it can be done as a fast forward merge.
// This is accomplished first by verifying that the remote tracking reference for the source database can be updated to
@@ -189,9 +189,10 @@ func PushToRemoteBranch(ctx context.Context, dEnv *env.DoltEnv, mode ref.UpdateM
return fmt.Errorf("%w; refspec not found: '%s'; %s", ref.ErrInvalidRefSpec, srcRef.GetPath(), err.Error())
}
wg, progChan, pullerEventCh := progStarter()
newCtx, cancelFunc := context.WithCancel(ctx)
wg, progChan, pullerEventCh := progStarter(newCtx)
err = Push(ctx, dEnv, mode, destRef.(ref.BranchRef), remoteRef.(ref.RemoteRef), localDB, remoteDB, cm, progChan, pullerEventCh)
progStopper(wg, progChan, pullerEventCh)
progStopper(cancelFunc, wg, progChan, pullerEventCh)
if err != nil {
switch err {
@@ -212,11 +213,10 @@ func pushTagToRemote(ctx context.Context, dEnv *env.DoltEnv, srcRef, destRef ref
return err
}
wg, progChan, pullerEventCh := progStarter()
newCtx, cancelFunc := context.WithCancel(ctx)
wg, progChan, pullerEventCh := progStarter(newCtx)
err = PushTag(ctx, dEnv, destRef.(ref.TagRef), localDB, remoteDB, tg, progChan, pullerEventCh)
progStopper(wg, progChan, pullerEventCh)
progStopper(cancelFunc, wg, progChan, pullerEventCh)
if err != nil {
return err
@@ -252,25 +252,25 @@ func DeleteRemoteBranch(ctx context.Context, targetRef ref.BranchRef, remoteRef
}
// FetchCommit takes a fetches a commit and all underlying data from a remote source database to the local destination database.
func FetchCommit(ctx context.Context, dEnv *env.DoltEnv, srcDB, destDB *doltdb.DoltDB, srcDBCommit *doltdb.Commit, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) error {
func FetchCommit(ctx context.Context, tempTablesDir string, srcDB, destDB *doltdb.DoltDB, srcDBCommit *doltdb.Commit, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) error {
stRef, err := srcDBCommit.GetStRef()
if err != nil {
return err
}
return destDB.PullChunks(ctx, dEnv.TempTableFilesDir(), srcDB, stRef, progChan, pullerEventCh)
return destDB.PullChunks(ctx, tempTablesDir, srcDB, stRef, progChan, pullerEventCh)
}
// FetchCommit takes a fetches a commit tag and all underlying data from a remote source database to the local destination database.
func FetchTag(ctx context.Context, dEnv *env.DoltEnv, srcDB, destDB *doltdb.DoltDB, srcDBTag *doltdb.Tag, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) error {
func FetchTag(ctx context.Context, tempTableDir string, srcDB, destDB *doltdb.DoltDB, srcDBTag *doltdb.Tag, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) error {
stRef, err := srcDBTag.GetStRef()
if err != nil {
return err
}
return destDB.PullChunks(ctx, dEnv.TempTableFilesDir(), srcDB, stRef, progChan, pullerEventCh)
return destDB.PullChunks(ctx, tempTableDir, srcDB, stRef, progChan, pullerEventCh)
}
// Clone pulls all data from a remote source database to a local destination database.
@@ -281,7 +281,7 @@ func Clone(ctx context.Context, srcDB, destDB *doltdb.DoltDB, eventCh chan<- dat
// fetchFollowTags fetches all tags from the source DB whose commits have already
// been fetched into the destination DB.
// todo: potentially too expensive to iterate over all srcDB tags
func FetchFollowTags(ctx context.Context, dEnv *env.DoltEnv, srcDB, destDB *doltdb.DoltDB, progStarter ProgStarter, progStopper ProgStopper) error {
func FetchFollowTags(ctx context.Context, tempTableDir string, srcDB, destDB *doltdb.DoltDB, progStarter ProgStarter, progStopper ProgStopper) error {
err := IterResolvedTags(ctx, srcDB, func(tag *doltdb.Tag) (stop bool, err error) {
stRef, err := tag.GetStRef()
if err != nil {
@@ -313,9 +313,10 @@ func FetchFollowTags(ctx context.Context, dEnv *env.DoltEnv, srcDB, destDB *dolt
return false, nil
}
wg, progChan, pullerEventCh := progStarter()
err = FetchTag(ctx, dEnv, srcDB, destDB, tag, progChan, pullerEventCh)
progStopper(wg, progChan, pullerEventCh)
newCtx, cancelFunc := context.WithCancel(ctx)
wg, progChan, pullerEventCh := progStarter(newCtx)
err = FetchTag(ctx, tempTableDir, srcDB, destDB, tag, progChan, pullerEventCh)
progStopper(cancelFunc, wg, progChan, pullerEventCh)
if err != nil {
return true, err
@@ -333,7 +334,7 @@ func FetchFollowTags(ctx context.Context, dEnv *env.DoltEnv, srcDB, destDB *dolt
return nil
}
func FetchRemoteBranch(ctx context.Context, dEnv *env.DoltEnv, rem env.Remote, srcDB, destDB *doltdb.DoltDB, srcRef, destRef ref.DoltRef, progStarter ProgStarter, progStopper ProgStopper) (*doltdb.Commit, error) {
func FetchRemoteBranch(ctx context.Context, tempTablesDir string, rem env.Remote, srcDB, destDB *doltdb.DoltDB, srcRef, destRef ref.DoltRef, progStarter ProgStarter, progStopper ProgStopper) (*doltdb.Commit, error) {
evt := events.GetEventFromContext(ctx)
u, err := earl.Parse(rem.Url)
@@ -351,9 +352,10 @@ func FetchRemoteBranch(ctx context.Context, dEnv *env.DoltEnv, rem env.Remote, s
return nil, fmt.Errorf("unable to find '%s' on '%s'; %w", srcRef.GetPath(), rem.Name, err)
}
wg, progChan, pullerEventCh := progStarter()
err = FetchCommit(ctx, dEnv, srcDB, destDB, srcDBCommit, progChan, pullerEventCh)
progStopper(wg, progChan, pullerEventCh)
newCtx, cancelFunc := context.WithCancel(ctx)
wg, progChan, pullerEventCh := progStarter(newCtx)
err = FetchCommit(ctx, tempTablesDir, srcDB, destDB, srcDBCommit, progChan, pullerEventCh)
progStopper(cancelFunc, wg, progChan, pullerEventCh)
if err != nil {
return nil, err

View File

@@ -550,19 +550,15 @@ func (dEnv *DoltEnv) UpdateWorkingSet(ctx context.Context, ws *doltdb.WorkingSet
}
type repoStateReader struct {
dEnv *DoltEnv
*DoltEnv
}
func (r *repoStateReader) CWBHeadRef() ref.DoltRef {
return r.dEnv.RepoState.CWBHeadRef()
return r.RepoState.CWBHeadRef()
}
func (r *repoStateReader) CWBHeadSpec() *doltdb.CommitSpec {
return r.dEnv.RepoState.CWBHeadSpec()
}
func (r *repoStateReader) GetRemotes() (map[string]Remote, error) {
return r.dEnv.GetRemotes()
return r.RepoState.CWBHeadSpec()
}
func (dEnv *DoltEnv) RepoStateReader() RepoStateReader {
@@ -971,13 +967,17 @@ func (dEnv *DoltEnv) FindRef(ctx context.Context, refStr string) (ref.DoltRef, e
// GetRefSpecs takes an optional remoteName and returns all refspecs associated with that remote. If "" is passed as
// the remoteName then the default remote is used.
func (dEnv *DoltEnv) GetRefSpecs(remoteName string) ([]ref.RemoteRefSpec, error) {
func GetRefSpecs(rsr RepoStateReader, remoteName string) ([]ref.RemoteRefSpec, error) {
var remote Remote
var err error
remotes, err := rsr.GetRemotes()
if err != nil {
return nil, err
}
if remoteName == "" {
remote, err = dEnv.GetDefaultRemote()
} else if r, ok := dEnv.RepoState.Remotes[remoteName]; ok {
remote, err = GetDefaultRemote(rsr)
} else if r, ok := remotes[remoteName]; ok {
remote = r
} else {
err = ErrUnknownRemote
@@ -1014,8 +1014,11 @@ var ErrCantDetermineDefault = errors.New("unable to determine the default remote
// GetDefaultRemote gets the default remote for the environment. Not fully implemented yet. Needs to support multiple
// repos and a configurable default.
func (dEnv *DoltEnv) GetDefaultRemote() (Remote, error) {
remotes := dEnv.RepoState.Remotes
func GetDefaultRemote(rsr RepoStateReader) (Remote, error) {
remotes, err := rsr.GetRemotes()
if err != nil {
return NoRemote, err
}
if len(remotes) == 0 {
return NoRemote, ErrNoRemote
@@ -1025,7 +1028,7 @@ func (dEnv *DoltEnv) GetDefaultRemote() (Remote, error) {
}
}
if remote, ok := dEnv.RepoState.Remotes["origin"]; ok {
if remote, ok := remotes["origin"]; ok {
return remote, nil
}

View File

@@ -286,10 +286,10 @@ type PullSpec struct {
Branch ref.DoltRef
}
func ParsePullSpec(ctx context.Context, dEnv *DoltEnv, remoteName string, squash, noff, force bool) (*PullSpec, error) {
branch := dEnv.RepoStateReader().CWBHeadRef()
func NewPullSpec(ctx context.Context, rsr RepoStateReader, remoteName string, squash, noff, force bool) (*PullSpec, error) {
branch := rsr.CWBHeadRef()
refSpecs, err := dEnv.GetRefSpecs(remoteName)
refSpecs, err := GetRefSpecs(rsr, remoteName)
if err != nil {
return nil, err
}
@@ -298,7 +298,11 @@ func ParsePullSpec(ctx context.Context, dEnv *DoltEnv, remoteName string, squash
return nil, ErrNoRefSpecForRemote
}
remote := dEnv.RepoState.Remotes[refSpecs[0].GetRemote()]
remotes, err := rsr.GetRemotes()
if err != nil {
return nil, err
}
remote := remotes[refSpecs[0].GetRemote()]
return &PullSpec{
Squash: squash,

View File

@@ -39,6 +39,7 @@ type RepoStateWriter interface {
SetCWBHeadRef(context.Context, ref.MarshalableRef) error
AddRemote(name string, url string, fetchSpecs []string, params map[string]string) error
RemoveRemote(ctx context.Context, name string) error
TempTableFilesDir() string
}
type DocsReadWriter interface {

View File

@@ -35,10 +35,10 @@ var ErrMergeFailedToUpdateRepoState = errors.New("unable to execute repo state u
var ErrFailedToDetermineMergeability = errors.New("failed to determine mergeability")
type MergeSpec struct {
H1 hash.Hash
H2 hash.Hash
Cm1 *doltdb.Commit
Cm2 *doltdb.Commit
HeadH hash.Hash
MergeH hash.Hash
HeadC *doltdb.Commit
MergeC *doltdb.Commit
TblNames []string
WorkingDiffs map[string]hash.Hash
Squash bool
@@ -51,13 +51,13 @@ type MergeSpec struct {
Date time.Time
}
func ParseMergeSpec(ctx context.Context, dEnv *env.DoltEnv, msg string, commitSpecStr string, squash bool, noff bool, force bool, date time.Time) (*MergeSpec, bool, error) {
func NewMergeSpec(ctx context.Context, rsr env.RepoStateReader, ddb *doltdb.DoltDB, roots doltdb.Roots, name, email, msg string, commitSpecStr string, squash bool, noff bool, force bool, date time.Time) (*MergeSpec, bool, error) {
cs1, err := doltdb.NewCommitSpec("HEAD")
if err != nil {
return nil, false, err
}
cm1, err := dEnv.DoltDB.Resolve(context.TODO(), cs1, dEnv.RepoStateReader().CWBHeadRef())
cm1, err := ddb.Resolve(context.TODO(), cs1, rsr.CWBHeadRef())
if err != nil {
return nil, false, err
}
@@ -67,7 +67,7 @@ func ParseMergeSpec(ctx context.Context, dEnv *env.DoltEnv, msg string, commitSp
return nil, false, err
}
cm2, err := dEnv.DoltDB.Resolve(context.TODO(), cs2, dEnv.RepoStateReader().CWBHeadRef())
cm2, err := ddb.Resolve(context.TODO(), cs2, rsr.CWBHeadRef())
if err != nil {
return nil, false, err
}
@@ -83,26 +83,16 @@ func ParseMergeSpec(ctx context.Context, dEnv *env.DoltEnv, msg string, commitSp
}
roots, err := dEnv.Roots(ctx)
if err != nil {
return nil, false, err
}
tblNames, workingDiffs, err := MergeWouldStompChanges(ctx, roots, cm2)
if err != nil {
return nil, false, fmt.Errorf("%w; %s", ErrFailedToDetermineMergeability, err.Error())
}
name, email, err := env.GetNameAndEmail(dEnv.Config)
if err != nil {
return nil, false, err
}
return &MergeSpec{
H1: h1,
H2: h2,
Cm1: cm1,
Cm2: cm2,
HeadH: h1,
MergeH: h2,
HeadC: cm1,
MergeC: cm2,
TblNames: tblNames,
WorkingDiffs: workingDiffs,
Squash: squash,
@@ -115,42 +105,42 @@ func ParseMergeSpec(ctx context.Context, dEnv *env.DoltEnv, msg string, commitSp
}, true, nil
}
func MergeCommitSpec(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *MergeSpec) (map[string]*MergeStats, error) {
if ok, err := mergeSpec.Cm1.CanFastForwardTo(ctx, mergeSpec.Cm2); ok {
ancRoot, err := mergeSpec.Cm1.GetRootValue()
func MergeCommitSpec(ctx context.Context, dEnv *env.DoltEnv, spec *MergeSpec) (map[string]*MergeStats, error) {
if ok, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC); ok {
ancRoot, err := spec.HeadC.GetRootValue()
if err != nil {
return nil, err
}
mergedRoot, err := mergeSpec.Cm2.GetRootValue()
mergedRoot, err := spec.MergeC.GetRootValue()
if err != nil {
return nil, err
}
if cvPossible, err := MayHaveConstraintViolations(ctx, ancRoot, mergedRoot); err != nil {
return nil, err
} else if cvPossible {
return ExecuteMerge(ctx, dEnv, mergeSpec)
return ExecuteMerge(ctx, dEnv, spec)
}
if mergeSpec.Noff {
return ExecNoFFMerge(ctx, dEnv, mergeSpec)
if spec.Noff {
return ExecNoFFMerge(ctx, dEnv, spec)
} else {
return nil, ExecuteFFMerge(ctx, dEnv, mergeSpec)
return nil, ExecuteFFMerge(ctx, dEnv, spec)
}
} else if err == doltdb.ErrUpToDate || err == doltdb.ErrIsAhead {
return nil, err
} else {
return ExecuteMerge(ctx, dEnv, mergeSpec)
return ExecuteMerge(ctx, dEnv, spec)
}
}
func ExecNoFFMerge(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *MergeSpec) (map[string]*MergeStats, error) {
mergedRoot, err := mergeSpec.Cm2.GetRootValue()
func ExecNoFFMerge(ctx context.Context, dEnv *env.DoltEnv, spec *MergeSpec) (map[string]*MergeStats, error) {
mergedRoot, err := spec.MergeC.GetRootValue()
if err != nil {
return nil, ErrFailedToReadDatabase
}
tblToStats := make(map[string]*MergeStats)
err = mergedRootToWorking(ctx, false, dEnv, mergedRoot, mergeSpec.WorkingDiffs, mergeSpec.Cm2, tblToStats)
err = mergedRootToWorking(ctx, false, dEnv, mergedRoot, spec.WorkingDiffs, spec.MergeC, tblToStats)
if err != nil {
return tblToStats, err
@@ -173,12 +163,12 @@ func ExecNoFFMerge(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *MergeSpec)
}
_, err = actions.CommitStaged(ctx, roots, ws.MergeActive(), mergeParentCommits, dEnv.DbData(), actions.CommitStagedProps{
Message: mergeSpec.Msg,
Date: mergeSpec.Date,
AllowEmpty: mergeSpec.AllowEmpty,
Force: mergeSpec.Force,
Name: mergeSpec.Name,
Email: mergeSpec.Email,
Message: spec.Msg,
Date: spec.Date,
AllowEmpty: spec.AllowEmpty,
Force: spec.Force,
Name: spec.Name,
Email: spec.Email,
})
if err != nil {
@@ -209,16 +199,16 @@ func applyChanges(ctx context.Context, root *doltdb.RootValue, workingDiffs map[
func ExecuteFFMerge(
ctx context.Context,
dEnv *env.DoltEnv,
mergeSpec *MergeSpec,
spec *MergeSpec,
) error {
stagedRoot, err := mergeSpec.Cm2.GetRootValue()
stagedRoot, err := spec.MergeC.GetRootValue()
if err != nil {
return err
}
workingRoot := stagedRoot
if len(mergeSpec.WorkingDiffs) > 0 {
workingRoot, err = applyChanges(ctx, stagedRoot, mergeSpec.WorkingDiffs)
if len(spec.WorkingDiffs) > 0 {
workingRoot, err = applyChanges(ctx, stagedRoot, spec.WorkingDiffs)
if err != nil {
//return errhand.BuildDError("Failed to re-apply working changes.").AddCause(err).Build()
@@ -231,8 +221,8 @@ func ExecuteFFMerge(
return err
}
if !mergeSpec.Squash {
err = dEnv.DoltDB.FastForward(ctx, dEnv.RepoStateReader().CWBHeadRef(), mergeSpec.Cm2)
if !spec.Squash {
err = dEnv.DoltDB.FastForward(ctx, dEnv.RepoStateReader().CWBHeadRef(), spec.MergeC)
if err != nil {
return err
@@ -257,9 +247,9 @@ func ExecuteFFMerge(
return nil
}
func ExecuteMerge(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *MergeSpec) (map[string]*MergeStats, error) {
func ExecuteMerge(ctx context.Context, dEnv *env.DoltEnv, spec *MergeSpec) (map[string]*MergeStats, error) {
opts := editor.Options{Deaf: dEnv.DbEaFactory()}
mergedRoot, tblToStats, err := MergeCommits(ctx, mergeSpec.Cm1, mergeSpec.Cm2, opts)
mergedRoot, tblToStats, err := MergeCommits(ctx, spec.HeadC, spec.MergeC, opts)
if err != nil {
switch err {
case doltdb.ErrUpToDate:
@@ -270,7 +260,7 @@ func ExecuteMerge(ctx context.Context, dEnv *env.DoltEnv, mergeSpec *MergeSpec)
return tblToStats, err
}
return tblToStats, mergedRootToWorking(ctx, mergeSpec.Squash, dEnv, mergedRoot, mergeSpec.WorkingDiffs, mergeSpec.Cm2, tblToStats)
return tblToStats, mergedRootToWorking(ctx, spec.Squash, dEnv, mergedRoot, spec.WorkingDiffs, spec.MergeC, tblToStats)
}
// TODO: change this to be functional and not write to repo state

View File

@@ -40,8 +40,10 @@ type DoltMergeFunc struct {
const DoltConflictWarningCode int = 1105 // Since this our own custom warning we'll use 1105, the code for an unknown error
const hasConflicts = 0
const noConflicts = 1
const (
hasConflicts int = 0
noConflicts int = 1
)
func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
dbName := ctx.GetCurrentDatabase()
@@ -51,18 +53,6 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
}
sess := dsess.DSessFromSess(ctx.Session)
dbData, ok := sess.GetDbData(ctx, dbName)
if !ok {
return noConflicts, fmt.Errorf("Could not load database %s", dbName)
}
dbState, ok, err := sess.LookupDbState(ctx, dbName)
if err != nil {
return noConflicts, err
} else if !ok {
return noConflicts, fmt.Errorf("Could not load database %s", dbName)
}
ap := cli.CreateMergeArgParser()
args, err := getDoltArgs(ctx, row, d.Children())
@@ -85,11 +75,8 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
return noConflicts, err
}
roots, ok := sess.GetRoots(ctx, dbName)
// logrus.Errorf("heads are working: %s\nhead: %s", roots.Working.DebugString(ctx, true), roots.Head.DebugString(ctx, true))
if !ok {
return noConflicts, fmt.Errorf("Could not load database %s", dbName)
return noConflicts, sql.ErrDatabaseNotFound.New(dbName)
}
if apr.Contains(cli.AbortParam) {
@@ -110,108 +97,120 @@ func (d DoltMergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
return noConflicts, nil
}
ddb, ok := sess.GetDoltDB(ctx, dbName)
if !ok {
return noConflicts, sql.ErrDatabaseNotFound.New(dbName)
branchName := apr.Arg(0)
mergeSpec, err := createMergeSpec(ctx, sess, dbName, apr, branchName)
if err != nil {
return noConflicts, err
}
ws, conflicts, err := mergeIntoWorkingSet(ctx, sess, roots, ws, dbName, mergeSpec)
if err != nil {
return conflicts, err
}
if hasConflicts, err := roots.Working.HasConflicts(ctx); err != nil {
return noConflicts, err
} else if hasConflicts {
return noConflicts, doltdb.ErrUnresolvedConflicts
err = sess.SetWorkingSet(ctx, dbName, ws, nil)
if err != nil {
return conflicts, err
}
return conflicts, nil
}
// mergeIntoWorkingSet encapsulates server merge logic, switching between fast-forward, no fast-forward, merge commit,
// and merging into working set. Returns a new WorkingSet and whether there were merge conflicts. This currently
// persists merge commits in the database, but expects the caller to update the working set.
func mergeIntoWorkingSet(ctx *sql.Context, sess *dsess.Session, roots doltdb.Roots, ws *doltdb.WorkingSet, dbName string, spec *merge.MergeSpec) (*doltdb.WorkingSet, int, error) {
if conflicts, err := roots.Working.HasConflicts(ctx); err != nil {
return ws, noConflicts, err
} else if conflicts {
return ws, hasConflicts, doltdb.ErrUnresolvedConflicts
}
if hasConstraintViolations, err := roots.Working.HasConstraintViolations(ctx); err != nil {
return noConflicts, err
return ws, hasConflicts, err
} else if hasConstraintViolations {
return noConflicts, doltdb.ErrUnresolvedConstraintViolations
return ws, hasConflicts, doltdb.ErrUnresolvedConstraintViolations
}
if ws.MergeActive() {
return noConflicts, doltdb.ErrMergeActive
return ws, noConflicts, doltdb.ErrMergeActive
}
err = checkForUncommittedChanges(roots.Working, roots.Head)
err := checkForUncommittedChanges(roots.Working, roots.Head)
if err != nil {
return noConflicts, err
return ws, noConflicts, err
}
branchName := apr.Arg(0)
mergeCommit, _, err := getBranchCommit(ctx, branchName, ddb)
canFF, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC)
if err != nil {
return noConflicts, err
}
headCommit, err := sess.GetHeadCommit(ctx, dbName)
if err != nil {
return noConflicts, err
}
canFF, err := headCommit.CanFastForwardTo(ctx, mergeCommit)
if err != nil {
return noConflicts, err
return ws, noConflicts, err
}
if canFF {
headRoot, err := headCommit.GetRootValue()
headRoot, err := spec.HeadC.GetRootValue()
if err != nil {
return noConflicts, err
return ws, noConflicts, err
}
mergeRoot, err := mergeCommit.GetRootValue()
mergeRoot, err := spec.MergeC.GetRootValue()
if err != nil {
return noConflicts, err
return ws, noConflicts, err
}
if cvPossible, err := merge.MayHaveConstraintViolations(ctx, headRoot, mergeRoot); err != nil {
return noConflicts, err
return ws, noConflicts, err
} else if !cvPossible {
if apr.Contains(cli.NoFFParam) {
ws, err = executeNoFFMerge(ctx, sess, apr, dbName, ws, dbData, headCommit, mergeCommit)
dbData, ok := sess.GetDbData(ctx, dbName)
if !ok {
return ws, noConflicts, fmt.Errorf("could not load database %s", dbName)
}
if spec.Noff {
ws, err = executeNoFFMerge(ctx, sess, spec, dbName, ws, dbData)
if err == doltdb.ErrUnresolvedConflicts {
// if there are unresolved conflicts, write the resulting working set back to the session and return an
// error message
wsErr := sess.SetWorkingSet(ctx, dbName, ws, nil)
if wsErr != nil {
return hasConflicts, wsErr
return ws, hasConflicts, wsErr
}
ctx.Warn(DoltConflictWarningCode, err.Error())
// Return 0 indicating there are conflicts
return hasConflicts, nil
return ws, hasConflicts, nil
}
} else {
err = executeFFMerge(ctx, sess, apr.Contains(cli.SquashParam), dbName, ws, dbData, mergeCommit)
ws, err = executeFFMerge(ctx, spec.Squash, ws, dbData, spec.MergeC)
}
if err != nil {
return noConflicts, err
return ws, noConflicts, err
}
return noConflicts, err
return ws, noConflicts, err
}
}
ws, err = executeMerge(ctx, apr.Contains(cli.SquashParam), headCommit, mergeCommit, ws, dbState.EditSession.Opts)
dbState, ok, err := sess.LookupDbState(ctx, dbName)
if err != nil {
return ws, noConflicts, err
} else if !ok {
return ws, noConflicts, fmt.Errorf("could not load database %s", dbName)
}
ws, err = executeMerge(ctx, spec.Squash, spec.HeadC, spec.MergeC, ws, dbState.EditSession.Opts)
if err == doltdb.ErrUnresolvedConflicts {
// if there are unresolved conflicts, write the resulting working set back to the session and return an
// error message
wsErr := sess.SetWorkingSet(ctx, dbName, ws, nil)
if wsErr != nil {
return hasConflicts, wsErr
return ws, hasConflicts, wsErr
}
ctx.Warn(DoltConflictWarningCode, err.Error())
return hasConflicts, nil
return ws, hasConflicts, nil
} else if err != nil {
return noConflicts, err
return ws, noConflicts, err
}
err = sess.SetWorkingSet(ctx, dbName, ws, nil)
if err != nil {
return noConflicts, err
}
return noConflicts, nil
return ws, noConflicts, nil
}
func abortMerge(ctx *sql.Context, workingSet *doltdb.WorkingSet, roots doltdb.Roots) (*doltdb.WorkingSet, error) {
@@ -247,10 +246,10 @@ func executeMerge(ctx *sql.Context, squash bool, head, cm *doltdb.Commit, ws *do
return mergeRootToWorking(squash, ws, mergeRoot, cm, mergeStats)
}
func executeFFMerge(ctx *sql.Context, sess *dsess.Session, squash bool, dbName string, ws *doltdb.WorkingSet, dbData env.DbData, cm2 *doltdb.Commit) error {
func executeFFMerge(ctx *sql.Context, squash bool, ws *doltdb.WorkingSet, dbData env.DbData, cm2 *doltdb.Commit) (*doltdb.WorkingSet, error) {
rv, err := cm2.GetRootValue()
if err != nil {
return err
return ws, err
}
// TODO: This is all incredibly suspect, needs to be replaced with library code that is functional instead of
@@ -258,71 +257,33 @@ func executeFFMerge(ctx *sql.Context, sess *dsess.Session, squash bool, dbName s
if !squash {
err = dbData.Ddb.FastForward(ctx, dbData.Rsr.CWBHeadRef(), cm2)
if err != nil {
return err
return ws, err
}
}
ws = ws.WithWorkingRoot(rv).WithStagedRoot(rv)
return sess.SetWorkingSet(ctx, dbName, ws, nil)
return ws.WithWorkingRoot(rv).WithStagedRoot(rv), nil
}
func executeNoFFMerge(
ctx *sql.Context,
dSess *dsess.Session,
apr *argparser.ArgParseResults,
spec *merge.MergeSpec,
dbName string,
ws *doltdb.WorkingSet,
dbData env.DbData,
headCommit, mergeCommit *doltdb.Commit,
//headCommit, mergeCommit *doltdb.Commit,
) (*doltdb.WorkingSet, error) {
mergeRoot, err := mergeCommit.GetRootValue()
mergeRoot, err := spec.MergeC.GetRootValue()
if err != nil {
return nil, err
}
ws, err = mergeRootToWorking(false, ws, mergeRoot, mergeCommit, map[string]*merge.MergeStats{})
ws, err = mergeRootToWorking(false, ws, mergeRoot, spec.MergeC, map[string]*merge.MergeStats{})
if err != nil {
// This error is recoverable, so we return a working set value along with the error
return ws, err
}
msg, msgOk := apr.GetValue(cli.CommitMessageArg)
if !msgOk {
hh, err := headCommit.HashOf()
if err != nil {
return nil, err
}
cmh, err := mergeCommit.HashOf()
if err != nil {
return nil, err
}
msg = fmt.Sprintf("SQL Generated commit merging %s into %s", hh.String(), cmh.String())
}
// TODO: refactor, redundant
var name, email string
if authorStr, ok := apr.GetValue(cli.AuthorParam); ok {
name, email, err = cli.ParseAuthor(authorStr)
if err != nil {
return nil, err
}
} else {
name = dSess.Username
email = dSess.Email
}
t := ctx.QueryTime()
if commitTimeStr, ok := apr.GetValue(cli.DateParam); ok {
var err error
t, err = cli.ParseDate(commitTimeStr)
if err != nil {
return nil, err
}
}
// Save our work so far in the session, as it will be referenced by the commit call below (badly in need of a
// refactoring)
err = dSess.SetWorkingSet(ctx, dbName, ws, nil)
@@ -341,12 +302,12 @@ func executeNoFFMerge(
// TODO: this does several session state updates, and it really needs to just do one
// We also need to commit any pending transaction before we do this.
_, err = actions.CommitStaged(ctx, roots, ws.MergeActive(), mergeParentCommits, dbData, actions.CommitStagedProps{
Message: msg,
Date: t,
AllowEmpty: apr.Contains(cli.AllowEmptyFlag),
Force: apr.Contains(cli.ForceFlag),
Name: name,
Email: email,
Message: spec.Msg,
Date: spec.Date,
AllowEmpty: spec.AllowEmpty,
Force: spec.Force,
Name: spec.Name,
Email: spec.Email,
})
if err != nil {
return nil, err
@@ -355,6 +316,49 @@ func executeNoFFMerge(
return ws, dSess.SetWorkingSet(ctx, dbName, ws.ClearMerge(), nil)
}
func createMergeSpec(ctx *sql.Context, sess *dsess.Session, dbName string, apr *argparser.ArgParseResults, commitSpecStr string) (*merge.MergeSpec, error) {
ddb, ok := sess.GetDoltDB(ctx, dbName)
dbData, ok := sess.GetDbData(ctx, dbName)
msg, ok := apr.GetValue(cli.CommitMessageArg)
if !ok {
// TODO probably change, but we can't open editor so it'll have to be automated
msg = "automatic SQL merge"
}
var err error
var name, email string
if authorStr, ok := apr.GetValue(cli.AuthorParam); ok {
name, email, err = cli.ParseAuthor(authorStr)
if err != nil {
return nil, err
}
} else {
name = sess.Username
email = sess.Email
}
t := ctx.QueryTime()
if commitTimeStr, ok := apr.GetValue(cli.DateParam); ok {
t, err = cli.ParseDate(commitTimeStr)
if err != nil {
return nil, err
}
}
roots, ok := sess.GetRoots(ctx, dbName)
if !ok {
return nil, sql.ErrDatabaseNotFound.New(dbName)
}
mergeSpec, _, err := merge.NewMergeSpec(ctx, dbData.Rsr, ddb, roots, name, email, msg, commitSpecStr, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag), t)
if err != nil {
return nil, err
}
return mergeSpec, nil
}
// TODO: this copied from commands/merge.go because the latter isn't reusable. Fix that.
func mergeRootToWorking(
squash bool,

View File

@@ -0,0 +1,213 @@
// Copyright 2021 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dfunctions
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/store/datas"
)
const DoltPullFuncName = "dolt_pull"
type DoltPullFunc struct {
expression.NaryExpression
}
// NewPullFunc creates a new PullFunc expression.
func NewPullFunc(ctx *sql.Context, args ...sql.Expression) (sql.Expression, error) {
return &DoltPullFunc{expression.NaryExpression{ChildExpressions: args}}, nil
}
func (d DoltPullFunc) String() string {
childrenStrings := make([]string, len(d.Children()))
for i, child := range d.Children() {
childrenStrings[i] = child.String()
}
return fmt.Sprintf("DOLT_PULL(%s)", strings.Join(childrenStrings, ","))
}
func (d DoltPullFunc) Type() sql.Type {
return sql.Boolean
}
func (d DoltPullFunc) WithChildren(ctx *sql.Context, children ...sql.Expression) (sql.Expression, error) {
return NewPullFunc(ctx, children...)
}
func (d DoltPullFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
dbName := ctx.GetCurrentDatabase()
if len(dbName) == 0 {
return noConflicts, fmt.Errorf("empty database name.")
}
sess := dsess.DSessFromSess(ctx.Session)
dbData, ok := sess.GetDbData(ctx, dbName)
if !ok {
return noConflicts, sql.ErrDatabaseNotFound.New(dbName)
}
ap := cli.CreatePullArgParser()
args, err := getDoltArgs(ctx, row, d.Children())
apr, err := ap.Parse(args)
if err != nil {
return noConflicts, err
}
if apr.NArg() > 1 {
return noConflicts, actions.ErrInvalidPullArgs
}
var remoteName string
if apr.NArg() == 1 {
remoteName = apr.Arg(0)
}
pullSpec, err := env.NewPullSpec(ctx, dbData.Rsr, remoteName, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag))
if err != nil {
return noConflicts, err
}
srcDB, err := pullSpec.Remote.GetRemoteDBWithoutCaching(ctx, dbData.Ddb.ValueReadWriter().Format())
if err != nil {
return noConflicts, fmt.Errorf("failed to get remote db; %w", err)
}
ws, err := sess.WorkingSet(ctx, dbName)
if err != nil {
return noConflicts, err
}
var conflicts interface{}
for _, refSpec := range pullSpec.RefSpecs {
remoteTrackRef := refSpec.DestRef(pullSpec.Branch)
if remoteTrackRef != nil {
// todo: can we pass nil for either of the channels?
srcDBCommit, err := actions.FetchRemoteBranch(ctx, dbData.Rsw.TempTableFilesDir(), pullSpec.Remote, srcDB, dbData.Ddb, pullSpec.Branch, remoteTrackRef, runProgFuncs, stopProgFuncs)
if err != nil {
return noConflicts, err
}
// TODO: this could be replaced with a canFF check to test for error
err = dbData.Ddb.FastForward(ctx, remoteTrackRef, srcDBCommit)
if err != nil {
return noConflicts, fmt.Errorf("fetch failed; %w", err)
}
roots, ok := sess.GetRoots(ctx, dbName)
if !ok {
return noConflicts, sql.ErrDatabaseNotFound.New(dbName)
}
mergeSpec, err := createMergeSpec(ctx, sess, dbName, apr, remoteTrackRef.String())
if err != nil {
return noConflicts, err
}
ws, conflicts, err = mergeIntoWorkingSet(ctx, sess, roots, ws, dbName, mergeSpec)
if err != nil && !errors.Is(doltdb.ErrUpToDate, err) {
return conflicts, err
}
err = sess.SetWorkingSet(ctx, dbName, ws, nil)
if err != nil {
return conflicts, err
}
}
}
err = actions.FetchFollowTags(ctx, dbData.Rsw.TempTableFilesDir(), srcDB, dbData.Ddb, runProgFuncs, stopProgFuncs)
if err != nil {
return noConflicts, err
}
return noConflicts, nil
}
func pullerProgFunc(ctx context.Context, pullerEventCh <-chan datas.PullerEvent) {
for {
select {
case <-ctx.Done():
return
default:
}
select {
case <-ctx.Done():
return
case <-pullerEventCh:
default:
}
}
}
func progFunc(ctx context.Context, progChan <-chan datas.PullProgress) {
for {
select {
case <-ctx.Done():
return
default:
}
select {
case <-ctx.Done():
return
case <-progChan:
default:
}
}
}
func runProgFuncs(ctx context.Context) (*sync.WaitGroup, chan datas.PullProgress, chan datas.PullerEvent) {
pullerEventCh := make(chan datas.PullerEvent)
progChan := make(chan datas.PullProgress)
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
progFunc(ctx, progChan)
}()
wg.Add(1)
go func() {
defer wg.Done()
pullerProgFunc(ctx, pullerEventCh)
}()
return wg, progChan, pullerEventCh
}
func stopProgFuncs(cancel context.CancelFunc, wg *sync.WaitGroup, progChan chan datas.PullProgress, pullerEventCh chan datas.PullerEvent) {
cancel()
close(progChan)
close(pullerEventCh)
wg.Wait()
}

View File

@@ -33,6 +33,7 @@ var DoltFunctions = []sql.Function{
sql.FunctionN{Name: ConstraintsVerifyFuncName, Fn: NewConstraintsVerifyFunc},
sql.FunctionN{Name: ConstraintsVerifyAllFuncName, Fn: NewConstraintsVerifyAllFunc},
sql.FunctionN{Name: RevertFuncName, Fn: NewRevertFunc},
sql.FunctionN{Name: DoltPullFuncName, Fn: NewPullFunc},
}
// These are the DoltFunctions that get exposed to Dolthub Api.

View File

@@ -17,12 +17,15 @@ package dsess
import (
"context"
"fmt"
"path/filepath"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
)
// SessionStateAdapter is an adapter for env.RepoStateReader in SQL contexts, getting information about the repo state
@@ -124,3 +127,16 @@ func (s SessionStateAdapter) AddRemote(name string, url string, fetchSpecs []str
func (s SessionStateAdapter) RemoveRemote(ctx context.Context, name string) error {
return fmt.Errorf("cannot delete remote in an SQL session")
}
func (s SessionStateAdapter) TempTableFilesDir() string {
//todo: save tempfile in dbState on server startup?
return mustAbs(dbfactory.DoltDir, "temptf")
}
func mustAbs(path ...string) string {
absPath, err := filesys.LocalFS.Abs(filepath.Join(path...))
if err != nil {
panic(err)
}
return absPath
}

View File

@@ -1110,4 +1110,4 @@ setup_ref_test() {
run dolt fetch dadasdfasdfa
[ "$status" -eq 1 ]
[[ "$output" =~ "error: dadasdfasdfa does not appear to be a dolt database" ]] || false
}
}

View File

@@ -0,0 +1,229 @@
#!/usr/bin/env bats
load $BATS_TEST_DIRNAME/helper/common.bash
setup() {
setup_common
TMPDIRS=$(pwd)/tmpdirs
mkdir -p $TMPDIRS/{rem1,tmp1}
# tmp1 -> rem1 -> tmp2
cd $TMPDIRS/tmp1
dolt init
dolt branch feature
dolt remote add origin file://../rem1
dolt remote add test-remote file://../rem1
dolt push origin master
cd $TMPDIRS
dolt clone file://rem1 tmp2
cd $TMPDIRS/tmp2
dolt log
dolt branch feature
dolt remote add test-remote file://../rem1
# table and comits only present on tmp1, rem1 at start
cd $TMPDIRS/tmp1
dolt sql -q "create table t1 (a int primary key, b int)"
dolt commit -am "First commit"
dolt sql -q "insert into t1 values (0,0)"
dolt commit -am "Second commit"
dolt push origin master
cd $TMPDIRS
}
teardown() {
teardown_common
rm -rf $TMPDIRS
cd $BATS_TMPDIR
}
@test "sql-pull: dolt_pull master" {
cd tmp2
dolt sql -q "select dolt_pull('origin')"
run dolt sql -q "show tables" -r csv
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 2 ]
[[ "$output" =~ "Table" ]] || false
[[ "$output" =~ "t1" ]] || false
}
@test "sql-pull: dolt_pull custom remote" {
cd tmp2
dolt sql -q "select dolt_pull('test-remote')"
run dolt sql -q "show tables" -r csv
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 2 ]
[[ "$output" =~ "Table" ]] || false
[[ "$output" =~ "t1" ]] || false
}
@test "sql-pull: dolt_pull default origin" {
cd tmp2
dolt remote remove test-remote
dolt sql -q "select dolt_pull()"
run dolt sql -q "show tables" -r csv
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 2 ]
[[ "$output" =~ "Table" ]] || false
[[ "$output" =~ "t1" ]] || false
}
@test "sql-pull: dolt_pull default custom remote" {
cd tmp2
dolt remote remove origin
dolt sql -q "select dolt_pull()"
run dolt sql -q "show tables" -r csv
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 2 ]
[[ "$output" =~ "Table" ]] || false
[[ "$output" =~ "t1" ]] || false
}
@test "sql-pull: dolt_pull up to date does not error" {
cd tmp2
dolt sql -q "select dolt_pull('origin')"
dolt sql -q "select dolt_pull('origin')"
run dolt sql -q "show tables" -r csv
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 2 ]
[[ "$output" =~ "Table" ]] || false
[[ "$output" =~ "t1" ]] || false
}
@test "sql-pull: dolt_pull unknown remote fails" {
cd tmp2
run dolt sql -q "select dolt_pull('unknown')"
[ "$status" -eq 1 ]
[[ "$output" =~ "unknown remote" ]] || false
[[ ! "$output" =~ "panic" ]] || false
}
@test "sql-pull: dolt_pull unknown feature branch fails" {
cd tmp2
dolt checkout feature
run dolt sql -q "select dolt_pull('origin')"
[ "$status" -eq 1 ]
[[ "$output" =~ "branch not found" ]] || false
[[ ! "$output" =~ "panic" ]] || false
}
@test "sql-pull: dolt_pull feature branch" {
cd tmp1
dolt checkout feature
dolt merge master
dolt push origin feature
cd ../tmp2
dolt checkout feature
dolt sql -q "select dolt_pull('origin')"
run dolt sql -q "show tables" -r csv
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 2 ]
[[ "$output" =~ "Table" ]] || false
[[ "$output" =~ "t1" ]] || false
}
@test "sql-pull: dolt_pull force" {
skip "todo: support dolt pull --force (cli too)"
cd tmp2
dolt sql -q "create table t2 (a int)"
dolt commit -am "2.0 commit"
dolt push origin master
cd ../tmp1
dolt sql -q "create table t2 (a int primary key)"
dolt sql -q "create table t3 (a int primary key)"
dolt commit -am "2.1 commit"
dolt push -f origin master
cd ../tmp2
run dolt sql -q "select dolt_pull('origin')"
[ "$status" -eq 1 ]
[[ ! "$output" =~ "panic" ]] || false
[[ "$output" =~ "fetch failed; dataset head is not ancestor of commit" ]] || false
dolt sql -q "select dolt_pull('-f', 'origin')"
run dolt log -n 1
[ "$status" -eq 0 ]
[[ "$output" =~ "2.1 commit" ]] || false
run dolt sql -q "show tables" -r csv
[ "${#lines[@]}" -eq 4 ]
[[ "$output" =~ "t3" ]] || false
}
@test "sql-pull: dolt_pull squash" {
skip "todo: support dolt pull --squash (cli too)"
cd tmp2
dolt sql -q "select dolt_pull('--squash', 'origin')"
run dolt sql -q "show tables" -r csv
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 2 ]
[[ "$output" =~ "Table" ]] || false
[[ "$output" =~ "t1" ]] || false
}
@test "sql-pull: dolt_pull --noff flag" {
cd tmp2
dolt sql -q "select dolt_pull('--no-ff', 'origin')"
dolt status
run dolt log -n 1
[ "$status" -eq 0 ]
# TODO change the default message name
[[ "$output" =~ "automatic SQL merge" ]] || false
run dolt sql -q "show tables" -r csv
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 2 ]
[[ "$output" =~ "Table" ]] || false
[[ "$output" =~ "t1" ]] || false
}
@test "sql-pull: empty remote name does not panic" {
cd tmp2
dolt sql -q "select dolt_pull('')"
}
@test "sql-pull: dolt_pull dirty working set fails" {
cd tmp2
dolt sql -q "create table t2 (a int)"
run dolt sql -q "select dolt_pull('origin')"
[ "$status" -eq 1 ]
[[ "$output" =~ "cannot merge with uncommitted changes" ]] || false
}
@test "sql-pull: dolt_pull tag" {
cd tmp1
dolt tag v1
dolt push origin v1
dolt tag
cd ../tmp2
dolt sql -q "select dolt_pull('origin')"
run dolt tag
[ "$status" -eq 0 ]
[[ "$output" =~ "v1" ]] || false
}
@test "sql-pull: dolt_pull tags only for resolved commits" {
cd tmp1
dolt tag v1 head
dolt tag v2 head^
dolt push origin v1
dolt push origin v2
dolt checkout feature
dolt sql -q "create table t2 (a int)"
dolt commit -am "feature commit"
dolt tag v3
dolt push origin v3
cd ../tmp2
dolt sql -q "select dolt_pull('origin')"
run dolt tag
[ "$status" -eq 0 ]
[[ "$output" =~ "v1" ]] || false
[[ "$output" =~ "v2" ]] || false
[[ ! "$output" =~ "v3" ]] || false
}