collect all invariants for a block before stopping (#4707)

This commit is contained in:
colin axner 2019-07-11 03:56:43 -07:00 committed by Alessio Treglia
parent 3f519832a7
commit 280734d0e3
13 changed files with 172 additions and 163 deletions

View File

@ -42,15 +42,16 @@ func init() {
flag.BoolVar(&commit, "SimulationCommit", false, "have the simulation commit")
flag.IntVar(&period, "SimulationPeriod", 1, "run slow invariants only once every period assertions")
flag.BoolVar(&onOperation, "SimulateEveryOperation", false, "run slow invariants every operation")
flag.BoolVar(&allInvariants, "PrintAllInvariants", false, "print all invariants if a broken invariant is found")
}
// helper function for populating input for SimulateFromSeed
func getSimulateFromSeedInput(tb testing.TB, w io.Writer, app *SimApp) (
testing.TB, io.Writer, *baseapp.BaseApp, simulation.AppStateFn, int64,
simulation.WeightedOperations, sdk.Invariants, int, int, bool, bool, bool) {
simulation.WeightedOperations, sdk.Invariants, int, int, bool, bool, bool, bool) {
return tb, w, app.BaseApp, appStateFn, seed,
testAndRunTxs(app), invariants(app), numBlocks, blockSize, commit, lean, onOperation
testAndRunTxs(app), invariants(app), numBlocks, blockSize, commit, lean, onOperation, allInvariants
}
func appStateFn(
@ -602,6 +603,7 @@ func TestAppStateDeterminism(t *testing.T) {
true,
false,
false,
false,
)
appHash := app.LastCommitID().Hash
appHashList[j] = appHash
@ -627,7 +629,7 @@ func BenchmarkInvariants(b *testing.B) {
// 2. Run parameterized simulation (w/o invariants)
_, err := simulation.SimulateFromSeed(
b, ioutil.Discard, app.BaseApp, appStateFn, seed, testAndRunTxs(app),
[]sdk.Invariant{}, numBlocks, blockSize, commit, lean, onOperation,
[]sdk.Invariant{}, numBlocks, blockSize, commit, lean, onOperation, false,
)
if err != nil {
fmt.Println(err)
@ -642,8 +644,8 @@ func BenchmarkInvariants(b *testing.B) {
// their respective metadata which makes it useful for testing/benchmarking.
for _, cr := range app.crisisKeeper.Routes() {
b.Run(fmt.Sprintf("%s/%s", cr.ModuleName, cr.Route), func(b *testing.B) {
if err := cr.Invar(ctx); err != nil {
fmt.Printf("broken invariant at block %d of %d\n%s", ctx.BlockHeight()-1, numBlocks, err)
if res, stop := cr.Invar(ctx); stop {
fmt.Printf("broken invariant at block %d of %d\n%s", ctx.BlockHeight()-1, numBlocks, res)
b.FailNow()
}
})

View File

@ -48,6 +48,7 @@ var (
commit bool
period int
onOperation bool // TODO Remove in favor of binary search for invariant violation
allInvariants bool
)
// NewSimAppUNSAFE is used for debugging purposes only.

View File

@ -1,10 +1,12 @@
package types
import "fmt"
// An Invariant is a function which tests a particular invariant.
// If the invariant has been broken, it should return an error
// containing a descriptive message about what happened.
// The simulator will then halt and print the logs.
type Invariant func(ctx Context) error
type Invariant func(ctx Context) (string, bool)
// Invariants defines a group of invariants
type Invariants []Invariant
@ -13,3 +15,10 @@ type Invariants []Invariant
type InvariantRegistry interface {
RegisterRoute(moduleName, route string, invar Invariant)
}
// FormatInvariant returns a standardized invariant message along with
// broken boolean.
func FormatInvariant(module, name, msg string, broken bool) (string, bool) {
return fmt.Sprintf("%s: %s invariant\n%s\nInvariant Broken: %v\n",
module, name, msg, broken), broken
}

View File

@ -1,11 +1,9 @@
package keeper
import (
"errors"
"fmt"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/exported"
"github.com/cosmos/cosmos-sdk/x/bank/internal/types"
)
@ -17,36 +15,23 @@ func RegisterInvariants(ir sdk.InvariantRegistry, ak types.AccountKeeper) {
// NonnegativeBalanceInvariant checks that all accounts in the application have non-negative balances
func NonnegativeBalanceInvariant(ak types.AccountKeeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
var msg string
var amt int
accts := ak.GetAllAccounts(ctx)
for _, acc := range accts {
coins := acc.GetCoins()
if coins.IsAnyNegative() {
return fmt.Errorf("%s has a negative denomination of %s",
amt++
msg += fmt.Sprintf("\t%s has a negative denomination of %s\n",
acc.GetAddress().String(),
coins.String())
}
}
return nil
}
}
broken := amt != 0
// TotalCoinsInvariant checks that the sum of the coins across all accounts
// is what is expected
func TotalCoinsInvariant(ak types.AccountKeeper, totalSupplyFn func() sdk.Coins) sdk.Invariant {
return func(ctx sdk.Context) error {
totalCoins := sdk.NewCoins()
chkAccount := func(acc exported.Account) bool {
coins := acc.GetCoins()
totalCoins = totalCoins.Add(coins)
return false
}
ak.IterateAccounts(ctx, chkAccount)
if !totalSupplyFn().IsEqual(totalCoins) {
return errors.New("total calculated coins doesn't equal expected coins")
}
return nil
return sdk.FormatInvariant(types.ModuleName, "nonnegative-outstanding",
fmt.Sprintf("amount of negative accounts found %d\n%s", amt, msg), broken)
}
}

View File

@ -40,10 +40,11 @@ func handleMsgVerifyInvariant(ctx sdk.Context, msg types.MsgVerifyInvariant, k k
found := false
msgFullRoute := msg.FullInvariantRoute()
var invarianceErr error
var res string
var stop bool
for _, invarRoute := range k.Routes() {
if invarRoute.FullRoute() == msgFullRoute {
invarianceErr = invarRoute.Invar(cacheCtx)
res, stop = invarRoute.Invar(cacheCtx)
found = true
break
}
@ -53,7 +54,7 @@ func handleMsgVerifyInvariant(ctx sdk.Context, msg types.MsgVerifyInvariant, k k
return types.ErrUnknownInvariant(types.DefaultCodespace).Result()
}
if invarianceErr != nil {
if stop {
// NOTE currently, because the chain halts here, this transaction will never be included
// in the blockchain thus the constant fee will have never been deducted. Thus no
// refund is required.
@ -70,7 +71,7 @@ func handleMsgVerifyInvariant(ctx sdk.Context, msg types.MsgVerifyInvariant, k k
//}
// TODO replace with circuit breaker
panic(invarianceErr)
panic(res)
}
ctx.EventManager().EmitEvents(sdk.Events{

View File

@ -1,7 +1,6 @@
package crisis_test
import (
"errors"
"fmt"
"strings"
"testing"
@ -17,8 +16,8 @@ import (
var (
testModuleName = "dummy"
dummyRouteWhichPasses = crisis.NewInvarRoute(testModuleName, "which-passes", func(_ sdk.Context) error { return nil })
dummyRouteWhichFails = crisis.NewInvarRoute(testModuleName, "which-fails", func(_ sdk.Context) error { return errors.New("whoops") })
dummyRouteWhichPasses = crisis.NewInvarRoute(testModuleName, "which-passes", func(_ sdk.Context) (string, bool) { return "", false })
dummyRouteWhichFails = crisis.NewInvarRoute(testModuleName, "which-fails", func(_ sdk.Context) (string, bool) { return "whoops", true })
addrs = distr.TestAddrs
)

View File

@ -67,13 +67,13 @@ func (k Keeper) AssertInvariants(ctx sdk.Context) {
start := time.Now()
invarRoutes := k.Routes()
for _, ir := range invarRoutes {
if err := ir.Invar(ctx); err != nil {
if res, stop := ir.Invar(ctx); stop {
// TODO: Include app name as part of context to allow for this to be
// variable.
panic(fmt.Errorf("invariant broken: %s\n"+
"\tCRITICAL please submit the following transaction:\n"+
"\t\t tx crisis invariant-broken %v %v", err, ir.ModuleName, ir.Route))
"\t\t tx crisis invariant-broken %s %s", res, ir.ModuleName, ir.Route))
}
}

View File

@ -22,18 +22,18 @@ func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) {
// AllInvariants runs all invariants of the distribution module
func AllInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
err := CanWithdrawInvariant(k)(ctx)
if err != nil {
return err
return func(ctx sdk.Context) (string, bool) {
res, stop := CanWithdrawInvariant(k)(ctx)
if stop {
return res, stop
}
err = NonNegativeOutstandingInvariant(k)(ctx)
if err != nil {
return err
res, stop = NonNegativeOutstandingInvariant(k)(ctx)
if stop {
return res, stop
}
err = ReferenceCountInvariant(k)(ctx)
if err != nil {
return err
res, stop = ReferenceCountInvariant(k)(ctx)
if stop {
return res, stop
}
return ModuleAccountInvariant(k)(ctx)
}
@ -41,30 +41,29 @@ func AllInvariants(k Keeper) sdk.Invariant {
// NonNegativeOutstandingInvariant checks that outstanding unwithdrawn fees are never negative
func NonNegativeOutstandingInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
var msg string
var amt int
var outstanding sdk.DecCoins
k.IterateValidatorOutstandingRewards(ctx, func(_ sdk.ValAddress, rewards types.ValidatorOutstandingRewards) (stop bool) {
k.IterateValidatorOutstandingRewards(ctx, func(addr sdk.ValAddress, rewards types.ValidatorOutstandingRewards) (stop bool) {
outstanding = rewards
if outstanding.IsAnyNegative() {
return true
amt++
msg += fmt.Sprintf("\t%v has negative outstanding coins: %v\n", addr, outstanding)
}
return false
})
broken := amt != 0
if outstanding.IsAnyNegative() {
return fmt.Errorf("negative outstanding coins: %v", outstanding)
}
return nil
return sdk.FormatInvariant(types.ModuleName, "nonnegative outstanding",
fmt.Sprintf("found %d validators with negative outstanding rewards\n%s", amt, msg), broken)
}
}
// CanWithdrawInvariant checks that current rewards can be completely withdrawn
func CanWithdrawInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
// cache, we don't want to write changes
ctx, _ = ctx.CacheContext()
@ -98,17 +97,15 @@ func CanWithdrawInvariant(k Keeper) sdk.Invariant {
return false
})
if len(remaining) > 0 && remaining[0].Amount.LT(sdk.ZeroDec()) {
return fmt.Errorf("negative remaining coins: %v", remaining)
}
return nil
broken := len(remaining) > 0 && remaining[0].Amount.LT(sdk.ZeroDec())
return sdk.FormatInvariant(types.ModuleName, "can withdraw",
fmt.Sprintf("remaining coins: %v", remaining), broken)
}
}
// ReferenceCountInvariant checks that the number of historical rewards records is correct
func ReferenceCountInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
valCount := uint64(0)
k.stakingKeeper.IterateValidators(ctx, func(_ int64, val exported.ValidatorI) (stop bool) {
@ -127,21 +124,18 @@ func ReferenceCountInvariant(k Keeper) sdk.Invariant {
// delegation (previous period), one record per slash (previous period)
expected := valCount + uint64(len(dels)) + slashCount
count := k.GetValidatorHistoricalReferenceCount(ctx)
broken := count != expected
if count != expected {
return fmt.Errorf("unexpected number of historical rewards records: "+
return sdk.FormatInvariant(types.ModuleName, "reference count", fmt.Sprintf("unexpected number of historical rewards records: "+
"expected %v (%v vals + %v dels + %v slashes), got %v",
expected, valCount, len(dels), slashCount, count)
}
return nil
expected, valCount, len(dels), slashCount, count), broken)
}
}
// ModuleAccountInvariant checks that the coins held by the distr ModuleAccount
// is consistent with the sum of validator outstanding rewards
func ModuleAccountInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
var expectedCoins sdk.DecCoins
k.IterateValidatorOutstandingRewards(ctx, func(_ sdk.ValAddress, rewards types.ValidatorOutstandingRewards) (stop bool) {
@ -154,12 +148,8 @@ func ModuleAccountInvariant(k Keeper) sdk.Invariant {
macc := k.GetDistributionAccount(ctx)
if !macc.GetCoins().IsEqual(expectedInt) {
return fmt.Errorf("distribution ModuleAccount coins invariance:\n"+
"\texpected ModuleAccount coins: %s\n"+
"\tdistribution ModuleAccount coins : %s", expectedInt, macc.GetCoins())
}
return nil
broken := !macc.GetCoins().IsEqual(expectedInt)
return sdk.FormatInvariant(types.ModuleName, "ModuleAccount coins", fmt.Sprintf("expected ModuleAccount coins: %s\n"+
"\tdistribution ModuleAccount coins : %s", expectedInt, macc.GetCoins()), broken)
}
}

View File

@ -14,7 +14,7 @@ func RegisterInvariants(ir sdk.InvariantRegistry, keeper Keeper) {
// AllInvariants runs all invariants of the governance module
func AllInvariants(keeper Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
return ModuleAccountInvariant(keeper)(ctx)
}
}
@ -22,7 +22,7 @@ func AllInvariants(keeper Keeper) sdk.Invariant {
// ModuleAccountInvariant checks that the module account coins reflects the sum of
// deposit amounts held on store
func ModuleAccountInvariant(keeper Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
var expectedDeposits sdk.Coins
keeper.IterateAllDeposits(ctx, func(deposit types.Deposit) bool {
@ -31,12 +31,10 @@ func ModuleAccountInvariant(keeper Keeper) sdk.Invariant {
})
macc := keeper.GetGovernanceAccount(ctx)
if !macc.GetCoins().IsEqual(expectedDeposits) {
return fmt.Errorf("deposits invariance:\n"+
"\tgov ModuleAccount coins: %s\n"+
"\tsum of deposit amounts: %s", macc.GetCoins(), expectedDeposits)
}
broken := !macc.GetCoins().IsEqual(expectedDeposits)
return nil
return sdk.FormatInvariant(types.ModuleName, "deposits",
fmt.Sprintf("\tgov ModuleAccount coins: %s\n\tsum of deposit amounts: %s",
macc.GetCoins(), expectedDeposits), broken)
}
}

View File

@ -46,7 +46,7 @@ func SimulateFromSeed(
tb testing.TB, w io.Writer, app *baseapp.BaseApp,
appStateFn AppStateFn, seed int64, ops WeightedOperations,
invariants sdk.Invariants,
numBlocks, blockSize int, commit, lean, onOperation bool,
numBlocks, blockSize int, commit, lean, onOperation, allInvariants bool,
) (stopEarly bool, simError error) {
// in case we have to end early, don't os.Exit so that we can run cleanup code.
@ -108,7 +108,7 @@ func SimulateFromSeed(
blockSimulator := createBlockSimulator(
testingMode, tb, t, w, params, eventStats.tally, invariants,
ops, operationQueue, timeOperationQueue,
numBlocks, blockSize, logWriter, lean, onOperation)
numBlocks, blockSize, logWriter, lean, onOperation, allInvariants)
if !testingMode {
b.ResetTimer()
@ -137,7 +137,7 @@ func SimulateFromSeed(
app.BeginBlock(request)
if testingMode {
assertAllInvariants(t, app, invariants, "BeginBlock", logWriter)
assertAllInvariants(t, app, invariants, "BeginBlock", logWriter, allInvariants)
}
ctx := app.NewContext(false, header)
@ -152,14 +152,14 @@ func SimulateFromSeed(
tb, r, app, ctx, accs, logWriter, eventStats.tally, lean)
if testingMode && onOperation {
assertAllInvariants(t, app, invariants, "QueuedOperations", logWriter)
assertAllInvariants(t, app, invariants, "QueuedOperations", logWriter, allInvariants)
}
// run standard operations
operations := blockSimulator(r, app, ctx, accs, header)
opCount += operations + numQueuedOpsRan + numQueuedTimeOpsRan
if testingMode {
assertAllInvariants(t, app, invariants, "StandardOperations", logWriter)
assertAllInvariants(t, app, invariants, "StandardOperations", logWriter, allInvariants)
}
res := app.EndBlock(abci.RequestEndBlock{})
@ -172,7 +172,7 @@ func SimulateFromSeed(
logWriter.AddEntry(EndBlockEntry(int64(height)))
if testingMode {
assertAllInvariants(t, app, invariants, "EndBlock", logWriter)
assertAllInvariants(t, app, invariants, "EndBlock", logWriter, allInvariants)
}
if commit {
app.Commit()
@ -221,7 +221,7 @@ type blockSimFn func(r *rand.Rand, app *baseapp.BaseApp, ctx sdk.Context,
func createBlockSimulator(testingMode bool, tb testing.TB, t *testing.T, w io.Writer, params Params,
event func(string), invariants sdk.Invariants, ops WeightedOperations,
operationQueue OperationQueue, timeOperationQueue []FutureOperation,
totalNumBlocks, avgBlockSize int, logWriter LogWriter, lean, onOperation bool) blockSimFn {
totalNumBlocks, avgBlockSize int, logWriter LogWriter, lean, onOperation, allInvariants bool) blockSimFn {
lastBlocksizeState := 0 // state for [4 * uniform distribution]
blocksize := 0
@ -269,7 +269,7 @@ func createBlockSimulator(testingMode bool, tb testing.TB, t *testing.T, w io.Wr
fmt.Fprintf(w, "\rSimulating... block %d/%d, operation %d/%d. ",
header.Height, totalNumBlocks, opCount, blocksize)
eventStr := fmt.Sprintf("operation: %v", opMsg.String())
assertAllInvariants(t, app, invariants, eventStr, logWriter)
assertAllInvariants(t, app, invariants, eventStr, logWriter, allInvariants)
} else if opCount%50 == 0 {
fmt.Fprintf(w, "\rSimulating... block %d/%d, operation %d/%d. ",
header.Height, totalNumBlocks, opCount, blocksize)

View File

@ -14,18 +14,31 @@ import (
// assertAll asserts the all invariants against application state
func assertAllInvariants(t *testing.T, app *baseapp.BaseApp, invs sdk.Invariants,
event string, logWriter LogWriter) {
event string, logWriter LogWriter, allInvariants bool) {
ctx := app.NewContext(false, abci.Header{Height: app.LastBlockHeight() + 1})
var broken bool
var invariantResults []string
for i := 0; i < len(invs); i++ {
if err := invs[i](ctx); err != nil {
fmt.Printf("Invariants broken after %s\n%s\n", event, err.Error())
res, stop := invs[i](ctx)
if stop {
broken = true
invariantResults = append(invariantResults, res)
} else if allInvariants {
invariantResults = append(invariantResults, res)
}
}
if broken {
fmt.Printf("Invariants broken after %s\n\n", event)
for _, res := range invariantResults {
fmt.Printf("%s\n", res)
}
logWriter.PrintLogs()
t.Fatal()
}
}
}
func getTestingMode(tb testing.TB) (testingMode bool, t *testing.T, b *testing.B) {
testingMode = false
@ -66,11 +79,11 @@ func getBlockSize(r *rand.Rand, params Params,
func PeriodicInvariants(invariants []sdk.Invariant, period, offset int) []sdk.Invariant {
var outInvariants []sdk.Invariant
for _, invariant := range invariants {
outInvariant := func(ctx sdk.Context) error {
outInvariant := func(ctx sdk.Context) (string, bool) {
if int(ctx.BlockHeight())%period == offset {
return invariant(ctx)
}
return nil
return "", false
}
outInvariants = append(outInvariants, outInvariant)
}

View File

@ -25,20 +25,20 @@ func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) {
// AllInvariants runs all invariants of the staking module.
func AllInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
err := ModuleAccountInvariants(k)(ctx)
if err != nil {
return err
return func(ctx sdk.Context) (string, bool) {
res, stop := ModuleAccountInvariants(k)(ctx)
if stop {
return res, stop
}
err = NonNegativePowerInvariant(k)(ctx)
if err != nil {
return err
res, stop = NonNegativePowerInvariant(k)(ctx)
if stop {
return res, stop
}
err = PositiveDelegationInvariant(k)(ctx)
if err != nil {
return err
res, stop = PositiveDelegationInvariant(k)(ctx)
if stop {
return res, stop
}
return DelegatorSharesInvariant(k)(ctx)
@ -48,7 +48,7 @@ func AllInvariants(k Keeper) sdk.Invariant {
// ModuleAccountInvariants checks that the bonded and notBonded ModuleAccounts pools
// reflects the tokens actively bonded and not bonded
func ModuleAccountInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
bonded := sdk.ZeroInt()
notBonded := sdk.ZeroInt()
bondedPool := k.GetBondedPool(ctx)
@ -76,12 +76,11 @@ func ModuleAccountInvariants(k Keeper) sdk.Invariant {
poolBonded := bondedPool.GetCoins().AmountOf(bondDenom)
poolNotBonded := notBondedPool.GetCoins().AmountOf(bondDenom)
broken := !poolBonded.Equal(bonded) || !poolNotBonded.Equal(notBonded)
// Bonded tokens should equal sum of tokens with bonded validators
// Not-bonded tokens should equal unbonding delegations plus tokens on unbonded validators
if !poolBonded.Equal(bonded) || !poolNotBonded.Equal(notBonded) {
return fmt.Errorf(
"bonded token invariance:\n"+
return sdk.FormatInvariant(types.ModuleName, "bonded and not bonded module account coins", fmt.Sprintf(
"\tPool's bonded tokens: %v\n"+
"\tsum of bonded tokens: %v\n"+
"not bonded token invariance:\n"+
@ -90,17 +89,16 @@ func ModuleAccountInvariants(k Keeper) sdk.Invariant {
"module accounts total (bonded + not bonded):\n"+
"\tModule Accounts' tokens: %v\n"+
"\tsum tokens: %v\n",
poolBonded, bonded, poolNotBonded, notBonded, poolBonded.Add(poolNotBonded), bonded.Add(notBonded),
)
}
return nil
poolBonded, bonded, poolNotBonded, notBonded, poolBonded.Add(poolNotBonded), bonded.Add(notBonded)), broken)
}
}
// NonNegativePowerInvariant checks that all stored validators have >= 0 power.
func NonNegativePowerInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
var msg string
var broken bool
iterator := k.ValidatorsPowerStoreIterator(ctx)
for ; iterator.Valid(); iterator.Next() {
@ -112,42 +110,54 @@ func NonNegativePowerInvariant(k Keeper) sdk.Invariant {
powerKey := types.GetValidatorsByPowerIndexKey(validator)
if !bytes.Equal(iterator.Key(), powerKey) {
return fmt.Errorf("power store invariance:\n\tvalidator.Power: %v"+
broken = true
msg += fmt.Sprintf("power store invariance:\n\tvalidator.Power: %v"+
"\n\tkey should be: %v\n\tkey in store: %v",
validator.GetConsensusPower(), powerKey, iterator.Key())
}
if validator.Tokens.IsNegative() {
return fmt.Errorf("negative tokens for validator: %v", validator)
broken = true
msg += fmt.Sprintf("negative tokens for validator: %v", validator)
}
}
iterator.Close()
return nil
return sdk.FormatInvariant(types.ModuleName, "nonnegative power", msg, broken)
}
}
// PositiveDelegationInvariant checks that all stored delegations have > 0 shares.
func PositiveDelegationInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
var msg string
var amt int
delegations := k.GetAllDelegations(ctx)
for _, delegation := range delegations {
if delegation.Shares.IsNegative() {
return fmt.Errorf("delegation with negative shares: %+v", delegation)
amt++
msg += fmt.Sprintf("\tdelegation with negative shares: %+v\n", delegation)
}
if delegation.Shares.IsZero() {
return fmt.Errorf("delegation with zero shares: %+v", delegation)
amt++
msg += fmt.Sprintf("\tdelegation with zero shares: %+v\n", delegation)
}
}
broken := amt != 0
return nil
return sdk.FormatInvariant(types.ModuleName, "positive delegations", fmt.Sprintf(
"%d invalid delegations found\n%s", amt, msg), broken)
}
}
// DelegatorSharesInvariant checks whether all the delegator shares which persist
// in the delegator object add up to the correct total delegator shares
// amount stored in each validator
// amount stored in each validator.
func DelegatorSharesInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
var msg string
var broken bool
validators := k.GetAllValidators(ctx)
for _, validator := range validators {
@ -160,11 +170,12 @@ func DelegatorSharesInvariant(k Keeper) sdk.Invariant {
}
if !valTotalDelShares.Equal(totalDelShares) {
return fmt.Errorf("broken delegator shares invariance:\n"+
broken = true
msg += fmt.Sprintf("broken delegator shares invariance:\n"+
"\tvalidator.DelegatorShares: %v\n"+
"\tsum of Delegator.Shares: %v", valTotalDelShares, totalDelShares)
}
}
return nil
return sdk.FormatInvariant(types.ModuleName, "delegator shares", msg, broken)
}
}

View File

@ -15,14 +15,14 @@ func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) {
// AllInvariants runs all invariants of the supply module.
func AllInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
return TotalSupply(k)(ctx)
}
}
// TotalSupply checks that the total supply reflects all the coins held in accounts
func TotalSupply(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error {
return func(ctx sdk.Context) (string, bool) {
var expectedTotal sdk.Coins
supply := k.GetSupply(ctx)
@ -31,12 +31,12 @@ func TotalSupply(k Keeper) sdk.Invariant {
return false
})
if !expectedTotal.IsEqual(supply.Total) {
return fmt.Errorf("total supply invariance:\n"+
"\tsum of accounts coins: %v\n"+
"\tsupply.Total: %v", expectedTotal, supply.Total)
}
broken := !expectedTotal.IsEqual(supply.Total)
return nil
return sdk.FormatInvariant(types.ModuleName, "total supply",
fmt.Sprintf(
"\tsum of accounts coins: %v\n"+
"\tsupply.Total: %v",
expectedTotal, supply.Total), broken)
}
}