2019-06-01 07:55:43 -07:00
|
|
|
//! The `weighted_shuffle` module provides an iterator over shuffled weights.
|
|
|
|
|
2021-05-26 08:15:46 -07:00
|
|
|
use {
|
|
|
|
itertools::Itertools,
|
2021-07-07 07:14:43 -07:00
|
|
|
num_traits::{CheckedAdd, FromPrimitive, ToPrimitive},
|
|
|
|
rand::{
|
|
|
|
distributions::uniform::{SampleUniform, UniformSampler},
|
|
|
|
Rng, SeedableRng,
|
|
|
|
},
|
2021-05-26 08:15:46 -07:00
|
|
|
rand_chacha::ChaChaRng,
|
2021-07-07 07:14:43 -07:00
|
|
|
std::{
|
2021-07-21 11:15:08 -07:00
|
|
|
borrow::Borrow,
|
2021-07-07 07:14:43 -07:00
|
|
|
iter,
|
|
|
|
ops::{AddAssign, Div, Sub, SubAssign},
|
|
|
|
},
|
2021-05-26 08:15:46 -07:00
|
|
|
};
|
2019-06-01 07:55:43 -07:00
|
|
|
|
2021-07-07 07:14:43 -07:00
|
|
|
#[derive(Debug)]
|
|
|
|
pub enum WeightedShuffleError<T> {
|
|
|
|
NegativeWeight(T),
|
|
|
|
SumOverflow,
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Implements an iterator where indices are shuffled according to their
|
|
|
|
/// weights:
|
|
|
|
/// - Returned indices are unique in the range [0, weights.len()).
|
|
|
|
/// - Higher weighted indices tend to appear earlier proportional to their
|
|
|
|
/// weight.
|
2022-01-31 08:23:50 -08:00
|
|
|
/// - Zero weighted indices are shuffled and appear only at the end, after
|
|
|
|
/// non-zero weighted indices.
|
2022-02-01 07:27:23 -08:00
|
|
|
pub struct WeightedShuffle<T> {
|
2022-01-31 08:23:50 -08:00
|
|
|
arr: Vec<T>, // Underlying array implementing binary indexed tree.
|
|
|
|
sum: T, // Current sum of weights, excluding already selected indices.
|
|
|
|
zeros: Vec<usize>, // Indices of zero weighted entries.
|
2021-07-07 07:14:43 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// The implementation uses binary indexed tree:
|
|
|
|
// https://en.wikipedia.org/wiki/Fenwick_tree
|
|
|
|
// to maintain cumulative sum of weights excluding already selected indices
|
|
|
|
// over self.arr.
|
2022-02-01 07:27:23 -08:00
|
|
|
impl<T> WeightedShuffle<T>
|
2021-07-07 07:14:43 -07:00
|
|
|
where
|
|
|
|
T: Copy + Default + PartialOrd + AddAssign + CheckedAdd,
|
|
|
|
{
|
|
|
|
/// Returns error if:
|
|
|
|
/// - any of the weights are negative.
|
|
|
|
/// - sum of weights overflows.
|
2022-02-01 07:27:23 -08:00
|
|
|
pub fn new(weights: &[T]) -> Result<Self, WeightedShuffleError<T>> {
|
2021-07-07 07:14:43 -07:00
|
|
|
let size = weights.len() + 1;
|
|
|
|
let zero = <T as Default>::default();
|
|
|
|
let mut arr = vec![zero; size];
|
|
|
|
let mut sum = zero;
|
2022-01-31 08:23:50 -08:00
|
|
|
let mut zeros = Vec::default();
|
2021-07-07 07:14:43 -07:00
|
|
|
for (mut k, &weight) in (1usize..).zip(weights) {
|
|
|
|
#[allow(clippy::neg_cmp_op_on_partial_ord)]
|
|
|
|
// weight < zero does not work for NaNs.
|
|
|
|
if !(weight >= zero) {
|
|
|
|
return Err(WeightedShuffleError::NegativeWeight(weight));
|
|
|
|
}
|
2022-01-31 08:23:50 -08:00
|
|
|
if weight == zero {
|
|
|
|
zeros.push(k - 1);
|
|
|
|
continue;
|
|
|
|
}
|
2021-07-07 07:14:43 -07:00
|
|
|
sum = sum
|
|
|
|
.checked_add(&weight)
|
|
|
|
.ok_or(WeightedShuffleError::SumOverflow)?;
|
|
|
|
while k < size {
|
|
|
|
arr[k] += weight;
|
|
|
|
k += k & k.wrapping_neg();
|
|
|
|
}
|
|
|
|
}
|
2022-02-01 07:27:23 -08:00
|
|
|
Ok(Self { arr, sum, zeros })
|
2021-07-07 07:14:43 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-02-01 07:27:23 -08:00
|
|
|
impl<T> WeightedShuffle<T>
|
2021-07-07 07:14:43 -07:00
|
|
|
where
|
|
|
|
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
|
|
|
|
{
|
|
|
|
// Returns cumulative sum of current weights upto index k (inclusive).
|
|
|
|
fn cumsum(&self, mut k: usize) -> T {
|
|
|
|
let mut out = <T as Default>::default();
|
|
|
|
while k != 0 {
|
|
|
|
out += self.arr[k];
|
|
|
|
k ^= k & k.wrapping_neg();
|
|
|
|
}
|
|
|
|
out
|
|
|
|
}
|
|
|
|
|
|
|
|
// Removes given weight at index k.
|
|
|
|
fn remove(&mut self, mut k: usize, weight: T) {
|
|
|
|
self.sum -= weight;
|
|
|
|
let size = self.arr.len();
|
|
|
|
while k < size {
|
|
|
|
self.arr[k] -= weight;
|
|
|
|
k += k & k.wrapping_neg();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Returns smallest index such that self.cumsum(k) > val,
|
|
|
|
// along with its respective weight.
|
|
|
|
fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) {
|
|
|
|
let zero = <T as Default>::default();
|
|
|
|
debug_assert!(val >= zero);
|
|
|
|
debug_assert!(val < self.sum);
|
|
|
|
let mut lo = (/*index:*/ 0, /*cumsum:*/ zero);
|
|
|
|
let mut hi = (self.arr.len() - 1, self.sum);
|
|
|
|
while lo.0 + 1 < hi.0 {
|
|
|
|
let k = lo.0 + (hi.0 - lo.0) / 2;
|
|
|
|
let sum = self.cumsum(k);
|
|
|
|
if sum <= val {
|
|
|
|
lo = (k, sum);
|
|
|
|
} else {
|
|
|
|
hi = (k, sum);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
debug_assert!(lo.1 <= val);
|
|
|
|
debug_assert!(hi.1 > val);
|
|
|
|
(hi.0, hi.1 - lo.1)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-02-01 07:27:23 -08:00
|
|
|
impl<'a, T: 'a> WeightedShuffle<T>
|
2021-07-07 07:14:43 -07:00
|
|
|
where
|
|
|
|
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
|
|
|
|
{
|
2022-02-01 07:27:23 -08:00
|
|
|
pub fn shuffle<R: Rng>(mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
|
|
|
|
std::iter::from_fn(move || {
|
|
|
|
let zero = <T as Default>::default();
|
|
|
|
if self.sum > zero {
|
|
|
|
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
|
|
|
|
let (index, weight) = WeightedShuffle::search(&self, sample);
|
|
|
|
self.remove(index, weight);
|
|
|
|
return Some(index - 1);
|
|
|
|
}
|
|
|
|
if self.zeros.is_empty() {
|
|
|
|
return None;
|
|
|
|
}
|
|
|
|
let index =
|
|
|
|
<usize as SampleUniform>::Sampler::sample_single(0usize, self.zeros.len(), rng);
|
|
|
|
Some(self.zeros.swap_remove(index))
|
|
|
|
})
|
2021-07-07 07:14:43 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-10-14 08:09:36 -07:00
|
|
|
// Equivalent to WeightedShuffle(rng, weights).unwrap().next().
|
|
|
|
pub fn weighted_sample_single<R: Rng, T>(rng: &mut R, cumulative_weights: &[T]) -> Option<usize>
|
|
|
|
where
|
|
|
|
T: Copy + Default + PartialOrd + SampleUniform,
|
|
|
|
{
|
|
|
|
let zero = <T as Default>::default();
|
|
|
|
let high = cumulative_weights.last().copied().unwrap_or_default();
|
2022-01-31 08:23:50 -08:00
|
|
|
if high == zero {
|
|
|
|
if cumulative_weights.is_empty() {
|
|
|
|
return None;
|
|
|
|
}
|
|
|
|
let index =
|
|
|
|
<usize as SampleUniform>::Sampler::sample_single(0usize, cumulative_weights.len(), rng);
|
|
|
|
return Some(index);
|
2021-10-14 08:09:36 -07:00
|
|
|
}
|
|
|
|
let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng);
|
|
|
|
let mut lo = 0usize;
|
|
|
|
let mut hi = cumulative_weights.len() - 1;
|
|
|
|
while lo + 1 < hi {
|
|
|
|
let k = lo + (hi - lo) / 2;
|
|
|
|
if cumulative_weights[k] <= sample {
|
|
|
|
lo = k;
|
|
|
|
} else {
|
|
|
|
hi = k;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if cumulative_weights[lo] > sample {
|
|
|
|
Some(lo)
|
|
|
|
} else {
|
|
|
|
Some(hi)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-07-17 12:44:28 -07:00
|
|
|
/// Returns a list of indexes shuffled based on the input weights
|
|
|
|
/// Note - The sum of all weights must not exceed `u64::MAX`
|
2021-07-21 11:15:08 -07:00
|
|
|
pub fn weighted_shuffle<T, B, F>(weights: F, seed: [u8; 32]) -> Vec<usize>
|
2019-06-01 07:55:43 -07:00
|
|
|
where
|
|
|
|
T: Copy + PartialOrd + iter::Sum + Div<T, Output = T> + FromPrimitive + ToPrimitive,
|
2021-07-21 11:15:08 -07:00
|
|
|
B: Borrow<T>,
|
|
|
|
F: Iterator<Item = B> + Clone,
|
2019-06-01 07:55:43 -07:00
|
|
|
{
|
2021-07-21 11:15:08 -07:00
|
|
|
let total_weight: T = weights.clone().map(|x| *x.borrow()).sum();
|
2019-10-29 21:02:11 -07:00
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
2019-06-01 07:55:43 -07:00
|
|
|
weights
|
|
|
|
.enumerate()
|
2021-07-21 11:15:08 -07:00
|
|
|
.map(|(i, weight)| {
|
|
|
|
let weight = weight.borrow();
|
2019-10-29 17:04:11 -07:00
|
|
|
// This generates an "inverse" weight but it avoids floating point math
|
2021-07-21 11:15:08 -07:00
|
|
|
let x = (total_weight / *weight)
|
2019-07-17 12:44:28 -07:00
|
|
|
.to_u64()
|
|
|
|
.expect("values > u64::max are not supported");
|
2019-06-01 07:55:43 -07:00
|
|
|
(
|
|
|
|
i,
|
2019-07-17 12:44:28 -07:00
|
|
|
// capture the u64 into u128s to prevent overflow
|
2019-10-29 17:04:11 -07:00
|
|
|
rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x),
|
2019-06-01 07:55:43 -07:00
|
|
|
)
|
|
|
|
})
|
2019-10-29 17:04:11 -07:00
|
|
|
// sort in ascending order
|
2019-06-01 07:55:43 -07:00
|
|
|
.sorted_by(|(_, l_val), (_, r_val)| l_val.cmp(r_val))
|
|
|
|
.map(|x| x.0)
|
|
|
|
.collect()
|
|
|
|
}
|
|
|
|
|
2019-10-01 09:38:29 -07:00
|
|
|
/// Returns the highest index after computing a weighted shuffle.
|
|
|
|
/// Saves doing any sorting for O(n) max calculation.
|
2020-12-03 06:26:07 -08:00
|
|
|
// TODO: Remove in favor of rand::distributions::WeightedIndex.
|
2019-10-29 21:02:11 -07:00
|
|
|
pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> usize {
|
2019-10-29 17:04:11 -07:00
|
|
|
if weights_and_indexes.is_empty() {
|
2019-10-01 09:38:29 -07:00
|
|
|
return 0;
|
|
|
|
}
|
2019-10-29 21:02:11 -07:00
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
2019-10-29 17:04:11 -07:00
|
|
|
let total_weight: u64 = weights_and_indexes.iter().map(|x| x.0).sum();
|
|
|
|
let mut lowest_weight = std::u128::MAX;
|
2019-10-01 09:38:29 -07:00
|
|
|
let mut best_index = 0;
|
2019-10-29 17:04:11 -07:00
|
|
|
for v in weights_and_indexes {
|
|
|
|
// This generates an "inverse" weight but it avoids floating point math
|
2019-10-01 09:38:29 -07:00
|
|
|
let x = (total_weight / v.0)
|
|
|
|
.to_u64()
|
|
|
|
.expect("values > u64::max are not supported");
|
|
|
|
// capture the u64 into u128s to prevent overflow
|
2019-10-29 17:04:11 -07:00
|
|
|
let computed_weight = rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x);
|
|
|
|
// The highest input weight maps to the lowest computed weight
|
|
|
|
if computed_weight < lowest_weight {
|
|
|
|
lowest_weight = computed_weight;
|
2019-10-01 09:38:29 -07:00
|
|
|
best_index = v.1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
best_index
|
|
|
|
}
|
|
|
|
|
2019-06-01 07:55:43 -07:00
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
2021-12-03 09:00:31 -08:00
|
|
|
use {
|
|
|
|
super::*,
|
|
|
|
std::{convert::TryInto, iter::repeat_with},
|
|
|
|
};
|
2021-07-07 07:14:43 -07:00
|
|
|
|
|
|
|
fn weighted_shuffle_slow<R>(rng: &mut R, mut weights: Vec<u64>) -> Vec<usize>
|
|
|
|
where
|
|
|
|
R: Rng,
|
|
|
|
{
|
|
|
|
let mut shuffle = Vec::with_capacity(weights.len());
|
2022-01-31 08:23:50 -08:00
|
|
|
let mut high: u64 = weights.iter().sum();
|
|
|
|
let mut zeros: Vec<_> = weights
|
|
|
|
.iter()
|
|
|
|
.enumerate()
|
|
|
|
.filter(|(_, w)| **w == 0)
|
|
|
|
.map(|(i, _)| i)
|
|
|
|
.collect();
|
|
|
|
while high != 0 {
|
2021-07-07 07:14:43 -07:00
|
|
|
let sample = rng.gen_range(0, high);
|
|
|
|
let index = weights
|
|
|
|
.iter()
|
|
|
|
.scan(0, |acc, &w| {
|
|
|
|
*acc += w;
|
|
|
|
Some(*acc)
|
|
|
|
})
|
|
|
|
.position(|acc| sample < acc)
|
|
|
|
.unwrap();
|
|
|
|
shuffle.push(index);
|
2022-01-31 08:23:50 -08:00
|
|
|
high -= weights[index];
|
2021-07-07 07:14:43 -07:00
|
|
|
weights[index] = 0;
|
|
|
|
}
|
2022-01-31 08:23:50 -08:00
|
|
|
while !zeros.is_empty() {
|
|
|
|
let index = <usize as SampleUniform>::Sampler::sample_single(0usize, zeros.len(), rng);
|
|
|
|
shuffle.push(zeros.swap_remove(index));
|
|
|
|
}
|
|
|
|
shuffle
|
2021-07-07 07:14:43 -07:00
|
|
|
}
|
2019-06-01 07:55:43 -07:00
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_weighted_shuffle_iterator() {
|
|
|
|
let mut test_set = [0; 6];
|
|
|
|
let mut count = 0;
|
2021-07-21 11:15:08 -07:00
|
|
|
let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]);
|
2019-06-01 07:55:43 -07:00
|
|
|
shuffle.into_iter().for_each(|x| {
|
|
|
|
assert_eq!(test_set[x], 0);
|
|
|
|
test_set[x] = 1;
|
|
|
|
count += 1;
|
|
|
|
});
|
|
|
|
assert_eq!(count, 6);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_weighted_shuffle_iterator_large() {
|
|
|
|
let mut test_set = [0; 100];
|
|
|
|
let mut test_weights = vec![0; 100];
|
|
|
|
(0..100).for_each(|i| test_weights[i] = (i + 1) as u64);
|
|
|
|
let mut count = 0;
|
2021-07-21 11:15:08 -07:00
|
|
|
let shuffle = weighted_shuffle(test_weights.into_iter(), [0xa5; 32]);
|
2019-06-01 07:55:43 -07:00
|
|
|
shuffle.into_iter().for_each(|x| {
|
|
|
|
assert_eq!(test_set[x], 0);
|
|
|
|
test_set[x] = 1;
|
|
|
|
count += 1;
|
|
|
|
});
|
|
|
|
assert_eq!(count, 100);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_weighted_shuffle_compare() {
|
2021-07-21 11:15:08 -07:00
|
|
|
let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]);
|
2019-06-01 07:55:43 -07:00
|
|
|
|
2021-07-21 11:15:08 -07:00
|
|
|
let shuffle1 = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]);
|
2019-06-01 07:55:43 -07:00
|
|
|
shuffle1
|
|
|
|
.into_iter()
|
|
|
|
.zip(shuffle.into_iter())
|
|
|
|
.for_each(|(x, y)| {
|
|
|
|
assert_eq!(x, y);
|
|
|
|
});
|
|
|
|
}
|
2019-07-17 12:44:28 -07:00
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_weighted_shuffle_imbalanced() {
|
|
|
|
let mut weights = vec![std::u32::MAX as u64; 3];
|
|
|
|
weights.push(1);
|
2021-07-21 11:15:08 -07:00
|
|
|
let shuffle = weighted_shuffle(weights.iter().cloned(), [0x5a; 32]);
|
2019-07-17 12:44:28 -07:00
|
|
|
shuffle.into_iter().for_each(|x| {
|
|
|
|
if x == weights.len() - 1 {
|
|
|
|
assert_eq!(weights[x], 1);
|
|
|
|
} else {
|
|
|
|
assert_eq!(weights[x], std::u32::MAX as u64);
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
2019-10-01 09:38:29 -07:00
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_weighted_best() {
|
2019-10-29 17:04:11 -07:00
|
|
|
let weights_and_indexes: Vec<_> = vec![100u64, 1000, 10_000, 10]
|
|
|
|
.into_iter()
|
|
|
|
.enumerate()
|
|
|
|
.map(|(i, weight)| (weight, i))
|
|
|
|
.collect();
|
2019-10-29 21:02:11 -07:00
|
|
|
let best_index = weighted_best(&weights_and_indexes, [0x5b; 32]);
|
2019-10-29 17:04:11 -07:00
|
|
|
assert_eq!(best_index, 2);
|
2019-10-01 09:38:29 -07:00
|
|
|
}
|
2021-07-07 07:14:43 -07:00
|
|
|
|
2021-10-14 08:09:36 -07:00
|
|
|
// Asserts that empty weights will return empty shuffle.
|
|
|
|
#[test]
|
|
|
|
fn test_weighted_shuffle_empty_weights() {
|
|
|
|
let weights = Vec::<u64>::new();
|
|
|
|
let mut rng = rand::thread_rng();
|
2022-02-01 07:27:23 -08:00
|
|
|
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
|
|
|
assert!(shuffle.shuffle(&mut rng).next().is_none());
|
2021-10-14 08:09:36 -07:00
|
|
|
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
|
|
|
|
}
|
|
|
|
|
2022-01-31 08:23:50 -08:00
|
|
|
// Asserts that zero weights will be shuffled.
|
2021-10-14 08:09:36 -07:00
|
|
|
#[test]
|
|
|
|
fn test_weighted_shuffle_zero_weights() {
|
|
|
|
let weights = vec![0u64; 5];
|
2022-01-31 08:23:50 -08:00
|
|
|
let seed = [37u8; 32];
|
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
2022-02-01 07:27:23 -08:00
|
|
|
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
|
|
|
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
2022-01-31 08:23:50 -08:00
|
|
|
assert_eq!(shuffle, [1, 4, 2, 3, 0]);
|
|
|
|
assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1));
|
2021-10-14 08:09:36 -07:00
|
|
|
}
|
|
|
|
|
2021-07-07 07:14:43 -07:00
|
|
|
// Asserts that each index is selected proportional to its weight.
|
|
|
|
#[test]
|
|
|
|
fn test_weighted_shuffle_sanity() {
|
|
|
|
let seed: Vec<_> = (1..).step_by(3).take(32).collect();
|
|
|
|
let seed: [u8; 32] = seed.try_into().unwrap();
|
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
2021-10-14 08:09:36 -07:00
|
|
|
let weights = [1, 0, 1000, 0, 0, 10, 100, 0];
|
|
|
|
let mut counts = [0; 8];
|
2021-07-07 07:14:43 -07:00
|
|
|
for _ in 0..100000 {
|
2022-02-01 07:27:23 -08:00
|
|
|
let mut shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng);
|
2021-07-07 07:14:43 -07:00
|
|
|
counts[shuffle.next().unwrap()] += 1;
|
|
|
|
let _ = shuffle.count(); // consume the rest.
|
|
|
|
}
|
2022-01-31 08:23:50 -08:00
|
|
|
assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
|
2021-07-07 07:14:43 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_weighted_shuffle_hard_coded() {
|
|
|
|
let weights = [
|
2022-01-31 08:23:50 -08:00
|
|
|
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72,
|
2021-07-07 07:14:43 -07:00
|
|
|
];
|
2021-10-14 08:09:36 -07:00
|
|
|
let cumulative_weights: Vec<_> = weights
|
|
|
|
.iter()
|
|
|
|
.scan(0, |acc, w| {
|
|
|
|
*acc += w;
|
|
|
|
Some(*acc)
|
|
|
|
})
|
|
|
|
.collect();
|
2021-07-07 07:14:43 -07:00
|
|
|
let seed = [48u8; 32];
|
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
2022-02-01 07:27:23 -08:00
|
|
|
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
|
|
|
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
2021-07-07 07:14:43 -07:00
|
|
|
assert_eq!(
|
|
|
|
shuffle,
|
2022-01-31 08:23:50 -08:00
|
|
|
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
|
2021-07-07 07:14:43 -07:00
|
|
|
);
|
2021-10-14 08:09:36 -07:00
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
|
|
|
assert_eq!(
|
|
|
|
weighted_sample_single(&mut rng, &cumulative_weights),
|
|
|
|
Some(2),
|
|
|
|
);
|
2021-07-07 07:14:43 -07:00
|
|
|
let seed = [37u8; 32];
|
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
2022-02-01 07:27:23 -08:00
|
|
|
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
|
|
|
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
2021-07-07 07:14:43 -07:00
|
|
|
assert_eq!(
|
|
|
|
shuffle,
|
2022-01-31 08:23:50 -08:00
|
|
|
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
|
2021-07-07 07:14:43 -07:00
|
|
|
);
|
2021-10-14 08:09:36 -07:00
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
|
|
|
assert_eq!(
|
|
|
|
weighted_sample_single(&mut rng, &cumulative_weights),
|
2022-01-31 08:23:50 -08:00
|
|
|
Some(19),
|
2021-10-14 08:09:36 -07:00
|
|
|
);
|
2021-07-07 07:14:43 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_weighted_shuffle_match_slow() {
|
|
|
|
let mut rng = rand::thread_rng();
|
|
|
|
let weights: Vec<u64> = repeat_with(|| rng.gen_range(0, 1000)).take(997).collect();
|
2021-10-14 08:09:36 -07:00
|
|
|
let cumulative_weights: Vec<_> = weights
|
|
|
|
.iter()
|
|
|
|
.scan(0, |acc, w| {
|
|
|
|
*acc += w;
|
|
|
|
Some(*acc)
|
|
|
|
})
|
|
|
|
.collect();
|
2021-07-07 07:14:43 -07:00
|
|
|
for _ in 0..10 {
|
|
|
|
let mut seed = [0u8; 32];
|
|
|
|
rng.fill(&mut seed[..]);
|
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
2022-02-01 07:27:23 -08:00
|
|
|
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
|
|
|
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
2021-07-07 07:14:43 -07:00
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
|
|
|
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
|
2021-10-14 08:09:36 -07:00
|
|
|
assert_eq!(shuffle, shuffle_slow);
|
|
|
|
let mut rng = ChaChaRng::from_seed(seed);
|
|
|
|
assert_eq!(
|
|
|
|
weighted_sample_single(&mut rng, &cumulative_weights),
|
|
|
|
Some(shuffle[0]),
|
|
|
|
);
|
2021-07-07 07:14:43 -07:00
|
|
|
}
|
|
|
|
}
|
2019-06-01 07:55:43 -07:00
|
|
|
}
|