Ensure received client request oneshots are used via the type system
The `peer::Client` translates `Request`s into `ClientRequest`s, which it sends to a background task. If the send is `Ok(())`, it will assume that it is safe to unconditionally poll the `Receiver` tied to the `Sender` used to create the `ClientRequest`. We enforce this invariant via the type system, by converting `ClientRequest`s to `InProgressClientRequest`s when they are received by the background task. These conversions are implemented by `ClientRequestReceiver`. Changes: * Revert `ClientRequest` so it uses a `oneshot::Sender` * Add `InProgressClientRequest`, which is the same as `ClientRequest`, but has a `MustUseOneshotSender` * `impl From<ClientRequest> for InProgressClientRequest` * Add a new `ClientRequestReceiver` type that wraps a `mpsc::Receiver<ClientRequest>` * `impl Stream<InProgressClientRequest> for ClientRequestReceiver`, converting the successful result of `inner.poll_next_unpin` into an `InProgressClientRequest` * Replace `client_rx: mpsc::Receiver<ClientRequest>` in `Connection` with the new `ClientRequestReceiver` type * `impl From<mpsc::Receiver<ClientRequest>> for ClientRequestReceiver`
This commit is contained in:
parent
df1b0c8d58
commit
6d3aa0002c
|
@ -12,6 +12,8 @@ mod error;
|
|||
mod handshake;
|
||||
|
||||
use client::ClientRequest;
|
||||
use client::ClientRequestReceiver;
|
||||
use client::InProgressClientRequest;
|
||||
use client::MustUseOneshotSender;
|
||||
use error::ErrorSlot;
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ use std::{
|
|||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
future, ready,
|
||||
stream::{Stream, StreamExt},
|
||||
};
|
||||
use tower::Service;
|
||||
|
||||
|
@ -25,8 +26,32 @@ pub struct Client {
|
|||
|
||||
/// A message from the `peer::Client` to the `peer::Server`.
|
||||
#[derive(Debug)]
|
||||
#[must_use = "tx.send() must be called before drop"]
|
||||
pub(super) struct ClientRequest {
|
||||
/// The actual request.
|
||||
pub request: Request,
|
||||
/// The return message channel, included because `peer::Client::call` returns a
|
||||
/// future that may be moved around before it resolves.
|
||||
pub tx: oneshot::Sender<Result<Response, SharedPeerError>>,
|
||||
/// The tracing context for the request, so that work the connection task does
|
||||
/// processing messages in the context of this request will have correct context.
|
||||
pub span: tracing::Span,
|
||||
}
|
||||
|
||||
/// A receiver for the `peer::Server`, which wraps a `mpsc::Receiver`,
|
||||
/// converting `ClientRequest`s into `InProgressClientRequest`s.
|
||||
#[derive(Debug)]
|
||||
pub(super) struct ClientRequestReceiver {
|
||||
/// The inner receiver
|
||||
inner: mpsc::Receiver<ClientRequest>,
|
||||
}
|
||||
|
||||
/// A message from the `peer::Client` to the `peer::Server`,
|
||||
/// after it has been received by the `peer::Server`.
|
||||
///
|
||||
///
|
||||
#[derive(Debug)]
|
||||
#[must_use = "tx.send() must be called before drop"]
|
||||
pub(super) struct InProgressClientRequest {
|
||||
/// The actual request.
|
||||
pub request: Request,
|
||||
/// The return message channel, included because `peer::Client::call` returns a
|
||||
|
@ -34,7 +59,15 @@ pub(super) struct ClientRequest {
|
|||
///
|
||||
/// INVARIANT: `tx.send()` must be called before dropping `tx`.
|
||||
///
|
||||
/// JUSTIFICATION: the `peer::Client` will translate all `Request`s into a `ClientRequest` which it sends to a background task, and if the send replies with `Ok(())` it will assume that it is safe to unconditionally poll the `Receiver` tied to the `Sender` used to create the `ClientRequest`.
|
||||
/// JUSTIFICATION: the `peer::Client` translates `Request`s into
|
||||
/// `ClientRequest`s, which it sends to a background task. If the send is
|
||||
/// `Ok(())`, it will assume that it is safe to unconditionally poll the
|
||||
/// `Receiver` tied to the `Sender` used to create the `ClientRequest`.
|
||||
///
|
||||
/// We enforce this invariant via the type system, by converting
|
||||
/// `ClientRequest`s to `InProgressClientRequest`s when they are received by
|
||||
/// the background task. These conversions are implemented by
|
||||
/// `ClientRequestReceiver`.
|
||||
pub tx: MustUseOneshotSender<Result<Response, SharedPeerError>>,
|
||||
/// The tracing context for the request, so that work the connection task does
|
||||
/// processing messages in the context of this request will have correct context.
|
||||
|
@ -54,6 +87,49 @@ pub(super) struct MustUseOneshotSender<T: std::fmt::Debug> {
|
|||
pub tx: Option<oneshot::Sender<T>>,
|
||||
}
|
||||
|
||||
impl From<ClientRequest> for InProgressClientRequest {
|
||||
fn from(client_request: ClientRequest) -> Self {
|
||||
let ClientRequest { request, tx, span } = client_request;
|
||||
InProgressClientRequest {
|
||||
request,
|
||||
tx: tx.into(),
|
||||
span,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientRequestReceiver {
|
||||
/// Forwards to `inner.close()`
|
||||
pub fn close(&mut self) {
|
||||
self.inner.close()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for ClientRequestReceiver {
|
||||
type Item = InProgressClientRequest;
|
||||
|
||||
/// Converts the successful result of `inner.poll_next()` to an
|
||||
/// `InProgressClientRequest`.
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match self.inner.poll_next_unpin(cx) {
|
||||
Poll::Ready(client_request) => Poll::Ready(client_request.map(Into::into)),
|
||||
// `inner.poll_next_unpin` parks the task for this future
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `inner.size_hint()`
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
self.inner.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpsc::Receiver<ClientRequest>> for ClientRequestReceiver {
|
||||
fn from(rx: mpsc::Receiver<ClientRequest>) -> Self {
|
||||
ClientRequestReceiver { inner: rx }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: std::fmt::Debug> MustUseOneshotSender<T> {
|
||||
/// Forwards `t` to `tx.send()`, and marks this sender as used.
|
||||
///
|
||||
|
@ -143,11 +219,7 @@ impl Service<Request> for Client {
|
|||
// request.
|
||||
let span = tracing::Span::current();
|
||||
|
||||
match self.server_tx.try_send(ClientRequest {
|
||||
request,
|
||||
span,
|
||||
tx: tx.into(),
|
||||
}) {
|
||||
match self.server_tx.try_send(ClientRequest { request, span, tx }) {
|
||||
Err(e) => {
|
||||
if e.is_disconnected() {
|
||||
let ClientRequest { tx, .. } = e.into_inner();
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
future::{self, Either},
|
||||
prelude::*,
|
||||
stream::Stream,
|
||||
|
@ -34,7 +33,10 @@ use crate::{
|
|||
BoxError,
|
||||
};
|
||||
|
||||
use super::{ClientRequest, ErrorSlot, MustUseOneshotSender, PeerError, SharedPeerError};
|
||||
use super::{
|
||||
ClientRequestReceiver, ErrorSlot, InProgressClientRequest, MustUseOneshotSender, PeerError,
|
||||
SharedPeerError,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) enum Handler {
|
||||
|
@ -327,7 +329,9 @@ pub struct Connection<S, Tx> {
|
|||
/// other state handling.
|
||||
pub(super) request_timer: Option<Sleep>,
|
||||
pub(super) svc: S,
|
||||
pub(super) client_rx: mpsc::Receiver<ClientRequest>,
|
||||
/// A `mpsc::Receiver<ClientRequest>` that converts its results to
|
||||
/// `InProgressClientRequest`
|
||||
pub(super) client_rx: ClientRequestReceiver,
|
||||
/// A slot for an error shared between the Connection and the Client that uses it.
|
||||
pub(super) error_slot: ErrorSlot,
|
||||
//pub(super) peer_rx: Rx,
|
||||
|
@ -475,7 +479,7 @@ where
|
|||
// requests before we can return and complete the future.
|
||||
State::Failed => {
|
||||
match self.client_rx.next().await {
|
||||
Some(ClientRequest { tx, span, .. }) => {
|
||||
Some(InProgressClientRequest { tx, span, .. }) => {
|
||||
trace!(
|
||||
parent: &span,
|
||||
"erroring pending request to failed connection"
|
||||
|
@ -535,11 +539,11 @@ where
|
|||
///
|
||||
/// NOTE: the caller should use .instrument(msg.span) to instrument the function.
|
||||
#[instrument(skip(self))]
|
||||
async fn handle_client_request(&mut self, req: ClientRequest) {
|
||||
async fn handle_client_request(&mut self, req: InProgressClientRequest) {
|
||||
trace!(?req.request);
|
||||
use Request::*;
|
||||
use State::*;
|
||||
let ClientRequest { request, tx, span } = req;
|
||||
let InProgressClientRequest { request, tx, span } = req;
|
||||
|
||||
if tx.is_canceled() {
|
||||
metrics::counter!("peer.canceled", 1);
|
||||
|
|
|
@ -435,7 +435,7 @@ where
|
|||
let server = Connection {
|
||||
state: connection::State::AwaitingRequest,
|
||||
svc: inbound_service,
|
||||
client_rx: server_rx,
|
||||
client_rx: server_rx.into(),
|
||||
error_slot: slot,
|
||||
peer_tx,
|
||||
request_timer: None,
|
||||
|
@ -451,7 +451,7 @@ where
|
|||
let heartbeat_span = tracing::debug_span!(parent: connection_span, "heartbeat");
|
||||
tokio::spawn(
|
||||
async move {
|
||||
use super::client::ClientRequest;
|
||||
use super::ClientRequest;
|
||||
use futures::future::Either;
|
||||
|
||||
let mut shutdown_rx = shutdown_rx;
|
||||
|
@ -466,16 +466,23 @@ where
|
|||
tracing::trace!(?request, "queueing heartbeat request");
|
||||
match server_tx.try_send(ClientRequest {
|
||||
request,
|
||||
tx: tx.into(),
|
||||
tx,
|
||||
span: tracing::Span::current(),
|
||||
}) {
|
||||
Ok(()) => {
|
||||
match server_tx.flush().await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
// TODO: we can't get the client request for this failure,
|
||||
// so we can't ensure the invariant holds
|
||||
panic!("flushing client request failed: {:?}", e);
|
||||
// We can't get the client request for this failure,
|
||||
// so we can't send an error back here. But that's ok,
|
||||
// because:
|
||||
// - this error never happens (or it's very rare)
|
||||
// - if the flush() fails, the server hasn't
|
||||
// received the request
|
||||
tracing::warn!(
|
||||
"flushing client request failed: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue