buffer: wake tasks waiting for channel capacity when terminating (#480)

This commit is contained in:
Eliza Weisman 2020-10-28 08:41:11 -07:00 committed by GitHub
parent 069c9085b1
commit 43c44922af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 183 additions and 4 deletions

View File

@ -80,8 +80,8 @@ where
Request: Send + 'static,
{
let (tx, rx) = mpsc::unbounded_channel();
let (handle, worker) = Worker::new(service, rx);
let semaphore = Semaphore::new(bound);
let (semaphore, wake_waiters) = Semaphore::new_with_close(bound);
let (handle, worker) = Worker::new(service, rx, wake_waiters);
(
Buffer {
tx,

View File

@ -20,7 +20,7 @@ use tower_service::Service;
/// as part of the public API. This is the "sealed" pattern to include "private"
/// types in public traits that are not meant for consumers of the library to
/// implement (only call).
#[pin_project]
#[pin_project(PinnedDrop)]
#[derive(Debug)]
pub struct Worker<T, Request>
where
@ -33,6 +33,7 @@ where
finish: bool,
failed: Option<ServiceError>,
handle: Handle,
close: Option<crate::semaphore::Close>,
}
/// Get the error out
@ -49,6 +50,7 @@ where
pub(crate) fn new(
service: T,
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
close: crate::semaphore::Close,
) -> (Handle, Worker<T, Request>) {
let handle = Handle {
inner: Arc::new(Mutex::new(None)),
@ -61,6 +63,7 @@ where
rx,
service,
handle: handle.clone(),
close: Some(close),
};
(handle, worker)
@ -195,6 +198,11 @@ where
.as_ref()
.expect("Worker::failed did not set self.failed?")
.clone()));
// Wake any tasks waiting on channel capacity.
if let Some(close) = self.close.take() {
tracing::debug!("waking pending tasks");
close.close();
}
}
}
}
@ -208,6 +216,19 @@ where
}
}
#[pin_project::pinned_drop]
impl<T, Request> PinnedDrop for Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<crate::BoxError>,
{
fn drop(mut self: Pin<&mut Self>) {
if let Some(close) = self.as_mut().close.take() {
close.close();
}
}
}
impl Handle {
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
self.inner

View File

@ -5,7 +5,7 @@ use std::{
future::Future,
mem,
pin::Pin,
sync::Arc,
sync::{Arc, Weak},
task::{Context, Poll},
};
use tokio::sync;
@ -16,6 +16,12 @@ pub(crate) struct Semaphore {
state: State,
}
#[derive(Debug)]
pub(crate) struct Close {
semaphore: Weak<sync::Semaphore>,
permits: usize,
}
enum State {
Waiting(Pin<Box<dyn Future<Output = Permit> + Send + 'static>>),
Ready(Permit),
@ -23,6 +29,19 @@ enum State {
}
impl Semaphore {
pub(crate) fn new_with_close(permits: usize) -> (Self, Close) {
let semaphore = Arc::new(sync::Semaphore::new(permits));
let close = Close {
semaphore: Arc::downgrade(&semaphore),
permits,
};
let semaphore = Self {
semaphore,
state: State::Empty,
};
(semaphore, close)
}
pub(crate) fn new(permits: usize) -> Self {
Self {
semaphore: Arc::new(sync::Semaphore::new(permits)),
@ -72,3 +91,23 @@ impl fmt::Debug for State {
}
}
}
impl Close {
/// Close the semaphore, waking any remaining tasks currently awaiting a permit.
pub(crate) fn close(self) {
// The maximum number of permits that a `tokio::sync::Semaphore`
// can hold is usize::MAX >> 3. If we attempt to add more than that
// number of permits, the semaphore will panic.
// XXX(eliza): another shift is kinda janky but if we add (usize::MAX
// > 3 - initial permits) the semaphore impl panics (I think due to a
// bug in tokio?).
// TODO(eliza): Tokio should _really_ just expose `Semaphore::close`
// publicly so we don't have to do this nonsense...
const MAX: usize = std::usize::MAX >> 4;
if let Some(semaphore) = self.semaphore.upgrade() {
// If we added `MAX - available_permits`, any tasks that are
// currently holding permits could drop them, overflowing the max.
semaphore.add_permits(MAX - self.permits);
}
}
}

View File

@ -4,6 +4,7 @@ mod support;
use std::thread;
use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task};
use tower::buffer::{error, Buffer};
use tower::{util::ServiceExt, Service};
use tower_test::{assert_request_eq, mock};
fn let_worker_work() {
@ -227,6 +228,124 @@ async fn waits_for_channel_capacity() {
assert_ready_ok!(response4.poll());
}
#[tokio::test(flavor = "current_thread")]
async fn wakes_pending_waiters_on_close() {
let _t = support::trace_init();
let (service, mut handle) = mock::pair::<_, ()>();
let (mut service, worker) = Buffer::pair(service, 1);
let mut worker = task::spawn(worker);
// keep the request in the worker
handle.allow(0);
let service1 = service.ready_and().await.unwrap();
assert_pending!(worker.poll());
let mut response = task::spawn(service1.call("hello"));
let mut service1 = service.clone();
let mut ready_and1 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and1.poll(), "no capacity");
let mut service1 = service.clone();
let mut ready_and2 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and2.poll(), "no capacity");
// kill the worker task
drop(worker);
let err = assert_ready_err!(response.poll());
assert!(
err.is::<error::Closed>(),
"response should fail with a Closed, got: {:?}",
err
);
assert!(
ready_and1.is_woken(),
"dropping worker should wake ready_and task 1"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::Closed>(),
"ready_and 1 should fail with a Closed, got: {:?}",
err
);
assert!(
ready_and2.is_woken(),
"dropping worker should wake ready_and task 2"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::Closed>(),
"ready_and 2 should fail with a Closed, got: {:?}",
err
);
}
#[tokio::test(flavor = "current_thread")]
async fn wakes_pending_waiters_on_failure() {
let _t = support::trace_init();
let (service, mut handle) = mock::pair::<_, ()>();
let (mut service, worker) = Buffer::pair(service, 1);
let mut worker = task::spawn(worker);
// keep the request in the worker
handle.allow(0);
let service1 = service.ready_and().await.unwrap();
assert_pending!(worker.poll());
let mut response = task::spawn(service1.call("hello"));
let mut service1 = service.clone();
let mut ready_and1 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and1.poll(), "no capacity");
let mut service1 = service.clone();
let mut ready_and2 = task::spawn(service1.ready_and());
assert_pending!(worker.poll());
assert_pending!(ready_and2.poll(), "no capacity");
// fail the inner service
handle.send_error("foobar");
// worker task terminates
assert_ready!(worker.poll());
let err = assert_ready_err!(response.poll());
assert!(
err.is::<error::ServiceError>(),
"response should fail with a ServiceError, got: {:?}",
err
);
assert!(
ready_and1.is_woken(),
"dropping worker should wake ready_and task 1"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::ServiceError>(),
"ready_and 1 should fail with a ServiceError, got: {:?}",
err
);
assert!(
ready_and2.is_woken(),
"dropping worker should wake ready_and task 2"
);
let err = assert_ready_err!(ready_and1.poll());
assert!(
err.is::<error::ServiceError>(),
"ready_and 2 should fail with a ServiceError, got: {:?}",
err
);
}
type Mock = mock::Mock<&'static str, &'static str>;
type Handle = mock::Handle<&'static str, &'static str>;