diff --git a/types/assert.go b/types/assert.go new file mode 100644 index 0000000000..8f7593d9db --- /dev/null +++ b/types/assert.go @@ -0,0 +1,31 @@ +package types + +import "github.com/attic-labs/noms/d" + +func assertType(t TypeRef, v ...Value) { + if t.Kind() != ValueKind { + for _, v := range v { + d.Chk.True(t.Equals(v.TypeRef())) + } + } +} + +func assertSetsSameType(s Set, v ...Set) { + if s.elemType().Kind() != ValueKind { + t := s.TypeRef() + for _, v := range v { + d.Chk.True(t.Equals(v.TypeRef())) + } + } +} + +func assertMapElemTypes(m Map, v ...Value) { + elemTypes := m.elemTypes() + keyType := elemTypes[0] + valueType := elemTypes[0] + if keyType.Kind() != ValueKind || valueType.Kind() != ValueKind { + for i, v := range v { + d.Chk.True(elemTypes[i%2].Equals(v.TypeRef())) + } + } +} diff --git a/types/list_leaf.go b/types/list_leaf.go index 68a1ba9e4d..853e256206 100644 --- a/types/list_leaf.go +++ b/types/list_leaf.go @@ -125,6 +125,7 @@ func (l listLeaf) Slice(start uint64, end uint64) List { } func (l listLeaf) Set(idx uint64, v Value) List { + assertType(l.elemType(), v) values := make([]Value, len(l.values)) copy(values, l.values) values[idx] = v @@ -132,11 +133,13 @@ func (l listLeaf) Set(idx uint64, v Value) List { } func (l listLeaf) Append(v ...Value) List { + assertType(l.elemType(), v...) values := append(l.values, v...) return newListLeafNoCopy(values, l.t) } func (l listLeaf) Insert(idx uint64, v ...Value) List { + assertType(l.elemType(), v...) values := make([]Value, len(l.values)+len(v)) copy(values, l.values[:idx]) copy(values[idx:], v) @@ -178,3 +181,7 @@ func (l listLeaf) Chunks() (chunks []ref.Ref) { func (l listLeaf) TypeRef() TypeRef { return l.t } + +func (l listLeaf) elemType() TypeRef { + return l.t.Desc.(CompoundDesc).ElemTypes[0] +} diff --git a/types/list_test.go b/types/list_test.go index a8b3fc3a18..431f59ea43 100644 --- a/types/list_test.go +++ b/types/list_test.go @@ -341,19 +341,24 @@ func TestListTypeRef(t *testing.T) { tr := MakeCompoundTypeRef(ListKind, MakePrimitiveTypeRef(UInt8Kind)) l2 := newListLeafNoCopy([]Value{UInt8(0), UInt8(1)}, tr) assert.Equal(tr, l2.TypeRef()) - l2 = l2.Slice(0, 1) - assert.Equal(tr, l2.TypeRef()) - l2 = l2.Set(0, UInt8(11)) - assert.Equal(tr, l2.TypeRef()) - l2 = l2.Append(UInt8(2)) - assert.Equal(tr, l2.TypeRef()) - l2 = l2.Insert(0, UInt8(3)) - assert.Equal(tr, l2.TypeRef()) - l2 = l2.Remove(0, 1) - assert.Equal(tr, l2.TypeRef()) - l2 = l2.RemoveAt(0) - assert.Equal(tr, l2.TypeRef()) + l3 := l2.Slice(0, 1) + assert.True(tr.Equals(l3.TypeRef())) + l3 = l2.Remove(0, 1) + assert.True(tr.Equals(l3.TypeRef())) + l3 = l2.RemoveAt(0) + assert.True(tr.Equals(l3.TypeRef())) + + l3 = l2.Set(0, UInt8(11)) + assert.True(tr.Equals(l3.TypeRef())) + l3 = l2.Append(UInt8(2)) + assert.True(tr.Equals(l3.TypeRef())) + l3 = l2.Insert(0, UInt8(3)) + assert.True(tr.Equals(l3.TypeRef())) + + assert.Panics(func() { l2.Set(0, NewString("")) }) + assert.Panics(func() { l2.Append(NewString("")) }) + assert.Panics(func() { l2.Insert(0, NewString("")) }) } func TestListChunks(t *testing.T) { diff --git a/types/map.go b/types/map.go index 6ffc30b7f2..98633322c5 100644 --- a/types/map.go +++ b/types/map.go @@ -60,10 +60,14 @@ func (m Map) MaybeGet(key Value) (v Value, ok bool) { } func (m Map) Set(key Value, val Value) Map { + elemTypes := m.t.Desc.(CompoundDesc).ElemTypes + assertType(elemTypes[0], key) + assertType(elemTypes[1], val) return newMapFromData(buildMapData(m.data, []Value{key, val}), m.t) } func (m Map) SetM(kv ...Value) Map { + assertMapElemTypes(m, kv...) return newMapFromData(buildMapData(m.data, kv), m.t) } @@ -132,6 +136,10 @@ func (m Map) TypeRef() TypeRef { return m.t } +func (m Map) elemTypes() []TypeRef { + return m.t.Desc.(CompoundDesc).ElemTypes +} + func init() { RegisterFromValFunction(mapTypeRef, func(v Value) Value { return v.(Map) diff --git a/types/map_test.go b/types/map_test.go index eddefbf97e..8e3a763277 100644 --- a/types/map_test.go +++ b/types/map_test.go @@ -220,23 +220,28 @@ func TestMapTypeRef(t *testing.T) { m := NewMap() assert.True(m.TypeRef().Equals(MakeCompoundTypeRef(MapKind, MakePrimitiveTypeRef(ValueKind), MakePrimitiveTypeRef(ValueKind)))) - tr := MakeCompoundTypeRef(MapKind, MakePrimitiveTypeRef(StringKind), MakePrimitiveTypeRef(Int64Kind)) + tr := MakeCompoundTypeRef(MapKind, MakePrimitiveTypeRef(StringKind), MakePrimitiveTypeRef(UInt64Kind)) m = newMapFromData(mapData{}, tr) assert.Equal(tr, m.TypeRef()) - m = m.Set(NewString("A"), UInt64(1)) - assert.Equal(tr, m.TypeRef()) - - m = m.SetM(NewString("B"), UInt64(2), NewString("C"), UInt64(2)) - assert.Equal(tr, m.TypeRef()) - - m = m.Remove(NewString("B")) - assert.Equal(tr, m.TypeRef()) + m2 := m.Remove(NewString("B")) + assert.True(tr.Equals(m2.TypeRef())) m = m.Filter(func(k, v Value) bool { return true }) - assert.Equal(tr, m.TypeRef()) + assert.True(tr.Equals(m2.TypeRef())) + + m2 = m.Set(NewString("A"), UInt64(1)) + assert.True(tr.Equals(m2.TypeRef())) + + m2 = m.SetM(NewString("B"), UInt64(2), NewString("C"), UInt64(2)) + assert.True(tr.Equals(m2.TypeRef())) + + assert.Panics(func() { m.Set(NewString("A"), UInt8(1)) }) + assert.Panics(func() { m.Set(Bool(true), UInt64(1)) }) + assert.Panics(func() { m.SetM(NewString("B"), UInt64(2), NewString("A"), UInt8(1)) }) + assert.Panics(func() { m.SetM(NewString("B"), UInt64(2), Bool(true), UInt64(1)) }) } func TestMapChunks(t *testing.T) { diff --git a/types/ref.go b/types/ref.go index 5134bc1c44..87104eb483 100644 --- a/types/ref.go +++ b/types/ref.go @@ -52,5 +52,6 @@ func (r Ref) TargetValue(cs chunks.ChunkSource) Value { } func (r Ref) SetTargetValue(val Value, cs chunks.ChunkSink) Ref { + assertType(r.t.Desc.(CompoundDesc).ElemTypes[0], val) return newRef(WriteValue(val, cs), r.t) } diff --git a/types/ref_test.go b/types/ref_test.go index 438d6f2918..865296e64e 100644 --- a/types/ref_test.go +++ b/types/ref_test.go @@ -51,4 +51,13 @@ func TestRefTypeRef(t *testing.T) { m := NewMap() r2 := r.SetTargetValue(m, cs) assert.True(r2.TypeRef().Equals(tr)) + + b := Bool(true) + r2 = r.SetTargetValue(b, cs) + r2.t = MakeCompoundTypeRef(RefKind, b.TypeRef()) + + r3 := r2.SetTargetValue(Bool(false), cs) + assert.True(r2.TypeRef().Equals(r3.TypeRef())) + + assert.Panics(func() { r2.SetTargetValue(Int16(1), cs) }) } diff --git a/types/set.go b/types/set.go index c7579d2b4a..fde5215d94 100644 --- a/types/set.go +++ b/types/set.go @@ -32,6 +32,7 @@ func (s Set) Has(v Value) bool { } func (s Set) Insert(values ...Value) Set { + assertType(s.elemType(), values...) return newSetFromData(buildSetData(s.data, values), s.t) } @@ -49,6 +50,7 @@ func (s Set) Remove(values ...Value) Set { } func (s Set) Union(others ...Set) Set { + assertSetsSameType(s, others...) result := s for _, other := range others { other.Iter(func(v Value) (stop bool) { @@ -129,6 +131,10 @@ func (s Set) TypeRef() TypeRef { return s.t } +func (s Set) elemType() TypeRef { + return s.t.Desc.(CompoundDesc).ElemTypes[0] +} + func init() { RegisterFromValFunction(setTypeRef, func(v Value) Value { return v.(Set) diff --git a/types/set_test.go b/types/set_test.go index 144a48cf62..f9c045a6ba 100644 --- a/types/set_test.go +++ b/types/set_test.go @@ -193,22 +193,30 @@ func TestSetTypeRef(t *testing.T) { s = newSetFromData(setData{}, tr) assert.Equal(tr, s.TypeRef()) - s = s.Insert(UInt64(0), UInt64(1)) - assert.Equal(tr, s.TypeRef()) + s2 := s.Remove(UInt64(1)) + assert.True(tr.Equals(s2.TypeRef())) - s = s.Remove(UInt64(1)) - assert.Equal(tr, s.TypeRef()) + s2 = s.Subtract(s) + assert.True(tr.Equals(s2.TypeRef())) - s = s.Union(s) - assert.Equal(tr, s.TypeRef()) - - s = s.Subtract(s) - assert.Equal(tr, s.TypeRef()) - - s = s.Filter(func(v Value) bool { + s2 = s.Filter(func(v Value) bool { return true }) - assert.Equal(tr, s.TypeRef()) + assert.True(tr.Equals(s2.TypeRef())) + + s2 = s.Insert(UInt64(0), UInt64(1)) + assert.True(tr.Equals(s2.TypeRef())) + + s3 := NewSet(UInt64(2)) + s3.t = s2.t + s2 = s.Union(s3) + assert.True(tr.Equals(s2.TypeRef())) + + assert.Panics(func() { s.Insert(Bool(true)) }) + assert.Panics(func() { s.Insert(UInt64(3), Bool(true)) }) + assert.Panics(func() { s.Union(NewSet(UInt64(2))) }) + assert.Panics(func() { s.Union(NewSet(Bool(true))) }) + assert.Panics(func() { s.Union(s, NewSet(Bool(true))) }) } func TestSetChunks(t *testing.T) {