//! Sample implementation of the Tree interface. use std::collections::BTreeSet; use super::{Frontier, Tree}; use crate::{ hashing::Hashable, position::{Level, Position}, }; #[derive(Clone, Debug)] pub struct TreeState { leaves: Vec, current_offset: usize, marks: BTreeSet, depth: usize, } impl TreeState { /// Creates a new, empty binary tree of specified depth. #[cfg(test)] pub fn new(depth: usize) -> Self { Self { leaves: vec![H::empty_leaf(); 1 << depth], current_offset: 0, marks: BTreeSet::new(), depth, } } } impl Frontier for TreeState { fn append(&mut self, value: &H) -> bool { if self.current_offset == (1 << self.depth) { false } else { self.leaves[self.current_offset] = value.clone(); self.current_offset += 1; true } } /// Obtains the current root of this Merkle tree. fn root(&self) -> H { lazy_root(self.leaves.clone()) } } impl TreeState { fn current_position(&self) -> Option { if self.current_offset == 0 { None } else { Some((self.current_offset - 1).into()) } } /// Returns the leaf most recently appended to the tree fn current_leaf(&self) -> Option<&H> { self.current_position() .map(|p| &self.leaves[::from(p)]) } /// Returns the leaf at the specified position if the tree can produce /// a witness for it. fn get_marked_leaf(&self, position: Position) -> Option<&H> { if self.marks.contains(&position) { self.leaves.get(::from(position)) } else { None } } /// Marks the current tree state leaf as a value that we're interested in /// marking. Returns the current position if the tree is non-empty. fn mark(&mut self) -> Option { self.current_position().map(|pos| { self.marks.insert(pos); pos }) } /// Obtains a witness to the value at the specified position. /// Returns `None` if there is no available witness to that /// value. fn witness(&self, position: Position) -> Option> { if Some(position) <= self.current_position() { let mut path = vec![]; let mut leaf_idx: usize = position.into(); for bit in 0..self.depth { leaf_idx ^= 1 << bit; path.push(lazy_root::( self.leaves[leaf_idx..][0..(1 << bit)].to_vec(), )); leaf_idx &= usize::MAX << (bit + 1); } Some(path) } else { None } } /// Marks the value at the specified position as a value we're no longer /// interested in maintaining a mark for. Returns true if successful and /// false if we were already not maintaining a mark at this position. fn remove_mark(&mut self, position: Position) -> bool { self.marks.remove(&position) } } #[derive(Clone, Debug)] pub struct CompleteTree { tree_state: TreeState, checkpoints: Vec>, max_checkpoints: usize, } impl CompleteTree { /// Creates a new, empty binary tree of specified depth. #[cfg(test)] pub fn new(depth: usize, max_checkpoints: usize) -> Self { Self { tree_state: TreeState::new(depth), checkpoints: vec![], max_checkpoints, } } } impl CompleteTree { /// Removes the oldest checkpoint. Returns true if successful and false if /// there are no checkpoints. fn drop_oldest_checkpoint(&mut self) -> bool { if self.checkpoints.is_empty() { false } else { self.checkpoints.remove(0); true } } /// Retrieve the tree state at the specified checkpoint depth. This /// is the current tree state if the depth is 0, and this will return /// None if not enough checkpoints exist to obtain the requested depth. fn tree_state_at_checkpoint_depth(&self, checkpoint_depth: usize) -> Option<&TreeState> { if self.checkpoints.len() < checkpoint_depth { None } else if checkpoint_depth == 0 { Some(&self.tree_state) } else { self.checkpoints .get(self.checkpoints.len() - checkpoint_depth) } } } impl Tree for CompleteTree { /// Appends a new value to the tree at the next available slot. Returns true /// if successful and false if the tree is full. fn append(&mut self, value: &H) -> bool { self.tree_state.append(value) } /// Returns the most recently appended leaf value. fn current_position(&self) -> Option { self.tree_state.current_position() } fn current_leaf(&self) -> Option<&H> { self.tree_state.current_leaf() } fn get_marked_leaf(&self, position: Position) -> Option<&H> { self.tree_state.get_marked_leaf(position) } fn mark(&mut self) -> Option { self.tree_state.mark() } fn marked_positions(&self) -> BTreeSet { self.tree_state.marks.clone() } fn root(&self, checkpoint_depth: usize) -> Option { self.tree_state_at_checkpoint_depth(checkpoint_depth) .map(|s| s.root()) } fn witness(&self, position: Position, root: &H) -> Option> { // Search for the checkpointed state corresponding to the provided root, and if one is // found, compute the witness as of that root. self.checkpoints .iter() .chain(Some(&self.tree_state)) .rev() .skip_while(|c| !c.marks.contains(&position)) .find_map(|c| { if &c.root() == root { c.witness(position) } else { None } }) } fn remove_mark(&mut self, position: Position) -> bool { self.tree_state.remove_mark(position) } fn checkpoint(&mut self) { self.checkpoints.push(self.tree_state.clone()); if self.checkpoints.len() > self.max_checkpoints { self.drop_oldest_checkpoint(); } } fn rewind(&mut self) -> bool { if let Some(checkpointed_state) = self.checkpoints.pop() { self.tree_state = checkpointed_state; true } else { false } } } pub(crate) fn lazy_root(mut leaves: Vec) -> H { //leaves are always at level zero, so we start there. let mut level = Level::from(0); while leaves.len() != 1 { leaves = leaves .iter() .enumerate() .filter(|(i, _)| (i % 2) == 0) .map(|(_, a)| a) .zip( leaves .iter() .enumerate() .filter(|(i, _)| (i % 2) == 1) .map(|(_, b)| b), ) .map(|(a, b)| H::combine(level, a, b)) .collect(); level = level + 1; } leaves[0].clone() } #[cfg(test)] mod tests { use std::convert::TryFrom; use super::CompleteTree; use crate::{ hashing::Hashable, position::{Level, Position}, testing::{ tests::{self, compute_root_from_witness}, SipHashable, Tree, }, }; #[test] fn correct_empty_root() { const DEPTH: u8 = 5; let mut expected = SipHashable(0u64); for lvl in 0u8..DEPTH { expected = SipHashable::combine(lvl.into(), &expected, &expected); } let tree = CompleteTree::::new(DEPTH as usize, 100); assert_eq!(tree.root(0).unwrap(), expected); } #[test] fn correct_root() { const DEPTH: usize = 3; let values = (0..(1 << DEPTH)).into_iter().map(SipHashable); let mut tree = CompleteTree::::new(DEPTH, 100); for value in values { assert!(tree.append(&value)); } assert!(!tree.append(&SipHashable(0))); let expected = SipHashable::combine( Level::from(2), &SipHashable::combine( Level::from(1), &SipHashable::combine(Level::from(1), &SipHashable(0), &SipHashable(1)), &SipHashable::combine(Level::from(1), &SipHashable(2), &SipHashable(3)), ), &SipHashable::combine( Level::from(1), &SipHashable::combine(Level::from(1), &SipHashable(4), &SipHashable(5)), &SipHashable::combine(Level::from(1), &SipHashable(6), &SipHashable(7)), ), ); assert_eq!(tree.root(0).unwrap(), expected); } #[test] fn root_hashes() { tests::check_root_hashes(|max_c| CompleteTree::::new(4, max_c)); } #[test] fn witnesss() { tests::check_witnesss(|max_c| CompleteTree::::new(4, max_c)); } #[test] fn correct_witness() { const DEPTH: usize = 3; let values = (0..(1 << DEPTH)).into_iter().map(SipHashable); let mut tree = CompleteTree::::new(DEPTH, 100); for value in values { assert!(tree.append(&value)); tree.mark(); } assert!(!tree.append(&SipHashable(0))); let expected = SipHashable::combine( ::from(2), &SipHashable::combine( Level::from(1), &SipHashable::combine(Level::from(1), &SipHashable(0), &SipHashable(1)), &SipHashable::combine(Level::from(1), &SipHashable(2), &SipHashable(3)), ), &SipHashable::combine( Level::from(1), &SipHashable::combine(Level::from(1), &SipHashable(4), &SipHashable(5)), &SipHashable::combine(Level::from(1), &SipHashable(6), &SipHashable(7)), ), ); assert_eq!(tree.root(0).unwrap(), expected); for i in 0u64..(1 << DEPTH) { let position = Position::try_from(i).unwrap(); let path = tree.witness(position, &tree.root(0).unwrap()).unwrap(); assert_eq!( compute_root_from_witness(SipHashable(i), position, &path), expected ); } } #[test] fn checkpoint_rewind() { tests::check_checkpoint_rewind(|max_c| CompleteTree::::new(4, max_c)); } #[test] fn rewind_remove_mark() { tests::check_rewind_remove_mark(|max_c| CompleteTree::::new(4, max_c)); } }