From 69c3a7640bc48957ed3984596da276264b3f1038 Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Mon, 25 Dec 2017 18:38:35 -0600 Subject: [PATCH] add safeAdd & safeSub plus quickcheck tests --- types/validator_set.go | 39 +++++++++++++++++++++++++++++-------- types/validator_set_test.go | 35 ++++++++++++++++++++++++++------- 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/types/validator_set.go b/types/validator_set.go index 9aaa6830..a9b98659 100644 --- a/types/validator_set.go +++ b/types/validator_set.go @@ -52,10 +52,15 @@ func (valSet *ValidatorSet) IncrementAccum(times int) { // Add VotingPower * times to each validator and order into heap. validatorsHeap := cmn.NewHeap() for _, val := range valSet.Validators { - res, overflow := signedMulWithOverflowCheck(val.VotingPower, int64(times)) // check for overflow both multiplication and sum - if !overflow && val.Accum <= mostPositive-res { - val.Accum += res + res, overflow := safeMul(val.VotingPower, int64(times)) + if !overflow { + res2, overflow2 := safeAdd(val.Accum, res) + if !overflow2 { + val.Accum = res2 + } else { + val.Accum = mostPositive + } } else { val.Accum = mostPositive } @@ -70,8 +75,9 @@ func (valSet *ValidatorSet) IncrementAccum(times int) { } // mind underflow - if mostest.Accum >= mostNegative+valSet.TotalVotingPower() { - mostest.Accum -= valSet.TotalVotingPower() + res, underflow := safeSub(mostest.Accum, valSet.TotalVotingPower()) + if !underflow { + mostest.Accum = res } else { mostest.Accum = mostNegative } @@ -129,8 +135,9 @@ func (valSet *ValidatorSet) TotalVotingPower() int64 { if valSet.totalVotingPower == 0 { for _, val := range valSet.Validators { // mind overflow - if valSet.totalVotingPower <= mostPositive-val.VotingPower { - valSet.totalVotingPower += val.VotingPower + res, overflow := safeAdd(valSet.totalVotingPower, val.VotingPower) + if !overflow { + valSet.totalVotingPower = res } else { valSet.totalVotingPower = mostPositive return valSet.totalVotingPower @@ -443,10 +450,13 @@ func RandValidatorSet(numValidators int, votingPower int64) (*ValidatorSet, []*P return valSet, privValidators } +/////////////////////////////////////////////////////////////////////////////// +// Safe multiplication and addition/subtraction + const mostNegative int64 = -mostPositive - 1 const mostPositive int64 = 1<<63 - 1 -func signedMulWithOverflowCheck(a, b int64) (int64, bool) { +func safeMul(a, b int64) (int64, bool) { if a == 0 || b == 0 { return 0, false } @@ -462,3 +472,16 @@ func signedMulWithOverflowCheck(a, b int64) (int64, bool) { c := a * b return c, c/b != a } + +func safeAdd(a, b int64) (int64, bool) { + if b > 0 && a > mostPositive-b { + return -1, true + } else if b < 0 && a < mostNegative-b { + return -1, true + } + return a + b, false +} + +func safeSub(a, b int64) (int64, bool) { + return safeAdd(a, -b) +} diff --git a/types/validator_set_test.go b/types/validator_set_test.go index c65f507f..dd2a5999 100644 --- a/types/validator_set_test.go +++ b/types/validator_set_test.go @@ -4,6 +4,7 @@ import ( "bytes" "strings" "testing" + "testing/quick" "github.com/stretchr/testify/assert" crypto "github.com/tendermint/go-crypto" @@ -191,6 +192,16 @@ func TestProposerSelection3(t *testing.T) { } } +func TestValidatorSetTotalVotingPowerOverflows(t *testing.T) { + vset := NewValidatorSet([]*Validator{ + {Address: []byte("a"), VotingPower: mostPositive, Accum: 0}, + {Address: []byte("b"), VotingPower: mostPositive, Accum: 0}, + {Address: []byte("c"), VotingPower: mostPositive, Accum: 0}, + }) + + assert.Equal(t, mostPositive, vset.TotalVotingPower()) +} + func TestValidatorSetIncrementAccumOverflows(t *testing.T) { // NewValidatorSet calls IncrementAccum(1) vset := NewValidatorSet([]*Validator{ @@ -220,14 +231,24 @@ func TestValidatorSetIncrementAccumUnderflows(t *testing.T) { assert.Equal(t, mostNegative, vset.Validators[1].Accum, "1") } -func TestValidatorSetTotalVotingPowerOverflows(t *testing.T) { - vset := NewValidatorSet([]*Validator{ - {Address: []byte("a"), VotingPower: mostPositive, Accum: 0}, - {Address: []byte("b"), VotingPower: mostPositive, Accum: 0}, - {Address: []byte("c"), VotingPower: mostPositive, Accum: 0}, - }) +func TestSafeMul(t *testing.T) { + f := func(a, b int64) bool { + c, overflow := safeMul(a, b) + return overflow || (!overflow && c == a*b) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} - assert.Equal(t, mostPositive, vset.TotalVotingPower()) +func TestSafeAdd(t *testing.T) { + f := func(a, b int64) bool { + c, overflow := safeAdd(a, b) + return overflow || (!overflow && c == a+b) + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } } func BenchmarkValidatorSetCopy(b *testing.B) {