1use std::fmt;
13
14use futures_util::TryStreamExt;
15use shardtree::error::ShardTreeError;
16use subtle::ConditionallySelectable;
17use tonic::{
18 body::Body as TonicBody,
19 client::GrpcService,
20 codegen::{Body, Bytes, StdError},
21};
22use tracing::{debug, info};
23
24use zcash_keys::encoding::AddressCodec as _;
25use zcash_primitives::merkle_tree::HashSer;
26use zcash_protocol::consensus::{BlockHeight, Parameters};
27
28use crate::{
29 data_api::{
30 chain::{
31 error::Error as ChainError, scan_cached_blocks, BlockCache, ChainState,
32 CommitmentTreeRoot,
33 },
34 scanning::{ScanPriority, ScanRange},
35 WalletCommitmentTrees, WalletRead, WalletWrite,
36 },
37 proto::service::{self, compact_tx_streamer_client::CompactTxStreamerClient, BlockId},
38 scanning::ScanError,
39};
40
41#[cfg(feature = "orchard")]
42use orchard::tree::MerkleHashOrchard;
43
44#[cfg(feature = "transparent-inputs")]
45use {
46 crate::wallet::WalletTransparentOutput,
47 ::transparent::{
48 address::Script,
49 bundle::{OutPoint, TxOut},
50 },
51 zcash_protocol::value::Zatoshis,
52};
53
54pub async fn run<P, ChT, CaT, DbT>(
56 client: &mut CompactTxStreamerClient<ChT>,
57 params: &P,
58 db_cache: &CaT,
59 db_data: &mut DbT,
60 batch_size: u32,
61) -> Result<(), Error<CaT::Error, <DbT as WalletRead>::Error, <DbT as WalletCommitmentTrees>::Error>>
62where
63 P: Parameters + Send + 'static,
64 ChT: GrpcService<TonicBody>,
65 ChT::Error: Into<StdError>,
66 ChT::ResponseBody: Body<Data = Bytes> + Send + 'static,
67 <ChT::ResponseBody as Body>::Error: Into<StdError> + Send,
68 CaT: BlockCache,
69 CaT::Error: std::error::Error + Send + Sync + 'static,
70 DbT: WalletWrite + WalletCommitmentTrees,
71 DbT::AccountId: ConditionallySelectable + Default + Send + 'static,
72 <DbT as WalletRead>::Error: std::error::Error + Send + Sync + 'static,
73 <DbT as WalletCommitmentTrees>::Error: std::error::Error + Send + Sync + 'static,
74{
75 update_subtree_roots(client, db_data).await?;
78
79 while running(client, params, db_cache, db_data, batch_size).await? {}
80
81 Ok(())
82}
83
84async fn running<P, ChT, CaT, DbT, TrErr>(
85 client: &mut CompactTxStreamerClient<ChT>,
86 params: &P,
87 db_cache: &CaT,
88 db_data: &mut DbT,
89 batch_size: u32,
90) -> Result<bool, Error<CaT::Error, <DbT as WalletRead>::Error, TrErr>>
91where
92 P: Parameters + Send + 'static,
93 ChT: GrpcService<TonicBody>,
94 ChT::Error: Into<StdError>,
95 ChT::ResponseBody: Body<Data = Bytes> + Send + 'static,
96 <ChT::ResponseBody as Body>::Error: Into<StdError> + Send,
97 CaT: BlockCache,
98 CaT::Error: std::error::Error + Send + Sync + 'static,
99 DbT: WalletWrite,
100 DbT::AccountId: ConditionallySelectable + Default + Send + 'static,
101 DbT::Error: std::error::Error + Send + Sync + 'static,
102{
103 update_chain_tip(client, db_data).await?;
106
107 #[cfg(feature = "transparent-inputs")]
111 for account_id in db_data.get_account_ids().map_err(Error::Wallet)? {
112 let start_height = db_data
113 .utxo_query_height(account_id)
114 .map_err(Error::Wallet)?;
115 info!(
116 "Refreshing UTXOs for {:?} from height {}",
117 account_id, start_height,
118 );
119 refresh_utxos(params, client, db_data, account_id, start_height).await?;
120 }
121
122 let mut scan_ranges = db_data.suggest_scan_ranges().map_err(Error::Wallet)?;
124
125 let mut block_deletions = vec![];
128
129 loop {
132 match scan_ranges.first() {
135 Some(scan_range) if scan_range.priority() == ScanPriority::Verify => {
136 download_blocks(client, db_cache, scan_range).await?;
139
140 let chain_state =
141 download_chain_state(client, scan_range.block_range().start - 1).await?;
142
143 let scan_ranges_updated =
147 scan_blocks(params, db_cache, db_data, &chain_state, scan_range).await?;
148
149 block_deletions.push(db_cache.delete(scan_range.clone()));
152
153 if scan_ranges_updated {
154 scan_ranges = db_data.suggest_scan_ranges().map_err(Error::Wallet)?;
156 } else {
157 break;
163 }
164 }
165 _ => {
166 break;
168 }
169 }
170 }
171
172 let scan_ranges = db_data.suggest_scan_ranges().map_err(Error::Wallet)?;
175 debug!("Suggested ranges: {:?}", scan_ranges);
176 for scan_range in scan_ranges.into_iter().flat_map(|r| {
177 (0..).scan(r, |acc, _| {
179 if acc.is_empty() {
180 None
181 } else if let Some((cur, next)) = acc.split_at(acc.block_range().start + batch_size) {
182 *acc = next;
183 Some(cur)
184 } else {
185 let cur = acc.clone();
186 let end = acc.block_range().end;
187 *acc = ScanRange::from_parts(end..end, acc.priority());
188 Some(cur)
189 }
190 })
191 }) {
192 download_blocks(client, db_cache, &scan_range).await?;
194
195 let chain_state = download_chain_state(client, scan_range.block_range().start - 1).await?;
196
197 let scan_ranges_updated =
199 scan_blocks(params, db_cache, db_data, &chain_state, &scan_range).await?;
200
201 block_deletions.push(db_cache.delete(scan_range));
203
204 if scan_ranges_updated {
205 info!("Waiting for cached blocks to be deleted...");
208 for deletion in block_deletions {
209 deletion.await.map_err(Error::Cache)?;
210 }
211 return Ok(true);
212 }
213 }
214
215 info!("Waiting for cached blocks to be deleted...");
216 for deletion in block_deletions {
217 deletion.await.map_err(Error::Cache)?;
218 }
219 Ok(false)
220}
221
222async fn update_subtree_roots<ChT, DbT, CaErr, DbErr>(
223 client: &mut CompactTxStreamerClient<ChT>,
224 db_data: &mut DbT,
225) -> Result<(), Error<CaErr, DbErr, <DbT as WalletCommitmentTrees>::Error>>
226where
227 ChT: GrpcService<TonicBody>,
228 ChT::Error: Into<StdError>,
229 ChT::ResponseBody: Body<Data = Bytes> + Send + 'static,
230 <ChT::ResponseBody as Body>::Error: Into<StdError> + Send,
231 DbT: WalletCommitmentTrees,
232 <DbT as WalletCommitmentTrees>::Error: std::error::Error + Send + Sync + 'static,
233{
234 let mut request = service::GetSubtreeRootsArg::default();
235 request.set_shielded_protocol(service::ShieldedProtocol::Sapling);
236
237 let sapling_roots: Vec<CommitmentTreeRoot<sapling::Node>> = client
238 .get_subtree_roots(request)
239 .await?
240 .into_inner()
241 .and_then(|root| async move {
242 let root_hash = sapling::Node::read(&root.root_hash[..])?;
243 Ok(CommitmentTreeRoot::from_parts(
244 BlockHeight::from_u32(root.completing_block_height as u32),
245 root_hash,
246 ))
247 })
248 .try_collect()
249 .await?;
250
251 info!("Sapling tree has {} subtrees", sapling_roots.len());
252 db_data
253 .put_sapling_subtree_roots(0, &sapling_roots)
254 .map_err(Error::WalletTrees)?;
255
256 #[cfg(feature = "orchard")]
257 {
258 let mut request = service::GetSubtreeRootsArg::default();
259 request.set_shielded_protocol(service::ShieldedProtocol::Orchard);
260
261 let orchard_roots: Vec<CommitmentTreeRoot<MerkleHashOrchard>> = client
262 .get_subtree_roots(request)
263 .await?
264 .into_inner()
265 .and_then(|root| async move {
266 let root_hash = MerkleHashOrchard::read(&root.root_hash[..])?;
267 Ok(CommitmentTreeRoot::from_parts(
268 BlockHeight::from_u32(root.completing_block_height as u32),
269 root_hash,
270 ))
271 })
272 .try_collect()
273 .await?;
274
275 info!("Orchard tree has {} subtrees", orchard_roots.len());
276 db_data
277 .put_orchard_subtree_roots(0, &orchard_roots)
278 .map_err(Error::WalletTrees)?;
279 }
280
281 Ok(())
282}
283
284async fn update_chain_tip<ChT, DbT, CaErr, TrErr>(
285 client: &mut CompactTxStreamerClient<ChT>,
286 db_data: &mut DbT,
287) -> Result<(), Error<CaErr, <DbT as WalletRead>::Error, TrErr>>
288where
289 ChT: GrpcService<TonicBody>,
290 ChT::Error: Into<StdError>,
291 ChT::ResponseBody: Body<Data = Bytes> + Send + 'static,
292 <ChT::ResponseBody as Body>::Error: Into<StdError> + Send,
293 DbT: WalletWrite,
294 DbT::Error: std::error::Error + Send + Sync + 'static,
295{
296 let tip_height: BlockHeight = client
297 .get_latest_block(service::ChainSpec::default())
298 .await?
299 .get_ref()
300 .height
301 .try_into()
302 .map_err(|_| Error::MisbehavingServer)?;
303
304 info!("Latest block height is {}", tip_height);
305 db_data
306 .update_chain_tip(tip_height)
307 .map_err(Error::Wallet)?;
308
309 Ok(())
310}
311
312async fn download_blocks<ChT, CaT, DbErr, TrErr>(
313 client: &mut CompactTxStreamerClient<ChT>,
314 db_cache: &CaT,
315 scan_range: &ScanRange,
316) -> Result<(), Error<CaT::Error, DbErr, TrErr>>
317where
318 ChT: GrpcService<TonicBody>,
319 ChT::Error: Into<StdError>,
320 ChT::ResponseBody: Body<Data = Bytes> + Send + 'static,
321 <ChT::ResponseBody as Body>::Error: Into<StdError> + Send,
322 CaT: BlockCache,
323 CaT::Error: std::error::Error + Send + Sync + 'static,
324{
325 info!("Fetching {}", scan_range);
326 let mut start = service::BlockId::default();
327 start.height = scan_range.block_range().start.into();
328 let mut end = service::BlockId::default();
329 end.height = (scan_range.block_range().end - 1).into();
330 let range = service::BlockRange {
331 start: Some(start),
332 end: Some(end),
333 };
334 let compact_blocks = client
335 .get_block_range(range)
336 .await?
337 .into_inner()
338 .try_collect::<Vec<_>>()
339 .await?;
340
341 db_cache
342 .insert(compact_blocks)
343 .await
344 .map_err(Error::Cache)?;
345
346 Ok(())
347}
348
349async fn download_chain_state<ChT, CaErr, DbErr, TrErr>(
350 client: &mut CompactTxStreamerClient<ChT>,
351 block_height: BlockHeight,
352) -> Result<ChainState, Error<CaErr, DbErr, TrErr>>
353where
354 ChT: GrpcService<TonicBody>,
355 ChT::Error: Into<StdError>,
356 ChT::ResponseBody: Body<Data = Bytes> + Send + 'static,
357 <ChT::ResponseBody as Body>::Error: Into<StdError> + Send,
358{
359 let tree_state = client
360 .get_tree_state(BlockId {
361 height: block_height.into(),
362 hash: vec![],
363 })
364 .await?;
365
366 tree_state
367 .into_inner()
368 .to_chain_state()
369 .map_err(|_| Error::MisbehavingServer)
370}
371
372async fn scan_blocks<P, CaT, DbT, TrErr>(
377 params: &P,
378 db_cache: &CaT,
379 db_data: &mut DbT,
380 initial_chain_state: &ChainState,
381 scan_range: &ScanRange,
382) -> Result<bool, Error<CaT::Error, <DbT as WalletRead>::Error, TrErr>>
383where
384 P: Parameters + Send + 'static,
385 CaT: BlockCache,
386 CaT::Error: std::error::Error + Send + Sync + 'static,
387 DbT: WalletWrite,
388 DbT::AccountId: ConditionallySelectable + Default + Send + 'static,
389 DbT::Error: std::error::Error + Send + Sync + 'static,
390{
391 info!("Scanning {}", scan_range);
392 let scan_result = scan_cached_blocks(
393 params,
394 db_cache,
395 db_data,
396 scan_range.block_range().start,
397 initial_chain_state,
398 scan_range.len(),
399 );
400
401 match scan_result {
402 Err(ChainError::Scan(err)) if err.is_continuity_error() => {
403 let rewind_height = err.at_height().saturating_sub(10);
408 info!(
409 "Chain reorg detected at {}, rewinding to {}",
410 err.at_height(),
411 rewind_height,
412 );
413
414 db_data
416 .truncate_to_height(rewind_height)
417 .map_err(Error::Wallet)?;
418
419 db_cache
425 .truncate(rewind_height)
426 .await
427 .map_err(Error::Cache)?;
428
429 Ok(true)
431 }
432 Ok(_) => {
433 let latest_ranges = db_data.suggest_scan_ranges().map_err(Error::Wallet)?;
436
437 Ok(if let Some(range) = latest_ranges.first() {
438 range.priority() > scan_range.priority()
439 } else {
440 false
441 })
442 }
443 Err(e) => Err(e.into()),
444 }
445}
446
447#[cfg(feature = "transparent-inputs")]
473async fn refresh_utxos<P, ChT, DbT, CaErr, TrErr>(
474 params: &P,
475 client: &mut CompactTxStreamerClient<ChT>,
476 db_data: &mut DbT,
477 account_id: DbT::AccountId,
478 start_height: BlockHeight,
479) -> Result<(), Error<CaErr, <DbT as WalletRead>::Error, TrErr>>
480where
481 P: Parameters + Send + 'static,
482 ChT: GrpcService<TonicBody>,
483 ChT::Error: Into<StdError>,
484 ChT::ResponseBody: Body<Data = Bytes> + Send + 'static,
485 <ChT::ResponseBody as Body>::Error: Into<StdError> + Send,
486 DbT: WalletWrite,
487 DbT::Error: std::error::Error + Send + Sync + 'static,
488{
489 let request = service::GetAddressUtxosArg {
490 addresses: db_data
491 .get_transparent_receivers(account_id, true)
492 .map_err(Error::Wallet)?
493 .into_keys()
494 .map(|addr| addr.encode(params))
495 .collect(),
496 start_height: start_height.into(),
497 max_entries: 0,
498 };
499
500 if request.addresses.is_empty() {
501 info!("{:?} has no transparent receivers", account_id);
502 } else {
503 client
504 .get_address_utxos_stream(request)
505 .await?
506 .into_inner()
507 .map_err(Error::Server)
508 .and_then(|reply| async move {
509 WalletTransparentOutput::from_parts(
510 OutPoint::new(
511 reply.txid[..]
512 .try_into()
513 .map_err(|_| Error::MisbehavingServer)?,
514 reply
515 .index
516 .try_into()
517 .map_err(|_| Error::MisbehavingServer)?,
518 ),
519 TxOut {
520 value: Zatoshis::from_nonnegative_i64(reply.value_zat)
521 .map_err(|_| Error::MisbehavingServer)?,
522 script_pubkey: Script(reply.script),
523 },
524 Some(
525 BlockHeight::try_from(reply.height)
526 .map_err(|_| Error::MisbehavingServer)?,
527 ),
528 )
529 .ok_or(Error::MisbehavingServer)
530 })
531 .try_for_each(|output| {
532 let res = db_data.put_received_transparent_utxo(&output).map(|_| ());
533 async move { res.map_err(Error::Wallet) }
534 })
535 .await?;
536 }
537
538 Ok(())
539}
540
541#[derive(Debug)]
543pub enum Error<CaErr, DbErr, TrErr> {
544 Cache(CaErr),
546 MisbehavingServer,
548 Scan(ScanError),
550 Server(tonic::Status),
552 Wallet(DbErr),
555 WalletTrees(ShardTreeError<TrErr>),
557}
558
559impl<CaErr, DbErr, TrErr> fmt::Display for Error<CaErr, DbErr, TrErr>
560where
561 CaErr: fmt::Display,
562 DbErr: fmt::Display,
563 TrErr: fmt::Display,
564{
565 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
566 match self {
567 Error::Cache(e) => write!(f, "Error while interacting with block cache: {}", e),
568 Error::MisbehavingServer => write!(f, "lightwalletd server is misbehaving"),
569 Error::Scan(e) => write!(f, "Error while scanning blocks: {}", e),
570 Error::Server(e) => write!(
571 f,
572 "Error while communicating with lightwalletd server: {}",
573 e
574 ),
575 Error::Wallet(e) => write!(f, "Error while interacting with wallet database: {}", e),
576 Error::WalletTrees(e) => write!(
577 f,
578 "Error while interacting with wallet commitment trees: {}",
579 e
580 ),
581 }
582 }
583}
584
585impl<CaErr, DbErr, TrErr> std::error::Error for Error<CaErr, DbErr, TrErr>
586where
587 CaErr: std::error::Error,
588 DbErr: std::error::Error,
589 TrErr: std::error::Error,
590{
591}
592
593impl<CaErr, DbErr, TrErr> From<ChainError<DbErr, CaErr>> for Error<CaErr, DbErr, TrErr> {
594 fn from(e: ChainError<DbErr, CaErr>) -> Self {
595 match e {
596 ChainError::Wallet(e) => Error::Wallet(e),
597 ChainError::BlockSource(e) => Error::Cache(e),
598 ChainError::Scan(e) => Error::Scan(e),
599 }
600 }
601}
602
603impl<CaErr, DbErr, TrErr> From<tonic::Status> for Error<CaErr, DbErr, TrErr> {
604 fn from(status: tonic::Status) -> Self {
605 Error::Server(status)
606 }
607}