use pin_project::pin_project; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use tower_service::Service; /// A policy which decides which requests can be cloned and sent to the B /// service. pub trait Policy { fn clone_request(&self, req: &Request) -> Option; } /// Select is a middleware which attempts to clone the request and sends the /// original request to the A service and, if the request was able to be cloned, /// the cloned request to the B service. Both resulting futures will be polled /// and whichever future completes first will be used as the result. #[derive(Debug)] pub struct Select { policy: P, a: A, b: B, } #[pin_project] #[derive(Debug)] pub struct ResponseFuture { #[pin] a_fut: AF, #[pin] b_fut: Option, } impl Select { pub fn new(policy: P, a: A, b: B) -> Self where P: Policy, A: Service, A::Error: Into, B: Service, B::Error: Into, { Select { policy, a, b } } } impl Service for Select where P: Policy, A: Service, A::Error: Into, B: Service, B::Error: Into, { type Response = A::Response; type Error = super::Error; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { match (self.a.poll_ready(cx), self.b.poll_ready(cx)) { (Poll::Ready(Ok(())), Poll::Ready(Ok(()))) => Poll::Ready(Ok(())), (Poll::Ready(Err(e)), _) => Poll::Ready(Err(e.into())), (_, Poll::Ready(Err(e))) => Poll::Ready(Err(e.into())), _ => Poll::Pending, } } fn call(&mut self, request: Request) -> Self::Future { let b_fut = if let Some(cloned_req) = self.policy.clone_request(&request) { Some(self.b.call(cloned_req)) } else { None }; ResponseFuture { a_fut: self.a.call(request), b_fut, } } } impl Future for ResponseFuture where AF: Future>, AE: Into, BF: Future>, BE: Into, { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); if let Poll::Ready(r) = this.a_fut.poll(cx) { return Poll::Ready(Ok(r.map_err(Into::into)?)); } if let Some(b_fut) = this.b_fut.as_pin_mut() { if let Poll::Ready(r) = b_fut.poll(cx) { return Poll::Ready(Ok(r.map_err(Into::into)?)); } } return Poll::Pending; } }