From 0ca2ff6702b0c73c1fedcfd4591635fdee8feb9f Mon Sep 17 00:00:00 2001 From: Ethan Frey Date: Thu, 7 May 2020 11:08:20 +0200 Subject: [PATCH] Make QueryPlugins configurable for keeper --- app/app.go | 5 +- x/wasm/internal/keeper/handler_plugin.go | 7 +- x/wasm/internal/keeper/keeper.go | 24 ++--- x/wasm/internal/keeper/keeper_test.go | 26 +++--- x/wasm/internal/keeper/mask_test.go | 4 +- x/wasm/internal/keeper/querier_test.go | 4 +- x/wasm/internal/keeper/query_plugins.go | 113 +++++++++++++---------- x/wasm/internal/keeper/test_common.go | 4 +- x/wasm/module_test.go | 2 +- 9 files changed, 104 insertions(+), 85 deletions(-) diff --git a/app/app.go b/app/app.go index 5c02689..9325ad2 100644 --- a/app/app.go +++ b/app/app.go @@ -224,8 +224,9 @@ func NewWasmApp( } wasmConfig := wasmWrap.Wasm - // The last argument can contain custom message handlers, if we want to allow any custom messages - app.wasmKeeper = wasm.NewKeeper(app.cdc, keys[wasm.StoreKey], app.accountKeeper, app.bankKeeper, wasmRouter, wasmDir, wasmConfig, nil) + // The last arguments can contain custom message handlers, and custom query handlers, + // if we want to allow any custom callbacks + app.wasmKeeper = wasm.NewKeeper(app.cdc, keys[wasm.StoreKey], app.accountKeeper, app.bankKeeper, wasmRouter, wasmDir, wasmConfig, nil, nil) // create evidence keeper with evidence router evidenceKeeper := evidence.NewKeeper( diff --git a/x/wasm/internal/keeper/handler_plugin.go b/x/wasm/internal/keeper/handler_plugin.go index 6909ad2..399993e 100644 --- a/x/wasm/internal/keeper/handler_plugin.go +++ b/x/wasm/internal/keeper/handler_plugin.go @@ -16,7 +16,7 @@ type MessageHandler struct { encoders MessageEncoders } -func NewMessageHandler(router sdk.Router, customEncoders MessageEncoders) MessageHandler { +func NewMessageHandler(router sdk.Router, customEncoders *MessageEncoders) MessageHandler { encoders := DefaultEncoders().Merge(customEncoders) return MessageHandler{ router: router, @@ -40,7 +40,10 @@ func DefaultEncoders() MessageEncoders { } } -func (e MessageEncoders) Merge(o MessageEncoders) MessageEncoders { +func (e MessageEncoders) Merge(o *MessageEncoders) MessageEncoders { + if o == nil { + return e + } if o.Bank != nil { e.Bank = o.Bank } diff --git a/x/wasm/internal/keeper/keeper.go b/x/wasm/internal/keeper/keeper.go index 6bff3f5..c77d301 100644 --- a/x/wasm/internal/keeper/keeper.go +++ b/x/wasm/internal/keeper/keeper.go @@ -33,9 +33,9 @@ type Keeper struct { accountKeeper auth.AccountKeeper bankKeeper bank.Keeper - wasmer wasm.Wasmer - queryMods QueryModules - messenger MessageHandler + wasmer wasm.Wasmer + queryPlugins QueryPlugins + messenger MessageHandler // queryGasLimit is the max wasm gas that can be spent on executing a query with a contract queryGasLimit uint64 } @@ -43,18 +43,14 @@ type Keeper struct { // NewKeeper creates a new contract Keeper instance // If customEncoders is non-nil, we can use this to override some of the message handler, especially custom func NewKeeper(cdc *codec.Codec, storeKey sdk.StoreKey, accountKeeper auth.AccountKeeper, bankKeeper bank.Keeper, - router sdk.Router, homeDir string, wasmConfig types.WasmConfig, customEncoders *MessageEncoders) Keeper { + router sdk.Router, homeDir string, wasmConfig types.WasmConfig, customEncoders *MessageEncoders, customPlugins *QueryPlugins) Keeper { wasmer, err := wasm.NewWasmer(filepath.Join(homeDir, "wasm"), wasmConfig.CacheSize) if err != nil { panic(err) } - if customEncoders == nil { - customEncoders = &MessageEncoders{} - } - messenger := NewMessageHandler(router, *customEncoders) - // TODO: make this configurable also - queryMods := DefaultQueryModules(bankKeeper) + messenger := NewMessageHandler(router, customEncoders) + queryPlugins := DefaultQueryPlugins(bankKeeper).Merge(customPlugins) return Keeper{ storeKey: storeKey, @@ -63,7 +59,7 @@ func NewKeeper(cdc *codec.Codec, storeKey sdk.StoreKey, accountKeeper auth.Accou accountKeeper: accountKeeper, bankKeeper: bankKeeper, messenger: messenger, - queryMods: queryMods, + queryPlugins: queryPlugins, queryGasLimit: wasmConfig.SmartQueryGasLimit, } } @@ -135,7 +131,7 @@ func (k Keeper) Instantiate(ctx sdk.Context, codeID uint64, creator sdk.AccAddre // prepare querier querier := QueryHandler{ Ctx: ctx, - Modules: k.queryMods, + Plugins: k.queryPlugins, } // instantiate wasm contract @@ -183,7 +179,7 @@ func (k Keeper) Execute(ctx sdk.Context, contractAddress sdk.AccAddress, caller // prepare querier querier := QueryHandler{ Ctx: ctx, - Modules: k.queryMods, + Plugins: k.queryPlugins, } gas := gasForContract(ctx) @@ -218,7 +214,7 @@ func (k Keeper) QuerySmart(ctx sdk.Context, contractAddr sdk.AccAddress, req []b // prepare querier querier := QueryHandler{ Ctx: ctx, - Modules: k.queryMods, + Plugins: k.queryPlugins, } queryResult, gasUsed, qErr := k.wasmer.Query(codeInfo.CodeHash, req, prefixStore, cosmwasmAPI, querier, gasForContract(ctx)) if qErr != nil { diff --git a/x/wasm/internal/keeper/keeper_test.go b/x/wasm/internal/keeper/keeper_test.go index 924ebf9..bd439d1 100644 --- a/x/wasm/internal/keeper/keeper_test.go +++ b/x/wasm/internal/keeper/keeper_test.go @@ -23,7 +23,7 @@ func TestNewKeeper(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - _, _, keeper := CreateTestInput(t, false, tempDir, nil) + _, _, keeper := CreateTestInput(t, false, tempDir, nil, nil) require.NotNil(t, keeper) } @@ -31,7 +31,7 @@ func TestCreate(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) creator := createFakeFundedAccount(ctx, accKeeper, deposit) @@ -52,7 +52,7 @@ func TestCreateDuplicate(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) creator := createFakeFundedAccount(ctx, accKeeper, deposit) @@ -83,7 +83,7 @@ func TestCreateWithSimulation(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) ctx = ctx.WithBlockHeader(abci.Header{Height: 1}). WithGasMeter(stypes.NewInfiniteGasMeter()) @@ -99,7 +99,7 @@ func TestCreateWithSimulation(t *testing.T) { require.Equal(t, uint64(1), contractID) // then try to create it in non-simulation mode (should not fail) - ctx, accKeeper, keeper = CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper = CreateTestInput(t, false, tempDir, nil, nil) contractID, err = keeper.Create(ctx, creator, wasmCode, "https://github.com/cosmwasm/wasmd/blob/master/x/wasm/testdata/escrow.wasm", "confio/cosmwasm-opt:0.7.2") require.NoError(t, err) require.Equal(t, uint64(1), contractID) @@ -139,7 +139,7 @@ func TestCreateWithGzippedPayload(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) creator := createFakeFundedAccount(ctx, accKeeper, deposit) @@ -162,7 +162,7 @@ func TestInstantiate(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) creator := createFakeFundedAccount(ctx, accKeeper, deposit) @@ -206,7 +206,7 @@ func TestInstantiateWithNonExistingCodeID(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) creator := createFakeFundedAccount(ctx, accKeeper, deposit) @@ -227,7 +227,7 @@ func TestExecute(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) topUp := sdk.NewCoins(sdk.NewInt64Coin("denom", 5000)) @@ -304,7 +304,7 @@ func TestExecuteWithNonExistingAddress(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) creator := createFakeFundedAccount(ctx, accKeeper, deposit.Add(deposit...)) @@ -319,7 +319,7 @@ func TestExecuteWithPanic(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) topUp := sdk.NewCoins(sdk.NewInt64Coin("denom", 5000)) @@ -352,7 +352,7 @@ func TestExecuteWithCpuLoop(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) topUp := sdk.NewCoins(sdk.NewInt64Coin("denom", 5000)) @@ -395,7 +395,7 @@ func TestExecuteWithStorageLoop(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) topUp := sdk.NewCoins(sdk.NewInt64Coin("denom", 5000)) diff --git a/x/wasm/internal/keeper/mask_test.go b/x/wasm/internal/keeper/mask_test.go index 129e730..9e93033 100644 --- a/x/wasm/internal/keeper/mask_test.go +++ b/x/wasm/internal/keeper/mask_test.go @@ -37,7 +37,7 @@ func TestMaskReflectContractSend(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, maskEncoders(MakeTestCodec())) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, maskEncoders(MakeTestCodec()), nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) creator := createFakeFundedAccount(ctx, accKeeper, deposit) @@ -120,7 +120,7 @@ func TestMaskReflectCustom(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, maskEncoders(MakeTestCodec())) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, maskEncoders(MakeTestCodec()), nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) creator := createFakeFundedAccount(ctx, accKeeper, deposit) diff --git a/x/wasm/internal/keeper/querier_test.go b/x/wasm/internal/keeper/querier_test.go index d98b271..c90c7b2 100644 --- a/x/wasm/internal/keeper/querier_test.go +++ b/x/wasm/internal/keeper/querier_test.go @@ -19,7 +19,7 @@ func TestQueryContractState(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 100000)) topUp := sdk.NewCoins(sdk.NewInt64Coin("denom", 5000)) @@ -147,7 +147,7 @@ func TestListContractByCodeOrdering(t *testing.T) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) defer os.RemoveAll(tempDir) - ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, accKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) deposit := sdk.NewCoins(sdk.NewInt64Coin("denom", 1000000)) topUp := sdk.NewCoins(sdk.NewInt64Coin("denom", 500)) diff --git a/x/wasm/internal/keeper/query_plugins.go b/x/wasm/internal/keeper/query_plugins.go index 076616e..efcfd41 100644 --- a/x/wasm/internal/keeper/query_plugins.go +++ b/x/wasm/internal/keeper/query_plugins.go @@ -10,33 +10,17 @@ import ( type QueryHandler struct { Ctx sdk.Context - Modules QueryModules + Plugins QueryPlugins } var _ wasmTypes.Querier = QueryHandler{} -// Fill out more modules -// Rethink interfaces -type QueryModules struct { - Bank bank.ViewKeeper -} - -func DefaultQueryModules(bank bank.ViewKeeper) QueryModules { - return QueryModules{ - Bank: bank, - } -} - -//type QueryPlugins struct { -// Bank func(msg *wasmTypes.BankQuery) ([]byte, error) -// Custom func(msg json.RawMessage) ([]byte, error) -// Staking func(msg *wasmTypes.StakingQuery) ([]byte, error) -// Wasm func(msg *wasmTypes.WasmQuery) ([]byte, error) -//} - func (q QueryHandler) Query(request wasmTypes.QueryRequest) ([]byte, error) { if request.Bank != nil { - return q.QueryBank(request.Bank) + if q.Plugins.Bank == nil { + return nil, wasmTypes.UnsupportedRequest{"bank"} + } + return q.Plugins.Bank(q.Ctx, request.Bank) } // TODO: below if request.Custom != nil { @@ -51,34 +35,69 @@ func (q QueryHandler) Query(request wasmTypes.QueryRequest) ([]byte, error) { return nil, wasmTypes.Unknown{} } -func (q QueryHandler) QueryBank(request *wasmTypes.BankQuery) ([]byte, error) { - if request.AllBalances != nil { - addr, err := sdk.AccAddressFromBech32(request.AllBalances.Address) - if err != nil { - return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, request.AllBalances.Address) - } - coins := q.Modules.Bank.GetCoins(q.Ctx, addr) - res := wasmTypes.AllBalancesResponse{ - Amount: convertSdkCoinToWasmCoin(coins), - } - return json.Marshal(res) +type QueryPlugins struct { + Bank func(ctx sdk.Context, msg *wasmTypes.BankQuery) ([]byte, error) + Custom func(ctx sdk.Context, msg json.RawMessage) ([]byte, error) + Staking func(ctx sdk.Context, msg *wasmTypes.StakingQuery) ([]byte, error) + Wasm func(ctx sdk.Context, msg *wasmTypes.WasmQuery) ([]byte, error) +} + +func DefaultQueryPlugins(bank bank.ViewKeeper) QueryPlugins { + return QueryPlugins{ + Bank: BankQuerier(bank), } - if request.Balance != nil { - addr, err := sdk.AccAddressFromBech32(request.Balance.Address) - if err != nil { - return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, request.Balance.Address) - } - coins := q.Modules.Bank.GetCoins(q.Ctx, addr) - amount := coins.AmountOf(request.Balance.Denom) - res := wasmTypes.BalanceResponse{ - Amount: wasmTypes.Coin{ - Denom: request.Balance.Denom, - Amount: amount.String(), - }, - } - return json.Marshal(res) +} + +func (e QueryPlugins) Merge(o *QueryPlugins) QueryPlugins { + // only update if this is non-nil and then only set values + if o == nil { + return e + } + if o.Bank != nil { + e.Bank = o.Bank + } + if o.Custom != nil { + e.Custom = o.Custom + } + if o.Staking != nil { + e.Staking = o.Staking + } + if o.Wasm != nil { + e.Wasm = o.Wasm + } + return e +} + +func BankQuerier(bank bank.ViewKeeper) func(ctx sdk.Context, request *wasmTypes.BankQuery) ([]byte, error) { + return func(ctx sdk.Context, request *wasmTypes.BankQuery) ([]byte, error) { + if request.AllBalances != nil { + addr, err := sdk.AccAddressFromBech32(request.AllBalances.Address) + if err != nil { + return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, request.AllBalances.Address) + } + coins := bank.GetCoins(ctx, addr) + res := wasmTypes.AllBalancesResponse{ + Amount: convertSdkCoinToWasmCoin(coins), + } + return json.Marshal(res) + } + if request.Balance != nil { + addr, err := sdk.AccAddressFromBech32(request.Balance.Address) + if err != nil { + return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, request.Balance.Address) + } + coins := bank.GetCoins(ctx, addr) + amount := coins.AmountOf(request.Balance.Denom) + res := wasmTypes.BalanceResponse{ + Amount: wasmTypes.Coin{ + Denom: request.Balance.Denom, + Amount: amount.String(), + }, + } + return json.Marshal(res) + } + return nil, wasmTypes.UnsupportedRequest{"unknown BankQuery variant"} } - return nil, wasmTypes.UnsupportedRequest{"unknown BankQuery variant"} } func convertSdkCoinToWasmCoin(coins []sdk.Coin) wasmTypes.Coins { diff --git a/x/wasm/internal/keeper/test_common.go b/x/wasm/internal/keeper/test_common.go index 8da5f2f..5f6efb0 100644 --- a/x/wasm/internal/keeper/test_common.go +++ b/x/wasm/internal/keeper/test_common.go @@ -41,7 +41,7 @@ func MakeTestCodec() *codec.Codec { } // encoders can be nil to accept the defaults, or set it to override some of the message handlers (like default) -func CreateTestInput(t *testing.T, isCheckTx bool, tempDir string, encoders *MessageEncoders) (sdk.Context, auth.AccountKeeper, Keeper) { +func CreateTestInput(t *testing.T, isCheckTx bool, tempDir string, encoders *MessageEncoders, queriers *QueryPlugins) (sdk.Context, auth.AccountKeeper, Keeper) { keyContract := sdk.NewKVStoreKey(types.StoreKey) keyAcc := sdk.NewKVStoreKey(auth.StoreKey) keyParams := sdk.NewKVStoreKey(params.StoreKey) @@ -86,7 +86,7 @@ func CreateTestInput(t *testing.T, isCheckTx bool, tempDir string, encoders *Mes // Load default wasm config wasmConfig := wasmTypes.DefaultWasmConfig() - keeper := NewKeeper(cdc, keyContract, accountKeeper, bk, router, tempDir, wasmConfig, encoders) + keeper := NewKeeper(cdc, keyContract, accountKeeper, bk, router, tempDir, wasmConfig, encoders, queriers) // add wasm handler so we can loop-back (contracts calling contracts) router.AddRoute(wasmTypes.RouterKey, TestHandler(keeper)) diff --git a/x/wasm/module_test.go b/x/wasm/module_test.go index e82643d..4a6f2c8 100644 --- a/x/wasm/module_test.go +++ b/x/wasm/module_test.go @@ -34,7 +34,7 @@ func setupTest(t *testing.T) (testData, func()) { tempDir, err := ioutil.TempDir("", "wasm") require.NoError(t, err) - ctx, acctKeeper, keeper := CreateTestInput(t, false, tempDir, nil) + ctx, acctKeeper, keeper := CreateTestInput(t, false, tempDir, nil, nil) data := testData{ module: NewAppModule(keeper), ctx: ctx,