From 5f8cea77c0c0acfff99ed93f9421afb900633999 Mon Sep 17 00:00:00 2001 From: George Tankersley Date: Thu, 17 Oct 2019 22:47:40 -0400 Subject: [PATCH] WIP type safety + factored address book --- zcash/address.go | 101 ++++++++++++++++-- zcash/client.go | 255 ++++++++++++++++------------------------------ zcash/peer_map.go | 81 +++++++++++++++ 3 files changed, 263 insertions(+), 174 deletions(-) create mode 100644 zcash/peer_map.go diff --git a/zcash/address.go b/zcash/address.go index aa77afb..7eea87c 100644 --- a/zcash/address.go +++ b/zcash/address.go @@ -3,6 +3,7 @@ package zcash import ( "net" "strconv" + "sync" "time" "github.com/btcsuite/btcd/wire" @@ -15,12 +16,21 @@ type Address struct { lastTried time.Time } +func NewAddress(na *wire.NetAddress) *Address { + return &Address{ + netaddr: na, + valid: true, + blacklist: false, + lastTried: time.Now(), + } +} + func (a *Address) IsGood() bool { - return a.valid == true + return a.valid && !a.blacklist } func (a *Address) IsBad() bool { - return a.blacklist == true + return a.blacklist } func (a *Address) String() string { @@ -28,13 +38,88 @@ func (a *Address) String() string { return net.JoinHostPort(a.netaddr.IP.String(), portString) } -// Addresses should be sortable by least-recently-tried +func (a *Address) asPeerKey() PeerKey { + return PeerKey(a.String()) +} -type AddrList []*Address +func (a *Address) MarshalText() (text []byte, err error) { + return []byte(a.String()), nil +} -func (list AddrList) Len() int { return len(list) } -func (list AddrList) Swap(i, j int) { list[i], list[j] = list[j], list[i] } -func (list AddrList) Less(i, j int) bool { return list[i].lastTried.Before(list[j].lastTried) } +type AddressBook struct { + addrList []*Address + addrState sync.RWMutex + addrRecvCond *sync.Cond +} + +func (bk *AddressBook) Add(newAddr *Address) { + bk.addrState.Lock() + bk.addrList = append(bk.addrList, newAddr) + bk.addrState.Unlock() +} + +func (bk *AddressBook) Blacklist(addr PeerKey) { + bk.addrState.Lock() + for i := 0; i < len(bk.addrList); i++ { + address := bk.addrList[i] + if address.asPeerKey() == addr { + address.valid = false + address.blacklist = true + } + } + bk.addrState.Unlock() +} + +func (bk *AddressBook) AlreadyKnowsAddress(na *wire.NetAddress) bool { + bk.addrState.RLock() + defer bk.addrState.RUnlock() + + addr := NewAddress(na) + + for i := 0; i < len(bk.addrList); i++ { + if bk.addrList[i].String() == addr.String() { + return true + } + } + return false +} + +func (bk *AddressBook) IsBlacklistedAddress(na *wire.NetAddress) bool { + bk.addrState.RLock() + defer bk.addrState.RUnlock() + + ref := NewAddress(na) + + for i := 0; i < len(bk.addrList); i++ { + if bk.addrList[i].String() == ref.String() { + return bk.addrList[i].IsBad() + } + } + + return false +} + +func (bk *AddressBook) UpdateAddressState(update *Address) { + bk.addrState.Lock() + defer bk.addrState.Unlock() + + for i := 0; i < len(bk.addrList); i++ { + if bk.addrList[i].String() == update.String() { + bk.addrList[i].valid = update.valid + bk.addrList[i].blacklist = update.blacklist + bk.addrList[i].lastTried = update.lastTried + return + } + } +} + +func NewAddressBook(capacity int) *AddressBook { + addrBook := &AddressBook{ + addrList: make([]*Address, 0, capacity), + } + addrBook.addrRecvCond = sync.NewCond(&addrBook.addrState) + return addrBook +} // GetShuffledAddressList returns a slice of n valid addresses in random order. -func GetShuffledAddressList(addrList []*Address, n int) []*Address { return nil } +func (ab *AddressBook) GetShuffledAddressList(n int) []*Address { return nil } diff --git a/zcash/client.go b/zcash/client.go index 3004ec2..1d9d640 100644 --- a/zcash/client.go +++ b/zcash/client.go @@ -41,13 +41,11 @@ type Seeder struct { // Peer list handling peerState sync.RWMutex handshakeSignals *sync.Map - pendingPeers *sync.Map - livePeers *sync.Map + pendingPeers *PeerMap + livePeers *PeerMap - // Address list handling - addrState sync.RWMutex - addrRecvCond *sync.Cond - addrList []*Address + // The set of known addresses + addrBook *AddressBook } func NewSeeder(network network.Network) (*Seeder, error) { @@ -62,13 +60,11 @@ func NewSeeder(network network.Network) (*Seeder, error) { config: config, logger: logger, handshakeSignals: new(sync.Map), - pendingPeers: new(sync.Map), - livePeers: new(sync.Map), - addrList: make([]*Address, 0, 1000), + pendingPeers: NewPeerMap(), + livePeers: NewPeerMap(), + addrBook: NewAddressBook(1000), } - newSeeder.addrRecvCond = sync.NewCond(&newSeeder.addrState) - newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck newSeeder.config.Listeners.OnAddr = newSeeder.onAddr @@ -91,13 +87,11 @@ func newTestSeeder(network network.Network) (*Seeder, error) { config: config, logger: logger, handshakeSignals: new(sync.Map), - pendingPeers: new(sync.Map), - livePeers: new(sync.Map), - addrList: make([]*Address, 0, 1000), + pendingPeers: NewPeerMap(), + livePeers: NewPeerMap(), + addrBook: NewAddressBook(1000), } - newSeeder.addrRecvCond = sync.NewCond(&newSeeder.addrState) - newSeeder.config.Listeners.OnVerAck = newSeeder.onVerAck newSeeder.config.Listeners.OnAddr = newSeeder.onAddr @@ -128,7 +122,7 @@ func (s *Seeder) GetNetworkDefaultPort() string { func (s *Seeder) onVerAck(p *peer.Peer, msg *wire.MsgVerAck) { // Check if we're expecting to hear from this peer - _, ok := s.pendingPeers.Load(p.Addr()) + _, ok := s.pendingPeers.Load(peerKeyFromPeer(p)) if !ok { s.logger.Printf("Got verack from unexpected peer %s", p.Addr()) @@ -136,41 +130,32 @@ func (s *Seeder) onVerAck(p *peer.Peer, msg *wire.MsgVerAck) { } // Add to set of live peers - s.livePeers.Store(p.Addr(), p) + s.livePeers.Store(peerKeyFromPeer(p), p) // Remove from set of pending peers - s.pendingPeers.Delete(p.Addr()) + s.pendingPeers.Delete(peerKeyFromPeer(p)) // Signal successful connection if signal, ok := s.handshakeSignals.Load(p.Addr()); ok { signal.(chan struct{}) <- struct{}{} } else { s.logger.Printf("Got verack from peer without a callback channel: %s", p.Addr()) - s.DisconnectPeer(p.Addr()) + s.DisconnectPeer(peerKeyFromPeer(p)) return } // Add to list of known good addresses if we don't already have it. // Otherwise, update the last-valid time. - newAddr := &Address{ - netaddr: p.NA(), - valid: true, - blacklist: false, - lastTried: time.Now(), - } - - if s.alreadyKnowsAddress(p.NA()) { + if s.addrBook.AlreadyKnowsAddress(p.NA()) { + newAddr := NewAddress(p.NA()) s.updateAddressState(newAddr) return } s.logger.Printf("Adding %s to address list", p.Addr()) - s.addrState.Lock() - s.addrList = append(s.addrList, newAddr) - s.addrState.Unlock() - + s.addrBook.Add(newAddr) return } @@ -189,19 +174,19 @@ func (s *Seeder) Connect(addr, port string) error { return errors.Wrap(err, "constructing outbound peer") } - if s.isBlacklistedAddress(p.NA()) { + if s.addrBook.IsBlacklistedAddress(p.NA()) { return ErrBlacklistedPeer } - _, alreadyPending := s.pendingPeers.Load(p.Addr()) - _, alreadyHandshaking := s.handshakeSignals.Load(p.Addr()) - _, alreadyLive := s.livePeers.Load(p.Addr()) + _, alreadyPending := s.pendingPeers.Load(peerKeyFromPeer(p)) + _, alreadyHandshaking := s.handshakeSignals.Load(peerKeyFromPeer(p)) + _, alreadyLive := s.livePeers.Load(peerKeyFromPeer(p)) if alreadyPending { s.logger.Printf("Peer is already pending: %s", p.Addr()) return ErrRepeatConnection } else { - s.pendingPeers.Store(p.Addr(), p) + s.pendingPeers.Store(peerKeyFromPeer(p), p) } if alreadyHandshaking { @@ -242,7 +227,7 @@ func (s *Seeder) Connect(addr, port string) error { // GetPeer returns a live peer identified by "host:port" string, or an error if // we aren't connected to that peer. -func (s *Seeder) GetPeer(addr string) (*peer.Peer, error) { +func (s *Seeder) GetPeer(addr PeerKey) (*peer.Peer, error) { p, ok := s.livePeers.Load(addr) if ok { @@ -254,21 +239,17 @@ func (s *Seeder) GetPeer(addr string) (*peer.Peer, error) { // DisconnectPeer disconnects from a live peer identified by "host:port" // string. It returns an error if we aren't connected to that peer. -func (s *Seeder) DisconnectPeer(addr string) error { +func (s *Seeder) DisconnectPeer(addr PeerKey) error { p, ok := s.livePeers.Load(addr) if !ok { return ErrNoSuchPeer } - // TODO: type safety and error handling - - v := p.(*peer.Peer) - s.logger.Printf("Disconnecting from peer %s", v.Addr()) - v.Disconnect() - v.WaitForDisconnect() + s.logger.Printf("Disconnecting from peer %s", p.Addr()) + p.Disconnect() + p.WaitForDisconnect() s.livePeers.Delete(addr) - return nil } @@ -276,67 +257,43 @@ func (s *Seeder) DisconnectPeer(addr string) error { // "host:port" string. It returns an error if we aren't connected to that peer. // "Dishonorably" furthermore removes this peer from the list of known good // addresses and adds them to a blacklist. -func (s *Seeder) DisconnectPeerDishonorably(addr string) error { +func (s *Seeder) DisconnectPeerDishonorably(addr PeerKey) error { p, ok := s.livePeers.Load(addr) if !ok { return ErrNoSuchPeer } - // TODO: type safety and error handling + s.logger.Printf("Disconnecting from peer %s", addr) + p.Disconnect() + p.WaitForDisconnect() - v := p.(*peer.Peer) - s.logger.Printf("Disconnecting from peer %s", v.Addr()) - v.Disconnect() - v.WaitForDisconnect() + // Remove from live peer set s.livePeers.Delete(addr) - s.addrState.Lock() - for i := 0; i < len(s.addrList); i++ { - address := s.addrList[i] - if address.String() == addr { - s.logger.Printf("Blacklisting peer %s", v.Addr()) - address.valid = false - address.blacklist = true - } - } - s.addrState.Unlock() - + // Never connect to them again + s.logger.Printf("Blacklisting peer %s", addr) + s.addrBook.Blacklist(addr) return nil } // DisconnectAllPeers terminates the connections to all live and pending peers. func (s *Seeder) DisconnectAllPeers() { - s.pendingPeers.Range(func(key, value interface{}) bool { - p, ok := value.(*peer.Peer) - if !ok { - s.logger.Printf("Invalid peer in pendingPeers") - return false - } + s.pendingPeers.Range(func(key PeerKey, p *peer.Peer) bool { p.Disconnect() p.WaitForDisconnect() s.pendingPeers.Delete(key) return true }) - s.livePeers.Range(func(key, value interface{}) bool { - p, ok := value.(*peer.Peer) - if !ok { - s.logger.Printf("Invalid peer in livePeers") - return false - } + s.livePeers.Range(func(key PeerKey, p *peer.Peer) bool { s.DisconnectPeer(p.Addr()) return true }) } func (s *Seeder) RequestAddresses() { - s.livePeers.Range(func(key, value interface{}) bool { - p, ok := value.(*peer.Peer) - if !ok { - s.logger.Printf("Invalid peer in livePeers") - return false - } + s.livePeers.Range(func(key PeerKey, p *peer.Peer) bool { s.logger.Printf("Requesting addresses from peer %s", p.Addr()) p.QueueMessage(wire.NewMsgGetAddr(), nil) return true @@ -360,54 +317,6 @@ func (s *Seeder) WaitForAddresses(n int) error { return nil } -func (s *Seeder) alreadyKnowsAddress(na *wire.NetAddress) bool { - s.addrState.RLock() - defer s.addrState.RUnlock() - - ref := &Address{ - netaddr: na, - } - - for i := 0; i < len(s.addrList); i++ { - if s.addrList[i].String() == ref.String() { - return true - } - } - - return false -} - -func (s *Seeder) isBlacklistedAddress(na *wire.NetAddress) bool { - s.addrState.RLock() - defer s.addrState.RUnlock() - - ref := &Address{ - netaddr: na, - } - - for i := 0; i < len(s.addrList); i++ { - if s.addrList[i].String() == ref.String() { - return s.addrList[i].IsBad() - } - } - - return false -} - -func (s *Seeder) updateAddressState(update *Address) { - s.addrState.Lock() - defer s.addrState.Unlock() - - for i := 0; i < len(s.addrList); i++ { - if s.addrList[i].String() == update.String() { - s.addrList[i].valid = update.valid - s.addrList[i].blacklist = update.blacklist - s.addrList[i].lastTried = update.lastTried - return - } - } -} - func (s *Seeder) onAddr(p *peer.Peer, msg *wire.MsgAddr) { if len(msg.AddrList) == 0 { s.logger.Printf("Got empty addr message from peer %s. Disconnecting.", p.Addr()) @@ -417,49 +326,63 @@ func (s *Seeder) onAddr(p *peer.Peer, msg *wire.MsgAddr) { s.logger.Printf("Got %d addrs from peer %s", len(msg.AddrList), p.Addr()) + queue := make(chan *wire.NetAddress, len(msg.AddrList)) + for _, na := range msg.AddrList { - s.logger.Printf("Trying %s:%d from peer %s", na.IP, na.Port, p.Addr()) - go func(na *wire.NetAddress) { - if !addrmgr.IsRoutable(na) && !s.config.AllowSelfConns { - s.logger.Printf("Got bad addr %s:%d from peer %s", na.IP, na.Port, p.Addr()) - s.DisconnectPeerDishonorably(p.Addr()) - return - } + queue <- na + } - if s.alreadyKnowsAddress(na) { - s.logger.Printf("Already knew about address %s:%d", na.IP, na.Port) - return - } - - if s.isBlacklistedAddress(na) { - s.logger.Printf("Address %s:%d is blacklisted", na.IP, na.Port) - return - } - - portString := strconv.Itoa(int(na.Port)) - err := s.Connect(na.IP.String(), portString) - - if err != nil { - s.logger.Printf("Got unusable peer %s:%d from peer %s. Error: %s", na.IP, na.Port, p.Addr(), err) - - // Mark previously-known peers as invalid - newAddr := &Address{ - netaddr: p.NA(), - valid: false, - lastTried: time.Now(), + for i := 0; i < 32; i++ { + go func() { + var na *wire.NetAddress + for { + select { + case next := <-queue: + na = next + case <-time.After(1 * time.Second): + return } - if s.alreadyKnowsAddress(p.NA()) { - s.updateAddressState(newAddr) + if !addrmgr.IsRoutable(na) && !s.config.AllowSelfConns { + s.logger.Printf("Got bad addr %s:%d from peer %s", na.IP, na.Port, p.Addr()) + s.DisconnectPeerDishonorably(p.Addr()) + continue } - return + if s.addrBook.AlreadyKnowsAddress(na) { + s.logger.Printf("Already knew about address %s:%d", na.IP, na.Port) + continue + } + + if s.addrBook.IsBlacklistedAddress(na) { + s.logger.Printf("Address %s:%d is blacklisted", na.IP, na.Port) + continue + } + + portString := strconv.Itoa(int(na.Port)) + err := s.Connect(na.IP.String(), portString) + + if err != nil { + s.logger.Printf("Got unusable peer %s:%d from peer %s. Error: %s", na.IP, na.Port, p.Addr(), err) + + // Mark previously-known peers as invalid + newAddr := &Address{ + netaddr: p.NA(), + valid: false, + lastTried: time.Now(), + } + + if s.alreadyKnowsAddress(p.NA()) { + s.updateAddressState(newAddr) + } + continue + } + + peerString := net.JoinHostPort(na.IP.String(), portString) + s.DisconnectPeer(peerString) + + s.addrRecvCond.Broadcast() } - - peerString := net.JoinHostPort(na.IP.String(), portString) - s.DisconnectPeer(peerString) - - s.addrRecvCond.Broadcast() - }(na) + }() } } diff --git a/zcash/peer_map.go b/zcash/peer_map.go new file mode 100644 index 0000000..f6c8437 --- /dev/null +++ b/zcash/peer_map.go @@ -0,0 +1,81 @@ +package zcash + +import ( + "net" + "strconv" + "sync" + + "github.com/btcsuite/btcd/peer" + "github.com/btcsuite/btcd/wire" +) + +// The "host:port" format used throughout our maps and lists. +type PeerKey string + +func peerKeyFromPeer(p *peer.Peer) PeerKey { + return PeerKey(p.Addr()) +} + +func peerKeyFromNA(na *wire.NetAddress) PeerKey { + portString := strconv.Itoa(int(na.Port)) + return PeerKey(net.JoinHostPort(na.IP.String(), portString)) +} + +// PeerMap is a typed wrapper for a sync.Map. Its keys are PeerKeys (host:port +// format strings) and it stores a pointer to a btcsuite peer.Peer. +type PeerMap struct { + m *sync.Map +} + +// NewPeerMap returns a fresh, empty PeerMap. +func NewPeerMap() *PeerMap { + return &PeerMap{ + m: new(sync.Map), + } +} + +// Load returns the value stored in the map for a key, or nil if no value is +// present. The ok result indicates whether value was found in the map. +func (pm *PeerMap) Load(key PeerKey) (*peer.Peer, bool) { + v, mapOk := pm.m.Load(key) + if mapOk { + p, typeOk := v.(*peer.Peer) + if typeOK { + return p, true + } + } + return nil, false +} + +// LoadOrStore returns the existing value for the key if present. Otherwise, it +// stores and returns the given value. The loaded result is true if the value +// was loaded, false if stored. +func (pm *PeerMap) LoadOrStore(key PeerKey, value *peer.Peer) (*peer.Peer, bool) { + v, loaded := pm.m.LoadOrStore(key, value) + p, _ := v.(*peer.Peer) + return p, loaded +} + +// Store sets the value for a key. +func (pm *PeerMap) Store(key PeerKey, value *peer.Peer) { + pm.m.Store(key, value) +} + +// Delete deletes the value for a key. +func (pm *PeerMap) Delete(key PeerKey) { + pm.m.Delete(key) +} + +// Range calls f sequentially for each key and value present in the map. If f +// returns false, range stops the iteration. +// +// Range does not necessarily correspond to any consistent snapshot of the +// Map's contents: no key will be visited more than once, but if the value for +// any key is stored or deleted concurrently, Range may reflect any mapping for +// that key from any point during the Range call. +// +// Range may be O(N) with the number of elements in the map even if f returns +// false after a constant number of calls. +func (pm *PeerMap) Range(f func(key PeerKey, value *peer.Peer) bool) { + pm.m.Range(f) +}