Add weighted shuffle support for values upto u64::MAX (#5151)
automerge
This commit is contained in:
parent
7aad427511
commit
10d85f8366
|
@ -7,6 +7,8 @@ use rand_chacha::ChaChaRng;
|
||||||
use std::iter;
|
use std::iter;
|
||||||
use std::ops::Div;
|
use std::ops::Div;
|
||||||
|
|
||||||
|
/// Returns a list of indexes shuffled based on the input weights
|
||||||
|
/// Note - The sum of all weights must not exceed `u64::MAX`
|
||||||
pub fn weighted_shuffle<T>(weights: Vec<T>, rng: ChaChaRng) -> Vec<usize>
|
pub fn weighted_shuffle<T>(weights: Vec<T>, rng: ChaChaRng) -> Vec<usize>
|
||||||
where
|
where
|
||||||
T: Copy + PartialOrd + iter::Sum + Div<T, Output = T> + FromPrimitive + ToPrimitive,
|
T: Copy + PartialOrd + iter::Sum + Div<T, Output = T> + FromPrimitive + ToPrimitive,
|
||||||
|
@ -17,10 +19,13 @@ where
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, v)| {
|
.map(|(i, v)| {
|
||||||
let x = (total_weight / v).to_u32().unwrap();
|
let x = (total_weight / v)
|
||||||
|
.to_u64()
|
||||||
|
.expect("values > u64::max are not supported");
|
||||||
(
|
(
|
||||||
i,
|
i,
|
||||||
(&mut rng).gen_range(1, u64::from(std::u16::MAX)) * u64::from(x),
|
// capture the u64 into u128s to prevent overflow
|
||||||
|
(&mut rng).gen_range(1, u128::from(std::u16::MAX)) * u128::from(x),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.sorted_by(|(_, l_val), (_, r_val)| l_val.cmp(r_val))
|
.sorted_by(|(_, l_val), (_, r_val)| l_val.cmp(r_val))
|
||||||
|
@ -73,4 +78,18 @@ mod tests {
|
||||||
assert_eq!(x, y);
|
assert_eq!(x, y);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_weighted_shuffle_imbalanced() {
|
||||||
|
let mut weights = vec![std::u32::MAX as u64; 3];
|
||||||
|
weights.push(1);
|
||||||
|
let shuffle = weighted_shuffle(weights.clone(), ChaChaRng::from_seed([0x5a; 32]));
|
||||||
|
shuffle.into_iter().for_each(|x| {
|
||||||
|
if x == weights.len() - 1 {
|
||||||
|
assert_eq!(weights[x], 1);
|
||||||
|
} else {
|
||||||
|
assert_eq!(weights[x], std::u32::MAX as u64);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue