refactor: migrate ante hanlders to middlewares (#10028)

<!--
The default pull request template is for types feat, fix, or refactor.
For other templates, add one of the following parameters to the url:
- template=docs.md
- template=other.md
-->

## Description

Closes: #9585 

<!-- Add a description of the changes that this PR introduces and the files that
are the most critical to review. -->

---

### Author Checklist

*All items are required. Please add a note to the item if the item is not applicable and
please add links to any relevant follow up issues.*

I have...

- [ ] included the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title
- [ ] added `!` to the type prefix if API or client breaking change
- [ ] targeted the correct branch (see [PR Targeting](https://github.com/cosmos/cosmos-sdk/blob/master/CONTRIBUTING.md#pr-targeting))
- [ ] provided a link to the relevant issue or specification
- [ ] followed the guidelines for [building modules](https://github.com/cosmos/cosmos-sdk/blob/master/docs/building-modules)
- [ ] included the necessary unit and integration [tests](https://github.com/cosmos/cosmos-sdk/blob/master/CONTRIBUTING.md#testing)
- [ ] added a changelog entry to `CHANGELOG.md`
- [ ] included comments for [documenting Go code](https://blog.golang.org/godoc)
- [ ] updated the relevant documentation or specification
- [ ] reviewed "Files changed" and left comments if necessary
- [ ] confirmed all CI checks have passed

### Reviewers Checklist

*All items are required. Please add a note if the item is not applicable and please add
your handle next to the items reviewed if you only reviewed selected items.*

I have...

- [ ] confirmed the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title
- [ ] confirmed `!` in the type prefix if API or client breaking change
- [ ] confirmed all author checklist items have been addressed 
- [ ] reviewed state machine logic
- [ ] reviewed API design and naming
- [ ] reviewed documentation is accurate
- [ ] reviewed tests and test coverage
- [ ] manually tested (if applicable)
This commit is contained in:
atheeshp 2021-10-01 20:00:22 +05:30 committed by GitHub
parent 9833bf14c1
commit f726a2398a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 2367 additions and 2122 deletions

View File

@ -1,4 +1,4 @@
package baseapp package baseapp_test
import ( import (
"fmt" "fmt"
@ -9,6 +9,7 @@ import (
tmprototypes "github.com/tendermint/tendermint/proto/tendermint/types" tmprototypes "github.com/tendermint/tendermint/proto/tendermint/types"
dbm "github.com/tendermint/tm-db" dbm "github.com/tendermint/tm-db"
"github.com/cosmos/cosmos-sdk/baseapp"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
) )
@ -18,81 +19,81 @@ func TestGetBlockRentionHeight(t *testing.T) {
name := t.Name() name := t.Name()
testCases := map[string]struct { testCases := map[string]struct {
bapp *BaseApp bapp *baseapp.BaseApp
maxAgeBlocks int64 maxAgeBlocks int64
commitHeight int64 commitHeight int64
expected int64 expected int64
}{ }{
"defaults": { "defaults": {
bapp: NewBaseApp(name, logger, db, nil), bapp: baseapp.NewBaseApp(name, logger, db, nil),
maxAgeBlocks: 0, maxAgeBlocks: 0,
commitHeight: 499000, commitHeight: 499000,
expected: 0, expected: 0,
}, },
"pruning unbonding time only": { "pruning unbonding time only": {
bapp: NewBaseApp(name, logger, db, nil, SetMinRetainBlocks(1)), bapp: baseapp.NewBaseApp(name, logger, db, nil, baseapp.SetMinRetainBlocks(1)),
maxAgeBlocks: 362880, maxAgeBlocks: 362880,
commitHeight: 499000, commitHeight: 499000,
expected: 136120, expected: 136120,
}, },
"pruning iavl snapshot only": { "pruning iavl snapshot only": {
bapp: NewBaseApp( bapp: baseapp.NewBaseApp(
name, logger, db, nil, name, logger, db, nil,
SetPruning(sdk.PruningOptions{KeepEvery: 10000}), baseapp.SetPruning(sdk.PruningOptions{KeepEvery: 10000}),
SetMinRetainBlocks(1), baseapp.SetMinRetainBlocks(1),
), ),
maxAgeBlocks: 0, maxAgeBlocks: 0,
commitHeight: 499000, commitHeight: 499000,
expected: 490000, expected: 490000,
}, },
"pruning state sync snapshot only": { "pruning state sync snapshot only": {
bapp: NewBaseApp( bapp: baseapp.NewBaseApp(
name, logger, db, nil, name, logger, db, nil,
SetSnapshotInterval(50000), baseapp.SetSnapshotInterval(50000),
SetSnapshotKeepRecent(3), baseapp.SetSnapshotKeepRecent(3),
SetMinRetainBlocks(1), baseapp.SetMinRetainBlocks(1),
), ),
maxAgeBlocks: 0, maxAgeBlocks: 0,
commitHeight: 499000, commitHeight: 499000,
expected: 349000, expected: 349000,
}, },
"pruning min retention only": { "pruning min retention only": {
bapp: NewBaseApp( bapp: baseapp.NewBaseApp(
name, logger, db, nil, name, logger, db, nil,
SetMinRetainBlocks(400000), baseapp.SetMinRetainBlocks(400000),
), ),
maxAgeBlocks: 0, maxAgeBlocks: 0,
commitHeight: 499000, commitHeight: 499000,
expected: 99000, expected: 99000,
}, },
"pruning all conditions": { "pruning all conditions": {
bapp: NewBaseApp( bapp: baseapp.NewBaseApp(
name, logger, db, nil, name, logger, db, nil,
SetPruning(sdk.PruningOptions{KeepEvery: 10000}), baseapp.SetPruning(sdk.PruningOptions{KeepEvery: 10000}),
SetMinRetainBlocks(400000), baseapp.SetMinRetainBlocks(400000),
SetSnapshotInterval(50000), SetSnapshotKeepRecent(3), baseapp.SetSnapshotInterval(50000), baseapp.SetSnapshotKeepRecent(3),
), ),
maxAgeBlocks: 362880, maxAgeBlocks: 362880,
commitHeight: 499000, commitHeight: 499000,
expected: 99000, expected: 99000,
}, },
"no pruning due to no persisted state": { "no pruning due to no persisted state": {
bapp: NewBaseApp( bapp: baseapp.NewBaseApp(
name, logger, db, nil, name, logger, db, nil,
SetPruning(sdk.PruningOptions{KeepEvery: 10000}), baseapp.SetPruning(sdk.PruningOptions{KeepEvery: 10000}),
SetMinRetainBlocks(400000), baseapp.SetMinRetainBlocks(400000),
SetSnapshotInterval(50000), SetSnapshotKeepRecent(3), baseapp.SetSnapshotInterval(50000), baseapp.SetSnapshotKeepRecent(3),
), ),
maxAgeBlocks: 362880, maxAgeBlocks: 362880,
commitHeight: 10000, commitHeight: 10000,
expected: 0, expected: 0,
}, },
"disable pruning": { "disable pruning": {
bapp: NewBaseApp( bapp: baseapp.NewBaseApp(
name, logger, db, nil, name, logger, db, nil,
SetPruning(sdk.PruningOptions{KeepEvery: 10000}), baseapp.SetPruning(sdk.PruningOptions{KeepEvery: 10000}),
SetMinRetainBlocks(0), baseapp.SetMinRetainBlocks(0),
SetSnapshotInterval(50000), SetSnapshotKeepRecent(3), baseapp.SetSnapshotInterval(50000), baseapp.SetSnapshotKeepRecent(3),
), ),
maxAgeBlocks: 362880, maxAgeBlocks: 362880,
commitHeight: 499000, commitHeight: 499000,
@ -126,14 +127,14 @@ func TestBaseAppCreateQueryContextRejectsNegativeHeights(t *testing.T) {
logger := defaultLogger() logger := defaultLogger()
db := dbm.NewMemDB() db := dbm.NewMemDB()
name := t.Name() name := t.Name()
app := NewBaseApp(name, logger, db, nil) app := baseapp.NewBaseApp(name, logger, db, nil)
proves := []bool{ proves := []bool{
false, true, false, true,
} }
for _, prove := range proves { for _, prove := range proves {
t.Run(fmt.Sprintf("prove=%t", prove), func(t *testing.T) { t.Run(fmt.Sprintf("prove=%t", prove), func(t *testing.T) {
sctx, err := app.createQueryContext(-10, true) sctx, err := app.CreateQueryContext(-10, true)
require.Error(t, err) require.Error(t, err)
require.Equal(t, sctx, sdk.Context{}) require.Equal(t, sctx, sdk.Context{})
}) })

View File

@ -1,4 +1,4 @@
package baseapp package baseapp_test
import ( import (
"bytes" "bytes"
@ -22,6 +22,7 @@ import (
tmproto "github.com/tendermint/tendermint/proto/tendermint/types" tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
dbm "github.com/tendermint/tm-db" dbm "github.com/tendermint/tm-db"
"github.com/cosmos/cosmos-sdk/baseapp"
"github.com/cosmos/cosmos-sdk/codec" "github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/snapshots" "github.com/cosmos/cosmos-sdk/snapshots"
snapshottypes "github.com/cosmos/cosmos-sdk/snapshots/types" snapshottypes "github.com/cosmos/cosmos-sdk/snapshots/types"
@ -30,6 +31,7 @@ import (
"github.com/cosmos/cosmos-sdk/testutil/testdata" "github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/x/auth/middleware" "github.com/cosmos/cosmos-sdk/x/auth/middleware"
"github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx"
) )
@ -82,12 +84,12 @@ func defaultLogger() log.Logger {
return log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "sdk/app") return log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "sdk/app")
} }
func newBaseApp(name string, options ...func(*BaseApp)) *BaseApp { func newBaseApp(name string, options ...func(*baseapp.BaseApp)) *baseapp.BaseApp {
logger := defaultLogger() logger := defaultLogger()
db := dbm.NewMemDB() db := dbm.NewMemDB()
codec := codec.NewLegacyAmino() codec := codec.NewLegacyAmino()
registerTestCodec(codec) registerTestCodec(codec)
return NewBaseApp(name, logger, db, testTxDecoder(codec), options...) return baseapp.NewBaseApp(name, logger, db, testTxDecoder(codec), options...)
} }
func registerTestCodec(cdc *codec.LegacyAmino) { func registerTestCodec(cdc *codec.LegacyAmino) {
@ -111,7 +113,7 @@ func aminoTxEncoder() sdk.TxEncoder {
} }
// simple one store baseapp // simple one store baseapp
func setupBaseApp(t *testing.T, options ...func(*BaseApp)) *BaseApp { func setupBaseApp(t *testing.T, options ...func(*baseapp.BaseApp)) *baseapp.BaseApp {
app := newBaseApp(t.Name(), options...) app := newBaseApp(t.Name(), options...)
require.Equal(t, t.Name(), app.Name()) require.Equal(t, t.Name(), app.Name())
@ -124,23 +126,37 @@ func setupBaseApp(t *testing.T, options ...func(*BaseApp)) *BaseApp {
return app return app
} }
// testTxHandler is a tx.Handler used for the mock app, it does not
// contain any signature verification logic.
func testTxHandler(options middleware.TxHandlerOptions, customTxHandlerMiddleware handlerFun) tx.Handler {
return middleware.ComposeMiddlewares(
middleware.NewRunMsgsTxHandler(options.MsgServiceRouter, options.LegacyRouter),
middleware.GasTxMiddleware,
middleware.RecoveryTxMiddleware,
middleware.NewIndexEventsTxMiddleware(options.IndexEvents),
middleware.ValidateBasicMiddleware,
CustomTxHandlerMiddleware(customTxHandlerMiddleware),
)
}
// simple one store baseapp with data and snapshots. Each tx is 1 MB in size (uncompressed). // simple one store baseapp with data and snapshots. Each tx is 1 MB in size (uncompressed).
func setupBaseAppWithSnapshots(t *testing.T, blocks uint, blockTxs int, options ...func(*BaseApp)) (*BaseApp, func()) { func setupBaseAppWithSnapshots(t *testing.T, blocks uint, blockTxs int, options ...func(*baseapp.BaseApp)) (*baseapp.BaseApp, func()) {
codec := codec.NewLegacyAmino() codec := codec.NewLegacyAmino()
registerTestCodec(codec) registerTestCodec(codec)
routerOpt := func(bapp *BaseApp) { routerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
legacyRouter.AddRoute(sdk.NewRoute(routeMsgKeyValue, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { legacyRouter.AddRoute(sdk.NewRoute(routeMsgKeyValue, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
kv := msg.(*msgKeyValue) kv := msg.(*msgKeyValue)
bapp.cms.GetCommitKVStore(capKey2).Set(kv.Key, kv.Value) bapp.CMS().GetCommitKVStore(capKey2).Set(kv.Key, kv.Value)
return &sdk.Result{}, nil return &sdk.Result{}, nil
})) }))
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
LegacyAnteHandler: func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { return ctx, nil }, middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
}) },
require.NoError(t, err) func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { return ctx, nil },
)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
@ -155,9 +171,9 @@ func setupBaseAppWithSnapshots(t *testing.T, blocks uint, blockTxs int, options
} }
app := setupBaseApp(t, append(options, app := setupBaseApp(t, append(options,
SetSnapshotStore(snapshotStore), baseapp.SetSnapshotStore(snapshotStore),
SetSnapshotInterval(snapshotInterval), baseapp.SetSnapshotInterval(snapshotInterval),
SetPruning(sdk.PruningOptions{KeepEvery: 1}), baseapp.SetPruning(sdk.PruningOptions{KeepEvery: 1}),
routerOpt)...) routerOpt)...)
app.InitChain(abci.RequestInitChain{}) app.InitChain(abci.RequestInitChain{})
@ -208,9 +224,9 @@ func TestMountStores(t *testing.T) {
app := setupBaseApp(t) app := setupBaseApp(t)
// check both stores // check both stores
store1 := app.cms.GetCommitKVStore(capKey1) store1 := app.CMS().GetCommitKVStore(capKey1)
require.NotNil(t, store1) require.NotNil(t, store1)
store2 := app.cms.GetCommitKVStore(capKey2) store2 := app.CMS().GetCommitKVStore(capKey2)
require.NotNil(t, store2) require.NotNil(t, store2)
} }
@ -218,10 +234,10 @@ func TestMountStores(t *testing.T) {
// Test that LoadLatestVersion actually does. // Test that LoadLatestVersion actually does.
func TestLoadVersion(t *testing.T) { func TestLoadVersion(t *testing.T) {
logger := defaultLogger() logger := defaultLogger()
pruningOpt := SetPruning(store.PruneNothing) pruningOpt := baseapp.SetPruning(store.PruneNothing)
db := dbm.NewMemDB() db := dbm.NewMemDB()
name := t.Name() name := t.Name()
app := NewBaseApp(name, logger, db, nil, pruningOpt) app := baseapp.NewBaseApp(name, logger, db, nil, pruningOpt)
// make a cap key and mount the store // make a cap key and mount the store
err := app.LoadLatestVersion() // needed to make stores non-nil err := app.LoadLatestVersion() // needed to make stores non-nil
@ -248,7 +264,7 @@ func TestLoadVersion(t *testing.T) {
commitID2 := sdk.CommitID{Version: 2, Hash: res.Data} commitID2 := sdk.CommitID{Version: 2, Hash: res.Data}
// reload with LoadLatestVersion // reload with LoadLatestVersion
app = NewBaseApp(name, logger, db, nil, pruningOpt) app = baseapp.NewBaseApp(name, logger, db, nil, pruningOpt)
app.MountStores() app.MountStores()
err = app.LoadLatestVersion() err = app.LoadLatestVersion()
require.Nil(t, err) require.Nil(t, err)
@ -256,7 +272,7 @@ func TestLoadVersion(t *testing.T) {
// reload with LoadVersion, see if you can commit the same block and get // reload with LoadVersion, see if you can commit the same block and get
// the same result // the same result
app = NewBaseApp(name, logger, db, nil, pruningOpt) app = baseapp.NewBaseApp(name, logger, db, nil, pruningOpt)
err = app.LoadVersion(1) err = app.LoadVersion(1)
require.Nil(t, err) require.Nil(t, err)
testLoadVersionHelper(t, app, int64(1), commitID1) testLoadVersionHelper(t, app, int64(1), commitID1)
@ -265,8 +281,8 @@ func TestLoadVersion(t *testing.T) {
testLoadVersionHelper(t, app, int64(2), commitID2) testLoadVersionHelper(t, app, int64(2), commitID2)
} }
func useDefaultLoader(app *BaseApp) { func useDefaultLoader(app *baseapp.BaseApp) {
app.SetStoreLoader(DefaultStoreLoader) app.SetStoreLoader(baseapp.DefaultStoreLoader)
} }
func initStore(t *testing.T, db dbm.DB, storeKey string, k, v []byte) { func initStore(t *testing.T, db dbm.DB, storeKey string, k, v []byte) {
@ -305,7 +321,7 @@ func checkStore(t *testing.T, db dbm.DB, ver int64, storeKey string, k, v []byte
// Test that LoadLatestVersion actually does. // Test that LoadLatestVersion actually does.
func TestSetLoader(t *testing.T) { func TestSetLoader(t *testing.T) {
cases := map[string]struct { cases := map[string]struct {
setLoader func(*BaseApp) setLoader func(*baseapp.BaseApp)
origStoreKey string origStoreKey string
loadStoreKey string loadStoreKey string
}{ }{
@ -331,11 +347,11 @@ func TestSetLoader(t *testing.T) {
initStore(t, db, tc.origStoreKey, k, v) initStore(t, db, tc.origStoreKey, k, v)
// load the app with the existing db // load the app with the existing db
opts := []func(*BaseApp){SetPruning(store.PruneNothing)} opts := []func(*baseapp.BaseApp){baseapp.SetPruning(store.PruneNothing)}
if tc.setLoader != nil { if tc.setLoader != nil {
opts = append(opts, tc.setLoader) opts = append(opts, tc.setLoader)
} }
app := NewBaseApp(t.Name(), defaultLogger(), db, nil, opts...) app := baseapp.NewBaseApp(t.Name(), defaultLogger(), db, nil, opts...)
app.MountStores(sdk.NewKVStoreKey(tc.loadStoreKey)) app.MountStores(sdk.NewKVStoreKey(tc.loadStoreKey))
err := app.LoadLatestVersion() err := app.LoadLatestVersion()
require.Nil(t, err) require.Nil(t, err)
@ -354,10 +370,10 @@ func TestSetLoader(t *testing.T) {
func TestVersionSetterGetter(t *testing.T) { func TestVersionSetterGetter(t *testing.T) {
logger := defaultLogger() logger := defaultLogger()
pruningOpt := SetPruning(store.PruneDefault) pruningOpt := baseapp.SetPruning(store.PruneDefault)
db := dbm.NewMemDB() db := dbm.NewMemDB()
name := t.Name() name := t.Name()
app := NewBaseApp(name, logger, db, nil, pruningOpt) app := baseapp.NewBaseApp(name, logger, db, nil, pruningOpt)
require.Equal(t, "", app.Version()) require.Equal(t, "", app.Version())
res := app.Query(abci.RequestQuery{Path: "app/version"}) res := app.Query(abci.RequestQuery{Path: "app/version"})
@ -374,10 +390,10 @@ func TestVersionSetterGetter(t *testing.T) {
func TestLoadVersionInvalid(t *testing.T) { func TestLoadVersionInvalid(t *testing.T) {
logger := log.NewNopLogger() logger := log.NewNopLogger()
pruningOpt := SetPruning(store.PruneNothing) pruningOpt := baseapp.SetPruning(store.PruneNothing)
db := dbm.NewMemDB() db := dbm.NewMemDB()
name := t.Name() name := t.Name()
app := NewBaseApp(name, logger, db, nil, pruningOpt) app := baseapp.NewBaseApp(name, logger, db, nil, pruningOpt)
err := app.LoadLatestVersion() err := app.LoadLatestVersion()
require.Nil(t, err) require.Nil(t, err)
@ -392,7 +408,7 @@ func TestLoadVersionInvalid(t *testing.T) {
commitID1 := sdk.CommitID{Version: 1, Hash: res.Data} commitID1 := sdk.CommitID{Version: 1, Hash: res.Data}
// create a new app with the stores mounted under the same cap key // create a new app with the stores mounted under the same cap key
app = NewBaseApp(name, logger, db, nil, pruningOpt) app = baseapp.NewBaseApp(name, logger, db, nil, pruningOpt)
// require we can load the latest version // require we can load the latest version
err = app.LoadVersion(1) err = app.LoadVersion(1)
@ -411,10 +427,10 @@ func TestLoadVersionPruning(t *testing.T) {
KeepEvery: 3, KeepEvery: 3,
Interval: 1, Interval: 1,
} }
pruningOpt := SetPruning(pruningOptions) pruningOpt := baseapp.SetPruning(pruningOptions)
db := dbm.NewMemDB() db := dbm.NewMemDB()
name := t.Name() name := t.Name()
app := NewBaseApp(name, logger, db, nil, pruningOpt) app := baseapp.NewBaseApp(name, logger, db, nil, pruningOpt)
// make a cap key and mount the store // make a cap key and mount the store
capKey := sdk.NewKVStoreKey("key1") capKey := sdk.NewKVStoreKey("key1")
@ -442,17 +458,17 @@ func TestLoadVersionPruning(t *testing.T) {
} }
for _, v := range []int64{1, 2, 4} { for _, v := range []int64{1, 2, 4} {
_, err = app.cms.CacheMultiStoreWithVersion(v) _, err = app.CMS().CacheMultiStoreWithVersion(v)
require.NoError(t, err) require.NoError(t, err)
} }
for _, v := range []int64{3, 5, 6, 7} { for _, v := range []int64{3, 5, 6, 7} {
_, err = app.cms.CacheMultiStoreWithVersion(v) _, err = app.CMS().CacheMultiStoreWithVersion(v)
require.NoError(t, err) require.NoError(t, err)
} }
// reload with LoadLatestVersion, check it loads last version // reload with LoadLatestVersion, check it loads last version
app = NewBaseApp(name, logger, db, nil, pruningOpt) app = baseapp.NewBaseApp(name, logger, db, nil, pruningOpt)
app.MountStores(capKey) app.MountStores(capKey)
err = app.LoadLatestVersion() err = app.LoadLatestVersion()
@ -460,7 +476,7 @@ func TestLoadVersionPruning(t *testing.T) {
testLoadVersionHelper(t, app, int64(7), lastCommitID) testLoadVersionHelper(t, app, int64(7), lastCommitID)
} }
func testLoadVersionHelper(t *testing.T, app *BaseApp, expectedHeight int64, expectedID sdk.CommitID) { func testLoadVersionHelper(t *testing.T, app *baseapp.BaseApp, expectedHeight int64, expectedID sdk.CommitID) {
lastHeight := app.LastBlockHeight() lastHeight := app.LastBlockHeight()
lastID := app.LastCommitID() lastID := app.LastCommitID()
require.Equal(t, expectedHeight, lastHeight) require.Equal(t, expectedHeight, lastHeight)
@ -470,13 +486,13 @@ func testLoadVersionHelper(t *testing.T, app *BaseApp, expectedHeight int64, exp
func TestOptionFunction(t *testing.T) { func TestOptionFunction(t *testing.T) {
logger := defaultLogger() logger := defaultLogger()
db := dbm.NewMemDB() db := dbm.NewMemDB()
bap := NewBaseApp("starting name", logger, db, nil, testChangeNameHelper("new name")) bap := baseapp.NewBaseApp("starting name", logger, db, nil, testChangeNameHelper("new name"))
require.Equal(t, bap.name, "new name", "BaseApp should have had name changed via option function") require.Equal(t, bap.GetName(), "new name", "BaseApp should have had name changed via option function")
} }
func testChangeNameHelper(name string) func(*BaseApp) { func testChangeNameHelper(name string) func(*baseapp.BaseApp) {
return func(bap *BaseApp) { return func(bap *baseapp.BaseApp) {
bap.name = name bap.SetName(name)
} }
} }
@ -490,7 +506,7 @@ func TestTxDecoder(t *testing.T) {
tx := newTxCounter(1, 0) tx := newTxCounter(1, 0)
txBytes := codec.MustMarshal(tx) txBytes := codec.MustMarshal(tx)
dTx, err := app.txDecoder(txBytes) dTx, err := app.TxDecoder(txBytes)
require.NoError(t, err) require.NoError(t, err)
cTx := dTx.(txTest) cTx := dTx.(txTest)
@ -555,8 +571,8 @@ func TestBaseAppOptionSeal(t *testing.T) {
func TestSetMinGasPrices(t *testing.T) { func TestSetMinGasPrices(t *testing.T) {
minGasPrices := sdk.DecCoins{sdk.NewInt64DecCoin("stake", 5000)} minGasPrices := sdk.DecCoins{sdk.NewInt64DecCoin("stake", 5000)}
app := newBaseApp(t.Name(), SetMinGasPrices(minGasPrices.String())) app := newBaseApp(t.Name(), baseapp.SetMinGasPrices(minGasPrices.String()))
require.Equal(t, minGasPrices, app.minGasPrices) require.Equal(t, minGasPrices, app.MinGasPrices())
} }
func TestInitChainer(t *testing.T) { func TestInitChainer(t *testing.T) {
@ -565,7 +581,7 @@ func TestInitChainer(t *testing.T) {
// we can reload the same app later // we can reload the same app later
db := dbm.NewMemDB() db := dbm.NewMemDB()
logger := defaultLogger() logger := defaultLogger()
app := NewBaseApp(name, logger, db, nil) app := baseapp.NewBaseApp(name, logger, db, nil)
capKey := sdk.NewKVStoreKey("main") capKey := sdk.NewKVStoreKey("main")
capKey2 := sdk.NewKVStoreKey("key2") capKey2 := sdk.NewKVStoreKey("key2")
app.MountStores(capKey, capKey2) app.MountStores(capKey, capKey2)
@ -608,10 +624,10 @@ func TestInitChainer(t *testing.T) {
) )
// assert that chainID is set correctly in InitChain // assert that chainID is set correctly in InitChain
chainID := app.deliverState.ctx.ChainID() chainID := app.DeliverState().Context().ChainID()
require.Equal(t, "test-chain-id", chainID, "ChainID in deliverState not set correctly in InitChain") require.Equal(t, "test-chain-id", chainID, "ChainID in deliverState not set correctly in InitChain")
chainID = app.checkState.ctx.ChainID() chainID = app.CheckState().Context().ChainID()
require.Equal(t, "test-chain-id", chainID, "ChainID in checkState not set correctly in InitChain") require.Equal(t, "test-chain-id", chainID, "ChainID in checkState not set correctly in InitChain")
app.Commit() app.Commit()
@ -620,7 +636,7 @@ func TestInitChainer(t *testing.T) {
require.Equal(t, value, res.Value) require.Equal(t, value, res.Value)
// reload app // reload app
app = NewBaseApp(name, logger, db, nil) app = baseapp.NewBaseApp(name, logger, db, nil)
app.SetInitChainer(initChainer) app.SetInitChainer(initChainer)
app.MountStores(capKey, capKey2) app.MountStores(capKey, capKey2)
err = app.LoadLatestVersion() // needed to make stores non-nil err = app.LoadLatestVersion() // needed to make stores non-nil
@ -644,7 +660,7 @@ func TestInitChain_WithInitialHeight(t *testing.T) {
name := t.Name() name := t.Name()
db := dbm.NewMemDB() db := dbm.NewMemDB()
logger := defaultLogger() logger := defaultLogger()
app := NewBaseApp(name, logger, db, nil) app := baseapp.NewBaseApp(name, logger, db, nil)
app.InitChain( app.InitChain(
abci.RequestInitChain{ abci.RequestInitChain{
@ -660,7 +676,7 @@ func TestBeginBlock_WithInitialHeight(t *testing.T) {
name := t.Name() name := t.Name()
db := dbm.NewMemDB() db := dbm.NewMemDB()
logger := defaultLogger() logger := defaultLogger()
app := NewBaseApp(name, logger, db, nil) app := baseapp.NewBaseApp(name, logger, db, nil)
app.InitChain( app.InitChain(
abci.RequestInitChain{ abci.RequestInitChain{
@ -711,6 +727,9 @@ func (tx txTest) ValidateBasic() error { return nil }
// Implements GasTx // Implements GasTx
func (tx txTest) GetGas() uint64 { return tx.GasLimit } func (tx txTest) GetGas() uint64 { return tx.GasLimit }
// Implements TxWithTimeoutHeight
func (tx txTest) GetTimeoutHeight() uint64 { return 0 }
const ( const (
routeMsgCounter = "msgCounter" routeMsgCounter = "msgCounter"
routeMsgCounter2 = "msgCounter2" routeMsgCounter2 = "msgCounter2"
@ -826,7 +845,7 @@ func testTxDecoder(cdc *codec.LegacyAmino) sdk.TxDecoder {
} }
} }
func anteHandlerTxTest(t *testing.T, capKey sdk.StoreKey, storeKey []byte) sdk.AnteHandler { func customHandlerTxTest(t *testing.T, capKey sdk.StoreKey, storeKey []byte) handlerFun {
return func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { return func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) {
store := ctx.KVStore(capKey) store := ctx.KVStore(capKey)
txTest := tx.(txTest) txTest := tx.(txTest)
@ -841,7 +860,7 @@ func anteHandlerTxTest(t *testing.T, capKey sdk.StoreKey, storeKey []byte) sdk.A
} }
ctx.EventManager().EmitEvents( ctx.EventManager().EmitEvents(
counterEvent("ante_handler", txTest.Counter), counterEvent("post_handlers", txTest.Counter),
) )
return ctx, nil return ctx, nil
@ -929,18 +948,19 @@ func TestCheckTx(t *testing.T) {
// This ensures changes to the kvstore persist across successive CheckTx. // This ensures changes to the kvstore persist across successive CheckTx.
counterKey := []byte("counter-key") counterKey := []byte("counter-key")
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
// TODO: can remove this once CheckTx doesnt process msgs. // TODO: can remove this once CheckTx doesnt process msgs.
legacyRouter.AddRoute(sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { legacyRouter.AddRoute(sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
return &sdk.Result{}, nil return &sdk.Result{}, nil
})) }))
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
LegacyAnteHandler: anteHandlerTxTest(t, capKey1, counterKey),
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
}) },
require.NoError(t, err) customHandlerTxTest(t, capKey1, counterKey),
)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
@ -962,23 +982,23 @@ func TestCheckTx(t *testing.T) {
require.True(t, r.IsOK(), fmt.Sprintf("%v", r)) require.True(t, r.IsOK(), fmt.Sprintf("%v", r))
} }
checkStateStore := app.checkState.ctx.KVStore(capKey1) checkStateStore := app.CheckState().Context().KVStore(capKey1)
storedCounter := getIntFromStore(checkStateStore, counterKey) storedCounter := getIntFromStore(checkStateStore, counterKey)
// Ensure AnteHandler ran // Ensure storedCounter
require.Equal(t, nTxs, storedCounter) require.Equal(t, nTxs, storedCounter)
// If a block is committed, CheckTx state should be reset. // If a block is committed, CheckTx state should be reset.
header := tmproto.Header{Height: 1} header := tmproto.Header{Height: 1}
app.BeginBlock(abci.RequestBeginBlock{Header: header, Hash: []byte("hash")}) app.BeginBlock(abci.RequestBeginBlock{Header: header, Hash: []byte("hash")})
require.NotNil(t, app.checkState.ctx.BlockGasMeter(), "block gas meter should have been set to checkState") require.NotNil(t, app.CheckState().Context().BlockGasMeter(), "block gas meter should have been set to checkState")
require.NotEmpty(t, app.checkState.ctx.HeaderHash()) require.NotEmpty(t, app.CheckState().Context().HeaderHash())
app.EndBlock(abci.RequestEndBlock{}) app.EndBlock(abci.RequestEndBlock{})
app.Commit() app.Commit()
checkStateStore = app.checkState.ctx.KVStore(capKey1) checkStateStore = app.CheckState().Context().KVStore(capKey1)
storedBytes := checkStateStore.Get(counterKey) storedBytes := checkStateStore.Get(counterKey)
require.Nil(t, storedBytes) require.Nil(t, storedBytes)
} }
@ -986,20 +1006,21 @@ func TestCheckTx(t *testing.T) {
// Test that successive DeliverTx can see each others' effects // Test that successive DeliverTx can see each others' effects
// on the store, both within and across blocks. // on the store, both within and across blocks.
func TestDeliverTx(t *testing.T) { func TestDeliverTx(t *testing.T) {
// test increments in the ante // test increments in the post txHandler
anteKey := []byte("ante-key") anteKey := []byte("ante-key")
// test increments in the handler // test increments in the handler
deliverKey := []byte("deliver-key") deliverKey := []byte("deliver-key")
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey)) r := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey))
legacyRouter.AddRoute(r) legacyRouter.AddRoute(r)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
LegacyAnteHandler: anteHandlerTxTest(t, capKey1, anteKey),
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
}) },
require.NoError(t, err) customHandlerTxTest(t, capKey1, anteKey),
)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
app := setupBaseApp(t, txHandlerOpt) app := setupBaseApp(t, txHandlerOpt)
@ -1027,7 +1048,7 @@ func TestDeliverTx(t *testing.T) {
require.True(t, res.IsOK(), fmt.Sprintf("%v", res)) require.True(t, res.IsOK(), fmt.Sprintf("%v", res))
events := res.GetEvents() events := res.GetEvents()
require.Len(t, events, 3, "should contain ante handler, message type and counter events respectively") require.Len(t, events, 3, "should contain ante handler, message type and counter events respectively")
require.Equal(t, sdk.MarkEventsToIndex(counterEvent("ante_handler", counter).ToABCIEvents(), map[string]struct{}{})[0], events[0], "ante handler event") require.Equal(t, sdk.MarkEventsToIndex(counterEvent("post_handlers", counter).ToABCIEvents(), map[string]struct{}{})[0], events[0], "ante handler event")
require.Equal(t, sdk.MarkEventsToIndex(counterEvent(sdk.EventTypeMessage, counter).ToABCIEvents(), map[string]struct{}{})[0], events[2], "msg handler update counter event") require.Equal(t, sdk.MarkEventsToIndex(counterEvent(sdk.EventTypeMessage, counter).ToABCIEvents(), map[string]struct{}{})[0], events[2], "msg handler update counter event")
} }
@ -1049,18 +1070,19 @@ func TestMultiMsgDeliverTx(t *testing.T) {
// increment the msg counter // increment the msg counter
deliverKey := []byte("deliver-key") deliverKey := []byte("deliver-key")
deliverKey2 := []byte("deliver-key2") deliverKey2 := []byte("deliver-key2")
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
r1 := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey)) r1 := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey))
r2 := sdk.NewRoute(routeMsgCounter2, handlerMsgCounter(t, capKey1, deliverKey2)) r2 := sdk.NewRoute(routeMsgCounter2, handlerMsgCounter(t, capKey1, deliverKey2))
legacyRouter.AddRoute(r1) legacyRouter.AddRoute(r1)
legacyRouter.AddRoute(r2) legacyRouter.AddRoute(r2)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
LegacyAnteHandler: anteHandlerTxTest(t, capKey1, anteKey),
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
}) },
require.NoError(t, err) customHandlerTxTest(t, capKey1, anteKey),
)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
app := setupBaseApp(t, txHandlerOpt) app := setupBaseApp(t, txHandlerOpt)
@ -1080,7 +1102,7 @@ func TestMultiMsgDeliverTx(t *testing.T) {
res := app.DeliverTx(abci.RequestDeliverTx{Tx: txBytes}) res := app.DeliverTx(abci.RequestDeliverTx{Tx: txBytes})
require.True(t, res.IsOK(), fmt.Sprintf("%v", res)) require.True(t, res.IsOK(), fmt.Sprintf("%v", res))
store := app.deliverState.ctx.KVStore(capKey1) store := app.DeliverState().Context().KVStore(capKey1)
// tx counter only incremented once // tx counter only incremented once
txCounter := getIntFromStore(store, anteKey) txCounter := getIntFromStore(store, anteKey)
@ -1100,7 +1122,7 @@ func TestMultiMsgDeliverTx(t *testing.T) {
res = app.DeliverTx(abci.RequestDeliverTx{Tx: txBytes}) res = app.DeliverTx(abci.RequestDeliverTx{Tx: txBytes})
require.True(t, res.IsOK(), fmt.Sprintf("%v", res)) require.True(t, res.IsOK(), fmt.Sprintf("%v", res))
store = app.deliverState.ctx.KVStore(capKey1) store = app.DeliverState().Context().KVStore(capKey1)
// tx counter only incremented once // tx counter only incremented once
txCounter = getIntFromStore(store, anteKey) txCounter = getIntFromStore(store, anteKey)
@ -1127,19 +1149,20 @@ func TestConcurrentCheckDeliver(t *testing.T) {
func TestSimulateTx(t *testing.T) { func TestSimulateTx(t *testing.T) {
gasConsumed := uint64(5) gasConsumed := uint64(5)
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
ctx.GasMeter().ConsumeGas(gasConsumed, "test") ctx.GasMeter().ConsumeGas(gasConsumed, "test")
return &sdk.Result{}, nil return &sdk.Result{}, nil
}) })
legacyRouter.AddRoute(r) legacyRouter.AddRoute(r)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
LegacyAnteHandler: func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { return ctx, nil },
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
}) },
require.NoError(t, err) func(ctx sdk.Context, tx sdk.Tx, simulate bool) (sdk.Context, error) { return ctx, nil },
)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
app := setupBaseApp(t, txHandlerOpt) app := setupBaseApp(t, txHandlerOpt)
@ -1195,20 +1218,21 @@ func TestSimulateTx(t *testing.T) {
} }
func TestRunInvalidTransaction(t *testing.T) { func TestRunInvalidTransaction(t *testing.T) {
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
return &sdk.Result{}, nil return &sdk.Result{}, nil
}) })
legacyRouter.AddRoute(r) legacyRouter.AddRoute(r)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
LegacyAnteHandler: func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) { MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
},
func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) {
return return
}, },
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), )
})
require.NoError(t, err)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
app := setupBaseApp(t, txHandlerOpt) app := setupBaseApp(t, txHandlerOpt)
@ -1309,7 +1333,7 @@ func TestTxGasLimits(t *testing.T) {
return ctx, nil return ctx, nil
} }
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
count := msg.(msgCounter).Counter count := msg.(msgCounter).Counter
@ -1317,12 +1341,14 @@ func TestTxGasLimits(t *testing.T) {
return &sdk.Result{}, nil return &sdk.Result{}, nil
}) })
legacyRouter.AddRoute(r) legacyRouter.AddRoute(r)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
LegacyAnteHandler: ante,
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
}) },
require.NoError(t, err) ante,
)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
app := setupBaseApp(t, txHandlerOpt) app := setupBaseApp(t, txHandlerOpt)
@ -1386,7 +1412,7 @@ func TestMaxBlockGasLimits(t *testing.T) {
return ctx, nil return ctx, nil
} }
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
count := msg.(msgCounter).Counter count := msg.(msgCounter).Counter
@ -1394,12 +1420,13 @@ func TestMaxBlockGasLimits(t *testing.T) {
return &sdk.Result{}, nil return &sdk.Result{}, nil
}) })
legacyRouter.AddRoute(r) legacyRouter.AddRoute(r)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
LegacyAnteHandler: ante,
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
}) },
require.NoError(t, err) ante,
)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
app := setupBaseApp(t, txHandlerOpt) app := setupBaseApp(t, txHandlerOpt)
@ -1443,7 +1470,7 @@ func TestMaxBlockGasLimits(t *testing.T) {
for j := 0; j < tc.numDelivers; j++ { for j := 0; j < tc.numDelivers; j++ {
_, result, err := app.SimDeliver(aminoTxEncoder(), tx) _, result, err := app.SimDeliver(aminoTxEncoder(), tx)
ctx := app.getState(runTxModeDeliver).ctx ctx := app.DeliverState().Context()
// check for failed transactions // check for failed transactions
if tc.fail && (j+1) > tc.failAfterDeliver { if tc.fail && (j+1) > tc.failAfterDeliver {
@ -1470,21 +1497,22 @@ func TestMaxBlockGasLimits(t *testing.T) {
} }
} }
func TestBaseAppAnteHandler(t *testing.T) { func TestBaseAppMiddleware(t *testing.T) {
anteKey := []byte("ante-key") anteKey := []byte("ante-key")
deliverKey := []byte("deliver-key") deliverKey := []byte("deliver-key")
cdc := codec.NewLegacyAmino() cdc := codec.NewLegacyAmino()
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey)) r := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey))
legacyRouter.AddRoute(r) legacyRouter.AddRoute(r)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
LegacyAnteHandler: anteHandlerTxTest(t, capKey1, anteKey),
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
}) },
require.NoError(t, err) customHandlerTxTest(t, capKey1, anteKey),
)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
app := setupBaseApp(t, txHandlerOpt) app := setupBaseApp(t, txHandlerOpt)
@ -1498,7 +1526,7 @@ func TestBaseAppAnteHandler(t *testing.T) {
// execute a tx that will fail ante handler execution // execute a tx that will fail ante handler execution
// //
// NOTE: State should not be mutated here. This will be implicitly checked by // NOTE: State should not be mutated here. This will be implicitly checked by
// the next txs ante handler execution (anteHandlerTxTest). // the next txs ante handler execution (customHandlerTxTest).
tx := newTxCounter(0, 0) tx := newTxCounter(0, 0)
tx.setFailOnAnte(true) tx.setFailOnAnte(true)
txBytes, err := cdc.Marshal(tx) txBytes, err := cdc.Marshal(tx)
@ -1507,7 +1535,7 @@ func TestBaseAppAnteHandler(t *testing.T) {
require.Empty(t, res.Events) require.Empty(t, res.Events)
require.False(t, res.IsOK(), fmt.Sprintf("%v", res)) require.False(t, res.IsOK(), fmt.Sprintf("%v", res))
ctx := app.getState(runTxModeDeliver).ctx ctx := app.DeliverState().Context()
store := ctx.KVStore(capKey1) store := ctx.KVStore(capKey1)
require.Equal(t, int64(0), getIntFromStore(store, anteKey)) require.Equal(t, int64(0), getIntFromStore(store, anteKey))
@ -1523,7 +1551,7 @@ func TestBaseAppAnteHandler(t *testing.T) {
require.Empty(t, res.Events) require.Empty(t, res.Events)
require.False(t, res.IsOK(), fmt.Sprintf("%v", res)) require.False(t, res.IsOK(), fmt.Sprintf("%v", res))
ctx = app.getState(runTxModeDeliver).ctx ctx = app.DeliverState().Context()
store = ctx.KVStore(capKey1) store = ctx.KVStore(capKey1)
require.Equal(t, int64(1), getIntFromStore(store, anteKey)) require.Equal(t, int64(1), getIntFromStore(store, anteKey))
require.Equal(t, int64(0), getIntFromStore(store, deliverKey)) require.Equal(t, int64(0), getIntFromStore(store, deliverKey))
@ -1539,7 +1567,7 @@ func TestBaseAppAnteHandler(t *testing.T) {
require.NotEmpty(t, res.Events) require.NotEmpty(t, res.Events)
require.True(t, res.IsOK(), fmt.Sprintf("%v", res)) require.True(t, res.IsOK(), fmt.Sprintf("%v", res))
ctx = app.getState(runTxModeDeliver).ctx ctx = app.DeliverState().Context()
store = ctx.KVStore(capKey1) store = ctx.KVStore(capKey1)
require.Equal(t, int64(2), getIntFromStore(store, anteKey)) require.Equal(t, int64(2), getIntFromStore(store, anteKey))
require.Equal(t, int64(1), getIntFromStore(store, deliverKey)) require.Equal(t, int64(1), getIntFromStore(store, deliverKey))
@ -1564,7 +1592,7 @@ func TestGasConsumptionBadTx(t *testing.T) {
cdc := codec.NewLegacyAmino() cdc := codec.NewLegacyAmino()
registerTestCodec(cdc) registerTestCodec(cdc)
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
count := msg.(msgCounter).Counter count := msg.(msgCounter).Counter
@ -1572,12 +1600,13 @@ func TestGasConsumptionBadTx(t *testing.T) {
return &sdk.Result{}, nil return &sdk.Result{}, nil
}) })
legacyRouter.AddRoute(r) legacyRouter.AddRoute(r)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
LegacyAnteHandler: ante,
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
}) },
require.NoError(t, err) ante,
)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
app := setupBaseApp(t, txHandlerOpt) app := setupBaseApp(t, txHandlerOpt)
@ -1617,7 +1646,7 @@ func TestGasConsumptionBadTx(t *testing.T) {
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
key, value := []byte("hello"), []byte("goodbye") key, value := []byte("hello"), []byte("goodbye")
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter() legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
store := ctx.KVStore(capKey1) store := ctx.KVStore(capKey1)
@ -1625,16 +1654,17 @@ func TestQuery(t *testing.T) {
return &sdk.Result{}, nil return &sdk.Result{}, nil
}) })
legacyRouter.AddRoute(r) legacyRouter.AddRoute(r)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
LegacyAnteHandler: func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) { MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
},
func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) {
store := ctx.KVStore(capKey1) store := ctx.KVStore(capKey1)
store.Set(key, value) store.Set(key, value)
return return
}, },
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry), )
})
require.NoError(t, err)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
app := setupBaseApp(t, txHandlerOpt) app := setupBaseApp(t, txHandlerOpt)
@ -1678,7 +1708,7 @@ func TestQuery(t *testing.T) {
} }
func TestGRPCQuery(t *testing.T) { func TestGRPCQuery(t *testing.T) {
grpcQueryOpt := func(bapp *BaseApp) { grpcQueryOpt := func(bapp *baseapp.BaseApp) {
testdata.RegisterQueryServer( testdata.RegisterQueryServer(
bapp.GRPCQueryRouter(), bapp.GRPCQueryRouter(),
testdata.QueryImpl{}, testdata.QueryImpl{},
@ -1713,14 +1743,14 @@ func TestGRPCQuery(t *testing.T) {
// Test p2p filter queries // Test p2p filter queries
func TestP2PQuery(t *testing.T) { func TestP2PQuery(t *testing.T) {
addrPeerFilterOpt := func(bapp *BaseApp) { addrPeerFilterOpt := func(bapp *baseapp.BaseApp) {
bapp.SetAddrPeerFilter(func(addrport string) abci.ResponseQuery { bapp.SetAddrPeerFilter(func(addrport string) abci.ResponseQuery {
require.Equal(t, "1.1.1.1:8000", addrport) require.Equal(t, "1.1.1.1:8000", addrport)
return abci.ResponseQuery{Code: uint32(3)} return abci.ResponseQuery{Code: uint32(3)}
}) })
} }
idPeerFilterOpt := func(bapp *BaseApp) { idPeerFilterOpt := func(bapp *baseapp.BaseApp) {
bapp.SetIDPeerFilter(func(id string) abci.ResponseQuery { bapp.SetIDPeerFilter(func(id string) abci.ResponseQuery {
require.Equal(t, "testid", id) require.Equal(t, "testid", id)
return abci.ResponseQuery{Code: uint32(4)} return abci.ResponseQuery{Code: uint32(4)}
@ -1748,16 +1778,16 @@ func TestGetMaximumBlockGas(t *testing.T) {
ctx := app.NewContext(true, tmproto.Header{}) ctx := app.NewContext(true, tmproto.Header{})
app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: 0}}) app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: 0}})
require.Equal(t, uint64(0), app.getMaximumBlockGas(ctx)) require.Equal(t, uint64(0), app.GetMaximumBlockGas(ctx))
app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: -1}}) app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: -1}})
require.Equal(t, uint64(0), app.getMaximumBlockGas(ctx)) require.Equal(t, uint64(0), app.GetMaximumBlockGas(ctx))
app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: 5000000}}) app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: 5000000}})
require.Equal(t, uint64(5000000), app.getMaximumBlockGas(ctx)) require.Equal(t, uint64(5000000), app.GetMaximumBlockGas(ctx))
app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: -5000000}}) app.StoreConsensusParams(ctx, &abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: -5000000}})
require.Panics(t, func() { app.getMaximumBlockGas(ctx) }) require.Panics(t, func() { app.GetMaximumBlockGas(ctx) })
} }
func TestListSnapshots(t *testing.T) { func TestListSnapshots(t *testing.T) {
@ -1940,21 +1970,14 @@ func (rtr *testCustomRouter) Route(ctx sdk.Context, path string) sdk.Handler {
} }
func TestWithRouter(t *testing.T) { func TestWithRouter(t *testing.T) {
// test increments in the ante
anteKey := []byte("ante-key")
// test increments in the handler // test increments in the handler
deliverKey := []byte("deliver-key") deliverKey := []byte("deliver-key")
txHandlerOpt := func(bapp *BaseApp) { txHandlerOpt := func(bapp *baseapp.BaseApp) {
customRouter := &testCustomRouter{routes: sync.Map{}} customRouter := &testCustomRouter{routes: sync.Map{}}
r := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey)) r := sdk.NewRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey))
customRouter.AddRoute(r) customRouter.AddRoute(r)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := middleware.NewRunMsgsTxHandler(middleware.NewMsgServiceRouter(interfaceRegistry), customRouter)
LegacyRouter: customRouter,
LegacyAnteHandler: anteHandlerTxTest(t, capKey1, anteKey),
MsgServiceRouter: middleware.NewMsgServiceRouter(interfaceRegistry),
})
require.NoError(t, err)
bapp.SetTxHandler(txHandler) bapp.SetTxHandler(txHandler)
} }
app := setupBaseApp(t, txHandlerOpt) app := setupBaseApp(t, txHandlerOpt)
@ -1998,7 +2021,7 @@ func TestBaseApp_EndBlock(t *testing.T) {
}, },
} }
app := NewBaseApp(name, logger, db, nil) app := baseapp.NewBaseApp(name, logger, db, nil)
app.SetParamStore(&paramStore{db: dbm.NewMemDB()}) app.SetParamStore(&paramStore{db: dbm.NewMemDB()})
app.InitChain(abci.RequestInitChain{ app.InitChain(abci.RequestInitChain{
ConsensusParams: cp, ConsensusParams: cp,

View File

@ -0,0 +1,117 @@
package baseapp_test
import (
"context"
"fmt"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/tx"
abci "github.com/tendermint/tendermint/abci/types"
"github.com/tendermint/tendermint/crypto/tmhash"
)
type handlerFun func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error)
type customTxHandler struct {
handler handlerFun
next tx.Handler
}
var _ tx.Handler = customTxHandler{}
// CustomTxMiddleware is being used in tests for testing
// custom pre-`runMsgs` logic (also called antehandlers before).
func CustomTxHandlerMiddleware(handler handlerFun) tx.Middleware {
return func(txHandler tx.Handler) tx.Handler {
return customTxHandler{
handler: handler,
next: txHandler,
}
}
}
// CheckTx implements tx.Handler.CheckTx method.
func (txh customTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
sdkCtx, err := txh.runHandler(ctx, tx, req.Tx, false)
if err != nil {
return abci.ResponseCheckTx{}, err
}
return txh.next.CheckTx(sdk.WrapSDKContext(sdkCtx), tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx method.
func (txh customTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
sdkCtx, err := txh.runHandler(ctx, tx, req.Tx, false)
if err != nil {
return abci.ResponseDeliverTx{}, err
}
return txh.next.DeliverTx(sdk.WrapSDKContext(sdkCtx), tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx method.
func (txh customTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
sdkCtx, err := txh.runHandler(ctx, sdkTx, req.TxBytes, true)
if err != nil {
return tx.ResponseSimulateTx{}, err
}
return txh.next.SimulateTx(sdk.WrapSDKContext(sdkCtx), sdkTx, req)
}
func (txh customTxHandler) runHandler(ctx context.Context, tx sdk.Tx, txBytes []byte, isSimulate bool) (sdk.Context, error) {
sdkCtx := sdk.UnwrapSDKContext(ctx)
if txh.handler == nil {
return sdkCtx, nil
}
ms := sdkCtx.MultiStore()
// Branch context before Handler call in case it aborts.
// This is required for both CheckTx and DeliverTx.
// Ref: https://github.com/cosmos/cosmos-sdk/issues/2772
//
// NOTE: Alternatively, we could require that Handler ensures that
// writes do not happen if aborted/failed. This may have some
// performance benefits, but it'll be more difficult to get right.
cacheCtx, msCache := cacheTxContext(sdkCtx, txBytes)
cacheCtx = cacheCtx.WithEventManager(sdk.NewEventManager())
newCtx, err := txh.handler(cacheCtx, tx, isSimulate)
if err != nil {
return sdk.Context{}, err
}
if !newCtx.IsZero() {
// At this point, newCtx.MultiStore() is a store branch, or something else
// replaced by the Handler. We want the original multistore.
//
// Also, in the case of the tx aborting, we need to track gas consumed via
// the instantiated gas meter in the Handler, so we update the context
// prior to returning.
sdkCtx = newCtx.WithMultiStore(ms)
}
msCache.Write()
return sdkCtx, nil
}
// cacheTxContext returns a new context based off of the provided context with
// a branched multi-store.
func cacheTxContext(sdkCtx sdk.Context, txBytes []byte) (sdk.Context, sdk.CacheMultiStore) {
ms := sdkCtx.MultiStore()
// TODO: https://github.com/cosmos/cosmos-sdk/issues/2824
msCache := ms.CacheMultiStore()
if msCache.TracingEnabled() {
msCache = msCache.SetTracingContext(
sdk.TraceContext(
map[string]interface{}{
"txHash": fmt.Sprintf("%X", tmhash.Sum(txBytes)),
},
),
).(sdk.CacheMultiStore)
}
return sdkCtx.WithMultiStore(msCache), msCache
}

View File

@ -1,4 +1,4 @@
package baseapp package baseapp_test
import ( import (
"testing" "testing"
@ -7,6 +7,7 @@ import (
abci "github.com/tendermint/tendermint/abci/types" abci "github.com/tendermint/tendermint/abci/types"
"github.com/cosmos/cosmos-sdk/baseapp"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
) )
@ -15,7 +16,7 @@ var testQuerier = func(_ sdk.Context, _ []string, _ abci.RequestQuery) ([]byte,
} }
func TestQueryRouter(t *testing.T) { func TestQueryRouter(t *testing.T) {
qr := NewQueryRouter() qr := baseapp.NewQueryRouter()
// require panic on invalid route // require panic on invalid route
require.Panics(t, func() { require.Panics(t, func() {

67
baseapp/util_test.go Normal file
View File

@ -0,0 +1,67 @@
package baseapp
import (
"github.com/cosmos/cosmos-sdk/types"
sdk "github.com/cosmos/cosmos-sdk/types"
)
// TODO: Can be removed once we move all middleware tests into x/auth/middleware
// ref: #https://github.com/cosmos/cosmos-sdk/issues/10282
// CheckState is an exported method to be able to access baseapp's
// checkState in tests.
//
// This method is only accessible in baseapp tests.
func (app *BaseApp) CheckState() *state {
return app.checkState
}
// DeliverState is an exported method to be able to access baseapp's
// deliverState in tests.
//
// This method is only accessible in baseapp tests.
func (app *BaseApp) DeliverState() *state {
return app.deliverState
}
// CMS is an exported method to be able to access baseapp's cms in tests.
//
// This method is only accessible in baseapp tests.
func (app *BaseApp) CMS() types.CommitMultiStore {
return app.cms
}
// GetMaximumBlockGas return maximum blocks gas.
//
// This method is only accessible in baseapp tests.
func (app *BaseApp) GetMaximumBlockGas(ctx sdk.Context) uint64 {
return app.getMaximumBlockGas(ctx)
}
// GetName return name.
//
// This method is only accessible in baseapp tests.
func (app *BaseApp) GetName() string {
return app.name
}
// GetName return name.
//
// This method is only accessible in baseapp tests.
func (app *BaseApp) TxDecoder(txBytes []byte) (sdk.Tx, error) {
return app.txDecoder(txBytes)
}
// CreateQueryContext calls app's createQueryContext.
//
// This method is only accessible in baseapp tests.
func (app *BaseApp) CreateQueryContext(height int64, prove bool) (sdk.Context, error) {
return app.createQueryContext(height, prove)
}
// MinGasPrices returns minGasPrices.
//
// This method is only accessible in baseapp tests.
func (app *BaseApp) MinGasPrices() sdk.DecCoins {
return app.minGasPrices
}

View File

@ -1,7 +1,7 @@
[ [
{ {
"account_identifier": { "account_identifier": {
"address":"cosmos1y3awd3vl7g29q44uvz0yrevcduf2exvkwxk3uq" "address":"cosmos1wy36cv7hveh7xt4ushy2twp5czqxnz5v6rn3xw"
}, },
"currency":{ "currency":{
"symbol":"stake", "symbol":"stake",

View File

@ -15,9 +15,19 @@ import (
"github.com/cosmos/cosmos-sdk/codec" "github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/simapp" "github.com/cosmos/cosmos-sdk/simapp"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/x/auth/middleware" "github.com/cosmos/cosmos-sdk/x/auth/middleware"
) )
func testTxHandler(options middleware.TxHandlerOptions) tx.Handler {
return middleware.ComposeMiddlewares(
middleware.NewRunMsgsTxHandler(options.MsgServiceRouter, options.LegacyRouter),
middleware.GasTxMiddleware,
middleware.RecoveryTxMiddleware,
middleware.NewIndexEventsTxMiddleware(options.IndexEvents),
)
}
// NewApp creates a simple mock kvstore app for testing. It should work // NewApp creates a simple mock kvstore app for testing. It should work
// similar to a real app. Make sure rootDir is empty before running the test, // similar to a real app. Make sure rootDir is empty before running the test,
// in order to guarantee consistent results // in order to guarantee consistent results
@ -44,13 +54,12 @@ func NewApp(rootDir string, logger log.Logger) (abci.Application, error) {
// We're adding a test legacy route here, which accesses the kvstore // We're adding a test legacy route here, which accesses the kvstore
// and simply sets the Msg's key/value pair in the kvstore. // and simply sets the Msg's key/value pair in the kvstore.
legacyRouter.AddRoute(sdk.NewRoute("kvstore", KVStoreHandler(capKeyMainStore))) legacyRouter.AddRoute(sdk.NewRoute("kvstore", KVStoreHandler(capKeyMainStore)))
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{ txHandler := testTxHandler(
middleware.TxHandlerOptions{
LegacyRouter: legacyRouter, LegacyRouter: legacyRouter,
MsgServiceRouter: middleware.NewMsgServiceRouter(encCfg.InterfaceRegistry), MsgServiceRouter: middleware.NewMsgServiceRouter(encCfg.InterfaceRegistry),
}) },
if err != nil { )
return nil, err
}
baseApp.SetTxHandler(txHandler) baseApp.SetTxHandler(txHandler)
// Load latest version. // Load latest version.

View File

@ -30,7 +30,6 @@ import (
"github.com/cosmos/cosmos-sdk/types/module" "github.com/cosmos/cosmos-sdk/types/module"
"github.com/cosmos/cosmos-sdk/version" "github.com/cosmos/cosmos-sdk/version"
"github.com/cosmos/cosmos-sdk/x/auth" "github.com/cosmos/cosmos-sdk/x/auth"
"github.com/cosmos/cosmos-sdk/x/auth/ante"
authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper"
authsims "github.com/cosmos/cosmos-sdk/x/auth/simulation" authsims "github.com/cosmos/cosmos-sdk/x/auth/simulation"
authtx "github.com/cosmos/cosmos-sdk/x/auth/tx" authtx "github.com/cosmos/cosmos-sdk/x/auth/tx"
@ -401,19 +400,6 @@ func NewSimApp(
} }
func (app *SimApp) setTxHandler(txConfig client.TxConfig, indexEventsStr []string) { func (app *SimApp) setTxHandler(txConfig client.TxConfig, indexEventsStr []string) {
anteHandler, err := ante.NewAnteHandler(
ante.HandlerOptions{
AccountKeeper: app.AccountKeeper,
BankKeeper: app.BankKeeper,
SignModeHandler: txConfig.SignModeHandler(),
FeegrantKeeper: app.FeeGrantKeeper,
SigGasConsumer: ante.DefaultSigVerificationGasConsumer,
},
)
if err != nil {
panic(err)
}
indexEvents := map[string]struct{}{} indexEvents := map[string]struct{}{}
for _, e := range indexEventsStr { for _, e := range indexEventsStr {
indexEvents[e] = struct{}{} indexEvents[e] = struct{}{}
@ -423,7 +409,11 @@ func (app *SimApp) setTxHandler(txConfig client.TxConfig, indexEventsStr []strin
IndexEvents: indexEvents, IndexEvents: indexEvents,
LegacyRouter: app.legacyRouter, LegacyRouter: app.legacyRouter,
MsgServiceRouter: app.msgSvcRouter, MsgServiceRouter: app.msgSvcRouter,
LegacyAnteHandler: anteHandler, AccountKeeper: app.AccountKeeper,
BankKeeper: app.BankKeeper,
FeegrantKeeper: app.FeeGrantKeeper,
SignModeHandler: txConfig.SignModeHandler(),
SigGasConsumer: authmiddleware.DefaultSigVerificationGasConsumer,
}) })
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -1,57 +0,0 @@
package ante
import (
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx/signing"
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/cosmos/cosmos-sdk/x/auth/types"
)
// HandlerOptions are the options required for constructing a default SDK AnteHandler.
type HandlerOptions struct {
AccountKeeper AccountKeeper
BankKeeper types.BankKeeper
FeegrantKeeper FeegrantKeeper
SignModeHandler authsigning.SignModeHandler
SigGasConsumer func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error
}
// NewAnteHandler returns an AnteHandler that checks and increments sequence
// numbers, checks signatures & account numbers, and deducts fees from the first
// signer.
func NewAnteHandler(options HandlerOptions) (sdk.AnteHandler, error) {
if options.AccountKeeper == nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "account keeper is required for ante builder")
}
if options.BankKeeper == nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "bank keeper is required for ante builder")
}
if options.SignModeHandler == nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "sign mode handler is required for ante builder")
}
var sigGasConsumer = options.SigGasConsumer
if sigGasConsumer == nil {
sigGasConsumer = DefaultSigVerificationGasConsumer
}
anteDecorators := []sdk.AnteDecorator{
NewRejectExtensionOptionsDecorator(),
NewMempoolFeeDecorator(),
NewValidateBasicDecorator(),
NewTxTimeoutHeightDecorator(),
NewValidateMemoDecorator(options.AccountKeeper),
NewConsumeGasForTxSizeDecorator(options.AccountKeeper),
NewDeductFeeDecorator(options.AccountKeeper, options.BankKeeper, options.FeegrantKeeper),
NewSetPubKeyDecorator(options.AccountKeeper), // SetPubKeyDecorator must be called before all signature verification decorators
NewValidateSigCountDecorator(options.AccountKeeper),
NewSigGasConsumeDecorator(options.AccountKeeper, sigGasConsumer),
NewSigVerificationDecorator(options.AccountKeeper, options.SignModeHandler),
NewIncrementSequenceDecorator(options.AccountKeeper),
}
return sdk.ChainAnteDecorators(anteDecorators...), nil
}

View File

@ -1,207 +0,0 @@
package ante
import (
"github.com/cosmos/cosmos-sdk/codec/legacy"
"github.com/cosmos/cosmos-sdk/crypto/keys/multisig"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx"
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
)
// ValidateBasicDecorator will call tx.ValidateBasic, msg.ValidateBasic(for each msg inside tx)
// and return any non-nil error.
// If ValidateBasic passes, decorator calls next AnteHandler in chain. Note,
// ValidateBasicDecorator decorator will not get executed on ReCheckTx since it
// is not dependent on application state.
type ValidateBasicDecorator struct{}
func NewValidateBasicDecorator() ValidateBasicDecorator {
return ValidateBasicDecorator{}
}
func (vbd ValidateBasicDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) {
// no need to validate basic on recheck tx, call next antehandler
if ctx.IsReCheckTx() {
return next(ctx, tx, simulate)
}
if err := tx.ValidateBasic(); err != nil {
return ctx, err
}
return next(ctx, tx, simulate)
}
// ValidateMemoDecorator will validate memo given the parameters passed in
// If memo is too large decorator returns with error, otherwise call next AnteHandler
// CONTRACT: Tx must implement TxWithMemo interface
type ValidateMemoDecorator struct {
ak AccountKeeper
}
func NewValidateMemoDecorator(ak AccountKeeper) ValidateMemoDecorator {
return ValidateMemoDecorator{
ak: ak,
}
}
func (vmd ValidateMemoDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) {
memoTx, ok := tx.(sdk.TxWithMemo)
if !ok {
return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}
params := vmd.ak.GetParams(ctx)
memoLength := len(memoTx.GetMemo())
if uint64(memoLength) > params.MaxMemoCharacters {
return ctx, sdkerrors.Wrapf(sdkerrors.ErrMemoTooLarge,
"maximum number of characters is %d but received %d characters",
params.MaxMemoCharacters, memoLength,
)
}
return next(ctx, tx, simulate)
}
// ConsumeTxSizeGasDecorator will take in parameters and consume gas proportional
// to the size of tx before calling next AnteHandler. Note, the gas costs will be
// slightly over estimated due to the fact that any given signing account may need
// to be retrieved from state.
//
// CONTRACT: If simulate=true, then signatures must either be completely filled
// in or empty.
// CONTRACT: To use this decorator, signatures of transaction must be represented
// as legacytx.StdSignature otherwise simulate mode will incorrectly estimate gas cost.
type ConsumeTxSizeGasDecorator struct {
ak AccountKeeper
}
func NewConsumeGasForTxSizeDecorator(ak AccountKeeper) ConsumeTxSizeGasDecorator {
return ConsumeTxSizeGasDecorator{
ak: ak,
}
}
func (cgts ConsumeTxSizeGasDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) {
sigTx, ok := tx.(authsigning.SigVerifiableTx)
if !ok {
return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid tx type")
}
params := cgts.ak.GetParams(ctx)
ctx.GasMeter().ConsumeGas(params.TxSizeCostPerByte*sdk.Gas(len(ctx.TxBytes())), "txSize")
// simulate gas cost for signatures in simulate mode
if simulate {
// in simulate mode, each element should be a nil signature
sigs, err := sigTx.GetSignaturesV2()
if err != nil {
return ctx, err
}
n := len(sigs)
for i, signer := range sigTx.GetSigners() {
// if signature is already filled in, no need to simulate gas cost
if i < n && !isIncompleteSignature(sigs[i].Data) {
continue
}
var pubkey cryptotypes.PubKey
acc := cgts.ak.GetAccount(ctx, signer)
// use placeholder simSecp256k1Pubkey if sig is nil
if acc == nil || acc.GetPubKey() == nil {
pubkey = simSecp256k1Pubkey
} else {
pubkey = acc.GetPubKey()
}
// use stdsignature to mock the size of a full signature
simSig := legacytx.StdSignature{ //nolint:staticcheck // this will be removed when proto is ready
Signature: simSecp256k1Sig[:],
PubKey: pubkey,
}
sigBz := legacy.Cdc.MustMarshal(simSig)
cost := sdk.Gas(len(sigBz) + 6)
// If the pubkey is a multi-signature pubkey, then we estimate for the maximum
// number of signers.
if _, ok := pubkey.(*multisig.LegacyAminoPubKey); ok {
cost *= params.TxSigLimit
}
ctx.GasMeter().ConsumeGas(params.TxSizeCostPerByte*cost, "txSize")
}
}
return next(ctx, tx, simulate)
}
// isIncompleteSignature tests whether SignatureData is fully filled in for simulation purposes
func isIncompleteSignature(data signing.SignatureData) bool {
if data == nil {
return true
}
switch data := data.(type) {
case *signing.SingleSignatureData:
return len(data.Signature) == 0
case *signing.MultiSignatureData:
if len(data.Signatures) == 0 {
return true
}
for _, s := range data.Signatures {
if isIncompleteSignature(s) {
return true
}
}
}
return false
}
type (
// TxTimeoutHeightDecorator defines an AnteHandler decorator that checks for a
// tx height timeout.
TxTimeoutHeightDecorator struct{}
// TxWithTimeoutHeight defines the interface a tx must implement in order for
// TxHeightTimeoutDecorator to process the tx.
TxWithTimeoutHeight interface {
sdk.Tx
GetTimeoutHeight() uint64
}
)
// TxTimeoutHeightDecorator defines an AnteHandler decorator that checks for a
// tx height timeout.
func NewTxTimeoutHeightDecorator() TxTimeoutHeightDecorator {
return TxTimeoutHeightDecorator{}
}
// AnteHandle implements an AnteHandler decorator for the TxHeightTimeoutDecorator
// type where the current block height is checked against the tx's height timeout.
// If a height timeout is provided (non-zero) and is less than the current block
// height, then an error is returned.
func (txh TxTimeoutHeightDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) {
timeoutTx, ok := tx.(TxWithTimeoutHeight)
if !ok {
return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "expected tx to implement TxWithTimeoutHeight")
}
timeoutHeight := timeoutTx.GetTimeoutHeight()
if timeoutHeight > 0 && uint64(ctx.BlockHeight()) > timeoutHeight {
return ctx, sdkerrors.Wrapf(
sdkerrors.ErrTxTimeoutHeight, "block height: %d, timeout height: %d", ctx.BlockHeight(), timeoutHeight,
)
}
return next(ctx, tx, simulate)
}

View File

@ -1,224 +0,0 @@
package ante_test
import (
"strings"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/cosmos-sdk/crypto/types/multisig"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/ante"
)
func (suite *AnteTestSuite) TestValidateBasic() {
suite.SetupTest(true) // setup
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
suite.Require().NoError(suite.txBuilder.SetMsgs(msg))
suite.txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit)
privs, accNums, accSeqs := []cryptotypes.PrivKey{}, []uint64{}, []uint64{}
invalidTx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID())
suite.Require().NoError(err)
vbd := ante.NewValidateBasicDecorator()
antehandler := sdk.ChainAnteDecorators(vbd)
_, err = antehandler(suite.ctx, invalidTx, false)
suite.Require().NotNil(err, "Did not error on invalid tx")
privs, accNums, accSeqs = []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
validTx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID())
suite.Require().NoError(err)
_, err = antehandler(suite.ctx, validTx, false)
suite.Require().Nil(err, "ValidateBasicDecorator returned error on valid tx. err: %v", err)
// test decorator skips on recheck
suite.ctx = suite.ctx.WithIsReCheckTx(true)
// decorator should skip processing invalidTx on recheck and thus return nil-error
_, err = antehandler(suite.ctx, invalidTx, false)
suite.Require().Nil(err, "ValidateBasicDecorator ran on ReCheck")
}
func (suite *AnteTestSuite) TestValidateMemo() {
suite.SetupTest(true) // setup
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
suite.Require().NoError(suite.txBuilder.SetMsgs(msg))
suite.txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit)
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
suite.txBuilder.SetMemo(strings.Repeat("01234567890", 500))
invalidTx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID())
suite.Require().NoError(err)
// require that long memos get rejected
vmd := ante.NewValidateMemoDecorator(suite.app.AccountKeeper)
antehandler := sdk.ChainAnteDecorators(vmd)
_, err = antehandler(suite.ctx, invalidTx, false)
suite.Require().NotNil(err, "Did not error on tx with high memo")
suite.txBuilder.SetMemo(strings.Repeat("01234567890", 10))
validTx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID())
suite.Require().NoError(err)
// require small memos pass ValidateMemo Decorator
_, err = antehandler(suite.ctx, validTx, false)
suite.Require().Nil(err, "ValidateBasicDecorator returned error on valid tx. err: %v", err)
}
func (suite *AnteTestSuite) TestConsumeGasForTxSize() {
suite.SetupTest(true) // setup
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
cgtsd := ante.NewConsumeGasForTxSizeDecorator(suite.app.AccountKeeper)
antehandler := sdk.ChainAnteDecorators(cgtsd)
testCases := []struct {
name string
sigV2 signing.SignatureV2
}{
{"SingleSignatureData", signing.SignatureV2{PubKey: priv1.PubKey()}},
{"MultiSignatureData", signing.SignatureV2{PubKey: priv1.PubKey(), Data: multisig.NewMultisig(2)}},
}
for _, tc := range testCases {
suite.Run(tc.name, func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
suite.Require().NoError(suite.txBuilder.SetMsgs(msg))
suite.txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit)
suite.txBuilder.SetMemo(strings.Repeat("01234567890", 10))
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID())
suite.Require().NoError(err)
txBytes, err := suite.clientCtx.TxConfig.TxJSONEncoder()(tx)
suite.Require().Nil(err, "Cannot marshal tx: %v", err)
params := suite.app.AccountKeeper.GetParams(suite.ctx)
expectedGas := sdk.Gas(len(txBytes)) * params.TxSizeCostPerByte
// Set suite.ctx with TxBytes manually
suite.ctx = suite.ctx.WithTxBytes(txBytes)
// track how much gas is necessary to retrieve parameters
beforeGas := suite.ctx.GasMeter().GasConsumed()
suite.app.AccountKeeper.GetParams(suite.ctx)
afterGas := suite.ctx.GasMeter().GasConsumed()
expectedGas += afterGas - beforeGas
beforeGas = suite.ctx.GasMeter().GasConsumed()
suite.ctx, err = antehandler(suite.ctx, tx, false)
suite.Require().Nil(err, "ConsumeTxSizeGasDecorator returned error: %v", err)
// require that decorator consumes expected amount of gas
consumedGas := suite.ctx.GasMeter().GasConsumed() - beforeGas
suite.Require().Equal(expectedGas, consumedGas, "Decorator did not consume the correct amount of gas")
// simulation must not underestimate gas of this decorator even with nil signatures
txBuilder, err := suite.clientCtx.TxConfig.WrapTxBuilder(tx)
suite.Require().NoError(err)
suite.Require().NoError(txBuilder.SetSignatures(tc.sigV2))
tx = txBuilder.GetTx()
simTxBytes, err := suite.clientCtx.TxConfig.TxJSONEncoder()(tx)
suite.Require().Nil(err, "Cannot marshal tx: %v", err)
// require that simulated tx is smaller than tx with signatures
suite.Require().True(len(simTxBytes) < len(txBytes), "simulated tx still has signatures")
// Set suite.ctx with smaller simulated TxBytes manually
suite.ctx = suite.ctx.WithTxBytes(simTxBytes)
beforeSimGas := suite.ctx.GasMeter().GasConsumed()
// run antehandler with simulate=true
suite.ctx, err = antehandler(suite.ctx, tx, true)
consumedSimGas := suite.ctx.GasMeter().GasConsumed() - beforeSimGas
// require that antehandler passes and does not underestimate decorator cost
suite.Require().Nil(err, "ConsumeTxSizeGasDecorator returned error: %v", err)
suite.Require().True(consumedSimGas >= expectedGas, "Simulate mode underestimates gas on AnteDecorator. Simulated cost: %d, expected cost: %d", consumedSimGas, expectedGas)
})
}
}
func (suite *AnteTestSuite) TestTxHeightTimeoutDecorator() {
suite.SetupTest(true)
antehandler := sdk.ChainAnteDecorators(ante.NewTxTimeoutHeightDecorator())
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
testCases := []struct {
name string
timeout uint64
height int64
expectErr bool
}{
{"default value", 0, 10, false},
{"no timeout (greater height)", 15, 10, false},
{"no timeout (same height)", 10, 10, false},
{"timeout (smaller height)", 9, 10, true},
}
for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
suite.Require().NoError(suite.txBuilder.SetMsgs(msg))
suite.txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit)
suite.txBuilder.SetMemo(strings.Repeat("01234567890", 10))
suite.txBuilder.SetTimeoutHeight(tc.timeout)
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID())
suite.Require().NoError(err)
ctx := suite.ctx.WithBlockHeight(tc.height)
_, err = antehandler(ctx, tx, true)
suite.Require().Equal(tc.expectErr, err != nil, err)
})
}
}

View File

@ -1,36 +0,0 @@
package ante
import (
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
)
type HasExtensionOptionsTx interface {
GetExtensionOptions() []*codectypes.Any
GetNonCriticalExtensionOptions() []*codectypes.Any
}
// RejectExtensionOptionsDecorator is an AnteDecorator that rejects all extension
// options which can optionally be included in protobuf transactions. Users that
// need extension options should create a custom AnteHandler chain that handles
// needed extension options properly and rejects unknown ones.
type RejectExtensionOptionsDecorator struct{}
// NewRejectExtensionOptionsDecorator creates a new RejectExtensionOptionsDecorator
func NewRejectExtensionOptionsDecorator() RejectExtensionOptionsDecorator {
return RejectExtensionOptionsDecorator{}
}
var _ types.AnteDecorator = RejectExtensionOptionsDecorator{}
// AnteHandle implements the AnteDecorator.AnteHandle method
func (r RejectExtensionOptionsDecorator) AnteHandle(ctx types.Context, tx types.Tx, simulate bool, next types.AnteHandler) (newCtx types.Context, err error) {
if hasExtOptsTx, ok := tx.(HasExtensionOptionsTx); ok {
if len(hasExtOptsTx.GetExtensionOptions()) != 0 {
return ctx, sdkerrors.ErrUnknownExtensionOptions
}
}
return next(ctx, tx, simulate)
}

View File

@ -1,36 +0,0 @@
package ante_test
import (
"github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/ante"
"github.com/cosmos/cosmos-sdk/x/auth/tx"
)
func (suite *AnteTestSuite) TestRejectExtensionOptionsDecorator() {
suite.SetupTest(true) // setup
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
reod := ante.NewRejectExtensionOptionsDecorator()
antehandler := sdk.ChainAnteDecorators(reod)
// no extension options should not trigger an error
theTx := suite.txBuilder.GetTx()
_, err := antehandler(suite.ctx, theTx, false)
suite.Require().NoError(err)
extOptsTxBldr, ok := suite.txBuilder.(tx.ExtensionOptionsTxBuilder)
if !ok {
// if we can't set extension options, this decorator doesn't apply and we're done
return
}
// setting any extension option should cause an error
any, err := types.NewAnyWithValue(testdata.NewTestMsg())
suite.Require().NoError(err)
extOptsTxBldr.SetExtensionOptions(any)
theTx = suite.txBuilder.GetTx()
_, err = antehandler(suite.ctx, theTx, false)
suite.Require().EqualError(err, "unknown extension options")
}

View File

@ -1,140 +0,0 @@
package ante
import (
"fmt"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/x/auth/types"
)
// MempoolFeeDecorator will check if the transaction's fee is at least as large
// as the local validator's minimum gasFee (defined in validator config).
// If fee is too low, decorator returns error and tx is rejected from mempool.
// Note this only applies when ctx.CheckTx = true
// If fee is high enough or not CheckTx, then call next AnteHandler
// CONTRACT: Tx must implement FeeTx to use MempoolFeeDecorator
type MempoolFeeDecorator struct{}
func NewMempoolFeeDecorator() MempoolFeeDecorator {
return MempoolFeeDecorator{}
}
func (mfd MempoolFeeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) {
feeTx, ok := tx.(sdk.FeeTx)
if !ok {
return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx")
}
feeCoins := feeTx.GetFee()
gas := feeTx.GetGas()
// Ensure that the provided fees meet a minimum threshold for the validator,
// if this is a CheckTx. This is only for local mempool purposes, and thus
// is only ran on check tx.
if ctx.IsCheckTx() && !simulate {
minGasPrices := ctx.MinGasPrices()
if !minGasPrices.IsZero() {
requiredFees := make(sdk.Coins, len(minGasPrices))
// Determine the required fees by multiplying each required minimum gas
// price by the gas limit, where fee = ceil(minGasPrice * gasLimit).
glDec := sdk.NewDec(int64(gas))
for i, gp := range minGasPrices {
fee := gp.Amount.Mul(glDec)
requiredFees[i] = sdk.NewCoin(gp.Denom, fee.Ceil().RoundInt())
}
if !feeCoins.IsAnyGTE(requiredFees) {
return ctx, sdkerrors.Wrapf(sdkerrors.ErrInsufficientFee, "insufficient fees; got: %s required: %s", feeCoins, requiredFees)
}
}
}
return next(ctx, tx, simulate)
}
// DeductFeeDecorator deducts fees from the first signer of the tx
// If the first signer does not have the funds to pay for the fees, return with InsufficientFunds error
// Call next AnteHandler if fees successfully deducted
// CONTRACT: Tx must implement FeeTx interface to use DeductFeeDecorator
type DeductFeeDecorator struct {
ak AccountKeeper
bankKeeper types.BankKeeper
feegrantKeeper FeegrantKeeper
}
func NewDeductFeeDecorator(ak AccountKeeper, bk types.BankKeeper, fk FeegrantKeeper) DeductFeeDecorator {
return DeductFeeDecorator{
ak: ak,
bankKeeper: bk,
feegrantKeeper: fk,
}
}
func (dfd DeductFeeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) {
feeTx, ok := tx.(sdk.FeeTx)
if !ok {
return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx")
}
if addr := dfd.ak.GetModuleAddress(types.FeeCollectorName); addr == nil {
panic(fmt.Sprintf("%s module account has not been set", types.FeeCollectorName))
}
fee := feeTx.GetFee()
feePayer := feeTx.FeePayer()
feeGranter := feeTx.FeeGranter()
deductFeesFrom := feePayer
// if feegranter set deduct fee from feegranter account.
// this works with only when feegrant enabled.
if feeGranter != nil {
if dfd.feegrantKeeper == nil {
return ctx, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "fee grants are not enabled")
} else if !feeGranter.Equals(feePayer) {
err := dfd.feegrantKeeper.UseGrantedFees(ctx, feeGranter, feePayer, fee, tx.GetMsgs())
if err != nil {
return ctx, sdkerrors.Wrapf(err, "%s not allowed to pay fees from %s", feeGranter, feePayer)
}
}
deductFeesFrom = feeGranter
}
deductFeesFromAcc := dfd.ak.GetAccount(ctx, deductFeesFrom)
if deductFeesFromAcc == nil {
return ctx, sdkerrors.Wrapf(sdkerrors.ErrUnknownAddress, "fee payer address: %s does not exist", deductFeesFrom)
}
// deduct the fees
if !feeTx.GetFee().IsZero() {
err = DeductFees(dfd.bankKeeper, ctx, deductFeesFromAcc, feeTx.GetFee())
if err != nil {
return ctx, err
}
}
events := sdk.Events{sdk.NewEvent(sdk.EventTypeTx,
sdk.NewAttribute(sdk.AttributeKeyFee, feeTx.GetFee().String()),
)}
ctx.EventManager().EmitEvents(events)
return next(ctx, tx, simulate)
}
// DeductFees deducts fees from the given account.
func DeductFees(bankKeeper types.BankKeeper, ctx sdk.Context, acc types.AccountI, fees sdk.Coins) error {
if !fees.IsValid() {
return sdkerrors.Wrapf(sdkerrors.ErrInsufficientFee, "invalid fee amount: %s", fees)
}
err := bankKeeper.SendCoinsFromAccountToModule(ctx, acc.GetAddress(), types.FeeCollectorName, fees)
if err != nil {
return sdkerrors.Wrapf(sdkerrors.ErrInsufficientFunds, err.Error())
}
return nil
}

View File

@ -1,104 +0,0 @@
package ante_test
import (
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/ante"
"github.com/cosmos/cosmos-sdk/x/bank/testutil"
)
func (suite *AnteTestSuite) TestEnsureMempoolFees() {
suite.SetupTest(true) // setup
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
mfd := ante.NewMempoolFeeDecorator()
antehandler := sdk.ChainAnteDecorators(mfd)
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
suite.Require().NoError(suite.txBuilder.SetMsgs(msg))
suite.txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit)
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID())
suite.Require().NoError(err)
// Set high gas price so standard test fee fails
atomPrice := sdk.NewDecCoinFromDec("atom", sdk.NewDec(200).Quo(sdk.NewDec(100000)))
highGasPrice := []sdk.DecCoin{atomPrice}
suite.ctx = suite.ctx.WithMinGasPrices(highGasPrice)
// Set IsCheckTx to true
suite.ctx = suite.ctx.WithIsCheckTx(true)
// antehandler errors with insufficient fees
_, err = antehandler(suite.ctx, tx, false)
suite.Require().NotNil(err, "Decorator should have errored on too low fee for local gasPrice")
// Set IsCheckTx to false
suite.ctx = suite.ctx.WithIsCheckTx(false)
// antehandler should not error since we do not check minGasPrice in DeliverTx
_, err = antehandler(suite.ctx, tx, false)
suite.Require().Nil(err, "MempoolFeeDecorator returned error in DeliverTx")
// Set IsCheckTx back to true for testing sufficient mempool fee
suite.ctx = suite.ctx.WithIsCheckTx(true)
atomPrice = sdk.NewDecCoinFromDec("atom", sdk.NewDec(0).Quo(sdk.NewDec(100000)))
lowGasPrice := []sdk.DecCoin{atomPrice}
suite.ctx = suite.ctx.WithMinGasPrices(lowGasPrice)
_, err = antehandler(suite.ctx, tx, false)
suite.Require().Nil(err, "Decorator should not have errored on fee higher than local gasPrice")
}
func (suite *AnteTestSuite) TestDeductFees() {
suite.SetupTest(false) // setup
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
suite.Require().NoError(suite.txBuilder.SetMsgs(msg))
suite.txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit)
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID())
suite.Require().NoError(err)
// Set account with insufficient funds
acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr1)
suite.app.AccountKeeper.SetAccount(suite.ctx, acc)
coins := sdk.NewCoins(sdk.NewCoin("atom", sdk.NewInt(10)))
err = testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr1, coins)
suite.Require().NoError(err)
dfd := ante.NewDeductFeeDecorator(suite.app.AccountKeeper, suite.app.BankKeeper, nil)
antehandler := sdk.ChainAnteDecorators(dfd)
_, err = antehandler(suite.ctx, tx, false)
suite.Require().NotNil(err, "Tx did not error when fee payer had insufficient funds")
// Set account with sufficient funds
suite.app.AccountKeeper.SetAccount(suite.ctx, acc)
err = testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr1, sdk.NewCoins(sdk.NewCoin("atom", sdk.NewInt(200))))
suite.Require().NoError(err)
_, err = antehandler(suite.ctx, tx, false)
suite.Require().Nil(err, "Tx errored after account has been set with sufficient funds")
}

View File

@ -1,76 +0,0 @@
package ante
import (
"fmt"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx"
)
var (
_ GasTx = (*legacytx.StdTx)(nil) // assert StdTx implements GasTx
)
// GasTx defines a Tx with a GetGas() method which is needed to use SetUpContextDecorator
type GasTx interface {
sdk.Tx
GetGas() uint64
}
// SetUpContextDecorator sets the GasMeter in the Context and wraps the next AnteHandler with a defer clause
// to recover from any downstream OutOfGas panics in the AnteHandler chain to return an error with information
// on gas provided and gas used.
// CONTRACT: Must be first decorator in the chain
// CONTRACT: Tx must implement GasTx interface
type SetUpContextDecorator struct{}
func NewSetUpContextDecorator() SetUpContextDecorator {
return SetUpContextDecorator{}
}
func (sud SetUpContextDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) {
// all transactions must implement GasTx
gasTx, ok := tx.(GasTx)
if !ok {
// Set a gas meter with limit 0 as to prevent an infinite gas meter attack
// during runTx.
newCtx = SetGasMeter(simulate, ctx, 0)
return newCtx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be GasTx")
}
newCtx = SetGasMeter(simulate, ctx, gasTx.GetGas())
// Decorator will catch an OutOfGasPanic caused in the next antehandler
// AnteHandlers must have their own defer/recover in order for the BaseApp
// to know how much gas was used! This is because the GasMeter is created in
// the AnteHandler, but if it panics the context won't be set properly in
// runTx's recover call.
defer func() {
if r := recover(); r != nil {
switch rType := r.(type) {
case sdk.ErrorOutOfGas:
log := fmt.Sprintf(
"insufficient gas, gasOffered: %d, gasRequired: %d, code location: %v",
gasTx.GetGas(), newCtx.GasMeter().GasConsumed(), rType.Descriptor)
err = sdkerrors.Wrap(sdkerrors.ErrOutOfGas, log)
default:
panic(r)
}
}
}()
return next(newCtx, tx, simulate)
}
// SetGasMeter returns a new context with a gas meter set from a given context.
func SetGasMeter(simulate bool, ctx sdk.Context, gasLimit uint64) sdk.Context {
// In various cases such as simulation and during the genesis block, we do not
// meter any gas utilization.
if simulate || ctx.BlockHeight() == 0 {
return ctx.WithGasMeter(sdk.NewInfiniteGasMeter())
}
return ctx.WithGasMeter(sdk.NewGasMeter(gasLimit))
}

View File

@ -1,214 +0,0 @@
package ante_test
import (
"errors"
"fmt"
"testing"
minttypes "github.com/cosmos/cosmos-sdk/x/mint/types"
"github.com/stretchr/testify/suite"
tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/client/tx"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/cosmos-sdk/simapp"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/ante"
xauthsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/cosmos/cosmos-sdk/x/auth/types"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
)
// TestAccount represents an account used in the tests in x/auth/ante.
type TestAccount struct {
acc types.AccountI
priv cryptotypes.PrivKey
}
// AnteTestSuite is a test suite to be used with ante handler tests.
type AnteTestSuite struct {
suite.Suite
app *simapp.SimApp
anteHandler sdk.AnteHandler
ctx sdk.Context
clientCtx client.Context
txBuilder client.TxBuilder
}
// returns context and app with params set on account keeper
func createTestApp(t *testing.T, isCheckTx bool) (*simapp.SimApp, sdk.Context) {
app := simapp.Setup(t, isCheckTx)
ctx := app.BaseApp.NewContext(isCheckTx, tmproto.Header{})
app.AccountKeeper.SetParams(ctx, authtypes.DefaultParams())
return app, ctx
}
// SetupTest setups a new test, with new app, context, and anteHandler.
func (suite *AnteTestSuite) SetupTest(isCheckTx bool) {
suite.app, suite.ctx = createTestApp(suite.T(), isCheckTx)
suite.ctx = suite.ctx.WithBlockHeight(1)
// Set up TxConfig.
encodingConfig := simapp.MakeTestEncodingConfig()
// We're using TestMsg encoding in some tests, so register it here.
encodingConfig.Amino.RegisterConcrete(&testdata.TestMsg{}, "testdata.TestMsg", nil)
testdata.RegisterInterfaces(encodingConfig.InterfaceRegistry)
suite.clientCtx = client.Context{}.
WithTxConfig(encodingConfig.TxConfig)
// We're not using ante.NewAnteHandler here because:
// - ante.NewAnteHandler doesn't have SetUpContextDecorator, as it has been
// moved to the gas TxMiddleware
// - whereas these tests have not been migrated to middlewares yet, so
// still need the SetUpContextDecorator.
//
// TODO: migrate all antehandler tests to middleware tests.
// https://github.com/cosmos/cosmos-sdk/issues/9585
anteDecorators := []sdk.AnteDecorator{
ante.NewSetUpContextDecorator(),
ante.NewRejectExtensionOptionsDecorator(),
ante.NewMempoolFeeDecorator(),
ante.NewValidateBasicDecorator(),
ante.NewTxTimeoutHeightDecorator(),
ante.NewValidateMemoDecorator(suite.app.AccountKeeper),
ante.NewConsumeGasForTxSizeDecorator(suite.app.AccountKeeper),
ante.NewDeductFeeDecorator(suite.app.AccountKeeper, suite.app.BankKeeper, suite.app.FeeGrantKeeper),
// SetPubKeyDecorator must be called before all signature verification decorators
ante.NewSetPubKeyDecorator(suite.app.AccountKeeper),
ante.NewValidateSigCountDecorator(suite.app.AccountKeeper),
ante.NewSigGasConsumeDecorator(suite.app.AccountKeeper, ante.DefaultSigVerificationGasConsumer),
ante.NewSigVerificationDecorator(suite.app.AccountKeeper, encodingConfig.TxConfig.SignModeHandler()),
ante.NewIncrementSequenceDecorator(suite.app.AccountKeeper),
}
suite.anteHandler = sdk.ChainAnteDecorators(anteDecorators...)
}
// CreateTestAccounts creates `numAccs` accounts, and return all relevant
// information about them including their private keys.
func (suite *AnteTestSuite) CreateTestAccounts(numAccs int) []TestAccount {
var accounts []TestAccount
for i := 0; i < numAccs; i++ {
priv, _, addr := testdata.KeyTestPubAddr()
acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr)
err := acc.SetAccountNumber(uint64(i))
suite.Require().NoError(err)
suite.app.AccountKeeper.SetAccount(suite.ctx, acc)
someCoins := sdk.Coins{
sdk.NewInt64Coin("atom", 10000000),
}
err = suite.app.BankKeeper.MintCoins(suite.ctx, minttypes.ModuleName, someCoins)
suite.Require().NoError(err)
err = suite.app.BankKeeper.SendCoinsFromModuleToAccount(suite.ctx, minttypes.ModuleName, addr, someCoins)
suite.Require().NoError(err)
accounts = append(accounts, TestAccount{acc, priv})
}
return accounts
}
// CreateTestTx is a helper function to create a tx given multiple inputs.
func (suite *AnteTestSuite) CreateTestTx(privs []cryptotypes.PrivKey, accNums []uint64, accSeqs []uint64, chainID string) (xauthsigning.Tx, error) {
// First round: we gather all the signer infos. We use the "set empty
// signature" hack to do that.
var sigsV2 []signing.SignatureV2
for i, priv := range privs {
sigV2 := signing.SignatureV2{
PubKey: priv.PubKey(),
Data: &signing.SingleSignatureData{
SignMode: suite.clientCtx.TxConfig.SignModeHandler().DefaultMode(),
Signature: nil,
},
Sequence: accSeqs[i],
}
sigsV2 = append(sigsV2, sigV2)
}
err := suite.txBuilder.SetSignatures(sigsV2...)
if err != nil {
return nil, err
}
// Second round: all signer infos are set, so each signer can sign.
sigsV2 = []signing.SignatureV2{}
for i, priv := range privs {
signerData := xauthsigning.SignerData{
ChainID: chainID,
AccountNumber: accNums[i],
Sequence: accSeqs[i],
}
sigV2, err := tx.SignWithPrivKey(
suite.clientCtx.TxConfig.SignModeHandler().DefaultMode(), signerData,
suite.txBuilder, priv, suite.clientCtx.TxConfig, accSeqs[i])
if err != nil {
return nil, err
}
sigsV2 = append(sigsV2, sigV2)
}
err = suite.txBuilder.SetSignatures(sigsV2...)
if err != nil {
return nil, err
}
return suite.txBuilder.GetTx(), nil
}
// TestCase represents a test case used in test tables.
type TestCase struct {
desc string
malleate func()
simulate bool
expPass bool
expErr error
}
// CreateTestTx is a helper function to create a tx given multiple inputs.
func (suite *AnteTestSuite) RunTestCase(privs []cryptotypes.PrivKey, msgs []sdk.Msg, feeAmount sdk.Coins, gasLimit uint64, accNums, accSeqs []uint64, chainID string, tc TestCase) {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...))
suite.txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit)
// Theoretically speaking, ante handler unit tests should only test
// ante handlers, but here we sometimes also test the tx creation
// process.
tx, txErr := suite.CreateTestTx(privs, accNums, accSeqs, chainID)
newCtx, anteErr := suite.anteHandler(suite.ctx, tx, tc.simulate)
if tc.expPass {
suite.Require().NoError(txErr)
suite.Require().NoError(anteErr)
suite.Require().NotNil(newCtx)
suite.ctx = newCtx
} else {
switch {
case txErr != nil:
suite.Require().Error(txErr)
suite.Require().True(errors.Is(txErr, tc.expErr))
case anteErr != nil:
suite.Require().Error(anteErr)
suite.Require().True(errors.Is(anteErr, tc.expErr))
default:
suite.Fail("expected one of txErr,anteErr to be an error")
}
}
})
}
func TestAnteTestSuite(t *testing.T) {
suite.Run(t, new(AnteTestSuite))
}

358
x/auth/middleware/basic.go Normal file
View File

@ -0,0 +1,358 @@
package middleware
import (
"context"
"github.com/cosmos/cosmos-sdk/codec/legacy"
"github.com/cosmos/cosmos-sdk/crypto/keys/multisig"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx"
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
abci "github.com/tendermint/tendermint/abci/types"
)
type validateBasicTxHandler struct {
next tx.Handler
}
// ValidateBasicMiddleware will call tx.ValidateBasic, msg.ValidateBasic(for each msg inside tx)
// and return any non-nil error.
// If ValidateBasic passes, middleware calls next middleware in chain. Note,
// validateBasicTxHandler will not get executed on ReCheckTx since it
// is not dependent on application state.
func ValidateBasicMiddleware(txh tx.Handler) tx.Handler {
return validateBasicTxHandler{
next: txh,
}
}
var _ tx.Handler = validateBasicTxHandler{}
// validateBasicTxMsgs executes basic validator calls for messages.
func validateBasicTxMsgs(msgs []sdk.Msg) error {
if len(msgs) == 0 {
return sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "must contain at least one message")
}
for _, msg := range msgs {
err := msg.ValidateBasic()
if err != nil {
return err
}
}
return nil
}
// CheckTx implements tx.Handler.CheckTx.
func (txh validateBasicTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
// no need to validate basic on recheck tx, call next middleware
if req.Type == abci.CheckTxType_Recheck {
return txh.next.CheckTx(ctx, tx, req)
}
if err := validateBasicTxMsgs(tx.GetMsgs()); err != nil {
return abci.ResponseCheckTx{}, err
}
if err := tx.ValidateBasic(); err != nil {
return abci.ResponseCheckTx{}, err
}
return txh.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (txh validateBasicTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := tx.ValidateBasic(); err != nil {
return abci.ResponseDeliverTx{}, err
}
if err := validateBasicTxMsgs(tx.GetMsgs()); err != nil {
return abci.ResponseDeliverTx{}, err
}
return txh.next.DeliverTx(ctx, tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx.
func (txh validateBasicTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := sdkTx.ValidateBasic(); err != nil {
return tx.ResponseSimulateTx{}, err
}
if err := validateBasicTxMsgs(sdkTx.GetMsgs()); err != nil {
return tx.ResponseSimulateTx{}, err
}
return txh.next.SimulateTx(ctx, sdkTx, req)
}
var _ tx.Handler = txTimeoutHeightTxHandler{}
type txTimeoutHeightTxHandler struct {
next tx.Handler
}
// TxTimeoutHeightMiddleware defines a middleware that checks for a
// tx height timeout.
func TxTimeoutHeightMiddleware(txh tx.Handler) tx.Handler {
return txTimeoutHeightTxHandler{
next: txh,
}
}
func checkTimeout(ctx context.Context, tx sdk.Tx) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
timeoutTx, ok := tx.(sdk.TxWithTimeoutHeight)
if !ok {
return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "expected tx to implement TxWithTimeoutHeight")
}
timeoutHeight := timeoutTx.GetTimeoutHeight()
if timeoutHeight > 0 && uint64(sdkCtx.BlockHeight()) > timeoutHeight {
return sdkerrors.Wrapf(
sdkerrors.ErrTxTimeoutHeight, "block height: %d, timeout height: %d", sdkCtx.BlockHeight(), timeoutHeight,
)
}
return nil
}
// CheckTx implements tx.Handler.CheckTx.
func (txh txTimeoutHeightTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
if err := checkTimeout(ctx, tx); err != nil {
return abci.ResponseCheckTx{}, err
}
return txh.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (txh txTimeoutHeightTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := checkTimeout(ctx, tx); err != nil {
return abci.ResponseDeliverTx{}, err
}
return txh.next.DeliverTx(ctx, tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx.
func (txh txTimeoutHeightTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := checkTimeout(ctx, sdkTx); err != nil {
return tx.ResponseSimulateTx{}, err
}
return txh.next.SimulateTx(ctx, sdkTx, req)
}
type validateMemoTxHandler struct {
ak AccountKeeper
next tx.Handler
}
// ValidateMemoMiddleware will validate memo given the parameters passed in
// If memo is too large middleware returns with error, otherwise call next middleware
// CONTRACT: Tx must implement TxWithMemo interface
func ValidateMemoMiddleware(ak AccountKeeper) tx.Middleware {
return func(txHandler tx.Handler) tx.Handler {
return validateMemoTxHandler{
ak: ak,
next: txHandler,
}
}
}
var _ tx.Handler = validateMemoTxHandler{}
func (vmm validateMemoTxHandler) checkForValidMemo(ctx context.Context, tx sdk.Tx) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
memoTx, ok := tx.(sdk.TxWithMemo)
if !ok {
return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}
params := vmm.ak.GetParams(sdkCtx)
memoLength := len(memoTx.GetMemo())
if uint64(memoLength) > params.MaxMemoCharacters {
return sdkerrors.Wrapf(sdkerrors.ErrMemoTooLarge,
"maximum number of characters is %d but received %d characters",
params.MaxMemoCharacters, memoLength,
)
}
return nil
}
// CheckTx implements tx.Handler.CheckTx method.
func (vmm validateMemoTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
if err := vmm.checkForValidMemo(ctx, tx); err != nil {
return abci.ResponseCheckTx{}, err
}
return vmm.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx method.
func (vmm validateMemoTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := vmm.checkForValidMemo(ctx, tx); err != nil {
return abci.ResponseDeliverTx{}, err
}
return vmm.next.DeliverTx(ctx, tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx method.
func (vmm validateMemoTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := vmm.checkForValidMemo(ctx, sdkTx); err != nil {
return tx.ResponseSimulateTx{}, err
}
return vmm.next.SimulateTx(ctx, sdkTx, req)
}
var _ tx.Handler = consumeTxSizeGasTxHandler{}
type consumeTxSizeGasTxHandler struct {
ak AccountKeeper
next tx.Handler
}
// ConsumeTxSizeGasMiddleware will take in parameters and consume gas proportional
// to the size of tx before calling next middleware. Note, the gas costs will be
// slightly over estimated due to the fact that any given signing account may need
// to be retrieved from state.
//
// CONTRACT: If simulate=true, then signatures must either be completely filled
// in or empty.
// CONTRACT: To use this middleware, signatures of transaction must be represented
// as legacytx.StdSignature otherwise simulate mode will incorrectly estimate gas cost.
func ConsumeTxSizeGasMiddleware(ak AccountKeeper) tx.Middleware {
return func(txHandler tx.Handler) tx.Handler {
return consumeTxSizeGasTxHandler{
ak: ak,
next: txHandler,
}
}
}
func (cgts consumeTxSizeGasTxHandler) simulateSigGasCost(ctx context.Context, tx sdk.Tx) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
params := cgts.ak.GetParams(sdkCtx)
sigTx, ok := tx.(authsigning.SigVerifiableTx)
if !ok {
return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid tx type")
}
// in simulate mode, each element should be a nil signature
sigs, err := sigTx.GetSignaturesV2()
if err != nil {
return err
}
n := len(sigs)
for i, signer := range sigTx.GetSigners() {
// if signature is already filled in, no need to simulate gas cost
if i < n && !isIncompleteSignature(sigs[i].Data) {
continue
}
var pubkey cryptotypes.PubKey
acc := cgts.ak.GetAccount(sdkCtx, signer)
// use placeholder simSecp256k1Pubkey if sig is nil
if acc == nil || acc.GetPubKey() == nil {
pubkey = simSecp256k1Pubkey
} else {
pubkey = acc.GetPubKey()
}
// use stdsignature to mock the size of a full signature
simSig := legacytx.StdSignature{ //nolint:staticcheck // this will be removed when proto is ready
Signature: simSecp256k1Sig[:],
PubKey: pubkey,
}
sigBz := legacy.Cdc.MustMarshal(simSig)
cost := sdk.Gas(len(sigBz) + 6)
// If the pubkey is a multi-signature pubkey, then we estimate for the maximum
// number of signers.
if _, ok := pubkey.(*multisig.LegacyAminoPubKey); ok {
cost *= params.TxSigLimit
}
sdkCtx.GasMeter().ConsumeGas(params.TxSizeCostPerByte*cost, "txSize")
}
return nil
}
func (cgts consumeTxSizeGasTxHandler) consumeTxSizeGas(ctx context.Context, tx sdk.Tx, txBytes []byte, simulate bool) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
params := cgts.ak.GetParams(sdkCtx)
sdkCtx.GasMeter().ConsumeGas(params.TxSizeCostPerByte*sdk.Gas(len(txBytes)), "txSize")
return nil
}
// CheckTx implements tx.Handler.CheckTx.
func (cgts consumeTxSizeGasTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
if err := cgts.consumeTxSizeGas(ctx, tx, req.GetTx(), false); err != nil {
return abci.ResponseCheckTx{}, err
}
return cgts.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (cgts consumeTxSizeGasTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := cgts.consumeTxSizeGas(ctx, tx, req.GetTx(), false); err != nil {
return abci.ResponseDeliverTx{}, err
}
return cgts.next.DeliverTx(ctx, tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx.
func (cgts consumeTxSizeGasTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := cgts.consumeTxSizeGas(ctx, sdkTx, req.TxBytes, true); err != nil {
return tx.ResponseSimulateTx{}, err
}
if err := cgts.simulateSigGasCost(ctx, sdkTx); err != nil {
return tx.ResponseSimulateTx{}, err
}
return cgts.next.SimulateTx(ctx, sdkTx, req)
}
// isIncompleteSignature tests whether SignatureData is fully filled in for simulation purposes
func isIncompleteSignature(data signing.SignatureData) bool {
if data == nil {
return true
}
switch data := data.(type) {
case *signing.SingleSignatureData:
return len(data.Signature) == 0
case *signing.MultiSignatureData:
if len(data.Signatures) == 0 {
return true
}
for _, s := range data.Signatures {
if isIncompleteSignature(s) {
return true
}
}
}
return false
}

View File

@ -0,0 +1,222 @@
package middleware_test
import (
"strings"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/cosmos-sdk/crypto/types/multisig"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/middleware"
"github.com/tendermint/tendermint/abci/types"
)
func (s *MWTestSuite) TestValidateBasic() {
ctx := s.SetupTest(true) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.ValidateBasicMiddleware)
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
s.Require().NoError(txBuilder.SetMsgs(msg))
txBuilder.SetFeeAmount(feeAmount)
txBuilder.SetGasLimit(gasLimit)
privs, accNums, accSeqs := []cryptotypes.PrivKey{}, []uint64{}, []uint64{}
invalidTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
s.Require().NoError(err)
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), invalidTx, types.RequestDeliverTx{})
s.Require().NotNil(err, "Did not error on invalid tx")
privs, accNums, accSeqs = []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
validTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
s.Require().NoError(err)
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), validTx, types.RequestDeliverTx{})
s.Require().Nil(err, "ValidateBasicMiddleware returned error on valid tx. err: %v", err)
// test middleware skips on recheck
ctx = ctx.WithIsReCheckTx(true)
// middleware should skip processing invalidTx on recheck and thus return nil-error
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), invalidTx, types.RequestDeliverTx{})
s.Require().Nil(err, "ValidateBasicMiddleware ran on ReCheck")
}
func (s *MWTestSuite) TestValidateMemo() {
ctx := s.SetupTest(true) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.ValidateMemoMiddleware(s.app.AccountKeeper))
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
s.Require().NoError(txBuilder.SetMsgs(msg))
txBuilder.SetFeeAmount(feeAmount)
txBuilder.SetGasLimit(gasLimit)
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
txBuilder.SetMemo(strings.Repeat("01234567890", 500))
invalidTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
s.Require().NoError(err)
// require that long memos get rejected
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), invalidTx, types.RequestDeliverTx{})
s.Require().NotNil(err, "Did not error on tx with high memo")
txBuilder.SetMemo(strings.Repeat("01234567890", 10))
validTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
s.Require().NoError(err)
// require small memos pass ValidateMemo middleware
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), validTx, types.RequestDeliverTx{})
s.Require().Nil(err, "ValidateBasicMiddleware returned error on valid tx. err: %v", err)
}
func (s *MWTestSuite) TestConsumeGasForTxSize() {
ctx := s.SetupTest(true) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.ConsumeTxSizeGasMiddleware(s.app.AccountKeeper))
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
testCases := []struct {
name string
sigV2 signing.SignatureV2
}{
{"SingleSignatureData", signing.SignatureV2{PubKey: priv1.PubKey()}},
{"MultiSignatureData", signing.SignatureV2{PubKey: priv1.PubKey(), Data: multisig.NewMultisig(2)}},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
txBuilder = s.clientCtx.TxConfig.NewTxBuilder()
s.Require().NoError(txBuilder.SetMsgs(msg))
txBuilder.SetFeeAmount(feeAmount)
txBuilder.SetGasLimit(gasLimit)
txBuilder.SetMemo(strings.Repeat("01234567890", 10))
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
s.Require().NoError(err)
txBytes, err := s.clientCtx.TxConfig.TxJSONEncoder()(testTx)
s.Require().Nil(err, "Cannot marshal tx: %v", err)
params := s.app.AccountKeeper.GetParams(ctx)
expectedGas := sdk.Gas(len(txBytes)) * params.TxSizeCostPerByte
// Set ctx with TxBytes manually
ctx = ctx.WithTxBytes(txBytes)
// track how much gas is necessary to retrieve parameters
beforeGas := ctx.GasMeter().GasConsumed()
s.app.AccountKeeper.GetParams(ctx)
afterGas := ctx.GasMeter().GasConsumed()
expectedGas += afterGas - beforeGas
beforeGas = ctx.GasMeter().GasConsumed()
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), testTx, types.RequestDeliverTx{Tx: txBytes})
s.Require().Nil(err, "ConsumeTxSizeGasMiddleware returned error: %v", err)
// require that middleware consumes expected amount of gas
consumedGas := ctx.GasMeter().GasConsumed() - beforeGas
s.Require().Equal(expectedGas, consumedGas, "Middleware did not consume the correct amount of gas")
// simulation must not underestimate gas of this middleware even with nil signatures
txBuilder, err := s.clientCtx.TxConfig.WrapTxBuilder(testTx)
s.Require().NoError(err)
s.Require().NoError(txBuilder.SetSignatures(tc.sigV2))
testTx = txBuilder.GetTx()
simTxBytes, err := s.clientCtx.TxConfig.TxJSONEncoder()(testTx)
s.Require().Nil(err, "Cannot marshal tx: %v", err)
// require that simulated tx is smaller than tx with signatures
s.Require().True(len(simTxBytes) < len(txBytes), "simulated tx still has signatures")
// Set s.ctx with smaller simulated TxBytes manually
ctx = ctx.WithTxBytes(simTxBytes)
beforeSimGas := ctx.GasMeter().GasConsumed()
// run txhandler in simulate mode
_, err = txHandler.SimulateTx(sdk.WrapSDKContext(ctx), testTx, tx.RequestSimulateTx{TxBytes: simTxBytes})
consumedSimGas := ctx.GasMeter().GasConsumed() - beforeSimGas
// require that txhandler passes and does not underestimate middleware cost
s.Require().Nil(err, "ConsumeTxSizeGasMiddleware returned error: %v", err)
s.Require().True(consumedSimGas >= expectedGas, "Simulate mode underestimates gas on Middleware. Simulated cost: %d, expected cost: %d", consumedSimGas, expectedGas)
})
}
}
func (s *MWTestSuite) TestTxHeightTimeoutMiddleware() {
ctx := s.SetupTest(true)
txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.TxTimeoutHeightMiddleware)
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
testCases := []struct {
name string
timeout uint64
height int64
expectErr bool
}{
{"default value", 0, 10, false},
{"no timeout (greater height)", 15, 10, false},
{"no timeout (same height)", 10, 10, false},
{"timeout (smaller height)", 9, 10, true},
}
for _, tc := range testCases {
tc := tc
s.Run(tc.name, func() {
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
s.Require().NoError(txBuilder.SetMsgs(msg))
txBuilder.SetFeeAmount(feeAmount)
txBuilder.SetGasLimit(gasLimit)
txBuilder.SetMemo(strings.Repeat("01234567890", 10))
txBuilder.SetTimeoutHeight(tc.timeout)
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
s.Require().NoError(err)
ctx := ctx.WithBlockHeight(tc.height)
_, err = txHandler.SimulateTx(sdk.WrapSDKContext(ctx), testTx, tx.RequestSimulateTx{})
s.Require().Equal(tc.expectErr, err != nil, err)
})
}
}

View File

@ -1,4 +1,4 @@
package ante package middleware
import ( import (
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
@ -6,7 +6,7 @@ import (
) )
// AccountKeeper defines the contract needed for AccountKeeper related APIs. // AccountKeeper defines the contract needed for AccountKeeper related APIs.
// Interface provides support to use non-sdk AccountKeeper for AnteHandler's decorators. // Interface provides support to use non-sdk AccountKeeper for TxHandler's middlewares.
type AccountKeeper interface { type AccountKeeper interface {
GetParams(ctx sdk.Context) (params types.Params) GetParams(ctx sdk.Context) (params types.Params)
GetAccount(ctx sdk.Context, addr sdk.AccAddress) types.AccountI GetAccount(ctx sdk.Context, addr sdk.AccAddress) types.AccountI

71
x/auth/middleware/ext.go Normal file
View File

@ -0,0 +1,71 @@
package middleware
import (
"context"
abci "github.com/tendermint/tendermint/abci/types"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
)
type HasExtensionOptionsTx interface {
GetExtensionOptions() []*codectypes.Any
GetNonCriticalExtensionOptions() []*codectypes.Any
}
type rejectExtensionOptionsTxHandler struct {
next tx.Handler
}
// RejectExtensionOptionsMiddleware creates a new rejectExtensionOptionsMiddleware.
// rejectExtensionOptionsMiddleware is a middleware that rejects all extension
// options which can optionally be included in protobuf transactions. Users that
// need extension options should create a custom middleware chain that handles
// needed extension options properly and rejects unknown ones.
func RejectExtensionOptionsMiddleware(txh tx.Handler) tx.Handler {
return rejectExtensionOptionsTxHandler{
next: txh,
}
}
var _ tx.Handler = rejectExtensionOptionsTxHandler{}
func checkExtOpts(tx sdk.Tx) error {
if hasExtOptsTx, ok := tx.(HasExtensionOptionsTx); ok {
if len(hasExtOptsTx.GetExtensionOptions()) != 0 {
return sdkerrors.ErrUnknownExtensionOptions
}
}
return nil
}
// CheckTx implements tx.Handler.CheckTx.
func (txh rejectExtensionOptionsTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
if err := checkExtOpts(tx); err != nil {
return abci.ResponseCheckTx{}, err
}
return txh.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (txh rejectExtensionOptionsTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := checkExtOpts(tx); err != nil {
return abci.ResponseDeliverTx{}, err
}
return txh.next.DeliverTx(ctx, tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx method.
func (txh rejectExtensionOptionsTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := checkExtOpts(sdkTx); err != nil {
return tx.ResponseSimulateTx{}, err
}
return txh.next.SimulateTx(ctx, sdkTx, req)
}

View File

@ -0,0 +1,36 @@
package middleware_test
import (
"github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/middleware"
"github.com/cosmos/cosmos-sdk/x/auth/tx"
abci "github.com/tendermint/tendermint/abci/types"
)
func (s *MWTestSuite) TestRejectExtensionOptionsMiddleware() {
ctx := s.SetupTest(true) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.RejectExtensionOptionsMiddleware)
// no extension options should not trigger an error
theTx := txBuilder.GetTx()
_, err := txHandler.CheckTx(sdk.WrapSDKContext(ctx), theTx, abci.RequestCheckTx{})
s.Require().NoError(err)
extOptsTxBldr, ok := txBuilder.(tx.ExtensionOptionsTxBuilder)
if !ok {
// if we can't set extension options, this middleware doesn't apply and we're done
return
}
// setting any extension option should cause an error
any, err := types.NewAnyWithValue(testdata.NewTestMsg())
s.Require().NoError(err)
extOptsTxBldr.SetExtensionOptions(any)
theTx = txBuilder.GetTx()
_, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), theTx, abci.RequestCheckTx{})
s.Require().EqualError(err, "unknown extension options")
}

194
x/auth/middleware/fee.go Normal file
View File

@ -0,0 +1,194 @@
package middleware
import (
"context"
"fmt"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/x/auth/types"
abci "github.com/tendermint/tendermint/abci/types"
)
var _ tx.Handler = mempoolFeeTxHandler{}
type mempoolFeeTxHandler struct {
next tx.Handler
}
// MempoolFeeMiddleware will check if the transaction's fee is at least as large
// as the local validator's minimum gasFee (defined in validator config).
// If fee is too low, middleware returns error and tx is rejected from mempool.
// Note this only applies when ctx.CheckTx = true
// If fee is high enough or not CheckTx, then call next middleware
// CONTRACT: Tx must implement FeeTx to use MempoolFeeMiddleware
func MempoolFeeMiddleware(txh tx.Handler) tx.Handler {
return mempoolFeeTxHandler{
next: txh,
}
}
// CheckTx implements tx.Handler.CheckTx.
func (txh mempoolFeeTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
sdkCtx := sdk.UnwrapSDKContext(ctx)
feeTx, ok := tx.(sdk.FeeTx)
if !ok {
return abci.ResponseCheckTx{}, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx")
}
feeCoins := feeTx.GetFee()
gas := feeTx.GetGas()
// Ensure that the provided fees meet a minimum threshold for the validator,
// if this is a CheckTx. This is only for local mempool purposes, and thus
// is only ran on check tx.
minGasPrices := sdkCtx.MinGasPrices()
if !minGasPrices.IsZero() {
requiredFees := make(sdk.Coins, len(minGasPrices))
// Determine the required fees by multiplying each required minimum gas
// price by the gas limit, where fee = ceil(minGasPrice * gasLimit).
glDec := sdk.NewDec(int64(gas))
for i, gp := range minGasPrices {
fee := gp.Amount.Mul(glDec)
requiredFees[i] = sdk.NewCoin(gp.Denom, fee.Ceil().RoundInt())
}
if !feeCoins.IsAnyGTE(requiredFees) {
return abci.ResponseCheckTx{}, sdkerrors.Wrapf(sdkerrors.ErrInsufficientFee, "insufficient fees; got: %s required: %s", feeCoins, requiredFees)
}
}
return txh.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (txh mempoolFeeTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
return txh.next.DeliverTx(ctx, tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx.
func (txh mempoolFeeTxHandler) SimulateTx(ctx context.Context, tx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
return txh.next.SimulateTx(ctx, tx, req)
}
var _ tx.Handler = deductFeeTxHandler{}
type deductFeeTxHandler struct {
accountKeeper AccountKeeper
bankKeeper types.BankKeeper
feegrantKeeper FeegrantKeeper
next tx.Handler
}
// DeductFeeMiddleware deducts fees from the first signer of the tx
// If the first signer does not have the funds to pay for the fees, return with InsufficientFunds error
// Call next middleware if fees successfully deducted
// CONTRACT: Tx must implement FeeTx interface to use deductFeeTxHandler
func DeductFeeMiddleware(ak AccountKeeper, bk types.BankKeeper, fk FeegrantKeeper) tx.Middleware {
return func(txh tx.Handler) tx.Handler {
return deductFeeTxHandler{
accountKeeper: ak,
bankKeeper: bk,
feegrantKeeper: fk,
next: txh,
}
}
}
func (dfd deductFeeTxHandler) checkDeductFee(ctx context.Context, tx sdk.Tx) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
feeTx, ok := tx.(sdk.FeeTx)
if !ok {
return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx")
}
if addr := dfd.accountKeeper.GetModuleAddress(types.FeeCollectorName); addr == nil {
panic(fmt.Sprintf("%s module account has not been set", types.FeeCollectorName))
}
fee := feeTx.GetFee()
feePayer := feeTx.FeePayer()
feeGranter := feeTx.FeeGranter()
deductFeesFrom := feePayer
// if feegranter set deduct fee from feegranter account.
// this works with only when feegrant enabled.
if feeGranter != nil {
if dfd.feegrantKeeper == nil {
return sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "fee grants are not enabled")
} else if !feeGranter.Equals(feePayer) {
err := dfd.feegrantKeeper.UseGrantedFees(sdkCtx, feeGranter, feePayer, fee, tx.GetMsgs())
if err != nil {
return sdkerrors.Wrapf(err, "%s not allowed to pay fees from %s", feeGranter, feePayer)
}
}
deductFeesFrom = feeGranter
}
deductFeesFromAcc := dfd.accountKeeper.GetAccount(sdkCtx, deductFeesFrom)
if deductFeesFromAcc == nil {
return sdkerrors.Wrapf(sdkerrors.ErrUnknownAddress, "fee payer address: %s does not exist", deductFeesFrom)
}
// deduct the fees
if !feeTx.GetFee().IsZero() {
err := DeductFees(dfd.bankKeeper, sdkCtx, deductFeesFromAcc, feeTx.GetFee())
if err != nil {
return err
}
}
events := sdk.Events{sdk.NewEvent(sdk.EventTypeTx,
sdk.NewAttribute(sdk.AttributeKeyFee, feeTx.GetFee().String()),
)}
sdkCtx.EventManager().EmitEvents(events)
return nil
}
// CheckTx implements tx.Handler.CheckTx.
func (dfd deductFeeTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
if err := dfd.checkDeductFee(ctx, tx); err != nil {
return abci.ResponseCheckTx{}, err
}
return dfd.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (dfd deductFeeTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := dfd.checkDeductFee(ctx, tx); err != nil {
return abci.ResponseDeliverTx{}, err
}
return dfd.next.DeliverTx(ctx, tx, req)
}
func (dfd deductFeeTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := dfd.checkDeductFee(ctx, sdkTx); err != nil {
return tx.ResponseSimulateTx{}, err
}
return dfd.next.SimulateTx(ctx, sdkTx, req)
}
// DeductFees deducts fees from the given account.
func DeductFees(bankKeeper types.BankKeeper, ctx sdk.Context, acc types.AccountI, fees sdk.Coins) error {
if !fees.IsValid() {
return sdkerrors.Wrapf(sdkerrors.ErrInsufficientFee, "invalid fee amount: %s", fees)
}
err := bankKeeper.SendCoinsFromAccountToModule(ctx, acc.GetAddress(), types.FeeCollectorName, fees)
if err != nil {
return sdkerrors.Wrapf(sdkerrors.ErrInsufficientFunds, err.Error())
}
return nil
}

View File

@ -0,0 +1,99 @@
package middleware_test
import (
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/middleware"
"github.com/cosmos/cosmos-sdk/x/bank/testutil"
abci "github.com/tendermint/tendermint/abci/types"
)
func (s *MWTestSuite) TestEnsureMempoolFees() {
ctx := s.SetupTest(true) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
txHandler := middleware.ComposeMiddlewares(noopTxHandler{}, middleware.MempoolFeeMiddleware)
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
s.Require().NoError(txBuilder.SetMsgs(msg))
txBuilder.SetFeeAmount(feeAmount)
txBuilder.SetGasLimit(gasLimit)
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
s.Require().NoError(err)
// Set high gas price so standard test fee fails
atomPrice := sdk.NewDecCoinFromDec("atom", sdk.NewDec(200).Quo(sdk.NewDec(100000)))
highGasPrice := []sdk.DecCoin{atomPrice}
ctx = ctx.WithMinGasPrices(highGasPrice)
// txHandler errors with insufficient fees
_, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx, abci.RequestCheckTx{})
s.Require().NotNil(err, "Middleware should have errored on too low fee for local gasPrice")
// txHandler should not error since we do not check minGasPrice in DeliverTx
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{})
s.Require().Nil(err, "MempoolFeeMiddleware returned error in DeliverTx")
atomPrice = sdk.NewDecCoinFromDec("atom", sdk.NewDec(0).Quo(sdk.NewDec(100000)))
lowGasPrice := []sdk.DecCoin{atomPrice}
ctx = ctx.WithMinGasPrices(lowGasPrice)
_, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx, abci.RequestCheckTx{})
s.Require().Nil(err, "Middleware should not have errored on fee higher than local gasPrice")
}
func (s *MWTestSuite) TestDeductFees() {
ctx := s.SetupTest(false) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
txHandler := middleware.ComposeMiddlewares(
noopTxHandler{},
middleware.DeductFeeMiddleware(
s.app.AccountKeeper,
s.app.BankKeeper,
s.app.FeeGrantKeeper,
),
)
// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
// msg and signatures
msg := testdata.NewTestMsg(addr1)
feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit()
s.Require().NoError(txBuilder.SetMsgs(msg))
txBuilder.SetFeeAmount(feeAmount)
txBuilder.SetGasLimit(gasLimit)
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
s.Require().NoError(err)
// Set account with insufficient funds
acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr1)
s.app.AccountKeeper.SetAccount(ctx, acc)
coins := sdk.NewCoins(sdk.NewCoin("atom", sdk.NewInt(10)))
err = testutil.FundAccount(s.app.BankKeeper, ctx, addr1, coins)
s.Require().NoError(err)
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{})
s.Require().NotNil(err, "Tx did not error when fee payer had insufficient funds")
// Set account with sufficient funds
s.app.AccountKeeper.SetAccount(ctx, acc)
err = testutil.FundAccount(s.app.BankKeeper, ctx, addr1, sdk.NewCoins(sdk.NewCoin("atom", sdk.NewInt(200))))
s.Require().NoError(err)
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{})
s.Require().Nil(err, "Tx errored after account has been set with sufficient funds")
}

View File

@ -1,10 +1,11 @@
package ante_test package middleware_test
import ( import (
"math/rand" "math/rand"
"testing" "testing"
"time" "time"
abci "github.com/tendermint/tendermint/abci/types"
"github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/crypto"
"github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client"
@ -15,7 +16,7 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/simulation" "github.com/cosmos/cosmos-sdk/types/simulation"
"github.com/cosmos/cosmos-sdk/types/tx/signing" "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/ante" "github.com/cosmos/cosmos-sdk/x/auth/middleware"
authsign "github.com/cosmos/cosmos-sdk/x/auth/signing" authsign "github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/cosmos/cosmos-sdk/x/auth/tx" "github.com/cosmos/cosmos-sdk/x/auth/tx"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
@ -23,19 +24,20 @@ import (
"github.com/cosmos/cosmos-sdk/x/feegrant" "github.com/cosmos/cosmos-sdk/x/feegrant"
) )
func (suite *AnteTestSuite) TestDeductFeesNoDelegation() { func (s *MWTestSuite) TestDeductFeesNoDelegation() {
suite.SetupTest(false) ctx := s.SetupTest(false) // setup
// setup app := s.app
app, ctx := suite.app, suite.ctx
protoTxCfg := tx.NewTxConfig(codec.NewProtoCodec(app.InterfaceRegistry()), tx.DefaultSignModes) protoTxCfg := tx.NewTxConfig(codec.NewProtoCodec(app.InterfaceRegistry()), tx.DefaultSignModes)
// this just tests our handler txHandler := middleware.ComposeMiddlewares(
dfd := ante.NewDeductFeeDecorator(app.AccountKeeper, app.BankKeeper, app.FeeGrantKeeper) noopTxHandler{},
feeAnteHandler := sdk.ChainAnteDecorators(dfd) middleware.DeductFeeMiddleware(
s.app.AccountKeeper,
// this tests the whole stack s.app.BankKeeper,
anteHandlerStack := suite.anteHandler s.app.FeeGrantKeeper,
),
)
// keys and addresses // keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr() priv1, _, addr1 := testdata.KeyTestPubAddr()
@ -45,24 +47,24 @@ func (suite *AnteTestSuite) TestDeductFeesNoDelegation() {
priv5, _, addr5 := testdata.KeyTestPubAddr() priv5, _, addr5 := testdata.KeyTestPubAddr()
// Set addr1 with insufficient funds // Set addr1 with insufficient funds
err := testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr1, []sdk.Coin{sdk.NewCoin("atom", sdk.NewInt(10))}) err := testutil.FundAccount(s.app.BankKeeper, ctx, addr1, []sdk.Coin{sdk.NewCoin("atom", sdk.NewInt(10))})
suite.Require().NoError(err) s.Require().NoError(err)
// Set addr2 with more funds // Set addr2 with more funds
err = testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr2, []sdk.Coin{sdk.NewCoin("atom", sdk.NewInt(99999))}) err = testutil.FundAccount(s.app.BankKeeper, ctx, addr2, []sdk.Coin{sdk.NewCoin("atom", sdk.NewInt(99999))})
suite.Require().NoError(err) s.Require().NoError(err)
// grant fee allowance from `addr2` to `addr3` (plenty to pay) // grant fee allowance from `addr2` to `addr3` (plenty to pay)
err = app.FeeGrantKeeper.GrantAllowance(ctx, addr2, addr3, &feegrant.BasicAllowance{ err = app.FeeGrantKeeper.GrantAllowance(ctx, addr2, addr3, &feegrant.BasicAllowance{
SpendLimit: sdk.NewCoins(sdk.NewInt64Coin("atom", 500)), SpendLimit: sdk.NewCoins(sdk.NewInt64Coin("atom", 500)),
}) })
suite.Require().NoError(err) s.Require().NoError(err)
// grant low fee allowance (20atom), to check the tx requesting more than allowed. // grant low fee allowance (20atom), to check the tx requesting more than allowed.
err = app.FeeGrantKeeper.GrantAllowance(ctx, addr2, addr4, &feegrant.BasicAllowance{ err = app.FeeGrantKeeper.GrantAllowance(ctx, addr2, addr4, &feegrant.BasicAllowance{
SpendLimit: sdk.NewCoins(sdk.NewInt64Coin("atom", 20)), SpendLimit: sdk.NewCoins(sdk.NewInt64Coin("atom", 20)),
}) })
suite.Require().NoError(err) s.Require().NoError(err)
cases := map[string]struct { cases := map[string]struct {
signerKey cryptotypes.PrivKey signerKey cryptotypes.PrivKey
@ -133,7 +135,7 @@ func (suite *AnteTestSuite) TestDeductFeesNoDelegation() {
for name, stc := range cases { for name, stc := range cases {
tc := stc // to make scopelint happy tc := stc // to make scopelint happy
suite.T().Run(name, func(t *testing.T) { s.T().Run(name, func(t *testing.T) {
fee := sdk.NewCoins(sdk.NewInt64Coin("atom", tc.fee)) fee := sdk.NewCoins(sdk.NewInt64Coin("atom", tc.fee))
msgs := []sdk.Msg{testdata.NewTestMsg(tc.signer)} msgs := []sdk.Msg{testdata.NewTestMsg(tc.signer)}
@ -144,19 +146,22 @@ func (suite *AnteTestSuite) TestDeductFeesNoDelegation() {
} }
tx, err := genTxWithFeeGranter(protoTxCfg, msgs, fee, helpers.DefaultGenTxGas, ctx.ChainID(), accNums, seqs, tc.feeAccount, privs...) tx, err := genTxWithFeeGranter(protoTxCfg, msgs, fee, helpers.DefaultGenTxGas, ctx.ChainID(), accNums, seqs, tc.feeAccount, privs...)
suite.Require().NoError(err) s.Require().NoError(err)
_, err = feeAnteHandler(ctx, tx, false) // tests only feegrant ante
// tests only feegrant middleware
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{})
if tc.valid { if tc.valid {
suite.Require().NoError(err) s.Require().NoError(err)
} else { } else {
suite.Require().Error(err) s.Require().Error(err)
} }
_, err = anteHandlerStack(ctx, tx, false) // tests while stack // tests while stack
_, err = s.txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{})
if tc.valid { if tc.valid {
suite.Require().NoError(err) s.Require().NoError(err)
} else { } else {
suite.Require().Error(err) s.Require().Error(err)
} }
}) })
} }

View File

@ -70,7 +70,7 @@ func (s *MWTestSuite) TestSetup() {
if tc.expErr { if tc.expErr {
s.Require().EqualError(err, tc.errorStr) s.Require().EqualError(err, tc.errorStr)
} else { } else {
s.Require().Nil(err, "SetUpContextDecorator returned error") s.Require().Nil(err, "SetUpContextMiddleware returned error")
s.Require().Equal(tc.expGasLimit, uint64(res.GasWanted)) s.Require().Equal(tc.expGasLimit, uint64(res.GasWanted))
} }
}) })

View File

@ -1,115 +0,0 @@
package middleware
import (
"context"
abci "github.com/tendermint/tendermint/abci/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
)
type legacyAnteTxHandler struct {
anteHandler sdk.AnteHandler
inner tx.Handler
}
func newLegacyAnteMiddleware(anteHandler sdk.AnteHandler) tx.Middleware {
return func(txHandler tx.Handler) tx.Handler {
return legacyAnteTxHandler{
anteHandler: anteHandler,
inner: txHandler,
}
}
}
var _ tx.Handler = legacyAnteTxHandler{}
// CheckTx implements tx.Handler.CheckTx method.
func (txh legacyAnteTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
sdkCtx, err := txh.runAnte(ctx, tx, req.Tx, false)
if err != nil {
return abci.ResponseCheckTx{}, err
}
return txh.inner.CheckTx(sdk.WrapSDKContext(sdkCtx), tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx method.
func (txh legacyAnteTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
sdkCtx, err := txh.runAnte(ctx, tx, req.Tx, false)
if err != nil {
return abci.ResponseDeliverTx{}, err
}
return txh.inner.DeliverTx(sdk.WrapSDKContext(sdkCtx), tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx method.
func (txh legacyAnteTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
sdkCtx, err := txh.runAnte(ctx, sdkTx, req.TxBytes, true)
if err != nil {
return tx.ResponseSimulateTx{}, err
}
return txh.inner.SimulateTx(sdk.WrapSDKContext(sdkCtx), sdkTx, req)
}
func (txh legacyAnteTxHandler) runAnte(ctx context.Context, tx sdk.Tx, txBytes []byte, isSimulate bool) (sdk.Context, error) {
err := validateBasicTxMsgs(tx.GetMsgs())
if err != nil {
return sdk.Context{}, err
}
sdkCtx := sdk.UnwrapSDKContext(ctx)
if txh.anteHandler == nil {
return sdkCtx, nil
}
ms := sdkCtx.MultiStore()
// Branch context before AnteHandler call in case it aborts.
// This is required for both CheckTx and DeliverTx.
// Ref: https://github.com/cosmos/cosmos-sdk/issues/2772
//
// NOTE: Alternatively, we could require that AnteHandler ensures that
// writes do not happen if aborted/failed. This may have some
// performance benefits, but it'll be more difficult to get right.
anteCtx, msCache := cacheTxContext(sdkCtx, txBytes)
anteCtx = anteCtx.WithEventManager(sdk.NewEventManager())
newCtx, err := txh.anteHandler(anteCtx, tx, isSimulate)
if err != nil {
return sdk.Context{}, err
}
if !newCtx.IsZero() {
// At this point, newCtx.MultiStore() is a store branch, or something else
// replaced by the AnteHandler. We want the original multistore.
//
// Also, in the case of the tx aborting, we need to track gas consumed via
// the instantiated gas meter in the AnteHandler, so we update the context
// prior to returning.
sdkCtx = newCtx.WithMultiStore(ms)
}
msCache.Write()
return sdkCtx, nil
}
// validateBasicTxMsgs executes basic validator calls for messages.
func validateBasicTxMsgs(msgs []sdk.Msg) error {
if len(msgs) == 0 {
return sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "must contain at least one message")
}
for _, msg := range msgs {
err := msg.ValidateBasic()
if err != nil {
return err
}
}
return nil
}

View File

@ -2,7 +2,11 @@ package middleware
import ( import (
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/types/tx/signing"
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/cosmos/cosmos-sdk/x/auth/types"
) )
// ComposeMiddlewares compose multiple middlewares on top of a tx.Handler. The // ComposeMiddlewares compose multiple middlewares on top of a tx.Handler. The
@ -35,12 +39,33 @@ type TxHandlerOptions struct {
LegacyRouter sdk.Router LegacyRouter sdk.Router
MsgServiceRouter *MsgServiceRouter MsgServiceRouter *MsgServiceRouter
LegacyAnteHandler sdk.AnteHandler AccountKeeper AccountKeeper
BankKeeper types.BankKeeper
FeegrantKeeper FeegrantKeeper
SignModeHandler authsigning.SignModeHandler
SigGasConsumer func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error
} }
// NewDefaultTxHandler defines a TxHandler middleware stacks that should work // NewDefaultTxHandler defines a TxHandler middleware stacks that should work
// for most applications. // for most applications.
func NewDefaultTxHandler(options TxHandlerOptions) (tx.Handler, error) { func NewDefaultTxHandler(options TxHandlerOptions) (tx.Handler, error) {
if options.AccountKeeper == nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "account keeper is required for compose middlewares")
}
if options.BankKeeper == nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "bank keeper is required for compose middlewares")
}
if options.SignModeHandler == nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrLogic, "sign mode handler is required for compose middlewares")
}
var sigGasConsumer = options.SigGasConsumer
if sigGasConsumer == nil {
sigGasConsumer = DefaultSigVerificationGasConsumer
}
return ComposeMiddlewares( return ComposeMiddlewares(
NewRunMsgsTxHandler(options.MsgServiceRouter, options.LegacyRouter), NewRunMsgsTxHandler(options.MsgServiceRouter, options.LegacyRouter),
// Set a new GasMeter on sdk.Context. // Set a new GasMeter on sdk.Context.
@ -55,8 +80,19 @@ func NewDefaultTxHandler(options TxHandlerOptions) (tx.Handler, error) {
// Choose which events to index in Tendermint. Make sure no events are // Choose which events to index in Tendermint. Make sure no events are
// emitted outside of this middleware. // emitted outside of this middleware.
NewIndexEventsTxMiddleware(options.IndexEvents), NewIndexEventsTxMiddleware(options.IndexEvents),
// Temporary middleware to bundle antehandlers. // Reject all extension options which can optionally be included in the
// TODO Remove in https://github.com/cosmos/cosmos-sdk/issues/9585. // tx.
newLegacyAnteMiddleware(options.LegacyAnteHandler), RejectExtensionOptionsMiddleware,
MempoolFeeMiddleware,
ValidateBasicMiddleware,
TxTimeoutHeightMiddleware,
ValidateMemoMiddleware(options.AccountKeeper),
ConsumeTxSizeGasMiddleware(options.AccountKeeper),
DeductFeeMiddleware(options.AccountKeeper, options.BankKeeper, options.FeegrantKeeper),
SetPubKeyMiddleware(options.AccountKeeper),
ValidateSigCountMiddleware(options.AccountKeeper),
SigGasConsumeMiddleware(options.AccountKeeper, sigGasConsumer),
SigVerificationMiddleware(options.AccountKeeper, options.SignModeHandler),
IncrementSequenceMiddleware(options.AccountKeeper),
), nil ), nil
} }

View File

@ -1,4 +1,4 @@
package ante_test package middleware_test
import ( import (
"encoding/json" "encoding/json"
@ -7,8 +7,6 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/require"
"github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519"
kmultisig "github.com/cosmos/cosmos-sdk/crypto/keys/multisig" kmultisig "github.com/cosmos/cosmos-sdk/crypto/keys/multisig"
"github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1" "github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1"
@ -17,18 +15,21 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx/signing" "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/ante" "github.com/cosmos/cosmos-sdk/x/auth/middleware"
"github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/cosmos/cosmos-sdk/x/auth/types"
"github.com/cosmos/cosmos-sdk/x/bank/testutil" "github.com/cosmos/cosmos-sdk/x/bank/testutil"
minttypes "github.com/cosmos/cosmos-sdk/x/mint/types" minttypes "github.com/cosmos/cosmos-sdk/x/mint/types"
"github.com/stretchr/testify/require"
abci "github.com/tendermint/tendermint/abci/types"
) )
// Test that simulate transaction accurately estimates gas cost // Test that simulate transaction accurately estimates gas cost
func (suite *AnteTestSuite) TestSimulateGasCost() { func (s *MWTestSuite) TestSimulateGasCost() {
suite.SetupTest(false) // reset ctx := s.SetupTest(false) // reset
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(3) accounts := s.createTestAccounts(ctx, 3)
msgs := []sdk.Msg{ msgs := []sdk.Msg{
testdata.NewTestMsg(accounts[0].acc.GetAddress(), accounts[1].acc.GetAddress()), testdata.NewTestMsg(accounts[0].acc.GetAddress(), accounts[1].acc.GetAddress()),
testdata.NewTestMsg(accounts[2].acc.GetAddress(), accounts[0].acc.GetAddress()), testdata.NewTestMsg(accounts[2].acc.GetAddress(), accounts[0].acc.GetAddress()),
@ -44,8 +45,8 @@ func (suite *AnteTestSuite) TestSimulateGasCost() {
{ {
"tx with 150atom fee", "tx with 150atom fee",
func() { func() {
suite.txBuilder.SetFeeAmount(feeAmount) txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit) txBuilder.SetGasLimit(gasLimit)
}, },
true, true,
true, true,
@ -54,11 +55,11 @@ func (suite *AnteTestSuite) TestSimulateGasCost() {
{ {
"with previously estimated gas", "with previously estimated gas",
func() { func() {
simulatedGas := suite.ctx.GasMeter().GasConsumed() simulatedGas := ctx.GasMeter().GasConsumed()
accSeqs = []uint64{1, 1, 1} accSeqs = []uint64{1, 1, 1}
suite.txBuilder.SetFeeAmount(feeAmount) txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(simulatedGas) txBuilder.SetGasLimit(simulatedGas)
}, },
false, false,
true, true,
@ -67,18 +68,18 @@ func (suite *AnteTestSuite) TestSimulateGasCost() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc)
}) })
} }
} }
// Test various error cases in the AnteHandler control flow. // Test various error cases in the TxHandler control flow.
func (suite *AnteTestSuite) TestAnteHandlerSigErrors() { func (s *MWTestSuite) TestTxHandlerSigErrors() {
suite.SetupTest(false) // reset ctx := s.SetupTest(false) // reset
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
priv0, _, addr0 := testdata.KeyTestPubAddr() priv0, _, addr0 := testdata.KeyTestPubAddr()
@ -105,12 +106,12 @@ func (suite *AnteTestSuite) TestAnteHandlerSigErrors() {
privs, accNums, accSeqs = []cryptotypes.PrivKey{}, []uint64{}, []uint64{} privs, accNums, accSeqs = []cryptotypes.PrivKey{}, []uint64{}, []uint64{}
// Create tx manually to test the tx's signers // Create tx manually to test the tx's signers
suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) s.Require().NoError(txBuilder.SetMsgs(msgs...))
tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
suite.Require().NoError(err) s.Require().NoError(err)
// tx.GetSigners returns addresses in correct order: addr1, addr2, addr3 // tx.GetSigners returns addresses in correct order: addr1, addr2, addr3
expectedSigners := []sdk.AccAddress{addr0, addr1, addr2} expectedSigners := []sdk.AccAddress{addr0, addr1, addr2}
suite.Require().Equal(expectedSigners, tx.GetSigners()) s.Require().Equal(expectedSigners, tx.GetSigners())
}, },
false, false,
false, false,
@ -137,12 +138,12 @@ func (suite *AnteTestSuite) TestAnteHandlerSigErrors() {
{ {
"save the first account, but second is still unrecognized", "save the first account, but second is still unrecognized",
func() { func() {
acc1 := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr0) acc1 := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr0)
suite.app.AccountKeeper.SetAccount(suite.ctx, acc1) s.app.AccountKeeper.SetAccount(ctx, acc1)
err := suite.app.BankKeeper.MintCoins(suite.ctx, minttypes.ModuleName, feeAmount) err := s.app.BankKeeper.MintCoins(ctx, minttypes.ModuleName, feeAmount)
suite.Require().NoError(err) s.Require().NoError(err)
err = suite.app.BankKeeper.SendCoinsFromModuleToAccount(suite.ctx, minttypes.ModuleName, addr0, feeAmount) err = s.app.BankKeeper.SendCoinsFromModuleToAccount(ctx, minttypes.ModuleName, addr0, feeAmount)
suite.Require().NoError(err) s.Require().NoError(err)
}, },
false, false,
false, false,
@ -151,21 +152,21 @@ func (suite *AnteTestSuite) TestAnteHandlerSigErrors() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc)
}) })
} }
} }
// Test logic around account number checking with one signer and many signers. // Test logic around account number checking with one signer and many signers.
func (suite *AnteTestSuite) TestAnteHandlerAccountNumbers() { func (s *MWTestSuite) TestTxHandlerAccountNumbers() {
suite.SetupTest(false) // reset ctx := s.SetupTest(false) // reset
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(2) accounts := s.createTestAccounts(ctx, 2)
feeAmount := testdata.NewTestFeeAmount() feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit() gasLimit := testdata.NewTestGasLimit()
@ -232,22 +233,22 @@ func (suite *AnteTestSuite) TestAnteHandlerAccountNumbers() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc)
}) })
} }
} }
// Test logic around account number checking with many signers when BlockHeight is 0. // Test logic around account number checking with many signers when BlockHeight is 0.
func (suite *AnteTestSuite) TestAnteHandlerAccountNumbersAtBlockHeightZero() { func (s *MWTestSuite) TestTxHandlerAccountNumbersAtBlockHeightZero() {
suite.SetupTest(false) // setup ctx := s.SetupTest(false) // setup
suite.ctx = suite.ctx.WithBlockHeight(0) ctx = ctx.WithBlockHeight(0)
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(2) accounts := s.createTestAccounts(ctx, 2)
feeAmount := testdata.NewTestFeeAmount() feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit() gasLimit := testdata.NewTestGasLimit()
@ -316,21 +317,21 @@ func (suite *AnteTestSuite) TestAnteHandlerAccountNumbersAtBlockHeightZero() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc)
}) })
} }
} }
// Test logic around sequence checking with one signer and many signers. // Test logic around sequence checking with one signer and many signers.
func (suite *AnteTestSuite) TestAnteHandlerSequences() { func (s *MWTestSuite) TestTxHandlerSequences() {
suite.SetupTest(false) // setup ctx := s.SetupTest(false) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(3) accounts := s.createTestAccounts(ctx, 3)
feeAmount := testdata.NewTestFeeAmount() feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit() gasLimit := testdata.NewTestGasLimit()
@ -428,24 +429,24 @@ func (suite *AnteTestSuite) TestAnteHandlerSequences() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc)
}) })
} }
} }
// Test logic around fee deduction. // Test logic around fee deduction.
func (suite *AnteTestSuite) TestAnteHandlerFees() { func (s *MWTestSuite) TestTxHandlerFees() {
suite.SetupTest(false) // setup ctx := s.SetupTest(false) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
priv0, _, addr0 := testdata.KeyTestPubAddr() priv0, _, addr0 := testdata.KeyTestPubAddr()
acc1 := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr0) acc1 := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr0)
suite.app.AccountKeeper.SetAccount(suite.ctx, acc1) s.app.AccountKeeper.SetAccount(ctx, acc1)
msgs := []sdk.Msg{testdata.NewTestMsg(addr0)} msgs := []sdk.Msg{testdata.NewTestMsg(addr0)}
feeAmount := testdata.NewTestFeeAmount() feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit() gasLimit := testdata.NewTestGasLimit()
@ -470,8 +471,8 @@ func (suite *AnteTestSuite) TestAnteHandlerFees() {
{ {
"signer does not have enough funds to pay the fee", "signer does not have enough funds to pay the fee",
func() { func() {
err := testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr0, sdk.NewCoins(sdk.NewInt64Coin("atom", 149))) err := testutil.FundAccount(s.app.BankKeeper, ctx, addr0, sdk.NewCoins(sdk.NewInt64Coin("atom", 149)))
suite.Require().NoError(err) s.Require().NoError(err)
}, },
false, false,
false, false,
@ -482,13 +483,13 @@ func (suite *AnteTestSuite) TestAnteHandlerFees() {
func() { func() {
accNums = []uint64{acc1.GetAccountNumber()} accNums = []uint64{acc1.GetAccountNumber()}
modAcc := suite.app.AccountKeeper.GetModuleAccount(suite.ctx, types.FeeCollectorName) modAcc := s.app.AccountKeeper.GetModuleAccount(ctx, types.FeeCollectorName)
suite.Require().True(suite.app.BankKeeper.GetAllBalances(suite.ctx, modAcc.GetAddress()).Empty()) s.Require().True(s.app.BankKeeper.GetAllBalances(ctx, modAcc.GetAddress()).Empty())
require.True(sdk.IntEq(suite.T(), suite.app.BankKeeper.GetAllBalances(suite.ctx, addr0).AmountOf("atom"), sdk.NewInt(149))) require.True(sdk.IntEq(s.T(), s.app.BankKeeper.GetAllBalances(ctx, addr0).AmountOf("atom"), sdk.NewInt(149)))
err := testutil.FundAccount(suite.app.BankKeeper, suite.ctx, addr0, sdk.NewCoins(sdk.NewInt64Coin("atom", 1))) err := testutil.FundAccount(s.app.BankKeeper, ctx, addr0, sdk.NewCoins(sdk.NewInt64Coin("atom", 1)))
suite.Require().NoError(err) s.Require().NoError(err)
}, },
false, false,
true, true,
@ -497,10 +498,10 @@ func (suite *AnteTestSuite) TestAnteHandlerFees() {
{ {
"signer doesn't have any more funds", "signer doesn't have any more funds",
func() { func() {
modAcc := suite.app.AccountKeeper.GetModuleAccount(suite.ctx, types.FeeCollectorName) modAcc := s.app.AccountKeeper.GetModuleAccount(ctx, types.FeeCollectorName)
require.True(sdk.IntEq(suite.T(), suite.app.BankKeeper.GetAllBalances(suite.ctx, modAcc.GetAddress()).AmountOf("atom"), sdk.NewInt(150))) require.True(sdk.IntEq(s.T(), s.app.BankKeeper.GetAllBalances(ctx, modAcc.GetAddress()).AmountOf("atom"), sdk.NewInt(150)))
require.True(sdk.IntEq(suite.T(), suite.app.BankKeeper.GetAllBalances(suite.ctx, addr0).AmountOf("atom"), sdk.NewInt(0))) require.True(sdk.IntEq(s.T(), s.app.BankKeeper.GetAllBalances(ctx, addr0).AmountOf("atom"), sdk.NewInt(0)))
}, },
false, false,
false, false,
@ -509,22 +510,21 @@ func (suite *AnteTestSuite) TestAnteHandlerFees() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc)
}) })
} }
} }
// Test logic around memo gas consumption. // Test logic around memo gas consumption.
func (suite *AnteTestSuite) TestAnteHandlerMemoGas() { func (s *MWTestSuite) TestTxHandlerMemoGas() {
suite.SetupTest(false) // setup ctx := s.SetupTest(false) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(1) accounts := s.createTestAccounts(ctx, 1)
msgs := []sdk.Msg{testdata.NewTestMsg(accounts[0].acc.GetAddress())} msgs := []sdk.Msg{testdata.NewTestMsg(accounts[0].acc.GetAddress())}
privs, accNums, accSeqs := []cryptotypes.PrivKey{accounts[0].priv}, []uint64{0}, []uint64{0} privs, accNums, accSeqs := []cryptotypes.PrivKey{accounts[0].priv}, []uint64{0}, []uint64{0}
@ -550,7 +550,7 @@ func (suite *AnteTestSuite) TestAnteHandlerMemoGas() {
func() { func() {
feeAmount = sdk.NewCoins(sdk.NewInt64Coin("atom", 0)) feeAmount = sdk.NewCoins(sdk.NewInt64Coin("atom", 0))
gasLimit = 801 gasLimit = 801
suite.txBuilder.SetMemo("abcininasidniandsinasindiansdiansdinaisndiasndiadninsd") txBuilder.SetMemo("abcininasidniandsinasindiansdiansdinaisndiasndiadninsd")
}, },
false, false,
false, false,
@ -561,7 +561,7 @@ func (suite *AnteTestSuite) TestAnteHandlerMemoGas() {
func() { func() {
feeAmount = sdk.NewCoins(sdk.NewInt64Coin("atom", 0)) feeAmount = sdk.NewCoins(sdk.NewInt64Coin("atom", 0))
gasLimit = 50000 gasLimit = 50000
suite.txBuilder.SetMemo(strings.Repeat("01234567890", 500)) txBuilder.SetMemo(strings.Repeat("01234567890", 500))
}, },
false, false,
false, false,
@ -572,7 +572,7 @@ func (suite *AnteTestSuite) TestAnteHandlerMemoGas() {
func() { func() {
feeAmount = sdk.NewCoins(sdk.NewInt64Coin("atom", 0)) feeAmount = sdk.NewCoins(sdk.NewInt64Coin("atom", 0))
gasLimit = 50000 gasLimit = 50000
suite.txBuilder.SetMemo(strings.Repeat("0123456789", 10)) txBuilder.SetMemo(strings.Repeat("0123456789", 10))
}, },
false, false,
true, true,
@ -581,20 +581,20 @@ func (suite *AnteTestSuite) TestAnteHandlerMemoGas() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc)
}) })
} }
} }
func (suite *AnteTestSuite) TestAnteHandlerMultiSigner() { func (s *MWTestSuite) TestTxHandlerMultiSigner() {
suite.SetupTest(false) // setup ctx := s.SetupTest(false) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(3) accounts := s.createTestAccounts(ctx, 3)
msg1 := testdata.NewTestMsg(accounts[0].acc.GetAddress(), accounts[1].acc.GetAddress()) msg1 := testdata.NewTestMsg(accounts[0].acc.GetAddress(), accounts[1].acc.GetAddress())
msg2 := testdata.NewTestMsg(accounts[2].acc.GetAddress(), accounts[0].acc.GetAddress()) msg2 := testdata.NewTestMsg(accounts[2].acc.GetAddress(), accounts[0].acc.GetAddress())
msg3 := testdata.NewTestMsg(accounts[1].acc.GetAddress(), accounts[2].acc.GetAddress()) msg3 := testdata.NewTestMsg(accounts[1].acc.GetAddress(), accounts[2].acc.GetAddress())
@ -615,7 +615,7 @@ func (suite *AnteTestSuite) TestAnteHandlerMultiSigner() {
func() { func() {
msgs = []sdk.Msg{msg1, msg2, msg3} msgs = []sdk.Msg{msg1, msg2, msg3}
privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[0].priv, accounts[1].priv, accounts[2].priv}, []uint64{0, 1, 2}, []uint64{0, 0, 0} privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[0].priv, accounts[1].priv, accounts[2].priv}, []uint64{0, 1, 2}, []uint64{0, 0, 0}
suite.txBuilder.SetMemo("Check signers are in expected order and different account numbers works") txBuilder.SetMemo("Check signers are in expected order and different account numbers works")
}, },
false, false,
true, true,
@ -654,20 +654,20 @@ func (suite *AnteTestSuite) TestAnteHandlerMultiSigner() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc)
}) })
} }
} }
func (suite *AnteTestSuite) TestAnteHandlerBadSignBytes() { func (s *MWTestSuite) TestTxHandlerBadSignBytes() {
suite.SetupTest(false) // setup ctx := s.SetupTest(true) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(2) accounts := s.createTestAccounts(ctx, 2)
msg0 := testdata.NewTestMsg(accounts[0].acc.GetAddress()) msg0 := testdata.NewTestMsg(accounts[0].acc.GetAddress())
// Variable data per test case // Variable data per test case
@ -685,7 +685,7 @@ func (suite *AnteTestSuite) TestAnteHandlerBadSignBytes() {
{ {
"test good tx and signBytes", "test good tx and signBytes",
func() { func() {
chainID = suite.ctx.ChainID() chainID = ctx.ChainID()
feeAmount = testdata.NewTestFeeAmount() feeAmount = testdata.NewTestFeeAmount()
gasLimit = testdata.NewTestGasLimit() gasLimit = testdata.NewTestGasLimit()
msgs = []sdk.Msg{msg0} msgs = []sdk.Msg{msg0}
@ -708,7 +708,7 @@ func (suite *AnteTestSuite) TestAnteHandlerBadSignBytes() {
{ {
"test wrong accSeqs", "test wrong accSeqs",
func() { func() {
chainID = suite.ctx.ChainID() // Back to correct chainID chainID = ctx.ChainID() // Back to correct chainID
accSeqs = []uint64{2} accSeqs = []uint64{2}
}, },
false, false,
@ -780,20 +780,20 @@ func (suite *AnteTestSuite) TestAnteHandlerBadSignBytes() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, chainID, tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, chainID, tc)
}) })
} }
} }
func (suite *AnteTestSuite) TestAnteHandlerSetPubKey() { func (s *MWTestSuite) TestTxHandlerSetPubKey() {
suite.SetupTest(false) // setup ctx := s.SetupTest(true) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(2) accounts := s.createTestAccounts(ctx, 2)
feeAmount := testdata.NewTestFeeAmount() feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit() gasLimit := testdata.NewTestGasLimit()
@ -820,8 +820,8 @@ func (suite *AnteTestSuite) TestAnteHandlerSetPubKey() {
"make sure public key has been set (tx itself should fail because of replay protection)", "make sure public key has been set (tx itself should fail because of replay protection)",
func() { func() {
// Make sure public key has been set from previous test. // Make sure public key has been set from previous test.
acc0 := suite.app.AccountKeeper.GetAccount(suite.ctx, accounts[0].acc.GetAddress()) acc0 := s.app.AccountKeeper.GetAccount(ctx, accounts[0].acc.GetAddress())
suite.Require().Equal(acc0.GetPubKey(), accounts[0].priv.PubKey()) s.Require().Equal(acc0.GetPubKey(), accounts[0].priv.PubKey())
}, },
false, false,
false, false,
@ -841,30 +841,30 @@ func (suite *AnteTestSuite) TestAnteHandlerSetPubKey() {
"make sure public key is not set, when tx has no pubkey or signature", "make sure public key is not set, when tx has no pubkey or signature",
func() { func() {
// Make sure public key has not been set from previous test. // Make sure public key has not been set from previous test.
acc1 := suite.app.AccountKeeper.GetAccount(suite.ctx, accounts[1].acc.GetAddress()) acc1 := s.app.AccountKeeper.GetAccount(ctx, accounts[1].acc.GetAddress())
suite.Require().Nil(acc1.GetPubKey()) s.Require().Nil(acc1.GetPubKey())
privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[1].priv}, []uint64{1}, []uint64{0} privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[1].priv}, []uint64{1}, []uint64{0}
msgs = []sdk.Msg{testdata.NewTestMsg(accounts[1].acc.GetAddress())} msgs = []sdk.Msg{testdata.NewTestMsg(accounts[1].acc.GetAddress())}
suite.txBuilder.SetMsgs(msgs...) txBuilder.SetMsgs(msgs...)
suite.txBuilder.SetFeeAmount(feeAmount) txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit) txBuilder.SetGasLimit(gasLimit)
// Manually create tx, and remove signature. // Manually create tx, and remove signature.
tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
suite.Require().NoError(err) s.Require().NoError(err)
txBuilder, err := suite.clientCtx.TxConfig.WrapTxBuilder(tx) txBuilder, err := s.clientCtx.TxConfig.WrapTxBuilder(tx)
suite.Require().NoError(err) s.Require().NoError(err)
suite.Require().NoError(txBuilder.SetSignatures()) s.Require().NoError(txBuilder.SetSignatures())
// Run anteHandler manually, expect ErrNoSignatures. // Run txHandler manually, expect ErrNoSignatures.
_, err = suite.anteHandler(suite.ctx, txBuilder.GetTx(), false) _, err = s.txHandler.CheckTx(sdk.WrapSDKContext(ctx), txBuilder.GetTx(), abci.RequestCheckTx{})
suite.Require().Error(err) s.Require().Error(err)
suite.Require().True(errors.Is(err, sdkerrors.ErrNoSignatures)) s.Require().True(errors.Is(err, sdkerrors.ErrNoSignatures))
// Make sure public key has not been set. // Make sure public key has not been set.
acc1 = suite.app.AccountKeeper.GetAccount(suite.ctx, accounts[1].acc.GetAddress()) acc1 = s.app.AccountKeeper.GetAccount(ctx, accounts[1].acc.GetAddress())
suite.Require().Nil(acc1.GetPubKey()) s.Require().Nil(acc1.GetPubKey())
// Set incorrect accSeq, to generate incorrect signature. // Set incorrect accSeq, to generate incorrect signature.
privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[1].priv}, []uint64{1}, []uint64{1} privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[1].priv}, []uint64{1}, []uint64{1}
@ -876,10 +876,10 @@ func (suite *AnteTestSuite) TestAnteHandlerSetPubKey() {
{ {
"make sure previous public key has been set after wrong signature", "make sure previous public key has been set after wrong signature",
func() { func() {
// Make sure public key has been set, as SetPubKeyDecorator // Make sure public key has been set, as SetPubKeyMiddleware
// is called before all signature verification decorators. // is called before all signature verification middlewares.
acc1 := suite.app.AccountKeeper.GetAccount(suite.ctx, accounts[1].acc.GetAddress()) acc1 := s.app.AccountKeeper.GetAccount(ctx, accounts[1].acc.GetAddress())
suite.Require().Equal(acc1.GetPubKey(), accounts[1].priv.PubKey()) s.Require().Equal(acc1.GetPubKey(), accounts[1].priv.PubKey())
}, },
false, false,
false, false,
@ -888,11 +888,10 @@ func (suite *AnteTestSuite) TestAnteHandlerSetPubKey() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc)
}) })
} }
} }
@ -962,16 +961,17 @@ func TestCountSubkeys(t *testing.T) {
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(T *testing.T) { t.Run(tc.name, func(T *testing.T) {
require.Equal(t, tc.want, ante.CountSubKeys(tc.args.pub)) require.Equal(t, tc.want, middleware.CountSubKeys(tc.args.pub))
}) })
} }
} }
func (suite *AnteTestSuite) TestAnteHandlerSigLimitExceeded() { func (s *MWTestSuite) TestTxHandlerSigLimitExceeded() {
suite.SetupTest(false) // setup ctx := s.SetupTest(false) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(8) accounts := s.createTestAccounts(ctx, 8)
var addrs []sdk.AccAddress var addrs []sdk.AccAddress
var privs []cryptotypes.PrivKey var privs []cryptotypes.PrivKey
for i := 0; i < 8; i++ { for i := 0; i < 8; i++ {
@ -994,26 +994,25 @@ func (suite *AnteTestSuite) TestAnteHandlerSigLimitExceeded() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) s.runTestCase(ctx, txBuilder, privs, msgs, feeAmount, gasLimit, accNums, accSeqs, ctx.ChainID(), tc)
}) })
} }
} }
// Test custom SignatureVerificationGasConsumer // Test custom SignatureVerificationGasConsumer
func (suite *AnteTestSuite) TestCustomSignatureVerificationGasConsumer() { func (s *MWTestSuite) TestCustomSignatureVerificationGasConsumer() {
suite.SetupTest(false) // setup ctx := s.SetupTest(false) // setup
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// setup an ante handler that only accepts PubKeyEd25519 txHandler, err := middleware.NewDefaultTxHandler(
anteHandler, err := ante.NewAnteHandler( middleware.TxHandlerOptions{
ante.HandlerOptions{ AccountKeeper: s.app.AccountKeeper,
AccountKeeper: suite.app.AccountKeeper, BankKeeper: s.app.BankKeeper,
BankKeeper: suite.app.BankKeeper, FeegrantKeeper: s.app.FeeGrantKeeper,
FeegrantKeeper: suite.app.FeeGrantKeeper, SignModeHandler: s.clientCtx.TxConfig.SignModeHandler(),
SignModeHandler: suite.clientCtx.TxConfig.SignModeHandler(),
SigGasConsumer: func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error { SigGasConsumer: func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error {
switch pubkey := sig.PubKey.(type) { switch pubkey := sig.PubKey.(type) {
case *ed25519.PubKey: case *ed25519.PubKey:
@ -1025,19 +1024,19 @@ func (suite *AnteTestSuite) TestCustomSignatureVerificationGasConsumer() {
}, },
}, },
) )
s.Require().NoError(err)
suite.Require().NoError(err) s.Require().NoError(err)
suite.anteHandler = anteHandler
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(1) accounts := s.createTestAccounts(ctx, 1)
feeAmount := testdata.NewTestFeeAmount() txBuilder.SetFeeAmount(testdata.NewTestFeeAmount())
gasLimit := testdata.NewTestGasLimit() txBuilder.SetGasLimit(testdata.NewTestGasLimit())
txBuilder.SetMsgs(testdata.NewTestMsg(accounts[0].acc.GetAddress()))
// Variable data per test case // Variable data per test case
var ( var (
accNums []uint64 accNums []uint64
msgs []sdk.Msg
privs []cryptotypes.PrivKey privs []cryptotypes.PrivKey
accSeqs []uint64 accSeqs []uint64
) )
@ -1046,7 +1045,6 @@ func (suite *AnteTestSuite) TestCustomSignatureVerificationGasConsumer() {
{ {
"verify that an secp256k1 account gets rejected", "verify that an secp256k1 account gets rejected",
func() { func() {
msgs = []sdk.Msg{testdata.NewTestMsg(accounts[0].acc.GetAddress())}
privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[0].priv}, []uint64{0}, []uint64{0} privs, accNums, accSeqs = []cryptotypes.PrivKey{accounts[0].priv}, []uint64{0}, []uint64{0}
}, },
false, false,
@ -1056,54 +1054,57 @@ func (suite *AnteTestSuite) TestCustomSignatureVerificationGasConsumer() {
} }
for _, tc := range testCases { for _, tc := range testCases {
suite.Run(fmt.Sprintf("Case %s", tc.desc), func() { s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
tc.malleate() tc.malleate()
suite.RunTestCase(privs, msgs, feeAmount, gasLimit, accNums, accSeqs, suite.ctx.ChainID(), tc) tx, txBytes, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
s.Require().NoError(err)
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, abci.RequestDeliverTx{Tx: txBytes})
s.Require().Error(err)
s.Require().True(errors.Is(err, tc.expErr))
}) })
} }
} }
func (suite *AnteTestSuite) TestAnteHandlerReCheck() { func (s *MWTestSuite) TestTxHandlerReCheck() {
suite.SetupTest(false) // setup ctx := s.SetupTest(false) // setup
// Set recheck=true // Set recheck=true
suite.ctx = suite.ctx.WithIsReCheckTx(true) ctx = ctx.WithIsReCheckTx(true)
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Same data for every test cases // Same data for every test cases
accounts := suite.CreateTestAccounts(1) accounts := s.createTestAccounts(ctx, 1)
feeAmount := testdata.NewTestFeeAmount() feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit() gasLimit := testdata.NewTestGasLimit()
suite.txBuilder.SetFeeAmount(feeAmount) txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit) txBuilder.SetGasLimit(gasLimit)
msg := testdata.NewTestMsg(accounts[0].acc.GetAddress()) msg := testdata.NewTestMsg(accounts[0].acc.GetAddress())
msgs := []sdk.Msg{msg} msgs := []sdk.Msg{msg}
suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) s.Require().NoError(txBuilder.SetMsgs(msgs...))
suite.txBuilder.SetMemo("thisisatestmemo") txBuilder.SetMemo("thisisatestmemo")
// test that operations skipped on recheck do not run // test that operations skipped on recheck do not run
privs, accNums, accSeqs := []cryptotypes.PrivKey{accounts[0].priv}, []uint64{0}, []uint64{0} privs, accNums, accSeqs := []cryptotypes.PrivKey{accounts[0].priv}, []uint64{0}, []uint64{0}
tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
suite.Require().NoError(err) s.Require().NoError(err)
// make signature array empty which would normally cause ValidateBasicDecorator and SigVerificationDecorator fail // make signature array empty which would normally cause ValidateBasicMiddleware and SigVerificationMiddleware fail
// since these decorators don't run on recheck, the tx should pass the antehandler // since these middlewares don't run on recheck, the tx should pass the middleware
txBuilder, err := suite.clientCtx.TxConfig.WrapTxBuilder(tx) txBuilder, err = s.clientCtx.TxConfig.WrapTxBuilder(tx)
suite.Require().NoError(err) s.Require().NoError(err)
suite.Require().NoError(txBuilder.SetSignatures()) s.Require().NoError(txBuilder.SetSignatures())
_, err = suite.anteHandler(suite.ctx, txBuilder.GetTx(), false) _, err = s.txHandler.CheckTx(sdk.WrapSDKContext(ctx), txBuilder.GetTx(), abci.RequestCheckTx{Type: abci.CheckTxType_Recheck})
suite.Require().Nil(err, "AnteHandler errored on recheck unexpectedly: %v", err) s.Require().Nil(err, "TxHandler errored on recheck unexpectedly: %v", err)
tx, err = suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) tx, _, err = s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
suite.Require().NoError(err) s.Require().NoError(err)
txBytes, err := json.Marshal(tx) txBytes, err := json.Marshal(tx)
suite.Require().Nil(err, "Error marshalling tx: %v", err) s.Require().Nil(err, "Error marshalling tx: %v", err)
suite.ctx = suite.ctx.WithTxBytes(txBytes) ctx = ctx.WithTxBytes(txBytes)
// require that state machine param-dependent checking is still run on recheck since parameters can change between check and recheck // require that state machine param-dependent checking is still run on recheck since parameters can change between check and recheck
testCases := []struct { testCases := []struct {
@ -1114,35 +1115,37 @@ func (suite *AnteTestSuite) TestAnteHandlerReCheck() {
{"txsize check", types.NewParams(types.DefaultMaxMemoCharacters, types.DefaultTxSigLimit, 10000000, types.DefaultSigVerifyCostED25519, types.DefaultSigVerifyCostSecp256k1)}, {"txsize check", types.NewParams(types.DefaultMaxMemoCharacters, types.DefaultTxSigLimit, 10000000, types.DefaultSigVerifyCostED25519, types.DefaultSigVerifyCostSecp256k1)},
{"sig verify cost check", types.NewParams(types.DefaultMaxMemoCharacters, types.DefaultTxSigLimit, types.DefaultTxSizeCostPerByte, types.DefaultSigVerifyCostED25519, 100000000)}, {"sig verify cost check", types.NewParams(types.DefaultMaxMemoCharacters, types.DefaultTxSigLimit, types.DefaultTxSizeCostPerByte, types.DefaultSigVerifyCostED25519, 100000000)},
} }
for _, tc := range testCases { for _, tc := range testCases {
// set testcase parameters // set testcase parameters
suite.app.AccountKeeper.SetParams(suite.ctx, tc.params) s.app.AccountKeeper.SetParams(ctx, tc.params)
_, err := suite.anteHandler(suite.ctx, tx, false) _, err = s.txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx, abci.RequestCheckTx{Tx: txBytes, Type: abci.CheckTxType_Recheck})
suite.Require().NotNil(err, "tx does not fail on recheck with updated params in test case: %s", tc.name) s.Require().NotNil(err, "tx does not fail on recheck with updated params in test case: %s", tc.name)
// reset parameters to default values // reset parameters to default values
suite.app.AccountKeeper.SetParams(suite.ctx, types.DefaultParams()) s.app.AccountKeeper.SetParams(ctx, types.DefaultParams())
} }
// require that local mempool fee check is still run on recheck since validator may change minFee between check and recheck // require that local mempool fee check is still run on recheck since validator may change minFee between check and recheck
// create new minimum gas price so antehandler fails on recheck // create new minimum gas price so txhandler fails on recheck
suite.ctx = suite.ctx.WithMinGasPrices([]sdk.DecCoin{{ ctx = ctx.WithMinGasPrices([]sdk.DecCoin{{
Denom: "dnecoin", // fee does not have this denom Denom: "dnecoin", // fee does not have this denom
Amount: sdk.NewDec(5), Amount: sdk.NewDec(5),
}}) }})
_, err = suite.anteHandler(suite.ctx, tx, false) _, err = s.txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx, abci.RequestCheckTx{})
suite.Require().NotNil(err, "antehandler on recheck did not fail when mingasPrice was changed")
s.Require().NotNil(err, "txhandler on recheck did not fail when mingasPrice was changed")
// reset min gasprice // reset min gasprice
suite.ctx = suite.ctx.WithMinGasPrices(sdk.DecCoins{}) ctx = ctx.WithMinGasPrices(sdk.DecCoins{})
// remove funds for account so antehandler fails on recheck // remove funds for account so txhandler fails on recheck
suite.app.AccountKeeper.SetAccount(suite.ctx, accounts[0].acc) s.app.AccountKeeper.SetAccount(ctx, accounts[0].acc)
balances := suite.app.BankKeeper.GetAllBalances(suite.ctx, accounts[0].acc.GetAddress()) balances := s.app.BankKeeper.GetAllBalances(ctx, accounts[0].acc.GetAddress())
err = suite.app.BankKeeper.SendCoinsFromAccountToModule(suite.ctx, accounts[0].acc.GetAddress(), minttypes.ModuleName, balances) err = s.app.BankKeeper.SendCoinsFromAccountToModule(ctx, accounts[0].acc.GetAddress(), minttypes.ModuleName, balances)
suite.Require().NoError(err) s.Require().NoError(err)
_, err = suite.anteHandler(suite.ctx, tx, false) _, err = s.txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx, abci.RequestCheckTx{})
suite.Require().NotNil(err, "antehandler on recheck did not fail once feePayer no longer has sufficient funds") s.Require().NotNil(err, "txhandler on recheck did not fail once feePayer no longer has sufficient funds")
} }

View File

@ -1,22 +1,13 @@
package middleware_test package middleware_test
import ( import (
"os"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
abci "github.com/tendermint/tendermint/abci/types"
"github.com/tendermint/tendermint/libs/log"
tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
dbm "github.com/tendermint/tm-db"
"github.com/cosmos/cosmos-sdk/baseapp"
"github.com/cosmos/cosmos-sdk/client/tx"
"github.com/cosmos/cosmos-sdk/simapp" "github.com/cosmos/cosmos-sdk/simapp"
"github.com/cosmos/cosmos-sdk/testutil/testdata" "github.com/cosmos/cosmos-sdk/testutil/testdata"
"github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/middleware" "github.com/cosmos/cosmos-sdk/x/auth/middleware"
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
) )
func TestRegisterMsgService(t *testing.T) { func TestRegisterMsgService(t *testing.T) {
@ -62,63 +53,3 @@ func TestRegisterMsgServiceTwice(t *testing.T) {
) )
}) })
} }
func TestMsgService(t *testing.T) {
priv, _, _ := testdata.KeyTestPubAddr()
encCfg := simapp.MakeTestEncodingConfig()
testdata.RegisterInterfaces(encCfg.InterfaceRegistry)
db := dbm.NewMemDB()
app := baseapp.NewBaseApp("test", log.NewTMLogger(log.NewSyncWriter(os.Stdout)), db, encCfg.TxConfig.TxDecoder())
app.SetInterfaceRegistry(encCfg.InterfaceRegistry)
msr := middleware.NewMsgServiceRouter(encCfg.InterfaceRegistry)
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{
MsgServiceRouter: msr,
})
require.NoError(t, err)
app.SetTxHandler(txHandler)
testdata.RegisterMsgServer(
msr,
testdata.MsgServerImpl{},
)
_ = app.BeginBlock(abci.RequestBeginBlock{Header: tmproto.Header{Height: 1}})
msg := testdata.MsgCreateDog{Dog: &testdata.Dog{Name: "Spot"}}
txBuilder := encCfg.TxConfig.NewTxBuilder()
txBuilder.SetFeeAmount(testdata.NewTestFeeAmount())
txBuilder.SetGasLimit(testdata.NewTestGasLimit())
err = txBuilder.SetMsgs(&msg)
require.NoError(t, err)
// First round: we gather all the signer infos. We use the "set empty
// signature" hack to do that.
sigV2 := signing.SignatureV2{
PubKey: priv.PubKey(),
Data: &signing.SingleSignatureData{
SignMode: encCfg.TxConfig.SignModeHandler().DefaultMode(),
Signature: nil,
},
Sequence: 0,
}
err = txBuilder.SetSignatures(sigV2)
require.NoError(t, err)
// Second round: all signer infos are set, so each signer can sign.
signerData := authsigning.SignerData{
ChainID: "test",
AccountNumber: 0,
Sequence: 0,
}
sigV2, err = tx.SignWithPrivKey(
encCfg.TxConfig.SignModeHandler().DefaultMode(), signerData,
txBuilder, priv, encCfg.TxConfig, 0)
require.NoError(t, err)
err = txBuilder.SetSignatures(sigV2)
require.NoError(t, err)
// Send the tx to the app
txBytes, err := encCfg.TxConfig.TxEncoder()(txBuilder.GetTx())
require.NoError(t, err)
res := app.DeliverTx(abci.RequestDeliverTx{Tx: txBytes})
require.Equal(t, abci.CodeTypeOK, res.Code, "res=%+v", res)
}

View File

@ -83,6 +83,7 @@ func (txh runMsgsTxHandler) runMsgs(sdkCtx sdk.Context, msgs []sdk.Msg, txBytes
Data: make([]*sdk.MsgData, 0, len(msgs)), Data: make([]*sdk.MsgData, 0, len(msgs)),
} }
// NOTE: GasWanted is determined by the Gas TxHandler and GasUsed by the GasMeter.
for i, msg := range msgs { for i, msg := range msgs {
var ( var (
msgResult *sdk.Result msgResult *sdk.Result

View File

@ -0,0 +1,36 @@
package middleware_test
import (
"github.com/tendermint/tendermint/abci/types"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/middleware"
)
func (s *MWTestSuite) TestRunMsgs() {
ctx := s.SetupTest(true) // setup
msr := middleware.NewMsgServiceRouter(s.clientCtx.InterfaceRegistry)
testdata.RegisterMsgServer(msr, testdata.MsgServerImpl{})
txHandler := middleware.NewRunMsgsTxHandler(msr, nil)
priv, _, _ := testdata.KeyTestPubAddr()
txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
txBuilder.SetMsgs(&testdata.MsgCreateDog{Dog: &testdata.Dog{Name: "Spot"}})
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv}, []uint64{0}, []uint64{0}
tx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
s.Require().NoError(err)
txBytes, err := s.clientCtx.TxConfig.TxEncoder()(tx)
s.Require().NoError(err)
res, err := txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, types.RequestDeliverTx{Tx: txBytes})
s.Require().NoError(err)
s.Require().NotEmpty(res.Data)
var txMsgData sdk.TxMsgData
err = s.clientCtx.Codec.Unmarshal(res.Data, &txMsgData)
s.Require().NoError(err)
s.Require().Len(txMsgData.Data, 1)
s.Require().Equal(sdk.MsgTypeURL(&testdata.MsgCreateDog{}), txMsgData.Data[0].MsgType)
}

View File

@ -1,9 +1,9 @@
package ante package middleware
import ( import (
"bytes" "bytes"
"context"
"encoding/base64" "encoding/base64"
"encoding/hex"
"fmt" "fmt"
"github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519"
@ -14,10 +14,11 @@ import (
"github.com/cosmos/cosmos-sdk/crypto/types/multisig" "github.com/cosmos/cosmos-sdk/crypto/types/multisig"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/types/tx/signing" "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx"
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/cosmos/cosmos-sdk/x/auth/types"
abci "github.com/tendermint/tendermint/abci/types"
) )
var ( var (
@ -25,44 +26,42 @@ var (
key = make([]byte, secp256k1.PubKeySize) key = make([]byte, secp256k1.PubKeySize)
simSecp256k1Pubkey = &secp256k1.PubKey{Key: key} simSecp256k1Pubkey = &secp256k1.PubKey{Key: key}
simSecp256k1Sig [64]byte simSecp256k1Sig [64]byte
_ authsigning.SigVerifiableTx = (*legacytx.StdTx)(nil) // assert StdTx implements SigVerifiableTx
) )
func init() {
// This decodes a valid hex string into a sepc256k1Pubkey for use in transaction simulation
bz, _ := hex.DecodeString("035AD6810A47F073553FF30D2FCC7E0D3B1C0B74B61A1AAA2582344037151E143A")
copy(key, bz)
simSecp256k1Pubkey.Key = key
}
// SignatureVerificationGasConsumer is the type of function that is used to both // SignatureVerificationGasConsumer is the type of function that is used to both
// consume gas when verifying signatures and also to accept or reject different types of pubkeys // consume gas when verifying signatures and also to accept or reject different types of pubkeys
// This is where apps can define their own PubKey // This is where apps can define their own PubKey
type SignatureVerificationGasConsumer = func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error type SignatureVerificationGasConsumer = func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error
// SetPubKeyDecorator sets PubKeys in context for any signer which does not already have pubkey set var _ tx.Handler = setPubKeyTxHandler{}
// PubKeys must be set in context for all signers before any other sigverify decorators run
// CONTRACT: Tx must implement SigVerifiableTx interface type setPubKeyTxHandler struct {
type SetPubKeyDecorator struct {
ak AccountKeeper ak AccountKeeper
next tx.Handler
} }
func NewSetPubKeyDecorator(ak AccountKeeper) SetPubKeyDecorator { // SetPubKeyMiddleware sets PubKeys in context for any signer which does not already have pubkey set
return SetPubKeyDecorator{ // PubKeys must be set in context for all signers before any other sigverify middlewares run
// CONTRACT: Tx must implement SigVerifiableTx interface
func SetPubKeyMiddleware(ak AccountKeeper) tx.Middleware {
return func(txh tx.Handler) tx.Handler {
return setPubKeyTxHandler{
ak: ak, ak: ak,
next: txh,
}
} }
} }
func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { func (spkm setPubKeyTxHandler) setPubKey(ctx context.Context, tx sdk.Tx, simulate bool) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
sigTx, ok := tx.(authsigning.SigVerifiableTx) sigTx, ok := tx.(authsigning.SigVerifiableTx)
if !ok { if !ok {
return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid tx type") return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid tx type")
} }
pubkeys, err := sigTx.GetPubKeys() pubkeys, err := sigTx.GetPubKeys()
if err != nil { if err != nil {
return ctx, err return err
} }
signers := sigTx.GetSigners() signers := sigTx.GetSigners()
@ -76,13 +75,13 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b
} }
// Only make check if simulate=false // Only make check if simulate=false
if !simulate && !bytes.Equal(pk.Address(), signers[i]) { if !simulate && !bytes.Equal(pk.Address(), signers[i]) {
return ctx, sdkerrors.Wrapf(sdkerrors.ErrInvalidPubKey, return sdkerrors.Wrapf(sdkerrors.ErrInvalidPubKey,
"pubKey does not match signer address %s with signer index: %d", signers[i], i) "pubKey does not match signer address %s with signer index: %d", signers[i], i)
} }
acc, err := GetSignerAcc(ctx, spkd.ak, signers[i]) acc, err := GetSignerAcc(sdkCtx, spkm.ak, signers[i])
if err != nil { if err != nil {
return ctx, err return err
} }
// account already has pubkey set,no need to reset // account already has pubkey set,no need to reset
if acc.GetPubKey() != nil { if acc.GetPubKey() != nil {
@ -90,9 +89,9 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b
} }
err = acc.SetPubKey(pk) err = acc.SetPubKey(pk)
if err != nil { if err != nil {
return ctx, sdkerrors.Wrap(sdkerrors.ErrInvalidPubKey, err.Error()) return sdkerrors.Wrap(sdkerrors.ErrInvalidPubKey, err.Error())
} }
spkd.ak.SetAccount(ctx, acc) spkm.ak.SetAccount(sdkCtx, acc)
} }
// Also emit the following events, so that txs can be indexed by these // Also emit the following events, so that txs can be indexed by these
@ -101,7 +100,7 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b
// - concat(address,"/",sequence) (via `tx.acc_seq='cosmos1abc...def/42'`). // - concat(address,"/",sequence) (via `tx.acc_seq='cosmos1abc...def/42'`).
sigs, err := sigTx.GetSignaturesV2() sigs, err := sigTx.GetSignaturesV2()
if err != nil { if err != nil {
return ctx, err return err
} }
var events sdk.Events var events sdk.Events
@ -112,7 +111,7 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b
sigBzs, err := signatureDataToBz(sig.Data) sigBzs, err := signatureDataToBz(sig.Data)
if err != nil { if err != nil {
return ctx, err return err
} }
for _, sigBz := range sigBzs { for _, sigBz := range sigBzs {
events = append(events, sdk.NewEvent(sdk.EventTypeTx, events = append(events, sdk.NewEvent(sdk.EventTypeTx,
@ -121,264 +120,106 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b
} }
} }
ctx.EventManager().EmitEvents(events) sdkCtx.EventManager().EmitEvents(events)
return next(ctx, tx, simulate) return nil
} }
// Consume parameter-defined amount of gas for each signature according to the passed-in SignatureVerificationGasConsumer function // CheckTx implements tx.Handler.CheckTx.
// before calling the next AnteHandler func (spkm setPubKeyTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
// CONTRACT: Pubkeys are set in context for all signers before this decorator runs if err := spkm.setPubKey(ctx, tx, false); err != nil {
return abci.ResponseCheckTx{}, err
}
return spkm.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (spkm setPubKeyTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := spkm.setPubKey(ctx, tx, false); err != nil {
return abci.ResponseDeliverTx{}, err
}
return spkm.next.DeliverTx(ctx, tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx.
func (spkm setPubKeyTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := spkm.setPubKey(ctx, sdkTx, true); err != nil {
return tx.ResponseSimulateTx{}, err
}
return spkm.next.SimulateTx(ctx, sdkTx, req)
}
var _ tx.Handler = validateSigCountTxHandler{}
type validateSigCountTxHandler struct {
ak AccountKeeper
next tx.Handler
}
// ValidateSigCountMiddleware takes in Params and returns errors if there are too many signatures in the tx for the given params
// otherwise it calls next middleware
// Use this middleware to set parameterized limit on number of signatures in tx
// CONTRACT: Tx must implement SigVerifiableTx interface // CONTRACT: Tx must implement SigVerifiableTx interface
type SigGasConsumeDecorator struct { func ValidateSigCountMiddleware(ak AccountKeeper) tx.Middleware {
ak AccountKeeper return func(txh tx.Handler) tx.Handler {
sigGasConsumer SignatureVerificationGasConsumer return validateSigCountTxHandler{
}
func NewSigGasConsumeDecorator(ak AccountKeeper, sigGasConsumer SignatureVerificationGasConsumer) SigGasConsumeDecorator {
return SigGasConsumeDecorator{
ak: ak, ak: ak,
sigGasConsumer: sigGasConsumer, next: txh,
}
} }
} }
func (sgcd SigGasConsumeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) { func (vscd validateSigCountTxHandler) checkSigCount(ctx context.Context, tx sdk.Tx) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
sigTx, ok := tx.(authsigning.SigVerifiableTx) sigTx, ok := tx.(authsigning.SigVerifiableTx)
if !ok { if !ok {
return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a sigTx")
} }
params := sgcd.ak.GetParams(ctx) params := vscd.ak.GetParams(sdkCtx)
sigs, err := sigTx.GetSignaturesV2()
if err != nil {
return ctx, err
}
// stdSigs contains the sequence number, account number, and signatures.
// When simulating, this would just be a 0-length slice.
signerAddrs := sigTx.GetSigners()
for i, sig := range sigs {
signerAcc, err := GetSignerAcc(ctx, sgcd.ak, signerAddrs[i])
if err != nil {
return ctx, err
}
pubKey := signerAcc.GetPubKey()
// In simulate mode the transaction comes with no signatures, thus if the
// account's pubkey is nil, both signature verification and gasKVStore.Set()
// shall consume the largest amount, i.e. it takes more gas to verify
// secp256k1 keys than ed25519 ones.
if simulate && pubKey == nil {
pubKey = simSecp256k1Pubkey
}
// make a SignatureV2 with PubKey filled in from above
sig = signing.SignatureV2{
PubKey: pubKey,
Data: sig.Data,
Sequence: sig.Sequence,
}
err = sgcd.sigGasConsumer(ctx.GasMeter(), sig, params)
if err != nil {
return ctx, err
}
}
return next(ctx, tx, simulate)
}
// Verify all signatures for a tx and return an error if any are invalid. Note,
// the SigVerificationDecorator decorator will not get executed on ReCheck.
//
// CONTRACT: Pubkeys are set in context for all signers before this decorator runs
// CONTRACT: Tx must implement SigVerifiableTx interface
type SigVerificationDecorator struct {
ak AccountKeeper
signModeHandler authsigning.SignModeHandler
}
func NewSigVerificationDecorator(ak AccountKeeper, signModeHandler authsigning.SignModeHandler) SigVerificationDecorator {
return SigVerificationDecorator{
ak: ak,
signModeHandler: signModeHandler,
}
}
// OnlyLegacyAminoSigners checks SignatureData to see if all
// signers are using SIGN_MODE_LEGACY_AMINO_JSON. If this is the case
// then the corresponding SignatureV2 struct will not have account sequence
// explicitly set, and we should skip the explicit verification of sig.Sequence
// in the SigVerificationDecorator's AnteHandler function.
func OnlyLegacyAminoSigners(sigData signing.SignatureData) bool {
switch v := sigData.(type) {
case *signing.SingleSignatureData:
return v.SignMode == signing.SignMode_SIGN_MODE_LEGACY_AMINO_JSON
case *signing.MultiSignatureData:
for _, s := range v.Signatures {
if !OnlyLegacyAminoSigners(s) {
return false
}
}
return true
default:
return false
}
}
func (svd SigVerificationDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (newCtx sdk.Context, err error) {
// no need to verify signatures on recheck tx
if ctx.IsReCheckTx() {
return next(ctx, tx, simulate)
}
sigTx, ok := tx.(authsigning.SigVerifiableTx)
if !ok {
return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}
// stdSigs contains the sequence number, account number, and signatures.
// When simulating, this would just be a 0-length slice.
sigs, err := sigTx.GetSignaturesV2()
if err != nil {
return ctx, err
}
signerAddrs := sigTx.GetSigners()
// check that signer length and signature length are the same
if len(sigs) != len(signerAddrs) {
return ctx, sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "invalid number of signer; expected: %d, got %d", len(signerAddrs), len(sigs))
}
for i, sig := range sigs {
acc, err := GetSignerAcc(ctx, svd.ak, signerAddrs[i])
if err != nil {
return ctx, err
}
// retrieve pubkey
pubKey := acc.GetPubKey()
if !simulate && pubKey == nil {
return ctx, sdkerrors.Wrap(sdkerrors.ErrInvalidPubKey, "pubkey on account is not set")
}
// Check account sequence number.
if sig.Sequence != acc.GetSequence() {
return ctx, sdkerrors.Wrapf(
sdkerrors.ErrWrongSequence,
"account sequence mismatch, expected %d, got %d", acc.GetSequence(), sig.Sequence,
)
}
// retrieve signer data
genesis := ctx.BlockHeight() == 0
chainID := ctx.ChainID()
var accNum uint64
if !genesis {
accNum = acc.GetAccountNumber()
}
signerData := authsigning.SignerData{
ChainID: chainID,
AccountNumber: accNum,
Sequence: acc.GetSequence(),
}
if !simulate {
err := authsigning.VerifySignature(pubKey, signerData, sig.Data, svd.signModeHandler, tx)
if err != nil {
var errMsg string
if OnlyLegacyAminoSigners(sig.Data) {
// If all signers are using SIGN_MODE_LEGACY_AMINO, we rely on VerifySignature to check account sequence number,
// and therefore communicate sequence number as a potential cause of error.
errMsg = fmt.Sprintf("signature verification failed; please verify account number (%d), sequence (%d) and chain-id (%s)", accNum, acc.GetSequence(), chainID)
} else {
errMsg = fmt.Sprintf("signature verification failed; please verify account number (%d) and chain-id (%s)", accNum, chainID)
}
return ctx, sdkerrors.Wrap(sdkerrors.ErrUnauthorized, errMsg)
}
}
}
return next(ctx, tx, simulate)
}
// IncrementSequenceDecorator handles incrementing sequences of all signers.
// Use the IncrementSequenceDecorator decorator to prevent replay attacks. Note,
// there is no need to execute IncrementSequenceDecorator on RecheckTX since
// CheckTx would already bump the sequence number.
//
// NOTE: Since CheckTx and DeliverTx state are managed separately, subsequent and
// sequential txs orginating from the same account cannot be handled correctly in
// a reliable way unless sequence numbers are managed and tracked manually by a
// client. It is recommended to instead use multiple messages in a tx.
type IncrementSequenceDecorator struct {
ak AccountKeeper
}
func NewIncrementSequenceDecorator(ak AccountKeeper) IncrementSequenceDecorator {
return IncrementSequenceDecorator{
ak: ak,
}
}
func (isd IncrementSequenceDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) {
sigTx, ok := tx.(authsigning.SigVerifiableTx)
if !ok {
return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}
// increment sequence of all signers
for _, addr := range sigTx.GetSigners() {
acc := isd.ak.GetAccount(ctx, addr)
if err := acc.SetSequence(acc.GetSequence() + 1); err != nil {
panic(err)
}
isd.ak.SetAccount(ctx, acc)
}
return next(ctx, tx, simulate)
}
// ValidateSigCountDecorator takes in Params and returns errors if there are too many signatures in the tx for the given params
// otherwise it calls next AnteHandler
// Use this decorator to set parameterized limit on number of signatures in tx
// CONTRACT: Tx must implement SigVerifiableTx interface
type ValidateSigCountDecorator struct {
ak AccountKeeper
}
func NewValidateSigCountDecorator(ak AccountKeeper) ValidateSigCountDecorator {
return ValidateSigCountDecorator{
ak: ak,
}
}
func (vscd ValidateSigCountDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) {
sigTx, ok := tx.(authsigning.SigVerifiableTx)
if !ok {
return ctx, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a sigTx")
}
params := vscd.ak.GetParams(ctx)
pubKeys, err := sigTx.GetPubKeys() pubKeys, err := sigTx.GetPubKeys()
if err != nil { if err != nil {
return ctx, err return err
} }
sigCount := 0 sigCount := 0
for _, pk := range pubKeys { for _, pk := range pubKeys {
sigCount += CountSubKeys(pk) sigCount += CountSubKeys(pk)
if uint64(sigCount) > params.TxSigLimit { if uint64(sigCount) > params.TxSigLimit {
return ctx, sdkerrors.Wrapf(sdkerrors.ErrTooManySignatures, return sdkerrors.Wrapf(sdkerrors.ErrTooManySignatures,
"signatures: %d, limit: %d", sigCount, params.TxSigLimit) "signatures: %d, limit: %d", sigCount, params.TxSigLimit)
} }
} }
return nil
}
return next(ctx, tx, simulate) // CheckTx implements tx.Handler.CheckTx.
func (vscd validateSigCountTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
if err := vscd.checkSigCount(ctx, tx); err != nil {
return abci.ResponseCheckTx{}, err
}
return vscd.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (vscd validateSigCountTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := vscd.checkSigCount(ctx, sdkTx); err != nil {
return tx.ResponseSimulateTx{}, err
}
return vscd.next.SimulateTx(ctx, sdkTx, req)
}
// SimulateTx implements tx.Handler.SimulateTx.
func (vscd validateSigCountTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := vscd.checkSigCount(ctx, tx); err != nil {
return abci.ResponseDeliverTx{}, err
}
return vscd.next.DeliverTx(ctx, tx, req)
} }
// DefaultSigVerificationGasConsumer is the default implementation of SignatureVerificationGasConsumer. It consumes gas // DefaultSigVerificationGasConsumer is the default implementation of SignatureVerificationGasConsumer. It consumes gas
@ -445,6 +286,326 @@ func ConsumeMultisignatureVerificationGas(
return nil return nil
} }
var _ tx.Handler = sigGasConsumeTxHandler{}
type sigGasConsumeTxHandler struct {
ak AccountKeeper
sigGasConsumer SignatureVerificationGasConsumer
next tx.Handler
}
// SigGasConsumeMiddleware consumes parameter-defined amount of gas for each signature according to the passed-in SignatureVerificationGasConsumer function
// before calling the next middleware
// CONTRACT: Pubkeys are set in context for all signers before this middleware runs
// CONTRACT: Tx must implement SigVerifiableTx interface
func SigGasConsumeMiddleware(ak AccountKeeper, sigGasConsumer SignatureVerificationGasConsumer) tx.Middleware {
return func(h tx.Handler) tx.Handler {
return sigGasConsumeTxHandler{
ak: ak,
sigGasConsumer: sigGasConsumer,
next: h,
}
}
}
func (sgcm sigGasConsumeTxHandler) sigGasConsume(ctx context.Context, tx sdk.Tx, simulate bool) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
sigTx, ok := tx.(authsigning.SigVerifiableTx)
if !ok {
return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}
params := sgcm.ak.GetParams(sdkCtx)
sigs, err := sigTx.GetSignaturesV2()
if err != nil {
return err
}
// stdSigs contains the sequence number, account number, and signatures.
// When simulating, this would just be a 0-length slice.
signerAddrs := sigTx.GetSigners()
for i, sig := range sigs {
signerAcc, err := GetSignerAcc(sdkCtx, sgcm.ak, signerAddrs[i])
if err != nil {
return err
}
pubKey := signerAcc.GetPubKey()
// In simulate mode the transaction comes with no signatures, thus if the
// account's pubkey is nil, both signature verification and gasKVStore.Set()
// shall consume the largest amount, i.e. it takes more gas to verify
// secp256k1 keys than ed25519 ones.
if simulate && pubKey == nil {
pubKey = simSecp256k1Pubkey
}
// make a SignatureV2 with PubKey filled in from above
sig = signing.SignatureV2{
PubKey: pubKey,
Data: sig.Data,
Sequence: sig.Sequence,
}
err = sgcm.sigGasConsumer(sdkCtx.GasMeter(), sig, params)
if err != nil {
return err
}
}
return nil
}
// CheckTx implements tx.Handler.CheckTx.
func (sgcm sigGasConsumeTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
if err := sgcm.sigGasConsume(ctx, tx, false); err != nil {
return abci.ResponseCheckTx{}, err
}
return sgcm.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (sgcm sigGasConsumeTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := sgcm.sigGasConsume(ctx, tx, false); err != nil {
return abci.ResponseDeliverTx{}, err
}
return sgcm.next.DeliverTx(ctx, tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx.
func (sgcm sigGasConsumeTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := sgcm.sigGasConsume(ctx, sdkTx, true); err != nil {
return tx.ResponseSimulateTx{}, err
}
return sgcm.next.SimulateTx(ctx, sdkTx, req)
}
var _ tx.Handler = sigVerificationTxHandler{}
type sigVerificationTxHandler struct {
ak AccountKeeper
signModeHandler authsigning.SignModeHandler
next tx.Handler
}
// SigVerificationMiddleware verifies all signatures for a tx and return an error if any are invalid. Note,
// the sigVerificationTxHandler middleware will not get executed on ReCheck.
//
// CONTRACT: Pubkeys are set in context for all signers before this middleware runs
// CONTRACT: Tx must implement SigVerifiableTx interface
func SigVerificationMiddleware(ak AccountKeeper, signModeHandler authsigning.SignModeHandler) tx.Middleware {
return func(h tx.Handler) tx.Handler {
return sigVerificationTxHandler{
ak: ak,
signModeHandler: signModeHandler,
next: h,
}
}
}
// OnlyLegacyAminoSigners checks SignatureData to see if all
// signers are using SIGN_MODE_LEGACY_AMINO_JSON. If this is the case
// then the corresponding SignatureV2 struct will not have account sequence
// explicitly set, and we should skip the explicit verification of sig.Sequence
// in the SigVerificationMiddleware's middleware function.
func OnlyLegacyAminoSigners(sigData signing.SignatureData) bool {
switch v := sigData.(type) {
case *signing.SingleSignatureData:
return v.SignMode == signing.SignMode_SIGN_MODE_LEGACY_AMINO_JSON
case *signing.MultiSignatureData:
for _, s := range v.Signatures {
if !OnlyLegacyAminoSigners(s) {
return false
}
}
return true
default:
return false
}
}
func (svm sigVerificationTxHandler) sigVerify(ctx context.Context, tx sdk.Tx, isReCheckTx, simulate bool) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
// no need to verify signatures on recheck tx
if isReCheckTx {
return nil
}
sigTx, ok := tx.(authsigning.SigVerifiableTx)
if !ok {
return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}
// stdSigs contains the sequence number, account number, and signatures.
// When simulating, this would just be a 0-length slice.
sigs, err := sigTx.GetSignaturesV2()
if err != nil {
return err
}
signerAddrs := sigTx.GetSigners()
// check that signer length and signature length are the same
if len(sigs) != len(signerAddrs) {
return sdkerrors.Wrapf(sdkerrors.ErrUnauthorized, "invalid number of signer; expected: %d, got %d", len(signerAddrs), len(sigs))
}
for i, sig := range sigs {
acc, err := GetSignerAcc(sdkCtx, svm.ak, signerAddrs[i])
if err != nil {
return err
}
// retrieve pubkey
pubKey := acc.GetPubKey()
if !simulate && pubKey == nil {
return sdkerrors.Wrap(sdkerrors.ErrInvalidPubKey, "pubkey on account is not set")
}
// Check account sequence number.
if sig.Sequence != acc.GetSequence() {
return sdkerrors.Wrapf(
sdkerrors.ErrWrongSequence,
"account sequence mismatch, expected %d, got %d", acc.GetSequence(), sig.Sequence,
)
}
// retrieve signer data
genesis := sdkCtx.BlockHeight() == 0
chainID := sdkCtx.ChainID()
var accNum uint64
if !genesis {
accNum = acc.GetAccountNumber()
}
signerData := authsigning.SignerData{
ChainID: chainID,
AccountNumber: accNum,
Sequence: acc.GetSequence(),
}
if !simulate {
err := authsigning.VerifySignature(pubKey, signerData, sig.Data, svm.signModeHandler, tx)
if err != nil {
var errMsg string
if OnlyLegacyAminoSigners(sig.Data) {
// If all signers are using SIGN_MODE_LEGACY_AMINO, we rely on VerifySignature to check account sequence number,
// and therefore communicate sequence number as a potential cause of error.
errMsg = fmt.Sprintf("signature verification failed; please verify account number (%d), sequence (%d) and chain-id (%s)", accNum, acc.GetSequence(), chainID)
} else {
errMsg = fmt.Sprintf("signature verification failed; please verify account number (%d) and chain-id (%s)", accNum, chainID)
}
return sdkerrors.Wrap(sdkerrors.ErrUnauthorized, errMsg)
}
}
}
return nil
}
// CheckTx implements tx.Handler.CheckTx.
func (svd sigVerificationTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
if err := svd.sigVerify(ctx, tx, req.Type == abci.CheckTxType_Recheck, false); err != nil {
return abci.ResponseCheckTx{}, err
}
return svd.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (svd sigVerificationTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := svd.sigVerify(ctx, tx, false, false); err != nil {
return abci.ResponseDeliverTx{}, err
}
return svd.next.DeliverTx(ctx, tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx.
func (svd sigVerificationTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := svd.sigVerify(ctx, sdkTx, false, true); err != nil {
return tx.ResponseSimulateTx{}, err
}
return svd.next.SimulateTx(ctx, sdkTx, req)
}
var _ tx.Handler = incrementSequenceTxHandler{}
type incrementSequenceTxHandler struct {
ak AccountKeeper
next tx.Handler
}
// IncrementSequenceMiddleware handles incrementing sequences of all signers.
// Use the incrementSequenceTxHandler middleware to prevent replay attacks. Note,
// there is no need to execute incrementSequenceTxHandler on RecheckTX since
// CheckTx would already bump the sequence number.
//
// NOTE: Since CheckTx and DeliverTx state are managed separately, subsequent and
// sequential txs orginating from the same account cannot be handled correctly in
// a reliable way unless sequence numbers are managed and tracked manually by a
// client. It is recommended to instead use multiple messages in a tx.
func IncrementSequenceMiddleware(ak AccountKeeper) tx.Middleware {
return func(h tx.Handler) tx.Handler {
return incrementSequenceTxHandler{
ak: ak,
next: h,
}
}
}
func (isd incrementSequenceTxHandler) incrementSeq(ctx context.Context, tx sdk.Tx) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
sigTx, ok := tx.(authsigning.SigVerifiableTx)
if !ok {
return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}
// increment sequence of all signers
for _, addr := range sigTx.GetSigners() {
acc := isd.ak.GetAccount(sdkCtx, addr)
if err := acc.SetSequence(acc.GetSequence() + 1); err != nil {
panic(err)
}
isd.ak.SetAccount(sdkCtx, acc)
}
return nil
}
// CheckTx implements tx.Handler.CheckTx.
func (isd incrementSequenceTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
if err := isd.incrementSeq(ctx, tx); err != nil {
return abci.ResponseCheckTx{}, err
}
return isd.next.CheckTx(ctx, tx, req)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (isd incrementSequenceTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
if err := isd.incrementSeq(ctx, tx); err != nil {
return abci.ResponseDeliverTx{}, err
}
return isd.next.DeliverTx(ctx, tx, req)
}
// SimulateTx implements tx.Handler.SimulateTx.
func (isd incrementSequenceTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
if err := isd.incrementSeq(ctx, sdkTx); err != nil {
return tx.ResponseSimulateTx{}, err
}
return isd.next.SimulateTx(ctx, sdkTx, req)
}
// GetSignerAcc returns an account for a given address that is expected to sign // GetSignerAcc returns an account for a given address that is expected to sign
// a transaction. // a transaction.
func GetSignerAcc(ctx sdk.Context, ak AccountKeeper, addr sdk.AccAddress) (types.AccountI, error) { func GetSignerAcc(ctx sdk.Context, ak AccountKeeper, addr sdk.AccAddress) (types.AccountI, error) {

View File

@ -1,4 +1,4 @@
package ante_test package middleware_test
import ( import (
"testing" "testing"
@ -10,7 +10,7 @@ import (
"github.com/cosmos/cosmos-sdk/crypto/keys/secp256r1" "github.com/cosmos/cosmos-sdk/crypto/keys/secp256r1"
) )
// This benchmark is used to asses the ante.Secp256k1ToR1GasFactor value // This benchmark is used to asses the middleware.Secp256k1ToR1GasFactor value
func BenchmarkSig(b *testing.B) { func BenchmarkSig(b *testing.B) {
require := require.New(b) require := require.New(b)
msg := tmcrypto.CRandBytes(1000) msg := tmcrypto.CRandBytes(1000)

View File

@ -1,10 +1,10 @@
package ante_test package middleware_test
import ( import (
"fmt" "fmt"
"github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/codec" "github.com/cosmos/cosmos-sdk/codec/legacy"
"github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519"
kmultisig "github.com/cosmos/cosmos-sdk/crypto/keys/multisig" kmultisig "github.com/cosmos/cosmos-sdk/crypto/keys/multisig"
"github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1" "github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1"
@ -14,16 +14,22 @@ import (
"github.com/cosmos/cosmos-sdk/simapp" "github.com/cosmos/cosmos-sdk/simapp"
"github.com/cosmos/cosmos-sdk/testutil/testdata" "github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/types/tx/signing" "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/ante" "github.com/cosmos/cosmos-sdk/x/auth/middleware"
"github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx"
"github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/cosmos/cosmos-sdk/x/auth/types"
abci "github.com/tendermint/tendermint/abci/types"
) )
func (suite *AnteTestSuite) TestSetPubKey() { func (s *MWTestSuite) TestSetPubKey() {
suite.SetupTest(true) // setup ctx := s.SetupTest(true) // setup
require := suite.Require() txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() require := s.Require()
txHandler := middleware.ComposeMiddlewares(
noopTxHandler{},
middleware.SetPubKeyMiddleware(s.app.AccountKeeper),
)
// keys and addresses // keys and addresses
priv1, pub1, addr1 := testdata.KeyTestPubAddr() priv1, pub1, addr1 := testdata.KeyTestPubAddr()
@ -36,35 +42,32 @@ func (suite *AnteTestSuite) TestSetPubKey() {
msgs := make([]sdk.Msg, len(addrs)) msgs := make([]sdk.Msg, len(addrs))
// set accounts and create msg for each address // set accounts and create msg for each address
for i, addr := range addrs { for i, addr := range addrs {
acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr)
require.NoError(acc.SetAccountNumber(uint64(i))) require.NoError(acc.SetAccountNumber(uint64(i)))
suite.app.AccountKeeper.SetAccount(suite.ctx, acc) s.app.AccountKeeper.SetAccount(ctx, acc)
msgs[i] = testdata.NewTestMsg(addr) msgs[i] = testdata.NewTestMsg(addr)
} }
require.NoError(suite.txBuilder.SetMsgs(msgs...)) require.NoError(txBuilder.SetMsgs(msgs...))
suite.txBuilder.SetFeeAmount(testdata.NewTestFeeAmount()) txBuilder.SetFeeAmount(testdata.NewTestFeeAmount())
suite.txBuilder.SetGasLimit(testdata.NewTestGasLimit()) txBuilder.SetGasLimit(testdata.NewTestGasLimit())
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 1, 2}, []uint64{0, 0, 0} privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 1, 2}, []uint64{0, 0, 0}
tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
require.NoError(err) require.NoError(err)
spkd := ante.NewSetPubKeyDecorator(suite.app.AccountKeeper) _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestDeliverTx{})
antehandler := sdk.ChainAnteDecorators(spkd)
ctx, err := antehandler(suite.ctx, tx, false)
require.NoError(err) require.NoError(err)
// Require that all accounts have pubkey set after Decorator runs // Require that all accounts have pubkey set after middleware runs
for i, addr := range addrs { for i, addr := range addrs {
pk, err := suite.app.AccountKeeper.GetPubKey(ctx, addr) pk, err := s.app.AccountKeeper.GetPubKey(ctx, addr)
require.NoError(err, "Error on retrieving pubkey from account") require.NoError(err, "Error on retrieving pubkey from account")
require.True(pubs[i].Equals(pk), require.True(pubs[i].Equals(pk),
"Wrong Pubkey retrieved from AccountKeeper, idx=%d\nexpected=%s\n got=%s", i, pubs[i], pk) "Wrong Pubkey retrieved from AccountKeeper, idx=%d\nexpected=%s\n got=%s", i, pubs[i], pk)
} }
} }
func (suite *AnteTestSuite) TestConsumeSignatureVerificationGas() { func (s *MWTestSuite) TestConsumeSignatureVerificationGas() {
params := types.DefaultParams() params := types.DefaultParams()
msg := []byte{1, 2, 3, 4} msg := []byte{1, 2, 3, 4}
cdc := simapp.MakeTestEncodingConfig().Amino cdc := simapp.MakeTestEncodingConfig().Amino
@ -78,9 +81,9 @@ func (suite *AnteTestSuite) TestConsumeSignatureVerificationGas() {
for i := 0; i < len(pkSet1); i++ { for i := 0; i < len(pkSet1); i++ {
stdSig := legacytx.StdSignature{PubKey: pkSet1[i], Signature: sigSet1[i]} stdSig := legacytx.StdSignature{PubKey: pkSet1[i], Signature: sigSet1[i]}
sigV2, err := legacytx.StdSignatureToSignatureV2(cdc, stdSig) sigV2, err := legacytx.StdSignatureToSignatureV2(cdc, stdSig)
suite.Require().NoError(err) s.Require().NoError(err)
err = multisig.AddSignatureV2(multisignature1, sigV2, pkSet1) err = multisig.AddSignatureV2(multisignature1, sigV2, pkSet1)
suite.Require().NoError(err) s.Require().NoError(err)
} }
type args struct { type args struct {
@ -107,23 +110,30 @@ func (suite *AnteTestSuite) TestConsumeSignatureVerificationGas() {
Data: tt.args.sig, Data: tt.args.sig,
Sequence: 0, // Arbitrary account sequence Sequence: 0, // Arbitrary account sequence
} }
err := ante.DefaultSigVerificationGasConsumer(tt.args.meter, sigV2, tt.args.params) err := middleware.DefaultSigVerificationGasConsumer(tt.args.meter, sigV2, tt.args.params)
if tt.shouldErr { if tt.shouldErr {
suite.Require().NotNil(err) s.Require().NotNil(err)
} else { } else {
suite.Require().Nil(err) s.Require().Nil(err)
suite.Require().Equal(tt.gasConsumed, tt.args.meter.GasConsumed(), fmt.Sprintf("%d != %d", tt.gasConsumed, tt.args.meter.GasConsumed())) s.Require().Equal(tt.gasConsumed, tt.args.meter.GasConsumed(), fmt.Sprintf("%d != %d", tt.gasConsumed, tt.args.meter.GasConsumed()))
} }
} }
} }
func (suite *AnteTestSuite) TestSigVerification() { func (s *MWTestSuite) TestSigVerification() {
suite.SetupTest(true) // setup ctx := s.SetupTest(true) // setup
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
// make block height non-zero to ensure account numbers part of signBytes // make block height non-zero to ensure account numbers part of signBytes
suite.ctx = suite.ctx.WithBlockHeight(1) ctx = ctx.WithBlockHeight(1)
txHandler := middleware.ComposeMiddlewares(
noopTxHandler{},
middleware.SetPubKeyMiddleware(s.app.AccountKeeper),
middleware.SigVerificationMiddleware(
s.app.AccountKeeper,
s.clientCtx.TxConfig.SignModeHandler(),
),
)
// keys and addresses // keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr() priv1, _, addr1 := testdata.KeyTestPubAddr()
@ -135,19 +145,15 @@ func (suite *AnteTestSuite) TestSigVerification() {
msgs := make([]sdk.Msg, len(addrs)) msgs := make([]sdk.Msg, len(addrs))
// set accounts and create msg for each address // set accounts and create msg for each address
for i, addr := range addrs { for i, addr := range addrs {
acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr)
suite.Require().NoError(acc.SetAccountNumber(uint64(i))) s.Require().NoError(acc.SetAccountNumber(uint64(i)))
suite.app.AccountKeeper.SetAccount(suite.ctx, acc) s.app.AccountKeeper.SetAccount(ctx, acc)
msgs[i] = testdata.NewTestMsg(addr) msgs[i] = testdata.NewTestMsg(addr)
} }
feeAmount := testdata.NewTestFeeAmount() feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit() gasLimit := testdata.NewTestGasLimit()
spkd := ante.NewSetPubKeyDecorator(suite.app.AccountKeeper)
svd := ante.NewSigVerificationDecorator(suite.app.AccountKeeper, suite.clientCtx.TxConfig.SignModeHandler())
antehandler := sdk.ChainAnteDecorators(spkd, svd)
type testCase struct { type testCase struct {
name string name string
privs []cryptotypes.PrivKey privs []cryptotypes.PrivKey
@ -166,21 +172,25 @@ func (suite *AnteTestSuite) TestSigVerification() {
{"no err on recheck", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, true, false}, {"no err on recheck", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, true, false},
} }
for i, tc := range testCases { for i, tc := range testCases {
suite.ctx = suite.ctx.WithIsReCheckTx(tc.recheck) ctx = ctx.WithIsReCheckTx(tc.recheck)
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() // Create new txBuilder for each test txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Create new txBuilder for each test
suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) s.Require().NoError(txBuilder.SetMsgs(msgs...))
suite.txBuilder.SetFeeAmount(feeAmount) txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit) txBuilder.SetGasLimit(gasLimit)
tx, err := suite.CreateTestTx(tc.privs, tc.accNums, tc.accSeqs, suite.ctx.ChainID()) testTx, _, err := s.createTestTx(txBuilder, tc.privs, tc.accNums, tc.accSeqs, ctx.ChainID())
suite.Require().NoError(err) s.Require().NoError(err)
_, err = antehandler(suite.ctx, tx, false) if tc.recheck {
if tc.shouldErr { _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestCheckTx{Type: abci.CheckTxType_Recheck})
suite.Require().NotNil(err, "TestCase %d: %s did not error as expected", i, tc.name)
} else { } else {
suite.Require().Nil(err, "TestCase %d: %s errored unexpectedly. Err: %v", i, tc.name, err) _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestCheckTx{})
}
if tc.shouldErr {
s.Require().NotNil(err, "TestCase %d: %s did not error as expected", i, tc.name)
} else {
s.Require().Nil(err, "TestCase %d: %s errored unexpectedly. Err: %v", i, tc.name, err)
} }
} }
} }
@ -191,35 +201,23 @@ func (suite *AnteTestSuite) TestSigVerification() {
// this, since it'll be handled by the test matrix. // this, since it'll be handled by the test matrix.
// In the meantime, we want to make double-sure amino compatibility works. // In the meantime, we want to make double-sure amino compatibility works.
// ref: https://github.com/cosmos/cosmos-sdk/issues/7229 // ref: https://github.com/cosmos/cosmos-sdk/issues/7229
func (suite *AnteTestSuite) TestSigVerification_ExplicitAmino() { func (s *MWTestSuite) TestSigVerification_ExplicitAmino() {
suite.app, suite.ctx = createTestApp(suite.T(), true) ctx := s.SetupTest(true)
suite.ctx = suite.ctx.WithBlockHeight(1) ctx = ctx.WithBlockHeight(1)
// Set up TxConfig. // Set up TxConfig.
aminoCdc := codec.NewLegacyAmino() aminoCdc := legacy.Cdc
aminoCdc.RegisterInterface((*sdk.Msg)(nil), nil)
aminoCdc.RegisterConcrete(&testdata.TestMsg{}, "testdata.TestMsg", nil)
// We're using TestMsg amino encoding in some tests, so register it here. // We're using TestMsg amino encoding in some tests, so register it here.
txConfig := legacytx.StdTxConfig{Cdc: aminoCdc} txConfig := legacytx.StdTxConfig{Cdc: aminoCdc}
suite.clientCtx = client.Context{}. s.clientCtx = client.Context{}.
WithTxConfig(txConfig) WithTxConfig(txConfig)
anteHandler, err := ante.NewAnteHandler(
ante.HandlerOptions{
AccountKeeper: suite.app.AccountKeeper,
BankKeeper: suite.app.BankKeeper,
FeegrantKeeper: suite.app.FeeGrantKeeper,
SignModeHandler: txConfig.SignModeHandler(),
SigGasConsumer: ante.DefaultSigVerificationGasConsumer,
},
)
suite.Require().NoError(err)
suite.anteHandler = anteHandler
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder()
// make block height non-zero to ensure account numbers part of signBytes // make block height non-zero to ensure account numbers part of signBytes
suite.ctx = suite.ctx.WithBlockHeight(1) ctx = ctx.WithBlockHeight(1)
// keys and addresses // keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr() priv1, _, addr1 := testdata.KeyTestPubAddr()
@ -231,18 +229,23 @@ func (suite *AnteTestSuite) TestSigVerification_ExplicitAmino() {
msgs := make([]sdk.Msg, len(addrs)) msgs := make([]sdk.Msg, len(addrs))
// set accounts and create msg for each address // set accounts and create msg for each address
for i, addr := range addrs { for i, addr := range addrs {
acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr)
suite.Require().NoError(acc.SetAccountNumber(uint64(i))) s.Require().NoError(acc.SetAccountNumber(uint64(i)))
suite.app.AccountKeeper.SetAccount(suite.ctx, acc) s.app.AccountKeeper.SetAccount(ctx, acc)
msgs[i] = testdata.NewTestMsg(addr) msgs[i] = testdata.NewTestMsg(addr)
} }
feeAmount := testdata.NewTestFeeAmount() feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit() gasLimit := testdata.NewTestGasLimit()
spkd := ante.NewSetPubKeyDecorator(suite.app.AccountKeeper) txHandler := middleware.ComposeMiddlewares(
svd := ante.NewSigVerificationDecorator(suite.app.AccountKeeper, suite.clientCtx.TxConfig.SignModeHandler()) noopTxHandler{},
antehandler := sdk.ChainAnteDecorators(spkd, svd) middleware.SetPubKeyMiddleware(s.app.AccountKeeper),
middleware.SigVerificationMiddleware(
s.app.AccountKeeper,
s.clientCtx.TxConfig.SignModeHandler(),
),
)
type testCase struct { type testCase struct {
name string name string
@ -252,6 +255,7 @@ func (suite *AnteTestSuite) TestSigVerification_ExplicitAmino() {
recheck bool recheck bool
shouldErr bool shouldErr bool
} }
testCases := []testCase{ testCases := []testCase{
{"no signers", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, false, true}, {"no signers", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, false, true},
{"not enough signers", []cryptotypes.PrivKey{priv1, priv2}, []uint64{0, 1}, []uint64{0, 0}, false, true}, {"not enough signers", []cryptotypes.PrivKey{priv1, priv2}, []uint64{0, 1}, []uint64{0, 0}, false, true},
@ -261,27 +265,32 @@ func (suite *AnteTestSuite) TestSigVerification_ExplicitAmino() {
{"valid tx", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 1, 2}, []uint64{0, 0, 0}, false, false}, {"valid tx", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 1, 2}, []uint64{0, 0, 0}, false, false},
{"no err on recheck", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, true, false}, {"no err on recheck", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, true, false},
} }
for i, tc := range testCases { for i, tc := range testCases {
suite.ctx = suite.ctx.WithIsReCheckTx(tc.recheck) ctx = ctx.WithIsReCheckTx(tc.recheck)
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() // Create new txBuilder for each test txBuilder := s.clientCtx.TxConfig.NewTxBuilder() // Create new txBuilder for each test
suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) s.Require().NoError(txBuilder.SetMsgs(msgs...))
suite.txBuilder.SetFeeAmount(feeAmount) txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit) txBuilder.SetGasLimit(gasLimit)
tx, err := suite.CreateTestTx(tc.privs, tc.accNums, tc.accSeqs, suite.ctx.ChainID()) testTx, _, err := s.createTestTx(txBuilder, tc.privs, tc.accNums, tc.accSeqs, ctx.ChainID())
suite.Require().NoError(err) s.Require().NoError(err)
_, err = antehandler(suite.ctx, tx, false) if tc.recheck {
if tc.shouldErr { _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestCheckTx{Type: abci.CheckTxType_Recheck})
suite.Require().NotNil(err, "TestCase %d: %s did not error as expected", i, tc.name)
} else { } else {
suite.Require().Nil(err, "TestCase %d: %s errored unexpectedly. Err: %v", i, tc.name, err) _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestCheckTx{})
}
if tc.shouldErr {
s.Require().NotNil(err, "TestCase %d: %s did not error as expected", i, tc.name)
} else {
s.Require().Nil(err, "TestCase %d: %s errored unexpectedly. Err: %v", i, tc.name, err)
} }
} }
} }
func (suite *AnteTestSuite) TestSigIntegration() { func (s *MWTestSuite) TestSigIntegration() {
// generate private keys // generate private keys
privs := []cryptotypes.PrivKey{ privs := []cryptotypes.PrivKey{
secp256k1.GenPrivKey(), secp256k1.GenPrivKey(),
@ -291,23 +300,23 @@ func (suite *AnteTestSuite) TestSigIntegration() {
params := types.DefaultParams() params := types.DefaultParams()
initialSigCost := params.SigVerifyCostSecp256k1 initialSigCost := params.SigVerifyCostSecp256k1
initialCost, err := suite.runSigDecorators(params, false, privs...) initialCost, err := s.runSigMiddlewares(params, false, privs...)
suite.Require().Nil(err) s.Require().Nil(err)
params.SigVerifyCostSecp256k1 *= 2 params.SigVerifyCostSecp256k1 *= 2
doubleCost, err := suite.runSigDecorators(params, false, privs...) doubleCost, err := s.runSigMiddlewares(params, false, privs...)
suite.Require().Nil(err) s.Require().Nil(err)
suite.Require().Equal(initialSigCost*uint64(len(privs)), doubleCost-initialCost) s.Require().Equal(initialSigCost*uint64(len(privs)), doubleCost-initialCost)
} }
func (suite *AnteTestSuite) runSigDecorators(params types.Params, _ bool, privs ...cryptotypes.PrivKey) (sdk.Gas, error) { func (s *MWTestSuite) runSigMiddlewares(params types.Params, _ bool, privs ...cryptotypes.PrivKey) (sdk.Gas, error) {
suite.SetupTest(true) // setup ctx := s.SetupTest(true) // setup
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
// Make block-height non-zero to include accNum in SignBytes // Make block-height non-zero to include accNum in SignBytes
suite.ctx = suite.ctx.WithBlockHeight(1) ctx = ctx.WithBlockHeight(1)
suite.app.AccountKeeper.SetParams(suite.ctx, params) s.app.AccountKeeper.SetParams(ctx, params)
msgs := make([]sdk.Msg, len(privs)) msgs := make([]sdk.Msg, len(privs))
accNums := make([]uint64, len(privs)) accNums := make([]uint64, len(privs))
@ -315,76 +324,89 @@ func (suite *AnteTestSuite) runSigDecorators(params types.Params, _ bool, privs
// set accounts and create msg for each address // set accounts and create msg for each address
for i, priv := range privs { for i, priv := range privs {
addr := sdk.AccAddress(priv.PubKey().Address()) addr := sdk.AccAddress(priv.PubKey().Address())
acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr)
suite.Require().NoError(acc.SetAccountNumber(uint64(i))) s.Require().NoError(acc.SetAccountNumber(uint64(i)))
suite.app.AccountKeeper.SetAccount(suite.ctx, acc) s.app.AccountKeeper.SetAccount(ctx, acc)
msgs[i] = testdata.NewTestMsg(addr) msgs[i] = testdata.NewTestMsg(addr)
accNums[i] = uint64(i) accNums[i] = uint64(i)
accSeqs[i] = uint64(0) accSeqs[i] = uint64(0)
} }
suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) s.Require().NoError(txBuilder.SetMsgs(msgs...))
feeAmount := testdata.NewTestFeeAmount() feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit() gasLimit := testdata.NewTestGasLimit()
suite.txBuilder.SetFeeAmount(feeAmount) txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit) txBuilder.SetGasLimit(gasLimit)
tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
suite.Require().NoError(err) s.Require().NoError(err)
spkd := ante.NewSetPubKeyDecorator(suite.app.AccountKeeper) txHandler := middleware.ComposeMiddlewares(
svgc := ante.NewSigGasConsumeDecorator(suite.app.AccountKeeper, ante.DefaultSigVerificationGasConsumer) noopTxHandler{},
svd := ante.NewSigVerificationDecorator(suite.app.AccountKeeper, suite.clientCtx.TxConfig.SignModeHandler()) middleware.SetPubKeyMiddleware(s.app.AccountKeeper),
antehandler := sdk.ChainAnteDecorators(spkd, svgc, svd) middleware.SigGasConsumeMiddleware(s.app.AccountKeeper, middleware.DefaultSigVerificationGasConsumer),
middleware.SigVerificationMiddleware(
s.app.AccountKeeper,
s.clientCtx.TxConfig.SignModeHandler(),
),
)
// Determine gas consumption of antehandler with default params // Determine gas consumption of txhandler with default params
before := suite.ctx.GasMeter().GasConsumed() before := ctx.GasMeter().GasConsumed()
ctx, err := antehandler(suite.ctx, tx, false) _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), testTx, abci.RequestDeliverTx{})
after := ctx.GasMeter().GasConsumed() after := ctx.GasMeter().GasConsumed()
return after - before, err return after - before, err
} }
func (suite *AnteTestSuite) TestIncrementSequenceDecorator() { func (s *MWTestSuite) TestIncrementSequenceMiddleware() {
suite.SetupTest(true) // setup ctx := s.SetupTest(true) // setup
suite.txBuilder = suite.clientCtx.TxConfig.NewTxBuilder() txBuilder := s.clientCtx.TxConfig.NewTxBuilder()
priv, _, addr := testdata.KeyTestPubAddr() priv, _, addr := testdata.KeyTestPubAddr()
acc := suite.app.AccountKeeper.NewAccountWithAddress(suite.ctx, addr) acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr)
suite.Require().NoError(acc.SetAccountNumber(uint64(50))) s.Require().NoError(acc.SetAccountNumber(uint64(50)))
suite.app.AccountKeeper.SetAccount(suite.ctx, acc) s.app.AccountKeeper.SetAccount(ctx, acc)
msgs := []sdk.Msg{testdata.NewTestMsg(addr)} msgs := []sdk.Msg{testdata.NewTestMsg(addr)}
suite.Require().NoError(suite.txBuilder.SetMsgs(msgs...)) s.Require().NoError(txBuilder.SetMsgs(msgs...))
privs := []cryptotypes.PrivKey{priv} privs := []cryptotypes.PrivKey{priv}
accNums := []uint64{suite.app.AccountKeeper.GetAccount(suite.ctx, addr).GetAccountNumber()} accNums := []uint64{s.app.AccountKeeper.GetAccount(ctx, addr).GetAccountNumber()}
accSeqs := []uint64{suite.app.AccountKeeper.GetAccount(suite.ctx, addr).GetSequence()} accSeqs := []uint64{s.app.AccountKeeper.GetAccount(ctx, addr).GetSequence()}
feeAmount := testdata.NewTestFeeAmount() feeAmount := testdata.NewTestFeeAmount()
gasLimit := testdata.NewTestGasLimit() gasLimit := testdata.NewTestGasLimit()
suite.txBuilder.SetFeeAmount(feeAmount) txBuilder.SetFeeAmount(feeAmount)
suite.txBuilder.SetGasLimit(gasLimit) txBuilder.SetGasLimit(gasLimit)
tx, err := suite.CreateTestTx(privs, accNums, accSeqs, suite.ctx.ChainID()) testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID())
suite.Require().NoError(err) s.Require().NoError(err)
isd := ante.NewIncrementSequenceDecorator(suite.app.AccountKeeper) txHandler := middleware.ComposeMiddlewares(
antehandler := sdk.ChainAnteDecorators(isd) noopTxHandler{},
middleware.IncrementSequenceMiddleware(s.app.AccountKeeper),
)
testCases := []struct { testCases := []struct {
ctx sdk.Context ctx sdk.Context
simulate bool simulate bool
expectedSeq uint64 expectedSeq uint64
}{ }{
{suite.ctx.WithIsReCheckTx(true), false, 1}, {ctx.WithIsReCheckTx(true), false, 1},
{suite.ctx.WithIsCheckTx(true).WithIsReCheckTx(false), false, 2}, {ctx.WithIsCheckTx(true).WithIsReCheckTx(false), false, 2},
{suite.ctx.WithIsReCheckTx(true), false, 3}, {ctx.WithIsReCheckTx(true), false, 3},
{suite.ctx.WithIsReCheckTx(true), false, 4}, {ctx.WithIsReCheckTx(true), false, 4},
{suite.ctx.WithIsReCheckTx(true), true, 5}, {ctx.WithIsReCheckTx(true), true, 5},
} }
for i, tc := range testCases { for i, tc := range testCases {
_, err := antehandler(tc.ctx, tx, tc.simulate) var err error
suite.Require().NoError(err, "unexpected error; tc #%d, %v", i, tc) if tc.simulate {
suite.Require().Equal(tc.expectedSeq, suite.app.AccountKeeper.GetAccount(suite.ctx, addr).GetSequence()) _, err = txHandler.SimulateTx(sdk.WrapSDKContext(tc.ctx), testTx, tx.RequestSimulateTx{})
} else {
_, err = txHandler.DeliverTx(sdk.WrapSDKContext(tc.ctx), testTx, abci.RequestDeliverTx{})
}
s.Require().NoError(err, "unexpected error; tc #%d, %v", i, tc)
s.Require().Equal(tc.expectedSeq, s.app.AccountKeeper.GetAccount(ctx, addr).GetSequence())
} }
} }

View File

@ -1,18 +1,24 @@
package middleware_test package middleware_test
import ( import (
"errors"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/tendermint/tendermint/abci/types"
tmproto "github.com/tendermint/tendermint/proto/tendermint/types" tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
"github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/client/tx" "github.com/cosmos/cosmos-sdk/client/tx"
"github.com/cosmos/cosmos-sdk/codec"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/cosmos-sdk/simapp" "github.com/cosmos/cosmos-sdk/simapp"
"github.com/cosmos/cosmos-sdk/testutil/testdata" "github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
txtypes "github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/types/tx/signing" "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/middleware"
xauthsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" xauthsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
minttypes "github.com/cosmos/cosmos-sdk/x/mint/types" minttypes "github.com/cosmos/cosmos-sdk/x/mint/types"
@ -30,12 +36,13 @@ type MWTestSuite struct {
app *simapp.SimApp app *simapp.SimApp
clientCtx client.Context clientCtx client.Context
txHandler txtypes.Handler
} }
// returns context and app with params set on account keeper // returns context and app with params set on account keeper
func createTestApp(t *testing.T, isCheckTx bool) (*simapp.SimApp, sdk.Context) { func createTestApp(t *testing.T, isCheckTx bool) (*simapp.SimApp, sdk.Context) {
app := simapp.Setup(t, isCheckTx) app := simapp.Setup(t, isCheckTx)
ctx := app.BaseApp.NewContext(isCheckTx, tmproto.Header{}) ctx := app.BaseApp.NewContext(isCheckTx, tmproto.Header{}).WithBlockGasMeter(sdk.NewInfiniteGasMeter())
app.AccountKeeper.SetParams(ctx, authtypes.DefaultParams()) app.AccountKeeper.SetParams(ctx, authtypes.DefaultParams())
return app, ctx return app, ctx
@ -54,14 +61,35 @@ func (s *MWTestSuite) SetupTest(isCheckTx bool) sdk.Context {
testdata.RegisterInterfaces(encodingConfig.InterfaceRegistry) testdata.RegisterInterfaces(encodingConfig.InterfaceRegistry)
s.clientCtx = client.Context{}. s.clientCtx = client.Context{}.
WithTxConfig(encodingConfig.TxConfig) WithTxConfig(encodingConfig.TxConfig).
WithInterfaceRegistry(encodingConfig.InterfaceRegistry).
WithCodec(codec.NewAminoCodec(encodingConfig.Amino))
// We don't use simapp's own txHandler. For more flexibility (i.e. around
// using testdata), we create own own txHandler for this test suite.
msr := middleware.NewMsgServiceRouter(encodingConfig.InterfaceRegistry)
testdata.RegisterMsgServer(msr, testdata.MsgServerImpl{})
legacyRouter := middleware.NewLegacyRouter()
legacyRouter.AddRoute(sdk.NewRoute((&testdata.TestMsg{}).Route(), func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) { return &sdk.Result{}, nil }))
txHandler, err := middleware.NewDefaultTxHandler(middleware.TxHandlerOptions{
Debug: s.app.Trace(),
MsgServiceRouter: msr,
LegacyRouter: legacyRouter,
AccountKeeper: s.app.AccountKeeper,
BankKeeper: s.app.BankKeeper,
FeegrantKeeper: s.app.FeeGrantKeeper,
SignModeHandler: encodingConfig.TxConfig.SignModeHandler(),
SigGasConsumer: middleware.DefaultSigVerificationGasConsumer,
})
s.Require().NoError(err)
s.txHandler = txHandler
return ctx return ctx
} }
// CreatetestAccounts creates `numAccs` accounts, and return all relevant // createTestAccounts creates `numAccs` accounts, and return all relevant
// information about them including their private keys. // information about them including their private keys.
func (s *MWTestSuite) CreatetestAccounts(ctx sdk.Context, numAccs int) []testAccount { func (s *MWTestSuite) createTestAccounts(ctx sdk.Context, numAccs int) []testAccount {
var accounts []testAccount var accounts []testAccount
for i := 0; i < numAccs; i++ { for i := 0; i < numAccs; i++ {
@ -137,6 +165,48 @@ func (s *MWTestSuite) createTestTx(txBuilder client.TxBuilder, privs []cryptotyp
return txBuilder.GetTx(), txBytes, nil return txBuilder.GetTx(), txBytes, nil
} }
func (s *MWTestSuite) runTestCase(ctx sdk.Context, txBuilder client.TxBuilder, privs []cryptotypes.PrivKey, msgs []sdk.Msg, feeAmount sdk.Coins, gasLimit uint64, accNums, accSeqs []uint64, chainID string, tc TestCase) {
s.Run(fmt.Sprintf("Case %s", tc.desc), func() {
s.Require().NoError(txBuilder.SetMsgs(msgs...))
txBuilder.SetFeeAmount(feeAmount)
txBuilder.SetGasLimit(gasLimit)
// Theoretically speaking, middleware unit tests should only test
// middlewares, but here we sometimes also test the tx creation
// process.
tx, _, txErr := s.createTestTx(txBuilder, privs, accNums, accSeqs, chainID)
newCtx, txHandlerErr := s.txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx, types.RequestDeliverTx{})
if tc.expPass {
s.Require().NoError(txErr)
s.Require().NoError(txHandlerErr)
s.Require().NotNil(newCtx)
} else {
switch {
case txErr != nil:
s.Require().Error(txErr)
s.Require().True(errors.Is(txErr, tc.expErr))
case txHandlerErr != nil:
s.Require().Error(txHandlerErr)
s.Require().True(errors.Is(txHandlerErr, tc.expErr))
default:
s.Fail("expected one of txErr,txHandlerErr to be an error")
}
}
})
}
// TestCase represents a test case used in test tables.
type TestCase struct {
desc string
malleate func()
simulate bool
expPass bool
expErr error
}
func TestMWTestSuite(t *testing.T) { func TestMWTestSuite(t *testing.T) {
suite.Run(t, new(MWTestSuite)) suite.Run(t, new(MWTestSuite))
} }

View File

@ -13,7 +13,7 @@ import (
"github.com/cosmos/cosmos-sdk/simapp" "github.com/cosmos/cosmos-sdk/simapp"
"github.com/cosmos/cosmos-sdk/testutil/testdata" "github.com/cosmos/cosmos-sdk/testutil/testdata"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/ante" "github.com/cosmos/cosmos-sdk/x/auth/middleware"
"github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx" "github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx"
"github.com/cosmos/cosmos-sdk/x/auth/signing" "github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/cosmos/cosmos-sdk/x/auth/types"
@ -42,7 +42,8 @@ func TestVerifySignature(t *testing.T) {
app.AccountKeeper.SetAccount(ctx, acc1) app.AccountKeeper.SetAccount(ctx, acc1)
balances := sdk.NewCoins(sdk.NewInt64Coin("atom", 200)) balances := sdk.NewCoins(sdk.NewInt64Coin("atom", 200))
require.NoError(t, testutil.FundAccount(app.BankKeeper, ctx, addr, balances)) require.NoError(t, testutil.FundAccount(app.BankKeeper, ctx, addr, balances))
acc, err := ante.GetSignerAcc(ctx, app.AccountKeeper, addr) acc, err := middleware.GetSignerAcc(ctx, app.AccountKeeper, addr)
require.NoError(t, err)
require.NoError(t, testutil.FundAccount(app.BankKeeper, ctx, addr, balances)) require.NoError(t, testutil.FundAccount(app.BankKeeper, ctx, addr, balances))
msgs := []sdk.Msg{testdata.NewTestMsg(addr)} msgs := []sdk.Msg{testdata.NewTestMsg(addr)}

View File

@ -10,7 +10,7 @@ import (
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx"
"github.com/cosmos/cosmos-sdk/types/tx/signing" "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/ante" "github.com/cosmos/cosmos-sdk/x/auth/middleware"
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
) )
@ -33,7 +33,7 @@ type wrapper struct {
var ( var (
_ authsigning.Tx = &wrapper{} _ authsigning.Tx = &wrapper{}
_ client.TxBuilder = &wrapper{} _ client.TxBuilder = &wrapper{}
_ ante.HasExtensionOptionsTx = &wrapper{} _ middleware.HasExtensionOptionsTx = &wrapper{}
_ ExtensionOptionsTxBuilder = &wrapper{} _ ExtensionOptionsTxBuilder = &wrapper{}
) )

View File

@ -8,7 +8,7 @@ import (
"github.com/cosmos/cosmos-sdk/codec" "github.com/cosmos/cosmos-sdk/codec"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/x/auth/ante" "github.com/cosmos/cosmos-sdk/x/auth/middleware"
"github.com/cosmos/cosmos-sdk/x/feegrant" "github.com/cosmos/cosmos-sdk/x/feegrant"
) )
@ -20,7 +20,7 @@ type Keeper struct {
authKeeper feegrant.AccountKeeper authKeeper feegrant.AccountKeeper
} }
var _ ante.FeegrantKeeper = &Keeper{} var _ middleware.FeegrantKeeper = &Keeper{}
// NewKeeper creates a fee grant Keeper // NewKeeper creates a fee grant Keeper
func NewKeeper(cdc codec.BinaryCodec, storeKey sdk.StoreKey, ak feegrant.AccountKeeper) Keeper { func NewKeeper(cdc codec.BinaryCodec, storeKey sdk.StoreKey, ak feegrant.AccountKeeper) Keeper {