Merge pull request #165 from tendermint/feature/164-tx-expiration

Add optional expiration height to tx
This commit is contained in:
Ethan Frey 2017-07-11 13:51:10 +02:00 committed by GitHub
commit af9169f629
20 changed files with 204 additions and 105 deletions

View File

@ -31,6 +31,7 @@ type Basecoin struct {
state *sm.State state *sm.State
cacheState *sm.State cacheState *sm.State
handler basecoin.Handler handler basecoin.Handler
height uint64
logger log.Logger logger log.Logger
} }
@ -45,6 +46,7 @@ func NewBasecoin(handler basecoin.Handler, eyesCli *eyes.Client, logger log.Logg
eyesCli: eyesCli, eyesCli: eyesCli,
state: state, state: state,
cacheState: nil, cacheState: nil,
height: 0,
logger: logger, logger: logger,
} }
} }
@ -73,6 +75,7 @@ func (app *Basecoin) Info() abci.ResponseInfo {
if err != nil { if err != nil {
cmn.PanicCrisis(err) cmn.PanicCrisis(err)
} }
app.height = resp.LastBlockHeight
return abci.ResponseInfo{ return abci.ResponseInfo{
Data: fmt.Sprintf("Basecoin v%v", version.Version), Data: fmt.Sprintf("Basecoin v%v", version.Version),
LastBlockHeight: resp.LastBlockHeight, LastBlockHeight: resp.LastBlockHeight,
@ -111,6 +114,7 @@ func (app *Basecoin) DeliverTx(txBytes []byte) abci.Result {
cache := app.state.CacheWrap() cache := app.state.CacheWrap()
ctx := stack.NewContext( ctx := stack.NewContext(
app.state.GetChainID(), app.state.GetChainID(),
app.height,
app.logger.With("call", "delivertx"), app.logger.With("call", "delivertx"),
) )
res, err := app.handler.DeliverTx(ctx, cache, tx) res, err := app.handler.DeliverTx(ctx, cache, tx)
@ -134,6 +138,7 @@ func (app *Basecoin) CheckTx(txBytes []byte) abci.Result {
// TODO: can we abstract this setup and commit logic?? // TODO: can we abstract this setup and commit logic??
ctx := stack.NewContext( ctx := stack.NewContext(
app.state.GetChainID(), app.state.GetChainID(),
app.height,
app.logger.With("call", "checktx"), app.logger.With("call", "checktx"),
) )
// checktx generally shouldn't touch the state, but we don't care // checktx generally shouldn't touch the state, but we don't care
@ -187,6 +192,7 @@ func (app *Basecoin) InitChain(validators []*abci.Validator) {
// BeginBlock - ABCI // BeginBlock - ABCI
func (app *Basecoin) BeginBlock(hash []byte, header *abci.Header) { func (app *Basecoin) BeginBlock(hash []byte, header *abci.Header) {
app.height++
// for _, plugin := range app.plugins.GetList() { // for _, plugin := range app.plugins.GetList() {
// plugin.BeginBlock(app.state, hash, header) // plugin.BeginBlock(app.state, hash, header)
// } // }

View File

@ -44,7 +44,7 @@ func (at *appTest) getTx(seq int, coins coin.Coins) basecoin.Tx {
in := []coin.TxInput{{Address: at.acctIn.Actor(), Coins: coins, Sequence: seq}} in := []coin.TxInput{{Address: at.acctIn.Actor(), Coins: coins, Sequence: seq}}
out := []coin.TxOutput{{Address: at.acctOut.Actor(), Coins: coins}} out := []coin.TxOutput{{Address: at.acctOut.Actor(), Coins: coins}}
tx := coin.NewSendTx(in, out) tx := coin.NewSendTx(in, out)
tx = base.NewChainTx(at.chainID, tx) tx = base.NewChainTx(at.chainID, 0, tx)
stx := auth.NewMulti(tx) stx := auth.NewMulti(tx)
auth.Sign(stx, at.acctIn.Key) auth.Sign(stx, at.acctIn.Key)
return stx.Wrap() return stx.Wrap()

View File

@ -34,6 +34,7 @@ const (
FlagAmount = "amount" FlagAmount = "amount"
FlagFee = "fee" FlagFee = "fee"
FlagGas = "gas" FlagGas = "gas"
FlagExpires = "expires"
FlagSequence = "sequence" FlagSequence = "sequence"
) )
@ -42,7 +43,8 @@ func init() {
flags.String(FlagTo, "", "Destination address for the bits") flags.String(FlagTo, "", "Destination address for the bits")
flags.String(FlagAmount, "", "Coins to send in the format <amt><coin>,<amt><coin>...") flags.String(FlagAmount, "", "Coins to send in the format <amt><coin>,<amt><coin>...")
flags.String(FlagFee, "0mycoin", "Coins for the transaction fee of the format <amt><coin>") flags.String(FlagFee, "0mycoin", "Coins for the transaction fee of the format <amt><coin>")
flags.Int64(FlagGas, 0, "Amount of gas for this transaction") flags.Uint64(FlagGas, 0, "Amount of gas for this transaction")
flags.Uint64(FlagExpires, 0, "Block height at which this tx expires")
flags.Int(FlagSequence, -1, "Sequence number for this transaction") flags.Int(FlagSequence, -1, "Sequence number for this transaction")
} }
@ -63,7 +65,12 @@ func doSendTx(cmd *cobra.Command, args []string) error {
// TODO: make this more flexible for middleware // TODO: make this more flexible for middleware
// add the chain info // add the chain info
tx = base.NewChainTx(commands.GetChainID(), tx) tx, err = WrapChainTx(tx)
if err != nil {
return err
}
// Note: this is single sig (no multi sig yet)
stx := auth.NewSig(tx) stx := auth.NewSig(tx)
// Sign if needed and post. This it the work-horse // Sign if needed and post. This it the work-horse
@ -76,6 +83,17 @@ func doSendTx(cmd *cobra.Command, args []string) error {
return txcmd.OutputTx(bres) return txcmd.OutputTx(bres)
} }
// WrapChainTx will wrap the tx with a ChainTx from the standard flags
func WrapChainTx(tx basecoin.Tx) (res basecoin.Tx, err error) {
expires := viper.GetInt64(FlagExpires)
chain := commands.GetChainID()
if chain == "" {
return res, errors.New("No chain-id provided")
}
res = base.NewChainTx(chain, uint64(expires), tx)
return res, nil
}
func readSendTxFlags() (tx basecoin.Tx, err error) { func readSendTxFlags() (tx basecoin.Tx, err error) {
// parse to address // parse to address
chain, to, err := parseChainAddress(viper.GetString(FlagTo)) chain, to, err := parseChainAddress(viper.GetString(FlagTo))

View File

@ -36,4 +36,5 @@ type Context interface {
IsParent(ctx Context) bool IsParent(ctx Context) bool
Reset() Context Reset() Context
ChainID() string ChainID() string
BlockHeight() uint64
} }

View File

@ -4,13 +4,12 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/tendermint/basecoin"
"github.com/tendermint/light-client/commands"
txcmd "github.com/tendermint/light-client/commands/txs" txcmd "github.com/tendermint/light-client/commands/txs"
"github.com/tendermint/basecoin"
bcmd "github.com/tendermint/basecoin/cmd/basecli/commands"
"github.com/tendermint/basecoin/docs/guide/counter/plugins/counter" "github.com/tendermint/basecoin/docs/guide/counter/plugins/counter"
"github.com/tendermint/basecoin/modules/auth" "github.com/tendermint/basecoin/modules/auth"
"github.com/tendermint/basecoin/modules/base"
"github.com/tendermint/basecoin/modules/coin" "github.com/tendermint/basecoin/modules/coin"
) )
@ -57,7 +56,10 @@ func counterTx(cmd *cobra.Command, args []string) error {
// TODO: make this more flexible for middleware // TODO: make this more flexible for middleware
// add the chain info // add the chain info
tx = base.NewChainTx(commands.GetChainID(), tx) tx, err = bcmd.WrapChainTx(tx)
if err != nil {
return err
}
stx := auth.NewSig(tx) stx := auth.NewSig(tx)
// Sign if needed and post. This it the work-horse // Sign if needed and post. This it the work-horse

View File

@ -42,7 +42,7 @@ func TestCounterPlugin(t *testing.T) {
// Deliver a CounterTx // Deliver a CounterTx
DeliverCounterTx := func(valid bool, counterFee coin.Coins, inputSequence int) abci.Result { DeliverCounterTx := func(valid bool, counterFee coin.Coins, inputSequence int) abci.Result {
tx := NewTx(valid, counterFee, inputSequence) tx := NewTx(valid, counterFee, inputSequence)
tx = base.NewChainTx(chainID, tx) tx = base.NewChainTx(chainID, 0, tx)
stx := auth.NewSig(tx) stx := auth.NewSig(tx)
auth.Sign(stx, acct.Key) auth.Sign(stx, acct.Key)
txBytes := wire.BinaryBytes(stx.Wrap()) txBytes := wire.BinaryBytes(stx.Wrap())

View File

@ -21,6 +21,7 @@ var (
errUnknownTxType = fmt.Errorf("Tx type unknown") errUnknownTxType = fmt.Errorf("Tx type unknown")
errInvalidFormat = fmt.Errorf("Invalid format") errInvalidFormat = fmt.Errorf("Invalid format")
errUnknownModule = fmt.Errorf("Unknown module") errUnknownModule = fmt.Errorf("Unknown module")
errExpired = fmt.Errorf("Tx expired")
) )
// some crazy reflection to unwrap any generated struct. // some crazy reflection to unwrap any generated struct.
@ -130,3 +131,10 @@ func ErrTooLarge() TMError {
func IsTooLargeErr(err error) bool { func IsTooLargeErr(err error) bool {
return IsSameError(errTooLarge, err) return IsSameError(errTooLarge, err)
} }
func ErrExpired() TMError {
return WithCode(errExpired, abci.CodeType_Unauthorized)
}
func IsExpiredErr(err error) bool {
return IsSameError(errExpired, err)
}

View File

@ -40,7 +40,7 @@ func BenchmarkCheckOneSig(b *testing.B) {
h := makeHandler() h := makeHandler()
store := state.NewMemKVStore() store := state.NewMemKVStore()
for i := 1; i <= b.N; i++ { for i := 1; i <= b.N; i++ {
ctx := stack.NewContext("foo", log.NewNopLogger()) ctx := stack.NewContext("foo", 100, log.NewNopLogger())
_, err := h.DeliverTx(ctx, store, tx) _, err := h.DeliverTx(ctx, store, tx)
// never should error // never should error
if err != nil { if err != nil {
@ -64,7 +64,7 @@ func benchmarkCheckMultiSig(b *testing.B, cnt int) {
h := makeHandler() h := makeHandler()
store := state.NewMemKVStore() store := state.NewMemKVStore()
for i := 1; i <= b.N; i++ { for i := 1; i <= b.N; i++ {
ctx := stack.NewContext("foo", log.NewNopLogger()) ctx := stack.NewContext("foo", 100, log.NewNopLogger())
_, err := h.DeliverTx(ctx, store, tx) _, err := h.DeliverTx(ctx, store, tx)
// never should error // never should error
if err != nil { if err != nil {

View File

@ -18,7 +18,7 @@ func TestSignatureChecks(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
// generic args // generic args
ctx := stack.NewContext("test-chain", log.NewNopLogger()) ctx := stack.NewContext("test-chain", 100, log.NewNopLogger())
store := state.NewMemKVStore() store := state.NewMemKVStore()
raw := stack.NewRawTx([]byte{1, 2, 3, 4}) raw := stack.NewRawTx([]byte{1, 2, 3, 4})

View File

@ -26,7 +26,7 @@ var _ stack.Middleware = Chain{}
// CheckTx makes sure we are on the proper chain - fulfills Middlware interface // CheckTx makes sure we are on the proper chain - fulfills Middlware interface
func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Checker) (res basecoin.Result, err error) { func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Checker) (res basecoin.Result, err error) {
stx, err := c.checkChain(ctx.ChainID(), tx) stx, err := c.checkChainTx(ctx.ChainID(), ctx.BlockHeight(), tx)
if err != nil { if err != nil {
return res, err return res, err
} }
@ -35,21 +35,34 @@ func (c Chain) CheckTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx
// DeliverTx makes sure we are on the proper chain - fulfills Middlware interface // DeliverTx makes sure we are on the proper chain - fulfills Middlware interface
func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Deliver) (res basecoin.Result, err error) { func (c Chain) DeliverTx(ctx basecoin.Context, store state.KVStore, tx basecoin.Tx, next basecoin.Deliver) (res basecoin.Result, err error) {
stx, err := c.checkChain(ctx.ChainID(), tx) stx, err := c.checkChainTx(ctx.ChainID(), ctx.BlockHeight(), tx)
if err != nil { if err != nil {
return res, err return res, err
} }
return next.DeliverTx(ctx, store, stx) return next.DeliverTx(ctx, store, stx)
} }
// checkChain makes sure the tx is a Chain Tx and is on the proper chain // checkChainTx makes sure the tx is a Chain Tx, it is on the proper chain,
func (c Chain) checkChain(chainID string, tx basecoin.Tx) (basecoin.Tx, error) { // and it has not expired.
func (c Chain) checkChainTx(chainID string, height uint64, tx basecoin.Tx) (basecoin.Tx, error) {
// make sure it is a chaintx
ctx, ok := tx.Unwrap().(ChainTx) ctx, ok := tx.Unwrap().(ChainTx)
if !ok { if !ok {
return tx, errors.ErrNoChain() return tx, errors.ErrNoChain()
} }
// basic validation
err := ctx.ValidateBasic()
if err != nil {
return tx, err
}
// compare against state
if ctx.ChainID != chainID { if ctx.ChainID != chainID {
return tx, errors.ErrWrongChain(ctx.ChainID) return tx, errors.ErrWrongChain(ctx.ChainID)
} }
if ctx.ExpiresAt != 0 && ctx.ExpiresAt <= height {
return tx, errors.ErrExpired()
}
return ctx.Tx, nil return ctx.Tx, nil
} }

View File

@ -13,10 +13,42 @@ import (
"github.com/tendermint/basecoin/state" "github.com/tendermint/basecoin/state"
) )
func TestChainValidate(t *testing.T) {
assert := assert.New(t)
raw := stack.NewRawTx([]byte{1, 2, 3, 4})
cases := []struct {
name string
expires uint64
valid bool
}{
{"hello", 0, true},
{"one-2-three", 123, true},
{"super!@#$%@", 0, false},
{"WISH_2_be", 14, true},
{"öhhh", 54, false},
}
for _, tc := range cases {
tx := NewChainTx(tc.name, tc.expires, raw)
err := tx.ValidateBasic()
if tc.valid {
assert.Nil(err, "%s: %+v", tc.name, err)
} else {
assert.NotNil(err, tc.name)
}
}
empty := NewChainTx("okay", 0, basecoin.Tx{})
err := empty.ValidateBasic()
assert.NotNil(err)
}
func TestChain(t *testing.T) { func TestChain(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
msg := "got it" msg := "got it"
chainID := "my-chain" chainID := "my-chain"
height := uint64(100)
raw := stack.NewRawTx([]byte{1, 2, 3, 4}) raw := stack.NewRawTx([]byte{1, 2, 3, 4})
cases := []struct { cases := []struct {
@ -24,13 +56,22 @@ func TestChain(t *testing.T) {
valid bool valid bool
errorMsg string errorMsg string
}{ }{
{NewChainTx(chainID, raw), true, ""}, // check the chain ids are validated
{NewChainTx("someone-else", raw), false, "someone-else"}, {NewChainTx(chainID, 0, raw), true, ""},
// non-matching chainid, or impossible chain id
{NewChainTx("someone-else", 0, raw), false, "someone-else: Wrong chain"},
{NewChainTx("Inval$$d:CH%%n", 0, raw), false, "Wrong chain"},
// Wrong tx type
{raw, false, "No chain id provided"}, {raw, false, "No chain id provided"},
// Check different heights - must be 0 or higher than current height
{NewChainTx(chainID, height+1, raw), true, ""},
{NewChainTx(chainID, height, raw), false, "Tx expired"},
{NewChainTx(chainID, 1, raw), false, "expired"},
{NewChainTx(chainID, 0, raw), true, ""},
} }
// generic args here... // generic args here...
ctx := stack.NewContext(chainID, log.NewNopLogger()) ctx := stack.NewContext(chainID, height, log.NewNopLogger())
store := state.NewMemKVStore() store := state.NewMemKVStore()
// build the stack // build the stack

View File

@ -1,6 +1,11 @@
package base package base
import "github.com/tendermint/basecoin" import (
"regexp"
"github.com/tendermint/basecoin"
"github.com/tendermint/basecoin/errors"
)
// nolint // nolint
const ( const (
@ -51,20 +56,44 @@ func (mt MultiTx) ValidateBasic() error {
// ChainTx locks this tx to one chainTx, wrap with this before signing // ChainTx locks this tx to one chainTx, wrap with this before signing
type ChainTx struct { type ChainTx struct {
Tx basecoin.Tx `json:"tx"` // name of chain, must be [A-Za-z0-9_-]+
ChainID string `json:"chain_id"` ChainID string `json:"chain_id"`
// block height at which it is no longer valid, 0 means no expiration
ExpiresAt uint64 `json:"expires_at"`
Tx basecoin.Tx `json:"tx"`
} }
var _ basecoin.TxInner = &ChainTx{} var _ basecoin.TxInner = &ChainTx{}
//nolint - TxInner Functions var (
func NewChainTx(chainID string, tx basecoin.Tx) basecoin.Tx { chainPattern = regexp.MustCompile("^[A-Za-z0-9_-]+$")
return (ChainTx{Tx: tx, ChainID: chainID}).Wrap() )
// NewChainTx wraps a particular tx with the ChainTx wrapper,
// to enforce chain and height
func NewChainTx(chainID string, expires uint64, tx basecoin.Tx) basecoin.Tx {
c := ChainTx{
ChainID: chainID,
ExpiresAt: expires,
Tx: tx,
}
return c.Wrap()
} }
//nolint - TxInner Functions
func (c ChainTx) Wrap() basecoin.Tx { func (c ChainTx) Wrap() basecoin.Tx {
return basecoin.Tx{c} return basecoin.Tx{c}
} }
func (c ChainTx) ValidateBasic() error { func (c ChainTx) ValidateBasic() error {
if c.ChainID == "" {
return errors.ErrNoChain()
}
if !chainPattern.MatchString(c.ChainID) {
return errors.ErrWrongChain(c.ChainID)
}
if c.Tx.Empty() {
return errors.ErrUnknownTxType(c.Tx)
}
// TODO: more checks? chainID? // TODO: more checks? chainID?
return c.Tx.ValidateBasic() return c.Tx.ValidateBasic()
} }

View File

@ -25,7 +25,7 @@ func TestEncoding(t *testing.T) {
}{ }{
{raw}, {raw},
{NewMultiTx(raw, raw2)}, {NewMultiTx(raw, raw2)},
{NewChainTx("foobar", raw)}, {NewChainTx("foobar", 0, raw)},
} }
for idx, tc := range cases { for idx, tc := range cases {

View File

@ -34,7 +34,7 @@ func BenchmarkSimpleTransfer(b *testing.B) {
// now, loop... // now, loop...
for i := 1; i <= b.N; i++ { for i := 1; i <= b.N; i++ {
ctx := stack.MockContext("foo").WithPermissions(sender) ctx := stack.MockContext("foo", 100).WithPermissions(sender)
tx := makeSimpleTx(sender, receiver, Coins{{"mycoin", 2}}, i) tx := makeSimpleTx(sender, receiver, Coins{{"mycoin", 2}}, i)
_, err := h.DeliverTx(ctx, store, tx) _, err := h.DeliverTx(ctx, store, tx)
// never should error // never should error

View File

@ -74,7 +74,7 @@ func TestHandlerValidation(t *testing.T) {
} }
for i, tc := range cases { for i, tc := range cases {
ctx := stack.MockContext("base-chain").WithPermissions(tc.perms...) ctx := stack.MockContext("base-chain", 100).WithPermissions(tc.perms...)
_, err := checkTx(ctx, tc.tx) _, err := checkTx(ctx, tc.tx)
if tc.valid { if tc.valid {
assert.Nil(err, "%d: %+v", i, err) assert.Nil(err, "%d: %+v", i, err)
@ -148,7 +148,7 @@ func TestDeliverTx(t *testing.T) {
require.Nil(err, "%d: %+v", i, err) require.Nil(err, "%d: %+v", i, err)
} }
ctx := stack.MockContext("base-chain").WithPermissions(tc.perms...) ctx := stack.MockContext("base-chain", 100).WithPermissions(tc.perms...)
_, err := h.DeliverTx(ctx, store, tc.tx) _, err := h.DeliverTx(ctx, store, tc.tx)
if len(tc.final) > 0 { // valid if len(tc.final) > 0 { // valid
assert.Nil(err, "%d: %+v", i, err) assert.Nil(err, "%d: %+v", i, err)

View File

@ -1,9 +1,6 @@
package stack package stack
import ( import (
"bytes"
"math/rand"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/tendermint/tmlibs/log" "github.com/tendermint/tmlibs/log"
@ -16,28 +13,20 @@ import (
type nonce int64 type nonce int64
type secureContext struct { type secureContext struct {
id nonce app string
chain string // this exposes the log.Logger and all other methods we don't override
app string naiveContext
perms []basecoin.Actor
log.Logger
} }
// NewContext - create a new secureContext // NewContext - create a new secureContext
func NewContext(chain string, logger log.Logger) basecoin.Context { func NewContext(chain string, height uint64, logger log.Logger) basecoin.Context {
return secureContext{ return secureContext{
id: nonce(rand.Int63()), naiveContext: MockContext(chain, height).(naiveContext),
chain: chain,
Logger: logger,
} }
} }
var _ basecoin.Context = secureContext{} var _ basecoin.Context = secureContext{}
func (c secureContext) ChainID() string {
return c.chain
}
// WithPermissions will panic if they try to set permission without the proper app // WithPermissions will panic if they try to set permission without the proper app
func (c secureContext) WithPermissions(perms ...basecoin.Actor) basecoin.Context { func (c secureContext) WithPermissions(perms ...basecoin.Actor) basecoin.Context {
// the guard makes sure you only set permissions for the app you are inside // the guard makes sure you only set permissions for the app you are inside
@ -50,32 +39,18 @@ func (c secureContext) WithPermissions(perms ...basecoin.Actor) basecoin.Context
} }
return secureContext{ return secureContext{
id: c.id, app: c.app,
chain: c.chain, naiveContext: c.naiveContext.WithPermissions(perms...).(naiveContext),
app: c.app,
perms: append(c.perms, perms...),
Logger: c.Logger,
} }
} }
func (c secureContext) HasPermission(perm basecoin.Actor) bool { // Reset should clear out all permissions,
for _, p := range c.perms { // but carry on knowledge that this is a child
if perm.App == p.App && bytes.Equal(perm.Address, p.Address) { func (c secureContext) Reset() basecoin.Context {
return true return secureContext{
} app: c.app,
naiveContext: c.naiveContext.Reset().(naiveContext),
} }
return false
}
func (c secureContext) GetPermissions(chain, app string) (res []basecoin.Actor) {
for _, p := range c.perms {
if chain == p.ChainID {
if app == "" || app == p.App {
res = append(res, p)
}
}
}
return res
} }
// IsParent ensures that this is derived from the given secureClient // IsParent ensures that this is derived from the given secureClient
@ -84,19 +59,7 @@ func (c secureContext) IsParent(other basecoin.Context) bool {
if !ok { if !ok {
return false return false
} }
return c.id == so.id return c.naiveContext.IsParent(so.naiveContext)
}
// Reset should clear out all permissions,
// but carry on knowledge that this is a child
func (c secureContext) Reset() basecoin.Context {
return secureContext{
id: c.id,
chain: c.chain,
app: c.app,
perms: nil,
Logger: c.Logger,
}
} }
// withApp is a private method that we can use to properly set the // withApp is a private method that we can use to properly set the
@ -107,11 +70,8 @@ func withApp(ctx basecoin.Context, app string) basecoin.Context {
return ctx return ctx
} }
return secureContext{ return secureContext{
id: sc.id, app: app,
chain: sc.chain, naiveContext: sc.naiveContext,
app: app,
perms: sc.perms,
Logger: sc.Logger,
} }
} }

View File

@ -15,7 +15,7 @@ import (
func TestOK(t *testing.T) { func TestOK(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
ctx := NewContext("test-chain", log.NewNopLogger()) ctx := NewContext("test-chain", 20, log.NewNopLogger())
store := state.NewMemKVStore() store := state.NewMemKVStore()
data := "this looks okay" data := "this looks okay"
tx := basecoin.Tx{} tx := basecoin.Tx{}
@ -33,7 +33,7 @@ func TestOK(t *testing.T) {
func TestFail(t *testing.T) { func TestFail(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
ctx := NewContext("test-chain", log.NewNopLogger()) ctx := NewContext("test-chain", 20, log.NewNopLogger())
store := state.NewMemKVStore() store := state.NewMemKVStore()
msg := "big problem" msg := "big problem"
tx := basecoin.Tx{} tx := basecoin.Tx{}
@ -53,7 +53,7 @@ func TestFail(t *testing.T) {
func TestPanic(t *testing.T) { func TestPanic(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
ctx := NewContext("test-chain", log.NewNopLogger()) ctx := NewContext("test-chain", 20, log.NewNopLogger())
store := state.NewMemKVStore() store := state.NewMemKVStore()
msg := "system crash!" msg := "system crash!"
tx := basecoin.Tx{} tx := basecoin.Tx{}

View File

@ -22,7 +22,7 @@ func TestPermissionSandbox(t *testing.T) {
require := require.New(t) require := require.New(t)
// generic args // generic args
ctx := NewContext("test-chain", log.NewNopLogger()) ctx := NewContext("test-chain", 20, log.NewNopLogger())
store := state.NewMemKVStore() store := state.NewMemKVStore()
raw := NewRawTx([]byte{1, 2, 3, 4}) raw := NewRawTx([]byte{1, 2, 3, 4})
rawBytes, err := data.ToWire(raw) rawBytes, err := data.ToWire(raw)

View File

@ -2,40 +2,55 @@ package stack
import ( import (
"bytes" "bytes"
"math/rand"
"github.com/tendermint/tmlibs/log" "github.com/tendermint/tmlibs/log"
"github.com/tendermint/basecoin" "github.com/tendermint/basecoin"
) )
type mockContext struct { type naiveContext struct {
perms []basecoin.Actor id nonce
chain string chain string
height uint64
perms []basecoin.Actor
log.Logger log.Logger
} }
func MockContext(chain string) basecoin.Context { // MockContext returns a simple, non-checking context for test cases.
return mockContext{ //
// Always use NewContext() for production code to sandbox malicious code better
func MockContext(chain string, height uint64) basecoin.Context {
return naiveContext{
id: nonce(rand.Int63()),
chain: chain, chain: chain,
height: height,
Logger: log.NewNopLogger(), Logger: log.NewNopLogger(),
} }
} }
var _ basecoin.Context = mockContext{} var _ basecoin.Context = naiveContext{}
func (c mockContext) ChainID() string { func (c naiveContext) ChainID() string {
return c.chain return c.chain
} }
func (c naiveContext) BlockHeight() uint64 {
return c.height
}
// WithPermissions will panic if they try to set permission without the proper app // WithPermissions will panic if they try to set permission without the proper app
func (c mockContext) WithPermissions(perms ...basecoin.Actor) basecoin.Context { func (c naiveContext) WithPermissions(perms ...basecoin.Actor) basecoin.Context {
return mockContext{ return naiveContext{
id: c.id,
chain: c.chain,
height: c.height,
perms: append(c.perms, perms...), perms: append(c.perms, perms...),
Logger: c.Logger, Logger: c.Logger,
} }
} }
func (c mockContext) HasPermission(perm basecoin.Actor) bool { func (c naiveContext) HasPermission(perm basecoin.Actor) bool {
for _, p := range c.perms { for _, p := range c.perms {
if perm.App == p.App && bytes.Equal(perm.Address, p.Address) { if perm.App == p.App && bytes.Equal(perm.Address, p.Address) {
return true return true
@ -44,7 +59,7 @@ func (c mockContext) HasPermission(perm basecoin.Actor) bool {
return false return false
} }
func (c mockContext) GetPermissions(chain, app string) (res []basecoin.Actor) { func (c naiveContext) GetPermissions(chain, app string) (res []basecoin.Actor) {
for _, p := range c.perms { for _, p := range c.perms {
if chain == p.ChainID { if chain == p.ChainID {
if app == "" || app == p.App { if app == "" || app == p.App {
@ -56,15 +71,21 @@ func (c mockContext) GetPermissions(chain, app string) (res []basecoin.Actor) {
} }
// IsParent ensures that this is derived from the given secureClient // IsParent ensures that this is derived from the given secureClient
func (c mockContext) IsParent(other basecoin.Context) bool { func (c naiveContext) IsParent(other basecoin.Context) bool {
_, ok := other.(mockContext) nc, ok := other.(naiveContext)
return ok if !ok {
return false
}
return c.id == nc.id
} }
// Reset should clear out all permissions, // Reset should clear out all permissions,
// but carry on knowledge that this is a child // but carry on knowledge that this is a child
func (c mockContext) Reset() basecoin.Context { func (c naiveContext) Reset() basecoin.Context {
return mockContext{ return naiveContext{
id: c.id,
chain: c.chain,
height: c.height,
Logger: c.Logger, Logger: c.Logger,
} }
} }

View File

@ -17,7 +17,7 @@ func TestRecovery(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
// generic args here... // generic args here...
ctx := NewContext("test-chain", log.NewNopLogger()) ctx := NewContext("test-chain", 20, log.NewNopLogger())
store := state.NewMemKVStore() store := state.NewMemKVStore()
tx := basecoin.Tx{} tx := basecoin.Tx{}