diff --git a/Cargo.lock b/Cargo.lock index 33b64d56d..65022d904 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2169,7 +2169,6 @@ dependencies = [ name = "tower-batch" version = "0.1.0" dependencies = [ - "color-eyre", "ed25519-zebra", "futures", "futures-core", diff --git a/tower-batch/Cargo.toml b/tower-batch/Cargo.toml index 7109d2677..3cd70545d 100644 --- a/tower-batch/Cargo.toml +++ b/tower-batch/Cargo.toml @@ -19,5 +19,4 @@ ed25519-zebra = "1.0" rand = "0.7" tokio = { version = "0.2", features = ["full"]} tracing = "0.1.16" -color-eyre = "0.5" zebra-test = { path = "../zebra-test/" } diff --git a/tower-batch/src/error.rs b/tower-batch/src/error.rs index 7d35a8a1b..418957fc3 100644 --- a/tower-batch/src/error.rs +++ b/tower-batch/src/error.rs @@ -1,12 +1,47 @@ //! Error types for the `Batch` middleware. -use std::fmt; +use crate::BoxError; +use std::{fmt, sync::Arc}; + +/// An error produced by a `Service` wrapped by a `Batch`. +#[derive(Debug)] +pub struct ServiceError { + inner: Arc, +} /// An error produced when the batch worker closes unexpectedly. pub struct Closed { _p: (), } +// ===== impl ServiceError ===== + +impl ServiceError { + pub(crate) fn new(inner: BoxError) -> ServiceError { + let inner = Arc::new(inner); + ServiceError { inner } + } + + // Private to avoid exposing `Clone` trait as part of the public API + pub(crate) fn clone(&self) -> ServiceError { + ServiceError { + inner: self.inner.clone(), + } + } +} + +impl fmt::Display for ServiceError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "batching service failed: {}", self.inner) + } +} + +impl std::error::Error for ServiceError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&**self.inner) + } +} + // ===== impl Closed ===== impl Closed { diff --git a/tower-batch/src/future.rs b/tower-batch/src/future.rs index f41301fa8..ed96ce3fc 100644 --- a/tower-batch/src/future.rs +++ b/tower-batch/src/future.rs @@ -4,88 +4,47 @@ use super::{error::Closed, message}; use futures_core::ready; use pin_project::pin_project; use std::{ - fmt::Debug, future::Future, pin::Pin, task::{Context, Poll}, }; -use tower::Service; /// Future that completes when the batch processing is complete. #[pin_project] -pub struct ResponseFuture -where - S: Service>, -{ +#[derive(Debug)] +pub struct ResponseFuture { #[pin] - state: ResponseState, -} - -impl Debug for ResponseFuture -where - S: Service>, - S::Future: Debug, - S::Error: Debug, - E2: Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ResponseFuture") - .field("state", &self.state) - .finish() - } + state: ResponseState, } #[pin_project(project = ResponseStateProj)] -enum ResponseState -where - S: Service>, -{ - Failed(Option), - Rx(#[pin] message::Rx), - Poll(#[pin] S::Future), +#[derive(Debug)] +enum ResponseState { + Failed(Option), + Rx(#[pin] message::Rx), + Poll(#[pin] T), } -impl Debug for ResponseState -where - S: Service>, - S::Future: Debug, - S::Error: Debug, - E2: Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ResponseState::Failed(e) => f.debug_tuple("ResponseState::Failed").field(e).finish(), - ResponseState::Rx(rx) => f.debug_tuple("ResponseState::Rx").field(rx).finish(), - ResponseState::Poll(fut) => f.debug_tuple("ResponseState::Pool").field(fut).finish(), - } - } -} - -impl ResponseFuture -where - S: Service>, -{ - pub(crate) fn new(rx: message::Rx) -> Self { +impl ResponseFuture { + pub(crate) fn new(rx: message::Rx) -> Self { ResponseFuture { state: ResponseState::Rx(rx), } } - pub(crate) fn failed(err: E2) -> Self { + pub(crate) fn failed(err: crate::BoxError) -> Self { ResponseFuture { state: ResponseState::Failed(Some(err)), } } } -impl Future for ResponseFuture +impl Future for ResponseFuture where - S: Service>, - S::Future: Future>, - S::Error: Into, - crate::error::Closed: Into, + F: Future>, + E: Into, { - type Output = Result; + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); diff --git a/tower-batch/src/layer.rs b/tower-batch/src/layer.rs index b137d19f8..8fda7ba30 100644 --- a/tower-batch/src/layer.rs +++ b/tower-batch/src/layer.rs @@ -9,14 +9,13 @@ use tower::Service; /// which means that this layer can only be used on the Tokio runtime. /// /// See the module documentation for more details. -pub struct BatchLayer { +pub struct BatchLayer { max_items: usize, max_latency: std::time::Duration, _p: PhantomData, - _e: PhantomData, } -impl BatchLayer { +impl BatchLayer { /// Creates a new `BatchLayer`. /// /// The wrapper is responsible for telling the inner service when to flush a @@ -29,28 +28,25 @@ impl BatchLayer { max_items, max_latency, _p: PhantomData, - _e: PhantomData, } } } -impl Layer for BatchLayer +impl Layer for BatchLayer where S: Service> + Send + 'static, S::Future: Send, - S::Error: Clone + Into + Send + Sync, + S::Error: Into + Send + Sync, Request: Send + 'static, - E2: Send + 'static, - crate::error::Closed: Into, { - type Service = Batch; + type Service = Batch; fn layer(&self, service: S) -> Self::Service { Batch::new(service, self.max_items, self.max_latency) } } -impl fmt::Debug for BatchLayer { +impl fmt::Debug for BatchLayer { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("BufferLayer") .field("max_items", &self.max_items) diff --git a/tower-batch/src/message.rs b/tower-batch/src/message.rs index 7e433592b..dc73a6adb 100644 --- a/tower-batch/src/message.rs +++ b/tower-batch/src/message.rs @@ -1,15 +1,16 @@ +use super::error::ServiceError; use tokio::sync::oneshot; /// Message sent to the batch worker #[derive(Debug)] -pub(crate) struct Message { +pub(crate) struct Message { pub(crate) request: Request, - pub(crate) tx: Tx, + pub(crate) tx: Tx, pub(crate) span: tracing::Span, } /// Response sender -pub(crate) type Tx = oneshot::Sender>; +pub(crate) type Tx = oneshot::Sender>; /// Response receiver -pub(crate) type Rx = oneshot::Receiver>; +pub(crate) type Rx = oneshot::Receiver>; diff --git a/tower-batch/src/service.rs b/tower-batch/src/service.rs index ceddcc2f8..28c4e7132 100644 --- a/tower-batch/src/service.rs +++ b/tower-batch/src/service.rs @@ -6,10 +6,7 @@ use super::{ }; use futures_core::ready; -use std::{ - marker::PhantomData, - task::{Context, Poll}, -}; +use std::task::{Context, Poll}; use tokio::sync::{mpsc, oneshot}; use tower::Service; @@ -17,23 +14,18 @@ use tower::Service; /// /// See the module documentation for more details. #[derive(Debug)] -pub struct Batch +pub struct Batch where - S: Service>, + T: Service>, { - tx: mpsc::Sender>, - handle: Handle, - _e: PhantomData, + tx: mpsc::Sender>, + handle: Handle, } -impl Batch +impl Batch where - S: Service>, - S::Error: Into + Clone, - E2: Send + 'static, - crate::error::Closed: Into, - // crate::error::Closed: Into<>::Error> + Send + Sync + 'static, - // crate::error::ServiceError: Into<>::Error> + Send + Sync + 'static, + T: Service>, + T::Error: Into, { /// Creates a new `Batch` wrapping `service`. /// @@ -45,39 +37,33 @@ where /// /// The default Tokio executor is used to run the given service, which means /// that this method must be called while on the Tokio runtime. - pub fn new(service: S, max_items: usize, max_latency: std::time::Duration) -> Self + pub fn new(service: T, max_items: usize, max_latency: std::time::Duration) -> Self where - S: Send + 'static, - S::Future: Send, - S::Error: Send + Sync + Clone, + T: Send + 'static, + T::Future: Send, + T::Error: Send + Sync, Request: Send + 'static, { // XXX(hdevalence): is this bound good let (tx, rx) = mpsc::channel(1); let (handle, worker) = Worker::new(service, rx, max_items, max_latency); tokio::spawn(worker.run()); - Batch { - tx, - handle, - _e: PhantomData, - } + Batch { tx, handle } } - fn get_worker_error(&self) -> E2 { + fn get_worker_error(&self) -> crate::BoxError { self.handle.get_error_on_closed() } } -impl Service for Batch +impl Service for Batch where - S: Service>, - crate::error::Closed: Into, - S::Error: Into + Clone, - E2: Send + 'static, + T: Service>, + T::Error: Into, { - type Response = S::Response; - type Error = E2; - type Future = ResponseFuture; + type Response = T::Response; + type Error = crate::BoxError; + type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { // If the inner service has errored, then we error here. @@ -119,15 +105,14 @@ where } } -impl Clone for Batch +impl Clone for Batch where - S: Service>, + T: Service>, { fn clone(&self) -> Self { Self { tx: self.tx.clone(), handle: self.handle.clone(), - _e: PhantomData, } } } diff --git a/tower-batch/src/worker.rs b/tower-batch/src/worker.rs index 3d4121a12..8d4ab367e 100644 --- a/tower-batch/src/worker.rs +++ b/tower-batch/src/worker.rs @@ -1,14 +1,11 @@ use super::{ - error::Closed, + error::{Closed, ServiceError}, message::{self, Message}, BatchControl, }; use futures::future::TryFutureExt; use pin_project::pin_project; -use std::{ - marker::PhantomData, - sync::{Arc, Mutex}, -}; +use std::sync::{Arc, Mutex}; use tokio::{ stream::StreamExt, sync::mpsc, @@ -26,41 +23,38 @@ use tracing_futures::Instrument; /// implement (only call). #[pin_project] #[derive(Debug)] -pub struct Worker +pub struct Worker where - S: Service>, - S::Error: Into, + T: Service>, + T::Error: Into, { - rx: mpsc::Receiver>, - service: S, - failed: Option, - handle: Handle, + rx: mpsc::Receiver>, + service: T, + failed: Option, + handle: Handle, max_items: usize, max_latency: std::time::Duration, - _e: PhantomData, } /// Get the error out #[derive(Debug)] -pub(crate) struct Handle { - inner: Arc>>, - _e: PhantomData, +pub(crate) struct Handle { + inner: Arc>>, } -impl Worker +impl Worker where - S: Service>, - S::Error: Into + Clone, + T: Service>, + T::Error: Into, { pub(crate) fn new( - service: S, - rx: mpsc::Receiver>, + service: T, + rx: mpsc::Receiver>, max_items: usize, max_latency: std::time::Duration, - ) -> (Handle, Worker) { + ) -> (Handle, Worker) { let handle = Handle { inner: Arc::new(Mutex::new(None)), - _e: PhantomData, }; let worker = Worker { @@ -70,16 +64,15 @@ where failed: None, max_items, max_latency, - _e: PhantomData, }; (handle, worker) } - async fn process_req(&mut self, req: Request, tx: message::Tx) { - if let Some(failed) = self.failed.clone() { + async fn process_req(&mut self, req: Request, tx: message::Tx) { + if let Some(ref failed) = self.failed { tracing::trace!("notifying caller about worker failure"); - let _ = tx.send(Err(failed)); + let _ = tx.send(Err(failed.clone())); } else { match self.service.ready_and().await { Ok(svc) => { @@ -87,11 +80,12 @@ where let _ = tx.send(Ok(rsp)); } Err(e) => { - self.failed(e); + self.failed(e.into()); let _ = tx.send(Err(self .failed - .clone() - .expect("Worker::failed did not set self.failed?"))); + .as_ref() + .expect("Worker::failed did not set self.failed?") + .clone())); } } } @@ -104,7 +98,7 @@ where .and_then(|svc| svc.call(BatchControl::Flush)) .await { - self.failed(e); + self.failed(e.into()); } } @@ -171,12 +165,11 @@ where } } - fn failed(&mut self, error: S::Error) { - // The underlying service failed when we called `poll_ready` on it with - // the given `error`. We need to communicate this to all the `Buffer` - // handles. To do so, we require that `S::Error` implements `Clone`, - // clone the error to send to all pending requests, and store it so that - // subsequent requests will also fail with the same error. + fn failed(&mut self, error: crate::BoxError) { + // The underlying service failed when we called `poll_ready` on it with the given `error`. We + // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in + // an `Arc`, send that `Arc` to all pending requests, and store it so that subsequent + // requests will also fail with the same error. // Note that we need to handle the case where some handle is concurrently trying to send us // a request. We need to make sure that *either* the send of the request fails *or* it @@ -185,6 +178,7 @@ where // request. We do this by *first* exposing the error, *then* closing the channel used to // send more requests (so the client will see the error when the send fails), and *then* // sending the error to all outstanding requests. + let error = ServiceError::new(error); let mut inner = self.handle.inner.lock().unwrap(); @@ -206,26 +200,21 @@ where } } -impl Handle -where - E: Clone + Into, - crate::error::Closed: Into, -{ - pub(crate) fn get_error_on_closed(&self) -> E2 { +impl Handle { + pub(crate) fn get_error_on_closed(&self) -> crate::BoxError { self.inner .lock() .unwrap() - .clone() - .map(Into::into) + .as_ref() + .map(|svc_err| svc_err.clone().into()) .unwrap_or_else(|| Closed::new().into()) } } -impl Clone for Handle { - fn clone(&self) -> Handle { +impl Clone for Handle { + fn clone(&self) -> Handle { Handle { inner: self.inner.clone(), - _e: PhantomData, } } } diff --git a/tower-batch/tests/ed25519.rs b/tower-batch/tests/ed25519.rs index f1b44ff9d..71446759a 100644 --- a/tower-batch/tests/ed25519.rs +++ b/tower-batch/tests/ed25519.rs @@ -6,7 +6,6 @@ use std::{ time::Duration, }; -use color_eyre::eyre::Result; use ed25519_zebra::*; use futures::stream::{FuturesUnordered, StreamExt}; use rand::thread_rng; @@ -109,23 +108,31 @@ where } #[tokio::test] -async fn batch_flushes_on_max_items() -> Result<()> { +async fn batch_flushes_on_max_items() { use tokio::time::timeout; zebra_test::init(); // Use a very long max_latency and a short timeout to check that // flushing is happening based on hitting max_items. let verifier = Batch::new(Ed25519Verifier::new(), 10, Duration::from_secs(1000)); - timeout(Duration::from_secs(1), sign_and_verify(verifier, 100)).await? + assert!( + timeout(Duration::from_secs(1), sign_and_verify(verifier, 100)) + .await + .is_ok() + ); } #[tokio::test] -async fn batch_flushes_on_max_latency() -> Result<()> { +async fn batch_flushes_on_max_latency() { use tokio::time::timeout; zebra_test::init(); // Use a very high max_items and a short timeout to check that // flushing is happening based on hitting max_latency. let verifier = Batch::new(Ed25519Verifier::new(), 100, Duration::from_millis(500)); - timeout(Duration::from_secs(1), sign_and_verify(verifier, 10)).await? + assert!( + timeout(Duration::from_secs(1), sign_and_verify(verifier, 10)) + .await + .is_ok() + ); }