diff --git a/types/context.go b/types/context.go index 3c90b016a..85eb9e493 100644 --- a/types/context.go +++ b/types/context.go @@ -16,12 +16,13 @@ The intent of Context is for it to be an immutable object that can be cloned and updated cheaply with WithValue() and passed forward to the next decorator or handler. For example, - func MsgHandler(ctx Context, tx Tx) Result { + func MsgHandler(context Context, tx Tx) Result { ... - ctx = ctx.WithValue(key, value) + context = context.WithValue(key, value) ... } */ + type Context struct { context.Context pst *thePast @@ -171,6 +172,12 @@ func (c Context) WithTxBytes(txBytes []byte) Context { return c.withValue(contextKeyTxBytes, txBytes) } +func (c Context) CacheContext() (Context, func()) { + cms := c.multiStore().CacheMultiStore() + cc := c.WithMultiStore(cms) + return cc, cms.Write +} + //---------------------------------------- // thePast diff --git a/types/context_test.go b/types/context_test.go index 36d8099b9..b40e79dc2 100644 --- a/types/context_test.go +++ b/types/context_test.go @@ -3,6 +3,11 @@ package types_test import ( "testing" + "github.com/stretchr/testify/assert" + + dbm "github.com/tendermint/tmlibs/db" + + "github.com/cosmos/cosmos-sdk/store" "github.com/cosmos/cosmos-sdk/types" abci "github.com/tendermint/abci/types" ) @@ -18,3 +23,39 @@ func TestContextGetOpShouldNeverPanic(t *testing.T) { _, _ = ctx.GetOp(index) } } + +func defaultContext(key types.StoreKey) types.Context { + db := dbm.NewMemDB() + cms := store.NewCommitMultiStore(db) + cms.MountStoreWithDB(key, types.StoreTypeIAVL, db) + cms.LoadLatestVersion() + ctx := types.NewContext(cms, abci.Header{}, false, nil) + return ctx +} + +func TestCacheContext(t *testing.T) { + key := types.NewKVStoreKey(t.Name()) + k1 := []byte("hello") + v1 := []byte("world") + k2 := []byte("key") + v2 := []byte("value") + + ctx := defaultContext(key) + store := ctx.KVStore(key) + store.Set(k1, v1) + assert.Equal(t, v1, store.Get(k1)) + assert.Nil(t, store.Get(k2)) + + cctx, write := ctx.CacheContext() + cstore := cctx.KVStore(key) + assert.Equal(t, v1, cstore.Get(k1)) + assert.Nil(t, cstore.Get(k2)) + + cstore.Set(k2, v2) + assert.Equal(t, v2, cstore.Get(k2)) + assert.Nil(t, store.Get(k2)) + + write() + + assert.Equal(t, v2, store.Get(k2)) +}