Don't special case the frontier

This commit is contained in:
Hanh 2021-06-20 08:32:15 +08:00
parent 9115304a2e
commit 845b1cc72f
3 changed files with 108 additions and 65 deletions

View File

@ -11,7 +11,7 @@ use zcash_primitives::merkle_tree::{CommitmentTree, IncrementalWitness};
use zcash_primitives::sapling::note_encryption::try_sapling_compact_note_decryption; use zcash_primitives::sapling::note_encryption::try_sapling_compact_note_decryption;
use zcash_primitives::sapling::{Node, Note, SaplingIvk}; use zcash_primitives::sapling::{Node, Note, SaplingIvk};
use zcash_primitives::transaction::components::sapling::CompactOutputDescription; use zcash_primitives::transaction::components::sapling::CompactOutputDescription;
use crate::commitment::{CTree, Witness}; use crate::commitment::{CTree, Witness, NotePosition};
use std::time::Instant; use std::time::Instant;
use log::info; use log::info;
@ -216,7 +216,9 @@ pub fn calculate_tree_state_v2(cbs: &[CompactBlock], blocks: &[DecryptedBlock])
info!("Build CMU list: {} ms - {} nodes", start.elapsed().as_millis(), nodes.len()); info!("Build CMU list: {} ms - {} nodes", start.elapsed().as_millis(), nodes.len());
let start = Instant::now(); let start = Instant::now();
let (_tree, positions) = CTree::calc_state(&mut nodes, &positions); let n = nodes.len();
let mut positions: Vec<_> = positions.iter().map(|&p| NotePosition::new(p, n)).collect();
let _frontier = CTree::calc_state(nodes, &mut positions, &CTree::new());
let witnesses: Vec<_> = positions.iter().map(|p| p.witness.clone()).collect(); let witnesses: Vec<_> = positions.iter().map(|p| p.witness.clone()).collect();
info!("Tree State & Witnesses: {} ms", start.elapsed().as_millis()); info!("Tree State & Witnesses: {} ms", start.elapsed().as_millis());
witnesses witnesses

View File

@ -1,9 +1,10 @@
use zcash_primitives::merkle_tree::Hashable; use crate::path::MerklePath;
use zcash_primitives::sapling::Node;
use std::io::Write;
use zcash_primitives::serialize::{Optional, Vector};
use byteorder::WriteBytesExt; use byteorder::WriteBytesExt;
use rayon::prelude::*; use rayon::prelude::*;
use std::io::Write;
use zcash_primitives::merkle_tree::Hashable;
use zcash_primitives::sapling::Node;
use zcash_primitives::serialize::{Optional, Vector};
/* /*
Same behavior and structure as CommitmentTree<Node> from librustzcash Same behavior and structure as CommitmentTree<Node> from librustzcash
@ -95,8 +96,7 @@ impl Witness {
Vector::write(&mut writer, &self.filled, |w, n| n.write(w))?; Vector::write(&mut writer, &self.filled, |w, n| n.write(w))?;
if self.cursor.left == None && self.cursor.right == None { if self.cursor.left == None && self.cursor.right == None {
writer.write_u8(0)?; writer.write_u8(0)?;
} } else {
else {
writer.write_u8(1)?; writer.write_u8(1)?;
self.cursor.write(writer)?; self.cursor.write(writer)?;
}; };
@ -106,26 +106,32 @@ impl Witness {
pub struct NotePosition { pub struct NotePosition {
p: usize, p: usize,
p2: Option<usize>, p2: usize,
c: usize, c: usize,
pub witness: Witness, pub witness: Witness,
is_last: bool,
} }
fn collect(tree: &mut CTree, mut p: usize, depth: usize, commitments: &[Node]) -> usize { fn collect(
tree: &mut CTree,
mut p: usize,
depth: usize,
commitments: &[Node],
offset: usize,
) -> usize {
// println!("--> {} {} {}", depth, p, offset);
if depth == 0 { if depth == 0 {
if p % 2 == 0 { if p % 2 == 0 {
tree.left = Some(commitments[p]); tree.left = Some(commitments[p - offset]);
} else { } else {
tree.left = Some(commitments[p - 1]); tree.left = Some(commitments[p - 1 - offset]);
tree.right = Some(commitments[p]); tree.right = Some(commitments[p - offset]);
p -= 1; p -= 1;
} }
} else { } else {
// the rest gets combined as a binary tree // the rest gets combined as a binary tree
if p % 2 != 0 { if p % 2 != 0 {
tree.parents.push(Some(commitments[p - 1])); tree.parents.push(Some(commitments[p - 1 - offset]));
} else if p != 0 { } else if p != offset {
tree.parents.push(None); tree.parents.push(None);
} }
} }
@ -133,47 +139,43 @@ fn collect(tree: &mut CTree, mut p: usize, depth: usize, commitments: &[Node]) -
} }
impl NotePosition { impl NotePosition {
fn new(position: usize, count: usize) -> NotePosition { pub fn new(position: usize, count: usize) -> NotePosition {
let is_last = position == count - 1; let c = cursor_start_position(position, count);
let c = if !is_last {
cursor_start_position(position, count)
} else {
0
};
let cursor_length = count - c;
NotePosition { NotePosition {
p: position, p: position,
p2: if cursor_length > 0 { p2: count - 1,
Some(cursor_length - 1)
} else {
None
},
c, c,
witness: Witness::new(), witness: Witness::new(),
is_last,
} }
} }
fn collect(&mut self, depth: usize, commitments: &[Node]) { fn collect(&mut self, depth: usize, commitments: &[Node], offset: usize) {
let count = commitments.len(); let count = commitments.len();
let p = self.p; let p = self.p;
self.p = collect(&mut self.witness.tree, p, depth, commitments); self.p = collect(&mut self.witness.tree, p, depth, commitments, offset);
if !self.is_last { if p % 2 == 0 && p + 1 - offset < commitments.len() {
if p % 2 == 0 && p + 1 < commitments.len() { let filler = commitments[p + 1 - offset];
let filler = commitments[p + 1];
self.witness.filled.push(filler); self.witness.filled.push(filler);
} }
let c = self.c - offset;
let cursor_commitments = &commitments[c..count];
// println!("c> {} {} {}", c, count, depth);
// println!("> {} {}", self.p2, self.c);
if !cursor_commitments.is_empty() {
let p2 = collect(
&mut self.witness.cursor,
self.p2,
depth,
cursor_commitments,
offset + c,
);
self.p2 = (p2 - self.c) / 2 + self.c / 2;
// println!("+ {} {}", self.p2, self.c);
} }
if let Some(ref mut p2) = self.p2 {
if !self.is_last {
let cursor_commitments = &commitments[self.c..count];
*p2 = collect(&mut self.witness.cursor, *p2, depth, cursor_commitments);
}
*p2 /= 2;
}
self.p /= 2; self.p /= 2;
self.c /= 2; self.c /= 2;
} }
@ -200,35 +202,50 @@ fn cursor_start_position(mut position: usize, mut count: usize) -> usize {
} }
impl CTree { impl CTree {
pub fn calc_state(commitments: &mut [Node], positions: &[usize]) -> (CTree, Vec<NotePosition>) { pub fn calc_state(
mut commitments: Vec<Node>,
positions: &mut [NotePosition],
prev_frontier: &CTree,
) -> CTree {
let mut n = commitments.len(); let mut n = commitments.len();
let mut positions: Vec<_> = positions.iter().map(|&p| NotePosition::new(p, n)).collect();
assert_ne!(n, 0); assert_ne!(n, 0);
let prev_count = prev_frontier.get_position();
let last_path = MerklePath::new(prev_frontier.left, prev_frontier.right);
let mut frontier = NotePosition::new(prev_count + n - 1, prev_count + n);
let mut offset = prev_count;
let mut depth = 0usize; let mut depth = 0usize;
let mut frontier = NotePosition::new(n - 1, n);
while n > 0 { while n > 0 {
if offset % 2 == 1 {
// start is not aligned, prepend the last node from the previous run
let node = last_path.get();
commitments.insert(0, node);
n += 1;
offset -= 1;
}
let commitment_slice = &commitments[0..n]; let commitment_slice = &commitments[0..n];
frontier.collect(depth, commitment_slice); frontier.collect(depth, commitment_slice, offset);
for p in positions.iter_mut() { for p in positions.iter_mut() {
p.collect(depth, commitment_slice); p.collect(depth, commitment_slice, offset);
} }
let nn = n / 2; let nn = n / 2;
let next_level: Vec<_> = (0..nn).into_par_iter().map(|i| { let next_level: Vec<_> = (0..nn)
Node::combine(depth, &commitments[2 * i], &commitments[2 * i + 1]) .into_par_iter()
}).collect(); .map(|i| Node::combine(depth, &commitments[2 * i], &commitments[2 * i + 1]))
.collect();
commitments[0..nn].copy_from_slice(&next_level); commitments[0..nn].copy_from_slice(&next_level);
depth += 1; depth += 1;
n = nn; n = nn;
} }
(frontier.witness.tree, positions) frontier.witness.tree
} }
fn new() -> CTree { pub fn new() -> CTree {
CTree { CTree {
left: None, left: None,
right: None, right: None,
@ -243,11 +260,29 @@ impl CTree {
Optional::write(w, e, |w, n| n.write(w)) Optional::write(w, e, |w, n| n.write(w))
}) })
} }
fn get_position(&self) -> usize {
let mut p = 0usize;
for parent in self.parents.iter() {
if parent.is_some() {
p += 1;
}
p *= 2;
}
p *= 2;
if self.left.is_some() {
p += 1;
}
if self.right.is_some() {
p += 1;
}
p
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::commitment::{cursor_start_position, CTree}; use crate::commitment::{cursor_start_position, CTree, NotePosition};
#[allow(unused_imports)] #[allow(unused_imports)]
use crate::print::{print_tree, print_witness}; use crate::print::{print_tree, print_witness};
use std::time::Instant; use std::time::Instant;
@ -259,8 +294,8 @@ mod tests {
*/ */
#[test] #[test]
fn test_calc_witnesses() { fn test_calc_witnesses() {
const NUM_NODES: u32 = 100000; // number of notes const NUM_NODES: u32 = 100; // number of notes
const WITNESS_PERCENT: u32 = 1; // percentage of notes that are ours const WITNESS_PERCENT: u32 = 100; // percentage of notes that are ours
const DEBUG_PRINT: bool = false; const DEBUG_PRINT: bool = false;
let witness_freq = 100 / WITNESS_PERCENT; let witness_freq = 100 / WITNESS_PERCENT;
@ -279,7 +314,8 @@ mod tests {
w.append(node).unwrap(); w.append(node).unwrap();
} }
if i % witness_freq == 0 { // if i % witness_freq == 0 {
if i == 65 {
let w = IncrementalWitness::<Node>::from_tree(&tree1); let w = IncrementalWitness::<Node>::from_tree(&tree1);
witnesses.push(w); witnesses.push(w);
positions.push((i - 1) as usize); positions.push((i - 1) as usize);
@ -289,7 +325,9 @@ mod tests {
} }
let start = Instant::now(); let start = Instant::now();
let (tree2, positions) = CTree::calc_state(&mut nodes, &positions); let n = nodes.len();
let mut positions: Vec<_> = positions.iter().map(|&p| NotePosition::new(p, n)).collect();
let tree2 = CTree::calc_state(nodes, &mut positions, &CTree::new());
eprintln!( eprintln!(
"Update State & Witnesses: {} ms", "Update State & Witnesses: {} ms",
start.elapsed().as_millis() start.elapsed().as_millis()
@ -297,32 +335,33 @@ mod tests {
println!("# witnesses = {}", positions.len()); println!("# witnesses = {}", positions.len());
for (w, p) in witnesses.iter().zip(&positions) { for (i, (w, p)) in witnesses.iter().zip(&positions).enumerate() {
let mut bb1: Vec<u8> = vec![]; let mut bb1: Vec<u8> = vec![];
w.write(&mut bb1).unwrap(); w.write(&mut bb1).unwrap();
let mut bb2: Vec<u8> = vec![]; let mut bb2: Vec<u8> = vec![];
p.witness.write(&mut bb2).unwrap(); p.witness.write(&mut bb2).unwrap();
assert_eq!(bb1.as_slice(), bb2.as_slice()); assert_eq!(bb1.as_slice(), bb2.as_slice(), "failed at {}", i);
} }
if DEBUG_PRINT { if DEBUG_PRINT {
print_witness(&witnesses[0]); let slot = 0usize;
print_witness(&witnesses[slot]);
println!("Tree"); println!("Tree");
let t = &positions[0].witness.tree; let t = &positions[slot].witness.tree;
println!("{:?}", t.left.map(|n| hex::encode(n.repr))); println!("{:?}", t.left.map(|n| hex::encode(n.repr)));
println!("{:?}", t.right.map(|n| hex::encode(n.repr))); println!("{:?}", t.right.map(|n| hex::encode(n.repr)));
for p in t.parents.iter() { for p in t.parents.iter() {
println!("{:?}", p.map(|n| hex::encode(n.repr))); println!("{:?}", p.map(|n| hex::encode(n.repr)));
} }
println!("Filled"); println!("Filled");
for f in positions[0].witness.filled.iter() { for f in positions[slot].witness.filled.iter() {
println!("{:?}", hex::encode(f.repr)); println!("{:?}", hex::encode(f.repr));
} }
println!("Cursor"); println!("Cursor");
let t = &positions[0].witness.cursor; let t = &positions[slot].witness.cursor;
println!("{:?}", t.left.map(|n| hex::encode(n.repr))); println!("{:?}", t.left.map(|n| hex::encode(n.repr)));
println!("{:?}", t.right.map(|n| hex::encode(n.repr))); println!("{:?}", t.right.map(|n| hex::encode(n.repr)));
for p in t.parents.iter() { for p in t.parents.iter() {

View File

@ -7,10 +7,12 @@ pub const NETWORK: Network = Network::MainNetwork;
mod print; mod print;
mod chain; mod chain;
mod path;
mod commitment; mod commitment;
mod scan; mod scan;
pub use crate::chain::{LWD_URL, get_latest_height, download_chain, calculate_tree_state_v2, DecryptNode}; pub use crate::chain::{LWD_URL, get_latest_height, download_chain, calculate_tree_state_v2, DecryptNode};
pub use crate::commitment::NotePosition;
pub use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient; pub use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient;
pub use crate::lw_rpc::*; pub use crate::lw_rpc::*;
pub use crate::scan::scan_all; pub use crate::scan::scan_all;