Fix tests

This commit is contained in:
Mariano Sorgente 2019-10-16 17:38:49 +09:00
commit b4baff9f4e
68 changed files with 18572 additions and 220 deletions

8
.gitignore vendored
View File

@ -5,6 +5,14 @@ __pycache__/
# C extensions
*.so
**/*.o
**/*.DS_Store
# VDF executables
lib/chiavdf/fast_vdf/compile_asm
lib/chiavdf/fast_vdf/vdf
# Flint dependency
lib/chiavdf/fast_vdf/flint
# PyInstaller
# Usually these files are written by a python script from a template

View File

@ -13,11 +13,18 @@ python3 -m venv .venv
pip install wheel
pip install .
pip install lib/chiapos
cd lib/chiavdf/fast_vdf
# Install libgmp, libboost, and libflint, and then run the following
sh install.sh
```
### Run servers
When running the servers on Mac OS, allow the application to accept incoming connections.
Run the servers in the following order (you can also use ipython):
```bash
./lib/chiavdf/fast_vdf/vdf 8889
./lib/chiavdf/fast_vdf/vdf 8890
python -m src.server.start_plotter
python -m src.server.start_timelord
python -m src.server.start_farmer

@ -1 +1 @@
Subproject commit b69ce7166f28e73a193b6f694ecf441c99240145
Subproject commit 6a5570ba4d1b71d8e0e8e3f7e19acb898d601ff5

@ -1 +1 @@
Subproject commit c9d32a81f40ad540015814edf13b29980c63e39c
Subproject commit 34c2281e315c51f5270321101dc733c1cf26214f

View File

@ -0,0 +1,85 @@
/**
Copyright (C) 2018 Markku Pulkkinen
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
**/
#ifndef CLASSGROUP_H
#define CLASSGROUP_H
#include <cstdint>
#include "gmp.h"
/**
* @brief The ClassGroup data struct for VDF variables a, b, c and discriminant.
* Optimal size because it fits into single entry of 64 byte wide cache line.
*/
struct alignas(64) ClassGroup {
mpz_t a;
mpz_t b;
mpz_t c;
mpz_t d;
};
/**
* @brief ClassGroupContext struct - placeholder for variables
* in classgroup arithmetic operations. Uses four cache
* line entries, 256 bytes.
*/
struct alignas(64) ClassGroupContext {
mpz_t a;
mpz_t b;
mpz_t c;
mpz_t mu;
mpz_t m;
mpz_t r;
mpz_t s;
mpz_t faa;
mpz_t fab;
mpz_t fac;
mpz_t fba;
mpz_t fbb;
mpz_t fbc;
mpz_t fca;
mpz_t fcb;
mpz_t fcc;
ClassGroupContext(uint32_t numBits = 4096) {
mpz_init2(a, numBits);
mpz_init2(b, numBits);
mpz_init2(c, numBits);
mpz_init2(mu, numBits);
mpz_init2(m, numBits);
mpz_init2(r, numBits);
mpz_init2(s, numBits);
mpz_init2(faa, numBits);
mpz_init2(fab, numBits);
mpz_init2(fac, numBits);
mpz_init2(fba, numBits);
mpz_init2(fbb, numBits);
mpz_init2(fbc, numBits);
mpz_init2(fca, numBits);
mpz_init2(fcb, numBits);
mpz_init2(fcc, numBits);
}
~ClassGroupContext() {
mpz_clears(a, b, c, mu, m, r, s, faa, fab, fac, fba, fbb, fbc, fca, fcb,
fcc, NULL);
}
};
#endif // CLASSGROUP_H

View File

@ -0,0 +1,68 @@
Copyright 2018 Ilya Gorodetskov
generic@sundersoft.com
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=== Summary ===
The NUDUPL algorithm is used. The equations are based on cryptoslava's equations from the previous contest. They were modified slightly to increase the level of parallelism.
The GCD is a custom implementation with scalar integers. There are two base cases: one uses a lookup table with continued fractions and the other uses the euclidean algorithm with a division table. The division table algorithm is slightly faster even though it has about 2x as many iterations.
After the base case, there is a 128 bit GCD that generates 64 bit cofactor matricies with Lehmer's algorithm. This is required to make the long integer multiplications efficient (Flint's implementation doesn't do this).
The GCD also implements Flint's partial xgcd function, but the output is slightly different. This implementation will always return an A value which is > the threshold and a B value which is <= the threshold. For a normal GCD, the threshold is 0, B is 0, and A is the GCD. Also the interfaces are slightly different.
Scalar integers are used for the GCD. I don't expect any speedup for the SIMD integers that were used in the last implementation since the GCD only uses 64x1024 multiplications, which are too small and have too high of a carry overhead for the SIMD version to be faster. In either case, most of the time seems to be spent in the base case so it shouldn't matter too much.
If SIMD integers are used with AVX-512, doubles have to be used because the multiplier sizes for doubles are significantly larger than for integers. There is an AVX-512 extension to support larger integer multiplications but no processor implements it yet. It should be possible to do a 50 bit multiply-add into a 100 bit accumulator with 4 fused multiply-adds if the accumulators have a special nonzero initial value and the inputs are scaled before the multiplication. This would make AVX-512 about 2.5x faster than scalar code for 1024x1024 integer multiplications (assuming the scalar code is unrolled and uses ADOX/ADCX/MULX properly, and the CPU can execute this at 1 cycle per iteration which it probably can't).
The GCD is parallelized by calculating the cofactors in a separate slave thread. The master thread will calculate the cofactor matricies and send them to the slave thread. Other calculations are also parallelized.
The VDF implementation from the first contest is still used as a fallback and is called about once every 5000 iterations. The GCD will encounter large quotients about this often and these are not implemented. This has a negligble effect on performance. Also, the NUDUPL case where A<=L is not implemented; it will fall back to the old implementation in this case (this never happens outside of the first 20 or so iterations).
There is also corruption detection by calculating C with a non-exact division and making sure the remainder is 0. This detected all injected random corruptions that I tested. No corruptions caused by bugs were observed during testing. This cannot correct for the sign of B being wrong.
=== GCD continued fraction lookup table ===
The is implemented in gcd_base_continued_fractions.h and asm_gcd_base_continued_fractions.h. The division table implementation is the same as the previous entry and was discussed there. Currently the division table is only used if AVX2 is enabled but it could be ported to SSE or scalar code easily. Both implementations have about the same performance.
The initial quotient sequence of gcd(a,b) is the same as the initial quotient sequence of gcd(a*2^n/b, 2^n) for any n. This is because the GCD quotients are the same as the continued fraction quotients of a/b, and the initial continued fraction quotients only depend on the initial bits of a/b. This makes it feasible to have a lookup table since it now only has one input.
a*2^n/b is calculated by doing a double precision division of a/b, and then truncating the lower bits. Some of the exponent bits are used in the table in addition to the fraction bits; this makes each slot of the table vary in size depending on what the exponent is. If the result is outside the table bounds, then the division result is floored to fall back to the euclidean algorithm (this is very rare).
The table is calculated by iterating all of the possible continued fractions that have a certain initial quotient sequence. Iteration ends when all of these fractions are either outside the table or they don't fully contain at least one slot of the table. Each slot that is fully contained by such a fraction is updated so that its quotient sequence equals the fraction's initial quotient sequence. Once this is complete, the cofactor matricies are calculated from the quotient sequences. Each cofactor matrix is 4 doubles.
The resulting code seems to have too many instructions so it doesn't perform very well. There might be some way to optimize it. It was written for SSE so that it would run on both processors.
This might work better on an FPGA possibly with low latency DRAM or SRAM (compared to the euclidean algorithm with a division table). There is no limit to the size of the table but doubling the latency would require the number of bits in the table to also be doubled to have the same performance.
=== Other GCD code ===
The gcd_128 function calculates a 128 bit GCD using Lehmer's algorithm. It is pretty straightforward and uses only unsigned arithmetic. Each cofactor matrix can only have two possible signs: [+ -; - +] or [- +; + -]. The gcd_unsigned function uses unsigned arithmetic and a jump table to apply the 64-bit cofactor matricies to the A and B values. It uses ADOX/ADCX/MULX if they are available and falls back to ADC/MUL otherwise. It will track the last known size of A to speed up the bit shifts required to get the top 128 bits of A.
No attempt was made to try to do the A and B long integer multiplications on a separate thread; I wouldn't expect any performance improvement from this.
=== Threads ===
There is a master thread and a slave thread. The slave thread only exists for each batch of 5000 or so squarings and is then destroyed and recreated for the next batch (this has no measurable overhead). If the original VDF is used as a fallback, the batch ends and the slave thread is destroyed.
Each thread has a 64-bit counter that only it can write to. Also, during a squaring iteration, it will not overwrite any value that it has previously written and transmitted to the other thread. Each squaring is split up into phases. Each thread will update its counter at the start of the phase (the counter can only be increased, not decreased). It can then wait on the other thread's counter to reach a certain value as part of a spin loop. If the spin loop takes too long, an error condition is raised and the batch ends; this should prevent any deadlocks from happening.
No CPU fences or atomics are required since each value can only be written to by one thread and since x86 enforces acquire/release ordering on all memory operations. Compiler memory fences are still required to prevent the compiler from caching or reordering memory operations.
The GCD master thread will increment the counter when a new cofactor matrix has been outputted. The slave thread will spin on this counter and then apply the cofactor matrix to the U or V vector to get a new U or V vector.
It was attempted to use modular arithmetic to calculate k directly but this slowed down the program due to GMP's modulo or integer multiply operations not having enough performance. This also makes the integer multiplications bigger.
The speedup isn't very high since most of the time is spent in the GCD base case and these can't be parallelized.

View File

@ -0,0 +1,214 @@
/**
Copyright (C) 2019 Markku Pulkkinen
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
**/
#ifndef REDUCER_H
#define REDUCER_H
#include <algorithm>
#include <array>
#include <cmath>
#include "ClassGroup.h"
/** constants utilized in reduction algorithm */
namespace {
const int_fast64_t THRESH{1ul << 31};
const int_fast64_t EXP_THRESH{31};
}
/**
* @brief The Reducer class that does custom reduce operation for VDF
* repeated squaring algorithm. The implementation is based on
* Akashnil VDF competition entry and further optimized for speed.
*/
class alignas(64) Reducer {
public:
/**
* @brief Reducer - constructs by using reference into cg context.
*/
Reducer(ClassGroupContext &ctx_) : ctx(ctx_) {}
~Reducer() {}
/**
* @brief run - runs reduction algorithm for cg context params
*/
inline void run() {
while (!isReduced()) {
int_fast64_t a, b, c;
{
int_fast64_t a_exp, b_exp, c_exp;
mpz_get_si_2exp(a, a_exp, ctx.a);
mpz_get_si_2exp(b, b_exp, ctx.b);
mpz_get_si_2exp(c, c_exp, ctx.c);
auto mm = std::minmax({a_exp, b_exp, c_exp});
if (mm.second - mm.first > EXP_THRESH) {
reducer();
continue;
}
// Ensure a, b, c are shifted so that a : b : c ratios are same as
// f.a : f.b : f.c. a, b, c will be used as approximations to f.a,
// f.b, f.c
int_fast64_t max_exp(mm.second++); // for safety vs overflow
a >>= (max_exp - a_exp);
b >>= (max_exp - b_exp);
c >>= (max_exp - c_exp);
}
{
int_fast64_t u, v, w, x;
calc_uvwx(u, v, w, x, a, b, c);
mpz_mul_si(ctx.faa, ctx.a, u * u);
mpz_mul_si(ctx.fab, ctx.b, u * w);
mpz_mul_si(ctx.fac, ctx.c, w * w);
mpz_mul_si(ctx.fba, ctx.a, u * v << 1);
mpz_mul_si(ctx.fbb, ctx.b, u * x + v * w);
mpz_mul_si(ctx.fbc, ctx.c, w * x << 1);
mpz_mul_si(ctx.fca, ctx.a, v * v);
mpz_mul_si(ctx.fcb, ctx.b, v * x);
mpz_mul_si(ctx.fcc, ctx.c, x * x);
mpz_add(ctx.a, ctx.faa, ctx.fab);
mpz_add(ctx.a, ctx.a, ctx.fac);
mpz_add(ctx.b, ctx.fba, ctx.fbb);
mpz_add(ctx.b, ctx.b, ctx.fbc);
mpz_add(ctx.c, ctx.fca, ctx.fcb);
mpz_add(ctx.c, ctx.c, ctx.fcc);
}
}
}
private:
inline void signed_shift(uint64_t op, int64_t shift, int_fast64_t &r) {
if (shift > 0)
r = static_cast<int64_t>(op << shift);
else if (shift <= -64)
r = 0;
else
r = static_cast<int64_t>(op >> (-shift));
}
inline void mpz_get_si_2exp(int_fast64_t &r, int_fast64_t &exp,
const mpz_t op) {
// Return an approximation x of the large mpz_t op by an int64_t and the
// exponent e adjustment. We must have (x * 2^e) / op = constant
// approximately.
int_fast64_t size(static_cast<long>(mpz_size(op)));
uint_fast64_t last(mpz_getlimbn(op, (size - 1)));
int_fast64_t lg2 = exp = ((63 - __builtin_clzll(last)) + 1);
signed_shift(last, (63 - exp), r);
if (size > 1) {
exp += (size - 1) * 64;
uint_fast64_t prev(mpz_getlimbn(op, (size - 2)));
int_fast64_t t;
signed_shift(prev, -1 - lg2, t);
r += t;
}
if (mpz_sgn(op) < 0)
r = -r;
}
inline bool isReduced() {
int a_b(mpz_cmpabs(ctx.a, ctx.b));
int c_b(mpz_cmpabs(ctx.c, ctx.b));
if (a_b < 0 || c_b < 0)
return false;
int a_c(mpz_cmp(ctx.a, ctx.c));
if (a_c > 0) {
mpz_swap(ctx.a, ctx.c);
mpz_neg(ctx.b, ctx.b);
} else if (a_c == 0 && mpz_sgn(ctx.b) < 0) {
mpz_neg(ctx.b, ctx.b);
}
return true;
}
inline void reducer() {
// (c + b)/2c == (1 + (b/c))/2 -> s
mpz_mdiv(ctx.r, ctx.b, ctx.c);
mpz_add_ui(ctx.r, ctx.r, 1);
mpz_div_2exp(ctx.s, ctx.r, 1);
// cs -> m
mpz_mul(ctx.m, ctx.c, ctx.s);
// 2cs -> r
mpz_mul_2exp(ctx.r, ctx.m, 1);
// (cs - b) -> m
mpz_sub(ctx.m, ctx.m, ctx.b);
// new b = -b + 2cs
mpz_sub(ctx.b, ctx.r, ctx.b);
// new a = c, c = a
mpz_swap(ctx.a, ctx.c);
// new c = c + cs^2 - bs ( == c + (s * ( cs - b)))
mpz_addmul(ctx.c, ctx.s, ctx.m);
}
inline void calc_uvwx(int_fast64_t &u, int_fast64_t &v, int_fast64_t &w,
int_fast64_t &x, int_fast64_t &a, int_fast64_t &b,
int_fast64_t &c) {
// We must be very careful about overflow in the following steps
int below_threshold;
int_fast64_t u_{1}, v_{0}, w_{0}, x_{1};
int_fast64_t a_, b_, s;
do {
u = u_;
v = v_;
w = w_;
x = x_;
s = static_cast<int_fast64_t>(
(floorf(b / (static_cast<float>(c))) + 1)) >>
1;
a_ = a;
b_ = b;
// cs = c * s;
// a = c
a = c;
// b = -b + 2cs
b = -b + (c * s << 1);
// c = a + cs^2 - bs
c = a_ - s * (b_ - c * s);
u_ = v;
v_ = -u + s * v;
w_ = x;
x_ = -w + s * x;
// The condition (abs(v_) | abs(x_)) <= THRESH protects against
// overflow
below_threshold = (abs(v_) | abs(x_)) <= THRESH ? 1 : 0;
} while (below_threshold && a > c && c > 0);
if (below_threshold) {
u = u_;
v = v_;
w = w_;
x = x_;
}
}
ClassGroupContext &ctx;
};
#endif // REDUCER_H

View File

@ -0,0 +1,133 @@
#ifdef GENERATE_ASM_TRACKING_DATA
const bool generate_asm_tracking_data=true;
#else
const bool generate_asm_tracking_data=false;
#endif
namespace asm_code {
string track_asm(string comment, string jump_to = "") {
if (!generate_asm_tracking_data) {
return jump_to;
}
mark_vdf_test();
static map<string, int> id_map;
static int next_id=1;
int& id=id_map[comment];
if (id==0) {
id=next_id;
++next_id;
}
assert(id>=1 && id<=num_asm_tracking_data);
//
//
static bool init=false;
if (!init) {
APPEND_M(str( ".data" ));
APPEND_M(str( ".balign 8" ));
APPEND_M(str( "track_asm_rax: .quad 0" ));
//APPEND_M(str( ".global asm_tracking_data" ));
//APPEND_M(str( "asm_tracking_data:" ));
//for (int x=0;x<num_asm_tracking_data;++x) {
//APPEND_M(str( ".quad 0" ));
//}
//APPEND_M(str( ".global asm_tracking_data_comments" ));
//APPEND_M(str( "asm_tracking_data_comments:" ));
//for (int x=0;x<num_asm_tracking_data;++x) {
//APPEND_M(str( ".quad 0" ));
//}
APPEND_M(str( ".text" ));
init=true;
}
string comment_label=m.alloc_label();
#ifdef CHIAOSX
APPEND_M(str( ".text " ));
#else
APPEND_M(str( ".text 1" ));
#endif
APPEND_M(str( "#:", comment_label ));
APPEND_M(str( ".string \"#\"", comment ));
APPEND_M(str( ".text" ));
string skip_label;
if (!jump_to.empty()) {
skip_label=m.alloc_label();
APPEND_M(str( "JMP #", skip_label ));
}
string c_label;
if (!jump_to.empty()) {
c_label=m.alloc_label();
APPEND_M(str( "#:", c_label ));
}
assert(!enable_threads); //this code isn't atomic
APPEND_M(str( "MOV [track_asm_rax], RAX" ));
APPEND_M(str( "MOV RAX, [asm_tracking_data+#]", to_hex(8*(id-1)) ));
APPEND_M(str( "LEA RAX, [RAX+1]" ));
APPEND_M(str( "MOV [asm_tracking_data+#], RAX", to_hex(8*(id-1)) ));
#ifdef CHIAOSX
APPEND_M(str( "LEA RAX, [RIP+comment_label] " ));
#else
APPEND_M(str( "MOV RAX, OFFSET FLAT:#", comment_label ));
#endif
APPEND_M(str( "MOV [asm_tracking_data_comments+#], RAX", to_hex(8*(id-1)) ));
APPEND_M(str( "MOV RAX, [track_asm_rax]" ));
if (!jump_to.empty()) {
APPEND_M(str( "JMP #", jump_to ));
APPEND_M(str( "#:", skip_label ));
}
return c_label;
}
//16-byte aligned; value is in both lanes
string constant_address_uint64(uint64 value_bits_0, uint64 value_bits_1, bool use_brackets=true) {
static map<pair<uint64, uint64>, string> constant_map;
string& name=constant_map[make_pair(value_bits_0, value_bits_1)];
if (name.empty()) {
name=m.alloc_label();
#ifdef CHIAOSX
APPEND_M(str( ".text " ));
#else
APPEND_M(str( ".text 1" ));
#endif
APPEND_M(str( ".balign 16" ));
APPEND_M(str( "#:", name ));
APPEND_M(str( ".quad #", to_hex(value_bits_0) )); //lane 0
APPEND_M(str( ".quad #", to_hex(value_bits_1) )); //lane 1
APPEND_M(str( ".text" ));
}
#ifdef CHIAOSX
return (use_brackets)? str( "[RIP+#]", name ) : name;
#else
return (use_brackets)? str( "[#]", name ) : name;
#endif
}
string constant_address_double(double value_0, double value_1, bool use_brackets=true) {
uint64 value_bits_0=*(uint64*)&value_0;
uint64 value_bits_1=*(uint64*)&value_1;
return constant_address_uint64(value_bits_0, value_bits_1, use_brackets);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,670 @@
namespace asm_code {
typedef array<reg_scalar, 2> reg_scalar_128;
//v[0] is low, v[1] is high. amount is >=0 and <128. res can't alias with v
//preserves inputs. returns low part of result
//regs: RCX, 1x scalar
void shift_right(
reg_alloc regs, array<reg_scalar, 2> v, reg_scalar amount, reg_scalar res,
reg_scalar tmp_rcx, reg_scalar tmp_res_2
) {
EXPAND_MACROS_SCOPE;
m.bind(v, "v");
m.bind(amount, "amount");
m.bind(res, "res");
assert(tmp_rcx.value==reg_rcx.value);
m.bind(tmp_res_2, "res_2");
//res=uint64([v[1]:v[0]] >> amount) ; undefined if amount>=64
APPEND_M(str( "MOV RCX, `amount" ));
APPEND_M(str( "MOV `res, `v_0" ));
APPEND_M(str( "SHRD `res, `v_1, CL" ));
//res_2=0
APPEND_M(str( "XOR `res_2, `res_2" ));
//RCX=amount-64
APPEND_M(str( "SUB RCX, 64" ));
//res=(amount>=64)? 0 : res
//res_2=(amount>=64)? v[1] : 0
APPEND_M(str( "CMOVAE `res, `res_2" ));
APPEND_M(str( "CMOVAE `res_2, `v_1" ));
//res_2=(amount>=64)? 0 : v[1]>>(amount-64)
APPEND_M(str( "SHR `res_2, CL" ));
//res=(amount>=64)? res_2 : res
APPEND_M(str( "OR `res, `res_2" ));
}
//all inputs are unsigned
void dot_product_exact(reg_alloc regs, array<reg_scalar, 2> a, array<reg_scalar, 2> b, reg_scalar out, string overflow_label) {
EXPAND_MACROS_SCOPE;
m.bind(a, "a");
m.bind(b, "b");
m.bind(out, "out");
reg_scalar rax=regs.bind_scalar(m, "rax", reg_rax);
reg_scalar rdx=regs.bind_scalar(m, "rdx", reg_rdx);
//out=a0*b0
APPEND_M(str( "MOV RAX, `a_0" ));
APPEND_M(str( "MUL `b_0" ));
APPEND_M(str( "JC #", overflow_label ));
APPEND_M(str( "MOV `out, RAX" ));
//RAX=a1*b1
APPEND_M(str( "MOV RAX, `a_1" ));
APPEND_M(str( "MUL `b_1" ));
APPEND_M(str( "JC #", overflow_label ));
//out=a0*b0+a1*b1
APPEND_M(str( "ADD `out, RAX" ));
APPEND_M(str( "JC #", overflow_label ));
}
//ab and ab_threshold reg_spill are 16 bytes (lsb first), 8 byte aligned. all others are 8 bytes
//parity is 1 if odd, else 0
//is_lehmer is 1 if true, else 0
//u, v, and parity are outputs
//regs: 15x scalar, 16x vector (i.e. all of the registers except RSP)
void gcd_128(
reg_alloc regs_parent,
array<reg_spill, 2> spill_ab_start, array<reg_spill, 2> spill_u, array<reg_spill, 2> spill_v,
reg_spill spill_parity, reg_spill spill_is_lehmer, reg_spill spill_ab_threshold,
string no_progress_label
) {
EXPAND_MACROS_SCOPE_PUBLIC;
track_asm( "gcd_128" );
m.bind(spill_ab_start[0], "spill_ab_start_0_0");
m.bind(spill_ab_start[0]+8, "spill_ab_start_0_1");
m.bind(spill_ab_start[1], "spill_ab_start_1_0");
m.bind(spill_ab_start[1]+8, "spill_ab_start_1_1");
m.bind(spill_u, "spill_u");
m.bind(spill_v, "spill_v");
m.bind(spill_parity, "spill_parity");
m.bind(spill_is_lehmer, "spill_is_lehmer");
m.bind(spill_ab_threshold, "spill_ab_threshold_0");
m.bind(spill_ab_threshold+8, "spill_ab_threshold_1");
reg_vector vector_ab=regs_parent.bind_vector(m, "vector_ab");
reg_vector vector_u=regs_parent.bind_vector(m, "vector_u");
reg_vector vector_v=regs_parent.bind_vector(m, "vector_v");
reg_vector vector_is_lehmer=regs_parent.bind_vector(m, "vector_is_lehmer");
reg_vector vector_ab_threshold=regs_parent.bind_vector(m, "vector_ab_threshold");
reg_spill spill_iter=regs_parent.bind_spill(m, "spill_iter");
APPEND_M(str( "MOV QWORD PTR `spill_u_0, 1" ));
APPEND_M(str( "MOV QWORD PTR `spill_u_1, 0" ));
APPEND_M(str( "MOV QWORD PTR `spill_v_0, 0" ));
APPEND_M(str( "MOV QWORD PTR `spill_v_1, 1" ));
APPEND_M(str( "MOV QWORD PTR `spill_parity, 0" ));
APPEND_M(str( "MOV QWORD PTR `spill_iter, #", to_hex(gcd_128_max_iter) ));
string start_label=m.alloc_label();
string loop_label=m.alloc_label();
string exit_label=m.alloc_label();
string exit_iter_0_label=m.alloc_label();
string start_assign_label=m.alloc_label();
APPEND_M(str( "JMP #", start_assign_label ));
APPEND_M(str( "#:", loop_label ));
track_asm( "gcd_128 iter" );
//4x scalar
reg_scalar new_u_0=regs_parent.bind_scalar(m, "new_u_0"); //a
reg_scalar new_u_1=regs_parent.bind_scalar(m, "new_u_1"); //b
reg_scalar new_v_0=regs_parent.bind_scalar(m, "new_v_0"); //ab_threshold
reg_scalar new_v_1=regs_parent.bind_scalar(m, "new_v_1"); //base iter
if (use_divide_table) {
string base_exit_label=m.alloc_label();
string base_loop_label=m.alloc_label();
APPEND_M(str( "MOV `new_v_1, #", to_hex(gcd_base_max_iter_divide_table) ));
APPEND_M(str( "MOVDQA `vector_u, #", constant_address_uint64(1ull, 0ull) ));
APPEND_M(str( "MOVDQA `vector_v, #", constant_address_uint64(0ull, 1ull) ));
APPEND_M(str( "#:", base_loop_label ));
gcd_64_iteration(regs_parent, vector_is_lehmer, {new_u_0, new_u_1}, {vector_u, vector_v}, new_v_0, base_exit_label);
APPEND_M(str( "DEC `new_v_1" ));
APPEND_M(str( "JNZ #", base_loop_label ));
APPEND_M(str( "#:", base_exit_label ));
APPEND_M(str( "CMP `new_v_1, #", to_hex(gcd_base_max_iter_divide_table) ));
APPEND_M(str( "JE #", track_asm( "gcd_128 base no progress", exit_label ) ));
} else {
gcd_base_continued_fraction(
regs_parent, vector_ab, vector_u, vector_v, vector_is_lehmer, vector_ab_threshold,
track_asm( "gcd_128 base no progress", exit_label )
);
}
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
//12x scalar (including dot product exact which is 2x scalar)
reg_scalar m_0_0=regs.bind_scalar(m, "m_0_0");
reg_scalar m_0_1=regs.bind_scalar(m, "m_0_1");
reg_scalar m_1_0=regs.bind_scalar(m, "m_1_0");
reg_scalar m_1_1=regs.bind_scalar(m, "m_1_1");
reg_scalar tmp_0=regs.bind_scalar(m, "tmp_0");
reg_scalar tmp_1=regs.bind_scalar(m, "tmp_1");
reg_vector tmp_a=regs.bind_vector(m, "tmp_a");
reg_vector tmp_b=regs.bind_vector(m, "tmp_b");
reg_vector tmp_c=regs.bind_vector(m, "tmp_c");
reg_vector c_double_abs_mask=regs.bind_vector(m, "double_abs_mask");
if (!use_divide_table) {
APPEND_M(str( "MOVAPD `double_abs_mask, #", constant_address_uint64(double_abs_mask, double_abs_mask) ));
}
auto abs_tmp_a=[&]() {
if (use_divide_table) {
//tmp_b = int64 mask = int64(v)>>63;
APPEND_M(str( "MOVDQA `tmp_b, `tmp_a" ));
APPEND_M(str( "PSRAD `tmp_b, 32" )); //high 32 bits = sign bit ; low 32 bits = undefined
APPEND_M(str( "PSHUFD `tmp_b, `tmp_b, #", to_hex( 0b11110101 ) )); //move high 32 bits to low 32 bits
//abs_v=(v + mask) ^ mask;
APPEND_M(str( "PADDQ `tmp_a, `tmp_b" ));
APPEND_M(str( "PXOR `tmp_a, `tmp_b" ));
} else {
APPEND_M(str( "PAND `tmp_a, `double_abs_mask" ));
}
};
auto mov_low_tmp_a=[&](string target) {
if (use_divide_table) {
APPEND_M(str( "MOVQ `#, `tmp_a", target ));
} else {
APPEND_M(str( "CVTTSD2SI `#, `tmp_a", target ));
}
};
//<m_0_0, m_1_0>=<abs(vector_u[0]), abs(vector_u[1])>
//for the divide table, this is u[0] and v[0]
APPEND_M(str( "MOVAPD `tmp_a, `vector_u" ));
abs_tmp_a();
mov_low_tmp_a( (use_divide_table)? "m_0_0" : "m_0_0" );
APPEND_M(str( "SHUFPD `tmp_a, `tmp_a, 3" ));
mov_low_tmp_a( (use_divide_table)? "m_0_1" : "m_1_0" );
//<m_1_0, m_1_1>=<abs(vector_v[0]), abs(vector_v[1])>
//for the divide table, this is u[1] and v[1]
APPEND_M(str( "MOVAPD `tmp_a, `vector_v" ));
abs_tmp_a();
mov_low_tmp_a( (use_divide_table)? "m_1_0" : "m_0_1" );
APPEND_M(str( "SHUFPD `tmp_a, `tmp_a, 3" ));
mov_low_tmp_a( (use_divide_table)? "m_1_1" : "m_1_1" );
APPEND_M(str( "MOV `tmp_0, `spill_u_0" ));
APPEND_M(str( "MOV `tmp_1, `spill_u_1" ));
dot_product_exact(regs, {m_0_0, m_0_1}, {tmp_0, tmp_1}, new_u_0, track_asm( "gcd_128 uv overflow", exit_label ));
dot_product_exact(regs, {m_1_0, m_1_1}, {tmp_0, tmp_1}, new_u_1, track_asm( "gcd_128 uv overflow", exit_label ));
APPEND_M(str( "MOV `tmp_0, `spill_v_0" ));
APPEND_M(str( "MOV `tmp_1, `spill_v_1" ));
dot_product_exact(regs, {m_0_0, m_0_1}, {tmp_0, tmp_1}, new_v_0, track_asm( "gcd_128 uv overflow", exit_label ));
dot_product_exact(regs, {m_1_0, m_1_1}, {tmp_0, tmp_1}, new_v_1, track_asm( "gcd_128 uv overflow", exit_label ));
}
//9x scalar
reg_scalar new_ab_0_0=regs_parent.bind_scalar(m, "new_ab_0_0");
reg_scalar new_ab_0_1=regs_parent.bind_scalar(m, "new_ab_0_1");
reg_scalar new_ab_1_0=regs_parent.bind_scalar(m, "new_ab_1_0");
reg_scalar new_ab_1_1=regs_parent.bind_scalar(m, "new_ab_1_1");
reg_scalar new_parity=regs_parent.bind_scalar(m, "new_parity");
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
//15x scalar
reg_scalar rax=regs.bind_scalar(m, "rax", reg_rax);
reg_scalar rdx=regs.bind_scalar(m, "rdx", reg_rdx);
reg_vector tmp_a=regs.bind_vector(m, "tmp_a");
reg_scalar ab_start_0_0=regs.bind_scalar(m, "ab_start_0_0");
reg_scalar ab_start_0_1=regs.bind_scalar(m, "ab_start_0_1");
reg_scalar ab_start_1_0=regs.bind_scalar(m, "ab_start_1_0");
reg_scalar ab_start_1_1=regs.bind_scalar(m, "ab_start_1_1");
APPEND_M(str( "MOV `ab_start_0_0, `spill_ab_start_0_0" ));
APPEND_M(str( "MOV `ab_start_0_1, `spill_ab_start_0_1" ));
APPEND_M(str( "MOV `ab_start_1_0, `spill_ab_start_1_0" ));
APPEND_M(str( "MOV `ab_start_1_1, `spill_ab_start_1_1" ));
//RAX=(uv_double[1][1]<0)? 1 : 0=uv_double_parity
//(this also works for integers with the divide table)
APPEND_M(str( "MOVAPD `tmp_a, `vector_v" ));
APPEND_M(str( "SHUFPD `tmp_a, `tmp_a, 3" ));
APPEND_M(str( "MOVQ RAX, `tmp_a" ));
APPEND_M(str( "SHR RAX, 63" ));
//new_parity=spill_parity^uv_double_parity
APPEND_M(str( "MOV `new_parity, `spill_parity" ));
APPEND_M(str( "XOR `new_parity, RAX" ));
//[out1:out0]=[a1:a0]*u - [b1:b0]*v
auto dot_product_subtract=[&](string a0, string a1, string b0, string b1, string u, string v, string out0, string out1) {
//[RDX:RAX]=a0*u
APPEND_M(str( "MOV RAX, `#", a0 ));
APPEND_M(str( "MUL `#", u ));
//[out1:out0]=a0*u
APPEND_M(str( "MOV `#, RAX", out0 ));
APPEND_M(str( "MOV `#, RDX", out1 ));
//[RDX:RAX]=a1*u
APPEND_M(str( "MOV RAX, `#", a1 ));
APPEND_M(str( "MUL `#", u ));
//[out1:out0]=a0*u + (a1*u)<<64=a*u
APPEND_M(str( "ADD `#, RAX", out1 ));
//[RDX:RAX]=b0*v
APPEND_M(str( "MOV RAX, `#", b0 ));
APPEND_M(str( "MUL `#", v ));
//[out1:out0]=a*u - b0*v
APPEND_M(str( "SUB `#, RAX", out0 ));
APPEND_M(str( "SBB `#, RDX", out1 ));
//[RDX:RAX]=b1*v
APPEND_M(str( "MOV RAX, `#", b1 ));
APPEND_M(str( "MUL `#", v ));
//[out1:out0]=a*u - b0*v - (b1*v)<<64=a*u - b*v
APPEND_M(str( "SUB `#, RAX", out1 ));
};
// uint64 uv_00=uv_uint64_new[0][0];
// uint64 uv_01=uv_uint64_new[0][1];
// int128 a_new_1=ab_start[0]; a_new_1*=uv_00;
// int128 a_new_2=ab_start[1]; a_new_2*=uv_01;
// if (uv_uint64_parity_new!=0) swap(a_new_1, a_new_2);
// int128 a_new_s=a_new_1-a_new_2;
// uint128 a_new(a_new_s);
dot_product_subtract(
"ab_start_0_0", "ab_start_0_1",
"ab_start_1_0", "ab_start_1_1",
"new_u_0", "new_v_0",
"new_ab_0_0", "new_ab_0_1"
);
// uint64 uv_10=uv_uint64_new[1][0];
// uint64 uv_11=uv_uint64_new[1][1];
// int128 b_new_1=ab_start[1]; b_new_1*=uv_11;
// int128 b_new_2=ab_start[0]; b_new_2*=uv_10;
// if (uv_uint64_parity_new!=0) swap(b_new_1, b_new_2);
// int128 b_new_s=b_new_1-b_new_2;
// uint128 b_new(b_new_s);
dot_product_subtract(
"ab_start_1_0", "ab_start_1_1",
"ab_start_0_0", "ab_start_0_1",
"new_v_1", "new_u_1",
"new_ab_1_0", "new_ab_1_1"
);
APPEND_M(str( "MOV RAX, -1" ));
APPEND_M(str( "ADD RAX, `new_parity" )); //rax=(new_parity==1)? 0 : ~0
APPEND_M(str( "NOT RAX" )); //rax=(new_parity==1)? ~0 : 0
//if (new_parity!=0) { [out1:out0]=-[out1:out0]; }
auto conditional_negate=[&](string out0, string out1) {
//flip all bits if new_parity==1
APPEND_M(str( "XOR `#, RAX", out0 ));
APPEND_M(str( "XOR `#, RAX", out1 ));
//add 1 if new_parity==1
APPEND_M(str( "ADD `#, `new_parity", out0 ));
APPEND_M(str( "ADC `#, 0", out1 ));
};
conditional_negate( "new_ab_0_0", "new_ab_0_1" );
conditional_negate( "new_ab_1_0", "new_ab_1_1" );
}
//11x scalar: new_ab, new_u, new_v, new_parity, ab_threshold
reg_scalar ab_threshold_0=regs_parent.bind_scalar(m, "ab_threshold_0");
reg_scalar ab_threshold_1=regs_parent.bind_scalar(m, "ab_threshold_1");
//flags for [a1:a0]-[b1:b0]:
//CMP a0,b0 ; sets CF if b0>a0. clears CF if b0==a0
//SBB a1,b1 ; sets CF if b>a. sets ZF if b==a. may set ZF if b<a (e.g. a1==0; b1==0; b0<a0)
//CF set: a<b
//CF cleared: a>=b
//need to swap the order for <= and >
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
//15x scalar
reg_scalar ab_delta_0=regs.bind_scalar(m, "ab_delta_0");
reg_scalar ab_delta_1=regs.bind_scalar(m, "ab_delta_1");
reg_scalar b_new_min=regs.bind_scalar(m, "b_new_min");
reg_scalar is_lehmer=regs.bind_scalar(m, "is_lehmer");
APPEND_M(str( "MOV `is_lehmer, `spill_is_lehmer" ));
//uint128 ab_delta=new_ab[0]-new_ab[1]
APPEND_M(str( "MOV `ab_delta_0, `new_ab_0_0" ));
APPEND_M(str( "MOV `ab_delta_1, `new_ab_0_1" ));
APPEND_M(str( "SUB `ab_delta_0, `new_ab_1_0" ));
APPEND_M(str( "SBB `ab_delta_1, `new_ab_1_1" ));
// assert(a_new>=b_new);
// uint128 ab_delta=a_new-b_new;
//
// even:
// +uv_00 -uv_01
// -uv_10 +uv_11
//
// uint128 v_delta=uint128(v_1)+uint128(v_0); //even: positive. odd: negative
// uint128 u_delta=uint128(u_1)+uint128(u_0); //even: negative. odd: positive
//
// uv_10 is negative if even, positive if odd
// uv_11 is positive if even, negative if odd
// bool passed_even=(b_new>=uint128(u_1) && ab_delta>=v_delta);
// bool passed_odd=(b_new>=uint128(v_1) && ab_delta>=u_delta);
//uint64 uv_delta_0=(even)? new_v_1 : new_u_1;
//uv_delta_0 stored in ab_threshold_0
APPEND_M(str( "CMP `new_parity, 0" ));
APPEND_M(str( "MOV `ab_threshold_0, `new_u_1" ));
APPEND_M(str( "CMOVE `ab_threshold_0, `new_v_1" ));
//uint64 uv_delta_1=(even)? new_v_0 : new_u_0;
//uv_delta_1 stored in ab_threshold_1
APPEND_M(str( "MOV `ab_threshold_1, `new_u_0" ));
APPEND_M(str( "CMOVE `ab_threshold_1, `new_v_0" ));
//uint64 b_new_min=(even)? new_u_1 : new_v_1;
APPEND_M(str( "MOV `b_new_min, `new_v_1" ));
APPEND_M(str( "CMOVE `b_new_min, `new_u_1" ));
//if (!is_lehmer) uv_delta=0
APPEND_M(str( "CMP `is_lehmer, 0" ));
APPEND_M(str( "CMOVE `ab_threshold_0, `is_lehmer" )); //if moved, is_lehmer==0
APPEND_M(str( "CMOVE `ab_threshold_1, `is_lehmer" ));
//if (!is_lehmer) b_new_min=0
APPEND_M(str( "CMOVE `b_new_min, `is_lehmer" ));
//[uv_delta_1:uv_delta_0]=uv_delta_0 + uv_delta_1 //v_delta if even, else u_delta
APPEND_M(str( "ADD `ab_threshold_0, `ab_threshold_1" ));
APPEND_M(str( "MOV `ab_threshold_1, 0" ));
APPEND_M(str( "ADC `ab_threshold_1, 0" ));
//if (ab_delta<uv_delta) goto exit
//clobbers ab_delta
//uv_delta (ab_threshold) not needed anymore
APPEND_M(str( "SUB `ab_delta_0, `ab_threshold_0" ));
APPEND_M(str( "SBB `ab_delta_1, `ab_threshold_1" ));
APPEND_M(str( "JC #", track_asm( "gcd_128 lehmer fail ab_delta<uv_delta", exit_label ) ));
//if (new_ab[1]<b_new_min) goto exit
//clobbers b_new_min
APPEND_M(str( "CMP `new_ab_1_0, `b_new_min" ));
APPEND_M(str( "MOV `b_new_min, `new_ab_1_1" ));
APPEND_M(str( "SBB `b_new_min, 0" ));
APPEND_M(str( "JC #", track_asm( "gcd_128 lehmer fail new_ab[1]<b_new_min", exit_label ) ));
//
//
APPEND_M(str( "MOV `ab_threshold_0, `spill_ab_threshold_0" ));
APPEND_M(str( "MOV `ab_threshold_1, `spill_ab_threshold_1" ));
//if (ab_threshold>=new_ab[0]) goto exit;
APPEND_M(str( "MOV `ab_delta_0, `ab_threshold_0" ));
APPEND_M(str( "MOV `ab_delta_1, `ab_threshold_1" ));
APPEND_M(str( "SUB `ab_delta_0, `new_ab_0_0" ));
APPEND_M(str( "SBB `ab_delta_1, `new_ab_0_1" ));
APPEND_M(str( "JNC #", track_asm( "gcd_128 went too far ab_threshold>=new_ab[0]", exit_label ) ));
//u=new_u;
APPEND_M(str( "MOV `spill_u_0, `new_u_0" ));
APPEND_M(str( "MOV `spill_u_1, `new_u_1" ));
//v=new_v;
APPEND_M(str( "MOV `spill_v_0, `new_v_0" ));
APPEND_M(str( "MOV `spill_v_1, `new_v_1" ));
//parity=new_parity;
APPEND_M(str( "MOV `spill_parity, `new_parity" ));
track_asm( "gcd_128 good iter" );
//--iter;
//if (iter==0) goto exit;
APPEND_M(str( "MOV `ab_delta_0, `spill_iter" ));
APPEND_M(str( "DEC `ab_delta_0" ));
APPEND_M(str( "MOV `spill_iter, `ab_delta_0" ));
APPEND_M(str( "JZ #", track_asm( "gcd_128 good exit", exit_iter_0_label ) ));
}
APPEND_M(str( "#:", start_label ));
//11x scalar: new_ab, new_u, new_v, new_parity, ab_threshold
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
//4x scalar
reg_scalar tmp_0=regs.bind_scalar(m, "tmp_0", reg_rax);
reg_scalar tmp_1=regs.bind_scalar(m, "tmp_1", reg_rdx);
reg_scalar tmp_2=regs.bind_scalar(m, "tmp_2");
reg_scalar tmp_3=regs.bind_scalar(m, "tmp_3", reg_rcx);
reg_scalar ab_0_0=new_ab_0_0;
reg_scalar ab_0_1=new_ab_0_1;
reg_scalar ab_1_0=new_ab_1_0;
reg_scalar ab_1_1=new_ab_1_1;
m.bind(new_ab_0_0, "ab_0_0");
m.bind(new_ab_0_1, "ab_0_1");
m.bind(new_ab_1_0, "ab_1_0");
m.bind(new_ab_1_1, "ab_1_1");
m.bind(ab_threshold_0, "ab_threshold_0");
m.bind(ab_threshold_1, "ab_threshold_1");
//tmp_3=0
APPEND_M(str( "XOR `tmp_3, `tmp_3" ));
//tmp=ab_1-ab_threshold
APPEND_M(str( "MOV `tmp_0, `ab_1_0" ));
APPEND_M(str( "MOV `tmp_1, `ab_1_1" ));
APPEND_M(str( "SUB `tmp_0, `ab_threshold_0" ));
APPEND_M(str( "SBB `tmp_1, `ab_threshold_1" ));
//if (ab[1]<ab_threshold) goto exit
APPEND_M(str( "JC #", track_asm( "gcd_128 ab[1]<ab_threshold", exit_label ) ));
//if (ab[1]==ab_threshold) goto exit
APPEND_M(str( "MOV `tmp_2, `tmp_0" ));
APPEND_M(str( "OR `tmp_2, `tmp_1" )); //ZF set if tmp_0==0 and tmp_1==0
APPEND_M(str( "JZ #", track_asm( "gcd_128 ab[1]==ab_threshold", exit_label ) ));
//tmp_0=(ab[0][1]==0)? ab[0][0] : ab[0][1]
//tmp_1=(ab[0][1]==0)? 0 : 64
//tmp_0 can't be 0
APPEND_M(str( "MOV `tmp_0, `ab_0_1" ));
APPEND_M(str( "MOV `tmp_1, 64" ));
APPEND_M(str( "CMP `ab_0_1, 0" ));
#ifdef CHIAOSX
string cmoveq_label1=m.alloc_label();
APPEND_M(str( "JNE #", cmoveq_label1));
APPEND_M(str( "MOV `tmp_0, `ab_0_0" ));
APPEND_M(str("#:", cmoveq_label1));
string cmoveq_label2=m.alloc_label();
APPEND_M(str( "JNE #", cmoveq_label2));
APPEND_M(str( "MOV `tmp_1, `tmp_3" ));
APPEND_M(str("#:", cmoveq_label2));
#else
APPEND_M(str( "CMOVEQ `tmp_0, `ab_0_0" ));
APPEND_M(str( "CMOVEQ `tmp_1, `tmp_3" ));
#endif
//tmp_0=[first set bit index in tmp_0]
APPEND_M(str( "BSR `tmp_0, `tmp_0" ));
//tmp_0=[number of bits in ab[0]]=a_num_bits
APPEND_M(str( "ADD `tmp_1, `tmp_0" ));
APPEND_M(str( "INC `tmp_1" ));
//if (is_lehmer) {
// const int min_bits=96;
// if (a_num_bits<min_bits) {
// a_num_bits=min_bits;
// }
//}
//tmp_2=spill_is_lehmer
//tmp_0=((spill_is_lehmer)? 96 : 0)=min_bits
APPEND_M(str( "XOR `tmp_0, `tmp_0" ));
APPEND_M(str( "MOV `tmp_2, `spill_is_lehmer" ));
APPEND_M(str( "CMP `tmp_2, 0" ));
APPEND_M(str( "MOV `tmp_3, 96" ));
APPEND_M(str( "CMOVNE `tmp_0, `tmp_3" ));
APPEND_M(str( "XOR `tmp_3, `tmp_3" ));
//if (a_num_bits<min_bits) a_num_bits=min_bits;
APPEND_M(str( "CMP `tmp_1, `tmp_0" ));
APPEND_M(str( "CMOVB `tmp_1, `tmp_0" ));
//int shift_amount=a_num_bits-gcd_base_bits; [shift amount can't exceed 128-gcd_base_bits]
//if (shift_amount<0) {
// shift_amount=0;
//}
//tmp_1=a_num_bits-gcd_base_bits
APPEND_M(str( "SUB `tmp_1, #", to_hex(gcd_base_bits) ));
//if (a_num_bits<gcd_base_bits) tmp_1=0
//tmp_1=shift_amount
APPEND_M(str( "CMOVB `tmp_1, `tmp_3" ));
//vector_is_lehmer=((spill_is_lehmer | shift_amount)!=0)? <~0, ~0> : <0, 0>
APPEND_M(str( "OR `tmp_2, `tmp_1" ));
if (!use_divide_table) {
#ifdef CHIAOSX
APPEND_M(str( "LEA `tmp_3, [RIP+#]", constant_address_uint64(0ull, 0ull, false) ));
APPEND_M(str( "LEA `tmp_0, [RIP+#]", constant_address_uint64(~(0ull), ~(0ull), false) ));
#else
APPEND_M(str( "MOV `tmp_3, OFFSET FLAT:#", constant_address_uint64(0ull, 0ull, false) ));
APPEND_M(str( "MOV `tmp_0, OFFSET FLAT:#", constant_address_uint64(~(0ull), ~(0ull), false) ));
#endif
} else {
#ifdef CHIAOSX
APPEND_M(str( "LEA `tmp_3, [RIP+#]", constant_address_uint64(gcd_mask_exact[0], gcd_mask_exact[1], false) ));
APPEND_M(str( "LEA `tmp_0, [RIP+#]", constant_address_uint64(gcd_mask_approximate[0], gcd_mask_approximate[1], false) ));
#else
APPEND_M(str( "MOV `tmp_3, OFFSET FLAT:#", constant_address_uint64(gcd_mask_exact[0], gcd_mask_exact[1], false) ));
APPEND_M(str( "MOV `tmp_0, OFFSET FLAT:#", constant_address_uint64(gcd_mask_approximate[0], gcd_mask_approximate[1], false) ));
#endif
}
APPEND_M(str( "CMOVZ `tmp_0, `tmp_3" ));
APPEND_M(str( "MOVAPD `vector_is_lehmer, [`tmp_0]" ));
//vector2 ab_double{
// double(uint64(ab[0]>>shift_amount)),
// double(uint64(ab[1]>>shift_amount))
//};
//double ab_threshold_double(uint64(ab_threshold>>shift_amount));
//if (shift_amount!=0) {
// ++ab_threshold_double; [can do this with integers because the shifted ab_threshold has to fit in a double exactly]
// a is larger than ab_threshold
//}
//vector_ab=<ab_1>>shift_amount, undefined>
//also store integer in new_u_1
shift_right(regs, {ab_1_0, ab_1_1}, tmp_1, new_u_1, tmp_3, tmp_2);
if (!use_divide_table) {
APPEND_M(str( "CVTSI2SD `vector_ab, `new_u_1" ));
}
//vector_ab=<ab_1>>shift_amount, ab_1>>shift_amount>
if (!use_divide_table) {
APPEND_M(str( "SHUFPD `vector_ab, `vector_ab, 0" ));
}
//vector_ab=<ab_0>>shift_amount, ab_1>>shift_amount>
//also store integer in new_u_1
shift_right(regs, {ab_0_0, ab_0_1}, tmp_1, new_u_0, tmp_3, tmp_2);
if (!use_divide_table) {
APPEND_M(str( "CVTSI2SD `vector_ab, `new_u_0" ));
}
//tmp_0=(ab_threshold>>shift_amount)
//also store integer in new_v_0
shift_right(regs, {ab_threshold_0, ab_threshold_1}, tmp_1, new_v_0, tmp_3, tmp_2);
//vector_ab_threshold=<ab_threshold_double, ab_threshold_double>
if (!use_divide_table) {
APPEND_M(str( "CVTSI2SD `vector_ab_threshold, `new_v_0" ));
APPEND_M(str( "SHUFPD `vector_ab_threshold, `vector_ab_threshold, 0" ));
}
}
APPEND_M(str( "JMP #", loop_label ));
//
//
APPEND_M(str( "#:", exit_label ));
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
reg_scalar tmp=regs.bind_scalar(m, "tmp");
//if (iter==gcd_128_max_iter) goto no_progress
APPEND_M(str( "MOV `tmp, `spill_iter" ));
APPEND_M(str( "CMP `tmp, #", to_hex(gcd_128_max_iter) ));
APPEND_M(str( "JE #", track_asm( "gcd_128 no progress", no_progress_label ) ));
}
APPEND_M(str( "JMP #", track_asm( "gcd_128 premature exit", exit_iter_0_label ) ));
//
//
APPEND_M(str( "#:", start_assign_label ));
APPEND_M(str( "MOV `new_ab_0_0, `spill_ab_start_0_0" ));
APPEND_M(str( "MOV `new_ab_0_1, `spill_ab_start_0_1" ));
APPEND_M(str( "MOV `new_ab_1_0, `spill_ab_start_1_0" ));
APPEND_M(str( "MOV `new_ab_1_1, `spill_ab_start_1_1" ));
APPEND_M(str( "MOV `ab_threshold_0, `spill_ab_threshold_0" ));
APPEND_M(str( "MOV `ab_threshold_1, `spill_ab_threshold_1" ));
APPEND_M(str( "JMP #", start_label ));
//
//
APPEND_M(str( "#:", exit_iter_0_label ));
}
}

View File

@ -0,0 +1,375 @@
namespace asm_code {
const double range_check_range=double((1ull<<53)-1);
const uint64 double_sign_mask=(1ull<<63);
const uint64 double_abs_mask=~double_sign_mask;
//clobbers v
void range_check(
reg_vector v, reg_vector range, reg_vector c_double_abs_mask,
string out_of_range_label
) {
EXPAND_MACROS_SCOPE;
m.bind(range, "range");
m.bind(c_double_abs_mask, "double_abs_mask");
m.bind(v, "tmp");
//tmp=abs(tmp)
APPEND_M(str( "ANDPD `tmp, `double_abs_mask" ));
//tmp all 0s if (abs(tmp0)<=range && abs(tmp1)<=range)
APPEND_M(str( "CMPNLEPD `tmp, `range" ));
//todo //can replace this with POR into an accumulator then use a single PTEST
//todo //can compile the code twice for is_leher being true and false, then branch to the appropriate version
//todo //can probably get rid of the uv range checks if is_lehmer is true
//todo //can get rid of the ab range checks if the table is used and each table uv value has a magnitude less than a certain amount
APPEND_M(str( "PTEST `tmp, `tmp" ));
APPEND_M(str( "JNZ #", out_of_range_label ));
}
//clobbers b
//this calculates the dot product of each lane separately and puts the result in that lane
void dot_product_exact(
array<reg_vector, 2> a, array<reg_vector, 2> b, reg_vector v, reg_vector range, reg_vector c_double_abs_mask,
string out_of_range_label, bool result_always_in_range=false
) {
EXPAND_MACROS_SCOPE;
m.bind(a, "a");
m.bind(b, "b");
m.bind(v, "v");
APPEND_M(str( "MULPD `b_0, `a_0" ));
APPEND_M(str( "MOVAPD `v, `b_0" ));
//todo //for avx, can get rid of a lot of the MOVs by using the 3-operand versions of the instructions
range_check(b[0], range, c_double_abs_mask, out_of_range_label);
if (enable_all_instructions) {
APPEND_M(str( "VFMADD231PD `v, `b_1, `a_1" ));
} else {
APPEND_M(str( "MULPD `b_1, `a_1" ));
APPEND_M(str( "ADDPD `v, `b_1" ));
range_check(b[1], range, c_double_abs_mask, out_of_range_label);
}
if (!result_always_in_range) {
APPEND_M(str( "MOVAPD `b_0, `v" ));
range_check(b[0], range, c_double_abs_mask, out_of_range_label);
}
}
//ab_threshold is the same for both lanes
//is_lehmer is all 1s if lehmer, else all 0s
//will assign u and v
void gcd_base_continued_fraction(
reg_alloc regs,
reg_vector ab, reg_vector u, reg_vector v, reg_vector is_lehmer, reg_vector ab_threshold,
string no_progress_label
) {
EXPAND_MACROS_SCOPE;
track_asm( "gcd_base" );
static double_table<continued_fraction> c_table=generate_table(gcd_table_num_exponent_bits, gcd_table_num_fraction_bits);
static bool outputted_table=false;
if (!outputted_table) {
#ifdef CHIAOSX
APPEND_M(str( ".text " ));
#else
APPEND_M(str( ".text 1" ));
#endif
APPEND_M(str( ".balign 64" ));
APPEND_M(str( "gcd_base_table:" ));
string table_data;
auto out_double=[&](double v) {
if (!table_data.empty()) {
table_data += ", ";
}
table_data+=to_hex(*(uint64*)&v);
};
//each entry is 32 bytes, 32-aligned
for (continued_fraction c : c_table.data) {
matrix2 mat=c.get_matrix();
out_double(mat[0][0]); //lane 0
out_double(mat[1][0]); //lane 1
out_double(mat[0][1]); //lane 0
out_double(mat[1][1]); //lane 1
APPEND_M(str( ".quad #", table_data ));
table_data.clear();
}
APPEND_M(str( ".text" ));
outputted_table=true;
}
//5x vector
m.bind(ab, "ab");
m.bind(u, "u");
m.bind(v, "v");
m.bind(is_lehmer, "is_lehmer");
m.bind(ab_threshold, "ab_threshold");
//11x vector
reg_vector m_0=regs.bind_vector(m, "m_0");
reg_vector m_1=regs.bind_vector(m, "m_1");
reg_vector new_ab=regs.bind_vector(m, "new_ab");
reg_vector new_ab_1=regs.bind_vector(m, "new_ab_1");
reg_vector tmp=regs.bind_vector(m, "tmp");
reg_vector tmp2=regs.bind_vector(m, "tmp2");
reg_vector new_u=regs.bind_vector(m, "new_u");
reg_vector new_v=regs.bind_vector(m, "new_v");
reg_vector q=regs.bind_vector(m, "q");
reg_vector c_range_check_range=regs.bind_vector(m, "range_check_range");
reg_vector c_double_abs_mask=regs.bind_vector(m, "double_abs_mask");
reg_scalar q_scalar=regs.bind_scalar(m, "q_scalar");
reg_scalar q_scalar_2=regs.bind_scalar(m, "q_scalar_2");
reg_scalar q_scalar_3=regs.bind_scalar(m, "q_scalar_3");
reg_scalar loop_counter=regs.bind_scalar(m, "loop_counter");
reg_scalar c_table_delta_minus_1=regs.bind_scalar(m, "c_table_delta_minus_1");
APPEND_M(str( "MOV `c_table_delta_minus_1, #", constant_address_uint64(c_table.delta-1, c_table.delta-1) ));
string exit_label=m.alloc_label();
string loop_label=m.alloc_label();
APPEND_M(str( "MOV `loop_counter, #", to_hex(gcd_base_max_iter) ));
APPEND_M(str( "MOVAPD `u, #", constant_address_double(1.0, 0.0) ));
APPEND_M(str( "MOVAPD `v, #", constant_address_double(0.0, 1.0) ));
APPEND_M(str( "MOVAPD `range_check_range, #", constant_address_double(range_check_range, range_check_range) ));
APPEND_M(str( "MOVAPD `double_abs_mask, #", constant_address_uint64(double_abs_mask, double_abs_mask) ));
// q[0]=ab[0]/ab[1]
APPEND_M(str( "MOVAPD `tmp, `ab" ));
APPEND_M(str( "SHUFPD `tmp, `tmp, 3" )); // tmp=<ab[1], ab[1]>
APPEND_M(str( "MOVAPD `q, `ab" ));
APPEND_M(str( "DIVSD `q, `tmp" ));
{
APPEND_M(str( "#:", loop_label ));
track_asm( "gcd_base iter" );
string no_table_label=m.alloc_label();
APPEND_M( "#gcd_base loop start" );
//q_scalar=q_scalar_2=to_uint64(ab[0]/ab[1])
APPEND_M(str( "MOVQ `q_scalar, `q" ));
APPEND_M(str( "MOV `q_scalar_2, `q_scalar" ));
APPEND_M(str( "MOV `q_scalar_3, `q_scalar" ));
//q_scalar=(to_uint64(ab_0/ab_1)>>c_table.right_shift_amount)<<5
assert(c_table.right_shift_amount>5);
APPEND_M(str( "SHR `q_scalar, #", to_hex(c_table.right_shift_amount-5) ));
APPEND_M(str( "AND `q_scalar, -32" ));
// q_scalar-=c_table.range_start_shifted<<5
// if (q_scalar<0 || q_scalar>=(c_table.range_end_shifted-c_table.range_start_shifted)<<5) goto no_table_label
//this bypasses the "ab[1]<=ab_threshold" check so we need to do it again in no_table_label
APPEND_M(str( "SUB `q_scalar, #", to_hex(c_table.range_start_shifted<<5) ));
APPEND_M(str( "JB #", track_asm( "gcd_base below table start", no_table_label ) ));
APPEND_M(str( "CMP `q_scalar, #", to_hex((c_table.range_end_shifted-c_table.range_start_shifted)<<5) ));
APPEND_M(str( "JAE #", track_asm( "gcd_base after table end", no_table_label ) ));
//m_0: column 0
//m_1: column 1
#ifdef CHIAOSX
APPEND_M(str( "LEA RSI,[RIP+gcd_base_table]"));
APPEND_M(str( "MOVAPD `m_0, [`q_scalar+RSI]" ));
APPEND_M(str( "MOVAPD `m_1, [16+`q_scalar+RSI]" ));
#else
APPEND_M(str( "MOVAPD `m_0, [gcd_base_table+`q_scalar]" ));
APPEND_M(str( "MOVAPD `m_1, [gcd_base_table+16+`q_scalar]" ));
#endif
//if (ab[1]<=ab_threshold) goto exit_label
//this also tests ab[0], which is >= ab[1] so this does nothing
APPEND_M(str( "MOVAPD `tmp, `ab" ));
APPEND_M(str( "CMPLEPD `tmp, `ab_threshold" )); // tmp all 0s if (ab[0]>ab_threshold[0] && ab[1]>ab_threshold[1])
APPEND_M(str( "PTEST `tmp, `tmp" ));
APPEND_M(str( "JNZ #", track_asm( "gcd_base ab[1]<=ab_threshold", exit_label ) ));
//if ( (q_scalar_2&(c_table.delta-1))==0 || (q_scalar_2&(c_table.delta-1))==c_table.delta-1 ) goto no_table_label
APPEND_M(str( "AND `q_scalar_2, `c_table_delta_minus_1" ));
APPEND_M(str( "JZ #", track_asm( "gcd_base on slot boundary", no_table_label ) ));
APPEND_M(str( "CMP `q_scalar_2, `c_table_delta_minus_1" ));
APPEND_M(str( "JE #", track_asm( "gcd_base on slot boundary", no_table_label ) ));
//assigns: new_ab, new_ab_1, q, new_u, new_v
//reads: m, ab, u, v
//clobbers: tmp
auto calculate_using_m=[&](string fail_label) {
APPEND_M(str( "MOVAPD `tmp, `ab" ));
APPEND_M(str( "SHUFPD `tmp, `tmp, 0" ));
APPEND_M(str( "MOVAPD `tmp2, `ab" ));
APPEND_M(str( "SHUFPD `tmp2, `tmp2, 3" ));
dot_product_exact(
{m_0, m_1}, {tmp, tmp2}, new_ab, c_range_check_range, c_double_abs_mask,
track_asm( "gcd_base ab range check failed", fail_label),
true
);
APPEND_M(str( "MOVAPD `new_ab_1, `new_ab" ));
APPEND_M(str( "SHUFPD `new_ab_1, `new_ab_1, 3" )); // new_ab_1=<new_ab[1], new_ab[1]>
// q[0]=new_ab[0]/new_ab[1]
// this clobbers q if the table is not used
APPEND_M(str( "MOVAPD `q, `new_ab" ));
APPEND_M(str( "DIVSD `q, `new_ab_1" ));
APPEND_M(str( "MOVAPD `tmp, `u" ));
APPEND_M(str( "SHUFPD `tmp, `tmp, 0" ));
APPEND_M(str( "MOVAPD `tmp2, `u" ));
APPEND_M(str( "SHUFPD `tmp2, `tmp2, 3" ));
dot_product_exact(
{m_0, m_1}, {tmp, tmp2}, new_u, c_range_check_range, c_double_abs_mask,
track_asm( "gcd_base uv range check failed", fail_label)
);
//todo //for avx, can replace some shuffles with broadcasts. can make a macro that expands to the proper instructions
APPEND_M(str( "MOVAPD `tmp, `v" ));
APPEND_M(str( "SHUFPD `tmp, `tmp, 0" ));
APPEND_M(str( "MOVAPD `tmp2, `v" ));
APPEND_M(str( "SHUFPD `tmp2, `tmp2, 3" ));
dot_product_exact(
{m_0, m_1}, {tmp, tmp2}, new_v, c_range_check_range, c_double_abs_mask,
track_asm( "gcd_base uv range check failed", fail_label)
);
};
calculate_using_m(no_table_label);
//if (new_ab[0]<=ab_threshold) goto no_table_label
APPEND_M(str( "UCOMISD `new_ab, `ab_threshold" ));
APPEND_M(str( "JBE #", track_asm( "gcd_base new_ab[0]<=ab_threshold for table", no_table_label ) ));
string lehmer_label=m.alloc_label();
APPEND_M(str( "JMP #", lehmer_label ));
APPEND_M(str( "#:", no_table_label ));
APPEND_M( "#gcd_base no table" );
{
track_asm( "gcd_base iter no table" );
//have to do this check here because it might have been skipped: if (ab[1]<=ab_threshold) goto exit_label
APPEND_M(str( "MOVAPD `tmp, `ab" ));
APPEND_M(str( "CMPLEPD `tmp, `ab_threshold" )); // tmp all 0s if (ab[0]>ab_threshold[0] && ab[1]>ab_threshold[1])
APPEND_M(str( "PTEST `tmp, `tmp" ));
APPEND_M(str( "JNZ #", track_asm( "gcd_base ab[1]<=ab_threshold", exit_label ) ));
//q is clobbered, so need to restore it
APPEND_M(str( "MOVQ `q, `q_scalar_3" ));
// q=floor(q);
//this requires SSE4. if not present, can also add and subtract a magic number
APPEND_M(str( "ROUNDSD `q, `q, 1" )); //floor
// m=[0 1]
// 1 -q]
// m_0=<0,1> [column 0]
// m_1=<1,-q> [column 1]
APPEND_M(str( "MOVAPD `m_0, #", constant_address_double(0.0, 1.0) ));
APPEND_M(str( "MOVAPD `m_1, `m_0" )); // m_1=<0,1>
APPEND_M(str( "SUBSD `m_1, `q" )); //m_1=<-q,1>
APPEND_M(str( "SHUFPD `m_1, `m_1, 1" )); //m_1=<1,-q>
calculate_using_m(exit_label);
}
APPEND_M(str( "#:", lehmer_label ));
APPEND_M( "#gcd_base end no table" );
// new_ab_0=<new_ab[0], new_ab[0]>
// new_ab_1=<new_ab[1], new_ab[1]>
// ab_delta=new_ab_0-new_ab_1
// new_uv_0=<new_u[0], new_v[0]>
// new_uv_1=<new_u[1], new_v[1]>
//bool passed=
// new_ab_1[0]>=-new_uv_1[0] && ab_delta[0]+new_uv_0[0]>=new_uv_1[0] &&
// new_ab_1[1]>=-new_uv_1[1] && ab_delta[1]+new_uv_0[1]>=new_uv_1[1]
//;
//bool passed=
// new_ab_1[0]>=-new_uv_1[0] && ab_delta[0]+new_vu_0[0]>=new_vu_1[0] &&
// new_ab_1[1]>=-new_uv_1[1] && ab_delta[1]+new_vu_0[1]>=new_vu_1[1]
//;
//bool passed=
// new_ab[1]>=-new_u[1] && ab_delta[0]+new_v[0]>=new_v[1] &&
// new_ab[1]>=-new_v[1] && ab_delta[0]+new_u[0]>=new_u[1]
//;
//m_0=new_uv_0=<new_u[0], new_v[0]>
APPEND_M(str( "MOVAPD `m_0, `new_u" ));
APPEND_M(str( "SHUFPD `m_0, `new_v, 0" ));
//m_1=new_uv_1=<new_u[1], new_v[1]>
APPEND_M(str( "MOVAPD `m_1, `new_u" ));
APPEND_M(str( "SHUFPD `m_1, `new_v, 3" ));
//tmp=new_ab_0=<new_ab[0], new_ab[0]>
APPEND_M(str( "MOVAPD `tmp, `new_ab" ));
APPEND_M(str( "SHUFPD `tmp, `tmp, 0" ));
//tmp=ab_delta=new_ab_0-new_ab_1
APPEND_M(str( "SUBPD `tmp, `new_ab_1" ));
//tmp=ab_delta+new_uv_0
APPEND_M(str( "ADDPD `tmp, `m_0" ));
//tmp all 0s if (ab_delta[0]+new_uv_0[0]>=new_uv_1[0] && ab_delta[1]+new_uv_0[1]>=new_uv_1[1])
APPEND_M(str( "CMPLTPD `tmp, `m_1" ));
//m_1=-new_uv_1
APPEND_M(str( "XORPD `m_1, #", constant_address_uint64(double_sign_mask, double_sign_mask) ));
//new_ab_1 all 0s if (new_ab_1[0]>=-new_uv_1[0] && new_ab_1[1]>=-new_uv_1[1])
APPEND_M(str( "CMPLTPD `new_ab_1, `m_1" ));
//if (is_lehmer && !(ab_delta[0]+new_uv_0[0]>=new_uv_1[0] && ab_delta[1]+new_uv_0[1]>=new_uv_1[1])) goto exit_label
//if (is_lehmer && !(new_ab_1[0]>=-new_uv_1[0] && new_ab_1[1]>=-new_uv_1[1])) goto exit_label
APPEND_M(str( "ORPD `tmp, `new_ab_1" )); //tmp all 0s if passed is true
APPEND_M(str( "ANDPD `tmp, `is_lehmer" )); //tmp all 0s if passed||(!is_lehmer) is true
APPEND_M(str( "PTEST `tmp, `tmp" ));
APPEND_M(str( "JNZ #", track_asm( "gcd_base lehmer failed", exit_label ) ));
APPEND_M(str( "MOVAPD `ab, `new_ab" ));
APPEND_M(str( "MOVAPD `u, `new_u" ));
APPEND_M(str( "MOVAPD `v, `new_v" ));
track_asm( "gcd_base good iter" );
APPEND_M(str( "DEC `loop_counter" ));
APPEND_M(str( "JNZ #", loop_label ));
APPEND_M( "#gcd_base loop end" );
}
track_asm( "gcd_base good exit" );
APPEND_M(str( "#:", exit_label ));
APPEND_M(str( "CMP `loop_counter, #", to_hex(gcd_base_max_iter) ));
APPEND_M(str( "JE #", track_asm( "gcd_base no progress", no_progress_label ) ));
}
}

View File

@ -0,0 +1,185 @@
namespace asm_code {
//regs: 1x scalar (RAX) + 4x scalar arguments (r==RDX)
//todo //test hit rate
void divide_table(reg_alloc regs, reg_scalar a, reg_scalar b, reg_scalar q, reg_scalar r) {
EXPAND_MACROS_SCOPE;
regs.get_scalar(reg_rax);
m.bind(a, "a");
m.bind(b, "b");
m.bind(q, "q");
assert(r.value==reg_rdx.value);
static bool outputted_table=false;
if (!outputted_table) {
#ifdef CHIAOSX
APPEND_M(str( ".text " ));
#else
APPEND_M(str( ".text 1" ));
#endif
APPEND_M(str( ".balign 64" ));
APPEND_M(str( "divide_table:" ));
const int expected_size=1<<divide_table_index_bits;
const int max_index=bit_sequence(0, divide_table_index_bits);
assert(max_index>=1);
int num=0;
auto add=[&](uint64 v) {
APPEND_M(str( ".quad #", to_hex(v) ));
++num;
};
add(0);
for (int index=1;index<=max_index;++index) {
uint128 v = (~uint128(0)) / uint128(index);
v>>=64;
add(v);
}
assert(num==expected_size);
APPEND_M(str( ".text" ));
outputted_table=true;
}
string b_shift_label=m.alloc_label();
APPEND_M(str( "BSR `q, `b" )); // b_shift = bsr(b)
APPEND_M(str( "SUB `q, #", to_hex(divide_table_index_bits-1) )); // b_shift = bsr(b)-(divide_table_index_bits-1)
APPEND_M(str( "JNB #", b_shift_label ));
APPEND_M(str( "XOR `q, `q" )); // if (b_shift<0) b_shift=0
APPEND_M(str( "#:", b_shift_label ));
APPEND_M(str( "SARX RAX, `b, `q" )); // b_approx = b>>b_shift
APPEND_M(str( "MOV RAX, [divide_table+RAX*8]" )); // b_approx_inverse = divide_table[b_approx]
APPEND_M(str( "IMUL `a" )); // q = (b_approx_inverse*a)>>64
APPEND_M(str( "SARX `q, RDX, `q" )); // q = q>>b_shift
string wrong_remainder_label=m.alloc_label();
APPEND_M(str( "MOV RAX, `q" ));
APPEND_M(str( "IMUL RAX, `b" )); // r = q*b
APPEND_M(str( "JO #", wrong_remainder_label )); // overflow
APPEND_M(str( "MOV RDX, `a" ));
APPEND_M(str( "SUB RDX, RAX" )); // r = a-q*b
APPEND_M(str( "JO #", wrong_remainder_label )); // overflow
APPEND_M(str( "CMP RDX, `b" ));
APPEND_M(str( "JAE #", wrong_remainder_label )); // !(r>=0 && r<b)
string end_label=m.alloc_label();
APPEND_M(str( "JMP #", end_label ));
const bool asm_output_common_case_only=false;
if (!asm_output_common_case_only) {
APPEND_M(str( "#:", wrong_remainder_label ));
APPEND_M(str( "MOV RDX, `a" ));
APPEND_M(str( "SAR RDX, #", to_hex(63) )); //all 1s if negative, all 0s if nonnegative
APPEND_M(str( "MOV RAX, `a" ));
APPEND_M(str( "IDIV `b" )); // RAX=a/b ; RDX=r=a%b
APPEND_M(str( "MOV `q, RAX" ));
APPEND_M(str( "CMP RDX, 0" ));
APPEND_M(str( "JGE #", end_label )); // r>=0
APPEND_M(str( "ADD RDX, `b" )); // r+=b
APPEND_M(str( "DEC `q" ));
}
APPEND_M(str( "#:", end_label ));
}
const array<uint64, 2> gcd_mask_approximate={1ull<<63, 1ull<<63};
const array<uint64, 2> gcd_mask_exact={0, 0};
//regs: 3x scalar, 3x vector, 2x scalar argument, 2x vector argument
//uv[0] is: u[0], v[0]. int64
//uv[1] is: u[1], v[1]
//c_gcd_mask is gcd_mask_approximate or gcd_mask_exact
//a is int64
void gcd_64_iteration(
reg_alloc regs, reg_vector c_gcd_mask, array<reg_scalar, 2> a, array<reg_vector, 2> uv, reg_scalar ab_threshold,
string early_exit_label
) {
EXPAND_MACROS_SCOPE;
m.bind(c_gcd_mask, "c_gcd_mask");
m.bind(a, "a");
m.bind(uv, "uv");
m.bind(ab_threshold, "ab_threshold");
reg_scalar q=regs.bind_scalar(m, "q");
reg_scalar r=regs.bind_scalar(m, "r", reg_rdx);
reg_scalar tmp_a=regs.bind_scalar(m, "tmp_a");
//new_uv_0 = uv[1]
reg_vector new_uv_1=regs.bind_vector(m, "new_uv_1");
reg_vector tmp_1=regs.bind_vector(m, "tmp_1");
reg_vector tmp_2=regs.bind_vector(m, "tmp_2");
APPEND_M(str( "CMP `a_1, `ab_threshold" ));
APPEND_M(str( "JBE #", early_exit_label ));
divide_table(regs, a[0], a[1], q, r);
APPEND_M(str( "MOV `tmp_a, `q" ));
APPEND_M(str( "SHL `tmp_a, #", to_hex(63-gcd_num_quotient_bits) ));
APPEND_M(str( "SAR `tmp_a, #", to_hex(63-gcd_num_quotient_bits) ));
APPEND_M(str( "CMP `tmp_a, `q" ));
APPEND_M(str( "JNE #", early_exit_label )); //quotient is too big
APPEND_M(str( "MOV `a_0, `a_1" ));
APPEND_M(str( "MOV `a_1, `r" ));
APPEND_M(str( "VMOVQ `new_uv_1_128, `q" ));
APPEND_M(str( "VPBROADCASTQ `new_uv_1, `new_uv_1_128" )); // new_uv_1 = q
APPEND_M(str( "VPMULDQ `new_uv_1, `new_uv_1, `uv_1" )); // new_uv_1 = q*uv[1]
APPEND_M(str( "VPSUBQ `new_uv_1, `uv_0, `new_uv_1" )); // new_uv_1 = uv[0] - q*uv[1]
//overflow checking:
//-the carry_mask bits must be all 0s or all 1s for each 64-bit entry
//-if 1<<data_size is added, the carry_mask bits must be all 0s (negative) or 1<<data_size with the rest 0 (nonnegative)
//-can add 1<<data_size, then check the carry_mask except the last bit
APPEND_M(str( "VPADDQ `tmp_1, `new_uv_1, #", constant_address_uint64(1ull<<data_size, 1ull<<data_size) ));
APPEND_M(str( "VPTEST `tmp_1, #", constant_address_uint64(carry_mask & (~(1ull<<data_size)), carry_mask & (~(1ull<<data_size))) ));
APPEND_M(str( "JNZ #", early_exit_label ));
{
APPEND_M(str( "VMOVQ `tmp_1_128, `a_0" ));
APPEND_M(str( "VPBROADCASTQ `tmp_1, `tmp_1_128" )); // tmp_1 = a[0]
APPEND_M(str( "VPADDQ `tmp_1, `tmp_1, `uv_1" )); // tmp_1 = a[0]+new_uv[0]
APPEND_M(str( "VMOVQ `tmp_2_128, `a_1" ));
APPEND_M(str( "VPBROADCASTQ `tmp_2, `tmp_2_128" )); // tmp_2 = a[1]
APPEND_M(str( "VPADDQ `tmp_2, `tmp_2, `new_uv_1" )); // tmp_2 = a[1]+new_uv[1]
APPEND_M(str( "VPSUBQ `tmp_1, `tmp_1, `tmp_2" )); // tmp_1 = a[0]+new_uv[0]-(a[1]+new_uv[1])
APPEND_M(str( "VPOR `tmp_1, `tmp_1, `tmp_2" )); // sign is 1 if tmp_1<0 or tmp_2<0
//approximate: ZF set if both signs of tmp_1 are 0 (i.e tmp_1>=0 and tmp_2>=0 for both lanes)
//exact: ZF set always
APPEND_M(str( "VPTEST `tmp_1, `c_gcd_mask" ));
APPEND_M(str( "JNZ #", early_exit_label )); //taken if ZF==0
//int64 delta=new_a[0]-new_a[1];
//if (new_a[1]<-new_uv[1]) goto early_exit_label
//if (delta<new_uv[1]-new_uv[0]) goto early_exit_label
//if (new_a[1]+new_uv[1]<0) goto early_exit_label
//if (new_a[0]+new_uv[0]-(new_a[1]+new_uv[1])<0) goto early_exit_label
}
APPEND_M(str( "VMOVDQU `uv_0, `uv_1" )); //>= ab_threshold
APPEND_M(str( "VMOVDQU `uv_1, `new_uv_1" ));
}
}

View File

@ -0,0 +1,796 @@
namespace asm_code {
struct asm_integer {
//if a sign limb exists, it is one qword before this address. the data limbs are after this address
reg_scalar addr_base;
//the asm_integer functions only use addr_base. this is used to assign addr_base if it needs to be allocated
reg_spill addr_base_spill;
int addr_offset=0;
bool is_signed=false;
int size=0; //limbs. lsb limb is first. this is a multiple of 4
asm_integer() {}
asm_integer(reg_spill t_spill, int t_size) {
addr_base_spill=t_spill;
size=t_size;
}
string operator[](int pos) {
assert(pos>=0 && pos<size);
return str( "[#+#]", addr_base.name(), to_hex(addr_offset+pos*8) );
}
bool is_null() {
return size==0;
}
//end_index will return the number of nonzero limbs minus 1
//end_index should initially be >= the number nonzero of limbs minus 1, but not more than size-1
//if the integer is 0, end_index should initially be at least 0 and the returned end_index is 0
//regs: 3x scalar
void update_end_index(reg_alloc regs, reg_scalar end_index) {
EXPAND_MACROS_SCOPE;
assert(size%4==0);
assert(addr_offset==0); //can temporarily modify addr_base if this is false
m.bind(end_index, "end_index");
m.bind(addr_base, "addr_base");
reg_scalar tmp_value=regs.bind_scalar(m, "tmp_value");
reg_scalar tmp_0=regs.bind_scalar(m, "tmp_0");
reg_scalar tmp_8=regs.bind_scalar(m, "tmp_8");
//convert index to address
APPEND_M(str( "LEA `end_index, [`addr_base+`end_index*8]" ));
APPEND_M(str( "XOR `tmp_0, `tmp_0" ));
APPEND_M(str( "MOV `tmp_8, 8" ));
string loop_label=m.alloc_label();
const int num_unroll=2;
assert(num_unroll>=1);
for (int x=0;x<num_unroll;++x) {
if (x==num_unroll-1) {
APPEND_M(str( "#:", loop_label ));
}
APPEND_M(str( "MOV `tmp_value, [`end_index]" ));
//tmp_value=(tmp_value==0)? 8 : 0
//(8 if the last limb is 0, else 0)
APPEND_M(str( "CMP `tmp_value, `tmp_0" ));
APPEND_M(str( "MOV `tmp_value, `tmp_0" ));
APPEND_M(str( "CMOVE `tmp_value, `tmp_8" ));
//if (end_index==end_addr) tmp_value=0
//(sets tmp_value to 0 if there is only 1 limb left)
APPEND_M(str( "CMP `end_index, `addr_base" ));
APPEND_M(str( "CMOVE `tmp_value, `tmp_0" ));
//if tmp_value==8, go to the next lowest limb
//if tmp_value==0, do nothing
APPEND_M(str( "SUB `end_index, `tmp_value" ));
if (x==1) {
//keep looping until end_index stops changing
APPEND_M(str( "CMP `tmp_value, `tmp_0" ));
APPEND_M(str( "JNE #", track_asm( "update_end_index loop", loop_label ) ));
}
}
//convert address to index
APPEND_M(str( "SUB `end_index, `addr_base" ));
APPEND_M(str( "SHR `end_index, 3" ));
}
//end_index=(end_index<2)? 0 : end_index-2
//regs: 1x scalar
void calculate_head_start(reg_alloc regs, reg_scalar end_index) {
EXPAND_MACROS_SCOPE;
assert(size%4==0);
m.bind(end_index, "end_index");
reg_scalar tmp=regs.bind_scalar(m, "tmp");
APPEND_M(str( "XOR `tmp, `tmp" ));
APPEND_M(str( "SUB `end_index, 2" ));
APPEND_M(str( "CMOVB `end_index, `tmp" ));
}
//this is the same as extract_head, except that extracts at nonzero_size
//nonzero_size should be >= the actual nonzero size to avoid truncation
//regs: 1x scalar
void extract_head_at(reg_alloc regs, reg_scalar head_start, array<reg_scalar, 3> res) {
EXPAND_MACROS_SCOPE;
assert(size%4==0);
m.bind(addr_base, "addr_base");
m.bind(head_start, "head_start");
m.bind(res, "res");
reg_scalar tmp_addr=regs.bind_scalar(m, "tmp_addr");
APPEND_M(str( "LEA `tmp_addr, [`addr_base+`head_start*8+#]", to_hex(addr_offset) ));
APPEND_M(str( "MOV `res_0, [`tmp_addr]" ));
APPEND_M(str( "MOV `res_1, [`tmp_addr+8]" ));
APPEND_M(str( "MOV `res_2, [`tmp_addr+16]" ));
}
void mul_add_bmi(
reg_alloc regs, asm_integer a, reg_scalar b, asm_integer c, bool invert_output, bool carry_in_is_1
) {
EXPAND_MACROS_SCOPE;
m.bind(b, "b");
//5x scalar
reg_scalar mul_low_0=regs.bind_scalar(m, "mul_low_0");
reg_scalar mul_low_1=regs.bind_scalar(m, "mul_low_1");
reg_scalar mul_high_0=regs.bind_scalar(m, "mul_high_0");
reg_scalar mul_high_1=regs.bind_scalar(m, "mul_high_1");
reg_scalar rdx=regs.bind_scalar(m, "rdx", reg_rdx);
//clears OF and CF
APPEND_M(str( "XOR RDX, RDX" ));
if (carry_in_is_1) {
APPEND_M(str( "STC" ));
}
APPEND_M(str( "MOV RDX, `b" ));
for (int pos=0;pos<size;pos+=2) {
bool first=(pos==0);
//mul_low=mul_low+mul_high>>64
APPEND_M(str( "MULX `mul_high_0, `mul_low_0, #", a[pos] ));
if (!first) {
APPEND_M(str( "ADOX `mul_low_0, `mul_high_1" ));
}
APPEND_M(str( "MULX `mul_high_1, `mul_low_1, #", a[pos+1] ));
APPEND_M(str( "ADOX `mul_low_1, `mul_high_0" ));
if (!c.is_null()) {
APPEND_M(str( "ADCX `mul_low_0, #", c[pos] ));
APPEND_M(str( "ADCX `mul_low_1, #", c[pos+1] ));
}
if (invert_output) {
APPEND_M(str( "NOT `mul_low_0" ));
APPEND_M(str( "NOT `mul_low_1" ));
}
APPEND_M(str( "MOV #, `mul_low_0", (*this)[pos] ));
APPEND_M(str( "MOV #, `mul_low_1", (*this)[pos+1] ));
}
}
void mul_add_slow(
reg_alloc regs, asm_integer a, reg_scalar b, asm_integer c, bool invert_output, bool carry_in_is_1
) {
EXPAND_MACROS_SCOPE;
m.bind(b, "b");
//11x scalar
reg_scalar mul_carry=regs.bind_scalar(m, "mul_carry");
reg_scalar add_carry=regs.bind_scalar(m, "add_carry");
reg_scalar mul_high_4_previous=regs.bind_scalar(m, "mul_high_4_previous");
reg_scalar mul_low_0=regs.bind_scalar(m, "mul_low_0");
reg_scalar mul_low_1=regs.bind_scalar(m, "mul_low_1");
reg_scalar mul_low_2=regs.bind_scalar(m, "mul_low_2");
reg_scalar mul_low_3=regs.bind_scalar(m, "mul_low_3", reg_rax);
reg_scalar mul_high_0=regs.bind_scalar(m, "mul_high_0");
reg_scalar mul_high_1=regs.bind_scalar(m, "mul_high_1");
reg_scalar mul_high_2=regs.bind_scalar(m, "mul_high_2");
reg_scalar mul_high_3=regs.bind_scalar(m, "mul_high_3", reg_rdx);
for (int pos=0;pos<size;pos+=4) {
bool first=(pos==0);
bool last=(pos==size-4);
//multiply 4 values of a by b
for (int x=0;x<4;++x) {
//mul_low_3=RAX
//mul_high_3=RDX
APPEND_M(str( "MOV RAX, `b" ));
APPEND_M(str( "MUL QWORD PTR #", a[pos+x] ));
if (x==3) {
assert(mul_low_3.value==reg_rax.value);
assert(mul_high_3.value==reg_rdx.value);
} else {
APPEND_M(str( "MOV `mul_low_#, RAX", x ));
APPEND_M(str( "MOV `mul_high_#, RDX", x ));
}
}
//mul_low=mul_low+mul_high>>64
if (first) {
//mul_carry==0 ; mul_high_4_previous==0
APPEND_M(str( "ADD `mul_low_1, `mul_high_0" ));
} else {
APPEND_M(str( "ADD `mul_carry, 1" )); // CF=(mul_carry==-1)? 1 : 0
APPEND_M(str( "ADC `mul_low_0, `mul_high_4_previous" ));
APPEND_M(str( "ADC `mul_low_1, `mul_high_0" ));
}
APPEND_M(str( "ADC `mul_low_2, `mul_high_1" ));
APPEND_M(str( "ADC `mul_low_3, `mul_high_2" ));
if (!last) {
APPEND_M(str( "MOV `mul_high_4_previous, `mul_high_3" ));
APPEND_M(str( "SBB `mul_carry, `mul_carry" )); // mul_carry=(CF)? -1 : 0
}
if (!c.is_null()) {
//mul_low=mul_low+c
//output mul_low
if (first) {
if (carry_in_is_1) {
APPEND_M(str( "STC" ));
APPEND_M(str( "ADC `mul_low_0, #", c[pos] ));
} else {
APPEND_M(str( "ADD `mul_low_0, #", c[pos] ));
}
} else {
APPEND_M(str( "ADD `add_carry, 1" )); // CF=(add_carry==-1)? 1 : 0
APPEND_M(str( "ADC `mul_low_0, #", c[pos] ));
}
for (int x=1;x<4;++x) {
APPEND_M(str( "ADC `mul_low_#, #", x, c[pos+x] ));
}
if (!last) {
APPEND_M(str( "SBB `add_carry, `add_carry" )); // add_carry=(CF)? -1 : 0
}
}
for (int x=0;x<4;++x) {
if (invert_output) {
APPEND_M(str( "NOT `mul_low_#", x ));
}
APPEND_M(str( "MOV #, `mul_low_#", (*this)[pos+x], x ));
}
}
}
// (*this)=a*b+c+(carry_in_is_1? 1 : 0)
// if (invert_output) (*this)=~(*this)
//all of the integers must have the same size (which is a multiple of 4)
//a or c can alias with *this (as long as the aliasing is not partial)
//regs: 11x scalar
//
//to calculate a*b-c*d:
//-first calculate ~(c*d)
//-then calculate a*b+(~(c*d))+1
void mul_add(
reg_alloc regs, asm_integer a, reg_scalar b, asm_integer c, bool invert_output, bool carry_in_is_1
) {
EXPAND_MACROS_SCOPE;
assert(!carry_in_is_1 || !c.is_null());
assert(size%4==0);
assert(size==a.size && (c.is_null() || size==c.size));
if (enable_all_instructions) {
mul_add_bmi(regs, a, b, c, invert_output, carry_in_is_1);
} else {
mul_add_slow(regs, a, b, c, invert_output, carry_in_is_1);
}
}
};
//sets res to the right shift amount required for the uppermost limb to be 0. this is between 0 and 64 inclusive
//regs: 1x scalar
void calculate_shift_amount(reg_alloc regs, array<reg_scalar, 3> limbs, reg_scalar res) {
EXPAND_MACROS_SCOPE;
m.bind(limbs, "limbs");
m.bind(res, "res");
reg_scalar tmp=regs.bind_scalar(m, "tmp");
//res=[first set bit index in limbs_2]+1
APPEND_M(str( "BSR `res, `limbs_2" ));
APPEND_M(str( "INC `res" ));
//res=num bits of limbs_2 [which is also the right shift amount]
//(this is 0 if limbs_2 is 0)
APPEND_M(str( "XOR `tmp, `tmp" ));
APPEND_M(str( "CMP `limbs_2, `tmp" ));
APPEND_M(str( "CMOVE `res, `tmp" ));
}
//amount must be >=0 and <=64
//this only calculates the lower 2 limbs of the result
//regs: 1x scalar
//in-place
void shift_right(reg_alloc regs, array<reg_scalar, 3> limbs, reg_scalar amount) {
EXPAND_MACROS_SCOPE;
m.bind(limbs, "limbs");
m.bind(amount, "amount");
regs.get_scalar(reg_rcx);
APPEND_M(str( "MOV RCX, `amount" ));
// if (amount<64) res[0]=[limbs[1]:limbs[0]]>>amount
// if (amount==64) no-op
APPEND_M(str( "SHRD `limbs_0, `limbs_1, CL" ));
// if (amount<64) res[1]=[limbs[2]:limbs[1]]>>amount
// if (amount==64) no-op
APPEND_M(str( "SHRD `limbs_1, `limbs_2, CL" ));
APPEND_M(str( "CMP `amount, 64" ));
APPEND_M(str( "CMOVE `limbs_0, `limbs_1" ));
APPEND_M(str( "CMOVE `limbs_1, `limbs_2" ));
}
//this must be true: a>=b; a>=threshold
//
//all of the integers should have spilled addresses with offsets of 0. all of their sizes should be the same
//the input a and b values should go into spill_a and spill_b. spill_a_2 and spill_b_2 should be uninitialized
//spill_iter will be between -1 and max_iterations
//the final a value is in spill_a if spill_iter is odd, otherwise is is in a_2. same with b
//
//for each iteration, including iteration -1, the following will happen:
//-64 bytes of data is written to *(spill_out_uv_addr + iter*64)
//-then, *spill_uv_counter_addr is set to spill_uv_counter_start+iter
//
//the data has the following format: [u0] [u1] [v0] [v1] [parity] [exit_flag]
//-each entry is 8 bytes
//-if iter is -1, only exit_flag is initialized and the rest have undefined values
//-if exit_flag is 1, this is the final result
//
//no more than max_iterations+1 results will be outputted. there will be an error if there are more results than this
//(this includes iteration -1)
//
//spill_a_end_index must be < a's size and >= 0. any limbs past this must be 0 for a, b, and threshold, but only up to the next
// multiple of 4 limbs. (e.g. if spill_a_end_index is 6, there are 7 limbs so the 8th limb must be 0 and the rest can be uninitialized)
//
//the return value of iter is the total number of iterations performed, which is at least 0. iter-1 is the parity of the last iteration
void gcd_unsigned(
reg_alloc regs_parent,
asm_integer spill_a, asm_integer spill_b, asm_integer spill_a_2, asm_integer spill_b_2, asm_integer spill_threshold,
reg_spill spill_uv_counter_start, reg_spill spill_out_uv_counter_addr, reg_spill spill_out_uv_addr,
reg_spill spill_iter, reg_spill spill_a_end_index, int max_iterations
) {
EXPAND_MACROS_SCOPE_PUBLIC;
track_asm( "gcd_unsigned" );
int int_size=spill_a.size;
assert(spill_a.addr_offset==0 && spill_b.addr_offset==0 && spill_threshold.addr_offset==0);
assert(spill_a.addr_base.value==-1 && spill_b.addr_base.value==-1 && spill_threshold.addr_base.value==-1);
assert(spill_a_2.addr_offset==0 && spill_b_2.addr_offset==0);
assert(spill_a_2.addr_base.value==-1 && spill_b_2.addr_base.value==-1);
assert(spill_a.size==int_size && spill_b.size==int_size && spill_threshold.size==int_size);
assert(spill_a_2.size==int_size && spill_b_2.size==int_size);
m.bind(spill_a.addr_base_spill, "spill_a_addr_base");
m.bind(spill_a_2.addr_base_spill, "spill_a_2_addr_base");
m.bind(spill_b.addr_base_spill, "spill_b_addr_base");
m.bind(spill_b_2.addr_base_spill, "spill_b_2_addr_base");
m.bind(spill_threshold.addr_base_spill, "spill_threshold_addr_base");
m.bind(spill_iter, "spill_iter");
m.bind(spill_uv_counter_start, "spill_uv_counter_start");
m.bind(spill_out_uv_addr, "spill_out_uv_addr");
m.bind(spill_out_uv_counter_addr, "spill_out_uv_counter_addr");
m.bind(spill_a_end_index, "spill_a_end_index");
reg_spill spill_u_0=regs_parent.bind_spill(m, "spill_u_0");
reg_spill spill_u_1=regs_parent.bind_spill(m, "spill_u_1");
reg_spill spill_v_0=regs_parent.bind_spill(m, "spill_v_0");
reg_spill spill_v_1=regs_parent.bind_spill(m, "spill_v_1");
reg_spill spill_parity=regs_parent.bind_spill(m, "spill_parity");
reg_spill spill_is_lehmer=regs_parent.bind_spill(m, "spill_is_lehmer");
reg_spill spill_a_128=regs_parent.bind_spill(m, "spill_a_128", 16, 8);
reg_spill spill_b_128=regs_parent.bind_spill(m, "spill_b_128", 16, 8);
reg_spill spill_threshold_128=regs_parent.bind_spill(m, "spill_threshold_128", 16, 8);
m.bind(spill_a_128+8, "spill_a_128_8");
m.bind(spill_b_128+8, "spill_b_128_8");
m.bind(spill_threshold_128+8, "spill_threshold_128_8");
APPEND_M(str( "MOV QWORD PTR `spill_iter, -1" ));
string loop_start=m.alloc_label();
string loop=m.alloc_label();
string loop_exit=m.alloc_label();
APPEND_M(str( "JMP #", loop_start ));
APPEND_M(str( "#:", loop ));
//iter even: old_a=a , old_b=b ; new_a=a_2, new_b=b_2
//iter odd: old_a=a_2, old_b=b_2 ; new_a=a , new_b=b
gcd_128(
regs_parent,
{spill_a_128, spill_b_128}, {spill_u_0, spill_u_1}, {spill_v_0, spill_v_1},
spill_parity, spill_is_lehmer, spill_threshold_128,
track_asm( "gcd_unsigned error: gcd 128 stuck", m.alloc_error_label() )
);
string exit_multiply_uv=m.alloc_label();
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
reg_scalar tmp=regs.bind_scalar(m, "tmp");
string jump_table_label=m.alloc_label();
#ifdef CHIAOSX
APPEND_M(str( ".text " ));
#else
APPEND_M(str( ".text 1" ));
#endif
APPEND_M(str( ".balign 8" ));
APPEND_M(str( "#:", jump_table_label ));
#ifdef CHIAOSX
APPEND_M(str( ".text" ));
APPEND_M(str( "MOV `tmp, `spill_a_end_index" ));
for (int end_index=0;end_index<int_size;++end_index) {
int size=end_index+1;
int mapped_size=size;
while (mapped_size==0 || mapped_size%4!=0) {
++mapped_size;
}
APPEND_M(str( "CMP `tmp, #", size ));
APPEND_M(str( "JE multiply_uv_size_#", mapped_size ));
}
#else
for (int end_index=0;end_index<int_size;++end_index) {
int size=end_index+1;
int mapped_size=size;
while (mapped_size==0 || mapped_size%4!=0) {
++mapped_size;
}
APPEND_M(str( ".quad multiply_uv_size_#", mapped_size ));
}
APPEND_M(str( ".text" ));
APPEND_M(str( "MOV `tmp, `spill_a_end_index" ));
APPEND_M(str( "JMP QWORD PTR [#+`tmp*8]", jump_table_label ));
#endif
}
for (int size=4;size<=int_size;size+=4) {
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
APPEND_M(str( "multiply_uv_size_#:", size ));
track_asm(str( "gcd_unsigned multiply uv size #", size ));
//reg_scalar t=regs.bind_scalar(m, "t");
// even:
// new_a=a*u_0 - b*v_0;
// new_a=b*v_1 - a*u_1;
//
// tmp0=b*v_0
// tmp1=a*u_1
// new_a=a*u_0 - tmp0
// new_b=b*v_1 - tmp1
//
// odd:
// new_a=b*v_0 - a*u_0;
// new_b=a*u_1 - b*v_1;
//
// tmp0=a*u_0
// tmp1=b*v_1
// new_a=b*v_0 - tmp0
// new_b=a*u_1 - tmp1
//
// in general:
// tmp0=(even?b:a)*(even?v_0:u_0)
// tmp1=(even?a:b)*(even?u_1:v_1)
// new_a=(even?a:b)*(even?u_0:v_0) - tmp0
// new_b=(even?b:a)*(even?v_1:u_1) - tmp1
reg_scalar addr_a=regs.bind_scalar(m, "addr_a");
reg_scalar addr_b=regs.bind_scalar(m, "addr_b");
reg_scalar addr_new=regs.bind_scalar(m, "addr_new");
reg_scalar tmp=regs.bind_scalar(m, "tmp");
reg_spill spill_mod_u_0=regs.bind_spill(m, "spill_mod_u_0");
reg_spill spill_mod_u_1=regs.bind_spill(m, "spill_mod_u_1");
reg_spill spill_mod_v_0=regs.bind_spill(m, "spill_mod_v_0");
reg_spill spill_mod_v_1=regs.bind_spill(m, "spill_mod_v_1");
reg_spill spill_addr_b_new=regs.bind_spill(m, "spill_addr_b_new");
APPEND_M(str( "MOV `tmp, `spill_parity" ));
APPEND_M(str( "CMP `tmp, 0" ));
for (int x=0;x<2;++x) {
APPEND_M(str( "MOV `addr_a, `spill_u_#", x ));
APPEND_M(str( "MOV `addr_b, `spill_v_#", x ));
//if (spill_parity!=0) swap(u[x], v[x])
APPEND_M(str( "MOV `addr_new, `addr_a" ));
APPEND_M(str( "CMOVNE `addr_a, `addr_b" ));
APPEND_M(str( "CMOVNE `addr_b, `addr_new" ));
APPEND_M(str( "MOV `spill_mod_u_#, `addr_a", x ));
APPEND_M(str( "MOV `spill_mod_v_#, `addr_b", x ));
}
APPEND_M(str( "MOV `addr_new, `spill_iter" ));
APPEND_M(str( "TEST `addr_new, 1" )); // ZF=even iteration
//addr_a=(even iteration)? &a : &a_2
APPEND_M(str( "MOV `addr_a, `spill_a_addr_base" ));
APPEND_M(str( "CMOVNZ `addr_a, `spill_a_2_addr_base" ));
//addr_b=(even iteration)? &b : &b_2
APPEND_M(str( "MOV `addr_b, `spill_b_addr_base" ));
APPEND_M(str( "CMOVNZ `addr_b, `spill_b_2_addr_base" ));
//if (spill_parity!=0) swap(addr_a, addr_b)
APPEND_M(str( "CMP `tmp, 0" ));
APPEND_M(str( "MOV `addr_new, `addr_a" ));
APPEND_M(str( "CMOVNE `addr_a, `addr_b" ));
APPEND_M(str( "CMOVNE `addr_b, `addr_new" ));
//done using tmp (spill_parity)
//spill_addr_b_new=(even iteration)? &b_2 : &b
APPEND_M(str( "MOV `addr_new, `spill_iter" ));
APPEND_M(str( "TEST `addr_new, 1" )); // ZF=even iteration
APPEND_M(str( "MOV `addr_new, `spill_b_2_addr_base" ));
APPEND_M(str( "CMOVNZ `addr_new, `spill_b_addr_base" ));
APPEND_M(str( "MOV `spill_addr_b_new, `addr_new" ));
//addr_new=(even iteration)? &a_2 : &a
APPEND_M(str( "MOV `addr_new, `spill_a_2_addr_base" ));
APPEND_M(str( "CMOVNZ `addr_new, `spill_a_addr_base" ));
//this can be a, a_2, b, or b_2 depending on iter and parity
asm_integer a;
a.size=int_size;
a.addr_base=addr_a;
asm_integer b;
b.size=int_size;
b.addr_base=addr_b;
//initially new_a
asm_integer new_ab;
new_ab.size=int_size;
new_ab.addr_base=addr_new;
reg_spill tmp0_spill=regs.get_spill(int_size*8, 8);
asm_integer tmp0;
tmp0.size=int_size;
tmp0.addr_base=reg_rsp;
tmp0.addr_offset=tmp0_spill.get_rsp_offset();
reg_spill tmp1_spill=regs.get_spill(int_size*8, 8);
asm_integer tmp1;
tmp1.size=int_size;
tmp1.addr_base=reg_rsp;
tmp1.addr_offset=tmp1_spill.get_rsp_offset();
// tmp0=(even?b:a)*(even?v_0:u_0)
APPEND_M(str( "MOV `tmp, `spill_mod_v_0" ));
tmp0.mul_add(regs, b, tmp, asm_integer(), true, false);
// tmp1=(even?a:b)*(even?u_1:v_1)
APPEND_M(str( "MOV `tmp, `spill_mod_u_1" ));
tmp1.mul_add(regs, a, tmp, asm_integer(), true, false);
// new_a=(even?a:b)*(even?u_0:v_0) - tmp0
APPEND_M(str( "MOV `tmp, `spill_mod_u_0" ));
new_ab.mul_add(regs, a, tmp, tmp0, false, true);
// new_b=(even?b:a)*(even?v_1:u_1) - tmp1
APPEND_M(str( "MOV `addr_new, `spill_addr_b_new" ));
APPEND_M(str( "MOV `tmp, `spill_mod_v_1" ));
new_ab.mul_add(regs, b, tmp, tmp1, false, true);
APPEND_M(str( "JMP #", exit_multiply_uv ));
}
APPEND_M(str( "#:", exit_multiply_uv ));
//8x
reg_scalar iter=regs_parent.bind_scalar(m, "iter");
reg_scalar is_lehmer=regs_parent.bind_scalar(m, "is_lehmer");
reg_scalar a_head_0=regs_parent.bind_scalar(m, "a_head_0");
reg_scalar a_head_1=regs_parent.bind_scalar(m, "a_head_1");
reg_scalar b_head_0=regs_parent.bind_scalar(m, "b_head_0");
reg_scalar b_head_1=regs_parent.bind_scalar(m, "b_head_1");
reg_scalar a_head_start=regs_parent.bind_scalar(m, "a_head_start");
reg_scalar shift_right_amount=regs_parent.bind_scalar(m, "shift_right_amount");
APPEND_M(str( "#:", loop_start ));
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
//6x + 3x from called functions
reg_scalar addr_a=regs.bind_scalar(m, "addr_a", reg_rax);
reg_scalar addr_b=regs.bind_scalar(m, "addr_b", reg_rdx);
reg_scalar b_head_2=regs.bind_scalar(m, "b_head_2");
reg_scalar a_head_2=regs.bind_scalar(m, "a_head_2");
APPEND_M(str( "MOV `iter, `spill_iter" ));
//addr_a=(even iteration)? &a_2 : &a
APPEND_M(str( "TEST `iter, 1" )); // ZF=even iteration
APPEND_M(str( "MOV `addr_a, `spill_a_2_addr_base" ));
APPEND_M(str( "CMOVNZ `addr_a, `spill_a_addr_base" ));
//addr_b=(even iteration)? &b_2 : &b
APPEND_M(str( "MOV `addr_b, `spill_b_2_addr_base" ));
APPEND_M(str( "CMOVNZ `addr_b, `spill_b_addr_base" ));
asm_integer a;
a.size=int_size;
a.addr_base=addr_a;
asm_integer b;
b.size=int_size;
b.addr_base=addr_b;
APPEND_M(str( "MOV `a_head_start, `spill_a_end_index" ));
a.update_end_index(regs, a_head_start);
APPEND_M(str( "MOV `spill_a_end_index, `a_head_start" ));
//is_lehmer=(a_end_index>=2)
//(a_end_index is stored in a_head_start)
APPEND_M(str( "XOR `is_lehmer, `is_lehmer" ));
APPEND_M(str( "CMP `a_head_start, 2" ));
APPEND_M(str( "SETAE `is_lehmer_8" ));
APPEND_M(str( "MOV `spill_is_lehmer, `is_lehmer" ));
a.calculate_head_start(regs, a_head_start);
a.extract_head_at(regs, a_head_start, {a_head_0, a_head_1, a_head_2});
calculate_shift_amount(regs, {a_head_0, a_head_1, a_head_2}, shift_right_amount);
shift_right(regs, {a_head_0, a_head_1, a_head_2}, shift_right_amount);
b.extract_head_at(regs, a_head_start, {b_head_0, b_head_1, b_head_2});
shift_right(regs, {b_head_0, b_head_1, b_head_2}, shift_right_amount);
APPEND_M(str( "MOV `spill_a_128, `a_head_0" ));
APPEND_M(str( "MOV `spill_a_128_8, `a_head_1" ));
APPEND_M(str( "MOV `spill_b_128, `b_head_0" ));
APPEND_M(str( "MOV `spill_b_128_8, `b_head_1" ));
}
//9x
//iter, is_lehmer, b_head_0, b_head_1, a_head_start, shift_right_amount
reg_scalar exit_flag=regs_parent.bind_scalar(m, "exit_flag");
//clobbers is_lehmer
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
//4x + 1x from called functions
reg_scalar addr_threshold=regs.bind_scalar(m, "addr_threshold", reg_rax);
reg_scalar threshold_head_0=regs.bind_scalar(m, "threshold_head_0", reg_rdx);
reg_scalar threshold_head_1=regs.bind_scalar(m, "threshold_head_1");
reg_scalar threshold_head_2=regs.bind_scalar(m, "threshold_head_2");
//addr_threshold=&threshold
APPEND_M(str( "MOV `addr_threshold, `spill_threshold_addr_base" ));
asm_integer threshold;
threshold.size=int_size;
threshold.addr_base=addr_threshold;
threshold.extract_head_at(regs, a_head_start, {threshold_head_0, threshold_head_1, threshold_head_2});
shift_right(regs, {threshold_head_0, threshold_head_1, threshold_head_2}, shift_right_amount);
APPEND_M(str( "MOV `spill_threshold_128, `threshold_head_0" ));
APPEND_M(str( "MOV `spill_threshold_128_8, `threshold_head_1" ));
//if (a_head<=threshold_head) goto error
APPEND_M(str( "MOV `addr_threshold, `threshold_head_0" ));
APPEND_M(str( "MOV `threshold_head_2, `threshold_head_1" ));
APPEND_M(str( "SUB `addr_threshold, `a_head_0" ));
APPEND_M(str( "SBB `threshold_head_2, `a_head_1" ));
APPEND_M(str( "JNC #", track_asm( "gcd_unsigned error: a_head<=threshold_head", m.alloc_error_label() ) ));
//threshold_head' = threshold_head-b_head
APPEND_M(str( "XOR `exit_flag, `exit_flag" ));
APPEND_M(str( "SUB `threshold_head_0, `b_head_0" ));
APPEND_M(str( "SBB `threshold_head_1, `b_head_1" ));
APPEND_M(str( "SETNC `exit_flag_8" )); //exit_flag = (threshold_head>=b_head)
//if (b_head==threshold_head && is_lehmer) goto error
APPEND_M(str( "OR `threshold_head_0, `threshold_head_1" ));
APPEND_M(str( "DEC `is_lehmer" )); // is_lehmer'=(is_lehmer)? 0 : ~0
APPEND_M(str( "OR `threshold_head_0, `is_lehmer" )); //ZF = (threshold_head'==0 && is_lehmer)
APPEND_M(str( "JZ #", track_asm( "gcd_unsigned error: b_head==threshold_head and is_lehmer", m.alloc_error_label() ) ));
}
//9x
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
//2x
reg_scalar out_uv_addr=regs.bind_scalar(m, "out_uv_addr");
reg_scalar tmp=regs.bind_scalar(m, "tmp");
//out_uv_addr = spill_out_uv_addr + iter*64
//note: iter can be -1
APPEND_M(str( "MOV `out_uv_addr, `iter" ));
APPEND_M(str( "SHL `out_uv_addr, 6" ));
APPEND_M(str( "ADD `out_uv_addr, `spill_out_uv_addr" ));
APPEND_M(str( "MOV `tmp, `spill_u_0" ));
APPEND_M(str( "MOV [`out_uv_addr], `tmp" ));
APPEND_M(str( "MOV `tmp, `spill_u_1" ));
APPEND_M(str( "MOV [`out_uv_addr+8], `tmp" ));
APPEND_M(str( "MOV `tmp, `spill_v_0" ));
APPEND_M(str( "MOV [`out_uv_addr+16], `tmp" ));
APPEND_M(str( "MOV `tmp, `spill_v_1" ));
APPEND_M(str( "MOV [`out_uv_addr+24], `tmp" ));
APPEND_M(str( "MOV `tmp, `spill_parity" ));
APPEND_M(str( "MOV [`out_uv_addr+32], `tmp" ));
APPEND_M(str( "MOV [`out_uv_addr+40], `exit_flag" ));
//done assigning the data; can now increment the counter. this is not atomic because only this thread can write to the counter
//(the counter must be 8-aligned)
//x86 uses acq_rel ordering on all of the loads and stores so no fences are required
APPEND_M(str( "MOV `tmp, `spill_uv_counter_start" ));
APPEND_M(str( "ADD `tmp, `iter" ));
APPEND_M(str( "MOV `out_uv_addr, `spill_out_uv_counter_addr" ));
APPEND_M(str( "MOV [`out_uv_addr], `tmp" ));
APPEND_M(str( "INC `iter" ));
APPEND_M(str( "MOV `spill_iter, `iter" ));
APPEND_M(str( "CMP `exit_flag, 0" ));
APPEND_M(str( "JNE #", loop_exit ));
APPEND_M(str( "CMP `iter, #", to_hex(max_iterations) )); //signed
APPEND_M(str( "JGE #", track_asm( "gcd_unsigned error: max_iterations exceeded", m.alloc_error_label() ) ));
}
APPEND_M(str( "JMP #", loop ));
APPEND_M(str( "#:", loop_exit ));
}
}

View File

@ -0,0 +1,250 @@
#ifdef GENERATE_ASM_TRACKING_DATA
#ifndef COMPILE_ASM
extern "C" uint64 asm_tracking_data[num_asm_tracking_data];
extern "C" char* asm_tracking_data_comments[num_asm_tracking_data];
uint64 asm_tracking_data[num_asm_tracking_data];
char* asm_tracking_data_comments[num_asm_tracking_data];
#endif
#endif
namespace asm_code {
//all doubles are arrays with 2 entries. the high entry is first followed by the low entry
//so: b, a; u1, u0; v1, v0
//is_lehmer is all 1s or all 0s. ab_threshold is duplicated twice
extern "C" int asm_func_gcd_base(double* ab, double* u, double* v, uint64* is_lehmer, double* ab_threshold, uint64* no_progress);
#ifdef COMPILE_ASM
void compile_asm_gcd_base() {
EXPAND_MACROS_SCOPE;
asm_function c_func( "gcd_base", 6 );
reg_alloc regs=c_func.regs;
reg_vector ab=regs.bind_vector(m, "ab");
reg_vector u=regs.bind_vector(m, "u");
reg_vector v=regs.bind_vector(m, "v");
reg_vector is_lehmer=regs.bind_vector(m, "is_lehmer");
reg_vector ab_threshold=regs.bind_vector(m, "ab_threshold");
m.bind(c_func.args.at(0), "ab_addr");
m.bind(c_func.args.at(1), "u_addr");
m.bind(c_func.args.at(2), "v_addr");
m.bind(c_func.args.at(3), "is_lehmer_addr");
m.bind(c_func.args.at(4), "ab_threshold_addr");
m.bind(c_func.args.at(5), "no_progress_addr");
APPEND_M(str( "MOVDQU `ab, [`ab_addr]" ));
APPEND_M(str( "MOVDQU `u, [`u_addr]" ));
APPEND_M(str( "MOVDQU `v, [`v_addr]" ));
APPEND_M(str( "MOVDQU `is_lehmer, [`is_lehmer_addr]" ));
APPEND_M(str( "MOVDQU `ab_threshold, [`ab_threshold_addr]" ));
string no_progress_label=m.alloc_label();
string progress_label=m.alloc_label();
string exit_label=m.alloc_label();
gcd_base_continued_fraction(regs, ab, u, v, is_lehmer, ab_threshold, no_progress_label);
APPEND_M(str( "JMP #", progress_label ));
APPEND_M(str( "#:", no_progress_label ));
APPEND_M(str( "MOV QWORD PTR [`no_progress_addr], 1" ));
APPEND_M(str( "JMP #", exit_label ));
APPEND_M(str( "#:", progress_label ));
APPEND_M(str( "MOV QWORD PTR [`no_progress_addr], 0" ));
APPEND_M(str( "#:", exit_label ));
APPEND_M(str( "MOVDQU [`ab_addr], `ab" ));
APPEND_M(str( "MOVDQU [`u_addr], `u" ));
APPEND_M(str( "MOVDQU [`v_addr], `v" ));
APPEND_M(str( "MOVDQU [`is_lehmer_addr], `is_lehmer" ));
APPEND_M(str( "MOVDQU [`ab_threshold_addr], `ab_threshold" ));
}
#endif
//104 bytes
struct asm_func_gcd_128_data {
//4
uint64 ab_start_0_0;
uint64 ab_start_0_8;
uint64 ab_start_1_0;
uint64 ab_start_1_8;
//4
uint64 u_0;
uint64 u_1;
uint64 v_0;
uint64 v_1;
//5
uint64 parity; //1 if odd, else 0
uint64 is_lehmer; //1 if true, else 0
uint64 ab_threshold_0;
uint64 ab_threshold_8;
uint64 no_progress;
};
extern "C" int asm_func_gcd_128(asm_func_gcd_128_data* data);
#ifdef COMPILE_ASM
void compile_asm_gcd_128() {
EXPAND_MACROS_SCOPE_PUBLIC;
asm_function c_func( "gcd_128", 1 );
reg_alloc regs_parent=c_func.regs;
reg_spill spill_data_addr=regs_parent.bind_spill(m, "spill_data_addr");
reg_spill spill_data=regs_parent.bind_spill(m, "spill_data", sizeof(asm_func_gcd_128_data), 8);
assert(sizeof(asm_func_gcd_128_data)%8==0);
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
m.bind(c_func.args.at(0), "data_addr");
reg_scalar tmp=regs.bind_scalar(m, "tmp");
APPEND_M(str( "MOV `spill_data_addr, `data_addr" ));
for (int x=0;x<sizeof(asm_func_gcd_128_data)/8;++x) {
APPEND_M(str( "MOV `tmp, [`data_addr+#]", to_hex(x*8) ));
APPEND_M(str( "MOV #, `tmp", (spill_data+8*x).name() ));
}
}
regs_parent.add(c_func.args.at(0));
c_func.args.clear();
string no_progress_label=m.alloc_label();
string progress_label=m.alloc_label();
string exit_label=m.alloc_label();
gcd_128(
regs_parent,
{spill_data, spill_data+16}, {spill_data+32, spill_data+40}, {spill_data+48, spill_data+56},
spill_data+64, spill_data+72, spill_data+80, no_progress_label
);
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
reg_scalar tmp=regs.bind_scalar(m, "tmp");
reg_scalar data_addr=regs.bind_scalar(m, "data_addr");
APPEND_M(str( "JMP #", progress_label ));
APPEND_M(str( "#:", no_progress_label ));
APPEND_M(str( "MOV `tmp, 1" ));
APPEND_M(str( "JMP #", exit_label ));
APPEND_M(str( "#:", progress_label ));
APPEND_M(str( "MOV `tmp, 0" ));
APPEND_M(str( "#:", exit_label ));
APPEND_M(str( "MOV #, `tmp", (spill_data+96).name() ));
APPEND_M(str( "MOV `data_addr, `spill_data_addr" ));
for (int x=0;x<sizeof(asm_func_gcd_128_data)/8;++x) {
APPEND_M(str( "MOV `tmp, #", (spill_data+8*x).name() ));
APPEND_M(str( "MOV [`data_addr+#], `tmp", to_hex(x*8) ));
}
}
}
#endif
struct asm_func_gcd_unsigned_data {
uint64* a;
uint64* b;
uint64* a_2;
uint64* b_2;
uint64* threshold;
uint64 uv_counter_start;
uint64* out_uv_counter_addr;
uint64* out_uv_addr;
int64 iter;
uint64 a_end_index;
};
extern "C" int asm_func_gcd_unsigned(asm_func_gcd_unsigned_data* data);
#ifdef COMPILE_ASM
void compile_asm_gcd_unsigned() {
EXPAND_MACROS_SCOPE_PUBLIC;
const int int_size=gcd_size;
const int max_iterations=gcd_max_iterations;
asm_function c_func( "gcd_unsigned", 1 );
reg_alloc regs_parent=c_func.regs;
reg_spill spill_data_addr=regs_parent.bind_spill(m, "spill_data_addr");
reg_spill spill_data=regs_parent.bind_spill(m, "spill_data", sizeof(asm_func_gcd_unsigned_data), 8);
assert(sizeof(asm_func_gcd_unsigned_data)%8==0);
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
m.bind(c_func.args.at(0), "data_addr");
reg_scalar tmp=regs.bind_scalar(m, "tmp");
APPEND_M(str( "MOV `spill_data_addr, `data_addr" ));
for (int x=0;x<sizeof(asm_func_gcd_unsigned_data)/8;++x) {
APPEND_M(str( "MOV `tmp, [`data_addr+#]", to_hex(x*8) ));
APPEND_M(str( "MOV #, `tmp", (spill_data+8*x).name() ));
}
}
regs_parent.add(c_func.args.at(0));
c_func.args.clear();
gcd_unsigned(
regs_parent,
asm_integer(spill_data, int_size), asm_integer(spill_data+8, int_size),
asm_integer(spill_data+16, int_size), asm_integer(spill_data+24, int_size), asm_integer(spill_data+32, int_size),
spill_data+40, spill_data+48, spill_data+56,
spill_data+64, spill_data+72, max_iterations
);
{
EXPAND_MACROS_SCOPE;
reg_alloc regs=regs_parent;
reg_scalar tmp=regs.bind_scalar(m, "tmp");
reg_scalar data_addr=regs.bind_scalar(m, "data_addr");
APPEND_M(str( "MOV `data_addr, `spill_data_addr" ));
for (int x=0;x<sizeof(asm_func_gcd_unsigned_data)/8;++x) {
APPEND_M(str( "MOV `tmp, #", (spill_data+8*x).name() ));
APPEND_M(str( "MOV [`data_addr+#], `tmp", to_hex(x*8) ));
}
}
}
#endif
#ifdef COMPILE_ASM
void compile_asm() {
compile_asm_gcd_base();
compile_asm_gcd_128();
compile_asm_gcd_unsigned();
ofstream out( "asm_compiled.s" );
out << m.format_res_text();
}
#endif
}

View File

@ -0,0 +1,664 @@
/*
0 rax
1 rbx
2 rcx
3 rdx
rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi, r8, r9, r10, r11, r12, r13, r14, r15
rsp - stack pointer (used by stack engine)
rax/rdx - output of multiplication and division; temporaries
notation:
-each name is either a 64 bit scalar register or a 256 bit ymm register
-for ymm registers, a "_128" suffix is used for the xmm register
-for scalar registers: "_32" is used for "eax/r8d/etc", "_16" is used for "ax/r8w/etc", "_8" is used for "al/r8b/etc"
-writing to a 32 bit register zero-extends the result to 64 bits. writing to a 8/16 bit register does not zero extend
***/
const int spill_bytes=1024;
const int comment_asm_line_size=40;
const vector<string> scalar_register_names_64={
"RSP", // 0 - stack pointer; used by stack engine etc. not allocated
"RAX", // 1 - temporary; used for mul/div/etc. this is allocated last
"RDX", // 2 - temporary; used for mul/div/etc. allocated 2nd last
"RCX", // 3 - temporary; used for shr/etc. allocated 3rd last
"RBX", // 4
"RBP", // 5
"RSI", // 6
"RDI", // 7
"R8", // 8
"R9", // 9
"R10", // 10
"R11", // 11
"R12", // 12
"R13", // 13
"R14", // 14
"R15" // 15
};
const vector<string> scalar_register_names_32={
"ESP" , "EAX" , "EDX" , "ECX" ,
"EBX" , "EBP" , "ESI" , "EDI" ,
"R8D" , "R9D" , "R10D", "R11D",
"R12D", "R13D", "R14D", "R15D"
};
const vector<string> scalar_register_names_16={
"SP" , "AX" , "DX" , "CX" ,
"BX" , "BP" , "SI" , "DI" ,
"R8W" , "R9W" , "R10W", "R11W",
"R12W", "R13W", "R14W", "R15W"
};
const vector<string> scalar_register_names_8={
"SPL" , "AL" , "DL" , "CL" ,
"BL" , "BPL" , "SIL" , "DIL" ,
"R8B" , "R9B" , "R10B", "R11B",
"R12B", "R13B", "R14B", "R15B"
};
string to_hex(int128 i) {
int128 i_abs=(i<0)? -i : i;
assert(i_abs>=0);
assert(uint64(i_abs)==i_abs);
ostringstream ss;
ss << ((i<0)? "-" : "") << "0x" << hex << uint64(i_abs);
return ss.str();
}
void str_impl(vector<string>& out) {}
template<class type_a, class... types> void str_impl(
vector<string>& out, const type_a& a, const types&... targs
) {
out.push_back(to_string(a));
str_impl(out, targs...);
}
template<class... types> string str(const string& t, const types&... targs) {
vector<string> data;
str_impl(data, targs...);
string res;
int next=0;
for (char c : t) {
if (c=='#') {
res+=data.at(next);
++next;
} else {
res+=c;
}
}
assert(next==data.size());
return res;
}
struct expand_macros_recording {
int start_pos=-1;
int end_pos=-1;
~expand_macros_recording() {
assert((start_pos==-1 && end_pos==-1) || (start_pos!=-1 && end_pos!=-1));
}
};
struct expand_macros {
struct scope_data {
string scope_name;
map<string, string> name_to_value;
bool is_public=false;
};
vector<scope_data> scopes;
map<string, set<pair<int, string>>> value_to_name; //int is scope
vector<vector<string>> res_text; //first entry is tag
int next_label_id=0;
int next_error_label_id=1; //can't be 0 since the id is used as the return code
int next_output_error_label_id=1;
int num_active_recordings=0;
vector<string> tag_stack;
bool output_tags=false;
void begin_recording(expand_macros_recording& res) {
assert(res.start_pos==-1 && res.end_pos==-1);
res.start_pos=res_text.size();
++num_active_recordings;
}
vector<vector<string>> end_recording(expand_macros_recording& res) {
assert(res.start_pos!=-1 && res.end_pos==-1);
res.end_pos=res_text.size();
--num_active_recordings;
vector<vector<string>> c_text;
for (int x=res.start_pos;x<res.end_pos;++x) {
c_text.push_back(res_text.at(x));
}
return c_text;
}
void append_recording(vector<vector<string>> c_text) {
for (auto& c : c_text) {
res_text.push_back(c);
}
}
string alloc_label() {
assert(num_active_recordings==0);
string res = "_label_" + to_string(next_label_id);
++next_label_id;
return res;
}
string alloc_error_label() {
assert(num_active_recordings==0);
string res = "label_error_" + to_string(next_error_label_id);
++next_error_label_id;
return res;
}
void begin_scope(string name, bool is_public=false) {
scopes.emplace_back(scope_data());
scopes.back().scope_name=name;
scopes.back().is_public=is_public;
}
void end_scope() {
assert(!scopes.empty());
for (pair<const string, string>& n : scopes.back().name_to_value) {
bool erase_res=value_to_name.at(n.second).erase(make_pair(scopes.size()-1, n.first));
assert(erase_res);
}
scopes.pop_back();
}
void bind_impl(string name, string value) {
assert(!scopes.empty());
bool emplace_res_1=scopes.back().name_to_value.emplace(name, value).second;
assert(emplace_res_1);
bool emplace_res_2=value_to_name[value].emplace(scopes.size()-1, name).second;
assert(emplace_res_2);
}
string lookup_value(string name) {
for (int x=scopes.size()-1;x>=0;--x) {
if (x!=scopes.size()-1 && !scopes[x].is_public) {
continue;
}
auto i=scopes[x].name_to_value.find(name);
if (i!=scopes[x].name_to_value.end()) {
return i->second;
}
}
assert(false);
return "";
}
string describe_scope() {
string res;
for (auto& c : scopes) {
if (!res.empty()) {
res+="/";
}
res+=c.scope_name;
}
return res;
}
string describe_name(string name) {
string value=lookup_value(name);
set<pair<int, string>>& names=value_to_name.at(value);
string res;
res+=name;
res+="=";
res+=value;
if (names.size()>=2) {
res+="(";
bool first=true;
for (auto& c : names) {
if (!first) {
res+=",";
}
if (c.second!=name) {
res+=c.second;
first=false;
}
}
res+=")";
}
return res;
}
pair<string, vector<string>> expand(string s) {
string res;
vector<string> res_names;
string buffer;
bool in_name=false;
s+='\0';
for (char c : s) {
if (in_name) {
if ((c>='0' && c<='9') || (c>='A' && c<='Z') || (c>='a' && c<='z') || c=='_') {
buffer+=c;
} else {
in_name=false;
res+=lookup_value(buffer);
res_names.push_back(buffer);
buffer.clear();
}
}
if (!in_name) {
if (c=='`') {
in_name=true;
} else {
if (c!='\0') {
res+=c;
}
}
}
}
return make_pair(res, res_names);
}
void append(string s, int line, string file, string func) {
bool add_comment=true;
assert(!s.empty());
auto r=expand(s);
res_text.emplace_back();
res_text.back().push_back((tag_stack.empty())? "" : tag_stack.back());
res_text.back().push_back(r.first);
if (add_comment) {
res_text.back().push_back( " # " + scopes.back().scope_name + ":" + to_string(line) + " " );
res_text.back().push_back(s);
}
}
template<class type> typename type::bindable bind(const type& a, string n) {
a.bind_impl(*this, n);
}
template<class type> struct void_box {
typedef void value;
};
template<class type> typename void_box<typename type::value_type>::value bind(
const type& a, string n
) {
int x=0;
for (const auto& c : a) {
bind(c, n + "_" + to_string(x));
++x;
}
}
string format_res_text() {
string res;
vector<int> sizes;
int next_line=1;
for (vector<string>& c : res_text) {
string c_tag=c.at(0);
if (output_tags && !c_tag.empty()) {
c_tag = "_" + c_tag;
}
c.at(1)=str( "Xx_##: ", next_line, c_tag ) + c.at(1);
++next_line;
for (int x=1;x<c.size();++x) {
while (sizes.size()<=x) {
sizes.push_back(0);
}
sizes[x]=max(sizes[x], int(c[x].size()));
}
}
sizes.at(1)=comment_asm_line_size;
for (vector<string>& c : res_text) {
for (int x=1;x<c.size();++x) {
res+=c[x];
if (x!=c.size()-1) {
for (int y=c[x].size();y<sizes.at(x);++y) {
res+= " " ;
}
}
}
res+= "\n" ;
}
return res;
}
};
struct expand_macros_tag {
expand_macros& m;
expand_macros_tag(expand_macros& t_m, string name) : m(t_m) {
m.tag_stack.push_back(name);
}
~expand_macros_tag() {
m.tag_stack.pop_back();
}
};
struct expand_macros_scope {
expand_macros& m;
expand_macros_scope(expand_macros& t_m, string name, bool is_public=false) : m(t_m) {
m.begin_scope(name, is_public);
}
~expand_macros_scope() {
m.end_scope();
}
};
#define EXPAND_MACROS_SCOPE expand_macros_scope c_scope(m, __func__)
#define EXPAND_MACROS_SCOPE_PUBLIC expand_macros_scope c_scope(m, __func__, true)
struct reg_scalar {
static const bool is_spill=false;
int value=-1;
reg_scalar() {}
explicit reg_scalar(int i) : value(i) {}
string name(int num_bits=64) const {
assert(value>=0);
const vector<string>* names=nullptr;
if (num_bits==64) {
names=&scalar_register_names_64;
} else
if (num_bits==32) {
names=&scalar_register_names_32;
} else
if (num_bits==16) {
names=&scalar_register_names_16;
} else {
assert(num_bits==8);
names=&scalar_register_names_8;
}
if (value<names->size()) {
return names->at(value);
} else {
return str( "PSEUDO_#_#", value, num_bits );
}
}
typedef void bindable;
void bind_impl(expand_macros& m, string n) const {
m.bind_impl(n, name(64));
m.bind_impl(n + "_32", name(32));
m.bind_impl(n + "_16", name(16));
m.bind_impl(n + "_8", name(8));
}
};
const reg_scalar reg_rsp=reg_scalar(0);
const reg_scalar reg_rax=reg_scalar(1);
const reg_scalar reg_rdx=reg_scalar(2);
const reg_scalar reg_rcx=reg_scalar(3);
const reg_scalar reg_rbx=reg_scalar(4);
const reg_scalar reg_rbp=reg_scalar(5);
const reg_scalar reg_rsi=reg_scalar(6);
const reg_scalar reg_rdi=reg_scalar(7);
const reg_scalar reg_r8=reg_scalar(8);
const reg_scalar reg_r9=reg_scalar(9);
const reg_scalar reg_r10=reg_scalar(10);
const reg_scalar reg_r11=reg_scalar(11);
const reg_scalar reg_r12=reg_scalar(12);
const reg_scalar reg_r13=reg_scalar(13);
const reg_scalar reg_r14=reg_scalar(14);
const reg_scalar reg_r15=reg_scalar(15);
struct reg_vector {
static const bool is_spill=false;
int value=-1;
reg_vector() {}
explicit reg_vector(int i) : value(i) {}
string name(int num_bits=512) const {
assert(value>=0);
string prefix;
if (num_bits==512) {
prefix = "Z";
} else
if (num_bits==256) {
prefix = "Y";
} else {
assert(num_bits==128);
prefix = "X";
}
if (value>=32 || (!enable_all_instructions && (value>=16 || num_bits!=128))) {
prefix = "PSEUDO_" + prefix;
}
return str( "#MM#", prefix, value );
}
typedef void bindable;
void bind_impl(expand_macros& m, string n) const {
m.bind_impl(n, name(128));
m.bind_impl(n + "_512", name(512));
m.bind_impl(n + "_256", name(256));
m.bind_impl(n + "_128", name(128));
}
};
struct reg_spill {
static const bool is_spill=true;
int value=-1; //byte offset
int size=-1;
int alignment=-1; //power of 2, up to 64
reg_spill() {}
reg_spill(int t_value, int t_size, int t_alignment) : value(t_value), size(t_size), alignment(t_alignment) {}
int get_rsp_offset() const {
return value-spill_bytes;
}
//this is negative
uint64 get_rsp_offset_uint64() const {
return uint64(value-spill_bytes);
}
string name() const {
assert(value>=0 && size>=1 && alignment>=1);
assert(value%alignment==0);
assert(value+size<=spill_bytes);
return str( "[RSP+#]", to_hex(value-spill_bytes) );
}
typedef void bindable;
void bind_impl(expand_macros& m, string n) const {
m.bind_impl(n, name());
m.bind_impl(n + "_rsp_offset", to_hex(value-spill_bytes));
}
reg_spill operator+(int byte_offset) const {
reg_spill res=*this;
res.value+=byte_offset;
res.size-=byte_offset;
res.alignment=1;
return res;
}
};
struct reg_alloc {
vector<int> order_to_scalar;
vector<int> scalar_to_order;
set<int> scalars;
set<int> vectors;
vector<bool> spills;
reg_alloc() {}
void add(reg_scalar s) {
bool insert_res=scalars.insert(scalar_to_order.at(s.value)).second;
assert(insert_res);
}
void init() {
const int num=32; //defines how many pseudo-registers to have
order_to_scalar.resize(num, -1);
scalar_to_order.resize(num, -1);
int next_order=0;
auto add_scalar=[&](reg_scalar scalar_reg) {
int scalar=scalar_reg.value;
int order=next_order;
++next_order;
assert(order_to_scalar.at(order)==-1);
order_to_scalar.at(order)=scalar;
assert(scalar_to_order.at(scalar)==-1);
scalar_to_order.at(scalar)=order;
add(reg_scalar(scalar));
};
add_scalar(reg_rbx);
add_scalar(reg_rbp);
add_scalar(reg_rsi);
add_scalar(reg_rdi);
add_scalar(reg_r8);
add_scalar(reg_r9);
add_scalar(reg_r10);
add_scalar(reg_r11);
add_scalar(reg_r12);
add_scalar(reg_r13);
add_scalar(reg_r14);
add_scalar(reg_r15);
add_scalar(reg_rcx);
add_scalar(reg_rdx);
add_scalar(reg_rax);
for (int x=16;x<num;++x) {
reg_scalar r;
r.value=x;
add_scalar(r);
}
for (int x=0;x<num;++x) {
vectors.insert(x);
}
for (int x=0;x<spill_bytes;++x) {
spills.push_back(true);
}
}
reg_scalar get_scalar(reg_scalar t_reg=reg_scalar()) {
assert(!scalars.empty());
int res=(t_reg.value==-1)? *scalars.begin() : scalar_to_order.at(t_reg.value);
bool erase_res=scalars.erase(res);
assert(erase_res);
return reg_scalar(order_to_scalar.at(res));
}
reg_vector get_vector() {
assert(!vectors.empty());
int res=*vectors.begin();
bool erase_res=vectors.erase(res);
assert(erase_res);
return reg_vector(res);
}
reg_spill get_spill(int size=8, int alignment=-1) {
if (alignment==-1) {
alignment=size;
}
assert(alignment==1 || alignment==2 || alignment==4 || alignment==8 || alignment==16 || alignment==32 || alignment==64);
for (int x=0;x<spills.size();++x) {
if (x%alignment!=0) {
continue;
}
bool valid=true;
for (int y=0;y<size;++y) {
if (x+y>=spills.size() || !spills[x+y]) {
valid=false;
break;
}
}
if (valid) {
for (int y=0;y<size;++y) {
spills.at(x+y)=false;
}
reg_spill res;
res.value=x;
res.size=size;
res.alignment=alignment;
return res;
}
}
assert(false);
return reg_spill();
}
reg_scalar bind_scalar(expand_macros& m, string name, reg_scalar t_reg=reg_scalar()) {
reg_scalar res=get_scalar(t_reg);
m.bind(res, name);
return res;
}
reg_vector bind_vector(expand_macros& m, string name) {
reg_vector res=get_vector();
m.bind(res, name);
return res;
}
reg_spill bind_spill(expand_macros& m, string name, int size=8, int alignment=-1) {
reg_spill res=get_spill(size, alignment);
m.bind(res, name);
return res;
}
};
namespace asm_code {
expand_macros m;
#define APPEND_M(data) m.append(data, __LINE__, __FILE__, __func__)
}

View File

@ -0,0 +1,131 @@
namespace asm_code {
string vpermq_mask(array<int, 4> lanes) {
int res=0;
for (int x=0;x<4;++x) {
int lane=lanes[x];
assert(lane>=0 && lane<4);
res|=lane << (2*x);
}
return to_hex(res);
}
string vpblendd_mask_4(array<int, 4> lanes) {
int res=0;
for (int x=0;x<4;++x) {
int lane=lanes[x];
assert(lane>=0 && lane<2);
res|=((lane==1)? 3 : 0) << (2*x);
}
return to_hex(res);
}
string vpblendd_mask_8(array<int, 8> lanes) {
int res=0;
for (int x=0;x<8;++x) {
int lane=lanes[x];
assert(lane>=0 && lane<2);
res|=((lane==1)? 1 : 0) << x;
}
return to_hex(res);
}
struct asm_function {
string name;
//this excludes the argument regs (if any). can add them after they are done being used
reg_alloc regs;
vector<reg_scalar> args;
vector<reg_scalar> pop_regs;
const vector<reg_scalar> all_save_regs={reg_rbp, reg_rbx, reg_r12, reg_r13, reg_r14, reg_r15};
const vector<reg_scalar> all_arg_regs={reg_rdi, reg_rsi, reg_rdx, reg_rcx, reg_r8, reg_r9};
//the scratch area ends at RSP (i.e. the last byte is at address RSP-1)
//RSP is 64-byte aligned
//RSP must be preserved but all other registers can be changed
//
//the arguments are stored in: RDI, RSI, RDX, RCX, R8, R9
//each argument is up to 8 bytes
asm_function(string t_name, int num_args=0, int num_regs=15) {
EXPAND_MACROS_SCOPE;
static bool outputted_header=false;
if (!outputted_header) {
APPEND_M(str( ".intel_syntax noprefix" ));
outputted_header=true;
}
name=t_name;
#ifdef CHIAOSX
APPEND_M(str( ".global _asm_func_#", t_name ));
APPEND_M(str( "_asm_func_#:", t_name ));
#else
APPEND_M(str( ".global asm_func_#", t_name ));
APPEND_M(str( "asm_func_#:", t_name ));
#endif
assert(num_regs<=15);
regs.init();
for (int x=0;x<num_args;++x) {
reg_scalar r=all_arg_regs.at(x);
regs.get_scalar(r);
args.push_back(r);
}
//takes 6 cycles max if nothing else to do
int num_available_regs=15-all_save_regs.size();
for (reg_scalar s : all_save_regs) {
if (num_regs>num_available_regs) {
APPEND_M(str( "PUSH #", s.name() ));
pop_regs.push_back(s);
++num_available_regs;
} else {
regs.get_scalar(s);
}
}
assert(num_available_regs==num_regs);
// RSP'=RSP&(~63) ; this makes it 64-aligned and can only reduce its value
// RSP''=RSP'-64 ; still 64-aligned but now there is at least 64 bytes of unused stuff
// [RSP'']=RSP ; store old value in unused area
APPEND_M(str( "MOV RAX, RSP" ));
APPEND_M(str( "AND RSP, -64" )); //-64 equals ~63
APPEND_M(str( "SUB RSP, 64" ));
APPEND_M(str( "MOV [RSP], RAX" ));
}
//the return value is the error code (0 if no error). it is put in RAX
~asm_function() {
EXPAND_MACROS_SCOPE;
//default return value of 0
APPEND_M(str( "MOV RAX, 0" ));
string end_label=m.alloc_label();
APPEND_M(str( "#:", end_label ));
//this takes 4 cycles including ret, if there is nothing else to do
APPEND_M(str( "MOV RSP, [RSP]" ));
for (int x=pop_regs.size()-1;x>=0;--x) {
APPEND_M(str( "POP #", pop_regs[x].name() ));
}
APPEND_M(str( "RET" ));
while (m.next_output_error_label_id<m.next_error_label_id) {
APPEND_M(str( "label_error_#:", m.next_output_error_label_id ));
assert(m.next_output_error_label_id!=0);
APPEND_M(str( "MOV RAX, #", to_hex(m.next_output_error_label_id) ));
APPEND_M(str( "JMP #", end_label ));
++m.next_output_error_label_id;
}
}
};
}

View File

@ -0,0 +1,52 @@
/*uint64 funnel_shift(uint64 low, uint64 high, int start, int size) {
assert(start>=0 && size>0 && start+size<=128);
uint128 v=(uint128(high)<<64) | uint128(low);
v>>=start;
v&=~(uint128(1)<<size);
return uint64(v);
} */
constexpr uint64 extract_bits(uint64 t, int start, int size) {
assert(start>=0 && start<64);
assert(size>=0 && start+size<=64);
t >>= start;
t &= (1ull<<size)-1;
return t;
}
constexpr uint64 insert_bits(uint64 t, uint64 bits, int start, int size) {
assert(start>=0 && start<64);
assert(size>=0 && start+size<=64);
assert(
( bits & ~((1ull<<size)-1) )
==0
);
bits <<= start;
uint64 mask = ((1ull<<size)-1) << start;
t &= ~mask;
t |= bits;
return t;
}
void output_bits(ostream& out, uint64 bits, int size) {
assert(size>0 && size<64);
assert(
( bits & ~((1ull<<size)-1) )
==0
);
for (int x=size-1;x>=0;--x) {
bool v=bits&(1ull<<x);
out << (v? "1" : "0");
}
}
constexpr uint64 bit_sequence(int start, int size) {
return insert_bits(0, (1ull<<size)-1, start, size);
}

View File

@ -0,0 +1,2 @@
#!/bin/bash
cat /proc/cpuinfo | grep -w cmovf | grep -w -q avx

View File

@ -0,0 +1,38 @@
#include "include.h"
#include "parameters.h"
#define COMPILE_ASM
#ifdef TEST_ASM
#undef TEST_ASM
#endif
#include "bit_manipulation.h"
#include "double_utility.h"
#include "integer.h"
#include "gpu_integer.h"
#include "gpu_integer_divide.h"
#include "gcd_base_continued_fractions.h"
#include "gcd_base_divide_table.h"
#include "gcd_128.h"
#include "gcd_unsigned.h"
#include "asm_types.h"
#include "asm_vm.h"
#include "asm_base.h"
#include "asm_gcd_base_continued_fractions.h"
#include "asm_gcd_base_divide_table.h"
#include "asm_gcd_128.h"
#include "asm_gcd_unsigned.h"
#include "asm_main.h"
int main(int argc, char** argv) {
set_rounding_mode();
asm_code::compile_asm();
}

View File

@ -0,0 +1,2 @@
#!/bin/bash
cp *.c *.cpp *.h *.sh sconstruct ~/projects/chia_vdf_entry/entry/

View File

@ -0,0 +1,115 @@
struct double_bits {
static const int exponent_num_bits=11;
static const int fraction_num_bits=52;
bool sign=false;
int exponent=0; //11 bits; starting value is -1023
uint64 fraction=0; //52 bits
double_bits() {}
double_bits(double v) {
uint64 v_bits=*(uint64*)(&v);
sign=extract_bits(v_bits, 63, 1);
exponent=extract_bits(v_bits, 52, 11);
fraction=extract_bits(v_bits, 0, 52);
}
void set_exponent(int v) {
exponent=v+1023;
}
uint64 to_uint64() const {
uint64 v_bits=0;
v_bits=insert_bits(v_bits, sign, 63, 1);
v_bits=insert_bits(v_bits, exponent, 52, 11);
v_bits=insert_bits(v_bits, fraction, 0, 52);
return v_bits;
}
double to_double() const {
uint64 v_bits=to_uint64();
return *((double*)&v_bits);
}
void output(ostream& out, bool decimal=false) const {
out << (sign? "-" : "+");
if (exponent==0 && fraction==0) {
out << "0";
}
if (exponent==0b11111111111) {
out << ((fraction==0)? "INF" : "NAN");
}
if (decimal) {
uint64 v=fraction | (1ull<<52);
out << v << "*2^" << exponent-1023-52;
} else {
out << ((exponent==0)? "0b0" : "0b1");
output_bits(out, fraction, 52);
out << "*2^" << exponent-1023-52;
}
}
};
void set_rounding_mode() {
assert(fesetround(FE_TOWARDZERO)==0); //truncated rounding
}
double d_exp2(int i) {
double_bits d;
d.sign=false;
d.set_exponent(i); //bit shift and integer add (either order)
return d.to_double();
}
//the cpu has to handle values of i that are above 2^52-1 so the built in instruction is slower than doing it this way
//can make this add a shift easily
double double_from_int(uint64 i) {
assert(i<(1ull<<52));
//b>=1 && b<2
double_bits b;
b.set_exponent(0);
b.fraction=i;
//res_1>=0 && res_1<1
double res_1=b.to_double(); //1 bitwise or (for the exponent)
double res=fma(res_1, d_exp2(52), -d_exp2(52));
//double_bits res_b=res_1-1;
//res_b.exponent+=52; //can't overflow; 1 uint64 add without shifts. can also use a 32/16 bit add or a double multiply
//double res=res_b.to_double();
assert(res==i);
return res;
}
//can make this handle shifted doubles easily
uint64 int_from_double(double v, bool exact=true) {
if (exact) {
uint64 v_test=v;
assert(v_test==v);
assert(v_test<(1ull<<52));
}
double res_1=fma(v, d_exp2(-52), 1); //one fma
double_bits b(res_1);
uint64 res=b.fraction; //1 bitwise and (for exponent)
if (exact) {
assert(res==v);
}
return res;
}
uint64 make_uint64(uint32 high, uint32 low) {
return uint64(high)<<32 | uint32(low);
}
uint128 make_uint128(uint64 high, uint64 low) {
return uint128(high)<<64 | uint128(low);
}

View File

@ -0,0 +1,245 @@
bool gcd_128(
array<uint128, 2>& ab, array<array<uint64, 2>, 2>& uv_uint64, int& uv_uint64_parity, bool is_lehmer, uint128 ab_threshold=0
) {
static int test_asm_counter=0;
++test_asm_counter;
bool test_asm_run=true;
bool test_asm_print=false; //(test_asm_counter%1000==0);
bool debug_output=false;
if (debug_output) {
cerr.setf(ios::fixed, ios::floatfield);
//cerr.setf(ios::showpoint);
}
assert(ab[0]>=ab[1] && ab[1]>=0);
uv_uint64={
array<uint64,2>{1, 0},
array<uint64,2>{0, 1}
};
uv_uint64_parity=0;
array<uint128, 2> ab_start=ab;
bool progress=false;
int iter=0;
while (true) {
if (debug_output) print(
"======== 1:", iter,
uint64(ab[0]), uint64(ab[0]>>64), uint64(ab[1]), uint64(ab[1]>>64),
uint64(ab_threshold), uint64(ab_threshold>>64)
);
if (ab[1]<=ab_threshold) {
break;
}
assert(ab[0]>=ab[1] && ab[1]>=0);
int a_zeros=0;
//this uses CMOV
if ((ab[0]>>64)!=0) {
uint64 a_high(ab[0]>>64);
assert(a_high!=0);
a_zeros=__builtin_clzll(a_high);
} else {
uint64 a_low(ab[0]);
assert(a_low!=0);
a_zeros=64+__builtin_clzll(a_low);
}
int a_num_bits=128-a_zeros;
if (is_lehmer) {
const int min_bits=96;
if (a_num_bits<min_bits) {
a_num_bits=min_bits;
}
}
int shift_amount=a_num_bits-gcd_base_bits;
if (shift_amount<0) {
shift_amount=0;
}
if (debug_output) print( "2:", a_zeros, a_num_bits, shift_amount );
//print( " gcd_128", a_num_bits );
vector2 ab_double{
double(uint64(ab[0]>>shift_amount)),
double(uint64(ab[1]>>shift_amount))
};
double ab_threshold_double(uint64(ab_threshold>>shift_amount));
if (debug_output) print( "3:", ab_double[0], ab_double[1], ab_threshold_double, is_lehmer || (shift_amount!=0) );
vector2 ab_double_2=ab_double;
//this doesn't need to be exact
//all of the comparisons with threshold are >, so this shouldn't be required
//if (shift_amount!=0) {
// ++ab_threshold_double;
//}
//void gcd_64(vector2 start_a, pair<matrix2, vector2>& res, int& num_iterations, bool approximate, int max_iterations) {
//}
matrix2 uv_double;
if (!gcd_base_continued_fraction(ab_double, uv_double, is_lehmer || (shift_amount!=0), ab_threshold_double)) {
print( " gcd_128 break 1" ); //this is fine
break;
}
if (debug_output) print( "4:", uv_double[0][0], uv_double[1][0], uv_double[0][1], uv_double[1][1], ab_double[0], ab_double[1] );
if (0) {
matrix2 uv_double_2;
if (!gcd_base_continued_fraction_2(ab_double_2, uv_double_2, is_lehmer || (shift_amount!=0), ab_threshold_double)) {
print( " gcd_128 break 2" );
break;
}
assert(uv_double==uv_double_2);
assert(ab_double==ab_double_2);
}
array<array<uint64,2>,2> uv_double_int={
array<uint64,2>{uint64(abs(uv_double[0][0])), uint64(abs(uv_double[0][1]))},
array<uint64,2>{uint64(abs(uv_double[1][0])), uint64(abs(uv_double[1][1]))}
};
int uv_double_parity=(uv_double[1][1]<0)? 1 : 0; //sign bit
array<array<uint64, 2>, 2> uv_uint64_new;
if (iter==0) {
uv_uint64_new=uv_double_int;
} else {
if (!multiply_exact(uv_double_int, uv_uint64, uv_uint64_new)) {
print( " gcd_128 slow 1" ); //calculated a bunch of quotients and threw all of them away, which is bad
break;
}
}
int uv_uint64_parity_new=uv_uint64_parity^uv_double_parity;
bool even=(uv_uint64_parity_new==0);
if (debug_output) print(
"5:", uv_uint64_new[0][0], uv_uint64_new[1][0], uv_uint64_new[0][1], uv_uint64_new[1][1], uv_uint64_parity_new
);
uint64 uv_00=uv_uint64_new[0][0];
uint64 uv_01=uv_uint64_new[0][1];
uint64 uv_10=uv_uint64_new[1][0];
uint64 uv_11=uv_uint64_new[1][1];
uint128 a_new_1=ab_start[0]; a_new_1*=uv_00; //a_new_1.set_negative(!even);
uint128 a_new_2=ab_start[1]; a_new_2*=uv_01; //a_new_2.set_negative(even);
uint128 b_new_1=ab_start[1]; b_new_1*=uv_11; //b_new_1.set_negative(!even);
uint128 b_new_2=ab_start[0]; b_new_2*=uv_10; //b_new_2.set_negative(even);
//CMOV
//print( " gcd_128 even", even );
if (!even) {
swap(a_new_1, a_new_2);
swap(b_new_1, b_new_2);
}
uint128 a_new_s=a_new_1-a_new_2;
uint128 b_new_s=b_new_1-b_new_2;
//if this assert hit, one of the quotients is wrong. the base case is not supposed to return incorrect quotients
//assert(a_new_s>=b_new_s && b_new_s>=0);
//commenting this out because a and b can be 128 bits now
//if (!(a_new_s>=b_new_s && b_new_s>=0)) {
//print( " gcd_128 slow 2" );
//break;
//}
uint128 a_new(a_new_s);
uint128 b_new(b_new_s);
if (debug_output) print( "6:", uint64(a_new), uint64(a_new>>64), uint64(b_new), uint64(b_new>>64) );
if (is_lehmer) {
assert(a_new>=b_new);
uint128 ab_delta=a_new-b_new;
// even:
// +uv_00 -uv_01
// -uv_10 +uv_11
uint128 u_delta=uint128(uv_10)+uint128(uv_00); //even: negative. odd: positive
uint128 v_delta=uint128(uv_11)+uint128(uv_01); //even: positive. odd: negative
// uv_10 is negative if even, positive if odd
// uv_11 is positive if even, negative if odd
bool passed_even=(b_new>=uint128(uv_10) && ab_delta>=v_delta);
bool passed_odd=(b_new>=uint128(uv_11) && ab_delta>=u_delta);
if (debug_output) print( "7:", passed_even, passed_odd );
//CMOV
if (!(even? passed_even : passed_odd)) {
print( " gcd_128 slow 5" ); //throwing away a bunch of quotients because the last one is bad
break;
}
}
if (a_new<=ab_threshold) {
if (debug_output) print( "8:" );
print( " gcd_128 slow 6" ); //still throwing away quotients
break;
}
ab={a_new, b_new};
uv_uint64=uv_uint64_new;
uv_uint64_parity=uv_uint64_parity_new;
progress=true;
++iter;
if (iter>=gcd_128_max_iter) {
if (debug_output) print( "9:" );
break; //this is the only way to exit the loop without wasting quotients
}
//todo break;
}
#ifdef TEST_ASM
#ifndef GENERATE_ASM_TRACKING_DATA
if (test_asm_run) {
if (test_asm_print) {
print( "test asm gcd_128", test_asm_counter );
}
asm_code::asm_func_gcd_128_data asm_data;
asm_data.ab_start_0_0=uint64(ab_start[0]);
asm_data.ab_start_0_8=uint64(ab_start[0]>>64);
asm_data.ab_start_1_0=uint64(ab_start[1]);
asm_data.ab_start_1_8=uint64(ab_start[1]>>64);
asm_data.is_lehmer=uint64(is_lehmer);
asm_data.ab_threshold_0=uint64(ab_threshold);
asm_data.ab_threshold_8=uint64(ab_threshold>>64);
int error_code=asm_code::asm_func_gcd_128(&asm_data);
assert(error_code==0);
assert(asm_data.u_0==uv_uint64[0][0]);
assert(asm_data.u_1==uv_uint64[1][0]);
assert(asm_data.v_0==uv_uint64[0][1]);
assert(asm_data.v_1==uv_uint64[1][1]);
assert(asm_data.parity==uv_uint64_parity);
assert(asm_data.no_progress==int(!progress));
}
#endif
#endif
return progress;
}

View File

@ -0,0 +1,757 @@
typedef array<double, 2> vector2;
typedef array<vector2, 2> matrix2;
matrix2 identity_matrix() {
return {
vector2{1, 0},
vector2{0, 1}
};
}
matrix2 quotient_matrix(double q) {
assert(int64(q)==q);
return {
vector2{0, 1},
vector2{1, -q}
};
}
bool range_check(double v) {
//this is the smallest value where you can add 1 exactly
//if you add 2, you get the same value as if you added 1
//if two floats are added/subtracted and there is a loss of precision, the absolute value of the result will be greater than this
//same with multiplication and fma
//(all of the doubles are integers whether they are exact or not)
return abs(v)<=double((1ull<<53)-1);
}
bool dot_product_exact(vector2 a, vector2 b, double& v, bool result_always_in_range=false) {
v=a[0]*b[0];
if (!range_check(v)) {
return false;
}
if (enable_fma_in_c_code) {
v=fma(a[1], b[1], v);
} else {
double v2=a[1]*b[1];
if (!range_check(v2)) {
return false;
}
v+=v2;
}
if (result_always_in_range) {
//still need the first range_check since the intermediate value might not be in range
assert(range_check(v));
}
return range_check(v);
}
//result_always_in_range ignored
bool dot_product_exact(array<uint64,2> a, array<uint64,2> b, uint64& v, bool result_always_in_range=false) {
uint64 t1;
if (__builtin_mul_overflow(a[0], b[0], &t1)) {
return false;
}
uint64 t2;
if (__builtin_mul_overflow(a[1], b[1], &t2)) {
return false;
}
return !__builtin_add_overflow(t1, t2, &v);
}
template<class type> bool multiply_exact(
array<array<type,2>,2> a, array<type,2> b, array<type,2>& v, bool result_always_in_range=false) {
return
dot_product_exact(a[0], b, v[0], result_always_in_range) &&
dot_product_exact(a[1], b, v[1], result_always_in_range)
;
}
template<class type> bool multiply_exact(
array<array<type,2>,2> a, array<array<type,2>,2> b, array<array<type,2>,2>& v, bool result_always_in_range=false
) {
return
dot_product_exact(a[0], array<type,2>{b[0][0], b[1][0]}, v[0][0], result_always_in_range) &&
dot_product_exact(a[0], array<type,2>{b[0][1], b[1][1]}, v[0][1], result_always_in_range) &&
dot_product_exact(a[1], array<type,2>{b[0][0], b[1][0]}, v[1][0], result_always_in_range) &&
dot_product_exact(a[1], array<type,2>{b[0][1], b[1][1]}, v[1][1], result_always_in_range)
;
}
struct continued_fraction {
vector<int> values;
matrix2 get_matrix() {
matrix2 res=identity_matrix();
for (int i : values) {
bool is_exact=multiply_exact(quotient_matrix(i), res, res);
assert(is_exact);
}
return res;
}
bool truncate(double max_matrix_value) {
bool res=false;
while (true) {
matrix2 m=get_matrix();
double max_value=max(
max(abs(m[0][0]), abs(m[0][1])),
max(abs(m[1][0]), abs(m[1][1]))
);
if (max_value>max_matrix_value) {
assert(!values.empty());
values.pop_back();
res=true;
} else {
break;
}
}
return res;
}
bool is_superset_of(continued_fraction& targ) {
if (values.size()>targ.values.size()) {
return false;
}
for (int x=0;x<values.size();++x) {
if (values[x]!=targ.values[x]) {
return false;
}
}
return true;
}
//rounds to 0; need to add 1 ulp to the fraction to get the possible range
//if is_exact is true then the result is inside the continued fraction
double get_bound(bool parity, bool& is_exact) {
assert(!values.empty());
bool first=true;
mpq_class res=0;
mpq_class one=1;
for (int x=values.size()-1;x>=0;--x) {
assert(values[x]>=1);
if (first) {
//the denominator of each fraction is between 1 and infinity
//this is already canonicalized
res=values[x] + (parity? 1 : 0);
} else {
//mpq_class(values[x]) is already canonicalized
res=mpq_class(values[x]) + one/res;
}
first=false;
}
double res_double=res.get_d();
{
mpq_class res_double_mpq(res_double);
res_double_mpq.canonicalize();
is_exact=(res_double_mpq==res);
}
return res_double;
}
//everything inside the bound starts with this continued fraction
//something outside the bound might also start with this continued fraction
//>= first, < second
pair<double, double> get_bound() {
bool a_exact=false;
double a=get_bound(false, a_exact);
bool b_exact=false;
double b=get_bound(true, b_exact);
if (a>b) {
swap(a, b);
swap(a_exact, b_exact);
}
if (!a_exact) {
//if a isn't exact, the next double value after a is inside the continued fraction (since it got rounded down). this assumes
// the bound isn't so small that it is close to the double machine epsilon; this is checked later by the double_table code
//if a is exact then it is inside the continued fraction
a=nextafter(a, HUGE_VAL);
}
//if b isn't exact, then it got rounded down and the b value is inside the continued fraction. the next value after b will
// be outside the continued fraction
//if b is exact then it is also inside the continued fraction and the next value is outside
b=nextafter(b, HUGE_VAL);
return make_pair(a, b);
}
};
//if you add 1 to the integer representation of a positive double, it will increase the value by 1 machine epsilon (assuming no overflow)
template<class type> struct double_table {
vector<type> data; //data[x] is >= range_start+x*delta and < range_start+(x+1)*delta
int exponent_bits;
int fraction_bits;
int64 range_start=0;
int64 range_end=0;
int64 delta=0;
double range_start_double=0;
double range_end_double=0;
int right_shift_amount=0;
uint64 range_start_shifted=0;
uint64 range_end_shifted=0;
//min value is 1
double_table(int t_exponent_bits, int t_fraction_bits) {
exponent_bits=t_exponent_bits;
fraction_bits=t_fraction_bits;
assert(exponent_bits>=0);
assert(fraction_bits>=1);
double_bits range_start_bits;
range_start_bits.sign=false;
range_start_bits.set_exponent(0);
range_start_bits.fraction=0;
range_start=range_start_bits.to_uint64();
range_start_double=range_start_bits.to_double();
double_bits range_end_bits;
range_end_bits.sign=false;
range_end_bits.set_exponent(1<<exponent_bits);
range_end_bits.fraction=0;
range_end=range_end_bits.to_uint64();
range_end_double=range_end_bits.to_double();
double_bits delta_bits;
delta_bits.sign=false;
delta_bits.exponent=0;
delta_bits.fraction=1ull<<(double_bits::fraction_num_bits-fraction_bits);
delta=delta_bits.to_uint64();
assert(range_end>range_start);
assert(range_start%delta==0);
assert(range_end%delta==0);
assert((range_end-range_start)/delta==1ull<<(exponent_bits+fraction_bits));
data.resize(1ull<<(exponent_bits+fraction_bits));
right_shift_amount=double_bits::fraction_num_bits-fraction_bits;
range_start_shifted=uint64(range_start)>>right_shift_amount;
range_end_shifted=uint64(range_end)>>right_shift_amount;
}
pair<double, double> index_range(int x) {
int64 res_low=range_start+x*delta;
int64 res_high=range_start+(x+1)*delta;
return make_pair(*(double*)&res_low, *(double*)&res_high);
}
bool lookup(double v, type& res) {
assert(v>=1);
res=type();
uint64 v_bits=*(uint64*)&v;
uint64 v_bits_shifted=v_bits>>right_shift_amount;
assert(v_bits_shifted>=range_start_shifted); //since v>=1
if (v_bits_shifted<range_start_shifted || v_bits_shifted>=range_end_shifted) {
return false;
}
//the table doesn't work if v is exactly between two slots
//happens if the remainder is 0 for one of the quotients
if (
(v_bits & (delta-1)) == 0 ||
(v_bits & (delta-1)) == delta-1
) {
return false;
}
res=data.at(v_bits_shifted-range_start_shifted);
return true;
}
//will assign all entries >= range.first and < range.second
//returns true if the range is at least 0.5 entries wide (for that area of the table) and is within the table bounds
bool assign(pair<double, double> range, type value, vector<type>& old_values) {
old_values.clear();
double start_double=range.first;
double end_double=range.second;
assert(start_double>0 && end_double>0 && end_double>=start_double && isfinite(start_double) && isfinite(end_double));
if (end_double<range_start_double || start_double>range_end_double) {
return false;
}
int64 start_bits=*(int64*)&start_double;
int64 end_bits=*(int64*)&end_double;
if (end_bits<=start_bits || 2*(end_bits-start_bits)<delta) {
return false;
}
int64 start_pos=(start_bits-range_start)/delta;
int64 end_pos=(end_bits-range_start)/delta + 1;
assert(end_pos>=start_pos);
if (start_pos<0) {
start_pos=0;
}
if (end_pos>data.size()) {
end_pos=data.size();
}
for (uint64 pos=start_pos;pos<end_pos;++pos) {
pair<double, double> slot_range=index_range(pos);
//if start_double==slot_range.first, then both ranges have the same starting double so that's fine
//if end_double==slot_range.second, then both ranges have the same ending double which is also fine
if (start_double<=slot_range.first && end_double>=slot_range.second) {
old_values.push_back(data[pos]);
data[pos]=value;
}
}
return true;
}
};
bool add_to_table(double_table<continued_fraction>& c_table, continued_fraction f) {
vector<continued_fraction> old_values;
if (!c_table.assign(f.get_bound(), f, old_values)) {
return false;
}
for (continued_fraction& c : old_values) {
assert(c.is_superset_of(f));
}
return true;
}
void add_children_to_table(double_table<continued_fraction>& c_table, continued_fraction f) {
f.values.push_back(1);
while (true) {
if (!add_to_table(c_table, f)) {
break;
}
add_children_to_table(c_table, f);
assert(f.values.back()<INT_MAX);
++f.values.back();
}
}
double_table<continued_fraction> generate_table(
int exponent_bits, int fraction_bits, uint64 truncate_max_value=1ull<<53, bool output_stats=false, bool dump=false
) {
double_table<continued_fraction> c_table(exponent_bits, fraction_bits);
add_children_to_table(c_table, continued_fraction());
bool any_truncated=false;
for (continued_fraction& c : c_table.data) {
assert(double(truncate_max_value)==truncate_max_value);
any_truncated |= c.truncate(truncate_max_value);
}
//if the exponent has too many bits, some of the table entries will span multiple integers and won't have any entries
//all of the full entries are at the start of the table, and all of the empty entires are at the end. they aren't interleaved
//when setting up the table range checks, should truncate off all of the empty values so they won't affect cache coherency
int num_empty=0;
for (int x=0;x<c_table.data.size();++x) {
if (dump) {
cerr << c_table.index_range(x).first << ", " << c_table.index_range(x).second << " : ";
for (int i : c_table.data[x].values) {
cerr << i << ", ";
}
cerr << "\n";
}
bool is_empty=(c_table.data[x].values.empty());
if (is_empty) {
++num_empty;
} else {
//all of the empty values are supposed to be before the non-empty values
assert(num_empty==0);
}
}
assert(num_empty==0); //gcd algorithm won't check for this
if (output_stats) {
print( "non-empty:", c_table.data.size()-num_empty, "; empty:", num_empty );
if (any_truncated) {
print( "truncated" );
}
}
return c_table;
}
//initial uv is the identity matrix
//parity is the number of quotients mod 2
//
//if uv is unsigned:
//-the parity is the sign of uv[1][1] (1 if negative)
//-to calculate the next uv, just multiply the unsigned matricies together. also add the parities modulo 2
//-to calculate ab from the starting ab, do a subtraction in the dot product instead of adding, then take the absolute value of the
// result. can also use the parity to decide what way to do the subtraction.
// - odd parity: b-a, a-b
// -even parity: a-b, b-a
// -can calculate assuming even parity. then sign extend the parity to 64 bits (from 1 bit) and use the parity as the carry in,
// then xor the result by the sign extended parity and add the carry. this can also determine the parity if it is unknown
//
// odd parity uv: { <=0 > 0
// > 0 < 0}
// even parity uv: { >=0 <=0s
// <=0 > 0}
//if this returns false then the new values are invalid and the old values are valid
//this works if u/v are unsigned, if v[1]-v[0] is replaced with |v[1]|+|v[0]| and -u[1] is replaced with |u1| etc
bool check_lehmer(array<int64, 2> a, array<int64, 2> u, array<int64, 2> v) {
// a[0]-a[1] is always >= 0 ; also a[1]>=0
// odd parity ; u[0]<=0 ; u[1]> 0 ; v[0]> 0 ; v[1]< 0
// even parity ; u[0]>=0 ; u[1]<=0 ; v[0]<=0 ; v[1]> 0
return
a[1]>=-u[1] && int128(a[0])-int128(a[1]) >= int128(v[1])-int128(v[0]) && // even parity
a[1]>=-v[1] && int128(a[0])-int128(a[1]) >= int128(u[1])-int128(u[0]) // odd parity
;
}
bool gcd_base_continued_fraction(vector2& ab, matrix2& uv, bool is_lehmer, double ab_threshold=0) {
static double_table<continued_fraction> c_table=generate_table(gcd_table_num_exponent_bits, gcd_table_num_fraction_bits);
static int test_asm_counter=0;
++test_asm_counter;
bool test_asm_run=true;
bool test_asm_print=false; //(test_asm_counter%1000==0);
bool debug_output=false;
assert(ab[0]>=ab[1] && ab[1]>=0);
uv=identity_matrix();
auto ab_start=ab;
bool progress=false;
bool enable_table=true;
int iter=0;
int iter_table=0;
int iter_slow=0;
if (debug_output) {
cerr.setf(ios::fixed, ios::floatfield);
//cerr.setf(ios::showpoint);
}
while (true) {
if (debug_output) print( "======== 1:", iter, ab[1], ab_threshold);
if (ab[1]<=ab_threshold) {
if (debug_output) print( "1.5:" );
break;
}
//print( " gcd_base", uint64(ab[0]) );
assert(ab[0]>=ab[1] && ab[1]>=0);
double q=ab[0]/ab[1];
if (debug_output) print( "2:", q );
vector2 new_ab;
matrix2 new_uv;
bool used_table=false;
continued_fraction f;
if (enable_table && c_table.lookup(q, f)) {
assert(!f.values.empty()); //table should be set up not to have empty values
if (debug_output) print( "3:", f.get_matrix()[0][0], f.get_matrix()[1][0], f.get_matrix()[0][1], f.get_matrix()[1][1] );
bool new_ab_valid=multiply_exact(f.get_matrix(), ab, new_ab, true); //a and b can only be reduced in magnitude
bool new_uv_valid=multiply_exact(f.get_matrix(), uv, new_uv);
bool new_a_valid=(new_ab[0]>ab_threshold);
if (debug_output) print( "4:", new_ab_valid, new_uv_valid, new_a_valid );
if (debug_output) print( "5:", new_ab[0], new_ab[1], new_uv[0][0], new_uv[1][0], new_uv[0][1], new_uv[1][1] );
if (new_ab_valid && new_uv_valid && new_a_valid) {
used_table=true;
++iter_table;
} else {
//this should be disabled to make the output the same as the non-table version
//this is disabled in the asm version
//if (is_lehmer && ab_threshold==0) {
//can also bypass the table but it is probably slower
//if ab_threshold is not 0, need to keep going since the partial gcd is about to terminate
//break;
//}
}
}
if (!used_table) {
//the native instruction is as fast as adding then subtracting some magic number
q=floor(q);
++iter_slow;
if (debug_output) print( "6:", q );
matrix2 m=quotient_matrix(q);
bool new_ab_valid=multiply_exact(m, ab, new_ab, true);
bool new_uv_valid=multiply_exact(m, uv, new_uv);
if (debug_output) print( "6.5:", new_ab[0], new_ab[1], new_uv[0][0], new_uv[1][0], new_uv[0][1], new_uv[1][1] );
if (!new_ab_valid || !new_uv_valid) {
if (debug_output) print( "7:" );
break;
}
//double new_b=fma(-q, ab[1], ab[0]);
//double new_u;
//double new_v;
//iter 0 is unrolled separately
//can probably just unroll all 6 iterations
//if (iter==0) {
//new_u=1;
//new_v=-q;
//} else {//}
//new_u=fma(-q, uv[1][0], uv[0][0]);
//new_v=fma(-q, uv[1][1], uv[0][1]);
//if (debug_output) print( "6:", q, new_b, new_u, new_v );
//if (!range_check(new_u) || !range_check(new_v)) {
//if (debug_output) print( "7" );
//break;
//}
//assert(range_check(new_b)); //a and b can only be reduced in magnitude
//new_ab={ab[1], new_b};
//new_uv={
//vector2{uv[1][0], uv[1][1]},
//vector2{ new_u, new_v}
//};
}
//this has to be checked on the first iteration if the table is not used (since there could be a giant quotient e.g. a=b)
//will check it even if the table is used. shouldn't affect performance
if (is_lehmer) {
double ab_delta=new_ab[0]-new_ab[1];
assert(range_check(ab_delta)); //both are nonnegative so the subtraction can't increase the magnitude
assert(ab_delta>=0); //ab[0] has to be greater
//the magnitudes add for these
//however, the comparison is ab_delta >= u_delta or v_delta, and ab_delta>=0, so the values of u_delta and v_delta can
// be increased. if the calculation is not exact, the values will be ceil'ed so they are exact or increased; never reduced
//double u_delta=uv[1][0]-uv[0][0];
//double v_delta=uv[1][1]-uv[0][1];
//even parity:
//don't care what the result of the odd comparison is as far as correctness goes. for performance, it has to be true most
// of the time
// uv[0][1]<=0 ; uv[1][1]>=0
//ab_delta+uv[0][1] is exact because the signs are opposite
//ab_delta+uv[0][0] is <= the true value so the comparison might return false wrongly. should be fine
bool even=(new_uv[1][1]>=0);
if (even) {
assert(range_check(ab_delta+new_uv[0][1]));
} else {
assert(range_check(ab_delta+new_uv[0][0]));
}
bool passed=
new_ab[1]>=-new_uv[1][0] && ab_delta+new_uv[0][1]>=new_uv[1][1] && // even parity. for odd parity this is always true
new_ab[1]>=-new_uv[1][1] && ab_delta+new_uv[0][0]>=new_uv[1][0] // odd parity. for even parity this is always true
;
if (debug_output) print( "8:", new_ab[1], new_uv[1][0], ab_delta, new_uv[0][1], new_uv[1][1] );
if (debug_output) print( "9:", new_ab[1], new_uv[1][1], ab_delta, new_uv[0][0], new_uv[1][0] );
if (debug_output) print( "10:", passed );
if (!passed) {
if (debug_output) print( "11:" );
if (enable_table) {
//this will make the table not change the output of the algorithm
//can just do a break in the actual code
//enable_table=false; continue;
break;
} else {
break;
}
}
}
ab=new_ab;
uv=new_uv;
progress=true;
++iter;
//print( " gcd_base quotient", q );
//print( "foo" );
{
//this would overflow a double; it works with modular arithmetic
int64 a_expected=int64(uv[0][0])*int64(ab_start[0]) + int64(uv[0][1])*int64(ab_start[1]);
int64 b_expected=int64(uv[1][0])*int64(ab_start[0]) + int64(uv[1][1])*int64(ab_start[1]);
assert(int64(ab[0])==a_expected);
assert(int64(ab[1])==b_expected);
}
if (iter>=gcd_base_max_iter) {
break;
}
//todo break;
}
//print( " gcd_base", iter_table+iter_slow, iter_table, iter_slow );
#ifdef TEST_ASM
#ifndef GENERATE_ASM_TRACKING_DATA
if (test_asm_run) {
if (test_asm_print) {
print( "test asm gcd_base", test_asm_counter );
}
double asm_ab[]={ab_start[0], ab_start[1]};
double asm_u[2];
double asm_v[2];
uint64 asm_is_lehmer[2]={(is_lehmer)? ~0ull : 0ull, (is_lehmer)? ~0ull : 0ull};
double asm_ab_threshold[2]={ab_threshold, ab_threshold};
uint64 asm_no_progress;
int error_code=asm_code::asm_func_gcd_base(asm_ab, asm_u, asm_v, asm_is_lehmer, asm_ab_threshold, &asm_no_progress);
assert(error_code==0);
assert(asm_ab[0]==ab[0]);
assert(asm_ab[1]==ab[1]);
assert(asm_u[0]==uv[0][0]);
assert(asm_u[1]==uv[1][0]);
assert(asm_v[0]==uv[0][1]);
assert(asm_v[1]==uv[1][1]);
assert(asm_no_progress==int(!progress));
}
#endif
#endif
return progress;
}
bool gcd_base_continued_fraction_2(vector2& ab_double, matrix2& uv_double, bool is_lehmer, double ab_threshold_double=0) {
int64 a_int=int64(ab_double[0]);
int64 b_int=int64(ab_double[1]);
int64 threshold_int=int64(ab_threshold_double);
assert(a_int>b_int && b_int>0);
array<int64, 2> ab={a_int, b_int};
array<int64, 2> u={1, 0};
array<int64, 2> v={0, 1};
auto apply=[&](int64 q, array<int64, 2> x) -> array<int64, 2> {
return {
x[1],
x[0]-q*x[1]
};
};
vector<uint64> res;
int num_iter=0;
int num_quotients=0;
while (ab[1]>threshold_int) {
//print( " gcd_base_2", ab[0] );
int64 q=ab[0]/ab[1];
assert(q>=0);
array<int64, 2> new_ab=apply(q, ab);
array<int64, 2> new_u=apply(q, u);
array<int64, 2> new_v=apply(q, v);
++num_iter;
if (is_lehmer && !check_lehmer(new_ab, new_u, new_v)) {
break;
}
//print(num_iter, u[0], u[1], v[0], v[1]);
auto ab_double_new=ab_double;
auto uv_double_new=uv_double;
ab_double_new[0]=double(new_ab[0]);
ab_double_new[1]=double(new_ab[1]);
uv_double_new[0][0]=double(new_u[0]);
uv_double_new[0][1]=double(new_v[0]);
uv_double_new[1][0]=double(new_u[1]);
uv_double_new[1][1]=double(new_v[1]);
if (
int64(ab_double_new[0])!=new_ab[0] ||
int64(ab_double_new[1])!=new_ab[1] ||
int64(uv_double_new[0][0])!=new_u[0] ||
int64(uv_double_new[0][1])!=new_v[0] ||
int64(uv_double_new[1][0])!=new_u[1] ||
int64(uv_double_new[1][1])!=new_v[1]
) {
break;
}
ab=new_ab;
u=new_u;
v=new_v;
ab_double=ab_double_new;
uv_double=uv_double_new;
//print( " gcd_base_2 quotient", q );
res.push_back(q);
++num_quotients;
//todo break;
}
return num_quotients!=0;
}

View File

@ -0,0 +1,232 @@
const uint64 data_mask=bit_sequence(0, data_size);
const int carry_size=64-data_size;
const uint64 carry_mask=bit_sequence(data_size, carry_size);
namespace simd_integer_namespace {
int64 abs_int(int64 v) {
return (v<0)? -v : v;
}
int divide_table_stats_calls=0;
int divide_table_stats_table=0;
//generic_stats gcd_64_num_iterations;
//used for both gcd and reduce
int64 divide_table_lookup(int64 index) {
assert(index>=0 && index<=bit_sequence(0, divide_table_index_bits));
uint128 res = (~uint128(0)) / uint128(max(uint64(index), uint64(1)));
res>>=64;
return res;
}
int64 divide_table_64(int64 a, int64 b, int64& q) {
assert(b>0);
q=a/b;
int64 r=a%b;
if (r<0) {
r+=b;
--q;
}
assert(r>=0 && r<b && q*b+r==a);
return r;
}
//note: this floors the quotient instead of truncating it like the div instruction
int64 divide_table(int64 a, int64 b, int64& q) {
const bool test_asm_funcs=false;
++divide_table_stats_calls;
assert(b>0);
//b_shift=(64-divide_table_index_bits) - lzcnt(b)
//bsr(b)=63-lzcnt(b)
//63-bsr(b)=lzcnt(b)
//b_shift=(64-divide_table_index_bits) - 63-bsr(b)
//b_shift=64-divide_table_index_bits - 63 + bsr(b)
//b_shift=1-divide_table_index_bits + bsr(b)
//b_shift=bsr(b) - (divide_table_index_bits-1)
int b_shift = (64-divide_table_index_bits) - __builtin_clzll(b);
if (b_shift<0) { //almost never happens
b_shift=0;
}
int64 b_approx = b >> b_shift;
int64 b_approx_inverse = divide_table_lookup(b_approx);
q = (int128(a)*int128(b_approx_inverse)) >> 64; //high part of product
q >>= b_shift;
int128 qb_128=int128(q)*int128(b);
int64 qb_64=int64(qb_128);
int128 r_128=int128(a)-int128(qb_64);
int64 r_64=int64(r_128);
//int128 r=int128(a)-int128(q)*int128(b);
//if (uint128(r)>=b) {
bool invalid_1=(int128(qb_64)!=qb_128 || int128(r_64)!=r_128 || uint64(r_64)>=b);
int128 r_2=int128(a)-int128(q)*int128(b);
bool invalid_2=(uint128(r_2)>=b);
assert(invalid_1==invalid_2);
if (!invalid_2) {
assert(r_64==int64(r_2));
}
int64 r=r_2;
if (invalid_2) {
r=divide_table_64(a, b, q);
} else {
++divide_table_stats_table;
}
int64 q_expected;
int64 r_expected=divide_table_64(a, b, q_expected);
assert(q==q_expected);
assert(r==r_expected);
//if (test_asm_funcs) {
//int64 q_asm;
//int64 r_asm=divide_table_asm(a, b, q_asm);
//assert(q_asm==q_expected);
//assert(r_asm==r_expected);
//}
return r;
}
void gcd_64(
array<int64, 2> start_a, pair<array<int64, 4>, array<int64, 2>>& res, int& num_iterations, bool approximate, int max_iterations
) {
const bool test_asm_funcs=false;
array<int64, 4> uv={1, 0, 0, 1};
array<int64, 2> a=start_a;
num_iterations=0;
if (approximate && (start_a[0]==start_a[1] || start_a[1]==0)) {
res=make_pair(uv, a);
return;
}
int asm_num_iterations=0;
array<int64, 4> uv_asm=uv;
array<int64, 2> a_asm=a;
while (true) {
if (test_asm_funcs) {
//if (gcd_64_iteration_asm(a_asm, uv_asm, approximate)) {
//++asm_num_iterations;
//}
}
if (a[1]==0) {
break;
}
assert(a[0]>a[1] && a[1]>0);
int64 q;
int64 r=divide_table(a[0], a[1], q);
{
int shift_amount=63-gcd_num_quotient_bits;
if ((q<<shift_amount)>>shift_amount!=q) {
break;
}
}
array<int64, 2> new_a={a[1], r};
array<int64, 4> new_uv;
for (int x=0;x<2;++x) {
new_uv[0*2+x]=uv[1*2+x];
new_uv[1*2+x]=uv[0*2+x] - q*uv[1*2+x];
}
bool valid=true;
if (approximate) {
assert(new_uv[1*2+0]!=0);
bool is_even=(new_uv[1*2+0]<0);
bool valid_exact;
if (is_even) {
valid_exact=(new_a[1]>=-new_uv[1*2+0] && new_a[0]-new_a[1]>=new_uv[1*2+1]-new_uv[0*2+1]);
} else {
valid_exact=(new_a[1]>=-new_uv[1*2+1] && new_a[0]-new_a[1]>=new_uv[1*2+0]-new_uv[0*2+0]);
}
//valid=valid_exact;
valid=
(new_a[1]>=-new_uv[1*2+0] && new_a[0]-new_a[1]>=new_uv[1*2+1]-new_uv[0*2+1]) &&
(new_a[1]>=-new_uv[1*2+1] && new_a[0]-new_a[1]>=new_uv[1*2+0]-new_uv[0*2+0])
;
assert(valid==valid_exact);
if (valid) {
assert(valid_exact);
}
}
//have to do this even if approximate is false
for (int x=0;x<4;++x) {
if (abs_int(new_uv[x])>data_mask) {
valid=false;
}
}
if (!valid) {
break;
}
uv=new_uv;
a=new_a;
++num_iterations;
if (test_asm_funcs) {
assert(uv==uv_asm);
assert(a==a_asm);
assert(num_iterations==asm_num_iterations);
}
if (num_iterations>=max_iterations) {
break;
}
}
//gcd_64_num_iterations.add(num_iterations);
for (int x=0;x<4;++x) {
assert(abs_int(uv[x])<=data_mask);
}
if (test_asm_funcs) {
assert(uv==uv_asm);
//assert(a==a_asm); the asm code will update a even if it becomes invalid; fine since it's not used
assert(num_iterations==asm_num_iterations);
}
res=make_pair(uv, a);
}
}

View File

@ -0,0 +1,345 @@
//threshold is 0 to calculate the normal gcd
template<int size> void gcd_unsigned_slow(
array<fixed_integer<uint64, size>, 2>& ab,
array<fixed_integer<uint64, size>, 2>& uv,
int& parity,
fixed_integer<uint64, size> threshold=fixed_integer<uint64, size>(integer(0))
) {
assert(ab[0]>threshold);
while (ab[1]>threshold) {
fixed_integer<uint64, size> q(ab[0]/ab[1]);
fixed_integer<uint64, size> r(ab[0]%ab[1]);
ab[0]=ab[1];
ab[1]=r;
//this is the absolute value of the cofactor matrix
auto u1_new=uv[0] + q*uv[1];
uv[0]=uv[1];
uv[1]=u1_new;
parity=-parity;
}
}
//todo
//test this by making two numbers that have a specified quotient sequence. can add big quotients then
//to generate numbers with a certain quotient sequence:
//euclidean algorithm: q=a/b ; a'=b ; b'=a-q*b ; terminates when b'=0
//initially b'=0 and all qs are known
//first iteration: b'=a-q*b=0 ; a=q*b ; select some b and this will determine a
//next: b'=a-q*b ; a'=b ; b'=a-q*a' ; b'+q*a'=a
//uv is <1,0> to calculate |u| and <0,1> to calculate |v|
//parity is negated for each quotient
template<int size> void gcd_unsigned(
array<fixed_integer<uint64, size>, 2>& ab,
array<fixed_integer<uint64, size>, 2>& uv,
int& parity,
fixed_integer<uint64, size> threshold=fixed_integer<uint64, size>(integer(0))
) {
typedef fixed_integer<uint64, size> int_t;
static int test_asm_counter=0;
++test_asm_counter;
bool test_asm_run=true;
bool test_asm_print=(test_asm_counter%1000==0);
bool debug_output=false;
assert(ab[0]>=ab[1] && !ab[1].is_negative());
assert(!ab[0].is_negative() && !ab[1].is_negative());
assert(!uv[0].is_negative() && !uv[1].is_negative());
auto ab_start=ab;
auto uv_start=uv;
int parity_start=parity;
int a_num_bits_old=-1;
int iter=0;
vector<array<array<uint64, 2>, 2>> matricies;
vector<int> local_parities;
bool valid=true;
while (true) {
assert(ab[0]>=ab[1] && !ab[1].is_negative());
if (debug_output) {
print( "" );
print( "" );
print( "====================================" );
for (int x=0;x<size;++x) print( "a limb", x, ab[0][x] );
print( "" );
for (int x=0;x<size;++x) print( "b limb", x, ab[1][x] );
print( "" );
for (int x=0;x<size;++x) print( "threshold limb", x, threshold[x] );
print( "" );
}
if (ab[0]<=threshold) {
valid=false;
print( " gcd_unsigned slow 1" );
break;
}
if (ab[1]<=threshold) {
if (debug_output) print( "ab[1]<=threshold" );
break;
}
//there is a cached num limbs for a. the num limbs for b and ab_threshold is smaller
//to calculate the new cached num limbs:
//-look at previous value. if limb is 0, go on to the next lowest limb. a cannot be 0 but should still tolerate this without crashing
//-unroll this two times
//-if more than 2 iterations are required, use a loop
//-a can only decrease in size so its true size can't be larger
//-this also calculates the head limb of a. need the top 3 head limbs. they are 0-padded if a is less than 3 nonzero limbs
//-the 3 head limbs are used to do the shift
//-this also truncates threshold and compares a[1] with the truncated value. it will exit if they are equal. this is not
// exactly the same as the c++ code
//-should probably implement this in c++ first then to make the two codes the same
int a_num_bits=ab[0].num_bits();
int shift_amount=a_num_bits-128; //todo //changed this to 128 bits
if (shift_amount<0) {
shift_amount=0;
}
//print( "gcd_unsigned", a_num_bits, a_num_bits_old-a_num_bits );
a_num_bits_old=a_num_bits;
array<uint128, 2> ab_head={
uint128(ab[0].window(shift_amount)) | (uint128(ab[0].window(shift_amount+64))<<64),
uint128(ab[1].window(shift_amount)) | (uint128(ab[1].window(shift_amount+64))<<64)
};
//assert((ab_head[0]>>127)==0);
//assert((ab_head[1]>>127)==0);
uint128 threshold_head=uint128(threshold.window(shift_amount)) | (uint128(threshold.window(shift_amount+64))<<64);
//assert((threshold_head>>127)==0);
//don't actually need to do this
//it will compare threshold_head with > so it will already exit if they are equal
//if (shift_amount!=0) {
// ++threshold_head;
//}
if (debug_output) print( "a_num_bits:", a_num_bits );
if (debug_output) print( "a last index:", (a_num_bits+63/64)-1 );
if (debug_output) print( "shift_amount:", shift_amount );
if (debug_output) print( "ab_head[0]:", uint64(ab_head[0]), uint64(ab_head[0]>>64) );
if (debug_output) print( "ab_head[1]:", uint64(ab_head[1]), uint64(ab_head[1]>>64) );
if (debug_output) print( "threshold_head:", uint64(threshold_head), uint64(threshold_head>>64) );
array<array<uint64, 2>, 2> uv_uint64;
int local_parity; //1 if odd, 0 if even
if (gcd_128(ab_head, uv_uint64, local_parity, shift_amount!=0, threshold_head)) {
//int local_parity=(uv_double[1][1]<0)? 1 : 0; //sign bit
bool even=(local_parity==0);
if (debug_output) print( "u:", uv_uint64[0][0], uv_uint64[1][0] );
if (debug_output) print( "v:", uv_uint64[0][1], uv_uint64[1][1] );
if (debug_output) print( "local parity:", local_parity );
uint64 uv_00=uv_uint64[0][0];
uint64 uv_01=uv_uint64[0][1];
uint64 uv_10=uv_uint64[1][0];
uint64 uv_11=uv_uint64[1][1];
//can use a_num_bits to make these smaller. this is at most a 2x speedup for these mutliplications which probably doesn't matter
//can do this with an unsigned subtraction and just swap the pointers
//
//this is an unsigned subtraction with the input pointers swapped to make the result nonnegative
//
//this uses mulx/adox/adcx if available for the multiplication
//will unroll the multiplication loop but early-exit based on the number of limbs in a (calculated before). this gives each
//branch its own branch predictor entry. each branch is at a multiple of 4 limbs. don't need to pad a
int_t a_new_1=ab[0]; a_new_1*=uv_00; a_new_1.set_negative(!even);
int_t a_new_2=ab[1]; a_new_2*=uv_01; a_new_2.set_negative(even);
int_t b_new_1=ab[0]; b_new_1*=uv_10; b_new_1.set_negative(even);
int_t b_new_2=ab[1]; b_new_2*=uv_11; b_new_2.set_negative(!even);
//both of these are subtractions; the signs determine the direction. the result is nonnegative
int_t a_new;
int_t b_new;
if (!even) {
a_new=int_t(a_new_2 + a_new_1);
b_new=int_t(b_new_1 + b_new_2);
} else {
a_new=int_t(a_new_1 + a_new_2);
b_new=int_t(b_new_2 + b_new_1);
}
//this allows the add function to be optimized
assert(!a_new.is_negative());
assert(!b_new.is_negative());
//do not do any of this stuff; instead return an array of matricies
//the array is processed while it is being generated so it is cache line aligned, has a counter, etc
ab[0]=a_new;
ab[1]=b_new;
//bx and by are nonnegative
auto dot=[&](uint64 ax, uint64 ay, int_t bx, int_t by) -> int_t {
bx*=ax;
by*=ay;
return int_t(bx+by);
};
int_t new_uv_0=dot(uv_00, uv_01, uv[0], uv[1]);
int_t new_uv_1=dot(uv_10, uv_11, uv[0], uv[1]);
uv[0]=new_uv_0;
uv[1]=new_uv_1;
//local_parity is 0 even, 1 odd
//want 1 even, -1 odd
//todo: don't do this; just make it 0 even, 1 odd
parity*=1-local_parity-local_parity;
matricies.push_back(uv_uint64);
local_parities.push_back(local_parity);
} else {
//can just make the gcd fail if this happens in the asm code
print( " gcd_unsigned slow" );
//todo assert(false); //very unlikely to happen if there are no bugs
valid=false;
break;
/*had_slow=true;
fixed_integer<uint64, size> q(ab[0]/ab[1]);
fixed_integer<uint64, size> r(ab[0]%ab[1]);
ab[0]=ab[1];
ab[1]=r;
//this is the absolute value of the cofactor matrix
auto u1_new=uv[0] + q*uv[1];
uv[0]=uv[1];
uv[1]=u1_new;
parity=-parity;*/
}
++iter;
}
{
auto ab2=ab_start;
auto uv2=uv_start;
int parity2=parity_start;
gcd_unsigned_slow(ab2, uv2, parity2, threshold);
if (valid) {
assert(integer(ab[0]) == integer(ab2[0]));
assert(integer(ab[1]) == integer(ab2[1]));
assert(integer(uv[0]) == integer(uv2[0]));
assert(integer(uv[1]) == integer(uv2[1]));
assert(parity==parity2);
} else {
ab=ab2;
uv=uv2;
parity=parity2;
}
}
#ifdef TEST_ASM
if (test_asm_run) {
if (test_asm_print) {
print( "test asm gcd_unsigned", test_asm_counter );
}
asm_code::asm_func_gcd_unsigned_data asm_data;
const int asm_size=gcd_size;
const int asm_max_iter=gcd_max_iterations;
assert(size>=1 && size<=asm_size);
fixed_integer<uint64, asm_size> asm_a(ab_start[0]);
fixed_integer<uint64, asm_size> asm_b(ab_start[1]);
fixed_integer<uint64, asm_size> asm_a_2;
fixed_integer<uint64, asm_size> asm_b_2;
fixed_integer<uint64, asm_size> asm_threshold(threshold);
uint64 asm_uv_counter_start=1234;
uint64 asm_uv_counter=asm_uv_counter_start;
array<array<uint64, 8>, asm_max_iter+1> asm_uv;
asm_data.a=&asm_a[0];
asm_data.b=&asm_b[0];
asm_data.a_2=&asm_a_2[0];
asm_data.b_2=&asm_b_2[0];
asm_data.threshold=&asm_threshold[0];
asm_data.uv_counter_start=asm_uv_counter_start;
asm_data.out_uv_counter_addr=&asm_uv_counter;
asm_data.out_uv_addr=(uint64*)&asm_uv[1];
asm_data.iter=-2; //uninitialized
asm_data.a_end_index=size-1;
int error_code=asm_code::asm_func_gcd_unsigned(&asm_data);
auto asm_get_uv=[&](int i) {
array<array<uint64, 2>, 2> res;
res[0][0]=asm_uv[i+1][0];
res[1][0]=asm_uv[i+1][1];
res[0][1]=asm_uv[i+1][2];
res[1][1]=asm_uv[i+1][3];
return res;
};
auto asm_get_parity=[&](int i) {
uint64 r=asm_uv[i+1][4];
assert(r==0 || r==1);
return bool(r);
};
auto asm_get_exit_flag=[&](int i) {
uint64 r=asm_uv[i+1][5];
assert(r==0 || r==1);
return bool(r);
};
if (error_code==0) {
assert(valid);
assert(asm_data.iter>=0 && asm_data.iter<=asm_max_iter); //total number of iterations performed
bool is_even=((asm_data.iter-1)&1)==0; //parity of last iteration (can be -1)
fixed_integer<uint64, asm_size>& asm_a_res=(is_even)? asm_a_2 : asm_a;
fixed_integer<uint64, asm_size>& asm_b_res=(is_even)? asm_b_2 : asm_b;
assert(integer(asm_a_res) == integer(ab[0]));
assert(integer(asm_b_res) == integer(ab[1]));
for (int x=0;x<=matricies.size();++x) {
assert( asm_get_exit_flag(x-1) == (x==matricies.size()) );
if (x!=matricies.size()) {
assert(asm_get_parity(x)==local_parities[x]);
assert(asm_get_uv(x)==matricies[x]);
}
}
assert(matricies.size()==asm_data.iter);
assert(asm_uv_counter==asm_uv_counter_start+asm_data.iter-1); //the last iteration that updated the counter is iter-1
} else {
if (!valid) {
print( "test asm gcd_unsigned error", error_code );
}
}
}
#endif
assert(integer(ab[0])>integer(threshold));
assert(integer(ab[1])<=integer(threshold));
}

View File

@ -0,0 +1,252 @@
#include "generic_macros.h"
#include <fstream>
#ifndef ILYA_SHARED_HEADER_GENERIC
#define ILYA_SHARED_HEADER_GENERIC
namespace generic {
using namespace std;
template<class type_a> void print_impl(ostream& out, const type_a& a) {}
template<class type_b> void print_impl(ostream& out, const char* a, const type_b& b) {
out << " " << b;
}
template<class type_a, class type_b> void print_impl(ostream& out, const type_a& a, const type_b& b) {
out << ", " << b;
}
template<class type_a, class type_b, class... types> void print_impl(ostream& out, const type_a& a, const type_b& b, const types&... targs) {
print_impl(out, a, b);
print_impl(out, b, targs...);
}
template<class type_a, class... types> void print_to(ostream& out, const type_a& a, const types&... targs) {
out << a;
print_impl(out, a, targs...);
out << "\n";
}
template<class type_a, class... types> void print(const type_a& a, const types&... targs) {
print_to(cerr, a, targs...);
}
//if buffer is not null, will return an empty string
string getstream(istream& targ, int block_size=10, string* buffer=nullptr) {
string new_buffer;
string& res=(buffer!=nullptr)? *buffer : new_buffer;
res.clear();
while(1) {
res.resize(res.size()+block_size);
targ.read(&(res[res.size()-block_size]), block_size);
int c=targ.gcount();
if (c!=block_size) {
res.resize(res.size()-block_size+c);
assert(targ.eof());
return new_buffer;
}
}
}
string getfile(const string& name, bool binary=0, int block_size=1024) {
ifstream in(name, binary? ios::binary|ios_base::in : ios_base::in);
assert(in.good());
return getstream(in, block_size);
}
struct less_ptr {
template<class ptr_type> bool operator()(ptr_type a, ptr_type b) {
return *a<*b;
}
};
template<class type> type instance_of();
template<class type> std::string to_string(std::ostringstream& s, const type& targ) {
s.clear();
s.str("");
s << targ;
return s.str();
}
template<class type> std::string to_string(const type& targ) {
static std::ostringstream s;
return to_string(s, targ);
}
template<class type> pair<type, bool> checked_from_string(std::istringstream& s, const std::string& targ) {
s.clear();
s.str(targ);
type res;
s >> res;
return make_pair(res, s.eof() && !s.fail());
}
template<class type> type from_string(std::istringstream& s, const std::string& targ) {
return checked_from_string<type>(s, targ).first;
}
template<class type> type from_string(const std::string& targ) {
static std::istringstream s;
return from_string<type>(s, targ);
}
template<class type> pair<type, bool> checked_from_string(const std::string& targ) {
static std::istringstream s;
return checked_from_string<type>(s, targ);
}
template<class type> type assert_from_string(const std::string& targ) {
auto res=checked_from_string<type>(targ);
assert(res.second);
return res.first;
}
template<class type, class... types> unique_ptr<type> make_unique_ptr(types&&... targs) {
return unique_ptr<type>(new type(forward<types>(targs)...));
}
template<class type, int size> int array_size(type(&)[size]) {
return size;
}
template<class type> std::ostream& print_as_number(std::ostream& out, const type& targ) { out << targ; return out; }
template<> std::ostream& print_as_number<unsigned char>(std::ostream& out, const unsigned char& targ) { out << int(targ); return out; }
template<> std::ostream& print_as_number<signed char>(std::ostream& out, const signed char& targ) { out << int(targ); return out; }
template<> std::ostream& print_as_number<char>(std::ostream& out, const char& targ) { out << int(targ); return out; }
//
template<bool n, class type> struct only_if {};
template<class type> struct only_if<1, type> { typedef type good; };
template<bool n> typename only_if<n, void>::good assert_true() {}
template<class a, class b, class type> struct only_if_same_types {};
template<class a, class type> struct only_if_same_types<a, a, type> { typedef type good; };
template<class a, class b> typename only_if_same_types<a, b, void>::good assert_same_types() {}
template<class a, class b, class type> struct only_if_not_same_types { typedef type good; };
template<class a, class type> struct only_if_not_same_types<a, a, type> {};
template<class a, class b> typename only_if_not_same_types<a, b, void>::good assert_not_same_types() {}
template<int n> struct static_abs { static const int res=n<0? -n : n; };
template<int n> struct static_sgn { static const int res=n<0? -1 : (n>0? 1 : 0); };
template<int a, int b> struct static_max { static const int res=a>b? a : b; };
template<int a, int b> struct static_min { static const int res=a<b? a : b; };
template<class type> class wrap_type { typedef type res; };
//
template<class type_a, class type_b> class union_pair {
template<class, class> friend class union_pair;
static const size_t size_bytes=static_max<sizeof(type_a), sizeof(type_b)>::res;
static const size_t alignment_bytes=static_max<alignof(type_a), alignof(type_b)>::res;
typename aligned_storage<size_bytes, alignment_bytes>::type buffer;
bool t_is_first;
public:
union_pair() : t_is_first(1) { new(&buffer) type_a(); }
union_pair(int, int) : t_is_first(0) { new(&buffer) type_b(); }
union_pair(const type_a& targ) : t_is_first(1) { new(&buffer) type_a(targ); }
union_pair(const type_b& targ) : t_is_first(0) { new(&buffer) type_b(targ); }
union_pair(type_a&& targ) : t_is_first(1) { new(&buffer) type_a(move(targ)); }
union_pair(type_b&& targ) : t_is_first(0) { new(&buffer) type_b(move(targ)); }
union_pair(const union_pair& targ) : t_is_first(targ.t_is_first) {
if (t_is_first) new(&buffer) type_a(targ.first()); else new(&buffer) type_b(targ.second());
}
union_pair(const union_pair<type_b, type_a>& targ) : t_is_first(!targ.t_is_first) {
if (t_is_first) new(&buffer) type_a(targ.second()); else new(&buffer) type_b(targ.first());
}
union_pair(union_pair&& targ) : t_is_first(targ.t_is_first) {
if (t_is_first) new(&buffer) type_a(move(targ.first())); else new(&buffer) type_b(move(targ.second()));
}
union_pair(union_pair<type_b, type_a>&& targ) : t_is_first(!targ.t_is_first) {
if (t_is_first) new(&buffer) type_a(move(targ.second())); else new(&buffer) type_b(move(targ.first()));
}
union_pair& operator=(const type_a& targ) {
if (is_first()) first()=targ; else set_first(targ);
return *this;
}
union_pair& operator=(const type_b& targ) {
if (is_second()) second()=targ; else set_second(targ);
return *this;
}
union_pair& operator=(type_a&& targ) {
if (is_first()) first()=move(targ); else set_first(move(targ));
return *this;
}
union_pair& operator=(type_b&& targ) {
if (is_second()) second()=move(targ); else set_second(move(targ));
return *this;
}
union_pair& operator=(const union_pair& targ) {
if (targ.is_first()) {
return *this=targ.first();
} else {
return *this=targ.second();
}
}
union_pair& operator=(const union_pair<type_b, type_a>& targ) {
if (targ.is_first()) {
return *this=targ.first();
} else {
return *this=targ.second();
}
}
union_pair& operator=(union_pair&& targ) {
if (targ.is_first()) {
return *this=move(targ.first());
} else {
return *this=move(targ.second());
}
}
union_pair& operator=(union_pair<type_b, type_a>&& targ) {
if (targ.is_first()) {
return *this=move(targ.first());
} else {
return *this=move(targ.second());
}
}
typedef type_a first_type;
typedef type_b second_type;
bool is_first() const { return t_is_first; }
bool is_second() const { return !t_is_first; }
//
type_a& first() { return *reinterpret_cast<type_a*>(&buffer); }
const type_a& first() const { return *reinterpret_cast<const type_a*>(&buffer); }
type_b& second() { return *reinterpret_cast<type_b*>(&buffer); }
const type_b& second() const { return *reinterpret_cast<const type_b*>(&buffer); }
//
template<class... types> type_a& set_first(types&&... targs) {
if (!t_is_first) {
second().~type_b();
t_is_first=1;
} else {
first().~type_a();
}
return *(new(&buffer) type_a(forward<types>(targs)...));
}
template<class... types> type_b& set_second(types&&... targs) {
if (t_is_first) {
first().~type_a();
t_is_first=0;
} else {
second().~type_b();
}
return *(new(&buffer) type_b(forward<types>(targs)...));
}
~union_pair() {
if (t_is_first) first().~type_a(); else second().~type_b();
}
};
}
#endif

View File

@ -0,0 +1,34 @@
#ifndef ILYA_SHARED_HEADER_GENERIC_MACROS
#define ILYA_SHARED_HEADER_GENERIC_MACROS
/*
#define main(...) \
main_inner(int argc, char** argv); \
int main(int argc, char** argv) { \
try {\
return main_inner(argc, argv);\
} catch(const std::exception& e) {\
std::cerr << "\n\nUncaught exception: " << e.what() << "\n";\
char *f=0; *f=1;\
} catch(const std::string& e) {\
std::cerr << "\n\nUncaught exception: " << e << "\n";\
char *f=0; *f=1;\
} catch(const char* e) {\
std::cerr << "\n\nUncaught exception: " << e << "\n";\
char *f=0; *f=1;\
} catch(...) {\
std::cerr << "\n\nUncaught exception.\n";\
char *f=0; *f=1;\
}\
}\
int main_inner(int argc, char** argv)
#ifndef NO_GENERIC_H_ASSERT
#ifdef assert
#undef assert
#endif
#define assert(v) if (!(v)) { std::cerr << "\n\nAssertion failed: " << __FILE__ << " : " << __LINE__ << "\n"; char* shared_generic_assert_char_123=nullptr; *shared_generic_assert_char_123=1; throw 0; } (void)0
#endif
*/
#endif

View File

@ -0,0 +1,639 @@
template<class int_type> int_type add_carry(int_type a, int_type b, int carry_in, int& carry_out) {
assert(carry_in==0 || carry_in==1);
uint128 res=uint128(a) + uint128(b) + uint128(carry_in);
carry_out=int(res >> (sizeof(int_type)*8));
assert(carry_out==0 || carry_out==1);
return int_type(res);
}
template<class int_type> int_type sub_carry(int_type a, int_type b, int carry_in, int& carry_out) {
assert(carry_in==0 || carry_in==1);
uint128 res=uint128(a) - uint128(b) - uint128(carry_in);
carry_out=int(res >> (sizeof(int_type)*8)) & 1;
assert(carry_out==0 || carry_out==1);
return int_type(res);
}
template<class int_type> int clz(int_type a) {
assert(sizeof(int_type)==4 || sizeof(int_type)==8);
if (a==0) {
return (sizeof(int_type)==4)? 32 : 64;
} else {
return (sizeof(int_type)==4)? __builtin_clz(uint32(a)) : __builtin_clzll(uint64(a));
}
}
uint64 mul_high(uint64 a, uint64 b) {
return uint64((uint128(a)*uint128(b))>>64);
}
uint32 mul_high(uint32 a, uint32 b) {
return uint32((uint64(a)*uint64(b))>>32);
}
constexpr int max_constexpr(int a, int b) {
if (a>b) {
return a;
} else {
return b;
}
}
//all "=" operators truncate ; all operators that return a separate result will pad the result as necessary
template<class type, int size> struct fixed_integer {
static const type positive_sign=0;
static const type negative_sign=~type(0);
type data[size+1]; //little endian; sign is first
fixed_integer() {
for (int x=0;x<size+1;++x) {
data[x]=0;
}
}
fixed_integer(const integer& i) : fixed_integer() {
assert(i.num_bits()<=size*sizeof(type)*8);
if (i<0) {
data[0]=negative_sign;
}
mpz_export(data+1, nullptr, -1, sizeof(type), -1, 0, i.impl);
}
operator integer() const {
integer res;
mpz_import(res.impl, size, -1, sizeof(type), -1, 0, data+1);
if (data[0]==negative_sign) {
res=-res;
}
return res;
}
USED integer to_integer() const {
return integer(*this);
}
//truncation
template<int t_size> explicit fixed_integer(fixed_integer<type, t_size> t) {
for (int x=0;x<size+1;++x) {
data[x]=(x<t_size+1)? t.data[x] : 0;
}
}
fixed_integer& operator=(const integer& v) { return *this=fixed_integer(v); }
template<int t_size> fixed_integer& operator=(fixed_integer<type, t_size> t) { return *this=fixed_integer(t); }
bool is_negative() const {
return !is_zero() && data[0]==negative_sign;
}
void set_negative(bool t_negative) {
data[0]=(t_negative)? negative_sign : positive_sign;
}
type& operator[](int pos) {
assert(pos>=0 && pos<size);
return data[pos+1];
}
const type& operator[](int pos) const {
assert(pos>=0 && pos<size);
return data[pos+1];
}
//the result is -1 if a<b, 0 if a==b, and 1 if a>b
//there is also a fast comparison in the add function, but it has a slow path
static int compare(
const type* a, int size_a, type sign_a,
const type* b, int size_b, type sign_b
) {
int carry=0;
type zero=0;
//this calculates |a|-|b|. all of the resulted are or'ed together in zero
for (int x=0;x<max(size_a, size_b);++x) {
type v_a=(x<size_a)? a[x] : 0;
type v_b=(x<size_b)? b[x] : 0;
zero|=sub_carry(v_a, v_b, carry, carry);
}
//if the final carry is 1, |a|<|b|
//if the final carry is 0 and zero==0, |a|==|b| (|a|-|b| is 0)
//if the final carry is 0 and zero!=0, |a|>|b| (|a|-|b| is positive)
//same sign, positive: use res
//same sign, negative: negate res
//opposite signs: use res if 0, otherwise 1 if sign_a is positive, -1 if sign_a is negative
int res=0;
if (zero!=0) res=1;
if (carry==1) res=-1;
//todo //get rid of branches
//this is used to implement exactly one comparison with a binary result, so that should get rid of all of these branches
if (sign_a==sign_b) {
if (sign_a==negative_sign) {
res=-res;
}
} else {
if (res!=0) {
res=(sign_a==negative_sign)? -1 : 1;
}
}
return res;
}
template<int b_size> int compare(fixed_integer<type, b_size> b) const {
return compare(
data+1, size, data[0],
b.data+1, size, b.data[0]
);
}
//a, b, and res can alias with each other but only if the pointers are equal
//the sign is not present in a/b/res
static void add(
const type* a, int size_a, type sign_a,
const type* b, int size_b, type sign_b,
type* res, int size_res, type& sign_res
) {
if (size_b>size_a) {
swap(a, b);
swap(size_a, size_b);
swap(sign_a, sign_b);
}
assert(size_res>=size_a && size_a>=size_b && size_b>=1);
type mask=sign_a ^ sign_b; //all 1s if opposite signs, else all 0s. this isn't affected by swapping
type swap_mask=positive_sign;
if (size_a==size_b) {
//carry flag
int size_ab=size_a;
bool a_less_than_b=a[size_ab-1]<b[size_ab-1];
if (a[size_ab-1]==b[size_ab-1] && size_ab>=2) {
a_less_than_b=a[size_ab-2]<b[size_ab-2];
}
const type* tmp=b;
if (a_less_than_b) b=a; //CMOVB
if (a_less_than_b) a=tmp; //CMOVB
if (a_less_than_b) sign_a=sign_b; //CMOVB
//sign_b isn't used anymore
sign_b=0;
//if (a_less_than_b) swap_mask=negative_sign; //CMOVB
}
int carry;
add_carry(mask, type(1), 0, carry); //carry set if opposite signs, else cleared
//if the ints were swapped, size_a==size_b
for (int x=0;x<size_res;++x) {
type v_a=(x<size_a)? a[x] : 0;
type v_b=(x<size_b)? b[x] : 0;
//print(x, v_a, v_b, mask, carry);
//this calculates a-b if they had opposite signs, or a+b if they had the same sign
res[x]=add_carry(v_a, v_b^mask, carry, carry);
}
//print(carry, "===");
//the final sign is a's sign since it has a higher magnitude than b
//however, if a subtraction was done and a and b were swapped, then this should be negated
sign_res=sign_a^(swap_mask & mask);
//todo //figure out how often this happens
//a subtraction was done and there was a carry out. since the subtraction is unsigned, this means it was done in the wrong order
//this almost never happens if the numbers are random and don't have excessive padding
//the subtraction was done in the wrong order if the result is negative
//the result is negative if each input were padded with 0, and the result limb was ~0 instead of 0
//the result limb is: add_carry(0, mask, carry, carry);
//carry in is 0: result is all 1s (bad)
//carry in is 1: result is all 0s and carry out is 1 (good)
//need to check for a carry out of 0 then, not 1
if (carry==0 && mask!=0) {
carry=0;
for (int x=0;x<size_res;++x) {
//print(x, ~res[x], type((x==0)? 1 : 0), carry);
//calculate the two's complement of the result
res[x]=add_carry(~res[x], type((x==0)? 1 : 0), carry, carry);
}
//print(carry, "===");
//todo print("slow add");
//assert(false);
//negate the sign since the subtraction order was flipped
sign_res=~sign_res;
}
}
fixed_integer operator-() const {
fixed_integer res=*this;
res.data[0]=~data[0];
return res;
}
void operator+=(fixed_integer b) {
add(
data+1, size, data[0],
b.data+1, size, b.data[0],
data+1, size, data[0]
);
}
void operator-=(fixed_integer b) {
add(
data+1, size, data[0],
b.data+1, size, negative_sign^b.data[0],
data+1, size, data[0]
);
}
template<int b_size>
fixed_integer<type, max_constexpr(size, b_size)+1> operator+(
fixed_integer<type, b_size> b
) const {
const int output_size=max_constexpr(size, b_size)+1;
fixed_integer<type, output_size> res;
add(
data+1, size, data[0],
b.data+1, b_size, b.data[0],
res.data+1, output_size, res.data[0]
);
return res;
}
template<int b_size>
fixed_integer<type, max_constexpr(size, b_size)+1> operator-(
fixed_integer<type, b_size> b
) const {
const int output_size=max_constexpr(size, b_size)+1;
fixed_integer<type, output_size> res;
add(
data+1, size, data[0],
b.data+1, b_size, negative_sign^b.data[0],
res.data+1, output_size, res.data[0]
);
return res;
}
//res=a*b+c
//res can alias with c if the pointers are equal. can't alias with a
//if c is null then it is all 0s
static void mad(
const type* a, int size_a,
type b,
const type* c, int size_c,
type* res, int size_res
) {
assert(size_res>=size_c && size_c>=size_a && size_a>=1);
type previous_high=0;
int carry_mul=0;
int add_mul=0;
for (int x=0;x<size_res;++x) {
type this_a=(x>=size_a)? 0 : a[x];
type this_low=this_a*b;
type this_high=mul_high(this_a, b);
type mul_res=add_carry(this_low, previous_high, carry_mul, carry_mul);
if (x==0) {
assert(mul_res==this_low && carry_mul==0);
} else
if (x==size_a) {
assert(carry_mul==0);
} else
if (x>size_a) {
assert(mul_res==0 && carry_mul==0);
}
type this_c=(x>=size_c || c==nullptr)? 0 : c[x];
type add_res=add_carry(mul_res, this_c, add_mul, add_mul);
res[x]=add_res;
previous_high=this_high;
}
}
//can't overflow
//two of these can implement a 1024x512 mul. for 1024x1024, need to do 2x 1024x512 in separate buffers then add them
static void mad_8x8(array<type, 8> a, array<type, 8> b, array<type, 8> c, array<type, 16>& res) {
for (int x=0;x<8;++x) {
res[x]=c[x];
}
for (int x=8;x<16;++x) {
res[x]=0;
}
for (int x=0;x<8;++x) {
//this uses a sliding window for the 8 res registers (no spilling)
//-the lowest register is finished after the first addition in mad. the this_low,previous_high addition is skipped
//-the highest register does not need to be loaded until the last multiplication in mad. actually this would always load 0
// so it is not done
//-the total number of registers is therefore 7
//there is one register for b
//the 8 a values are in registers but some or all may be spilled
//need 2 registers to store the MULX result
//need 1 register to store the previous high result (this is initially 0)
//the this_low,previous_high add result goes into one of those registers
//the mul_res,this_c result goes into the c register
//total registers is 18 then; 2 are spilled
//address registers:
//-will just use a static 32-bit address space for most of the code. can store the stack pointer there then
//-address registers are only used for b and res if the addresses are not static
//-the addresses are only used at the end of the loop, so there are spare registers to load the address registers from static
// memory. probably the addresses will be static though
mad(&a[0], 8, b[x], &res[x], 8, &res[x], 8);
}
}
void operator*=(type v) {
mad(
data+1, size,
v,
nullptr, size,
data+1, size
);
}
template<int t_size, int this_size>
static fixed_integer<type, t_size> subset(
fixed_integer<type, this_size> this_v, int start
) {
const int end=start+t_size;
fixed_integer<type, t_size> res;
res.data[0]=this_v.data[0];
for (int x=start;x<end;++x) {
int pos=x-start;
res[x]=(pos>=0 && pos<this_size)? this_v[x] : 0;
}
return res;
}
void left_shift_limbs(int amount) {
for (int x=size-1;x>=0;--x) {
int pos=x-amount;
(*this)[x] = (pos>=0 && pos<size)? (*this)[pos] : 0;
}
}
void right_shift_limbs(int amount) {
for (int x=0;x<size;++x) {
int pos=x+amount;
(*this)[x] = (pos>=0 && pos<size)? (*this)[pos] : 0;
}
}
void operator<<=(int amount) {
if (amount==0) {
//not sure if intel works with the "previous>>64" statement. might wrap around
return;
}
const int bits_per_limb=sizeof(type)*8;
assert(amount>0 && amount<bits_per_limb);
for (int x=size-1;x>=0;--x) {
type previous=(x==0)? 0 : (*this)[x-1];
(*this)[x] = ((*this)[x]<<amount) | (previous>>(bits_per_limb-amount));
}
}
void operator>>=(int amount) {
if (amount==0) {
return;
}
const int bits_per_limb=sizeof(type)*8;
assert(amount>0 && amount<bits_per_limb);
for (int x=0;x<size;++x) {
type next=(x==size-1)? 0 : (*this)[x+1];
(*this)[x] = ((*this)[x]>>amount) | (next<<(bits_per_limb-amount));
}
}
template<int b_size>
fixed_integer<type, size+b_size> operator*(
fixed_integer<type, b_size> b
) const {
const int output_size=size+b_size;
fixed_integer<type, output_size> res;
for (int x=0;x<b_size;++x) {
auto r=subset<output_size>(*this, 0);
r.data[0]=positive_sign;
integer b_x_int(vector<uint64>{b[x]});
r*=b[x];
//auto r2=subset<output_size+2>(r, 0);
//r2*=b[x];
//r=r2;
integer r_int(r);
integer this_int(abs(*this));
integer expected_r_int=this_int*b_x_int;
assert(r_int==expected_r_int);
r.left_shift_limbs(x);
r_int<<=x*sizeof(type)*8;
assert(r_int==integer(r));
integer res_old_int(res);
//todo //figure out why this doesn't work. might have something to do with the msb being set?
res+=r; //unsigned
/*auto res3=res;
res3+=r;
auto res2=res+r;
fixed_integer<type, output_size> res4(res2);*/
/*if (integer(res3)!=integer(res4)) {
print( "========" );
res3=res;
res3+=r;
//print( "========" );
auto res2_copy=res+r;
assert(false);
}*/
//res=res4;
integer res_new_int(res);
assert(res_new_int==res_old_int+r_int);
}
res.data[0]=data[0] ^ b.data[0];
return res;
}
fixed_integer<type, size+1> operator<<(int num) const {
auto res=subset<size+1>(*this, 0);
res<<=num;
return res;
}
//this rounds to 0 so it is different from division unless the input is divisible by 2^num
fixed_integer<type, size> operator>>(int num) const {
auto res=subset<size>(*this, 0);
res>>=num;
return res;
}
bool is_zero() const {
for (int x=0;x<size;++x) {
if (data[x+1]!=0) {
return false;
}
}
return true;
}
template<int b_size>
bool operator>=(fixed_integer<type, b_size> b) const {
return compare(b)>=0;
}
template<int b_size>
bool operator==(fixed_integer<type, b_size> b) const {
return compare(b)==0;
}
template<int b_size>
bool operator<(fixed_integer<type, b_size> b) const {
return compare(b)<0;
}
template<int b_size>
bool operator<=(fixed_integer<type, b_size> b) const {
return compare(b)<=0;
}
template<int b_size>
bool operator>(fixed_integer<type, b_size> b) const {
return compare(b)>0;
}
template<int b_size>
bool operator!=(fixed_integer<type, b_size> b) const {
return compare(b)!=0;
}
//"0" has 1 bit
int num_bits() const {
type v=0;
int num_full=0;
for (int x=size-1;x>=0;--x) {
if (v==0) {
v=(*this)[x];
num_full=x;
}
}
int v_bits;
if (v==0) {
v_bits=1;
assert(num_full==0);
} else
if (sizeof(v)==8) {
v_bits=64-__builtin_clzll(v);
} else{
assert(sizeof(v)==4);
v_bits=32-__builtin_clz(v);
}
return num_full*sizeof(type)*8 + v_bits;
}
type window(int start_bit) const {
int bits_per_limb_log2=(sizeof(type)==8)? 6 : 5;
int bits_per_limb=1<<bits_per_limb_log2;
int start_limb=start_bit>>bits_per_limb_log2;
int start_offset=start_bit&(bits_per_limb-1);
auto get_limb=[&](int pos) -> type {
assert(pos>=0);
return (pos>=size)? type(0) : (*this)[pos];
};
type start=get_limb(start_limb)>>(start_offset);
//the shift is undefined for start_offset==0
type end=get_limb(start_limb+1)<<(bits_per_limb-start_offset);
return (start_offset==0)? start : (start | end);
}
};
template<class type, int size> fixed_integer<type, size> abs(fixed_integer<type, size> v) {
v.set_negative(false);
return v;
}
template<int size> fixed_integer<uint64, (size+1)/2> to_uint64(fixed_integer<uint32, size> v) {
fixed_integer<uint64, (size+1)/2> res;
res.set_negative(v.is_negative()); //sign extend data[0]. can just make data[0] 64 bits if i actually have to do this
//this just copies the bytes over
for (int x=0;x<size;x+=2) {
uint32 low=v[x];
uint32 high=(x==size-1)? 0 : v[x+1];
res[x>>1]=uint64(high)<<32 | uint64(low);
}
return res;
}
template<int size> fixed_integer<uint32, size*2> to_uint32(fixed_integer<uint64, size> v) {
fixed_integer<uint32, size*2> res;
res.set_negative(v.is_negative()); //lower 32 bits of data[0]
for (int x=0;x<size;++x) {
res[2*x]=uint32(v[x]);
res[2*x+1]=uint32(v[x]>>32);
}
return res;
}

View File

@ -0,0 +1,378 @@
//unsigned
template<class type, int size> void normalize_divisor(fixed_integer<type, size>& b, int& shift_limbs, int& shift_bits) {
shift_limbs=0;
//todo //make this a variable shift (could have done it on the gpu through shared memory; oh well)
for (int x=0;x<size;++x) {
if (b[size-1]==0) {
++shift_limbs;
b.left_shift_limbs(1);
} else {
break;
}
}
shift_bits=clz(b[size-1]);
b<<=shift_bits;
}
//result is >= the actual reciprocal; max result is 2^63
uint64 calculate_reciprocal(uint32 high, uint32 low) {
assert((high>>31)!=0); //should be normalized
//bit 63 set
uint64 both_source=uint64(low) | (uint64(high)<<32);
uint64 both=both_source;
//bit 52 set
both>>=2*32-53;
//clears bit 52
both&=~(1ull<<52);
uint64 res;
if (both<=1) {
res=1ull<<63;
} else {
--both;
uint64 bits=both;
bits|=1023ull<<52;
double bits_double=*(double*)&bits;
bits_double=1.0/(bits_double);
bits=*(uint64*)&bits_double;
bits&=(1ull<<52)-1;
res=bits;
++res;
res|=1ull<<52;
res<<=(62-52);
}
return res;
}
//result is >= the actual quotient
uint32 calculate_quotient(uint32 high, uint32 low, uint64 reciprocal, uint32 b) {
uint64 both=uint64(low) | (uint64(high)<<32);
uint64 product_high=(uint128(both)*uint128(reciprocal))>>64;
++product_high;
uint64 res=product_high>>(32-2);
if (res>=1ull<<32) {
res=(1ull<<32)-1;
}
return uint32(res);
}
fixed_integer<uint64, 2> calculate_reciprocal(uint64 high, uint64 low);
uint64 calculate_quotient(uint64 high, uint64 low, fixed_integer<uint64, 2> reciprocal, uint64 b);
//should pad a by 1 limb then left shift it by num_bits
//all integers are unsigned
template<class type, int size_a, int size_b>
void divide_integers_impl(
fixed_integer<type, size_a> a, fixed_integer<type, size_b> b, int b_shift_limbs,
fixed_integer<type, size_a-1>& q, fixed_integer<type, size_b>& r
) {
const int max_quotient_size=size_a-1;
fixed_integer<type, max_quotient_size> res;
auto reciprocal=calculate_reciprocal(b[size_b-1], (size_b>=2)? b[size_b-2] : 0);
fixed_integer<type, size_a> b_shifted;
b_shifted=b;
b_shifted.left_shift_limbs(size_a-size_b-1); //it is already left shifted by b_shift_limbs
int quotient_size=size_a-(size_b-b_shift_limbs);
for (int x=0;x<max_quotient_size;++x) {
//this is more efficient than having an if statement without a break because of the compiler
if (x>=quotient_size) {
break;
}
{
type qj=calculate_quotient(a[size_a-1-x], a[size_a-2-x], reciprocal, b[size_b-1]);
//this is slower than using the doubles even though the doubles waste half the registers
//ptxas generates horrible code which isn't scheduled properly
//uint64 qj_64=((uint64(a[size_a-1-x])<<32) | uint64(a[size_a-2-x])) / uint64(b[size_b-1]);
//uint32 qj=uint32(min( qj_64, uint64(~uint32(0)) ));
auto a_start=a;
type qj_start=qj;
auto b_shifted_qj=b_shifted;
b_shifted_qj*=qj;
a-=b_shifted_qj;
while (a.is_negative()) {
//todo print( "slow division" );
--qj;
a+=b_shifted;
}
b_shifted.right_shift_limbs(1);
res[max_quotient_size-1-x]=qj;
}
}
//todo //get rid of this; use variable shifts
for (int x=0;x<max_quotient_size;++x) {
if (quotient_size>=max_quotient_size) {
break;
}
res.right_shift_limbs(1);
++quotient_size;
}
q=res;
r=a;
//todo print( "====" );
}
//these are signed
//this has a bug if size_a<size_b and the quotient is nonzero. remainder is wrong. dont care
template<class type, int size_a, int size_b>
void divide_integers(
fixed_integer<type, size_a> a, fixed_integer<type, size_b> b,
fixed_integer<type, size_a>& q, fixed_integer<type, size_b>& r
) {
int shift_limbs;
int shift_bits;
auto b_normalized=b;
b_normalized.set_negative(false);
normalize_divisor(b_normalized, shift_limbs, shift_bits);
fixed_integer<type, size_a+1> a_shifted;
a_shifted=a;
a_shifted.set_negative(false);
a_shifted<<=shift_bits;
fixed_integer<type, size_a> q_unsigned;
divide_integers_impl(a_shifted, b_normalized, shift_limbs, q_unsigned, r);
r>>=shift_bits;
if (a.is_negative()!=b.is_negative()) {
if (r==fixed_integer<type, size_b>(integer(0u))) {
q=q_unsigned;
q=-q;
} else {
q=q_unsigned+fixed_integer<type, size_a>(integer(1u));
q=-q; //q'=-q-1
auto abs_b=b;
abs_b.set_negative(false);
r=abs_b-r;
}
} else {
q=q_unsigned;
}
// qb+r=a ; b>0: 0<=r<b ; b<0: b<r<=0
// b<0:
// -qb-r=-a
// R=-r ; 0<=R<-b
// q(-b)+R=-a
r.set_negative(b.is_negative());
{
integer a_int(a);
integer b_int(b);
integer q_expected=a_int/b_int;
integer r_expected=a_int.fdiv_r(b_int);
integer r_expected_2=a_int%b_int;
integer q_actual=q;
integer r_actual=r;
assert(q_expected==q_actual);
assert(r_expected==r_actual);
//todo
//r=r_expected;
}
}
template<class type, int size_a, int size_b>
fixed_integer<type, size_a> operator/(
fixed_integer<type, size_a> a, fixed_integer<type, size_b> b
) {
fixed_integer<type, size_a> q;
fixed_integer<type, size_b> r;
divide_integers(a, b, q, r);
return q;
}
template<class type, int size_a, int size_b>
fixed_integer<type, size_b> operator%(
fixed_integer<type, size_a> a, fixed_integer<type, size_b> b
) {
fixed_integer<type, size_a> q;
fixed_integer<type, size_b> r;
b.set_negative(false);
divide_integers(a, b, q, r);
return r;
}
fixed_integer<uint64, 2> calculate_reciprocal(uint64 high, uint64 low) {
assert((high>>63)!=0); //normalized
fixed_integer<uint32, 6> a;
a[5]=1u<<31; // a=2^191 ; normalized
fixed_integer<uint32, 3> b;
b[0]=uint32(low>>32);
b[1]=uint32(high);
b[2]=uint32(high>>32);
b-=fixed_integer<uint32, 3>(integer(1));
return fixed_integer<uint64, 2>(to_uint64(a/b + fixed_integer<uint32, 6>(integer(1)))<<31);
}
//result is >= the actual reciprocal. it is approximately 2^127/((HIGH | LOW)/2^127)
//the max value is 2^127 + 2^31
/*fixed_integer<uint64, 2> calculate_reciprocal(uint64 high, uint64 low) {
assert((high>>63)!=0); //normalized
//fixed_integer<uint32, 6> a
//a[5]=1u<<31; // a=2^191 ; normalized
uint128 b=(uint128(high)<<32) | uint128(low>>32);
uint64 reciprocal=calculate_reciprocal(uint32(high>>32), uint32(high));
fixed_integer<type, size_a> b_shifted;
b_shifted=b;
b_shifted.left_shift_limbs(2);
int quotient_size=3;
for (int x=0;x<3;++x) {
uint64 qj=calculate_quotient(a[5-x], a[4-x], reciprocal, b[1]);
auto b_shifted_qj=b_shifted;
b_shifted_qj*=qj;
a-=b_shifted_qj;
while (a.is_negative()) {
//todo print( "slow division" );
--qj;
a+=b_shifted;
}
b_shifted.right_shift_limbs(1);
res[5-1-x]=qj;
}
todo //get rid of this; use variable shifts
for (int x=0;x<max_quotient_size;++x) {
if (quotient_size>=max_quotient_size) {
break;
}
res.right_shift_limbs(1);
++quotient_size;
}
q=res;
r=a;
//todo print( "====" );
return fixed_integer<uint64, 2>(to_uint64(a/b + fixed_integer<uint32, 6>(integer(1)))<<31);
} */
//result is >= the actual quotient
uint64 calculate_quotient(uint64 high, uint64 low, fixed_integer<uint64, 2> reciprocal, uint64 b) {
fixed_integer<uint64, 2> both;
both[0]=low;
both[1]=high;
//approximately (high | low) * (2^127/((HIGH | LOW)/2^127))
// = (2^(127*2)*(high | low)/((HIGH | LOW)/2^64)/2^64
// = (2^(127*2-64) * (high | low)/((HIGH | LOW)/2^64)
// = (2^190 * (high | low)/((HIGH | LOW)/2^64)
//need to right shift by 190 then, which is 2*64+62
//
//max value of the product is (2^128-1)*(2^127 + 2^31) = 2^255 + 2^159 - 2^127 - 2^31
integer both_int(both);
integer reciprocal_int(reciprocal);
integer product_both_int(both_int*reciprocal_int);
fixed_integer<uint64, 4> product_both(both*reciprocal);
assert(integer(product_both)==product_both_int);
product_both.right_shift_limbs(2);
product_both_int>>=128;
assert(integer(product_both)==product_both_int);
fixed_integer<uint64, 2> product_high(product_both);
//this can't overflow because the max value of the product has e.g. bit 254 cleared
product_high+=fixed_integer<uint64, 2>(integer(1));
product_high>>=64-2;
uint64 res;
if (product_high[1]!=0) {
res=~uint64(0);
} else {
res=product_high[0];
}
//uint128 qj_128=((uint128(high)<<64) | uint128(low)) / uint128(b);
//uint64 qj=uint64(min( qj_128, uint128(~uint64(0)) ));
//assert(res>=qj); this is wrong. res can be qj-1 sometimes
//assert(res<=qj+1); //optional
return res;
}
/*template<int size_a, int size_b>
fixed_integer<uint64, size_a> operator/(
fixed_integer<uint64, size_a> a, fixed_integer<uint64, size_b> b
) {
auto a_32=to_uint32(a);
auto b_32=to_uint32(b);
fixed_integer<uint32, size_a*2> q_32;
fixed_integer<uint32, size_b*2> r_32;
divide_integers(a_32, b_32, q_32, r_32);
return to_uint64(q_32);
}
template<int size_a, int size_b>
fixed_integer<uint64, size_b> operator%(
fixed_integer<uint64, size_a> a, fixed_integer<uint64, size_b> b
) {
auto a_32=to_uint32(a);
auto b_32=to_uint32(b);
fixed_integer<uint32, size_a*2> q_32;
fixed_integer<uint32, size_b*2> r_32;
b_32.set_negative(false);
divide_integers(a_32, b_32, q_32, r_32);
return to_uint64(r_32);
}**/

View File

@ -0,0 +1,118 @@
template<int size> struct fixed_gcd_res {
fixed_integer<uint64, size> gcd; //unsigned; final value of a
fixed_integer<uint64, size> gcd_2; //unsigned; final value of b. this is 0 for a normal gcd
fixed_integer<uint64, size> s; //signed
fixed_integer<uint64, size> t; //signed
fixed_integer<uint64, size> s_2; //signed
fixed_integer<uint64, size> t_2; //signed
};
//threshold is 0 to calculate the normal gcd
//this calculates either s (u) or t (v)
template<int size> fixed_gcd_res<size> gcd(
fixed_integer<uint64, size> a_signed, fixed_integer<uint64, size> b_signed, fixed_integer<uint64, size> threshold,
bool calculate_u
) {
assert(!threshold.is_negative());
bool a_negative=a_signed.is_negative();
bool b_negative=b_signed.is_negative();
assert(!b_negative);
array<fixed_integer<uint64, size>, 2> ab; //unsigned
ab[0]=a_signed;
ab[0].set_negative(false);
ab[1]=b_signed;
ab[1].set_negative(false);
array<fixed_integer<uint64, size>, 2> uv; //unsigned
int parity;
if (ab[0]<ab[1]) {
//swap components of u and v
//also negate the parity
auto a_copy=ab[0];
ab[0]=ab[1];
ab[1]=a_copy;
if (calculate_u) {
uv[0]=integer(0u);
uv[1]=integer(1u);
} else {
uv[0]=integer(1u);
uv[1]=integer(0u);
}
parity=-1;
} else {
if (calculate_u) {
uv[0]=integer(1u);
uv[1]=integer(0u);
} else {
uv[0]=integer(0u);
uv[1]=integer(1u);
}
parity=1;
}
gcd_unsigned(ab, uv, parity, threshold);
// sa+bt=g ; all nonnegative
// (-s)(-a)+bt=g
// sa+(-b)(-t)=g
// (-s)(-a)+(-b)(-t)=g
// sign of each cofactor is the sign of the input
fixed_gcd_res<size> res;
res.gcd=ab[0];
res.gcd_2=ab[1];
//if a was negative, negate the parity
//if the parity is -1, negate the parity and negate the result u/v values. the parity is now 1
//for u, u0 is positive and u1 is negative
//for v, v0 is negative and u1 is positive
if (calculate_u) {
res.s=uv[0];
res.s.set_negative(a_negative != (parity==-1));
res.s_2=uv[1];
res.s_2.set_negative(a_negative != (parity==1));
} else {
res.t=uv[0];
res.t.set_negative(a_negative != (parity==1));
res.t_2=uv[1];
res.t_2.set_negative(a_negative != (parity==-1));
}
if (threshold.is_zero()) {
auto expected_gcd_res=gcd(integer(a_signed), integer(b_signed));
assert(expected_gcd_res.gcd==integer(res.gcd));
if (calculate_u) {
assert(expected_gcd_res.s==integer(res.s));
} else {
assert(expected_gcd_res.t==integer(res.t));
}
} else {
//integer a_copy(a_signed);
//integer b_copy(a_signed);
//integer u_copy;
//integer v_copy;
//xgcd_partial(u_copy, v_copy, a_copy, b_copy, integer(threshold));
//assert(a_copy==res.gcd);
//assert(b_copy==res.gcd_2);
//if (calculate_t) {
//assert(u_copy==-res.t);
//assert(v_copy==-res.t_2);
//}
}
return res;
}

View File

@ -0,0 +1,58 @@
#ifdef NDEBUG
#undef NDEBUG
#endif
#if VDF_MODE==0
#define NDEBUG
#endif
#include <iostream>
#include <string>
#include <vector>
#include <cstdio>
#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>
#include <array>
#include <sstream>
#include <fstream>
#include <unistd.h>
#include <cassert>
#include <iomanip>
#include <set>
#include <random>
#include <limits>
#include <cstdlib>
#include <map>
#include <functional>
#include <algorithm>
#include <cstdint>
#include <deque>
#include <cfenv>
#include <ctime>
#include <thread>
#include <x86intrin.h>
#include "generic.h"
#include <gmpxx.h>
#include <flint/fmpz.h>
using namespace std;
using namespace generic;
typedef uint8_t uint8;
typedef uint16_t uint16;
typedef uint32_t uint32;
typedef uint64_t uint64;
typedef int8_t int8;
typedef int16_t int16;
typedef int32_t int32;
typedef int64_t int64;
typedef unsigned __int128 uint128;
typedef __int128 int128;
#define todo
#define USED __attribute__((used))

View File

@ -0,0 +1,2 @@
#!/bin/bash
./install_child.sh 2>&1

View File

@ -0,0 +1,25 @@
#!/bin/bash
set -v
cat /proc/cpuinfo | grep -e MHz -e GHz
cat /proc/cpuinfo | grep flags | head -n 1
enable_all_instructions=0
if cat /proc/cpuinfo | grep -w avx2 | grep -w fma | grep -w -q adx; then
enable_all_instructions=1
fi
echo "enable_all_instructions: $enable_all_instructions"
# Use this for linux only
# sudo apt-get install libgmp3-dev -y
# sudo apt-get install libflint-dev -y
# Remove -D CHIAOSX=1 for linux
compile_flags="-std=c++1z -D CHIAOSX=1 -D VDF_MODE=0 -D ENABLE_ALL_INSTRUCTIONS=$enable_all_instructions -no-pie -march=native"
link_flags="-no-pie -lgmpxx -lgmp -lflint -lpthread"
g++ -o compile_asm.o -c compile_asm.cpp $compile_flags -O0
g++ -o compile_asm compile_asm.o $link_flags
./compile_asm
as -o asm_compiled.o asm_compiled.s
g++ -o vdf.o -c vdf.cpp $compile_flags -O3
g++ -o vdf vdf.o asm_compiled.o $link_flags

View File

@ -0,0 +1,476 @@
//note: gmp already has c++ bindings so could have just used those. oh well
//const bool output_stats=false;
/*struct generic_stats {
vector<int> entries;
void add(int i) {
if (!output_stats) {
return;
}
entries.push_back(i);
}
void output(string name) {
if (!output_stats) {
return;
}
sort(entries.begin(), entries.end());
vector<double> percentiles={0, 0.01, 0.1, 1, 10, 25, 50, 75, 90, 99, 99.9, 99.99, 100};
print( "::", name );
print( " num =", entries.size() );
if (entries.empty()) {
return;
}
for (double c : percentiles) {
int i=(c/100)*entries.size();
if (i<0) {
i=0;
}
if (i>=entries.size()) {
i=entries.size()-1;
}
print( " ", c, " -> ", entries.at(i) );
}
double total=0;
for (int c : entries) {
total+=c;
}
print( " ", "avg", " -> ", total/double(entries.size()) );
}
};*/
/*struct track_cycles {
generic_stats& stats;
uint64 start_time;
bool is_aborted=false;
track_cycles(generic_stats& t_stats) : stats(t_stats) {
if (!enable_track_cycles) {
return;
}
start_time=__rdtsc();
}
void abort() {
if (!enable_track_cycles) {
return;
}
is_aborted=true;
}
~track_cycles() {
if (!enable_track_cycles) {
return;
}
if (is_aborted) {
return;
}
uint64 end_time=__rdtsc();
uint64 delta=end_time-start_time;
int delta_int=delta;
if (delta_int==delta) {
stats.add(delta_int);
} else {
stats.add(INT_MAX);
}
}
};*/
struct track_max_type {
map<pair<int, string>, pair<int, bool>> data;
void add(int line, string name, int value, bool negative) {
auto& v=data[make_pair(line, name)];
v.first=max(v.first, value);
v.second|=negative;
}
void output(int basis_bits) {
print( "== track max ==" );
for (auto c : data) {
print(c.first.second, double(c.second.first)/basis_bits, c.second.second);
}
}
};
track_max_type track_max;
//#define TRACK_MAX(data) track_max.add(#data " {" __func__ ":" "__LINE__" ")", (data).num_bits())
#define TRACK_MAX(data) track_max.add(__LINE__, #data, (data).num_bits(), (data)<0)
//typedef __mpz_struct mpz_t[1];
typedef __mpz_struct mpz_struct;
int mpz_num_bits_upper_bound(mpz_struct* v) {
return mpz_size(v)*sizeof(mp_limb_t)*8;
}
static bool allow_integer_constructor=false; //don't want static integers because they use the wrong allocator
struct integer {
mpz_struct impl[1];
~integer() {
mpz_clear(impl);
}
integer() {
assert(allow_integer_constructor);
mpz_init(impl);
}
integer(const integer& t) {
mpz_init(impl);
mpz_set(impl, t.impl);
}
integer(integer&& t) {
mpz_init(impl);
mpz_swap(impl, t.impl);
}
explicit integer(int64 i) {
mpz_init(impl);
mpz_set_si(impl, i);
}
explicit integer(const string& s) {
mpz_init(impl);
int res=mpz_set_str(impl, s.c_str(), 0);
assert(res==0);
}
//lsb first
explicit integer(const vector<uint64>& data) {
mpz_init(impl);
mpz_import(impl, data.size(), -1, 8, 0, 0, &data[0]);
}
//lsb first
vector<uint64> to_vector() const {
vector<uint64> res;
res.resize(mpz_sizeinbase(impl, 2)/64 + 1, 0);
size_t count;
mpz_export(&res[0], &count, -1, 8, 0, 0, impl);
res.resize(count);
return res;
}
integer& operator=(const integer& t) {
mpz_set(impl, t.impl);
return *this;
}
integer& operator=(integer&& t) {
mpz_swap(impl, t.impl);
return *this;
}
integer& operator=(int64 i) {
mpz_set_si(impl, i);
return *this;
}
integer& operator=(const string& s) {
int res=mpz_set_str(impl, s.c_str(), 0);
assert(res==0);
return *this;
}
void set_bit(int index, bool value) {
if (value) {
mpz_setbit(impl, index);
} else {
mpz_clrbit(impl, index);
}
}
bool get_bit(int index) {
return mpz_tstbit(impl, index);
}
USED string to_string() const {
char* res_char=mpz_get_str(nullptr, 16, impl);
string res_string="0x";
res_string+=res_char;
if (res_string.substr(0, 3)=="0x-") {
res_string.at(0)='-';
res_string.at(1)='0';
res_string.at(2)='x';
}
free(res_char);
return res_string;
}
string to_string_dec() const {
char* res_char=mpz_get_str(nullptr, 10, impl);
string res_string=res_char;
free(res_char);
return res_string;
}
integer& operator+=(const integer& t) {
mpz_add(impl, impl, t.impl);
return *this;
}
integer operator+(const integer& t) const {
integer res;
mpz_add(res.impl, impl, t.impl);
return res;
}
integer& operator-=(const integer& t) {
mpz_sub(impl, impl, t.impl);
return *this;
}
integer operator-(const integer& t) const {
integer res;
mpz_sub(res.impl, impl, t.impl);
return res;
}
integer& operator*=(const integer& t) {
mpz_mul(impl, impl, t.impl);
return *this;
}
integer operator*(const integer& t) const {
integer res;
mpz_mul(res.impl, impl, t.impl);
return res;
}
integer& operator<<=(int i) {
assert(i>=0);
mpz_mul_2exp(impl, impl, i);
return *this;
}
integer operator<<(int i) const {
assert(i>=0);
integer res;
mpz_mul_2exp(res.impl, impl, i);
return res;
}
integer operator-() const {
integer res;
mpz_neg(res.impl, impl);
return res;
}
integer& operator/=(const integer& t) {
mpz_fdiv_q(impl, impl, t.impl);
return *this;
}
integer operator/(const integer& t) const {
integer res;
mpz_fdiv_q(res.impl, impl, t.impl);
return res;
}
integer& operator>>=(int i) {
assert(i>=0);
mpz_fdiv_q_2exp(impl, impl, i);
return *this;
}
integer operator>>(int i) const {
assert(i>=0);
integer res;
mpz_fdiv_q_2exp(res.impl, impl, i);
return res;
}
//this is different from mpz_fdiv_r because it ignores the sign of t
integer& operator%=(const integer& t) {
mpz_mod(impl, impl, t.impl);
return *this;
}
integer operator%(const integer& t) const {
integer res;
mpz_mod(res.impl, impl, t.impl);
return res;
}
integer fdiv_r(const integer& t) const {
integer res;
mpz_fdiv_r(res.impl, impl, t.impl);
return res;
}
bool prime() const {
return mpz_probab_prime_p(impl, 50)!=0;
}
bool operator<(const integer& t) const {
return mpz_cmp(impl, t.impl)<0;
}
bool operator<=(const integer& t) const {
return mpz_cmp(impl, t.impl)<=0;
}
bool operator==(const integer& t) const {
return mpz_cmp(impl, t.impl)==0;
}
bool operator>=(const integer& t) const {
return mpz_cmp(impl, t.impl)>=0;
}
bool operator>(const integer& t) const {
return mpz_cmp(impl, t.impl)>0;
}
bool operator!=(const integer& t) const {
return mpz_cmp(impl, t.impl)!=0;
}
bool operator<(int i) const {
return mpz_cmp_si(impl, i)<0;
}
bool operator<=(int i) const {
return mpz_cmp_si(impl, i)<=0;
}
bool operator==(int i) const {
return mpz_cmp_si(impl, i)==0;
}
bool operator>=(int i) const {
return mpz_cmp_si(impl, i)>=0;
}
bool operator>(int i) const {
return mpz_cmp_si(impl, i)>0;
}
bool operator!=(int i) const {
return mpz_cmp_si(impl, i)!=0;
}
int num_bits() const {
return mpz_sizeinbase(impl, 2);
}
};
integer abs(const integer& t) {
integer res;
mpz_abs(res.impl, t.impl);
return res;
}
integer root(const integer& t, int n) {
integer res;
mpz_root(res.impl, t.impl, n);
return res;
}
struct gcd_res {
integer gcd;
integer s;
integer t;
};
//a*s + b*t = gcd ; gcd>=0
// abs(s) < abs(b) / (2 gcd)
// abs(t) < abs(a) / (2 gcd)
//(except if |s|<=1 or |t|<=1)
gcd_res gcd(const integer& a, const integer& b) {
gcd_res res;
mpz_gcdext(res.gcd.impl, res.s.impl, res.t.impl, a.impl, b.impl);
return res;
}
integer rand_integer(int num_bits, int seed=-1) {
thread_local gmp_randstate_t state;
thread_local bool is_init=false;
if (!is_init) {
gmp_randinit_mt(state);
gmp_randseed_ui(state, 0);
is_init=true;
}
if (seed!=-1) {
gmp_randseed_ui(state, seed);
}
integer res;
assert(num_bits>=0);
mpz_urandomb(res.impl, state, num_bits);
return res;
}
//a and b are nonnegative
void xgcd_partial(integer& u, integer& v, integer& a, integer& b, const integer& L) {
fmpz_t f_u; fmpz_init(f_u);
fmpz_t f_v; fmpz_init(f_v);
fmpz_t f_a; fmpz_init(f_a);
fmpz_t f_b; fmpz_init(f_b);
fmpz_t f_L; fmpz_init(f_L);
fmpz_set_mpz(f_a, a.impl);
fmpz_set_mpz(f_b, b.impl);
fmpz_set_mpz(f_L, L.impl);
fmpz_xgcd_partial(f_u, f_v, f_a, f_b, f_L);
fmpz_get_mpz(u.impl, f_u);
fmpz_get_mpz(v.impl, f_v);
fmpz_get_mpz(a.impl, f_a);
fmpz_get_mpz(b.impl, f_b);
fmpz_clear(f_u);
fmpz_clear(f_v);
fmpz_clear(f_a);
fmpz_clear(f_b);
fmpz_clear(f_L);
}
USED string to_string(mpz_struct* t) {
integer t_int;
mpz_set(t_int.impl, t);
return t_int.to_string();
}
void inject_error(mpz_struct* i) {
if (!enable_random_error_injection) {
return;
}
mark_vdf_test();
double v=rand_integer(32).to_vector()[0]/double(1ull<<32);
if (v<random_error_injection_rate) {
print( "injected random error" );
int pos=int(rand_integer(31).to_vector()[0]);
pos%=mpz_sizeinbase(i, 2);
mpz_combit(i, pos);
}
}

View File

@ -0,0 +1,202 @@
/**
Copyright 2018 Chia Network Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
**/
#define LOG2(X) (63 - __builtin_clzll((X)))
//using namespace std;
typedef struct qfb
{
fmpz_t a;
fmpz_t b;
fmpz_t c;
} qfb;
typedef qfb qfb_t[1];
// From Antic using Flint (works!)
void qfb_nucomp(qfb_t r, const qfb_t f, const qfb_t g, fmpz_t D, fmpz_t L)
{
fmpz_t a1, a2, c2, ca, cb, cc, k, s, sp, ss, m, t, u2, v1, v2;
if (fmpz_cmp(f->a, g->a) > 0)
{
qfb_nucomp(r, g, f, D, L);
return;
}
fmpz_init(a1); fmpz_init(a2); fmpz_init(c2);
fmpz_init(ca); fmpz_init(cb); fmpz_init(cc);
fmpz_init(k); fmpz_init(m);
fmpz_init(s); fmpz_init(sp); fmpz_init(ss);
fmpz_init(t); fmpz_init(u2); fmpz_init(v1); fmpz_init(v2);
/* nucomp calculation */
fmpz_set(a1, f->a);
fmpz_set(a2, g->a);
fmpz_set(c2, g->c);
fmpz_add(ss, f->b, g->b);
fmpz_fdiv_q_2exp(ss, ss, 1);
fmpz_sub(m, f->b, g->b);
fmpz_fdiv_q_2exp(m, m, 1);
fmpz_fdiv_r(t, a2, a1);
if (fmpz_is_zero(t))
{
fmpz_set_ui(v1, 0);
fmpz_set(sp, a1);
} else
fmpz_gcdinv(sp, v1, t, a1);
fmpz_mul(k, m, v1);
fmpz_fdiv_r(k, k, a1);
if (!fmpz_is_one(sp))
{
fmpz_xgcd(s, v2, u2, ss, sp);
fmpz_mul(k, k, u2);
fmpz_mul(t, v2, c2);
fmpz_sub(k, k, t);
if (!fmpz_is_one(s))
{
fmpz_fdiv_q(a1, a1, s);
fmpz_fdiv_q(a2, a2, s);
fmpz_mul(c2, c2, s);
}
fmpz_fdiv_r(k, k, a1);
}
if (fmpz_cmp(a1, L) < 0)
{
fmpz_mul(t, a2, k);
fmpz_mul(ca, a2, a1);
fmpz_mul_2exp(cb, t, 1);
fmpz_add(cb, cb, g->b);
fmpz_add(cc, g->b, t);
fmpz_mul(cc, cc, k);
fmpz_add(cc, cc, c2);
fmpz_fdiv_q(cc, cc, a1);
} else
{
fmpz_t m1, m2, r1, r2, co1, co2, temp;
fmpz_init(m1); fmpz_init(m2); fmpz_init(r1); fmpz_init(r2);
fmpz_init(co1); fmpz_init(co2); fmpz_init(temp);
fmpz_set(r2, a1);
fmpz_set(r1, k);
fmpz_xgcd_partial(co2, co1, r2, r1, L);
fmpz_mul(t, a2, r1);
fmpz_mul(m1, m, co1);
fmpz_add(m1, m1, t);
fmpz_tdiv_q(m1, m1, a1);
fmpz_mul(m2, ss, r1);
fmpz_mul(temp, c2, co1);
fmpz_sub(m2, m2, temp);
fmpz_tdiv_q(m2, m2, a1);
fmpz_mul(ca, r1, m1);
fmpz_mul(temp, co1, m2);
if (fmpz_sgn(co1) < 0)
fmpz_sub(ca, ca, temp);
else
fmpz_sub(ca, temp, ca);
fmpz_mul(cb, ca, co2);
fmpz_sub(cb, t, cb);
fmpz_mul_2exp(cb, cb, 1);
fmpz_fdiv_q(cb, cb, co1);
fmpz_sub(cb, cb, g->b);
fmpz_mul_2exp(temp, ca, 1);
fmpz_fdiv_r(cb, cb, temp);
fmpz_mul(cc, cb, cb);
fmpz_sub(cc, cc, D);
fmpz_fdiv_q(cc, cc, ca);
fmpz_fdiv_q_2exp(cc, cc, 2);
if (fmpz_sgn(ca) < 0)
{
fmpz_neg(ca, ca);
fmpz_neg(cc, cc);
}
fmpz_clear(m1); fmpz_clear(m2); fmpz_clear(r1); fmpz_clear(r2);
fmpz_clear(co1); fmpz_clear(co2); fmpz_clear(temp);
}
fmpz_set(r->a, ca);
fmpz_set(r->b, cb);
fmpz_set(r->c, cc);
fmpz_clear(ca); fmpz_clear(cb); fmpz_clear(cc);
fmpz_clear(k); fmpz_clear(m);
fmpz_clear(s); fmpz_clear(sp); fmpz_clear(ss);
fmpz_clear(t); fmpz_clear(u2); fmpz_clear(v1); fmpz_clear(v2);
fmpz_clear(a1); fmpz_clear(a2); fmpz_clear(c2);
}
// a = b * c
void nucomp_form(form &a, form &b, form &c, integer &D, integer &L) {
qfb fr, fr2, fr3;
fmpz_init(fr.a);
fmpz_init(fr.b);
fmpz_init(fr.c);
fmpz_init(fr2.a);
fmpz_init(fr2.b);
fmpz_init(fr2.c);
fmpz_init(fr3.a);
fmpz_init(fr3.b);
fmpz_init(fr3.c);
fmpz_set_mpz(fr2.a, b.a.impl);
fmpz_set_mpz(fr2.b, b.b.impl);
fmpz_set_mpz(fr2.c, b.c.impl);
fmpz_set_mpz(fr3.a, c.a.impl);
fmpz_set_mpz(fr3.b, c.b.impl);
fmpz_set_mpz(fr3.c, c.c.impl);
fmpz_t anticD, anticL;
fmpz_init(anticD);
fmpz_init(anticL);
fmpz_set_mpz(anticD, D.impl);
fmpz_set_mpz(anticL, L.impl);
qfb_nucomp(&fr,&fr2,&fr3,anticD,anticL);
fmpz_get_mpz(a.a.impl,fr.a);
fmpz_get_mpz(a.b.impl,fr.b);
fmpz_get_mpz(a.c.impl,fr.c);
fmpz_clear(fr.a);
fmpz_clear(fr.b);
fmpz_clear(fr.c);
fmpz_clear(fr2.a);
fmpz_clear(fr2.b);
fmpz_clear(fr2.c);
fmpz_clear(fr3.a);
fmpz_clear(fr3.b);
fmpz_clear(fr3.c);
fmpz_clear(anticD);
fmpz_clear(anticL);
}

View File

@ -0,0 +1,207 @@
//have to pass one of these in as a macro
//#define VDF_MODE 0 //used for the final submission and correctness testing
//#define VDF_MODE 1 //used for performance or other testing
//also have to pass in one of these
//#define ENABLE_ALL_INSTRUCTIONS 1
//#define ENABLE_ALL_INSTRUCTIONS 0
//
//
//divide table
const int divide_table_index_bits=11;
const int gcd_num_quotient_bits=31; //excludes sign bit
const int data_size=31;
const int gcd_base_max_iter_divide_table=16;
//continued fraction table
const int gcd_table_num_exponent_bits=3;
const int gcd_table_num_fraction_bits=7;
const int gcd_base_max_iter=5;
#if ENABLE_ALL_INSTRUCTIONS==1
const bool use_divide_table=true;
const int gcd_base_bits=63;
const int gcd_128_max_iter=2;
#else
const bool use_divide_table=false;
const int gcd_base_bits=50;
const int gcd_128_max_iter=3;
#endif
/*
divide_table_index bits
10 - 0m1.269s
11 - 0m1.261s
12 - 0m1.262s
13 - 0m1.341s
**/
/*
gcd_base_max_iter_divide_table
13 - 0m1.290s
14 - 0m1.275s
15 - 0m1.265s
16 - 0m1.261s
17 - 0m1.268s
18 - 0m1.278s
19 - 0m1.283s
**/
/*
100k iterations; median of 3 runs. consistency between runs was very high
effect of scheduler:
taskset 0,1 : 0m1.352s (63% speedup single thread, 37% over 0,2)
taskset 0,2 : 0m1.850s
default : 0m1.348s (fastest)
single threaded : 0m2.212s [this has gone down to 0m1.496s for some reason with the divide table]
exponent fraction base_bits base_iter 128_iter seconds
3 7 50 5 3 0m1.350s [fastest with range checks enabled]
3 7 52 5 3 0m1.318s [range checks disabled; 2.4% faster]
[this block with bmi and fma disabled]
3 7 46 5 3 0m1.426s
3 7 47 5 3 0m1.417s
3 7 48 5 3 0m1.421s
3 7 49 5 3 0m1.413s
3 7 50 5 3 0m1.401s [still fastest; bmi+fma is 3.8% faster]
3 7 51 5 3 0m1.406s
3 7 52 5 3 0m1.460s
3 7 50 6 3 0m1.416s
3 7 49 6 3 0m1.376s
2 8 45 6 3 0m1.590s
2 8 49 6 3 0m1.485s
2 8 51 6 3 0m1.479s
2 8 52 6 3 0m1.501s
2 8 53 6 3 0m1.531s
2 8 54 6 3 0m13.675s
2 8 55 6 3 0m13.648s
3 7 49 2 3 0m14.571s
3 7 49 3 3 0m1.597s
3 7 49 4 3 0m1.430s
3 7 49 5 3 0m1.348s
3 7 49 6 3 0m1.376s
3 7 49 10 3 0m1.485s
3 7 49 1 18 0m2.226s
3 7 49 2 10 0m1.756s
3 7 49 3 6 0m1.557s
3 7 49 4 4 0m1.388s
3 7 49 5 4 0m1.525s
3 7 49 6 3 0m1.377s
3 7 49 7 3 0m1.446s
3 7 49 8 2 0m1.503s
3 6 45 4 3 0m15.176s
3 7 45 4 3 0m1.443s
3 8 45 4 3 0m1.386s
3 9 45 4 3 0m1.355s
3 10 45 4 3 0m1.353s
3 11 45 4 3 0m1.419s
3 12 45 4 3 0m1.451s
3 13 45 4 3 0m1.584s
3 7 40 4 2 0m1.611s
3 8 40 4 2 0m1.570s
3 9 40 4 2 0m1.554s
3 10 40 4 2 0m1.594s
3 11 40 4 2 0m1.622s
3 12 40 4 2 0m1.674s
3 13 40 4 2 0m1.832s
3 7 48 5 3 0m1.358s
3 7 49 5 3 0m1.353s
3 7 50 5 3 0m1.350s
3 8 48 5 3 0m1.366s
3 8 49 5 3 0m1.349s
3 8 50 5 3 0m1.334s
3 9 48 5 3 0m1.370s
3 9 49 5 3 0m1.349s
3 9 50 5 3 0m1.346s
3 10 48 5 3 0m1.404s
3 10 49 5 3 0m1.382s
3 10 50 5 3 0m1.379s
***/
const uint64 max_spin_counter=10000000;
//this value makes square_original not be called in 100k iterations. with every iteration reduced, minimum value is 1
const int num_extra_bits_ab=3;
const bool calculate_k_repeated_mod=false;
const bool calculate_k_repeated_mod_interval=1;
const int validate_interval=1; //power of 2. will check the discriminant in the slave thread at this interval. -1 to disable. no effect on performance
const int checkpoint_interval=10000; //at each checkpoint, the slave thread is restarted and the master thread calculates c
//checkpoint_interval=100000: 39388
//checkpoint_interval=10000: 39249 cycles per fast iteration
//checkpoint_interval=1000: 38939
//checkpoint_interval=100: 39988
//no effect on performance (with track cycles enabled)
// ==== test ====
#if VDF_MODE==1
#define VDF_TEST
const bool is_vdf_test=true;
const bool enable_random_error_injection=false;
const double random_error_injection_rate=0; //0 to 1
//#define GENERATE_ASM_TRACKING_DATA
//#define ENABLE_TRACK_CYCLES
const bool vdf_test_correctness=false;
const bool enable_threads=true;
#endif
// ==== production ====
#if VDF_MODE==0
const bool is_vdf_test=false;
const bool enable_random_error_injection=false;
const double random_error_injection_rate=0; //0 to 1
const bool vdf_test_correctness=false;
const bool enable_threads=true;
//#define ENABLE_TRACK_CYCLES
#endif
//
//
//this doesn't do anything outside of test code
//this doesn't work with the divide table currently
#define TEST_ASM
const int gcd_size=20; //multiple of 4. must be at least half the discriminant size in bits divided by 64
const int gcd_max_iterations=gcd_size*2; //typically 1 iteration per limb
const int max_bits_base=1024; //half the discriminant number of bits, rounded up
const int reduce_max_iterations=10000;
const int num_asm_tracking_data=128;
bool enable_all_instructions=ENABLE_ALL_INSTRUCTIONS;
//if the asm code doesn't use fma, the c code shouldn't either to be the same as the asm code
const bool enable_fma_in_c_code=ENABLE_ALL_INSTRUCTIONS;
const int track_cycles_num_buckets=24; //each bucket is from 2^i to 2^(i+1) cycles
const int track_cycles_max_num=128;
void mark_vdf_test() {
static bool did_warning=false;
if (!is_vdf_test && !did_warning) {
print( "test code enabled in production build" );
did_warning=true;
}
}

View File

@ -0,0 +1,377 @@
/*
The MIT License (MIT)
Copyright (C) 2017 okdshin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
#ifndef PICOSHA2_H
#define PICOSHA2_H
// picosha2:20140213
#ifndef PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR
#define PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR \
1048576 //=1024*1024: default is 1MB memory
#endif
#include <algorithm>
#include <cassert>
#include <iterator>
#include <sstream>
#include <vector>
#include <fstream>
namespace picosha2 {
typedef unsigned long word_t;
typedef unsigned char byte_t;
static const size_t k_digest_size = 32;
namespace detail {
inline byte_t mask_8bit(byte_t x) { return x & 0xff; }
inline word_t mask_32bit(word_t x) { return x & 0xffffffff; }
const word_t add_constant[64] = {
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1,
0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786,
0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147,
0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a,
0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2};
const word_t initial_message_digest[8] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372,
0xa54ff53a, 0x510e527f, 0x9b05688c,
0x1f83d9ab, 0x5be0cd19};
inline word_t ch(word_t x, word_t y, word_t z) { return (x & y) ^ ((~x) & z); }
inline word_t maj(word_t x, word_t y, word_t z) {
return (x & y) ^ (x & z) ^ (y & z);
}
inline word_t rotr(word_t x, std::size_t n) {
assert(n < 32);
return mask_32bit((x >> n) | (x << (32 - n)));
}
inline word_t bsig0(word_t x) { return rotr(x, 2) ^ rotr(x, 13) ^ rotr(x, 22); }
inline word_t bsig1(word_t x) { return rotr(x, 6) ^ rotr(x, 11) ^ rotr(x, 25); }
inline word_t shr(word_t x, std::size_t n) {
assert(n < 32);
return x >> n;
}
inline word_t ssig0(word_t x) { return rotr(x, 7) ^ rotr(x, 18) ^ shr(x, 3); }
inline word_t ssig1(word_t x) { return rotr(x, 17) ^ rotr(x, 19) ^ shr(x, 10); }
template <typename RaIter1, typename RaIter2>
void hash256_block(RaIter1 message_digest, RaIter2 first, RaIter2 last) {
assert(first + 64 == last);
static_cast<void>(last); // for avoiding unused-variable warning
word_t w[64];
std::fill(w, w + 64, 0);
for (std::size_t i = 0; i < 16; ++i) {
w[i] = (static_cast<word_t>(mask_8bit(*(first + i * 4))) << 24) |
(static_cast<word_t>(mask_8bit(*(first + i * 4 + 1))) << 16) |
(static_cast<word_t>(mask_8bit(*(first + i * 4 + 2))) << 8) |
(static_cast<word_t>(mask_8bit(*(first + i * 4 + 3))));
}
for (std::size_t i = 16; i < 64; ++i) {
w[i] = mask_32bit(ssig1(w[i - 2]) + w[i - 7] + ssig0(w[i - 15]) +
w[i - 16]);
}
word_t a = *message_digest;
word_t b = *(message_digest + 1);
word_t c = *(message_digest + 2);
word_t d = *(message_digest + 3);
word_t e = *(message_digest + 4);
word_t f = *(message_digest + 5);
word_t g = *(message_digest + 6);
word_t h = *(message_digest + 7);
for (std::size_t i = 0; i < 64; ++i) {
word_t temp1 = h + bsig1(e) + ch(e, f, g) + add_constant[i] + w[i];
word_t temp2 = bsig0(a) + maj(a, b, c);
h = g;
g = f;
f = e;
e = mask_32bit(d + temp1);
d = c;
c = b;
b = a;
a = mask_32bit(temp1 + temp2);
}
*message_digest += a;
*(message_digest + 1) += b;
*(message_digest + 2) += c;
*(message_digest + 3) += d;
*(message_digest + 4) += e;
*(message_digest + 5) += f;
*(message_digest + 6) += g;
*(message_digest + 7) += h;
for (std::size_t i = 0; i < 8; ++i) {
*(message_digest + i) = mask_32bit(*(message_digest + i));
}
}
} // namespace detail
template <typename InIter>
void output_hex(InIter first, InIter last, std::ostream& os) {
os.setf(std::ios::hex, std::ios::basefield);
while (first != last) {
os.width(2);
os.fill('0');
os << static_cast<unsigned int>(*first);
++first;
}
os.setf(std::ios::dec, std::ios::basefield);
}
template <typename InIter>
void bytes_to_hex_string(InIter first, InIter last, std::string& hex_str) {
std::ostringstream oss;
output_hex(first, last, oss);
hex_str.assign(oss.str());
}
template <typename InContainer>
void bytes_to_hex_string(const InContainer& bytes, std::string& hex_str) {
bytes_to_hex_string(bytes.begin(), bytes.end(), hex_str);
}
template <typename InIter>
std::string bytes_to_hex_string(InIter first, InIter last) {
std::string hex_str;
bytes_to_hex_string(first, last, hex_str);
return hex_str;
}
template <typename InContainer>
std::string bytes_to_hex_string(const InContainer& bytes) {
std::string hex_str;
bytes_to_hex_string(bytes, hex_str);
return hex_str;
}
class hash256_one_by_one {
public:
hash256_one_by_one() { init(); }
void init() {
buffer_.clear();
std::fill(data_length_digits_, data_length_digits_ + 4, 0);
std::copy(detail::initial_message_digest,
detail::initial_message_digest + 8, h_);
}
template <typename RaIter>
void process(RaIter first, RaIter last) {
add_to_data_length(static_cast<word_t>(std::distance(first, last)));
std::copy(first, last, std::back_inserter(buffer_));
std::size_t i = 0;
for (; i + 64 <= buffer_.size(); i += 64) {
detail::hash256_block(h_, buffer_.begin() + i,
buffer_.begin() + i + 64);
}
buffer_.erase(buffer_.begin(), buffer_.begin() + i);
}
void finish() {
byte_t temp[64];
std::fill(temp, temp + 64, 0);
std::size_t remains = buffer_.size();
std::copy(buffer_.begin(), buffer_.end(), temp);
temp[remains] = 0x80;
if (remains > 55) {
std::fill(temp + remains + 1, temp + 64, 0);
detail::hash256_block(h_, temp, temp + 64);
std::fill(temp, temp + 64 - 4, 0);
} else {
std::fill(temp + remains + 1, temp + 64 - 4, 0);
}
write_data_bit_length(&(temp[56]));
detail::hash256_block(h_, temp, temp + 64);
}
template <typename OutIter>
void get_hash_bytes(OutIter first, OutIter last) const {
for (const word_t* iter = h_; iter != h_ + 8; ++iter) {
for (std::size_t i = 0; i < 4 && first != last; ++i) {
*(first++) = detail::mask_8bit(
static_cast<byte_t>((*iter >> (24 - 8 * i))));
}
}
}
private:
void add_to_data_length(word_t n) {
word_t carry = 0;
data_length_digits_[0] += n;
for (std::size_t i = 0; i < 4; ++i) {
data_length_digits_[i] += carry;
if (data_length_digits_[i] >= 65536u) {
carry = data_length_digits_[i] >> 16;
data_length_digits_[i] &= 65535u;
} else {
break;
}
}
}
void write_data_bit_length(byte_t* begin) {
word_t data_bit_length_digits[4];
std::copy(data_length_digits_, data_length_digits_ + 4,
data_bit_length_digits);
// convert byte length to bit length (multiply 8 or shift 3 times left)
word_t carry = 0;
for (std::size_t i = 0; i < 4; ++i) {
word_t before_val = data_bit_length_digits[i];
data_bit_length_digits[i] <<= 3;
data_bit_length_digits[i] |= carry;
data_bit_length_digits[i] &= 65535u;
carry = (before_val >> (16 - 3)) & 65535u;
}
// write data_bit_length
for (int i = 3; i >= 0; --i) {
(*begin++) = static_cast<byte_t>(data_bit_length_digits[i] >> 8);
(*begin++) = static_cast<byte_t>(data_bit_length_digits[i]);
}
}
std::vector<byte_t> buffer_;
word_t data_length_digits_[4]; // as 64bit integer (16bit x 4 integer)
word_t h_[8];
};
inline void get_hash_hex_string(const hash256_one_by_one& hasher,
std::string& hex_str) {
byte_t hash[k_digest_size];
hasher.get_hash_bytes(hash, hash + k_digest_size);
return bytes_to_hex_string(hash, hash + k_digest_size, hex_str);
}
inline std::string get_hash_hex_string(const hash256_one_by_one& hasher) {
std::string hex_str;
get_hash_hex_string(hasher, hex_str);
return hex_str;
}
namespace impl {
template <typename RaIter, typename OutIter>
void hash256_impl(RaIter first, RaIter last, OutIter first2, OutIter last2, int,
std::random_access_iterator_tag) {
hash256_one_by_one hasher;
// hasher.init();
hasher.process(first, last);
hasher.finish();
hasher.get_hash_bytes(first2, last2);
}
template <typename InputIter, typename OutIter>
void hash256_impl(InputIter first, InputIter last, OutIter first2,
OutIter last2, int buffer_size, std::input_iterator_tag) {
std::vector<byte_t> buffer(buffer_size);
hash256_one_by_one hasher;
// hasher.init();
while (first != last) {
int size = buffer_size;
for (int i = 0; i != buffer_size; ++i, ++first) {
if (first == last) {
size = i;
break;
}
buffer[i] = *first;
}
hasher.process(buffer.begin(), buffer.begin() + size);
}
hasher.finish();
hasher.get_hash_bytes(first2, last2);
}
}
template <typename InIter, typename OutIter>
void hash256(InIter first, InIter last, OutIter first2, OutIter last2,
int buffer_size = PICOSHA2_BUFFER_SIZE_FOR_INPUT_ITERATOR) {
picosha2::impl::hash256_impl(
first, last, first2, last2, buffer_size,
typename std::iterator_traits<InIter>::iterator_category());
}
template <typename InIter, typename OutContainer>
void hash256(InIter first, InIter last, OutContainer& dst) {
hash256(first, last, dst.begin(), dst.end());
}
template <typename InContainer, typename OutIter>
void hash256(const InContainer& src, OutIter first, OutIter last) {
hash256(src.begin(), src.end(), first, last);
}
template <typename InContainer, typename OutContainer>
void hash256(const InContainer& src, OutContainer& dst) {
hash256(src.begin(), src.end(), dst.begin(), dst.end());
}
template <typename InIter>
void hash256_hex_string(InIter first, InIter last, std::string& hex_str) {
byte_t hashed[k_digest_size];
hash256(first, last, hashed, hashed + k_digest_size);
std::ostringstream oss;
output_hex(hashed, hashed + k_digest_size, oss);
hex_str.assign(oss.str());
}
template <typename InIter>
std::string hash256_hex_string(InIter first, InIter last) {
std::string hex_str;
hash256_hex_string(first, last, hex_str);
return hex_str;
}
inline void hash256_hex_string(const std::string& src, std::string& hex_str) {
hash256_hex_string(src.begin(), src.end(), hex_str);
}
template <typename InContainer>
void hash256_hex_string(const InContainer& src, std::string& hex_str) {
hash256_hex_string(src.begin(), src.end(), hex_str);
}
template <typename InContainer>
std::string hash256_hex_string(const InContainer& src) {
return hash256_hex_string(src.begin(), src.end());
}
template<typename OutIter>void hash256(std::ifstream& f, OutIter first, OutIter last){
hash256(std::istreambuf_iterator<char>(f), std::istreambuf_iterator<char>(), first,last);
}
}// namespace picosha2
#endif // PICOSHA2_H

2
lib/chiavdf/fast_vdf/run.sh Executable file
View File

@ -0,0 +1,2 @@
#!/bin/bash
./vdf $1 $2

24
lib/chiavdf/fast_vdf/sconstruct Executable file
View File

@ -0,0 +1,24 @@
import gch
ccflags=' -O0'
#ccflags=' -O3'
ccflags = '-D VDF_MODE=1 -D ENABLE_ALL_INSTRUCTIONS=0 -no-pie -march=native' + ccflags
env.Append(
CCFLAGS=ccflags,
LINKFLAGS= '-no-pie',
LIBS=['gmpxx', 'gmp', 'flint', 'pthread']
);
gch.generate(env);
env['precompiled_header']=File('include.h');
env['Gch']=env.Gch(target='include.h.gch', source=env['precompiled_header']);
#env.Program('gcd_test.cpp');
#env.Program('vdf.cpp');
env.Program( 'compile_asm', 'compile_asm.cpp', CCFLAGS = ccflags + ' -O0' );
env.Command( 'asm_compiled.s', 'compile_asm', "./compile_asm" );
env.Program( 'vdf', [ 'vdf.cpp', 'asm_compiled.s' ] );

View File

@ -0,0 +1,898 @@
#include <boost/align/aligned_alloc.hpp>
//mp_limb_t is an unsigned integer
static_assert(sizeof(mp_limb_t)==8, "");
static_assert(sizeof(unsigned long int)==8, "");
static_assert(sizeof(long int)==8, "");
#ifdef ENABLE_TRACK_CYCLES
const int track_cycles_array_size=track_cycles_max_num*track_cycles_num_buckets;
thread_local int track_cycles_next_slot=0;
thread_local array<uint64, track_cycles_array_size> track_cycles_cycle_counters;
thread_local array<uint64, track_cycles_array_size> track_cycles_call_counters;
thread_local array<const char*, track_cycles_max_num> track_cycles_names;
void track_cycles_init() {
thread_local bool is_init=false;
if (!is_init) {
//print( &track_cycles_names );
//track_cycles_cycle_counters=new uint64[];
//track_cycles_call_counters=new uint64[track_cycles_max_num*track_cycles_num_buckets];
//track_cycles_names=new const char*[track_cycles_max_num];
for (int x=0;x<track_cycles_array_size;++x) {
track_cycles_cycle_counters.at(x)=0;
track_cycles_call_counters.at(x)=0;
}
for (int x=0;x<track_cycles_max_num;++x) {
track_cycles_names.at(x)=nullptr;
}
is_init=true;
}
}
void track_cycles_output_stats() {
track_cycles_init();
//print( &track_cycles_names );
for (int x=0;x<track_cycles_next_slot;++x) {
double total_calls=0;
for (int y=0;y<track_cycles_num_buckets;++y) {
total_calls+=track_cycles_call_counters.at(x*track_cycles_num_buckets + y);
}
if (total_calls==0) {
continue;
}
print( "" );
print( track_cycles_names.at(x), ":" );
for (int y=0;y<track_cycles_num_buckets;++y) {
double cycles=track_cycles_cycle_counters.at(x*track_cycles_num_buckets + y);
double calls=track_cycles_call_counters.at(x*track_cycles_num_buckets + y);
if (calls==0) {
continue;
}
print(str( "#%: #", int(calls/total_calls*100), int(cycles/calls) ));
}
}
}
struct track_cycles_impl {
int slot=-1;
uint64 start_time=0;
bool is_aborted=false;
static uint64 get_time() {
// Returns the time in EDX:EAX.
uint64 high;
uint64 low;
asm volatile(
"lfence\n\t"
"sfence\n\t"
"rdtsc\n\t"
"sfence\n\t"
"lfence\n\t"
: "=a"(low), "=d"(high) :: "memory");
return (high<<32) | low;
}
track_cycles_impl(int t_slot) {
slot=t_slot;
assert(slot>=0 && slot<track_cycles_max_num);
start_time=get_time();
}
void abort() {
is_aborted=true;
}
~track_cycles_impl() {
uint64 end_time=get_time();
if (is_aborted) {
return;
}
uint64 delta=end_time-start_time;
if (delta==0) {
return;
}
int num_bits=64-__builtin_clzll(delta);
if (num_bits>=track_cycles_num_buckets) {
return;
}
assert(num_bits>=0 && num_bits<track_cycles_num_buckets);
assert(slot>=0 && slot<track_cycles_max_num);
int index=slot*track_cycles_num_buckets + num_bits;
assert(index>=0 && index<track_cycles_max_num*track_cycles_num_buckets);
track_cycles_cycle_counters.at(index)+=delta;
++track_cycles_call_counters.at(index);
}
};
#define TO_STRING_IMPL(x) #x
#define TO_STRING(x) TO_STRING_IMPL(x)
#define TRACK_CYCLES \
track_cycles_init();\
thread_local int track_cycles_c_slot=-1;\
if (track_cycles_c_slot==-1) {\
track_cycles_c_slot=track_cycles_next_slot;\
++track_cycles_next_slot;\
\
track_cycles_names.at(track_cycles_c_slot)=__FILE__ ":" TO_STRING(__LINE__);\
}\
track_cycles_impl c_track_cycles_impl(track_cycles_c_slot);
//
#define TRACK_CYCLES_ABORT c_track_cycles_impl.abort();
#define TRACK_CYCLES_OUTPUT_STATS track_cycles_output_stats();
#else
#define TRACK_CYCLES
#define TRACK_CYCLES_ABORT
#define TRACK_CYCLES_OUTPUT_STATS
#endif
//use realloc or free to free the memory
void* alloc_cache_line(size_t bytes) {
//round up to the next multiple of 64
size_t aligned_bytes=((bytes+63)>>6)<<6;
void* res=boost::alignment::aligned_alloc(64, aligned_bytes); // aligned_alloc(64, aligned_bytes);
assert((uint64(res)&63)==0); //must be aligned for correctness
return res;
}
void* mp_alloc_func(size_t new_bytes) {
void* res=alloc_cache_line(new_bytes);
assert((uint64(res)&63)==0); //all memory used by gmp must be cache line aligned
return res;
}
void mp_free_func(void* old_ptr, size_t old_bytes) {
//either mp_alloc_func allocated old_ptr and it is 64-aligned, or it points to data in mpz and its address equals 16 modulo 64
assert((uint64(old_ptr)&63)==0 || (uint64(old_ptr)&63)==16);
if ((uint64(old_ptr)&63)==0) {
//mp_alloc_func allocated this, so it can be freed with std::free
boost::alignment::aligned_free(old_ptr); //free(old_ptr);
} else {
//this is part of the mpz struct defined below. it can't be freed, so do nothing
}
}
void* mp_realloc_func(void* old_ptr, size_t old_bytes, size_t new_bytes) {
void* res=mp_alloc_func(new_bytes);
memcpy(res, old_ptr, (old_bytes<new_bytes)? old_bytes : new_bytes);
mp_free_func(old_ptr, old_bytes);
return res;
}
//must call this before calling any gmp functions
//(the mpz class constructor does not call any gmp functions)
void init_gmp() {
mp_set_memory_functions(mp_alloc_func, mp_realloc_func, mp_free_func);
}
struct mpz_base {
//16 bytes
//int mpz._mp_alloc: number of limbs allocated
//int mpz._mp_size: abs(_mp_size) is number of limbs in use; 0 if the integer is zero. it is negated if the integer is negative
//mp_limb_t* mpz._mp_d: pointer to limbs
//do not call mpz_swap on this. mpz_swap can be called on other gmp integers
mpz_struct c_mpz;
operator mpz_struct*() { return &c_mpz; }
operator const mpz_struct*() const { return &c_mpz; }
mpz_struct* _() { return &c_mpz; }
const mpz_struct* _() const { return &c_mpz; }
};
//gmp can dynamically reallocate this
//the number of cache lines used is (padded_size+2)/8 rounded up
//1 cache line : 6 limbs
//2 cache lines: 14 limbs
//3 cache lines: 22 limbs
//4 cache lines: 30 limbs
//5 cache lines: 38 limbs
template<int d_expected_size, int d_padded_size> struct alignas(64) mpz : public mpz_base {
static const int expected_size=d_expected_size;
static const int padded_size=d_padded_size;
static_assert(expected_size>=1 && expected_size<=padded_size, "");
uint64 data[padded_size]; //must not be cache line aligned
bool was_reallocated() const {
return c_mpz._mp_d!=data;
}
//can't call any mpz functions here because it is global
mpz() {
c_mpz._mp_size=0;
c_mpz._mp_d=(mp_limb_t *)data;
c_mpz._mp_alloc=padded_size;
//this is supposed to be cache line aligned so that the next assert works
assert((uint64(this)&63)==0);
//mp_free_func uses this to decide whether to free or not
assert((uint64(c_mpz._mp_d)&63)==16);
}
~mpz() {
if (is_vdf_test) {
//don't want this to happen for performance reasons
assert(!was_reallocated());
}
//if c_mpz.data wasn't reallocated, it has to point to this instance's data and not some other instance's data
//if mpz_swap was used, this might be violated
assert((uint64(c_mpz._mp_d)&63)==0 || c_mpz._mp_d==data);
mpz_clear(&c_mpz);
}
mpz(const mpz& t)=delete;
mpz(mpz&& t)=delete;
mpz& operator=(const mpz_struct* t) {
mpz_set(*this, t);
return *this;
}
mpz& operator=(const mpz& t) {
mpz_set(*this, t);
return *this;
}
mpz& operator=(mpz&& t) {
mpz_set(*this, t); //do not use mpz_swap
return *this;
}
/*mpz& operator=(const mpz_base& t) {
mpz_set(*this, t);
return *this;
}
mpz& operator=(mpz_base&& t) {
mpz_set(*this, t); //do not use mpz_swap
return *this;
}*/
mpz& operator=(uint64 i) {
mpz_set_ui(*this, i);
return *this;
}
mpz& operator=(int64 i) {
mpz_set_si(*this, i);
return *this;
}
mpz& operator=(const string& s) {
int res=mpz_set_str(*this, s.c_str(), 0);
assert(res==0);
return *this;
}
USED string to_string() const {
char* res_char=mpz_get_str(nullptr, 16, *this);
string res_string = "0x";
res_string+=res_char;
if (res_string.substr(0, 3) == "0x-") {
res_string.at(0)='-';
res_string.at(1)='0';
res_string.at(2)='x';
}
free(res_char);
return res_string;
}
USED string to_string_dec() const {
char* res_char=mpz_get_str(nullptr, 10, *this);
string res_string=res_char;
free(res_char);
return res_string;
}
//sets *this to a+b
void set_add(const mpz_struct* a, const mpz_struct* b) {
mpz_add(*this, a, b);
}
void set_add(const mpz_struct* a, uint64 b) {
mpz_add_ui(*this, a, b);
}
mpz& operator+=(const mpz_struct* t) {
set_add(*this, t);
return *this;
}
mpz& operator+=(uint64 t) {
set_add(*this, t);
return *this;
}
void set_sub(const mpz_struct* a, const mpz_struct* b) {
mpz_sub(*this, a, b);
}
void set_sub(const mpz_struct* a, uint64 b) {
mpz_sub_ui(*this, a, b);
}
template<class mpz_b> void set_sub(uint64 a, const mpz_b& b) {
mpz_ui_sub(*this, a, b);
}
mpz& operator-=(const mpz_struct* t) {
set_sub(*this, t);
return *this;
}
void set_mul(const mpz_struct* a, const mpz_struct* b) {
mpz_mul(*this, a, b);
}
void set_mul(const mpz_struct* a, int64 b) {
mpz_mul_si(*this, a, b);
}
void set_mul(const mpz_struct* a, uint64 b) {
mpz_mul_ui(*this, a, b);
}
mpz& operator*=(const mpz_struct* t) {
set_mul(*this, t);
return *this;
}
mpz& operator*=(int64 t) {
set_mul(*this, t);
return *this;
}
mpz& operator*=(uint64 t) {
set_mul(*this, t);
return *this;
}
void set_left_shift(const mpz_struct* a, int i) {
assert(i>=0);
mpz_mul_2exp(*this, a, i);
}
mpz& operator<<=(int i) {
set_left_shift(*this, i);
return *this;
}
//*this+=a*b
void set_add_mul(const mpz_struct* a, const mpz_struct* b) {
mpz_addmul(*this, a, b);
}
void set_add_mul(const mpz_struct* a, uint64 b) {
mpz_addmul_ui(*this, a, b);
}
//*this-=a*b
void set_sub_mul(const mpz_struct* a, const mpz_struct* b) {
mpz_submul(*this, a, b);
}
void set_sub_mul(const mpz_struct* a, uint64 b) {
mpz_submul_ui(*this, a, b);
}
void negate() {
mpz_neg(*this, *this);
}
void abs() {
mpz_abs(*this, *this);
}
void set_divide_floor(const mpz_struct* a, const mpz_struct* b) {
if (mpz_sgn(b)==0) {
assert(false);
return;
}
mpz_fdiv_q(*this, a, b);
}
void set_divide_floor(const mpz_struct* a, const mpz_struct* b, mpz_struct* remainder) {
if (mpz_sgn(b)==0) {
assert(false);
return;
}
mpz_fdiv_qr(*this, remainder, a, b);
}
void set_divide_exact(const mpz_struct* a, const mpz_struct* b) {
if (mpz_sgn(b)==0) {
assert(false);
return;
}
mpz_divexact(*this, a, b);
}
void set_mod(const mpz_struct* a, const mpz_struct* b) {
if (mpz_sgn(b)==0) {
assert(false);
return;
}
mpz_mod(*this, a, b);
}
mpz& operator%=(const mpz_struct* t) {
set_mod(*this, t);
return *this;
}
bool divisible_by(const mpz_struct* a) const {
if (mpz_sgn(a)==0) {
assert(false);
return false;
}
return mpz_divisible_p(*this, a);
}
void set_right_shift(const mpz_struct* a, int i) {
assert(i>=0);
mpz_tdiv_q_2exp(*this, *this, i);
}
//note: this uses truncation rounding
mpz& operator>>=(int i) {
set_right_shift(*this, i);
return *this;
}
bool operator<(const mpz_struct* t) const { return mpz_cmp(*this, t)<0; }
bool operator<=(const mpz_struct* t) const { return mpz_cmp(*this, t)<=0; }
bool operator==(const mpz_struct* t) const { return mpz_cmp(*this, t)==0; }
bool operator>=(const mpz_struct* t) const { return mpz_cmp(*this, t)>=0; }
bool operator>(const mpz_struct* t) const { return mpz_cmp(*this, t)>0; }
bool operator!=(const mpz_struct* t) const { return mpz_cmp(*this, t)!=0; }
bool operator<(int64 i) const { return mpz_cmp_si(*this, i)<0; }
bool operator<=(int64 i) const { return mpz_cmp_si(*this, i)<=0; }
bool operator==(int64 i) const { return mpz_cmp_si(*this, i)==0; }
bool operator>=(int64 i) const { return mpz_cmp_si(*this, i)>=0; }
bool operator>(int64 i) const { return mpz_cmp_si(*this, i)>0; }
bool operator!=(int64 i) const { return mpz_cmp_si(*this, i)!=0; }
bool operator<(uint64 i) const { return mpz_cmp_ui(_(), i)<0; }
bool operator<=(uint64 i) const { return mpz_cmp_ui(_(), i)<=0; }
bool operator==(uint64 i) const { return mpz_cmp_ui(_(), i)==0; }
bool operator>=(uint64 i) const { return mpz_cmp_ui(_(), i)>=0; }
bool operator>(uint64 i) const { return mpz_cmp_ui(_(), i)>0; }
bool operator!=(uint64 i) const { return mpz_cmp_ui(_(), i)!=0; }
int compare_abs(const mpz_struct* t) const {
return mpz_cmpabs(*this, t);
}
int compare_abs(uint64 t) const {
return mpz_cmpabs_ui(*this, t);
}
//returns 0 if *this==0
int sgn() const {
return mpz_sgn(_());
}
int num_bits() const {
return mpz_sizeinbase(*this, 2);
}
//0 if this is 0
int num_limbs() const {
return mpz_size(*this);
}
const uint64* read_limbs() const {
return (uint64*)mpz_limbs_read(*this);
}
//limbs are uninitialized. call finish
uint64* write_limbs(int num) {
return (uint64*)mpz_limbs_write(*this, num);
}
//limbs are zero padded to the specified size. call finish
uint64* modify_limbs(int num) {
int old_size=num_limbs();
uint64* res=(uint64*)mpz_limbs_modify(*this, num);
//gmp doesn't do this
for (int x=old_size;x<num;++x) {
res[x]=0;
}
return res;
}
//num is whatever was passed to write_limbs or modify_limbs
//it can be less than that as long as it is at least the number of nonzero limbs
//it can be 0 if the result is 0
void finish(int num, bool negative=false) {
mpz_limbs_finish(*this, (negative)? -num : num);
}
template<int size> array<uint64, size> to_array() const {
assert(size>=num_limbs());
array<uint64, size> res;
for (int x=0;x<size;++x) {
res[x]=0;
}
for (int x=0;x<num_limbs();++x) {
res[x]=read_limbs()[x];
}
return res;
}
};
template<class type> struct cache_line_ptr {
type* ptr=nullptr;
cache_line_ptr() {}
cache_line_ptr(cache_line_ptr& t)=delete;
cache_line_ptr(cache_line_ptr&& t) { swap(ptr, t.ptr); }
cache_line_ptr& operator=(cache_line_ptr& t)=delete;
cache_line_ptr& operator=(cache_line_ptr&& t) { swap(ptr, t.ptr); }
~cache_line_ptr() {
if (ptr) {
ptr->~type();
boost::alignment::aligned_free(ptr); // wjb free(ptr);
ptr=nullptr;
}
}
type& operator*() const { return *ptr; }
type* operator->() const { return ptr; }
};
template<class type, class... arg_types> cache_line_ptr<type> make_cache_line(arg_types&&... args) {
cache_line_ptr<type> res;
res.ptr=(type*)alloc_cache_line(sizeof(type));
new(res.ptr) type(forward<arg_types>(args)...);
return res;
}
template<bool is_write, class type> void prefetch(const type& p) {
//write prefetching lowers performance but read prefetching increases it
if (is_write) return;
for (int x=0;x<sizeof(p);x+=64) {
__builtin_prefetch(((char*)&p)+x, (is_write)? 1 : 0);
}
}
template<class type> void prefetch_write(const type& p) { prefetch<true>(p); }
template<class type> void prefetch_read(const type& p) { prefetch<false>(p); }
void memory_barrier() {
asm volatile( "" ::: "memory" );
}
struct alignas(64) thread_counter {
uint64 counter_value=0; //updated atomically since only one thread can write to it
uint64 error_flag=0;
void reset() {
memory_barrier();
counter_value=0;
error_flag=0;
memory_barrier();
}
thread_counter() {
assert((uint64(this)&63)==0);
}
};
thread_counter master_counter[100];
thread_counter slave_counter[100];
struct thread_state {
int pairindex;
bool is_slave=false;
uint64 counter_start=0;
uint64 last_fence=0;
void reset() {
is_slave=false;
counter_start=0;
last_fence=0;
}
thread_counter& this_counter() {
return (is_slave)? slave_counter[pairindex] : master_counter[pairindex];
}
thread_counter& other_counter() {
return (is_slave)? master_counter[pairindex] : slave_counter[pairindex];
}
void raise_error() {
//if (is_vdf_test) {
//print( "raise_error", is_slave );
//}
memory_barrier();
this_counter().error_flag=1;
other_counter().error_flag=1;
memory_barrier();
}
const uint64 v() {
return this_counter().counter_value;
}
//waits for the other thread to have at least this counter value
//returns false if an error has been raised
bool fence_absolute(uint64 t_v) {
if (last_fence>=t_v) {
return true;
}
memory_barrier();
uint64 spin_counter=0;
while (other_counter().counter_value < t_v) {
if (this_counter().error_flag || other_counter().error_flag) {
raise_error();
break;
}
if (spin_counter>max_spin_counter) {
if (is_vdf_test) {
print( "spin_counter too high", is_slave );
}
raise_error();
break;
}
++spin_counter;
memory_barrier();
}
memory_barrier();
if (!(this_counter().error_flag)) {
last_fence=t_v;
}
return !(this_counter().error_flag);
}
bool fence(int delta) {
return fence_absolute(counter_start+uint64(delta));
}
//increases this thread's counter value. it can only be increased
//returns false if an error has been raised
bool advance_absolute(uint64 t_v) {
if (t_v==v()) {
return true;
}
memory_barrier(); //wait for all writes to finish (on x86 this doesn't do anything but the compiler still needs it)
assert(t_v>=v());
if (this_counter().error_flag) {
raise_error();
}
this_counter().counter_value=t_v;
memory_barrier(); //want the counter writes to be low latency so prevent the compiler from caching it
return !(this_counter().error_flag);
}
bool advance(int delta) {
return advance_absolute(counter_start+uint64(delta));
}
bool has_error() {
return this_counter().error_flag;
}
/*void wait_for_error_to_be_cleared() {
assert(is_slave && enable_threads);
while (this_counter().error_flag) {
memory_barrier();
}
}
void clear_error() {
assert(!is_slave);
memory_barrier();
this_counter().error_flag=0;
other_counter().error_flag=0;
memory_barrier();
}*/
};
thread_local thread_state c_thread_state;
struct alignas(64) gcd_uv_entry {
//these are uninitialized for the first entry
uint64 u_0;
uint64 u_1;
uint64 v_0;
uint64 v_1;
uint64 parity; //1 if odd, 0 if even
uint64 exit_flag; //1 if last, else 0
uint64 unused_0;
uint64 unused_1;
template<class mpz_type> void matrix_multiply(const mpz_type& in_a, const mpz_type& in_b, mpz_type& out_a, mpz_type& out_b) const {
out_a.set_mul((parity==0)? in_a : in_b, (parity==0)? u_0 : v_0);
out_a.set_sub_mul((parity==0)? in_b : in_a, (parity==0)? v_0 : u_0);
out_b.set_mul((parity==0)? in_b : in_a, (parity==0)? v_1 : u_1);
out_b.set_sub_mul((parity==0)? in_a : in_b, (parity==0)? u_1 : v_1);
}
};
static_assert(sizeof(gcd_uv_entry)==64, "");
template<class mpz_type> struct alignas(64) gcd_results_type {
mpz_type as[2];
mpz_type bs[2];
static const int num_counter=gcd_max_iterations+1; //one per outputted entry
array<gcd_uv_entry, gcd_max_iterations+1> uv_entries;
int end_index=0;
mpz_type& get_a_start() {
return as[0];
}
mpz_type& get_b_start() {
return bs[0];
}
mpz_type& get_a_end() {
assert(end_index>=0 && end_index<2);
return as[end_index];
}
mpz_type& get_b_end() {
assert(end_index>=0 && end_index<2);
return bs[end_index];
}
//this will increase the counter value and wait until the result at index is available
//index 0 only has exit_flag initialized
bool get_entry(int counter_start_delta, int index, const gcd_uv_entry** res) const {
*res=nullptr;
if (index>=gcd_max_iterations+1) {
c_thread_state.raise_error();
return false;
}
assert(index>=0);
if (!c_thread_state.fence(counter_start_delta + index+1)) {
return false;
}
*res=&uv_entries[index];
return true;
}
};
//a and b in c_results should be initialized
//returns false if the gcd failed
//this assumes that all inputs are unsigned, a>=b, and a>=threshold
//this will increase the counter value as results are generated
template<class mpz_type> bool gcd_unsigned(
int counter_start_delta, gcd_results_type<mpz_type>& c_results, const array<uint64, gcd_size>& threshold
) {
if (c_thread_state.has_error()) {
return false;
}
int a_limbs=c_results.get_a_start().num_limbs();
int b_limbs=c_results.get_b_start().num_limbs();
if (a_limbs>gcd_size || b_limbs>gcd_size) {
c_thread_state.raise_error();
return false;
}
asm_code::asm_func_gcd_unsigned_data data;
data.a=c_results.as[0].modify_limbs(gcd_size);
data.b=c_results.bs[0].modify_limbs(gcd_size);
data.a_2=c_results.as[1].write_limbs(gcd_size);
data.b_2=c_results.bs[1].write_limbs(gcd_size);
data.threshold=(uint64*)&threshold[0];
data.uv_counter_start=c_thread_state.counter_start+counter_start_delta+1;
data.out_uv_counter_addr=&(c_thread_state.this_counter().counter_value);
data.out_uv_addr=(uint64*)&(c_results.uv_entries[1]);
data.iter=-1;
data.a_end_index=(a_limbs==0)? 0 : a_limbs-1;
if (is_vdf_test) {
assert((uint64(data.out_uv_addr)&63)==0); //should be cache line aligned
}
memory_barrier();
int error_code=asm_code::asm_func_gcd_unsigned(&data);
memory_barrier();
if (error_code!=0) {
c_thread_state.raise_error();
return false;
}
assert(data.iter>=0 && data.iter<=gcd_max_iterations); //total number of iterations performed
bool is_even=((data.iter-1)&1)==0; //parity of last iteration (can be -1)
c_results.end_index=(is_even)? 1 : 0;
c_results.as[0].finish(gcd_size);
c_results.as[1].finish(gcd_size);
c_results.bs[0].finish(gcd_size);
c_results.bs[1].finish(gcd_size);
inject_error(c_results.as[0]);
inject_error(c_results.as[1]);
inject_error(c_results.bs[0]);
inject_error(c_results.bs[1]);
if (!c_thread_state.advance(counter_start_delta+gcd_results_type<mpz_type>::num_counter)) {
return false;
}
return true;
}

2
lib/chiavdf/fast_vdf/upload.sh Executable file
View File

@ -0,0 +1,2 @@
#!/bin/bash
scp *.c *.cpp *.h *.sh sconstruct $VM:projects/chia_vdf/

View File

@ -0,0 +1,860 @@
#include "include.h"
#include "parameters.h"
#include "bit_manipulation.h"
#include "double_utility.h"
#include "integer.h"
#include "asm_main.h"
#include "vdf_original.h"
#include "vdf_new.h"
#include "picosha2.h"
#include "gpu_integer.h"
#include "gpu_integer_divide.h"
#include "gcd_base_continued_fractions.h"
//#include "gcd_base_divide_table.h"
#include "gcd_128.h"
#include "gcd_unsigned.h"
#include "gpu_integer_gcd.h"
#include "asm_types.h"
#include "threading.h"
#include "nucomp.h"
#include "vdf_fast.h"
#include "vdf_test.h"
#include <map>
#include <algorithm>
#include <thread>
#include <future>
#include <chrono>
#include "ClassGroup.h"
#include "Reducer.h"
#include <boost/asio.hpp>
bool warn_on_corruption_in_production=false;
using boost::asio::ip::tcp;
struct akashnil_form {
// y = ax^2 + bxy + y^2
mpz_t a;
mpz_t b;
mpz_t c;
// mpz_t d; // discriminant
};
const int64_t THRESH = 1UL<<31;
const int64_t EXP_THRESH = 31;
std::vector<form> forms;
//always works
void repeated_square_original(vdf_original &vdfo, form& f, const integer& D, const integer& L, uint64 base, uint64 iterations, INUDUPLListener *nuduplListener) {
vdf_original::form f_in,*f_res;
f_in.a[0]=f.a.impl[0];
f_in.b[0]=f.b.impl[0];
f_in.c[0]=f.c.impl[0];
f_res=&f_in;
for (uint64_t i=0; i < iterations; i++) {
f_res = vdfo.square(*f_res);
if(nuduplListener!=NULL)
nuduplListener->OnIteration(NL_FORM,f_res,base+i);
}
mpz_set(f.a.impl, f_res->a);
mpz_set(f.b.impl, f_res->b);
mpz_set(f.c.impl, f_res->c);
}
class WesolowskiCallback :public INUDUPLListener {
public:
uint64_t kl;
//struct form *forms;
form result;
bool deferred;
int64_t switch_iters = -1;
int64_t switch_index;
int64_t iterations = 0; // This must be intialized to zero at start
integer D;
integer L;
ClassGroupContext *t;
Reducer *reducer;
vdf_original* vdfo;
WesolowskiCallback(uint64_t expected_space) {
vdfo = new vdf_original();
t=new ClassGroupContext(4096);
reducer=new Reducer(*t);
}
~WesolowskiCallback() {
delete(vdfo);
delete(reducer);
delete(t);
}
void reduce(form& inf) {
#if 0
// Old reduce from Sundersoft form
inf.reduce();
#else
// Pulmark reduce based on Akashnil reduce
mpz_set(t->a, inf.a.impl);
mpz_set(t->b, inf.b.impl);
mpz_set(t->c, inf.c.impl);
reducer->run();
mpz_set(inf.a.impl, t->a);
mpz_set(inf.b.impl, t->b);
mpz_set(inf.c.impl, t->c);
#endif
}
void IncreaseConstants(int num_iters) {
kl = 100;
switch_iters = num_iters;
switch_index = num_iters / 10;
}
int GetPosition(int power) {
if (switch_iters == -1 || power < switch_iters) {
return power / 10;
} else {
return (switch_index + (power - switch_iters) / 100);
}
}
form *GetForm(int power) {
return &(forms[GetPosition(power)]);
}
void OnIteration(int type, void *data, uint64 iteration)
{
iteration++;
//cout << iteration << " " << maxiterations << endl;
if(iteration%kl==0)
{
form *mulf=GetForm(iteration);
// Initialize since it is raw memory
// mpz_inits(mulf->a.impl,mulf->b.impl,mulf->c.impl,NULL);
switch(type)
{
case NL_SQUARESTATE:
{
//cout << "NL_SQUARESTATE" << endl;
uint64 res;
square_state_type *square_state=(square_state_type *)data;
if(!square_state->assign(mulf->a, mulf->b, mulf->c, res))
cout << "square_state->assign failed" << endl;
break;
}
case NL_FORM:
{
//cout << "NL_FORM" << endl;
vdf_original::form *f=(vdf_original::form *)data;
mpz_set(mulf->a.impl, f->a);
mpz_set(mulf->b.impl, f->b);
mpz_set(mulf->c.impl, f->c);
break;
}
default:
cout << "Unknown case" << endl;
}
reduce(*mulf);
iterations=iteration; // safe to access now
}
}
};
void ApproximateParameters(uint64_t T, uint64_t& L, uint64_t& k, uint64_t& w) {
double log_memory = 23.25349666;
double log_T = log2(T);
L = 1;
if (log_T - log_memory > 0.000001) {
L = ceil(pow(2, log_memory - 20));
}
double intermediate = T * (double)0.6931471 / (2.0 * L);
k = std::max(std::round(log(intermediate) - log(log(intermediate)) + 0.25), 1.0);
//w = floor((double) T / ((double) T/k + L * (1 << (k+1)))) - 2;
w = 2;
}
// thread safe; but it is only called from the main thread
void repeated_square(form f, const integer& D, const integer& L, WesolowskiCallback &weso, bool& stopped) {
#ifdef VDF_TEST
uint64 num_calls_fast=0;
uint64 num_iterations_fast=0;
uint64 num_iterations_slow=0;
#endif
uint64_t num_iterations = 0;
while (!stopped) {
uint64 c_checkpoint_interval=checkpoint_interval;
// if (weso.iterations >= 5000000) {
// std::cout << "Stopping weso at 5000000 iterations!\n";
// return ;
// }
#ifdef VDF_TEST
form f_copy;
form f_copy_3;
bool f_copy_3_valid=false;
if (vdf_test_correctness) {
f_copy=f;
c_checkpoint_interval=1;
f_copy_3=f;
f_copy_3_valid=square_fast_impl(f_copy_3, D, L, num_iterations);
}
#endif
uint64 batch_size=c_checkpoint_interval;
#ifdef ENABLE_TRACK_CYCLES
print( "track cycles enabled; results will be wrong" );
repeated_square_original(*weso.vdfo, f, D, L, 100); //randomize the a and b values
#endif
// This works single threaded
square_state_type square_state;
square_state.pairindex=0;
uint64 actual_iterations=repeated_square_fast(square_state, f, D, L, num_iterations, batch_size, &weso);
#ifdef VDF_TEST
++num_calls_fast;
if (actual_iterations!=~uint64(0)) num_iterations_fast+=actual_iterations;
#endif
#ifdef ENABLE_TRACK_CYCLES
print( "track cycles actual iterations", actual_iterations );
return; //exit the program
#endif
if (actual_iterations==~uint64(0)) {
//corruption; f is unchanged. do the entire batch with the slow algorithm
repeated_square_original(*weso.vdfo, f, D, L, num_iterations, batch_size, &weso);
actual_iterations=batch_size;
#ifdef VDF_TEST
num_iterations_slow+=batch_size;
#endif
if (warn_on_corruption_in_production) {
print( "!!!! corruption detected and corrected !!!!" );
}
}
if (actual_iterations<batch_size) {
//the fast algorithm terminated prematurely for whatever reason. f is still valid
//it might terminate prematurely again (e.g. gcd quotient too large), so will do one iteration of the slow algorithm
//this will also reduce f if the fast algorithm terminated because it was too big
repeated_square_original(*weso.vdfo, f, D, L, num_iterations+actual_iterations, 1, &weso);
#ifdef VDF_TEST
++num_iterations_slow;
if (vdf_test_correctness) {
assert(actual_iterations==0);
print( "fast vdf terminated prematurely", num_iterations );
}
#endif
++actual_iterations;
}
num_iterations+=actual_iterations;
#ifdef VDF_TEST
if (vdf_test_correctness) {
form f_copy_2=f;
weso.reduce(f_copy_2);
repeated_square_original(&weso.vdfo, f_copy, D, L, actual_iterations);
assert(f_copy==f_copy_2);
}
#endif
}
#ifdef VDF_TEST
print( "fast average batch size", double(num_iterations_fast)/double(num_calls_fast) );
print( "fast iterations per slow iteration", double(num_iterations_fast)/double(num_iterations_slow) );
#endif
}
std::vector<unsigned char> ConvertIntegerToBytes(integer x, uint64_t num_bytes) {
std::vector<unsigned char> bytes;
bool negative = false;
if (x < 0) {
x = abs(x);
x = x - integer(1);
negative = true;
}
for (int iter = 0; iter < num_bytes; iter++) {
auto byte = (x % integer(256)).to_vector();
if (negative)
byte[0] ^= 255;
bytes.push_back(byte[0]);
x = x / integer(256);
}
std::reverse(bytes.begin(), bytes.end());
return bytes;
}
integer HashPrime(std::vector<unsigned char> s) {
std::string prime = "prime";
uint32_t j = 0;
while (true) {
std::vector<unsigned char> input(prime.begin(), prime.end());
std::vector<unsigned char> j_to_bytes = ConvertIntegerToBytes(integer(j), 8);
input.insert(input.end(), j_to_bytes.begin(), j_to_bytes.end());
input.insert(input.end(), s.begin(), s.end());
std::vector<unsigned char> hash(picosha2::k_digest_size);
picosha2::hash256(input.begin(), input.end(), hash.begin(), hash.end());
integer prime_integer;
for (int i = 0; i < 16; i++) {
prime_integer *= integer(256);
prime_integer += integer(hash[i]);
}
if (prime_integer.prime()) {
return prime_integer;
}
j++;
}
}
std::vector<unsigned char> SerializeForm(WesolowskiCallback &weso, form &y, int int_size) {
//weso.reduce(y);
y.reduce();
std::vector<unsigned char> res = ConvertIntegerToBytes(y.a, int_size);
std::vector<unsigned char> b_res = ConvertIntegerToBytes(y.b, int_size);
res.insert(res.end(), b_res.begin(), b_res.end());
return res;
}
integer GetB(WesolowskiCallback &weso, integer& D, form &x, form& y) {
int int_size = (D.num_bits() + 16) >> 4;
std::vector<unsigned char> serialization = SerializeForm(weso, x, int_size);
std::vector<unsigned char> serialization_y = SerializeForm(weso, y, int_size);
serialization.insert(serialization.end(), serialization_y.begin(), serialization_y.end());
return HashPrime(serialization);
}
integer FastPow(uint64_t a, uint64_t b, integer& c) {
if (b == 0)
return integer(1);
integer res = FastPow(a, b / 2, c);
res = res * res;
res = res % c;
if (b % 2) {
res = res * integer(a);
res = res % c;
}
return res;
}
form FastPowForm(form &x, const integer& D, uint64_t num_iterations) {
if (num_iterations == 0)
return form::identity(D);
form res = FastPowForm(x, D, num_iterations / 2);
res = res * res;
if (num_iterations % 2)
res = res * x;
return res;
}
uint64_t GetBlock(uint64_t i, uint64_t k, uint64_t T, integer& B) {
integer res(1 << k);
res *= FastPow(2, T - k * (i + 1), B);
res = res / B;
auto res_vector = res.to_vector();
return res_vector[0];
}
std::string BytesToStr(const std::vector<unsigned char> &in)
{
std::vector<unsigned char>::const_iterator from = in.cbegin();
std::vector<unsigned char>::const_iterator to = in.cend();
std::ostringstream oss;
for (; from != to; ++from)
oss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(*from);
return oss.str();
}
struct Proof {
Proof() {
}
Proof(std::vector<unsigned char> y, std::vector<unsigned char> proof) {
this->y = y;
this->proof = proof;
}
string hex() {
std::vector<unsigned char> bytes(y);
bytes.insert(bytes.end(), proof.begin(), proof.end());
return BytesToStr(bytes);
}
std::vector<unsigned char> y;
std::vector<unsigned char> proof;
};
#define PULMARK 1
form GenerateProof(form &y, form &x_init, integer &D, uint64_t done_iterations, uint64_t num_iterations, uint64_t k, uint64_t l, WesolowskiCallback& weso, bool& stop_signal) {
auto t1 = std::chrono::high_resolution_clock::now();
#if PULMARK
ClassGroupContext *t;
Reducer *reducer;
t=new ClassGroupContext(4096);
reducer=new Reducer(*t);
#endif
integer B = GetB(weso, D, x_init, y);
integer L=root(-D, 4);
uint64_t k1 = k / 2;
uint64_t k0 = k - k1;
form x = form::identity(D);
for (int64_t j = l - 1; j >= 0; j--) {
x=FastPowForm(x, D, (1 << k));
std::vector<form> ys((1 << k));
for (uint64_t i = 0; i < (1 << k); i++)
ys[i] = form::identity(D);
form *tmp;
for (uint64_t i = 0; !stop_signal && i < ceil(1.0 * num_iterations / (k * l)); i++) {
if (num_iterations >= k * (i * l + j + 1)) {
uint64_t b = GetBlock(i*l + j, k, num_iterations, B);
tmp = weso.GetForm(done_iterations + i * k * l);
nucomp_form(ys[b], ys[b], *tmp, D, L);
#if PULMARK
// Pulmark reduce based on Akashnil reduce
mpz_set(t->a, ys[b].a.impl);
mpz_set(t->b, ys[b].b.impl);
mpz_set(t->c, ys[b].c.impl);
reducer->run();
mpz_set(ys[b].a.impl, t->a);
mpz_set(ys[b].b.impl, t->b);
mpz_set(ys[b].c.impl, t->c);
#else
ys[b].reduce();
#endif
}
}
if (stop_signal)
return form();
for (uint64_t b1 = 0; b1 < (1 << k1) && !stop_signal; b1++) {
form z = form::identity(D);
for (uint64_t b0 = 0; b0 < (1 << k0) && !stop_signal; b0++) {
nucomp_form(z, z, ys[b1 * (1 << k0) + b0], D, L);
#if PULMARK
// Pulmark reduce based on Akashnil reduce
mpz_set(t->a, z.a.impl);
mpz_set(t->b, z.b.impl);
mpz_set(t->c, z.c.impl);
reducer->run();
mpz_set(z.a.impl, t->a);
mpz_set(z.b.impl, t->b);
mpz_set(z.c.impl, t->c);
#else
z.reduce();
#endif
}
z = FastPowForm(z, D, b1 * (1 << k0));
x = x * z;
}
for (uint64_t b0 = 0; b0 < (1 << k0) && !stop_signal; b0++) {
form z = form::identity(D);
for (uint64_t b1 = 0; b1 < (1 << k1) && !stop_signal; b1++) {
nucomp_form(z, z, ys[b1 * (1 << k0) + b0], D, L);
#if PULMARK
// Pulmark reduce based on Akashnil reduce
mpz_set(t->a, z.a.impl);
mpz_set(t->b, z.b.impl);
mpz_set(t->c, z.c.impl);
reducer->run();
mpz_set(z.a.impl, t->a);
mpz_set(z.b.impl, t->b);
mpz_set(z.c.impl, t->c);
#else
z.reduce();
#endif
}
z = FastPowForm(z, D, b0);
x = x * z;
}
if (stop_signal)
return form();
}
#if PULMARK
// Pulmark reduce based on Akashnil reduce
mpz_set(t->a, x.a.impl);
mpz_set(t->b, x.b.impl);
mpz_set(t->c, x.c.impl);
reducer->run();
mpz_set(x.a.impl, t->a);
mpz_set(x.b.impl, t->b);
mpz_set(x.c.impl, t->c);
delete(reducer);
delete(t);
#else
x.reduce();
#endif
auto t2 = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count();
return x;
}
void GenerateProofThreaded(std::promise<form> && form_promise, form y, form x_init, integer D, uint64_t done_iterations, uint64_t num_iterations, uint64_t
k, uint64_t l, WesolowskiCallback& weso, bool& stop_signal) {
form proof = GenerateProof(y, x_init, D, done_iterations, num_iterations, k, l, weso, stop_signal);
form_promise.set_value(proof);
}
Proof CreateProofOfTimeWesolowski(integer& D, form x, int64_t num_iterations, uint64_t done_iterations, WesolowskiCallback& weso, bool& stop_signal) {
uint64_t l, k, w;
form x_init = x;
integer L=root(-D, 4);
k = 10;
w = 2;
l = (num_iterations >= 10000000) ? 10 : 1;
while (!stop_signal && weso.iterations < done_iterations + num_iterations) {
std::this_thread::sleep_for (std::chrono::seconds(3));
}
if (stop_signal)
return Proof();
vdf_original vdfo_proof;
uint64 checkpoint = (done_iterations + num_iterations) - (done_iterations + num_iterations) % 100;
//mpz_init(y.a.impl);
//mpz_init(y.b.impl);
//mpz_init(y.c.impl);
form y = forms[weso.GetPosition(checkpoint)];
repeated_square_original(vdfo_proof, y, D, L, 0, (done_iterations + num_iterations) % 100, NULL);
auto proof = GenerateProof(y, x_init, D, done_iterations, num_iterations, k, l, weso, stop_signal);
if (stop_signal)
return Proof();
int int_size = (D.num_bits() + 16) >> 4;
std::vector<unsigned char> y_bytes = SerializeForm(weso, y, 129);
std::vector<unsigned char> proof_bytes = SerializeForm(weso, proof, int_size);
Proof final_proof=Proof(y_bytes, proof_bytes);
return final_proof;
}
Proof CreateProofOfTimeNWesolowski(integer& D, form x, int64_t num_iterations,
uint64_t done_iterations, WesolowskiCallback& weso, int depth_limit, int depth, bool& stop_signal) {
uint64_t l, k, w;
int64_t iterations1, iterations2;
integer L=root(-D, 4);
form x_init = x;
k = 10;
w = 2;
l = (num_iterations >= 10000000) ? 10 : 1;
iterations1 = num_iterations * w / (w + 1);
// NOTE(Florin): This is still suboptimal,
// some work can still be lost if weso iterations is in between iterations1 and num_iterations.
if (weso.iterations >= done_iterations + num_iterations) {
iterations1 = (done_iterations + num_iterations) / 3;
}
iterations1 = iterations1 - iterations1 % 100;
iterations2 = num_iterations - iterations1;
while (!stop_signal && weso.iterations < done_iterations + iterations1) {
std::this_thread::sleep_for (std::chrono::seconds(3));
}
if (stop_signal)
return Proof();
form y1 = *weso.GetForm(done_iterations + iterations1);
std::promise<form> form_promise;
auto form_future = form_promise.get_future();
std::thread t(&GenerateProofThreaded, std::move(form_promise), y1, x_init, D, done_iterations, iterations1, k, l, std::ref(weso), std::ref(stop_signal));
Proof proof2;
if (depth < depth_limit - 1) {
proof2 = CreateProofOfTimeNWesolowski(D, y1, iterations2, done_iterations + iterations1, weso, depth_limit, depth + 1, stop_signal);
} else {
proof2 = CreateProofOfTimeWesolowski(D, y1, iterations2, done_iterations + iterations1, weso, stop_signal);
}
t.join();
if (stop_signal)
return Proof();
form proof = form_future.get();
int int_size = (D.num_bits() + 16) >> 4;
Proof final_proof;
final_proof.y = proof2.y;
std::vector<unsigned char> proof_bytes(proof2.proof);
std::vector<unsigned char> tmp = ConvertIntegerToBytes(integer(iterations1), 8);
proof_bytes.insert(proof_bytes.end(), tmp.begin(), tmp.end());
tmp.clear();
tmp = SerializeForm(weso, y1, int_size);
proof_bytes.insert(proof_bytes.end(), tmp.begin(), tmp.end());
tmp.clear();
tmp = SerializeForm(weso, proof, int_size);
proof_bytes.insert(proof_bytes.end(), tmp.begin(), tmp.end());
final_proof.proof = proof_bytes;
return final_proof;
}
std::mutex socket_mutex;
void NWesolowskiMain(integer D, form x, int64_t num_iterations, WesolowskiCallback& weso, bool& stop_signal, tcp::socket& sock) {
Proof result = CreateProofOfTimeNWesolowski(D, x, num_iterations, 0, weso, 2, 0, stop_signal);
if (stop_signal == true) {
std::cout << "Got stop signal before completing the proof!\n";
return ;
}
std::vector<unsigned char> bytes = ConvertIntegerToBytes(integer(num_iterations), 8);
bytes.insert(bytes.end(), result.y.begin(), result.y.end());
bytes.insert(bytes.end(), result.proof.begin(), result.proof.end());
std::string str_result = BytesToStr(bytes);
std::lock_guard<std::mutex> lock(socket_mutex);
std::cout << "VDF server: Generated proof = " << str_result << "\n";
boost::asio::write(sock, boost::asio::buffer(str_result.c_str(), str_result.size()));
}
void PollTimelord(tcp::socket& sock, bool& got_iters) {
// Wait for 15s, if no iters come, poll each 5 seconds the timelord.
int seconds = 0;
while (!got_iters) {
std::this_thread::sleep_for (std::chrono::seconds(1));
seconds++;
if (seconds >= 15 && (seconds - 15) % 5 == 0) {
socket_mutex.lock();
boost::asio::write(sock, boost::asio::buffer("POLL", 4));
socket_mutex.unlock();
}
}
}
const int max_length = 2048;
void session(tcp::socket sock) {
try {
char disc[350];
char disc_size[5];
boost::system::error_code error;
memset(disc,0x00,sizeof(disc)); // For null termination
memset(disc_size,0x00,sizeof(disc_size)); // For null termination
boost::asio::read(sock, boost::asio::buffer(disc_size, 3), error);
int disc_int_size = atoi(disc_size);
boost::asio::read(sock, boost::asio::buffer(disc, disc_int_size), error);
integer D(disc);
std::cout << "Discriminant = " << D.impl << "\n";
// Init VDF the discriminant...
if (error == boost::asio::error::eof)
return ; // Connection closed cleanly by peer.
else if (error)
throw boost::system::system_error(error); // Some other error.
if (getenv( "warn_on_corruption_in_production" )!=nullptr) {
warn_on_corruption_in_production=true;
}
if (is_vdf_test) {
print( "=== Test mode ===" );
}
if (warn_on_corruption_in_production) {
print( "=== Warn on corruption enabled ===" );
}
assert(is_vdf_test); //assertions should be disabled in VDF_MODE==0
init_gmp();
allow_integer_constructor=true; //make sure the old gmp allocator isn't used
set_rounding_mode();
integer L=root(-D, 4);
form f=form::generator(D);
bool stop_signal = false;
std::set<uint64_t> seen_iterations;
std::vector<std::thread> threads;
WesolowskiCallback weso(1000000);
//mpz_init(weso.forms[0].a.impl);
//mpz_init(weso.forms[0].b.impl);
//mpz_init(weso.forms[0].c.impl);
forms[0]=f;
weso.D = D;
weso.L = L;
weso.kl = 10;
bool stopped = false;
bool got_iters = false;
std::thread vdf_worker(repeated_square, f, D, L, std::ref(weso), std::ref(stopped));
std::thread poll_thread(PollTimelord, std::ref(sock), std::ref(got_iters));
// Tell client that I'm ready to get the challenges.
boost::asio::write(sock, boost::asio::buffer("OK", 2));
char data[10];
while (!stopped) {
memset(data, 0, sizeof(data));
boost::asio::read(sock, boost::asio::buffer(data, 1), error);
int size = data[0] - '0';
memset(data, 0, sizeof(data));
boost::asio::read(sock, boost::asio::buffer(data, size), error);
int iters = atoi(data);
std::cout << "Got iterations " << iters << "\n";
got_iters = true;
if (seen_iterations.size() > 0 && *seen_iterations.begin() <= iters) {
std::cout << "Ignoring..." << iters << "\n";
continue;
}
if (seen_iterations.size() > 2 && iters != 0) {
std::cout << "Ignoring..." << iters << "\n";
continue;
}
if (iters == 0) {
stopped = true;
poll_thread.join();
for (int t = 0; t < threads.size(); t++) {
threads[t].join();
}
vdf_worker.join();
} else {
if (seen_iterations.find(iters) == seen_iterations.end()) {
seen_iterations.insert(iters);
threads.push_back(std::thread(NWesolowskiMain, D, f, iters, std::ref(weso), std::ref(stopped),
std::ref(sock)));
}
}
}
} catch (std::exception& e) {
std::cerr << "Exception in thread: " << e.what() << "\n";
}
try {
// Tell client I've stopped everything, wait for ACK and close.
boost::system::error_code error;
std::cout << "Stopped everything! Ready for the next challenge.\n";
std::lock_guard<std::mutex> lock(socket_mutex);
boost::asio::write(sock, boost::asio::buffer("STOP", 4));
char ack[5];
memset(ack,0x00,sizeof(ack));
boost::asio::read(sock, boost::asio::buffer(ack, 3), error);
assert (strncmp(ack, "ACK", 3) == 0);
} catch (std::exception& e) {
std::cerr << "Exception in thread: " << e.what() << "\n";
}
}
void server(boost::asio::io_context& io_context, unsigned short port)
{
tcp::acceptor a(io_context, tcp::endpoint(tcp::v4(), port));
for (;;)
{
std::thread t(session, a.accept());
t.join();
}
}
int main(int argc, char* argv[])
{
forms.reserve(1000000);
for (int i = 0; i < 1000000; i++) {
mpz_inits(forms[i].a.impl, forms[i].b.impl, forms[i].c.impl, NULL);
}
try
{
if (argc != 2)
{
std::cerr << "Usage: blocking_tcp_echo_server <port>\n";
return 1;
}
boost::asio::io_context io_context;
server(io_context, std::atoi(argv[1]));
}
catch (std::exception& e)
{
std::cerr << "Exception: " << e.what() << "\n";
}
return 0;
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,30 @@
/*#include "include.h"
#include "integer.h"
#include "vdf_new.h"
int main(int argc, char** argv) {
parse_args(argc, argv);
integer a;
integer b;
integer c;
generator_for_discriminant(arg_discriminant, a, b, c);
for (int x=0;x<arg_iterations;++x) {
square(a, b, c);
reduce(a, b, c);
}
print( "" );
print(a.to_string());
print( "" );
print(b.to_string());
print( "" );
print(c.to_string());
print( "" );
}**/

View File

@ -0,0 +1,435 @@
void normalize(integer& a, integer& b, integer& c) {
integer r = (a-b)/(a<<1);
//todo print( "normalize r=", r.to_string() );
integer A = a;
integer B = b + ((r*a)<<1);
integer C = a*r*r + b*r + c;
// r=0:
// A=a
// B=b
// C=c
// r=1:
// A=a
// B=b+2a
// C=a+b+c
a=A;
b=B;
c=C;
}
void reduce_impl(integer& a, integer& b, integer& c) {
integer s = (c+b)/(c<<1);
//todo print( "reduce s=", s.to_string() );
integer A = c;
integer B = ((s*c)<<1) - b;
integer C = c*s*s - b*s+a;
a=A;
b=B;
c=C;
}
void reduce(integer& a, integer& b, integer& c) {
/*TRACK_MAX(a); // 2
TRACK_MAX(b); // 3
TRACK_MAX(c); // 4
*/
normalize(a, b, c);
/*TRACK_MAX(a); // 2
TRACK_MAX(b); // 2
TRACK_MAX(c); // 2
*/
int iter=0;
while (a>c || (a==c && b<0)) {
reduce_impl(a, b, c);
++iter;
/*if (iter==1) {
TRACK_MAX(a); // 2
TRACK_MAX(b); // 2
TRACK_MAX(c); // 2
}*/
}
normalize(a, b, c);
}
void generator_for_discriminant(const integer& d, integer& a, integer& b, integer& c) {
a=2;
b=1;
c = (b*b - d)/(a<<2);
reduce(a, b, c);
}
void square(integer& a, integer& b, integer& c) {
gcd_res r=gcd(b, a);
integer u=(c/r.gcd*r.s)%a;
integer A = a*a;
integer B = b - ((a*u)<<1);
integer C = u*u - (b*u-c)/a;
a=A;
b=B;
c=C;
}
//reduced bounds:
// |b| <= a
// |a|< = sqrt(-d/3) <= sqrt(|d|)
// |c| = |(b^2-d)/(4a)| <= |b^2-d| = |b^2+|d|| <= |a^2+|d|| <= |-d/3+|d|| = ||d|/3+|d|| = |2/3*|d|| <= |d|
// a and b have half as many bits as d (rounded up). c can have as many bits as d (but it is usually half)
// |ac| = |(b^2-d)/(4)| <= |b^2-d| <= |d|
// |bc| <= |ac| <= |d|
// b is odd:
// assume b=2n (even)
// d = b^2-4ac = 4n^2 - 4ac = multiple of 4
// d is prime so it is odd (contradiction)
struct form {
integer a;
integer b;
integer c;
static form from_abd(const integer& t_a, const integer& t_b, const integer& d) {
form res;
res.a=t_a;
res.b=t_b;
res.c=(t_b*t_b - d);
assert(t_a>integer(0));
assert(res.c % (t_a<<2) == integer(0));
res.c/=(t_a<<2);
res.reduce();
return res;
}
static form identity(const integer& d) {
return from_abd(integer(1), integer(1), d);
}
static form generator(const integer& d) {
return from_abd(integer(2), integer(1), d);
}
void reduce() {
::reduce(a, b, c);
}
form inverse() const {
form res=*this;
res.b=-res.b;
res.reduce(); //doesn't do anything unless |a|==|b|
return res;
}
bool check_valid(const integer& d) {
return b*b-integer(4)*a*c==d;
}
void assert_valid(const integer& d) {
assert(check_valid(d));
}
bool operator==(const form& f) const {
return a==f.a && b==f.b && c==f.c;
}
bool operator<(const form& f) const {
return make_tuple(a, b, c)<make_tuple(f.a, f.b, f.c);
}
//assumes this is normalized (c has the highest magnitude)
//the inverse has the same hash
int hash() const {
uint64 res=c.to_vector()[0];
return int((res>>4) & ((1ull<<31)-1)); //ignoring some of the lower bits because they might not be random enough
}
};
integer generate_discriminant(int num_bits, int seed=-1) {
integer res=rand_integer(num_bits, seed);
while (true) {
mpz_nextprime(res.impl, res.impl);
if ((res % integer(8)) == integer(7)) {
break;
}
}
return -res;
}
form square(const form& f) {
form res=f;
square(res.a, res.b, res.c);
res.reduce();
return res;
}
//inputs are: unsigned, unsigned, signed
integer three_gcd(integer a, integer b, integer c) {
auto res1=gcd(a, b);
auto res2=gcd(res1.gcd, c);
return res2.gcd;
}
gcd_res test_gcd(integer a_signed, integer b_signed, int index=0) {
bool a_negative=a_signed<integer(0);
bool b_negative=b_signed<integer(0);
integer a=abs(a_signed);
integer b=abs(b_signed);
integer u0;
integer u1;
int parity;
if (a<b) {
swap(a, b);
u0=0;
u1=1;
parity=-1;
} else {
u0=1;
u1=0;
parity=1;
}
int iter=0;
while (b!=integer(0)) {
/*if (iter==0 && index==0) {
TRACK_MAX(a); // 2
TRACK_MAX(b); // 2
TRACK_MAX(u0); // 0.03
TRACK_MAX(u1); // 0.03
}
if (iter==1 && index==0) {
TRACK_MAX(a); // 2
TRACK_MAX(b); // 2
TRACK_MAX(u0); // 0.25
TRACK_MAX(u1); // 0.55
}
if (iter==2 && index==0) {
TRACK_MAX(a); // 2
TRACK_MAX(b); // 2
TRACK_MAX(u0); // 0.55
TRACK_MAX(u1); // 0.60
}
if (iter==0 && index==1) {
TRACK_MAX(a); // 2
TRACK_MAX(b); // 1
TRACK_MAX(u0); // 0.03
TRACK_MAX(u1); // 0.03
}
if (iter==1 && index==1) {
TRACK_MAX(a); // 1
TRACK_MAX(b); // 1
TRACK_MAX(u0); // 0.03
TRACK_MAX(u1); // 0.25
}
if (iter==2 && index==1) {
TRACK_MAX(a); // 1
TRACK_MAX(b); // 1
TRACK_MAX(u0); // 0.25
TRACK_MAX(u1); // 0.28
}*/
integer q=a/b;
integer r=a%b;
a=b;
b=r;
integer u1_new=u0 + q*u1;
u0=u1;
u1=u1_new;
parity=-parity;
++iter;
}
gcd_res res;
res.gcd=a;
res.s=u0;
if (a_negative != (parity==-1)) {
res.s=-res.s;
}
{
auto expected_gcd_res=gcd(a_signed, b_signed);
assert(expected_gcd_res.gcd==res.gcd);
assert(expected_gcd_res.s==res.s);
}
return res;
}
//a and b are N bits and m is M bits; outputs are M bits
//a and b are signed and m is unsigned
//mu and v are unsigned
void solve_linear_congruence(const integer& a, const integer& b, const integer& m, integer& mu, integer& v, int index=0) {
// g = gcd(a, m), and da + em = g
//one round of the euclidean algorithm will equalize the sizes of the inputs; a is signed and m is unsigned
gcd_res gcd_r;
if (false) {
gcd_r=test_gcd(a, m, index);
} else {
gcd_r=gcd(a, m);
}
integer g=gcd_r.gcd; //min(N,M) bits unsigned
integer d=gcd_r.s; //max(N,M) bits signed
// q = b/g, r = b % g
integer q=b/g; //N bits ; signed/unsigned = signed
integer r=b%g;
assert(r==integer(0));
mu=(q*d)%m; //N+M bits mod M bits => M bits ; signed*signed mod unsigned = unsigned
v=m/g; //M bits unsigned
}
//f1.a,f1.b,f2.a,f2.b are N bits and f1.c,f2.c are 2N bits. a/c are unsigned and b is signed
form multiply(const form& f1, const form& f2) {
form f3;
integer g = (f2.b + f1.b) / integer(2); //N bits signed; sum is odd+odd which is even
integer h = (f2.b - f1.b) / integer(2); //N bits signed; sum is odd-odd which is even
integer w = three_gcd(f1.a, f2.a, g); //N bits unsigned
integer j = w; //N bits unsigned
//integer r = 0;
integer s = f1.a / w; //N bits unsigned
integer t = f2.a / w; //N bits unsigned
integer u = g / w; //N bits signed
integer k_temp;
integer constant_factor;
solve_linear_congruence(
t * u, //2N bits signed
h * u + s * f1.c, // f1.a * f1.c is 2N bits; 2N+1 bits; signed+unsigned = signed
s * t, //2N bits unsigned
k_temp, //2N bits (same as m argument); unsigned
constant_factor, //2N bits (same as m argument); unsigned
0
);
integer n;
integer constant_factor_2;
solve_linear_congruence(
t * constant_factor, //3N bits signed
h - t * k_temp, //3N bits signed - unsigned = signed
s, //N bits unsigned
n, //N bits unsigned
constant_factor_2, //N bits unsigned
1
);
integer k = k_temp + constant_factor * n; //4N bits unsigned
integer l = (t*k - h) / s; //5N bits signed / unsigned = signed
integer m = (t*u*k - h*u - s*f1.c) / (s*t); //6N bits divided by 2N bits => 6N bits ; signed / unsigned = signed
f3.a = s*t; //2N bits unsigned
f3.b = (j*u) - (k*t + l*s); //6N bits signed
f3.c = k*l - j*m; //9N bits unsigned (result must be nonnegative)
//experimental values (multiplies of d/2 bits)
//with 100 bits d:
// 50 bits - 2x 32-bit words
//100 bits - 4x 32-bit words
//150 bits - 5x 32-bit words
//200 bits - 7x 32-bit words
/*
TRACK_MAX(g); // 1
TRACK_MAX(h); // 1
TRACK_MAX(w); // 1
TRACK_MAX(s); // 1
TRACK_MAX(t); // 1
TRACK_MAX(u); // 1
TRACK_MAX(t*u); // 2
TRACK_MAX(s * f1.c); // 2
TRACK_MAX(h * u + s * f1.c); // 2
TRACK_MAX(s*t); // 2
TRACK_MAX(k_temp); // 2
TRACK_MAX(constant_factor); // 1
TRACK_MAX(n); // 1
TRACK_MAX(constant_factor_2); // 1
TRACK_MAX(t * constant_factor); // 2
TRACK_MAX(t * k_temp); // 3
TRACK_MAX(h - t * k_temp); // 3
TRACK_MAX(constant_factor * n); // 2
TRACK_MAX(k_temp + constant_factor * n); // 2
TRACK_MAX(t*k); // 3
TRACK_MAX(t*k - h); // 3
TRACK_MAX((t*k - h) / s); // 2
TRACK_MAX(t*u); // 2
TRACK_MAX(u*k); // 3
TRACK_MAX(t*k); // 3
TRACK_MAX(t*u*k); // 4
TRACK_MAX(h*u); // 2
TRACK_MAX(s*f1.c); // 2
TRACK_MAX(t*u*k - h*u - s*f1.c); // 4
TRACK_MAX(s*t); // 2
TRACK_MAX((t*u*k - h*u - s*f1.c) / (s*t)); // 2
TRACK_MAX(s*t); // 2
TRACK_MAX(j*u); // 1
TRACK_MAX(k*t); // 3
TRACK_MAX(l*s); // 3
TRACK_MAX((j*u) - (k*t + l*s)); // 3
TRACK_MAX(k*l); // 4
TRACK_MAX(j*m); // 2
TRACK_MAX(k*l - j*m); // 4
TRACK_MAX(f3.a); // 2
TRACK_MAX(f3.b); // 3
TRACK_MAX(f3.c); // 4
*/
f3.reduce();
return f3;
}
form operator*(const form& a, const form& b) {
if (&a==&b) {
return square(a);
} else {
return multiply(a, b);
}
}
/*integer arg_discriminant;
int arg_iterations;
void parse_args(int argc, char** argv) {
arg_discriminant=integer(
"-0xdc2a335cd2b355c99d3d8d92850122b3d8fe20d0f5360e7aaaecb448960d57bcddfee12a229bbd8d370feda5a17466fc725158ebb78a2a7d37d0a226d89b54434db9c3be9a9bb6ba2c2cd079221d873a17933ceb81a37b0665b9b7e247e8df66bdd45eb15ada12326db01e26c861adf0233666c01dec92bbb547df7369aed3b1fbdff867cfc670511cc270964fbd98e5c55fbe0947ac2b9803acbfd935f3abb8d9be6f938aa4b4cc6203f53c928a979a2f18a1ff501b2587a93e95a428a107545e451f0ac6c7f520a7e99bf77336b1659a2cb3dd1b60e0c6fcfffc05f74cfa763a1d0af7de9994b6e35a9682c4543ae991b3a39839230ef84dae63e88d90f457"
);
arg_iterations=1000;
if (argc==1) {
} else
if (argc==2) {
arg_iterations=from_string<int>(argv[1]);
} else
if (argc==3) {
arg_discriminant=integer(argv[1]);
arg_iterations=from_string<int>(argv[2]);
} else {
assert(false);
}
}**/

View File

@ -0,0 +1,325 @@
/**
Copyright 2018 Chia Network Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
***/
class vdf_original
{
public:
struct form {
// y = ax^2 + bxy + y^2
mpz_t a;
mpz_t b;
mpz_t c;
//mpz_t d; // discriminant
};
mpz_t negative_a, r, denom, old_b, ra, s, x, old_a, g, d, e, q, w, u, a,
b, m, k, mu, v, sigma, lambda, h, t, l, j;
form f3;
void normalize(form& f) {
mpz_neg(negative_a, f.a);
if (mpz_cmp(f.b, negative_a) > 0 && mpz_cmp(f.b, f.a) <= 0) {
// Already normalized
return;
}
// r = (a - b) / 2a
// a = a
// b = b + 2ra
// c = ar^2 + br + c
mpz_sub(r, f.a, f.b);
mpz_mul_ui(denom, f.a, 2);
// r = (a-b) / 2a
mpz_fdiv_q(r, r, denom);
mpz_set(old_b, f.b);
mpz_mul(ra, r, f.a);
mpz_add(f.b, f.b, ra);
mpz_add(f.b, f.b, ra);
// c += ar^2
mpz_mul(ra, ra, r);
mpz_add(f.c, f.c, ra);
// c += rb
mpz_set(ra, r);
mpz_mul(ra, ra, old_b);
mpz_add(f.c, f.c, ra);
}
void reduce(form& f) {
normalize(f);
while ((mpz_cmp(f.a, f.c) > 0) ||
(mpz_cmp(f.a, f.c) == 0 && mpz_cmp_si(f.b, 0) < 0)) {
mpz_add(s, f.c, f.b);
// x = 2c
mpz_mul_ui(x, f.c, 2);
mpz_fdiv_q(s, s, x);
mpz_set(old_a, f.a);
mpz_set(old_b, f.b);
// b = -b
mpz_set(f.a, f.c);
mpz_neg(f.b, f.b);
// x = 2sc
mpz_mul(x, s, f.c);
mpz_mul_ui(x, x, 2);
// b += 2sc
mpz_add(f.b, f.b, x);
// c = cs^2
mpz_mul(f.c, f.c, s);
mpz_mul(f.c, f.c, s);
// x = bs
mpz_mul(x, old_b, s);
// c -= bs
mpz_sub(f.c, f.c, x);
// c += a
mpz_add(f.c, f.c, old_a);
}
normalize(f);
}
form generator_for_discriminant(mpz_t* d) {
form x;
mpz_init_set_ui(x.a, 2);
mpz_init_set_ui(x.b, 1);
mpz_init(x.c);
//mpz_init_set(x.d, *d);
// c = b*b - d
mpz_mul(x.c, x.b, x.b);
mpz_sub(x.c, x.c, *d);
// denom = 4a
mpz_mul_ui(denom, x.a, 4);
mpz_fdiv_q(x.c, x.c, denom);
reduce(x);
return x;
}
// Returns mu and v, solving for x: ax = b mod m
// such that x = u + vn (n are all integers). Assumes that mu and v are initialized.
// Returns 0 on success, -1 on failure
int solve_linear_congruence(mpz_t& mu, mpz_t& v, mpz_t& a, mpz_t& b, mpz_t& m) {
// g = gcd(a, m), and da + em = g
mpz_gcdext(g, d, e, a, m);
// q = b/g, r = b % g
mpz_fdiv_qr(q, r, b, g);
if (mpz_cmp_ui(r, 0) != 0) {
// No solution, return error. Optimize out for speed..
cout << "No solution to congruence" << endl;
return -1;
}
mpz_mul(mu, q, d);
mpz_mod(mu, mu, m);
mpz_fdiv_q(v, m, g);
return 0;
}
// Faster version without check, and without returning v
int solve_linear_congruence(mpz_t& mu, mpz_t& a, mpz_t& b, mpz_t& m) {
mpz_gcdext(g, d, e, a, m);
mpz_fdiv_q(q, b, g);
mpz_mul(mu, q, d);
mpz_mod(mu, mu, m);
return 0;
}
// Takes the gcd of three numbers
void three_gcd(mpz_t& ret, mpz_t& a, mpz_t& b, mpz_t& c) {
mpz_gcd(ret, a, b);
mpz_gcd(ret, ret, c);
}
form* multiply(form &f1, form &f2) {
//assert(mpz_cmp(f1.d, f2.d) == 0);
// g = (b1 + b2) / 2
mpz_add(g, f1.b, f2.b);
mpz_fdiv_q_ui(g, g, 2);
// h = (b2 - b1) / 2
mpz_sub(h, f2.b, f1.b);
mpz_fdiv_q_ui(h, h, 2);
// w = gcd(a1, a2, g)
three_gcd(w, f1.a, f2.a, g);
// j = w
mpz_set(j, w);
// r = 0
mpz_set_ui(r, 0);
// s = a1/w
mpz_fdiv_q(s, f1.a, w);
// t = a2/w
mpz_fdiv_q(t, f2.a, w);
// u = g/w
mpz_fdiv_q(u, g, w);
// solve (tu)k = (hu + sc1) mod st, of the form k = mu + vn
// a = tu
mpz_mul(a, t, u);
// b = hu + sc1
mpz_mul(b, h, u);
mpz_mul(m, s, f1.c);
mpz_add(b, b, m);
// m = st
mpz_mul(m, s, t);
int ret = solve_linear_congruence(mu, v, a, b, m);
assert(ret == 0);
// solve (tv)n = (h - t * mu) mod s, of the form n = lamda + sigma n'
// a = tv
mpz_mul(a, t, v);
// b = h - t * mu
mpz_mul(m, t, mu); // use m as a temp variable
mpz_sub(b, h, m);
// m = s
mpz_set(m, s);
ret = solve_linear_congruence(lambda, sigma, a, b, m);
assert(ret == 0);
// k = mu + v*lamda
mpz_mul(a, v, lambda); // use a as a temp variable
mpz_add(k, mu, a);
// l = (k*t - h) / s
mpz_mul(l, k, t);
mpz_sub(l, l, h);
mpz_fdiv_q(l, l, s);
// m = (tuk - hu - cs) / st
mpz_mul(m, t, u);
mpz_mul(m, m, k);
mpz_mul(a, h, u); // use a as a temp variable
mpz_sub(m, m, a);
mpz_mul(a, f1.c, s); // use a as a temp variable
mpz_sub(m, m, a);
mpz_mul(a, s, t); // use a as a temp variable
mpz_fdiv_q(m, m, a);
// A = st - ru
mpz_mul(f3.a, s, t);
mpz_mul(a, r, u); // use a as a temp variable
mpz_sub(f3.a, f3.a, a);
// B = ju + mr - (kt + ls)
mpz_mul(f3.b, j, u);
mpz_mul(a, m, r); // use a as a temp variable
mpz_add(f3.b, f3.b, a);
mpz_mul(a, k, t); // use a as a temp variable
mpz_sub(f3.b, f3.b, a);
mpz_mul(a, l, s); // use a as a temp variable
mpz_sub(f3.b, f3.b, a);
// C = kl - jm
mpz_mul(f3.c, k, l);
mpz_mul(a, j, m);
mpz_sub(f3.c, f3.c, a);
//mpz_set(f3.d, f1.d);
reduce(f3);
return &f3;
}
/**
* This algorithm is the same as the composition/multiply algorithm,
* but simplified to where both inputs are equal (squaring). It also
* assumes that the discriminant is a negative prime. Algorithm:
*
* 1. solve for mu: b(mu) = c mod a
* 2. A = a^2
* B = B - 2a * mu
* C = mu^2 - (b * mu - c)/a
* 3. reduce f(A, B, C)
**/
form* square(form &f1) {
int ret = solve_linear_congruence(mu, f1.b, f1.c, f1.a);
assert(ret == 0);
mpz_mul(m, f1.b, mu);
mpz_sub(m, m, f1.c);
mpz_fdiv_q(m, m, f1.a);
// New a
mpz_set(old_a, f1.a);
mpz_mul(f3.a, f1.a, f1.a);
// New b
mpz_mul(a, mu, old_a);
mpz_mul_ui(a, a, 2);
mpz_sub(f3.b, f1.b, a);
// New c
mpz_mul(f3.c, mu, mu);
mpz_sub(f3.c, f3.c, m);
//mpz_set(f3.d, f1.d);
reduce(f3);
return &f3;
}
// Performs the VDF squaring iterations
form repeated_square(form *f, uint64_t iterations) {
for (uint64_t i=0; i < iterations; i++) {
f = square(*f);
}
return *f;
}
vdf_original() {
mpz_inits(negative_a, r, denom, old_a, old_b, ra, s, x, g, d, e, q, w, m,
u, a, b, k, mu, v, sigma, lambda, f3.a, f3.b, f3.c, //f3.d,
NULL);
}
~vdf_original() {
mpz_clears(negative_a, r, denom, old_a, old_b, ra, s, x, g, d, e, q, w, m,
u, a, b, k, mu, v, sigma, lambda, f3.a, f3.b, f3.c, NULL); //,);
}
};

View File

@ -0,0 +1,438 @@
#include "include.h"
#include "parameters.h"
#include "bit_manipulation.h"
#include "double_utility.h"
#include "integer.h"
#include "asm_main.h"
#include "vdf_original.h"
#include "vdf_new.h"
#include "gpu_integer.h"
#include "gpu_integer_divide.h"
#include "gcd_base.h"
#include "gpu_integer_gcd.h"
#include "vdf_test.h"
#if VDF_MODE==0
const bool test_correctness=false;
const bool assert_on_rollback=false;
const bool debug_rollback=false;
const int repeated_square_checkpoint_interval=1<<10; //should be a power of 2
#endif
#if VDF_MODE==1
const bool test_correctness=true;
const bool assert_on_rollback=true;
const bool debug_rollback=false;
const int repeated_square_checkpoint_interval=1<<10;
#endif
using namespace std;
//using simd_integer_namespace::track_cycles_test;
//each thread updates a sequence number. the write is atomic on x86
//it also has an array of outputs that is append-only
//it will generate an output, do a mfence, then increment the sequence number non-atomically (since it is the only writer)
//it can also wait for another thread's outputs by spinning on its sequence number (with a timeout)
//error states:
//-any thread can change its sequence number to "error" which is the highest uint64 value
//-it will do this if any operation fails or if it spins too long waiting for another thread's output
//-also, the spin loop will error out if the other thread's sequence number is "error". this will make the spinning thread's sequence
// number also be "error"
//-once a thread has become "error", it will exit the code. the slave threads will wait on the barrier and the main thread will just
// exit the squaring function with a "false" output
//-the error state is the global sequence number with the msb set. this allows the sequence number to not be reset across calls
//will just make every state have a 48 bit global sequence number (enough for 22 years) plus a 16 bit local sequence number
//-the last local sequence number is the error state
//-there is no finish state since each state update will change the sequence number to a new, unique sequence number
//to start the squaring, the main thread will output A and B then increase its sequence number to the next global sequence number
//-slave threads will wait on this when they are done squaring or have outputted the error state
//-is is assumed that the main thread synchronizes with each slave thread to consume its output
//can probably write the synchronization code in c++ then because of how simple it is
//this is trivial to implement and should be reliable
//if the gcd generates too many matricies (more than 32 or so), it should generate an error
//need to write each output to a separate cache line
//will use the slave core for: cofactors for both gcds, calculate C at the start of the squaring, calculate (-v2*c)%a as the v2
// cofactor is being generated. this will use <0,-c> as the initial state instead of <0,1> and will also reduce everything modulo a
// after each matrix multiplication. it will also calculate C first.
//the slave core will then calculate the partial gcd and the master core will calculate the cofactors
//once the master core has calculated all of the cofactors, it will also know the final values of a_copy and k_copy from the
// slave core. the slave core is done now
//the master core will calculate the new values of A and B on its own. this can't be parallelized
//-have an asm gcd. nothing else is asm. will use gmp for everything else
//-the asm gcd takes unsigned inputs where a>=b. it returns unsigned outputs. its inputs are zero-padded to a fixed size
//-it modifes its inputs and returns a sequence of cofactor matricies
//-gmp has some utility functions to make this work easily. gmp can also calculate the new size. the resulting sign is always +
//--the sequence is outputted to a fixed size array of cache lines. there is also an output counter which should initially be 0
//-- and can be any pointer. the msb of the output counter is used to indicate the last output
//-the slave core is still used
//-gmp is close to optimal for the pentium machine so will just use it. for the fast machine, can use avx-512 if i have time. the gmp
//- division is still used but only to find the approximate inverse. the result quotient should be >= the actual quotient for exact
//- division to still work
//generic_stats track_cycles_total;
void square_original(form& f) {
vdf_original::form f_in;
f_in.a[0]=f.a.impl[0];
f_in.b[0]=f.b.impl[0];
f_in.c[0]=f.c.impl[0];
vdf_original::form& f_res=*vdf_original::square(f_in);
mpz_set(f.a.impl, f_res.a);
mpz_set(f.b.impl, f_res.b);
mpz_set(f.c.impl, f_res.c);
}
bool square_fast(form& f, const integer& d, const integer& L, int current_iteration) {
form f_copy;
if (test_correctness) {
f_copy=f;
}
bool success=false;
const int max_bits_ab=max_bits_base + num_extra_bits_ab;
const int max_bits_c=max_bits_base + num_extra_bits_c;
//sometimes the nudupl code won't reduce the output all the way. if it has too many bits it will get reduced by calling
// square_original
if (f.a.num_bits()<max_bits_ab && f.b.num_bits()<max_bits_ab && f.c.num_bits()<max_bits_c) {
if (square_fast_impl(f, d, L, current_iteration)) {
success=true;
}
}
if (!success) {
//this also reduces it
print( "===square original===" );
square_original(f);
}
if (test_correctness) {
square_original(f_copy);
form f_copy_2=f;
f_copy_2.reduce();
assert(f_copy_2==f_copy);
}
return true;
}
void output_error(form start, int location) {
print( "=== error ===" );
print(start.a.to_string());
print(start.b.to_string());
print(start.c.to_string());
print(location);
assert(false);
}
struct repeated_square {
integer d;
integer L;
int64 checkpoint_iteration=0;
form checkpoint;
int64 current_iteration=0;
form current;
int64 num_iterations=0;
bool error_mode=false;
bool is_checkpoint() {
return
current_iteration==num_iterations ||
(current_iteration & (repeated_square_checkpoint_interval-1)) == 0
;
}
void advance_fast(bool& did_rollback) {
bool is_error=false;
if (!square_fast(current, d, L, int(current_iteration))) {
is_error=true;
}
if (!is_error) {
++current_iteration;
if (is_checkpoint() && !current.check_valid(d)) {
is_error=true;
}
}
if (is_error) {
if (debug_rollback) {
print( "Rollback", current_iteration, " -> ", checkpoint_iteration );
}
current_iteration=checkpoint_iteration;
current=checkpoint;
error_mode=true;
did_rollback=true;
assert(!assert_on_rollback);
}
}
void advance_error() {
square_original(current);
++current_iteration;
}
void advance() {
bool did_rollback=false;
if (error_mode) {
advance_error();
} else {
advance_fast(did_rollback);
}
if (!did_rollback && is_checkpoint()) {
checkpoint_iteration=current_iteration;
checkpoint=current;
error_mode=false;
}
}
repeated_square(integer t_d, form initial, int64 t_num_iterations) {
d=t_d;
L=root(-d, 4);
//L=integer(1)<<512;
checkpoint=initial;
current=initial;
num_iterations=t_num_iterations;
while (current_iteration<num_iterations) {
//todo if (current_iteration%10000==0) print(current_iteration);
advance();
}
//required if reduce isn't done after each iteration
current.reduce();
}
};
int main(int argc, char* argv[]) {
#if VDF_MODE!=0
print( "=== Test mode ===" );
#endif
//integer ab_start_0(
//"0x53098cff6d1cf3723235e44e397d7a7a77d254551ef35649381d0f2d192ab247d042d4d03005d188f0103aae267cc49515ae3d63b7513fb8d02da102ce2ff39c59a1e3ee9d4bbdb6011589d58f8e26a7c63fd342459fabefaa83ee65adbaf94d372ff6bbce71acdafb75aade3f39f5c7896490ff8b42b23ff337d414948adafb"
//);
//integer ab_start_1(
//"0x1e38edea0e0b65dcd83702504bfa6ceb51df1774093a759280932d6f0097fb04f28dd6da814c2eb045621d9666271be86cf2dfbd1d630a3e4ccec0d2aeb5876100e4ca48783a601d65fc628e80b737f130f4f0c83d79a93738402fcd605b3c6f189cd0a99ff08fad6cd2d425d13284d1d121320261e7740aaab0b7a14718eeb7"
//);
//integer threshold(
//"0xf68745a14f96317c568c660f2e4bcc3dbfd677e12911931303fb7afc4c5a6f637476e331f687ffdba09b7d51aa74f1caf416bcfa9532a1b911076302ac8f4ab8"
//);
//array<fixed_integer<uint64, 17>, 2> ab={fixed_integer<uint64, 17>(ab_start_0), fixed_integer<uint64, 17>(ab_start_1)};
//array<fixed_integer<uint64, 17>, 2> uv;
//int parity;
//gcd_unsigned(
//ab,
//uv,
//parity,
//fixed_integer<uint64, 17>(threshold)
//);
//todo assert(false);
//todo //set up thread affinity. make sure they are not hyperthreads on the same core if possible
set_rounding_mode();
vdf_original::init();
integer d(argv[1]);
int64 num_iterations=from_string<int64>(argv[2]);
form d_initial=form::generator(d);
//integer d(
//"-0xaf0806241ecbc630fbbfd0c9d61c257c40a185e8cab313041cf029d6f070d58ecbc6c906df53ecf0dd4497b0753ccdbce2ebd9c80ae0032acce89096af642dd8c008403dd989ee5c1262545004fdcd7acf47908b983bc5fed17889030f0138e10787a8493e95ca86649ae8208e4a70c05772e25f9ac901a399529de12910a7a2c"
//"3376292be9dba600fd89910aeccc14432b6e45c0456f41c177bb736915cad3332a74e25b3993f3e44728dc2bd13180132c5fb88f0490aeb96b2afca655c13dd9ab8874035e26dab16b6aad2d584a2d35ae0eaf00df4e94ab39fe8a3d5837dcab204c46d7a7b97b0c702d8be98c50e1bf8b649b5b6194fc3bae6180d2dd24d9f"
//);
//int64 num_iterations=1000;
//form d_initial=form::from_abd(
//integer(
//"0x6a8f34028dad0dec9e765a5d761b9b041733e86d849b507ba346052f7b768a18d0283597b581e4b9e705dccc3d5197c66186940d5bdbee00784f51dc0f193cedf619e149a7b0fd48b8c4eb6d4bf925a9d634e138254f22007337415cea377655a0c2832592db32ce9b61d4937dcffd13c33bdf1ac5164a974cd9d61b14c81820"
//),
//integer(
//"0x71c24869eed37be508e1751c21f49fcf16a68b42dec10cedf7376a036280f48a2c4b123d5f918ed4affa612a8dbacb4e6b5cdcaad439f3a5f0ab5a35ab6901025307c2ceaf54ab3bae5daae870817527dceb5fef9f7d6766a84bf843d9de74966fbd2bbad0200323876b90a3f4d9d135876a09f51225f126dd180412c658f4f"
//),
//d
//);
repeated_square c_square(d, d_initial, num_iterations);
cout << c_square.current.a.impl << "\n";
cout << c_square.current.b.impl;
//track_max.output(512);
//if (enable_track_cycles) {
//print( "" );
//print( "" );
//for (int x=0;x<track_cycles_test.size();++x) {
//if (track_cycles_test[x].entries.empty()) {
//continue;
//}
//track_cycles_test[x].output(str( "track_cycles_test_#", x ));
//}
//}
#ifdef GENERATE_ASM_TRACKING_DATA
{
using namespace asm_code;
print( "" );
map<string, double> tracking_data;
for (int x=0;x<num_asm_tracking_data;++x) {
if (!asm_tracking_data_comments[x]) {
continue;
}
tracking_data[asm_tracking_data_comments[x]]=asm_tracking_data[x];
}
for (auto c : tracking_data) {
string base_name;
for (int x=0;x<c.first.size();++x) {
if (c.first[x] == ' ') {
break;
}
base_name+=c.first[x];
}
auto base_i=tracking_data.find(base_name);
double base=1;
if (base_i!=tracking_data.end() && base_i->second!=0) {
base=base_i->second;
}
print(c.first, c.second/base, " ", base);
}
}
#endif
}
/*void square_fast_impl(square_state& _) {
const int max_bits_ab=max_bits_base + num_extra_bits_ab;
//all divisions are exact
//sometimes the nudupl code won't reduce the output all the way. if it has too many bits it will get reduced by calling
// square_original
bool too_many_bits;
too_many_bits=(_.a.num_bits()>max_bits_ab || _.b.num_bits()>max_bits_ab);
if (too_many_bits) {
return false;
}
//if a<=L then this will return false; usually a has twice as many limbs as L
bool a_too_small;
a_too_small=(_.a.num_limbs()<=_.L.num_limbs()+1);
if (a_too_small) {
return false;
}
//only b can be negative
//neither a or b can be 0; d=b^2-4ac is prime. if b=0, then d=-4ac=composite. if a=0, then d=b^2; d>=0
//no constraints on which is greater
//the gcd result is 1 because d=b^2-4ac ; assume gcd(a,b)!=1 ; a=A*s ; b=B*s ; s=gcd(a,b)!=1 ; d = (Bs)^2-4Asc
// d = B^2*s^2 - 4sac = s(B^2*s - 4ac) ; d is not prime. d is supposed to be prime so this can't happen
//the quadratic form might not be reduced all the way so it's possible for |b|>a. need to swap the inputs then
// (they are copied anyway)
//
// U0*b + V0*a = 1
// U1*b + V1*a = 0
//
// U0*b === 1 mod a
// U1*b === 0 mod a
U0=gcd(b, a, 0).u0;
c=(b*b-D)/(a<<2);
//start with <0,c> or <c,0> which is padded to 18 limbs so that the multiplications by 64 bits are exact (same with sums)
//once the new values of uv are calculated, need to reduce modulo a, which is 17 limbs and has been normalized already
//-the normalization also left shifted c
//reducing modulo a only looks at the first couple of limbs so it has the same efficiency as doing it at the end
//the modulo result is always nonnegative
//
// k+q*a=-U0*c
k=(-U0*c)%a;
// a>L so at least one input is >L initially
//when this terminates, one input is >L and one is <=L
//k is reduced modulo a, so |k|<|a|
//a is positive
//the result of mpz_mod is always nonnegative so k is nonnegative
//
// u0*a + v0*k = s ; s>L
// u1*a + v1*k = t ; t<=L
// v0*k === s mod a
// v1*k === t mod a
auto gcd2=gcd(a, k, L);
v0=gcd2.v0
v1=gcd2.v1
s=gcd2.a
t=gcd2.b
// b*t + c*v1 === b*v1*k + c*v1 === v1(b*k+c) === v1(-U0*c*b+c) === c*v1*(1-U0*b) === c*v1*(1-1) === 0 mod a
// b*t + c*v1 = b*(u0*a + v1*k) + c*v1 = b*u0*a + v1(b*k + c) = b*u0*a + v1(c - b*(U0*c+q*a))
// = b*u0*a + v1(c - b*U0*c - b*q*a) = b*u0*a + v1(c - (1-V0*a)*c - b*q*a) = b*u0*a + v1(V0*a*c - b*q*a)
// = a*(b*u0 + v1(V0*c - b*q))
// ((b*t+c*v1)/a) = b*u0 + v1(V0*c - b*q) ; this is slower
//
// S = -1 if v1<=0, else 1
// h = S*(b*t+c*v1)/a
// j = t*t*S
//
// A=t*t+v1*((b*t+c*v1)/a)
// A = j + v1*h
A=t*t+v1*((b*t+c*v1)/a);
if (v1<=0) {
A=-A;
}
// e = 2t*(a + S*t*v0)/v1
// e' = b - e
// f = e' - 2*v0*h
//
// (2*a*t + 2*A*v0)/v1
// = (2*a*t + 2*j*v0 + 2*v1*v0*h)/v1
// = (2*a*t + 2*j*v0)/v1 + 2*v0*h
// = (2*a*t + 2*S*t*t*v0)/v1 + 2*v0*h
// = 2t*(a + S*t*v0)/v1 + 2*v0*h
// = e + 2*v0*h
//
// B = ( b - ((a*t+A*v0)*2)/v1 )%(A*2)
// = ( b - e - 2*v0*h )%(A*2)
// = ( e' - 2*v0*h )%(A*2)
// = f % (2A)
B=( b - ((a*t+A*v0)*2)/v1 )%(A*2);
A=abs(A)
return true;
} */

View File

@ -0,0 +1,316 @@
bool square_fast_impl(form& f, const integer& D, const integer& L, int current_iteration) {
const int max_bits_ab=max_bits_base + num_extra_bits_ab;
const int max_bits_c=max_bits_base + num_extra_bits_ab*2;
//sometimes the nudupl code won't reduce the output all the way. if it has too many bits it will get reduced by calling
// square_original
if (!(f.a.num_bits()<max_bits_ab && f.b.num_bits()<max_bits_ab && f.c.num_bits()<max_bits_c)) {
return false;
}
print("f");
integer a=f.a;
integer b=f.b;
integer c=f.c;
fixed_integer<uint64, 17> a_int(a);
fixed_integer<uint64, 17> b_int(b);
fixed_integer<uint64, 17> c_int(c);
fixed_integer<uint64, 17> L_int(L); //actual size is 8 limbs; padded to 17
fixed_integer<uint64, 33> D_int(D); //padded by an extra limb
//2048 bit D, basis is 512; one limb is 0.125; one bit is 0.002
//TRACK_MAX(a); // a, 2.00585 <= bits (multiple of basis), 0 <= is negative
//TRACK_MAX(b); // b, 2.00585, 0
//TRACK_MAX(c); // c, 2.03125, 0
//can just look at the top couple limbs of a for this
assert((a<=L)==(a_int<=L_int));
if (a_int<=L_int) {
return false;
}
integer v2;
fixed_integer<uint64, 17> v2_int;
{
gcd_res g=gcd(b, a);
assert(g.gcd==1);
v2=g.s;
//only b can be negative
//neither a or b can be 0; d=b^2-4ac is prime. if b=0, then d=-4ac=composite. if a=0, then d=b^2; d>=0
//no constraints on which is greater
v2_int=gcd(b_int, a_int, fixed_integer<uint64, 17>(), true).s;
assert(integer(v2_int)==v2);
}
//TRACK_MAX(v2); // v2, 2.00195, 1
//todo
//start with <0,c> or <c,0> which is padded to 18 limbs so that the multiplications by 64 bits are exact (same with sums)
//once the new values of uv are calculated, need to reduce modulo a, which is 17 limbs and has been normalized already
//-the normalization also left shifted c
//reducing modulo a only looks at the first couple of limbs so it has the same efficiency as doing it at the end
//-it does require computing the inverse of a a bunch of times which is slow. this will probably slow it down by 2x-4x
//--can avoid this by only reducing every couple of iterations
integer k=(-v2*c)%a;
fixed_integer<uint64, 17> k_int=fixed_integer<uint64, 33>(-v2_int*c_int)%a_int;
assert(integer(k_int)==k);
//print( "v2", v2.to_string() );
//print( "k", k.to_string() );
//TRACK_MAX(v2*c); // v2*c, 4.0039, 1
//TRACK_MAX(k); // k, 2.0039, 0
integer a_copy=a;
integer k_copy=k;
integer co2;
integer co1;
xgcd_partial(co2, co1, a_copy, k_copy, L); //neither input is negative
const bool same_cofactors=false; //gcd and xgcd_parital can return slightly different results
fixed_integer<uint64, 9> co2_int;
fixed_integer<uint64, 9> co1_int;
fixed_integer<uint64, 9> a_copy_int;
fixed_integer<uint64, 9> k_copy_int;
{
// a>L so at least one input is >L initially
//when this terminates, one input is >L and one is <=L
auto g=gcd(a_int, k_int, L_int, false);
co2_int=-g.t;
co1_int=-g.t_2;
a_copy_int=g.gcd;
k_copy_int=g.gcd_2;
if (same_cofactors) {
assert(integer(co2_int)==co2);
assert(integer(co1_int)==co1);
assert(integer(a_copy_int)==a_copy);
assert(integer(k_copy_int)==k_copy);
}
}
//print( "co2", co2_int.to_integer().to_string() );
//print( "co1", co1_int.to_integer().to_string() );
//print( "a_copy", a_copy_int.to_integer().to_string() );
//print( "k_copy", k_copy_int.to_integer().to_string() );
//todo
//can speed the following operations up with simd (including calculating C but it is done on the slave core)
//division by a can be replaced by multiplication by a inverse. this takes the top N bits of the numerator and denominator inverse
// where N is the number of bits in the result
//if this is done correctly, the calculated result withh be >= the actual result, and it will be == almost all of the time
//to detect if it is >, can calculate the remainder and see if it is too high. this can be done by the slave core during the
// next iteration
//most of the stuff is in registers for avx-512
//the slave core will precalculate a inverse. it is already dividing by a to calculate c
//this would get rid of the 8x8 batched multiply but not the single limb multiply, since that is still needed for gcd
//for the cofactors which are calculated on the slave core, can use a tree matrix multiplication with the avx-512 code
//for the pentium processor, the adox instruction is banned so the single limb multiply needs to be changed
//the slave core can calculate the inverse of co1 while the master core is calculating A
//for the modulo, the quotient has about 15 bits. can probably calculate the inverse on the master core then since the division
// base case already calculates it with enough precision
//this should work for scalar code also
//TRACK_MAX(co2); // co2, 1.00195, 1
//TRACK_MAX(co1); // co1, 1.0039, 1
//TRACK_MAX(a_copy); // a_copy, 1.03906, 0
//TRACK_MAX(k_copy); // k_copy, 1, 0
//TRACK_MAX(k_copy*k_copy); // k_copy*k_copy, 2, 0
//TRACK_MAX(b*k_copy); // b*k_copy, 3.0039, 0
//TRACK_MAX(c*co1); // c*co1, 3.0039, 1
//TRACK_MAX(b*k_copy-c*co1); // b*k_copy-c*co1, 3.00585, 1
//TRACK_MAX((b*k_copy-c*co1)/a); // (b*k_copy-c*co1)/a, 1.02539, 1
//TRACK_MAX(co1*((b*k_copy-c*co1)/a)); // co1*((b*k_copy-c*co1)/a), 2.00585, 1
integer A=k_copy*k_copy-co1*((b*k_copy-c*co1)/a); // [exact]
//TRACK_MAX(A); // A, 2.00585, 0
fixed_integer<uint64, 17> A_int;
{
fixed_integer<uint64, 17> k_copy_k_copy(k_copy_int*k_copy_int);
fixed_integer<uint64, 25> b_k_copy(b_int*k_copy_int);
fixed_integer<uint64, 25> c_co1(c_int*co1_int);
fixed_integer<uint64, 25> b_k_copy_c_co1(b_k_copy-c_co1);
fixed_integer<uint64, 9> t1(b_k_copy_c_co1/a_int);
fixed_integer<uint64, 17> t2(co1_int*t1);
A_int=k_copy_k_copy-t2;
if (same_cofactors) {
assert(integer(A_int)==A);
}
}
if (co1>=0) {
A=-A;
}
if (!co1_int.is_negative()) {
A_int=-A_int;
}
if (same_cofactors) {
assert(integer(A_int)==A);
}
//TRACK_MAX(A); // A, 2.00585, 1
//TRACK_MAX(a*k_copy); // a*k_copy, 3.0039, 0
//TRACK_MAX(A*co2); // A*co2, 3.0039, 0
//TRACK_MAX((a*k_copy-A*co2)*integer(2)); // (a*k_copy-A*co2)*integer(2), 3.00585, 1
//TRACK_MAX(((a*k_copy-A*co2)*integer(2))/co1); // ((a*k_copy-A*co2)*integer(2))/co1, 2.03515, 1
//TRACK_MAX(((a*k_copy-A*co2)*integer(2))/co1 - b); // ((a*k_copy-A*co2)*integer(2))/co1 - b, 2.03515, 1
integer B=( ((a*k_copy-A*co2)*integer(2))/co1 - b )%(A*integer(2)); //[exact]
//TRACK_MAX(B); // B, 2.00585, 0
fixed_integer<uint64, 17> B_int;
{
fixed_integer<uint64, 25> a_k_copy(a_int*k_copy_int);
fixed_integer<uint64, 25> A_co2(A_int*co2_int);
fixed_integer<uint64, 25> t1((a_k_copy-A_co2)<<1);
fixed_integer<uint64, 17> t2(t1/co1_int);
fixed_integer<uint64, 17> t3(t2-b_int);
//assert(integer(a_k_copy) == a*k_copy);
//assert(integer(A_co2) == A*co2);
//assert(integer(a_k_copy-A_co2) == (a*k_copy-A*co2));
//print(integer(a_k_copy-A_co2).to_string());
//print(integer(fixed_integer<uint64, 30>(a_k_copy-A_co2)<<8).to_string());
//assert(integer((a_k_copy-A_co2)<<1) == ((a*k_copy-A*co2)*integer(2)));
//assert(integer(t2) == ((a*k_copy-A*co2)*integer(2))/co1);
//assert(integer(t3) == ( ((a*k_copy-A*co2)*integer(2))/co1 - b ));
//assert(integer(A_int<<1) == (A*integer(2)));
B_int=t3%fixed_integer<uint64, 17>(A_int<<1);
if (same_cofactors) {
assert(integer(B_int)==B);
}
}
//TRACK_MAX(B*B); // B*B, 4.01171, 0
//TRACK_MAX(B*B-D); // B*B-D, 4.01171, 0
integer C=((B*B-D)/A)>>2; //[division is exact; right shift is truncation towards 0; can be negative. right shift is exact]
fixed_integer<uint64, 17> C_int;
{
fixed_integer<uint64, 33> B_B(B_int*B_int);
fixed_integer<uint64, 33> B_B_D(B_B-D_int);
//calculated at the same time as the division
if (!(B_B_D%A_int).is_zero()) {
//todo //test random error injection
print( "discriminant error" );
return false;
}
fixed_integer<uint64, 17> t1(B_B_D/A_int);
//assert(integer(B_B)==B*B);
//assert(integer(B_B_D)==B*B-D);
//print(integer(t1).to_string());
//print(((B*B-D)/A).to_string());
//assert(integer(t1)==((B*B-D)/A));
C_int=t1>>2;
if (same_cofactors) {
assert(integer(C_int)==C);
}
}
//TRACK_MAX(C); // C, 2.03125, 1
if (A<0) {
A=-A;
C=-C;
}
A_int.set_negative(false);
C_int.set_negative(false);
//print( "A", A_int.to_integer().to_string() );
//print( "B", B_int.to_integer().to_string() );
if (same_cofactors) {
assert(integer(A_int)==A);
assert(integer(B_int)==B);
assert(integer(C_int)==C);
}
//TRACK_MAX(A); // A, 2.00585, 0
//TRACK_MAX(C); // C, 2.03125, 0
f.a=A;
f.b=B;
f.c=C;
//print( "" );
//print( "" );
//print( "==========================================" );
//print( "" );
//print( "" );
//
//
integer s=integer(a_copy_int);
integer t=integer(k_copy_int);
integer v0=-integer(co2_int);
integer v1=-integer(co1_int);
bool S_negative=(v1<=0);
integer c_v1=c*v1;
integer b_t=b*t;
integer b_t_c_v1=b_t+c_v1;
integer h=(b*t+c*v1)/a;
if (S_negative) {
h=-h;
}
integer v1_h=v1*h;
integer t_t_S=t*t;
if (S_negative) {
t_t_S=-t_t_S;
}
integer v0_2=v0<<1;
integer A_=t_t_S+v1_h;
integer A_2=A_<<1;
integer S_t_v0=t*v0;
if (S_negative) {
S_t_v0=-S_t_v0;
}
// B=( -((a*t+A*v0)*2)/v1 - b )%(A*2)
// B=( -((a*t+(t*t*S+v1*h)*v0)*2)/v1 - b )%(A*2)
// B=( -((a*t*2 + t*t*S*v0*2 + v1*v0*h*2))/v1 - b )%(A*2)
// B=( -(a*t*2 + t*t*S*v0*2)/v1 - v0*h*2 - b )%(A*2)
// B=( -(t*2(a + t*S*v0))/v1 - v0*h*2 - b )%(A*2)
integer a_S_t_v0=a+S_t_v0;
integer t_2=t<<1;
integer t_2_a_S_t_v0=t_2*a_S_t_v0;
integer t_2_a_S_t_v0_v1=t_2_a_S_t_v0/v1;
//integer t_2_a_S_t_v0_v1=t_2*a_S_t_v0_v1;
integer e=-t_2_a_S_t_v0_v1-b;
integer v0_2_h=v0_2*h;
integer f_=e-v0_2_h; // -(t*2*((a+S*t*v0)/v1)) - v0*h*2 - b
integer B_=f_%A_2;
A_=abs(A_);
//print( "A_", A_.to_string() );
//print( "B_", B_.to_string() );
return true;
}

View File

@ -81,7 +81,7 @@ def create_proof_of_time_nwesolowski(discriminant, x, iterations,
proof = ClassGroup.from_bytes(receive_con.recv_bytes(), discriminant)
p.join()
return y_2, proof_2 + serialize_proof([y_1, proof])
return y_2, proof_2 + iterations_1.to_bytes(8, byteorder="big") + serialize_proof([y_1, proof])
def create_proof_of_time_pietrzak(discriminant, x, iterations, int_size_bits):
@ -115,9 +115,21 @@ def check_proof_of_time_wesolowski(discriminant, x, proof_blob,
except Exception:
return False
def check_proof_of_time_nwesolowski(discriminant, x, proof_blob,
iterations, int_size_bits, recursion):
int_size = (int_size_bits + 16) >> 4
new_proof_blob = proof_blob[:4 * int_size]
iter_list = []
for i in range(4 * int_size, len(proof_blob), 4 * int_size + 8):
iter_list.append(int.from_bytes(proof_blob[i : (i + 8)], byteorder="big"))
new_proof_blob = new_proof_blob + proof_blob[(i + 8): (i + 8 + 4 * int_size)]
return check_proof_of_time_nwesolowski_inner(discriminant, x, new_proof_blob,
iterations, int_size_bits, iter_list, recursion)
def check_proof_of_time_nwesolowski_inner(discriminant, x, proof_blob,
iterations, int_size_bits, iter_list, recursion):
"""
Recursive verification function for nested wesolowski. The proof blob
includes the output of the VDF, along with the proof. The following
@ -145,14 +157,14 @@ def check_proof_of_time_nwesolowski(discriminant, x, proof_blob,
assert(len(proof) % 2 == 1 and len(proof) > 2)
_, _, w = proof_wesolowski.approximate_parameters(iterations)
iterations_1 = (iterations * w) // (w + 1)
iterations_1 = iter_list[-1]
iterations_2 = iterations - iterations_1
ver_outer = proof_wesolowski.verify_proof(x, proof[-2],
proof[-1], iterations_1)
return ver_outer and check_proof_of_time_nwesolowski(discriminant, proof[-2],
return ver_outer and check_proof_of_time_nwesolowski_inner(discriminant, proof[-2],
serialize_proof([y] + proof[:-2]),
iterations_2, int_size_bits, recursion-1)
iterations_2, int_size_bits, iter_list[:-1], recursion-1)
except Exception:
return False
@ -187,4 +199,4 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""

View File

@ -4,7 +4,7 @@ from enum import Enum
import time
import blspy
from typing import List, Dict, Optional, Tuple
from src.util.errors import BlockNotInBlockchain
from src.util.errors import BlockNotInBlockchain, InvalidGenesisBlock
from src.types.sized_bytes import bytes32
from src.util.ints import uint64, uint32
from src.types.trunk_block import TrunkBlock
@ -34,6 +34,7 @@ class ReceiveBlockResult(Enum):
class Blockchain:
def __init__(self, override_constants: Dict = {}):
print(consensus_constants["DIFFICULTY_STARTING"])
# Allow passing in custom overrides for any consesus parameters
self.constants: Dict = consensus_constants
for key, value in override_constants.items():
@ -45,7 +46,8 @@ class Blockchain:
self.height_to_hash: Dict[uint64, bytes32] = {}
self.genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"])
result = self.receive_block(self.genesis)
assert result == ReceiveBlockResult.ADDED_TO_HEAD
if result != ReceiveBlockResult.ADDED_TO_HEAD:
raise InvalidGenesisBlock()
# For blocks with height % constants["DIFFICULTY_DELAY"] == 1, a link to the hash of
# the (constants["DIFFICULTY_DELAY"])-th parent of this block
@ -237,7 +239,9 @@ class Blockchain:
challenge_hash = block.trunk_block.proof_of_time.output.challenge_hash
difficulty = self.get_next_difficulty(header_hash)
iterations = block.trunk_block.challenge.total_iters - prev_block.trunk_block.challenge.total_iters
return calculate_ips_from_iterations(proof_of_space, challenge_hash, difficulty, iterations)
# print(f"secon {self.constants[]}")
return calculate_ips_from_iterations(proof_of_space, challenge_hash, difficulty, iterations,
self.constants["MIN_BLOCK_TIME"])
# ips (along with difficulty) will change in this block, so we need to calculate the new one.
# The calculation is (iters_2 - iters_1) // (timestamp_2 - timestamp_1).
@ -307,7 +311,6 @@ class Blockchain:
(except for proof of time). The same as validate_block, but without proof of time
and challenge validation.
"""
# 1. Check previous pointer(s) / flyclient
if not genesis and block.prev_header_hash not in self.blocks:
return False
@ -408,7 +411,7 @@ class Blockchain:
block.trunk_block.proof_of_time.output.challenge_hash)
number_of_iters: uint64 = calculate_iterations_quality(pos_quality, block.trunk_block.proof_of_space.size,
difficulty, ips)
difficulty, ips, self.constants["MIN_BLOCK_TIME"])
if number_of_iters != block.trunk_block.proof_of_time.output.number_of_iterations:
return False

View File

@ -1,6 +1,3 @@
# How often to update our current estimate of VDF speed, in seconds
update_pot_estimate_interval: 30
# Don't send any more than these number of trunks and blocks, in one message
max_trunks_to_send: 100
max_blocks_to_send: 10

View File

@ -3,4 +3,9 @@ port: 8003
# How much recursion to use for the wesolowski VDF proof. This increases the size
# of the proofs.
n_wesolowski: 3
n_wesolowski: 2
# VDF servers must be started locally on these ports.
vdf_server_ports:
- 8889
- 8890

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,6 @@
from src.util.ints import uint64, uint8
from src.types.sized_bytes import bytes32
from src.types.proof_of_space import ProofOfSpace
from src.consensus.constants import constants
from decimal import getcontext, Decimal, ROUND_UP
# Sets a high precision so we can convert a 256 bit has to a decimal, and
@ -37,13 +36,13 @@ def _quality_to_decimal(quality: bytes32) -> Decimal:
def calculate_iterations_quality(quality: bytes32, size: uint8, difficulty: uint64,
vdf_ips: uint64) -> uint64:
vdf_ips: uint64, min_block_time: uint64) -> uint64:
"""
Calculates the number of iterations from the quality. The quality is converted to a number
between 0 and 1, then divided by expected plot size, and finally multiplied by the
difficulty.
"""
min_iterations = constants["MIN_BLOCK_TIME"] * vdf_ips
min_iterations = min_block_time * vdf_ips
dec_iters = (Decimal(int(difficulty) << 32) *
(_quality_to_decimal(quality) / _expected_plot_size(size)))
iters_final = uint64(min_iterations + dec_iters.to_integral_exact(rounding=ROUND_UP))
@ -52,17 +51,17 @@ def calculate_iterations_quality(quality: bytes32, size: uint8, difficulty: uint
def calculate_iterations(proof_of_space: ProofOfSpace, challenge_hash: bytes32,
difficulty: uint64, vdf_ips: uint64) -> uint64:
difficulty: uint64, vdf_ips: uint64, min_block_time: uint64) -> uint64:
"""
Convenience function to calculate the number of iterations using the proof instead
of the quality. The quality must be retrieved from the proof.
"""
quality: bytes32 = proof_of_space.verify_and_get_quality(challenge_hash)
return calculate_iterations_quality(quality, proof_of_space.size, difficulty, vdf_ips)
return calculate_iterations_quality(quality, proof_of_space.size, difficulty, vdf_ips, min_block_time)
def calculate_ips_from_iterations(proof_of_space: ProofOfSpace, challenge_hash: bytes32,
difficulty: uint64, iterations: uint64) -> uint64:
difficulty: uint64, iterations: uint64, min_block_time: uint64) -> uint64:
"""
Using the total number of iterations on a block (which is encoded in the block) along with
other details, we can calculate the VDF speed (iterations per second) used to compute the
@ -73,7 +72,7 @@ def calculate_ips_from_iterations(proof_of_space: ProofOfSpace, challenge_hash:
(_quality_to_decimal(quality) / _expected_plot_size(proof_of_space.size)))
iters_rounded = int(dec_iters.to_integral_exact(rounding=ROUND_UP))
min_iterations = uint64(iterations - iters_rounded)
ips = min_iterations / constants["MIN_BLOCK_TIME"]
ips = min_iterations / min_block_time
assert ips >= 1
assert uint64(ips) == ips
return uint64(ips)

View File

@ -13,6 +13,7 @@ from src.types.sized_bytes import bytes32
from src.util.ints import uint32, uint64
from src.consensus.block_rewards import calculate_block_reward
from src.consensus.pot_iterations import calculate_iterations_quality
from src.consensus.constants import constants
from src.server.outbound_message import OutboundMessage, Delivery, Message, NodeType
@ -63,9 +64,12 @@ async def challenge_response(challenge_response: plotter_protocol.ChallengeRespo
number_iters: uint64 = calculate_iterations_quality(challenge_response.quality,
challenge_response.plot_size,
difficulty, db.proof_of_time_estimate_ips)
difficulty,
db.proof_of_time_estimate_ips,
constants["MIN_BLOCK_TIME"])
estimate_secs: float = number_iters / db.proof_of_time_estimate_ips
log.info(f"Estimate: {estimate_secs}, rate: {db.proof_of_time_estimate_ips}")
if estimate_secs < config['pool_share_threshold'] or estimate_secs < config['propagate_threshold']:
async with db.lock:
db.plotter_responses_challenge[challenge_response.quality] = challenge_response.challenge_hash
@ -105,7 +109,8 @@ async def respond_proof_of_space(response: plotter_protocol.RespondProofOfSpace)
number_iters: uint64 = calculate_iterations_quality(computed_quality,
response.proof.size,
difficulty,
db.proof_of_time_estimate_ips)
db.proof_of_time_estimate_ips,
constants["MIN_BLOCK_TIME"])
async with db.lock:
estimate_secs: float = number_iters / db.proof_of_time_estimate_ips
if estimate_secs < config['pool_share_threshold']:

View File

@ -103,7 +103,7 @@ async def send_challenges_to_timelords() -> AsyncGenerator[OutboundMessage, None
async with db.lock:
for head in db.blockchain.get_current_heads():
challenge_hash = head.challenge.get_hash()
requests.append(timelord_protocol.ChallengeStart(challenge_hash))
requests.append(timelord_protocol.ChallengeStart(challenge_hash, head.challenge.height))
for request in requests:
yield OutboundMessage(NodeType.TIMELORD, Message("challenge_start", request), Delivery.BROADCAST)
@ -214,7 +214,7 @@ async def sync():
await asyncio.wait_for(db.potential_blocks_received[uint32(height)].wait(), timeout=2)
found = True
break
except concurrent.futures._base.TimeoutError:
except concurrent.futures.TimeoutError:
log.info("Did not receive desired block")
if not found:
raise PeersDontHaveBlock(f"Did not receive desired block at height {height}")
@ -497,7 +497,8 @@ async def unfinished_block(unfinished_block: peer_protocol.UnfinishedBlock) -> A
unfinished_block.block.trunk_block.prev_header_hash)
iterations_needed: uint64 = calculate_iterations(unfinished_block.block.trunk_block.proof_of_space,
challenge_hash, difficulty, vdf_ips)
challenge_hash, difficulty, vdf_ips,
constants["MIN_BLOCK_TIME"])
if (challenge_hash, iterations_needed) in db.unfinished_blocks:
return
@ -537,6 +538,7 @@ async def block(block: peer_protocol.Block) -> AsyncGenerator[OutboundMessage, N
"""
Receive a full block from a peer full node (or ourselves).
"""
header_hash = block.block.trunk_block.header.get_hash()
async with db.lock:
@ -549,7 +551,6 @@ async def block(block: peer_protocol.Block) -> AsyncGenerator[OutboundMessage, N
added: ReceiveBlockResult = db.blockchain.receive_block(block.block)
if added == ReceiveBlockResult.ALREADY_HAVE_BLOCK:
log.warning(f"ALready have block")
return
elif added == ReceiveBlockResult.INVALID_BLOCK:
log.warning(f"\tBlock {header_hash} at height {block.block.trunk_block.challenge.height} is invalid.")
@ -595,12 +596,12 @@ async def block(block: peer_protocol.Block) -> AsyncGenerator[OutboundMessage, N
async with db.lock:
# Only propagate blocks which extend the blockchain (one of the heads)
difficulty = db.blockchain.get_next_difficulty(block.block.prev_header_hash)
vdf_ips = db.blockchain.get_next_ips(block.block.prev_header_hash)
if vdf_ips != db.proof_of_time_estimate_ips:
db.proof_of_time_estimate_ips = vdf_ips
next_vdf_ips = db.blockchain.get_next_ips(block.block.header_hash)
if next_vdf_ips != db.proof_of_time_estimate_ips:
db.proof_of_time_estimate_ips = next_vdf_ips
ips_changed = True
if ips_changed:
rate_update = farmer_protocol.ProofOfTimeRate(vdf_ips)
rate_update = farmer_protocol.ProofOfTimeRate(next_vdf_ips)
yield OutboundMessage(NodeType.FARMER, Message("proof_of_time_rate", rate_update), Delivery.BROADCAST)
pos_quality = block.block.trunk_block.proof_of_space.verify_and_get_quality(
@ -610,9 +611,10 @@ async def block(block: peer_protocol.Block) -> AsyncGenerator[OutboundMessage, N
block.block.trunk_block.challenge.height,
pos_quality,
difficulty)
timelord_request = timelord_protocol.ChallengeStart(block.block.trunk_block.challenge.get_hash())
timelord_request_end = timelord_protocol.ChallengeStart(block.block.trunk_block.proof_of_time.
output.challenge_hash)
timelord_request = timelord_protocol.ChallengeStart(block.block.trunk_block.challenge.get_hash(),
block.block.trunk_block.challenge.height)
timelord_request_end = timelord_protocol.ChallengeEnd(block.block.trunk_block.proof_of_time.
output.challenge_hash)
# Tell timelord to stop previous challenge and start with new one
yield OutboundMessage(NodeType.TIMELORD, Message("challenge_end", timelord_request_end), Delivery.BROADCAST)
yield OutboundMessage(NodeType.TIMELORD, Message("challenge_start", timelord_request), Delivery.BROADCAST)
@ -622,3 +624,9 @@ async def block(block: peer_protocol.Block) -> AsyncGenerator[OutboundMessage, N
# Tell farmer about the new block
yield OutboundMessage(NodeType.FARMER, Message("proof_of_space_finalized", farmer_request), Delivery.BROADCAST)
elif added == ReceiveBlockResult.ADDED_AS_ORPHAN:
log.info("I've received an orphan, stopping the proof of time challenge.")
log.info(f"Height of the orphan block is {block.block.trunk_block.challenge.height}")
timelord_request_end = timelord_protocol.ChallengeEnd(block.block.trunk_block.proof_of_time.
output.challenge_hash)
yield OutboundMessage(NodeType.TIMELORD, Message("challenge_end", timelord_request_end), Delivery.BROADCAST)

View File

@ -1,6 +1,6 @@
from src.util.cbor_message import cbor_message
from src.types.sized_bytes import bytes32
from src.util.ints import uint64
from src.util.ints import uint32, uint64
from src.types.proof_of_time import ProofOfTime
"""
@ -20,7 +20,7 @@ class ProofOfTimeFinished:
@cbor_message(tag=3001)
class ChallengeStart:
challenge_hash: bytes32
height: uint32
@cbor_message(tag=3002)
class ChallengeEnd:

View File

@ -197,19 +197,19 @@ async def initialize_pipeline(aiter,
handshake_finished_3 = forker.fork(is_active=True)
# Reads messages one at a time from the TCP connection
messages_aiter = join_aiters(parallel_map_aiter(connection_to_message, 100, handshake_finished_1))
messages_aiter = join_aiters(parallel_map_aiter(connection_to_message, handshake_finished_1, 100))
# Handles each message one at a time, and yields responses to send back or broadcast
responses_aiter = join_aiters(parallel_map_aiter(
partial_func.partial_async_gen(handle_message, api),
100, messages_aiter))
messages_aiter, 100))
if on_connect is not None:
# Uses a forked aiter, and calls the on_connect function to send some initial messages
# as soon as the connection is established
on_connect_outbound_aiter = join_aiters(parallel_map_aiter(
partial_func.partial_async_gen(connection_to_outbound, on_connect), 100, handshake_finished_2))
partial_func.partial_async_gen(connection_to_outbound, on_connect), handshake_finished_2, 100))
responses_aiter = join_aiters(iter_to_aiter([responses_aiter, on_connect_outbound_aiter]))
if outbound_aiter is not None:
@ -220,7 +220,7 @@ async def initialize_pipeline(aiter,
# For each outbound message, replicate for each peer that we need to send to
expanded_messages_aiter = join_aiters(parallel_map_aiter(
expand_outbound_messages, 100, responses_aiter))
expand_outbound_messages, responses_aiter, 100))
# This will run forever. Sends each message through the TCP connection, using the
# length encoding and CBOR serialization

View File

@ -9,7 +9,10 @@ from src.protocols.plotter_protocol import PlotterHandshake
from src.server.outbound_message import OutboundMessage, Message, Delivery, NodeType
from src.util.network import parse_host_port
logging.basicConfig(format='Farmer %(name)-25s: %(levelname)-8s %(message)s', level=logging.INFO)
logging.basicConfig(format='Farmer %(name)-25s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s',
level=logging.INFO,
datefmt='%H:%M:%S'
)
async def main():

View File

@ -8,7 +8,10 @@ from src.server.outbound_message import NodeType
from src.types.peer_info import PeerInfo
logging.basicConfig(format='FullNode %(name)-23s: %(levelname)-8s %(message)s', level=logging.INFO)
logging.basicConfig(format='FullNode %(name)-23s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s',
level=logging.INFO,
datefmt='%H:%M:%S'
)
log = logging.getLogger(__name__)
"""

View File

@ -5,7 +5,10 @@ from src.server.outbound_message import NodeType
from src.util.network import parse_host_port
from src import plotter
logging.basicConfig(format='Plotter %(name)-24s: %(levelname)-8s %(message)s', level=logging.INFO)
logging.basicConfig(format='Plotter %(name)-24s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s',
level=logging.INFO,
datefmt='%H:%M:%S'
)
async def main():

View File

@ -5,7 +5,10 @@ from src.server.outbound_message import NodeType
from src.util.network import parse_host_port
from src import timelord
logging.basicConfig(format='Timelord %(name)-25s: %(levelname)-20s %(message)s', level=logging.INFO)
logging.basicConfig(format='Timelord %(name)-25s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s',
level=logging.INFO,
datefmt='%H:%M:%S'
)
async def main():

View File

@ -1,16 +1,22 @@
ps -e | grep python | grep "start_" | awk '{print $1}' | xargs -L1 kill -9
python -m src.server.start_plotter &
ps -e | grep "fast_vdf/vdf" | awk '{print $1}' | xargs -L1 kill -9
./lib/chiavdf/fast_vdf/vdf 8889 &
P1=$!
python -m src.server.start_timelord &
./lib/chiavdf/fast_vdf/vdf 8890 &
P2=$!
python -m src.server.start_farmer &
python -m src.server.start_plotter &
P3=$!
python -m src.server.start_full_node "127.0.0.1" 8002 "-f" "-t" &
python -m src.server.start_timelord &
P4=$!
python -m src.server.start_full_node "127.0.0.1" 8004 &
python -m src.server.start_farmer &
P5=$!
python -m src.server.start_full_node "127.0.0.1" 8005 &
python -m src.server.start_full_node "127.0.0.1" 8002 "-f" "-t" &
P6=$!
python -m src.server.start_full_node "127.0.0.1" 8004 &
P7=$!
python -m src.server.start_full_node "127.0.0.1" 8005 &
P8=$!
_term() {
echo "Caught SIGTERM signal, killing all servers."
@ -20,9 +26,11 @@ _term() {
kill -TERM "$P4" 2>/dev/null
kill -TERM "$P5" 2>/dev/null
kill -TERM "$P6" 2>/dev/null
kill -TERM "$P7" 2>/dev/null
kill -TERM "$P8" 2>/dev/null
}
trap _term SIGTERM
trap _term SIGINT
trap _term INT
wait $P1 $P2 $P3 $P4 $P5 $P6
wait $P1 $P2 $P3 $P4 $P5 $P6 $P7 $P8

View File

@ -1,10 +1,10 @@
import logging
import asyncio
import time
import io
import yaml
import time
from asyncio import Lock
from typing import Dict
from typing import Dict, List
from lib.chiavdf.inkfish.create_discriminant import create_discriminant
from lib.chiavdf.inkfish.proof_of_time import check_proof_of_time_nwesolowski
@ -20,13 +20,19 @@ from src.server.outbound_message import OutboundMessage, Delivery, Message, Node
class Database:
lock: Lock = Lock()
challenges: Dict = {}
process_running: bool = False
free_servers: List[int] = []
active_discriminants: Dict = {}
active_discriminants_start_time: Dict = {}
pending_iters: Dict = {}
done_discriminants = []
seen_discriminants = []
active_heights = []
config = yaml.safe_load(open("src/config/timelord.yaml", "r"))
log = logging.getLogger(__name__)
config = yaml.safe_load(open("src/config/timelord.yaml", "r"))
db = Database()
db.free_servers = config["vdf_server_ports"]
@api_request
@ -37,11 +43,114 @@ async def challenge_start(challenge_start: timelord_protocol.ChallengeStart):
a new VDF process here. But we don't know how many iterations to run for, so we run
forever.
"""
# TODO: stop previous processes
disc: int = create_discriminant(challenge_start.challenge_hash, constants["DISCRIMINANT_SIZE_BITS"])
async with db.lock:
disc: int = create_discriminant(challenge_start.challenge_hash, constants["DISCRIMINANT_SIZE_BITS"])
db.challenges[challenge_start.challenge_hash] = (time.time(), disc, None)
# TODO: Start a VDF process
if (challenge_start.challenge_hash in db.seen_discriminants):
log.info("Already seen this one... Ignoring")
return
db.seen_discriminants.append(challenge_start.challenge_hash)
db.active_heights.append(challenge_start.height)
# Wait for a server to become free.
port: int = -1
while port == -1:
async with db.lock:
if (challenge_start.height <= max(db.active_heights) - 3):
db.done_discriminants.append(challenge_start.challenge_hash)
db.active_heights.remove(challenge_start.height)
log.info(f"Will not execute challenge at height {challenge_start.height}, too old")
return
assert(len(db.active_heights) > 0)
if (challenge_start.height == max(db.active_heights)):
if (len(db.free_servers) != 0):
port = db.free_servers[0]
db.free_servers = db.free_servers[1:]
log.info(f"Discriminant {disc} attached to port {port}.")
log.info(f"Height attached is {challenge_start.height}")
db.active_heights.remove(challenge_start.height)
# Poll until a server becomes free.
if port == -1:
await asyncio.sleep(0.1)
# TODO(Florin): Handle connection failure (attempt another server)
try:
reader, writer = await asyncio.open_connection('127.0.0.1', port)
except Exception as e:
e_to_str = str(e)
log.error(f"Connection to VDF server error message: {e_to_str}")
writer.write((str(len(str(disc))) + str(disc)).encode())
await writer.drain()
ok = await reader.readexactly(2)
assert(ok.decode() == "OK")
log.info("Got handshake with VDF server.")
async with db.lock:
db.active_discriminants[challenge_start.challenge_hash] = writer
db.active_discriminants_start_time[challenge_start.challenge_hash] = time.time()
async with db.lock:
if (challenge_start.challenge_hash in db.pending_iters):
for iter in db.pending_iters[challenge_start.challenge_hash]:
writer.write((str(len(str(iter))) + str(iter)).encode())
await writer.drain()
# Listen to the server until "STOP" is received.
while True:
data = await reader.readexactly(4)
if (data.decode() == "STOP"):
# Server is now available.
async with db.lock:
writer.write(b"ACK")
await writer.drain()
db.free_servers.append(port)
break
elif (data.decode() == "POLL"):
async with db.lock:
# If I have a newer discriminant... Free up the VDF server
if (len(db.active_heights) > 0 and challenge_start.height < max(db.active_heights)):
log.info("Got poll, stopping the challenge!")
writer.write(b'10')
await writer.drain()
del db.active_discriminants[challenge_start.challenge_hash]
del db.active_discriminants_start_time[challenge_start.challenge_hash]
db.done_discriminants.append(challenge_start.challenge_hash)
else:
try:
# This must be a proof, read the continuation.
proof = await reader.readexactly(1860)
stdout_bytes_io: io.BytesIO = io.BytesIO(bytes.fromhex(data.decode() + proof.decode()))
except Exception as e:
e_to_str = str(e)
log.error(f"Socket error: {e_to_str}")
iterations_needed = int.from_bytes(stdout_bytes_io.read(8), "big", signed=True)
y = ClassgroupElement.parse(stdout_bytes_io)
proof_bytes: bytes = stdout_bytes_io.read()
# Verifies our own proof just in case
proof_blob = ClassGroup.from_ab_discriminant(y.a, y.b, disc).serialize() + proof_bytes
x = ClassGroup.from_ab_discriminant(2, 1, disc)
assert check_proof_of_time_nwesolowski(disc, x, proof_blob, iterations_needed,
constants["DISCRIMINANT_SIZE_BITS"], config["n_wesolowski"])
output = ProofOfTimeOutput(challenge_start.challenge_hash,
iterations_needed,
ClassgroupElement(y.a, y.b))
proof_of_time = ProofOfTime(output, config['n_wesolowski'], [uint8(b) for b in proof_bytes])
response = timelord_protocol.ProofOfTimeFinished(proof_of_time)
async with db.lock:
time_taken = time.time() - db.active_discriminants_start_time[challenge_start.challenge_hash]
ips = int(iterations_needed / time_taken * 10)/10
log.info(f"Finished PoT, chall:{challenge_start.challenge_hash[:10].hex()}.. {iterations_needed}"
f" iters. {int(time_taken*1000)/1000}s, {ips} ips")
yield OutboundMessage(NodeType.FULL_NODE, Message("proof_of_time_finished", response), Delivery.RESPOND)
@api_request
@ -50,9 +159,17 @@ async def challenge_end(challenge_end: timelord_protocol.ChallengeEnd):
A challenge is no longer active, so stop the process for this challenge, if it
exists.
"""
# TODO: Stop VDF process for this challenge
async with db.lock:
db.process_running = False
if (challenge_end.challenge_hash in db.done_discriminants):
return
if (challenge_end.challenge_hash in db.active_discriminants):
writer = db.active_discriminants[challenge_end.challenge_hash]
writer.write(b'10')
await writer.drain()
del db.active_discriminants[challenge_end.challenge_hash]
del db.active_discriminants_start_time[challenge_end.challenge_hash]
db.done_discriminants.append(challenge_end.challenge_hash)
await asyncio.sleep(0.5)
@api_request
@ -60,60 +177,18 @@ async def proof_of_space_info(proof_of_space_info: timelord_protocol.ProofOfSpac
"""
Notification from full node about a new proof of space for a challenge. If we already
have a process for this challenge, we should communicate to the process to tell it how
many iterations to run for. TODO: process should be started in challenge_start instead.
many iterations to run for.
"""
async with db.lock:
if proof_of_space_info.challenge_hash not in db.challenges:
log.warning(f"Have not seen challenge {proof_of_space_info.challenge_hash} yet.")
return
time_recvd, disc, iters = db.challenges[proof_of_space_info.challenge_hash]
if iters:
if proof_of_space_info.iterations_needed == iters:
log.warning(f"Have already seen this challenge with {proof_of_space_info.iterations_needed}\
iterations. Ignoring.")
return
elif proof_of_space_info.iterations_needed > iters:
# TODO: don't ignore, communicate to process
log.warning(f"Too many iterations required. Already executing {iters} iters")
return
if db.process_running:
# TODO: don't ignore, start a new process
log.warning("Already have a running process. Ignoring.")
return
db.process_running = True
command = (f"python -m lib.chiavdf.inkfish.cmds -t n-wesolowski -l 1024 -d {config['n_wesolowski']} " +
f"{proof_of_space_info.challenge_hash.hex()} {proof_of_space_info.iterations_needed}")
log.info(f"Executing VDF command with new process: {command}")
process_start = time.time()
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE)
stdout, stderr = await proc.communicate()
async with db.lock:
db.process_running = False
log.info(f"Finished executing VDF after {int((time.time() - process_start) * 1000)/1000}s")
if stderr:
log.error(f'[stderr]\n{stderr.decode()}')
stdout_bytes_io: io.BytesIO = io.BytesIO(bytes.fromhex(stdout.decode()))
y = ClassgroupElement.parse(stdout_bytes_io)
proof_bytes: bytes = stdout_bytes_io.read()
# Verifies our own proof just in case
proof_blob = ClassGroup.from_ab_discriminant(y.a, y.b, disc).serialize() + proof_bytes
x = ClassGroup.from_ab_discriminant(2, 1, disc)
assert check_proof_of_time_nwesolowski(disc, x, proof_blob, proof_of_space_info.iterations_needed, 1024, 3)
output = ProofOfTimeOutput(proof_of_space_info.challenge_hash,
proof_of_space_info.iterations_needed,
ClassgroupElement(y.a, y.b))
proof_of_time = ProofOfTime(output, config['n_wesolowski'], [uint8(b) for b in proof_bytes])
response = timelord_protocol.ProofOfTimeFinished(proof_of_time)
yield OutboundMessage(NodeType.FULL_NODE, Message("proof_of_time_finished", response), Delivery.RESPOND)
if (proof_of_space_info.challenge_hash in db.active_discriminants):
writer = db.active_discriminants[proof_of_space_info.challenge_hash]
writer.write(((str(len(str(proof_of_space_info.iterations_needed))) +
str(proof_of_space_info.iterations_needed)).encode()))
await writer.drain()
return
if (proof_of_space_info.challenge_hash in db.done_discriminants):
return
if (proof_of_space_info.challenge_hash not in db.pending_iters):
db.pending_iters[proof_of_space_info.challenge_hash] = []
db.pending_iters[proof_of_space_info.challenge_hash].append(proof_of_space_info.iterations_needed)

View File

@ -47,3 +47,8 @@ class InvalidWeight(Exception):
class InvalidUnfinishedBlock(Exception):
"""The unfinished block we received is invalid"""
pass
class InvalidGenesisBlock(Exception):
"""Genesis block is not valid according to the consensus constants and rules"""
pass

View File

@ -3,7 +3,7 @@ import os
import sys
from hashlib import sha256
from chiapos import DiskPlotter, DiskProver
from typing import List
from typing import List, Dict
from blspy import PublicKey, PrivateKey, PrependSignature
from src.types.sized_bytes import bytes32
from src.types.full_block import FullBlock
@ -65,114 +65,123 @@ class BlockTools:
sys.exit(1)
def get_consecutive_blocks(self,
input_constants: Dict,
num_blocks: int,
difficulty=constants["DIFFICULTY_STARTING"],
discriminant_size=constants["DISCRIMINANT_SIZE_BITS"],
seconds_per_block=constants["BLOCK_TIME_TARGET"],
block_list: List[FullBlock] = [],
seed: int = 0) -> List[FullBlock]:
seconds_per_block=constants["BLOCK_TIME_TARGET"],
seed: uint64 = uint64(0)) -> List[FullBlock]:
test_constants = constants.copy()
for key, value in input_constants.items():
test_constants[key] = value
if len(block_list) == 0:
block_list.append(self.create_genesis_block(bytes([(seed) % 256]*32),
difficulty,
discriminant_size,
seed))
prev_difficulty = difficulty
curr_difficulty = difficulty
curr_ips = constants["VDF_IPS_STARTING"]
elif len(block_list) < (constants["DIFFICULTY_EPOCH"] + constants["DIFFICULTY_DELAY"]):
if "GENESIS_BLOCK" in test_constants:
block_list.append(FullBlock.from_bytes(test_constants["GENESIS_BLOCK"]))
else:
block_list.append(self.create_genesis_block(test_constants, bytes([(seed) % 256]*32), seed))
prev_difficulty = test_constants["DIFFICULTY_STARTING"]
curr_difficulty = prev_difficulty
curr_ips = test_constants["VDF_IPS_STARTING"]
elif len(block_list) < (test_constants["DIFFICULTY_EPOCH"] + test_constants["DIFFICULTY_DELAY"]):
# First epoch (+delay), so just get first difficulty
prev_difficulty = block_list[0].weight
curr_difficulty = block_list[0].weight
assert difficulty == prev_difficulty
curr_ips = constants["VDF_IPS_STARTING"]
assert test_constants["DIFFICULTY_STARTING"] == prev_difficulty
curr_ips = test_constants["VDF_IPS_STARTING"]
else:
curr_difficulty = block_list[-1].weight - block_list[-2].weight
prev_difficulty = (block_list[-1 - constants["DIFFICULTY_EPOCH"]].weight -
block_list[-2 - constants["DIFFICULTY_EPOCH"]].weight)
prev_difficulty = (block_list[-1 - test_constants["DIFFICULTY_EPOCH"]].weight -
block_list[-2 - test_constants["DIFFICULTY_EPOCH"]].weight)
curr_ips = calculate_ips_from_iterations(block_list[-1].trunk_block.proof_of_space,
block_list[-1].trunk_block.proof_of_time.output.challenge_hash,
curr_difficulty,
block_list[-1].trunk_block.proof_of_time.output
.number_of_iterations)
.number_of_iterations,
test_constants["MIN_BLOCK_TIME"])
starting_height = block_list[-1].height + 1
timestamp = block_list[-1].trunk_block.header.data.timestamp
for next_height in range(starting_height, starting_height + num_blocks):
if (next_height > constants["DIFFICULTY_EPOCH"] and
next_height % constants["DIFFICULTY_EPOCH"] == constants["DIFFICULTY_DELAY"]):
if (next_height > test_constants["DIFFICULTY_EPOCH"] and
next_height % test_constants["DIFFICULTY_EPOCH"] == test_constants["DIFFICULTY_DELAY"]):
# Calculates new difficulty
height1 = uint64(next_height - (constants["DIFFICULTY_EPOCH"] +
constants["DIFFICULTY_DELAY"]) - 1)
height2 = uint64(next_height - (constants["DIFFICULTY_EPOCH"]) - 1)
height3 = uint64(next_height - (constants["DIFFICULTY_DELAY"]) - 1)
height1 = uint64(next_height - (test_constants["DIFFICULTY_EPOCH"] +
test_constants["DIFFICULTY_DELAY"]) - 1)
height2 = uint64(next_height - (test_constants["DIFFICULTY_EPOCH"]) - 1)
height3 = uint64(next_height - (test_constants["DIFFICULTY_DELAY"]) - 1)
if height1 >= 0:
timestamp1 = block_list[height1].trunk_block.header.data.timestamp
iters1 = block_list[height1].trunk_block.challenge.total_iters
else:
timestamp1 = (block_list[0].trunk_block.header.data.timestamp -
constants["BLOCK_TIME_TARGET"])
test_constants["BLOCK_TIME_TARGET"])
iters1 = block_list[0].trunk_block.challenge.total_iters
timestamp2 = block_list[height2].trunk_block.header.data.timestamp
timestamp3 = block_list[height3].trunk_block.header.data.timestamp
iters3 = block_list[height3].trunk_block.challenge.total_iters
term1 = (constants["DIFFICULTY_DELAY"] * prev_difficulty *
(timestamp3 - timestamp2) * constants["BLOCK_TIME_TARGET"])
term1 = (test_constants["DIFFICULTY_DELAY"] * prev_difficulty *
(timestamp3 - timestamp2) * test_constants["BLOCK_TIME_TARGET"])
term2 = ((constants["DIFFICULTY_WARP_FACTOR"] - 1) *
(constants["DIFFICULTY_EPOCH"] - constants["DIFFICULTY_DELAY"]) * curr_difficulty
* (timestamp2 - timestamp1) * constants["BLOCK_TIME_TARGET"])
term2 = ((test_constants["DIFFICULTY_WARP_FACTOR"] - 1) *
(test_constants["DIFFICULTY_EPOCH"] - test_constants["DIFFICULTY_DELAY"]) * curr_difficulty
* (timestamp2 - timestamp1) * test_constants["BLOCK_TIME_TARGET"])
# Round down after the division
new_difficulty: uint64 = uint64((term1 + term2) //
(constants["DIFFICULTY_WARP_FACTOR"] *
(test_constants["DIFFICULTY_WARP_FACTOR"] *
(timestamp3 - timestamp2) *
(timestamp2 - timestamp1)))
if new_difficulty >= curr_difficulty:
new_difficulty = min(new_difficulty, uint64(constants["DIFFICULTY_FACTOR"] *
new_difficulty = min(new_difficulty, uint64(test_constants["DIFFICULTY_FACTOR"] *
curr_difficulty))
else:
new_difficulty = max([uint64(1), new_difficulty,
uint64(curr_difficulty // constants["DIFFICULTY_FACTOR"])])
uint64(curr_difficulty // test_constants["DIFFICULTY_FACTOR"])])
prev_difficulty = curr_difficulty
curr_difficulty = new_difficulty
curr_ips = uint64((iters3 - iters1)//(timestamp3 - timestamp1))
print(f"Changing IPS {next_height} to {curr_ips} and diff {new_difficulty}")
print(f"Curr ips: {curr_ips}")
time_taken = seconds_per_block
timestamp += time_taken
block_list.append(self.create_next_block(block_list[-1], timestamp, curr_difficulty,
curr_ips, discriminant_size, seed))
block_list.append(self.create_next_block(test_constants, block_list[-1], timestamp, curr_difficulty,
curr_ips, seed))
return block_list
def create_genesis_block(self, challenge_hash=bytes([0]*32), difficulty=constants["DIFFICULTY_STARTING"],
discriminant_size=constants["DISCRIMINANT_SIZE_BITS"], seed: int = 0) -> FullBlock:
def create_genesis_block(self, input_constants: Dict, challenge_hash=bytes([0]*32),
seed: uint64 = uint64(0)) -> FullBlock:
"""
Creates the genesis block with the specified details.
"""
test_constants = constants.copy()
for key, value in input_constants.items():
test_constants[key] = value
return self._create_block(
test_constants,
challenge_hash,
uint32(0),
bytes([0]*32),
uint64(0),
uint64(0),
uint64(time.time()),
uint64(difficulty),
constants["VDF_IPS_STARTING"],
discriminant_size,
uint64(test_constants["DIFFICULTY_STARTING"]),
uint64(test_constants["VDF_IPS_STARTING"]),
seed
)
def create_next_block(self, prev_block: FullBlock, timestamp: uint64,
difficulty=constants["DIFFICULTY_STARTING"],
ips=constants["VDF_IPS_STARTING"],
discriminant_size=constants["DISCRIMINANT_SIZE_BITS"],
def create_next_block(self, input_constants: Dict, prev_block: FullBlock, timestamp: uint64,
difficulty: uint64, ips: uint64,
seed: int = 0) -> FullBlock:
"""
Creates the next block with the specified details.
"""
test_constants = constants.copy()
for key, value in input_constants.items():
test_constants[key] = value
return self._create_block(
test_constants,
prev_block.trunk_block.challenge.get_hash(),
prev_block.height + 1,
prev_block.header_hash,
@ -181,13 +190,12 @@ class BlockTools:
timestamp,
uint64(difficulty),
ips,
discriminant_size,
seed
)
def _create_block(self, challenge_hash: bytes32, height: uint32, prev_header_hash: bytes32,
def _create_block(self, test_constants: Dict, challenge_hash: bytes32, height: uint32, prev_header_hash: bytes32,
prev_iters: uint64, prev_weight: uint64, timestamp: uint64, difficulty: uint64,
ips: uint64, discriminant_size: uint64, seed: int = 0) -> FullBlock:
ips: uint64, seed: int) -> FullBlock:
"""
Creates a block with the specified details. Uses the stored plots to create a proof of space,
and also evaluates the VDF for the proof of time.
@ -211,10 +219,11 @@ class BlockTools:
proof_xs: bytes = prover.get_full_proof(challenge_hash, 0)
proof_of_space: ProofOfSpace = ProofOfSpace(pool_pk, plot_pk, k, list(proof_xs))
number_iters: uint64 = pot_iterations.calculate_iterations(proof_of_space, challenge_hash,
difficulty, ips)
disc: int = create_discriminant(challenge_hash, discriminant_size)
difficulty, ips,
test_constants["MIN_BLOCK_TIME"])
disc: int = create_discriminant(challenge_hash, test_constants["DISCRIMINANT_SIZE_BITS"])
start_x: ClassGroup = ClassGroup.from_ab_discriminant(2, 1, disc)
y_cl, proof_bytes = create_proof_of_time_nwesolowski(
disc, start_x, number_iters, disc, n_wesolowski)
@ -250,5 +259,7 @@ class BlockTools:
# This code generates a genesis block, uncomment to output genesis block to terminal
# This might take a while, using the python VDF implementation.
# Run by doing python -m tests.block_tools
# bt = BlockTools()
# print(bt.create_genesis_block(bytes([4]*32)).serialize())
# print(bt.create_genesis_block({}, bytes([1]*32), uint64(0)).serialize())

View File

@ -35,7 +35,8 @@ class TestPotIterations():
for b_index in range(total_blocks):
qualities = [sha256(b_index.to_bytes(32, "big") + bytes(farmer_index)).digest()
for farmer_index in range(len(farmer_ks))]
iters = [calculate_iterations_quality(qualities[i], farmer_ks[i], uint64(50000000), uint64(5000))
iters = [calculate_iterations_quality(qualities[i], farmer_ks[i], uint64(50000000),
uint64(5000), uint64(10))
for i in range(len(qualities))]
# print(iters)
wins[iters.index(min(iters))] += 1

View File

@ -1,7 +1,7 @@
from src.consensus.constants import constants
import time
import pytest
from blspy import PrivateKey
from src.consensus.constants import constants
from src.types.coinbase import CoinbaseInfo
from src.types.block_body import BlockBody
from src.types.proof_of_space import ProofOfSpace
@ -16,6 +16,17 @@ from tests.block_tools import BlockTools
bt = BlockTools()
test_constants = {
"DIFFICULTY_STARTING": 5,
"DISCRIMINANT_SIZE_BITS": 16,
"BLOCK_TIME_TARGET": 10,
"MIN_BLOCK_TIME": 2,
"DIFFICULTY_EPOCH": 12, # The number of blocks per epoch
"DIFFICULTY_WARP_FACTOR": 4, # DELAY divides EPOCH in order to warp efficiently.
"DIFFICULTY_DELAY": 3 # EPOCH / WARP_FACTOR
}
test_constants["GENESIS_BLOCK"] = bt.create_genesis_block(test_constants, bytes([0]*32), uint64(0)).serialize()
class TestGenesisBlock():
def test_basic_blockchain(self):
@ -34,12 +45,8 @@ class TestBlockValidation():
"""
Provides a list of 10 valid blocks, as well as a blockchain with 9 blocks added to it.
"""
blocks = bt.get_consecutive_blocks(10, 5, 16)
b: Blockchain = Blockchain({
"GENESIS_BLOCK": blocks[0].serialize(),
"DIFFICULTY_STARTING": 5,
"DISCRIMINANT_SIZE_BITS": 16
})
blocks = bt.get_consecutive_blocks(test_constants, 10, [], 10)
b: Blockchain = Blockchain(test_constants)
for i in range(1, 9):
assert b.receive_block(blocks[i]) == ReceiveBlockResult.ADDED_TO_HEAD
return (blocks, b)
@ -166,26 +173,16 @@ class TestBlockValidation():
def test_difficulty_change(self):
num_blocks = 20
# Make it 5x faster than target time
blocks = bt.get_consecutive_blocks(num_blocks, 5, 16, 1, [])
b: Blockchain = Blockchain({
"GENESIS_BLOCK": blocks[0].serialize(),
"DIFFICULTY_STARTING": 5,
"DISCRIMINANT_SIZE_BITS": 16,
"BLOCK_TIME_TARGET": 10,
"DIFFICULTY_EPOCH": 12, # The number of blocks per epoch
"DIFFICULTY_WARP_FACTOR": 4, # DELAY divides EPOCH in order to warp efficiently.
"DIFFICULTY_DELAY": 3 # EPOCH / WARP_FACTOR
})
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 2)
b: Blockchain = Blockchain(test_constants)
for i in range(1, num_blocks):
# print(f"Adding {i}")
assert b.receive_block(blocks[i]) == ReceiveBlockResult.ADDED_TO_HEAD
assert b.get_next_difficulty(blocks[13].header_hash) == b.get_next_difficulty(blocks[12].header_hash)
assert b.get_next_difficulty(blocks[14].header_hash) > b.get_next_difficulty(blocks[13].header_hash)
assert ((b.get_next_difficulty(blocks[14].header_hash) / b.get_next_difficulty(blocks[13].header_hash)
<= constants["DIFFICULTY_FACTOR"]))
assert blocks[-1].trunk_block.challenge.total_iters == 176091
assert blocks[-1].trunk_block.challenge.total_iters == 142911
assert b.get_next_ips(blocks[1].header_hash) == constants["VDF_IPS_STARTING"]
assert b.get_next_ips(blocks[12].header_hash) == b.get_next_ips(blocks[11].header_hash)
@ -196,22 +193,14 @@ class TestBlockValidation():
class TestReorgs():
def test_basic_reorg(self):
blocks = bt.get_consecutive_blocks(100, 5, 16, 9, [], 0)
b: Blockchain = Blockchain({
"GENESIS_BLOCK": blocks[0].serialize(),
"DIFFICULTY_STARTING": 5,
"DISCRIMINANT_SIZE_BITS": 16,
"BLOCK_TIME_TARGET": 10,
"DIFFICULTY_EPOCH": 12, # The number of blocks per epoch
"DIFFICULTY_WARP_FACTOR": 4, # DELAY divides EPOCH in order to warp efficiently.
"DIFFICULTY_DELAY": 3 # EPOCH / WARP_FACTOR
})
blocks = bt.get_consecutive_blocks(test_constants, 100, [], 9)
b: Blockchain = Blockchain(test_constants)
for block in blocks:
b.receive_block(block)
assert b.get_current_heads()[0].height == 100
blocks_reorg_chain = bt.get_consecutive_blocks(30, 5, 16, 9, blocks[:90], 1)
blocks_reorg_chain = bt.get_consecutive_blocks(test_constants, 30, blocks[:90], 9, uint64(1))
for reorg_block in blocks_reorg_chain:
result = b.receive_block(reorg_block)
if reorg_block.height < 90:
@ -223,23 +212,15 @@ class TestReorgs():
assert b.get_current_heads()[0].height == 119
def test_reorg_from_genesis(self):
blocks = bt.get_consecutive_blocks(20, 5, 16, 9, [], 0)
blocks = bt.get_consecutive_blocks(test_constants, 20, [], 9, uint64(0))
b: Blockchain = Blockchain({
"GENESIS_BLOCK": blocks[0].serialize(),
"DIFFICULTY_STARTING": 5,
"DISCRIMINANT_SIZE_BITS": 16,
"BLOCK_TIME_TARGET": 10,
"DIFFICULTY_EPOCH": 12, # The number of blocks per epoch
"DIFFICULTY_WARP_FACTOR": 4, # DELAY divides EPOCH in order to warp efficiently.
"DIFFICULTY_DELAY": 3 # EPOCH / WARP_FACTOR
})
b: Blockchain = Blockchain(test_constants)
for block in blocks:
b.receive_block(block)
assert b.get_current_heads()[0].height == 20
# Reorg from genesis
blocks_reorg_chain = bt.get_consecutive_blocks(21, 5, 16, 9, [blocks[0]], 1)
blocks_reorg_chain = bt.get_consecutive_blocks(test_constants, 21, [blocks[0]], 9, uint64(1))
for reorg_block in blocks_reorg_chain:
result = b.receive_block(reorg_block)
if reorg_block.height == 0:
@ -251,7 +232,7 @@ class TestReorgs():
assert b.get_current_heads()[0].height == 21
# Reorg back to original branch
blocks_reorg_chain_2 = bt.get_consecutive_blocks(3, 5, 16, 9, blocks, 3)
blocks_reorg_chain_2 = bt.get_consecutive_blocks(test_constants, 3, blocks, 9, uint64(3))
b.receive_block(blocks_reorg_chain_2[20]) == ReceiveBlockResult.ADDED_AS_ORPHAN
assert b.receive_block(blocks_reorg_chain_2[21]) == ReceiveBlockResult.ADDED_TO_HEAD
assert b.receive_block(blocks_reorg_chain_2[22]) == ReceiveBlockResult.ADDED_TO_HEAD