From c14b39da5f3ce599d047bc6b42b2045c9630251d Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Mon, 7 Aug 2017 18:29:55 -0400 Subject: [PATCH] make RPC server's ping period and pong wait configurable via options --- rpc/lib/rpc_test.go | 18 +++++------ rpc/lib/server/handlers.go | 65 +++++++++++++++++++++++++++----------- 2 files changed, 54 insertions(+), 29 deletions(-) diff --git a/rpc/lib/rpc_test.go b/rpc/lib/rpc_test.go index 341aea46..8bac7aa3 100644 --- a/rpc/lib/rpc_test.go +++ b/rpc/lib/rpc_test.go @@ -31,6 +31,8 @@ const ( unixAddr = "unix://" + unixSocket websocketEndpoint = "/websocket/endpoint" + + testPongWait = 2 * time.Second ) type ResultEcho struct { @@ -113,7 +115,7 @@ func setup() { tcpLogger := logger.With("socket", "tcp") mux := http.NewServeMux() server.RegisterRPCFuncs(mux, Routes, tcpLogger) - wm := server.NewWebsocketManager(Routes, nil) + wm := server.NewWebsocketManager(Routes, nil, server.PingPong((testPongWait*9)/10, testPongWait)) wm.SetLogger(tcpLogger) mux.HandleFunc(websocketEndpoint, wm.WebsocketHandler) go func() { @@ -276,7 +278,7 @@ func TestServersAndClientsBasic(t *testing.T) { testWithHTTPClient(t, cl2) cl3 := client.NewWSClient(addr, websocketEndpoint) - cl3.SetLogger(log.TestingLogger()) + cl3.SetLogger(log.TestingLogger()) _, err := cl3.Start() require.Nil(t, err) fmt.Printf("=== testing server on %s using %v client", addr, cl3) @@ -305,7 +307,7 @@ func TestQuotedStringArg(t *testing.T) { func TestWSNewWSRPCFunc(t *testing.T) { cl := client.NewWSClient(tcpAddr, websocketEndpoint) - cl.SetLogger(log.TestingLogger()) + cl.SetLogger(log.TestingLogger()) _, err := cl.Start() require.Nil(t, err) defer cl.Stop() @@ -331,7 +333,7 @@ func TestWSNewWSRPCFunc(t *testing.T) { func TestWSHandlesArrayParams(t *testing.T) { cl := client.NewWSClient(tcpAddr, websocketEndpoint) - cl.SetLogger(log.TestingLogger()) + cl.SetLogger(log.TestingLogger()) _, err := cl.Start() require.Nil(t, err) defer cl.Stop() @@ -356,17 +358,13 @@ func TestWSHandlesArrayParams(t *testing.T) { // TestWSClientPingPong checks that a client & server exchange pings // & pongs so connection stays alive. func TestWSClientPingPong(t *testing.T) { - if testing.Short() { - t.Skip("skipping ping pong in short mode") - } - cl := client.NewWSClient(tcpAddr, websocketEndpoint) - cl.SetLogger(log.TestingLogger()) + cl.SetLogger(log.TestingLogger()) _, err := cl.Start() require.Nil(t, err) defer cl.Stop() - time.Sleep(35 * time.Second) + time.Sleep((testPongWait * 11) / 10) } func randBytes(t *testing.T) []byte { diff --git a/rpc/lib/server/handlers.go b/rpc/lib/server/handlers.go index b6431a1e..b95f606c 100644 --- a/rpc/lib/server/handlers.go +++ b/rpc/lib/server/handlers.go @@ -337,10 +337,10 @@ func nonJsonToArg(ty reflect.Type, arg string) (reflect.Value, error, bool) { // rpc.websocket const ( - writeChanCapacity = 1000 - wsWriteTimeoutSeconds = 30 // each write times out after this. - wsReadTimeoutSeconds = 30 // connection times out if we haven't received *anything* in this long, not even pings. - wsPingTickerSeconds = 10 // send a ping every PingTickerSeconds. + writeChanCapacity = 1000 + wsWriteWait = 30 * time.Second // each write times out after this. + defaultWSPongWait = 30 * time.Second + defaultWSPingPeriod = 10 * time.Second ) // a single websocket connection @@ -357,29 +357,54 @@ type wsConnection struct { funcMap map[string]*RPCFunc evsw events.EventSwitch + + // Connection times out if we haven't received *anything* in this long, not even pings. + pongWait time.Duration + + // Send pings to server with this period. Must be less than pongWait. + pingPeriod time.Duration } -// new websocket connection wrapper -func NewWSConnection(baseConn *websocket.Conn, funcMap map[string]*RPCFunc, evsw events.EventSwitch) *wsConnection { +// NewWSConnection wraps websocket.Conn. See the commentary on the +// func(*wsConnection) functions for a detailed description of how to configure +// ping period and pong wait time. +func NewWSConnection(baseConn *websocket.Conn, funcMap map[string]*RPCFunc, evsw events.EventSwitch, options ...func(*wsConnection)) *wsConnection { wsc := &wsConnection{ remoteAddr: baseConn.RemoteAddr().String(), baseConn: baseConn, writeChan: make(chan types.RPCResponse, writeChanCapacity), // error when full. funcMap: funcMap, evsw: evsw, + pongWait: defaultWSPongWait, + pingPeriod: defaultWSPingPeriod, + } + for _, option := range options { + option(wsc) } wsc.BaseService = *cmn.NewBaseService(nil, "wsConnection", wsc) return wsc } +// PingPong allows changing ping period and pong wait time. If ping period +// greater or equal to pong wait time, panic will be thrown. +func PingPong(pingPeriod, pongWait time.Duration) func(*wsConnection) { + return func(wsc *wsConnection) { + if pingPeriod >= pongWait { + panic(fmt.Sprintf("ping period (%v) must be less than pong wait time (%v)", pingPeriod, pongWait)) + } + wsc.pingPeriod = pingPeriod + wsc.pongWait = pongWait + } +} + // wsc.Start() blocks until the connection closes. func (wsc *wsConnection) OnStart() error { wsc.BaseService.OnStart() // these must be set before the readRoutine is created, as it may // call wsc.Stop(), which accesses these timers - wsc.readTimeout = time.NewTimer(time.Second * wsReadTimeoutSeconds) - wsc.pingTicker = time.NewTicker(time.Second * wsPingTickerSeconds) + wsc.readTimeout = time.NewTimer(wsc.pongWait) + wsc.pingTicker = time.NewTicker(wsc.pingPeriod) // Read subscriptions/unsubscriptions to events go wsc.readRoutine() @@ -387,13 +412,13 @@ func (wsc *wsConnection) OnStart() error { // Custom Ping handler to touch readTimeout wsc.baseConn.SetPingHandler(func(m string) error { // NOTE: https://github.com/gorilla/websocket/issues/97 - go wsc.baseConn.WriteControl(websocket.PongMessage, []byte(m), time.Now().Add(time.Second*wsWriteTimeoutSeconds)) - wsc.readTimeout.Reset(time.Second * wsReadTimeoutSeconds) + go wsc.baseConn.WriteControl(websocket.PongMessage, []byte(m), time.Now().Add(wsWriteWait)) + wsc.readTimeout.Reset(wsc.pongWait) return nil }) wsc.baseConn.SetPongHandler(func(m string) error { // NOTE: https://github.com/gorilla/websocket/issues/97 - wsc.readTimeout.Reset(time.Second * wsReadTimeoutSeconds) + wsc.readTimeout.Reset(wsc.pongWait) return nil }) go wsc.readTimeoutRoutine() @@ -472,7 +497,7 @@ func (wsc *wsConnection) readRoutine() { default: var in []byte // Do not set a deadline here like below: - // wsc.baseConn.SetReadDeadline(time.Now().Add(time.Second * wsReadTimeoutSeconds)) + // wsc.baseConn.SetReadDeadline(time.Now().Add(wsc.pongWait)) // The client may not send anything for a while. // We use `readTimeout` to handle read timeouts. _, in, err := wsc.baseConn.ReadMessage() @@ -559,7 +584,7 @@ func (wsc *wsConnection) writeRoutine() { // All writes to the websocket must (re)set the write deadline. // If some writes don't set it while others do, they may timeout incorrectly (https://github.com/tendermint/tendermint/issues/553) func (wsc *wsConnection) writeMessageWithDeadline(msgType int, msg []byte) error { - wsc.baseConn.SetWriteDeadline(time.Now().Add(time.Second * wsWriteTimeoutSeconds)) + wsc.baseConn.SetWriteDeadline(time.Now().Add(wsWriteWait)) return wsc.baseConn.WriteMessage(msgType, msg) } @@ -570,12 +595,13 @@ func (wsc *wsConnection) writeMessageWithDeadline(msgType int, msg []byte) error // NOTE: The websocket path is defined externally, e.g. in node/node.go type WebsocketManager struct { websocket.Upgrader - funcMap map[string]*RPCFunc - evsw events.EventSwitch - logger log.Logger + funcMap map[string]*RPCFunc + evsw events.EventSwitch + logger log.Logger + wsConnOptions []func(*wsConnection) } -func NewWebsocketManager(funcMap map[string]*RPCFunc, evsw events.EventSwitch) *WebsocketManager { +func NewWebsocketManager(funcMap map[string]*RPCFunc, evsw events.EventSwitch, wsConnOptions ...func(*wsConnection)) *WebsocketManager { return &WebsocketManager{ funcMap: funcMap, evsw: evsw, @@ -587,7 +613,8 @@ func NewWebsocketManager(funcMap map[string]*RPCFunc, evsw events.EventSwitch) * return true }, }, - logger: log.NewNopLogger(), + logger: log.NewNopLogger(), + wsConnOptions: wsConnOptions, } } @@ -605,7 +632,7 @@ func (wm *WebsocketManager) WebsocketHandler(w http.ResponseWriter, r *http.Requ } // register connection - con := NewWSConnection(wsConn, wm.funcMap, wm.evsw) + con := NewWSConnection(wsConn, wm.funcMap, wm.evsw, wm.wsConnOptions...) con.SetLogger(wm.logger) wm.logger.Info("New websocket connection", "remote", con.remoteAddr) con.Start() // Blocking