diff --git a/crypto/multisig/compact_bit_array.go b/crypto/multisig/compact_bit_array.go new file mode 100644 index 00000000..d14dd8e6 --- /dev/null +++ b/crypto/multisig/compact_bit_array.go @@ -0,0 +1,218 @@ +package multisig + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "regexp" + "strings" +) + +// CompactBitArray is an implementation of a space efficient bit array. +// This is used to ensure that the encoded data takes up a minimal amount of +// space after amino encoding. +// This is not thread safe, and is not intended for concurrent usage. +type CompactBitArray struct { + ExtraBitsStored byte `json:"extra_bits"` // The number of extra bits in elems. + Elems []byte `json:"bits"` +} + +// NewCompactBitArray returns a new compact bit array. +// It returns nil if the number of bits is zero. +func NewCompactBitArray(bits int) *CompactBitArray { + if bits <= 0 { + return nil + } + return &CompactBitArray{ + ExtraBitsStored: byte(bits % 8), + Elems: make([]byte, (bits+7)/8), + } +} + +// Size returns the number of bits in the bitarray +func (bA *CompactBitArray) Size() int { + if bA == nil { + return 0 + } else if bA.ExtraBitsStored == byte(0) { + return len(bA.Elems) * 8 + } + return (len(bA.Elems)-1)*8 + int(bA.ExtraBitsStored) +} + +// GetIndex returns the bit at index i within the bit array. +// The behavior is undefined if i >= bA.Size() +func (bA *CompactBitArray) GetIndex(i int) bool { + if bA == nil { + return false + } + if i >= bA.Size() { + return false + } + return bA.Elems[i>>3]&(uint8(1)< 0 +} + +// SetIndex sets the bit at index i within the bit array. +// The behavior is undefined if i >= bA.Size() +func (bA *CompactBitArray) SetIndex(i int, v bool) bool { + if bA == nil { + return false + } + if i >= bA.Size() { + return false + } + if v { + bA.Elems[i>>3] |= (uint8(1) << uint8(7-(i%8))) + } else { + bA.Elems[i>>3] &= ^(uint8(1) << uint8(7-(i%8))) + } + return true +} + +// Copy returns a copy of the provided bit array. +func (bA *CompactBitArray) Copy() *CompactBitArray { + if bA == nil { + return nil + } + c := make([]byte, len(bA.Elems)) + copy(c, bA.Elems) + return &CompactBitArray{ + ExtraBitsStored: bA.ExtraBitsStored, + Elems: c, + } +} + +// String returns a string representation of CompactBitArray: BA{}, +// where is a sequence of 'x' (1) and '_' (0). +// The includes spaces and newlines to help people. +// For a simple sequence of 'x' and '_' characters with no spaces or newlines, +// see the MarshalJSON() method. +// Example: "BA{_x_}" or "nil-BitArray" for nil. +func (bA *CompactBitArray) String() string { + return bA.StringIndented("") +} + +// StringIndented returns the same thing as String(), but applies the indent +// at every 10th bit, and twice at every 50th bit. +func (bA *CompactBitArray) StringIndented(indent string) string { + if bA == nil { + return "nil-BitArray" + } + lines := []string{} + bits := "" + size := bA.Size() + for i := 0; i < size; i++ { + if bA.GetIndex(i) { + bits += "x" + } else { + bits += "_" + } + if i%100 == 99 { + lines = append(lines, bits) + bits = "" + } + if i%10 == 9 { + bits += indent + } + if i%50 == 49 { + bits += indent + } + } + if len(bits) > 0 { + lines = append(lines, bits) + } + return fmt.Sprintf("BA{%v:%v}", size, strings.Join(lines, indent)) +} + +// MarshalJSON implements json.Marshaler interface by marshaling bit array +// using a custom format: a string of '-' or 'x' where 'x' denotes the 1 bit. +func (bA *CompactBitArray) MarshalJSON() ([]byte, error) { + if bA == nil { + return []byte("null"), nil + } + + bits := `"` + size := bA.Size() + for i := 0; i < size; i++ { + if bA.GetIndex(i) { + bits += `x` + } else { + bits += `_` + } + } + bits += `"` + return []byte(bits), nil +} + +var bitArrayJSONRegexp = regexp.MustCompile(`\A"([_x]*)"\z`) + +// UnmarshalJSON implements json.Unmarshaler interface by unmarshaling a custom +// JSON description. +func (bA *CompactBitArray) UnmarshalJSON(bz []byte) error { + b := string(bz) + if b == "null" { + // This is required e.g. for encoding/json when decoding + // into a pointer with pre-allocated BitArray. + bA.ExtraBitsStored = 0 + bA.Elems = nil + return nil + } + + // Validate 'b'. + match := bitArrayJSONRegexp.FindStringSubmatch(b) + if match == nil { + return fmt.Errorf("BitArray in JSON should be a string of format %q but got %s", bitArrayJSONRegexp.String(), b) + } + bits := match[1] + + // Construct new CompactBitArray and copy over. + numBits := len(bits) + bA2 := NewCompactBitArray(numBits) + for i := 0; i < numBits; i++ { + if bits[i] == 'x' { + bA2.SetIndex(i, true) + } + } + *bA = *bA2 + return nil +} + +// CompactMarshal is a space efficient encoding for CompactBitArray. +// It is not amino compatible. +func (bA *CompactBitArray) CompactMarshal() []byte { + size := bA.Size() + if size <= 0 { + return []byte("null") + } + bz := make([]byte, 0, size/8) + // length prefix number of bits, not number of bytes. This difference + // takes 3-4 bits in encoding, as opposed to instead encoding the number of + // bytes (saving 3-4 bits) and including the offset as a full byte. + bz = appendUvarint(bz, uint64(size)) + bz = append(bz, bA.Elems...) + return bz +} + +// CompactUnmarshal is a space efficient decoding for CompactBitArray. +// It is not amino compatible. +func CompactUnmarshal(bz []byte) (*CompactBitArray, error) { + if len(bz) < 2 { + return nil, errors.New("compact bit array: invalid compact unmarshal size") + } else if bytes.Equal(bz, []byte("null")) { + return NewCompactBitArray(0), nil + } + size, n := binary.Uvarint(bz) + bz = bz[n:] + if len(bz) != int(size+7)/8 { + return nil, errors.New("compact bit array: invalid compact unmarshal size") + } + + bA := &CompactBitArray{byte(int(size % 8)), bz} + return bA, nil +} + +func appendUvarint(b []byte, x uint64) []byte { + var a [binary.MaxVarintLen64]byte + n := binary.PutUvarint(a[:], x) + return append(b, a[:n]...) +} diff --git a/crypto/multisig/compact_bit_array_test.go b/crypto/multisig/compact_bit_array_test.go new file mode 100644 index 00000000..91a82192 --- /dev/null +++ b/crypto/multisig/compact_bit_array_test.go @@ -0,0 +1,169 @@ +package multisig + +import ( + "encoding/json" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cmn "github.com/tendermint/tendermint/libs/common" +) + +func randCompactBitArray(bits int) (*CompactBitArray, []byte) { + numBytes := (bits + 7) / 8 + src := cmn.RandBytes((bits + 7) / 8) + bA := NewCompactBitArray(bits) + + for i := 0; i < numBytes-1; i++ { + for j := uint8(0); j < 8; j++ { + bA.SetIndex(i*8+int(j), src[i]&(uint8(1)<<(8-j)) > 0) + } + } + // Set remaining bits + for i := uint8(0); i < 8-uint8(bA.ExtraBitsStored); i++ { + bA.SetIndex(numBytes*8+int(i), src[numBytes-1]&(uint8(1)<<(8-i)) > 0) + } + return bA, src +} + +func TestNewBitArrayNeverCrashesOnNegatives(t *testing.T) { + bitList := []int{-127, -128, -1 << 31} + for _, bits := range bitList { + _ = NewCompactBitArray(bits) + } +} + +func TestJSONMarshalUnmarshal(t *testing.T) { + + bA1 := NewCompactBitArray(0) + bA2 := NewCompactBitArray(1) + + bA3 := NewCompactBitArray(1) + bA3.SetIndex(0, true) + + bA4 := NewCompactBitArray(5) + bA4.SetIndex(0, true) + bA4.SetIndex(1, true) + + bA5 := NewCompactBitArray(9) + bA5.SetIndex(0, true) + bA5.SetIndex(1, true) + bA5.SetIndex(8, true) + + bA6 := NewCompactBitArray(16) + bA6.SetIndex(0, true) + bA6.SetIndex(1, true) + bA6.SetIndex(8, false) + bA6.SetIndex(15, true) + + testCases := []struct { + bA *CompactBitArray + marshalledBA string + }{ + {nil, `null`}, + {bA1, `null`}, + {bA2, `"_"`}, + {bA3, `"x"`}, + {bA4, `"xx___"`}, + {bA5, `"xx______x"`}, + {bA6, `"xx_____________x"`}, + } + + for _, tc := range testCases { + t.Run(tc.bA.String(), func(t *testing.T) { + bz, err := json.Marshal(tc.bA) + require.NoError(t, err) + + assert.Equal(t, tc.marshalledBA, string(bz)) + + var unmarshalledBA *CompactBitArray + err = json.Unmarshal(bz, &unmarshalledBA) + require.NoError(t, err) + + if tc.bA == nil { + require.Nil(t, unmarshalledBA) + } else { + require.NotNil(t, unmarshalledBA) + assert.EqualValues(t, tc.bA.Elems, unmarshalledBA.Elems) + if assert.EqualValues(t, tc.bA.String(), unmarshalledBA.String()) { + assert.EqualValues(t, tc.bA.Elems, unmarshalledBA.Elems) + } + } + }) + } +} + +func TestCompactMarshalUnmarshal(t *testing.T) { + bA1 := NewCompactBitArray(0) + bA2 := NewCompactBitArray(1) + + bA3 := NewCompactBitArray(1) + bA3.SetIndex(0, true) + + bA4 := NewCompactBitArray(5) + bA4.SetIndex(0, true) + bA4.SetIndex(1, true) + + bA5 := NewCompactBitArray(9) + bA5.SetIndex(0, true) + bA5.SetIndex(1, true) + bA5.SetIndex(8, true) + + bA6 := NewCompactBitArray(16) + bA6.SetIndex(0, true) + bA6.SetIndex(1, true) + bA6.SetIndex(8, false) + bA6.SetIndex(15, true) + + testCases := []struct { + bA *CompactBitArray + marshalledBA []byte + }{ + {nil, []byte("null")}, + {bA1, []byte("null")}, + {bA2, []byte{byte(1), byte(0)}}, + {bA3, []byte{byte(1), byte(128)}}, + {bA4, []byte{byte(5), byte(192)}}, + {bA5, []byte{byte(9), byte(192), byte(128)}}, + {bA6, []byte{byte(16), byte(192), byte(1)}}, + } + + for _, tc := range testCases { + t.Run(tc.bA.String(), func(t *testing.T) { + bz := tc.bA.CompactMarshal() + + assert.Equal(t, tc.marshalledBA, bz) + + unmarshalledBA, err := CompactUnmarshal(bz) + require.NoError(t, err) + if tc.bA == nil { + require.Nil(t, unmarshalledBA) + } else { + require.NotNil(t, unmarshalledBA) + assert.EqualValues(t, tc.bA.Elems, unmarshalledBA.Elems) + if assert.EqualValues(t, tc.bA.String(), unmarshalledBA.String()) { + assert.EqualValues(t, tc.bA.Elems, unmarshalledBA.Elems) + } + } + }) + } +} + +func TestCompactBitArrayGetSetIndex(t *testing.T) { + r := rand.New(rand.NewSource(100)) + numTests := 10 + numBitsPerArr := 100 + for i := 0; i < numTests; i++ { + bits := r.Intn(1000) + bA, _ := randCompactBitArray(bits) + + for j := 0; j < numBitsPerArr; j++ { + copy := bA.Copy() + index := r.Intn(bits) + val := (r.Int63() % 2) == 0 + bA.SetIndex(index, val) + require.Equal(t, val, bA.GetIndex(index), "bA.SetIndex(%d, %v) failed on bit array: %s", index, val, copy) + } + } +}