diff --git a/rpcserver.go b/rpcserver.go index 40b72bb4..31338cba 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -240,12 +240,47 @@ func (r *rpcServer) ConnectPeer(ctx context.Context, return &lnrpc.ConnectPeerResponse{}, nil } -// DisconnectPeer attempts to disconnect one peer from another identified by a given pubKey. -func (r *rpcServer) DisconnectPeer(ctx context.Context, in *lnrpc.DisconnectPeerRequest) (*lnrpc.DisconnectPeerResponse, error) { - if err := r.server.DisconnectFromPeer(in.PubKey); err != nil { +// DisconnectPeer attempts to disconnect one peer from another identified by a +// given pubKey. In the case that we currently ahve a pending or active channel +// with the target peer, then +func (r *rpcServer) DisconnectPeer(ctx context.Context, + in *lnrpc.DisconnectPeerRequest) (*lnrpc.DisconnectPeerResponse, error) { + + rpcsLog.Debugf("[disconnectpeer] from peer(%s)", in.PubKey) + + // First we'll validate the string passed in within the request to + // ensure that it's a valid hex-string, and also a valid compressed + // public key. + pubKeyBytes, err := hex.DecodeString(in.PubKey) + if err != nil { + return nil, fmt.Errorf("unable to decode pubkey bytes: %v", err) + } + peerPubKey, err := btcec.ParsePubKey(pubKeyBytes, btcec.S256()) + if err != nil { + return nil, fmt.Errorf("unable to parse pubkey: %v", err) + } + + // Next, we'll fetch the pending/active channels we have with a + // particular peer. + nodeChannels, err := r.server.chanDB.FetchOpenChannels(peerPubKey) + if err != nil { + return nil, fmt.Errorf("unable to fetch channels for peer: %v", err) + } + + // In order to avoid erroneously disconnecting from a peer that we have + // an active channel with, if we have any channels active with this + // peer, then we'll disallow disconnecting from them. + if len(nodeChannels) > 0 { + return nil, fmt.Errorf("cannot disconnect from peer(%x), "+ + "all active channels with the peer need to be closed "+ + "first", pubKeyBytes) + } + + // With all initial validation complete, we'll now request that the + // sever disconnects from the per. + if err := r.server.DisconnectPeer(peerPubKey); err != nil { return nil, fmt.Errorf("unable to disconnect peer: %v", err) } - rpcsLog.Debugf("[disconnectpeer] from peer(%s)", in.PubKey) return &lnrpc.DisconnectPeerResponse{}, nil } diff --git a/server.go b/server.go index 029937a1..e1e8326b 100644 --- a/server.go +++ b/server.go @@ -648,8 +648,8 @@ func (s *server) peerTerminationWatcher(p *peer) { srvrLog.Debugf("Peer %v has been disconnected", p) // Tell the switch to unregister all links associated with this peer. - // Passing nil as the target link indicates that all - // links associated with this interface should be closed. + // Passing nil as the target link indicates that all links associated + // with this interface should be closed. p.server.htlcSwitch.UnregisterLink(p.addr.IdentityKey, nil) // Send the peer to be garbage collected by the server. @@ -664,7 +664,10 @@ func (s *server) peerTerminationWatcher(p *peer) { // Next, check to see if this is a persistent peer or not. pubStr := string(p.addr.IdentityKey.SerializeCompressed()) - if _, ok := s.persistentPeers[pubStr]; ok { + s.pendingConnMtx.RLock() + _, ok := s.persistentPeers[pubStr] + s.pendingConnMtx.RUnlock() + if ok { srvrLog.Debugf("Attempting to re-establish persistent "+ "connection to peer %v", p) @@ -677,7 +680,6 @@ func (s *server) peerTerminationWatcher(p *peer) { } s.pendingConnMtx.Lock() - // We'll only need to re-launch a connection requests if one // isn't already currently pending. if _, ok := s.persistentConnReqs[pubStr]; ok { @@ -948,8 +950,10 @@ type connectPeerMsg struct { err chan error } +// disconnectPeerMsg is a message requesting the server to disconnect from an +// active peer. type disconnectPeerMsg struct { - pubKey string + pubKey *btcec.PublicKey err chan error } @@ -1169,52 +1173,34 @@ func (s *server) handleConnectPeer(msg *connectPeerMsg) { // handleDisconnectPeer attempts to disconnect one peer from another func (s *server) handleDisconnectPeer(msg *disconnectPeerMsg) { - pubKey, err := hex.DecodeString(msg.pubKey) - if err != nil { - msg.err <- fmt.Errorf("unable to DecodeString public key: %v", err) - return - } + pubBytes := msg.pubKey.SerializeCompressed() + pubStr := string(pubBytes) - // Ensure we're already connected to this peer. + // Check that were actually connected to this peer. If not, then we'll + // exit in an error as we can't disconnect from a peer that we're not + // currently connected to. s.peersMtx.RLock() - peer, ok := s.peersByPub[string(pubKey)] + peer, ok := s.peersByPub[pubStr] s.peersMtx.RUnlock() if !ok { - msg.err <- fmt.Errorf("unable to find peer(%v) by public key(%v)", peer, msg.pubKey) + msg.err <- fmt.Errorf("unable to find peer %x", pubBytes) return } - // Get all pending and active channels corresponding with current node. - allChannels, err := s.chanDB.FetchAllChannels() - if err != nil { - msg.err <- fmt.Errorf("unable to get opened channels: %v", err) - return - } - - // Filter by public key all channels corresponding with the detached node. - var nodeChannels []*channeldb.OpenChannel - - for _, channel := range allChannels { - if hex.EncodeToString(channel.IdentityPub.SerializeCompressed()) == msg.pubKey { - nodeChannels = append(nodeChannels, channel) - } - } - - // Send server info logs containing channels id's and raise error about - // primary closing channels before start disconnecting peer. - if len(nodeChannels) > 0 { - for _, channel := range nodeChannels { - srvrLog.Infof("Before disconnect peer(%v) close channel: %v", - msg.pubKey, channel.ChanID) - } - msg.err <- fmt.Errorf("before disconnect peer(%v) you have to close "+ - "active and pending channels corresponding to that peer; %v", - msg.pubKey, nodeChannels) - return + // If this peer was formerly a persistent connection, then we'll remove + // them from this map so we don't attempt to re-connect after we + // disconnect. + s.pendingConnMtx.Lock() + if _, ok := s.persistentPeers[pubStr]; ok { + delete(s.persistentPeers, pubStr) } + s.pendingConnMtx.Unlock() + // Now that we know the peer is actually connected, we'll disconnect + // from the peer. srvrLog.Infof("Disconnecting from %v", peer) peer.Disconnect() + msg.err <- nil } @@ -1250,11 +1236,12 @@ func (s *server) handleOpenChanReq(req *openChanReq) { return } - // Spawn a goroutine to send the funding workflow request to the funding - // manager. This allows the server to continue handling queries instead - // of blocking on this request which is exported as a synchronous - // request to the outside world. - // TODO(roasbeef): server semaphore to restrict num goroutines + // Spawn a goroutine to send the funding workflow request to the + // funding manager. This allows the server to continue handling queries + // instead of blocking on this request which is exported as a + // synchronous request to the outside world. + // TODO(roasbeef): pass in chan that's closed if/when funding succeeds + // so can track as persistent peer? go s.fundingMgr.initFundingWorkflow(targetPeer.addr, req) } @@ -1275,14 +1262,14 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, return <-errChan } -// DisconnectFromPeer sends the request to server to close the connection -// with peer identified by public key. -func (s *server) DisconnectFromPeer(pubkey string) error { +// DisconnectPeer sends the request to server to close the connection with peer +// identified by public key. +func (s *server) DisconnectPeer(pubKey *btcec.PublicKey) error { errChan := make(chan error, 1) s.queries <- &disconnectPeerMsg{ - pubKey: pubkey, + pubKey: pubKey, err: errChan, }