From e5b644e83095dcb4435e1823090fe2541f006722 Mon Sep 17 00:00:00 2001 From: Tyera Eulberg Date: Thu, 11 Mar 2021 23:22:40 -0700 Subject: [PATCH] Add trait for saturating arithmetic (#15812) * Add SaturatingArithmetic trait * Use Duration saturating arithmetic * Use new macro to fix poh_config --- sdk/src/arithmetic.rs | 45 +++++++++++++++++++++++++++++ sdk/src/lib.rs | 1 + sdk/src/poh_config.rs | 10 +++---- sdk/src/stake_weighted_timestamp.rs | 16 ++++++---- 4 files changed, 61 insertions(+), 11 deletions(-) create mode 100644 sdk/src/arithmetic.rs diff --git a/sdk/src/arithmetic.rs b/sdk/src/arithmetic.rs new file mode 100644 index 0000000000..8f0be2df46 --- /dev/null +++ b/sdk/src/arithmetic.rs @@ -0,0 +1,45 @@ +use std::time::Duration; + +/// A helper trait for primitive types that do not yet implement saturating arithmetic methods +pub trait SaturatingArithmetic { + fn sol_saturating_add(&self, rhs: Self) -> Self; + fn sol_saturating_sub(&self, rhs: Self) -> Self; + fn sol_saturating_mul(&self, rhs: T) -> Self; +} + +/// Saturating arithmetic for Duration, until Rust support moves from nightly to stable +/// Duration::MAX is constructed manually, as Duration consts are not yet stable either. +impl SaturatingArithmetic for Duration { + fn sol_saturating_add(&self, rhs: Self) -> Self { + self.checked_add(rhs) + .unwrap_or_else(|| Self::new(u64::MAX, 1_000_000_000u32.saturating_sub(1))) + } + fn sol_saturating_sub(&self, rhs: Self) -> Self { + self.checked_sub(rhs).unwrap_or_else(|| Self::new(0, 0)) + } + fn sol_saturating_mul(&self, rhs: u32) -> Self { + self.checked_mul(rhs) + .unwrap_or_else(|| Self::new(u64::MAX, 1_000_000_000u32.saturating_sub(1))) + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + + #[test] + fn test_duration() { + let empty_duration = Duration::new(0, 0); + let max_duration = Duration::new(u64::MAX, 1_000_000_000 - 1); + let duration = Duration::new(u64::MAX, 0); + + let add = duration.sol_saturating_add(duration); + assert_eq!(add, max_duration); + + let sub = duration.sol_saturating_sub(max_duration); + assert_eq!(sub, empty_duration); + + let mult = duration.sol_saturating_mul(u32::MAX); + assert_eq!(mult, max_duration); + } +} diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index f0b6db4af8..2cc72463e6 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -8,6 +8,7 @@ pub use solana_program::*; pub mod account; pub mod account_utils; +pub mod arithmetic; pub mod builtins; pub mod client; pub mod commitment_config; diff --git a/sdk/src/poh_config.rs b/sdk/src/poh_config.rs index 12e7a583dc..a393365750 100644 --- a/sdk/src/poh_config.rs +++ b/sdk/src/poh_config.rs @@ -1,5 +1,4 @@ -#![allow(clippy::integer_arithmetic)] -use crate::clock::DEFAULT_TICKS_PER_SECOND; +use crate::{clock::DEFAULT_TICKS_PER_SECOND, unchecked_div_by_const}; use std::time::Duration; #[derive(Serialize, Deserialize, Clone, Debug, AbiExample)] @@ -29,8 +28,9 @@ impl PohConfig { impl Default for PohConfig { fn default() -> Self { - Self::new_sleep(Duration::from_micros( - 1000 * 1000 / DEFAULT_TICKS_PER_SECOND, - )) + Self::new_sleep(Duration::from_micros(unchecked_div_by_const!( + 1000 * 1000, + DEFAULT_TICKS_PER_SECOND + ))) } } diff --git a/sdk/src/stake_weighted_timestamp.rs b/sdk/src/stake_weighted_timestamp.rs index 2ef8bf608b..44478e2856 100644 --- a/sdk/src/stake_weighted_timestamp.rs +++ b/sdk/src/stake_weighted_timestamp.rs @@ -1,6 +1,7 @@ /// A helper for calculating a stake-weighted timestamp estimate from a set of timestamps and epoch /// stake. use solana_sdk::{ + arithmetic::SaturatingArithmetic, clock::{Slot, UnixTimestamp}, pubkey::Pubkey, }; @@ -43,7 +44,7 @@ where let mut total_stake: u128 = 0; for (vote_pubkey, slot_timestamp) in unique_timestamps { let (timestamp_slot, timestamp) = slot_timestamp.borrow(); - let offset = slot.saturating_sub(*timestamp_slot) as u32 * slot_duration; + let offset = slot_duration.sol_saturating_mul(slot.saturating_sub(*timestamp_slot) as u32); let estimate = timestamp.saturating_add(offset.as_secs() as i64); let stake = stakes .get(vote_pubkey.borrow()) @@ -70,16 +71,19 @@ where } // Bound estimate by `max_allowable_drift` since the start of the epoch if let Some((epoch_start_slot, epoch_start_timestamp)) = epoch_start_timestamp { - let poh_estimate_offset = slot.saturating_sub(epoch_start_slot) as u32 * slot_duration; + let poh_estimate_offset = + slot_duration.sol_saturating_mul(slot.saturating_sub(epoch_start_slot) as u32); let estimate_offset = Duration::from_secs(if fix_estimate_into_u64 { (estimate as u64).saturating_sub(epoch_start_timestamp as u64) } else { estimate.saturating_sub(epoch_start_timestamp) as u64 }); - let max_allowable_drift_fast = poh_estimate_offset * max_allowable_drift.fast / 100; - let max_allowable_drift_slow = poh_estimate_offset * max_allowable_drift.slow / 100; + let max_allowable_drift_fast = + poh_estimate_offset.sol_saturating_mul(max_allowable_drift.fast) / 100; + let max_allowable_drift_slow = + poh_estimate_offset.sol_saturating_mul(max_allowable_drift.slow) / 100; if estimate_offset > poh_estimate_offset - && estimate_offset - poh_estimate_offset > max_allowable_drift_slow + && estimate_offset.sol_saturating_sub(poh_estimate_offset) > max_allowable_drift_slow { // estimate offset since the start of the epoch is higher than // `MAX_ALLOWABLE_DRIFT_PERCENTAGE_SLOW` @@ -87,7 +91,7 @@ where .saturating_add(poh_estimate_offset.as_secs() as i64) .saturating_add(max_allowable_drift_slow.as_secs() as i64); } else if estimate_offset < poh_estimate_offset - && poh_estimate_offset - estimate_offset > max_allowable_drift_fast + && poh_estimate_offset.sol_saturating_sub(estimate_offset) > max_allowable_drift_fast { // estimate offset since the start of the epoch is lower than // `MAX_ALLOWABLE_DRIFT_PERCENTAGE_FAST`