diff --git a/shardtree/src/lib.rs b/shardtree/src/lib.rs index 4986033..7b898b6 100644 --- a/shardtree/src/lib.rs +++ b/shardtree/src/lib.rs @@ -1,9 +1,94 @@ use core::fmt::Debug; -use core::ops::Deref; +use core::ops::{BitAnd, BitOr, Deref, Not}; use either::Either; +use std::collections::BTreeSet; use std::rc::Rc; -use incrementalmerkletree::Address; +use incrementalmerkletree::{Address, Hashable, Level, Position, Retention}; + +/// A type for flags that determine when and how leaves can be pruned from a tree. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct RetentionFlags(u8); + +impl BitOr for RetentionFlags { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self { + RetentionFlags(self.0 | rhs.0) + } +} + +impl BitAnd for RetentionFlags { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self { + RetentionFlags(self.0 & rhs.0) + } +} + +/// An leaf with `EPHEMERAL` retention can be pruned as soon as we are certain that it is not part +/// of the witness for a leaf with `CHECKPOINT` or `MARKED` retention. +pub static EPHEMERAL: RetentionFlags = RetentionFlags(0b00000000); + +/// A leaf with `CHECKPOINT` retention can be pruned when there are more than `max_checkpoints` +/// additional checkpoint leaves, if it is not also a marked leaf. +pub static CHECKPOINT: RetentionFlags = RetentionFlags(0b00000001); + +/// A leaf with `MARKED` retention can be pruned only as a consequence of an explicit deletion +/// action. +pub static MARKED: RetentionFlags = RetentionFlags(0b00000010); + +impl RetentionFlags { + pub fn is_checkpoint(&self) -> bool { + (*self & CHECKPOINT) == CHECKPOINT + } + + pub fn is_marked(&self) -> bool { + (*self & MARKED) == MARKED + } +} + +impl<'a, C> From<&'a Retention> for RetentionFlags { + fn from(retention: &'a Retention) -> Self { + match retention { + Retention::Ephemeral => EPHEMERAL, + Retention::Checkpoint { is_marked, .. } => { + if *is_marked { + CHECKPOINT | MARKED + } else { + CHECKPOINT + } + } + Retention::Marked => MARKED, + } + } +} + +impl From> for RetentionFlags { + fn from(retention: Retention) -> Self { + RetentionFlags::from(&retention) + } +} + +/// A mask that may be used to unset one or more retention flags. +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct RetentionMask(u8); + +impl Not for RetentionFlags { + type Output = RetentionMask; + + fn not(self) -> Self::Output { + RetentionMask(!self.0) + } +} + +impl BitAnd for RetentionFlags { + type Output = Self; + + fn bitand(self, rhs: RetentionMask) -> Self { + RetentionFlags(self.0 & rhs.0) + } +} /// A "pattern functor" for a single layer of a binary tree. #[derive(Clone, Debug, PartialEq, Eq)] @@ -68,6 +153,25 @@ pub fn is_complete(node: Node) -> bool { } } +/// An F-algebra for use with [`Tree::try_reduce`] for determining whether a tree has any `MARKED` nodes. +/// +/// `Tree::try_reduce` is preferred for this operation because it allows us to short-circuit as +/// soon as we find a marked node. Returns [`Either::Left(())`] if a marked node exists, +/// [`Either::Right(())`] otherwise. +pub fn contains_marked(node: Node<(), A, (V, RetentionFlags)>) -> Either<(), ()> { + match node { + Node::Parent { .. } => Either::Right(()), + Node::Leaf { value: (_, r) } => { + if r.is_marked() { + Either::Left(()) + } else { + Either::Right(()) + } + } + Node::Nil { .. } => Either::Right(()), + } +} + /// An immutable binary tree with each of its nodes tagged with an annotation value. #[derive(Clone, Debug, PartialEq, Eq)] pub struct Tree(Node>, A, V>); @@ -161,11 +265,282 @@ impl Tree { } } +type PrunableTree = Tree>, (H, RetentionFlags)>; + +impl PrunableTree { + /// Returns the the value if this is a leaf. + pub fn leaf_value(&self) -> Option<&H> { + self.0.leaf_value().map(|(h, _)| h) + } + + /// Returns the cached root value with which the tree has been annotated for this node if it is + /// available, otherwise return the value if this is a leaf. + pub fn node_value(&self) -> Option<&H> { + self.0.annotation().map_or_else( + || self.leaf_value(), + |rc_opt| rc_opt.as_ref().map(|rc| rc.as_ref()), + ) + } + + /// Returns whether or not this tree is a leaf with `Marked` retention. + pub fn is_marked_leaf(&self) -> bool { + self.0 + .leaf_value() + .map_or(false, |(_, retention)| retention.is_marked()) + } + + /// Returns the Merkle root of this tree, given the address of the root node, or + /// a vector of the addresses of `Nil` nodes that inhibited the computation of + /// such a root. + /// + /// ### Parameters: + /// * `truncate_at` An inclusive lower bound on positions in the tree beyond which all leaf + /// values will be treated as `Nil`. + pub fn root_hash(&self, root_addr: Address, truncate_at: Position) -> Result> { + if truncate_at <= root_addr.position_range_start() { + // we are in the part of the tree where we're generating empty roots, + // so no need to inspect the tree + Ok(H::empty_root(root_addr.level())) + } else { + match self { + Tree(Node::Parent { ann, left, right }) => ann + .as_ref() + .filter(|_| truncate_at >= root_addr.position_range_end()) + .map_or_else( + || { + // Compute the roots of the left and right children and hash them + // together. + let (l_addr, r_addr) = root_addr.children().unwrap(); + accumulate_result_with( + left.root_hash(l_addr, truncate_at), + right.root_hash(r_addr, truncate_at), + |left_root, right_root| { + H::combine(l_addr.level(), &left_root, &right_root) + }, + ) + }, + |rc| { + // Since we have an annotation on the root, and we are not truncating + // within this subtree, we can just use the cached value. + Ok(rc.as_ref().clone()) + }, + ), + Tree(Node::Leaf { value }) => { + if truncate_at >= root_addr.position_range_end() { + // no truncation of this leaf is necessary, just use it + Ok(value.0.clone()) + } else { + // we have a leaf value that is a subtree root created by hashing together + // the roots of child subtrees, but truncation would require that that leaf + // value be "split" into its constituent parts, which we can't do so we + // return an error + Err(vec![root_addr]) + } + } + Tree(Node::Nil) => Err(vec![root_addr]), + } + } + } + + /// Returns a vector of the positions of [`Node::Leaf`] values in the tree having [`MARKED`] + /// retention. + /// + /// Computing the set of marked positions requires a full traversal of the tree, and so should + /// be considered to be a somewhat expensive operation. + pub fn marked_positions(&self, root_addr: Address) -> BTreeSet { + match &self.0 { + Node::Parent { left, right, .. } => { + // We should never construct parent nodes where both children are Nil. + // While we could handle that here, if we encountered that case it would + // be indicative of a programming error elsewhere and so we assert instead. + assert!(!(left.0.is_nil() && right.0.is_nil())); + let (left_root, right_root) = root_addr + .children() + .expect("A parent node cannot appear at level 0"); + + let mut left_incomplete = left.marked_positions(left_root); + let mut right_incomplete = right.marked_positions(right_root); + left_incomplete.append(&mut right_incomplete); + left_incomplete + } + Node::Leaf { + value: (_, retention), + } => { + let mut result = BTreeSet::new(); + if root_addr.level() == 0.into() && retention.is_marked() { + result.insert(Position::from(root_addr.index())); + } + result + } + Node::Nil => BTreeSet::new(), + } + } + + /// Prunes the tree by hashing together ephemeral sibling nodes. + /// + /// `level` must be the level of the root of the node being pruned. + pub fn prune(self, level: Level) -> Self { + match self { + Tree(Node::Parent { ann, left, right }) => Tree::unite( + level, + ann, + left.as_ref().clone().prune(level - 1), + right.as_ref().clone().prune(level - 1), + ), + other => other, + } + } + + /// Merge two subtrees having the same root address. + /// + /// The merge operation is checked to be strictly additive and returns an error if merging + /// would cause information loss or if a conflict between root hashes occurs at a node. The + /// returned error contains the address of the node where such a conflict occurred. + pub fn merge_checked(self, root_addr: Address, other: Self) -> Result { + #[allow(clippy::type_complexity)] + fn go( + addr: Address, + t0: PrunableTree, + t1: PrunableTree, + ) -> Result, Address> { + // Require that any roots the we compute will not be default-filled by picking + // a starting valid fill point that is outside the range of leaf positions. + let no_default_fill = addr.position_range_end(); + match (t0, t1) { + (Tree(Node::Nil), other) => Ok(other), + (other, Tree(Node::Nil)) => Ok(other), + (Tree(Node::Leaf { value: vl }), Tree(Node::Leaf { value: vr })) => { + if vl == vr { + Ok(Tree(Node::Leaf { value: vl })) + } else { + Err(addr) + } + } + (Tree(Node::Leaf { value }), parent) => { + // `parent` is statically known to be a `Node::Parent` + if parent + .root_hash(addr, no_default_fill) + .iter() + .all(|r| r == &value.0) + { + Ok(parent.reannotate_root(Some(Rc::new(value.0)))) + } else { + Err(addr) + } + } + (parent, Tree(Node::Leaf { value })) => { + // `parent` is statically known to be a `Node::Parent` + if parent + .root_hash(addr, no_default_fill) + .iter() + .all(|r| r == &value.0) + { + Ok(parent.reannotate_root(Some(Rc::new(value.0)))) + } else { + Err(addr) + } + } + (lparent, rparent) => { + let lroot = lparent.root_hash(addr, no_default_fill).ok(); + let rroot = rparent.root_hash(addr, no_default_fill).ok(); + // If both parents share the same root hash (or if one of them is absent), + // they can be merged + if lroot.zip(rroot).iter().all(|(l, r)| l == r) { + // using `if let` here to bind variables; we need to borrow the trees for + // root hash calculation but binding the children of the parent node + // interferes with binding a reference to the parent. + if let ( + Tree(Node::Parent { + ann: lann, + left: ll, + right: lr, + }), + Tree(Node::Parent { + ann: rann, + left: rl, + right: rr, + }), + ) = (lparent, rparent) + { + let (l_addr, r_addr) = addr.children().unwrap(); + Ok(Tree::unite( + addr.level() - 1, + lann.or(rann), + go(l_addr, ll.as_ref().clone(), rl.as_ref().clone())?, + go(r_addr, lr.as_ref().clone(), rr.as_ref().clone())?, + )) + } else { + unreachable!() + } + } else { + Err(addr) + } + } + } + } + + go(root_addr, self, other) + } + + /// Unite two nodes by either constructing a new parent node, or, if both nodes are ephemeral + /// leaves or Nil, constructing a replacement root by hashing leaf values together (or a + /// replacement `Nil` value). + /// + /// `level` must be the level of the two nodes that are being joined. + fn unite(level: Level, ann: Option>, left: Self, right: Self) -> Self { + match (left, right) { + (Tree(Node::Nil), Tree(Node::Nil)) => Tree(Node::Nil), + (Tree(Node::Leaf { value: lv }), Tree(Node::Leaf { value: rv })) + // we can prune right-hand leaves that are not marked; if a leaf + // is a checkpoint then that information will be propagated to + // the replacement leaf + if lv.1 == EPHEMERAL && (rv.1 & MARKED) == EPHEMERAL => + { + Tree( + Node::Leaf { + value: (H::combine(level, &lv.0, &rv.0), rv.1), + }, + ) + } + (left, right) => Tree( + Node::Parent { + ann, + left: Rc::new(left), + right: Rc::new(right), + }, + ), + } + } +} + +// We need an applicative functor for Result for this function so that we can correctly +// accumulate errors, but we don't have one so we just write a special- cased version here. +fn accumulate_result_with( + left: Result>, + right: Result>, + combine_success: impl FnOnce(A, B) -> C, +) -> Result> { + match (left, right) { + (Ok(a), Ok(b)) => Ok(combine_success(a, b)), + (Err(mut xs), Err(mut ys)) => { + xs.append(&mut ys); + Err(xs) + } + (Ok(_), Err(xs)) => Err(xs), + (Err(xs), Ok(_)) => Err(xs), + } +} + #[cfg(any(bench, test, feature = "test-dependencies"))] pub mod testing { use super::*; use incrementalmerkletree::Hashable; use proptest::prelude::*; + use proptest::sample::select; + + pub fn arb_retention_flags() -> impl Strategy { + select(vec![EPHEMERAL, CHECKPOINT, MARKED, MARKED | CHECKPOINT]) + } pub fn arb_tree( arb_annotation: A, @@ -200,8 +575,9 @@ pub mod testing { #[cfg(test)] mod tests { - use crate::{Node, Tree}; - use incrementalmerkletree::{Address, Level}; + use crate::{Node, PrunableTree, Tree, EPHEMERAL, MARKED}; + use incrementalmerkletree::{Address, Level, Position}; + use std::collections::BTreeSet; use std::rc::Rc; #[test] @@ -239,4 +615,171 @@ mod tests { ] ); } + + #[test] + fn tree_root() { + let t: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(Tree(Node::Leaf { + value: ("a".to_string(), EPHEMERAL), + })), + right: Rc::new(Tree(Node::Leaf { + value: ("b".to_string(), EPHEMERAL), + })), + }); + assert_eq!( + t.root_hash(Address::from_parts(Level::from(1), 0), Position::from(2)), + Ok("ab".to_string()) + ); + + let t0: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(Tree(Node::Nil)), + right: Rc::new(t.clone()), + }); + assert_eq!( + t0.root_hash(Address::from_parts(Level::from(2), 0), Position::from(4)), + Err(vec![Address::from_parts(Level::from(1), 0)]) + ); + + // Check root computation with truncation + let t1: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(t), + right: Rc::new(Tree(Node::Nil)), + }); + assert_eq!( + t1.root_hash(Address::from_parts(Level::from(2), 0), Position::from(2)), + Ok("ab__".to_string()) + ); + assert_eq!( + t1.root_hash(Address::from_parts(Level::from(2), 0), Position::from(3)), + Err(vec![Address::from_parts(Level::from(1), 1)]) + ); + } + + #[test] + fn tree_marked_positions() { + let t: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(Tree(Node::Leaf { + value: ("a".to_string(), EPHEMERAL), + })), + right: Rc::new(Tree(Node::Leaf { + value: ("b".to_string(), MARKED), + })), + }); + assert_eq!( + t.marked_positions(Address::from_parts(Level::from(1), 0)), + BTreeSet::from([Position::from(1)]) + ); + + let t0: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(t.clone()), + right: Rc::new(t), + }); + assert_eq!( + t0.marked_positions(Address::from_parts(Level::from(2), 1)), + BTreeSet::from([Position::from(5), Position::from(7)]) + ); + } + + #[test] + fn tree_prune() { + let t: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(Tree(Node::Leaf { + value: ("a".to_string(), EPHEMERAL), + })), + right: Rc::new(Tree(Node::Leaf { + value: ("b".to_string(), EPHEMERAL), + })), + }); + + assert_eq!( + t.clone().prune(Level::from(1)), + Tree(Node::Leaf { + value: ("ab".to_string(), EPHEMERAL) + }) + ); + + let t0: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(Tree(Node::Leaf { + value: ("c".to_string(), MARKED), + })), + right: Rc::new(t), + }); + assert_eq!( + t0.prune(Level::from(2)), + Tree(Node::Parent { + ann: None, + left: Rc::new(Tree(Node::Leaf { + value: ("c".to_string(), MARKED), + },)), + right: Rc::new(Tree(Node::Leaf { + value: ("ab".to_string(), EPHEMERAL) + })) + },) + ); + } + + #[test] + fn tree_merge_checked() { + let t0: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(Tree(Node::Leaf { + value: ("a".to_string(), EPHEMERAL), + })), + right: Rc::new(Tree(Node::Nil)), + }); + + let t1: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(Tree(Node::Nil)), + right: Rc::new(Tree(Node::Leaf { + value: ("b".to_string(), EPHEMERAL), + })), + }); + + assert_eq!( + t0.clone() + .merge_checked(Address::from_parts(1.into(), 0), t1.clone()), + Ok(Tree(Node::Leaf { + value: ("ab".to_string(), EPHEMERAL) + })) + ); + + let t2: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(Tree(Node::Leaf { + value: ("c".to_string(), EPHEMERAL), + })), + right: Rc::new(Tree(Node::Nil)), + }); + assert_eq!( + t0.clone() + .merge_checked(Address::from_parts(1.into(), 0), t2.clone()), + Err(Address::from_parts(0.into(), 0)) + ); + + let t3: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(t0), + right: Rc::new(t2), + }); + let t4: PrunableTree = Tree(Node::Parent { + ann: None, + left: Rc::new(t1.clone()), + right: Rc::new(t1), + }); + + assert_eq!( + t3.merge_checked(Address::from_parts(2.into(), 0), t4), + Ok(Tree(Node::Leaf { + value: ("abcb".to_string(), EPHEMERAL) + })) + ); + } }