393 lines
14 KiB
Rust
393 lines
14 KiB
Rust
//! Batch worker item handling and run loop implementation.
|
|
|
|
use std::{
|
|
pin::Pin,
|
|
sync::{Arc, Mutex},
|
|
};
|
|
|
|
use futures::{
|
|
future::{BoxFuture, OptionFuture},
|
|
stream::FuturesUnordered,
|
|
FutureExt, StreamExt,
|
|
};
|
|
use pin_project::pin_project;
|
|
use tokio::{
|
|
sync::mpsc,
|
|
time::{sleep, Sleep},
|
|
};
|
|
use tokio_util::sync::PollSemaphore;
|
|
use tower::{Service, ServiceExt};
|
|
use tracing_futures::Instrument;
|
|
|
|
use super::{
|
|
error::{Closed, ServiceError},
|
|
message::{self, Message},
|
|
BatchControl,
|
|
};
|
|
|
|
/// Task that handles processing the buffer. This type should not be used
|
|
/// directly, instead `Buffer` requires an `Executor` that can accept this task.
|
|
///
|
|
/// The struct is `pub` in the private module and the type is *not* re-exported
|
|
/// as part of the public API. This is the "sealed" pattern to include "private"
|
|
/// types in public traits that are not meant for consumers of the library to
|
|
/// implement (only call).
|
|
#[pin_project(PinnedDrop)]
|
|
#[derive(Debug)]
|
|
pub struct Worker<T, Request>
|
|
where
|
|
T: Service<BatchControl<Request>>,
|
|
T::Future: Send + 'static,
|
|
T::Error: Into<crate::BoxError>,
|
|
{
|
|
// Batch management
|
|
//
|
|
/// A semaphore-bounded channel for receiving requests from the batch wrapper service.
|
|
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
|
|
|
|
/// The wrapped service that processes batches.
|
|
service: T,
|
|
|
|
/// The number of pending items sent to `service`, since the last batch flush.
|
|
pending_items: usize,
|
|
|
|
/// The timer for the pending batch, if it has any items.
|
|
///
|
|
/// The timer is started when the first entry of a new batch is
|
|
/// submitted, so that the batch latency of all entries is at most
|
|
/// self.max_latency. However, we don't keep the timer running unless
|
|
/// there is a pending request to prevent wakeups on idle services.
|
|
pending_batch_timer: Option<Pin<Box<Sleep>>>,
|
|
|
|
/// The batches that the worker is concurrently executing.
|
|
concurrent_batches: FuturesUnordered<BoxFuture<'static, Result<T::Response, T::Error>>>,
|
|
|
|
// Errors and termination
|
|
//
|
|
/// An error that's populated on permanent service failure.
|
|
failed: Option<ServiceError>,
|
|
|
|
/// A shared error handle that's populated on permanent service failure.
|
|
error_handle: ErrorHandle,
|
|
|
|
/// A cloned copy of the wrapper service's semaphore, used to close the semaphore.
|
|
close: PollSemaphore,
|
|
|
|
// Config
|
|
//
|
|
/// The maximum number of items allowed in a batch.
|
|
max_items_in_batch: usize,
|
|
|
|
/// The maximum number of batches that are allowed to run concurrently.
|
|
max_concurrent_batches: usize,
|
|
|
|
/// The maximum delay before processing a batch with fewer than `max_items_in_batch`.
|
|
max_latency: std::time::Duration,
|
|
}
|
|
|
|
/// Get the error out
|
|
#[derive(Debug)]
|
|
pub(crate) struct ErrorHandle {
|
|
inner: Arc<Mutex<Option<ServiceError>>>,
|
|
}
|
|
|
|
impl<T, Request> Worker<T, Request>
|
|
where
|
|
T: Service<BatchControl<Request>>,
|
|
T::Future: Send + 'static,
|
|
T::Error: Into<crate::BoxError>,
|
|
{
|
|
/// Creates a new batch worker.
|
|
///
|
|
/// See [`Batch::new()`](crate::Batch::new) for details.
|
|
pub(crate) fn new(
|
|
service: T,
|
|
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
|
|
max_items_in_batch: usize,
|
|
max_concurrent_batches: usize,
|
|
max_latency: std::time::Duration,
|
|
close: PollSemaphore,
|
|
) -> (ErrorHandle, Worker<T, Request>) {
|
|
let error_handle = ErrorHandle {
|
|
inner: Arc::new(Mutex::new(None)),
|
|
};
|
|
|
|
let worker = Worker {
|
|
rx,
|
|
service,
|
|
pending_items: 0,
|
|
pending_batch_timer: None,
|
|
concurrent_batches: FuturesUnordered::new(),
|
|
failed: None,
|
|
error_handle: error_handle.clone(),
|
|
close,
|
|
max_items_in_batch,
|
|
max_concurrent_batches,
|
|
max_latency,
|
|
};
|
|
|
|
(error_handle, worker)
|
|
}
|
|
|
|
/// Process a single worker request.
|
|
async fn process_req(&mut self, req: Request, tx: message::Tx<T::Future>) {
|
|
if let Some(ref error) = self.failed {
|
|
tracing::trace!(
|
|
?error,
|
|
"notifying batch request caller about worker failure",
|
|
);
|
|
let _ = tx.send(Err(error.clone()));
|
|
return;
|
|
}
|
|
|
|
match self.service.ready().await {
|
|
Ok(svc) => {
|
|
let rsp = svc.call(req.into());
|
|
let _ = tx.send(Ok(rsp));
|
|
|
|
self.pending_items += 1;
|
|
}
|
|
Err(e) => {
|
|
self.failed(e.into());
|
|
let _ = tx.send(Err(self
|
|
.failed
|
|
.as_ref()
|
|
.expect("Worker::failed did not set self.failed?")
|
|
.clone()));
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Tell the inner service to flush the current batch.
|
|
///
|
|
/// Waits until the inner service is ready,
|
|
/// then stores a future which resolves when the batch finishes.
|
|
async fn flush_service(&mut self) {
|
|
if self.failed.is_some() {
|
|
tracing::trace!("worker failure: skipping flush");
|
|
return;
|
|
}
|
|
|
|
match self.service.ready().await {
|
|
Ok(ready_service) => {
|
|
let flush_future = ready_service.call(BatchControl::Flush);
|
|
self.concurrent_batches.push(flush_future.boxed());
|
|
|
|
// Now we have an empty batch.
|
|
self.pending_items = 0;
|
|
self.pending_batch_timer = None;
|
|
}
|
|
Err(error) => {
|
|
self.failed(error.into());
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Is the current number of concurrent batches above the configured limit?
|
|
fn can_spawn_new_batches(&self) -> bool {
|
|
self.concurrent_batches.len() < self.max_concurrent_batches
|
|
}
|
|
|
|
/// Run loop for batch requests, which implements the batch policies.
|
|
///
|
|
/// See [`Batch::new()`](crate::Batch::new) for details.
|
|
pub async fn run(mut self) {
|
|
loop {
|
|
// Wait on either a new message or the batch timer.
|
|
//
|
|
// If both are ready, end the batch now, because the timer has elapsed.
|
|
// If the timer elapses, any pending messages are preserved:
|
|
// https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.UnboundedReceiver.html#cancel-safety
|
|
tokio::select! {
|
|
biased;
|
|
|
|
batch_result = self.concurrent_batches.next(), if !self.concurrent_batches.is_empty() => match batch_result.expect("only returns None when empty") {
|
|
Ok(_response) => {
|
|
tracing::trace!(
|
|
pending_items = self.pending_items,
|
|
batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
|
|
running_batches = self.concurrent_batches.len(),
|
|
"batch finished executing",
|
|
);
|
|
}
|
|
Err(error) => {
|
|
let error = error.into();
|
|
tracing::trace!(?error, "batch execution failed");
|
|
self.failed(error);
|
|
}
|
|
},
|
|
|
|
Some(()) = OptionFuture::from(self.pending_batch_timer.as_mut()), if self.pending_batch_timer.as_ref().is_some() => {
|
|
tracing::trace!(
|
|
pending_items = self.pending_items,
|
|
batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
|
|
running_batches = self.concurrent_batches.len(),
|
|
"batch timer expired",
|
|
);
|
|
|
|
// TODO: use a batch-specific span to instrument this future.
|
|
self.flush_service().await;
|
|
},
|
|
|
|
maybe_msg = self.rx.recv(), if self.can_spawn_new_batches() => match maybe_msg {
|
|
Some(msg) => {
|
|
tracing::trace!(
|
|
pending_items = self.pending_items,
|
|
batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
|
|
running_batches = self.concurrent_batches.len(),
|
|
"batch message received",
|
|
);
|
|
|
|
let span = msg.span;
|
|
|
|
self.process_req(msg.request, msg.tx)
|
|
// Apply the provided span to request processing.
|
|
.instrument(span)
|
|
.await;
|
|
|
|
// Check whether we have too many pending items.
|
|
if self.pending_items >= self.max_items_in_batch {
|
|
tracing::trace!(
|
|
pending_items = self.pending_items,
|
|
batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
|
|
running_batches = self.concurrent_batches.len(),
|
|
"batch is full",
|
|
);
|
|
|
|
// TODO: use a batch-specific span to instrument this future.
|
|
self.flush_service().await;
|
|
} else if self.pending_items == 1 {
|
|
tracing::trace!(
|
|
pending_items = self.pending_items,
|
|
batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
|
|
running_batches = self.concurrent_batches.len(),
|
|
"batch is new, starting timer",
|
|
);
|
|
|
|
// The first message in a new batch.
|
|
self.pending_batch_timer = Some(Box::pin(sleep(self.max_latency)));
|
|
} else {
|
|
tracing::trace!(
|
|
pending_items = self.pending_items,
|
|
batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
|
|
running_batches = self.concurrent_batches.len(),
|
|
"waiting for full batch or batch timer",
|
|
);
|
|
}
|
|
}
|
|
None => {
|
|
tracing::trace!("batch channel closed and emptied, exiting worker task");
|
|
|
|
return;
|
|
}
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Register an inner service failure.
|
|
///
|
|
/// The underlying service failed when we called `poll_ready` on it with the given `error`. We
|
|
/// need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
|
|
/// an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
|
|
/// requests will also fail with the same error.
|
|
fn failed(&mut self, error: crate::BoxError) {
|
|
tracing::debug!(?error, "batch worker error");
|
|
|
|
// Note that we need to handle the case where some error_handle is concurrently trying to send us
|
|
// a request. We need to make sure that *either* the send of the request fails *or* it
|
|
// receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
|
|
// case where we send errors to all outstanding requests, and *then* the caller sends its
|
|
// request. We do this by *first* exposing the error, *then* closing the channel used to
|
|
// send more requests (so the client will see the error when the send fails), and *then*
|
|
// sending the error to all outstanding requests.
|
|
let error = ServiceError::new(error);
|
|
|
|
let mut inner = self.error_handle.inner.lock().unwrap();
|
|
|
|
// Ignore duplicate failures
|
|
if inner.is_some() {
|
|
return;
|
|
}
|
|
|
|
*inner = Some(error.clone());
|
|
drop(inner);
|
|
|
|
tracing::trace!(
|
|
?error,
|
|
"worker failure: waking pending requests so they can be failed",
|
|
);
|
|
self.rx.close();
|
|
self.close.close();
|
|
|
|
// We don't schedule any batches on an errored service
|
|
self.pending_batch_timer = None;
|
|
|
|
// By closing the mpsc::Receiver, we know that that the run() loop will
|
|
// drain all pending requests. We just need to make sure that any
|
|
// requests that we receive before we've exhausted the receiver receive
|
|
// the error:
|
|
self.failed = Some(error);
|
|
}
|
|
}
|
|
|
|
impl ErrorHandle {
|
|
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
|
|
self.inner
|
|
.lock()
|
|
.expect("previous task panicked while holding the error handle mutex")
|
|
.as_ref()
|
|
.map(|svc_err| svc_err.clone().into())
|
|
.unwrap_or_else(|| Closed::new().into())
|
|
}
|
|
}
|
|
|
|
impl Clone for ErrorHandle {
|
|
fn clone(&self) -> ErrorHandle {
|
|
ErrorHandle {
|
|
inner: self.inner.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[pin_project::pinned_drop]
|
|
impl<T, Request> PinnedDrop for Worker<T, Request>
|
|
where
|
|
T: Service<BatchControl<Request>>,
|
|
T::Future: Send + 'static,
|
|
T::Error: Into<crate::BoxError>,
|
|
{
|
|
fn drop(mut self: Pin<&mut Self>) {
|
|
tracing::trace!(
|
|
pending_items = self.pending_items,
|
|
batch_deadline = ?self.pending_batch_timer.as_ref().map(|sleep| sleep.deadline()),
|
|
running_batches = self.concurrent_batches.len(),
|
|
error = ?self.failed,
|
|
"dropping batch worker",
|
|
);
|
|
|
|
// Fail pending tasks
|
|
self.failed(Closed::new().into());
|
|
|
|
// Fail queued requests
|
|
while let Ok(msg) = self.rx.try_recv() {
|
|
let _ = msg
|
|
.tx
|
|
.send(Err(self.failed.as_ref().expect("just set failed").clone()));
|
|
}
|
|
|
|
// Clear any finished batches, ignoring any errors.
|
|
// Ignore any batches that are still executing, because we can't cancel them.
|
|
//
|
|
// now_or_never() can stop futures waking up, but that's ok here,
|
|
// because we're manually polling, then dropping the stream.
|
|
while let Some(Some(_)) = self
|
|
.as_mut()
|
|
.project()
|
|
.concurrent_batches
|
|
.next()
|
|
.now_or_never()
|
|
{}
|
|
}
|
|
}
|