added rational to types
This commit is contained in:
parent
d34333b3e9
commit
1ed7206efe
|
@ -0,0 +1,199 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
wire "github.com/tendermint/go-wire"
|
||||
)
|
||||
|
||||
var cdc *wire.Codec
|
||||
|
||||
func init() {
|
||||
cdc = wire.NewCodec()
|
||||
cdc.RegisterInterface((*Rational)(nil), nil)
|
||||
cdc.RegisterConcrete(Rat{}, "rat", nil)
|
||||
}
|
||||
|
||||
// Rat - extend big.Rat
|
||||
type Rat struct {
|
||||
*big.Rat `json:"rat"`
|
||||
}
|
||||
|
||||
// Rational - big Rat with additional functionality
|
||||
type Rational interface {
|
||||
GetRat() *big.Rat
|
||||
Num() int64
|
||||
Denom() int64
|
||||
GT(Rational) bool
|
||||
LT(Rational) bool
|
||||
Equal(Rational) bool
|
||||
IsZero() bool
|
||||
Inv() Rational
|
||||
Mul(Rational) Rational
|
||||
Quo(Rational) Rational
|
||||
Add(Rational) Rational
|
||||
Sub(Rational) Rational
|
||||
Round(int64) Rational
|
||||
Evaluate() int64
|
||||
}
|
||||
|
||||
var _ Rational = Rat{} // enforce at compile time
|
||||
|
||||
// nolint - common values
|
||||
var (
|
||||
Zero = Rat{big.NewRat(0, 1)}
|
||||
One = Rat{big.NewRat(1, 1)}
|
||||
)
|
||||
|
||||
// New - create a new Rat from integers
|
||||
func New(Numerator int64, Denominator ...int64) Rat {
|
||||
switch len(Denominator) {
|
||||
case 0:
|
||||
return Rat{big.NewRat(Numerator, 1)}
|
||||
case 1:
|
||||
return Rat{big.NewRat(Numerator, Denominator[0])}
|
||||
default:
|
||||
panic("improper use of New, can only have one denominator")
|
||||
}
|
||||
}
|
||||
|
||||
//NewFromDecimal - create a rational from decimal string or integer string
|
||||
func NewFromDecimal(decimalStr string) (f Rat, err error) {
|
||||
|
||||
// first extract any negative symbol
|
||||
neg := false
|
||||
if string(decimalStr[0]) == "-" {
|
||||
neg = true
|
||||
decimalStr = decimalStr[1:]
|
||||
}
|
||||
|
||||
str := strings.Split(decimalStr, ".")
|
||||
|
||||
var numStr string
|
||||
var denom int64 = 1
|
||||
switch len(str) {
|
||||
case 1:
|
||||
if len(str[0]) == 0 {
|
||||
return f, errors.New("not a decimal string")
|
||||
}
|
||||
numStr = str[0]
|
||||
case 2:
|
||||
if len(str[0]) == 0 || len(str[1]) == 0 {
|
||||
return f, errors.New("not a decimal string")
|
||||
}
|
||||
numStr = str[0] + str[1]
|
||||
len := int64(len(str[1]))
|
||||
denom = new(big.Int).Exp(big.NewInt(10), big.NewInt(len), nil).Int64()
|
||||
default:
|
||||
return f, errors.New("not a decimal string")
|
||||
}
|
||||
|
||||
num, err := strconv.Atoi(numStr)
|
||||
if err != nil {
|
||||
return f, err
|
||||
}
|
||||
|
||||
if neg {
|
||||
num *= -1
|
||||
}
|
||||
|
||||
return Rat{big.NewRat(int64(num), denom)}, nil
|
||||
}
|
||||
|
||||
//nolint
|
||||
func (r Rat) GetRat() *big.Rat { return r.Rat } // GetRat - get big.Rat
|
||||
func (r Rat) Num() int64 { return r.Rat.Num().Int64() } // Num - return the numerator
|
||||
func (r Rat) Denom() int64 { return r.Rat.Denom().Int64() } // Denom - return the denominator
|
||||
func (r Rat) IsZero() bool { return r.Num() == 0 } // IsZero - Is the Rat equal to zero
|
||||
func (r Rat) Equal(r2 Rational) bool { return r.Rat.Cmp(r2.GetRat()) == 0 } // Equal - rationals are equal
|
||||
func (r Rat) GT(r2 Rational) bool { return r.Rat.Cmp(r2.GetRat()) == 1 } // GT - greater than
|
||||
func (r Rat) LT(r2 Rational) bool { return r.Rat.Cmp(r2.GetRat()) == -1 } // LT - less than
|
||||
func (r Rat) Inv() Rational { return Rat{new(big.Rat).Inv(r.Rat)} } // Inv - inverse
|
||||
func (r Rat) Mul(r2 Rational) Rational { return Rat{new(big.Rat).Mul(r.Rat, r2.GetRat())} } // Mul - multiplication
|
||||
func (r Rat) Quo(r2 Rational) Rational { return Rat{new(big.Rat).Quo(r.Rat, r2.GetRat())} } // Quo - quotient
|
||||
func (r Rat) Add(r2 Rational) Rational { return Rat{new(big.Rat).Add(r.Rat, r2.GetRat())} } // Add - addition
|
||||
func (r Rat) Sub(r2 Rational) Rational { return Rat{new(big.Rat).Sub(r.Rat, r2.GetRat())} } // Sub - subtraction
|
||||
|
||||
var zero = big.NewInt(0)
|
||||
var one = big.NewInt(1)
|
||||
var two = big.NewInt(2)
|
||||
var five = big.NewInt(5)
|
||||
var nFive = big.NewInt(-5)
|
||||
var ten = big.NewInt(10)
|
||||
|
||||
// EvaluateBig - evaluate the rational using bankers rounding
|
||||
func (r Rat) EvaluateBig() *big.Int {
|
||||
|
||||
num := r.Rat.Num()
|
||||
denom := r.Rat.Denom()
|
||||
|
||||
d, rem := new(big.Int), new(big.Int)
|
||||
d.QuoRem(num, denom, rem)
|
||||
if rem.Cmp(zero) == 0 { // is the remainder zero
|
||||
return d
|
||||
}
|
||||
|
||||
// evaluate the remainder using bankers rounding
|
||||
tenNum := new(big.Int).Mul(num, ten)
|
||||
tenD := new(big.Int).Mul(d, ten)
|
||||
remainderDigit := new(big.Int).Sub(new(big.Int).Quo(tenNum, denom), tenD) // get the first remainder digit
|
||||
isFinalDigit := (new(big.Int).Rem(tenNum, denom).Cmp(zero) == 0) // is this the final digit in the remainder?
|
||||
|
||||
switch {
|
||||
case isFinalDigit && (remainderDigit.Cmp(five) == 0 || remainderDigit.Cmp(nFive) == 0):
|
||||
dRem2 := new(big.Int).Rem(d, two)
|
||||
return new(big.Int).Add(d, dRem2) // always rounds to the even number
|
||||
case remainderDigit.Cmp(five) != -1: //remainderDigit >= 5:
|
||||
d.Add(d, one)
|
||||
case remainderDigit.Cmp(nFive) != 1: //remainderDigit <= -5:
|
||||
d.Sub(d, one)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// Evaluate - evaluate the rational using bankers rounding
|
||||
func (r Rat) Evaluate() int64 {
|
||||
return r.EvaluateBig().Int64()
|
||||
}
|
||||
|
||||
// Round - round Rat with the provided precisionFactor
|
||||
func (r Rat) Round(precisionFactor int64) Rational {
|
||||
rTen := Rat{new(big.Rat).Mul(r.Rat, big.NewRat(precisionFactor, 1))}
|
||||
return Rat{big.NewRat(rTen.Evaluate(), precisionFactor)}
|
||||
}
|
||||
|
||||
//___________________________________________________________________________________
|
||||
|
||||
//TODO there has got to be a better way using native MarshalText and UnmarshalText
|
||||
|
||||
// RatMarshal - Marshable Rat Struct
|
||||
type RatMarshal struct {
|
||||
Numerator int64 `json:"numerator"`
|
||||
Denominator int64 `json:"denominator"`
|
||||
}
|
||||
|
||||
// MarshalJSON - custom implementation of JSON Marshal
|
||||
func (r Rat) MarshalJSON() ([]byte, error) {
|
||||
return cdc.MarshalJSON(RatMarshal{r.Num(), r.Denom()})
|
||||
}
|
||||
|
||||
// UnmarshalJSON - custom implementation of JSON Unmarshal
|
||||
func (r *Rat) UnmarshalJSON(data []byte) (err error) {
|
||||
defer func() {
|
||||
if rcv := recover(); rcv != nil {
|
||||
err = fmt.Errorf("Panic during UnmarshalJSON: %v", rcv)
|
||||
}
|
||||
}()
|
||||
|
||||
ratMar := new(RatMarshal)
|
||||
if err := cdc.UnmarshalJSON(data, ratMar); err != nil {
|
||||
return err
|
||||
}
|
||||
r.Rat = big.NewRat(ratMar.Numerator, ratMar.Denominator)
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,296 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
asrt "github.com/stretchr/testify/assert"
|
||||
rqr "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
assert := asrt.New(t)
|
||||
|
||||
assert.Equal(New(1), New(1, 1))
|
||||
assert.Equal(New(100), New(100, 1))
|
||||
assert.Equal(New(-1), New(-1, 1))
|
||||
assert.Equal(New(-100), New(-100, 1))
|
||||
assert.Equal(New(0), New(0, 1))
|
||||
|
||||
// do not allow for more than 2 variables
|
||||
assert.Panics(func() { New(1, 1, 1) })
|
||||
}
|
||||
|
||||
func TestNewFromDecimal(t *testing.T) {
|
||||
assert := asrt.New(t)
|
||||
|
||||
tests := []struct {
|
||||
decimalStr string
|
||||
expErr bool
|
||||
exp Rat
|
||||
}{
|
||||
{"0", false, New(0)},
|
||||
{"1", false, New(1)},
|
||||
{"1.1", false, New(11, 10)},
|
||||
{"0.75", false, New(3, 4)},
|
||||
{"0.8", false, New(4, 5)},
|
||||
{"0.11111", false, New(11111, 100000)},
|
||||
{".", true, Rat{}},
|
||||
{".0", true, Rat{}},
|
||||
{"1.", true, Rat{}},
|
||||
{"foobar", true, Rat{}},
|
||||
{"0.foobar", true, Rat{}},
|
||||
{"0.foobar.", true, Rat{}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
||||
res, err := NewFromDecimal(tc.decimalStr)
|
||||
if tc.expErr {
|
||||
assert.NotNil(err, tc.decimalStr)
|
||||
} else {
|
||||
assert.Nil(err)
|
||||
assert.True(res.Equal(tc.exp))
|
||||
}
|
||||
|
||||
// negative tc
|
||||
res, err = NewFromDecimal("-" + tc.decimalStr)
|
||||
if tc.expErr {
|
||||
assert.NotNil(err, tc.decimalStr)
|
||||
} else {
|
||||
assert.Nil(err)
|
||||
assert.True(res.Equal(tc.exp.Mul(New(-1))))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqualities(t *testing.T) {
|
||||
assert := asrt.New(t)
|
||||
|
||||
tests := []struct {
|
||||
r1, r2 Rat
|
||||
gt, lt, eq bool
|
||||
}{
|
||||
{New(0), New(0), false, false, true},
|
||||
{New(0, 100), New(0, 10000), false, false, true},
|
||||
{New(100), New(100), false, false, true},
|
||||
{New(-100), New(-100), false, false, true},
|
||||
{New(-100, -1), New(100), false, false, true},
|
||||
{New(-1, 1), New(1, -1), false, false, true},
|
||||
{New(1, -1), New(-1, 1), false, false, true},
|
||||
{New(3, 7), New(3, 7), false, false, true},
|
||||
|
||||
{New(0), New(3, 7), false, true, false},
|
||||
{New(0), New(100), false, true, false},
|
||||
{New(-1), New(3, 7), false, true, false},
|
||||
{New(-1), New(100), false, true, false},
|
||||
{New(1, 7), New(100), false, true, false},
|
||||
{New(1, 7), New(3, 7), false, true, false},
|
||||
{New(-3, 7), New(-1, 7), false, true, false},
|
||||
|
||||
{New(3, 7), New(0), true, false, false},
|
||||
{New(100), New(0), true, false, false},
|
||||
{New(3, 7), New(-1), true, false, false},
|
||||
{New(100), New(-1), true, false, false},
|
||||
{New(100), New(1, 7), true, false, false},
|
||||
{New(3, 7), New(1, 7), true, false, false},
|
||||
{New(-1, 7), New(-3, 7), true, false, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
assert.Equal(tc.gt, tc.r1.GT(tc.r2))
|
||||
assert.Equal(tc.lt, tc.r1.LT(tc.r2))
|
||||
assert.Equal(tc.eq, tc.r1.Equal(tc.r2))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestArithmatic(t *testing.T) {
|
||||
assert := asrt.New(t)
|
||||
|
||||
tests := []struct {
|
||||
r1, r2 Rat
|
||||
resMul, resDiv, resAdd, resSub Rat
|
||||
}{
|
||||
// r1 r2 MUL DIV ADD SUB
|
||||
{New(0), New(0), New(0), New(0), New(0), New(0)},
|
||||
{New(1), New(0), New(0), New(0), New(1), New(1)},
|
||||
{New(0), New(1), New(0), New(0), New(1), New(-1)},
|
||||
{New(0), New(-1), New(0), New(0), New(-1), New(1)},
|
||||
{New(-1), New(0), New(0), New(0), New(-1), New(-1)},
|
||||
|
||||
{New(1), New(1), New(1), New(1), New(2), New(0)},
|
||||
{New(-1), New(-1), New(1), New(1), New(-2), New(0)},
|
||||
{New(1), New(-1), New(-1), New(-1), New(0), New(2)},
|
||||
{New(-1), New(1), New(-1), New(-1), New(0), New(-2)},
|
||||
|
||||
{New(3), New(7), New(21), New(3, 7), New(10), New(-4)},
|
||||
{New(2), New(4), New(8), New(1, 2), New(6), New(-2)},
|
||||
{New(100), New(100), New(10000), New(1), New(200), New(0)},
|
||||
|
||||
{New(3, 2), New(3, 2), New(9, 4), New(1), New(3), New(0)},
|
||||
{New(3, 7), New(7, 3), New(1), New(9, 49), New(58, 21), New(-40, 21)},
|
||||
{New(1, 21), New(11, 5), New(11, 105), New(5, 231), New(236, 105), New(-226, 105)},
|
||||
{New(-21), New(3, 7), New(-9), New(-49), New(-144, 7), New(-150, 7)},
|
||||
{New(100), New(1, 7), New(100, 7), New(700), New(701, 7), New(699, 7)},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
assert.True(tc.resMul.Equal(tc.r1.Mul(tc.r2)), "r1 %v, r2 %v", tc.r1.GetRat(), tc.r2.GetRat())
|
||||
assert.True(tc.resAdd.Equal(tc.r1.Add(tc.r2)), "r1 %v, r2 %v", tc.r1.GetRat(), tc.r2.GetRat())
|
||||
assert.True(tc.resSub.Equal(tc.r1.Sub(tc.r2)), "r1 %v, r2 %v", tc.r1.GetRat(), tc.r2.GetRat())
|
||||
|
||||
if tc.r2.Num() == 0 { // panic for divide by zero
|
||||
assert.Panics(func() { tc.r1.Quo(tc.r2) })
|
||||
} else {
|
||||
assert.True(tc.resDiv.Equal(tc.r1.Quo(tc.r2)), "r1 %v, r2 %v", tc.r1.GetRat(), tc.r2.GetRat())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvaluate(t *testing.T) {
|
||||
assert := asrt.New(t)
|
||||
|
||||
tests := []struct {
|
||||
r1 Rat
|
||||
res int64
|
||||
}{
|
||||
{New(0), 0},
|
||||
{New(1), 1},
|
||||
{New(1, 4), 0},
|
||||
{New(1, 2), 0},
|
||||
{New(3, 4), 1},
|
||||
{New(5, 6), 1},
|
||||
{New(3, 2), 2},
|
||||
{New(5, 2), 2},
|
||||
{New(6, 11), 1}, // 0.545-> 1 even though 5 is first decimal and 1 not even
|
||||
{New(17, 11), 2}, // 1.545
|
||||
{New(5, 11), 0},
|
||||
{New(16, 11), 1},
|
||||
{New(113, 12), 9},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
assert.Equal(tc.res, tc.r1.Evaluate(), "%v", tc.r1)
|
||||
assert.Equal(tc.res*-1, tc.r1.Mul(New(-1)).Evaluate(), "%v", tc.r1.Mul(New(-1)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRound(t *testing.T) {
|
||||
assert, require := asrt.New(t), rqr.New(t)
|
||||
|
||||
many3 := "333333333333333333333333333333333333333333333"
|
||||
many7 := "777777777777777777777777777777777777777777777"
|
||||
big3, worked := new(big.Int).SetString(many3, 10)
|
||||
require.True(worked)
|
||||
big7, worked := new(big.Int).SetString(many7, 10)
|
||||
require.True(worked)
|
||||
|
||||
tests := []struct {
|
||||
r1, res Rat
|
||||
precFactor int64
|
||||
}{
|
||||
{New(333, 777), New(429, 1000), 1000},
|
||||
{Rat{new(big.Rat).SetFrac(big3, big7)}, New(429, 1000), 1000},
|
||||
{Rat{new(big.Rat).SetFrac(big3, big7)}, New(4285714286, 10000000000), 10000000000},
|
||||
{New(1, 2), New(1, 2), 1000},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
assert.Equal(tc.res, tc.r1.Round(tc.precFactor), "%v", tc.r1)
|
||||
negR1, negRes := tc.r1.Mul(New(-1)), tc.res.Mul(New(-1))
|
||||
assert.Equal(negRes, negR1.Round(tc.precFactor), "%v", negR1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestZeroSerializationJSON(t *testing.T) {
|
||||
assert := asrt.New(t)
|
||||
|
||||
var r Rat
|
||||
err := json.Unmarshal([]byte("{\"numerator\":0,\"denominator\":1}"), &r)
|
||||
assert.Nil(err)
|
||||
err = json.Unmarshal([]byte("{\"numerator\":0,\"denominator\":0}"), &r)
|
||||
assert.NotNil(err)
|
||||
err = json.Unmarshal([]byte("{\"numerator\":1,\"denominator\":0}"), &r)
|
||||
assert.NotNil(err)
|
||||
err = json.Unmarshal([]byte("{}"), &r)
|
||||
assert.NotNil(err)
|
||||
}
|
||||
|
||||
func TestSerializationJSON(t *testing.T) {
|
||||
assert, require := asrt.New(t), rqr.New(t)
|
||||
|
||||
r := New(1, 3)
|
||||
|
||||
rMarshal, err := json.Marshal(r)
|
||||
require.Nil(err)
|
||||
|
||||
var rUnmarshal Rat
|
||||
err = json.Unmarshal(rMarshal, &rUnmarshal)
|
||||
require.Nil(err)
|
||||
|
||||
assert.True(r.Equal(rUnmarshal), "original: %v, unmarshalled: %v", r, rUnmarshal)
|
||||
}
|
||||
|
||||
func TestSerializationGoWire(t *testing.T) {
|
||||
assert, require := asrt.New(t), rqr.New(t)
|
||||
|
||||
r := New(1, 3)
|
||||
|
||||
rMarshal, err := cdc.MarshalJSON(r)
|
||||
require.Nil(err)
|
||||
|
||||
var rUnmarshal Rat
|
||||
err = cdc.UnmarshalJSON(rMarshal, &rUnmarshal)
|
||||
require.Nil(err)
|
||||
|
||||
assert.True(r.Equal(rUnmarshal), "original: %v, unmarshalled: %v", r, rUnmarshal)
|
||||
}
|
||||
|
||||
type testEmbedStruct struct {
|
||||
Field1 string `json:"f1"`
|
||||
Field2 int `json:"f2"`
|
||||
Field3 Rat `json:"f3"`
|
||||
}
|
||||
|
||||
func TestEmbeddedStructSerializationGoWire(t *testing.T) {
|
||||
assert, require := asrt.New(t), rqr.New(t)
|
||||
|
||||
r := testEmbedStruct{"foo", 10, New(1, 3)}
|
||||
|
||||
rMarshal, err := cdc.MarshalJSON(r)
|
||||
require.Nil(err)
|
||||
|
||||
var rUnmarshal testEmbedStruct
|
||||
err = cdc.UnmarshalJSON(rMarshal, &rUnmarshal)
|
||||
require.Nil(err)
|
||||
|
||||
assert.Equal(r.Field1, rUnmarshal.Field1)
|
||||
assert.Equal(r.Field2, rUnmarshal.Field2)
|
||||
assert.True(r.Field3.Equal(rUnmarshal.Field3), "original: %v, unmarshalled: %v", r, rUnmarshal)
|
||||
|
||||
}
|
||||
|
||||
type testEmbedInterface struct {
|
||||
Field1 string `json:"f1"`
|
||||
Field2 int `json:"f2"`
|
||||
Field3 Rational `json:"f3"`
|
||||
}
|
||||
|
||||
func TestEmbeddedInterfaceSerializationGoWire(t *testing.T) {
|
||||
assert, require := asrt.New(t), rqr.New(t)
|
||||
|
||||
r := testEmbedInterface{"foo", 10, New(1, 3)}
|
||||
|
||||
rMarshal, err := cdc.MarshalJSON(r)
|
||||
require.Nil(err)
|
||||
|
||||
var rUnmarshal testEmbedInterface
|
||||
err = cdc.UnmarshalJSON(rMarshal, &rUnmarshal)
|
||||
require.Nil(err)
|
||||
|
||||
assert.Equal(r.Field1, rUnmarshal.Field1)
|
||||
assert.Equal(r.Field2, rUnmarshal.Field2)
|
||||
assert.True(r.Field3.Equal(rUnmarshal.Field3), "original: %v, unmarshalled: %v", r, rUnmarshal)
|
||||
}
|
Loading…
Reference in New Issue