Enforce the expiration height in Chain middleware
This commit is contained in:
parent
b6197a1c12
commit
765f52e402
|
@ -21,6 +21,7 @@ var (
|
|||
errUnknownTxType = fmt.Errorf("Tx type unknown")
|
||||
errInvalidFormat = fmt.Errorf("Invalid format")
|
||||
errUnknownModule = fmt.Errorf("Unknown module")
|
||||
errExpired = fmt.Errorf("Tx expired")
|
||||
)
|
||||
|
||||
// some crazy reflection to unwrap any generated struct.
|
||||
|
@ -130,3 +131,10 @@ func ErrTooLarge() TMError {
|
|||
func IsTooLargeErr(err error) bool {
|
||||
return IsSameError(errTooLarge, err)
|
||||
}
|
||||
|
||||
func ErrExpired() TMError {
|
||||
return WithCode(errExpired, abci.CodeType_Unauthorized)
|
||||
}
|
||||
func IsExpiredErr(err error) bool {
|
||||
return IsSameError(errExpired, err)
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ var _ stack.Middleware = Chain{}
|
|||
|
||||
// CheckTx makes sure we are on the proper chain - fulfills Middlware interface
|
||||
func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Checker) (res basecoin.Result, err error) {
|
||||
stx, err := c.checkChain(ctx.ChainID(), tx)
|
||||
stx, err := c.checkChain(ctx.ChainID(), ctx.BlockHeight(), tx)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx
|
|||
|
||||
// DeliverTx makes sure we are on the proper chain - fulfills Middlware interface
|
||||
func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Deliver) (res basecoin.Result, err error) {
|
||||
stx, err := c.checkChain(ctx.ChainID(), tx)
|
||||
stx, err := c.checkChain(ctx.ChainID(), ctx.BlockHeight(), tx)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin.
|
|||
}
|
||||
|
||||
// checkChain makes sure the tx is a Chain Tx and is on the proper chain
|
||||
func (c Chain) checkChain(chainID string, tx basecoin.Tx) (basecoin.Tx, error) {
|
||||
func (c Chain) checkChain(chainID string, height uint64, tx basecoin.Tx) (basecoin.Tx, error) {
|
||||
// make sure it is a chaintx
|
||||
ctx, ok := tx.Unwrap().(ChainTx)
|
||||
if !ok {
|
||||
|
@ -60,5 +60,8 @@ func (c Chain) checkChain(chainID string, tx basecoin.Tx) (basecoin.Tx, error) {
|
|||
if ctx.ChainID != chainID {
|
||||
return tx, errors.ErrWrongChain(ctx.ChainID)
|
||||
}
|
||||
if ctx.ExpiresAt != 0 && ctx.ExpiresAt <= height {
|
||||
return tx, errors.ErrExpired()
|
||||
}
|
||||
return ctx.Tx, nil
|
||||
}
|
||||
|
|
|
@ -48,6 +48,7 @@ func TestChain(t *testing.T) {
|
|||
assert := assert.New(t)
|
||||
msg := "got it"
|
||||
chainID := "my-chain"
|
||||
height := uint64(100)
|
||||
|
||||
raw := stack.NewRawTx([]byte{1, 2, 3, 4})
|
||||
cases := []struct {
|
||||
|
@ -55,13 +56,22 @@ func TestChain(t *testing.T) {
|
|||
valid bool
|
||||
errorMsg string
|
||||
}{
|
||||
// check the chain ids are validated
|
||||
{NewChainTx(chainID, 0, raw), true, ""},
|
||||
{NewChainTx("someone-else", 0, raw), false, "someone-else"},
|
||||
// non-matching chainid, or impossible chain id
|
||||
{NewChainTx("someone-else", 0, raw), false, "someone-else: Wrong chain"},
|
||||
{NewChainTx("Inval$$d:CH%%n", 0, raw), false, "Wrong chain"},
|
||||
// Wrong tx type
|
||||
{raw, false, "No chain id provided"},
|
||||
// Check different heights - must be 0 or higher than current height
|
||||
{NewChainTx(chainID, height+1, raw), true, ""},
|
||||
{NewChainTx(chainID, height, raw), false, "Tx expired"},
|
||||
{NewChainTx(chainID, 1, raw), false, "expired"},
|
||||
{NewChainTx(chainID, 0, raw), true, ""},
|
||||
}
|
||||
|
||||
// generic args here...
|
||||
ctx := stack.NewContext(chainID, 100, log.NewNopLogger())
|
||||
ctx := stack.NewContext(chainID, height, log.NewNopLogger())
|
||||
store := state.NewMemKVStore()
|
||||
|
||||
// build the stack
|
||||
|
|
|
@ -69,7 +69,12 @@ var (
|
|||
|
||||
//nolint - TxInner Functions
|
||||
func NewChainTx(chainID string, expires uint64, tx basecoin.Tx) basecoin.Tx {
|
||||
return (ChainTx{Tx: tx, ChainID: chainID}).Wrap()
|
||||
c := ChainTx{
|
||||
ChainID: chainID,
|
||||
ExpiresAt: expires,
|
||||
Tx: tx,
|
||||
}
|
||||
return c.Wrap()
|
||||
}
|
||||
func (c ChainTx) Wrap() basecoin.Tx {
|
||||
return basecoin.Tx{c}
|
||||
|
|
Loading…
Reference in New Issue