tower-batch: copy tower-buffer source code.

There's a lot of functional overlap between the batch design and tower-buffer's
existing internals, so we'll just vendor its source code and modify it.
If/when we upstream it, we can deduplicate common components.
This commit is contained in:
Henry de Valence 2020-06-12 10:22:08 -07:00
parent dab3eeca3c
commit dcd3f7bb2d
10 changed files with 612 additions and 0 deletions

11
Cargo.lock generated
View File

@ -1923,6 +1923,17 @@ dependencies = [
"tower-util",
]
[[package]]
name = "tower-batch"
version = "0.1.0"
dependencies = [
"futures-core",
"pin-project",
"tokio",
"tower",
"tracing",
]
[[package]]
name = "tower-buffer"
version = "0.3.0"

View File

@ -1,5 +1,6 @@
[workspace]
members = [
"tower-batch",
"zebra-chain",
"zebra-network",
"zebra-state",

13
tower-batch/Cargo.toml Normal file
View File

@ -0,0 +1,13 @@
[package]
name = "tower-batch"
version = "0.1.0"
authors = ["Zcash Foundation <zebra@zfnd.org>"]
license = "MIT"
edition = "2018"
[dependencies]
tokio = { version = "0.2", features = ["time"] }
tower = "0.3"
futures-core = "0.3.5"
pin-project = "0.4.20"
tracing = "0.1.15"

65
tower-batch/src/error.rs Normal file
View File

@ -0,0 +1,65 @@
//! Error types for the `Buffer` middleware.
use crate::BoxError;
use std::{fmt, sync::Arc};
/// An error produced by a `Service` wrapped by a `Buffer`
#[derive(Debug)]
pub struct ServiceError {
inner: Arc<BoxError>,
}
/// An error produced when the a buffer's worker closes unexpectedly.
pub struct Closed {
_p: (),
}
// ===== impl ServiceError =====
impl ServiceError {
pub(crate) fn new(inner: BoxError) -> ServiceError {
let inner = Arc::new(inner);
ServiceError { inner }
}
// Private to avoid exposing `Clone` trait as part of the public API
pub(crate) fn clone(&self) -> ServiceError {
ServiceError {
inner: self.inner.clone(),
}
}
}
impl fmt::Display for ServiceError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "buffered service failed: {}", self.inner)
}
}
impl std::error::Error for ServiceError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&**self.inner)
}
}
// ===== impl Closed =====
impl Closed {
pub(crate) fn new() -> Self {
Closed { _p: () }
}
}
impl fmt::Debug for Closed {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_tuple("Closed").finish()
}
}
impl fmt::Display for Closed {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str("buffer's worker closed unexpectedly")
}
}
impl std::error::Error for Closed {}

68
tower-batch/src/future.rs Normal file
View File

@ -0,0 +1,68 @@
//! Future types for the `Buffer` middleware.
use super::{error::Closed, message};
use futures_core::ready;
use pin_project::{pin_project, project};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
/// Future that completes when the buffered service eventually services the submitted request.
#[pin_project]
#[derive(Debug)]
pub struct ResponseFuture<T> {
#[pin]
state: ResponseState<T>,
}
#[pin_project]
#[derive(Debug)]
enum ResponseState<T> {
Failed(Option<crate::BoxError>),
Rx(#[pin] message::Rx<T>),
Poll(#[pin] T),
}
impl<T> ResponseFuture<T> {
pub(crate) fn new(rx: message::Rx<T>) -> Self {
ResponseFuture {
state: ResponseState::Rx(rx),
}
}
pub(crate) fn failed(err: crate::BoxError) -> Self {
ResponseFuture {
state: ResponseState::Failed(Some(err)),
}
}
}
impl<F, T, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<T, E>>,
E: Into<crate::BoxError>,
{
type Output = Result<T, crate::BoxError>;
#[project]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
#[project]
match this.state.as_mut().project() {
ResponseState::Failed(e) => {
return Poll::Ready(Err(e.take().expect("polled after error")));
}
ResponseState::Rx(rx) => match ready!(rx.poll(cx)) {
Ok(Ok(f)) => this.state.set(ResponseState::Poll(f)),
Ok(Err(e)) => return Poll::Ready(Err(e.into())),
Err(_) => return Poll::Ready(Err(Closed::new().into())),
},
ResponseState::Poll(fut) => return fut.poll(cx).map_err(Into::into),
}
}
}
}

60
tower-batch/src/layer.rs Normal file
View File

@ -0,0 +1,60 @@
use super::service::Buffer;
use std::{fmt, marker::PhantomData};
use tower_layer::Layer;
use tower_service::Service;
/// Adds an mpsc buffer in front of an inner service.
///
/// The default Tokio executor is used to run the given service,
/// which means that this layer can only be used on the Tokio runtime.
///
/// See the module documentation for more details.
pub struct BufferLayer<Request> {
bound: usize,
_p: PhantomData<fn(Request)>,
}
impl<Request> BufferLayer<Request> {
/// Creates a new `BufferLayer` with the provided `bound`.
///
/// `bound` gives the maximal number of requests that can be queued for the service before
/// backpressure is applied to callers.
///
/// # A note on choosing a `bound`
///
/// When `Buffer`'s implementation of `poll_ready` returns `Poll::Ready`, it reserves a
/// slot in the channel for the forthcoming `call()`. However, if this call doesn't arrive,
/// this reserved slot may be held up for a long time. As a result, it's advisable to set
/// `bound` to be at least the maximum number of concurrent requests the `Buffer` will see.
/// If you do not, all the slots in the buffer may be held up by futures that have just called
/// `poll_ready` but will not issue a `call`, which prevents other senders from issuing new
/// requests.
pub fn new(bound: usize) -> Self {
BufferLayer {
bound,
_p: PhantomData,
}
}
}
impl<S, Request> Layer<S> for BufferLayer<Request>
where
S: Service<Request> + Send + 'static,
S::Future: Send,
S::Error: Into<crate::BoxError> + Send + Sync,
Request: Send + 'static,
{
type Service = Buffer<S, Request>;
fn layer(&self, service: S) -> Self::Service {
Buffer::new(service, self.bound)
}
}
impl<Request> fmt::Debug for BufferLayer<Request> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BufferLayer")
.field("bound", &self.bound)
.finish()
}
}

11
tower-batch/src/lib.rs Normal file
View File

@ -0,0 +1,11 @@
pub mod error;
pub mod future;
mod layer;
mod message;
mod service;
mod worker;
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
pub use self::layer::BufferLayer;
pub use self::service::Buffer;

View File

@ -0,0 +1,16 @@
use super::error::ServiceError;
use tokio::sync::oneshot;
/// Message sent over buffer
#[derive(Debug)]
pub(crate) struct Message<Request, Fut> {
pub(crate) request: Request,
pub(crate) tx: Tx<Fut>,
pub(crate) span: tracing::Span,
}
/// Response sender
pub(crate) type Tx<Fut> = oneshot::Sender<Result<Fut, ServiceError>>;
/// Response receiver
pub(crate) type Rx<Fut> = oneshot::Receiver<Result<Fut, ServiceError>>;

139
tower-batch/src/service.rs Normal file
View File

@ -0,0 +1,139 @@
use super::{
future::ResponseFuture,
message::Message,
worker::{Handle, Worker},
};
use futures_core::ready;
use std::task::{Context, Poll};
use tokio::sync::{mpsc, oneshot};
use tower::Service;
/// Adds an mpsc buffer in front of an inner service.
///
/// See the module documentation for more details.
#[derive(Debug)]
pub struct Buffer<T, Request>
where
T: Service<Request>,
{
tx: mpsc::Sender<Message<Request, T::Future>>,
handle: Handle,
}
impl<T, Request> Buffer<T, Request>
where
T: Service<Request>,
T::Error: Into<crate::BoxError>,
{
/// Creates a new `Buffer` wrapping `service`.
///
/// `bound` gives the maximal number of requests that can be queued for the service before
/// backpressure is applied to callers.
///
/// The default Tokio executor is used to run the given service, which means that this method
/// must be called while on the Tokio runtime.
///
/// # A note on choosing a `bound`
///
/// When `Buffer`'s implementation of `poll_ready` returns `Poll::Ready`, it reserves a
/// slot in the channel for the forthcoming `call()`. However, if this call doesn't arrive,
/// this reserved slot may be held up for a long time. As a result, it's advisable to set
/// `bound` to be at least the maximum number of concurrent requests the `Buffer` will see.
/// If you do not, all the slots in the buffer may be held up by futures that have just called
/// `poll_ready` but will not issue a `call`, which prevents other senders from issuing new
/// requests.
pub fn new(service: T, bound: usize) -> Self
where
T: Send + 'static,
T::Future: Send,
T::Error: Send + Sync,
Request: Send + 'static,
{
let (tx, rx) = mpsc::channel(bound);
let (handle, worker) = Worker::new(service, rx);
tokio::spawn(worker);
Buffer { tx, handle }
}
/// Creates a new `Buffer` wrapping `service`, but returns the background worker.
///
/// This is useful if you do not want to spawn directly onto the `tokio` runtime
/// but instead want to use your own executor. This will return the `Buffer` and
/// the background `Worker` that you can then spawn.
pub fn pair(service: T, bound: usize) -> (Buffer<T, Request>, Worker<T, Request>)
where
T: Send + 'static,
T::Error: Send + Sync,
Request: Send + 'static,
{
let (tx, rx) = mpsc::channel(bound);
let (handle, worker) = Worker::new(service, rx);
(Buffer { tx, handle }, worker)
}
fn get_worker_error(&self) -> crate::BoxError {
self.handle.get_error_on_closed()
}
}
impl<T, Request> Service<Request> for Buffer<T, Request>
where
T: Service<Request>,
T::Error: Into<crate::BoxError>,
{
type Response = T::Response;
type Error = crate::BoxError;
type Future = ResponseFuture<T::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// If the inner service has errored, then we error here.
if let Err(_) = ready!(self.tx.poll_ready(cx)) {
Poll::Ready(Err(self.get_worker_error()))
} else {
Poll::Ready(Ok(()))
}
}
fn call(&mut self, request: Request) -> Self::Future {
// TODO:
// ideally we'd poll_ready again here so we don't allocate the oneshot
// if the try_send is about to fail, but sadly we can't call poll_ready
// outside of task context.
let (tx, rx) = oneshot::channel();
// get the current Span so that we can explicitly propagate it to the worker
// if we didn't do this, events on the worker related to this span wouldn't be counted
// towards that span since the worker would have no way of entering it.
let span = tracing::Span::current();
tracing::trace!(parent: &span, "sending request to buffer worker");
match self.tx.try_send(Message { request, span, tx }) {
Err(mpsc::error::TrySendError::Closed(_)) => {
ResponseFuture::failed(self.get_worker_error())
}
Err(mpsc::error::TrySendError::Full(_)) => {
// When `mpsc::Sender::poll_ready` returns `Ready`, a slot
// in the channel is reserved for the handle. Other `Sender`
// handles may not send a message using that slot. This
// guarantees capacity for `request`.
//
// Given this, the only way to hit this code path is if
// `poll_ready` has not been called & `Ready` returned.
panic!("buffer full; poll_ready must be called first");
}
Ok(_) => ResponseFuture::new(rx),
}
}
}
impl<T, Request> Clone for Buffer<T, Request>
where
T: Service<Request>,
{
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
handle: self.handle.clone(),
}
}
}

228
tower-batch/src/worker.rs Normal file
View File

@ -0,0 +1,228 @@
use super::{
error::{Closed, ServiceError},
message::Message,
};
use futures_core::ready;
use pin_project::pin_project;
use std::sync::{Arc, Mutex};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::mpsc;
use tower::Service;
/// Task that handles processing the buffer. This type should not be used
/// directly, instead `Buffer` requires an `Executor` that can accept this task.
///
/// The struct is `pub` in the private module and the type is *not* re-exported
/// as part of the public API. This is the "sealed" pattern to include "private"
/// types in public traits that are not meant for consumers of the library to
/// implement (only call).
#[pin_project]
#[derive(Debug)]
pub struct Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<crate::BoxError>,
{
current_message: Option<Message<Request, T::Future>>,
rx: mpsc::Receiver<Message<Request, T::Future>>,
service: T,
finish: bool,
failed: Option<ServiceError>,
handle: Handle,
}
/// Get the error out
#[derive(Debug)]
pub(crate) struct Handle {
inner: Arc<Mutex<Option<ServiceError>>>,
}
impl<T, Request> Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<crate::BoxError>,
{
pub(crate) fn new(
service: T,
rx: mpsc::Receiver<Message<Request, T::Future>>,
) -> (Handle, Worker<T, Request>) {
let handle = Handle {
inner: Arc::new(Mutex::new(None)),
};
let worker = Worker {
current_message: None,
finish: false,
failed: None,
rx,
service,
handle: handle.clone(),
};
(handle, worker)
}
/// Return the next queued Message that hasn't been canceled.
///
/// If a `Message` is returned, the `bool` is true if this is the first time we received this
/// message, and false otherwise (i.e., we tried to forward it to the backing service before).
fn poll_next_msg(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
if self.finish {
// We've already received None and are shutting down
return Poll::Ready(None);
}
tracing::trace!("worker polling for next message");
if let Some(mut msg) = self.current_message.take() {
// poll_closed returns Poll::Ready is the receiver is dropped.
// Returning Pending means it is still alive, so we should still
// use it.
if msg.tx.poll_closed(cx).is_pending() {
tracing::trace!("resuming buffered request");
return Poll::Ready(Some((msg, false)));
}
tracing::trace!("dropping cancelled buffered request");
}
// Get the next request
while let Some(mut msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
if msg.tx.poll_closed(cx).is_pending() {
tracing::trace!("processing new request");
return Poll::Ready(Some((msg, true)));
}
// Otherwise, request is canceled, so pop the next one.
tracing::trace!("dropping cancelled request");
}
Poll::Ready(None)
}
fn failed(&mut self, error: crate::BoxError) {
// The underlying service failed when we called `poll_ready` on it with the given `error`. We
// need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
// an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
// requests will also fail with the same error.
// Note that we need to handle the case where some handle is concurrently trying to send us
// a request. We need to make sure that *either* the send of the request fails *or* it
// receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
// case where we send errors to all outstanding requests, and *then* the caller sends its
// request. We do this by *first* exposing the error, *then* closing the channel used to
// send more requests (so the client will see the error when the send fails), and *then*
// sending the error to all outstanding requests.
let error = ServiceError::new(error);
let mut inner = self.handle.inner.lock().unwrap();
if inner.is_some() {
// Future::poll was called after we've already errored out!
return;
}
*inner = Some(error.clone());
drop(inner);
self.rx.close();
// By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None),
// which will trigger the `self.finish == true` phase. We just need to make sure that any
// requests that we receive before we've exhausted the receiver receive the error:
self.failed = Some(error);
}
}
impl<T, Request> Future for Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<crate::BoxError>,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.finish {
return Poll::Ready(());
}
loop {
match ready!(self.poll_next_msg(cx)) {
Some((msg, first)) => {
let _guard = msg.span.enter();
if let Some(ref failed) = self.failed {
tracing::trace!("notifying caller about worker failure");
let _ = msg.tx.send(Err(failed.clone()));
continue;
}
// Wait for the service to be ready
tracing::trace!(
resumed = !first,
message = "worker received request; waiting for service readiness"
);
match self.service.poll_ready(cx) {
Poll::Ready(Ok(())) => {
tracing::debug!(service.ready = true, message = "processing request");
let response = self.service.call(msg.request);
// Send the response future back to the sender.
//
// An error means the request had been canceled in-between
// our calls, the response future will just be dropped.
tracing::trace!("returning response future");
let _ = msg.tx.send(Ok(response));
}
Poll::Pending => {
tracing::trace!(service.ready = false, message = "delay");
// Put out current message back in its slot.
drop(_guard);
self.current_message = Some(msg);
return Poll::Pending;
}
Poll::Ready(Err(e)) => {
let error = e.into();
tracing::debug!({ %error }, "service failed");
drop(_guard);
self.failed(error);
let _ = msg.tx.send(Err(self
.failed
.as_ref()
.expect("Worker::failed did not set self.failed?")
.clone()));
}
}
}
None => {
// No more more requests _ever_.
self.finish = true;
return Poll::Ready(());
}
}
}
}
}
impl Handle {
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
self.inner
.lock()
.unwrap()
.as_ref()
.map(|svc_err| svc_err.clone().into())
.unwrap_or_else(|| Closed::new().into())
}
}
impl Clone for Handle {
fn clone(&self) -> Handle {
Handle {
inner: self.inner.clone(),
}
}
}