Ensure received client request oneshots are used via the type system
The `peer::Client` translates `Request`s into `ClientRequest`s, which it sends to a background task. If the send is `Ok(())`, it will assume that it is safe to unconditionally poll the `Receiver` tied to the `Sender` used to create the `ClientRequest`. We enforce this invariant via the type system, by converting `ClientRequest`s to `InProgressClientRequest`s when they are received by the background task. These conversions are implemented by `ClientRequestReceiver`. Changes: * Revert `ClientRequest` so it uses a `oneshot::Sender` * Add `InProgressClientRequest`, which is the same as `ClientRequest`, but has a `MustUseOneshotSender` * `impl From<ClientRequest> for InProgressClientRequest` * Add a new `ClientRequestReceiver` type that wraps a `mpsc::Receiver<ClientRequest>` * `impl Stream<InProgressClientRequest> for ClientRequestReceiver`, converting the successful result of `inner.poll_next_unpin` into an `InProgressClientRequest` * Replace `client_rx: mpsc::Receiver<ClientRequest>` in `Connection` with the new `ClientRequestReceiver` type * `impl From<mpsc::Receiver<ClientRequest>> for ClientRequestReceiver`
This commit is contained in:
parent
df1b0c8d58
commit
6d3aa0002c
|
@ -12,6 +12,8 @@ mod error;
|
||||||
mod handshake;
|
mod handshake;
|
||||||
|
|
||||||
use client::ClientRequest;
|
use client::ClientRequest;
|
||||||
|
use client::ClientRequestReceiver;
|
||||||
|
use client::InProgressClientRequest;
|
||||||
use client::MustUseOneshotSender;
|
use client::MustUseOneshotSender;
|
||||||
use error::ErrorSlot;
|
use error::ErrorSlot;
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ use std::{
|
||||||
use futures::{
|
use futures::{
|
||||||
channel::{mpsc, oneshot},
|
channel::{mpsc, oneshot},
|
||||||
future, ready,
|
future, ready,
|
||||||
|
stream::{Stream, StreamExt},
|
||||||
};
|
};
|
||||||
use tower::Service;
|
use tower::Service;
|
||||||
|
|
||||||
|
@ -25,8 +26,32 @@ pub struct Client {
|
||||||
|
|
||||||
/// A message from the `peer::Client` to the `peer::Server`.
|
/// A message from the `peer::Client` to the `peer::Server`.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[must_use = "tx.send() must be called before drop"]
|
|
||||||
pub(super) struct ClientRequest {
|
pub(super) struct ClientRequest {
|
||||||
|
/// The actual request.
|
||||||
|
pub request: Request,
|
||||||
|
/// The return message channel, included because `peer::Client::call` returns a
|
||||||
|
/// future that may be moved around before it resolves.
|
||||||
|
pub tx: oneshot::Sender<Result<Response, SharedPeerError>>,
|
||||||
|
/// The tracing context for the request, so that work the connection task does
|
||||||
|
/// processing messages in the context of this request will have correct context.
|
||||||
|
pub span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A receiver for the `peer::Server`, which wraps a `mpsc::Receiver`,
|
||||||
|
/// converting `ClientRequest`s into `InProgressClientRequest`s.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(super) struct ClientRequestReceiver {
|
||||||
|
/// The inner receiver
|
||||||
|
inner: mpsc::Receiver<ClientRequest>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A message from the `peer::Client` to the `peer::Server`,
|
||||||
|
/// after it has been received by the `peer::Server`.
|
||||||
|
///
|
||||||
|
///
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[must_use = "tx.send() must be called before drop"]
|
||||||
|
pub(super) struct InProgressClientRequest {
|
||||||
/// The actual request.
|
/// The actual request.
|
||||||
pub request: Request,
|
pub request: Request,
|
||||||
/// The return message channel, included because `peer::Client::call` returns a
|
/// The return message channel, included because `peer::Client::call` returns a
|
||||||
|
@ -34,7 +59,15 @@ pub(super) struct ClientRequest {
|
||||||
///
|
///
|
||||||
/// INVARIANT: `tx.send()` must be called before dropping `tx`.
|
/// INVARIANT: `tx.send()` must be called before dropping `tx`.
|
||||||
///
|
///
|
||||||
/// JUSTIFICATION: the `peer::Client` will translate all `Request`s into a `ClientRequest` which it sends to a background task, and if the send replies with `Ok(())` it will assume that it is safe to unconditionally poll the `Receiver` tied to the `Sender` used to create the `ClientRequest`.
|
/// JUSTIFICATION: the `peer::Client` translates `Request`s into
|
||||||
|
/// `ClientRequest`s, which it sends to a background task. If the send is
|
||||||
|
/// `Ok(())`, it will assume that it is safe to unconditionally poll the
|
||||||
|
/// `Receiver` tied to the `Sender` used to create the `ClientRequest`.
|
||||||
|
///
|
||||||
|
/// We enforce this invariant via the type system, by converting
|
||||||
|
/// `ClientRequest`s to `InProgressClientRequest`s when they are received by
|
||||||
|
/// the background task. These conversions are implemented by
|
||||||
|
/// `ClientRequestReceiver`.
|
||||||
pub tx: MustUseOneshotSender<Result<Response, SharedPeerError>>,
|
pub tx: MustUseOneshotSender<Result<Response, SharedPeerError>>,
|
||||||
/// The tracing context for the request, so that work the connection task does
|
/// The tracing context for the request, so that work the connection task does
|
||||||
/// processing messages in the context of this request will have correct context.
|
/// processing messages in the context of this request will have correct context.
|
||||||
|
@ -54,6 +87,49 @@ pub(super) struct MustUseOneshotSender<T: std::fmt::Debug> {
|
||||||
pub tx: Option<oneshot::Sender<T>>,
|
pub tx: Option<oneshot::Sender<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<ClientRequest> for InProgressClientRequest {
|
||||||
|
fn from(client_request: ClientRequest) -> Self {
|
||||||
|
let ClientRequest { request, tx, span } = client_request;
|
||||||
|
InProgressClientRequest {
|
||||||
|
request,
|
||||||
|
tx: tx.into(),
|
||||||
|
span,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientRequestReceiver {
|
||||||
|
/// Forwards to `inner.close()`
|
||||||
|
pub fn close(&mut self) {
|
||||||
|
self.inner.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for ClientRequestReceiver {
|
||||||
|
type Item = InProgressClientRequest;
|
||||||
|
|
||||||
|
/// Converts the successful result of `inner.poll_next()` to an
|
||||||
|
/// `InProgressClientRequest`.
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
match self.inner.poll_next_unpin(cx) {
|
||||||
|
Poll::Ready(client_request) => Poll::Ready(client_request.map(Into::into)),
|
||||||
|
// `inner.poll_next_unpin` parks the task for this future
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns `inner.size_hint()`
|
||||||
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
|
self.inner.size_hint()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<mpsc::Receiver<ClientRequest>> for ClientRequestReceiver {
|
||||||
|
fn from(rx: mpsc::Receiver<ClientRequest>) -> Self {
|
||||||
|
ClientRequestReceiver { inner: rx }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: std::fmt::Debug> MustUseOneshotSender<T> {
|
impl<T: std::fmt::Debug> MustUseOneshotSender<T> {
|
||||||
/// Forwards `t` to `tx.send()`, and marks this sender as used.
|
/// Forwards `t` to `tx.send()`, and marks this sender as used.
|
||||||
///
|
///
|
||||||
|
@ -143,11 +219,7 @@ impl Service<Request> for Client {
|
||||||
// request.
|
// request.
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
|
||||||
match self.server_tx.try_send(ClientRequest {
|
match self.server_tx.try_send(ClientRequest { request, span, tx }) {
|
||||||
request,
|
|
||||||
span,
|
|
||||||
tx: tx.into(),
|
|
||||||
}) {
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
if e.is_disconnected() {
|
if e.is_disconnected() {
|
||||||
let ClientRequest { tx, .. } = e.into_inner();
|
let ClientRequest { tx, .. } = e.into_inner();
|
||||||
|
|
|
@ -10,7 +10,6 @@
|
||||||
use std::{collections::HashSet, sync::Arc};
|
use std::{collections::HashSet, sync::Arc};
|
||||||
|
|
||||||
use futures::{
|
use futures::{
|
||||||
channel::mpsc,
|
|
||||||
future::{self, Either},
|
future::{self, Either},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
stream::Stream,
|
stream::Stream,
|
||||||
|
@ -34,7 +33,10 @@ use crate::{
|
||||||
BoxError,
|
BoxError,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{ClientRequest, ErrorSlot, MustUseOneshotSender, PeerError, SharedPeerError};
|
use super::{
|
||||||
|
ClientRequestReceiver, ErrorSlot, InProgressClientRequest, MustUseOneshotSender, PeerError,
|
||||||
|
SharedPeerError,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(super) enum Handler {
|
pub(super) enum Handler {
|
||||||
|
@ -327,7 +329,9 @@ pub struct Connection<S, Tx> {
|
||||||
/// other state handling.
|
/// other state handling.
|
||||||
pub(super) request_timer: Option<Sleep>,
|
pub(super) request_timer: Option<Sleep>,
|
||||||
pub(super) svc: S,
|
pub(super) svc: S,
|
||||||
pub(super) client_rx: mpsc::Receiver<ClientRequest>,
|
/// A `mpsc::Receiver<ClientRequest>` that converts its results to
|
||||||
|
/// `InProgressClientRequest`
|
||||||
|
pub(super) client_rx: ClientRequestReceiver,
|
||||||
/// A slot for an error shared between the Connection and the Client that uses it.
|
/// A slot for an error shared between the Connection and the Client that uses it.
|
||||||
pub(super) error_slot: ErrorSlot,
|
pub(super) error_slot: ErrorSlot,
|
||||||
//pub(super) peer_rx: Rx,
|
//pub(super) peer_rx: Rx,
|
||||||
|
@ -475,7 +479,7 @@ where
|
||||||
// requests before we can return and complete the future.
|
// requests before we can return and complete the future.
|
||||||
State::Failed => {
|
State::Failed => {
|
||||||
match self.client_rx.next().await {
|
match self.client_rx.next().await {
|
||||||
Some(ClientRequest { tx, span, .. }) => {
|
Some(InProgressClientRequest { tx, span, .. }) => {
|
||||||
trace!(
|
trace!(
|
||||||
parent: &span,
|
parent: &span,
|
||||||
"erroring pending request to failed connection"
|
"erroring pending request to failed connection"
|
||||||
|
@ -535,11 +539,11 @@ where
|
||||||
///
|
///
|
||||||
/// NOTE: the caller should use .instrument(msg.span) to instrument the function.
|
/// NOTE: the caller should use .instrument(msg.span) to instrument the function.
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
async fn handle_client_request(&mut self, req: ClientRequest) {
|
async fn handle_client_request(&mut self, req: InProgressClientRequest) {
|
||||||
trace!(?req.request);
|
trace!(?req.request);
|
||||||
use Request::*;
|
use Request::*;
|
||||||
use State::*;
|
use State::*;
|
||||||
let ClientRequest { request, tx, span } = req;
|
let InProgressClientRequest { request, tx, span } = req;
|
||||||
|
|
||||||
if tx.is_canceled() {
|
if tx.is_canceled() {
|
||||||
metrics::counter!("peer.canceled", 1);
|
metrics::counter!("peer.canceled", 1);
|
||||||
|
|
|
@ -435,7 +435,7 @@ where
|
||||||
let server = Connection {
|
let server = Connection {
|
||||||
state: connection::State::AwaitingRequest,
|
state: connection::State::AwaitingRequest,
|
||||||
svc: inbound_service,
|
svc: inbound_service,
|
||||||
client_rx: server_rx,
|
client_rx: server_rx.into(),
|
||||||
error_slot: slot,
|
error_slot: slot,
|
||||||
peer_tx,
|
peer_tx,
|
||||||
request_timer: None,
|
request_timer: None,
|
||||||
|
@ -451,7 +451,7 @@ where
|
||||||
let heartbeat_span = tracing::debug_span!(parent: connection_span, "heartbeat");
|
let heartbeat_span = tracing::debug_span!(parent: connection_span, "heartbeat");
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
async move {
|
async move {
|
||||||
use super::client::ClientRequest;
|
use super::ClientRequest;
|
||||||
use futures::future::Either;
|
use futures::future::Either;
|
||||||
|
|
||||||
let mut shutdown_rx = shutdown_rx;
|
let mut shutdown_rx = shutdown_rx;
|
||||||
|
@ -466,16 +466,23 @@ where
|
||||||
tracing::trace!(?request, "queueing heartbeat request");
|
tracing::trace!(?request, "queueing heartbeat request");
|
||||||
match server_tx.try_send(ClientRequest {
|
match server_tx.try_send(ClientRequest {
|
||||||
request,
|
request,
|
||||||
tx: tx.into(),
|
tx,
|
||||||
span: tracing::Span::current(),
|
span: tracing::Span::current(),
|
||||||
}) {
|
}) {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
match server_tx.flush().await {
|
match server_tx.flush().await {
|
||||||
Ok(()) => {}
|
Ok(()) => {}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// TODO: we can't get the client request for this failure,
|
// We can't get the client request for this failure,
|
||||||
// so we can't ensure the invariant holds
|
// so we can't send an error back here. But that's ok,
|
||||||
panic!("flushing client request failed: {:?}", e);
|
// because:
|
||||||
|
// - this error never happens (or it's very rare)
|
||||||
|
// - if the flush() fails, the server hasn't
|
||||||
|
// received the request
|
||||||
|
tracing::warn!(
|
||||||
|
"flushing client request failed: {:?}",
|
||||||
|
e
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue