From 6eaf83b4bf122b14b48a8d3501d316f40da09535 Mon Sep 17 00:00:00 2001 From: teor Date: Thu, 1 Jun 2023 05:04:15 +1000 Subject: [PATCH] fix(security): Randomly drop connections when inbound service is overloaded (#6790) * fix(security): Randomly drop connections when inbound service is overloaded * Uses progressively higher drop probabilities * Replaces Error::Overloaded with Fatal when internal services shutdown * Applies suggestions from code review. * Quickens initial drop probability decay and updates comment * Applies suggestions from code review. * Fixes drop connection probablity calc * Update connection state metrics for different overload/error outcomes * Split overload handler into separate methods * Add unit test for drop probability function properties * Add respond_error methods to zebra-test to help with type resolution * Initial test that Overloaded errors cause some continues and some closes * Tune the number of test runs and test timing * Fix doctests and replace some confusing example requests --------- Co-authored-by: arya2 --- zebra-network/src/constants.rs | 19 ++ zebra-network/src/peer/connection.rs | 125 +++++++-- .../src/peer/connection/tests/vectors.rs | 239 +++++++++++++++++- zebra-network/src/peer/error.rs | 5 + zebra-test/src/mock_service.rs | 148 ++++++++++- 5 files changed, 498 insertions(+), 38 deletions(-) diff --git a/zebra-network/src/constants.rs b/zebra-network/src/constants.rs index c6dfc0a3d..7b7f51b5f 100644 --- a/zebra-network/src/constants.rs +++ b/zebra-network/src/constants.rs @@ -316,6 +316,25 @@ pub const EWMA_DECAY_TIME_NANOS: f64 = 200.0 * NANOS_PER_SECOND; /// The number of nanoseconds in one second. const NANOS_PER_SECOND: f64 = 1_000_000_000.0; +/// The duration it takes for the drop probability of an overloaded connection to +/// reach [`MIN_OVERLOAD_DROP_PROBABILITY`]. +/// +/// Peer connections that receive multiple overloads have a higher probability of being dropped. +/// +/// The probability of a connection being dropped gradually decreases during this interval +/// until it reaches the default drop probability ([`MIN_OVERLOAD_DROP_PROBABILITY`]). +/// +/// Increasing this number increases the rate at which connections are dropped. +pub const OVERLOAD_PROTECTION_INTERVAL: Duration = MIN_INBOUND_PEER_CONNECTION_INTERVAL; + +/// The minimum probability of dropping a peer connection when it receives an +/// [`Overloaded`](crate::PeerError::Overloaded) error. +pub const MIN_OVERLOAD_DROP_PROBABILITY: f32 = 0.05; + +/// The maximum probability of dropping a peer connection when it receives an +/// [`Overloaded`](crate::PeerError::Overloaded) error. +pub const MAX_OVERLOAD_DROP_PROBABILITY: f32 = 0.95; + lazy_static! { /// The minimum network protocol version accepted by this crate for each network, /// represented as a network upgrade. diff --git a/zebra-network/src/peer/connection.rs b/zebra-network/src/peer/connection.rs index 20429cc73..568076b0a 100644 --- a/zebra-network/src/peer/connection.rs +++ b/zebra-network/src/peer/connection.rs @@ -7,15 +7,16 @@ //! And it's unclear if these assumptions match the `zcashd` implementation. //! It should be refactored into a cleaner set of request/response pairs (#1515). -use std::{borrow::Cow, collections::HashSet, fmt, pin::Pin, sync::Arc}; +use std::{borrow::Cow, collections::HashSet, fmt, pin::Pin, sync::Arc, time::Instant}; use futures::{ future::{self, Either}, prelude::*, stream::Stream, }; +use rand::{thread_rng, Rng}; use tokio::time::{sleep, Sleep}; -use tower::Service; +use tower::{load_shed::error::Overloaded, Service, ServiceExt}; use tracing_futures::Instrument; use zebra_chain::{ @@ -25,7 +26,10 @@ use zebra_chain::{ }; use crate::{ - constants, + constants::{ + self, MAX_OVERLOAD_DROP_PROBABILITY, MIN_OVERLOAD_DROP_PROBABILITY, + OVERLOAD_PROTECTION_INTERVAL, + }, meta_addr::MetaAddr, peer::{ connection::peer_tx::PeerTx, error::AlreadyErrored, ClientRequest, ClientRequestReceiver, @@ -508,6 +512,11 @@ pub struct Connection { /// The state for this peer, when the metrics were last updated. pub(super) last_metrics_state: Option>, + + /// The time of the last overload error response from the inbound + /// service to a request from this connection, + /// or None if this connection hasn't yet received an overload error. + last_overload_time: Option, } impl fmt::Debug for Connection { @@ -549,6 +558,7 @@ impl Connection { connection_tracker, metrics_label, last_metrics_state: None, + last_overload_time: None, } } } @@ -1242,7 +1252,6 @@ where /// of connected peers. async fn drive_peer_request(&mut self, req: Request) { trace!(?req); - use tower::{load_shed::error::Overloaded, ServiceExt}; // Add a metric for inbound requests metrics::counter!( @@ -1258,29 +1267,18 @@ where tokio::task::yield_now().await; if self.svc.ready().await.is_err() { - // Treat all service readiness errors as Overloaded - // TODO: treat `TryRecvError::Closed` in `Inbound::poll_ready` as a fatal error (#1655) - self.fail_with(PeerError::Overloaded); + self.fail_with(PeerError::ServiceShutdown); return; } let rsp = match self.svc.call(req.clone()).await { Err(e) => { if e.is::() { - tracing::info!( - remote_user_agent = ?self.connection_info.remote.user_agent, - negotiated_version = ?self.connection_info.negotiated_version, - peer = ?self.metrics_label, - last_peer_state = ?self.last_metrics_state, - // TODO: remove this detailed debug info once #6506 is fixed - remote_height = ?self.connection_info.remote.start_height, - cached_addrs = ?self.cached_addrs.len(), - connection_state = ?self.state, - "inbound service is overloaded, closing connection", - ); + tracing::debug!("inbound service is overloaded, may close connection"); - metrics::counter!("pool.closed.loadshed", 1); - self.fail_with(PeerError::Overloaded); + let now = Instant::now(); + + self.handle_inbound_overload(req, now).await; } else { // We could send a reject to the remote peer, but that might cause // them to disconnect, and we might be using them to sync blocks. @@ -1292,7 +1290,9 @@ where client_receiver = ?self.client_rx, "error processing peer request", ); + self.update_state_metrics(format!("In::Req::{}/Rsp::Error", req.command())); } + return; } Ok(rsp) => rsp, @@ -1307,6 +1307,7 @@ where ); self.update_state_metrics(format!("In::Rsp::{}", rsp.command())); + // TODO: split response handler into its own method match rsp.clone() { Response::Nil => { /* generic success, do nothing */ } Response::Peers(addrs) => { @@ -1412,6 +1413,90 @@ where // before checking the connection for the next inbound or outbound request. tokio::task::yield_now().await; } + + /// Handle inbound service overload error responses by randomly terminating some connections. + /// + /// # Security + /// + /// When the inbound service is overloaded with requests, Zebra needs to drop some connections, + /// to reduce the load on the application. But dropping every connection that receives an + /// `Overloaded` error from the inbound service could cause Zebra to drop too many peer + /// connections, and stop itself downloading blocks or transactions. + /// + /// Malicious or misbehaving peers can also overload the inbound service, and make Zebra drop + /// its connections to other peers. + /// + /// So instead, Zebra drops some overloaded connections at random. If a connection has recently + /// overloaded the inbound service, it is more likely to be dropped. This makes it harder for a + /// single peer (or multiple peers) to perform a denial of service attack. + /// + /// The inbound connection rate-limit also makes it hard for multiple peers to perform this + /// attack, because each inbound connection can only send one inbound request before its + /// probability of being disconnected increases. + async fn handle_inbound_overload(&mut self, req: Request, now: Instant) { + let prev = self.last_overload_time.replace(now); + let drop_connection_probability = overload_drop_connection_probability(now, prev); + + if thread_rng().gen::() < drop_connection_probability { + metrics::counter!("pool.closed.loadshed", 1); + + tracing::info!( + drop_connection_probability, + remote_user_agent = ?self.connection_info.remote.user_agent, + negotiated_version = ?self.connection_info.negotiated_version, + peer = ?self.metrics_label, + last_peer_state = ?self.last_metrics_state, + // TODO: remove this detailed debug info once #6506 is fixed + remote_height = ?self.connection_info.remote.start_height, + cached_addrs = ?self.cached_addrs.len(), + connection_state = ?self.state, + "inbound service is overloaded, closing connection", + ); + + self.update_state_metrics(format!("In::Req::{}/Rsp::Overload::Error", req.command())); + self.fail_with(PeerError::Overloaded); + } else { + self.update_state_metrics(format!("In::Req::{}/Rsp::Overload::Ignored", req.command())); + metrics::counter!("pool.ignored.loadshed", 1); + } + } +} + +/// Returns the probability of dropping a connection where the last overload was at `prev`, +/// and the current overload is `now`. +/// +/// # Security +/// +/// Connections that haven't seen an overload error in the past OVERLOAD_PROTECTION_INTERVAL +/// have a small chance of being closed (MIN_OVERLOAD_DROP_PROBABILITY). +/// +/// Connections that have seen a previous overload error in that time +/// have a higher chance of being dropped up to MAX_OVERLOAD_DROP_PROBABILITY. +/// This probability increases quadratically, so peers that send lots of inbound +/// requests are more likely to be dropped. +/// +/// ## Examples +/// +/// If a connection sends multiple overloads close together, it is very likely to be +/// disconnected. If a connection has two overloads multiple seconds apart, it is unlikely +/// to be disconnected. +fn overload_drop_connection_probability(now: Instant, prev: Option) -> f32 { + let Some(prev) = prev else { + return MIN_OVERLOAD_DROP_PROBABILITY; + }; + + let protection_fraction_since_last_overload = + (now - prev).as_secs_f32() / OVERLOAD_PROTECTION_INTERVAL.as_secs_f32(); + + // Quadratically increase the disconnection probability for very recent overloads. + // Negative values are ignored by clamping to MIN_OVERLOAD_DROP_PROBABILITY. + let overload_fraction = protection_fraction_since_last_overload.powi(2); + + let probability_range = MAX_OVERLOAD_DROP_PROBABILITY - MIN_OVERLOAD_DROP_PROBABILITY; + let raw_drop_probability = + MAX_OVERLOAD_DROP_PROBABILITY - (overload_fraction * probability_range); + + raw_drop_probability.clamp(MIN_OVERLOAD_DROP_PROBABILITY, MAX_OVERLOAD_DROP_PROBABILITY) } impl Connection { diff --git a/zebra-network/src/peer/connection/tests/vectors.rs b/zebra-network/src/peer/connection/tests/vectors.rs index 85ac7c854..cca8c8b20 100644 --- a/zebra-network/src/peer/connection/tests/vectors.rs +++ b/zebra-network/src/peer/connection/tests/vectors.rs @@ -4,22 +4,27 @@ //! - inbound message as request //! - inbound message, but not a request (or a response) -use std::{collections::HashSet, task::Poll, time::Duration}; +use std::{ + collections::HashSet, + task::Poll, + time::{Duration, Instant}, +}; use futures::{ channel::{mpsc, oneshot}, sink::SinkMapErr, - FutureExt, StreamExt, + FutureExt, SinkExt, StreamExt, }; - +use tower::load_shed::error::Overloaded; use tracing::Span; + use zebra_chain::serialization::SerializationError; use zebra_test::mock_service::{MockService, PanicAssertion}; use crate::{ - constants::REQUEST_TIMEOUT, + constants::{MAX_OVERLOAD_DROP_PROBABILITY, MIN_OVERLOAD_DROP_PROBABILITY, REQUEST_TIMEOUT}, peer::{ - connection::{Connection, State}, + connection::{overload_drop_connection_probability, Connection, State}, ClientRequest, ErrorSlot, }, protocol::external::Message, @@ -656,6 +661,230 @@ async fn connection_run_loop_receive_timeout() { assert_eq!(outbound_message, None); } +/// Check basic properties of overload probabilities +#[test] +fn overload_probability_reduces_over_time() { + let now = Instant::now(); + + // Edge case: previous is in the future due to OS monotonic clock bugs + let prev = now + Duration::from_secs(1); + assert_eq!( + overload_drop_connection_probability(now, Some(prev)), + MAX_OVERLOAD_DROP_PROBABILITY, + "if the overload time is in the future (OS bugs?), it should have maximum drop probability", + ); + + // Overload/DoS case/edge case: rapidly repeated overloads + let prev = now; + assert_eq!( + overload_drop_connection_probability(now, Some(prev)), + MAX_OVERLOAD_DROP_PROBABILITY, + "if the overload times are the same, overloads should have maximum drop probability", + ); + + // Overload/DoS case: rapidly repeated overloads + let prev = now - Duration::from_micros(1); + let drop_probability = overload_drop_connection_probability(now, Some(prev)); + assert!( + drop_probability <= MAX_OVERLOAD_DROP_PROBABILITY, + "if the overloads are very close together, drops can optionally decrease", + ); + assert!( + MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001, + "if the overloads are very close together, drops can only decrease slightly", + ); + let last_probability = drop_probability; + + // Overload/DoS case: rapidly repeated overloads + let prev = now - Duration::from_millis(1); + let drop_probability = overload_drop_connection_probability(now, Some(prev)); + assert!( + drop_probability < last_probability, + "if the overloads decrease, drops should decrease", + ); + assert!( + MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001, + "if the overloads are very close together, drops can only decrease slightly", + ); + let last_probability = drop_probability; + + // Overload/DoS case: rapidly repeated overloads + let prev = now - Duration::from_millis(10); + let drop_probability = overload_drop_connection_probability(now, Some(prev)); + assert!( + drop_probability < last_probability, + "if the overloads decrease, drops should decrease", + ); + assert!( + MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001, + "if the overloads are very close together, drops can only decrease slightly", + ); + let last_probability = drop_probability; + + // Overload case: frequent overloads + let prev = now - Duration::from_millis(100); + let drop_probability = overload_drop_connection_probability(now, Some(prev)); + assert!( + drop_probability < last_probability, + "if the overloads decrease, drops should decrease", + ); + assert!( + MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.01, + "if the overloads are very close together, drops can only decrease slightly", + ); + let last_probability = drop_probability; + + // Overload case: occasional but repeated overloads + let prev = now - Duration::from_secs(1); + let drop_probability = overload_drop_connection_probability(now, Some(prev)); + assert!( + drop_probability < last_probability, + "if the overloads decrease, drops should decrease", + ); + assert!( + MAX_OVERLOAD_DROP_PROBABILITY - drop_probability > 0.5, + "if the overloads are distant, drops should decrease a lot", + ); + let last_probability = drop_probability; + + // Overload case: occasional overloads + let prev = now - Duration::from_secs(5); + let drop_probability = overload_drop_connection_probability(now, Some(prev)); + assert!( + drop_probability < last_probability, + "if the overloads decrease, drops should decrease", + ); + assert!( + MAX_OVERLOAD_DROP_PROBABILITY - drop_probability > 0.7, + "if the overloads are distant, drops should decrease a lot", + ); + let _last_probability = drop_probability; + + // Base case: infrequent overloads + let prev = now - Duration::from_secs(10); + let drop_probability = overload_drop_connection_probability(now, Some(prev)); + assert_eq!( + drop_probability, MIN_OVERLOAD_DROP_PROBABILITY, + "if overloads are far apart, drops should have minimum drop probability", + ); + + // Base case: no previous overload + let drop_probability = overload_drop_connection_probability(now, None); + assert_eq!( + drop_probability, MIN_OVERLOAD_DROP_PROBABILITY, + "if there is no previous overload time, overloads should have minimum drop probability", + ); +} + +/// Test that connections are randomly terminated in response to `Overloaded` errors. +/// +/// TODO: do a similar test on the real service stack created in the `start` command. +#[tokio::test(flavor = "multi_thread")] +async fn connection_is_randomly_disconnected_on_overload() { + let _init_guard = zebra_test::init(); + + // The number of times we repeat the test + const TEST_RUNS: usize = 220; + // The expected number of tests before a test failure due to random chance. + // Based on 10 tests per PR, 100 PR pushes per week, 50 weeks per year. + const TESTS_BEFORE_FAILURE: f32 = 50_000.0; + + let test_runs = TEST_RUNS.try_into().expect("constant fits in i32"); + // The probability of random test failure is: + // MIN_OVERLOAD_DROP_PROBABILITY^TEST_RUNS + MAX_OVERLOAD_DROP_PROBABILITY^TEST_RUNS + assert!( + 1.0 / MIN_OVERLOAD_DROP_PROBABILITY.powi(test_runs) > TESTS_BEFORE_FAILURE, + "not enough test runs: failures must be frequent enough to happen in almost all tests" + ); + assert!( + 1.0 / MAX_OVERLOAD_DROP_PROBABILITY.powi(test_runs) > TESTS_BEFORE_FAILURE, + "not enough test runs: successes must be frequent enough to happen in almost all tests" + ); + + let mut connection_continues = 0; + let mut connection_closes = 0; + + for _ in 0..TEST_RUNS { + // The real stream and sink are from a split TCP connection, + // but that doesn't change how the state machine behaves. + let (mut peer_tx, peer_rx) = mpsc::channel(1); + + let ( + connection, + _client_tx, + mut inbound_service, + mut peer_outbound_messages, + shared_error_slot, + ) = new_test_connection(); + + // The connection hasn't run so it must not have errors + let error = shared_error_slot.try_get_error(); + assert!( + error.is_none(), + "unexpected error before starting the connection event loop: {error:?}", + ); + + // Start the connection run loop future in a spawned task + let connection_handle = tokio::spawn(connection.run(peer_rx)); + tokio::time::sleep(Duration::from_millis(1)).await; + + // The connection hasn't received any messages, so it must not have errors + let error = shared_error_slot.try_get_error(); + assert!( + error.is_none(), + "unexpected error before sending messages to the connection event loop: {error:?}", + ); + + // Simulate an overloaded connection error in response to an inbound request. + let inbound_req = Message::GetAddr; + peer_tx + .send(Ok(inbound_req)) + .await + .expect("send to channel always succeeds"); + tokio::time::sleep(Duration::from_millis(1)).await; + + // The connection hasn't got a response, so it must not have errors + let error = shared_error_slot.try_get_error(); + assert!( + error.is_none(), + "unexpected error before sending responses to the connection event loop: {error:?}", + ); + + inbound_service + .expect_request(Request::Peers) + .await + .respond_error(Overloaded::new().into()); + tokio::time::sleep(Duration::from_millis(1)).await; + + let outbound_result = peer_outbound_messages.try_next(); + assert!( + !matches!(outbound_result, Ok(Some(_))), + "unexpected outbound message after Overloaded error:\n\ + {outbound_result:?}\n\ + note: TryRecvErr means there are no messages, Ok(None) means the channel is closed" + ); + + let error = shared_error_slot.try_get_error(); + if error.is_some() { + connection_closes += 1; + } else { + connection_continues += 1; + } + + // We need to terminate the spawned task + connection_handle.abort(); + } + + assert!( + connection_closes > 0, + "some overloaded connections must be closed at random" + ); + assert!( + connection_continues > 0, + "some overloaded errors must be ignored at random" + ); +} + /// Creates a new [`Connection`] instance for unit tests. fn new_test_connection() -> ( Connection< diff --git a/zebra-network/src/peer/error.rs b/zebra-network/src/peer/error.rs index 0180c377d..4d842ba5c 100644 --- a/zebra-network/src/peer/error.rs +++ b/zebra-network/src/peer/error.rs @@ -82,6 +82,10 @@ pub enum PeerError { #[error("Internal services over capacity")] Overloaded, + /// This node's internal services are no longer able to service requests. + #[error("Internal services have failed or shutdown")] + ServiceShutdown, + /// We requested data, but the peer replied with a `notfound` message. /// (Or it didn't respond before the request finished.) /// @@ -138,6 +142,7 @@ impl PeerError { PeerError::Serialization(inner) => format!("Serialization({inner})").into(), PeerError::DuplicateHandshake => "DuplicateHandshake".into(), PeerError::Overloaded => "Overloaded".into(), + PeerError::ServiceShutdown => "ServiceShutdown".into(), PeerError::NotFoundResponse(_) => "NotFoundResponse".into(), PeerError::NotFoundRegistry(_) => "NotFoundRegistry".into(), } diff --git a/zebra-test/src/mock_service.rs b/zebra-test/src/mock_service.rs index 21debf97c..d92e6f8b4 100644 --- a/zebra-test/src/mock_service.rs +++ b/zebra-test/src/mock_service.rs @@ -740,7 +740,10 @@ impl ResponseSender { /// This method takes ownership of the [`ResponseSender`] so that only one response can be /// sent. /// - /// If `respond` or `respond_with` are not called, the caller will panic. + /// # Panics + /// + /// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a + /// timeout error. /// /// # Example /// @@ -748,6 +751,9 @@ impl ResponseSender { /// # use zebra_test::mock_service::MockService; /// # use tower::{Service, ServiceExt}; /// # + /// # #[derive(Debug, PartialEq, Eq)] + /// # struct Request; + /// # /// # let reactor = tokio::runtime::Builder::new_current_thread() /// # .enable_all() /// # .build() @@ -760,19 +766,19 @@ impl ResponseSender { /// /// # let mut service = mock_service.clone(); /// # let task = tokio::spawn(async move { - /// # let first_call_result = (&mut service).oneshot(1).await; - /// # let second_call_result = service.oneshot(1).await; + /// # let first_call_result = (&mut service).oneshot(Request).await; + /// # let second_call_result = service.oneshot(Request).await; /// # /// # (first_call_result, second_call_result) /// # }); /// # /// mock_service - /// .expect_request(1) + /// .expect_request(Request) /// .await - /// .respond("Received one".to_owned()); + /// .respond("Received Request".to_owned()); /// /// mock_service - /// .expect_request(1) + /// .expect_request(Request) /// .await /// .respond(Err("Duplicate request")); /// # }); @@ -789,7 +795,10 @@ impl ResponseSender { /// This method takes ownership of the [`ResponseSender`] so that only one response can be /// sent. /// - /// If `respond` or `respond_with` are not called, the caller will panic. + /// # Panics + /// + /// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a + /// timeout error. /// /// # Example /// @@ -797,6 +806,9 @@ impl ResponseSender { /// # use zebra_test::mock_service::MockService; /// # use tower::{Service, ServiceExt}; /// # + /// # #[derive(Debug, PartialEq, Eq)] + /// # struct Request; + /// # /// # let reactor = tokio::runtime::Builder::new_current_thread() /// # .enable_all() /// # .build() @@ -809,21 +821,21 @@ impl ResponseSender { /// /// # let mut service = mock_service.clone(); /// # let task = tokio::spawn(async move { - /// # let first_call_result = (&mut service).oneshot(1).await; - /// # let second_call_result = service.oneshot(1).await; + /// # let first_call_result = (&mut service).oneshot(Request).await; + /// # let second_call_result = service.oneshot(Request).await; /// # /// # (first_call_result, second_call_result) /// # }); /// # /// mock_service - /// .expect_request(1) + /// .expect_request(Request) /// .await - /// .respond_with(|req| format!("Received: {}", req)); + /// .respond_with(|req| format!("Received: {req:?}")); /// /// mock_service - /// .expect_request(1) + /// .expect_request(Request) /// .await - /// .respond_with(|req| Err(format!("Duplicate request: {}", req))); + /// .respond_with(|req| Err(format!("Duplicate request: {req:?}"))); /// # }); /// ``` pub fn respond_with(self, response_fn: F) @@ -834,6 +846,116 @@ impl ResponseSender { let response_result = response_fn(self.request()).into_result(); let _ = self.response_sender.send(response_result); } + + /// Respond to the request using a fixed error value. + /// + /// The `error` must be the `Error` type. This helps avoid type resolution issues in the + /// compiler. + /// + /// This method takes ownership of the [`ResponseSender`] so that only one response can be + /// sent. + /// + /// # Panics + /// + /// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a + /// timeout error. + /// + /// # Example + /// + /// ``` + /// # use zebra_test::mock_service::MockService; + /// # use tower::{Service, ServiceExt}; + /// # + /// # #[derive(Debug, PartialEq, Eq)] + /// # struct Request; + /// # struct Response; + /// # + /// # let reactor = tokio::runtime::Builder::new_current_thread() + /// # .enable_all() + /// # .build() + /// # .expect("Failed to build Tokio runtime"); + /// # + /// # reactor.block_on(async { + /// // Mock a service with a `String` as the service `Error` type. + /// let mut mock_service: MockService = + /// MockService::build().for_unit_tests(); + /// + /// # let mut service = mock_service.clone(); + /// # let task = tokio::spawn(async move { + /// # let first_call_result = (&mut service).oneshot(Request).await; + /// # let second_call_result = service.oneshot(Request).await; + /// # + /// # (first_call_result, second_call_result) + /// # }); + /// # + /// mock_service + /// .expect_request(Request) + /// .await + /// .respond_error("Duplicate request".to_string()); + /// # }); + /// ``` + pub fn respond_error(self, error: Error) { + // TODO: impl ResponseResult for BoxError/Error trait when overlapping impls are + // better supported by the compiler + let _ = self.response_sender.send(Err(error)); + } + + /// Respond to the request by calculating an error from the request. + /// + /// The `error` must be the `Error` type. This helps avoid type resolution issues in the + /// compiler. + /// + /// This method takes ownership of the [`ResponseSender`] so that only one response can be + /// sent. + /// + /// # Panics + /// + /// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a + /// timeout error. + /// + /// # Example + /// + /// ``` + /// # use zebra_test::mock_service::MockService; + /// # use tower::{Service, ServiceExt}; + /// # + /// # #[derive(Debug, PartialEq, Eq)] + /// # struct Request; + /// # struct Response; + /// # + /// # let reactor = tokio::runtime::Builder::new_current_thread() + /// # .enable_all() + /// # .build() + /// # .expect("Failed to build Tokio runtime"); + /// # + /// # reactor.block_on(async { + /// // Mock a service with a `String` as the service `Error` type. + /// let mut mock_service: MockService = + /// MockService::build().for_unit_tests(); + /// + /// # let mut service = mock_service.clone(); + /// # let task = tokio::spawn(async move { + /// # let first_call_result = (&mut service).oneshot(Request).await; + /// # let second_call_result = service.oneshot(Request).await; + /// # + /// # (first_call_result, second_call_result) + /// # }); + /// # + /// mock_service + /// .expect_request(Request) + /// .await + /// .respond_with_error(|req| format!("Duplicate request: {req:?}")); + /// # }); + /// ``` + pub fn respond_with_error(self, response_fn: F) + where + F: FnOnce(&Request) -> Error, + { + // TODO: impl ResponseResult for BoxError/Error trait when overlapping impls are + // better supported by the compiler + let response_result = Err(response_fn(self.request())); + let _ = self.response_sender.send(response_result); + } } /// A representation of an assertion type.