Reject unknown fields in TxDecoder and sign mode handlers (#6883)
* WIP on unknown field rejection in TxDecoder * WIP on unknown field rejection in TxDecoder * WIP * WIP * WIP * WIP * Fix bugs with RejectUnknownFields * Fix tests * Fix bug and update docs * Lint * Add tests * Add unknown field tests * Lint * Address review comments
This commit is contained in:
parent
57cd7d62b3
commit
6d937443b2
|
@ -65,7 +65,7 @@ type TestSuite struct {
|
|||
func (s *TestSuite) SetupSuite() {
|
||||
encCfg := simapp.MakeEncodingConfig()
|
||||
s.encCfg = encCfg
|
||||
s.protoCfg = tx.NewTxConfig(codec.NewProtoCodec(encCfg.InterfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModeHandler())
|
||||
s.protoCfg = tx.NewTxConfig(codec.NewProtoCodec(encCfg.InterfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModes)
|
||||
s.aminoCfg = types3.StdTxConfig{Cdc: encCfg.Amino}
|
||||
}
|
||||
|
||||
|
|
|
@ -53,11 +53,10 @@ func benchmarkRejectUnknownFields(b *testing.B, parallel bool) {
|
|||
b.ReportAllocs()
|
||||
|
||||
if !parallel {
|
||||
ckr := new(unknownproto.Checker)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
n1A := new(testdata.Nested1A)
|
||||
if err := ckr.RejectUnknownFields(n1BBlob, n1A); err == nil {
|
||||
if err := unknownproto.RejectUnknownFieldsStrict(n1BBlob, n1A); err == nil {
|
||||
b.Fatal("expected an error")
|
||||
}
|
||||
b.SetBytes(int64(len(n1BBlob)))
|
||||
|
@ -66,11 +65,10 @@ func benchmarkRejectUnknownFields(b *testing.B, parallel bool) {
|
|||
var mu sync.Mutex
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
ckr := new(unknownproto.Checker)
|
||||
for pb.Next() {
|
||||
// To simulate the conditions of multiple transactions being processed in parallel.
|
||||
n1A := new(testdata.Nested1A)
|
||||
if err := ckr.RejectUnknownFields(n1BBlob, n1A); err == nil {
|
||||
if err := unknownproto.RejectUnknownFieldsStrict(n1BBlob, n1A); err == nil {
|
||||
b.Fatal("expected an error")
|
||||
}
|
||||
mu.Lock()
|
||||
|
|
|
@ -6,22 +6,18 @@ a) Unknown fields in the stream -- this is indicative of mismatched services, pe
|
|||
|
||||
b) Mismatched wire types for a field -- this is indicative of mismatched services
|
||||
|
||||
Its API signature is similar to proto.Unmarshal([]byte, proto.Message) as
|
||||
Its API signature is similar to proto.Unmarshal([]byte, proto.Message) in the strict case
|
||||
|
||||
ckr := new(unknownproto.Checker)
|
||||
if err := ckr.RejectUnknownFields(protoBlob, protoMessage); err != nil {
|
||||
if err := RejectUnknownFieldsStrict(protoBlob, protoMessage, false); err != nil {
|
||||
// Handle the error.
|
||||
}
|
||||
|
||||
and ideally should be added before invoking proto.Unmarshal, if you'd like to enforce the features mentioned above.
|
||||
|
||||
By default, for security we report every single field that's unknown, whether a non-critical field or not. To customize
|
||||
this behavior, please create a Checker and set the AllowUnknownNonCriticals to true, for example:
|
||||
this behavior, please set the boolean parameter allowUnknownNonCriticals to true to RejectUnknownFields:
|
||||
|
||||
ckr := &unknownproto.Checker{
|
||||
AllowUnknownNonCriticals: true,
|
||||
}
|
||||
if err := ckr.RejectUnknownFields(protoBlob, protoMessage); err != nil {
|
||||
if err := RejectUnknownFields(protoBlob, protoMessage, true); err != nil {
|
||||
// Handle the error.
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -22,30 +22,37 @@ type descriptorIface interface {
|
|||
Descriptor() ([]byte, []int)
|
||||
}
|
||||
|
||||
type Checker struct {
|
||||
// AllowUnknownNonCriticals when set will skip over non-critical fields that are unknown.
|
||||
AllowUnknownNonCriticals bool
|
||||
// RejectUnknownFieldsStrict rejects any bytes bz with an error that has unknown fields for the provided proto.Message type.
|
||||
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
|
||||
func RejectUnknownFieldsStrict(bz []byte, msg proto.Message) error {
|
||||
_, err := RejectUnknownFields(bz, msg, false)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
// RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an
|
||||
// option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the
|
||||
// hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be
|
||||
// used to treat a message with non-critical field different in different security contexts (such as transaction signing).
|
||||
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
|
||||
func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool) (hasUnknownNonCriticals bool, err error) {
|
||||
if len(bz) == 0 {
|
||||
return hasUnknownNonCriticals, nil
|
||||
}
|
||||
|
||||
desc, ok := msg.(descriptorIface)
|
||||
if !ok {
|
||||
return fmt.Errorf("%T does not have a Descriptor() method", msg)
|
||||
return hasUnknownNonCriticals, fmt.Errorf("%T does not have a Descriptor() method", msg)
|
||||
}
|
||||
|
||||
fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
return hasUnknownNonCriticals, err
|
||||
}
|
||||
|
||||
for len(b) > 0 {
|
||||
tagNum, wireType, n := protowire.ConsumeField(b)
|
||||
if n < 0 {
|
||||
return errors.New("invalid length")
|
||||
for len(bz) > 0 {
|
||||
tagNum, wireType, m := protowire.ConsumeTag(bz)
|
||||
if m < 0 {
|
||||
return hasUnknownNonCriticals, errors.New("invalid length")
|
||||
}
|
||||
|
||||
fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)]
|
||||
|
@ -53,7 +60,7 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
|
|||
case ok:
|
||||
// Assert that the wireTypes match.
|
||||
if !canEncodeType(wireType, fieldDescProto.GetType()) {
|
||||
return &errMismatchedWireType{
|
||||
return hasUnknownNonCriticals, &errMismatchedWireType{
|
||||
Type: reflect.ValueOf(msg).Type().String(),
|
||||
TagNum: tagNum,
|
||||
GotWireType: wireType,
|
||||
|
@ -62,9 +69,15 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
|
|||
}
|
||||
|
||||
default:
|
||||
if !ckr.AllowUnknownNonCriticals || tagNum&bit11NonCritical == 0 {
|
||||
isCriticalField := tagNum&bit11NonCritical == 0
|
||||
|
||||
if !isCriticalField {
|
||||
hasUnknownNonCriticals = true
|
||||
}
|
||||
|
||||
if isCriticalField || !allowUnknownNonCriticals {
|
||||
// The tag is critical, so report it.
|
||||
return &errUnknownField{
|
||||
return hasUnknownNonCriticals, &errUnknownField{
|
||||
Type: reflect.ValueOf(msg).Type().String(),
|
||||
TagNum: tagNum,
|
||||
WireType: wireType,
|
||||
|
@ -72,9 +85,11 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Skip over the 2 bytes that store fieldNumber and wireType bytes.
|
||||
fieldBytes := b[2:n]
|
||||
b = b[n:]
|
||||
// Skip over the bytes that store fieldNumber and wireType bytes.
|
||||
bz = bz[m:]
|
||||
n := protowire.ConsumeFieldValue(tagNum, wireType, bz)
|
||||
fieldBytes := bz[:n]
|
||||
bz = bz[n:]
|
||||
|
||||
// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
|
||||
if fieldDescProto == nil || fieldDescProto.IsScalar() {
|
||||
|
@ -89,22 +104,28 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
|
|||
// TYPE_BYTES and TYPE_STRING as per
|
||||
// https://github.com/gogo/protobuf/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
|
||||
default:
|
||||
return fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
|
||||
return hasUnknownNonCriticals, fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Let's recursively traverse and typecheck the field.
|
||||
|
||||
// consume length prefix of nested message
|
||||
_, o := protowire.ConsumeVarint(fieldBytes)
|
||||
fieldBytes = fieldBytes[o:]
|
||||
|
||||
if protoMessageName == ".google.protobuf.Any" {
|
||||
// Firstly typecheck types.Any to ensure nothing snuck in.
|
||||
if err := ckr.RejectUnknownFields(fieldBytes, (*types.Any)(nil)); err != nil {
|
||||
return err
|
||||
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals)
|
||||
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
|
||||
if err != nil {
|
||||
return hasUnknownNonCriticals, err
|
||||
}
|
||||
// And finally we can extract the TypeURL containing the protoMessageName.
|
||||
any := new(types.Any)
|
||||
if err := proto.Unmarshal(fieldBytes, any); err != nil {
|
||||
return err
|
||||
return hasUnknownNonCriticals, err
|
||||
}
|
||||
protoMessageName = any.TypeUrl
|
||||
fieldBytes = any.Value
|
||||
|
@ -112,14 +133,17 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
|
|||
|
||||
msg, err := protoMessageForTypeName(protoMessageName[1:])
|
||||
if err != nil {
|
||||
return err
|
||||
return hasUnknownNonCriticals, err
|
||||
}
|
||||
if err := ckr.RejectUnknownFields(fieldBytes, msg); err != nil {
|
||||
return err
|
||||
|
||||
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals)
|
||||
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
|
||||
if err != nil {
|
||||
return hasUnknownNonCriticals, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return hasUnknownNonCriticals, nil
|
||||
}
|
||||
|
||||
var protoMessageForTypeNameMu sync.RWMutex
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/codec/types"
|
||||
|
@ -17,6 +19,7 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
|
|||
recv proto.Message
|
||||
wantErr error
|
||||
allowUnknownNonCriticals bool
|
||||
hasUnknownNonCriticals bool
|
||||
}{
|
||||
{
|
||||
name: "Unknown field in midst of repeated values",
|
||||
|
@ -172,6 +175,7 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
|
|||
TagNum: 1031,
|
||||
WireType: 2,
|
||||
},
|
||||
hasUnknownNonCriticals: true,
|
||||
},
|
||||
{
|
||||
name: "Unknown field in midst of repeated values, non-critical field ignored",
|
||||
|
@ -213,8 +217,9 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
recv: new(testdata.TestVersion1),
|
||||
wantErr: nil,
|
||||
recv: new(testdata.TestVersion1),
|
||||
wantErr: nil,
|
||||
hasUnknownNonCriticals: true,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -225,11 +230,9 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ckr := &Checker{AllowUnknownNonCriticals: tt.allowUnknownNonCriticals}
|
||||
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
|
||||
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||
t.Fatalf("Error mismatch\nGot:\n%v\n\nWant:\n%v", gotErr, tt.wantErr)
|
||||
}
|
||||
hasUnknownNonCriticals, gotErr := RejectUnknownFields(protoBlob, tt.recv, tt.allowUnknownNonCriticals)
|
||||
require.Equal(t, tt.wantErr, gotErr)
|
||||
require.Equal(t, tt.hasUnknownNonCriticals, hasUnknownNonCriticals)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -263,7 +266,7 @@ func TestRejectUnknownFields_allowUnknownNonCriticals(t *testing.T) {
|
|||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Unkown fields that are critical, but with allowUnknownNonCriticals set",
|
||||
name: "Unknown fields that are critical, but with allowUnknownNonCriticals set",
|
||||
allowUnknownNonCriticals: true,
|
||||
in: &testdata.Customer2{
|
||||
Id: 289,
|
||||
|
@ -285,9 +288,8 @@ func TestRejectUnknownFields_allowUnknownNonCriticals(t *testing.T) {
|
|||
t.Fatalf("Failed to marshal input: %v", err)
|
||||
}
|
||||
|
||||
ckr := &Checker{AllowUnknownNonCriticals: tt.allowUnknownNonCriticals}
|
||||
c1 := new(testdata.Customer1)
|
||||
gotErr := ckr.RejectUnknownFields(blob, c1)
|
||||
_, gotErr := RejectUnknownFields(blob, c1, tt.allowUnknownNonCriticals)
|
||||
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
|
||||
}
|
||||
|
@ -498,8 +500,7 @@ func TestRejectUnknownFieldsNested(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ckr := new(Checker)
|
||||
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
|
||||
gotErr := RejectUnknownFieldsStrict(protoBlob, tt.recv)
|
||||
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
|
||||
}
|
||||
|
@ -652,8 +653,7 @@ func TestRejectUnknownFieldsFlat(t *testing.T) {
|
|||
}
|
||||
|
||||
c1 := new(testdata.Customer1)
|
||||
ckr := new(Checker)
|
||||
gotErr := ckr.RejectUnknownFields(blob, c1)
|
||||
gotErr := RejectUnknownFieldsStrict(blob, c1)
|
||||
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
|
||||
}
|
||||
|
@ -738,8 +738,7 @@ func TestMismatchedTypes_Nested(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ckr := new(Checker)
|
||||
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
|
||||
_, gotErr := RejectUnknownFields(protoBlob, tt.recv, false)
|
||||
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
|
||||
}
|
||||
|
|
1
go.mod
1
go.mod
|
@ -12,7 +12,6 @@ require (
|
|||
github.com/cosmos/go-bip39 v0.0.0-20180819234021-555e2067c45d
|
||||
github.com/cosmos/ledger-cosmos-go v0.11.1
|
||||
github.com/enigmampc/btcutil v1.0.3-0.20200723161021-e2fb6adb2a25
|
||||
github.com/gibson042/canonicaljson-go v1.0.3
|
||||
github.com/gogo/protobuf v1.3.1
|
||||
github.com/golang/mock v1.4.4
|
||||
github.com/golang/protobuf v1.4.2
|
||||
|
|
4
go.sum
4
go.sum
|
@ -144,8 +144,6 @@ github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2
|
|||
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||
github.com/gibson042/canonicaljson-go v1.0.3 h1:EAyF8L74AWabkyUmrvEFHEt/AGFQeD6RfwbAuf0j1bI=
|
||||
github.com/gibson042/canonicaljson-go v1.0.3/go.mod h1:DsLpJTThXyGNO+KZlI85C1/KDcImpP67k/RKVjcaEqo=
|
||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
|
@ -480,8 +478,6 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
|
|||
github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE=
|
||||
github.com/spf13/viper v1.6.2/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k=
|
||||
github.com/spf13/viper v1.6.3/go.mod h1:jUMtyi0/lB5yZH/FjyGAoH7IMNrIhlBf6pXZmbMDvzw=
|
||||
github.com/spf13/viper v1.7.0 h1:xVKxvI7ouOI5I+U9s2eeiUfMaWBVoXA3AWskkrqK0VM=
|
||||
github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg=
|
||||
github.com/spf13/viper v1.7.1 h1:pM5oEahlgWv/WnHXpgbKz7iLIxRf65tye2Ci+XFK5sk=
|
||||
github.com/spf13/viper v1.7.1/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg=
|
||||
github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
|
||||
|
|
|
@ -14,7 +14,7 @@ func MakeEncodingConfig() EncodingConfig {
|
|||
amino := codec.New()
|
||||
interfaceRegistry := types.NewInterfaceRegistry()
|
||||
marshaler := codec.NewHybridCodec(amino, interfaceRegistry)
|
||||
txGen := tx.NewTxConfig(codec.NewProtoCodec(interfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModeHandler())
|
||||
txGen := tx.NewTxConfig(codec.NewProtoCodec(interfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModes)
|
||||
|
||||
return EncodingConfig{
|
||||
InterfaceRegistry: interfaceRegistry,
|
||||
|
|
|
@ -4,8 +4,9 @@ import (
|
|||
"container/list"
|
||||
"errors"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/types/kv"
|
||||
dbm "github.com/tendermint/tm-db"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/types/kv"
|
||||
)
|
||||
|
||||
// Iterates over iterKVCache items.
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -3,6 +3,7 @@ package testdata;
|
|||
|
||||
import "gogoproto/gogo.proto";
|
||||
import "google/protobuf/any.proto";
|
||||
import "cosmos/tx/tx.proto";
|
||||
|
||||
option go_package = "github.com/cosmos/cosmos-sdk/testutil/testdata";
|
||||
|
||||
|
@ -341,3 +342,28 @@ message AnyWithExtra {
|
|||
int64 b = 3;
|
||||
int64 c = 4;
|
||||
}
|
||||
|
||||
message TestUpdatedTxRaw {
|
||||
bytes body_bytes = 1;
|
||||
bytes auth_info_bytes = 2;
|
||||
repeated bytes signatures = 3;
|
||||
bytes new_field_5 = 5;
|
||||
bytes new_field_1024 = 1024;
|
||||
}
|
||||
|
||||
message TestUpdatedTxBody {
|
||||
repeated google.protobuf.Any messages = 1;
|
||||
string memo = 2;
|
||||
int64 timeout_height = 3;
|
||||
uint64 some_new_field = 4;
|
||||
string some_new_field_non_critical_field = 1050;
|
||||
repeated google.protobuf.Any extension_options = 1023;
|
||||
repeated google.protobuf.Any non_critical_extension_options = 2047;
|
||||
}
|
||||
|
||||
message TestUpdatedAuthInfo {
|
||||
repeated cosmos.tx.SignerInfo signer_infos = 1;
|
||||
cosmos.tx.Fee fee = 2;
|
||||
bytes new_field_3 = 3;
|
||||
bytes new_field_1024 = 1024;
|
||||
}
|
||||
|
|
|
@ -1,66 +0,0 @@
|
|||
package direct
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
types "github.com/cosmos/cosmos-sdk/types/tx"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
)
|
||||
|
||||
// ProtoTx defines an interface which protobuf transactions must implement for
|
||||
// signature verification via SignModeDirect
|
||||
type ProtoTx interface {
|
||||
// GetBodyBytes returns the raw serialized bytes for TxBody
|
||||
GetBodyBytes() []byte
|
||||
|
||||
// GetBodyBytes returns the raw serialized bytes for AuthInfo
|
||||
GetAuthInfoBytes() []byte
|
||||
}
|
||||
|
||||
// ModeHandler defines the SIGN_MODE_DIRECT SignModeHandler
|
||||
type ModeHandler struct{}
|
||||
|
||||
var _ signing.SignModeHandler = ModeHandler{}
|
||||
|
||||
// DefaultMode implements SignModeHandler.DefaultMode
|
||||
func (ModeHandler) DefaultMode() signingtypes.SignMode {
|
||||
return signingtypes.SignMode_SIGN_MODE_DIRECT
|
||||
}
|
||||
|
||||
// Modes implements SignModeHandler.Modes
|
||||
func (ModeHandler) Modes() []signingtypes.SignMode {
|
||||
return []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_DIRECT}
|
||||
}
|
||||
|
||||
// GetSignBytes implements SignModeHandler.GetSignBytes
|
||||
func (ModeHandler) GetSignBytes(mode signingtypes.SignMode, data signing.SignerData, tx sdk.Tx) ([]byte, error) {
|
||||
if mode != signingtypes.SignMode_SIGN_MODE_DIRECT {
|
||||
return nil, fmt.Errorf("expected %s, got %s", signingtypes.SignMode_SIGN_MODE_DIRECT, mode)
|
||||
}
|
||||
|
||||
protoTx, ok := tx.(ProtoTx)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can only get direct sign bytes for a ProtoTx, got %T", tx)
|
||||
}
|
||||
|
||||
bodyBz := protoTx.GetBodyBytes()
|
||||
authInfoBz := protoTx.GetAuthInfoBytes()
|
||||
|
||||
return SignBytes(bodyBz, authInfoBz, data.ChainID, data.AccountNumber, data.AccountSequence)
|
||||
}
|
||||
|
||||
// SignBytes returns the SIGN_MODE_DIRECT sign bytes for the provided TxBody bytes, AuthInfo bytes, chain ID,
|
||||
// account number and sequence.
|
||||
func SignBytes(bodyBytes, authInfoBytes []byte, chainID string, accnum, sequence uint64) ([]byte, error) {
|
||||
signDoc := types.SignDoc{
|
||||
BodyBytes: bodyBytes,
|
||||
AuthInfoBytes: authInfoBytes,
|
||||
ChainId: chainID,
|
||||
AccountNumber: accnum,
|
||||
AccountSequence: sequence,
|
||||
}
|
||||
return signDoc.Marshal()
|
||||
}
|
|
@ -18,7 +18,7 @@ func MakeTestHandlerMap() signing.SignModeHandler {
|
|||
return signing.NewSignModeHandlerMap(
|
||||
signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON,
|
||||
[]signing.SignModeHandler{
|
||||
authtypes.LegacyAminoJSONHandler{},
|
||||
authtypes.NewStdTxSignModeHandler(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -58,7 +58,7 @@ func TestHandlerMap_GetSignBytes(t *testing.T) {
|
|||
)
|
||||
|
||||
handler := MakeTestHandlerMap()
|
||||
aminoJSONHandler := authtypes.LegacyAminoJSONHandler{}
|
||||
aminoJSONHandler := authtypes.NewStdTxSignModeHandler()
|
||||
|
||||
signingData := signing.SignerData{
|
||||
ChainID: chainId,
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
"github.com/cosmos/cosmos-sdk/types/tx"
|
||||
"github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing/direct"
|
||||
)
|
||||
|
||||
type builder struct {
|
||||
|
@ -34,12 +33,13 @@ type builder struct {
|
|||
pubKeys []crypto.PubKey
|
||||
|
||||
pubkeyCodec types.PublicKeyCodec
|
||||
|
||||
txBodyHasUnknownNonCriticals bool
|
||||
}
|
||||
|
||||
var (
|
||||
_ authsigning.SigFeeMemoTx = &builder{}
|
||||
_ client.TxBuilder = &builder{}
|
||||
_ direct.ProtoTx = &builder{}
|
||||
)
|
||||
|
||||
func newBuilder(pubkeyCodec types.PublicKeyCodec) *builder {
|
||||
|
@ -122,7 +122,7 @@ func (t *builder) ValidateBasic() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t *builder) GetBodyBytes() []byte {
|
||||
func (t *builder) getBodyBytes() []byte {
|
||||
if len(t.bodyBz) == 0 {
|
||||
// if bodyBz is empty, then marshal the body. bodyBz will generally
|
||||
// be set to nil whenever SetBody is called so the result of calling
|
||||
|
@ -138,7 +138,7 @@ func (t *builder) GetBodyBytes() []byte {
|
|||
return t.bodyBz
|
||||
}
|
||||
|
||||
func (t *builder) GetAuthInfoBytes() []byte {
|
||||
func (t *builder) getAuthInfoBytes() []byte {
|
||||
if len(t.authInfoBz) == 0 {
|
||||
// if authInfoBz is empty, then marshal the body. authInfoBz will generally
|
||||
// be set to nil whenever SetAuthInfo is called so the result of calling
|
||||
|
|
|
@ -53,7 +53,7 @@ func TestTxBuilder(t *testing.T) {
|
|||
|
||||
fee := txtypes.Fee{Amount: sdk.NewCoins(sdk.NewInt64Coin("atom", 150)), GasLimit: 20000}
|
||||
|
||||
t.Log("verify that authInfo bytes encoded with DefaultTxEncoder and decoded with DefaultTxDecoder can be retrieved from GetAuthInfoBytes")
|
||||
t.Log("verify that authInfo bytes encoded with DefaultTxEncoder and decoded with DefaultTxDecoder can be retrieved from getAuthInfoBytes")
|
||||
authInfo := &txtypes.AuthInfo{
|
||||
Fee: &fee,
|
||||
SignerInfos: signerInfo,
|
||||
|
@ -63,7 +63,7 @@ func TestTxBuilder(t *testing.T) {
|
|||
|
||||
require.NotEmpty(t, authInfoBytes)
|
||||
|
||||
t.Log("verify that body bytes encoded with DefaultTxEncoder and decoded with DefaultTxDecoder can be retrieved from GetBodyBytes")
|
||||
t.Log("verify that body bytes encoded with DefaultTxEncoder and decoded with DefaultTxDecoder can be retrieved from getBodyBytes")
|
||||
anys := make([]*codectypes.Any, len(msgs))
|
||||
|
||||
for i, msg := range msgs {
|
||||
|
@ -80,29 +80,29 @@ func TestTxBuilder(t *testing.T) {
|
|||
}
|
||||
bodyBytes := marshaler.MustMarshalBinaryBare(txBody)
|
||||
require.NotEmpty(t, bodyBytes)
|
||||
require.Empty(t, txBuilder.GetBodyBytes())
|
||||
require.Empty(t, txBuilder.getBodyBytes())
|
||||
|
||||
t.Log("verify that calling the SetMsgs, SetMemo results in the correct GetBodyBytes")
|
||||
require.NotEqual(t, bodyBytes, txBuilder.GetBodyBytes())
|
||||
t.Log("verify that calling the SetMsgs, SetMemo results in the correct getBodyBytes")
|
||||
require.NotEqual(t, bodyBytes, txBuilder.getBodyBytes())
|
||||
err = txBuilder.SetMsgs(msgs...)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, bodyBytes, txBuilder.GetBodyBytes())
|
||||
require.NotEqual(t, bodyBytes, txBuilder.getBodyBytes())
|
||||
txBuilder.SetMemo(memo)
|
||||
require.Equal(t, bodyBytes, txBuilder.GetBodyBytes())
|
||||
require.Equal(t, bodyBytes, txBuilder.getBodyBytes())
|
||||
require.Equal(t, len(msgs), len(txBuilder.GetMsgs()))
|
||||
require.Equal(t, 0, len(txBuilder.GetPubKeys()))
|
||||
|
||||
t.Log("verify that updated AuthInfo results in the correct GetAuthInfoBytes and GetPubKeys")
|
||||
require.NotEqual(t, authInfoBytes, txBuilder.GetAuthInfoBytes())
|
||||
t.Log("verify that updated AuthInfo results in the correct getAuthInfoBytes and GetPubKeys")
|
||||
require.NotEqual(t, authInfoBytes, txBuilder.getAuthInfoBytes())
|
||||
txBuilder.SetFeeAmount(fee.Amount)
|
||||
require.NotEqual(t, authInfoBytes, txBuilder.GetAuthInfoBytes())
|
||||
require.NotEqual(t, authInfoBytes, txBuilder.getAuthInfoBytes())
|
||||
txBuilder.SetGasLimit(fee.GasLimit)
|
||||
require.NotEqual(t, authInfoBytes, txBuilder.GetAuthInfoBytes())
|
||||
require.NotEqual(t, authInfoBytes, txBuilder.getAuthInfoBytes())
|
||||
err = txBuilder.SetSignatures(sig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// once fee, gas and signerInfos are all set, AuthInfo bytes should match
|
||||
require.Equal(t, authInfoBytes, txBuilder.GetAuthInfoBytes())
|
||||
require.Equal(t, authInfoBytes, txBuilder.getAuthInfoBytes())
|
||||
|
||||
require.Equal(t, len(msgs), len(txBuilder.GetMsgs()))
|
||||
require.Equal(t, 1, len(txBuilder.GetPubKeys()))
|
||||
|
@ -230,24 +230,3 @@ func TestBuilderValidateBasic(t *testing.T) {
|
|||
err = txBuilder.ValidateBasic()
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDefaultTxDecoderError(t *testing.T) {
|
||||
registry := codectypes.NewInterfaceRegistry()
|
||||
pubKeyCdc := std.DefaultPublicKeyCodec{}
|
||||
encoder := DefaultTxEncoder()
|
||||
decoder := DefaultTxDecoder(registry, pubKeyCdc)
|
||||
|
||||
builder := newBuilder(pubKeyCdc)
|
||||
err := builder.SetMsgs(testdata.NewTestMsg())
|
||||
require.NoError(t, err)
|
||||
|
||||
txBz, err := encoder(builder.GetTx())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = decoder(txBz)
|
||||
require.EqualError(t, err, "no registered implementations of type types.Msg: tx parse error")
|
||||
|
||||
registry.RegisterImplementations((*sdk.Msg)(nil), &testdata.TestMsg{})
|
||||
_, err = decoder(txBz)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ package tx
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/codec"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/client"
|
||||
|
@ -11,7 +13,7 @@ import (
|
|||
"github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
)
|
||||
|
||||
type generator struct {
|
||||
type config struct {
|
||||
pubkeyCodec types.PublicKeyCodec
|
||||
handler signing.SignModeHandler
|
||||
decoder sdk.TxDecoder
|
||||
|
@ -21,11 +23,12 @@ type generator struct {
|
|||
protoCodec *codec.ProtoCodec
|
||||
}
|
||||
|
||||
// NewTxConfig returns a new protobuf TxConfig using the provided ProtoCodec, PublicKeyCodec and SignModeHandler.
|
||||
func NewTxConfig(protoCodec *codec.ProtoCodec, pubkeyCodec types.PublicKeyCodec, signModeHandler signing.SignModeHandler) client.TxConfig {
|
||||
return &generator{
|
||||
// NewTxConfig returns a new protobuf TxConfig using the provided ProtoCodec, PublicKeyCodec and sign modes. The
|
||||
// first enabled sign mode will become the default sign mode.
|
||||
func NewTxConfig(protoCodec *codec.ProtoCodec, pubkeyCodec types.PublicKeyCodec, enabledSignModes []signingtypes.SignMode) client.TxConfig {
|
||||
return &config{
|
||||
pubkeyCodec: pubkeyCodec,
|
||||
handler: signModeHandler,
|
||||
handler: makeSignModeHandler(enabledSignModes),
|
||||
decoder: DefaultTxDecoder(protoCodec, pubkeyCodec),
|
||||
encoder: DefaultTxEncoder(),
|
||||
jsonDecoder: DefaultJSONTxDecoder(protoCodec, pubkeyCodec),
|
||||
|
@ -34,12 +37,12 @@ func NewTxConfig(protoCodec *codec.ProtoCodec, pubkeyCodec types.PublicKeyCodec,
|
|||
}
|
||||
}
|
||||
|
||||
func (g generator) NewTxBuilder() client.TxBuilder {
|
||||
func (g config) NewTxBuilder() client.TxBuilder {
|
||||
return newBuilder(g.pubkeyCodec)
|
||||
}
|
||||
|
||||
// WrapTxBuilder returns a builder from provided transaction
|
||||
func (g generator) WrapTxBuilder(newTx sdk.Tx) (client.TxBuilder, error) {
|
||||
func (g config) WrapTxBuilder(newTx sdk.Tx) (client.TxBuilder, error) {
|
||||
newBuilder, ok := newTx.(*builder)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected %T, got %T", &builder{}, newTx)
|
||||
|
@ -48,22 +51,22 @@ func (g generator) WrapTxBuilder(newTx sdk.Tx) (client.TxBuilder, error) {
|
|||
return newBuilder, nil
|
||||
}
|
||||
|
||||
func (g generator) SignModeHandler() signing.SignModeHandler {
|
||||
func (g config) SignModeHandler() signing.SignModeHandler {
|
||||
return g.handler
|
||||
}
|
||||
|
||||
func (g generator) TxEncoder() sdk.TxEncoder {
|
||||
func (g config) TxEncoder() sdk.TxEncoder {
|
||||
return g.encoder
|
||||
}
|
||||
|
||||
func (g generator) TxDecoder() sdk.TxDecoder {
|
||||
func (g config) TxDecoder() sdk.TxDecoder {
|
||||
return g.decoder
|
||||
}
|
||||
|
||||
func (g generator) TxJSONEncoder() sdk.TxEncoder {
|
||||
func (g config) TxJSONEncoder() sdk.TxEncoder {
|
||||
return g.jsonEncoder
|
||||
}
|
||||
|
||||
func (g generator) TxJSONDecoder() sdk.TxDecoder {
|
||||
func (g config) TxJSONDecoder() sdk.TxDecoder {
|
||||
return g.jsonDecoder
|
||||
}
|
|
@ -20,6 +20,5 @@ func TestGenerator(t *testing.T) {
|
|||
interfaceRegistry.RegisterImplementations((*sdk.Msg)(nil), &testdata.TestMsg{})
|
||||
marshaler := codec.NewProtoCodec(interfaceRegistry)
|
||||
pubKeyCodec := std.DefaultPublicKeyCodec{}
|
||||
signModeHandler := DefaultSignModeHandler()
|
||||
suite.Run(t, testutil.NewTxConfigTestSuite(NewTxConfig(marshaler, pubKeyCodec, signModeHandler)))
|
||||
suite.Run(t, testutil.NewTxConfigTestSuite(NewTxConfig(marshaler, pubKeyCodec, DefaultSignModes)))
|
||||
}
|
|
@ -3,8 +3,9 @@ package tx
|
|||
import (
|
||||
"github.com/tendermint/tendermint/crypto"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/codec/unknownproto"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/codec"
|
||||
"github.com/cosmos/cosmos-sdk/codec/types"
|
||||
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
|
||||
|
@ -12,39 +13,71 @@ import (
|
|||
)
|
||||
|
||||
// DefaultTxDecoder returns a default protobuf TxDecoder using the provided Marshaler and PublicKeyCodec
|
||||
func DefaultTxDecoder(anyUnpacker types.AnyUnpacker, keyCodec cryptotypes.PublicKeyCodec) sdk.TxDecoder {
|
||||
cdc := codec.NewProtoCodec(anyUnpacker)
|
||||
func DefaultTxDecoder(cdc *codec.ProtoCodec, keyCodec cryptotypes.PublicKeyCodec) sdk.TxDecoder {
|
||||
return func(txBytes []byte) (sdk.Tx, error) {
|
||||
var raw tx.TxRaw
|
||||
err := cdc.UnmarshalBinaryBare(txBytes, &raw)
|
||||
|
||||
// reject all unknown proto fields in the root TxRaw
|
||||
err := unknownproto.RejectUnknownFieldsStrict(txBytes, &raw)
|
||||
if err != nil {
|
||||
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
|
||||
}
|
||||
|
||||
var theTx tx.Tx
|
||||
err = cdc.UnmarshalBinaryBare(txBytes, &theTx)
|
||||
err = cdc.UnmarshalBinaryBare(txBytes, &raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var body tx.TxBody
|
||||
|
||||
// allow non-critical unknown fields in TxBody
|
||||
txBodyHasUnknownNonCriticals, err := unknownproto.RejectUnknownFields(raw.BodyBytes, &body, true)
|
||||
if err != nil {
|
||||
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
|
||||
}
|
||||
|
||||
err = cdc.UnmarshalBinaryBare(raw.BodyBytes, &body)
|
||||
if err != nil {
|
||||
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
|
||||
}
|
||||
|
||||
var authInfo tx.AuthInfo
|
||||
|
||||
// reject all unknown proto fields in AuthInfo
|
||||
err = unknownproto.RejectUnknownFieldsStrict(raw.AuthInfoBytes, &authInfo)
|
||||
if err != nil {
|
||||
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
|
||||
}
|
||||
|
||||
err = cdc.UnmarshalBinaryBare(raw.AuthInfoBytes, &authInfo)
|
||||
if err != nil {
|
||||
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
|
||||
}
|
||||
|
||||
theTx := &tx.Tx{
|
||||
Body: &body,
|
||||
AuthInfo: &authInfo,
|
||||
Signatures: raw.Signatures,
|
||||
}
|
||||
|
||||
pks, err := extractPubKeys(theTx, keyCodec)
|
||||
if err != nil {
|
||||
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
|
||||
}
|
||||
|
||||
return &builder{
|
||||
tx: &theTx,
|
||||
bodyBz: raw.BodyBytes,
|
||||
authInfoBz: raw.AuthInfoBytes,
|
||||
pubKeys: pks,
|
||||
pubkeyCodec: keyCodec,
|
||||
tx: theTx,
|
||||
bodyBz: raw.BodyBytes,
|
||||
authInfoBz: raw.AuthInfoBytes,
|
||||
pubKeys: pks,
|
||||
pubkeyCodec: keyCodec,
|
||||
txBodyHasUnknownNonCriticals: txBodyHasUnknownNonCriticals,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultTxDecoder returns a default protobuf JSON TxDecoder using the provided Marshaler and PublicKeyCodec
|
||||
func DefaultJSONTxDecoder(anyUnpacker types.AnyUnpacker, keyCodec cryptotypes.PublicKeyCodec) sdk.TxDecoder {
|
||||
cdc := codec.NewProtoCodec(anyUnpacker)
|
||||
func DefaultJSONTxDecoder(cdc *codec.ProtoCodec, keyCodec cryptotypes.PublicKeyCodec) sdk.TxDecoder {
|
||||
return func(txBytes []byte) (sdk.Tx, error) {
|
||||
var theTx tx.Tx
|
||||
err := cdc.UnmarshalJSON(txBytes, &theTx)
|
||||
|
@ -52,7 +85,7 @@ func DefaultJSONTxDecoder(anyUnpacker types.AnyUnpacker, keyCodec cryptotypes.Pu
|
|||
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
|
||||
}
|
||||
|
||||
pks, err := extractPubKeys(theTx, keyCodec)
|
||||
pks, err := extractPubKeys(&theTx, keyCodec)
|
||||
if err != nil {
|
||||
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
|
||||
}
|
||||
|
@ -65,7 +98,7 @@ func DefaultJSONTxDecoder(anyUnpacker types.AnyUnpacker, keyCodec cryptotypes.Pu
|
|||
}
|
||||
}
|
||||
|
||||
func extractPubKeys(tx tx.Tx, keyCodec cryptotypes.PublicKeyCodec) ([]crypto.PubKey, error) {
|
||||
func extractPubKeys(tx *tx.Tx, keyCodec cryptotypes.PublicKeyCodec) ([]crypto.PubKey, error) {
|
||||
if tx.AuthInfo == nil {
|
||||
return []crypto.PubKey{}, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
package tx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
types "github.com/cosmos/cosmos-sdk/types/tx"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
)
|
||||
|
||||
// signModeDirectHandler defines the SIGN_MODE_DIRECT SignModeHandler
|
||||
type signModeDirectHandler struct{}
|
||||
|
||||
var _ signing.SignModeHandler = signModeDirectHandler{}
|
||||
|
||||
// DefaultMode implements SignModeHandler.DefaultMode
|
||||
func (signModeDirectHandler) DefaultMode() signingtypes.SignMode {
|
||||
return signingtypes.SignMode_SIGN_MODE_DIRECT
|
||||
}
|
||||
|
||||
// Modes implements SignModeHandler.Modes
|
||||
func (signModeDirectHandler) Modes() []signingtypes.SignMode {
|
||||
return []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_DIRECT}
|
||||
}
|
||||
|
||||
// GetSignBytes implements SignModeHandler.GetSignBytes
|
||||
func (signModeDirectHandler) GetSignBytes(mode signingtypes.SignMode, data signing.SignerData, tx sdk.Tx) ([]byte, error) {
|
||||
if mode != signingtypes.SignMode_SIGN_MODE_DIRECT {
|
||||
return nil, fmt.Errorf("expected %s, got %s", signingtypes.SignMode_SIGN_MODE_DIRECT, mode)
|
||||
}
|
||||
|
||||
protoTx, ok := tx.(*builder)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can only handle a protobuf Tx, got %T", tx)
|
||||
}
|
||||
|
||||
bodyBz := protoTx.getBodyBytes()
|
||||
authInfoBz := protoTx.getAuthInfoBytes()
|
||||
|
||||
return DirectSignBytes(bodyBz, authInfoBz, data.ChainID, data.AccountNumber, data.AccountSequence)
|
||||
}
|
||||
|
||||
// DirectSignBytes returns the SIGN_MODE_DIRECT sign bytes for the provided TxBody bytes, AuthInfo bytes, chain ID,
|
||||
// account number and sequence.
|
||||
func DirectSignBytes(bodyBytes, authInfoBytes []byte, chainID string, accnum, sequence uint64) ([]byte, error) {
|
||||
signDoc := types.SignDoc{
|
||||
BodyBytes: bodyBytes,
|
||||
AuthInfoBytes: authInfoBytes,
|
||||
ChainId: chainID,
|
||||
AccountNumber: accnum,
|
||||
AccountSequence: sequence,
|
||||
}
|
||||
return signDoc.Marshal()
|
||||
}
|
|
@ -1,13 +1,9 @@
|
|||
package direct_test
|
||||
package tx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/tx"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing/direct"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/testutil/testdata"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/codec"
|
||||
|
@ -30,8 +26,8 @@ func TestDirectModeHandler(t *testing.T) {
|
|||
marshaler := codec.NewProtoCodec(interfaceRegistry)
|
||||
pubKeyCdc := std.DefaultPublicKeyCodec{}
|
||||
|
||||
txGen := tx.NewTxConfig(marshaler, pubKeyCdc, tx.DefaultSignModeHandler())
|
||||
txBuilder := txGen.NewTxBuilder()
|
||||
txConfig := NewTxConfig(marshaler, pubKeyCdc, []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_DIRECT})
|
||||
txBuilder := txConfig.NewTxBuilder()
|
||||
|
||||
memo := "sometestmemo"
|
||||
msgs := []sdk.Msg{testdata.NewTestMsg(addr)}
|
||||
|
@ -71,9 +67,9 @@ func TestDirectModeHandler(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
t.Log("verify modes and default-mode")
|
||||
directModeHandler := direct.ModeHandler{}
|
||||
require.Equal(t, directModeHandler.DefaultMode(), signingtypes.SignMode_SIGN_MODE_DIRECT)
|
||||
require.Len(t, directModeHandler.Modes(), 1)
|
||||
modeHandler := txConfig.SignModeHandler()
|
||||
require.Equal(t, modeHandler.DefaultMode(), signingtypes.SignMode_SIGN_MODE_DIRECT)
|
||||
require.Len(t, modeHandler.Modes(), 1)
|
||||
|
||||
signingData := signing.SignerData{
|
||||
ChainID: "test-chain",
|
||||
|
@ -81,7 +77,7 @@ func TestDirectModeHandler(t *testing.T) {
|
|||
AccountSequence: 1,
|
||||
}
|
||||
|
||||
signBytes, err := directModeHandler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, txBuilder.GetTx())
|
||||
signBytes, err := modeHandler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, txBuilder.GetTx())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signBytes)
|
||||
|
@ -127,7 +123,7 @@ func TestDirectModeHandler(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
err = txBuilder.SetSignatures(sig)
|
||||
require.NoError(t, err)
|
||||
signBytes, err = directModeHandler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, txBuilder.GetTx())
|
||||
signBytes, err = modeHandler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, txBuilder.GetTx())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSignBytes, signBytes)
|
||||
|
||||
|
@ -146,7 +142,7 @@ func TestDirectModeHandler_nonDIRECT_MODE(t *testing.T) {
|
|||
}
|
||||
for _, invalidMode := range invalidModes {
|
||||
t.Run(invalidMode.String(), func(t *testing.T) {
|
||||
var dh direct.ModeHandler
|
||||
var dh signModeDirectHandler
|
||||
var signingData signing.SignerData
|
||||
_, err := dh.GetSignBytes(invalidMode, signingData, nil)
|
||||
require.Error(t, err)
|
||||
|
@ -164,11 +160,11 @@ func (npt *nonProtoTx) ValidateBasic() error { return nil }
|
|||
var _ sdk.Tx = (*nonProtoTx)(nil)
|
||||
|
||||
func TestDirectModeHandler_nonProtoTx(t *testing.T) {
|
||||
var dh direct.ModeHandler
|
||||
var dh signModeDirectHandler
|
||||
var signingData signing.SignerData
|
||||
tx := new(nonProtoTx)
|
||||
_, err := dh.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, tx)
|
||||
require.Error(t, err)
|
||||
wantErr := fmt.Errorf("can only get direct sign bytes for a ProtoTx, got %T", tx)
|
||||
wantErr := fmt.Errorf("can only handle a protobuf Tx, got %T", tx)
|
||||
require.Equal(t, err, wantErr)
|
||||
}
|
|
@ -0,0 +1,165 @@
|
|||
package tx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/types/tx"
|
||||
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/codec"
|
||||
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
|
||||
"github.com/cosmos/cosmos-sdk/std"
|
||||
"github.com/cosmos/cosmos-sdk/testutil/testdata"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
)
|
||||
|
||||
func TestDefaultTxDecoderError(t *testing.T) {
|
||||
registry := codectypes.NewInterfaceRegistry()
|
||||
cdc := codec.NewProtoCodec(registry)
|
||||
pubKeyCdc := std.DefaultPublicKeyCodec{}
|
||||
encoder := DefaultTxEncoder()
|
||||
decoder := DefaultTxDecoder(cdc, pubKeyCdc)
|
||||
|
||||
builder := newBuilder(pubKeyCdc)
|
||||
err := builder.SetMsgs(testdata.NewTestMsg())
|
||||
require.NoError(t, err)
|
||||
|
||||
txBz, err := encoder(builder.GetTx())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = decoder(txBz)
|
||||
require.EqualError(t, err, "no registered implementations of type types.Msg: tx parse error")
|
||||
|
||||
registry.RegisterImplementations((*sdk.Msg)(nil), &testdata.TestMsg{})
|
||||
_, err = decoder(txBz)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUnknownFields(t *testing.T) {
|
||||
registry := codectypes.NewInterfaceRegistry()
|
||||
cdc := codec.NewProtoCodec(registry)
|
||||
pubKeyCdc := std.DefaultPublicKeyCodec{}
|
||||
decoder := DefaultTxDecoder(cdc, pubKeyCdc)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body *testdata.TestUpdatedTxBody
|
||||
authInfo *testdata.TestUpdatedAuthInfo
|
||||
shouldErr bool
|
||||
shouldAminoErr string
|
||||
}{
|
||||
{
|
||||
name: "no new fields should pass",
|
||||
body: &testdata.TestUpdatedTxBody{
|
||||
Memo: "foo",
|
||||
},
|
||||
authInfo: &testdata.TestUpdatedAuthInfo{},
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-critical fields in TxBody should not error on decode, but should error with amino",
|
||||
body: &testdata.TestUpdatedTxBody{
|
||||
Memo: "foo",
|
||||
SomeNewFieldNonCriticalField: "blah",
|
||||
},
|
||||
authInfo: &testdata.TestUpdatedAuthInfo{},
|
||||
shouldErr: false,
|
||||
shouldAminoErr: fmt.Sprintf("%s: %s", aminoNonCriticalFieldsError, sdkerrors.ErrInvalidRequest.Error()),
|
||||
},
|
||||
{
|
||||
name: "critical fields in TxBody should error on decode",
|
||||
body: &testdata.TestUpdatedTxBody{
|
||||
Memo: "foo",
|
||||
SomeNewField: 10,
|
||||
},
|
||||
authInfo: &testdata.TestUpdatedAuthInfo{},
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "critical fields in AuthInfo should error on decode",
|
||||
body: &testdata.TestUpdatedTxBody{
|
||||
Memo: "foo",
|
||||
},
|
||||
authInfo: &testdata.TestUpdatedAuthInfo{
|
||||
NewField_3: []byte("xyz"),
|
||||
},
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-critical fields in AuthInfo should error on decode",
|
||||
body: &testdata.TestUpdatedTxBody{
|
||||
Memo: "foo",
|
||||
},
|
||||
authInfo: &testdata.TestUpdatedAuthInfo{
|
||||
NewField_1024: []byte("xyz"),
|
||||
},
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
bodyBz, err := tt.body.Marshal()
|
||||
require.NoError(t, err)
|
||||
|
||||
authInfoBz, err := tt.authInfo.Marshal()
|
||||
require.NoError(t, err)
|
||||
|
||||
txRaw := &tx.TxRaw{
|
||||
BodyBytes: bodyBz,
|
||||
AuthInfoBytes: authInfoBz,
|
||||
}
|
||||
txBz, err := txRaw.Marshal()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = decoder(txBz)
|
||||
if tt.shouldErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if tt.shouldAminoErr != "" {
|
||||
handler := signModeLegacyAminoJSONHandler{}
|
||||
decoder := DefaultTxDecoder(codec.NewProtoCodec(codectypes.NewInterfaceRegistry()), std.DefaultPublicKeyCodec{})
|
||||
theTx, err := decoder(txBz)
|
||||
require.NoError(t, err)
|
||||
_, err = handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signing.SignerData{}, theTx)
|
||||
require.EqualError(t, err, tt.shouldAminoErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Log("test TxRaw no new fields, should succeed")
|
||||
txRaw := &testdata.TestUpdatedTxRaw{}
|
||||
txBz, err := txRaw.Marshal()
|
||||
require.NoError(t, err)
|
||||
_, err = decoder(txBz)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Log("new field in TxRaw should fail")
|
||||
txRaw = &testdata.TestUpdatedTxRaw{
|
||||
NewField_5: []byte("abc"),
|
||||
}
|
||||
txBz, err = txRaw.Marshal()
|
||||
require.NoError(t, err)
|
||||
_, err = decoder(txBz)
|
||||
require.Error(t, err)
|
||||
|
||||
//
|
||||
t.Log("new \"non-critical\" field in TxRaw should fail")
|
||||
txRaw = &testdata.TestUpdatedTxRaw{
|
||||
NewField_1024: []byte("abc"),
|
||||
}
|
||||
txBz, err = txRaw.Marshal()
|
||||
require.NoError(t, err)
|
||||
_, err = decoder(txBz)
|
||||
require.Error(t, err)
|
||||
}
|
|
@ -19,8 +19,8 @@ func DefaultTxEncoder() types.TxEncoder {
|
|||
}
|
||||
|
||||
raw := &txtypes.TxRaw{
|
||||
BodyBytes: wrapper.GetBodyBytes(),
|
||||
AuthInfoBytes: wrapper.GetAuthInfoBytes(),
|
||||
BodyBytes: wrapper.getBodyBytes(),
|
||||
AuthInfoBytes: wrapper.getAuthInfoBytes(),
|
||||
Signatures: wrapper.tx.Signatures,
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
package tx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/types"
|
||||
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
)
|
||||
|
||||
// signModeLegacyAminoJSONHandler defines the SIGN_MODE_LEGACY_AMINO_JSON SignModeHandler
|
||||
type signModeLegacyAminoJSONHandler struct{}
|
||||
|
||||
func (s signModeLegacyAminoJSONHandler) DefaultMode() signingtypes.SignMode {
|
||||
return signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON
|
||||
}
|
||||
|
||||
func (s signModeLegacyAminoJSONHandler) Modes() []signingtypes.SignMode {
|
||||
return []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON}
|
||||
}
|
||||
|
||||
const aminoNonCriticalFieldsError = "protobuf transaction contains unknown non-critical fields. This is a transaction malleability issue and SIGN_MODE_LEGACY_AMINO_JSON cannot be used."
|
||||
|
||||
func (s signModeLegacyAminoJSONHandler) GetSignBytes(mode signingtypes.SignMode, data signing.SignerData, tx sdk.Tx) ([]byte, error) {
|
||||
if mode != signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON {
|
||||
return nil, fmt.Errorf("expected %s, got %s", signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, mode)
|
||||
}
|
||||
|
||||
protoTx, ok := tx.(*builder)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can only handle a protobuf Tx, got %T", tx)
|
||||
}
|
||||
|
||||
if protoTx.txBodyHasUnknownNonCriticals {
|
||||
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, aminoNonCriticalFieldsError)
|
||||
}
|
||||
|
||||
body := protoTx.tx.Body
|
||||
|
||||
if len(body.ExtensionOptions) != 0 || len(body.NonCriticalExtensionOptions) != 0 {
|
||||
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest,
|
||||
"SIGN_MODE_LEGACY_AMINO_JSON does not support protobuf extension options.")
|
||||
}
|
||||
|
||||
if body.TimeoutHeight != 0 {
|
||||
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest,
|
||||
"SIGN_MODE_LEGACY_AMINO_JSON does not support timeout height.")
|
||||
}
|
||||
|
||||
return types.StdSignBytes(
|
||||
data.ChainID, data.AccountNumber, data.AccountSequence,
|
||||
//nolint:staticcheck
|
||||
types.StdFee{Amount: protoTx.GetFee(), Gas: protoTx.GetGas()},
|
||||
tx.GetMsgs(), protoTx.GetMemo(),
|
||||
), nil
|
||||
}
|
||||
|
||||
var _ signing.SignModeHandler = signModeLegacyAminoJSONHandler{}
|
|
@ -0,0 +1,101 @@
|
|||
package tx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
cdctypes "github.com/cosmos/cosmos-sdk/codec/types"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/std"
|
||||
"github.com/cosmos/cosmos-sdk/testutil/testdata"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/types"
|
||||
)
|
||||
|
||||
var (
|
||||
_, _, addr1 = testdata.KeyTestPubAddr()
|
||||
_, _, addr2 = testdata.KeyTestPubAddr()
|
||||
|
||||
coins = sdk.Coins{sdk.NewInt64Coin("foocoin", 10)}
|
||||
gas = uint64(10000)
|
||||
msg = testdata.NewTestMsg(addr1, addr2)
|
||||
memo = "foo"
|
||||
)
|
||||
|
||||
func buildTx(t *testing.T, bldr *builder) {
|
||||
bldr.SetFeeAmount(coins)
|
||||
bldr.SetGasLimit(gas)
|
||||
bldr.SetMemo(memo)
|
||||
require.NoError(t, bldr.SetMsgs(msg))
|
||||
}
|
||||
|
||||
func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
|
||||
bldr := newBuilder(std.DefaultPublicKeyCodec{})
|
||||
buildTx(t, bldr)
|
||||
tx := bldr.GetTx()
|
||||
|
||||
var (
|
||||
chainId = "test-chain"
|
||||
accNum uint64 = 7
|
||||
seqNum uint64 = 7
|
||||
)
|
||||
|
||||
handler := signModeLegacyAminoJSONHandler{}
|
||||
signingData := signing.SignerData{
|
||||
ChainID: chainId,
|
||||
AccountNumber: accNum,
|
||||
AccountSequence: seqNum,
|
||||
}
|
||||
signBz, err := handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signingData, tx)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedSignBz := types.StdSignBytes(chainId, accNum, seqNum, types.StdFee{
|
||||
Amount: coins,
|
||||
Gas: gas,
|
||||
}, []sdk.Msg{msg}, memo)
|
||||
|
||||
require.Equal(t, expectedSignBz, signBz)
|
||||
|
||||
// expect error with wrong sign mode
|
||||
_, err = handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_DIRECT, signingData, tx)
|
||||
require.Error(t, err)
|
||||
|
||||
// expect error with timeout height
|
||||
bldr = newBuilder(std.DefaultPublicKeyCodec{})
|
||||
buildTx(t, bldr)
|
||||
bldr.tx.Body.TimeoutHeight = 10
|
||||
tx = bldr.GetTx()
|
||||
signBz, err = handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signingData, tx)
|
||||
require.Error(t, err)
|
||||
|
||||
// expect error with extension options
|
||||
bldr = newBuilder(std.DefaultPublicKeyCodec{})
|
||||
buildTx(t, bldr)
|
||||
any, err := cdctypes.NewAnyWithValue(testdata.NewTestMsg())
|
||||
require.NoError(t, err)
|
||||
bldr.tx.Body.ExtensionOptions = []*cdctypes.Any{any}
|
||||
tx = bldr.GetTx()
|
||||
signBz, err = handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signingData, tx)
|
||||
require.Error(t, err)
|
||||
|
||||
// expect error with non-critical extension options
|
||||
bldr = newBuilder(std.DefaultPublicKeyCodec{})
|
||||
buildTx(t, bldr)
|
||||
bldr.tx.Body.NonCriticalExtensionOptions = []*cdctypes.Any{any}
|
||||
tx = bldr.GetTx()
|
||||
signBz, err = handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signingData, tx)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestLegacyAminoJSONHandler_DefaultMode(t *testing.T) {
|
||||
handler := signModeLegacyAminoJSONHandler{}
|
||||
require.Equal(t, signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, handler.DefaultMode())
|
||||
}
|
||||
|
||||
func TestLegacyAminoJSONHandler_Modes(t *testing.T) {
|
||||
handler := signModeLegacyAminoJSONHandler{}
|
||||
require.Equal(t, []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON}, handler.Modes())
|
||||
}
|
|
@ -1,20 +1,40 @@
|
|||
package tx
|
||||
|
||||
import (
|
||||
signing2 "github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
"fmt"
|
||||
|
||||
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing/direct"
|
||||
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
|
||||
)
|
||||
|
||||
// DefaultSignModeHandler returns the default protobuf SignModeHandler supporting
|
||||
// DefaultSignModes are the default sign modes enabled for protobuf transactions.
|
||||
var DefaultSignModes = []signingtypes.SignMode{
|
||||
signingtypes.SignMode_SIGN_MODE_DIRECT,
|
||||
signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON,
|
||||
}
|
||||
|
||||
// makeSignModeHandler returns the default protobuf SignModeHandler supporting
|
||||
// SIGN_MODE_DIRECT and SIGN_MODE_LEGACY_AMINO_JSON.
|
||||
func DefaultSignModeHandler() signing.SignModeHandler {
|
||||
func makeSignModeHandler(modes []signingtypes.SignMode) signing.SignModeHandler {
|
||||
if len(modes) < 1 {
|
||||
panic(fmt.Errorf("no sign modes enabled"))
|
||||
}
|
||||
|
||||
handlers := make([]signing.SignModeHandler, len(modes))
|
||||
|
||||
for i, mode := range modes {
|
||||
switch mode {
|
||||
case signingtypes.SignMode_SIGN_MODE_DIRECT:
|
||||
handlers[i] = signModeDirectHandler{}
|
||||
case signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON:
|
||||
handlers[i] = signModeLegacyAminoJSONHandler{}
|
||||
default:
|
||||
panic(fmt.Errorf("unsupported sign mode %+v", mode))
|
||||
}
|
||||
}
|
||||
|
||||
return signing.NewSignModeHandlerMap(
|
||||
signing2.SignMode_SIGN_MODE_DIRECT,
|
||||
[]signing.SignModeHandler{
|
||||
authtypes.LegacyAminoJSONHandler{},
|
||||
direct.ModeHandler{},
|
||||
},
|
||||
modes[0],
|
||||
handlers,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -105,7 +105,7 @@ func decodeMultisignatures(bz []byte) ([][]byte, error) {
|
|||
return multisig.Signatures, nil
|
||||
}
|
||||
|
||||
func (g generator) MarshalSignatureJSON(sigs []signing.SignatureV2) ([]byte, error) {
|
||||
func (g config) MarshalSignatureJSON(sigs []signing.SignatureV2) ([]byte, error) {
|
||||
descs := make([]*signing.SignatureDescriptor, len(sigs))
|
||||
|
||||
for i, sig := range sigs {
|
||||
|
@ -127,7 +127,7 @@ func (g generator) MarshalSignatureJSON(sigs []signing.SignatureV2) ([]byte, err
|
|||
return codec.ProtoMarshalJSON(toJSON)
|
||||
}
|
||||
|
||||
func (g generator) UnmarshalSignatureJSON(bz []byte) ([]signing.SignatureV2, error) {
|
||||
func (g config) UnmarshalSignatureJSON(bz []byte) ([]signing.SignatureV2, error) {
|
||||
var sigDescs signing.SignatureDescriptors
|
||||
err := g.protoCodec.UnmarshalJSON(bz, &sigDescs)
|
||||
if err != nil {
|
||||
|
|
|
@ -8,38 +8,37 @@ import (
|
|||
"github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
)
|
||||
|
||||
// LegacyAminoJSONHandler is a SignModeHandler that handles SIGN_MODE_LEGACY_AMINO_JSON
|
||||
type LegacyAminoJSONHandler struct{}
|
||||
// stdTxSignModeHandler is a SignModeHandler that handles SIGN_MODE_LEGACY_AMINO_JSON
|
||||
type stdTxSignModeHandler struct{}
|
||||
|
||||
var _ signing.SignModeHandler = LegacyAminoJSONHandler{}
|
||||
func NewStdTxSignModeHandler() signing.SignModeHandler {
|
||||
return &stdTxSignModeHandler{}
|
||||
}
|
||||
|
||||
var _ signing.SignModeHandler = stdTxSignModeHandler{}
|
||||
|
||||
// DefaultMode implements SignModeHandler.DefaultMode
|
||||
func (h LegacyAminoJSONHandler) DefaultMode() signingtypes.SignMode {
|
||||
func (h stdTxSignModeHandler) DefaultMode() signingtypes.SignMode {
|
||||
return signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON
|
||||
}
|
||||
|
||||
// Modes implements SignModeHandler.Modes
|
||||
func (LegacyAminoJSONHandler) Modes() []signingtypes.SignMode {
|
||||
func (stdTxSignModeHandler) Modes() []signingtypes.SignMode {
|
||||
return []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON}
|
||||
}
|
||||
|
||||
// DefaultMode implements SignModeHandler.GetSignBytes
|
||||
func (LegacyAminoJSONHandler) GetSignBytes(mode signingtypes.SignMode, data signing.SignerData, tx sdk.Tx) ([]byte, error) {
|
||||
func (stdTxSignModeHandler) GetSignBytes(mode signingtypes.SignMode, data signing.SignerData, tx sdk.Tx) ([]byte, error) {
|
||||
if mode != signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON {
|
||||
return nil, fmt.Errorf("expected %s, got %s", signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, mode)
|
||||
}
|
||||
|
||||
feeTx, ok := tx.(sdk.FeeTx)
|
||||
stdTx, ok := tx.(StdTx)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected FeeTx, got %T", tx)
|
||||
}
|
||||
|
||||
memoTx, ok := tx.(sdk.TxWithMemo)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected TxWithMemo, got %T", tx)
|
||||
return nil, fmt.Errorf("expected %T, got %T", StdTx{}, tx)
|
||||
}
|
||||
|
||||
return StdSignBytes(
|
||||
data.ChainID, data.AccountNumber, data.AccountSequence, StdFee{Amount: feeTx.GetFee(), Gas: feeTx.GetGas()}, tx.GetMsgs(), memoTx.GetMemo(),
|
||||
data.ChainID, data.AccountNumber, data.AccountSequence, StdFee{Amount: stdTx.GetFee(), Gas: stdTx.GetGas()}, tx.GetMsgs(), stdTx.GetMemo(),
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -1,18 +1,16 @@
|
|||
package types_test
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/types"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/tendermint/tendermint/crypto/secp256k1"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/testutil/testdata"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
|
||||
banktypes "github.com/cosmos/cosmos-sdk/x/bank/types"
|
||||
"github.com/cosmos/cosmos-sdk/x/auth/signing"
|
||||
)
|
||||
|
||||
func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
|
||||
|
@ -23,20 +21,16 @@ func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
|
|||
|
||||
coins := sdk.Coins{sdk.NewInt64Coin("foocoin", 10)}
|
||||
|
||||
fee := types.StdFee{
|
||||
fee := StdFee{
|
||||
Amount: coins,
|
||||
Gas: 10000,
|
||||
}
|
||||
memo := "foo"
|
||||
msgs := []sdk.Msg{
|
||||
&banktypes.MsgSend{
|
||||
FromAddress: addr1,
|
||||
ToAddress: addr2,
|
||||
Amount: coins,
|
||||
},
|
||||
testdata.NewTestMsg(addr1, addr2),
|
||||
}
|
||||
|
||||
tx := types.StdTx{
|
||||
tx := StdTx{
|
||||
Msgs: msgs,
|
||||
Fee: fee,
|
||||
Signatures: nil,
|
||||
|
@ -49,7 +43,7 @@ func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
|
|||
seqNum uint64 = 7
|
||||
)
|
||||
|
||||
handler := types.LegacyAminoJSONHandler{}
|
||||
handler := stdTxSignModeHandler{}
|
||||
signingData := signing.SignerData{
|
||||
ChainID: chainId,
|
||||
AccountNumber: accNum,
|
||||
|
@ -58,7 +52,7 @@ func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
|
|||
signBz, err := handler.GetSignBytes(signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, signingData, tx)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedSignBz := types.StdSignBytes(chainId, accNum, seqNum, fee, msgs, memo)
|
||||
expectedSignBz := StdSignBytes(chainId, accNum, seqNum, fee, msgs, memo)
|
||||
|
||||
require.Equal(t, expectedSignBz, signBz)
|
||||
|
||||
|
@ -68,11 +62,11 @@ func TestLegacyAminoJSONHandler_GetSignBytes(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLegacyAminoJSONHandler_DefaultMode(t *testing.T) {
|
||||
handler := types.LegacyAminoJSONHandler{}
|
||||
handler := stdTxSignModeHandler{}
|
||||
require.Equal(t, signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON, handler.DefaultMode())
|
||||
}
|
||||
|
||||
func TestLegacyAminoJSONHandler_Modes(t *testing.T) {
|
||||
handler := types.LegacyAminoJSONHandler{}
|
||||
handler := stdTxSignModeHandler{}
|
||||
require.Equal(t, []signingtypes.SignMode{signingtypes.SignMode_SIGN_MODE_LEGACY_AMINO_JSON}, handler.Modes())
|
||||
}
|
||||
|
|
|
@ -146,5 +146,5 @@ func (s StdTxConfig) UnmarshalSignatureJSON(bz []byte) ([]signing.SignatureV2, e
|
|||
}
|
||||
|
||||
func (s StdTxConfig) SignModeHandler() authsigning.SignModeHandler {
|
||||
return LegacyAminoJSONHandler{}
|
||||
return stdTxSignModeHandler{}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue