From b9849179c94c1805ddbef35c6e3265ed4a80f4bd Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Mon, 26 Sep 2022 21:28:15 +0000 Subject: [PATCH] bypasses merkle proof verification for recovered merkle shreds (#28076) Merkle proof for shreds recovered from erasure codes are generated locally, and it is superfluous to verify them when sanitizing recovered shreds: https://github.com/solana-labs/solana/blob/a0f49c2e4/ledger/src/shred/merkle.rs#L727-L760 --- ledger/src/shred/merkle.rs | 85 ++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 36 deletions(-) diff --git a/ledger/src/shred/merkle.rs b/ledger/src/shred/merkle.rs index 725854766b..107967896e 100644 --- a/ledger/src/shred/merkle.rs +++ b/ledger/src/shred/merkle.rs @@ -91,7 +91,7 @@ impl Shred { dispatch!(fn erasure_shard_index(&self) -> Result); dispatch!(fn merkle_tree_node(&self) -> Result); dispatch!(fn payload(&self) -> &Vec); - dispatch!(fn sanitize(&self) -> Result<(), Error>); + dispatch!(fn sanitize(&self, verify_merkle_proof: bool) -> Result<(), Error>); dispatch!(fn set_merkle_branch(&mut self, merkle_branch: MerkleBranch) -> Result<(), Error>); dispatch!(fn set_signature(&mut self, signature: Signature)); dispatch!(fn signed_message(&self) -> &[u8]); @@ -219,6 +219,23 @@ impl ShredData { self.merkle_branch = merkle_branch; Ok(()) } + + fn sanitize(&self, verify_merkle_proof: bool) -> Result<(), Error> { + match self.common_header.shred_variant { + ShredVariant::MerkleData(proof_size) => { + if self.merkle_branch.proof.len() != usize::from(proof_size) { + return Err(Error::InvalidProofSize(proof_size)); + } + } + _ => return Err(Error::InvalidShredVariant), + } + if !verify_merkle_proof { + debug_assert_matches!(self.verify_merkle_proof(), Ok(true)); + } else if !self.verify_merkle_proof()? { + return Err(Error::InvalidMerkleProof); + } + shred_data::sanitize(self) + } } impl ShredCode { @@ -317,6 +334,23 @@ impl ShredCode { self.merkle_branch = merkle_branch; Ok(()) } + + fn sanitize(&self, verify_merkle_proof: bool) -> Result<(), Error> { + match self.common_header.shred_variant { + ShredVariant::MerkleCode(proof_size) => { + if self.merkle_branch.proof.len() != usize::from(proof_size) { + return Err(Error::InvalidProofSize(proof_size)); + } + } + _ => return Err(Error::InvalidShredVariant), + } + if !verify_merkle_proof { + debug_assert_matches!(self.verify_merkle_proof(), Ok(true)); + } else if !self.verify_merkle_proof()? { + return Err(Error::InvalidMerkleProof); + } + shred_code::sanitize(self) + } } impl MerkleBranch { @@ -368,7 +402,8 @@ impl ShredTrait for ShredData { merkle_branch, payload, }; - shred.sanitize().map(|_| shred) + shred.sanitize(/*verify_merkle_proof:*/ true)?; + Ok(shred) } fn erasure_shard_index(&self) -> Result { @@ -402,18 +437,7 @@ impl ShredTrait for ShredData { } fn sanitize(&self) -> Result<(), Error> { - match self.common_header.shred_variant { - ShredVariant::MerkleData(proof_size) => { - if self.merkle_branch.proof.len() != usize::from(proof_size) { - return Err(Error::InvalidProofSize(proof_size)); - } - } - _ => return Err(Error::InvalidShredVariant), - } - if !self.verify_merkle_proof()? { - return Err(Error::InvalidMerkleProof); - } - shred_data::sanitize(self) + self.sanitize(/*verify_merkle_proof:*/ true) } fn signed_message(&self) -> &[u8] { @@ -452,7 +476,8 @@ impl ShredTrait for ShredCode { merkle_branch, payload, }; - shred.sanitize().map(|_| shred) + shred.sanitize(/*verify_merkle_proof:*/ true)?; + Ok(shred) } fn erasure_shard_index(&self) -> Result { @@ -486,18 +511,7 @@ impl ShredTrait for ShredCode { } fn sanitize(&self) -> Result<(), Error> { - match self.common_header.shred_variant { - ShredVariant::MerkleCode(proof_size) => { - if self.merkle_branch.proof.len() != usize::from(proof_size) { - return Err(Error::InvalidProofSize(proof_size)); - } - } - _ => return Err(Error::InvalidShredVariant), - } - if !self.verify_merkle_proof()? { - return Err(Error::InvalidMerkleProof); - } - shred_code::sanitize(self) + self.sanitize(/*verify_merkle_proof:*/ true) } fn signed_message(&self) -> &[u8] { @@ -619,10 +633,7 @@ pub(super) fn recover( Some((common_header, coding_header)) }); let (common_header, coding_header) = headers.ok_or(TooFewParityShards)?; - debug_assert!(matches!( - common_header.shred_variant, - ShredVariant::MerkleCode(_) - )); + debug_assert_matches!(common_header.shred_variant, ShredVariant::MerkleCode(_)); let proof_size = match common_header.shred_variant { ShredVariant::MerkleCode(proof_size) => proof_size, ShredVariant::MerkleData(_) | ShredVariant::LegacyCode | ShredVariant::LegacyData => { @@ -751,12 +762,14 @@ pub(super) fn recover( }); } } - // TODO: No need to verify merkle proof in sanitize here. shreds .into_iter() .zip(mask) .filter(|(_, mask)| !mask) - .map(|(shred, _)| shred.sanitize().map(|_| shred)) + .map(|(shred, _)| { + shred.sanitize(/*verify_merkle_proof:*/ false)?; + Ok(shred) + }) .collect() } @@ -1008,7 +1021,7 @@ fn make_erasure_batch( shred.set_merkle_branch(merkle_branch)?; shred.set_signature(signature); debug_assert!(shred.verify(&keypair.pubkey())); - debug_assert_matches!(shred.sanitize(), Ok(())); + debug_assert_matches!(shred.sanitize(/*verify_merkle_proof:*/ true), Ok(())); // Assert that shred payload is fully populated. debug_assert_eq!(shred, { let shred = shred.payload().clone(); @@ -1237,7 +1250,7 @@ mod test { let signature = keypair.sign_message(shred.signed_message()); shred.set_signature(signature); assert!(shred.verify(&keypair.pubkey())); - assert_matches!(shred.sanitize(), Ok(())); + assert_matches!(shred.sanitize(/*verify_merkle_proof:*/ true), Ok(())); } assert_eq!(shreds.iter().map(Shred::signature).dedup().count(), 1); for size in num_data_shreds..num_shreds { @@ -1372,7 +1385,7 @@ mod test { // Assert that shreds sanitize and verify. for shred in &shreds { assert!(shred.verify(&keypair.pubkey())); - assert_matches!(shred.sanitize(), Ok(())); + assert_matches!(shred.sanitize(/*verify_merkle_proof:*/ true), Ok(())); let ShredCommonHeader { signature, shred_variant,