use wrappers.packer instead of byte array

This commit is contained in:
Dan Laine 2020-06-12 10:41:02 -04:00
parent 7879dd1768
commit 617a158097
2 changed files with 154 additions and 251 deletions

View File

@ -4,18 +4,19 @@
package codec
import (
"encoding/binary"
"errors"
"fmt"
"math"
"reflect"
"unicode"
"github.com/ava-labs/gecko/utils/wrappers"
)
const (
defaultMaxSize = 1 << 18 // default max size, in bytes, of something being marshalled by Marshal()
defaultMaxSliceLength = 1 << 18 // default max length of a slice being marshalled by Marshal()
maxStringLen = math.MaxInt16
maxStringLen = math.MaxUint16
)
// ErrBadCodec is returned when one tries to perform an operation
@ -84,8 +85,7 @@ func (c codec) RegisterType(val interface{}) error {
// structs, slices and arrays can only be serialized if their constituent values can be.
// 5) To marshal an interface, you must pass a pointer to the value
// 6) To unmarshal an interface, you must call codec.RegisterType([instance of the type that fulfills the interface]).
// 7) nil slices will be unmarshaled as an empty slice of the appropriate type
// 8) Serialized fields must be exported
// 7) Serialized fields must be exported
// To marshal an interface, [value] must be a pointer to the interface
func (c codec) Marshal(value interface{}) ([]byte, error) {
@ -97,11 +97,11 @@ func (c codec) Marshal(value interface{}) ([]byte, error) {
return nil, err
}
bytes := make([]byte, size, size)
if err := f(bytes); err != nil {
p := &wrappers.Packer{MaxSize: size, Bytes: make([]byte, 0, size)}
if err := f(p); err != nil {
return nil, err
}
return bytes, nil
return p.Bytes, nil
}
// marshal returns:
@ -110,7 +110,7 @@ func (c codec) Marshal(value interface{}) ([]byte, error) {
// and returns the number of bytes it wrote.
// When these functions are called in order, they write [value] to a byte slice.
// 3) An error
func (c codec) marshal(value reflect.Value) (size int, f func([]byte) error, err error) {
func (c codec) marshal(value reflect.Value) (size int, f func(*wrappers.Packer) error, err error) {
valueKind := value.Kind()
// Case: Value can't be marshalled
@ -125,116 +125,73 @@ func (c codec) marshal(value reflect.Value) (size int, f func([]byte) error, err
switch valueKind {
case reflect.Uint8:
size = 1
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 1 {
return fmt.Errorf("expected len(bytes) to be at least 1 but is %d", bytesLen)
}
copy(b, []byte{byte(value.Uint())})
return nil
f = func(p *wrappers.Packer) error {
p.PackByte(byte(value.Uint()))
return p.Err
}
return
case reflect.Int8:
size = 1
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 1 {
return fmt.Errorf("expected len(bytes) to be at least 1 but is %d", bytesLen)
}
copy(b, []byte{byte(value.Int())})
return nil
f = func(p *wrappers.Packer) error {
p.PackByte(byte(value.Int()))
return p.Err
}
return
case reflect.Uint16:
size = 2
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 2 {
return fmt.Errorf("expected len(bytes) to be at least 2 but is %d", bytesLen)
}
binary.BigEndian.PutUint16(b, uint16(value.Uint()))
return nil
f = func(p *wrappers.Packer) error {
p.PackShort(uint16(value.Uint()))
return p.Err
}
return
case reflect.Int16:
size = 2
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 2 {
return fmt.Errorf("expected len(bytes) to be at least 2 but is %d", bytesLen)
}
binary.BigEndian.PutUint16(b, uint16(value.Int()))
return nil
f = func(p *wrappers.Packer) error {
p.PackShort(uint16(value.Int()))
return p.Err
}
return
case reflect.Uint32:
size = 4
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 4 {
return fmt.Errorf("expected len(bytes) to be at least 4 but is %d", bytesLen)
}
binary.BigEndian.PutUint32(b, uint32(value.Uint()))
return nil
f = func(p *wrappers.Packer) error {
p.PackInt(uint32(value.Uint()))
return p.Err
}
return
case reflect.Int32:
size = 4
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 4 {
return fmt.Errorf("expected len(bytes) to be at least 4 but is %d", bytesLen)
}
binary.BigEndian.PutUint32(b, uint32(value.Int()))
return nil
f = func(p *wrappers.Packer) error {
p.PackInt(uint32(value.Int()))
return p.Err
}
return
case reflect.Uint64:
size = 8
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 8 {
return fmt.Errorf("expected len(bytes) to be at least 8 but is %d", bytesLen)
}
binary.BigEndian.PutUint64(b, uint64(value.Uint()))
return nil
f = func(p *wrappers.Packer) error {
p.PackLong(uint64(value.Uint()))
return p.Err
}
return
case reflect.Int64:
size = 8
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 8 {
return fmt.Errorf("expected len(bytes) to be at least 8 but is %d", bytesLen)
}
binary.BigEndian.PutUint64(b, uint64(value.Int()))
return nil
f = func(p *wrappers.Packer) error {
p.PackLong(uint64(value.Int()))
return p.Err
}
return
case reflect.String:
asStr := value.String()
strSize := len(asStr)
if strSize > maxStringLen {
return 0, nil, errSliceTooLarge
}
size = strSize + 2
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < size {
return fmt.Errorf("expected len(bytes) to be at least %d but is %d", size, bytesLen)
}
binary.BigEndian.PutUint16(b, uint16(strSize))
if strSize == 0 {
return nil
}
copy(b[2:], []byte(asStr))
return nil
size = len(asStr) + 2
f = func(p *wrappers.Packer) error {
p.PackStr(asStr)
return p.Err
}
return
case reflect.Bool:
size = 1
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 1 {
return fmt.Errorf("expected len(bytes) to be at least 1 but is %d", bytesLen)
}
if value.Bool() {
copy(b, []byte{1})
} else {
copy(b, []byte{0})
}
return nil
f = func(p *wrappers.Packer) error {
p.PackBool(value.Bool())
return p.Err
}
return
case reflect.Uintptr, reflect.Ptr:
@ -251,51 +208,41 @@ func (c codec) marshal(value reflect.Value) (size int, f func([]byte) error, err
}
size = 4 + subsize // 4 because we pack the type ID, a uint32
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 4+subsize {
return fmt.Errorf("expected len(bytes) to be at least %d but is %d", 4+subsize, bytesLen)
f = func(p *wrappers.Packer) error {
p.PackInt(typeID)
if p.Err != nil {
return p.Err
}
binary.BigEndian.PutUint32(b, uint32(typeID))
if len(b) == 4 {
return nil
}
return subfunc(b[4:])
return subfunc(p)
}
return
case reflect.Slice:
if value.IsNil() {
size = 1
f = func(b []byte) error {
if bytesLen := len(b); bytesLen < 1 {
return fmt.Errorf("expected len(bytes) to be at least 1 but is %d", bytesLen)
}
b[0] = 1 // slice is nil; set isNil flag to 1
return nil
f = func(p *wrappers.Packer) error {
p.PackBool(true) // slice is nil; set isNil flag to 1
return p.Err
}
return
}
numElts := value.Len() // # elements in the slice/array (assumed to be <= 2^31 - 1)
numElts := value.Len() // # elements in the slice/array (assumed to be <= math.MaxUint16)
if numElts > c.maxSliceLen {
return 0, nil, fmt.Errorf("slice length, %d, exceeds maximum length, %d", numElts, math.MaxUint32)
}
size = 5 // 1 for the isNil flag. 0 --> this slice isn't nil. 1--> it is nil.
// 4 for the size of the slice (uint32)
size = 3 // 1 for the isNil flag. 2 for the size of the slice (uint16)
// offsets[i] is the index in the byte array that subFuncs[i] will start writing at
offsets := make([]int, numElts+1, numElts+1)
if numElts != 0 {
offsets[1] = 5 // 1 for nil flag, 4 for slice size
offsets[1] = 3
}
subFuncs := make([]func([]byte) error, numElts+1, numElts+1)
subFuncs[0] = func(b []byte) error { // write the nil flag and number of elements
if bytesLen := len(b); bytesLen < 5 {
return fmt.Errorf("expected len(bytes) to be at least 5 but is %d", bytesLen)
}
b[0] = 0 // slice is non-nil; set isNil flag to 0
binary.BigEndian.PutUint32(b[1:], uint32(numElts))
return nil
subFuncs := make([]func(*wrappers.Packer) error, numElts+1, numElts+1)
subFuncs[0] = func(p *wrappers.Packer) error { // write the nil flag and number of elements
p.PackBool(false)
p.PackShort(uint16(numElts))
return p.Err
}
for i := 1; i < numElts+1; i++ { // Process each element in the slice
subSize, subFunc, subErr := c.marshal(value.Index(i - 1))
@ -313,14 +260,9 @@ func (c codec) marshal(value reflect.Value) (size int, f func([]byte) error, err
return 0, nil, fmt.Errorf("expected len(subFuncs) = %d. len(offsets) = %d. Should be same", subFuncsLen, len(offsets))
}
f = func(b []byte) error {
bytesLen := len(b)
for i, f := range subFuncs {
offset := offsets[i]
if offset > bytesLen {
return fmt.Errorf("attempted out of bounds slice. offset: %d. bytesLen: %d", offset, bytesLen)
}
if err := f(b[offset:]); err != nil {
f = func(p *wrappers.Packer) error {
for _, f := range subFuncs {
if err := f(p); err != nil {
return err
}
}
@ -328,7 +270,7 @@ func (c codec) marshal(value reflect.Value) (size int, f func([]byte) error, err
}
return
case reflect.Array:
numElts := value.Len() // # elements in the slice/array (assumed to be <= 2^31 - 1)
numElts := value.Len() // # elements in the slice/array (assumed to be <= math.MaxUint16)
if numElts > math.MaxUint32 {
return 0, nil, fmt.Errorf("array length, %d, exceeds maximum length, %d", numElts, math.MaxUint32)
}
@ -337,7 +279,7 @@ func (c codec) marshal(value reflect.Value) (size int, f func([]byte) error, err
// offsets[i] is the index in the byte array that subFuncs[i] will start writing at
offsets := make([]int, numElts, numElts)
offsets[1] = 4 // 4 for slice size
subFuncs := make([]func([]byte) error, numElts, numElts)
subFuncs := make([]func(*wrappers.Packer) error, numElts, numElts)
for i := 0; i < numElts; i++ { // Process each element in the array
subSize, subFunc, subErr := c.marshal(value.Index(i))
if subErr != nil {
@ -354,14 +296,9 @@ func (c codec) marshal(value reflect.Value) (size int, f func([]byte) error, err
return 0, nil, fmt.Errorf("expected len(subFuncs) = %d. len(offsets) = %d. Should be same", subFuncsLen, len(offsets))
}
f = func(b []byte) error {
bytesLen := len(b)
for i, f := range subFuncs {
offset := offsets[i]
if offset > bytesLen {
return fmt.Errorf("attempted out of bounds slice")
}
if err := f(b[offset:]); err != nil {
f = func(p *wrappers.Packer) error {
for _, f := range subFuncs {
if err := f(p); err != nil {
return err
}
}
@ -375,7 +312,7 @@ func (c codec) marshal(value reflect.Value) (size int, f func([]byte) error, err
// offsets[i] is the index in the byte array that subFuncs[i] will start writing at
offsets := make([]int, 0, numFields)
offsets = append(offsets, 0)
subFuncs := make([]func([]byte) error, 0, numFields)
subFuncs := make([]func(*wrappers.Packer) error, 0, numFields)
for i := 0; i < numFields; i++ { // Go through all fields of this struct
field := t.Field(i)
if !shouldSerialize(field) { // Skip fields we don't need to serialize
@ -396,14 +333,9 @@ func (c codec) marshal(value reflect.Value) (size int, f func([]byte) error, err
}
}
f = func(b []byte) error {
bytesLen := len(b)
for i, f := range subFuncs {
offset := offsets[i]
if offset > bytesLen {
return fmt.Errorf("attempted out of bounds slice")
}
if err := f(b[offset:]); err != nil {
f = func(p *wrappers.Packer) error {
for _, f := range subFuncs {
if err := f(p); err != nil {
return err
}
}
@ -430,94 +362,95 @@ func (c codec) Unmarshal(bytes []byte, dest interface{}) error {
return errNeedPointer
}
p := &wrappers.Packer{MaxSize: c.maxSize, Bytes: bytes}
destVal := destPtr.Elem()
bytesRead, err := c.unmarshal(bytes, destVal)
if err != nil {
if err := c.unmarshal(p, destVal); err != nil {
return err
}
if l := len(bytes); l != bytesRead {
return fmt.Errorf("%d leftover bytes after unmarshalling", l-bytesRead)
}
return nil
}
// Unmarshal bytes from [bytes] into [field]
// [field] must be addressable
// Returns the number of bytes read from [bytes]
func (c codec) unmarshal(bytes []byte, field reflect.Value) (int, error) {
bytesLen := len(bytes)
func (c codec) unmarshal(p *wrappers.Packer, field reflect.Value) error {
kind := field.Kind()
switch kind {
case reflect.Uint8:
if bytesLen < 1 {
return 0, fmt.Errorf("expected len(bytes) to be at least 1 but is %d", bytesLen)
b := p.UnpackByte()
if p.Err != nil {
return p.Err
}
field.SetUint(uint64(bytes[0]))
return 1, nil
field.SetUint(uint64(b))
return nil
case reflect.Int8:
if bytesLen < 1 {
return 0, fmt.Errorf("expected len(bytes) to be at least 1 but is %d", bytesLen)
b := p.UnpackByte()
if p.Err != nil {
return p.Err
}
field.SetInt(int64(bytes[0]))
return 1, nil
field.SetInt(int64(b))
return nil
case reflect.Uint16:
if bytesLen < 2 {
return 0, fmt.Errorf("expected len(bytes) to be at least 2 but is %d", bytesLen)
b := p.UnpackShort()
if p.Err != nil {
return p.Err
}
field.SetUint(uint64(binary.BigEndian.Uint16(bytes)))
return 2, nil
field.SetUint(uint64(b))
return nil
case reflect.Int16:
if bytesLen < 2 {
return 0, fmt.Errorf("expected len(bytes) to be at least 2 but is %d", bytesLen)
b := p.UnpackShort()
if p.Err != nil {
return p.Err
}
field.SetInt(int64(binary.BigEndian.Uint16(bytes)))
return 2, nil
field.SetInt(int64(b))
return nil
case reflect.Uint32:
if bytesLen < 4 {
return 0, fmt.Errorf("expected len(bytes) to be at least 4 but is %d", bytesLen)
b := p.UnpackInt()
if p.Err != nil {
return p.Err
}
field.SetUint(uint64(binary.BigEndian.Uint32(bytes)))
return 4, nil
field.SetUint(uint64(b))
return nil
case reflect.Int32:
if bytesLen < 4 {
return 0, fmt.Errorf("expected len(bytes) to be at least 4 but is %d", bytesLen)
b := p.UnpackInt()
if p.Err != nil {
return p.Err
}
field.SetInt(int64(binary.BigEndian.Uint32(bytes)))
return 4, nil
field.SetInt(int64(b))
return nil
case reflect.Uint64:
if bytesLen < 4 {
return 0, fmt.Errorf("expected len(bytes) to be at least 8 but is %d", bytesLen)
b := p.UnpackLong()
if p.Err != nil {
return p.Err
}
field.SetUint(uint64(binary.BigEndian.Uint64(bytes)))
return 8, nil
field.SetUint(uint64(b))
return nil
case reflect.Int64:
if bytesLen < 4 {
return 0, fmt.Errorf("expected len(bytes) to be at least 8 but is %d", bytesLen)
b := p.UnpackLong()
if p.Err != nil {
return p.Err
}
field.SetInt(int64(binary.BigEndian.Uint64(bytes)))
return 8, nil
field.SetInt(int64(b))
return nil
case reflect.Bool:
if bytesLen < 1 {
return 0, fmt.Errorf("expected len(bytes) to be at least 1 but is %d", bytesLen)
b := p.UnpackBool()
if p.Err != nil {
return p.Err
}
if bytes[0] == 0 {
field.SetBool(false)
} else {
field.SetBool(true)
}
return 1, nil
field.SetBool(b)
return nil
case reflect.Slice:
if bytesLen < 1 {
return 0, fmt.Errorf("expected len(bytes) to be at least 1 but is %d", bytesLen)
isNil := p.UnpackBool()
if p.Err != nil {
return p.Err
}
if bytes[0] == 1 { // isNil flag is 1 --> this slice is nil
return 1, nil
if isNil { // slice is nil
return nil
}
numElts := int(binary.BigEndian.Uint32(bytes[1:])) // number of elements in the slice
if numElts > c.maxSliceLen {
return 0, fmt.Errorf("slice length, %d, exceeds maximum, %d", numElts, c.maxSliceLen)
numElts := int(p.UnpackShort())
if p.Err != nil {
return p.Err
}
// set [field] to be a slice of the appropriate type/capacity (right now [field] is nil)
@ -525,113 +458,83 @@ func (c codec) unmarshal(bytes []byte, field reflect.Value) (int, error) {
field.Set(slice)
// Unmarshal each element into the appropriate index of the slice
bytesRead := 5 // 1 for isNil flag, 4 for numElts
for i := 0; i < numElts; i++ {
if bytesRead > bytesLen {
return 0, fmt.Errorf("attempted out of bounds slice")
if err := c.unmarshal(p, field.Index(i)); err != nil {
return err
}
n, err := c.unmarshal(bytes[bytesRead:], field.Index(i))
if err != nil {
return 0, err
}
bytesRead += n
}
return bytesRead, nil
return nil
case reflect.Array:
bytesRead := 0
for i := 0; i < field.Len(); i++ {
if bytesRead > bytesLen {
return 0, fmt.Errorf("attempted out of bounds slice")
if err := c.unmarshal(p, field.Index(i)); err != nil {
return err
}
n, err := c.unmarshal(bytes[bytesRead:], field.Index(i))
if err != nil {
return 0, err
}
bytesRead += n
}
return bytesRead, nil
return nil
case reflect.String:
if bytesLen < 2 {
return 0, fmt.Errorf("expected len(bytes) to be at least 2 but is %d", bytesLen)
str := p.UnpackStr()
if p.Err != nil {
return p.Err
}
strLen := int(binary.BigEndian.Uint16(bytes))
if bytesLen < 2+strLen {
return 0, fmt.Errorf("expected len(bytes) to be at least %d but is %d", 2+strLen, bytesLen)
}
if strLen > 0 {
field.SetString(string(bytes[2 : 2+strLen]))
} else {
field.SetString("")
}
return strLen + 2, nil
field.SetString(str)
return nil
case reflect.Interface:
if bytesLen < 4 {
return 0, fmt.Errorf("expected len(bytes) to be at least 4 but is %d", bytesLen)
typeID := p.UnpackInt() // Get the type ID
if p.Err != nil {
return p.Err
}
// Get the type ID
typeID := binary.BigEndian.Uint32(bytes)
// Get a struct that implements the interface
typ, ok := c.typeIDToType[typeID]
if !ok {
return 0, errUnmarshalUnregisteredType
return errUnmarshalUnregisteredType
}
// Ensure struct actually does implement the interface
fieldType := field.Type()
if !typ.Implements(fieldType) {
return 0, fmt.Errorf("%s does not implement interface %s", typ, fieldType)
return fmt.Errorf("%s does not implement interface %s", typ, fieldType)
}
concreteInstancePtr := reflect.New(typ) // instance of the proper type
// Unmarshal into the struct
n, err := c.unmarshal(bytes[4:], concreteInstancePtr.Elem())
if err != nil {
return 0, err
if err := c.unmarshal(p, concreteInstancePtr.Elem()); err != nil {
return err
}
// And assign the filled struct to the field
field.Set(concreteInstancePtr.Elem())
return n + 4, nil
return nil
case reflect.Struct:
// Type of this struct
structType := reflect.TypeOf(field.Interface())
// Go through all the fields and umarshal into each
bytesRead := 0
for i := 0; i < structType.NumField(); i++ {
structField := structType.Field(i)
if !shouldSerialize(structField) { // Skip fields we don't need to unmarshal
continue
}
if unicode.IsLower(rune(structField.Name[0])) { // Only unmarshal into exported field
return 0, errUnmarshalUnexportedField
return errUnmarshalUnexportedField
}
field := field.Field(i) // Get the field
if bytesRead > bytesLen {
return 0, fmt.Errorf("attempted out of bounds slice")
field := field.Field(i) // Get the field
if err := c.unmarshal(p, field); err != nil { // Unmarshal into the field
return err
}
n, err := c.unmarshal(bytes[bytesRead:], field) // Unmarshal into the field
if err != nil {
return 0, err
}
bytesRead += n
}
return bytesRead, nil
return nil
case reflect.Ptr:
// Get the type this pointer points to
underlyingType := field.Type().Elem()
// Create a new pointer to a new value of the underlying type
underlyingValue := reflect.New(underlyingType)
// Fill the value
n, err := c.unmarshal(bytes, underlyingValue.Elem())
if err != nil {
return 0, err
if err := c.unmarshal(p, underlyingValue.Elem()); err != nil {
return err
}
// Assign to the top-level struct's member
field.Set(underlyingValue)
return n, nil
return nil
case reflect.Invalid:
return 0, errNil
return errNil
default:
return 0, errUnknownType
return errUnknownType
}
}

View File

@ -474,7 +474,7 @@ func TestEmptySliceSerialization(t *testing.T) {
codec := NewDefault()
val := &simpleSliceStruct{Arr: make([]uint32, 0, 1)}
expected := []byte{0, 0, 0, 0, 0} // 0 for isNil flag, 0 for size
expected := []byte{0, 0, 0} // 0 for isNil flag, 0 for size
result, err := codec.Marshal(val)
if err != nil {
t.Fatal(err)
@ -507,7 +507,7 @@ func TestSliceWithEmptySerialization(t *testing.T) {
val := &nestedSliceStruct{
Arr: make([]emptyStruct, 1000),
}
expected := []byte{0x00, 0x00, 0x00, 0x03, 0xE8} // 0 for isNil flag, then 1000 for numElts
expected := []byte{0x00, 0x03, 0xE8} // 0 for isNil flag, then 1000 for numElts
result, err := codec.Marshal(val)
if err != nil {
t.Fatal(err)