zcash: replace mutex tangle with too many sync.Maps
This commit is contained in:
parent
352b865775
commit
c6c3f2ca53
151
zcash/client.go
151
zcash/client.go
|
@ -15,6 +15,10 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrRepeatConnection = errors.New("attempted repeat connection to existing peer")
|
||||||
|
)
|
||||||
|
|
||||||
var defaultPeerConfig = &peer.Config{
|
var defaultPeerConfig = &peer.Config{
|
||||||
UserAgentName: "MagicBean",
|
UserAgentName: "MagicBean",
|
||||||
UserAgentVersion: "2.0.7",
|
UserAgentVersion: "2.0.7",
|
||||||
|
@ -29,9 +33,9 @@ type Seeder struct {
|
||||||
config *peer.Config
|
config *peer.Config
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
|
||||||
handshakeSignals map[string]chan *peer.Peer
|
handshakeSignals *sync.Map
|
||||||
pendingPeers map[string]*peer.Peer
|
pendingPeers *sync.Map
|
||||||
livePeers map[string]*peer.Peer
|
livePeers *sync.Map
|
||||||
|
|
||||||
// For mutating the above
|
// For mutating the above
|
||||||
peerState sync.RWMutex
|
peerState sync.RWMutex
|
||||||
|
@ -48,9 +52,9 @@ func NewSeeder(network network.Network) (*Seeder, error) {
|
||||||
newSeeder := Seeder{
|
newSeeder := Seeder{
|
||||||
config: config,
|
config: config,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
handshakeSignals: make(map[string]chan *peer.Peer),
|
handshakeSignals: new(sync.Map),
|
||||||
pendingPeers: make(map[string]*peer.Peer),
|
pendingPeers: new(sync.Map),
|
||||||
livePeers: make(map[string]*peer.Peer),
|
livePeers: new(sync.Map),
|
||||||
}
|
}
|
||||||
|
|
||||||
newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck
|
newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck
|
||||||
|
@ -73,9 +77,9 @@ func newTestSeeder(network network.Network) (*Seeder, error) {
|
||||||
newSeeder := Seeder{
|
newSeeder := Seeder{
|
||||||
config: config,
|
config: config,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
handshakeSignals: make(map[string]chan *peer.Peer),
|
handshakeSignals: new(sync.Map),
|
||||||
pendingPeers: make(map[string]*peer.Peer),
|
pendingPeers: new(sync.Map),
|
||||||
livePeers: make(map[string]*peer.Peer),
|
livePeers: new(sync.Map),
|
||||||
}
|
}
|
||||||
|
|
||||||
newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck
|
newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck
|
||||||
|
@ -101,33 +105,27 @@ func newSeederPeerConfig(magic network.Network, template *peer.Config) (*peer.Co
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Seeder) onVerAck(p *peer.Peer, msg *wire.MsgVerAck) {
|
func (s *Seeder) onVerAck(p *peer.Peer, msg *wire.MsgVerAck) {
|
||||||
// lock peers for read
|
// Check if we're expecting to hear from this peer
|
||||||
s.peerState.RLock()
|
_, ok := s.pendingPeers.Load(p.Addr())
|
||||||
_, expectingPeer := s.pendingPeers[p.Addr()]
|
|
||||||
s.peerState.RUnlock()
|
|
||||||
|
|
||||||
if !expectingPeer {
|
if !ok {
|
||||||
s.logger.Printf("Got verack from unexpected peer %s", p.Addr())
|
s.logger.Printf("Got verack from unexpected peer %s", p.Addr())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.peerState.Lock()
|
// Add to set of live peers
|
||||||
{
|
s.livePeers.Store(p.Addr(), p)
|
||||||
// Add to set of live peers
|
|
||||||
s.livePeers[p.Addr()] = p
|
|
||||||
|
|
||||||
// Remove from set of pending peers
|
// Remove from set of pending peers
|
||||||
delete(s.pendingPeers, p.Addr())
|
s.pendingPeers.Delete(p.Addr())
|
||||||
|
|
||||||
|
// Signal successful connection
|
||||||
|
if signal, ok := s.handshakeSignals.Load(p.Addr()); ok {
|
||||||
|
signal.(chan struct{}) <- struct{}{}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
s.peerState.Unlock()
|
|
||||||
|
|
||||||
s.peerState.RLock()
|
|
||||||
{
|
|
||||||
// Signal successful connection
|
|
||||||
s.handshakeSignals[p.Addr()] <- p
|
|
||||||
}
|
|
||||||
s.peerState.RUnlock()
|
|
||||||
|
|
||||||
|
s.logger.Printf("Got verack from peer without a callback channel: %s", p.Addr())
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectToPeer attempts to connect to a peer on the default port at the
|
// ConnectToPeer attempts to connect to a peer on the default port at the
|
||||||
|
@ -140,44 +138,33 @@ func (s *Seeder) ConnectToPeer(addr string) error {
|
||||||
return errors.Wrap(err, "constructing outbound peer")
|
return errors.Wrap(err, "constructing outbound peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, alreadyPending := s.pendingPeers.LoadOrStore(p.Addr(), p)
|
||||||
|
_, alreadyHandshaking := s.handshakeSignals.LoadOrStore(p.Addr(), make(chan struct{}, 1))
|
||||||
|
_, alreadyLive := s.livePeers.Load(p.Addr())
|
||||||
|
|
||||||
|
if alreadyPending || alreadyHandshaking || alreadyLive {
|
||||||
|
s.logger.Printf("Attempted repeat connection to peer %s", p.Addr())
|
||||||
|
return ErrRepeatConnection
|
||||||
|
}
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", p.Addr())
|
conn, err := net.Dial("tcp", p.Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "dialing new peer address")
|
return errors.Wrap(err, "dialing new peer address")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.peerState.Lock()
|
|
||||||
{
|
|
||||||
// Record that we're expecting a verack from this peer.
|
|
||||||
s.pendingPeers[p.Addr()] = p
|
|
||||||
|
|
||||||
// Make a channel for us to wait on.
|
|
||||||
s.handshakeSignals[p.Addr()] = make(chan *peer.Peer, 1)
|
|
||||||
}
|
|
||||||
s.peerState.Unlock()
|
|
||||||
|
|
||||||
// Begin connection negotiation.
|
// Begin connection negotiation.
|
||||||
s.logger.Printf("Handshake initated with new peer %s", p.Addr())
|
s.logger.Printf("Handshake initated with new peer %s", p.Addr())
|
||||||
p.AssociateConnection(conn)
|
p.AssociateConnection(conn)
|
||||||
|
|
||||||
for {
|
handshakeChan, _ := s.handshakeSignals.Load(p.Addr())
|
||||||
// lock signals map for select
|
|
||||||
s.peerState.RLock()
|
|
||||||
handshakeChan := s.handshakeSignals[p.Addr()]
|
|
||||||
s.peerState.RUnlock()
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case verackPeer := <-handshakeChan:
|
case <-handshakeChan.(chan struct{}):
|
||||||
s.peerState.Lock()
|
s.logger.Printf("Handshake completed with new peer %s", p.Addr())
|
||||||
{
|
s.handshakeSignals.Delete(p.Addr())
|
||||||
close(s.handshakeSignals[p.Addr()])
|
return nil
|
||||||
delete(s.handshakeSignals, p.Addr())
|
case <-time.After(time.Second * 1):
|
||||||
}
|
return errors.New("peer handshake timed out")
|
||||||
s.peerState.Unlock()
|
|
||||||
s.logger.Printf("Handshake completed with new peer %s", verackPeer.Addr())
|
|
||||||
return nil
|
|
||||||
case <-time.After(time.Second * 1):
|
|
||||||
return errors.New("peer handshake timed out")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
panic("This should be unreachable")
|
panic("This should be unreachable")
|
||||||
|
@ -185,37 +172,41 @@ func (s *Seeder) ConnectToPeer(addr string) error {
|
||||||
|
|
||||||
func (s *Seeder) GetPeer(addr string) (*peer.Peer, error) {
|
func (s *Seeder) GetPeer(addr string) (*peer.Peer, error) {
|
||||||
lookupKey := net.JoinHostPort(addr, s.config.ChainParams.DefaultPort)
|
lookupKey := net.JoinHostPort(addr, s.config.ChainParams.DefaultPort)
|
||||||
s.peerState.RLock()
|
p, ok := s.livePeers.Load(lookupKey)
|
||||||
p, ok := s.livePeers[lookupKey]
|
|
||||||
s.peerState.RUnlock()
|
|
||||||
|
|
||||||
if !ok {
|
if ok {
|
||||||
return nil, errors.New("no such active peer")
|
return p.(*peer.Peer), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return p, nil
|
return nil, errors.New("no such active peer")
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Seeder) WaitForPeers() {
|
|
||||||
panic("not yet implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Seeder) DisconnectAllPeers() {
|
func (s *Seeder) DisconnectAllPeers() {
|
||||||
s.peerState.Lock()
|
s.pendingPeers.Range(func(key, value interface{}) bool {
|
||||||
{
|
p, ok := value.(*peer.Peer)
|
||||||
for _, v := range s.pendingPeers {
|
if !ok {
|
||||||
s.logger.Printf("Disconnecting from peer %s", v.Addr())
|
s.logger.Printf("Invalid peer in pendingPeers")
|
||||||
v.Disconnect()
|
return false
|
||||||
v.WaitForDisconnect()
|
|
||||||
}
|
}
|
||||||
s.pendingPeers = make(map[string]*peer.Peer)
|
s.logger.Printf("Disconnecting from pending peer %s", p.Addr())
|
||||||
|
p.Disconnect()
|
||||||
|
p.WaitForDisconnect()
|
||||||
|
s.pendingPeers.Delete(key)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
for _, v := range s.livePeers {
|
s.livePeers.Range(func(key, value interface{}) bool {
|
||||||
s.logger.Printf("Disconnecting from peer %s", v.Addr())
|
p, ok := value.(*peer.Peer)
|
||||||
v.Disconnect()
|
if !ok {
|
||||||
v.WaitForDisconnect()
|
s.logger.Printf("Invalid peer in livePeers")
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
s.livePeers = make(map[string]*peer.Peer)
|
s.logger.Printf("Disconnecting from live peer %s", p.Addr())
|
||||||
}
|
p.Disconnect()
|
||||||
s.peerState.Unlock()
|
p.WaitForDisconnect()
|
||||||
|
s.livePeers.Delete(key)
|
||||||
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Seeder) RequestAddresses() {}
|
||||||
|
|
|
@ -3,12 +3,10 @@ package zcash
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/peer"
|
"github.com/btcsuite/btcd/peer"
|
||||||
"github.com/btcsuite/btclog"
|
|
||||||
"github.com/gtank/coredns-zcash/zcash/network"
|
"github.com/gtank/coredns-zcash/zcash/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,10 +19,10 @@ func mockLocalPeer(ctx context.Context) error {
|
||||||
|
|
||||||
config.AllowSelfConns = true
|
config.AllowSelfConns = true
|
||||||
|
|
||||||
backendLogger := btclog.NewBackend(os.Stdout)
|
// backendLogger := btclog.NewBackend(os.Stdout)
|
||||||
mockPeerLogger := backendLogger.Logger("mockPeer")
|
// mockPeerLogger := backendLogger.Logger("mockPeer")
|
||||||
//mockPeerLogger.SetLevel(btclog.LevelTrace)
|
// mockPeerLogger.SetLevel(btclog.LevelTrace)
|
||||||
peer.UseLogger(mockPeerLogger)
|
// peer.UseLogger(mockPeerLogger)
|
||||||
|
|
||||||
mockPeer := peer.NewInboundPeer(config)
|
mockPeer := peer.NewInboundPeer(config)
|
||||||
|
|
||||||
|
@ -103,19 +101,36 @@ func TestOutboundPeerAsync(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
done := make(chan struct{})
|
var wg sync.WaitGroup
|
||||||
go func() {
|
for i := 0; i < 4; i++ {
|
||||||
err := regSeeder.ConnectToPeer("127.0.0.1")
|
wg.Add(1)
|
||||||
if err != nil {
|
go func() {
|
||||||
t.Fatal(err)
|
err := regSeeder.ConnectToPeer("127.0.0.1")
|
||||||
}
|
if err != nil && err != ErrRepeatConnection {
|
||||||
regSeeder.DisconnectAllPeers()
|
t.Error(err)
|
||||||
done <- struct{}{}
|
}
|
||||||
}()
|
wg.Done()
|
||||||
|
}()
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-time.After(time.Second * 1):
|
|
||||||
t.Error("timed out")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Can we address that peer if we want to?
|
||||||
|
p, err := regSeeder.GetPeer("127.0.0.1")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Connected() {
|
||||||
|
// Shouldn't try to connect to a live peer again
|
||||||
|
err := regSeeder.ConnectToPeer("127.0.0.1")
|
||||||
|
if err != ErrRepeatConnection {
|
||||||
|
t.Error("should have caught repeat connection attempt")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Error("Peer never connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
regSeeder.DisconnectAllPeers()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue