1use crossbeam_channel as channel;
2use std::collections::HashMap;
3use std::fmt;
4use std::mem;
5use std::sync::{
6 atomic::{AtomicUsize, Ordering},
7 Arc,
8};
9
10use memuse::DynamicUsage;
11use zcash_note_encryption::{
12 batch, BatchDomain, Domain, ShieldedOutput, COMPACT_NOTE_SIZE, ENC_CIPHERTEXT_SIZE,
13};
14use zcash_primitives::{block::BlockHash, transaction::TxId};
15
16pub(crate) struct DecryptedOutput<IvkTag, D: Domain, M> {
18 pub(crate) ivk_tag: IvkTag,
20 pub(crate) recipient: D::Recipient,
22 pub(crate) note: D::Note,
24 pub(crate) memo: M,
26}
27
28impl<IvkTag, D: Domain, M> fmt::Debug for DecryptedOutput<IvkTag, D, M>
29where
30 IvkTag: fmt::Debug,
31 D::IncomingViewingKey: fmt::Debug,
32 D::Recipient: fmt::Debug,
33 D::Note: fmt::Debug,
34 M: fmt::Debug,
35{
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 f.debug_struct("DecryptedOutput")
38 .field("ivk_tag", &self.ivk_tag)
39 .field("recipient", &self.recipient)
40 .field("note", &self.note)
41 .field("memo", &self.memo)
42 .finish()
43 }
44}
45
46pub(crate) trait Decryptor<D: BatchDomain, Output> {
48 type Memo;
49
50 fn batch_decrypt<IvkTag: Clone>(
51 tags: &[IvkTag],
52 ivks: &[D::IncomingViewingKey],
53 outputs: &[(D, Output)],
54 ) -> impl Iterator<Item = Option<DecryptedOutput<IvkTag, D, Self::Memo>>>;
55}
56
57#[allow(dead_code)]
59pub(crate) struct FullDecryptor;
60
61impl<D: BatchDomain, Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE>> Decryptor<D, Output>
62 for FullDecryptor
63{
64 type Memo = D::Memo;
65
66 fn batch_decrypt<IvkTag: Clone>(
67 tags: &[IvkTag],
68 ivks: &[D::IncomingViewingKey],
69 outputs: &[(D, Output)],
70 ) -> impl Iterator<Item = Option<DecryptedOutput<IvkTag, D, Self::Memo>>> {
71 batch::try_note_decryption(ivks, outputs)
72 .into_iter()
73 .map(|res| {
74 res.map(|((note, recipient, memo), ivk_idx)| DecryptedOutput {
75 ivk_tag: tags[ivk_idx].clone(),
76 recipient,
77 note,
78 memo,
79 })
80 })
81 }
82}
83
84pub(crate) struct CompactDecryptor;
86
87impl<D: BatchDomain, Output: ShieldedOutput<D, COMPACT_NOTE_SIZE>> Decryptor<D, Output>
88 for CompactDecryptor
89{
90 type Memo = ();
91
92 fn batch_decrypt<IvkTag: Clone>(
93 tags: &[IvkTag],
94 ivks: &[D::IncomingViewingKey],
95 outputs: &[(D, Output)],
96 ) -> impl Iterator<Item = Option<DecryptedOutput<IvkTag, D, Self::Memo>>> {
97 batch::try_compact_note_decryption(ivks, outputs)
98 .into_iter()
99 .map(|res| {
100 res.map(|((note, recipient), ivk_idx)| DecryptedOutput {
101 ivk_tag: tags[ivk_idx].clone(),
102 recipient,
103 note,
104 memo: (),
105 })
106 })
107 }
108}
109
110struct OutputIndex<V> {
112 output_index: usize,
114 value: V,
116}
117
118type OutputItem<IvkTag, D, M> = OutputIndex<DecryptedOutput<IvkTag, D, M>>;
119
120struct OutputReplier<IvkTag, D: Domain, M>(OutputIndex<channel::Sender<OutputItem<IvkTag, D, M>>>);
122
123impl<IvkTag, D: Domain, M> DynamicUsage for OutputReplier<IvkTag, D, M> {
124 #[inline(always)]
125 fn dynamic_usage(&self) -> usize {
126 0
128 }
129
130 #[inline(always)]
131 fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
132 (0, Some(0))
133 }
134}
135
136struct BatchReceiver<IvkTag, D: Domain, M>(channel::Receiver<OutputItem<IvkTag, D, M>>);
138
139impl<IvkTag, D: Domain, M> DynamicUsage for BatchReceiver<IvkTag, D, M> {
140 fn dynamic_usage(&self) -> usize {
141 let num_items = self.0.len();
143
144 const ITEMS_PER_BLOCK: usize = 31;
148 let num_blocks = num_items.div_ceil(ITEMS_PER_BLOCK);
149
150 const PTR_SIZE: usize = std::mem::size_of::<usize>();
156 let item_size = std::mem::size_of::<OutputItem<IvkTag, D, M>>();
157 const ATOMIC_USIZE_SIZE: usize = std::mem::size_of::<AtomicUsize>();
158 let block_size = PTR_SIZE + ITEMS_PER_BLOCK * (item_size + ATOMIC_USIZE_SIZE);
159
160 num_blocks * block_size
161 }
162
163 fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
164 let usage = self.dynamic_usage();
165 (usage, Some(usage))
166 }
167}
168
169pub(crate) trait Tasks<Item> {
173 type Task: Task;
174 fn new() -> Self;
175 fn add_task(&self, item: Item) -> Self::Task;
176 fn run_task(&self, item: Item) {
177 let task = self.add_task(item);
178 rayon::spawn_fifo(|| task.run());
179 }
180}
181
182pub(crate) trait Task: Send + 'static {
184 fn run(self);
185}
186
187impl<Item: Task> Tasks<Item> for () {
188 type Task = Item;
189 fn new() -> Self {}
190 fn add_task(&self, item: Item) -> Self::Task {
191 item
194 }
195}
196
197#[allow(dead_code)]
202pub(crate) struct WithUsage {
203 running_usage: Arc<AtomicUsize>,
205}
206
207impl DynamicUsage for WithUsage {
208 fn dynamic_usage(&self) -> usize {
209 self.running_usage.load(Ordering::Relaxed)
210 }
211
212 fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
213 let usage = self.dynamic_usage();
216 (usage, Some(usage))
217 }
218}
219
220impl<Item: Task + DynamicUsage> Tasks<Item> for WithUsage {
221 type Task = WithUsageTask<Item>;
222
223 fn new() -> Self {
224 Self {
225 running_usage: Arc::new(AtomicUsize::new(0)),
226 }
227 }
228
229 fn add_task(&self, item: Item) -> Self::Task {
230 let mut task = WithUsageTask {
232 item,
233 own_usage: 0,
234 running_usage: self.running_usage.clone(),
235 };
236
237 task.own_usage =
244 mem::size_of::<Arc<()>>() + mem::size_of_val(&task) + task.item.dynamic_usage();
245
246 self.running_usage
250 .fetch_add(task.own_usage, Ordering::SeqCst);
251
252 task
253 }
254}
255
256pub(crate) struct WithUsageTask<Item> {
259 item: Item,
261 own_usage: usize,
265 running_usage: Arc<AtomicUsize>,
267}
268
269impl<Item: Task> Task for WithUsageTask<Item> {
270 fn run(self) {
271 self.item.run();
273
274 self.running_usage
276 .fetch_sub(self.own_usage, Ordering::SeqCst);
277 }
278}
279
280pub(crate) struct Batch<IvkTag, D: BatchDomain, Output, Dec: Decryptor<D, Output>> {
282 tags: Vec<IvkTag>,
283 ivks: Vec<D::IncomingViewingKey>,
284 outputs: Vec<(D, Output)>,
292 repliers: Vec<OutputReplier<IvkTag, D, Dec::Memo>>,
293}
294
295impl<IvkTag, D, Output, Dec> DynamicUsage for Batch<IvkTag, D, Output, Dec>
296where
297 IvkTag: DynamicUsage,
298 D: BatchDomain + DynamicUsage,
299 D::IncomingViewingKey: DynamicUsage,
300 Output: DynamicUsage,
301 Dec: Decryptor<D, Output>,
302{
303 fn dynamic_usage(&self) -> usize {
304 self.tags.dynamic_usage()
305 + self.ivks.dynamic_usage()
306 + self.outputs.dynamic_usage()
307 + self.repliers.dynamic_usage()
308 }
309
310 fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
311 let (tags_lower, tags_upper) = self.tags.dynamic_usage_bounds();
312 let (ivks_lower, ivks_upper) = self.ivks.dynamic_usage_bounds();
313 let (outputs_lower, outputs_upper) = self.outputs.dynamic_usage_bounds();
314 let (repliers_lower, repliers_upper) = self.repliers.dynamic_usage_bounds();
315
316 (
317 tags_lower + ivks_lower + outputs_lower + repliers_lower,
318 tags_upper
319 .zip(ivks_upper)
320 .zip(outputs_upper)
321 .zip(repliers_upper)
322 .map(|(((a, b), c), d)| a + b + c + d),
323 )
324 }
325}
326
327impl<IvkTag, D, Output, Dec> Batch<IvkTag, D, Output, Dec>
328where
329 IvkTag: Clone,
330 D: BatchDomain,
331 Dec: Decryptor<D, Output>,
332{
333 fn new(tags: Vec<IvkTag>, ivks: Vec<D::IncomingViewingKey>) -> Self {
335 assert_eq!(tags.len(), ivks.len());
336 Self {
337 tags,
338 ivks,
339 outputs: vec![],
340 repliers: vec![],
341 }
342 }
343
344 fn is_empty(&self) -> bool {
346 self.outputs.is_empty()
347 }
348}
349
350impl<IvkTag, D, Output, Dec> Task for Batch<IvkTag, D, Output, Dec>
351where
352 IvkTag: Clone + Send + 'static,
353 D: BatchDomain + Send + 'static,
354 D::IncomingViewingKey: Send,
355 D::Memo: Send,
356 D::Note: Send,
357 D::Recipient: Send,
358 Output: Send + 'static,
359 Dec: Decryptor<D, Output> + 'static,
360 Dec::Memo: Send,
361{
362 fn run(self) {
364 let Self {
366 tags,
367 ivks,
368 outputs,
369 repliers,
370 } = self;
371
372 assert_eq!(outputs.len(), repliers.len());
373
374 let decryption_results = Dec::batch_decrypt(&tags, &ivks, &outputs);
375 for (decryption_result, OutputReplier(replier)) in decryption_results.zip(repliers) {
376 if let Some(value) = decryption_result {
379 let result = OutputIndex {
380 output_index: replier.output_index,
381 value,
382 };
383
384 if replier.value.send(result).is_err() {
385 tracing::debug!("BatchRunner was dropped before batch finished");
386 break;
387 }
388 }
389 }
390 }
391}
392
393impl<IvkTag, D, Output, Dec> Batch<IvkTag, D, Output, Dec>
394where
395 D: BatchDomain,
396 Output: Clone,
397 Dec: Decryptor<D, Output>,
398{
399 fn add_outputs(
403 &mut self,
404 domain: impl Fn(&Output) -> D,
405 outputs: &[Output],
406 replier: channel::Sender<OutputItem<IvkTag, D, Dec::Memo>>,
407 ) {
408 self.outputs.extend(
409 outputs
410 .iter()
411 .cloned()
412 .map(|output| (domain(&output), output)),
413 );
414 self.repliers.extend((0..outputs.len()).map(|output_index| {
415 OutputReplier(OutputIndex {
416 output_index,
417 value: replier.clone(),
418 })
419 }));
420 }
421}
422
423#[derive(PartialEq, Eq, Hash)]
425struct ResultKey(BlockHash, TxId);
426
427impl DynamicUsage for ResultKey {
428 #[inline(always)]
429 fn dynamic_usage(&self) -> usize {
430 0
431 }
432
433 #[inline(always)]
434 fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
435 (0, Some(0))
436 }
437}
438
439pub(crate) struct BatchRunner<IvkTag, D, Output, Dec, T>
441where
442 D: BatchDomain,
443 Dec: Decryptor<D, Output>,
444 T: Tasks<Batch<IvkTag, D, Output, Dec>>,
445{
446 batch_size_threshold: usize,
447 acc: Batch<IvkTag, D, Output, Dec>,
449 running_tasks: T,
451 pending_results: HashMap<ResultKey, BatchReceiver<IvkTag, D, Dec::Memo>>,
453}
454
455impl<IvkTag, D, Output, Dec, T> DynamicUsage for BatchRunner<IvkTag, D, Output, Dec, T>
456where
457 IvkTag: DynamicUsage,
458 D: BatchDomain + DynamicUsage,
459 D::IncomingViewingKey: DynamicUsage,
460 Output: DynamicUsage,
461 Dec: Decryptor<D, Output>,
462 T: Tasks<Batch<IvkTag, D, Output, Dec>> + DynamicUsage,
463{
464 fn dynamic_usage(&self) -> usize {
465 self.acc.dynamic_usage()
466 + self.running_tasks.dynamic_usage()
467 + self.pending_results.dynamic_usage()
468 }
469
470 fn dynamic_usage_bounds(&self) -> (usize, Option<usize>) {
471 let running_usage = self.running_tasks.dynamic_usage();
472
473 let bounds = (
474 self.acc.dynamic_usage_bounds(),
475 self.pending_results.dynamic_usage_bounds(),
476 );
477 (
478 bounds.0 .0 + running_usage + bounds.1 .0,
479 bounds
480 .0
481 .1
482 .zip(bounds.1 .1)
483 .map(|(a, b)| a + running_usage + b),
484 )
485 }
486}
487
488impl<IvkTag, D, Output, Dec, T> BatchRunner<IvkTag, D, Output, Dec, T>
489where
490 IvkTag: Clone,
491 D: BatchDomain,
492 Dec: Decryptor<D, Output>,
493 T: Tasks<Batch<IvkTag, D, Output, Dec>>,
494{
495 pub(crate) fn new(
497 batch_size_threshold: usize,
498 ivks: impl Iterator<Item = (IvkTag, D::IncomingViewingKey)>,
499 ) -> Self {
500 let (tags, ivks) = ivks.unzip();
501 Self {
502 batch_size_threshold,
503 acc: Batch::new(tags, ivks),
504 running_tasks: T::new(),
505 pending_results: HashMap::default(),
506 }
507 }
508}
509
510impl<IvkTag, D, Output, Dec, T> BatchRunner<IvkTag, D, Output, Dec, T>
511where
512 IvkTag: Clone + Send + 'static,
513 D: BatchDomain + Send + 'static,
514 D::IncomingViewingKey: Clone + Send,
515 D::Memo: Send,
516 D::Note: Send,
517 D::Recipient: Send,
518 Output: Clone + Send + 'static,
519 Dec: Decryptor<D, Output>,
520 T: Tasks<Batch<IvkTag, D, Output, Dec>>,
521{
522 pub(crate) fn add_outputs(
532 &mut self,
533 block_tag: BlockHash,
534 txid: TxId,
535 domain: impl Fn(&Output) -> D,
536 outputs: &[Output],
537 ) {
538 let (tx, rx) = channel::unbounded();
539 self.acc.add_outputs(domain, outputs, tx);
540 self.pending_results
541 .insert(ResultKey(block_tag, txid), BatchReceiver(rx));
542
543 if self.acc.outputs.len() >= self.batch_size_threshold {
544 self.flush();
545 }
546 }
547
548 pub(crate) fn flush(&mut self) {
552 if !self.acc.is_empty() {
553 let mut batch = Batch::new(self.acc.tags.clone(), self.acc.ivks.clone());
554 mem::swap(&mut batch, &mut self.acc);
555 self.running_tasks.run_task(batch);
556 }
557 }
558
559 pub(crate) fn collect_results(
565 &mut self,
566 block_tag: BlockHash,
567 txid: TxId,
568 ) -> HashMap<(TxId, usize), DecryptedOutput<IvkTag, D, Dec::Memo>> {
569 self.pending_results
570 .remove(&ResultKey(block_tag, txid))
571 .map(|BatchReceiver(rx)| {
574 rx.into_iter()
581 .map(
582 |OutputIndex {
583 output_index,
584 value,
585 }| { ((txid, output_index), value) },
586 )
587 .collect()
588 })
589 .unwrap_or_default()
590 }
591}