#[cfg(test)] use {crate::shred::ShredType, solana_sdk::pubkey::Pubkey}; use { crate::{ shred::{ common::impl_shred_common, dispatch, shred_code, shred_data, traits::{ Shred as ShredTrait, ShredCode as ShredCodeTrait, ShredData as ShredDataTrait, }, CodingShredHeader, DataShredHeader, Error, ShredCommonHeader, ShredVariant, SIZE_OF_CODING_SHRED_HEADERS, SIZE_OF_DATA_SHRED_HEADERS, SIZE_OF_SIGNATURE, }, shredder::ReedSolomon, }, reed_solomon_erasure::Error::{InvalidIndex, TooFewParityShards, TooFewShards}, solana_perf::packet::deserialize_from_with_limit, solana_sdk::{ clock::Slot, hash::{hashv, Hash}, signature::Signature, }, static_assertions::const_assert_eq, std::{ io::{Cursor, Seek, SeekFrom}, iter::repeat_with, ops::Range, }, }; const_assert_eq!(SIZE_OF_MERKLE_ROOT, 20); const SIZE_OF_MERKLE_ROOT: usize = std::mem::size_of::(); const_assert_eq!(SIZE_OF_MERKLE_PROOF_ENTRY, 20); const SIZE_OF_MERKLE_PROOF_ENTRY: usize = std::mem::size_of::(); const_assert_eq!(ShredData::SIZE_OF_PAYLOAD, 1203); // Defense against second preimage attack: // https://en.wikipedia.org/wiki/Merkle_tree#Second_preimage_attack const MERKLE_HASH_PREFIX_LEAF: &[u8] = &[0x00]; const MERKLE_HASH_PREFIX_NODE: &[u8] = &[0x01]; type MerkleRoot = MerkleProofEntry; type MerkleProofEntry = [u8; 20]; // Layout: {common, data} headers | data buffer | merkle branch // The slice past signature and before merkle branch is erasure coded. // Same slice is hashed to generate merkle tree. // The root of merkle tree is signed. #[derive(Clone, Debug, Eq, PartialEq)] pub struct ShredData { common_header: ShredCommonHeader, data_header: DataShredHeader, merkle_branch: MerkleBranch, payload: Vec, } // Layout: {common, coding} headers | erasure coded shard | merkle branch // The slice past signature and before merkle branch is hashed to generate // merkle tree. The root of merkle tree is signed. #[derive(Clone, Debug, Eq, PartialEq)] pub struct ShredCode { common_header: ShredCommonHeader, coding_header: CodingShredHeader, merkle_branch: MerkleBranch, payload: Vec, } #[derive(Clone, Debug, Eq, PartialEq)] pub(super) enum Shred { ShredCode(ShredCode), ShredData(ShredData), } #[derive(Clone, Debug, Eq, PartialEq)] struct MerkleBranch { root: MerkleRoot, proof: Vec, } impl Shred { dispatch!(fn common_header(&self) -> &ShredCommonHeader); dispatch!(fn erasure_shard_as_slice(&self) -> Result<&[u8], Error>); dispatch!(fn erasure_shard_index(&self) -> Result); dispatch!(fn merkle_tree_node(&self) -> Result); dispatch!(fn sanitize(&self) -> Result<(), Error>); dispatch!(fn set_merkle_branch(&mut self, merkle_branch: MerkleBranch) -> Result<(), Error>); fn merkle_root(&self) -> &MerkleRoot { match self { Self::ShredCode(shred) => &shred.merkle_branch.root, Self::ShredData(shred) => &shred.merkle_branch.root, } } } #[cfg(test)] impl Shred { dispatch!(fn set_signature(&mut self, signature: Signature)); dispatch!(fn signed_message(&self) -> &[u8]); fn index(&self) -> u32 { self.common_header().index } fn shred_type(&self) -> ShredType { ShredType::from(self.common_header().shred_variant) } fn signature(&self) -> Signature { self.common_header().signature } #[must_use] fn verify(&self, pubkey: &Pubkey) -> bool { let message = self.signed_message(); self.signature().verify(pubkey.as_ref(), message) } } impl ShredData { // proof_size is the number of proof entries in the merkle tree branch. fn proof_size(&self) -> Result { match self.common_header.shred_variant { ShredVariant::MerkleData(proof_size) => Ok(proof_size), _ => Err(Error::InvalidShredVariant), } } // Maximum size of ledger data that can be embedded in a data-shred. // Also equal to: // ShredCode::capacity(proof_size).unwrap() // - ShredData::SIZE_OF_HEADERS // + SIZE_OF_SIGNATURE pub(super) fn capacity(proof_size: u8) -> Result { Self::SIZE_OF_PAYLOAD .checked_sub( Self::SIZE_OF_HEADERS + SIZE_OF_MERKLE_ROOT + usize::from(proof_size) * SIZE_OF_MERKLE_PROOF_ENTRY, ) .ok_or(Error::InvalidProofSize(proof_size)) } pub(super) fn get_signed_message_range(proof_size: u8) -> Option> { let data_buffer_size = Self::capacity(proof_size).ok()?; let offset = Self::SIZE_OF_HEADERS + data_buffer_size; Some(offset..offset + SIZE_OF_MERKLE_ROOT) } fn merkle_tree_node(&self) -> Result { let chunk = self.erasure_shard_as_slice()?; Ok(hashv(&[MERKLE_HASH_PREFIX_LEAF, chunk])) } fn verify_merkle_proof(&self) -> Result { let node = self.merkle_tree_node()?; let index = self.erasure_shard_index()?; Ok(verify_merkle_proof(index, node, &self.merkle_branch)) } fn from_recovered_shard(signature: &Signature, mut shard: Vec) -> Result { let shard_size = shard.len(); if shard_size + SIZE_OF_SIGNATURE > Self::SIZE_OF_PAYLOAD { return Err(Error::InvalidShardSize(shard_size)); } shard.resize(Self::SIZE_OF_PAYLOAD, 0u8); shard.copy_within(0..shard_size, SIZE_OF_SIGNATURE); shard[0..SIZE_OF_SIGNATURE].copy_from_slice(signature.as_ref()); // Deserialize headers. let mut cursor = Cursor::new(&shard[..]); let common_header: ShredCommonHeader = deserialize_from_with_limit(&mut cursor)?; let proof_size = match common_header.shred_variant { ShredVariant::MerkleData(proof_size) => proof_size, _ => return Err(Error::InvalidShredVariant), }; if ShredCode::capacity(proof_size)? != shard_size { return Err(Error::InvalidShardSize(shard_size)); } let data_header = deserialize_from_with_limit(&mut cursor)?; Ok(Self { common_header, data_header, merkle_branch: MerkleBranch::new_zeroed(proof_size), payload: shard, }) } fn set_merkle_branch(&mut self, merkle_branch: MerkleBranch) -> Result<(), Error> { let proof_size = self.proof_size()?; if merkle_branch.proof.len() != usize::from(proof_size) { return Err(Error::InvalidMerkleProof); } let offset = Self::SIZE_OF_HEADERS + Self::capacity(proof_size)?; let mut cursor = Cursor::new( self.payload .get_mut(offset..) .ok_or(Error::InvalidProofSize(proof_size))?, ); bincode::serialize_into(&mut cursor, &merkle_branch.root)?; for entry in &merkle_branch.proof { bincode::serialize_into(&mut cursor, entry)?; } self.merkle_branch = merkle_branch; Ok(()) } } impl ShredCode { // proof_size is the number of proof entries in the merkle tree branch. fn proof_size(&self) -> Result { match self.common_header.shred_variant { ShredVariant::MerkleCode(proof_size) => Ok(proof_size), _ => Err(Error::InvalidShredVariant), } } // Size of buffer embedding erasure codes. fn capacity(proof_size: u8) -> Result { // Merkle branch is generated and signed after coding shreds are // generated. Coding shred headers cannot be erasure coded either. Self::SIZE_OF_PAYLOAD .checked_sub( Self::SIZE_OF_HEADERS + SIZE_OF_MERKLE_ROOT + SIZE_OF_MERKLE_PROOF_ENTRY * usize::from(proof_size), ) .ok_or(Error::InvalidProofSize(proof_size)) } fn merkle_tree_node(&self) -> Result { let proof_size = self.proof_size()?; let shard_size = Self::capacity(proof_size)?; let chunk = self .payload .get(SIZE_OF_SIGNATURE..Self::SIZE_OF_HEADERS + shard_size) .ok_or(Error::InvalidPayloadSize(self.payload.len()))?; Ok(hashv(&[MERKLE_HASH_PREFIX_LEAF, chunk])) } fn verify_merkle_proof(&self) -> Result { let node = self.merkle_tree_node()?; let index = self.erasure_shard_index()?; Ok(verify_merkle_proof(index, node, &self.merkle_branch)) } pub(super) fn get_signed_message_range(proof_size: u8) -> Option> { let offset = Self::SIZE_OF_HEADERS + Self::capacity(proof_size).ok()?; Some(offset..offset + SIZE_OF_MERKLE_ROOT) } pub(super) fn erasure_mismatch(&self, other: &ShredCode) -> bool { shred_code::erasure_mismatch(self, other) || self.merkle_branch.root != other.merkle_branch.root || self.common_header.signature != other.common_header.signature } fn from_recovered_shard( common_header: ShredCommonHeader, coding_header: CodingShredHeader, mut shard: Vec, ) -> Result { let proof_size = match common_header.shred_variant { ShredVariant::MerkleCode(proof_size) => proof_size, _ => return Err(Error::InvalidShredVariant), }; let shard_size = shard.len(); if Self::capacity(proof_size)? != shard_size { return Err(Error::InvalidShardSize(shard_size)); } if shard_size + Self::SIZE_OF_HEADERS > Self::SIZE_OF_PAYLOAD { return Err(Error::InvalidShardSize(shard_size)); } shard.resize(Self::SIZE_OF_PAYLOAD, 0u8); shard.copy_within(0..shard_size, Self::SIZE_OF_HEADERS); let mut cursor = Cursor::new(&mut shard[..]); bincode::serialize_into(&mut cursor, &common_header)?; bincode::serialize_into(&mut cursor, &coding_header)?; Ok(Self { common_header, coding_header, merkle_branch: MerkleBranch::new_zeroed(proof_size), payload: shard, }) } fn set_merkle_branch(&mut self, merkle_branch: MerkleBranch) -> Result<(), Error> { let proof_size = self.proof_size()?; if merkle_branch.proof.len() != usize::from(proof_size) { return Err(Error::InvalidMerkleProof); } let offset = Self::SIZE_OF_HEADERS + Self::capacity(proof_size)?; let mut cursor = Cursor::new( self.payload .get_mut(offset..) .ok_or(Error::InvalidProofSize(proof_size))?, ); bincode::serialize_into(&mut cursor, &merkle_branch.root)?; for entry in &merkle_branch.proof { bincode::serialize_into(&mut cursor, entry)?; } self.merkle_branch = merkle_branch; Ok(()) } } impl MerkleBranch { fn new_zeroed(proof_size: u8) -> Self { Self { root: MerkleRoot::default(), proof: vec![MerkleProofEntry::default(); usize::from(proof_size)], } } } impl ShredTrait for ShredData { impl_shred_common!(); // Also equal to: // ShredData::SIZE_OF_HEADERS // + ShredData::capacity(proof_size).unwrap() // + SIZE_OF_MERKLE_ROOT // + usize::from(proof_size) * SIZE_OF_MERKLE_PROOF_ENTRY const SIZE_OF_PAYLOAD: usize = ShredCode::SIZE_OF_PAYLOAD - ShredCode::SIZE_OF_HEADERS + SIZE_OF_SIGNATURE; const SIZE_OF_HEADERS: usize = SIZE_OF_DATA_SHRED_HEADERS; fn from_payload(mut payload: Vec) -> Result { if payload.len() < Self::SIZE_OF_PAYLOAD { return Err(Error::InvalidPayloadSize(payload.len())); } payload.truncate(Self::SIZE_OF_PAYLOAD); let mut cursor = Cursor::new(&payload[..]); let common_header: ShredCommonHeader = deserialize_from_with_limit(&mut cursor)?; let proof_size = match common_header.shred_variant { ShredVariant::MerkleData(proof_size) => proof_size, _ => return Err(Error::InvalidShredVariant), }; let data_header = deserialize_from_with_limit(&mut cursor)?; // Skip data buffer. let data_buffer_size = Self::capacity(proof_size)?; let data_buffer_size = i64::try_from(data_buffer_size).unwrap(); cursor.seek(SeekFrom::Current(data_buffer_size))?; // Deserialize merkle branch. let root = deserialize_from_with_limit(&mut cursor)?; let proof = repeat_with(|| deserialize_from_with_limit(&mut cursor)) .take(usize::from(proof_size)) .collect::>()?; let merkle_branch = MerkleBranch { root, proof }; let shred = Self { common_header, data_header, merkle_branch, payload, }; shred.sanitize().map(|_| shred) } fn erasure_shard_index(&self) -> Result { shred_data::erasure_shard_index(self).ok_or_else(|| { let headers = Box::new((self.common_header, self.data_header)); Error::InvalidErasureShardIndex(headers) }) } fn erasure_shard(self) -> Result, Error> { if self.payload.len() != Self::SIZE_OF_PAYLOAD { return Err(Error::InvalidPayloadSize(self.payload.len())); } let proof_size = self.proof_size()?; let data_buffer_size = Self::capacity(proof_size)?; let mut shard = self.payload; shard.truncate(Self::SIZE_OF_HEADERS + data_buffer_size); shard.drain(0..SIZE_OF_SIGNATURE); Ok(shard) } fn erasure_shard_as_slice(&self) -> Result<&[u8], Error> { if self.payload.len() != Self::SIZE_OF_PAYLOAD { return Err(Error::InvalidPayloadSize(self.payload.len())); } let proof_size = self.proof_size()?; let data_buffer_size = Self::capacity(proof_size)?; self.payload .get(SIZE_OF_SIGNATURE..Self::SIZE_OF_HEADERS + data_buffer_size) .ok_or(Error::InvalidPayloadSize(self.payload.len())) } fn sanitize(&self) -> Result<(), Error> { match self.common_header.shred_variant { ShredVariant::MerkleData(proof_size) => { if self.merkle_branch.proof.len() != usize::from(proof_size) { return Err(Error::InvalidProofSize(proof_size)); } } _ => return Err(Error::InvalidShredVariant), } if !self.verify_merkle_proof()? { return Err(Error::InvalidMerkleProof); } shred_data::sanitize(self) } fn signed_message(&self) -> &[u8] { self.merkle_branch.root.as_ref() } } impl ShredTrait for ShredCode { impl_shred_common!(); const SIZE_OF_PAYLOAD: usize = shred_code::ShredCode::SIZE_OF_PAYLOAD; const SIZE_OF_HEADERS: usize = SIZE_OF_CODING_SHRED_HEADERS; fn from_payload(mut payload: Vec) -> Result { let mut cursor = Cursor::new(&payload[..]); let common_header: ShredCommonHeader = deserialize_from_with_limit(&mut cursor)?; let proof_size = match common_header.shred_variant { ShredVariant::MerkleCode(proof_size) => proof_size, _ => return Err(Error::InvalidShredVariant), }; let coding_header = deserialize_from_with_limit(&mut cursor)?; // Skip erasure code shard. let shard_size = Self::capacity(proof_size)?; let shard_size = i64::try_from(shard_size).unwrap(); cursor.seek(SeekFrom::Current(shard_size))?; // Deserialize merkle branch. let root = deserialize_from_with_limit(&mut cursor)?; let proof = repeat_with(|| deserialize_from_with_limit(&mut cursor)) .take(usize::from(proof_size)) .collect::>()?; let merkle_branch = MerkleBranch { root, proof }; // see: https://github.com/solana-labs/solana/pull/10109 payload.truncate(Self::SIZE_OF_PAYLOAD); let shred = Self { common_header, coding_header, merkle_branch, payload, }; shred.sanitize().map(|_| shred) } fn erasure_shard_index(&self) -> Result { shred_code::erasure_shard_index(self).ok_or_else(|| { let headers = Box::new((self.common_header, self.coding_header)); Error::InvalidErasureShardIndex(headers) }) } fn erasure_shard(self) -> Result, Error> { if self.payload.len() != Self::SIZE_OF_PAYLOAD { return Err(Error::InvalidPayloadSize(self.payload.len())); } let proof_size = self.proof_size()?; let shard_size = Self::capacity(proof_size)?; let mut shard = self.payload; shard.drain(..Self::SIZE_OF_HEADERS); shard.truncate(shard_size); Ok(shard) } fn erasure_shard_as_slice(&self) -> Result<&[u8], Error> { if self.payload.len() != Self::SIZE_OF_PAYLOAD { return Err(Error::InvalidPayloadSize(self.payload.len())); } let proof_size = self.proof_size()?; let shard_size = Self::capacity(proof_size)?; self.payload .get(Self::SIZE_OF_HEADERS..Self::SIZE_OF_HEADERS + shard_size) .ok_or(Error::InvalidPayloadSize(self.payload.len())) } fn sanitize(&self) -> Result<(), Error> { match self.common_header.shred_variant { ShredVariant::MerkleCode(proof_size) => { if self.merkle_branch.proof.len() != usize::from(proof_size) { return Err(Error::InvalidProofSize(proof_size)); } } _ => return Err(Error::InvalidShredVariant), } if !self.verify_merkle_proof()? { return Err(Error::InvalidMerkleProof); } shred_code::sanitize(self) } fn signed_message(&self) -> &[u8] { self.merkle_branch.root.as_ref() } } impl ShredDataTrait for ShredData { #[inline] fn data_header(&self) -> &DataShredHeader { &self.data_header } fn data(&self) -> Result<&[u8], Error> { let proof_size = self.proof_size()?; let data_buffer_size = Self::capacity(proof_size)?; let size = usize::from(self.data_header.size); if size > self.payload.len() || size < Self::SIZE_OF_HEADERS || size > Self::SIZE_OF_HEADERS + data_buffer_size { return Err(Error::InvalidDataSize { size: self.data_header.size, payload: self.payload.len(), }); } Ok(&self.payload[Self::SIZE_OF_HEADERS..size]) } } impl ShredCodeTrait for ShredCode { #[inline] fn coding_header(&self) -> &CodingShredHeader { &self.coding_header } } // Obtains parent's hash by joining two sibiling nodes in merkle tree. fn join_nodes, T: AsRef<[u8]>>(node: S, other: T) -> Hash { let node = &node.as_ref()[..SIZE_OF_MERKLE_PROOF_ENTRY]; let other = &other.as_ref()[..SIZE_OF_MERKLE_PROOF_ENTRY]; hashv(&[MERKLE_HASH_PREFIX_NODE, node, other]) } fn verify_merkle_proof(index: usize, node: Hash, merkle_branch: &MerkleBranch) -> bool { let proof = merkle_branch.proof.iter(); let (index, root) = proof.fold((index, node), |(index, node), other| { let parent = if index % 2 == 0 { join_nodes(node, other) } else { join_nodes(other, node) }; (index >> 1, parent) }); let root = &root.as_ref()[..SIZE_OF_MERKLE_ROOT]; (index, root) == (0usize, &merkle_branch.root[..]) } fn make_merkle_tree(mut nodes: Vec) -> Vec { let mut size = nodes.len(); while size > 1 { let offset = nodes.len() - size; for index in (offset..offset + size).step_by(2) { let node = &nodes[index]; let other = &nodes[(index + 1).min(offset + size - 1)]; let parent = join_nodes(node, other); nodes.push(parent); } size = nodes.len() - offset - size; } nodes } fn make_merkle_branch( mut index: usize, // leaf index ~ shred's erasure shard index. mut size: usize, // number of leaves ~ erasure batch size. tree: &[Hash], ) -> Option { if index >= size { return None; } let mut offset = 0; let mut proof = Vec::::new(); while size > 1 { let node = tree.get(offset + (index ^ 1).min(size - 1))?; let entry = &node.as_ref()[..SIZE_OF_MERKLE_PROOF_ENTRY]; proof.push(MerkleProofEntry::try_from(entry).unwrap()); offset += size; size = (size + 1) >> 1; index >>= 1; } if offset + 1 != tree.len() { return None; } let root = &tree.last()?.as_ref()[..SIZE_OF_MERKLE_ROOT]; let root = MerkleRoot::try_from(root).unwrap(); Some(MerkleBranch { root, proof }) } pub(super) fn recover(mut shreds: Vec) -> Result, Error> { // Grab {common, coding} headers from first coding shred. let headers = shreds.iter().find_map(|shred| { let shred = match shred { Shred::ShredCode(shred) => shred, Shred::ShredData(_) => return None, }; let position = u32::from(shred.coding_header.position); let common_header = ShredCommonHeader { index: shred.common_header.index.checked_sub(position)?, ..shred.common_header }; let coding_header = CodingShredHeader { position: 0u16, ..shred.coding_header }; Some((common_header, coding_header)) }); let (common_header, coding_header) = headers.ok_or(TooFewParityShards)?; debug_assert!(matches!( common_header.shred_variant, ShredVariant::MerkleCode(_) )); let proof_size = match common_header.shred_variant { ShredVariant::MerkleCode(proof_size) => proof_size, ShredVariant::MerkleData(_) | ShredVariant::LegacyCode | ShredVariant::LegacyData => { return Err(Error::InvalidShredVariant); } }; // Verify that shreds belong to the same erasure batch // and have consistent headers. debug_assert!(shreds.iter().all(|shred| { let ShredCommonHeader { signature, shred_variant, slot, index: _, version, fec_set_index, } = shred.common_header(); signature == &common_header.signature && slot == &common_header.slot && version == &common_header.version && fec_set_index == &common_header.fec_set_index && match shred { Shred::ShredData(_) => shred_variant == &ShredVariant::MerkleData(proof_size), Shred::ShredCode(shred) => { let CodingShredHeader { num_data_shreds, num_coding_shreds, position: _, } = shred.coding_header; shred_variant == &ShredVariant::MerkleCode(proof_size) && num_data_shreds == coding_header.num_data_shreds && num_coding_shreds == coding_header.num_coding_shreds } } })); let num_data_shreds = usize::from(coding_header.num_data_shreds); let num_coding_shreds = usize::from(coding_header.num_coding_shreds); let num_shards = num_data_shreds + num_coding_shreds; // Obtain erasure encoded shards from shreds. let shreds = { let mut batch = vec![None; num_shards]; while let Some(shred) = shreds.pop() { let index = match shred.erasure_shard_index() { Ok(index) if index < batch.len() => index, _ => return Err(Error::from(InvalidIndex)), }; batch[index] = Some(shred); } batch }; let mut shards: Vec>> = shreds .iter() .map(|shred| Some(shred.as_ref()?.erasure_shard_as_slice().ok()?.to_vec())) .collect(); ReedSolomon::new(num_data_shreds, num_coding_shreds)?.reconstruct(&mut shards)?; let mask: Vec<_> = shreds.iter().map(Option::is_some).collect(); // Reconstruct code and data shreds from erasure encoded shards. let mut shreds: Vec<_> = shreds .into_iter() .zip(shards) .enumerate() .map(|(index, (shred, shard))| { if let Some(shred) = shred { return Ok(shred); } let shard = shard.ok_or(TooFewShards)?; if index < num_data_shreds { let shred = ShredData::from_recovered_shard(&common_header.signature, shard)?; let ShredCommonHeader { signature: _, shred_variant, slot, index: _, version, fec_set_index, } = shred.common_header; if shred_variant != ShredVariant::MerkleData(proof_size) || common_header.slot != slot || common_header.version != version || common_header.fec_set_index != fec_set_index { return Err(Error::InvalidRecoveredShred); } Ok(Shred::ShredData(shred)) } else { let offset = index - num_data_shreds; let coding_header = CodingShredHeader { position: offset as u16, ..coding_header }; let common_header = ShredCommonHeader { index: common_header.index + offset as u32, ..common_header }; let shred = ShredCode::from_recovered_shard(common_header, coding_header, shard)?; Ok(Shred::ShredCode(shred)) } }) .collect::>()?; // Compute merkle tree and set the merkle branch on the recovered shreds. let nodes: Vec<_> = shreds .iter() .map(Shred::merkle_tree_node) .collect::>()?; let tree = make_merkle_tree(nodes); let merkle_root = &tree.last().unwrap().as_ref()[..SIZE_OF_MERKLE_ROOT]; let merkle_root = MerkleRoot::try_from(merkle_root).unwrap(); for (index, (shred, mask)) in shreds.iter_mut().zip(&mask).enumerate() { if *mask { if shred.merkle_root() != &merkle_root { return Err(Error::InvalidMerkleProof); } } else { let merkle_branch = make_merkle_branch(index, num_shards, &tree).ok_or(Error::InvalidMerkleProof)?; if merkle_branch.proof.len() != usize::from(proof_size) { return Err(Error::InvalidMerkleProof); } shred.set_merkle_branch(merkle_branch)?; } } // TODO: No need to verify merkle proof in sanitize here. shreds .into_iter() .zip(mask) .filter(|(_, mask)| !mask) .map(|(shred, _)| shred.sanitize().map(|_| shred)) .collect() } #[cfg(test)] mod test { use { super::*, crate::shred::ShredFlags, itertools::Itertools, matches::assert_matches, rand::{seq::SliceRandom, CryptoRng, Rng}, solana_sdk::signature::{Keypair, Signer}, std::{cmp::Ordering, iter::repeat_with}, test_case::test_case, }; // Total size of a data shred including headers and merkle branch. fn shred_data_size_of_payload(proof_size: u8) -> usize { ShredData::SIZE_OF_HEADERS + ShredData::capacity(proof_size).unwrap() + SIZE_OF_MERKLE_ROOT + usize::from(proof_size) * SIZE_OF_MERKLE_PROOF_ENTRY } // Merkle branch is generated and signed after coding shreds are generated. // All payload excluding merkle branch and the signature are erasure coded. // Therefore the data capacity is equal to erasure encoded shard size minus // size of erasure encoded header. fn shred_data_capacity(proof_size: u8) -> usize { const SIZE_OF_ERASURE_ENCODED_HEADER: usize = ShredData::SIZE_OF_HEADERS - SIZE_OF_SIGNATURE; ShredCode::capacity(proof_size).unwrap() - SIZE_OF_ERASURE_ENCODED_HEADER } fn shred_data_size_of_erasure_encoded_slice(proof_size: u8) -> usize { ShredData::SIZE_OF_PAYLOAD - SIZE_OF_SIGNATURE - SIZE_OF_MERKLE_ROOT - usize::from(proof_size) * SIZE_OF_MERKLE_PROOF_ENTRY } #[test] fn test_shred_data_size_of_payload() { for proof_size in 0..0x15 { assert_eq!( ShredData::SIZE_OF_PAYLOAD, shred_data_size_of_payload(proof_size) ); } } #[test] fn test_shred_data_capacity() { for proof_size in 0..0x15 { assert_eq!( ShredData::capacity(proof_size).unwrap(), shred_data_capacity(proof_size) ); } } #[test] fn test_shred_code_capacity() { for proof_size in 0..0x15 { assert_eq!( ShredCode::capacity(proof_size).unwrap(), shred_data_size_of_erasure_encoded_slice(proof_size), ); } } #[test] fn test_merkle_proof_entry_from_hash() { let mut rng = rand::thread_rng(); let bytes: [u8; 32] = rng.gen(); let hash = Hash::from(bytes); let entry = &hash.as_ref()[..SIZE_OF_MERKLE_PROOF_ENTRY]; let entry = MerkleProofEntry::try_from(entry).unwrap(); assert_eq!(entry, &bytes[..SIZE_OF_MERKLE_PROOF_ENTRY]); } fn run_merkle_tree_round_trip(size: usize) { let mut rng = rand::thread_rng(); let nodes = repeat_with(|| rng.gen::<[u8; 32]>()).map(Hash::from); let nodes: Vec<_> = nodes.take(size).collect(); let tree = make_merkle_tree(nodes.clone()); for index in 0..size { let branch = make_merkle_branch(index, size, &tree).unwrap(); let root = &tree.last().unwrap().as_ref()[..SIZE_OF_MERKLE_ROOT]; assert_eq!(&branch.root, root); assert!(verify_merkle_proof(index, nodes[index], &branch)); for i in (0..size).filter(|&i| i != index) { assert!(!verify_merkle_proof(i, nodes[i], &branch)); } } } #[test] fn test_merkle_tree_round_trip() { for size in [1, 2, 3, 4, 5, 6, 7, 8, 9, 19, 37, 64, 79] { run_merkle_tree_round_trip(size); } } #[test_case(37)] #[test_case(64)] #[test_case(73)] fn test_recover_merkle_shreds(num_shreds: usize) { let mut rng = rand::thread_rng(); for num_data_shreds in 1..num_shreds { let num_coding_shreds = num_shreds - num_data_shreds; run_recover_merkle_shreds(&mut rng, num_data_shreds, num_coding_shreds); } } fn run_recover_merkle_shreds( rng: &mut R, num_data_shreds: usize, num_coding_shreds: usize, ) { let keypair = Keypair::generate(rng); let num_shreds = num_data_shreds + num_coding_shreds; let proof_size = (num_shreds as f64).log2().ceil() as u8; let capacity = ShredData::capacity(proof_size).unwrap(); let common_header = ShredCommonHeader { signature: Signature::default(), shred_variant: ShredVariant::MerkleData(proof_size), slot: 145865705, index: 1835, version: 4978, fec_set_index: 1835, }; let data_header = DataShredHeader { parent_offset: 25, flags: unsafe { ShredFlags::from_bits_unchecked(0b0010_1010) }, size: 0, }; let coding_header = CodingShredHeader { num_data_shreds: num_data_shreds as u16, num_coding_shreds: num_coding_shreds as u16, position: 0, }; let mut shreds = Vec::with_capacity(num_shreds); for i in 0..num_data_shreds { let common_header = ShredCommonHeader { index: common_header.index + i as u32, ..common_header }; let size = ShredData::SIZE_OF_HEADERS + rng.gen_range(0, capacity); let data_header = DataShredHeader { size: size as u16, ..data_header }; let mut payload = vec![0u8; ShredData::SIZE_OF_PAYLOAD]; let mut cursor = Cursor::new(&mut payload[..]); bincode::serialize_into(&mut cursor, &common_header).unwrap(); bincode::serialize_into(&mut cursor, &data_header).unwrap(); rng.fill(&mut payload[ShredData::SIZE_OF_HEADERS..size]); let shred = ShredData { common_header, data_header, merkle_branch: MerkleBranch::new_zeroed(proof_size), payload, }; shreds.push(Shred::ShredData(shred)); } let data: Vec<_> = shreds .iter() .map(Shred::erasure_shard_as_slice) .collect::>() .unwrap(); let mut parity = vec![vec![0u8; data[0].len()]; num_coding_shreds]; ReedSolomon::new(num_data_shreds, num_coding_shreds) .unwrap() .encode_sep(&data, &mut parity[..]) .unwrap(); for (i, code) in parity.into_iter().enumerate() { let common_header = ShredCommonHeader { shred_variant: ShredVariant::MerkleCode(proof_size), index: common_header.index + i as u32 + 7, ..common_header }; let coding_header = CodingShredHeader { position: i as u16, ..coding_header }; let mut payload = vec![0u8; ShredCode::SIZE_OF_PAYLOAD]; let mut cursor = Cursor::new(&mut payload[..]); bincode::serialize_into(&mut cursor, &common_header).unwrap(); bincode::serialize_into(&mut cursor, &coding_header).unwrap(); payload[ShredCode::SIZE_OF_HEADERS..ShredCode::SIZE_OF_HEADERS + code.len()] .copy_from_slice(&code); let shred = ShredCode { common_header, coding_header, merkle_branch: MerkleBranch::new_zeroed(proof_size), payload, }; shreds.push(Shred::ShredCode(shred)); } let nodes: Vec<_> = shreds .iter() .map(Shred::merkle_tree_node) .collect::>() .unwrap(); let tree = make_merkle_tree(nodes); for (index, shred) in shreds.iter_mut().enumerate() { let merkle_branch = make_merkle_branch(index, num_shreds, &tree).unwrap(); assert_eq!(merkle_branch.proof.len(), usize::from(proof_size)); shred.set_merkle_branch(merkle_branch).unwrap(); let signature = keypair.sign_message(shred.signed_message()); shred.set_signature(signature); assert!(shred.verify(&keypair.pubkey())); assert_matches!(shred.sanitize(), Ok(())); } assert_eq!(shreds.iter().map(Shred::signature).dedup().count(), 1); for size in num_data_shreds..num_shreds { let mut shreds = shreds.clone(); let mut removed_shreds = Vec::new(); while shreds.len() > size { let index = rng.gen_range(0, shreds.len()); removed_shreds.push(shreds.swap_remove(index)); } shreds.shuffle(rng); // Should at least contain one coding shred. if shreds.iter().all(|shred| { matches!( shred.common_header().shred_variant, ShredVariant::MerkleData(_) ) }) { assert_matches!( recover(shreds), Err(Error::ErasureError(TooFewParityShards)) ); continue; } let recovered_shreds = recover(shreds).unwrap(); assert_eq!(size + recovered_shreds.len(), num_shreds); assert_eq!(recovered_shreds.len(), removed_shreds.len()); removed_shreds.sort_by(|a, b| { if a.shred_type() == b.shred_type() { a.index().cmp(&b.index()) } else if a.shred_type() == ShredType::Data { Ordering::Less } else { Ordering::Greater } }); assert_eq!(recovered_shreds, removed_shreds); } } }