zebra/tower-batch/src/service.rs

339 lines
12 KiB
Rust

//! Wrapper service for batching items to an underlying service.
use std::{
cmp::max,
fmt,
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
use futures_core::ready;
use tokio::{
pin,
sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore},
task::JoinHandle,
};
use tokio_util::sync::PollSemaphore;
use tower::Service;
use tracing::{info_span, Instrument};
use super::{
future::ResponseFuture,
message::Message,
worker::{ErrorHandle, Worker},
BatchControl,
};
/// The maximum number of batches in the queue.
///
/// This avoids having very large queues on machines with hundreds or thousands of cores.
pub const QUEUE_BATCH_LIMIT: usize = 64;
/// Allows batch processing of requests.
///
/// See the crate documentation for more details.
pub struct Batch<T, Request>
where
T: Service<BatchControl<Request>>,
{
// Batch management
//
/// A custom-bounded channel for sending requests to the batch worker.
///
/// Note: this actually _is_ bounded, but rather than using Tokio's unbounded
/// channel, we use tokio's semaphore separately to implement the bound.
tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
/// A semaphore used to bound the channel.
///
/// When the buffer's channel is full, we want to exert backpressure in
/// `poll_ready`, so that callers such as load balancers could choose to call
/// another service rather than waiting for buffer capacity.
///
/// Unfortunately, this can't be done easily using Tokio's bounded MPSC
/// channel, because it doesn't wake pending tasks on close. Therefore, we implement our
/// own bounded MPSC on top of the unbounded channel, using a semaphore to
/// limit how many items are in the channel.
semaphore: PollSemaphore,
/// A semaphore permit that allows this service to send one message on `tx`.
permit: Option<OwnedSemaphorePermit>,
// Errors
//
/// An error handle shared between all service clones for the same worker.
error_handle: ErrorHandle,
/// A worker task handle shared between all service clones for the same worker.
///
/// Only used when the worker is spawned on the tokio runtime.
worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
}
impl<T, Request> fmt::Debug for Batch<T, Request>
where
T: Service<BatchControl<Request>>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = std::any::type_name::<Self>();
f.debug_struct(name)
.field("tx", &self.tx)
.field("semaphore", &self.semaphore)
.field("permit", &self.permit)
.field("error_handle", &self.error_handle)
.field("worker_handle", &self.worker_handle)
.finish()
}
}
impl<T, Request> Batch<T, Request>
where
T: Service<BatchControl<Request>>,
T::Future: Send + 'static,
T::Error: Into<crate::BoxError>,
{
/// Creates a new `Batch` wrapping `service`.
///
/// The wrapper is responsible for telling the inner service when to flush a
/// batch of requests. These parameters control this policy:
///
/// * `max_items_in_batch` gives the maximum number of items per batch.
/// * `max_batches` is an upper bound on the number of batches in the queue,
/// and the number of concurrently executing batches.
/// If this is `None`, we use the current number of [`rayon`] threads.
/// The number of batches in the queue is also limited by [`QUEUE_BATCH_LIMIT`].
/// * `max_latency` gives the maximum latency for a batch item to start verifying.
///
/// The default Tokio executor is used to run the given service, which means
/// that this method must be called while on the Tokio runtime.
pub fn new(
service: T,
max_items_in_batch: usize,
max_batches: impl Into<Option<usize>>,
max_latency: std::time::Duration,
) -> Self
where
T: Send + 'static,
T::Future: Send,
T::Response: Send,
T::Error: Send + Sync,
Request: Send + 'static,
{
let (mut batch, worker) = Self::pair(service, max_items_in_batch, max_batches, max_latency);
let span = info_span!("batch worker", kind = std::any::type_name::<T>());
#[cfg(tokio_unstable)]
let worker_handle = {
let batch_kind = std::any::type_name::<T>();
// TODO: identify the unique part of the type name generically,
// or make it an argument to this method
let batch_kind = batch_kind.trim_start_matches("zebra_consensus::primitives::");
let batch_kind = batch_kind.trim_end_matches("::Verifier");
tokio::task::Builder::new()
.name(&format!("{} batch", batch_kind))
.spawn(worker.run().instrument(span))
};
#[cfg(not(tokio_unstable))]
let worker_handle = tokio::spawn(worker.run().instrument(span));
batch.register_worker(worker_handle);
batch
}
/// Creates a new `Batch` wrapping `service`, but returns the background worker.
///
/// This is useful if you do not want to spawn directly onto the `tokio`
/// runtime but instead want to use your own executor. This will return the
/// `Batch` and the background `Worker` that you can then spawn.
pub fn pair(
service: T,
max_items_in_batch: usize,
max_batches: impl Into<Option<usize>>,
max_latency: std::time::Duration,
) -> (Self, Worker<T, Request>)
where
T: Send + 'static,
T::Error: Send + Sync,
Request: Send + 'static,
{
let (tx, rx) = mpsc::unbounded_channel();
// Clamp config to sensible values.
let max_items_in_batch = max(max_items_in_batch, 1);
let max_batches = max_batches
.into()
.unwrap_or_else(rayon::current_num_threads);
let max_batches_in_queue = max_batches.clamp(1, QUEUE_BATCH_LIMIT);
// The semaphore bound limits the maximum number of concurrent requests
// (specifically, requests which got a `Ready` from `poll_ready`, but haven't
// used their semaphore reservation in a `call` yet).
//
// We choose a bound that allows callers to check readiness for one batch per rayon CPU thread.
// This helps keep all CPUs filled with work: there is one batch executing, and another ready to go.
// Often there is only one verifier running, when that happens we want it to take all the cores.
let semaphore = Semaphore::new(max_items_in_batch * max_batches_in_queue);
let semaphore = PollSemaphore::new(Arc::new(semaphore));
let (error_handle, worker) = Worker::new(
service,
rx,
max_items_in_batch,
max_batches,
max_latency,
semaphore.clone(),
);
let batch = Batch {
tx,
semaphore,
permit: None,
error_handle,
worker_handle: Arc::new(Mutex::new(None)),
};
(batch, worker)
}
/// Ask the `Batch` to monitor the spawned worker task's [`JoinHandle`](tokio::task::JoinHandle).
///
/// Only used when the task is spawned on the tokio runtime.
pub fn register_worker(&mut self, worker_handle: JoinHandle<()>) {
*self
.worker_handle
.lock()
.expect("previous task panicked while holding the worker handle mutex") =
Some(worker_handle);
}
/// Returns the error from the batch worker's `error_handle`.
fn get_worker_error(&self) -> crate::BoxError {
self.error_handle.get_error_on_closed()
}
}
impl<T, Request> Service<Request> for Batch<T, Request>
where
T: Service<BatchControl<Request>>,
T::Future: Send + 'static,
T::Error: Into<crate::BoxError>,
{
type Response = T::Response;
type Error = crate::BoxError;
type Future = ResponseFuture<T::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// Check to see if the worker has returned or panicked.
//
// Correctness: Registers this task for wakeup when the worker finishes.
if let Some(worker_handle) = self
.worker_handle
.lock()
.expect("previous task panicked while holding the worker handle mutex")
.as_mut()
{
match Pin::new(worker_handle).poll(cx) {
Poll::Ready(Ok(())) => return Poll::Ready(Err(self.get_worker_error())),
Poll::Ready(Err(task_cancelled)) if task_cancelled.is_cancelled() => {
tracing::warn!(
"batch task cancelled: {task_cancelled}\n\
Is Zebra shutting down?"
);
return Poll::Ready(Err(task_cancelled.into()));
}
Poll::Ready(Err(task_panic)) => {
std::panic::resume_unwind(task_panic.into_panic());
}
Poll::Pending => {}
}
}
// Check if the worker has set an error and closed its channels.
//
// Correctness: Registers this task for wakeup when the channel is closed.
let tx = self.tx.clone();
let closed = tx.closed();
pin!(closed);
if closed.poll(cx).is_ready() {
return Poll::Ready(Err(self.get_worker_error()));
}
// Poll to acquire a semaphore permit.
//
// CORRECTNESS
//
// If we acquire a permit, then there's enough buffer capacity to send a new request.
// Otherwise, we need to wait for capacity. When that happens, `poll_acquire()` registers
// this task for wakeup when the next permit is available, or when the semaphore is closed.
//
// When `poll_ready()` is called multiple times, and channel capacity is 1,
// avoid deadlocks by dropping any previous permit before acquiring another one.
// This also stops tasks holding a permit after an error.
//
// Calling `poll_ready()` multiple times can make tasks lose their previous permit
// to another concurrent task.
self.permit = None;
let permit = ready!(self.semaphore.poll_acquire(cx));
if let Some(permit) = permit {
// Calling poll_ready() more than once will drop any previous permit,
// releasing its capacity back to the semaphore.
self.permit = Some(permit);
} else {
// The semaphore has been closed.
return Poll::Ready(Err(self.get_worker_error()));
}
Poll::Ready(Ok(()))
}
fn call(&mut self, request: Request) -> Self::Future {
tracing::trace!("sending request to buffer worker");
let _permit = self
.permit
.take()
.expect("poll_ready must be called before a batch request");
// get the current Span so that we can explicitly propagate it to the worker
// if we didn't do this, events on the worker related to this span wouldn't be counted
// towards that span since the worker would have no way of entering it.
let span = tracing::Span::current();
// If we've made it here, then a semaphore permit has already been
// acquired, so we can freely allocate a oneshot.
let (tx, rx) = oneshot::channel();
match self.tx.send(Message {
request,
tx,
span,
_permit,
}) {
Err(_) => ResponseFuture::failed(self.get_worker_error()),
Ok(_) => ResponseFuture::new(rx),
}
}
}
impl<T, Request> Clone for Batch<T, Request>
where
T: Service<BatchControl<Request>>,
{
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
semaphore: self.semaphore.clone(),
permit: None,
error_handle: self.error_handle.clone(),
worker_handle: self.worker_handle.clone(),
}
}
}