uses typed Pubkey instead of opaque [u8; 32] (#32842)

This commit is contained in:
behzad nouri 2023-08-18 23:28:08 +00:00 committed by GitHub
parent e28c819819
commit 7bd7410592
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 97 deletions

View File

@ -92,17 +92,11 @@ mod test {
.unwrap();
packet.meta_mut().flags |= PacketFlags::REPAIR;
let leader_slots = [(slot, keypair.pubkey().to_bytes())]
.iter()
.cloned()
.collect();
let leader_slots = HashMap::from([(slot, keypair.pubkey())]);
assert!(verify_shred_cpu(&packet, &leader_slots));
let wrong_keypair = Keypair::new();
let leader_slots = [(slot, wrong_keypair.pubkey().to_bytes())]
.iter()
.cloned()
.collect();
let leader_slots = HashMap::from([(slot, wrong_keypair.pubkey())]);
assert!(!verify_shred_cpu(&packet, &leader_slots));
let leader_slots = HashMap::new();

View File

@ -19,7 +19,7 @@ use {
signature::{Keypair, Signature, Signer},
},
static_assertions::const_assert_eq,
std::{collections::HashMap, fmt::Debug, iter::repeat, mem::size_of, ops::Range, sync::Arc},
std::{collections::HashMap, iter::repeat, mem::size_of, ops::Range, sync::Arc},
};
const SIGN_SHRED_GPU_MIN: usize = 256;
@ -27,10 +27,7 @@ const_assert_eq!(SIZE_OF_MERKLE_ROOT, 32);
const SIZE_OF_MERKLE_ROOT: usize = std::mem::size_of::<Hash>();
#[must_use]
pub fn verify_shred_cpu(
packet: &Packet,
slot_leaders: &HashMap<Slot, /*pubkey:*/ [u8; 32]>,
) -> bool {
pub fn verify_shred_cpu(packet: &Packet, slot_leaders: &HashMap<Slot, Pubkey>) -> bool {
if packet.meta().discard() {
return false;
}
@ -51,13 +48,13 @@ pub fn verify_shred_cpu(
let Some(data) = shred::layout::get_signed_data(shred) else {
return false;
};
signature.verify(pubkey, data.as_ref())
signature.verify(pubkey.as_ref(), data.as_ref())
}
fn verify_shreds_cpu(
thread_pool: &ThreadPool,
batches: &[PacketBatch],
slot_leaders: &HashMap<Slot, /*pubkey:*/ [u8; 32]>,
slot_leaders: &HashMap<Slot, Pubkey>,
) -> Vec<Vec<u8>> {
let packet_count = count_packets_in_batches(batches);
debug!("CPU SHRED ECDSA for {}", packet_count);
@ -76,17 +73,14 @@ fn verify_shreds_cpu(
rv
}
fn slot_key_data_for_gpu<T>(
fn slot_key_data_for_gpu(
thread_pool: &ThreadPool,
batches: &[PacketBatch],
slot_keys: &HashMap<Slot, /*pubkey:*/ T>,
slot_keys: &HashMap<Slot, Pubkey>,
recycler_cache: &RecyclerCache,
) -> (/*pubkeys:*/ PinnedVec<u8>, TxOffset)
where
T: AsRef<[u8]> + Copy + Debug + Default + Eq + std::hash::Hash + Sync,
{
) -> (/*pubkeys:*/ PinnedVec<u8>, TxOffset) {
//TODO: mark Pubkey::default shreds as failed after the GPU returns
assert_eq!(slot_keys.get(&Slot::MAX), Some(&T::default()));
assert_eq!(slot_keys.get(&Slot::MAX), Some(&Pubkey::default()));
let slots: Vec<Slot> = thread_pool.install(|| {
batches
.into_par_iter()
@ -104,14 +98,14 @@ where
})
.collect()
});
let keys_to_slots: HashMap<T, Vec<Slot>> = slots
let keys_to_slots: HashMap<Pubkey, Vec<Slot>> = slots
.iter()
.map(|slot| (slot_keys[slot], *slot))
.into_group_map();
let mut keyvec = recycler_cache.buffer().allocate("shred_gpu_pubkeys");
keyvec.set_pinnable();
let keyvec_size = keys_to_slots.len() * size_of::<T>();
let keyvec_size = keys_to_slots.len() * size_of::<Pubkey>();
resize_buffer(&mut keyvec, keyvec_size);
let key_offsets: HashMap<Slot, /*key offset:*/ usize> = {
@ -120,7 +114,7 @@ where
.into_iter()
.flat_map(|(key, slots)| {
let offset = next_offset;
next_offset += std::mem::size_of::<T>();
next_offset += std::mem::size_of::<Pubkey>();
keyvec[offset..next_offset].copy_from_slice(key.as_ref());
slots.into_iter().zip(repeat(offset))
})
@ -247,7 +241,7 @@ fn shred_gpu_offsets(
pub fn verify_shreds_gpu(
thread_pool: &ThreadPool,
batches: &[PacketBatch],
slot_leaders: &HashMap<Slot, /*pubkey:*/ [u8; 32]>,
slot_leaders: &HashMap<Slot, Pubkey>,
recycler_cache: &RecyclerCache,
) -> Vec<Vec<u8>> {
let Some(api) = perf_libs::api() else {
@ -505,17 +499,11 @@ mod tests {
packet.buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
packet.meta_mut().size = shred.payload().len();
let leader_slots = [(slot, keypair.pubkey().to_bytes())]
.iter()
.cloned()
.collect();
let leader_slots = HashMap::from([(slot, keypair.pubkey())]);
assert!(verify_shred_cpu(&packet, &leader_slots));
let wrong_keypair = Keypair::new();
let leader_slots = [(slot, wrong_keypair.pubkey().to_bytes())]
.iter()
.cloned()
.collect();
let leader_slots = HashMap::from([(slot, wrong_keypair.pubkey())]);
assert!(!verify_shred_cpu(&packet, &leader_slots));
let leader_slots = HashMap::new();
@ -546,18 +534,12 @@ mod tests {
batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].meta_mut().size = shred.payload().len();
let leader_slots = [(slot, keypair.pubkey().to_bytes())]
.iter()
.cloned()
.collect();
let leader_slots = HashMap::from([(slot, keypair.pubkey())]);
let rv = verify_shreds_cpu(thread_pool, &batches, &leader_slots);
assert_eq!(rv, vec![vec![1]]);
let wrong_keypair = Keypair::new();
let leader_slots = [(slot, wrong_keypair.pubkey().to_bytes())]
.iter()
.cloned()
.collect();
let leader_slots = HashMap::from([(slot, wrong_keypair.pubkey())]);
let rv = verify_shreds_cpu(thread_pool, &batches, &leader_slots);
assert_eq!(rv, vec![vec![0]]);
@ -565,10 +547,7 @@ mod tests {
let rv = verify_shreds_cpu(thread_pool, &batches, &leader_slots);
assert_eq!(rv, vec![vec![0]]);
let leader_slots = [(slot, keypair.pubkey().to_bytes())]
.iter()
.cloned()
.collect();
let leader_slots = HashMap::from([(slot, keypair.pubkey())]);
batches[0][0].meta_mut().size = 0;
let rv = verify_shreds_cpu(thread_pool, &batches, &leader_slots);
assert_eq!(rv, vec![vec![0]]);
@ -601,39 +580,26 @@ mod tests {
batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].meta_mut().size = shred.payload().len();
let leader_slots = [
(std::u64::MAX, Pubkey::default().to_bytes()),
(slot, keypair.pubkey().to_bytes()),
]
.iter()
.cloned()
.collect();
let leader_slots =
HashMap::from([(std::u64::MAX, Pubkey::default()), (slot, keypair.pubkey())]);
let rv = verify_shreds_gpu(thread_pool, &batches, &leader_slots, &recycler_cache);
assert_eq!(rv, vec![vec![1]]);
let wrong_keypair = Keypair::new();
let leader_slots = [
(std::u64::MAX, Pubkey::default().to_bytes()),
(slot, wrong_keypair.pubkey().to_bytes()),
]
.iter()
.cloned()
.collect();
let leader_slots = HashMap::from([
(std::u64::MAX, Pubkey::default()),
(slot, wrong_keypair.pubkey()),
]);
let rv = verify_shreds_gpu(thread_pool, &batches, &leader_slots, &recycler_cache);
assert_eq!(rv, vec![vec![0]]);
let leader_slots = [(std::u64::MAX, [0u8; 32])].iter().cloned().collect();
let leader_slots = HashMap::from([(std::u64::MAX, Pubkey::default())]);
let rv = verify_shreds_gpu(thread_pool, &batches, &leader_slots, &recycler_cache);
assert_eq!(rv, vec![vec![0]]);
batches[0][0].meta_mut().size = 0;
let leader_slots = [
(std::u64::MAX, Pubkey::default().to_bytes()),
(slot, keypair.pubkey().to_bytes()),
]
.iter()
.cloned()
.collect();
let leader_slots =
HashMap::from([(std::u64::MAX, Pubkey::default()), (slot, keypair.pubkey())]);
let rv = verify_shreds_gpu(thread_pool, &batches, &leader_slots, &recycler_cache);
assert_eq!(rv, vec![vec![0]]);
}
@ -670,13 +636,7 @@ mod tests {
let keypair = Keypair::new();
let pinned_keypair = sign_shreds_gpu_pinned_keypair(&keypair, &recycler_cache);
let pinned_keypair = Some(Arc::new(pinned_keypair));
let pubkeys = [
(std::u64::MAX, Pubkey::default().to_bytes()),
(slot, keypair.pubkey().to_bytes()),
]
.iter()
.cloned()
.collect();
let pubkeys = HashMap::from([(std::u64::MAX, Pubkey::default()), (slot, keypair.pubkey())]);
//unsigned
let rv = verify_shreds_gpu(thread_pool, &batches, &pubkeys, &recycler_cache);
assert_eq!(rv, vec![vec![0; num_packets]; num_batches]);
@ -720,13 +680,7 @@ mod tests {
batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].meta_mut().size = shred.payload().len();
let pubkeys = [
(slot, keypair.pubkey().to_bytes()),
(std::u64::MAX, Pubkey::default().to_bytes()),
]
.iter()
.cloned()
.collect();
let pubkeys = HashMap::from([(slot, keypair.pubkey()), (std::u64::MAX, Pubkey::default())]);
//unsigned
let rv = verify_shreds_cpu(thread_pool, &batches, &pubkeys);
assert_eq!(rv, vec![vec![0]]);
@ -853,10 +807,10 @@ mod tests {
.take(3)
.collect();
let shreds = make_shreds(&mut rng, &keypairs);
let pubkeys: HashMap<Slot, [u8; 32]> = keypairs
let pubkeys: HashMap<Slot, Pubkey> = keypairs
.iter()
.map(|(&slot, keypair)| (slot, keypair.pubkey().to_bytes()))
.chain(once((Slot::MAX, Pubkey::default().to_bytes())))
.map(|(&slot, keypair)| (slot, keypair.pubkey()))
.chain(once((Slot::MAX, Pubkey::default())))
.collect();
let mut packets = make_packets(&mut rng, &shreds);
assert_eq!(
@ -901,13 +855,13 @@ mod tests {
make_shreds(&mut rng, &keypairs)
};
let keypair = Keypair::new();
let pubkeys: HashMap<Slot, [u8; 32]> = {
let pubkey = keypair.pubkey().to_bytes();
let pubkeys: HashMap<Slot, Pubkey> = {
let pubkey = keypair.pubkey();
shreds
.iter()
.map(Shred::slot)
.map(|slot| (slot, pubkey))
.chain(once((Slot::MAX, Pubkey::default().to_bytes())))
.chain(once((Slot::MAX, Pubkey::default())))
.collect()
};
let mut packets = make_packets(&mut rng, &shreds);

View File

@ -147,11 +147,11 @@ fn verify_packets(
packets: &mut [PacketBatch],
) {
let working_bank = bank_forks.read().unwrap().working_bank();
let leader_slots: HashMap<Slot, [u8; 32]> =
let leader_slots: HashMap<Slot, Pubkey> =
get_slot_leaders(self_pubkey, packets, leader_schedule_cache, &working_bank)
.into_iter()
.filter_map(|(slot, pubkey)| Some((slot, pubkey?.to_bytes())))
.chain(std::iter::once((Slot::MAX, [0u8; 32])))
.filter_map(|(slot, pubkey)| Some((slot, pubkey?)))
.chain(std::iter::once((Slot::MAX, Pubkey::default())))
.collect();
let out = verify_shreds_gpu(thread_pool, packets, &leader_slots, recycler_cache);
solana_perf::sigverify::mark_disabled(packets, &out);
@ -175,12 +175,9 @@ fn get_slot_leaders(
continue;
}
let shred = shred::layout::get_shred(packet);
let slot = match shred.and_then(shred::layout::get_slot) {
None => {
packet.meta_mut().set_discard(true);
continue;
}
Some(slot) => slot,
let Some(slot) = shred.and_then(shred::layout::get_slot) else {
packet.meta_mut().set_discard(true);
continue;
};
let leader = leaders.entry(slot).or_insert_with(|| {
let leader = leader_schedule_cache.slot_leader_at(slot, Some(bank))?;