diff --git a/rlp/decode.go b/rlp/decode.go new file mode 100644 index 000000000..96d912f56 --- /dev/null +++ b/rlp/decode.go @@ -0,0 +1,667 @@ +package rlp + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "reflect" +) + +var ( + errNoPointer = errors.New("rlp: interface given to Decode must be a pointer") + errDecodeIntoNil = errors.New("rlp: pointer given to Decode must not be nil") +) + +// Decoder is implemented by types that require custom RLP +// decoding rules or need to decode into private fields. +// +// The DecodeRLP method should read one value from the given +// Stream. It is not forbidden to read less or more, but it might +// be confusing. +type Decoder interface { + DecodeRLP(*Stream) error +} + +// Decode parses RLP-encoded data from r and stores the result +// in the value pointed to by val. Val must be a non-nil pointer. +// +// Decode uses the following type-dependent decoding rules: +// +// If the type implements the Decoder interface, decode calls +// DecodeRLP. +// +// To decode into a pointer, Decode will set the pointer to nil if the +// input has size zero or the input is a single byte with value zero. +// If the input has nonzero size, Decode will allocate a new value of +// the type being pointed to. +// +// To decode into a struct, Decode expects the input to be an RLP +// list. The decoded elements of the list are assigned to each public +// field in the order given by the struct's definition. If the input +// list has too few elements, no error is returned and the remaining +// fields will have the zero value. +// Recursive struct types are supported. +// +// To decode into a slice, the input must be a list and the resulting +// slice will contain the input elements in order. +// As a special case, if the slice has a byte-size element type, the input +// can also be an RLP string. +// +// To decode into a Go string, the input must be an RLP string. The +// bytes are taken as-is and will not necessarily be valid UTF-8. +// +// To decode into an integer type, the input must also be an RLP +// string. The bytes are interpreted as a big endian representation of +// the integer. If the RLP string is larger than the bit size of the +// type, Decode will return an error. Decode also supports *big.Int. +// There is no size limit for big integers. +// +// To decode into an interface value, Decode stores one of these +// in the value: +// +// []interface{}, for RLP lists +// []byte, for RLP strings +// +// Non-empty interface types are not supported, nor are bool, float32, +// float64, maps, channel types and functions. +func Decode(r ByteReader, val interface{}) error { + return NewStream(r).Decode(val) +} + +func makeNumDecoder(typ reflect.Type) decoder { + kind := typ.Kind() + switch { + case kind <= reflect.Int64: + return decodeInt + case kind <= reflect.Uint64: + return decodeUint + default: + panic("fallthrough") + } +} + +func decodeInt(s *Stream, val reflect.Value) error { + num, err := s.uint(val.Type().Bits()) + if err != nil { + return err + } + val.SetInt(int64(num)) + return nil +} + +func decodeUint(s *Stream, val reflect.Value) error { + num, err := s.uint(val.Type().Bits()) + if err != nil { + return err + } + val.SetUint(num) + return nil +} + +func decodeString(s *Stream, val reflect.Value) error { + b, err := s.Bytes() + if err != nil { + return err + } + val.SetString(string(b)) + return nil +} + +func decodeBigIntNoPtr(s *Stream, val reflect.Value) error { + return decodeBigInt(s, val.Addr()) +} + +func decodeBigInt(s *Stream, val reflect.Value) error { + b, err := s.Bytes() + if err != nil { + return err + } + i := val.Interface().(*big.Int) + if i == nil { + i = new(big.Int) + val.Set(reflect.ValueOf(i)) + } + i.SetBytes(b) + return nil +} + +const maxInt = int(^uint(0) >> 1) + +func makeListDecoder(typ reflect.Type) (decoder, error) { + etype := typ.Elem() + if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) { + if typ.Kind() == reflect.Array { + return decodeByteArray, nil + } else { + return decodeByteSlice, nil + } + } + etypeinfo, err := cachedTypeInfo1(etype) + if err != nil { + return nil, err + } + var maxLen = maxInt + if typ.Kind() == reflect.Array { + maxLen = typ.Len() + } + dec := func(s *Stream, val reflect.Value) error { + return decodeList(s, val, etypeinfo.decoder, maxLen) + } + return dec, nil +} + +// decodeList decodes RLP list elements into slices and arrays. +// +// The approach here is stolen from package json, although we differ +// in the semantics for arrays. package json discards remaining +// elements that would not fit into the array. We generate an error in +// this case because we'd be losing information. +func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) error { + size, err := s.List() + if err != nil { + return err + } + if size == 0 { + if val.Kind() == reflect.Slice { + val.Set(reflect.MakeSlice(val.Type(), 0, 0)) + } else { + zero(val, 0) + } + return s.ListEnd() + } + + i := 0 + for { + if i > maxelem { + return fmt.Errorf("rlp: input List has more than %d elements", maxelem) + } + if val.Kind() == reflect.Slice { + // grow slice if necessary + if i >= val.Cap() { + newcap := val.Cap() + val.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newv := reflect.MakeSlice(val.Type(), val.Len(), newcap) + reflect.Copy(newv, val) + val.Set(newv) + } + if i >= val.Len() { + val.SetLen(i + 1) + } + } + // decode into element + if err := elemdec(s, val.Index(i)); err == EOL { + break + } else if err != nil { + return err + } + i++ + } + if i < val.Len() { + if val.Kind() == reflect.Array { + // zero the rest of the array. + zero(val, i) + } else { + val.SetLen(i) + } + } + return s.ListEnd() +} + +func decodeByteSlice(s *Stream, val reflect.Value) error { + kind, _, err := s.Kind() + if err != nil { + return err + } + if kind == List { + return decodeList(s, val, decodeUint, maxInt) + } + b, err := s.Bytes() + if err == nil { + val.SetBytes(b) + } + return err +} + +var errStringDoesntFitArray = errors.New("rlp: string value doesn't fit into target array") + +func decodeByteArray(s *Stream, val reflect.Value) error { + kind, size, err := s.Kind() + if err != nil { + return err + } + switch kind { + case Byte: + if val.Len() == 0 { + return errStringDoesntFitArray + } + bv, _ := s.Uint() + val.Index(0).SetUint(bv) + zero(val, 1) + case String: + if uint64(val.Len()) < size { + return errStringDoesntFitArray + } + slice := val.Slice(0, int(size)).Interface().([]byte) + if err := s.readFull(slice); err != nil { + return err + } + zero(val, int(size)) + case List: + return decodeList(s, val, decodeUint, val.Len()) + } + return nil +} + +func zero(val reflect.Value, start int) { + z := reflect.Zero(val.Type().Elem()) + for i := start; i < val.Len(); i++ { + val.Index(i).Set(z) + } +} + +type field struct { + index int + info *typeinfo +} + +func makeStructDecoder(typ reflect.Type) (decoder, error) { + var fields []field + for i := 0; i < typ.NumField(); i++ { + if f := typ.Field(i); f.PkgPath == "" { // exported + info, err := cachedTypeInfo1(f.Type) + if err != nil { + return nil, err + } + fields = append(fields, field{i, info}) + } + } + dec := func(s *Stream, val reflect.Value) (err error) { + if _, err = s.List(); err != nil { + return err + } + for _, f := range fields { + err = f.info.decoder(s, val.Field(f.index)) + if err == EOL { + // too few elements. leave the rest at their zero value. + break + } else if err != nil { + return err + } + } + if err = s.ListEnd(); err == errNotAtEOL { + err = errors.New("rlp: input List has too many elements") + } + return err + } + return dec, nil +} + +func makePtrDecoder(typ reflect.Type) (decoder, error) { + etype := typ.Elem() + etypeinfo, err := cachedTypeInfo1(etype) + if err != nil { + return nil, err + } + dec := func(s *Stream, val reflect.Value) (err error) { + _, size, err := s.Kind() + if err != nil || size == 0 && s.byteval == 0 { + val.Set(reflect.Zero(typ)) // set to nil + return err + } + newval := val + if val.IsNil() { + newval = reflect.New(etype) + } + if err = etypeinfo.decoder(s, newval.Elem()); err == nil { + val.Set(newval) + } + return err + } + return dec, nil +} + +var ifsliceType = reflect.TypeOf([]interface{}{}) + +func decodeInterface(s *Stream, val reflect.Value) error { + kind, _, err := s.Kind() + if err != nil { + return err + } + if kind == List { + slice := reflect.New(ifsliceType).Elem() + if err := decodeList(s, slice, decodeInterface, maxInt); err != nil { + return err + } + val.Set(slice) + } else { + b, err := s.Bytes() + if err != nil { + return err + } + val.Set(reflect.ValueOf(b)) + } + return nil +} + +// This decoder is used for non-pointer values of types +// that implement the Decoder interface using a pointer receiver. +func decodeDecoderNoPtr(s *Stream, val reflect.Value) error { + return val.Addr().Interface().(Decoder).DecodeRLP(s) +} + +func decodeDecoder(s *Stream, val reflect.Value) error { + // Decoder instances are not handled using the pointer rule if the type + // implements Decoder with pointer receiver (i.e. always) + // because it might handle empty values specially. + // We need to allocate one here in this case, like makePtrDecoder does. + if val.Kind() == reflect.Ptr && val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) + } + return val.Interface().(Decoder).DecodeRLP(s) +} + +// Kind represents the kind of value contained in an RLP stream. +type Kind int + +const ( + Byte Kind = iota + String + List +) + +func (k Kind) String() string { + switch k { + case Byte: + return "Byte" + case String: + return "String" + case List: + return "List" + default: + return fmt.Sprintf("Unknown(%d)", k) + } +} + +var ( + // EOL is returned when the end of the current list + // has been reached during streaming. + EOL = errors.New("rlp: end of list") + + // Other errors + ErrExpectedString = errors.New("rlp: expected String or Byte") + ErrExpectedList = errors.New("rlp: expected List") + ErrElemTooLarge = errors.New("rlp: element is larger than containing list") + + // internal errors + errNotInList = errors.New("rlp: call of ListEnd outside of any list") + errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") +) + +// ByteReader must be implemented by any input reader for a Stream. It +// is implemented by e.g. bufio.Reader and bytes.Reader. +type ByteReader interface { + io.Reader + io.ByteReader +} + +// Stream can be used for piecemeal decoding of an input stream. This +// is useful if the input is very large or if the decoding rules for a +// type depend on the input structure. Stream does not keep an +// internal buffer. After decoding a value, the input reader will be +// positioned just before the type information for the next value. +// +// When decoding a list and the input position reaches the declared +// length of the list, all operations will return error EOL. +// The end of the list must be acknowledged using ListEnd to continue +// reading the enclosing list. +// +// Stream is not safe for concurrent use. +type Stream struct { + r ByteReader + uintbuf []byte + + kind Kind // kind of value ahead + size uint64 // size of value ahead + byteval byte // value of single byte in type tag + stack []listpos +} + +type listpos struct{ pos, size uint64 } + +func NewStream(r ByteReader) *Stream { + return &Stream{r: r, uintbuf: make([]byte, 8), kind: -1} +} + +// Bytes reads an RLP string and returns its contents as a byte slice. +// If the input does not contain an RLP string, the returned +// error will be ErrExpectedString. +func (s *Stream) Bytes() ([]byte, error) { + kind, size, err := s.Kind() + if err != nil { + return nil, err + } + switch kind { + case Byte: + s.kind = -1 // rearm Kind + return []byte{s.byteval}, nil + case String: + b := make([]byte, size) + if err = s.readFull(b); err != nil { + return nil, err + } + return b, nil + default: + return nil, ErrExpectedString + } +} + +// Uint reads an RLP string of up to 8 bytes and returns its contents +// as an unsigned integer. If the input does not contain an RLP string, the +// returned error will be ErrExpectedString. +func (s *Stream) Uint() (uint64, error) { + return s.uint(64) +} + +func (s *Stream) uint(maxbits int) (uint64, error) { + kind, size, err := s.Kind() + if err != nil { + return 0, err + } + switch kind { + case Byte: + s.kind = -1 // rearm Kind + return uint64(s.byteval), nil + case String: + if size > uint64(maxbits/8) { + return 0, fmt.Errorf("rlp: string is larger than %d bits", maxbits) + } + return s.readUint(byte(size)) + default: + return 0, ErrExpectedString + } +} + +// List starts decoding an RLP list. If the input does not contain a +// list, the returned error will be ErrExpectedList. When the list's +// end has been reached, any Stream operation will return EOL. +func (s *Stream) List() (size uint64, err error) { + kind, size, err := s.Kind() + if err != nil { + return 0, err + } + if kind != List { + return 0, ErrExpectedList + } + s.stack = append(s.stack, listpos{0, size}) + s.kind = -1 + s.size = 0 + return size, nil +} + +// ListEnd returns to the enclosing list. +// The input reader must be positioned at the end of a list. +func (s *Stream) ListEnd() error { + if len(s.stack) == 0 { + return errNotInList + } + tos := s.stack[len(s.stack)-1] + if tos.pos != tos.size { + return errNotAtEOL + } + s.stack = s.stack[:len(s.stack)-1] // pop + if len(s.stack) > 0 { + s.stack[len(s.stack)-1].pos += tos.size + } + s.kind = -1 + s.size = 0 + return nil +} + +// Decode decodes a value and stores the result in the value pointed +// to by val. Please see the documentation for the Decode function +// to learn about the decoding rules. +func (s *Stream) Decode(val interface{}) error { + if val == nil { + return errDecodeIntoNil + } + rval := reflect.ValueOf(val) + rtyp := rval.Type() + if rtyp.Kind() != reflect.Ptr { + return errNoPointer + } + if rval.IsNil() { + return errDecodeIntoNil + } + info, err := cachedTypeInfo(rtyp.Elem()) + if err != nil { + return err + } + return info.decoder(s, rval.Elem()) +} + +// Kind returns the kind and size of the next value in the +// input stream. +// +// The returned size is the number of bytes that make up the value. +// For kind == Byte, the size is zero because the value is +// contained in the type tag. +// +// The first call to Kind will read size information from the input +// reader and leave it positioned at the start of the actual bytes of +// the value. Subsequent calls to Kind (until the value is decoded) +// will not advance the input reader and return cached information. +func (s *Stream) Kind() (kind Kind, size uint64, err error) { + var tos *listpos + if len(s.stack) > 0 { + tos = &s.stack[len(s.stack)-1] + } + if s.kind < 0 { + if tos != nil && tos.pos == tos.size { + return 0, 0, EOL + } + kind, size, err = s.readKind() + if err != nil { + return 0, 0, err + } + s.kind, s.size = kind, size + } + if tos != nil && tos.pos+s.size > tos.size { + return 0, 0, ErrElemTooLarge + } + return s.kind, s.size, nil +} + +func (s *Stream) readKind() (kind Kind, size uint64, err error) { + b, err := s.readByte() + if err != nil { + return 0, 0, err + } + s.byteval = 0 + switch { + case b < 0x80: + // For a single byte whose value is in the [0x00, 0x7F] range, that byte + // is its own RLP encoding. + s.byteval = b + return Byte, 0, nil + case b < 0xB8: + // Otherwise, if a string is 0-55 bytes long, + // the RLP encoding consists of a single byte with value 0x80 plus the + // length of the string followed by the string. The range of the first + // byte is thus [0x80, 0xB7]. + return String, uint64(b - 0x80), nil + case b < 0xC0: + // If a string is more than 55 bytes long, the + // RLP encoding consists of a single byte with value 0xB7 plus the length + // of the length of the string in binary form, followed by the length of + // the string, followed by the string. For example, a length-1024 string + // would be encoded as 0xB90400 followed by the string. The range of + // the first byte is thus [0xB8, 0xBF]. + size, err = s.readUint(b - 0xB7) + return String, size, err + case b < 0xF8: + // If the total payload of a list + // (i.e. the combined length of all its items) is 0-55 bytes long, the + // RLP encoding consists of a single byte with value 0xC0 plus the length + // of the list followed by the concatenation of the RLP encodings of the + // items. The range of the first byte is thus [0xC0, 0xF7]. + return List, uint64(b - 0xC0), nil + default: + // If the total payload of a list is more than 55 bytes long, + // the RLP encoding consists of a single byte with value 0xF7 + // plus the length of the length of the payload in binary + // form, followed by the length of the payload, followed by + // the concatenation of the RLP encodings of the items. The + // range of the first byte is thus [0xF8, 0xFF]. + size, err = s.readUint(b - 0xF7) + return List, size, err + } +} + +func (s *Stream) readUint(size byte) (uint64, error) { + if size == 1 { + b, err := s.readByte() + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return uint64(b), err + } + start := int(8 - size) + for i := 0; i < start; i++ { + s.uintbuf[i] = 0 + } + err := s.readFull(s.uintbuf[start:]) + return binary.BigEndian.Uint64(s.uintbuf), err +} + +func (s *Stream) readFull(buf []byte) (err error) { + s.willRead(uint64(len(buf))) + var nn, n int + for n < len(buf) && err == nil { + nn, err = s.r.Read(buf[n:]) + n += nn + } + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return err +} + +func (s *Stream) readByte() (byte, error) { + s.willRead(1) + b, err := s.r.ReadByte() + if len(s.stack) > 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + return b, err +} + +func (s *Stream) willRead(n uint64) { + s.kind = -1 // rearm Kind + if len(s.stack) > 0 { + s.stack[len(s.stack)-1].pos += n + } +} diff --git a/rlp/decode_test.go b/rlp/decode_test.go new file mode 100644 index 000000000..eb1618299 --- /dev/null +++ b/rlp/decode_test.go @@ -0,0 +1,476 @@ +package rlp + +import ( + "bytes" + "encoding/hex" + "errors" + "fmt" + "io" + "math/big" + "reflect" + "testing" + + "github.com/ethereum/go-ethereum/ethutil" +) + +func TestStreamKind(t *testing.T) { + tests := []struct { + input string + wantKind Kind + wantLen uint64 + }{ + {"00", Byte, 0}, + {"01", Byte, 0}, + {"7F", Byte, 0}, + {"80", String, 0}, + {"B7", String, 55}, + {"B800", String, 0}, + {"B90400", String, 1024}, + {"BA000400", String, 1024}, + {"BB00000400", String, 1024}, + {"BFFFFFFFFFFFFFFFFF", String, ^uint64(0)}, + {"C0", List, 0}, + {"C8", List, 8}, + {"F7", List, 55}, + {"F800", List, 0}, + {"F804", List, 4}, + {"F90400", List, 1024}, + {"FFFFFFFFFFFFFFFFFF", List, ^uint64(0)}, + } + + for i, test := range tests { + s := NewStream(bytes.NewReader(unhex(test.input))) + kind, len, err := s.Kind() + if err != nil { + t.Errorf("test %d: Type returned error: %v", i, err) + continue + } + if kind != test.wantKind { + t.Errorf("test %d: kind mismatch: got %d, want %d", i, kind, test.wantKind) + } + if len != test.wantLen { + t.Errorf("test %d: len mismatch: got %d, want %d", i, len, test.wantLen) + } + } +} + +func TestStreamErrors(t *testing.T) { + type calls []string + tests := []struct { + string + calls + error + }{ + {"", calls{"Kind"}, io.EOF}, + {"", calls{"List"}, io.EOF}, + {"", calls{"Uint"}, io.EOF}, + {"C0", calls{"Bytes"}, ErrExpectedString}, + {"C0", calls{"Uint"}, ErrExpectedString}, + {"81", calls{"Bytes"}, io.ErrUnexpectedEOF}, + {"81", calls{"Uint"}, io.ErrUnexpectedEOF}, + {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF}, + {"89000000000000000001", calls{"Uint"}, errors.New("rlp: string is larger than 64 bits")}, + {"00", calls{"List"}, ErrExpectedList}, + {"80", calls{"List"}, ErrExpectedList}, + {"C0", calls{"List", "Uint"}, EOL}, + {"C801", calls{"List", "Uint", "Uint"}, io.ErrUnexpectedEOF}, + {"C8C9", calls{"List", "Kind"}, ErrElemTooLarge}, + {"C3C2010201", calls{"List", "List", "Uint", "Uint", "ListEnd", "Uint"}, EOL}, + {"00", calls{"ListEnd"}, errNotInList}, + {"C40102", calls{"List", "Uint", "ListEnd"}, errNotAtEOL}, + } + +testfor: + for i, test := range tests { + s := NewStream(bytes.NewReader(unhex(test.string))) + rs := reflect.ValueOf(s) + for j, call := range test.calls { + fval := rs.MethodByName(call) + ret := fval.Call(nil) + err := "" + if lastret := ret[len(ret)-1].Interface(); lastret != nil { + err = lastret.(error).Error() + } + if j == len(test.calls)-1 { + if err != test.error.Error() { + t.Errorf("test %d: last call (%s) error mismatch\ngot: %s\nwant: %v", + i, call, err, test.error) + } + } else if err != "" { + t.Errorf("test %d: call %d (%s) unexpected error: %q", i, j, call, err) + continue testfor + } + } + } +} + +func TestStreamList(t *testing.T) { + s := NewStream(bytes.NewReader(unhex("C80102030405060708"))) + + len, err := s.List() + if err != nil { + t.Fatalf("List error: %v", err) + } + if len != 8 { + t.Fatalf("List returned invalid length, got %d, want 8", len) + } + + for i := uint64(1); i <= 8; i++ { + v, err := s.Uint() + if err != nil { + t.Fatalf("Uint error: %v", err) + } + if i != v { + t.Errorf("Uint returned wrong value, got %d, want %d", v, i) + } + } + + if _, err := s.Uint(); err != EOL { + t.Errorf("Uint error mismatch, got %v, want %v", err, EOL) + } + if err = s.ListEnd(); err != nil { + t.Fatalf("ListEnd error: %v", err) + } +} + +func TestDecodeErrors(t *testing.T) { + r := bytes.NewReader(nil) + + if err := Decode(r, nil); err != errDecodeIntoNil { + t.Errorf("Decode(r, nil) error mismatch, got %q, want %q", err, errDecodeIntoNil) + } + + var nilptr *struct{} + if err := Decode(r, nilptr); err != errDecodeIntoNil { + t.Errorf("Decode(r, nilptr) error mismatch, got %q, want %q", err, errDecodeIntoNil) + } + + if err := Decode(r, struct{}{}); err != errNoPointer { + t.Errorf("Decode(r, struct{}{}) error mismatch, got %q, want %q", err, errNoPointer) + } + + expectErr := "rlp: type chan bool is not RLP-serializable" + if err := Decode(r, new(chan bool)); err == nil || err.Error() != expectErr { + t.Errorf("Decode(r, new(chan bool)) error mismatch, got %q, want %q", err, expectErr) + } + + if err := Decode(r, new(int)); err != io.EOF { + t.Errorf("Decode(r, new(int)) error mismatch, got %q, want %q", err, io.EOF) + } +} + +type decodeTest struct { + input string + ptr interface{} + value interface{} + error error +} + +type simplestruct struct { + A int + B string +} + +type recstruct struct { + I int + Child *recstruct +} + +var ( + veryBigInt = big.NewInt(0).Add( + big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), + big.NewInt(0xFFFF), + ) +) + +var ( + sharedByteArray [5]byte + sharedPtr = new(*int) +) + +var decodeTests = []decodeTest{ + // integers + {input: "05", ptr: new(uint32), value: uint32(5)}, + {input: "80", ptr: new(uint32), value: uint32(0)}, + {input: "8105", ptr: new(uint32), value: uint32(5)}, + {input: "820505", ptr: new(uint32), value: uint32(0x0505)}, + {input: "83050505", ptr: new(uint32), value: uint32(0x050505)}, + {input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)}, + {input: "850505050505", ptr: new(uint32), error: errors.New("rlp: string is larger than 32 bits")}, + {input: "C0", ptr: new(uint32), error: ErrExpectedString}, + + // slices + {input: "C0", ptr: new([]int), value: []int{}}, + {input: "C80102030405060708", ptr: new([]int), value: []int{1, 2, 3, 4, 5, 6, 7, 8}}, + + // arrays + {input: "C0", ptr: new([5]int), value: [5]int{}}, + {input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}}, + {input: "C6010203040506", ptr: new([5]int), error: errors.New("rlp: input List has more than 5 elements")}, + + // byte slices + {input: "01", ptr: new([]byte), value: []byte{1}}, + {input: "80", ptr: new([]byte), value: []byte{}}, + {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, + {input: "C0", ptr: new([]byte), value: []byte{}}, + {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, + {input: "C3820102", ptr: new([]byte), error: errors.New("rlp: string is larger than 8 bits")}, + + // byte arrays + {input: "01", ptr: new([5]byte), value: [5]byte{1}}, + {input: "80", ptr: new([5]byte), value: [5]byte{}}, + {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, + {input: "C0", ptr: new([5]byte), value: [5]byte{}}, + {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, + {input: "C3820102", ptr: new([5]byte), error: errors.New("rlp: string is larger than 8 bits")}, + {input: "86010203040506", ptr: new([5]byte), error: errStringDoesntFitArray}, + {input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF}, + + // byte array reuse (should be zeroed) + {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, + {input: "8101", ptr: &sharedByteArray, value: [5]byte{1}}, // kind: String + {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, + {input: "01", ptr: &sharedByteArray, value: [5]byte{1}}, // kind: Byte + {input: "C3010203", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 0, 0}}, + {input: "C101", ptr: &sharedByteArray, value: [5]byte{1}}, // kind: List + + // zero sized byte arrays + {input: "80", ptr: new([0]byte), value: [0]byte{}}, + {input: "C0", ptr: new([0]byte), value: [0]byte{}}, + {input: "01", ptr: new([0]byte), error: errStringDoesntFitArray}, + {input: "8101", ptr: new([0]byte), error: errStringDoesntFitArray}, + + // strings + {input: "00", ptr: new(string), value: "\000"}, + {input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"}, + {input: "C0", ptr: new(string), error: ErrExpectedString}, + + // big ints + {input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, + {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, + {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works + {input: "C0", ptr: new(*big.Int), error: ErrExpectedString}, + + // structs + {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, + {input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, + {input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, + {input: "C3010101", ptr: new(simplestruct), error: errors.New("rlp: input List has too many elements")}, + { + input: "C501C302C103", + ptr: new(recstruct), + value: recstruct{1, &recstruct{2, &recstruct{3, nil}}}, + }, + + // pointers + {input: "00", ptr: new(*int), value: (*int)(nil)}, + {input: "80", ptr: new(*int), value: (*int)(nil)}, + {input: "C0", ptr: new(*int), value: (*int)(nil)}, + {input: "07", ptr: new(*int), value: intp(7)}, + {input: "8108", ptr: new(*int), value: intp(8)}, + {input: "C109", ptr: new(*[]int), value: &[]int{9}}, + {input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}}, + + // pointer should be reset to nil + {input: "05", ptr: sharedPtr, value: intp(5)}, + {input: "80", ptr: sharedPtr, value: (*int)(nil)}, + + // interface{} + {input: "00", ptr: new(interface{}), value: []byte{0}}, + {input: "01", ptr: new(interface{}), value: []byte{1}}, + {input: "80", ptr: new(interface{}), value: []byte{}}, + {input: "850505050505", ptr: new(interface{}), value: []byte{5, 5, 5, 5, 5}}, + {input: "C0", ptr: new(interface{}), value: []interface{}{}}, + {input: "C50183040404", ptr: new(interface{}), value: []interface{}{[]byte{1}, []byte{4, 4, 4}}}, +} + +func intp(i int) *int { return &i } + +func TestDecode(t *testing.T) { + for i, test := range decodeTests { + input, err := hex.DecodeString(test.input) + if err != nil { + t.Errorf("test %d: invalid hex input %q", i, test.input) + continue + } + err = Decode(bytes.NewReader(input), test.ptr) + if err != nil && test.error == nil { + t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q", + i, err, test.ptr, test.input) + continue + } + if test.error != nil && fmt.Sprint(err) != fmt.Sprint(test.error) { + t.Errorf("test %d: Decode error mismatch\ngot %v\nwant %v\ndecoding into %T\ninput %q", + i, err, test.error, test.ptr, test.input) + continue + } + deref := reflect.ValueOf(test.ptr).Elem().Interface() + if err == nil && !reflect.DeepEqual(deref, test.value) { + t.Errorf("test %d: value mismatch\ngot %#v\nwant %#v\ndecoding into %T\ninput %q", + i, deref, test.value, test.ptr, test.input) + } + } +} + +type testDecoder struct{ called bool } + +func (t *testDecoder) DecodeRLP(s *Stream) error { + if _, err := s.Uint(); err != nil { + return err + } + t.called = true + return nil +} + +func TestDecodeDecoder(t *testing.T) { + var s struct { + T1 testDecoder + T2 *testDecoder + T3 **testDecoder + } + if err := Decode(bytes.NewReader(unhex("C3010203")), &s); err != nil { + t.Fatalf("Decode error: %v", err) + } + + if !s.T1.called { + t.Errorf("DecodeRLP was not called for (non-pointer) testDecoder") + } + + if s.T2 == nil { + t.Errorf("*testDecoder has not been allocated") + } else if !s.T2.called { + t.Errorf("DecodeRLP was not called for *testDecoder") + } + + if s.T3 == nil || *s.T3 == nil { + t.Errorf("**testDecoder has not been allocated") + } else if !(*s.T3).called { + t.Errorf("DecodeRLP was not called for **testDecoder") + } +} + +type byteDecoder byte + +func (bd *byteDecoder) DecodeRLP(s *Stream) error { + _, err := s.Uint() + *bd = 255 + return err +} + +func (bd byteDecoder) called() bool { + return bd == 255 +} + +// This test verifies that the byte slice/byte array logic +// does not kick in for element types implementing Decoder. +func TestDecoderInByteSlice(t *testing.T) { + var slice []byteDecoder + if err := Decode(bytes.NewReader(unhex("C101")), &slice); err != nil { + t.Errorf("unexpected Decode error %v", err) + } else if !slice[0].called() { + t.Errorf("DecodeRLP not called for slice element") + } + + var array [1]byteDecoder + if err := Decode(bytes.NewReader(unhex("C101")), &array); err != nil { + t.Errorf("unexpected Decode error %v", err) + } else if !array[0].called() { + t.Errorf("DecodeRLP not called for array element") + } +} + +func ExampleDecode() { + input, _ := hex.DecodeString("C90A1486666F6F626172") + + type example struct { + A, B int + private int // private fields are ignored + String string + } + + var s example + err := Decode(bytes.NewReader(input), &s) + if err != nil { + fmt.Printf("Error: %v\n", err) + } else { + fmt.Printf("Decoded value: %#v\n", s) + } + // Output: + // Decoded value: rlp.example{A:10, B:20, private:0, String:"foobar"} +} + +func ExampleStream() { + input, _ := hex.DecodeString("C90A1486666F6F626172") + s := NewStream(bytes.NewReader(input)) + + // Check what kind of value lies ahead + kind, size, _ := s.Kind() + fmt.Printf("Kind: %v size:%d\n", kind, size) + + // Enter the list + if _, err := s.List(); err != nil { + fmt.Printf("List error: %v\n", err) + return + } + + // Decode elements + fmt.Println(s.Uint()) + fmt.Println(s.Uint()) + fmt.Println(s.Bytes()) + + // Acknowledge end of list + if err := s.ListEnd(); err != nil { + fmt.Printf("ListEnd error: %v\n", err) + } + // Output: + // Kind: List size:9 + // 10 + // 20 + // [102 111 111 98 97 114] +} + +func BenchmarkDecode(b *testing.B) { + enc := encTest(90000) + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var s []int + r := bytes.NewReader(enc) + if err := Decode(r, &s); err != nil { + b.Fatalf("Decode error: %v", err) + } + } +} + +func BenchmarkDecodeIntSliceReuse(b *testing.B) { + enc := encTest(100000) + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var s []int + for i := 0; i < b.N; i++ { + r := bytes.NewReader(enc) + if err := Decode(r, &s); err != nil { + b.Fatalf("Decode error: %v", err) + } + } +} + +func encTest(n int) []byte { + s := make([]interface{}, n) + for i := 0; i < n; i++ { + s[i] = i + } + return ethutil.Encode(s) +} + +func unhex(str string) []byte { + b, err := hex.DecodeString(str) + if err != nil { + panic(fmt.Sprintf("invalid hex string: %q", str)) + } + return b +} diff --git a/rlp/doc.go b/rlp/doc.go new file mode 100644 index 000000000..aab98ea43 --- /dev/null +++ b/rlp/doc.go @@ -0,0 +1,17 @@ +/* +Package rlp implements the RLP serialization format. + +The purpose of RLP (Recursive Linear Prefix) qis to encode arbitrarily +nested arrays of binary data, and RLP is the main encoding method used +to serialize objects in Ethereum. The only purpose of RLP is to encode +structure; encoding specific atomic data types (eg. strings, ints, +floats) is left up to higher-order protocols; in Ethereum integers +must be represented in big endian binary form with no leading zeroes +(thus making the integer value zero be equivalent to the empty byte +array). + +RLP values are distinguished by a type tag. The type tag precedes the +value in the input stream and defines the size and kind of the bytes +that follow. +*/ +package rlp diff --git a/rlp/typecache.go b/rlp/typecache.go new file mode 100644 index 000000000..75dbb43c2 --- /dev/null +++ b/rlp/typecache.go @@ -0,0 +1,91 @@ +package rlp + +import ( + "fmt" + "math/big" + "reflect" + "sync" +) + +type decoder func(*Stream, reflect.Value) error + +type typeinfo struct { + decoder +} + +var ( + typeCacheMutex sync.RWMutex + typeCache = make(map[reflect.Type]*typeinfo) +) + +func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) { + typeCacheMutex.RLock() + info := typeCache[typ] + typeCacheMutex.RUnlock() + if info != nil { + return info, nil + } + // not in the cache, need to generate info for this type. + typeCacheMutex.Lock() + defer typeCacheMutex.Unlock() + return cachedTypeInfo1(typ) +} + +func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) { + info := typeCache[typ] + if info != nil { + // another goroutine got the write lock first + return info, nil + } + // put a dummmy value into the cache before generating. + // if the generator tries to lookup itself, it will get + // the dummy value and won't call itself recursively. + typeCache[typ] = new(typeinfo) + info, err := genTypeInfo(typ) + if err != nil { + // remove the dummy value if the generator fails + delete(typeCache, typ) + return nil, err + } + *typeCache[typ] = *info + return typeCache[typ], err +} + +var ( + decoderInterface = reflect.TypeOf(new(Decoder)).Elem() + bigInt = reflect.TypeOf(big.Int{}) +) + +func genTypeInfo(typ reflect.Type) (info *typeinfo, err error) { + info = new(typeinfo) + kind := typ.Kind() + switch { + case typ.Implements(decoderInterface): + info.decoder = decodeDecoder + case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(decoderInterface): + info.decoder = decodeDecoderNoPtr + case typ.AssignableTo(reflect.PtrTo(bigInt)): + info.decoder = decodeBigInt + case typ.AssignableTo(bigInt): + info.decoder = decodeBigIntNoPtr + case isInteger(kind): + info.decoder = makeNumDecoder(typ) + case kind == reflect.String: + info.decoder = decodeString + case kind == reflect.Slice || kind == reflect.Array: + info.decoder, err = makeListDecoder(typ) + case kind == reflect.Struct: + info.decoder, err = makeStructDecoder(typ) + case kind == reflect.Ptr: + info.decoder, err = makePtrDecoder(typ) + case kind == reflect.Interface && typ.NumMethod() == 0: + info.decoder = decodeInterface + default: + err = fmt.Errorf("rlp: type %v is not RLP-serializable", typ) + } + return info, err +} + +func isInteger(k reflect.Kind) bool { + return k >= reflect.Int && k <= reflect.Uintptr +}