mirror of
https://github.com/dolthub/dolt.git
synced 2026-05-21 03:24:13 -05:00
Make sure we flatten unions before calling simplify types (#3307)
We ran into an issue where we passed non flatten types to simplify types which lead to us having two `struct Commit` in a union.
This commit is contained in:
+54
-43
@@ -1,6 +1,11 @@
|
||||
package types
|
||||
|
||||
import "github.com/attic-labs/noms/go/d"
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/attic-labs/noms/go/d"
|
||||
"github.com/attic-labs/noms/go/hash"
|
||||
)
|
||||
|
||||
// makeSimplifiedType returns a type that is a supertype of all the input types but is much
|
||||
// smaller and less complex than a straight union of all those types would be.
|
||||
@@ -32,22 +37,21 @@ import "github.com/attic-labs/noms/go/d"
|
||||
//
|
||||
// Anytime any of the above cases generates a union as output, the same process
|
||||
// is applied to that union recursively.
|
||||
func makeSimplifiedType(intersectStructs bool, in ...*Type) *Type {
|
||||
func (tc *TypeCache) makeSimplifiedType(intersectStructs bool, in ...*Type) *Type {
|
||||
ts := make(typeset, len(in))
|
||||
for _, t := range in {
|
||||
// De-cycle so that we handle cycles explicitly below. Otherwise, we would implicitly crawl
|
||||
// cycles and recurse forever.
|
||||
ut := ToUnresolvedType(t)
|
||||
ts[ut] = struct{}{}
|
||||
ts.Add(ToUnresolvedType(t))
|
||||
}
|
||||
|
||||
// Impl de-cycles internally.
|
||||
return makeSimplifiedTypeImpl(ts, intersectStructs)
|
||||
return tc.makeSimplifiedTypeImpl(ts, intersectStructs)
|
||||
}
|
||||
|
||||
// typeset is a helper that aggregates the unique set of input types for this algorithm, flattening
|
||||
// any unions recursively.
|
||||
type typeset map[*Type]struct{}
|
||||
type typeset map[hash.Hash]*Type
|
||||
|
||||
func (ts typeset) Add(t *Type) {
|
||||
switch t.Kind() {
|
||||
@@ -56,7 +60,7 @@ func (ts typeset) Add(t *Type) {
|
||||
ts.Add(et)
|
||||
}
|
||||
default:
|
||||
ts[t] = struct{}{}
|
||||
ts[t.Hash()] = t
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,15 +75,15 @@ func newTypeset(t ...*Type) typeset {
|
||||
// makeSimplifiedTypeImpl is an implementation detail.
|
||||
// Warning: Do not call this directly. It assumes its input has been de-cycled using
|
||||
// ToUnresolvedType() and will infinitely recurse on cyclic types otherwise.
|
||||
func makeSimplifiedTypeImpl(in typeset, intersectStructs bool) *Type {
|
||||
func (tc *TypeCache) makeSimplifiedTypeImpl(in typeset, intersectStructs bool) *Type {
|
||||
type how struct {
|
||||
k NomsKind
|
||||
n string
|
||||
}
|
||||
|
||||
out := make([]*Type, 0, len(in))
|
||||
out := make(typeSlice, 0, len(in))
|
||||
groups := map[how]typeset{}
|
||||
for t := range in {
|
||||
for _, t := range in {
|
||||
var h how
|
||||
switch t.Kind() {
|
||||
case RefKind, SetKind, ListKind, MapKind:
|
||||
@@ -100,7 +104,7 @@ func makeSimplifiedTypeImpl(in typeset, intersectStructs bool) *Type {
|
||||
|
||||
for h, ts := range groups {
|
||||
if len(ts) == 1 {
|
||||
for t := range ts {
|
||||
for _, t := range ts {
|
||||
out = append(out, t)
|
||||
}
|
||||
continue
|
||||
@@ -109,15 +113,15 @@ func makeSimplifiedTypeImpl(in typeset, intersectStructs bool) *Type {
|
||||
var r *Type
|
||||
switch h.k {
|
||||
case RefKind:
|
||||
r = simplifyRefs(ts, intersectStructs)
|
||||
r = tc.simplifyContainers(h.k, ts, intersectStructs)
|
||||
case SetKind:
|
||||
r = simplifySets(ts, intersectStructs)
|
||||
r = tc.simplifyContainers(h.k, ts, intersectStructs)
|
||||
case ListKind:
|
||||
r = simplifyLists(ts, intersectStructs)
|
||||
r = tc.simplifyContainers(h.k, ts, intersectStructs)
|
||||
case MapKind:
|
||||
r = simplifyMaps(ts, intersectStructs)
|
||||
r = tc.simplifyMaps(ts, intersectStructs)
|
||||
case StructKind:
|
||||
r = simplifyStructs(h.n, ts, intersectStructs)
|
||||
r = tc.simplifyStructs(h.n, ts, intersectStructs)
|
||||
}
|
||||
out = append(out, r)
|
||||
}
|
||||
@@ -131,47 +135,49 @@ func makeSimplifiedTypeImpl(in typeset, intersectStructs bool) *Type {
|
||||
return out[0]
|
||||
}
|
||||
|
||||
staticTypeCache.Lock()
|
||||
defer staticTypeCache.Unlock()
|
||||
return staticTypeCache.makeUnionType(out...)
|
||||
sort.Sort(out)
|
||||
|
||||
tc.Lock()
|
||||
defer tc.Unlock()
|
||||
|
||||
return tc.getCompoundType(UnionKind, out...)
|
||||
}
|
||||
|
||||
func simplifyRefs(ts typeset, intersectStructs bool) *Type {
|
||||
return simplifyContainers(RefKind, MakeRefType, ts, intersectStructs)
|
||||
}
|
||||
|
||||
func simplifySets(ts typeset, intersectStructs bool) *Type {
|
||||
return simplifyContainers(SetKind, MakeSetType, ts, intersectStructs)
|
||||
}
|
||||
|
||||
func simplifyLists(ts typeset, intersectStructs bool) *Type {
|
||||
return simplifyContainers(ListKind, MakeListType, ts, intersectStructs)
|
||||
}
|
||||
|
||||
func simplifyContainers(expectedKind NomsKind, makeContainer func(elem *Type) *Type, ts typeset, intersectStructs bool) *Type {
|
||||
func (tc *TypeCache) simplifyContainers(expectedKind NomsKind, ts typeset, intersectStructs bool) *Type {
|
||||
elemTypes := make(typeset, len(ts))
|
||||
for t := range ts {
|
||||
for _, t := range ts {
|
||||
d.Chk.True(expectedKind == t.Kind())
|
||||
elemTypes.Add(t.Desc.(CompoundDesc).ElemTypes[0])
|
||||
}
|
||||
return makeContainer(makeSimplifiedTypeImpl(elemTypes, intersectStructs))
|
||||
|
||||
elemType := tc.makeSimplifiedTypeImpl(elemTypes, intersectStructs)
|
||||
|
||||
tc.Lock()
|
||||
defer tc.Unlock()
|
||||
|
||||
return tc.getCompoundType(expectedKind, elemType)
|
||||
}
|
||||
|
||||
func simplifyMaps(ts typeset, intersectStructs bool) *Type {
|
||||
func (tc *TypeCache) simplifyMaps(ts typeset, intersectStructs bool) *Type {
|
||||
keyTypes := make(typeset, len(ts))
|
||||
valTypes := make(typeset, len(ts))
|
||||
for t := range ts {
|
||||
for _, t := range ts {
|
||||
d.Chk.True(MapKind == t.Kind())
|
||||
desc := t.Desc.(CompoundDesc)
|
||||
keyTypes.Add(desc.ElemTypes[0])
|
||||
valTypes.Add(desc.ElemTypes[1])
|
||||
}
|
||||
return MakeMapType(
|
||||
makeSimplifiedTypeImpl(keyTypes, intersectStructs),
|
||||
makeSimplifiedTypeImpl(valTypes, intersectStructs))
|
||||
|
||||
kt := tc.makeSimplifiedTypeImpl(keyTypes, intersectStructs)
|
||||
vt := tc.makeSimplifiedTypeImpl(valTypes, intersectStructs)
|
||||
|
||||
tc.Lock()
|
||||
defer tc.Unlock()
|
||||
|
||||
return tc.getCompoundType(MapKind, kt, vt)
|
||||
}
|
||||
|
||||
func simplifyStructs(expectedName string, ts typeset, intersectStructs bool) *Type {
|
||||
func (tc *TypeCache) simplifyStructs(expectedName string, ts typeset, intersectStructs bool) *Type {
|
||||
// We gather all the fields/types into allFields. If the number of
|
||||
// times a field name is present is less that then number of types we
|
||||
// are simplifying then the field must be optional.
|
||||
@@ -186,7 +192,7 @@ func simplifyStructs(expectedName string, ts typeset, intersectStructs bool) *Ty
|
||||
}
|
||||
allFields := map[string]fieldTypeInfo{}
|
||||
|
||||
for t := range ts {
|
||||
for _, t := range ts {
|
||||
d.Chk.True(StructKind == t.Kind())
|
||||
desc := t.Desc.(StructDesc)
|
||||
d.Chk.True(expectedName == desc.Name)
|
||||
@@ -212,10 +218,15 @@ func simplifyStructs(expectedName string, ts typeset, intersectStructs bool) *Ty
|
||||
for name, fti := range allFields {
|
||||
fields = append(fields, StructField{
|
||||
Name: name,
|
||||
Type: makeSimplifiedTypeImpl(fti.typeset, intersectStructs),
|
||||
Type: tc.makeSimplifiedTypeImpl(fti.typeset, intersectStructs),
|
||||
Optional: !(intersectStructs && fti.anyNonOptional) && fti.count < count,
|
||||
})
|
||||
}
|
||||
|
||||
return MakeStructType(expectedName, fields...)
|
||||
sort.Sort(fields)
|
||||
|
||||
tc.Lock()
|
||||
defer tc.Unlock()
|
||||
|
||||
return tc.makeStructType(expectedName, fields)
|
||||
}
|
||||
|
||||
@@ -15,10 +15,24 @@ import (
|
||||
// - test grouping of the various kinds
|
||||
// - test cycles
|
||||
|
||||
func simplifyRefs(ts typeset, intersectStructs bool) *Type {
|
||||
return staticTypeCache.simplifyContainers(RefKind, ts, intersectStructs)
|
||||
}
|
||||
func simplifySets(ts typeset, intersectStructs bool) *Type {
|
||||
return staticTypeCache.simplifyContainers(SetKind, ts, intersectStructs)
|
||||
}
|
||||
func simplifyLists(ts typeset, intersectStructs bool) *Type {
|
||||
return staticTypeCache.simplifyContainers(ListKind, ts, intersectStructs)
|
||||
}
|
||||
|
||||
func simplifyMaps(ts typeset, intersectStructs bool) *Type {
|
||||
return staticTypeCache.simplifyMaps(ts, intersectStructs)
|
||||
}
|
||||
|
||||
func TestSimplifyHelpers(t *testing.T) {
|
||||
structSimplifier := func(n string) func(typeset, bool) *Type {
|
||||
return func(ts typeset, intersectStructs bool) *Type {
|
||||
return simplifyStructs(n, ts, intersectStructs)
|
||||
return staticTypeCache.simplifyStructs(n, ts, intersectStructs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,7 +266,7 @@ func TestMakeSimplifiedUnion(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
act := makeSimplifiedType(intersectStruct, c.in...)
|
||||
act := staticTypeCache.makeSimplifiedType(intersectStruct, c.in...)
|
||||
assert.True(t, c.out.Equals(act), "Test case as position %d - got %s, expected %s", i, act.Describe(), c.out.Describe())
|
||||
}
|
||||
}
|
||||
|
||||
+6
-34
@@ -7,8 +7,6 @@ package types
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/attic-labs/noms/go/hash"
|
||||
)
|
||||
|
||||
type TypeCache struct {
|
||||
@@ -311,14 +309,7 @@ func checkForUnresolvedCycles(t, root *Type, parentStructTypes []*Type) {
|
||||
|
||||
// MakeUnionType creates a new union type unless the elemTypes can be folded into a single non union type.
|
||||
func (tc *TypeCache) makeUnionType(elemTypes ...*Type) *Type {
|
||||
seenTypes := map[hash.Hash]bool{}
|
||||
ts := flattenUnionTypes(typeSlice(elemTypes), &seenTypes)
|
||||
if len(ts) == 1 {
|
||||
return ts[0]
|
||||
}
|
||||
// We sort the contituent types to dedup equivalent types in memory; we may need to sort again after cycles are resolved for final encoding.
|
||||
sort.Sort(ts)
|
||||
return tc.getCompoundType(UnionKind, ts...)
|
||||
return tc.makeSimplifiedType(false, elemTypes...)
|
||||
}
|
||||
|
||||
func (tc *TypeCache) getCycleType(level uint32) *Type {
|
||||
@@ -331,25 +322,6 @@ func (tc *TypeCache) getCycleType(level uint32) *Type {
|
||||
return trie.t
|
||||
}
|
||||
|
||||
func flattenUnionTypes(ts typeSlice, seenTypes *map[hash.Hash]bool) typeSlice {
|
||||
if len(ts) == 0 {
|
||||
return ts
|
||||
}
|
||||
|
||||
ts2 := make(typeSlice, 0, len(ts))
|
||||
for _, t := range ts {
|
||||
if t.Kind() == UnionKind {
|
||||
ts2 = append(ts2, flattenUnionTypes(t.Desc.(CompoundDesc).ElemTypes, seenTypes)...)
|
||||
} else {
|
||||
if !(*seenTypes)[t.Hash()] {
|
||||
(*seenTypes)[t.Hash()] = true
|
||||
ts2 = append(ts2, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ts2
|
||||
}
|
||||
|
||||
func MakeListType(elemType *Type) *Type {
|
||||
staticTypeCache.Lock()
|
||||
defer staticTypeCache.Unlock()
|
||||
@@ -401,17 +373,17 @@ func (s structFields) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s structFields) Less(i, j int) bool { return s[i].Name < s[j].Name }
|
||||
|
||||
func MakeStructType(name string, fields ...StructField) *Type {
|
||||
staticTypeCache.Lock()
|
||||
defer staticTypeCache.Unlock()
|
||||
|
||||
fs := structFields(fields)
|
||||
sort.Sort(&fs)
|
||||
|
||||
staticTypeCache.Lock()
|
||||
defer staticTypeCache.Unlock()
|
||||
|
||||
return staticTypeCache.makeStructType(name, fs)
|
||||
}
|
||||
|
||||
func MakeUnionType(elemTypes ...*Type) *Type {
|
||||
return makeSimplifiedType(false, elemTypes...)
|
||||
return staticTypeCache.makeUnionType(elemTypes...)
|
||||
}
|
||||
|
||||
// MakeUnionTypeIntersectStructs is a bit of strange function. It creates a
|
||||
@@ -419,7 +391,7 @@ func MakeUnionType(elemTypes ...*Type) *Type {
|
||||
// types.
|
||||
// This function will go away so do not use it!
|
||||
func MakeUnionTypeIntersectStructs(elemTypes ...*Type) *Type {
|
||||
return makeSimplifiedType(true, elemTypes...)
|
||||
return staticTypeCache.makeSimplifiedType(true, elemTypes...)
|
||||
}
|
||||
|
||||
func MakeCycleType(level uint32) *Type {
|
||||
|
||||
Reference in New Issue
Block a user