From 845b1cc72ffd38b27eafdea9dd36c6425ecdd135 Mon Sep 17 00:00:00 2001 From: Hanh Date: Sun, 20 Jun 2021 08:32:15 +0800 Subject: [PATCH] Don't special case the frontier --- src/chain.rs | 6 +- src/commitment.rs | 165 ++++++++++++++++++++++++++++------------------ src/lib.rs | 2 + 3 files changed, 108 insertions(+), 65 deletions(-) diff --git a/src/chain.rs b/src/chain.rs index 1f81fb2..cf50a85 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -11,7 +11,7 @@ use zcash_primitives::merkle_tree::{CommitmentTree, IncrementalWitness}; use zcash_primitives::sapling::note_encryption::try_sapling_compact_note_decryption; use zcash_primitives::sapling::{Node, Note, SaplingIvk}; use zcash_primitives::transaction::components::sapling::CompactOutputDescription; -use crate::commitment::{CTree, Witness}; +use crate::commitment::{CTree, Witness, NotePosition}; use std::time::Instant; use log::info; @@ -216,7 +216,9 @@ pub fn calculate_tree_state_v2(cbs: &[CompactBlock], blocks: &[DecryptedBlock]) info!("Build CMU list: {} ms - {} nodes", start.elapsed().as_millis(), nodes.len()); let start = Instant::now(); - let (_tree, positions) = CTree::calc_state(&mut nodes, &positions); + let n = nodes.len(); + let mut positions: Vec<_> = positions.iter().map(|&p| NotePosition::new(p, n)).collect(); + let _frontier = CTree::calc_state(nodes, &mut positions, &CTree::new()); let witnesses: Vec<_> = positions.iter().map(|p| p.witness.clone()).collect(); info!("Tree State & Witnesses: {} ms", start.elapsed().as_millis()); witnesses diff --git a/src/commitment.rs b/src/commitment.rs index 9e1ec2a..6576705 100644 --- a/src/commitment.rs +++ b/src/commitment.rs @@ -1,9 +1,10 @@ -use zcash_primitives::merkle_tree::Hashable; -use zcash_primitives::sapling::Node; -use std::io::Write; -use zcash_primitives::serialize::{Optional, Vector}; +use crate::path::MerklePath; use byteorder::WriteBytesExt; use rayon::prelude::*; +use std::io::Write; +use zcash_primitives::merkle_tree::Hashable; +use zcash_primitives::sapling::Node; +use zcash_primitives::serialize::{Optional, Vector}; /* Same behavior and structure as CommitmentTree from librustzcash @@ -78,7 +79,7 @@ gets pushed into `filled`. pub struct Witness { tree: CTree, // commitment tree at the moment the witness is created: immutable filled: Vec, // as more nodes are added, levels get filled up: won't change anymore - cursor: CTree, // partial tree which still updates when nodes are added + cursor: CTree, // partial tree which still updates when nodes are added } impl Witness { @@ -95,8 +96,7 @@ impl Witness { Vector::write(&mut writer, &self.filled, |w, n| n.write(w))?; if self.cursor.left == None && self.cursor.right == None { writer.write_u8(0)?; - } - else { + } else { writer.write_u8(1)?; self.cursor.write(writer)?; }; @@ -106,26 +106,32 @@ impl Witness { pub struct NotePosition { p: usize, - p2: Option, + p2: usize, c: usize, pub witness: Witness, - is_last: bool, } -fn collect(tree: &mut CTree, mut p: usize, depth: usize, commitments: &[Node]) -> usize { +fn collect( + tree: &mut CTree, + mut p: usize, + depth: usize, + commitments: &[Node], + offset: usize, +) -> usize { + // println!("--> {} {} {}", depth, p, offset); if depth == 0 { if p % 2 == 0 { - tree.left = Some(commitments[p]); + tree.left = Some(commitments[p - offset]); } else { - tree.left = Some(commitments[p - 1]); - tree.right = Some(commitments[p]); + tree.left = Some(commitments[p - 1 - offset]); + tree.right = Some(commitments[p - offset]); p -= 1; } } else { // the rest gets combined as a binary tree if p % 2 != 0 { - tree.parents.push(Some(commitments[p - 1])); - } else if p != 0 { + tree.parents.push(Some(commitments[p - 1 - offset])); + } else if p != offset { tree.parents.push(None); } } @@ -133,47 +139,43 @@ fn collect(tree: &mut CTree, mut p: usize, depth: usize, commitments: &[Node]) - } impl NotePosition { - fn new(position: usize, count: usize) -> NotePosition { - let is_last = position == count - 1; - let c = if !is_last { - cursor_start_position(position, count) - } else { - 0 - }; - let cursor_length = count - c; + pub fn new(position: usize, count: usize) -> NotePosition { + let c = cursor_start_position(position, count); NotePosition { p: position, - p2: if cursor_length > 0 { - Some(cursor_length - 1) - } else { - None - }, + p2: count - 1, c, witness: Witness::new(), - is_last, } } - fn collect(&mut self, depth: usize, commitments: &[Node]) { + fn collect(&mut self, depth: usize, commitments: &[Node], offset: usize) { let count = commitments.len(); let p = self.p; - self.p = collect(&mut self.witness.tree, p, depth, commitments); + self.p = collect(&mut self.witness.tree, p, depth, commitments, offset); - if !self.is_last { - if p % 2 == 0 && p + 1 < commitments.len() { - let filler = commitments[p + 1]; - self.witness.filled.push(filler); - } + if p % 2 == 0 && p + 1 - offset < commitments.len() { + let filler = commitments[p + 1 - offset]; + self.witness.filled.push(filler); } - if let Some(ref mut p2) = self.p2 { - if !self.is_last { - let cursor_commitments = &commitments[self.c..count]; - *p2 = collect(&mut self.witness.cursor, *p2, depth, cursor_commitments); - } - *p2 /= 2; + let c = self.c - offset; + let cursor_commitments = &commitments[c..count]; + // println!("c> {} {} {}", c, count, depth); + // println!("> {} {}", self.p2, self.c); + if !cursor_commitments.is_empty() { + let p2 = collect( + &mut self.witness.cursor, + self.p2, + depth, + cursor_commitments, + offset + c, + ); + self.p2 = (p2 - self.c) / 2 + self.c / 2; + // println!("+ {} {}", self.p2, self.c); } + self.p /= 2; self.c /= 2; } @@ -200,35 +202,50 @@ fn cursor_start_position(mut position: usize, mut count: usize) -> usize { } impl CTree { - pub fn calc_state(commitments: &mut [Node], positions: &[usize]) -> (CTree, Vec) { + pub fn calc_state( + mut commitments: Vec, + positions: &mut [NotePosition], + prev_frontier: &CTree, + ) -> CTree { let mut n = commitments.len(); - let mut positions: Vec<_> = positions.iter().map(|&p| NotePosition::new(p, n)).collect(); assert_ne!(n, 0); + let prev_count = prev_frontier.get_position(); + let last_path = MerklePath::new(prev_frontier.left, prev_frontier.right); + let mut frontier = NotePosition::new(prev_count + n - 1, prev_count + n); + let mut offset = prev_count; + let mut depth = 0usize; - let mut frontier = NotePosition::new(n - 1, n); while n > 0 { + if offset % 2 == 1 { + // start is not aligned, prepend the last node from the previous run + let node = last_path.get(); + commitments.insert(0, node); + n += 1; + offset -= 1; + } let commitment_slice = &commitments[0..n]; - frontier.collect(depth, commitment_slice); + frontier.collect(depth, commitment_slice, offset); for p in positions.iter_mut() { - p.collect(depth, commitment_slice); + p.collect(depth, commitment_slice, offset); } let nn = n / 2; - let next_level: Vec<_> = (0..nn).into_par_iter().map(|i| { - Node::combine(depth, &commitments[2 * i], &commitments[2 * i + 1]) - }).collect(); + let next_level: Vec<_> = (0..nn) + .into_par_iter() + .map(|i| Node::combine(depth, &commitments[2 * i], &commitments[2 * i + 1])) + .collect(); commitments[0..nn].copy_from_slice(&next_level); depth += 1; n = nn; } - (frontier.witness.tree, positions) + frontier.witness.tree } - fn new() -> CTree { + pub fn new() -> CTree { CTree { left: None, right: None, @@ -243,11 +260,29 @@ impl CTree { Optional::write(w, e, |w, n| n.write(w)) }) } + + fn get_position(&self) -> usize { + let mut p = 0usize; + for parent in self.parents.iter() { + if parent.is_some() { + p += 1; + } + p *= 2; + } + p *= 2; + if self.left.is_some() { + p += 1; + } + if self.right.is_some() { + p += 1; + } + p + } } #[cfg(test)] mod tests { - use crate::commitment::{cursor_start_position, CTree}; + use crate::commitment::{cursor_start_position, CTree, NotePosition}; #[allow(unused_imports)] use crate::print::{print_tree, print_witness}; use std::time::Instant; @@ -259,8 +294,8 @@ mod tests { */ #[test] fn test_calc_witnesses() { - const NUM_NODES: u32 = 100000; // number of notes - const WITNESS_PERCENT: u32 = 1; // percentage of notes that are ours + const NUM_NODES: u32 = 100; // number of notes + const WITNESS_PERCENT: u32 = 100; // percentage of notes that are ours const DEBUG_PRINT: bool = false; let witness_freq = 100 / WITNESS_PERCENT; @@ -279,7 +314,8 @@ mod tests { w.append(node).unwrap(); } - if i % witness_freq == 0 { + // if i % witness_freq == 0 { + if i == 65 { let w = IncrementalWitness::::from_tree(&tree1); witnesses.push(w); positions.push((i - 1) as usize); @@ -289,7 +325,9 @@ mod tests { } let start = Instant::now(); - let (tree2, positions) = CTree::calc_state(&mut nodes, &positions); + let n = nodes.len(); + let mut positions: Vec<_> = positions.iter().map(|&p| NotePosition::new(p, n)).collect(); + let tree2 = CTree::calc_state(nodes, &mut positions, &CTree::new()); eprintln!( "Update State & Witnesses: {} ms", start.elapsed().as_millis() @@ -297,32 +335,33 @@ mod tests { println!("# witnesses = {}", positions.len()); - for (w, p) in witnesses.iter().zip(&positions) { + for (i, (w, p)) in witnesses.iter().zip(&positions).enumerate() { let mut bb1: Vec = vec![]; w.write(&mut bb1).unwrap(); let mut bb2: Vec = vec![]; p.witness.write(&mut bb2).unwrap(); - assert_eq!(bb1.as_slice(), bb2.as_slice()); + assert_eq!(bb1.as_slice(), bb2.as_slice(), "failed at {}", i); } if DEBUG_PRINT { - print_witness(&witnesses[0]); + let slot = 0usize; + print_witness(&witnesses[slot]); println!("Tree"); - let t = &positions[0].witness.tree; + let t = &positions[slot].witness.tree; println!("{:?}", t.left.map(|n| hex::encode(n.repr))); println!("{:?}", t.right.map(|n| hex::encode(n.repr))); for p in t.parents.iter() { println!("{:?}", p.map(|n| hex::encode(n.repr))); } println!("Filled"); - for f in positions[0].witness.filled.iter() { + for f in positions[slot].witness.filled.iter() { println!("{:?}", hex::encode(f.repr)); } println!("Cursor"); - let t = &positions[0].witness.cursor; + let t = &positions[slot].witness.cursor; println!("{:?}", t.left.map(|n| hex::encode(n.repr))); println!("{:?}", t.right.map(|n| hex::encode(n.repr))); for p in t.parents.iter() { diff --git a/src/lib.rs b/src/lib.rs index b5cdf5b..7cd5441 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,10 +7,12 @@ pub const NETWORK: Network = Network::MainNetwork; mod print; mod chain; +mod path; mod commitment; mod scan; pub use crate::chain::{LWD_URL, get_latest_height, download_chain, calculate_tree_state_v2, DecryptNode}; +pub use crate::commitment::NotePosition; pub use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient; pub use crate::lw_rpc::*; pub use crate::scan::scan_all;