cache serializable fields of struct types; change codec methods to be on pointer type; change variable names; change benchmark toinclude both marshaling and unmarshaling

This commit is contained in:
Dan Laine 2020-06-14 10:56:43 -04:00
parent 954074abcc
commit f6cabee51b
2 changed files with 82 additions and 65 deletions

View File

@ -41,6 +41,8 @@ type codec struct {
typeIDToType map[uint32]reflect.Type typeIDToType map[uint32]reflect.Type
typeToTypeID map[reflect.Type]uint32 typeToTypeID map[reflect.Type]uint32
serializedFields map[reflect.Type][]int
} }
// Codec marshals and unmarshals // Codec marshals and unmarshals
@ -52,7 +54,7 @@ type Codec interface {
// New returns a new codec // New returns a new codec
func New(maxSize, maxSliceLen int) Codec { func New(maxSize, maxSliceLen int) Codec {
return codec{ return &codec{
maxSize: maxSize, maxSize: maxSize,
maxSliceLen: maxSliceLen, maxSliceLen: maxSliceLen,
typeIDToType: map[uint32]reflect.Type{}, typeIDToType: map[uint32]reflect.Type{},
@ -65,7 +67,7 @@ func NewDefault() Codec { return New(defaultMaxSize, defaultMaxSliceLength) }
// RegisterType is used to register types that may be unmarshaled into an interface // RegisterType is used to register types that may be unmarshaled into an interface
// [val] is a value of the type being registered // [val] is a value of the type being registered
func (c codec) RegisterType(val interface{}) error { func (c *codec) RegisterType(val interface{}) error {
valType := reflect.TypeOf(val) valType := reflect.TypeOf(val)
if _, exists := c.typeToTypeID[valType]; exists { if _, exists := c.typeToTypeID[valType]; exists {
return fmt.Errorf("type %v has already been registered", valType) return fmt.Errorf("type %v has already been registered", valType)
@ -89,7 +91,7 @@ func (c codec) RegisterType(val interface{}) error {
// 8) nil slices are marshaled as empty slices // 8) nil slices are marshaled as empty slices
// To marshal an interface, [value] must be a pointer to the interface // To marshal an interface, [value] must be a pointer to the interface
func (c codec) Marshal(value interface{}) ([]byte, error) { func (c *codec) Marshal(value interface{}) ([]byte, error) {
if value == nil { if value == nil {
return nil, errNil return nil, errNil
} }
@ -118,7 +120,7 @@ func (c codec) Marshal(value interface{}) ([]byte, error) {
// and returns the number of bytes it wrote. // and returns the number of bytes it wrote.
// When these functions are called in order, they write [value] to a byte slice. // When these functions are called in order, they write [value] to a byte slice.
// 3) An error // 3) An error
func (c codec) marshal(value reflect.Value, index int, funcs *[]func(*wrappers.Packer) error) (size int, funcsWritten int, err error) { func (c *codec) marshal(value reflect.Value, index int, funcs *[]func(*wrappers.Packer) error) (size int, funcsWritten int, err error) {
valueKind := value.Kind() valueKind := value.Kind()
// Case: Value can't be marshalled // Case: Value can't be marshalled
@ -238,10 +240,7 @@ func (c codec) marshal(value reflect.Value, index int, funcs *[]func(*wrappers.P
size = 4 + subsize // 4 because we pack the type ID, a uint32 size = 4 + subsize // 4 because we pack the type ID, a uint32
(*funcs)[index] = func(p *wrappers.Packer) error { (*funcs)[index] = func(p *wrappers.Packer) error {
p.PackInt(typeID) p.PackInt(typeID)
if p.Err != nil { return p.Err
return p.Err
}
return nil
} }
funcsWritten = 1 + subFuncsWritten funcsWritten = 1 + subFuncsWritten
return return
@ -265,10 +264,7 @@ func (c codec) marshal(value reflect.Value, index int, funcs *[]func(*wrappers.P
numEltsAsUint32 := uint32(numElts) numEltsAsUint32 := uint32(numElts)
(*funcs)[index] = func(p *wrappers.Packer) error { (*funcs)[index] = func(p *wrappers.Packer) error {
p.PackInt(numEltsAsUint32) // pack # elements p.PackInt(numEltsAsUint32) // pack # elements
if p.Err != nil { return p.Err
return p.Err
}
return nil
} }
funcsWritten = subFuncsWritten + 1 funcsWritten = subFuncsWritten + 1
return return
@ -291,20 +287,17 @@ func (c codec) marshal(value reflect.Value, index int, funcs *[]func(*wrappers.P
return return
case reflect.Struct: case reflect.Struct:
t := value.Type() t := value.Type()
numFields := t.NumField()
size = 0 size = 0
fieldsMarshalled := 0 fieldsMarshalled := 0
funcsWritten = 0 funcsWritten = 0
for i := 0; i < numFields; i++ { // Go through all fields of this struct serializedFields, subErr := c.getSerializedFieldIndices(t)
field := t.Field(i) if subErr != nil {
if !shouldSerialize(field) { // Skip fields we don't need to serialize return 0, 0, subErr
continue }
}
if unicode.IsLower(rune(field.Name[0])) { // Can only marshal exported fields for _, f := range serializedFields { // Go through all fields of this struct
return 0, 0, fmt.Errorf("can't marshal unexported field %s", field.Name) fieldVal := value.Field(f) // The field we're serializing
}
fieldVal := value.Field(i) // The field we're serializing
subSize, n, err := c.marshal(fieldVal, index+funcsWritten, funcs) // Serialize the field subSize, n, err := c.marshal(fieldVal, index+funcsWritten, funcs) // Serialize the field
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
@ -321,7 +314,7 @@ func (c codec) marshal(value reflect.Value, index int, funcs *[]func(*wrappers.P
// Unmarshal unmarshals [bytes] into [dest], where // Unmarshal unmarshals [bytes] into [dest], where
// [dest] must be a pointer or interface // [dest] must be a pointer or interface
func (c codec) Unmarshal(bytes []byte, dest interface{}) error { func (c *codec) Unmarshal(bytes []byte, dest interface{}) error {
switch { switch {
case len(bytes) > c.maxSize: case len(bytes) > c.maxSize:
return errSliceTooLarge return errSliceTooLarge
@ -343,92 +336,90 @@ func (c codec) Unmarshal(bytes []byte, dest interface{}) error {
return nil return nil
} }
// Unmarshal bytes from [bytes] into [field] // Unmarshal from [bytes] into [value]. [value] must be addressable
// [field] must be addressable func (c *codec) unmarshal(p *wrappers.Packer, value reflect.Value) error {
func (c codec) unmarshal(p *wrappers.Packer, field reflect.Value) error { switch value.Kind() {
kind := field.Kind()
switch kind {
case reflect.Uint8: case reflect.Uint8:
b := p.UnpackByte() b := p.UnpackByte()
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
field.SetUint(uint64(b)) value.SetUint(uint64(b))
return nil return nil
case reflect.Int8: case reflect.Int8:
b := p.UnpackByte() b := p.UnpackByte()
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
field.SetInt(int64(b)) value.SetInt(int64(b))
return nil return nil
case reflect.Uint16: case reflect.Uint16:
b := p.UnpackShort() b := p.UnpackShort()
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
field.SetUint(uint64(b)) value.SetUint(uint64(b))
return nil return nil
case reflect.Int16: case reflect.Int16:
b := p.UnpackShort() b := p.UnpackShort()
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
field.SetInt(int64(b)) value.SetInt(int64(b))
return nil return nil
case reflect.Uint32: case reflect.Uint32:
b := p.UnpackInt() b := p.UnpackInt()
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
field.SetUint(uint64(b)) value.SetUint(uint64(b))
return nil return nil
case reflect.Int32: case reflect.Int32:
b := p.UnpackInt() b := p.UnpackInt()
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
field.SetInt(int64(b)) value.SetInt(int64(b))
return nil return nil
case reflect.Uint64: case reflect.Uint64:
b := p.UnpackLong() b := p.UnpackLong()
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
field.SetUint(uint64(b)) value.SetUint(uint64(b))
return nil return nil
case reflect.Int64: case reflect.Int64:
b := p.UnpackLong() b := p.UnpackLong()
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
field.SetInt(int64(b)) value.SetInt(int64(b))
return nil return nil
case reflect.Bool: case reflect.Bool:
b := p.UnpackBool() b := p.UnpackBool()
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
field.SetBool(b) value.SetBool(b)
return nil return nil
case reflect.Slice: case reflect.Slice:
numElts := int(p.UnpackInt()) numElts := int(p.UnpackInt())
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
// set [field] to be a slice of the appropriate type/capacity (right now [field] is nil) // set [value] to be a slice of the appropriate type/capacity (right now [value] is nil)
slice := reflect.MakeSlice(field.Type(), numElts, numElts) slice := reflect.MakeSlice(value.Type(), numElts, numElts)
field.Set(slice) value.Set(slice)
// Unmarshal each element into the appropriate index of the slice // Unmarshal each element into the appropriate index of the slice
for i := 0; i < numElts; i++ { for i := 0; i < numElts; i++ {
if err := c.unmarshal(p, field.Index(i)); err != nil { if err := c.unmarshal(p, value.Index(i)); err != nil {
return err return err
} }
} }
return nil return nil
case reflect.Array: case reflect.Array:
for i := 0; i < field.Len(); i++ { for i := 0; i < value.Len(); i++ {
if err := c.unmarshal(p, field.Index(i)); err != nil { if err := c.unmarshal(p, value.Index(i)); err != nil {
return err return err
} }
} }
@ -438,7 +429,7 @@ func (c codec) unmarshal(p *wrappers.Packer, field reflect.Value) error {
if p.Err != nil { if p.Err != nil {
return p.Err return p.Err
} }
field.SetString(str) value.SetString(str)
return nil return nil
case reflect.Interface: case reflect.Interface:
typeID := p.UnpackInt() // Get the type ID typeID := p.UnpackInt() // Get the type ID
@ -451,31 +442,28 @@ func (c codec) unmarshal(p *wrappers.Packer, field reflect.Value) error {
return errUnmarshalUnregisteredType return errUnmarshalUnregisteredType
} }
// Ensure struct actually does implement the interface // Ensure struct actually does implement the interface
fieldType := field.Type() valueType := value.Type()
if !typ.Implements(fieldType) { if !typ.Implements(valueType) {
return fmt.Errorf("%s does not implement interface %s", typ, fieldType) return fmt.Errorf("%s does not implement interface %s", typ, valueType)
} }
concreteInstancePtr := reflect.New(typ) // instance of the proper type concreteInstancePtr := reflect.New(typ) // instance of the proper type
// Unmarshal into the struct // Unmarshal into the struct
if err := c.unmarshal(p, concreteInstancePtr.Elem()); err != nil { if err := c.unmarshal(p, concreteInstancePtr.Elem()); err != nil {
return err return err
} }
// And assign the filled struct to the field // And assign the filled struct to the value
field.Set(concreteInstancePtr.Elem()) value.Set(concreteInstancePtr.Elem())
return nil return nil
case reflect.Struct: case reflect.Struct:
// Type of this struct // Type of this struct
structType := reflect.TypeOf(field.Interface()) t := reflect.TypeOf(value.Interface())
serializedFieldIndices, err := c.getSerializedFieldIndices(t)
if err != nil {
return err
}
// Go through all the fields and umarshal into each // Go through all the fields and umarshal into each
for i := 0; i < structType.NumField(); i++ { for _, index := range serializedFieldIndices {
structField := structType.Field(i) field := value.Field(index) // Get the field
if !shouldSerialize(structField) { // Skip fields we don't need to unmarshal
continue
}
if unicode.IsLower(rune(structField.Name[0])) { // Only unmarshal into exported field
return errUnmarshalUnexportedField
}
field := field.Field(i) // Get the field
if err := c.unmarshal(p, field); err != nil { // Unmarshal into the field if err := c.unmarshal(p, field); err != nil { // Unmarshal into the field
return err return err
} }
@ -483,7 +471,7 @@ func (c codec) unmarshal(p *wrappers.Packer, field reflect.Value) error {
return nil return nil
case reflect.Ptr: case reflect.Ptr:
// Get the type this pointer points to // Get the type this pointer points to
underlyingType := field.Type().Elem() underlyingType := value.Type().Elem()
// Create a new pointer to a new value of the underlying type // Create a new pointer to a new value of the underlying type
underlyingValue := reflect.New(underlyingType) underlyingValue := reflect.New(underlyingType)
// Fill the value // Fill the value
@ -491,7 +479,7 @@ func (c codec) unmarshal(p *wrappers.Packer, field reflect.Value) error {
return err return err
} }
// Assign to the top-level struct's member // Assign to the top-level struct's member
field.Set(underlyingValue) value.Set(underlyingValue)
return nil return nil
case reflect.Invalid: case reflect.Invalid:
return errNil return errNil
@ -500,7 +488,27 @@ func (c codec) unmarshal(p *wrappers.Packer, field reflect.Value) error {
} }
} }
// Returns true iff [field] should be serialized // Returns the indices of the serializable fields of [t], which is a struct type
func shouldSerialize(field reflect.StructField) bool { // Returns an error if a field has tag "serialize: true" but the field is unexported
return field.Tag.Get("serialize") == "true" func (c *codec) getSerializedFieldIndices(t reflect.Type) ([]int, error) {
if c.serializedFields == nil {
c.serializedFields = make(map[reflect.Type][]int)
}
if serializedFields, ok := c.serializedFields[t]; ok {
return serializedFields, nil
}
numFields := t.NumField()
serializedFields := make([]int, 0, numFields)
for i := 0; i < numFields; i++ { // Go through all fields of this struct
field := t.Field(i)
if field.Tag.Get("serialize") != "true" { // Skip fields we don't need to serialize
continue
}
if unicode.IsLower(rune(field.Name[0])) { // Can only marshal exported fields
return []int{}, fmt.Errorf("can't marshal unexported field %s", field.Name)
}
serializedFields = append(serializedFields, i)
}
c.serializedFields[t] = serializedFields
return serializedFields, nil
} }

View File

@ -35,13 +35,22 @@ func BenchmarkMarshal(b *testing.B) {
}, },
MyPointer: &temp, MyPointer: &temp,
} }
var unmarshaledMyStructInstance myStruct
codec := NewDefault() codec := NewDefault()
codec.RegisterType(&MyInnerStruct{}) // Register the types that may be unmarshaled into interfaces codec.RegisterType(&MyInnerStruct{}) // Register the types that may be unmarshaled into interfaces
codec.RegisterType(&MyInnerStruct2{}) codec.RegisterType(&MyInnerStruct2{})
codec.Marshal(myStructInstance) // warm up serializedFields cache
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
codec.Marshal(myStructInstance) bytes, err := codec.Marshal(myStructInstance)
if err != nil {
b.Fatal(err)
}
if err := codec.Unmarshal(bytes, &unmarshaledMyStructInstance); err != nil {
b.Fatal(err)
}
} }
} }