From de612c25b3202cce0ceaf481fe6fce48e0aeb80d Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Thu, 26 May 2022 13:06:27 +0000 Subject: [PATCH] removes shred wire layout specs from sigverify (#25520) sigverify_shreds relies on wire layout specs of shreds: https://github.com/solana-labs/solana/blob/0376ab41a/ledger/src/sigverify_shreds.rs#L39-L46 https://github.com/solana-labs/solana/blob/0376ab41a/ledger/src/sigverify_shreds.rs#L298-L305 In preparation of https://github.com/solana-labs/solana/pull/25237 which adds a new shred variant with different layout and signed message, this commit removes shred layout specification from sigverify and instead encapsulate that in shred module. --- core/src/repair_response.rs | 9 +- core/src/sigverify_shreds.rs | 7 +- ledger/src/blockstore.rs | 2 +- ledger/src/shred.rs | 171 ++++++++++++++----------- ledger/src/shredder.rs | 7 +- ledger/src/sigverify_shreds.rs | 220 ++++++++++++++------------------- 6 files changed, 203 insertions(+), 213 deletions(-) diff --git a/core/src/repair_response.rs b/core/src/repair_response.rs index 5d36831fb..600a58208 100644 --- a/core/src/repair_response.rs +++ b/core/src/repair_response.rs @@ -92,20 +92,17 @@ mod test { .iter() .cloned() .collect(); - let rv = verify_shred_cpu(&packet, &leader_slots); - assert_eq!(rv, Some(1)); + assert!(verify_shred_cpu(&packet, &leader_slots)); let wrong_keypair = Keypair::new(); let leader_slots = [(slot, wrong_keypair.pubkey().to_bytes())] .iter() .cloned() .collect(); - let rv = verify_shred_cpu(&packet, &leader_slots); - assert_eq!(rv, Some(0)); + assert!(!verify_shred_cpu(&packet, &leader_slots)); let leader_slots = HashMap::new(); - let rv = verify_shred_cpu(&packet, &leader_slots); - assert_eq!(rv, None); + assert!(!verify_shred_cpu(&packet, &leader_slots)); } #[test] diff --git a/core/src/sigverify_shreds.rs b/core/src/sigverify_shreds.rs index 261eec390..5600d5092 100644 --- a/core/src/sigverify_shreds.rs +++ b/core/src/sigverify_shreds.rs @@ -7,8 +7,7 @@ use { }, crossbeam_channel::Sender, solana_ledger::{ - leader_schedule_cache::LeaderScheduleCache, shred::Shred, - sigverify_shreds::verify_shreds_gpu, + leader_schedule_cache::LeaderScheduleCache, shred, sigverify_shreds::verify_shreds_gpu, }, solana_perf::{self, packet::PacketBatch, recycler_cache::RecyclerCache}, solana_runtime::bank_forks::BankForks, @@ -43,7 +42,9 @@ impl ShredSigVerifier { fn read_slots(batches: &[PacketBatch]) -> HashSet { batches .iter() - .flat_map(|batch| batch.iter().filter_map(Shred::get_slot_from_packet)) + .flat_map(PacketBatch::iter) + .map(shred::layout::get_shred) + .filter_map(shred::layout::get_slot) .collect() } } diff --git a/ledger/src/blockstore.rs b/ledger/src/blockstore.rs index 09fbfdce3..39489bb9a 100644 --- a/ledger/src/blockstore.rs +++ b/ledger/src/blockstore.rs @@ -1859,7 +1859,7 @@ impl Blockstore { let upper_index = cmp::min(current_index, end_index); // the tick that will be used to figure out the timeout for this hole let data = db_iterator.value().expect("couldn't read value"); - let reference_tick = u64::from(Shred::reference_tick_from_data(data).unwrap()); + let reference_tick = u64::from(shred::layout::get_reference_tick(data).unwrap()); if ticks_since_first_insert < reference_tick + MAX_TURBINE_DELAY_IN_TICKS { // The higher index holes have not timed out yet break 'outer; diff --git a/ledger/src/shred.rs b/ledger/src/shred.rs index 0bfecf3a5..481b1e00a 100644 --- a/ledger/src/shred.rs +++ b/ledger/src/shred.rs @@ -56,7 +56,7 @@ use { num_enum::{IntoPrimitive, TryFromPrimitive}, serde::{Deserialize, Serialize}, solana_entry::entry::{create_ticks, Entry}, - solana_perf::packet::Packet, + solana_perf::packet::{deserialize_from_with_limit, Packet}, solana_runtime::bank::Bank, solana_sdk::{ clock::Slot, @@ -317,7 +317,7 @@ impl Shred { } pub fn new_from_serialized_shred(shred: Vec) -> Result { - Ok(match Self::shred_type_from_payload(&shred)? { + Ok(match layout::get_shred_type(&shred)? { ShredType::Code => Self::from(ShredCode::from_payload(shred)?), ShredType::Data => Self::from(ShredData::from_payload(shred)?), }) @@ -383,7 +383,7 @@ impl Shred { // Possibly zero pads bytes stored in blockstore. pub(crate) fn resize_stored_shred(shred: Vec) -> Result, Error> { - match Self::shred_type_from_payload(&shred)? { + match layout::get_shred_type(&shred)? { ShredType::Code => ShredCode::resize_stored_shred(shred), ShredType::Data => ShredData::resize_stored_shred(shred), } @@ -441,16 +441,6 @@ impl Shred { self.common_header().shred_type } - fn shred_type_from_payload(shred: &[u8]) -> Result { - match shred.get(OFFSET_OF_SHRED_TYPE) { - None => Err(Error::InvalidPayloadSize(shred.len())), - Some(shred_type) => match ShredType::try_from(*shred_type) { - Err(_) => Err(Error::InvalidShredType), - Ok(shred_type) => Ok(shred_type), - }, - } - } - pub fn is_data(&self) -> bool { self.shred_type() == ShredType::Data } @@ -488,25 +478,6 @@ impl Shred { } } - // Get slot from a shred packet with partial deserialize - pub fn get_slot_from_packet(p: &Packet) -> Option { - let slot_start = OFFSET_OF_SHRED_SLOT; - let slot_end = slot_start + SIZE_OF_SHRED_SLOT; - p.deserialize_slice(slot_start..slot_end).ok() - } - - pub(crate) fn reference_tick_from_data(data: &[u8]) -> Result { - const SHRED_FLAGS_OFFSET: usize = SIZE_OF_COMMON_SHRED_HEADER + std::mem::size_of::(); - if Self::shred_type_from_payload(data)? != ShredType::Data { - return Err(Error::InvalidShredType); - } - let flags = match data.get(SHRED_FLAGS_OFFSET) { - None => return Err(Error::InvalidPayloadSize(data.len())), - Some(flags) => flags, - }; - Ok(flags & ShredFlags::SHRED_TICK_REFERENCE_MASK.bits()) - } - pub fn verify(&self, pubkey: &Pubkey) -> bool { let message = self.signed_payload(); self.signature().verify(pubkey.as_ref(), message) @@ -535,6 +506,73 @@ impl Shred { } } +// Helper methods to extract pieces of the shred from the payload +// without deserializing the entire payload. +pub mod layout { + use {super::*, std::ops::Range}; + + fn get_shred_size(packet: &Packet) -> usize { + if packet.meta.repair() { + packet.meta.size.saturating_sub(SIZE_OF_NONCE) + } else { + packet.meta.size + } + } + + pub fn get_shred(packet: &Packet) -> &[u8] { + &packet.data()[..get_shred_size(packet)] + } + + pub(crate) fn get_signature(shred: &[u8]) -> Option { + Some(Signature::new(shred.get(..SIZE_OF_SIGNATURE)?)) + } + + pub(crate) const fn get_signature_range() -> Range { + 0..SIZE_OF_SIGNATURE + } + + pub(super) fn get_shred_type(shred: &[u8]) -> Result { + match shred.get(OFFSET_OF_SHRED_TYPE) { + None => Err(Error::InvalidPayloadSize(shred.len())), + Some(shred_type) => match ShredType::try_from(*shred_type) { + Err(_) => Err(Error::InvalidShredType), + Ok(shred_type) => Ok(shred_type), + }, + } + } + + pub fn get_slot(shred: &[u8]) -> Option { + deserialize_from_with_limit(shred.get(OFFSET_OF_SHRED_SLOT..)?).ok() + } + + pub(super) fn get_index(shred: &[u8]) -> Option { + deserialize_from_with_limit(shred.get(OFFSET_OF_SHRED_INDEX..)?).ok() + } + + // Returns chunk of the payload which is signed. + pub(crate) fn get_signed_message(shred: &[u8]) -> Option<&[u8]> { + shred.get(SIZE_OF_SIGNATURE..) + } + + // Returns slice range of the packet payload which is signed. + pub(crate) fn get_signed_message_range(packet: &Packet) -> Range { + SIZE_OF_SIGNATURE..get_shred_size(packet) + } + + pub(crate) fn get_reference_tick(shred: &[u8]) -> Result { + const SIZE_OF_PARENT_OFFSET: usize = std::mem::size_of::(); + const OFFSET_OF_SHRED_FLAGS: usize = SIZE_OF_COMMON_SHRED_HEADER + SIZE_OF_PARENT_OFFSET; + if get_shred_type(shred)? != ShredType::Data { + return Err(Error::InvalidShredType); + } + let flags = match shred.get(OFFSET_OF_SHRED_FLAGS) { + None => return Err(Error::InvalidPayloadSize(shred.len())), + Some(flags) => flags, + }; + Ok(flags & ShredFlags::SHRED_TICK_REFERENCE_MASK.bits()) + } +} + impl From for Shred { fn from(shred: ShredCode) -> Self { Self::ShredCode(shred) @@ -549,50 +587,39 @@ impl From for Shred { // Get slot, index, and type from a packet with partial deserialize pub fn get_shred_slot_index_type( - p: &Packet, + packet: &Packet, stats: &mut ShredFetchStats, ) -> Option<(Slot, u32, ShredType)> { - let index_start = OFFSET_OF_SHRED_INDEX; - let index_end = index_start + SIZE_OF_SHRED_INDEX; - let slot_start = OFFSET_OF_SHRED_SLOT; - let slot_end = slot_start + SIZE_OF_SHRED_SLOT; - - debug_assert!(index_end > slot_end); - debug_assert!(index_end > OFFSET_OF_SHRED_TYPE); - - if index_end > p.meta.size { + let shred = layout::get_shred(packet); + if OFFSET_OF_SHRED_INDEX + SIZE_OF_SHRED_INDEX > shred.len() { stats.index_overrun += 1; return None; } - - let index = match p.deserialize_slice(index_start..index_end) { - Ok(x) => x, - Err(_e) => { - stats.index_bad_deserialize += 1; - return None; - } - }; - - if index >= MAX_DATA_SHREDS_PER_SLOT as u32 { - stats.index_out_of_bounds += 1; - return None; - } - - let slot = match p.deserialize_slice(slot_start..slot_end) { - Ok(x) => x, - Err(_e) => { - stats.slot_bad_deserialize += 1; - return None; - } - }; - - let shred_type = match ShredType::try_from(p.data()[OFFSET_OF_SHRED_TYPE]) { + let shred_type = match layout::get_shred_type(shred) { + Ok(shred_type) => shred_type, Err(_) => { stats.bad_shred_type += 1; return None; } - Ok(shred_type) => shred_type, }; + let slot = match layout::get_slot(shred) { + Some(slot) => slot, + None => { + stats.slot_bad_deserialize += 1; + return None; + } + }; + let index = match layout::get_index(shred) { + Some(index) => index, + None => { + stats.index_bad_deserialize += 1; + return None; + } + }; + if index >= MAX_DATA_SHREDS_PER_SLOT as u32 { + stats.index_out_of_bounds += 1; + return None; + } Some((slot, index, shred_type)) } @@ -924,9 +951,9 @@ mod tests { assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); assert_eq!( shred.reference_tick(), - Shred::reference_tick_from_data(packet.data()).unwrap() + layout::get_reference_tick(packet.data()).unwrap() ); - assert_eq!(Shred::get_slot_from_packet(&packet), Some(shred.slot())); + assert_eq!(layout::get_slot(packet.data()), Some(shred.slot())); assert_eq!( get_shred_slot_index_type(&packet, &mut ShredFetchStats::default()), Some((shred.slot(), shred.index(), shred.shred_type())) @@ -965,9 +992,9 @@ mod tests { assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); assert_eq!( shred.reference_tick(), - Shred::reference_tick_from_data(packet.data()).unwrap() + layout::get_reference_tick(packet.data()).unwrap() ); - assert_eq!(Shred::get_slot_from_packet(&packet), Some(shred.slot())); + assert_eq!(layout::get_slot(packet.data()), Some(shred.slot())); assert_eq!( get_shred_slot_index_type(&packet, &mut ShredFetchStats::default()), Some((shred.slot(), shred.index(), shred.shred_type())) @@ -1011,7 +1038,7 @@ mod tests { packet.meta.size = payload.len(); assert_eq!(shred.bytes_to_store(), payload); assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); - assert_eq!(Shred::get_slot_from_packet(&packet), Some(shred.slot())); + assert_eq!(layout::get_slot(packet.data()), Some(shred.slot())); assert_eq!( get_shred_slot_index_type(&packet, &mut ShredFetchStats::default()), Some((shred.slot(), shred.index(), shred.shred_type())) @@ -1050,7 +1077,7 @@ mod tests { assert_eq!(shred.last_in_slot(), is_last_in_slot); assert_eq!(shred.reference_tick(), reference_tick.min(63u8)); assert_eq!( - Shred::reference_tick_from_data(shred.payload()).unwrap(), + layout::get_reference_tick(shred.payload()).unwrap(), reference_tick.min(63u8), ); } diff --git a/ledger/src/shredder.rs b/ledger/src/shredder.rs index 31dfc7212..ffa1e7cef 100644 --- a/ledger/src/shredder.rs +++ b/ledger/src/shredder.rs @@ -353,7 +353,8 @@ mod tests { use { super::*, crate::shred::{ - max_entries_per_n_shred, max_ticks_per_n_shreds, verify_test_data_shred, ShredType, + self, max_entries_per_n_shred, max_ticks_per_n_shreds, verify_test_data_shred, + ShredType, }, bincode::serialized_size, matches::assert_matches, @@ -519,7 +520,7 @@ mod tests { ); data_shreds.iter().for_each(|s| { assert_eq!(s.reference_tick(), 5); - assert_eq!(Shred::reference_tick_from_data(s.payload()).unwrap(), 5); + assert_eq!(shred::layout::get_reference_tick(s.payload()).unwrap(), 5); }); let deserialized_shred = @@ -555,7 +556,7 @@ mod tests { ShredFlags::SHRED_TICK_REFERENCE_MASK.bits() ); assert_eq!( - Shred::reference_tick_from_data(s.payload()).unwrap(), + shred::layout::get_reference_tick(s.payload()).unwrap(), ShredFlags::SHRED_TICK_REFERENCE_MASK.bits() ); }); diff --git a/ledger/src/sigverify_shreds.rs b/ledger/src/sigverify_shreds.rs index 4f769680f..f85a696c7 100644 --- a/ledger/src/sigverify_shreds.rs +++ b/ledger/src/sigverify_shreds.rs @@ -1,13 +1,8 @@ #![allow(clippy::implicit_hasher)] use { - crate::shred::{ShredType, SIZE_OF_NONCE}, - rayon::{ - iter::{ - IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, - ParallelIterator, - }, - ThreadPool, - }, + crate::shred, + itertools::Itertools, + rayon::{prelude::*, ThreadPool}, sha2::{Digest, Sha512}, solana_metrics::inc_new_counter_debug, solana_perf::{ @@ -23,58 +18,52 @@ use { pubkey::Pubkey, signature::{Keypair, Signature, Signer}, }, - std::{collections::HashMap, mem::size_of, sync::Arc}, + std::{collections::HashMap, fmt::Debug, iter::repeat, mem::size_of, ops::Range, sync::Arc}, }; -pub const SIGN_SHRED_GPU_MIN: usize = 256; +const SIGN_SHRED_GPU_MIN: usize = 256; lazy_static! { - pub static ref SIGVERIFY_THREAD_POOL: ThreadPool = rayon::ThreadPoolBuilder::new() + static ref SIGVERIFY_THREAD_POOL: ThreadPool = rayon::ThreadPoolBuilder::new() .num_threads(get_thread_count()) .thread_name(|ix| format!("sigverify_shreds_{}", ix)) .build() .unwrap(); } -/// Assuming layout is -/// signature: Signature -/// signed_msg: { -/// type: ShredType -/// slot: u64, -/// ... -/// } -/// Signature is the first thing in the packet, and slot is the first thing in the signed message. -pub fn verify_shred_cpu(packet: &Packet, slot_leaders: &HashMap) -> Option { - let sig_start = 0; - let sig_end = size_of::(); - let slot_start = sig_end + size_of::(); - let slot_end = slot_start + size_of::(); - let msg_start = sig_end; +pub fn verify_shred_cpu( + packet: &Packet, + slot_leaders: &HashMap, +) -> bool { if packet.meta.discard() { - return Some(0); + return false; } - trace!("slot start and end {} {}", slot_start, slot_end); - let slot: u64 = packet.deserialize_slice(slot_start..slot_end).ok()?; - let msg_end = if packet.meta.repair() { - packet.meta.size.saturating_sub(SIZE_OF_NONCE) - } else { - packet.meta.size + let shred = shred::layout::get_shred(packet); + let slot = match shred::layout::get_slot(shred) { + None => return false, + Some(slot) => slot, }; trace!("slot {}", slot); - let pubkey = slot_leaders.get(&slot)?; - let signature = Signature::new(packet.data().get(sig_start..sig_end)?); + let pubkey = match slot_leaders.get(&slot) { + None => return false, + Some(pubkey) => pubkey, + }; + let signature = match shred::layout::get_signature(shred) { + None => return false, + Some(signature) => signature, + }; trace!("signature {}", signature); - if !signature.verify(pubkey, packet.data().get(msg_start..msg_end)?) { - return Some(0); - } - Some(1) + let message = match shred::layout::get_signed_message(shred) { + None => return false, + Some(message) => message, + }; + signature.verify(pubkey, message) } fn verify_shreds_cpu( batches: &[PacketBatch], - slot_leaders: &HashMap, + slot_leaders: &HashMap, ) -> Vec> { - use rayon::prelude::*; let packet_count = count_packets_in_batches(batches); debug!("CPU SHRED ECDSA for {}", packet_count); let rv = SIGVERIFY_THREAD_POOL.install(|| { @@ -83,7 +72,7 @@ fn verify_shreds_cpu( .map(|batch| { batch .par_iter() - .map(|p| verify_shred_cpu(p, slot_leaders).unwrap_or(0)) + .map(|packet| u8::from(verify_shred_cpu(packet, slot_leaders))) .collect() }) .collect() @@ -92,73 +81,61 @@ fn verify_shreds_cpu( rv } -fn slot_key_data_for_gpu< - T: Sync + Sized + Default + std::fmt::Debug + Eq + std::hash::Hash + Clone + Copy + AsRef<[u8]>, ->( +fn slot_key_data_for_gpu( offset_start: usize, batches: &[PacketBatch], - slot_keys: &HashMap, + slot_keys: &HashMap, recycler_cache: &RecyclerCache, -) -> (PinnedVec, TxOffset, usize) { +) -> (PinnedVec, TxOffset, usize) +where + T: AsRef<[u8]> + Copy + Debug + Default + Eq + std::hash::Hash + Sync, +{ //TODO: mark Pubkey::default shreds as failed after the GPU returns - assert_eq!(slot_keys.get(&std::u64::MAX), Some(&T::default())); - let slots: Vec> = SIGVERIFY_THREAD_POOL.install(|| { + assert_eq!(slot_keys.get(&Slot::MAX), Some(&T::default())); + let slots: Vec = SIGVERIFY_THREAD_POOL.install(|| { batches .into_par_iter() - .map(|batch| { - batch - .iter() - .map(|packet| { - if packet.meta.discard() { - return Slot::MAX; - } - - let slot_start = size_of::() + size_of::(); - let slot_end = slot_start + size_of::(); - let slot: Option = - packet.deserialize_slice(slot_start..slot_end).ok(); - match slot { - Some(slot) if slot_keys.get(&slot).is_some() => slot, - _ => Slot::MAX, - } - }) - .collect() + .flat_map_iter(|batch| { + batch.iter().map(|packet| { + if packet.meta.discard() { + return Slot::MAX; + } + let shred = shred::layout::get_shred(packet); + match shred::layout::get_slot(shred) { + Some(slot) if slot_keys.contains_key(&slot) => slot, + _ => Slot::MAX, + } + }) }) .collect() }); - let mut keys_to_slots: HashMap> = HashMap::new(); - for batch in slots.iter() { - for slot in batch.iter() { - let key = slot_keys.get(slot).unwrap(); - keys_to_slots - .entry(*key) - .or_insert_with(Vec::new) - .push(*slot); - } - } + let keys_to_slots: HashMap> = slots + .iter() + .map(|slot| (*slot_keys.get(slot).unwrap(), *slot)) + .into_group_map(); let mut keyvec = recycler_cache.buffer().allocate("shred_gpu_pubkeys"); keyvec.set_pinnable(); - let mut slot_to_key_ix = HashMap::new(); let keyvec_size = keys_to_slots.len() * size_of::(); keyvec.resize(keyvec_size, 0); - for (i, (k, slots)) in keys_to_slots.iter().enumerate() { - let start = i * size_of::(); - let end = start + size_of::(); - keyvec[start..end].copy_from_slice(k.as_ref()); - for s in slots { - slot_to_key_ix.insert(s, i); - } - } + let slot_to_key_ix: HashMap = keys_to_slots + .into_iter() + .enumerate() + .flat_map(|(i, (k, slots))| { + let start = i * size_of::(); + let end = start + size_of::(); + keyvec[start..end].copy_from_slice(k.as_ref()); + slots.into_iter().zip(repeat(i)) + }) + .collect(); + let mut offsets = recycler_cache.offsets().allocate("shred_offsets"); offsets.set_pinnable(); - slots.iter().for_each(|packet_slots| { - packet_slots.iter().for_each(|slot| { - offsets - .push((offset_start + (slot_to_key_ix.get(slot).unwrap() * size_of::())) as u32); - }); - }); + for slot in slots { + let key_offset = slot_to_key_ix.get(&slot).unwrap() * size_of::(); + offsets.push((offset_start + key_offset) as u32); + } let num_in_packets = resize_vec(&mut keyvec); trace!("keyvec.len: {}", keyvec.len()); trace!("keyvec: {:?}", keyvec); @@ -184,6 +161,9 @@ fn shred_gpu_offsets( batches: &[PacketBatch], recycler_cache: &RecyclerCache, ) -> (TxOffset, TxOffset, TxOffset, Vec>) { + fn add_offset(range: Range, offset: usize) -> Range { + range.start + offset..range.end + offset + } let mut signature_offsets = recycler_cache.offsets().allocate("shred_signatures"); signature_offsets.set_pinnable(); let mut msg_start_offsets = recycler_cache.offsets().allocate("shred_msg_starts"); @@ -194,21 +174,14 @@ fn shred_gpu_offsets( for batch in batches.iter() { let mut sig_lens = Vec::new(); for packet in batch.iter() { - let sig_start = pubkeys_end; - let sig_end = sig_start + size_of::(); - let msg_start = sig_end; - let msg_end = if packet.meta.repair() { - sig_start + packet.meta.size.saturating_sub(SIZE_OF_NONCE) - } else { - sig_start + packet.meta.size - }; - signature_offsets.push(sig_start as u32); - msg_start_offsets.push(msg_start as u32); - let msg_size = if msg_end < msg_start { - 0 - } else { - msg_end - msg_start - }; + let sig = shred::layout::get_signature_range(); + let sig = add_offset(sig, pubkeys_end); + debug_assert_eq!(sig.end - sig.start, std::mem::size_of::()); + let msg = shred::layout::get_signed_message_range(packet); + let msg = add_offset(msg, pubkeys_end); + signature_offsets.push(sig.start as u32); + msg_start_offsets.push(msg.start as u32); + let msg_size = msg.end.saturating_sub(msg.start); msg_sizes.push(msg_size as u32); sig_lens.push(1); pubkeys_end += size_of::(); @@ -220,7 +193,7 @@ fn shred_gpu_offsets( pub fn verify_shreds_gpu( batches: &[PacketBatch], - slot_leaders: &HashMap, + slot_leaders: &HashMap, recycler_cache: &RecyclerCache, ) -> Vec> { let api = perf_libs::api(); @@ -292,25 +265,19 @@ pub fn verify_shreds_gpu( rvs } -/// Assuming layout is -/// signature: Signature -/// signed_msg: { -/// type: ShredType -/// slot: u64, -/// ... -/// } -/// Signature is the first thing in the packet, and slot is the first thing in the signed message. fn sign_shred_cpu(keypair: &Keypair, packet: &mut Packet) { - let sig_start = 0; - let sig_end = sig_start + size_of::(); - let msg_start = sig_end; - let signature = keypair.sign_message(&packet.data()[msg_start..]); + let sig = shred::layout::get_signature_range(); + let msg = shred::layout::get_signed_message_range(packet); + assert!( + packet.meta.size >= sig.end, + "packet is not large enough for a signature" + ); + let signature = keypair.sign_message(&packet.data()[msg]); trace!("signature {:?}", signature); - packet.buffer_mut()[..sig_end].copy_from_slice(signature.as_ref()); + packet.buffer_mut()[sig].copy_from_slice(signature.as_ref()); } pub fn sign_shreds_cpu(keypair: &Keypair, batches: &mut [PacketBatch]) { - use rayon::prelude::*; let packet_count = count_packets_in_batches(batches); debug!("CPU SHRED ECDSA for {}", packet_count); SIGVERIFY_THREAD_POOL.install(|| { @@ -444,7 +411,7 @@ pub fn sign_shreds_gpu( } #[cfg(test)] -pub mod tests { +mod tests { use { super::*, crate::shred::{Shred, ShredFlags, SIZE_OF_DATA_SHRED_PAYLOAD}, @@ -475,20 +442,17 @@ pub mod tests { .iter() .cloned() .collect(); - let rv = verify_shred_cpu(&packet, &leader_slots); - assert_eq!(rv, Some(1)); + assert!(verify_shred_cpu(&packet, &leader_slots)); let wrong_keypair = Keypair::new(); let leader_slots = [(slot, wrong_keypair.pubkey().to_bytes())] .iter() .cloned() .collect(); - let rv = verify_shred_cpu(&packet, &leader_slots); - assert_eq!(rv, Some(0)); + assert!(!verify_shred_cpu(&packet, &leader_slots)); let leader_slots = HashMap::new(); - let rv = verify_shred_cpu(&packet, &leader_slots); - assert_eq!(rv, None); + assert!(!verify_shred_cpu(&packet, &leader_slots)); } #[test]