gecko/utils/json/pubsub_server.go

261 lines
5.7 KiB
Go

// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package json
import (
"errors"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/ava-labs/gecko/snow"
)
const (
// Size of the ws read buffer
readBufferSize = 1024
// Size of the ws write buffer
writeBufferSize = 1024
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 512 // bytes
// Maximum number of pending messages to send to a peer.
maxPendingMessages = 256 // messages
)
var upgrader = websocket.Upgrader{
ReadBufferSize: readBufferSize,
WriteBufferSize: writeBufferSize,
CheckOrigin: func(*http.Request) bool { return true },
}
var (
errDuplicateChannel = errors.New("duplicate channel")
)
// PubSubServer maintains the set of active clients and sends messages to the clients.
type PubSubServer struct {
ctx *snow.Context
lock sync.Mutex
conns map[*Connection]map[string]struct{}
channels map[string]map[*Connection]struct{}
}
// NewPubSubServer ...
func NewPubSubServer(ctx *snow.Context) *PubSubServer {
return &PubSubServer{
ctx: ctx,
conns: make(map[*Connection]map[string]struct{}),
channels: make(map[string]map[*Connection]struct{}),
}
}
func (s *PubSubServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
wsConn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
s.ctx.Log.Debug("Failed to upgrade %s", err)
return
}
conn := &Connection{s: s, conn: wsConn, send: make(chan interface{}, maxPendingMessages)}
s.addConnection(conn)
}
// Publish ...
func (s *PubSubServer) Publish(channel string, msg interface{}) {
s.lock.Lock()
defer s.lock.Unlock()
conns, exists := s.channels[channel]
if !exists {
s.ctx.Log.Warn("attempted to publush to an unknown channel %s", channel)
return
}
pubMsg := &publish{
Channel: channel,
Value: msg,
}
for conn := range conns {
select {
case conn.send <- pubMsg:
default:
s.ctx.Log.Verbo("dropping message to subscribed connection due to too many pending messages")
}
}
}
// Register ...
func (s *PubSubServer) Register(channel string) error {
s.lock.Lock()
defer s.lock.Unlock()
if _, exists := s.channels[channel]; exists {
return errDuplicateChannel
}
s.channels[channel] = make(map[*Connection]struct{})
return nil
}
func (s *PubSubServer) addConnection(conn *Connection) {
s.lock.Lock()
defer s.lock.Unlock()
s.conns[conn] = make(map[string]struct{})
go conn.writePump()
go conn.readPump()
}
func (s *PubSubServer) removeConnection(conn *Connection) {
s.lock.Lock()
defer s.lock.Unlock()
channels, exists := s.conns[conn]
if !exists {
s.ctx.Log.Warn("attempted to remove an unknown connection")
return
}
for channel := range channels {
delete(s.channels[channel], conn)
}
}
func (s *PubSubServer) addChannel(conn *Connection, channel string) {
s.lock.Lock()
defer s.lock.Unlock()
channels, exists := s.conns[conn]
if !exists {
return
}
conns, exists := s.channels[channel]
if !exists {
return
}
channels[channel] = struct{}{}
conns[conn] = struct{}{}
}
func (s *PubSubServer) removeChannel(conn *Connection, channel string) {
s.lock.Lock()
defer s.lock.Unlock()
channels, exists := s.conns[conn]
if !exists {
return
}
conns, exists := s.channels[channel]
if !exists {
return
}
delete(channels, channel)
delete(conns, conn)
}
type publish struct {
Channel string `json:"channel"`
Value interface{} `json:"value"`
}
type subscribe struct {
Channel string `json:"channel"`
Unsubscribe bool `json:"unsubscribe"`
}
// Connection is a representation of the websocket connection.
type Connection struct {
s *PubSubServer
// The websocket connection.
conn *websocket.Conn
// Buffered channel of outbound messages.
send chan interface{}
}
// readPump pumps messages from the websocket connection to the hub.
//
// The application runs readPump in a per-connection goroutine. The application
// ensures that there is at most one reader on a connection by executing all
// reads from this goroutine.
func (c *Connection) readPump() {
defer func() {
c.s.removeConnection(c)
c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
for {
msg := subscribe{}
err := c.conn.ReadJSON(&msg)
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
c.s.ctx.Log.Debug("Unexpected close in websockets: %s", err)
}
break
}
if msg.Unsubscribe {
c.s.removeChannel(c, msg.Channel)
} else {
c.s.addChannel(c, msg.Channel)
}
}
}
// writePump pumps messages from the hub to the websocket connection.
//
// A goroutine running writePump is started for each connection. The
// application ensures that there is at most one writer to a connection by
// executing all writes from this goroutine.
func (c *Connection) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// The hub closed the channel.
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
if err := c.conn.WriteJSON(message); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}