implement multiplication for FixedU128, FixedI128
This commit is contained in:
parent
05f82047b4
commit
c847c9e74f
169
src/lib.rs
169
src/lib.rs
|
@ -343,15 +343,132 @@ macro_rules! mul_div_widen {
|
|||
};
|
||||
}
|
||||
|
||||
trait FallbackHelper: Sized {
|
||||
type Unsigned;
|
||||
fn hi_lo(self) -> (Self, Self);
|
||||
fn shift_lo_up(self) -> Self;
|
||||
fn shift_lo_up_unsigned(self) -> Self::Unsigned;
|
||||
fn combine_lo_then_shl(self, lo: Self::Unsigned, shift: u32) -> Self;
|
||||
fn carrying_add(self, other: Self) -> (Self, Self);
|
||||
}
|
||||
|
||||
impl FallbackHelper for u128 {
|
||||
type Unsigned = u128;
|
||||
#[inline]
|
||||
fn hi_lo(self) -> (u128, u128) {
|
||||
(self >> 64, self & !(!0 << 64))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn shift_lo_up(self) -> u128 {
|
||||
debug_assert!(self >> 64 == 0);
|
||||
self << 64
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn shift_lo_up_unsigned(self) -> u128 {
|
||||
debug_assert!(self >> 64 == 0);
|
||||
self << 64
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn combine_lo_then_shl(self, lo: u128, shift: u32) -> u128 {
|
||||
if shift == 128 {
|
||||
return self;
|
||||
}
|
||||
if shift == 0 {
|
||||
assert!(self == 0, "overflow");
|
||||
return lo;
|
||||
}
|
||||
let lo = lo >> shift;
|
||||
let hi = self << (128 - shift);
|
||||
assert!(self >> shift == 0, "overflow");
|
||||
lo | hi
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn carrying_add(self, rhs: u128) -> (u128, u128) {
|
||||
let (sum, overflow) = self.overflowing_add(rhs);
|
||||
let carry = if overflow { 1 } else { 0 };
|
||||
(sum, carry)
|
||||
}
|
||||
}
|
||||
|
||||
impl FallbackHelper for i128 {
|
||||
type Unsigned = u128;
|
||||
#[inline]
|
||||
fn hi_lo(self) -> (i128, i128) {
|
||||
(self >> 64, self & !(!0 << 64))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn shift_lo_up(self) -> i128 {
|
||||
debug_assert!(self >> 64 == 0);
|
||||
self << 64
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn shift_lo_up_unsigned(self) -> u128 {
|
||||
debug_assert!(self >> 64 == 0);
|
||||
(self << 64) as u128
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn combine_lo_then_shl(self, lo: u128, shift: u32) -> i128 {
|
||||
if shift == 128 {
|
||||
return self;
|
||||
}
|
||||
if shift == 0 {
|
||||
let ans = lo as i128;
|
||||
assert!(ans >> 64 >> 64 == self, "overflow");
|
||||
return ans;
|
||||
}
|
||||
let lo = (lo >> shift) as i128;
|
||||
let hi = self << (128 - shift);
|
||||
let ans = lo | hi;
|
||||
assert!(ans >> 64 >> 64 == self >> shift, "overflow");
|
||||
ans
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn carrying_add(self, rhs: i128) -> (i128, i128) {
|
||||
let (sum, overflow) = self.overflowing_add(rhs);
|
||||
let carry = if overflow {
|
||||
if sum < 0 {
|
||||
1
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
} else {
|
||||
0
|
||||
};
|
||||
(sum, carry)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! mul_div_fallback {
|
||||
($Single:ty) => {
|
||||
impl MulDiv for $Single {
|
||||
#[inline]
|
||||
fn mul(self, rhs: $Single) -> $Single {
|
||||
if F == 0 {
|
||||
self * rhs
|
||||
} else {
|
||||
unimplemented!()
|
||||
let (lh, ll) = self.hi_lo();
|
||||
let (rh, rl) = rhs.hi_lo();
|
||||
let ll_rl = ll.wrapping_mul(rl);
|
||||
let lh_rl = lh.wrapping_mul(rl);
|
||||
let ll_rh = ll.wrapping_mul(rh);
|
||||
let lh_rh = lh.wrapping_mul(rh);
|
||||
let col01 = ll_rl as <$Single as FallbackHelper>::Unsigned;
|
||||
let (col12, carry_col3) = lh_rl.carrying_add(ll_rh);
|
||||
let col23 = lh_rh;
|
||||
let (col12_hi, col12_lo) = col12.hi_lo();
|
||||
let col12_lo_up = col12_lo.shift_lo_up_unsigned();
|
||||
let (ans01, carry_col2) = col01.carrying_add(col12_lo_up);
|
||||
let carries = carry_col2 as $Single + carry_col3.shift_lo_up();
|
||||
let ans23 = col23.wrapping_add(carries).wrapping_add(col12_hi);
|
||||
|
||||
ans23.combine_lo_then_shl(ans01, F)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -401,20 +518,58 @@ mod tests {
|
|||
#[test]
|
||||
fn fixed_i16() {
|
||||
let a = 12;
|
||||
let b = -4;
|
||||
let af = FixedI16::from_bits(a << F);
|
||||
let bf = FixedI16::from_bits(b << F);
|
||||
let b = 4;
|
||||
for &pair in &[(a, b), (a, -b), (-a, b), (-a, -b)] {
|
||||
let (a, b) = pair;
|
||||
let af = FixedI16::from_bits(a << F);
|
||||
let bf = FixedI16::from_bits(b << F);
|
||||
assert_eq!((af + bf).to_bits(), (a << F) + (b << F));
|
||||
assert_eq!((af - bf).to_bits(), (a << F) - (b << F));
|
||||
assert_eq!((af * bf).to_bits(), (a << F) * b);
|
||||
assert_eq!((af / bf).to_bits(), (a << F) / b);
|
||||
assert_eq!((af & bf).to_bits(), (a << F) & (b << F));
|
||||
assert_eq!((af | bf).to_bits(), (a << F) | (b << F));
|
||||
assert_eq!((af ^ bf).to_bits(), (a << F) ^ (b << F));
|
||||
assert_eq!((-af).to_bits(), -(a << F));
|
||||
assert_eq!((!af).to_bits(), !(a << F));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fixed_u128() {
|
||||
let a = 0x0003456789abcdef_0123456789abcdef_u128;
|
||||
let b = 5;
|
||||
let af = FixedU128::from_bits(a << F);
|
||||
let bf = FixedU128::from_bits(b << F);
|
||||
assert_eq!((af + bf).to_bits(), (a << F) + (b << F));
|
||||
assert_eq!((af - bf).to_bits(), (a << F) - (b << F));
|
||||
assert_eq!((af * bf).to_bits(), (a << F) * b);
|
||||
assert_eq!((af / bf).to_bits(), (a << F) / b);
|
||||
// assert_eq!((af / bf).to_bits(), (a << F) / b);
|
||||
assert_eq!((af & bf).to_bits(), (a << F) & (b << F));
|
||||
assert_eq!((af | bf).to_bits(), (a << F) | (b << F));
|
||||
assert_eq!((af ^ bf).to_bits(), (a << F) ^ (b << F));
|
||||
assert_eq!((-af).to_bits(), -(a << F));
|
||||
assert_eq!((!af).to_bits(), !(a << F));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fixed_i128() {
|
||||
let a = 0x0003456789abcdef_0123456789abcdef_i128;
|
||||
let b = 5;
|
||||
for &pair in &[(a, b), (a, -b), (-a, b), (-a, -b)] {
|
||||
let (a, b) = pair;
|
||||
let af = FixedI128::from_bits(a << F);
|
||||
let bf = FixedI128::from_bits(b << F);
|
||||
assert_eq!((af + bf).to_bits(), (a << F) + (b << F));
|
||||
assert_eq!((af - bf).to_bits(), (a << F) - (b << F));
|
||||
assert_eq!((af * bf).to_bits(), (a << F) * b);
|
||||
// assert_eq!((af / bf).to_bits(), (a << F) / b);
|
||||
assert_eq!((af & bf).to_bits(), (a << F) & (b << F));
|
||||
assert_eq!((af | bf).to_bits(), (a << F) | (b << F));
|
||||
assert_eq!((af ^ bf).to_bits(), (a << F) ^ (b << F));
|
||||
assert_eq!((!af).to_bits(), !(a << F));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_f32() {
|
||||
for u in 0x00..=0xff {
|
||||
|
|
Loading…
Reference in New Issue