compute merkle root on chunks of fanout^3 (#15344)
* compute merkle root on chunks of fanout^3 * improve test_accountsdb_compute_merkle_root_large
This commit is contained in:
parent
ba02452d75
commit
8367740ff9
|
@ -3624,6 +3624,7 @@ impl AccountsDB {
|
|||
fn compute_merkle_root_from_slices<'a, F>(
|
||||
total_hashes: usize,
|
||||
fanout: usize,
|
||||
max_levels_per_pass: Option<usize>,
|
||||
get_hashes: F,
|
||||
) -> Hash
|
||||
where
|
||||
|
@ -3635,7 +3636,17 @@ impl AccountsDB {
|
|||
|
||||
let mut time = Measure::start("time");
|
||||
|
||||
let chunks = Self::div_ceil(total_hashes, fanout);
|
||||
const THREE_LEVEL_OPTIMIZATION: usize = 3; // this '3' is dependent on the code structure below where we manually unroll
|
||||
let target = fanout.pow(THREE_LEVEL_OPTIMIZATION as u32);
|
||||
|
||||
// Only use the 3 level optimization if we have at least 4 levels of data.
|
||||
// Otherwise, we'll be serializing a parallel operation.
|
||||
let threshold = target * fanout;
|
||||
let three_level = max_levels_per_pass.unwrap_or(usize::MAX) >= THREE_LEVEL_OPTIMIZATION
|
||||
&& total_hashes >= threshold;
|
||||
let num_hashes_per_chunk = if three_level { target } else { fanout };
|
||||
|
||||
let chunks = Self::div_ceil(total_hashes, num_hashes_per_chunk);
|
||||
|
||||
// initial fetch - could return entire slice
|
||||
let data: &[Hash] = get_hashes(0);
|
||||
|
@ -3644,24 +3655,54 @@ impl AccountsDB {
|
|||
let result: Vec<_> = (0..chunks)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
let start_index = i * fanout;
|
||||
let end_index = std::cmp::min(start_index + fanout, total_hashes);
|
||||
let start_index = i * num_hashes_per_chunk;
|
||||
let end_index = std::cmp::min(start_index + num_hashes_per_chunk, total_hashes);
|
||||
|
||||
let mut hasher = Hasher::default();
|
||||
let mut data_index = start_index;
|
||||
let mut data = data;
|
||||
let mut data_len = data_len;
|
||||
|
||||
for i in start_index..end_index {
|
||||
if data_index >= data_len {
|
||||
// fetch next slice
|
||||
data = get_hashes(i);
|
||||
data_len = data.len();
|
||||
data_index = 0;
|
||||
if !three_level {
|
||||
// 1 group of fanout
|
||||
// The result of this loop is a single hash value from fanout input hashes.
|
||||
for i in start_index..end_index {
|
||||
if data_index >= data_len {
|
||||
// fetch next slice
|
||||
data = get_hashes(i);
|
||||
data_len = data.len();
|
||||
data_index = 0;
|
||||
}
|
||||
hasher.hash(data[data_index].as_ref());
|
||||
data_index += 1;
|
||||
}
|
||||
} else {
|
||||
// hash 3 levels of fanout simultaneously.
|
||||
// The result of this loop is a single hash value from fanout^3 input hashes.
|
||||
let mut i = start_index;
|
||||
while i < end_index {
|
||||
let mut hasher_j = Hasher::default();
|
||||
for _j in 0..fanout {
|
||||
let mut hasher_k = Hasher::default();
|
||||
let end = std::cmp::min(end_index - i, fanout);
|
||||
for _k in 0..end {
|
||||
if data_index >= data_len {
|
||||
// fetch next slice
|
||||
data = get_hashes(i);
|
||||
data_len = data.len();
|
||||
data_index = 0;
|
||||
}
|
||||
hasher_k.hash(data[data_index].as_ref());
|
||||
data_index += 1;
|
||||
i += 1;
|
||||
}
|
||||
hasher_j.hash(hasher_k.result().as_ref());
|
||||
if i >= end_index {
|
||||
break;
|
||||
}
|
||||
}
|
||||
hasher.hash(hasher_j.result().as_ref());
|
||||
}
|
||||
|
||||
hasher.hash(data[data_index].as_ref());
|
||||
data_index += 1;
|
||||
}
|
||||
|
||||
hasher.result()
|
||||
|
@ -3823,10 +3864,12 @@ impl AccountsDB {
|
|||
let hash_total = cumulative_offsets.total_count;
|
||||
let total_lamports = *total_lamports.lock().unwrap();
|
||||
let mut hash_time = Measure::start("hash");
|
||||
let accumulated_hash =
|
||||
Self::compute_merkle_root_from_slices(hash_total, MERKLE_FANOUT, |start: usize| {
|
||||
cumulative_offsets.get_slice(&hashes, start)
|
||||
});
|
||||
let accumulated_hash = Self::compute_merkle_root_from_slices(
|
||||
hash_total,
|
||||
MERKLE_FANOUT,
|
||||
None,
|
||||
|start: usize| cumulative_offsets.get_slice(&hashes, start),
|
||||
);
|
||||
hash_time.stop();
|
||||
datapoint_info!(
|
||||
"update_accounts_hash",
|
||||
|
@ -4118,7 +4161,8 @@ impl AccountsDB {
|
|||
let offsets = CumulativeOffsets::from_raw_2d(&hashes);
|
||||
|
||||
let get_slice = |start: usize| -> &[Hash] { offsets.get_slice_2d(&hashes, start) };
|
||||
let hash = Self::compute_merkle_root_from_slices(offsets.total_count, fanout, get_slice);
|
||||
let hash =
|
||||
Self::compute_merkle_root_from_slices(offsets.total_count, fanout, None, get_slice);
|
||||
hash_time.stop();
|
||||
stats.hash_time_total_us += hash_time.as_us();
|
||||
stats.hash_total = offsets.total_count;
|
||||
|
@ -6459,55 +6503,44 @@ pub mod tests {
|
|||
|
||||
fn test_hashing_larger(hashes: Vec<(Pubkey, Hash)>, fanout: usize) -> Hash {
|
||||
let result = AccountsDB::compute_merkle_root(hashes.clone(), fanout);
|
||||
if hashes.len() >= fanout * fanout * fanout {
|
||||
let reduced: Vec<_> = hashes.iter().map(|x| x.1).collect();
|
||||
let result2 =
|
||||
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, |start| {
|
||||
&reduced[start..]
|
||||
});
|
||||
assert_eq!(result, result2);
|
||||
|
||||
let reduced2: Vec<_> = hashes.iter().map(|x| vec![x.1]).collect();
|
||||
let result2 = AccountsDB::flatten_hashes_and_hash(
|
||||
vec![reduced2],
|
||||
fanout,
|
||||
&mut HashStats::default(),
|
||||
);
|
||||
assert_eq!(result, result2);
|
||||
|
||||
for left in 0..reduced.len() {
|
||||
for right in left + 1..reduced.len() {
|
||||
let src = vec![
|
||||
vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()],
|
||||
vec![reduced[right..].to_vec()],
|
||||
];
|
||||
let result2 =
|
||||
AccountsDB::flatten_hashes_and_hash(src, fanout, &mut HashStats::default());
|
||||
assert_eq!(result, result2);
|
||||
}
|
||||
}
|
||||
}
|
||||
let reduced: Vec<_> = hashes.iter().map(|x| x.1).collect();
|
||||
let result2 = test_hashing(reduced, fanout);
|
||||
assert_eq!(result, result2, "len: {}", hashes.len());
|
||||
result
|
||||
}
|
||||
|
||||
fn test_hashing(hashes: Vec<Hash>, fanout: usize) -> Hash {
|
||||
let temp: Vec<_> = hashes.iter().map(|h| (Pubkey::default(), *h)).collect();
|
||||
let result = AccountsDB::compute_merkle_root(temp, fanout);
|
||||
if hashes.len() >= fanout * fanout * fanout {
|
||||
let reduced: Vec<_> = hashes.clone();
|
||||
let result2 =
|
||||
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, |start| {
|
||||
&reduced[start..]
|
||||
});
|
||||
assert_eq!(result, result2, "len: {}", hashes.len());
|
||||
let reduced: Vec<_> = hashes.clone();
|
||||
let result2 =
|
||||
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, None, |start| {
|
||||
&reduced[start..]
|
||||
});
|
||||
assert_eq!(result, result2, "len: {}", hashes.len());
|
||||
|
||||
let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect();
|
||||
let result2 = AccountsDB::flatten_hashes_and_hash(
|
||||
vec![reduced2],
|
||||
fanout,
|
||||
&mut HashStats::default(),
|
||||
);
|
||||
assert_eq!(result, result2, "len: {}", hashes.len());
|
||||
let result2 =
|
||||
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, Some(1), |start| {
|
||||
&reduced[start..]
|
||||
});
|
||||
assert_eq!(result, result2, "len: {}", hashes.len());
|
||||
|
||||
let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect();
|
||||
let result2 =
|
||||
AccountsDB::flatten_hashes_and_hash(vec![reduced2], fanout, &mut HashStats::default());
|
||||
assert_eq!(result, result2, "len: {}", hashes.len());
|
||||
|
||||
let max = std::cmp::min(reduced.len(), fanout * 2);
|
||||
for left in 0..max {
|
||||
for right in left + 1..max {
|
||||
let src = vec![
|
||||
vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()],
|
||||
vec![reduced[right..].to_vec()],
|
||||
];
|
||||
let result2 =
|
||||
AccountsDB::flatten_hashes_and_hash(src, fanout, &mut HashStats::default());
|
||||
assert_eq!(result, result2);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
@ -6516,12 +6549,29 @@ pub mod tests {
|
|||
fn test_accountsdb_compute_merkle_root_large() {
|
||||
solana_logger::setup();
|
||||
|
||||
let mut num = 100;
|
||||
for _pass in 0..2 {
|
||||
num *= 10;
|
||||
let hashes: Vec<_> = (0..num).into_iter().map(|_| Hash::new_unique()).collect();
|
||||
// handle fanout^x -1, +0, +1 for a few 'x's
|
||||
const FANOUT: usize = 3;
|
||||
let mut hash_counts: Vec<_> = (1..6)
|
||||
.map(|x| {
|
||||
let mark = FANOUT.pow(x);
|
||||
vec![mark - 1, mark, mark + 1]
|
||||
})
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
test_hashing(hashes, MERKLE_FANOUT);
|
||||
// saturate the test space for threshold to threshold + target
|
||||
// this hits right before we use the 3 deep optimization and all the way through all possible partial last chunks
|
||||
let target = FANOUT.pow(3);
|
||||
let threshold = target * FANOUT;
|
||||
hash_counts.extend(threshold - 1..=threshold + target);
|
||||
|
||||
for hash_count in hash_counts {
|
||||
let hashes: Vec<_> = (0..hash_count)
|
||||
.into_iter()
|
||||
.map(|_| Hash::new_unique())
|
||||
.collect();
|
||||
|
||||
test_hashing(hashes, FANOUT);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue