Fix in flight limit (#184)
* Move `InFlightLimitLayer` into `layer` mod * Remove error type
This commit is contained in:
parent
bdcce9677b
commit
1e38ee6e1f
|
@ -5,10 +5,11 @@ authors = ["Carl Lerche <me@carllerche.com>"]
|
|||
publish = false
|
||||
|
||||
[dependencies]
|
||||
futures = "0.1"
|
||||
futures = "0.1.25"
|
||||
tokio-sync = "0.1.3"
|
||||
tower-service = "0.2.0"
|
||||
tower-layer = { version = "0.1", path = "../tower-layer" }
|
||||
tower-layer = { version = "0.1.0", path = "../tower-layer" }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { git = "https://github.com/carllerche/tokio-test" }
|
||||
tokio-mock-task = "0.1.1"
|
||||
tower-mock = { version = "0.1", path = "../tower-mock" }
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
use futures::{Future, Poll};
|
||||
use std::sync::Arc;
|
||||
use tokio_sync::semaphore::Semaphore;
|
||||
use Error;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ResponseFuture<T> {
|
||||
inner: T,
|
||||
semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl<T> ResponseFuture<T> {
|
||||
pub(crate) fn new(inner: T, semaphore: Arc<Semaphore>) -> ResponseFuture<T> {
|
||||
ResponseFuture { inner, semaphore }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for ResponseFuture<T>
|
||||
where
|
||||
T: Future,
|
||||
T::Error: Into<Error>,
|
||||
{
|
||||
type Item = T::Item;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
self.inner.poll().map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for ResponseFuture<T> {
|
||||
fn drop(&mut self) {
|
||||
self.semaphore.add_permits(1);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
use tower_layer::Layer;
|
||||
use tower_service::Service;
|
||||
use {Error, InFlightLimit};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InFlightLimitLayer {
|
||||
max: usize,
|
||||
}
|
||||
|
||||
impl InFlightLimitLayer {
|
||||
pub fn new(max: usize) -> Self {
|
||||
InFlightLimitLayer { max }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, Request> Layer<S, Request> for InFlightLimitLayer
|
||||
where
|
||||
S: Service<Request>,
|
||||
S::Error: Into<Error>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = Error;
|
||||
type LayerError = ();
|
||||
type Service = InFlightLimit<S>;
|
||||
|
||||
fn layer(&self, service: S) -> Result<Self::Service, Self::LayerError> {
|
||||
Ok(InFlightLimit::new(service, self.max))
|
||||
}
|
||||
}
|
|
@ -1,55 +1,37 @@
|
|||
//! Tower middleware that limits the maximum number of in-flight requests for a
|
||||
//! service.
|
||||
|
||||
#[macro_use]
|
||||
extern crate futures;
|
||||
extern crate tokio_sync;
|
||||
extern crate tower_layer;
|
||||
extern crate tower_service;
|
||||
|
||||
use tower_layer::Layer;
|
||||
pub mod future;
|
||||
mod layer;
|
||||
|
||||
use future::ResponseFuture;
|
||||
pub use layer::InFlightLimitLayer;
|
||||
|
||||
use tower_service::Service;
|
||||
|
||||
use futures::task::AtomicTask;
|
||||
use futures::{Async, Future, Poll};
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering::SeqCst;
|
||||
use futures::Poll;
|
||||
use std::sync::Arc;
|
||||
use std::{error, fmt};
|
||||
use tokio_sync::semaphore::{self, Semaphore};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
pub struct InFlightLimit<T> {
|
||||
inner: T,
|
||||
state: State,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InFlightLimitLayer {
|
||||
max: usize,
|
||||
}
|
||||
|
||||
/// Error returned when the service has reached its limit.
|
||||
#[derive(Debug)]
|
||||
pub enum Error<T> {
|
||||
Upstream(T),
|
||||
limit: Limit,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ResponseFuture<T> {
|
||||
inner: T,
|
||||
shared: Arc<Shared>,
|
||||
struct Limit {
|
||||
semaphore: Arc<Semaphore>,
|
||||
permit: semaphore::Permit,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
shared: Arc<Shared>,
|
||||
reserved: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Shared {
|
||||
max: usize,
|
||||
curr: AtomicUsize,
|
||||
task: AtomicTask,
|
||||
}
|
||||
type Error = Box<::std::error::Error + Send + Sync>;
|
||||
|
||||
// ===== impl InFlightLimit =====
|
||||
|
||||
|
@ -61,13 +43,9 @@ impl<T> InFlightLimit<T> {
|
|||
{
|
||||
InFlightLimit {
|
||||
inner,
|
||||
state: State {
|
||||
shared: Arc::new(Shared {
|
||||
max,
|
||||
curr: AtomicUsize::new(0),
|
||||
task: AtomicTask::new(),
|
||||
}),
|
||||
reserved: false,
|
||||
limit: Limit {
|
||||
semaphore: Arc::new(Semaphore::new(max)),
|
||||
permit: semaphore::Permit::new(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -91,179 +69,61 @@ impl<T> InFlightLimit<T> {
|
|||
impl<S, Request> Service<Request> for InFlightLimit<S>
|
||||
where
|
||||
S: Service<Request>,
|
||||
S::Error: Into<Error>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = Error<S::Error>;
|
||||
type Future = ResponseFuture<S::Future>;
|
||||
type Error = Error;
|
||||
type Future = future::ResponseFuture<S::Future>;
|
||||
|
||||
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
|
||||
if self.state.reserved {
|
||||
return self.inner.poll_ready().map_err(Error::Upstream);
|
||||
}
|
||||
try_ready!(self
|
||||
.limit
|
||||
.permit
|
||||
.poll_acquire(&self.limit.semaphore)
|
||||
.map_err(Error::from));
|
||||
|
||||
self.state.shared.task.register();
|
||||
|
||||
if !self.state.shared.reserve() {
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
|
||||
self.state.reserved = true;
|
||||
|
||||
self.inner.poll_ready().map_err(Error::Upstream)
|
||||
self.inner.poll_ready().map_err(Into::into)
|
||||
}
|
||||
|
||||
fn call(&mut self, request: Request) -> Self::Future {
|
||||
// In this implementation, `poll_ready` is not expected to be called
|
||||
// first (though, it might have been).
|
||||
if self.state.reserved {
|
||||
self.state.reserved = false;
|
||||
} else {
|
||||
// Try to reserve
|
||||
if !self.state.shared.reserve() {
|
||||
panic!("service not ready; call poll_ready first");
|
||||
}
|
||||
// Make sure a permit has been acquired
|
||||
if self
|
||||
.limit
|
||||
.permit
|
||||
.try_acquire(&self.limit.semaphore)
|
||||
.is_err()
|
||||
{
|
||||
panic!("max requests in-flight; poll_ready must be called first");
|
||||
}
|
||||
|
||||
ResponseFuture {
|
||||
inner: self.inner.call(request),
|
||||
shared: self.state.shared.clone(),
|
||||
}
|
||||
// Call the inner service
|
||||
let future = self.inner.call(request);
|
||||
|
||||
// Forget the permit, the permit will be returned when
|
||||
// `future::ResponseFuture` is dropped.
|
||||
self.limit.permit.forget();
|
||||
|
||||
ResponseFuture::new(future, self.limit.semaphore.clone())
|
||||
}
|
||||
}
|
||||
|
||||
// ===== impl InFlightLimitLayer =====
|
||||
|
||||
impl InFlightLimitLayer {
|
||||
pub fn new(max: usize) -> Self {
|
||||
InFlightLimitLayer { max }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, Request> Layer<S, Request> for InFlightLimitLayer
|
||||
impl<S> Clone for InFlightLimit<S>
|
||||
where
|
||||
S: Service<Request>,
|
||||
S: Clone,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = Error<S::Error>;
|
||||
type LayerError = ();
|
||||
type Service = InFlightLimit<S>;
|
||||
|
||||
fn layer(&self, service: S) -> Result<Self::Service, Self::LayerError> {
|
||||
Ok(InFlightLimit::new(service, self.max))
|
||||
}
|
||||
}
|
||||
|
||||
// ===== impl ResponseFuture =====
|
||||
|
||||
impl<T> Future for ResponseFuture<T>
|
||||
where
|
||||
T: Future,
|
||||
{
|
||||
type Item = T::Item;
|
||||
type Error = Error<T::Error>;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
use futures::Async::*;
|
||||
|
||||
match self.inner.poll() {
|
||||
Ok(Ready(v)) => Ok(Ready(v)),
|
||||
Ok(NotReady) => {
|
||||
return Ok(NotReady);
|
||||
}
|
||||
Err(e) => Err(Error::Upstream(e)),
|
||||
fn clone(&self) -> InFlightLimit<S> {
|
||||
InFlightLimit {
|
||||
inner: self.inner.clone(),
|
||||
limit: Limit {
|
||||
semaphore: self.limit.semaphore.clone(),
|
||||
permit: semaphore::Permit::new(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for ResponseFuture<T> {
|
||||
impl Drop for Limit {
|
||||
fn drop(&mut self) {
|
||||
self.shared.release();
|
||||
}
|
||||
}
|
||||
|
||||
// ===== impl State =====
|
||||
|
||||
impl Clone for State {
|
||||
fn clone(&self) -> Self {
|
||||
State {
|
||||
shared: self.shared.clone(),
|
||||
reserved: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for State {
|
||||
fn drop(&mut self) {
|
||||
if self.reserved {
|
||||
self.shared.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== impl Shared =====
|
||||
|
||||
impl Shared {
|
||||
/// Attempts to reserve capacity for a request. Returns `true` if the
|
||||
/// reservation is successful.
|
||||
fn reserve(&self) -> bool {
|
||||
let mut curr = self.curr.load(SeqCst);
|
||||
|
||||
loop {
|
||||
if curr == self.max {
|
||||
return false;
|
||||
}
|
||||
|
||||
let actual = self.curr.compare_and_swap(curr, curr + 1, SeqCst);
|
||||
|
||||
if actual == curr {
|
||||
return true;
|
||||
}
|
||||
|
||||
curr = actual;
|
||||
}
|
||||
}
|
||||
|
||||
/// Release a reserved in-flight request. This is called when either the
|
||||
/// request has completed OR the service that made the reservation has
|
||||
/// dropped.
|
||||
pub fn release(&self) {
|
||||
let prev = self.curr.fetch_sub(1, SeqCst);
|
||||
|
||||
// Cannot go above the max number of in-flight
|
||||
debug_assert!(prev <= self.max);
|
||||
|
||||
if prev == self.max {
|
||||
self.task.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== impl Error =====
|
||||
|
||||
impl<T> fmt::Display for Error<T>
|
||||
where
|
||||
T: fmt::Display,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
Error::Upstream(ref why) => fmt::Display::fmt(why, f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> error::Error for Error<T>
|
||||
where
|
||||
T: error::Error,
|
||||
{
|
||||
fn cause(&self) -> Option<&error::Error> {
|
||||
match *self {
|
||||
Error::Upstream(ref why) => Some(why),
|
||||
}
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
match *self {
|
||||
Error::Upstream(_) => "upstream service error",
|
||||
}
|
||||
self.permit.release(&self.semaphore);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
extern crate futures;
|
||||
extern crate tokio_test;
|
||||
extern crate tokio_mock_task;
|
||||
extern crate tower_in_flight_limit;
|
||||
extern crate tower_mock;
|
||||
extern crate tower_service;
|
||||
|
@ -8,7 +8,27 @@ use tower_in_flight_limit::InFlightLimit;
|
|||
use tower_service::Service;
|
||||
|
||||
use futures::future::{poll_fn, Future};
|
||||
use tokio_test::MockTask;
|
||||
use tokio_mock_task::MockTask;
|
||||
|
||||
macro_rules! assert_ready {
|
||||
($e:expr) => {{
|
||||
match $e {
|
||||
Ok(futures::Async::Ready(v)) => v,
|
||||
Ok(_) => panic!("not ready"),
|
||||
Err(e) => panic!("error = {:?}", e),
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! assert_not_ready {
|
||||
($e:expr) => {{
|
||||
match $e {
|
||||
Ok(futures::Async::NotReady) => {}
|
||||
Ok(futures::Async::Ready(v)) => panic!("ready; value = {:?}", v),
|
||||
Err(e) => panic!("error = {:?}", e),
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn basic_service_limit_functionality_with_poll_ready() {
|
||||
|
@ -244,6 +264,33 @@ fn response_future_drop_releases_capacity() {
|
|||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_waiters() {
|
||||
let mut task1 = MockTask::new();
|
||||
let mut task2 = MockTask::new();
|
||||
let mut task3 = MockTask::new();
|
||||
|
||||
let (mut s1, _handle) = new_service(1);
|
||||
let mut s2 = s1.clone();
|
||||
let mut s3 = s1.clone();
|
||||
|
||||
// Reserve capacity in s1
|
||||
task1.enter(|| assert_ready!(s1.poll_ready()));
|
||||
|
||||
// s2 and s3 are not ready
|
||||
task2.enter(|| assert_not_ready!(s2.poll_ready()));
|
||||
task3.enter(|| assert_not_ready!(s3.poll_ready()));
|
||||
|
||||
drop(s1);
|
||||
|
||||
assert!(task2.is_notified());
|
||||
assert!(!task3.is_notified());
|
||||
|
||||
drop(s2);
|
||||
|
||||
assert!(task3.is_notified());
|
||||
}
|
||||
|
||||
type Mock = tower_mock::Mock<&'static str, &'static str>;
|
||||
type Handle = tower_mock::Handle<&'static str, &'static str>;
|
||||
|
||||
|
|
Loading…
Reference in New Issue