From 3892894ffae9838824c5764ab9b8b47d4dff067f Mon Sep 17 00:00:00 2001 From: teor Date: Wed, 16 Dec 2020 16:43:19 +1000 Subject: [PATCH] Call ClientRequest.tx.send() even if there is an error Previously, tx would be dropped before send if: - the success case would have used tx to wait for further messages, - but the response was actually an error. Instead, send the error on `tx` and call `fail_with()` using the same error. To support this change, allow `fail_with()` to take a `PeerError` or a `SharedPeerError`. --- zebra-network/src/peer/connection.rs | 222 +++++++++++++++------------ 1 file changed, 122 insertions(+), 100 deletions(-) diff --git a/zebra-network/src/peer/connection.rs b/zebra-network/src/peer/connection.rs index 3628ceade..3d9887046 100644 --- a/zebra-network/src/peer/connection.rs +++ b/zebra-network/src/peer/connection.rs @@ -7,8 +7,7 @@ //! And it's unclear if these assumptions match the `zcashd` implementation. //! It should be refactored into a cleaner set of request/response pairs (#1515). -use std::collections::HashSet; -use std::sync::Arc; +use std::{collections::HashSet, fmt, sync::Arc}; use futures::{ channel::{mpsc, oneshot}, @@ -371,7 +370,7 @@ where Either::Left((None, _)) => { self.fail_with(PeerError::ConnectionClosed); } - Either::Left((Some(Err(e)), _)) => self.fail_with(e.into()), + Either::Left((Some(Err(e)), _)) => self.fail_with(e), Either::Left((Some(Ok(msg)), _)) => { self.handle_message_as_request(msg).await } @@ -405,7 +404,7 @@ where .await { Either::Left((None, _)) => self.fail_with(PeerError::ConnectionClosed), - Either::Left((Some(Err(e)), _)) => self.fail_with(e.into()), + Either::Left((Some(Err(e)), _)) => self.fail_with(e), Either::Left((Some(Ok(peer_msg)), _cancel)) => { // Try to process the message using the handler. // This extremely awkward construction avoids @@ -494,7 +493,10 @@ where } /// Marks the peer as having failed with error `e`. - fn fail_with(&mut self, e: PeerError) { + fn fail_with(&mut self, e: E) + where + E: Into + fmt::Display, + { debug!(%e, "failing peer service with error"); // Update the shared error slot let mut guard = self @@ -542,123 +544,143 @@ where // XXX(hdevalence) this is truly horrible, but let's fix it later - // Inner match returns Result with the new state or an error. - // Outer match updates state or fails. + // Inner matches return a Result with a new state or an (error, Option) + // Middle match returns Result with the new state or the (error, Option) + // Outer match updates state or fails, and sends the error on the Sender if it is Some match match (&self.state, request) { (Failed, _) => panic!("failed connection cannot handle requests"), (AwaitingResponse { .. }, _) => panic!("tried to update pending request"), - (AwaitingRequest, Peers) => self - .peer_tx - .send(Message::GetAddr) - .await - .map_err(|e| e.into()) - .map(|()| AwaitingResponse { + + (AwaitingRequest, Peers) => match self.peer_tx.send(Message::GetAddr).await { + Ok(()) => Ok(AwaitingResponse { handler: Handler::Peers, tx, span, }), - (AwaitingRequest, Ping(nonce)) => self - .peer_tx - .send(Message::Ping(nonce)) - .await - .map_err(|e| e.into()) - .map(|()| AwaitingResponse { + Err(e) => Err((e, Some(tx))), + }, + (AwaitingRequest, Ping(nonce)) => match self.peer_tx.send(Message::Ping(nonce)).await { + Ok(()) => Ok(AwaitingResponse { handler: Handler::Ping(nonce), tx, span, }), - (AwaitingRequest, BlocksByHash(hashes)) => self - .peer_tx - .send(Message::GetData( - hashes.iter().map(|h| (*h).into()).collect(), - )) - .await - .map_err(|e| e.into()) - .map(|()| AwaitingResponse { - handler: Handler::BlocksByHash { - blocks: Vec::with_capacity(hashes.len()), - hashes, - }, - tx, - span, - }), - (AwaitingRequest, TransactionsByHash(hashes)) => self - .peer_tx - .send(Message::GetData( - hashes.iter().map(|h| (*h).into()).collect(), - )) - .await - .map_err(|e| e.into()) - .map(|()| AwaitingResponse { - handler: Handler::TransactionsByHash { - transactions: Vec::with_capacity(hashes.len()), - hashes, - }, - tx, - span, - }), - (AwaitingRequest, FindBlocks { known_blocks, stop }) => self - .peer_tx - .send(Message::GetBlocks { known_blocks, stop }) - .await - .map_err(|e| e.into()) - .map(|()| AwaitingResponse { - handler: Handler::FindBlocks, - tx, - span, - }), - (AwaitingRequest, FindHeaders { known_blocks, stop }) => self - .peer_tx - .send(Message::GetHeaders { known_blocks, stop }) - .await - .map_err(|e| e.into()) - .map(|()| AwaitingResponse { - handler: Handler::FindHeaders, - tx, - span, - }), - (AwaitingRequest, MempoolTransactions) => self - .peer_tx - .send(Message::Mempool) - .await - .map_err(|e| e.into()) - .map(|()| AwaitingResponse { - handler: Handler::MempoolTransactions, - tx, - span, - }), + Err(e) => Err((e, Some(tx))), + }, + (AwaitingRequest, BlocksByHash(hashes)) => { + match self + .peer_tx + .send(Message::GetData( + hashes.iter().map(|h| (*h).into()).collect(), + )) + .await + { + Ok(()) => Ok(AwaitingResponse { + handler: Handler::BlocksByHash { + blocks: Vec::with_capacity(hashes.len()), + hashes, + }, + tx, + span, + }), + Err(e) => Err((e, Some(tx))), + } + } + (AwaitingRequest, TransactionsByHash(hashes)) => { + match self + .peer_tx + .send(Message::GetData( + hashes.iter().map(|h| (*h).into()).collect(), + )) + .await + { + Ok(()) => Ok(AwaitingResponse { + handler: Handler::TransactionsByHash { + transactions: Vec::with_capacity(hashes.len()), + hashes, + }, + tx, + span, + }), + Err(e) => Err((e, Some(tx))), + } + } + (AwaitingRequest, FindBlocks { known_blocks, stop }) => { + match self + .peer_tx + .send(Message::GetBlocks { known_blocks, stop }) + .await + { + Ok(()) => Ok(AwaitingResponse { + handler: Handler::FindBlocks, + tx, + span, + }), + Err(e) => Err((e, Some(tx))), + } + } + (AwaitingRequest, FindHeaders { known_blocks, stop }) => { + match self + .peer_tx + .send(Message::GetHeaders { known_blocks, stop }) + .await + { + Ok(()) => Ok(AwaitingResponse { + handler: Handler::FindHeaders, + tx, + span, + }), + Err(e) => Err((e, Some(tx))), + } + } + (AwaitingRequest, MempoolTransactions) => { + match self.peer_tx.send(Message::Mempool).await { + Ok(()) => Ok(AwaitingResponse { + handler: Handler::MempoolTransactions, + tx, + span, + }), + Err(e) => Err((e, Some(tx))), + } + } (AwaitingRequest, PushTransaction(transaction)) => { // Since we're not waiting for further messages, we need to // send a response before dropping tx. let _ = tx.send(Ok(Response::Nil)); - self.peer_tx - .send(Message::Tx(transaction)) - .await - .map_err(|e| e.into()) - .map(|()| AwaitingRequest) + match self.peer_tx.send(Message::Tx(transaction)).await { + Ok(()) => Ok(AwaitingRequest), + Err(e) => Err((e, None)), + } } (AwaitingRequest, AdvertiseTransactions(hashes)) => { let _ = tx.send(Ok(Response::Nil)); - self.peer_tx + match self + .peer_tx .send(Message::Inv(hashes.iter().map(|h| (*h).into()).collect())) .await - .map_err(|e| e.into()) - .map(|()| AwaitingRequest) + { + Ok(()) => Ok(AwaitingRequest), + Err(e) => Err((e, None)), + } } (AwaitingRequest, AdvertiseBlock(hash)) => { let _ = tx.send(Ok(Response::Nil)); - self.peer_tx - .send(Message::Inv(vec![hash.into()])) - .await - .map_err(|e| e.into()) - .map(|()| AwaitingRequest) + match self.peer_tx.send(Message::Inv(vec![hash.into()])).await { + Ok(()) => Ok(AwaitingRequest), + Err(e) => Err((e, None)), + } } } { Ok(new_state) => { self.state = new_state; self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); } - Err(e) => self.fail_with(e), + Err((e, Some(tx))) => { + let e = SharedPeerError::from(e); + let _ = tx.send(Err(e.clone())); + self.fail_with(e); + } + Err((e, None)) => self.fail_with(e), } } @@ -671,7 +693,7 @@ where Message::Ping(nonce) => { trace!(?nonce, "responding to heartbeat"); if let Err(e) = self.peer_tx.send(Message::Pong(nonce)).await { - self.fail_with(e.into()); + self.fail_with(e); } return; } @@ -800,14 +822,14 @@ where Response::Nil => { /* generic success, do nothing */ } Response::Peers(addrs) => { if let Err(e) = self.peer_tx.send(Message::Addr(addrs)).await { - self.fail_with(e.into()); + self.fail_with(e); } } Response::Transactions(transactions) => { // Generate one tx message per transaction. for transaction in transactions.into_iter() { if let Err(e) = self.peer_tx.send(Message::Tx(transaction)).await { - self.fail_with(e.into()); + self.fail_with(e); } } } @@ -815,7 +837,7 @@ where // Generate one block message per block. for block in blocks.into_iter() { if let Err(e) = self.peer_tx.send(Message::Block(block)).await { - self.fail_with(e.into()); + self.fail_with(e); } } } @@ -825,12 +847,12 @@ where .send(Message::Inv(hashes.into_iter().map(Into::into).collect())) .await { - self.fail_with(e.into()) + self.fail_with(e) } } Response::BlockHeaders(headers) => { if let Err(e) = self.peer_tx.send(Message::Headers(headers)).await { - self.fail_with(e.into()) + self.fail_with(e) } } Response::TransactionHashes(hashes) => { @@ -839,7 +861,7 @@ where .send(Message::Inv(hashes.into_iter().map(Into::into).collect())) .await { - self.fail_with(e.into()) + self.fail_with(e) } } }