simplify overflow checks for multiplication and division

This commit is contained in:
Trevor Spiteri 2019-12-07 22:40:09 +01:00
parent 950f9051f7
commit be35c826da
3 changed files with 69 additions and 105 deletions

View File

@ -22,7 +22,6 @@ use crate::{
FixedU8,
};
use core::{
cmp::Ordering,
iter::{Product, Sum},
ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div,
@ -197,8 +196,8 @@ macro_rules! fixed_arith {
type Output = $Fixed<Frac>;
#[inline]
fn mul(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, dir) = self.to_bits().mul_dir(rhs.to_bits(), Frac::U32);
debug_assert!(dir == Ordering::Equal, "overflow");
let (ans, overflow) = self.to_bits().mul_overflow(rhs.to_bits(), Frac::U32);
debug_assert!(!overflow, "overflow");
Self::from_bits(ans)
}
}
@ -218,8 +217,8 @@ macro_rules! fixed_arith {
type Output = $Fixed<Frac>;
#[inline]
fn div(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, dir) = self.to_bits().div_dir(rhs.to_bits(), Frac::U32);
debug_assert!(dir == Ordering::Equal, "overflow");
let (ans, overflow) = self.to_bits().div_overflow(rhs.to_bits(), Frac::U32);
debug_assert!(!overflow, "overflow");
Self::from_bits(ans)
}
}
@ -482,51 +481,37 @@ fixed_arith! { FixedI32(i32, LeEqU32, 32), Signed }
fixed_arith! { FixedI64(i64, LeEqU64, 64), Signed }
fixed_arith! { FixedI128(i128, LeEqU128, 128), Signed }
pub(crate) trait MulDivDir: Sized {
fn mul_dir(self, rhs: Self, frac_nbits: u32) -> (Self, Ordering);
fn div_dir(self, rhs: Self, frac_nbits: u32) -> (Self, Ordering);
pub(crate) trait MulDivOverflow: Sized {
fn mul_overflow(self, rhs: Self, frac_nbits: u32) -> (Self, bool);
fn div_overflow(self, rhs: Self, frac_nbits: u32) -> (Self, bool);
}
macro_rules! mul_div_widen {
($Single:ty, $Double:ty, $Signedness:tt) => {
impl MulDivDir for $Single {
impl MulDivOverflow for $Single {
#[inline]
fn mul_dir(self, rhs: $Single, frac_nbits: u32) -> ($Single, Ordering) {
fn mul_overflow(self, rhs: $Single, frac_nbits: u32) -> ($Single, bool) {
const NBITS: u32 = <$Single>::NBITS;
let int_nbits: u32 = NBITS - frac_nbits;
let lhs2 = <$Double>::from(self);
let rhs2 = <$Double>::from(rhs) << int_nbits;
let (prod2, overflow) = lhs2.overflowing_mul(rhs2);
let dir;
if_unsigned! {
$Signedness;
dir = if !overflow {
Ordering::Equal
} else {
Ordering::Less
};
}
if_signed! {
$Signedness;
dir = if !overflow {
Ordering::Equal
} else if (self < 0) == (rhs < 0) {
Ordering::Less
} else {
Ordering::Greater
};
}
((prod2 >> NBITS) as $Single, dir)
((prod2 >> NBITS) as $Single, overflow)
}
#[inline]
fn div_dir(self, rhs: $Single, frac_nbits: u32) -> ($Single, Ordering) {
fn div_overflow(self, rhs: $Single, frac_nbits: u32) -> ($Single, bool) {
const NBITS: u32 = <$Single>::NBITS;
let lhs2 = <$Double>::from(self) << frac_nbits;
let rhs2 = <$Double>::from(rhs);
let quot2 = lhs2 / rhs2;
let quot = quot2 as $Single;
let dir = <$Double>::from(quot).cmp(&quot2);
(quot, dir)
let overflow = if_signed_unsigned! {
$Signedness,
quot2 >> NBITS != if quot < 0 { -1 } else { 0 },
quot2 >> NBITS != 0
};
(quot, overflow)
}
}
};
@ -537,7 +522,7 @@ trait FallbackHelper: Sized {
fn hi_lo(self) -> (Self, Self);
fn shift_lo_up(self) -> Self;
fn shift_lo_up_unsigned(self) -> Self::Unsigned;
fn combine_lo_then_shl(self, lo: Self::Unsigned, shift: u32) -> (Self, Ordering);
fn combine_lo_then_shl(self, lo: Self::Unsigned, shift: u32) -> (Self, bool);
fn carrying_add(self, other: Self) -> (Self, Self);
}
@ -561,15 +546,15 @@ impl FallbackHelper for u128 {
}
#[inline]
fn combine_lo_then_shl(self, lo: u128, shift: u32) -> (u128, Ordering) {
fn combine_lo_then_shl(self, lo: u128, shift: u32) -> (u128, bool) {
if shift == 128 {
(self, Ordering::Equal)
(self, false)
} else if shift == 0 {
(lo, 0.cmp(&self))
(lo, self != 0)
} else {
let lo = lo >> shift;
let hi = self << (128 - shift);
(lo | hi, 0.cmp(&(self >> shift)))
(lo | hi, self >> shift != 0)
}
}
@ -601,17 +586,17 @@ impl FallbackHelper for i128 {
}
#[inline]
fn combine_lo_then_shl(self, lo: u128, shift: u32) -> (i128, Ordering) {
fn combine_lo_then_shl(self, lo: u128, shift: u32) -> (i128, bool) {
if shift == 128 {
(self, Ordering::Equal)
(self, false)
} else if shift == 0 {
let ans = lo as i128;
(ans, (ans >> 64 >> 64).cmp(&self))
(ans, self != if ans < 0 { -1 } else { 0 })
} else {
let lo = (lo >> shift) as i128;
let hi = self << (128 - shift);
let ans = lo | hi;
(ans, (ans >> 64 >> 64).cmp(&(self >> shift)))
(ans, self >> shift != if ans < 0 { -1 } else { 0 })
}
}
@ -633,25 +618,11 @@ impl FallbackHelper for i128 {
macro_rules! mul_div_fallback {
($Single:ty, $Uns:ty, $Signedness:tt) => {
impl MulDivDir for $Single {
impl MulDivOverflow for $Single {
#[inline]
fn mul_dir(self, rhs: $Single, frac_nbits: u32) -> ($Single, Ordering) {
fn mul_overflow(self, rhs: $Single, frac_nbits: u32) -> ($Single, bool) {
if frac_nbits == 0 {
let (ans, overflow) = self.overflowing_mul(rhs);
let dir = if !overflow {
Ordering::Equal
} else {
if_signed_unsigned! {
$Signedness,
if (self < 0) == (rhs < 0) {
Ordering::Less
} else {
Ordering::Greater
},
Ordering::Less,
}
};
(ans, dir)
self.overflowing_mul(rhs)
} else {
let (lh, ll) = self.hi_lo();
let (rh, rl) = rhs.hi_lo();
@ -672,33 +643,20 @@ macro_rules! mul_div_fallback {
}
#[inline]
fn div_dir(self, rhs: $Single, frac_nbits: u32) -> ($Single, Ordering) {
fn div_overflow(self, rhs: $Single, frac_nbits: u32) -> ($Single, bool) {
if frac_nbits == 0 {
let (ans, overflow) = self.overflowing_div(rhs);
let dir = if !overflow {
Ordering::Equal
} else {
if_signed_unsigned! {
$Signedness,
if (self < 0) == (rhs < 0) {
Ordering::Less
} else {
Ordering::Greater
},
Ordering::Less,
}
};
(ans, dir)
self.overflowing_div(rhs)
} else {
const NBITS: u32 = <$Single>::NBITS;
let lhs2 = (self >> (NBITS - frac_nbits), (self << frac_nbits) as $Uns);
let (quot2, _) = rhs.div_rem_from(lhs2);
let quot = quot2.1 as $Single;
let quot2_ret = (quot >> (NBITS / 2) >> (NBITS - NBITS / 2), quot2.1);
let dir = (quot2_ret.0)
.cmp(&quot2.0)
.then((quot2_ret.1).cmp(&quot2.1));
(quot, dir)
let overflow = if_signed_unsigned! {
$Signedness,
quot2.0 != if quot < 0 { -1 } else { 0 },
quot2.0 != 0
};
(quot, overflow)
}
}
}

View File

@ -256,7 +256,7 @@ pub mod types;
mod wide_div;
mod wrapping;
use crate::{
arith::MulDivDir,
arith::MulDivOverflow,
from_str::FromStrRadix,
traits::{FromFixed, ToFixed},
types::extra::{LeEqU128, LeEqU16, LeEqU32, LeEqU64, LeEqU8},

View File

@ -153,10 +153,9 @@ assert_eq!(Fix::max_value().checked_mul(Fix::from_num(2)), None);
";
#[inline]
pub fn checked_mul(self, rhs: $Fixed<Frac>) -> Option<$Fixed<Frac>> {
let (ans, dir) = self.to_bits().mul_dir(rhs.to_bits(), Frac::U32);
match dir {
Ordering::Equal => Some(Self::from_bits(ans)),
_ => None,
match self.to_bits().mul_overflow(rhs.to_bits(), Frac::U32) {
(ans, false) => Some(Self::from_bits(ans)),
(_, true) => None,
}
}
}
@ -181,10 +180,9 @@ assert_eq!(Fix::max_value().checked_div(Fix::from_num(1) / 2), None);
if rhs.to_bits() == 0 {
return None;
}
let (ans, dir) = self.to_bits().div_dir(rhs.to_bits(), Frac::U32);
match dir {
Ordering::Equal => Some(Self::from_bits(ans)),
_ => None,
match self.to_bits().div_overflow(rhs.to_bits(), Frac::U32) {
(ans, false) => Some(Self::from_bits(ans)),
(_, true) => None,
}
}
}
@ -203,11 +201,15 @@ assert_eq!(Fix::max_value().saturating_mul(Fix::from_num(2)), Fix::max_value());
";
#[inline]
pub fn saturating_mul(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, dir) = self.to_bits().mul_dir(rhs.to_bits(), Frac::U32);
match dir {
Ordering::Equal => Self::from_bits(ans),
Ordering::Less => Self::max_value(),
Ordering::Greater => Self::min_value(),
match self.to_bits().mul_overflow(rhs.to_bits(), Frac::U32) {
(ans, false) => Self::from_bits(ans),
(_, true) => {
if (self < 0) != (rhs < 0) {
Self::min_value()
} else {
Self::max_value()
}
}
}
}
}
@ -231,11 +233,15 @@ assert_eq!(Fix::max_value().saturating_div(one_half), Fix::max_value());
";
#[inline]
pub fn saturating_div(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, dir) = self.to_bits().div_dir(rhs.to_bits(), Frac::U32);
match dir {
Ordering::Equal => Self::from_bits(ans),
Ordering::Less => Self::max_value(),
Ordering::Greater => Self::min_value(),
match self.to_bits().div_overflow(rhs.to_bits(), Frac::U32) {
(ans, false) => Self::from_bits(ans),
(_, true) => {
if (self < 0) != (rhs < 0) {
Self::min_value()
} else {
Self::max_value()
}
}
}
}
}
@ -255,7 +261,7 @@ assert_eq!(Fix::max_value().wrapping_mul(Fix::from_num(4)), wrapped);
";
#[inline]
pub fn wrapping_mul(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, _) = self.to_bits().mul_dir(rhs.to_bits(), Frac::U32);
let (ans, _) = self.to_bits().mul_overflow(rhs.to_bits(), Frac::U32);
Self::from_bits(ans)
}
}
@ -281,7 +287,7 @@ assert_eq!(Fix::max_value().wrapping_div(quarter), wrapped);
";
#[inline]
pub fn wrapping_div(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, _) = self.to_bits().div_dir(rhs.to_bits(), Frac::U32);
let (ans, _) = self.to_bits().div_overflow(rhs.to_bits(), Frac::U32);
Self::from_bits(ans)
}
}
@ -307,8 +313,8 @@ assert_eq!(Fix::max_value().overflowing_mul(Fix::from_num(4)), (wrapped, true));
";
#[inline]
pub fn overflowing_mul(self, rhs: $Fixed<Frac>) -> ($Fixed<Frac>, bool) {
let (ans, dir) = self.to_bits().mul_dir(rhs.to_bits(), Frac::U32);
(Self::from_bits(ans), dir != Ordering::Equal)
let (ans, overflow) = self.to_bits().mul_overflow(rhs.to_bits(), Frac::U32);
(Self::from_bits(ans), overflow)
}
}
@ -336,8 +342,8 @@ assert_eq!(Fix::max_value().overflowing_div(quarter), (wrapped, true));
";
#[inline]
pub fn overflowing_div(self, rhs: $Fixed<Frac>) -> ($Fixed<Frac>, bool) {
let (ans, dir) = self.to_bits().div_dir(rhs.to_bits(), Frac::U32);
(Self::from_bits(ans), dir != Ordering::Equal)
let (ans, overflow) = self.to_bits().div_overflow(rhs.to_bits(), Frac::U32);
(Self::from_bits(ans), overflow)
}
}
}