Implement index-truncation Equihash optimisation
This commit is contained in:
parent
6afef0dd6d
commit
c92c1f6050
|
@ -48,6 +48,19 @@ int Equihash::InitialiseState(eh_HashState& base_state)
|
|||
personalization);
|
||||
}
|
||||
|
||||
eh_trunc TruncateIndex(eh_index i, unsigned int ilen)
|
||||
{
|
||||
// Truncate to 8 bits
|
||||
assert(sizeof(eh_trunc) == 1);
|
||||
return (i >> (ilen - 8)) & 0xff;
|
||||
}
|
||||
|
||||
eh_index UntruncateIndex(eh_trunc t, eh_index r, unsigned int ilen)
|
||||
{
|
||||
eh_index i{t};
|
||||
return (i << (ilen - 8)) | r;
|
||||
}
|
||||
|
||||
StepRow::StepRow(unsigned int n, const eh_HashState& base_state, eh_index i) :
|
||||
hash {new unsigned char[n/8]},
|
||||
len {n/8}
|
||||
|
@ -152,6 +165,47 @@ bool DistinctIndices(const FullStepRow& a, const FullStepRow& b)
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsValidBranch(const FullStepRow& a, const unsigned int ilen, const eh_trunc t)
|
||||
{
|
||||
return TruncateIndex(a.indices[0], ilen) == t;
|
||||
}
|
||||
|
||||
TruncatedStepRow::TruncatedStepRow(unsigned int n, const eh_HashState& base_state, eh_index i, unsigned int ilen) :
|
||||
StepRow {n, base_state, i},
|
||||
indices {TruncateIndex(i, ilen)}
|
||||
{
|
||||
assert(indices.size() == 1);
|
||||
}
|
||||
|
||||
TruncatedStepRow& TruncatedStepRow::operator=(const TruncatedStepRow& a)
|
||||
{
|
||||
unsigned char* p = new unsigned char[a.len];
|
||||
std::copy(a.hash, a.hash+a.len, p);
|
||||
delete[] hash;
|
||||
hash = p;
|
||||
len = a.len;
|
||||
indices = a.indices;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TruncatedStepRow& TruncatedStepRow::operator^=(const TruncatedStepRow& a)
|
||||
{
|
||||
if (a.len != len) {
|
||||
throw std::invalid_argument("Hash length differs");
|
||||
}
|
||||
if (a.indices.size() != indices.size()) {
|
||||
throw std::invalid_argument("Number of indices differs");
|
||||
}
|
||||
unsigned char* p = new unsigned char[len];
|
||||
for (int i = 0; i < len; i++)
|
||||
p[i] = hash[i] ^ a.hash[i];
|
||||
delete[] hash;
|
||||
hash = p;
|
||||
indices.reserve(indices.size() + a.indices.size());
|
||||
indices.insert(indices.end(), a.indices.begin(), a.indices.end());
|
||||
return *this;
|
||||
}
|
||||
|
||||
Equihash::Equihash(unsigned int n, unsigned int k) :
|
||||
n(n), k(k)
|
||||
{
|
||||
|
@ -244,6 +298,207 @@ std::set<std::vector<eh_index>> Equihash::BasicSolve(const eh_HashState& base_st
|
|||
return solns;
|
||||
}
|
||||
|
||||
void CollideBranches(std::vector<FullStepRow>& X, const unsigned int clen, const unsigned int ilen, const eh_trunc lt, const eh_trunc rt)
|
||||
{
|
||||
int i = 0;
|
||||
int posFree = 0;
|
||||
std::vector<FullStepRow> Xc;
|
||||
while (i < X.size() - 1) {
|
||||
// 2b) Find next set of unordered pairs with collisions on the next n/(k+1) bits
|
||||
int j = 1;
|
||||
while (i+j < X.size() &&
|
||||
HasCollision(X[i], X[i+j], clen)) {
|
||||
j++;
|
||||
}
|
||||
|
||||
// 2c) Calculate tuples (X_i ^ X_j, (i, j))
|
||||
for (int l = 0; l < j - 1; l++) {
|
||||
for (int m = l + 1; m < j; m++) {
|
||||
if (DistinctIndices(X[i+l], X[i+m])) {
|
||||
if (IsValidBranch(X[i+l], ilen, lt) && IsValidBranch(X[i+m], ilen, rt)) {
|
||||
Xc.push_back(X[i+l] ^ X[i+m]);
|
||||
Xc.back().TrimHash(clen);
|
||||
} else if (IsValidBranch(X[i+m], ilen, lt) && IsValidBranch(X[i+l], ilen, rt)) {
|
||||
Xc.push_back(X[i+m] ^ X[i+l]);
|
||||
Xc.back().TrimHash(clen);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2d) Store tuples on the table in-place if possible
|
||||
while (posFree < i+j && Xc.size() > 0) {
|
||||
X[posFree++] = Xc.back();
|
||||
Xc.pop_back();
|
||||
}
|
||||
|
||||
i += j;
|
||||
}
|
||||
|
||||
// 2e) Handle edge case where final table entry has no collision
|
||||
while (posFree < X.size() && Xc.size() > 0) {
|
||||
X[posFree++] = Xc.back();
|
||||
Xc.pop_back();
|
||||
}
|
||||
|
||||
if (Xc.size() > 0) {
|
||||
// 2f) Add overflow to end of table
|
||||
X.insert(X.end(), Xc.begin(), Xc.end());
|
||||
} else if (posFree < X.size()) {
|
||||
// 2g) Remove empty space at the end
|
||||
X.erase(X.begin()+posFree, X.end());
|
||||
X.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
std::set<std::vector<eh_index>> Equihash::OptimisedSolve(const eh_HashState& base_state)
|
||||
{
|
||||
assert(CollisionBitLength() + 1 < 8*sizeof(eh_index));
|
||||
eh_index init_size { 1 << (CollisionBitLength() + 1) };
|
||||
|
||||
// First run the algorithm with truncated indices
|
||||
|
||||
std::vector<std::vector<eh_trunc>> partialSolns;
|
||||
{
|
||||
|
||||
// 1) Generate first list
|
||||
LogPrint("pow", "Generating first list\n");
|
||||
std::vector<TruncatedStepRow> Xt;
|
||||
Xt.reserve(init_size);
|
||||
for (eh_index i = 0; i < init_size; i++) {
|
||||
Xt.emplace_back(n, base_state, i, CollisionBitLength() + 1);
|
||||
}
|
||||
|
||||
// 3) Repeat step 2 until 2n/(k+1) bits remain
|
||||
for (int r = 1; r < k && Xt.size() > 0; r++) {
|
||||
LogPrint("pow", "Round %d:\n", r);
|
||||
// 2a) Sort the list
|
||||
LogPrint("pow", "- Sorting list\n");
|
||||
std::sort(Xt.begin(), Xt.end());
|
||||
|
||||
LogPrint("pow", "- Finding collisions\n");
|
||||
int i = 0;
|
||||
int posFree = 0;
|
||||
std::vector<TruncatedStepRow> Xc;
|
||||
while (i < Xt.size() - 1) {
|
||||
// 2b) Find next set of unordered pairs with collisions on the next n/(k+1) bits
|
||||
int j = 1;
|
||||
while (i+j < Xt.size() &&
|
||||
HasCollision(Xt[i], Xt[i+j], CollisionByteLength())) {
|
||||
j++;
|
||||
}
|
||||
|
||||
// 2c) Calculate tuples (X_i ^ X_j, (i, j))
|
||||
for (int l = 0; l < j - 1; l++) {
|
||||
for (int m = l + 1; m < j; m++) {
|
||||
// We truncated, so don't check for distinct indices here
|
||||
Xc.push_back(Xt[i+l] ^ Xt[i+m]);
|
||||
Xc.back().TrimHash(CollisionByteLength());
|
||||
}
|
||||
}
|
||||
|
||||
// 2d) Store tuples on the table in-place if possible
|
||||
while (posFree < i+j && Xc.size() > 0) {
|
||||
Xt[posFree++] = Xc.back();
|
||||
Xc.pop_back();
|
||||
}
|
||||
|
||||
i += j;
|
||||
}
|
||||
|
||||
// 2e) Handle edge case where final table entry has no collision
|
||||
while (posFree < Xt.size() && Xc.size() > 0) {
|
||||
Xt[posFree++] = Xc.back();
|
||||
Xc.pop_back();
|
||||
}
|
||||
|
||||
if (Xc.size() > 0) {
|
||||
// 2f) Add overflow to end of table
|
||||
Xt.insert(Xt.end(), Xc.begin(), Xc.end());
|
||||
} else if (posFree < Xt.size()) {
|
||||
// 2g) Remove empty space at the end
|
||||
Xt.erase(Xt.begin()+posFree, Xt.end());
|
||||
Xt.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
// k+1) Find a collision on last 2n(k+1) bits
|
||||
LogPrint("pow", "Final round:\n");
|
||||
if (Xt.size() > 1) {
|
||||
LogPrint("pow", "- Sorting list\n");
|
||||
std::sort(Xt.begin(), Xt.end());
|
||||
LogPrint("pow", "- Finding collisions\n");
|
||||
for (int i = 0; i < Xt.size() - 1; i++) {
|
||||
TruncatedStepRow res = Xt[i] ^ Xt[i+1];
|
||||
if (res.IsZero()) {
|
||||
partialSolns.push_back(res.GetPartialSolution());
|
||||
}
|
||||
}
|
||||
} else
|
||||
LogPrint("pow", "- List is empty\n");
|
||||
|
||||
} // Ensure Xt goes out of scope and is destroyed
|
||||
|
||||
LogPrint("pow", "Found %d partial solutions\n", partialSolns.size());
|
||||
|
||||
// Now for each solution run the algorithm again to recreate the indices
|
||||
LogPrint("pow", "Culling solutions\n");
|
||||
std::set<std::vector<eh_index>> solns;
|
||||
eh_index recreate_size { UntruncateIndex(1, 0, CollisionBitLength() + 1) };
|
||||
int invalidCount = 0;
|
||||
for (std::vector<eh_trunc> partialSoln : partialSolns) {
|
||||
// 1) Generate first list of possibilities
|
||||
std::vector<std::vector<FullStepRow>> X;
|
||||
X.reserve(partialSoln.size());
|
||||
for (int i = 0; i < partialSoln.size(); i++) {
|
||||
std::vector<FullStepRow> ic;
|
||||
ic.reserve(recreate_size);
|
||||
for (eh_index j = 0; j < recreate_size; j++) {
|
||||
eh_index newIndex { UntruncateIndex(partialSoln[i], j, CollisionBitLength() + 1) };
|
||||
ic.emplace_back(n, base_state, newIndex);
|
||||
}
|
||||
X.push_back(ic);
|
||||
}
|
||||
|
||||
// 3) Repeat step 2 for each level of the tree
|
||||
for (int r = 0; X.size() > 1; r++) {
|
||||
std::vector<std::vector<FullStepRow>> Xc;
|
||||
Xc.reserve(X.size()/2);
|
||||
|
||||
// 2a) For each pair of lists:
|
||||
for (int v = 0; v < X.size(); v += 2) {
|
||||
// 2b) Merge the lists
|
||||
std::vector<FullStepRow> ic(X[v]);
|
||||
ic.reserve(X[v].size() + X[v+1].size());
|
||||
ic.insert(ic.end(), X[v+1].begin(), X[v+1].end());
|
||||
std::sort(ic.begin(), ic.end());
|
||||
CollideBranches(ic, CollisionByteLength(), CollisionBitLength() + 1, partialSoln[(1<<r)*v], partialSoln[(1<<r)*(v+1)]);
|
||||
|
||||
// 2v) Check if this has become an invalid solution
|
||||
if (ic.size() == 0)
|
||||
goto invalidsolution;
|
||||
|
||||
Xc.push_back(ic);
|
||||
}
|
||||
|
||||
X = Xc;
|
||||
}
|
||||
|
||||
// We are at the top of the tree
|
||||
assert(X.size() == 1);
|
||||
for (FullStepRow row : X[0]) {
|
||||
solns.insert(row.GetSolution());
|
||||
}
|
||||
continue;
|
||||
|
||||
invalidsolution:
|
||||
invalidCount++;
|
||||
}
|
||||
LogPrint("pow", "- Number of invalid solutions found: %d\n", invalidCount);
|
||||
|
||||
return solns;
|
||||
}
|
||||
|
||||
bool Equihash::IsValidSolution(const eh_HashState& base_state, std::vector<eh_index> soln)
|
||||
{
|
||||
eh_index soln_size { 1u << k };
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
typedef crypto_generichash_blake2b_state eh_HashState;
|
||||
typedef uint32_t eh_index;
|
||||
typedef uint8_t eh_trunc;
|
||||
|
||||
struct invalid_params { };
|
||||
|
||||
|
@ -66,9 +67,33 @@ public:
|
|||
}
|
||||
|
||||
friend bool DistinctIndices(const FullStepRow& a, const FullStepRow& b);
|
||||
friend bool IsValidBranch(const FullStepRow& a, const unsigned int ilen, const eh_trunc t);
|
||||
};
|
||||
|
||||
bool DistinctIndices(const FullStepRow& a, const FullStepRow& b);
|
||||
bool IsValidBranch(const FullStepRow& a, const unsigned int ilen, const eh_trunc t);
|
||||
|
||||
class TruncatedStepRow : public StepRow
|
||||
{
|
||||
private:
|
||||
std::vector<eh_trunc> indices;
|
||||
|
||||
public:
|
||||
TruncatedStepRow(unsigned int n, const eh_HashState& base_state, eh_index i, unsigned int ilen);
|
||||
~TruncatedStepRow() { }
|
||||
|
||||
TruncatedStepRow(const TruncatedStepRow& a) : StepRow {a}, indices(a.indices) { }
|
||||
TruncatedStepRow& operator=(const TruncatedStepRow& a);
|
||||
TruncatedStepRow& operator^=(const TruncatedStepRow& a);
|
||||
|
||||
bool IndicesBefore(const TruncatedStepRow& a) { return indices[0] < a.indices[0]; }
|
||||
std::vector<eh_trunc> GetPartialSolution() { return std::vector<eh_trunc>(indices); }
|
||||
|
||||
friend inline const TruncatedStepRow operator^(const TruncatedStepRow& a, const TruncatedStepRow& b) {
|
||||
if (a.indices[0] < b.indices[0]) { return TruncatedStepRow(a) ^= b; }
|
||||
else { return TruncatedStepRow(b) ^= a; }
|
||||
}
|
||||
};
|
||||
|
||||
class Equihash
|
||||
{
|
||||
|
@ -84,6 +109,7 @@ public:
|
|||
|
||||
int InitialiseState(eh_HashState& base_state);
|
||||
std::set<std::vector<eh_index>> BasicSolve(const eh_HashState& base_state);
|
||||
std::set<std::vector<eh_index>> OptimisedSolve(const eh_HashState& base_state);
|
||||
bool IsValidSolution(const eh_HashState& base_state, std::vector<eh_index> soln);
|
||||
};
|
||||
|
||||
|
|
|
@ -56,6 +56,15 @@ void TestEquihashSolvers(unsigned int n, unsigned int k, const std::string &I, c
|
|||
PrintSolutions(strm, ret);
|
||||
BOOST_TEST_MESSAGE(strm.str());
|
||||
BOOST_CHECK(ret == solns);
|
||||
|
||||
// The optimised solver should have the exact same result
|
||||
std::set<std::vector<uint32_t>> retOpt = eh.OptimisedSolve(state);
|
||||
BOOST_TEST_MESSAGE("[Optimised] Number of solutions: " << retOpt.size());
|
||||
strm.str("");
|
||||
PrintSolutions(strm, retOpt);
|
||||
BOOST_TEST_MESSAGE(strm.str());
|
||||
BOOST_CHECK(retOpt == solns);
|
||||
BOOST_CHECK(retOpt == ret);
|
||||
}
|
||||
|
||||
void TestEquihashValidator(unsigned int n, unsigned int k, const std::string &I, const arith_uint256 &nonce, std::vector<uint32_t> soln, bool expected) {
|
||||
|
|
Loading…
Reference in New Issue