From ef3aa2731cfe40b3298c957adfe80e454e4835ab Mon Sep 17 00:00:00 2001 From: Sagar Dhawan Date: Tue, 29 Oct 2019 17:04:11 -0700 Subject: [PATCH] Fix Weighted Best calculation (#6606) automerge --- core/src/weighted_shuffle.rs | 37 ++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/core/src/weighted_shuffle.rs b/core/src/weighted_shuffle.rs index d56cb49ef..28f40660f 100644 --- a/core/src/weighted_shuffle.rs +++ b/core/src/weighted_shuffle.rs @@ -9,25 +9,26 @@ use std::ops::Div; /// Returns a list of indexes shuffled based on the input weights /// Note - The sum of all weights must not exceed `u64::MAX` -pub fn weighted_shuffle(weights: Vec, rng: ChaChaRng) -> Vec +pub fn weighted_shuffle(weights: Vec, mut rng: ChaChaRng) -> Vec where T: Copy + PartialOrd + iter::Sum + Div + FromPrimitive + ToPrimitive, { - let mut rng = rng; let total_weight: T = weights.clone().into_iter().sum(); weights .into_iter() .enumerate() .map(|(i, v)| { + // This generates an "inverse" weight but it avoids floating point math let x = (total_weight / v) .to_u64() .expect("values > u64::max are not supported"); ( i, // capture the u64 into u128s to prevent overflow - (&mut rng).gen_range(1, u128::from(std::u16::MAX)) * u128::from(x), + rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x), ) }) + // sort in ascending order .sorted_by(|(_, l_val), (_, r_val)| l_val.cmp(r_val)) .map(|x| x.0) .collect() @@ -35,22 +36,23 @@ where /// Returns the highest index after computing a weighted shuffle. /// Saves doing any sorting for O(n) max calculation. -pub fn weighted_best(weights_and_indicies: &[(u64, usize)], rng: ChaChaRng) -> usize { - let mut rng = rng; - if weights_and_indicies.is_empty() { +pub fn weighted_best(weights_and_indexes: &[(u64, usize)], mut rng: ChaChaRng) -> usize { + if weights_and_indexes.is_empty() { return 0; } - let total_weight: u64 = weights_and_indicies.iter().map(|x| x.0).sum(); - let mut best_weight = 0; + let total_weight: u64 = weights_and_indexes.iter().map(|x| x.0).sum(); + let mut lowest_weight = std::u128::MAX; let mut best_index = 0; - for v in weights_and_indicies { + for v in weights_and_indexes { + // This generates an "inverse" weight but it avoids floating point math let x = (total_weight / v.0) .to_u64() .expect("values > u64::max are not supported"); // capture the u64 into u128s to prevent overflow - let weight = (&mut rng).gen_range(1, u128::from(std::u16::MAX)) * u128::from(x); - if weight > best_weight { - best_weight = weight; + let computed_weight = rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x); + // The highest input weight maps to the lowest computed weight + if computed_weight < lowest_weight { + lowest_weight = computed_weight; best_index = v.1; } } @@ -120,9 +122,12 @@ mod tests { #[test] fn test_weighted_best() { - let mut weights = vec![(std::u32::MAX as u64, 0); 3]; - weights.push((1, 5)); - let best = weighted_best(&weights, ChaChaRng::from_seed([0x5b; 32])); - assert_eq!(best, 5); + let weights_and_indexes: Vec<_> = vec![100u64, 1000, 10_000, 10] + .into_iter() + .enumerate() + .map(|(i, weight)| (weight, i)) + .collect(); + let best_index = weighted_best(&weights_and_indexes, ChaChaRng::from_seed([0x5b; 32])); + assert_eq!(best_index, 2); } }