diff --git a/zebra-grpc/src/server.rs b/zebra-grpc/src/server.rs index 383a71fdc..b62cdb1c9 100644 --- a/zebra-grpc/src/server.rs +++ b/zebra-grpc/src/server.rs @@ -5,7 +5,7 @@ use std::{collections::BTreeMap, net::SocketAddr, pin::Pin}; use futures_util::future::TryFutureExt; use tokio_stream::{wrappers::ReceiverStream, Stream}; use tonic::{transport::Server, Request, Response, Status}; -use tower::ServiceExt; +use tower::{timeout::error::Elapsed, ServiceExt}; use zebra_chain::{block::Height, transaction}; use zebra_node_services::scan_service::{ @@ -70,34 +70,57 @@ where .into_iter() .map(|KeyWithHeight { key, height }| (key, height)) .collect(); + let register_keys_response_fut = self .scan_service .clone() - .oneshot(ScanServiceRequest::RegisterKeys(keys.clone())); + .ready() + .await + .map_err(|_| Status::unknown("service poll_ready() method returned an error"))? + .call(ScanServiceRequest::RegisterKeys(keys.clone())); let keys: Vec<_> = keys.into_iter().map(|(key, _start_at)| key).collect(); - let subscribe_results_response_fut = - self.scan_service - .clone() - .oneshot(ScanServiceRequest::SubscribeResults( - keys.iter().cloned().collect(), - )); + let subscribe_results_response_fut = self + .scan_service + .clone() + .ready() + .await + .map_err(|_| Status::unknown("service poll_ready() method returned an error"))? + .call(ScanServiceRequest::SubscribeResults( + keys.iter().cloned().collect(), + )); let (register_keys_response, subscribe_results_response) = tokio::join!(register_keys_response_fut, subscribe_results_response_fut); - let ScanServiceResponse::RegisteredKeys(_) = register_keys_response - .map_err(|err| Status::unknown(format!("scan service returned error: {err}")))? - else { - return Err(Status::unknown( - "scan service returned an unexpected response", - )); + // Ignores errors from the register keys request, we expect there to be a timeout if the keys + // are already registered, or an empty response if no new keys could be parsed as Sapling efvks. + // + // This method will still return an error if every key in the `scan` request is invalid, since + // the SubscribeResults request will return an error once the `rsp_tx` is dropped in `ScanTask::process_messages` + // when it finds that none of the keys in the request are registered. + let register_keys_err = match register_keys_response { + Ok(ScanServiceResponse::RegisteredKeys(_)) => None, + Ok(response) => { + return Err(Status::internal(format!( + "unexpected response from scan service: {response:?}" + ))) + } + Err(err) if err.downcast_ref::().is_some() => { + return Err(Status::deadline_exceeded( + "scan service requests timed out, is Zebra synced past Sapling activation height?") + ) + } + Err(err) => Some(err), }; let ScanServiceResponse::SubscribeResults(mut results_receiver) = - subscribe_results_response - .map_err(|err| Status::unknown(format!("scan service returned error: {err}")))? + subscribe_results_response.map_err(|err| { + register_keys_err + .map(|err| Status::invalid_argument(err.to_string())) + .unwrap_or(Status::internal(err.to_string())) + })? else { return Err(Status::unknown( "scan service returned an unexpected response", @@ -179,7 +202,7 @@ where .ready() .and_then(|service| service.call(ScanServiceRequest::Info)) .await - .map_err(|_| Status::unknown("scan service was unavailable"))? + .map_err(|err| Status::unknown(format!("scan service returned error: {err}")))? else { return Err(Status::unknown( "scan service returned an unexpected response", @@ -217,20 +240,30 @@ where return Err(Status::invalid_argument(msg)); } - let ScanServiceResponse::RegisteredKeys(keys) = self + match self .scan_service .clone() .ready() .and_then(|service| service.call(ScanServiceRequest::RegisterKeys(keys))) .await - .map_err(|_| Status::unknown("scan service was unavailable"))? - else { - return Err(Status::unknown( - "scan service returned an unexpected response", - )); - }; + { + Ok(ScanServiceResponse::RegisteredKeys(keys)) => { + Ok(Response::new(RegisterKeysResponse { keys })) + } - Ok(Response::new(RegisterKeysResponse { keys })) + Ok(response) => { + return Err(Status::internal(format!( + "unexpected response from scan service: {response:?}" + ))) + } + + Err(err) if err.downcast_ref::().is_some() => Err(Status::deadline_exceeded( + "RegisterKeys scan service request timed out, \ + is Zebra synced past Sapling activation height?", + )), + + Err(err) => Err(Status::unknown(err.to_string())), + } } async fn clear_results( diff --git a/zebra-node-services/src/scan_service/request.rs b/zebra-node-services/src/scan_service/request.rs index 1490501d2..e6a8e9c49 100644 --- a/zebra-node-services/src/scan_service/request.rs +++ b/zebra-node-services/src/scan_service/request.rs @@ -13,9 +13,6 @@ pub enum Request { /// Requests general info about the scanner Info, - /// TODO: Accept `KeyHash`es and return key hashes that are registered - CheckKeyHashes(Vec<()>), - /// Submits viewing keys with their optional birth-heights for scanning. RegisterKeys(Vec<(String, Option)>), diff --git a/zebra-scan/src/init.rs b/zebra-scan/src/init.rs index fe0016844..15ab2a9a9 100644 --- a/zebra-scan/src/init.rs +++ b/zebra-scan/src/init.rs @@ -1,6 +1,6 @@ //! Initializing the scanner and gRPC server. -use std::net::SocketAddr; +use std::{net::SocketAddr, time::Duration}; use color_eyre::Report; use tokio::task::JoinHandle; @@ -12,6 +12,9 @@ use zebra_state::ChainTipChange; use crate::{scan, service::ScanService, storage::Storage, Config}; +/// The timeout applied to scan service calls +pub const SCAN_SERVICE_TIMEOUT: Duration = Duration::from_secs(30); + /// Initialize [`ScanService`] based on its config. /// /// TODO: add a test for this function. @@ -25,6 +28,7 @@ pub async fn init_with_server( info!(?config, "starting scan service"); let scan_service = ServiceBuilder::new() .buffer(10) + .timeout(SCAN_SERVICE_TIMEOUT) .service(ScanService::new(&config, network, state, chain_tip_change).await); // TODO: move this to zebra-grpc init() function and include addr diff --git a/zebra-scan/src/service.rs b/zebra-scan/src/service.rs index 426ec94ba..0b9758b91 100644 --- a/zebra-scan/src/service.rs +++ b/zebra-scan/src/service.rs @@ -3,7 +3,7 @@ use std::{collections::BTreeMap, future::Future, pin::Pin, task::Poll, time::Duration}; use futures::future::FutureExt; -use tower::Service; +use tower::{BoxError, Service}; use zebra_chain::{diagnostic::task::WaitForPanics, parameters::Network, transaction::Hash}; @@ -32,6 +32,9 @@ pub struct ScanService { } /// A timeout applied to `DeleteKeys` requests. +/// +/// This should be shorter than [`SCAN_SERVICE_TIMEOUT`](crate::init::SCAN_SERVICE_TIMEOUT) so the +/// request can try to delete entries from storage after the timeout before the future is dropped. const DELETE_KEY_TIMEOUT: Duration = Duration::from_secs(15); impl ScanService { @@ -64,7 +67,7 @@ impl ScanService { impl Service for ScanService { type Response = Response; - type Error = Box; + type Error = BoxError; type Future = Pin> + Send + 'static>>; @@ -97,17 +100,17 @@ impl Service for ScanService { .boxed(); } - Request::CheckKeyHashes(_key_hashes) => { - // TODO: check that these entries exist in db - } - Request::RegisterKeys(keys) => { let mut scan_task = self.scan_task.clone(); return async move { - Ok(Response::RegisteredKeys( - scan_task.register_keys(keys)?.await?, - )) + let newly_registered_keys = scan_task.register_keys(keys)?.await?; + if !newly_registered_keys.is_empty() { + Ok(Response::RegisteredKeys(newly_registered_keys)) + } else { + Err("no keys were registered, check that keys are not already registered and \ + are valid Sapling extended full viewing keys".into()) + } } .boxed(); } @@ -123,7 +126,7 @@ impl Service for ScanService { scan_task.remove_keys(keys.clone())?, ) .await - .map_err(|_| "timeout waiting for delete keys done notification"); + .map_err(|_| "request timed out removing keys from scan task".to_string()); // Delete the key from the database after either confirmation that it's been removed from the scan task, or // waiting `DELETE_KEY_TIMEOUT`. @@ -171,7 +174,9 @@ impl Service for ScanService { let mut scan_task = self.scan_task.clone(); return async move { - let results_receiver = scan_task.subscribe(keys).await?; + let results_receiver = scan_task.subscribe(keys)?.await.map_err(|_| { + "scan task dropped responder, check that keys are registered" + })?; Ok(Response::SubscribeResults(results_receiver)) } @@ -193,7 +198,5 @@ impl Service for ScanService { .boxed(); } } - - async move { Ok(Response::Results(BTreeMap::new())) }.boxed() } } diff --git a/zebra-scan/src/service/scan_task/commands.rs b/zebra-scan/src/service/scan_task/commands.rs index b7a5fe0d6..2e7964bf7 100644 --- a/zebra-scan/src/service/scan_task/commands.rs +++ b/zebra-scan/src/service/scan_task/commands.rs @@ -8,7 +8,6 @@ use tokio::sync::{ oneshot, }; -use tower::BoxError; use zcash_primitives::{sapling::SaplingIvk, zip32::DiversifiableFullViewingKey}; use zebra_chain::{block::Height, parameters::Network}; use zebra_node_services::scan_service::response::ScanResult; @@ -207,14 +206,14 @@ impl ScanTask { /// Sends a message to the scan task to start sending the results for the provided viewing keys to a channel. /// /// Returns the channel receiver. - pub async fn subscribe( + pub fn subscribe( &mut self, keys: HashSet, - ) -> Result, BoxError> { + ) -> Result>, TrySendError> { let (rsp_tx, rsp_rx) = oneshot::channel(); self.send(ScanTaskCommand::SubscribeResults { keys, rsp_tx })?; - Ok(rsp_rx.await?) + Ok(rsp_rx) } } diff --git a/zebra-scan/src/service/scan_task/tests/vectors.rs b/zebra-scan/src/service/scan_task/tests/vectors.rs index b3c1a0959..fca8b9a5b 100644 --- a/zebra-scan/src/service/scan_task/tests/vectors.rs +++ b/zebra-scan/src/service/scan_task/tests/vectors.rs @@ -147,7 +147,12 @@ async fn scan_task_processes_messages_correctly() -> Result<(), Report> { let subscribe_keys: HashSet = sapling_keys[..5].iter().cloned().collect(); let result_receiver_fut = { let mut mock_scan_task = mock_scan_task.clone(); - tokio::spawn(async move { mock_scan_task.subscribe(subscribe_keys.clone()).await }) + tokio::spawn(async move { + mock_scan_task + .subscribe(subscribe_keys.clone()) + .expect("should send subscribe msg successfully") + .await + }) }; // Wait for spawned task to send subscribe message diff --git a/zebra-scan/src/service/tests.rs b/zebra-scan/src/service/tests.rs index 0ae172830..fa5bd530a 100644 --- a/zebra-scan/src/service/tests.rs +++ b/zebra-scan/src/service/tests.rs @@ -1,7 +1,10 @@ //! Tests for ScanService. +use std::time::Duration; + +use futures::{stream::FuturesOrdered, StreamExt}; use tokio::sync::mpsc::error::TryRecvError; -use tower::{Service, ServiceBuilder, ServiceExt}; +use tower::{timeout::error::Elapsed, Service, ServiceBuilder, ServiceExt}; use color_eyre::{eyre::eyre, Result}; @@ -10,6 +13,7 @@ use zebra_node_services::scan_service::{request::Request, response::Response}; use zebra_state::TransactionIndex; use crate::{ + init::SCAN_SERVICE_TIMEOUT, service::{scan_task::ScanTaskCommand, ScanService}, storage::db::tests::{fake_sapling_results, new_test_storage}, tests::{mock_sapling_scanning_keys, ZECPAGES_SAPLING_VIEWING_KEY}, @@ -329,5 +333,82 @@ async fn scan_service_registers_keys_correctly_for(network: Network) -> Result<( _ => panic!("scan service should have responded with the `RegisteredKeys` response"), } + // Try registering invalid keys. + let register_keys_error_message = scan_service + .ready() + .await + .map_err(|err| eyre!(err))? + .call(Request::RegisterKeys(vec![( + "invalid key".to_string(), + None, + )])) + .await + .expect_err("response should be an error when there are no valid keys to be added") + .to_string(); + + assert!( + register_keys_error_message.starts_with("no keys were registered"), + "error message should say that no keys were registered" + ); + + Ok(()) +} + +/// Test that the scan service with a timeout layer returns timeout errors after expected timeout +#[tokio::test] +async fn scan_service_timeout() -> Result<()> { + let db = new_test_storage(Network::Mainnet); + + let (scan_service, _cmd_receiver) = ScanService::new_with_mock_scanner(db); + let mut scan_service = ServiceBuilder::new() + .buffer(10) + .timeout(SCAN_SERVICE_TIMEOUT) + .service(scan_service); + + let keys = vec![String::from("fake key")]; + let mut response_futs = FuturesOrdered::new(); + + for request in [ + Request::RegisterKeys(keys.iter().cloned().map(|key| (key, None)).collect()), + Request::SubscribeResults(keys.iter().cloned().collect()), + Request::DeleteKeys(keys), + ] { + let response_fut = scan_service + .ready() + .await + .expect("service should be ready") + .call(request); + + response_futs.push_back(tokio::time::timeout( + SCAN_SERVICE_TIMEOUT + .checked_add(Duration::from_secs(1)) + .expect("should not overflow"), + response_fut, + )); + } + + let expect_timeout_err = |response: Option, _>>| { + response + .expect("response_futs should not be empty") + .expect("service should respond with timeout error before outer timeout") + .expect_err("service response should be a timeout error") + }; + + // RegisterKeys and SubscribeResults should return `Elapsed` errors from `Timeout` layer + for _ in 0..2 { + let response = response_futs.next().await; + expect_timeout_err(response) + .downcast::() + .expect("service should return Elapsed error from Timeout layer"); + } + + let response = response_futs.next().await; + let response_error_msg = expect_timeout_err(response).to_string(); + + assert!( + response_error_msg.starts_with("request timed out"), + "error message should say the request timed out" + ); + Ok(()) } diff --git a/zebrad/tests/common/shielded_scan/scan_task_commands.rs b/zebrad/tests/common/shielded_scan/scan_task_commands.rs index e1d2fd630..44021fe83 100644 --- a/zebrad/tests/common/shielded_scan/scan_task_commands.rs +++ b/zebrad/tests/common/shielded_scan/scan_task_commands.rs @@ -120,8 +120,9 @@ pub(crate) async fn run() -> Result<()> { let mut result_receiver = scan_task .subscribe(keys.iter().cloned().collect()) + .expect("should send subscribe message successfully") .await - .expect("should send and receive message successfully"); + .expect("should receive response successfully"); // Wait for the scanner to send a result in the channel let result = tokio::time::timeout(WAIT_FOR_RESULTS_DURATION, result_receiver.recv()).await?;