diff --git a/zebra-network/src/peer/connection.rs b/zebra-network/src/peer/connection.rs index 1656066f6..736c4f3fd 100644 --- a/zebra-network/src/peer/connection.rs +++ b/zebra-network/src/peer/connection.rs @@ -543,26 +543,32 @@ where return; } - // These matches return a Result with a new state or an (error, Option) + // These matches return a Result with (new_state, Option) or an (error, Sender) let new_state_result = match (&self.state, request) { (Failed, _) => panic!("failed connection cannot handle requests"), (AwaitingResponse { .. }, _) => panic!("tried to update pending request"), (AwaitingRequest, Peers) => match self.peer_tx.send(Message::GetAddr).await { - Ok(()) => Ok(AwaitingResponse { - handler: Handler::Peers, - tx, - span, - }), - Err(e) => Err((e, Some(tx))), + Ok(()) => Ok(( + AwaitingResponse { + handler: Handler::Peers, + tx, + span, + }, + None, + )), + Err(e) => Err((e, tx)), }, (AwaitingRequest, Ping(nonce)) => match self.peer_tx.send(Message::Ping(nonce)).await { - Ok(()) => Ok(AwaitingResponse { - handler: Handler::Ping(nonce), - tx, - span, - }), - Err(e) => Err((e, Some(tx))), + Ok(()) => Ok(( + AwaitingResponse { + handler: Handler::Ping(nonce), + tx, + span, + }, + None, + )), + Err(e) => Err((e, tx)), }, (AwaitingRequest, BlocksByHash(hashes)) => { match self @@ -572,15 +578,18 @@ where )) .await { - Ok(()) => Ok(AwaitingResponse { - handler: Handler::BlocksByHash { - blocks: Vec::with_capacity(hashes.len()), - hashes, + Ok(()) => Ok(( + AwaitingResponse { + handler: Handler::BlocksByHash { + blocks: Vec::with_capacity(hashes.len()), + hashes, + }, + tx, + span, }, - tx, - span, - }), - Err(e) => Err((e, Some(tx))), + None, + )), + Err(e) => Err((e, tx)), } } (AwaitingRequest, TransactionsByHash(hashes)) => { @@ -591,15 +600,18 @@ where )) .await { - Ok(()) => Ok(AwaitingResponse { - handler: Handler::TransactionsByHash { - transactions: Vec::with_capacity(hashes.len()), - hashes, + Ok(()) => Ok(( + AwaitingResponse { + handler: Handler::TransactionsByHash { + transactions: Vec::with_capacity(hashes.len()), + hashes, + }, + tx, + span, }, - tx, - span, - }), - Err(e) => Err((e, Some(tx))), + None, + )), + Err(e) => Err((e, tx)), } } (AwaitingRequest, FindBlocks { known_blocks, stop }) => { @@ -608,12 +620,15 @@ where .send(Message::GetBlocks { known_blocks, stop }) .await { - Ok(()) => Ok(AwaitingResponse { - handler: Handler::FindBlocks, - tx, - span, - }), - Err(e) => Err((e, Some(tx))), + Ok(()) => Ok(( + AwaitingResponse { + handler: Handler::FindBlocks, + tx, + span, + }, + None, + )), + Err(e) => Err((e, tx)), } } (AwaitingRequest, FindHeaders { known_blocks, stop }) => { @@ -622,64 +637,83 @@ where .send(Message::GetHeaders { known_blocks, stop }) .await { - Ok(()) => Ok(AwaitingResponse { - handler: Handler::FindHeaders, - tx, - span, - }), - Err(e) => Err((e, Some(tx))), + Ok(()) => Ok(( + AwaitingResponse { + handler: Handler::FindHeaders, + tx, + span, + }, + None, + )), + Err(e) => Err((e, tx)), } } (AwaitingRequest, MempoolTransactions) => { match self.peer_tx.send(Message::Mempool).await { - Ok(()) => Ok(AwaitingResponse { - handler: Handler::MempoolTransactions, - tx, - span, - }), - Err(e) => Err((e, Some(tx))), + Ok(()) => Ok(( + AwaitingResponse { + handler: Handler::MempoolTransactions, + tx, + span, + }, + None, + )), + Err(e) => Err((e, tx)), } } (AwaitingRequest, PushTransaction(transaction)) => { - // Since we're not waiting for further messages, we need to - // send a response before dropping tx. - let _ = tx.send(Ok(Response::Nil)); match self.peer_tx.send(Message::Tx(transaction)).await { - Ok(()) => Ok(AwaitingRequest), - Err(e) => Err((e, None)), + Ok(()) => Ok((AwaitingRequest, Some(tx))), + Err(e) => Err((e, tx)), } } (AwaitingRequest, AdvertiseTransactions(hashes)) => { - let _ = tx.send(Ok(Response::Nil)); match self .peer_tx .send(Message::Inv(hashes.iter().map(|h| (*h).into()).collect())) .await { - Ok(()) => Ok(AwaitingRequest), - Err(e) => Err((e, None)), + Ok(()) => Ok((AwaitingRequest, Some(tx))), + Err(e) => Err((e, tx)), } } (AwaitingRequest, AdvertiseBlock(hash)) => { - let _ = tx.send(Ok(Response::Nil)); match self.peer_tx.send(Message::Inv(vec![hash.into()])).await { - Ok(()) => Ok(AwaitingRequest), - Err(e) => Err((e, None)), + Ok(()) => Ok((AwaitingRequest, Some(tx))), + Err(e) => Err((e, tx)), } } }; // Updates state or fails. Sends the error on the Sender if it is Some. match new_state_result { - Ok(new_state) => { + Ok((AwaitingRequest, Some(tx))) => { + // Since we're not waiting for further messages, we need to + // send a response before dropping tx. + let _ = tx.send(Ok(Response::Nil)); + self.state = AwaitingRequest; + self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); + } + Ok((new_state @ AwaitingResponse { .. }, None)) => { self.state = new_state; self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); } - Err((e, Some(tx))) => { + Err((e, tx)) => { let e = SharedPeerError::from(e); let _ = tx.send(Err(e.clone())); self.fail_with(e); } - Err((e, None)) => self.fail_with(e), + // unreachable states + Ok((Failed, tx)) => unreachable!( + "failed client requests must use fail_with(error) to reach a Failed state. tx: {:?}", + tx + ), + Ok((AwaitingRequest, None)) => unreachable!( + "successful AwaitingRequest states must send a response on tx, but tx is None", + ), + Ok((AwaitingResponse { .. }, Some(tx))) => unreachable!( + "successful AwaitingResponse states must keep tx, but tx is Some: {:?}", + tx + ), }; }