use zcash_primitives::merkle_tree::Hashable; use zcash_primitives::sapling::Node; use std::io::Write; use zcash_primitives::serialize::{Optional, Vector}; use byteorder::WriteBytesExt; use rayon::prelude::*; #[derive(Clone)] pub struct CTree { left: Option, right: Option, parents: Vec>, } #[derive(Clone)] pub struct Witness { tree: CTree, // commitment tree at the moment the witness is created: immutable filled: Vec, // as more nodes are added, levels get filled up: won't change anymore cursor: CTree, // partial tree which still updates when nodes are added } impl Witness { pub fn new() -> Witness { Witness { tree: CTree::new(), filled: vec![], cursor: CTree::new(), } } pub fn write(&self, mut writer: W) -> std::io::Result<()> { self.tree.write(&mut writer)?; Vector::write(&mut writer, &self.filled, |w, n| n.write(w))?; if self.cursor.left == None && self.cursor.right == None { writer.write_u8(0)?; } else { writer.write_u8(1)?; self.cursor.write(writer)?; }; Ok(()) } } pub struct NotePosition { p: usize, p2: Option, c: usize, pub witness: Witness, is_last: bool, } fn collect(tree: &mut CTree, mut p: usize, depth: usize, commitments: &[Node]) -> usize { if depth == 0 { if p % 2 == 0 { tree.left = Some(commitments[p]); } else { tree.left = Some(commitments[p - 1]); tree.right = Some(commitments[p]); p -= 1; } } else { // the rest gets combined as a binary tree if p % 2 != 0 { tree.parents.push(Some(commitments[p - 1])); } else if p != 0 { tree.parents.push(None); } } p } impl NotePosition { fn new(position: usize, count: usize) -> NotePosition { let is_last = position == count - 1; let c = if !is_last { cursor_start_position(position, count) } else { 0 }; let cursor_length = count - c; NotePosition { p: position, p2: if cursor_length > 0 { Some(cursor_length - 1) } else { None }, c, witness: Witness::new(), is_last, } } fn collect(&mut self, depth: usize, commitments: &[Node]) { let count = commitments.len(); let p = self.p; self.p = collect(&mut self.witness.tree, p, depth, commitments); if !self.is_last { if p % 2 == 0 && p + 1 < commitments.len() { let filler = commitments[p + 1]; self.witness.filled.push(filler); } } if let Some(ref mut p2) = self.p2 { if !self.is_last { let cursor_commitments = &commitments[self.c..count]; *p2 = collect(&mut self.witness.cursor, *p2, depth, cursor_commitments); } *p2 /= 2; } self.p /= 2; self.c /= 2; } } fn cursor_start_position(mut position: usize, mut count: usize) -> usize { assert!(position < count); // same logic as filler let mut depth = 0; loop { if position % 2 == 0 { if position + 1 < count { position += 1; } else { break; } } position /= 2; count /= 2; depth += 1; } (position + 1) << depth } impl CTree { pub fn calc_state(commitments: &mut [Node], positions: &[usize]) -> (CTree, Vec) { let mut n = commitments.len(); let mut positions: Vec<_> = positions.iter().map(|&p| NotePosition::new(p, n)).collect(); assert_ne!(n, 0); let mut depth = 0usize; let mut frontier = NotePosition::new(n - 1, n); while n > 0 { let commitment_slice = &commitments[0..n]; frontier.collect(depth, commitment_slice); for p in positions.iter_mut() { p.collect(depth, commitment_slice); } let nn = n / 2; let next_level: Vec<_> = (0..nn).into_par_iter().map(|i| { Node::combine(depth, &commitments[2 * i], &commitments[2 * i + 1]) }).collect(); commitments[0..nn].copy_from_slice(&next_level); depth += 1; n = nn; } (frontier.witness.tree, positions) } fn new() -> CTree { CTree { left: None, right: None, parents: vec![], } } fn write(&self, mut writer: W) -> std::io::Result<()> { Optional::write(&mut writer, &self.left, |w, n| n.write(w))?; Optional::write(&mut writer, &self.right, |w, n| n.write(w))?; Vector::write(&mut writer, &self.parents, |w, e| { Optional::write(w, e, |w, n| n.write(w)) }) } } #[cfg(test)] mod tests { use crate::commitment::{cursor_start_position, CTree}; #[allow(unused_imports)] use crate::print::{print_tree, print_witness}; use std::time::Instant; use zcash_primitives::merkle_tree::{CommitmentTree, IncrementalWitness}; use zcash_primitives::sapling::Node; /* Build incremental witnesses with both methods and compare their binary serialization */ #[test] fn test_calc_witnesses() { const NUM_NODES: u32 = 100000; // number of notes const WITNESS_PERCENT: u32 = 1; // percentage of notes that are ours const DEBUG_PRINT: bool = false; let witness_freq = 100 / WITNESS_PERCENT; let mut tree1: CommitmentTree = CommitmentTree::empty(); let mut nodes: Vec = vec![]; let mut witnesses: Vec> = vec![]; let mut positions: Vec = vec![]; for i in 1..=NUM_NODES { let mut bb = [0u8; 32]; bb[0..4].copy_from_slice(&i.to_be_bytes()); let node = Node::new(bb); tree1.append(node).unwrap(); for w in witnesses.iter_mut() { w.append(node).unwrap(); } if i % witness_freq == 0 { let w = IncrementalWitness::::from_tree(&tree1); witnesses.push(w); positions.push((i - 1) as usize); } nodes.push(node); } let start = Instant::now(); let (tree2, positions) = CTree::calc_state(&mut nodes, &positions); eprintln!( "Update State & Witnesses: {} ms", start.elapsed().as_millis() ); println!("# witnesses = {}", positions.len()); for (w, p) in witnesses.iter().zip(&positions) { let mut bb1: Vec = vec![]; w.write(&mut bb1).unwrap(); let mut bb2: Vec = vec![]; p.witness.write(&mut bb2).unwrap(); assert_eq!(bb1.as_slice(), bb2.as_slice()); } if DEBUG_PRINT { print_witness(&witnesses[0]); println!("Tree"); let t = &positions[0].witness.tree; println!("{:?}", t.left.map(|n| hex::encode(n.repr))); println!("{:?}", t.right.map(|n| hex::encode(n.repr))); for p in t.parents.iter() { println!("{:?}", p.map(|n| hex::encode(n.repr))); } println!("Filled"); for f in positions[0].witness.filled.iter() { println!("{:?}", hex::encode(f.repr)); } println!("Cursor"); let t = &positions[0].witness.cursor; println!("{:?}", t.left.map(|n| hex::encode(n.repr))); println!("{:?}", t.right.map(|n| hex::encode(n.repr))); for p in t.parents.iter() { println!("{:?}", p.map(|n| hex::encode(n.repr))); } println!("{:?}", tree1.left.map(|n| hex::encode(n.repr))); println!("{:?}", tree1.right.map(|n| hex::encode(n.repr))); for p in tree1.parents.iter() { println!("{:?}", p.map(|n| hex::encode(n.repr))); } println!("-----"); println!("{:?}", tree2.left.map(|n| hex::encode(n.repr))); println!("{:?}", tree2.right.map(|n| hex::encode(n.repr))); for p in tree2.parents.iter() { println!("{:?}", p.map(|n| hex::encode(n.repr))); } } } #[test] fn test_cursor() { // println!("{}", cursor_start_position(8, 14)); println!("{}", cursor_start_position(9, 14)); // println!("{}", cursor_start_position(10, 14)); } }