Merge pull request #82 from zcash/shardtree-bugfixes

`shardtree` bugfixes
This commit is contained in:
str4d 2023-07-06 17:34:18 +01:00 committed by GitHub
commit 67111e2940
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 393 additions and 95 deletions

View File

@ -5,6 +5,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to Rust's notion of
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Fixed
- `incrementalmerkletree::Address::common_ancestor` no longer produces incorrect
results for some pairs of addresses. It was previously using an arithmetic
distance between indices within a level, instead of a bitwise distance.
### Changed
- `incrementalmerkletree::Hashable` trait now has a `Debug` bound.
## [0.4.0] - 2023-06-05
Release 0.4.0 represents a substantial rewrite of the `incrementalmerkletree`

View File

@ -46,6 +46,7 @@
use either::Either;
use std::cmp::Ordering;
use std::convert::{TryFrom, TryInto};
use std::fmt;
use std::num::TryFromIntError;
use std::ops::{Add, AddAssign, Range, Sub};
@ -383,22 +384,23 @@ impl Address {
/// Returns the common ancestor of `self` and `other` having the smallest level value.
pub fn common_ancestor(&self, other: &Self) -> Self {
if self.level >= other.level {
let other_ancestor_idx = other.index >> (self.level.0 - other.level.0);
let index_delta = self.index.abs_diff(other_ancestor_idx);
let level_delta = (u64::BITS - index_delta.leading_zeros()) as u8;
Address {
level: self.level + level_delta,
index: std::cmp::max(self.index, other_ancestor_idx) >> level_delta,
}
// We can leverage the symmetry of a binary tree to share the calculation logic,
// by ordering the nodes.
let (higher, lower) = if self.level >= other.level {
(self, other)
} else {
let self_ancestor_idx = self.index >> (other.level.0 - self.level.0);
let index_delta = other.index.abs_diff(self_ancestor_idx);
let level_delta = (u64::BITS - index_delta.leading_zeros()) as u8;
Address {
level: other.level + level_delta,
index: std::cmp::max(other.index, self_ancestor_idx) >> level_delta,
}
(other, self)
};
// We follow the lower node's subtree up to the same level as the higher node, and
// then use their XOR distance to determine how many levels of the tree their
// Merkle paths differ on.
let lower_ancestor_idx = lower.index >> (higher.level.0 - lower.level.0);
let index_delta = higher.index ^ lower_ancestor_idx;
let level_delta = (u64::BITS - index_delta.leading_zeros()) as u8;
Address {
level: higher.level + level_delta,
index: std::cmp::max(higher.index, lower_ancestor_idx) >> level_delta,
}
}
@ -602,7 +604,7 @@ impl<H: Hashable, const DEPTH: u8> MerklePath<H, DEPTH> {
/// A trait describing the operations that make a type suitable for use as
/// a leaf or node value in a merkle tree.
pub trait Hashable {
pub trait Hashable: fmt::Debug {
fn empty_leaf() -> Self;
fn combine(level: Level, a: &Self, b: &Self) -> Self;
@ -826,25 +828,115 @@ pub(crate) mod tests {
#[test]
fn addr_common_ancestor() {
// rt
// --------------- ----------------
// ------- ------- right -------
// ----- left ----- ----- ----- ----- ----- -----
// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(2), 1).common_ancestor(&Address::from_parts(Level(3), 2)),
Address::from_parts(Level(5), 0)
);
// --------------------------
// --------------- ----------------
// ------- rt ------- -------
// ----- ----- left ----- ----- ----- ----- -----
// -- -- -- -- -- -- -- rg -- -- -- -- -- -- -- --
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(2), 2).common_ancestor(&Address::from_parts(Level(1), 7)),
Address::from_parts(Level(3), 1)
);
// --------------------------
// --------------- ----------------
// ------- rt ------- -------
// ----- ----- left ----- ----- ----- ----- -----
// -- -- -- -- -- -- rg -- -- -- -- -- -- -- -- --
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(2), 2).common_ancestor(&Address::from_parts(Level(1), 6)),
Address::from_parts(Level(3), 1)
);
// --------------------------
// --------------- ----------------
// ------- ------- ------- -------
// ----- ----- all ----- ----- ----- ----- -----
// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(2), 2).common_ancestor(&Address::from_parts(Level(2), 2)),
Address::from_parts(Level(2), 2)
);
// --------------------------
// --------------- ----------------
// ------- ------- ------- -------
// ----- ----- lf,rt ----- ----- ----- ----- -----
// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
// - - - - - - - - - rg - - - - - - - - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(2), 2).common_ancestor(&Address::from_parts(Level(0), 9)),
Address::from_parts(Level(2), 2)
);
// --------------------------
// --------------- ----------------
// ------- ------- ------- -------
// ----- ----- rg,rt ----- ----- ----- ----- -----
// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
// - - - - - - - - - lf - - - - - - - - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(0), 9).common_ancestor(&Address::from_parts(Level(2), 2)),
Address::from_parts(Level(2), 2)
);
// --------------------------
// --------------- ----------------
// ------- ------- ------- -------
// ----- ----- ----- rt ----- ----- ----- -----
// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
// - - - - - - - - - - - - lf - - rg - - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(0), 12).common_ancestor(&Address::from_parts(Level(0), 15)),
Address::from_parts(Level(2), 3)
);
// --------------------------
// --------------- ----------------
// ------- ------- ------- -------
// ----- ----- ----- rt ----- ----- ----- -----
// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
// - - - - - - - - - - - - - lf - rg - - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(0), 13).common_ancestor(&Address::from_parts(Level(0), 15)),
Address::from_parts(Level(2), 3)
);
// --------------------------
// --------------- ----------------
// ------- ------- ------- -------
// ----- ----- ----- rt ----- ----- ----- -----
// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
// - - - - - - - - - - - - - lf rg - - - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(0), 13).common_ancestor(&Address::from_parts(Level(0), 14)),
Address::from_parts(Level(2), 3)
);
// --------------------------
// --------------- ----------------
// ------- ------- ------- -------
// ----- ----- ----- ----- ----- ----- ----- -----
// -- -- -- -- -- -- -- rt -- -- -- -- -- -- -- --
// - - - - - - - - - - - - - - lf rg - - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(0), 14).common_ancestor(&Address::from_parts(Level(0), 15)),
Address::from_parts(Level(1), 7)
);
// rt
// --------------- ----------------
// ------- ------- ------- -------
// ----- ----- ----- ----- ----- ----- ----- -----
// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
// - - - - - - - - - - - - - - - lf rg - - - - - - - - - - - - - - -
assert_eq!(
Address::from_parts(Level(0), 15).common_ancestor(&Address::from_parts(Level(0), 16)),
Address::from_parts(Level(5), 0)
);
}
}

View File

@ -17,6 +17,7 @@ bitflags = "1.3"
either = "1.8"
incrementalmerkletree = { version = "0.4", path = "../incrementalmerkletree" }
proptest = { version = "1.0.0", optional = true }
tracing = "0.1"
[dev-dependencies]
assert_matches = "1.5"

View File

@ -2,6 +2,7 @@ use core::fmt::{self, Debug, Display};
use either::Either;
use std::collections::{BTreeMap, BTreeSet};
use std::rc::Rc;
use tracing::trace;
use incrementalmerkletree::{
frontier::NonEmptyFrontier, Address, Hashable, Level, MerklePath, Position, Retention,
@ -349,7 +350,7 @@ where
impl<
H: Hashable + Clone + PartialEq,
C: Clone + Ord,
C: Clone + Debug + Ord,
S: ShardStore<H = H, CheckpointId = C>,
const DEPTH: u8,
const SHARD_HEIGHT: u8,
@ -534,6 +535,7 @@ impl<
) -> Result<(), ShardTreeError<S::Error>> {
let leaf_position = frontier.position();
let subtree_root_addr = Address::above_position(Self::subtree_level(), leaf_position);
trace!("Subtree containing nodes: {:?}", subtree_root_addr);
let (updated_subtree, supertree) = self
.store
@ -559,6 +561,7 @@ impl<
}
if let Retention::Checkpoint { id, is_marked: _ } = leaf_retention {
trace!("Adding checkpoint {:?} at {:?}", id, leaf_position);
self.store
.add_checkpoint(id, Checkpoint::at_position(leaf_position))
.map_err(ShardTreeError::Storage)?;
@ -646,6 +649,7 @@ impl<
mut start: Position,
values: I,
) -> Result<Option<(Position, Vec<IncompleteAt>)>, ShardTreeError<S::Error>> {
trace!("Batch inserting from {:?}", start);
let mut values = values.peekable();
let mut subtree_root_addr = Self::subtree_addr(start);
let mut max_insert_position = None;
@ -801,20 +805,34 @@ impl<
.store
.checkpoint_count()
.map_err(ShardTreeError::Storage)?;
trace!(
"Tree has {} checkpoints, max is {}",
checkpoint_count,
self.max_checkpoints,
);
if checkpoint_count > self.max_checkpoints {
// Batch removals by subtree & create a list of the checkpoint identifiers that
// will be removed from the checkpoints map.
let remove_count = checkpoint_count - self.max_checkpoints;
let mut checkpoints_to_delete = vec![];
let mut clear_positions: BTreeMap<Address, BTreeMap<Position, RetentionFlags>> =
BTreeMap::new();
self.store
.with_checkpoints(
checkpoint_count - self.max_checkpoints,
|cid, checkpoint| {
checkpoints_to_delete.push(cid.clone());
.with_checkpoints(checkpoint_count, |cid, checkpoint| {
// When removing is true, we are iterating through the range of
// checkpoints being removed. When remove is false, we are
// iterating through the range of checkpoints that are being
// retained.
let removing = checkpoints_to_delete.len() < remove_count;
let mut clear_at = |pos, flags_to_clear| {
let subtree_addr = Self::subtree_addr(pos);
if removing {
checkpoints_to_delete.push(cid.clone());
}
let mut clear_at = |pos, flags_to_clear| {
let subtree_addr = Self::subtree_addr(pos);
if removing {
// Mark flags to be cleared from the given position.
clear_positions
.entry(subtree_addr)
.and_modify(|to_clear| {
@ -824,23 +842,43 @@ impl<
.or_insert(flags_to_clear);
})
.or_insert_with(|| BTreeMap::from([(pos, flags_to_clear)]));
};
// clear the checkpoint leaf
if let TreeState::AtPosition(pos) = checkpoint.tree_state {
clear_at(pos, RetentionFlags::CHECKPOINT)
} else {
// Unmark flags that might have been marked for clearing above
// but which we now know we need to preserve.
if let Some(to_clear) = clear_positions.get_mut(&subtree_addr) {
if let Some(flags) = to_clear.get_mut(&pos) {
*flags &= !flags_to_clear;
}
}
}
};
// clear the leaves that have been marked for removal
for unmark_pos in checkpoint.marks_removed.iter() {
clear_at(*unmark_pos, RetentionFlags::MARKED)
}
// Clear or preserve the checkpoint leaf.
if let TreeState::AtPosition(pos) = checkpoint.tree_state {
clear_at(pos, RetentionFlags::CHECKPOINT)
}
Ok(())
},
)
// Clear or preserve the leaves that have been marked for removal.
for unmark_pos in checkpoint.marks_removed.iter() {
clear_at(*unmark_pos, RetentionFlags::MARKED)
}
Ok(())
})
.map_err(ShardTreeError::Storage)?;
// Remove any nodes that are fully preserved by later checkpoints.
clear_positions.retain(|_, to_clear| {
to_clear.retain(|_, flags| !flags.is_empty());
!to_clear.is_empty()
});
trace!(
"Removing checkpoints {:?}, pruning subtrees {:?}",
checkpoints_to_delete,
clear_positions,
);
// Prune each affected subtree
for (subtree_addr, positions) in clear_positions.into_iter() {
let cleared = self
@ -1569,6 +1607,39 @@ mod tests {
check_rewind_remove_mark(new_tree);
}
#[test]
fn checkpoint_pruning_repeated() {
// Create a tree with some leaves.
let mut tree = new_tree(10);
for c in 'a'..='c' {
tree.append(c.to_string(), Retention::Ephemeral).unwrap();
}
// Repeatedly checkpoint the tree at the same position until the checkpoint cache
// is full (creating a sequence of checkpoints in between which no new leaves were
// appended to the tree).
for i in 0..10 {
assert_eq!(tree.checkpoint(i), Ok(true));
}
// Create one more checkpoint at the same position, causing the oldest in the
// cache to be pruned.
assert_eq!(tree.checkpoint(10), Ok(true));
// Append a leaf to the tree and checkpoint it, causing the next oldest in the
// cache to be pruned.
assert_eq!(
tree.append(
'd'.to_string(),
Retention::Checkpoint {
id: 11,
is_marked: false
},
),
Ok(()),
);
}
// Combined tree tests
#[allow(clippy::type_complexity)]
fn new_combined_tree<H: Hashable + Ord + Clone + core::fmt::Debug>(

View File

@ -7,6 +7,7 @@ use bitflags::bitflags;
use incrementalmerkletree::{
frontier::NonEmptyFrontier, Address, Hashable, Level, Position, Retention,
};
use tracing::trace;
use crate::{LocatedTree, Node, Tree};
@ -217,21 +218,23 @@ impl<H: Hashable + Clone + PartialEq> PrunableTree<H> {
match (t0, t1) {
(Tree(Node::Nil), other) | (other, Tree(Node::Nil)) => Ok(other),
(Tree(Node::Leaf { value: vl }), Tree(Node::Leaf { value: vr })) => {
if vl == vr {
Ok(Tree(Node::Leaf { value: vl }))
if vl.0 == vr.0 {
// Merge the flags together.
Ok(Tree(Node::Leaf {
value: (vl.0, vl.1 | vr.1),
}))
} else {
trace!(left = ?vl.0, right = ?vr.0, "Merge conflict for leaves");
Err(addr)
}
}
(Tree(Node::Leaf { value }), parent @ Tree(Node::Parent { .. }))
| (parent @ Tree(Node::Parent { .. }), Tree(Node::Leaf { value })) => {
if parent
.root_hash(addr, no_default_fill)
.iter()
.all(|r| r == &value.0)
{
let parent_hash = parent.root_hash(addr, no_default_fill);
if parent_hash.iter().all(|r| r == &value.0) {
Ok(parent.reannotate_root(Some(Rc::new(value.0))))
} else {
trace!(leaf = ?value, node = ?parent_hash, "Merge conflict for leaf into node");
Err(addr)
}
}
@ -240,7 +243,7 @@ impl<H: Hashable + Clone + PartialEq> PrunableTree<H> {
let rroot = rparent.root_hash(addr, no_default_fill).ok();
// If both parents share the same root hash (or if one of them is absent),
// they can be merged
if lroot.zip(rroot).iter().all(|(l, r)| l == r) {
if lroot.iter().zip(&rroot).all(|(l, r)| l == r) {
// using `if let` here to bind variables; we need to borrow the trees for
// root hash calculation but binding the children of the parent node
// interferes with binding a reference to the parent.
@ -268,12 +271,14 @@ impl<H: Hashable + Clone + PartialEq> PrunableTree<H> {
unreachable!()
}
} else {
trace!(left = ?lroot, right = ?rroot, "Merge conflict for nodes");
Err(addr)
}
}
}
}
trace!(this = ?self, other = ?other, "Merging subtrees");
go(root_addr, self, other)
}
@ -688,6 +693,11 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
(node.root.reannotate_root(ann), incomplete)
};
trace!(
"Node at {:?} contains subtree at {:?}",
root_addr,
subtree.root_addr(),
);
match into {
Tree(Node::Nil) => Ok(replacement(None, subtree)),
Tree(Node::Leaf { value: (value, _) }) => {
@ -696,20 +706,18 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
// It is safe to replace the existing root unannotated, because we
// can always recompute the root from a complete subtree.
Ok((subtree.root, vec![]))
} else if subtree
.root
.0
.annotation()
.and_then(|ann| ann.as_ref())
.iter()
.all(|v| v.as_ref() == value)
{
} else if subtree.root.node_value().iter().all(|v| v == &value) {
Ok((
// at this point we statically know the root to be a parent
subtree.root.reannotate_root(Some(Rc::new(value.clone()))),
vec![],
))
} else {
trace!(
cur_root = ?value,
new_root = ?subtree.root.node_value(),
"Insertion conflict",
);
Err(InsertionError::Conflict(root_addr))
}
} else {
@ -866,6 +874,12 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
prune_below: Level,
mut values: I,
) -> Option<BatchInsertionResult<H, C, I>> {
trace!(
position_range = ?position_range,
prune_below = ?prune_below,
"Creating minimal tree for insertion"
);
// Unite two subtrees by either adding a parent node, or a leaf containing the Merkle root
// of such a parent if both nodes are ephemeral leaves.
//
@ -876,6 +890,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
rroot: LocatedPrunableTree<H>,
prune_below: Level,
) -> LocatedTree<Option<Rc<H>>, (H, RetentionFlags)> {
assert_eq!(lroot.root_addr.parent(), rroot.root_addr.parent());
LocatedTree {
root_addr: lroot.root_addr.parent(),
root: if lroot.root_addr.level() < prune_below {
@ -890,12 +905,43 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
}
}
/// Combines the given subtree with an empty sibling node to obtain the next level
/// subtree.
///
/// `expect_left_child` is set to a constant at each callsite, to ensure that this
/// function is only called on either the left-most or right-most subtree.
fn combine_with_empty<H: Hashable + Clone + PartialEq>(
root: LocatedPrunableTree<H>,
expect_left_child: bool,
incomplete: &mut Vec<IncompleteAt>,
contains_marked: bool,
prune_below: Level,
) -> LocatedPrunableTree<H> {
assert_eq!(expect_left_child, root.root_addr.is_left_child());
let sibling_addr = root.root_addr.sibling();
incomplete.push(IncompleteAt {
address: sibling_addr,
required_for_witness: contains_marked,
});
let sibling = LocatedTree {
root_addr: sibling_addr,
root: Tree(Node::Nil),
};
let (lroot, rroot) = if root.root_addr.is_left_child() {
(root, sibling)
} else {
(sibling, root)
};
unite(lroot, rroot, prune_below)
}
// Builds a single tree from the provided stack of subtrees, which must be non-overlapping
// and in position order. Returns the resulting tree, a flag indicating whether the
// resulting tree contains a `MARKED` node, and the vector of [`IncompleteAt`] values for
// [`Node::Nil`] nodes that were introduced in the process of constructing the tree.
fn build_minimal_tree<H: Hashable + Clone + PartialEq>(
mut xs: Vec<(LocatedPrunableTree<H>, bool)>,
root_addr: Address,
prune_below: Level,
) -> Option<(LocatedPrunableTree<H>, bool, Vec<IncompleteAt>)> {
// First, consume the stack from the right, building up a single tree
@ -904,19 +950,8 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
let mut incomplete = vec![];
while let Some((top, top_marked)) = xs.pop() {
while cur.root_addr.level() < top.root_addr.level() {
let sibling_addr = cur.root_addr.sibling();
incomplete.push(IncompleteAt {
address: sibling_addr,
required_for_witness: top_marked,
});
cur = unite(
cur,
LocatedTree {
root_addr: sibling_addr,
root: Tree(Node::Nil),
},
prune_below,
);
cur =
combine_with_empty(cur, true, &mut incomplete, top_marked, prune_below);
}
if cur.root_addr.level() == top.root_addr.level() {
@ -929,17 +964,11 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
// we've merged as much as we can from the right and need to work from
// the left
xs.push((top, top_marked));
let sibling_addr = cur.root_addr.sibling();
incomplete.push(IncompleteAt {
address: sibling_addr,
required_for_witness: top_marked,
});
cur = unite(
cur = combine_with_empty(
cur,
LocatedTree {
root_addr: sibling_addr,
root: Tree(Node::Nil),
},
true,
&mut incomplete,
top_marked,
prune_below,
);
break;
@ -952,6 +981,17 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
}
}
// Ensure we can work from the left in a single pass by making this right-most subtree
while cur.root_addr.level() + 1 < root_addr.level() {
cur = combine_with_empty(
cur,
true,
&mut incomplete,
contains_marked,
prune_below,
);
}
// push our accumulated max-height right hand node back on to the stack.
xs.push((cur, contains_marked));
@ -964,23 +1004,16 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
// add nil branches to build up the left tree until we can merge it
// with the right
while prev_tree.root_addr.level() < next_tree.root_addr.level() {
let sibling_addr = prev_tree.root_addr.sibling();
contains_marked = contains_marked || next_marked;
incomplete.push(IncompleteAt {
address: sibling_addr,
required_for_witness: next_marked,
});
prev_tree = unite(
LocatedTree {
root_addr: sibling_addr,
root: Tree(Node::Nil),
},
prev_tree = combine_with_empty(
prev_tree,
false,
&mut incomplete,
next_marked,
prune_below,
);
}
// at this point, prev_tree.level == next_tree.level
Some(unite(prev_tree, next_tree, prune_below))
} else {
Some(next_tree)
@ -1034,17 +1067,26 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
break;
}
}
trace!("Initial fragments: {:?}", fragments);
build_minimal_tree(fragments, prune_below).map(
|(to_insert, contains_marked, incomplete)| BatchInsertionResult {
subtree: to_insert,
contains_marked,
incomplete,
max_insert_position: Some(position - 1),
checkpoints,
remainder: values,
},
)
if position > position_range.start {
let last_position = position - 1;
let minimal_tree_addr =
Address::from(position_range.start).common_ancestor(&last_position.into());
trace!("Building minimal tree at {:?}", minimal_tree_addr);
build_minimal_tree(fragments, minimal_tree_addr, prune_below).map(
|(to_insert, contains_marked, incomplete)| BatchInsertionResult {
subtree: to_insert,
contains_marked,
incomplete,
max_insert_position: Some(last_position),
checkpoints,
remainder: values,
},
)
} else {
None
}
}
/// Put a range of values into the subtree by consuming the given iterator, starting at the
@ -1062,6 +1104,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
start: Position,
values: I,
) -> Result<Option<BatchInsertionResult<H, C, I>>, InsertionError> {
trace!("Batch inserting into {:?} from {:?}", self.root_addr, start);
let subtree_range = self.root_addr.position_range();
let contains_start = subtree_range.contains(&start);
if contains_start {
@ -1358,6 +1401,12 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
let (l_addr, r_addr) = root_addr.children().unwrap();
let p = to_clear.partition_point(|(p, _)| p < &l_addr.position_range_end());
trace!(
"In {:?}, partitioned: {:?} {:?}",
root_addr,
&to_clear[0..p],
&to_clear[p..],
);
Tree::unite(
l_addr.level(),
ann.clone(),
@ -1366,6 +1415,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
)
}
Tree(Node::Leaf { value: (h, r) }) => {
trace!("In {:?}, clearing {:?}", root_addr, to_clear);
// When we reach a leaf, we should be down to just a single position
// which should correspond to the last level-0 child of the address's
// subtree range; if it's a checkpoint this will always be the case for
@ -1527,6 +1577,26 @@ mod tests {
);
}
#[test]
fn merge_checked_flags() {
let t0: PrunableTree<String> = leaf(("a".to_string(), RetentionFlags::EPHEMERAL));
let t1: PrunableTree<String> = leaf(("a".to_string(), RetentionFlags::MARKED));
let t2: PrunableTree<String> = leaf(("a".to_string(), RetentionFlags::CHECKPOINT));
assert_eq!(
t0.merge_checked(Address::from_parts(1.into(), 0), t1.clone()),
Ok(t1.clone()),
);
assert_eq!(
t1.merge_checked(Address::from_parts(1.into(), 0), t2),
Ok(leaf((
"a".to_string(),
RetentionFlags::MARKED | RetentionFlags::CHECKPOINT,
))),
);
}
#[test]
fn located_insert_subtree() {
let t: LocatedPrunableTree<String> = LocatedTree {
@ -1561,6 +1631,31 @@ mod tests {
);
}
#[test]
fn located_insert_subtree_leaf_overwrites() {
let t: LocatedPrunableTree<String> = LocatedTree {
root_addr: Address::from_parts(2.into(), 1),
root: parent(leaf(("a".to_string(), RetentionFlags::MARKED)), nil()),
};
assert_eq!(
t.insert_subtree(
LocatedTree {
root_addr: Address::from_parts(1.into(), 2),
root: leaf(("b".to_string(), RetentionFlags::EPHEMERAL)),
},
false,
),
Ok((
LocatedTree {
root_addr: Address::from_parts(2.into(), 1),
root: parent(leaf(("b".to_string(), RetentionFlags::EPHEMERAL)), nil()),
},
vec![],
)),
);
}
#[test]
fn located_witness() {
let t: LocatedPrunableTree<String> = LocatedTree {
@ -1600,6 +1695,36 @@ mod tests {
);
}
#[test]
fn located_from_iter_non_sibling_adjacent() {
let res = LocatedPrunableTree::from_iter::<(), _>(
Position::from(3)..Position::from(5),
Level::new(0),
vec![
("d".to_string(), Retention::Ephemeral),
("e".to_string(), Retention::Ephemeral),
]
.into_iter(),
)
.unwrap();
assert_eq!(
res.subtree,
LocatedPrunableTree {
root_addr: Address::from_parts(3.into(), 0),
root: parent(
parent(
nil(),
parent(nil(), leaf(("d".to_string(), RetentionFlags::EPHEMERAL)))
),
parent(
parent(leaf(("e".to_string(), RetentionFlags::EPHEMERAL)), nil()),
nil()
)
)
},
);
}
#[test]
fn located_insert() {
let tree = LocatedPrunableTree::empty(Address::from_parts(Level::from(2), 0));