bit_array: Simplify subtraction

also, fix potential bug in Or function
This commit is contained in:
ValarDragon 2018-10-02 16:03:59 -07:00 committed by Anton Kaliaev
parent c94133ed1b
commit 0755a5203d
3 changed files with 47 additions and 103 deletions

View File

@ -11,7 +11,7 @@ BREAKING CHANGES:
* [rpc] \#2298 `/abci_query` takes `prove` argument instead of `trusted` and switches the default * [rpc] \#2298 `/abci_query` takes `prove` argument instead of `trusted` and switches the default
behaviour to `prove=false` behaviour to `prove=false`
* [privval] \#2459 Split `SocketPVMsg`s implementations into Request and Response, where the Response may contain a error message (returned by the remote signer) * [privval] \#2459 Split `SocketPVMsg`s implementations into Request and Response, where the Response may contain a error message (returned by the remote signer)
* Apps * Apps
* [abci] \#2298 ResponseQuery.Proof is now a structured merkle.Proof, not just * [abci] \#2298 ResponseQuery.Proof is now a structured merkle.Proof, not just
arbitrary bytes arbitrary bytes
@ -40,3 +40,4 @@ BUG FIXES:
- [autofile] \#2428 Group.RotateFile need call Flush() before rename (@goolAdapter) - [autofile] \#2428 Group.RotateFile need call Flush() before rename (@goolAdapter)
- [node] \#2434 Make node respond to signal interrupts while sleeping for genesis time - [node] \#2434 Make node respond to signal interrupts while sleeping for genesis time
- [evidence] \#2515 fix db iter leak (@goolAdapter) - [evidence] \#2515 fix db iter leak (@goolAdapter)
- [common/bit_array] Fixed a bug in the `Or` function

View File

@ -119,14 +119,13 @@ func (bA *BitArray) Or(o *BitArray) *BitArray {
} }
bA.mtx.Lock() bA.mtx.Lock()
o.mtx.Lock() o.mtx.Lock()
defer func() {
bA.mtx.Unlock()
o.mtx.Unlock()
}()
c := bA.copyBits(MaxInt(bA.Bits, o.Bits)) c := bA.copyBits(MaxInt(bA.Bits, o.Bits))
for i := 0; i < len(c.Elems); i++ { smaller := MinInt(len(bA.Elems), len(o.Elems))
for i := 0; i < smaller; i++ {
c.Elems[i] |= o.Elems[i] c.Elems[i] |= o.Elems[i]
} }
bA.mtx.Unlock()
o.mtx.Unlock()
return c return c
} }
@ -173,8 +172,9 @@ func (bA *BitArray) not() *BitArray {
} }
// Sub subtracts the two bit-arrays bitwise, without carrying the bits. // Sub subtracts the two bit-arrays bitwise, without carrying the bits.
// This is essentially bA.And(o.Not()). // Note that carryless subtraction of a - b is (a and not b).
// If bA is longer than o, o is right padded with zeroes. // The output is the same as bA, regardless of o's size.
// If bA is longer than o, o is right padded with zeroes
func (bA *BitArray) Sub(o *BitArray) *BitArray { func (bA *BitArray) Sub(o *BitArray) *BitArray {
if bA == nil || o == nil { if bA == nil || o == nil {
// TODO: Decide if we should do 1's complement here? // TODO: Decide if we should do 1's complement here?
@ -182,24 +182,20 @@ func (bA *BitArray) Sub(o *BitArray) *BitArray {
} }
bA.mtx.Lock() bA.mtx.Lock()
o.mtx.Lock() o.mtx.Lock()
defer func() { // output is the same size as bA
bA.mtx.Unlock() c := bA.copyBits(bA.Bits)
o.mtx.Unlock() // Only iterate to the minimum size between the two.
}() // If o is longer, those bits are ignored.
if bA.Bits > o.Bits { // If bA is longer, then skipping those iterations is equivalent
c := bA.copy() // to right padding with 0's
for i := 0; i < len(o.Elems)-1; i++ { smaller := MinInt(len(bA.Elems), len(o.Elems))
c.Elems[i] &= ^o.Elems[i] for i := 0; i < smaller; i++ {
} // &^ is and not in golang
i := len(o.Elems) - 1 c.Elems[i] &^= o.Elems[i]
if i >= 0 {
for idx := i * 64; idx < o.Bits; idx++ {
c.setIndex(idx, c.getIndex(idx) && !o.getIndex(idx))
}
}
return c
} }
return bA.and(o.not()) // Note degenerate case where o == nil bA.mtx.Unlock()
o.mtx.Unlock()
return c
} }
// IsEmpty returns true iff all bits in the bit array are 0 // IsEmpty returns true iff all bits in the bit array are 0

View File

@ -75,87 +75,34 @@ func TestOr(t *testing.T) {
} }
} }
func TestSub1(t *testing.T) { func TestSub(t *testing.T) {
testCases := []struct {
bA1, _ := randBitArray(31) initBA string
bA2, _ := randBitArray(51) subtractingBA string
bA3 := bA1.Sub(bA2) expectedBA string
}{
bNil := (*BitArray)(nil) {`null`, `null`, `null`},
require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) {`"x"`, `null`, `null`},
require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) {`null`, `"x"`, `null`},
require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) {`"x"`, `"x"`, `"_"`},
{`"xxxxxx"`, `"x_x_x_"`, `"_x_x_x"`},
if bA3.Bits != bA1.Bits { {`"x_x_x_"`, `"xxxxxx"`, `"______"`},
t.Error("Expected bA1 bits") {`"xxxxxx"`, `"x_x_x_xxxx"`, `"_x_x_x"`},
{`"x_x_x_xxxx"`, `"xxxxxx"`, `"______xxxx"`},
{`"xxxxxxxxxx"`, `"x_x_x_"`, `"_x_x_xxxxx"`},
{`"x_x_x_"`, `"xxxxxxxxxx"`, `"______"`},
} }
if len(bA3.Elems) != len(bA1.Elems) { for _, tc := range testCases {
t.Error("Expected bA1 elems length") var bA *BitArray
} err := json.Unmarshal([]byte(tc.initBA), &bA)
for i := 0; i < bA3.Bits; i++ { require.Nil(t, err)
expected := bA1.GetIndex(i)
if bA2.GetIndex(i) {
expected = false
}
if bA3.GetIndex(i) != expected {
t.Error("Wrong bit from bA3", i, bA1.GetIndex(i), bA2.GetIndex(i), bA3.GetIndex(i))
}
}
}
func TestSub2(t *testing.T) { var o *BitArray
err = json.Unmarshal([]byte(tc.subtractingBA), &o)
require.Nil(t, err)
bA1, _ := randBitArray(51) got, _ := json.Marshal(bA.Sub(o))
bA2, _ := randBitArray(31) require.Equal(t, tc.expectedBA, string(got), "%s minus %s doesn't equal %s", tc.initBA, tc.subtractingBA, tc.expectedBA)
bA3 := bA1.Sub(bA2)
bNil := (*BitArray)(nil)
require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil))
require.Equal(t, bA1.Sub(nil), (*BitArray)(nil))
require.Equal(t, bNil.Sub(nil), (*BitArray)(nil))
if bA3.Bits != bA1.Bits {
t.Error("Expected bA1 bits")
}
if len(bA3.Elems) != len(bA1.Elems) {
t.Error("Expected bA1 elems length")
}
for i := 0; i < bA3.Bits; i++ {
expected := bA1.GetIndex(i)
if i < bA2.Bits && bA2.GetIndex(i) {
expected = false
}
if bA3.GetIndex(i) != expected {
t.Error("Wrong bit from bA3")
}
}
}
func TestSub3(t *testing.T) {
bA1, _ := randBitArray(231)
bA2, _ := randBitArray(81)
bA3 := bA1.Sub(bA2)
bNil := (*BitArray)(nil)
require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil))
require.Equal(t, bA1.Sub(nil), (*BitArray)(nil))
require.Equal(t, bNil.Sub(nil), (*BitArray)(nil))
if bA3.Bits != bA1.Bits {
t.Error("Expected bA1 bits")
}
if len(bA3.Elems) != len(bA1.Elems) {
t.Error("Expected bA1 elems length")
}
for i := 0; i < bA3.Bits; i++ {
expected := bA1.GetIndex(i)
if i < bA2.Bits && bA2.GetIndex(i){
expected = false
}
if bA3.GetIndex(i) != expected {
t.Error("Wrong bit from bA3")
}
} }
} }