Panic when must-use senders are dropped before use

Add a MustUseOneshotSender, which panics if its inner sender is unused.
Callers must call `send()` on the MustUseOneshotSender, or ensure that
the sender is canceled.

Replaces an unreliable panic in `Client::call()` with a reliable panic
when a must-use sender is dropped.
This commit is contained in:
teor 2021-01-04 17:43:33 +10:00 committed by Jane Lusby
parent b03809ebe3
commit fa29fca917
4 changed files with 81 additions and 6 deletions

View File

@ -12,6 +12,7 @@ mod error;
mod handshake;
use client::ClientRequest;
use client::MustUseOneshotSender;
use error::ErrorSlot;
pub use client::Client;

View File

@ -33,12 +33,82 @@ pub(super) struct ClientRequest {
/// future that may be moved around before it resolves.
///
/// INVARIANT: `tx.send()` must be called before dropping `tx`.
pub tx: oneshot::Sender<Result<Response, SharedPeerError>>,
pub tx: MustUseOneshotSender<Result<Response, SharedPeerError>>,
/// The tracing context for the request, so that work the connection task does
/// processing messages in the context of this request will have correct context.
pub span: tracing::Span,
}
/// A oneshot::Sender that must be used by calling `send()`.
///
/// Panics on drop if `tx` has not been used or canceled.
/// Panics if `tx.send()` is used more than once.
#[derive(Debug)]
#[must_use = "tx.send() must be called before drop"]
pub(super) struct MustUseOneshotSender<T: std::fmt::Debug> {
/// The sender for the oneshot channel.
///
/// `None` if `tx.send()` has been used.
pub tx: Option<oneshot::Sender<T>>,
}
impl<T: std::fmt::Debug> MustUseOneshotSender<T> {
/// Forwards `t` to `tx.send()`, and marks this sender as used.
///
/// Panics if `tx.send()` is used more than once.
pub fn send(mut self, t: T) -> Result<(), T> {
self.tx
.take()
.unwrap_or_else(|| {
panic!(
"multiple uses of oneshot sender: oneshot must be used exactly once: {:?}",
self
)
})
.send(t)
}
/// Returns `tx.cancellation()`.
///
/// Panics if `tx.send()` has previously been used.
pub fn cancellation(&mut self) -> oneshot::Cancellation<'_, T> {
self.tx
.as_mut()
.map(|tx| tx.cancellation())
.unwrap_or_else( || {
panic!("called cancellation() after using oneshot sender: oneshot must be used exactly once")
})
}
/// Returns `tx.is_canceled()`.
///
/// Panics if `tx.send()` has previously been used.
pub fn is_canceled(&self) -> bool {
self.tx
.as_ref()
.map(|tx| tx.is_canceled())
.unwrap_or_else(
|| panic!("called is_canceled() after using oneshot sender: oneshot must be used exactly once: {:?}", self))
}
}
impl<T: std::fmt::Debug> From<oneshot::Sender<T>> for MustUseOneshotSender<T> {
fn from(sender: oneshot::Sender<T>) -> Self {
MustUseOneshotSender { tx: Some(sender) }
}
}
impl<T: std::fmt::Debug> Drop for MustUseOneshotSender<T> {
fn drop(&mut self) {
// is_canceled() will not panic, because we check is_none() first
assert!(
self.tx.is_none() || self.is_canceled(),
"unused oneshot sender: oneshot must be used or canceled: {:?}",
self
);
}
}
impl Service<Request> for Client {
type Response = Response;
type Error = SharedPeerError;
@ -66,7 +136,11 @@ impl Service<Request> for Client {
// request.
let span = tracing::Span::current();
match self.server_tx.try_send(ClientRequest { request, span, tx }) {
match self.server_tx.try_send(ClientRequest {
request,
span,
tx: tx.into(),
}) {
Err(e) => {
if e.is_disconnected() {
future::ready(Err(self

View File

@ -10,7 +10,7 @@
use std::{collections::HashSet, sync::Arc};
use futures::{
channel::{mpsc, oneshot},
channel::mpsc,
future::{self, Either},
prelude::*,
stream::Stream,
@ -34,7 +34,7 @@ use crate::{
BoxError,
};
use super::{ClientRequest, ErrorSlot, PeerError, SharedPeerError};
use super::{ClientRequest, ErrorSlot, MustUseOneshotSender, PeerError, SharedPeerError};
#[derive(Debug)]
pub(super) enum Handler {
@ -312,7 +312,7 @@ pub(super) enum State {
/// Awaiting a peer message we can interpret as a client request.
AwaitingResponse {
handler: Handler,
tx: oneshot::Sender<Result<Response, SharedPeerError>>,
tx: MustUseOneshotSender<Result<Response, SharedPeerError>>,
span: tracing::Span,
},
/// A failure has occurred and we are shutting down the connection.

View File

@ -466,7 +466,7 @@ where
if server_tx
.send(ClientRequest {
request,
tx,
tx: tx.into(),
span: tracing::Span::current(),
})
.await