Add filtered pagination (#6514)
* Add filtered pagination * Add example for filtered pagination * Fix typo * Fix counter * Fix paginate * fix accumulate * Fix example * rename count to numHits * Add default PageRequest * Add tests for filtered pagination * Fix filteredPaginate tests * Add test cases * Add example for filtered pagination * Add more test case * Add iterate error * Add error check for iterator * Update godoc * Update godoc Co-authored-by: sahith-narahari <sahithnarahari@gmail.com> Co-authored-by: Aaron Craelius <aaron@regen.network>
This commit is contained in:
parent
6a52c5a569
commit
0f0723dd78
|
@ -0,0 +1,113 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/store/types"
|
||||
)
|
||||
|
||||
// FilteredPaginate does pagination of all the results in the PrefixStore based on the
|
||||
// provided PageRequest. onResult should be used to do actual unmarshaling and filter the results.
|
||||
// If key is provided, the pagination uses the optimized querying.
|
||||
// If offset is used, the pagination uses lazy filtering i.e., searches through all the records.
|
||||
// The accumulate parameter represents if the response is valid based on the offset given.
|
||||
// It will be false for the results (filtered) < offset and true for `offset > accumulate <= end`.
|
||||
// When accumulate is set to true the current result should be appended to the result set returned
|
||||
// to the client.
|
||||
func FilteredPaginate(
|
||||
prefixStore types.KVStore,
|
||||
req *PageRequest,
|
||||
onResult func(key []byte, value []byte, accumulate bool) (bool, error),
|
||||
) (*PageResponse, error) {
|
||||
|
||||
// if the PageRequest is nil, use default PageRequest
|
||||
if req == nil {
|
||||
req = &PageRequest{}
|
||||
}
|
||||
|
||||
offset := req.Offset
|
||||
key := req.Key
|
||||
limit := req.Limit
|
||||
countTotal := req.CountTotal
|
||||
|
||||
if offset > 0 && key != nil {
|
||||
return nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
limit = defaultLimit
|
||||
|
||||
// count total results when the limit is zero/not supplied
|
||||
countTotal = true
|
||||
}
|
||||
|
||||
if len(key) != 0 {
|
||||
iterator := prefixStore.Iterator(key, nil)
|
||||
defer iterator.Close()
|
||||
|
||||
var numHits uint64
|
||||
var nextKey []byte
|
||||
|
||||
for ; iterator.Valid(); iterator.Next() {
|
||||
if numHits == limit {
|
||||
nextKey = iterator.Key()
|
||||
break
|
||||
}
|
||||
|
||||
if iterator.Error() != nil {
|
||||
return nil, iterator.Error()
|
||||
}
|
||||
|
||||
hit, err := onResult(iterator.Key(), iterator.Value(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if hit {
|
||||
numHits++
|
||||
}
|
||||
}
|
||||
|
||||
return &PageResponse{
|
||||
NextKey: nextKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
iterator := prefixStore.Iterator(nil, nil)
|
||||
defer iterator.Close()
|
||||
|
||||
end := offset + limit
|
||||
|
||||
var numHits uint64
|
||||
var nextKey []byte
|
||||
|
||||
for ; iterator.Valid(); iterator.Next() {
|
||||
if iterator.Error() != nil {
|
||||
return nil, iterator.Error()
|
||||
}
|
||||
accumulate := numHits >= offset && numHits < end
|
||||
hit, err := onResult(iterator.Key(), iterator.Value(), accumulate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if hit {
|
||||
numHits++
|
||||
}
|
||||
|
||||
if numHits == end {
|
||||
nextKey = iterator.Key()
|
||||
|
||||
if !countTotal {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res := &PageResponse{NextKey: nextKey}
|
||||
if countTotal {
|
||||
res.Total = numHits
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
|
@ -0,0 +1,178 @@
|
|||
package query_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/codec"
|
||||
"github.com/cosmos/cosmos-sdk/store/prefix"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
"github.com/cosmos/cosmos-sdk/types/query"
|
||||
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
|
||||
"github.com/cosmos/cosmos-sdk/x/bank/types"
|
||||
)
|
||||
|
||||
var addr1 = sdk.AccAddress([]byte("addr1"))
|
||||
|
||||
func TestFilteredPaginations(t *testing.T) {
|
||||
app, ctx, appCodec := setupTest()
|
||||
|
||||
var balances sdk.Coins
|
||||
|
||||
for i := 0; i < numBalances; i++ {
|
||||
denom := fmt.Sprintf("foo%ddenom", i)
|
||||
balances = append(balances, sdk.NewInt64Coin(denom, 100))
|
||||
}
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
denom := fmt.Sprintf("test%ddenom", i)
|
||||
balances = append(balances, sdk.NewInt64Coin(denom, 250))
|
||||
}
|
||||
|
||||
addr1 := sdk.AccAddress([]byte("addr1"))
|
||||
acc1 := app.AccountKeeper.NewAccountWithAddress(ctx, addr1)
|
||||
app.AccountKeeper.SetAccount(ctx, acc1)
|
||||
require.NoError(t, app.BankKeeper.SetBalances(ctx, addr1, balances))
|
||||
store := ctx.KVStore(app.GetKey(authtypes.StoreKey))
|
||||
|
||||
// verify pagination with limit > total values
|
||||
pageReq := &query.PageRequest{Key: nil, Limit: 5, CountTotal: true}
|
||||
balances, res, err := execFilterPaginate(store, pageReq, appCodec)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
require.Equal(t, 4, len(balances))
|
||||
|
||||
// verify empty request
|
||||
balances, res, err = execFilterPaginate(store, nil, appCodec)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
require.Equal(t, 4, len(balances))
|
||||
require.Equal(t, uint64(4), res.Total)
|
||||
|
||||
// verify next key is returned
|
||||
pageReq = &query.PageRequest{Key: nil, Limit: 2, CountTotal: true}
|
||||
balances, res, err = execFilterPaginate(store, pageReq, appCodec)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
require.Equal(t, 2, len(balances))
|
||||
require.NotNil(t, res.NextKey)
|
||||
require.Equal(t, uint64(4), res.Total)
|
||||
|
||||
// use next key for query
|
||||
pageReq = &query.PageRequest{Key: res.NextKey, Limit: 2, CountTotal: true}
|
||||
balances, res, err = execFilterPaginate(store, pageReq, appCodec)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
require.Equal(t, 2, len(balances))
|
||||
require.NotNil(t, res.NextKey)
|
||||
|
||||
// verify both key and offset can't be given
|
||||
pageReq = &query.PageRequest{Key: res.NextKey, Limit: 1, Offset: 2, CountTotal: true}
|
||||
balances, res, err = execFilterPaginate(store, pageReq, appCodec)
|
||||
require.Error(t, err)
|
||||
|
||||
// verify default limit
|
||||
pageReq = &query.PageRequest{Key: nil, Limit: 0}
|
||||
balances, res, err = execFilterPaginate(store, pageReq, appCodec)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
require.Equal(t, 4, len(balances))
|
||||
require.Equal(t, uint64(4), res.Total)
|
||||
|
||||
// verify offset
|
||||
pageReq = &query.PageRequest{Offset: 2, Limit: 2}
|
||||
balances, res, err = execFilterPaginate(store, pageReq, appCodec)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
require.LessOrEqual(t, len(balances), 2)
|
||||
}
|
||||
|
||||
func ExampleFilteredPaginate() {
|
||||
app, ctx, appCodec := setupTest()
|
||||
|
||||
var balances sdk.Coins
|
||||
|
||||
for i := 0; i < numBalances; i++ {
|
||||
denom := fmt.Sprintf("foo%ddenom", i)
|
||||
balances = append(balances, sdk.NewInt64Coin(denom, 100))
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
denom := fmt.Sprintf("test%ddenom", i)
|
||||
balances = append(balances, sdk.NewInt64Coin(denom, 250))
|
||||
}
|
||||
addr1 := sdk.AccAddress([]byte("addr1"))
|
||||
acc1 := app.AccountKeeper.NewAccountWithAddress(ctx, addr1)
|
||||
app.AccountKeeper.SetAccount(ctx, acc1)
|
||||
err := app.BankKeeper.SetBalances(ctx, addr1, balances)
|
||||
if err != nil { // should return no error
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
pageReq := &query.PageRequest{Key: nil, Limit: 1, CountTotal: true}
|
||||
store := ctx.KVStore(app.GetKey(authtypes.StoreKey))
|
||||
balancesStore := prefix.NewStore(store, types.BalancesPrefix)
|
||||
accountStore := prefix.NewStore(balancesStore, addr1.Bytes())
|
||||
|
||||
var balResult sdk.Coins
|
||||
res, err := query.FilteredPaginate(accountStore, pageReq, func(key []byte, value []byte, accumulate bool) (bool, error) {
|
||||
var bal sdk.Coin
|
||||
err := appCodec.UnmarshalBinaryBare(value, &bal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// filter balances with amount greater than 100
|
||||
if bal.Amount.Int64() > int64(100) {
|
||||
if accumulate {
|
||||
balResult = append(balResult, bal)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err != nil { // should return no error
|
||||
fmt.Println(err)
|
||||
}
|
||||
fmt.Println(&types.QueryAllBalancesResponse{Balances: balResult, Res: res})
|
||||
// Output:
|
||||
// balances:<denom:"test0denom" amount:"250" > res:<next_key:"test0denom" total:5 >
|
||||
}
|
||||
|
||||
func execFilterPaginate(store sdk.KVStore, pageReq *query.PageRequest, appCodec codec.Marshaler) (balances sdk.Coins, res *query.PageResponse, err error) {
|
||||
balancesStore := prefix.NewStore(store, types.BalancesPrefix)
|
||||
accountStore := prefix.NewStore(balancesStore, addr1.Bytes())
|
||||
|
||||
var balResult sdk.Coins
|
||||
res, err = query.FilteredPaginate(accountStore, pageReq, func(key []byte, value []byte, accumulate bool) (bool, error) {
|
||||
var bal sdk.Coin
|
||||
err := appCodec.UnmarshalBinaryBare(value, &bal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// filter balances with amount greater than 100
|
||||
if bal.Amount.Int64() > int64(100) {
|
||||
if accumulate {
|
||||
balResult = append(balResult, bal)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
})
|
||||
|
||||
return balResult, res, err
|
||||
}
|
|
@ -17,6 +17,12 @@ func Paginate(
|
|||
req *PageRequest,
|
||||
onResult func(key []byte, value []byte) error,
|
||||
) (*PageResponse, error) {
|
||||
|
||||
// if the PageRequest is nil, use default PageRequest
|
||||
if req == nil {
|
||||
req = &PageRequest{}
|
||||
}
|
||||
|
||||
offset := req.Offset
|
||||
key := req.Key
|
||||
limit := req.Limit
|
||||
|
@ -45,7 +51,9 @@ func Paginate(
|
|||
nextKey = iterator.Key()
|
||||
break
|
||||
}
|
||||
|
||||
if iterator.Error() != nil {
|
||||
return nil, iterator.Error()
|
||||
}
|
||||
err := onResult(iterator.Key(), iterator.Value())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -78,9 +86,15 @@ func Paginate(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if !countTotal {
|
||||
} else if count == end+1 {
|
||||
nextKey = iterator.Key()
|
||||
break
|
||||
|
||||
if !countTotal {
|
||||
break
|
||||
}
|
||||
}
|
||||
if iterator.Error() != nil {
|
||||
return nil, iterator.Error()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/codec"
|
||||
"github.com/cosmos/cosmos-sdk/store/prefix"
|
||||
authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper"
|
||||
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
|
||||
|
@ -36,7 +37,7 @@ const (
|
|||
)
|
||||
|
||||
func TestPagination(t *testing.T) {
|
||||
app, ctx := setupTest()
|
||||
app, ctx, _ := setupTest()
|
||||
queryHelper := baseapp.NewQueryServerTestHelper(ctx)
|
||||
types.RegisterQueryServer(queryHelper, app.BankKeeper)
|
||||
queryClient := types.NewQueryClient(queryHelper)
|
||||
|
@ -59,7 +60,7 @@ func TestPagination(t *testing.T) {
|
|||
res, err := queryClient.AllBalances(gocontext.Background(), request)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, res.Res.Total, uint64(numBalances))
|
||||
require.Nil(t, res.Res.NextKey)
|
||||
require.NotNil(t, res.Res.NextKey)
|
||||
require.LessOrEqual(t, res.Balances.Len(), defaultLimit)
|
||||
|
||||
t.Log("verify page request with limit > defaultLimit, returns less or equal to `limit` records")
|
||||
|
@ -77,7 +78,7 @@ func TestPagination(t *testing.T) {
|
|||
res, err = queryClient.AllBalances(gocontext.Background(), request)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, res.Balances.Len(), underLimit)
|
||||
require.Nil(t, res.Res.NextKey)
|
||||
require.NotNil(t, res.Res.NextKey)
|
||||
require.Equal(t, res.Res.Total, uint64(numBalances))
|
||||
|
||||
t.Log("verify paginate with custom limit and countTotal false")
|
||||
|
@ -144,7 +145,7 @@ func TestPagination(t *testing.T) {
|
|||
}
|
||||
|
||||
func ExamplePaginate() {
|
||||
app, ctx := setupTest()
|
||||
app, ctx, _ := setupTest()
|
||||
|
||||
var balances sdk.Coins
|
||||
|
||||
|
@ -181,10 +182,10 @@ func ExamplePaginate() {
|
|||
}
|
||||
fmt.Println(&types.QueryAllBalancesResponse{Balances: balResult, Res: res})
|
||||
// Output:
|
||||
// balances:<denom:"foo0denom" amount:"100" > res:<total:2 >
|
||||
// balances:<denom:"foo0denom" amount:"100" > res:<next_key:"foo1denom" total:2 >
|
||||
}
|
||||
|
||||
func setupTest() (*simapp.SimApp, sdk.Context) {
|
||||
func setupTest() (*simapp.SimApp, sdk.Context, codec.Marshaler) {
|
||||
app := simapp.Setup(false)
|
||||
ctx := app.BaseApp.NewContext(false, abci.Header{Height: 1})
|
||||
appCodec := app.AppCodec()
|
||||
|
@ -209,5 +210,5 @@ func setupTest() (*simapp.SimApp, sdk.Context) {
|
|||
app.GetSubspace(types.ModuleName), make(map[string]bool),
|
||||
)
|
||||
|
||||
return app, ctx
|
||||
return app, ctx, appCodec
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue