rate-limit: Refresh layout (#189)

- Switch to `Box<Error>`
- Break up lib.rs into multiple files.
- Use `tokio::clock::now` instead of `Instant::now`.
This commit is contained in:
Carl Lerche 2019-03-08 22:44:48 -08:00 committed by GitHub
parent 20102a647b
commit 720d31c65f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 128 additions and 121 deletions

View File

@ -0,0 +1,18 @@
use std::error;
pub(crate) type Error = Box<error::Error + Send + Sync>;
pub(crate) mod never {
use std::{error, fmt};
#[derive(Debug)]
pub enum Never {}
impl fmt::Display for Never {
fn fmt(&self, _: &mut fmt::Formatter) -> fmt::Result {
unreachable!();
}
}
impl error::Error for Never {}
}

View File

@ -0,0 +1,26 @@
use crate::Error;
use futures::{Future, Poll};
#[derive(Debug)]
pub struct ResponseFuture<T> {
inner: T,
}
impl<T> ResponseFuture<T> {
pub(crate) fn new(inner: T) -> ResponseFuture<T> {
ResponseFuture { inner }
}
}
impl<T> Future for ResponseFuture<T>
where
T: Future,
Error: From<T::Error>,
{
type Item = T::Item;
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.inner.poll().map_err(Into::into)
}
}

View File

@ -0,0 +1,32 @@
use crate::error::{never::Never, Error};
use crate::{Rate, RateLimit};
use std::time::Duration;
use tower_layer::Layer;
use tower_service::Service;
#[derive(Debug)]
pub struct RateLimitLayer {
rate: Rate,
}
impl RateLimitLayer {
pub fn new(num: u64, per: Duration) -> Self {
let rate = Rate::new(num, per);
RateLimitLayer { rate }
}
}
impl<S, Request> Layer<S, Request> for RateLimitLayer
where
S: Service<Request>,
Error: From<S::Error>,
{
type Response = S::Response;
type Error = Error;
type LayerError = Never;
type Service = RateLimit<S>;
fn layer(&self, service: S) -> Result<Self::Service, Self::LayerError> {
Ok(RateLimit::new(service, self.rate))
}
}

View File

@ -7,13 +7,21 @@ extern crate tokio_timer;
extern crate tower_layer; extern crate tower_layer;
extern crate tower_service; extern crate tower_service;
pub mod error;
pub mod future;
mod layer;
mod rate;
pub use crate::layer::RateLimitLayer;
pub use crate::rate::Rate;
use crate::error::Error;
use crate::future::ResponseFuture;
use futures::{Future, Poll}; use futures::{Future, Poll};
use tokio_timer::Delay; use tokio_timer::{clock, Delay};
use tower_layer::Layer;
use tower_service::Service; use tower_service::Service;
use std::time::{Duration, Instant}; use std::time::Instant;
use std::{error, fmt};
#[derive(Debug)] #[derive(Debug)]
pub struct RateLimit<T> { pub struct RateLimit<T> {
@ -22,30 +30,6 @@ pub struct RateLimit<T> {
state: State, state: State,
} }
#[derive(Debug)]
pub struct RateLimitLayer {
rate: Rate,
}
#[derive(Debug, Copy, Clone)]
pub struct Rate {
num: u64,
per: Duration,
}
/// The request has been rate limited
///
/// TODO: Consider returning the original request
#[derive(Debug)]
pub enum Error<T> {
RateLimit,
Upstream(T),
}
pub struct ResponseFuture<T> {
inner: T,
}
#[derive(Debug)] #[derive(Debug)]
enum State { enum State {
// The service has hit its limit // The service has hit its limit
@ -53,27 +37,6 @@ enum State {
Ready { until: Instant, rem: u64 }, Ready { until: Instant, rem: u64 },
} }
impl RateLimitLayer {
pub fn new(num: u64, per: Duration) -> Self {
let rate = Rate { num, per };
RateLimitLayer { rate }
}
}
impl<S, Request> Layer<S, Request> for RateLimitLayer
where
S: Service<Request>,
{
type Response = S::Response;
type Error = Error<S::Error>;
type LayerError = ();
type Service = RateLimit<S>;
fn layer(&self, service: S) -> Result<Self::Service, Self::LayerError> {
Ok(RateLimit::new(service, self.rate))
}
}
impl<T> RateLimit<T> { impl<T> RateLimit<T> {
/// Create a new rate limiter /// Create a new rate limiter
pub fn new<Request>(inner: T, rate: Rate) -> Self pub fn new<Request>(inner: T, rate: Rate) -> Self
@ -81,8 +44,8 @@ impl<T> RateLimit<T> {
T: Service<Request>, T: Service<Request>,
{ {
let state = State::Ready { let state = State::Ready {
until: Instant::now(), until: clock::now(),
rem: rate.num, rem: rate.num(),
}; };
RateLimit { RateLimit {
@ -108,41 +71,26 @@ impl<T> RateLimit<T> {
} }
} }
impl Rate {
/// Create a new rate
///
/// # Panics
///
/// This function panics if `num` or `per` is 0.
pub fn new(num: u64, per: Duration) -> Self {
assert!(num > 0);
assert!(per > Duration::from_millis(0));
Rate { num, per }
}
}
impl<S, Request> Service<Request> for RateLimit<S> impl<S, Request> Service<Request> for RateLimit<S>
where where
S: Service<Request>, S: Service<Request>,
Error: From<S::Error>,
{ {
type Response = S::Response; type Response = S::Response;
type Error = Error<S::Error>; type Error = Error;
type Future = ResponseFuture<S::Future>; type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> { fn poll_ready(&mut self) -> Poll<(), Self::Error> {
match self.state { match self.state {
State::Ready { .. } => return Ok(().into()), State::Ready { .. } => return Ok(().into()),
State::Limited(ref mut sleep) => { State::Limited(ref mut sleep) => {
let res = sleep.poll().map_err(|_| Error::RateLimit); try_ready!(sleep.poll());
try_ready!(res);
} }
} }
self.state = State::Ready { self.state = State::Ready {
until: Instant::now() + self.rate.per, until: clock::now() + self.rate.per(),
rem: self.rate.num, rem: self.rate.num(),
}; };
Ok(().into()) Ok(().into())
@ -151,12 +99,12 @@ where
fn call(&mut self, request: Request) -> Self::Future { fn call(&mut self, request: Request) -> Self::Future {
match self.state { match self.state {
State::Ready { mut until, mut rem } => { State::Ready { mut until, mut rem } => {
let now = Instant::now(); let now = clock::now();
// If the period has elapsed, reset it. // If the period has elapsed, reset it.
if now >= until { if now >= until {
until = now + self.rate.per; until = now + self.rate.per();
let rem = self.rate.num; let rem = self.rate.num();
self.state = State::Ready { until, rem } self.state = State::Ready { until, rem }
} }
@ -172,55 +120,9 @@ where
// Call the inner future // Call the inner future
let inner = self.inner.call(request); let inner = self.inner.call(request);
ResponseFuture { inner } ResponseFuture::new(inner)
} }
State::Limited(..) => panic!("service not ready; poll_ready must be called first"), State::Limited(..) => panic!("service not ready; poll_ready must be called first"),
} }
} }
} }
impl<T> Future for ResponseFuture<T>
where
T: Future,
{
type Item = T::Item;
type Error = Error<T::Error>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.inner.poll().map_err(Error::Upstream)
}
}
// ===== impl Error =====
impl<T> fmt::Display for Error<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::Upstream(ref why) => fmt::Display::fmt(why, f),
Error::RateLimit => f.pad("rate limit exceeded"),
}
}
}
impl<T> error::Error for Error<T>
where
T: error::Error,
{
fn cause(&self) -> Option<&error::Error> {
if let Error::Upstream(ref why) = *self {
Some(why)
} else {
None
}
}
fn description(&self) -> &str {
match *self {
Error::Upstream(_) => "upstream service error",
Error::RateLimit => "rate limit exceeded",
}
}
}

View File

@ -0,0 +1,29 @@
use std::time::Duration;
#[derive(Debug, Copy, Clone)]
pub struct Rate {
num: u64,
per: Duration,
}
impl Rate {
/// Create a new rate
///
/// # Panics
///
/// This function panics if `num` or `per` is 0.
pub fn new(num: u64, per: Duration) -> Self {
assert!(num > 0);
assert!(per > Duration::from_millis(0));
Rate { num, per }
}
pub(crate) fn num(&self) -> u64 {
self.num
}
pub(crate) fn per(&self) -> Duration {
self.per
}
}