From 5a9281a7a83aede263f9922c58d211c7f5b5ce22 Mon Sep 17 00:00:00 2001 From: Arya Date: Wed, 21 Feb 2024 18:29:13 -0500 Subject: [PATCH] fix(scan): Fix minor concurrency bug in the `scan` gRPC method (#8303) * update scan task to notify subscribe caller once the new subscribed_keys are sent on the watch channel * Fixes timing bug in scan gRPC method: Joins register/subscribe scan service calls, sends SubscribeResults request first, and filters out duplicate results from channel * Removes outdated TODO * wraps subscribed_keys in an Arc before sending to watch channel, fixes typo * Remove result senders for keys that have been removed --- zebra-grpc/src/server.rs | 68 ++++++++++++------- zebra-grpc/src/tests/snapshot.rs | 10 +-- zebra-scan/src/service.rs | 2 +- zebra-scan/src/service/scan_task/commands.rs | 46 +++++++------ zebra-scan/src/service/scan_task/executor.rs | 6 +- zebra-scan/src/service/scan_task/scan.rs | 40 +++++++---- .../src/service/scan_task/scan/scan_range.rs | 7 +- .../src/service/scan_task/tests/vectors.rs | 41 ++++++++--- zebra-scan/src/service/tests.rs | 9 +-- .../common/shielded_scan/subscribe_results.rs | 7 +- 10 files changed, 150 insertions(+), 86 deletions(-) diff --git a/zebra-grpc/src/server.rs b/zebra-grpc/src/server.rs index 71e7491af..8c8df4c3e 100644 --- a/zebra-grpc/src/server.rs +++ b/zebra-grpc/src/server.rs @@ -70,13 +70,24 @@ where .into_iter() .map(|KeyWithHeight { key, height }| (key, height)) .collect(); - - let ScanServiceResponse::RegisteredKeys(_) = self + let register_keys_response_fut = self .scan_service .clone() - .ready() - .and_then(|service| service.call(ScanServiceRequest::RegisterKeys(keys.clone()))) - .await + .oneshot(ScanServiceRequest::RegisterKeys(keys.clone())); + + let keys: Vec<_> = keys.into_iter().map(|(key, _start_at)| key).collect(); + + let subscribe_results_response_fut = + self.scan_service + .clone() + .oneshot(ScanServiceRequest::SubscribeResults( + keys.iter().cloned().collect(), + )); + + let (register_keys_response, subscribe_results_response) = + tokio::join!(register_keys_response_fut, subscribe_results_response_fut); + + let ScanServiceResponse::RegisteredKeys(_) = register_keys_response .map_err(|err| Status::unknown(format!("scan service returned error: {err}")))? else { return Err(Status::unknown( @@ -84,7 +95,14 @@ where )); }; - let keys: Vec<_> = keys.into_iter().map(|(key, _start_at)| key).collect(); + let ScanServiceResponse::SubscribeResults(mut results_receiver) = + subscribe_results_response + .map_err(|err| Status::unknown(format!("scan service returned error: {err}")))? + else { + return Err(Status::unknown( + "scan service returned an unexpected response", + )); + }; let ScanServiceResponse::Results(results) = self .scan_service @@ -99,29 +117,31 @@ where )); }; - let ScanServiceResponse::SubscribeResults(mut results_receiver) = self - .scan_service - .clone() - .ready() - .and_then(|service| { - service.call(ScanServiceRequest::SubscribeResults( - keys.iter().cloned().collect(), - )) - }) - .await - .map_err(|err| Status::unknown(format!("scan service returned error: {err}")))? - else { - return Err(Status::unknown( - "scan service returned an unexpected response", - )); - }; - let (response_sender, response_receiver) = tokio::sync::mpsc::channel(SCAN_RESPONDER_BUFFER_SIZE); let response_stream = ReceiverStream::new(response_receiver); tokio::spawn(async move { - let initial_results = process_results(keys, results); + let mut initial_results = process_results(keys, results); + + // Empty results receiver channel to filter out duplicate results between the channel and cache + while let Ok(ScanResult { key, height, tx_id }) = results_receiver.try_recv() { + let entry = initial_results + .entry(key) + .or_default() + .by_height + .entry(height.0) + .or_default(); + + let tx_id = Transaction { + hash: tx_id.to_string(), + }; + + // Add the scan result to the initial results if it's not already present. + if !entry.transactions.contains(&tx_id) { + entry.transactions.push(tx_id); + } + } let send_result = response_sender .send(Ok(ScanResponse { diff --git a/zebra-grpc/src/tests/snapshot.rs b/zebra-grpc/src/tests/snapshot.rs index 63f355aa9..37acb59fe 100644 --- a/zebra-grpc/src/tests/snapshot.rs +++ b/zebra-grpc/src/tests/snapshot.rs @@ -176,15 +176,15 @@ async fn test_mocked_rpc_response_data_for_network(network: Network, random_port .await .respond(ScanResponse::RegisteredKeys(vec![])); - mock_scan_service - .expect_request_that(|req| matches!(req, ScanRequest::Results(_))) - .await - .respond(ScanResponse::Results(fake_results_response)); - mock_scan_service .expect_request_that(|req| matches!(req, ScanRequest::SubscribeResults(_))) .await .respond(ScanResponse::SubscribeResults(fake_results_receiver)); + + mock_scan_service + .expect_request_that(|req| matches!(req, ScanRequest::Results(_))) + .await + .respond(ScanResponse::Results(fake_results_response)); }); } diff --git a/zebra-scan/src/service.rs b/zebra-scan/src/service.rs index fd6b75c9b..5c5aeeae1 100644 --- a/zebra-scan/src/service.rs +++ b/zebra-scan/src/service.rs @@ -169,7 +169,7 @@ impl Service for ScanService { let mut scan_task = self.scan_task.clone(); return async move { - let results_receiver = scan_task.subscribe(keys)?; + let results_receiver = scan_task.subscribe(keys).await?; Ok(Response::SubscribeResults(results_receiver)) } diff --git a/zebra-scan/src/service/scan_task/commands.rs b/zebra-scan/src/service/scan_task/commands.rs index 868200219..5a210c1ec 100644 --- a/zebra-scan/src/service/scan_task/commands.rs +++ b/zebra-scan/src/service/scan_task/commands.rs @@ -8,6 +8,7 @@ use tokio::sync::{ oneshot, }; +use tower::BoxError; use zcash_primitives::{sapling::SaplingIvk, zip32::DiversifiableFullViewingKey}; use zebra_chain::{block::Height, parameters::Network}; use zebra_node_services::scan_service::response::ScanResult; @@ -41,11 +42,11 @@ pub enum ScanTaskCommand { /// Start sending results for key hashes to `result_sender` SubscribeResults { - /// Sender for results - result_sender: Sender, - /// Key hashes to send the results of to result channel keys: HashSet, + + /// Returns the result receiver once the subscribed keys have been added. + rsp_tx: oneshot::Sender>, }, } @@ -69,6 +70,7 @@ impl ScanTask { (Vec, Vec, Height), >, HashMap>, + Vec<(Receiver, oneshot::Sender>)>, ), Report, > { @@ -76,6 +78,7 @@ impl ScanTask { let mut new_keys = HashMap::new(); let mut new_result_senders = HashMap::new(); + let mut new_result_receivers = Vec::new(); let sapling_activation_height = network.sapling_activation_height(); loop { @@ -142,13 +145,20 @@ impl ScanTask { let _ = done_tx.send(()); } - ScanTaskCommand::SubscribeResults { - result_sender, - keys, - } => { - let keys = keys + ScanTaskCommand::SubscribeResults { rsp_tx, keys } => { + let keys: Vec<_> = keys .into_iter() - .filter(|key| registered_keys.contains_key(key)); + .filter(|key| registered_keys.contains_key(key)) + .collect(); + + if keys.is_empty() { + continue; + } + + let (result_sender, result_receiver) = + tokio::sync::mpsc::channel(RESULTS_SENDER_BUFFER_SIZE); + + new_result_receivers.push((result_receiver, rsp_tx)); for key in keys { new_result_senders.insert(key, result_sender.clone()); @@ -157,7 +167,7 @@ impl ScanTask { } } - Ok((new_keys, new_result_senders)) + Ok((new_keys, new_result_senders, new_result_receivers)) } /// Sends a command to the scan task @@ -200,18 +210,14 @@ impl ScanTask { /// Sends a message to the scan task to start sending the results for the provided viewing keys to a channel. /// /// Returns the channel receiver. - pub fn subscribe( + pub async fn subscribe( &mut self, keys: HashSet, - ) -> Result, TrySendError> { - // TODO: Use a bounded channel - let (result_sender, result_receiver) = - tokio::sync::mpsc::channel(RESULTS_SENDER_BUFFER_SIZE); + ) -> Result, BoxError> { + let (rsp_tx, rsp_rx) = oneshot::channel(); - self.send(ScanTaskCommand::SubscribeResults { - result_sender, - keys, - }) - .map(|_| result_receiver) + self.send(ScanTaskCommand::SubscribeResults { keys, rsp_tx })?; + + Ok(rsp_rx.await?) } } diff --git a/zebra-scan/src/service/scan_task/executor.rs b/zebra-scan/src/service/scan_task/executor.rs index 00e3fa72c..eb6ecca46 100644 --- a/zebra-scan/src/service/scan_task/executor.rs +++ b/zebra-scan/src/service/scan_task/executor.rs @@ -1,6 +1,6 @@ //! The scan task executor -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use color_eyre::eyre::Report; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; @@ -19,7 +19,7 @@ use super::scan::ScanRangeTaskBuilder; const EXECUTOR_BUFFER_SIZE: usize = 100; pub fn spawn_init( - subscribed_keys_receiver: tokio::sync::watch::Receiver>>, + subscribed_keys_receiver: watch::Receiver>>>, ) -> (Sender, JoinHandle>) { let (scan_task_sender, scan_task_receiver) = tokio::sync::mpsc::channel(EXECUTOR_BUFFER_SIZE); @@ -33,7 +33,7 @@ pub fn spawn_init( pub async fn scan_task_executor( mut scan_task_receiver: Receiver, - subscribed_keys_receiver: watch::Receiver>>, + subscribed_keys_receiver: watch::Receiver>>>, ) -> Result<(), Report> { let mut scan_range_tasks = FuturesUnordered::new(); diff --git a/zebra-scan/src/service/scan_task/scan.rs b/zebra-scan/src/service/scan_task/scan.rs index 7d0cbc16f..898183fe1 100644 --- a/zebra-scan/src/service/scan_task/scan.rs +++ b/zebra-scan/src/service/scan_task/scan.rs @@ -8,7 +8,10 @@ use std::{ use color_eyre::{eyre::eyre, Report}; use itertools::Itertools; -use tokio::{sync::mpsc::Sender, task::JoinHandle}; +use tokio::{ + sync::{mpsc::Sender, watch}, + task::JoinHandle, +}; use tower::{buffer::Buffer, util::BoxService, Service, ServiceExt}; use tracing::Instrument; @@ -116,10 +119,10 @@ pub async fn start( let mut subscribed_keys: HashMap> = HashMap::new(); let (subscribed_keys_sender, subscribed_keys_receiver) = - tokio::sync::watch::channel(subscribed_keys.clone()); + tokio::sync::watch::channel(Arc::new(subscribed_keys.clone())); let (scan_task_sender, scan_task_executor_handle) = - executor::spawn_init(subscribed_keys_receiver); + executor::spawn_init(subscribed_keys_receiver.clone()); let mut scan_task_executor_handle = Some(scan_task_executor_handle); // Give empty states time to verify some blocks before we start scanning. @@ -139,17 +142,23 @@ pub async fn start( let was_parsed_keys_empty = parsed_keys.is_empty(); - let (new_keys, new_result_senders) = + let (new_keys, new_result_senders, new_result_receivers) = ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?; + subscribed_keys.extend(new_result_senders); + // Drop any results senders that are closed from subscribed_keys + subscribed_keys.retain(|key, sender| !sender.is_closed() && parsed_keys.contains_key(key)); + // Send the latest version of `subscribed_keys` before spawning the scan range task - if !new_result_senders.is_empty() { - subscribed_keys.extend(new_result_senders); - // Ignore send errors, it's okay if there aren't any receivers. - let _ = subscribed_keys_sender.send(subscribed_keys.clone()); + subscribed_keys_sender + .send(Arc::new(subscribed_keys.clone())) + .expect("last receiver should not be dropped while this task is running"); + + for (result_receiver, rsp_tx) in new_result_receivers { + // Ignore send errors, we drop any closed results channels above. + let _ = rsp_tx.send(result_receiver); } - // TODO: Check if the `start_height` is at or above the current height if !new_keys.is_empty() { let state = state.clone(); let storage = storage.clone(); @@ -163,7 +172,9 @@ pub async fn start( if was_parsed_keys_empty { info!(?start_height, "setting new start height"); height = start_height; - } else if start_height < height { + } + // Skip spawning ScanRange task if `start_height` is at or above the current height + else if start_height < height { scan_task_sender .send(ScanRangeTaskBuilder::new(height, new_keys, state, storage)) .await @@ -179,7 +190,7 @@ pub async fn start( storage.clone(), key_heights.clone(), parsed_keys.clone(), - subscribed_keys.clone(), + subscribed_keys_receiver.clone(), ) .await?; @@ -241,7 +252,7 @@ pub async fn scan_height_and_store_results( storage: Storage, key_last_scanned_heights: Arc>, parsed_keys: HashMap, Vec)>, - subscribed_keys: HashMap>, + subscribed_keys_receiver: watch::Receiver>>>, ) -> Result, Report> { let network = storage.network(); @@ -295,7 +306,7 @@ pub async fn scan_height_and_store_results( _other => {} }; - let results_sender = subscribed_keys.get(&sapling_key).cloned(); + let subscribed_keys_receiver = subscribed_keys_receiver.clone(); let sapling_key = sapling_key.clone(); let block = block.clone(); @@ -322,7 +333,8 @@ pub async fn scan_height_and_store_results( let dfvk_res = scanned_block_to_db_result(dfvk_res); let ivk_res = scanned_block_to_db_result(ivk_res); - if let Some(results_sender) = results_sender { + let latest_subscribed_keys = subscribed_keys_receiver.borrow().clone(); + if let Some(results_sender) = latest_subscribed_keys.get(&sapling_key).cloned() { let results = dfvk_res.iter().chain(ivk_res.iter()); for (_tx_index, &tx_id) in results { diff --git a/zebra-scan/src/service/scan_task/scan/scan_range.rs b/zebra-scan/src/service/scan_task/scan/scan_range.rs index 7d36b917d..b4b75cfe5 100644 --- a/zebra-scan/src/service/scan_task/scan/scan_range.rs +++ b/zebra-scan/src/service/scan_task/scan/scan_range.rs @@ -56,7 +56,7 @@ impl ScanRangeTaskBuilder { // TODO: return a tuple with a shutdown sender pub fn spawn( self, - subscribed_keys_receiver: watch::Receiver>>, + subscribed_keys_receiver: watch::Receiver>>>, ) -> JoinHandle> { let Self { height_range, @@ -86,7 +86,7 @@ pub async fn scan_range( keys: HashMap, Vec, Height)>, state: State, storage: Storage, - subscribed_keys_receiver: watch::Receiver>>, + subscribed_keys_receiver: watch::Receiver>>>, ) -> Result<(), Report> { let sapling_activation_height = storage.network().sapling_activation_height(); // Do not scan and notify if we are below sapling activation height. @@ -116,7 +116,6 @@ pub async fn scan_range( .collect(); while height < stop_before_height { - let subscribed_keys = subscribed_keys_receiver.borrow().clone(); let scanned_height = scan_height_and_store_results( height, state.clone(), @@ -124,7 +123,7 @@ pub async fn scan_range( storage.clone(), key_heights.clone(), parsed_keys.clone(), - subscribed_keys, + subscribed_keys_receiver.clone(), ) .await?; diff --git a/zebra-scan/src/service/scan_task/tests/vectors.rs b/zebra-scan/src/service/scan_task/tests/vectors.rs index 5ad5ee589..9022da794 100644 --- a/zebra-scan/src/service/scan_task/tests/vectors.rs +++ b/zebra-scan/src/service/scan_task/tests/vectors.rs @@ -1,6 +1,9 @@ //! Fixed test vectors for the scan task. -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + time::Duration, +}; use color_eyre::Report; @@ -24,7 +27,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> { sapling_keys.into_iter().zip((0..).map(Some)).collect(); mock_scan_task.register_keys(sapling_keys_with_birth_heights.clone())?; - let (new_keys, _new_results_senders) = + let (new_keys, _new_results_senders, _new_results_receivers) = ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?; // Check that it updated parsed_keys correctly and returned the right new keys when starting with an empty state @@ -45,7 +48,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> { // Check that no key should be added if they are all already known and the heights are the same - let (new_keys, _new_results_senders) = + let (new_keys, _new_results_senders, _new_results_receivers) = ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?; assert_eq!( @@ -71,7 +74,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> { mock_scan_task.register_keys(sapling_keys_with_birth_heights[10..20].to_vec())?; mock_scan_task.register_keys(sapling_keys_with_birth_heights[10..15].to_vec())?; - let (new_keys, _new_results_senders) = + let (new_keys, _new_results_senders, _new_results_receivers) = ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?; assert_eq!( @@ -91,7 +94,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> { let sapling_keys = mock_sapling_scanning_keys(30, network); let done_rx = mock_scan_task.remove_keys(&sapling_keys)?; - let (new_keys, _new_results_senders) = + let (new_keys, _new_results_senders, _new_results_receivers) = ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?; // Check that it sends the done notification successfully before returning and dropping `done_tx` @@ -110,7 +113,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> { mock_scan_task.remove_keys(&sapling_keys)?; - let (new_keys, _new_results_senders) = + let (new_keys, _new_results_senders, _new_results_receivers) = ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?; assert!( @@ -126,7 +129,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> { mock_scan_task.register_keys(sapling_keys_with_birth_heights[..2].to_vec())?; - let (new_keys, _new_results_senders) = + let (new_keys, _new_results_senders, _new_results_receivers) = ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?; assert_eq!( @@ -142,11 +145,31 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> { ); let subscribe_keys: HashSet = sapling_keys[..5].iter().cloned().collect(); - let mut result_receiver = mock_scan_task.subscribe(subscribe_keys.clone())?; + let result_receiver_fut = { + let mut mock_scan_task = mock_scan_task.clone(); + tokio::spawn(async move { mock_scan_task.subscribe(subscribe_keys.clone()).await }) + }; - let (_new_keys, new_results_senders) = + // Wait for spawned task to send subscribe message + tokio::time::sleep(Duration::from_secs(1)).await; + + let (_new_keys, new_results_senders, new_results_receivers) = ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?; + let (result_receiver, rsp_tx) = new_results_receivers + .into_iter() + .next() + .expect("there should be a new result receiver"); + + rsp_tx + .send(result_receiver) + .expect("should send response successfully"); + + let mut result_receiver = result_receiver_fut + .await + .expect("tokio task should join successfully") + .expect("should send and receive message successfully"); + let processed_subscribe_keys: HashSet = new_results_senders.keys().cloned().collect(); let expected_new_subscribe_keys: HashSet = sapling_keys[..2].iter().cloned().collect(); diff --git a/zebra-scan/src/service/tests.rs b/zebra-scan/src/service/tests.rs index 0173094e3..0ae172830 100644 --- a/zebra-scan/src/service/tests.rs +++ b/zebra-scan/src/service/tests.rs @@ -96,14 +96,15 @@ pub async fn scan_service_subscribes_to_results_correctly() -> Result<()> { let expected_keys = keys.iter().cloned().collect(); let cmd_handler_fut = tokio::spawn(async move { - let Some(ScanTaskCommand::SubscribeResults { - result_sender: _, - keys, - }) = cmd_receiver.recv().await + let Some(ScanTaskCommand::SubscribeResults { rsp_tx, keys }) = cmd_receiver.recv().await else { panic!("should successfully receive SubscribeResults message"); }; + let (_results_sender, results_receiver) = tokio::sync::mpsc::channel(1); + rsp_tx + .send(results_receiver) + .expect("should send response successfully"); assert_eq!(keys, expected_keys, "keys should match the request keys"); }); diff --git a/zebrad/tests/common/shielded_scan/subscribe_results.rs b/zebrad/tests/common/shielded_scan/subscribe_results.rs index add0368f6..9806619ad 100644 --- a/zebrad/tests/common/shielded_scan/subscribe_results.rs +++ b/zebrad/tests/common/shielded_scan/subscribe_results.rs @@ -89,11 +89,14 @@ pub(crate) async fn run() -> Result<()> { scan_task.register_keys( keys.iter() .cloned() - .map(|key| (key, Some(736000))) + .map(|key| (key, Some(780_000))) .collect(), )?; - let mut result_receiver = scan_task.subscribe(keys.into_iter().collect())?; + let mut result_receiver = scan_task + .subscribe(keys.into_iter().collect()) + .await + .expect("should send and receive message successfully"); // Wait for the scanner to send a result in the channel let result = tokio::time::timeout(WAIT_FOR_RESULTS_DURATION, result_receiver.recv()).await?;