removes legacy weighted_shuffle and weighted_best methods (#24125)

Older weighted_shuffle is based on a heuristic which results in biased
samples as shown in:
https://github.com/solana-labs/solana/pull/18343
and can be replaced with WeightedShuffle.

Also, as described in:
https://github.com/solana-labs/solana/pull/13919
weighted_best can be replaced with rand::distributions::WeightedIndex,
or WeightdShuffle::first.
This commit is contained in:
behzad nouri 2022-04-05 19:19:22 +00:00 committed by GitHub
parent 4ea59d8cb4
commit db23295e1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 163 deletions

View File

@ -17,7 +17,7 @@ use {
solana_gossip::{ solana_gossip::{
cluster_info::{ClusterInfo, ClusterInfoError}, cluster_info::{ClusterInfo, ClusterInfoError},
contact_info::ContactInfo, contact_info::ContactInfo,
weighted_shuffle::{weighted_best, weighted_shuffle}, weighted_shuffle::WeightedShuffle,
}, },
solana_ledger::{ solana_ledger::{
ancestor_iterator::{AncestorIterator, AncestorIteratorWithHash}, ancestor_iterator::{AncestorIterator, AncestorIteratorWithHash},
@ -525,16 +525,17 @@ impl ServeRepair {
if repair_peers.is_empty() { if repair_peers.is_empty() {
return Err(ClusterInfoError::NoPeers.into()); return Err(ClusterInfoError::NoPeers.into());
} }
let weights = cluster_slots.compute_weights_exclude_nonfrozen(slot, &repair_peers); let (weights, index): (Vec<_>, Vec<_>) = cluster_slots
let mut sampled_validators = weighted_shuffle( .compute_weights_exclude_nonfrozen(slot, &repair_peers)
weights.into_iter().map(|(stake, _i)| stake),
solana_sdk::pubkey::new_rand().to_bytes(),
);
sampled_validators.truncate(ANCESTOR_HASH_REPAIR_SAMPLE_SIZE);
Ok(sampled_validators
.into_iter() .into_iter()
.unzip();
let peers = WeightedShuffle::new("repair_request_ancestor_hashes", &weights)
.shuffle(&mut rand::thread_rng())
.take(ANCESTOR_HASH_REPAIR_SAMPLE_SIZE)
.map(|i| index[i])
.map(|i| (repair_peers[i].id, repair_peers[i].serve_repair)) .map(|i| (repair_peers[i].id, repair_peers[i].serve_repair))
.collect()) .collect();
Ok(peers)
} }
pub fn repair_request_duplicate_compute_best_peer( pub fn repair_request_duplicate_compute_best_peer(
@ -547,8 +548,12 @@ impl ServeRepair {
if repair_peers.is_empty() { if repair_peers.is_empty() {
return Err(ClusterInfoError::NoPeers.into()); return Err(ClusterInfoError::NoPeers.into());
} }
let weights = cluster_slots.compute_weights_exclude_nonfrozen(slot, &repair_peers); let (weights, index): (Vec<_>, Vec<_>) = cluster_slots
let n = weighted_best(&weights, solana_sdk::pubkey::new_rand().to_bytes()); .compute_weights_exclude_nonfrozen(slot, &repair_peers)
.into_iter()
.unzip();
let k = WeightedIndex::new(weights)?.sample(&mut rand::thread_rng());
let n = index[k];
Ok((repair_peers[n].id, repair_peers[n].serve_repair)) Ok((repair_peers[n].id, repair_peers[n].serve_repair))
} }

View File

@ -5,7 +5,7 @@ extern crate test;
use { use {
rand::{Rng, SeedableRng}, rand::{Rng, SeedableRng},
rand_chacha::ChaChaRng, rand_chacha::ChaChaRng,
solana_gossip::weighted_shuffle::{weighted_shuffle, WeightedShuffle}, solana_gossip::weighted_shuffle::WeightedShuffle,
std::iter::repeat_with, std::iter::repeat_with,
test::Bencher, test::Bencher,
}; };
@ -15,18 +15,7 @@ fn make_weights<R: Rng>(rng: &mut R) -> Vec<u64> {
} }
#[bench] #[bench]
fn bench_weighted_shuffle_old(bencher: &mut Bencher) { fn bench_weighted_shuffle(bencher: &mut Bencher) {
let mut seed = [0u8; 32];
let mut rng = rand::thread_rng();
let weights = make_weights(&mut rng);
bencher.iter(|| {
rng.fill(&mut seed[..]);
weighted_shuffle::<u64, &u64, std::slice::Iter<'_, u64>>(weights.iter(), seed);
});
}
#[bench]
fn bench_weighted_shuffle_new(bencher: &mut Bencher) {
let mut seed = [0u8; 32]; let mut seed = [0u8; 32];
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let weights = make_weights(&mut rng); let weights = make_weights(&mut rng);

View File

@ -1,18 +1,12 @@
//! The `weighted_shuffle` module provides an iterator over shuffled weights. //! The `weighted_shuffle` module provides an iterator over shuffled weights.
use { use {
itertools::Itertools, num_traits::CheckedAdd,
num_traits::{CheckedAdd, FromPrimitive, ToPrimitive},
rand::{ rand::{
distributions::uniform::{SampleUniform, UniformSampler}, distributions::uniform::{SampleUniform, UniformSampler},
Rng, SeedableRng, Rng,
},
rand_chacha::ChaChaRng,
std::{
borrow::Borrow,
iter,
ops::{AddAssign, Div, Sub, SubAssign},
}, },
std::ops::{AddAssign, Sub, SubAssign},
}; };
/// Implements an iterator where indices are shuffled according to their /// Implements an iterator where indices are shuffled according to their
@ -182,68 +176,12 @@ where
} }
} }
/// Returns a list of indexes shuffled based on the input weights
/// Note - The sum of all weights must not exceed `u64::MAX`
pub fn weighted_shuffle<T, B, F>(weights: F, seed: [u8; 32]) -> Vec<usize>
where
T: Copy + PartialOrd + iter::Sum + Div<T, Output = T> + FromPrimitive + ToPrimitive,
B: Borrow<T>,
F: Iterator<Item = B> + Clone,
{
let total_weight: T = weights.clone().map(|x| *x.borrow()).sum();
let mut rng = ChaChaRng::from_seed(seed);
weights
.enumerate()
.map(|(i, weight)| {
let weight = weight.borrow();
// This generates an "inverse" weight but it avoids floating point math
let x = (total_weight / *weight)
.to_u64()
.expect("values > u64::max are not supported");
(
i,
// capture the u64 into u128s to prevent overflow
rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x),
)
})
// sort in ascending order
.sorted_by(|(_, l_val), (_, r_val)| l_val.cmp(r_val))
.map(|x| x.0)
.collect()
}
/// Returns the highest index after computing a weighted shuffle.
/// Saves doing any sorting for O(n) max calculation.
// TODO: Remove in favor of rand::distributions::WeightedIndex.
pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> usize {
if weights_and_indexes.is_empty() {
return 0;
}
let mut rng = ChaChaRng::from_seed(seed);
let total_weight: u64 = weights_and_indexes.iter().map(|x| x.0).sum();
let mut lowest_weight = std::u128::MAX;
let mut best_index = 0;
for v in weights_and_indexes {
// This generates an "inverse" weight but it avoids floating point math
let x = (total_weight / v.0)
.to_u64()
.expect("values > u64::max are not supported");
// capture the u64 into u128s to prevent overflow
let computed_weight = rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x);
// The highest input weight maps to the lowest computed weight
if computed_weight < lowest_weight {
lowest_weight = computed_weight;
best_index = v.1;
}
}
best_index
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use { use {
super::*, super::*,
rand::SeedableRng,
rand_chacha::ChaChaRng,
std::{convert::TryInto, iter::repeat_with}, std::{convert::TryInto, iter::repeat_with},
}; };
@ -280,72 +218,6 @@ mod tests {
shuffle shuffle
} }
#[test]
fn test_weighted_shuffle_iterator() {
let mut test_set = [0; 6];
let mut count = 0;
let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]);
shuffle.into_iter().for_each(|x| {
assert_eq!(test_set[x], 0);
test_set[x] = 1;
count += 1;
});
assert_eq!(count, 6);
}
#[test]
fn test_weighted_shuffle_iterator_large() {
let mut test_set = [0; 100];
let mut test_weights = vec![0; 100];
(0..100).for_each(|i| test_weights[i] = (i + 1) as u64);
let mut count = 0;
let shuffle = weighted_shuffle(test_weights.into_iter(), [0xa5; 32]);
shuffle.into_iter().for_each(|x| {
assert_eq!(test_set[x], 0);
test_set[x] = 1;
count += 1;
});
assert_eq!(count, 100);
}
#[test]
fn test_weighted_shuffle_compare() {
let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]);
let shuffle1 = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]);
shuffle1
.into_iter()
.zip(shuffle.into_iter())
.for_each(|(x, y)| {
assert_eq!(x, y);
});
}
#[test]
fn test_weighted_shuffle_imbalanced() {
let mut weights = vec![std::u32::MAX as u64; 3];
weights.push(1);
let shuffle = weighted_shuffle(weights.iter().cloned(), [0x5a; 32]);
shuffle.into_iter().for_each(|x| {
if x == weights.len() - 1 {
assert_eq!(weights[x], 1);
} else {
assert_eq!(weights[x], std::u32::MAX as u64);
}
});
}
#[test]
fn test_weighted_best() {
let weights_and_indexes: Vec<_> = vec![100u64, 1000, 10_000, 10]
.into_iter()
.enumerate()
.map(|(i, weight)| (weight, i))
.collect();
let best_index = weighted_best(&weights_and_indexes, [0x5b; 32]);
assert_eq!(best_index, 2);
}
// Asserts that empty weights will return empty shuffle. // Asserts that empty weights will return empty shuffle.
#[test] #[test]
fn test_weighted_shuffle_empty_weights() { fn test_weighted_shuffle_empty_weights() {

View File

@ -2,12 +2,14 @@
use { use {
crossbeam_channel::{unbounded, Receiver, Sender, TryRecvError}, crossbeam_channel::{unbounded, Receiver, Sender, TryRecvError},
itertools::Itertools, itertools::Itertools,
rand::SeedableRng,
rand_chacha::ChaChaRng,
rayon::{iter::ParallelIterator, prelude::*}, rayon::{iter::ParallelIterator, prelude::*},
serial_test::serial, serial_test::serial,
solana_gossip::{ solana_gossip::{
cluster_info::{compute_retransmit_peers, ClusterInfo}, cluster_info::{compute_retransmit_peers, ClusterInfo},
contact_info::ContactInfo, contact_info::ContactInfo,
weighted_shuffle::weighted_shuffle, weighted_shuffle::WeightedShuffle,
}, },
solana_sdk::{pubkey::Pubkey, signer::keypair::Keypair}, solana_sdk::{pubkey::Pubkey, signer::keypair::Keypair},
solana_streamer::socket::SocketAddrSpace, solana_streamer::socket::SocketAddrSpace,
@ -95,11 +97,13 @@ fn shuffle_peers_and_index(
} }
fn stake_weighted_shuffle(stakes_and_index: &[(u64, usize)], seed: [u8; 32]) -> Vec<(u64, usize)> { fn stake_weighted_shuffle(stakes_and_index: &[(u64, usize)], seed: [u8; 32]) -> Vec<(u64, usize)> {
let stake_weights = stakes_and_index.iter().map(|(w, _)| *w); let mut rng = ChaChaRng::from_seed(seed);
let stake_weights: Vec<_> = stakes_and_index.iter().map(|(w, _)| *w).collect();
let shuffle = weighted_shuffle(stake_weights, seed); let shuffle = WeightedShuffle::new("stake_weighted_shuffle", &stake_weights);
shuffle
shuffle.iter().map(|x| stakes_and_index[*x]).collect() .shuffle(&mut rng)
.map(|i| stakes_and_index[i])
.collect()
} }
fn retransmit( fn retransmit(