Revert #500 (generic errors in tower-batch).

Unfortunately, since the Batch wrapper was changed to have a generic error
type, when wrapping it in another Service, nothing constrains the error type,
so we have to specify it explicitly to avoid an inference hole.  This is pretty
unergonomic -- from the compiler error message it's very unintuitive that the
right fix is to change `Batch::new` to `Batch::<_, _, SomeError>::new`.

The options are:

1. roll back the changes that make the error type generic, so that the error
   type is a concrete type;

2. keep the error type generic but hardcode the error in the default
   constructor and add an additional code path that allows overriding the
   error.

However, there's a further issue with generic errors: the error type must be
Clone.  This problem comes from the fact that there can be multiple Batch
handles that have to share access to errors generated by the inner Batch
worker, so there's not a way to work around this.  However, almost all error
types aren't Clone, so there are fairly few error types that we would be
swapping in.

This suggests that in case (2) we would be maintaining extra code to allow
generic errors, but with restrictive enough generic bounds to make it
impractical to use generic error types.  For this reason I think that (1) is a
better option.
This commit is contained in:
Henry de Valence 2020-07-15 21:42:57 -07:00 committed by Deirdre Connolly
parent 7067ac6e0d
commit 0586da7167
9 changed files with 133 additions and 163 deletions

1
Cargo.lock generated
View File

@ -2169,7 +2169,6 @@ dependencies = [
name = "tower-batch" name = "tower-batch"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"color-eyre",
"ed25519-zebra", "ed25519-zebra",
"futures", "futures",
"futures-core", "futures-core",

View File

@ -19,5 +19,4 @@ ed25519-zebra = "1.0"
rand = "0.7" rand = "0.7"
tokio = { version = "0.2", features = ["full"]} tokio = { version = "0.2", features = ["full"]}
tracing = "0.1.16" tracing = "0.1.16"
color-eyre = "0.5"
zebra-test = { path = "../zebra-test/" } zebra-test = { path = "../zebra-test/" }

View File

@ -1,12 +1,47 @@
//! Error types for the `Batch` middleware. //! Error types for the `Batch` middleware.
use std::fmt; use crate::BoxError;
use std::{fmt, sync::Arc};
/// An error produced by a `Service` wrapped by a `Batch`.
#[derive(Debug)]
pub struct ServiceError {
inner: Arc<BoxError>,
}
/// An error produced when the batch worker closes unexpectedly. /// An error produced when the batch worker closes unexpectedly.
pub struct Closed { pub struct Closed {
_p: (), _p: (),
} }
// ===== impl ServiceError =====
impl ServiceError {
pub(crate) fn new(inner: BoxError) -> ServiceError {
let inner = Arc::new(inner);
ServiceError { inner }
}
// Private to avoid exposing `Clone` trait as part of the public API
pub(crate) fn clone(&self) -> ServiceError {
ServiceError {
inner: self.inner.clone(),
}
}
}
impl fmt::Display for ServiceError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "batching service failed: {}", self.inner)
}
}
impl std::error::Error for ServiceError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&**self.inner)
}
}
// ===== impl Closed ===== // ===== impl Closed =====
impl Closed { impl Closed {

View File

@ -4,88 +4,47 @@ use super::{error::Closed, message};
use futures_core::ready; use futures_core::ready;
use pin_project::pin_project; use pin_project::pin_project;
use std::{ use std::{
fmt::Debug,
future::Future, future::Future,
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tower::Service;
/// Future that completes when the batch processing is complete. /// Future that completes when the batch processing is complete.
#[pin_project] #[pin_project]
pub struct ResponseFuture<S, E2, Response> #[derive(Debug)]
where pub struct ResponseFuture<T> {
S: Service<crate::BatchControl<Response>>,
{
#[pin] #[pin]
state: ResponseState<S, E2, Response>, state: ResponseState<T>,
}
impl<S, E2, Response> Debug for ResponseFuture<S, E2, Response>
where
S: Service<crate::BatchControl<Response>>,
S::Future: Debug,
S::Error: Debug,
E2: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseFuture")
.field("state", &self.state)
.finish()
}
} }
#[pin_project(project = ResponseStateProj)] #[pin_project(project = ResponseStateProj)]
enum ResponseState<S, E2, Response> #[derive(Debug)]
where enum ResponseState<T> {
S: Service<crate::BatchControl<Response>>, Failed(Option<crate::BoxError>),
{ Rx(#[pin] message::Rx<T>),
Failed(Option<E2>), Poll(#[pin] T),
Rx(#[pin] message::Rx<S::Future, S::Error>),
Poll(#[pin] S::Future),
} }
impl<S, E2, Response> Debug for ResponseState<S, E2, Response> impl<T> ResponseFuture<T> {
where pub(crate) fn new(rx: message::Rx<T>) -> Self {
S: Service<crate::BatchControl<Response>>,
S::Future: Debug,
S::Error: Debug,
E2: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ResponseState::Failed(e) => f.debug_tuple("ResponseState::Failed").field(e).finish(),
ResponseState::Rx(rx) => f.debug_tuple("ResponseState::Rx").field(rx).finish(),
ResponseState::Poll(fut) => f.debug_tuple("ResponseState::Pool").field(fut).finish(),
}
}
}
impl<S, E2, Response> ResponseFuture<S, E2, Response>
where
S: Service<crate::BatchControl<Response>>,
{
pub(crate) fn new(rx: message::Rx<S::Future, S::Error>) -> Self {
ResponseFuture { ResponseFuture {
state: ResponseState::Rx(rx), state: ResponseState::Rx(rx),
} }
} }
pub(crate) fn failed(err: E2) -> Self { pub(crate) fn failed(err: crate::BoxError) -> Self {
ResponseFuture { ResponseFuture {
state: ResponseState::Failed(Some(err)), state: ResponseState::Failed(Some(err)),
} }
} }
} }
impl<S, E2, Response> Future for ResponseFuture<S, E2, Response> impl<F, T, E> Future for ResponseFuture<F>
where where
S: Service<crate::BatchControl<Response>>, F: Future<Output = Result<T, E>>,
S::Future: Future<Output = Result<S::Response, S::Error>>, E: Into<crate::BoxError>,
S::Error: Into<E2>,
crate::error::Closed: Into<E2>,
{ {
type Output = Result<S::Response, E2>; type Output = Result<T, crate::BoxError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project(); let mut this = self.project();

View File

@ -9,14 +9,13 @@ use tower::Service;
/// which means that this layer can only be used on the Tokio runtime. /// which means that this layer can only be used on the Tokio runtime.
/// ///
/// See the module documentation for more details. /// See the module documentation for more details.
pub struct BatchLayer<Request, E2> { pub struct BatchLayer<Request> {
max_items: usize, max_items: usize,
max_latency: std::time::Duration, max_latency: std::time::Duration,
_p: PhantomData<fn(Request)>, _p: PhantomData<fn(Request)>,
_e: PhantomData<E2>,
} }
impl<Request, E2> BatchLayer<Request, E2> { impl<Request> BatchLayer<Request> {
/// Creates a new `BatchLayer`. /// Creates a new `BatchLayer`.
/// ///
/// The wrapper is responsible for telling the inner service when to flush a /// The wrapper is responsible for telling the inner service when to flush a
@ -29,28 +28,25 @@ impl<Request, E2> BatchLayer<Request, E2> {
max_items, max_items,
max_latency, max_latency,
_p: PhantomData, _p: PhantomData,
_e: PhantomData,
} }
} }
} }
impl<S, Request, E2> Layer<S> for BatchLayer<Request, E2> impl<S, Request> Layer<S> for BatchLayer<Request>
where where
S: Service<BatchControl<Request>> + Send + 'static, S: Service<BatchControl<Request>> + Send + 'static,
S::Future: Send, S::Future: Send,
S::Error: Clone + Into<E2> + Send + Sync, S::Error: Into<crate::BoxError> + Send + Sync,
Request: Send + 'static, Request: Send + 'static,
E2: Send + 'static,
crate::error::Closed: Into<E2>,
{ {
type Service = Batch<S, Request, E2>; type Service = Batch<S, Request>;
fn layer(&self, service: S) -> Self::Service { fn layer(&self, service: S) -> Self::Service {
Batch::new(service, self.max_items, self.max_latency) Batch::new(service, self.max_items, self.max_latency)
} }
} }
impl<Request, E2> fmt::Debug for BatchLayer<Request, E2> { impl<Request> fmt::Debug for BatchLayer<Request> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BufferLayer") f.debug_struct("BufferLayer")
.field("max_items", &self.max_items) .field("max_items", &self.max_items)

View File

@ -1,15 +1,16 @@
use super::error::ServiceError;
use tokio::sync::oneshot; use tokio::sync::oneshot;
/// Message sent to the batch worker /// Message sent to the batch worker
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Message<Request, Fut, E> { pub(crate) struct Message<Request, Fut> {
pub(crate) request: Request, pub(crate) request: Request,
pub(crate) tx: Tx<Fut, E>, pub(crate) tx: Tx<Fut>,
pub(crate) span: tracing::Span, pub(crate) span: tracing::Span,
} }
/// Response sender /// Response sender
pub(crate) type Tx<Fut, E> = oneshot::Sender<Result<Fut, E>>; pub(crate) type Tx<Fut> = oneshot::Sender<Result<Fut, ServiceError>>;
/// Response receiver /// Response receiver
pub(crate) type Rx<Fut, E> = oneshot::Receiver<Result<Fut, E>>; pub(crate) type Rx<Fut> = oneshot::Receiver<Result<Fut, ServiceError>>;

View File

@ -6,10 +6,7 @@ use super::{
}; };
use futures_core::ready; use futures_core::ready;
use std::{ use std::task::{Context, Poll};
marker::PhantomData,
task::{Context, Poll},
};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tower::Service; use tower::Service;
@ -17,23 +14,18 @@ use tower::Service;
/// ///
/// See the module documentation for more details. /// See the module documentation for more details.
#[derive(Debug)] #[derive(Debug)]
pub struct Batch<S, Request, E2 = crate::BoxError> pub struct Batch<T, Request>
where where
S: Service<BatchControl<Request>>, T: Service<BatchControl<Request>>,
{ {
tx: mpsc::Sender<Message<Request, S::Future, S::Error>>, tx: mpsc::Sender<Message<Request, T::Future>>,
handle: Handle<S::Error, E2>, handle: Handle,
_e: PhantomData<E2>,
} }
impl<S, Request, E2> Batch<S, Request, E2> impl<T, Request> Batch<T, Request>
where where
S: Service<BatchControl<Request>>, T: Service<BatchControl<Request>>,
S::Error: Into<E2> + Clone, T::Error: Into<crate::BoxError>,
E2: Send + 'static,
crate::error::Closed: Into<E2>,
// crate::error::Closed: Into<<Self as Service<Request>>::Error> + Send + Sync + 'static,
// crate::error::ServiceError: Into<<Self as Service<Request>>::Error> + Send + Sync + 'static,
{ {
/// Creates a new `Batch` wrapping `service`. /// Creates a new `Batch` wrapping `service`.
/// ///
@ -45,39 +37,33 @@ where
/// ///
/// The default Tokio executor is used to run the given service, which means /// The default Tokio executor is used to run the given service, which means
/// that this method must be called while on the Tokio runtime. /// that this method must be called while on the Tokio runtime.
pub fn new(service: S, max_items: usize, max_latency: std::time::Duration) -> Self pub fn new(service: T, max_items: usize, max_latency: std::time::Duration) -> Self
where where
S: Send + 'static, T: Send + 'static,
S::Future: Send, T::Future: Send,
S::Error: Send + Sync + Clone, T::Error: Send + Sync,
Request: Send + 'static, Request: Send + 'static,
{ {
// XXX(hdevalence): is this bound good // XXX(hdevalence): is this bound good
let (tx, rx) = mpsc::channel(1); let (tx, rx) = mpsc::channel(1);
let (handle, worker) = Worker::new(service, rx, max_items, max_latency); let (handle, worker) = Worker::new(service, rx, max_items, max_latency);
tokio::spawn(worker.run()); tokio::spawn(worker.run());
Batch { Batch { tx, handle }
tx,
handle,
_e: PhantomData,
}
} }
fn get_worker_error(&self) -> E2 { fn get_worker_error(&self) -> crate::BoxError {
self.handle.get_error_on_closed() self.handle.get_error_on_closed()
} }
} }
impl<S, Request, E2> Service<Request> for Batch<S, Request, E2> impl<T, Request> Service<Request> for Batch<T, Request>
where where
S: Service<BatchControl<Request>>, T: Service<BatchControl<Request>>,
crate::error::Closed: Into<E2>, T::Error: Into<crate::BoxError>,
S::Error: Into<E2> + Clone,
E2: Send + 'static,
{ {
type Response = S::Response; type Response = T::Response;
type Error = E2; type Error = crate::BoxError;
type Future = ResponseFuture<S, E2, Request>; type Future = ResponseFuture<T::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// If the inner service has errored, then we error here. // If the inner service has errored, then we error here.
@ -119,15 +105,14 @@ where
} }
} }
impl<S, Request> Clone for Batch<S, Request> impl<T, Request> Clone for Batch<T, Request>
where where
S: Service<BatchControl<Request>>, T: Service<BatchControl<Request>>,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
tx: self.tx.clone(), tx: self.tx.clone(),
handle: self.handle.clone(), handle: self.handle.clone(),
_e: PhantomData,
} }
} }
} }

View File

@ -1,14 +1,11 @@
use super::{ use super::{
error::Closed, error::{Closed, ServiceError},
message::{self, Message}, message::{self, Message},
BatchControl, BatchControl,
}; };
use futures::future::TryFutureExt; use futures::future::TryFutureExt;
use pin_project::pin_project; use pin_project::pin_project;
use std::{ use std::sync::{Arc, Mutex};
marker::PhantomData,
sync::{Arc, Mutex},
};
use tokio::{ use tokio::{
stream::StreamExt, stream::StreamExt,
sync::mpsc, sync::mpsc,
@ -26,41 +23,38 @@ use tracing_futures::Instrument;
/// implement (only call). /// implement (only call).
#[pin_project] #[pin_project]
#[derive(Debug)] #[derive(Debug)]
pub struct Worker<S, Request, E2> pub struct Worker<T, Request>
where where
S: Service<BatchControl<Request>>, T: Service<BatchControl<Request>>,
S::Error: Into<E2>, T::Error: Into<crate::BoxError>,
{ {
rx: mpsc::Receiver<Message<Request, S::Future, S::Error>>, rx: mpsc::Receiver<Message<Request, T::Future>>,
service: S, service: T,
failed: Option<S::Error>, failed: Option<ServiceError>,
handle: Handle<S::Error, E2>, handle: Handle,
max_items: usize, max_items: usize,
max_latency: std::time::Duration, max_latency: std::time::Duration,
_e: PhantomData<E2>,
} }
/// Get the error out /// Get the error out
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Handle<E, E2> { pub(crate) struct Handle {
inner: Arc<Mutex<Option<E>>>, inner: Arc<Mutex<Option<ServiceError>>>,
_e: PhantomData<E2>,
} }
impl<S, Request, E2> Worker<S, Request, E2> impl<T, Request> Worker<T, Request>
where where
S: Service<BatchControl<Request>>, T: Service<BatchControl<Request>>,
S::Error: Into<E2> + Clone, T::Error: Into<crate::BoxError>,
{ {
pub(crate) fn new( pub(crate) fn new(
service: S, service: T,
rx: mpsc::Receiver<Message<Request, S::Future, S::Error>>, rx: mpsc::Receiver<Message<Request, T::Future>>,
max_items: usize, max_items: usize,
max_latency: std::time::Duration, max_latency: std::time::Duration,
) -> (Handle<S::Error, E2>, Worker<S, Request, E2>) { ) -> (Handle, Worker<T, Request>) {
let handle = Handle { let handle = Handle {
inner: Arc::new(Mutex::new(None)), inner: Arc::new(Mutex::new(None)),
_e: PhantomData,
}; };
let worker = Worker { let worker = Worker {
@ -70,16 +64,15 @@ where
failed: None, failed: None,
max_items, max_items,
max_latency, max_latency,
_e: PhantomData,
}; };
(handle, worker) (handle, worker)
} }
async fn process_req(&mut self, req: Request, tx: message::Tx<S::Future, S::Error>) { async fn process_req(&mut self, req: Request, tx: message::Tx<T::Future>) {
if let Some(failed) = self.failed.clone() { if let Some(ref failed) = self.failed {
tracing::trace!("notifying caller about worker failure"); tracing::trace!("notifying caller about worker failure");
let _ = tx.send(Err(failed)); let _ = tx.send(Err(failed.clone()));
} else { } else {
match self.service.ready_and().await { match self.service.ready_and().await {
Ok(svc) => { Ok(svc) => {
@ -87,11 +80,12 @@ where
let _ = tx.send(Ok(rsp)); let _ = tx.send(Ok(rsp));
} }
Err(e) => { Err(e) => {
self.failed(e); self.failed(e.into());
let _ = tx.send(Err(self let _ = tx.send(Err(self
.failed .failed
.clone() .as_ref()
.expect("Worker::failed did not set self.failed?"))); .expect("Worker::failed did not set self.failed?")
.clone()));
} }
} }
} }
@ -104,7 +98,7 @@ where
.and_then(|svc| svc.call(BatchControl::Flush)) .and_then(|svc| svc.call(BatchControl::Flush))
.await .await
{ {
self.failed(e); self.failed(e.into());
} }
} }
@ -171,12 +165,11 @@ where
} }
} }
fn failed(&mut self, error: S::Error) { fn failed(&mut self, error: crate::BoxError) {
// The underlying service failed when we called `poll_ready` on it with // The underlying service failed when we called `poll_ready` on it with the given `error`. We
// the given `error`. We need to communicate this to all the `Buffer` // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
// handles. To do so, we require that `S::Error` implements `Clone`, // an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
// clone the error to send to all pending requests, and store it so that // requests will also fail with the same error.
// subsequent requests will also fail with the same error.
// Note that we need to handle the case where some handle is concurrently trying to send us // Note that we need to handle the case where some handle is concurrently trying to send us
// a request. We need to make sure that *either* the send of the request fails *or* it // a request. We need to make sure that *either* the send of the request fails *or* it
@ -185,6 +178,7 @@ where
// request. We do this by *first* exposing the error, *then* closing the channel used to // request. We do this by *first* exposing the error, *then* closing the channel used to
// send more requests (so the client will see the error when the send fails), and *then* // send more requests (so the client will see the error when the send fails), and *then*
// sending the error to all outstanding requests. // sending the error to all outstanding requests.
let error = ServiceError::new(error);
let mut inner = self.handle.inner.lock().unwrap(); let mut inner = self.handle.inner.lock().unwrap();
@ -206,26 +200,21 @@ where
} }
} }
impl<E, E2> Handle<E, E2> impl Handle {
where pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
E: Clone + Into<E2>,
crate::error::Closed: Into<E2>,
{
pub(crate) fn get_error_on_closed(&self) -> E2 {
self.inner self.inner
.lock() .lock()
.unwrap() .unwrap()
.clone() .as_ref()
.map(Into::into) .map(|svc_err| svc_err.clone().into())
.unwrap_or_else(|| Closed::new().into()) .unwrap_or_else(|| Closed::new().into())
} }
} }
impl<E, E2> Clone for Handle<E, E2> { impl Clone for Handle {
fn clone(&self) -> Handle<E, E2> { fn clone(&self) -> Handle {
Handle { Handle {
inner: self.inner.clone(), inner: self.inner.clone(),
_e: PhantomData,
} }
} }
} }

View File

@ -6,7 +6,6 @@ use std::{
time::Duration, time::Duration,
}; };
use color_eyre::eyre::Result;
use ed25519_zebra::*; use ed25519_zebra::*;
use futures::stream::{FuturesUnordered, StreamExt}; use futures::stream::{FuturesUnordered, StreamExt};
use rand::thread_rng; use rand::thread_rng;
@ -109,23 +108,31 @@ where
} }
#[tokio::test] #[tokio::test]
async fn batch_flushes_on_max_items() -> Result<()> { async fn batch_flushes_on_max_items() {
use tokio::time::timeout; use tokio::time::timeout;
zebra_test::init(); zebra_test::init();
// Use a very long max_latency and a short timeout to check that // Use a very long max_latency and a short timeout to check that
// flushing is happening based on hitting max_items. // flushing is happening based on hitting max_items.
let verifier = Batch::new(Ed25519Verifier::new(), 10, Duration::from_secs(1000)); let verifier = Batch::new(Ed25519Verifier::new(), 10, Duration::from_secs(1000));
timeout(Duration::from_secs(1), sign_and_verify(verifier, 100)).await? assert!(
timeout(Duration::from_secs(1), sign_and_verify(verifier, 100))
.await
.is_ok()
);
} }
#[tokio::test] #[tokio::test]
async fn batch_flushes_on_max_latency() -> Result<()> { async fn batch_flushes_on_max_latency() {
use tokio::time::timeout; use tokio::time::timeout;
zebra_test::init(); zebra_test::init();
// Use a very high max_items and a short timeout to check that // Use a very high max_items and a short timeout to check that
// flushing is happening based on hitting max_latency. // flushing is happening based on hitting max_latency.
let verifier = Batch::new(Ed25519Verifier::new(), 100, Duration::from_millis(500)); let verifier = Batch::new(Ed25519Verifier::new(), 100, Duration::from_millis(500));
timeout(Duration::from_secs(1), sign_and_verify(verifier, 10)).await? assert!(
timeout(Duration::from_secs(1), sign_and_verify(verifier, 10))
.await
.is_ok()
);
} }