diff --git a/vms/components/codec/codec.go b/vms/components/codec/codec.go index 29cfaef..5000a6a 100644 --- a/vms/components/codec/codec.go +++ b/vms/components/codec/codec.go @@ -42,7 +42,7 @@ type codec struct { typeIDToType map[uint32]reflect.Type typeToTypeID map[reflect.Type]uint32 - serializedFields map[reflect.Type][]int + serializedFieldIndices map[reflect.Type][]int } // Codec marshals and unmarshals @@ -55,10 +55,11 @@ type Codec interface { // New returns a new codec func New(maxSize, maxSliceLen int) Codec { return &codec{ - maxSize: maxSize, - maxSliceLen: maxSliceLen, - typeIDToType: map[uint32]reflect.Type{}, - typeToTypeID: map[reflect.Type]uint32{}, + maxSize: maxSize, + maxSliceLen: maxSliceLen, + typeIDToType: map[uint32]reflect.Type{}, + typeToTypeID: map[reflect.Type]uint32{}, + serializedFieldIndices: map[reflect.Type][]int{}, } } @@ -226,13 +227,14 @@ func (c *codec) marshal(value reflect.Value, index int, funcs *[]func(*wrappers. case reflect.Uintptr, reflect.Ptr: return c.marshal(value.Elem(), index, funcs) case reflect.Interface: - typeID, ok := c.typeToTypeID[reflect.TypeOf(value.Interface())] // Get the type ID of the value being marshaled + underlyingValue := value.Interface() + typeID, ok := c.typeToTypeID[reflect.TypeOf(underlyingValue)] // Get the type ID of the value being marshaled if !ok { - return 0, 0, fmt.Errorf("can't marshal unregistered type '%v'", reflect.TypeOf(value.Interface()).String()) + return 0, 0, fmt.Errorf("can't marshal unregistered type '%v'", reflect.TypeOf(underlyingValue).String()) } (*funcs)[index] = nil - subsize, subFuncsWritten, subErr := c.marshal(reflect.ValueOf(value.Interface()), index+1, funcs) + subsize, subFuncsWritten, subErr := c.marshal(value.Elem(), index+1, funcs) if subErr != nil { return 0, 0, subErr } @@ -408,8 +410,7 @@ func (c *codec) unmarshal(p *wrappers.Packer, value reflect.Value) error { return p.Err } // set [value] to be a slice of the appropriate type/capacity (right now [value] is nil) - slice := reflect.MakeSlice(value.Type(), numElts, numElts) - value.Set(slice) + value.Set(reflect.MakeSlice(value.Type(), numElts, numElts)) // Unmarshal each element into the appropriate index of the slice for i := 0; i < numElts; i++ { if err := c.unmarshal(p, value.Index(i)); err != nil { @@ -425,12 +426,8 @@ func (c *codec) unmarshal(p *wrappers.Packer, value reflect.Value) error { } return nil case reflect.String: - str := p.UnpackStr() - if p.Err != nil { - return p.Err - } - value.SetString(str) - return nil + value.SetString(p.UnpackStr()) + return p.Err case reflect.Interface: typeID := p.UnpackInt() // Get the type ID if p.Err != nil { @@ -456,15 +453,14 @@ func (c *codec) unmarshal(p *wrappers.Packer, value reflect.Value) error { return nil case reflect.Struct: // Type of this struct - t := reflect.TypeOf(value.Interface()) + t := value.Type() serializedFieldIndices, err := c.getSerializedFieldIndices(t) if err != nil { return err } // Go through all the fields and umarshal into each for _, index := range serializedFieldIndices { - field := value.Field(index) // Get the field - if err := c.unmarshal(p, field); err != nil { // Unmarshal into the field + if err := c.unmarshal(p, value.Field(index)); err != nil { // Unmarshal into the field return err } } @@ -491,10 +487,10 @@ func (c *codec) unmarshal(p *wrappers.Packer, value reflect.Value) error { // Returns the indices of the serializable fields of [t], which is a struct type // Returns an error if a field has tag "serialize: true" but the field is unexported func (c *codec) getSerializedFieldIndices(t reflect.Type) ([]int, error) { - if c.serializedFields == nil { - c.serializedFields = make(map[reflect.Type][]int) + if c.serializedFieldIndices == nil { + c.serializedFieldIndices = make(map[reflect.Type][]int) } - if serializedFields, ok := c.serializedFields[t]; ok { + if serializedFields, ok := c.serializedFieldIndices[t]; ok { return serializedFields, nil } numFields := t.NumField() @@ -509,6 +505,6 @@ func (c *codec) getSerializedFieldIndices(t reflect.Type) ([]int, error) { } serializedFields = append(serializedFields, i) } - c.serializedFields[t] = serializedFields + c.serializedFieldIndices[t] = serializedFields return serializedFields, nil }