simplify overflow checks for multiplication and division

This commit is contained in:
Trevor Spiteri 2019-12-07 22:40:09 +01:00
parent 950f9051f7
commit be35c826da
3 changed files with 69 additions and 105 deletions

View File

@ -22,7 +22,6 @@ use crate::{
FixedU8, FixedU8,
}; };
use core::{ use core::{
cmp::Ordering,
iter::{Product, Sum}, iter::{Product, Sum},
ops::{ ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div,
@ -197,8 +196,8 @@ macro_rules! fixed_arith {
type Output = $Fixed<Frac>; type Output = $Fixed<Frac>;
#[inline] #[inline]
fn mul(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> { fn mul(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, dir) = self.to_bits().mul_dir(rhs.to_bits(), Frac::U32); let (ans, overflow) = self.to_bits().mul_overflow(rhs.to_bits(), Frac::U32);
debug_assert!(dir == Ordering::Equal, "overflow"); debug_assert!(!overflow, "overflow");
Self::from_bits(ans) Self::from_bits(ans)
} }
} }
@ -218,8 +217,8 @@ macro_rules! fixed_arith {
type Output = $Fixed<Frac>; type Output = $Fixed<Frac>;
#[inline] #[inline]
fn div(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> { fn div(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, dir) = self.to_bits().div_dir(rhs.to_bits(), Frac::U32); let (ans, overflow) = self.to_bits().div_overflow(rhs.to_bits(), Frac::U32);
debug_assert!(dir == Ordering::Equal, "overflow"); debug_assert!(!overflow, "overflow");
Self::from_bits(ans) Self::from_bits(ans)
} }
} }
@ -482,51 +481,37 @@ fixed_arith! { FixedI32(i32, LeEqU32, 32), Signed }
fixed_arith! { FixedI64(i64, LeEqU64, 64), Signed } fixed_arith! { FixedI64(i64, LeEqU64, 64), Signed }
fixed_arith! { FixedI128(i128, LeEqU128, 128), Signed } fixed_arith! { FixedI128(i128, LeEqU128, 128), Signed }
pub(crate) trait MulDivDir: Sized { pub(crate) trait MulDivOverflow: Sized {
fn mul_dir(self, rhs: Self, frac_nbits: u32) -> (Self, Ordering); fn mul_overflow(self, rhs: Self, frac_nbits: u32) -> (Self, bool);
fn div_dir(self, rhs: Self, frac_nbits: u32) -> (Self, Ordering); fn div_overflow(self, rhs: Self, frac_nbits: u32) -> (Self, bool);
} }
macro_rules! mul_div_widen { macro_rules! mul_div_widen {
($Single:ty, $Double:ty, $Signedness:tt) => { ($Single:ty, $Double:ty, $Signedness:tt) => {
impl MulDivDir for $Single { impl MulDivOverflow for $Single {
#[inline] #[inline]
fn mul_dir(self, rhs: $Single, frac_nbits: u32) -> ($Single, Ordering) { fn mul_overflow(self, rhs: $Single, frac_nbits: u32) -> ($Single, bool) {
const NBITS: u32 = <$Single>::NBITS; const NBITS: u32 = <$Single>::NBITS;
let int_nbits: u32 = NBITS - frac_nbits; let int_nbits: u32 = NBITS - frac_nbits;
let lhs2 = <$Double>::from(self); let lhs2 = <$Double>::from(self);
let rhs2 = <$Double>::from(rhs) << int_nbits; let rhs2 = <$Double>::from(rhs) << int_nbits;
let (prod2, overflow) = lhs2.overflowing_mul(rhs2); let (prod2, overflow) = lhs2.overflowing_mul(rhs2);
let dir; ((prod2 >> NBITS) as $Single, overflow)
if_unsigned! {
$Signedness;
dir = if !overflow {
Ordering::Equal
} else {
Ordering::Less
};
}
if_signed! {
$Signedness;
dir = if !overflow {
Ordering::Equal
} else if (self < 0) == (rhs < 0) {
Ordering::Less
} else {
Ordering::Greater
};
}
((prod2 >> NBITS) as $Single, dir)
} }
#[inline] #[inline]
fn div_dir(self, rhs: $Single, frac_nbits: u32) -> ($Single, Ordering) { fn div_overflow(self, rhs: $Single, frac_nbits: u32) -> ($Single, bool) {
const NBITS: u32 = <$Single>::NBITS;
let lhs2 = <$Double>::from(self) << frac_nbits; let lhs2 = <$Double>::from(self) << frac_nbits;
let rhs2 = <$Double>::from(rhs); let rhs2 = <$Double>::from(rhs);
let quot2 = lhs2 / rhs2; let quot2 = lhs2 / rhs2;
let quot = quot2 as $Single; let quot = quot2 as $Single;
let dir = <$Double>::from(quot).cmp(&quot2); let overflow = if_signed_unsigned! {
(quot, dir) $Signedness,
quot2 >> NBITS != if quot < 0 { -1 } else { 0 },
quot2 >> NBITS != 0
};
(quot, overflow)
} }
} }
}; };
@ -537,7 +522,7 @@ trait FallbackHelper: Sized {
fn hi_lo(self) -> (Self, Self); fn hi_lo(self) -> (Self, Self);
fn shift_lo_up(self) -> Self; fn shift_lo_up(self) -> Self;
fn shift_lo_up_unsigned(self) -> Self::Unsigned; fn shift_lo_up_unsigned(self) -> Self::Unsigned;
fn combine_lo_then_shl(self, lo: Self::Unsigned, shift: u32) -> (Self, Ordering); fn combine_lo_then_shl(self, lo: Self::Unsigned, shift: u32) -> (Self, bool);
fn carrying_add(self, other: Self) -> (Self, Self); fn carrying_add(self, other: Self) -> (Self, Self);
} }
@ -561,15 +546,15 @@ impl FallbackHelper for u128 {
} }
#[inline] #[inline]
fn combine_lo_then_shl(self, lo: u128, shift: u32) -> (u128, Ordering) { fn combine_lo_then_shl(self, lo: u128, shift: u32) -> (u128, bool) {
if shift == 128 { if shift == 128 {
(self, Ordering::Equal) (self, false)
} else if shift == 0 { } else if shift == 0 {
(lo, 0.cmp(&self)) (lo, self != 0)
} else { } else {
let lo = lo >> shift; let lo = lo >> shift;
let hi = self << (128 - shift); let hi = self << (128 - shift);
(lo | hi, 0.cmp(&(self >> shift))) (lo | hi, self >> shift != 0)
} }
} }
@ -601,17 +586,17 @@ impl FallbackHelper for i128 {
} }
#[inline] #[inline]
fn combine_lo_then_shl(self, lo: u128, shift: u32) -> (i128, Ordering) { fn combine_lo_then_shl(self, lo: u128, shift: u32) -> (i128, bool) {
if shift == 128 { if shift == 128 {
(self, Ordering::Equal) (self, false)
} else if shift == 0 { } else if shift == 0 {
let ans = lo as i128; let ans = lo as i128;
(ans, (ans >> 64 >> 64).cmp(&self)) (ans, self != if ans < 0 { -1 } else { 0 })
} else { } else {
let lo = (lo >> shift) as i128; let lo = (lo >> shift) as i128;
let hi = self << (128 - shift); let hi = self << (128 - shift);
let ans = lo | hi; let ans = lo | hi;
(ans, (ans >> 64 >> 64).cmp(&(self >> shift))) (ans, self >> shift != if ans < 0 { -1 } else { 0 })
} }
} }
@ -633,25 +618,11 @@ impl FallbackHelper for i128 {
macro_rules! mul_div_fallback { macro_rules! mul_div_fallback {
($Single:ty, $Uns:ty, $Signedness:tt) => { ($Single:ty, $Uns:ty, $Signedness:tt) => {
impl MulDivDir for $Single { impl MulDivOverflow for $Single {
#[inline] #[inline]
fn mul_dir(self, rhs: $Single, frac_nbits: u32) -> ($Single, Ordering) { fn mul_overflow(self, rhs: $Single, frac_nbits: u32) -> ($Single, bool) {
if frac_nbits == 0 { if frac_nbits == 0 {
let (ans, overflow) = self.overflowing_mul(rhs); self.overflowing_mul(rhs)
let dir = if !overflow {
Ordering::Equal
} else {
if_signed_unsigned! {
$Signedness,
if (self < 0) == (rhs < 0) {
Ordering::Less
} else {
Ordering::Greater
},
Ordering::Less,
}
};
(ans, dir)
} else { } else {
let (lh, ll) = self.hi_lo(); let (lh, ll) = self.hi_lo();
let (rh, rl) = rhs.hi_lo(); let (rh, rl) = rhs.hi_lo();
@ -672,33 +643,20 @@ macro_rules! mul_div_fallback {
} }
#[inline] #[inline]
fn div_dir(self, rhs: $Single, frac_nbits: u32) -> ($Single, Ordering) { fn div_overflow(self, rhs: $Single, frac_nbits: u32) -> ($Single, bool) {
if frac_nbits == 0 { if frac_nbits == 0 {
let (ans, overflow) = self.overflowing_div(rhs); self.overflowing_div(rhs)
let dir = if !overflow {
Ordering::Equal
} else {
if_signed_unsigned! {
$Signedness,
if (self < 0) == (rhs < 0) {
Ordering::Less
} else {
Ordering::Greater
},
Ordering::Less,
}
};
(ans, dir)
} else { } else {
const NBITS: u32 = <$Single>::NBITS; const NBITS: u32 = <$Single>::NBITS;
let lhs2 = (self >> (NBITS - frac_nbits), (self << frac_nbits) as $Uns); let lhs2 = (self >> (NBITS - frac_nbits), (self << frac_nbits) as $Uns);
let (quot2, _) = rhs.div_rem_from(lhs2); let (quot2, _) = rhs.div_rem_from(lhs2);
let quot = quot2.1 as $Single; let quot = quot2.1 as $Single;
let quot2_ret = (quot >> (NBITS / 2) >> (NBITS - NBITS / 2), quot2.1); let overflow = if_signed_unsigned! {
let dir = (quot2_ret.0) $Signedness,
.cmp(&quot2.0) quot2.0 != if quot < 0 { -1 } else { 0 },
.then((quot2_ret.1).cmp(&quot2.1)); quot2.0 != 0
(quot, dir) };
(quot, overflow)
} }
} }
} }

View File

@ -256,7 +256,7 @@ pub mod types;
mod wide_div; mod wide_div;
mod wrapping; mod wrapping;
use crate::{ use crate::{
arith::MulDivDir, arith::MulDivOverflow,
from_str::FromStrRadix, from_str::FromStrRadix,
traits::{FromFixed, ToFixed}, traits::{FromFixed, ToFixed},
types::extra::{LeEqU128, LeEqU16, LeEqU32, LeEqU64, LeEqU8}, types::extra::{LeEqU128, LeEqU16, LeEqU32, LeEqU64, LeEqU8},

View File

@ -153,10 +153,9 @@ assert_eq!(Fix::max_value().checked_mul(Fix::from_num(2)), None);
"; ";
#[inline] #[inline]
pub fn checked_mul(self, rhs: $Fixed<Frac>) -> Option<$Fixed<Frac>> { pub fn checked_mul(self, rhs: $Fixed<Frac>) -> Option<$Fixed<Frac>> {
let (ans, dir) = self.to_bits().mul_dir(rhs.to_bits(), Frac::U32); match self.to_bits().mul_overflow(rhs.to_bits(), Frac::U32) {
match dir { (ans, false) => Some(Self::from_bits(ans)),
Ordering::Equal => Some(Self::from_bits(ans)), (_, true) => None,
_ => None,
} }
} }
} }
@ -181,10 +180,9 @@ assert_eq!(Fix::max_value().checked_div(Fix::from_num(1) / 2), None);
if rhs.to_bits() == 0 { if rhs.to_bits() == 0 {
return None; return None;
} }
let (ans, dir) = self.to_bits().div_dir(rhs.to_bits(), Frac::U32); match self.to_bits().div_overflow(rhs.to_bits(), Frac::U32) {
match dir { (ans, false) => Some(Self::from_bits(ans)),
Ordering::Equal => Some(Self::from_bits(ans)), (_, true) => None,
_ => None,
} }
} }
} }
@ -203,11 +201,15 @@ assert_eq!(Fix::max_value().saturating_mul(Fix::from_num(2)), Fix::max_value());
"; ";
#[inline] #[inline]
pub fn saturating_mul(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> { pub fn saturating_mul(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, dir) = self.to_bits().mul_dir(rhs.to_bits(), Frac::U32); match self.to_bits().mul_overflow(rhs.to_bits(), Frac::U32) {
match dir { (ans, false) => Self::from_bits(ans),
Ordering::Equal => Self::from_bits(ans), (_, true) => {
Ordering::Less => Self::max_value(), if (self < 0) != (rhs < 0) {
Ordering::Greater => Self::min_value(), Self::min_value()
} else {
Self::max_value()
}
}
} }
} }
} }
@ -231,11 +233,15 @@ assert_eq!(Fix::max_value().saturating_div(one_half), Fix::max_value());
"; ";
#[inline] #[inline]
pub fn saturating_div(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> { pub fn saturating_div(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, dir) = self.to_bits().div_dir(rhs.to_bits(), Frac::U32); match self.to_bits().div_overflow(rhs.to_bits(), Frac::U32) {
match dir { (ans, false) => Self::from_bits(ans),
Ordering::Equal => Self::from_bits(ans), (_, true) => {
Ordering::Less => Self::max_value(), if (self < 0) != (rhs < 0) {
Ordering::Greater => Self::min_value(), Self::min_value()
} else {
Self::max_value()
}
}
} }
} }
} }
@ -255,7 +261,7 @@ assert_eq!(Fix::max_value().wrapping_mul(Fix::from_num(4)), wrapped);
"; ";
#[inline] #[inline]
pub fn wrapping_mul(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> { pub fn wrapping_mul(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, _) = self.to_bits().mul_dir(rhs.to_bits(), Frac::U32); let (ans, _) = self.to_bits().mul_overflow(rhs.to_bits(), Frac::U32);
Self::from_bits(ans) Self::from_bits(ans)
} }
} }
@ -281,7 +287,7 @@ assert_eq!(Fix::max_value().wrapping_div(quarter), wrapped);
"; ";
#[inline] #[inline]
pub fn wrapping_div(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> { pub fn wrapping_div(self, rhs: $Fixed<Frac>) -> $Fixed<Frac> {
let (ans, _) = self.to_bits().div_dir(rhs.to_bits(), Frac::U32); let (ans, _) = self.to_bits().div_overflow(rhs.to_bits(), Frac::U32);
Self::from_bits(ans) Self::from_bits(ans)
} }
} }
@ -307,8 +313,8 @@ assert_eq!(Fix::max_value().overflowing_mul(Fix::from_num(4)), (wrapped, true));
"; ";
#[inline] #[inline]
pub fn overflowing_mul(self, rhs: $Fixed<Frac>) -> ($Fixed<Frac>, bool) { pub fn overflowing_mul(self, rhs: $Fixed<Frac>) -> ($Fixed<Frac>, bool) {
let (ans, dir) = self.to_bits().mul_dir(rhs.to_bits(), Frac::U32); let (ans, overflow) = self.to_bits().mul_overflow(rhs.to_bits(), Frac::U32);
(Self::from_bits(ans), dir != Ordering::Equal) (Self::from_bits(ans), overflow)
} }
} }
@ -336,8 +342,8 @@ assert_eq!(Fix::max_value().overflowing_div(quarter), (wrapped, true));
"; ";
#[inline] #[inline]
pub fn overflowing_div(self, rhs: $Fixed<Frac>) -> ($Fixed<Frac>, bool) { pub fn overflowing_div(self, rhs: $Fixed<Frac>) -> ($Fixed<Frac>, bool) {
let (ans, dir) = self.to_bits().div_dir(rhs.to_bits(), Frac::U32); let (ans, overflow) = self.to_bits().div_overflow(rhs.to_bits(), Frac::U32);
(Self::from_bits(ans), dir != Ordering::Equal) (Self::from_bits(ans), overflow)
} }
} }
} }