fix(security): Randomly drop connections when inbound service is overloaded (#6790)

* fix(security): Randomly drop connections when inbound service is overloaded

* Uses progressively higher drop probabilities

* Replaces Error::Overloaded with Fatal when internal services shutdown

* Applies suggestions from code review.

* Quickens initial drop probability decay and updates comment

* Applies suggestions from code review.

* Fixes drop connection probablity calc

* Update connection state metrics for different overload/error outcomes

* Split overload handler into separate methods

* Add unit test for drop probability function properties

* Add respond_error methods to zebra-test to help with type resolution

* Initial test that Overloaded errors cause some continues and some closes

* Tune the number of test runs and test timing

* Fix doctests and replace some confusing example requests

---------

Co-authored-by: arya2 <aryasolhi@gmail.com>
This commit is contained in:
teor 2023-06-01 05:04:15 +10:00 committed by GitHub
parent af4d53122f
commit 6eaf83b4bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 498 additions and 38 deletions

View File

@ -316,6 +316,25 @@ pub const EWMA_DECAY_TIME_NANOS: f64 = 200.0 * NANOS_PER_SECOND;
/// The number of nanoseconds in one second.
const NANOS_PER_SECOND: f64 = 1_000_000_000.0;
/// The duration it takes for the drop probability of an overloaded connection to
/// reach [`MIN_OVERLOAD_DROP_PROBABILITY`].
///
/// Peer connections that receive multiple overloads have a higher probability of being dropped.
///
/// The probability of a connection being dropped gradually decreases during this interval
/// until it reaches the default drop probability ([`MIN_OVERLOAD_DROP_PROBABILITY`]).
///
/// Increasing this number increases the rate at which connections are dropped.
pub const OVERLOAD_PROTECTION_INTERVAL: Duration = MIN_INBOUND_PEER_CONNECTION_INTERVAL;
/// The minimum probability of dropping a peer connection when it receives an
/// [`Overloaded`](crate::PeerError::Overloaded) error.
pub const MIN_OVERLOAD_DROP_PROBABILITY: f32 = 0.05;
/// The maximum probability of dropping a peer connection when it receives an
/// [`Overloaded`](crate::PeerError::Overloaded) error.
pub const MAX_OVERLOAD_DROP_PROBABILITY: f32 = 0.95;
lazy_static! {
/// The minimum network protocol version accepted by this crate for each network,
/// represented as a network upgrade.

View File

@ -7,15 +7,16 @@
//! And it's unclear if these assumptions match the `zcashd` implementation.
//! It should be refactored into a cleaner set of request/response pairs (#1515).
use std::{borrow::Cow, collections::HashSet, fmt, pin::Pin, sync::Arc};
use std::{borrow::Cow, collections::HashSet, fmt, pin::Pin, sync::Arc, time::Instant};
use futures::{
future::{self, Either},
prelude::*,
stream::Stream,
};
use rand::{thread_rng, Rng};
use tokio::time::{sleep, Sleep};
use tower::Service;
use tower::{load_shed::error::Overloaded, Service, ServiceExt};
use tracing_futures::Instrument;
use zebra_chain::{
@ -25,7 +26,10 @@ use zebra_chain::{
};
use crate::{
constants,
constants::{
self, MAX_OVERLOAD_DROP_PROBABILITY, MIN_OVERLOAD_DROP_PROBABILITY,
OVERLOAD_PROTECTION_INTERVAL,
},
meta_addr::MetaAddr,
peer::{
connection::peer_tx::PeerTx, error::AlreadyErrored, ClientRequest, ClientRequestReceiver,
@ -508,6 +512,11 @@ pub struct Connection<S, Tx> {
/// The state for this peer, when the metrics were last updated.
pub(super) last_metrics_state: Option<Cow<'static, str>>,
/// The time of the last overload error response from the inbound
/// service to a request from this connection,
/// or None if this connection hasn't yet received an overload error.
last_overload_time: Option<Instant>,
}
impl<S, Tx> fmt::Debug for Connection<S, Tx> {
@ -549,6 +558,7 @@ impl<S, Tx> Connection<S, Tx> {
connection_tracker,
metrics_label,
last_metrics_state: None,
last_overload_time: None,
}
}
}
@ -1242,7 +1252,6 @@ where
/// of connected peers.
async fn drive_peer_request(&mut self, req: Request) {
trace!(?req);
use tower::{load_shed::error::Overloaded, ServiceExt};
// Add a metric for inbound requests
metrics::counter!(
@ -1258,29 +1267,18 @@ where
tokio::task::yield_now().await;
if self.svc.ready().await.is_err() {
// Treat all service readiness errors as Overloaded
// TODO: treat `TryRecvError::Closed` in `Inbound::poll_ready` as a fatal error (#1655)
self.fail_with(PeerError::Overloaded);
self.fail_with(PeerError::ServiceShutdown);
return;
}
let rsp = match self.svc.call(req.clone()).await {
Err(e) => {
if e.is::<Overloaded>() {
tracing::info!(
remote_user_agent = ?self.connection_info.remote.user_agent,
negotiated_version = ?self.connection_info.negotiated_version,
peer = ?self.metrics_label,
last_peer_state = ?self.last_metrics_state,
// TODO: remove this detailed debug info once #6506 is fixed
remote_height = ?self.connection_info.remote.start_height,
cached_addrs = ?self.cached_addrs.len(),
connection_state = ?self.state,
"inbound service is overloaded, closing connection",
);
tracing::debug!("inbound service is overloaded, may close connection");
metrics::counter!("pool.closed.loadshed", 1);
self.fail_with(PeerError::Overloaded);
let now = Instant::now();
self.handle_inbound_overload(req, now).await;
} else {
// We could send a reject to the remote peer, but that might cause
// them to disconnect, and we might be using them to sync blocks.
@ -1292,7 +1290,9 @@ where
client_receiver = ?self.client_rx,
"error processing peer request",
);
self.update_state_metrics(format!("In::Req::{}/Rsp::Error", req.command()));
}
return;
}
Ok(rsp) => rsp,
@ -1307,6 +1307,7 @@ where
);
self.update_state_metrics(format!("In::Rsp::{}", rsp.command()));
// TODO: split response handler into its own method
match rsp.clone() {
Response::Nil => { /* generic success, do nothing */ }
Response::Peers(addrs) => {
@ -1412,6 +1413,90 @@ where
// before checking the connection for the next inbound or outbound request.
tokio::task::yield_now().await;
}
/// Handle inbound service overload error responses by randomly terminating some connections.
///
/// # Security
///
/// When the inbound service is overloaded with requests, Zebra needs to drop some connections,
/// to reduce the load on the application. But dropping every connection that receives an
/// `Overloaded` error from the inbound service could cause Zebra to drop too many peer
/// connections, and stop itself downloading blocks or transactions.
///
/// Malicious or misbehaving peers can also overload the inbound service, and make Zebra drop
/// its connections to other peers.
///
/// So instead, Zebra drops some overloaded connections at random. If a connection has recently
/// overloaded the inbound service, it is more likely to be dropped. This makes it harder for a
/// single peer (or multiple peers) to perform a denial of service attack.
///
/// The inbound connection rate-limit also makes it hard for multiple peers to perform this
/// attack, because each inbound connection can only send one inbound request before its
/// probability of being disconnected increases.
async fn handle_inbound_overload(&mut self, req: Request, now: Instant) {
let prev = self.last_overload_time.replace(now);
let drop_connection_probability = overload_drop_connection_probability(now, prev);
if thread_rng().gen::<f32>() < drop_connection_probability {
metrics::counter!("pool.closed.loadshed", 1);
tracing::info!(
drop_connection_probability,
remote_user_agent = ?self.connection_info.remote.user_agent,
negotiated_version = ?self.connection_info.negotiated_version,
peer = ?self.metrics_label,
last_peer_state = ?self.last_metrics_state,
// TODO: remove this detailed debug info once #6506 is fixed
remote_height = ?self.connection_info.remote.start_height,
cached_addrs = ?self.cached_addrs.len(),
connection_state = ?self.state,
"inbound service is overloaded, closing connection",
);
self.update_state_metrics(format!("In::Req::{}/Rsp::Overload::Error", req.command()));
self.fail_with(PeerError::Overloaded);
} else {
self.update_state_metrics(format!("In::Req::{}/Rsp::Overload::Ignored", req.command()));
metrics::counter!("pool.ignored.loadshed", 1);
}
}
}
/// Returns the probability of dropping a connection where the last overload was at `prev`,
/// and the current overload is `now`.
///
/// # Security
///
/// Connections that haven't seen an overload error in the past OVERLOAD_PROTECTION_INTERVAL
/// have a small chance of being closed (MIN_OVERLOAD_DROP_PROBABILITY).
///
/// Connections that have seen a previous overload error in that time
/// have a higher chance of being dropped up to MAX_OVERLOAD_DROP_PROBABILITY.
/// This probability increases quadratically, so peers that send lots of inbound
/// requests are more likely to be dropped.
///
/// ## Examples
///
/// If a connection sends multiple overloads close together, it is very likely to be
/// disconnected. If a connection has two overloads multiple seconds apart, it is unlikely
/// to be disconnected.
fn overload_drop_connection_probability(now: Instant, prev: Option<Instant>) -> f32 {
let Some(prev) = prev else {
return MIN_OVERLOAD_DROP_PROBABILITY;
};
let protection_fraction_since_last_overload =
(now - prev).as_secs_f32() / OVERLOAD_PROTECTION_INTERVAL.as_secs_f32();
// Quadratically increase the disconnection probability for very recent overloads.
// Negative values are ignored by clamping to MIN_OVERLOAD_DROP_PROBABILITY.
let overload_fraction = protection_fraction_since_last_overload.powi(2);
let probability_range = MAX_OVERLOAD_DROP_PROBABILITY - MIN_OVERLOAD_DROP_PROBABILITY;
let raw_drop_probability =
MAX_OVERLOAD_DROP_PROBABILITY - (overload_fraction * probability_range);
raw_drop_probability.clamp(MIN_OVERLOAD_DROP_PROBABILITY, MAX_OVERLOAD_DROP_PROBABILITY)
}
impl<S, Tx> Connection<S, Tx> {

View File

@ -4,22 +4,27 @@
//! - inbound message as request
//! - inbound message, but not a request (or a response)
use std::{collections::HashSet, task::Poll, time::Duration};
use std::{
collections::HashSet,
task::Poll,
time::{Duration, Instant},
};
use futures::{
channel::{mpsc, oneshot},
sink::SinkMapErr,
FutureExt, StreamExt,
FutureExt, SinkExt, StreamExt,
};
use tower::load_shed::error::Overloaded;
use tracing::Span;
use zebra_chain::serialization::SerializationError;
use zebra_test::mock_service::{MockService, PanicAssertion};
use crate::{
constants::REQUEST_TIMEOUT,
constants::{MAX_OVERLOAD_DROP_PROBABILITY, MIN_OVERLOAD_DROP_PROBABILITY, REQUEST_TIMEOUT},
peer::{
connection::{Connection, State},
connection::{overload_drop_connection_probability, Connection, State},
ClientRequest, ErrorSlot,
},
protocol::external::Message,
@ -656,6 +661,230 @@ async fn connection_run_loop_receive_timeout() {
assert_eq!(outbound_message, None);
}
/// Check basic properties of overload probabilities
#[test]
fn overload_probability_reduces_over_time() {
let now = Instant::now();
// Edge case: previous is in the future due to OS monotonic clock bugs
let prev = now + Duration::from_secs(1);
assert_eq!(
overload_drop_connection_probability(now, Some(prev)),
MAX_OVERLOAD_DROP_PROBABILITY,
"if the overload time is in the future (OS bugs?), it should have maximum drop probability",
);
// Overload/DoS case/edge case: rapidly repeated overloads
let prev = now;
assert_eq!(
overload_drop_connection_probability(now, Some(prev)),
MAX_OVERLOAD_DROP_PROBABILITY,
"if the overload times are the same, overloads should have maximum drop probability",
);
// Overload/DoS case: rapidly repeated overloads
let prev = now - Duration::from_micros(1);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability <= MAX_OVERLOAD_DROP_PROBABILITY,
"if the overloads are very close together, drops can optionally decrease",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001,
"if the overloads are very close together, drops can only decrease slightly",
);
let last_probability = drop_probability;
// Overload/DoS case: rapidly repeated overloads
let prev = now - Duration::from_millis(1);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability < last_probability,
"if the overloads decrease, drops should decrease",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001,
"if the overloads are very close together, drops can only decrease slightly",
);
let last_probability = drop_probability;
// Overload/DoS case: rapidly repeated overloads
let prev = now - Duration::from_millis(10);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability < last_probability,
"if the overloads decrease, drops should decrease",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001,
"if the overloads are very close together, drops can only decrease slightly",
);
let last_probability = drop_probability;
// Overload case: frequent overloads
let prev = now - Duration::from_millis(100);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability < last_probability,
"if the overloads decrease, drops should decrease",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.01,
"if the overloads are very close together, drops can only decrease slightly",
);
let last_probability = drop_probability;
// Overload case: occasional but repeated overloads
let prev = now - Duration::from_secs(1);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability < last_probability,
"if the overloads decrease, drops should decrease",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability > 0.5,
"if the overloads are distant, drops should decrease a lot",
);
let last_probability = drop_probability;
// Overload case: occasional overloads
let prev = now - Duration::from_secs(5);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert!(
drop_probability < last_probability,
"if the overloads decrease, drops should decrease",
);
assert!(
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability > 0.7,
"if the overloads are distant, drops should decrease a lot",
);
let _last_probability = drop_probability;
// Base case: infrequent overloads
let prev = now - Duration::from_secs(10);
let drop_probability = overload_drop_connection_probability(now, Some(prev));
assert_eq!(
drop_probability, MIN_OVERLOAD_DROP_PROBABILITY,
"if overloads are far apart, drops should have minimum drop probability",
);
// Base case: no previous overload
let drop_probability = overload_drop_connection_probability(now, None);
assert_eq!(
drop_probability, MIN_OVERLOAD_DROP_PROBABILITY,
"if there is no previous overload time, overloads should have minimum drop probability",
);
}
/// Test that connections are randomly terminated in response to `Overloaded` errors.
///
/// TODO: do a similar test on the real service stack created in the `start` command.
#[tokio::test(flavor = "multi_thread")]
async fn connection_is_randomly_disconnected_on_overload() {
let _init_guard = zebra_test::init();
// The number of times we repeat the test
const TEST_RUNS: usize = 220;
// The expected number of tests before a test failure due to random chance.
// Based on 10 tests per PR, 100 PR pushes per week, 50 weeks per year.
const TESTS_BEFORE_FAILURE: f32 = 50_000.0;
let test_runs = TEST_RUNS.try_into().expect("constant fits in i32");
// The probability of random test failure is:
// MIN_OVERLOAD_DROP_PROBABILITY^TEST_RUNS + MAX_OVERLOAD_DROP_PROBABILITY^TEST_RUNS
assert!(
1.0 / MIN_OVERLOAD_DROP_PROBABILITY.powi(test_runs) > TESTS_BEFORE_FAILURE,
"not enough test runs: failures must be frequent enough to happen in almost all tests"
);
assert!(
1.0 / MAX_OVERLOAD_DROP_PROBABILITY.powi(test_runs) > TESTS_BEFORE_FAILURE,
"not enough test runs: successes must be frequent enough to happen in almost all tests"
);
let mut connection_continues = 0;
let mut connection_closes = 0;
for _ in 0..TEST_RUNS {
// The real stream and sink are from a split TCP connection,
// but that doesn't change how the state machine behaves.
let (mut peer_tx, peer_rx) = mpsc::channel(1);
let (
connection,
_client_tx,
mut inbound_service,
mut peer_outbound_messages,
shared_error_slot,
) = new_test_connection();
// The connection hasn't run so it must not have errors
let error = shared_error_slot.try_get_error();
assert!(
error.is_none(),
"unexpected error before starting the connection event loop: {error:?}",
);
// Start the connection run loop future in a spawned task
let connection_handle = tokio::spawn(connection.run(peer_rx));
tokio::time::sleep(Duration::from_millis(1)).await;
// The connection hasn't received any messages, so it must not have errors
let error = shared_error_slot.try_get_error();
assert!(
error.is_none(),
"unexpected error before sending messages to the connection event loop: {error:?}",
);
// Simulate an overloaded connection error in response to an inbound request.
let inbound_req = Message::GetAddr;
peer_tx
.send(Ok(inbound_req))
.await
.expect("send to channel always succeeds");
tokio::time::sleep(Duration::from_millis(1)).await;
// The connection hasn't got a response, so it must not have errors
let error = shared_error_slot.try_get_error();
assert!(
error.is_none(),
"unexpected error before sending responses to the connection event loop: {error:?}",
);
inbound_service
.expect_request(Request::Peers)
.await
.respond_error(Overloaded::new().into());
tokio::time::sleep(Duration::from_millis(1)).await;
let outbound_result = peer_outbound_messages.try_next();
assert!(
!matches!(outbound_result, Ok(Some(_))),
"unexpected outbound message after Overloaded error:\n\
{outbound_result:?}\n\
note: TryRecvErr means there are no messages, Ok(None) means the channel is closed"
);
let error = shared_error_slot.try_get_error();
if error.is_some() {
connection_closes += 1;
} else {
connection_continues += 1;
}
// We need to terminate the spawned task
connection_handle.abort();
}
assert!(
connection_closes > 0,
"some overloaded connections must be closed at random"
);
assert!(
connection_continues > 0,
"some overloaded errors must be ignored at random"
);
}
/// Creates a new [`Connection`] instance for unit tests.
fn new_test_connection() -> (
Connection<

View File

@ -82,6 +82,10 @@ pub enum PeerError {
#[error("Internal services over capacity")]
Overloaded,
/// This node's internal services are no longer able to service requests.
#[error("Internal services have failed or shutdown")]
ServiceShutdown,
/// We requested data, but the peer replied with a `notfound` message.
/// (Or it didn't respond before the request finished.)
///
@ -138,6 +142,7 @@ impl PeerError {
PeerError::Serialization(inner) => format!("Serialization({inner})").into(),
PeerError::DuplicateHandshake => "DuplicateHandshake".into(),
PeerError::Overloaded => "Overloaded".into(),
PeerError::ServiceShutdown => "ServiceShutdown".into(),
PeerError::NotFoundResponse(_) => "NotFoundResponse".into(),
PeerError::NotFoundRegistry(_) => "NotFoundRegistry".into(),
}

View File

@ -740,7 +740,10 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
/// This method takes ownership of the [`ResponseSender`] so that only one response can be
/// sent.
///
/// If `respond` or `respond_with` are not called, the caller will panic.
/// # Panics
///
/// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a
/// timeout error.
///
/// # Example
///
@ -748,6 +751,9 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
/// # use zebra_test::mock_service::MockService;
/// # use tower::{Service, ServiceExt};
/// #
/// # #[derive(Debug, PartialEq, Eq)]
/// # struct Request;
/// #
/// # let reactor = tokio::runtime::Builder::new_current_thread()
/// # .enable_all()
/// # .build()
@ -760,19 +766,19 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
///
/// # let mut service = mock_service.clone();
/// # let task = tokio::spawn(async move {
/// # let first_call_result = (&mut service).oneshot(1).await;
/// # let second_call_result = service.oneshot(1).await;
/// # let first_call_result = (&mut service).oneshot(Request).await;
/// # let second_call_result = service.oneshot(Request).await;
/// #
/// # (first_call_result, second_call_result)
/// # });
/// #
/// mock_service
/// .expect_request(1)
/// .expect_request(Request)
/// .await
/// .respond("Received one".to_owned());
/// .respond("Received Request".to_owned());
///
/// mock_service
/// .expect_request(1)
/// .expect_request(Request)
/// .await
/// .respond(Err("Duplicate request"));
/// # });
@ -789,7 +795,10 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
/// This method takes ownership of the [`ResponseSender`] so that only one response can be
/// sent.
///
/// If `respond` or `respond_with` are not called, the caller will panic.
/// # Panics
///
/// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a
/// timeout error.
///
/// # Example
///
@ -797,6 +806,9 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
/// # use zebra_test::mock_service::MockService;
/// # use tower::{Service, ServiceExt};
/// #
/// # #[derive(Debug, PartialEq, Eq)]
/// # struct Request;
/// #
/// # let reactor = tokio::runtime::Builder::new_current_thread()
/// # .enable_all()
/// # .build()
@ -809,21 +821,21 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
///
/// # let mut service = mock_service.clone();
/// # let task = tokio::spawn(async move {
/// # let first_call_result = (&mut service).oneshot(1).await;
/// # let second_call_result = service.oneshot(1).await;
/// # let first_call_result = (&mut service).oneshot(Request).await;
/// # let second_call_result = service.oneshot(Request).await;
/// #
/// # (first_call_result, second_call_result)
/// # });
/// #
/// mock_service
/// .expect_request(1)
/// .expect_request(Request)
/// .await
/// .respond_with(|req| format!("Received: {}", req));
/// .respond_with(|req| format!("Received: {req:?}"));
///
/// mock_service
/// .expect_request(1)
/// .expect_request(Request)
/// .await
/// .respond_with(|req| Err(format!("Duplicate request: {}", req)));
/// .respond_with(|req| Err(format!("Duplicate request: {req:?}")));
/// # });
/// ```
pub fn respond_with<F, R>(self, response_fn: F)
@ -834,6 +846,116 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
let response_result = response_fn(self.request()).into_result();
let _ = self.response_sender.send(response_result);
}
/// Respond to the request using a fixed error value.
///
/// The `error` must be the `Error` type. This helps avoid type resolution issues in the
/// compiler.
///
/// This method takes ownership of the [`ResponseSender`] so that only one response can be
/// sent.
///
/// # Panics
///
/// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a
/// timeout error.
///
/// # Example
///
/// ```
/// # use zebra_test::mock_service::MockService;
/// # use tower::{Service, ServiceExt};
/// #
/// # #[derive(Debug, PartialEq, Eq)]
/// # struct Request;
/// # struct Response;
/// #
/// # let reactor = tokio::runtime::Builder::new_current_thread()
/// # .enable_all()
/// # .build()
/// # .expect("Failed to build Tokio runtime");
/// #
/// # reactor.block_on(async {
/// // Mock a service with a `String` as the service `Error` type.
/// let mut mock_service: MockService<Request, Response, _, String> =
/// MockService::build().for_unit_tests();
///
/// # let mut service = mock_service.clone();
/// # let task = tokio::spawn(async move {
/// # let first_call_result = (&mut service).oneshot(Request).await;
/// # let second_call_result = service.oneshot(Request).await;
/// #
/// # (first_call_result, second_call_result)
/// # });
/// #
/// mock_service
/// .expect_request(Request)
/// .await
/// .respond_error("Duplicate request".to_string());
/// # });
/// ```
pub fn respond_error(self, error: Error) {
// TODO: impl ResponseResult for BoxError/Error trait when overlapping impls are
// better supported by the compiler
let _ = self.response_sender.send(Err(error));
}
/// Respond to the request by calculating an error from the request.
///
/// The `error` must be the `Error` type. This helps avoid type resolution issues in the
/// compiler.
///
/// This method takes ownership of the [`ResponseSender`] so that only one response can be
/// sent.
///
/// # Panics
///
/// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a
/// timeout error.
///
/// # Example
///
/// ```
/// # use zebra_test::mock_service::MockService;
/// # use tower::{Service, ServiceExt};
/// #
/// # #[derive(Debug, PartialEq, Eq)]
/// # struct Request;
/// # struct Response;
/// #
/// # let reactor = tokio::runtime::Builder::new_current_thread()
/// # .enable_all()
/// # .build()
/// # .expect("Failed to build Tokio runtime");
/// #
/// # reactor.block_on(async {
/// // Mock a service with a `String` as the service `Error` type.
/// let mut mock_service: MockService<Request, Response, _, String> =
/// MockService::build().for_unit_tests();
///
/// # let mut service = mock_service.clone();
/// # let task = tokio::spawn(async move {
/// # let first_call_result = (&mut service).oneshot(Request).await;
/// # let second_call_result = service.oneshot(Request).await;
/// #
/// # (first_call_result, second_call_result)
/// # });
/// #
/// mock_service
/// .expect_request(Request)
/// .await
/// .respond_with_error(|req| format!("Duplicate request: {req:?}"));
/// # });
/// ```
pub fn respond_with_error<F>(self, response_fn: F)
where
F: FnOnce(&Request) -> Error,
{
// TODO: impl ResponseResult for BoxError/Error trait when overlapping impls are
// better supported by the compiler
let response_result = Err(response_fn(self.request()));
let _ = self.response_sender.send(response_result);
}
}
/// A representation of an assertion type.