wormchain: tokenbridge: Clean up coin handling

Use `sdk.Coin` methods as much as possible rather than mixing with
big.Int and uint256.Int.
This commit is contained in:
Chirantan Ekbote 2022-08-25 17:24:14 +09:00 committed by Chirantan Ekbote
parent ce40d17c74
commit c53813ce37
3 changed files with 58 additions and 70 deletions

View File

@ -109,42 +109,39 @@ func (k msgServer) ExecuteVAA(goCtx context.Context, msg *types.MsgExecuteVAA) (
} }
} }
amount, err := types.Untruncate(unnormalizedAmount, meta) amt := sdk.NewCoin(identifier, sdk.NewIntFromBigInt(unnormalizedAmount))
if err := amt.Validate(); err != nil {
return nil, fmt.Errorf("%w: %s", types.ErrInvalidAmount, err)
}
amount, err := types.Untruncate(amt, meta)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to untruncate amount: %w", err)
} }
fee, err := types.Untruncate(unnormalizedFee, meta) f := sdk.NewCoin(identifier, sdk.NewIntFromBigInt(unnormalizedFee))
if err != nil { if err := f.Validate(); err != nil {
return nil, err return nil, fmt.Errorf("%w: %s", types.ErrInvalidFee, err)
} }
if fee.Sign() == -1 { fee, err := types.Untruncate(f, meta)
return nil, types.ErrNegativeFee if err != nil {
return nil, fmt.Errorf("failed to untruncate fee: %w", err)
}
if amount.IsLT(fee) {
return nil, types.ErrFeeTooHigh
} }
if wrapped { if wrapped {
err = k.bankKeeper.MintCoins(ctx, types.ModuleName, sdk.Coins{ if err := k.bankKeeper.MintCoins(ctx, types.ModuleName, sdk.Coins{amount}); err != nil {
{ return nil, fmt.Errorf("failed to mint coins (%s): %w", amount, err)
Denom: identifier,
Amount: sdk.NewIntFromBigInt(amount),
},
})
if err != nil {
return nil, err
} }
} }
moduleAccount := k.accountKeeper.GetModuleAddress(types.ModuleName) moduleAccount := k.accountKeeper.GetModuleAddress(types.ModuleName)
amtLessFees := sdk.Coins{ amtLessFees := amount.Sub(fee)
{
Denom: identifier,
Amount: sdk.NewIntFromBigInt(new(big.Int).Sub(amount, fee)),
},
}
err = k.bankKeeper.SendCoins(ctx, moduleAccount, to[:], amtLessFees) if err := k.bankKeeper.SendCoins(ctx, moduleAccount, to[:], sdk.Coins{amtLessFees}); err != nil {
if err != nil {
return nil, err return nil, err
} }
@ -153,15 +150,9 @@ func (k msgServer) ExecuteVAA(goCtx context.Context, msg *types.MsgExecuteVAA) (
return nil, err return nil, err
} }
// Transfer fee to tx sender if it is not 0 // Transfer fee to tx sender if it is not 0
if fee.Sign() == 1 { if fee.IsPositive() {
err = k.bankKeeper.SendCoins(ctx, moduleAccount, txSender, sdk.Coins{ if err := k.bankKeeper.SendCoins(ctx, moduleAccount, txSender, sdk.Coins{fee}); err != nil {
{ return nil, fmt.Errorf("failed to send fees (%s) to tx sender: %w", fee, err)
Denom: identifier,
Amount: sdk.NewIntFromBigInt(fee),
},
})
if err != nil {
return nil, err
} }
} }
@ -170,8 +161,8 @@ func (k msgServer) ExecuteVAA(goCtx context.Context, msg *types.MsgExecuteVAA) (
TokenAddress: tokenAddress[:], TokenAddress: tokenAddress[:],
To: sdk.AccAddress(to[:]).String(), To: sdk.AccAddress(to[:]).String(),
FeeRecipient: txSender.String(), FeeRecipient: txSender.String(),
Amount: amount.String(), Amount: amount.Amount.String(),
Fee: fee.String(), Fee: fee.Amount.String(),
LocalDenom: identifier, LocalDenom: identifier,
}) })
if err != nil { if err != nil {

View File

@ -4,13 +4,13 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/binary" "encoding/binary"
"fmt"
"math/big" "math/big"
"github.com/certusone/wormhole-chain/x/tokenbridge/types" "github.com/certusone/wormhole-chain/x/tokenbridge/types"
whtypes "github.com/certusone/wormhole-chain/x/wormhole/types" whtypes "github.com/certusone/wormhole-chain/x/wormhole/types"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/holiman/uint256"
) )
func (k msgServer) Transfer(goCtx context.Context, msg *types.MsgTransfer) (*types.MsgTransferResponse, error) { func (k msgServer) Transfer(goCtx context.Context, msg *types.MsgTransfer) (*types.MsgTransferResponse, error) {
@ -49,31 +49,30 @@ func (k msgServer) Transfer(goCtx context.Context, msg *types.MsgTransfer) (*typ
} }
} }
bridgeBalance := new(big.Int).Set(k.bankKeeper.GetBalance(ctx, k.accountKeeper.GetModuleAddress(types.ModuleName), msg.Amount.Denom).Amount.BigInt()) bridgeBalance, err := types.Truncate(k.bankKeeper.GetBalance(ctx, moduleAddress, msg.Amount.Denom), meta)
amount := new(big.Int).Set(msg.Amount.Amount.BigInt())
fees := new(big.Int).Set(msg.Fee.Amount.BigInt())
truncAmount, err := types.Truncate(amount, meta)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to truncate bridge balance: %w", err)
}
amount, err := types.Truncate(msg.Amount, meta)
if err != nil {
return nil, fmt.Errorf("%w: %s", types.ErrInvalidAmount, err)
} }
truncFees, err := types.Truncate(fees, meta) fees, err := types.Truncate(msg.Fee, meta)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("%w: %s", types.ErrInvalidFee, err)
} }
truncBridgeBalance, err := types.Truncate(bridgeBalance, meta) if amount.IsLT(fees) {
if err != nil { return nil, types.ErrFeeTooHigh
return nil, err
} }
if !truncAmount.IsUint64() || !bridgeBalance.IsUint64() { if !amount.Amount.IsUint64() || !bridgeBalance.Amount.IsUint64() {
return nil, types.ErrAmountTooHigh return nil, types.ErrAmountTooHigh
} }
// Check that the total outflow of this asset does not exceed u64 // Check that the total outflow of this asset does not exceed u64
if !new(big.Int).Add(truncAmount, truncBridgeBalance).IsUint64() { if !bridgeBalance.Add(amount).Amount.IsUint64() {
return nil, types.ErrAmountTooHigh return nil, types.ErrAmountTooHigh
} }
@ -81,11 +80,7 @@ func (k msgServer) Transfer(goCtx context.Context, msg *types.MsgTransfer) (*typ
// PayloadID // PayloadID
buf.WriteByte(1) buf.WriteByte(1)
// Amount // Amount
tokenAmount, overflow := uint256.FromBig(truncAmount) tokenAmountBytes32 := bytes32(amount.Amount.BigInt())
if overflow {
return nil, types.ErrInvalidAmount
}
tokenAmountBytes32 := tokenAmount.Bytes32()
buf.Write(tokenAmountBytes32[:]) buf.Write(tokenAmountBytes32[:])
tokenChain, tokenAddress, err := types.GetTokenMeta(wormholeConfig, msg.Amount.Denom) tokenChain, tokenAddress, err := types.GetTokenMeta(wormholeConfig, msg.Amount.Denom)
if err != nil { if err != nil {
@ -100,18 +95,9 @@ func (k msgServer) Transfer(goCtx context.Context, msg *types.MsgTransfer) (*typ
// ToChain // ToChain
MustWrite(buf, binary.BigEndian, uint16(msg.ToChain)) MustWrite(buf, binary.BigEndian, uint16(msg.ToChain))
// Fee // Fee
fee, overflow := uint256.FromBig(truncFees) feeBytes32 := bytes32(fees.Amount.BigInt())
if overflow {
return nil, types.ErrInvalidFee
}
feeBytes32 := fee.Bytes32()
buf.Write(feeBytes32[:]) buf.Write(feeBytes32[:])
// Check that the amount is sufficient to cover the fee
if truncAmount.Cmp(truncFees) != 1 {
return nil, types.ErrFeeTooHigh
}
// Post message // Post message
emitterAddress := whtypes.EmitterAddressFromAccAddress(moduleAddress) emitterAddress := whtypes.EmitterAddressFromAccAddress(moduleAddress)
err = k.wormholeKeeper.PostMessage(ctx, emitterAddress, 0, buf.Bytes()) err = k.wormholeKeeper.PostMessage(ctx, emitterAddress, 0, buf.Bytes())
@ -121,3 +107,11 @@ func (k msgServer) Transfer(goCtx context.Context, msg *types.MsgTransfer) (*typ
return &types.MsgTransferResponse{}, nil return &types.MsgTransferResponse{}, nil
} }
func bytes32(i *big.Int) [32]byte {
var out [32]byte
i.FillBytes(out[:])
return out
}

View File

@ -9,27 +9,30 @@ import (
"strings" "strings"
whtypes "github.com/certusone/wormhole-chain/x/wormhole/types" whtypes "github.com/certusone/wormhole-chain/x/wormhole/types"
sdk "github.com/cosmos/cosmos-sdk/types"
btypes "github.com/cosmos/cosmos-sdk/x/bank/types" btypes "github.com/cosmos/cosmos-sdk/x/bank/types"
) )
// Truncate an amount // Truncate an amount
func Truncate(amount *big.Int, meta btypes.Metadata) (normalized *big.Int, err error) { func Truncate(coin sdk.Coin, meta btypes.Metadata) (normalized sdk.Coin, err error) {
factor, err := truncFactor(meta) factor, err := truncFactor(meta)
if err != nil { if err != nil {
return new(big.Int), err return normalized, err
} }
return new(big.Int).Div(amount, factor), nil amt := new(big.Int).Div(coin.Amount.BigInt(), factor)
return sdk.NewCoin(coin.Denom, sdk.NewIntFromBigInt(amt)), nil
} }
// Untruncate an amount // Untruncate an amount
func Untruncate(amount *big.Int, meta btypes.Metadata) (normalized *big.Int, err error) { func Untruncate(coin sdk.Coin, meta btypes.Metadata) (normalized sdk.Coin, err error) {
factor, err := truncFactor(meta) factor, err := truncFactor(meta)
if err != nil { if err != nil {
return new(big.Int), err return normalized, err
} }
return new(big.Int).Mul(amount, factor), nil amt := new(big.Int).Mul(coin.Amount.BigInt(), factor)
return sdk.NewCoin(coin.Denom, sdk.NewIntFromBigInt(amt)), nil
} }
// Compute truncation factor for a given token meta. // Compute truncation factor for a given token meta.
@ -55,7 +58,7 @@ func truncFactor(meta btypes.Metadata) (factor *big.Int, err error) {
if displayDenom.Exponent > 8 { if displayDenom.Exponent > 8 {
return new(big.Int).SetInt64(int64(math.Pow10(int(displayDenom.Exponent - 8)))), nil return new(big.Int).SetInt64(int64(math.Pow10(int(displayDenom.Exponent - 8)))), nil
} else { } else {
return new(big.Int).SetInt64(1), nil return big.NewInt(1), nil
} }
} }