find common ancestor using transitive ref closure

This commit is contained in:
Andy Arthur
2021-07-31 10:50:29 -07:00
parent 58bf08407a
commit d9b5509e3d
3 changed files with 199 additions and 35 deletions
+37
View File
@@ -133,6 +133,43 @@ func FindCommonAncestor(ctx context.Context, c1, c2 types.Ref, vr1, vr2 types.Va
return a, ok, nil
}
// todo comment doc
func FindClosureCommonAncestor(ctx context.Context, cl RefClosure, rf types.Ref, vr types.ValueReader) (a types.Ref, ok bool, err error) {
t, err := types.TypeOf(rf)
if err != nil {
return types.Ref{}, false, err
}
// precondition checks
if !IsRefOfCommitType(rf.Format(), t) {
d.Panic("reference is not a commit")
}
q := &RefByHeightHeap{rf}
var curr types.RefSlice
for !q.Empty() {
curr = q.PopRefsOfHeight(q.MaxHeight())
for _, r := range curr {
ok, err = cl.Contains(r)
if err != nil {
return types.Ref{}, false, err
}
if ok {
return r, ok, nil
}
}
err = parentsToQueue(ctx, curr, q, vr)
if err != nil {
return types.Ref{}, false, err
}
}
return types.Ref{}, false, nil
}
func parentsToQueue(ctx context.Context, refs types.RefSlice, q *RefByHeightHeap, vr types.ValueReader) error {
seen := make(map[hash.Hash]bool)
for _, r := range refs {
+65 -35
View File
@@ -23,6 +23,7 @@ package datas
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
@@ -241,6 +242,59 @@ func toRefList(vrw types.ValueReadWriter, commits ...types.Struct) (types.List,
return le.List(context.Background())
}
func commonAncWithSetClosure(ctx context.Context, c1, c2 types.Ref, vr1, vr2 types.ValueReader) (a types.Ref, ok bool, err error) {
var closure RefClosure
closure, err = NewSetRefClosure(ctx, vr1, c1)
if err != nil {
return types.Ref{}, false, err
}
return FindClosureCommonAncestor(ctx, closure, c2, vr2)
}
func commonAncWithLazyClosure(ctx context.Context, c1, c2 types.Ref, vr1, vr2 types.ValueReader) (a types.Ref, ok bool, err error) {
closure := NewLazyRefClousure(ctx, c1, vr1)
return FindClosureCommonAncestor(ctx, closure, c2, vr2)
}
// Assert that c is the common ancestor of a and b
func assertCommonAncestor(t *testing.T, expected, a, b types.Struct, ldb, rdb Database) {
assert := assert.New(t)
type caFinder func(ctx context.Context, c1, c2 types.Ref, vr1, vr2 types.ValueReader) (a types.Ref, ok bool, err error)
methods := map[string]caFinder{
"FindCommonAncestor": FindCommonAncestor,
"SetClosure": commonAncWithSetClosure,
}
for name, method := range methods {
tn := fmt.Sprintf("find common ancestor using %s", name)
t.Run(tn, func(t *testing.T) {
found, ok, err := method(context.Background(), mustRef(types.NewRef(a, types.Format_7_18)), mustRef(types.NewRef(b, types.Format_7_18)), ldb, rdb)
assert.NoError(err)
if assert.True(ok) {
tv, err := found.TargetValue(context.Background(), ldb)
assert.NoError(err)
ancestor := tv.(types.Struct)
expV, _, _ := expected.MaybeGet(ValueField)
aV, _, _ := a.MaybeGet(ValueField)
bV, _, _ := b.MaybeGet(ValueField)
ancV, _, _ := ancestor.MaybeGet(ValueField)
assert.True(
expected.Equals(ancestor),
"%s should be common ancestor of %s, %s. Got %s",
expV,
aV,
bV,
ancV,
)
}
})
}
}
func TestFindCommonAncestor(t *testing.T) {
assert := assert.New(t)
@@ -253,30 +307,6 @@ func TestFindCommonAncestor(t *testing.T) {
return mustHead(ds)
}
// Assert that c is the common ancestor of a and b
assertCommonAncestor := func(expected, a, b types.Struct, ldb, rdb Database) {
found, ok, err := FindCommonAncestor(context.Background(), mustRef(types.NewRef(a, types.Format_7_18)), mustRef(types.NewRef(b, types.Format_7_18)), ldb, rdb)
assert.NoError(err)
if assert.True(ok) {
tv, err := found.TargetValue(context.Background(), ldb)
assert.NoError(err)
ancestor := tv.(types.Struct)
expV, _, _ := expected.MaybeGet(ValueField)
aV, _, _ := a.MaybeGet(ValueField)
bV, _, _ := b.MaybeGet(ValueField)
ancV, _, _ := ancestor.MaybeGet(ValueField)
assert.True(
expected.Equals(ancestor),
"%s should be common ancestor of %s, %s. Got %s",
expV,
aV,
bV,
ancV,
)
}
}
storage := &chunks.TestStorage{}
db := NewDatabase(storage.NewView())
@@ -310,11 +340,11 @@ func TestFindCommonAncestor(t *testing.T) {
b5 := addCommit(db, b, "b5", b4, a3)
a6 := addCommit(db, a, "a6", a5, b5)
assertCommonAncestor(a1, a1, a1, db, db) // All self
assertCommonAncestor(a1, a1, a2, db, db) // One side self
assertCommonAncestor(a2, a3, b3, db, db) // Common parent
assertCommonAncestor(a2, a4, b4, db, db) // Common grandparent
assertCommonAncestor(a1, a6, c3, db, db) // Traversing multiple parents on both sides
assertCommonAncestor(t, a1, a1, a1, db, db) // All self
assertCommonAncestor(t, a1, a1, a2, db, db) // One side self
assertCommonAncestor(t, a2, a3, b3, db, db) // Common parent
assertCommonAncestor(t, a2, a4, b4, db, db) // Common grandparent
assertCommonAncestor(t, a1, a6, c3, db, db) // Traversing multiple parents on both sides
// No common ancestor
found, ok, err := FindCommonAncestor(context.Background(), mustRef(types.NewRef(d2, types.Format_7_18)), mustRef(types.NewRef(a6, types.Format_7_18)), db, db)
@@ -386,13 +416,13 @@ func TestFindCommonAncestor(t *testing.T) {
ra8 := addCommit(rdb, a, "ra8", ra7)
ra9 := addCommit(rdb, a, "ra9", ra8)
assertCommonAncestor(a1, a1, a1, db, rdb) // All self
assertCommonAncestor(a1, a1, a2, db, rdb) // One side self
assertCommonAncestor(a2, a3, b3, db, rdb) // Common parent
assertCommonAncestor(a2, a4, b4, db, rdb) // Common grandparent
assertCommonAncestor(a1, a6, c3, db, rdb) // Traversing multiple parents on both sides
assertCommonAncestor(t, a1, a1, a1, db, rdb) // All self
assertCommonAncestor(t, a1, a1, a2, db, rdb) // One side self
assertCommonAncestor(t, a2, a3, b3, db, rdb) // Common parent
assertCommonAncestor(t, a2, a4, b4, db, rdb) // Common grandparent
assertCommonAncestor(t, a1, a6, c3, db, rdb) // Traversing multiple parents on both sides
assertCommonAncestor(a6, a9, ra9, db, rdb) // Common third parent
assertCommonAncestor(t, a6, a9, ra9, db, rdb) // Common third parent
_, _, err = FindCommonAncestor(context.Background(), mustRef(types.NewRef(a9, types.Format_7_18)), mustRef(types.NewRef(ra9, types.Format_7_18)), rdb, db)
assert.Error(err)
+97
View File
@@ -0,0 +1,97 @@
// Copyright 2019 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datas
import (
"context"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/types"
)
// todo comment doc
type RefClosure interface {
Contains(ref types.Ref) (bool, error)
}
// todo comment doc
func NewSetRefClosure(ctx context.Context, vr types.ValueReader, ref types.Ref) (RefClosure, error) {
s, err := transitiveClosure(ctx, vr, ref)
if err != nil {
return setRefClosure{}, err
}
return setRefClosure{HashSet: s}, nil
}
type setRefClosure struct {
hash.HashSet
}
var _ RefClosure = setRefClosure{}
func (s setRefClosure) Contains(ref types.Ref) (ok bool, err error) {
ok = s.HashSet.Has(ref.TargetHash())
return
}
func transitiveClosure(ctx context.Context, vr types.ValueReader, ref types.Ref) (s hash.HashSet, err error) {
h := &RefByHeightHeap{ref}
s = hash.NewHashSet()
var curr types.RefSlice
for !h.Empty() {
curr = h.PopRefsOfHeight(h.MaxHeight())
for _, r := range curr {
s.Insert(r.TargetHash())
}
err = parentsToQueue(ctx, curr, h, vr)
if err != nil {
return nil, err
}
}
return s, nil
}
func NewLazyRefClousure(ctx context.Context, ref types.Ref, vr types.ValueReader) lazyRefClosure {
return lazyRefClosure{
partial: hash.NewHashSet(ref.TargetHash()),
bottom: RefByHeightHeap{ref},
}
}
type lazyRefClosure struct {
partial hash.HashSet
bottom []types.Ref
depth uint64
}
var _ RefClosure = lazyRefClosure{}
func (l lazyRefClosure) Contains(ref types.Ref) (ok bool, err error) {
if ref.Height() < l.depth {
err = traverseToDepth(ref.Height(), l.bottom, l.partial)
}
if err != nil {
return false, err
}
return l.partial.Has(ref.TargetHash()), nil
}
func traverseToDepth(depth uint64, roots []types.Ref, visited hash.HashSet) error {
panic("todo")
}