diff --git a/shardtree/src/lib.rs b/shardtree/src/lib.rs index f4811bf..ab7e73f 100644 --- a/shardtree/src/lib.rs +++ b/shardtree/src/lib.rs @@ -1713,6 +1713,98 @@ impl< Ok(()) } + /// Adds a checkpoint at the rightmost leaf state of the tree. + pub fn checkpoint(&mut self, checkpoint_id: C) -> bool { + fn go( + root_addr: Address, + root: &PrunableTree, + ) -> Option<(PrunableTree, Position)> { + match root { + Tree(Node::Parent { ann, left, right }) => { + let (l_addr, r_addr) = root_addr.children().unwrap(); + go(r_addr, right).map_or_else( + || { + go(l_addr, left).map(|(new_left, pos)| { + ( + Tree::unite( + l_addr.level(), + ann.clone(), + new_left, + right.as_ref().clone(), + ), + pos, + ) + }) + }, + |(new_right, pos)| { + Some(( + Tree::unite( + l_addr.level(), + ann.clone(), + left.as_ref().clone(), + new_right, + ), + pos, + )) + }, + ) + } + Tree(Node::Leaf { value: (h, r) }) => Some(( + Tree(Node::Leaf { + value: (h.clone(), *r | CHECKPOINT), + }), + root_addr.max_position(), + )), + Tree(Node::Nil) => None, + } + } + + // checkpoint identifiers at the tip must be in increasing order + if self.checkpoints.keys().last() >= Some(&checkpoint_id) { + return false; + } + + // Search backward from the end of the subtrees iter to find a non-empty subtree. + // When we find one, update the subtree to add the `CHECKPOINT` flag to the + // right-most leaf (which need not be a level-0 leaf; it's fine to rewind to a + // pruned state). + for subtree_addr in self.store.get_shard_roots().iter().rev() { + let subtree = self.store.get_shard(*subtree_addr).expect( + "The store should not return root addresses for subtrees it cannot provide.", + ); + if let Some((replacement, checkpoint_position)) = go(*subtree_addr, &subtree.root) { + if self + .store + .put_shard(LocatedTree { + root_addr: *subtree_addr, + root: replacement, + }) + .is_err() + { + return false; + } + self.checkpoints + .insert(checkpoint_id, Checkpoint::at_position(checkpoint_position)); + + // early return once we've updated the tree state + return self + .prune_excess_checkpoints() + .map_err(InsertionError::Storage) + .is_ok(); + } + } + + self.checkpoints + .insert(checkpoint_id, Checkpoint::tree_empty()); + + // TODO: it should not be necessary to do this on every checkpoint, + // but currently that's how the reference tree behaves so we're maintaining + // those semantics for test compatibility. + self.prune_excess_checkpoints() + .map_err(InsertionError::Storage) + .is_ok() + } + fn prune_excess_checkpoints(&mut self) -> Result<(), S::Error> { if self.checkpoints.len() > self.max_checkpoints { // Batch removals by subtree & create a list of the checkpoint identifiers that @@ -1776,6 +1868,115 @@ impl< Ok(()) } + /// Computes the root of any subtree of this tree rooted at the given address, with the overall + /// tree truncated to the specified position. + /// + /// The specified address is not required to be at any particular level, though it cannot + /// exceed the level corresponding to the maximum depth of the tree. Nodes to the right of the + /// given position, and parents of such nodes, will be replaced by the empty root for the + /// associated level. + /// + /// Use [`Self::root_at_checkpoint`] to obtain the root of the overall tree. + pub fn root(&self, address: Address, truncate_at: Position) -> Result { + match address.context(Self::subtree_level()) { + Either::Left(subtree_addr) => { + // The requested root address is fully contained within one of the subtrees. + if truncate_at <= address.position_range_start() { + Ok(H::empty_root(address.level())) + } else { + // get the child of the subtree with its root at `address` + self.store + .get_shard(subtree_addr) + .ok_or_else(|| vec![subtree_addr]) + .and_then(|subtree| { + subtree.subtree(address).map_or_else( + || Err(vec![address]), + |child| child.root_hash(truncate_at), + ) + }) + .map_err(QueryError::TreeIncomplete) + } + } + Either::Right(subtree_range) => { + // The requested root requires hashing together the roots of several subtrees. + let mut root_stack = vec![]; + let mut incomplete = vec![]; + + for subtree_idx in subtree_range { + let subtree_addr = Address::from_parts(Self::subtree_level(), subtree_idx); + if truncate_at <= subtree_addr.position_range_start() { + break; + } + + let subtree_root = self + .store + .get_shard(subtree_addr) + .ok_or_else(|| vec![subtree_addr]) + .and_then(|s| s.root_hash(truncate_at)); + + match subtree_root { + Ok(mut cur_hash) => { + if subtree_addr.index() % 2 == 0 { + root_stack.push((subtree_addr, cur_hash)) + } else { + let mut cur_addr = subtree_addr; + while let Some((addr, hash)) = root_stack.pop() { + if addr.parent() == cur_addr.parent() { + cur_hash = H::combine(cur_addr.level(), &hash, &cur_hash); + cur_addr = cur_addr.parent(); + } else { + root_stack.push((addr, hash)); + break; + } + } + root_stack.push((cur_addr, cur_hash)); + } + } + Err(mut new_incomplete) => { + // Accumulate incomplete root information and continue, so that we can + // return the complete set of incomplete results. + incomplete.append(&mut new_incomplete); + } + } + } + + if !incomplete.is_empty() { + return Err(QueryError::TreeIncomplete(incomplete)); + } + + // Now hash with empty roots to obtain the root at maximum height + if let Some((mut cur_addr, mut cur_hash)) = root_stack.pop() { + while let Some((addr, hash)) = root_stack.pop() { + while addr.level() > cur_addr.level() { + cur_hash = H::combine( + cur_addr.level(), + &cur_hash, + &H::empty_root(cur_addr.level()), + ); + cur_addr = cur_addr.parent(); + } + cur_hash = H::combine(cur_addr.level(), &hash, &cur_hash); + cur_addr = cur_addr.parent(); + } + + while cur_addr.level() < address.level() { + cur_hash = H::combine( + cur_addr.level(), + &cur_hash, + &H::empty_root(cur_addr.level()), + ); + cur_addr = cur_addr.parent(); + } + + Ok(cur_hash) + } else { + // if the stack is empty, we just return the default root at max height + Ok(H::empty_root(address.level())) + } + } + } + } + /// Returns the position of the checkpoint, if any, along with the number of subsequent /// checkpoints at the same position. Returns `None` if `checkpoint_depth == 0` or if /// insufficient checkpoints exist to seek back to the requested depth. @@ -1813,6 +2014,18 @@ impl< } } } + + /// Computes the root of the tree as of the checkpointed position at the specified depth. + /// + /// Returns the root as of the most recently appended leaf if `checkpoint_depth == 0`. Note + /// that if the most recently appended leaf is also a checkpoint, this will return the same + /// result as `checkpoint_depth == 1`. + pub fn root_at_checkpoint(&self, checkpoint_depth: usize) -> Result { + self.max_leaf_position(checkpoint_depth)?.map_or_else( + || Ok(H::empty_root(Self::root_addr().level())), + |pos| self.root(Self::root_addr(), pos + 1), + ) + } } // We need an applicative functor for Result for this function so that we can correctly @@ -1883,7 +2096,9 @@ mod tests { }; use core::convert::Infallible; use incrementalmerkletree::{ - testing::{self, check_append, complete_tree::CompleteTree, CombinedTree}, + testing::{ + self, check_append, check_root_hashes, complete_tree::CompleteTree, CombinedTree, + }, Address, Hashable, Level, Position, Retention, }; use std::collections::BTreeSet; @@ -2217,16 +2432,16 @@ mod tests { ShardTree::max_leaf_position(self, 0).ok().flatten() } - fn get_marked_leaf(&self, _position: Position) -> Option<&H> { - todo!() + fn get_marked_leaf(&self, position: Position) -> Option<&H> { + ShardTree::get_marked_leaf(self, position) } fn marked_positions(&self) -> BTreeSet { - todo!() + ShardTree::marked_positions(self) } - fn root(&self, _checkpoint_depth: usize) -> Option { - todo!() + fn root(&self, checkpoint_depth: usize) -> Option { + ShardTree::root_at_checkpoint(self, checkpoint_depth).ok() } fn witness(&self, _position: Position, _checkpoint_depth: usize) -> Option> { @@ -2237,8 +2452,8 @@ mod tests { todo!() } - fn checkpoint(&mut self, _checkpoint_id: C) -> bool { - todo!() + fn checkpoint(&mut self, checkpoint_id: C) -> bool { + ShardTree::checkpoint(self, checkpoint_id) } fn rewind(&mut self) -> bool { @@ -2253,6 +2468,12 @@ mod tests { }); } + #[test] + fn root_hashes() { + check_root_hashes(|m| { + ShardTree::, 4, 3>::new(vec![], m, 0) + }); + } // Combined tree tests #[allow(clippy::type_complexity)] fn new_combined_tree(