diff --git a/zcash_primitives/src/merkle_tree.rs b/zcash_primitives/src/merkle_tree.rs index e0b5861fa..0428248fc 100644 --- a/zcash_primitives/src/merkle_tree.rs +++ b/zcash_primitives/src/merkle_tree.rs @@ -194,10 +194,14 @@ impl CommitmentTree { } fn is_complete(&self, depth: usize) -> bool { - self.left.is_some() - && self.right.is_some() - && self.parents.len() == depth - 1 - && self.parents.iter().all(|p| p.is_some()) + if depth == 0 { + self.left.is_some() && self.right.is_none() && self.parents.is_empty() + } else { + self.left.is_some() + && self.right.is_some() + && self.parents.len() == depth - 1 + && self.parents.iter().all(|p| p.is_some()) + } } } @@ -616,8 +620,8 @@ mod tests { use crate::sapling::{testing::arb_node, Node}; use super::{ - testing::arb_commitment_tree, CommitmentTree, Hashable, IncrementalWitness, MerklePath, - PathFiller, + testing::{arb_commitment_tree, TestNode}, + CommitmentTree, Hashable, IncrementalWitness, MerklePath, PathFiller, }; const HEX_EMPTY_ROOTS: [&str; 33] = [ @@ -1186,13 +1190,40 @@ mod tests { assert_eq!(frontier, frontier0); } } + + #[test] + fn test_commitment_tree_complete() { + let mut t: CommitmentTree = CommitmentTree::empty(); + for n in 1u64..=32 { + t.append(TestNode(n)).unwrap(); + // every tree of a power-of-two height is complete + let is_complete = n.count_ones() == 1; + let level = 63 - n.leading_zeros(); //log2 + assert_eq!( + is_complete, + t.is_complete(level.try_into().unwrap()), + "Tree {:?} {} complete at height {}", + t, + if is_complete { + "should be" + } else { + "should not be" + }, + n + ); + } + } } #[cfg(any(test, feature = "test-dependencies"))] pub mod testing { + use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use core::fmt::Debug; use proptest::collection::vec; use proptest::prelude::*; + use std::collections::hash_map::DefaultHasher; + use std::hash::Hasher; + use std::io::{self, Read, Write}; use super::{CommitmentTree, Hashable}; @@ -1211,4 +1242,32 @@ pub mod testing { tree }) } + + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] + pub(crate) struct TestNode(pub(crate) u64); + + impl Hashable for TestNode { + fn read(mut reader: R) -> io::Result { + reader.read_u64::().map(TestNode) + } + + fn write(&self, mut writer: W) -> io::Result<()> { + writer.write_u64::(self.0) + } + + fn combine(_: usize, a: &TestNode, b: &TestNode) -> TestNode { + let mut hasher = DefaultHasher::new(); + hasher.write_u64(a.0); + hasher.write_u64(b.0); + TestNode(hasher.finish()) + } + + fn blank() -> TestNode { + TestNode(0) + } + + fn empty_root(alt: usize) -> TestNode { + (0..alt).fold(Self::blank(), |v, lvl| Self::combine(lvl, &v, &v)) + } + } }