diff --git a/types/future.go b/types/future.go index 356eb7b98c..2a3aa56998 100644 --- a/types/future.go +++ b/types/future.go @@ -2,6 +2,7 @@ package types import ( "github.com/attic-labs/noms/chunks" + . "github.com/attic-labs/noms/dbg" "github.com/attic-labs/noms/ref" ) @@ -25,3 +26,12 @@ func futuresEqual(f1, f2 future) bool { return f1.Ref() == f2.Ref() } } + +func futureEqualsValue(f future, v Value) bool { + Chk.NotNil(v) + if f.Val() != nil { + return f.Val().Equals(v) + } else { + return f.Ref() == v.Ref() + } +} diff --git a/types/map.go b/types/map.go index daeff3b72d..d2f85705d1 100644 --- a/types/map.go +++ b/types/map.go @@ -30,14 +30,14 @@ func (fm Map) Len() uint64 { func (fm Map) Has(key Value) bool { idx := indexMapData(fm.m, key.Ref()) - return idx < len(fm.m) && fm.m[idx].key.Ref() == key.Ref() + return idx < len(fm.m) && futureEqualsValue(fm.m[idx].key, key) } func (fm Map) Get(key Value) Value { idx := indexMapData(fm.m, key.Ref()) if idx < len(fm.m) { entry := fm.m[idx] - if entry.key.Ref() == key.Ref() { + if futureEqualsValue(entry.key, key) { v, err := entry.value.Deref(fm.cs) Chk.NoError(err) return v @@ -56,7 +56,7 @@ func (fm Map) SetM(kv ...Value) Map { func (fm Map) Remove(k Value) Map { idx := indexMapData(fm.m, k.Ref()) - if idx == len(fm.m) || fm.m[idx].key.Ref() != k.Ref() { + if idx == len(fm.m) || !futureEqualsValue(fm.m[idx].key, k) { return fm } @@ -109,7 +109,7 @@ func buildMapData(oldData mapData, futures []future) mapData { v := futures[i+1] e := mapEntry{k, v} idx := indexMapData(m, k.Ref()) - if idx != len(m) && m[idx].key.Ref() == k.Ref() { + if idx != len(m) && futuresEqual(m[idx].key, k) { m[idx] = e } else { m = append(m, e) diff --git a/types/set.go b/types/set.go index 58d8b7da41..99f1b1342f 100644 --- a/types/set.go +++ b/types/set.go @@ -34,7 +34,7 @@ func (fs Set) Len() uint64 { func (fs Set) Has(v Value) bool { idx := indexSetData(fs.m, v.Ref()) - return idx < len(fs.m) && fs.m[idx].Ref() == v.Ref() + return idx < len(fs.m) && futureEqualsValue(fs.m[idx], v) } func (fs Set) Insert(values ...Value) Set { @@ -46,7 +46,7 @@ func (fs Set) Remove(values ...Value) Set { for _, v := range values { if v != nil { idx := indexSetData(fs.m, v.Ref()) - if idx < len(fs.m) && fs.m[idx].Ref() == v.Ref() { + if idx < len(fs.m) && futureEqualsValue(fs.m[idx], v) { m2 = append(m2[:idx], m2[idx+1:]...) } } @@ -125,7 +125,7 @@ func buildSetData(old setData, futures []future) setData { copy(r, old) for _, f := range futures { idx := indexSetData(r, f.Ref()) - if idx < len(r) && r[idx].Ref() == f.Ref() { + if idx < len(r) && futuresEqual(r[idx], f) { // We already have this fellow. continue } else {