diff --git a/nomdl/codegen/codegen.go b/nomdl/codegen/codegen.go index dcbc731ba6..d379927ba3 100644 --- a/nomdl/codegen/codegen.go +++ b/nomdl/codegen/codegen.go @@ -392,7 +392,7 @@ func (gen *codeGen) writeStruct(t types.Type, ordinal int) { nil, len(desc.Union) != 0, types.MakePrimitiveType(types.Uint32Kind), - gen.canUseDef(t), + gen.canUseDef(t, gen.pkg.Package), } if data.HasUnion { @@ -423,7 +423,7 @@ func (gen *codeGen) writeList(t types.Type) { gen.generator.UserName(t), t, elemTypes[0], - gen.canUseDef(t), + gen.canUseDef(t, gen.pkg.Package), } gen.writeTemplate("list.tmpl", t, data) gen.writeLater(elemTypes[0]) @@ -444,7 +444,7 @@ func (gen *codeGen) writeMap(t types.Type) { t, elemTypes[0], elemTypes[1], - gen.canUseDef(t), + gen.canUseDef(t, gen.pkg.Package), } gen.writeTemplate("map.tmpl", t, data) gen.writeLater(elemTypes[0]) @@ -481,7 +481,7 @@ func (gen *codeGen) writeSet(t types.Type) { gen.generator.UserName(t), t, elemTypes[0], - gen.canUseDef(t), + gen.canUseDef(t, gen.pkg.Package), } gen.writeTemplate("set.tmpl", t, data) gen.writeLater(elemTypes[0]) @@ -506,32 +506,31 @@ func (gen *codeGen) writeEnum(t types.Type, ordinal int) { gen.writeTemplate("enum.tmpl", t, data) } -func (gen *codeGen) canUseDef(t types.Type) bool { +func (gen *codeGen) canUseDef(t types.Type, p types.Package) bool { cache := map[string]bool{} var rec func(t types.Type, p types.Package) bool rec = func(t types.Type, p types.Package) bool { - if t.HasPackageRef() { - p = gen.deps[t.PackageRef()] - d.Chk.NotNil(p) - } - rt := resolveInPackage(t, &p) - switch rt.Kind() { + switch t.Kind() { + case types.UnresolvedKind: + t2, p2 := gen.resolveInPackage(t, p) + d.Chk.False(t2.IsUnresolved()) + return rec(t2, p2) case types.ListKind: - return rec(rt.Desc.(types.CompoundDesc).ElemTypes[0], p) + return rec(t.Desc.(types.CompoundDesc).ElemTypes[0], p) case types.SetKind: - elemType := rt.Desc.(types.CompoundDesc).ElemTypes[0] - return !gen.containsNonComparable(elemType) && rec(elemType, p) + elemType := t.Desc.(types.CompoundDesc).ElemTypes[0] + return !gen.containsNonComparable(elemType, p) && rec(elemType, p) case types.MapKind: - elemTypes := rt.Desc.(types.CompoundDesc).ElemTypes - return !gen.containsNonComparable(elemTypes[0]) && rec(elemTypes[0], p) && rec(elemTypes[1], p) + elemTypes := t.Desc.(types.CompoundDesc).ElemTypes + return !gen.containsNonComparable(elemTypes[0], p) && rec(elemTypes[0], p) && rec(elemTypes[1], p) case types.StructKind: userName := gen.generator.UserName(t) if b, ok := cache[userName]; ok { return b } cache[userName] = true - for _, f := range rt.Desc.(types.StructDesc).Fields { + for _, f := range t.Desc.(types.StructDesc).Fields { if f.T.Equals(t) || !rec(f.T, p) { cache[userName] = false return false @@ -543,22 +542,21 @@ func (gen *codeGen) canUseDef(t types.Type) bool { } } - return rec(t, gen.pkg.Package) + return rec(t, p) } // We use a go map as the def for Set and Map. These cannot have a key that is a // Set, Map or a List because slices and maps are not comparable in go. -func (gen *codeGen) containsNonComparable(t types.Type) bool { +func (gen *codeGen) containsNonComparable(t types.Type, p types.Package) bool { cache := map[string]bool{} var rec func(t types.Type, p types.Package) bool rec = func(t types.Type, p types.Package) bool { - if t.HasPackageRef() { - p = gen.deps[t.PackageRef()] - d.Chk.NotNil(p) - } - t = resolveInPackage(t, &p) switch t.Desc.Kind() { + case types.UnresolvedKind: + t2, p2 := gen.resolveInPackage(t, p) + d.Chk.False(t2.IsUnresolved()) + return rec(t2, p2) case types.ListKind, types.MapKind, types.SetKind: return true case types.StructKind: @@ -567,8 +565,7 @@ func (gen *codeGen) containsNonComparable(t types.Type) bool { if b, ok := cache[userName]; ok { return b } - // If we get here in a recursive call we will mark it as not having a non comparable value. If it does then that will - // get handled higher up in the call chain. + // If we get here in a recursive call we will mark it as not having a non comparable value. If it does then that will get handled higher up in the call chain. cache[userName] = false for _, f := range t.Desc.(types.StructDesc).Fields { if rec(f.T, p) { @@ -582,12 +579,17 @@ func (gen *codeGen) containsNonComparable(t types.Type) bool { } } - return rec(t, gen.pkg.Package) + return rec(t, p) } -func resolveInPackage(t types.Type, p *types.Package) types.Type { - if !t.IsUnresolved() { - return t +func (gen *codeGen) resolveInPackage(t types.Type, p types.Package) (types.Type, types.Package) { + d.Chk.True(t.IsUnresolved()) + + // For unresolved types that references types in the same package the ref is empty and we need to use the passed in package. + if t.HasPackageRef() { + p = gen.deps[t.PackageRef()] + d.Chk.NotNil(p) } - return p.Types()[t.Ordinal()] + + return p.Types()[t.Ordinal()], p } diff --git a/nomdl/codegen/codegen_test.go b/nomdl/codegen/codegen_test.go index 552a64cfec..1711cb4d75 100644 --- a/nomdl/codegen/codegen_test.go +++ b/nomdl/codegen/codegen_test.go @@ -71,10 +71,10 @@ func TestCanUseDef(t *testing.T) { pkg := pkg.ParseNomDL("fakefile", bytes.NewBufferString(s), "", emptyCS) gen := newCodeGen(nil, "fakefile", map[string]bool{}, depsMap{}, pkg) for _, t := range pkg.UsingDeclarations { - assert.Equal(using, gen.canUseDef(t)) + assert.Equal(using, gen.canUseDef(t, gen.pkg.Package)) } for _, t := range pkg.Types() { - assert.Equal(named, gen.canUseDef(t)) + assert.Equal(named, gen.canUseDef(t, gen.pkg.Package)) } } diff --git a/nomdl/pkg/imports_test.go b/nomdl/pkg/imports_test.go index a61f8951f8..e9fc2605aa 100644 --- a/nomdl/pkg/imports_test.go +++ b/nomdl/pkg/imports_test.go @@ -199,3 +199,40 @@ func (suite *ImportTestSuite) TestImports() { suite.EqualValues(types.ListKind, usings[1].Kind()) suite.EqualValues(0, usings[1].Desc.(types.CompoundDesc).ElemTypes[0].Ordinal()) } + +func (suite *ImportTestSuite) TestImportWithLocalRef() { + dir, err := ioutil.TempDir("", "") + suite.NoError(err) + defer os.RemoveAll(dir) + + byPathNomDL := filepath.Join(dir, "filedep.noms") + err = ioutil.WriteFile(byPathNomDL, []byte("struct FromFile{i:Int8}"), 0600) + suite.NoError(err) + + r1 := strings.NewReader(` + struct A { + B: B + } + struct B { + X: Int64 + }`) + pkg1 := ParseNomDL("test1", r1, dir, suite.cs) + pkgRef1 := types.WriteValue(pkg1.Package, suite.cs) + + r2 := strings.NewReader(fmt.Sprintf(` + alias Other = import "%s" + struct C { + C: Map + } + `, pkgRef1)) + pkg2 := ParseNomDL("test2", r2, dir, suite.cs) + + ts := pkg2.Types() + suite.Len(ts, 1) + suite.EqualValues(types.StructKind, ts[0].Kind()) + mapType := ts[0].Desc.(types.StructDesc).Fields[0].T + suite.EqualValues(types.MapKind, mapType.Kind()) + otherAType := mapType.Desc.(types.CompoundDesc).ElemTypes[1] + suite.EqualValues(types.UnresolvedKind, otherAType.Kind()) + suite.EqualValues(pkgRef1, otherAType.PackageRef()) +} diff --git a/nomdl/pkg/parse.go b/nomdl/pkg/parse.go index deceb3b3d5..3f67acb247 100644 --- a/nomdl/pkg/parse.go +++ b/nomdl/pkg/parse.go @@ -59,8 +59,9 @@ func resolveLocalOrdinals(p *intermediate) { if t.Namespace() == "" && !t.HasOrdinal() { ordinal := indexOf(t, p.Types) d.Chk.True(ordinal >= 0 && int(ordinal) < len(p.Types), "Invalid reference: %s", t.Name()) - return types.MakeType(t.PackageRef(), int16(ordinal)) + return types.MakeType(ref.Ref{}, int16(ordinal)) } + return t } @@ -119,6 +120,8 @@ func resolveNamespaces(p *intermediate, aliases map[string]ref.Ref, deps map[ref t = resolveNamespace(t, aliases, deps) } switch t.Kind() { + case types.UnresolvedKind: + d.Chk.True(t.HasPackageRef(), "should resolve again") case types.ListKind, types.SetKind, types.RefKind: return types.MakeCompoundType(t.Kind(), rec(t.Desc.(types.CompoundDesc).ElemTypes[0])) case types.MapKind: @@ -129,7 +132,10 @@ func resolveNamespaces(p *intermediate, aliases map[string]ref.Ref, deps map[ref resolveFields(t.Desc.(types.StructDesc).Union) } - d.Chk.True(!t.IsUnresolved() || t.HasOrdinal()) + if t.IsUnresolved() { + return rec(t) + } + return t } @@ -155,6 +161,7 @@ func resolveNamespace(t types.Type, aliases map[string]ref.Ref, deps map[ref.Ref d.Chk.NotEqual("", t.Name()) ordinal := deps[pkgRef].GetOrdinal(t.Name()) d.Exp.NotEqual(int64(-1), ordinal, "Could not find type %s in package %s (aliased to %s).", t.Name(), pkgRef.String(), t.Namespace()) + d.Chk.False(pkgRef.IsEmpty()) return types.MakeType(pkgRef, int16(ordinal)) }