implements an unbiased weighted shuffle using binary indexed tree (#18343)

Current implementation of weighted_shuffle:
https://github.com/solana-labs/solana/blob/b08f8bd1b/gossip/src/weighted_shuffle.rs#L11-L37
uses a heuristic which results in biased samples.

For example, if the weights are [1, 10, 100], then the 3rd index should
come first 100 times more often than the 1st index. However,
weighted_shuffle is picking the 3rd index 200+ times more often than the
1st index, showing a disproportional bias in favor of higher weights.

This commit implements weighted shuffle using binary indexed tree to
maintain cumulative sum of weights while sampling. The resulting samples
are demonstrably unbiased and precisely proportional to the weights.

Additionally the iterator interface allows to skip computations when
not all indices are processed.

Of the use cases of weighted_shuffle, changing turbine code requires
feature-gating to keep the cluster in sync. That is not updated in
this commit, but can be done together with future updates to turbine.
This commit is contained in:
behzad nouri 2021-07-07 14:14:43 +00:00 committed by GitHub
parent 72da25e9d2
commit dba42c57b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 266 additions and 31 deletions

View File

@ -0,0 +1,39 @@
#![feature(test)]
extern crate test;
use {
rand::{Rng, SeedableRng},
rand_chacha::ChaChaRng,
solana_gossip::weighted_shuffle::{weighted_shuffle, WeightedShuffle},
std::iter::repeat_with,
test::Bencher,
};
fn make_weights<R: Rng>(rng: &mut R) -> Vec<u64> {
repeat_with(|| rng.gen_range(1, 100)).take(1000).collect()
}
#[bench]
fn bench_weighted_shuffle_old(bencher: &mut Bencher) {
let mut seed = [0u8; 32];
let mut rng = rand::thread_rng();
let weights = make_weights(&mut rng);
bencher.iter(|| {
rng.fill(&mut seed[..]);
weighted_shuffle(&weights, seed);
});
}
#[bench]
fn bench_weighted_shuffle_new(bencher: &mut Bencher) {
let mut seed = [0u8; 32];
let mut rng = rand::thread_rng();
let weights = make_weights(&mut rng);
bencher.iter(|| {
rng.fill(&mut seed[..]);
WeightedShuffle::new(&mut ChaChaRng::from_seed(seed), &weights)
.unwrap()
.collect::<Vec<_>>()
});
}

View File

@ -29,7 +29,7 @@ use {
gossip_error::GossipError,
ping_pong::{self, PingCache, Pong},
socketaddr, socketaddr_any,
weighted_shuffle::weighted_shuffle,
weighted_shuffle::{weighted_shuffle, WeightedShuffle},
},
bincode::{serialize, serialized_size},
itertools::Itertools,
@ -2043,11 +2043,8 @@ impl ClusterInfo {
if responses.is_empty() {
return packets;
}
let shuffle = {
let mut seed = [0; 32];
rand::thread_rng().fill(&mut seed[..]);
weighted_shuffle(&scores, seed).into_iter()
};
let mut rng = rand::thread_rng();
let shuffle = WeightedShuffle::new(&mut rng, &scores).unwrap();
let mut total_bytes = 0;
let mut sent = 0;
for (addr, response) in shuffle.map(|i| &responses[i]) {

View File

@ -18,7 +18,7 @@ use {
crds_gossip_error::CrdsGossipError,
crds_value::CrdsValue,
ping_pong::PingCache,
weighted_shuffle::weighted_shuffle,
weighted_shuffle::WeightedShuffle,
},
itertools::Itertools,
lru::LruCache,
@ -235,13 +235,10 @@ impl CrdsGossipPull {
if peers.is_empty() {
return Err(CrdsGossipError::NoPeers);
}
let mut peers = {
let mut rng = rand::thread_rng();
let mut seed = [0u8; 32];
rng.fill(&mut seed[..]);
let index = weighted_shuffle(&weights, seed);
index.into_iter().map(|i| peers[i])
};
let mut rng = rand::thread_rng();
let mut peers = WeightedShuffle::new(&mut rng, &weights)
.unwrap()
.map(|i| peers[i]);
let peer = {
let mut rng = rand::thread_rng();
let mut ping_cache = ping_cache.lock().unwrap();
@ -273,7 +270,7 @@ impl CrdsGossipPull {
now: u64,
gossip_validators: Option<&HashSet<Pubkey>>,
stakes: &HashMap<Pubkey, u64>,
) -> Vec<(f32, &'a ContactInfo)> {
) -> Vec<(u64, &'a ContactInfo)> {
let mut rng = rand::thread_rng();
let active_cutoff = now.saturating_sub(PULL_ACTIVE_TIMEOUT_MS);
crds.get_nodes()
@ -307,7 +304,9 @@ impl CrdsGossipPull {
let since = (now.saturating_sub(req_time).min(3600 * 1000) / 1024) as u32;
let stake = get_stake(&item.id, stakes);
let weight = get_weight(max_weight, since, stake);
(weight, item)
// Weights are bounded by max_weight defined above.
// So this type-cast should be safe.
((weight * 100.0) as u64, item)
})
.collect()
}

View File

@ -16,7 +16,7 @@ use {
crds_gossip::{get_stake, get_weight},
crds_gossip_error::CrdsGossipError,
crds_value::CrdsValue,
weighted_shuffle::weighted_shuffle,
weighted_shuffle::WeightedShuffle,
},
bincode::serialized_size,
indexmap::map::IndexMap,
@ -119,6 +119,7 @@ impl CrdsGossipPush {
if peer_stake_total < prune_stake_threshold {
return Vec::new();
}
let mut rng = rand::thread_rng();
let shuffled_staked_peers = {
let peers: Vec<_> = peers
.iter()
@ -126,11 +127,9 @@ impl CrdsGossipPush {
.filter_map(|(peer, _)| Some((*peer, *stakes.get(peer)?)))
.filter(|(_, stake)| *stake > 0)
.collect();
let mut seed = [0; 32];
rand::thread_rng().fill(&mut seed[..]);
let weights: Vec<_> = peers.iter().map(|(_, stake)| *stake).collect();
weighted_shuffle(&weights, seed)
.into_iter()
WeightedShuffle::new(&mut rng, &weights)
.unwrap()
.map(move |i| peers[i])
};
let mut keep = HashSet::new();
@ -282,11 +281,7 @@ impl CrdsGossipPush {
return;
}
let num_bloom_items = MIN_NUM_BLOOM_ITEMS.max(network_size);
let shuffle = {
let mut seed = [0; 32];
rng.fill(&mut seed[..]);
weighted_shuffle(&weights, seed).into_iter()
};
let shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap();
for peer in shuffle.map(|i| peers[i].id) {
if new_items.len() >= need {
break;
@ -320,7 +315,7 @@ impl CrdsGossipPush {
self_shred_version: u16,
stakes: &HashMap<Pubkey, u64>,
gossip_validators: Option<&HashSet<Pubkey>>,
) -> Vec<(f32, &'a ContactInfo)> {
) -> Vec<(u64, &'a ContactInfo)> {
let now = timestamp();
let mut rng = rand::thread_rng();
let max_weight = u16::MAX as f32 - 1.0;
@ -356,7 +351,9 @@ impl CrdsGossipPush {
let since = (now.saturating_sub(last_pushed_to).min(3600 * 1000) / 1024) as u32;
let stake = get_stake(&info.id, stakes);
let weight = get_weight(max_weight, since, stake);
(weight, info)
// Weights are bounded by max_weight defined above.
// So this type-cast should be safe.
((weight * 100.0) as u64, info)
})
.collect()
}

View File

@ -2,12 +2,138 @@
use {
itertools::Itertools,
num_traits::{FromPrimitive, ToPrimitive},
rand::{Rng, SeedableRng},
num_traits::{CheckedAdd, FromPrimitive, ToPrimitive},
rand::{
distributions::uniform::{SampleUniform, UniformSampler},
Rng, SeedableRng,
},
rand_chacha::ChaChaRng,
std::{iter, ops::Div},
std::{
iter,
ops::{AddAssign, Div, Sub, SubAssign},
},
};
#[derive(Debug)]
pub enum WeightedShuffleError<T> {
NegativeWeight(T),
SumOverflow,
}
/// Implements an iterator where indices are shuffled according to their
/// weights:
/// - Returned indices are unique in the range [0, weights.len()).
/// - Higher weighted indices tend to appear earlier proportional to their
/// weight.
/// - Zero weighted indices are excluded. Therefore the iterator may have
/// count less than weights.len().
pub struct WeightedShuffle<'a, R, T> {
arr: Vec<T>, // Underlying array implementing binary indexed tree.
sum: T, // Current sum of weights, excluding already selected indices.
rng: &'a mut R, // Random number generator.
}
// The implementation uses binary indexed tree:
// https://en.wikipedia.org/wiki/Fenwick_tree
// to maintain cumulative sum of weights excluding already selected indices
// over self.arr.
impl<'a, R: Rng, T> WeightedShuffle<'a, R, T>
where
T: Copy + Default + PartialOrd + AddAssign + CheckedAdd,
{
/// Returns error if:
/// - any of the weights are negative.
/// - sum of weights overflows.
pub fn new(rng: &'a mut R, weights: &[T]) -> Result<Self, WeightedShuffleError<T>> {
let size = weights.len() + 1;
let zero = <T as Default>::default();
let mut arr = vec![zero; size];
let mut sum = zero;
for (mut k, &weight) in (1usize..).zip(weights) {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
if !(weight >= zero) {
return Err(WeightedShuffleError::NegativeWeight(weight));
}
sum = sum
.checked_add(&weight)
.ok_or(WeightedShuffleError::SumOverflow)?;
while k < size {
arr[k] += weight;
k += k & k.wrapping_neg();
}
}
Ok(Self { arr, sum, rng })
}
}
impl<'a, R, T> WeightedShuffle<'a, R, T>
where
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
{
// Returns cumulative sum of current weights upto index k (inclusive).
fn cumsum(&self, mut k: usize) -> T {
let mut out = <T as Default>::default();
while k != 0 {
out += self.arr[k];
k ^= k & k.wrapping_neg();
}
out
}
// Removes given weight at index k.
fn remove(&mut self, mut k: usize, weight: T) {
self.sum -= weight;
let size = self.arr.len();
while k < size {
self.arr[k] -= weight;
k += k & k.wrapping_neg();
}
}
// Returns smallest index such that self.cumsum(k) > val,
// along with its respective weight.
fn search(&self, val: T) -> (/*index:*/ usize, /*weight:*/ T) {
let zero = <T as Default>::default();
debug_assert!(val >= zero);
debug_assert!(val < self.sum);
let mut lo = (/*index:*/ 0, /*cumsum:*/ zero);
let mut hi = (self.arr.len() - 1, self.sum);
while lo.0 + 1 < hi.0 {
let k = lo.0 + (hi.0 - lo.0) / 2;
let sum = self.cumsum(k);
if sum <= val {
lo = (k, sum);
} else {
hi = (k, sum);
}
}
debug_assert!(lo.1 <= val);
debug_assert!(hi.1 > val);
(hi.0, hi.1 - lo.1)
}
}
impl<'a, R: Rng, T> Iterator for WeightedShuffle<'a, R, T>
where
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
let zero = <T as Default>::default();
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// self.sum <= zero does not work for NaNs.
if !(self.sum > zero) {
return None;
}
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.sum, &mut self.rng);
let (index, weight) = WeightedShuffle::search(self, sample);
self.remove(index, weight);
Some(index - 1)
}
}
/// Returns a list of indexes shuffled based on the input weights
/// Note - The sum of all weights must not exceed `u64::MAX`
pub fn weighted_shuffle<T>(weights: &[T], seed: [u8; 32]) -> Vec<usize>
@ -67,6 +193,31 @@ pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> us
#[cfg(test)]
mod tests {
use super::*;
use std::{convert::TryInto, iter::repeat_with};
fn weighted_shuffle_slow<R>(rng: &mut R, mut weights: Vec<u64>) -> Vec<usize>
where
R: Rng,
{
let mut shuffle = Vec::with_capacity(weights.len());
loop {
let high: u64 = weights.iter().sum();
if high == 0 {
break shuffle;
}
let sample = rng.gen_range(0, high);
let index = weights
.iter()
.scan(0, |acc, &w| {
*acc += w;
Some(*acc)
})
.position(|acc| sample < acc)
.unwrap();
shuffle.push(index);
weights[index] = 0;
}
}
#[test]
fn test_weighted_shuffle_iterator() {
@ -133,4 +284,56 @@ mod tests {
let best_index = weighted_best(&weights_and_indexes, [0x5b; 32]);
assert_eq!(best_index, 2);
}
// Asserts that each index is selected proportional to its weight.
#[test]
fn test_weighted_shuffle_sanity() {
let seed: Vec<_> = (1..).step_by(3).take(32).collect();
let seed: [u8; 32] = seed.try_into().unwrap();
let mut rng = ChaChaRng::from_seed(seed);
let weights = [1, 1000, 10, 100];
let mut counts = [0; 4];
for _ in 0..100000 {
let mut shuffle = WeightedShuffle::new(&mut rng, &weights).unwrap();
counts[shuffle.next().unwrap()] += 1;
let _ = shuffle.count(); // consume the rest.
}
assert_eq!(counts, [101, 90113, 891, 8895]);
}
#[test]
fn test_weighted_shuffle_hard_coded() {
let weights = [
78, 70, 38, 27, 21, 0, 82, 42, 21, 77, 77, 17, 4, 50, 96, 83, 33, 16, 72,
];
let seed = [48u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
assert_eq!(
shuffle,
[2, 11, 16, 0, 13, 14, 15, 10, 1, 9, 7, 6, 12, 18, 4, 17, 3, 8]
);
let seed = [37u8; 32];
let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
assert_eq!(
shuffle,
[17, 3, 14, 13, 6, 10, 15, 16, 9, 2, 4, 1, 0, 7, 8, 18, 11, 12]
);
}
#[test]
fn test_weighted_shuffle_match_slow() {
let mut rng = rand::thread_rng();
let weights: Vec<u64> = repeat_with(|| rng.gen_range(0, 1000)).take(997).collect();
for _ in 0..10 {
let mut seed = [0u8; 32];
rng.fill(&mut seed[..]);
let mut rng = ChaChaRng::from_seed(seed);
let shuffle: Vec<_> = WeightedShuffle::new(&mut rng, &weights).unwrap().collect();
let mut rng = ChaChaRng::from_seed(seed);
let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone());
assert_eq!(shuffle, shuffle_slow,);
}
}
}