diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b48c46a1..d62e3eed1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -67,6 +67,7 @@ FIXES * \#1367 - set ChainID in InitChain * \#1353 - CLI: Show pool shares fractions in human-readable format * \#1258 - printing big.rat's can no longer overflow int64 +* \#887 - limit the size of rationals that can be passed in from user input IMPROVEMENTS * bank module uses go-wire codec instead of 'encoding/json' diff --git a/types/int.go b/types/int.go index 760fc607b..d04c6a80c 100644 --- a/types/int.go +++ b/types/int.go @@ -227,6 +227,10 @@ func (i Int) Neg() (res Int) { return Int{neg(i.i)} } +func (i Int) String() string { + return i.i.String() +} + // MarshalAmino defines custom encoding scheme func (i Int) MarshalAmino() (string, error) { if i.i == nil { // Necessary since default Uint initialization has i.i as nil diff --git a/types/rational.go b/types/rational.go index a192aa316..f24831f89 100644 --- a/types/rational.go +++ b/types/rational.go @@ -38,8 +38,8 @@ func NewRat(Numerator int64, Denominator ...int64) Rat { } // create a rational from decimal string or integer string -func NewRatFromDecimal(decimalStr string) (f Rat, err Error) { - +// precision is the number of values after the decimal point which should be read +func NewRatFromDecimal(decimalStr string, prec int) (f Rat, err Error) { // first extract any negative symbol neg := false if string(decimalStr[0]) == "-" { @@ -61,6 +61,9 @@ func NewRatFromDecimal(decimalStr string) (f Rat, err Error) { if len(str[0]) == 0 || len(str[1]) == 0 { return f, ErrUnknownRequest("not a decimal string") } + if len(str[1]) > prec { + return f, ErrUnknownRequest("string has too many decimals") + } numStr = str[0] + str[1] len := int64(len(str[1])) denom = new(big.Int).Exp(big.NewInt(10), big.NewInt(len), nil).Int64() @@ -69,8 +72,20 @@ func NewRatFromDecimal(decimalStr string) (f Rat, err Error) { } num, errConv := strconv.Atoi(numStr) - if errConv != nil { - return f, ErrUnknownRequest(errConv.Error()) + if errConv != nil && strings.HasSuffix(errConv.Error(), "value out of range") { + // resort to big int, don't make this default option for efficiency + numBig, success := new(big.Int).SetString(numStr, 10) + if success != true { + return f, ErrUnknownRequest("not a decimal string") + } + + if neg { + numBig.Neg(numBig) + } + + return NewRatFromBigInt(numBig, big.NewInt(denom)), nil + } else if errConv != nil { + return f, ErrUnknownRequest("not a decimal string") } if neg { @@ -105,9 +120,9 @@ func NewRatFromInt(num Int, denom ...Int) Rat { } //nolint -func (r Rat) Num() int64 { return r.Rat.Num().Int64() } // Num - return the numerator -func (r Rat) Denom() int64 { return r.Rat.Denom().Int64() } // Denom - return the denominator -func (r Rat) IsZero() bool { return r.Num() == 0 } // IsZero - Is the Rat equal to zero +func (r Rat) Num() Int { return Int{r.Rat.Num()} } // Num - return the numerator +func (r Rat) Denom() Int { return Int{r.Rat.Denom()} } // Denom - return the denominator +func (r Rat) IsZero() bool { return r.Num().IsZero() } // IsZero - Is the Rat equal to zero func (r Rat) Equal(r2 Rat) bool { return (r.Rat).Cmp(r2.Rat) == 0 } func (r Rat) GT(r2 Rat) bool { return (r.Rat).Cmp(r2.Rat) == 1 } // greater than func (r Rat) GTE(r2 Rat) bool { return !r.LT(r2) } // greater than or equal diff --git a/types/rational_test.go b/types/rational_test.go index 43c9ddd57..a137ca498 100644 --- a/types/rational_test.go +++ b/types/rational_test.go @@ -21,6 +21,8 @@ func TestNew(t *testing.T) { } func TestNewFromDecimal(t *testing.T) { + largeBigInt, success := new(big.Int).SetString("3109736052979742687701388262607869", 10) + require.True(t, success) tests := []struct { decimalStr string expErr bool @@ -31,7 +33,13 @@ func TestNewFromDecimal(t *testing.T) { {"1.1", false, NewRat(11, 10)}, {"0.75", false, NewRat(3, 4)}, {"0.8", false, NewRat(4, 5)}, - {"0.11111", false, NewRat(11111, 100000)}, + {"0.11111", true, NewRat(1111, 10000)}, + {"628240629832763.5738930323617075341", true, NewRat(3141203149163817869, 5000)}, + {"621947210595948537540277652521.5738930323617075341", + true, NewRatFromBigInt(largeBigInt, big.NewInt(5000))}, + {"628240629832763.5738", false, NewRat(3141203149163817869, 5000)}, + {"621947210595948537540277652521.5738", + false, NewRatFromBigInt(largeBigInt, big.NewInt(5000))}, {".", true, Rat{}}, {".0", true, Rat{}}, {"1.", true, Rat{}}, @@ -41,22 +49,21 @@ func TestNewFromDecimal(t *testing.T) { } for _, tc := range tests { - - res, err := NewRatFromDecimal(tc.decimalStr) + res, err := NewRatFromDecimal(tc.decimalStr, 4) if tc.expErr { assert.NotNil(t, err, tc.decimalStr) } else { - assert.Nil(t, err) - assert.True(t, res.Equal(tc.exp)) + require.Nil(t, err, tc.decimalStr) + require.True(t, res.Equal(tc.exp), tc.decimalStr) } // negative tc - res, err = NewRatFromDecimal("-" + tc.decimalStr) + res, err = NewRatFromDecimal("-"+tc.decimalStr, 4) if tc.expErr { assert.NotNil(t, err, tc.decimalStr) } else { - assert.Nil(t, err) - assert.True(t, res.Equal(tc.exp.Mul(NewRat(-1)))) + assert.Nil(t, err, tc.decimalStr) + assert.True(t, res.Equal(tc.exp.Mul(NewRat(-1))), tc.decimalStr) } } } @@ -133,7 +140,7 @@ func TestArithmetic(t *testing.T) { assert.True(t, tc.resAdd.Equal(tc.r1.Add(tc.r2)), "r1 %v, r2 %v", tc.r1.Rat, tc.r2.Rat) assert.True(t, tc.resSub.Equal(tc.r1.Sub(tc.r2)), "r1 %v, r2 %v", tc.r1.Rat, tc.r2.Rat) - if tc.r2.Num() == 0 { // panic for divide by zero + if tc.r2.Num().IsZero() { // panic for divide by zero assert.Panics(t, func() { tc.r1.Quo(tc.r2) }) } else { assert.True(t, tc.resDiv.Equal(tc.r1.Quo(tc.r2)), "r1 %v, r2 %v", tc.r1.Rat, tc.r2.Rat) diff --git a/x/stake/client/cli/tx.go b/x/stake/client/cli/tx.go index b7941e2bb..b0fa2e524 100644 --- a/x/stake/client/cli/tx.go +++ b/x/stake/client/cli/tx.go @@ -12,6 +12,7 @@ import ( "github.com/cosmos/cosmos-sdk/wire" authcmd "github.com/cosmos/cosmos-sdk/x/auth/client/cli" "github.com/cosmos/cosmos-sdk/x/stake" + "github.com/cosmos/cosmos-sdk/x/stake/types" ) // create create validator command @@ -219,7 +220,7 @@ func getShares(storeName string, cdc *wire.Codec, sharesAmountStr, sharesPercent case sharesAmountStr == "" && sharesPercentStr == "": return sharesAmount, errors.Errorf("can either specify the amount OR the percent of the shares, not both") case sharesAmountStr != "": - sharesAmount, err = sdk.NewRatFromDecimal(sharesAmountStr) + sharesAmount, err = sdk.NewRatFromDecimal(sharesAmountStr, types.MaxBondDenominatorPrecision) if err != nil { return sharesAmount, err } @@ -228,7 +229,7 @@ func getShares(storeName string, cdc *wire.Codec, sharesAmountStr, sharesPercent } case sharesPercentStr != "": var sharesPercent sdk.Rat - sharesPercent, err = sdk.NewRatFromDecimal(sharesPercentStr) + sharesPercent, err = sdk.NewRatFromDecimal(sharesPercentStr, types.MaxBondDenominatorPrecision) if err != nil { return sharesAmount, err } diff --git a/x/stake/client/rest/tx.go b/x/stake/client/rest/tx.go index 1d8547662..51a854528 100644 --- a/x/stake/client/rest/tx.go +++ b/x/stake/client/rest/tx.go @@ -14,6 +14,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/wire" "github.com/cosmos/cosmos-sdk/x/stake" + "github.com/cosmos/cosmos-sdk/x/stake/types" ) func registerTxRoutes(ctx context.CoreContext, r *mux.Router, cdc *wire.Codec, kb keys.Keybase) { @@ -145,7 +146,7 @@ func editDelegationsRequestHandlerFn(cdc *wire.Codec, kb keys.Keybase, ctx conte w.Write([]byte(fmt.Sprintf("Couldn't decode validator. Error: %s", err.Error()))) return } - shares, err := sdk.NewRatFromDecimal(msg.SharesAmount) + shares, err := sdk.NewRatFromDecimal(msg.SharesAmount, types.MaxBondDenominatorPrecision) if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(fmt.Sprintf("Couldn't decode shares amount. Error: %s", err.Error()))) @@ -210,7 +211,7 @@ func editDelegationsRequestHandlerFn(cdc *wire.Codec, kb keys.Keybase, ctx conte w.Write([]byte(fmt.Sprintf("Couldn't decode validator. Error: %s", err.Error()))) return } - shares, err := sdk.NewRatFromDecimal(msg.SharesAmount) + shares, err := sdk.NewRatFromDecimal(msg.SharesAmount, types.MaxBondDenominatorPrecision) if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(fmt.Sprintf("Couldn't decode shares amount. Error: %s", err.Error()))) diff --git a/x/stake/types/errors.go b/x/stake/types/errors.go index 622bd0e1a..2914741f4 100644 --- a/x/stake/types/errors.go +++ b/x/stake/types/errors.go @@ -79,6 +79,12 @@ func ErrNotEnoughDelegationShares(codespace sdk.CodespaceType, shares string) sd func ErrBadSharesAmount(codespace sdk.CodespaceType) sdk.Error { return sdk.NewError(codespace, CodeInvalidDelegation, "shares must be > 0") } +func ErrBadSharesPrecision(codespace sdk.CodespaceType) sdk.Error { + return sdk.NewError(codespace, CodeInvalidDelegation, + fmt.Sprintf("shares denominator must be < %s, try reducing the number of decimal points", + maximumBondingRationalDenominator.String()), + ) +} func ErrBadSharesPercent(codespace sdk.CodespaceType) sdk.Error { return sdk.NewError(codespace, CodeInvalidDelegation, "shares percent must be >0 and <=1") } diff --git a/x/stake/types/msg.go b/x/stake/types/msg.go index f8b8e6717..878c1ba17 100644 --- a/x/stake/types/msg.go +++ b/x/stake/types/msg.go @@ -1,6 +1,8 @@ package types import ( + "math" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/tendermint/tendermint/crypto" ) @@ -8,11 +10,18 @@ import ( // name to idetify transaction types const MsgType = "stake" -//Verify interface at compile time +// Maximum amount of decimal points in the decimal representation of rationals +// used in MsgBeginUnbonding / MsgBeginRedelegate +const MaxBondDenominatorPrecision = 8 + +// Verify interface at compile time var _, _, _ sdk.Msg = &MsgCreateValidator{}, &MsgEditValidator{}, &MsgDelegate{} var _, _ sdk.Msg = &MsgBeginUnbonding{}, &MsgCompleteUnbonding{} var _, _ sdk.Msg = &MsgBeginRedelegate{}, &MsgCompleteRedelegate{} +// Initialize Int for the denominator +var maximumBondingRationalDenominator sdk.Int = sdk.NewInt(int64(math.Pow10(MaxBondDenominatorPrecision))) + //______________________________________________________________________ // MsgCreateValidator - struct for unbonding transactions @@ -234,6 +243,9 @@ func (msg MsgBeginRedelegate) ValidateBasic() sdk.Error { if msg.SharesAmount.LTE(sdk.ZeroRat()) { return ErrBadSharesAmount(DefaultCodespace) } + if msg.SharesAmount.Denom().GT(maximumBondingRationalDenominator) { + return ErrBadSharesPrecision(DefaultCodespace) + } return nil } @@ -340,6 +352,9 @@ func (msg MsgBeginUnbonding) ValidateBasic() sdk.Error { if msg.SharesAmount.LTE(sdk.ZeroRat()) { return ErrBadSharesAmount(DefaultCodespace) } + if msg.SharesAmount.Denom().GT(maximumBondingRationalDenominator) { + return ErrBadSharesPrecision(DefaultCodespace) + } return nil }