use futures_util::ready; use pin_project::{pin_project, project}; use std::time::Duration; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use tower_service::Service; /// A policy which specifies how long each request should be delayed for. pub trait Policy { fn delay(&self, req: &Request) -> Duration; } /// A middleware which delays sending the request to the underlying service /// for an amount of time specified by the policy. #[derive(Debug)] pub struct Delay { policy: P, service: S, } #[pin_project] #[derive(Debug)] pub struct ResponseFuture { service: S, #[pin] state: State, } #[pin_project] #[derive(Debug)] enum State { Delaying(#[pin] tokio::time::Delay, Option), Called(#[pin] F), } impl Delay { pub fn new(policy: P, service: S) -> Self where P: Policy, S: Service + Clone, S::Error: Into, { Delay { policy, service } } } impl Service for Delay where P: Policy, S: Service + Clone, S::Error: Into, { type Response = S::Response; type Error = super::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx).map_err(|e| e.into()) } fn call(&mut self, request: Request) -> Self::Future { let deadline = tokio::time::Instant::now() + self.policy.delay(&request); let mut cloned = self.service.clone(); // Pass the original service to the ResponseFuture and keep the cloned service on self. let orig = { std::mem::swap(&mut cloned, &mut self.service); cloned }; ResponseFuture { service: orig, state: State::Delaying(tokio::time::delay_until(deadline), Some(request)), } } } impl Future for ResponseFuture where F: Future>, E: Into, S: Service, { type Output = Result; #[project] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); loop { #[project] match this.state.as_mut().project() { State::Delaying(delay, req) => { ready!(delay.poll(cx)); let req = req.take().expect("Missing request in delay"); let fut = this.service.call(req); this.state.set(State::Called(fut)); } State::Called(fut) => { return fut.poll(cx).map_err(Into::into); } }; } } }