add(scan): Implement SubscribeResults request for scan service (#8253)

* processes SubscribeResults messages

* send tx ids of results to the subscribe channel

* replaces BoxError with Report in scan_range

* adds a watch channel for using subscribed_keys in scan_range

* updates args to process_messages in test

* adds a `subscribe` method to ScanTask for sending a SubscribeResults cmd

* updates test for process_messages to cover subscribe cmds

* impls SubscribeResult service request and updates sender type

* adds test for SubscribeResults scan service request

* adds acceptance test

* updates tests and imports

* fixes acceptance test by using spawn_blocking to avoid blocking async executor and setting an appropriate start height

* fixes test

* Applies suggestions from code review.

* use tokio mpsc channel in scan task instead of std/blocking mpsc

* use tokio mpsc channel for results sender

* adds `was_parsed_keys_empty` instead of checking that all the parsed keys are new keys

* fixes test failures related to send errors in scan task

* returns height and key for scan results from subcribe_results results receiver

* hide scan_service mod in zebra-node-service behind feature
This commit is contained in:
Arya 2024-02-12 19:42:40 -05:00 committed by GitHub
parent c69befda2f
commit 3929a526e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 495 additions and 136 deletions

View File

@ -5829,6 +5829,7 @@ dependencies = [
"reqwest",
"serde",
"serde_json",
"tokio",
"zebra-chain",
]

View File

@ -25,7 +25,7 @@ color-eyre = "0.6.2"
zcash_primitives = { version = "0.13.0-rc.1" }
zebra-node-services = { path = "../zebra-node-services", version = "1.0.0-beta.34" }
zebra-node-services = { path = "../zebra-node-services", version = "1.0.0-beta.34", features = ["shielded-scan"] }
[build-dependencies]
tonic-build = "0.10.2"

View File

@ -34,6 +34,8 @@ rpc-client = [
"serde_json",
]
shielded-scan = ["tokio"]
[dependencies]
zebra-chain = { path = "../zebra-chain" , version = "1.0.0-beta.34" }
@ -46,6 +48,7 @@ jsonrpc-core = { version = "18.0.0", optional = true }
reqwest = { version = "0.11.24", default-features = false, features = ["rustls-tls"], optional = true }
serde = { version = "1.0.196", optional = true }
serde_json = { version = "1.0.113", optional = true }
tokio = { version = "1.36.0", features = ["time"], optional = true }
[dev-dependencies]

View File

@ -13,4 +13,5 @@ pub mod rpc_client;
/// parameterized by 'a), *not* that the object itself has 'static lifetime.
pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
#[cfg(feature = "shielded-scan")]
pub mod scan_service;

View File

@ -1,5 +1,7 @@
//! `zebra_scan::service::ScanService` request types.
use std::collections::HashSet;
use crate::BoxError;
/// The maximum number of keys that may be included in a request to the scan service
@ -23,8 +25,8 @@ pub enum Request {
/// Accept keys and return transaction data
Results(Vec<String>),
/// TODO: Accept `KeyHash`es and return a channel receiver
SubscribeResults(Vec<()>),
/// Accept keys and return a channel receiver for transaction data
SubscribeResults(HashSet<String>),
/// Clear the results for a set of viewing keys
ClearResults(Vec<String>),

View File

@ -1,11 +1,22 @@
//! `zebra_scan::service::ScanService` response types.
use std::{
collections::BTreeMap,
sync::{mpsc, Arc},
};
use std::collections::BTreeMap;
use zebra_chain::{block::Height, transaction::Hash};
use zebra_chain::{block::Height, transaction};
/// A relevant transaction for a key and the block height where it was found.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ScanResult {
/// The key that successfully decrypts the transaction
pub key: String,
/// The height of the block with the transaction
pub height: Height,
/// A transaction ID, which uniquely identifies mined v5 transactions,
/// and all v1-v4 transactions.
pub tx_id: transaction::Hash,
}
#[derive(Debug)]
/// Response types for `zebra_scan::service::ScanService`
@ -24,7 +35,7 @@ pub enum Response {
/// Response to [`Results`](super::request::Request::Results) request
///
/// We use the nested `BTreeMap` so we don't repeat any piece of response data.
Results(BTreeMap<String, BTreeMap<Height, Vec<Hash>>>),
Results(BTreeMap<String, BTreeMap<Height, Vec<transaction::Hash>>>),
/// Response to [`DeleteKeys`](super::request::Request::DeleteKeys) request
DeletedKeys,
@ -32,6 +43,6 @@ pub enum Response {
/// Response to [`ClearResults`](super::request::Request::ClearResults) request
ClearedResults,
/// Response to `SubscribeResults` request
SubscribeResults(mpsc::Receiver<Arc<Hash>>),
/// Response to [`SubscribeResults`](super::request::Request::SubscribeResults) request
SubscribeResults(tokio::sync::mpsc::Receiver<ScanResult>),
}

View File

@ -56,7 +56,7 @@ zcash_primitives = "0.13.0-rc.1"
zebra-chain = { path = "../zebra-chain", version = "1.0.0-beta.34" }
zebra-state = { path = "../zebra-state", version = "1.0.0-beta.34", features = ["shielded-scan"] }
zebra-node-services = { path = "../zebra-node-services", version = "1.0.0-beta.33" }
zebra-node-services = { path = "../zebra-node-services", version = "1.0.0-beta.34", features = ["shielded-scan"] }
zebra-grpc = { path = "../zebra-grpc", version = "0.1.0-alpha.1" }
chrono = { version = "0.4.33", default-features = false, features = ["clock", "std", "serde"] }

View File

@ -57,7 +57,7 @@ pub fn spawn_init(
tokio::task::spawn_blocking(move || Storage::new(&config, network, false))
.wait_for_panics()
.await;
let (_cmd_sender, cmd_receiver) = std::sync::mpsc::channel();
let (_cmd_sender, cmd_receiver) = tokio::sync::mpsc::channel(1);
scan::start(state, chain_tip_change, storage, cmd_receiver).await
}
.in_current_span(),

View File

@ -19,7 +19,7 @@ pub mod scan_task;
pub use scan_task::{ScanTask, ScanTaskCommand};
#[cfg(any(test, feature = "proptest-impl"))]
use std::sync::mpsc::Receiver;
use tokio::sync::mpsc::Receiver;
/// Zebra-scan [`tower::Service`]
#[derive(Debug)]
@ -165,8 +165,15 @@ impl Service<Request> for ScanService {
.boxed();
}
Request::SubscribeResults(_key_hashes) => {
// TODO: send key_hashes and mpsc::Sender to scanner task, return mpsc::Receiver to caller
Request::SubscribeResults(keys) => {
let mut scan_task = self.scan_task.clone();
return async move {
let results_receiver = scan_task.subscribe(keys)?;
Ok(Response::SubscribeResults(results_receiver))
}
.boxed();
}
Request::ClearResults(keys) => {

View File

@ -1,6 +1,6 @@
//! Types and method implementations for [`ScanTask`]
use std::sync::{mpsc, Arc};
use std::sync::Arc;
use color_eyre::Report;
use tokio::task::JoinHandle;
@ -25,14 +25,17 @@ pub struct ScanTask {
pub handle: Arc<JoinHandle<Result<(), Report>>>,
/// Task command channel sender
pub cmd_sender: mpsc::Sender<ScanTaskCommand>,
pub cmd_sender: tokio::sync::mpsc::Sender<ScanTaskCommand>,
}
/// The size of the command channel buffer
const SCAN_TASK_BUFFER_SIZE: usize = 100;
impl ScanTask {
/// Spawns a new [`ScanTask`].
pub fn spawn(db: Storage, state: scan::State, chain_tip_change: ChainTipChange) -> Self {
// TODO: Use a bounded channel or move this logic to the scan service or another service.
let (cmd_sender, cmd_receiver) = mpsc::channel();
let (cmd_sender, cmd_receiver) = tokio::sync::mpsc::channel(SCAN_TASK_BUFFER_SIZE);
Self {
handle: Arc::new(scan::spawn_init(db, state, chain_tip_change, cmd_receiver)),

View File

@ -1,24 +1,24 @@
//! Types and method implementations for [`ScanTaskCommand`]
use std::{
collections::HashMap,
sync::{
mpsc::{self, Receiver, TryRecvError},
Arc,
},
};
use std::collections::{HashMap, HashSet};
use color_eyre::{eyre::eyre, Report};
use tokio::sync::oneshot;
use tokio::sync::{
mpsc::{error::TrySendError, Receiver, Sender},
oneshot,
};
use zcash_primitives::{sapling::SaplingIvk, zip32::DiversifiableFullViewingKey};
use zebra_chain::{block::Height, parameters::Network, transaction::Transaction};
use zebra_chain::{block::Height, parameters::Network};
use zebra_node_services::scan_service::response::ScanResult;
use zebra_state::SaplingScanningKey;
use crate::scan::sapling_key_to_scan_block_keys;
use super::ScanTask;
const RESULTS_SENDER_BUFFER_SIZE: usize = 100;
#[derive(Debug)]
/// Commands that can be sent to [`ScanTask`]
pub enum ScanTaskCommand {
@ -40,13 +40,12 @@ pub enum ScanTaskCommand {
},
/// Start sending results for key hashes to `result_sender`
// TODO: Implement this command (#8206)
SubscribeResults {
/// Sender for results
result_sender: mpsc::Sender<Arc<Transaction>>,
result_sender: Sender<ScanResult>,
/// Key hashes to send the results of to result channel
keys: Vec<String>,
keys: HashSet<String>,
},
}
@ -57,17 +56,26 @@ impl ScanTask {
///
/// Returns newly registered keys for scanning.
pub fn process_messages(
cmd_receiver: &Receiver<ScanTaskCommand>,
parsed_keys: &mut HashMap<
cmd_receiver: &mut tokio::sync::mpsc::Receiver<ScanTaskCommand>,
registered_keys: &mut HashMap<
SaplingScanningKey,
(Vec<DiversifiableFullViewingKey>, Vec<SaplingIvk>),
>,
network: Network,
) -> Result<
HashMap<SaplingScanningKey, (Vec<DiversifiableFullViewingKey>, Vec<SaplingIvk>, Height)>,
(
HashMap<
SaplingScanningKey,
(Vec<DiversifiableFullViewingKey>, Vec<SaplingIvk>, Height),
>,
HashMap<SaplingScanningKey, Sender<ScanResult>>,
),
Report,
> {
use tokio::sync::mpsc::error::TryRecvError;
let mut new_keys = HashMap::new();
let mut new_result_senders = HashMap::new();
let sapling_activation_height = network.sapling_activation_height();
loop {
@ -90,7 +98,9 @@ impl ScanTask {
// Don't accept keys that:
// 1. the scanner already has, and
// 2. were already submitted.
if parsed_keys.contains_key(&key.0) && !new_keys.contains_key(&key.0) {
if registered_keys.contains_key(&key.0)
&& !new_keys.contains_key(&key.0)
{
return None;
}
@ -116,7 +126,7 @@ impl ScanTask {
new_keys.extend(keys.clone());
parsed_keys.extend(
registered_keys.extend(
keys.into_iter()
.map(|(key, (dfvks, ivks, _))| (key, (dfvks, ivks))),
);
@ -124,7 +134,7 @@ impl ScanTask {
ScanTaskCommand::RemoveKeys { done_tx, keys } => {
for key in keys {
parsed_keys.remove(&key);
registered_keys.remove(&key);
new_keys.remove(&key);
}
@ -132,26 +142,39 @@ impl ScanTask {
let _ = done_tx.send(());
}
_ => continue,
ScanTaskCommand::SubscribeResults {
result_sender,
keys,
} => {
let keys = keys
.into_iter()
.filter(|key| registered_keys.contains_key(key));
for key in keys {
new_result_senders.insert(key, result_sender.clone());
}
}
}
}
Ok(new_keys)
Ok((new_keys, new_result_senders))
}
/// Sends a command to the scan task
pub fn send(
&mut self,
command: ScanTaskCommand,
) -> Result<(), mpsc::SendError<ScanTaskCommand>> {
self.cmd_sender.send(command)
) -> Result<(), tokio::sync::mpsc::error::TrySendError<ScanTaskCommand>> {
self.cmd_sender.try_send(command)
}
/// Sends a message to the scan task to remove the provided viewing keys.
///
/// Returns a oneshot channel receiver to notify the caller when the keys have been removed.
pub fn remove_keys(
&mut self,
keys: &[String],
) -> Result<oneshot::Receiver<()>, mpsc::SendError<ScanTaskCommand>> {
) -> Result<oneshot::Receiver<()>, TrySendError<ScanTaskCommand>> {
let (done_tx, done_rx) = oneshot::channel();
self.send(ScanTaskCommand::RemoveKeys {
@ -166,11 +189,29 @@ impl ScanTask {
pub fn register_keys(
&mut self,
keys: Vec<(String, Option<u32>)>,
) -> Result<oneshot::Receiver<Vec<String>>, mpsc::SendError<ScanTaskCommand>> {
) -> Result<oneshot::Receiver<Vec<String>>, TrySendError<ScanTaskCommand>> {
let (rsp_tx, rsp_rx) = oneshot::channel();
self.send(ScanTaskCommand::RegisterKeys { keys, rsp_tx })?;
Ok(rsp_rx)
}
/// Sends a message to the scan task to start sending the results for the provided viewing keys to a channel.
///
/// Returns the channel receiver.
pub fn subscribe(
&mut self,
keys: HashSet<SaplingScanningKey>,
) -> Result<Receiver<ScanResult>, TrySendError<ScanTaskCommand>> {
// TODO: Use a bounded channel
let (result_sender, result_receiver) =
tokio::sync::mpsc::channel(RESULTS_SENDER_BUFFER_SIZE);
self.send(ScanTaskCommand::SubscribeResults {
result_sender,
keys,
})
.map(|_| result_receiver)
}
}

View File

@ -1,45 +1,52 @@
//! The scan task executor
use std::collections::HashMap;
use color_eyre::eyre::Report;
use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
use tokio::{
sync::mpsc::{Receiver, Sender},
sync::{
mpsc::{Receiver, Sender},
watch,
},
task::JoinHandle,
};
use tracing::Instrument;
use zebra_chain::BoxError;
use zebra_node_services::scan_service::response::ScanResult;
use super::scan::ScanRangeTaskBuilder;
const EXECUTOR_BUFFER_SIZE: usize = 100;
pub fn spawn_init() -> (
Sender<ScanRangeTaskBuilder>,
JoinHandle<Result<(), BoxError>>,
) {
// TODO: Use a bounded channel.
pub fn spawn_init(
subscribed_keys_receiver: tokio::sync::watch::Receiver<HashMap<String, Sender<ScanResult>>>,
) -> (Sender<ScanRangeTaskBuilder>, JoinHandle<Result<(), Report>>) {
let (scan_task_sender, scan_task_receiver) = tokio::sync::mpsc::channel(EXECUTOR_BUFFER_SIZE);
(
scan_task_sender,
tokio::spawn(scan_task_executor(scan_task_receiver).in_current_span()),
tokio::spawn(
scan_task_executor(scan_task_receiver, subscribed_keys_receiver).in_current_span(),
),
)
}
pub async fn scan_task_executor(
mut scan_task_receiver: Receiver<ScanRangeTaskBuilder>,
) -> Result<(), BoxError> {
subscribed_keys_receiver: watch::Receiver<HashMap<String, Sender<ScanResult>>>,
) -> Result<(), Report> {
let mut scan_range_tasks = FuturesUnordered::new();
// Push a pending future so that `.next()` will always return `Some`
scan_range_tasks.push(tokio::spawn(
std::future::pending::<Result<(), BoxError>>().boxed(),
std::future::pending::<Result<(), Report>>().boxed(),
));
loop {
tokio::select! {
Some(scan_range_task) = scan_task_receiver.recv() => {
// TODO: Add a long timeout?
scan_range_tasks.push(scan_range_task.spawn());
scan_range_tasks.push(scan_range_task.spawn(subscribed_keys_receiver.clone()));
}
Some(finished_task) = scan_range_tasks.next() => {

View File

@ -2,13 +2,13 @@
use std::{
collections::{BTreeMap, HashMap},
sync::{mpsc::Receiver, Arc},
sync::Arc,
time::Duration,
};
use color_eyre::{eyre::eyre, Report};
use itertools::Itertools;
use tokio::task::JoinHandle;
use tokio::{sync::mpsc::Sender, task::JoinHandle};
use tower::{buffer::Buffer, util::BoxService, Service, ServiceExt};
use tracing::Instrument;
@ -34,6 +34,7 @@ use zebra_chain::{
serialization::ZcashSerialize,
transaction::Transaction,
};
use zebra_node_services::scan_service::response::ScanResult;
use zebra_state::{ChainTipChange, SaplingScannedResult, TransactionIndex};
use crate::{
@ -72,11 +73,13 @@ pub async fn start(
state: State,
chain_tip_change: ChainTipChange,
storage: Storage,
cmd_receiver: Receiver<ScanTaskCommand>,
mut cmd_receiver: tokio::sync::mpsc::Receiver<ScanTaskCommand>,
) -> Result<(), Report> {
let network = storage.network();
let sapling_activation_height = network.sapling_activation_height();
info!(?network, "starting scan task");
// Do not scan and notify if we are below sapling activation height.
wait_for_height(
sapling_activation_height,
@ -94,6 +97,8 @@ pub async fn start(
let mut height = get_min_height(&key_heights).unwrap_or(sapling_activation_height);
info!(start_height = ?height, "got min scan height");
// Parse and convert keys once, then use them to scan all blocks.
// There is some cryptography here, but it should be fast even with thousands of keys.
let mut parsed_keys: HashMap<
@ -107,7 +112,13 @@ pub async fn start(
})
.try_collect()?;
let (scan_task_sender, scan_task_executor_handle) = executor::spawn_init();
let mut subscribed_keys: HashMap<SaplingScanningKey, Sender<ScanResult>> = HashMap::new();
let (subscribed_keys_sender, subscribed_keys_receiver) =
tokio::sync::watch::channel(subscribed_keys.clone());
let (scan_task_sender, scan_task_executor_handle) =
executor::spawn_init(subscribed_keys_receiver);
let mut scan_task_executor_handle = Some(scan_task_executor_handle);
// Give empty states time to verify some blocks before we start scanning.
@ -125,31 +136,58 @@ pub async fn start(
}
}
let new_keys = ScanTask::process_messages(&cmd_receiver, &mut parsed_keys, network)?;
let was_parsed_keys_empty = parsed_keys.is_empty();
let (new_keys, new_result_senders) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
// Send the latest version of `subscribed_keys` before spawning the scan range task
if !new_result_senders.is_empty() {
subscribed_keys.extend(new_result_senders);
// Ignore send errors, it's okay if there aren't any receivers.
let _ = subscribed_keys_sender.send(subscribed_keys.clone());
}
// TODO: Check if the `start_height` is at or above the current height
if !new_keys.is_empty() {
let state = state.clone();
let storage = storage.clone();
scan_task_sender
.send(ScanRangeTaskBuilder::new(height, new_keys, state, storage))
.await
.expect("scan_until_task channel should not be closed");
let start_height = new_keys
.iter()
.map(|(_, (_, _, height))| *height)
.min()
.unwrap_or(sapling_activation_height);
if was_parsed_keys_empty {
info!(?start_height, "setting new start height");
height = start_height;
} else if start_height < height {
scan_task_sender
.send(ScanRangeTaskBuilder::new(height, new_keys, state, storage))
.await
.expect("scan_until_task channel should not be closed");
}
}
let scanned_height = scan_height_and_store_results(
height,
state.clone(),
Some(chain_tip_change.clone()),
storage.clone(),
key_heights.clone(),
parsed_keys.clone(),
)
.await?;
if !parsed_keys.is_empty() {
let scanned_height = scan_height_and_store_results(
height,
state.clone(),
Some(chain_tip_change.clone()),
storage.clone(),
key_heights.clone(),
parsed_keys.clone(),
subscribed_keys.clone(),
)
.await?;
// If we've reached the tip, sleep for a while then try and get the same block.
if scanned_height.is_none() {
// If we've reached the tip, sleep for a while then try and get the same block.
if scanned_height.is_none() {
tokio::time::sleep(CHECK_INTERVAL).await;
continue;
}
} else {
tokio::time::sleep(CHECK_INTERVAL).await;
continue;
}
@ -173,10 +211,16 @@ pub async fn wait_for_height(
"scanner is waiting for {height_name}. Current tip: {}, {height_name}: {}",
tip_height.0, height.0
);
tokio::time::sleep(CHECK_INTERVAL).await;
continue;
} else {
info!(
"scanner finished waiting for {height_name}. Current tip: {}, {height_name}: {}",
tip_height.0, height.0
);
break;
}
break;
}
Ok(())
@ -196,6 +240,7 @@ pub async fn scan_height_and_store_results(
storage: Storage,
key_last_scanned_heights: Arc<HashMap<SaplingScanningKey, Height>>,
parsed_keys: HashMap<SaplingScanningKey, (Vec<DiversifiableFullViewingKey>, Vec<SaplingIvk>)>,
subscribed_keys: HashMap<SaplingScanningKey, Sender<ScanResult>>,
) -> Result<Option<Height>, Report> {
let network = storage.network();
@ -237,12 +282,20 @@ pub async fn scan_height_and_store_results(
height.as_usize(),
chain_tip_change.latest_chain_tip().best_tip_height().expect("we should have a tip to scan").as_usize(),
);
} else {
info!(
"Scanning the blockchain for key {}, started at block {:?}, now at block {:?}",
key_index_in_task, last_scanned_height.next().expect("height is not maximum").as_usize(),
height.as_usize(),
);
}
}
_other => {}
};
let results_sender = subscribed_keys.get(&sapling_key).cloned();
let sapling_key = sapling_key.clone();
let block = block.clone();
let mut storage = storage.clone();
@ -268,6 +321,19 @@ pub async fn scan_height_and_store_results(
let dfvk_res = scanned_block_to_db_result(dfvk_res);
let ivk_res = scanned_block_to_db_result(ivk_res);
if let Some(results_sender) = results_sender {
let results = dfvk_res.iter().chain(ivk_res.iter());
for (_tx_index, &tx_id) in results {
// TODO: Handle `SendErrors` by dropping sender from `subscribed_keys`
let _ = results_sender.try_send(ScanResult {
key: sapling_key.clone(),
height,
tx_id: tx_id.into(),
});
}
}
storage.add_sapling_results(&sapling_key, height, dfvk_res);
storage.add_sapling_results(&sapling_key, height, ivk_res);
@ -491,7 +557,7 @@ pub fn spawn_init(
storage: Storage,
state: State,
chain_tip_change: ChainTipChange,
cmd_receiver: Receiver<ScanTaskCommand>,
cmd_receiver: tokio::sync::mpsc::Receiver<ScanTaskCommand>,
) -> JoinHandle<Result<(), Report>> {
tokio::spawn(start(state, chain_tip_change, storage, cmd_receiver).in_current_span())
}

View File

@ -2,16 +2,20 @@
use std::{collections::HashMap, sync::Arc};
use tokio::task::JoinHandle;
use tracing::Instrument;
use zcash_primitives::{sapling::SaplingIvk, zip32::DiversifiableFullViewingKey};
use zebra_chain::{block::Height, BoxError};
use zebra_state::SaplingScanningKey;
use crate::{
scan::{scan_height_and_store_results, wait_for_height, State, CHECK_INTERVAL},
scan::{get_min_height, scan_height_and_store_results, wait_for_height, State, CHECK_INTERVAL},
storage::Storage,
};
use color_eyre::eyre::Report;
use tokio::{
sync::{mpsc::Sender, watch},
task::JoinHandle,
};
use tracing::Instrument;
use zcash_primitives::{sapling::SaplingIvk, zip32::DiversifiableFullViewingKey};
use zebra_chain::block::Height;
use zebra_node_services::scan_service::response::ScanResult;
use zebra_state::SaplingScanningKey;
/// A builder for a scan until task
pub struct ScanRangeTaskBuilder {
@ -50,7 +54,10 @@ impl ScanRangeTaskBuilder {
/// Spawns a `scan_range()` task and returns its [`JoinHandle`]
// TODO: return a tuple with a shutdown sender
pub fn spawn(self) -> JoinHandle<Result<(), BoxError>> {
pub fn spawn(
self,
subscribed_keys_receiver: watch::Receiver<HashMap<String, Sender<ScanResult>>>,
) -> JoinHandle<Result<(), Report>> {
let Self {
height_range,
keys,
@ -58,7 +65,16 @@ impl ScanRangeTaskBuilder {
storage,
} = self;
tokio::spawn(scan_range(height_range.end, keys, state, storage).in_current_span())
tokio::spawn(
scan_range(
height_range.end,
keys,
state,
storage,
subscribed_keys_receiver,
)
.in_current_span(),
)
}
}
@ -70,7 +86,8 @@ pub async fn scan_range(
keys: HashMap<SaplingScanningKey, (Vec<DiversifiableFullViewingKey>, Vec<SaplingIvk>, Height)>,
state: State,
storage: Storage,
) -> Result<(), BoxError> {
subscribed_keys_receiver: watch::Receiver<HashMap<String, Sender<ScanResult>>>,
) -> Result<(), Report> {
let sapling_activation_height = storage.network().sapling_activation_height();
// Do not scan and notify if we are below sapling activation height.
wait_for_height(
@ -84,13 +101,10 @@ pub async fn scan_range(
.iter()
.map(|(key, (_, _, height))| (key.clone(), *height))
.collect();
let key_heights = Arc::new(key_heights);
let mut height = key_heights
.values()
.cloned()
.min()
.unwrap_or(sapling_activation_height);
let mut height = get_min_height(&key_heights).unwrap_or(sapling_activation_height);
let key_heights = Arc::new(key_heights);
// Parse and convert keys once, then use them to scan all blocks.
let parsed_keys: HashMap<
@ -102,6 +116,7 @@ pub async fn scan_range(
.collect();
while height < stop_before_height {
let subscribed_keys = subscribed_keys_receiver.borrow().clone();
let scanned_height = scan_height_and_store_results(
height,
state.clone(),
@ -109,6 +124,7 @@ pub async fn scan_range(
storage.clone(),
key_heights.clone(),
parsed_keys.clone(),
subscribed_keys,
)
.await?;
@ -123,5 +139,11 @@ pub async fn scan_range(
.expect("a valid blockchain never reaches the max height");
}
info!(
start_height = ?height,
?stop_before_height,
"finished scanning range"
);
Ok(())
}

View File

@ -1,19 +1,16 @@
//! Tests for the scan task.
use std::sync::{
mpsc::{self, Receiver},
Arc,
};
use std::sync::Arc;
use super::{ScanTask, ScanTaskCommand};
use super::{ScanTask, ScanTaskCommand, SCAN_TASK_BUFFER_SIZE};
#[cfg(test)]
mod vectors;
impl ScanTask {
/// Spawns a new [`ScanTask`] for tests.
pub fn mock() -> (Self, Receiver<ScanTaskCommand>) {
let (cmd_sender, cmd_receiver) = mpsc::channel();
pub fn mock() -> (Self, tokio::sync::mpsc::Receiver<ScanTaskCommand>) {
let (cmd_sender, cmd_receiver) = tokio::sync::mpsc::channel(SCAN_TASK_BUFFER_SIZE);
(
Self {

View File

@ -1,15 +1,18 @@
//! Fixed test vectors for the scan task.
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use color_eyre::Report;
use zebra_chain::{block::Height, transaction};
use zebra_node_services::scan_service::response::ScanResult;
use crate::{service::ScanTask, tests::mock_sapling_scanning_keys};
/// Test that [`ScanTask::process_messages`] adds and removes keys as expected for `RegisterKeys` and `DeleteKeys` command
#[tokio::test]
async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
let (mut mock_scan_task, cmd_receiver) = ScanTask::mock();
let (mut mock_scan_task, mut cmd_receiver) = ScanTask::mock();
let mut parsed_keys = HashMap::new();
let network = Default::default();
@ -20,7 +23,8 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
sapling_keys.into_iter().zip((0..).map(Some)).collect();
mock_scan_task.register_keys(sapling_keys_with_birth_heights.clone())?;
let new_keys = ScanTask::process_messages(&cmd_receiver, &mut parsed_keys, network)?;
let (new_keys, _new_results_senders) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
// Check that it updated parsed_keys correctly and returned the right new keys when starting with an empty state
@ -40,7 +44,8 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
// Check that no key should be added if they are all already known and the heights are the same
let new_keys = ScanTask::process_messages(&cmd_receiver, &mut parsed_keys, network)?;
let (new_keys, _new_results_senders) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
assert_eq!(
parsed_keys.len(),
@ -65,7 +70,8 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
mock_scan_task.register_keys(sapling_keys_with_birth_heights[10..20].to_vec())?;
mock_scan_task.register_keys(sapling_keys_with_birth_heights[10..15].to_vec())?;
let new_keys = ScanTask::process_messages(&cmd_receiver, &mut parsed_keys, network)?;
let (new_keys, _new_results_senders) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
assert_eq!(
parsed_keys.len(),
@ -82,10 +88,10 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
// Check that it removes keys correctly
let sapling_keys = mock_sapling_scanning_keys(30);
let done_rx = mock_scan_task.remove_keys(&sapling_keys)?;
let new_keys = ScanTask::process_messages(&cmd_receiver, &mut parsed_keys, network)?;
let (new_keys, _new_results_senders) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
// Check that it sends the done notification successfully before returning and dropping `done_tx`
done_rx.await?;
@ -103,7 +109,8 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
mock_scan_task.remove_keys(&sapling_keys)?;
let new_keys = ScanTask::process_messages(&cmd_receiver, &mut parsed_keys, network)?;
let (new_keys, _new_results_senders) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
assert!(
new_keys.is_empty(),
@ -118,7 +125,8 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
mock_scan_task.register_keys(sapling_keys_with_birth_heights[..2].to_vec())?;
let new_keys = ScanTask::process_messages(&cmd_receiver, &mut parsed_keys, network)?;
let (new_keys, _new_results_senders) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
assert_eq!(
new_keys.len(),
@ -132,5 +140,42 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> {
"should add 2 keys to parsed_keys after removals"
);
let subscribe_keys: HashSet<String> = sapling_keys[..5].iter().cloned().collect();
let mut result_receiver = mock_scan_task.subscribe(subscribe_keys.clone())?;
let (_new_keys, new_results_senders) =
ScanTask::process_messages(&mut cmd_receiver, &mut parsed_keys, network)?;
let processed_subscribe_keys: HashSet<String> = new_results_senders.keys().cloned().collect();
let expected_new_subscribe_keys: HashSet<String> = sapling_keys[..2].iter().cloned().collect();
assert_eq!(
processed_subscribe_keys, expected_new_subscribe_keys,
"should return new result senders for registered keys"
);
for sender in new_results_senders.values() {
// send a fake tx id for each key
sender
.send(ScanResult {
key: String::new(),
height: Height::MIN,
tx_id: transaction::Hash([0; 32]),
})
.await?;
}
let mut num_results = 0;
while result_receiver.try_recv().is_ok() {
num_results += 1;
}
assert_eq!(
num_results,
expected_new_subscribe_keys.len(),
"there should be a fake result sent for each subscribed key"
);
Ok(())
}

View File

@ -1,5 +1,6 @@
//! Tests for ScanService.
use tokio::sync::mpsc::error::TryRecvError;
use tower::{Service, ServiceExt};
use color_eyre::{eyre::eyre, Result};
@ -38,7 +39,7 @@ pub async fn scan_service_deletes_keys_correctly() -> Result<()> {
"there should be some results for this key in the db"
);
let (mut scan_service, cmd_receiver) = ScanService::new_with_mock_scanner(db);
let (mut scan_service, mut cmd_receiver) = ScanService::new_with_mock_scanner(db);
let response_fut = scan_service
.ready()
@ -47,8 +48,8 @@ pub async fn scan_service_deletes_keys_correctly() -> Result<()> {
.call(Request::DeleteKeys(vec![zec_pages_sapling_efvk.clone()]));
let expected_keys = vec![zec_pages_sapling_efvk.clone()];
let cmd_handler_fut = tokio::task::spawn_blocking(move || {
let Ok(ScanTaskCommand::RemoveKeys { done_tx, keys }) = cmd_receiver.recv() else {
let cmd_handler_fut = tokio::spawn(async move {
let Some(ScanTaskCommand::RemoveKeys { done_tx, keys }) = cmd_receiver.recv().await else {
panic!("should successfully receive RemoveKeys message");
};
@ -77,6 +78,52 @@ pub async fn scan_service_deletes_keys_correctly() -> Result<()> {
Ok(())
}
/// Tests that keys are deleted correctly
#[tokio::test]
pub async fn scan_service_subscribes_to_results_correctly() -> Result<()> {
let db = new_test_storage(Network::Mainnet);
let (mut scan_service, mut cmd_receiver) = ScanService::new_with_mock_scanner(db);
let keys = [String::from("fake key")];
let response_fut = scan_service
.ready()
.await
.map_err(|err| eyre!(err))?
.call(Request::SubscribeResults(keys.iter().cloned().collect()));
let expected_keys = keys.iter().cloned().collect();
let cmd_handler_fut = tokio::spawn(async move {
let Some(ScanTaskCommand::SubscribeResults {
result_sender: _,
keys,
}) = cmd_receiver.recv().await
else {
panic!("should successfully receive SubscribeResults message");
};
assert_eq!(keys, expected_keys, "keys should match the request keys");
});
// Poll futures
let (response, join_result) = tokio::join!(response_fut, cmd_handler_fut);
join_result?;
let mut results_receiver = match response.map_err(|err| eyre!(err))? {
Response::SubscribeResults(results_receiver) => results_receiver,
_ => panic!("scan service returned unexpected response variant"),
};
assert_eq!(
results_receiver.try_recv(),
Err(TryRecvError::Disconnected),
"channel with no items and dropped sender should be closed"
);
Ok(())
}
/// Tests that results are cleared are deleted correctly
#[tokio::test]
pub async fn scan_service_clears_results_correctly() -> Result<()> {

View File

@ -128,6 +128,12 @@
//! ZEBRA_CACHED_STATE_DIR=/path/to/zebra/state cargo test scans_for_new_key --features shielded-scan --release -- --ignored --nocapture
//! ```
//!
//! Example of how to run the scan_subscribe_results test:
//!
//! ```console
//! ZEBRA_CACHED_STATE_DIR=/path/to/zebra/state cargo test scan_subscribe_results --features shielded-scan -- --ignored --nocapture
//! ```
//!
//! ## Checkpoint Generation Tests
//!
//! Generate checkpoints on mainnet and testnet using a cached state:
@ -3011,7 +3017,7 @@ fn scan_start_where_left() -> Result<()> {
/// Test successful registration of a new key in the scan task.
///
/// See [`common::shielded_scan::register_key`] for more information.
/// See [`common::shielded_scan::scans_for_new_key`] for more information.
// TODO: Add this test to CI (#8236)
#[tokio::test]
#[ignore]
@ -3019,3 +3025,14 @@ fn scan_start_where_left() -> Result<()> {
async fn scans_for_new_key() -> Result<()> {
common::shielded_scan::scans_for_new_key::run().await
}
/// Tests SubscribeResults ScanService request.
///
/// See [`common::shielded_scan::subscribe_results`] for more information.
// TODO: Add this test to CI (#8236)
#[tokio::test]
#[ignore]
#[cfg(feature = "shielded-scan")]
async fn scan_subscribe_results() -> Result<()> {
common::shielded_scan::subscribe_results::run().await
}

View File

@ -1,3 +1,4 @@
//! Acceptance tests for `shielded-scan`` feature in zebrad.
pub(crate) mod scans_for_new_key;
pub(crate) mod subscribe_results;

View File

@ -4,9 +4,9 @@
//! Sapling activation height and [`REQUIRED_MIN_TIP_HEIGHT`]
//!
//! export ZEBRA_CACHED_STATE_DIR="/path/to/zebra/state"
//! cargo test scans_for_new_key --features="shielded-scan" -- --ignored --nocapture
//! cargo test scans_for_new_key --release --features="shielded-scan" -- --ignored --nocapture
use std::{collections::HashMap, time::Duration};
use std::time::Duration;
use color_eyre::{eyre::eyre, Result};
@ -16,11 +16,7 @@ use zebra_chain::{
chain_tip::ChainTip,
parameters::{Network, NetworkUpgrade},
};
use zebra_scan::{
scan::sapling_key_to_scan_block_keys, service::ScanTask, storage::Storage,
tests::ZECPAGES_SAPLING_VIEWING_KEY, DiversifiableFullViewingKey, SaplingIvk,
};
use zebra_state::SaplingScanningKey;
use zebra_scan::{service::ScanTask, storage::Storage, tests::ZECPAGES_SAPLING_VIEWING_KEY};
use crate::common::{
cached_state::start_state_service_with_cache_dir, launch::can_spawn_zebrad_for_test_type,
@ -31,7 +27,7 @@ use crate::common::{
const REQUIRED_MIN_TIP_HEIGHT: Height = Height(1_000_000);
/// How long this test waits after registering keys to check if there are any results.
const WAIT_FOR_RESULTS_DURATION: Duration = Duration::from_secs(10 * 60);
const WAIT_FOR_RESULTS_DURATION: Duration = Duration::from_secs(60);
/// Initialize Zebra's state service with a cached state, add a new key to the scan task, and
/// check that it stores results for the new key without errors.
@ -90,25 +86,11 @@ pub(crate) async fn run() -> Result<()> {
let mut scan_task = ScanTask::spawn(storage, state, chain_tip_change);
let (zecpages_dfvks, zecpages_ivks) =
sapling_key_to_scan_block_keys(&ZECPAGES_SAPLING_VIEWING_KEY.to_string(), network)?;
let mut parsed_keys: HashMap<
SaplingScanningKey,
(Vec<DiversifiableFullViewingKey>, Vec<SaplingIvk>, Height),
> = HashMap::new();
parsed_keys.insert(
ZECPAGES_SAPLING_VIEWING_KEY.to_string(),
(zecpages_dfvks, zecpages_ivks, Height::MIN),
);
tracing::info!("started scan task, sending register keys message with zecpages key to start scanning for a new key",);
scan_task.register_keys(
parsed_keys
[(ZECPAGES_SAPLING_VIEWING_KEY.to_string(), None)]
.into_iter()
.map(|(key, (_, _, Height(h)))| (key, Some(h)))
.collect(),
)?;
@ -126,6 +108,8 @@ pub(crate) async fn run() -> Result<()> {
let results = storage.sapling_results(&ZECPAGES_SAPLING_VIEWING_KEY.to_string());
tracing::info!(?results, "got the results");
// Check that some results were added for the zecpages key that was not in the config or the db when ScanTask started.
assert!(
!results.is_empty(),

View File

@ -0,0 +1,104 @@
//! Test registering and subscribing to the results for a new key in the scan task while zebrad is running.
//!
//! This test requires a cached chain state that is partially synchronized past the
//! Sapling activation height and [`REQUIRED_MIN_TIP_HEIGHT`]
//!
//! export ZEBRA_CACHED_STATE_DIR="/path/to/zebra/state"
//! cargo test scan_subscribe_results --features="shielded-scan" -- --ignored --nocapture
use std::time::Duration;
use color_eyre::{eyre::eyre, Result};
use tower::ServiceBuilder;
use zebra_chain::{
block::Height,
chain_tip::ChainTip,
parameters::{Network, NetworkUpgrade},
};
use zebra_scan::{service::ScanTask, storage::Storage, tests::ZECPAGES_SAPLING_VIEWING_KEY};
use crate::common::{
cached_state::start_state_service_with_cache_dir, launch::can_spawn_zebrad_for_test_type,
test_type::TestType,
};
/// The minimum required tip height for the cached state in this test.
const REQUIRED_MIN_TIP_HEIGHT: Height = Height(1_000_000);
/// How long this test waits for a result before failing.
const WAIT_FOR_RESULTS_DURATION: Duration = Duration::from_secs(30 * 60);
/// Initialize Zebra's state service with a cached state, add a new key to the scan task, and
/// check that it stores results for the new key without errors.
pub(crate) async fn run() -> Result<()> {
let _init_guard = zebra_test::init();
let test_type = TestType::UpdateZebraCachedStateNoRpc;
let test_name = "scan_subscribe_results";
let network = Network::Mainnet;
// Skip the test unless the user specifically asked for it and there is a zebrad_state_path
if !can_spawn_zebrad_for_test_type(test_name, test_type, true) {
return Ok(());
}
tracing::info!(
?network,
?test_type,
"running scan_subscribe_results test using zebra state service",
);
let zebrad_state_path = test_type
.zebrad_state_path(test_name)
.expect("already checked that there is a cached state path");
let (state_service, _read_state_service, latest_chain_tip, chain_tip_change) =
start_state_service_with_cache_dir(network, zebrad_state_path).await?;
let chain_tip_height = latest_chain_tip
.best_tip_height()
.ok_or_else(|| eyre!("State directory doesn't have a chain tip block"))?;
let sapling_activation_height = NetworkUpgrade::Sapling
.activation_height(network)
.expect("there should be an activation height for Mainnet");
assert!(
sapling_activation_height < REQUIRED_MIN_TIP_HEIGHT,
"minimum tip height should be above sapling activation height"
);
assert!(
REQUIRED_MIN_TIP_HEIGHT < chain_tip_height,
"chain tip height must be above required minimum tip height"
);
tracing::info!("opened state service with valid chain tip height, starting scan task",);
let state = ServiceBuilder::new().buffer(10).service(state_service);
// Create an ephemeral `Storage` instance
let storage = Storage::new(&zebra_scan::Config::ephemeral(), network, false);
let mut scan_task = ScanTask::spawn(storage, state, chain_tip_change);
tracing::info!("started scan task, sending register/subscribe keys messages with zecpages key to start scanning for a new key",);
let keys = [ZECPAGES_SAPLING_VIEWING_KEY.to_string()];
scan_task.register_keys(
keys.iter()
.cloned()
.map(|key| (key, Some(736000)))
.collect(),
)?;
let mut result_receiver = scan_task.subscribe(keys.into_iter().collect())?;
// Wait for the scanner to send a result in the channel
let result = tokio::time::timeout(WAIT_FOR_RESULTS_DURATION, result_receiver.recv()).await?;
tracing::info!(?result, "received a result from the channel");
Ok(())
}