Merge PR #4797: blacklist module accounts from receiving txs

This commit is contained in:
Federico Kunze 2019-07-31 17:39:02 +02:00 committed by Alexander Bezobchuk
parent 83b1a9fc22
commit 8c989fd424
20 changed files with 313 additions and 143 deletions

View File

@ -0,0 +1,2 @@
#4795 restrict module accounts from receiving transactions.
Allowing this would cause an invariant on the module account coins.

View File

@ -2,10 +2,11 @@
## Supply ## Supply
The `supply` module: The `supply` module:
- passively tracks the total supply of coins within a chain,
- provides a pattern for modules to hold/interact with `Coins`, and - passively tracks the total supply of coins within a chain,
- introduces the invariant check to verify a chain's total supply. - provides a pattern for modules to hold/interact with `Coins`, and
- introduces the invariant check to verify a chain's total supply.
### Total Supply ### Total Supply
@ -32,11 +33,14 @@ The `ModuleAccount` interface is defined as follows:
type ModuleAccount interface { type ModuleAccount interface {
auth.Account // same methods as the Account interface auth.Account // same methods as the Account interface
GetName() string // name of the module; used to obtain the address GetName() string // name of the module; used to obtain the address
GetPermissions() []string // permissions of module account GetPermissions() []string // permissions of module account
HasPermission(string) bool HasPermission(string) bool
} }
``` ```
> **WARNING!**
Any module or message handler that allows either direct or indirect sending of funds must explicitly guarantee those funds cannot be sent to module accounts (unless allowed).
The supply `Keeper` also introduces new wrapper functions for the auth `Keeper` The supply `Keeper` also introduces new wrapper functions for the auth `Keeper`
and the bank `Keeper` that are related to `ModuleAccount`s in order to be able and the bank `Keeper` that are related to `ModuleAccount`s in order to be able
to: to:

View File

@ -141,13 +141,13 @@ func NewSimApp(
// add keepers // add keepers
app.accountKeeper = auth.NewAccountKeeper(app.cdc, keys[auth.StoreKey], authSubspace, auth.ProtoBaseAccount) app.accountKeeper = auth.NewAccountKeeper(app.cdc, keys[auth.StoreKey], authSubspace, auth.ProtoBaseAccount)
app.bankKeeper = bank.NewBaseKeeper(app.accountKeeper, bankSubspace, bank.DefaultCodespace) app.bankKeeper = bank.NewBaseKeeper(app.accountKeeper, bankSubspace, bank.DefaultCodespace, app.ModuleAccountAddrs())
app.supplyKeeper = supply.NewKeeper(app.cdc, keys[supply.StoreKey], app.accountKeeper, app.bankKeeper, supply.DefaultCodespace, maccPerms) app.supplyKeeper = supply.NewKeeper(app.cdc, keys[supply.StoreKey], app.accountKeeper, app.bankKeeper, supply.DefaultCodespace, maccPerms)
stakingKeeper := staking.NewKeeper(app.cdc, keys[staking.StoreKey], tkeys[staking.TStoreKey], stakingKeeper := staking.NewKeeper(app.cdc, keys[staking.StoreKey], tkeys[staking.TStoreKey],
app.supplyKeeper, stakingSubspace, staking.DefaultCodespace) app.supplyKeeper, stakingSubspace, staking.DefaultCodespace)
app.mintKeeper = mint.NewKeeper(app.cdc, keys[mint.StoreKey], mintSubspace, &stakingKeeper, app.supplyKeeper, auth.FeeCollectorName) app.mintKeeper = mint.NewKeeper(app.cdc, keys[mint.StoreKey], mintSubspace, &stakingKeeper, app.supplyKeeper, auth.FeeCollectorName)
app.distrKeeper = distr.NewKeeper(app.cdc, keys[distr.StoreKey], distrSubspace, &stakingKeeper, app.distrKeeper = distr.NewKeeper(app.cdc, keys[distr.StoreKey], distrSubspace, &stakingKeeper,
app.supplyKeeper, distr.DefaultCodespace, auth.FeeCollectorName) app.supplyKeeper, distr.DefaultCodespace, auth.FeeCollectorName, app.ModuleAccountAddrs())
app.slashingKeeper = slashing.NewKeeper(app.cdc, keys[slashing.StoreKey], &stakingKeeper, app.slashingKeeper = slashing.NewKeeper(app.cdc, keys[slashing.StoreKey], &stakingKeeper,
slashingSubspace, slashing.DefaultCodespace) slashingSubspace, slashing.DefaultCodespace)
app.crisisKeeper = crisis.NewKeeper(crisisSubspace, invCheckPeriod, app.supplyKeeper, auth.FeeCollectorName) app.crisisKeeper = crisis.NewKeeper(crisisSubspace, invCheckPeriod, app.supplyKeeper, auth.FeeCollectorName)

View File

@ -5,11 +5,8 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth" "github.com/cosmos/cosmos-sdk/x/auth"
"github.com/cosmos/cosmos-sdk/x/bank"
"github.com/cosmos/cosmos-sdk/x/bank/internal/keeper"
"github.com/cosmos/cosmos-sdk/x/bank/internal/types" "github.com/cosmos/cosmos-sdk/x/bank/internal/types"
"github.com/cosmos/cosmos-sdk/x/mock" "github.com/cosmos/cosmos-sdk/x/mock"
"github.com/cosmos/cosmos-sdk/x/supply"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -50,6 +47,7 @@ var (
freeFee = auth.NewStdFee(100000, sdk.Coins{sdk.NewInt64Coin("foocoin", 0)}) freeFee = auth.NewStdFee(100000, sdk.Coins{sdk.NewInt64Coin("foocoin", 0)})
sendMsg1 = types.NewMsgSend(addr1, addr2, coins) sendMsg1 = types.NewMsgSend(addr1, addr2, coins)
sendMsg2 = types.NewMsgSend(addr1, moduleAccAddr, coins)
multiSendMsg1 = types.MsgMultiSend{ multiSendMsg1 = types.MsgMultiSend{
Inputs: []types.Input{types.NewInput(addr1, coins)}, Inputs: []types.Input{types.NewInput(addr1, coins)},
@ -88,26 +86,15 @@ var (
types.NewOutput(addr2, manyCoins), types.NewOutput(addr2, manyCoins),
}, },
} }
) multiSendMsg6 = types.MsgMultiSend{
Inputs: []types.Input{
// initialize the mock application for this module types.NewInput(addr1, coins),
func getMockApp(t *testing.T) *mock.App { },
mapp, err := getBenchmarkMockApp() Outputs: []types.Output{
supply.RegisterCodec(mapp.Cdc) types.NewOutput(moduleAccAddr, coins),
require.NoError(t, err) },
return mapp
}
// overwrite the mock init chainer
func getInitChainer(mapp *mock.App, keeper keeper.BaseKeeper) sdk.InitChainer {
return func(ctx sdk.Context, req abci.RequestInitChain) abci.ResponseInitChain {
mapp.InitChainer(ctx, req)
bankGenesis := bank.DefaultGenesisState()
bank.InitGenesis(ctx, keeper, bankGenesis)
return abci.ResponseInitChain{}
} }
} )
func TestSendNotEnoughBalance(t *testing.T) { func TestSendNotEnoughBalance(t *testing.T) {
mapp := getMockApp(t) mapp := getMockApp(t)
@ -140,6 +127,41 @@ func TestSendNotEnoughBalance(t *testing.T) {
require.True(t, res2.GetSequence() == origSeq+1) require.True(t, res2.GetSequence() == origSeq+1)
} }
func TestSendToModuleAcc(t *testing.T) {
mapp := getMockApp(t)
acc := &auth.BaseAccount{
Address: addr1,
Coins: coins,
}
macc := &auth.BaseAccount{
Address: moduleAccAddr,
}
mock.SetGenesis(mapp, []auth.Account{acc, macc})
ctxCheck := mapp.BaseApp.NewContext(true, abci.Header{})
res1 := mapp.AccountKeeper.GetAccount(ctxCheck, addr1)
require.NotNil(t, res1)
require.Equal(t, acc, res1.(*auth.BaseAccount))
origAccNum := res1.GetAccountNumber()
origSeq := res1.GetSequence()
header := abci.Header{Height: mapp.LastBlockHeight() + 1}
mock.SignCheckDeliver(t, mapp.Cdc, mapp.BaseApp, header, []sdk.Msg{sendMsg2}, []uint64{origAccNum}, []uint64{origSeq}, false, false, priv1)
mock.CheckBalance(t, mapp, addr1, coins)
mock.CheckBalance(t, mapp, moduleAccAddr, sdk.Coins(nil))
res2 := mapp.AccountKeeper.GetAccount(mapp.NewContext(true, abci.Header{}), addr1)
require.NotNil(t, res2)
require.True(t, res2.GetAccountNumber() == origAccNum)
require.True(t, res2.GetSequence() == origSeq+1)
}
func TestMsgMultiSendWithAccounts(t *testing.T) { func TestMsgMultiSendWithAccounts(t *testing.T) {
mapp := getMockApp(t) mapp := getMockApp(t)
acc := &auth.BaseAccount{ acc := &auth.BaseAccount{
@ -147,7 +169,11 @@ func TestMsgMultiSendWithAccounts(t *testing.T) {
Coins: sdk.Coins{sdk.NewInt64Coin("foocoin", 67)}, Coins: sdk.Coins{sdk.NewInt64Coin("foocoin", 67)},
} }
mock.SetGenesis(mapp, []auth.Account{acc}) macc := &auth.BaseAccount{
Address: moduleAccAddr,
}
mock.SetGenesis(mapp, []auth.Account{acc, macc})
ctxCheck := mapp.BaseApp.NewContext(true, abci.Header{}) ctxCheck := mapp.BaseApp.NewContext(true, abci.Header{})
@ -176,6 +202,14 @@ func TestMsgMultiSendWithAccounts(t *testing.T) {
expPass: false, expPass: false,
privKeys: []crypto.PrivKey{priv1}, privKeys: []crypto.PrivKey{priv1},
}, },
{
msgs: []sdk.Msg{multiSendMsg6},
accNums: []uint64{0},
accSeqs: []uint64{0},
expSimPass: false,
expPass: false,
privKeys: []crypto.PrivKey{priv1},
},
} }
for _, tc := range testCases { for _, tc := range testCases {

View File

@ -3,6 +3,7 @@ package bank_test
import ( import (
"testing" "testing"
"github.com/stretchr/testify/require"
abci "github.com/tendermint/tendermint/abci/types" abci "github.com/tendermint/tendermint/abci/types"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
@ -11,18 +12,44 @@ import (
"github.com/cosmos/cosmos-sdk/x/bank/internal/keeper" "github.com/cosmos/cosmos-sdk/x/bank/internal/keeper"
"github.com/cosmos/cosmos-sdk/x/bank/internal/types" "github.com/cosmos/cosmos-sdk/x/bank/internal/types"
"github.com/cosmos/cosmos-sdk/x/mock" "github.com/cosmos/cosmos-sdk/x/mock"
"github.com/cosmos/cosmos-sdk/x/supply"
) )
var moduleAccAddr = sdk.AccAddress([]byte("moduleAcc"))
// initialize the mock application for this module
func getMockApp(t *testing.T) *mock.App {
mapp, err := getBenchmarkMockApp()
supply.RegisterCodec(mapp.Cdc)
require.NoError(t, err)
return mapp
}
// overwrite the mock init chainer
func getInitChainer(mapp *mock.App, keeper keeper.BaseKeeper) sdk.InitChainer {
return func(ctx sdk.Context, req abci.RequestInitChain) abci.ResponseInitChain {
mapp.InitChainer(ctx, req)
bankGenesis := bank.DefaultGenesisState()
bank.InitGenesis(ctx, keeper, bankGenesis)
return abci.ResponseInitChain{}
}
}
// getBenchmarkMockApp initializes a mock application for this module, for purposes of benchmarking // getBenchmarkMockApp initializes a mock application for this module, for purposes of benchmarking
// Any long term API support commitments do not apply to this function. // Any long term API support commitments do not apply to this function.
func getBenchmarkMockApp() (*mock.App, error) { func getBenchmarkMockApp() (*mock.App, error) {
mapp := mock.NewApp() mapp := mock.NewApp()
types.RegisterCodec(mapp.Cdc) types.RegisterCodec(mapp.Cdc)
blacklistedAddrs := make(map[string]bool)
blacklistedAddrs[moduleAccAddr.String()] = true
bankKeeper := keeper.NewBaseKeeper( bankKeeper := keeper.NewBaseKeeper(
mapp.AccountKeeper, mapp.AccountKeeper,
mapp.ParamsKeeper.Subspace(types.DefaultParamspace), mapp.ParamsKeeper.Subspace(types.DefaultParamspace),
types.DefaultCodespace, types.DefaultCodespace,
blacklistedAddrs,
) )
mapp.Router().AddRoute(types.RouterKey, bank.NewHandler(bankKeeper)) mapp.Router().AddRoute(types.RouterKey, bank.NewHandler(bankKeeper))
mapp.SetInitChainer(getInitChainer(mapp, bankKeeper)) mapp.SetInitChainer(getInitChainer(mapp, bankKeeper))

View File

@ -33,6 +33,10 @@ func handleMsgSend(ctx sdk.Context, k keeper.Keeper, msg types.MsgSend) sdk.Resu
return types.ErrSendDisabled(k.Codespace()).Result() return types.ErrSendDisabled(k.Codespace()).Result()
} }
if k.BlacklistedAddr(msg.ToAddress) {
return sdk.ErrUnauthorized(fmt.Sprintf("%s is not allowed to receive transactions", msg.ToAddress)).Result()
}
err := k.SendCoins(ctx, msg.FromAddress, msg.ToAddress, msg.Amount) err := k.SendCoins(ctx, msg.FromAddress, msg.ToAddress, msg.Amount)
if err != nil { if err != nil {
return err.Result() return err.Result()
@ -55,6 +59,12 @@ func handleMsgMultiSend(ctx sdk.Context, k keeper.Keeper, msg types.MsgMultiSend
return types.ErrSendDisabled(k.Codespace()).Result() return types.ErrSendDisabled(k.Codespace()).Result()
} }
for _, out := range msg.Outputs {
if k.BlacklistedAddr(out.Address) {
return sdk.ErrUnauthorized(fmt.Sprintf("%s is not allowed to receive transactions", out.Address)).Result()
}
}
err := k.InputOutputCoins(ctx, msg.Inputs, msg.Outputs) err := k.InputOutputCoins(ctx, msg.Inputs, msg.Outputs)
if err != nil { if err != nil {
return err.Result() return err.Result()

View File

@ -0,0 +1,59 @@
package keeper
// DONTCOVER
import (
abci "github.com/tendermint/tendermint/abci/types"
dbm "github.com/tendermint/tendermint/libs/db"
"github.com/tendermint/tendermint/libs/log"
"github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/store"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth"
"github.com/cosmos/cosmos-sdk/x/bank/internal/types"
"github.com/cosmos/cosmos-sdk/x/params"
)
type testInput struct {
cdc *codec.Codec
ctx sdk.Context
k Keeper
ak auth.AccountKeeper
pk params.Keeper
}
func setupTestInput() testInput {
db := dbm.NewMemDB()
cdc := codec.New()
auth.RegisterCodec(cdc)
codec.RegisterCrypto(cdc)
authCapKey := sdk.NewKVStoreKey("authCapKey")
keyParams := sdk.NewKVStoreKey("params")
tkeyParams := sdk.NewTransientStoreKey("transient_params")
ms := store.NewCommitMultiStore(db)
ms.MountStoreWithDB(authCapKey, sdk.StoreTypeIAVL, db)
ms.MountStoreWithDB(keyParams, sdk.StoreTypeIAVL, db)
ms.MountStoreWithDB(tkeyParams, sdk.StoreTypeTransient, db)
ms.LoadLatestVersion()
blacklistedAddrs := make(map[string]bool)
blacklistedAddrs[sdk.AccAddress([]byte("moduleAcc")).String()] = true
pk := params.NewKeeper(cdc, keyParams, tkeyParams, params.DefaultCodespace)
ak := auth.NewAccountKeeper(
cdc, authCapKey, pk.Subspace(auth.DefaultParamspace), auth.ProtoBaseAccount,
)
ctx := sdk.NewContext(ms, abci.Header{ChainID: "test-chain-id"}, false, log.NewNopLogger())
ak.SetParams(ctx, auth.DefaultParams())
bankKeeper := NewBaseKeeper(ak, pk.Subspace(types.DefaultParamspace), types.DefaultCodespace, blacklistedAddrs)
bankKeeper.SetSendEnabled(ctx, true)
return testInput{cdc: cdc, ctx: ctx, k: bankKeeper, ak: ak, pk: pk}
}

View File

@ -7,7 +7,7 @@ import (
"github.com/cosmos/cosmos-sdk/x/bank/internal/types" "github.com/cosmos/cosmos-sdk/x/bank/internal/types"
) )
// register bank invariants // RegisterInvariants registers the bank module invariants
func RegisterInvariants(ir sdk.InvariantRegistry, ak types.AccountKeeper) { func RegisterInvariants(ir sdk.InvariantRegistry, ak types.AccountKeeper) {
ir.RegisterRoute(types.ModuleName, "nonnegative-outstanding", ir.RegisterRoute(types.ModuleName, "nonnegative-outstanding",
NonnegativeBalanceInvariant(ak)) NonnegativeBalanceInvariant(ak))

View File

@ -34,11 +34,11 @@ type BaseKeeper struct {
// NewBaseKeeper returns a new BaseKeeper // NewBaseKeeper returns a new BaseKeeper
func NewBaseKeeper(ak types.AccountKeeper, func NewBaseKeeper(ak types.AccountKeeper,
paramSpace params.Subspace, paramSpace params.Subspace,
codespace sdk.CodespaceType) BaseKeeper { codespace sdk.CodespaceType, blacklistedAddrs map[string]bool) BaseKeeper {
ps := paramSpace.WithKeyTable(types.ParamKeyTable()) ps := paramSpace.WithKeyTable(types.ParamKeyTable())
return BaseKeeper{ return BaseKeeper{
BaseSendKeeper: NewBaseSendKeeper(ak, ps, codespace), BaseSendKeeper: NewBaseSendKeeper(ak, ps, codespace, blacklistedAddrs),
ak: ak, ak: ak,
paramSpace: ps, paramSpace: ps,
} }
@ -145,6 +145,8 @@ type SendKeeper interface {
GetSendEnabled(ctx sdk.Context) bool GetSendEnabled(ctx sdk.Context) bool
SetSendEnabled(ctx sdk.Context, enabled bool) SetSendEnabled(ctx sdk.Context, enabled bool)
BlacklistedAddr(addr sdk.AccAddress) bool
} }
var _ SendKeeper = (*BaseSendKeeper)(nil) var _ SendKeeper = (*BaseSendKeeper)(nil)
@ -156,16 +158,20 @@ type BaseSendKeeper struct {
ak types.AccountKeeper ak types.AccountKeeper
paramSpace params.Subspace paramSpace params.Subspace
// list of addresses that are restricted from receiving transactions
blacklistedAddrs map[string]bool
} }
// NewBaseSendKeeper returns a new BaseSendKeeper. // NewBaseSendKeeper returns a new BaseSendKeeper.
func NewBaseSendKeeper(ak types.AccountKeeper, func NewBaseSendKeeper(ak types.AccountKeeper,
paramSpace params.Subspace, codespace sdk.CodespaceType) BaseSendKeeper { paramSpace params.Subspace, codespace sdk.CodespaceType, blacklistedAddrs map[string]bool) BaseSendKeeper {
return BaseSendKeeper{ return BaseSendKeeper{
BaseViewKeeper: NewBaseViewKeeper(ak, codespace), BaseViewKeeper: NewBaseViewKeeper(ak, codespace),
ak: ak, ak: ak,
paramSpace: paramSpace, paramSpace: paramSpace,
blacklistedAddrs: blacklistedAddrs,
} }
} }
@ -321,6 +327,12 @@ func (keeper BaseSendKeeper) SetSendEnabled(ctx sdk.Context, enabled bool) {
keeper.paramSpace.Set(ctx, types.ParamStoreKeySendEnabled, &enabled) keeper.paramSpace.Set(ctx, types.ParamStoreKeySendEnabled, &enabled)
} }
// BlacklistedAddr checks if a given address is blacklisted (i.e restricted from
// receiving funds)
func (keeper BaseSendKeeper) BlacklistedAddr(addr sdk.AccAddress) bool {
return keeper.blacklistedAddrs[addr.String()]
}
var _ ViewKeeper = (*BaseViewKeeper)(nil) var _ ViewKeeper = (*BaseViewKeeper)(nil)
// ViewKeeper defines a module interface that facilitates read only access to // ViewKeeper defines a module interface that facilitates read only access to

View File

@ -7,58 +7,13 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
abci "github.com/tendermint/tendermint/abci/types" abci "github.com/tendermint/tendermint/abci/types"
dbm "github.com/tendermint/tendermint/libs/db"
"github.com/tendermint/tendermint/libs/log"
tmtime "github.com/tendermint/tendermint/types/time" tmtime "github.com/tendermint/tendermint/types/time"
"github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/store"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth" "github.com/cosmos/cosmos-sdk/x/auth"
"github.com/cosmos/cosmos-sdk/x/bank/internal/types" "github.com/cosmos/cosmos-sdk/x/bank/internal/types"
"github.com/cosmos/cosmos-sdk/x/params"
) )
type testInput struct {
cdc *codec.Codec
ctx sdk.Context
k Keeper
ak auth.AccountKeeper
pk params.Keeper
}
func setupTestInput() testInput {
db := dbm.NewMemDB()
cdc := codec.New()
auth.RegisterCodec(cdc)
codec.RegisterCrypto(cdc)
authCapKey := sdk.NewKVStoreKey("authCapKey")
keyParams := sdk.NewKVStoreKey("params")
tkeyParams := sdk.NewTransientStoreKey("transient_params")
ms := store.NewCommitMultiStore(db)
ms.MountStoreWithDB(authCapKey, sdk.StoreTypeIAVL, db)
ms.MountStoreWithDB(keyParams, sdk.StoreTypeIAVL, db)
ms.MountStoreWithDB(tkeyParams, sdk.StoreTypeTransient, db)
ms.LoadLatestVersion()
pk := params.NewKeeper(cdc, keyParams, tkeyParams, params.DefaultCodespace)
ak := auth.NewAccountKeeper(
cdc, authCapKey, pk.Subspace(auth.DefaultParamspace), auth.ProtoBaseAccount,
)
ctx := sdk.NewContext(ms, abci.Header{ChainID: "test-chain-id"}, false, log.NewNopLogger())
ak.SetParams(ctx, auth.DefaultParams())
bankKeeper := NewBaseKeeper(ak, pk.Subspace(types.DefaultParamspace), types.DefaultCodespace)
bankKeeper.SetSendEnabled(ctx, true)
return testInput{cdc: cdc, ctx: ctx, k: bankKeeper, ak: ak, pk: pk}
}
func TestKeeper(t *testing.T) { func TestKeeper(t *testing.T) {
input := setupTestInput() input := setupTestInput()
ctx := input.ctx ctx := input.ctx
@ -140,8 +95,11 @@ func TestKeeper(t *testing.T) {
func TestSendKeeper(t *testing.T) { func TestSendKeeper(t *testing.T) {
input := setupTestInput() input := setupTestInput()
ctx := input.ctx ctx := input.ctx
blacklistedAddrs := make(map[string]bool)
paramSpace := input.pk.Subspace("newspace") paramSpace := input.pk.Subspace("newspace")
sendKeeper := NewBaseSendKeeper(input.ak, paramSpace, types.DefaultCodespace) sendKeeper := NewBaseSendKeeper(input.ak, paramSpace, types.DefaultCodespace, blacklistedAddrs)
input.k.SetSendEnabled(ctx, true) input.k.SetSendEnabled(ctx, true)
addr := sdk.AccAddress([]byte("addr1")) addr := sdk.AccAddress([]byte("addr1"))

View File

@ -19,16 +19,17 @@ type Keeper struct {
stakingKeeper types.StakingKeeper stakingKeeper types.StakingKeeper
supplyKeeper types.SupplyKeeper supplyKeeper types.SupplyKeeper
// codespace
codespace sdk.CodespaceType codespace sdk.CodespaceType
blacklistedAddrs map[string]bool
feeCollectorName string // name of the FeeCollector ModuleAccount feeCollectorName string // name of the FeeCollector ModuleAccount
} }
// NewKeeper creates a new distribution Keeper instance // NewKeeper creates a new distribution Keeper instance
func NewKeeper(cdc *codec.Codec, key sdk.StoreKey, paramSpace params.Subspace, func NewKeeper(cdc *codec.Codec, key sdk.StoreKey, paramSpace params.Subspace,
sk types.StakingKeeper, supplyKeeper types.SupplyKeeper, codespace sdk.CodespaceType, sk types.StakingKeeper, supplyKeeper types.SupplyKeeper, codespace sdk.CodespaceType,
feeCollectorName string) Keeper { feeCollectorName string, blacklistedAddrs map[string]bool) Keeper {
// ensure distribution module account is set // ensure distribution module account is set
if addr := supplyKeeper.GetModuleAddress(types.ModuleName); addr == nil { if addr := supplyKeeper.GetModuleAddress(types.ModuleName); addr == nil {
@ -43,6 +44,7 @@ func NewKeeper(cdc *codec.Codec, key sdk.StoreKey, paramSpace params.Subspace,
supplyKeeper: supplyKeeper, supplyKeeper: supplyKeeper,
codespace: codespace, codespace: codespace,
feeCollectorName: feeCollectorName, feeCollectorName: feeCollectorName,
blacklistedAddrs: blacklistedAddrs,
} }
} }
@ -51,8 +53,12 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger {
return ctx.Logger().With("module", fmt.Sprintf("x/%s", types.ModuleName)) return ctx.Logger().With("module", fmt.Sprintf("x/%s", types.ModuleName))
} }
// set withdraw address // SetWithdrawAddr sets a new address that will receive the rewards upon withdrawal
func (k Keeper) SetWithdrawAddr(ctx sdk.Context, delegatorAddr sdk.AccAddress, withdrawAddr sdk.AccAddress) sdk.Error { func (k Keeper) SetWithdrawAddr(ctx sdk.Context, delegatorAddr sdk.AccAddress, withdrawAddr sdk.AccAddress) sdk.Error {
if k.blacklistedAddrs[withdrawAddr.String()] {
return sdk.ErrUnauthorized(fmt.Sprintf("%s is blacklisted from receiving external funds", withdrawAddr))
}
if !k.GetWithdrawAddrEnabled(ctx) { if !k.GetWithdrawAddrEnabled(ctx) {
return types.ErrSetWithdrawAddrDisabled(k.codespace) return types.ErrSetWithdrawAddrDisabled(k.codespace)
} }

View File

@ -20,6 +20,9 @@ func TestSetWithdrawAddr(t *testing.T) {
err = keeper.SetWithdrawAddr(ctx, delAddr1, delAddr2) err = keeper.SetWithdrawAddr(ctx, delAddr1, delAddr2)
require.Nil(t, err) require.Nil(t, err)
keeper.blacklistedAddrs[distrAcc.GetAddress().String()] = true
require.Error(t, keeper.SetWithdrawAddr(ctx, delAddr1, distrAcc.GetAddress()))
} }
func TestWithdrawValidatorCommission(t *testing.T) { func TestWithdrawValidatorCommission(t *testing.T) {

View File

@ -59,6 +59,8 @@ var (
emptyDelAddr sdk.AccAddress emptyDelAddr sdk.AccAddress
emptyValAddr sdk.ValAddress emptyValAddr sdk.ValAddress
emptyPubkey crypto.PubKey emptyPubkey crypto.PubKey
distrAcc = supply.NewEmptyModuleAccount(types.ModuleName)
) )
// create a codec used only for testing // create a codec used only for testing
@ -114,12 +116,22 @@ func CreateTestInputAdvanced(t *testing.T, isCheckTx bool, initPower int64,
err := ms.LoadLatestVersion() err := ms.LoadLatestVersion()
require.Nil(t, err) require.Nil(t, err)
feeCollectorAcc := supply.NewEmptyModuleAccount(auth.FeeCollectorName)
notBondedPool := supply.NewEmptyModuleAccount(staking.NotBondedPoolName, supply.Burner, supply.Staking)
bondPool := supply.NewEmptyModuleAccount(staking.BondedPoolName, supply.Burner, supply.Staking)
blacklistedAddrs := make(map[string]bool)
blacklistedAddrs[feeCollectorAcc.String()] = true
blacklistedAddrs[notBondedPool.String()] = true
blacklistedAddrs[bondPool.String()] = true
blacklistedAddrs[distrAcc.String()] = true
cdc := MakeTestCodec() cdc := MakeTestCodec()
pk := params.NewKeeper(cdc, keyParams, tkeyParams, params.DefaultCodespace) pk := params.NewKeeper(cdc, keyParams, tkeyParams, params.DefaultCodespace)
ctx := sdk.NewContext(ms, abci.Header{ChainID: "foochainid"}, isCheckTx, log.NewNopLogger()) ctx := sdk.NewContext(ms, abci.Header{ChainID: "foochainid"}, isCheckTx, log.NewNopLogger())
accountKeeper := auth.NewAccountKeeper(cdc, keyAcc, pk.Subspace(auth.DefaultParamspace), auth.ProtoBaseAccount) accountKeeper := auth.NewAccountKeeper(cdc, keyAcc, pk.Subspace(auth.DefaultParamspace), auth.ProtoBaseAccount)
bankKeeper := bank.NewBaseKeeper(accountKeeper, pk.Subspace(bank.DefaultParamspace), bank.DefaultCodespace) bankKeeper := bank.NewBaseKeeper(accountKeeper, pk.Subspace(bank.DefaultParamspace), bank.DefaultCodespace, blacklistedAddrs)
maccPerms := map[string][]string{ maccPerms := map[string][]string{
auth.FeeCollectorName: nil, auth.FeeCollectorName: nil,
types.ModuleName: nil, types.ModuleName: nil,
@ -131,7 +143,7 @@ func CreateTestInputAdvanced(t *testing.T, isCheckTx bool, initPower int64,
sk := staking.NewKeeper(cdc, keyStaking, tkeyStaking, supplyKeeper, pk.Subspace(staking.DefaultParamspace), staking.DefaultCodespace) sk := staking.NewKeeper(cdc, keyStaking, tkeyStaking, supplyKeeper, pk.Subspace(staking.DefaultParamspace), staking.DefaultCodespace)
sk.SetParams(ctx, staking.DefaultParams()) sk.SetParams(ctx, staking.DefaultParams())
keeper := NewKeeper(cdc, keyDistr, pk.Subspace(DefaultParamspace), sk, supplyKeeper, types.DefaultCodespace, auth.FeeCollectorName) keeper := NewKeeper(cdc, keyDistr, pk.Subspace(DefaultParamspace), sk, supplyKeeper, types.DefaultCodespace, auth.FeeCollectorName, blacklistedAddrs)
initCoins := sdk.NewCoins(sdk.NewCoin(sk.BondDenom(ctx), initTokens)) initCoins := sdk.NewCoins(sdk.NewCoin(sk.BondDenom(ctx), initTokens))
totalSupply := sdk.NewCoins(sdk.NewCoin(sk.BondDenom(ctx), initTokens.MulRaw(int64(len(TestAddrs))))) totalSupply := sdk.NewCoins(sdk.NewCoin(sk.BondDenom(ctx), initTokens.MulRaw(int64(len(TestAddrs)))))
@ -143,12 +155,7 @@ func CreateTestInputAdvanced(t *testing.T, isCheckTx bool, initPower int64,
require.Nil(t, err) require.Nil(t, err)
} }
// create module accounts // set module accounts
feeCollectorAcc := supply.NewEmptyModuleAccount(auth.FeeCollectorName)
notBondedPool := supply.NewEmptyModuleAccount(staking.NotBondedPoolName, supply.Burner, supply.Staking)
bondPool := supply.NewEmptyModuleAccount(staking.BondedPoolName, supply.Burner, supply.Staking)
distrAcc := supply.NewEmptyModuleAccount(types.ModuleName)
keeper.supplyKeeper.SetModuleAccount(ctx, feeCollectorAcc) keeper.supplyKeeper.SetModuleAccount(ctx, feeCollectorAcc)
keeper.supplyKeeper.SetModuleAccount(ctx, notBondedPool) keeper.supplyKeeper.SetModuleAccount(ctx, notBondedPool)
keeper.supplyKeeper.SetModuleAccount(ctx, bondPool) keeper.supplyKeeper.SetModuleAccount(ctx, bondPool)

View File

@ -21,6 +21,7 @@ import (
"github.com/cosmos/cosmos-sdk/x/mock" "github.com/cosmos/cosmos-sdk/x/mock"
"github.com/cosmos/cosmos-sdk/x/staking" "github.com/cosmos/cosmos-sdk/x/staking"
"github.com/cosmos/cosmos-sdk/x/supply" "github.com/cosmos/cosmos-sdk/x/supply"
supplyexported "github.com/cosmos/cosmos-sdk/x/supply/exported"
) )
var ( var (
@ -52,12 +53,21 @@ func getMockApp(t *testing.T, numGenAccs int, genState GenesisState, genAccs []a
keyGov := sdk.NewKVStoreKey(StoreKey) keyGov := sdk.NewKVStoreKey(StoreKey)
keySupply := sdk.NewKVStoreKey(supply.StoreKey) keySupply := sdk.NewKVStoreKey(supply.StoreKey)
govAcc := supply.NewEmptyModuleAccount(types.ModuleName, supply.Burner)
notBondedPool := supply.NewEmptyModuleAccount(staking.NotBondedPoolName, supply.Burner, supply.Staking)
bondPool := supply.NewEmptyModuleAccount(staking.BondedPoolName, supply.Burner, supply.Staking)
blacklistedAddrs := make(map[string]bool)
blacklistedAddrs[govAcc.String()] = true
blacklistedAddrs[notBondedPool.String()] = true
blacklistedAddrs[bondPool.String()] = true
pk := mApp.ParamsKeeper pk := mApp.ParamsKeeper
rtr := NewRouter(). rtr := NewRouter().
AddRoute(RouterKey, ProposalHandler) AddRoute(RouterKey, ProposalHandler)
bk := bank.NewBaseKeeper(mApp.AccountKeeper, mApp.ParamsKeeper.Subspace(bank.DefaultParamspace), bank.DefaultCodespace) bk := bank.NewBaseKeeper(mApp.AccountKeeper, mApp.ParamsKeeper.Subspace(bank.DefaultParamspace), bank.DefaultCodespace, blacklistedAddrs)
maccPerms := map[string][]string{ maccPerms := map[string][]string{
types.ModuleName: {supply.Burner}, types.ModuleName: {supply.Burner},
@ -67,13 +77,14 @@ func getMockApp(t *testing.T, numGenAccs int, genState GenesisState, genAccs []a
supplyKeeper := supply.NewKeeper(mApp.Cdc, keySupply, mApp.AccountKeeper, bk, supply.DefaultCodespace, maccPerms) supplyKeeper := supply.NewKeeper(mApp.Cdc, keySupply, mApp.AccountKeeper, bk, supply.DefaultCodespace, maccPerms)
sk := staking.NewKeeper(mApp.Cdc, keyStaking, tKeyStaking, supplyKeeper, pk.Subspace(staking.DefaultParamspace), staking.DefaultCodespace) sk := staking.NewKeeper(mApp.Cdc, keyStaking, tKeyStaking, supplyKeeper, pk.Subspace(staking.DefaultParamspace), staking.DefaultCodespace)
keeper := NewKeeper(mApp.Cdc, keyGov, pk, pk.Subspace("testgov"), supplyKeeper, sk, DefaultCodespace, rtr) keeper := NewKeeper(mApp.Cdc, keyGov, pk, pk.Subspace(DefaultParamspace), supplyKeeper, sk, DefaultCodespace, rtr)
mApp.Router().AddRoute(RouterKey, NewHandler(keeper)) mApp.Router().AddRoute(RouterKey, NewHandler(keeper))
mApp.QueryRouter().AddRoute(QuerierRoute, NewQuerier(keeper)) mApp.QueryRouter().AddRoute(QuerierRoute, NewQuerier(keeper))
mApp.SetEndBlocker(getEndBlocker(keeper)) mApp.SetEndBlocker(getEndBlocker(keeper))
mApp.SetInitChainer(getInitChainer(mApp, keeper, sk, supplyKeeper, genAccs, genState)) mApp.SetInitChainer(getInitChainer(mApp, keeper, sk, supplyKeeper, genAccs, genState,
[]supplyexported.ModuleAccountI{govAcc, notBondedPool, bondPool}))
require.NoError(t, mApp.CompleteSetup(keyStaking, tKeyStaking, keyGov, keySupply)) require.NoError(t, mApp.CompleteSetup(keyStaking, tKeyStaking, keyGov, keySupply))
@ -101,7 +112,8 @@ func getEndBlocker(keeper Keeper) sdk.EndBlocker {
} }
// gov and staking initchainer // gov and staking initchainer
func getInitChainer(mapp *mock.App, keeper Keeper, stakingKeeper staking.Keeper, supplyKeeper supply.Keeper, accs []auth.Account, genState GenesisState) sdk.InitChainer { func getInitChainer(mapp *mock.App, keeper Keeper, stakingKeeper staking.Keeper, supplyKeeper supply.Keeper, accs []auth.Account, genState GenesisState,
blacklistedAddrs []supplyexported.ModuleAccountI) sdk.InitChainer {
return func(ctx sdk.Context, req abci.RequestInitChain) abci.ResponseInitChain { return func(ctx sdk.Context, req abci.RequestInitChain) abci.ResponseInitChain {
mapp.InitChainer(ctx, req) mapp.InitChainer(ctx, req)
@ -111,13 +123,9 @@ func getInitChainer(mapp *mock.App, keeper Keeper, stakingKeeper staking.Keeper,
supplyKeeper.SetSupply(ctx, supply.NewSupply(totalSupply)) supplyKeeper.SetSupply(ctx, supply.NewSupply(totalSupply))
// set module accounts // set module accounts
govAcc := supply.NewEmptyModuleAccount(types.ModuleName, supply.Burner) for _, macc := range blacklistedAddrs {
notBondedPool := supply.NewEmptyModuleAccount(staking.NotBondedPoolName, supply.Burner, supply.Staking) supplyKeeper.SetModuleAccount(ctx, macc)
bondPool := supply.NewEmptyModuleAccount(staking.BondedPoolName, supply.Burner, supply.Staking) }
supplyKeeper.SetModuleAccount(ctx, govAcc)
supplyKeeper.SetModuleAccount(ctx, notBondedPool)
supplyKeeper.SetModuleAccount(ctx, bondPool)
validators := staking.InitGenesis(ctx, stakingKeeper, mapp.AccountKeeper, supplyKeeper, stakingGenesis) validators := staking.InitGenesis(ctx, stakingKeeper, mapp.AccountKeeper, supplyKeeper, stakingGenesis)
if genState.IsEmpty() { if genState.IsEmpty() {

View File

@ -54,9 +54,20 @@ func newTestInput(t *testing.T) testInput {
ctx := sdk.NewContext(ms, abci.Header{Time: time.Unix(0, 0)}, false, log.NewTMLogger(os.Stdout)) ctx := sdk.NewContext(ms, abci.Header{Time: time.Unix(0, 0)}, false, log.NewTMLogger(os.Stdout))
feeCollectorAcc := supply.NewEmptyModuleAccount(auth.FeeCollectorName)
notBondedPool := supply.NewEmptyModuleAccount(staking.NotBondedPoolName, supply.Burner, supply.Staking)
bondPool := supply.NewEmptyModuleAccount(staking.BondedPoolName, supply.Burner, supply.Staking)
minterAcc := supply.NewEmptyModuleAccount(types.ModuleName, supply.Minter)
blacklistedAddrs := make(map[string]bool)
blacklistedAddrs[feeCollectorAcc.String()] = true
blacklistedAddrs[notBondedPool.String()] = true
blacklistedAddrs[bondPool.String()] = true
blacklistedAddrs[minterAcc.String()] = true
paramsKeeper := params.NewKeeper(types.ModuleCdc, keyParams, tkeyParams, params.DefaultCodespace) paramsKeeper := params.NewKeeper(types.ModuleCdc, keyParams, tkeyParams, params.DefaultCodespace)
accountKeeper := auth.NewAccountKeeper(types.ModuleCdc, keyAcc, paramsKeeper.Subspace(auth.DefaultParamspace), auth.ProtoBaseAccount) accountKeeper := auth.NewAccountKeeper(types.ModuleCdc, keyAcc, paramsKeeper.Subspace(auth.DefaultParamspace), auth.ProtoBaseAccount)
bankKeeper := bank.NewBaseKeeper(accountKeeper, paramsKeeper.Subspace(bank.DefaultParamspace), bank.DefaultCodespace) bankKeeper := bank.NewBaseKeeper(accountKeeper, paramsKeeper.Subspace(bank.DefaultParamspace), bank.DefaultCodespace, blacklistedAddrs)
maccPerms := map[string][]string{ maccPerms := map[string][]string{
auth.FeeCollectorName: nil, auth.FeeCollectorName: nil,
types.ModuleName: {supply.Minter}, types.ModuleName: {supply.Minter},
@ -72,11 +83,6 @@ func newTestInput(t *testing.T) testInput {
mintKeeper := NewKeeper(types.ModuleCdc, keyMint, paramsKeeper.Subspace(types.DefaultParamspace), &stakingKeeper, supplyKeeper, auth.FeeCollectorName) mintKeeper := NewKeeper(types.ModuleCdc, keyMint, paramsKeeper.Subspace(types.DefaultParamspace), &stakingKeeper, supplyKeeper, auth.FeeCollectorName)
// set module accounts // set module accounts
feeCollectorAcc := supply.NewEmptyModuleAccount(auth.FeeCollectorName)
minterAcc := supply.NewEmptyModuleAccount(types.ModuleName, supply.Minter)
notBondedPool := supply.NewEmptyModuleAccount(staking.NotBondedPoolName, supply.Burner)
bondPool := supply.NewEmptyModuleAccount(staking.BondedPoolName, supply.Burner)
supplyKeeper.SetModuleAccount(ctx, feeCollectorAcc) supplyKeeper.SetModuleAccount(ctx, feeCollectorAcc)
supplyKeeper.SetModuleAccount(ctx, minterAcc) supplyKeeper.SetModuleAccount(ctx, minterAcc)
supplyKeeper.SetModuleAccount(ctx, notBondedPool) supplyKeeper.SetModuleAccount(ctx, notBondedPool)

View File

@ -16,6 +16,7 @@ import (
"github.com/cosmos/cosmos-sdk/x/staking" "github.com/cosmos/cosmos-sdk/x/staking"
"github.com/cosmos/cosmos-sdk/x/staking/types" "github.com/cosmos/cosmos-sdk/x/staking/types"
"github.com/cosmos/cosmos-sdk/x/supply" "github.com/cosmos/cosmos-sdk/x/supply"
supplyexported "github.com/cosmos/cosmos-sdk/x/supply/exported"
) )
var ( var (
@ -37,7 +38,16 @@ func getMockApp(t *testing.T) (*mock.App, staking.Keeper, Keeper) {
keySlashing := sdk.NewKVStoreKey(StoreKey) keySlashing := sdk.NewKVStoreKey(StoreKey)
keySupply := sdk.NewKVStoreKey(supply.StoreKey) keySupply := sdk.NewKVStoreKey(supply.StoreKey)
bankKeeper := bank.NewBaseKeeper(mapp.AccountKeeper, mapp.ParamsKeeper.Subspace(bank.DefaultParamspace), bank.DefaultCodespace) feeCollector := supply.NewEmptyModuleAccount(auth.FeeCollectorName)
notBondedPool := supply.NewEmptyModuleAccount(types.NotBondedPoolName, supply.Burner, supply.Staking)
bondPool := supply.NewEmptyModuleAccount(types.BondedPoolName, supply.Burner, supply.Staking)
blacklistedAddrs := make(map[string]bool)
blacklistedAddrs[feeCollector.String()] = true
blacklistedAddrs[notBondedPool.String()] = true
blacklistedAddrs[bondPool.String()] = true
bankKeeper := bank.NewBaseKeeper(mapp.AccountKeeper, mapp.ParamsKeeper.Subspace(bank.DefaultParamspace), bank.DefaultCodespace, blacklistedAddrs)
maccPerms := map[string][]string{ maccPerms := map[string][]string{
auth.FeeCollectorName: nil, auth.FeeCollectorName: nil,
staking.NotBondedPoolName: {supply.Burner, supply.Staking}, staking.NotBondedPoolName: {supply.Burner, supply.Staking},
@ -50,7 +60,8 @@ func getMockApp(t *testing.T) (*mock.App, staking.Keeper, Keeper) {
mapp.Router().AddRoute(RouterKey, NewHandler(keeper)) mapp.Router().AddRoute(RouterKey, NewHandler(keeper))
mapp.SetEndBlocker(getEndBlocker(stakingKeeper)) mapp.SetEndBlocker(getEndBlocker(stakingKeeper))
mapp.SetInitChainer(getInitChainer(mapp, stakingKeeper, mapp.AccountKeeper, supplyKeeper)) mapp.SetInitChainer(getInitChainer(mapp, stakingKeeper, mapp.AccountKeeper, supplyKeeper,
[]supplyexported.ModuleAccountI{feeCollector, notBondedPool, bondPool}))
require.NoError(t, mapp.CompleteSetup(keyStaking, tkeyStaking, keySupply, keySlashing)) require.NoError(t, mapp.CompleteSetup(keyStaking, tkeyStaking, keySupply, keySlashing))
@ -68,16 +79,13 @@ func getEndBlocker(keeper staking.Keeper) sdk.EndBlocker {
} }
// overwrite the mock init chainer // overwrite the mock init chainer
func getInitChainer(mapp *mock.App, keeper staking.Keeper, accountKeeper types.AccountKeeper, supplyKeeper types.SupplyKeeper) sdk.InitChainer { func getInitChainer(mapp *mock.App, keeper staking.Keeper, accountKeeper types.AccountKeeper, supplyKeeper types.SupplyKeeper,
blacklistedAddrs []supplyexported.ModuleAccountI) sdk.InitChainer {
return func(ctx sdk.Context, req abci.RequestInitChain) abci.ResponseInitChain { return func(ctx sdk.Context, req abci.RequestInitChain) abci.ResponseInitChain {
// set module accounts // set module accounts
feeCollector := supply.NewEmptyModuleAccount(auth.FeeCollectorName) for _, macc := range blacklistedAddrs {
notBondedPool := supply.NewEmptyModuleAccount(types.NotBondedPoolName, supply.Burner, supply.Staking) supplyKeeper.SetModuleAccount(ctx, macc)
bondPool := supply.NewEmptyModuleAccount(types.BondedPoolName, supply.Burner, supply.Staking) }
supplyKeeper.SetModuleAccount(ctx, feeCollector)
supplyKeeper.SetModuleAccount(ctx, bondPool)
supplyKeeper.SetModuleAccount(ctx, notBondedPool)
mapp.InitChainer(ctx, req) mapp.InitChainer(ctx, req)
stakingGenesis := staking.DefaultGenesisState() stakingGenesis := staking.DefaultGenesisState()

View File

@ -63,7 +63,9 @@ func CreateTestInput(t *testing.T, defaults types.Params) (sdk.Context, bank.Kee
keySupply := sdk.NewKVStoreKey(supply.StoreKey) keySupply := sdk.NewKVStoreKey(supply.StoreKey)
keyParams := sdk.NewKVStoreKey(params.StoreKey) keyParams := sdk.NewKVStoreKey(params.StoreKey)
tkeyParams := sdk.NewTransientStoreKey(params.TStoreKey) tkeyParams := sdk.NewTransientStoreKey(params.TStoreKey)
db := dbm.NewMemDB() db := dbm.NewMemDB()
ms := store.NewCommitMultiStore(db) ms := store.NewCommitMultiStore(db)
ms.MountStoreWithDB(keyAcc, sdk.StoreTypeIAVL, db) ms.MountStoreWithDB(keyAcc, sdk.StoreTypeIAVL, db)
ms.MountStoreWithDB(tkeyStaking, sdk.StoreTypeTransient, nil) ms.MountStoreWithDB(tkeyStaking, sdk.StoreTypeTransient, nil)
@ -72,14 +74,26 @@ func CreateTestInput(t *testing.T, defaults types.Params) (sdk.Context, bank.Kee
ms.MountStoreWithDB(keySlashing, sdk.StoreTypeIAVL, db) ms.MountStoreWithDB(keySlashing, sdk.StoreTypeIAVL, db)
ms.MountStoreWithDB(keyParams, sdk.StoreTypeIAVL, db) ms.MountStoreWithDB(keyParams, sdk.StoreTypeIAVL, db)
ms.MountStoreWithDB(tkeyParams, sdk.StoreTypeTransient, db) ms.MountStoreWithDB(tkeyParams, sdk.StoreTypeTransient, db)
err := ms.LoadLatestVersion() err := ms.LoadLatestVersion()
require.Nil(t, err) require.Nil(t, err)
ctx := sdk.NewContext(ms, abci.Header{Time: time.Unix(0, 0)}, false, log.NewNopLogger()) ctx := sdk.NewContext(ms, abci.Header{Time: time.Unix(0, 0)}, false, log.NewNopLogger())
cdc := createTestCodec() cdc := createTestCodec()
feeCollectorAcc := supply.NewEmptyModuleAccount(auth.FeeCollectorName)
notBondedPool := supply.NewEmptyModuleAccount(staking.NotBondedPoolName, supply.Burner, supply.Staking)
bondPool := supply.NewEmptyModuleAccount(staking.BondedPoolName, supply.Burner, supply.Staking)
blacklistedAddrs := make(map[string]bool)
blacklistedAddrs[feeCollectorAcc.String()] = true
blacklistedAddrs[notBondedPool.String()] = true
blacklistedAddrs[bondPool.String()] = true
paramsKeeper := params.NewKeeper(cdc, keyParams, tkeyParams, params.DefaultCodespace) paramsKeeper := params.NewKeeper(cdc, keyParams, tkeyParams, params.DefaultCodespace)
accountKeeper := auth.NewAccountKeeper(cdc, keyAcc, paramsKeeper.Subspace(auth.DefaultParamspace), auth.ProtoBaseAccount) accountKeeper := auth.NewAccountKeeper(cdc, keyAcc, paramsKeeper.Subspace(auth.DefaultParamspace), auth.ProtoBaseAccount)
bk := bank.NewBaseKeeper(accountKeeper, paramsKeeper.Subspace(bank.DefaultParamspace), bank.DefaultCodespace) bk := bank.NewBaseKeeper(accountKeeper, paramsKeeper.Subspace(bank.DefaultParamspace), bank.DefaultCodespace, blacklistedAddrs)
maccPerms := map[string][]string{ maccPerms := map[string][]string{
auth.FeeCollectorName: nil, auth.FeeCollectorName: nil,
staking.NotBondedPoolName: {supply.Burner, supply.Staking}, staking.NotBondedPoolName: {supply.Burner, supply.Staking},
@ -94,10 +108,6 @@ func CreateTestInput(t *testing.T, defaults types.Params) (sdk.Context, bank.Kee
genesis := staking.DefaultGenesisState() genesis := staking.DefaultGenesisState()
// set module accounts // set module accounts
feeCollectorAcc := supply.NewEmptyModuleAccount(auth.FeeCollectorName)
notBondedPool := supply.NewEmptyModuleAccount(staking.NotBondedPoolName, supply.Burner, supply.Staking)
bondPool := supply.NewEmptyModuleAccount(staking.BondedPoolName, supply.Burner, supply.Staking)
supplyKeeper.SetModuleAccount(ctx, feeCollectorAcc) supplyKeeper.SetModuleAccount(ctx, feeCollectorAcc)
supplyKeeper.SetModuleAccount(ctx, bondPool) supplyKeeper.SetModuleAccount(ctx, bondPool)
supplyKeeper.SetModuleAccount(ctx, notBondedPool) supplyKeeper.SetModuleAccount(ctx, notBondedPool)

View File

@ -12,6 +12,7 @@ import (
"github.com/cosmos/cosmos-sdk/x/mock" "github.com/cosmos/cosmos-sdk/x/mock"
"github.com/cosmos/cosmos-sdk/x/staking/types" "github.com/cosmos/cosmos-sdk/x/staking/types"
"github.com/cosmos/cosmos-sdk/x/supply" "github.com/cosmos/cosmos-sdk/x/supply"
supplyexported "github.com/cosmos/cosmos-sdk/x/supply/exported"
) )
// getMockApp returns an initialized mock application for this module. // getMockApp returns an initialized mock application for this module.
@ -25,7 +26,16 @@ func getMockApp(t *testing.T) (*mock.App, Keeper) {
tkeyStaking := sdk.NewTransientStoreKey(TStoreKey) tkeyStaking := sdk.NewTransientStoreKey(TStoreKey)
keySupply := sdk.NewKVStoreKey(supply.StoreKey) keySupply := sdk.NewKVStoreKey(supply.StoreKey)
bankKeeper := bank.NewBaseKeeper(mApp.AccountKeeper, mApp.ParamsKeeper.Subspace(bank.DefaultParamspace), bank.DefaultCodespace) feeCollector := supply.NewEmptyModuleAccount(auth.FeeCollectorName)
notBondedPool := supply.NewEmptyModuleAccount(types.NotBondedPoolName, supply.Burner, supply.Staking)
bondPool := supply.NewEmptyModuleAccount(types.BondedPoolName, supply.Burner, supply.Staking)
blacklistedAddrs := make(map[string]bool)
blacklistedAddrs[feeCollector.String()] = true
blacklistedAddrs[notBondedPool.String()] = true
blacklistedAddrs[bondPool.String()] = true
bankKeeper := bank.NewBaseKeeper(mApp.AccountKeeper, mApp.ParamsKeeper.Subspace(bank.DefaultParamspace), bank.DefaultCodespace, blacklistedAddrs)
maccPerms := map[string][]string{ maccPerms := map[string][]string{
auth.FeeCollectorName: nil, auth.FeeCollectorName: nil,
types.NotBondedPoolName: {supply.Burner, supply.Staking}, types.NotBondedPoolName: {supply.Burner, supply.Staking},
@ -36,7 +46,8 @@ func getMockApp(t *testing.T) (*mock.App, Keeper) {
mApp.Router().AddRoute(RouterKey, NewHandler(keeper)) mApp.Router().AddRoute(RouterKey, NewHandler(keeper))
mApp.SetEndBlocker(getEndBlocker(keeper)) mApp.SetEndBlocker(getEndBlocker(keeper))
mApp.SetInitChainer(getInitChainer(mApp, keeper, mApp.AccountKeeper, supplyKeeper)) mApp.SetInitChainer(getInitChainer(mApp, keeper, mApp.AccountKeeper, supplyKeeper,
[]supplyexported.ModuleAccountI{feeCollector, notBondedPool, bondPool}))
require.NoError(t, mApp.CompleteSetup(keyStaking, tkeyStaking, keySupply)) require.NoError(t, mApp.CompleteSetup(keyStaking, tkeyStaking, keySupply))
return mApp, keeper return mApp, keeper
@ -55,18 +66,15 @@ func getEndBlocker(keeper Keeper) sdk.EndBlocker {
// getInitChainer initializes the chainer of the mock app and sets the genesis // getInitChainer initializes the chainer of the mock app and sets the genesis
// state. It returns an empty ResponseInitChain. // state. It returns an empty ResponseInitChain.
func getInitChainer(mapp *mock.App, keeper Keeper, accountKeeper types.AccountKeeper, supplyKeeper types.SupplyKeeper) sdk.InitChainer { func getInitChainer(mapp *mock.App, keeper Keeper, accountKeeper types.AccountKeeper, supplyKeeper types.SupplyKeeper,
blacklistedAddrs []supplyexported.ModuleAccountI) sdk.InitChainer {
return func(ctx sdk.Context, req abci.RequestInitChain) abci.ResponseInitChain { return func(ctx sdk.Context, req abci.RequestInitChain) abci.ResponseInitChain {
mapp.InitChainer(ctx, req) mapp.InitChainer(ctx, req)
// set module accounts // set module accounts
feeCollector := supply.NewEmptyModuleAccount(auth.FeeCollectorName) for _, macc := range blacklistedAddrs {
notBondedPool := supply.NewEmptyModuleAccount(types.NotBondedPoolName, supply.Burner, supply.Staking) supplyKeeper.SetModuleAccount(ctx, macc)
bondPool := supply.NewEmptyModuleAccount(types.BondedPoolName, supply.Burner, supply.Staking) }
supplyKeeper.SetModuleAccount(ctx, feeCollector)
supplyKeeper.SetModuleAccount(ctx, bondPool)
supplyKeeper.SetModuleAccount(ctx, notBondedPool)
stakingGenesis := DefaultGenesisState() stakingGenesis := DefaultGenesisState()
validators := InitGenesis(ctx, keeper, accountKeeper, supplyKeeper, stakingGenesis) validators := InitGenesis(ctx, keeper, accountKeeper, supplyKeeper, stakingGenesis)

View File

@ -109,6 +109,15 @@ func CreateTestInput(t *testing.T, isCheckTx bool, initPower int64) (sdk.Context
) )
cdc := MakeTestCodec() cdc := MakeTestCodec()
feeCollectorAcc := supply.NewEmptyModuleAccount(auth.FeeCollectorName)
notBondedPool := supply.NewEmptyModuleAccount(types.NotBondedPoolName, supply.Burner, supply.Staking)
bondPool := supply.NewEmptyModuleAccount(types.BondedPoolName, supply.Burner, supply.Staking)
blacklistedAddrs := make(map[string]bool)
blacklistedAddrs[feeCollectorAcc.String()] = true
blacklistedAddrs[notBondedPool.String()] = true
blacklistedAddrs[bondPool.String()] = true
pk := params.NewKeeper(cdc, keyParams, tkeyParams, params.DefaultCodespace) pk := params.NewKeeper(cdc, keyParams, tkeyParams, params.DefaultCodespace)
accountKeeper := auth.NewAccountKeeper( accountKeeper := auth.NewAccountKeeper(
@ -122,6 +131,7 @@ func CreateTestInput(t *testing.T, isCheckTx bool, initPower int64) (sdk.Context
accountKeeper, accountKeeper,
pk.Subspace(bank.DefaultParamspace), pk.Subspace(bank.DefaultParamspace),
bank.DefaultCodespace, bank.DefaultCodespace,
blacklistedAddrs,
) )
maccPerms := map[string][]string{ maccPerms := map[string][]string{
@ -141,10 +151,6 @@ func CreateTestInput(t *testing.T, isCheckTx bool, initPower int64) (sdk.Context
keeper.SetParams(ctx, types.DefaultParams()) keeper.SetParams(ctx, types.DefaultParams())
// set module accounts // set module accounts
feeCollectorAcc := supply.NewEmptyModuleAccount(auth.FeeCollectorName)
notBondedPool := supply.NewEmptyModuleAccount(types.NotBondedPoolName, supply.Burner, supply.Staking)
bondPool := supply.NewEmptyModuleAccount(types.BondedPoolName, supply.Burner, supply.Staking)
err = notBondedPool.SetCoins(totalSupply) err = notBondedPool.SetCoins(totalSupply)
require.NoError(t, err) require.NoError(t, err)

View File

@ -70,9 +70,11 @@ func createTestInput(t *testing.T, isCheckTx bool, initPower int64, nAccs int64)
) )
cdc := makeTestCodec() cdc := makeTestCodec()
blacklistedAddrs := make(map[string]bool)
pk := params.NewKeeper(cdc, keyParams, tkeyParams, params.DefaultCodespace) pk := params.NewKeeper(cdc, keyParams, tkeyParams, params.DefaultCodespace)
ak := auth.NewAccountKeeper(cdc, keyAcc, pk.Subspace(auth.DefaultParamspace), auth.ProtoBaseAccount) ak := auth.NewAccountKeeper(cdc, keyAcc, pk.Subspace(auth.DefaultParamspace), auth.ProtoBaseAccount)
bk := bank.NewBaseKeeper(ak, pk.Subspace(bank.DefaultParamspace), bank.DefaultCodespace) bk := bank.NewBaseKeeper(ak, pk.Subspace(bank.DefaultParamspace), bank.DefaultCodespace, blacklistedAddrs)
valTokens := sdk.TokensFromConsensusPower(initPower) valTokens := sdk.TokensFromConsensusPower(initPower)