diff --git a/CHANGELOG.md b/CHANGELOG.md index b565059d3..7e162a3a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,6 +86,7 @@ Ref: https://keepachangelog.com/en/1.0.0/ * (types) [\#10948](https://github.com/cosmos/cosmos-sdk/issues/10948) Add `app-db-backend` to the `app.toml` config to replace the compile-time `types.DBbackend` variable. * (authz)[\#11060](https://github.com/cosmos/cosmos-sdk/pull/11060) Support grant with no expire time. * (rosetta) [\#11590](https://github.com/cosmos/cosmos-sdk/pull/11590) Add fee suggestion for rosetta and enable offline mode. Also force set events about Fees to Success to pass reconciliation test. +* (types) [\#11959](https://github.com/cosmos/cosmos-sdk/pull/11959) Added `sdk.Coins.Find` helper method to find a coin by denom. ### API Breaking Changes diff --git a/types/coin.go b/types/coin.go index b6cf93813..c5bf96aea 100644 --- a/types/coin.go +++ b/types/coin.go @@ -703,28 +703,37 @@ func (coins Coins) AmountOf(denom string) Int { // AmountOfNoDenomValidation returns the amount of a denom from coins // without validating the denomination. func (coins Coins) AmountOfNoDenomValidation(denom string) Int { + if ok, c := coins.Find(denom); ok { + return c.Amount + } + return ZeroInt() +} + +// Find returns true and coin if the denom exists in coins. Otherwise it returns false +// and a zero coin. Uses binary search. +// CONTRACT: coins must be valid (sorted). +func (coins Coins) Find(denom string) (bool, Coin) { switch len(coins) { case 0: - return ZeroInt() + return false, Coin{} case 1: coin := coins[0] if coin.Denom == denom { - return coin.Amount + return true, coin } - return ZeroInt() + return false, Coin{} default: - // Binary search the amount of coins remaining midIdx := len(coins) / 2 // 2:1, 3:1, 4:2 coin := coins[midIdx] switch { case denom < coin.Denom: - return coins[:midIdx].AmountOfNoDenomValidation(denom) + return coins[:midIdx].Find(denom) case denom == coin.Denom: - return coin.Amount + return true, coin default: - return coins[midIdx+1:].AmountOfNoDenomValidation(denom) + return coins[midIdx+1:].Find(denom) } } } diff --git a/types/coin_test.go b/types/coin_test.go index f405457bc..7ee9c91c6 100644 --- a/types/coin_test.go +++ b/types/coin_test.go @@ -931,7 +931,8 @@ func (s *coinTestSuite) TestSortCoins() { } } -func (s *coinTestSuite) TestAmountOf() { +func (s *coinTestSuite) TestSearch() { + require := s.Require() case0 := sdk.Coins{} case1 := sdk.Coins{ sdk.NewInt64Coin("gold", 0), @@ -949,7 +950,7 @@ func (s *coinTestSuite) TestAmountOf() { sdk.NewInt64Coin("gas", 8), } - cases := []struct { + amountOfCases := []struct { coins sdk.Coins amountOf int64 amountOfSpace int64 @@ -964,13 +965,38 @@ func (s *coinTestSuite) TestAmountOf() { {case4, 0, 0, 8, 0, 0}, } - for _, tc := range cases { - s.Require().Equal(sdk.NewInt(tc.amountOfGAS), tc.coins.AmountOf("gas")) - s.Require().Equal(sdk.NewInt(tc.amountOfMINERAL), tc.coins.AmountOf("mineral")) - s.Require().Equal(sdk.NewInt(tc.amountOfTREE), tc.coins.AmountOf("tree")) - } + s.Run("AmountOf", func() { + for i, tc := range amountOfCases { + require.Equal(sdk.NewInt(tc.amountOfGAS), tc.coins.AmountOf("gas"), i) + require.Equal(sdk.NewInt(tc.amountOfMINERAL), tc.coins.AmountOf("mineral"), i) + require.Equal(sdk.NewInt(tc.amountOfTREE), tc.coins.AmountOf("tree"), i) + } + require.Panics(func() { amountOfCases[0].coins.AmountOf("10Invalid") }) + }) - s.Require().Panics(func() { cases[0].coins.AmountOf("10Invalid") }) + zeroCoin := sdk.Coin{} + findCases := []struct { + coins sdk.Coins + denom string + expectedOk bool + expectedCoin sdk.Coin + }{ + {case0, "any", false, zeroCoin}, + {case1, "other", false, zeroCoin}, + {case1, "gold", true, case1[0]}, + {case4, "gas", true, case4[0]}, + {case2, "gas", true, case2[0]}, + {case2, "mineral", true, case2[1]}, + {case2, "tree", true, case2[2]}, + {case2, "other", false, zeroCoin}, + } + s.Run("Find", func() { + for i, tc := range findCases { + ok, c := tc.coins.Find(tc.denom) + require.Equal(tc.expectedOk, ok, i) + require.Equal(tc.expectedCoin, c, i) + } + }) } func (s *coinTestSuite) TestCoinsIsAnyGTE() {