removes Rng field from WeightedShuffle struct (#22850)
This commit is contained in:
parent
93789ca5e5
commit
45e09664b8
|
@ -411,8 +411,9 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo
|
|||
fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> {
|
||||
// Nodes are sorted by (stake, pubkey) in descending order.
|
||||
let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
|
||||
WeightedShuffle::new(rng, &stakes)
|
||||
WeightedShuffle::new(&stakes)
|
||||
.unwrap()
|
||||
.shuffle(rng)
|
||||
.map(|i| nodes[i])
|
||||
.collect()
|
||||
}
|
||||
|
|
|
@ -32,8 +32,9 @@ fn bench_weighted_shuffle_new(bencher: &mut Bencher) {
|
|||
let weights = make_weights(&mut rng);
|
||||
bencher.iter(|| {
|
||||
rng.fill(&mut seed[..]);
|
||||
WeightedShuffle::new(&mut ChaChaRng::from_seed(seed), &weights)
|
||||
.unwrap()
|
||||
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||
shuffle
|
||||
.shuffle(&mut ChaChaRng::from_seed(seed))
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
}
|
||||
|
|
|
@ -2010,7 +2010,7 @@ impl ClusterInfo {
|
|||
return packet_batch;
|
||||
}
|
||||
let mut rng = rand::thread_rng();
|
||||
let shuffle = WeightedShuffle::new(&mut rng, &scores).unwrap();
|
||||
let shuffle = WeightedShuffle::new(&scores).unwrap().shuffle(&mut rng);
|
||||
let mut total_bytes = 0;
|
||||
let mut sent = 0;
|
||||
for (addr, response) in shuffle.map(|i| &responses[i]) {
|
||||
|
|
|
@ -246,8 +246,9 @@ impl CrdsGossipPull {
|
|||
return Err(CrdsGossipError::NoPeers);
|
||||
}
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut peers = WeightedShuffle::new(&mut rng, &weights)
|
||||
let mut peers = WeightedShuffle::new(&weights)
|
||||
.unwrap()
|
||||
.shuffle(&mut rng)
|
||||
.map(|i| peers[i]);
|
||||
let peer = {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
|
|
@ -169,8 +169,9 @@ impl CrdsGossipPush {
|
|||
.filter(|(_, stake)| *stake > 0)
|
||||
.collect();
|
||||
let weights: Vec<_> = peers.iter().map(|(_, stake)| *stake).collect();
|
||||
WeightedShuffle::new(&mut rng, &weights)
|
||||
WeightedShuffle::new(&weights)
|
||||
.unwrap()
|
||||
.shuffle(&mut rng)
|
||||
.map(move |i| peers[i])
|
||||
};
|
||||
let mut keep = HashSet::new();
|
||||
|
@ -369,7 +370,7 @@ impl CrdsGossipPush {
|
|||
return;
|
||||
}
|
||||
let num_bloom_items = MIN_NUM_BLOOM_ITEMS.max(network_size);
|
||||
let shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap();
|
||||
let shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng);
|
||||
let mut active_set = self.active_set.write().unwrap();
|
||||
let need = Self::compute_need(self.num_active, active_set.len(), ratio);
|
||||
for peer in shuffle.map(|i| peers[i]) {
|
||||
|
|
|
@ -28,25 +28,24 @@ pub enum WeightedShuffleError<T> {
|
|||
/// weight.
|
||||
/// - Zero weighted indices are shuffled and appear only at the end, after
|
||||
/// non-zero weighted indices.
|
||||
pub struct WeightedShuffle<'a, R, T> {
|
||||
pub struct WeightedShuffle<T> {
|
||||
arr: Vec<T>, // Underlying array implementing binary indexed tree.
|
||||
sum: T, // Current sum of weights, excluding already selected indices.
|
||||
zeros: Vec<usize>, // Indices of zero weighted entries.
|
||||
rng: &'a mut R, // Random number generator.
|
||||
}
|
||||
|
||||
// The implementation uses binary indexed tree:
|
||||
// https://en.wikipedia.org/wiki/Fenwick_tree
|
||||
// to maintain cumulative sum of weights excluding already selected indices
|
||||
// over self.arr.
|
||||
impl<'a, R: Rng, T> WeightedShuffle<'a, R, T>
|
||||
impl<T> WeightedShuffle<T>
|
||||
where
|
||||
T: Copy + Default + PartialOrd + AddAssign + CheckedAdd,
|
||||
{
|
||||
/// Returns error if:
|
||||
/// - any of the weights are negative.
|
||||
/// - sum of weights overflows.
|
||||
pub fn new(rng: &'a mut R, weights: &[T]) -> Result<Self, WeightedShuffleError<T>> {
|
||||
pub fn new(weights: &[T]) -> Result<Self, WeightedShuffleError<T>> {
|
||||
let size = weights.len() + 1;
|
||||
let zero = <T as Default>::default();
|
||||
let mut arr = vec![zero; size];
|
||||
|
@ -70,16 +69,11 @@ where
|
|||
k += k & k.wrapping_neg();
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
arr,
|
||||
sum,
|
||||
rng,
|
||||
zeros,
|
||||
})
|
||||
Ok(Self { arr, sum, zeros })
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R, T> WeightedShuffle<'a, R, T>
|
||||
impl<T> WeightedShuffle<T>
|
||||
where
|
||||
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
|
||||
{
|
||||
|
@ -126,30 +120,26 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a, R: Rng, T> Iterator for WeightedShuffle<'a, R, T>
|
||||
impl<'a, T: 'a> WeightedShuffle<T>
|
||||
where
|
||||
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
|
||||
{
|
||||
type Item = usize;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let zero = <T as Default>::default();
|
||||
if self.sum > zero {
|
||||
let sample =
|
||||
<T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
|
||||
let (index, weight) = WeightedShuffle::search(self, sample);
|
||||
self.remove(index, weight);
|
||||
return Some(index - 1);
|
||||
}
|
||||
if self.zeros.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let index = <usize as SampleUniform>::Sampler::sample_single(
|
||||
0usize,
|
||||
self.zeros.len(),
|
||||
&mut self.rng,
|
||||
);
|
||||
Some(self.zeros.swap_remove(index))
|
||||
pub fn shuffle<R: Rng>(mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
|
||||
std::iter::from_fn(move || {
|
||||
let zero = <T as Default>::default();
|
||||
if self.sum > zero {
|
||||
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, rng);
|
||||
let (index, weight) = WeightedShuffle::search(&self, sample);
|
||||
self.remove(index, weight);
|
||||
return Some(index - 1);
|
||||
}
|
||||
if self.zeros.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let index =
|
||||
<usize as SampleUniform>::Sampler::sample_single(0usize, self.zeros.len(), rng);
|
||||
Some(self.zeros.swap_remove(index))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -355,8 +345,8 @@ mod tests {
|
|||
fn test_weighted_shuffle_empty_weights() {
|
||||
let weights = Vec::<u64>::new();
|
||||
let mut rng = rand::thread_rng();
|
||||
let shuffle = WeightedShuffle::new(&mut rng, &weights);
|
||||
assert!(shuffle.unwrap().next().is_none());
|
||||
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||
assert!(shuffle.shuffle(&mut rng).next().is_none());
|
||||
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
|
||||
}
|
||||
|
||||
|
@ -366,7 +356,8 @@ mod tests {
|
|||
let weights = vec![0u64; 5];
|
||||
let seed = [37u8; 32];
|
||||
let mut rng = ChaChaRng::from_seed(seed);
|
||||
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
||||
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
||||
assert_eq!(shuffle, [1, 4, 2, 3, 0]);
|
||||
assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1));
|
||||
}
|
||||
|
@ -380,7 +371,7 @@ mod tests {
|
|||
let weights = [1, 0, 1000, 0, 0, 10, 100, 0];
|
||||
let mut counts = [0; 8];
|
||||
for _ in 0..100000 {
|
||||
let mut shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap();
|
||||
let mut shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng);
|
||||
counts[shuffle.next().unwrap()] += 1;
|
||||
let _ = shuffle.count(); // consume the rest.
|
||||
}
|
||||
|
@ -401,7 +392,8 @@ mod tests {
|
|||
.collect();
|
||||
let seed = [48u8; 32];
|
||||
let mut rng = ChaChaRng::from_seed(seed);
|
||||
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
||||
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
||||
assert_eq!(
|
||||
shuffle,
|
||||
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
|
||||
|
@ -413,7 +405,8 @@ mod tests {
|
|||
);
|
||||
let seed = [37u8; 32];
|
||||
let mut rng = ChaChaRng::from_seed(seed);
|
||||
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
||||
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
||||
assert_eq!(
|
||||
shuffle,
|
||||
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
|
||||
|
@ -440,7 +433,8 @@ mod tests {
|
|||
let mut seed = [0u8; 32];
|
||||
rng.fill(&mut seed[..]);
|
||||
let mut rng = ChaChaRng::from_seed(seed);
|
||||
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
||||
let shuffle = WeightedShuffle::new(&weights).unwrap();
|
||||
let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect();
|
||||
let mut rng = ChaChaRng::from_seed(seed);
|
||||
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
|
||||
assert_eq!(shuffle, shuffle_slow);
|
||||
|
|
Loading…
Reference in New Issue