//! The `weighted_shuffle` module provides an iterator over shuffled weights. use { itertools::Itertools, num_traits::{CheckedAdd, FromPrimitive, ToPrimitive}, rand::{ distributions::uniform::{SampleUniform, UniformSampler}, Rng, SeedableRng, }, rand_chacha::ChaChaRng, std::{ borrow::Borrow, iter, ops::{AddAssign, Div, Sub, SubAssign}, }, }; #[derive(Debug)] pub enum WeightedShuffleError { NegativeWeight(T), SumOverflow, } /// Implements an iterator where indices are shuffled according to their /// weights: /// - Returned indices are unique in the range [0, weights.len()). /// - Higher weighted indices tend to appear earlier proportional to their /// weight. /// - Zero weighted indices are shuffled and appear only at the end, after /// non-zero weighted indices. #[derive(Clone)] pub struct WeightedShuffle { arr: Vec, // Underlying array implementing binary indexed tree. sum: T, // Current sum of weights, excluding already selected indices. zeros: Vec, // Indices of zero weighted entries. } // The implementation uses binary indexed tree: // https://en.wikipedia.org/wiki/Fenwick_tree // to maintain cumulative sum of weights excluding already selected indices // over self.arr. impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, { /// Returns error if: /// - any of the weights are negative. /// - sum of weights overflows. pub fn new(weights: &[T]) -> Result> { let size = weights.len() + 1; let zero = ::default(); let mut arr = vec![zero; size]; let mut sum = zero; let mut zeros = Vec::default(); for (mut k, &weight) in (1usize..).zip(weights) { #[allow(clippy::neg_cmp_op_on_partial_ord)] // weight < zero does not work for NaNs. if !(weight >= zero) { return Err(WeightedShuffleError::NegativeWeight(weight)); } if weight == zero { zeros.push(k - 1); continue; } sum = sum .checked_add(&weight) .ok_or(WeightedShuffleError::SumOverflow)?; while k < size { arr[k] += weight; k += k & k.wrapping_neg(); } } Ok(Self { arr, sum, zeros }) } } impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub, { // Returns cumulative sum of current weights upto index k (inclusive). fn cumsum(&self, mut k: usize) -> T { let mut out = ::default(); while k != 0 { out += self.arr[k]; k ^= k & k.wrapping_neg(); } out } // Removes given weight at index k. fn remove(&mut self, mut k: usize, weight: T) { self.sum -= weight; let size = self.arr.len(); while k < size { self.arr[k] -= weight; k += k & k.wrapping_neg(); } } // Returns smallest index such that self.cumsum(k) > val, // along with its respective weight. fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) { let zero = ::default(); debug_assert!(val >= zero); debug_assert!(val < self.sum); let mut lo = (/*index:*/ 0, /*cumsum:*/ zero); let mut hi = (self.arr.len() - 1, self.sum); while lo.0 + 1 < hi.0 { let k = lo.0 + (hi.0 - lo.0) / 2; let sum = self.cumsum(k); if sum <= val { lo = (k, sum); } else { hi = (k, sum); } } debug_assert!(lo.1 <= val); debug_assert!(hi.1 > val); (hi.0, hi.1 - lo.1) } pub fn remove_index(&mut self, index: usize) { let zero = ::default(); let weight = self.cumsum(index + 1) - self.cumsum(index); if weight != zero { self.remove(index + 1, weight); } else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) { self.zeros.remove(index); } } } impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, { // Equivalent to weighted_shuffle.shuffle(&mut rng).next() pub fn first(&self, rng: &mut R) -> Option { let zero = ::default(); if self.sum > zero { let sample = ::Sampler::sample_single(zero, self.sum, rng); let (index, _weight) = WeightedShuffle::search(self, sample); return Some(index - 1); } if self.zeros.is_empty() { return None; } let index = ::Sampler::sample_single(0usize, self.zeros.len(), rng); self.zeros.get(index).copied() } } impl<'a, T: 'a> WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, { pub fn shuffle(mut self, rng: &'a mut R) -> impl Iterator + 'a { std::iter::from_fn(move || { let zero = ::default(); if self.sum > zero { let sample = ::Sampler::sample_single(zero, self.sum, rng); let (index, weight) = WeightedShuffle::search(&self, sample); self.remove(index, weight); return Some(index - 1); } if self.zeros.is_empty() { return None; } let index = ::Sampler::sample_single(0usize, self.zeros.len(), rng); Some(self.zeros.swap_remove(index)) }) } } /// Returns a list of indexes shuffled based on the input weights /// Note - The sum of all weights must not exceed `u64::MAX` pub fn weighted_shuffle(weights: F, seed: [u8; 32]) -> Vec where T: Copy + PartialOrd + iter::Sum + Div + FromPrimitive + ToPrimitive, B: Borrow, F: Iterator + Clone, { let total_weight: T = weights.clone().map(|x| *x.borrow()).sum(); let mut rng = ChaChaRng::from_seed(seed); weights .enumerate() .map(|(i, weight)| { let weight = weight.borrow(); // This generates an "inverse" weight but it avoids floating point math let x = (total_weight / *weight) .to_u64() .expect("values > u64::max are not supported"); ( i, // capture the u64 into u128s to prevent overflow rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x), ) }) // sort in ascending order .sorted_by(|(_, l_val), (_, r_val)| l_val.cmp(r_val)) .map(|x| x.0) .collect() } /// Returns the highest index after computing a weighted shuffle. /// Saves doing any sorting for O(n) max calculation. // TODO: Remove in favor of rand::distributions::WeightedIndex. pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> usize { if weights_and_indexes.is_empty() { return 0; } let mut rng = ChaChaRng::from_seed(seed); let total_weight: u64 = weights_and_indexes.iter().map(|x| x.0).sum(); let mut lowest_weight = std::u128::MAX; let mut best_index = 0; for v in weights_and_indexes { // This generates an "inverse" weight but it avoids floating point math let x = (total_weight / v.0) .to_u64() .expect("values > u64::max are not supported"); // capture the u64 into u128s to prevent overflow let computed_weight = rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x); // The highest input weight maps to the lowest computed weight if computed_weight < lowest_weight { lowest_weight = computed_weight; best_index = v.1; } } best_index } #[cfg(test)] mod tests { use { super::*, std::{convert::TryInto, iter::repeat_with}, }; fn weighted_shuffle_slow(rng: &mut R, mut weights: Vec) -> Vec where R: Rng, { let mut shuffle = Vec::with_capacity(weights.len()); let mut high: u64 = weights.iter().sum(); let mut zeros: Vec<_> = weights .iter() .enumerate() .filter(|(_, w)| **w == 0) .map(|(i, _)| i) .collect(); while high != 0 { let sample = rng.gen_range(0, high); let index = weights .iter() .scan(0, |acc, &w| { *acc += w; Some(*acc) }) .position(|acc| sample < acc) .unwrap(); shuffle.push(index); high -= weights[index]; weights[index] = 0; } while !zeros.is_empty() { let index = ::Sampler::sample_single(0usize, zeros.len(), rng); shuffle.push(zeros.swap_remove(index)); } shuffle } #[test] fn test_weighted_shuffle_iterator() { let mut test_set = [0; 6]; let mut count = 0; let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]); shuffle.into_iter().for_each(|x| { assert_eq!(test_set[x], 0); test_set[x] = 1; count += 1; }); assert_eq!(count, 6); } #[test] fn test_weighted_shuffle_iterator_large() { let mut test_set = [0; 100]; let mut test_weights = vec![0; 100]; (0..100).for_each(|i| test_weights[i] = (i + 1) as u64); let mut count = 0; let shuffle = weighted_shuffle(test_weights.into_iter(), [0xa5; 32]); shuffle.into_iter().for_each(|x| { assert_eq!(test_set[x], 0); test_set[x] = 1; count += 1; }); assert_eq!(count, 100); } #[test] fn test_weighted_shuffle_compare() { let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]); let shuffle1 = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]); shuffle1 .into_iter() .zip(shuffle.into_iter()) .for_each(|(x, y)| { assert_eq!(x, y); }); } #[test] fn test_weighted_shuffle_imbalanced() { let mut weights = vec![std::u32::MAX as u64; 3]; weights.push(1); let shuffle = weighted_shuffle(weights.iter().cloned(), [0x5a; 32]); shuffle.into_iter().for_each(|x| { if x == weights.len() - 1 { assert_eq!(weights[x], 1); } else { assert_eq!(weights[x], std::u32::MAX as u64); } }); } #[test] fn test_weighted_best() { let weights_and_indexes: Vec<_> = vec![100u64, 1000, 10_000, 10] .into_iter() .enumerate() .map(|(i, weight)| (weight, i)) .collect(); let best_index = weighted_best(&weights_and_indexes, [0x5b; 32]); assert_eq!(best_index, 2); } // Asserts that empty weights will return empty shuffle. #[test] fn test_weighted_shuffle_empty_weights() { let weights = Vec::::new(); let mut rng = rand::thread_rng(); let shuffle = WeightedShuffle::new(&weights).unwrap(); assert!(shuffle.clone().shuffle(&mut rng).next().is_none()); assert!(shuffle.first(&mut rng).is_none()); } // Asserts that zero weights will be shuffled. #[test] fn test_weighted_shuffle_zero_weights() { let weights = vec![0u64; 5]; let seed = [37u8; 32]; let mut rng = ChaChaRng::from_seed(seed); let shuffle = WeightedShuffle::new(&weights).unwrap(); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [1, 4, 2, 3, 0] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!(shuffle.first(&mut rng), Some(1)); } // Asserts that each index is selected proportional to its weight. #[test] fn test_weighted_shuffle_sanity() { let seed: Vec<_> = (1..).step_by(3).take(32).collect(); let seed: [u8; 32] = seed.try_into().unwrap(); let mut rng = ChaChaRng::from_seed(seed); let weights = [1, 0, 1000, 0, 0, 10, 100, 0]; let mut counts = [0; 8]; for _ in 0..100000 { let mut shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng); counts[shuffle.next().unwrap()] += 1; let _ = shuffle.count(); // consume the rest. } assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]); let mut counts = [0; 8]; for _ in 0..100000 { let mut shuffle = WeightedShuffle::new(&weights).unwrap(); shuffle.remove_index(5); shuffle.remove_index(3); shuffle.remove_index(1); let mut shuffle = shuffle.shuffle(&mut rng); counts[shuffle.next().unwrap()] += 1; let _ = shuffle.count(); // consume the rest. } assert_eq!(counts, [97, 0, 90862, 0, 0, 0, 9041, 0]); } #[test] fn test_weighted_shuffle_hard_coded() { let weights = [ 78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72, ]; let seed = [48u8; 32]; let mut rng = ChaChaRng::from_seed(seed); let mut shuffle = WeightedShuffle::new(&weights).unwrap(); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!(shuffle.first(&mut rng), Some(2)); let mut rng = ChaChaRng::from_seed(seed); shuffle.remove_index(11); shuffle.remove_index(3); shuffle.remove_index(15); shuffle.remove_index(0); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [4, 6, 1, 12, 19, 14, 17, 20, 2, 9, 10, 8, 7, 18, 13, 5, 16] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!(shuffle.first(&mut rng), Some(4)); let seed = [37u8; 32]; let mut rng = ChaChaRng::from_seed(seed); let mut shuffle = WeightedShuffle::new(&weights).unwrap(); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!(shuffle.first(&mut rng), Some(19)); shuffle.remove_index(16); shuffle.remove_index(8); shuffle.remove_index(20); shuffle.remove_index(5); shuffle.remove_index(19); shuffle.remove_index(4); let mut rng = ChaChaRng::from_seed(seed); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [17, 2, 9, 14, 6, 10, 12, 1, 15, 13, 7, 0, 18, 3, 11] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!(shuffle.first(&mut rng), Some(17)); } #[test] fn test_weighted_shuffle_match_slow() { let mut rng = rand::thread_rng(); let weights: Vec = repeat_with(|| rng.gen_range(0, 1000)).take(997).collect(); for _ in 0..10 { let mut seed = [0u8; 32]; rng.fill(&mut seed[..]); let mut rng = ChaChaRng::from_seed(seed); let shuffle = WeightedShuffle::new(&weights).unwrap(); let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect(); let mut rng = ChaChaRng::from_seed(seed); let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone()); assert_eq!(shuffle, shuffle_slow); let mut rng = ChaChaRng::from_seed(seed); let shuffle = WeightedShuffle::new(&weights).unwrap(); assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0])); } } }