diff --git a/p2p/conn/connection.go b/p2p/conn/connection.go index 7727ee32..46e36301 100644 --- a/p2p/conn/connection.go +++ b/p2p/conn/connection.go @@ -2,6 +2,7 @@ package conn import ( "bufio" + "errors" "fmt" "io" "math" @@ -21,7 +22,6 @@ const ( minReadBufferSize = 1024 minWriteBufferSize = 65536 updateStats = 2 * time.Second - pingTimeout = 40 * time.Second // some of these defaults are written in the user config // flushThrottle, sendRate, recvRate @@ -34,6 +34,8 @@ const ( defaultSendRate = int64(512000) // 500KB/s defaultRecvRate = int64(512000) // 500KB/s defaultSendTimeout = 10 * time.Second + defaultPingInterval = 60 * time.Second + defaultPongTimeout = 45 * time.Second ) type receiveCbFunc func(chID byte, msgBytes []byte) @@ -81,10 +83,15 @@ type MConnection struct { errored uint32 config *MConnConfig - quit chan struct{} - flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled. - pingTimer *cmn.RepeatTimer // send pings periodically - chStatsTimer *cmn.RepeatTimer // update channel stats periodically + quit chan struct{} + flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled. + pingTimer *cmn.RepeatTimer // send pings periodically + + // close conn if pong is not received in pongTimeout + pongTimer *time.Timer + pongTimeoutCh chan bool // true - timeout, false - peer sent pong + + chStatsTimer *cmn.RepeatTimer // update channel stats periodically created time.Time // time of creation } @@ -94,9 +101,17 @@ type MConnConfig struct { SendRate int64 `mapstructure:"send_rate"` RecvRate int64 `mapstructure:"recv_rate"` - MaxMsgPacketPayloadSize int + // Maximum payload size + MaxMsgPacketPayloadSize int `mapstructure:"max_msg_packet_payload_size"` - FlushThrottle time.Duration + // Interval to flush writes (throttled) + FlushThrottle time.Duration `mapstructure:"flush_throttle"` + + // Interval to send pings + PingInterval time.Duration `mapstructure:"ping_interval"` + + // Maximum wait time for pongs + PongTimeout time.Duration `mapstructure:"pong_timeout"` } func (cfg *MConnConfig) maxMsgPacketTotalSize() int { @@ -110,6 +125,8 @@ func DefaultMConnConfig() *MConnConfig { RecvRate: defaultRecvRate, MaxMsgPacketPayloadSize: defaultMaxMsgPacketPayloadSize, FlushThrottle: defaultFlushThrottle, + PingInterval: defaultPingInterval, + PongTimeout: defaultPongTimeout, } } @@ -125,6 +142,10 @@ func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive recei // NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection { + if config.PongTimeout >= config.PingInterval { + panic("pongTimeout must be less than pingInterval (otherwise, next ping will reset pong timer)") + } + mconn := &MConnection{ conn: conn, bufReader: bufio.NewReaderSize(conn, minReadBufferSize), @@ -132,7 +153,7 @@ func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onRec sendMonitor: flow.New(0, 0), recvMonitor: flow.New(0, 0), send: make(chan struct{}, 1), - pong: make(chan struct{}), + pong: make(chan struct{}, 1), onReceive: onReceive, onError: onError, config: config, @@ -169,7 +190,8 @@ func (c *MConnection) OnStart() error { } c.quit = make(chan struct{}) c.flushTimer = cmn.NewThrottleTimer("flush", c.config.FlushThrottle) - c.pingTimer = cmn.NewRepeatTimer("ping", pingTimeout) + c.pingTimer = cmn.NewRepeatTimer("ping", c.config.PingInterval) + c.pongTimeoutCh = make(chan bool, 1) c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateStats) go c.sendRoutine() go c.recvRoutine() @@ -179,12 +201,12 @@ func (c *MConnection) OnStart() error { // OnStop implements BaseService func (c *MConnection) OnStop() { c.BaseService.OnStop() - c.flushTimer.Stop() - c.pingTimer.Stop() - c.chStatsTimer.Stop() if c.quit != nil { close(c.quit) } + c.flushTimer.Stop() + c.pingTimer.Stop() + c.chStatsTimer.Stop() c.conn.Close() // nolint: errcheck // We can't close pong safely here because @@ -315,7 +337,18 @@ FOR_LOOP: c.Logger.Debug("Send Ping") wire.WriteByte(packetTypePing, c.bufWriter, &n, &err) c.sendMonitor.Update(int(n)) + c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout) + c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() { + c.pongTimeoutCh <- true + }) c.flush() + case timeout := <-c.pongTimeoutCh: + if timeout { + c.Logger.Debug("Pong timeout") + err = errors.New("pong timeout") + } else { + c.stopPongTimer() + } case <-c.pong: c.Logger.Debug("Send Pong") wire.WriteByte(packetTypePong, c.bufWriter, &n, &err) @@ -346,6 +379,7 @@ FOR_LOOP: } // Cleanup + c.stopPongTimer() } // Returns true if messages from channels were exhausted. @@ -447,6 +481,7 @@ FOR_LOOP: switch pktType { case packetTypePing: // TODO: prevent abuse, as they cause flush()'s. + // https://github.com/tendermint/tendermint/issues/1190 c.Logger.Debug("Receive Ping") select { case c.pong <- struct{}{}: @@ -454,8 +489,12 @@ FOR_LOOP: // never block } case packetTypePong: - // do nothing c.Logger.Debug("Receive Pong") + select { + case c.pongTimeoutCh <- false: + default: + // never block + } case packetTypeMsg: pkt, n, err := msgPacket{}, int(0), error(nil) wire.ReadBinaryPtr(&pkt, c.bufReader, c.config.maxMsgPacketTotalSize(), &n, &err) @@ -503,6 +542,17 @@ FOR_LOOP: } } +// not goroutine-safe +func (c *MConnection) stopPongTimer() { + if c.pongTimer != nil { + if !c.pongTimer.Stop() { + <-c.pongTimer.C + } + drain(c.pongTimeoutCh) + c.pongTimer = nil + } +} + type ConnectionStatus struct { Duration time.Duration SendMonitor flow.Status @@ -730,3 +780,13 @@ type msgPacket struct { func (p msgPacket) String() string { return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF) } + +func drain(ch <-chan bool) { + for { + select { + case <-ch: + default: + return + } + } +} diff --git a/p2p/conn/connection_test.go b/p2p/conn/connection_test.go index 9c8eccbe..d308ea61 100644 --- a/p2p/conn/connection_test.go +++ b/p2p/conn/connection_test.go @@ -22,8 +22,11 @@ func createTestMConnection(conn net.Conn) *MConnection { } func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *MConnection { + cfg := DefaultMConnConfig() + cfg.PingInterval = 90 * time.Millisecond + cfg.PongTimeout = 45 * time.Millisecond chDescs := []*ChannelDescriptor{&ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} - c := NewMConnection(conn, chDescs, onReceive, onError) + c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg) c.SetLogger(log.TestingLogger()) return c } @@ -116,6 +119,176 @@ func TestMConnectionStatus(t *testing.T) { assert.Zero(status.Channels[0].SendQueueSize) } +func TestMConnectionPongTimeoutResultsInError(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + receivedCh := make(chan []byte) + errorsCh := make(chan interface{}) + onReceive := func(chID byte, msgBytes []byte) { + receivedCh <- msgBytes + } + onError := func(r interface{}) { + errorsCh <- r + } + mconn := createMConnectionWithCallbacks(client, onReceive, onError) + err := mconn.Start() + require.Nil(t, err) + defer mconn.Stop() + + serverGotPing := make(chan struct{}) + go func() { + // read ping + server.Read(make([]byte, 1)) + serverGotPing <- struct{}{} + }() + <-serverGotPing + + pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond + select { + case msgBytes := <-receivedCh: + t.Fatalf("Expected error, but got %v", msgBytes) + case err := <-errorsCh: + assert.NotNil(t, err) + case <-time.After(pongTimerExpired): + t.Fatalf("Expected to receive error after %v", pongTimerExpired) + } +} + +func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + receivedCh := make(chan []byte) + errorsCh := make(chan interface{}) + onReceive := func(chID byte, msgBytes []byte) { + receivedCh <- msgBytes + } + onError := func(r interface{}) { + errorsCh <- r + } + mconn := createMConnectionWithCallbacks(client, onReceive, onError) + err := mconn.Start() + require.Nil(t, err) + defer mconn.Stop() + + // sending 3 pongs in a row (abuse) + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + + serverGotPing := make(chan struct{}) + go func() { + // read ping (one byte) + _, err = server.Read(make([]byte, 1)) + require.Nil(t, err) + serverGotPing <- struct{}{} + // respond with pong + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + }() + <-serverGotPing + + pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond + select { + case msgBytes := <-receivedCh: + t.Fatalf("Expected no data, but got %v", msgBytes) + case err := <-errorsCh: + t.Fatalf("Expected no error, but got %v", err) + case <-time.After(pongTimerExpired): + assert.True(t, mconn.IsRunning()) + } +} + +func TestMConnectionMultiplePings(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + receivedCh := make(chan []byte) + errorsCh := make(chan interface{}) + onReceive := func(chID byte, msgBytes []byte) { + receivedCh <- msgBytes + } + onError := func(r interface{}) { + errorsCh <- r + } + mconn := createMConnectionWithCallbacks(client, onReceive, onError) + err := mconn.Start() + require.Nil(t, err) + defer mconn.Stop() + + // sending 3 pings in a row (abuse) + // see https://github.com/tendermint/tendermint/issues/1190 + _, err = server.Write([]byte{packetTypePing}) + require.Nil(t, err) + _, err = server.Read(make([]byte, 1)) + require.Nil(t, err) + _, err = server.Write([]byte{packetTypePing}) + require.Nil(t, err) + _, err = server.Read(make([]byte, 1)) + require.Nil(t, err) + _, err = server.Write([]byte{packetTypePing}) + require.Nil(t, err) + _, err = server.Read(make([]byte, 1)) + require.Nil(t, err) + + assert.True(t, mconn.IsRunning()) +} + +func TestMConnectionPingPongs(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + receivedCh := make(chan []byte) + errorsCh := make(chan interface{}) + onReceive := func(chID byte, msgBytes []byte) { + receivedCh <- msgBytes + } + onError := func(r interface{}) { + errorsCh <- r + } + mconn := createMConnectionWithCallbacks(client, onReceive, onError) + err := mconn.Start() + require.Nil(t, err) + defer mconn.Stop() + + serverGotPing := make(chan struct{}) + go func() { + // read ping + server.Read(make([]byte, 1)) + serverGotPing <- struct{}{} + // respond with pong + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + + time.Sleep(mconn.config.PingInterval) + + // read ping + server.Read(make([]byte, 1)) + // respond with pong + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + }() + <-serverGotPing + + pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2 + select { + case msgBytes := <-receivedCh: + t.Fatalf("Expected no data, but got %v", msgBytes) + case err := <-errorsCh: + t.Fatalf("Expected no error, but got %v", err) + case <-time.After(2 * pongTimerExpired): + assert.True(t, mconn.IsRunning()) + } +} + func TestMConnectionStopsAndReturnsError(t *testing.T) { assert, require := assert.New(t), require.New(t) @@ -303,13 +476,7 @@ func TestMConnectionTrySend(t *testing.T) { mconn.TrySend(0x01, msg) resultCh <- "TrySend" }() - go func() { - mconn.Send(0x01, msg) - resultCh <- "Send" - }() assert.False(mconn.CanSend(0x01)) assert.False(mconn.TrySend(0x01, msg)) assert.Equal("TrySend", <-resultCh) - server.Read(make([]byte, len(msg))) - assert.Equal("Send", <-resultCh) // Order constrained by parallel blocking above } diff --git a/p2p/conn/secret_connection_test.go b/p2p/conn/secret_connection_test.go index 8af9cdeb..4cf715dd 100644 --- a/p2p/conn/secret_connection_test.go +++ b/p2p/conn/secret_connection_test.go @@ -4,7 +4,7 @@ import ( "io" "testing" - "github.com/tendermint/go-crypto" + crypto "github.com/tendermint/go-crypto" cmn "github.com/tendermint/tmlibs/common" ) diff --git a/p2p/switch.go b/p2p/switch.go index 9502359d..f1d02dcf 100644 --- a/p2p/switch.go +++ b/p2p/switch.go @@ -5,6 +5,7 @@ import ( "math" "math/rand" "net" + "sync" "time" "github.com/pkg/errors" @@ -200,9 +201,29 @@ func (sw *Switch) OnStop() { //--------------------------------------------------------------------- // Peers -// Peers returns the set of peers that are connected to the switch. -func (sw *Switch) Peers() IPeerSet { - return sw.peers +// Broadcast runs a go routine for each attempted send, which will block trying +// to send for defaultSendTimeoutSeconds. Returns a channel which receives +// success values for each attempted send (false if times out). Channel will be +// closed once msg send to all peers. +// +// NOTE: Broadcast uses goroutines, so order of broadcast may not be preserved. +func (sw *Switch) Broadcast(chID byte, msg interface{}) chan bool { + successChan := make(chan bool, len(sw.peers.List())) + sw.Logger.Debug("Broadcast", "channel", chID, "msg", msg) + var wg sync.WaitGroup + for _, peer := range sw.peers.List() { + wg.Add(1) + go func(peer Peer) { + defer wg.Done() + success := peer.Send(chID, msg) + successChan <- success + }(peer) + } + go func() { + wg.Wait() + close(successChan) + }() + return successChan } // NumPeers returns the count of outbound/inbound and outbound-dialing peers. @@ -219,21 +240,9 @@ func (sw *Switch) NumPeers() (outbound, inbound, dialing int) { return } -// Broadcast runs a go routine for each attempted send, which will block -// trying to send for defaultSendTimeoutSeconds. Returns a channel -// which receives success values for each attempted send (false if times out). -// NOTE: Broadcast uses goroutines, so order of broadcast may not be preserved. -// TODO: Something more intelligent. -func (sw *Switch) Broadcast(chID byte, msg interface{}) chan bool { - successChan := make(chan bool, len(sw.peers.List())) - sw.Logger.Debug("Broadcast", "channel", chID, "msg", msg) - for _, peer := range sw.peers.List() { - go func(peer Peer) { - success := peer.Send(chID, msg) - successChan <- success - }(peer) - } - return successChan +// Peers returns the set of peers that are connected to the switch. +func (sw *Switch) Peers() IPeerSet { + return sw.peers } // StopPeerForError disconnects from a peer due to external error. diff --git a/p2p/switch_test.go b/p2p/switch_test.go index 75f9640b..745eb44e 100644 --- a/p2p/switch_test.go +++ b/p2p/switch_test.go @@ -300,9 +300,7 @@ func TestSwitchFullConnectivity(t *testing.T) { } } -func BenchmarkSwitches(b *testing.B) { - b.StopTimer() - +func BenchmarkSwitchBroadcast(b *testing.B) { s1, s2 := MakeSwitchPair(b, func(i int, sw *Switch) *Switch { // Make bar reactors of bar channels each sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{ @@ -320,7 +318,8 @@ func BenchmarkSwitches(b *testing.B) { // Allow time for goroutines to boot up time.Sleep(1 * time.Second) - b.StartTimer() + + b.ResetTimer() numSuccess, numFailure := 0, 0 @@ -338,7 +337,4 @@ func BenchmarkSwitches(b *testing.B) { } b.Logf("success: %v, failure: %v", numSuccess, numFailure) - - // Allow everything to flush before stopping switches & closing connections. - b.StopTimer() }