diff --git a/p2p/message.go b/p2p/message.go index d3b8b74d4..f5418ff47 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -3,9 +3,11 @@ package p2p import ( "bytes" "encoding/binary" + "errors" "io" "io/ioutil" "math/big" + "sync/atomic" "github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/rlp" @@ -153,3 +155,78 @@ func (r *postrack) ReadByte() (byte, error) { } return b, err } + +// MsgPipe creates a message pipe. Reads on one end are matched +// with writes on the other. The pipe is full-duplex, both ends +// implement MsgReadWriter. +func MsgPipe() (*MsgPipeRW, *MsgPipeRW) { + var ( + c1, c2 = make(chan Msg), make(chan Msg) + closing = make(chan struct{}) + closed = new(int32) + rw1 = &MsgPipeRW{c1, c2, closing, closed} + rw2 = &MsgPipeRW{c2, c1, closing, closed} + ) + return rw1, rw2 +} + +// ErrPipeClosed is returned from pipe operations after the +// pipe has been closed. +var ErrPipeClosed = errors.New("p2p: read or write on closed message pipe") + +// MsgPipeRW is an endpoint of a MsgReadWriter pipe. +type MsgPipeRW struct { + w chan<- Msg + r <-chan Msg + closing chan struct{} + closed *int32 +} + +// WriteMsg sends a messsage on the pipe. +// It blocks until the receiver has consumed the message payload. +func (p *MsgPipeRW) WriteMsg(msg Msg) error { + if atomic.LoadInt32(p.closed) == 0 { + consumed := make(chan struct{}, 1) + msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed} + select { + case p.w <- msg: + if msg.Size > 0 { + // wait for payload read or discard + <-consumed + } + return nil + case <-p.closing: + } + } + return ErrPipeClosed +} + +// EncodeMsg is a convenient shorthand for sending an RLP-encoded message. +func (p *MsgPipeRW) EncodeMsg(code uint64, data ...interface{}) error { + return p.WriteMsg(NewMsg(code, data...)) +} + +// ReadMsg returns a message sent on the other end of the pipe. +func (p *MsgPipeRW) ReadMsg() (Msg, error) { + if atomic.LoadInt32(p.closed) == 0 { + select { + case msg := <-p.r: + return msg, nil + case <-p.closing: + } + } + return Msg{}, ErrPipeClosed +} + +// Close unblocks any pending ReadMsg and WriteMsg calls on both ends +// of the pipe. They will return ErrPipeClosed. Note that Close does +// not interrupt any reads from a message payload. +func (p *MsgPipeRW) Close() error { + if atomic.AddInt32(p.closed, 1) != 1 { + // someone else is already closing + atomic.StoreInt32(p.closed, 1) // avoid overflow + return nil + } + close(p.closing) + return nil +} diff --git a/p2p/message_test.go b/p2p/message_test.go index 557bfed26..066d2516d 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -2,8 +2,11 @@ package p2p import ( "bytes" + "fmt" "io/ioutil" + "runtime" "testing" + "time" "github.com/ethereum/go-ethereum/ethutil" ) @@ -68,3 +71,63 @@ func TestDecodeRealMsg(t *testing.T) { t.Errorf("incorrect code %d, want %d", msg.Code, 0) } } + +func ExampleMsgPipe() { + rw1, rw2 := MsgPipe() + go func() { + rw1.EncodeMsg(8, []byte{0, 0}) + rw1.EncodeMsg(5, []byte{1, 1}) + rw1.Close() + }() + + for { + msg, err := rw2.ReadMsg() + if err != nil { + break + } + var data [1][]byte + msg.Decode(&data) + fmt.Printf("msg: %d, %x\n", msg.Code, data[0]) + } + // Output: + // msg: 8, 0000 + // msg: 5, 0101 +} + +func TestMsgPipeUnblockWrite(t *testing.T) { +loop: + for i := 0; i < 100; i++ { + rw1, rw2 := MsgPipe() + done := make(chan struct{}) + go func() { + if err := rw1.EncodeMsg(1); err == nil { + t.Error("EncodeMsg returned nil error") + } else if err != ErrPipeClosed { + t.Error("EncodeMsg returned wrong error: got %v, want %v", err, ErrPipeClosed) + } + close(done) + }() + + // this call should ensure that EncodeMsg is waiting to + // deliver sometimes. if this isn't done, Close is likely to + // be executed before EncodeMsg starts and then we won't test + // all the cases. + runtime.Gosched() + + rw2.Close() + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Errorf("write didn't unblock") + break loop + } + } +} + +// This test should panic if concurrent close isn't implemented correctly. +func TestMsgPipeConcurrentClose(t *testing.T) { + rw1, _ := MsgPipe() + for i := 0; i < 10; i++ { + go rw1.Close() + } +} diff --git a/p2p/peer.go b/p2p/peer.go index 893ba86d7..86c4d7ab5 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -300,7 +300,7 @@ func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) proto.in <- msg } else { wait = true - pr := &eofSignal{msg.Payload, protoDone} + pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone} msg.Payload = pr proto.in <- msg } @@ -438,18 +438,25 @@ func (rw *proto) ReadMsg() (Msg, error) { return msg, nil } -// eofSignal wraps a reader with eof signaling. -// the eof channel is closed when the wrapped reader -// reaches EOF. +// eofSignal wraps a reader with eof signaling. the eof channel is +// closed when the wrapped reader returns an error or when count bytes +// have been read. +// type eofSignal struct { wrapped io.Reader + count int64 eof chan<- struct{} } +// note: when using eofSignal to detect whether a message payload +// has been read, Read might not be called for zero sized messages. + func (r *eofSignal) Read(buf []byte) (int, error) { n, err := r.wrapped.Read(buf) - if err != nil { + r.count -= int64(n) + if (err != nil || r.count <= 0) && r.eof != nil { r.eof <- struct{}{} // tell Peer that msg has been consumed + r.eof = nil } return n, err } diff --git a/p2p/peer_error.go b/p2p/peer_error.go index 88b870fbd..0eb7ec838 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -100,7 +100,16 @@ func (d DiscReason) String() string { return discReasonToString[d] } +type discRequestedError DiscReason + +func (err discRequestedError) Error() string { + return fmt.Sprintf("disconnect requested: %v", DiscReason(err)) +} + func discReasonForError(err error) DiscReason { + if reason, ok := err.(discRequestedError); ok { + return DiscReason(reason) + } peerError, ok := err.(*peerError) if !ok { return DiscSubprotocolError diff --git a/p2p/peer_test.go b/p2p/peer_test.go index d9640292f..f7759786e 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/hex" + "io" "io/ioutil" "net" "reflect" @@ -237,3 +238,58 @@ func TestNewPeer(t *testing.T) { // Should not hang. p.Disconnect(DiscAlreadyConnected) } + +func TestEOFSignal(t *testing.T) { + rb := make([]byte, 10) + + // empty reader + eof := make(chan struct{}, 1) + sig := &eofSignal{new(bytes.Buffer), 0, eof} + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // count before error + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof} + if n, err := sig.Read(rb); n != 8 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // error before count + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof} + if n, err := sig.Read(rb); n != 4 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + if n, err := sig.Read(rb); n != 0 || err != io.EOF { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + default: + t.Error("EOF chan not signaled") + } + + // no signal if neither occurs + eof = make(chan struct{}, 1) + sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof} + if n, err := sig.Read(rb); n != 10 || err != nil { + t.Errorf("Read returned unexpected values: (%v, %v)", n, err) + } + select { + case <-eof: + t.Error("unexpected EOF signal") + default: + } +} diff --git a/p2p/protocol.go b/p2p/protocol.go index 28eab87cd..3f52205f5 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -154,12 +154,11 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error { return newPeerError(errProtocolBreach, "extra handshake received") case discMsg: - var reason DiscReason + var reason [1]DiscReason if err := msg.Decode(&reason); err != nil { return err } - bp.peer.Disconnect(reason) - return nil + return discRequestedError(reason[0]) case pingMsg: return bp.rw.EncodeMsg(pongMsg) diff --git a/p2p/protocol_test.go b/p2p/protocol_test.go new file mode 100644 index 000000000..65f26fb12 --- /dev/null +++ b/p2p/protocol_test.go @@ -0,0 +1,58 @@ +package p2p + +import ( + "fmt" + "testing" +) + +func TestBaseProtocolDisconnect(t *testing.T) { + peer := NewPeer(NewSimpleClientIdentity("p1", "", "", "foo"), nil) + peer.ourID = NewSimpleClientIdentity("p2", "", "", "bar") + peer.pubkeyHook = func(*peerAddr) error { return nil } + + rw1, rw2 := MsgPipe() + done := make(chan struct{}) + go func() { + if err := expectMsg(rw2, handshakeMsg); err != nil { + t.Error(err) + } + err := rw2.EncodeMsg(handshakeMsg, + baseProtocolVersion, + "", + []interface{}{}, + 0, + make([]byte, 64), + ) + if err != nil { + t.Error(err) + } + if err := expectMsg(rw2, getPeersMsg); err != nil { + t.Error(err) + } + if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil { + t.Error(err) + } + close(done) + }() + + if err := runBaseProtocol(peer, rw1); err == nil { + t.Errorf("base protocol returned without error") + } else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting { + t.Errorf("base protocol returned wrong error: %v", err) + } + <-done +} + +func expectMsg(r MsgReader, code uint64) error { + msg, err := r.ReadMsg() + if err != nil { + return err + } + if err := msg.Discard(); err != nil { + return err + } + if msg.Code != code { + return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code) + } + return nil +}