Merge pull request #10501 from rhafer/issue/10495

fix(graph): Use the correct opaqueId when Statting OCM shares
This commit is contained in:
Michael Barz
2024-11-07 18:03:00 +01:00
committed by GitHub
11 changed files with 1624 additions and 26 deletions

View File

@@ -0,0 +1,7 @@
Bugfix: Fixed `sharedWithMe` response for OCM shares
OCM shares returned in the `sharedWithMe` response did not have the `mimeType` property
populated correctly.
https://github.com/owncloud/ocis/pull/10501
https://github.com/owncloud/ocis/issues/10495

View File

@@ -634,7 +634,7 @@ func (api DriveItemPermissionsApi) Invite(w http.ResponseWriter, r *http.Request
return
}
permission, err := api.driveItemPermissionsService.Invite(ctx, &itemID, *driveItemInvite)
permission, err := api.driveItemPermissionsService.Invite(ctx, itemID, *driveItemInvite)
if err != nil {
errorcode.RenderError(w, r, err)
return
@@ -698,7 +698,7 @@ func (api DriveItemPermissionsApi) ListPermissions(w http.ResponseWriter, r *htt
ctx := r.Context()
permissions, err := api.driveItemPermissionsService.ListPermissions(ctx, &itemID, listFederatedRoles, selectRoles)
permissions, err := api.driveItemPermissionsService.ListPermissions(ctx, itemID, listFederatedRoles, selectRoles)
if err != nil {
errorcode.RenderError(w, r, err)
return
@@ -774,7 +774,7 @@ func (api DriveItemPermissionsApi) DeletePermission(w http.ResponseWriter, r *ht
}
ctx := r.Context()
err = api.driveItemPermissionsService.DeletePermission(ctx, &itemID, permissionID)
err = api.driveItemPermissionsService.DeletePermission(ctx, itemID, permissionID)
if err != nil {
errorcode.RenderError(w, r, err)
return
@@ -841,7 +841,7 @@ func (api DriveItemPermissionsApi) UpdatePermission(w http.ResponseWriter, r *ht
return
}
updatedPermission, err := api.driveItemPermissionsService.UpdatePermission(ctx, &itemID, permissionID, permission)
updatedPermission, err := api.driveItemPermissionsService.UpdatePermission(ctx, itemID, permissionID, permission)
if err != nil {
errorcode.RenderError(w, r, err)
return

View File

@@ -166,7 +166,7 @@ func (api DriveItemPermissionsApi) CreateLink(w http.ResponseWriter, r *http.Req
return
}
perm, err := api.driveItemPermissionsService.CreateLink(r.Context(), &driveItemID, createLink)
perm, err := api.driveItemPermissionsService.CreateLink(r.Context(), driveItemID, createLink)
if err != nil {
errorcode.RenderError(w, r, err)
return
@@ -228,7 +228,7 @@ func (api DriveItemPermissionsApi) SetLinkPassword(w http.ResponseWriter, r *htt
return
}
newPermission, err := api.driveItemPermissionsService.SetPublicLinkPassword(ctx, &itemID, permissionID, password.GetPassword())
newPermission, err := api.driveItemPermissionsService.SetPublicLinkPassword(ctx, itemID, permissionID, password.GetPassword())
if err != nil {
errorcode.RenderError(w, r, err)
return

View File

@@ -499,7 +499,7 @@ func (api DrivesDriveItemApi) CreateDriveItem(w http.ResponseWriter, r *http.Req
return
}
if !IsShareJail(driveID) {
if !IsShareJail(&driveID) {
api.logger.Debug().Interface("driveID", driveID).Msg(ErrNotAShareJail.Error())
ErrNotAShareJail.Render(w, r)
return

View File

@@ -2,6 +2,7 @@ package svc
import (
"context"
"encoding/base64"
"encoding/json"
"io"
"net/http"
@@ -44,8 +45,8 @@ func IsSpaceRoot(rid *storageprovider.ResourceId) bool {
// GetDriveAndItemIDParam parses the driveID and itemID from the request,
// validates the common fields and returns the parsed IDs if ok.
func GetDriveAndItemIDParam(r *http.Request, logger *log.Logger) (storageprovider.ResourceId, storageprovider.ResourceId, error) {
empty := storageprovider.ResourceId{}
func GetDriveAndItemIDParam(r *http.Request, logger *log.Logger) (*storageprovider.ResourceId, *storageprovider.ResourceId, error) {
empty := &storageprovider.ResourceId{}
driveID, err := parseIDParam(r, "driveID")
if err != nil {
@@ -60,16 +61,16 @@ func GetDriveAndItemIDParam(r *http.Request, logger *log.Logger) (storageprovide
}
if itemID.GetOpaqueId() == "" {
logger.Debug().Interface("driveID", driveID).Interface("itemID", itemID).Msg("empty item opaqueID")
logger.Debug().Interface("driveID", &driveID).Interface("itemID", &itemID).Msg("empty item opaqueID")
return empty, empty, errorcode.New(errorcode.InvalidRequest, "invalid itemID")
}
if driveID.GetStorageId() != itemID.GetStorageId() || driveID.GetSpaceId() != itemID.GetSpaceId() {
logger.Debug().Interface("driveID", driveID).Interface("itemID", itemID).Msg("driveID and itemID do not match")
logger.Debug().Interface("driveID", &driveID).Interface("itemID", &itemID).Msg("driveID and itemID do not match")
return empty, empty, errorcode.New(errorcode.ItemNotFound, "driveID and itemID do not match")
}
return driveID, itemID, nil
return &driveID, &itemID, nil
}
// GetFilterParam returns the $filter query parameter from the request. If you need to parse the filter use godata.ParseRequest
@@ -95,7 +96,7 @@ func (g Graph) GetGatewayClient(w http.ResponseWriter, r *http.Request) (gateway
}
// IsShareJail returns true if given id is a share jail id.
func IsShareJail(id storageprovider.ResourceId) bool {
func IsShareJail(id *storageprovider.ResourceId) bool {
return id.GetStorageId() == utils.ShareStorageProviderID && id.GetSpaceId() == utils.ShareStorageSpaceID
}
@@ -510,7 +511,7 @@ func federatedRoleConditionForResourceType(ri *storageprovider.ResourceInfo) (st
// ExtractShareIdFromResourceId is a bit of a hack.
// We should not rely on a specific format of the item id.
// But currently there is no other way to get the ShareID.
func ExtractShareIdFromResourceId(rid storageprovider.ResourceId) *collaboration.ShareId {
func ExtractShareIdFromResourceId(rid *storageprovider.ResourceId) *collaboration.ShareId {
return &collaboration.ShareId{
OpaqueId: rid.GetOpaqueId(),
}
@@ -538,14 +539,21 @@ func cs3ReceivedOCMSharesToDriveItems(ctx context.Context,
group.Go(func() error {
var err error // redeclare
// for OCM shares the opaqueID is the '/' for shared directories and '/filename' for
// file shares
resOpaqueID := "/"
if receivedShares[0].GetResourceType() == storageprovider.ResourceType_RESOURCE_TYPE_FILE {
resOpaqueID += receivedShares[0].GetName()
}
shareStat, err := gatewayClient.Stat(ctx, &storageprovider.StatRequest{
Ref: &storageprovider.Reference{
ResourceId: &storageprovider.ResourceId{
// TODO maybe the reference is wrong
StorageId: utils.OCMStorageProviderID,
SpaceId: receivedShares[0].GetId().GetOpaqueId(),
OpaqueId: "", // in OCM resources the opaque id is the base64 encoded path
//OpaqueId: maybe ? receivedShares[0].GetId().GetOpaqueId(),
OpaqueId: base64.StdEncoding.EncodeToString([]byte(resOpaqueID)),
},
},
})

View File

@@ -15,6 +15,7 @@ import (
provider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1"
"github.com/cs3org/reva/v2/pkg/storagespace"
"google.golang.org/protobuf/testing/protocmp"
"github.com/owncloud/ocis/v2/ocis-pkg/conversions"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
@@ -42,10 +43,10 @@ var _ = Describe("Utils", func() {
case true:
Expect(err).To(BeNil())
parsedItemID, _ := storagespace.ParseID(itemID)
Expect(extractedItemID).To(Equal(parsedItemID))
Expect(extractedItemID).To(BeComparableTo(&parsedItemID, protocmp.Transform()))
parsedDriveID, _ := storagespace.ParseID(driveID)
Expect(extractedDriveID).To(Equal(parsedDriveID))
Expect(extractedDriveID).To(BeComparableTo(&parsedDriveID, protocmp.Transform()))
default:
Expect(err).ToNot(BeNil())
}
@@ -82,29 +83,29 @@ var _ = Describe("Utils", func() {
)
DescribeTable("IsShareJail",
func(resourceID provider.ResourceId, isShareJail bool) {
func(resourceID *provider.ResourceId, isShareJail bool) {
Expect(service.IsShareJail(resourceID)).To(Equal(isShareJail))
},
Entry("valid: share jail", provider.ResourceId{
Entry("valid: share jail", &provider.ResourceId{
StorageId: utils.ShareStorageProviderID,
SpaceId: utils.ShareStorageSpaceID,
}, true),
Entry("invalid: empty storageId", provider.ResourceId{
Entry("invalid: empty storageId", &provider.ResourceId{
SpaceId: utils.ShareStorageSpaceID,
}, false),
Entry("invalid: empty spaceId", provider.ResourceId{
Entry("invalid: empty spaceId", &provider.ResourceId{
StorageId: utils.ShareStorageProviderID,
}, false),
Entry("invalid: empty storageId and spaceId", provider.ResourceId{}, false),
Entry("invalid: non share jail storageId", provider.ResourceId{
Entry("invalid: empty storageId and spaceId", &provider.ResourceId{}, false),
Entry("invalid: non share jail storageId", &provider.ResourceId{
StorageId: "123",
SpaceId: utils.ShareStorageSpaceID,
}, false),
Entry("invalid: non share jail spaceId", provider.ResourceId{
Entry("invalid: non share jail spaceId", &provider.ResourceId{
StorageId: utils.ShareStorageProviderID,
SpaceId: "123",
}, false),
Entry("invalid: non share jail storageID and spaceId", provider.ResourceId{
Entry("invalid: non share jail storageID and spaceId", &provider.ResourceId{
StorageId: "123",
SpaceId: "123",
}, false),

View File

@@ -0,0 +1,261 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package msgfmt implements a text marshaler combining the desirable features
// of both the JSON and proto text formats.
// It is optimized for human readability and has no associated deserializer.
package msgfmt
import (
"bytes"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"time"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/detrand"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
// Format returns a formatted string for the message.
func Format(m proto.Message) string {
return string(appendMessage(nil, m.ProtoReflect()))
}
// FormatValue returns a formatted string for an arbitrary value.
func FormatValue(v protoreflect.Value, fd protoreflect.FieldDescriptor) string {
return string(appendValue(nil, v, fd))
}
func appendValue(b []byte, v protoreflect.Value, fd protoreflect.FieldDescriptor) []byte {
switch v := v.Interface().(type) {
case nil:
return append(b, "<invalid>"...)
case bool, int32, int64, uint32, uint64, float32, float64:
return append(b, fmt.Sprint(v)...)
case string:
return append(b, strconv.Quote(string(v))...)
case []byte:
return append(b, strconv.Quote(string(v))...)
case protoreflect.EnumNumber:
return appendEnum(b, v, fd)
case protoreflect.Message:
return appendMessage(b, v)
case protoreflect.List:
return appendList(b, v, fd)
case protoreflect.Map:
return appendMap(b, v, fd)
default:
panic(fmt.Sprintf("invalid type: %T", v))
}
}
func appendEnum(b []byte, v protoreflect.EnumNumber, fd protoreflect.FieldDescriptor) []byte {
if fd != nil {
if ev := fd.Enum().Values().ByNumber(v); ev != nil {
return append(b, ev.Name()...)
}
}
return strconv.AppendInt(b, int64(v), 10)
}
func appendMessage(b []byte, m protoreflect.Message) []byte {
if b2 := appendKnownMessage(b, m); b2 != nil {
return b2
}
b = append(b, '{')
order.RangeFields(m, order.IndexNameFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
b = append(b, fd.TextName()...)
b = append(b, ':')
b = appendValue(b, v, fd)
b = append(b, delim()...)
return true
})
b = appendUnknown(b, m.GetUnknown())
b = bytes.TrimRight(b, delim())
b = append(b, '}')
return b
}
var protocmpMessageType = reflect.TypeOf(map[string]any(nil))
func appendKnownMessage(b []byte, m protoreflect.Message) []byte {
md := m.Descriptor()
fds := md.Fields()
switch md.FullName() {
case genid.Any_message_fullname:
var msgVal protoreflect.Message
url := m.Get(fds.ByNumber(genid.Any_TypeUrl_field_number)).String()
if v := reflect.ValueOf(m); v.Type().ConvertibleTo(protocmpMessageType) {
// For protocmp.Message, directly obtain the sub-message value
// which is stored in structured form, rather than as raw bytes.
m2 := v.Convert(protocmpMessageType).Interface().(map[string]any)
v, ok := m2[string(genid.Any_Value_field_name)].(proto.Message)
if !ok {
return nil
}
msgVal = v.ProtoReflect()
} else {
val := m.Get(fds.ByNumber(genid.Any_Value_field_number)).Bytes()
mt, err := protoregistry.GlobalTypes.FindMessageByURL(url)
if err != nil {
return nil
}
msgVal = mt.New()
err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(val, msgVal.Interface())
if err != nil {
return nil
}
}
b = append(b, '{')
b = append(b, "["+url+"]"...)
b = append(b, ':')
b = appendMessage(b, msgVal)
b = append(b, '}')
return b
case genid.Timestamp_message_fullname:
secs := m.Get(fds.ByNumber(genid.Timestamp_Seconds_field_number)).Int()
nanos := m.Get(fds.ByNumber(genid.Timestamp_Nanos_field_number)).Int()
if nanos < 0 || nanos >= 1e9 {
return nil
}
t := time.Unix(secs, nanos).UTC()
x := t.Format("2006-01-02T15:04:05.000000000") // RFC 3339
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, ".000")
return append(b, x+"Z"...)
case genid.Duration_message_fullname:
sign := ""
secs := m.Get(fds.ByNumber(genid.Duration_Seconds_field_number)).Int()
nanos := m.Get(fds.ByNumber(genid.Duration_Nanos_field_number)).Int()
if nanos <= -1e9 || nanos >= 1e9 || (secs > 0 && nanos < 0) || (secs < 0 && nanos > 0) {
return nil
}
if secs < 0 || nanos < 0 {
sign, secs, nanos = "-", -1*secs, -1*nanos
}
x := fmt.Sprintf("%s%d.%09d", sign, secs, nanos)
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, ".000")
return append(b, x+"s"...)
case genid.BoolValue_message_fullname,
genid.Int32Value_message_fullname,
genid.Int64Value_message_fullname,
genid.UInt32Value_message_fullname,
genid.UInt64Value_message_fullname,
genid.FloatValue_message_fullname,
genid.DoubleValue_message_fullname,
genid.StringValue_message_fullname,
genid.BytesValue_message_fullname:
fd := fds.ByNumber(genid.WrapperValue_Value_field_number)
return appendValue(b, m.Get(fd), fd)
}
return nil
}
func appendUnknown(b []byte, raw protoreflect.RawFields) []byte {
rs := make(map[protoreflect.FieldNumber][]protoreflect.RawFields)
for len(raw) > 0 {
num, _, n := protowire.ConsumeField(raw)
rs[num] = append(rs[num], raw[:n])
raw = raw[n:]
}
var ns []protoreflect.FieldNumber
for n := range rs {
ns = append(ns, n)
}
sort.Slice(ns, func(i, j int) bool { return ns[i] < ns[j] })
for _, n := range ns {
var leftBracket, rightBracket string
if len(rs[n]) > 1 {
leftBracket, rightBracket = "[", "]"
}
b = strconv.AppendInt(b, int64(n), 10)
b = append(b, ':')
b = append(b, leftBracket...)
for _, r := range rs[n] {
num, typ, n := protowire.ConsumeTag(r)
r = r[n:]
switch typ {
case protowire.VarintType:
v, _ := protowire.ConsumeVarint(r)
b = strconv.AppendInt(b, int64(v), 10)
case protowire.Fixed32Type:
v, _ := protowire.ConsumeFixed32(r)
b = append(b, fmt.Sprintf("0x%08x", v)...)
case protowire.Fixed64Type:
v, _ := protowire.ConsumeFixed64(r)
b = append(b, fmt.Sprintf("0x%016x", v)...)
case protowire.BytesType:
v, _ := protowire.ConsumeBytes(r)
b = strconv.AppendQuote(b, string(v))
case protowire.StartGroupType:
v, _ := protowire.ConsumeGroup(num, r)
b = append(b, '{')
b = appendUnknown(b, v)
b = bytes.TrimRight(b, delim())
b = append(b, '}')
default:
panic(fmt.Sprintf("invalid type: %v", typ))
}
b = append(b, delim()...)
}
b = bytes.TrimRight(b, delim())
b = append(b, rightBracket...)
b = append(b, delim()...)
}
return b
}
func appendList(b []byte, v protoreflect.List, fd protoreflect.FieldDescriptor) []byte {
b = append(b, '[')
for i := 0; i < v.Len(); i++ {
b = appendValue(b, v.Get(i), fd)
b = append(b, delim()...)
}
b = bytes.TrimRight(b, delim())
b = append(b, ']')
return b
}
func appendMap(b []byte, v protoreflect.Map, fd protoreflect.FieldDescriptor) []byte {
b = append(b, '{')
order.RangeEntries(v, order.GenericKeyOrder, func(k protoreflect.MapKey, v protoreflect.Value) bool {
b = appendValue(b, k.Value(), fd.MapKey())
b = append(b, ':')
b = appendValue(b, v, fd.MapValue())
b = append(b, delim()...)
return true
})
b = bytes.TrimRight(b, delim())
b = append(b, '}')
return b
}
func delim() string {
// Deliberately introduce instability into the message string to
// discourage users from depending on it.
if detrand.Bool() {
return " "
}
return ", "
}

View File

@@ -0,0 +1,258 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protocmp
import (
"reflect"
"sort"
"strconv"
"strings"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
)
func reflectValueOf(v any) protoreflect.Value {
switch v := v.(type) {
case Enum:
return protoreflect.ValueOfEnum(v.Number())
case Message:
return protoreflect.ValueOfMessage(v.ProtoReflect())
case []byte:
return protoreflect.ValueOfBytes(v) // avoid overlap with reflect.Slice check below
default:
switch rv := reflect.ValueOf(v); {
case rv.Kind() == reflect.Slice:
return protoreflect.ValueOfList(reflectList{rv})
case rv.Kind() == reflect.Map:
return protoreflect.ValueOfMap(reflectMap{rv})
default:
return protoreflect.ValueOf(v)
}
}
}
type reflectMessage Message
func (m reflectMessage) stringKey(fd protoreflect.FieldDescriptor) string {
if m.Descriptor() != fd.ContainingMessage() {
panic("mismatching containing message")
}
return fd.TextName()
}
func (m reflectMessage) Descriptor() protoreflect.MessageDescriptor {
return (Message)(m).Descriptor()
}
func (m reflectMessage) Type() protoreflect.MessageType {
return reflectMessageType{m.Descriptor()}
}
func (m reflectMessage) New() protoreflect.Message {
return m.Type().New()
}
func (m reflectMessage) Interface() protoreflect.ProtoMessage {
return Message(m)
}
func (m reflectMessage) Range(f func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool) {
// Range over populated known fields.
fds := m.Descriptor().Fields()
for i := 0; i < fds.Len(); i++ {
fd := fds.Get(i)
if m.Has(fd) && !f(fd, m.Get(fd)) {
return
}
}
// Range over populated extension fields.
for _, xd := range m[messageTypeKey].(messageMeta).xds {
if m.Has(xd) && !f(xd, m.Get(xd)) {
return
}
}
}
func (m reflectMessage) Has(fd protoreflect.FieldDescriptor) bool {
_, ok := m[m.stringKey(fd)]
return ok
}
func (m reflectMessage) Clear(protoreflect.FieldDescriptor) {
panic("invalid mutation of read-only message")
}
func (m reflectMessage) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
v, ok := m[m.stringKey(fd)]
if !ok {
switch {
case fd.IsList():
return protoreflect.ValueOfList(reflectList{})
case fd.IsMap():
return protoreflect.ValueOfMap(reflectMap{})
case fd.Message() != nil:
return protoreflect.ValueOfMessage(reflectMessage{
messageTypeKey: messageMeta{md: fd.Message()},
})
default:
return fd.Default()
}
}
// The transformation may leave Any messages in structured form.
// If so, convert them back to a raw-encoded form.
if fd.FullName() == genid.Any_Value_field_fullname {
if m, ok := v.(Message); ok {
b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
if err != nil {
panic("BUG: " + err.Error())
}
return protoreflect.ValueOfBytes(b)
}
}
return reflectValueOf(v)
}
func (m reflectMessage) Set(protoreflect.FieldDescriptor, protoreflect.Value) {
panic("invalid mutation of read-only message")
}
func (m reflectMessage) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
panic("invalid mutation of read-only message")
}
func (m reflectMessage) NewField(protoreflect.FieldDescriptor) protoreflect.Value {
panic("not implemented")
}
func (m reflectMessage) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
if m.Descriptor().Oneofs().ByName(od.Name()) != od {
panic("oneof descriptor does not belong to this message")
}
fds := od.Fields()
for i := 0; i < fds.Len(); i++ {
fd := fds.Get(i)
if _, ok := m[m.stringKey(fd)]; ok {
return fd
}
}
return nil
}
func (m reflectMessage) GetUnknown() protoreflect.RawFields {
var nums []protoreflect.FieldNumber
for k := range m {
if len(strings.Trim(k, "0123456789")) == 0 {
n, _ := strconv.ParseUint(k, 10, 32)
nums = append(nums, protoreflect.FieldNumber(n))
}
}
sort.Slice(nums, func(i, j int) bool { return nums[i] < nums[j] })
var raw protoreflect.RawFields
for _, num := range nums {
b, _ := m[strconv.FormatUint(uint64(num), 10)].(protoreflect.RawFields)
raw = append(raw, b...)
}
return raw
}
func (m reflectMessage) SetUnknown(protoreflect.RawFields) {
panic("invalid mutation of read-only message")
}
func (m reflectMessage) IsValid() bool {
invalid, _ := m[messageInvalidKey].(bool)
return !invalid
}
func (m reflectMessage) ProtoMethods() *protoiface.Methods {
return nil
}
type reflectMessageType struct{ protoreflect.MessageDescriptor }
func (t reflectMessageType) New() protoreflect.Message {
panic("not implemented")
}
func (t reflectMessageType) Zero() protoreflect.Message {
panic("not implemented")
}
func (t reflectMessageType) Descriptor() protoreflect.MessageDescriptor {
return t.MessageDescriptor
}
type reflectList struct{ v reflect.Value }
func (ls reflectList) Len() int {
if !ls.IsValid() {
return 0
}
return ls.v.Len()
}
func (ls reflectList) Get(i int) protoreflect.Value {
return reflectValueOf(ls.v.Index(i).Interface())
}
func (ls reflectList) Set(int, protoreflect.Value) {
panic("invalid mutation of read-only list")
}
func (ls reflectList) Append(protoreflect.Value) {
panic("invalid mutation of read-only list")
}
func (ls reflectList) AppendMutable() protoreflect.Value {
panic("invalid mutation of read-only list")
}
func (ls reflectList) Truncate(int) {
panic("invalid mutation of read-only list")
}
func (ls reflectList) NewElement() protoreflect.Value {
panic("not implemented")
}
func (ls reflectList) IsValid() bool {
return ls.v.IsValid()
}
type reflectMap struct{ v reflect.Value }
func (ms reflectMap) Len() int {
if !ms.IsValid() {
return 0
}
return ms.v.Len()
}
func (ms reflectMap) Range(f func(protoreflect.MapKey, protoreflect.Value) bool) {
if !ms.IsValid() {
return
}
ks := ms.v.MapKeys()
for _, k := range ks {
pk := reflectValueOf(k.Interface()).MapKey()
pv := reflectValueOf(ms.v.MapIndex(k).Interface())
if !f(pk, pv) {
return
}
}
}
func (ms reflectMap) Has(k protoreflect.MapKey) bool {
if !ms.IsValid() {
return false
}
return ms.v.MapIndex(reflect.ValueOf(k.Interface())).IsValid()
}
func (ms reflectMap) Clear(protoreflect.MapKey) {
panic("invalid mutation of read-only list")
}
func (ms reflectMap) Get(k protoreflect.MapKey) protoreflect.Value {
if !ms.IsValid() {
return protoreflect.Value{}
}
v := ms.v.MapIndex(reflect.ValueOf(k.Interface()))
if !v.IsValid() {
return protoreflect.Value{}
}
return reflectValueOf(v.Interface())
}
func (ms reflectMap) Set(protoreflect.MapKey, protoreflect.Value) {
panic("invalid mutation of read-only list")
}
func (ms reflectMap) Mutable(k protoreflect.MapKey) protoreflect.Value {
panic("invalid mutation of read-only list")
}
func (ms reflectMap) NewValue() protoreflect.Value {
panic("not implemented")
}
func (ms reflectMap) IsValid() bool {
return ms.v.IsValid()
}

View File

@@ -0,0 +1,684 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protocmp
import (
"bytes"
"fmt"
"math"
"reflect"
"strings"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
var (
enumReflectType = reflect.TypeOf(Enum{})
messageReflectType = reflect.TypeOf(Message{})
)
// FilterEnum filters opt to only be applicable on a standalone [Enum],
// singular fields of enums, list fields of enums, or map fields of enum values,
// where the enum is the same type as the specified enum.
//
// The Go type of the last path step may be an:
// - [Enum] for singular fields, elements of a repeated field,
// values of a map field, or standalone [Enum] values
// - [][Enum] for list fields
// - map[K][Enum] for map fields
// - any for a [Message] map entry value
//
// This must be used in conjunction with [Transform].
func FilterEnum(enum protoreflect.Enum, opt cmp.Option) cmp.Option {
return FilterDescriptor(enum.Descriptor(), opt)
}
// FilterMessage filters opt to only be applicable on a standalone [Message] values,
// singular fields of messages, list fields of messages, or map fields of
// message values, where the message is the same type as the specified message.
//
// The Go type of the last path step may be an:
// - [Message] for singular fields, elements of a repeated field,
// values of a map field, or standalone [Message] values
// - [][Message] for list fields
// - map[K][Message] for map fields
// - any for a [Message] map entry value
//
// This must be used in conjunction with [Transform].
func FilterMessage(message proto.Message, opt cmp.Option) cmp.Option {
return FilterDescriptor(message.ProtoReflect().Descriptor(), opt)
}
// FilterField filters opt to only be applicable on the specified field
// in the message. It panics if a field of the given name does not exist.
//
// The Go type of the last path step may be an:
// - T for singular fields
// - []T for list fields
// - map[K]T for map fields
// - any for a [Message] map entry value
//
// This must be used in conjunction with [Transform].
func FilterField(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option {
md := message.ProtoReflect().Descriptor()
return FilterDescriptor(mustFindFieldDescriptor(md, name), opt)
}
// FilterOneof filters opt to only be applicable on all fields within the
// specified oneof in the message. It panics if a oneof of the given name
// does not exist.
//
// The Go type of the last path step may be an:
// - T for singular fields
// - []T for list fields
// - map[K]T for map fields
// - any for a [Message] map entry value
//
// This must be used in conjunction with [Transform].
func FilterOneof(message proto.Message, name protoreflect.Name, opt cmp.Option) cmp.Option {
md := message.ProtoReflect().Descriptor()
return FilterDescriptor(mustFindOneofDescriptor(md, name), opt)
}
// FilterDescriptor ignores the specified descriptor.
//
// The following descriptor types may be specified:
// - [protoreflect.EnumDescriptor]
// - [protoreflect.MessageDescriptor]
// - [protoreflect.FieldDescriptor]
// - [protoreflect.OneofDescriptor]
//
// For the behavior of each, see the corresponding filter function.
// Since this filter accepts a [protoreflect.FieldDescriptor], it can be used
// to also filter for extension fields as a [protoreflect.ExtensionDescriptor]
// is just an alias to [protoreflect.FieldDescriptor].
//
// This must be used in conjunction with [Transform].
func FilterDescriptor(desc protoreflect.Descriptor, opt cmp.Option) cmp.Option {
f := newNameFilters(desc)
return cmp.FilterPath(f.Filter, opt)
}
// IgnoreEnums ignores all enums of the specified types.
// It is equivalent to FilterEnum(enum, cmp.Ignore()) for each enum.
//
// This must be used in conjunction with [Transform].
func IgnoreEnums(enums ...protoreflect.Enum) cmp.Option {
var ds []protoreflect.Descriptor
for _, e := range enums {
ds = append(ds, e.Descriptor())
}
return IgnoreDescriptors(ds...)
}
// IgnoreMessages ignores all messages of the specified types.
// It is equivalent to [FilterMessage](message, [cmp.Ignore]()) for each message.
//
// This must be used in conjunction with [Transform].
func IgnoreMessages(messages ...proto.Message) cmp.Option {
var ds []protoreflect.Descriptor
for _, m := range messages {
ds = append(ds, m.ProtoReflect().Descriptor())
}
return IgnoreDescriptors(ds...)
}
// IgnoreFields ignores the specified fields in the specified message.
// It is equivalent to [FilterField](message, name, [cmp.Ignore]()) for each field
// in the message.
//
// This must be used in conjunction with [Transform].
func IgnoreFields(message proto.Message, names ...protoreflect.Name) cmp.Option {
var ds []protoreflect.Descriptor
md := message.ProtoReflect().Descriptor()
for _, s := range names {
ds = append(ds, mustFindFieldDescriptor(md, s))
}
return IgnoreDescriptors(ds...)
}
// IgnoreOneofs ignores fields of the specified oneofs in the specified message.
// It is equivalent to FilterOneof(message, name, cmp.Ignore()) for each oneof
// in the message.
//
// This must be used in conjunction with [Transform].
func IgnoreOneofs(message proto.Message, names ...protoreflect.Name) cmp.Option {
var ds []protoreflect.Descriptor
md := message.ProtoReflect().Descriptor()
for _, s := range names {
ds = append(ds, mustFindOneofDescriptor(md, s))
}
return IgnoreDescriptors(ds...)
}
// IgnoreDescriptors ignores the specified set of descriptors.
// It is equivalent to [FilterDescriptor](desc, [cmp.Ignore]()) for each descriptor.
//
// This must be used in conjunction with [Transform].
func IgnoreDescriptors(descs ...protoreflect.Descriptor) cmp.Option {
return cmp.FilterPath(newNameFilters(descs...).Filter, cmp.Ignore())
}
func mustFindFieldDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.FieldDescriptor {
d := findDescriptor(md, s)
if fd, ok := d.(protoreflect.FieldDescriptor); ok && fd.TextName() == string(s) {
return fd
}
var suggestion string
switch d := d.(type) {
case protoreflect.FieldDescriptor:
suggestion = fmt.Sprintf("; consider specifying field %q instead", d.TextName())
case protoreflect.OneofDescriptor:
suggestion = fmt.Sprintf("; consider specifying oneof %q with IgnoreOneofs instead", d.Name())
}
panic(fmt.Sprintf("message %q has no field %q%s", md.FullName(), s, suggestion))
}
func mustFindOneofDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.OneofDescriptor {
d := findDescriptor(md, s)
if od, ok := d.(protoreflect.OneofDescriptor); ok && d.Name() == s {
return od
}
var suggestion string
switch d := d.(type) {
case protoreflect.OneofDescriptor:
suggestion = fmt.Sprintf("; consider specifying oneof %q instead", d.Name())
case protoreflect.FieldDescriptor:
suggestion = fmt.Sprintf("; consider specifying field %q with IgnoreFields instead", d.TextName())
}
panic(fmt.Sprintf("message %q has no oneof %q%s", md.FullName(), s, suggestion))
}
func findDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.Descriptor {
// Exact match.
if fd := md.Fields().ByTextName(string(s)); fd != nil {
return fd
}
if od := md.Oneofs().ByName(s); od != nil && !od.IsSynthetic() {
return od
}
// Best-effort match.
//
// It's a common user mistake to use the CamelCased field name as it appears
// in the generated Go struct. Instead of complaining that it doesn't exist,
// suggest the real protobuf name that the user may have desired.
normalize := func(s protoreflect.Name) string {
return strings.Replace(strings.ToLower(string(s)), "_", "", -1)
}
for i := 0; i < md.Fields().Len(); i++ {
if fd := md.Fields().Get(i); normalize(fd.Name()) == normalize(s) {
return fd
}
}
for i := 0; i < md.Oneofs().Len(); i++ {
if od := md.Oneofs().Get(i); normalize(od.Name()) == normalize(s) {
return od
}
}
return nil
}
type nameFilters struct {
names map[protoreflect.FullName]bool
}
func newNameFilters(descs ...protoreflect.Descriptor) *nameFilters {
f := &nameFilters{names: make(map[protoreflect.FullName]bool)}
for _, d := range descs {
switch d := d.(type) {
case protoreflect.EnumDescriptor:
f.names[d.FullName()] = true
case protoreflect.MessageDescriptor:
f.names[d.FullName()] = true
case protoreflect.FieldDescriptor:
f.names[d.FullName()] = true
case protoreflect.OneofDescriptor:
for i := 0; i < d.Fields().Len(); i++ {
f.names[d.Fields().Get(i).FullName()] = true
}
default:
panic("invalid descriptor type")
}
}
return f
}
func (f *nameFilters) Filter(p cmp.Path) bool {
vx, vy := p.Last().Values()
return (f.filterValue(vx) && f.filterValue(vy)) || f.filterFields(p)
}
func (f *nameFilters) filterFields(p cmp.Path) bool {
// Trim off trailing type-assertions so that the filter can match on the
// concrete value held within an interface value.
if _, ok := p.Last().(cmp.TypeAssertion); ok {
p = p[:len(p)-1]
}
// Filter for Message maps.
mi, ok := p.Index(-1).(cmp.MapIndex)
if !ok {
return false
}
ps := p.Index(-2)
if ps.Type() != messageReflectType {
return false
}
// Check field name.
vx, vy := ps.Values()
mx := vx.Interface().(Message)
my := vy.Interface().(Message)
k := mi.Key().String()
if f.filterFieldName(mx, k) && f.filterFieldName(my, k) {
return true
}
// Check field value.
vx, vy = mi.Values()
if f.filterFieldValue(vx) && f.filterFieldValue(vy) {
return true
}
return false
}
func (f *nameFilters) filterFieldName(m Message, k string) bool {
if _, ok := m[k]; !ok {
return true // treat missing fields as already filtered
}
var fd protoreflect.FieldDescriptor
switch mm := m[messageTypeKey].(messageMeta); {
case protoreflect.Name(k).IsValid():
fd = mm.md.Fields().ByTextName(k)
default:
fd = mm.xds[k]
}
if fd != nil {
return f.names[fd.FullName()]
}
return false
}
func (f *nameFilters) filterFieldValue(v reflect.Value) bool {
if !v.IsValid() {
return true // implies missing slice element or map entry
}
v = v.Elem() // map entries are always populated values
switch t := v.Type(); {
case t == enumReflectType || t == messageReflectType:
// Check for singular message or enum field.
return f.filterValue(v)
case t.Kind() == reflect.Slice && (t.Elem() == enumReflectType || t.Elem() == messageReflectType):
// Check for list field of enum or message type.
return f.filterValue(v.Index(0))
case t.Kind() == reflect.Map && (t.Elem() == enumReflectType || t.Elem() == messageReflectType):
// Check for map field of enum or message type.
return f.filterValue(v.MapIndex(v.MapKeys()[0]))
}
return false
}
func (f *nameFilters) filterValue(v reflect.Value) bool {
if !v.IsValid() {
return true // implies missing slice element or map entry
}
if !v.CanInterface() {
return false // implies unexported struct field
}
switch v := v.Interface().(type) {
case Enum:
return v.Descriptor() != nil && f.names[v.Descriptor().FullName()]
case Message:
return v.Descriptor() != nil && f.names[v.Descriptor().FullName()]
}
return false
}
// IgnoreDefaultScalars ignores singular scalars that are unpopulated or
// explicitly set to the default value.
// This option does not effect elements in a list or entries in a map.
//
// This must be used in conjunction with [Transform].
func IgnoreDefaultScalars() cmp.Option {
return cmp.FilterPath(func(p cmp.Path) bool {
// Filter for Message maps.
mi, ok := p.Index(-1).(cmp.MapIndex)
if !ok {
return false
}
ps := p.Index(-2)
if ps.Type() != messageReflectType {
return false
}
// Check whether both fields are default or unpopulated scalars.
vx, vy := ps.Values()
mx := vx.Interface().(Message)
my := vy.Interface().(Message)
k := mi.Key().String()
return isDefaultScalar(mx, k) && isDefaultScalar(my, k)
}, cmp.Ignore())
}
func isDefaultScalar(m Message, k string) bool {
if _, ok := m[k]; !ok {
return true
}
var fd protoreflect.FieldDescriptor
switch mm := m[messageTypeKey].(messageMeta); {
case protoreflect.Name(k).IsValid():
fd = mm.md.Fields().ByTextName(k)
default:
fd = mm.xds[k]
}
if fd == nil || !fd.Default().IsValid() {
return false
}
switch fd.Kind() {
case protoreflect.BytesKind:
v, ok := m[k].([]byte)
return ok && bytes.Equal(fd.Default().Bytes(), v)
case protoreflect.FloatKind:
v, ok := m[k].(float32)
return ok && equalFloat64(fd.Default().Float(), float64(v))
case protoreflect.DoubleKind:
v, ok := m[k].(float64)
return ok && equalFloat64(fd.Default().Float(), float64(v))
case protoreflect.EnumKind:
v, ok := m[k].(Enum)
return ok && fd.Default().Enum() == v.Number()
default:
return reflect.DeepEqual(fd.Default().Interface(), m[k])
}
}
func equalFloat64(x, y float64) bool {
return x == y || (math.IsNaN(x) && math.IsNaN(y))
}
// IgnoreEmptyMessages ignores messages that are empty or unpopulated.
// It applies to standalone [Message] values, singular message fields,
// list fields of messages, and map fields of message values.
//
// This must be used in conjunction with [Transform].
func IgnoreEmptyMessages() cmp.Option {
return cmp.FilterPath(func(p cmp.Path) bool {
vx, vy := p.Last().Values()
return (isEmptyMessage(vx) && isEmptyMessage(vy)) || isEmptyMessageFields(p)
}, cmp.Ignore())
}
func isEmptyMessageFields(p cmp.Path) bool {
// Filter for Message maps.
mi, ok := p.Index(-1).(cmp.MapIndex)
if !ok {
return false
}
ps := p.Index(-2)
if ps.Type() != messageReflectType {
return false
}
// Check field value.
vx, vy := mi.Values()
if isEmptyMessageFieldValue(vx) && isEmptyMessageFieldValue(vy) {
return true
}
return false
}
func isEmptyMessageFieldValue(v reflect.Value) bool {
if !v.IsValid() {
return true // implies missing slice element or map entry
}
v = v.Elem() // map entries are always populated values
switch t := v.Type(); {
case t == messageReflectType:
// Check singular field for empty message.
if !isEmptyMessage(v) {
return false
}
case t.Kind() == reflect.Slice && t.Elem() == messageReflectType:
// Check list field for all empty message elements.
for i := 0; i < v.Len(); i++ {
if !isEmptyMessage(v.Index(i)) {
return false
}
}
case t.Kind() == reflect.Map && t.Elem() == messageReflectType:
// Check map field for all empty message values.
for _, k := range v.MapKeys() {
if !isEmptyMessage(v.MapIndex(k)) {
return false
}
}
default:
return false
}
return true
}
func isEmptyMessage(v reflect.Value) bool {
if !v.IsValid() {
return true // implies missing slice element or map entry
}
if !v.CanInterface() {
return false // implies unexported struct field
}
if m, ok := v.Interface().(Message); ok {
for k := range m {
if k != messageTypeKey && k != messageInvalidKey {
return false
}
}
return true
}
return false
}
// IgnoreUnknown ignores unknown fields in all messages.
//
// This must be used in conjunction with [Transform].
func IgnoreUnknown() cmp.Option {
return cmp.FilterPath(func(p cmp.Path) bool {
// Filter for Message maps.
mi, ok := p.Index(-1).(cmp.MapIndex)
if !ok {
return false
}
ps := p.Index(-2)
if ps.Type() != messageReflectType {
return false
}
// Filter for unknown fields (which always have a numeric map key).
return strings.Trim(mi.Key().String(), "0123456789") == ""
}, cmp.Ignore())
}
// SortRepeated sorts repeated fields of the specified element type.
// The less function must be of the form "func(T, T) bool" where T is the
// Go element type for the repeated field kind.
//
// The element type T can be one of the following:
// - Go type for a protobuf scalar kind except for an enum
// (i.e., bool, int32, int64, uint32, uint64, float32, float64, string, and []byte)
// - E where E is a concrete enum type that implements [protoreflect.Enum]
// - M where M is a concrete message type that implement [proto.Message]
//
// This option only applies to repeated fields within a protobuf message.
// It does not operate on higher-order Go types that seem like a repeated field.
// For example, a []T outside the context of a protobuf message will not be
// handled by this option. To sort Go slices that are not repeated fields,
// consider using [github.com/google/go-cmp/cmp/cmpopts.SortSlices] instead.
//
// This must be used in conjunction with [Transform].
func SortRepeated(lessFunc any) cmp.Option {
t, ok := checkTTBFunc(lessFunc)
if !ok {
panic(fmt.Sprintf("invalid less function: %T", lessFunc))
}
var opt cmp.Option
var sliceType reflect.Type
switch vf := reflect.ValueOf(lessFunc); {
case t.Implements(enumV2Type):
et := reflect.Zero(t).Interface().(protoreflect.Enum).Type()
lessFunc = func(x, y Enum) bool {
vx := reflect.ValueOf(et.New(x.Number()))
vy := reflect.ValueOf(et.New(y.Number()))
return vf.Call([]reflect.Value{vx, vy})[0].Bool()
}
opt = FilterDescriptor(et.Descriptor(), cmpopts.SortSlices(lessFunc))
sliceType = reflect.SliceOf(enumReflectType)
case t.Implements(messageV2Type):
mt := reflect.Zero(t).Interface().(protoreflect.ProtoMessage).ProtoReflect().Type()
lessFunc = func(x, y Message) bool {
mx := mt.New().Interface()
my := mt.New().Interface()
proto.Merge(mx, x)
proto.Merge(my, y)
vx := reflect.ValueOf(mx)
vy := reflect.ValueOf(my)
return vf.Call([]reflect.Value{vx, vy})[0].Bool()
}
opt = FilterDescriptor(mt.Descriptor(), cmpopts.SortSlices(lessFunc))
sliceType = reflect.SliceOf(messageReflectType)
default:
switch t {
case reflect.TypeOf(bool(false)):
case reflect.TypeOf(int32(0)):
case reflect.TypeOf(int64(0)):
case reflect.TypeOf(uint32(0)):
case reflect.TypeOf(uint64(0)):
case reflect.TypeOf(float32(0)):
case reflect.TypeOf(float64(0)):
case reflect.TypeOf(string("")):
case reflect.TypeOf([]byte(nil)):
default:
panic(fmt.Sprintf("invalid element type: %v", t))
}
opt = cmpopts.SortSlices(lessFunc)
sliceType = reflect.SliceOf(t)
}
return cmp.FilterPath(func(p cmp.Path) bool {
// Filter to only apply to repeated fields within a message.
if t := p.Index(-1).Type(); t == nil || t != sliceType {
return false
}
if t := p.Index(-2).Type(); t == nil || t.Kind() != reflect.Interface {
return false
}
if t := p.Index(-3).Type(); t == nil || t != messageReflectType {
return false
}
return true
}, opt)
}
func checkTTBFunc(lessFunc any) (reflect.Type, bool) {
switch t := reflect.TypeOf(lessFunc); {
case t == nil:
return nil, false
case t.NumIn() != 2 || t.In(0) != t.In(1) || t.IsVariadic():
return nil, false
case t.NumOut() != 1 || t.Out(0) != reflect.TypeOf(false):
return nil, false
default:
return t.In(0), true
}
}
// SortRepeatedFields sorts the specified repeated fields.
// Sorting a repeated field is useful for treating the list as a multiset
// (i.e., a set where each value can appear multiple times).
// It panics if the field does not exist or is not a repeated field.
//
// The sort ordering is as follows:
// - Booleans are sorted where false is sorted before true.
// - Integers are sorted in ascending order.
// - Floating-point numbers are sorted in ascending order according to
// the total ordering defined by IEEE-754 (section 5.10).
// - Strings and bytes are sorted lexicographically in ascending order.
// - [Enum] values are sorted in ascending order based on its numeric value.
// - [Message] values are sorted according to some arbitrary ordering
// which is undefined and may change in future implementations.
//
// The ordering chosen for repeated messages is unlikely to be aesthetically
// preferred by humans. Consider using a custom sort function:
//
// FilterField(m, "foo_field", SortRepeated(func(x, y *foopb.MyMessage) bool {
// ... // user-provided definition for less
// }))
//
// This must be used in conjunction with [Transform].
func SortRepeatedFields(message proto.Message, names ...protoreflect.Name) cmp.Option {
var opts cmp.Options
md := message.ProtoReflect().Descriptor()
for _, name := range names {
fd := mustFindFieldDescriptor(md, name)
if !fd.IsList() {
panic(fmt.Sprintf("message field %q is not repeated", fd.FullName()))
}
var lessFunc any
switch fd.Kind() {
case protoreflect.BoolKind:
lessFunc = func(x, y bool) bool { return !x && y }
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
lessFunc = func(x, y int32) bool { return x < y }
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
lessFunc = func(x, y int64) bool { return x < y }
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
lessFunc = func(x, y uint32) bool { return x < y }
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
lessFunc = func(x, y uint64) bool { return x < y }
case protoreflect.FloatKind:
lessFunc = lessF32
case protoreflect.DoubleKind:
lessFunc = lessF64
case protoreflect.StringKind:
lessFunc = func(x, y string) bool { return x < y }
case protoreflect.BytesKind:
lessFunc = func(x, y []byte) bool { return bytes.Compare(x, y) < 0 }
case protoreflect.EnumKind:
lessFunc = func(x, y Enum) bool { return x.Number() < y.Number() }
case protoreflect.MessageKind, protoreflect.GroupKind:
lessFunc = func(x, y Message) bool { return x.String() < y.String() }
default:
panic(fmt.Sprintf("invalid kind: %v", fd.Kind()))
}
opts = append(opts, FilterDescriptor(fd, cmpopts.SortSlices(lessFunc)))
}
return opts
}
func lessF32(x, y float32) bool {
// Bit-wise implementation of IEEE-754, section 5.10.
xi := int32(math.Float32bits(x))
yi := int32(math.Float32bits(y))
xi ^= int32(uint32(xi>>31) >> 1)
yi ^= int32(uint32(yi>>31) >> 1)
return xi < yi
}
func lessF64(x, y float64) bool {
// Bit-wise implementation of IEEE-754, section 5.10.
xi := int64(math.Float64bits(x))
yi := int64(math.Float64bits(y))
xi ^= int64(uint64(xi>>63) >> 1)
yi ^= int64(uint64(yi>>63) >> 1)
return xi < yi
}

View File

@@ -0,0 +1,377 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package protocmp provides protobuf specific options for the
// [github.com/google/go-cmp/cmp] package.
//
// The primary feature is the [Transform] option, which transform [proto.Message]
// types into a [Message] map that is suitable for cmp to introspect upon.
// All other options in this package must be used in conjunction with [Transform].
package protocmp
import (
"reflect"
"strconv"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/msgfmt"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoiface"
"google.golang.org/protobuf/runtime/protoimpl"
)
var (
enumV2Type = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem()
messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem()
messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem()
)
// Enum is a dynamic representation of a protocol buffer enum that is
// suitable for [cmp.Equal] and [cmp.Diff] to compare upon.
type Enum struct {
num protoreflect.EnumNumber
ed protoreflect.EnumDescriptor
}
// Descriptor returns the enum descriptor.
// It returns nil for a zero Enum value.
func (e Enum) Descriptor() protoreflect.EnumDescriptor {
return e.ed
}
// Number returns the enum value as an integer.
func (e Enum) Number() protoreflect.EnumNumber {
return e.num
}
// Equal reports whether e1 and e2 represent the same enum value.
func (e1 Enum) Equal(e2 Enum) bool {
if e1.ed.FullName() != e2.ed.FullName() {
return false
}
return e1.num == e2.num
}
// String returns the name of the enum value if known (e.g., "ENUM_VALUE"),
// otherwise it returns the formatted decimal enum number (e.g., "14").
func (e Enum) String() string {
if ev := e.ed.Values().ByNumber(e.num); ev != nil {
return string(ev.Name())
}
return strconv.Itoa(int(e.num))
}
const (
// messageTypeKey indicates the protobuf message type.
// The value type is always messageMeta.
// From the public API, it presents itself as only the type, but the
// underlying data structure holds arbitrary metadata about the message.
messageTypeKey = "@type"
// messageInvalidKey indicates that the message is invalid.
// The value is always the boolean "true".
messageInvalidKey = "@invalid"
)
type messageMeta struct {
m proto.Message
md protoreflect.MessageDescriptor
xds map[string]protoreflect.ExtensionDescriptor
}
func (t messageMeta) String() string {
return string(t.md.FullName())
}
func (t1 messageMeta) Equal(t2 messageMeta) bool {
return t1.md.FullName() == t2.md.FullName()
}
// Message is a dynamic representation of a protocol buffer message that is
// suitable for [cmp.Equal] and [cmp.Diff] to directly operate upon.
//
// Every populated known field (excluding extension fields) is stored in the map
// with the key being the short name of the field (e.g., "field_name") and
// the value determined by the kind and cardinality of the field.
//
// Singular scalars are represented by the same Go type as [protoreflect.Value],
// singular messages are represented by the [Message] type,
// singular enums are represented by the [Enum] type,
// list fields are represented as a Go slice, and
// map fields are represented as a Go map.
//
// Every populated extension field is stored in the map with the key being the
// full name of the field surrounded by brackets (e.g., "[extension.full.name]")
// and the value determined according to the same rules as known fields.
//
// Every unknown field is stored in the map with the key being the field number
// encoded as a decimal string (e.g., "132") and the value being the raw bytes
// of the encoded field (as the [protoreflect.RawFields] type).
//
// Message values must not be created by or mutated by users.
type Message map[string]any
// Unwrap returns the original message value.
// It returns nil if this Message was not constructed from another message.
func (m Message) Unwrap() proto.Message {
mm, _ := m[messageTypeKey].(messageMeta)
return mm.m
}
// Descriptor return the message descriptor.
// It returns nil for a zero Message value.
func (m Message) Descriptor() protoreflect.MessageDescriptor {
mm, _ := m[messageTypeKey].(messageMeta)
return mm.md
}
// ProtoReflect returns a reflective view of m.
// It only implements the read-only operations of [protoreflect.Message].
// Calling any mutating operations on m panics.
func (m Message) ProtoReflect() protoreflect.Message {
return (reflectMessage)(m)
}
// ProtoMessage is a marker method from the legacy message interface.
func (m Message) ProtoMessage() {}
// Reset is the required Reset method from the legacy message interface.
func (m Message) Reset() {
panic("invalid mutation of a read-only message")
}
// String returns a formatted string for the message.
// It is intended for human debugging and has no guarantees about its
// exact format or the stability of its output.
func (m Message) String() string {
switch {
case m == nil:
return "<nil>"
case !m.ProtoReflect().IsValid():
return "<invalid>"
default:
return msgfmt.Format(m)
}
}
type transformer struct {
resolver protoregistry.MessageTypeResolver
}
func newTransformer(opts ...option) *transformer {
xf := &transformer{
resolver: protoregistry.GlobalTypes,
}
for _, opt := range opts {
opt(xf)
}
return xf
}
type option func(*transformer)
// MessageTypeResolver overrides the resolver used for messages packed
// inside Any. The default is protoregistry.GlobalTypes, which is
// sufficient for all compiled-in Protobuf messages. Overriding the
// resolver is useful in tests that dynamically create Protobuf
// descriptors and messages, e.g. in proxies using dynamicpb.
func MessageTypeResolver(r protoregistry.MessageTypeResolver) option {
return func(xf *transformer) {
xf.resolver = r
}
}
// Transform returns a [cmp.Option] that converts each [proto.Message] to a [Message].
// The transformation does not mutate nor alias any converted messages.
//
// The google.protobuf.Any message is automatically unmarshaled such that the
// "value" field is a [Message] representing the underlying message value
// assuming it could be resolved and properly unmarshaled.
//
// This does not directly transform higher-order composite Go types.
// For example, []*foopb.Message is not transformed into []Message,
// but rather the individual message elements of the slice are transformed.
func Transform(opts ...option) cmp.Option {
xf := newTransformer(opts...)
// addrType returns a pointer to t if t isn't a pointer or interface.
addrType := func(t reflect.Type) reflect.Type {
if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr {
return t
}
return reflect.PtrTo(t)
}
// TODO: Should this transform protoreflect.Enum types to Enum as well?
return cmp.FilterPath(func(p cmp.Path) bool {
ps := p.Last()
if isMessageType(addrType(ps.Type())) {
return true
}
// Check whether the concrete values of an interface both satisfy
// the Message interface.
if ps.Type().Kind() == reflect.Interface {
vx, vy := ps.Values()
if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() {
return false
}
return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type()))
}
return false
}, cmp.Transformer("protocmp.Transform", func(v any) Message {
// For user convenience, shallow copy the message value if necessary
// in order for it to implement the message interface.
if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) {
pv := reflect.New(rv.Type())
pv.Elem().Set(rv)
v = pv.Interface()
}
m := protoimpl.X.MessageOf(v)
switch {
case m == nil:
return nil
case !m.IsValid():
return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
default:
return xf.transformMessage(m)
}
}))
}
func isMessageType(t reflect.Type) bool {
// Avoid transforming the Message itself.
if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) {
return false
}
return t.Implements(messageV1Type) || t.Implements(messageV2Type)
}
func (xf *transformer) transformMessage(m protoreflect.Message) Message {
mx := Message{}
mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
// Handle known and extension fields.
m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
s := fd.TextName()
if fd.IsExtension() {
mt.xds[s] = fd
}
switch {
case fd.IsList():
mx[s] = xf.transformList(fd, v.List())
case fd.IsMap():
mx[s] = xf.transformMap(fd, v.Map())
default:
mx[s] = xf.transformSingular(fd, v)
}
return true
})
// Handle unknown fields.
for b := m.GetUnknown(); len(b) > 0; {
num, _, n := protowire.ConsumeField(b)
s := strconv.Itoa(int(num))
b2, _ := mx[s].(protoreflect.RawFields)
mx[s] = append(b2, b[:n]...)
b = b[n:]
}
// Expand Any messages.
if mt.md.FullName() == genid.Any_message_fullname {
s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string)
b, _ := mx[string(genid.Any_Value_field_name)].([]byte)
mt, err := xf.resolver.FindMessageByURL(s)
if mt != nil && err == nil {
m2 := mt.New()
err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface())
if err == nil {
mx[string(genid.Any_Value_field_name)] = xf.transformMessage(m2)
}
}
}
mx[messageTypeKey] = mt
return mx
}
func (xf *transformer) transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) any {
t := protoKindToGoType(fd.Kind())
rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len())
for i := 0; i < lv.Len(); i++ {
v := reflect.ValueOf(xf.transformSingular(fd, lv.Get(i)))
rv.Index(i).Set(v)
}
return rv.Interface()
}
func (xf *transformer) transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) any {
kfd := fd.MapKey()
vfd := fd.MapValue()
kt := protoKindToGoType(kfd.Kind())
vt := protoKindToGoType(vfd.Kind())
rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len())
mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
kv := reflect.ValueOf(xf.transformSingular(kfd, k.Value()))
vv := reflect.ValueOf(xf.transformSingular(vfd, v))
rv.SetMapIndex(kv, vv)
return true
})
return rv.Interface()
}
func (xf *transformer) transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) any {
switch fd.Kind() {
case protoreflect.EnumKind:
return Enum{num: v.Enum(), ed: fd.Enum()}
case protoreflect.MessageKind, protoreflect.GroupKind:
return xf.transformMessage(v.Message())
case protoreflect.BytesKind:
// The protoreflect API does not specify whether an empty bytes is
// guaranteed to be nil or not. Always return non-nil bytes to avoid
// leaking information about the concrete proto.Message implementation.
if len(v.Bytes()) == 0 {
return []byte{}
}
return v.Bytes()
default:
return v.Interface()
}
}
func protoKindToGoType(k protoreflect.Kind) reflect.Type {
switch k {
case protoreflect.BoolKind:
return reflect.TypeOf(bool(false))
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
return reflect.TypeOf(int32(0))
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
return reflect.TypeOf(int64(0))
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
return reflect.TypeOf(uint32(0))
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
return reflect.TypeOf(uint64(0))
case protoreflect.FloatKind:
return reflect.TypeOf(float32(0))
case protoreflect.DoubleKind:
return reflect.TypeOf(float64(0))
case protoreflect.StringKind:
return reflect.TypeOf(string(""))
case protoreflect.BytesKind:
return reflect.TypeOf([]byte(nil))
case protoreflect.EnumKind:
return reflect.TypeOf(Enum{})
case protoreflect.MessageKind, protoreflect.GroupKind:
return reflect.TypeOf(Message{})
default:
panic("invalid kind")
}
}

2
vendor/modules.txt vendored
View File

@@ -2353,6 +2353,7 @@ google.golang.org/protobuf/internal/filetype
google.golang.org/protobuf/internal/flags
google.golang.org/protobuf/internal/genid
google.golang.org/protobuf/internal/impl
google.golang.org/protobuf/internal/msgfmt
google.golang.org/protobuf/internal/order
google.golang.org/protobuf/internal/pragma
google.golang.org/protobuf/internal/set
@@ -2365,6 +2366,7 @@ google.golang.org/protobuf/reflect/protoreflect
google.golang.org/protobuf/reflect/protoregistry
google.golang.org/protobuf/runtime/protoiface
google.golang.org/protobuf/runtime/protoimpl
google.golang.org/protobuf/testing/protocmp
google.golang.org/protobuf/types/descriptorpb
google.golang.org/protobuf/types/gofeaturespb
google.golang.org/protobuf/types/known/anypb