refactor code plus add one more test

* extract stopPongTimer method
* TestMConnectionMultiplePings
This commit is contained in:
Anton Kaliaev 2018-02-06 12:31:34 +04:00
parent ac0123d249
commit 26419fba28
No known key found for this signature in database
GPG Key ID: 7B6881D965918214
2 changed files with 54 additions and 18 deletions

View File

@ -351,12 +351,7 @@ FOR_LOOP:
c.sendMonitor.Update(int(n))
c.flush()
case <-c.quit:
if c.pongTimer != nil {
if !c.pongTimer.Stop() {
<-c.pongTimer.C
}
drain(c.pongTimeoutCh)
}
c.stopPongTimer()
break FOR_LOOP
case <-c.send:
// Send some msgPackets
@ -482,6 +477,7 @@ FOR_LOOP:
switch pktType {
case packetTypePing:
// TODO: prevent abuse, as they cause flush()'s.
// https://github.com/tendermint/tendermint/issues/1190
c.Logger.Debug("Receive Ping")
select {
case c.pong <- struct{}{}:
@ -490,12 +486,7 @@ FOR_LOOP:
}
case packetTypePong:
c.Logger.Debug("Receive Pong")
if c.pongTimer != nil {
if !c.pongTimer.Stop() {
<-c.pongTimer.C
}
drain(c.pongTimeoutCh)
}
c.stopPongTimer()
case packetTypeMsg:
pkt, n, err := msgPacket{}, int(0), error(nil)
wire.ReadBinaryPtr(&pkt, c.bufReader, c.config.maxMsgPacketTotalSize(), &n, &err)
@ -543,6 +534,15 @@ FOR_LOOP:
}
}
func (c *MConnection) stopPongTimer() {
if c.pongTimer != nil {
if !c.pongTimer.Stop() {
<-c.pongTimer.C
}
drain(c.pongTimeoutCh)
}
}
type ConnectionStatus struct {
Duration time.Duration
SendMonitor flow.Status

View File

@ -145,7 +145,7 @@ func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
}()
<-serverGotPing
pongTimerExpired := mconn.config.PongTimeout + 10*time.Millisecond
pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected error, but got %v", msgBytes)
@ -174,7 +174,7 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
require.Nil(t, err)
defer mconn.Stop()
// sending 3 pongs in a row
// sending 3 pongs in a row (abuse)
_, err = server.Write([]byte{packetTypePong})
require.Nil(t, err)
_, err = server.Write([]byte{packetTypePong})
@ -184,8 +184,9 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
serverGotPing := make(chan struct{})
go func() {
// read ping
server.Read(make([]byte, 1))
// read ping (one byte)
_, err = server.Read(make([]byte, 1))
require.Nil(t, err)
serverGotPing <- struct{}{}
// respond with pong
_, err = server.Write([]byte{packetTypePong})
@ -193,7 +194,7 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
}()
<-serverGotPing
pongTimerExpired := mconn.config.PongTimeout + 10*time.Millisecond
pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected no data, but got %v", msgBytes)
@ -204,6 +205,41 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
}
}
func TestMConnectionMultiplePings(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
defer client.Close()
receivedCh := make(chan []byte)
errorsCh := make(chan interface{})
onReceive := func(chID byte, msgBytes []byte) {
receivedCh <- msgBytes
}
onError := func(r interface{}) {
errorsCh <- r
}
mconn := createMConnectionWithCallbacks(client, onReceive, onError)
err := mconn.Start()
require.Nil(t, err)
defer mconn.Stop()
// sending 3 pings in a row (abuse)
_, err = server.Write([]byte{packetTypePing})
require.Nil(t, err)
_, err = server.Read(make([]byte, 1))
require.Nil(t, err)
_, err = server.Write([]byte{packetTypePing})
require.Nil(t, err)
_, err = server.Read(make([]byte, 1))
require.Nil(t, err)
_, err = server.Write([]byte{packetTypePing})
require.Nil(t, err)
_, err = server.Read(make([]byte, 1))
require.Nil(t, err)
assert.True(t, mconn.IsRunning())
}
func TestMConnectionPingPongs(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
@ -241,7 +277,7 @@ func TestMConnectionPingPongs(t *testing.T) {
}()
<-serverGotPing
pongTimerExpired := (mconn.config.PongTimeout + 10*time.Millisecond) * 2
pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2
select {
case msgBytes := <-receivedCh:
t.Fatalf("Expected no data, but got %v", msgBytes)