Add caching of the "cap" to root & witness computation.

This commit is contained in:
Kris Nuttycombe 2023-05-16 12:32:36 -06:00
parent f8c13d17de
commit 290b66d5c8
2 changed files with 342 additions and 35 deletions

View File

@ -358,7 +358,7 @@ impl<H, const DEPTH: u8> CommitmentTree<H, DEPTH> {
}
pub fn leaf(&self) -> Option<&H> {
self.right.as_ref().or_else(|| self.left.as_ref())
self.right.as_ref().or(self.left.as_ref())
}
pub fn ommers_iter(&self) -> Box<dyn Iterator<Item = &'_ H> + '_> {

View File

@ -166,7 +166,7 @@ impl<A, V> Tree<A, V> {
/// Replaces the annotation at the root of the tree, if the root is a `Node::Parent`; otherwise
/// returns this tree unaltered.
pub fn reannotate_root(self, ann: A) -> Tree<A, V> {
pub fn reannotate_root(self, ann: A) -> Self {
Tree(self.0.reannotate(ann))
}
@ -1407,15 +1407,13 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
// the current address is a left child, so create a parent with
// an empty right-hand tree
subtree = Tree::parent(None, subtree, Tree::empty());
} else if let Some(left) = ommers.next() {
// the current address corresponds to a right child, so create a parent that
// takes the left sibling to that child from the ommers
subtree =
Tree::parent(None, Tree::leaf((left, RetentionFlags::EPHEMERAL)), subtree);
} else {
if let Some(left) = ommers.next() {
// the current address corresponds to a right child, so create a parent that
// takes the left sibling to that child from the ommers
subtree =
Tree::parent(None, Tree::leaf((left, RetentionFlags::EPHEMERAL)), subtree);
} else {
break;
}
break;
}
addr = addr.parent();
@ -1429,7 +1427,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
let located_supertree = if located_subtree.root_addr().level() == split_at {
let mut addr = located_subtree.root_addr();
let mut supertree = None;
while let Some(left) = ommers.next() {
for left in ommers {
// build up the left-biased tree until we get a right-hand node
while addr.index() & 0x1 == 0 {
supertree = supertree.map(|t| Tree::parent(None, t, Tree::empty()));
@ -1500,7 +1498,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
// add filled nodes to the supertree
let supertree = if addr.level() == split_at {
let mut supertree = None;
while let Some(right) = filled.next() {
for right in filled {
// build up the right-biased tree until we get a left-hand node
while addr.index() & 0x1 == 1 {
supertree = supertree.map(|t| Tree::parent(None, Tree::empty(), t));
@ -1936,7 +1934,7 @@ impl<H, C: Ord> MemoryShardStore<H, C> {
}
}
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub enum MemoryShardStoreError {
Insertion(InsertionError),
Query(QueryError),
@ -2250,10 +2248,7 @@ where
&mut self,
frontier: NonEmptyFrontier<H>,
leaf_retention: Retention<C>,
) -> Result<(), S::Error>
where
S::Error: From<InsertionError>,
{
) -> Result<(), S::Error> {
let leaf_position = frontier.position();
let subtree_root_addr = Address::above_position(Self::subtree_level(), leaf_position);
@ -2293,16 +2288,13 @@ where
&mut self,
witness: IncrementalWitness<H, DEPTH>,
checkpoint_id: S::CheckpointId,
) -> Result<(), S::Error>
where
S::Error: From<InsertionError>,
{
) -> Result<(), S::Error> {
let leaf_position = witness.witnessed_position();
let subtree_root_addr = Address::above_position(Self::subtree_level(), leaf_position);
let shard = self
.store
.get_shard(subtree_root_addr)
.get_shard(subtree_root_addr)?
.unwrap_or_else(|| LocatedTree::empty(subtree_root_addr));
let (updated_subtree, supertree, tip_subtree) =
@ -2328,7 +2320,7 @@ where
let tip_shard = self
.store
.get_shard(tip_subtree_addr)
.get_shard(tip_subtree_addr)?
.unwrap_or_else(|| LocatedTree::empty(tip_subtree_addr));
self.store
@ -2570,6 +2562,7 @@ where
self.store
.truncate(Address::from_parts(Self::subtree_level(), 0))?;
self.store.truncate_checkpoints(&checkpoint_id)?;
self.store.put_cap(Tree::empty())?;
true
}
TreeState::AtPosition(position) => {
@ -2579,13 +2572,23 @@ where
.get_shard(subtree_addr)?
.and_then(|s| s.truncate_to_position(position));
if let Some(truncated) = replacement {
self.store.truncate(subtree_addr)?;
self.store.put_shard(truncated)?;
self.store.truncate_checkpoints(&checkpoint_id)?;
true
} else {
false
let cap_tree = LocatedTree {
root_addr: Self::root_addr(),
root: self.store.get_cap()?,
};
if let Some(truncated) = cap_tree.truncate_to_position(position) {
self.store.put_cap(truncated.root)?;
};
match replacement {
Some(truncated) => {
self.store.truncate(subtree_addr)?;
self.store.put_shard(truncated)?;
self.store.truncate_checkpoints(&checkpoint_id)?;
true
}
None => false,
}
}
}
@ -2604,6 +2607,181 @@ where
///
/// Use [`Self::root_at_checkpoint`] to obtain the root of the overall tree.
pub fn root(&self, address: Address, truncate_at: Position) -> Result<H, S::Error> {
assert!(Self::root_addr().contains(&address));
// traverse the cap from root to leaf depth-first, either returning an existing
// cached value for the node or inserting the computed value into the cache
let (root, _) = self.root_internal(
&LocatedPrunableTree {
root_addr: Self::root_addr(),
root: self.store.get_cap()?,
},
address,
truncate_at,
)?;
Ok(root)
}
pub fn root_caching(&mut self, address: Address, truncate_at: Position) -> Result<H, S::Error> {
let (root, updated_cap) = self.root_internal(
&LocatedPrunableTree {
root_addr: Self::root_addr(),
root: self.store.get_cap()?,
},
address,
truncate_at,
)?;
if let Some(updated_cap) = updated_cap {
self.store.put_cap(updated_cap)?;
}
Ok(root)
}
// compute the root, along with an optional update to the cap
fn root_internal(
&self,
cap: &LocatedPrunableTree<S::H>,
// The address at which we want to compute the root hash
target_addr: Address,
// An inclusive lower bound for positions whose leaf values will be replaced by empty
// roots.
truncate_at: Position,
) -> Result<(H, Option<PrunableTree<H>>), S::Error> {
match &cap.root {
Tree(Node::Parent { ann, left, right }) => {
match ann {
Some(cached_root) if target_addr.contains(&cap.root_addr) => {
Ok((cached_root.as_ref().clone(), None))
}
_ => {
// Compute the roots of the left and right children and hash them together.
// We skip computation in any subtrees that will not have data included in
// the final result.
let (l_addr, r_addr) = cap.root_addr.children().unwrap();
let l_result = if r_addr.contains(&target_addr) {
None
} else {
Some(self.root_internal(
&LocatedPrunableTree {
root_addr: l_addr,
root: left.as_ref().clone(),
},
if l_addr.contains(&target_addr) {
target_addr
} else {
l_addr
},
truncate_at,
)?)
};
let r_result = if l_addr.contains(&target_addr) {
None
} else {
Some(self.root_internal(
&LocatedPrunableTree {
root_addr: r_addr,
root: right.as_ref().clone(),
},
if r_addr.contains(&target_addr) {
target_addr
} else {
r_addr
},
truncate_at,
)?)
};
// Compute the root value based on the child roots; these may contain the
// hashes of empty/truncated nodes.
let (root, new_left, new_right) = match (l_result, r_result) {
(Some((l_root, new_left)), Some((r_root, new_right))) => (
S::H::combine(l_addr.level(), &l_root, &r_root),
new_left,
new_right,
),
(Some((l_root, new_left)), None) => (l_root, new_left, None),
(None, Some((r_root, new_right))) => (r_root, None, new_right),
(None, None) => unreachable!(),
};
let new_parent = Tree(Node::Parent {
ann: new_left
.as_ref()
.and_then(|l| l.node_value())
.zip(new_right.as_ref().and_then(|r| r.node_value()))
.map(|(l, r)| {
// the node values of child nodes cannot contain the hashes of
// empty nodes or nodes with positions greater than the
Rc::new(S::H::combine(l_addr.level(), l, r))
}),
left: new_left.map_or_else(|| left.clone(), Rc::new),
right: new_right.map_or_else(|| right.clone(), Rc::new),
});
Ok((root, Some(new_parent)))
}
}
}
Tree(Node::Leaf { value }) => {
if truncate_at >= cap.root_addr.position_range_end()
&& target_addr.contains(&cap.root_addr)
{
// no truncation or computation of child subtrees of this leaf is necessary, just use
// the cached leaf value
Ok((value.0.clone(), None))
} else {
// since the tree was truncated below this level, recursively call with an
// empty parent node to trigger the continued traversal
let (root, replacement) = self.root_internal(
&LocatedPrunableTree {
root_addr: cap.root_addr(),
root: Tree::parent(None, Tree::empty(), Tree::empty()),
},
target_addr,
truncate_at,
)?;
Ok((
root,
replacement.map(|r| r.reannotate_root(Some(Rc::new(value.0.clone())))),
))
}
}
Tree(Node::Nil) => {
if cap.root_addr == target_addr
|| cap.root_addr.level() == ShardTree::<S, DEPTH, SHARD_HEIGHT>::subtree_level()
{
// We are at the leaf level or the target address; compute the root hash and
// return it as cacheable if it is not truncated.
let root = self.root_from_shards(target_addr, truncate_at)?;
Ok((
root.clone(),
if truncate_at >= cap.root_addr.position_range_end() {
// return the compute root as a new leaf to be cached if it contains no
// empty hashes due to truncation
Some(Tree::leaf((root, RetentionFlags::EPHEMERAL)))
} else {
None
},
))
} else {
// Compute the result by recursively walking down the tree. By replacing
// the current node with a parent node, the `Parent` handler will take care
// of the branching recursive calls.
self.root_internal(
&LocatedPrunableTree {
root_addr: cap.root_addr,
root: Tree::parent(None, Tree::empty(), Tree::empty()),
},
target_addr,
truncate_at,
)
}
}
}
}
fn root_from_shards(&self, address: Address, truncate_at: Position) -> Result<H, S::Error> {
match address.context(Self::subtree_level()) {
Either::Left(subtree_addr) => {
// The requested root address is fully contained within one of the subtrees.
@ -2739,6 +2917,13 @@ where
)
}
pub fn root_at_checkpoint_caching(&mut self, checkpoint_depth: usize) -> Result<H, S::Error> {
self.max_leaf_position(checkpoint_depth)?.map_or_else(
|| Ok(H::empty_root(Self::root_addr().level())),
|pos| self.root_caching(Self::root_addr(), pos + 1),
)
}
/// Computes the witness for the leaf at the specified position.
///
/// Returns the witness as of the most recently appended leaf if `checkpoint_depth == 0`. Note
@ -2750,13 +2935,14 @@ where
checkpoint_depth: usize,
) -> Result<MerklePath<H, DEPTH>, S::Error> {
let max_leaf_position = self.max_leaf_position(checkpoint_depth).and_then(|v| {
v.ok_or_else(|| S::Error::from(QueryError::TreeIncomplete(vec![Self::root_addr()])))
v.ok_or_else(|| QueryError::TreeIncomplete(vec![Self::root_addr()]).into())
})?;
if position > max_leaf_position {
Err(S::Error::from(QueryError::NotContained(
Address::from_parts(Level::from(0), position.into()),
)))
Err(
QueryError::NotContained(Address::from_parts(Level::from(0), position.into()))
.into(),
)
} else {
let subtree_addr = Self::subtree_addr(position);
@ -2778,6 +2964,41 @@ where
}
}
pub fn witness_caching(
&mut self,
position: Position,
checkpoint_depth: usize,
) -> Result<MerklePath<H, DEPTH>, S::Error> {
let max_leaf_position = self.max_leaf_position(checkpoint_depth).and_then(|v| {
v.ok_or_else(|| QueryError::TreeIncomplete(vec![Self::root_addr()]).into())
})?;
if position > max_leaf_position {
Err(
QueryError::NotContained(Address::from_parts(Level::from(0), position.into()))
.into(),
)
} else {
let subtree_addr = Address::above_position(Self::subtree_level(), position);
// compute the witness for the specified position up to the subtree root
let mut witness = self.store.get_shard(subtree_addr)?.map_or_else(
|| Err(QueryError::TreeIncomplete(vec![subtree_addr])),
|subtree| subtree.witness(position, max_leaf_position + 1),
)?;
// compute the remaining parts of the witness up to the root
let root_addr = Self::root_addr();
let mut cur_addr = subtree_addr;
while cur_addr != root_addr {
witness.push(self.root_caching(cur_addr.sibling(), max_leaf_position + 1)?);
cur_addr = cur_addr.parent();
}
Ok(MerklePath::from_parts(witness, position).unwrap())
}
}
/// Make a marked leaf at a position eligible to be pruned.
///
/// If the checkpoint associated with the specified identifier does not exist because the
@ -2844,6 +3065,8 @@ fn accumulate_result_with<A, B, C>(
pub mod testing {
use super::*;
use incrementalmerkletree::Hashable;
use proptest::bool::weighted;
use proptest::collection::vec;
use proptest::prelude::*;
use proptest::sample::select;
@ -2885,11 +3108,70 @@ pub mod testing {
})
})
}
/// Constructs a random shardtree of size up to 2^6 with shards of size 2^3. Returns the tree,
/// along with vectors of the checkpoint and mark positions.
pub fn arb_shardtree<H: Strategy>(
arb_leaf: H,
) -> impl Strategy<
Value = (
ShardTree<MemoryShardStore<H::Value, usize>, 6, 3>,
Vec<Position>,
Vec<Position>,
),
>
where
H::Value: Hashable + Clone + PartialEq,
{
vec(
(arb_leaf, weighted(0.1), weighted(0.2)),
0..=(2usize.pow(6)),
)
.prop_map(|leaves| {
let mut tree = ShardTree::empty(MemoryShardStore::empty(), 10);
let mut checkpoint_positions = vec![];
let mut marked_positions = vec![];
tree.batch_insert(
Position::from(0),
leaves
.into_iter()
.enumerate()
.map(|(id, (leaf, is_marked, is_checkpoint))| {
(
leaf,
match (is_checkpoint, is_marked) {
(false, false) => Retention::Ephemeral,
(true, is_marked) => {
let pos = Position::try_from(id).unwrap();
checkpoint_positions.push(pos);
if is_marked {
marked_positions.push(pos);
}
Retention::Checkpoint { id, is_marked }
}
(false, true) => {
marked_positions.push(Position::try_from(id).unwrap());
Retention::Marked
}
},
)
}),
)
.unwrap();
(tree, checkpoint_positions, marked_positions)
})
}
pub fn arb_char_str() -> impl Strategy<Value = String> {
let chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
(0usize..chars.len()).prop_map(move |i| chars.get(i..=i).unwrap().to_string())
}
}
#[cfg(test)]
mod tests {
use crate::{
testing::{arb_char_str, arb_shardtree},
IncompleteAt, InsertionError, LocatedPrunableTree, LocatedTree, MemoryShardStore,
MemoryShardStoreError, Node, PrunableTree, QueryError, RetentionFlags, ShardStore,
ShardTree, Tree,
@ -3460,6 +3742,30 @@ mod tests {
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn check_shardtree_caching(
(mut tree, _, marked_positions) in arb_shardtree(arb_char_str())
) {
if let Some(max_leaf_pos) = tree.max_leaf_position(0).unwrap() {
let max_complete_addr = Address::above_position(max_leaf_pos.root_level(), max_leaf_pos);
let root = tree.root(max_complete_addr, max_leaf_pos + 1);
let caching_root = tree.root_caching(max_complete_addr, max_leaf_pos + 1);
assert_matches!(root, Ok(_));
assert_eq!(root, caching_root);
for pos in marked_positions {
let witness = tree.witness(pos, 0);
let caching_witness = tree.witness_caching(pos, 0);
assert_matches!(witness, Ok(_));
assert_eq!(witness, caching_witness);
}
}
}
}
#[test]
fn insert_frontier_nodes() {
let mut frontier = NonEmptyFrontier::new("a".to_string());
@ -3539,7 +3845,8 @@ mod tests {
);
assert_eq!(
r.root.root_hash(Address::from_parts(Level::from(3), 3), Position::from(25)),
r.root
.root_hash(Address::from_parts(Level::from(3), 3), Position::from(25)),
Ok("y_______".to_string())
);
}