From 0f0723dd78df0ebbfcda6169b78a5bf34f3bd3e0 Mon Sep 17 00:00:00 2001 From: Anil Kumar Kammari Date: Wed, 1 Jul 2020 19:34:26 +0530 Subject: [PATCH] Add filtered pagination (#6514) * Add filtered pagination * Add example for filtered pagination * Fix typo * Fix counter * Fix paginate * fix accumulate * Fix example * rename count to numHits * Add default PageRequest * Add tests for filtered pagination * Fix filteredPaginate tests * Add test cases * Add example for filtered pagination * Add more test case * Add iterate error * Add error check for iterator * Update godoc * Update godoc Co-authored-by: sahith-narahari Co-authored-by: Aaron Craelius --- types/query/filtered_pagination.go | 113 +++++++++++++++ types/query/filtered_pagination_test.go | 178 ++++++++++++++++++++++++ types/query/pagination.go | 20 ++- types/query/pagination_test.go | 15 +- 4 files changed, 316 insertions(+), 10 deletions(-) create mode 100644 types/query/filtered_pagination.go create mode 100644 types/query/filtered_pagination_test.go diff --git a/types/query/filtered_pagination.go b/types/query/filtered_pagination.go new file mode 100644 index 000000000..2a7ee925b --- /dev/null +++ b/types/query/filtered_pagination.go @@ -0,0 +1,113 @@ +package query + +import ( + "fmt" + + "github.com/cosmos/cosmos-sdk/store/types" +) + +// FilteredPaginate does pagination of all the results in the PrefixStore based on the +// provided PageRequest. onResult should be used to do actual unmarshaling and filter the results. +// If key is provided, the pagination uses the optimized querying. +// If offset is used, the pagination uses lazy filtering i.e., searches through all the records. +// The accumulate parameter represents if the response is valid based on the offset given. +// It will be false for the results (filtered) < offset and true for `offset > accumulate <= end`. +// When accumulate is set to true the current result should be appended to the result set returned +// to the client. +func FilteredPaginate( + prefixStore types.KVStore, + req *PageRequest, + onResult func(key []byte, value []byte, accumulate bool) (bool, error), +) (*PageResponse, error) { + + // if the PageRequest is nil, use default PageRequest + if req == nil { + req = &PageRequest{} + } + + offset := req.Offset + key := req.Key + limit := req.Limit + countTotal := req.CountTotal + + if offset > 0 && key != nil { + return nil, fmt.Errorf("invalid request, either offset or key is expected, got both") + } + + if limit == 0 { + limit = defaultLimit + + // count total results when the limit is zero/not supplied + countTotal = true + } + + if len(key) != 0 { + iterator := prefixStore.Iterator(key, nil) + defer iterator.Close() + + var numHits uint64 + var nextKey []byte + + for ; iterator.Valid(); iterator.Next() { + if numHits == limit { + nextKey = iterator.Key() + break + } + + if iterator.Error() != nil { + return nil, iterator.Error() + } + + hit, err := onResult(iterator.Key(), iterator.Value(), true) + if err != nil { + return nil, err + } + + if hit { + numHits++ + } + } + + return &PageResponse{ + NextKey: nextKey, + }, nil + } + + iterator := prefixStore.Iterator(nil, nil) + defer iterator.Close() + + end := offset + limit + + var numHits uint64 + var nextKey []byte + + for ; iterator.Valid(); iterator.Next() { + if iterator.Error() != nil { + return nil, iterator.Error() + } + accumulate := numHits >= offset && numHits < end + hit, err := onResult(iterator.Key(), iterator.Value(), accumulate) + if err != nil { + return nil, err + } + + if hit { + numHits++ + } + + if numHits == end { + nextKey = iterator.Key() + + if !countTotal { + break + } + } + } + + res := &PageResponse{NextKey: nextKey} + if countTotal { + res.Total = numHits + } + + return res, nil +} diff --git a/types/query/filtered_pagination_test.go b/types/query/filtered_pagination_test.go new file mode 100644 index 000000000..fc4e6cf13 --- /dev/null +++ b/types/query/filtered_pagination_test.go @@ -0,0 +1,178 @@ +package query_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cosmos/cosmos-sdk/codec" + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/query" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + "github.com/cosmos/cosmos-sdk/x/bank/types" +) + +var addr1 = sdk.AccAddress([]byte("addr1")) + +func TestFilteredPaginations(t *testing.T) { + app, ctx, appCodec := setupTest() + + var balances sdk.Coins + + for i := 0; i < numBalances; i++ { + denom := fmt.Sprintf("foo%ddenom", i) + balances = append(balances, sdk.NewInt64Coin(denom, 100)) + } + + for i := 0; i < 4; i++ { + denom := fmt.Sprintf("test%ddenom", i) + balances = append(balances, sdk.NewInt64Coin(denom, 250)) + } + + addr1 := sdk.AccAddress([]byte("addr1")) + acc1 := app.AccountKeeper.NewAccountWithAddress(ctx, addr1) + app.AccountKeeper.SetAccount(ctx, acc1) + require.NoError(t, app.BankKeeper.SetBalances(ctx, addr1, balances)) + store := ctx.KVStore(app.GetKey(authtypes.StoreKey)) + + // verify pagination with limit > total values + pageReq := &query.PageRequest{Key: nil, Limit: 5, CountTotal: true} + balances, res, err := execFilterPaginate(store, pageReq, appCodec) + + require.NoError(t, err) + require.NotNil(t, res) + require.Equal(t, 4, len(balances)) + + // verify empty request + balances, res, err = execFilterPaginate(store, nil, appCodec) + + require.NoError(t, err) + require.NotNil(t, res) + require.Equal(t, 4, len(balances)) + require.Equal(t, uint64(4), res.Total) + + // verify next key is returned + pageReq = &query.PageRequest{Key: nil, Limit: 2, CountTotal: true} + balances, res, err = execFilterPaginate(store, pageReq, appCodec) + + require.NoError(t, err) + require.NotNil(t, res) + require.Equal(t, 2, len(balances)) + require.NotNil(t, res.NextKey) + require.Equal(t, uint64(4), res.Total) + + // use next key for query + pageReq = &query.PageRequest{Key: res.NextKey, Limit: 2, CountTotal: true} + balances, res, err = execFilterPaginate(store, pageReq, appCodec) + + require.NoError(t, err) + require.NotNil(t, res) + require.Equal(t, 2, len(balances)) + require.NotNil(t, res.NextKey) + + // verify both key and offset can't be given + pageReq = &query.PageRequest{Key: res.NextKey, Limit: 1, Offset: 2, CountTotal: true} + balances, res, err = execFilterPaginate(store, pageReq, appCodec) + require.Error(t, err) + + // verify default limit + pageReq = &query.PageRequest{Key: nil, Limit: 0} + balances, res, err = execFilterPaginate(store, pageReq, appCodec) + + require.NoError(t, err) + require.NotNil(t, res) + require.Equal(t, 4, len(balances)) + require.Equal(t, uint64(4), res.Total) + + // verify offset + pageReq = &query.PageRequest{Offset: 2, Limit: 2} + balances, res, err = execFilterPaginate(store, pageReq, appCodec) + + require.NoError(t, err) + require.NotNil(t, res) + require.LessOrEqual(t, len(balances), 2) +} + +func ExampleFilteredPaginate() { + app, ctx, appCodec := setupTest() + + var balances sdk.Coins + + for i := 0; i < numBalances; i++ { + denom := fmt.Sprintf("foo%ddenom", i) + balances = append(balances, sdk.NewInt64Coin(denom, 100)) + } + + for i := 0; i < 5; i++ { + denom := fmt.Sprintf("test%ddenom", i) + balances = append(balances, sdk.NewInt64Coin(denom, 250)) + } + addr1 := sdk.AccAddress([]byte("addr1")) + acc1 := app.AccountKeeper.NewAccountWithAddress(ctx, addr1) + app.AccountKeeper.SetAccount(ctx, acc1) + err := app.BankKeeper.SetBalances(ctx, addr1, balances) + if err != nil { // should return no error + fmt.Println(err) + } + + pageReq := &query.PageRequest{Key: nil, Limit: 1, CountTotal: true} + store := ctx.KVStore(app.GetKey(authtypes.StoreKey)) + balancesStore := prefix.NewStore(store, types.BalancesPrefix) + accountStore := prefix.NewStore(balancesStore, addr1.Bytes()) + + var balResult sdk.Coins + res, err := query.FilteredPaginate(accountStore, pageReq, func(key []byte, value []byte, accumulate bool) (bool, error) { + var bal sdk.Coin + err := appCodec.UnmarshalBinaryBare(value, &bal) + if err != nil { + return false, err + } + + // filter balances with amount greater than 100 + if bal.Amount.Int64() > int64(100) { + if accumulate { + balResult = append(balResult, bal) + } + + return true, nil + } + + return false, nil + }) + + if err != nil { // should return no error + fmt.Println(err) + } + fmt.Println(&types.QueryAllBalancesResponse{Balances: balResult, Res: res}) + // Output: + // balances: res: +} + +func execFilterPaginate(store sdk.KVStore, pageReq *query.PageRequest, appCodec codec.Marshaler) (balances sdk.Coins, res *query.PageResponse, err error) { + balancesStore := prefix.NewStore(store, types.BalancesPrefix) + accountStore := prefix.NewStore(balancesStore, addr1.Bytes()) + + var balResult sdk.Coins + res, err = query.FilteredPaginate(accountStore, pageReq, func(key []byte, value []byte, accumulate bool) (bool, error) { + var bal sdk.Coin + err := appCodec.UnmarshalBinaryBare(value, &bal) + if err != nil { + return false, err + } + + // filter balances with amount greater than 100 + if bal.Amount.Int64() > int64(100) { + if accumulate { + balResult = append(balResult, bal) + } + + return true, nil + } + + return false, nil + }) + + return balResult, res, err +} diff --git a/types/query/pagination.go b/types/query/pagination.go index 595fe29b9..0505e708b 100644 --- a/types/query/pagination.go +++ b/types/query/pagination.go @@ -17,6 +17,12 @@ func Paginate( req *PageRequest, onResult func(key []byte, value []byte) error, ) (*PageResponse, error) { + + // if the PageRequest is nil, use default PageRequest + if req == nil { + req = &PageRequest{} + } + offset := req.Offset key := req.Key limit := req.Limit @@ -45,7 +51,9 @@ func Paginate( nextKey = iterator.Key() break } - + if iterator.Error() != nil { + return nil, iterator.Error() + } err := onResult(iterator.Key(), iterator.Value()) if err != nil { return nil, err @@ -78,9 +86,15 @@ func Paginate( if err != nil { return nil, err } - } else if !countTotal { + } else if count == end+1 { nextKey = iterator.Key() - break + + if !countTotal { + break + } + } + if iterator.Error() != nil { + return nil, iterator.Error() } } diff --git a/types/query/pagination_test.go b/types/query/pagination_test.go index 191fc8a64..d4585103b 100644 --- a/types/query/pagination_test.go +++ b/types/query/pagination_test.go @@ -5,6 +5,7 @@ import ( "fmt" "testing" + "github.com/cosmos/cosmos-sdk/codec" "github.com/cosmos/cosmos-sdk/store/prefix" authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" @@ -36,7 +37,7 @@ const ( ) func TestPagination(t *testing.T) { - app, ctx := setupTest() + app, ctx, _ := setupTest() queryHelper := baseapp.NewQueryServerTestHelper(ctx) types.RegisterQueryServer(queryHelper, app.BankKeeper) queryClient := types.NewQueryClient(queryHelper) @@ -59,7 +60,7 @@ func TestPagination(t *testing.T) { res, err := queryClient.AllBalances(gocontext.Background(), request) require.NoError(t, err) require.Equal(t, res.Res.Total, uint64(numBalances)) - require.Nil(t, res.Res.NextKey) + require.NotNil(t, res.Res.NextKey) require.LessOrEqual(t, res.Balances.Len(), defaultLimit) t.Log("verify page request with limit > defaultLimit, returns less or equal to `limit` records") @@ -77,7 +78,7 @@ func TestPagination(t *testing.T) { res, err = queryClient.AllBalances(gocontext.Background(), request) require.NoError(t, err) require.Equal(t, res.Balances.Len(), underLimit) - require.Nil(t, res.Res.NextKey) + require.NotNil(t, res.Res.NextKey) require.Equal(t, res.Res.Total, uint64(numBalances)) t.Log("verify paginate with custom limit and countTotal false") @@ -144,7 +145,7 @@ func TestPagination(t *testing.T) { } func ExamplePaginate() { - app, ctx := setupTest() + app, ctx, _ := setupTest() var balances sdk.Coins @@ -181,10 +182,10 @@ func ExamplePaginate() { } fmt.Println(&types.QueryAllBalancesResponse{Balances: balResult, Res: res}) // Output: - // balances: res: + // balances: res: } -func setupTest() (*simapp.SimApp, sdk.Context) { +func setupTest() (*simapp.SimApp, sdk.Context, codec.Marshaler) { app := simapp.Setup(false) ctx := app.BaseApp.NewContext(false, abci.Header{Height: 1}) appCodec := app.AppCodec() @@ -209,5 +210,5 @@ func setupTest() (*simapp.SimApp, sdk.Context) { app.GetSubspace(types.ModuleName), make(map[string]bool), ) - return app, ctx + return app, ctx, appCodec }