diff --git a/modules/nonce/replaycheck.go b/modules/nonce/replaycheck.go index 8a401e79a..b9f28d253 100644 --- a/modules/nonce/replaycheck.go +++ b/modules/nonce/replaycheck.go @@ -28,46 +28,30 @@ var _ stack.Middleware = ReplayCheck{} func (r ReplayCheck) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Checker) (res basecoin.Result, err error) { - stx, err := r.checkNonceTx(ctx, store, tx) + stx, err := r.checkIncrementNonceTx(ctx, store, tx) if err != nil { return res, err } - res, err = next.CheckTx(ctx, store, stx) - if err != nil { - return res, err - } - - err = r.incrementNonceTx(ctx, store, tx) - if err != nil { - return res, err - } - return + return next.CheckTx(ctx, store, stx) } // DeliverTx verifies tx is not being replayed - fulfills Middlware interface +// NOTE It is okay to modify the sequence before running the wrapped TX because if the +// wrapped Tx fails, the state changes are not applied func (r ReplayCheck) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Deliver) (res basecoin.Result, err error) { - stx, err := r.checkNonceTx(ctx, store, tx) + stx, err := r.checkIncrementNonceTx(ctx, store, tx) if err != nil { return res, err } - res, err = next.DeliverTx(ctx, store, stx) - if err != nil { - return res, err - } - - err = r.incrementNonceTx(ctx, store, tx) - if err != nil { - return res, err - } - return + return next.DeliverTx(ctx, store, stx) } -// checkNonceTx varifies the nonce sequence -func (r ReplayCheck) checkNonceTx(ctx basecoin.Context, store state.KVStore, +// checkNonceTx varifies the nonce sequence, an increment sequence number +func (r ReplayCheck) checkIncrementNonceTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx) (basecoin.Tx, error) { // make sure it is a the nonce Tx (Tx from this package) @@ -77,27 +61,9 @@ func (r ReplayCheck) checkNonceTx(ctx basecoin.Context, store state.KVStore, } // check the nonce sequence number - err := nonceTx.CheckSeq(ctx, store) + err := nonceTx.CheckIncrementSeq(ctx, store) if err != nil { return tx, err } return nonceTx.Tx, nil } - -// incrementNonceTx increases the nonce sequence number -func (r ReplayCheck) incrementNonceTx(ctx basecoin.Context, store state.KVStore, - tx basecoin.Tx) error { - - // make sure it is a the nonce Tx (Tx from this package) - nonceTx, ok := tx.Unwrap().(Tx) - if !ok { - return errors.ErrNoNonce() - } - - // check the nonce sequence number - err := nonceTx.IncrementSeq(ctx, store) - if err != nil { - return err - } - return nil -} diff --git a/modules/nonce/tx.go b/modules/nonce/tx.go index 8e7df262b..bb3d19729 100644 --- a/modules/nonce/tx.go +++ b/modules/nonce/tx.go @@ -59,9 +59,11 @@ func (n Tx) ValidateBasic() error { return n.Tx.ValidateBasic() } -// CheckSeq - Check that the sequence number is one more than the state sequence number +// CheckIncrementSeq - Check that the sequence number is one more than the state sequence number // and further increment the sequence number -func (n Tx) CheckSeq(ctx basecoin.Context, store state.KVStore) error { +// NOTE It is okay to modify the sequence before running the wrapped TX because if the +// wrapped Tx fails, the state changes are not applied +func (n Tx) CheckIncrementSeq(ctx basecoin.Context, store state.KVStore) error { seqKey := n.getSeqKey() @@ -80,19 +82,6 @@ func (n Tx) CheckSeq(ctx basecoin.Context, store state.KVStore) error { return errors.ErrNotMember() } } - return nil -} - -// IncrementSeq - increment the sequence for a group of actors -func (n Tx) IncrementSeq(ctx basecoin.Context, store state.KVStore) error { - - seqKey := n.getSeqKey() - - // check the current state - cur, err := getSeq(store, seqKey) - if err != nil { - return err - } // increment the sequence by 1 err = setSeq(store, seqKey, cur+1) diff --git a/modules/nonce/tx_test.go b/modules/nonce/tx_test.go index ecd1fc40f..2d4406207 100644 --- a/modules/nonce/tx_test.go +++ b/modules/nonce/tx_test.go @@ -78,7 +78,7 @@ func TestNonce(t *testing.T) { nonceTx, ok := tx.Unwrap().(Tx) require.True(ok) - err := nonceTx.CheckSeq(myCtx, store) + err := nonceTx.CheckIncrementSeq(myCtx, store) if test.valid { assert.Nil(err, "%d: %+v", i, err) } else {