diff --git a/binary/binary.go b/binary/binary.go index f13e38b1..0ad636bf 100644 --- a/binary/binary.go +++ b/binary/binary.go @@ -3,7 +3,5 @@ package binary import "io" type Binary interface { - ByteSize() int WriteTo(io.Writer) (int64, error) - Equals(Binary) bool } diff --git a/binary/codec.go b/binary/codec.go index 96e73e24..b729440f 100644 --- a/binary/codec.go +++ b/binary/codec.go @@ -1,22 +1,26 @@ package binary -const ( - TYPE_NIL = byte(0x00) - TYPE_BYTE = byte(0x01) - TYPE_INT8 = byte(0x02) - TYPE_UINT8 = byte(0x03) - TYPE_INT16 = byte(0x04) - TYPE_UINT16 = byte(0x05) - TYPE_INT32 = byte(0x06) - TYPE_UINT32 = byte(0x07) - TYPE_INT64 = byte(0x08) - TYPE_UINT64 = byte(0x09) - - TYPE_STRING = byte(0x10) - TYPE_BYTESLICE = byte(0x11) +import ( + "io" ) -func GetBinaryType(o Binary) byte { +const ( + TYPE_NIL = Byte(0x00) + TYPE_BYTE = Byte(0x01) + TYPE_INT8 = Byte(0x02) + TYPE_UINT8 = Byte(0x03) + TYPE_INT16 = Byte(0x04) + TYPE_UINT16 = Byte(0x05) + TYPE_INT32 = Byte(0x06) + TYPE_UINT32 = Byte(0x07) + TYPE_INT64 = Byte(0x08) + TYPE_UINT64 = Byte(0x09) + + TYPE_STRING = Byte(0x10) + TYPE_BYTESLICE = Byte(0x11) +) + +func GetBinaryType(o Binary) Byte { switch o.(type) { case nil: return TYPE_NIL case Byte: return TYPE_BYTE @@ -38,22 +42,22 @@ func GetBinaryType(o Binary) byte { } } -func ReadBinary(buf []byte, start int) (Binary, int) { - typeByte := buf[start] - switch typeByte { - case TYPE_NIL: return nil, start+1 - case TYPE_BYTE: return ReadByte(buf[start+1:]), start+2 - case TYPE_INT8: return ReadInt8(buf[start+1:]), start+2 - case TYPE_UINT8: return ReadUInt8(buf[start+1:]), start+2 - case TYPE_INT16: return ReadInt16(buf[start+1:]), start+3 - case TYPE_UINT16: return ReadUInt16(buf[start+1:]), start+3 - case TYPE_INT32: return ReadInt32(buf[start+1:]), start+5 - case TYPE_UINT32: return ReadUInt32(buf[start+1:]), start+5 - case TYPE_INT64: return ReadInt64(buf[start+1:]), start+9 - case TYPE_UINT64: return ReadUInt64(buf[start+1:]), start+9 +func ReadBinary(r io.Reader) Binary { + type_ := ReadByte(r) + switch type_ { + case TYPE_NIL: return nil + case TYPE_BYTE: return ReadByte(r) + case TYPE_INT8: return ReadInt8(r) + case TYPE_UINT8: return ReadUInt8(r) + case TYPE_INT16: return ReadInt16(r) + case TYPE_UINT16: return ReadUInt16(r) + case TYPE_INT32: return ReadInt32(r) + case TYPE_UINT32: return ReadUInt32(r) + case TYPE_INT64: return ReadInt64(r) + case TYPE_UINT64: return ReadUInt64(r) - case TYPE_STRING: return ReadString(buf, start+1) - case TYPE_BYTESLICE:return ReadByteSlice(buf, start+1) + case TYPE_STRING: return ReadString(r) + case TYPE_BYTESLICE:return ReadByteSlice(r) default: panic("Unsupported type") } diff --git a/binary/int.go b/binary/int.go index f7de9cd4..9aeac271 100644 --- a/binary/int.go +++ b/binary/int.go @@ -41,8 +41,11 @@ func (self Byte) WriteTo(w io.Writer) (int64, error) { return int64(n), err } -func ReadByte(bytes []byte) Byte { - return Byte(bytes[0]) +func ReadByte(r io.Reader) Byte { + buf := [1]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return Byte(buf[0]) } @@ -69,8 +72,11 @@ func (self Int8) WriteTo(w io.Writer) (int64, error) { return int64(n), err } -func ReadInt8(bytes []byte) Int8 { - return Int8(bytes[0]) +func ReadInt8(r io.Reader) Int8 { + buf := [1]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return Int8(buf[0]) } @@ -97,8 +103,11 @@ func (self UInt8) WriteTo(w io.Writer) (int64, error) { return int64(n), err } -func ReadUInt8(bytes []byte) UInt8 { - return UInt8(bytes[0]) +func ReadUInt8(r io.Reader) UInt8 { + buf := [1]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return UInt8(buf[0]) } @@ -125,8 +134,11 @@ func (self Int16) WriteTo(w io.Writer) (int64, error) { return 2, err } -func ReadInt16(bytes []byte) Int16 { - return Int16(binary.LittleEndian.Uint16(bytes)) +func ReadInt16(r io.Reader) Int16 { + buf := [2]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return Int16(binary.LittleEndian.Uint16(buf[:])) } @@ -153,8 +165,11 @@ func (self UInt16) WriteTo(w io.Writer) (int64, error) { return 2, err } -func ReadUInt16(bytes []byte) UInt16 { - return UInt16(binary.LittleEndian.Uint16(bytes)) +func ReadUInt16(r io.Reader) UInt16 { + buf := [2]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return UInt16(binary.LittleEndian.Uint16(buf[:])) } @@ -181,8 +196,11 @@ func (self Int32) WriteTo(w io.Writer) (int64, error) { return 4, err } -func ReadInt32(bytes []byte) Int32 { - return Int32(binary.LittleEndian.Uint32(bytes)) +func ReadInt32(r io.Reader) Int32 { + buf := [4]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return Int32(binary.LittleEndian.Uint32(buf[:])) } @@ -209,8 +227,11 @@ func (self UInt32) WriteTo(w io.Writer) (int64, error) { return 4, err } -func ReadUInt32(bytes []byte) UInt32 { - return UInt32(binary.LittleEndian.Uint32(bytes)) +func ReadUInt32(r io.Reader) UInt32 { + buf := [4]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return UInt32(binary.LittleEndian.Uint32(buf[:])) } @@ -237,8 +258,11 @@ func (self Int64) WriteTo(w io.Writer) (int64, error) { return 8, err } -func ReadInt64(bytes []byte) Int64 { - return Int64(binary.LittleEndian.Uint64(bytes)) +func ReadInt64(r io.Reader) Int64 { + buf := [8]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return Int64(binary.LittleEndian.Uint64(buf[:])) } @@ -265,8 +289,11 @@ func (self UInt64) WriteTo(w io.Writer) (int64, error) { return 8, err } -func ReadUInt64(bytes []byte) UInt64 { - return UInt64(binary.LittleEndian.Uint64(bytes)) +func ReadUInt64(r io.Reader) UInt64 { + buf := [8]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return UInt64(binary.LittleEndian.Uint64(buf[:])) } @@ -293,10 +320,14 @@ func (self Int) WriteTo(w io.Writer) (int64, error) { return 8, err } -func ReadInt(bytes []byte) Int { - return Int(binary.LittleEndian.Uint64(bytes)) +func ReadInt(r io.Reader) Int { + buf := [8]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return Int(binary.LittleEndian.Uint64(buf[:])) } + // UInt func (self UInt) Equals(other Binary) bool { @@ -320,6 +351,9 @@ func (self UInt) WriteTo(w io.Writer) (int64, error) { return 8, err } -func ReadUInt(bytes []byte) UInt { - return UInt(binary.LittleEndian.Uint64(bytes)) +func ReadUInt(r io.Reader) UInt { + buf := [8]byte{0} + _, err := io.ReadFull(r, buf[:]) + if err != nil { panic(err) } + return UInt(binary.LittleEndian.Uint64(buf[:])) } diff --git a/binary/string.go b/binary/string.go index 3dd91cbe..98a94cfc 100644 --- a/binary/string.go +++ b/binary/string.go @@ -32,10 +32,12 @@ func (self String) WriteTo(w io.Writer) (n int64, err error) { return int64(n_+4), err } -// NOTE: keeps a reference to the original byte slice -func ReadString(bytes []byte, start int) (String, int) { - length := int(ReadUInt32(bytes[start:])) - return String(bytes[start+4:start+4+length]), start+4+length +func ReadString(r io.Reader) String { + length := int(ReadUInt32(r)) + bytes := make([]byte, length) + _, err := io.ReadFull(r, bytes) + if err != nil { panic(err) } + return String(bytes) } @@ -69,8 +71,10 @@ func (self ByteSlice) WriteTo(w io.Writer) (n int64, err error) { return int64(n_+4), err } -// NOTE: keeps a reference to the original byte slice -func ReadByteSlice(bytes []byte, start int) (ByteSlice, int) { - length := int(ReadUInt32(bytes[start:])) - return ByteSlice(bytes[start+4:start+4+length]), start+4+length +func ReadByteSlice(r io.Reader) ByteSlice { + length := int(ReadUInt32(r)) + bytes := make([]byte, length) + _, err := io.ReadFull(r, bytes) + if err != nil { panic(err) } + return ByteSlice(bytes) } diff --git a/binary/util.go b/binary/util.go new file mode 100644 index 00000000..4087378c --- /dev/null +++ b/binary/util.go @@ -0,0 +1,33 @@ +package binary + +import ( + "crypto/sha256" + "bytes" +) + +func BinaryBytes(b Binary) ByteSlice { + buf := bytes.NewBuffer(nil) + b.WriteTo(buf) + return ByteSlice(buf.Bytes()) +} + +// NOTE: does not care about the type, only the binary representation. +func BinaryEqual(a, b Binary) bool { + aBytes := BinaryBytes(a) + bBytes := BinaryBytes(b) + return bytes.Equal(aBytes, bBytes) +} + +// NOTE: does not care about the type, only the binary representation. +func BinaryCompare(a, b Binary) int { + aBytes := BinaryBytes(a) + bBytes := BinaryBytes(b) + return bytes.Compare(aBytes, bBytes) +} + +func BinaryHash(b Binary) ByteSlice { + hasher := sha256.New() + _, err := b.WriteTo(hasher) + if err != nil { panic(err) } + return ByteSlice(hasher.Sum(nil)) +} diff --git a/merkle/iavl.go b/merkle/iavl.go index 44680df0..a33dd2ce 100644 --- a/merkle/iavl.go +++ b/merkle/iavl.go @@ -134,14 +134,6 @@ func (self *IAVLNode) Copy() *IAVLNode { } } -func (self *IAVLNode) Equals(other Binary) bool { - if o, ok := other.(*IAVLNode); ok { - return self.hash.Equals(o.hash) - } else { - return false - } -} - func (self *IAVLNode) Key() Key { return self.key } @@ -218,11 +210,10 @@ func (self *IAVLNode) Save(db Db) { } // save self - buf := make([]byte, 0, self.ByteSize()) - n, err := self.WriteTo(bytes.NewBuffer(buf)) + buf := bytes.NewBuffer(nil) + _, err := self.WriteTo(buf) if err != nil { panic(err) } - if n != int64(cap(buf)) { panic("unexpected write length") } - db.Put([]byte(self.hash), buf[0:cap(buf)]) + db.Put([]byte(self.hash), buf.Bytes()) self.flags |= IAVLNODE_FLAG_PERSISTED } @@ -303,27 +294,6 @@ func (self *IAVLNode) remove(db Db, key Key) (newSelf *IAVLNode, newKey Key, val } } -func (self *IAVLNode) ByteSize() int { - // 1 byte node height - // 8 bytes node size - size := 9 - // key - size += 1 // type info - size += self.key.ByteSize() - if self.height == 0 { - // value - size += 1 // type info - if self.value != nil { - size += self.value.ByteSize() - } - } else { - // children - size += HASH_BYTE_SIZE - size += HASH_BYTE_SIZE - } - return size -} - func (self *IAVLNode) WriteTo(w io.Writer) (n int64, err error) { n, _, err = self.saveToCountHashes(w, true) return @@ -378,33 +348,33 @@ func (self *IAVLNode) fill(db Db) { panic("placeholder.hash can't be nil") } buf := db.Get(self.hash) - cur := 0 + r := bytes.NewReader(buf) // node header - self.height = uint8(ReadUInt8(buf[0:])) - self.size = uint64(ReadUInt64(buf[1:])) + self.height = uint8(ReadUInt8(r)) + self.size = uint64(ReadUInt64(r)) // key - key, cur := ReadBinary(buf, 9) + key := ReadBinary(r) self.key = key.(Key) if self.height == 0 { // value - self.value, cur = ReadBinary(buf, cur) + self.value = ReadBinary(r) } else { // left var leftHash ByteSlice - leftHash, cur = ReadByteSlice(buf, cur) + leftHash = ReadByteSlice(r) self.left = &IAVLNode{ hash: leftHash, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, } // right var rightHash ByteSlice - rightHash, cur = ReadByteSlice(buf, cur) + rightHash = ReadByteSlice(r) self.right = &IAVLNode{ hash: rightHash, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, } - if cur != len(buf) { + if r.Len() != 0 { panic("buf not all consumed") } } diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index c81f7298..7fdf5887 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -233,7 +233,7 @@ func TestPersistence(t *testing.T) { t2 := NewIAVLTreeFromHash(db, hash) for key, value := range records { t2value := t2.Get(key) - if !t2value.Equals(value) { + if !BinaryEqual(t2value, value) { t.Fatalf("Invalid value. Expected %v, got %v", value, t2value) } } diff --git a/merkle/types.go b/merkle/types.go index b379adba..83ecb193 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -11,6 +11,7 @@ type Value interface { type Key interface { Binary + Equals(Binary) bool Less(b Binary) bool }