shardtree: Ensure `LocatedTree` is correct by construction

This ensures that the various existing `unwrap`s are correctly
justified.
This commit is contained in:
Jack Grigg 2024-11-23 03:47:35 +00:00
parent 1e81f3ca45
commit 166872e49b
4 changed files with 84 additions and 19 deletions

View File

@ -7,6 +7,10 @@ and this project adheres to Rust's notion of
## Unreleased ## Unreleased
### Changed
- `shardtree::LocatedTree::from_parts` now returns `Option<Self>` (returning
`None` if the provided `Address` and `Tree` are inconsistent).
## [0.5.0] - 2024-10-04 ## [0.5.0] - 2024-10-04
This release includes a significant refactoring and rework of several methods This release includes a significant refactoring and rework of several methods

View File

@ -406,13 +406,16 @@ impl<
/// Adds a checkpoint at the rightmost leaf state of the tree. /// Adds a checkpoint at the rightmost leaf state of the tree.
pub fn checkpoint(&mut self, checkpoint_id: C) -> Result<bool, ShardTreeError<S::Error>> { pub fn checkpoint(&mut self, checkpoint_id: C) -> Result<bool, ShardTreeError<S::Error>> {
/// Pre-condition: `root_addr` must be the address of `root`.
fn go<H: Hashable + Clone + PartialEq>( fn go<H: Hashable + Clone + PartialEq>(
root_addr: Address, root_addr: Address,
root: &PrunableTree<H>, root: &PrunableTree<H>,
) -> Option<(PrunableTree<H>, Position)> { ) -> Option<(PrunableTree<H>, Position)> {
match &root.0 { match &root.0 {
Node::Parent { ann, left, right } => { Node::Parent { ann, left, right } => {
let (l_addr, r_addr) = root_addr.children().unwrap(); let (l_addr, r_addr) = root_addr
.children()
.expect("has children because we checked `root` is a parent");
go(r_addr, right).map_or_else( go(r_addr, right).map_or_else(
|| { || {
go(l_addr, left).map(|(new_left, pos)| { go(l_addr, left).map(|(new_left, pos)| {
@ -765,7 +768,10 @@ impl<
// Compute the roots of the left and right children and hash them together. // Compute the roots of the left and right children and hash them together.
// We skip computation in any subtrees that will not have data included in // We skip computation in any subtrees that will not have data included in
// the final result. // the final result.
let (l_addr, r_addr) = cap.root_addr.children().unwrap(); let (l_addr, r_addr) = cap
.root_addr
.children()
.expect("has children because we checked `cap.root` is a parent");
let l_result = if r_addr.contains(&target_addr) { let l_result = if r_addr.contains(&target_addr) {
None None
} else { } else {

View File

@ -358,6 +358,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
/// Note that no actual leaf value may exist at this position, as it may have previously been /// Note that no actual leaf value may exist at this position, as it may have previously been
/// pruned. /// pruned.
pub fn max_position(&self) -> Option<Position> { pub fn max_position(&self) -> Option<Position> {
/// Pre-condition: `addr` must be the address of `root`.
fn go<H>( fn go<H>(
addr: Address, addr: Address,
root: &Tree<Option<Arc<H>>, (H, RetentionFlags)>, root: &Tree<Option<Arc<H>>, (H, RetentionFlags)>,
@ -369,7 +370,9 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
if ann.is_some() { if ann.is_some() {
Some(addr.max_position()) Some(addr.max_position())
} else { } else {
let (l_addr, r_addr) = addr.children().unwrap(); let (l_addr, r_addr) = addr
.children()
.expect("has children because we checked `root` is a parent");
go(r_addr, right.as_ref()).or_else(|| go(l_addr, left.as_ref())) go(r_addr, right.as_ref()).or_else(|| go(l_addr, left.as_ref()))
} }
} }
@ -406,6 +409,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
/// Returns the positions of marked leaves in the tree. /// Returns the positions of marked leaves in the tree.
pub fn marked_positions(&self) -> BTreeSet<Position> { pub fn marked_positions(&self) -> BTreeSet<Position> {
/// Pre-condition: `root_addr` must be the address of `root`.
fn go<H: Hashable + Clone + PartialEq>( fn go<H: Hashable + Clone + PartialEq>(
root_addr: Address, root_addr: Address,
root: &PrunableTree<H>, root: &PrunableTree<H>,
@ -413,7 +417,9 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
) { ) {
match &root.0 { match &root.0 {
Node::Parent { left, right, .. } => { Node::Parent { left, right, .. } => {
let (l_addr, r_addr) = root_addr.children().unwrap(); let (l_addr, r_addr) = root_addr
.children()
.expect("has children because we checked `root` is a parent");
go(l_addr, left.as_ref(), acc); go(l_addr, left.as_ref(), acc);
go(r_addr, right.as_ref(), acc); go(r_addr, right.as_ref(), acc);
} }
@ -440,8 +446,10 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
/// Returns either the witness for the leaf at the specified position, or an error that /// Returns either the witness for the leaf at the specified position, or an error that
/// describes the causes of failure. /// describes the causes of failure.
pub fn witness(&self, position: Position, truncate_at: Position) -> Result<Vec<H>, QueryError> { pub fn witness(&self, position: Position, truncate_at: Position) -> Result<Vec<H>, QueryError> {
// traverse down to the desired leaf position, and then construct /// Traverse down to the desired leaf position, and then construct
// the authentication path on the way back up. /// the authentication path on the way back up.
//
/// Pre-condition: `root_addr` must be the address of `root`.
fn go<H: Hashable + Clone + PartialEq>( fn go<H: Hashable + Clone + PartialEq>(
root: &PrunableTree<H>, root: &PrunableTree<H>,
root_addr: Address, root_addr: Address,
@ -450,7 +458,9 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
) -> Result<Vec<H>, Vec<Address>> { ) -> Result<Vec<H>, Vec<Address>> {
match &root.0 { match &root.0 {
Node::Parent { left, right, .. } => { Node::Parent { left, right, .. } => {
let (l_addr, r_addr) = root_addr.children().unwrap(); let (l_addr, r_addr) = root_addr
.children()
.expect("has children because we checked `root` is a parent");
if root_addr.level() > 1.into() { if root_addr.level() > 1.into() {
let r_start = r_addr.position_range_start(); let r_start = r_addr.position_range_start();
if position < r_start { if position < r_start {
@ -525,6 +535,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
/// subtree root with the specified position as its maximum position exists, or `None` /// subtree root with the specified position as its maximum position exists, or `None`
/// otherwise. /// otherwise.
pub fn truncate_to_position(&self, position: Position) -> Option<Self> { pub fn truncate_to_position(&self, position: Position) -> Option<Self> {
/// Pre-condition: `root_addr` must be the address of `root`.
fn go<H: Hashable + Clone + PartialEq>( fn go<H: Hashable + Clone + PartialEq>(
position: Position, position: Position,
root_addr: Address, root_addr: Address,
@ -532,7 +543,9 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
) -> Option<PrunableTree<H>> { ) -> Option<PrunableTree<H>> {
match &root.0 { match &root.0 {
Node::Parent { ann, left, right } => { Node::Parent { ann, left, right } => {
let (l_child, r_child) = root_addr.children().unwrap(); let (l_child, r_child) = root_addr
.children()
.expect("has children because we checked `root` is a parent");
if position < r_child.position_range_start() { if position < r_child.position_range_start() {
// we are truncating within the range of the left node, so recurse // we are truncating within the range of the left node, so recurse
// to the left to truncate the left child and then reconstruct the // to the left to truncate the left child and then reconstruct the
@ -586,8 +599,10 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
subtree: Self, subtree: Self,
contains_marked: bool, contains_marked: bool,
) -> Result<(Self, Vec<IncompleteAt>), InsertionError> { ) -> Result<(Self, Vec<IncompleteAt>), InsertionError> {
// A function to recursively dig into the tree, creating a path downward and introducing /// A function to recursively dig into the tree, creating a path downward and introducing
// empty nodes as necessary until we can insert the provided subtree. /// empty nodes as necessary until we can insert the provided subtree.
///
/// Pre-condition: `root_addr` must be the address of `into`.
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn go<H: Hashable + Clone + PartialEq>( fn go<H: Hashable + Clone + PartialEq>(
root_addr: Address, root_addr: Address,
@ -694,7 +709,9 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
Tree(Node::Parent { ann, left, right }) => { Tree(Node::Parent { ann, left, right }) => {
// In this case, we have an existing parent but we need to dig down farther // In this case, we have an existing parent but we need to dig down farther
// before we can insert the subtree that we're carrying for insertion. // before we can insert the subtree that we're carrying for insertion.
let (l_addr, r_addr) = root_addr.children().unwrap(); let (l_addr, r_addr) = root_addr
.children()
.expect("has children because we checked `into` is a parent");
if l_addr.contains(&subtree.root_addr) { if l_addr.contains(&subtree.root_addr) {
let (new_left, incomplete) = let (new_left, incomplete) =
go(l_addr, left.as_ref(), subtree, contains_marked)?; go(l_addr, left.as_ref(), subtree, contains_marked)?;
@ -892,6 +909,7 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
/// Clears the specified retention flags at all positions specified, pruning any branches /// Clears the specified retention flags at all positions specified, pruning any branches
/// that no longer need to be retained. /// that no longer need to be retained.
pub fn clear_flags(&self, to_clear: BTreeMap<Position, RetentionFlags>) -> Self { pub fn clear_flags(&self, to_clear: BTreeMap<Position, RetentionFlags>) -> Self {
/// Pre-condition: `root_addr` must be the address of `root`.
fn go<H: Hashable + Clone + PartialEq>( fn go<H: Hashable + Clone + PartialEq>(
to_clear: &[(Position, RetentionFlags)], to_clear: &[(Position, RetentionFlags)],
root_addr: Address, root_addr: Address,
@ -903,7 +921,9 @@ impl<H: Hashable + Clone + PartialEq> LocatedPrunableTree<H> {
} else { } else {
match &root.0 { match &root.0 {
Node::Parent { ann, left, right } => { Node::Parent { ann, left, right } => {
let (l_addr, r_addr) = root_addr.children().unwrap(); let (l_addr, r_addr) = root_addr
.children()
.expect("has children because we checked `root` is a parent");
let p = to_clear.partition_point(|(p, _)| p < &l_addr.position_range_end()); let p = to_clear.partition_point(|(p, _)| p < &l_addr.position_range_end());
trace!( trace!(
@ -1228,7 +1248,7 @@ mod tests {
root in arb_prunable_tree(arb_char_str(), 8, 2^6) root in arb_prunable_tree(arb_char_str(), 8, 2^6)
) { ) {
let root_addr = Address::from_parts(Level::from(7), 0); let root_addr = Address::from_parts(Level::from(7), 0);
let tree = LocatedTree::from_parts(root_addr, root); let tree = LocatedTree::from_parts(root_addr, root).unwrap();
let (to_clear, to_retain) = tree.flag_positions().into_iter().enumerate().fold( let (to_clear, to_retain) = tree.flag_positions().into_iter().enumerate().fold(
(BTreeMap::new(), BTreeMap::new()), (BTreeMap::new(), BTreeMap::new()),

View File

@ -197,9 +197,35 @@ pub struct LocatedTree<A, V> {
} }
impl<A, V> LocatedTree<A, V> { impl<A, V> LocatedTree<A, V> {
/// Constructs a new LocatedTree from its constituent parts /// Constructs a new LocatedTree from its constituent parts.
pub fn from_parts(root_addr: Address, root: Tree<A, V>) -> Self { ///
LocatedTree { root_addr, root } /// Returns `None` if `root_addr` is inconsistent with `root` (in particular, if the
/// level of `root_addr` is too small to contain `tree`).
pub fn from_parts(root_addr: Address, root: Tree<A, V>) -> Option<Self> {
// In order to meet various pre-conditions throughout the crate, we require that
// no `Node::Parent` in `root` has a level of 0 relative to `root_addr`.
fn is_consistent<A, V>(addr: Address, root: &Tree<A, V>) -> bool {
match (&root.0, addr.children()) {
// Found an inconsistency!
(Node::Parent { .. }, None) => false,
// Check consistency of children recursively.
(Node::Parent { left, right, .. }, Some((l_addr, r_addr))) => {
is_consistent(l_addr, left) && is_consistent(r_addr, right)
}
// Leaves are technically allowed to occur at any level, so we do not
// require `addr` to have no children.
(Node::Leaf { .. }, _) => true,
// Nil nodes have no information, so we cannot verify that the data it
// represents is consistent with `root_addr`. Instead we rely on methods
// that mutate `LocatedTree` to verify that the insertion address is not
// inconsistent with `root_addr`.
(Node::Nil, _) => true,
}
}
is_consistent(root_addr, &root).then_some(LocatedTree { root_addr, root })
} }
/// Returns the root address of this tree. /// Returns the root address of this tree.
@ -234,10 +260,13 @@ impl<A, V> LocatedTree<A, V> {
/// Returns the value at the specified position, if any. /// Returns the value at the specified position, if any.
pub fn value_at_position(&self, position: Position) -> Option<&V> { pub fn value_at_position(&self, position: Position) -> Option<&V> {
/// Pre-condition: `addr` must be the address of `root`.
fn go<A, V>(pos: Position, addr: Address, root: &Tree<A, V>) -> Option<&V> { fn go<A, V>(pos: Position, addr: Address, root: &Tree<A, V>) -> Option<&V> {
match &root.0 { match &root.0 {
Node::Parent { left, right, .. } => { Node::Parent { left, right, .. } => {
let (l_addr, r_addr) = addr.children().unwrap(); let (l_addr, r_addr) = addr
.children()
.expect("has children because we checked `root` is a parent");
if l_addr.position_range().contains(&pos) { if l_addr.position_range().contains(&pos) {
go(pos, l_addr, left) go(pos, l_addr, left)
} else { } else {
@ -306,6 +335,7 @@ impl<A: Default + Clone, V: Clone> LocatedTree<A, V> {
/// if the tree is terminated by a [`Node::Nil`] or leaf node before the specified address can /// if the tree is terminated by a [`Node::Nil`] or leaf node before the specified address can
/// be reached. /// be reached.
pub fn subtree(&self, addr: Address) -> Option<Self> { pub fn subtree(&self, addr: Address) -> Option<Self> {
/// Pre-condition: `root_addr` must be the address of `root`.
fn go<A: Clone, V: Clone>( fn go<A: Clone, V: Clone>(
root_addr: Address, root_addr: Address,
root: &Tree<A, V>, root: &Tree<A, V>,
@ -319,7 +349,9 @@ impl<A: Default + Clone, V: Clone> LocatedTree<A, V> {
} else { } else {
match &root.0 { match &root.0 {
Node::Parent { left, right, .. } => { Node::Parent { left, right, .. } => {
let (l_addr, r_addr) = root_addr.children().unwrap(); let (l_addr, r_addr) = root_addr
.children()
.expect("has children because we checked `root` is a parent");
if l_addr.contains(&addr) { if l_addr.contains(&addr) {
go(l_addr, left.as_ref(), addr) go(l_addr, left.as_ref(), addr)
} else { } else {
@ -343,6 +375,7 @@ impl<A: Default + Clone, V: Clone> LocatedTree<A, V> {
/// If this root address of this tree is lower down in the tree than the level specified, /// If this root address of this tree is lower down in the tree than the level specified,
/// the entire tree is returned as the sole element of the result vector. /// the entire tree is returned as the sole element of the result vector.
pub fn decompose_to_level(self, level: Level) -> Vec<Self> { pub fn decompose_to_level(self, level: Level) -> Vec<Self> {
/// Pre-condition: `root_addr` must be the address of `root`.
fn go<A: Clone, V: Clone>( fn go<A: Clone, V: Clone>(
level: Level, level: Level,
root_addr: Address, root_addr: Address,
@ -353,7 +386,9 @@ impl<A: Default + Clone, V: Clone> LocatedTree<A, V> {
} else { } else {
match root.0 { match root.0 {
Node::Parent { left, right, .. } => { Node::Parent { left, right, .. } => {
let (l_addr, r_addr) = root_addr.children().unwrap(); let (l_addr, r_addr) = root_addr
.children()
.expect("has children because we checked `root` is a parent");
let mut l_decomposed = go( let mut l_decomposed = go(
level, level,
l_addr, l_addr,