diff --git a/server.go b/server.go index 4a4ed899..c6b74e10 100644 --- a/server.go +++ b/server.go @@ -64,7 +64,7 @@ type server struct { // long-term identity private key. lightningID [32]byte - mu sync.Mutex + mu sync.RWMutex peersByID map[int32]*peer peersByPub map[string]*peer @@ -678,9 +678,9 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, case <-sampleTicker.C: // Obtain the current number of peers, so we can gauge // if we need to sample more peers or not. - s.mu.Lock() + s.mu.RLock() numActivePeers := uint32(len(s.peersByPub)) - s.mu.Unlock() + s.mu.RUnlock() // If we have enough peers, then we can loop back // around to the next round as we're done here. @@ -896,6 +896,11 @@ func (s *server) establishPersistentConnections() error { return err } + // Acquire and hold server lock until all persistent connection requests + // have been recorded and sent to the connection manager. + s.mu.Lock() + defer s.mu.Unlock() + // Iterate through the combined list of addresses from prior links and // node announcements and attempt to reconnect to each node. for pubStr, nodeAddr := range nodeAddrsMap { @@ -939,8 +944,8 @@ func (s *server) establishPersistentConnections() error { func (s *server) BroadcastMessage(skip map[routing.Vertex]struct{}, msgs ...lnwire.Message) error { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() return s.broadcastMessages(skip, msgs) } @@ -990,10 +995,30 @@ func (s *server) broadcastMessages( func (s *server) SendToPeer(target *btcec.PublicKey, msgs ...lnwire.Message) error { - s.mu.Lock() - defer s.mu.Unlock() + // Queue the incoming messages in the peer's outgoing message buffer. + // We acquire the shared lock here to ensure the peer map doesn't change + // from underneath us. + s.mu.RLock() + targetPeer, errChans, err := s.sendToPeer(target, msgs) + s.mu.RUnlock() + if err != nil { + return err + } - return s.sendToPeer(target, msgs) + // With the server's shared lock released, we now handle all of the + // errors being returned from the target peer's write handler. + for _, errChan := range errChans { + select { + case err := <-errChan: + return err + case <-targetPeer.quit: + return fmt.Errorf("peer shutting down") + case <-s.quit: + return ErrServerShuttingDown + } + } + + return nil } // NotifyWhenOnline can be called by other subsystems to get notified when a @@ -1023,10 +1048,12 @@ func (s *server) NotifyWhenOnline(peer *btcec.PublicKey, s.peerConnectedListeners[pubStr], connectedChan) } -// sendToPeer is an internal method that delivers messages to the specified -// `target` peer. +// sendToPeer is an internal method that queues the given messages in the +// outgoing buffer of the specified `target` peer. Upon success, this method +// returns the peer instance and a slice of error chans that will contain +// responses from the write handler. func (s *server) sendToPeer(target *btcec.PublicKey, - msgs []lnwire.Message) error { + msgs []lnwire.Message) (*peer, []chan error, error) { // Compute the target peer's identifier. targetPubBytes := target.SerializeCompressed() @@ -1042,23 +1069,14 @@ func (s *server) sendToPeer(target *btcec.PublicKey, if err == ErrPeerNotFound { srvrLog.Errorf("unable to send message to %x, "+ "peer not found", targetPubBytes) - return err + return nil, nil, err } - // Send messages to the peer and return any error from - // sending a message. + // Send messages to the peer and return the error channels that will be + // signaled by the peer's write handler. errChans := s.sendPeerMessages(targetPeer, msgs, nil) - for _, errChan := range errChans { - select { - case err := <-errChan: - return err - case <-targetPeer.quit: - return fmt.Errorf("peer shutting down") - case <-s.quit: - return ErrServerShuttingDown - } - } - return nil + + return targetPeer, errChans, nil } // sendPeerMessages enqueues a list of messages into the outgoingQueue of the @@ -1109,8 +1127,8 @@ func (s *server) sendPeerMessages( // // NOTE: This function is safe for concurrent access. func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() pubStr := string(peerKey.SerializeCompressed()) @@ -1123,8 +1141,8 @@ func (s *server) FindPeer(peerKey *btcec.PublicKey) (*peer, error) { // // NOTE: This function is safe for concurrent access. func (s *server) FindPeerByPubStr(pubStr string) (*peer, error) { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() return s.findPeerByPubStr(pubStr) } @@ -1713,13 +1731,13 @@ func (s *server) OpenChannel(peerID int32, nodeKey *btcec.PublicKey, // First attempt to locate the target peer to open a channel with, if // we're unable to locate the peer then this request will fail. - s.mu.Lock() + s.mu.RLock() if peer, ok := s.peersByID[peerID]; ok { targetPeer = peer } else if peer, ok := s.peersByPub[string(pubKeyBytes)]; ok { targetPeer = peer } - s.mu.Unlock() + s.mu.RUnlock() if targetPeer == nil { errChan <- fmt.Errorf("unable to find peer nodeID(%x), "+ @@ -1770,8 +1788,8 @@ func (s *server) OpenChannel(peerID int32, nodeKey *btcec.PublicKey, // // NOTE: This function is safe for concurrent access. func (s *server) Peers() []*peer { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() peers := make([]*peer, 0, len(s.peersByID)) for _, peer := range s.peersByID {