buffer: wake tasks waiting for channel capacity when terminating (#480)
This commit is contained in:
parent
069c9085b1
commit
43c44922af
|
@ -80,8 +80,8 @@ where
|
|||
Request: Send + 'static,
|
||||
{
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
let (handle, worker) = Worker::new(service, rx);
|
||||
let semaphore = Semaphore::new(bound);
|
||||
let (semaphore, wake_waiters) = Semaphore::new_with_close(bound);
|
||||
let (handle, worker) = Worker::new(service, rx, wake_waiters);
|
||||
(
|
||||
Buffer {
|
||||
tx,
|
||||
|
|
|
@ -20,7 +20,7 @@ use tower_service::Service;
|
|||
/// as part of the public API. This is the "sealed" pattern to include "private"
|
||||
/// types in public traits that are not meant for consumers of the library to
|
||||
/// implement (only call).
|
||||
#[pin_project]
|
||||
#[pin_project(PinnedDrop)]
|
||||
#[derive(Debug)]
|
||||
pub struct Worker<T, Request>
|
||||
where
|
||||
|
@ -33,6 +33,7 @@ where
|
|||
finish: bool,
|
||||
failed: Option<ServiceError>,
|
||||
handle: Handle,
|
||||
close: Option<crate::semaphore::Close>,
|
||||
}
|
||||
|
||||
/// Get the error out
|
||||
|
@ -49,6 +50,7 @@ where
|
|||
pub(crate) fn new(
|
||||
service: T,
|
||||
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
|
||||
close: crate::semaphore::Close,
|
||||
) -> (Handle, Worker<T, Request>) {
|
||||
let handle = Handle {
|
||||
inner: Arc::new(Mutex::new(None)),
|
||||
|
@ -61,6 +63,7 @@ where
|
|||
rx,
|
||||
service,
|
||||
handle: handle.clone(),
|
||||
close: Some(close),
|
||||
};
|
||||
|
||||
(handle, worker)
|
||||
|
@ -195,6 +198,11 @@ where
|
|||
.as_ref()
|
||||
.expect("Worker::failed did not set self.failed?")
|
||||
.clone()));
|
||||
// Wake any tasks waiting on channel capacity.
|
||||
if let Some(close) = self.close.take() {
|
||||
tracing::debug!("waking pending tasks");
|
||||
close.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -208,6 +216,19 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[pin_project::pinned_drop]
|
||||
impl<T, Request> PinnedDrop for Worker<T, Request>
|
||||
where
|
||||
T: Service<Request>,
|
||||
T::Error: Into<crate::BoxError>,
|
||||
{
|
||||
fn drop(mut self: Pin<&mut Self>) {
|
||||
if let Some(close) = self.as_mut().close.take() {
|
||||
close.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Handle {
|
||||
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
|
||||
self.inner
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::{
|
|||
future::Future,
|
||||
mem,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
sync::{Arc, Weak},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::sync;
|
||||
|
@ -16,6 +16,12 @@ pub(crate) struct Semaphore {
|
|||
state: State,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Close {
|
||||
semaphore: Weak<sync::Semaphore>,
|
||||
permits: usize,
|
||||
}
|
||||
|
||||
enum State {
|
||||
Waiting(Pin<Box<dyn Future<Output = Permit> + Send + 'static>>),
|
||||
Ready(Permit),
|
||||
|
@ -23,6 +29,19 @@ enum State {
|
|||
}
|
||||
|
||||
impl Semaphore {
|
||||
pub(crate) fn new_with_close(permits: usize) -> (Self, Close) {
|
||||
let semaphore = Arc::new(sync::Semaphore::new(permits));
|
||||
let close = Close {
|
||||
semaphore: Arc::downgrade(&semaphore),
|
||||
permits,
|
||||
};
|
||||
let semaphore = Self {
|
||||
semaphore,
|
||||
state: State::Empty,
|
||||
};
|
||||
(semaphore, close)
|
||||
}
|
||||
|
||||
pub(crate) fn new(permits: usize) -> Self {
|
||||
Self {
|
||||
semaphore: Arc::new(sync::Semaphore::new(permits)),
|
||||
|
@ -72,3 +91,23 @@ impl fmt::Debug for State {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Close {
|
||||
/// Close the semaphore, waking any remaining tasks currently awaiting a permit.
|
||||
pub(crate) fn close(self) {
|
||||
// The maximum number of permits that a `tokio::sync::Semaphore`
|
||||
// can hold is usize::MAX >> 3. If we attempt to add more than that
|
||||
// number of permits, the semaphore will panic.
|
||||
// XXX(eliza): another shift is kinda janky but if we add (usize::MAX
|
||||
// > 3 - initial permits) the semaphore impl panics (I think due to a
|
||||
// bug in tokio?).
|
||||
// TODO(eliza): Tokio should _really_ just expose `Semaphore::close`
|
||||
// publicly so we don't have to do this nonsense...
|
||||
const MAX: usize = std::usize::MAX >> 4;
|
||||
if let Some(semaphore) = self.semaphore.upgrade() {
|
||||
// If we added `MAX - available_permits`, any tasks that are
|
||||
// currently holding permits could drop them, overflowing the max.
|
||||
semaphore.add_permits(MAX - self.permits);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ mod support;
|
|||
use std::thread;
|
||||
use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task};
|
||||
use tower::buffer::{error, Buffer};
|
||||
use tower::{util::ServiceExt, Service};
|
||||
use tower_test::{assert_request_eq, mock};
|
||||
|
||||
fn let_worker_work() {
|
||||
|
@ -227,6 +228,124 @@ async fn waits_for_channel_capacity() {
|
|||
assert_ready_ok!(response4.poll());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn wakes_pending_waiters_on_close() {
|
||||
let _t = support::trace_init();
|
||||
|
||||
let (service, mut handle) = mock::pair::<_, ()>();
|
||||
|
||||
let (mut service, worker) = Buffer::pair(service, 1);
|
||||
let mut worker = task::spawn(worker);
|
||||
|
||||
// keep the request in the worker
|
||||
handle.allow(0);
|
||||
let service1 = service.ready_and().await.unwrap();
|
||||
assert_pending!(worker.poll());
|
||||
let mut response = task::spawn(service1.call("hello"));
|
||||
|
||||
let mut service1 = service.clone();
|
||||
let mut ready_and1 = task::spawn(service1.ready_and());
|
||||
assert_pending!(worker.poll());
|
||||
assert_pending!(ready_and1.poll(), "no capacity");
|
||||
|
||||
let mut service1 = service.clone();
|
||||
let mut ready_and2 = task::spawn(service1.ready_and());
|
||||
assert_pending!(worker.poll());
|
||||
assert_pending!(ready_and2.poll(), "no capacity");
|
||||
|
||||
// kill the worker task
|
||||
drop(worker);
|
||||
|
||||
let err = assert_ready_err!(response.poll());
|
||||
assert!(
|
||||
err.is::<error::Closed>(),
|
||||
"response should fail with a Closed, got: {:?}",
|
||||
err
|
||||
);
|
||||
|
||||
assert!(
|
||||
ready_and1.is_woken(),
|
||||
"dropping worker should wake ready_and task 1"
|
||||
);
|
||||
let err = assert_ready_err!(ready_and1.poll());
|
||||
assert!(
|
||||
err.is::<error::Closed>(),
|
||||
"ready_and 1 should fail with a Closed, got: {:?}",
|
||||
err
|
||||
);
|
||||
|
||||
assert!(
|
||||
ready_and2.is_woken(),
|
||||
"dropping worker should wake ready_and task 2"
|
||||
);
|
||||
let err = assert_ready_err!(ready_and1.poll());
|
||||
assert!(
|
||||
err.is::<error::Closed>(),
|
||||
"ready_and 2 should fail with a Closed, got: {:?}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn wakes_pending_waiters_on_failure() {
|
||||
let _t = support::trace_init();
|
||||
|
||||
let (service, mut handle) = mock::pair::<_, ()>();
|
||||
|
||||
let (mut service, worker) = Buffer::pair(service, 1);
|
||||
let mut worker = task::spawn(worker);
|
||||
|
||||
// keep the request in the worker
|
||||
handle.allow(0);
|
||||
let service1 = service.ready_and().await.unwrap();
|
||||
assert_pending!(worker.poll());
|
||||
let mut response = task::spawn(service1.call("hello"));
|
||||
|
||||
let mut service1 = service.clone();
|
||||
let mut ready_and1 = task::spawn(service1.ready_and());
|
||||
assert_pending!(worker.poll());
|
||||
assert_pending!(ready_and1.poll(), "no capacity");
|
||||
|
||||
let mut service1 = service.clone();
|
||||
let mut ready_and2 = task::spawn(service1.ready_and());
|
||||
assert_pending!(worker.poll());
|
||||
assert_pending!(ready_and2.poll(), "no capacity");
|
||||
|
||||
// fail the inner service
|
||||
handle.send_error("foobar");
|
||||
// worker task terminates
|
||||
assert_ready!(worker.poll());
|
||||
|
||||
let err = assert_ready_err!(response.poll());
|
||||
assert!(
|
||||
err.is::<error::ServiceError>(),
|
||||
"response should fail with a ServiceError, got: {:?}",
|
||||
err
|
||||
);
|
||||
|
||||
assert!(
|
||||
ready_and1.is_woken(),
|
||||
"dropping worker should wake ready_and task 1"
|
||||
);
|
||||
let err = assert_ready_err!(ready_and1.poll());
|
||||
assert!(
|
||||
err.is::<error::ServiceError>(),
|
||||
"ready_and 1 should fail with a ServiceError, got: {:?}",
|
||||
err
|
||||
);
|
||||
|
||||
assert!(
|
||||
ready_and2.is_woken(),
|
||||
"dropping worker should wake ready_and task 2"
|
||||
);
|
||||
let err = assert_ready_err!(ready_and1.poll());
|
||||
assert!(
|
||||
err.is::<error::ServiceError>(),
|
||||
"ready_and 2 should fail with a ServiceError, got: {:?}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
type Mock = mock::Mock<&'static str, &'static str>;
|
||||
type Handle = mock::Handle<&'static str, &'static str>;
|
||||
|
||||
|
|
Loading…
Reference in New Issue