tower-fallback: add implementation.
This commit is contained in:
parent
a19fdd9f25
commit
4be0a8edc3
|
@ -489,6 +489,12 @@ dependencies = [
|
|||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
|
||||
|
||||
[[package]]
|
||||
name = "equihash"
|
||||
version = "0.1.0"
|
||||
|
@ -2144,6 +2150,19 @@ dependencies = [
|
|||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-fallback"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"either",
|
||||
"futures-core",
|
||||
"pin-project",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tracing",
|
||||
"zebra-test",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.0"
|
||||
|
|
|
@ -11,6 +11,7 @@ members = [
|
|||
"zebra-test",
|
||||
"zebra-utils",
|
||||
"tower-batch",
|
||||
"tower-fallback",
|
||||
]
|
||||
|
||||
[profile.dev]
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
[package]
|
||||
name = "tower-fallback"
|
||||
version = "0.1.0"
|
||||
authors = ["Zcash Foundation <zebra@zfnd.org>"]
|
||||
license = "MIT"
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
either = "1.5"
|
||||
tower = "0.3"
|
||||
futures-core = "0.3.5"
|
||||
pin-project = "0.4.20"
|
||||
tracing = "0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
zebra-test = { path = "../zebra-test/" }
|
||||
tokio = { version = "0.2", features = ["full"]}
|
|
@ -0,0 +1,158 @@
|
|||
//! Future types for the `Fallback` middleware.
|
||||
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use either::Either;
|
||||
use futures_core::ready;
|
||||
use pin_project::pin_project;
|
||||
use tower::Service;
|
||||
|
||||
/// Future that completes either with the first service's successful response, or
|
||||
/// with the second service's response.
|
||||
#[pin_project]
|
||||
pub struct ResponseFuture<S1, S2, Request>
|
||||
where
|
||||
S1: Service<Request>,
|
||||
S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
|
||||
{
|
||||
#[pin]
|
||||
state: ResponseState<S1, S2, Request>,
|
||||
}
|
||||
|
||||
#[pin_project(project_replace, project = ResponseStateProj)]
|
||||
enum ResponseState<S1, S2, Request>
|
||||
where
|
||||
S1: Service<Request>,
|
||||
S2: Service<Request>,
|
||||
{
|
||||
PollResponse1 {
|
||||
#[pin]
|
||||
fut: S1::Future,
|
||||
req: Request,
|
||||
svc2: S2,
|
||||
},
|
||||
PollReady2 {
|
||||
req: Request,
|
||||
svc2: S2,
|
||||
},
|
||||
PollResponse2 {
|
||||
#[pin]
|
||||
fut: S2::Future,
|
||||
},
|
||||
// Placeholder value to swap into the pin projection of the enum so we can take ownership of the fields.
|
||||
Tmp,
|
||||
}
|
||||
|
||||
impl<S1, S2, Request> ResponseFuture<S1, S2, Request>
|
||||
where
|
||||
S1: Service<Request>,
|
||||
S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
|
||||
{
|
||||
pub(crate) fn new(fut: S1::Future, req: Request, svc2: S2) -> Self {
|
||||
ResponseFuture {
|
||||
state: ResponseState::PollResponse1 { fut, req, svc2 },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S1, S2, Request> Future for ResponseFuture<S1, S2, Request>
|
||||
where
|
||||
S1: Service<Request>,
|
||||
S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
|
||||
{
|
||||
type Output = Result<
|
||||
<S1 as Service<Request>>::Response,
|
||||
Either<<S1 as Service<Request>>::Error, <S2 as Service<Request>>::Error>,
|
||||
>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.project();
|
||||
loop {
|
||||
match this.state.as_mut().project() {
|
||||
ResponseStateProj::PollResponse1 { fut, .. } => match ready!(fut.poll(cx)) {
|
||||
Ok(rsp) => return Poll::Ready(Ok(rsp)),
|
||||
Err(_) => {
|
||||
tracing::debug!("got error from svc1, retrying on svc2");
|
||||
if let __ResponseStateProjectionOwned::PollResponse1 { req, svc2, .. } =
|
||||
this.state.as_mut().project_replace(ResponseState::Tmp)
|
||||
{
|
||||
this.state.set(ResponseState::PollReady2 { req, svc2 });
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
},
|
||||
ResponseStateProj::PollReady2 { svc2, .. } => match ready!(svc2.poll_ready(cx)) {
|
||||
Err(e) => return Poll::Ready(Err(Either::Right(e))),
|
||||
Ok(()) => {
|
||||
if let __ResponseStateProjectionOwned::PollReady2 { mut svc2, req } =
|
||||
this.state.as_mut().project_replace(ResponseState::Tmp)
|
||||
{
|
||||
this.state.set(ResponseState::PollResponse2 {
|
||||
fut: svc2.call(req),
|
||||
});
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
},
|
||||
ResponseStateProj::PollResponse2 { fut } => {
|
||||
return fut.poll(cx).map_err(Either::Right)
|
||||
}
|
||||
ResponseStateProj::Tmp => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S1, S2, Request> Debug for ResponseFuture<S1, S2, Request>
|
||||
where
|
||||
S1: Service<Request>,
|
||||
S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
|
||||
Request: Debug,
|
||||
S1::Future: Debug,
|
||||
S2: Debug,
|
||||
S2::Future: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ResponseFuture")
|
||||
.field("state", &self.state)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S1, S2, Request> Debug for ResponseState<S1, S2, Request>
|
||||
where
|
||||
S1: Service<Request>,
|
||||
S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
|
||||
Request: Debug,
|
||||
S1::Future: Debug,
|
||||
S2: Debug,
|
||||
S2::Future: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ResponseState::PollResponse1 { fut, req, svc2 } => f
|
||||
.debug_struct("ResponseState::PollResponse1")
|
||||
.field("fut", fut)
|
||||
.field("req", req)
|
||||
.field("svc2", svc2)
|
||||
.finish(),
|
||||
ResponseState::PollReady2 { req, svc2 } => f
|
||||
.debug_struct("ResponseState::PollReady2")
|
||||
.field("req", req)
|
||||
.field("svc2", svc2)
|
||||
.finish(),
|
||||
ResponseState::PollResponse2 { fut } => f
|
||||
.debug_struct("ResponseState::PollResponse2")
|
||||
.field("fut", fut)
|
||||
.finish(),
|
||||
ResponseState::Tmp => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
/// A service combinator that sends requests to a first service, then retries
|
||||
/// processing on a second fallback service if the first service errors.
|
||||
///
|
||||
/// TODO: similar code exists in linkerd and could be upstreamed into tower
|
||||
pub mod future;
|
||||
mod service;
|
||||
|
||||
pub use self::service::Fallback;
|
||||
pub use either::Either;
|
|
@ -0,0 +1,54 @@
|
|||
use super::future::ResponseFuture;
|
||||
|
||||
use either::Either;
|
||||
use std::task::{Context, Poll};
|
||||
use tower::Service;
|
||||
|
||||
/// Provides fallback processing on a second service if the first service returned an error.
|
||||
#[derive(Debug)]
|
||||
pub struct Fallback<S1, S2>
|
||||
where
|
||||
S2: Clone,
|
||||
{
|
||||
svc1: S1,
|
||||
svc2: S2,
|
||||
}
|
||||
|
||||
impl<S1: Clone, S2: Clone> Clone for Fallback<S1, S2> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
svc1: self.svc1.clone(),
|
||||
svc2: self.svc2.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S1, S2: Clone> Fallback<S1, S2> {
|
||||
/// Creates a new `Fallback` wrapping a pair of services.
|
||||
///
|
||||
/// Requests are processed on `svc1`, and retried on `svc2` if `svc1` errored.
|
||||
pub fn new(svc1: S1, svc2: S2) -> Self {
|
||||
Self { svc1, svc2 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S1, S2, Request> Service<Request> for Fallback<S1, S2>
|
||||
where
|
||||
S1: Service<Request>,
|
||||
S2: Service<Request, Response = <S1 as Service<Request>>::Response>,
|
||||
S2: Clone,
|
||||
Request: Clone,
|
||||
{
|
||||
type Response = <S1 as Service<Request>>::Response;
|
||||
type Error = Either<<S1 as Service<Request>>::Error, <S2 as Service<Request>>::Error>;
|
||||
type Future = ResponseFuture<S1, S2, Request>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.svc1.poll_ready(cx).map_err(Either::Left)
|
||||
}
|
||||
|
||||
fn call(&mut self, request: Request) -> Self::Future {
|
||||
let request2 = request.clone();
|
||||
ResponseFuture::new(self.svc1.call(request), request2, self.svc2.clone())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
use tower::{service_fn, Service, ServiceExt};
|
||||
use tower_fallback::{Either, Fallback};
|
||||
|
||||
#[tokio::test]
|
||||
async fn fallback() {
|
||||
zebra_test::init();
|
||||
|
||||
// we'd like to use Transcript here but it can't handle errors :(
|
||||
|
||||
let svc1 = service_fn(|val: u64| async move {
|
||||
if val < 10 {
|
||||
Ok(val)
|
||||
} else {
|
||||
Err("too big value on svc1")
|
||||
}
|
||||
});
|
||||
let svc2 = service_fn(|val: u64| async move {
|
||||
if val < 20 {
|
||||
Ok(100 + val)
|
||||
} else {
|
||||
Err("too big value on svc2")
|
||||
}
|
||||
});
|
||||
|
||||
let mut svc = Fallback::new(svc1, svc2);
|
||||
|
||||
assert_eq!(svc.ready_and().await.unwrap().call(1).await, Ok(1));
|
||||
assert_eq!(svc.ready_and().await.unwrap().call(11).await, Ok(111));
|
||||
assert_eq!(
|
||||
svc.ready_and().await.unwrap().call(21).await,
|
||||
Err(Either::Right("too big value on svc2"))
|
||||
);
|
||||
}
|
Loading…
Reference in New Issue