From 10d85f8366bd4c0c5c22f02d91d9be76e9ad7249 Mon Sep 17 00:00:00 2001 From: Sagar Dhawan Date: Wed, 17 Jul 2019 12:44:28 -0700 Subject: [PATCH] Add weighted shuffle support for values upto u64::MAX (#5151) automerge --- core/src/weighted_shuffle.rs | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/core/src/weighted_shuffle.rs b/core/src/weighted_shuffle.rs index 0de0795276..d0d5a249cd 100644 --- a/core/src/weighted_shuffle.rs +++ b/core/src/weighted_shuffle.rs @@ -7,6 +7,8 @@ use rand_chacha::ChaChaRng; use std::iter; use std::ops::Div; +/// Returns a list of indexes shuffled based on the input weights +/// Note - The sum of all weights must not exceed `u64::MAX` pub fn weighted_shuffle(weights: Vec, rng: ChaChaRng) -> Vec where T: Copy + PartialOrd + iter::Sum + Div + FromPrimitive + ToPrimitive, @@ -17,10 +19,13 @@ where .into_iter() .enumerate() .map(|(i, v)| { - let x = (total_weight / v).to_u32().unwrap(); + let x = (total_weight / v) + .to_u64() + .expect("values > u64::max are not supported"); ( i, - (&mut rng).gen_range(1, u64::from(std::u16::MAX)) * u64::from(x), + // capture the u64 into u128s to prevent overflow + (&mut rng).gen_range(1, u128::from(std::u16::MAX)) * u128::from(x), ) }) .sorted_by(|(_, l_val), (_, r_val)| l_val.cmp(r_val)) @@ -73,4 +78,18 @@ mod tests { assert_eq!(x, y); }); } + + #[test] + fn test_weighted_shuffle_imbalanced() { + let mut weights = vec![std::u32::MAX as u64; 3]; + weights.push(1); + let shuffle = weighted_shuffle(weights.clone(), ChaChaRng::from_seed([0x5a; 32])); + shuffle.into_iter().for_each(|x| { + if x == weights.len() - 1 { + assert_eq!(weights[x], 1); + } else { + assert_eq!(weights[x], std::u32::MAX as u64); + } + }); + } }