librustzcash/components/equihash/src/verify.rs

389 lines
11 KiB
Rust
Raw Normal View History

2019-09-24 02:54:15 -07:00
//! Verification functions for the [Equihash] proof-of-work algorithm.
//!
//! [Equihash]: https://zips.z.cash/protocol/protocol.pdf#equihash
use blake2b_simd::{Hash as Blake2bHash, Params as Blake2bParams, State as Blake2bState};
use byteorder::{BigEndian, LittleEndian, ReadBytesExt, WriteBytesExt};
use std::fmt;
use std::io::Cursor;
use std::mem::size_of;
#[derive(Clone, Copy)]
pub(crate) struct Params {
pub(crate) n: u32,
pub(crate) k: u32,
2017-12-30 14:57:51 -08:00
}
#[derive(Clone)]
struct Node {
hash: Vec<u8>,
indices: Vec<u32>,
}
2017-12-30 14:57:51 -08:00
impl Params {
fn new(n: u32, k: u32) -> Result<Self, Error> {
// We place the following requirements on the parameters:
// - n is a multiple of 8, so the hash output has an exact byte length.
// - k >= 3 so the encoded solutions have an exact byte length.
// - k < n, so the collision bit length is at least 1.
// - n is a multiple of k + 1, so we have an integer collision bit length.
if (n % 8 == 0) && (k >= 3) && (k < n) && (n % (k + 1) == 0) {
Ok(Params { n, k })
} else {
Err(Error(Kind::InvalidParams))
}
}
2017-12-30 14:57:51 -08:00
fn indices_per_hash_output(&self) -> u32 {
512 / self.n
}
fn hash_output(&self) -> u8 {
(self.indices_per_hash_output() * self.n / 8) as u8
}
fn collision_bit_length(&self) -> usize {
(self.n / (self.k + 1)) as usize
}
fn collision_byte_length(&self) -> usize {
(self.collision_bit_length() + 7) / 8
}
#[cfg(test)]
2017-12-30 14:57:51 -08:00
fn hash_length(&self) -> usize {
((self.k as usize) + 1) * self.collision_byte_length()
}
}
impl Node {
fn new(p: &Params, state: &Blake2bState, i: u32) -> Self {
2018-05-23 21:19:10 -07:00
let hash = generate_hash(state, i / p.indices_per_hash_output());
let start = ((i % p.indices_per_hash_output()) * p.n / 8) as usize;
let end = start + (p.n as usize) / 8;
Node {
hash: expand_array(&hash.as_bytes()[start..end], p.collision_bit_length(), 0),
indices: vec![i],
}
}
fn from_children(a: Node, b: Node, trim: usize) -> Self {
2018-06-12 14:32:57 -07:00
let hash: Vec<_> = a
.hash
.iter()
.zip(b.hash.iter())
.skip(trim)
.map(|(a, b)| a ^ b)
.collect();
let indices = if a.indices_before(&b) {
let mut indices = a.indices;
indices.extend(b.indices.iter());
indices
} else {
let mut indices = b.indices;
indices.extend(a.indices.iter());
indices
};
Node { hash, indices }
}
#[cfg(test)]
2017-12-30 14:57:51 -08:00
fn from_children_ref(a: &Node, b: &Node, trim: usize) -> Self {
2018-06-12 14:32:57 -07:00
let hash: Vec<_> = a
.hash
2017-12-30 14:57:51 -08:00
.iter()
.zip(b.hash.iter())
.skip(trim)
.map(|(a, b)| a ^ b)
.collect();
let mut indices = Vec::with_capacity(a.indices.len() + b.indices.len());
if a.indices_before(b) {
2017-12-30 14:57:51 -08:00
indices.extend(a.indices.iter());
indices.extend(b.indices.iter());
} else {
2017-12-30 14:57:51 -08:00
indices.extend(b.indices.iter());
indices.extend(a.indices.iter());
}
Node { hash, indices }
}
fn indices_before(&self, other: &Node) -> bool {
// Indices are serialized in big-endian so that integer
// comparison is equivalent to array comparison
self.indices[0] < other.indices[0]
}
fn is_zero(&self, len: usize) -> bool {
2017-12-30 14:57:51 -08:00
self.hash.iter().take(len).all(|v| *v == 0)
}
}
2020-07-07 22:34:52 -07:00
/// An Equihash solution failed to verify.
#[derive(Debug)]
pub struct Error(Kind);
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Invalid solution: {}", self.0)
}
}
impl std::error::Error for Error {}
#[derive(Debug, PartialEq)]
pub(crate) enum Kind {
InvalidParams,
Collision,
OutOfOrder,
DuplicateIdxs,
NonZeroRootHash,
}
impl fmt::Display for Kind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Kind::InvalidParams => f.write_str("invalid parameters"),
Kind::Collision => f.write_str("invalid collision length between StepRows"),
Kind::OutOfOrder => f.write_str("Index tree incorrectly ordered"),
Kind::DuplicateIdxs => f.write_str("duplicate indices"),
Kind::NonZeroRootHash => f.write_str("root hash of tree is non-zero"),
}
}
}
fn initialise_state(n: u32, k: u32, digest_len: u8) -> Blake2bState {
let mut personalization: Vec<u8> = Vec::from("ZcashPoW");
personalization.write_u32::<LittleEndian>(n).unwrap();
personalization.write_u32::<LittleEndian>(k).unwrap();
Blake2bParams::new()
.hash_length(digest_len as usize)
.personal(&personalization)
.to_state()
}
fn generate_hash(base_state: &Blake2bState, i: u32) -> Blake2bHash {
let mut lei = [0u8; 4];
(&mut lei[..]).write_u32::<LittleEndian>(i).unwrap();
let mut state = base_state.clone();
state.update(&lei);
state.finalize()
}
fn expand_array(vin: &[u8], bit_len: usize, byte_pad: usize) -> Vec<u8> {
assert!(bit_len >= 8);
assert!(8 * size_of::<u32>() >= 7 + bit_len);
let out_width = (bit_len + 7) / 8 + byte_pad;
let out_len = 8 * out_width * vin.len() / bit_len;
2017-12-30 14:57:51 -08:00
// Shortcut for parameters where expansion is a no-op
if out_len == vin.len() {
return vin.to_vec();
}
let mut vout: Vec<u8> = vec![0; out_len];
let bit_len_mask: u32 = (1 << bit_len) - 1;
// The acc_bits least-significant bits of acc_value represent a bit sequence
// in big-endian order.
let mut acc_bits = 0;
let mut acc_value: u32 = 0;
let mut j = 0;
2017-12-30 14:57:51 -08:00
for b in vin {
acc_value = (acc_value << 8) | u32::from(*b);
acc_bits += 8;
// When we have bit_len or more bits in the accumulator, write the next
// output element.
if acc_bits >= bit_len {
acc_bits -= bit_len;
for x in byte_pad..out_width {
vout[j + x] = ((
// Big-endian
acc_value >> (acc_bits + (8 * (out_width - x - 1)))
2019-01-03 13:52:06 -08:00
) & (
// Apply bit_len_mask across byte boundaries
(bit_len_mask >> (8 * (out_width - x - 1))) & 0xFF
)) as u8;
}
j += out_width;
}
}
vout
}
fn indices_from_minimal(p: Params, minimal: &[u8]) -> Result<Vec<u32>, Error> {
let c_bit_len = p.collision_bit_length();
// Division is exact because k >= 3.
if minimal.len() != ((1 << p.k) * (c_bit_len + 1)) / 8 {
return Err(Error(Kind::InvalidParams));
}
assert!(((c_bit_len + 1) + 7) / 8 <= size_of::<u32>());
let len_indices = 8 * size_of::<u32>() * minimal.len() / (c_bit_len + 1);
let byte_pad = size_of::<u32>() - ((c_bit_len + 1) + 7) / 8;
let mut csr = Cursor::new(expand_array(minimal, c_bit_len + 1, byte_pad));
let mut ret = Vec::with_capacity(len_indices);
// Big-endian so that lexicographic array comparison is equivalent to integer
// comparison
while let Ok(i) = csr.read_u32::<BigEndian>() {
ret.push(i);
}
Ok(ret)
}
fn has_collision(a: &Node, b: &Node, len: usize) -> bool {
2017-12-30 14:57:51 -08:00
a.hash
.iter()
.zip(b.hash.iter())
.take(len)
.all(|(a, b)| a == b)
}
fn distinct_indices(a: &Node, b: &Node) -> bool {
for i in &(a.indices) {
for j in &(b.indices) {
if i == j {
return false;
}
}
}
true
}
fn validate_subtrees(p: &Params, a: &Node, b: &Node) -> Result<(), Kind> {
2018-05-23 21:19:10 -07:00
if !has_collision(a, b, p.collision_byte_length()) {
Err(Kind::Collision)
2018-05-23 21:19:10 -07:00
} else if b.indices_before(a) {
Err(Kind::OutOfOrder)
2018-05-23 21:19:10 -07:00
} else if !distinct_indices(a, b) {
Err(Kind::DuplicateIdxs)
2018-05-23 21:19:10 -07:00
} else {
Ok(())
2018-05-23 21:19:10 -07:00
}
}
#[cfg(test)]
fn is_valid_solution_iterative(
p: Params,
input: &[u8],
nonce: &[u8],
indices: &[u32],
) -> Result<(), Error> {
2017-12-30 14:57:51 -08:00
let mut state = initialise_state(p.n, p.k, p.hash_output());
state.update(input);
state.update(nonce);
2017-12-30 14:57:51 -08:00
let mut rows = Vec::new();
for i in indices {
2018-05-23 21:19:10 -07:00
rows.push(Node::new(&p, &state, *i));
}
2017-12-30 14:57:51 -08:00
let mut hash_len = p.hash_length();
while rows.len() > 1 {
let mut cur_rows = Vec::new();
for pair in rows.chunks(2) {
let a = &pair[0];
let b = &pair[1];
validate_subtrees(&p, a, b).map_err(Error)?;
2017-12-30 14:57:51 -08:00
cur_rows.push(Node::from_children_ref(a, b, p.collision_byte_length()));
}
2017-12-30 14:57:51 -08:00
rows = cur_rows;
hash_len -= p.collision_byte_length();
}
2017-12-30 14:57:51 -08:00
assert!(rows.len() == 1);
if rows[0].is_zero(hash_len) {
Ok(())
} else {
Err(Error(Kind::NonZeroRootHash))
}
}
fn tree_validator(p: &Params, state: &Blake2bState, indices: &[u32]) -> Result<Node, Error> {
if indices.len() > 1 {
let end = indices.len();
let mid = end / 2;
let a = tree_validator(p, state, &indices[0..mid])?;
let b = tree_validator(p, state, &indices[mid..end])?;
validate_subtrees(p, &a, &b).map_err(Error)?;
Ok(Node::from_children(a, b, p.collision_byte_length()))
} else {
Ok(Node::new(&p, &state, indices[0]))
}
}
fn is_valid_solution_recursive(
p: Params,
input: &[u8],
nonce: &[u8],
indices: &[u32],
) -> Result<(), Error> {
let mut state = initialise_state(p.n, p.k, p.hash_output());
state.update(input);
state.update(nonce);
let root = tree_validator(&p, &state, indices)?;
// Hashes were trimmed, so only need to check remaining length
if root.is_zero(p.collision_byte_length()) {
Ok(())
} else {
Err(Error(Kind::NonZeroRootHash))
}
}
2020-07-07 22:34:52 -07:00
/// Checks whether `soln` is a valid solution for `(input, nonce)` with the
/// parameters `(n, k)`.
pub fn is_valid_solution(
n: u32,
k: u32,
input: &[u8],
nonce: &[u8],
soln: &[u8],
) -> Result<(), Error> {
let p = Params::new(n, k)?;
let indices = indices_from_minimal(p, soln)?;
// Recursive validation is faster
is_valid_solution_recursive(p, input, nonce, &indices)
}
#[cfg(test)]
mod tests {
use super::is_valid_solution_iterative;
use super::is_valid_solution_recursive;
use crate::test_vectors::{INVALID_TEST_VECTORS, VALID_TEST_VECTORS};
#[test]
fn valid_test_vectors() {
for tv in VALID_TEST_VECTORS {
for soln in tv.solutions {
is_valid_solution_iterative(tv.params, tv.input, &tv.nonce, soln).unwrap();
is_valid_solution_recursive(tv.params, tv.input, &tv.nonce, soln).unwrap();
}
}
}
#[test]
fn invalid_test_vectors() {
for tv in INVALID_TEST_VECTORS {
assert_eq!(
is_valid_solution_iterative(tv.params, tv.input, &tv.nonce, &tv.solution)
.unwrap_err()
.0,
tv.error
);
assert_eq!(
is_valid_solution_recursive(tv.params, tv.input, &tv.nonce, &tv.solution)
.unwrap_err()
.0,
tv.error
);
}
}
}