diff --git a/zcash_primitives/src/merkle_tree.rs b/zcash_primitives/src/merkle_tree.rs index 8b2b370e1..016f2a31d 100644 --- a/zcash_primitives/src/merkle_tree.rs +++ b/zcash_primitives/src/merkle_tree.rs @@ -494,6 +494,20 @@ impl CommitmentTreeWitness { Err(()) } } + + /// Returns the root of the tree corresponding to the witness. + pub fn root(&self, leaf: Node) -> Node { + self.auth_path + .iter() + .enumerate() + .fold( + leaf, + |root, (i, (p, leaf_is_on_right))| match leaf_is_on_right { + false => Node::combine(i, &root, p), + true => Node::combine(i, p, &root), + }, + ) + } } #[cfg(test)] @@ -1002,6 +1016,7 @@ mod tests { assert_eq!(tree.size(), 0); let mut witnesses = vec![]; + let mut last_cm = None; let mut paths_i = 0; let mut witness_ser_i = 0; for i in 0..16 { @@ -1012,7 +1027,7 @@ mod tests { let cm = Node::new(cm); // Witness here - witnesses.push(TestIncrementalWitness::from_tree(&tree)); + witnesses.push((TestIncrementalWitness::from_tree(&tree), last_cm)); // Now append a commitment to the tree assert!(tree.append(cm).is_ok()); @@ -1026,14 +1041,11 @@ mod tests { // Check serialization of tree assert_tree_ser_eq(&tree, tree_ser[i]); - let mut first = true; // The first witness can never form a path - for witness in witnesses.as_mut_slice() { + for (witness, leaf) in witnesses.as_mut_slice() { // Append the same commitment to all the witnesses assert!(witness.append(cm).is_ok()); - if first { - assert!(witness.path().is_none()); - } else { + if let Some(leaf) = leaf { let path = witness.path().expect("should be able to create a path"); let expected = CommitmentTreeWitness::from_slice_with_depth( &mut hex::decode(paths[paths_i]).unwrap(), @@ -1041,7 +1053,11 @@ mod tests { ) .unwrap(); assert_eq!(path, expected); + assert_eq!(path.root(*leaf), witness.root()); paths_i += 1; + } else { + // The first witness can never form a path + assert!(witness.path().is_none()); } // Check witness serialization @@ -1049,15 +1065,15 @@ mod tests { witness_ser_i += 1; assert_eq!(witness.root(), tree.root()); - - first = false; } + + last_cm = Some(cm); } // Tree should be full now let node = Node::blank(); assert!(tree.append(node).is_err()); - for witness in witnesses.as_mut_slice() { + for (witness, _) in witnesses.as_mut_slice() { assert!(witness.append(node).is_err()); } }