From 290b66d5c8f458c12935e82f195d41e6e48bcbcb Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Tue, 16 May 2023 12:32:36 -0600 Subject: [PATCH] Add caching of the "cap" to root & witness computation. --- incrementalmerkletree/src/frontier.rs | 2 +- shardtree/src/lib.rs | 375 +++++++++++++++++++++++--- 2 files changed, 342 insertions(+), 35 deletions(-) diff --git a/incrementalmerkletree/src/frontier.rs b/incrementalmerkletree/src/frontier.rs index 973d0a6..94db76b 100644 --- a/incrementalmerkletree/src/frontier.rs +++ b/incrementalmerkletree/src/frontier.rs @@ -358,7 +358,7 @@ impl CommitmentTree { } pub fn leaf(&self) -> Option<&H> { - self.right.as_ref().or_else(|| self.left.as_ref()) + self.right.as_ref().or(self.left.as_ref()) } pub fn ommers_iter(&self) -> Box + '_> { diff --git a/shardtree/src/lib.rs b/shardtree/src/lib.rs index ca063ea..c2fbf8c 100644 --- a/shardtree/src/lib.rs +++ b/shardtree/src/lib.rs @@ -166,7 +166,7 @@ impl Tree { /// Replaces the annotation at the root of the tree, if the root is a `Node::Parent`; otherwise /// returns this tree unaltered. - pub fn reannotate_root(self, ann: A) -> Tree { + pub fn reannotate_root(self, ann: A) -> Self { Tree(self.0.reannotate(ann)) } @@ -1407,15 +1407,13 @@ impl LocatedPrunableTree { // the current address is a left child, so create a parent with // an empty right-hand tree subtree = Tree::parent(None, subtree, Tree::empty()); + } else if let Some(left) = ommers.next() { + // the current address corresponds to a right child, so create a parent that + // takes the left sibling to that child from the ommers + subtree = + Tree::parent(None, Tree::leaf((left, RetentionFlags::EPHEMERAL)), subtree); } else { - if let Some(left) = ommers.next() { - // the current address corresponds to a right child, so create a parent that - // takes the left sibling to that child from the ommers - subtree = - Tree::parent(None, Tree::leaf((left, RetentionFlags::EPHEMERAL)), subtree); - } else { - break; - } + break; } addr = addr.parent(); @@ -1429,7 +1427,7 @@ impl LocatedPrunableTree { let located_supertree = if located_subtree.root_addr().level() == split_at { let mut addr = located_subtree.root_addr(); let mut supertree = None; - while let Some(left) = ommers.next() { + for left in ommers { // build up the left-biased tree until we get a right-hand node while addr.index() & 0x1 == 0 { supertree = supertree.map(|t| Tree::parent(None, t, Tree::empty())); @@ -1500,7 +1498,7 @@ impl LocatedPrunableTree { // add filled nodes to the supertree let supertree = if addr.level() == split_at { let mut supertree = None; - while let Some(right) = filled.next() { + for right in filled { // build up the right-biased tree until we get a left-hand node while addr.index() & 0x1 == 1 { supertree = supertree.map(|t| Tree::parent(None, Tree::empty(), t)); @@ -1936,7 +1934,7 @@ impl MemoryShardStore { } } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum MemoryShardStoreError { Insertion(InsertionError), Query(QueryError), @@ -2250,10 +2248,7 @@ where &mut self, frontier: NonEmptyFrontier, leaf_retention: Retention, - ) -> Result<(), S::Error> - where - S::Error: From, - { + ) -> Result<(), S::Error> { let leaf_position = frontier.position(); let subtree_root_addr = Address::above_position(Self::subtree_level(), leaf_position); @@ -2293,16 +2288,13 @@ where &mut self, witness: IncrementalWitness, checkpoint_id: S::CheckpointId, - ) -> Result<(), S::Error> - where - S::Error: From, - { + ) -> Result<(), S::Error> { let leaf_position = witness.witnessed_position(); let subtree_root_addr = Address::above_position(Self::subtree_level(), leaf_position); let shard = self .store - .get_shard(subtree_root_addr) + .get_shard(subtree_root_addr)? .unwrap_or_else(|| LocatedTree::empty(subtree_root_addr)); let (updated_subtree, supertree, tip_subtree) = @@ -2328,7 +2320,7 @@ where let tip_shard = self .store - .get_shard(tip_subtree_addr) + .get_shard(tip_subtree_addr)? .unwrap_or_else(|| LocatedTree::empty(tip_subtree_addr)); self.store @@ -2570,6 +2562,7 @@ where self.store .truncate(Address::from_parts(Self::subtree_level(), 0))?; self.store.truncate_checkpoints(&checkpoint_id)?; + self.store.put_cap(Tree::empty())?; true } TreeState::AtPosition(position) => { @@ -2579,13 +2572,23 @@ where .get_shard(subtree_addr)? .and_then(|s| s.truncate_to_position(position)); - if let Some(truncated) = replacement { - self.store.truncate(subtree_addr)?; - self.store.put_shard(truncated)?; - self.store.truncate_checkpoints(&checkpoint_id)?; - true - } else { - false + let cap_tree = LocatedTree { + root_addr: Self::root_addr(), + root: self.store.get_cap()?, + }; + + if let Some(truncated) = cap_tree.truncate_to_position(position) { + self.store.put_cap(truncated.root)?; + }; + + match replacement { + Some(truncated) => { + self.store.truncate(subtree_addr)?; + self.store.put_shard(truncated)?; + self.store.truncate_checkpoints(&checkpoint_id)?; + true + } + None => false, } } } @@ -2604,6 +2607,181 @@ where /// /// Use [`Self::root_at_checkpoint`] to obtain the root of the overall tree. pub fn root(&self, address: Address, truncate_at: Position) -> Result { + assert!(Self::root_addr().contains(&address)); + + // traverse the cap from root to leaf depth-first, either returning an existing + // cached value for the node or inserting the computed value into the cache + let (root, _) = self.root_internal( + &LocatedPrunableTree { + root_addr: Self::root_addr(), + root: self.store.get_cap()?, + }, + address, + truncate_at, + )?; + Ok(root) + } + + pub fn root_caching(&mut self, address: Address, truncate_at: Position) -> Result { + let (root, updated_cap) = self.root_internal( + &LocatedPrunableTree { + root_addr: Self::root_addr(), + root: self.store.get_cap()?, + }, + address, + truncate_at, + )?; + if let Some(updated_cap) = updated_cap { + self.store.put_cap(updated_cap)?; + } + Ok(root) + } + + // compute the root, along with an optional update to the cap + fn root_internal( + &self, + cap: &LocatedPrunableTree, + // The address at which we want to compute the root hash + target_addr: Address, + // An inclusive lower bound for positions whose leaf values will be replaced by empty + // roots. + truncate_at: Position, + ) -> Result<(H, Option>), S::Error> { + match &cap.root { + Tree(Node::Parent { ann, left, right }) => { + match ann { + Some(cached_root) if target_addr.contains(&cap.root_addr) => { + Ok((cached_root.as_ref().clone(), None)) + } + _ => { + // Compute the roots of the left and right children and hash them together. + // We skip computation in any subtrees that will not have data included in + // the final result. + let (l_addr, r_addr) = cap.root_addr.children().unwrap(); + let l_result = if r_addr.contains(&target_addr) { + None + } else { + Some(self.root_internal( + &LocatedPrunableTree { + root_addr: l_addr, + root: left.as_ref().clone(), + }, + if l_addr.contains(&target_addr) { + target_addr + } else { + l_addr + }, + truncate_at, + )?) + }; + let r_result = if l_addr.contains(&target_addr) { + None + } else { + Some(self.root_internal( + &LocatedPrunableTree { + root_addr: r_addr, + root: right.as_ref().clone(), + }, + if r_addr.contains(&target_addr) { + target_addr + } else { + r_addr + }, + truncate_at, + )?) + }; + + // Compute the root value based on the child roots; these may contain the + // hashes of empty/truncated nodes. + let (root, new_left, new_right) = match (l_result, r_result) { + (Some((l_root, new_left)), Some((r_root, new_right))) => ( + S::H::combine(l_addr.level(), &l_root, &r_root), + new_left, + new_right, + ), + (Some((l_root, new_left)), None) => (l_root, new_left, None), + (None, Some((r_root, new_right))) => (r_root, None, new_right), + (None, None) => unreachable!(), + }; + + let new_parent = Tree(Node::Parent { + ann: new_left + .as_ref() + .and_then(|l| l.node_value()) + .zip(new_right.as_ref().and_then(|r| r.node_value())) + .map(|(l, r)| { + // the node values of child nodes cannot contain the hashes of + // empty nodes or nodes with positions greater than the + Rc::new(S::H::combine(l_addr.level(), l, r)) + }), + left: new_left.map_or_else(|| left.clone(), Rc::new), + right: new_right.map_or_else(|| right.clone(), Rc::new), + }); + + Ok((root, Some(new_parent))) + } + } + } + Tree(Node::Leaf { value }) => { + if truncate_at >= cap.root_addr.position_range_end() + && target_addr.contains(&cap.root_addr) + { + // no truncation or computation of child subtrees of this leaf is necessary, just use + // the cached leaf value + Ok((value.0.clone(), None)) + } else { + // since the tree was truncated below this level, recursively call with an + // empty parent node to trigger the continued traversal + let (root, replacement) = self.root_internal( + &LocatedPrunableTree { + root_addr: cap.root_addr(), + root: Tree::parent(None, Tree::empty(), Tree::empty()), + }, + target_addr, + truncate_at, + )?; + + Ok(( + root, + replacement.map(|r| r.reannotate_root(Some(Rc::new(value.0.clone())))), + )) + } + } + Tree(Node::Nil) => { + if cap.root_addr == target_addr + || cap.root_addr.level() == ShardTree::::subtree_level() + { + // We are at the leaf level or the target address; compute the root hash and + // return it as cacheable if it is not truncated. + let root = self.root_from_shards(target_addr, truncate_at)?; + Ok(( + root.clone(), + if truncate_at >= cap.root_addr.position_range_end() { + // return the compute root as a new leaf to be cached if it contains no + // empty hashes due to truncation + Some(Tree::leaf((root, RetentionFlags::EPHEMERAL))) + } else { + None + }, + )) + } else { + // Compute the result by recursively walking down the tree. By replacing + // the current node with a parent node, the `Parent` handler will take care + // of the branching recursive calls. + self.root_internal( + &LocatedPrunableTree { + root_addr: cap.root_addr, + root: Tree::parent(None, Tree::empty(), Tree::empty()), + }, + target_addr, + truncate_at, + ) + } + } + } + } + + fn root_from_shards(&self, address: Address, truncate_at: Position) -> Result { match address.context(Self::subtree_level()) { Either::Left(subtree_addr) => { // The requested root address is fully contained within one of the subtrees. @@ -2739,6 +2917,13 @@ where ) } + pub fn root_at_checkpoint_caching(&mut self, checkpoint_depth: usize) -> Result { + self.max_leaf_position(checkpoint_depth)?.map_or_else( + || Ok(H::empty_root(Self::root_addr().level())), + |pos| self.root_caching(Self::root_addr(), pos + 1), + ) + } + /// Computes the witness for the leaf at the specified position. /// /// Returns the witness as of the most recently appended leaf if `checkpoint_depth == 0`. Note @@ -2750,13 +2935,14 @@ where checkpoint_depth: usize, ) -> Result, S::Error> { let max_leaf_position = self.max_leaf_position(checkpoint_depth).and_then(|v| { - v.ok_or_else(|| S::Error::from(QueryError::TreeIncomplete(vec![Self::root_addr()]))) + v.ok_or_else(|| QueryError::TreeIncomplete(vec![Self::root_addr()]).into()) })?; if position > max_leaf_position { - Err(S::Error::from(QueryError::NotContained( - Address::from_parts(Level::from(0), position.into()), - ))) + Err( + QueryError::NotContained(Address::from_parts(Level::from(0), position.into())) + .into(), + ) } else { let subtree_addr = Self::subtree_addr(position); @@ -2778,6 +2964,41 @@ where } } + pub fn witness_caching( + &mut self, + position: Position, + checkpoint_depth: usize, + ) -> Result, S::Error> { + let max_leaf_position = self.max_leaf_position(checkpoint_depth).and_then(|v| { + v.ok_or_else(|| QueryError::TreeIncomplete(vec![Self::root_addr()]).into()) + })?; + + if position > max_leaf_position { + Err( + QueryError::NotContained(Address::from_parts(Level::from(0), position.into())) + .into(), + ) + } else { + let subtree_addr = Address::above_position(Self::subtree_level(), position); + + // compute the witness for the specified position up to the subtree root + let mut witness = self.store.get_shard(subtree_addr)?.map_or_else( + || Err(QueryError::TreeIncomplete(vec![subtree_addr])), + |subtree| subtree.witness(position, max_leaf_position + 1), + )?; + + // compute the remaining parts of the witness up to the root + let root_addr = Self::root_addr(); + let mut cur_addr = subtree_addr; + while cur_addr != root_addr { + witness.push(self.root_caching(cur_addr.sibling(), max_leaf_position + 1)?); + cur_addr = cur_addr.parent(); + } + + Ok(MerklePath::from_parts(witness, position).unwrap()) + } + } + /// Make a marked leaf at a position eligible to be pruned. /// /// If the checkpoint associated with the specified identifier does not exist because the @@ -2844,6 +3065,8 @@ fn accumulate_result_with( pub mod testing { use super::*; use incrementalmerkletree::Hashable; + use proptest::bool::weighted; + use proptest::collection::vec; use proptest::prelude::*; use proptest::sample::select; @@ -2885,11 +3108,70 @@ pub mod testing { }) }) } + + /// Constructs a random shardtree of size up to 2^6 with shards of size 2^3. Returns the tree, + /// along with vectors of the checkpoint and mark positions. + pub fn arb_shardtree( + arb_leaf: H, + ) -> impl Strategy< + Value = ( + ShardTree, 6, 3>, + Vec, + Vec, + ), + > + where + H::Value: Hashable + Clone + PartialEq, + { + vec( + (arb_leaf, weighted(0.1), weighted(0.2)), + 0..=(2usize.pow(6)), + ) + .prop_map(|leaves| { + let mut tree = ShardTree::empty(MemoryShardStore::empty(), 10); + let mut checkpoint_positions = vec![]; + let mut marked_positions = vec![]; + tree.batch_insert( + Position::from(0), + leaves + .into_iter() + .enumerate() + .map(|(id, (leaf, is_marked, is_checkpoint))| { + ( + leaf, + match (is_checkpoint, is_marked) { + (false, false) => Retention::Ephemeral, + (true, is_marked) => { + let pos = Position::try_from(id).unwrap(); + checkpoint_positions.push(pos); + if is_marked { + marked_positions.push(pos); + } + Retention::Checkpoint { id, is_marked } + } + (false, true) => { + marked_positions.push(Position::try_from(id).unwrap()); + Retention::Marked + } + }, + ) + }), + ) + .unwrap(); + (tree, checkpoint_positions, marked_positions) + }) + } + + pub fn arb_char_str() -> impl Strategy { + let chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + (0usize..chars.len()).prop_map(move |i| chars.get(i..=i).unwrap().to_string()) + } } #[cfg(test)] mod tests { use crate::{ + testing::{arb_char_str, arb_shardtree}, IncompleteAt, InsertionError, LocatedPrunableTree, LocatedTree, MemoryShardStore, MemoryShardStoreError, Node, PrunableTree, QueryError, RetentionFlags, ShardStore, ShardTree, Tree, @@ -3460,6 +3742,30 @@ mod tests { } } + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn check_shardtree_caching( + (mut tree, _, marked_positions) in arb_shardtree(arb_char_str()) + ) { + if let Some(max_leaf_pos) = tree.max_leaf_position(0).unwrap() { + let max_complete_addr = Address::above_position(max_leaf_pos.root_level(), max_leaf_pos); + let root = tree.root(max_complete_addr, max_leaf_pos + 1); + let caching_root = tree.root_caching(max_complete_addr, max_leaf_pos + 1); + assert_matches!(root, Ok(_)); + assert_eq!(root, caching_root); + + for pos in marked_positions { + let witness = tree.witness(pos, 0); + let caching_witness = tree.witness_caching(pos, 0); + assert_matches!(witness, Ok(_)); + assert_eq!(witness, caching_witness); + } + } + } + } + #[test] fn insert_frontier_nodes() { let mut frontier = NonEmptyFrontier::new("a".to_string()); @@ -3539,7 +3845,8 @@ mod tests { ); assert_eq!( - r.root.root_hash(Address::from_parts(Level::from(3), 3), Position::from(25)), + r.root + .root_hash(Address::from_parts(Level::from(3), 3), Position::from(25)), Ok("y_______".to_string()) ); }