From c847c9e74fe8d8094b185abe27f0d5895113ae5b Mon Sep 17 00:00:00 2001 From: Trevor Spiteri Date: Thu, 9 Aug 2018 23:39:44 +0200 Subject: [PATCH] implement multiplication for FixedU128, FixedI128 --- src/lib.rs | 169 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 162 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e43ac42..174b9af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -343,15 +343,132 @@ macro_rules! mul_div_widen { }; } +trait FallbackHelper: Sized { + type Unsigned; + fn hi_lo(self) -> (Self, Self); + fn shift_lo_up(self) -> Self; + fn shift_lo_up_unsigned(self) -> Self::Unsigned; + fn combine_lo_then_shl(self, lo: Self::Unsigned, shift: u32) -> Self; + fn carrying_add(self, other: Self) -> (Self, Self); +} + +impl FallbackHelper for u128 { + type Unsigned = u128; + #[inline] + fn hi_lo(self) -> (u128, u128) { + (self >> 64, self & !(!0 << 64)) + } + + #[inline] + fn shift_lo_up(self) -> u128 { + debug_assert!(self >> 64 == 0); + self << 64 + } + + #[inline] + fn shift_lo_up_unsigned(self) -> u128 { + debug_assert!(self >> 64 == 0); + self << 64 + } + + #[inline] + fn combine_lo_then_shl(self, lo: u128, shift: u32) -> u128 { + if shift == 128 { + return self; + } + if shift == 0 { + assert!(self == 0, "overflow"); + return lo; + } + let lo = lo >> shift; + let hi = self << (128 - shift); + assert!(self >> shift == 0, "overflow"); + lo | hi + } + + #[inline] + fn carrying_add(self, rhs: u128) -> (u128, u128) { + let (sum, overflow) = self.overflowing_add(rhs); + let carry = if overflow { 1 } else { 0 }; + (sum, carry) + } +} + +impl FallbackHelper for i128 { + type Unsigned = u128; + #[inline] + fn hi_lo(self) -> (i128, i128) { + (self >> 64, self & !(!0 << 64)) + } + + #[inline] + fn shift_lo_up(self) -> i128 { + debug_assert!(self >> 64 == 0); + self << 64 + } + + #[inline] + fn shift_lo_up_unsigned(self) -> u128 { + debug_assert!(self >> 64 == 0); + (self << 64) as u128 + } + + #[inline] + fn combine_lo_then_shl(self, lo: u128, shift: u32) -> i128 { + if shift == 128 { + return self; + } + if shift == 0 { + let ans = lo as i128; + assert!(ans >> 64 >> 64 == self, "overflow"); + return ans; + } + let lo = (lo >> shift) as i128; + let hi = self << (128 - shift); + let ans = lo | hi; + assert!(ans >> 64 >> 64 == self >> shift, "overflow"); + ans + } + + #[inline] + fn carrying_add(self, rhs: i128) -> (i128, i128) { + let (sum, overflow) = self.overflowing_add(rhs); + let carry = if overflow { + if sum < 0 { + 1 + } else { + -1 + } + } else { + 0 + }; + (sum, carry) + } +} + macro_rules! mul_div_fallback { ($Single:ty) => { impl MulDiv for $Single { - #[inline] fn mul(self, rhs: $Single) -> $Single { if F == 0 { self * rhs } else { - unimplemented!() + let (lh, ll) = self.hi_lo(); + let (rh, rl) = rhs.hi_lo(); + let ll_rl = ll.wrapping_mul(rl); + let lh_rl = lh.wrapping_mul(rl); + let ll_rh = ll.wrapping_mul(rh); + let lh_rh = lh.wrapping_mul(rh); + let col01 = ll_rl as <$Single as FallbackHelper>::Unsigned; + let (col12, carry_col3) = lh_rl.carrying_add(ll_rh); + let col23 = lh_rh; + let (col12_hi, col12_lo) = col12.hi_lo(); + let col12_lo_up = col12_lo.shift_lo_up_unsigned(); + let (ans01, carry_col2) = col01.carrying_add(col12_lo_up); + let carries = carry_col2 as $Single + carry_col3.shift_lo_up(); + let ans23 = col23.wrapping_add(carries).wrapping_add(col12_hi); + + ans23.combine_lo_then_shl(ans01, F) } } @@ -401,20 +518,58 @@ mod tests { #[test] fn fixed_i16() { let a = 12; - let b = -4; - let af = FixedI16::from_bits(a << F); - let bf = FixedI16::from_bits(b << F); + let b = 4; + for &pair in &[(a, b), (a, -b), (-a, b), (-a, -b)] { + let (a, b) = pair; + let af = FixedI16::from_bits(a << F); + let bf = FixedI16::from_bits(b << F); + assert_eq!((af + bf).to_bits(), (a << F) + (b << F)); + assert_eq!((af - bf).to_bits(), (a << F) - (b << F)); + assert_eq!((af * bf).to_bits(), (a << F) * b); + assert_eq!((af / bf).to_bits(), (a << F) / b); + assert_eq!((af & bf).to_bits(), (a << F) & (b << F)); + assert_eq!((af | bf).to_bits(), (a << F) | (b << F)); + assert_eq!((af ^ bf).to_bits(), (a << F) ^ (b << F)); + assert_eq!((-af).to_bits(), -(a << F)); + assert_eq!((!af).to_bits(), !(a << F)); + } + } + + #[test] + fn fixed_u128() { + let a = 0x0003456789abcdef_0123456789abcdef_u128; + let b = 5; + let af = FixedU128::from_bits(a << F); + let bf = FixedU128::from_bits(b << F); assert_eq!((af + bf).to_bits(), (a << F) + (b << F)); assert_eq!((af - bf).to_bits(), (a << F) - (b << F)); assert_eq!((af * bf).to_bits(), (a << F) * b); - assert_eq!((af / bf).to_bits(), (a << F) / b); + // assert_eq!((af / bf).to_bits(), (a << F) / b); assert_eq!((af & bf).to_bits(), (a << F) & (b << F)); assert_eq!((af | bf).to_bits(), (a << F) | (b << F)); assert_eq!((af ^ bf).to_bits(), (a << F) ^ (b << F)); - assert_eq!((-af).to_bits(), -(a << F)); assert_eq!((!af).to_bits(), !(a << F)); } + #[test] + fn fixed_i128() { + let a = 0x0003456789abcdef_0123456789abcdef_i128; + let b = 5; + for &pair in &[(a, b), (a, -b), (-a, b), (-a, -b)] { + let (a, b) = pair; + let af = FixedI128::from_bits(a << F); + let bf = FixedI128::from_bits(b << F); + assert_eq!((af + bf).to_bits(), (a << F) + (b << F)); + assert_eq!((af - bf).to_bits(), (a << F) - (b << F)); + assert_eq!((af * bf).to_bits(), (a << F) * b); + // assert_eq!((af / bf).to_bits(), (a << F) / b); + assert_eq!((af & bf).to_bits(), (a << F) & (b << F)); + assert_eq!((af | bf).to_bits(), (a << F) | (b << F)); + assert_eq!((af ^ bf).to_bits(), (a << F) ^ (b << F)); + assert_eq!((!af).to_bits(), !(a << F)); + } + } + #[test] fn to_f32() { for u in 0x00..=0xff {