diff --git a/tower/src/buffer/service.rs b/tower/src/buffer/service.rs index 882a35e..d106e76 100644 --- a/tower/src/buffer/service.rs +++ b/tower/src/buffer/service.rs @@ -80,8 +80,8 @@ where Request: Send + 'static, { let (tx, rx) = mpsc::unbounded_channel(); - let (handle, worker) = Worker::new(service, rx); - let semaphore = Semaphore::new(bound); + let (semaphore, wake_waiters) = Semaphore::new_with_close(bound); + let (handle, worker) = Worker::new(service, rx, wake_waiters); ( Buffer { tx, diff --git a/tower/src/buffer/worker.rs b/tower/src/buffer/worker.rs index ca5640b..1f70c8e 100644 --- a/tower/src/buffer/worker.rs +++ b/tower/src/buffer/worker.rs @@ -20,7 +20,7 @@ use tower_service::Service; /// as part of the public API. This is the "sealed" pattern to include "private" /// types in public traits that are not meant for consumers of the library to /// implement (only call). -#[pin_project] +#[pin_project(PinnedDrop)] #[derive(Debug)] pub struct Worker where @@ -33,6 +33,7 @@ where finish: bool, failed: Option, handle: Handle, + close: Option, } /// Get the error out @@ -49,6 +50,7 @@ where pub(crate) fn new( service: T, rx: mpsc::UnboundedReceiver>, + close: crate::semaphore::Close, ) -> (Handle, Worker) { let handle = Handle { inner: Arc::new(Mutex::new(None)), @@ -61,6 +63,7 @@ where rx, service, handle: handle.clone(), + close: Some(close), }; (handle, worker) @@ -195,6 +198,11 @@ where .as_ref() .expect("Worker::failed did not set self.failed?") .clone())); + // Wake any tasks waiting on channel capacity. + if let Some(close) = self.close.take() { + tracing::debug!("waking pending tasks"); + close.close(); + } } } } @@ -208,6 +216,19 @@ where } } +#[pin_project::pinned_drop] +impl PinnedDrop for Worker +where + T: Service, + T::Error: Into, +{ + fn drop(mut self: Pin<&mut Self>) { + if let Some(close) = self.as_mut().close.take() { + close.close(); + } + } +} + impl Handle { pub(crate) fn get_error_on_closed(&self) -> crate::BoxError { self.inner diff --git a/tower/src/semaphore.rs b/tower/src/semaphore.rs index e15b923..ea1c005 100644 --- a/tower/src/semaphore.rs +++ b/tower/src/semaphore.rs @@ -5,7 +5,7 @@ use std::{ future::Future, mem, pin::Pin, - sync::Arc, + sync::{Arc, Weak}, task::{Context, Poll}, }; use tokio::sync; @@ -16,6 +16,12 @@ pub(crate) struct Semaphore { state: State, } +#[derive(Debug)] +pub(crate) struct Close { + semaphore: Weak, + permits: usize, +} + enum State { Waiting(Pin + Send + 'static>>), Ready(Permit), @@ -23,6 +29,19 @@ enum State { } impl Semaphore { + pub(crate) fn new_with_close(permits: usize) -> (Self, Close) { + let semaphore = Arc::new(sync::Semaphore::new(permits)); + let close = Close { + semaphore: Arc::downgrade(&semaphore), + permits, + }; + let semaphore = Self { + semaphore, + state: State::Empty, + }; + (semaphore, close) + } + pub(crate) fn new(permits: usize) -> Self { Self { semaphore: Arc::new(sync::Semaphore::new(permits)), @@ -72,3 +91,23 @@ impl fmt::Debug for State { } } } + +impl Close { + /// Close the semaphore, waking any remaining tasks currently awaiting a permit. + pub(crate) fn close(self) { + // The maximum number of permits that a `tokio::sync::Semaphore` + // can hold is usize::MAX >> 3. If we attempt to add more than that + // number of permits, the semaphore will panic. + // XXX(eliza): another shift is kinda janky but if we add (usize::MAX + // > 3 - initial permits) the semaphore impl panics (I think due to a + // bug in tokio?). + // TODO(eliza): Tokio should _really_ just expose `Semaphore::close` + // publicly so we don't have to do this nonsense... + const MAX: usize = std::usize::MAX >> 4; + if let Some(semaphore) = self.semaphore.upgrade() { + // If we added `MAX - available_permits`, any tasks that are + // currently holding permits could drop them, overflowing the max. + semaphore.add_permits(MAX - self.permits); + } + } +} diff --git a/tower/tests/buffer/main.rs b/tower/tests/buffer/main.rs index 66a768d..41b0336 100644 --- a/tower/tests/buffer/main.rs +++ b/tower/tests/buffer/main.rs @@ -4,6 +4,7 @@ mod support; use std::thread; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task}; use tower::buffer::{error, Buffer}; +use tower::{util::ServiceExt, Service}; use tower_test::{assert_request_eq, mock}; fn let_worker_work() { @@ -227,6 +228,124 @@ async fn waits_for_channel_capacity() { assert_ready_ok!(response4.poll()); } +#[tokio::test(flavor = "current_thread")] +async fn wakes_pending_waiters_on_close() { + let _t = support::trace_init(); + + let (service, mut handle) = mock::pair::<_, ()>(); + + let (mut service, worker) = Buffer::pair(service, 1); + let mut worker = task::spawn(worker); + + // keep the request in the worker + handle.allow(0); + let service1 = service.ready_and().await.unwrap(); + assert_pending!(worker.poll()); + let mut response = task::spawn(service1.call("hello")); + + let mut service1 = service.clone(); + let mut ready_and1 = task::spawn(service1.ready_and()); + assert_pending!(worker.poll()); + assert_pending!(ready_and1.poll(), "no capacity"); + + let mut service1 = service.clone(); + let mut ready_and2 = task::spawn(service1.ready_and()); + assert_pending!(worker.poll()); + assert_pending!(ready_and2.poll(), "no capacity"); + + // kill the worker task + drop(worker); + + let err = assert_ready_err!(response.poll()); + assert!( + err.is::(), + "response should fail with a Closed, got: {:?}", + err + ); + + assert!( + ready_and1.is_woken(), + "dropping worker should wake ready_and task 1" + ); + let err = assert_ready_err!(ready_and1.poll()); + assert!( + err.is::(), + "ready_and 1 should fail with a Closed, got: {:?}", + err + ); + + assert!( + ready_and2.is_woken(), + "dropping worker should wake ready_and task 2" + ); + let err = assert_ready_err!(ready_and1.poll()); + assert!( + err.is::(), + "ready_and 2 should fail with a Closed, got: {:?}", + err + ); +} + +#[tokio::test(flavor = "current_thread")] +async fn wakes_pending_waiters_on_failure() { + let _t = support::trace_init(); + + let (service, mut handle) = mock::pair::<_, ()>(); + + let (mut service, worker) = Buffer::pair(service, 1); + let mut worker = task::spawn(worker); + + // keep the request in the worker + handle.allow(0); + let service1 = service.ready_and().await.unwrap(); + assert_pending!(worker.poll()); + let mut response = task::spawn(service1.call("hello")); + + let mut service1 = service.clone(); + let mut ready_and1 = task::spawn(service1.ready_and()); + assert_pending!(worker.poll()); + assert_pending!(ready_and1.poll(), "no capacity"); + + let mut service1 = service.clone(); + let mut ready_and2 = task::spawn(service1.ready_and()); + assert_pending!(worker.poll()); + assert_pending!(ready_and2.poll(), "no capacity"); + + // fail the inner service + handle.send_error("foobar"); + // worker task terminates + assert_ready!(worker.poll()); + + let err = assert_ready_err!(response.poll()); + assert!( + err.is::(), + "response should fail with a ServiceError, got: {:?}", + err + ); + + assert!( + ready_and1.is_woken(), + "dropping worker should wake ready_and task 1" + ); + let err = assert_ready_err!(ready_and1.poll()); + assert!( + err.is::(), + "ready_and 1 should fail with a ServiceError, got: {:?}", + err + ); + + assert!( + ready_and2.is_woken(), + "dropping worker should wake ready_and task 2" + ); + let err = assert_ready_err!(ready_and1.poll()); + assert!( + err.is::(), + "ready_and 2 should fail with a ServiceError, got: {:?}", + err + ); +} + type Mock = mock::Mock<&'static str, &'static str>; type Handle = mock::Handle<&'static str, &'static str>;