compute merkle root on chunks of fanout^3 (#15344)

* compute merkle root on chunks of fanout^3

* improve test_accountsdb_compute_merkle_root_large
This commit is contained in:
Jeff Washington (jwash) 2021-02-16 17:03:35 -06:00 committed by GitHub
parent ba02452d75
commit 8367740ff9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 114 additions and 64 deletions

View File

@ -3624,6 +3624,7 @@ impl AccountsDB {
fn compute_merkle_root_from_slices<'a, F>(
total_hashes: usize,
fanout: usize,
max_levels_per_pass: Option<usize>,
get_hashes: F,
) -> Hash
where
@ -3635,7 +3636,17 @@ impl AccountsDB {
let mut time = Measure::start("time");
let chunks = Self::div_ceil(total_hashes, fanout);
const THREE_LEVEL_OPTIMIZATION: usize = 3; // this '3' is dependent on the code structure below where we manually unroll
let target = fanout.pow(THREE_LEVEL_OPTIMIZATION as u32);
// Only use the 3 level optimization if we have at least 4 levels of data.
// Otherwise, we'll be serializing a parallel operation.
let threshold = target * fanout;
let three_level = max_levels_per_pass.unwrap_or(usize::MAX) >= THREE_LEVEL_OPTIMIZATION
&& total_hashes >= threshold;
let num_hashes_per_chunk = if three_level { target } else { fanout };
let chunks = Self::div_ceil(total_hashes, num_hashes_per_chunk);
// initial fetch - could return entire slice
let data: &[Hash] = get_hashes(0);
@ -3644,24 +3655,54 @@ impl AccountsDB {
let result: Vec<_> = (0..chunks)
.into_par_iter()
.map(|i| {
let start_index = i * fanout;
let end_index = std::cmp::min(start_index + fanout, total_hashes);
let start_index = i * num_hashes_per_chunk;
let end_index = std::cmp::min(start_index + num_hashes_per_chunk, total_hashes);
let mut hasher = Hasher::default();
let mut data_index = start_index;
let mut data = data;
let mut data_len = data_len;
for i in start_index..end_index {
if data_index >= data_len {
// fetch next slice
data = get_hashes(i);
data_len = data.len();
data_index = 0;
if !three_level {
// 1 group of fanout
// The result of this loop is a single hash value from fanout input hashes.
for i in start_index..end_index {
if data_index >= data_len {
// fetch next slice
data = get_hashes(i);
data_len = data.len();
data_index = 0;
}
hasher.hash(data[data_index].as_ref());
data_index += 1;
}
} else {
// hash 3 levels of fanout simultaneously.
// The result of this loop is a single hash value from fanout^3 input hashes.
let mut i = start_index;
while i < end_index {
let mut hasher_j = Hasher::default();
for _j in 0..fanout {
let mut hasher_k = Hasher::default();
let end = std::cmp::min(end_index - i, fanout);
for _k in 0..end {
if data_index >= data_len {
// fetch next slice
data = get_hashes(i);
data_len = data.len();
data_index = 0;
}
hasher_k.hash(data[data_index].as_ref());
data_index += 1;
i += 1;
}
hasher_j.hash(hasher_k.result().as_ref());
if i >= end_index {
break;
}
}
hasher.hash(hasher_j.result().as_ref());
}
hasher.hash(data[data_index].as_ref());
data_index += 1;
}
hasher.result()
@ -3823,10 +3864,12 @@ impl AccountsDB {
let hash_total = cumulative_offsets.total_count;
let total_lamports = *total_lamports.lock().unwrap();
let mut hash_time = Measure::start("hash");
let accumulated_hash =
Self::compute_merkle_root_from_slices(hash_total, MERKLE_FANOUT, |start: usize| {
cumulative_offsets.get_slice(&hashes, start)
});
let accumulated_hash = Self::compute_merkle_root_from_slices(
hash_total,
MERKLE_FANOUT,
None,
|start: usize| cumulative_offsets.get_slice(&hashes, start),
);
hash_time.stop();
datapoint_info!(
"update_accounts_hash",
@ -4118,7 +4161,8 @@ impl AccountsDB {
let offsets = CumulativeOffsets::from_raw_2d(&hashes);
let get_slice = |start: usize| -> &[Hash] { offsets.get_slice_2d(&hashes, start) };
let hash = Self::compute_merkle_root_from_slices(offsets.total_count, fanout, get_slice);
let hash =
Self::compute_merkle_root_from_slices(offsets.total_count, fanout, None, get_slice);
hash_time.stop();
stats.hash_time_total_us += hash_time.as_us();
stats.hash_total = offsets.total_count;
@ -6459,55 +6503,44 @@ pub mod tests {
fn test_hashing_larger(hashes: Vec<(Pubkey, Hash)>, fanout: usize) -> Hash {
let result = AccountsDB::compute_merkle_root(hashes.clone(), fanout);
if hashes.len() >= fanout * fanout * fanout {
let reduced: Vec<_> = hashes.iter().map(|x| x.1).collect();
let result2 =
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, |start| {
&reduced[start..]
});
assert_eq!(result, result2);
let reduced2: Vec<_> = hashes.iter().map(|x| vec![x.1]).collect();
let result2 = AccountsDB::flatten_hashes_and_hash(
vec![reduced2],
fanout,
&mut HashStats::default(),
);
assert_eq!(result, result2);
for left in 0..reduced.len() {
for right in left + 1..reduced.len() {
let src = vec![
vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()],
vec![reduced[right..].to_vec()],
];
let result2 =
AccountsDB::flatten_hashes_and_hash(src, fanout, &mut HashStats::default());
assert_eq!(result, result2);
}
}
}
let reduced: Vec<_> = hashes.iter().map(|x| x.1).collect();
let result2 = test_hashing(reduced, fanout);
assert_eq!(result, result2, "len: {}", hashes.len());
result
}
fn test_hashing(hashes: Vec<Hash>, fanout: usize) -> Hash {
let temp: Vec<_> = hashes.iter().map(|h| (Pubkey::default(), *h)).collect();
let result = AccountsDB::compute_merkle_root(temp, fanout);
if hashes.len() >= fanout * fanout * fanout {
let reduced: Vec<_> = hashes.clone();
let result2 =
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, |start| {
&reduced[start..]
});
assert_eq!(result, result2, "len: {}", hashes.len());
let reduced: Vec<_> = hashes.clone();
let result2 =
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, None, |start| {
&reduced[start..]
});
assert_eq!(result, result2, "len: {}", hashes.len());
let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect();
let result2 = AccountsDB::flatten_hashes_and_hash(
vec![reduced2],
fanout,
&mut HashStats::default(),
);
assert_eq!(result, result2, "len: {}", hashes.len());
let result2 =
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, Some(1), |start| {
&reduced[start..]
});
assert_eq!(result, result2, "len: {}", hashes.len());
let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect();
let result2 =
AccountsDB::flatten_hashes_and_hash(vec![reduced2], fanout, &mut HashStats::default());
assert_eq!(result, result2, "len: {}", hashes.len());
let max = std::cmp::min(reduced.len(), fanout * 2);
for left in 0..max {
for right in left + 1..max {
let src = vec![
vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()],
vec![reduced[right..].to_vec()],
];
let result2 =
AccountsDB::flatten_hashes_and_hash(src, fanout, &mut HashStats::default());
assert_eq!(result, result2);
}
}
result
}
@ -6516,12 +6549,29 @@ pub mod tests {
fn test_accountsdb_compute_merkle_root_large() {
solana_logger::setup();
let mut num = 100;
for _pass in 0..2 {
num *= 10;
let hashes: Vec<_> = (0..num).into_iter().map(|_| Hash::new_unique()).collect();
// handle fanout^x -1, +0, +1 for a few 'x's
const FANOUT: usize = 3;
let mut hash_counts: Vec<_> = (1..6)
.map(|x| {
let mark = FANOUT.pow(x);
vec![mark - 1, mark, mark + 1]
})
.flatten()
.collect();
test_hashing(hashes, MERKLE_FANOUT);
// saturate the test space for threshold to threshold + target
// this hits right before we use the 3 deep optimization and all the way through all possible partial last chunks
let target = FANOUT.pow(3);
let threshold = target * FANOUT;
hash_counts.extend(threshold - 1..=threshold + target);
for hash_count in hash_counts {
let hashes: Vec<_> = (0..hash_count)
.into_iter()
.map(|_| Hash::new_unique())
.collect();
test_hashing(hashes, FANOUT);
}
}