diff --git a/shardtree/src/lib.rs b/shardtree/src/lib.rs index e5cdc8d..14b56a7 100644 --- a/shardtree/src/lib.rs +++ b/shardtree/src/lib.rs @@ -1420,37 +1420,54 @@ pub trait ShardStore { fn truncate(&mut self, from: Address) -> Result<(), Self::Error>; } -impl ShardStore for Vec> { +#[derive(Debug)] +pub struct MemoryShardStore { + shards: Vec>, +} + +impl MemoryShardStore { + pub fn empty() -> Self { + Self { shards: vec![] } + } +} + +impl ShardStore for MemoryShardStore { type Error = Infallible; fn get_shard(&self, shard_root: Address) -> Option<&LocatedPrunableTree> { - self.get(usize::try_from(shard_root.index()).expect("SHARD_HEIGHT > 64 is unsupported")) + let shard_idx = + usize::try_from(shard_root.index()).expect("SHARD_HEIGHT > 64 is unsupported"); + self.shards.get(shard_idx) } fn last_shard(&self) -> Option<&LocatedPrunableTree> { - self.last() + self.shards.last() } fn put_shard(&mut self, subtree: LocatedPrunableTree) -> Result<(), Self::Error> { let subtree_addr = subtree.root_addr; - for subtree_idx in self.last().map_or(0, |s| s.root_addr.index() + 1)..=subtree_addr.index() + for subtree_idx in + self.shards.last().map_or(0, |s| s.root_addr.index() + 1)..=subtree_addr.index() { - self.push(LocatedTree { + self.shards.push(LocatedTree { root_addr: Address::from_parts(subtree_addr.level(), subtree_idx), root: Tree(Node::Nil), }) } - self[usize::try_from(subtree_addr.index()).expect("SHARD_HEIGHT > 64 is unsupported")] = - subtree; + + let shard_idx = + usize::try_from(subtree_addr.index()).expect("SHARD_HEIGHT > 64 is unsupported"); + self.shards[shard_idx] = subtree; Ok(()) } fn get_shard_roots(&self) -> Vec
{ - self.iter().map(|s| s.root_addr).collect() + self.shards.iter().map(|s| s.root_addr).collect() } fn truncate(&mut self, from: Address) -> Result<(), Self::Error> { - self.truncate(usize::try_from(from.index()).expect("SHARD_HEIGHT > 64 is unsupported")); + let shard_idx = usize::try_from(from.index()).expect("SHARD_HEIGHT > 64 is unsupported"); + self.shards.truncate(shard_idx); Ok(()) } } @@ -2224,8 +2241,8 @@ pub mod testing { #[cfg(test)] mod tests { use crate::{ - IncompleteAt, LocatedPrunableTree, LocatedTree, Node, PrunableTree, QueryError, - RetentionFlags, ShardStore, ShardTree, Tree, + IncompleteAt, LocatedPrunableTree, LocatedTree, MemoryShardStore, Node, PrunableTree, + QueryError, RetentionFlags, ShardStore, ShardTree, Tree, }; use assert_matches::assert_matches; use core::convert::Infallible; @@ -2390,8 +2407,6 @@ mod tests { ); } - type VecShardStore = Vec>; - #[test] fn tree_marked_positions() { let t: PrunableTree = parent( @@ -2554,8 +2569,8 @@ mod tests { #[test] fn shardtree_insertion() { - let mut tree: ShardTree, 4, 3> = - ShardTree::new(vec![], 100, 0); + let mut tree: ShardTree, 4, 3> = + ShardTree::new(MemoryShardStore::empty(), 100, 0); assert_matches!( tree.batch_insert( Position::from(1), @@ -2701,35 +2716,55 @@ mod tests { #[test] fn append() { check_append(|m| { - ShardTree::, 4, 3>::new(vec![], m, 0) + ShardTree::, 4, 3>::new( + MemoryShardStore::empty(), + m, + 0, + ) }); } #[test] fn root_hashes() { check_root_hashes(|m| { - ShardTree::, 4, 3>::new(vec![], m, 0) + ShardTree::, 4, 3>::new( + MemoryShardStore::empty(), + m, + 0, + ) }); } #[test] fn witnesses() { check_witnesses(|m| { - ShardTree::, 4, 3>::new(vec![], m, 0) + ShardTree::, 4, 3>::new( + MemoryShardStore::empty(), + m, + 0, + ) }); } #[test] fn checkpoint_rewind() { check_checkpoint_rewind(|m| { - ShardTree::, 4, 3>::new(vec![], m, 0) + ShardTree::, 4, 3>::new( + MemoryShardStore::empty(), + m, + 0, + ) }); } #[test] fn rewind_remove_mark() { check_rewind_remove_mark(|m| { - ShardTree::, 4, 3>::new(vec![], m, 0) + ShardTree::, 4, 3>::new( + MemoryShardStore::empty(), + m, + 0, + ) }); } @@ -2741,11 +2776,11 @@ mod tests { H, usize, CompleteTree, - ShardTree, 4, 3>, + ShardTree, 4, 3>, > { CombinedTree::new( CompleteTree::new(max_checkpoints, 0), - ShardTree::new(vec![], max_checkpoints, 0), + ShardTree::new(MemoryShardStore::empty(), max_checkpoints, 0), ) }