//! Verification functions for the [Equihash] proof-of-work algorithm. //! //! [Equihash]: https://zips.z.cash/protocol/protocol.pdf#equihash use blake2b_simd::{Hash as Blake2bHash, Params as Blake2bParams, State as Blake2bState}; use byteorder::{BigEndian, LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; use std::io::Cursor; use std::mem::size_of; #[derive(Clone, Copy)] pub(crate) struct Params { pub(crate) n: u32, pub(crate) k: u32, } #[derive(Clone)] struct Node { hash: Vec, indices: Vec, } impl Params { fn new(n: u32, k: u32) -> Result { // We place the following requirements on the parameters: // - n is a multiple of 8, so the hash output has an exact byte length. // - k >= 3 so the encoded solutions have an exact byte length. // - k < n, so the collision bit length is at least 1. // - n is a multiple of k + 1, so we have an integer collision bit length. if (n % 8 == 0) && (k >= 3) && (k < n) && (n % (k + 1) == 0) { Ok(Params { n, k }) } else { Err(Error(Kind::InvalidParams)) } } fn indices_per_hash_output(&self) -> u32 { 512 / self.n } fn hash_output(&self) -> u8 { (self.indices_per_hash_output() * self.n / 8) as u8 } fn collision_bit_length(&self) -> usize { (self.n / (self.k + 1)) as usize } fn collision_byte_length(&self) -> usize { (self.collision_bit_length() + 7) / 8 } #[cfg(test)] fn hash_length(&self) -> usize { ((self.k as usize) + 1) * self.collision_byte_length() } } impl Node { fn new(p: &Params, state: &Blake2bState, i: u32) -> Self { let hash = generate_hash(state, i / p.indices_per_hash_output()); let start = ((i % p.indices_per_hash_output()) * p.n / 8) as usize; let end = start + (p.n as usize) / 8; Node { hash: expand_array(&hash.as_bytes()[start..end], p.collision_bit_length(), 0), indices: vec![i], } } fn from_children(a: Node, b: Node, trim: usize) -> Self { let hash: Vec<_> = a .hash .iter() .zip(b.hash.iter()) .skip(trim) .map(|(a, b)| a ^ b) .collect(); let indices = if a.indices_before(&b) { let mut indices = a.indices; indices.extend(b.indices.iter()); indices } else { let mut indices = b.indices; indices.extend(a.indices.iter()); indices }; Node { hash, indices } } #[cfg(test)] fn from_children_ref(a: &Node, b: &Node, trim: usize) -> Self { let hash: Vec<_> = a .hash .iter() .zip(b.hash.iter()) .skip(trim) .map(|(a, b)| a ^ b) .collect(); let mut indices = Vec::with_capacity(a.indices.len() + b.indices.len()); if a.indices_before(b) { indices.extend(a.indices.iter()); indices.extend(b.indices.iter()); } else { indices.extend(b.indices.iter()); indices.extend(a.indices.iter()); } Node { hash, indices } } fn indices_before(&self, other: &Node) -> bool { // Indices are serialized in big-endian so that integer // comparison is equivalent to array comparison self.indices[0] < other.indices[0] } fn is_zero(&self, len: usize) -> bool { self.hash.iter().take(len).all(|v| *v == 0) } } /// An Equihash solution failed to verify. #[derive(Debug)] pub struct Error(Kind); impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Invalid solution: {}", self.0) } } impl std::error::Error for Error {} #[derive(Debug, PartialEq)] pub(crate) enum Kind { InvalidParams, Collision, OutOfOrder, DuplicateIdxs, NonZeroRootHash, } impl fmt::Display for Kind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Kind::InvalidParams => f.write_str("invalid parameters"), Kind::Collision => f.write_str("invalid collision length between StepRows"), Kind::OutOfOrder => f.write_str("Index tree incorrectly ordered"), Kind::DuplicateIdxs => f.write_str("duplicate indices"), Kind::NonZeroRootHash => f.write_str("root hash of tree is non-zero"), } } } fn initialise_state(n: u32, k: u32, digest_len: u8) -> Blake2bState { let mut personalization: Vec = Vec::from("ZcashPoW"); personalization.write_u32::(n).unwrap(); personalization.write_u32::(k).unwrap(); Blake2bParams::new() .hash_length(digest_len as usize) .personal(&personalization) .to_state() } fn generate_hash(base_state: &Blake2bState, i: u32) -> Blake2bHash { let mut lei = [0u8; 4]; (&mut lei[..]).write_u32::(i).unwrap(); let mut state = base_state.clone(); state.update(&lei); state.finalize() } fn expand_array(vin: &[u8], bit_len: usize, byte_pad: usize) -> Vec { assert!(bit_len >= 8); assert!(8 * size_of::() >= 7 + bit_len); let out_width = (bit_len + 7) / 8 + byte_pad; let out_len = 8 * out_width * vin.len() / bit_len; // Shortcut for parameters where expansion is a no-op if out_len == vin.len() { return vin.to_vec(); } let mut vout: Vec = vec![0; out_len]; let bit_len_mask: u32 = (1 << bit_len) - 1; // The acc_bits least-significant bits of acc_value represent a bit sequence // in big-endian order. let mut acc_bits = 0; let mut acc_value: u32 = 0; let mut j = 0; for b in vin { acc_value = (acc_value << 8) | u32::from(*b); acc_bits += 8; // When we have bit_len or more bits in the accumulator, write the next // output element. if acc_bits >= bit_len { acc_bits -= bit_len; for x in byte_pad..out_width { vout[j + x] = (( // Big-endian acc_value >> (acc_bits + (8 * (out_width - x - 1))) ) & ( // Apply bit_len_mask across byte boundaries (bit_len_mask >> (8 * (out_width - x - 1))) & 0xFF )) as u8; } j += out_width; } } vout } fn indices_from_minimal(p: Params, minimal: &[u8]) -> Result, Error> { let c_bit_len = p.collision_bit_length(); // Division is exact because k >= 3. if minimal.len() != ((1 << p.k) * (c_bit_len + 1)) / 8 { return Err(Error(Kind::InvalidParams)); } assert!(((c_bit_len + 1) + 7) / 8 <= size_of::()); let len_indices = 8 * size_of::() * minimal.len() / (c_bit_len + 1); let byte_pad = size_of::() - ((c_bit_len + 1) + 7) / 8; let mut csr = Cursor::new(expand_array(minimal, c_bit_len + 1, byte_pad)); let mut ret = Vec::with_capacity(len_indices); // Big-endian so that lexicographic array comparison is equivalent to integer // comparison while let Ok(i) = csr.read_u32::() { ret.push(i); } Ok(ret) } fn has_collision(a: &Node, b: &Node, len: usize) -> bool { a.hash .iter() .zip(b.hash.iter()) .take(len) .all(|(a, b)| a == b) } fn distinct_indices(a: &Node, b: &Node) -> bool { for i in &(a.indices) { for j in &(b.indices) { if i == j { return false; } } } true } fn validate_subtrees(p: &Params, a: &Node, b: &Node) -> Result<(), Kind> { if !has_collision(a, b, p.collision_byte_length()) { Err(Kind::Collision) } else if b.indices_before(a) { Err(Kind::OutOfOrder) } else if !distinct_indices(a, b) { Err(Kind::DuplicateIdxs) } else { Ok(()) } } #[cfg(test)] fn is_valid_solution_iterative( p: Params, input: &[u8], nonce: &[u8], indices: &[u32], ) -> Result<(), Error> { let mut state = initialise_state(p.n, p.k, p.hash_output()); state.update(input); state.update(nonce); let mut rows = Vec::new(); for i in indices { rows.push(Node::new(&p, &state, *i)); } let mut hash_len = p.hash_length(); while rows.len() > 1 { let mut cur_rows = Vec::new(); for pair in rows.chunks(2) { let a = &pair[0]; let b = &pair[1]; validate_subtrees(&p, a, b).map_err(Error)?; cur_rows.push(Node::from_children_ref(a, b, p.collision_byte_length())); } rows = cur_rows; hash_len -= p.collision_byte_length(); } assert!(rows.len() == 1); if rows[0].is_zero(hash_len) { Ok(()) } else { Err(Error(Kind::NonZeroRootHash)) } } fn tree_validator(p: &Params, state: &Blake2bState, indices: &[u32]) -> Result { if indices.len() > 1 { let end = indices.len(); let mid = end / 2; let a = tree_validator(p, state, &indices[0..mid])?; let b = tree_validator(p, state, &indices[mid..end])?; validate_subtrees(p, &a, &b).map_err(Error)?; Ok(Node::from_children(a, b, p.collision_byte_length())) } else { Ok(Node::new(&p, &state, indices[0])) } } fn is_valid_solution_recursive( p: Params, input: &[u8], nonce: &[u8], indices: &[u32], ) -> Result<(), Error> { let mut state = initialise_state(p.n, p.k, p.hash_output()); state.update(input); state.update(nonce); let root = tree_validator(&p, &state, indices)?; // Hashes were trimmed, so only need to check remaining length if root.is_zero(p.collision_byte_length()) { Ok(()) } else { Err(Error(Kind::NonZeroRootHash)) } } /// Checks whether `soln` is a valid solution for `(input, nonce)` with the /// parameters `(n, k)`. pub fn is_valid_solution( n: u32, k: u32, input: &[u8], nonce: &[u8], soln: &[u8], ) -> Result<(), Error> { let p = Params::new(n, k)?; let indices = indices_from_minimal(p, soln)?; // Recursive validation is faster is_valid_solution_recursive(p, input, nonce, &indices) } #[cfg(test)] mod tests { use super::is_valid_solution_iterative; use super::is_valid_solution_recursive; use crate::test_vectors::{INVALID_TEST_VECTORS, VALID_TEST_VECTORS}; #[test] fn valid_test_vectors() { for tv in VALID_TEST_VECTORS { for soln in tv.solutions { is_valid_solution_iterative(tv.params, tv.input, &tv.nonce, soln).unwrap(); is_valid_solution_recursive(tv.params, tv.input, &tv.nonce, soln).unwrap(); } } } #[test] fn invalid_test_vectors() { for tv in INVALID_TEST_VECTORS { assert_eq!( is_valid_solution_iterative(tv.params, tv.input, &tv.nonce, &tv.solution) .unwrap_err() .0, tv.error ); assert_eq!( is_valid_solution_recursive(tv.params, tv.input, &tv.nonce, &tv.solution) .unwrap_err() .0, tv.error ); } } }