diff --git a/tower-in-flight-limit/Cargo.toml b/tower-in-flight-limit/Cargo.toml index c2b072a..2551b55 100644 --- a/tower-in-flight-limit/Cargo.toml +++ b/tower-in-flight-limit/Cargo.toml @@ -5,10 +5,11 @@ authors = ["Carl Lerche "] publish = false [dependencies] -futures = "0.1" +futures = "0.1.25" +tokio-sync = "0.1.3" tower-service = "0.2.0" -tower-layer = { version = "0.1", path = "../tower-layer" } +tower-layer = { version = "0.1.0", path = "../tower-layer" } [dev-dependencies] -tokio-test = { git = "https://github.com/carllerche/tokio-test" } +tokio-mock-task = "0.1.1" tower-mock = { version = "0.1", path = "../tower-mock" } diff --git a/tower-in-flight-limit/src/future.rs b/tower-in-flight-limit/src/future.rs new file mode 100644 index 0000000..1f70417 --- /dev/null +++ b/tower-in-flight-limit/src/future.rs @@ -0,0 +1,35 @@ +use futures::{Future, Poll}; +use std::sync::Arc; +use tokio_sync::semaphore::Semaphore; +use Error; + +#[derive(Debug)] +pub struct ResponseFuture { + inner: T, + semaphore: Arc, +} + +impl ResponseFuture { + pub(crate) fn new(inner: T, semaphore: Arc) -> ResponseFuture { + ResponseFuture { inner, semaphore } + } +} + +impl Future for ResponseFuture +where + T: Future, + T::Error: Into, +{ + type Item = T::Item; + type Error = Error; + + fn poll(&mut self) -> Poll { + self.inner.poll().map_err(Into::into) + } +} + +impl Drop for ResponseFuture { + fn drop(&mut self) { + self.semaphore.add_permits(1); + } +} diff --git a/tower-in-flight-limit/src/layer.rs b/tower-in-flight-limit/src/layer.rs new file mode 100644 index 0000000..4e5e945 --- /dev/null +++ b/tower-in-flight-limit/src/layer.rs @@ -0,0 +1,29 @@ +use tower_layer::Layer; +use tower_service::Service; +use {Error, InFlightLimit}; + +#[derive(Debug, Clone)] +pub struct InFlightLimitLayer { + max: usize, +} + +impl InFlightLimitLayer { + pub fn new(max: usize) -> Self { + InFlightLimitLayer { max } + } +} + +impl Layer for InFlightLimitLayer +where + S: Service, + S::Error: Into, +{ + type Response = S::Response; + type Error = Error; + type LayerError = (); + type Service = InFlightLimit; + + fn layer(&self, service: S) -> Result { + Ok(InFlightLimit::new(service, self.max)) + } +} diff --git a/tower-in-flight-limit/src/lib.rs b/tower-in-flight-limit/src/lib.rs index 0e5c250..8e3f701 100644 --- a/tower-in-flight-limit/src/lib.rs +++ b/tower-in-flight-limit/src/lib.rs @@ -1,55 +1,37 @@ //! Tower middleware that limits the maximum number of in-flight requests for a //! service. +#[macro_use] extern crate futures; +extern crate tokio_sync; extern crate tower_layer; extern crate tower_service; -use tower_layer::Layer; +pub mod future; +mod layer; + +use future::ResponseFuture; +pub use layer::InFlightLimitLayer; + use tower_service::Service; -use futures::task::AtomicTask; -use futures::{Async, Future, Poll}; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::SeqCst; +use futures::Poll; use std::sync::Arc; -use std::{error, fmt}; +use tokio_sync::semaphore::{self, Semaphore}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct InFlightLimit { inner: T, - state: State, -} - -#[derive(Debug, Clone)] -pub struct InFlightLimitLayer { - max: usize, -} - -/// Error returned when the service has reached its limit. -#[derive(Debug)] -pub enum Error { - Upstream(T), + limit: Limit, } #[derive(Debug)] -pub struct ResponseFuture { - inner: T, - shared: Arc, +struct Limit { + semaphore: Arc, + permit: semaphore::Permit, } -#[derive(Debug)] -struct State { - shared: Arc, - reserved: bool, -} - -#[derive(Debug)] -struct Shared { - max: usize, - curr: AtomicUsize, - task: AtomicTask, -} +type Error = Box<::std::error::Error + Send + Sync>; // ===== impl InFlightLimit ===== @@ -61,13 +43,9 @@ impl InFlightLimit { { InFlightLimit { inner, - state: State { - shared: Arc::new(Shared { - max, - curr: AtomicUsize::new(0), - task: AtomicTask::new(), - }), - reserved: false, + limit: Limit { + semaphore: Arc::new(Semaphore::new(max)), + permit: semaphore::Permit::new(), }, } } @@ -91,179 +69,61 @@ impl InFlightLimit { impl Service for InFlightLimit where S: Service, + S::Error: Into, { type Response = S::Response; - type Error = Error; - type Future = ResponseFuture; + type Error = Error; + type Future = future::ResponseFuture; fn poll_ready(&mut self) -> Poll<(), Self::Error> { - if self.state.reserved { - return self.inner.poll_ready().map_err(Error::Upstream); - } + try_ready!(self + .limit + .permit + .poll_acquire(&self.limit.semaphore) + .map_err(Error::from)); - self.state.shared.task.register(); - - if !self.state.shared.reserve() { - return Ok(Async::NotReady); - } - - self.state.reserved = true; - - self.inner.poll_ready().map_err(Error::Upstream) + self.inner.poll_ready().map_err(Into::into) } fn call(&mut self, request: Request) -> Self::Future { - // In this implementation, `poll_ready` is not expected to be called - // first (though, it might have been). - if self.state.reserved { - self.state.reserved = false; - } else { - // Try to reserve - if !self.state.shared.reserve() { - panic!("service not ready; call poll_ready first"); - } + // Make sure a permit has been acquired + if self + .limit + .permit + .try_acquire(&self.limit.semaphore) + .is_err() + { + panic!("max requests in-flight; poll_ready must be called first"); } - ResponseFuture { - inner: self.inner.call(request), - shared: self.state.shared.clone(), - } + // Call the inner service + let future = self.inner.call(request); + + // Forget the permit, the permit will be returned when + // `future::ResponseFuture` is dropped. + self.limit.permit.forget(); + + ResponseFuture::new(future, self.limit.semaphore.clone()) } } -// ===== impl InFlightLimitLayer ===== - -impl InFlightLimitLayer { - pub fn new(max: usize) -> Self { - InFlightLimitLayer { max } - } -} - -impl Layer for InFlightLimitLayer +impl Clone for InFlightLimit where - S: Service, + S: Clone, { - type Response = S::Response; - type Error = Error; - type LayerError = (); - type Service = InFlightLimit; - - fn layer(&self, service: S) -> Result { - Ok(InFlightLimit::new(service, self.max)) - } -} - -// ===== impl ResponseFuture ===== - -impl Future for ResponseFuture -where - T: Future, -{ - type Item = T::Item; - type Error = Error; - - fn poll(&mut self) -> Poll { - use futures::Async::*; - - match self.inner.poll() { - Ok(Ready(v)) => Ok(Ready(v)), - Ok(NotReady) => { - return Ok(NotReady); - } - Err(e) => Err(Error::Upstream(e)), + fn clone(&self) -> InFlightLimit { + InFlightLimit { + inner: self.inner.clone(), + limit: Limit { + semaphore: self.limit.semaphore.clone(), + permit: semaphore::Permit::new(), + }, } } } -impl Drop for ResponseFuture { +impl Drop for Limit { fn drop(&mut self) { - self.shared.release(); - } -} - -// ===== impl State ===== - -impl Clone for State { - fn clone(&self) -> Self { - State { - shared: self.shared.clone(), - reserved: false, - } - } -} - -impl Drop for State { - fn drop(&mut self) { - if self.reserved { - self.shared.release(); - } - } -} - -// ===== impl Shared ===== - -impl Shared { - /// Attempts to reserve capacity for a request. Returns `true` if the - /// reservation is successful. - fn reserve(&self) -> bool { - let mut curr = self.curr.load(SeqCst); - - loop { - if curr == self.max { - return false; - } - - let actual = self.curr.compare_and_swap(curr, curr + 1, SeqCst); - - if actual == curr { - return true; - } - - curr = actual; - } - } - - /// Release a reserved in-flight request. This is called when either the - /// request has completed OR the service that made the reservation has - /// dropped. - pub fn release(&self) { - let prev = self.curr.fetch_sub(1, SeqCst); - - // Cannot go above the max number of in-flight - debug_assert!(prev <= self.max); - - if prev == self.max { - self.task.notify(); - } - } -} - -// ===== impl Error ===== - -impl fmt::Display for Error -where - T: fmt::Display, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::Upstream(ref why) => fmt::Display::fmt(why, f), - } - } -} - -impl error::Error for Error -where - T: error::Error, -{ - fn cause(&self) -> Option<&error::Error> { - match *self { - Error::Upstream(ref why) => Some(why), - } - } - - fn description(&self) -> &str { - match *self { - Error::Upstream(_) => "upstream service error", - } + self.permit.release(&self.semaphore); } } diff --git a/tower-in-flight-limit/tests/in_flight_limit.rs b/tower-in-flight-limit/tests/in_flight_limit.rs index 312b4d5..f3950ba 100644 --- a/tower-in-flight-limit/tests/in_flight_limit.rs +++ b/tower-in-flight-limit/tests/in_flight_limit.rs @@ -1,5 +1,5 @@ extern crate futures; -extern crate tokio_test; +extern crate tokio_mock_task; extern crate tower_in_flight_limit; extern crate tower_mock; extern crate tower_service; @@ -8,7 +8,27 @@ use tower_in_flight_limit::InFlightLimit; use tower_service::Service; use futures::future::{poll_fn, Future}; -use tokio_test::MockTask; +use tokio_mock_task::MockTask; + +macro_rules! assert_ready { + ($e:expr) => {{ + match $e { + Ok(futures::Async::Ready(v)) => v, + Ok(_) => panic!("not ready"), + Err(e) => panic!("error = {:?}", e), + } + }}; +} + +macro_rules! assert_not_ready { + ($e:expr) => {{ + match $e { + Ok(futures::Async::NotReady) => {} + Ok(futures::Async::Ready(v)) => panic!("ready; value = {:?}", v), + Err(e) => panic!("error = {:?}", e), + } + }}; +} #[test] fn basic_service_limit_functionality_with_poll_ready() { @@ -244,6 +264,33 @@ fn response_future_drop_releases_capacity() { }); } +#[test] +fn multi_waiters() { + let mut task1 = MockTask::new(); + let mut task2 = MockTask::new(); + let mut task3 = MockTask::new(); + + let (mut s1, _handle) = new_service(1); + let mut s2 = s1.clone(); + let mut s3 = s1.clone(); + + // Reserve capacity in s1 + task1.enter(|| assert_ready!(s1.poll_ready())); + + // s2 and s3 are not ready + task2.enter(|| assert_not_ready!(s2.poll_ready())); + task3.enter(|| assert_not_ready!(s3.poll_ready())); + + drop(s1); + + assert!(task2.is_notified()); + assert!(!task3.is_notified()); + + drop(s2); + + assert!(task3.is_notified()); +} + type Mock = tower_mock::Mock<&'static str, &'static str>; type Handle = tower_mock::Handle<&'static str, &'static str>;