refactor: simplify hooks implementation (#13396)

This commit is contained in:
Julien Robert 2022-09-27 20:37:26 +02:00 committed by GitHub
parent dcb0c9c04c
commit 53519ea5b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 209 additions and 212 deletions

View File

@ -146,6 +146,7 @@ cosmovisor:
mocks: $(MOCKS_DIR) mocks: $(MOCKS_DIR)
@go install github.com/golang/mock/mockgen@v1.6.0
sh ./scripts/mockgen.sh sh ./scripts/mockgen.sh
.PHONY: mocks .PHONY: mocks

View File

@ -1,9 +1,8 @@
#!/usr/bin/env bash #!/usr/bin/env bash
mockgen_cmd="go run github.com/golang/mock/mockgen" mockgen_cmd="mockgen"
$mockgen_cmd -source=client/account_retriever.go -package mock -destination testutil/mock/account_retriever.go $mockgen_cmd -source=client/account_retriever.go -package mock -destination testutil/mock/account_retriever.go
$mockgen_cmd -package mock -destination testutil/mock/tendermint_tm_db_DB.go github.com/tendermint/tm-db DB $mockgen_cmd -package mock -destination testutil/mock/tendermint_tm_db_DB.go github.com/tendermint/tm-db DB
$mockgen_cmd -source db/types.go -package mock -destination testutil/mock/db/types.go
$mockgen_cmd -source=types/module/module.go -package mock -destination testutil/mock/types_module_module.go $mockgen_cmd -source=types/module/module.go -package mock -destination testutil/mock/types_module_module.go
$mockgen_cmd -source=types/invariant.go -package mock -destination testutil/mock/types_invariant.go $mockgen_cmd -source=types/invariant.go -package mock -destination testutil/mock/types_invariant.go
$mockgen_cmd -source=types/router.go -package mock -destination testutil/mock/types_router.go $mockgen_cmd -source=types/router.go -package mock -destination testutil/mock/types_router.go

View File

@ -88,6 +88,9 @@ func GetNonKey(allkeys [][]byte, loc tmproofs.Where) []byte {
// returns a list of all keys in sorted order // returns a list of all keys in sorted order
func BuildTree(size int) (tree *iavl.MutableTree, keys [][]byte, err error) { func BuildTree(size int) (tree *iavl.MutableTree, keys [][]byte, err error) {
tree, err = iavl.NewMutableTree(db.NewMemDB(), 0) tree, err = iavl.NewMutableTree(db.NewMemDB(), 0)
if err != nil {
return nil, nil, err
}
// insert lots of info and store the bytes // insert lots of info and store the bytes
keys = make([][]byte, size) keys = make([][]byte, size)

View File

@ -14,7 +14,9 @@ type Hooks struct {
var _ stakingtypes.StakingHooks = Hooks{} var _ stakingtypes.StakingHooks = Hooks{}
// Create new distribution hooks // Create new distribution hooks
func (k Keeper) Hooks() Hooks { return Hooks{k} } func (k Keeper) Hooks() Hooks {
return Hooks{k}
}
// initialize validator distribution record // initialize validator distribution record
func (h Hooks) AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) error { func (h Hooks) AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) error {
@ -109,7 +111,10 @@ func (h Hooks) BeforeValidatorSlashed(ctx sdk.Context, valAddr sdk.ValAddress, f
return nil return nil
} }
func (h Hooks) BeforeValidatorModified(_ sdk.Context, _ sdk.ValAddress) error { return nil } func (h Hooks) BeforeValidatorModified(_ sdk.Context, _ sdk.ValAddress) error {
return nil
}
func (h Hooks) AfterValidatorBonded(_ sdk.Context, _ sdk.ConsAddress, _ sdk.ValAddress) error { func (h Hooks) AfterValidatorBonded(_ sdk.Context, _ sdk.ConsAddress, _ sdk.ValAddress) error {
return nil return nil
} }

View File

@ -24,7 +24,7 @@ func EndBlocker(ctx sdk.Context, keeper *keeper.Keeper) {
keeper.RefundAndDeleteDeposits(ctx, proposal.Id) // refund deposit if proposal got removed without getting 100% of the proposal keeper.RefundAndDeleteDeposits(ctx, proposal.Id) // refund deposit if proposal got removed without getting 100% of the proposal
// called when proposal become inactive // called when proposal become inactive
keeper.AfterProposalFailedMinDeposit(ctx, proposal.Id) keeper.Hooks().AfterProposalFailedMinDeposit(ctx, proposal.Id)
ctx.EventManager().EmitEvent( ctx.EventManager().EmitEvent(
sdk.NewEvent( sdk.NewEvent(
@ -104,7 +104,7 @@ func EndBlocker(ctx sdk.Context, keeper *keeper.Keeper) {
keeper.RemoveFromActiveProposalQueue(ctx, proposal.Id, *proposal.VotingEndTime) keeper.RemoveFromActiveProposalQueue(ctx, proposal.Id, *proposal.VotingEndTime)
// when proposal become active // when proposal become active
keeper.AfterProposalVotingPeriodEnded(ctx, proposal.Id) keeper.Hooks().AfterProposalVotingPeriodEnded(ctx, proposal.Id)
logger.Info( logger.Info(
"proposal tallied", "proposal tallied",

View File

@ -147,7 +147,7 @@ func (keeper Keeper) AddDeposit(ctx sdk.Context, proposalID uint64, depositorAdd
} }
// called when deposit has been added to a proposal, however the proposal may not be active // called when deposit has been added to a proposal, however the proposal may not be active
keeper.AfterProposalDeposit(ctx, proposalID, depositorAddr) keeper.Hooks().AfterProposalDeposit(ctx, proposalID, depositorAddr)
ctx.EventManager().EmitEvent( ctx.EventManager().EmitEvent(
sdk.NewEvent( sdk.NewEvent(

View File

@ -1,44 +0,0 @@
package keeper
import (
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/gov/types"
)
// Implements GovHooks interface
var _ types.GovHooks = Keeper{}
// AfterProposalSubmission - call hook if registered
func (keeper Keeper) AfterProposalSubmission(ctx sdk.Context, proposalID uint64) {
if keeper.hooks != nil {
keeper.hooks.AfterProposalSubmission(ctx, proposalID)
}
}
// AfterProposalDeposit - call hook if registered
func (keeper Keeper) AfterProposalDeposit(ctx sdk.Context, proposalID uint64, depositorAddr sdk.AccAddress) {
if keeper.hooks != nil {
keeper.hooks.AfterProposalDeposit(ctx, proposalID, depositorAddr)
}
}
// AfterProposalVote - call hook if registered
func (keeper Keeper) AfterProposalVote(ctx sdk.Context, proposalID uint64, voterAddr sdk.AccAddress) {
if keeper.hooks != nil {
keeper.hooks.AfterProposalVote(ctx, proposalID, voterAddr)
}
}
// AfterProposalFailedMinDeposit - call hook if registered
func (keeper Keeper) AfterProposalFailedMinDeposit(ctx sdk.Context, proposalID uint64) {
if keeper.hooks != nil {
keeper.hooks.AfterProposalFailedMinDeposit(ctx, proposalID)
}
}
// AfterProposalVotingPeriodEnded - call hook if registered
func (keeper Keeper) AfterProposalVotingPeriodEnded(ctx sdk.Context, proposalID uint64) {
if keeper.hooks != nil {
keeper.hooks.AfterProposalVotingPeriodEnded(ctx, proposalID)
}
}

View File

@ -92,6 +92,16 @@ func NewKeeper(
} }
} }
// Hooks gets the hooks for governance *Keeper {
func (keeper *Keeper) Hooks() types.GovHooks {
if keeper.hooks == nil {
// return a no-op implementation if no hooks are set
return types.MultiGovHooks{}
}
return keeper.hooks
}
// SetHooks sets the hooks for governance // SetHooks sets the hooks for governance
func (keeper *Keeper) SetHooks(gh types.GovHooks) *Keeper { func (keeper *Keeper) SetHooks(gh types.GovHooks) *Keeper {
if keeper.hooks != nil { if keeper.hooks != nil {

View File

@ -82,7 +82,7 @@ func (keeper Keeper) SubmitProposal(ctx sdk.Context, messages []sdk.Msg, metadat
keeper.SetProposalID(ctx, proposalID+1) keeper.SetProposalID(ctx, proposalID+1)
// called right after a proposal is submitted // called right after a proposal is submitted
keeper.AfterProposalSubmission(ctx, proposalID) keeper.Hooks().AfterProposalSubmission(ctx, proposalID)
ctx.EventManager().EmitEvent( ctx.EventManager().EmitEvent(
sdk.NewEvent( sdk.NewEvent(

View File

@ -33,7 +33,7 @@ func (keeper Keeper) AddVote(ctx sdk.Context, proposalID uint64, voterAddr sdk.A
keeper.SetVote(ctx, vote) keeper.SetVote(ctx, vote)
// called after a vote on a proposal is cast // called after a vote on a proposal is cast
keeper.AfterProposalVote(ctx, proposalID, voterAddr) keeper.Hooks().AfterProposalVote(ctx, proposalID, voterAddr)
ctx.EventManager().EmitEvent( ctx.EventManager().EmitEvent(
sdk.NewEvent( sdk.NewEvent(

View File

@ -9,14 +9,26 @@ import (
"github.com/cosmos/cosmos-sdk/x/slashing/types" "github.com/cosmos/cosmos-sdk/x/slashing/types"
) )
func (k Keeper) AfterValidatorBonded(ctx sdk.Context, address sdk.ConsAddress, _ sdk.ValAddress) error { var _ types.StakingHooks = Hooks{}
// Update the signing info start height or create a new signing info
signingInfo, found := k.GetValidatorSigningInfo(ctx, address) // Hooks wrapper struct for slashing keeper
type Hooks struct {
k Keeper
}
// Return the slashing hooks
func (k Keeper) Hooks() Hooks {
return Hooks{k}
}
// AfterValidatorBonded updates the signing info start height or create a new signing info
func (h Hooks) AfterValidatorBonded(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) error {
signingInfo, found := h.k.GetValidatorSigningInfo(ctx, consAddr)
if found { if found {
signingInfo.StartHeight = ctx.BlockHeight() signingInfo.StartHeight = ctx.BlockHeight()
} else { } else {
signingInfo = types.NewValidatorSigningInfo( signingInfo = types.NewValidatorSigningInfo(
address, consAddr,
ctx.BlockHeight(), ctx.BlockHeight(),
0, 0,
time.Unix(0, 0), time.Unix(0, 0),
@ -25,53 +37,26 @@ func (k Keeper) AfterValidatorBonded(ctx sdk.Context, address sdk.ConsAddress, _
) )
} }
k.SetValidatorSigningInfo(ctx, address, signingInfo) h.k.SetValidatorSigningInfo(ctx, consAddr, signingInfo)
return nil return nil
} }
// AfterValidatorRemoved deletes the address-pubkey relation when a validator is removed,
func (h Hooks) AfterValidatorRemoved(ctx sdk.Context, consAddr sdk.ConsAddress, _ sdk.ValAddress) error {
h.k.deleteAddrPubkeyRelation(ctx, crypto.Address(consAddr))
return nil
}
// AfterValidatorCreated adds the address-pubkey relation when a validator is created. // AfterValidatorCreated adds the address-pubkey relation when a validator is created.
func (k Keeper) AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) error { func (h Hooks) AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) error {
validator := k.sk.Validator(ctx, valAddr) validator := h.k.sk.Validator(ctx, valAddr)
consPk, err := validator.ConsPubKey() consPk, err := validator.ConsPubKey()
if err != nil { if err != nil {
return err return err
} }
return k.AddPubkey(ctx, consPk) return h.k.AddPubkey(ctx, consPk)
}
// AfterValidatorRemoved deletes the address-pubkey relation when a validator is removed,
func (k Keeper) AfterValidatorRemoved(ctx sdk.Context, address sdk.ConsAddress) error {
k.deleteAddrPubkeyRelation(ctx, crypto.Address(address))
return nil
}
// Hooks wrapper struct for slashing keeper
type Hooks struct {
k Keeper
}
var _ types.StakingHooks = Hooks{}
// Return the wrapper struct
func (k Keeper) Hooks() Hooks {
return Hooks{k}
}
// Implements sdk.ValidatorHooks
func (h Hooks) AfterValidatorBonded(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) error {
return h.k.AfterValidatorBonded(ctx, consAddr, valAddr)
}
// Implements sdk.ValidatorHooks
func (h Hooks) AfterValidatorRemoved(ctx sdk.Context, consAddr sdk.ConsAddress, _ sdk.ValAddress) error {
return h.k.AfterValidatorRemoved(ctx, consAddr)
}
// Implements sdk.ValidatorHooks
func (h Hooks) AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) error {
return h.k.AfterValidatorCreated(ctx, valAddr)
} }
func (h Hooks) AfterValidatorBeginUnbonding(_ sdk.Context, _ sdk.ConsAddress, _ sdk.ValAddress) error { func (h Hooks) AfterValidatorBeginUnbonding(_ sdk.Context, _ sdk.ConsAddress, _ sdk.ValAddress) error {

View File

@ -11,7 +11,7 @@ func (s *KeeperTestSuite) TestAfterValidatorBonded() {
require := s.Require() require := s.Require()
valAddr := sdk.ValAddress(consAddr.Bytes()) valAddr := sdk.ValAddress(consAddr.Bytes())
keeper.AfterValidatorBonded(ctx, consAddr, valAddr) keeper.Hooks().AfterValidatorBonded(ctx, consAddr, valAddr)
_, ok := keeper.GetValidatorSigningInfo(ctx, consAddr) _, ok := keeper.GetValidatorSigningInfo(ctx, consAddr)
require.True(ok) require.True(ok)
@ -28,14 +28,14 @@ func (s *KeeperTestSuite) TestAfterValidatorCreatedOrRemoved() {
require.NoError(err) require.NoError(err)
s.stakingKeeper.EXPECT().Validator(ctx, valAddr).Return(validator) s.stakingKeeper.EXPECT().Validator(ctx, valAddr).Return(validator)
err = keeper.AfterValidatorCreated(ctx, valAddr) err = keeper.Hooks().AfterValidatorCreated(ctx, valAddr)
require.NoError(err) require.NoError(err)
ePubKey, err := keeper.GetPubkey(ctx, addr.Bytes()) ePubKey, err := keeper.GetPubkey(ctx, addr.Bytes())
require.NoError(err) require.NoError(err)
require.Equal(ePubKey, pubKey) require.Equal(ePubKey, pubKey)
err = keeper.AfterValidatorRemoved(ctx, sdk.ConsAddress(addr)) err = keeper.Hooks().AfterValidatorRemoved(ctx, sdk.ConsAddress(addr), nil)
require.NoError(err) require.NoError(err)
_, err = keeper.GetPubkey(ctx, addr.Bytes()) _, err = keeper.GetPubkey(ctx, addr.Bytes())

View File

@ -396,6 +396,34 @@ func (m *MockStakingHooks) EXPECT() *MockStakingHooksMockRecorder {
return m.recorder return m.recorder
} }
// AfterDelegationModified mocks base method.
func (m *MockStakingHooks) AfterDelegationModified(ctx types.Context, delAddr types.AccAddress, valAddr types.ValAddress) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AfterDelegationModified", ctx, delAddr, valAddr)
ret0, _ := ret[0].(error)
return ret0
}
// AfterDelegationModified indicates an expected call of AfterDelegationModified.
func (mr *MockStakingHooksMockRecorder) AfterDelegationModified(ctx, delAddr, valAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AfterDelegationModified", reflect.TypeOf((*MockStakingHooks)(nil).AfterDelegationModified), ctx, delAddr, valAddr)
}
// AfterValidatorBeginUnbonding mocks base method.
func (m *MockStakingHooks) AfterValidatorBeginUnbonding(ctx types.Context, consAddr types.ConsAddress, valAddr types.ValAddress) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AfterValidatorBeginUnbonding", ctx, consAddr, valAddr)
ret0, _ := ret[0].(error)
return ret0
}
// AfterValidatorBeginUnbonding indicates an expected call of AfterValidatorBeginUnbonding.
func (mr *MockStakingHooksMockRecorder) AfterValidatorBeginUnbonding(ctx, consAddr, valAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AfterValidatorBeginUnbonding", reflect.TypeOf((*MockStakingHooks)(nil).AfterValidatorBeginUnbonding), ctx, consAddr, valAddr)
}
// AfterValidatorBonded mocks base method. // AfterValidatorBonded mocks base method.
func (m *MockStakingHooks) AfterValidatorBonded(ctx types.Context, consAddr types.ConsAddress, valAddr types.ValAddress) error { func (m *MockStakingHooks) AfterValidatorBonded(ctx types.Context, consAddr types.ConsAddress, valAddr types.ValAddress) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -437,3 +465,73 @@ func (mr *MockStakingHooksMockRecorder) AfterValidatorRemoved(ctx, consAddr, val
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AfterValidatorRemoved", reflect.TypeOf((*MockStakingHooks)(nil).AfterValidatorRemoved), ctx, consAddr, valAddr) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AfterValidatorRemoved", reflect.TypeOf((*MockStakingHooks)(nil).AfterValidatorRemoved), ctx, consAddr, valAddr)
} }
// BeforeDelegationCreated mocks base method.
func (m *MockStakingHooks) BeforeDelegationCreated(ctx types.Context, delAddr types.AccAddress, valAddr types.ValAddress) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeforeDelegationCreated", ctx, delAddr, valAddr)
ret0, _ := ret[0].(error)
return ret0
}
// BeforeDelegationCreated indicates an expected call of BeforeDelegationCreated.
func (mr *MockStakingHooksMockRecorder) BeforeDelegationCreated(ctx, delAddr, valAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeforeDelegationCreated", reflect.TypeOf((*MockStakingHooks)(nil).BeforeDelegationCreated), ctx, delAddr, valAddr)
}
// BeforeDelegationRemoved mocks base method.
func (m *MockStakingHooks) BeforeDelegationRemoved(ctx types.Context, delAddr types.AccAddress, valAddr types.ValAddress) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeforeDelegationRemoved", ctx, delAddr, valAddr)
ret0, _ := ret[0].(error)
return ret0
}
// BeforeDelegationRemoved indicates an expected call of BeforeDelegationRemoved.
func (mr *MockStakingHooksMockRecorder) BeforeDelegationRemoved(ctx, delAddr, valAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeforeDelegationRemoved", reflect.TypeOf((*MockStakingHooks)(nil).BeforeDelegationRemoved), ctx, delAddr, valAddr)
}
// BeforeDelegationSharesModified mocks base method.
func (m *MockStakingHooks) BeforeDelegationSharesModified(ctx types.Context, delAddr types.AccAddress, valAddr types.ValAddress) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeforeDelegationSharesModified", ctx, delAddr, valAddr)
ret0, _ := ret[0].(error)
return ret0
}
// BeforeDelegationSharesModified indicates an expected call of BeforeDelegationSharesModified.
func (mr *MockStakingHooksMockRecorder) BeforeDelegationSharesModified(ctx, delAddr, valAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeforeDelegationSharesModified", reflect.TypeOf((*MockStakingHooks)(nil).BeforeDelegationSharesModified), ctx, delAddr, valAddr)
}
// BeforeValidatorModified mocks base method.
func (m *MockStakingHooks) BeforeValidatorModified(ctx types.Context, valAddr types.ValAddress) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeforeValidatorModified", ctx, valAddr)
ret0, _ := ret[0].(error)
return ret0
}
// BeforeValidatorModified indicates an expected call of BeforeValidatorModified.
func (mr *MockStakingHooksMockRecorder) BeforeValidatorModified(ctx, valAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeforeValidatorModified", reflect.TypeOf((*MockStakingHooks)(nil).BeforeValidatorModified), ctx, valAddr)
}
// BeforeValidatorSlashed mocks base method.
func (m *MockStakingHooks) BeforeValidatorSlashed(ctx types.Context, valAddr types.ValAddress, fraction types.Dec) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeforeValidatorSlashed", ctx, valAddr, fraction)
ret0, _ := ret[0].(error)
return ret0
}
// BeforeValidatorSlashed indicates an expected call of BeforeValidatorSlashed.
func (mr *MockStakingHooksMockRecorder) BeforeValidatorSlashed(ctx, valAddr, fraction interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeforeValidatorSlashed", reflect.TypeOf((*MockStakingHooks)(nil).BeforeValidatorSlashed), ctx, valAddr, fraction)
}

View File

@ -59,7 +59,15 @@ type StakingKeeper interface {
// StakingHooks event hooks for staking validator object (noalias) // StakingHooks event hooks for staking validator object (noalias)
type StakingHooks interface { type StakingHooks interface {
AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) error // Must be called when a validator is created AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) error // Must be called when a validator is created
BeforeValidatorModified(ctx sdk.Context, valAddr sdk.ValAddress) error // Must be called when a validator's state changes
AfterValidatorRemoved(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) error // Must be called when a validator is deleted AfterValidatorRemoved(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) error // Must be called when a validator is deleted
AfterValidatorBonded(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) error // Must be called when a validator is bonded AfterValidatorBonded(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) error // Must be called when a validator is bonded
AfterValidatorBeginUnbonding(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) error // Must be called when a validator begins unbonding
BeforeDelegationCreated(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) error // Must be called when a delegation is created
BeforeDelegationSharesModified(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) error // Must be called when a delegation's shares are modified
BeforeDelegationRemoved(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) error // Must be called when a delegation is removed
AfterDelegationModified(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) error
BeforeValidatorSlashed(ctx sdk.Context, valAddr sdk.ValAddress, fraction sdk.Dec) error
} }

View File

@ -103,7 +103,7 @@ func (k Keeper) RemoveDelegation(ctx sdk.Context, delegation types.Delegation) e
delegatorAddress := sdk.MustAccAddressFromBech32(delegation.DelegatorAddress) delegatorAddress := sdk.MustAccAddressFromBech32(delegation.DelegatorAddress)
// TODO: Consider calling hooks outside of the store wrapper functions, it's unobvious. // TODO: Consider calling hooks outside of the store wrapper functions, it's unobvious.
if err := k.BeforeDelegationRemoved(ctx, delegatorAddress, delegation.GetValidatorAddr()); err != nil { if err := k.Hooks().BeforeDelegationRemoved(ctx, delegatorAddress, delegation.GetValidatorAddr()); err != nil {
return err return err
} }
@ -630,9 +630,9 @@ func (k Keeper) Delegate(
// call the appropriate hook if present // call the appropriate hook if present
if found { if found {
err = k.BeforeDelegationSharesModified(ctx, delAddr, validator.GetOperator()) err = k.Hooks().BeforeDelegationSharesModified(ctx, delAddr, validator.GetOperator())
} else { } else {
err = k.BeforeDelegationCreated(ctx, delAddr, validator.GetOperator()) err = k.Hooks().BeforeDelegationCreated(ctx, delAddr, validator.GetOperator())
} }
if err != nil { if err != nil {
@ -689,7 +689,7 @@ func (k Keeper) Delegate(
k.SetDelegation(ctx, delegation) k.SetDelegation(ctx, delegation)
// Call the after-modification hook // Call the after-modification hook
if err := k.AfterDelegationModified(ctx, delegatorAddress, delegation.GetValidatorAddr()); err != nil { if err := k.Hooks().AfterDelegationModified(ctx, delegatorAddress, delegation.GetValidatorAddr()); err != nil {
return newShares, err return newShares, err
} }
@ -707,7 +707,7 @@ func (k Keeper) Unbond(
} }
// call the before-delegation-modified hook // call the before-delegation-modified hook
if err := k.BeforeDelegationSharesModified(ctx, delAddr, valAddr); err != nil { if err := k.Hooks().BeforeDelegationSharesModified(ctx, delAddr, valAddr); err != nil {
return amount, err return amount, err
} }
@ -745,7 +745,7 @@ func (k Keeper) Unbond(
} else { } else {
k.SetDelegation(ctx, delegation) k.SetDelegation(ctx, delegation)
// call the after delegation modification hook // call the after delegation modification hook
err = k.AfterDelegationModified(ctx, delegatorAddress, delegation.GetValidatorAddr()) err = k.Hooks().AfterDelegationModified(ctx, delegatorAddress, delegation.GetValidatorAddr())
} }
if err != nil { if err != nil {

View File

@ -610,7 +610,7 @@ func (s *KeeperTestSuite) TestRedelegationMaxEntries() {
require.Equal(valTokens, issuedShares.RoundInt()) require.Equal(valTokens, issuedShares.RoundInt())
s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(gomock.Any(), stakingtypes.NotBondedPoolName, stakingtypes.BondedPoolName, gomock.Any()) s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(gomock.Any(), stakingtypes.NotBondedPoolName, stakingtypes.BondedPoolName, gomock.Any())
validator = stakingkeeper.TestingUpdateValidator(keeper, ctx, validator, true) _ = stakingkeeper.TestingUpdateValidator(keeper, ctx, validator, true)
val0AccAddr := sdk.AccAddress(addrVals[0].Bytes()) val0AccAddr := sdk.AccAddress(addrVals[0].Bytes())
selfDelegation := stakingtypes.NewDelegation(val0AccAddr, addrVals[0], issuedShares) selfDelegation := stakingtypes.NewDelegation(val0AccAddr, addrVals[0], issuedShares)
keeper.SetDelegation(ctx, selfDelegation) keeper.SetDelegation(ctx, selfDelegation)
@ -732,7 +732,7 @@ func (s *KeeperTestSuite) TestRedelegateFromUnbondingValidator() {
validator2, issuedShares = validator2.AddTokensFromDel(valTokens) validator2, issuedShares = validator2.AddTokensFromDel(valTokens)
require.Equal(valTokens, issuedShares.RoundInt()) require.Equal(valTokens, issuedShares.RoundInt())
s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(gomock.Any(), stakingtypes.NotBondedPoolName, stakingtypes.BondedPoolName, gomock.Any()) s.bankKeeper.EXPECT().SendCoinsFromModuleToModule(gomock.Any(), stakingtypes.NotBondedPoolName, stakingtypes.BondedPoolName, gomock.Any())
validator2 = stakingkeeper.TestingUpdateValidator(keeper, ctx, validator2, true) _ = stakingkeeper.TestingUpdateValidator(keeper, ctx, validator2, true)
header := ctx.BlockHeader() header := ctx.BlockHeader()
blockHeight := int64(10) blockHeight := int64(10)

View File

@ -41,7 +41,7 @@ func (k Keeper) InitGenesis(ctx sdk.Context, data *types.GenesisState) (res []ab
// Call the creation hook if not exported // Call the creation hook if not exported
if !data.Exported { if !data.Exported {
if err := k.AfterValidatorCreated(ctx, validator.GetOperator()); err != nil { if err := k.Hooks().AfterValidatorCreated(ctx, validator.GetOperator()); err != nil {
panic(err) panic(err)
} }
} }
@ -68,7 +68,7 @@ func (k Keeper) InitGenesis(ctx sdk.Context, data *types.GenesisState) (res []ab
// Call the before-creation hook if not exported // Call the before-creation hook if not exported
if !data.Exported { if !data.Exported {
if err := k.BeforeDelegationCreated(ctx, delegatorAddress, delegation.GetValidatorAddr()); err != nil { if err := k.Hooks().BeforeDelegationCreated(ctx, delegatorAddress, delegation.GetValidatorAddr()); err != nil {
panic(err) panic(err)
} }
} }
@ -77,7 +77,7 @@ func (k Keeper) InitGenesis(ctx sdk.Context, data *types.GenesisState) (res []ab
// Call the after-modification hook if not exported // Call the after-modification hook if not exported
if !data.Exported { if !data.Exported {
if err := k.AfterDelegationModified(ctx, delegatorAddress, delegation.GetValidatorAddr()); err != nil { if err := k.Hooks().AfterDelegationModified(ctx, delegatorAddress, delegation.GetValidatorAddr()); err != nil {
panic(err) panic(err)
} }
} }

View File

@ -1,89 +0,0 @@
package keeper
import (
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/staking/types"
)
// Implements StakingHooks interface
var _ types.StakingHooks = Keeper{}
// AfterValidatorCreated - call hook if registered
func (k Keeper) AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) error {
if k.hooks != nil {
return k.hooks.AfterValidatorCreated(ctx, valAddr)
}
return nil
}
// BeforeValidatorModified - call hook if registered
func (k Keeper) BeforeValidatorModified(ctx sdk.Context, valAddr sdk.ValAddress) error {
if k.hooks != nil {
return k.hooks.BeforeValidatorModified(ctx, valAddr)
}
return nil
}
// AfterValidatorRemoved - call hook if registered
func (k Keeper) AfterValidatorRemoved(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) error {
if k.hooks != nil {
return k.hooks.AfterValidatorRemoved(ctx, consAddr, valAddr)
}
return nil
}
// AfterValidatorBonded - call hook if registered
func (k Keeper) AfterValidatorBonded(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) error {
if k.hooks != nil {
return k.hooks.AfterValidatorBonded(ctx, consAddr, valAddr)
}
return nil
}
// AfterValidatorBeginUnbonding - call hook if registered
func (k Keeper) AfterValidatorBeginUnbonding(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) error {
if k.hooks != nil {
return k.hooks.AfterValidatorBeginUnbonding(ctx, consAddr, valAddr)
}
return nil
}
// BeforeDelegationCreated - call hook if registered
func (k Keeper) BeforeDelegationCreated(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) error {
if k.hooks != nil {
return k.hooks.BeforeDelegationCreated(ctx, delAddr, valAddr)
}
return nil
}
// BeforeDelegationSharesModified - call hook if registered
func (k Keeper) BeforeDelegationSharesModified(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) error {
if k.hooks != nil {
return k.hooks.BeforeDelegationSharesModified(ctx, delAddr, valAddr)
}
return nil
}
// BeforeDelegationRemoved - call hook if registered
func (k Keeper) BeforeDelegationRemoved(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) error {
if k.hooks != nil {
k.hooks.BeforeDelegationRemoved(ctx, delAddr, valAddr)
}
return nil
}
// AfterDelegationModified - call hook if registered
func (k Keeper) AfterDelegationModified(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) error {
if k.hooks != nil {
return k.hooks.AfterDelegationModified(ctx, delAddr, valAddr)
}
return nil
}
// BeforeValidatorSlashed - call hook if registered
func (k Keeper) BeforeValidatorSlashed(ctx sdk.Context, valAddr sdk.ValAddress, fraction sdk.Dec) error {
if k.hooks != nil {
return k.hooks.BeforeValidatorSlashed(ctx, valAddr, fraction)
}
return nil
}

View File

@ -66,6 +66,16 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger {
return ctx.Logger().With("module", "x/"+types.ModuleName) return ctx.Logger().With("module", "x/"+types.ModuleName)
} }
// Hooks gets the hooks for staking *Keeper {
func (keeper *Keeper) Hooks() types.StakingHooks {
if keeper.hooks == nil {
// return a no-op implementation if no hooks are set
return types.MultiStakingHooks{}
}
return keeper.hooks
}
// SetHooks Set the validator hooks // SetHooks Set the validator hooks
func (k *Keeper) SetHooks(sh types.StakingHooks) { func (k *Keeper) SetHooks(sh types.StakingHooks) {
if k.hooks != nil { if k.hooks != nil {

View File

@ -114,7 +114,7 @@ func (k msgServer) CreateValidator(goCtx context.Context, msg *types.MsgCreateVa
k.SetNewValidatorByPowerIndex(ctx, validator) k.SetNewValidatorByPowerIndex(ctx, validator)
// call the after-creation hook // call the after-creation hook
if err := k.AfterValidatorCreated(ctx, validator.GetOperator()); err != nil { if err := k.Hooks().AfterValidatorCreated(ctx, validator.GetOperator()); err != nil {
return nil, err return nil, err
} }
@ -170,7 +170,7 @@ func (k msgServer) EditValidator(goCtx context.Context, msg *types.MsgEditValida
} }
// call the before-modification hook since we're about to update the commission // call the before-modification hook since we're about to update the commission
if err := k.BeforeValidatorModified(ctx, valAddr); err != nil { if err := k.Hooks().BeforeValidatorModified(ctx, valAddr); err != nil {
return nil, err return nil, err
} }

View File

@ -65,7 +65,9 @@ func (k Keeper) Slash(ctx sdk.Context, consAddr sdk.ConsAddress, infractionHeigh
operatorAddress := validator.GetOperator() operatorAddress := validator.GetOperator()
// call the before-modification hook // call the before-modification hook
k.BeforeValidatorModified(ctx, operatorAddress) if err := k.Hooks().BeforeValidatorModified(ctx, operatorAddress); err != nil {
k.Logger(ctx).Error("failed to call before validator modified hook", "error", err)
}
// Track remaining slash amount for the validator // Track remaining slash amount for the validator
// This will decrease when we slash unbondings and // This will decrease when we slash unbondings and
@ -123,7 +125,9 @@ func (k Keeper) Slash(ctx sdk.Context, consAddr sdk.ConsAddress, infractionHeigh
effectiveFraction = math.LegacyOneDec() effectiveFraction = math.LegacyOneDec()
} }
// call the before-slashed hook // call the before-slashed hook
k.BeforeValidatorSlashed(ctx, operatorAddress, effectiveFraction) if err := k.Hooks().BeforeValidatorSlashed(ctx, operatorAddress, effectiveFraction); err != nil {
k.Logger(ctx).Error("failed to call before validator slashed hook", "error", err)
}
} }
// Deduct from validator's bonded tokens and update the validator. // Deduct from validator's bonded tokens and update the validator.

View File

@ -298,7 +298,10 @@ func (k Keeper) bondValidator(ctx sdk.Context, validator types.Validator) (types
if err != nil { if err != nil {
return validator, err return validator, err
} }
k.AfterValidatorBonded(ctx, consAddr, validator.GetOperator())
if err := k.Hooks().AfterValidatorBonded(ctx, consAddr, validator.GetOperator()); err != nil {
return validator, err
}
return validator, err return validator, err
} }
@ -333,7 +336,10 @@ func (k Keeper) beginUnbondingValidator(ctx sdk.Context, validator types.Validat
if err != nil { if err != nil {
return validator, err return validator, err
} }
k.AfterValidatorBeginUnbonding(ctx, consAddr, validator.GetOperator())
if err := k.Hooks().AfterValidatorBeginUnbonding(ctx, consAddr, validator.GetOperator()); err != nil {
return validator, err
}
return validator, nil return validator, nil
} }

View File

@ -183,7 +183,9 @@ func (k Keeper) RemoveValidator(ctx sdk.Context, address sdk.ValAddress) {
store.Delete(types.GetValidatorsByPowerIndexKey(validator, k.PowerReduction(ctx))) store.Delete(types.GetValidatorsByPowerIndexKey(validator, k.PowerReduction(ctx)))
// call hooks // call hooks
k.AfterValidatorRemoved(ctx, valConsAddr, validator.GetOperator()) if err := k.Hooks().AfterValidatorRemoved(ctx, valConsAddr, validator.GetOperator()); err != nil {
k.Logger(ctx).Error("error in after validator removed hook", "error", err)
}
} }
// get groups of validators // get groups of validators

View File

@ -10,7 +10,6 @@ import (
stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper"
"github.com/cosmos/cosmos-sdk/x/staking/teststaking" "github.com/cosmos/cosmos-sdk/x/staking/teststaking"
"github.com/cosmos/cosmos-sdk/x/staking/types"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
abci "github.com/tendermint/tendermint/abci/types" abci "github.com/tendermint/tendermint/abci/types"
) )
@ -88,11 +87,11 @@ func (s *KeeperTestSuite) TestValidatorBasics() {
require := s.Require() require := s.Require()
// construct the validators // construct the validators
var validators [3]types.Validator var validators [3]stakingtypes.Validator
powers := []int64{9, 8, 7} powers := []int64{9, 8, 7}
for i, power := range powers { for i, power := range powers {
validators[i] = teststaking.NewValidator(s.T(), sdk.ValAddress(PKs[i].Address().Bytes()), PKs[i]) validators[i] = teststaking.NewValidator(s.T(), sdk.ValAddress(PKs[i].Address().Bytes()), PKs[i])
validators[i].Status = types.Unbonded validators[i].Status = stakingtypes.Unbonded
validators[i].Tokens = math.ZeroInt() validators[i].Tokens = math.ZeroInt()
tokens := keeper.TokensFromConsensusPower(ctx, power) tokens := keeper.TokensFromConsensusPower(ctx, power)
@ -131,11 +130,11 @@ func (s *KeeperTestSuite) TestValidatorBasics() {
resVals = keeper.GetLastValidators(ctx) resVals = keeper.GetLastValidators(ctx)
require.Equal(1, len(resVals)) require.Equal(1, len(resVals))
require.True(validators[0].MinEqual(&resVals[0])) require.True(validators[0].MinEqual(&resVals[0]))
require.Equal(types.Bonded, validators[0].Status) require.Equal(stakingtypes.Bonded, validators[0].Status)
require.True(keeper.TokensFromConsensusPower(ctx, 9).Equal(validators[0].BondedTokens())) require.True(keeper.TokensFromConsensusPower(ctx, 9).Equal(validators[0].BondedTokens()))
// modify a records, save, and retrieve // modify a records, save, and retrieve
validators[0].Status = types.Bonded validators[0].Status = stakingtypes.Bonded
validators[0].Tokens = keeper.TokensFromConsensusPower(ctx, 10) validators[0].Tokens = keeper.TokensFromConsensusPower(ctx, 10)
validators[0].DelegatorShares = sdk.NewDecFromInt(validators[0].Tokens) validators[0].DelegatorShares = sdk.NewDecFromInt(validators[0].Tokens)
validators[0] = stakingkeeper.TestingUpdateValidator(keeper, ctx, validators[0], true) validators[0] = stakingkeeper.TestingUpdateValidator(keeper, ctx, validators[0], true)
@ -169,7 +168,7 @@ func (s *KeeperTestSuite) TestValidatorBasics() {
func() { keeper.RemoveValidator(ctx, validators[1].GetOperator()) }) func() { keeper.RemoveValidator(ctx, validators[1].GetOperator()) })
// shouldn't be able to remove if there are still tokens left // shouldn't be able to remove if there are still tokens left
validators[1].Status = types.Unbonded validators[1].Status = stakingtypes.Unbonded
keeper.SetValidator(ctx, validators[1]) keeper.SetValidator(ctx, validators[1])
require.PanicsWithValue("attempting to remove a validator which still contains tokens", require.PanicsWithValue("attempting to remove a validator which still contains tokens",
func() { keeper.RemoveValidator(ctx, validators[1].GetOperator()) }) func() { keeper.RemoveValidator(ctx, validators[1].GetOperator()) })
@ -229,7 +228,7 @@ func (s *KeeperTestSuite) TestApplyAndReturnValidatorSetUpdatesPowerDecrease() {
require := s.Require() require := s.Require()
powers := []int64{100, 100} powers := []int64{100, 100}
var validators [2]types.Validator var validators [2]stakingtypes.Validator
for i, power := range powers { for i, power := range powers {
validators[i] = teststaking.NewValidator(s.T(), sdk.ValAddress(PKs[i].Address().Bytes()), PKs[i]) validators[i] = teststaking.NewValidator(s.T(), sdk.ValAddress(PKs[i].Address().Bytes()), PKs[i])