diff --git a/zebra-network/src/peer/client.rs b/zebra-network/src/peer/client.rs index 93ac3b187..6e55775b9 100644 --- a/zebra-network/src/peer/client.rs +++ b/zebra-network/src/peer/client.rs @@ -11,13 +11,19 @@ use futures::{ }; use tower::Service; -use crate::protocol::{ - external::types::Version, - internal::{Request, Response}, +use crate::{ + peer::error::AlreadyErrored, + protocol::{ + external::types::Version, + internal::{Request, Response}, + }, }; use super::{ErrorSlot, PeerError, SharedPeerError}; +#[cfg(test)] +mod tests; + /// The "client" duplex half of a peer connection. pub struct Client { /// Used to shut down the corresponding heartbeat. @@ -68,8 +74,6 @@ pub(super) struct ClientRequestReceiver { /// A message from the `peer::Client` to the `peer::Server`, /// after it has been received by the `peer::Server`. -/// -/// #[derive(Debug)] #[must_use = "tx.send() must be called before drop"] pub(super) struct InProgressClientRequest { @@ -129,10 +133,29 @@ impl From for InProgressClientRequest { } impl ClientRequestReceiver { - /// Forwards to `inner.close()` + /// Forwards to `inner.close()`. pub fn close(&mut self) { self.inner.close() } + + /// Closes `inner`, then gets the next pending [`Request`]. + /// + /// Closing the channel ensures that: + /// - the request stream terminates, and + /// - task notifications are not required. + pub fn close_and_flush_next(&mut self) -> Option { + self.inner.close(); + + // # Correctness + // + // The request stream terminates, because the sender is closed, + // and the channel has a limited capacity. + // Task notifications are not required, because the sender is closed. + self.inner + .try_next() + .expect("channel is closed") + .map(Into::into) + } } impl Stream for ClientRequestReceiver { @@ -227,6 +250,62 @@ impl Drop for MustUseOneshotSender { } } +impl Client { + /// Check if this connection's heartbeat task has exited. + fn check_heartbeat(&mut self, cx: &mut Context<'_>) -> Result<(), SharedPeerError> { + if let Poll::Ready(()) = self + .shutdown_tx + .as_mut() + .expect("only taken on drop") + .poll_canceled(cx) + { + // Make sure there is an error in the slot + let heartbeat_error: SharedPeerError = PeerError::HeartbeatTaskExited.into(); + let original_error = self.error_slot.try_update_error(heartbeat_error.clone()); + debug!( + ?original_error, + latest_error = ?heartbeat_error, + "client heartbeat task exited" + ); + + if let Err(AlreadyErrored { original_error }) = original_error { + Err(original_error) + } else { + Err(heartbeat_error) + } + } else { + Ok(()) + } + } + + /// Poll for space in the shared request sender channel. + fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { + if ready!(self.server_tx.poll_ready(cx)).is_err() { + Poll::Ready(Err(self + .error_slot + .try_get_error() + .expect("failed servers must set their error slot"))) + } else if let Some(error) = self.error_slot.try_get_error() { + Poll::Ready(Err(error)) + } else { + Poll::Ready(Ok(())) + } + } + + /// Shut down the resources held by the client half of this peer connection. + /// + /// Stops further requests to the remote peer, and stops the heartbeat task. + fn shutdown(&mut self) { + // Prevent any senders from sending more messages to this peer. + self.server_tx.close_channel(); + + // Stop the heartbeat task + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(CancelHeartbeatTask); + } + } +} + impl Service for Client { type Response = Response; type Error = SharedPeerError; @@ -234,24 +313,27 @@ impl Service for Client { Pin> + Send + 'static>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - // CORRECTNESS + // # Correctness // // The current task must be scheduled for wakeup every time we return // `Poll::Pending`. // + // `poll_canceled` schedules the client task for wakeup + // if the heartbeat task exits and drops the cancel handle. + // //`ready!` returns `Poll::Pending` when `server_tx` is unready, and // schedules this task for wakeup. - // - // Since `shutdown_tx` is used for oneshot communication to the heartbeat - // task, it will never be `Pending`. - // - // TODO: should the Client exit if the heartbeat task exits and drops - // `shutdown_tx`? - if ready!(self.server_tx.poll_ready(cx)).is_err() { - Poll::Ready(Err(self - .error_slot - .try_get_error() - .expect("failed servers must set their error slot"))) + + let mut result = self.check_heartbeat(cx); + + if result.is_ok() { + result = ready!(self.poll_request(cx)); + } + + if let Err(error) = result { + self.shutdown(); + + Poll::Ready(Err(error)) } else { Poll::Ready(Ok(())) } @@ -297,10 +379,15 @@ impl Service for Client { impl Drop for Client { fn drop(&mut self) { - let _ = self - .shutdown_tx - .take() - .expect("must not drop twice") - .send(CancelHeartbeatTask); + // Make sure there is an error in the slot + let drop_error: SharedPeerError = PeerError::ClientDropped.into(); + let original_error = self.error_slot.try_update_error(drop_error.clone()); + debug!( + ?original_error, + latest_error = ?drop_error, + "client struct dropped" + ); + + self.shutdown(); } } diff --git a/zebra-network/src/peer/client/tests.rs b/zebra-network/src/peer/client/tests.rs new file mode 100644 index 000000000..78babc731 --- /dev/null +++ b/zebra-network/src/peer/client/tests.rs @@ -0,0 +1,3 @@ +//! Tests for the [`Client`] part of peer connections + +mod vectors; diff --git a/zebra-network/src/peer/client/tests/vectors.rs b/zebra-network/src/peer/client/tests/vectors.rs new file mode 100644 index 000000000..3d18d790e --- /dev/null +++ b/zebra-network/src/peer/client/tests/vectors.rs @@ -0,0 +1,233 @@ +//! Fixed peer [`Client`] test vectors. + +use futures::{ + channel::{mpsc, oneshot}, + FutureExt, +}; +use tower::ServiceExt; + +use crate::{ + peer::{CancelHeartbeatTask, Client, ErrorSlot}, + protocol::external::types::Version, + PeerError, +}; + +#[tokio::test] +async fn client_service_ready_ok() { + zebra_test::init(); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let (server_tx, mut server_rx) = mpsc::channel(1); + + let shared_error_slot = ErrorSlot::default(); + + let mut client = Client { + shutdown_tx: Some(shutdown_tx), + server_tx, + error_slot: shared_error_slot.clone(), + version: Version(0), + }; + + let result = client.ready().now_or_never(); + assert!(matches!(result, Some(Ok(Client { .. })))); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, None)); + + let result = shutdown_rx.try_recv(); + assert!(matches!(result, Ok(None))); + + // Unlike oneshots, open futures::mpsc channels return Err when empty + let result = server_rx.try_next(); + assert!(matches!(result, Err(_))); +} + +#[tokio::test] +async fn client_service_ready_heartbeat_exit() { + zebra_test::init(); + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let (server_tx, mut server_rx) = mpsc::channel(1); + + let shared_error_slot = ErrorSlot::default(); + + let mut client = Client { + shutdown_tx: Some(shutdown_tx), + server_tx, + error_slot: shared_error_slot.clone(), + version: Version(0), + }; + + shared_error_slot + .try_update_error(PeerError::HeartbeatTaskExited.into()) + .expect("unexpected earlier error in tests"); + std::mem::drop(shutdown_rx); + + let result = client.ready().now_or_never(); + assert!(matches!(result, Some(Err(_)))); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + // Unlike oneshots, closed futures::mpsc channels return None + let result = server_rx.try_next(); + assert!(matches!(result, Ok(None))); +} + +#[tokio::test] +async fn client_service_ready_request_drop() { + zebra_test::init(); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let (server_tx, server_rx) = mpsc::channel(1); + + let shared_error_slot = ErrorSlot::default(); + + let mut client = Client { + shutdown_tx: Some(shutdown_tx), + server_tx, + error_slot: shared_error_slot.clone(), + version: Version(0), + }; + + shared_error_slot + .try_update_error(PeerError::ConnectionDropped.into()) + .expect("unexpected earlier error in tests"); + std::mem::drop(server_rx); + + let result = client.ready().now_or_never(); + assert!(matches!(result, Some(Err(_)))); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + let result = shutdown_rx.try_recv(); + assert!(matches!(result, Ok(Some(CancelHeartbeatTask)))); +} + +#[tokio::test] +async fn client_service_ready_request_close() { + zebra_test::init(); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let (server_tx, mut server_rx) = mpsc::channel(1); + + let shared_error_slot = ErrorSlot::default(); + + let mut client = Client { + shutdown_tx: Some(shutdown_tx), + server_tx, + error_slot: shared_error_slot.clone(), + version: Version(0), + }; + + shared_error_slot + .try_update_error(PeerError::ConnectionClosed.into()) + .expect("unexpected earlier error in tests"); + server_rx.close(); + + let result = client.ready().now_or_never(); + assert!(matches!(result, Some(Err(_)))); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + let result = shutdown_rx.try_recv(); + assert!(matches!(result, Ok(Some(CancelHeartbeatTask)))); + + let result = server_rx.try_next(); + assert!(matches!(result, Ok(None))); +} + +#[tokio::test] +async fn client_service_ready_error_in_slot() { + zebra_test::init(); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let (server_tx, mut server_rx) = mpsc::channel(1); + + let shared_error_slot = ErrorSlot::default(); + + let mut client = Client { + shutdown_tx: Some(shutdown_tx), + server_tx, + error_slot: shared_error_slot.clone(), + version: Version(0), + }; + + shared_error_slot + .try_update_error(PeerError::Overloaded.into()) + .expect("unexpected earlier error in tests"); + + let result = client.ready().now_or_never(); + assert!(matches!(result, Some(Err(_)))); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + let result = shutdown_rx.try_recv(); + assert!(matches!(result, Ok(Some(CancelHeartbeatTask)))); + + let result = server_rx.try_next(); + assert!(matches!(result, Ok(None))); +} + +#[tokio::test] +async fn client_service_ready_multiple_errors() { + zebra_test::init(); + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let (server_tx, mut server_rx) = mpsc::channel(1); + + let shared_error_slot = ErrorSlot::default(); + + let mut client = Client { + shutdown_tx: Some(shutdown_tx), + server_tx, + error_slot: shared_error_slot.clone(), + version: Version(0), + }; + + shared_error_slot + .try_update_error(PeerError::DuplicateHandshake.into()) + .expect("unexpected earlier error in tests"); + std::mem::drop(shutdown_rx); + server_rx.close(); + + let result = client.ready().now_or_never(); + assert!(matches!(result, Some(Err(_)))); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + let result = server_rx.try_next(); + assert!(matches!(result, Ok(None))); +} + +#[tokio::test] +async fn client_service_drop_cleanup() { + zebra_test::init(); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let (server_tx, mut server_rx) = mpsc::channel(1); + + let shared_error_slot = ErrorSlot::default(); + + let client = Client { + shutdown_tx: Some(shutdown_tx), + server_tx, + error_slot: shared_error_slot.clone(), + version: Version(0), + }; + + std::mem::drop(client); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + let result = shutdown_rx.try_recv(); + assert!(matches!(result, Ok(Some(CancelHeartbeatTask)))); + + let result = server_rx.try_next(); + assert!(matches!(result, Ok(None))); +} diff --git a/zebra-network/src/peer/connection.rs b/zebra-network/src/peer/connection.rs index f5e29ad15..26305ceff 100644 --- a/zebra-network/src/peer/connection.rs +++ b/zebra-network/src/peer/connection.rs @@ -38,6 +38,9 @@ use crate::{ BoxError, }; +#[cfg(test)] +mod tests; + #[derive(Debug)] pub(super) enum Handler { /// Indicates that the handler has finished processing the request. @@ -420,9 +423,9 @@ pub struct Connection { /// The `inbound` service, used to answer requests from this connection's peer. pub(super) svc: S, - /// A channel that receives network requests from the rest of Zebra. + /// A channel for requests that Zebra's internal services want to send to remote peers. /// - /// This channel produces `InProgressClientRequest`s. + /// This channel accepts [`Request`]s, and produces [`InProgressClientRequest`]s. pub(super) client_rx: ClientRequestReceiver, /// A slot for an error shared between the Connection and the Client that uses it. @@ -430,7 +433,13 @@ pub struct Connection { /// `None` unless the connection or client have errored. pub(super) error_slot: ErrorSlot, - /// A channel for sending requests to the connected peer. + /// A channel for sending Zcash messages to the connected peer. + /// + /// This channel accepts [`Message`]s. + /// + /// The corresponding peer message receiver is passed to [`Connection::run`]. + /// + /// TODO: add a timeout when sending messages to the remote peer (#3234) pub(super) peer_tx: Tx, /// A connection tracker that reduces the open connection count when dropped. @@ -442,8 +451,7 @@ pub struct Connection { /// /// If this connection tracker or `Connection`s are leaked, /// the number of active connections will appear higher than it actually is. - /// - /// Eventually, Zebra could stop making connections entirely. + /// If enough connections leak, Zebra will stop making new connections. #[allow(dead_code)] pub(super) connection_tracker: ConnectionTracker, @@ -461,6 +469,9 @@ where Tx: Sink + Unpin, { /// Consume this `Connection` to form a spawnable future containing its event loop. + /// + /// `peer_rx` is a channel for receiving Zcash [`Message`]s from the connected peer. + /// The corresponding peer message receiver is [`Connection.peer_tx`]. pub async fn run(mut self, mut peer_rx: Rx) where Rx: Stream> + Unpin, @@ -484,6 +495,8 @@ where // // If there is a pending request, we wait only on an incoming peer message, and // check whether it can be interpreted as a response to the pending request. + // + // TODO: turn this comment into a module-level comment, after splitting the module. loop { self.update_state_metrics(None); @@ -516,7 +529,11 @@ where } Either::Right((None, _)) => { trace!("client_rx closed, ending connection"); - return; + + // There are no requests to be flushed, + // but we need to set an error and update metrics. + self.shutdown(PeerError::ClientDropped); + break; } Either::Right((Some(req), _)) => { let span = req.span.clone(); @@ -646,6 +663,8 @@ where tx, .. } => { + // We replaced the original state, which means `fail_with` won't see it. + // So we do the state request cleanup manually. let e = SharedPeerError::from(e); let _ = tx.send(Err(e.clone())); self.fail_with(e); @@ -663,107 +682,35 @@ where }; } Either::Left((Either::Left(_), _peer_fut)) => { + // The client receiver was dropped, so we don't need to send on `tx` here. trace!(parent: &span, "client request was cancelled"); self.state = State::AwaitingRequest; } } } - // We've failed, but we need to flush all pending client - // requests before we can return and complete the future. - State::Failed => { - match self.client_rx.next().await { - Some(InProgressClientRequest { tx, span, .. }) => { - trace!( - parent: &span, - "sending an error response to a pending request on a failed connection" - ); - // Correctness - // - // Error slots use a threaded `std::sync::Mutex`, so - // accessing the slot can block the async task's - // current thread. So we only hold the lock for long - // enough to get a reference to the error. - let e = self - .error_slot - .try_get_error() - .expect("cannot enter failed state without setting error slot"); - let _ = tx.send(Err(e)); - // Continue until we've errored all queued reqs - continue; - } - None => return, - } - } + // This connection has failed: stop the event loop, and complete the future. + State::Failed => break, } } + + assert!( + self.error_slot.try_get_error().is_some(), + "closing connections must call fail_with() or shutdown() to set the error slot" + ); } - /// Marks the peer as having failed with error `e`. + /// Fail this connection. /// - /// # Panics - /// - /// If `self` has already failed with a previous error. - fn fail_with(&mut self, e: E) - where - E: Into, - { - let e = e.into(); - debug!(%e, - connection_state = ?self.state, + /// If the connection has errored already, re-use the original error. + /// Otherwise, fail the connection with `error`. + fn fail_with(&mut self, error: impl Into) { + let error = error.into(); + + debug!(%error, client_receiver = ?self.client_rx, "failing peer service with error"); - // Update the shared error slot - // - // # Correctness - // - // Error slots use a threaded `std::sync::Mutex`, so accessing the slot - // can block the async task's current thread. We only perform a single - // slot update per `Client`, and panic to enforce this constraint. - // - // This assertion typically fails due to these bugs: - // * we mark a connection as failed without using fail_with - // * we call fail_with without checking for a failed connection - // state - // * we continue processing messages after calling fail_with - // - // See the original bug #1510 and PR #1531, and the later bug #1599 - // and PR #1600. - let error_result = self.error_slot.try_update_error(e.clone()); - - if let Err(AlreadyErrored { original_error }) = error_result { - panic!( - "multiple failures for connection: \n\ - failed connections should stop processing pending requests and responses, \n\ - then close the connection. \n\ - state: {:?} \n\ - client receiver: {:?} \n\ - original error: {:?} \n\ - new error: {:?}", - self.state, self.client_rx, original_error, e, - ); - } - - // We want to close the client channel and set State::Failed so - // that we can flush any pending client requests. However, we may have - // an outstanding client request in State::AwaitingResponse, so - // we need to deal with it first if it exists. - self.client_rx.close(); - let old_state = std::mem::replace(&mut self.state, State::Failed); - self.update_state_metrics(None); - - if let State::AwaitingResponse { tx, .. } = old_state { - // # Correctness - // - // We know the slot has Some(e) because we just set it above, - // and the error slot is never unset. - // - // Accessing the error slot locks a threaded std::sync::Mutex, which - // can block the current async task thread. We briefly lock the mutex - // to get a reference to the error. - let e = self.error_slot.try_get_error().unwrap(); - let _ = tx.send(Err(e)); - } + self.shutdown(error); } /// Handle an incoming client request, possibly generating outgoing messages to the @@ -1273,19 +1220,89 @@ impl Connection { ); } } + + /// Marks the peer as having failed with `error`, and performs connection cleanup. + /// + /// If the connection has errored already, re-use the original error. + /// Otherwise, fail the connection with `error`. + fn shutdown(&mut self, error: impl Into) { + let mut error = error.into(); + + // Close channels first, so other tasks can start shutting down. + // + // TODO: close peer_tx and peer_rx, after: + // - adapting them using a struct with a Stream impl, rather than closures + // - making the struct forward `close` to the inner channel + self.client_rx.close(); + + // Update the shared error slot + // + // # Correctness + // + // Error slots use a threaded `std::sync::Mutex`, so accessing the slot + // can block the async task's current thread. We only perform a single + // slot update per `Client`. We ignore subsequent error slot updates. + let slot_result = self.error_slot.try_update_error(error.clone()); + + if let Err(AlreadyErrored { original_error }) = slot_result { + debug!( + new_error = %error, + %original_error, + connection_state = ?self.state, + "multiple errors on connection: \ + failed connections should stop processing pending requests and responses, \ + then close the connection" + ); + + error = original_error; + } else { + debug!(%error, + connection_state = ?self.state, + "shutting down peer service with error"); + } + + // Prepare to flush any pending client requests. + // + // We've already closed the client channel, so setting State::Failed + // will make the main loop flush any pending requests. + // + // However, we may have an outstanding client request in State::AwaitingResponse, + // so we need to deal with it first. + if let State::AwaitingResponse { tx, .. } = + std::mem::replace(&mut self.state, State::Failed) + { + // # Correctness + // + // We know the slot has Some(error), because we just set it above, + // and the error slot is never unset. + // + // Accessing the error slot locks a threaded std::sync::Mutex, which + // can block the current async task thread. We briefly lock the mutex + // to clone the error. + let _ = tx.send(Err(error.clone())); + } + + // Make the timer and metrics consistent with the Failed state. + self.request_timer = None; + self.update_state_metrics(None); + + // Finally, flush pending client requests. + while let Some(InProgressClientRequest { tx, span, .. }) = + self.client_rx.close_and_flush_next() + { + trace!( + parent: &span, + %error, + "sending an error response to a pending request on a failed connection" + ); + let _ = tx.send(Err(error.clone())); + } + } } impl Drop for Connection { fn drop(&mut self) { - if let State::AwaitingResponse { tx, .. } = - std::mem::replace(&mut self.state, State::Failed) - { - if let Some(error) = self.error_slot.try_get_error() { - let _ = tx.send(Err(error)); - } else { - let _ = tx.send(Err(PeerError::ConnectionDropped.into())); - } - } + self.shutdown(PeerError::ConnectionDropped); self.erase_state_metrics(); } diff --git a/zebra-network/src/peer/connection/tests.rs b/zebra-network/src/peer/connection/tests.rs new file mode 100644 index 000000000..d82199e8e --- /dev/null +++ b/zebra-network/src/peer/connection/tests.rs @@ -0,0 +1,3 @@ +//! Tests for peer connections + +mod vectors; diff --git a/zebra-network/src/peer/connection/tests/vectors.rs b/zebra-network/src/peer/connection/tests/vectors.rs new file mode 100644 index 000000000..b98b3e45c --- /dev/null +++ b/zebra-network/src/peer/connection/tests/vectors.rs @@ -0,0 +1,410 @@ +//! Fixed test vectors for peer connections. +//! +//! TODO: +//! - connection tests when awaiting requests (#3232) +//! - connection tests with closed/dropped peer_outbound_tx (#3233) + +use futures::{channel::mpsc, FutureExt}; +use tokio_util::codec::FramedWrite; +use tower::service_fn; +use zebra_chain::parameters::Network; + +use crate::{ + peer::{client::ClientRequestReceiver, connection::State, Connection, ErrorSlot}, + peer_set::ActiveConnectionCounter, + protocol::external::Codec, + PeerError, +}; + +#[tokio::test] +async fn connection_run_loop_ok() { + zebra_test::init(); + + let (client_tx, client_rx) = mpsc::channel(1); + + // The real stream and sink are from a split TCP connection, + // but that doesn't change how the state machine behaves. + let (peer_inbound_tx, peer_inbound_rx) = mpsc::channel(1); + + let mut peer_outbound_bytes = Vec::::new(); + let peer_outbound_tx = FramedWrite::new( + &mut peer_outbound_bytes, + Codec::builder() + .for_network(Network::Mainnet) + .with_metrics_addr_label("test".into()) + .finish(), + ); + + let unused_inbound_service = + service_fn(|_| async { unreachable!("inbound service should never be called") }); + + let shared_error_slot = ErrorSlot::default(); + + let connection = Connection { + state: State::AwaitingRequest, + request_timer: None, + svc: unused_inbound_service, + client_rx: ClientRequestReceiver::from(client_rx), + error_slot: shared_error_slot.clone(), + peer_tx: peer_outbound_tx, + connection_tracker: ActiveConnectionCounter::new_counter().track_connection(), + metrics_label: "test".to_string(), + last_metrics_state: None, + }; + + let connection = connection.run(peer_inbound_rx); + + // The run loop will wait forever for a request from Zebra or the peer, + // without any errors, channel closes, or bytes written. + // + // But the connection closes if we drop the future, so we avoid the drop by cloning it. + let connection = connection.shared(); + let connection_guard = connection.clone(); + let result = connection.now_or_never(); + assert_eq!(result, None); + + let error = shared_error_slot.try_get_error(); + assert!( + matches!(error, None), + "unexpected connection error: {:?}", + error + ); + + assert!(!client_tx.is_closed()); + assert!(!peer_inbound_tx.is_closed()); + + // We need to drop the future, because it holds a mutable reference to the bytes. + std::mem::drop(connection_guard); + assert_eq!(peer_outbound_bytes, Vec::::new()); +} + +#[tokio::test] +async fn connection_run_loop_future_drop() { + zebra_test::init(); + + let (client_tx, client_rx) = mpsc::channel(1); + + let (peer_inbound_tx, peer_inbound_rx) = mpsc::channel(1); + + let mut peer_outbound_bytes = Vec::::new(); + let peer_outbound_tx = FramedWrite::new( + &mut peer_outbound_bytes, + Codec::builder() + .for_network(Network::Mainnet) + .with_metrics_addr_label("test".into()) + .finish(), + ); + + let unused_inbound_service = + service_fn(|_| async { unreachable!("inbound service should never be called") }); + + let shared_error_slot = ErrorSlot::default(); + + let connection = Connection { + state: State::AwaitingRequest, + request_timer: None, + svc: unused_inbound_service, + client_rx: ClientRequestReceiver::from(client_rx), + error_slot: shared_error_slot.clone(), + peer_tx: peer_outbound_tx, + connection_tracker: ActiveConnectionCounter::new_counter().track_connection(), + metrics_label: "test".to_string(), + last_metrics_state: None, + }; + + let connection = connection.run(peer_inbound_rx); + + // now_or_never implicitly drops the connection future. + let result = connection.now_or_never(); + assert_eq!(result, None); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + assert!(client_tx.is_closed()); + assert!(peer_inbound_tx.is_closed()); + + assert_eq!(peer_outbound_bytes, Vec::::new()); +} + +#[tokio::test] +async fn connection_run_loop_client_close() { + zebra_test::init(); + + let (mut client_tx, client_rx) = mpsc::channel(1); + + let (peer_inbound_tx, peer_inbound_rx) = mpsc::channel(1); + + let mut peer_outbound_bytes = Vec::::new(); + let peer_outbound_tx = FramedWrite::new( + &mut peer_outbound_bytes, + Codec::builder() + .for_network(Network::Mainnet) + .with_metrics_addr_label("test".into()) + .finish(), + ); + + let unused_inbound_service = + service_fn(|_| async { unreachable!("inbound service should never be called") }); + + let shared_error_slot = ErrorSlot::default(); + + let connection = Connection { + state: State::AwaitingRequest, + request_timer: None, + svc: unused_inbound_service, + client_rx: ClientRequestReceiver::from(client_rx), + error_slot: shared_error_slot.clone(), + peer_tx: peer_outbound_tx, + connection_tracker: ActiveConnectionCounter::new_counter().track_connection(), + metrics_label: "test".to_string(), + last_metrics_state: None, + }; + + let connection = connection.run(peer_inbound_rx); + + // Explicitly close the client channel. + client_tx.close_channel(); + + // If we drop the future, the connection will close anyway, so we avoid the drop by cloning it. + let connection = connection.shared(); + let connection_guard = connection.clone(); + let result = connection.now_or_never(); + assert_eq!(result, Some(())); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + assert!(client_tx.is_closed()); + assert!(peer_inbound_tx.is_closed()); + + // We need to drop the future, because it holds a mutable reference to the bytes. + std::mem::drop(connection_guard); + assert_eq!(peer_outbound_bytes, Vec::::new()); +} + +#[tokio::test] +async fn connection_run_loop_client_drop() { + zebra_test::init(); + + let (client_tx, client_rx) = mpsc::channel(1); + + let (peer_inbound_tx, peer_inbound_rx) = mpsc::channel(1); + + let mut peer_outbound_bytes = Vec::::new(); + let peer_outbound_tx = FramedWrite::new( + &mut peer_outbound_bytes, + Codec::builder() + .for_network(Network::Mainnet) + .with_metrics_addr_label("test".into()) + .finish(), + ); + + let unused_inbound_service = + service_fn(|_| async { unreachable!("inbound service should never be called") }); + + let shared_error_slot = ErrorSlot::default(); + + let connection = Connection { + state: State::AwaitingRequest, + request_timer: None, + svc: unused_inbound_service, + client_rx: ClientRequestReceiver::from(client_rx), + error_slot: shared_error_slot.clone(), + peer_tx: peer_outbound_tx, + connection_tracker: ActiveConnectionCounter::new_counter().track_connection(), + metrics_label: "test".to_string(), + last_metrics_state: None, + }; + + let connection = connection.run(peer_inbound_rx); + + // Drop the client channel. + std::mem::drop(client_tx); + + // If we drop the future, the connection will close anyway, so we avoid the drop by cloning it. + let connection = connection.shared(); + let connection_guard = connection.clone(); + let result = connection.now_or_never(); + assert_eq!(result, Some(())); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + assert!(peer_inbound_tx.is_closed()); + + // We need to drop the future, because it holds a mutable reference to the bytes. + std::mem::drop(connection_guard); + assert_eq!(peer_outbound_bytes, Vec::::new()); +} + +#[tokio::test] +async fn connection_run_loop_inbound_close() { + zebra_test::init(); + + let (client_tx, client_rx) = mpsc::channel(1); + + let (mut peer_inbound_tx, peer_inbound_rx) = mpsc::channel(1); + + let mut peer_outbound_bytes = Vec::::new(); + let peer_outbound_tx = FramedWrite::new( + &mut peer_outbound_bytes, + Codec::builder() + .for_network(Network::Mainnet) + .with_metrics_addr_label("test".into()) + .finish(), + ); + + let unused_inbound_service = + service_fn(|_| async { unreachable!("inbound service should never be called") }); + + let shared_error_slot = ErrorSlot::default(); + + let connection = Connection { + state: State::AwaitingRequest, + request_timer: None, + svc: unused_inbound_service, + client_rx: ClientRequestReceiver::from(client_rx), + error_slot: shared_error_slot.clone(), + peer_tx: peer_outbound_tx, + connection_tracker: ActiveConnectionCounter::new_counter().track_connection(), + metrics_label: "test".to_string(), + last_metrics_state: None, + }; + + let connection = connection.run(peer_inbound_rx); + + // Explicitly close the inbound peer channel. + peer_inbound_tx.close_channel(); + + // If we drop the future, the connection will close anyway, so we avoid the drop by cloning it. + let connection = connection.shared(); + let connection_guard = connection.clone(); + let result = connection.now_or_never(); + assert_eq!(result, Some(())); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + assert!(client_tx.is_closed()); + assert!(peer_inbound_tx.is_closed()); + + // We need to drop the future, because it holds a mutable reference to the bytes. + std::mem::drop(connection_guard); + assert_eq!(peer_outbound_bytes, Vec::::new()); +} + +#[tokio::test] +async fn connection_run_loop_inbound_drop() { + zebra_test::init(); + + let (client_tx, client_rx) = mpsc::channel(1); + + let (peer_inbound_tx, peer_inbound_rx) = mpsc::channel(1); + + let mut peer_outbound_bytes = Vec::::new(); + let peer_outbound_tx = FramedWrite::new( + &mut peer_outbound_bytes, + Codec::builder() + .for_network(Network::Mainnet) + .with_metrics_addr_label("test".into()) + .finish(), + ); + + let unused_inbound_service = + service_fn(|_| async { unreachable!("inbound service should never be called") }); + + let shared_error_slot = ErrorSlot::default(); + + let connection = Connection { + state: State::AwaitingRequest, + request_timer: None, + svc: unused_inbound_service, + client_rx: ClientRequestReceiver::from(client_rx), + error_slot: shared_error_slot.clone(), + peer_tx: peer_outbound_tx, + connection_tracker: ActiveConnectionCounter::new_counter().track_connection(), + metrics_label: "test".to_string(), + last_metrics_state: None, + }; + + let connection = connection.run(peer_inbound_rx); + + // Drop the inbound peer channel. + std::mem::drop(peer_inbound_tx); + + // If we drop the future, the connection will close anyway, so we avoid the drop by cloning it. + let connection = connection.shared(); + let connection_guard = connection.clone(); + let result = connection.now_or_never(); + assert_eq!(result, Some(())); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + assert!(client_tx.is_closed()); + + // We need to drop the future, because it holds a mutable reference to the bytes. + std::mem::drop(connection_guard); + assert_eq!(peer_outbound_bytes, Vec::::new()); +} + +#[tokio::test] +async fn connection_run_loop_failed() { + zebra_test::init(); + + let (client_tx, client_rx) = mpsc::channel(1); + + let (peer_inbound_tx, peer_inbound_rx) = mpsc::channel(1); + + let mut peer_outbound_bytes = Vec::::new(); + let peer_outbound_tx = FramedWrite::new( + &mut peer_outbound_bytes, + Codec::builder() + .for_network(Network::Mainnet) + .with_metrics_addr_label("test".into()) + .finish(), + ); + + let unused_inbound_service = + service_fn(|_| async { unreachable!("inbound service should never be called") }); + + let shared_error_slot = ErrorSlot::default(); + + // Simulate an internal connection error. + shared_error_slot + .try_update_error(PeerError::ClientRequestTimeout.into()) + .expect("unexpected previous error in tests"); + + let connection = Connection { + state: State::Failed, + request_timer: None, + svc: unused_inbound_service, + client_rx: ClientRequestReceiver::from(client_rx), + error_slot: shared_error_slot.clone(), + peer_tx: peer_outbound_tx, + connection_tracker: ActiveConnectionCounter::new_counter().track_connection(), + metrics_label: "test".to_string(), + last_metrics_state: None, + }; + + let connection = connection.run(peer_inbound_rx); + + // If we drop the future, the connection will close anyway, so we avoid the drop by cloning it. + let connection = connection.shared(); + let connection_guard = connection.clone(); + let result = connection.now_or_never(); + // Because the peer error mutex is a sync mutex, + // the connection can't exit until it reaches the outer async loop. + assert_eq!(result, Some(())); + + let error = shared_error_slot.try_get_error(); + assert!(matches!(error, Some(_))); + + assert!(client_tx.is_closed()); + assert!(peer_inbound_tx.is_closed()); + + // We need to drop the future, because it holds a mutable reference to the bytes. + std::mem::drop(connection_guard); + assert_eq!(peer_outbound_bytes, Vec::::new()); +} diff --git a/zebra-network/src/peer/error.rs b/zebra-network/src/peer/error.rs index 2f4743b1a..21053660e 100644 --- a/zebra-network/src/peer/error.rs +++ b/zebra-network/src/peer/error.rs @@ -33,6 +33,14 @@ pub enum PeerError { #[error("Internal connection dropped")] ConnectionDropped, + /// Zebra dropped the [`Client`]. + #[error("Internal client dropped")] + ClientDropped, + + /// Zebra's internal heartbeat task exited. + #[error("Internal heartbeat task exited")] + HeartbeatTaskExited, + /// The remote peer did not respond to a [`peer::Client`] request in time. #[error("Client request timed out")] ClientRequestTimeout, @@ -62,6 +70,8 @@ impl PeerError { match self { PeerError::ConnectionClosed => "ConnectionClosed".into(), PeerError::ConnectionDropped => "ConnectionDropped".into(), + PeerError::ClientDropped => "ClientDropped".into(), + PeerError::HeartbeatTaskExited => "HeartbeatTaskExited".into(), PeerError::ClientRequestTimeout => "ClientRequestTimeout".into(), // TODO: add error kinds or summaries to `SerializationError` PeerError::Serialization(inner) => format!("Serialization({})", inner).into(), @@ -129,7 +139,7 @@ impl ErrorSlot { } } -/// Error used when the `ErrorSlot` already contains an error. +/// Error returned when the `ErrorSlot` already contains an error. #[derive(Clone, Debug)] pub struct AlreadyErrored { /// The original error in the error slot. diff --git a/zebra-network/src/peer/handshake.rs b/zebra-network/src/peer/handshake.rs index 17d784547..337b918aa 100644 --- a/zebra-network/src/peer/handshake.rs +++ b/zebra-network/src/peer/handshake.rs @@ -789,12 +789,12 @@ where // in this block, see constants.rs for more. let (server_tx, server_rx) = mpsc::channel(0); let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let slot = ErrorSlot::default(); + let error_slot = ErrorSlot::default(); let client = Client { shutdown_tx: Some(shutdown_tx), server_tx: server_tx.clone(), - error_slot: slot.clone(), + error_slot: error_slot.clone(), version: remote_version, }; @@ -921,7 +921,7 @@ where request_timer: None, svc: inbound_service, client_rx: server_rx.into(), - error_slot: slot, + error_slot, peer_tx, connection_tracker, metrics_label: connected_addr.get_transient_addr_label(), diff --git a/zebra-network/src/peer_set/set.rs b/zebra-network/src/peer_set/set.rs index 9735be21a..6f6641a28 100644 --- a/zebra-network/src/peer_set/set.rs +++ b/zebra-network/src/peer_set/set.rs @@ -433,6 +433,14 @@ where "service was canceled, dropping service" ); } + Poll::Ready(Some(Err((key, UnreadyError::CancelHandleDropped(_))))) => { + // Similarly, services with dropped cancel handes can have duplicates. + trace!( + ?key, + duplicate_connection = self.cancel_handles.contains_key(&key), + "cancel handle was dropped, dropping service" + ); + } // Unready -> Errored Poll::Ready(Some(Err((key, UnreadyError::Inner(error))))) => { diff --git a/zebra-network/src/peer_set/unready_service.rs b/zebra-network/src/peer_set/unready_service.rs index 4881dcd84..97fc8d7e0 100644 --- a/zebra-network/src/peer_set/unready_service.rs +++ b/zebra-network/src/peer_set/unready_service.rs @@ -12,6 +12,9 @@ use tower::Service; use crate::peer_set::set::CancelClientWork; +#[cfg(test)] +mod tests; + /// A Future that becomes satisfied when an `S`-typed service is ready. /// /// May fail due to cancellation, i.e. if the service is removed from discovery. @@ -26,9 +29,11 @@ pub(super) struct UnreadyService { pub(super) _req: PhantomData, } +#[derive(Debug, Eq, PartialEq)] pub(super) enum Error { Inner(E), Canceled, + CancelHandleDropped(oneshot::Canceled), } impl, Req> Future for UnreadyService { @@ -37,12 +42,22 @@ impl, Req> Future for UnreadyService { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - if let Poll::Ready(Ok(CancelClientWork)) = this.cancel.poll(cx) { + if let Poll::Ready(oneshot_result) = this.cancel.poll(cx) { let key = this.key.take().expect("polled after ready"); - return Poll::Ready(Err((key, Error::Canceled))); + + // # Correctness + // + // Return an error if the service is explicitly canceled, + // or its cancel handle is dropped, implicitly cancelling it. + match oneshot_result { + Ok(CancelClientWork) => return Poll::Ready(Err((key, Error::Canceled))), + Err(canceled_error) => { + return Poll::Ready(Err((key, Error::CancelHandleDropped(canceled_error)))) + } + } } - // CORRECTNESS + // # Correctness // // The current task must be scheduled for wakeup every time we return // `Poll::Pending`. @@ -54,7 +69,7 @@ impl, Req> Future for UnreadyService { let res = ready!(this .service .as_mut() - .expect("poll after ready") + .expect("polled after ready") .poll_ready(cx)); let key = this.key.take().expect("polled after ready"); diff --git a/zebra-network/src/peer_set/unready_service/tests.rs b/zebra-network/src/peer_set/unready_service/tests.rs new file mode 100644 index 000000000..19e513250 --- /dev/null +++ b/zebra-network/src/peer_set/unready_service/tests.rs @@ -0,0 +1,3 @@ +//! Tests for unready services. + +mod vectors; diff --git a/zebra-network/src/peer_set/unready_service/tests/vectors.rs b/zebra-network/src/peer_set/unready_service/tests/vectors.rs new file mode 100644 index 000000000..7d89882d7 --- /dev/null +++ b/zebra-network/src/peer_set/unready_service/tests/vectors.rs @@ -0,0 +1,86 @@ +//! Fixed test vectors for unready services. +//! +//! TODO: test that inner service errors are handled correctly (#3204) + +use std::marker::PhantomData; + +use futures::channel::oneshot; + +use zebra_test::mock_service::MockService; + +use crate::{ + peer_set::{ + set::CancelClientWork, + unready_service::{Error, UnreadyService}, + }, + Request, Response, SharedPeerError, +}; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +struct MockKey; + +#[tokio::test] +async fn unready_service_result_ok() { + zebra_test::init(); + + let (_cancel_sender, cancel) = oneshot::channel(); + + let mock_client: MockService> = + MockService::build().for_unit_tests(); + let unready_service = UnreadyService { + key: Some(MockKey), + cancel, + service: Some(mock_client), + _req: PhantomData::default(), + }; + + let result = unready_service.await; + assert!(matches!(result, Ok((MockKey, MockService { .. })))); +} + +#[tokio::test] +async fn unready_service_result_canceled() { + zebra_test::init(); + + let (cancel_sender, cancel) = oneshot::channel(); + + let mock_client: MockService> = + MockService::build().for_unit_tests(); + let unready_service = UnreadyService { + key: Some(MockKey), + cancel, + service: Some(mock_client), + _req: PhantomData::default(), + }; + + cancel_sender + .send(CancelClientWork) + .expect("unexpected oneshot send failure in tests"); + + let result = unready_service.await; + assert!(matches!(result, Err((MockKey, Error::Canceled)))); +} + +#[tokio::test] +async fn unready_service_result_cancel_handle_dropped() { + zebra_test::init(); + + let (cancel_sender, cancel) = oneshot::channel(); + + let mock_client: MockService> = + MockService::build().for_unit_tests(); + let unready_service = UnreadyService { + key: Some(MockKey), + cancel, + service: Some(mock_client), + _req: PhantomData::default(), + }; + + std::mem::drop(cancel_sender); + + let result = unready_service.await; + assert!(matches!( + result, + Err((MockKey, Error::CancelHandleDropped(_))) + )); +}