equihash: Add parameter validity checks

This commit is contained in:
Jack Grigg 2020-07-07 22:09:24 +12:00
parent 997657f256
commit 8759684fad
1 changed files with 27 additions and 17 deletions

View File

@ -8,7 +8,8 @@ use std::fmt;
use std::io::Cursor;
use std::mem::size_of;
struct Params {
#[derive(Clone, Copy)]
pub struct Params {
n: u32,
k: u32,
}
@ -20,6 +21,13 @@ struct Node {
}
impl Params {
pub fn new(n: u32, k: u32) -> Result<Self, Error> {
if (k < n) && (n % 8 == 0) {
Ok(Params { n, k })
} else {
Err(Error(Kind::InvalidParams))
}
}
fn indices_per_hash_output(&self) -> u32 {
512 / self.n
}
@ -111,6 +119,7 @@ impl std::error::Error for Error {}
#[derive(Debug)]
enum Kind {
InvalidParams,
Collision,
OutOfOrder,
DuplicateIdxs,
@ -120,6 +129,7 @@ enum Kind {
impl fmt::Display for Kind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Kind::InvalidParams => f.write_str("invalid parameters"),
Kind::Collision => f.write_str("invalid collision length between StepRows"),
Kind::OutOfOrder => f.write_str("Index tree incorrectly ordered"),
Kind::DuplicateIdxs => f.write_str("duplicate indices"),
@ -193,7 +203,12 @@ fn expand_array(vin: &[u8], bit_len: usize, byte_pad: usize) -> Vec<u8> {
vout
}
fn indices_from_minimal(minimal: &[u8], c_bit_len: usize) -> Vec<u32> {
fn indices_from_minimal(p: Params, minimal: &[u8]) -> Result<Vec<u32>, Error> {
let c_bit_len = p.collision_bit_length();
if minimal.len() != (1 << p.k) * (c_bit_len + 1) / 8 {
return Err(Error(Kind::InvalidParams));
}
assert!(((c_bit_len + 1) + 7) / 8 <= size_of::<u32>());
let len_indices = 8 * size_of::<u32>() * minimal.len() / (c_bit_len + 1);
let byte_pad = size_of::<u32>() - ((c_bit_len + 1) + 7) / 8;
@ -207,7 +222,7 @@ fn indices_from_minimal(minimal: &[u8], c_bit_len: usize) -> Vec<u32> {
ret.push(i);
}
ret
Ok(ret)
}
fn has_collision(a: &Node, b: &Node, len: usize) -> bool {
@ -242,14 +257,11 @@ fn validate_subtrees(p: &Params, a: &Node, b: &Node) -> Result<(), Kind> {
}
pub fn is_valid_solution_iterative(
n: u32,
k: u32,
p: Params,
input: &[u8],
nonce: &[u8],
indices: &[u32],
) -> Result<(), Error> {
let p = Params { n, k };
let mut state = initialise_state(p.n, p.k, p.hash_output());
state.update(input);
state.update(nonce);
@ -295,14 +307,11 @@ fn tree_validator(p: &Params, state: &Blake2bState, indices: &[u32]) -> Result<N
}
pub fn is_valid_solution_recursive(
n: u32,
k: u32,
p: Params,
input: &[u8],
nonce: &[u8],
indices: &[u32],
) -> Result<(), Error> {
let p = Params { n, k };
let mut state = initialise_state(p.n, p.k, p.hash_output());
state.update(input);
state.update(nonce);
@ -324,18 +333,18 @@ pub fn is_valid_solution(
nonce: &[u8],
soln: &[u8],
) -> Result<(), Error> {
let p = Params { n, k };
let indices = indices_from_minimal(soln, p.collision_bit_length());
let p = Params::new(n, k)?;
let indices = indices_from_minimal(p, soln)?;
// Recursive validation is faster
is_valid_solution_recursive(n, k, input, nonce, &indices)
is_valid_solution_recursive(p, input, nonce, &indices)
}
#[cfg(test)]
mod tests {
use super::is_valid_solution_iterative;
use super::is_valid_solution_recursive;
use super::Error;
use super::{Error, Params};
fn is_valid_solution(
n: u32,
@ -344,8 +353,9 @@ mod tests {
nonce: &[u8],
indices: &[u32],
) -> Result<(), Error> {
is_valid_solution_iterative(n, k, input, nonce, indices)?;
is_valid_solution_recursive(n, k, input, nonce, indices)?;
let p = Params::new(n, k).unwrap();
is_valid_solution_iterative(p, input, nonce, indices)?;
is_valid_solution_recursive(p, input, nonce, indices)?;
Ok(())
}