diff --git a/zebrad/src/components/mempool.rs b/zebrad/src/components/mempool.rs index d1d55a214..64b05f725 100644 --- a/zebrad/src/components/mempool.rs +++ b/zebrad/src/components/mempool.rs @@ -83,7 +83,7 @@ enum ActiveState { /// ##: Correctness: only components internal to the [`Mempool`] struct are allowed to /// inject transactions into `storage`, as transactions must be verified beforehand. storage: storage::Storage, - /// The transaction dowload and verify stream. + /// The transaction download and verify stream. tx_downloads: Pin>, }, } @@ -246,9 +246,19 @@ impl Service for Mempool { storage, tx_downloads, } => { - // Clear the mempool if there has been a chain tip reset. - if let Some(TipAction::Reset { .. }) = self.chain_tip_change.last_tip_change() { - storage.clear(); + if let Some(tip_action) = self.chain_tip_change.last_tip_change() { + match tip_action { + // Clear the mempool if there has been a chain tip reset. + TipAction::Reset { .. } => { + storage.clear(); + } + // Cancel downloads/verifications of transactions with the same + // IDs as recently mined transactions. + TipAction::Grow { block } => { + let txid_set = block.transaction_hashes.iter().collect(); + tx_downloads.cancel(txid_set); + } + } } // Clean up completed download tasks and add to mempool if successful. @@ -268,7 +278,7 @@ impl Service for Mempool { ActiveState::Disabled => { // When the mempool is disabled we still return that the service is ready. // Otherwise, callers could block waiting for the mempool to be enabled, - // which may not be the desired behaviour. + // which may not be the desired behavior. } } diff --git a/zebrad/src/components/mempool/downloads.rs b/zebrad/src/components/mempool/downloads.rs index 446369885..05524a6bd 100644 --- a/zebrad/src/components/mempool/downloads.rs +++ b/zebrad/src/components/mempool/downloads.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, pin::Pin, task::{Context, Poll}, time::Duration, @@ -16,7 +16,7 @@ use tokio::{sync::oneshot, task::JoinHandle}; use tower::{Service, ServiceExt}; use tracing_futures::Instrument; -use zebra_chain::transaction::{UnminedTx, UnminedTxId}; +use zebra_chain::transaction::{self, UnminedTx, UnminedTxId}; use zebra_consensus::transaction as tx; use zebra_network as zn; use zebra_state as zs; @@ -315,6 +315,25 @@ where Ok(()) } + /// Cancel download/verification tasks of transactions with the + /// given transaction hash (see [`UnminedTxId::mined_id`]). + pub fn cancel(&mut self, mined_ids: HashSet<&transaction::Hash>) { + // TODO: this can be simplified with [`HashMap::drain_filter`] which + // is currently nightly-only experimental API. + let removed_txids: Vec = self + .cancel_handles + .keys() + .filter(|txid| mined_ids.contains(&txid.mined_id())) + .cloned() + .collect(); + + for txid in removed_txids { + if let Some(handle) = self.cancel_handles.remove(&txid) { + let _ = handle.send(()); + } + } + } + /// Get the number of currently in-flight download tasks. // Note: copied from zebrad/src/components/sync/downloads.rs #[allow(dead_code)] diff --git a/zebrad/src/components/mempool/tests.rs b/zebrad/src/components/mempool/tests.rs index dfef11a99..6e7c4b872 100644 --- a/zebrad/src/components/mempool/tests.rs +++ b/zebrad/src/components/mempool/tests.rs @@ -1,9 +1,12 @@ use super::*; use color_eyre::Report; -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; use storage::tests::unmined_transactions_in_blocks; +use tokio::time; use tower::{ServiceBuilder, ServiceExt}; +use zebra_chain::block::Block; +use zebra_chain::serialization::ZcashDeserializeInto; use zebra_consensus::Config as ConsensusConfig; use zebra_state::Config as StateConfig; use zebra_test::mock_service::MockService; @@ -348,3 +351,133 @@ async fn mempool_service_disabled() -> Result<(), Report> { Ok(()) } + +#[tokio::test] +async fn mempool_cancel_mined() -> Result<(), Report> { + let block1: Arc = zebra_test::vectors::BLOCK_MAINNET_1_BYTES + .zcash_deserialize_into() + .unwrap(); + let block2: Arc = zebra_test::vectors::BLOCK_MAINNET_2_BYTES + .zcash_deserialize_into() + .unwrap(); + + // Using the mainnet for now + let network = Network::Mainnet; + let consensus_config = ConsensusConfig::default(); + let state_config = StateConfig::ephemeral(); + let peer_set = MockService::build().for_unit_tests(); + let (sync_status, mut recent_syncs) = SyncStatus::new(); + let (state, latest_chain_tip, chain_tip_change) = + zebra_state::init(state_config.clone(), network); + + let mut state_service = ServiceBuilder::new().buffer(1).service(state); + let (_chain_verifier, tx_verifier) = + zebra_consensus::chain::init(consensus_config.clone(), network, state_service.clone()) + .await; + + time::pause(); + + // Start the mempool service + let mut mempool = Mempool::new( + network, + Buffer::new(BoxService::new(peer_set), 1), + state_service.clone(), + tx_verifier, + sync_status, + latest_chain_tip, + chain_tip_change, + ); + + // Enable the mempool + let _ = mempool.enable(&mut recent_syncs).await; + assert!(mempool.is_enabled()); + + // Push the genesis block to the state + let genesis_block: Arc = zebra_test::vectors::BLOCK_MAINNET_GENESIS_BYTES + .zcash_deserialize_into() + .unwrap(); + state_service + .ready_and() + .await + .unwrap() + .call(zebra_state::Request::CommitFinalizedBlock( + genesis_block.clone().into(), + )) + .await + .unwrap(); + + // Queue transaction from block 2 for download + let txid = block2.transactions[0].unmined_id(); + let response = mempool + .ready_and() + .await + .unwrap() + .call(Request::Queue(vec![txid.into()])) + .await + .unwrap(); + let queued_responses = match response { + Response::Queued(queue_responses) => queue_responses, + _ => unreachable!("will never happen in this test"), + }; + assert_eq!(queued_responses.len(), 1); + assert!(queued_responses[0].is_ok()); + assert_eq!(mempool.tx_downloads().in_flight(), 1); + + // Query the mempool to make it poll chain_tip_change + let _response = mempool + .ready_and() + .await + .unwrap() + .call(Request::TransactionIds) + .await + .unwrap(); + + // Push block 1 to the state + state_service + .ready_and() + .await + .unwrap() + .call(zebra_state::Request::CommitFinalizedBlock( + block1.clone().into(), + )) + .await + .unwrap(); + + // Query the mempool to make it poll chain_tip_change + let _response = mempool + .ready_and() + .await + .unwrap() + .call(Request::TransactionIds) + .await + .unwrap(); + + // Push block 2 to the state + state_service + .oneshot(zebra_state::Request::CommitFinalizedBlock( + block2.clone().into(), + )) + .await + .unwrap(); + + // This is done twice because after the first query the cancellation + // is picked up by select!, and after the second the mempool gets the + // result and the download future is removed. + for _ in 0..2 { + // Query the mempool just to poll it and make it cancel the download. + let _response = mempool + .ready_and() + .await + .unwrap() + .call(Request::TransactionIds) + .await + .unwrap(); + // Sleep to avoid starvation and make sure the cancellation is picked up. + time::sleep(time::Duration::from_millis(100)).await; + } + + // Check if download was cancelled. + assert_eq!(mempool.tx_downloads().in_flight(), 0); + + Ok(()) +}