zcash: implement address requests and improve test reliability

This commit is contained in:
George Tankersley 2019-10-12 22:00:20 -04:00
parent c6c3f2ca53
commit 7716305c89
2 changed files with 120 additions and 46 deletions

View File

@ -17,6 +17,7 @@ import (
var ( var (
ErrRepeatConnection = errors.New("attempted repeat connection to existing peer") ErrRepeatConnection = errors.New("attempted repeat connection to existing peer")
ErrNoSuchPeer = errors.New("no record of requested peer")
) )
var defaultPeerConfig = &peer.Config{ var defaultPeerConfig = &peer.Config{
@ -37,6 +38,8 @@ type Seeder struct {
pendingPeers *sync.Map pendingPeers *sync.Map
livePeers *sync.Map livePeers *sync.Map
addrRecvChan chan *wire.NetAddress
// For mutating the above // For mutating the above
peerState sync.RWMutex peerState sync.RWMutex
} }
@ -55,9 +58,11 @@ func NewSeeder(network network.Network) (*Seeder, error) {
handshakeSignals: new(sync.Map), handshakeSignals: new(sync.Map),
pendingPeers: new(sync.Map), pendingPeers: new(sync.Map),
livePeers: new(sync.Map), livePeers: new(sync.Map),
addrRecvChan: make(chan *wire.NetAddress, 100),
} }
newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck
newSeeder.config.Listeners.OnAddr = newSeeder.onAddr
return &newSeeder, nil return &newSeeder, nil
} }
@ -80,9 +85,11 @@ func newTestSeeder(network network.Network) (*Seeder, error) {
handshakeSignals: new(sync.Map), handshakeSignals: new(sync.Map),
pendingPeers: new(sync.Map), pendingPeers: new(sync.Map),
livePeers: new(sync.Map), livePeers: new(sync.Map),
addrRecvChan: make(chan *wire.NetAddress, 100),
} }
newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck
newSeeder.config.Listeners.OnAddr = newSeeder.onAddr
return &newSeeder, nil return &newSeeder, nil
} }
@ -156,6 +163,7 @@ func (s *Seeder) ConnectToPeer(addr string) error {
s.logger.Printf("Handshake initated with new peer %s", p.Addr()) s.logger.Printf("Handshake initated with new peer %s", p.Addr())
p.AssociateConnection(conn) p.AssociateConnection(conn)
// TODO: handle disconnect during this
handshakeChan, _ := s.handshakeSignals.Load(p.Addr()) handshakeChan, _ := s.handshakeSignals.Load(p.Addr())
select { select {
@ -178,7 +186,25 @@ func (s *Seeder) GetPeer(addr string) (*peer.Peer, error) {
return p.(*peer.Peer), nil return p.(*peer.Peer), nil
} }
return nil, errors.New("no such active peer") return nil, ErrNoSuchPeer
}
func (s *Seeder) DisconnectPeer(addr string) error {
lookupKey := net.JoinHostPort(addr, s.config.ChainParams.DefaultPort)
p, ok := s.livePeers.Load(lookupKey)
if !ok {
return ErrNoSuchPeer
}
// TODO: type safety and error handling
v := p.(*peer.Peer)
v.Disconnect()
v.WaitForDisconnect()
s.livePeers.Delete(lookupKey)
return nil
} }
func (s *Seeder) DisconnectAllPeers() { func (s *Seeder) DisconnectAllPeers() {
@ -202,11 +228,37 @@ func (s *Seeder) DisconnectAllPeers() {
return false return false
} }
s.logger.Printf("Disconnecting from live peer %s", p.Addr()) s.logger.Printf("Disconnecting from live peer %s", p.Addr())
p.Disconnect() s.DisconnectPeer(p.Addr())
p.WaitForDisconnect()
s.livePeers.Delete(key)
return true return true
}) })
} }
func (s *Seeder) RequestAddresses() {} func (s *Seeder) RequestAddresses() {
s.livePeers.Range(func(key, value interface{}) bool {
p, ok := value.(*peer.Peer)
if !ok {
s.logger.Printf("Invalid peer in livePeers")
return false
}
s.logger.Printf("Requesting addresses from peer %s", p.Addr())
p.QueueMessage(wire.NewMsgGetAddr(), nil)
return true
})
}
func (s *Seeder) WaitForMoreAddresses() {
<-s.addrRecvChan
}
func (s *Seeder) onAddr(p *peer.Peer, msg *wire.MsgAddr) {
if len(msg.AddrList) == 0 {
s.logger.Printf("Got empty addr message from peer %s. Disconnecting.", p.Addr())
s.DisconnectPeer(p.Addr())
return
}
s.logger.Printf("Got %d addrs from peer %s", len(msg.AddrList), p.Addr())
for _, addr := range msg.AddrList {
s.addrRecvChan <- addr
}
}

View File

@ -1,101 +1,106 @@
package zcash package zcash
import ( import (
"context"
"net" "net"
"os"
"sync" "sync"
"testing" "testing"
"time"
"github.com/btcsuite/btcd/peer" "github.com/btcsuite/btcd/peer"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog"
"github.com/gtank/coredns-zcash/zcash/network" "github.com/gtank/coredns-zcash/zcash/network"
) )
func mockLocalPeer(ctx context.Context) error { func TestMain(m *testing.M) {
startMockLoop()
exitCode := m.Run()
os.Exit(exitCode)
}
func startMockLoop() {
// Configure peer to act as a regtest node that offers no services. // Configure peer to act as a regtest node that offers no services.
config, err := newSeederPeerConfig(network.Regtest, defaultPeerConfig) config, err := newSeederPeerConfig(network.Regtest, defaultPeerConfig)
if err != nil { if err != nil {
return err return
} }
config.AllowSelfConns = true config.AllowSelfConns = true
// backendLogger := btclog.NewBackend(os.Stdout) backendLogger := btclog.NewBackend(os.Stdout)
// mockPeerLogger := backendLogger.Logger("mockPeer") mockPeerLogger := backendLogger.Logger("mockPeer")
// mockPeerLogger.SetLevel(btclog.LevelTrace) //mockPeerLogger.SetLevel(btclog.LevelTrace)
// peer.UseLogger(mockPeerLogger) peer.UseLogger(mockPeerLogger)
mockPeer := peer.NewInboundPeer(config) config.Listeners.OnGetAddr = func(p *peer.Peer, msg *wire.MsgGetAddr) {
cache := make([]*wire.NetAddress, 0, 1)
addr := wire.NewNetAddressTimestamp(
time.Now(),
0,
net.ParseIP("127.0.0.1"),
uint16(8233),
)
cache = append(cache, addr)
_, err := p.PushAddrMsg(cache)
if err != nil {
mockPeerLogger.Error(err)
}
}
listenAddr := net.JoinHostPort("127.0.0.1", config.ChainParams.DefaultPort) listenAddr := net.JoinHostPort("127.0.0.1", config.ChainParams.DefaultPort)
listener, err := net.Listen("tcp", listenAddr) listener, err := net.Listen("tcp", listenAddr)
if err != nil { if err != nil {
return err return
} }
go func() { go func() {
conn, err := listener.Accept() for {
if err != nil { conn, err := listener.Accept()
return if err != nil {
} return
}
mockPeer.AssociateConnection(conn) mockPeer := peer.NewInboundPeer(config)
mockPeer.AssociateConnection(conn)
select {
case <-ctx.Done():
mockPeer.Disconnect()
mockPeer.WaitForDisconnect()
return
} }
}() }()
return nil
} }
func TestOutboundPeerSync(t *testing.T) { func TestOutboundPeerSync(t *testing.T) {
testContext, cancel := context.WithCancel(context.Background())
defer cancel()
if err := mockLocalPeer(testContext); err != nil {
t.Logf("error starting mock peer (%v).", err)
}
regSeeder, err := newTestSeeder(network.Regtest) regSeeder, err := newTestSeeder(network.Regtest)
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
return
} }
err = regSeeder.ConnectToPeer("127.0.0.1") err = regSeeder.ConnectToPeer("127.0.0.1")
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
return
} }
// Can we address that peer if we want to? // Can we address that peer if we want to?
p, err := regSeeder.GetPeer("127.0.0.1") p, err := regSeeder.GetPeer("127.0.0.1")
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
return
} }
if p.Connected() { if p.Connected() {
regSeeder.DisconnectAllPeers() regSeeder.DisconnectPeer("127.0.0.1")
} else { } else {
t.Error("Peer never connected") t.Error("Peer never connected")
} }
// Can we STILL address a flushed peer? // Can we STILL address a flushed peer?
p, err = regSeeder.GetPeer("127.0.0.1") p, err = regSeeder.GetPeer("127.0.0.1")
if err == nil { if err != ErrNoSuchPeer {
t.Error("Peer should have been cleared on disconnect") t.Error("Peer should have been cleared on disconnect")
} }
} }
func TestOutboundPeerAsync(t *testing.T) { func TestOutboundPeerAsync(t *testing.T) {
testContext, cancel := context.WithCancel(context.Background())
defer cancel()
if err := mockLocalPeer(testContext); err != nil {
t.Logf("error starting mock peer (%v).", err)
}
regSeeder, err := newTestSeeder(network.Regtest) regSeeder, err := newTestSeeder(network.Regtest)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -134,3 +139,20 @@ func TestOutboundPeerAsync(t *testing.T) {
regSeeder.DisconnectAllPeers() regSeeder.DisconnectAllPeers()
} }
func TestRequestAddresses(t *testing.T) {
regSeeder, err := newTestSeeder(network.Regtest)
if err != nil {
t.Error(err)
return
}
err = regSeeder.ConnectToPeer("127.0.0.1")
if err != nil {
t.Error(err)
return
}
regSeeder.RequestAddresses()
regSeeder.WaitForMoreAddresses()
}