includes zero weighted entries in WeightedShuffle (#22829)

Current WeightedShuffle implementation excludes zero weighted entries
from the shuffle:
https://github.com/solana-labs/solana/blob/13e631dcf/gossip/src/weighted_shuffle.rs#L29-L30

Though mathematically this might make more sense, for our use-cases
(turbine specifically), this results in less efficient code:
https://github.com/solana-labs/solana/blob/13e631dcf/core/src/cluster_nodes.rs#L409-L430

This commit changes the implementation so that zero weighted indices are
also included in the shuffle but appear only at the end after non-zero
weighted indices.
This commit is contained in:
behzad nouri 2022-01-31 16:23:50 +00:00 committed by GitHub
parent 17b4563a6f
commit 604ca9316c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 46 deletions

View File

@ -410,23 +410,11 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo
// Unstaked nodes will always appear at the very end.
fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> {
// Nodes are sorted by (stake, pubkey) in descending order.
let stakes: Vec<u64> = nodes
.iter()
.map(|node| node.stake)
.take_while(|stake| *stake > 0)
.collect();
let num_staked = stakes.len();
let mut out: Vec<_> = WeightedShuffle::new(rng, &stakes)
let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
WeightedShuffle::new(rng, &stakes)
.unwrap()
.map(|i| nodes[i])
.collect();
let weights = vec![1; nodes.len() - num_staked];
out.extend(
WeightedShuffle::new(rng, &weights)
.unwrap()
.map(|i| nodes[i + num_staked]),
);
out
.collect()
}
impl<T> ClusterNodesCache<T> {

View File

@ -26,12 +26,13 @@ pub enum WeightedShuffleError<T> {
/// - Returned indices are unique in the range [0, weights.len()).
/// - Higher weighted indices tend to appear earlier proportional to their
/// weight.
/// - Zero weighted indices are excluded. Therefore the iterator may have
/// count less than weights.len().
/// - Zero weighted indices are shuffled and appear only at the end, after
/// non-zero weighted indices.
pub struct WeightedShuffle<'a, R, T> {
arr: Vec<T>, // Underlying array implementing binary indexed tree.
sum: T, // Current sum of weights, excluding already selected indices.
rng: &'a mut R, // Random number generator.
arr: Vec<T>, // Underlying array implementing binary indexed tree.
sum: T, // Current sum of weights, excluding already selected indices.
zeros: Vec<usize>, // Indices of zero weighted entries.
rng: &'a mut R, // Random number generator.
}
// The implementation uses binary indexed tree:
@ -50,12 +51,17 @@ where
let zero = <T as Default>::default();
let mut arr = vec![zero; size];
let mut sum = zero;
let mut zeros = Vec::default();
for (mut k, &weight) in (1usize..).zip(weights) {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
if !(weight >= zero) {
return Err(WeightedShuffleError::NegativeWeight(weight));
}
if weight == zero {
zeros.push(k - 1);
continue;
}
sum = sum
.checked_add(&weight)
.ok_or(WeightedShuffleError::SumOverflow)?;
@ -64,7 +70,12 @@ where
k += k & k.wrapping_neg();
}
}
Ok(Self { arr, sum, rng })
Ok(Self {
arr,
sum,
rng,
zeros,
})
}
}
@ -123,15 +134,22 @@ where
fn next(&mut self) -> Option<Self::Item> {
let zero = <T as Default>::default();
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// self.sum <= zero does not work for NaNs.
if !(self.sum > zero) {
if self.sum > zero {
let sample =
<T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
let (index, weight) = WeightedShuffle::search(self, sample);
self.remove(index, weight);
return Some(index - 1);
}
if self.zeros.is_empty() {
return None;
}
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
let (index, weight) = WeightedShuffle::search(self, sample);
self.remove(index, weight);
Some(index - 1)
let index = <usize as SampleUniform>::Sampler::sample_single(
0usize,
self.zeros.len(),
&mut self.rng,
);
Some(self.zeros.swap_remove(index))
}
}
@ -142,9 +160,13 @@ where
{
let zero = <T as Default>::default();
let high = cumulative_weights.last().copied().unwrap_or_default();
#[allow(clippy::neg_cmp_op_on_partial_ord)]
if !(high > zero) {
return None;
if high == zero {
if cumulative_weights.is_empty() {
return None;
}
let index =
<usize as SampleUniform>::Sampler::sample_single(0usize, cumulative_weights.len(), rng);
return Some(index);
}
let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng);
let mut lo = 0usize;
@ -234,11 +256,14 @@ mod tests {
R: Rng,
{
let mut shuffle = Vec::with_capacity(weights.len());
loop {
let high: u64 = weights.iter().sum();
if high == 0 {
break shuffle;
}
let mut high: u64 = weights.iter().sum();
let mut zeros: Vec<_> = weights
.iter()
.enumerate()
.filter(|(_, w)| **w == 0)
.map(|(i, _)| i)
.collect();
while high != 0 {
let sample = rng.gen_range(0, high);
let index = weights
.iter()
@ -249,8 +274,14 @@ mod tests {
.position(|acc| sample < acc)
.unwrap();
shuffle.push(index);
high -= weights[index];
weights[index] = 0;
}
while !zeros.is_empty() {
let index = <usize as SampleUniform>::Sampler::sample_single(0usize, zeros.len(), rng);
shuffle.push(zeros.swap_remove(index));
}
shuffle
}
#[test]
@ -329,14 +360,15 @@ mod tests {
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
}
// Asserts that zero weights will return empty shuffle.
// Asserts that zero weights will be shuffled.
#[test]
fn test_weighted_shuffle_zero_weights() {
let weights = vec![0u64; 5];
let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new(&mut rng, &weights);
assert!(shuffle.unwrap().next().is_none());
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
assert_eq!(shuffle, [1, 4, 2, 3, 0]);
assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1));
}
// Asserts that each index is selected proportional to its weight.
@ -352,13 +384,13 @@ mod tests {
counts[shuffle.next().unwrap()] += 1;
let _ = shuffle.count(); // consume the rest.
}
assert_eq!(counts, [101, 0, 90113, 0, 0, 891, 8895, 0]);
assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
}
#[test]
fn test_weighted_shuffle_hard_coded() {
let weights = [
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 17, 4, 50, 96, 83, 33, 16, 72,
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72,
];
let cumulative_weights: Vec<_> = weights
.iter()
@ -372,7 +404,7 @@ mod tests {
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
assert_eq!(
shuffle,
[2, 11, 16, 0, 13, 14, 15, 10, 1, 9, 7, 6, 12, 18, 4, 17, 3, 8]
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
);
let mut rng = ChaChaRng::from_seed(seed);
assert_eq!(
@ -384,12 +416,12 @@ mod tests {
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
assert_eq!(
shuffle,
[17, 3, 14, 13, 6, 10, 15, 16, 9, 2, 4, 1, 0, 7, 8, 18, 11, 12]
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
);
let mut rng = ChaChaRng::from_seed(seed);
assert_eq!(
weighted_sample_single(&mut rng, &cumulative_weights),
Some(17),
Some(19),
);
}