make return error type for Batch generic

This commit is contained in:
Jane Lusby 2020-06-17 14:39:10 -07:00 committed by Deirdre Connolly
parent 6cc1627a5d
commit 63ae085945
9 changed files with 129 additions and 118 deletions

1
Cargo.lock generated
View File

@ -1927,6 +1927,7 @@ dependencies = [
name = "tower-batch"
version = "0.1.0"
dependencies = [
"color-eyre",
"ed25519-zebra",
"futures",
"futures-core",

View File

@ -21,3 +21,4 @@ tokio = { version = "0.2", features = ["full"]}
tracing-error = "0.1.2"
tracing-subscriber = "0.2.5"
tracing = "0.1.15"
color-eyre = "0.3.4"

View File

@ -1,47 +1,12 @@
//! Error types for the `Batch` middleware.
use crate::BoxError;
use std::{fmt, sync::Arc};
/// An error produced by a `Service` wrapped by a `Batch`.
#[derive(Debug)]
pub struct ServiceError {
inner: Arc<BoxError>,
}
use std::fmt;
/// An error produced when the batch worker closes unexpectedly.
pub struct Closed {
_p: (),
}
// ===== impl ServiceError =====
impl ServiceError {
pub(crate) fn new(inner: BoxError) -> ServiceError {
let inner = Arc::new(inner);
ServiceError { inner }
}
// Private to avoid exposing `Clone` trait as part of the public API
pub(crate) fn clone(&self) -> ServiceError {
ServiceError {
inner: self.inner.clone(),
}
}
}
impl fmt::Display for ServiceError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "batching service failed: {}", self.inner)
}
}
impl std::error::Error for ServiceError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&**self.inner)
}
}
// ===== impl Closed =====
impl Closed {

View File

@ -4,47 +4,68 @@ use super::{error::Closed, message};
use futures_core::ready;
use pin_project::pin_project;
use std::{
fmt::Debug,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::Service;
/// Future that completes when the batch processing is complete.
#[pin_project]
#[derive(Debug)]
pub struct ResponseFuture<T> {
pub struct ResponseFuture<T, E, R>
where
T: Service<crate::BatchControl<R>>,
{
#[pin]
state: ResponseState<T>,
state: ResponseState<T, E, R>,
}
#[pin_project(project = ResponseStateProj)]
#[derive(Debug)]
enum ResponseState<T> {
Failed(Option<crate::BoxError>),
Rx(#[pin] message::Rx<T>),
Poll(#[pin] T),
enum ResponseState<T, E, R>
where
T: Service<crate::BatchControl<R>>,
{
Failed(Option<E>),
Rx(#[pin] message::Rx<T::Future, T::Error>),
Poll(#[pin] T::Future),
}
impl<T> ResponseFuture<T> {
pub(crate) fn new(rx: message::Rx<T>) -> Self {
impl<T, E, R> Debug for ResponseState<T, E, R>
where
T: Service<crate::BatchControl<R>>,
{
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
todo!()
}
}
impl<T, E, R> ResponseFuture<T, E, R>
where
T: Service<crate::BatchControl<R>>,
{
pub(crate) fn new(rx: message::Rx<T::Future, T::Error>) -> Self {
ResponseFuture {
state: ResponseState::Rx(rx),
}
}
pub(crate) fn failed(err: crate::BoxError) -> Self {
pub(crate) fn failed(err: E) -> Self {
ResponseFuture {
state: ResponseState::Failed(Some(err)),
}
}
}
impl<F, T, E> Future for ResponseFuture<F>
impl<S, E2, R> Future for ResponseFuture<S, E2, R>
where
F: Future<Output = Result<T, E>>,
E: Into<crate::BoxError>,
S: Service<crate::BatchControl<R>>,
S::Future: Future<Output = Result<S::Response, S::Error>>,
S::Error: Into<E2>,
crate::error::Closed: Into<E2>,
{
type Output = Result<T, crate::BoxError>;
type Output = Result<S::Response, E2>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();

View File

@ -9,13 +9,14 @@ use tower::Service;
/// which means that this layer can only be used on the Tokio runtime.
///
/// See the module documentation for more details.
pub struct BatchLayer<Request> {
pub struct BatchLayer<Request, E> {
max_items: usize,
max_latency: std::time::Duration,
_p: PhantomData<fn(Request)>,
_e: PhantomData<E>,
}
impl<Request> BatchLayer<Request> {
impl<Request, E> BatchLayer<Request, E> {
/// Creates a new `BatchLayer`.
///
/// The wrapper is responsible for telling the inner service when to flush a
@ -28,25 +29,28 @@ impl<Request> BatchLayer<Request> {
max_items,
max_latency,
_p: PhantomData,
_e: PhantomData,
}
}
}
impl<S, Request> Layer<S> for BatchLayer<Request>
impl<S, Request, E> Layer<S> for BatchLayer<Request, E>
where
S: Service<BatchControl<Request>> + Send + 'static,
S::Future: Send,
S::Error: Into<crate::BoxError> + Send + Sync,
S::Error: Clone + Into<E> + Send + Sync,
Request: Send + 'static,
E: Clone + Send + 'static,
crate::error::Closed: Into<E>,
{
type Service = Batch<S, Request>;
type Service = Batch<S, Request, E>;
fn layer(&self, service: S) -> Self::Service {
Batch::new(service, self.max_items, self.max_latency)
}
}
impl<Request> fmt::Debug for BatchLayer<Request> {
impl<Request, E> fmt::Debug for BatchLayer<Request, E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BufferLayer")
.field("max_items", &self.max_items)

View File

@ -1,16 +1,15 @@
use super::error::ServiceError;
use tokio::sync::oneshot;
/// Message sent to the batch worker
#[derive(Debug)]
pub(crate) struct Message<Request, Fut> {
pub(crate) struct Message<Request, Fut, E> {
pub(crate) request: Request,
pub(crate) tx: Tx<Fut>,
pub(crate) tx: Tx<Fut, E>,
pub(crate) span: tracing::Span,
}
/// Response sender
pub(crate) type Tx<Fut> = oneshot::Sender<Result<Fut, ServiceError>>;
pub(crate) type Tx<Fut, E> = oneshot::Sender<Result<Fut, E>>;
/// Response receiver
pub(crate) type Rx<Fut> = oneshot::Receiver<Result<Fut, ServiceError>>;
pub(crate) type Rx<Fut, E> = oneshot::Receiver<Result<Fut, E>>;

View File

@ -6,7 +6,10 @@ use super::{
};
use futures_core::ready;
use std::task::{Context, Poll};
use std::{
marker::PhantomData,
task::{Context, Poll},
};
use tokio::sync::{mpsc, oneshot};
use tower::Service;
@ -14,18 +17,23 @@ use tower::Service;
///
/// See the module documentation for more details.
#[derive(Debug)]
pub struct Batch<T, Request>
pub struct Batch<T, Request, E = crate::BoxError>
where
T: Service<BatchControl<Request>>,
{
tx: mpsc::Sender<Message<Request, T::Future>>,
handle: Handle,
tx: mpsc::Sender<Message<Request, T::Future, T::Error>>,
handle: Handle<E>,
_error_type: PhantomData<E>,
}
impl<T, Request> Batch<T, Request>
impl<T, Request, E> Batch<T, Request, E>
where
T: Service<BatchControl<Request>>,
T::Error: Into<crate::BoxError>,
T::Error: Into<E>,
E: Send + 'static,
crate::error::Closed: Into<E>,
// crate::error::Closed: Into<<Self as Service<Request>>::Error> + Send + Sync + 'static,
// crate::error::ServiceError: Into<<Self as Service<Request>>::Error> + Send + Sync + 'static,
{
/// Creates a new `Batch` wrapping `service`.
///
@ -41,29 +49,35 @@ where
where
T: Send + 'static,
T::Future: Send,
T::Error: Send + Sync,
T::Error: Send + Sync + Clone,
Request: Send + 'static,
{
// XXX(hdevalence): is this bound good
let (tx, rx) = mpsc::channel(1);
let (handle, worker) = Worker::new(service, rx, max_items, max_latency);
tokio::spawn(worker.run());
Batch { tx, handle }
Batch {
tx,
handle,
_error_type: PhantomData,
}
}
fn get_worker_error(&self) -> crate::BoxError {
fn get_worker_error(&self) -> E {
self.handle.get_error_on_closed()
}
}
impl<T, Request> Service<Request> for Batch<T, Request>
impl<T, Request, E> Service<Request> for Batch<T, Request, E>
where
T: Service<BatchControl<Request>>,
T::Error: Into<crate::BoxError>,
crate::error::Closed: Into<E>,
T::Error: Into<E>,
E: Send + 'static,
{
type Response = T::Response;
type Error = crate::BoxError;
type Future = ResponseFuture<T::Future>;
type Error = E;
type Future = ResponseFuture<T, E, Request>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// If the inner service has errored, then we error here.
@ -113,6 +127,7 @@ where
Self {
tx: self.tx.clone(),
handle: self.handle.clone(),
_error_type: PhantomData,
}
}
}

View File

@ -1,11 +1,14 @@
use super::{
error::{Closed, ServiceError},
error::Closed,
message::{self, Message},
BatchControl,
};
use futures::future::TryFutureExt;
use pin_project::pin_project;
use std::sync::{Arc, Mutex};
use std::{
marker::PhantomData,
sync::{Arc, Mutex},
};
use tokio::{
stream::StreamExt,
sync::mpsc,
@ -23,36 +26,37 @@ use tracing_futures::Instrument;
/// implement (only call).
#[pin_project]
#[derive(Debug)]
pub struct Worker<T, Request>
pub struct Worker<T, Request, E>
where
T: Service<BatchControl<Request>>,
T::Error: Into<crate::BoxError>,
T::Error: Into<E>,
{
rx: mpsc::Receiver<Message<Request, T::Future>>,
rx: mpsc::Receiver<Message<Request, T::Future, T::Error>>,
service: T,
failed: Option<ServiceError>,
handle: Handle,
failed: Option<T::Error>,
handle: Handle<E>,
max_items: usize,
max_latency: std::time::Duration,
_error_type: PhantomData<E>,
}
/// Get the error out
#[derive(Debug)]
pub(crate) struct Handle {
inner: Arc<Mutex<Option<ServiceError>>>,
pub(crate) struct Handle<E> {
inner: Arc<Mutex<Option<E>>>,
}
impl<T, Request> Worker<T, Request>
impl<T, Request, E> Worker<T, Request, E>
where
T: Service<BatchControl<Request>>,
T::Error: Into<crate::BoxError>,
T::Error: Into<E> + Clone,
{
pub(crate) fn new(
service: T,
rx: mpsc::Receiver<Message<Request, T::Future>>,
rx: mpsc::Receiver<Message<Request, T::Future, T::Error>>,
max_items: usize,
max_latency: std::time::Duration,
) -> (Handle, Worker<T, Request>) {
) -> (Handle<E>, Worker<T, Request, E>) {
let handle = Handle {
inner: Arc::new(Mutex::new(None)),
};
@ -64,15 +68,16 @@ where
failed: None,
max_items,
max_latency,
_error_type: PhantomData,
};
(handle, worker)
}
async fn process_req(&mut self, req: Request, tx: message::Tx<T::Future>) {
if let Some(ref failed) = self.failed {
async fn process_req(&mut self, req: Request, tx: message::Tx<T::Future, T::Error>) {
if let Some(failed) = self.failed.clone() {
tracing::trace!("notifying caller about worker failure");
let _ = tx.send(Err(failed.clone()));
let _ = tx.send(Err(failed));
} else {
match self.service.ready_and().await {
Ok(svc) => {
@ -80,12 +85,11 @@ where
let _ = tx.send(Ok(rsp));
}
Err(e) => {
self.failed(e.into());
self.failed(e);
let _ = tx.send(Err(self
.failed
.as_ref()
.expect("Worker::failed did not set self.failed?")
.clone()));
.clone()
.expect("Worker::failed did not set self.failed?")));
}
}
}
@ -98,7 +102,7 @@ where
.and_then(|svc| svc.call(BatchControl::Flush))
.await
{
self.failed(e.into());
self.failed(e);
}
}
@ -165,7 +169,7 @@ where
}
}
fn failed(&mut self, error: crate::BoxError) {
fn failed(&mut self, error: T::Error) {
// The underlying service failed when we called `poll_ready` on it with the given `error`. We
// need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
// an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
@ -178,7 +182,6 @@ where
// request. We do this by *first* exposing the error, *then* closing the channel used to
// send more requests (so the client will see the error when the send fails), and *then*
// sending the error to all outstanding requests.
let error = ServiceError::new(error);
let mut inner = self.handle.inner.lock().unwrap();
@ -187,7 +190,7 @@ where
return;
}
*inner = Some(error.clone());
*inner = Some(error.clone().into());
drop(inner);
self.rx.close();
@ -200,19 +203,21 @@ where
}
}
impl Handle {
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
impl<E> Handle<E>
where
crate::error::Closed: Into<E>,
{
pub(crate) fn get_error_on_closed(&self) -> E {
self.inner
.lock()
.unwrap()
.as_ref()
.map(|svc_err| svc_err.clone().into())
.take()
.unwrap_or_else(|| Closed::new().into())
}
}
impl Clone for Handle {
fn clone(&self) -> Handle {
impl<E> Clone for Handle<E> {
fn clone(&self) -> Handle<E> {
Handle {
inner: self.inner.clone(),
}

View File

@ -131,31 +131,31 @@ where
}
#[tokio::test]
async fn batch_flushes_on_max_items() {
async fn batch_flushes_on_max_items() -> color_eyre::Result<()> {
use tokio::time::timeout;
install_tracing();
// Use a very long max_latency and a short timeout to check that
// flushing is happening based on hitting max_items.
let verifier = Batch::new(Ed25519Verifier::new(), 10, Duration::from_secs(1000));
assert!(
timeout(Duration::from_secs(1), sign_and_verify(verifier, 100))
.await
.is_ok()
)
let verifier = Batch::<_, _, color_eyre::Report>::new(
Ed25519Verifier::new(),
10,
Duration::from_secs(1000),
);
Ok(timeout(Duration::from_secs(1), sign_and_verify(verifier, 100)).await?)
}
#[tokio::test]
async fn batch_flushes_on_max_latency() {
async fn batch_flushes_on_max_latency() -> color_eyre::Result<()> {
use tokio::time::timeout;
install_tracing();
// Use a very high max_items and a short timeout to check that
// flushing is happening based on hitting max_latency.
let verifier = Batch::new(Ed25519Verifier::new(), 100, Duration::from_millis(500));
assert!(
timeout(Duration::from_secs(1), sign_and_verify(verifier, 10))
.await
.is_ok()
)
let verifier = Batch::<_, _, color_eyre::Report>::new(
Ed25519Verifier::new(),
100,
Duration::from_millis(500),
);
Ok(timeout(Duration::from_secs(1), sign_and_verify(verifier, 10)).await?)
}