268 lines
7.0 KiB
Go
268 lines
7.0 KiB
Go
// Copyright 2020 dfuse Platform Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package ws
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/dfuse-io/solana-go/rpc"
|
|
"github.com/gorilla/rpc/v2/json2"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/tidwall/gjson"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type result interface{}
|
|
|
|
type Client struct {
|
|
rpcURL string
|
|
conn *websocket.Conn
|
|
lock sync.RWMutex
|
|
subscriptionByRequestID map[uint64]*Subscription
|
|
subscriptionByWSSubID map[uint64]*Subscription
|
|
reconnectOnErr bool
|
|
}
|
|
|
|
func Dial(ctx context.Context, rpcURL string) (c *Client, err error) {
|
|
c = &Client{
|
|
rpcURL: rpcURL,
|
|
subscriptionByRequestID: map[uint64]*Subscription{},
|
|
subscriptionByWSSubID: map[uint64]*Subscription{},
|
|
}
|
|
|
|
c.conn, _, err = websocket.DefaultDialer.DialContext(ctx, rpcURL, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("new ws client: dial: %w", err)
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
time.Sleep(20 * time.Second)
|
|
}
|
|
}()
|
|
go c.receiveMessages()
|
|
return c, nil
|
|
}
|
|
|
|
func (c *Client) Close() {
|
|
c.conn.Close()
|
|
}
|
|
|
|
func (c *Client) receiveMessages() {
|
|
for {
|
|
_, message, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
c.closeAllSubscription(err)
|
|
return
|
|
}
|
|
c.handleMessage(message)
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleMessage(message []byte) {
|
|
// when receiving message with id. the result will be a subscription number.
|
|
// that number will be associated to all future message destine to this request
|
|
if gjson.GetBytes(message, "id").Exists() {
|
|
requestID := uint64(gjson.GetBytes(message, "id").Int())
|
|
subID := uint64(gjson.GetBytes(message, "result").Int())
|
|
c.handleNewSubscriptionMessage(requestID, subID)
|
|
return
|
|
}
|
|
|
|
c.handleSubscriptionMessage(uint64(gjson.GetBytes(message, "params.subscription").Int()), message)
|
|
|
|
}
|
|
|
|
func (c *Client) handleNewSubscriptionMessage(requestID, subID uint64) {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
if traceEnabled {
|
|
zlog.Debug("received new subscription message",
|
|
zap.Uint64("message_id", requestID),
|
|
zap.Uint64("subscription_id", subID),
|
|
)
|
|
}
|
|
|
|
callBack, found := c.subscriptionByRequestID[requestID]
|
|
if !found {
|
|
zlog.Error("cannot find websocket message handler for a new stream.... this should not happen",
|
|
zap.Uint64("request_id", requestID),
|
|
zap.Uint64("subscription_id", subID),
|
|
)
|
|
return
|
|
}
|
|
callBack.subID = subID
|
|
c.subscriptionByWSSubID[subID] = callBack
|
|
|
|
zlog.Debug("registered ws subscription",
|
|
zap.Uint64("subscription_id", subID),
|
|
zap.Uint64("request_id", requestID),
|
|
zap.Int("subscription_count", len(c.subscriptionByWSSubID)),
|
|
)
|
|
return
|
|
}
|
|
|
|
func (c *Client) handleSubscriptionMessage(subID uint64, message []byte) {
|
|
if traceEnabled {
|
|
zlog.Debug("received subscription message",
|
|
zap.Uint64("subscription_id", subID),
|
|
)
|
|
}
|
|
|
|
c.lock.RLock()
|
|
sub, found := c.subscriptionByWSSubID[subID]
|
|
c.lock.RUnlock()
|
|
if !found {
|
|
zlog.Warn("unable to find subscription for ws message", zap.Uint64("subscription_id", subID))
|
|
return
|
|
}
|
|
|
|
//getting and instantiate the return type for the call back.
|
|
resultType := reflect.New(sub.reflectType)
|
|
result := resultType.Interface()
|
|
err := decodeResponse(bytes.NewReader(message), &result)
|
|
if err != nil {
|
|
fmt.Println("*****************************")
|
|
c.closeSubscription(sub.req.ID, fmt.Errorf("unable to decode client response: %w", err))
|
|
return
|
|
}
|
|
|
|
// this cannot be blocking or else
|
|
// we will no read any other message
|
|
if len(sub.stream) >= cap(sub.stream) {
|
|
zlog.Warn("closing ws client subscription... not consuming fast en ought",
|
|
zap.Uint64("request_id", sub.req.ID),
|
|
)
|
|
c.closeSubscription(sub.req.ID, fmt.Errorf("reached channel max capacity %d", len(sub.stream)))
|
|
return
|
|
}
|
|
|
|
sub.stream <- result
|
|
return
|
|
}
|
|
|
|
func (c *Client) closeAllSubscription(err error) {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
for _, sub := range c.subscriptionByRequestID {
|
|
sub.err <- err
|
|
}
|
|
|
|
c.subscriptionByRequestID = map[uint64]*Subscription{}
|
|
c.subscriptionByWSSubID = map[uint64]*Subscription{}
|
|
}
|
|
|
|
func (c *Client) closeSubscription(reqID uint64, err error) {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
sub, found := c.subscriptionByRequestID[reqID]
|
|
if !found {
|
|
return
|
|
}
|
|
|
|
sub.err <- err
|
|
|
|
err = c.unsubscribe(sub.subID, sub.unsubscribeMethod)
|
|
if err != nil {
|
|
zlog.Warn("unable to send rpc unsubscribe call",
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
|
|
delete(c.subscriptionByRequestID, sub.req.ID)
|
|
delete(c.subscriptionByWSSubID, sub.subID)
|
|
}
|
|
|
|
func (c *Client) unsubscribe(subID uint64, method string) error {
|
|
req := newRequest([]interface{}{subID}, method, map[string]interface{}{})
|
|
data, err := req.encode()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to encode unsubscription message for subID %d and method %s", subID, method)
|
|
}
|
|
|
|
err = c.conn.WriteMessage(websocket.TextMessage, data)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to send unsubscription message for subID %d and method %s", subID, method)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) subscribe(params []interface{}, conf map[string]interface{}, subscriptionMethod, unsubscribeMethod string, commitment rpc.CommitmentType, resultType interface{}) (*Subscription, error) {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
if commitment != "" {
|
|
conf["commitment"] = string(commitment)
|
|
}
|
|
|
|
req := newRequest(params, subscriptionMethod, conf)
|
|
data, err := req.encode()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("subscribe: unable to encode subsciption request: %w", err)
|
|
}
|
|
|
|
sub := newSubscription(req, reflect.TypeOf(resultType), func(err error) {
|
|
c.closeSubscription(req.ID, err)
|
|
}, unsubscribeMethod)
|
|
|
|
c.subscriptionByRequestID[req.ID] = sub
|
|
zlog.Info("added new subscription to websocket client", zap.Int("count", len(c.subscriptionByRequestID)))
|
|
|
|
zlog.Debug("writing data to conn", zap.String("data", string(data)))
|
|
err = c.conn.WriteMessage(websocket.TextMessage, data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to write request: %w", err)
|
|
}
|
|
|
|
return sub, nil
|
|
}
|
|
|
|
func decodeResponse(r io.Reader, reply interface{}) (err error) {
|
|
var c *response
|
|
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
|
return err
|
|
}
|
|
|
|
if c.Error != nil {
|
|
jsonErr := &json2.Error{}
|
|
if err := json.Unmarshal(*c.Error, jsonErr); err != nil {
|
|
return &json2.Error{
|
|
Code: json2.E_SERVER,
|
|
Message: string(*c.Error),
|
|
}
|
|
}
|
|
return jsonErr
|
|
}
|
|
|
|
if c.Params == nil {
|
|
return json2.ErrNullResult
|
|
}
|
|
|
|
return json.Unmarshal(*c.Params.Result, &reply)
|
|
}
|