diff --git a/crypto/types/compact_bit_array.go b/crypto/types/compact_bit_array.go index 2d81636f4..9ac23212a 100644 --- a/crypto/types/compact_bit_array.go +++ b/crypto/types/compact_bit_array.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "math" "regexp" "strings" ) @@ -15,14 +16,24 @@ import ( // This is not thread safe, and is not intended for concurrent usage. // NewCompactBitArray returns a new compact bit array. -// It returns nil if the number of bits is zero. +// It returns nil if the number of bits is zero, or if there is any overflow +// in the arithmetic to encounter for the number of its elements: (bits+7)/8, +// or if the number of elements will be an unreasonably large number like +// > maxint32 aka >2**31. func NewCompactBitArray(bits int) *CompactBitArray { if bits <= 0 { return nil } + nElems := (bits + 7) / 8 + if nElems <= 0 || nElems > math.MaxInt32 { + // We encountered an overflow here, and shouldn't pass negatives + // to make, nor should we allow unreasonable limits > maxint32. + // See https://github.com/cosmos/cosmos-sdk/issues/9162 + return nil + } return &CompactBitArray{ ExtraBitsStored: uint32(bits % 8), - Elems: make([]byte, (bits+7)/8), + Elems: make([]byte, nElems), } } diff --git a/crypto/types/compact_bit_array_test.go b/crypto/types/compact_bit_array_test.go index 44f97d6f4..05bc73eb3 100644 --- a/crypto/types/compact_bit_array_test.go +++ b/crypto/types/compact_bit_array_test.go @@ -2,6 +2,8 @@ package types import ( "encoding/json" + "fmt" + "math" "math/rand" "testing" @@ -239,3 +241,35 @@ func BenchmarkNumTrueBitsBefore(b *testing.B) { } }) } + +func TestNewCompactBitArrayCrashWithLimits(t *testing.T) { + if testing.Short() { + t.Skip("This test can be expensive in memory") + } + tests := []struct { + in int + mustPass bool + }{ + {int(^uint(0) >> 30), false}, + {int(^uint(0) >> 1), false}, + {int(^uint(0) >> 2), false}, + {int(math.MaxInt32), true}, + {int(math.MaxInt32) + 1, true}, + {int(math.MaxInt32) + 2, true}, + {int(math.MaxInt32) - 7, true}, + {int(math.MaxInt32) + 24, true}, + {int(math.MaxInt32) * 9, false}, // results in >=maxint after (bits+7)/8 + {1, true}, + {0, false}, + } + + for _, tt := range tests { + tt := tt + t.Run(fmt.Sprintf("%d", tt.in), func(t *testing.T) { + got := NewCompactBitArray(tt.in) + if g := got != nil; g != tt.mustPass { + t.Fatalf("got!=nil=%t, want=%t", g, tt.mustPass) + } + }) + } +}