1use alloc::vec::Vec;
6use blake2b_simd::{Hash as Blake2bHash, Params as Blake2bParams, State as Blake2bState};
7use core::fmt;
8use core2::io::Write;
9
10use crate::{
11 minimal::{expand_array, indices_from_minimal},
12 params::Params,
13};
14
15#[derive(Clone)]
16struct Node {
17 hash: Vec<u8>,
18 indices: Vec<u32>,
19}
20
21impl Node {
22 fn new(p: &Params, state: &Blake2bState, i: u32) -> Self {
23 let hash = generate_hash(state, i / p.indices_per_hash_output());
24 let start = ((i % p.indices_per_hash_output()) * p.n / 8) as usize;
25 let end = start + (p.n as usize) / 8;
26 Node {
27 hash: expand_array(&hash.as_bytes()[start..end], p.collision_bit_length(), 0),
28 indices: vec![i],
29 }
30 }
31
32 #[allow(clippy::wrong_self_convention)]
34 fn from_children(a: Node, b: Node, trim: usize) -> Self {
35 let hash: Vec<_> = a
36 .hash
37 .iter()
38 .zip(b.hash.iter())
39 .skip(trim)
40 .map(|(a, b)| a ^ b)
41 .collect();
42 let indices = if a.indices_before(&b) {
43 let mut indices = a.indices;
44 indices.extend(b.indices.iter());
45 indices
46 } else {
47 let mut indices = b.indices;
48 indices.extend(a.indices.iter());
49 indices
50 };
51 Node { hash, indices }
52 }
53
54 #[cfg(test)]
55 fn from_children_ref(a: &Node, b: &Node, trim: usize) -> Self {
56 let hash: Vec<_> = a
57 .hash
58 .iter()
59 .zip(b.hash.iter())
60 .skip(trim)
61 .map(|(a, b)| a ^ b)
62 .collect();
63 let mut indices = Vec::with_capacity(a.indices.len() + b.indices.len());
64 if a.indices_before(b) {
65 indices.extend(a.indices.iter());
66 indices.extend(b.indices.iter());
67 } else {
68 indices.extend(b.indices.iter());
69 indices.extend(a.indices.iter());
70 }
71 Node { hash, indices }
72 }
73
74 fn indices_before(&self, other: &Node) -> bool {
75 self.indices[0] < other.indices[0]
78 }
79
80 fn is_zero(&self, len: usize) -> bool {
81 self.hash.iter().take(len).all(|v| *v == 0)
82 }
83}
84
85#[derive(Debug)]
87pub struct Error(Kind);
88
89impl fmt::Display for Error {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 write!(f, "Invalid solution: {}", self.0)
92 }
93}
94
95#[cfg(feature = "std")]
96impl std::error::Error for Error {}
97
98#[derive(Debug, PartialEq)]
99pub(crate) enum Kind {
100 InvalidParams,
101 Collision,
102 OutOfOrder,
103 DuplicateIdxs,
104 NonZeroRootHash,
105}
106
107impl fmt::Display for Kind {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 match self {
110 Kind::InvalidParams => f.write_str("invalid parameters"),
111 Kind::Collision => f.write_str("invalid collision length between StepRows"),
112 Kind::OutOfOrder => f.write_str("Index tree incorrectly ordered"),
113 Kind::DuplicateIdxs => f.write_str("duplicate indices"),
114 Kind::NonZeroRootHash => f.write_str("root hash of tree is non-zero"),
115 }
116 }
117}
118
119pub(crate) fn initialise_state(n: u32, k: u32, digest_len: u8) -> Blake2bState {
120 let mut personalization: Vec<u8> = Vec::from("ZcashPoW");
121 personalization.write_all(&n.to_le_bytes()).unwrap();
122 personalization.write_all(&k.to_le_bytes()).unwrap();
123
124 Blake2bParams::new()
125 .hash_length(digest_len as usize)
126 .personal(&personalization)
127 .to_state()
128}
129
130fn generate_hash(base_state: &Blake2bState, i: u32) -> Blake2bHash {
131 let mut lei = [0u8; 4];
132 (&mut lei[..]).write_all(&i.to_le_bytes()).unwrap();
133
134 let mut state = base_state.clone();
135 state.update(&lei);
136 state.finalize()
137}
138
139fn has_collision(a: &Node, b: &Node, len: usize) -> bool {
140 a.hash
141 .iter()
142 .zip(b.hash.iter())
143 .take(len)
144 .all(|(a, b)| a == b)
145}
146
147fn distinct_indices(a: &Node, b: &Node) -> bool {
148 for i in &(a.indices) {
149 for j in &(b.indices) {
150 if i == j {
151 return false;
152 }
153 }
154 }
155 true
156}
157
158fn validate_subtrees(p: &Params, a: &Node, b: &Node) -> Result<(), Kind> {
159 if !has_collision(a, b, p.collision_byte_length()) {
160 Err(Kind::Collision)
161 } else if b.indices_before(a) {
162 Err(Kind::OutOfOrder)
163 } else if !distinct_indices(a, b) {
164 Err(Kind::DuplicateIdxs)
165 } else {
166 Ok(())
167 }
168}
169
170#[cfg(test)]
171fn is_valid_solution_iterative(
172 p: Params,
173 input: &[u8],
174 nonce: &[u8],
175 indices: &[u32],
176) -> Result<(), Error> {
177 let mut state = initialise_state(p.n, p.k, p.hash_output());
178 state.update(input);
179 state.update(nonce);
180
181 let mut rows = Vec::new();
182 for i in indices {
183 rows.push(Node::new(&p, &state, *i));
184 }
185
186 let mut hash_len = p.hash_length();
187 while rows.len() > 1 {
188 let mut cur_rows = Vec::new();
189 for pair in rows.chunks(2) {
190 let a = &pair[0];
191 let b = &pair[1];
192 validate_subtrees(&p, a, b).map_err(Error)?;
193 cur_rows.push(Node::from_children_ref(a, b, p.collision_byte_length()));
194 }
195 rows = cur_rows;
196 hash_len -= p.collision_byte_length();
197 }
198
199 assert!(rows.len() == 1);
200
201 if rows[0].is_zero(hash_len) {
202 Ok(())
203 } else {
204 Err(Error(Kind::NonZeroRootHash))
205 }
206}
207
208fn tree_validator(p: &Params, state: &Blake2bState, indices: &[u32]) -> Result<Node, Error> {
209 if indices.len() > 1 {
210 let end = indices.len();
211 let mid = end / 2;
212 let a = tree_validator(p, state, &indices[0..mid])?;
213 let b = tree_validator(p, state, &indices[mid..end])?;
214 validate_subtrees(p, &a, &b).map_err(Error)?;
215 Ok(Node::from_children(a, b, p.collision_byte_length()))
216 } else {
217 Ok(Node::new(p, state, indices[0]))
218 }
219}
220
221fn is_valid_solution_recursive(
222 p: Params,
223 input: &[u8],
224 nonce: &[u8],
225 indices: &[u32],
226) -> Result<(), Error> {
227 let mut state = initialise_state(p.n, p.k, p.hash_output());
228 state.update(input);
229 state.update(nonce);
230
231 let root = tree_validator(&p, &state, indices)?;
232
233 if root.is_zero(p.collision_byte_length()) {
235 Ok(())
236 } else {
237 Err(Error(Kind::NonZeroRootHash))
238 }
239}
240
241pub fn is_valid_solution(
244 n: u32,
245 k: u32,
246 input: &[u8],
247 nonce: &[u8],
248 soln: &[u8],
249) -> Result<(), Error> {
250 let p = Params::new(n, k).ok_or(Error(Kind::InvalidParams))?;
251 let indices = indices_from_minimal(p, soln).ok_or(Error(Kind::InvalidParams))?;
252
253 is_valid_solution_recursive(p, input, nonce, &indices)
255}
256
257#[cfg(test)]
258mod tests {
259 use super::{is_valid_solution, is_valid_solution_iterative, is_valid_solution_recursive};
260 use crate::test_vectors::{INVALID_TEST_VECTORS, VALID_TEST_VECTORS};
261
262 #[test]
263 fn valid_test_vectors() {
264 for tv in VALID_TEST_VECTORS {
265 for soln in tv.solutions {
266 is_valid_solution_iterative(tv.params, tv.input, &tv.nonce, soln).unwrap();
267 is_valid_solution_recursive(tv.params, tv.input, &tv.nonce, soln).unwrap();
268 }
269 }
270 }
271
272 #[test]
273 fn invalid_test_vectors() {
274 for tv in INVALID_TEST_VECTORS {
275 assert_eq!(
276 is_valid_solution_iterative(tv.params, tv.input, &tv.nonce, tv.solution)
277 .unwrap_err()
278 .0,
279 tv.error
280 );
281 assert_eq!(
282 is_valid_solution_recursive(tv.params, tv.input, &tv.nonce, tv.solution)
283 .unwrap_err()
284 .0,
285 tv.error
286 );
287 }
288 }
289
290 #[test]
291 fn all_bits_matter() {
292 let n = 96;
294 let k = 5;
295 let input = b"Equihash is an asymmetric PoW based on the Generalised Birthday problem.";
296 let nonce = [
297 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
298 0, 0, 0,
299 ];
300 let soln = &[
301 0x04, 0x6a, 0x8e, 0xd4, 0x51, 0xa2, 0x19, 0x73, 0x32, 0xe7, 0x1f, 0x39, 0xdb, 0x9c,
302 0x79, 0xfb, 0xf9, 0x3f, 0xc1, 0x44, 0x3d, 0xa5, 0x8f, 0xb3, 0x8d, 0x05, 0x99, 0x17,
303 0x21, 0x16, 0xd5, 0x55, 0xb1, 0xb2, 0x1f, 0x32, 0x70, 0x5c, 0xe9, 0x98, 0xf6, 0x0d,
304 0xa8, 0x52, 0xf7, 0x7f, 0x0e, 0x7f, 0x4d, 0x63, 0xfc, 0x2d, 0xd2, 0x30, 0xa3, 0xd9,
305 0x99, 0x53, 0xa0, 0x78, 0x7d, 0xfe, 0xfc, 0xab, 0x34, 0x1b, 0xde, 0xc8,
306 ];
307
308 is_valid_solution(n, k, input, &nonce, soln).unwrap();
310
311 for i in 0..soln.len() * 8 {
313 let mut mutated = soln.to_vec();
314 mutated[i / 8] ^= 1 << (i % 8);
315 is_valid_solution(n, k, input, &nonce, &mutated).unwrap_err();
316 }
317 }
318}