package binary import ( "errors" "fmt" "io" "reflect" ) type TypeInfo struct { Type reflect.Type // The type Encoder Encoder // Optional custom encoder function Decoder Decoder // Optional custom decoder function HasTypeByte bool TypeByte byte } // If a type implements TypeByte, the byte is included // as the first byte for encoding. This is used to encode // interfaces/union types. In this case the decoding should // be done manually with a switch statement, and so the // reflection-based decoder provided here does not expect this // prefix byte. // See the reactor implementations for use-cases. type HasTypeByte interface { TypeByte() byte } var typeInfos = map[reflect.Type]*TypeInfo{} func RegisterType(info *TypeInfo) *TypeInfo { // Register the type info typeInfos[info.Type] = info // Also register the underlying struct's info, if info.Type is a pointer. // Or, if info.Type is not a pointer, register the pointer. if info.Type.Kind() == reflect.Ptr { rt := info.Type.Elem() typeInfos[rt] = info } else { ptrRt := reflect.PtrTo(info.Type) typeInfos[ptrRt] = info } // See if the type implements HasTypeByte if info.Type.Implements(reflect.TypeOf((*HasTypeByte)(nil)).Elem()) { zero := reflect.Zero(info.Type) typeByte := zero.Interface().(HasTypeByte).TypeByte() if info.HasTypeByte && info.TypeByte != typeByte { panic(fmt.Sprintf("Type %v expected TypeByte of %X", info.Type, typeByte)) } info.HasTypeByte = true info.TypeByte = typeByte } return info } func readReflect(rv reflect.Value, rt reflect.Type, r io.Reader, n *int64, err *error) { // First, create a new struct if rv is nil pointer. if rt.Kind() == reflect.Ptr && rv.IsNil() { newRv := reflect.New(rt.Elem()) rv.Set(newRv) rv = newRv } // Dereference pointer // Still addressable, thus settable! if rv.Kind() == reflect.Ptr { rv, rt = rv.Elem(), rt.Elem() } // Get typeInfo typeInfo := typeInfos[rt] if typeInfo == nil { typeInfo = RegisterType(&TypeInfo{Type: rt}) } // Custom decoder if typeInfo.Decoder != nil { decoded := typeInfo.Decoder(r, n, err) decodedRv := reflect.Indirect(reflect.ValueOf(decoded)) rv.Set(decodedRv) return } // Read TypeByte prefix if typeInfo.HasTypeByte { typeByte := ReadByte(r, n, err) if typeByte != typeInfo.TypeByte { *err = errors.New(fmt.Sprintf("Expected TypeByte of %X but got %X", typeInfo.TypeByte, typeByte)) return } } switch rt.Kind() { case reflect.Slice: elemRt := rt.Elem() if elemRt.Kind() == reflect.Uint8 { // Special case: Byteslices byteslice := ReadByteSlice(r, n, err) rv.Set(reflect.ValueOf(byteslice)) } else { // Read length length := int(ReadUvarint(r, n, err)) sliceRv := reflect.MakeSlice(rt, length, length) // Read elems for i := 0; i < length; i++ { elemRv := sliceRv.Index(i) readReflect(elemRv, elemRt, r, n, err) } rv.Set(sliceRv) } case reflect.Struct: numFields := rt.NumField() for i := 0; i < numFields; i++ { field := rt.Field(i) if field.PkgPath != "" { continue } fieldRv := rv.Field(i) readReflect(fieldRv, field.Type, r, n, err) } case reflect.String: str := ReadString(r, n, err) rv.SetString(str) case reflect.Int64: num := ReadUint64(r, n, err) rv.SetInt(int64(num)) case reflect.Int32: num := ReadUint32(r, n, err) rv.SetInt(int64(num)) case reflect.Int16: num := ReadUint16(r, n, err) rv.SetInt(int64(num)) case reflect.Int8: num := ReadUint8(r, n, err) rv.SetInt(int64(num)) case reflect.Int: num := ReadUvarint(r, n, err) rv.SetInt(int64(num)) case reflect.Uint64: num := ReadUint64(r, n, err) rv.SetUint(uint64(num)) case reflect.Uint32: num := ReadUint32(r, n, err) rv.SetUint(uint64(num)) case reflect.Uint16: num := ReadUint16(r, n, err) rv.SetUint(uint64(num)) case reflect.Uint8: num := ReadUint8(r, n, err) rv.SetUint(uint64(num)) case reflect.Uint: num := ReadUvarint(r, n, err) rv.SetUint(uint64(num)) default: panic(fmt.Sprintf("Unknown field type %v", rt.Kind())) } } func writeReflect(rv reflect.Value, rt reflect.Type, w io.Writer, n *int64, err *error) { // Get typeInfo typeInfo := typeInfos[rt] if typeInfo == nil { typeInfo = RegisterType(&TypeInfo{Type: rt}) } // Custom encoder, say for an interface type rt. if typeInfo.Encoder != nil { typeInfo.Encoder(rv.Interface(), w, n, err) return } // Dereference pointer or interface if rt.Kind() == reflect.Ptr { rt = rt.Elem() rv = rv.Elem() // RegisterType registers the ptr type, // so typeInfo is already for the ptr. } else if rt.Kind() == reflect.Interface { rv = rv.Elem() rt = rv.Type() typeInfo = typeInfos[rt] // If interface type, get typeInfo of underlying type. if typeInfo == nil { typeInfo = RegisterType(&TypeInfo{Type: rt}) } } // Write TypeByte prefix if typeInfo.HasTypeByte { WriteByte(typeInfo.TypeByte, w, n, err) } switch rt.Kind() { case reflect.Slice: elemRt := rt.Elem() if elemRt.Kind() == reflect.Uint8 { // Special case: Byteslices byteslice := rv.Interface().([]byte) WriteByteSlice(byteslice, w, n, err) } else { // Write length length := rv.Len() WriteUvarint(uint(length), w, n, err) // Write elems for i := 0; i < length; i++ { elemRv := rv.Index(i) writeReflect(elemRv, elemRt, w, n, err) } } case reflect.Struct: numFields := rt.NumField() for i := 0; i < numFields; i++ { field := rt.Field(i) if field.PkgPath != "" { continue } fieldRv := rv.Field(i) writeReflect(fieldRv, field.Type, w, n, err) } case reflect.String: WriteString(rv.String(), w, n, err) case reflect.Int64: WriteInt64(rv.Int(), w, n, err) case reflect.Int32: WriteInt32(int32(rv.Int()), w, n, err) case reflect.Int16: WriteInt16(int16(rv.Int()), w, n, err) case reflect.Int8: WriteInt8(int8(rv.Int()), w, n, err) case reflect.Int: WriteVarint(int(rv.Int()), w, n, err) case reflect.Uint64: WriteUint64(rv.Uint(), w, n, err) case reflect.Uint32: WriteUint32(uint32(rv.Uint()), w, n, err) case reflect.Uint16: WriteUint16(uint16(rv.Uint()), w, n, err) case reflect.Uint8: WriteUint8(uint8(rv.Uint()), w, n, err) case reflect.Uint: WriteUvarint(uint(rv.Uint()), w, n, err) default: panic(fmt.Sprintf("Unknown field type %v", rt.Kind())) } }