Add IsValid check on sendCoins (#3212)
This commit is contained in:
parent
d3804774d4
commit
0be04aff86
|
@ -140,10 +140,11 @@ func (coins Coins) IsValid() bool {
|
||||||
case 1:
|
case 1:
|
||||||
return coins[0].IsPositive()
|
return coins[0].IsPositive()
|
||||||
default:
|
default:
|
||||||
// Check single coin case
|
// check single coin case
|
||||||
if !(Coins{coins[0]}).IsValid() {
|
if !(Coins{coins[0]}).IsValid() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
lowDenom := coins[0].Denom
|
lowDenom := coins[0].Denom
|
||||||
for _, coin := range coins[1:] {
|
for _, coin := range coins[1:] {
|
||||||
if coin.Denom <= lowDenom {
|
if coin.Denom <= lowDenom {
|
||||||
|
|
|
@ -194,8 +194,13 @@ func addCoins(ctx sdk.Context, am auth.AccountKeeper, addr sdk.AccAddress, amt s
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendCoins moves coins from one account to another
|
// SendCoins moves coins from one account to another
|
||||||
// NOTE: Make sure to revert state changes from tx on error
|
|
||||||
func sendCoins(ctx sdk.Context, am auth.AccountKeeper, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) (sdk.Tags, sdk.Error) {
|
func sendCoins(ctx sdk.Context, am auth.AccountKeeper, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) (sdk.Tags, sdk.Error) {
|
||||||
|
// Safety check ensuring that when sending coins the keeper must maintain the
|
||||||
|
// supply invariant.
|
||||||
|
if !amt.IsValid() {
|
||||||
|
return nil, sdk.ErrInvalidCoins(amt.String())
|
||||||
|
}
|
||||||
|
|
||||||
_, subTags, err := subtractCoins(ctx, am, fromAddr, amt)
|
_, subTags, err := subtractCoins(ctx, am, fromAddr, amt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -212,6 +217,12 @@ func sendCoins(ctx sdk.Context, am auth.AccountKeeper, fromAddr sdk.AccAddress,
|
||||||
// InputOutputCoins handles a list of inputs and outputs
|
// InputOutputCoins handles a list of inputs and outputs
|
||||||
// NOTE: Make sure to revert state changes from tx on error
|
// NOTE: Make sure to revert state changes from tx on error
|
||||||
func inputOutputCoins(ctx sdk.Context, am auth.AccountKeeper, inputs []Input, outputs []Output) (sdk.Tags, sdk.Error) {
|
func inputOutputCoins(ctx sdk.Context, am auth.AccountKeeper, inputs []Input, outputs []Output) (sdk.Tags, sdk.Error) {
|
||||||
|
// Safety check ensuring that when sending coins the keeper must maintain the
|
||||||
|
// supply invariant.
|
||||||
|
if err := ValidateInputsOutputs(inputs, outputs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
allTags := sdk.EmptyTags()
|
allTags := sdk.EmptyTags()
|
||||||
|
|
||||||
for _, in := range inputs {
|
for _, in := range inputs {
|
||||||
|
|
|
@ -10,6 +10,19 @@ import (
|
||||||
dbm "github.com/tendermint/tendermint/libs/db"
|
dbm "github.com/tendermint/tendermint/libs/db"
|
||||||
"github.com/tendermint/tendermint/libs/log"
|
"github.com/tendermint/tendermint/libs/log"
|
||||||
|
|
||||||
|
codec "github.com/cosmos/cosmos-sdk/codec"
|
||||||
|
"github.com/cosmos/cosmos-sdk/store"
|
||||||
|
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||||
|
"github.com/cosmos/cosmos-sdk/x/auth"
|
||||||
|
"github.com/cosmos/cosmos-sdk/x/params"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testInput struct {
|
||||||
|
cdc *codec.Codec
|
||||||
|
ctx sdk.Context
|
||||||
|
ak auth.AccountKeeper
|
||||||
|
}
|
||||||
|
|
||||||
codec "github.com/cosmos/cosmos-sdk/codec"
|
codec "github.com/cosmos/cosmos-sdk/codec"
|
||||||
"github.com/cosmos/cosmos-sdk/store"
|
"github.com/cosmos/cosmos-sdk/store"
|
||||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||||
|
@ -108,7 +121,6 @@ func TestKeeper(t *testing.T) {
|
||||||
require.True(t, bankKeeper.GetCoins(ctx, addr).IsEqual(sdk.Coins{sdk.NewInt64Coin("barcoin", 21), sdk.NewInt64Coin("foocoin", 4)}))
|
require.True(t, bankKeeper.GetCoins(ctx, addr).IsEqual(sdk.Coins{sdk.NewInt64Coin("barcoin", 21), sdk.NewInt64Coin("foocoin", 4)}))
|
||||||
require.True(t, bankKeeper.GetCoins(ctx, addr2).IsEqual(sdk.Coins{sdk.NewInt64Coin("barcoin", 7), sdk.NewInt64Coin("foocoin", 6)}))
|
require.True(t, bankKeeper.GetCoins(ctx, addr2).IsEqual(sdk.Coins{sdk.NewInt64Coin("barcoin", 7), sdk.NewInt64Coin("foocoin", 6)}))
|
||||||
require.True(t, bankKeeper.GetCoins(ctx, addr3).IsEqual(sdk.Coins{sdk.NewInt64Coin("barcoin", 2), sdk.NewInt64Coin("foocoin", 5)}))
|
require.True(t, bankKeeper.GetCoins(ctx, addr3).IsEqual(sdk.Coins{sdk.NewInt64Coin("barcoin", 2), sdk.NewInt64Coin("foocoin", 5)}))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSendKeeper(t *testing.T) {
|
func TestSendKeeper(t *testing.T) {
|
||||||
|
@ -146,8 +158,8 @@ func TestSendKeeper(t *testing.T) {
|
||||||
require.True(t, sendKeeper.GetCoins(ctx, addr).IsEqual(sdk.Coins{sdk.NewInt64Coin("foocoin", 10)}))
|
require.True(t, sendKeeper.GetCoins(ctx, addr).IsEqual(sdk.Coins{sdk.NewInt64Coin("foocoin", 10)}))
|
||||||
require.True(t, sendKeeper.GetCoins(ctx, addr2).IsEqual(sdk.Coins{sdk.NewInt64Coin("foocoin", 5)}))
|
require.True(t, sendKeeper.GetCoins(ctx, addr2).IsEqual(sdk.Coins{sdk.NewInt64Coin("foocoin", 5)}))
|
||||||
|
|
||||||
_, err2 := sendKeeper.SendCoins(ctx, addr, addr2, sdk.Coins{sdk.NewInt64Coin("foocoin", 50)})
|
_, err := sendKeeper.SendCoins(ctx, addr, addr2, sdk.Coins{sdk.NewInt64Coin("foocoin", 50)})
|
||||||
assert.Implements(t, (*sdk.Error)(nil), err2)
|
require.Implements(t, (*sdk.Error)(nil), err)
|
||||||
require.True(t, sendKeeper.GetCoins(ctx, addr).IsEqual(sdk.Coins{sdk.NewInt64Coin("foocoin", 10)}))
|
require.True(t, sendKeeper.GetCoins(ctx, addr).IsEqual(sdk.Coins{sdk.NewInt64Coin("foocoin", 10)}))
|
||||||
require.True(t, sendKeeper.GetCoins(ctx, addr2).IsEqual(sdk.Coins{sdk.NewInt64Coin("foocoin", 5)}))
|
require.True(t, sendKeeper.GetCoins(ctx, addr2).IsEqual(sdk.Coins{sdk.NewInt64Coin("foocoin", 5)}))
|
||||||
|
|
||||||
|
@ -156,6 +168,11 @@ func TestSendKeeper(t *testing.T) {
|
||||||
require.True(t, sendKeeper.GetCoins(ctx, addr).IsEqual(sdk.Coins{sdk.NewInt64Coin("barcoin", 20), sdk.NewInt64Coin("foocoin", 5)}))
|
require.True(t, sendKeeper.GetCoins(ctx, addr).IsEqual(sdk.Coins{sdk.NewInt64Coin("barcoin", 20), sdk.NewInt64Coin("foocoin", 5)}))
|
||||||
require.True(t, sendKeeper.GetCoins(ctx, addr2).IsEqual(sdk.Coins{sdk.NewInt64Coin("barcoin", 10), sdk.NewInt64Coin("foocoin", 10)}))
|
require.True(t, sendKeeper.GetCoins(ctx, addr2).IsEqual(sdk.Coins{sdk.NewInt64Coin("barcoin", 10), sdk.NewInt64Coin("foocoin", 10)}))
|
||||||
|
|
||||||
|
// validate coins with invalid denoms or negative values cannot be sent
|
||||||
|
// NOTE: We must use the Coin literal as the constructor does not allow
|
||||||
|
// negative values.
|
||||||
|
_, err = sendKeeper.SendCoins(ctx, addr, addr2, sdk.Coins{sdk.Coin{"FOOCOIN", sdk.NewInt(-5)}})
|
||||||
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestViewKeeper(t *testing.T) {
|
func TestViewKeeper(t *testing.T) {
|
||||||
|
|
|
@ -37,25 +37,8 @@ func (msg MsgSend) ValidateBasic() sdk.Error {
|
||||||
if len(msg.Outputs) == 0 {
|
if len(msg.Outputs) == 0 {
|
||||||
return ErrNoOutputs(DefaultCodespace).TraceSDK("")
|
return ErrNoOutputs(DefaultCodespace).TraceSDK("")
|
||||||
}
|
}
|
||||||
// make sure all inputs and outputs are individually valid
|
|
||||||
var totalIn, totalOut sdk.Coins
|
return ValidateInputsOutputs(msg.Inputs, msg.Outputs)
|
||||||
for _, in := range msg.Inputs {
|
|
||||||
if err := in.ValidateBasic(); err != nil {
|
|
||||||
return err.TraceSDK("")
|
|
||||||
}
|
|
||||||
totalIn = totalIn.Plus(in.Coins)
|
|
||||||
}
|
|
||||||
for _, out := range msg.Outputs {
|
|
||||||
if err := out.ValidateBasic(); err != nil {
|
|
||||||
return err.TraceSDK("")
|
|
||||||
}
|
|
||||||
totalOut = totalOut.Plus(out.Coins)
|
|
||||||
}
|
|
||||||
// make sure inputs and outputs match
|
|
||||||
if !totalIn.IsEqual(totalOut) {
|
|
||||||
return sdk.ErrInvalidCoins(totalIn.String()).TraceSDK("inputs and outputs don't match")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implements Msg.
|
// Implements Msg.
|
||||||
|
@ -170,3 +153,33 @@ func NewOutput(addr sdk.AccAddress, coins sdk.Coins) Output {
|
||||||
}
|
}
|
||||||
return output
|
return output
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Auxiliary
|
||||||
|
|
||||||
|
// ValidateInputsOutputs validates that each respective input and output is
|
||||||
|
// valid and that the sum of inputs is equal to the sum of outputs.
|
||||||
|
func ValidateInputsOutputs(inputs []Input, outputs []Output) sdk.Error {
|
||||||
|
var totalIn, totalOut sdk.Coins
|
||||||
|
|
||||||
|
for _, in := range inputs {
|
||||||
|
if err := in.ValidateBasic(); err != nil {
|
||||||
|
return err.TraceSDK("")
|
||||||
|
}
|
||||||
|
totalIn = totalIn.Plus(in.Coins)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, out := range outputs {
|
||||||
|
if err := out.ValidateBasic(); err != nil {
|
||||||
|
return err.TraceSDK("")
|
||||||
|
}
|
||||||
|
totalOut = totalOut.Plus(out.Coins)
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure inputs and outputs match
|
||||||
|
if !totalIn.IsEqual(totalOut) {
|
||||||
|
return sdk.ErrInvalidCoins(totalIn.String()).TraceSDK("inputs and outputs don't match")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue