equihash/
verify.rs

1//! Verification functions for the [Equihash] proof-of-work algorithm.
2//!
3//! [Equihash]: https://zips.z.cash/protocol/protocol.pdf#equihash
4
5use alloc::vec::Vec;
6use blake2b_simd::{Hash as Blake2bHash, Params as Blake2bParams, State as Blake2bState};
7use core::fmt;
8use core2::io::Write;
9
10use crate::{
11    minimal::{expand_array, indices_from_minimal},
12    params::Params,
13};
14
15#[derive(Clone)]
16struct Node {
17    hash: Vec<u8>,
18    indices: Vec<u32>,
19}
20
21impl Node {
22    fn new(p: &Params, state: &Blake2bState, i: u32) -> Self {
23        let hash = generate_hash(state, i / p.indices_per_hash_output());
24        let start = ((i % p.indices_per_hash_output()) * p.n / 8) as usize;
25        let end = start + (p.n as usize) / 8;
26        Node {
27            hash: expand_array(&hash.as_bytes()[start..end], p.collision_bit_length(), 0),
28            indices: vec![i],
29        }
30    }
31
32    // Clippy incorrectly interprets the first argument as `self`.
33    #[allow(clippy::wrong_self_convention)]
34    fn from_children(a: Node, b: Node, trim: usize) -> Self {
35        let hash: Vec<_> = a
36            .hash
37            .iter()
38            .zip(b.hash.iter())
39            .skip(trim)
40            .map(|(a, b)| a ^ b)
41            .collect();
42        let indices = if a.indices_before(&b) {
43            let mut indices = a.indices;
44            indices.extend(b.indices.iter());
45            indices
46        } else {
47            let mut indices = b.indices;
48            indices.extend(a.indices.iter());
49            indices
50        };
51        Node { hash, indices }
52    }
53
54    #[cfg(test)]
55    fn from_children_ref(a: &Node, b: &Node, trim: usize) -> Self {
56        let hash: Vec<_> = a
57            .hash
58            .iter()
59            .zip(b.hash.iter())
60            .skip(trim)
61            .map(|(a, b)| a ^ b)
62            .collect();
63        let mut indices = Vec::with_capacity(a.indices.len() + b.indices.len());
64        if a.indices_before(b) {
65            indices.extend(a.indices.iter());
66            indices.extend(b.indices.iter());
67        } else {
68            indices.extend(b.indices.iter());
69            indices.extend(a.indices.iter());
70        }
71        Node { hash, indices }
72    }
73
74    fn indices_before(&self, other: &Node) -> bool {
75        // Indices are serialized in big-endian so that integer
76        // comparison is equivalent to array comparison
77        self.indices[0] < other.indices[0]
78    }
79
80    fn is_zero(&self, len: usize) -> bool {
81        self.hash.iter().take(len).all(|v| *v == 0)
82    }
83}
84
85/// An Equihash solution failed to verify.
86#[derive(Debug)]
87pub struct Error(Kind);
88
89impl fmt::Display for Error {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        write!(f, "Invalid solution: {}", self.0)
92    }
93}
94
95#[cfg(feature = "std")]
96impl std::error::Error for Error {}
97
98#[derive(Debug, PartialEq)]
99pub(crate) enum Kind {
100    InvalidParams,
101    Collision,
102    OutOfOrder,
103    DuplicateIdxs,
104    NonZeroRootHash,
105}
106
107impl fmt::Display for Kind {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        match self {
110            Kind::InvalidParams => f.write_str("invalid parameters"),
111            Kind::Collision => f.write_str("invalid collision length between StepRows"),
112            Kind::OutOfOrder => f.write_str("Index tree incorrectly ordered"),
113            Kind::DuplicateIdxs => f.write_str("duplicate indices"),
114            Kind::NonZeroRootHash => f.write_str("root hash of tree is non-zero"),
115        }
116    }
117}
118
119pub(crate) fn initialise_state(n: u32, k: u32, digest_len: u8) -> Blake2bState {
120    let mut personalization: Vec<u8> = Vec::from("ZcashPoW");
121    personalization.write_all(&n.to_le_bytes()).unwrap();
122    personalization.write_all(&k.to_le_bytes()).unwrap();
123
124    Blake2bParams::new()
125        .hash_length(digest_len as usize)
126        .personal(&personalization)
127        .to_state()
128}
129
130fn generate_hash(base_state: &Blake2bState, i: u32) -> Blake2bHash {
131    let mut lei = [0u8; 4];
132    (&mut lei[..]).write_all(&i.to_le_bytes()).unwrap();
133
134    let mut state = base_state.clone();
135    state.update(&lei);
136    state.finalize()
137}
138
139fn has_collision(a: &Node, b: &Node, len: usize) -> bool {
140    a.hash
141        .iter()
142        .zip(b.hash.iter())
143        .take(len)
144        .all(|(a, b)| a == b)
145}
146
147fn distinct_indices(a: &Node, b: &Node) -> bool {
148    for i in &(a.indices) {
149        for j in &(b.indices) {
150            if i == j {
151                return false;
152            }
153        }
154    }
155    true
156}
157
158fn validate_subtrees(p: &Params, a: &Node, b: &Node) -> Result<(), Kind> {
159    if !has_collision(a, b, p.collision_byte_length()) {
160        Err(Kind::Collision)
161    } else if b.indices_before(a) {
162        Err(Kind::OutOfOrder)
163    } else if !distinct_indices(a, b) {
164        Err(Kind::DuplicateIdxs)
165    } else {
166        Ok(())
167    }
168}
169
170#[cfg(test)]
171fn is_valid_solution_iterative(
172    p: Params,
173    input: &[u8],
174    nonce: &[u8],
175    indices: &[u32],
176) -> Result<(), Error> {
177    let mut state = initialise_state(p.n, p.k, p.hash_output());
178    state.update(input);
179    state.update(nonce);
180
181    let mut rows = Vec::new();
182    for i in indices {
183        rows.push(Node::new(&p, &state, *i));
184    }
185
186    let mut hash_len = p.hash_length();
187    while rows.len() > 1 {
188        let mut cur_rows = Vec::new();
189        for pair in rows.chunks(2) {
190            let a = &pair[0];
191            let b = &pair[1];
192            validate_subtrees(&p, a, b).map_err(Error)?;
193            cur_rows.push(Node::from_children_ref(a, b, p.collision_byte_length()));
194        }
195        rows = cur_rows;
196        hash_len -= p.collision_byte_length();
197    }
198
199    assert!(rows.len() == 1);
200
201    if rows[0].is_zero(hash_len) {
202        Ok(())
203    } else {
204        Err(Error(Kind::NonZeroRootHash))
205    }
206}
207
208fn tree_validator(p: &Params, state: &Blake2bState, indices: &[u32]) -> Result<Node, Error> {
209    if indices.len() > 1 {
210        let end = indices.len();
211        let mid = end / 2;
212        let a = tree_validator(p, state, &indices[0..mid])?;
213        let b = tree_validator(p, state, &indices[mid..end])?;
214        validate_subtrees(p, &a, &b).map_err(Error)?;
215        Ok(Node::from_children(a, b, p.collision_byte_length()))
216    } else {
217        Ok(Node::new(p, state, indices[0]))
218    }
219}
220
221fn is_valid_solution_recursive(
222    p: Params,
223    input: &[u8],
224    nonce: &[u8],
225    indices: &[u32],
226) -> Result<(), Error> {
227    let mut state = initialise_state(p.n, p.k, p.hash_output());
228    state.update(input);
229    state.update(nonce);
230
231    let root = tree_validator(&p, &state, indices)?;
232
233    // Hashes were trimmed, so only need to check remaining length
234    if root.is_zero(p.collision_byte_length()) {
235        Ok(())
236    } else {
237        Err(Error(Kind::NonZeroRootHash))
238    }
239}
240
241/// Checks whether `soln` is a valid solution for `(input, nonce)` with the
242/// parameters `(n, k)`.
243pub fn is_valid_solution(
244    n: u32,
245    k: u32,
246    input: &[u8],
247    nonce: &[u8],
248    soln: &[u8],
249) -> Result<(), Error> {
250    let p = Params::new(n, k).ok_or(Error(Kind::InvalidParams))?;
251    let indices = indices_from_minimal(p, soln).ok_or(Error(Kind::InvalidParams))?;
252
253    // Recursive validation is faster
254    is_valid_solution_recursive(p, input, nonce, &indices)
255}
256
257#[cfg(test)]
258mod tests {
259    use super::{is_valid_solution, is_valid_solution_iterative, is_valid_solution_recursive};
260    use crate::test_vectors::{INVALID_TEST_VECTORS, VALID_TEST_VECTORS};
261
262    #[test]
263    fn valid_test_vectors() {
264        for tv in VALID_TEST_VECTORS {
265            for soln in tv.solutions {
266                is_valid_solution_iterative(tv.params, tv.input, &tv.nonce, soln).unwrap();
267                is_valid_solution_recursive(tv.params, tv.input, &tv.nonce, soln).unwrap();
268            }
269        }
270    }
271
272    #[test]
273    fn invalid_test_vectors() {
274        for tv in INVALID_TEST_VECTORS {
275            assert_eq!(
276                is_valid_solution_iterative(tv.params, tv.input, &tv.nonce, tv.solution)
277                    .unwrap_err()
278                    .0,
279                tv.error
280            );
281            assert_eq!(
282                is_valid_solution_recursive(tv.params, tv.input, &tv.nonce, tv.solution)
283                    .unwrap_err()
284                    .0,
285                tv.error
286            );
287        }
288    }
289
290    #[test]
291    fn all_bits_matter() {
292        // Initialize the state according to one of the valid test vectors.
293        let n = 96;
294        let k = 5;
295        let input = b"Equihash is an asymmetric PoW based on the Generalised Birthday problem.";
296        let nonce = [
297            1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
298            0, 0, 0,
299        ];
300        let soln = &[
301            0x04, 0x6a, 0x8e, 0xd4, 0x51, 0xa2, 0x19, 0x73, 0x32, 0xe7, 0x1f, 0x39, 0xdb, 0x9c,
302            0x79, 0xfb, 0xf9, 0x3f, 0xc1, 0x44, 0x3d, 0xa5, 0x8f, 0xb3, 0x8d, 0x05, 0x99, 0x17,
303            0x21, 0x16, 0xd5, 0x55, 0xb1, 0xb2, 0x1f, 0x32, 0x70, 0x5c, 0xe9, 0x98, 0xf6, 0x0d,
304            0xa8, 0x52, 0xf7, 0x7f, 0x0e, 0x7f, 0x4d, 0x63, 0xfc, 0x2d, 0xd2, 0x30, 0xa3, 0xd9,
305            0x99, 0x53, 0xa0, 0x78, 0x7d, 0xfe, 0xfc, 0xab, 0x34, 0x1b, 0xde, 0xc8,
306        ];
307
308        // Prove that the solution is valid.
309        is_valid_solution(n, k, input, &nonce, soln).unwrap();
310
311        // Changing any single bit of the encoded solution should make it invalid.
312        for i in 0..soln.len() * 8 {
313            let mut mutated = soln.to_vec();
314            mutated[i / 8] ^= 1 << (i % 8);
315            is_valid_solution(n, k, input, &nonce, &mutated).unwrap_err();
316        }
317    }
318}