diff --git a/zebra-network/src/peer.rs b/zebra-network/src/peer.rs index b98824c3e..fab29d706 100644 --- a/zebra-network/src/peer.rs +++ b/zebra-network/src/peer.rs @@ -12,6 +12,7 @@ mod error; mod handshake; use client::ClientRequest; +use client::MustUseOneshotSender; use error::ErrorSlot; pub use client::Client; diff --git a/zebra-network/src/peer/client.rs b/zebra-network/src/peer/client.rs index 939db8cf8..912067660 100644 --- a/zebra-network/src/peer/client.rs +++ b/zebra-network/src/peer/client.rs @@ -33,12 +33,82 @@ pub(super) struct ClientRequest { /// future that may be moved around before it resolves. /// /// INVARIANT: `tx.send()` must be called before dropping `tx`. - pub tx: oneshot::Sender>, + pub tx: MustUseOneshotSender>, /// The tracing context for the request, so that work the connection task does /// processing messages in the context of this request will have correct context. pub span: tracing::Span, } +/// A oneshot::Sender that must be used by calling `send()`. +/// +/// Panics on drop if `tx` has not been used or canceled. +/// Panics if `tx.send()` is used more than once. +#[derive(Debug)] +#[must_use = "tx.send() must be called before drop"] +pub(super) struct MustUseOneshotSender { + /// The sender for the oneshot channel. + /// + /// `None` if `tx.send()` has been used. + pub tx: Option>, +} + +impl MustUseOneshotSender { + /// Forwards `t` to `tx.send()`, and marks this sender as used. + /// + /// Panics if `tx.send()` is used more than once. + pub fn send(mut self, t: T) -> Result<(), T> { + self.tx + .take() + .unwrap_or_else(|| { + panic!( + "multiple uses of oneshot sender: oneshot must be used exactly once: {:?}", + self + ) + }) + .send(t) + } + + /// Returns `tx.cancellation()`. + /// + /// Panics if `tx.send()` has previously been used. + pub fn cancellation(&mut self) -> oneshot::Cancellation<'_, T> { + self.tx + .as_mut() + .map(|tx| tx.cancellation()) + .unwrap_or_else( || { + panic!("called cancellation() after using oneshot sender: oneshot must be used exactly once") + }) + } + + /// Returns `tx.is_canceled()`. + /// + /// Panics if `tx.send()` has previously been used. + pub fn is_canceled(&self) -> bool { + self.tx + .as_ref() + .map(|tx| tx.is_canceled()) + .unwrap_or_else( + || panic!("called is_canceled() after using oneshot sender: oneshot must be used exactly once: {:?}", self)) + } +} + +impl From> for MustUseOneshotSender { + fn from(sender: oneshot::Sender) -> Self { + MustUseOneshotSender { tx: Some(sender) } + } +} + +impl Drop for MustUseOneshotSender { + fn drop(&mut self) { + // is_canceled() will not panic, because we check is_none() first + assert!( + self.tx.is_none() || self.is_canceled(), + "unused oneshot sender: oneshot must be used or canceled: {:?}", + self + ); + } +} + impl Service for Client { type Response = Response; type Error = SharedPeerError; @@ -66,7 +136,11 @@ impl Service for Client { // request. let span = tracing::Span::current(); - match self.server_tx.try_send(ClientRequest { request, span, tx }) { + match self.server_tx.try_send(ClientRequest { + request, + span, + tx: tx.into(), + }) { Err(e) => { if e.is_disconnected() { future::ready(Err(self diff --git a/zebra-network/src/peer/connection.rs b/zebra-network/src/peer/connection.rs index 904cdd976..980af0a25 100644 --- a/zebra-network/src/peer/connection.rs +++ b/zebra-network/src/peer/connection.rs @@ -10,7 +10,7 @@ use std::{collections::HashSet, sync::Arc}; use futures::{ - channel::{mpsc, oneshot}, + channel::mpsc, future::{self, Either}, prelude::*, stream::Stream, @@ -34,7 +34,7 @@ use crate::{ BoxError, }; -use super::{ClientRequest, ErrorSlot, PeerError, SharedPeerError}; +use super::{ClientRequest, ErrorSlot, MustUseOneshotSender, PeerError, SharedPeerError}; #[derive(Debug)] pub(super) enum Handler { @@ -312,7 +312,7 @@ pub(super) enum State { /// Awaiting a peer message we can interpret as a client request. AwaitingResponse { handler: Handler, - tx: oneshot::Sender>, + tx: MustUseOneshotSender>, span: tracing::Span, }, /// A failure has occurred and we are shutting down the connection. diff --git a/zebra-network/src/peer/handshake.rs b/zebra-network/src/peer/handshake.rs index fe2f2b76c..18424585d 100644 --- a/zebra-network/src/peer/handshake.rs +++ b/zebra-network/src/peer/handshake.rs @@ -466,7 +466,7 @@ where if server_tx .send(ClientRequest { request, - tx, + tx: tx.into(), span: tracing::Span::current(), }) .await