implement multiplication for FixedU128, FixedI128

This commit is contained in:
Trevor Spiteri 2018-08-09 23:39:44 +02:00
parent 05f82047b4
commit c847c9e74f
1 changed files with 162 additions and 7 deletions

View File

@ -343,15 +343,132 @@ macro_rules! mul_div_widen {
};
}
trait FallbackHelper: Sized {
type Unsigned;
fn hi_lo(self) -> (Self, Self);
fn shift_lo_up(self) -> Self;
fn shift_lo_up_unsigned(self) -> Self::Unsigned;
fn combine_lo_then_shl(self, lo: Self::Unsigned, shift: u32) -> Self;
fn carrying_add(self, other: Self) -> (Self, Self);
}
impl FallbackHelper for u128 {
type Unsigned = u128;
#[inline]
fn hi_lo(self) -> (u128, u128) {
(self >> 64, self & !(!0 << 64))
}
#[inline]
fn shift_lo_up(self) -> u128 {
debug_assert!(self >> 64 == 0);
self << 64
}
#[inline]
fn shift_lo_up_unsigned(self) -> u128 {
debug_assert!(self >> 64 == 0);
self << 64
}
#[inline]
fn combine_lo_then_shl(self, lo: u128, shift: u32) -> u128 {
if shift == 128 {
return self;
}
if shift == 0 {
assert!(self == 0, "overflow");
return lo;
}
let lo = lo >> shift;
let hi = self << (128 - shift);
assert!(self >> shift == 0, "overflow");
lo | hi
}
#[inline]
fn carrying_add(self, rhs: u128) -> (u128, u128) {
let (sum, overflow) = self.overflowing_add(rhs);
let carry = if overflow { 1 } else { 0 };
(sum, carry)
}
}
impl FallbackHelper for i128 {
type Unsigned = u128;
#[inline]
fn hi_lo(self) -> (i128, i128) {
(self >> 64, self & !(!0 << 64))
}
#[inline]
fn shift_lo_up(self) -> i128 {
debug_assert!(self >> 64 == 0);
self << 64
}
#[inline]
fn shift_lo_up_unsigned(self) -> u128 {
debug_assert!(self >> 64 == 0);
(self << 64) as u128
}
#[inline]
fn combine_lo_then_shl(self, lo: u128, shift: u32) -> i128 {
if shift == 128 {
return self;
}
if shift == 0 {
let ans = lo as i128;
assert!(ans >> 64 >> 64 == self, "overflow");
return ans;
}
let lo = (lo >> shift) as i128;
let hi = self << (128 - shift);
let ans = lo | hi;
assert!(ans >> 64 >> 64 == self >> shift, "overflow");
ans
}
#[inline]
fn carrying_add(self, rhs: i128) -> (i128, i128) {
let (sum, overflow) = self.overflowing_add(rhs);
let carry = if overflow {
if sum < 0 {
1
} else {
-1
}
} else {
0
};
(sum, carry)
}
}
macro_rules! mul_div_fallback {
($Single:ty) => {
impl MulDiv for $Single {
#[inline]
fn mul(self, rhs: $Single) -> $Single {
if F == 0 {
self * rhs
} else {
unimplemented!()
let (lh, ll) = self.hi_lo();
let (rh, rl) = rhs.hi_lo();
let ll_rl = ll.wrapping_mul(rl);
let lh_rl = lh.wrapping_mul(rl);
let ll_rh = ll.wrapping_mul(rh);
let lh_rh = lh.wrapping_mul(rh);
let col01 = ll_rl as <$Single as FallbackHelper>::Unsigned;
let (col12, carry_col3) = lh_rl.carrying_add(ll_rh);
let col23 = lh_rh;
let (col12_hi, col12_lo) = col12.hi_lo();
let col12_lo_up = col12_lo.shift_lo_up_unsigned();
let (ans01, carry_col2) = col01.carrying_add(col12_lo_up);
let carries = carry_col2 as $Single + carry_col3.shift_lo_up();
let ans23 = col23.wrapping_add(carries).wrapping_add(col12_hi);
ans23.combine_lo_then_shl(ans01, F)
}
}
@ -401,20 +518,58 @@ mod tests {
#[test]
fn fixed_i16() {
let a = 12;
let b = -4;
let af = FixedI16::from_bits(a << F);
let bf = FixedI16::from_bits(b << F);
let b = 4;
for &pair in &[(a, b), (a, -b), (-a, b), (-a, -b)] {
let (a, b) = pair;
let af = FixedI16::from_bits(a << F);
let bf = FixedI16::from_bits(b << F);
assert_eq!((af + bf).to_bits(), (a << F) + (b << F));
assert_eq!((af - bf).to_bits(), (a << F) - (b << F));
assert_eq!((af * bf).to_bits(), (a << F) * b);
assert_eq!((af / bf).to_bits(), (a << F) / b);
assert_eq!((af & bf).to_bits(), (a << F) & (b << F));
assert_eq!((af | bf).to_bits(), (a << F) | (b << F));
assert_eq!((af ^ bf).to_bits(), (a << F) ^ (b << F));
assert_eq!((-af).to_bits(), -(a << F));
assert_eq!((!af).to_bits(), !(a << F));
}
}
#[test]
fn fixed_u128() {
let a = 0x0003456789abcdef_0123456789abcdef_u128;
let b = 5;
let af = FixedU128::from_bits(a << F);
let bf = FixedU128::from_bits(b << F);
assert_eq!((af + bf).to_bits(), (a << F) + (b << F));
assert_eq!((af - bf).to_bits(), (a << F) - (b << F));
assert_eq!((af * bf).to_bits(), (a << F) * b);
assert_eq!((af / bf).to_bits(), (a << F) / b);
// assert_eq!((af / bf).to_bits(), (a << F) / b);
assert_eq!((af & bf).to_bits(), (a << F) & (b << F));
assert_eq!((af | bf).to_bits(), (a << F) | (b << F));
assert_eq!((af ^ bf).to_bits(), (a << F) ^ (b << F));
assert_eq!((-af).to_bits(), -(a << F));
assert_eq!((!af).to_bits(), !(a << F));
}
#[test]
fn fixed_i128() {
let a = 0x0003456789abcdef_0123456789abcdef_i128;
let b = 5;
for &pair in &[(a, b), (a, -b), (-a, b), (-a, -b)] {
let (a, b) = pair;
let af = FixedI128::from_bits(a << F);
let bf = FixedI128::from_bits(b << F);
assert_eq!((af + bf).to_bits(), (a << F) + (b << F));
assert_eq!((af - bf).to_bits(), (a << F) - (b << F));
assert_eq!((af * bf).to_bits(), (a << F) * b);
// assert_eq!((af / bf).to_bits(), (a << F) / b);
assert_eq!((af & bf).to_bits(), (a << F) & (b << F));
assert_eq!((af | bf).to_bits(), (a << F) | (b << F));
assert_eq!((af ^ bf).to_bits(), (a << F) ^ (b << F));
assert_eq!((!af).to_bits(), !(a << F));
}
}
#[test]
fn to_f32() {
for u in 0x00..=0xff {