diff --git a/zebra-network/src/peer.rs b/zebra-network/src/peer.rs index 1a10a96e2..70ed31dbd 100644 --- a/zebra-network/src/peer.rs +++ b/zebra-network/src/peer.rs @@ -20,10 +20,12 @@ pub use client::tests::ClientTestHarness; #[cfg(not(test))] use client::ClientRequest; #[cfg(test)] -pub(crate) use client::{CancelHeartbeatTask, ClientRequest}; +pub(crate) use client::ClientRequest; use client::{ClientRequestReceiver, InProgressClientRequest, MustUseOneshotSender}; +pub(crate) use client::CancelHeartbeatTask; + pub use client::Client; pub use connection::Connection; pub use connector::{Connector, OutboundConnectorRequest}; diff --git a/zebra-network/src/peer/client.rs b/zebra-network/src/peer/client.rs index da0e32b58..5c1f3e581 100644 --- a/zebra-network/src/peer/client.rs +++ b/zebra-network/src/peer/client.rs @@ -8,7 +8,9 @@ use futures::{ channel::{mpsc, oneshot}, future, ready, stream::{Stream, StreamExt}, + FutureExt, }; +use tokio::task::JoinHandle; use tower::Service; use crate::{ @@ -40,6 +42,12 @@ pub struct Client { /// The peer connection's protocol version. pub(crate) version: Version, + + /// A handle to the task responsible for connecting to the peer. + pub(crate) connection_task: JoinHandle<()>, + + /// A handle to the task responsible for sending periodic heartbeats. + pub(crate) heartbeat_task: JoinHandle<()>, } /// A signal sent by the [`Client`] half of a peer connection, @@ -253,28 +261,70 @@ impl Drop for MustUseOneshotSender { impl Client { /// Check if this connection's heartbeat task has exited. fn check_heartbeat(&mut self, cx: &mut Context<'_>) -> Result<(), SharedPeerError> { - if let Poll::Ready(()) = self + let is_canceled = self .shutdown_tx .as_mut() .expect("only taken on drop") .poll_canceled(cx) - { - // Make sure there is an error in the slot - let heartbeat_error: SharedPeerError = PeerError::HeartbeatTaskExited.into(); - let original_error = self.error_slot.try_update_error(heartbeat_error.clone()); - debug!( - ?original_error, - latest_error = ?heartbeat_error, - "client heartbeat task exited" - ); + .is_ready(); - if let Err(AlreadyErrored { original_error }) = original_error { - Err(original_error) - } else { - Err(heartbeat_error) + if is_canceled { + return self.set_task_exited_error("heartbeat", PeerError::HeartbeatTaskExited); + } + + match self.heartbeat_task.poll_unpin(cx) { + Poll::Pending => { + // Heartbeat task is still running. + Ok(()) } + Poll::Ready(Ok(())) => { + // Heartbeat task stopped unexpectedly, without panicking. + self.set_task_exited_error("heartbeat", PeerError::HeartbeatTaskExited) + } + Poll::Ready(Err(error)) => { + // Heartbeat task stopped unexpectedly with a panic. + panic!("heartbeat task has panicked: {}", error); + } + } + } + + /// Check if the connection's task has exited. + fn check_connection(&mut self, context: &mut Context<'_>) -> Result<(), SharedPeerError> { + match self.connection_task.poll_unpin(context) { + Poll::Pending => { + // Connection task is still running. + Ok(()) + } + Poll::Ready(Ok(())) => { + // Connection task stopped unexpectedly, without panicking. + self.set_task_exited_error("connection", PeerError::ConnectionTaskExited) + } + Poll::Ready(Err(error)) => { + // Connection task stopped unexpectedly with a panic. + panic!("connection task has panicked: {}", error); + } + } + } + + /// Properly update the error slot after a background task has unexpectedly stopped. + fn set_task_exited_error( + &mut self, + task_name: &str, + error: PeerError, + ) -> Result<(), SharedPeerError> { + // Make sure there is an error in the slot + let task_error = SharedPeerError::from(error); + let original_error = self.error_slot.try_update_error(task_error.clone()); + debug!( + ?original_error, + latest_error = ?task_error, + "client {} task exited", task_name + ); + + if let Err(AlreadyErrored { original_error }) = original_error { + Err(original_error) } else { - Ok(()) + Err(task_error) } } @@ -318,13 +368,15 @@ impl Service for Client { // The current task must be scheduled for wakeup every time we return // `Poll::Pending`. // - // `poll_canceled` schedules the client task for wakeup - // if the heartbeat task exits and drops the cancel handle. + // `check_heartbeat` and `check_connection` schedule the client task for wakeup + // if either task exits, or if the heartbeat task drops the cancel handle. // //`ready!` returns `Poll::Pending` when `server_tx` is unready, and // schedules this task for wakeup. - let mut result = self.check_heartbeat(cx); + let mut result = self + .check_heartbeat(cx) + .and_then(|()| self.check_connection(cx)); if result.is_ok() { result = ready!(self.poll_request(cx)); @@ -340,8 +392,6 @@ impl Service for Client { } fn call(&mut self, request: Request) -> Self::Future { - use futures::future::FutureExt; - let (tx, rx) = oneshot::channel(); // get the current Span to propagate it to the peer connection task. // this allows the peer connection to enter the correct tracing context diff --git a/zebra-network/src/peer/client/tests.rs b/zebra-network/src/peer/client/tests.rs index e99807283..b26553b49 100644 --- a/zebra-network/src/peer/client/tests.rs +++ b/zebra-network/src/peer/client/tests.rs @@ -3,26 +3,41 @@ mod vectors; -use futures::channel::{mpsc, oneshot}; +use std::time::Duration; + +use futures::{ + channel::{mpsc, oneshot}, + future::{self, AbortHandle, Future, FutureExt}, +}; +use tokio::task::JoinHandle; use crate::{ peer::{error::SharedPeerError, CancelHeartbeatTask, Client, ClientRequest, ErrorSlot}, protocol::external::types::Version, }; +/// The maximum time a mocked peer connection should be alive during a test. +const MAX_PEER_CONNECTION_TIME: Duration = Duration::from_secs(10); + /// A harness with mocked channels for testing a [`Client`] instance. pub struct ClientTestHarness { client_request_receiver: Option>, shutdown_receiver: Option>, error_slot: ErrorSlot, version: Version, + connection_aborter: AbortHandle, + heartbeat_aborter: AbortHandle, } impl ClientTestHarness { /// Create a [`ClientTestHarnessBuilder`] instance to help create a new [`Client`] instance /// and a [`ClientTestHarness`] to track it. pub fn build() -> ClientTestHarnessBuilder { - ClientTestHarnessBuilder { version: None } + ClientTestHarnessBuilder { + version: None, + connection_task: None, + heartbeat_task: None, + } } /// Gets the peer protocol version associated to the [`Client`]. @@ -109,6 +124,22 @@ impl ClientTestHarness { .try_update_error(error.into()) .expect("unexpected earlier error in error slot") } + + /// Stops the mock background task that handles incoming remote requests and replies. + pub async fn stop_connection_task(&self) { + self.connection_aborter.abort(); + + // Allow the task to detect that it was aborted. + tokio::task::yield_now().await; + } + + /// Stops the mock background task that sends periodic heartbeats. + pub async fn stop_heartbeat_task(&self) { + self.heartbeat_aborter.abort(); + + // Allow the task to detect that it was aborted. + tokio::task::yield_now().await; + } } /// The result of an attempt to receive a [`ClientRequest`] sent by the [`Client`] instance. @@ -152,17 +183,47 @@ impl ReceiveRequestAttempt { /// Mocked data is used to construct a real [`Client`] instance. The mocked data is initialized by /// the [`ClientTestHarnessBuilder`], and can be accessed and changed through the /// [`ClientTestHarness`]. -pub struct ClientTestHarnessBuilder { +pub struct ClientTestHarnessBuilder, H = future::Ready<()>> { + connection_task: Option, + heartbeat_task: Option, version: Option, } -impl ClientTestHarnessBuilder { +impl ClientTestHarnessBuilder +where + C: Future + Send + 'static, + H: Future + Send + 'static, +{ /// Configure the mocked version for the peer. pub fn with_version(mut self, version: Version) -> Self { self.version = Some(version); self } + /// Configure the mock connection task future to use. + pub fn with_connection_task( + self, + connection_task: NewC, + ) -> ClientTestHarnessBuilder { + ClientTestHarnessBuilder { + connection_task: Some(connection_task), + heartbeat_task: self.heartbeat_task, + version: self.version, + } + } + + /// Configure the mock heartbeat task future to use. + pub fn with_heartbeat_task( + self, + heartbeat_task: NewH, + ) -> ClientTestHarnessBuilder { + ClientTestHarnessBuilder { + connection_task: self.connection_task, + heartbeat_task: Some(heartbeat_task), + version: self.version, + } + } + /// Build a [`Client`] instance with the mocked data and a [`ClientTestHarness`] to track it. pub fn finish(self) -> (Client, ClientTestHarness) { let (shutdown_sender, shutdown_receiver) = oneshot::channel(); @@ -170,11 +231,18 @@ impl ClientTestHarnessBuilder { let error_slot = ErrorSlot::default(); let version = self.version.unwrap_or(Version(0)); + let (connection_task, connection_aborter) = + Self::spawn_background_task_or_fallback(self.connection_task); + let (heartbeat_task, heartbeat_aborter) = + Self::spawn_background_task_or_fallback(self.heartbeat_task); + let client = Client { shutdown_tx: Some(shutdown_sender), server_tx: client_request_sender, error_slot: error_slot.clone(), version, + connection_task, + heartbeat_task, }; let harness = ClientTestHarness { @@ -182,8 +250,35 @@ impl ClientTestHarnessBuilder { shutdown_receiver: Some(shutdown_receiver), error_slot, version, + connection_aborter, + heartbeat_aborter, }; (client, harness) } + + /// Spawn a mock background abortable task `task_future` if provided, or a fallback task + /// otherwise. + /// + /// The fallback task lives as long as [`MAX_PEER_CONNECTION_TIME`]. + fn spawn_background_task_or_fallback(task_future: Option) -> (JoinHandle<()>, AbortHandle) + where + T: Future + Send + 'static, + { + match task_future { + Some(future) => Self::spawn_background_task(future), + None => Self::spawn_background_task(tokio::time::sleep(MAX_PEER_CONNECTION_TIME)), + } + } + + /// Spawn a mock background abortable task to run `task_future`. + fn spawn_background_task(task_future: T) -> (JoinHandle<()>, AbortHandle) + where + T: Future + Send + 'static, + { + let (task, abort_handle) = future::abortable(task_future); + let task_handle = tokio::spawn(task.map(|_result| ())); + + (task_handle, abort_handle) + } } diff --git a/zebra-network/src/peer/client/tests/vectors.rs b/zebra-network/src/peer/client/tests/vectors.rs index 5a331973e..5130699b0 100644 --- a/zebra-network/src/peer/client/tests/vectors.rs +++ b/zebra-network/src/peer/client/tests/vectors.rs @@ -1,5 +1,6 @@ //! Fixed peer [`Client`] test vectors. +use futures::poll; use tower::ServiceExt; use zebra_test::service_extensions::IsReady; @@ -150,3 +151,69 @@ async fn client_service_drop_cleanup() { assert!(!harness.wants_connection_heartbeats()); assert!(harness.try_to_receive_outbound_client_request().is_closed()); } + +/// Force the connection background task to stop, and check if the `Client` properly handles it. +#[tokio::test] +async fn client_service_handles_exited_connection_task() { + zebra_test::init(); + + let (mut client, mut harness) = ClientTestHarness::build().finish(); + + harness.stop_connection_task().await; + + assert!(client.is_failed().await); + assert!(harness.current_error().is_some()); + assert!(!harness.wants_connection_heartbeats()); + assert!(harness.try_to_receive_outbound_client_request().is_closed()); +} + +/// Force the connection background task to panic, and check if the `Client` propagates it. +#[tokio::test] +#[should_panic] +async fn client_service_propagates_panic_from_connection_task() { + zebra_test::init(); + + let (mut client, _harness) = ClientTestHarness::build() + .with_connection_task(async move { + panic!("connection task failure"); + }) + .finish(); + + // Allow the custom connection task to run. + tokio::task::yield_now().await; + + let _ = poll!(client.ready()); +} + +/// Force the heartbeat background task to stop, and check if the `Client` properly handles it. +#[tokio::test] +async fn client_service_handles_exited_heartbeat_task() { + zebra_test::init(); + + let (mut client, mut harness) = ClientTestHarness::build().finish(); + + harness.stop_heartbeat_task().await; + + assert!(client.is_failed().await); + assert!(harness.current_error().is_some()); + assert!(!harness.wants_connection_heartbeats()); + assert!(harness.try_to_receive_outbound_client_request().is_closed()); +} + +/// Force the heartbeat background task to panic, and check if the `Client` propagates it. +#[tokio::test] +#[should_panic] +async fn client_service_propagates_panic_from_heartbeat_task() { + zebra_test::init(); + + let (mut client, _harness) = ClientTestHarness::build() + .with_heartbeat_task(async move { + panic!("heartbeat task failure"); + }) + .finish(); + + // Allow the custom heartbeat task to run. + tokio::task::yield_now().await; + + let _ = poll!(client.ready()); +} diff --git a/zebra-network/src/peer/error.rs b/zebra-network/src/peer/error.rs index 968b4a4c0..68b9a8011 100644 --- a/zebra-network/src/peer/error.rs +++ b/zebra-network/src/peer/error.rs @@ -37,6 +37,10 @@ pub enum PeerError { #[error("Internal client dropped")] ClientDropped, + /// A [`Client`]'s internal connection task exited. + #[error("Internal peer connection task exited")] + ConnectionTaskExited, + /// Zebra's internal heartbeat task exited. #[error("Internal heartbeat task exited")] HeartbeatTaskExited, @@ -72,6 +76,7 @@ impl PeerError { PeerError::ConnectionDropped => "ConnectionDropped".into(), PeerError::ClientDropped => "ClientDropped".into(), PeerError::HeartbeatTaskExited => "HeartbeatTaskExited".into(), + PeerError::ConnectionTaskExited => "ConnectionTaskExited".into(), PeerError::ClientRequestTimeout => "ClientRequestTimeout".into(), // TODO: add error kinds or summaries to `SerializationError` PeerError::Serialization(inner) => format!("Serialization({})", inner).into(), diff --git a/zebra-network/src/peer/handshake.rs b/zebra-network/src/peer/handshake.rs index f633314e2..30c398973 100644 --- a/zebra-network/src/peer/handshake.rs +++ b/zebra-network/src/peer/handshake.rs @@ -28,7 +28,8 @@ use crate::{ constants, meta_addr::MetaAddrChange, peer::{ - Client, ClientRequest, Connection, ErrorSlot, HandshakeError, MinimumPeerVersion, PeerError, + CancelHeartbeatTask, Client, ClientRequest, Connection, ErrorSlot, HandshakeError, + MinimumPeerVersion, PeerError, }, peer_set::ConnectionTracker, protocol::{ @@ -788,13 +789,6 @@ where let (shutdown_tx, shutdown_rx) = oneshot::channel(); let error_slot = ErrorSlot::default(); - let client = Client { - shutdown_tx: Some(shutdown_tx), - server_tx: server_tx.clone(), - error_slot: error_slot.clone(), - version: remote_version, - }; - let (peer_tx, peer_rx) = peer_conn.split(); // Instrument the peer's rx and tx streams. @@ -918,92 +912,40 @@ where request_timer: None, svc: inbound_service, client_rx: server_rx.into(), - error_slot, + error_slot: error_slot.clone(), peer_tx, connection_tracker, metrics_label: connected_addr.get_transient_addr_label(), last_metrics_state: None, }; - tokio::spawn( + let connection_task = tokio::spawn( server .run(peer_rx) .instrument(connection_span.clone()) .boxed(), ); - // CORRECTNESS - // - // To prevent hangs: - // - every await that depends on the network must have a timeout (or interval) - // - every error/shutdown must update the address book state and return - // - // The address book state can be updated via `ClientRequest.tx`, or the - // heartbeat_ts_collector. - // - // Returning from the spawned closure terminates the connection's heartbeat task. - let heartbeat_span = tracing::debug_span!(parent: connection_span, "heartbeat"); - let heartbeat_ts_collector = address_book_updater.clone(); - tokio::spawn( - async move { - use futures::future::Either; - - let mut shutdown_rx = shutdown_rx; - let mut server_tx = server_tx; - let mut heartbeat_ts_collector = heartbeat_ts_collector.clone(); - let mut interval_stream = - IntervalStream::new(tokio::time::interval(constants::HEARTBEAT_INTERVAL)); - - loop { - let shutdown_rx_ref = Pin::new(&mut shutdown_rx); - - // CORRECTNESS - // - // Currently, select prefers the first future if multiple - // futures are ready. - // - // Starvation is impossible here, because interval has a - // slow rate, and shutdown is a oneshot. If both futures - // are ready, we want the shutdown to take priority over - // sending a useless heartbeat. - if matches!( - future::select(shutdown_rx_ref, interval_stream.next()).await, - Either::Left(_) - ) { - tracing::trace!("shutting down due to Client shut down"); - if let Some(book_addr) = connected_addr.get_address_book_addr() { - // awaiting a local task won't hang - let _ = heartbeat_ts_collector - .send(MetaAddr::new_shutdown(&book_addr, remote_services)) - .await; - } - return; - } - - // We've reached another heartbeat interval without - // shutting down, so do a heartbeat request. - // - // TODO: await heartbeat and shutdown. The select - // function needs pinned types, but pinned generics - // are hard (#1678) - let heartbeat = send_one_heartbeat(&mut server_tx); - if heartbeat_timeout( - heartbeat, - &mut heartbeat_ts_collector, - &connected_addr, - &remote_services, - ) - .await - .is_err() - { - return; - } - } - } - .instrument(heartbeat_span) - .boxed(), + let heartbeat_task = tokio::spawn( + send_periodic_heartbeats( + connected_addr, + remote_services, + shutdown_rx, + server_tx.clone(), + address_book_updater.clone(), + ) + .instrument(tracing::debug_span!(parent: connection_span, "heartbeat")), ); + let client = Client { + shutdown_tx: Some(shutdown_tx), + server_tx, + error_slot, + version: remote_version, + connection_task, + heartbeat_task, + }; + Ok(client) }; @@ -1014,6 +956,76 @@ where } } +/// Send periodical heartbeats to `server_tx`, and update the peer status through +/// `heartbeat_ts_collector`. +/// +/// # Correctness +/// +/// To prevent hangs: +/// - every await that depends on the network must have a timeout (or interval) +/// - every error/shutdown must update the address book state and return +/// +/// The address book state can be updated via `ClientRequest.tx`, or the +/// heartbeat_ts_collector. +/// +/// Returning from this function terminates the connection's heartbeat task. +async fn send_periodic_heartbeats( + connected_addr: ConnectedAddr, + remote_services: PeerServices, + mut shutdown_rx: oneshot::Receiver, + mut server_tx: futures::channel::mpsc::Sender, + mut heartbeat_ts_collector: tokio::sync::mpsc::Sender, +) { + use futures::future::Either; + + let mut interval_stream = + IntervalStream::new(tokio::time::interval(constants::HEARTBEAT_INTERVAL)); + + loop { + let shutdown_rx_ref = Pin::new(&mut shutdown_rx); + + // CORRECTNESS + // + // Currently, select prefers the first future if multiple + // futures are ready. + // + // Starvation is impossible here, because interval has a + // slow rate, and shutdown is a oneshot. If both futures + // are ready, we want the shutdown to take priority over + // sending a useless heartbeat. + if matches!( + future::select(shutdown_rx_ref, interval_stream.next()).await, + Either::Left(_) + ) { + tracing::trace!("shutting down due to Client shut down"); + if let Some(book_addr) = connected_addr.get_address_book_addr() { + // awaiting a local task won't hang + let _ = heartbeat_ts_collector + .send(MetaAddr::new_shutdown(&book_addr, remote_services)) + .await; + } + return; + } + + // We've reached another heartbeat interval without + // shutting down, so do a heartbeat request. + // + // TODO: await heartbeat and shutdown (#3254) + let heartbeat = send_one_heartbeat(&mut server_tx); + if heartbeat_timeout( + heartbeat, + &mut heartbeat_ts_collector, + &connected_addr, + &remote_services, + ) + .await + .is_err() + { + return; + } + } +} + /// Send one heartbeat using `server_tx`. async fn send_one_heartbeat( server_tx: &mut futures::channel::mpsc::Sender, diff --git a/zebra-network/src/peer_set/set/tests/prop.rs b/zebra-network/src/peer_set/set/tests/prop.rs index 949d5858d..9b820c925 100644 --- a/zebra-network/src/peer_set/set/tests/prop.rs +++ b/zebra-network/src/peer_set/set/tests/prop.rs @@ -23,7 +23,6 @@ proptest! { ) { let runtime = zebra_test::init_async(); - let (discovered_peers, mut harnesses) = peer_versions.mock_peer_discovery(); let (mut minimum_peer_version, best_tip_height) = MinimumPeerVersion::with_mock_chain_tip(network); @@ -34,6 +33,7 @@ proptest! { let current_minimum_version = minimum_peer_version.current(); runtime.block_on(async move { + let (discovered_peers, mut harnesses) = peer_versions.mock_peer_discovery(); let (mut peer_set, _peer_set_guard) = PeerSetBuilder::new() .with_discover(discovered_peers) .with_minimum_peer_version(minimum_peer_version) @@ -57,7 +57,6 @@ proptest! { ) { let runtime = zebra_test::init_async(); - let (discovered_peers, mut harnesses) = peer_versions.mock_peer_discovery(); let (mut minimum_peer_version, best_tip_height) = MinimumPeerVersion::with_mock_chain_tip(block_heights.network); @@ -66,6 +65,7 @@ proptest! { .expect("receiving endpoint lives as long as `minimum_peer_version`"); runtime.block_on(async move { + let (discovered_peers, mut harnesses) = peer_versions.mock_peer_discovery(); let (mut peer_set, _peer_set_guard) = PeerSetBuilder::new() .with_discover(discovered_peers) .with_minimum_peer_version(minimum_peer_version.clone())