rlp: add Stream.Reset and accept any reader (for p2p)

This commit is contained in:
Felix Lange 2014-11-24 19:01:25 +01:00
parent 59b63caf5e
commit 5a5560f105
2 changed files with 66 additions and 7 deletions

View File

@ -1,6 +1,7 @@
package rlp package rlp
import ( import (
"bufio"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -24,8 +25,9 @@ type Decoder interface {
DecodeRLP(*Stream) error DecodeRLP(*Stream) error
} }
// Decode parses RLP-encoded data from r and stores the result // Decode parses RLP-encoded data from r and stores the result in the
// in the value pointed to by val. Val must be a non-nil pointer. // value pointed to by val. Val must be a non-nil pointer. If r does
// not implement ByteReader, Decode will do its own buffering.
// //
// Decode uses the following type-dependent decoding rules: // Decode uses the following type-dependent decoding rules:
// //
@ -66,7 +68,7 @@ type Decoder interface {
// //
// Non-empty interface types are not supported, nor are bool, float32, // Non-empty interface types are not supported, nor are bool, float32,
// float64, maps, channel types and functions. // float64, maps, channel types and functions.
func Decode(r ByteReader, val interface{}) error { func Decode(r io.Reader, val interface{}) error {
return NewStream(r).Decode(val) return NewStream(r).Decode(val)
} }
@ -432,8 +434,14 @@ type Stream struct {
type listpos struct{ pos, size uint64 } type listpos struct{ pos, size uint64 }
func NewStream(r ByteReader) *Stream { // NewStream creates a new stream reading from r.
return &Stream{r: r, uintbuf: make([]byte, 8), kind: -1} // If r does not implement ByteReader, the Stream will
// introduce its own buffering.
func NewStream(r io.Reader) *Stream {
s := new(Stream)
s.Reset(r)
return s
}
} }
// Bytes reads an RLP string and returns its contents as a byte slice. // Bytes reads an RLP string and returns its contents as a byte slice.
@ -543,6 +551,23 @@ func (s *Stream) Decode(val interface{}) error {
return info.decoder(s, rval.Elem()) return info.decoder(s, rval.Elem())
} }
// Reset discards any information about the current decoding context
// and starts reading from r. If r does not also implement ByteReader,
// Stream will do its own buffering.
func (s *Stream) Reset(r io.Reader) {
bufr, ok := r.(ByteReader)
if !ok {
bufr = bufio.NewReader(r)
}
s.r = bufr
s.stack = s.stack[:0]
s.size = 0
s.kind = -1
if s.uintbuf == nil {
s.uintbuf = make([]byte, 8)
}
}
// Kind returns the kind and size of the next value in the // Kind returns the kind and size of the next value in the
// input stream. // input stream.
// //

View File

@ -286,14 +286,14 @@ var decodeTests = []decodeTest{
func intp(i int) *int { return &i } func intp(i int) *int { return &i }
func TestDecode(t *testing.T) { func runTests(t *testing.T, decode func([]byte, interface{}) error) {
for i, test := range decodeTests { for i, test := range decodeTests {
input, err := hex.DecodeString(test.input) input, err := hex.DecodeString(test.input)
if err != nil { if err != nil {
t.Errorf("test %d: invalid hex input %q", i, test.input) t.Errorf("test %d: invalid hex input %q", i, test.input)
continue continue
} }
err = Decode(bytes.NewReader(input), test.ptr) err = decode(input, test.ptr)
if err != nil && test.error == nil { if err != nil && test.error == nil {
t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q", t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q",
i, err, test.ptr, test.input) i, err, test.ptr, test.input)
@ -312,6 +312,40 @@ func TestDecode(t *testing.T) {
} }
} }
func TestDecodeWithByteReader(t *testing.T) {
runTests(t, func(input []byte, into interface{}) error {
return Decode(bytes.NewReader(input), into)
})
}
// dumbReader reads from a byte slice but does not
// implement ReadByte.
type dumbReader []byte
func (r *dumbReader) Read(buf []byte) (n int, err error) {
if len(*r) == 0 {
return 0, io.EOF
}
n = copy(buf, *r)
*r = (*r)[n:]
return n, nil
}
func TestDecodeWithNonByteReader(t *testing.T) {
runTests(t, func(input []byte, into interface{}) error {
r := dumbReader(input)
return Decode(&r, into)
})
}
func TestDecodeStreamReset(t *testing.T) {
s := NewStream(nil)
runTests(t, func(input []byte, into interface{}) error {
s.Reset(bytes.NewReader(input))
return s.Decode(into)
})
}
type testDecoder struct{ called bool } type testDecoder struct{ called bool }
func (t *testDecoder) DecodeRLP(s *Stream) error { func (t *testDecoder) DecodeRLP(s *Stream) error {