fix race by sending signal instead of stopping pongTimer
This commit is contained in:
parent
26419fba28
commit
45750e1b29
|
@ -89,7 +89,7 @@ type MConnection struct {
|
|||
|
||||
// close conn if pong is not received in pongTimeout
|
||||
pongTimer *time.Timer
|
||||
pongTimeoutCh chan struct{}
|
||||
pongTimeoutCh chan bool // true - timeout, false - peer sent pong
|
||||
|
||||
chStatsTimer *cmn.RepeatTimer // update channel stats periodically
|
||||
|
||||
|
@ -191,7 +191,7 @@ func (c *MConnection) OnStart() error {
|
|||
c.quit = make(chan struct{})
|
||||
c.flushTimer = cmn.NewThrottleTimer("flush", c.config.FlushThrottle)
|
||||
c.pingTimer = cmn.NewRepeatTimer("ping", c.config.PingInterval)
|
||||
c.pongTimeoutCh = make(chan struct{})
|
||||
c.pongTimeoutCh = make(chan bool)
|
||||
c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateStats)
|
||||
go c.sendRoutine()
|
||||
go c.recvRoutine()
|
||||
|
@ -339,19 +339,22 @@ FOR_LOOP:
|
|||
c.sendMonitor.Update(int(n))
|
||||
c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout)
|
||||
c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() {
|
||||
c.pongTimeoutCh <- struct{}{}
|
||||
c.pongTimeoutCh <- true
|
||||
})
|
||||
c.flush()
|
||||
case <-c.pongTimeoutCh:
|
||||
case timeout := <-c.pongTimeoutCh:
|
||||
if timeout {
|
||||
c.Logger.Debug("Pong timeout")
|
||||
err = errors.New("pong timeout")
|
||||
} else {
|
||||
c.stopPongTimer()
|
||||
}
|
||||
case <-c.pong:
|
||||
c.Logger.Debug("Send Pong")
|
||||
wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
|
||||
c.sendMonitor.Update(int(n))
|
||||
c.flush()
|
||||
case <-c.quit:
|
||||
c.stopPongTimer()
|
||||
break FOR_LOOP
|
||||
case <-c.send:
|
||||
// Send some msgPackets
|
||||
|
@ -376,6 +379,7 @@ FOR_LOOP:
|
|||
}
|
||||
|
||||
// Cleanup
|
||||
c.stopPongTimer()
|
||||
}
|
||||
|
||||
// Returns true if messages from channels were exhausted.
|
||||
|
@ -486,7 +490,11 @@ FOR_LOOP:
|
|||
}
|
||||
case packetTypePong:
|
||||
c.Logger.Debug("Receive Pong")
|
||||
c.stopPongTimer()
|
||||
select {
|
||||
case c.pongTimeoutCh <- false:
|
||||
case <-c.quit:
|
||||
break FOR_LOOP
|
||||
}
|
||||
case packetTypeMsg:
|
||||
pkt, n, err := msgPacket{}, int(0), error(nil)
|
||||
wire.ReadBinaryPtr(&pkt, c.bufReader, c.config.maxMsgPacketTotalSize(), &n, &err)
|
||||
|
@ -534,12 +542,14 @@ FOR_LOOP:
|
|||
}
|
||||
}
|
||||
|
||||
// not goroutine-safe
|
||||
func (c *MConnection) stopPongTimer() {
|
||||
if c.pongTimer != nil {
|
||||
if !c.pongTimer.Stop() {
|
||||
<-c.pongTimer.C
|
||||
}
|
||||
drain(c.pongTimeoutCh)
|
||||
c.pongTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -771,7 +781,7 @@ func (p msgPacket) String() string {
|
|||
return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
|
||||
}
|
||||
|
||||
func drain(ch <-chan struct{}) {
|
||||
func drain(ch <-chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
|
|
Loading…
Reference in New Issue