shards crds values based on their hash prefix (#12187)

filter_crds_values checks every crds filter against every hash value:
https://github.com/solana-labs/solana/blob/ee646aa7/core/src/crds_gossip_pull.rs#L432
which can be inefficient if the filter's bit-mask only matches small
portion of the entire crds table.

This commit shards crds values into separate tables based on shard_bits
first bits of their hash prefix. Given a (mask, mask_bits) filter,
filtering crds can be done by inspecting only relevant shards.

If CrdsFilter.mask_bits <= shard_bits, then precisely only the crds
values which match (mask, mask_bits) bit pattern are traversed.
If CrdsFilter.mask_bits > shard_bits, then approximately only
1/2^shard_bits of crds values are inspected.

Benchmarking on a gce cluster of 20 nodes, I see ~10% improvement in
generate_pull_responses metric, but with larger clusters, crds table and
2^mask_bits are both larger, so the impact should be more significant.
This commit is contained in:
behzad nouri 2020-09-17 14:05:16 +00:00 committed by GitHub
parent 5546b6b676
commit 9b866d79fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 414 additions and 60 deletions

View File

@ -97,6 +97,9 @@ name = "blockstore"
[[bench]]
name = "crds_gossip_pull"
[[bench]]
name = "crds_shards"
[[bench]]
name = "gen_keys"

View File

@ -0,0 +1,69 @@
#![feature(test)]
extern crate test;
use rand::{thread_rng, Rng};
use solana_core::contact_info::ContactInfo;
use solana_core::crds::VersionedCrdsValue;
use solana_core::crds_shards::CrdsShards;
use solana_core::crds_value::{CrdsData, CrdsValue};
use solana_sdk::pubkey::Pubkey;
use solana_sdk::timing::timestamp;
use test::Bencher;
const CRDS_SHARDS_BITS: u32 = 8;
fn new_test_crds_value() -> VersionedCrdsValue {
let data = CrdsData::ContactInfo(ContactInfo::new_localhost(&Pubkey::new_rand(), timestamp()));
VersionedCrdsValue::new(timestamp(), CrdsValue::new_unsigned(data))
}
fn bench_crds_shards_find(bencher: &mut Bencher, num_values: usize, mask_bits: u32) {
let values: Vec<VersionedCrdsValue> = std::iter::repeat_with(new_test_crds_value)
.take(num_values)
.collect();
let mut shards = CrdsShards::new(CRDS_SHARDS_BITS);
for (index, value) in values.iter().enumerate() {
assert!(shards.insert(index, value));
}
let mut rng = thread_rng();
bencher.iter(|| {
let mask = rng.gen();
let _hits = shards.find(mask, mask_bits).count();
});
}
#[bench]
fn bench_crds_shards_find_0(bencher: &mut Bencher) {
bench_crds_shards_find(bencher, 100_000, 0);
}
#[bench]
fn bench_crds_shards_find_1(bencher: &mut Bencher) {
bench_crds_shards_find(bencher, 100_000, 1);
}
#[bench]
fn bench_crds_shards_find_3(bencher: &mut Bencher) {
bench_crds_shards_find(bencher, 100_000, 3);
}
#[bench]
fn bench_crds_shards_find_5(bencher: &mut Bencher) {
bench_crds_shards_find(bencher, 100_000, 5);
}
#[bench]
fn bench_crds_shards_find_7(bencher: &mut Bencher) {
bench_crds_shards_find(bencher, 100_000, 7);
}
#[bench]
fn bench_crds_shards_find_8(bencher: &mut Bencher) {
bench_crds_shards_find(bencher, 100_000, 8);
}
#[bench]
fn bench_crds_shards_find_9(bencher: &mut Bencher) {
bench_crds_shards_find(bencher, 100_000, 9);
}

View File

@ -24,22 +24,24 @@
//! A value is updated to a new version if the labels match, and the value
//! wallclock is later, or the value hash is greater.
use crate::crds_gossip_pull::CrdsFilter;
use crate::crds_shards::CrdsShards;
use crate::crds_value::{CrdsValue, CrdsValueLabel};
use bincode::serialize;
use indexmap::map::IndexMap;
use indexmap::map::{Entry, IndexMap};
use solana_sdk::hash::{hash, Hash};
use solana_sdk::pubkey::Pubkey;
use std::cmp;
use std::collections::HashMap;
use std::ops::Index;
const CRDS_SHARDS_BITS: u32 = 8;
#[derive(Clone)]
pub struct Crds {
/// Stores the map of labels and values
pub table: IndexMap<CrdsValueLabel, VersionedCrdsValue>,
pub num_inserts: usize,
pub masks: IndexMap<CrdsValueLabel, u64>,
pub shards: CrdsShards,
}
#[derive(PartialEq, Debug)]
@ -89,7 +91,7 @@ impl Default for Crds {
Crds {
table: IndexMap::new(),
num_inserts: 0,
masks: IndexMap::new(),
shards: CrdsShards::new(CRDS_SHARDS_BITS),
}
}
}
@ -123,23 +125,28 @@ impl Crds {
new_value: VersionedCrdsValue,
) -> Result<Option<VersionedCrdsValue>, CrdsError> {
let label = new_value.value.label();
let wallclock = new_value.value.wallclock();
let do_insert = self
.table
.get(&label)
.map(|current| new_value > *current)
.unwrap_or(true);
if do_insert {
self.masks.insert(
label.clone(),
CrdsFilter::hash_as_u64(&new_value.value_hash),
);
let old = self.table.insert(label, new_value);
self.num_inserts += 1;
Ok(old)
} else {
trace!("INSERT FAILED data: {} new.wallclock: {}", label, wallclock,);
Err(CrdsError::InsertFailed)
match self.table.entry(label) {
Entry::Vacant(entry) => {
assert!(self.shards.insert(entry.index(), &new_value));
entry.insert(new_value);
self.num_inserts += 1;
Ok(None)
}
Entry::Occupied(mut entry) if *entry.get() < new_value => {
let index = entry.index();
assert!(self.shards.remove(index, entry.get()));
assert!(self.shards.insert(index, &new_value));
self.num_inserts += 1;
Ok(Some(entry.insert(new_value)))
}
_ => {
trace!(
"INSERT FAILED data: {} new.wallclock: {}",
new_value.value.label(),
new_value.value.wallclock(),
);
Err(CrdsError::InsertFailed)
}
}
}
pub fn insert(
@ -200,8 +207,16 @@ impl Crds {
}
pub fn remove(&mut self, key: &CrdsValueLabel) {
self.table.swap_remove(key);
self.masks.swap_remove(key);
if let Some((index, _, value)) = self.table.swap_remove_full(key) {
assert!(self.shards.remove(index, &value));
// The previously last element in the table is now moved to the
// 'index' position. Shards need to be updated accordingly.
if index < self.table.len() {
let value = self.table.index(index);
assert!(self.shards.remove(self.table.len(), value));
assert!(self.shards.insert(index, value));
}
}
}
}
@ -210,6 +225,7 @@ mod test {
use super::*;
use crate::contact_info::ContactInfo;
use crate::crds_value::CrdsData;
use rand::{thread_rng, Rng};
#[test]
fn test_insert() {
@ -329,6 +345,45 @@ mod test {
assert_eq!(crds.find_old_labels(3, &set), vec![val.label()]);
}
#[test]
fn test_crds_shards() {
fn check_crds_shards(crds: &Crds) {
crds.shards
.check(&crds.table.values().cloned().collect::<Vec<_>>());
}
let mut crds = Crds::default();
let pubkeys: Vec<_> = std::iter::repeat_with(Pubkey::new_rand).take(256).collect();
let mut rng = thread_rng();
let mut num_inserts = 0;
for _ in 0..4096 {
let pubkey = pubkeys[rng.gen_range(0, pubkeys.len())];
let value = VersionedCrdsValue::new(
rng.gen(), // local_timestamp
CrdsValue::new_unsigned(CrdsData::ContactInfo(ContactInfo::new_localhost(
&pubkey,
rng.gen(), // now
))),
);
if crds.insert_versioned(value).is_ok() {
check_crds_shards(&crds);
num_inserts += 1;
}
}
assert_eq!(num_inserts, crds.num_inserts);
assert!(num_inserts > 700);
assert!(crds.table.len() > 200);
assert!(num_inserts > crds.table.len());
check_crds_shards(&crds);
// Remove values one by one and assert that shards stay valid.
while !crds.table.is_empty() {
let index = rng.gen_range(0, crds.table.len());
let key = crds.table.get_index(index).unwrap().0.clone();
crds.remove(&key);
check_crds_shards(&crds);
}
}
#[test]
fn test_remove_staked() {
let mut crds = Crds::default();

View File

@ -23,6 +23,7 @@ use std::cmp;
use std::collections::VecDeque;
use std::collections::{HashMap, HashSet};
use std::convert::TryInto;
use std::ops::Index;
pub const CRDS_GOSSIP_PULL_CRDS_TIMEOUT_MS: u64 = 15000;
// The maximum age of a value received over pull responses
@ -418,52 +419,44 @@ impl CrdsGossipPull {
filters: &[(CrdsValue, CrdsFilter)],
now: u64,
) -> Vec<Vec<CrdsValue>> {
let mut ret = vec![vec![]; filters.len()];
let msg_timeout = CRDS_GOSSIP_PULL_CRDS_TIMEOUT_MS;
let jitter = rand::thread_rng().gen_range(0, msg_timeout / 4);
let start = filters.len();
//skip filters from callers that are too old
let future = now.saturating_add(msg_timeout);
let past = now.saturating_sub(msg_timeout);
let recent: Vec<_> = filters
let mut dropped_requests = 0;
let mut total_skipped = 0;
let ret = filters
.iter()
.enumerate()
.filter(|(_, (caller, _))| caller.wallclock() < future && caller.wallclock() >= past)
.map(|(caller, filter)| {
let caller_wallclock = caller.wallclock();
if caller_wallclock >= future || caller_wallclock < past {
dropped_requests += 1;
return vec![];
}
let caller_wallclock = caller_wallclock.checked_add(jitter).unwrap_or(0);
crds.shards
.find(filter.mask, filter.mask_bits)
.filter_map(|index| {
let item = crds.table.index(index);
debug_assert!(filter.test_mask(&item.value_hash));
//skip values that are too new
if item.value.wallclock() > caller_wallclock {
total_skipped += 1;
None
} else if filter.filter_contains(&item.value_hash) {
None
} else {
Some(item.value.clone())
}
})
.collect()
})
.collect();
inc_new_counter_info!(
"gossip_filter_crds_values-dropped_requests",
start - recent.len()
dropped_requests
);
if recent.is_empty() {
return ret;
}
let mut total_skipped = 0;
let mask_ones: Vec<_> = recent
.iter()
.map(|(_i, (_caller, filter))| (!0u64).checked_shr(filter.mask_bits).unwrap_or(!0u64))
.collect();
for (label, mask) in crds.masks.iter() {
recent
.iter()
.zip(mask_ones.iter())
.for_each(|((i, (caller, filter)), mask_ones)| {
if filter.test_mask_u64(*mask, *mask_ones) {
let item = crds.table.get(label).unwrap();
//skip values that are too new
if item.value.wallclock()
> caller.wallclock().checked_add(jitter).unwrap_or_else(|| 0)
{
total_skipped += 1;
return;
}
if !filter.filter_contains(&item.value_hash) {
ret[*i].push(item.value.clone());
}
}
});
}
inc_new_counter_info!("gossip_filter_crds_values-dropped_values", total_skipped);
ret
}

233
core/src/crds_shards.rs Normal file
View File

@ -0,0 +1,233 @@
use crate::crds::VersionedCrdsValue;
use crate::crds_gossip_pull::CrdsFilter;
use indexmap::map::IndexMap;
use std::cmp::Ordering;
use std::ops::{Index, IndexMut};
#[derive(Clone)]
pub struct CrdsShards {
// shards[k] includes crds values which the first shard_bits of their hash
// value is equal to k. Each shard is a mapping from crds values indices to
// their hash value.
shards: Vec<IndexMap<usize, u64>>,
shard_bits: u32,
}
impl CrdsShards {
pub fn new(shard_bits: u32) -> Self {
CrdsShards {
shards: vec![IndexMap::new(); 1 << shard_bits],
shard_bits,
}
}
#[must_use]
pub fn insert(&mut self, index: usize, value: &VersionedCrdsValue) -> bool {
let hash = CrdsFilter::hash_as_u64(&value.value_hash);
self.shard_mut(hash).insert(index, hash).is_none()
}
#[must_use]
pub fn remove(&mut self, index: usize, value: &VersionedCrdsValue) -> bool {
let hash = CrdsFilter::hash_as_u64(&value.value_hash);
self.shard_mut(hash).swap_remove(&index).is_some()
}
/// Returns indices of all crds values which the first 'mask_bits' of their
/// hash value is equal to 'mask'.
pub fn find(&self, mask: u64, mask_bits: u32) -> impl Iterator<Item = usize> + '_ {
let ones = (!0u64).checked_shr(mask_bits).unwrap_or(0);
let mask = mask | ones;
match self.shard_bits.cmp(&mask_bits) {
Ordering::Less => {
let pred = move |(&index, hash)| {
if hash | ones == mask {
Some(index)
} else {
None
}
};
Iter::Less(self.shard(mask).iter().filter_map(pred))
}
Ordering::Equal => Iter::Equal(self.shard(mask).keys().cloned()),
Ordering::Greater => {
let count = 1 << (self.shard_bits - mask_bits);
let end = self.shard_index(mask) + 1;
Iter::Greater(
self.shards[end - count..end]
.iter()
.flat_map(IndexMap::keys)
.cloned(),
)
}
}
}
#[inline]
fn shard_index(&self, hash: u64) -> usize {
hash.checked_shr(64 - self.shard_bits).unwrap_or(0) as usize
}
#[inline]
fn shard(&self, hash: u64) -> &IndexMap<usize, u64> {
let shard_index = self.shard_index(hash);
self.shards.index(shard_index)
}
#[inline]
fn shard_mut(&mut self, hash: u64) -> &mut IndexMap<usize, u64> {
let shard_index = self.shard_index(hash);
self.shards.index_mut(shard_index)
}
// Checks invariants in the shards tables against the crds table.
#[cfg(test)]
pub fn check(&self, crds: &[VersionedCrdsValue]) {
let mut indices: Vec<_> = self
.shards
.iter()
.flat_map(IndexMap::keys)
.cloned()
.collect();
indices.sort_unstable();
assert_eq!(indices, (0..crds.len()).collect::<Vec<_>>());
for (shard_index, shard) in self.shards.iter().enumerate() {
for (&index, &hash) in shard {
assert_eq!(hash, CrdsFilter::hash_as_u64(&crds[index].value_hash));
assert_eq!(
shard_index as u64,
hash.checked_shr(64 - self.shard_bits).unwrap_or(0)
);
}
}
}
}
// Wrapper for 3 types of iterators we get when comparing shard_bits and
// mask_bits in find method. This is to avoid Box<dyn Iterator<Item =...>>
// which involves dynamic dispatch and is relatively slow.
enum Iter<R, S, T> {
Less(R),
Equal(S),
Greater(T),
}
impl<R, S, T> Iterator for Iter<R, S, T>
where
R: Iterator<Item = usize>,
S: Iterator<Item = usize>,
T: Iterator<Item = usize>,
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Greater(iter) => iter.next(),
Self::Less(iter) => iter.next(),
Self::Equal(iter) => iter.next(),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::contact_info::ContactInfo;
use crate::crds_value::{CrdsData, CrdsValue};
use rand::{thread_rng, Rng};
use solana_sdk::pubkey::Pubkey;
use solana_sdk::timing::timestamp;
use std::collections::HashSet;
use std::ops::Index;
fn new_test_crds_value() -> VersionedCrdsValue {
let data =
CrdsData::ContactInfo(ContactInfo::new_localhost(&Pubkey::new_rand(), timestamp()));
VersionedCrdsValue::new(timestamp(), CrdsValue::new_unsigned(data))
}
// Returns true if the first mask_bits most significant bits of hash is the
// same as the given bit mask.
fn check_mask(value: &VersionedCrdsValue, mask: u64, mask_bits: u32) -> bool {
let hash = CrdsFilter::hash_as_u64(&value.value_hash);
let ones = (!0u64).checked_shr(mask_bits).unwrap_or(0u64);
(hash | ones) == (mask | ones)
}
// Manual filtering by scanning all the values.
fn filter_crds_values(
values: &[VersionedCrdsValue],
mask: u64,
mask_bits: u32,
) -> HashSet<usize> {
values
.iter()
.enumerate()
.filter_map(|(index, value)| {
if check_mask(value, mask, mask_bits) {
Some(index)
} else {
None
}
})
.collect()
}
#[test]
fn test_crds_shards_round_trip() {
let mut rng = thread_rng();
// Generate some random hash and crds value labels.
let mut values: Vec<_> = std::iter::repeat_with(new_test_crds_value)
.take(4096)
.collect();
// Insert everything into the crds shards.
let mut shards = CrdsShards::new(5);
for (index, value) in values.iter().enumerate() {
assert!(shards.insert(index, value));
}
shards.check(&values);
// Remove some of the values.
for _ in 0..512 {
let index = rng.gen_range(0, values.len());
let value = values.swap_remove(index);
assert!(shards.remove(index, &value));
if index < values.len() {
let value = values.index(index);
assert!(shards.remove(values.len(), value));
assert!(shards.insert(index, value));
}
shards.check(&values);
}
// Random masks.
for _ in 0..10 {
let mask = rng.gen();
for mask_bits in 0..12 {
let mut set = filter_crds_values(&values, mask, mask_bits);
for index in shards.find(mask, mask_bits) {
assert!(set.remove(&index));
}
assert!(set.is_empty());
}
}
// Existing hash values.
for (index, value) in values.iter().enumerate() {
let mask = CrdsFilter::hash_as_u64(&value.value_hash);
let hits: Vec<_> = shards.find(mask, 64).collect();
assert_eq!(hits, vec![index]);
}
// Remove everything.
while !values.is_empty() {
let index = rng.gen_range(0, values.len());
let value = values.swap_remove(index);
assert!(shards.remove(index, &value));
if index < values.len() {
let value = values.index(index);
assert!(shards.remove(values.len(), value));
assert!(shards.insert(index, value));
}
if index % 5 == 0 {
shards.check(&values);
}
}
}
}

View File

@ -29,6 +29,7 @@ pub mod crds_gossip;
pub mod crds_gossip_error;
pub mod crds_gossip_pull;
pub mod crds_gossip_push;
pub mod crds_shards;
pub mod crds_value;
pub mod epoch_slots;
pub mod fetch_stage;