Use `bitflags` crate instead of hand-rolled `RetentionFlags` bit flags.

This commit is contained in:
Kris Nuttycombe 2023-03-09 13:51:56 -07:00
parent 37be939c0f
commit 664cead68b
2 changed files with 67 additions and 93 deletions

View File

@ -12,6 +12,7 @@ repository = "https://github.com/zcash/incrementalmerkletree"
categories = ["algorithms", "data-structures"] categories = ["algorithms", "data-structures"]
[dependencies] [dependencies]
bitflags = "1.3"
either = "1.8" either = "1.8"
incrementalmerkletree = { version = "0.3", path = "../incrementalmerkletree" } incrementalmerkletree = { version = "0.3", path = "../incrementalmerkletree" }
proptest = { version = "1.0.0", optional = true } proptest = { version = "1.0.0", optional = true }

View File

@ -1,67 +1,52 @@
use bitflags::bitflags;
use core::convert::Infallible; use core::convert::Infallible;
use core::fmt::Debug; use core::fmt::Debug;
use core::marker::PhantomData; use core::marker::PhantomData;
use core::ops::{BitAnd, BitOr, Deref, Not, Range}; use core::ops::{Deref, Range};
use either::Either; use either::Either;
use std::collections::{BTreeMap, BTreeSet}; use std::collections::{BTreeMap, BTreeSet};
use std::rc::Rc; use std::rc::Rc;
use incrementalmerkletree::{Address, Hashable, Level, Position, Retention}; use incrementalmerkletree::{Address, Hashable, Level, Position, Retention};
/// A type for flags that determine when and how leaves can be pruned from a tree. bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct RetentionFlags: u8 {
pub struct RetentionFlags(u8); /// An leaf with `EPHEMERAL` retention can be pruned as soon as we are certain that it is not part
/// of the witness for a leaf with `CHECKPOINT` or `MARKED` retention.
const EPHEMERAL = 0b00000000;
impl BitOr for RetentionFlags { /// A leaf with `CHECKPOINT` retention can be pruned when there are more than `max_checkpoints`
type Output = Self; /// additional checkpoint leaves, if it is not also a marked leaf.
const CHECKPOINT = 0b00000001;
fn bitor(self, rhs: Self) -> Self { /// A leaf with `MARKED` retention can be pruned only as a consequence of an explicit deletion
RetentionFlags(self.0 | rhs.0) /// action.
const MARKED = 0b00000010;
} }
} }
impl BitAnd for RetentionFlags {
type Output = Self;
fn bitand(self, rhs: Self) -> Self {
RetentionFlags(self.0 & rhs.0)
}
}
/// An leaf with `EPHEMERAL` retention can be pruned as soon as we are certain that it is not part
/// of the witness for a leaf with `CHECKPOINT` or `MARKED` retention.
pub static EPHEMERAL: RetentionFlags = RetentionFlags(0b00000000);
/// A leaf with `CHECKPOINT` retention can be pruned when there are more than `max_checkpoints`
/// additional checkpoint leaves, if it is not also a marked leaf.
pub static CHECKPOINT: RetentionFlags = RetentionFlags(0b00000001);
/// A leaf with `MARKED` retention can be pruned only as a consequence of an explicit deletion
/// action.
pub static MARKED: RetentionFlags = RetentionFlags(0b00000010);
impl RetentionFlags { impl RetentionFlags {
pub fn is_checkpoint(&self) -> bool { pub fn is_checkpoint(&self) -> bool {
(*self & CHECKPOINT) == CHECKPOINT (*self & RetentionFlags::CHECKPOINT) == RetentionFlags::CHECKPOINT
} }
pub fn is_marked(&self) -> bool { pub fn is_marked(&self) -> bool {
(*self & MARKED) == MARKED (*self & RetentionFlags::MARKED) == RetentionFlags::MARKED
} }
} }
impl<'a, C> From<&'a Retention<C>> for RetentionFlags { impl<'a, C> From<&'a Retention<C>> for RetentionFlags {
fn from(retention: &'a Retention<C>) -> Self { fn from(retention: &'a Retention<C>) -> Self {
match retention { match retention {
Retention::Ephemeral => EPHEMERAL, Retention::Ephemeral => RetentionFlags::EPHEMERAL,
Retention::Checkpoint { is_marked, .. } => { Retention::Checkpoint { is_marked, .. } => {
if *is_marked { if *is_marked {
CHECKPOINT | MARKED RetentionFlags::CHECKPOINT | RetentionFlags::MARKED
} else { } else {
CHECKPOINT RetentionFlags::CHECKPOINT
} }
} }
Retention::Marked => MARKED, Retention::Marked => RetentionFlags::MARKED,
} }
} }
} }
@ -72,26 +57,6 @@ impl<C> From<Retention<C>> for RetentionFlags {
} }
} }
/// A mask that may be used to unset one or more retention flags.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct RetentionMask(u8);
impl Not for RetentionFlags {
type Output = RetentionMask;
fn not(self) -> Self::Output {
RetentionMask(!self.0)
}
}
impl BitAnd<RetentionMask> for RetentionFlags {
type Output = Self;
fn bitand(self, rhs: RetentionMask) -> Self {
RetentionFlags(self.0 & rhs.0)
}
}
/// A "pattern functor" for a single layer of a binary tree. /// A "pattern functor" for a single layer of a binary tree.
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum Node<C, A, V> { pub enum Node<C, A, V> {
@ -483,7 +448,7 @@ impl<H: Hashable + Clone + PartialEq> PrunableTree<H> {
// we can prune right-hand leaves that are not marked; if a leaf // we can prune right-hand leaves that are not marked; if a leaf
// is a checkpoint then that information will be propagated to // is a checkpoint then that information will be propagated to
// the replacement leaf // the replacement leaf
if lv.1 == EPHEMERAL && (rv.1 & MARKED) == EPHEMERAL => if lv.1 == RetentionFlags::EPHEMERAL && (rv.1 & RetentionFlags::MARKED) == RetentionFlags::EPHEMERAL =>
{ {
Tree( Tree(
Node::Leaf { Node::Leaf {
@ -1639,7 +1604,7 @@ impl<
Ok(Some(LocatedTree { Ok(Some(LocatedTree {
root_addr: addr, root_addr: addr,
root: Tree(Node::Leaf { root: Tree(Node::Leaf {
value: (value, EPHEMERAL), value: (value, RetentionFlags::EPHEMERAL),
}), }),
})) }))
} }
@ -1821,7 +1786,7 @@ impl<
} }
Tree(Node::Leaf { value: (h, r) }) => Some(( Tree(Node::Leaf { value: (h, r) }) => Some((
Tree(Node::Leaf { Tree(Node::Leaf {
value: (h.clone(), *r | CHECKPOINT), value: (h.clone(), *r | RetentionFlags::CHECKPOINT),
}), }),
root_addr.max_position(), root_addr.max_position(),
)), )),
@ -1897,10 +1862,10 @@ impl<
.and_modify(|to_clear| { .and_modify(|to_clear| {
to_clear to_clear
.entry(pos) .entry(pos)
.and_modify(|flags| *flags = *flags | CHECKPOINT) .and_modify(|flags| *flags = *flags | RetentionFlags::CHECKPOINT)
.or_insert(CHECKPOINT); .or_insert(RetentionFlags::CHECKPOINT);
}) })
.or_insert_with(|| BTreeMap::from([(pos, CHECKPOINT)])); .or_insert_with(|| BTreeMap::from([(pos, RetentionFlags::CHECKPOINT)]));
} }
// clear the leaves that have been marked for removal // clear the leaves that have been marked for removal
@ -1911,10 +1876,10 @@ impl<
.and_modify(|to_clear| { .and_modify(|to_clear| {
to_clear to_clear
.entry(*unmark_pos) .entry(*unmark_pos)
.and_modify(|flags| *flags = *flags | MARKED) .and_modify(|flags| *flags = *flags | RetentionFlags::MARKED)
.or_insert(MARKED); .or_insert(RetentionFlags::MARKED);
}) })
.or_insert_with(|| BTreeMap::from([(*unmark_pos, MARKED)])); .or_insert_with(|| BTreeMap::from([(*unmark_pos, RetentionFlags::MARKED)]));
} }
} }
@ -2241,7 +2206,12 @@ pub mod testing {
use proptest::sample::select; use proptest::sample::select;
pub fn arb_retention_flags() -> impl Strategy<Value = RetentionFlags> { pub fn arb_retention_flags() -> impl Strategy<Value = RetentionFlags> {
select(vec![EPHEMERAL, CHECKPOINT, MARKED, MARKED | CHECKPOINT]) select(vec![
RetentionFlags::EPHEMERAL,
RetentionFlags::CHECKPOINT,
RetentionFlags::MARKED,
RetentionFlags::MARKED | RetentionFlags::CHECKPOINT,
])
} }
pub fn arb_tree<A: Strategy + Clone + 'static, V: Strategy + Clone + 'static>( pub fn arb_tree<A: Strategy + Clone + 'static, V: Strategy + Clone + 'static>(
@ -2278,8 +2248,8 @@ pub mod testing {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::{
IncompleteAt, LocatedPrunableTree, LocatedTree, Node, PrunableTree, QueryError, ShardStore, IncompleteAt, LocatedPrunableTree, LocatedTree, Node, PrunableTree, QueryError,
ShardTree, Tree, EPHEMERAL, MARKED, RetentionFlags, ShardStore, ShardTree, Tree,
}; };
use assert_matches::assert_matches; use assert_matches::assert_matches;
use core::convert::Infallible; use core::convert::Infallible;
@ -2344,8 +2314,8 @@ mod tests {
#[test] #[test]
fn tree_root() { fn tree_root() {
let t: PrunableTree<String> = parent( let t: PrunableTree<String> = parent(
leaf(("a".to_string(), EPHEMERAL)), leaf(("a".to_string(), RetentionFlags::EPHEMERAL)),
leaf(("b".to_string(), EPHEMERAL)), leaf(("b".to_string(), RetentionFlags::EPHEMERAL)),
); );
assert_eq!( assert_eq!(
@ -2376,8 +2346,8 @@ mod tests {
let t: LocatedPrunableTree<String> = LocatedTree { let t: LocatedPrunableTree<String> = LocatedTree {
root_addr: Address::from_parts(3.into(), 1), root_addr: Address::from_parts(3.into(), 1),
root: parent( root: parent(
leaf(("abcd".to_string(), EPHEMERAL)), leaf(("abcd".to_string(), RetentionFlags::EPHEMERAL)),
parent(nil(), leaf(("gh".to_string(), EPHEMERAL))), parent(nil(), leaf(("gh".to_string(), RetentionFlags::EPHEMERAL))),
), ),
}; };
@ -2385,7 +2355,7 @@ mod tests {
t.insert_subtree::<Infallible>( t.insert_subtree::<Infallible>(
LocatedTree { LocatedTree {
root_addr: Address::from_parts(1.into(), 6), root_addr: Address::from_parts(1.into(), 6),
root: parent(leaf(("e".to_string(), MARKED)), nil()) root: parent(leaf(("e".to_string(), RetentionFlags::MARKED)), nil())
}, },
true true
), ),
@ -2393,10 +2363,10 @@ mod tests {
LocatedTree { LocatedTree {
root_addr: Address::from_parts(3.into(), 1), root_addr: Address::from_parts(3.into(), 1),
root: parent( root: parent(
leaf(("abcd".to_string(), EPHEMERAL)), leaf(("abcd".to_string(), RetentionFlags::EPHEMERAL)),
parent( parent(
parent(leaf(("e".to_string(), MARKED)), nil()), parent(leaf(("e".to_string(), RetentionFlags::MARKED)), nil()),
leaf(("gh".to_string(), EPHEMERAL)) leaf(("gh".to_string(), RetentionFlags::EPHEMERAL))
) )
) )
}, },
@ -2410,13 +2380,13 @@ mod tests {
let t: LocatedPrunableTree<String> = LocatedTree { let t: LocatedPrunableTree<String> = LocatedTree {
root_addr: Address::from_parts(3.into(), 0), root_addr: Address::from_parts(3.into(), 0),
root: parent( root: parent(
leaf(("abcd".to_string(), EPHEMERAL)), leaf(("abcd".to_string(), RetentionFlags::EPHEMERAL)),
parent( parent(
parent( parent(
leaf(("e".to_string(), MARKED)), leaf(("e".to_string(), RetentionFlags::MARKED)),
leaf(("f".to_string(), EPHEMERAL)), leaf(("f".to_string(), RetentionFlags::EPHEMERAL)),
), ),
leaf(("gh".to_string(), EPHEMERAL)), leaf(("gh".to_string(), RetentionFlags::EPHEMERAL)),
), ),
), ),
}; };
@ -2449,8 +2419,8 @@ mod tests {
#[test] #[test]
fn tree_marked_positions() { fn tree_marked_positions() {
let t: PrunableTree<String> = parent( let t: PrunableTree<String> = parent(
leaf(("a".to_string(), EPHEMERAL)), leaf(("a".to_string(), RetentionFlags::EPHEMERAL)),
leaf(("b".to_string(), MARKED)), leaf(("b".to_string(), RetentionFlags::MARKED)),
); );
assert_eq!( assert_eq!(
t.marked_positions(Address::from_parts(Level::from(1), 0)), t.marked_positions(Address::from_parts(Level::from(1), 0)),
@ -2467,38 +2437,41 @@ mod tests {
#[test] #[test]
fn tree_prune() { fn tree_prune() {
let t: PrunableTree<String> = parent( let t: PrunableTree<String> = parent(
leaf(("a".to_string(), EPHEMERAL)), leaf(("a".to_string(), RetentionFlags::EPHEMERAL)),
leaf(("b".to_string(), EPHEMERAL)), leaf(("b".to_string(), RetentionFlags::EPHEMERAL)),
); );
assert_eq!( assert_eq!(
t.clone().prune(Level::from(1)), t.clone().prune(Level::from(1)),
leaf(("ab".to_string(), EPHEMERAL)) leaf(("ab".to_string(), RetentionFlags::EPHEMERAL))
); );
let t0 = parent(leaf(("c".to_string(), MARKED)), t); let t0 = parent(leaf(("c".to_string(), RetentionFlags::MARKED)), t);
assert_eq!( assert_eq!(
t0.prune(Level::from(2)), t0.prune(Level::from(2)),
parent( parent(
leaf(("c".to_string(), MARKED)), leaf(("c".to_string(), RetentionFlags::MARKED)),
leaf(("ab".to_string(), EPHEMERAL)) leaf(("ab".to_string(), RetentionFlags::EPHEMERAL))
) )
); );
} }
#[test] #[test]
fn tree_merge_checked() { fn tree_merge_checked() {
let t0: PrunableTree<String> = parent(leaf(("a".to_string(), EPHEMERAL)), nil()); let t0: PrunableTree<String> =
parent(leaf(("a".to_string(), RetentionFlags::EPHEMERAL)), nil());
let t1: PrunableTree<String> = parent(nil(), leaf(("b".to_string(), EPHEMERAL))); let t1: PrunableTree<String> =
parent(nil(), leaf(("b".to_string(), RetentionFlags::EPHEMERAL)));
assert_eq!( assert_eq!(
t0.clone() t0.clone()
.merge_checked(Address::from_parts(1.into(), 0), t1.clone()), .merge_checked(Address::from_parts(1.into(), 0), t1.clone()),
Ok(leaf(("ab".to_string(), EPHEMERAL))) Ok(leaf(("ab".to_string(), RetentionFlags::EPHEMERAL)))
); );
let t2: PrunableTree<String> = parent(leaf(("c".to_string(), EPHEMERAL)), nil()); let t2: PrunableTree<String> =
parent(leaf(("c".to_string(), RetentionFlags::EPHEMERAL)), nil());
assert_eq!( assert_eq!(
t0.clone() t0.clone()
.merge_checked(Address::from_parts(1.into(), 0), t2.clone()), .merge_checked(Address::from_parts(1.into(), 0), t2.clone()),
@ -2510,7 +2483,7 @@ mod tests {
assert_eq!( assert_eq!(
t3.merge_checked(Address::from_parts(2.into(), 0), t4), t3.merge_checked(Address::from_parts(2.into(), 0), t4),
Ok(leaf(("abcb".to_string(), EPHEMERAL))) Ok(leaf(("abcb".to_string(), RetentionFlags::EPHEMERAL)))
); );
} }
@ -2582,8 +2555,8 @@ mod tests {
LocatedPrunableTree { LocatedPrunableTree {
root_addr: Address::from_parts(2.into(), 0), root_addr: Address::from_parts(2.into(), 0),
root: parent( root: parent(
parent(leaf(("a".to_string(), EPHEMERAL)), nil()), parent(leaf(("a".to_string(), RetentionFlags::EPHEMERAL)), nil()),
parent(nil(), leaf(("d".to_string(), EPHEMERAL))) parent(nil(), leaf(("d".to_string(), RetentionFlags::EPHEMERAL)))
) )
} }
); );