zcash: replace mutex tangle with too many sync.Maps

This commit is contained in:
George Tankersley 2019-10-12 19:18:25 -04:00
parent 352b865775
commit c6c3f2ca53
2 changed files with 107 additions and 101 deletions

View File

@ -15,6 +15,10 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var (
ErrRepeatConnection = errors.New("attempted repeat connection to existing peer")
)
var defaultPeerConfig = &peer.Config{ var defaultPeerConfig = &peer.Config{
UserAgentName: "MagicBean", UserAgentName: "MagicBean",
UserAgentVersion: "2.0.7", UserAgentVersion: "2.0.7",
@ -29,9 +33,9 @@ type Seeder struct {
config *peer.Config config *peer.Config
logger *log.Logger logger *log.Logger
handshakeSignals map[string]chan *peer.Peer handshakeSignals *sync.Map
pendingPeers map[string]*peer.Peer pendingPeers *sync.Map
livePeers map[string]*peer.Peer livePeers *sync.Map
// For mutating the above // For mutating the above
peerState sync.RWMutex peerState sync.RWMutex
@ -48,9 +52,9 @@ func NewSeeder(network network.Network) (*Seeder, error) {
newSeeder := Seeder{ newSeeder := Seeder{
config: config, config: config,
logger: logger, logger: logger,
handshakeSignals: make(map[string]chan *peer.Peer), handshakeSignals: new(sync.Map),
pendingPeers: make(map[string]*peer.Peer), pendingPeers: new(sync.Map),
livePeers: make(map[string]*peer.Peer), livePeers: new(sync.Map),
} }
newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck
@ -73,9 +77,9 @@ func newTestSeeder(network network.Network) (*Seeder, error) {
newSeeder := Seeder{ newSeeder := Seeder{
config: config, config: config,
logger: logger, logger: logger,
handshakeSignals: make(map[string]chan *peer.Peer), handshakeSignals: new(sync.Map),
pendingPeers: make(map[string]*peer.Peer), pendingPeers: new(sync.Map),
livePeers: make(map[string]*peer.Peer), livePeers: new(sync.Map),
} }
newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck
@ -101,33 +105,27 @@ func newSeederPeerConfig(magic network.Network, template *peer.Config) (*peer.Co
} }
func (s *Seeder) onVerAck(p *peer.Peer, msg *wire.MsgVerAck) { func (s *Seeder) onVerAck(p *peer.Peer, msg *wire.MsgVerAck) {
// lock peers for read // Check if we're expecting to hear from this peer
s.peerState.RLock() _, ok := s.pendingPeers.Load(p.Addr())
_, expectingPeer := s.pendingPeers[p.Addr()]
s.peerState.RUnlock()
if !expectingPeer { if !ok {
s.logger.Printf("Got verack from unexpected peer %s", p.Addr()) s.logger.Printf("Got verack from unexpected peer %s", p.Addr())
return return
} }
s.peerState.Lock()
{
// Add to set of live peers // Add to set of live peers
s.livePeers[p.Addr()] = p s.livePeers.Store(p.Addr(), p)
// Remove from set of pending peers // Remove from set of pending peers
delete(s.pendingPeers, p.Addr()) s.pendingPeers.Delete(p.Addr())
}
s.peerState.Unlock()
s.peerState.RLock()
{
// Signal successful connection // Signal successful connection
s.handshakeSignals[p.Addr()] <- p if signal, ok := s.handshakeSignals.Load(p.Addr()); ok {
signal.(chan struct{}) <- struct{}{}
return
} }
s.peerState.RUnlock()
s.logger.Printf("Got verack from peer without a callback channel: %s", p.Addr())
} }
// ConnectToPeer attempts to connect to a peer on the default port at the // ConnectToPeer attempts to connect to a peer on the default port at the
@ -140,82 +138,75 @@ func (s *Seeder) ConnectToPeer(addr string) error {
return errors.Wrap(err, "constructing outbound peer") return errors.Wrap(err, "constructing outbound peer")
} }
_, alreadyPending := s.pendingPeers.LoadOrStore(p.Addr(), p)
_, alreadyHandshaking := s.handshakeSignals.LoadOrStore(p.Addr(), make(chan struct{}, 1))
_, alreadyLive := s.livePeers.Load(p.Addr())
if alreadyPending || alreadyHandshaking || alreadyLive {
s.logger.Printf("Attempted repeat connection to peer %s", p.Addr())
return ErrRepeatConnection
}
conn, err := net.Dial("tcp", p.Addr()) conn, err := net.Dial("tcp", p.Addr())
if err != nil { if err != nil {
return errors.Wrap(err, "dialing new peer address") return errors.Wrap(err, "dialing new peer address")
} }
s.peerState.Lock()
{
// Record that we're expecting a verack from this peer.
s.pendingPeers[p.Addr()] = p
// Make a channel for us to wait on.
s.handshakeSignals[p.Addr()] = make(chan *peer.Peer, 1)
}
s.peerState.Unlock()
// Begin connection negotiation. // Begin connection negotiation.
s.logger.Printf("Handshake initated with new peer %s", p.Addr()) s.logger.Printf("Handshake initated with new peer %s", p.Addr())
p.AssociateConnection(conn) p.AssociateConnection(conn)
for { handshakeChan, _ := s.handshakeSignals.Load(p.Addr())
// lock signals map for select
s.peerState.RLock()
handshakeChan := s.handshakeSignals[p.Addr()]
s.peerState.RUnlock()
select { select {
case verackPeer := <-handshakeChan: case <-handshakeChan.(chan struct{}):
s.peerState.Lock() s.logger.Printf("Handshake completed with new peer %s", p.Addr())
{ s.handshakeSignals.Delete(p.Addr())
close(s.handshakeSignals[p.Addr()])
delete(s.handshakeSignals, p.Addr())
}
s.peerState.Unlock()
s.logger.Printf("Handshake completed with new peer %s", verackPeer.Addr())
return nil return nil
case <-time.After(time.Second * 1): case <-time.After(time.Second * 1):
return errors.New("peer handshake timed out") return errors.New("peer handshake timed out")
} }
}
panic("This should be unreachable") panic("This should be unreachable")
} }
func (s *Seeder) GetPeer(addr string) (*peer.Peer, error) { func (s *Seeder) GetPeer(addr string) (*peer.Peer, error) {
lookupKey := net.JoinHostPort(addr, s.config.ChainParams.DefaultPort) lookupKey := net.JoinHostPort(addr, s.config.ChainParams.DefaultPort)
s.peerState.RLock() p, ok := s.livePeers.Load(lookupKey)
p, ok := s.livePeers[lookupKey]
s.peerState.RUnlock()
if !ok { if ok {
return nil, errors.New("no such active peer") return p.(*peer.Peer), nil
} }
return p, nil return nil, errors.New("no such active peer")
}
func (s *Seeder) WaitForPeers() {
panic("not yet implemented")
} }
func (s *Seeder) DisconnectAllPeers() { func (s *Seeder) DisconnectAllPeers() {
s.peerState.Lock() s.pendingPeers.Range(func(key, value interface{}) bool {
{ p, ok := value.(*peer.Peer)
for _, v := range s.pendingPeers { if !ok {
s.logger.Printf("Disconnecting from peer %s", v.Addr()) s.logger.Printf("Invalid peer in pendingPeers")
v.Disconnect() return false
v.WaitForDisconnect()
} }
s.pendingPeers = make(map[string]*peer.Peer) s.logger.Printf("Disconnecting from pending peer %s", p.Addr())
p.Disconnect()
p.WaitForDisconnect()
s.pendingPeers.Delete(key)
return true
})
for _, v := range s.livePeers { s.livePeers.Range(func(key, value interface{}) bool {
s.logger.Printf("Disconnecting from peer %s", v.Addr()) p, ok := value.(*peer.Peer)
v.Disconnect() if !ok {
v.WaitForDisconnect() s.logger.Printf("Invalid peer in livePeers")
return false
} }
s.livePeers = make(map[string]*peer.Peer) s.logger.Printf("Disconnecting from live peer %s", p.Addr())
} p.Disconnect()
s.peerState.Unlock() p.WaitForDisconnect()
s.livePeers.Delete(key)
return true
})
} }
func (s *Seeder) RequestAddresses() {}

View File

@ -3,12 +3,10 @@ package zcash
import ( import (
"context" "context"
"net" "net"
"os" "sync"
"testing" "testing"
"time"
"github.com/btcsuite/btcd/peer" "github.com/btcsuite/btcd/peer"
"github.com/btcsuite/btclog"
"github.com/gtank/coredns-zcash/zcash/network" "github.com/gtank/coredns-zcash/zcash/network"
) )
@ -21,10 +19,10 @@ func mockLocalPeer(ctx context.Context) error {
config.AllowSelfConns = true config.AllowSelfConns = true
backendLogger := btclog.NewBackend(os.Stdout) // backendLogger := btclog.NewBackend(os.Stdout)
mockPeerLogger := backendLogger.Logger("mockPeer") // mockPeerLogger := backendLogger.Logger("mockPeer")
//mockPeerLogger.SetLevel(btclog.LevelTrace) // mockPeerLogger.SetLevel(btclog.LevelTrace)
peer.UseLogger(mockPeerLogger) // peer.UseLogger(mockPeerLogger)
mockPeer := peer.NewInboundPeer(config) mockPeer := peer.NewInboundPeer(config)
@ -103,19 +101,36 @@ func TestOutboundPeerAsync(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
done := make(chan struct{}) var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
go func() { go func() {
err := regSeeder.ConnectToPeer("127.0.0.1") err := regSeeder.ConnectToPeer("127.0.0.1")
if err != nil { if err != nil && err != ErrRepeatConnection {
t.Fatal(err) t.Error(err)
} }
regSeeder.DisconnectAllPeers() wg.Done()
done <- struct{}{}
}() }()
select {
case <-done:
case <-time.After(time.Second * 1):
t.Error("timed out")
} }
wg.Wait()
// Can we address that peer if we want to?
p, err := regSeeder.GetPeer("127.0.0.1")
if err != nil {
t.Error(err)
return
}
if p.Connected() {
// Shouldn't try to connect to a live peer again
err := regSeeder.ConnectToPeer("127.0.0.1")
if err != ErrRepeatConnection {
t.Error("should have caught repeat connection attempt")
}
} else {
t.Error("Peer never connected")
}
regSeeder.DisconnectAllPeers()
} }