return error if client already subscribed
This commit is contained in:
parent
cb4ba522ef
commit
e4ef2835f0
|
@ -13,6 +13,8 @@ package pubsub
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
|
||||||
cmn "github.com/tendermint/tmlibs/common"
|
cmn "github.com/tendermint/tmlibs/common"
|
||||||
)
|
)
|
||||||
|
@ -48,6 +50,9 @@ type Server struct {
|
||||||
|
|
||||||
cmds chan cmd
|
cmds chan cmd
|
||||||
cmdsCap int
|
cmdsCap int
|
||||||
|
|
||||||
|
mtx sync.RWMutex
|
||||||
|
subscriptions map[string]map[string]struct{} // subscriber -> query -> struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option sets a parameter for the server.
|
// Option sets a parameter for the server.
|
||||||
|
@ -57,7 +62,9 @@ type Option func(*Server)
|
||||||
// for a detailed description of how to configure buffering. If no options are
|
// for a detailed description of how to configure buffering. If no options are
|
||||||
// provided, the resulting server's queue is unbuffered.
|
// provided, the resulting server's queue is unbuffered.
|
||||||
func NewServer(options ...Option) *Server {
|
func NewServer(options ...Option) *Server {
|
||||||
s := &Server{}
|
s := &Server{
|
||||||
|
subscriptions: make(map[string]map[string]struct{}),
|
||||||
|
}
|
||||||
s.BaseService = *cmn.NewBaseService(nil, "PubSub", s)
|
s.BaseService = *cmn.NewBaseService(nil, "PubSub", s)
|
||||||
|
|
||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
|
@ -83,17 +90,33 @@ func BufferCapacity(cap int) Option {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BufferCapacity returns capacity of the internal server's queue.
|
// BufferCapacity returns capacity of the internal server's queue.
|
||||||
func (s Server) BufferCapacity() int {
|
func (s *Server) BufferCapacity() int {
|
||||||
return s.cmdsCap
|
return s.cmdsCap
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe creates a subscription for the given client. It accepts a channel
|
// Subscribe creates a subscription for the given client. It accepts a channel
|
||||||
// on which messages matching the given query can be received. If the
|
// on which messages matching the given query can be received. An error will be
|
||||||
// subscription already exists, the old channel will be closed. An error will
|
// returned to the caller if the context is canceled or if subscription already
|
||||||
// be returned to the caller if the context is canceled.
|
// exist for pair clientID and query.
|
||||||
func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, out chan<- interface{}) error {
|
func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, out chan<- interface{}) error {
|
||||||
|
s.mtx.RLock()
|
||||||
|
clientSubscriptions, ok := s.subscriptions[clientID]
|
||||||
|
if ok {
|
||||||
|
_, ok = clientSubscriptions[query.String()]
|
||||||
|
}
|
||||||
|
s.mtx.RUnlock()
|
||||||
|
if ok {
|
||||||
|
return errors.New("already subscribed")
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}:
|
case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}:
|
||||||
|
s.mtx.Lock()
|
||||||
|
if _, ok = s.subscriptions[clientID]; !ok {
|
||||||
|
s.subscriptions[clientID] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
s.subscriptions[clientID][query.String()] = struct{}{}
|
||||||
|
s.mtx.Unlock()
|
||||||
return nil
|
return nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
|
@ -101,10 +124,24 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unsubscribe removes the subscription on the given query. An error will be
|
// Unsubscribe removes the subscription on the given query. An error will be
|
||||||
// returned to the caller if the context is canceled.
|
// returned to the caller if the context is canceled or if subscription does
|
||||||
|
// not exist.
|
||||||
func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error {
|
func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error {
|
||||||
|
s.mtx.RLock()
|
||||||
|
clientSubscriptions, ok := s.subscriptions[clientID]
|
||||||
|
if ok {
|
||||||
|
_, ok = clientSubscriptions[query.String()]
|
||||||
|
}
|
||||||
|
s.mtx.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return errors.New("subscription not found")
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}:
|
case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}:
|
||||||
|
s.mtx.Lock()
|
||||||
|
delete(clientSubscriptions, query.String())
|
||||||
|
s.mtx.Unlock()
|
||||||
return nil
|
return nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
|
@ -112,10 +149,20 @@ func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnsubscribeAll removes all client subscriptions. An error will be returned
|
// UnsubscribeAll removes all client subscriptions. An error will be returned
|
||||||
// to the caller if the context is canceled.
|
// to the caller if the context is canceled or if subscription does not exist.
|
||||||
func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error {
|
func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error {
|
||||||
|
s.mtx.RLock()
|
||||||
|
_, ok := s.subscriptions[clientID]
|
||||||
|
s.mtx.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return errors.New("subscription not found")
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case s.cmds <- cmd{op: unsub, clientID: clientID}:
|
case s.cmds <- cmd{op: unsub, clientID: clientID}:
|
||||||
|
s.mtx.Lock()
|
||||||
|
delete(s.subscriptions, clientID)
|
||||||
|
s.mtx.Unlock()
|
||||||
return nil
|
return nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
|
@ -187,13 +234,8 @@ loop:
|
||||||
|
|
||||||
func (state *state) add(clientID string, q Query, ch chan<- interface{}) {
|
func (state *state) add(clientID string, q Query, ch chan<- interface{}) {
|
||||||
// add query if needed
|
// add query if needed
|
||||||
if clientToChannelMap, ok := state.queries[q]; !ok {
|
if _, ok := state.queries[q]; !ok {
|
||||||
state.queries[q] = make(map[string]chan<- interface{})
|
state.queries[q] = make(map[string]chan<- interface{})
|
||||||
} else {
|
|
||||||
// check if already subscribed
|
|
||||||
if oldCh, ok := clientToChannelMap[clientID]; ok {
|
|
||||||
close(oldCh)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// create subscription
|
// create subscription
|
||||||
|
|
|
@ -86,14 +86,11 @@ func TestClientSubscribesTwice(t *testing.T) {
|
||||||
|
|
||||||
ch2 := make(chan interface{}, 1)
|
ch2 := make(chan interface{}, 1)
|
||||||
err = s.Subscribe(ctx, clientID, q, ch2)
|
err = s.Subscribe(ctx, clientID, q, ch2)
|
||||||
require.NoError(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
_, ok := <-ch1
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
err = s.PublishWithTags(ctx, "Spider-Man", map[string]interface{}{"tm.events.type": "NewBlock"})
|
err = s.PublishWithTags(ctx, "Spider-Man", map[string]interface{}{"tm.events.type": "NewBlock"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertReceive(t, "Spider-Man", ch2)
|
assertReceive(t, "Spider-Man", ch1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnsubscribe(t *testing.T) {
|
func TestUnsubscribe(t *testing.T) {
|
||||||
|
@ -117,6 +114,27 @@ func TestUnsubscribe(t *testing.T) {
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResubscribe(t *testing.T) {
|
||||||
|
s := pubsub.NewServer()
|
||||||
|
s.SetLogger(log.TestingLogger())
|
||||||
|
s.Start()
|
||||||
|
defer s.Stop()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
ch := make(chan interface{})
|
||||||
|
err := s.Subscribe(ctx, clientID, query.Empty{}, ch)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = s.Unsubscribe(ctx, clientID, query.Empty{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
ch = make(chan interface{})
|
||||||
|
err = s.Subscribe(ctx, clientID, query.Empty{}, ch)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = s.Publish(ctx, "Cable")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assertReceive(t, "Cable", ch)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnsubscribeAll(t *testing.T) {
|
func TestUnsubscribeAll(t *testing.T) {
|
||||||
s := pubsub.NewServer()
|
s := pubsub.NewServer()
|
||||||
s.SetLogger(log.TestingLogger())
|
s.SetLogger(log.TestingLogger())
|
||||||
|
@ -125,9 +143,9 @@ func TestUnsubscribeAll(t *testing.T) {
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ch1, ch2 := make(chan interface{}, 1), make(chan interface{}, 1)
|
ch1, ch2 := make(chan interface{}, 1), make(chan interface{}, 1)
|
||||||
err := s.Subscribe(ctx, clientID, query.Empty{}, ch1)
|
err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'"), ch1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = s.Subscribe(ctx, clientID, query.Empty{}, ch2)
|
err = s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlockHeader'"), ch2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = s.UnsubscribeAll(ctx, clientID)
|
err = s.UnsubscribeAll(ctx, clientID)
|
||||||
|
|
Loading…
Reference in New Issue