fix(security): Randomly drop connections when inbound service is overloaded (#6790)
* fix(security): Randomly drop connections when inbound service is overloaded * Uses progressively higher drop probabilities * Replaces Error::Overloaded with Fatal when internal services shutdown * Applies suggestions from code review. * Quickens initial drop probability decay and updates comment * Applies suggestions from code review. * Fixes drop connection probablity calc * Update connection state metrics for different overload/error outcomes * Split overload handler into separate methods * Add unit test for drop probability function properties * Add respond_error methods to zebra-test to help with type resolution * Initial test that Overloaded errors cause some continues and some closes * Tune the number of test runs and test timing * Fix doctests and replace some confusing example requests --------- Co-authored-by: arya2 <aryasolhi@gmail.com>
This commit is contained in:
parent
af4d53122f
commit
6eaf83b4bf
|
@ -316,6 +316,25 @@ pub const EWMA_DECAY_TIME_NANOS: f64 = 200.0 * NANOS_PER_SECOND;
|
|||
/// The number of nanoseconds in one second.
|
||||
const NANOS_PER_SECOND: f64 = 1_000_000_000.0;
|
||||
|
||||
/// The duration it takes for the drop probability of an overloaded connection to
|
||||
/// reach [`MIN_OVERLOAD_DROP_PROBABILITY`].
|
||||
///
|
||||
/// Peer connections that receive multiple overloads have a higher probability of being dropped.
|
||||
///
|
||||
/// The probability of a connection being dropped gradually decreases during this interval
|
||||
/// until it reaches the default drop probability ([`MIN_OVERLOAD_DROP_PROBABILITY`]).
|
||||
///
|
||||
/// Increasing this number increases the rate at which connections are dropped.
|
||||
pub const OVERLOAD_PROTECTION_INTERVAL: Duration = MIN_INBOUND_PEER_CONNECTION_INTERVAL;
|
||||
|
||||
/// The minimum probability of dropping a peer connection when it receives an
|
||||
/// [`Overloaded`](crate::PeerError::Overloaded) error.
|
||||
pub const MIN_OVERLOAD_DROP_PROBABILITY: f32 = 0.05;
|
||||
|
||||
/// The maximum probability of dropping a peer connection when it receives an
|
||||
/// [`Overloaded`](crate::PeerError::Overloaded) error.
|
||||
pub const MAX_OVERLOAD_DROP_PROBABILITY: f32 = 0.95;
|
||||
|
||||
lazy_static! {
|
||||
/// The minimum network protocol version accepted by this crate for each network,
|
||||
/// represented as a network upgrade.
|
||||
|
|
|
@ -7,15 +7,16 @@
|
|||
//! And it's unclear if these assumptions match the `zcashd` implementation.
|
||||
//! It should be refactored into a cleaner set of request/response pairs (#1515).
|
||||
|
||||
use std::{borrow::Cow, collections::HashSet, fmt, pin::Pin, sync::Arc};
|
||||
use std::{borrow::Cow, collections::HashSet, fmt, pin::Pin, sync::Arc, time::Instant};
|
||||
|
||||
use futures::{
|
||||
future::{self, Either},
|
||||
prelude::*,
|
||||
stream::Stream,
|
||||
};
|
||||
use rand::{thread_rng, Rng};
|
||||
use tokio::time::{sleep, Sleep};
|
||||
use tower::Service;
|
||||
use tower::{load_shed::error::Overloaded, Service, ServiceExt};
|
||||
use tracing_futures::Instrument;
|
||||
|
||||
use zebra_chain::{
|
||||
|
@ -25,7 +26,10 @@ use zebra_chain::{
|
|||
};
|
||||
|
||||
use crate::{
|
||||
constants,
|
||||
constants::{
|
||||
self, MAX_OVERLOAD_DROP_PROBABILITY, MIN_OVERLOAD_DROP_PROBABILITY,
|
||||
OVERLOAD_PROTECTION_INTERVAL,
|
||||
},
|
||||
meta_addr::MetaAddr,
|
||||
peer::{
|
||||
connection::peer_tx::PeerTx, error::AlreadyErrored, ClientRequest, ClientRequestReceiver,
|
||||
|
@ -508,6 +512,11 @@ pub struct Connection<S, Tx> {
|
|||
|
||||
/// The state for this peer, when the metrics were last updated.
|
||||
pub(super) last_metrics_state: Option<Cow<'static, str>>,
|
||||
|
||||
/// The time of the last overload error response from the inbound
|
||||
/// service to a request from this connection,
|
||||
/// or None if this connection hasn't yet received an overload error.
|
||||
last_overload_time: Option<Instant>,
|
||||
}
|
||||
|
||||
impl<S, Tx> fmt::Debug for Connection<S, Tx> {
|
||||
|
@ -549,6 +558,7 @@ impl<S, Tx> Connection<S, Tx> {
|
|||
connection_tracker,
|
||||
metrics_label,
|
||||
last_metrics_state: None,
|
||||
last_overload_time: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1242,7 +1252,6 @@ where
|
|||
/// of connected peers.
|
||||
async fn drive_peer_request(&mut self, req: Request) {
|
||||
trace!(?req);
|
||||
use tower::{load_shed::error::Overloaded, ServiceExt};
|
||||
|
||||
// Add a metric for inbound requests
|
||||
metrics::counter!(
|
||||
|
@ -1258,29 +1267,18 @@ where
|
|||
tokio::task::yield_now().await;
|
||||
|
||||
if self.svc.ready().await.is_err() {
|
||||
// Treat all service readiness errors as Overloaded
|
||||
// TODO: treat `TryRecvError::Closed` in `Inbound::poll_ready` as a fatal error (#1655)
|
||||
self.fail_with(PeerError::Overloaded);
|
||||
self.fail_with(PeerError::ServiceShutdown);
|
||||
return;
|
||||
}
|
||||
|
||||
let rsp = match self.svc.call(req.clone()).await {
|
||||
Err(e) => {
|
||||
if e.is::<Overloaded>() {
|
||||
tracing::info!(
|
||||
remote_user_agent = ?self.connection_info.remote.user_agent,
|
||||
negotiated_version = ?self.connection_info.negotiated_version,
|
||||
peer = ?self.metrics_label,
|
||||
last_peer_state = ?self.last_metrics_state,
|
||||
// TODO: remove this detailed debug info once #6506 is fixed
|
||||
remote_height = ?self.connection_info.remote.start_height,
|
||||
cached_addrs = ?self.cached_addrs.len(),
|
||||
connection_state = ?self.state,
|
||||
"inbound service is overloaded, closing connection",
|
||||
);
|
||||
tracing::debug!("inbound service is overloaded, may close connection");
|
||||
|
||||
metrics::counter!("pool.closed.loadshed", 1);
|
||||
self.fail_with(PeerError::Overloaded);
|
||||
let now = Instant::now();
|
||||
|
||||
self.handle_inbound_overload(req, now).await;
|
||||
} else {
|
||||
// We could send a reject to the remote peer, but that might cause
|
||||
// them to disconnect, and we might be using them to sync blocks.
|
||||
|
@ -1292,7 +1290,9 @@ where
|
|||
client_receiver = ?self.client_rx,
|
||||
"error processing peer request",
|
||||
);
|
||||
self.update_state_metrics(format!("In::Req::{}/Rsp::Error", req.command()));
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
Ok(rsp) => rsp,
|
||||
|
@ -1307,6 +1307,7 @@ where
|
|||
);
|
||||
self.update_state_metrics(format!("In::Rsp::{}", rsp.command()));
|
||||
|
||||
// TODO: split response handler into its own method
|
||||
match rsp.clone() {
|
||||
Response::Nil => { /* generic success, do nothing */ }
|
||||
Response::Peers(addrs) => {
|
||||
|
@ -1412,6 +1413,90 @@ where
|
|||
// before checking the connection for the next inbound or outbound request.
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
/// Handle inbound service overload error responses by randomly terminating some connections.
|
||||
///
|
||||
/// # Security
|
||||
///
|
||||
/// When the inbound service is overloaded with requests, Zebra needs to drop some connections,
|
||||
/// to reduce the load on the application. But dropping every connection that receives an
|
||||
/// `Overloaded` error from the inbound service could cause Zebra to drop too many peer
|
||||
/// connections, and stop itself downloading blocks or transactions.
|
||||
///
|
||||
/// Malicious or misbehaving peers can also overload the inbound service, and make Zebra drop
|
||||
/// its connections to other peers.
|
||||
///
|
||||
/// So instead, Zebra drops some overloaded connections at random. If a connection has recently
|
||||
/// overloaded the inbound service, it is more likely to be dropped. This makes it harder for a
|
||||
/// single peer (or multiple peers) to perform a denial of service attack.
|
||||
///
|
||||
/// The inbound connection rate-limit also makes it hard for multiple peers to perform this
|
||||
/// attack, because each inbound connection can only send one inbound request before its
|
||||
/// probability of being disconnected increases.
|
||||
async fn handle_inbound_overload(&mut self, req: Request, now: Instant) {
|
||||
let prev = self.last_overload_time.replace(now);
|
||||
let drop_connection_probability = overload_drop_connection_probability(now, prev);
|
||||
|
||||
if thread_rng().gen::<f32>() < drop_connection_probability {
|
||||
metrics::counter!("pool.closed.loadshed", 1);
|
||||
|
||||
tracing::info!(
|
||||
drop_connection_probability,
|
||||
remote_user_agent = ?self.connection_info.remote.user_agent,
|
||||
negotiated_version = ?self.connection_info.negotiated_version,
|
||||
peer = ?self.metrics_label,
|
||||
last_peer_state = ?self.last_metrics_state,
|
||||
// TODO: remove this detailed debug info once #6506 is fixed
|
||||
remote_height = ?self.connection_info.remote.start_height,
|
||||
cached_addrs = ?self.cached_addrs.len(),
|
||||
connection_state = ?self.state,
|
||||
"inbound service is overloaded, closing connection",
|
||||
);
|
||||
|
||||
self.update_state_metrics(format!("In::Req::{}/Rsp::Overload::Error", req.command()));
|
||||
self.fail_with(PeerError::Overloaded);
|
||||
} else {
|
||||
self.update_state_metrics(format!("In::Req::{}/Rsp::Overload::Ignored", req.command()));
|
||||
metrics::counter!("pool.ignored.loadshed", 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the probability of dropping a connection where the last overload was at `prev`,
|
||||
/// and the current overload is `now`.
|
||||
///
|
||||
/// # Security
|
||||
///
|
||||
/// Connections that haven't seen an overload error in the past OVERLOAD_PROTECTION_INTERVAL
|
||||
/// have a small chance of being closed (MIN_OVERLOAD_DROP_PROBABILITY).
|
||||
///
|
||||
/// Connections that have seen a previous overload error in that time
|
||||
/// have a higher chance of being dropped up to MAX_OVERLOAD_DROP_PROBABILITY.
|
||||
/// This probability increases quadratically, so peers that send lots of inbound
|
||||
/// requests are more likely to be dropped.
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
/// If a connection sends multiple overloads close together, it is very likely to be
|
||||
/// disconnected. If a connection has two overloads multiple seconds apart, it is unlikely
|
||||
/// to be disconnected.
|
||||
fn overload_drop_connection_probability(now: Instant, prev: Option<Instant>) -> f32 {
|
||||
let Some(prev) = prev else {
|
||||
return MIN_OVERLOAD_DROP_PROBABILITY;
|
||||
};
|
||||
|
||||
let protection_fraction_since_last_overload =
|
||||
(now - prev).as_secs_f32() / OVERLOAD_PROTECTION_INTERVAL.as_secs_f32();
|
||||
|
||||
// Quadratically increase the disconnection probability for very recent overloads.
|
||||
// Negative values are ignored by clamping to MIN_OVERLOAD_DROP_PROBABILITY.
|
||||
let overload_fraction = protection_fraction_since_last_overload.powi(2);
|
||||
|
||||
let probability_range = MAX_OVERLOAD_DROP_PROBABILITY - MIN_OVERLOAD_DROP_PROBABILITY;
|
||||
let raw_drop_probability =
|
||||
MAX_OVERLOAD_DROP_PROBABILITY - (overload_fraction * probability_range);
|
||||
|
||||
raw_drop_probability.clamp(MIN_OVERLOAD_DROP_PROBABILITY, MAX_OVERLOAD_DROP_PROBABILITY)
|
||||
}
|
||||
|
||||
impl<S, Tx> Connection<S, Tx> {
|
||||
|
|
|
@ -4,22 +4,27 @@
|
|||
//! - inbound message as request
|
||||
//! - inbound message, but not a request (or a response)
|
||||
|
||||
use std::{collections::HashSet, task::Poll, time::Duration};
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
task::Poll,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
sink::SinkMapErr,
|
||||
FutureExt, StreamExt,
|
||||
FutureExt, SinkExt, StreamExt,
|
||||
};
|
||||
|
||||
use tower::load_shed::error::Overloaded;
|
||||
use tracing::Span;
|
||||
|
||||
use zebra_chain::serialization::SerializationError;
|
||||
use zebra_test::mock_service::{MockService, PanicAssertion};
|
||||
|
||||
use crate::{
|
||||
constants::REQUEST_TIMEOUT,
|
||||
constants::{MAX_OVERLOAD_DROP_PROBABILITY, MIN_OVERLOAD_DROP_PROBABILITY, REQUEST_TIMEOUT},
|
||||
peer::{
|
||||
connection::{Connection, State},
|
||||
connection::{overload_drop_connection_probability, Connection, State},
|
||||
ClientRequest, ErrorSlot,
|
||||
},
|
||||
protocol::external::Message,
|
||||
|
@ -656,6 +661,230 @@ async fn connection_run_loop_receive_timeout() {
|
|||
assert_eq!(outbound_message, None);
|
||||
}
|
||||
|
||||
/// Check basic properties of overload probabilities
|
||||
#[test]
|
||||
fn overload_probability_reduces_over_time() {
|
||||
let now = Instant::now();
|
||||
|
||||
// Edge case: previous is in the future due to OS monotonic clock bugs
|
||||
let prev = now + Duration::from_secs(1);
|
||||
assert_eq!(
|
||||
overload_drop_connection_probability(now, Some(prev)),
|
||||
MAX_OVERLOAD_DROP_PROBABILITY,
|
||||
"if the overload time is in the future (OS bugs?), it should have maximum drop probability",
|
||||
);
|
||||
|
||||
// Overload/DoS case/edge case: rapidly repeated overloads
|
||||
let prev = now;
|
||||
assert_eq!(
|
||||
overload_drop_connection_probability(now, Some(prev)),
|
||||
MAX_OVERLOAD_DROP_PROBABILITY,
|
||||
"if the overload times are the same, overloads should have maximum drop probability",
|
||||
);
|
||||
|
||||
// Overload/DoS case: rapidly repeated overloads
|
||||
let prev = now - Duration::from_micros(1);
|
||||
let drop_probability = overload_drop_connection_probability(now, Some(prev));
|
||||
assert!(
|
||||
drop_probability <= MAX_OVERLOAD_DROP_PROBABILITY,
|
||||
"if the overloads are very close together, drops can optionally decrease",
|
||||
);
|
||||
assert!(
|
||||
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001,
|
||||
"if the overloads are very close together, drops can only decrease slightly",
|
||||
);
|
||||
let last_probability = drop_probability;
|
||||
|
||||
// Overload/DoS case: rapidly repeated overloads
|
||||
let prev = now - Duration::from_millis(1);
|
||||
let drop_probability = overload_drop_connection_probability(now, Some(prev));
|
||||
assert!(
|
||||
drop_probability < last_probability,
|
||||
"if the overloads decrease, drops should decrease",
|
||||
);
|
||||
assert!(
|
||||
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001,
|
||||
"if the overloads are very close together, drops can only decrease slightly",
|
||||
);
|
||||
let last_probability = drop_probability;
|
||||
|
||||
// Overload/DoS case: rapidly repeated overloads
|
||||
let prev = now - Duration::from_millis(10);
|
||||
let drop_probability = overload_drop_connection_probability(now, Some(prev));
|
||||
assert!(
|
||||
drop_probability < last_probability,
|
||||
"if the overloads decrease, drops should decrease",
|
||||
);
|
||||
assert!(
|
||||
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.001,
|
||||
"if the overloads are very close together, drops can only decrease slightly",
|
||||
);
|
||||
let last_probability = drop_probability;
|
||||
|
||||
// Overload case: frequent overloads
|
||||
let prev = now - Duration::from_millis(100);
|
||||
let drop_probability = overload_drop_connection_probability(now, Some(prev));
|
||||
assert!(
|
||||
drop_probability < last_probability,
|
||||
"if the overloads decrease, drops should decrease",
|
||||
);
|
||||
assert!(
|
||||
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability < 0.01,
|
||||
"if the overloads are very close together, drops can only decrease slightly",
|
||||
);
|
||||
let last_probability = drop_probability;
|
||||
|
||||
// Overload case: occasional but repeated overloads
|
||||
let prev = now - Duration::from_secs(1);
|
||||
let drop_probability = overload_drop_connection_probability(now, Some(prev));
|
||||
assert!(
|
||||
drop_probability < last_probability,
|
||||
"if the overloads decrease, drops should decrease",
|
||||
);
|
||||
assert!(
|
||||
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability > 0.5,
|
||||
"if the overloads are distant, drops should decrease a lot",
|
||||
);
|
||||
let last_probability = drop_probability;
|
||||
|
||||
// Overload case: occasional overloads
|
||||
let prev = now - Duration::from_secs(5);
|
||||
let drop_probability = overload_drop_connection_probability(now, Some(prev));
|
||||
assert!(
|
||||
drop_probability < last_probability,
|
||||
"if the overloads decrease, drops should decrease",
|
||||
);
|
||||
assert!(
|
||||
MAX_OVERLOAD_DROP_PROBABILITY - drop_probability > 0.7,
|
||||
"if the overloads are distant, drops should decrease a lot",
|
||||
);
|
||||
let _last_probability = drop_probability;
|
||||
|
||||
// Base case: infrequent overloads
|
||||
let prev = now - Duration::from_secs(10);
|
||||
let drop_probability = overload_drop_connection_probability(now, Some(prev));
|
||||
assert_eq!(
|
||||
drop_probability, MIN_OVERLOAD_DROP_PROBABILITY,
|
||||
"if overloads are far apart, drops should have minimum drop probability",
|
||||
);
|
||||
|
||||
// Base case: no previous overload
|
||||
let drop_probability = overload_drop_connection_probability(now, None);
|
||||
assert_eq!(
|
||||
drop_probability, MIN_OVERLOAD_DROP_PROBABILITY,
|
||||
"if there is no previous overload time, overloads should have minimum drop probability",
|
||||
);
|
||||
}
|
||||
|
||||
/// Test that connections are randomly terminated in response to `Overloaded` errors.
|
||||
///
|
||||
/// TODO: do a similar test on the real service stack created in the `start` command.
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn connection_is_randomly_disconnected_on_overload() {
|
||||
let _init_guard = zebra_test::init();
|
||||
|
||||
// The number of times we repeat the test
|
||||
const TEST_RUNS: usize = 220;
|
||||
// The expected number of tests before a test failure due to random chance.
|
||||
// Based on 10 tests per PR, 100 PR pushes per week, 50 weeks per year.
|
||||
const TESTS_BEFORE_FAILURE: f32 = 50_000.0;
|
||||
|
||||
let test_runs = TEST_RUNS.try_into().expect("constant fits in i32");
|
||||
// The probability of random test failure is:
|
||||
// MIN_OVERLOAD_DROP_PROBABILITY^TEST_RUNS + MAX_OVERLOAD_DROP_PROBABILITY^TEST_RUNS
|
||||
assert!(
|
||||
1.0 / MIN_OVERLOAD_DROP_PROBABILITY.powi(test_runs) > TESTS_BEFORE_FAILURE,
|
||||
"not enough test runs: failures must be frequent enough to happen in almost all tests"
|
||||
);
|
||||
assert!(
|
||||
1.0 / MAX_OVERLOAD_DROP_PROBABILITY.powi(test_runs) > TESTS_BEFORE_FAILURE,
|
||||
"not enough test runs: successes must be frequent enough to happen in almost all tests"
|
||||
);
|
||||
|
||||
let mut connection_continues = 0;
|
||||
let mut connection_closes = 0;
|
||||
|
||||
for _ in 0..TEST_RUNS {
|
||||
// The real stream and sink are from a split TCP connection,
|
||||
// but that doesn't change how the state machine behaves.
|
||||
let (mut peer_tx, peer_rx) = mpsc::channel(1);
|
||||
|
||||
let (
|
||||
connection,
|
||||
_client_tx,
|
||||
mut inbound_service,
|
||||
mut peer_outbound_messages,
|
||||
shared_error_slot,
|
||||
) = new_test_connection();
|
||||
|
||||
// The connection hasn't run so it must not have errors
|
||||
let error = shared_error_slot.try_get_error();
|
||||
assert!(
|
||||
error.is_none(),
|
||||
"unexpected error before starting the connection event loop: {error:?}",
|
||||
);
|
||||
|
||||
// Start the connection run loop future in a spawned task
|
||||
let connection_handle = tokio::spawn(connection.run(peer_rx));
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
|
||||
// The connection hasn't received any messages, so it must not have errors
|
||||
let error = shared_error_slot.try_get_error();
|
||||
assert!(
|
||||
error.is_none(),
|
||||
"unexpected error before sending messages to the connection event loop: {error:?}",
|
||||
);
|
||||
|
||||
// Simulate an overloaded connection error in response to an inbound request.
|
||||
let inbound_req = Message::GetAddr;
|
||||
peer_tx
|
||||
.send(Ok(inbound_req))
|
||||
.await
|
||||
.expect("send to channel always succeeds");
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
|
||||
// The connection hasn't got a response, so it must not have errors
|
||||
let error = shared_error_slot.try_get_error();
|
||||
assert!(
|
||||
error.is_none(),
|
||||
"unexpected error before sending responses to the connection event loop: {error:?}",
|
||||
);
|
||||
|
||||
inbound_service
|
||||
.expect_request(Request::Peers)
|
||||
.await
|
||||
.respond_error(Overloaded::new().into());
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
|
||||
let outbound_result = peer_outbound_messages.try_next();
|
||||
assert!(
|
||||
!matches!(outbound_result, Ok(Some(_))),
|
||||
"unexpected outbound message after Overloaded error:\n\
|
||||
{outbound_result:?}\n\
|
||||
note: TryRecvErr means there are no messages, Ok(None) means the channel is closed"
|
||||
);
|
||||
|
||||
let error = shared_error_slot.try_get_error();
|
||||
if error.is_some() {
|
||||
connection_closes += 1;
|
||||
} else {
|
||||
connection_continues += 1;
|
||||
}
|
||||
|
||||
// We need to terminate the spawned task
|
||||
connection_handle.abort();
|
||||
}
|
||||
|
||||
assert!(
|
||||
connection_closes > 0,
|
||||
"some overloaded connections must be closed at random"
|
||||
);
|
||||
assert!(
|
||||
connection_continues > 0,
|
||||
"some overloaded errors must be ignored at random"
|
||||
);
|
||||
}
|
||||
|
||||
/// Creates a new [`Connection`] instance for unit tests.
|
||||
fn new_test_connection() -> (
|
||||
Connection<
|
||||
|
|
|
@ -82,6 +82,10 @@ pub enum PeerError {
|
|||
#[error("Internal services over capacity")]
|
||||
Overloaded,
|
||||
|
||||
/// This node's internal services are no longer able to service requests.
|
||||
#[error("Internal services have failed or shutdown")]
|
||||
ServiceShutdown,
|
||||
|
||||
/// We requested data, but the peer replied with a `notfound` message.
|
||||
/// (Or it didn't respond before the request finished.)
|
||||
///
|
||||
|
@ -138,6 +142,7 @@ impl PeerError {
|
|||
PeerError::Serialization(inner) => format!("Serialization({inner})").into(),
|
||||
PeerError::DuplicateHandshake => "DuplicateHandshake".into(),
|
||||
PeerError::Overloaded => "Overloaded".into(),
|
||||
PeerError::ServiceShutdown => "ServiceShutdown".into(),
|
||||
PeerError::NotFoundResponse(_) => "NotFoundResponse".into(),
|
||||
PeerError::NotFoundRegistry(_) => "NotFoundRegistry".into(),
|
||||
}
|
||||
|
|
|
@ -740,7 +740,10 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
|
|||
/// This method takes ownership of the [`ResponseSender`] so that only one response can be
|
||||
/// sent.
|
||||
///
|
||||
/// If `respond` or `respond_with` are not called, the caller will panic.
|
||||
/// # Panics
|
||||
///
|
||||
/// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a
|
||||
/// timeout error.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -748,6 +751,9 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
|
|||
/// # use zebra_test::mock_service::MockService;
|
||||
/// # use tower::{Service, ServiceExt};
|
||||
/// #
|
||||
/// # #[derive(Debug, PartialEq, Eq)]
|
||||
/// # struct Request;
|
||||
/// #
|
||||
/// # let reactor = tokio::runtime::Builder::new_current_thread()
|
||||
/// # .enable_all()
|
||||
/// # .build()
|
||||
|
@ -760,19 +766,19 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
|
|||
///
|
||||
/// # let mut service = mock_service.clone();
|
||||
/// # let task = tokio::spawn(async move {
|
||||
/// # let first_call_result = (&mut service).oneshot(1).await;
|
||||
/// # let second_call_result = service.oneshot(1).await;
|
||||
/// # let first_call_result = (&mut service).oneshot(Request).await;
|
||||
/// # let second_call_result = service.oneshot(Request).await;
|
||||
/// #
|
||||
/// # (first_call_result, second_call_result)
|
||||
/// # });
|
||||
/// #
|
||||
/// mock_service
|
||||
/// .expect_request(1)
|
||||
/// .expect_request(Request)
|
||||
/// .await
|
||||
/// .respond("Received one".to_owned());
|
||||
/// .respond("Received Request".to_owned());
|
||||
///
|
||||
/// mock_service
|
||||
/// .expect_request(1)
|
||||
/// .expect_request(Request)
|
||||
/// .await
|
||||
/// .respond(Err("Duplicate request"));
|
||||
/// # });
|
||||
|
@ -789,7 +795,10 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
|
|||
/// This method takes ownership of the [`ResponseSender`] so that only one response can be
|
||||
/// sent.
|
||||
///
|
||||
/// If `respond` or `respond_with` are not called, the caller will panic.
|
||||
/// # Panics
|
||||
///
|
||||
/// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a
|
||||
/// timeout error.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -797,6 +806,9 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
|
|||
/// # use zebra_test::mock_service::MockService;
|
||||
/// # use tower::{Service, ServiceExt};
|
||||
/// #
|
||||
/// # #[derive(Debug, PartialEq, Eq)]
|
||||
/// # struct Request;
|
||||
/// #
|
||||
/// # let reactor = tokio::runtime::Builder::new_current_thread()
|
||||
/// # .enable_all()
|
||||
/// # .build()
|
||||
|
@ -809,21 +821,21 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
|
|||
///
|
||||
/// # let mut service = mock_service.clone();
|
||||
/// # let task = tokio::spawn(async move {
|
||||
/// # let first_call_result = (&mut service).oneshot(1).await;
|
||||
/// # let second_call_result = service.oneshot(1).await;
|
||||
/// # let first_call_result = (&mut service).oneshot(Request).await;
|
||||
/// # let second_call_result = service.oneshot(Request).await;
|
||||
/// #
|
||||
/// # (first_call_result, second_call_result)
|
||||
/// # });
|
||||
/// #
|
||||
/// mock_service
|
||||
/// .expect_request(1)
|
||||
/// .expect_request(Request)
|
||||
/// .await
|
||||
/// .respond_with(|req| format!("Received: {}", req));
|
||||
/// .respond_with(|req| format!("Received: {req:?}"));
|
||||
///
|
||||
/// mock_service
|
||||
/// .expect_request(1)
|
||||
/// .expect_request(Request)
|
||||
/// .await
|
||||
/// .respond_with(|req| Err(format!("Duplicate request: {}", req)));
|
||||
/// .respond_with(|req| Err(format!("Duplicate request: {req:?}")));
|
||||
/// # });
|
||||
/// ```
|
||||
pub fn respond_with<F, R>(self, response_fn: F)
|
||||
|
@ -834,6 +846,116 @@ impl<Request, Response, Error> ResponseSender<Request, Response, Error> {
|
|||
let response_result = response_fn(self.request()).into_result();
|
||||
let _ = self.response_sender.send(response_result);
|
||||
}
|
||||
|
||||
/// Respond to the request using a fixed error value.
|
||||
///
|
||||
/// The `error` must be the `Error` type. This helps avoid type resolution issues in the
|
||||
/// compiler.
|
||||
///
|
||||
/// This method takes ownership of the [`ResponseSender`] so that only one response can be
|
||||
/// sent.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a
|
||||
/// timeout error.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zebra_test::mock_service::MockService;
|
||||
/// # use tower::{Service, ServiceExt};
|
||||
/// #
|
||||
/// # #[derive(Debug, PartialEq, Eq)]
|
||||
/// # struct Request;
|
||||
/// # struct Response;
|
||||
/// #
|
||||
/// # let reactor = tokio::runtime::Builder::new_current_thread()
|
||||
/// # .enable_all()
|
||||
/// # .build()
|
||||
/// # .expect("Failed to build Tokio runtime");
|
||||
/// #
|
||||
/// # reactor.block_on(async {
|
||||
/// // Mock a service with a `String` as the service `Error` type.
|
||||
/// let mut mock_service: MockService<Request, Response, _, String> =
|
||||
/// MockService::build().for_unit_tests();
|
||||
///
|
||||
/// # let mut service = mock_service.clone();
|
||||
/// # let task = tokio::spawn(async move {
|
||||
/// # let first_call_result = (&mut service).oneshot(Request).await;
|
||||
/// # let second_call_result = service.oneshot(Request).await;
|
||||
/// #
|
||||
/// # (first_call_result, second_call_result)
|
||||
/// # });
|
||||
/// #
|
||||
/// mock_service
|
||||
/// .expect_request(Request)
|
||||
/// .await
|
||||
/// .respond_error("Duplicate request".to_string());
|
||||
/// # });
|
||||
/// ```
|
||||
pub fn respond_error(self, error: Error) {
|
||||
// TODO: impl ResponseResult for BoxError/Error trait when overlapping impls are
|
||||
// better supported by the compiler
|
||||
let _ = self.response_sender.send(Err(error));
|
||||
}
|
||||
|
||||
/// Respond to the request by calculating an error from the request.
|
||||
///
|
||||
/// The `error` must be the `Error` type. This helps avoid type resolution issues in the
|
||||
/// compiler.
|
||||
///
|
||||
/// This method takes ownership of the [`ResponseSender`] so that only one response can be
|
||||
/// sent.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If one of the `respond*` methods isn't called, the [`MockService`] might panic with a
|
||||
/// timeout error.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zebra_test::mock_service::MockService;
|
||||
/// # use tower::{Service, ServiceExt};
|
||||
/// #
|
||||
/// # #[derive(Debug, PartialEq, Eq)]
|
||||
/// # struct Request;
|
||||
/// # struct Response;
|
||||
/// #
|
||||
/// # let reactor = tokio::runtime::Builder::new_current_thread()
|
||||
/// # .enable_all()
|
||||
/// # .build()
|
||||
/// # .expect("Failed to build Tokio runtime");
|
||||
/// #
|
||||
/// # reactor.block_on(async {
|
||||
/// // Mock a service with a `String` as the service `Error` type.
|
||||
/// let mut mock_service: MockService<Request, Response, _, String> =
|
||||
/// MockService::build().for_unit_tests();
|
||||
///
|
||||
/// # let mut service = mock_service.clone();
|
||||
/// # let task = tokio::spawn(async move {
|
||||
/// # let first_call_result = (&mut service).oneshot(Request).await;
|
||||
/// # let second_call_result = service.oneshot(Request).await;
|
||||
/// #
|
||||
/// # (first_call_result, second_call_result)
|
||||
/// # });
|
||||
/// #
|
||||
/// mock_service
|
||||
/// .expect_request(Request)
|
||||
/// .await
|
||||
/// .respond_with_error(|req| format!("Duplicate request: {req:?}"));
|
||||
/// # });
|
||||
/// ```
|
||||
pub fn respond_with_error<F>(self, response_fn: F)
|
||||
where
|
||||
F: FnOnce(&Request) -> Error,
|
||||
{
|
||||
// TODO: impl ResponseResult for BoxError/Error trait when overlapping impls are
|
||||
// better supported by the compiler
|
||||
let response_result = Err(response_fn(self.request()));
|
||||
let _ = self.response_sender.send(response_result);
|
||||
}
|
||||
}
|
||||
|
||||
/// A representation of an assertion type.
|
||||
|
|
Loading…
Reference in New Issue