diff --git a/zebra-network/src/peer/client.rs b/zebra-network/src/peer/client.rs index 7f7852246..f925a55a8 100644 --- a/zebra-network/src/peer/client.rs +++ b/zebra-network/src/peer/client.rs @@ -16,6 +16,9 @@ use super::{ErrorSlot, SharedPeerError}; /// The "client" duplex half of a peer connection. pub struct Client { + // Used to shut down the corresponding heartbeat. + // This is always Some except when we take it on drop. + pub(super) shutdown_tx: Option>, pub(super) server_tx: mpsc::Sender, pub(super) error_slot: ErrorSlot, } @@ -85,3 +88,13 @@ impl Service for Client { } } } + +impl Drop for Client { + fn drop(&mut self) { + let _ = self + .shutdown_tx + .take() + .expect("must not drop twice") + .send(()); + } +} diff --git a/zebra-network/src/peer/handshake.rs b/zebra-network/src/peer/handshake.rs index 17e74dbbd..5e73f2eb8 100644 --- a/zebra-network/src/peer/handshake.rs +++ b/zebra-network/src/peer/handshake.rs @@ -8,7 +8,10 @@ use std::{ }; use chrono::Utc; -use futures::{channel::mpsc, prelude::*}; +use futures::{ + channel::{mpsc, oneshot}, + prelude::*, +}; use tokio::net::TcpStream; use tokio_util::codec::Framed; use tower::Service; @@ -215,9 +218,11 @@ where // These channels should not be cloned more than they are // in this block, see constants.rs for more. let (server_tx, server_rx) = mpsc::channel(0); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); let slot = ErrorSlot::default(); let client = Client { + shutdown_tx: Some(shutdown_tx), server_tx: server_tx.clone(), error_slot: slot.clone(), }; @@ -283,35 +288,38 @@ where let heartbeat_span = tracing::debug_span!(parent: connection_span, "heartbeat"); tokio::spawn( async move { - use futures::channel::oneshot; - use super::client::ClientRequest; + use futures::future::Either; + let mut shutdown_rx = shutdown_rx; let mut server_tx = server_tx; - let mut interval_stream = tokio::time::interval(constants::HEARTBEAT_INTERVAL); - loop { - interval_stream.tick().await; - - // We discard the server handle because our - // heartbeat `Ping`s are a special case, and we - // don't actually care about the response here. - let (request_tx, _) = oneshot::channel(); - if server_tx - .send(ClientRequest { - request: Request::Ping(Nonce::default()), - tx: request_tx, - span: tracing::Span::current(), - }) - .await - .is_err() - { - return; + let shutdown_rx_ref = Pin::new(&mut shutdown_rx); + match future::select(interval_stream.next(), shutdown_rx_ref).await { + Either::Left(_) => { + // We don't wait on a response because heartbeats are checked + // internally to the connection logic, we just need a separate + // task (this one) to generate them. + let (request_tx, _) = oneshot::channel(); + if server_tx + .send(ClientRequest { + request: Request::Ping(Nonce::default()), + tx: request_tx, + span: tracing::Span::current(), + }) + .await + .is_err() + { + return; + } + } + Either::Right(_) => return, // got shutdown signal } } } - .instrument(heartbeat_span), + .instrument(heartbeat_span) + .boxed(), ); Ok(client)