diff --git a/common/repeat_timer.go b/common/repeat_timer.go index 23faf74a..d7d9154d 100644 --- a/common/repeat_timer.go +++ b/common/repeat_timer.go @@ -1,6 +1,7 @@ package common import ( + "sync" "time" ) @@ -10,93 +11,76 @@ It's good for keeping connections alive. A RepeatTimer must be Stop()'d or it will keep a goroutine alive. */ type RepeatTimer struct { - Name string - Ch <-chan time.Time - output chan<- time.Time - input chan repeatCommand + Ch chan time.Time - dur time.Duration - ticker *time.Ticker - stopped bool + mtx sync.Mutex + name string + ticker *time.Ticker + quit chan struct{} + wg *sync.WaitGroup + dur time.Duration } -type repeatCommand int8 - -const ( - Reset repeatCommand = iota - RQuit -) - func NewRepeatTimer(name string, dur time.Duration) *RepeatTimer { - c := make(chan time.Time) var t = &RepeatTimer{ - Name: name, - Ch: c, - output: c, - input: make(chan repeatCommand), - - dur: dur, + Ch: make(chan time.Time), ticker: time.NewTicker(dur), + quit: make(chan struct{}), + wg: new(sync.WaitGroup), + name: name, + dur: dur, } - go t.run() + t.wg.Add(1) + go t.fireRoutine(t.ticker) return t } +func (t *RepeatTimer) fireRoutine(ticker *time.Ticker) { + for { + select { + case t_ := <-ticker.C: + t.Ch <- t_ + case <-t.quit: + // needed so we know when we can reset t.quit + t.wg.Done() + return + } + } +} + // Wait the duration again before firing. func (t *RepeatTimer) Reset() { - t.input <- Reset + t.Stop() + + t.mtx.Lock() // Lock + defer t.mtx.Unlock() + + t.ticker = time.NewTicker(t.dur) + t.quit = make(chan struct{}) + t.wg.Add(1) + go t.fireRoutine(t.ticker) } // For ease of .Stop()'ing services before .Start()'ing them, // we ignore .Stop()'s on nil RepeatTimers. func (t *RepeatTimer) Stop() bool { - // use t.stopped to gracefully handle many Stop() without blocking - if t == nil || t.stopped { + if t == nil { return false } - t.input <- RQuit - t.stopped = true - return true -} + t.mtx.Lock() // Lock + defer t.mtx.Unlock() -func (t *RepeatTimer) run() { - done := false - for !done { + exists := t.ticker != nil + if exists { + t.ticker.Stop() // does not close the channel select { - case cmd := <-t.input: - // stop goroutine if the input says so - // don't close channels, as closed channels mess up select reads - done = t.processInput(cmd) - case tick := <-t.ticker.C: - t.send(tick) + case <-t.Ch: + // read off channel if there's anything there + default: } + close(t.quit) + t.wg.Wait() // must wait for quit to close else we race Reset + t.ticker = nil } -} - -// send performs blocking send on t.Ch -func (t *RepeatTimer) send(tick time.Time) { - // XXX: possibly it is better to not block: - // https://golang.org/src/time/sleep.go#L132 - // select { - // case t.output <- tick: - // default: - // } - t.output <- tick -} - -// all modifications of the internal state of ThrottleTimer -// happen in this method. It is only called from the run goroutine -// so we avoid any race conditions -func (t *RepeatTimer) processInput(cmd repeatCommand) (shutdown bool) { - switch cmd { - case Reset: - t.ticker.Stop() - t.ticker = time.NewTicker(t.dur) - case RQuit: - t.ticker.Stop() - shutdown = true - default: - panic("unknown command!") - } - return shutdown + return exists } diff --git a/common/repeat_timer_test.go b/common/repeat_timer_test.go index db53aa61..87f34b95 100644 --- a/common/repeat_timer_test.go +++ b/common/repeat_timer_test.go @@ -10,7 +10,7 @@ import ( ) type rCounter struct { - input <-chan time.Time + input chan time.Time mtx sync.Mutex count int } @@ -39,11 +39,11 @@ func (c *rCounter) Read() { func TestRepeat(test *testing.T) { assert := asrt.New(test) - dur := time.Duration(100) * time.Millisecond + dur := time.Duration(50) * time.Millisecond short := time.Duration(20) * time.Millisecond // delay waits for cnt durations, an a little extra delay := func(cnt int) time.Duration { - return time.Duration(cnt)*dur + time.Duration(10)*time.Millisecond + return time.Duration(cnt)*dur + time.Duration(5)*time.Millisecond } t := NewRepeatTimer("bar", dur) @@ -70,9 +70,9 @@ func TestRepeat(test *testing.T) { // after a stop, nothing more is sent stopped := t.Stop() assert.True(stopped) - time.Sleep(delay(2)) + time.Sleep(delay(7)) assert.Equal(6, c.Count()) - // extra calls to stop don't block - t.Stop() + // close channel to stop counter + close(t.Ch) } diff --git a/common/throttle_timer.go b/common/throttle_timer.go index a5bd6ded..ab2ad2e6 100644 --- a/common/throttle_timer.go +++ b/common/throttle_timer.go @@ -13,21 +13,20 @@ at most once every "dur". type ThrottleTimer struct { Name string Ch <-chan struct{} - input chan throttleCommand + input chan command output chan<- struct{} dur time.Duration - timer *time.Timer - isSet bool - stopped bool + timer *time.Timer + isSet bool } -type throttleCommand int8 +type command int32 const ( - Set throttleCommand = iota + Set command = iota Unset - TQuit + Quit ) // NewThrottleTimer creates a new ThrottleTimer. @@ -37,7 +36,7 @@ func NewThrottleTimer(name string, dur time.Duration) *ThrottleTimer { Name: name, Ch: c, dur: dur, - input: make(chan throttleCommand), + input: make(chan command), output: c, timer: time.NewTimer(dur), } @@ -75,14 +74,14 @@ func (t *ThrottleTimer) trySend() { // all modifications of the internal state of ThrottleTimer // happen in this method. It is only called from the run goroutine // so we avoid any race conditions -func (t *ThrottleTimer) processInput(cmd throttleCommand) (shutdown bool) { +func (t *ThrottleTimer) processInput(cmd command) (shutdown bool) { switch cmd { case Set: if !t.isSet { t.isSet = true t.timer.Reset(t.dur) } - case TQuit: + case Quit: shutdown = true fallthrough case Unset: @@ -120,10 +119,9 @@ func (t *ThrottleTimer) Unset() { // For ease of stopping services before starting them, we ignore Stop on nil // ThrottleTimers. func (t *ThrottleTimer) Stop() bool { - if t == nil || t.stopped { + if t == nil { return false } - t.input <- TQuit - t.stopped = true + t.input <- Quit return true } diff --git a/common/throttle_timer_test.go b/common/throttle_timer_test.go index 94ec1b43..a1b6606f 100644 --- a/common/throttle_timer_test.go +++ b/common/throttle_timer_test.go @@ -95,6 +95,4 @@ func TestThrottle(test *testing.T) { stopped := t.Stop() assert.True(stopped) - // extra calls to stop don't block - t.Stop() }