Merge PR #3665: Uint overhaul

This commit is contained in:
Alessio Treglia 2019-02-19 23:59:03 +01:00 committed by Jack Zampolin
parent 95710f1c37
commit b67d024fe3
5 changed files with 427 additions and 509 deletions

View File

@ -45,6 +45,8 @@
### SDK
* [\#3665] Overhaul sdk.Uint type in preparation for Coins's Int -> Uint migration.
### Tendermint
<!--------------------------------- BUG FIXES -------------------------------->

View File

@ -318,11 +318,6 @@ func (i Int) String() string {
return i.i.String()
}
// Testing purpose random Int generator
func randomInt(i Int) Int {
return NewIntFromBigInt(random(i.BigInt()))
}
// MarshalAmino defines custom encoding scheme
func (i Int) MarshalAmino() (string, error) {
if i.i == nil { // Necessary since default Uint initialization has i.i as nil
@ -355,256 +350,6 @@ func (i *Int) UnmarshalJSON(bz []byte) error {
return unmarshalJSON(i.i, bz)
}
// Int wraps integer with 256 bit range bound
// Checks overflow, underflow and division by zero
// Exists in range from 0 to 2^256-1
type Uint struct {
i *big.Int
}
// BigInt converts Uint to big.Unt
func (i Uint) BigInt() *big.Int {
return new(big.Int).Set(i.i)
}
// NewUint constructs Uint from int64
func NewUint(n uint64) Uint {
i := new(big.Int)
i.SetUint64(n)
return Uint{i}
}
// NewUintFromBigUint constructs Uint from big.Uint
func NewUintFromBigInt(i *big.Int) Uint {
res := Uint{i}
if UintOverflow(res) {
panic("Uint overflow")
}
return res
}
// NewUintFromString constructs Uint from string
func NewUintFromString(s string) (res Uint, ok bool) {
i, ok := newIntegerFromString(s)
if !ok {
return
}
// Check overflow
if i.Sign() == -1 || i.Sign() == 1 && i.BitLen() > 256 {
ok = false
return
}
return Uint{i}, true
}
// NewUintWithDecimal constructs Uint with decimal
// Result value is n*10^dec
func NewUintWithDecimal(n uint64, dec int) Uint {
if dec < 0 {
panic("NewUintWithDecimal() decimal is negative")
}
exp := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(dec)), nil)
i := new(big.Int)
i.Mul(new(big.Int).SetUint64(n), exp)
res := Uint{i}
if UintOverflow(res) {
panic("NewUintWithDecimal() out of bound")
}
return res
}
// ZeroUint returns Uint value with zero
func ZeroUint() Uint { return Uint{big.NewInt(0)} }
// OneUint returns Uint value with one
func OneUint() Uint { return Uint{big.NewInt(1)} }
// Uint64 converts Uint to uint64
// Panics if the value is out of range
func (i Uint) Uint64() uint64 {
if !i.i.IsUint64() {
panic("Uint64() out of bound")
}
return i.i.Uint64()
}
// IsUint64 returns true if Uint64() not panics
func (i Uint) IsUint64() bool {
return i.i.IsUint64()
}
// IsZero returns true if Uint is zero
func (i Uint) IsZero() bool {
return i.i.Sign() == 0
}
// Sign returns sign of Uint
func (i Uint) Sign() int {
return i.i.Sign()
}
// Equal compares two Uints
func (i Uint) Equal(i2 Uint) bool {
return equal(i.i, i2.i)
}
// GT returns true if first Uint is greater than second
func (i Uint) GT(i2 Uint) bool {
return gt(i.i, i2.i)
}
// LT returns true if first Uint is lesser than second
func (i Uint) LT(i2 Uint) bool {
return lt(i.i, i2.i)
}
// Add adds Uint from another
func (i Uint) Add(i2 Uint) (res Uint) {
res = Uint{add(i.i, i2.i)}
if UintOverflow(res) {
panic("Uint overflow")
}
return
}
// AddRaw adds uint64 to Uint
func (i Uint) AddRaw(i2 uint64) Uint {
return i.Add(NewUint(i2))
}
// Sub subtracts Uint from another
func (i Uint) Sub(i2 Uint) (res Uint) {
res = Uint{sub(i.i, i2.i)}
if UintOverflow(res) {
panic("Uint overflow")
}
return
}
// SafeSub attempts to subtract one Uint from another. A boolean is also returned
// indicating if the result contains integer overflow.
func (i Uint) SafeSub(i2 Uint) (Uint, bool) {
res := Uint{sub(i.i, i2.i)}
if UintOverflow(res) {
return res, true
}
return res, false
}
// SubRaw subtracts uint64 from Uint
func (i Uint) SubRaw(i2 uint64) Uint {
return i.Sub(NewUint(i2))
}
// Mul multiples two Uints
func (i Uint) Mul(i2 Uint) (res Uint) {
if i.i.BitLen()+i2.i.BitLen()-1 > 256 {
panic("Uint overflow")
}
res = Uint{mul(i.i, i2.i)}
if UintOverflow(res) {
panic("Uint overflow")
}
return
}
// MulRaw multipies Uint and uint64
func (i Uint) MulRaw(i2 uint64) Uint {
return i.Mul(NewUint(i2))
}
// Div divides Uint with Uint
func (i Uint) Div(i2 Uint) (res Uint) {
// Check division-by-zero
if i2.Sign() == 0 {
panic("division-by-zero")
}
return Uint{div(i.i, i2.i)}
}
// Div divides Uint with uint64
func (i Uint) DivRaw(i2 uint64) Uint {
return i.Div(NewUint(i2))
}
// Mod returns remainder after dividing with Uint
func (i Uint) Mod(i2 Uint) Uint {
if i2.Sign() == 0 {
panic("division-by-zero")
}
return Uint{mod(i.i, i2.i)}
}
// ModRaw returns remainder after dividing with uint64
func (i Uint) ModRaw(i2 uint64) Uint {
return i.Mod(NewUint(i2))
}
// Return the minimum of the Uints
func MinUint(i1, i2 Uint) Uint {
return Uint{min(i1.BigInt(), i2.BigInt())}
}
// MaxUint returns the maximum between two unsigned integers.
func MaxUint(i, i2 Uint) Uint {
return Uint{max(i.BigInt(), i2.BigInt())}
}
// Human readable string
func (i Uint) String() string {
return i.i.String()
}
// Testing purpose random Uint generator
func randomUint(i Uint) Uint {
return NewUintFromBigInt(random(i.BigInt()))
}
// MarshalAmino defines custom encoding scheme
func (i Uint) MarshalAmino() (string, error) {
if i.i == nil { // Necessary since default Uint initialization has i.i as nil
i.i = new(big.Int)
}
return marshalAmino(i.i)
}
// UnmarshalAmino defines custom decoding scheme
func (i *Uint) UnmarshalAmino(text string) error {
if i.i == nil { // Necessary since default Uint initialization has i.i as nil
i.i = new(big.Int)
}
return unmarshalAmino(i.i, text)
}
// MarshalJSON defines custom encoding scheme
func (i Uint) MarshalJSON() ([]byte, error) {
if i.i == nil { // Necessary since default Uint initialization has i.i as nil
i.i = new(big.Int)
}
return marshalJSON(i.i)
}
// UnmarshalJSON defines custom decoding scheme
func (i *Uint) UnmarshalJSON(bz []byte) error {
if i.i == nil { // Necessary since default Uint initialization has i.i as nil
i.i = new(big.Int)
}
return unmarshalJSON(i.i, bz)
}
//__________________________________________________________________________
// UintOverflow returns true if a given unsigned integer overflows and false
// otherwise.
func UintOverflow(x Uint) bool {
return x.i.Sign() == -1 || x.i.Sign() == 1 && x.i.BitLen() > 256
}
// intended to be used with require/assert: require.True(IntEq(...))
func IntEq(t *testing.T, exp, got Int) (*testing.T, bool, string, string, string) {
return t, exp.Equal(got), "expected:\t%v\ngot:\t\t%v", exp.String(), got.String()

View File

@ -1,7 +1,6 @@
package types
import (
"math"
"math/big"
"math/rand"
"strconv"
@ -73,45 +72,6 @@ func TestIntPanic(t *testing.T) {
require.Panics(t, func() { i1.Div(NewInt(0)) })
}
func TestUintPanic(t *testing.T) {
// Max Uint = 1.15e+77
// Min Uint = 0
require.NotPanics(t, func() { NewUintWithDecimal(5, 76) })
i1 := NewUintWithDecimal(5, 76)
require.NotPanics(t, func() { NewUintWithDecimal(10, 76) })
i2 := NewUintWithDecimal(10, 76)
require.NotPanics(t, func() { NewUintWithDecimal(11, 76) })
i3 := NewUintWithDecimal(11, 76)
require.Panics(t, func() { NewUintWithDecimal(12, 76) })
require.Panics(t, func() { NewUintWithDecimal(1, 80) })
// Overflow check
require.NotPanics(t, func() { i1.Add(i1) })
require.Panics(t, func() { i2.Add(i2) })
require.Panics(t, func() { i3.Add(i3) })
require.Panics(t, func() { i1.Mul(i1) })
require.Panics(t, func() { i2.Mul(i2) })
require.Panics(t, func() { i3.Mul(i3) })
// Underflow check
require.NotPanics(t, func() { i2.Sub(i1) })
require.NotPanics(t, func() { i2.Sub(i2) })
require.Panics(t, func() { i2.Sub(i3) })
// Bound check
uintmax := NewUintFromBigInt(new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1)))
uintmin := NewUint(0)
require.NotPanics(t, func() { uintmax.Add(ZeroUint()) })
require.NotPanics(t, func() { uintmin.Sub(ZeroUint()) })
require.Panics(t, func() { uintmax.Add(OneUint()) })
require.Panics(t, func() { uintmin.Sub(OneUint()) })
// Division-by-zero check
require.Panics(t, func() { i1.Div(uintmin) })
}
// Tests below uses randomness
// Since we are using *big.Int as underlying value
// and (U/)Int is immutable value(see TestImmutability(U/)Int)
@ -205,28 +165,6 @@ func TestCompInt(t *testing.T) {
}
}
func TestIdentUint(t *testing.T) {
for d := 0; d < 1000; d++ {
n := rand.Uint64()
i := NewUint(n)
ifromstr, ok := NewUintFromString(strconv.FormatUint(n, 10))
require.True(t, ok)
cases := []uint64{
i.Uint64(),
i.BigInt().Uint64(),
ifromstr.Uint64(),
NewUintFromBigInt(new(big.Int).SetUint64(n)).Uint64(),
NewUintWithDecimal(n, 0).Uint64(),
}
for tcnum, tc := range cases {
require.Equal(t, n, tc, "Uint is modified during conversion. tc #%d", tcnum)
}
}
}
func minuint(i1, i2 uint64) uint64 {
if i1 < i2 {
return i1
@ -241,71 +179,6 @@ func maxuint(i1, i2 uint64) uint64 {
return i2
}
func TestArithUint(t *testing.T) {
for d := 0; d < 1000; d++ {
n1 := uint64(rand.Uint32())
i1 := NewUint(n1)
n2 := uint64(rand.Uint32())
i2 := NewUint(n2)
cases := []struct {
ires Uint
nres uint64
}{
{i1.Add(i2), n1 + n2},
{i1.Mul(i2), n1 * n2},
{i1.Div(i2), n1 / n2},
{i1.AddRaw(n2), n1 + n2},
{i1.MulRaw(n2), n1 * n2},
{i1.DivRaw(n2), n1 / n2},
{MinUint(i1, i2), minuint(n1, n2)},
{MaxUint(i1, i2), maxuint(n1, n2)},
}
for tcnum, tc := range cases {
require.Equal(t, tc.nres, tc.ires.Uint64(), "Uint arithmetic operation does not match with uint64 operation. tc #%d", tcnum)
}
if n2 > n1 {
continue
}
subs := []struct {
ires Uint
nres uint64
}{
{i1.Sub(i2), n1 - n2},
{i1.SubRaw(n2), n1 - n2},
}
for tcnum, tc := range subs {
require.Equal(t, tc.nres, tc.ires.Uint64(), "Uint subtraction does not match with uint64 operation. tc #%d", tcnum)
}
}
}
func TestCompUint(t *testing.T) {
for d := 0; d < 1000; d++ {
n1 := rand.Uint64()
i1 := NewUint(n1)
n2 := rand.Uint64()
i2 := NewUint(n2)
cases := []struct {
ires bool
nres bool
}{
{i1.Equal(i2), n1 == n2},
{i1.GT(i2), n1 > n2},
{i1.LT(i2), n1 < n2},
}
for tcnum, tc := range cases {
require.Equal(t, tc.nres, tc.ires, "Uint comparison operation does not match with uint64 operation. tc #%d", tcnum)
}
}
}
func randint() Int {
return NewInt(rand.Int63())
}
@ -394,108 +267,6 @@ func TestImmutabilityArithInt(t *testing.T) {
}
}
}
func TestImmutabilityAllUint(t *testing.T) {
ops := []func(*Uint){
func(i *Uint) { _ = i.Add(NewUint(rand.Uint64())) },
func(i *Uint) { _ = i.Sub(NewUint(rand.Uint64() % i.Uint64())) },
func(i *Uint) { _ = i.Mul(randuint()) },
func(i *Uint) { _ = i.Div(randuint()) },
func(i *Uint) { _ = i.AddRaw(rand.Uint64()) },
func(i *Uint) { _ = i.SubRaw(rand.Uint64() % i.Uint64()) },
func(i *Uint) { _ = i.MulRaw(rand.Uint64()) },
func(i *Uint) { _ = i.DivRaw(rand.Uint64()) },
func(i *Uint) { _ = i.IsZero() },
func(i *Uint) { _ = i.Sign() },
func(i *Uint) { _ = i.Equal(randuint()) },
func(i *Uint) { _ = i.GT(randuint()) },
func(i *Uint) { _ = i.LT(randuint()) },
func(i *Uint) { _ = i.String() },
}
for i := 0; i < 1000; i++ {
n := rand.Uint64()
ni := NewUint(n)
for opnum, op := range ops {
op(&ni)
require.Equal(t, n, ni.Uint64(), "Uint is modified by operation. #%d", opnum)
require.Equal(t, NewUint(n), ni, "Uint is modified by operation. #%d", opnum)
}
}
}
type uintop func(Uint, *big.Int) (Uint, *big.Int)
func uintarith(uifn func(Uint, Uint) Uint, bifn func(*big.Int, *big.Int, *big.Int) *big.Int, sub bool) uintop {
return func(ui Uint, bi *big.Int) (Uint, *big.Int) {
r := rand.Uint64()
if sub && ui.IsUint64() {
if ui.IsZero() {
return ui, bi
}
r = r % ui.Uint64()
}
ur := NewUint(r)
br := new(big.Int).SetUint64(r)
return uifn(ui, ur), bifn(new(big.Int), bi, br)
}
}
func uintarithraw(uifn func(Uint, uint64) Uint, bifn func(*big.Int, *big.Int, *big.Int) *big.Int, sub bool) uintop {
return func(ui Uint, bi *big.Int) (Uint, *big.Int) {
r := rand.Uint64()
if sub && ui.IsUint64() {
if ui.IsZero() {
return ui, bi
}
r = r % ui.Uint64()
}
br := new(big.Int).SetUint64(r)
mui := ui.ModRaw(math.MaxUint64)
mbi := new(big.Int).Mod(bi, new(big.Int).SetUint64(math.MaxUint64))
return uifn(mui, r), bifn(new(big.Int), mbi, br)
}
}
func TestImmutabilityArithUint(t *testing.T) {
size := 500
ops := []uintop{
uintarith(Uint.Add, (*big.Int).Add, false),
uintarith(Uint.Sub, (*big.Int).Sub, true),
uintarith(Uint.Mul, (*big.Int).Mul, false),
uintarith(Uint.Div, (*big.Int).Div, false),
uintarithraw(Uint.AddRaw, (*big.Int).Add, false),
uintarithraw(Uint.SubRaw, (*big.Int).Sub, true),
uintarithraw(Uint.MulRaw, (*big.Int).Mul, false),
uintarithraw(Uint.DivRaw, (*big.Int).Div, false),
}
for i := 0; i < 100; i++ {
uis := make([]Uint, size)
bis := make([]*big.Int, size)
n := rand.Uint64()
ui := NewUint(n)
bi := new(big.Int).SetUint64(n)
for j := 0; j < size; j++ {
op := ops[rand.Intn(len(ops))]
uis[j], bis[j] = op(ui, bi)
}
for j := 0; j < size; j++ {
require.Equal(t, 0, bis[j].Cmp(uis[j].BigInt()), "Int is different from *big.Int. tc #%d, Int %s, *big.Int %s", j, uis[j].String(), bis[j].String())
require.Equal(t, NewUintFromBigInt(bis[j]), uis[j], "Int is different from *big.Int. tc #%d, Int %s, *big.Int %s", j, uis[j].String(), bis[j].String())
require.True(t, uis[j].i != bis[j], "Pointer addresses are equal. tc #%d, Int %s, *big.Int %s", j, uis[j].String(), bis[j].String())
}
}
}
func randuint() Uint {
return NewUint(rand.Uint64())
}
func TestEncodingRandom(t *testing.T) {
for i := 0; i < 1000; i++ {
@ -607,31 +378,6 @@ func TestEncodingTableUint(t *testing.T) {
}
}
func TestSafeSub(t *testing.T) {
testCases := []struct {
x, y Uint
expected uint64
overflow bool
}{
{NewUint(0), NewUint(0), 0, false},
{NewUint(10), NewUint(5), 5, false},
{NewUint(5), NewUint(10), 5, true},
{NewUint(math.MaxUint64), NewUint(0), math.MaxUint64, false},
}
for i, tc := range testCases {
res, overflow := tc.x.SafeSub(tc.y)
require.Equal(
t, tc.overflow, overflow,
"invalid overflow result; x: %s, y: %s, tc: #%d", tc.x, tc.y, i,
)
require.Equal(
t, tc.expected, res.BigInt().Uint64(),
"invalid subtraction result; x: %s, y: %s, tc: #%d", tc.x, tc.y, i,
)
}
}
func TestSerializationOverflow(t *testing.T) {
bx, _ := new(big.Int).SetString("91888242871839275229946405745257275988696311157297823662689937894645226298583", 10)
x := Int{bx}

172
types/uint.go Normal file
View File

@ -0,0 +1,172 @@
package types
import (
"errors"
"fmt"
"math/big"
)
// Uint wraps integer with 256 bit range bound
// Checks overflow, underflow and division by zero
// Exists in range from 0 to 2^256-1
type Uint struct {
i *big.Int
}
// NewUintFromBigUint constructs Uint from big.Uint
func NewUintFromBigInt(i *big.Int) Uint {
u, err := checkNewUint(i)
if err != nil {
panic(fmt.Errorf("overflow: %s", err))
}
return u
}
// NewUint constructs Uint from int64
func NewUint(n uint64) Uint {
i := new(big.Int)
i.SetUint64(n)
return NewUintFromBigInt(i)
}
// NewUintFromString constructs Uint from string
func NewUintFromString(s string) Uint {
u, err := ParseUint(s)
if err != nil {
panic(err)
}
return u
}
// ZeroUint returns unsigned zero.
func ZeroUint() Uint { return Uint{big.NewInt(0)} }
// OneUint returns Uint value with one.
func OneUint() Uint { return Uint{big.NewInt(1)} }
// Uint64 converts Uint to uint64
// Panics if the value is out of range
func (u Uint) Uint64() uint64 {
if !u.i.IsUint64() {
panic("Uint64() out of bound")
}
return u.i.Uint64()
}
// IsZero returns 1 if the uint equals to 0.
func (u Uint) IsZero() bool { return u.Equal(ZeroUint()) }
// Equal compares two Uints
func (u Uint) Equal(u2 Uint) bool { return equal(u.i, u2.i) }
// GT returns true if first Uint is greater than second
func (u Uint) GT(u2 Uint) bool { return gt(u.i, u2.i) }
// GTE returns true if first Uint is greater than second
func (u Uint) GTE(u2 Uint) bool { return u.GT(u2) || u.Equal(u2) }
// LT returns true if first Uint is lesser than second
func (u Uint) LT(u2 Uint) bool { return lt(u.i, u2.i) }
// LTE returns true if first Uint is lesser than or equal to the second
func (u Uint) LTE(u2 Uint) bool { return !u.GTE(u2) }
// Add adds Uint from another
func (u Uint) Add(u2 Uint) Uint { return NewUintFromBigInt(new(big.Int).Add(u.i, u2.i)) }
// Add convert uint64 and add it to Uint
func (u Uint) AddUint64(u2 uint64) Uint { return u.Add(NewUint(u2)) }
// Sub adds Uint from another
func (u Uint) Sub(u2 Uint) Uint { return NewUintFromBigInt(new(big.Int).Sub(u.i, u2.i)) }
// SubUint64 adds Uint from another
func (u Uint) SubUint64(u2 uint64) Uint { return u.Sub(NewUint(u2)) }
// Mul multiplies two Uints
func (u Uint) Mul(u2 Uint) (res Uint) {
return NewUintFromBigInt(new(big.Int).Mul(u.i, u2.i))
}
// Mul multiplies two Uints
func (u Uint) MulUint64(u2 uint64) (res Uint) { return u.Mul(NewUint(u2)) }
// Div divides Uint with Uint
func (u Uint) Div(u2 Uint) (res Uint) { return NewUintFromBigInt(div(u.i, u2.i)) }
// Div divides Uint with uint64
func (u Uint) DivUint64(u2 uint64) Uint { return u.Div(NewUint(u2)) }
// Return the minimum of the Uints
func MinUint(u1, u2 Uint) Uint { return NewUintFromBigInt(min(u1.i, u2.i)) }
// Return the maximum of the Uints
func MaxUint(u1, u2 Uint) Uint { return NewUintFromBigInt(max(u1.i, u2.i)) }
// Human readable string
func (u Uint) String() string { return u.i.String() }
// Testing purpose random Uint generator
func randomUint(u Uint) Uint { return NewUintFromBigInt(random(u.i)) }
// MarshalAmino defines custom encoding scheme
func (u Uint) MarshalAmino() (string, error) {
if u.i == nil { // Necessary since default Uint initialization has i.i as nil
u.i = new(big.Int)
}
return marshalAmino(u.i)
}
// UnmarshalAmino defines custom decoding scheme
func (u *Uint) UnmarshalAmino(text string) error {
if u.i == nil { // Necessary since default Uint initialization has i.i as nil
u.i = new(big.Int)
}
return unmarshalAmino(u.i, text)
}
// MarshalJSON defines custom encoding scheme
func (u Uint) MarshalJSON() ([]byte, error) {
if u.i == nil { // Necessary since default Uint initialization has i.i as nil
u.i = new(big.Int)
}
return marshalJSON(u.i)
}
// UnmarshalJSON defines custom decoding scheme
func (u *Uint) UnmarshalJSON(bz []byte) error {
if u.i == nil { // Necessary since default Uint initialization has i.i as nil
u.i = new(big.Int)
}
return unmarshalJSON(u.i, bz)
}
//__________________________________________________________________________
// UintOverflow returns true if a given unsigned integer overflows and false
// otherwise.
func UintOverflow(i *big.Int) error {
if i.Sign() < 0 {
return errors.New("non-positive integer")
}
if i.BitLen() > 256 {
return fmt.Errorf("bit length %d greater than 256", i.BitLen())
}
return nil
}
// ParseUint reads a string-encoded Uint value and return a Uint.
func ParseUint(s string) (Uint, error) {
i, ok := new(big.Int).SetString(s, 0)
if !ok {
return Uint{}, fmt.Errorf("cannot convert %q to big.Int", s)
}
return checkNewUint(i)
}
func checkNewUint(i *big.Int) (Uint, error) {
if err := UintOverflow(i); err != nil {
return Uint{}, err
}
return Uint{i}, nil
}

253
types/uint_test.go Normal file
View File

@ -0,0 +1,253 @@
package types
import (
"math"
"math/big"
"math/rand"
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestUintPanics(t *testing.T) {
// Max Uint = 1.15e+77
// Min Uint = 0
u1 := NewUint(0)
u2 := OneUint()
require.Equal(t, uint64(0), u1.Uint64())
require.Equal(t, uint64(1), u2.Uint64())
require.Panics(t, func() { NewUintFromBigInt(big.NewInt(-5)) })
require.Panics(t, func() { NewUintFromString("-1") })
require.NotPanics(t, func() {
require.True(t, NewUintFromString("0").Equal(ZeroUint()))
require.True(t, NewUintFromString("5").Equal(NewUint(5)))
})
// Overflow check
require.True(t, u1.Add(u1).Equal(ZeroUint()))
require.True(t, u1.Add(OneUint()).Equal(OneUint()))
require.Equal(t, uint64(0), u1.Uint64())
require.Equal(t, uint64(1), OneUint().Uint64())
require.Panics(t, func() { u1.SubUint64(2) })
require.True(t, u1.SubUint64(0).Equal(ZeroUint()))
require.True(t, u2.Add(OneUint()).Sub(OneUint()).Equal(OneUint())) // i2 == 1
require.True(t, u2.Add(OneUint()).Mul(NewUint(5)).Equal(NewUint(10))) // i2 == 10
require.True(t, NewUint(7).Div(NewUint(2)).Equal(NewUint(3)))
require.True(t, NewUint(0).Div(NewUint(2)).Equal(ZeroUint()))
require.True(t, NewUint(5).MulUint64(4).Equal(NewUint(20)))
require.True(t, NewUint(5).MulUint64(0).Equal(ZeroUint()))
uintmax := NewUintFromBigInt(new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1)))
uintmin := ZeroUint()
// divs by zero
require.Panics(t, func() { OneUint().Mul(ZeroUint().SubUint64(uint64(1))) })
require.Panics(t, func() { OneUint().DivUint64(0) })
require.Panics(t, func() { OneUint().Div(ZeroUint()) })
require.Panics(t, func() { ZeroUint().DivUint64(0) })
require.Panics(t, func() { OneUint().Div(ZeroUint().Sub(OneUint())) })
require.Panics(t, func() { uintmax.Add(OneUint()) })
require.Panics(t, func() { uintmin.Sub(OneUint()) })
require.Equal(t, uint64(0), MinUint(ZeroUint(), OneUint()).Uint64())
require.Equal(t, uint64(1), MaxUint(ZeroUint(), OneUint()).Uint64())
// comparison ops
require.True(t,
OneUint().GT(ZeroUint()),
)
require.False(t,
OneUint().LT(ZeroUint()),
)
require.True(t,
OneUint().GTE(ZeroUint()),
)
require.False(t,
OneUint().LTE(ZeroUint()),
)
require.False(t, ZeroUint().GT(OneUint()))
require.True(t, ZeroUint().LT(OneUint()))
require.False(t, ZeroUint().GTE(OneUint()))
require.True(t, ZeroUint().LTE(OneUint()))
}
func TestIdentUint(t *testing.T) {
for d := 0; d < 1000; d++ {
n := rand.Uint64()
i := NewUint(n)
ifromstr := NewUintFromString(strconv.FormatUint(n, 10))
cases := []uint64{
i.Uint64(),
i.i.Uint64(),
ifromstr.Uint64(),
NewUintFromBigInt(new(big.Int).SetUint64(n)).Uint64(),
}
for tcnum, tc := range cases {
require.Equal(t, n, tc, "Uint is modified during conversion. tc #%d", tcnum)
}
}
}
func TestArithUint(t *testing.T) {
for d := 0; d < 1000; d++ {
n1 := uint64(rand.Uint32())
u1 := NewUint(n1)
n2 := uint64(rand.Uint32())
u2 := NewUint(n2)
cases := []struct {
ures Uint
nres uint64
}{
{u1.Add(u2), n1 + n2},
{u1.Mul(u2), n1 * n2},
{u1.Div(u2), n1 / n2},
{u1.AddUint64(n2), n1 + n2},
{u1.MulUint64(n2), n1 * n2},
{u1.DivUint64(n2), n1 / n2},
{MinUint(u1, u2), minuint(n1, n2)},
{MaxUint(u1, u2), maxuint(n1, n2)},
}
for tcnum, tc := range cases {
require.Equal(t, tc.nres, tc.ures.Uint64(), "Uint arithmetic operation does not match with uint64 operation. tc #%d", tcnum)
}
if n2 > n1 {
n1, n2 = n2, n1
u1, u2 = NewUint(n1), NewUint(n2)
}
subs := []struct {
ures Uint
nres uint64
}{
{u1.Sub(u2), n1 - n2},
{u1.SubUint64(n2), n1 - n2},
}
for tcnum, tc := range subs {
require.Equal(t, tc.nres, tc.ures.Uint64(), "Uint subtraction does not match with uint64 operation. tc #%d", tcnum)
}
}
}
func TestCompUint(t *testing.T) {
for d := 0; d < 1000; d++ {
n1 := rand.Uint64()
i1 := NewUint(n1)
n2 := rand.Uint64()
i2 := NewUint(n2)
cases := []struct {
ires bool
nres bool
}{
{i1.Equal(i2), n1 == n2},
{i1.GT(i2), n1 > n2},
{i1.LT(i2), n1 < n2},
{i1.GTE(i2), !i1.LT(i2)},
{!i1.GTE(i2), i1.LT(i2)},
}
for tcnum, tc := range cases {
require.Equal(t, tc.nres, tc.ires, "Uint comparison operation does not match with uint64 operation. tc #%d", tcnum)
}
}
}
func TestImmutabilityAllUint(t *testing.T) {
ops := []func(*Uint){
func(i *Uint) { _ = i.Add(NewUint(rand.Uint64())) },
func(i *Uint) { _ = i.Sub(NewUint(rand.Uint64() % i.Uint64())) },
func(i *Uint) { _ = i.Mul(randuint()) },
func(i *Uint) { _ = i.Div(randuint()) },
func(i *Uint) { _ = i.AddUint64(rand.Uint64()) },
func(i *Uint) { _ = i.SubUint64(rand.Uint64() % i.Uint64()) },
func(i *Uint) { _ = i.MulUint64(rand.Uint64()) },
func(i *Uint) { _ = i.DivUint64(rand.Uint64()) },
func(i *Uint) { _ = i.IsZero() },
func(i *Uint) { _ = i.Equal(randuint()) },
func(i *Uint) { _ = i.GT(randuint()) },
func(i *Uint) { _ = i.GTE(randuint()) },
func(i *Uint) { _ = i.LT(randuint()) },
func(i *Uint) { _ = i.LTE(randuint()) },
func(i *Uint) { _ = i.String() },
}
for i := 0; i < 1000; i++ {
n := rand.Uint64()
ni := NewUint(n)
for opnum, op := range ops {
op(&ni)
require.Equal(t, n, ni.Uint64(), "Uint is modified by operation. #%d", opnum)
require.Equal(t, NewUint(n), ni, "Uint is modified by operation. #%d", opnum)
}
}
}
func TestSafeSub(t *testing.T) {
testCases := []struct {
x, y Uint
expected uint64
panic bool
}{
{NewUint(0), NewUint(0), 0, false},
{NewUint(10), NewUint(5), 5, false},
{NewUint(5), NewUint(10), 5, true},
{NewUint(math.MaxUint64), NewUint(0), math.MaxUint64, false},
}
for i, tc := range testCases {
if tc.panic {
require.Panics(t, func() { tc.x.Sub(tc.y) })
continue
}
require.Equal(
t, tc.expected, tc.x.Sub(tc.y).Uint64(),
"invalid subtraction result; x: %s, y: %s, tc: #%d", tc.x, tc.y, i,
)
}
}
func TestParseUint(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
want Uint
wantErr bool
}{
{"malformed", args{"malformed"}, Uint{}, true},
{"empty", args{""}, Uint{}, true},
{"positive", args{"50"}, NewUint(uint64(50)), false},
{"negative", args{"-1"}, Uint{}, true},
{"zero", args{"0"}, ZeroUint(), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseUint(tt.args.s)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.True(t, got.Equal(tt.want))
})
}
}
func randuint() Uint {
return NewUint(rand.Uint64())
}