diff --git a/go/marshal/encode.go b/go/marshal/encode.go index 1a20825939..236b644b13 100644 --- a/go/marshal/encode.go +++ b/go/marshal/encode.go @@ -38,8 +38,12 @@ import ( // // Struct values are encoded as Noms structs (types.Struct). Each exported Go // struct field becomes a member of the Noms struct unless -// - the field's tag is "-" -// - the field is empty and its tag specifies the "omitempty" option. +// - The field's tag is "-" +// - The field is empty and its tag specifies the "omitempty" option. +// - The field has the "original" tag, in which case the field is used as an +// initial value onto which the fields of the Go type are added. When +// combined with the corresponding support for "original" in Unmarshal(), +// this allows one to find and modify any values of a known subtype. // // The empty values are false, 0, any nil pointer or interface value, and any // array, slice, map, or string of length zero. @@ -218,7 +222,7 @@ func structEncoder(t reflect.Type, parentStructTypes []reflect.Type) encoderFunc } parentStructTypes = append(parentStructTypes, t) - fields, structType := typeFields(t, parentStructTypes) + fields, structType, originalFieldIndex := typeFields(t, parentStructTypes) if structType != nil { e = func(v reflect.Value) types.Value { values := make([]types.Value, len(fields)) @@ -227,8 +231,9 @@ func structEncoder(t reflect.Type, parentStructTypes []reflect.Type) encoderFunc } return types.NewStructWithType(structType, values) } - } else { - // Cannot precompute the Noms type since there are Noms collections. + } else if originalFieldIndex == nil { + // Slower path: cannot precompute the Noms type since there are Noms collections, + // but at least there are a set number of fields name := strings.Title(t.Name()) e = func(v reflect.Value) types.Value { data := make(types.StructData, len(fields)) @@ -241,6 +246,20 @@ func structEncoder(t reflect.Type, parentStructTypes []reflect.Type) encoderFunc } return types.NewStruct(name, data) } + } else { + // Slowest path - we are extending some other struct. We need to start with the + // type of that struct and extend. + e = func(v reflect.Value) types.Value { + ret := v.FieldByIndex(originalFieldIndex).Interface().(types.Struct) + for _, f := range fields { + fv := v.Field(f.index) + if !fv.IsValid() || f.omitEmpty && isEmptyValue(fv) { + continue + } + ret = ret.Set(f.name, f.encoder(fv)) + } + return ret + } } encoderCache.set(t, e) @@ -352,7 +371,7 @@ func validateField(f reflect.StructField, t reflect.Type) { } } -func typeFields(t reflect.Type, parentStructTypes []reflect.Type) (fields fieldSlice, structType *types.Type) { +func typeFields(t reflect.Type, parentStructTypes []reflect.Type) (fields fieldSlice, structType *types.Type, originalFieldIndex []int) { canComputeStructType := true for i := 0; i < t.NumField(); i++ { f := t.Field(i) @@ -361,6 +380,12 @@ func typeFields(t reflect.Type, parentStructTypes []reflect.Type) (fields fieldS continue } + if tags.original { + canComputeStructType = false + originalFieldIndex = f.Index + continue + } + validateField(f, t) nt := nomsType(f.Type, parentStructTypes) if nt == nil { @@ -441,7 +466,7 @@ func structNomsType(t reflect.Type, parentStructTypes []reflect.Type) *types.Typ } } - _, structType := typeFields(t, parentStructTypes) + _, structType, _ := typeFields(t, parentStructTypes) return structType } diff --git a/go/marshal/encode_test.go b/go/marshal/encode_test.go index 8184873600..ad93eabf6d 100644 --- a/go/marshal/encode_test.go +++ b/go/marshal/encode_test.go @@ -647,6 +647,53 @@ func TestEncodeCanSkipUnexportedField(t *testing.T) { }).Equals(v)) } +func TestEncodeOriginal(t *testing.T) { + assert := assert.New(t) + + type S struct { + Foo int `noms:",omitempty"` + Bar types.Struct `noms:",original"` + } + + var s S + var err error + var orig types.Struct + + // New field value clobbers old field value + orig = types.NewStruct("S", types.StructData{ + "foo": types.Number(42), + }) + err = Unmarshal(orig, &s) + assert.NoError(err) + s.Foo = 43 + assert.True(MustMarshal(s).Equals(orig.Set("foo", types.Number(43)))) + + // New field extends old struct + orig = types.NewStruct("S", types.StructData{}) + err = Unmarshal(orig, &s) + assert.NoError(err) + s.Foo = 43 + assert.True(MustMarshal(s).Equals(orig.Set("foo", types.Number(43)))) + + // Old struct name always used + orig = types.NewStruct("Q", types.StructData{}) + err = Unmarshal(orig, &s) + assert.NoError(err) + s.Foo = 43 + assert.True(MustMarshal(s).Equals(orig.Set("foo", types.Number(43)))) + + // Field type of base are preserved + st := types.MakeStructType("S", []string{"foo"}, + []*types.Type{types.MakeUnionType(types.StringType, types.NumberType)}) + orig = types.NewStructWithType(st, []types.Value{types.Number(42)}) + err = Unmarshal(orig, &s) + assert.NoError(err) + s.Foo = 43 + out := MustMarshal(s) + assert.True(out.Equals(orig.Set("foo", types.Number(43)))) + assert.True(out.Type().Equals(st)) +} + type TestInterface interface { M() }