includes zero weighted entries in WeightedShuffle (#22829)
Current WeightedShuffle implementation excludes zero weighted entries from the shuffle: https://github.com/solana-labs/solana/blob/13e631dcf/gossip/src/weighted_shuffle.rs#L29-L30 Though mathematically this might make more sense, for our use-cases (turbine specifically), this results in less efficient code: https://github.com/solana-labs/solana/blob/13e631dcf/core/src/cluster_nodes.rs#L409-L430 This commit changes the implementation so that zero weighted indices are also included in the shuffle but appear only at the end after non-zero weighted indices.
This commit is contained in:
parent
17b4563a6f
commit
604ca9316c
|
@ -410,23 +410,11 @@ fn enable_turbine_peers_shuffle_patch(shred_slot: Slot, root_bank: &Bank) -> boo
|
|||
// Unstaked nodes will always appear at the very end.
|
||||
fn shuffle_nodes<'a, R: Rng>(rng: &mut R, nodes: &[&'a Node]) -> Vec<&'a Node> {
|
||||
// Nodes are sorted by (stake, pubkey) in descending order.
|
||||
let stakes: Vec<u64> = nodes
|
||||
.iter()
|
||||
.map(|node| node.stake)
|
||||
.take_while(|stake| *stake > 0)
|
||||
.collect();
|
||||
let num_staked = stakes.len();
|
||||
let mut out: Vec<_> = WeightedShuffle::new(rng, &stakes)
|
||||
let stakes: Vec<u64> = nodes.iter().map(|node| node.stake).collect();
|
||||
WeightedShuffle::new(rng, &stakes)
|
||||
.unwrap()
|
||||
.map(|i| nodes[i])
|
||||
.collect();
|
||||
let weights = vec![1; nodes.len() - num_staked];
|
||||
out.extend(
|
||||
WeightedShuffle::new(rng, &weights)
|
||||
.unwrap()
|
||||
.map(|i| nodes[i + num_staked]),
|
||||
);
|
||||
out
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl<T> ClusterNodesCache<T> {
|
||||
|
|
|
@ -26,12 +26,13 @@ pub enum WeightedShuffleError<T> {
|
|||
/// - Returned indices are unique in the range [0, weights.len()).
|
||||
/// - Higher weighted indices tend to appear earlier proportional to their
|
||||
/// weight.
|
||||
/// - Zero weighted indices are excluded. Therefore the iterator may have
|
||||
/// count less than weights.len().
|
||||
/// - Zero weighted indices are shuffled and appear only at the end, after
|
||||
/// non-zero weighted indices.
|
||||
pub struct WeightedShuffle<'a, R, T> {
|
||||
arr: Vec<T>, // Underlying array implementing binary indexed tree.
|
||||
sum: T, // Current sum of weights, excluding already selected indices.
|
||||
rng: &'a mut R, // Random number generator.
|
||||
arr: Vec<T>, // Underlying array implementing binary indexed tree.
|
||||
sum: T, // Current sum of weights, excluding already selected indices.
|
||||
zeros: Vec<usize>, // Indices of zero weighted entries.
|
||||
rng: &'a mut R, // Random number generator.
|
||||
}
|
||||
|
||||
// The implementation uses binary indexed tree:
|
||||
|
@ -50,12 +51,17 @@ where
|
|||
let zero = <T as Default>::default();
|
||||
let mut arr = vec![zero; size];
|
||||
let mut sum = zero;
|
||||
let mut zeros = Vec::default();
|
||||
for (mut k, &weight) in (1usize..).zip(weights) {
|
||||
#[allow(clippy::neg_cmp_op_on_partial_ord)]
|
||||
// weight < zero does not work for NaNs.
|
||||
if !(weight >= zero) {
|
||||
return Err(WeightedShuffleError::NegativeWeight(weight));
|
||||
}
|
||||
if weight == zero {
|
||||
zeros.push(k - 1);
|
||||
continue;
|
||||
}
|
||||
sum = sum
|
||||
.checked_add(&weight)
|
||||
.ok_or(WeightedShuffleError::SumOverflow)?;
|
||||
|
@ -64,7 +70,12 @@ where
|
|||
k += k & k.wrapping_neg();
|
||||
}
|
||||
}
|
||||
Ok(Self { arr, sum, rng })
|
||||
Ok(Self {
|
||||
arr,
|
||||
sum,
|
||||
rng,
|
||||
zeros,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -123,15 +134,22 @@ where
|
|||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let zero = <T as Default>::default();
|
||||
#[allow(clippy::neg_cmp_op_on_partial_ord)]
|
||||
// self.sum <= zero does not work for NaNs.
|
||||
if !(self.sum > zero) {
|
||||
if self.sum > zero {
|
||||
let sample =
|
||||
<T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
|
||||
let (index, weight) = WeightedShuffle::search(self, sample);
|
||||
self.remove(index, weight);
|
||||
return Some(index - 1);
|
||||
}
|
||||
if self.zeros.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
|
||||
let (index, weight) = WeightedShuffle::search(self, sample);
|
||||
self.remove(index, weight);
|
||||
Some(index - 1)
|
||||
let index = <usize as SampleUniform>::Sampler::sample_single(
|
||||
0usize,
|
||||
self.zeros.len(),
|
||||
&mut self.rng,
|
||||
);
|
||||
Some(self.zeros.swap_remove(index))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -142,9 +160,13 @@ where
|
|||
{
|
||||
let zero = <T as Default>::default();
|
||||
let high = cumulative_weights.last().copied().unwrap_or_default();
|
||||
#[allow(clippy::neg_cmp_op_on_partial_ord)]
|
||||
if !(high > zero) {
|
||||
return None;
|
||||
if high == zero {
|
||||
if cumulative_weights.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let index =
|
||||
<usize as SampleUniform>::Sampler::sample_single(0usize, cumulative_weights.len(), rng);
|
||||
return Some(index);
|
||||
}
|
||||
let sample = <T as SampleUniform>::Sampler::sample_single(zero, high, rng);
|
||||
let mut lo = 0usize;
|
||||
|
@ -234,11 +256,14 @@ mod tests {
|
|||
R: Rng,
|
||||
{
|
||||
let mut shuffle = Vec::with_capacity(weights.len());
|
||||
loop {
|
||||
let high: u64 = weights.iter().sum();
|
||||
if high == 0 {
|
||||
break shuffle;
|
||||
}
|
||||
let mut high: u64 = weights.iter().sum();
|
||||
let mut zeros: Vec<_> = weights
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, w)| **w == 0)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
while high != 0 {
|
||||
let sample = rng.gen_range(0, high);
|
||||
let index = weights
|
||||
.iter()
|
||||
|
@ -249,8 +274,14 @@ mod tests {
|
|||
.position(|acc| sample < acc)
|
||||
.unwrap();
|
||||
shuffle.push(index);
|
||||
high -= weights[index];
|
||||
weights[index] = 0;
|
||||
}
|
||||
while !zeros.is_empty() {
|
||||
let index = <usize as SampleUniform>::Sampler::sample_single(0usize, zeros.len(), rng);
|
||||
shuffle.push(zeros.swap_remove(index));
|
||||
}
|
||||
shuffle
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -329,14 +360,15 @@ mod tests {
|
|||
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
|
||||
}
|
||||
|
||||
// Asserts that zero weights will return empty shuffle.
|
||||
// Asserts that zero weights will be shuffled.
|
||||
#[test]
|
||||
fn test_weighted_shuffle_zero_weights() {
|
||||
let weights = vec![0u64; 5];
|
||||
let mut rng = rand::thread_rng();
|
||||
let shuffle = WeightedShuffle::new(&mut rng, &weights);
|
||||
assert!(shuffle.unwrap().next().is_none());
|
||||
assert_eq!(weighted_sample_single(&mut rng, &weights), None);
|
||||
let seed = [37u8; 32];
|
||||
let mut rng = ChaChaRng::from_seed(seed);
|
||||
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
||||
assert_eq!(shuffle, [1, 4, 2, 3, 0]);
|
||||
assert_eq!(weighted_sample_single(&mut rng, &weights), Some(1));
|
||||
}
|
||||
|
||||
// Asserts that each index is selected proportional to its weight.
|
||||
|
@ -352,13 +384,13 @@ mod tests {
|
|||
counts[shuffle.next().unwrap()] += 1;
|
||||
let _ = shuffle.count(); // consume the rest.
|
||||
}
|
||||
assert_eq!(counts, [101, 0, 90113, 0, 0, 891, 8895, 0]);
|
||||
assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weighted_shuffle_hard_coded() {
|
||||
let weights = [
|
||||
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 17, 4, 50, 96, 83, 33, 16, 72,
|
||||
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 0, 17, 4, 50, 96, 0, 83, 33, 16, 72,
|
||||
];
|
||||
let cumulative_weights: Vec<_> = weights
|
||||
.iter()
|
||||
|
@ -372,7 +404,7 @@ mod tests {
|
|||
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
||||
assert_eq!(
|
||||
shuffle,
|
||||
[2, 11, 16, 0, 13, 14, 15, 10, 1, 9, 7, 6, 12, 18, 4, 17, 3, 8]
|
||||
[2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5]
|
||||
);
|
||||
let mut rng = ChaChaRng::from_seed(seed);
|
||||
assert_eq!(
|
||||
|
@ -384,12 +416,12 @@ mod tests {
|
|||
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
|
||||
assert_eq!(
|
||||
shuffle,
|
||||
[17, 3, 14, 13, 6, 10, 15, 16, 9, 2, 4, 1, 0, 7, 8, 18, 11, 12]
|
||||
[19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11]
|
||||
);
|
||||
let mut rng = ChaChaRng::from_seed(seed);
|
||||
assert_eq!(
|
||||
weighted_sample_single(&mut rng, &cumulative_weights),
|
||||
Some(17),
|
||||
Some(19),
|
||||
);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue