diff --git a/deny.toml b/deny.toml index 4f0a839..6ec664e 100644 --- a/deny.toml +++ b/deny.toml @@ -15,8 +15,12 @@ confidence-threshold = 0.8 [bans] multiple-versions = "deny" highlight = "all" -skip-tree = [ - { name = "tower", version = "=0.3"} +skip-tree = [{ name = "tower", version = ">=0.3, <=0.4" }] +skip = [ + # `quickcheck` and `tracing-subscriber` depend on incompatible versions of + # `wasi` via their dependencies on `rand` and `chrono`, respectively; we + # can't really fix this. + { name = "wasi" }, ] [sources] diff --git a/examples/Cargo.toml b/examples/Cargo.toml new file mode 100644 index 0000000..0cca426 --- /dev/null +++ b/examples/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "examples" +version = "0.0.0" +publish = false +edition = "2018" + +# If you copy one of the examples into a new project, you should be using +# [dependencies] instead. +[dev-dependencies] +tower = { version = "0.4", path = "../tower", features = ["full"] } +tower-service = "0.3" +tokio = { version = "0.3", features = ["full"] } +rand = "0.7" +pin-project = "1.0" +futures = "0.3" +tracing = "0.1" +tracing-subscriber = "0.2" +hdrhistogram = "7" + +[[example]] +name = "balance" +path = "balance.rs" \ No newline at end of file diff --git a/tower-test/Cargo.toml b/tower-test/Cargo.toml index 1e1db2f..de8c97e 100644 --- a/tower-test/Cargo.toml +++ b/tower-test/Cargo.toml @@ -8,7 +8,7 @@ name = "tower-test" # - README.md # - Update CHANGELOG.md. # - Create "v0.1.x" git tag. -version = "0.3.0" +version = "0.4.0" authors = ["Tower Maintainers "] license = "MIT" readme = "README.md" @@ -23,11 +23,11 @@ edition = "2018" [dependencies] futures-util = { version = "0.3", default-features = false } -tokio = { version = "0.2", features = ["sync"]} +tokio = { version = "0.3", features = ["sync"] } +tokio-test = { version = "0.3" } tower-layer = { version = "0.3", path = "../tower-layer" } -tokio-test = "0.2" tower-service = { version = "0.3" } -pin-project = "0.4.17" +pin-project = "1" [dev-dependencies] -tokio = { version = "0.2", features = ["macros"] } +tokio = { version = "0.3", features = ["macros"] } diff --git a/tower/Cargo.toml b/tower/Cargo.toml index 50a0698..ea7d443 100644 --- a/tower/Cargo.toml +++ b/tower/Cargo.toml @@ -8,7 +8,7 @@ name = "tower" # - README.md # - Update CHANGELOG.md. # - Create "vX.X.X" git tag. -version = "0.3.1" +version = "0.4.0" authors = ["Tower Maintainers "] license = "MIT" readme = "README.md" @@ -26,26 +26,26 @@ edition = "2018" [features] default = ["log"] log = ["tracing/log"] -balance = ["discover", "load", "ready-cache", "make", "rand", "slab"] -buffer = ["tokio/sync", "tokio/rt-core"] +balance = ["discover", "load", "ready-cache", "make", "rand", "slab", "tokio/stream"] +buffer = ["tokio/sync", "tokio/rt", "tokio/stream"] discover = [] filter = [] hedge = ["util", "filter", "futures-util", "hdrhistogram", "tokio/time"] -limit = ["tokio/time"] +limit = ["tokio/time", "tokio/sync"] load = ["tokio/time"] load-shed = [] make = ["tokio/io-std"] ready-cache = ["futures-util", "indexmap", "tokio/sync"] reconnect = ["make", "tokio/io-std"] retry = ["tokio/time"] -spawn-ready = ["futures-util", "tokio/sync", "tokio/rt-core"] +spawn-ready = ["futures-util", "tokio/sync", "tokio/rt"] steer = ["futures-util"] timeout = ["tokio/time"] util = ["futures-util"] [dependencies] futures-core = "0.3" -pin-project = "0.4.17" +pin-project = "1" tower-layer = { version = "0.3", path = "../tower-layer" } tower-service = { version = "0.3" } tracing = "0.1.2" @@ -55,16 +55,16 @@ hdrhistogram = { version = "6.0", optional = true } indexmap = { version = "1.0.2", optional = true } rand = { version = "0.7", features = ["small_rng"], optional = true } slab = { version = "0.4", optional = true } -tokio = { version = "0.2", optional = true, features = ["sync"] } +tokio = { version = "0.3", optional = true, features = ["sync"] } [dev-dependencies] futures-util = { version = "0.3", default-features = false, features = ["alloc", "async-await"] } hdrhistogram = "6.0" quickcheck = { version = "0.9", default-features = false } -tokio = { version = "0.2", features = ["macros", "stream", "sync", "test-util" ] } -tokio-test = "0.2" -tower-test = { version = "0.3", path = "../tower-test" } -tracing-subscriber = "0.1.1" +tokio = { version = "0.3", features = ["macros", "stream", "sync", "test-util", "rt-multi-thread"] } +tokio-test = "0.3" +tower-test = { version = "0.4", path = "../tower-test" } +tracing-subscriber = "0.2.14" # env_logger = { version = "0.5.3", default-features = false } # log = "0.4.1" diff --git a/tower/examples/tower-balance.rs b/tower/examples/tower-balance.rs index 649cbeb..131d51f 100644 --- a/tower/examples/tower-balance.rs +++ b/tower/examples/tower-balance.rs @@ -118,7 +118,7 @@ fn gen_disco() -> impl Discover< let latency = Duration::from_millis(rand::thread_rng().gen_range(0, maxms)); async move { - time::delay_until(start + latency).await; + time::sleep_until(start + latency).await; let latency = start.elapsed(); Ok(Rsp { latency, instance }) } diff --git a/tower/src/balance/pool/mod.rs b/tower/src/balance/pool/mod.rs index 799aaf8..705ae91 100644 --- a/tower/src/balance/pool/mod.rs +++ b/tower/src/balance/pool/mod.rs @@ -90,7 +90,7 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); - while let Poll::Ready(Some(sid)) = this.died_rx.as_mut().poll_recv(cx) { + while let Poll::Ready(Some(sid)) = this.died_rx.as_mut().poll_next(cx) { this.services.remove(sid); tracing::trace!( pool.services = this.services.len(), diff --git a/tower/src/buffer/message.rs b/tower/src/buffer/message.rs index 6d13aa1..069828e 100644 --- a/tower/src/buffer/message.rs +++ b/tower/src/buffer/message.rs @@ -7,6 +7,7 @@ pub(crate) struct Message { pub(crate) request: Request, pub(crate) tx: Tx, pub(crate) span: tracing::Span, + pub(super) _permit: crate::semaphore::Permit, } /// Response sender diff --git a/tower/src/buffer/service.rs b/tower/src/buffer/service.rs index ce315ff..882a35e 100644 --- a/tower/src/buffer/service.rs +++ b/tower/src/buffer/service.rs @@ -4,6 +4,7 @@ use super::{ worker::{Handle, Worker}, }; +use crate::semaphore::Semaphore; use futures_core::ready; use std::task::{Context, Poll}; use tokio::sync::{mpsc, oneshot}; @@ -17,7 +18,19 @@ pub struct Buffer where T: Service, { - tx: mpsc::Sender>, + // Note: this actually _is_ bounded, but rather than using Tokio's unbounded + // channel, we use tokio's semaphore separately to implement the bound. + tx: mpsc::UnboundedSender>, + // When the buffer's channel is full, we want to exert backpressure in + // `poll_ready`, so that callers such as load balancers could choose to call + // another service rather than waiting for buffer capacity. + // + // Unfortunately, this can't be done easily using Tokio's bounded MPSC + // channel, because it doesn't expose a polling-based interface, only an + // `async fn ready`, which borrows the sender. Therefore, we implement our + // own bounded MPSC on top of the unbounded channel, using a semaphore to + // limit how many items are in the channel. + semaphore: Semaphore, handle: Handle, } @@ -50,10 +63,9 @@ where T::Error: Send + Sync, Request: Send + 'static, { - let (tx, rx) = mpsc::channel(bound); - let (handle, worker) = Worker::new(service, rx); + let (service, worker) = Self::pair(service, bound); tokio::spawn(worker); - Buffer { tx, handle } + service } /// Creates a new `Buffer` wrapping `service`, but returns the background worker. @@ -67,9 +79,17 @@ where T::Error: Send + Sync, Request: Send + 'static, { - let (tx, rx) = mpsc::channel(bound); + let (tx, rx) = mpsc::unbounded_channel(); let (handle, worker) = Worker::new(service, rx); - (Buffer { tx, handle }, worker) + let semaphore = Semaphore::new(bound); + ( + Buffer { + tx, + handle, + semaphore, + }, + worker, + ) } fn get_worker_error(&self) -> crate::BoxError { @@ -87,40 +107,43 @@ where type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - // If the inner service has errored, then we error here. - if let Err(_) = ready!(self.tx.poll_ready(cx)) { - Poll::Ready(Err(self.get_worker_error())) - } else { - Poll::Ready(Ok(())) + // First, check if the worker is still alive. + if self.tx.is_closed() { + // If the inner service has errored, then we error here. + return Poll::Ready(Err(self.get_worker_error())); } + + // Then, poll to acquire a semaphore permit. If we acquire a permit, + // then there's enough buffer capacity to send a new request. Otherwise, + // we need to wait for capacity. + ready!(self.semaphore.poll_acquire(cx)); + + Poll::Ready(Ok(())) } fn call(&mut self, request: Request) -> Self::Future { - // TODO: - // ideally we'd poll_ready again here so we don't allocate the oneshot - // if the try_send is about to fail, but sadly we can't call poll_ready - // outside of task context. - let (tx, rx) = oneshot::channel(); + tracing::trace!("sending request to buffer worker"); + let _permit = self + .semaphore + .take_permit() + .expect("buffer full; poll_ready must be called first"); // get the current Span so that we can explicitly propagate it to the worker // if we didn't do this, events on the worker related to this span wouldn't be counted // towards that span since the worker would have no way of entering it. let span = tracing::Span::current(); - tracing::trace!(parent: &span, "sending request to buffer worker"); - match self.tx.try_send(Message { request, span, tx }) { - Err(mpsc::error::TrySendError::Closed(_)) => { - ResponseFuture::failed(self.get_worker_error()) - } - Err(mpsc::error::TrySendError::Full(_)) => { - // When `mpsc::Sender::poll_ready` returns `Ready`, a slot - // in the channel is reserved for the handle. Other `Sender` - // handles may not send a message using that slot. This - // guarantees capacity for `request`. - // - // Given this, the only way to hit this code path is if - // `poll_ready` has not been called & `Ready` returned. - panic!("buffer full; poll_ready must be called first"); - } + + // If we've made it here, then a semaphore permit has already been + // acquired, so we can freely allocate a oneshot. + let (tx, rx) = oneshot::channel(); + + match self.tx.send(Message { + request, + span, + tx, + _permit, + }) { + Err(_) => ResponseFuture::failed(self.get_worker_error()), Ok(_) => ResponseFuture::new(rx), } } @@ -134,6 +157,7 @@ where Self { tx: self.tx.clone(), handle: self.handle.clone(), + semaphore: self.semaphore.clone(), } } } diff --git a/tower/src/buffer/worker.rs b/tower/src/buffer/worker.rs index 2c2ae10..ca5640b 100644 --- a/tower/src/buffer/worker.rs +++ b/tower/src/buffer/worker.rs @@ -10,7 +10,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tokio::sync::mpsc; +use tokio::{stream::Stream, sync::mpsc}; use tower_service::Service; /// Task that handles processing the buffer. This type should not be used @@ -28,7 +28,7 @@ where T::Error: Into, { current_message: Option>, - rx: mpsc::Receiver>, + rx: mpsc::UnboundedReceiver>, service: T, finish: bool, failed: Option, @@ -48,7 +48,7 @@ where { pub(crate) fn new( service: T, - rx: mpsc::Receiver>, + rx: mpsc::UnboundedReceiver>, ) -> (Handle, Worker) { let handle = Handle { inner: Arc::new(Mutex::new(None)), @@ -80,11 +80,11 @@ where } tracing::trace!("worker polling for next message"); - if let Some(mut msg) = self.current_message.take() { - // poll_closed returns Poll::Ready is the receiver is dropped. - // Returning Pending means it is still alive, so we should still - // use it. - if msg.tx.poll_closed(cx).is_pending() { + if let Some(msg) = self.current_message.take() { + // If the oneshot sender is closed, then the receiver is dropped, + // and nobody cares about the response. If this is the case, we + // should continue to the next request. + if !msg.tx.is_closed() { tracing::trace!("resuming buffered request"); return Poll::Ready(Some((msg, false))); } @@ -93,8 +93,8 @@ where } // Get the next request - while let Some(mut msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) { - if msg.tx.poll_closed(cx).is_pending() { + while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_next(cx)) { + if !msg.tx.is_closed() { tracing::trace!("processing new request"); return Poll::Ready(Some((msg, true))); } diff --git a/tower/src/hedge/delay.rs b/tower/src/hedge/delay.rs index 21c6fd3..7f35a70 100644 --- a/tower/src/hedge/delay.rs +++ b/tower/src/hedge/delay.rs @@ -37,7 +37,7 @@ where #[pin_project(project = StateProj)] #[derive(Debug)] enum State { - Delaying(#[pin] tokio::time::Delay, Option), + Delaying(#[pin] tokio::time::Sleep, Option), Called(#[pin] F), } @@ -70,10 +70,10 @@ where } fn call(&mut self, request: Request) -> Self::Future { - let deadline = tokio::time::Instant::now() + self.policy.delay(&request); + let delay = self.policy.delay(&request); ResponseFuture { service: Some(self.service.clone()), - state: State::Delaying(tokio::time::delay_until(deadline), Some(request)), + state: State::Delaying(tokio::time::sleep(delay), Some(request)), } } } diff --git a/tower/src/hedge/mod.rs b/tower/src/hedge/mod.rs index e21415a..3552240 100644 --- a/tower/src/hedge/mod.rs +++ b/tower/src/hedge/mod.rs @@ -1,12 +1,7 @@ //! Pre-emptively retry requests which have been outstanding for longer //! than a given latency percentile. -#![warn( - missing_debug_implementations, - missing_docs, - rust_2018_idioms, - unreachable_pub -)] +#![warn(missing_debug_implementations, missing_docs, unreachable_pub)] use crate::filter::Filter; use futures_util::future; diff --git a/tower/src/lib.rs b/tower/src/lib.rs index 3049301..9af0760 100644 --- a/tower/src/lib.rs +++ b/tower/src/lib.rs @@ -77,6 +77,9 @@ pub use tower_layer::Layer; #[doc(inline)] pub use tower_service::Service; +#[cfg(any(feature = "buffer", feature = "limit"))] +mod semaphore; + #[allow(unreachable_pub)] mod sealed { pub trait Sealed {} diff --git a/tower/src/limit/concurrency/future.rs b/tower/src/limit/concurrency/future.rs index ff82c21..eda2b3f 100644 --- a/tower/src/limit/concurrency/future.rs +++ b/tower/src/limit/concurrency/future.rs @@ -1,5 +1,6 @@ //! Future types //! +use crate::semaphore::Permit; use futures_core::ready; use pin_project::pin_project; use std::{ @@ -7,7 +8,6 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tokio::sync::OwnedSemaphorePermit; /// Future for the `ConcurrencyLimit` service. #[pin_project] @@ -16,11 +16,11 @@ pub struct ResponseFuture { #[pin] inner: T, // Keep this around so that it is dropped when the future completes - _permit: OwnedSemaphorePermit, + _permit: Permit, } impl ResponseFuture { - pub(crate) fn new(inner: T, _permit: OwnedSemaphorePermit) -> ResponseFuture { + pub(crate) fn new(inner: T, _permit: Permit) -> ResponseFuture { ResponseFuture { inner, _permit } } } diff --git a/tower/src/limit/concurrency/service.rs b/tower/src/limit/concurrency/service.rs index 1a544a7..790b46a 100644 --- a/tower/src/limit/concurrency/service.rs +++ b/tower/src/limit/concurrency/service.rs @@ -1,39 +1,24 @@ use super::future::ResponseFuture; - +use crate::semaphore::Semaphore; use tower_service::Service; use futures_core::ready; -use std::fmt; -use std::future::Future; -use std::mem; -use std::pin::Pin; -use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; /// Enforces a limit on the concurrent number of requests the underlying /// service can handle. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ConcurrencyLimit { inner: T, - semaphore: Arc, - state: State, -} - -enum State { - Waiting(Pin + Send + 'static>>), - Ready(OwnedSemaphorePermit), - Empty, + semaphore: Semaphore, } impl ConcurrencyLimit { /// Create a new concurrency limiter. pub fn new(inner: T, max: usize) -> Self { - let semaphore = Arc::new(Semaphore::new(max)); ConcurrencyLimit { inner, - semaphore, - state: State::Empty, + semaphore: Semaphore::new(max), } } @@ -62,27 +47,18 @@ where type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - self.state = match self.state { - State::Ready(_) => return self.inner.poll_ready(cx), - State::Waiting(ref mut fut) => { - tokio::pin!(fut); - let permit = ready!(fut.poll(cx)); - State::Ready(permit) - } - State::Empty => State::Waiting(Box::pin(self.semaphore.clone().acquire_owned())), - }; - } + // First, poll the semaphore... + ready!(self.semaphore.poll_acquire(cx)); + // ...and if it's ready, poll the inner service. + self.inner.poll_ready(cx) } fn call(&mut self, request: Request) -> Self::Future { - // Make sure a permit has been acquired - let permit = match mem::replace(&mut self.state, State::Empty) { - // Take the permit. - State::Ready(permit) => permit, - // whoopsie! - _ => panic!("max requests in-flight; poll_ready must be called first"), - }; + // Take the permit + let permit = self + .semaphore + .take_permit() + .expect("max requests in-flight; poll_ready must be called first"); // Call the inner service let future = self.inner.call(request); @@ -101,29 +77,3 @@ where self.inner.load() } } - -impl Clone for ConcurrencyLimit -where - S: Clone, -{ - fn clone(&self) -> ConcurrencyLimit { - ConcurrencyLimit { - inner: self.inner.clone(), - semaphore: self.semaphore.clone(), - state: State::Empty, - } - } -} - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::Waiting(_) => f - .debug_tuple("State::Waiting") - .field(&format_args!("...")) - .finish(), - State::Ready(ref r) => f.debug_tuple("State::Ready").field(&r).finish(), - State::Empty => f.debug_tuple("State::Empty").finish(), - } - } -} diff --git a/tower/src/limit/rate/service.rs b/tower/src/limit/rate/service.rs index 7e426c6..e7332ea 100644 --- a/tower/src/limit/rate/service.rs +++ b/tower/src/limit/rate/service.rs @@ -5,7 +5,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tokio::time::{Delay, Instant}; +use tokio::time::{Instant, Sleep}; use tower_service::Service; /// Enforces a rate limit on the number of requests the underlying @@ -20,7 +20,7 @@ pub struct RateLimit { #[derive(Debug)] enum State { // The service has hit its limit - Limited(Delay), + Limited(Sleep), Ready { until: Instant, rem: u64 }, } @@ -98,7 +98,7 @@ where self.state = State::Ready { until, rem }; } else { // The service is disabled until further notice - let sleep = tokio::time::delay_until(until); + let sleep = tokio::time::sleep_until(until); self.state = State::Limited(sleep); } diff --git a/tower/src/semaphore.rs b/tower/src/semaphore.rs new file mode 100644 index 0000000..e15b923 --- /dev/null +++ b/tower/src/semaphore.rs @@ -0,0 +1,74 @@ +pub(crate) use self::sync::OwnedSemaphorePermit as Permit; +use futures_core::ready; +use std::{ + fmt, + future::Future, + mem, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::sync; + +#[derive(Debug)] +pub(crate) struct Semaphore { + semaphore: Arc, + state: State, +} + +enum State { + Waiting(Pin + Send + 'static>>), + Ready(Permit), + Empty, +} + +impl Semaphore { + pub(crate) fn new(permits: usize) -> Self { + Self { + semaphore: Arc::new(sync::Semaphore::new(permits)), + state: State::Empty, + } + } + + pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<()> { + loop { + self.state = match self.state { + State::Ready(_) => return Poll::Ready(()), + State::Waiting(ref mut fut) => { + let permit = ready!(Pin::new(fut).poll(cx)); + State::Ready(permit) + } + State::Empty => State::Waiting(Box::pin(self.semaphore.clone().acquire_owned())), + }; + } + } + + pub(crate) fn take_permit(&mut self) -> Option { + if let State::Ready(permit) = mem::replace(&mut self.state, State::Empty) { + return Some(permit); + } + None + } +} + +impl Clone for Semaphore { + fn clone(&self) -> Self { + Self { + semaphore: self.semaphore.clone(), + state: State::Empty, + } + } +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + State::Waiting(_) => f + .debug_tuple("State::Waiting") + .field(&format_args!("...")) + .finish(), + State::Ready(ref r) => f.debug_tuple("State::Ready").field(&r).finish(), + State::Empty => f.debug_tuple("State::Empty").finish(), + } + } +} diff --git a/tower/src/spawn_ready/future.rs b/tower/src/spawn_ready/future.rs index 17478c1..892c540 100644 --- a/tower/src/spawn_ready/future.rs +++ b/tower/src/spawn_ready/future.rs @@ -49,7 +49,18 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - if let Poll::Ready(_) = Pin::new(this.tx.as_mut().expect("illegal state")).poll_closed(cx) { + // Is the channel sender closed? + // Note that we must actually poll the sender's closed future here, + // rather than just calling `is_closed` on it, since we want to be + // notified if the receiver is dropped. + let closed = { + // TODO(eliza): once `tokio` 0.3.2 is released, we can change this back + // to just using `Sender::poll_closed`, which is being re-added. + let closed = this.tx.as_mut().expect("illegal state").closed(); + tokio::pin!(closed); + closed.poll(cx) + }; + if let Poll::Ready(_) = closed { return Poll::Ready(()); } diff --git a/tower/src/timeout/future.rs b/tower/src/timeout/future.rs index 5e61cb6..661c259 100644 --- a/tower/src/timeout/future.rs +++ b/tower/src/timeout/future.rs @@ -7,7 +7,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tokio::time::Delay; +use tokio::time::Sleep; /// `Timeout` response future #[pin_project] @@ -16,11 +16,11 @@ pub struct ResponseFuture { #[pin] response: T, #[pin] - sleep: Delay, + sleep: Sleep, } impl ResponseFuture { - pub(crate) fn new(response: T, sleep: Delay) -> Self { + pub(crate) fn new(response: T, sleep: Sleep) -> Self { ResponseFuture { response, sleep } } } diff --git a/tower/src/timeout/mod.rs b/tower/src/timeout/mod.rs index ac5ae93..2407078 100644 --- a/tower/src/timeout/mod.rs +++ b/tower/src/timeout/mod.rs @@ -63,7 +63,7 @@ where fn call(&mut self, request: Request) -> Self::Future { let response = self.inner.call(request); - let sleep = tokio::time::delay_for(self.timeout); + let sleep = tokio::time::sleep(self.timeout); ResponseFuture::new(response, sleep) } diff --git a/tower/tests/buffer/main.rs b/tower/tests/buffer/main.rs index 5fa32c5..67c75ae 100644 --- a/tower/tests/buffer/main.rs +++ b/tower/tests/buffer/main.rs @@ -14,6 +14,7 @@ fn let_worker_work() { async fn req_and_res() { let (mut service, mut handle) = new_service(); + assert_ready_ok!(service.poll_ready()); let mut response = task::spawn(service.call("hello")); assert_request_eq!(handle, "hello").send_response("world"); @@ -28,16 +29,18 @@ async fn clears_canceled_requests() { handle.allow(1); + assert_ready_ok!(service.poll_ready()); let mut res1 = task::spawn(service.call("hello")); let send_response1 = assert_request_eq!(handle, "hello"); // don't respond yet, new requests will get buffered - + assert_ready_ok!(service.poll_ready()); let res2 = task::spawn(service.call("hello2")); assert_pending!(handle.poll_request()); + assert_ready_ok!(service.poll_ready()); let mut res3 = task::spawn(service.call("hello3")); drop(res2); @@ -63,6 +66,7 @@ async fn when_inner_is_not_ready() { // Make the service NotReady handle.allow(0); + assert_ready_ok!(service.poll_ready()); let mut res1 = task::spawn(service.call("hello")); let_worker_work(); @@ -87,6 +91,7 @@ async fn when_inner_fails() { handle.allow(0); handle.send_error("foobar"); + assert_ready_ok!(service.poll_ready()); let mut res1 = task::spawn(service.call("hello")); let_worker_work(); @@ -125,6 +130,7 @@ async fn response_future_when_worker_is_dropped_early() { // keep the request in the worker handle.allow(0); + assert_ready_ok!(service.poll_ready()); let mut response = task::spawn(service.call("hello")); drop(worker); @@ -134,13 +140,90 @@ async fn response_future_when_worker_is_dropped_early() { assert!(err.is::(), "should be a Closed: {:?}", err); } +#[tokio::test] +async fn waits_for_channel_capacity() { + let (service, mut handle) = mock::pair::<&'static str, &'static str>(); + + let (service, worker) = Buffer::pair(service, 3); + + let mut service = mock::Spawn::new(service); + let mut worker = task::spawn(worker); + + // keep requests in the worker + handle.allow(0); + assert_ready_ok!(service.poll_ready()); + let mut response1 = task::spawn(service.call("hello")); + assert_pending!(worker.poll()); + + assert_ready_ok!(service.poll_ready()); + let mut response2 = task::spawn(service.call("hello")); + assert_pending!(worker.poll()); + + assert_ready_ok!(service.poll_ready()); + let mut response3 = task::spawn(service.call("hello")); + assert_pending!(service.poll_ready()); + assert_pending!(worker.poll()); + + handle.allow(1); + assert_pending!(worker.poll()); + + handle + .next_request() + .await + .unwrap() + .1 + .send_response("world"); + assert_pending!(worker.poll()); + assert_ready_ok!(response1.poll()); + + assert_ready_ok!(service.poll_ready()); + let mut response4 = task::spawn(service.call("hello")); + assert_pending!(worker.poll()); + + handle.allow(3); + assert_pending!(worker.poll()); + + handle + .next_request() + .await + .unwrap() + .1 + .send_response("world"); + assert_pending!(worker.poll()); + assert_ready_ok!(response2.poll()); + + assert_pending!(worker.poll()); + handle + .next_request() + .await + .unwrap() + .1 + .send_response("world"); + assert_pending!(worker.poll()); + assert_ready_ok!(response3.poll()); + + assert_pending!(worker.poll()); + handle + .next_request() + .await + .unwrap() + .1 + .send_response("world"); + assert_pending!(worker.poll()); + assert_ready_ok!(response4.poll()); +} + type Mock = mock::Mock<&'static str, &'static str>; type Handle = mock::Handle<&'static str, &'static str>; fn new_service() -> (mock::Spawn>, Handle) { // bound is >0 here because clears_canceled_requests needs multiple outstanding requests + new_service_with_bound(10) +} + +fn new_service_with_bound(bound: usize) -> (mock::Spawn>, Handle) { mock::spawn_with(|s| { - let (svc, worker) = Buffer::pair(s, 10); + let (svc, worker) = Buffer::pair(s, bound); thread::spawn(move || { let mut fut = tokio_test::task::spawn(worker); diff --git a/tower/tests/steer/main.rs b/tower/tests/steer/main.rs index ab6482d..092b094 100644 --- a/tower/tests/steer/main.rs +++ b/tower/tests/steer/main.rs @@ -29,7 +29,7 @@ impl Service for MyService { #[test] fn pick_correctly() { - let mut rt = tokio::runtime::Runtime::new().unwrap(); + let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async move { let srvs = vec![MyService(42, true), MyService(57, true)]; let mut st = Steer::new(srvs, |_: &_, _: &[_]| 1); @@ -44,7 +44,7 @@ fn pick_correctly() { #[test] fn pending_all_ready() { - let mut rt = tokio::runtime::Runtime::new().unwrap(); + let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async move { let srvs = vec![MyService(42, true), MyService(57, false)]; let mut st = Steer::new(srvs, |_: &_, _: &[_]| 0); diff --git a/tower/tests/util/oneshot.rs b/tower/tests/util/oneshot.rs index 63ba004..c2b95e1 100644 --- a/tower/tests/util/oneshot.rs +++ b/tower/tests/util/oneshot.rs @@ -35,5 +35,5 @@ async fn service_driven_to_readiness() { } let svc = PollMeTwice { ready: false }; - svc.oneshot(()).await; + svc.oneshot(()).await.unwrap(); }