Enforce the expiration height in Chain middleware

This commit is contained in:
Ethan Frey 2017-07-10 12:36:30 +02:00
parent b6197a1c12
commit 765f52e402
4 changed files with 32 additions and 6 deletions

View File

@ -21,6 +21,7 @@ var (
errUnknownTxType = fmt.Errorf("Tx type unknown")
errInvalidFormat = fmt.Errorf("Invalid format")
errUnknownModule = fmt.Errorf("Unknown module")
errExpired = fmt.Errorf("Tx expired")
)
// some crazy reflection to unwrap any generated struct.
@ -130,3 +131,10 @@ func ErrTooLarge() TMError {
func IsTooLargeErr(err error) bool {
return IsSameError(errTooLarge, err)
}
func ErrExpired() TMError {
return WithCode(errExpired, abci.CodeType_Unauthorized)
}
func IsExpiredErr(err error) bool {
return IsSameError(errExpired, err)
}

View File

@ -26,7 +26,7 @@ var _ stack.Middleware = Chain{}
// CheckTx makes sure we are on the proper chain - fulfills Middlware interface
func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Checker) (res basecoin.Result, err error) {
stx, err := c.checkChain(ctx.ChainID(), tx)
stx, err := c.checkChain(ctx.ChainID(), ctx.BlockHeight(), tx)
if err != nil {
return res, err
}
@ -35,7 +35,7 @@ func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx
// DeliverTx makes sure we are on the proper chain - fulfills Middlware interface
func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Deliver) (res basecoin.Result, err error) {
stx, err := c.checkChain(ctx.ChainID(), tx)
stx, err := c.checkChain(ctx.ChainID(), ctx.BlockHeight(), tx)
if err != nil {
return res, err
}
@ -43,7 +43,7 @@ func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin.
}
// checkChain makes sure the tx is a Chain Tx and is on the proper chain
func (c Chain) checkChain(chainID string, tx basecoin.Tx) (basecoin.Tx, error) {
func (c Chain) checkChain(chainID string, height uint64, tx basecoin.Tx) (basecoin.Tx, error) {
// make sure it is a chaintx
ctx, ok := tx.Unwrap().(ChainTx)
if !ok {
@ -60,5 +60,8 @@ func (c Chain) checkChain(chainID string, tx basecoin.Tx) (basecoin.Tx, error) {
if ctx.ChainID != chainID {
return tx, errors.ErrWrongChain(ctx.ChainID)
}
if ctx.ExpiresAt != 0 && ctx.ExpiresAt <= height {
return tx, errors.ErrExpired()
}
return ctx.Tx, nil
}

View File

@ -48,6 +48,7 @@ func TestChain(t *testing.T) {
assert := assert.New(t)
msg := "got it"
chainID := "my-chain"
height := uint64(100)
raw := stack.NewRawTx([]byte{1, 2, 3, 4})
cases := []struct {
@ -55,13 +56,22 @@ func TestChain(t *testing.T) {
valid bool
errorMsg string
}{
// check the chain ids are validated
{NewChainTx(chainID, 0, raw), true, ""},
{NewChainTx("someone-else", 0, raw), false, "someone-else"},
// non-matching chainid, or impossible chain id
{NewChainTx("someone-else", 0, raw), false, "someone-else: Wrong chain"},
{NewChainTx("Inval$$d:CH%%n", 0, raw), false, "Wrong chain"},
// Wrong tx type
{raw, false, "No chain id provided"},
// Check different heights - must be 0 or higher than current height
{NewChainTx(chainID, height+1, raw), true, ""},
{NewChainTx(chainID, height, raw), false, "Tx expired"},
{NewChainTx(chainID, 1, raw), false, "expired"},
{NewChainTx(chainID, 0, raw), true, ""},
}
// generic args here...
ctx := stack.NewContext(chainID, 100, log.NewNopLogger())
ctx := stack.NewContext(chainID, height, log.NewNopLogger())
store := state.NewMemKVStore()
// build the stack

View File

@ -69,7 +69,12 @@ var (
//nolint - TxInner Functions
func NewChainTx(chainID string, expires uint64, tx basecoin.Tx) basecoin.Tx {
return (ChainTx{Tx: tx, ChainID: chainID}).Wrap()
c := ChainTx{
ChainID: chainID,
ExpiresAt: expires,
Tx: tx,
}
return c.Wrap()
}
func (c ChainTx) Wrap() basecoin.Tx {
return basecoin.Tx{c}