Refactor Weighted Shuffle (#6614)

automerge
This commit is contained in:
Sagar Dhawan 2019-10-29 21:02:11 -07:00 committed by Grimes
parent 4ec95043d7
commit 801337a422
4 changed files with 20 additions and 24 deletions

View File

@ -26,9 +26,7 @@ use crate::weighted_shuffle::{weighted_best, weighted_shuffle};
use bincode::{deserialize, serialize, serialized_size};
use core::cmp;
use itertools::Itertools;
use rand::SeedableRng;
use rand::{thread_rng, Rng};
use rand_chacha::ChaChaRng;
use solana_ledger::bank_forks::BankForks;
use solana_ledger::blocktree::Blocktree;
use solana_ledger::staking_utils;
@ -510,11 +508,11 @@ impl ClusterInfo {
fn stake_weighted_shuffle(
stakes_and_index: &[(u64, usize)],
rng: ChaChaRng,
seed: [u8; 32],
) -> Vec<(u64, usize)> {
let stake_weights = stakes_and_index.iter().map(|(w, _)| *w).collect();
let shuffle = weighted_shuffle(stake_weights, rng);
let shuffle = weighted_shuffle(stake_weights, seed);
shuffle.iter().map(|x| stakes_and_index[*x]).collect()
}
@ -536,9 +534,9 @@ impl ClusterInfo {
id: &Pubkey,
peers: &[ContactInfo],
stakes_and_index: &[(u64, usize)],
rng: ChaChaRng,
seed: [u8; 32],
) -> (usize, Vec<(u64, usize)>) {
let shuffled_stakes_and_index = ClusterInfo::stake_weighted_shuffle(stakes_and_index, rng);
let shuffled_stakes_and_index = ClusterInfo::stake_weighted_shuffle(stakes_and_index, seed);
let mut self_index = 0;
shuffled_stakes_and_index
.iter()
@ -723,7 +721,7 @@ impl ClusterInfo {
.into_iter()
.zip(seeds)
.map(|(shred, seed)| {
let broadcast_index = weighted_best(&peers_and_stakes, ChaChaRng::from_seed(*seed));
let broadcast_index = weighted_best(&peers_and_stakes, *seed);
(shred, &peers[broadcast_index].tvu)
})

View File

@ -20,8 +20,7 @@ use indexmap::map::IndexMap;
use itertools::Itertools;
use rand;
use rand::seq::SliceRandom;
use rand::{thread_rng, RngCore, SeedableRng};
use rand_chacha::ChaChaRng;
use rand::{thread_rng, RngCore};
use solana_runtime::bloom::Bloom;
use solana_sdk::hash::Hash;
use solana_sdk::pubkey::Pubkey;
@ -106,7 +105,7 @@ impl CrdsGossipPush {
seed[0..8].copy_from_slice(&thread_rng().next_u64().to_le_bytes());
let shuffle = weighted_shuffle(
staked_peers.iter().map(|(_, stake)| *stake).collect_vec(),
ChaChaRng::from_seed(seed),
seed,
);
let mut keep = HashSet::new();
@ -244,7 +243,7 @@ impl CrdsGossipPush {
seed[0..8].copy_from_slice(&thread_rng().next_u64().to_le_bytes());
let mut shuffle = weighted_shuffle(
options.iter().map(|weighted| weighted.0).collect_vec(),
ChaChaRng::from_seed(seed),
seed,
)
.into_iter();

View File

@ -10,8 +10,6 @@ use crate::{
window_service::{should_retransmit_and_persist, WindowService},
};
use crossbeam_channel::Receiver as CrossbeamReceiver;
use rand::SeedableRng;
use rand_chacha::ChaChaRng;
use solana_ledger::{
bank_forks::BankForks,
blocktree::{Blocktree, CompletedSlotsReceiver},
@ -92,7 +90,7 @@ fn retransmit(
&me.id,
&peers,
&stakes_and_index,
ChaChaRng::from_seed(packet.meta.seed),
packet.meta.seed,
);
peers_len = cmp::max(peers_len, shuffled_stakes_and_index.len());
shuffled_stakes_and_index.remove(my_index);

View File

@ -2,18 +2,19 @@
use itertools::Itertools;
use num_traits::{FromPrimitive, ToPrimitive};
use rand::Rng;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaChaRng;
use std::iter;
use std::ops::Div;
/// Returns a list of indexes shuffled based on the input weights
/// Note - The sum of all weights must not exceed `u64::MAX`
pub fn weighted_shuffle<T>(weights: Vec<T>, mut rng: ChaChaRng) -> Vec<usize>
pub fn weighted_shuffle<T>(weights: Vec<T>, seed: [u8; 32]) -> Vec<usize>
where
T: Copy + PartialOrd + iter::Sum + Div<T, Output = T> + FromPrimitive + ToPrimitive,
{
let total_weight: T = weights.clone().into_iter().sum();
let mut rng = ChaChaRng::from_seed(seed);
weights
.into_iter()
.enumerate()
@ -36,10 +37,11 @@ where
/// Returns the highest index after computing a weighted shuffle.
/// Saves doing any sorting for O(n) max calculation.
pub fn weighted_best(weights_and_indexes: &[(u64, usize)], mut rng: ChaChaRng) -> usize {
pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> usize {
if weights_and_indexes.is_empty() {
return 0;
}
let mut rng = ChaChaRng::from_seed(seed);
let total_weight: u64 = weights_and_indexes.iter().map(|x| x.0).sum();
let mut lowest_weight = std::u128::MAX;
let mut best_index = 0;
@ -63,13 +65,12 @@ pub fn weighted_best(weights_and_indexes: &[(u64, usize)], mut rng: ChaChaRng) -
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
#[test]
fn test_weighted_shuffle_iterator() {
let mut test_set = [0; 6];
let mut count = 0;
let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], ChaChaRng::from_seed([0x5a; 32]));
let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], [0x5a; 32]);
shuffle.into_iter().for_each(|x| {
assert_eq!(test_set[x], 0);
test_set[x] = 1;
@ -84,7 +85,7 @@ mod tests {
let mut test_weights = vec![0; 100];
(0..100).for_each(|i| test_weights[i] = (i + 1) as u64);
let mut count = 0;
let shuffle = weighted_shuffle(test_weights, ChaChaRng::from_seed([0xa5; 32]));
let shuffle = weighted_shuffle(test_weights, [0xa5; 32]);
shuffle.into_iter().for_each(|x| {
assert_eq!(test_set[x], 0);
test_set[x] = 1;
@ -95,9 +96,9 @@ mod tests {
#[test]
fn test_weighted_shuffle_compare() {
let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], ChaChaRng::from_seed([0x5a; 32]));
let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], [0x5a; 32]);
let shuffle1 = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], ChaChaRng::from_seed([0x5a; 32]));
let shuffle1 = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], [0x5a; 32]);
shuffle1
.into_iter()
.zip(shuffle.into_iter())
@ -110,7 +111,7 @@ mod tests {
fn test_weighted_shuffle_imbalanced() {
let mut weights = vec![std::u32::MAX as u64; 3];
weights.push(1);
let shuffle = weighted_shuffle(weights.clone(), ChaChaRng::from_seed([0x5a; 32]));
let shuffle = weighted_shuffle(weights.clone(), [0x5a; 32]);
shuffle.into_iter().for_each(|x| {
if x == weights.len() - 1 {
assert_eq!(weights[x], 1);
@ -127,7 +128,7 @@ mod tests {
.enumerate()
.map(|(i, weight)| (weight, i))
.collect();
let best_index = weighted_best(&weights_and_indexes, ChaChaRng::from_seed([0x5b; 32]));
let best_index = weighted_best(&weights_and_indexes, [0x5b; 32]);
assert_eq!(best_index, 2);
}
}