639 lines
19 KiB
C++
639 lines
19 KiB
C++
template<class int_type> int_type add_carry(int_type a, int_type b, int carry_in, int& carry_out) {
|
|
assert(carry_in==0 || carry_in==1);
|
|
uint128 res=uint128(a) + uint128(b) + uint128(carry_in);
|
|
|
|
carry_out=int(res >> (sizeof(int_type)*8));
|
|
assert(carry_out==0 || carry_out==1);
|
|
|
|
return int_type(res);
|
|
}
|
|
|
|
template<class int_type> int_type sub_carry(int_type a, int_type b, int carry_in, int& carry_out) {
|
|
assert(carry_in==0 || carry_in==1);
|
|
uint128 res=uint128(a) - uint128(b) - uint128(carry_in);
|
|
|
|
carry_out=int(res >> (sizeof(int_type)*8)) & 1;
|
|
assert(carry_out==0 || carry_out==1);
|
|
|
|
return int_type(res);
|
|
}
|
|
|
|
template<class int_type> int clz(int_type a) {
|
|
assert(sizeof(int_type)==4 || sizeof(int_type)==8);
|
|
|
|
if (a==0) {
|
|
return (sizeof(int_type)==4)? 32 : 64;
|
|
} else {
|
|
return (sizeof(int_type)==4)? __builtin_clz(uint32(a)) : __builtin_clzll(uint64(a));
|
|
}
|
|
}
|
|
|
|
uint64 mul_high(uint64 a, uint64 b) {
|
|
return uint64((uint128(a)*uint128(b))>>64);
|
|
}
|
|
|
|
uint32 mul_high(uint32 a, uint32 b) {
|
|
return uint32((uint64(a)*uint64(b))>>32);
|
|
}
|
|
|
|
constexpr int max_constexpr(int a, int b) {
|
|
if (a>b) {
|
|
return a;
|
|
} else {
|
|
return b;
|
|
}
|
|
}
|
|
|
|
//all "=" operators truncate ; all operators that return a separate result will pad the result as necessary
|
|
template<class type, int size> struct fixed_integer {
|
|
static const type positive_sign=0;
|
|
static const type negative_sign=~type(0);
|
|
|
|
type data[size+1]; //little endian; sign is first
|
|
|
|
fixed_integer() {
|
|
for (int x=0;x<size+1;++x) {
|
|
data[x]=0;
|
|
}
|
|
}
|
|
|
|
fixed_integer(const integer& i) : fixed_integer() {
|
|
assert(i.num_bits()<=size*sizeof(type)*8);
|
|
if (i<0) {
|
|
data[0]=negative_sign;
|
|
}
|
|
|
|
mpz_export(data+1, nullptr, -1, sizeof(type), -1, 0, i.impl);
|
|
}
|
|
|
|
operator integer() const {
|
|
integer res;
|
|
mpz_import(res.impl, size, -1, sizeof(type), -1, 0, data+1);
|
|
|
|
if (data[0]==negative_sign) {
|
|
res=-res;
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
USED integer to_integer() const {
|
|
return integer(*this);
|
|
}
|
|
|
|
//truncation
|
|
template<int t_size> explicit fixed_integer(fixed_integer<type, t_size> t) {
|
|
for (int x=0;x<size+1;++x) {
|
|
data[x]=(x<t_size+1)? t.data[x] : 0;
|
|
}
|
|
}
|
|
|
|
fixed_integer& operator=(const integer& v) { return *this=fixed_integer(v); }
|
|
template<int t_size> fixed_integer& operator=(fixed_integer<type, t_size> t) { return *this=fixed_integer(t); }
|
|
|
|
bool is_negative() const {
|
|
return !is_zero() && data[0]==negative_sign;
|
|
}
|
|
|
|
void set_negative(bool t_negative) {
|
|
data[0]=(t_negative)? negative_sign : positive_sign;
|
|
}
|
|
|
|
type& operator[](int pos) {
|
|
assert(pos>=0 && pos<size);
|
|
return data[pos+1];
|
|
}
|
|
|
|
const type& operator[](int pos) const {
|
|
assert(pos>=0 && pos<size);
|
|
return data[pos+1];
|
|
}
|
|
|
|
//the result is -1 if a<b, 0 if a==b, and 1 if a>b
|
|
//there is also a fast comparison in the add function, but it has a slow path
|
|
static int compare(
|
|
const type* a, int size_a, type sign_a,
|
|
const type* b, int size_b, type sign_b
|
|
) {
|
|
int carry=0;
|
|
type zero=0;
|
|
|
|
//this calculates |a|-|b|. all of the resulted are or'ed together in zero
|
|
for (int x=0;x<max(size_a, size_b);++x) {
|
|
type v_a=(x<size_a)? a[x] : 0;
|
|
type v_b=(x<size_b)? b[x] : 0;
|
|
|
|
zero|=sub_carry(v_a, v_b, carry, carry);
|
|
}
|
|
|
|
//if the final carry is 1, |a|<|b|
|
|
//if the final carry is 0 and zero==0, |a|==|b| (|a|-|b| is 0)
|
|
//if the final carry is 0 and zero!=0, |a|>|b| (|a|-|b| is positive)
|
|
|
|
//same sign, positive: use res
|
|
//same sign, negative: negate res
|
|
//opposite signs: use res if 0, otherwise 1 if sign_a is positive, -1 if sign_a is negative
|
|
|
|
int res=0;
|
|
if (zero!=0) res=1;
|
|
if (carry==1) res=-1;
|
|
|
|
//todo //get rid of branches
|
|
//this is used to implement exactly one comparison with a binary result, so that should get rid of all of these branches
|
|
if (sign_a==sign_b) {
|
|
if (sign_a==negative_sign) {
|
|
res=-res;
|
|
}
|
|
} else {
|
|
if (res!=0) {
|
|
res=(sign_a==negative_sign)? -1 : 1;
|
|
}
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
template<int b_size> int compare(fixed_integer<type, b_size> b) const {
|
|
return compare(
|
|
data+1, size, data[0],
|
|
b.data+1, size, b.data[0]
|
|
);
|
|
}
|
|
|
|
//a, b, and res can alias with each other but only if the pointers are equal
|
|
//the sign is not present in a/b/res
|
|
static void add(
|
|
const type* a, int size_a, type sign_a,
|
|
const type* b, int size_b, type sign_b,
|
|
type* res, int size_res, type& sign_res
|
|
) {
|
|
if (size_b>size_a) {
|
|
swap(a, b);
|
|
swap(size_a, size_b);
|
|
swap(sign_a, sign_b);
|
|
}
|
|
|
|
assert(size_res>=size_a && size_a>=size_b && size_b>=1);
|
|
|
|
type mask=sign_a ^ sign_b; //all 1s if opposite signs, else all 0s. this isn't affected by swapping
|
|
|
|
type swap_mask=positive_sign;
|
|
|
|
if (size_a==size_b) {
|
|
//carry flag
|
|
int size_ab=size_a;
|
|
bool a_less_than_b=a[size_ab-1]<b[size_ab-1];
|
|
if (a[size_ab-1]==b[size_ab-1] && size_ab>=2) {
|
|
a_less_than_b=a[size_ab-2]<b[size_ab-2];
|
|
}
|
|
|
|
const type* tmp=b;
|
|
if (a_less_than_b) b=a; //CMOVB
|
|
if (a_less_than_b) a=tmp; //CMOVB
|
|
|
|
if (a_less_than_b) sign_a=sign_b; //CMOVB
|
|
//sign_b isn't used anymore
|
|
sign_b=0;
|
|
|
|
//if (a_less_than_b) swap_mask=negative_sign; //CMOVB
|
|
}
|
|
|
|
int carry;
|
|
add_carry(mask, type(1), 0, carry); //carry set if opposite signs, else cleared
|
|
|
|
//if the ints were swapped, size_a==size_b
|
|
for (int x=0;x<size_res;++x) {
|
|
type v_a=(x<size_a)? a[x] : 0;
|
|
type v_b=(x<size_b)? b[x] : 0;
|
|
|
|
//print(x, v_a, v_b, mask, carry);
|
|
|
|
//this calculates a-b if they had opposite signs, or a+b if they had the same sign
|
|
res[x]=add_carry(v_a, v_b^mask, carry, carry);
|
|
}
|
|
|
|
//print(carry, "===");
|
|
|
|
//the final sign is a's sign since it has a higher magnitude than b
|
|
//however, if a subtraction was done and a and b were swapped, then this should be negated
|
|
sign_res=sign_a^(swap_mask & mask);
|
|
|
|
//todo //figure out how often this happens
|
|
//a subtraction was done and there was a carry out. since the subtraction is unsigned, this means it was done in the wrong order
|
|
//this almost never happens if the numbers are random and don't have excessive padding
|
|
//the subtraction was done in the wrong order if the result is negative
|
|
//the result is negative if each input were padded with 0, and the result limb was ~0 instead of 0
|
|
//the result limb is: add_carry(0, mask, carry, carry);
|
|
//carry in is 0: result is all 1s (bad)
|
|
//carry in is 1: result is all 0s and carry out is 1 (good)
|
|
//need to check for a carry out of 0 then, not 1
|
|
|
|
if (carry==0 && mask!=0) {
|
|
carry=0;
|
|
for (int x=0;x<size_res;++x) {
|
|
//print(x, ~res[x], type((x==0)? 1 : 0), carry);
|
|
|
|
//calculate the two's complement of the result
|
|
res[x]=add_carry(~res[x], type((x==0)? 1 : 0), carry, carry);
|
|
}
|
|
|
|
//print(carry, "===");
|
|
|
|
//todo print("slow add");
|
|
//assert(false);
|
|
|
|
//negate the sign since the subtraction order was flipped
|
|
sign_res=~sign_res;
|
|
}
|
|
}
|
|
|
|
fixed_integer operator-() const {
|
|
fixed_integer res=*this;
|
|
res.data[0]=~data[0];
|
|
return res;
|
|
}
|
|
|
|
void operator+=(fixed_integer b) {
|
|
add(
|
|
data+1, size, data[0],
|
|
b.data+1, size, b.data[0],
|
|
data+1, size, data[0]
|
|
);
|
|
}
|
|
|
|
void operator-=(fixed_integer b) {
|
|
add(
|
|
data+1, size, data[0],
|
|
b.data+1, size, negative_sign^b.data[0],
|
|
data+1, size, data[0]
|
|
);
|
|
}
|
|
|
|
template<int b_size>
|
|
fixed_integer<type, max_constexpr(size, b_size)+1> operator+(
|
|
fixed_integer<type, b_size> b
|
|
) const {
|
|
const int output_size=max_constexpr(size, b_size)+1;
|
|
|
|
fixed_integer<type, output_size> res;
|
|
|
|
add(
|
|
data+1, size, data[0],
|
|
b.data+1, b_size, b.data[0],
|
|
res.data+1, output_size, res.data[0]
|
|
);
|
|
|
|
return res;
|
|
}
|
|
|
|
template<int b_size>
|
|
fixed_integer<type, max_constexpr(size, b_size)+1> operator-(
|
|
fixed_integer<type, b_size> b
|
|
) const {
|
|
const int output_size=max_constexpr(size, b_size)+1;
|
|
|
|
fixed_integer<type, output_size> res;
|
|
|
|
add(
|
|
data+1, size, data[0],
|
|
b.data+1, b_size, negative_sign^b.data[0],
|
|
res.data+1, output_size, res.data[0]
|
|
);
|
|
|
|
return res;
|
|
}
|
|
|
|
//res=a*b+c
|
|
//res can alias with c if the pointers are equal. can't alias with a
|
|
//if c is null then it is all 0s
|
|
static void mad(
|
|
const type* a, int size_a,
|
|
type b,
|
|
const type* c, int size_c,
|
|
type* res, int size_res
|
|
) {
|
|
assert(size_res>=size_c && size_c>=size_a && size_a>=1);
|
|
|
|
type previous_high=0;
|
|
int carry_mul=0;
|
|
int add_mul=0;
|
|
|
|
for (int x=0;x<size_res;++x) {
|
|
type this_a=(x>=size_a)? 0 : a[x];
|
|
|
|
type this_low=this_a*b;
|
|
type this_high=mul_high(this_a, b);
|
|
|
|
type mul_res=add_carry(this_low, previous_high, carry_mul, carry_mul);
|
|
|
|
if (x==0) {
|
|
assert(mul_res==this_low && carry_mul==0);
|
|
} else
|
|
if (x==size_a) {
|
|
assert(carry_mul==0);
|
|
} else
|
|
if (x>size_a) {
|
|
assert(mul_res==0 && carry_mul==0);
|
|
}
|
|
|
|
type this_c=(x>=size_c || c==nullptr)? 0 : c[x];
|
|
type add_res=add_carry(mul_res, this_c, add_mul, add_mul);
|
|
|
|
res[x]=add_res;
|
|
|
|
previous_high=this_high;
|
|
}
|
|
}
|
|
|
|
//can't overflow
|
|
//two of these can implement a 1024x512 mul. for 1024x1024, need to do 2x 1024x512 in separate buffers then add them
|
|
static void mad_8x8(array<type, 8> a, array<type, 8> b, array<type, 8> c, array<type, 16>& res) {
|
|
for (int x=0;x<8;++x) {
|
|
res[x]=c[x];
|
|
}
|
|
for (int x=8;x<16;++x) {
|
|
res[x]=0;
|
|
}
|
|
|
|
for (int x=0;x<8;++x) {
|
|
//this uses a sliding window for the 8 res registers (no spilling)
|
|
//-the lowest register is finished after the first addition in mad. the this_low,previous_high addition is skipped
|
|
//-the highest register does not need to be loaded until the last multiplication in mad. actually this would always load 0
|
|
// so it is not done
|
|
//-the total number of registers is therefore 7
|
|
//there is one register for b
|
|
//the 8 a values are in registers but some or all may be spilled
|
|
//need 2 registers to store the MULX result
|
|
//need 1 register to store the previous high result (this is initially 0)
|
|
//the this_low,previous_high add result goes into one of those registers
|
|
//the mul_res,this_c result goes into the c register
|
|
//total registers is 18 then; 2 are spilled
|
|
//address registers:
|
|
//-will just use a static 32-bit address space for most of the code. can store the stack pointer there then
|
|
//-address registers are only used for b and res if the addresses are not static
|
|
//-the addresses are only used at the end of the loop, so there are spare registers to load the address registers from static
|
|
// memory. probably the addresses will be static though
|
|
mad(&a[0], 8, b[x], &res[x], 8, &res[x], 8);
|
|
}
|
|
}
|
|
|
|
void operator*=(type v) {
|
|
mad(
|
|
data+1, size,
|
|
v,
|
|
nullptr, size,
|
|
data+1, size
|
|
);
|
|
}
|
|
|
|
template<int t_size, int this_size>
|
|
static fixed_integer<type, t_size> subset(
|
|
fixed_integer<type, this_size> this_v, int start
|
|
) {
|
|
const int end=start+t_size;
|
|
|
|
fixed_integer<type, t_size> res;
|
|
res.data[0]=this_v.data[0];
|
|
|
|
for (int x=start;x<end;++x) {
|
|
int pos=x-start;
|
|
res[x]=(pos>=0 && pos<this_size)? this_v[x] : 0;
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
void left_shift_limbs(int amount) {
|
|
for (int x=size-1;x>=0;--x) {
|
|
int pos=x-amount;
|
|
(*this)[x] = (pos>=0 && pos<size)? (*this)[pos] : 0;
|
|
}
|
|
}
|
|
|
|
void right_shift_limbs(int amount) {
|
|
for (int x=0;x<size;++x) {
|
|
int pos=x+amount;
|
|
(*this)[x] = (pos>=0 && pos<size)? (*this)[pos] : 0;
|
|
}
|
|
}
|
|
|
|
void operator<<=(int amount) {
|
|
if (amount==0) {
|
|
//not sure if intel works with the "previous>>64" statement. might wrap around
|
|
return;
|
|
}
|
|
|
|
const int bits_per_limb=sizeof(type)*8;
|
|
assert(amount>0 && amount<bits_per_limb);
|
|
|
|
for (int x=size-1;x>=0;--x) {
|
|
type previous=(x==0)? 0 : (*this)[x-1];
|
|
(*this)[x] = ((*this)[x]<<amount) | (previous>>(bits_per_limb-amount));
|
|
}
|
|
}
|
|
|
|
void operator>>=(int amount) {
|
|
if (amount==0) {
|
|
return;
|
|
}
|
|
|
|
const int bits_per_limb=sizeof(type)*8;
|
|
assert(amount>0 && amount<bits_per_limb);
|
|
|
|
for (int x=0;x<size;++x) {
|
|
type next=(x==size-1)? 0 : (*this)[x+1];
|
|
(*this)[x] = ((*this)[x]>>amount) | (next<<(bits_per_limb-amount));
|
|
}
|
|
}
|
|
|
|
template<int b_size>
|
|
fixed_integer<type, size+b_size> operator*(
|
|
fixed_integer<type, b_size> b
|
|
) const {
|
|
const int output_size=size+b_size;
|
|
fixed_integer<type, output_size> res;
|
|
|
|
for (int x=0;x<b_size;++x) {
|
|
auto r=subset<output_size>(*this, 0);
|
|
r.data[0]=positive_sign;
|
|
|
|
integer b_x_int(vector<uint64>{b[x]});
|
|
|
|
r*=b[x];
|
|
//auto r2=subset<output_size+2>(r, 0);
|
|
//r2*=b[x];
|
|
//r=r2;
|
|
|
|
integer r_int(r);
|
|
integer this_int(abs(*this));
|
|
integer expected_r_int=this_int*b_x_int;
|
|
assert(r_int==expected_r_int);
|
|
|
|
r.left_shift_limbs(x);
|
|
r_int<<=x*sizeof(type)*8;
|
|
assert(r_int==integer(r));
|
|
|
|
integer res_old_int(res);
|
|
|
|
//todo //figure out why this doesn't work. might have something to do with the msb being set?
|
|
res+=r; //unsigned
|
|
/*auto res3=res;
|
|
res3+=r;
|
|
|
|
auto res2=res+r;
|
|
fixed_integer<type, output_size> res4(res2);*/
|
|
|
|
/*if (integer(res3)!=integer(res4)) {
|
|
print( "========" );
|
|
|
|
res3=res;
|
|
res3+=r;
|
|
|
|
//print( "========" );
|
|
|
|
auto res2_copy=res+r;
|
|
|
|
assert(false);
|
|
}*/
|
|
|
|
//res=res4;
|
|
|
|
integer res_new_int(res);
|
|
|
|
assert(res_new_int==res_old_int+r_int);
|
|
}
|
|
|
|
res.data[0]=data[0] ^ b.data[0];
|
|
return res;
|
|
}
|
|
|
|
fixed_integer<type, size+1> operator<<(int num) const {
|
|
auto res=subset<size+1>(*this, 0);
|
|
res<<=num;
|
|
return res;
|
|
}
|
|
|
|
//this rounds to 0 so it is different from division unless the input is divisible by 2^num
|
|
fixed_integer<type, size> operator>>(int num) const {
|
|
auto res=subset<size>(*this, 0);
|
|
res>>=num;
|
|
return res;
|
|
}
|
|
|
|
bool is_zero() const {
|
|
for (int x=0;x<size;++x) {
|
|
if (data[x+1]!=0) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template<int b_size>
|
|
bool operator>=(fixed_integer<type, b_size> b) const {
|
|
return compare(b)>=0;
|
|
}
|
|
|
|
template<int b_size>
|
|
bool operator==(fixed_integer<type, b_size> b) const {
|
|
return compare(b)==0;
|
|
}
|
|
|
|
template<int b_size>
|
|
bool operator<(fixed_integer<type, b_size> b) const {
|
|
return compare(b)<0;
|
|
}
|
|
|
|
template<int b_size>
|
|
bool operator<=(fixed_integer<type, b_size> b) const {
|
|
return compare(b)<=0;
|
|
}
|
|
|
|
template<int b_size>
|
|
bool operator>(fixed_integer<type, b_size> b) const {
|
|
return compare(b)>0;
|
|
}
|
|
|
|
template<int b_size>
|
|
bool operator!=(fixed_integer<type, b_size> b) const {
|
|
return compare(b)!=0;
|
|
}
|
|
|
|
//"0" has 1 bit
|
|
int num_bits() const {
|
|
type v=0;
|
|
int num_full=0;
|
|
|
|
for (int x=size-1;x>=0;--x) {
|
|
if (v==0) {
|
|
v=(*this)[x];
|
|
num_full=x;
|
|
}
|
|
}
|
|
|
|
int v_bits;
|
|
if (v==0) {
|
|
v_bits=1;
|
|
assert(num_full==0);
|
|
} else
|
|
if (sizeof(v)==8) {
|
|
v_bits=64-__builtin_clzll(v);
|
|
} else{
|
|
assert(sizeof(v)==4);
|
|
v_bits=32-__builtin_clz(v);
|
|
}
|
|
|
|
return num_full*sizeof(type)*8 + v_bits;
|
|
}
|
|
|
|
type window(int start_bit) const {
|
|
int bits_per_limb_log2=(sizeof(type)==8)? 6 : 5;
|
|
int bits_per_limb=1<<bits_per_limb_log2;
|
|
|
|
int start_limb=start_bit>>bits_per_limb_log2;
|
|
int start_offset=start_bit&(bits_per_limb-1);
|
|
|
|
auto get_limb=[&](int pos) -> type {
|
|
assert(pos>=0);
|
|
return (pos>=size)? type(0) : (*this)[pos];
|
|
};
|
|
|
|
type start=get_limb(start_limb)>>(start_offset);
|
|
|
|
//the shift is undefined for start_offset==0
|
|
type end=get_limb(start_limb+1)<<(bits_per_limb-start_offset);
|
|
|
|
return (start_offset==0)? start : (start | end);
|
|
}
|
|
};
|
|
|
|
template<class type, int size> fixed_integer<type, size> abs(fixed_integer<type, size> v) {
|
|
v.set_negative(false);
|
|
return v;
|
|
}
|
|
|
|
template<int size> fixed_integer<uint64, (size+1)/2> to_uint64(fixed_integer<uint32, size> v) {
|
|
fixed_integer<uint64, (size+1)/2> res;
|
|
res.set_negative(v.is_negative()); //sign extend data[0]. can just make data[0] 64 bits if i actually have to do this
|
|
|
|
//this just copies the bytes over
|
|
for (int x=0;x<size;x+=2) {
|
|
uint32 low=v[x];
|
|
uint32 high=(x==size-1)? 0 : v[x+1];
|
|
res[x>>1]=uint64(high)<<32 | uint64(low);
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
template<int size> fixed_integer<uint32, size*2> to_uint32(fixed_integer<uint64, size> v) {
|
|
fixed_integer<uint32, size*2> res;
|
|
res.set_negative(v.is_negative()); //lower 32 bits of data[0]
|
|
|
|
for (int x=0;x<size;++x) {
|
|
res[2*x]=uint32(v[x]);
|
|
res[2*x+1]=uint32(v[x]>>32);
|
|
}
|
|
|
|
return res;
|
|
} |