collect all invariants for a block before stopping (#4707)

This commit is contained in:
colin axner 2019-07-11 03:56:43 -07:00 committed by Alessio Treglia
parent 3f519832a7
commit 280734d0e3
13 changed files with 172 additions and 163 deletions

View File

@ -42,15 +42,16 @@ func init() {
flag.BoolVar(&commit, "SimulationCommit", false, "have the simulation commit") flag.BoolVar(&commit, "SimulationCommit", false, "have the simulation commit")
flag.IntVar(&period, "SimulationPeriod", 1, "run slow invariants only once every period assertions") flag.IntVar(&period, "SimulationPeriod", 1, "run slow invariants only once every period assertions")
flag.BoolVar(&onOperation, "SimulateEveryOperation", false, "run slow invariants every operation") flag.BoolVar(&onOperation, "SimulateEveryOperation", false, "run slow invariants every operation")
flag.BoolVar(&allInvariants, "PrintAllInvariants", false, "print all invariants if a broken invariant is found")
} }
// helper function for populating input for SimulateFromSeed // helper function for populating input for SimulateFromSeed
func getSimulateFromSeedInput(tb testing.TB, w io.Writer, app *SimApp) ( func getSimulateFromSeedInput(tb testing.TB, w io.Writer, app *SimApp) (
testing.TB, io.Writer, *baseapp.BaseApp, simulation.AppStateFn, int64, testing.TB, io.Writer, *baseapp.BaseApp, simulation.AppStateFn, int64,
simulation.WeightedOperations, sdk.Invariants, int, int, bool, bool, bool) { simulation.WeightedOperations, sdk.Invariants, int, int, bool, bool, bool, bool) {
return tb, w, app.BaseApp, appStateFn, seed, return tb, w, app.BaseApp, appStateFn, seed,
testAndRunTxs(app), invariants(app), numBlocks, blockSize, commit, lean, onOperation testAndRunTxs(app), invariants(app), numBlocks, blockSize, commit, lean, onOperation, allInvariants
} }
func appStateFn( func appStateFn(
@ -602,6 +603,7 @@ func TestAppStateDeterminism(t *testing.T) {
true, true,
false, false,
false, false,
false,
) )
appHash := app.LastCommitID().Hash appHash := app.LastCommitID().Hash
appHashList[j] = appHash appHashList[j] = appHash
@ -627,7 +629,7 @@ func BenchmarkInvariants(b *testing.B) {
// 2. Run parameterized simulation (w/o invariants) // 2. Run parameterized simulation (w/o invariants)
_, err := simulation.SimulateFromSeed( _, err := simulation.SimulateFromSeed(
b, ioutil.Discard, app.BaseApp, appStateFn, seed, testAndRunTxs(app), b, ioutil.Discard, app.BaseApp, appStateFn, seed, testAndRunTxs(app),
[]sdk.Invariant{}, numBlocks, blockSize, commit, lean, onOperation, []sdk.Invariant{}, numBlocks, blockSize, commit, lean, onOperation, false,
) )
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@ -642,8 +644,8 @@ func BenchmarkInvariants(b *testing.B) {
// their respective metadata which makes it useful for testing/benchmarking. // their respective metadata which makes it useful for testing/benchmarking.
for _, cr := range app.crisisKeeper.Routes() { for _, cr := range app.crisisKeeper.Routes() {
b.Run(fmt.Sprintf("%s/%s", cr.ModuleName, cr.Route), func(b *testing.B) { b.Run(fmt.Sprintf("%s/%s", cr.ModuleName, cr.Route), func(b *testing.B) {
if err := cr.Invar(ctx); err != nil { if res, stop := cr.Invar(ctx); stop {
fmt.Printf("broken invariant at block %d of %d\n%s", ctx.BlockHeight()-1, numBlocks, err) fmt.Printf("broken invariant at block %d of %d\n%s", ctx.BlockHeight()-1, numBlocks, res)
b.FailNow() b.FailNow()
} }
}) })

View File

@ -37,17 +37,18 @@ import (
) )
var ( var (
genesisFile string genesisFile string
paramsFile string paramsFile string
seed int64 seed int64
numBlocks int numBlocks int
blockSize int blockSize int
enabled bool enabled bool
verbose bool verbose bool
lean bool lean bool
commit bool commit bool
period int period int
onOperation bool // TODO Remove in favor of binary search for invariant violation onOperation bool // TODO Remove in favor of binary search for invariant violation
allInvariants bool
) )
// NewSimAppUNSAFE is used for debugging purposes only. // NewSimAppUNSAFE is used for debugging purposes only.

View File

@ -1,10 +1,12 @@
package types package types
import "fmt"
// An Invariant is a function which tests a particular invariant. // An Invariant is a function which tests a particular invariant.
// If the invariant has been broken, it should return an error // If the invariant has been broken, it should return an error
// containing a descriptive message about what happened. // containing a descriptive message about what happened.
// The simulator will then halt and print the logs. // The simulator will then halt and print the logs.
type Invariant func(ctx Context) error type Invariant func(ctx Context) (string, bool)
// Invariants defines a group of invariants // Invariants defines a group of invariants
type Invariants []Invariant type Invariants []Invariant
@ -13,3 +15,10 @@ type Invariants []Invariant
type InvariantRegistry interface { type InvariantRegistry interface {
RegisterRoute(moduleName, route string, invar Invariant) RegisterRoute(moduleName, route string, invar Invariant)
} }
// FormatInvariant returns a standardized invariant message along with
// broken boolean.
func FormatInvariant(module, name, msg string, broken bool) (string, bool) {
return fmt.Sprintf("%s: %s invariant\n%s\nInvariant Broken: %v\n",
module, name, msg, broken), broken
}

View File

@ -1,11 +1,9 @@
package keeper package keeper
import ( import (
"errors"
"fmt" "fmt"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/exported"
"github.com/cosmos/cosmos-sdk/x/bank/internal/types" "github.com/cosmos/cosmos-sdk/x/bank/internal/types"
) )
@ -17,36 +15,23 @@ func RegisterInvariants(ir sdk.InvariantRegistry, ak types.AccountKeeper) {
// NonnegativeBalanceInvariant checks that all accounts in the application have non-negative balances // NonnegativeBalanceInvariant checks that all accounts in the application have non-negative balances
func NonnegativeBalanceInvariant(ak types.AccountKeeper) sdk.Invariant { func NonnegativeBalanceInvariant(ak types.AccountKeeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
var msg string
var amt int
accts := ak.GetAllAccounts(ctx) accts := ak.GetAllAccounts(ctx)
for _, acc := range accts { for _, acc := range accts {
coins := acc.GetCoins() coins := acc.GetCoins()
if coins.IsAnyNegative() { if coins.IsAnyNegative() {
return fmt.Errorf("%s has a negative denomination of %s", amt++
msg += fmt.Sprintf("\t%s has a negative denomination of %s\n",
acc.GetAddress().String(), acc.GetAddress().String(),
coins.String()) coins.String())
} }
} }
return nil broken := amt != 0
}
} return sdk.FormatInvariant(types.ModuleName, "nonnegative-outstanding",
fmt.Sprintf("amount of negative accounts found %d\n%s", amt, msg), broken)
// TotalCoinsInvariant checks that the sum of the coins across all accounts
// is what is expected
func TotalCoinsInvariant(ak types.AccountKeeper, totalSupplyFn func() sdk.Coins) sdk.Invariant {
return func(ctx sdk.Context) error {
totalCoins := sdk.NewCoins()
chkAccount := func(acc exported.Account) bool {
coins := acc.GetCoins()
totalCoins = totalCoins.Add(coins)
return false
}
ak.IterateAccounts(ctx, chkAccount)
if !totalSupplyFn().IsEqual(totalCoins) {
return errors.New("total calculated coins doesn't equal expected coins")
}
return nil
} }
} }

View File

@ -40,10 +40,11 @@ func handleMsgVerifyInvariant(ctx sdk.Context, msg types.MsgVerifyInvariant, k k
found := false found := false
msgFullRoute := msg.FullInvariantRoute() msgFullRoute := msg.FullInvariantRoute()
var invarianceErr error var res string
var stop bool
for _, invarRoute := range k.Routes() { for _, invarRoute := range k.Routes() {
if invarRoute.FullRoute() == msgFullRoute { if invarRoute.FullRoute() == msgFullRoute {
invarianceErr = invarRoute.Invar(cacheCtx) res, stop = invarRoute.Invar(cacheCtx)
found = true found = true
break break
} }
@ -53,7 +54,7 @@ func handleMsgVerifyInvariant(ctx sdk.Context, msg types.MsgVerifyInvariant, k k
return types.ErrUnknownInvariant(types.DefaultCodespace).Result() return types.ErrUnknownInvariant(types.DefaultCodespace).Result()
} }
if invarianceErr != nil { if stop {
// NOTE currently, because the chain halts here, this transaction will never be included // NOTE currently, because the chain halts here, this transaction will never be included
// in the blockchain thus the constant fee will have never been deducted. Thus no // in the blockchain thus the constant fee will have never been deducted. Thus no
// refund is required. // refund is required.
@ -70,7 +71,7 @@ func handleMsgVerifyInvariant(ctx sdk.Context, msg types.MsgVerifyInvariant, k k
//} //}
// TODO replace with circuit breaker // TODO replace with circuit breaker
panic(invarianceErr) panic(res)
} }
ctx.EventManager().EmitEvents(sdk.Events{ ctx.EventManager().EmitEvents(sdk.Events{

View File

@ -1,7 +1,6 @@
package crisis_test package crisis_test
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
@ -17,8 +16,8 @@ import (
var ( var (
testModuleName = "dummy" testModuleName = "dummy"
dummyRouteWhichPasses = crisis.NewInvarRoute(testModuleName, "which-passes", func(_ sdk.Context) error { return nil }) dummyRouteWhichPasses = crisis.NewInvarRoute(testModuleName, "which-passes", func(_ sdk.Context) (string, bool) { return "", false })
dummyRouteWhichFails = crisis.NewInvarRoute(testModuleName, "which-fails", func(_ sdk.Context) error { return errors.New("whoops") }) dummyRouteWhichFails = crisis.NewInvarRoute(testModuleName, "which-fails", func(_ sdk.Context) (string, bool) { return "whoops", true })
addrs = distr.TestAddrs addrs = distr.TestAddrs
) )

View File

@ -67,13 +67,13 @@ func (k Keeper) AssertInvariants(ctx sdk.Context) {
start := time.Now() start := time.Now()
invarRoutes := k.Routes() invarRoutes := k.Routes()
for _, ir := range invarRoutes { for _, ir := range invarRoutes {
if err := ir.Invar(ctx); err != nil { if res, stop := ir.Invar(ctx); stop {
// TODO: Include app name as part of context to allow for this to be // TODO: Include app name as part of context to allow for this to be
// variable. // variable.
panic(fmt.Errorf("invariant broken: %s\n"+ panic(fmt.Errorf("invariant broken: %s\n"+
"\tCRITICAL please submit the following transaction:\n"+ "\tCRITICAL please submit the following transaction:\n"+
"\t\t tx crisis invariant-broken %v %v", err, ir.ModuleName, ir.Route)) "\t\t tx crisis invariant-broken %s %s", res, ir.ModuleName, ir.Route))
} }
} }

View File

@ -22,18 +22,18 @@ func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) {
// AllInvariants runs all invariants of the distribution module // AllInvariants runs all invariants of the distribution module
func AllInvariants(k Keeper) sdk.Invariant { func AllInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
err := CanWithdrawInvariant(k)(ctx) res, stop := CanWithdrawInvariant(k)(ctx)
if err != nil { if stop {
return err return res, stop
} }
err = NonNegativeOutstandingInvariant(k)(ctx) res, stop = NonNegativeOutstandingInvariant(k)(ctx)
if err != nil { if stop {
return err return res, stop
} }
err = ReferenceCountInvariant(k)(ctx) res, stop = ReferenceCountInvariant(k)(ctx)
if err != nil { if stop {
return err return res, stop
} }
return ModuleAccountInvariant(k)(ctx) return ModuleAccountInvariant(k)(ctx)
} }
@ -41,30 +41,29 @@ func AllInvariants(k Keeper) sdk.Invariant {
// NonNegativeOutstandingInvariant checks that outstanding unwithdrawn fees are never negative // NonNegativeOutstandingInvariant checks that outstanding unwithdrawn fees are never negative
func NonNegativeOutstandingInvariant(k Keeper) sdk.Invariant { func NonNegativeOutstandingInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
var msg string
var amt int
var outstanding sdk.DecCoins var outstanding sdk.DecCoins
k.IterateValidatorOutstandingRewards(ctx, func(_ sdk.ValAddress, rewards types.ValidatorOutstandingRewards) (stop bool) { k.IterateValidatorOutstandingRewards(ctx, func(addr sdk.ValAddress, rewards types.ValidatorOutstandingRewards) (stop bool) {
outstanding = rewards outstanding = rewards
if outstanding.IsAnyNegative() { if outstanding.IsAnyNegative() {
return true amt++
msg += fmt.Sprintf("\t%v has negative outstanding coins: %v\n", addr, outstanding)
} }
return false return false
}) })
broken := amt != 0
if outstanding.IsAnyNegative() { return sdk.FormatInvariant(types.ModuleName, "nonnegative outstanding",
return fmt.Errorf("negative outstanding coins: %v", outstanding) fmt.Sprintf("found %d validators with negative outstanding rewards\n%s", amt, msg), broken)
}
return nil
} }
} }
// CanWithdrawInvariant checks that current rewards can be completely withdrawn // CanWithdrawInvariant checks that current rewards can be completely withdrawn
func CanWithdrawInvariant(k Keeper) sdk.Invariant { func CanWithdrawInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
// cache, we don't want to write changes // cache, we don't want to write changes
ctx, _ = ctx.CacheContext() ctx, _ = ctx.CacheContext()
@ -98,17 +97,15 @@ func CanWithdrawInvariant(k Keeper) sdk.Invariant {
return false return false
}) })
if len(remaining) > 0 && remaining[0].Amount.LT(sdk.ZeroDec()) { broken := len(remaining) > 0 && remaining[0].Amount.LT(sdk.ZeroDec())
return fmt.Errorf("negative remaining coins: %v", remaining) return sdk.FormatInvariant(types.ModuleName, "can withdraw",
} fmt.Sprintf("remaining coins: %v", remaining), broken)
return nil
} }
} }
// ReferenceCountInvariant checks that the number of historical rewards records is correct // ReferenceCountInvariant checks that the number of historical rewards records is correct
func ReferenceCountInvariant(k Keeper) sdk.Invariant { func ReferenceCountInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
valCount := uint64(0) valCount := uint64(0)
k.stakingKeeper.IterateValidators(ctx, func(_ int64, val exported.ValidatorI) (stop bool) { k.stakingKeeper.IterateValidators(ctx, func(_ int64, val exported.ValidatorI) (stop bool) {
@ -127,21 +124,18 @@ func ReferenceCountInvariant(k Keeper) sdk.Invariant {
// delegation (previous period), one record per slash (previous period) // delegation (previous period), one record per slash (previous period)
expected := valCount + uint64(len(dels)) + slashCount expected := valCount + uint64(len(dels)) + slashCount
count := k.GetValidatorHistoricalReferenceCount(ctx) count := k.GetValidatorHistoricalReferenceCount(ctx)
broken := count != expected
if count != expected { return sdk.FormatInvariant(types.ModuleName, "reference count", fmt.Sprintf("unexpected number of historical rewards records: "+
return fmt.Errorf("unexpected number of historical rewards records: "+ "expected %v (%v vals + %v dels + %v slashes), got %v",
"expected %v (%v vals + %v dels + %v slashes), got %v", expected, valCount, len(dels), slashCount, count), broken)
expected, valCount, len(dels), slashCount, count)
}
return nil
} }
} }
// ModuleAccountInvariant checks that the coins held by the distr ModuleAccount // ModuleAccountInvariant checks that the coins held by the distr ModuleAccount
// is consistent with the sum of validator outstanding rewards // is consistent with the sum of validator outstanding rewards
func ModuleAccountInvariant(k Keeper) sdk.Invariant { func ModuleAccountInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
var expectedCoins sdk.DecCoins var expectedCoins sdk.DecCoins
k.IterateValidatorOutstandingRewards(ctx, func(_ sdk.ValAddress, rewards types.ValidatorOutstandingRewards) (stop bool) { k.IterateValidatorOutstandingRewards(ctx, func(_ sdk.ValAddress, rewards types.ValidatorOutstandingRewards) (stop bool) {
@ -154,12 +148,8 @@ func ModuleAccountInvariant(k Keeper) sdk.Invariant {
macc := k.GetDistributionAccount(ctx) macc := k.GetDistributionAccount(ctx)
if !macc.GetCoins().IsEqual(expectedInt) { broken := !macc.GetCoins().IsEqual(expectedInt)
return fmt.Errorf("distribution ModuleAccount coins invariance:\n"+ return sdk.FormatInvariant(types.ModuleName, "ModuleAccount coins", fmt.Sprintf("expected ModuleAccount coins: %s\n"+
"\texpected ModuleAccount coins: %s\n"+ "\tdistribution ModuleAccount coins : %s", expectedInt, macc.GetCoins()), broken)
"\tdistribution ModuleAccount coins : %s", expectedInt, macc.GetCoins())
}
return nil
} }
} }

View File

@ -14,7 +14,7 @@ func RegisterInvariants(ir sdk.InvariantRegistry, keeper Keeper) {
// AllInvariants runs all invariants of the governance module // AllInvariants runs all invariants of the governance module
func AllInvariants(keeper Keeper) sdk.Invariant { func AllInvariants(keeper Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
return ModuleAccountInvariant(keeper)(ctx) return ModuleAccountInvariant(keeper)(ctx)
} }
} }
@ -22,7 +22,7 @@ func AllInvariants(keeper Keeper) sdk.Invariant {
// ModuleAccountInvariant checks that the module account coins reflects the sum of // ModuleAccountInvariant checks that the module account coins reflects the sum of
// deposit amounts held on store // deposit amounts held on store
func ModuleAccountInvariant(keeper Keeper) sdk.Invariant { func ModuleAccountInvariant(keeper Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
var expectedDeposits sdk.Coins var expectedDeposits sdk.Coins
keeper.IterateAllDeposits(ctx, func(deposit types.Deposit) bool { keeper.IterateAllDeposits(ctx, func(deposit types.Deposit) bool {
@ -31,12 +31,10 @@ func ModuleAccountInvariant(keeper Keeper) sdk.Invariant {
}) })
macc := keeper.GetGovernanceAccount(ctx) macc := keeper.GetGovernanceAccount(ctx)
if !macc.GetCoins().IsEqual(expectedDeposits) { broken := !macc.GetCoins().IsEqual(expectedDeposits)
return fmt.Errorf("deposits invariance:\n"+
"\tgov ModuleAccount coins: %s\n"+
"\tsum of deposit amounts: %s", macc.GetCoins(), expectedDeposits)
}
return nil return sdk.FormatInvariant(types.ModuleName, "deposits",
fmt.Sprintf("\tgov ModuleAccount coins: %s\n\tsum of deposit amounts: %s",
macc.GetCoins(), expectedDeposits), broken)
} }
} }

View File

@ -46,7 +46,7 @@ func SimulateFromSeed(
tb testing.TB, w io.Writer, app *baseapp.BaseApp, tb testing.TB, w io.Writer, app *baseapp.BaseApp,
appStateFn AppStateFn, seed int64, ops WeightedOperations, appStateFn AppStateFn, seed int64, ops WeightedOperations,
invariants sdk.Invariants, invariants sdk.Invariants,
numBlocks, blockSize int, commit, lean, onOperation bool, numBlocks, blockSize int, commit, lean, onOperation, allInvariants bool,
) (stopEarly bool, simError error) { ) (stopEarly bool, simError error) {
// in case we have to end early, don't os.Exit so that we can run cleanup code. // in case we have to end early, don't os.Exit so that we can run cleanup code.
@ -108,7 +108,7 @@ func SimulateFromSeed(
blockSimulator := createBlockSimulator( blockSimulator := createBlockSimulator(
testingMode, tb, t, w, params, eventStats.tally, invariants, testingMode, tb, t, w, params, eventStats.tally, invariants,
ops, operationQueue, timeOperationQueue, ops, operationQueue, timeOperationQueue,
numBlocks, blockSize, logWriter, lean, onOperation) numBlocks, blockSize, logWriter, lean, onOperation, allInvariants)
if !testingMode { if !testingMode {
b.ResetTimer() b.ResetTimer()
@ -137,7 +137,7 @@ func SimulateFromSeed(
app.BeginBlock(request) app.BeginBlock(request)
if testingMode { if testingMode {
assertAllInvariants(t, app, invariants, "BeginBlock", logWriter) assertAllInvariants(t, app, invariants, "BeginBlock", logWriter, allInvariants)
} }
ctx := app.NewContext(false, header) ctx := app.NewContext(false, header)
@ -152,14 +152,14 @@ func SimulateFromSeed(
tb, r, app, ctx, accs, logWriter, eventStats.tally, lean) tb, r, app, ctx, accs, logWriter, eventStats.tally, lean)
if testingMode && onOperation { if testingMode && onOperation {
assertAllInvariants(t, app, invariants, "QueuedOperations", logWriter) assertAllInvariants(t, app, invariants, "QueuedOperations", logWriter, allInvariants)
} }
// run standard operations // run standard operations
operations := blockSimulator(r, app, ctx, accs, header) operations := blockSimulator(r, app, ctx, accs, header)
opCount += operations + numQueuedOpsRan + numQueuedTimeOpsRan opCount += operations + numQueuedOpsRan + numQueuedTimeOpsRan
if testingMode { if testingMode {
assertAllInvariants(t, app, invariants, "StandardOperations", logWriter) assertAllInvariants(t, app, invariants, "StandardOperations", logWriter, allInvariants)
} }
res := app.EndBlock(abci.RequestEndBlock{}) res := app.EndBlock(abci.RequestEndBlock{})
@ -172,7 +172,7 @@ func SimulateFromSeed(
logWriter.AddEntry(EndBlockEntry(int64(height))) logWriter.AddEntry(EndBlockEntry(int64(height)))
if testingMode { if testingMode {
assertAllInvariants(t, app, invariants, "EndBlock", logWriter) assertAllInvariants(t, app, invariants, "EndBlock", logWriter, allInvariants)
} }
if commit { if commit {
app.Commit() app.Commit()
@ -221,7 +221,7 @@ type blockSimFn func(r *rand.Rand, app *baseapp.BaseApp, ctx sdk.Context,
func createBlockSimulator(testingMode bool, tb testing.TB, t *testing.T, w io.Writer, params Params, func createBlockSimulator(testingMode bool, tb testing.TB, t *testing.T, w io.Writer, params Params,
event func(string), invariants sdk.Invariants, ops WeightedOperations, event func(string), invariants sdk.Invariants, ops WeightedOperations,
operationQueue OperationQueue, timeOperationQueue []FutureOperation, operationQueue OperationQueue, timeOperationQueue []FutureOperation,
totalNumBlocks, avgBlockSize int, logWriter LogWriter, lean, onOperation bool) blockSimFn { totalNumBlocks, avgBlockSize int, logWriter LogWriter, lean, onOperation, allInvariants bool) blockSimFn {
lastBlocksizeState := 0 // state for [4 * uniform distribution] lastBlocksizeState := 0 // state for [4 * uniform distribution]
blocksize := 0 blocksize := 0
@ -269,7 +269,7 @@ func createBlockSimulator(testingMode bool, tb testing.TB, t *testing.T, w io.Wr
fmt.Fprintf(w, "\rSimulating... block %d/%d, operation %d/%d. ", fmt.Fprintf(w, "\rSimulating... block %d/%d, operation %d/%d. ",
header.Height, totalNumBlocks, opCount, blocksize) header.Height, totalNumBlocks, opCount, blocksize)
eventStr := fmt.Sprintf("operation: %v", opMsg.String()) eventStr := fmt.Sprintf("operation: %v", opMsg.String())
assertAllInvariants(t, app, invariants, eventStr, logWriter) assertAllInvariants(t, app, invariants, eventStr, logWriter, allInvariants)
} else if opCount%50 == 0 { } else if opCount%50 == 0 {
fmt.Fprintf(w, "\rSimulating... block %d/%d, operation %d/%d. ", fmt.Fprintf(w, "\rSimulating... block %d/%d, operation %d/%d. ",
header.Height, totalNumBlocks, opCount, blocksize) header.Height, totalNumBlocks, opCount, blocksize)

View File

@ -14,17 +14,30 @@ import (
// assertAll asserts the all invariants against application state // assertAll asserts the all invariants against application state
func assertAllInvariants(t *testing.T, app *baseapp.BaseApp, invs sdk.Invariants, func assertAllInvariants(t *testing.T, app *baseapp.BaseApp, invs sdk.Invariants,
event string, logWriter LogWriter) { event string, logWriter LogWriter, allInvariants bool) {
ctx := app.NewContext(false, abci.Header{Height: app.LastBlockHeight() + 1}) ctx := app.NewContext(false, abci.Header{Height: app.LastBlockHeight() + 1})
var broken bool
var invariantResults []string
for i := 0; i < len(invs); i++ { for i := 0; i < len(invs); i++ {
if err := invs[i](ctx); err != nil { res, stop := invs[i](ctx)
fmt.Printf("Invariants broken after %s\n%s\n", event, err.Error()) if stop {
logWriter.PrintLogs() broken = true
t.Fatal() invariantResults = append(invariantResults, res)
} else if allInvariants {
invariantResults = append(invariantResults, res)
} }
} }
if broken {
fmt.Printf("Invariants broken after %s\n\n", event)
for _, res := range invariantResults {
fmt.Printf("%s\n", res)
}
logWriter.PrintLogs()
t.Fatal()
}
} }
func getTestingMode(tb testing.TB) (testingMode bool, t *testing.T, b *testing.B) { func getTestingMode(tb testing.TB) (testingMode bool, t *testing.T, b *testing.B) {
@ -66,11 +79,11 @@ func getBlockSize(r *rand.Rand, params Params,
func PeriodicInvariants(invariants []sdk.Invariant, period, offset int) []sdk.Invariant { func PeriodicInvariants(invariants []sdk.Invariant, period, offset int) []sdk.Invariant {
var outInvariants []sdk.Invariant var outInvariants []sdk.Invariant
for _, invariant := range invariants { for _, invariant := range invariants {
outInvariant := func(ctx sdk.Context) error { outInvariant := func(ctx sdk.Context) (string, bool) {
if int(ctx.BlockHeight())%period == offset { if int(ctx.BlockHeight())%period == offset {
return invariant(ctx) return invariant(ctx)
} }
return nil return "", false
} }
outInvariants = append(outInvariants, outInvariant) outInvariants = append(outInvariants, outInvariant)
} }

View File

@ -25,20 +25,20 @@ func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) {
// AllInvariants runs all invariants of the staking module. // AllInvariants runs all invariants of the staking module.
func AllInvariants(k Keeper) sdk.Invariant { func AllInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
err := ModuleAccountInvariants(k)(ctx) res, stop := ModuleAccountInvariants(k)(ctx)
if err != nil { if stop {
return err return res, stop
} }
err = NonNegativePowerInvariant(k)(ctx) res, stop = NonNegativePowerInvariant(k)(ctx)
if err != nil { if stop {
return err return res, stop
} }
err = PositiveDelegationInvariant(k)(ctx) res, stop = PositiveDelegationInvariant(k)(ctx)
if err != nil { if stop {
return err return res, stop
} }
return DelegatorSharesInvariant(k)(ctx) return DelegatorSharesInvariant(k)(ctx)
@ -48,7 +48,7 @@ func AllInvariants(k Keeper) sdk.Invariant {
// ModuleAccountInvariants checks that the bonded and notBonded ModuleAccounts pools // ModuleAccountInvariants checks that the bonded and notBonded ModuleAccounts pools
// reflects the tokens actively bonded and not bonded // reflects the tokens actively bonded and not bonded
func ModuleAccountInvariants(k Keeper) sdk.Invariant { func ModuleAccountInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
bonded := sdk.ZeroInt() bonded := sdk.ZeroInt()
notBonded := sdk.ZeroInt() notBonded := sdk.ZeroInt()
bondedPool := k.GetBondedPool(ctx) bondedPool := k.GetBondedPool(ctx)
@ -76,31 +76,29 @@ func ModuleAccountInvariants(k Keeper) sdk.Invariant {
poolBonded := bondedPool.GetCoins().AmountOf(bondDenom) poolBonded := bondedPool.GetCoins().AmountOf(bondDenom)
poolNotBonded := notBondedPool.GetCoins().AmountOf(bondDenom) poolNotBonded := notBondedPool.GetCoins().AmountOf(bondDenom)
broken := !poolBonded.Equal(bonded) || !poolNotBonded.Equal(notBonded)
// Bonded tokens should equal sum of tokens with bonded validators // Bonded tokens should equal sum of tokens with bonded validators
// Not-bonded tokens should equal unbonding delegations plus tokens on unbonded validators // Not-bonded tokens should equal unbonding delegations plus tokens on unbonded validators
if !poolBonded.Equal(bonded) || !poolNotBonded.Equal(notBonded) { return sdk.FormatInvariant(types.ModuleName, "bonded and not bonded module account coins", fmt.Sprintf(
return fmt.Errorf( "\tPool's bonded tokens: %v\n"+
"bonded token invariance:\n"+ "\tsum of bonded tokens: %v\n"+
"\tPool's bonded tokens: %v\n"+ "not bonded token invariance:\n"+
"\tsum of bonded tokens: %v\n"+ "\tPool's not bonded tokens: %v\n"+
"not bonded token invariance:\n"+ "\tsum of not bonded tokens: %v\n"+
"\tPool's not bonded tokens: %v\n"+ "module accounts total (bonded + not bonded):\n"+
"\tsum of not bonded tokens: %v\n"+ "\tModule Accounts' tokens: %v\n"+
"module accounts total (bonded + not bonded):\n"+ "\tsum tokens: %v\n",
"\tModule Accounts' tokens: %v\n"+ poolBonded, bonded, poolNotBonded, notBonded, poolBonded.Add(poolNotBonded), bonded.Add(notBonded)), broken)
"\tsum tokens: %v\n",
poolBonded, bonded, poolNotBonded, notBonded, poolBonded.Add(poolNotBonded), bonded.Add(notBonded),
)
}
return nil
} }
} }
// NonNegativePowerInvariant checks that all stored validators have >= 0 power. // NonNegativePowerInvariant checks that all stored validators have >= 0 power.
func NonNegativePowerInvariant(k Keeper) sdk.Invariant { func NonNegativePowerInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
var msg string
var broken bool
iterator := k.ValidatorsPowerStoreIterator(ctx) iterator := k.ValidatorsPowerStoreIterator(ctx)
for ; iterator.Valid(); iterator.Next() { for ; iterator.Valid(); iterator.Next() {
@ -112,42 +110,54 @@ func NonNegativePowerInvariant(k Keeper) sdk.Invariant {
powerKey := types.GetValidatorsByPowerIndexKey(validator) powerKey := types.GetValidatorsByPowerIndexKey(validator)
if !bytes.Equal(iterator.Key(), powerKey) { if !bytes.Equal(iterator.Key(), powerKey) {
return fmt.Errorf("power store invariance:\n\tvalidator.Power: %v"+ broken = true
msg += fmt.Sprintf("power store invariance:\n\tvalidator.Power: %v"+
"\n\tkey should be: %v\n\tkey in store: %v", "\n\tkey should be: %v\n\tkey in store: %v",
validator.GetConsensusPower(), powerKey, iterator.Key()) validator.GetConsensusPower(), powerKey, iterator.Key())
} }
if validator.Tokens.IsNegative() { if validator.Tokens.IsNegative() {
return fmt.Errorf("negative tokens for validator: %v", validator) broken = true
msg += fmt.Sprintf("negative tokens for validator: %v", validator)
} }
} }
iterator.Close() iterator.Close()
return nil return sdk.FormatInvariant(types.ModuleName, "nonnegative power", msg, broken)
} }
} }
// PositiveDelegationInvariant checks that all stored delegations have > 0 shares. // PositiveDelegationInvariant checks that all stored delegations have > 0 shares.
func PositiveDelegationInvariant(k Keeper) sdk.Invariant { func PositiveDelegationInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
var msg string
var amt int
delegations := k.GetAllDelegations(ctx) delegations := k.GetAllDelegations(ctx)
for _, delegation := range delegations { for _, delegation := range delegations {
if delegation.Shares.IsNegative() { if delegation.Shares.IsNegative() {
return fmt.Errorf("delegation with negative shares: %+v", delegation) amt++
msg += fmt.Sprintf("\tdelegation with negative shares: %+v\n", delegation)
} }
if delegation.Shares.IsZero() { if delegation.Shares.IsZero() {
return fmt.Errorf("delegation with zero shares: %+v", delegation) amt++
msg += fmt.Sprintf("\tdelegation with zero shares: %+v\n", delegation)
} }
} }
broken := amt != 0
return nil return sdk.FormatInvariant(types.ModuleName, "positive delegations", fmt.Sprintf(
"%d invalid delegations found\n%s", amt, msg), broken)
} }
} }
// DelegatorSharesInvariant checks whether all the delegator shares which persist // DelegatorSharesInvariant checks whether all the delegator shares which persist
// in the delegator object add up to the correct total delegator shares // in the delegator object add up to the correct total delegator shares
// amount stored in each validator // amount stored in each validator.
func DelegatorSharesInvariant(k Keeper) sdk.Invariant { func DelegatorSharesInvariant(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
var msg string
var broken bool
validators := k.GetAllValidators(ctx) validators := k.GetAllValidators(ctx)
for _, validator := range validators { for _, validator := range validators {
@ -160,11 +170,12 @@ func DelegatorSharesInvariant(k Keeper) sdk.Invariant {
} }
if !valTotalDelShares.Equal(totalDelShares) { if !valTotalDelShares.Equal(totalDelShares) {
return fmt.Errorf("broken delegator shares invariance:\n"+ broken = true
msg += fmt.Sprintf("broken delegator shares invariance:\n"+
"\tvalidator.DelegatorShares: %v\n"+ "\tvalidator.DelegatorShares: %v\n"+
"\tsum of Delegator.Shares: %v", valTotalDelShares, totalDelShares) "\tsum of Delegator.Shares: %v", valTotalDelShares, totalDelShares)
} }
} }
return nil return sdk.FormatInvariant(types.ModuleName, "delegator shares", msg, broken)
} }
} }

View File

@ -15,14 +15,14 @@ func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) {
// AllInvariants runs all invariants of the supply module. // AllInvariants runs all invariants of the supply module.
func AllInvariants(k Keeper) sdk.Invariant { func AllInvariants(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
return TotalSupply(k)(ctx) return TotalSupply(k)(ctx)
} }
} }
// TotalSupply checks that the total supply reflects all the coins held in accounts // TotalSupply checks that the total supply reflects all the coins held in accounts
func TotalSupply(k Keeper) sdk.Invariant { func TotalSupply(k Keeper) sdk.Invariant {
return func(ctx sdk.Context) error { return func(ctx sdk.Context) (string, bool) {
var expectedTotal sdk.Coins var expectedTotal sdk.Coins
supply := k.GetSupply(ctx) supply := k.GetSupply(ctx)
@ -31,12 +31,12 @@ func TotalSupply(k Keeper) sdk.Invariant {
return false return false
}) })
if !expectedTotal.IsEqual(supply.Total) { broken := !expectedTotal.IsEqual(supply.Total)
return fmt.Errorf("total supply invariance:\n"+
"\tsum of accounts coins: %v\n"+
"\tsupply.Total: %v", expectedTotal, supply.Total)
}
return nil return sdk.FormatInvariant(types.ModuleName, "total supply",
fmt.Sprintf(
"\tsum of accounts coins: %v\n"+
"\tsupply.Total: %v",
expectedTotal, supply.Total), broken)
} }
} }