diff --git a/Cargo.toml b/Cargo.toml index 9cc0da9..acb0d37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,17 @@ homepage = "https://tower.rs" repository = "https://github.com/tower-rs/tower" readme = "README.md" +[workspace] + +members = [ + "./", + "tower-filter", + "tower-mock", + "tower-rate-limit", + "tower-route", + "tower-timeout", +] + [dependencies] futures = "0.1" @@ -22,3 +33,6 @@ log = "0.3" env_logger = "0.4" tokio-timer = "0.1" futures-cpupool = "0.1" + +[replace] +"futures:0.1.16" = { git = "https://github.com/carllerche/futures-rs", branch = "test-harness" } diff --git a/tower-filter/Cargo.toml b/tower-filter/Cargo.toml new file mode 100644 index 0000000..b71ef2c --- /dev/null +++ b/tower-filter/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "tower-filter" +version = "0.1.0" +authors = ["Carl Lerche "] + +[dependencies] +futures = "0.1" +tower = { version = "0.1", path = "../" } + +[dev-dependencies] +tower-mock = { version = "0.1", path = "../tower-mock" } diff --git a/tower-filter/README.md b/tower-filter/README.md new file mode 100644 index 0000000..3e71106 --- /dev/null +++ b/tower-filter/README.md @@ -0,0 +1,12 @@ +Tower Filter + +A Tower middleware that conditionally allows requests to be dispatched to the +inner service based on the result of a predicate. + +# License + +`tower-filter` is primarily distributed under the terms of both the MIT license +and the Apache License (Version 2.0), with portions covered by various BSD-like +licenses. + +See LICENSE-APACHE, and LICENSE-MIT for details. diff --git a/tower-filter/src/lib.rs b/tower-filter/src/lib.rs new file mode 100644 index 0000000..315a240 --- /dev/null +++ b/tower-filter/src/lib.rs @@ -0,0 +1,263 @@ +//! Conditionally dispatch requests to the inner service based on the result of +//! a predicate. + +extern crate futures; +extern crate tower; + +use futures::{Future, IntoFuture, Poll, Async}; +use futures::task::AtomicTask; +use tower::Service; + +use std::{fmt, mem}; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::SeqCst; + +#[derive(Debug)] +pub struct Filter { + inner: T, + predicate: U, + // Tracks the number of in-flight requests + counts: Arc, +} + +pub struct ResponseFuture +where S: Service, +{ + inner: Option>, +} + +#[derive(Debug)] +struct ResponseInner +where S: Service, +{ + state: State, + check: T, + service: S, + counts: Arc, +} + +/// Errors produced by `Filter` +#[derive(Debug)] +pub enum Error { + /// The predicate rejected the request. + Rejected(T), + + /// The inner service produced an error. + Inner(U), + + /// The service is out of capacity. + NoCapacity, +} + +/// Checks a request +pub trait Predicate { + type Error; + type Future: Future; + + fn check(&mut self, request: &T) -> Self::Future; +} + +#[derive(Debug)] +struct Counts { + /// Filter::poll_ready task + task: AtomicTask, + + /// Remaining capacity + rem: AtomicUsize, +} + +#[derive(Debug)] +enum State { + Check(T), + WaitReady(T), + WaitResponse(U), + NoCapacity, +} + +// ===== impl Filter ===== + +impl Filter +where T: Service + Clone, + U: Predicate, +{ + pub fn new(inner: T, predicate: U, buffer: usize) -> Self { + let counts = Counts { + task: AtomicTask::new(), + rem: AtomicUsize::new(buffer), + }; + + Filter { + inner, + predicate, + counts: Arc::new(counts), + } + } +} + +impl Service for Filter +where T: Service + Clone, + U: Predicate, +{ + type Request = T::Request; + type Response = T::Response; + type Error = Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.counts.task.register(); + + let rem = self.counts.rem.load(SeqCst); + + // TODO: Handle catching upstream closing + + if rem == 0 { + Ok(Async::NotReady) + } else { + Ok(().into()) + } + } + + fn call(&mut self, request: T::Request) -> Self::Future { + let rem = self.counts.rem.load(SeqCst); + + if rem == 0 { + return ResponseFuture { + inner: None, + }; + } + + // Decrement + self.counts.rem.fetch_sub(1, SeqCst); + + // Check the request + let check = self.predicate.check(&request); + + // Clone the service + let service = self.inner.clone(); + + // Clone counts + let counts = self.counts.clone(); + + ResponseFuture { + inner: Some(ResponseInner { + state: State::Check(request), + check, + service, + counts, + }), + } + } +} + +// ===== impl Predicate ===== + +impl Predicate for F + where F: Fn(&T) -> U, + U: IntoFuture, +{ + type Error = U::Error; + type Future = U::Future; + + fn check(&mut self, request: &T) -> Self::Future { + self(request).into_future() + } +} + +// ===== impl ResponseFuture ===== + +impl Future for ResponseFuture +where T: Future, + U: Service, +{ + type Item = U::Response; + type Error = Error; + + fn poll(&mut self) -> Poll { + match self.inner { + Some(ref mut inner) => inner.poll(), + None => Err(Error::NoCapacity), + } + } +} + +impl fmt::Debug for ResponseFuture +where T: fmt::Debug, + S: Service + fmt::Debug, + S::Request: fmt::Debug, + S::Future: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("ResponseFuture") + .field("inner", &self.inner) + .finish() + } +} + +// ===== impl ResponseInner ===== + +impl ResponseInner +where T: Future, + U: Service, +{ + fn inc_rem(&self) { + if 0 == self.counts.rem.fetch_add(1, SeqCst) { + self.counts.task.notify(); + } + } + + fn poll(&mut self) -> Poll> { + use self::State::*; + + loop { + match mem::replace(&mut self.state, NoCapacity) { + Check(request) => { + // Poll predicate + match self.check.poll() { + Ok(Async::Ready(_)) => { + self.state = WaitReady(request); + } + Ok(Async::NotReady) => { + self.state = Check(request); + return Ok(Async::NotReady); + } + Err(e) => { + return Err(Error::Rejected(e)); + } + } + } + WaitReady(request) => { + // Poll service for readiness + match self.service.poll_ready() { + Ok(Async::Ready(_)) => { + self.inc_rem(); + + let response = self.service.call(request); + self.state = WaitResponse(response); + } + Ok(Async::NotReady) => { + self.state = WaitReady(request); + return Ok(Async::NotReady); + } + Err(e) => { + self.inc_rem(); + + return Err(Error::Inner(e)); + } + } + } + WaitResponse(mut response) => { + let ret = response.poll() + .map_err(Error::Inner); + + self.state = WaitResponse(response); + + return ret; + } + NoCapacity => { + return Err(Error::NoCapacity); + } + } + } + } +} diff --git a/tower-filter/tests/filter.rs b/tower-filter/tests/filter.rs new file mode 100644 index 0000000..184c397 --- /dev/null +++ b/tower-filter/tests/filter.rs @@ -0,0 +1,139 @@ +extern crate futures; +extern crate tower; +extern crate tower_mock; +extern crate tower_filter; + +use futures::*; +use tower::*; +use tower_filter::*; + +use std::thread; +use std::sync::mpsc; + +#[test] +fn passthrough_sync() { + let (mut service, mut handle) = + new_service(10, |_| Ok::<_, ()>(())); + + let th = thread::spawn(move || { + // Receive the requests and respond + for i in 0..10 { + let expect = format!("ping-{}", i); + let actual = handle.next_request().unwrap(); + + assert_eq!(actual.as_str(), expect.as_str()); + + actual.respond(format!("pong-{}", i)); + } + }); + + let mut responses = vec![]; + + for i in 0..10 { + let request = format!("ping-{}", i); + let exchange = service.call(request) + .and_then(move |response| { + let expect = format!("pong-{}", i); + assert_eq!(response.as_str(), expect.as_str()); + + Ok(()) + }); + + responses.push(exchange); + } + + future::join_all(responses).wait().unwrap(); + th.join().unwrap(); +} + +#[test] +fn rejected_sync() { + let (mut service, _handle) = + new_service(10, |_| Err::<(), _>(())); + + let response = service.call("hello".into()).wait(); + assert!(response.is_err()); +} + +#[test] +fn saturate() { + use futures::stream::FuturesUnordered; + + let (mut service, mut handle) = + new_service(1, |_| Ok::<_, ()>(())); + + with_task(|| { + // First request is ready + assert!(service.poll_ready().unwrap().is_ready()); + }); + + let mut r1 = service.call("one".into()); + + with_task(|| { + // Second request is not ready + assert!(service.poll_ready().unwrap().is_not_ready()); + }); + + let mut futs = FuturesUnordered::new(); + futs.push(service.ready()); + + let (tx, rx) = mpsc::channel(); + + // Complete the request in another thread + let th1 = thread::spawn(move || { + with_task(|| { + assert!(r1.poll().unwrap().is_not_ready()); + + tx.send(()).unwrap(); + + let response = r1.wait().unwrap(); + assert_eq!(response.as_str(), "resp-one"); + }); + }); + + rx.recv().unwrap(); + + // The service should be ready + let mut service = with_task(|| { + match futs.poll().unwrap() { + Async::Ready(Some(s)) => s, + Async::Ready(None) => panic!("None"), + Async::NotReady => panic!("NotReady"), + } + }); + + let r2 = service.call("two".into()); + + let th2 = thread::spawn(move || { + let response = r2.wait().unwrap(); + assert_eq!(response.as_str(), "resp-two"); + }); + + let request = handle.next_request().unwrap(); + assert_eq!("one", request.as_str()); + request.respond("resp-one".into()); + + let request = handle.next_request().unwrap(); + assert_eq!("two", request.as_str()); + request.respond("resp-two".into()); + + th1.join().unwrap(); + th2.join().unwrap(); +} + +type Mock = tower_mock::Mock; +type Handle = tower_mock::Handle; + +fn new_service(max: usize, f: F) -> (Filter, Handle) +where F: Fn(&String) -> U, + U: IntoFuture +{ + let (service, handle) = Mock::new(); + let service = Filter::new(service, f, max); + (service, handle) +} + +fn with_task U, U>(f: F) -> U { + use futures::future::{Future, lazy}; + lazy(|| Ok::<_, ()>(f())).wait().unwrap() +} diff --git a/tower-mock/Cargo.toml b/tower-mock/Cargo.toml new file mode 100644 index 0000000..25d8103 --- /dev/null +++ b/tower-mock/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "tower-mock" +version = "0.1.0" +authors = ["Carl Lerche "] + +[dependencies] +futures = "0.1" +tower = { version = "0.1", path = "../" } diff --git a/tower-mock/README.md b/tower-mock/README.md new file mode 100644 index 0000000..11ce833 --- /dev/null +++ b/tower-mock/README.md @@ -0,0 +1,11 @@ +Tower Mock + +A mock `Service` that can be used to test middleware or clients. + +# License + +`tower-mock` is primarily distributed under the terms of both the MIT license +and the Apache License (Version 2.0), with portions covered by various BSD-like +licenses. + +See LICENSE-APACHE, and LICENSE-MIT for details. diff --git a/tower-mock/src/lib.rs b/tower-mock/src/lib.rs new file mode 100644 index 0000000..303f214 --- /dev/null +++ b/tower-mock/src/lib.rs @@ -0,0 +1,325 @@ +//! Mock `Service` that can be used in tests. + +extern crate tower; +extern crate futures; + +use tower::Service; + +use futures::{Future, Stream, Poll, Async}; +use futures::sync::{oneshot, mpsc}; +use futures::task::{self, Task}; + +use std::{ops, u64}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +/// A mock service +#[derive(Debug)] +pub struct Mock { + id: u64, + tx: Mutex>, + state: Arc>, + can_send: bool, +} + +/// Handle to the `Mock`. +#[derive(Debug)] +pub struct Handle { + rx: Rx, + state: Arc>, +} + +#[derive(Debug)] +pub struct Request { + request: T, + respond: Respond, +} + +/// Respond to a request received by `Mock`. +#[derive(Debug)] +pub struct Respond { + tx: oneshot::Sender>, +} + +/// Future of the `Mock` response. +#[derive(Debug)] +pub struct ResponseFuture { + // Slight abuse of the error enum... + rx: Error>>, +} + +/// Enumeration of errors that can be returned by `Mock`. +#[derive(Debug)] +pub enum Error { + Closed, + NoCapacity, + Other(T), +} + +#[derive(Debug)] +struct State { + // Tracks the number of requests that can be sent through + rem: u64, + + // Tasks that are blocked + tasks: HashMap, + + // Tracks if the `Handle` dropped + is_closed: bool, + + // Tracks the ID for the next mock clone + next_clone_id: u64, +} + +type Tx = mpsc::UnboundedSender>; +type Rx = mpsc::UnboundedReceiver>; + +// ===== impl Mock ===== + +impl Mock { + /// Create a new `Mock` and `Handle` pair. + pub fn new() -> (Self, Handle) { + let (tx, rx) = mpsc::unbounded(); + let tx = Mutex::new(tx); + + let state = Arc::new(Mutex::new(State::new())); + + let mock = Mock { + id: 0, + tx, + state: state.clone(), + can_send: false, + }; + + let handle = Handle { + rx, + state, + }; + + (mock, handle) + } +} + +impl Service for Mock { + type Request = T; + type Response = U; + type Error = Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + let mut state = self.state.lock().unwrap(); + + if state.is_closed { + return Err(Error::Closed); + } + + if self.can_send { + return Ok(().into()); + } + + if state.rem > 0 { + assert!(!state.tasks.contains_key(&self.id)); + + // Returning `Ready` means the next call to `call` must succeed. + self.can_send = true; + + Ok(Async::Ready(())) + } else { + // Bit weird... but whatevz + *state.tasks.entry(self.id) + .or_insert_with(|| task::current()) = task::current(); + + Ok(Async::NotReady) + } + } + + fn call(&mut self, request: Self::Request) -> Self::Future { + // Make sure that the service has capacity + let mut state = self.state.lock().unwrap(); + + if state.is_closed { + return ResponseFuture { + rx: Error::Closed, + }; + } + + if !self.can_send { + if state.rem == 0 { + return ResponseFuture { + rx: Error::NoCapacity, + } + } + } + + self.can_send = false; + + // Decrement the number of remaining requests that can be sent + if state.rem > 0 { + state.rem -= 1; + } + + let (tx, rx) = oneshot::channel(); + + let request = Request { + request, + respond: Respond { tx }, + }; + + match self.tx.lock().unwrap().unbounded_send((request)) { + Ok(_) => {} + Err(_) => { + // TODO: Can this be reached + return ResponseFuture { + rx: Error::Closed, + }; + } + } + + ResponseFuture { rx: Error::Other(rx) } + } +} + +impl Clone for Mock { + fn clone(&self) -> Self { + let id = { + let mut state = self.state.lock().unwrap(); + let id = state.next_clone_id; + + state.next_clone_id += 1; + + id + }; + + let tx = Mutex::new(self.tx.lock().unwrap().clone()); + + Mock { + id, + tx, + state: self.state.clone(), + can_send: false, + } + } +} + +impl Drop for Mock { + fn drop(&mut self) { + let mut state = self.state.lock().unwrap(); + state.tasks.remove(&self.id); + } +} + +// ===== impl Handle ===== + +impl Handle { + /// Asynchronously gets the next request + pub fn poll_request(&mut self) + -> Poll>, ()> + { + self.rx.poll() + } + + /// Synchronously gets the next request. + /// + /// This function blocks the current thread until a request is received. + pub fn next_request(&mut self) -> Option> { + use futures::future::poll_fn; + poll_fn(|| self.poll_request()).wait().unwrap() + } + + /// Allow a certain number of requests + pub fn allow(&mut self, num: u64) { + let mut state = self.state.lock().unwrap(); + state.rem = num; + + if num > 0 { + for (_, task) in state.tasks.drain() { + task.notify(); + } + } + } +} + +impl Drop for Handle { + fn drop(&mut self) { + let mut state = self.state.lock().unwrap(); + state.is_closed = true; + + for (_, task) in state.tasks.drain() { + task.notify(); + } + } +} + +// ===== impl Request ===== + +impl Request { + /// Split the request and respond handle + pub fn into_parts(self) -> (T, Respond) { + (self.request, self.respond) + } + + pub fn respond(self, response: U) { + self.respond.respond(response) + } + + pub fn error(self, err: E) { + self.respond.error(err) + } +} + +impl ops::Deref for Request { + type Target = T; + + fn deref(&self) -> &T { + &self.request + } +} + +// ===== impl Respond ===== + +impl Respond { + pub fn respond(self, response: T) { + // TODO: Should the result be dropped? + let _ = self.tx.send(Ok(response)); + } + + pub fn error(self, err: E) { + // TODO: Should the result be dropped? + let _ = self.tx.send(Err(err)); + } +} + +// ===== impl ResponseFuture ===== + +impl Future for ResponseFuture { + type Item = T; + type Error = Error; + + fn poll(&mut self) -> Poll { + match self.rx { + Error::Other(ref mut rx) => { + match rx.poll() { + Ok(Async::Ready(Ok(v))) => Ok(v.into()), + Ok(Async::Ready(Err(e))) => Err(Error::Other(e)), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(_) => Err(Error::Closed), + } + } + Error::NoCapacity => Err(Error::NoCapacity), + Error::Closed => Err(Error::Closed), + } + } +} + +// ===== impl State ===== + +impl State { + fn new() -> State { + State { + rem: u64::MAX, + tasks: HashMap::new(), + is_closed: false, + next_clone_id: 1, + } + } +} diff --git a/tower-mock/tests/mock.rs b/tower-mock/tests/mock.rs new file mode 100644 index 0000000..1e546cd --- /dev/null +++ b/tower-mock/tests/mock.rs @@ -0,0 +1,70 @@ +extern crate futures; +extern crate tower; +extern crate tower_mock; + +use futures::Future; +use tower::Service; + +#[test] +fn single_request_ready() { + let (mut mock, mut handle) = new_mock(); + + // No pending requests + with_task(|| { + assert!(handle.poll_request().unwrap().is_not_ready()); + }); + + // Issue a request + let mut response = mock.call("hello?".into()); + + // Get the request from the handle + let request = handle.next_request().unwrap(); + + assert_eq!(request.as_str(), "hello?"); + + // Response is not ready + with_task(|| { + assert!(response.poll().unwrap().is_not_ready()); + }); + + // Send the response + request.respond("yes?".into()); + + assert_eq!(response.wait().unwrap().as_str(), "yes?"); +} + +#[test] +fn backpressure() { + let (mut mock, mut handle) = new_mock(); + + handle.allow(0); + + // Make sure the mock cannot accept more requests + with_task(|| { + assert!(mock.poll_ready().unwrap().is_not_ready()); + }); + + // Try to send a request + let response = mock.call("hello?".into()); + + // Did not send + with_task(|| { + assert!(handle.poll_request().unwrap().is_not_ready()); + }); + + // Response is an error + assert!(response.wait().is_err()); +} + +type Mock = tower_mock::Mock; +type Handle = tower_mock::Handle; + +fn new_mock() -> (Mock, Handle) { + Mock::new() +} + +// Helper to run some code within context of a task +fn with_task U, U>(f: F) -> U { + use futures::future::{Future, lazy}; + lazy(|| Ok::<_, ()>(f())).wait().unwrap() +} diff --git a/tower-rate-limit/Cargo.toml b/tower-rate-limit/Cargo.toml new file mode 100644 index 0000000..35a9112 --- /dev/null +++ b/tower-rate-limit/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "tower-rate-limit" +version = "0.1.0" +authors = ["Carl Lerche "] + +[dependencies] +futures = "0.1" +tower = { version = "0.1", path = "../" } +tokio-timer = "0.1" + +[dev-dependencies] +tower-mock = { version = "0.1", path = "../tower-mock" } diff --git a/tower-rate-limit/README.md b/tower-rate-limit/README.md new file mode 100644 index 0000000..2a85f5f --- /dev/null +++ b/tower-rate-limit/README.md @@ -0,0 +1,13 @@ +Tower Rate Limit + +A Tower middleware that rate limits the requests that are passed to the inner +service. + +# License + +`tower-rate-limit` is primarily distributed under the terms of both the MIT +license and the Apache License (Version 2.0), with portions covered by various +BSD-like licenses. + +See LICENSE-APACHE, and LICENSE-MIT for details. + diff --git a/tower-rate-limit/src/lib.rs b/tower-rate-limit/src/lib.rs new file mode 100644 index 0000000..f5febcf --- /dev/null +++ b/tower-rate-limit/src/lib.rs @@ -0,0 +1,174 @@ +//! Tower middleware that applies a timeout to requests. +//! +//! If the response does not complete within the specified timeout, the response +//! will be aborted. + +#[macro_use] +extern crate futures; +extern crate tower; +extern crate tokio_timer; + +use futures::{Future, Poll}; +use tower::Service; +use tokio_timer::{Timer, Sleep}; + +use std::time::{Duration, Instant}; + +#[derive(Debug)] +pub struct RateLimit { + inner: T, + timer: Timer, + rate: Rate, + state: State, +} + +#[derive(Debug, Copy, Clone)] +pub struct Rate { + num: u64, + per: Duration, +} + +/// The request has been rate limited +/// +/// TODO: Consider returning the original request +#[derive(Debug)] +pub enum Error { + RateLimit, + Upstream(T), +} + +pub struct ResponseFuture { + inner: Option, +} + +#[derive(Debug)] +enum State { + // The service has hit its limit + Limited(Sleep), + Ready { + until: Instant, + rem: u64, + }, +} + +impl RateLimit { + /// Create a new rate limiter + pub fn new(inner: T, rate: Rate, timer: Timer) -> Self { + let state = State::Ready { + until: Instant::now(), + rem: rate.num, + }; + + RateLimit { + inner, + rate, + timer, + state: state, + } + } + + /// Get a reference to the inner service + pub fn get_ref(&self) -> &T { + &self.inner + } + + /// Get a mutable reference to the inner service + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner + } + + /// Consume `self`, returning the inner service + pub fn into_inner(self) -> T { + self.inner + } +} + +impl Rate { + /// Create a new rate + /// + /// # Panics + /// + /// This function panics if `num` or `per` is 0. + pub fn new(num: u64, per: Duration) -> Self { + assert!(num > 0); + assert!(per > Duration::from_millis(0)); + + Rate { num, per } + } +} + +impl Service for RateLimit +where S: Service +{ + type Request = S::Request; + type Response = S::Response; + type Error = Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + match self.state { + State::Ready { .. } => return Ok(().into()), + State::Limited(ref mut sleep) => { + let res = sleep.poll() + .map_err(|_| Error::RateLimit); + + try_ready!(res); + } + } + + self.state = State::Ready { + until: Instant::now() + self.rate.per, + rem: self.rate.num, + }; + + Ok(().into()) + } + + fn call(&mut self, request: Self::Request) -> Self::Future { + match self.state { + State::Ready { mut until, mut rem } => { + let now = Instant::now(); + + // If the period has elapsed, reset it. + if now >= until { + until = now + self.rate.per; + let rem = self.rate.num; + + self.state = State::Ready { until, rem } + } + + if rem > 1 { + rem -= 1; + self.state = State::Ready { until, rem }; + } else { + // The service is disabled until further notice + let sleep = self.timer.sleep(until - now); + self.state = State::Limited(sleep); + } + + // Call the inner future + let inner = Some(self.inner.call(request)); + ResponseFuture { inner } + } + State::Limited(..) => { + ResponseFuture { inner: None } + } + } + } +} + +impl Future for ResponseFuture +where T: Future, +{ + type Item = T::Item; + type Error = Error; + + fn poll(&mut self) -> Poll { + match self.inner { + Some(ref mut f) => { + f.poll().map_err(Error::Upstream) + } + None => Err(Error::RateLimit), + } + } +} diff --git a/tower-rate-limit/tests/rate_limit.rs b/tower-rate-limit/tests/rate_limit.rs new file mode 100644 index 0000000..bdec491 --- /dev/null +++ b/tower-rate-limit/tests/rate_limit.rs @@ -0,0 +1,75 @@ +extern crate futures; +extern crate tower; +extern crate tower_mock; +extern crate tower_rate_limit; +extern crate tokio_timer; + +use futures::prelude::*; +use tower::*; +use tower_rate_limit::*; + +use std::time::Duration; +use std::thread; + +#[test] +fn reaching_capacity() { + let (mut service, mut handle) = + new_service(Rate::new(1, from_millis(100))); + + let response = service.call("hello"); + + let request = handle.next_request().unwrap(); + assert_eq!(*request, "hello"); + request.respond("world"); + + assert_eq!(response.wait().unwrap(), "world"); + + // Sending another request is rejected + let response = service.call("no"); + with_task(|| { + assert!(handle.poll_request().unwrap().is_not_ready()); + }); + + assert!(response.wait().is_err()); + + with_task(|| { + assert!(service.poll_ready().unwrap().is_not_ready()); + }); + + thread::sleep(Duration::from_millis(100)); + + with_task(|| { + assert!(service.poll_ready().unwrap().is_ready()); + }); + + // Send a second request + let response = service.call("two"); + + let request = handle.next_request().unwrap(); + assert_eq!(*request, "two"); + request.respond("done"); + + assert_eq!(response.wait().unwrap(), "done"); +} + +type Mock = tower_mock::Mock<&'static str, &'static str, ()>; +type Handle = tower_mock::Handle<&'static str, &'static str, ()>; + +fn new_service(rate: Rate) -> (RateLimit, Handle) { + let timer = tokio_timer::wheel() + .tick_duration(Duration::from_millis(1)) + .build(); + + let (service, handle) = Mock::new(); + let service = RateLimit::new(service, rate, timer); + (service, handle) +} + +fn with_task U, U>(f: F) -> U { + use futures::future::{Future, lazy}; + lazy(|| Ok::<_, ()>(f())).wait().unwrap() +} + +fn from_millis(n: u64) -> Duration { + Duration::from_millis(n) +} diff --git a/tower-route/Cargo.toml b/tower-route/Cargo.toml new file mode 100644 index 0000000..482ea45 --- /dev/null +++ b/tower-route/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "tower-route" +version = "0.1.0" +authors = ["Carl Lerche "] + +[dependencies] +futures = "0.1" +tower = { version = "0.1", path = "../" } +futures-borrow = { git = "https://github.com/carllerche/futures-borrow" } diff --git a/tower-route/README.md b/tower-route/README.md new file mode 100644 index 0000000..0901ab3 --- /dev/null +++ b/tower-route/README.md @@ -0,0 +1,12 @@ +Tower Route + +A Tower middleware that routes requests to one of a set of inner services using +a request predicate. + +# License + +`tower-route` is primarily distributed under the terms of both the MIT license +and the Apache License (Version 2.0), with portions covered by various BSD-like +licenses. + +See LICENSE-APACHE, and LICENSE-MIT for details. diff --git a/tower-route/src/lib.rs b/tower-route/src/lib.rs new file mode 100644 index 0000000..77fd899 --- /dev/null +++ b/tower-route/src/lib.rs @@ -0,0 +1,194 @@ +//! Routes requests to one of many inner inner services based on the request. + +extern crate tower; + +#[macro_use] +extern crate futures; +extern crate futures_borrow; + +use tower::Service; + +use futures::{Future, Poll}; +use futures_borrow::{Borrow, BorrowGuard}; + +use std::mem; + +use self::ResponseState::*; + +/// Routes requests to an inner service based on the request. +pub struct Route { + recognize: Borrow, +} + +/// Matches the request with a route +pub trait Recognize: 'static { + /// Request being matched + type Request; + + /// Inner service's response + type Response; + + /// Error produced by a failed inner service request + type Error; + + /// Error produced by failed route recognition + type RouteError; + + /// The destination service + type Service: Service; + + /// Recognize a route + /// + /// Takes a request, returns the route matching the request. + /// + /// The returned value is a mutable reference to the destination `Service`. + /// However, it may be that some asynchronous initialization must be + /// performed before the service is able to process requests (for example, + /// a TCP connection might need to be established). In this case, the inner + /// service should determine the buffering strategy used to handle the + /// request until the request can be processed. This behavior enables + /// punting all buffering decisions to the inner service. + fn recognize(&mut self, request: &Self::Request) + -> Result<&mut Self::Service, Self::RouteError>; +} + +pub struct ResponseFuture +where T: Recognize, +{ + state: ResponseState, +} + +/// Error produced by the `Route` service +/// +/// TODO: Make variants priv +#[derive(Debug)] +pub enum Error { + /// Error produced by inner service. + Inner(T), + + /// Error produced during route recognition. + Route(U), + + /// Request sent when not ready. + NotReady, +} + +enum ResponseState +where T: Recognize +{ + Dispatched(::Future), + RouteError(T::RouteError), + Queued { + service: BorrowGuard, + request: T::Request, + }, + NotReady, + Invalid, +} + +// ===== impl Route ===== + +impl Route +where T: Recognize +{ + /// Create a new router + pub fn new(recognize: T) -> Self { + Route { recognize: Borrow::new(recognize) } + } +} + +impl Service for Route +where T: Recognize, +{ + type Request = T::Request; + type Response = T::Response; + type Error = Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + // Checks if there is an outstanding borrow (i.e. there is an in-flight + // request that is blocked on an inner service). + // + // Borrow::poll_ready returning an error means the borrow was poisoned. + // A panic is fine. + self.recognize.poll_ready().map_err(|_| panic!()) + } + + fn call(&mut self, request: Self::Request) -> Self::Future { + let borrow = match self.recognize.try_borrow() { + Ok(borrow) => borrow, + Err(_) => { + return ResponseFuture { + state: NotReady, + } + } + }; + + let recognize = Borrow::try_map(borrow, |recognize| { + // Match the service + recognize.recognize(&request) + }); + + match recognize { + Ok(service) => { + ResponseFuture { + state: Queued { + service, + request, + }, + } + } + Err((_, err)) => { + ResponseFuture { + state: RouteError(err), + } + } + } + } +} + +// ===== impl ResponseFuture ===== + +impl Future for ResponseFuture +where T: Recognize, +{ + type Item = T::Response; + type Error = Error; + + fn poll(&mut self) -> Poll { + loop { + match self.state { + Dispatched(ref mut inner) => { + return inner.poll() + .map_err(Error::Inner); + } + Queued { ref mut service, .. } => { + let res = service.poll_ready() + .map_err(Error::Inner); + + try_ready!(res); + + // Fall through to transition state + } + _ => {} + } + + match mem::replace(&mut self.state, Invalid) { + Dispatched(..) => unreachable!(), + Queued { mut service, request } => { + let response = service.call(request); + self.state = Dispatched(response); + } + RouteError(err) => { + return Err(Error::Route(err)); + } + NotReady => { + return Err(Error::NotReady); + } + Invalid => panic!(), + } + } + } +} diff --git a/tower-route/tests/route.rs b/tower-route/tests/route.rs new file mode 100644 index 0000000..349ed74 --- /dev/null +++ b/tower-route/tests/route.rs @@ -0,0 +1,199 @@ +extern crate futures; +extern crate tower; +extern crate tower_route; + +use tower::Service; +use tower_route::*; + +use futures::*; +use futures::future::FutureResult; +use futures::executor::TestHarness; + +use std::collections::HashMap; + +macro_rules! assert_ready { + ($service:expr) => {{ + let s = $service; + let mut t = TestHarness::new(future::poll_fn(|| s.poll_ready())); + assert!(t.poll().unwrap().is_ready()); + }}; +} + +macro_rules! assert_not_ready { + ($service:expr) => {{ + let s = $service; + let mut t = TestHarness::new(future::poll_fn(|| s.poll_ready())); + assert!(!t.poll().unwrap().is_ready()); + }}; +} + +#[test] +fn basic_routing() { + let mut recognize = MapRecognize::new(); + recognize.map.insert("one".into(), StringService::ok("hello")); + recognize.map.insert("two".into(), StringService::ok("world")); + + let mut service = Route::new(recognize); + + // Router is ready by default + assert_ready!(&mut service); + + let resp = service.call("one".into()); + + assert_not_ready!(&mut service); + assert_eq!(resp.wait().unwrap(), "hello"); + + // Router ready again + assert_ready!(&mut service); + + let resp = service.call("two".into()); + assert_eq!(resp.wait().unwrap(), "world"); + + // Try invalid routing + let resp = service.call("three".into()); + assert!(resp.wait().is_err()); +} + +#[test] +fn inner_service_err() { + let mut recognize = MapRecognize::new(); + recognize.map.insert("one".into(), StringService::ok("hello")); + recognize.map.insert("two".into(), StringService::err()); + + let mut service = Route::new(recognize); + + let resp = service.call("two".into()); + assert!(resp.wait().is_err()); + + assert_ready!(&mut service); + + let resp = service.call("one".into()); + assert_eq!(resp.wait().unwrap(), "hello"); +} + +#[test] +fn inner_service_not_ready() { + let mut recognize = MapRecognize::new(); + recognize.map.insert("one".into(), MaybeService::new("hello")); + recognize.map.insert("two".into(), MaybeService::none()); + + let mut service = Route::new(recognize); + + let resp = service.call("two".into()); + let mut resp = TestHarness::new(resp); + assert!(!resp.poll().unwrap().is_ready()); + + assert_not_ready!(&mut service); + + let resp = service.call("one".into()); + assert!(resp.wait().is_err()); +} + +// ===== impl MapRecognize ===== + +#[derive(Debug)] +struct MapRecognize { + map: HashMap, +} + +impl MapRecognize { + fn new() -> Self { + MapRecognize { map: HashMap::new() } + } +} + +impl Recognize for MapRecognize +where T: Service + 'static, +{ + type Request = String; + type Response = String; + type Error = (); + type RouteError = (); + type Service = T; + + fn recognize(&mut self, request: &Self::Request) + -> Result<&mut Self::Service, Self::RouteError> + { + match self.map.get_mut(request) { + Some(service) => Ok(service), + None => Err(()), + } + } +} + +// ===== impl services ===== + +#[derive(Debug)] +struct StringService { + string: Result, +} + +impl StringService { + pub fn ok(string: &str) -> Self { + StringService { + string: Ok(string.into()), + } + } + + pub fn err() -> Self { + StringService { + string: Err(()), + } + } +} + +impl Service for StringService { + type Request = String; + type Response = String; + type Error = (); + type Future = FutureResult; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, _: Self::Request) -> Self::Future { + future::result(self.string.clone()) + } +} + +#[derive(Debug)] +struct MaybeService { + string: Option, +} + +impl MaybeService { + pub fn new(string: &str) -> Self { + MaybeService { + string: Some(string.into()), + } + } + + pub fn none() -> Self { + MaybeService { + string: None, + } + } +} + +impl Service for MaybeService { + type Request = String; + type Response = String; + type Error = (); + type Future = FutureResult; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + if self.string.is_some() { + Ok(Async::Ready(())) + } else { + Ok(Async::NotReady) + } + } + + fn call(&mut self, _: Self::Request) -> Self::Future { + match self.string.clone() { + Some(string) => future::ok(string), + None => future::err(()), + } + } +} diff --git a/tower-timeout/Cargo.toml b/tower-timeout/Cargo.toml new file mode 100644 index 0000000..a06472f --- /dev/null +++ b/tower-timeout/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "tower-timeout" +version = "0.1.0" +authors = ["Carl Lerche "] + +[dependencies] +futures = "0.1" +tower = { version = "0.1", path = "../" } +tokio-timer = "0.1" diff --git a/tower-timeout/README.md b/tower-timeout/README.md new file mode 100644 index 0000000..916b10a --- /dev/null +++ b/tower-timeout/README.md @@ -0,0 +1,11 @@ +Tower Timeout + +A Tower middleware that applies a timeout to requests. + +# License + +`tower-timeout` is primarily distributed under the terms of both the MIT license +and the Apache License (Version 2.0), with portions covered by various BSD-like +licenses. + +See LICENSE-APACHE, and LICENSE-MIT for details. diff --git a/tower-timeout/src/lib.rs b/tower-timeout/src/lib.rs new file mode 100644 index 0000000..c3b18dd --- /dev/null +++ b/tower-timeout/src/lib.rs @@ -0,0 +1,96 @@ +//! Tower middleware that applies a timeout to requests. +//! +//! If the response does not complete within the specified timeout, the response +//! will be aborted. + +extern crate futures; +extern crate tower; +extern crate tokio_timer; + +use futures::{Future, Poll, Async}; +use tower::Service; +use tokio_timer::{Timer, Sleep}; +use std::time::Duration; + +/// Applies a timeout to requests. +#[derive(Debug)] +pub struct Timeout { + upstream: T, + timer: Timer, + timeout: Duration, +} + +/// Errors produced by `Timeout`. +#[derive(Debug)] +pub enum Error { + /// The inner service produced an error + Inner(T), + + /// The request did not complete within the specified timeout. + Timeout, +} + +/// `Timeout` response future +#[derive(Debug)] +pub struct ResponseFuture { + response: T, + sleep: Sleep, +} + +// ===== impl Timeout ===== + +impl Timeout { + pub fn new(upstream: T, timer: Timer, timeout: Duration) -> Self { + Timeout { + upstream, + timer, + timeout, + } + } +} + +impl Service for Timeout +where S: Service, +{ + type Request = S::Request; + type Response = S::Response; + type Error = Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.upstream.poll_ready() + .map_err(Error::Inner) + } + + fn call(&mut self, request: Self::Request) -> Self::Future { + ResponseFuture { + response: self.upstream.call(request), + sleep: self.timer.sleep(self.timeout), + } + } +} + +// ===== impl ResponseFuture ===== + +impl Future for ResponseFuture +where T: Future, +{ + type Item = T::Item; + type Error = Error; + + fn poll(&mut self) -> Poll { + // First, try polling the future + match self.response.poll() { + Ok(Async::Ready(v)) => return Ok(Async::Ready(v)), + Ok(Async::NotReady) => {} + Err(e) => return Err(Error::Inner(e)), + } + + // Now check the sleep + match self.sleep.poll() { + Ok(Async::NotReady) => Ok(Async::NotReady), + Ok(Async::Ready(_)) => Err(Error::Timeout), + Err(_) => Err(Error::Timeout), + } + } +}