package tx import ( "fmt" "google.golang.org/protobuf/encoding/protowire" "github.com/cosmos/cosmos-sdk/codec" "github.com/cosmos/cosmos-sdk/codec/unknownproto" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/cosmos-sdk/types/tx" ) // DefaultTxDecoder returns a default protobuf TxDecoder using the provided Marshaler. func DefaultTxDecoder(cdc codec.ProtoCodecMarshaler) sdk.TxDecoder { return func(txBytes []byte) (sdk.Tx, error) { // Make sure txBytes follow ADR-027. err := rejectNonADR027TxRaw(txBytes) if err != nil { return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error()) } var raw tx.TxRaw // reject all unknown proto fields in the root TxRaw err = unknownproto.RejectUnknownFieldsStrict(txBytes, &raw, cdc.InterfaceRegistry()) if err != nil { return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error()) } err = cdc.Unmarshal(txBytes, &raw) if err != nil { return nil, err } var body tx.TxBody // allow non-critical unknown fields in TxBody txBodyHasUnknownNonCriticals, err := unknownproto.RejectUnknownFields(raw.BodyBytes, &body, true, cdc.InterfaceRegistry()) if err != nil { return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error()) } err = cdc.Unmarshal(raw.BodyBytes, &body) if err != nil { return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error()) } var authInfo tx.AuthInfo // reject all unknown proto fields in AuthInfo err = unknownproto.RejectUnknownFieldsStrict(raw.AuthInfoBytes, &authInfo, cdc.InterfaceRegistry()) if err != nil { return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error()) } err = cdc.Unmarshal(raw.AuthInfoBytes, &authInfo) if err != nil { return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error()) } theTx := &tx.Tx{ Body: &body, AuthInfo: &authInfo, Signatures: raw.Signatures, } return &wrapper{ tx: theTx, bodyBz: raw.BodyBytes, authInfoBz: raw.AuthInfoBytes, txBodyHasUnknownNonCriticals: txBodyHasUnknownNonCriticals, }, nil } } // DefaultJSONTxDecoder returns a default protobuf JSON TxDecoder using the provided Marshaler. func DefaultJSONTxDecoder(cdc codec.ProtoCodecMarshaler) sdk.TxDecoder { return func(txBytes []byte) (sdk.Tx, error) { var theTx tx.Tx err := cdc.UnmarshalJSON(txBytes, &theTx) if err != nil { return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error()) } return &wrapper{ tx: &theTx, }, nil } } // rejectNonADR027TxRaw rejects txBytes that do not follow ADR-027. This is NOT // a generic ADR-027 checker, it only applies decoding TxRaw. Specifically, it // only checks that: // - field numbers are in ascending order (1, 2, and potentially multiple 3s), // - and varints are as short as possible. // All other ADR-027 edge cases (e.g. default values) are not applicable with // TxRaw. func rejectNonADR027TxRaw(txBytes []byte) error { // Make sure all fields are ordered in ascending order with this variable. prevTagNum := protowire.Number(0) for len(txBytes) > 0 { tagNum, wireType, m := protowire.ConsumeTag(txBytes) if m < 0 { return fmt.Errorf("invalid length; %w", protowire.ParseError(m)) } // TxRaw only has bytes fields. if wireType != protowire.BytesType { return fmt.Errorf("expected %d wire type, got %d", protowire.BytesType, wireType) } // Make sure fields are ordered in ascending order. if tagNum < prevTagNum { return fmt.Errorf("txRaw must follow ADR-027, got tagNum %d after tagNum %d", tagNum, prevTagNum) } prevTagNum = tagNum // All 3 fields of TxRaw have wireType == 2, so their next component // is a varint, so we can safely call ConsumeVarint here. // Byte structure: // Inner fields are verified in `DefaultTxDecoder` lengthPrefix, m := protowire.ConsumeVarint(txBytes[m:]) if m < 0 { return fmt.Errorf("invalid length; %w", protowire.ParseError(m)) } // We make sure that this varint is as short as possible. n := varintMinLength(lengthPrefix) if n != m { return fmt.Errorf("length prefix varint for tagNum %d is not as short as possible, read %d, only need %d", tagNum, m, n) } // Skip over the bytes that store fieldNumber and wireType bytes. _, _, m = protowire.ConsumeField(txBytes) if m < 0 { return fmt.Errorf("invalid length; %w", protowire.ParseError(m)) } txBytes = txBytes[m:] } return nil } // varintMinLength returns the minimum number of bytes necessary to encode an // uint using varint encoding. func varintMinLength(n uint64) int { switch { // Note: 1<