Fix Weighted Best calculation (#6606)

automerge
This commit is contained in:
Sagar Dhawan 2019-10-29 17:04:11 -07:00 committed by Grimes
parent e738019c48
commit ef3aa2731c
1 changed files with 21 additions and 16 deletions

View File

@ -9,25 +9,26 @@ use std::ops::Div;
/// Returns a list of indexes shuffled based on the input weights /// Returns a list of indexes shuffled based on the input weights
/// Note - The sum of all weights must not exceed `u64::MAX` /// Note - The sum of all weights must not exceed `u64::MAX`
pub fn weighted_shuffle<T>(weights: Vec<T>, rng: ChaChaRng) -> Vec<usize> pub fn weighted_shuffle<T>(weights: Vec<T>, mut rng: ChaChaRng) -> Vec<usize>
where where
T: Copy + PartialOrd + iter::Sum + Div<T, Output = T> + FromPrimitive + ToPrimitive, T: Copy + PartialOrd + iter::Sum + Div<T, Output = T> + FromPrimitive + ToPrimitive,
{ {
let mut rng = rng;
let total_weight: T = weights.clone().into_iter().sum(); let total_weight: T = weights.clone().into_iter().sum();
weights weights
.into_iter() .into_iter()
.enumerate() .enumerate()
.map(|(i, v)| { .map(|(i, v)| {
// This generates an "inverse" weight but it avoids floating point math
let x = (total_weight / v) let x = (total_weight / v)
.to_u64() .to_u64()
.expect("values > u64::max are not supported"); .expect("values > u64::max are not supported");
( (
i, i,
// capture the u64 into u128s to prevent overflow // capture the u64 into u128s to prevent overflow
(&mut rng).gen_range(1, u128::from(std::u16::MAX)) * u128::from(x), rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x),
) )
}) })
// sort in ascending order
.sorted_by(|(_, l_val), (_, r_val)| l_val.cmp(r_val)) .sorted_by(|(_, l_val), (_, r_val)| l_val.cmp(r_val))
.map(|x| x.0) .map(|x| x.0)
.collect() .collect()
@ -35,22 +36,23 @@ where
/// Returns the highest index after computing a weighted shuffle. /// Returns the highest index after computing a weighted shuffle.
/// Saves doing any sorting for O(n) max calculation. /// Saves doing any sorting for O(n) max calculation.
pub fn weighted_best(weights_and_indicies: &[(u64, usize)], rng: ChaChaRng) -> usize { pub fn weighted_best(weights_and_indexes: &[(u64, usize)], mut rng: ChaChaRng) -> usize {
let mut rng = rng; if weights_and_indexes.is_empty() {
if weights_and_indicies.is_empty() {
return 0; return 0;
} }
let total_weight: u64 = weights_and_indicies.iter().map(|x| x.0).sum(); let total_weight: u64 = weights_and_indexes.iter().map(|x| x.0).sum();
let mut best_weight = 0; let mut lowest_weight = std::u128::MAX;
let mut best_index = 0; let mut best_index = 0;
for v in weights_and_indicies { for v in weights_and_indexes {
// This generates an "inverse" weight but it avoids floating point math
let x = (total_weight / v.0) let x = (total_weight / v.0)
.to_u64() .to_u64()
.expect("values > u64::max are not supported"); .expect("values > u64::max are not supported");
// capture the u64 into u128s to prevent overflow // capture the u64 into u128s to prevent overflow
let weight = (&mut rng).gen_range(1, u128::from(std::u16::MAX)) * u128::from(x); let computed_weight = rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x);
if weight > best_weight { // The highest input weight maps to the lowest computed weight
best_weight = weight; if computed_weight < lowest_weight {
lowest_weight = computed_weight;
best_index = v.1; best_index = v.1;
} }
} }
@ -120,9 +122,12 @@ mod tests {
#[test] #[test]
fn test_weighted_best() { fn test_weighted_best() {
let mut weights = vec![(std::u32::MAX as u64, 0); 3]; let weights_and_indexes: Vec<_> = vec![100u64, 1000, 10_000, 10]
weights.push((1, 5)); .into_iter()
let best = weighted_best(&weights, ChaChaRng::from_seed([0x5b; 32])); .enumerate()
assert_eq!(best, 5); .map(|(i, weight)| (weight, i))
.collect();
let best_index = weighted_best(&weights_and_indexes, ChaChaRng::from_seed([0x5b; 32]));
assert_eq!(best_index, 2);
} }
} }