Add caching of the "cap" to root & witness computation.
This commit is contained in:
parent
f8c13d17de
commit
290b66d5c8
|
@ -358,7 +358,7 @@ impl<H, const DEPTH: u8> CommitmentTree<H, DEPTH> {
|
|||
}
|
||||
|
||||
pub fn leaf(&self) -> Option<&H> {
|
||||
self.right.as_ref().or_else(|| self.left.as_ref())
|
||||
self.right.as_ref().or(self.left.as_ref())
|
||||
}
|
||||
|
||||
pub fn ommers_iter(&self) -> Box<dyn Iterator<Item = &'_ H> + '_> {
|
||||
|
|
|
@ -166,7 +166,7 @@ impl<A, V> Tree<A, V> {
|
|||
|
||||
/// Replaces the annotation at the root of the tree, if the root is a `Node::Parent`; otherwise
|
||||
/// returns this tree unaltered.
|
||||
pub fn reannotate_root(self, ann: A) -> Tree<A, V> {
|
||||
pub fn reannotate_root(self, ann: A) -> Self {
|
||||
Tree(self.0.reannotate(ann))
|
||||
}
|
||||
|
||||
|
@ -1407,15 +1407,13 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
|
|||
// the current address is a left child, so create a parent with
|
||||
// an empty right-hand tree
|
||||
subtree = Tree::parent(None, subtree, Tree::empty());
|
||||
} else if let Some(left) = ommers.next() {
|
||||
// the current address corresponds to a right child, so create a parent that
|
||||
// takes the left sibling to that child from the ommers
|
||||
subtree =
|
||||
Tree::parent(None, Tree::leaf((left, RetentionFlags::EPHEMERAL)), subtree);
|
||||
} else {
|
||||
if let Some(left) = ommers.next() {
|
||||
// the current address corresponds to a right child, so create a parent that
|
||||
// takes the left sibling to that child from the ommers
|
||||
subtree =
|
||||
Tree::parent(None, Tree::leaf((left, RetentionFlags::EPHEMERAL)), subtree);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
addr = addr.parent();
|
||||
|
@ -1429,7 +1427,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
|
|||
let located_supertree = if located_subtree.root_addr().level() == split_at {
|
||||
let mut addr = located_subtree.root_addr();
|
||||
let mut supertree = None;
|
||||
while let Some(left) = ommers.next() {
|
||||
for left in ommers {
|
||||
// build up the left-biased tree until we get a right-hand node
|
||||
while addr.index() & 0x1 == 0 {
|
||||
supertree = supertree.map(|t| Tree::parent(None, t, Tree::empty()));
|
||||
|
@ -1500,7 +1498,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
|
|||
// add filled nodes to the supertree
|
||||
let supertree = if addr.level() == split_at {
|
||||
let mut supertree = None;
|
||||
while let Some(right) = filled.next() {
|
||||
for right in filled {
|
||||
// build up the right-biased tree until we get a left-hand node
|
||||
while addr.index() & 0x1 == 1 {
|
||||
supertree = supertree.map(|t| Tree::parent(None, Tree::empty(), t));
|
||||
|
@ -1936,7 +1934,7 @@ impl<H, C: Ord> MemoryShardStore<H, C> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum MemoryShardStoreError {
|
||||
Insertion(InsertionError),
|
||||
Query(QueryError),
|
||||
|
@ -2250,10 +2248,7 @@ where
|
|||
&mut self,
|
||||
frontier: NonEmptyFrontier<H>,
|
||||
leaf_retention: Retention<C>,
|
||||
) -> Result<(), S::Error>
|
||||
where
|
||||
S::Error: From<InsertionError>,
|
||||
{
|
||||
) -> Result<(), S::Error> {
|
||||
let leaf_position = frontier.position();
|
||||
let subtree_root_addr = Address::above_position(Self::subtree_level(), leaf_position);
|
||||
|
||||
|
@ -2293,16 +2288,13 @@ where
|
|||
&mut self,
|
||||
witness: IncrementalWitness<H, DEPTH>,
|
||||
checkpoint_id: S::CheckpointId,
|
||||
) -> Result<(), S::Error>
|
||||
where
|
||||
S::Error: From<InsertionError>,
|
||||
{
|
||||
) -> Result<(), S::Error> {
|
||||
let leaf_position = witness.witnessed_position();
|
||||
let subtree_root_addr = Address::above_position(Self::subtree_level(), leaf_position);
|
||||
|
||||
let shard = self
|
||||
.store
|
||||
.get_shard(subtree_root_addr)
|
||||
.get_shard(subtree_root_addr)?
|
||||
.unwrap_or_else(|| LocatedTree::empty(subtree_root_addr));
|
||||
|
||||
let (updated_subtree, supertree, tip_subtree) =
|
||||
|
@ -2328,7 +2320,7 @@ where
|
|||
|
||||
let tip_shard = self
|
||||
.store
|
||||
.get_shard(tip_subtree_addr)
|
||||
.get_shard(tip_subtree_addr)?
|
||||
.unwrap_or_else(|| LocatedTree::empty(tip_subtree_addr));
|
||||
|
||||
self.store
|
||||
|
@ -2570,6 +2562,7 @@ where
|
|||
self.store
|
||||
.truncate(Address::from_parts(Self::subtree_level(), 0))?;
|
||||
self.store.truncate_checkpoints(&checkpoint_id)?;
|
||||
self.store.put_cap(Tree::empty())?;
|
||||
true
|
||||
}
|
||||
TreeState::AtPosition(position) => {
|
||||
|
@ -2579,13 +2572,23 @@ where
|
|||
.get_shard(subtree_addr)?
|
||||
.and_then(|s| s.truncate_to_position(position));
|
||||
|
||||
if let Some(truncated) = replacement {
|
||||
self.store.truncate(subtree_addr)?;
|
||||
self.store.put_shard(truncated)?;
|
||||
self.store.truncate_checkpoints(&checkpoint_id)?;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
let cap_tree = LocatedTree {
|
||||
root_addr: Self::root_addr(),
|
||||
root: self.store.get_cap()?,
|
||||
};
|
||||
|
||||
if let Some(truncated) = cap_tree.truncate_to_position(position) {
|
||||
self.store.put_cap(truncated.root)?;
|
||||
};
|
||||
|
||||
match replacement {
|
||||
Some(truncated) => {
|
||||
self.store.truncate(subtree_addr)?;
|
||||
self.store.put_shard(truncated)?;
|
||||
self.store.truncate_checkpoints(&checkpoint_id)?;
|
||||
true
|
||||
}
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2604,6 +2607,181 @@ where
|
|||
///
|
||||
/// Use [`Self::root_at_checkpoint`] to obtain the root of the overall tree.
|
||||
pub fn root(&self, address: Address, truncate_at: Position) -> Result<H, S::Error> {
|
||||
assert!(Self::root_addr().contains(&address));
|
||||
|
||||
// traverse the cap from root to leaf depth-first, either returning an existing
|
||||
// cached value for the node or inserting the computed value into the cache
|
||||
let (root, _) = self.root_internal(
|
||||
&LocatedPrunableTree {
|
||||
root_addr: Self::root_addr(),
|
||||
root: self.store.get_cap()?,
|
||||
},
|
||||
address,
|
||||
truncate_at,
|
||||
)?;
|
||||
Ok(root)
|
||||
}
|
||||
|
||||
pub fn root_caching(&mut self, address: Address, truncate_at: Position) -> Result<H, S::Error> {
|
||||
let (root, updated_cap) = self.root_internal(
|
||||
&LocatedPrunableTree {
|
||||
root_addr: Self::root_addr(),
|
||||
root: self.store.get_cap()?,
|
||||
},
|
||||
address,
|
||||
truncate_at,
|
||||
)?;
|
||||
if let Some(updated_cap) = updated_cap {
|
||||
self.store.put_cap(updated_cap)?;
|
||||
}
|
||||
Ok(root)
|
||||
}
|
||||
|
||||
// compute the root, along with an optional update to the cap
|
||||
fn root_internal(
|
||||
&self,
|
||||
cap: &LocatedPrunableTree<S::H>,
|
||||
// The address at which we want to compute the root hash
|
||||
target_addr: Address,
|
||||
// An inclusive lower bound for positions whose leaf values will be replaced by empty
|
||||
// roots.
|
||||
truncate_at: Position,
|
||||
) -> Result<(H, Option<PrunableTree<H>>), S::Error> {
|
||||
match &cap.root {
|
||||
Tree(Node::Parent { ann, left, right }) => {
|
||||
match ann {
|
||||
Some(cached_root) if target_addr.contains(&cap.root_addr) => {
|
||||
Ok((cached_root.as_ref().clone(), None))
|
||||
}
|
||||
_ => {
|
||||
// Compute the roots of the left and right children and hash them together.
|
||||
// We skip computation in any subtrees that will not have data included in
|
||||
// the final result.
|
||||
let (l_addr, r_addr) = cap.root_addr.children().unwrap();
|
||||
let l_result = if r_addr.contains(&target_addr) {
|
||||
None
|
||||
} else {
|
||||
Some(self.root_internal(
|
||||
&LocatedPrunableTree {
|
||||
root_addr: l_addr,
|
||||
root: left.as_ref().clone(),
|
||||
},
|
||||
if l_addr.contains(&target_addr) {
|
||||
target_addr
|
||||
} else {
|
||||
l_addr
|
||||
},
|
||||
truncate_at,
|
||||
)?)
|
||||
};
|
||||
let r_result = if l_addr.contains(&target_addr) {
|
||||
None
|
||||
} else {
|
||||
Some(self.root_internal(
|
||||
&LocatedPrunableTree {
|
||||
root_addr: r_addr,
|
||||
root: right.as_ref().clone(),
|
||||
},
|
||||
if r_addr.contains(&target_addr) {
|
||||
target_addr
|
||||
} else {
|
||||
r_addr
|
||||
},
|
||||
truncate_at,
|
||||
)?)
|
||||
};
|
||||
|
||||
// Compute the root value based on the child roots; these may contain the
|
||||
// hashes of empty/truncated nodes.
|
||||
let (root, new_left, new_right) = match (l_result, r_result) {
|
||||
(Some((l_root, new_left)), Some((r_root, new_right))) => (
|
||||
S::H::combine(l_addr.level(), &l_root, &r_root),
|
||||
new_left,
|
||||
new_right,
|
||||
),
|
||||
(Some((l_root, new_left)), None) => (l_root, new_left, None),
|
||||
(None, Some((r_root, new_right))) => (r_root, None, new_right),
|
||||
(None, None) => unreachable!(),
|
||||
};
|
||||
|
||||
let new_parent = Tree(Node::Parent {
|
||||
ann: new_left
|
||||
.as_ref()
|
||||
.and_then(|l| l.node_value())
|
||||
.zip(new_right.as_ref().and_then(|r| r.node_value()))
|
||||
.map(|(l, r)| {
|
||||
// the node values of child nodes cannot contain the hashes of
|
||||
// empty nodes or nodes with positions greater than the
|
||||
Rc::new(S::H::combine(l_addr.level(), l, r))
|
||||
}),
|
||||
left: new_left.map_or_else(|| left.clone(), Rc::new),
|
||||
right: new_right.map_or_else(|| right.clone(), Rc::new),
|
||||
});
|
||||
|
||||
Ok((root, Some(new_parent)))
|
||||
}
|
||||
}
|
||||
}
|
||||
Tree(Node::Leaf { value }) => {
|
||||
if truncate_at >= cap.root_addr.position_range_end()
|
||||
&& target_addr.contains(&cap.root_addr)
|
||||
{
|
||||
// no truncation or computation of child subtrees of this leaf is necessary, just use
|
||||
// the cached leaf value
|
||||
Ok((value.0.clone(), None))
|
||||
} else {
|
||||
// since the tree was truncated below this level, recursively call with an
|
||||
// empty parent node to trigger the continued traversal
|
||||
let (root, replacement) = self.root_internal(
|
||||
&LocatedPrunableTree {
|
||||
root_addr: cap.root_addr(),
|
||||
root: Tree::parent(None, Tree::empty(), Tree::empty()),
|
||||
},
|
||||
target_addr,
|
||||
truncate_at,
|
||||
)?;
|
||||
|
||||
Ok((
|
||||
root,
|
||||
replacement.map(|r| r.reannotate_root(Some(Rc::new(value.0.clone())))),
|
||||
))
|
||||
}
|
||||
}
|
||||
Tree(Node::Nil) => {
|
||||
if cap.root_addr == target_addr
|
||||
|| cap.root_addr.level() == ShardTree::<S, DEPTH, SHARD_HEIGHT>::subtree_level()
|
||||
{
|
||||
// We are at the leaf level or the target address; compute the root hash and
|
||||
// return it as cacheable if it is not truncated.
|
||||
let root = self.root_from_shards(target_addr, truncate_at)?;
|
||||
Ok((
|
||||
root.clone(),
|
||||
if truncate_at >= cap.root_addr.position_range_end() {
|
||||
// return the compute root as a new leaf to be cached if it contains no
|
||||
// empty hashes due to truncation
|
||||
Some(Tree::leaf((root, RetentionFlags::EPHEMERAL)))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
))
|
||||
} else {
|
||||
// Compute the result by recursively walking down the tree. By replacing
|
||||
// the current node with a parent node, the `Parent` handler will take care
|
||||
// of the branching recursive calls.
|
||||
self.root_internal(
|
||||
&LocatedPrunableTree {
|
||||
root_addr: cap.root_addr,
|
||||
root: Tree::parent(None, Tree::empty(), Tree::empty()),
|
||||
},
|
||||
target_addr,
|
||||
truncate_at,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn root_from_shards(&self, address: Address, truncate_at: Position) -> Result<H, S::Error> {
|
||||
match address.context(Self::subtree_level()) {
|
||||
Either::Left(subtree_addr) => {
|
||||
// The requested root address is fully contained within one of the subtrees.
|
||||
|
@ -2739,6 +2917,13 @@ where
|
|||
)
|
||||
}
|
||||
|
||||
pub fn root_at_checkpoint_caching(&mut self, checkpoint_depth: usize) -> Result<H, S::Error> {
|
||||
self.max_leaf_position(checkpoint_depth)?.map_or_else(
|
||||
|| Ok(H::empty_root(Self::root_addr().level())),
|
||||
|pos| self.root_caching(Self::root_addr(), pos + 1),
|
||||
)
|
||||
}
|
||||
|
||||
/// Computes the witness for the leaf at the specified position.
|
||||
///
|
||||
/// Returns the witness as of the most recently appended leaf if `checkpoint_depth == 0`. Note
|
||||
|
@ -2750,13 +2935,14 @@ where
|
|||
checkpoint_depth: usize,
|
||||
) -> Result<MerklePath<H, DEPTH>, S::Error> {
|
||||
let max_leaf_position = self.max_leaf_position(checkpoint_depth).and_then(|v| {
|
||||
v.ok_or_else(|| S::Error::from(QueryError::TreeIncomplete(vec![Self::root_addr()])))
|
||||
v.ok_or_else(|| QueryError::TreeIncomplete(vec![Self::root_addr()]).into())
|
||||
})?;
|
||||
|
||||
if position > max_leaf_position {
|
||||
Err(S::Error::from(QueryError::NotContained(
|
||||
Address::from_parts(Level::from(0), position.into()),
|
||||
)))
|
||||
Err(
|
||||
QueryError::NotContained(Address::from_parts(Level::from(0), position.into()))
|
||||
.into(),
|
||||
)
|
||||
} else {
|
||||
let subtree_addr = Self::subtree_addr(position);
|
||||
|
||||
|
@ -2778,6 +2964,41 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn witness_caching(
|
||||
&mut self,
|
||||
position: Position,
|
||||
checkpoint_depth: usize,
|
||||
) -> Result<MerklePath<H, DEPTH>, S::Error> {
|
||||
let max_leaf_position = self.max_leaf_position(checkpoint_depth).and_then(|v| {
|
||||
v.ok_or_else(|| QueryError::TreeIncomplete(vec![Self::root_addr()]).into())
|
||||
})?;
|
||||
|
||||
if position > max_leaf_position {
|
||||
Err(
|
||||
QueryError::NotContained(Address::from_parts(Level::from(0), position.into()))
|
||||
.into(),
|
||||
)
|
||||
} else {
|
||||
let subtree_addr = Address::above_position(Self::subtree_level(), position);
|
||||
|
||||
// compute the witness for the specified position up to the subtree root
|
||||
let mut witness = self.store.get_shard(subtree_addr)?.map_or_else(
|
||||
|| Err(QueryError::TreeIncomplete(vec![subtree_addr])),
|
||||
|subtree| subtree.witness(position, max_leaf_position + 1),
|
||||
)?;
|
||||
|
||||
// compute the remaining parts of the witness up to the root
|
||||
let root_addr = Self::root_addr();
|
||||
let mut cur_addr = subtree_addr;
|
||||
while cur_addr != root_addr {
|
||||
witness.push(self.root_caching(cur_addr.sibling(), max_leaf_position + 1)?);
|
||||
cur_addr = cur_addr.parent();
|
||||
}
|
||||
|
||||
Ok(MerklePath::from_parts(witness, position).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// Make a marked leaf at a position eligible to be pruned.
|
||||
///
|
||||
/// If the checkpoint associated with the specified identifier does not exist because the
|
||||
|
@ -2844,6 +3065,8 @@ fn accumulate_result_with<A, B, C>(
|
|||
pub mod testing {
|
||||
use super::*;
|
||||
use incrementalmerkletree::Hashable;
|
||||
use proptest::bool::weighted;
|
||||
use proptest::collection::vec;
|
||||
use proptest::prelude::*;
|
||||
use proptest::sample::select;
|
||||
|
||||
|
@ -2885,11 +3108,70 @@ pub mod testing {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Constructs a random shardtree of size up to 2^6 with shards of size 2^3. Returns the tree,
|
||||
/// along with vectors of the checkpoint and mark positions.
|
||||
pub fn arb_shardtree<H: Strategy>(
|
||||
arb_leaf: H,
|
||||
) -> impl Strategy<
|
||||
Value = (
|
||||
ShardTree<MemoryShardStore<H::Value, usize>, 6, 3>,
|
||||
Vec<Position>,
|
||||
Vec<Position>,
|
||||
),
|
||||
>
|
||||
where
|
||||
H::Value: Hashable + Clone + PartialEq,
|
||||
{
|
||||
vec(
|
||||
(arb_leaf, weighted(0.1), weighted(0.2)),
|
||||
0..=(2usize.pow(6)),
|
||||
)
|
||||
.prop_map(|leaves| {
|
||||
let mut tree = ShardTree::empty(MemoryShardStore::empty(), 10);
|
||||
let mut checkpoint_positions = vec![];
|
||||
let mut marked_positions = vec![];
|
||||
tree.batch_insert(
|
||||
Position::from(0),
|
||||
leaves
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(id, (leaf, is_marked, is_checkpoint))| {
|
||||
(
|
||||
leaf,
|
||||
match (is_checkpoint, is_marked) {
|
||||
(false, false) => Retention::Ephemeral,
|
||||
(true, is_marked) => {
|
||||
let pos = Position::try_from(id).unwrap();
|
||||
checkpoint_positions.push(pos);
|
||||
if is_marked {
|
||||
marked_positions.push(pos);
|
||||
}
|
||||
Retention::Checkpoint { id, is_marked }
|
||||
}
|
||||
(false, true) => {
|
||||
marked_positions.push(Position::try_from(id).unwrap());
|
||||
Retention::Marked
|
||||
}
|
||||
},
|
||||
)
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
(tree, checkpoint_positions, marked_positions)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn arb_char_str() -> impl Strategy<Value = String> {
|
||||
let chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||
(0usize..chars.len()).prop_map(move |i| chars.get(i..=i).unwrap().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
testing::{arb_char_str, arb_shardtree},
|
||||
IncompleteAt, InsertionError, LocatedPrunableTree, LocatedTree, MemoryShardStore,
|
||||
MemoryShardStoreError, Node, PrunableTree, QueryError, RetentionFlags, ShardStore,
|
||||
ShardTree, Tree,
|
||||
|
@ -3460,6 +3742,30 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#![proptest_config(ProptestConfig::with_cases(100))]
|
||||
|
||||
#[test]
|
||||
fn check_shardtree_caching(
|
||||
(mut tree, _, marked_positions) in arb_shardtree(arb_char_str())
|
||||
) {
|
||||
if let Some(max_leaf_pos) = tree.max_leaf_position(0).unwrap() {
|
||||
let max_complete_addr = Address::above_position(max_leaf_pos.root_level(), max_leaf_pos);
|
||||
let root = tree.root(max_complete_addr, max_leaf_pos + 1);
|
||||
let caching_root = tree.root_caching(max_complete_addr, max_leaf_pos + 1);
|
||||
assert_matches!(root, Ok(_));
|
||||
assert_eq!(root, caching_root);
|
||||
|
||||
for pos in marked_positions {
|
||||
let witness = tree.witness(pos, 0);
|
||||
let caching_witness = tree.witness_caching(pos, 0);
|
||||
assert_matches!(witness, Ok(_));
|
||||
assert_eq!(witness, caching_witness);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn insert_frontier_nodes() {
|
||||
let mut frontier = NonEmptyFrontier::new("a".to_string());
|
||||
|
@ -3539,7 +3845,8 @@ mod tests {
|
|||
);
|
||||
|
||||
assert_eq!(
|
||||
r.root.root_hash(Address::from_parts(Level::from(3), 3), Position::from(25)),
|
||||
r.root
|
||||
.root_hash(Address::from_parts(Level::from(3), 3), Position::from(25)),
|
||||
Ok("y_______".to_string())
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue