fix(scan): Fix minor concurrency bug in the `scan` gRPC method (#8303)

* update scan task to notify subscribe caller once the new subscribed_keys are sent on the watch channel

* Fixes timing bug in scan gRPC method:

Joins register/subscribe scan service calls, sends SubscribeResults request first, and filters out duplicate results from channel

* Removes outdated TODO

* wraps subscribed_keys in an Arc before sending to watch channel, fixes typo

* Remove result senders for keys that have been removed
This commit is contained in:
Arya 2024-02-21 18:29:13 -05:00 committed by GitHub
parent c6b56b492a
commit 5a9281a7a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 150 additions and 86 deletions

View File

@ -70,13 +70,24 @@ where
.into_iter()
.map(|KeyWithHeight { key, height }| (key, height))
.collect();
let ScanServiceResponse::RegisteredKeys(_) = self
let register_keys_response_fut = self
.scan_service
.clone()
.ready()
.and_then(|service| service.call(ScanServiceRequest::RegisterKeys(keys.clone())))
.await
.oneshot(ScanServiceRequest::RegisterKeys(keys.clone()));
let keys: Vec<_> = keys.into_iter().map(|(key, _start_at)| key).collect();
let subscribe_results_response_fut =
self.scan_service
.clone()
.oneshot(ScanServiceRequest::SubscribeResults(
keys.iter().cloned().collect(),
));
let (register_keys_response, subscribe_results_response) =
tokio::join!(register_keys_response_fut, subscribe_results_response_fut);
let ScanServiceResponse::RegisteredKeys(_) = register_keys_response
.map_err(|err| Status::unknown(format!("scan service returned error: {err}")))?
else {
return Err(Status::unknown(
@ -84,7 +95,14 @@ where
));
};
let keys: Vec<_> = keys.into_iter().map(|(key, _start_at)| key).collect();
let ScanServiceResponse::SubscribeResults(mut results_receiver) =
subscribe_results_response
.map_err(|err| Status::unknown(format!("scan service returned error: {err}")))?
else {
return Err(Status::unknown(
"scan service returned an unexpected response",
));
};
let ScanServiceResponse::Results(results) = self
.scan_service
@ -99,29 +117,31 @@ where
));
};
let ScanServiceResponse::SubscribeResults(mut results_receiver) = self
.scan_service
.clone()
.ready()
.and_then(|service| {
service.call(ScanServiceRequest::SubscribeResults(
keys.iter().cloned().collect(),
))
})
.await
.map_err(|err| Status::unknown(format!("scan service returned error: {err}")))?
else {
return Err(Status::unknown(
"scan service returned an unexpected response",
));
};
let (response_sender, response_receiver) =
tokio::sync::mpsc::channel(SCAN_RESPONDER_BUFFER_SIZE);
let response_stream = ReceiverStream::new(response_receiver);
tokio::spawn(async move {
let initial_results = process_results(keys, results);
let mut initial_results = process_results(keys, results);
// Empty results receiver channel to filter out duplicate results between the channel and cache
while let Ok(ScanResult { key, height, tx_id }) = results_receiver.try_recv() {
let entry = initial_results
.entry(key)
.or_default()
.by_height
.entry(height.0)
.or_default();
let tx_id = Transaction {
hash: tx_id.to_string(),
};
// Add the scan result to the initial results if it's not already present.
if !entry.transactions.contains(&tx_id) {
entry.transactions.push(tx_id);
}
}
let send_result = response_sender
.send(Ok(ScanResponse {

View File

@ -176,15 +176,15 @@ async fn test_mocked_rpc_response_data_for_network(network: Network, random_port
.await
.respond(ScanResponse::RegisteredKeys(vec![]));
mock_scan_service
.expect_request_that(|req| matches!(req, ScanRequest::Results(_)))
.await
.respond(ScanResponse::Results(fake_results_response));
mock_scan_service
.expect_request_that(|req| matches!(req, ScanRequest::SubscribeResults(_)))
.await
.respond(ScanResponse::SubscribeResults(fake_results_receiver));
mock_scan_service
.expect_request_that(|req| matches!(req, ScanRequest::Results(_)))
.await
.respond(ScanResponse::Results(fake_results_response));
});
}

View File

@ -169,7 +169,7 @@ impl Service<Request> for ScanService {
let mut scan_task = self.scan_task.clone();
return async move {
let results_receiver = scan_task.subscribe(keys)?;
let results_receiver = scan_task.subscribe(keys).await?;
Ok(Response::SubscribeResults(results_receiver))
}

View File

@ -8,6 +8,7 @@ use tokio::sync::{
oneshot,
};
use tower::BoxError;
use zcash_primitives::{sapling::SaplingIvk, zip32::DiversifiableFullViewingKey};
use zebra_chain::{block::Height, parameters::Network};
use zebra_node_services::scan_service::response::ScanResult;
@ -41,11 +42,11 @@ pub enum ScanTaskCommand {
/// Start sending results for key hashes to `result_sender`
SubscribeResults {
/// Sender for results
result_sender: Sender<ScanResult>,
/// Key hashes to send the results of to result channel
keys: HashSet<String>,
/// Returns the result receiver once the subscribed keys have been added.
rsp_tx: oneshot::Sender<Receiver<ScanResult>>,
},
}
@ -69,6 +70,7 @@ impl ScanTask {
(Vec<DiversifiableFullViewingKey>, Vec<SaplingIvk>, Height),
>,
HashMap<SaplingScanningKey, Sender<ScanResult>>,
Vec<(Receiver<ScanResult>, oneshot::Sender<Receiver<ScanResult>>)>,
),
Report,
> {
@ -76,6 +78,7 @@ impl ScanTask {
let mut new_keys = HashMap::new();
let mut new_result_senders = HashMap::new();
let mut new_result_receivers = Vec::new();
let sapling_activation_height = network.sapling_activation_height();
loop {
@ -142,13 +145,20 @@ impl ScanTask {
let _ = done_tx.send(());
}
ScanTaskCommand::SubscribeResults {
result_sender,
keys,
} => {
let keys = keys
ScanTaskCommand::SubscribeResults { rsp_tx, keys } => {
let keys: Vec<_> = keys
.into_iter()
.filter(|key| registered_keys.contains_key(key));
.filter(|key| registered_keys.contains_key(key))
.collect();
if keys.is_empty() {
continue;
}
let (result_sender, result_receiver) =
tokio::sync::mpsc::channel(RESULTS_SENDER_BUFFER_SIZE);
new_result_receivers.push((result_receiver, rsp_tx));
for key in keys {
new_result_senders.insert(key, result_sender.clone());
@ -157,7 +167,7 @@ impl ScanTask {
}
}
Ok((new_keys, new_result_senders))
Ok((new_keys, new_result_senders, new_result_receivers))
}
/// Sends a command to the scan task
@ -200,18 +210,14 @@ impl ScanTask {
/// Sends a message to the scan task to start sending the results for the provided viewing keys to a channel.
///
/// Returns the channel receiver.
pub fn subscribe(
pub async fn subscribe(
&mut self,
keys: HashSet<SaplingScanningKey>,
) -> Result<Receiver<ScanResult>, TrySendError<ScanTaskCommand>> {
// TODO: Use a bounded channel
let (result_sender, result_receiver) =
tokio::sync::mpsc::channel(RESULTS_SENDER_BUFFER_SIZE);
) -> Result<Receiver<ScanResult>, BoxError> {
let (rsp_tx, rsp_rx) = oneshot::channel();
self.send(ScanTaskCommand::SubscribeResults {
result_sender,
keys,
})
.map(|_| result_receiver)
self.send(ScanTaskCommand::SubscribeResults { keys, rsp_tx })?;
Ok(rsp_rx.await?)
}
}

View File

@ -1,6 +1,6 @@
//! The scan task executor
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use color_eyre::eyre::Report;
use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
@ -19,7 +19,7 @@ use super::scan::ScanRangeTaskBuilder;
const EXECUTOR_BUFFER_SIZE: usize = 100;
pub fn spawn_init(
subscribed_keys_receiver: tokio::sync::watch::Receiver<HashMap<String, Sender<ScanResult>>>,
subscribed_keys_receiver: watch::Receiver<Arc<HashMap<String, Sender<ScanResult>>>>,
) -> (Sender<ScanRangeTaskBuilder>, JoinHandle<Result<(), Report>>) {
let (scan_task_sender, scan_task_receiver) = tokio::sync::mpsc::channel(EXECUTOR_BUFFER_SIZE);
@ -33,7 +33,7 @@ pub fn spawn_init(
pub async fn scan_task_executor(
mut scan_task_receiver: Receiver<ScanRangeTaskBuilder>,
subscribed_keys_receiver: watch::Receiver<HashMap<String, Sender<ScanResult>>>,
subscribed_keys_receiver: watch::Receiver<Arc<HashMap<String, Sender<ScanResult>>>>,
) -> Result<(), Report> {
let mut scan_range_tasks = FuturesUnordered::new();

View File

@ -8,7 +8,10 @@ use std::{
use color_eyre::{eyre::eyre, Report};
use itertools::Itertools;
use tokio::{sync::mpsc::Sender, task::JoinHandle};
use tokio::{
sync::{mpsc::Sender, watch},
task::JoinHandle,
};
use tower::{buffer::Buffer, util::BoxService, Service, ServiceExt};
use tracing::Instrument;
@ -116,10 +119,10 @@ pub async fn start(
let mut subscribed_keys: HashMap<SaplingScanningKey, Sender<ScanResult>> = HashMap::new();
let (subscribed_keys_sender, subscribed_keys_receiver) =
tokio::sync::watch::channel(subscribed_keys.clone());
tokio::sync::watch::channel(Arc::new(subscribed_keys.clone()));
let (scan_task_sender, scan_task_executor_handle) =
executor::spawn_init(subscribed_keys_receiver);
executor::spawn_init(subscribed_keys_receiver.clone());
let mut scan_task_executor_handle = Some(scan_task_executor_handle);
// Give empty states time to verify some blocks before we start scanning.
@ -139,17 +142,23 @@ pub async fn start(
let was_parsed_keys_empty = parsed_keys.is_empty();
let (new_keys, new_result_senders) =
let (new_keys, new_result_senders, new_result_receivers) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
subscribed_keys.extend(new_result_senders);
// Drop any results senders that are closed from subscribed_keys
subscribed_keys.retain(|key, sender| !sender.is_closed() && parsed_keys.contains_key(key));
// Send the latest version of `subscribed_keys` before spawning the scan range task
if !new_result_senders.is_empty() {
subscribed_keys.extend(new_result_senders);
// Ignore send errors, it's okay if there aren't any receivers.
let _ = subscribed_keys_sender.send(subscribed_keys.clone());
subscribed_keys_sender
.send(Arc::new(subscribed_keys.clone()))
.expect("last receiver should not be dropped while this task is running");
for (result_receiver, rsp_tx) in new_result_receivers {
// Ignore send errors, we drop any closed results channels above.
let _ = rsp_tx.send(result_receiver);
}
// TODO: Check if the `start_height` is at or above the current height
if !new_keys.is_empty() {
let state = state.clone();
let storage = storage.clone();
@ -163,7 +172,9 @@ pub async fn start(
if was_parsed_keys_empty {
info!(?start_height, "setting new start height");
height = start_height;
} else if start_height < height {
}
// Skip spawning ScanRange task if `start_height` is at or above the current height
else if start_height < height {
scan_task_sender
.send(ScanRangeTaskBuilder::new(height, new_keys, state, storage))
.await
@ -179,7 +190,7 @@ pub async fn start(
storage.clone(),
key_heights.clone(),
parsed_keys.clone(),
subscribed_keys.clone(),
subscribed_keys_receiver.clone(),
)
.await?;
@ -241,7 +252,7 @@ pub async fn scan_height_and_store_results(
storage: Storage,
key_last_scanned_heights: Arc<HashMap<SaplingScanningKey, Height>>,
parsed_keys: HashMap<SaplingScanningKey, (Vec<DiversifiableFullViewingKey>, Vec<SaplingIvk>)>,
subscribed_keys: HashMap<SaplingScanningKey, Sender<ScanResult>>,
subscribed_keys_receiver: watch::Receiver<Arc<HashMap<String, Sender<ScanResult>>>>,
) -> Result<Option<Height>, Report> {
let network = storage.network();
@ -295,7 +306,7 @@ pub async fn scan_height_and_store_results(
_other => {}
};
let results_sender = subscribed_keys.get(&sapling_key).cloned();
let subscribed_keys_receiver = subscribed_keys_receiver.clone();
let sapling_key = sapling_key.clone();
let block = block.clone();
@ -322,7 +333,8 @@ pub async fn scan_height_and_store_results(
let dfvk_res = scanned_block_to_db_result(dfvk_res);
let ivk_res = scanned_block_to_db_result(ivk_res);
if let Some(results_sender) = results_sender {
let latest_subscribed_keys = subscribed_keys_receiver.borrow().clone();
if let Some(results_sender) = latest_subscribed_keys.get(&sapling_key).cloned() {
let results = dfvk_res.iter().chain(ivk_res.iter());
for (_tx_index, &tx_id) in results {

View File

@ -56,7 +56,7 @@ impl ScanRangeTaskBuilder {
// TODO: return a tuple with a shutdown sender
pub fn spawn(
self,
subscribed_keys_receiver: watch::Receiver<HashMap<String, Sender<ScanResult>>>,
subscribed_keys_receiver: watch::Receiver<Arc<HashMap<String, Sender<ScanResult>>>>,
) -> JoinHandle<Result<(), Report>> {
let Self {
height_range,
@ -86,7 +86,7 @@ pub async fn scan_range(
keys: HashMap<SaplingScanningKey, (Vec<DiversifiableFullViewingKey>, Vec<SaplingIvk>, Height)>,
state: State,
storage: Storage,
subscribed_keys_receiver: watch::Receiver<HashMap<String, Sender<ScanResult>>>,
subscribed_keys_receiver: watch::Receiver<Arc<HashMap<String, Sender<ScanResult>>>>,
) -> Result<(), Report> {
let sapling_activation_height = storage.network().sapling_activation_height();
// Do not scan and notify if we are below sapling activation height.
@ -116,7 +116,6 @@ pub async fn scan_range(
.collect();
while height < stop_before_height {
let subscribed_keys = subscribed_keys_receiver.borrow().clone();
let scanned_height = scan_height_and_store_results(
height,
state.clone(),
@ -124,7 +123,7 @@ pub async fn scan_range(
storage.clone(),
key_heights.clone(),
parsed_keys.clone(),
subscribed_keys,
subscribed_keys_receiver.clone(),
)
.await?;

View File

@ -1,6 +1,9 @@
//! Fixed test vectors for the scan task.
use std::collections::{HashMap, HashSet};
use std::{
collections::{HashMap, HashSet},
time::Duration,
};
use color_eyre::Report;
@ -24,7 +27,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
sapling_keys.into_iter().zip((0..).map(Some)).collect();
mock_scan_task.register_keys(sapling_keys_with_birth_heights.clone())?;
let (new_keys, _new_results_senders) =
let (new_keys, _new_results_senders, _new_results_receivers) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
// Check that it updated parsed_keys correctly and returned the right new keys when starting with an empty state
@ -45,7 +48,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
// Check that no key should be added if they are all already known and the heights are the same
let (new_keys, _new_results_senders) =
let (new_keys, _new_results_senders, _new_results_receivers) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
assert_eq!(
@ -71,7 +74,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
mock_scan_task.register_keys(sapling_keys_with_birth_heights[10..20].to_vec())?;
mock_scan_task.register_keys(sapling_keys_with_birth_heights[10..15].to_vec())?;
let (new_keys, _new_results_senders) =
let (new_keys, _new_results_senders, _new_results_receivers) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
assert_eq!(
@ -91,7 +94,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
let sapling_keys = mock_sapling_scanning_keys(30, network);
let done_rx = mock_scan_task.remove_keys(&sapling_keys)?;
let (new_keys, _new_results_senders) =
let (new_keys, _new_results_senders, _new_results_receivers) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
// Check that it sends the done notification successfully before returning and dropping `done_tx`
@ -110,7 +113,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
mock_scan_task.remove_keys(&sapling_keys)?;
let (new_keys, _new_results_senders) =
let (new_keys, _new_results_senders, _new_results_receivers) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
assert!(
@ -126,7 +129,7 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
mock_scan_task.register_keys(sapling_keys_with_birth_heights[..2].to_vec())?;
let (new_keys, _new_results_senders) =
let (new_keys, _new_results_senders, _new_results_receivers) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
assert_eq!(
@ -142,11 +145,31 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
);
let subscribe_keys: HashSet<String> = sapling_keys[..5].iter().cloned().collect();
let mut result_receiver = mock_scan_task.subscribe(subscribe_keys.clone())?;
let result_receiver_fut = {
let mut mock_scan_task = mock_scan_task.clone();
tokio::spawn(async move { mock_scan_task.subscribe(subscribe_keys.clone()).await })
};
let (_new_keys, new_results_senders) =
// Wait for spawned task to send subscribe message
tokio::time::sleep(Duration::from_secs(1)).await;
let (_new_keys, new_results_senders, new_results_receivers) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
let (result_receiver, rsp_tx) = new_results_receivers
.into_iter()
.next()
.expect("there should be a new result receiver");
rsp_tx
.send(result_receiver)
.expect("should send response successfully");
let mut result_receiver = result_receiver_fut
.await
.expect("tokio task should join successfully")
.expect("should send and receive message successfully");
let processed_subscribe_keys: HashSet<String> = new_results_senders.keys().cloned().collect();
let expected_new_subscribe_keys: HashSet<String> = sapling_keys[..2].iter().cloned().collect();

View File

@ -96,14 +96,15 @@ pub async fn scan_service_subscribes_to_results_correctly() -> Result<()> {
let expected_keys = keys.iter().cloned().collect();
let cmd_handler_fut = tokio::spawn(async move {
let Some(ScanTaskCommand::SubscribeResults {
result_sender: _,
keys,
}) = cmd_receiver.recv().await
let Some(ScanTaskCommand::SubscribeResults { rsp_tx, keys }) = cmd_receiver.recv().await
else {
panic!("should successfully receive SubscribeResults message");
};
let (_results_sender, results_receiver) = tokio::sync::mpsc::channel(1);
rsp_tx
.send(results_receiver)
.expect("should send response successfully");
assert_eq!(keys, expected_keys, "keys should match the request keys");
});

View File

@ -89,11 +89,14 @@ pub(crate) async fn run() -> Result<()> {
scan_task.register_keys(
keys.iter()
.cloned()
.map(|key| (key, Some(736000)))
.map(|key| (key, Some(780_000)))
.collect(),
)?;
let mut result_receiver = scan_task.subscribe(keys.into_iter().collect())?;
let mut result_receiver = scan_task
.subscribe(keys.into_iter().collect())
.await
.expect("should send and receive message successfully");
// Wait for the scanner to send a result in the channel
let result = tokio::time::timeout(WAIT_FOR_RESULTS_DURATION, result_receiver.recv()).await?;