Reject unknown fields in TxDecoder and sign mode handlers (#6883)

* WIP on unknown field rejection in TxDecoder

* WIP on unknown field rejection in TxDecoder

* WIP

* WIP

* WIP

* WIP

* Fix bugs with RejectUnknownFields

* Fix tests

* Fix bug and update docs

* Lint

* Add tests

* Add unknown field tests

* Lint

* Address review comments
This commit is contained in:
Aaron Craelius 2020-08-03 15:47:25 -04:00 committed by GitHub
parent 57cd7d62b3
commit 6d937443b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 2598 additions and 968 deletions

View File

@ -65,7 +65,7 @@ type TestSuite struct {
func (s *TestSuite) SetupSuite() {
encCfg := simapp.MakeEncodingConfig()
s.encCfg = encCfg
s.protoCfg = tx.NewTxConfig(codec.NewProtoCodec(encCfg.InterfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModeHandler())
s.protoCfg = tx.NewTxConfig(codec.NewProtoCodec(encCfg.InterfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModes)
s.aminoCfg = types3.StdTxConfig{Cdc: encCfg.Amino}
}

View File

@ -53,11 +53,10 @@ func benchmarkRejectUnknownFields(b *testing.B, parallel bool) {
b.ReportAllocs()
if !parallel {
ckr := new(unknownproto.Checker)
b.ResetTimer()
for i := 0; i < b.N; i++ {
n1A := new(testdata.Nested1A)
if err := ckr.RejectUnknownFields(n1BBlob, n1A); err == nil {
if err := unknownproto.RejectUnknownFieldsStrict(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
b.SetBytes(int64(len(n1BBlob)))
@ -66,11 +65,10 @@ func benchmarkRejectUnknownFields(b *testing.B, parallel bool) {
var mu sync.Mutex
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ckr := new(unknownproto.Checker)
for pb.Next() {
// To simulate the conditions of multiple transactions being processed in parallel.
n1A := new(testdata.Nested1A)
if err := ckr.RejectUnknownFields(n1BBlob, n1A); err == nil {
if err := unknownproto.RejectUnknownFieldsStrict(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
mu.Lock()

View File

@ -6,22 +6,18 @@ a) Unknown fields in the stream -- this is indicative of mismatched services, pe
b) Mismatched wire types for a field -- this is indicative of mismatched services
Its API signature is similar to proto.Unmarshal([]byte, proto.Message) as
Its API signature is similar to proto.Unmarshal([]byte, proto.Message) in the strict case
ckr := new(unknownproto.Checker)
if err := ckr.RejectUnknownFields(protoBlob, protoMessage); err != nil {
if err := RejectUnknownFieldsStrict(protoBlob, protoMessage, false); err != nil {
// Handle the error.
}
and ideally should be added before invoking proto.Unmarshal, if you'd like to enforce the features mentioned above.
By default, for security we report every single field that's unknown, whether a non-critical field or not. To customize
this behavior, please create a Checker and set the AllowUnknownNonCriticals to true, for example:
this behavior, please set the boolean parameter allowUnknownNonCriticals to true to RejectUnknownFields:
ckr := &unknownproto.Checker{
AllowUnknownNonCriticals: true,
}
if err := ckr.RejectUnknownFields(protoBlob, protoMessage); err != nil {
if err := RejectUnknownFields(protoBlob, protoMessage, true); err != nil {
// Handle the error.
}
*/

View File

@ -22,30 +22,37 @@ type descriptorIface interface {
Descriptor() ([]byte, []int)
}
type Checker struct {
// AllowUnknownNonCriticals when set will skip over non-critical fields that are unknown.
AllowUnknownNonCriticals bool
// RejectUnknownFieldsStrict rejects any bytes bz with an error that has unknown fields for the provided proto.Message type.
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
func RejectUnknownFieldsStrict(bz []byte, msg proto.Message) error {
_, err := RejectUnknownFields(bz, msg, false)
return err
}
func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
if len(b) == 0 {
return nil
// RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an
// option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the
// hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be
// used to treat a message with non-critical field different in different security contexts (such as transaction signing).
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool) (hasUnknownNonCriticals bool, err error) {
if len(bz) == 0 {
return hasUnknownNonCriticals, nil
}
desc, ok := msg.(descriptorIface)
if !ok {
return fmt.Errorf("%T does not have a Descriptor() method", msg)
return hasUnknownNonCriticals, fmt.Errorf("%T does not have a Descriptor() method", msg)
}
fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg)
if err != nil {
return err
return hasUnknownNonCriticals, err
}
for len(b) > 0 {
tagNum, wireType, n := protowire.ConsumeField(b)
if n < 0 {
return errors.New("invalid length")
for len(bz) > 0 {
tagNum, wireType, m := protowire.ConsumeTag(bz)
if m < 0 {
return hasUnknownNonCriticals, errors.New("invalid length")
}
fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)]
@ -53,7 +60,7 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
case ok:
// Assert that the wireTypes match.
if !canEncodeType(wireType, fieldDescProto.GetType()) {
return &errMismatchedWireType{
return hasUnknownNonCriticals, &errMismatchedWireType{
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
GotWireType: wireType,
@ -62,9 +69,15 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
}
default:
if !ckr.AllowUnknownNonCriticals || tagNum&bit11NonCritical == 0 {
isCriticalField := tagNum&bit11NonCritical == 0
if !isCriticalField {
hasUnknownNonCriticals = true
}
if isCriticalField || !allowUnknownNonCriticals {
// The tag is critical, so report it.
return &errUnknownField{
return hasUnknownNonCriticals, &errUnknownField{
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
WireType: wireType,
@ -72,9 +85,11 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
}
}
// Skip over the 2 bytes that store fieldNumber and wireType bytes.
fieldBytes := b[2:n]
b = b[n:]
// Skip over the bytes that store fieldNumber and wireType bytes.
bz = bz[m:]
n := protowire.ConsumeFieldValue(tagNum, wireType, bz)
fieldBytes := bz[:n]
bz = bz[n:]
// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
if fieldDescProto == nil || fieldDescProto.IsScalar() {
@ -89,22 +104,28 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
// TYPE_BYTES and TYPE_STRING as per
// https://github.com/gogo/protobuf/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
default:
return fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
return hasUnknownNonCriticals, fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
}
continue
}
// Let's recursively traverse and typecheck the field.
// consume length prefix of nested message
_, o := protowire.ConsumeVarint(fieldBytes)
fieldBytes = fieldBytes[o:]
if protoMessageName == ".google.protobuf.Any" {
// Firstly typecheck types.Any to ensure nothing snuck in.
if err := ckr.RejectUnknownFields(fieldBytes, (*types.Any)(nil)); err != nil {
return err
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
}
// And finally we can extract the TypeURL containing the protoMessageName.
any := new(types.Any)
if err := proto.Unmarshal(fieldBytes, any); err != nil {
return err
return hasUnknownNonCriticals, err
}
protoMessageName = any.TypeUrl
fieldBytes = any.Value
@ -112,14 +133,17 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
msg, err := protoMessageForTypeName(protoMessageName[1:])
if err != nil {
return err
return hasUnknownNonCriticals, err
}
if err := ckr.RejectUnknownFields(fieldBytes, msg); err != nil {
return err
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
}
}
return nil
return hasUnknownNonCriticals, nil
}
var protoMessageForTypeNameMu sync.RWMutex

View File

@ -4,6 +4,8 @@ import (
"reflect"
"testing"
"github.com/stretchr/testify/require"
"github.com/gogo/protobuf/proto"
"github.com/cosmos/cosmos-sdk/codec/types"
@ -17,6 +19,7 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
recv proto.Message
wantErr error
allowUnknownNonCriticals bool
hasUnknownNonCriticals bool
}{
{
name: "Unknown field in midst of repeated values",
@ -172,6 +175,7 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
TagNum: 1031,
WireType: 2,
},
hasUnknownNonCriticals: true,
},
{
name: "Unknown field in midst of repeated values, non-critical field ignored",
@ -213,8 +217,9 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
},
},
},
recv: new(testdata.TestVersion1),
wantErr: nil,
recv: new(testdata.TestVersion1),
wantErr: nil,
hasUnknownNonCriticals: true,
},
}
@ -225,11 +230,9 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ckr := &Checker{AllowUnknownNonCriticals: tt.allowUnknownNonCriticals}
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%v\n\nWant:\n%v", gotErr, tt.wantErr)
}
hasUnknownNonCriticals, gotErr := RejectUnknownFields(protoBlob, tt.recv, tt.allowUnknownNonCriticals)
require.Equal(t, tt.wantErr, gotErr)
require.Equal(t, tt.hasUnknownNonCriticals, hasUnknownNonCriticals)
})
}
}
@ -263,7 +266,7 @@ func TestRejectUnknownFields_allowUnknownNonCriticals(t *testing.T) {
wantErr: nil,
},
{
name: "Unkown fields that are critical, but with allowUnknownNonCriticals set",
name: "Unknown fields that are critical, but with allowUnknownNonCriticals set",
allowUnknownNonCriticals: true,
in: &testdata.Customer2{
Id: 289,
@ -285,9 +288,8 @@ func TestRejectUnknownFields_allowUnknownNonCriticals(t *testing.T) {
t.Fatalf("Failed to marshal input: %v", err)
}
ckr := &Checker{AllowUnknownNonCriticals: tt.allowUnknownNonCriticals}
c1 := new(testdata.Customer1)
gotErr := ckr.RejectUnknownFields(blob, c1)
_, gotErr := RejectUnknownFields(blob, c1, tt.allowUnknownNonCriticals)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
@ -498,8 +500,7 @@ func TestRejectUnknownFieldsNested(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ckr := new(Checker)
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
gotErr := RejectUnknownFieldsStrict(protoBlob, tt.recv)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
@ -652,8 +653,7 @@ func TestRejectUnknownFieldsFlat(t *testing.T) {
}
c1 := new(testdata.Customer1)
ckr := new(Checker)
gotErr := ckr.RejectUnknownFields(blob, c1)
gotErr := RejectUnknownFieldsStrict(blob, c1)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
@ -738,8 +738,7 @@ func TestMismatchedTypes_Nested(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ckr := new(Checker)
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
_, gotErr := RejectUnknownFields(protoBlob, tt.recv, false)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}

1
go.mod
View File

@ -12,7 +12,6 @@ require (
github.com/cosmos/go-bip39 v0.0.0-20180819234021-555e2067c45d
github.com/cosmos/ledger-cosmos-go v0.11.1
github.com/enigmampc/btcutil v1.0.3-0.20200723161021-e2fb6adb2a25
github.com/gibson042/canonicaljson-go v1.0.3
github.com/gogo/protobuf v1.3.1
github.com/golang/mock v1.4.4
github.com/golang/protobuf v1.4.2

4
go.sum
View File

@ -144,8 +144,6 @@ github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/gibson042/canonicaljson-go v1.0.3 h1:EAyF8L74AWabkyUmrvEFHEt/AGFQeD6RfwbAuf0j1bI=
github.com/gibson042/canonicaljson-go v1.0.3/go.mod h1:DsLpJTThXyGNO+KZlI85C1/KDcImpP67k/RKVjcaEqo=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
@ -480,8 +478,6 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE=
github.com/spf13/viper v1.6.2/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k=
github.com/spf13/viper v1.6.3/go.mod h1:jUMtyi0/lB5yZH/FjyGAoH7IMNrIhlBf6pXZmbMDvzw=
github.com/spf13/viper v1.7.0 h1:xVKxvI7ouOI5I+U9s2eeiUfMaWBVoXA3AWskkrqK0VM=
github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg=
github.com/spf13/viper v1.7.1 h1:pM5oEahlgWv/WnHXpgbKz7iLIxRf65tye2Ci+XFK5sk=
github.com/spf13/viper v1.7.1/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg=
github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=

View File

@ -14,7 +14,7 @@ func MakeEncodingConfig() EncodingConfig {
amino := codec.New()
interfaceRegistry := types.NewInterfaceRegistry()
marshaler := codec.NewHybridCodec(amino, interfaceRegistry)
txGen := tx.NewTxConfig(codec.NewProtoCodec(interfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModeHandler())
txGen := tx.NewTxConfig(codec.NewProtoCodec(interfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModes)
return EncodingConfig{
InterfaceRegistry: interfaceRegistry,

View File

@ -4,8 +4,9 @@ import (
"container/list"
"errors"
"github.com/cosmos/cosmos-sdk/types/kv"
dbm "github.com/tendermint/tm-db"
"github.com/cosmos/cosmos-sdk/types/kv"
)
// Iterates over iterKVCache items.

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@ package testdata;
import "gogoproto/gogo.proto";
import "google/protobuf/any.proto";
import "cosmos/tx/tx.proto";
option go_package = "github.com/cosmos/cosmos-sdk/testutil/testdata";
@ -341,3 +342,28 @@ message AnyWithExtra {
int64 b = 3;
int64 c = 4;
}
message TestUpdatedTxRaw {
bytes body_bytes = 1;
bytes auth_info_bytes = 2;
repeated bytes signatures = 3;
bytes new_field_5 = 5;
bytes new_field_1024 = 1024;
}
message TestUpdatedTxBody {
repeated google.protobuf.Any messages = 1;
string memo = 2;
int64 timeout_height = 3;
uint64 some_new_field = 4;
string some_new_field_non_critical_field = 1050;
repeated google.protobuf.Any extension_options = 1023;
repeated google.protobuf.Any non_critical_extension_options = 2047;
}
message TestUpdatedAuthInfo {
repeated cosmos.tx.SignerInfo signer_infos = 1;
cosmos.tx.Fee fee = 2;
bytes new_field_3 = 3;
bytes new_field_1024 = 1024;
}

View File

@ -1,66 +0,0 @@
package direct
import (
"fmt"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
sdk "github.com/cosmos/cosmos-sdk/types"
types "github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/x/auth/signing"
)
// ProtoTx defines an interface which protobuf transactions must implement for
// signature verification via SignModeDirect
type ProtoTx interface {
// GetBodyBytes returns the raw serialized bytes for TxBody
GetBodyBytes() []byte
// GetBodyBytes returns the raw serialized bytes for AuthInfo
GetAuthInfoBytes() []byte
}
// ModeHandler defines the SIGN_MODE_DIRECT SignModeHandler
type ModeHandler struct{}
var _ signing.SignModeHandler = ModeHandler{}
// DefaultMode implements SignModeHandler.DefaultMode
func (ModeHandler) DefaultMode() signingtypes.SignMode {
return signingtypes.SignMode_SIGN_MODE_DIRECT
}
// Modes implements SignModeHandler.Modes
func (ModeHandler) Modes() []signingtypes.SignMode {
return []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_DIRECT}
}
// GetSignBytes implements SignModeHandler.GetSignBytes
func (ModeHandler) GetSignBytes(mode signingtypes.SignMode, data signing.SignerData, tx sdk.Tx) ([]byte, error) {
if mode != signingtypes.SignMode_SIGN_MODE_DIRECT {
return nil, fmt.Errorf("expected %s, got %s", signingtypes.SignMode_SIGN_MODE_DIRECT, mode)
}
protoTx, ok := tx.(ProtoTx)
if !ok {
return nil, fmt.Errorf("can only get direct sign bytes for a ProtoTx, got %T", tx)
}
bodyBz := protoTx.GetBodyBytes()
authInfoBz := protoTx.GetAuthInfoBytes()
return SignBytes(bodyBz, authInfoBz, data.ChainID, data.AccountNumber, data.AccountSequence)
}
// SignBytes returns the SIGN_MODE_DIRECT sign bytes for the provided TxBody bytes, AuthInfo bytes, chain ID,
// account number and sequence.
func SignBytes(bodyBytes, authInfoBytes []byte, chainID string, accnum, sequence uint64) ([]byte, error) {
signDoc := types.SignDoc{
BodyBytes: bodyBytes,
AuthInfoBytes: authInfoBytes,
ChainId: chainID,
AccountNumber: accnum,
AccountSequence: sequence,
}
return signDoc.Marshal()
}

View File

@ -18,7 +18,7 @@ func MakeTestHandlerMap() signing.SignModeHandler {
return signing.NewSignModeHandlerMap(
signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON,
[]signing.SignModeHandler{
authtypes.LegacyAminoJSONHandler{},
authtypes.NewStdTxSignModeHandler(),
},
)
}
@ -58,7 +58,7 @@ func TestHandlerMap_GetSignBytes(t *testing.T) {
)
handler := MakeTestHandlerMap()
aminoJSONHandler := authtypes.LegacyAminoJSONHandler{}
aminoJSONHandler := authtypes.NewStdTxSignModeHandler()
signingData := signing.SignerData{
ChainID: chainId,

View File

@ -15,7 +15,6 @@ import (
"github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/types/tx/signing"
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/cosmos/cosmos-sdk/x/auth/signing/direct"
)
type builder struct {
@ -34,12 +33,13 @@ type builder struct {
pubKeys []crypto.PubKey
pubkeyCodec types.PublicKeyCodec
txBodyHasUnknownNonCriticals bool
}
var (
_ authsigning.SigFeeMemoTx = &builder{}
_ client.TxBuilder = &builder{}
_ direct.ProtoTx = &builder{}
)
func newBuilder(pubkeyCodec types.PublicKeyCodec) *builder {
@ -122,7 +122,7 @@ func (t *builder) ValidateBasic() error {
return nil
}
func (t *builder) GetBodyBytes() []byte {
func (t *builder) getBodyBytes() []byte {
if len(t.bodyBz) == 0 {
// if bodyBz is empty, then marshal the body. bodyBz will generally
// be set to nil whenever SetBody is called so the result of calling
@ -138,7 +138,7 @@ func (t *builder) GetBodyBytes() []byte {
return t.bodyBz
}
func (t *builder) GetAuthInfoBytes() []byte {
func (t *builder) getAuthInfoBytes() []byte {
if len(t.authInfoBz) == 0 {
// if authInfoBz is empty, then marshal the body. authInfoBz will generally
// be set to nil whenever SetAuthInfo is called so the result of calling

View File

@ -53,7 +53,7 @@ func TestTxBuilder(t *testing.T) {
fee := txtypes.Fee{Amount: sdk.NewCoins(sdk.NewInt64Coin("atom", 150)), GasLimit: 20000}
t.Log("verify that authInfo bytes encoded with DefaultTxEncoder and decoded with DefaultTxDecoder can be retrieved from GetAuthInfoBytes")
t.Log("verify that authInfo bytes encoded with DefaultTxEncoder and decoded with DefaultTxDecoder can be retrieved from getAuthInfoBytes")
authInfo := &txtypes.AuthInfo{
Fee: &fee,
SignerInfos: signerInfo,
@ -63,7 +63,7 @@ func TestTxBuilder(t *testing.T) {
require.NotEmpty(t, authInfoBytes)
t.Log("verify that body bytes encoded with DefaultTxEncoder and decoded with DefaultTxDecoder can be retrieved from GetBodyBytes")
t.Log("verify that body bytes encoded with DefaultTxEncoder and decoded with DefaultTxDecoder can be retrieved from getBodyBytes")
anys := make([]*codectypes.Any, len(msgs))
for i, msg := range msgs {
@ -80,29 +80,29 @@ func TestTxBuilder(t *testing.T) {
}
bodyBytes := marshaler.MustMarshalBinaryBare(txBody)
require.NotEmpty(t, bodyBytes)
require.Empty(t, txBuilder.GetBodyBytes())
require.Empty(t, txBuilder.getBodyBytes())
t.Log("verify that calling the SetMsgs, SetMemo results in the correct GetBodyBytes")
require.NotEqual(t, bodyBytes, txBuilder.GetBodyBytes())
t.Log("verify that calling the SetMsgs, SetMemo results in the correct getBodyBytes")
require.NotEqual(t, bodyBytes, txBuilder.getBodyBytes())
err = txBuilder.SetMsgs(msgs...)
require.NoError(t, err)
require.NotEqual(t, bodyBytes, txBuilder.GetBodyBytes())
require.NotEqual(t, bodyBytes, txBuilder.getBodyBytes())
txBuilder.SetMemo(memo)
require.Equal(t, bodyBytes, txBuilder.GetBodyBytes())
require.Equal(t, bodyBytes, txBuilder.getBodyBytes())
require.Equal(t, len(msgs), len(txBuilder.GetMsgs()))
require.Equal(t, 0, len(txBuilder.GetPubKeys()))
t.Log("verify that updated AuthInfo results in the correct GetAuthInfoBytes and GetPubKeys")
require.NotEqual(t, authInfoBytes, txBuilder.GetAuthInfoBytes())
t.Log("verify that updated AuthInfo results in the correct getAuthInfoBytes and GetPubKeys")
require.NotEqual(t, authInfoBytes, txBuilder.getAuthInfoBytes())
txBuilder.SetFeeAmount(fee.Amount)
require.NotEqual(t, authInfoBytes, txBuilder.GetAuthInfoBytes())
require.NotEqual(t, authInfoBytes, txBuilder.getAuthInfoBytes())
txBuilder.SetGasLimit(fee.GasLimit)
require.NotEqual(t, authInfoBytes, txBuilder.GetAuthInfoBytes())
require.NotEqual(t, authInfoBytes, txBuilder.getAuthInfoBytes())
err = txBuilder.SetSignatures(sig)
require.NoError(t, err)
// once fee, gas and signerInfos are all set, AuthInfo bytes should match
require.Equal(t, authInfoBytes, txBuilder.GetAuthInfoBytes())
require.Equal(t, authInfoBytes, txBuilder.getAuthInfoBytes())
require.Equal(t, len(msgs), len(txBuilder.GetMsgs()))
require.Equal(t, 1, len(txBuilder.GetPubKeys()))
@ -230,24 +230,3 @@ func TestBuilderValidateBasic(t *testing.T) {
err = txBuilder.ValidateBasic()
require.Error(t, err)
}
func TestDefaultTxDecoderError(t *testing.T) {
registry := codectypes.NewInterfaceRegistry()
pubKeyCdc := std.DefaultPublicKeyCodec{}
encoder := DefaultTxEncoder()
decoder := DefaultTxDecoder(registry, pubKeyCdc)
builder := newBuilder(pubKeyCdc)
err := builder.SetMsgs(testdata.NewTestMsg())
require.NoError(t, err)
txBz, err := encoder(builder.GetTx())
require.NoError(t, err)
_, err = decoder(txBz)
require.EqualError(t, err, "no registered implementations of type types.Msg: tx parse error")
registry.RegisterImplementations((*sdk.Msg)(nil), &testdata.TestMsg{})
_, err = decoder(txBz)
require.NoError(t, err)
}

View File

@ -3,6 +3,8 @@ package tx
import (
"fmt"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/client"
@ -11,7 +13,7 @@ import (
"github.com/cosmos/cosmos-sdk/x/auth/signing"
)
type generator struct {
type config struct {
pubkeyCodec types.PublicKeyCodec
handler signing.SignModeHandler
decoder sdk.TxDecoder
@ -21,11 +23,12 @@ type generator struct {
protoCodec *codec.ProtoCodec
}
// NewTxConfig returns a new protobuf TxConfig using the provided ProtoCodec, PublicKeyCodec and SignModeHandler.
func NewTxConfig(protoCodec *codec.ProtoCodec, pubkeyCodec types.PublicKeyCodec, signModeHandler signing.SignModeHandler) client.TxConfig {
return &generator{
// NewTxConfig returns a new protobuf TxConfig using the provided ProtoCodec, PublicKeyCodec and sign modes. The
// first enabled sign mode will become the default sign mode.
func NewTxConfig(protoCodec *codec.ProtoCodec, pubkeyCodec types.PublicKeyCodec, enabledSignModes []signingtypes.SignMode) client.TxConfig {
return &config{
pubkeyCodec: pubkeyCodec,
handler: signModeHandler,
handler: makeSignModeHandler(enabledSignModes),
decoder: DefaultTxDecoder(protoCodec, pubkeyCodec),
encoder: DefaultTxEncoder(),
jsonDecoder: DefaultJSONTxDecoder(protoCodec, pubkeyCodec),
@ -34,12 +37,12 @@ func NewTxConfig(protoCodec *codec.ProtoCodec, pubkeyCodec types.PublicKeyCodec,
}
}
func (g generator) NewTxBuilder() client.TxBuilder {
func (g config) NewTxBuilder() client.TxBuilder {
return newBuilder(g.pubkeyCodec)
}
// WrapTxBuilder returns a builder from provided transaction
func (g generator) WrapTxBuilder(newTx sdk.Tx) (client.TxBuilder, error) {
func (g config) WrapTxBuilder(newTx sdk.Tx) (client.TxBuilder, error) {
newBuilder, ok := newTx.(*builder)
if !ok {
return nil, fmt.Errorf("expected %T, got %T", &builder{}, newTx)
@ -48,22 +51,22 @@ func (g generator) WrapTxBuilder(newTx sdk.Tx) (client.TxBuilder, error) {
return newBuilder, nil
}
func (g generator) SignModeHandler() signing.SignModeHandler {
func (g config) SignModeHandler() signing.SignModeHandler {
return g.handler
}
func (g generator) TxEncoder() sdk.TxEncoder {
func (g config) TxEncoder() sdk.TxEncoder {
return g.encoder
}
func (g generator) TxDecoder() sdk.TxDecoder {
func (g config) TxDecoder() sdk.TxDecoder {
return g.decoder
}
func (g generator) TxJSONEncoder() sdk.TxEncoder {
func (g config) TxJSONEncoder() sdk.TxEncoder {
return g.jsonEncoder
}
func (g generator) TxJSONDecoder() sdk.TxDecoder {
func (g config) TxJSONDecoder() sdk.TxDecoder {
return g.jsonDecoder
}

View File

@ -20,6 +20,5 @@ func TestGenerator(t *testing.T) {
interfaceRegistry.RegisterImplementations((*sdk.Msg)(nil), &testdata.TestMsg{})
marshaler := codec.NewProtoCodec(interfaceRegistry)
pubKeyCodec := std.DefaultPublicKeyCodec{}
signModeHandler := DefaultSignModeHandler()
suite.Run(t, testutil.NewTxConfigTestSuite(NewTxConfig(marshaler, pubKeyCodec, signModeHandler)))
suite.Run(t, testutil.NewTxConfigTestSuite(NewTxConfig(marshaler, pubKeyCodec, DefaultSignModes)))
}

View File

@ -3,8 +3,9 @@ package tx
import (
"github.com/tendermint/tendermint/crypto"
"github.com/cosmos/cosmos-sdk/codec/unknownproto"
"github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/codec/types"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
@ -12,39 +13,71 @@ import (
)
// DefaultTxDecoder returns a default protobuf TxDecoder using the provided Marshaler and PublicKeyCodec
func DefaultTxDecoder(anyUnpacker types.AnyUnpacker, keyCodec cryptotypes.PublicKeyCodec) sdk.TxDecoder {
cdc := codec.NewProtoCodec(anyUnpacker)
func DefaultTxDecoder(cdc *codec.ProtoCodec, keyCodec cryptotypes.PublicKeyCodec) sdk.TxDecoder {
return func(txBytes []byte) (sdk.Tx, error) {
var raw tx.TxRaw
err := cdc.UnmarshalBinaryBare(txBytes, &raw)
// reject all unknown proto fields in the root TxRaw
err := unknownproto.RejectUnknownFieldsStrict(txBytes, &raw)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}
var theTx tx.Tx
err = cdc.UnmarshalBinaryBare(txBytes, &theTx)
err = cdc.UnmarshalBinaryBare(txBytes, &raw)
if err != nil {
return nil, err
}
var body tx.TxBody
// allow non-critical unknown fields in TxBody
txBodyHasUnknownNonCriticals, err := unknownproto.RejectUnknownFields(raw.BodyBytes, &body, true)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}
err = cdc.UnmarshalBinaryBare(raw.BodyBytes, &body)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}
var authInfo tx.AuthInfo
// reject all unknown proto fields in AuthInfo
err = unknownproto.RejectUnknownFieldsStrict(raw.AuthInfoBytes, &authInfo)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}
err = cdc.UnmarshalBinaryBare(raw.AuthInfoBytes, &authInfo)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}
theTx := &tx.Tx{
Body: &body,
AuthInfo: &authInfo,
Signatures: raw.Signatures,
}
pks, err := extractPubKeys(theTx, keyCodec)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}
return &builder{
tx: &theTx,
bodyBz: raw.BodyBytes,
authInfoBz: raw.AuthInfoBytes,
pubKeys: pks,
pubkeyCodec: keyCodec,
tx: theTx,
bodyBz: raw.BodyBytes,
authInfoBz: raw.AuthInfoBytes,
pubKeys: pks,
pubkeyCodec: keyCodec,
txBodyHasUnknownNonCriticals: txBodyHasUnknownNonCriticals,
}, nil
}
}
// DefaultTxDecoder returns a default protobuf JSON TxDecoder using the provided Marshaler and PublicKeyCodec
func DefaultJSONTxDecoder(anyUnpacker types.AnyUnpacker, keyCodec cryptotypes.PublicKeyCodec) sdk.TxDecoder {
cdc := codec.NewProtoCodec(anyUnpacker)
func DefaultJSONTxDecoder(cdc *codec.ProtoCodec, keyCodec cryptotypes.PublicKeyCodec) sdk.TxDecoder {
return func(txBytes []byte) (sdk.Tx, error) {
var theTx tx.Tx
err := cdc.UnmarshalJSON(txBytes, &theTx)
@ -52,7 +85,7 @@ func DefaultJSONTxDecoder(anyUnpacker types.AnyUnpacker, keyCodec cryptotypes.Pu
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}
pks, err := extractPubKeys(theTx, keyCodec)
pks, err := extractPubKeys(&theTx, keyCodec)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}
@ -65,7 +98,7 @@ func DefaultJSONTxDecoder(anyUnpacker types.AnyUnpacker, keyCodec cryptotypes.Pu
}
}
func extractPubKeys(tx tx.Tx, keyCodec cryptotypes.PublicKeyCodec) ([]crypto.PubKey, error) {
func extractPubKeys(tx *tx.Tx, keyCodec cryptotypes.PublicKeyCodec) ([]crypto.PubKey, error) {
if tx.AuthInfo == nil {
return []crypto.PubKey{}, nil
}

56
x/auth/tx/direct.go Normal file
View File

@ -0,0 +1,56 @@
package tx
import (
"fmt"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
sdk "github.com/cosmos/cosmos-sdk/types"
types "github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/x/auth/signing"
)
// signModeDirectHandler defines the SIGN_MODE_DIRECT SignModeHandler
type signModeDirectHandler struct{}
var _ signing.SignModeHandler = signModeDirectHandler{}
// DefaultMode implements SignModeHandler.DefaultMode
func (signModeDirectHandler) DefaultMode() signingtypes.SignMode {
return signingtypes.SignMode_SIGN_MODE_DIRECT
}
// Modes implements SignModeHandler.Modes
func (signModeDirectHandler) Modes() []signingtypes.SignMode {
return []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_DIRECT}
}
// GetSignBytes implements SignModeHandler.GetSignBytes
func (signModeDirectHandler) GetSignBytes(mode signingtypes.SignMode, data signing.SignerData, tx sdk.Tx) ([]byte, error) {
if mode != signingtypes.SignMode_SIGN_MODE_DIRECT {
return nil, fmt.Errorf("expected %s, got %s", signingtypes.SignMode_SIGN_MODE_DIRECT, mode)
}
protoTx, ok := tx.(*builder)
if !ok {
return nil, fmt.Errorf("can only handle a protobuf Tx, got %T", tx)
}
bodyBz := protoTx.getBodyBytes()
authInfoBz := protoTx.getAuthInfoBytes()
return DirectSignBytes(bodyBz, authInfoBz, data.ChainID, data.AccountNumber, data.AccountSequence)
}
// DirectSignBytes returns the SIGN_MODE_DIRECT sign bytes for the provided TxBody bytes, AuthInfo bytes, chain ID,
// account number and sequence.
func DirectSignBytes(bodyBytes, authInfoBytes []byte, chainID string, accnum, sequence uint64) ([]byte, error) {
signDoc := types.SignDoc{
BodyBytes: bodyBytes,
AuthInfoBytes: authInfoBytes,
ChainId: chainID,
AccountNumber: accnum,
AccountSequence: sequence,
}
return signDoc.Marshal()
}

View File

@ -1,13 +1,9 @@
package direct_test
package tx
import (
"fmt"
"testing"
"github.com/cosmos/cosmos-sdk/x/auth/tx"
"github.com/cosmos/cosmos-sdk/x/auth/signing/direct"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
"github.com/cosmos/cosmos-sdk/codec"
@ -30,8 +26,8 @@ func TestDirectModeHandler(t *testing.T) {
marshaler := codec.NewProtoCodec(interfaceRegistry)
pubKeyCdc := std.DefaultPublicKeyCodec{}
txGen := tx.NewTxConfig(marshaler, pubKeyCdc, tx.DefaultSignModeHandler())
txBuilder := txGen.NewTxBuilder()
txConfig := NewTxConfig(marshaler, pubKeyCdc, []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_DIRECT})
txBuilder := txConfig.NewTxBuilder()
memo := "sometestmemo"
msgs := []sdk.Msg{testdata.NewTestMsg(addr)}
@ -71,9 +67,9 @@ func TestDirectModeHandler(t *testing.T) {
require.NoError(t, err)
t.Log("verify modes and default-mode")
directModeHandler := direct.ModeHandler{}
require.Equal(t, directModeHandler.DefaultMode(), signingtypes.SignMode_SIGN_MODE_DIRECT)
require.Len(t, directModeHandler.Modes(), 1)
modeHandler := txConfig.SignModeHandler()
require.Equal(t, modeHandler.DefaultMode(), signingtypes.SignMode_SIGN_MODE_DIRECT)
require.Len(t, modeHandler.Modes(), 1)
signingData := signing.SignerData{
ChainID: "test-chain",
@ -81,7 +77,7 @@ func TestDirectModeHandler(t *testing.T) {
AccountSequence: 1,
}
signBytes, err := directModeHandler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, txBuilder.GetTx())
signBytes, err := modeHandler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, txBuilder.GetTx())
require.NoError(t, err)
require.NotNil(t, signBytes)
@ -127,7 +123,7 @@ func TestDirectModeHandler(t *testing.T) {
require.NoError(t, err)
err = txBuilder.SetSignatures(sig)
require.NoError(t, err)
signBytes, err = directModeHandler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, txBuilder.GetTx())
signBytes, err = modeHandler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, txBuilder.GetTx())
require.NoError(t, err)
require.Equal(t, expectedSignBytes, signBytes)
@ -146,7 +142,7 @@ func TestDirectModeHandler_nonDIRECT_MODE(t *testing.T) {
}
for _, invalidMode := range invalidModes {
t.Run(invalidMode.String(), func(t *testing.T) {
var dh direct.ModeHandler
var dh signModeDirectHandler
var signingData signing.SignerData
_, err := dh.GetSignBytes(invalidMode, signingData, nil)
require.Error(t, err)
@ -164,11 +160,11 @@ func (npt *nonProtoTx) ValidateBasic() error { return nil }
var _ sdk.Tx = (*nonProtoTx)(nil)
func TestDirectModeHandler_nonProtoTx(t *testing.T) {
var dh direct.ModeHandler
var dh signModeDirectHandler
var signingData signing.SignerData
tx := new(nonProtoTx)
_, err := dh.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, tx)
require.Error(t, err)
wantErr := fmt.Errorf("can only get direct sign bytes for a ProtoTx, got %T", tx)
wantErr := fmt.Errorf("can only handle a protobuf Tx, got %T", tx)
require.Equal(t, err, wantErr)
}

View File

@ -0,0 +1,165 @@
package tx
import (
"fmt"
"testing"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/stretchr/testify/require"
"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/std"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
)
func TestDefaultTxDecoderError(t *testing.T) {
registry := codectypes.NewInterfaceRegistry()
cdc := codec.NewProtoCodec(registry)
pubKeyCdc := std.DefaultPublicKeyCodec{}
encoder := DefaultTxEncoder()
decoder := DefaultTxDecoder(cdc, pubKeyCdc)
builder := newBuilder(pubKeyCdc)
err := builder.SetMsgs(testdata.NewTestMsg())
require.NoError(t, err)
txBz, err := encoder(builder.GetTx())
require.NoError(t, err)
_, err = decoder(txBz)
require.EqualError(t, err, "no registered implementations of type types.Msg: tx parse error")
registry.RegisterImplementations((*sdk.Msg)(nil), &testdata.TestMsg{})
_, err = decoder(txBz)
require.NoError(t, err)
}
func TestUnknownFields(t *testing.T) {
registry := codectypes.NewInterfaceRegistry()
cdc := codec.NewProtoCodec(registry)
pubKeyCdc := std.DefaultPublicKeyCodec{}
decoder := DefaultTxDecoder(cdc, pubKeyCdc)
tests := []struct {
name string
body *testdata.TestUpdatedTxBody
authInfo *testdata.TestUpdatedAuthInfo
shouldErr bool
shouldAminoErr string
}{
{
name: "no new fields should pass",
body: &testdata.TestUpdatedTxBody{
Memo: "foo",
},
authInfo: &testdata.TestUpdatedAuthInfo{},
shouldErr: false,
},
{
name: "non-critical fields in TxBody should not error on decode, but should error with amino",
body: &testdata.TestUpdatedTxBody{
Memo: "foo",
SomeNewFieldNonCriticalField: "blah",
},
authInfo: &testdata.TestUpdatedAuthInfo{},
shouldErr: false,
shouldAminoErr: fmt.Sprintf("%s: %s", aminoNonCriticalFieldsError, sdkerrors.ErrInvalidRequest.Error()),
},
{
name: "critical fields in TxBody should error on decode",
body: &testdata.TestUpdatedTxBody{
Memo: "foo",
SomeNewField: 10,
},
authInfo: &testdata.TestUpdatedAuthInfo{},
shouldErr: true,
},
{
name: "critical fields in AuthInfo should error on decode",
body: &testdata.TestUpdatedTxBody{
Memo: "foo",
},
authInfo: &testdata.TestUpdatedAuthInfo{
NewField_3: []byte("xyz"),
},
shouldErr: true,
},
{
name: "non-critical fields in AuthInfo should error on decode",
body: &testdata.TestUpdatedTxBody{
Memo: "foo",
},
authInfo: &testdata.TestUpdatedAuthInfo{
NewField_1024: []byte("xyz"),
},
shouldErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
bodyBz, err := tt.body.Marshal()
require.NoError(t, err)
authInfoBz, err := tt.authInfo.Marshal()
require.NoError(t, err)
txRaw := &tx.TxRaw{
BodyBytes: bodyBz,
AuthInfoBytes: authInfoBz,
}
txBz, err := txRaw.Marshal()
require.NoError(t, err)
_, err = decoder(txBz)
if tt.shouldErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
if tt.shouldAminoErr != "" {
handler := signModeLegacyAminoJSONHandler{}
decoder := DefaultTxDecoder(codec.NewProtoCodec(codectypes.NewInterfaceRegistry()), std.DefaultPublicKeyCodec{})
theTx, err := decoder(txBz)
require.NoError(t, err)
_, err = handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signing.SignerData{}, theTx)
require.EqualError(t, err, tt.shouldAminoErr)
}
})
}
t.Log("test TxRaw no new fields, should succeed")
txRaw := &testdata.TestUpdatedTxRaw{}
txBz, err := txRaw.Marshal()
require.NoError(t, err)
_, err = decoder(txBz)
require.NoError(t, err)
t.Log("new field in TxRaw should fail")
txRaw = &testdata.TestUpdatedTxRaw{
NewField_5: []byte("abc"),
}
txBz, err = txRaw.Marshal()
require.NoError(t, err)
_, err = decoder(txBz)
require.Error(t, err)
//
t.Log("new \"non-critical\" field in TxRaw should fail")
txRaw = &testdata.TestUpdatedTxRaw{
NewField_1024: []byte("abc"),
}
txBz, err = txRaw.Marshal()
require.NoError(t, err)
_, err = decoder(txBz)
require.Error(t, err)
}

View File

@ -19,8 +19,8 @@ func DefaultTxEncoder() types.TxEncoder {
}
raw := &txtypes.TxRaw{
BodyBytes: wrapper.GetBodyBytes(),
AuthInfoBytes: wrapper.GetAuthInfoBytes(),
BodyBytes: wrapper.getBodyBytes(),
AuthInfoBytes: wrapper.getAuthInfoBytes(),
Signatures: wrapper.tx.Signatures,
}

View File

@ -0,0 +1,61 @@
package tx
import (
"fmt"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/x/auth/types"
sdk "github.com/cosmos/cosmos-sdk/types"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/signing"
)
// signModeLegacyAminoJSONHandler defines the SIGN_MODE_LEGACY_AMINO_JSON SignModeHandler
type signModeLegacyAminoJSONHandler struct{}
func (s signModeLegacyAminoJSONHandler) DefaultMode() signingtypes.SignMode {
return signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON
}
func (s signModeLegacyAminoJSONHandler) Modes() []signingtypes.SignMode {
return []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON}
}
const aminoNonCriticalFieldsError = "protobuf transaction contains unknown non-critical fields. This is a transaction malleability issue and SIGN_MODE_LEGACY_AMINO_JSON cannot be used."
func (s signModeLegacyAminoJSONHandler) GetSignBytes(mode signingtypes.SignMode, data signing.SignerData, tx sdk.Tx) ([]byte, error) {
if mode != signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON {
return nil, fmt.Errorf("expected %s, got %s", signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, mode)
}
protoTx, ok := tx.(*builder)
if !ok {
return nil, fmt.Errorf("can only handle a protobuf Tx, got %T", tx)
}
if protoTx.txBodyHasUnknownNonCriticals {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, aminoNonCriticalFieldsError)
}
body := protoTx.tx.Body
if len(body.ExtensionOptions) != 0 || len(body.NonCriticalExtensionOptions) != 0 {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest,
"SIGN_MODE_LEGACY_AMINO_JSON does not support protobuf extension options.")
}
if body.TimeoutHeight != 0 {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest,
"SIGN_MODE_LEGACY_AMINO_JSON does not support timeout height.")
}
return types.StdSignBytes(
data.ChainID, data.AccountNumber, data.AccountSequence,
//nolint:staticcheck
types.StdFee{Amount: protoTx.GetFee(), Gas: protoTx.GetGas()},
tx.GetMsgs(), protoTx.GetMemo(),
), nil
}
var _ signing.SignModeHandler = signModeLegacyAminoJSONHandler{}

View File

@ -0,0 +1,101 @@
package tx
import (
"testing"
cdctypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/stretchr/testify/require"
"github.com/cosmos/cosmos-sdk/std"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/cosmos/cosmos-sdk/x/auth/types"
)
var (
_, _, addr1 = testdata.KeyTestPubAddr()
_, _, addr2 = testdata.KeyTestPubAddr()
coins = sdk.Coins{sdk.NewInt64Coin("foocoin", 10)}
gas = uint64(10000)
msg = testdata.NewTestMsg(addr1, addr2)
memo = "foo"
)
func buildTx(t *testing.T, bldr *builder) {
bldr.SetFeeAmount(coins)
bldr.SetGasLimit(gas)
bldr.SetMemo(memo)
require.NoError(t, bldr.SetMsgs(msg))
}
func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
bldr := newBuilder(std.DefaultPublicKeyCodec{})
buildTx(t, bldr)
tx := bldr.GetTx()
var (
chainId = "test-chain"
accNum uint64 = 7
seqNum uint64 = 7
)
handler := signModeLegacyAminoJSONHandler{}
signingData := signing.SignerData{
ChainID: chainId,
AccountNumber: accNum,
AccountSequence: seqNum,
}
signBz, err := handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signingData, tx)
require.NoError(t, err)
expectedSignBz := types.StdSignBytes(chainId, accNum, seqNum, types.StdFee{
Amount: coins,
Gas: gas,
}, []sdk.Msg{msg}, memo)
require.Equal(t, expectedSignBz, signBz)
// expect error with wrong sign mode
_, err = handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, tx)
require.Error(t, err)
// expect error with timeout height
bldr = newBuilder(std.DefaultPublicKeyCodec{})
buildTx(t, bldr)
bldr.tx.Body.TimeoutHeight = 10
tx = bldr.GetTx()
signBz, err = handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signingData, tx)
require.Error(t, err)
// expect error with extension options
bldr = newBuilder(std.DefaultPublicKeyCodec{})
buildTx(t, bldr)
any, err := cdctypes.NewAnyWithValue(testdata.NewTestMsg())
require.NoError(t, err)
bldr.tx.Body.ExtensionOptions = []*cdctypes.Any{any}
tx = bldr.GetTx()
signBz, err = handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signingData, tx)
require.Error(t, err)
// expect error with non-critical extension options
bldr = newBuilder(std.DefaultPublicKeyCodec{})
buildTx(t, bldr)
bldr.tx.Body.NonCriticalExtensionOptions = []*cdctypes.Any{any}
tx = bldr.GetTx()
signBz, err = handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signingData, tx)
require.Error(t, err)
}
func TestLegacyAminoJSONHandler_DefaultMode(t *testing.T) {
handler := signModeLegacyAminoJSONHandler{}
require.Equal(t, signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, handler.DefaultMode())
}
func TestLegacyAminoJSONHandler_Modes(t *testing.T) {
handler := signModeLegacyAminoJSONHandler{}
require.Equal(t, []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON}, handler.Modes())
}

View File

@ -1,20 +1,40 @@
package tx
import (
signing2 "github.com/cosmos/cosmos-sdk/types/tx/signing"
"fmt"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/cosmos/cosmos-sdk/x/auth/signing/direct"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
)
// DefaultSignModeHandler returns the default protobuf SignModeHandler supporting
// DefaultSignModes are the default sign modes enabled for protobuf transactions.
var DefaultSignModes = []signingtypes.SignMode{
signingtypes.SignMode_SIGN_MODE_DIRECT,
signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON,
}
// makeSignModeHandler returns the default protobuf SignModeHandler supporting
// SIGN_MODE_DIRECT and SIGN_MODE_LEGACY_AMINO_JSON.
func DefaultSignModeHandler() signing.SignModeHandler {
func makeSignModeHandler(modes []signingtypes.SignMode) signing.SignModeHandler {
if len(modes) < 1 {
panic(fmt.Errorf("no sign modes enabled"))
}
handlers := make([]signing.SignModeHandler, len(modes))
for i, mode := range modes {
switch mode {
case signingtypes.SignMode_SIGN_MODE_DIRECT:
handlers[i] = signModeDirectHandler{}
case signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON:
handlers[i] = signModeLegacyAminoJSONHandler{}
default:
panic(fmt.Errorf("unsupported sign mode %+v", mode))
}
}
return signing.NewSignModeHandlerMap(
signing2.SignMode_SIGN_MODE_DIRECT,
[]signing.SignModeHandler{
authtypes.LegacyAminoJSONHandler{},
direct.ModeHandler{},
},
modes[0],
handlers,
)
}

View File

@ -105,7 +105,7 @@ func decodeMultisignatures(bz []byte) ([][]byte, error) {
return multisig.Signatures, nil
}
func (g generator) MarshalSignatureJSON(sigs []signing.SignatureV2) ([]byte, error) {
func (g config) MarshalSignatureJSON(sigs []signing.SignatureV2) ([]byte, error) {
descs := make([]*signing.SignatureDescriptor, len(sigs))
for i, sig := range sigs {
@ -127,7 +127,7 @@ func (g generator) MarshalSignatureJSON(sigs []signing.SignatureV2) ([]byte, err
return codec.ProtoMarshalJSON(toJSON)
}
func (g generator) UnmarshalSignatureJSON(bz []byte) ([]signing.SignatureV2, error) {
func (g config) UnmarshalSignatureJSON(bz []byte) ([]signing.SignatureV2, error) {
var sigDescs signing.SignatureDescriptors
err := g.protoCodec.UnmarshalJSON(bz, &sigDescs)
if err != nil {

View File

@ -8,38 +8,37 @@ import (
"github.com/cosmos/cosmos-sdk/x/auth/signing"
)
// LegacyAminoJSONHandler is a SignModeHandler that handles SIGN_MODE_LEGACY_AMINO_JSON
type LegacyAminoJSONHandler struct{}
// stdTxSignModeHandler is a SignModeHandler that handles SIGN_MODE_LEGACY_AMINO_JSON
type stdTxSignModeHandler struct{}
var _ signing.SignModeHandler = LegacyAminoJSONHandler{}
func NewStdTxSignModeHandler() signing.SignModeHandler {
return &stdTxSignModeHandler{}
}
var _ signing.SignModeHandler = stdTxSignModeHandler{}
// DefaultMode implements SignModeHandler.DefaultMode
func (h LegacyAminoJSONHandler) DefaultMode() signingtypes.SignMode {
func (h stdTxSignModeHandler) DefaultMode() signingtypes.SignMode {
return signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON
}
// Modes implements SignModeHandler.Modes
func (LegacyAminoJSONHandler) Modes() []signingtypes.SignMode {
func (stdTxSignModeHandler) Modes() []signingtypes.SignMode {
return []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON}
}
// DefaultMode implements SignModeHandler.GetSignBytes
func (LegacyAminoJSONHandler) GetSignBytes(mode signingtypes.SignMode, data signing.SignerData, tx sdk.Tx) ([]byte, error) {
func (stdTxSignModeHandler) GetSignBytes(mode signingtypes.SignMode, data signing.SignerData, tx sdk.Tx) ([]byte, error) {
if mode != signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON {
return nil, fmt.Errorf("expected %s, got %s", signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, mode)
}
feeTx, ok := tx.(sdk.FeeTx)
stdTx, ok := tx.(StdTx)
if !ok {
return nil, fmt.Errorf("expected FeeTx, got %T", tx)
}
memoTx, ok := tx.(sdk.TxWithMemo)
if !ok {
return nil, fmt.Errorf("expected TxWithMemo, got %T", tx)
return nil, fmt.Errorf("expected %T, got %T", StdTx{}, tx)
}
return StdSignBytes(
data.ChainID, data.AccountNumber, data.AccountSequence, StdFee{Amount: feeTx.GetFee(), Gas: feeTx.GetGas()}, tx.GetMsgs(), memoTx.GetMemo(),
data.ChainID, data.AccountNumber, data.AccountSequence, StdFee{Amount: stdTx.GetFee(), Gas: stdTx.GetGas()}, tx.GetMsgs(), stdTx.GetMemo(),
), nil
}

View File

@ -1,18 +1,16 @@
package types_test
package types
import (
"testing"
"github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/cosmos/cosmos-sdk/x/auth/types"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/crypto/secp256k1"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
banktypes "github.com/cosmos/cosmos-sdk/x/bank/types"
"github.com/cosmos/cosmos-sdk/x/auth/signing"
)
func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
@ -23,20 +21,16 @@ func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
coins := sdk.Coins{sdk.NewInt64Coin("foocoin", 10)}
fee := types.StdFee{
fee := StdFee{
Amount: coins,
Gas: 10000,
}
memo := "foo"
msgs := []sdk.Msg{
&banktypes.MsgSend{
FromAddress: addr1,
ToAddress: addr2,
Amount: coins,
},
testdata.NewTestMsg(addr1, addr2),
}
tx := types.StdTx{
tx := StdTx{
Msgs: msgs,
Fee: fee,
Signatures: nil,
@ -49,7 +43,7 @@ func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
seqNum uint64 = 7
)
handler := types.LegacyAminoJSONHandler{}
handler := stdTxSignModeHandler{}
signingData := signing.SignerData{
ChainID: chainId,
AccountNumber: accNum,
@ -58,7 +52,7 @@ func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
signBz, err := handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signingData, tx)
require.NoError(t, err)
expectedSignBz := types.StdSignBytes(chainId, accNum, seqNum, fee, msgs, memo)
expectedSignBz := StdSignBytes(chainId, accNum, seqNum, fee, msgs, memo)
require.Equal(t, expectedSignBz, signBz)
@ -68,11 +62,11 @@ func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
}
func TestLegacyAminoJSONHandler_DefaultMode(t *testing.T) {
handler := types.LegacyAminoJSONHandler{}
handler := stdTxSignModeHandler{}
require.Equal(t, signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, handler.DefaultMode())
}
func TestLegacyAminoJSONHandler_Modes(t *testing.T) {
handler := types.LegacyAminoJSONHandler{}
handler := stdTxSignModeHandler{}
require.Equal(t, []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON}, handler.Modes())
}

View File

@ -146,5 +146,5 @@ func (s StdTxConfig) UnmarshalSignatureJSON(bz []byte) ([]signing.SignatureV2, e
}
func (s StdTxConfig) SignModeHandler() authsigning.SignModeHandler {
return LegacyAminoJSONHandler{}
return stdTxSignModeHandler{}
}

File diff suppressed because it is too large Load Diff