From 40c907dd09d082ec0d191a3050b39dcda19319db Mon Sep 17 00:00:00 2001 From: teor Date: Tue, 19 Oct 2021 01:31:11 +1000 Subject: [PATCH] Remove duplicate IDs in mempool requests and responses (#2887) * Guarantee unique IDs in mempool service responses * Guarantee unique IDs in crawler task mempool Queue requests Also update the tests to use unique IDs. Co-authored-by: Conrado Gouvea --- zebrad/src/components/inbound.rs | 2 +- zebrad/src/components/mempool.rs | 4 +- zebrad/src/components/mempool/crawler.rs | 15 ++-- .../src/components/mempool/crawler/tests.rs | 83 +++++++++++++------ 4 files changed, 70 insertions(+), 34 deletions(-) diff --git a/zebrad/src/components/inbound.rs b/zebrad/src/components/inbound.rs index afc2fa954..d5f2d3930 100644 --- a/zebrad/src/components/inbound.rs +++ b/zebrad/src/components/inbound.rs @@ -396,7 +396,7 @@ impl Service for Inbound { if let Setup::Initialized { mempool, .. } = &mut self.network_setup { mempool.clone().oneshot(mempool::Request::TransactionIds).map_ok(|resp| match resp { mempool::Response::TransactionIds(transaction_ids) if transaction_ids.is_empty() => zn::Response::Nil, - mempool::Response::TransactionIds(transaction_ids) => zn::Response::TransactionIds(transaction_ids), + mempool::Response::TransactionIds(transaction_ids) => zn::Response::TransactionIds(transaction_ids.into_iter().collect()), _ => unreachable!("Mempool component should always respond to a `TransactionIds` request with a `TransactionIds` response"), }) .boxed() diff --git a/zebrad/src/components/mempool.rs b/zebrad/src/components/mempool.rs index 34586c7ea..066ae57bb 100644 --- a/zebrad/src/components/mempool.rs +++ b/zebrad/src/components/mempool.rs @@ -71,8 +71,8 @@ pub enum Request { #[derive(Debug)] pub enum Response { Transactions(Vec), - TransactionIds(Vec), - RejectedTransactionIds(Vec), + TransactionIds(HashSet), + RejectedTransactionIds(HashSet), Queued(Vec>), } diff --git a/zebrad/src/components/mempool/crawler.rs b/zebrad/src/components/mempool/crawler.rs index fa156caf0..ce5d91b2c 100644 --- a/zebrad/src/components/mempool/crawler.rs +++ b/zebrad/src/components/mempool/crawler.rs @@ -2,13 +2,13 @@ //! //! The crawler periodically requests transactions from peers in order to populate the mempool. -use std::time::Duration; +use std::{collections::HashSet, time::Duration}; use futures::{future, pin_mut, stream::FuturesUnordered, StreamExt}; use tokio::{sync::watch, task::JoinHandle, time::sleep}; use tower::{timeout::Timeout, BoxError, Service, ServiceExt}; -use zebra_chain::block::Height; +use zebra_chain::{block::Height, transaction::UnminedTxId}; use zebra_network as zn; use zebra_state::ChainTipChange; @@ -171,8 +171,8 @@ where /// Handle a peer's response to the crawler's request for transactions. async fn handle_response(&mut self, response: zn::Response) -> Result<(), BoxError> { - let transaction_ids: Vec<_> = match response { - zn::Response::TransactionIds(ids) => ids.into_iter().map(Gossip::Id).collect(), + let transaction_ids: HashSet<_> = match response { + zn::Response::TransactionIds(ids) => ids.into_iter().collect(), _ => unreachable!("Peer set did not respond with transaction IDs to mempool crawler"), }; @@ -189,7 +189,12 @@ where } /// Forward the crawled transactions IDs to the mempool transaction downloader. - async fn queue_transactions(&mut self, transaction_ids: Vec) -> Result<(), BoxError> { + async fn queue_transactions( + &mut self, + transaction_ids: HashSet, + ) -> Result<(), BoxError> { + let transaction_ids = transaction_ids.into_iter().map(Gossip::Id).collect(); + let call_result = self .mempool .ready_and() diff --git a/zebrad/src/components/mempool/crawler/tests.rs b/zebrad/src/components/mempool/crawler/tests.rs index 29279e2f8..819b22ebd 100644 --- a/zebrad/src/components/mempool/crawler/tests.rs +++ b/zebrad/src/components/mempool/crawler/tests.rs @@ -1,6 +1,9 @@ -use std::time::Duration; +use std::{collections::HashSet, time::Duration}; -use proptest::{collection::vec, prelude::*}; +use proptest::{ + collection::{hash_set, vec}, + prelude::*, +}; use tokio::time; use zebra_chain::{parameters::Network, transaction::UnminedTxId}; @@ -68,7 +71,7 @@ proptest! { for _ in 0..CRAWL_ITERATIONS { for _ in 0..FANOUT { if mempool_is_enabled { - respond_with_transaction_ids(&mut peer_set, vec![]).await?; + respond_with_transaction_ids(&mut peer_set, HashSet::new()).await?; } else { peer_set.expect_no_requests().await?; } @@ -96,7 +99,7 @@ proptest! { /// the mempool. #[test] fn crawled_transactions_are_forwarded_to_downloader( - transaction_ids in vec(any::(), 1..MAX_CRAWLED_TX), + transaction_ids in hash_set(any::(), 1..MAX_CRAWLED_TX), ) { let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() @@ -136,11 +139,15 @@ proptest! { #[test] fn transaction_id_forwarding_errors_dont_stop_the_crawler( service_call_error in any::(), - transaction_ids_for_call_failure in vec(any::(), 1..MAX_CRAWLED_TX), + transaction_ids_for_call_failure in hash_set(any::(), 1..MAX_CRAWLED_TX), transaction_ids_and_responses in vec(any::<(UnminedTxId, Result<(), MempoolError>)>(), 1..MAX_CRAWLED_TX), - transaction_ids_for_return_to_normal in vec(any::(), 1..MAX_CRAWLED_TX), + transaction_ids_for_return_to_normal in hash_set(any::(), 1..MAX_CRAWLED_TX), ) { + // Make transaction_ids_and_responses unique + let unique_transaction_ids_and_responses: HashSet = transaction_ids_and_responses.iter().map(|(id, _result)| id).copied().collect(); + let transaction_ids_and_responses: Vec<(UnminedTxId, Result<(), MempoolError>)> = unique_transaction_ids_and_responses.iter().map(|unique_id| transaction_ids_and_responses.iter().find(|(id, _result)| id == unique_id).unwrap()).cloned().collect(); + let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() @@ -158,11 +165,11 @@ proptest! { // Prepare to simulate download errors. let download_result_count = transaction_ids_and_responses.len(); - let mut transaction_ids_for_download_errors = Vec::with_capacity(download_result_count); + let mut transaction_ids_for_download_errors = HashSet::with_capacity(download_result_count); let mut download_result_list = Vec::with_capacity(download_result_count); for (transaction_id, result) in transaction_ids_and_responses { - transaction_ids_for_download_errors.push(transaction_id); + transaction_ids_for_download_errors.insert(transaction_id); download_result_list.push(result); } @@ -257,12 +264,14 @@ fn setup_crawler() -> ( /// Intercept a request for mempool transaction IDs and respond with the `transaction_ids` list. async fn respond_with_transaction_ids( peer_set: &mut MockPeerSet, - transaction_ids: Vec, + transaction_ids: HashSet, ) -> Result<(), TestCaseError> { peer_set .expect_request(zn::Request::MempoolTransactionIds) .await? - .respond(zn::Response::TransactionIds(transaction_ids)); + .respond(zn::Response::TransactionIds( + transaction_ids.into_iter().collect(), + )); Ok(()) } @@ -280,7 +289,7 @@ async fn respond_with_transaction_ids( /// If `responses` contains more items than the [`FANOUT`] number. async fn crawler_iteration( peer_set: &mut MockPeerSet, - responses: Vec>, + responses: Vec>, ) -> Result<(), TestCaseError> { let empty_responses = FANOUT .checked_sub(responses.len()) @@ -291,7 +300,7 @@ async fn crawler_iteration( } for _ in 0..empty_responses { - respond_with_transaction_ids(peer_set, vec![]).await?; + respond_with_transaction_ids(peer_set, HashSet::new()).await?; } peer_set.expect_no_requests().await?; @@ -310,16 +319,27 @@ async fn crawler_iteration( /// If `response` and `expected_transaction_ids` have different sizes. async fn respond_to_queue_request( mempool: &mut MockMempool, - expected_transaction_ids: Vec, + expected_transaction_ids: HashSet, response: Vec>, ) -> Result<(), TestCaseError> { - let request_parameter = expected_transaction_ids - .into_iter() - .map(Gossip::Id) - .collect(); - mempool - .expect_request(mempool::Request::Queue(request_parameter)) + .expect_request_that(|req| { + if let mempool::Request::Queue(req) = req { + let ids: HashSet = req + .iter() + .filter_map(|gossip| { + if let Gossip::Id(id) = gossip { + Some(*id) + } else { + None + } + }) + .collect(); + ids == expected_transaction_ids + } else { + false + } + }) .await? .respond(mempool::Response::Queued(response)); @@ -333,16 +353,27 @@ async fn respond_to_queue_request( /// from queuing the transactions for downloading. async fn respond_to_queue_request_with_error( mempool: &mut MockMempool, - expected_transaction_ids: Vec, + expected_transaction_ids: HashSet, error: MempoolError, ) -> Result<(), TestCaseError> { - let request_parameter = expected_transaction_ids - .into_iter() - .map(Gossip::Id) - .collect(); - mempool - .expect_request(mempool::Request::Queue(request_parameter)) + .expect_request_that(|req| { + if let mempool::Request::Queue(req) = req { + let ids: HashSet = req + .iter() + .filter_map(|gossip| { + if let Gossip::Id(id) = gossip { + Some(*id) + } else { + None + } + }) + .collect(); + ids == expected_transaction_ids + } else { + false + } + }) .await? .respond(Err(error));