Ensure received client request oneshots are used via the type system

The `peer::Client` translates `Request`s into `ClientRequest`s, which
it sends to a background task. If the send is `Ok(())`, it will assume
that it is safe to unconditionally poll the `Receiver` tied to the
`Sender` used to create the `ClientRequest`.

We enforce this invariant via the type system, by converting
`ClientRequest`s to `InProgressClientRequest`s when they are received by
the background task. These conversions are implemented by
`ClientRequestReceiver`.

Changes:
* Revert `ClientRequest` so it uses a `oneshot::Sender`
* Add `InProgressClientRequest`, which is the same as `ClientRequest`,
  but has a `MustUseOneshotSender`
* `impl From<ClientRequest> for InProgressClientRequest`

* Add a new `ClientRequestReceiver` type that wraps a
  `mpsc::Receiver<ClientRequest>`
* `impl Stream<InProgressClientRequest> for ClientRequestReceiver`,
  converting the successful result of `inner.poll_next_unpin` into an
  `InProgressClientRequest`

* Replace `client_rx: mpsc::Receiver<ClientRequest>` in `Connection`
  with the new `ClientRequestReceiver` type
* `impl From<mpsc::Receiver<ClientRequest>> for ClientRequestReceiver`
This commit is contained in:
teor 2021-01-06 17:58:20 +10:00 committed by Jane Lusby
parent df1b0c8d58
commit 6d3aa0002c
4 changed files with 104 additions and 19 deletions

View File

@ -12,6 +12,8 @@ mod error;
mod handshake;
use client::ClientRequest;
use client::ClientRequestReceiver;
use client::InProgressClientRequest;
use client::MustUseOneshotSender;
use error::ErrorSlot;

View File

@ -7,6 +7,7 @@ use std::{
use futures::{
channel::{mpsc, oneshot},
future, ready,
stream::{Stream, StreamExt},
};
use tower::Service;
@ -25,8 +26,32 @@ pub struct Client {
/// A message from the `peer::Client` to the `peer::Server`.
#[derive(Debug)]
#[must_use = "tx.send() must be called before drop"]
pub(super) struct ClientRequest {
/// The actual request.
pub request: Request,
/// The return message channel, included because `peer::Client::call` returns a
/// future that may be moved around before it resolves.
pub tx: oneshot::Sender<Result<Response, SharedPeerError>>,
/// The tracing context for the request, so that work the connection task does
/// processing messages in the context of this request will have correct context.
pub span: tracing::Span,
}
/// A receiver for the `peer::Server`, which wraps a `mpsc::Receiver`,
/// converting `ClientRequest`s into `InProgressClientRequest`s.
#[derive(Debug)]
pub(super) struct ClientRequestReceiver {
/// The inner receiver
inner: mpsc::Receiver<ClientRequest>,
}
/// A message from the `peer::Client` to the `peer::Server`,
/// after it has been received by the `peer::Server`.
///
///
#[derive(Debug)]
#[must_use = "tx.send() must be called before drop"]
pub(super) struct InProgressClientRequest {
/// The actual request.
pub request: Request,
/// The return message channel, included because `peer::Client::call` returns a
@ -34,7 +59,15 @@ pub(super) struct ClientRequest {
///
/// INVARIANT: `tx.send()` must be called before dropping `tx`.
///
/// JUSTIFICATION: the `peer::Client` will translate all `Request`s into a `ClientRequest` which it sends to a background task, and if the send replies with `Ok(())` it will assume that it is safe to unconditionally poll the `Receiver` tied to the `Sender` used to create the `ClientRequest`.
/// JUSTIFICATION: the `peer::Client` translates `Request`s into
/// `ClientRequest`s, which it sends to a background task. If the send is
/// `Ok(())`, it will assume that it is safe to unconditionally poll the
/// `Receiver` tied to the `Sender` used to create the `ClientRequest`.
///
/// We enforce this invariant via the type system, by converting
/// `ClientRequest`s to `InProgressClientRequest`s when they are received by
/// the background task. These conversions are implemented by
/// `ClientRequestReceiver`.
pub tx: MustUseOneshotSender<Result<Response, SharedPeerError>>,
/// The tracing context for the request, so that work the connection task does
/// processing messages in the context of this request will have correct context.
@ -54,6 +87,49 @@ pub(super) struct MustUseOneshotSender<T: std::fmt::Debug> {
pub tx: Option<oneshot::Sender<T>>,
}
impl From<ClientRequest> for InProgressClientRequest {
fn from(client_request: ClientRequest) -> Self {
let ClientRequest { request, tx, span } = client_request;
InProgressClientRequest {
request,
tx: tx.into(),
span,
}
}
}
impl ClientRequestReceiver {
/// Forwards to `inner.close()`
pub fn close(&mut self) {
self.inner.close()
}
}
impl Stream for ClientRequestReceiver {
type Item = InProgressClientRequest;
/// Converts the successful result of `inner.poll_next()` to an
/// `InProgressClientRequest`.
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner.poll_next_unpin(cx) {
Poll::Ready(client_request) => Poll::Ready(client_request.map(Into::into)),
// `inner.poll_next_unpin` parks the task for this future
Poll::Pending => Poll::Pending,
}
}
/// Returns `inner.size_hint()`
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl From<mpsc::Receiver<ClientRequest>> for ClientRequestReceiver {
fn from(rx: mpsc::Receiver<ClientRequest>) -> Self {
ClientRequestReceiver { inner: rx }
}
}
impl<T: std::fmt::Debug> MustUseOneshotSender<T> {
/// Forwards `t` to `tx.send()`, and marks this sender as used.
///
@ -143,11 +219,7 @@ impl Service<Request> for Client {
// request.
let span = tracing::Span::current();
match self.server_tx.try_send(ClientRequest {
request,
span,
tx: tx.into(),
}) {
match self.server_tx.try_send(ClientRequest { request, span, tx }) {
Err(e) => {
if e.is_disconnected() {
let ClientRequest { tx, .. } = e.into_inner();

View File

@ -10,7 +10,6 @@
use std::{collections::HashSet, sync::Arc};
use futures::{
channel::mpsc,
future::{self, Either},
prelude::*,
stream::Stream,
@ -34,7 +33,10 @@ use crate::{
BoxError,
};
use super::{ClientRequest, ErrorSlot, MustUseOneshotSender, PeerError, SharedPeerError};
use super::{
ClientRequestReceiver, ErrorSlot, InProgressClientRequest, MustUseOneshotSender, PeerError,
SharedPeerError,
};
#[derive(Debug)]
pub(super) enum Handler {
@ -327,7 +329,9 @@ pub struct Connection<S, Tx> {
/// other state handling.
pub(super) request_timer: Option<Sleep>,
pub(super) svc: S,
pub(super) client_rx: mpsc::Receiver<ClientRequest>,
/// A `mpsc::Receiver<ClientRequest>` that converts its results to
/// `InProgressClientRequest`
pub(super) client_rx: ClientRequestReceiver,
/// A slot for an error shared between the Connection and the Client that uses it.
pub(super) error_slot: ErrorSlot,
//pub(super) peer_rx: Rx,
@ -475,7 +479,7 @@ where
// requests before we can return and complete the future.
State::Failed => {
match self.client_rx.next().await {
Some(ClientRequest { tx, span, .. }) => {
Some(InProgressClientRequest { tx, span, .. }) => {
trace!(
parent: &span,
"erroring pending request to failed connection"
@ -535,11 +539,11 @@ where
///
/// NOTE: the caller should use .instrument(msg.span) to instrument the function.
#[instrument(skip(self))]
async fn handle_client_request(&mut self, req: ClientRequest) {
async fn handle_client_request(&mut self, req: InProgressClientRequest) {
trace!(?req.request);
use Request::*;
use State::*;
let ClientRequest { request, tx, span } = req;
let InProgressClientRequest { request, tx, span } = req;
if tx.is_canceled() {
metrics::counter!("peer.canceled", 1);

View File

@ -435,7 +435,7 @@ where
let server = Connection {
state: connection::State::AwaitingRequest,
svc: inbound_service,
client_rx: server_rx,
client_rx: server_rx.into(),
error_slot: slot,
peer_tx,
request_timer: None,
@ -451,7 +451,7 @@ where
let heartbeat_span = tracing::debug_span!(parent: connection_span, "heartbeat");
tokio::spawn(
async move {
use super::client::ClientRequest;
use super::ClientRequest;
use futures::future::Either;
let mut shutdown_rx = shutdown_rx;
@ -466,16 +466,23 @@ where
tracing::trace!(?request, "queueing heartbeat request");
match server_tx.try_send(ClientRequest {
request,
tx: tx.into(),
tx,
span: tracing::Span::current(),
}) {
Ok(()) => {
match server_tx.flush().await {
Ok(()) => {}
Err(e) => {
// TODO: we can't get the client request for this failure,
// so we can't ensure the invariant holds
panic!("flushing client request failed: {:?}", e);
// We can't get the client request for this failure,
// so we can't send an error back here. But that's ok,
// because:
// - this error never happens (or it's very rare)
// - if the flush() fails, the server hasn't
// received the request
tracing::warn!(
"flushing client request failed: {:?}",
e
);
}
}
}