diff --git a/zebra-network/src/peer/connection.rs b/zebra-network/src/peer/connection.rs index 3c14a59bd..20c02a66b 100644 --- a/zebra-network/src/peer/connection.rs +++ b/zebra-network/src/peer/connection.rs @@ -411,6 +411,29 @@ impl State { } } +/// The outcome of mapping an inbound [`Message`] to a [`Request`]. +#[derive(Clone, Debug, Eq, PartialEq)] +#[must_use = "inbound messages must be handled"] +pub enum InboundMessage { + /// The message was mapped to an inbound [`Request`]. + AsRequest(Request), + + /// The message was consumed by the mapping method. + /// + /// For example, it could be cached, treated as an error, + /// or an internally handled [`Message::Ping`]. + Consumed, + + /// The message was not used by the inbound message handler. + Unused, +} + +impl From for InboundMessage { + fn from(request: Request) -> Self { + InboundMessage::AsRequest(request) + } +} + /// The state associated with a peer connection. pub struct Connection { /// The state of this connection's current request or response. @@ -777,8 +800,7 @@ where ); self.update_state_metrics(format!("Out::Req::{}", request.command())); - // These matches return a Result with (new_state, Option) or an (error, Sender) - let new_state_result = match (&self.state, request) { + let new_handler = match (&self.state, request) { (Failed, request) => panic!( "failed connection cannot handle new request: {:?}, client_receiver: {:?}", request, @@ -792,7 +814,7 @@ where ), // Consume the cached addresses from the peer, - // to work-around a `zcashd` response rate-limit + // to work-around a `zcashd` response rate-limit. (AwaitingRequest, Peers) if !self.cached_addrs.is_empty() => { let cached_addrs = std::mem::take(&mut self.cached_addrs); debug!( @@ -800,184 +822,117 @@ where "responding to Peers request using cached addresses", ); - Ok(( - AwaitingResponse { - handler: Handler::Finished(Ok(Response::Peers(cached_addrs))), - tx, - span, - }, - None, - ))} - , - (AwaitingRequest, Peers) => match self.peer_tx.send(Message::GetAddr).await { - Ok(()) => Ok(( - AwaitingResponse { - handler: Handler::Peers, - tx, - span, - }, - None, - )), - Err(e) => Err((e, tx)), - }, + Ok(Handler::Finished(Ok(Response::Peers(cached_addrs)))) + } + (AwaitingRequest, Peers) => self + .peer_tx + .send(Message::GetAddr) + .await + .map(|()| Handler::Peers), + + (AwaitingRequest, Ping(nonce)) => self + .peer_tx + .send(Message::Ping(nonce)) + .await + .map(|()| Handler::Ping(nonce)), - (AwaitingRequest, Ping(nonce)) => match self.peer_tx.send(Message::Ping(nonce)).await { - Ok(()) => Ok(( - AwaitingResponse { - handler: Handler::Ping(nonce), - tx, - span, - }, - None, - )), - Err(e) => Err((e, tx)), - }, (AwaitingRequest, BlocksByHash(hashes)) => { - match self + self .peer_tx .send(Message::GetData( hashes.iter().map(|h| (*h).into()).collect(), )) .await - { - Ok(()) => Ok(( - AwaitingResponse { - handler: Handler::BlocksByHash { - blocks: Vec::with_capacity(hashes.len()), - pending_hashes: hashes, - }, - tx, - span, - }, - None, - )), - Err(e) => Err((e, tx)), - } + .map(|()| + Handler::BlocksByHash { + blocks: Vec::with_capacity(hashes.len()), + pending_hashes: hashes, + } + ) } (AwaitingRequest, TransactionsById(ids)) => { - match self + self .peer_tx .send(Message::GetData( ids.iter().map(Into::into).collect(), )) .await - { - Ok(()) => Ok(( - AwaitingResponse { - handler: Handler::TransactionsById { - transactions: Vec::with_capacity(ids.len()), - pending_ids: ids, - }, - tx, - span, - }, - None, - )), - Err(e) => Err((e, tx)), - } + .map(|()| + Handler::TransactionsById { + transactions: Vec::with_capacity(ids.len()), + pending_ids: ids, + }) } + (AwaitingRequest, FindBlocks { known_blocks, stop }) => { - match self + self .peer_tx .send(Message::GetBlocks { known_blocks, stop }) .await - { - Ok(()) => Ok(( - AwaitingResponse { - handler: Handler::FindBlocks, - tx, - span, - }, - None, - )), - Err(e) => Err((e, tx)), - } + .map(|()| + Handler::FindBlocks + ) } (AwaitingRequest, FindHeaders { known_blocks, stop }) => { - match self + self .peer_tx .send(Message::GetHeaders { known_blocks, stop }) .await - { - Ok(()) => Ok(( - AwaitingResponse { - handler: Handler::FindHeaders, - tx, - span, - }, - None, - )), - Err(e) => Err((e, tx)), - } + .map(|()| + Handler::FindHeaders + ) } + (AwaitingRequest, MempoolTransactionIds) => { - match self.peer_tx.send(Message::Mempool).await { - Ok(()) => Ok(( - AwaitingResponse { - handler: Handler::MempoolTransactionIds, - tx, - span, - }, - None, - )), - Err(e) => Err((e, tx)), - } + self + .peer_tx + .send(Message::Mempool) + .await + .map(|()| + Handler::MempoolTransactionIds + ) } + (AwaitingRequest, PushTransaction(transaction)) => { - match self.peer_tx.send(Message::Tx(transaction)).await { - Ok(()) => Ok((AwaitingRequest, Some(tx))), - Err(e) => Err((e, tx)), - } + self + .peer_tx + .send(Message::Tx(transaction)) + .await + .map(|()| + Handler::Finished(Ok(Response::Nil)) + ) } (AwaitingRequest, AdvertiseTransactionIds(hashes)) => { - match self + self .peer_tx .send(Message::Inv(hashes.iter().map(|h| (*h).into()).collect())) .await - { - Ok(()) => Ok((AwaitingRequest, Some(tx))), - Err(e) => Err((e, tx)), - } + .map(|()| + Handler::Finished(Ok(Response::Nil)) + ) } (AwaitingRequest, AdvertiseBlock(hash)) => { - match self.peer_tx.send(Message::Inv(vec![hash.into()])).await { - Ok(()) => Ok((AwaitingRequest, Some(tx))), - Err(e) => Err((e, tx)), - } + self + .peer_tx + .send(Message::Inv(vec![hash.into()])) + .await + .map(|()| + Handler::Finished(Ok(Response::Nil)) + ) } }; - // Updates state or fails. Sends the error on the Sender if it is Some. - match new_state_result { - Ok((AwaitingRequest, Some(tx))) => { - // Since we're not waiting for further messages, we need to - // send a response before dropping tx. - let _ = tx.send(Ok(Response::Nil)); - self.state = AwaitingRequest; - // We only need a timer when we're waiting for a response. - // (And we don't want to accidentally re-use old timers.) - self.request_timer = None; - } - Ok((new_state @ AwaitingResponse { .. }, None)) => { - self.state = new_state; + + // Update the connection state with a new handler, or fail with an error. + match new_handler { + Ok(handler) => { + self.state = AwaitingResponse { handler, span, tx }; self.request_timer = Some(Box::pin(sleep(constants::REQUEST_TIMEOUT))); } - Err((e, tx)) => { - let e = SharedPeerError::from(e); - let _ = tx.send(Err(e.clone())); - self.fail_with(e); + Err(error) => { + let error = SharedPeerError::from(error); + let _ = tx.send(Err(error.clone())); + self.fail_with(error); } - // unreachable states - Ok((Failed, tx)) => unreachable!( - "failed client requests must use fail_with(error) to reach a Failed state. tx: {:?}", - tx - ), - Ok((AwaitingRequest, None)) => unreachable!( - "successful AwaitingRequest states must send a response on tx, but tx is None", - ), - Ok((new_state @ AwaitingResponse { .. }, Some(tx))) => unreachable!( - "successful AwaitingResponse states must keep tx, but tx is Some: {:?} for: {:?}", - tx, new_state, - ), }; } @@ -993,22 +948,24 @@ where self.update_state_metrics(format!("In::Msg::{}", msg.command())); + use InboundMessage::*; + let req = match msg { Message::Ping(nonce) => { trace!(?nonce, "responding to heartbeat"); if let Err(e) = self.peer_tx.send(Message::Pong(nonce)).await { self.fail_with(e); } - None + Consumed } // These messages shouldn't be sent outside of a handshake. Message::Version { .. } => { self.fail_with(PeerError::DuplicateHandshake); - None + Consumed } Message::Verack { .. } => { self.fail_with(PeerError::DuplicateHandshake); - None + Consumed } // These messages should already be handled as a response if they // could be a response, so if we see them here, they were either @@ -1016,23 +973,23 @@ where // that we've already forgotten about. Message::Reject { .. } => { debug!(%msg, "got reject message unsolicited or from canceled request"); - None + Unused } Message::NotFound { .. } => { debug!(%msg, "got notfound message unsolicited or from canceled request"); - None + Unused } Message::Pong(_) => { debug!(%msg, "got pong message unsolicited or from canceled request"); - None + Unused } Message::Block(_) => { debug!(%msg, "got block message unsolicited or from canceled request"); - None + Unused } Message::Headers(_) => { debug!(%msg, "got headers message unsolicited or from canceled request"); - None + Unused } // These messages should never be sent by peers. Message::FilterLoad { .. } @@ -1046,7 +1003,9 @@ where // Since we can't verify their source, Zebra needs to ignore unexpected messages, // because closing the connection could cause a denial of service or eclipse attack. debug!(%msg, "got BIP111 message without advertising NODE_BLOOM"); - None + + // Ignored, but consumed because it is technically a protocol error. + Consumed } // Zebra crawls the network proactively, to prevent // peers from inserting data into our address book. @@ -1056,46 +1015,50 @@ where // Always refresh the cache with multi-addr messages. debug!(%msg, "caching unsolicited multi-addr message"); self.cached_addrs = addrs.clone(); + Consumed } else if addrs.len() == 1 && self.cached_addrs.len() <= 1 { // Only refresh a cached single addr message with another single addr. // (`zcashd` regularly advertises its own address.) debug!(%msg, "caching unsolicited single addr message"); self.cached_addrs = addrs.clone(); + Consumed } else { debug!( %msg, "ignoring unsolicited single addr message: already cached a multi-addr message" ); + Consumed } - None } - Message::Tx(ref transaction) => Some(Request::PushTransaction(transaction.clone())), + Message::Tx(ref transaction) => Request::PushTransaction(transaction.clone()).into(), Message::Inv(ref items) => match &items[..] { // We don't expect to be advertised multiple blocks at a time, // so we ignore any advertisements of multiple blocks. - [InventoryHash::Block(hash)] => Some(Request::AdvertiseBlock(*hash)), + [InventoryHash::Block(hash)] => Request::AdvertiseBlock(*hash).into(), // Some peers advertise invs with mixed item types. // But we're just interested in the transaction invs. // // TODO: split mixed invs into multiple requests, // but skip runs of multiple blocks. - tx_ids if tx_ids.iter().any(|item| item.unmined_tx_id().is_some()) => Some( - Request::AdvertiseTransactionIds(transaction_ids(items).collect()), - ), + tx_ids if tx_ids.iter().any(|item| item.unmined_tx_id().is_some()) => { + Request::AdvertiseTransactionIds(transaction_ids(items).collect()).into() + } // Log detailed messages for ignored inv advertisement messages. [] => { debug!(%msg, "ignoring empty inv"); - None + + // This might be a minor protocol error, or it might mean "not found". + Unused } [InventoryHash::Block(_), InventoryHash::Block(_), ..] => { debug!(%msg, "ignoring inv with multiple blocks"); - None + Unused } _ => { debug!(%msg, "ignoring inv with no transactions"); - None + Unused } }, Message::GetData(ref items) => match &items[..] { @@ -1112,46 +1075,52 @@ where .iter() .any(|item| matches!(item, InventoryHash::Block(_))) => { - Some(Request::BlocksByHash(block_hashes(items).collect())) + Request::BlocksByHash(block_hashes(items).collect()).into() } tx_ids if tx_ids.iter().any(|item| item.unmined_tx_id().is_some()) => { - Some(Request::TransactionsById(transaction_ids(items).collect())) + Request::TransactionsById(transaction_ids(items).collect()).into() } // Log detailed messages for ignored getdata request messages. [] => { debug!(%msg, "ignoring empty getdata"); - None + + // This might be a minor protocol error, or it might mean "not found". + Unused } _ => { debug!(%msg, "ignoring getdata with no blocks or transactions"); - None + Unused } }, - Message::GetAddr => Some(Request::Peers), + Message::GetAddr => Request::Peers.into(), Message::GetBlocks { ref known_blocks, stop, - } => Some(Request::FindBlocks { + } => Request::FindBlocks { known_blocks: known_blocks.clone(), stop, - }), + } + .into(), Message::GetHeaders { ref known_blocks, stop, - } => Some(Request::FindHeaders { + } => Request::FindHeaders { known_blocks: known_blocks.clone(), stop, - }), - Message::Mempool => Some(Request::MempoolTransactionIds), + } + .into(), + Message::Mempool => Request::MempoolTransactionIds.into(), }; - if let Some(req) = req { - self.drive_peer_request(req).await; - None - } else { - // return the unused message - Some(msg) + // Handle the request, and return unused messages. + match req { + AsRequest(req) => { + self.drive_peer_request(req).await; + None + } + Consumed => None, + Unused => Some(msg), } }