cosmos-sdk/types/uint_test.go

327 lines
8.8 KiB
Go

package types_test
import (
"fmt"
"math"
"math/big"
"math/rand"
"testing"
"github.com/stretchr/testify/suite"
sdk "github.com/cosmos/cosmos-sdk/types"
)
type uintTestSuite struct {
suite.Suite
}
func TestUnitTestSuite(t *testing.T) {
suite.Run(t, new(uintTestSuite))
}
func (s *uintTestSuite) SetupSuite() {
s.T().Parallel()
}
func (s *uintTestSuite) TestUintPanics() {
// Max Uint = 1.15e+77
// Min Uint = 0
u1 := sdk.NewUint(0)
u2 := sdk.OneUint()
s.Require().Equal(uint64(0), u1.Uint64())
s.Require().Equal(uint64(1), u2.Uint64())
s.Require().Panics(func() { sdk.NewUintFromBigInt(big.NewInt(-5)) })
s.Require().Panics(func() { sdk.NewUintFromString("-1") })
s.Require().NotPanics(func() {
s.Require().True(sdk.NewUintFromString("0").Equal(sdk.ZeroUint()))
s.Require().True(sdk.NewUintFromString("5").Equal(sdk.NewUint(5)))
})
// Overflow check
s.Require().True(u1.Add(u1).Equal(sdk.ZeroUint()))
s.Require().True(u1.Add(sdk.OneUint()).Equal(sdk.OneUint()))
s.Require().Equal(uint64(0), u1.Uint64())
s.Require().Equal(uint64(1), sdk.OneUint().Uint64())
s.Require().Panics(func() { u1.SubUint64(2) })
s.Require().True(u1.SubUint64(0).Equal(sdk.ZeroUint()))
s.Require().True(u2.Add(sdk.OneUint()).Sub(sdk.OneUint()).Equal(sdk.OneUint())) // i2 == 1
s.Require().True(u2.Add(sdk.OneUint()).Mul(sdk.NewUint(5)).Equal(sdk.NewUint(10))) // i2 == 10
s.Require().True(sdk.NewUint(7).Quo(sdk.NewUint(2)).Equal(sdk.NewUint(3)))
s.Require().True(sdk.NewUint(0).Quo(sdk.NewUint(2)).Equal(sdk.ZeroUint()))
s.Require().True(sdk.NewUint(5).MulUint64(4).Equal(sdk.NewUint(20)))
s.Require().True(sdk.NewUint(5).MulUint64(0).Equal(sdk.ZeroUint()))
uintmax := sdk.NewUintFromBigInt(new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1)))
uintmin := sdk.ZeroUint()
// divs by zero
s.Require().Panics(func() { sdk.OneUint().Mul(sdk.ZeroUint().SubUint64(uint64(1))) })
s.Require().Panics(func() { sdk.OneUint().QuoUint64(0) })
s.Require().Panics(func() { sdk.OneUint().Quo(sdk.ZeroUint()) })
s.Require().Panics(func() { sdk.ZeroUint().QuoUint64(0) })
s.Require().Panics(func() { sdk.OneUint().Quo(sdk.ZeroUint().Sub(sdk.OneUint())) })
s.Require().Panics(func() { uintmax.Add(sdk.OneUint()) })
s.Require().Panics(func() { uintmax.Incr() })
s.Require().Panics(func() { uintmin.Sub(sdk.OneUint()) })
s.Require().Panics(func() { uintmin.Decr() })
s.Require().Equal(uint64(0), sdk.MinUint(sdk.ZeroUint(), sdk.OneUint()).Uint64())
s.Require().Equal(uint64(1), sdk.MaxUint(sdk.ZeroUint(), sdk.OneUint()).Uint64())
// comparison ops
s.Require().True(
sdk.OneUint().GT(sdk.ZeroUint()),
)
s.Require().False(
sdk.OneUint().LT(sdk.ZeroUint()),
)
s.Require().True(
sdk.OneUint().GTE(sdk.ZeroUint()),
)
s.Require().False(
sdk.OneUint().LTE(sdk.ZeroUint()),
)
s.Require().False(sdk.ZeroUint().GT(sdk.OneUint()))
s.Require().True(sdk.ZeroUint().LT(sdk.OneUint()))
s.Require().False(sdk.ZeroUint().GTE(sdk.OneUint()))
s.Require().True(sdk.ZeroUint().LTE(sdk.OneUint()))
}
func (s *uintTestSuite) TestArithUint() {
for d := 0; d < 1000; d++ {
n1 := uint64(rand.Uint32())
u1 := sdk.NewUint(n1)
n2 := uint64(rand.Uint32())
u2 := sdk.NewUint(n2)
cases := []struct {
ures sdk.Uint
nres uint64
}{
{u1.Add(u2), n1 + n2},
{u1.Mul(u2), n1 * n2},
{u1.Quo(u2), n1 / n2},
{u1.AddUint64(n2), n1 + n2},
{u1.MulUint64(n2), n1 * n2},
{u1.QuoUint64(n2), n1 / n2},
{sdk.MinUint(u1, u2), minuint(n1, n2)},
{sdk.MaxUint(u1, u2), maxuint(n1, n2)},
{u1.Incr(), n1 + 1},
}
for tcnum, tc := range cases {
s.Require().Equal(tc.nres, tc.ures.Uint64(), "Uint arithmetic operation does not match with uint64 operation. tc #%d", tcnum)
}
if n2 > n1 {
n1, n2 = n2, n1
u1, u2 = sdk.NewUint(n1), sdk.NewUint(n2)
}
subs := []struct {
ures sdk.Uint
nres uint64
}{
{u1.Sub(u2), n1 - n2},
{u1.SubUint64(n2), n1 - n2},
{u1.Decr(), n1 - 1},
}
for tcnum, tc := range subs {
s.Require().Equal(tc.nres, tc.ures.Uint64(), "Uint subtraction does not match with uint64 operation. tc #%d", tcnum)
}
}
}
func (s *uintTestSuite) TestCompUint() {
for d := 0; d < 10000; d++ {
n1 := rand.Uint64()
i1 := sdk.NewUint(n1)
n2 := rand.Uint64()
i2 := sdk.NewUint(n2)
cases := []struct {
ires bool
nres bool
}{
{i1.Equal(i2), n1 == n2},
{i1.GT(i2), n1 > n2},
{i1.LT(i2), n1 < n2},
{i1.GTE(i2), !i1.LT(i2)},
{!i1.GTE(i2), i1.LT(i2)},
{i1.LTE(i2), n1 <= n2},
{i2.LTE(i1), n2 <= n1},
}
for tcnum, tc := range cases {
s.Require().Equal(tc.nres, tc.ires, "Uint comparison operation does not match with uint64 operation. tc #%d", tcnum)
}
}
}
func (s *uintTestSuite) TestImmutabilityAllUint() {
ops := []func(*sdk.Uint){
func(i *sdk.Uint) { _ = i.Add(sdk.NewUint(rand.Uint64())) },
func(i *sdk.Uint) { _ = i.Sub(sdk.NewUint(rand.Uint64() % i.Uint64())) },
func(i *sdk.Uint) { _ = i.Mul(randuint()) },
func(i *sdk.Uint) { _ = i.Quo(randuint()) },
func(i *sdk.Uint) { _ = i.AddUint64(rand.Uint64()) },
func(i *sdk.Uint) { _ = i.SubUint64(rand.Uint64() % i.Uint64()) },
func(i *sdk.Uint) { _ = i.MulUint64(rand.Uint64()) },
func(i *sdk.Uint) { _ = i.QuoUint64(rand.Uint64()) },
func(i *sdk.Uint) { _ = i.IsZero() },
func(i *sdk.Uint) { _ = i.Equal(randuint()) },
func(i *sdk.Uint) { _ = i.GT(randuint()) },
func(i *sdk.Uint) { _ = i.GTE(randuint()) },
func(i *sdk.Uint) { _ = i.LT(randuint()) },
func(i *sdk.Uint) { _ = i.LTE(randuint()) },
func(i *sdk.Uint) { _ = i.String() },
func(i *sdk.Uint) { _ = i.Incr() },
func(i *sdk.Uint) {
if i.IsZero() {
return
}
_ = i.Decr()
},
}
for i := 0; i < 1000; i++ {
n := rand.Uint64()
ni := sdk.NewUint(n)
for opnum, op := range ops {
op(&ni)
s.Require().Equal(n, ni.Uint64(), "Uint is modified by operation. #%d", opnum)
s.Require().Equal(sdk.NewUint(n), ni, "Uint is modified by operation. #%d", opnum)
}
}
}
func (s *uintTestSuite) TestSafeSub() {
testCases := []struct {
x, y sdk.Uint
expected uint64
panic bool
}{
{sdk.NewUint(0), sdk.NewUint(0), 0, false},
{sdk.NewUint(10), sdk.NewUint(5), 5, false},
{sdk.NewUint(5), sdk.NewUint(10), 5, true},
{sdk.NewUint(math.MaxUint64), sdk.NewUint(0), math.MaxUint64, false},
}
for i, tc := range testCases {
tc := tc
if tc.panic {
s.Require().Panics(func() { tc.x.Sub(tc.y) })
continue
}
s.Require().Equal(
tc.expected, tc.x.Sub(tc.y).Uint64(),
"invalid subtraction result; x: %s, y: %s, tc: #%d", tc.x, tc.y, i,
)
}
}
func (s *uintTestSuite) TestParseUint() {
type args struct {
s string
}
tests := []struct {
name string
args args
want sdk.Uint
wantErr bool
}{
{"malformed", args{"malformed"}, sdk.Uint{}, true},
{"empty", args{""}, sdk.Uint{}, true},
{"positive", args{"50"}, sdk.NewUint(uint64(50)), false},
{"negative", args{"-1"}, sdk.Uint{}, true},
{"zero", args{"0"}, sdk.ZeroUint(), false},
}
for _, tt := range tests {
got, err := sdk.ParseUint(tt.args.s)
if tt.wantErr {
s.Require().Error(err)
continue
}
s.Require().NoError(err)
s.Require().True(got.Equal(tt.want))
}
}
func randuint() sdk.Uint {
return sdk.NewUint(rand.Uint64())
}
func (s *uintTestSuite) TestRelativePow() {
tests := []struct {
args []sdk.Uint
want sdk.Uint
}{
{[]sdk.Uint{sdk.ZeroUint(), sdk.ZeroUint(), sdk.OneUint()}, sdk.OneUint()},
{[]sdk.Uint{sdk.ZeroUint(), sdk.ZeroUint(), sdk.NewUint(10)}, sdk.NewUint(10)},
{[]sdk.Uint{sdk.ZeroUint(), sdk.OneUint(), sdk.NewUint(10)}, sdk.ZeroUint()},
{[]sdk.Uint{sdk.NewUint(10), sdk.NewUint(2), sdk.OneUint()}, sdk.NewUint(100)},
{[]sdk.Uint{sdk.NewUint(210), sdk.NewUint(2), sdk.NewUint(100)}, sdk.NewUint(441)},
{[]sdk.Uint{sdk.NewUint(2100), sdk.NewUint(2), sdk.NewUint(1000)}, sdk.NewUint(4410)},
{[]sdk.Uint{sdk.NewUint(1000000001547125958), sdk.NewUint(600), sdk.NewUint(1000000000000000000)}, sdk.NewUint(1000000928276004850)},
}
for i, tc := range tests {
res := sdk.RelativePow(tc.args[0], tc.args[1], tc.args[2])
s.Require().Equal(tc.want, res, "unexpected result for test case %d, input: %v, got: %v", i, tc.args, res)
}
}
func minuint(i1, i2 uint64) uint64 {
if i1 < i2 {
return i1
}
return i2
}
func maxuint(i1, i2 uint64) uint64 {
if i1 > i2 {
return i1
}
return i2
}
func TestRoundTripMarshalToUint(t *testing.T) {
var values = []uint64{
0,
1,
1 << 10,
1<<10 - 3,
1<<63 - 1,
1<<32 - 7,
1<<22 - 8,
}
for _, value := range values {
value := value
t.Run(fmt.Sprintf("%d", value), func(t *testing.T) {
t.Parallel()
var scratch [20]byte
uv := sdk.NewUint(value)
n, err := uv.MarshalTo(scratch[:])
if err != nil {
t.Fatal(err)
}
rt := new(sdk.Uint)
if err := rt.Unmarshal(scratch[:n]); err != nil {
t.Fatal(err)
}
if !rt.Equal(uv) {
t.Fatalf("roundtrip=%q != original=%q", rt, uv)
}
})
}
}