zebra/zebra-test/src/transcript.rs

195 lines
6.7 KiB
Rust

//! A [`Service`](tower::Service) implementation based on a fixed transcript.
use std::{
fmt::Debug,
sync::Arc,
task::{Context, Poll},
};
use color_eyre::{
eyre::{eyre, Report, WrapErr},
section::Section,
section::SectionExt,
};
use futures::future::{ready, Ready};
use tower::{Service, ServiceExt};
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
/// An error-checking function: is the value an expected error?
///
/// If the checked error is the expected error, the function should return `Ok(())`.
/// Otherwise, it should just return the checked error, wrapped inside `Err`.
pub type ErrorChecker = fn(Option<BoxError>) -> Result<(), BoxError>;
/// An expected error in a transcript.
#[derive(Debug, Clone)]
pub enum ExpectedTranscriptError {
/// Match any error
Any,
/// Use a validator function to check for matching errors
Exact(Arc<ErrorChecker>),
}
impl ExpectedTranscriptError {
/// Convert the `verifier` function into an exact error checker
pub fn exact(verifier: ErrorChecker) -> Self {
ExpectedTranscriptError::Exact(verifier.into())
}
/// Check the actual error `e` against this expected error.
#[track_caller]
fn check(&self, e: BoxError) -> Result<(), Report> {
match self {
ExpectedTranscriptError::Any => Ok(()),
ExpectedTranscriptError::Exact(checker) => checker(Some(e)),
}
.map_err(ErrorCheckerError)
.wrap_err("service returned an error but it didn't match the expected error")
}
fn mock(&self) -> Report {
match self {
ExpectedTranscriptError::Any => eyre!("mock error"),
ExpectedTranscriptError::Exact(checker) => {
checker(None).map_err(|e| eyre!(e)).expect_err(
"transcript should correctly produce the expected mock error when passed None",
)
}
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("ErrorChecker Error: {0}")]
struct ErrorCheckerError(BoxError);
/// A transcript: a list of requests and expected results.
#[must_use]
pub struct Transcript<R, S, I>
where
I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
{
messages: I,
}
impl<R, S, I> From<I> for Transcript<R, S, I::IntoIter>
where
I: IntoIterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
{
fn from(messages: I) -> Self {
Self {
messages: messages.into_iter(),
}
}
}
impl<R, S, I> Transcript<R, S, I>
where
I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
R: Debug,
S: Debug + Eq,
{
/// Check this transcript against the responses from the `to_check` service
pub async fn check<C>(mut self, mut to_check: C) -> Result<(), Report>
where
C: Service<R, Response = S>,
C::Error: Into<BoxError>,
{
for (req, expected_rsp) in &mut self.messages {
// These unwraps could propagate errors with the correct
// bound on C::Error
let fut = to_check
.ready()
.await
.map_err(Into::into)
.map_err(|e| eyre!(e))
.expect("expected service to not fail during execution of transcript");
let response = fut.call(req).await;
match (response, expected_rsp) {
(Ok(rsp), Ok(expected_rsp)) => {
if rsp != expected_rsp {
Err(eyre!(
"response doesn't match transcript's expected response"
))
.with_section(|| format!("{expected_rsp:?}").header("Expected Response:"))
.with_section(|| format!("{rsp:?}").header("Found Response:"))?;
}
}
(Ok(rsp), Err(error_checker)) => {
let error = Err(eyre!("received a response when an error was expected"))
.with_section(|| format!("{rsp:?}").header("Found Response:"));
let error = match std::panic::catch_unwind(|| error_checker.mock()) {
Ok(expected_err) => error
.with_section(|| format!("{expected_err:?}").header("Expected Error:")),
Err(pi) => {
let payload = pi
.downcast_ref::<String>()
.cloned()
.or_else(|| pi.downcast_ref::<&str>().map(ToString::to_string))
.unwrap_or_else(|| "<non string panic payload>".into());
error
.section(payload.header("Panic:"))
.wrap_err("ErrorChecker panicked when producing expected response")
}
};
error?;
}
(Err(e), Ok(expected_rsp)) => {
Err(eyre!("received an error when a response was expected"))
.with_error(|| ErrorCheckerError(e.into()))
.with_section(|| format!("{expected_rsp:?}").header("Expected Response:"))?
}
(Err(e), Err(error_checker)) => {
error_checker.check(e.into())?;
continue;
}
}
}
Ok(())
}
}
impl<R, S, I> Service<R> for Transcript<R, S, I>
where
R: Debug + Eq,
I: Iterator<Item = (R, Result<S, ExpectedTranscriptError>)>,
{
type Response = S;
type Error = Report;
type Future = Ready<Result<S, Report>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[track_caller]
fn call(&mut self, request: R) -> Self::Future {
if let Some((expected_request, response)) = self.messages.next() {
match response {
Ok(response) => {
if request == expected_request {
ready(Ok(response))
} else {
ready(
Err(eyre!("received unexpected request"))
.with_section(|| {
format!("{expected_request:?}").header("Expected Request:")
})
.with_section(|| format!("{request:?}").header("Found Request:")),
)
}
}
Err(check_fn) => ready(Err(check_fn.mock())),
}
} else {
ready(Err(eyre!("Got request after transcript ended")))
}
}
}