multi-pass bin scanning (#15377)

* multi-pass bin scanning

* pr feedback

* format

* fix typo

* adjust metrics for code changes

* merge errors
This commit is contained in:
Jeff Washington (jwash) 2021-03-18 10:32:07 -05:00 committed by GitHub
parent 0988c2f1d6
commit 4beb39f7a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 799 additions and 127 deletions

View File

@ -20,7 +20,7 @@
use crate::{
accounts_cache::{AccountsCache, CachedAccount, SlotCache},
accounts_hash::{AccountsHash, CalculateHashIntermediate, HashStats},
accounts_hash::{AccountsHash, CalculateHashIntermediate, HashStats, PreviousPass},
accounts_index::{
AccountIndex, AccountsIndex, AccountsIndexRootsStats, Ancestors, IndexKey, IsCached,
SlotList, SlotSlice, ZeroLamport,
@ -54,7 +54,7 @@ use std::{
collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet},
convert::TryFrom,
io::{Error as IoError, Result as IoResult},
ops::RangeBounds,
ops::{Range, RangeBounds},
path::{Path, PathBuf},
sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
sync::{Arc, Mutex, MutexGuard, RwLock},
@ -3705,9 +3705,11 @@ impl AccountsDb {
storage: &[SnapshotStorage],
mut stats: &mut crate::accounts_hash::HashStats,
bins: usize,
bin_range: &Range<usize>,
) -> Vec<Vec<Vec<CalculateHashIntermediate>>> {
let max_plus_1 = std::u8::MAX as usize + 1;
assert!(bins <= max_plus_1 && bins > 0);
assert!(bin_range.start < bins && bin_range.end <= bins && bin_range.start < bin_range.end);
let mut time = Measure::start("scan all accounts");
stats.num_snapshot_storage = storage.len();
let result: Vec<Vec<Vec<CalculateHashIntermediate>>> = Self::scan_account_storage_no_bank(
@ -3716,6 +3718,12 @@ impl AccountsDb {
|loaded_account: LoadedAccount,
accum: &mut Vec<Vec<CalculateHashIntermediate>>,
slot: Slot| {
let pubkey = *loaded_account.pubkey();
let pubkey_to_bin_index = pubkey.as_ref()[0] as usize * bins / max_plus_1;
if !bin_range.contains(&pubkey_to_bin_index) {
return;
}
let version = loaded_account.write_version();
let raw_lamports = loaded_account.lamports();
let zero_raw_lamports = raw_lamports == 0;
@ -3729,7 +3737,6 @@ impl AccountsDb {
)
};
let pubkey = *loaded_account.pubkey();
let source_item = CalculateHashIntermediate::new(
version,
*loaded_account.loaded_hash(),
@ -3737,12 +3744,11 @@ impl AccountsDb {
slot,
pubkey,
);
let rng_index = pubkey.as_ref()[0] as usize * bins / max_plus_1;
let max = accum.len();
if max == 0 {
accum.extend(vec![Vec::new(); bins]);
}
accum[rng_index].push(source_item);
accum[pubkey_to_bin_index].push(source_item);
},
);
time.stop();
@ -3759,14 +3765,46 @@ impl AccountsDb {
let scan_and_hash = || {
let mut stats = HashStats::default();
// When calculating hashes, it is helpful to break the pubkeys found into bins based on the pubkey value.
// More bins means smaller vectors to sort, copy, etc.
const PUBKEY_BINS_FOR_CALCULATING_HASHES: usize = 64;
let result = Self::scan_snapshot_stores(
storages,
&mut stats,
PUBKEY_BINS_FOR_CALCULATING_HASHES,
);
AccountsHash::rest_of_hash_calculation(result, &mut stats)
// # of passes should be a function of the total # of accounts that are active.
// higher passes = slower total time, lower dynamic memory usage
// lower passes = faster total time, higher dynamic memory usage
// passes=2 cuts dynamic memory usage in approximately half.
let num_scan_passes: usize = 2;
let bins_per_pass = PUBKEY_BINS_FOR_CALCULATING_HASHES / num_scan_passes;
assert_eq!(
bins_per_pass * num_scan_passes,
PUBKEY_BINS_FOR_CALCULATING_HASHES
); // evenly divisible
let mut previous_pass = PreviousPass::default();
let mut final_result = (Hash::default(), 0);
for pass in 0..num_scan_passes {
let bounds = Range {
start: pass * bins_per_pass,
end: (pass + 1) * bins_per_pass,
};
let result = Self::scan_snapshot_stores(
storages,
&mut stats,
PUBKEY_BINS_FOR_CALCULATING_HASHES,
&bounds,
);
let (hash, lamports, for_next_pass) = AccountsHash::rest_of_hash_calculation(
result,
&mut stats,
pass == num_scan_passes - 1,
previous_pass,
);
previous_pass = for_next_pass;
final_result = (hash, lamports);
}
final_result
};
if let Some(thread_pool) = thread_pool {
thread_pool.install(scan_and_hash)
@ -5015,13 +5053,49 @@ pub mod tests {
#[should_panic(expected = "assertion failed: bins <= max_plus_1 && bins > 0")]
fn test_accountsdb_scan_snapshot_stores_illegal_bins2() {
let mut stats = HashStats::default();
AccountsDb::scan_snapshot_stores(&[], &mut stats, 257);
let bounds = Range { start: 0, end: 0 };
AccountsDb::scan_snapshot_stores(&[], &mut stats, 257, &bounds);
}
#[test]
#[should_panic(expected = "assertion failed: bins <= max_plus_1 && bins > 0")]
fn test_accountsdb_scan_snapshot_stores_illegal_bins() {
let mut stats = HashStats::default();
AccountsDb::scan_snapshot_stores(&[], &mut stats, 0);
let bounds = Range { start: 0, end: 0 };
AccountsDb::scan_snapshot_stores(&[], &mut stats, 0, &bounds);
}
#[test]
#[should_panic(
expected = "bin_range.start < bins && bin_range.end <= bins &&\\n bin_range.start < bin_range.end"
)]
fn test_accountsdb_scan_snapshot_stores_illegal_range_start() {
let mut stats = HashStats::default();
let bounds = Range { start: 2, end: 2 };
AccountsDb::scan_snapshot_stores(&[], &mut stats, 2, &bounds);
}
#[test]
#[should_panic(
expected = "bin_range.start < bins && bin_range.end <= bins &&\\n bin_range.start < bin_range.end"
)]
fn test_accountsdb_scan_snapshot_stores_illegal_range_end() {
let mut stats = HashStats::default();
let bounds = Range { start: 1, end: 3 };
AccountsDb::scan_snapshot_stores(&[], &mut stats, 2, &bounds);
}
#[test]
#[should_panic(
expected = "bin_range.start < bins && bin_range.end <= bins &&\\n bin_range.start < bin_range.end"
)]
fn test_accountsdb_scan_snapshot_stores_illegal_range_inverse() {
let mut stats = HashStats::default();
let bounds = Range { start: 1, end: 0 };
AccountsDb::scan_snapshot_stores(&[], &mut stats, 2, &bounds);
}
fn sample_storages_and_accounts() -> (SnapshotStorages, Vec<CalculateHashIntermediate>) {
@ -5108,11 +5182,28 @@ pub mod tests {
let bins = 1;
let mut stats = HashStats::default();
let result = AccountsDb::scan_snapshot_stores(&storages, &mut stats, bins);
let result = AccountsDb::scan_snapshot_stores(
&storages,
&mut stats,
bins,
&Range {
start: 0,
end: bins,
},
);
assert_eq!(result, vec![vec![raw_expected.clone()]]);
let bins = 2;
let result = AccountsDb::scan_snapshot_stores(&storages, &mut stats, bins);
let result = AccountsDb::scan_snapshot_stores(
&storages,
&mut stats,
bins,
&Range {
start: 0,
end: bins,
},
);
let mut expected = vec![Vec::new(); bins];
expected[0].push(raw_expected[0].clone());
expected[0].push(raw_expected[1].clone());
@ -5121,7 +5212,15 @@ pub mod tests {
assert_eq!(result, vec![expected]);
let bins = 4;
let result = AccountsDb::scan_snapshot_stores(&storages, &mut stats, bins);
let result = AccountsDb::scan_snapshot_stores(
&storages,
&mut stats,
bins,
&Range {
start: 0,
end: bins,
},
);
let mut expected = vec![Vec::new(); bins];
expected[0].push(raw_expected[0].clone());
expected[1].push(raw_expected[1].clone());
@ -5130,7 +5229,15 @@ pub mod tests {
assert_eq!(result, vec![expected]);
let bins = 256;
let result = AccountsDb::scan_snapshot_stores(&storages, &mut stats, bins);
let result = AccountsDb::scan_snapshot_stores(
&storages,
&mut stats,
bins,
&Range {
start: 0,
end: bins,
},
);
let mut expected = vec![Vec::new(); bins];
expected[0].push(raw_expected[0].clone());
expected[127].push(raw_expected[1].clone());
@ -5151,13 +5258,126 @@ pub mod tests {
storages[0].splice(0..0, vec![arc; MAX_ITEMS_PER_CHUNK]);
let mut stats = HashStats::default();
let result = AccountsDb::scan_snapshot_stores(&storages, &mut stats, bins);
let result = AccountsDb::scan_snapshot_stores(
&storages,
&mut stats,
bins,
&Range {
start: 0,
end: bins,
},
);
assert_eq!(result.len(), 2); // 2 chunks
assert_eq!(result[0].len(), 0); // nothing found in first slots
assert_eq!(result[1].len(), bins);
assert_eq!(result[1], vec![raw_expected]);
}
#[test]
fn test_accountsdb_scan_snapshot_stores_binning() {
let mut stats = HashStats::default();
let (mut storages, raw_expected) = sample_storages_and_accounts();
// just the first bin of 2
let bins = 2;
let result = AccountsDb::scan_snapshot_stores(
&storages,
&mut stats,
bins,
&Range {
start: 0,
end: bins / 2,
},
);
let mut expected = vec![Vec::new(); bins];
expected[0].push(raw_expected[0].clone());
expected[0].push(raw_expected[1].clone());
assert_eq!(result, vec![expected]);
// just the second bin of 2
let result = AccountsDb::scan_snapshot_stores(
&storages,
&mut stats,
bins,
&Range {
start: 1,
end: bins,
},
);
let mut expected = vec![Vec::new(); bins];
expected[bins - 1].push(raw_expected[2].clone());
expected[bins - 1].push(raw_expected[3].clone());
assert_eq!(result, vec![expected]);
// 1 bin at a time of 4
let bins = 4;
for bin in 0..bins {
let result = AccountsDb::scan_snapshot_stores(
&storages,
&mut stats,
bins,
&Range {
start: bin,
end: bin + 1,
},
);
let mut expected = vec![Vec::new(); bins];
expected[bin].push(raw_expected[bin].clone());
assert_eq!(result, vec![expected]);
}
let bins = 256;
let bin_locations = vec![0, 127, 128, 255];
for bin in 0..bins {
let result = AccountsDb::scan_snapshot_stores(
&storages,
&mut stats,
bins,
&Range {
start: bin,
end: bin + 1,
},
);
let mut expected = vec![];
if let Some(index) = bin_locations.iter().position(|&r| r == bin) {
expected = vec![Vec::new(); bins];
expected[bin].push(raw_expected[index].clone());
}
assert_eq!(result, vec![expected]);
}
// enough stores to get to 2nd chunk
// range is for only 1 bin out of 256.
let bins = 256;
let (_temp_dirs, paths) = get_temp_accounts_paths(1).unwrap();
let slot_expected: Slot = 0;
let size: usize = 123;
let data = AccountStorageEntry::new(&paths[0], slot_expected, 0, size as u64);
let arc = Arc::new(data);
const MAX_ITEMS_PER_CHUNK: usize = 5_000;
storages[0].splice(0..0, vec![arc; MAX_ITEMS_PER_CHUNK]);
let mut stats = HashStats::default();
let result = AccountsDb::scan_snapshot_stores(
&storages,
&mut stats,
bins,
&Range {
start: 127,
end: 128,
},
);
assert_eq!(result.len(), 2); // 2 chunks
assert_eq!(result[0].len(), 0); // nothing found in first slots
let mut expected = vec![Vec::new(); bins];
expected[127].push(raw_expected[1].clone());
assert_eq!(result[1].len(), bins);
assert_eq!(result[1], expected);
}
#[test]
fn test_accountsdb_calculate_accounts_hash_without_index_simple() {
solana_logger::setup();

View File

@ -11,6 +11,13 @@ use std::{convert::TryInto, sync::Mutex};
pub const ZERO_RAW_LAMPORTS_SENTINEL: u64 = std::u64::MAX;
pub const MERKLE_FANOUT: usize = 16;
#[derive(Default, Debug)]
pub struct PreviousPass {
pub reduced_hashes: Vec<Vec<Hash>>,
pub remaining_unhashed: Vec<Hash>,
pub lamports: u64,
}
#[derive(Debug, Default)]
pub struct HashStats {
pub scan_time_total_us: u64,
@ -197,8 +204,9 @@ impl AccountsHash {
MERKLE_FANOUT,
None,
|start: usize| cumulative_offsets.get_slice(&hashes, start),
None,
);
(result, hash_total)
(result.0, hash_total)
}
pub fn compute_merkle_root(hashes: Vec<(Pubkey, Hash)>, fanout: usize) -> Hash {
@ -259,48 +267,85 @@ impl AccountsHash {
}
}
// This function is designed to allow hashes to be located in multiple, perhaps multiply deep vecs.
// The caller provides a function to return a slice from the source data.
pub fn compute_merkle_root_from_slices<'a, F>(
fn calculate_three_level_chunks(
total_hashes: usize,
fanout: usize,
max_levels_per_pass: Option<usize>,
get_hashes: F,
) -> Hash
where
F: Fn(usize) -> &'a [Hash] + std::marker::Sync,
{
if total_hashes == 0 {
return Hasher::default().result();
}
let mut time = Measure::start("time");
specific_level_count: Option<usize>,
) -> (usize, usize, bool) {
const THREE_LEVEL_OPTIMIZATION: usize = 3; // this '3' is dependent on the code structure below where we manually unroll
let target = fanout.pow(THREE_LEVEL_OPTIMIZATION as u32);
// Only use the 3 level optimization if we have at least 4 levels of data.
// Otherwise, we'll be serializing a parallel operation.
let threshold = target * fanout;
let three_level = max_levels_per_pass.unwrap_or(usize::MAX) >= THREE_LEVEL_OPTIMIZATION
let mut three_level = max_levels_per_pass.unwrap_or(usize::MAX) >= THREE_LEVEL_OPTIMIZATION
&& total_hashes >= threshold;
let num_hashes_per_chunk = if three_level { target } else { fanout };
if three_level {
if let Some(specific_level_count_value) = specific_level_count {
three_level = specific_level_count_value >= THREE_LEVEL_OPTIMIZATION;
}
}
let (num_hashes_per_chunk, levels_hashed) = if three_level {
(target, THREE_LEVEL_OPTIMIZATION)
} else {
(fanout, 1)
};
(num_hashes_per_chunk, levels_hashed, three_level)
}
// This function is designed to allow hashes to be located in multiple, perhaps multiply deep vecs.
// The caller provides a function to return a slice from the source data.
pub fn compute_merkle_root_from_slices<'a, F>(
total_hashes: usize,
fanout: usize,
max_levels_per_pass: Option<usize>,
get_hash_slice_starting_at_index: F,
specific_level_count: Option<usize>,
) -> (Hash, Vec<Hash>)
where
F: Fn(usize) -> &'a [Hash] + std::marker::Sync,
{
if total_hashes == 0 {
return (Hasher::default().result(), vec![]);
}
let mut time = Measure::start("time");
let (num_hashes_per_chunk, levels_hashed, three_level) = Self::calculate_three_level_chunks(
total_hashes,
fanout,
max_levels_per_pass,
specific_level_count,
);
let chunks = Self::div_ceil(total_hashes, num_hashes_per_chunk);
// initial fetch - could return entire slice
let data: &[Hash] = get_hashes(0);
let data: &[Hash] = get_hash_slice_starting_at_index(0);
let data_len = data.len();
let result: Vec<_> = (0..chunks)
.into_par_iter()
.map(|i| {
// summary:
// this closure computes 1 or 3 levels of merkle tree (all chunks will be 1 or all will be 3)
// for a subset (our chunk) of the input data [start_index..end_index]
// index into get_hash_slice_starting_at_index where this chunk's range begins
let start_index = i * num_hashes_per_chunk;
// index into get_hash_slice_starting_at_index where this chunk's range ends
let end_index = std::cmp::min(start_index + num_hashes_per_chunk, total_hashes);
// will compute the final result for this closure
let mut hasher = Hasher::default();
// index into 'data' where we are currently pulling data
// if we exhaust our data, then we will request a new slice, and data_index resets to 0, the beginning of the new slice
let mut data_index = start_index;
// source data, which we may refresh when we exhaust
let mut data = data;
// len of the source data
let mut data_len = data_len;
if !three_level {
@ -308,8 +353,8 @@ impl AccountsHash {
// The result of this loop is a single hash value from fanout input hashes.
for i in start_index..end_index {
if data_index >= data_len {
// fetch next slice
data = get_hashes(i);
// we exhausted our data, fetch next slice starting at i
data = get_hash_slice_starting_at_index(i);
data_len = data.len();
data_index = 0;
}
@ -318,7 +363,43 @@ impl AccountsHash {
}
} else {
// hash 3 levels of fanout simultaneously.
// This codepath produces 1 hash value for between 1..=fanout^3 input hashes.
// It is equivalent to running the normal merkle tree calculation 3 iterations on the input.
//
// big idea:
// merkle trees usually reduce the input vector by a factor of fanout with each iteration
// example with fanout 2:
// start: [0,1,2,3,4,5,6,7] in our case: [...16M...] or really, 1B
// iteration0 [.5, 2.5, 4.5, 6.5] [... 1M...]
// iteration1 [1.5, 5.5] [...65k...]
// iteration2 3.5 [...4k... ]
// So iteration 0 consumes N elements, hashes them in groups of 'fanout' and produces a vector of N/fanout elements
// and the process repeats until there is only 1 hash left.
//
// With the three_level code path, we make each chunk we iterate of size fanout^3 (4096)
// So, the input could be 16M hashes and the output will be 4k hashes, or N/fanout^3
// The goal is to reduce the amount of data that has to be constructed and held in memory.
// When we know we have enough hashes, then, in 1 pass, we hash 3 levels simultaneously, storing far fewer intermediate hashes.
//
// Now, some details:
// The result of this loop is a single hash value from fanout^3 input hashes.
// concepts:
// what we're conceptually hashing: "raw_hashes"[start_index..end_index]
// example: [a,b,c,d,e,f]
// but... hashes[] may really be multiple vectors that are pieced together.
// example: [[a,b],[c],[d,e,f]]
// get_hash_slice_starting_at_index(any_index) abstracts that and returns a slice starting at raw_hashes[any_index..]
// such that the end of get_hash_slice_starting_at_index may be <, >, or = end_index
// example: get_hash_slice_starting_at_index(1) returns [b]
// get_hash_slice_starting_at_index(3) returns [d,e,f]
// This code is basically 3 iterations of merkle tree hashing occurring simultaneously.
// The first fanout raw hashes are hashed in hasher_k. This is iteration0
// Once hasher_k has hashed fanout hashes, hasher_k's result hash is hashed in hasher_j and then discarded
// hasher_k then starts over fresh and hashes the next fanout raw hashes. This is iteration0 again for a new set of data.
// Once hasher_j has hashed fanout hashes (from k), hasher_j's result hash is hashed in hasher and then discarded
// Once hasher has hashed fanout hashes (from j), then the result of hasher is the hash for fanout^3 raw hashes.
// If there are < fanout^3 hashes, then this code stops when it runs out of raw hashes and returns whatever it hashed.
// This is always how the very last elements work in a merkle tree.
let mut i = start_index;
while i < end_index {
let mut hasher_j = Hasher::default();
@ -327,8 +408,8 @@ impl AccountsHash {
let end = std::cmp::min(end_index - i, fanout);
for _k in 0..end {
if data_index >= data_len {
// fetch next slice
data = get_hashes(i);
// we exhausted our data, fetch next slice starting at i
data = get_hash_slice_starting_at_index(i);
data_len = data.len();
data_index = 0;
}
@ -351,13 +432,47 @@ impl AccountsHash {
time.stop();
debug!("hashing {} {}", total_hashes, time);
if result.len() == 1 {
result[0]
if let Some(mut specific_level_count_value) = specific_level_count {
specific_level_count_value -= levels_hashed;
if specific_level_count_value == 0 {
(Hash::default(), result)
} else {
assert!(specific_level_count_value > 0);
// We did not hash the number of levels required by 'specific_level_count', so repeat
Self::compute_merkle_root_from_slices_recurse(
result,
fanout,
max_levels_per_pass,
Some(specific_level_count_value),
)
}
} else {
Self::compute_merkle_root_recurse(result, fanout)
(
if result.len() == 1 {
result[0]
} else {
Self::compute_merkle_root_recurse(result, fanout)
},
vec![], // no intermediate results needed by caller
)
}
}
pub fn compute_merkle_root_from_slices_recurse(
hashes: Vec<Hash>,
fanout: usize,
max_levels_per_pass: Option<usize>,
specific_level_count: Option<usize>,
) -> (Hash, Vec<Hash>) {
Self::compute_merkle_root_from_slices(
hashes.len(),
fanout,
max_levels_per_pass,
|start| &hashes[start..],
specific_level_count,
)
}
pub fn accumulate_account_hashes(mut hashes: Vec<(Pubkey, Hash)>) -> Hash {
Self::sort_hashes_by_pubkey(&mut hashes);
@ -405,7 +520,7 @@ impl AccountsHash {
}
flatten_time.stop();
stats.flatten_time_total_us += flatten_time.as_us();
stats.unreduced_entries = raw_len;
stats.unreduced_entries += raw_len;
data_by_pubkey
}
@ -569,50 +684,120 @@ impl AccountsHash {
(result, sum)
}
fn flatten_hashes_and_hash(
hashes: Vec<Vec<Vec<Hash>>>,
fanout: usize,
stats: &mut HashStats,
) -> Hash {
let mut hash_time = Measure::start("flat2");
let offsets = CumulativeOffsets::from_raw_2d(&hashes);
let get_slice = |start: usize| -> &[Hash] { offsets.get_slice_2d(&hashes, start) };
let hash = AccountsHash::compute_merkle_root_from_slices(
offsets.total_count,
fanout,
None,
get_slice,
);
hash_time.stop();
stats.hash_time_total_us += hash_time.as_us();
stats.hash_total = offsets.total_count;
hash
}
// input:
// vec: unordered, created by parallelism
// vec: [0..bins] - where bins are pubkey ranges
// vec: [..] - items which fin in the containing bin, unordered within this vec
// vec: [..] - items which fit in the containing bin, unordered within this vec
// so, assumption is middle vec is bins sorted by pubkey
pub fn rest_of_hash_calculation(
data_sections_by_pubkey: Vec<Vec<Vec<CalculateHashIntermediate>>>,
mut stats: &mut HashStats,
) -> (Hash, u64) {
is_last_pass: bool,
mut previous_state: PreviousPass,
) -> (Hash, u64, PreviousPass) {
let outer = Self::flatten_hash_intermediate(data_sections_by_pubkey, &mut stats);
let sorted_data_by_pubkey = Self::sort_hash_intermediate(outer, &mut stats);
let (hashes, total_lamports) =
let (mut hashes, mut total_lamports) =
Self::de_dup_and_eliminate_zeros(sorted_data_by_pubkey, &mut stats);
let hash = Self::flatten_hashes_and_hash(hashes, MERKLE_FANOUT, &mut stats);
total_lamports += previous_state.lamports;
stats.log();
if !previous_state.remaining_unhashed.is_empty() {
// these items were not hashed last iteration because they didn't divide evenly
hashes.insert(0, vec![previous_state.remaining_unhashed]);
previous_state.remaining_unhashed = Vec::new();
}
(hash, total_lamports)
let mut next_pass = PreviousPass::default();
let cumulative = CumulativeOffsets::from_raw_2d(&hashes);
let mut hash_total = cumulative.total_count;
stats.hash_total += hash_total;
next_pass.reduced_hashes = previous_state.reduced_hashes;
const TARGET_FANOUT_LEVEL: usize = 3;
let target_fanout = MERKLE_FANOUT.pow(TARGET_FANOUT_LEVEL as u32);
if !is_last_pass {
next_pass.lamports = total_lamports;
total_lamports = 0;
// Save hashes that don't evenly hash. They will be combined with hashes from the next pass.
let left_over_hashes = hash_total % target_fanout;
// move tail hashes that don't evenly hash into a 1d vector for next time
let mut i = hash_total - left_over_hashes;
while i < hash_total {
let data = cumulative.get_slice_2d(&hashes, i);
next_pass.remaining_unhashed.extend(data);
i += data.len();
}
hash_total -= left_over_hashes; // this is enough to cause the hashes at the end of the data set to be ignored
}
// if we have raw hashes to process and
// we are not the last pass (we already modded against target_fanout) OR
// we have previously surpassed target_fanout and hashed some already to the target_fanout level. In that case, we know
// we need to hash whatever is left here to the target_fanout level.
if hash_total != 0 && (!is_last_pass || !next_pass.reduced_hashes.is_empty()) {
let mut hash_time = Measure::start("hash");
let partial_hashes = Self::compute_merkle_root_from_slices(
hash_total, // note this does not include the ones that didn't divide evenly, unless we're in the last iteration
MERKLE_FANOUT,
Some(TARGET_FANOUT_LEVEL),
|start| cumulative.get_slice_2d(&hashes, start),
Some(TARGET_FANOUT_LEVEL),
)
.1;
hash_time.stop();
stats.hash_time_total_us += hash_time.as_us();
next_pass.reduced_hashes.push(partial_hashes);
}
let no_progress = is_last_pass && next_pass.reduced_hashes.is_empty() && !hashes.is_empty();
if no_progress {
// we never made partial progress, so hash everything now
hashes.into_iter().for_each(|v| {
v.into_iter().for_each(|v| {
if !v.is_empty() {
next_pass.reduced_hashes.push(v);
}
});
});
}
let hash = if is_last_pass {
let cumulative = CumulativeOffsets::from_raw(&next_pass.reduced_hashes);
let hash = if cumulative.total_count == 1 && !no_progress {
// all the passes resulted in a single hash, that means we're done, so we had <= MERKLE_ROOT total hashes
cumulative.get_slice(&next_pass.reduced_hashes, 0)[0]
} else {
let mut hash_time = Measure::start("hash");
// hash all the rest and combine and hash until we have only 1 hash left
let (hash, _) = Self::compute_merkle_root_from_slices(
cumulative.total_count,
MERKLE_FANOUT,
None,
|start| cumulative.get_slice(&next_pass.reduced_hashes, start),
None,
);
hash_time.stop();
stats.hash_time_total_us += hash_time.as_us();
hash
};
next_pass.reduced_hashes = Vec::new();
hash
} else {
Hash::default()
};
if is_last_pass {
stats.log();
}
(hash, total_lamports, next_pass)
}
}
@ -662,6 +847,8 @@ pub mod tests {
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![account_maps.clone()]],
&mut HashStats::default(),
true,
PreviousPass::default(),
);
let expected_hash = Hash::from_str("8j9ARGFv4W2GfML7d3sVJK2MePwrikqYnu6yqer28cCa").unwrap();
assert_eq!((result.0, result.1), (expected_hash, 88));
@ -675,6 +862,8 @@ pub mod tests {
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![account_maps.clone()]],
&mut HashStats::default(),
true,
PreviousPass::default(),
);
let expected_hash = Hash::from_str("EHv9C5vX7xQjjMpsJMzudnDTzoTSRwYkqLzY8tVMihGj").unwrap();
assert_eq!((result.0, result.1), (expected_hash, 108));
@ -688,11 +877,302 @@ pub mod tests {
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![account_maps]],
&mut HashStats::default(),
true,
PreviousPass::default(),
);
let expected_hash = Hash::from_str("7NNPg5A8Xsg1uv4UFm6KZNwsipyyUnmgCrznP6MBWoBZ").unwrap();
assert_eq!((result.0, result.1), (expected_hash, 118));
}
#[test]
fn test_accountsdb_multi_pass_rest_of_hash_calculation() {
solana_logger::setup();
// passes:
// 0: empty, NON-empty, empty, empty final
// 1: NON-empty, empty final
// 2: NON-empty, empty, empty final
for pass in 0..3 {
let mut account_maps: Vec<CalculateHashIntermediate> = Vec::new();
let key = Pubkey::new(&[11u8; 32]);
let hash = Hash::new(&[1u8; 32]);
let val = CalculateHashIntermediate::new(0, hash, 88, Slot::default(), key);
account_maps.push(val);
// 2nd key - zero lamports, so will be removed
let key = Pubkey::new(&[12u8; 32]);
let hash = Hash::new(&[2u8; 32]);
let val = CalculateHashIntermediate::new(
0,
hash,
ZERO_RAW_LAMPORTS_SENTINEL,
Slot::default(),
key,
);
account_maps.push(val);
let mut previous_pass = PreviousPass::default();
if pass == 0 {
// first pass that is not last and is empty
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![vec![]]],
&mut HashStats::default(),
false, // not last pass
previous_pass,
);
assert_eq!(result.0, Hash::default());
assert_eq!(result.1, 0);
previous_pass = result.2;
assert_eq!(previous_pass.remaining_unhashed.len(), 0);
assert_eq!(previous_pass.reduced_hashes.len(), 0);
assert_eq!(previous_pass.lamports, 0);
}
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![account_maps.clone()]],
&mut HashStats::default(),
false, // not last pass
previous_pass,
);
assert_eq!(result.0, Hash::default());
assert_eq!(result.1, 0);
let mut previous_pass = result.2;
assert_eq!(previous_pass.remaining_unhashed, vec![account_maps[0].hash]);
assert_eq!(previous_pass.reduced_hashes.len(), 0);
assert_eq!(previous_pass.lamports, account_maps[0].lamports);
let expected_hash =
Hash::from_str("8j9ARGFv4W2GfML7d3sVJK2MePwrikqYnu6yqer28cCa").unwrap();
if pass == 2 {
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![vec![]]],
&mut HashStats::default(),
false,
previous_pass,
);
previous_pass = result.2;
assert_eq!(previous_pass.remaining_unhashed, vec![account_maps[0].hash]);
assert_eq!(previous_pass.reduced_hashes.len(), 0);
assert_eq!(previous_pass.lamports, account_maps[0].lamports);
}
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![vec![]]],
&mut HashStats::default(),
true, // finally, last pass
previous_pass,
);
let previous_pass = result.2;
assert_eq!(previous_pass.remaining_unhashed.len(), 0);
assert_eq!(previous_pass.reduced_hashes.len(), 0);
assert_eq!(previous_pass.lamports, 0);
assert_eq!((result.0, result.1), (expected_hash, 88));
}
}
#[test]
fn test_accountsdb_multi_pass_rest_of_hash_calculation_partial() {
solana_logger::setup();
let mut account_maps: Vec<CalculateHashIntermediate> = Vec::new();
let key = Pubkey::new(&[11u8; 32]);
let hash = Hash::new(&[1u8; 32]);
let val = CalculateHashIntermediate::new(0, hash, 88, Slot::default(), key);
account_maps.push(val);
let key = Pubkey::new(&[12u8; 32]);
let hash = Hash::new(&[2u8; 32]);
let val = CalculateHashIntermediate::new(0, hash, 20, Slot::default(), key);
account_maps.push(val);
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![vec![account_maps[0].clone()]]],
&mut HashStats::default(),
false, // not last pass
PreviousPass::default(),
);
assert_eq!(result.0, Hash::default());
assert_eq!(result.1, 0);
let previous_pass = result.2;
assert_eq!(previous_pass.remaining_unhashed, vec![account_maps[0].hash]);
assert_eq!(previous_pass.reduced_hashes.len(), 0);
assert_eq!(previous_pass.lamports, account_maps[0].lamports);
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![vec![account_maps[1].clone()]]],
&mut HashStats::default(),
false, // not last pass
previous_pass,
);
assert_eq!(result.0, Hash::default());
assert_eq!(result.1, 0);
let previous_pass = result.2;
assert_eq!(
previous_pass.remaining_unhashed,
vec![account_maps[0].hash, account_maps[1].hash]
);
assert_eq!(previous_pass.reduced_hashes.len(), 0);
let total_lamports_expected = account_maps[0].lamports + account_maps[1].lamports;
assert_eq!(previous_pass.lamports, total_lamports_expected);
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![vec![]]],
&mut HashStats::default(),
true,
previous_pass,
);
let previous_pass = result.2;
assert_eq!(previous_pass.remaining_unhashed.len(), 0);
assert_eq!(previous_pass.reduced_hashes.len(), 0);
assert_eq!(previous_pass.lamports, 0);
let expected_hash = AccountsHash::compute_merkle_root(
account_maps
.iter()
.map(|a| (a.pubkey, a.hash))
.collect::<Vec<_>>(),
MERKLE_FANOUT,
);
assert_eq!(
(result.0, result.1),
(expected_hash, total_lamports_expected)
);
}
#[test]
fn test_accountsdb_multi_pass_rest_of_hash_calculation_partial_hashes() {
solana_logger::setup();
let mut account_maps: Vec<CalculateHashIntermediate> = Vec::new();
const TARGET_FANOUT_LEVEL: usize = 3;
let target_fanout = MERKLE_FANOUT.pow(TARGET_FANOUT_LEVEL as u32);
let mut total_lamports_expected = 0;
let plus1 = target_fanout + 1;
for i in 0..plus1 * 2 {
let lamports = (i + 1) as u64;
total_lamports_expected += lamports;
let key = Pubkey::new_unique();
let hash = Hash::new_unique();
let val = CalculateHashIntermediate::new(0, hash, lamports, Slot::default(), key);
account_maps.push(val);
}
let chunk = account_maps[0..plus1].to_vec();
let mut sorted = chunk.clone();
sorted.sort_by(AccountsHash::compare_two_hash_entries);
// first 4097 hashes (1 left over)
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![chunk]],
&mut HashStats::default(),
false, // not last pass
PreviousPass::default(),
);
assert_eq!(result.0, Hash::default());
assert_eq!(result.1, 0);
let previous_pass = result.2;
let left_over_1 = sorted[plus1 - 1].hash;
assert_eq!(previous_pass.remaining_unhashed, vec![left_over_1]);
assert_eq!(previous_pass.reduced_hashes.len(), 1);
let expected_hash = AccountsHash::compute_merkle_root(
sorted[0..target_fanout]
.iter()
.map(|a| (a.pubkey, a.hash))
.collect::<Vec<_>>(),
MERKLE_FANOUT,
);
assert_eq!(previous_pass.reduced_hashes[0], vec![expected_hash]);
assert_eq!(
previous_pass.lamports,
account_maps[0..plus1]
.iter()
.map(|i| i.lamports)
.sum::<u64>()
);
let chunk = account_maps[plus1..plus1 * 2].to_vec();
let mut sorted2 = chunk.clone();
sorted2.sort_by(AccountsHash::compare_two_hash_entries);
let mut with_left_over = vec![left_over_1];
with_left_over.extend(sorted2[0..plus1 - 2].to_vec().into_iter().map(|i| i.hash));
let expected_hash2 = AccountsHash::compute_merkle_root(
with_left_over[0..target_fanout]
.iter()
.map(|a| (Pubkey::default(), *a))
.collect::<Vec<_>>(),
MERKLE_FANOUT,
);
// second 4097 hashes (2 left over)
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![chunk]],
&mut HashStats::default(),
false, // not last pass
previous_pass,
);
assert_eq!(result.0, Hash::default());
assert_eq!(result.1, 0);
let previous_pass = result.2;
assert_eq!(
previous_pass.remaining_unhashed,
vec![sorted2[plus1 - 2].hash, sorted2[plus1 - 1].hash]
);
assert_eq!(previous_pass.reduced_hashes.len(), 2);
assert_eq!(
previous_pass.reduced_hashes,
vec![vec![expected_hash], vec![expected_hash2]]
);
assert_eq!(
previous_pass.lamports,
account_maps[0..plus1 * 2]
.iter()
.map(|i| i.lamports)
.sum::<u64>()
);
let result = AccountsHash::rest_of_hash_calculation(
vec![vec![vec![]]],
&mut HashStats::default(),
true,
previous_pass,
);
let previous_pass = result.2;
assert_eq!(previous_pass.remaining_unhashed.len(), 0);
assert_eq!(previous_pass.reduced_hashes.len(), 0);
assert_eq!(previous_pass.lamports, 0);
let mut combined = sorted;
combined.extend(sorted2);
let expected_hash = AccountsHash::compute_merkle_root(
combined
.iter()
.map(|a| (a.pubkey, a.hash))
.collect::<Vec<_>>(),
MERKLE_FANOUT,
);
assert_eq!(
(result.0, result.1),
(expected_hash, total_lamports_expected)
);
}
#[test]
fn test_accountsdb_de_dup_accounts_zero_chunks() {
let (hashes, lamports) =
@ -923,40 +1403,6 @@ pub mod tests {
}
}
#[test]
fn test_accountsdb_flatten_hashes_and_hash() {
solana_logger::setup();
const COUNT: usize = 4;
let hashes: Vec<_> = (0..COUNT)
.into_iter()
.map(|i| Hash::new(&[(i) as u8; 32]))
.collect();
let expected =
AccountsHash::compute_merkle_root_loop(hashes.clone(), MERKLE_FANOUT, |i| *i);
assert_eq!(
AccountsHash::flatten_hashes_and_hash(
vec![vec![hashes.clone()]],
MERKLE_FANOUT,
&mut HashStats::default()
),
expected,
);
for in_first in 1..COUNT - 1 {
assert_eq!(
AccountsHash::flatten_hashes_and_hash(
vec![vec![
hashes.clone()[0..in_first].to_vec(),
hashes.clone()[in_first..COUNT].to_vec()
]],
MERKLE_FANOUT,
&mut HashStats::default()
),
expected
);
}
}
#[test]
fn test_sort_hash_intermediate() {
solana_logger::setup();
@ -1367,25 +1813,23 @@ pub mod tests {
let temp: Vec<_> = hashes.iter().map(|h| (Pubkey::default(), *h)).collect();
let result = AccountsHash::compute_merkle_root(temp, fanout);
let reduced: Vec<_> = hashes.clone();
let result2 =
AccountsHash::compute_merkle_root_from_slices(hashes.len(), fanout, None, |start| {
&reduced[start..]
});
assert_eq!(result, result2, "len: {}", hashes.len());
let result2 =
AccountsHash::compute_merkle_root_from_slices(hashes.len(), fanout, Some(1), |start| {
&reduced[start..]
});
assert_eq!(result, result2, "len: {}", hashes.len());
let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect();
let result2 = AccountsHash::flatten_hashes_and_hash(
vec![reduced2],
let result2 = AccountsHash::compute_merkle_root_from_slices(
hashes.len(),
fanout,
&mut HashStats::default(),
None,
|start| &reduced[start..],
None,
);
assert_eq!(result, result2, "len: {}", hashes.len());
assert_eq!(result, result2.0, "len: {}", hashes.len());
let result2 = AccountsHash::compute_merkle_root_from_slices(
hashes.len(),
fanout,
Some(1),
|start| &reduced[start..],
None,
);
assert_eq!(result, result2.0, "len: {}", hashes.len());
let max = std::cmp::min(reduced.len(), fanout * 2);
for left in 0..max {
@ -1394,9 +1838,17 @@ pub mod tests {
vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()],
vec![reduced[right..].to_vec()],
];
let result2 =
AccountsHash::flatten_hashes_and_hash(src, fanout, &mut HashStats::default());
assert_eq!(result, result2);
let offsets = CumulativeOffsets::from_raw_2d(&src);
let get_slice = |start: usize| -> &[Hash] { offsets.get_slice_2d(&src, start) };
let result2 = AccountsHash::compute_merkle_root_from_slices(
offsets.total_count,
fanout,
None,
get_slice,
None,
);
assert_eq!(result, result2.0);
}
}
result