//! The `weighted_shuffle` module provides an iterator over shuffled weights. use { num_traits::CheckedAdd, rand::{ distributions::uniform::{SampleUniform, UniformSampler}, Rng, }, std::ops::{AddAssign, Sub, SubAssign}, }; /// Implements an iterator where indices are shuffled according to their /// weights: /// - Returned indices are unique in the range [0, weights.len()). /// - Higher weighted indices tend to appear earlier proportional to their /// weight. /// - Zero weighted indices are shuffled and appear only at the end, after /// non-zero weighted indices. #[derive(Clone)] pub struct WeightedShuffle { arr: Vec, // Underlying array implementing binary indexed tree. sum: T, // Current sum of weights, excluding already selected indices. zeros: Vec, // Indices of zero weighted entries. } // The implementation uses binary indexed tree: // https://en.wikipedia.org/wiki/Fenwick_tree // to maintain cumulative sum of weights excluding already selected indices // over self.arr. impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, { /// If weights are negative or overflow the total sum /// they are treated as zero. pub fn new(name: &'static str, weights: &[T]) -> Self { let size = weights.len() + 1; let zero = ::default(); let mut arr = vec![zero; size]; let mut sum = zero; let mut zeros = Vec::default(); let mut num_negative = 0; let mut num_overflow = 0; for (mut k, &weight) in (1usize..).zip(weights) { #[allow(clippy::neg_cmp_op_on_partial_ord)] // weight < zero does not work for NaNs. if !(weight >= zero) { zeros.push(k - 1); num_negative += 1; continue; } if weight == zero { zeros.push(k - 1); continue; } sum = match sum.checked_add(&weight) { Some(val) => val, None => { zeros.push(k - 1); num_overflow += 1; continue; } }; while k < size { arr[k] += weight; k += k & k.wrapping_neg(); } } if num_negative > 0 { datapoint_error!("weighted-shuffle-negative", (name, num_negative, i64)); } if num_overflow > 0 { datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64)); } Self { arr, sum, zeros } } } impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub, { // Returns cumulative sum of current weights upto index k (inclusive). fn cumsum(&self, mut k: usize) -> T { let mut out = ::default(); while k != 0 { out += self.arr[k]; k ^= k & k.wrapping_neg(); } out } // Removes given weight at index k. fn remove(&mut self, mut k: usize, weight: T) { self.sum -= weight; let size = self.arr.len(); while k < size { self.arr[k] -= weight; k += k & k.wrapping_neg(); } } // Returns smallest index such that self.cumsum(k) > val, // along with its respective weight. fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) { let zero = ::default(); debug_assert!(val >= zero); debug_assert!(val < self.sum); let mut lo = (/*index:*/ 0, /*cumsum:*/ zero); let mut hi = (self.arr.len() - 1, self.sum); while lo.0 + 1 < hi.0 { let k = lo.0 + (hi.0 - lo.0) / 2; let sum = self.cumsum(k); if sum <= val { lo = (k, sum); } else { hi = (k, sum); } } debug_assert!(lo.1 <= val); debug_assert!(hi.1 > val); (hi.0, hi.1 - lo.1) } pub fn remove_index(&mut self, index: usize) { let zero = ::default(); let weight = self.cumsum(index + 1) - self.cumsum(index); if weight != zero { self.remove(index + 1, weight); } else if let Some(index) = self.zeros.iter().position(|ix| *ix == index) { self.zeros.remove(index); } } } impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, { // Equivalent to weighted_shuffle.shuffle(&mut rng).next() pub fn first(&self, rng: &mut R) -> Option { let zero = ::default(); if self.sum > zero { let sample = ::Sampler::sample_single(zero, self.sum, rng); let (index, _weight) = WeightedShuffle::search(self, sample); return Some(index - 1); } if self.zeros.is_empty() { return None; } let index = ::Sampler::sample_single(0usize, self.zeros.len(), rng); self.zeros.get(index).copied() } } impl<'a, T: 'a> WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, { pub fn shuffle(mut self, rng: &'a mut R) -> impl Iterator + 'a { std::iter::from_fn(move || { let zero = ::default(); if self.sum > zero { let sample = ::Sampler::sample_single(zero, self.sum, rng); let (index, weight) = WeightedShuffle::search(&self, sample); self.remove(index, weight); return Some(index - 1); } if self.zeros.is_empty() { return None; } let index = ::Sampler::sample_single(0usize, self.zeros.len(), rng); Some(self.zeros.swap_remove(index)) }) } } #[cfg(test)] mod tests { use { super::*, rand::SeedableRng, rand_chacha::ChaChaRng, std::{convert::TryInto, iter::repeat_with}, }; fn weighted_shuffle_slow(rng: &mut R, mut weights: Vec) -> Vec where R: Rng, { let mut shuffle = Vec::with_capacity(weights.len()); let mut high: u64 = weights.iter().sum(); let mut zeros: Vec<_> = weights .iter() .enumerate() .filter(|(_, w)| **w == 0) .map(|(i, _)| i) .collect(); while high != 0 { let sample = rng.gen_range(0..high); let index = weights .iter() .scan(0, |acc, &w| { *acc += w; Some(*acc) }) .position(|acc| sample < acc) .unwrap(); shuffle.push(index); high -= weights[index]; weights[index] = 0; } while !zeros.is_empty() { let index = ::Sampler::sample_single(0usize, zeros.len(), rng); shuffle.push(zeros.swap_remove(index)); } shuffle } // Asserts that empty weights will return empty shuffle. #[test] fn test_weighted_shuffle_empty_weights() { let weights = Vec::::new(); let mut rng = rand::thread_rng(); let shuffle = WeightedShuffle::new("", &weights); assert!(shuffle.clone().shuffle(&mut rng).next().is_none()); assert!(shuffle.first(&mut rng).is_none()); } // Asserts that zero weights will be shuffled. #[test] fn test_weighted_shuffle_zero_weights() { let weights = vec![0u64; 5]; let seed = [37u8; 32]; let mut rng = ChaChaRng::from_seed(seed); let shuffle = WeightedShuffle::new("", &weights); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [1, 4, 2, 3, 0] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!(shuffle.first(&mut rng), Some(1)); } // Asserts that each index is selected proportional to its weight. #[test] fn test_weighted_shuffle_sanity() { let seed: Vec<_> = (1..).step_by(3).take(32).collect(); let seed: [u8; 32] = seed.try_into().unwrap(); let mut rng = ChaChaRng::from_seed(seed); let weights = [1, 0, 1000, 0, 0, 10, 100, 0]; let mut counts = [0; 8]; for _ in 0..100000 { let mut shuffle = WeightedShuffle::new("", &weights).shuffle(&mut rng); counts[shuffle.next().unwrap()] += 1; let _ = shuffle.count(); // consume the rest. } assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]); let mut counts = [0; 8]; for _ in 0..100000 { let mut shuffle = WeightedShuffle::new("", &weights); shuffle.remove_index(5); shuffle.remove_index(3); shuffle.remove_index(1); let mut shuffle = shuffle.shuffle(&mut rng); counts[shuffle.next().unwrap()] += 1; let _ = shuffle.count(); // consume the rest. } assert_eq!(counts, [97, 0, 90862, 0, 0, 0, 9041, 0]); } #[test] fn test_weighted_shuffle_negative_overflow() { const SEED: [u8; 32] = [48u8; 32]; let weights = [19i64, 23, 7, 0, 0, 23, 3, 0, 5, 0, 19, 29]; let mut rng = ChaChaRng::from_seed(SEED); let shuffle = WeightedShuffle::new("", &weights); assert_eq!( shuffle.shuffle(&mut rng).collect::>(), [8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7] ); // Negative weights and overflowing ones are treated as zero. let weights = [19, 23, 7, -57, i64::MAX, 23, 3, i64::MAX, 5, -79, 19, 29]; let mut rng = ChaChaRng::from_seed(SEED); let shuffle = WeightedShuffle::new("", &weights); assert_eq!( shuffle.shuffle(&mut rng).collect::>(), [8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7] ); } #[test] fn test_weighted_shuffle_hard_coded() { let weights = [ 78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72, ]; let seed = [48u8; 32]; let mut rng = ChaChaRng::from_seed(seed); let mut shuffle = WeightedShuffle::new("", &weights); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!(shuffle.first(&mut rng), Some(2)); let mut rng = ChaChaRng::from_seed(seed); shuffle.remove_index(11); shuffle.remove_index(3); shuffle.remove_index(15); shuffle.remove_index(0); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [4, 6, 1, 12, 19, 14, 17, 20, 2, 9, 10, 8, 7, 18, 13, 5, 16] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!(shuffle.first(&mut rng), Some(4)); let seed = [37u8; 32]; let mut rng = ChaChaRng::from_seed(seed); let mut shuffle = WeightedShuffle::new("", &weights); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!(shuffle.first(&mut rng), Some(19)); shuffle.remove_index(16); shuffle.remove_index(8); shuffle.remove_index(20); shuffle.remove_index(5); shuffle.remove_index(19); shuffle.remove_index(4); let mut rng = ChaChaRng::from_seed(seed); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [17, 2, 9, 14, 6, 10, 12, 1, 15, 13, 7, 0, 18, 3, 11] ); let mut rng = ChaChaRng::from_seed(seed); assert_eq!(shuffle.first(&mut rng), Some(17)); } #[test] fn test_weighted_shuffle_match_slow() { let mut rng = rand::thread_rng(); let weights: Vec = repeat_with(|| rng.gen_range(0..1000)).take(997).collect(); for _ in 0..10 { let mut seed = [0u8; 32]; rng.fill(&mut seed[..]); let mut rng = ChaChaRng::from_seed(seed); let shuffle = WeightedShuffle::new("", &weights); let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect(); let mut rng = ChaChaRng::from_seed(seed); let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone()); assert_eq!(shuffle, shuffle_slow); let mut rng = ChaChaRng::from_seed(seed); let shuffle = WeightedShuffle::new("", &weights); assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0])); } } }