removes Rng field from WeightedShuffle struct (#22850)

This commit is contained in:
behzad nouri 2022-02-01 15:27:23 +00:00 committed by GitHub
parent 93789ca5e5
commit 45e09664b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 44 additions and 46 deletions

View File

@ -411,8 +411,9 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo
fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> {
// Nodes are sorted by (stake, pubkey) in descending order.
let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
WeightedShuffle::new(rng, &stakes)
WeightedShuffle::new(&stakes)
.unwrap()
.shuffle(rng)
.map(|i| nodes[i])
.collect()
}

View File

@ -32,8 +32,9 @@ fn bench_weighted_shuffle_new(bencher: &mut Bencher) {
let weights = make_weights(&mut rng);
bencher.iter(|| {
rng.fill(&mut seed[..]);
WeightedShuffle::new(&mut ChaChaRng::from_seed(seed), &weights)
.unwrap()
let shuffle = WeightedShuffle::new(&weights).unwrap();
shuffle
.shuffle(&mut ChaChaRng::from_seed(seed))
.collect::<Vec<_>>()
});
}

View File

@ -2010,7 +2010,7 @@ impl ClusterInfo {
return packet_batch;
}
let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new(&mut rng, &scores).unwrap();
let shuffle = WeightedShuffle::new(&scores).unwrap().shuffle(&mut rng);
let mut total_bytes = 0;
let mut sent = 0;
for (addr, response) in shuffle.map(|i| &responses[i]) {

View File

@ -246,8 +246,9 @@ impl CrdsGossipPull {
return Err(CrdsGossipError::NoPeers);
}
let mut rng = rand::thread_rng();
let mut peers = WeightedShuffle::new(&mut rng, &weights)
let mut peers = WeightedShuffle::new(&weights)
.unwrap()
.shuffle(&mut rng)
.map(|i| peers[i]);
let peer = {
let mut rng = rand::thread_rng();

View File

@ -169,8 +169,9 @@ impl CrdsGossipPush {
.filter(|(_, stake)| *stake > 0)
.collect();
let weights: Vec<_> = peers.iter().map(|(_, stake)| *stake).collect();
WeightedShuffle::new(&mut rng, &weights)
WeightedShuffle::new(&weights)
.unwrap()
.shuffle(&mut rng)
.map(move |i| peers[i])
};
let mut keep = HashSet::new();
@ -369,7 +370,7 @@ impl CrdsGossipPush {
return;
}
let num_bloom_items = MIN_NUM_BLOOM_ITEMS.max(network_size);
let shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap();
let shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng);
let mut active_set = self.active_set.write().unwrap();
let need = Self::compute_need(self.num_active, active_set.len(), ratio);
for peer in shuffle.map(|i| peers[i]) {

View File

@ -28,25 +28,24 @@ pub enum WeightedShuffleError<T> {
/// weight.
/// - Zero weighted indices are shuffled and appear only at the end, after
/// non-zero weighted indices.
pub struct WeightedShuffle<'a, R, T> {
pub struct WeightedShuffle<T> {
arr: Vec<T>, // Underlying array implementing binary indexed tree.
sum: T, // Current sum of weights, excluding already selected indices.
zeros: Vec<usize>, // Indices of zero weighted entries.
rng: &'a mut R, // Random number generator.
}
// The implementation uses binary indexed tree:
// https://en.wikipedia.org/wiki/Fenwick_tree
// to maintain cumulative sum of weights excluding already selected indices
// over self.arr.
impl<'a, R: Rng, T> WeightedShuffle<'a, R, T>
impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + CheckedAdd,
{
/// Returns error if:
/// - any of the weights are negative.
/// - sum of weights overflows.
pub fn new(rng: &'a mut R, weights: &[T]) -> Result<Self, WeightedShuffleError<T>> {
pub fn new(weights: &[T]) -> Result<Self, WeightedShuffleError<T>> {
let size = weights.len() + 1;
let zero = <T as Default>::default();
let mut arr = vec![zero; size];
@ -70,16 +69,11 @@ where
k += k & k.wrapping_neg();
}
}
Ok(Self {
arr,
sum,
rng,
zeros,
})
Ok(Self { arr, sum, zeros })
}
}
impl<'a, R, T> WeightedShuffle<'a, R, T>
impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
{
@ -126,30 +120,26 @@ where
}
}
impl<'a, R: Rng, T> Iterator for WeightedShuffle<'a, R, T>
impl<'a, T: 'a> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
let zero = <T as Default>::default();
if self.sum > zero {
let sample =
<T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
let (index, weight) = WeightedShuffle::search(self, sample);
self.remove(index, weight);
return Some(index - 1);
}
if self.zeros.is_empty() {
return None;
}
let index = <usize as SampleUniform>::Sampler::sample_single(
0usize,
self.zeros.len(),
&mut self.rng,
);
Some(self.zeros.swap_remove(index))
pub fn shuffle<R: Rng>(mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
std::iter::from_fn(move || {
let zero = <T as Default>::default();
if self.sum > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
let (index, weight) = WeightedShuffle::search(&self, sample);
self.remove(index, weight);
return Some(index - 1);
}
if self.zeros.is_empty() {
return None;
}
let index =
<usize as SampleUniform>::Sampler::sample_single(0usize, self.zeros.len(), rng);
Some(self.zeros.swap_remove(index))
})
}
}
@ -355,8 +345,8 @@ mod tests {
fn test_weighted_shuffle_empty_weights() {
let weights = Vec::<u64>::new();
let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new(&mut rng, &weights);
assert!(shuffle.unwrap().next().is_none());
let shuffle = WeightedShuffle::new(&weights).unwrap();
assert!(shuffle.shuffle(&mut rng).next().is_none());
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
}
@ -366,7 +356,8 @@ mod tests {
let weights = vec![0u64; 5];
let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
assert_eq!(shuffle, [1, 4, 2, 3, 0]);
assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1));
}
@ -380,7 +371,7 @@ mod tests {
let weights = [1, 0, 1000, 0, 0, 10, 100, 0];
let mut counts = [0; 8];
for _ in 0..100000 {
let mut shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap();
let mut shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng);
counts[shuffle.next().unwrap()] += 1;
let _ = shuffle.count(); // consume the rest.
}
@ -401,7 +392,8 @@ mod tests {
.collect();
let seed = [48u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
assert_eq!(
shuffle,
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
@ -413,7 +405,8 @@ mod tests {
);
let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
assert_eq!(
shuffle,
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
@ -440,7 +433,8 @@ mod tests {
let mut seed = [0u8; 32];
rng.fill(&mut seed[..]);
let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
let shuffle = WeightedShuffle::new(&weights).unwrap();
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
let mut rng = ChaChaRng::from_seed(seed);
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
assert_eq!(shuffle, shuffle_slow);