mirror of
https://github.com/dolthub/dolt.git
synced 2026-02-11 02:59:34 -06:00
Fix issue with canUseDef
We were not passing the package through when we called containsComparable which made us lookup the type in the wrong package. Fixes #869
This commit is contained in:
@@ -392,7 +392,7 @@ func (gen *codeGen) writeStruct(t types.Type, ordinal int) {
|
||||
nil,
|
||||
len(desc.Union) != 0,
|
||||
types.MakePrimitiveType(types.Uint32Kind),
|
||||
gen.canUseDef(t),
|
||||
gen.canUseDef(t, gen.pkg.Package),
|
||||
}
|
||||
|
||||
if data.HasUnion {
|
||||
@@ -423,7 +423,7 @@ func (gen *codeGen) writeList(t types.Type) {
|
||||
gen.generator.UserName(t),
|
||||
t,
|
||||
elemTypes[0],
|
||||
gen.canUseDef(t),
|
||||
gen.canUseDef(t, gen.pkg.Package),
|
||||
}
|
||||
gen.writeTemplate("list.tmpl", t, data)
|
||||
gen.writeLater(elemTypes[0])
|
||||
@@ -444,7 +444,7 @@ func (gen *codeGen) writeMap(t types.Type) {
|
||||
t,
|
||||
elemTypes[0],
|
||||
elemTypes[1],
|
||||
gen.canUseDef(t),
|
||||
gen.canUseDef(t, gen.pkg.Package),
|
||||
}
|
||||
gen.writeTemplate("map.tmpl", t, data)
|
||||
gen.writeLater(elemTypes[0])
|
||||
@@ -481,7 +481,7 @@ func (gen *codeGen) writeSet(t types.Type) {
|
||||
gen.generator.UserName(t),
|
||||
t,
|
||||
elemTypes[0],
|
||||
gen.canUseDef(t),
|
||||
gen.canUseDef(t, gen.pkg.Package),
|
||||
}
|
||||
gen.writeTemplate("set.tmpl", t, data)
|
||||
gen.writeLater(elemTypes[0])
|
||||
@@ -506,32 +506,31 @@ func (gen *codeGen) writeEnum(t types.Type, ordinal int) {
|
||||
gen.writeTemplate("enum.tmpl", t, data)
|
||||
}
|
||||
|
||||
func (gen *codeGen) canUseDef(t types.Type) bool {
|
||||
func (gen *codeGen) canUseDef(t types.Type, p types.Package) bool {
|
||||
cache := map[string]bool{}
|
||||
|
||||
var rec func(t types.Type, p types.Package) bool
|
||||
rec = func(t types.Type, p types.Package) bool {
|
||||
if t.HasPackageRef() {
|
||||
p = gen.deps[t.PackageRef()]
|
||||
d.Chk.NotNil(p)
|
||||
}
|
||||
rt := resolveInPackage(t, &p)
|
||||
switch rt.Kind() {
|
||||
switch t.Kind() {
|
||||
case types.UnresolvedKind:
|
||||
t2, p2 := gen.resolveInPackage(t, p)
|
||||
d.Chk.False(t2.IsUnresolved())
|
||||
return rec(t2, p2)
|
||||
case types.ListKind:
|
||||
return rec(rt.Desc.(types.CompoundDesc).ElemTypes[0], p)
|
||||
return rec(t.Desc.(types.CompoundDesc).ElemTypes[0], p)
|
||||
case types.SetKind:
|
||||
elemType := rt.Desc.(types.CompoundDesc).ElemTypes[0]
|
||||
return !gen.containsNonComparable(elemType) && rec(elemType, p)
|
||||
elemType := t.Desc.(types.CompoundDesc).ElemTypes[0]
|
||||
return !gen.containsNonComparable(elemType, p) && rec(elemType, p)
|
||||
case types.MapKind:
|
||||
elemTypes := rt.Desc.(types.CompoundDesc).ElemTypes
|
||||
return !gen.containsNonComparable(elemTypes[0]) && rec(elemTypes[0], p) && rec(elemTypes[1], p)
|
||||
elemTypes := t.Desc.(types.CompoundDesc).ElemTypes
|
||||
return !gen.containsNonComparable(elemTypes[0], p) && rec(elemTypes[0], p) && rec(elemTypes[1], p)
|
||||
case types.StructKind:
|
||||
userName := gen.generator.UserName(t)
|
||||
if b, ok := cache[userName]; ok {
|
||||
return b
|
||||
}
|
||||
cache[userName] = true
|
||||
for _, f := range rt.Desc.(types.StructDesc).Fields {
|
||||
for _, f := range t.Desc.(types.StructDesc).Fields {
|
||||
if f.T.Equals(t) || !rec(f.T, p) {
|
||||
cache[userName] = false
|
||||
return false
|
||||
@@ -543,22 +542,21 @@ func (gen *codeGen) canUseDef(t types.Type) bool {
|
||||
}
|
||||
}
|
||||
|
||||
return rec(t, gen.pkg.Package)
|
||||
return rec(t, p)
|
||||
}
|
||||
|
||||
// We use a go map as the def for Set and Map. These cannot have a key that is a
|
||||
// Set, Map or a List because slices and maps are not comparable in go.
|
||||
func (gen *codeGen) containsNonComparable(t types.Type) bool {
|
||||
func (gen *codeGen) containsNonComparable(t types.Type, p types.Package) bool {
|
||||
cache := map[string]bool{}
|
||||
|
||||
var rec func(t types.Type, p types.Package) bool
|
||||
rec = func(t types.Type, p types.Package) bool {
|
||||
if t.HasPackageRef() {
|
||||
p = gen.deps[t.PackageRef()]
|
||||
d.Chk.NotNil(p)
|
||||
}
|
||||
t = resolveInPackage(t, &p)
|
||||
switch t.Desc.Kind() {
|
||||
case types.UnresolvedKind:
|
||||
t2, p2 := gen.resolveInPackage(t, p)
|
||||
d.Chk.False(t2.IsUnresolved())
|
||||
return rec(t2, p2)
|
||||
case types.ListKind, types.MapKind, types.SetKind:
|
||||
return true
|
||||
case types.StructKind:
|
||||
@@ -567,8 +565,7 @@ func (gen *codeGen) containsNonComparable(t types.Type) bool {
|
||||
if b, ok := cache[userName]; ok {
|
||||
return b
|
||||
}
|
||||
// If we get here in a recursive call we will mark it as not having a non comparable value. If it does then that will
|
||||
// get handled higher up in the call chain.
|
||||
// If we get here in a recursive call we will mark it as not having a non comparable value. If it does then that will get handled higher up in the call chain.
|
||||
cache[userName] = false
|
||||
for _, f := range t.Desc.(types.StructDesc).Fields {
|
||||
if rec(f.T, p) {
|
||||
@@ -582,12 +579,17 @@ func (gen *codeGen) containsNonComparable(t types.Type) bool {
|
||||
}
|
||||
}
|
||||
|
||||
return rec(t, gen.pkg.Package)
|
||||
return rec(t, p)
|
||||
}
|
||||
|
||||
func resolveInPackage(t types.Type, p *types.Package) types.Type {
|
||||
if !t.IsUnresolved() {
|
||||
return t
|
||||
func (gen *codeGen) resolveInPackage(t types.Type, p types.Package) (types.Type, types.Package) {
|
||||
d.Chk.True(t.IsUnresolved())
|
||||
|
||||
// For unresolved types that references types in the same package the ref is empty and we need to use the passed in package.
|
||||
if t.HasPackageRef() {
|
||||
p = gen.deps[t.PackageRef()]
|
||||
d.Chk.NotNil(p)
|
||||
}
|
||||
return p.Types()[t.Ordinal()]
|
||||
|
||||
return p.Types()[t.Ordinal()], p
|
||||
}
|
||||
|
||||
@@ -71,10 +71,10 @@ func TestCanUseDef(t *testing.T) {
|
||||
pkg := pkg.ParseNomDL("fakefile", bytes.NewBufferString(s), "", emptyCS)
|
||||
gen := newCodeGen(nil, "fakefile", map[string]bool{}, depsMap{}, pkg)
|
||||
for _, t := range pkg.UsingDeclarations {
|
||||
assert.Equal(using, gen.canUseDef(t))
|
||||
assert.Equal(using, gen.canUseDef(t, gen.pkg.Package))
|
||||
}
|
||||
for _, t := range pkg.Types() {
|
||||
assert.Equal(named, gen.canUseDef(t))
|
||||
assert.Equal(named, gen.canUseDef(t, gen.pkg.Package))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -199,3 +199,40 @@ func (suite *ImportTestSuite) TestImports() {
|
||||
suite.EqualValues(types.ListKind, usings[1].Kind())
|
||||
suite.EqualValues(0, usings[1].Desc.(types.CompoundDesc).ElemTypes[0].Ordinal())
|
||||
}
|
||||
|
||||
func (suite *ImportTestSuite) TestImportWithLocalRef() {
|
||||
dir, err := ioutil.TempDir("", "")
|
||||
suite.NoError(err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
byPathNomDL := filepath.Join(dir, "filedep.noms")
|
||||
err = ioutil.WriteFile(byPathNomDL, []byte("struct FromFile{i:Int8}"), 0600)
|
||||
suite.NoError(err)
|
||||
|
||||
r1 := strings.NewReader(`
|
||||
struct A {
|
||||
B: B
|
||||
}
|
||||
struct B {
|
||||
X: Int64
|
||||
}`)
|
||||
pkg1 := ParseNomDL("test1", r1, dir, suite.cs)
|
||||
pkgRef1 := types.WriteValue(pkg1.Package, suite.cs)
|
||||
|
||||
r2 := strings.NewReader(fmt.Sprintf(`
|
||||
alias Other = import "%s"
|
||||
struct C {
|
||||
C: Map<Int64, Other.A>
|
||||
}
|
||||
`, pkgRef1))
|
||||
pkg2 := ParseNomDL("test2", r2, dir, suite.cs)
|
||||
|
||||
ts := pkg2.Types()
|
||||
suite.Len(ts, 1)
|
||||
suite.EqualValues(types.StructKind, ts[0].Kind())
|
||||
mapType := ts[0].Desc.(types.StructDesc).Fields[0].T
|
||||
suite.EqualValues(types.MapKind, mapType.Kind())
|
||||
otherAType := mapType.Desc.(types.CompoundDesc).ElemTypes[1]
|
||||
suite.EqualValues(types.UnresolvedKind, otherAType.Kind())
|
||||
suite.EqualValues(pkgRef1, otherAType.PackageRef())
|
||||
}
|
||||
|
||||
@@ -59,8 +59,9 @@ func resolveLocalOrdinals(p *intermediate) {
|
||||
if t.Namespace() == "" && !t.HasOrdinal() {
|
||||
ordinal := indexOf(t, p.Types)
|
||||
d.Chk.True(ordinal >= 0 && int(ordinal) < len(p.Types), "Invalid reference: %s", t.Name())
|
||||
return types.MakeType(t.PackageRef(), int16(ordinal))
|
||||
return types.MakeType(ref.Ref{}, int16(ordinal))
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -119,6 +120,8 @@ func resolveNamespaces(p *intermediate, aliases map[string]ref.Ref, deps map[ref
|
||||
t = resolveNamespace(t, aliases, deps)
|
||||
}
|
||||
switch t.Kind() {
|
||||
case types.UnresolvedKind:
|
||||
d.Chk.True(t.HasPackageRef(), "should resolve again")
|
||||
case types.ListKind, types.SetKind, types.RefKind:
|
||||
return types.MakeCompoundType(t.Kind(), rec(t.Desc.(types.CompoundDesc).ElemTypes[0]))
|
||||
case types.MapKind:
|
||||
@@ -129,7 +132,10 @@ func resolveNamespaces(p *intermediate, aliases map[string]ref.Ref, deps map[ref
|
||||
resolveFields(t.Desc.(types.StructDesc).Union)
|
||||
}
|
||||
|
||||
d.Chk.True(!t.IsUnresolved() || t.HasOrdinal())
|
||||
if t.IsUnresolved() {
|
||||
return rec(t)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -155,6 +161,7 @@ func resolveNamespace(t types.Type, aliases map[string]ref.Ref, deps map[ref.Ref
|
||||
d.Chk.NotEqual("", t.Name())
|
||||
ordinal := deps[pkgRef].GetOrdinal(t.Name())
|
||||
d.Exp.NotEqual(int64(-1), ordinal, "Could not find type %s in package %s (aliased to %s).", t.Name(), pkgRef.String(), t.Namespace())
|
||||
d.Chk.False(pkgRef.IsEmpty())
|
||||
return types.MakeType(pkgRef, int16(ordinal))
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user