chia-blockchain/lib/chiavdf/fast_vdf/asm_gcd_base_continued_frac...

375 lines
14 KiB
C++

namespace asm_code {
const double range_check_range=double((1ull<<53)-1);
const uint64 double_sign_mask=(1ull<<63);
const uint64 double_abs_mask=~double_sign_mask;
//clobbers v
void range_check(
reg_vector v, reg_vector range, reg_vector c_double_abs_mask,
string out_of_range_label
) {
EXPAND_MACROS_SCOPE;
m.bind(range, "range");
m.bind(c_double_abs_mask, "double_abs_mask");
m.bind(v, "tmp");
//tmp=abs(tmp)
APPEND_M(str( "ANDPD `tmp, `double_abs_mask" ));
//tmp all 0s if (abs(tmp0)<=range && abs(tmp1)<=range)
APPEND_M(str( "CMPNLEPD `tmp, `range" ));
//todo //can replace this with POR into an accumulator then use a single PTEST
//todo //can compile the code twice for is_leher being true and false, then branch to the appropriate version
//todo //can probably get rid of the uv range checks if is_lehmer is true
//todo //can get rid of the ab range checks if the table is used and each table uv value has a magnitude less than a certain amount
APPEND_M(str( "PTEST `tmp, `tmp" ));
APPEND_M(str( "JNZ #", out_of_range_label ));
}
//clobbers b
//this calculates the dot product of each lane separately and puts the result in that lane
void dot_product_exact(
array<reg_vector, 2> a, array<reg_vector, 2> b, reg_vector v, reg_vector range, reg_vector c_double_abs_mask,
string out_of_range_label, bool result_always_in_range=false
) {
EXPAND_MACROS_SCOPE;
m.bind(a, "a");
m.bind(b, "b");
m.bind(v, "v");
APPEND_M(str( "MULPD `b_0, `a_0" ));
APPEND_M(str( "MOVAPD `v, `b_0" ));
//todo //for avx, can get rid of a lot of the MOVs by using the 3-operand versions of the instructions
range_check(b[0], range, c_double_abs_mask, out_of_range_label);
if (enable_all_instructions) {
APPEND_M(str( "VFMADD231PD `v, `b_1, `a_1" ));
} else {
APPEND_M(str( "MULPD `b_1, `a_1" ));
APPEND_M(str( "ADDPD `v, `b_1" ));
range_check(b[1], range, c_double_abs_mask, out_of_range_label);
}
if (!result_always_in_range) {
APPEND_M(str( "MOVAPD `b_0, `v" ));
range_check(b[0], range, c_double_abs_mask, out_of_range_label);
}
}
//ab_threshold is the same for both lanes
//is_lehmer is all 1s if lehmer, else all 0s
//will assign u and v
void gcd_base_continued_fraction(
reg_alloc regs,
reg_vector ab, reg_vector u, reg_vector v, reg_vector is_lehmer, reg_vector ab_threshold,
string no_progress_label
) {
EXPAND_MACROS_SCOPE;
track_asm( "gcd_base" );
static double_table<continued_fraction> c_table=generate_table(gcd_table_num_exponent_bits, gcd_table_num_fraction_bits);
static bool outputted_table=false;
if (!outputted_table) {
#ifdef CHIAOSX
APPEND_M(str( ".text " ));
#else
APPEND_M(str( ".text 1" ));
#endif
APPEND_M(str( ".balign 64" ));
APPEND_M(str( "gcd_base_table:" ));
string table_data;
auto out_double=[&](double v) {
if (!table_data.empty()) {
table_data += ", ";
}
table_data+=to_hex(*(uint64*)&v);
};
//each entry is 32 bytes, 32-aligned
for (continued_fraction c : c_table.data) {
matrix2 mat=c.get_matrix();
out_double(mat[0][0]); //lane 0
out_double(mat[1][0]); //lane 1
out_double(mat[0][1]); //lane 0
out_double(mat[1][1]); //lane 1
APPEND_M(str( ".quad #", table_data ));
table_data.clear();
}
APPEND_M(str( ".text" ));
outputted_table=true;
}
//5x vector
m.bind(ab, "ab");
m.bind(u, "u");
m.bind(v, "v");
m.bind(is_lehmer, "is_lehmer");
m.bind(ab_threshold, "ab_threshold");
//11x vector
reg_vector m_0=regs.bind_vector(m, "m_0");
reg_vector m_1=regs.bind_vector(m, "m_1");
reg_vector new_ab=regs.bind_vector(m, "new_ab");
reg_vector new_ab_1=regs.bind_vector(m, "new_ab_1");
reg_vector tmp=regs.bind_vector(m, "tmp");
reg_vector tmp2=regs.bind_vector(m, "tmp2");
reg_vector new_u=regs.bind_vector(m, "new_u");
reg_vector new_v=regs.bind_vector(m, "new_v");
reg_vector q=regs.bind_vector(m, "q");
reg_vector c_range_check_range=regs.bind_vector(m, "range_check_range");
reg_vector c_double_abs_mask=regs.bind_vector(m, "double_abs_mask");
reg_scalar q_scalar=regs.bind_scalar(m, "q_scalar");
reg_scalar q_scalar_2=regs.bind_scalar(m, "q_scalar_2");
reg_scalar q_scalar_3=regs.bind_scalar(m, "q_scalar_3");
reg_scalar loop_counter=regs.bind_scalar(m, "loop_counter");
reg_scalar c_table_delta_minus_1=regs.bind_scalar(m, "c_table_delta_minus_1");
APPEND_M(str( "MOV `c_table_delta_minus_1, #", constant_address_uint64(c_table.delta-1, c_table.delta-1) ));
string exit_label=m.alloc_label();
string loop_label=m.alloc_label();
APPEND_M(str( "MOV `loop_counter, #", to_hex(gcd_base_max_iter) ));
APPEND_M(str( "MOVAPD `u, #", constant_address_double(1.0, 0.0) ));
APPEND_M(str( "MOVAPD `v, #", constant_address_double(0.0, 1.0) ));
APPEND_M(str( "MOVAPD `range_check_range, #", constant_address_double(range_check_range, range_check_range) ));
APPEND_M(str( "MOVAPD `double_abs_mask, #", constant_address_uint64(double_abs_mask, double_abs_mask) ));
// q[0]=ab[0]/ab[1]
APPEND_M(str( "MOVAPD `tmp, `ab" ));
APPEND_M(str( "SHUFPD `tmp, `tmp, 3" )); // tmp=<ab[1], ab[1]>
APPEND_M(str( "MOVAPD `q, `ab" ));
APPEND_M(str( "DIVSD `q, `tmp" ));
{
APPEND_M(str( "#:", loop_label ));
track_asm( "gcd_base iter" );
string no_table_label=m.alloc_label();
APPEND_M( "#gcd_base loop start" );
//q_scalar=q_scalar_2=to_uint64(ab[0]/ab[1])
APPEND_M(str( "MOVQ `q_scalar, `q" ));
APPEND_M(str( "MOV `q_scalar_2, `q_scalar" ));
APPEND_M(str( "MOV `q_scalar_3, `q_scalar" ));
//q_scalar=(to_uint64(ab_0/ab_1)>>c_table.right_shift_amount)<<5
assert(c_table.right_shift_amount>5);
APPEND_M(str( "SHR `q_scalar, #", to_hex(c_table.right_shift_amount-5) ));
APPEND_M(str( "AND `q_scalar, -32" ));
// q_scalar-=c_table.range_start_shifted<<5
// if (q_scalar<0 || q_scalar>=(c_table.range_end_shifted-c_table.range_start_shifted)<<5) goto no_table_label
//this bypasses the "ab[1]<=ab_threshold" check so we need to do it again in no_table_label
APPEND_M(str( "SUB `q_scalar, #", to_hex(c_table.range_start_shifted<<5) ));
APPEND_M(str( "JB #", track_asm( "gcd_base below table start", no_table_label ) ));
APPEND_M(str( "CMP `q_scalar, #", to_hex((c_table.range_end_shifted-c_table.range_start_shifted)<<5) ));
APPEND_M(str( "JAE #", track_asm( "gcd_base after table end", no_table_label ) ));
//m_0: column 0
//m_1: column 1
#ifdef CHIAOSX
APPEND_M(str( "LEA RSI,[RIP+gcd_base_table]"));
APPEND_M(str( "MOVAPD `m_0, [`q_scalar+RSI]" ));
APPEND_M(str( "MOVAPD `m_1, [16+`q_scalar+RSI]" ));
#else
APPEND_M(str( "MOVAPD `m_0, [gcd_base_table+`q_scalar]" ));
APPEND_M(str( "MOVAPD `m_1, [gcd_base_table+16+`q_scalar]" ));
#endif
//if (ab[1]<=ab_threshold) goto exit_label
//this also tests ab[0], which is >= ab[1] so this does nothing
APPEND_M(str( "MOVAPD `tmp, `ab" ));
APPEND_M(str( "CMPLEPD `tmp, `ab_threshold" )); // tmp all 0s if (ab[0]>ab_threshold[0] && ab[1]>ab_threshold[1])
APPEND_M(str( "PTEST `tmp, `tmp" ));
APPEND_M(str( "JNZ #", track_asm( "gcd_base ab[1]<=ab_threshold", exit_label ) ));
//if ( (q_scalar_2&(c_table.delta-1))==0 || (q_scalar_2&(c_table.delta-1))==c_table.delta-1 ) goto no_table_label
APPEND_M(str( "AND `q_scalar_2, `c_table_delta_minus_1" ));
APPEND_M(str( "JZ #", track_asm( "gcd_base on slot boundary", no_table_label ) ));
APPEND_M(str( "CMP `q_scalar_2, `c_table_delta_minus_1" ));
APPEND_M(str( "JE #", track_asm( "gcd_base on slot boundary", no_table_label ) ));
//assigns: new_ab, new_ab_1, q, new_u, new_v
//reads: m, ab, u, v
//clobbers: tmp
auto calculate_using_m=[&](string fail_label) {
APPEND_M(str( "MOVAPD `tmp, `ab" ));
APPEND_M(str( "SHUFPD `tmp, `tmp, 0" ));
APPEND_M(str( "MOVAPD `tmp2, `ab" ));
APPEND_M(str( "SHUFPD `tmp2, `tmp2, 3" ));
dot_product_exact(
{m_0, m_1}, {tmp, tmp2}, new_ab, c_range_check_range, c_double_abs_mask,
track_asm( "gcd_base ab range check failed", fail_label),
true
);
APPEND_M(str( "MOVAPD `new_ab_1, `new_ab" ));
APPEND_M(str( "SHUFPD `new_ab_1, `new_ab_1, 3" )); // new_ab_1=<new_ab[1], new_ab[1]>
// q[0]=new_ab[0]/new_ab[1]
// this clobbers q if the table is not used
APPEND_M(str( "MOVAPD `q, `new_ab" ));
APPEND_M(str( "DIVSD `q, `new_ab_1" ));
APPEND_M(str( "MOVAPD `tmp, `u" ));
APPEND_M(str( "SHUFPD `tmp, `tmp, 0" ));
APPEND_M(str( "MOVAPD `tmp2, `u" ));
APPEND_M(str( "SHUFPD `tmp2, `tmp2, 3" ));
dot_product_exact(
{m_0, m_1}, {tmp, tmp2}, new_u, c_range_check_range, c_double_abs_mask,
track_asm( "gcd_base uv range check failed", fail_label)
);
//todo //for avx, can replace some shuffles with broadcasts. can make a macro that expands to the proper instructions
APPEND_M(str( "MOVAPD `tmp, `v" ));
APPEND_M(str( "SHUFPD `tmp, `tmp, 0" ));
APPEND_M(str( "MOVAPD `tmp2, `v" ));
APPEND_M(str( "SHUFPD `tmp2, `tmp2, 3" ));
dot_product_exact(
{m_0, m_1}, {tmp, tmp2}, new_v, c_range_check_range, c_double_abs_mask,
track_asm( "gcd_base uv range check failed", fail_label)
);
};
calculate_using_m(no_table_label);
//if (new_ab[0]<=ab_threshold) goto no_table_label
APPEND_M(str( "UCOMISD `new_ab, `ab_threshold" ));
APPEND_M(str( "JBE #", track_asm( "gcd_base new_ab[0]<=ab_threshold for table", no_table_label ) ));
string lehmer_label=m.alloc_label();
APPEND_M(str( "JMP #", lehmer_label ));
APPEND_M(str( "#:", no_table_label ));
APPEND_M( "#gcd_base no table" );
{
track_asm( "gcd_base iter no table" );
//have to do this check here because it might have been skipped: if (ab[1]<=ab_threshold) goto exit_label
APPEND_M(str( "MOVAPD `tmp, `ab" ));
APPEND_M(str( "CMPLEPD `tmp, `ab_threshold" )); // tmp all 0s if (ab[0]>ab_threshold[0] && ab[1]>ab_threshold[1])
APPEND_M(str( "PTEST `tmp, `tmp" ));
APPEND_M(str( "JNZ #", track_asm( "gcd_base ab[1]<=ab_threshold", exit_label ) ));
//q is clobbered, so need to restore it
APPEND_M(str( "MOVQ `q, `q_scalar_3" ));
// q=floor(q);
//this requires SSE4. if not present, can also add and subtract a magic number
APPEND_M(str( "ROUNDSD `q, `q, 1" )); //floor
// m=[0 1]
// 1 -q]
// m_0=<0,1> [column 0]
// m_1=<1,-q> [column 1]
APPEND_M(str( "MOVAPD `m_0, #", constant_address_double(0.0, 1.0) ));
APPEND_M(str( "MOVAPD `m_1, `m_0" )); // m_1=<0,1>
APPEND_M(str( "SUBSD `m_1, `q" )); //m_1=<-q,1>
APPEND_M(str( "SHUFPD `m_1, `m_1, 1" )); //m_1=<1,-q>
calculate_using_m(exit_label);
}
APPEND_M(str( "#:", lehmer_label ));
APPEND_M( "#gcd_base end no table" );
// new_ab_0=<new_ab[0], new_ab[0]>
// new_ab_1=<new_ab[1], new_ab[1]>
// ab_delta=new_ab_0-new_ab_1
// new_uv_0=<new_u[0], new_v[0]>
// new_uv_1=<new_u[1], new_v[1]>
//bool passed=
// new_ab_1[0]>=-new_uv_1[0] && ab_delta[0]+new_uv_0[0]>=new_uv_1[0] &&
// new_ab_1[1]>=-new_uv_1[1] && ab_delta[1]+new_uv_0[1]>=new_uv_1[1]
//;
//bool passed=
// new_ab_1[0]>=-new_uv_1[0] && ab_delta[0]+new_vu_0[0]>=new_vu_1[0] &&
// new_ab_1[1]>=-new_uv_1[1] && ab_delta[1]+new_vu_0[1]>=new_vu_1[1]
//;
//bool passed=
// new_ab[1]>=-new_u[1] && ab_delta[0]+new_v[0]>=new_v[1] &&
// new_ab[1]>=-new_v[1] && ab_delta[0]+new_u[0]>=new_u[1]
//;
//m_0=new_uv_0=<new_u[0], new_v[0]>
APPEND_M(str( "MOVAPD `m_0, `new_u" ));
APPEND_M(str( "SHUFPD `m_0, `new_v, 0" ));
//m_1=new_uv_1=<new_u[1], new_v[1]>
APPEND_M(str( "MOVAPD `m_1, `new_u" ));
APPEND_M(str( "SHUFPD `m_1, `new_v, 3" ));
//tmp=new_ab_0=<new_ab[0], new_ab[0]>
APPEND_M(str( "MOVAPD `tmp, `new_ab" ));
APPEND_M(str( "SHUFPD `tmp, `tmp, 0" ));
//tmp=ab_delta=new_ab_0-new_ab_1
APPEND_M(str( "SUBPD `tmp, `new_ab_1" ));
//tmp=ab_delta+new_uv_0
APPEND_M(str( "ADDPD `tmp, `m_0" ));
//tmp all 0s if (ab_delta[0]+new_uv_0[0]>=new_uv_1[0] && ab_delta[1]+new_uv_0[1]>=new_uv_1[1])
APPEND_M(str( "CMPLTPD `tmp, `m_1" ));
//m_1=-new_uv_1
APPEND_M(str( "XORPD `m_1, #", constant_address_uint64(double_sign_mask, double_sign_mask) ));
//new_ab_1 all 0s if (new_ab_1[0]>=-new_uv_1[0] && new_ab_1[1]>=-new_uv_1[1])
APPEND_M(str( "CMPLTPD `new_ab_1, `m_1" ));
//if (is_lehmer && !(ab_delta[0]+new_uv_0[0]>=new_uv_1[0] && ab_delta[1]+new_uv_0[1]>=new_uv_1[1])) goto exit_label
//if (is_lehmer && !(new_ab_1[0]>=-new_uv_1[0] && new_ab_1[1]>=-new_uv_1[1])) goto exit_label
APPEND_M(str( "ORPD `tmp, `new_ab_1" )); //tmp all 0s if passed is true
APPEND_M(str( "ANDPD `tmp, `is_lehmer" )); //tmp all 0s if passed||(!is_lehmer) is true
APPEND_M(str( "PTEST `tmp, `tmp" ));
APPEND_M(str( "JNZ #", track_asm( "gcd_base lehmer failed", exit_label ) ));
APPEND_M(str( "MOVAPD `ab, `new_ab" ));
APPEND_M(str( "MOVAPD `u, `new_u" ));
APPEND_M(str( "MOVAPD `v, `new_v" ));
track_asm( "gcd_base good iter" );
APPEND_M(str( "DEC `loop_counter" ));
APPEND_M(str( "JNZ #", loop_label ));
APPEND_M( "#gcd_base loop end" );
}
track_asm( "gcd_base good exit" );
APPEND_M(str( "#:", exit_label ));
APPEND_M(str( "CMP `loop_counter, #", to_hex(gcd_base_max_iter) ));
APPEND_M(str( "JE #", track_asm( "gcd_base no progress", no_progress_label ) ));
}
}