diff --git a/src/chain.rs b/src/chain.rs index cf50a85..9650ad8 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -218,7 +218,7 @@ pub fn calculate_tree_state_v2(cbs: &[CompactBlock], blocks: &[DecryptedBlock]) let start = Instant::now(); let n = nodes.len(); let mut positions: Vec<_> = positions.iter().map(|&p| NotePosition::new(p, n)).collect(); - let _frontier = CTree::calc_state(nodes, &mut positions, &CTree::new()); + let _frontier = CTree::calc_state(nodes, &mut positions, None); let witnesses: Vec<_> = positions.iter().map(|p| p.witness.clone()).collect(); info!("Tree State & Witnesses: {} ms", start.elapsed().as_millis()); witnesses diff --git a/src/commitment.rs b/src/commitment.rs index 6576705..924dcfe 100644 --- a/src/commitment.rs +++ b/src/commitment.rs @@ -105,6 +105,7 @@ impl Witness { } pub struct NotePosition { + p0: usize, p: usize, p2: usize, c: usize, @@ -119,6 +120,7 @@ fn collect( offset: usize, ) -> usize { // println!("--> {} {} {}", depth, p, offset); + if p < offset { return p } if depth == 0 { if p % 2 == 0 { tree.left = Some(commitments[p - offset]); @@ -131,7 +133,7 @@ fn collect( // the rest gets combined as a binary tree if p % 2 != 0 { tree.parents.push(Some(commitments[p - 1 - offset])); - } else if p != offset { + } else if !(p == 0 && offset == 0) { tree.parents.push(None); } } @@ -142,6 +144,7 @@ impl NotePosition { pub fn new(position: usize, count: usize) -> NotePosition { let c = cursor_start_position(position, count); NotePosition { + p0: position, p: position, p2: count - 1, c, @@ -149,13 +152,20 @@ impl NotePosition { } } + pub fn reset(&mut self, count: usize) { + let c = cursor_start_position(self.p0, count); + self.p = self.p0; + self.p2 = count - 1; + self.c = c; + } + fn collect(&mut self, depth: usize, commitments: &[Node], offset: usize) { let count = commitments.len(); let p = self.p; self.p = collect(&mut self.witness.tree, p, depth, commitments, offset); - if p % 2 == 0 && p + 1 - offset < commitments.len() { + if p % 2 == 0 && p + 1 >= offset && p + 1 - offset < commitments.len() { let filler = commitments[p + 1 - offset]; self.witness.filled.push(filler); } @@ -205,21 +215,31 @@ impl CTree { pub fn calc_state( mut commitments: Vec, positions: &mut [NotePosition], - prev_frontier: &CTree, + prev_frontier: Option, ) -> CTree { let mut n = commitments.len(); assert_ne!(n, 0); - let prev_count = prev_frontier.get_position(); - let last_path = MerklePath::new(prev_frontier.left, prev_frontier.right); - let mut frontier = NotePosition::new(prev_count + n - 1, prev_count + n); + let prev_count = prev_frontier.as_ref().map(|f| f.get_position()).unwrap_or(0); + let count = prev_count + n; + let mut last_path = prev_frontier.as_ref().map(|f| MerklePath::new(f.left, f.right)); + let mut frontier = NotePosition::new(count - 1, count); let mut offset = prev_count; + for p in positions.iter_mut() { + p.reset(count); + } + let mut depth = 0usize; - while n > 0 { + while n + offset > 0 { if offset % 2 == 1 { - // start is not aligned, prepend the last node from the previous run - let node = last_path.get(); + // start is not aligned + let mut lp = last_path.take().unwrap(); + let node = lp.get(); // prepend the last node from the previous run + if n > 0 { + lp.set(commitments[0]); // put the right node into the path + } + last_path = Some(lp); commitments.insert(0, node); n += 1; offset -= 1; @@ -238,8 +258,14 @@ impl CTree { .collect(); commitments[0..nn].copy_from_slice(&next_level); + if let Some(mut lp) = last_path.take() { + lp.up(depth, prev_frontier.as_ref().unwrap().parents.get(depth).unwrap_or(&None)); + last_path = Some(lp); + } + depth += 1; n = nn; + offset /= 2; } frontier.witness.tree @@ -263,13 +289,12 @@ impl CTree { fn get_position(&self) -> usize { let mut p = 0usize; - for parent in self.parents.iter() { + for parent in self.parents.iter().rev() { if parent.is_some() { p += 1; } p *= 2; } - p *= 2; if self.left.is_some() { p += 1; } @@ -294,48 +319,55 @@ mod tests { */ #[test] fn test_calc_witnesses() { - const NUM_NODES: u32 = 100; // number of notes - const WITNESS_PERCENT: u32 = 100; // percentage of notes that are ours - const DEBUG_PRINT: bool = false; + const NUM_CHUNKS: usize = 20; + const NUM_NODES: usize = 200; // number of notes + const WITNESS_PERCENT: usize = 1; // percentage of notes that are ours + const DEBUG_PRINT: bool = true; - let witness_freq = 100 / WITNESS_PERCENT; + let witness_freq = 100000 / WITNESS_PERCENT; let mut tree1: CommitmentTree = CommitmentTree::empty(); - let mut nodes: Vec = vec![]; + let mut tree2: Option = None; let mut witnesses: Vec> = vec![]; - let mut positions: Vec = vec![]; - for i in 1..=NUM_NODES { - let mut bb = [0u8; 32]; - bb[0..4].copy_from_slice(&i.to_be_bytes()); - let node = Node::new(bb); + let mut all_positions: Vec = vec![]; - tree1.append(node).unwrap(); + for c in 0..NUM_CHUNKS { + let mut positions: Vec = vec![]; + let mut nodes: Vec = vec![]; + for i in 1..=NUM_NODES { + let mut bb = [0u8; 32]; + bb[0..8].copy_from_slice(&i.to_be_bytes()); + let node = Node::new(bb); - for w in witnesses.iter_mut() { - w.append(node).unwrap(); + tree1.append(node).unwrap(); + + for w in witnesses.iter_mut() { + w.append(node).unwrap(); + } + + if i % witness_freq == 0 { + let w = IncrementalWitness::::from_tree(&tree1); + witnesses.push(w); + positions.push((i - 1 + c * NUM_NODES) as usize); + } + + nodes.push(node); } - // if i % witness_freq == 0 { - if i == 65 { - let w = IncrementalWitness::::from_tree(&tree1); - witnesses.push(w); - positions.push((i - 1) as usize); - } - - nodes.push(node); + let start = Instant::now(); + let n = nodes.len(); + let mut positions: Vec<_> = positions.iter().map(|&p| NotePosition::new(p, n + c*NUM_NODES)).collect(); + all_positions.append(&mut positions); + tree2 = Some(CTree::calc_state(nodes, &mut all_positions, tree2)); + eprintln!( + "Update State & Witnesses: {} ms", + start.elapsed().as_millis() + ); } + let tree2 = tree2.unwrap(); - let start = Instant::now(); - let n = nodes.len(); - let mut positions: Vec<_> = positions.iter().map(|&p| NotePosition::new(p, n)).collect(); - let tree2 = CTree::calc_state(nodes, &mut positions, &CTree::new()); - eprintln!( - "Update State & Witnesses: {} ms", - start.elapsed().as_millis() - ); + println!("# witnesses = {}", all_positions.len()); - println!("# witnesses = {}", positions.len()); - - for (i, (w, p)) in witnesses.iter().zip(&positions).enumerate() { + for (i, (w, p)) in witnesses.iter().zip(&all_positions).enumerate() { let mut bb1: Vec = vec![]; w.write(&mut bb1).unwrap(); @@ -345,29 +377,37 @@ mod tests { assert_eq!(bb1.as_slice(), bb2.as_slice(), "failed at {}", i); } - if DEBUG_PRINT { - let slot = 0usize; - print_witness(&witnesses[slot]); + let mut bb1: Vec = vec![]; + tree1.write(&mut bb1).unwrap(); - println!("Tree"); - let t = &positions[slot].witness.tree; - println!("{:?}", t.left.map(|n| hex::encode(n.repr))); - println!("{:?}", t.right.map(|n| hex::encode(n.repr))); - for p in t.parents.iter() { - println!("{:?}", p.map(|n| hex::encode(n.repr))); - } - println!("Filled"); - for f in positions[slot].witness.filled.iter() { - println!("{:?}", hex::encode(f.repr)); - } - println!("Cursor"); - let t = &positions[slot].witness.cursor; - println!("{:?}", t.left.map(|n| hex::encode(n.repr))); - println!("{:?}", t.right.map(|n| hex::encode(n.repr))); - for p in t.parents.iter() { - println!("{:?}", p.map(|n| hex::encode(n.repr))); - } - println!("===="); + let mut bb2: Vec = vec![]; + tree2.write(&mut bb2).unwrap(); + + assert_eq!(bb1.as_slice(), bb2.as_slice(), "tree states not equal"); + + if DEBUG_PRINT { + // let slot = 0usize; + // print_witness(&witnesses[slot]); + // + // println!("Tree"); + // let t = &all_positions[slot].witness.tree; + // println!("{:?}", t.left.map(|n| hex::encode(n.repr))); + // println!("{:?}", t.right.map(|n| hex::encode(n.repr))); + // for p in t.parents.iter() { + // println!("{:?}", p.map(|n| hex::encode(n.repr))); + // } + // println!("Filled"); + // for f in all_positions[slot].witness.filled.iter() { + // println!("{:?}", hex::encode(f.repr)); + // } + // println!("Cursor"); + // let t = &all_positions[slot].witness.cursor; + // println!("{:?}", t.left.map(|n| hex::encode(n.repr))); + // println!("{:?}", t.right.map(|n| hex::encode(n.repr))); + // for p in t.parents.iter() { + // println!("{:?}", p.map(|n| hex::encode(n.repr))); + // } + // println!("===="); println!("{:?}", tree1.left.map(|n| hex::encode(n.repr))); println!("{:?}", tree1.right.map(|n| hex::encode(n.repr)));