diff --git a/shardtree/src/lib.rs b/shardtree/src/lib.rs index 14b56a7..390d76c 100644 --- a/shardtree/src/lib.rs +++ b/shardtree/src/lib.rs @@ -1393,7 +1393,7 @@ impl Checkpoint { /// A capability for storage of fragment subtrees of the `ShardTree` type. /// /// All fragment subtrees must have roots at level `SHARD_HEIGHT - 1` -pub trait ShardStore { +pub trait ShardStore { type Error; /// Returns the subtree at the given root address, if any such subtree exists. @@ -1418,20 +1418,149 @@ pub trait ShardStore { /// Implementations of this method MUST enforce the constraint that the root address /// provided has level `SHARD_HEIGHT - 1`. fn truncate(&mut self, from: Address) -> Result<(), Self::Error>; + + // /// TODO: Add a tree that is used to cache the known roots of subtrees in the "cap" of nodes between + // /// `SHARD_HEIGHT` and `DEPTH` that are otherwise not directly represented in the tree. This + // /// cache will be automatically updated when computing roots and witnesses. Leaf nodes are empty + // /// because the annotation slot is consistently used to store the subtree hashes at each node. + // cap_cache: Tree>, ()> + + /// Returns the identifier for the checkpoint with the lowest associated position value. + fn min_checkpoint_id(&self) -> Option<&C>; + + /// Returns the identifier for the checkpoint with the highest associated position value. + fn max_checkpoint_id(&self) -> Option<&C>; + + /// Adds a checkpoint to the data store. + fn add_checkpoint( + &mut self, + checkpoint_id: C, + checkpoint: Checkpoint, + ) -> Result<(), Self::Error>; + + /// Returns the number of checkpoints maintained by the data store + fn checkpoint_count(&self) -> Result; + + /// Returns the position of the checkpoint, if any, along with the number of subsequent + /// checkpoints at the same position. Returns `None` if `checkpoint_depth == 0` or if + /// insufficient checkpoints exist to seek back to the requested depth. + fn get_checkpoint_at_depth(&self, checkpoint_depth: usize) -> Option<(&C, &Checkpoint)>; + + /// Iterates in checkpoint ID order over the first `limit` checkpoints, applying the + /// given callback to each. + fn with_checkpoints(&mut self, limit: usize, callback: F) -> Result<(), Self::Error> + where + F: FnMut(&C, &Checkpoint) -> Result<(), Self::Error>; + + /// Update the checkpoint having the given identifier by mutating it with the provided + /// function, and persist the updated checkpoint to the data store. + /// + /// Returns `Ok(true)` if the checkpoint was found, `Ok(false)` if no checkpoint with the + /// provided identifier exists in the data store, or an error if a storage error occurred. + fn update_checkpoint_with( + &mut self, + checkpoint_id: &C, + update: F, + ) -> Result + where + F: Fn(&mut Checkpoint) -> Result<(), Self::Error>; + + /// Removes a checkpoint from the data store. + fn remove_checkpoint(&mut self, checkpoint_id: &C) -> Result<(), Self::Error>; + + /// Removes checkpoints with identifiers greater than or equal to the given identifier + fn truncate_checkpoints(&mut self, checkpoint_id: &C) -> Result<(), Self::Error>; } -#[derive(Debug)] -pub struct MemoryShardStore { - shards: Vec>, -} +impl> ShardStore for &mut S { + type Error = S::Error; + fn get_shard(&self, shard_root: Address) -> Option<&LocatedPrunableTree> { + S::get_shard(*self, shard_root) + } -impl MemoryShardStore { - pub fn empty() -> Self { - Self { shards: vec![] } + fn last_shard(&self) -> Option<&LocatedPrunableTree> { + S::last_shard(*self) + } + + fn put_shard(&mut self, subtree: LocatedPrunableTree) -> Result<(), Self::Error> { + S::put_shard(*self, subtree) + } + + fn get_shard_roots(&self) -> Vec
{ + S::get_shard_roots(*self) + } + + fn truncate(&mut self, from: Address) -> Result<(), Self::Error> { + S::truncate(*self, from) + } + + fn min_checkpoint_id(&self) -> Option<&C> { + S::min_checkpoint_id(self) + } + + fn max_checkpoint_id(&self) -> Option<&C> { + S::max_checkpoint_id(self) + } + + fn add_checkpoint( + &mut self, + checkpoint_id: C, + checkpoint: Checkpoint, + ) -> Result<(), Self::Error> { + S::add_checkpoint(self, checkpoint_id, checkpoint) + } + + fn checkpoint_count(&self) -> Result { + S::checkpoint_count(self) + } + + fn get_checkpoint_at_depth(&self, checkpoint_depth: usize) -> Option<(&C, &Checkpoint)> { + S::get_checkpoint_at_depth(self, checkpoint_depth) + } + + fn with_checkpoints(&mut self, limit: usize, callback: F) -> Result<(), Self::Error> + where + F: FnMut(&C, &Checkpoint) -> Result<(), Self::Error>, + { + S::with_checkpoints(self, limit, callback) + } + + fn update_checkpoint_with( + &mut self, + checkpoint_id: &C, + update: F, + ) -> Result + where + F: Fn(&mut Checkpoint) -> Result<(), Self::Error>, + { + S::update_checkpoint_with(self, checkpoint_id, update) + } + + fn remove_checkpoint(&mut self, checkpoint_id: &C) -> Result<(), Self::Error> { + S::remove_checkpoint(self, checkpoint_id) + } + + fn truncate_checkpoints(&mut self, checkpoint_id: &C) -> Result<(), Self::Error> { + S::truncate_checkpoints(self, checkpoint_id) } } -impl ShardStore for MemoryShardStore { +#[derive(Debug)] +pub struct MemoryShardStore { + shards: Vec>, + checkpoints: BTreeMap, +} + +impl MemoryShardStore { + pub fn empty() -> Self { + Self { + shards: vec![], + checkpoints: BTreeMap::new(), + } + } +} + +impl ShardStore for MemoryShardStore { type Error = Infallible; fn get_shard(&self, shard_root: Address) -> Option<&LocatedPrunableTree> { @@ -1470,6 +1599,72 @@ impl ShardStore for MemoryShardStore { self.shards.truncate(shard_idx); Ok(()) } + + fn add_checkpoint( + &mut self, + checkpoint_id: C, + checkpoint: Checkpoint, + ) -> Result<(), Self::Error> { + self.checkpoints.insert(checkpoint_id, checkpoint); + Ok(()) + } + + fn checkpoint_count(&self) -> Result { + Ok(self.checkpoints.len()) + } + + fn get_checkpoint_at_depth(&self, checkpoint_depth: usize) -> Option<(&C, &Checkpoint)> { + if checkpoint_depth == 0 { + None + } else { + self.checkpoints.iter().rev().nth(checkpoint_depth - 1) + } + } + + fn min_checkpoint_id(&self) -> Option<&C> { + self.checkpoints.keys().next() + } + + fn max_checkpoint_id(&self) -> Option<&C> { + self.checkpoints.keys().last() + } + + fn with_checkpoints(&mut self, limit: usize, mut callback: F) -> Result<(), Self::Error> + where + F: FnMut(&C, &Checkpoint) -> Result<(), Self::Error>, + { + for (cid, checkpoint) in self.checkpoints.iter().take(limit) { + callback(cid, checkpoint)? + } + + Ok(()) + } + + fn update_checkpoint_with( + &mut self, + checkpoint_id: &C, + update: F, + ) -> Result + where + F: Fn(&mut Checkpoint) -> Result<(), Self::Error>, + { + if let Some(c) = self.checkpoints.get_mut(checkpoint_id) { + update(c)?; + return Ok(true); + } + + return Ok(false); + } + + fn remove_checkpoint(&mut self, checkpoint_id: &C) -> Result<(), Self::Error> { + self.checkpoints.remove(checkpoint_id); + Ok(()) + } + + fn truncate_checkpoints(&mut self, checkpoint_id: &C) -> Result<(), Self::Error> { + self.checkpoints.split_off(checkpoint_id); + Ok(()) + } } /// A sparse binary Merkle tree of the specified depth, represented as an ordered collection of @@ -1479,61 +1674,39 @@ impl ShardStore for MemoryShardStore { /// front of the tree, that are maintained such that it's possible to truncate nodes to the right /// of the specified position. #[derive(Debug)] -pub struct ShardTree, const DEPTH: u8, const SHARD_HEIGHT: u8> { +pub struct ShardTree, const DEPTH: u8, const SHARD_HEIGHT: u8> { /// The vector of tree shards. store: S, /// The maximum number of checkpoints to retain before pruning. max_checkpoints: usize, - /// An ordered map from checkpoint identifier to checkpoint. - checkpoints: BTreeMap, - // /// TODO: Add a tree that is used to cache the known roots of subtrees in the "cap" of nodes between - // /// `SHARD_HEIGHT` and `DEPTH` that are otherwise not directly represented in the tree. This - // /// cache will be automatically updated when computing roots and witnesses. Leaf nodes are empty - // /// because the annotation slot is consistently used to store the subtree hashes at each node. - // cap_cache: Tree>, ()> _hash_type: PhantomData, -} - -impl> ShardStore for &mut S { - type Error = S::Error; - fn get_shard(&self, shard_root: Address) -> Option<&LocatedPrunableTree> { - S::get_shard(*self, shard_root) - } - - fn last_shard(&self) -> Option<&LocatedPrunableTree> { - S::last_shard(*self) - } - - fn put_shard(&mut self, subtree: LocatedPrunableTree) -> Result<(), Self::Error> { - S::put_shard(*self, subtree) - } - - fn get_shard_roots(&self) -> Vec
{ - S::get_shard_roots(*self) - } - - fn truncate(&mut self, from: Address) -> Result<(), Self::Error> { - S::truncate(*self, from) - } + _checkpoint_id_type: PhantomData, } impl< H: Hashable + Clone + PartialEq, C: Clone + Ord, - S: ShardStore, + S: ShardStore, const DEPTH: u8, const SHARD_HEIGHT: u8, > ShardTree { /// Creates a new empty tree. - pub fn new(store: S, max_checkpoints: usize, initial_checkpoint_id: C) -> Self { - Self { + pub fn new( + store: S, + max_checkpoints: usize, + initial_checkpoint_id: C, + ) -> Result { + let mut result = Self { store, max_checkpoints, - checkpoints: BTreeMap::from([(initial_checkpoint_id, Checkpoint::tree_empty())]), - //cap_cache: Tree(None, ()) _hash_type: PhantomData, - } + _checkpoint_id_type: PhantomData, + }; + result + .store + .add_checkpoint(initial_checkpoint_id, Checkpoint::tree_empty())?; + Ok(result) } /// Returns the root address of the tree. @@ -1547,11 +1720,6 @@ impl< Level::from(SHARD_HEIGHT - 1) } - /// Returns the position and checkpoint count for each checkpointed position in the tree. - pub fn checkpoints(&self) -> &BTreeMap { - &self.checkpoints - } - /// Returns the leaf value at the specified position, if it is a marked leaf. pub fn get_marked_leaf(&self, position: Position) -> Option<&H> { self.store @@ -1624,9 +1792,12 @@ impl< &mut self, value: H, retention: Retention, - ) -> Result<(), InsertionError> { + ) -> Result<(), InsertionError> + where + S::Error: Debug, + { if let Retention::Checkpoint { id, .. } = &retention { - if self.checkpoints.keys().last() >= Some(id) { + if self.store.max_checkpoint_id() >= Some(id) { return Err(InsertionError::CheckpointOutOfOrder); } } @@ -1653,8 +1824,9 @@ impl< .put_shard(append_result) .map_err(InsertionError::Storage)?; if let Some(c) = checkpoint_id { - self.checkpoints - .insert(c, Checkpoint::at_position(position)); + self.store + .add_checkpoint(c, Checkpoint::at_position(position)) + .map_err(InsertionError::Storage)?; } self.prune_excess_checkpoints() @@ -1682,7 +1854,10 @@ impl< &mut self, mut start: Position, values: I, - ) -> Result)>, InsertionError> { + ) -> Result)>, InsertionError> + where + S::Error: Debug, + { let mut values = values.peekable(); let mut subtree_root_addr = Address::above_position(Self::subtree_level(), start); let mut max_insert_position = None; @@ -1702,8 +1877,9 @@ impl< .put_shard(res.subtree) .map_err(InsertionError::Storage)?; for (id, position) in res.checkpoints.into_iter() { - self.checkpoints - .insert(id, Checkpoint::at_position(position)); + self.store + .add_checkpoint(id, Checkpoint::at_position(position)) + .map_err(InsertionError::Storage)?; } values = res.remainder; @@ -1718,7 +1894,6 @@ impl< self.prune_excess_checkpoints() .map_err(InsertionError::Storage)?; - Ok(max_insert_position.map(|p| (p, all_incomplete))) } @@ -1747,7 +1922,10 @@ impl< } /// Adds a checkpoint at the rightmost leaf state of the tree. - pub fn checkpoint(&mut self, checkpoint_id: C) -> bool { + pub fn checkpoint(&mut self, checkpoint_id: C) -> Result> + where + S::Error: Debug, + { fn go( root_addr: Address, root: &PrunableTree, @@ -1793,8 +1971,8 @@ impl< } // checkpoint identifiers at the tip must be in increasing order - if self.checkpoints.keys().last() >= Some(&checkpoint_id) { - return false; + if self.store.max_checkpoint_id() >= Some(&checkpoint_id) { + return Ok(false); } // Search backward from the end of the subtrees iter to find a non-empty subtree. @@ -1814,67 +1992,75 @@ impl< }) .is_err() { - return false; + return Ok(false); } - self.checkpoints - .insert(checkpoint_id, Checkpoint::at_position(checkpoint_position)); + self.store + .add_checkpoint(checkpoint_id, Checkpoint::at_position(checkpoint_position)) + .map_err(InsertionError::Storage)?; // early return once we've updated the tree state - return self - .prune_excess_checkpoints() - .map_err(InsertionError::Storage) - .is_ok(); + self.prune_excess_checkpoints() + .map_err(InsertionError::Storage)?; + return Ok(true); } } - self.checkpoints - .insert(checkpoint_id, Checkpoint::tree_empty()); + self.store + .add_checkpoint(checkpoint_id, Checkpoint::tree_empty()) + .map_err(InsertionError::Storage)?; // TODO: it should not be necessary to do this on every checkpoint, // but currently that's how the reference tree behaves so we're maintaining // those semantics for test compatibility. self.prune_excess_checkpoints() - .map_err(InsertionError::Storage) - .is_ok() + .map_err(InsertionError::Storage)?; + Ok(true) } - fn prune_excess_checkpoints(&mut self) -> Result<(), S::Error> { - if self.checkpoints.len() > self.max_checkpoints { + fn prune_excess_checkpoints(&mut self) -> Result<(), S::Error> + where + S::Error: Debug, + { + let checkpoint_count = self.store.checkpoint_count()?; + if checkpoint_count > self.max_checkpoints { // Batch removals by subtree & create a list of the checkpoint identifiers that // will be removed from the checkpoints map. let mut checkpoints_to_delete = vec![]; let mut clear_positions: BTreeMap> = BTreeMap::new(); - for (cid, checkpoint) in self - .checkpoints - .iter() - .take(self.checkpoints.len() - self.max_checkpoints) - { - checkpoints_to_delete.push(cid.clone()); + self.store + .with_checkpoints( + checkpoint_count - self.max_checkpoints, + |cid, checkpoint| { + checkpoints_to_delete.push(cid.clone()); - let mut clear_at = |pos, flags_to_clear| { - let subtree_addr = Address::above_position(Self::subtree_level(), pos); - clear_positions - .entry(subtree_addr) - .and_modify(|to_clear| { - to_clear - .entry(pos) - .and_modify(|flags| *flags |= flags_to_clear) - .or_insert(flags_to_clear); - }) - .or_insert_with(|| BTreeMap::from([(pos, flags_to_clear)])); - }; + let mut clear_at = |pos, flags_to_clear| { + let subtree_addr = Address::above_position(Self::subtree_level(), pos); + clear_positions + .entry(subtree_addr) + .and_modify(|to_clear| { + to_clear + .entry(pos) + .and_modify(|flags| *flags |= flags_to_clear) + .or_insert(flags_to_clear); + }) + .or_insert_with(|| BTreeMap::from([(pos, flags_to_clear)])); + }; - // clear the checkpoint leaf - if let TreeState::AtPosition(pos) = checkpoint.tree_state { - clear_at(pos, RetentionFlags::CHECKPOINT) - } + // clear the checkpoint leaf + if let TreeState::AtPosition(pos) = checkpoint.tree_state { + clear_at(pos, RetentionFlags::CHECKPOINT) + } - // clear the leaves that have been marked for removal - for unmark_pos in checkpoint.marks_removed.iter() { - clear_at(*unmark_pos, RetentionFlags::MARKED) - } - } + // clear the leaves that have been marked for removal + for unmark_pos in checkpoint.marks_removed.iter() { + clear_at(*unmark_pos, RetentionFlags::MARKED) + } + + Ok(()) + }, + ) + .expect("The provided function is infallible."); // Prune each affected subtree for (subtree_addr, positions) in clear_positions.into_iter() { @@ -1889,7 +2075,7 @@ impl< // Now that the leaves have been pruned, actually remove the checkpoints for c in checkpoints_to_delete { - self.checkpoints.remove(&c); + self.store.remove_checkpoint(&c)?; } } @@ -1901,23 +2087,21 @@ impl< /// This will also discard all checkpoints with depth <= the specified depth. Returns `true` /// if the truncation succeeds or has no effect, or `false` if no checkpoint exists at the /// specified depth. - pub fn truncate_removing_checkpoint(&mut self, checkpoint_depth: usize) -> bool { + pub fn truncate_removing_checkpoint( + &mut self, + checkpoint_depth: usize, + ) -> Result { if checkpoint_depth == 0 { - true - } else if self.checkpoints.len() > 1 { - match self.checkpoint_at_depth(checkpoint_depth) { + Ok(true) + } else if self.store.checkpoint_count()? > 1 { + Ok(match self.store.get_checkpoint_at_depth(checkpoint_depth) { Some((checkpoint_id, c)) => { let checkpoint_id = checkpoint_id.clone(); match c.tree_state { TreeState::Empty => { - if (self - .store - .truncate(Address::from_parts(Self::subtree_level(), 0))) - .is_err() - { - return false; - } - self.checkpoints.split_off(&checkpoint_id); + self.store + .truncate(Address::from_parts(Self::subtree_level(), 0))?; + self.store.truncate_checkpoints(&checkpoint_id)?; true } TreeState::AtPosition(position) => { @@ -1927,16 +2111,13 @@ impl< .store .get_shard(subtree_addr) .and_then(|s| s.truncate_to_position(position)); + match replacement { Some(truncated) => { - if self.store.truncate(subtree_addr).is_err() - || self.store.put_shard(truncated).is_err() - { - false - } else { - self.checkpoints.split_off(&checkpoint_id); - true - } + self.store.truncate(subtree_addr)?; + self.store.put_shard(truncated)?; + self.store.truncate_checkpoints(&checkpoint_id)?; + true } None => false, } @@ -1944,10 +2125,9 @@ impl< } } None => false, - } + }) } else { - // do not remove the first checkpoint. - false + Ok(false) } } @@ -2060,17 +2240,6 @@ impl< } } - /// Returns the position of the checkpoint, if any, along with the number of subsequent - /// checkpoints at the same position. Returns `None` if `checkpoint_depth == 0` or if - /// insufficient checkpoints exist to seek back to the requested depth. - pub fn checkpoint_at_depth(&self, checkpoint_depth: usize) -> Option<(&C, &Checkpoint)> { - if checkpoint_depth == 0 { - None - } else { - self.checkpoints.iter().rev().nth(checkpoint_depth - 1) - } - } - /// Returns the position of the rightmost leaf inserted as of the given checkpoint. /// /// Returns the maximum leaf position if `checkpoint_depth == 0` (or `Ok(None)` in this @@ -2088,7 +2257,7 @@ impl< // just store it directly. Ok(self.store.last_shard().and_then(|t| t.max_position())) } else { - match self.checkpoint_at_depth(checkpoint_depth) { + match self.store.get_checkpoint_at_depth(checkpoint_depth) { Some((_, c)) => Ok(c.position()), None => { // There is no checkpoint at the specified depth, so we report it as pruned. @@ -2156,20 +2325,33 @@ impl< /// corresponding checkpoint would have been more than `max_checkpoints` deep, the removal /// is recorded as of the first existing checkpoint and the associated leaves will be pruned /// when that checkpoint is subsequently removed. - pub fn remove_mark(&mut self, position: Position, as_of_checkpoint: &C) -> bool { + pub fn remove_mark( + &mut self, + position: Position, + as_of_checkpoint: &C, + ) -> Result { if self.get_marked_leaf(position).is_some() { - if let Some(checkpoint) = self.checkpoints.get_mut(as_of_checkpoint) { - checkpoint.marks_removed.insert(position); - return true; + if self + .store + .update_checkpoint_with(as_of_checkpoint, |checkpoint| { + checkpoint.marks_removed.insert(position); + Ok(()) + })? + { + return Ok(true); } - if let Some((_, checkpoint)) = self.checkpoints.iter_mut().next() { - checkpoint.marks_removed.insert(position); - return true; + if let Some(cid) = self.store.min_checkpoint_id().cloned() { + if self.store.update_checkpoint_with(&cid, |checkpoint| { + checkpoint.marks_removed.insert(position); + Ok(()) + })? { + return Ok(true); + } } } - false + Ok(false) } } @@ -2569,8 +2751,8 @@ mod tests { #[test] fn shardtree_insertion() { - let mut tree: ShardTree, 4, 3> = - ShardTree::new(MemoryShardStore::empty(), 100, 0); + let mut tree: ShardTree, 4, 3> = + ShardTree::new(MemoryShardStore::empty(), 100, 0).unwrap(); assert_matches!( tree.batch_insert( Position::from(1), @@ -2655,16 +2837,18 @@ mod tests { ] ); - assert!(tree.truncate_removing_checkpoint(1)); + assert_matches!(tree.truncate_removing_checkpoint(1), Ok(true)); } impl< H: Hashable + Ord + Clone, C: Clone + Ord + core::fmt::Debug, - S: ShardStore, + S: ShardStore, const DEPTH: u8, const SHARD_HEIGHT: u8, > testing::Tree for ShardTree + where + S::Error: core::fmt::Debug, { fn depth(&self) -> u8 { DEPTH @@ -2697,74 +2881,79 @@ mod tests { } fn remove_mark(&mut self, position: Position) -> bool { - if let Some(c) = self.checkpoints.iter().rev().map(|(c, _)| c.clone()).next() { - ShardTree::remove_mark(self, position, &c) + if let Some(c) = self.store.max_checkpoint_id().cloned() { + ShardTree::remove_mark(self, position, &c).unwrap() } else { false } } fn checkpoint(&mut self, checkpoint_id: C) -> bool { - ShardTree::checkpoint(self, checkpoint_id) + ShardTree::checkpoint(self, checkpoint_id).unwrap() } fn rewind(&mut self) -> bool { - ShardTree::truncate_removing_checkpoint(self, 1) + ShardTree::truncate_removing_checkpoint(self, 1).unwrap() } } #[test] fn append() { check_append(|m| { - ShardTree::, 4, 3>::new( + ShardTree::, 4, 3>::new( MemoryShardStore::empty(), m, 0, ) + .unwrap() }); } #[test] fn root_hashes() { check_root_hashes(|m| { - ShardTree::, 4, 3>::new( + ShardTree::, 4, 3>::new( MemoryShardStore::empty(), m, 0, ) + .unwrap() }); } #[test] fn witnesses() { check_witnesses(|m| { - ShardTree::, 4, 3>::new( + ShardTree::, 4, 3>::new( MemoryShardStore::empty(), m, 0, ) + .unwrap() }); } #[test] fn checkpoint_rewind() { check_checkpoint_rewind(|m| { - ShardTree::, 4, 3>::new( + ShardTree::, 4, 3>::new( MemoryShardStore::empty(), m, 0, ) + .unwrap() }); } #[test] fn rewind_remove_mark() { check_rewind_remove_mark(|m| { - ShardTree::, 4, 3>::new( + ShardTree::, 4, 3>::new( MemoryShardStore::empty(), m, 0, ) + .unwrap() }); } @@ -2776,11 +2965,11 @@ mod tests { H, usize, CompleteTree, - ShardTree, 4, 3>, + ShardTree, 4, 3>, > { CombinedTree::new( CompleteTree::new(max_checkpoints, 0), - ShardTree::new(MemoryShardStore::empty(), max_checkpoints, 0), + ShardTree::new(MemoryShardStore::empty(), max_checkpoints, 0).unwrap(), ) }