buffer: wake tasks waiting for channel capacity when terminating (#480)
This commit is contained in:
parent
069c9085b1
commit
43c44922af
|
@ -80,8 +80,8 @@ where
|
||||||
Request: Send + 'static,
|
Request: Send + 'static,
|
||||||
{
|
{
|
||||||
let (tx, rx) = mpsc::unbounded_channel();
|
let (tx, rx) = mpsc::unbounded_channel();
|
||||||
let (handle, worker) = Worker::new(service, rx);
|
let (semaphore, wake_waiters) = Semaphore::new_with_close(bound);
|
||||||
let semaphore = Semaphore::new(bound);
|
let (handle, worker) = Worker::new(service, rx, wake_waiters);
|
||||||
(
|
(
|
||||||
Buffer {
|
Buffer {
|
||||||
tx,
|
tx,
|
||||||
|
|
|
@ -20,7 +20,7 @@ use tower_service::Service;
|
||||||
/// as part of the public API. This is the "sealed" pattern to include "private"
|
/// as part of the public API. This is the "sealed" pattern to include "private"
|
||||||
/// types in public traits that are not meant for consumers of the library to
|
/// types in public traits that are not meant for consumers of the library to
|
||||||
/// implement (only call).
|
/// implement (only call).
|
||||||
#[pin_project]
|
#[pin_project(PinnedDrop)]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Worker<T, Request>
|
pub struct Worker<T, Request>
|
||||||
where
|
where
|
||||||
|
@ -33,6 +33,7 @@ where
|
||||||
finish: bool,
|
finish: bool,
|
||||||
failed: Option<ServiceError>,
|
failed: Option<ServiceError>,
|
||||||
handle: Handle,
|
handle: Handle,
|
||||||
|
close: Option<crate::semaphore::Close>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the error out
|
/// Get the error out
|
||||||
|
@ -49,6 +50,7 @@ where
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
service: T,
|
service: T,
|
||||||
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
|
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
|
||||||
|
close: crate::semaphore::Close,
|
||||||
) -> (Handle, Worker<T, Request>) {
|
) -> (Handle, Worker<T, Request>) {
|
||||||
let handle = Handle {
|
let handle = Handle {
|
||||||
inner: Arc::new(Mutex::new(None)),
|
inner: Arc::new(Mutex::new(None)),
|
||||||
|
@ -61,6 +63,7 @@ where
|
||||||
rx,
|
rx,
|
||||||
service,
|
service,
|
||||||
handle: handle.clone(),
|
handle: handle.clone(),
|
||||||
|
close: Some(close),
|
||||||
};
|
};
|
||||||
|
|
||||||
(handle, worker)
|
(handle, worker)
|
||||||
|
@ -195,6 +198,11 @@ where
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.expect("Worker::failed did not set self.failed?")
|
.expect("Worker::failed did not set self.failed?")
|
||||||
.clone()));
|
.clone()));
|
||||||
|
// Wake any tasks waiting on channel capacity.
|
||||||
|
if let Some(close) = self.close.take() {
|
||||||
|
tracing::debug!("waking pending tasks");
|
||||||
|
close.close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -208,6 +216,19 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pin_project::pinned_drop]
|
||||||
|
impl<T, Request> PinnedDrop for Worker<T, Request>
|
||||||
|
where
|
||||||
|
T: Service<Request>,
|
||||||
|
T::Error: Into<crate::BoxError>,
|
||||||
|
{
|
||||||
|
fn drop(mut self: Pin<&mut Self>) {
|
||||||
|
if let Some(close) = self.as_mut().close.take() {
|
||||||
|
close.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Handle {
|
impl Handle {
|
||||||
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
|
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
|
||||||
self.inner
|
self.inner
|
||||||
|
|
|
@ -5,7 +5,7 @@ use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
mem,
|
mem,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::Arc,
|
sync::{Arc, Weak},
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
use tokio::sync;
|
use tokio::sync;
|
||||||
|
@ -16,6 +16,12 @@ pub(crate) struct Semaphore {
|
||||||
state: State,
|
state: State,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct Close {
|
||||||
|
semaphore: Weak<sync::Semaphore>,
|
||||||
|
permits: usize,
|
||||||
|
}
|
||||||
|
|
||||||
enum State {
|
enum State {
|
||||||
Waiting(Pin<Box<dyn Future<Output = Permit> + Send + 'static>>),
|
Waiting(Pin<Box<dyn Future<Output = Permit> + Send + 'static>>),
|
||||||
Ready(Permit),
|
Ready(Permit),
|
||||||
|
@ -23,6 +29,19 @@ enum State {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Semaphore {
|
impl Semaphore {
|
||||||
|
pub(crate) fn new_with_close(permits: usize) -> (Self, Close) {
|
||||||
|
let semaphore = Arc::new(sync::Semaphore::new(permits));
|
||||||
|
let close = Close {
|
||||||
|
semaphore: Arc::downgrade(&semaphore),
|
||||||
|
permits,
|
||||||
|
};
|
||||||
|
let semaphore = Self {
|
||||||
|
semaphore,
|
||||||
|
state: State::Empty,
|
||||||
|
};
|
||||||
|
(semaphore, close)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn new(permits: usize) -> Self {
|
pub(crate) fn new(permits: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
semaphore: Arc::new(sync::Semaphore::new(permits)),
|
semaphore: Arc::new(sync::Semaphore::new(permits)),
|
||||||
|
@ -72,3 +91,23 @@ impl fmt::Debug for State {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Close {
|
||||||
|
/// Close the semaphore, waking any remaining tasks currently awaiting a permit.
|
||||||
|
pub(crate) fn close(self) {
|
||||||
|
// The maximum number of permits that a `tokio::sync::Semaphore`
|
||||||
|
// can hold is usize::MAX >> 3. If we attempt to add more than that
|
||||||
|
// number of permits, the semaphore will panic.
|
||||||
|
// XXX(eliza): another shift is kinda janky but if we add (usize::MAX
|
||||||
|
// > 3 - initial permits) the semaphore impl panics (I think due to a
|
||||||
|
// bug in tokio?).
|
||||||
|
// TODO(eliza): Tokio should _really_ just expose `Semaphore::close`
|
||||||
|
// publicly so we don't have to do this nonsense...
|
||||||
|
const MAX: usize = std::usize::MAX >> 4;
|
||||||
|
if let Some(semaphore) = self.semaphore.upgrade() {
|
||||||
|
// If we added `MAX - available_permits`, any tasks that are
|
||||||
|
// currently holding permits could drop them, overflowing the max.
|
||||||
|
semaphore.add_permits(MAX - self.permits);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ mod support;
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task};
|
use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task};
|
||||||
use tower::buffer::{error, Buffer};
|
use tower::buffer::{error, Buffer};
|
||||||
|
use tower::{util::ServiceExt, Service};
|
||||||
use tower_test::{assert_request_eq, mock};
|
use tower_test::{assert_request_eq, mock};
|
||||||
|
|
||||||
fn let_worker_work() {
|
fn let_worker_work() {
|
||||||
|
@ -227,6 +228,124 @@ async fn waits_for_channel_capacity() {
|
||||||
assert_ready_ok!(response4.poll());
|
assert_ready_ok!(response4.poll());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "current_thread")]
|
||||||
|
async fn wakes_pending_waiters_on_close() {
|
||||||
|
let _t = support::trace_init();
|
||||||
|
|
||||||
|
let (service, mut handle) = mock::pair::<_, ()>();
|
||||||
|
|
||||||
|
let (mut service, worker) = Buffer::pair(service, 1);
|
||||||
|
let mut worker = task::spawn(worker);
|
||||||
|
|
||||||
|
// keep the request in the worker
|
||||||
|
handle.allow(0);
|
||||||
|
let service1 = service.ready_and().await.unwrap();
|
||||||
|
assert_pending!(worker.poll());
|
||||||
|
let mut response = task::spawn(service1.call("hello"));
|
||||||
|
|
||||||
|
let mut service1 = service.clone();
|
||||||
|
let mut ready_and1 = task::spawn(service1.ready_and());
|
||||||
|
assert_pending!(worker.poll());
|
||||||
|
assert_pending!(ready_and1.poll(), "no capacity");
|
||||||
|
|
||||||
|
let mut service1 = service.clone();
|
||||||
|
let mut ready_and2 = task::spawn(service1.ready_and());
|
||||||
|
assert_pending!(worker.poll());
|
||||||
|
assert_pending!(ready_and2.poll(), "no capacity");
|
||||||
|
|
||||||
|
// kill the worker task
|
||||||
|
drop(worker);
|
||||||
|
|
||||||
|
let err = assert_ready_err!(response.poll());
|
||||||
|
assert!(
|
||||||
|
err.is::<error::Closed>(),
|
||||||
|
"response should fail with a Closed, got: {:?}",
|
||||||
|
err
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
ready_and1.is_woken(),
|
||||||
|
"dropping worker should wake ready_and task 1"
|
||||||
|
);
|
||||||
|
let err = assert_ready_err!(ready_and1.poll());
|
||||||
|
assert!(
|
||||||
|
err.is::<error::Closed>(),
|
||||||
|
"ready_and 1 should fail with a Closed, got: {:?}",
|
||||||
|
err
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
ready_and2.is_woken(),
|
||||||
|
"dropping worker should wake ready_and task 2"
|
||||||
|
);
|
||||||
|
let err = assert_ready_err!(ready_and1.poll());
|
||||||
|
assert!(
|
||||||
|
err.is::<error::Closed>(),
|
||||||
|
"ready_and 2 should fail with a Closed, got: {:?}",
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "current_thread")]
|
||||||
|
async fn wakes_pending_waiters_on_failure() {
|
||||||
|
let _t = support::trace_init();
|
||||||
|
|
||||||
|
let (service, mut handle) = mock::pair::<_, ()>();
|
||||||
|
|
||||||
|
let (mut service, worker) = Buffer::pair(service, 1);
|
||||||
|
let mut worker = task::spawn(worker);
|
||||||
|
|
||||||
|
// keep the request in the worker
|
||||||
|
handle.allow(0);
|
||||||
|
let service1 = service.ready_and().await.unwrap();
|
||||||
|
assert_pending!(worker.poll());
|
||||||
|
let mut response = task::spawn(service1.call("hello"));
|
||||||
|
|
||||||
|
let mut service1 = service.clone();
|
||||||
|
let mut ready_and1 = task::spawn(service1.ready_and());
|
||||||
|
assert_pending!(worker.poll());
|
||||||
|
assert_pending!(ready_and1.poll(), "no capacity");
|
||||||
|
|
||||||
|
let mut service1 = service.clone();
|
||||||
|
let mut ready_and2 = task::spawn(service1.ready_and());
|
||||||
|
assert_pending!(worker.poll());
|
||||||
|
assert_pending!(ready_and2.poll(), "no capacity");
|
||||||
|
|
||||||
|
// fail the inner service
|
||||||
|
handle.send_error("foobar");
|
||||||
|
// worker task terminates
|
||||||
|
assert_ready!(worker.poll());
|
||||||
|
|
||||||
|
let err = assert_ready_err!(response.poll());
|
||||||
|
assert!(
|
||||||
|
err.is::<error::ServiceError>(),
|
||||||
|
"response should fail with a ServiceError, got: {:?}",
|
||||||
|
err
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
ready_and1.is_woken(),
|
||||||
|
"dropping worker should wake ready_and task 1"
|
||||||
|
);
|
||||||
|
let err = assert_ready_err!(ready_and1.poll());
|
||||||
|
assert!(
|
||||||
|
err.is::<error::ServiceError>(),
|
||||||
|
"ready_and 1 should fail with a ServiceError, got: {:?}",
|
||||||
|
err
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
ready_and2.is_woken(),
|
||||||
|
"dropping worker should wake ready_and task 2"
|
||||||
|
);
|
||||||
|
let err = assert_ready_err!(ready_and1.poll());
|
||||||
|
assert!(
|
||||||
|
err.is::<error::ServiceError>(),
|
||||||
|
"ready_and 2 should fail with a ServiceError, got: {:?}",
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
type Mock = mock::Mock<&'static str, &'static str>;
|
type Mock = mock::Mock<&'static str, &'static str>;
|
||||||
type Handle = mock::Handle<&'static str, &'static str>;
|
type Handle = mock::Handle<&'static str, &'static str>;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue