cosmos-sdk/codec/unknownproto/unknown_fields.go

440 lines
14 KiB
Go

package unknownproto
import (
"bytes"
"compress/gzip"
"errors"
"fmt"
"io/ioutil"
"reflect"
"strings"
"sync"
"github.com/gogo/protobuf/jsonpb"
"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
"google.golang.org/protobuf/encoding/protowire"
"github.com/cosmos/cosmos-sdk/codec/types"
)
const bit11NonCritical = 1 << 10
type descriptorIface interface {
Descriptor() ([]byte, []int)
}
// RejectUnknownFieldsStrict rejects any bytes bz with an error that has unknown fields for the provided proto.Message type.
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
// An AnyResolver must be provided for traversing inside google.protobuf.Any's.
func RejectUnknownFieldsStrict(bz []byte, msg proto.Message, resolver jsonpb.AnyResolver) error {
_, err := RejectUnknownFields(bz, msg, false, resolver)
return err
}
// RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an
// option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the
// hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be
// used to treat a message with non-critical field different in different security contexts (such as transaction signing).
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
// An AnyResolver must be provided for traversing inside google.protobuf.Any's.
func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool, resolver jsonpb.AnyResolver) (hasUnknownNonCriticals bool, err error) {
if len(bz) == 0 {
return hasUnknownNonCriticals, nil
}
desc, ok := msg.(descriptorIface)
if !ok {
return hasUnknownNonCriticals, fmt.Errorf("%T does not have a Descriptor() method", msg)
}
fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg)
if err != nil {
return hasUnknownNonCriticals, err
}
for len(bz) > 0 {
tagNum, wireType, m := protowire.ConsumeTag(bz)
if m < 0 {
return hasUnknownNonCriticals, errors.New("invalid length")
}
fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)]
switch {
case ok:
// Assert that the wireTypes match.
if !canEncodeType(wireType, fieldDescProto.GetType()) {
return hasUnknownNonCriticals, &errMismatchedWireType{
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
GotWireType: wireType,
WantWireType: protowire.Type(fieldDescProto.WireType()),
}
}
default:
isCriticalField := tagNum&bit11NonCritical == 0
if !isCriticalField {
hasUnknownNonCriticals = true
}
if isCriticalField || !allowUnknownNonCriticals {
// The tag is critical, so report it.
return hasUnknownNonCriticals, &errUnknownField{
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
WireType: wireType,
}
}
}
// Skip over the bytes that store fieldNumber and wireType bytes.
bz = bz[m:]
n := protowire.ConsumeFieldValue(tagNum, wireType, bz)
if n < 0 {
err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w",
tagNum, wireTypeToString(wireType), protowire.ParseError(n))
return hasUnknownNonCriticals, err
}
fieldBytes := bz[:n]
bz = bz[n:]
// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
if fieldDescProto == nil || fieldDescProto.IsScalar() {
continue
}
protoMessageName := fieldDescProto.GetTypeName()
if protoMessageName == "" {
switch typ := fieldDescProto.GetType(); typ {
case descriptor.FieldDescriptorProto_TYPE_STRING, descriptor.FieldDescriptorProto_TYPE_BYTES:
// At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for
// TYPE_BYTES and TYPE_STRING as per
// https://github.com/gogo/protobuf/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
default:
return hasUnknownNonCriticals, fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
}
continue
}
// Let's recursively traverse and typecheck the field.
// consume length prefix of nested message
_, o := protowire.ConsumeVarint(fieldBytes)
fieldBytes = fieldBytes[o:]
var msg proto.Message
var err error
if protoMessageName == ".google.protobuf.Any" {
// Firstly typecheck types.Any to ensure nothing snuck in.
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
}
// And finally we can extract the TypeURL containing the protoMessageName.
any := new(types.Any)
if err := proto.Unmarshal(fieldBytes, any); err != nil {
return hasUnknownNonCriticals, err
}
protoMessageName = any.TypeUrl
fieldBytes = any.Value
msg, err = resolver.Resolve(protoMessageName)
if err != nil {
return hasUnknownNonCriticals, err
}
} else {
msg, err = protoMessageForTypeName(protoMessageName[1:])
if err != nil {
return hasUnknownNonCriticals, err
}
}
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
}
}
return hasUnknownNonCriticals, nil
}
var protoMessageForTypeNameMu sync.RWMutex
var protoMessageForTypeNameCache = make(map[string]proto.Message)
// protoMessageForTypeName takes in a fully qualified name e.g. testdata.TestVersionFD1
// and returns a corresponding empty protobuf message that serves the prototype for typechecking.
func protoMessageForTypeName(protoMessageName string) (proto.Message, error) {
protoMessageForTypeNameMu.RLock()
msg, ok := protoMessageForTypeNameCache[protoMessageName]
protoMessageForTypeNameMu.RUnlock()
if ok {
return msg, nil
}
concreteGoType := proto.MessageType(protoMessageName)
if concreteGoType == nil {
return nil, fmt.Errorf("failed to retrieve the message of type %q", protoMessageName)
}
value := reflect.New(concreteGoType).Elem()
msg, ok = value.Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("%q does not implement proto.Message", protoMessageName)
}
// Now cache it.
protoMessageForTypeNameMu.Lock()
protoMessageForTypeNameCache[protoMessageName] = msg
protoMessageForTypeNameMu.Unlock()
return msg, nil
}
// checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type.
// it is implemented this way so as to have constant time lookups and avoid the overhead
// from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%.
var checks = [...]map[descriptor.FieldDescriptorProto_Type]bool{
// "0 Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum"
0: {
descriptor.FieldDescriptorProto_TYPE_INT32: true,
descriptor.FieldDescriptorProto_TYPE_INT64: true,
descriptor.FieldDescriptorProto_TYPE_UINT32: true,
descriptor.FieldDescriptorProto_TYPE_UINT64: true,
descriptor.FieldDescriptorProto_TYPE_SINT32: true,
descriptor.FieldDescriptorProto_TYPE_SINT64: true,
descriptor.FieldDescriptorProto_TYPE_BOOL: true,
descriptor.FieldDescriptorProto_TYPE_ENUM: true,
},
// "1 64-bit: fixed64, sfixed64, double"
1: {
descriptor.FieldDescriptorProto_TYPE_FIXED64: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptor.FieldDescriptorProto_TYPE_DOUBLE: true,
},
// "2 Length-delimited: string, bytes, embedded messages, packed repeated fields"
2: {
descriptor.FieldDescriptorProto_TYPE_STRING: true,
descriptor.FieldDescriptorProto_TYPE_BYTES: true,
descriptor.FieldDescriptorProto_TYPE_MESSAGE: true,
// The following types can be packed repeated.
// ref: "Only repeated fields of primitive numeric types (types which use the varint, 32-bit, or 64-bit wire types) can be declared "packed"."
// ref: https://developers.google.com/protocol-buffers/docs/encoding#packed
descriptor.FieldDescriptorProto_TYPE_INT32: true,
descriptor.FieldDescriptorProto_TYPE_INT64: true,
descriptor.FieldDescriptorProto_TYPE_UINT32: true,
descriptor.FieldDescriptorProto_TYPE_UINT64: true,
descriptor.FieldDescriptorProto_TYPE_SINT32: true,
descriptor.FieldDescriptorProto_TYPE_SINT64: true,
descriptor.FieldDescriptorProto_TYPE_BOOL: true,
descriptor.FieldDescriptorProto_TYPE_ENUM: true,
descriptor.FieldDescriptorProto_TYPE_FIXED64: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptor.FieldDescriptorProto_TYPE_DOUBLE: true,
},
// "3 Start group: groups (deprecated)"
3: {
descriptor.FieldDescriptorProto_TYPE_GROUP: true,
},
// "4 End group: groups (deprecated)"
4: {
descriptor.FieldDescriptorProto_TYPE_GROUP: true,
},
// "5 32-bit: fixed32, sfixed32, float"
5: {
descriptor.FieldDescriptorProto_TYPE_FIXED32: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED32: true,
descriptor.FieldDescriptorProto_TYPE_FLOAT: true,
},
}
// canEncodeType returns true if the wireType is suitable for encoding the descriptor type.
// See https://developers.google.com/protocol-buffers/docs/encoding#structure.
func canEncodeType(wireType protowire.Type, descType descriptor.FieldDescriptorProto_Type) bool {
if iwt := int(wireType); iwt < 0 || iwt >= len(checks) {
return false
}
return checks[wireType][descType]
}
// errMismatchedWireType describes a mismatch between
// expected and got wireTypes for a specific tag number.
type errMismatchedWireType struct {
Type string
GotWireType protowire.Type
WantWireType protowire.Type
TagNum protowire.Number
}
// String implements fmt.Stringer.
func (mwt *errMismatchedWireType) String() string {
return fmt.Sprintf("Mismatched %q: {TagNum: %d, GotWireType: %q != WantWireType: %q}",
mwt.Type, mwt.TagNum, wireTypeToString(mwt.GotWireType), wireTypeToString(mwt.WantWireType))
}
// Error implements the error interface.
func (mwt *errMismatchedWireType) Error() string {
return mwt.String()
}
var _ error = (*errMismatchedWireType)(nil)
func wireTypeToString(wt protowire.Type) string {
switch wt {
case 0:
return "varint"
case 1:
return "fixed64"
case 2:
return "bytes"
case 3:
return "start_group"
case 4:
return "end_group"
case 5:
return "fixed32"
default:
return fmt.Sprintf("unknown type: %d", wt)
}
}
// errUnknownField represents an error indicating that we encountered
// a field that isn't available in the target proto.Message.
type errUnknownField struct {
Type string
TagNum protowire.Number
WireType protowire.Type
}
// String implements fmt.Stringer.
func (twt *errUnknownField) String() string {
return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}",
twt.Type, twt.TagNum, wireTypeToString(twt.WireType))
}
// Error implements the error interface.
func (twt *errUnknownField) Error() string {
return twt.String()
}
var _ error = (*errUnknownField)(nil)
var (
protoFileToDesc = make(map[string]*descriptor.FileDescriptorProto)
protoFileToDescMu sync.RWMutex
)
func unnestDesc(mdescs []*descriptor.DescriptorProto, indices []int) *descriptor.DescriptorProto {
mdesc := mdescs[indices[0]]
for _, index := range indices[1:] {
mdesc = mdesc.NestedType[index]
}
return mdesc
}
// Invoking descriptor.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
// for every single message, thus the need for a hand-rolled custom version that's performant and cacheable.
func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescriptorProto, *descriptor.DescriptorProto, error) {
gzippedPb, indices := desc.Descriptor()
protoFileToDescMu.RLock()
cached, ok := protoFileToDesc[string(gzippedPb)]
protoFileToDescMu.RUnlock()
if ok {
return cached, unnestDesc(cached.MessageType, indices), nil
}
// Time to gunzip the content of the FileDescriptor and then proto unmarshal them.
gzr, err := gzip.NewReader(bytes.NewReader(gzippedPb))
if err != nil {
return nil, nil, err
}
protoBlob, err := ioutil.ReadAll(gzr)
if err != nil {
return nil, nil, err
}
fdesc := new(descriptor.FileDescriptorProto)
if err := proto.Unmarshal(protoBlob, fdesc); err != nil {
return nil, nil, err
}
// Now cache the FileDescriptor.
protoFileToDescMu.Lock()
protoFileToDesc[string(gzippedPb)] = fdesc
protoFileToDescMu.Unlock()
// Unnest the type if necessary.
return fdesc, unnestDesc(fdesc.MessageType, indices), nil
}
type descriptorMatch struct {
cache map[int32]*descriptor.FieldDescriptorProto
desc *descriptor.DescriptorProto
}
var descprotoCacheMu sync.RWMutex
var descprotoCache = make(map[reflect.Type]*descriptorMatch)
// getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors.
func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptor.FieldDescriptorProto, *descriptor.DescriptorProto, error) {
key := reflect.ValueOf(msg).Type()
descprotoCacheMu.RLock()
got, ok := descprotoCache[key]
descprotoCacheMu.RUnlock()
if ok {
return got.cache, got.desc, nil
}
// Now compute and cache the index.
_, md, err := extractFileDescMessageDesc(desc)
if err != nil {
return nil, nil, err
}
tagNumToTypeIndex := make(map[int32]*descriptor.FieldDescriptorProto)
for _, field := range md.Field {
tagNumToTypeIndex[field.GetNumber()] = field
}
descprotoCacheMu.Lock()
descprotoCache[key] = &descriptorMatch{
cache: tagNumToTypeIndex,
desc: md,
}
descprotoCacheMu.Unlock()
return tagNumToTypeIndex, md, nil
}
// DefaultAnyResolver is a default implementation of AnyResolver which uses
// the default encoding of type URLs as specified by the protobuf specification.
type DefaultAnyResolver struct{}
var _ jsonpb.AnyResolver = DefaultAnyResolver{}
// Resolve is the AnyResolver.Resolve method.
func (d DefaultAnyResolver) Resolve(typeURL string) (proto.Message, error) {
// Only the part of typeURL after the last slash is relevant.
mname := typeURL
if slash := strings.LastIndex(mname, "/"); slash >= 0 {
mname = mname[slash+1:]
}
mt := proto.MessageType(mname)
if mt == nil {
return nil, fmt.Errorf("unknown message type %q", mname)
}
return reflect.New(mt.Elem()).Interface().(proto.Message), nil
}