fix race by sending signal instead of stopping pongTimer

This commit is contained in:
Anton Kaliaev 2018-02-09 15:16:22 +04:00
parent 26419fba28
commit 45750e1b29
No known key found for this signature in database
GPG Key ID: 7B6881D965918214
1 changed files with 19 additions and 9 deletions

View File

@ -89,7 +89,7 @@ type MConnection struct {
// close conn if pong is not received in pongTimeout
pongTimer *time.Timer
pongTimeoutCh chan struct{}
pongTimeoutCh chan bool // true - timeout, false - peer sent pong
chStatsTimer *cmn.RepeatTimer // update channel stats periodically
@ -191,7 +191,7 @@ func (c *MConnection) OnStart() error {
c.quit = make(chan struct{})
c.flushTimer = cmn.NewThrottleTimer("flush", c.config.FlushThrottle)
c.pingTimer = cmn.NewRepeatTimer("ping", c.config.PingInterval)
c.pongTimeoutCh = make(chan struct{})
c.pongTimeoutCh = make(chan bool)
c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateStats)
go c.sendRoutine()
go c.recvRoutine()
@ -339,19 +339,22 @@ FOR_LOOP:
c.sendMonitor.Update(int(n))
c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout)
c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() {
c.pongTimeoutCh <- struct{}{}
c.pongTimeoutCh <- true
})
c.flush()
case <-c.pongTimeoutCh:
c.Logger.Debug("Pong timeout")
err = errors.New("pong timeout")
case timeout := <-c.pongTimeoutCh:
if timeout {
c.Logger.Debug("Pong timeout")
err = errors.New("pong timeout")
} else {
c.stopPongTimer()
}
case <-c.pong:
c.Logger.Debug("Send Pong")
wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
c.sendMonitor.Update(int(n))
c.flush()
case <-c.quit:
c.stopPongTimer()
break FOR_LOOP
case <-c.send:
// Send some msgPackets
@ -376,6 +379,7 @@ FOR_LOOP:
}
// Cleanup
c.stopPongTimer()
}
// Returns true if messages from channels were exhausted.
@ -486,7 +490,11 @@ FOR_LOOP:
}
case packetTypePong:
c.Logger.Debug("Receive Pong")
c.stopPongTimer()
select {
case c.pongTimeoutCh <- false:
case <-c.quit:
break FOR_LOOP
}
case packetTypeMsg:
pkt, n, err := msgPacket{}, int(0), error(nil)
wire.ReadBinaryPtr(&pkt, c.bufReader, c.config.maxMsgPacketTotalSize(), &n, &err)
@ -534,12 +542,14 @@ FOR_LOOP:
}
}
// not goroutine-safe
func (c *MConnection) stopPongTimer() {
if c.pongTimer != nil {
if !c.pongTimer.Stop() {
<-c.pongTimer.C
}
drain(c.pongTimeoutCh)
c.pongTimer = nil
}
}
@ -771,7 +781,7 @@ func (p msgPacket) String() string {
return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
}
func drain(ch <-chan struct{}) {
func drain(ch <-chan bool) {
for {
select {
case <-ch: