Provide impl Zcash[De]Serialize for Vec<T: Zcash[De]Serialize>.

This replaces the read_list function and makes the code significantly cleaner.

The only downside is that it loses exact preallocation, but this is probably not a big deal.
This commit is contained in:
Henry de Valence 2019-12-20 16:43:19 -08:00 committed by Deirdre Connolly
parent 1199cfa23e
commit 92ddf0542f
2 changed files with 48 additions and 133 deletions

View File

@ -55,6 +55,31 @@ pub trait ZcashDeserialize: Sized {
fn zcash_deserialize<R: io::Read>(reader: R) -> Result<Self, SerializationError>;
}
impl<T: ZcashSerialize> ZcashSerialize for Vec<T> {
fn zcash_serialize<W: io::Write>(&self, mut writer: W) -> Result<(), SerializationError> {
writer.write_compactsize(self.len() as u64)?;
for x in self {
x.zcash_serialize(&mut writer)?;
}
Ok(())
}
}
impl<T: ZcashDeserialize> ZcashDeserialize for Vec<T> {
fn zcash_deserialize<R: io::Read>(mut reader: R) -> Result<Self, SerializationError> {
let len = reader.read_compactsize()?;
// We're given len, so we could preallocate. But blindly preallocating
// without a size bound can allow DOS attacks, and there's no way to
// pass a size bound in a ZcashDeserialize impl, so instead we allocate
// as we read from the reader.
let mut vec = Vec::new();
for _ in 0..len {
vec.push(T::zcash_deserialize(&mut reader)?);
}
Ok(vec)
}
}
/// Extends [`Write`] with methods for writing Zcash/Bitcoin types.
///
/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html
@ -248,37 +273,6 @@ pub trait ReadZcashExt: io::Read {
self.read_exact(&mut bytes)?;
Ok(bytes)
}
/// Convenience method to read a `Vec<T>` with a leading count in
/// a safer manner.
///
/// This method preallocates a buffer, performing a single
/// allocation in the honest case. It's possible for someone to
/// send a short message with a large count field, so if we
/// naively trust the count field we could be tricked into
/// preallocating a large buffer. Instead, we rely on the passed
/// maximum count for a valid message and select the min of the
/// two values.
#[inline]
fn read_list<T: ZcashDeserialize>(
&mut self,
max_count: usize,
) -> Result<Vec<T>, SerializationError> {
// This prevents the inferred type for zcash_deserialize from
// taking ownership of &mut self. This wouldn't really be an
// issue if the target impl's `Copy`, but we need to own it.
let mut self2 = self;
let count = self2.read_compactsize()? as usize;
let mut items = Vec::with_capacity(std::cmp::min(count, max_count));
for _ in 0..count {
items.push(T::zcash_deserialize(&mut self2)?);
}
return Ok(items);
}
}
/// Mark all types implementing `Read` as implementing the extension.

View File

@ -9,7 +9,7 @@ use chrono::{TimeZone, Utc};
use tokio_util::codec::{Decoder, Encoder};
use zebra_chain::{
block::{Block, BlockHeader, BlockHeaderHash},
block::{Block, BlockHeaderHash},
serialization::{
ReadZcashExt, SerializationError as Error, WriteZcashExt, ZcashDeserialize, ZcashSerialize,
},
@ -20,7 +20,6 @@ use zebra_chain::{
use crate::{constants, types::Network};
use super::{
inv::InventoryHash,
message::{Message, RejectReason},
types::*,
};
@ -202,12 +201,7 @@ impl Codec {
writer.write_string(&reason)?;
writer.write_all(&data.unwrap())?;
}
Addr(ref addrs) => {
writer.write_compactsize(addrs.len() as u64)?;
for addr in addrs {
addr.zcash_serialize(&mut writer)?;
}
}
Addr(ref addrs) => addrs.zcash_serialize(&mut writer)?,
GetAddr => { /* Empty payload -- no-op */ }
Block {
ref version,
@ -222,10 +216,7 @@ impl Codec {
ref hash_stop,
} => {
writer.write_u32::<LittleEndian>(version.0)?;
writer.write_compactsize(block_locator_hashes.len() as u64)?;
for hash in block_locator_hashes {
hash.zcash_serialize(&mut writer)?;
}
block_locator_hashes.zcash_serialize(&mut writer)?;
hash_stop.zcash_serialize(&mut writer)?;
}
GetHeaders {
@ -234,36 +225,13 @@ impl Codec {
ref hash_stop,
} => {
writer.write_u32::<LittleEndian>(version.0)?;
writer.write_compactsize(block_locator_hashes.len() as u64)?;
for hash in block_locator_hashes {
hash.zcash_serialize(&mut writer)?;
}
block_locator_hashes.zcash_serialize(&mut writer)?;
hash_stop.zcash_serialize(&mut writer)?;
}
Headers(ref headers) => {
writer.write_compactsize(headers.len() as u64)?;
for header in headers {
header.zcash_serialize(&mut writer)?;
}
}
Inv(ref hashes) => {
writer.write_compactsize(hashes.len() as u64)?;
for hash in hashes {
hash.zcash_serialize(&mut writer)?;
}
}
GetData(ref hashes) => {
writer.write_compactsize(hashes.len() as u64)?;
for hash in hashes {
hash.zcash_serialize(&mut writer)?;
}
}
NotFound(ref hashes) => {
writer.write_compactsize(hashes.len() as u64)?;
for hash in hashes {
hash.zcash_serialize(&mut writer)?;
}
}
Headers(ref headers) => headers.zcash_serialize(&mut writer)?,
Inv(ref hashes) => hashes.zcash_serialize(&mut writer)?,
GetData(ref hashes) => hashes.zcash_serialize(&mut writer)?,
NotFound(ref hashes) => hashes.zcash_serialize(&mut writer)?,
Tx {
ref version,
ref transaction,
@ -488,16 +456,8 @@ impl Codec {
})
}
fn read_addr<R: Read>(&self, mut reader: R) -> Result<Message, Error> {
use crate::meta_addr::MetaAddr;
// addrs are encoded as: timestamp + services + ipv6 + port
const ENCODED_ADDR_SIZE: usize = 4 + 8 + 16 + 2;
let max_count = self.builder.max_len / ENCODED_ADDR_SIZE;
let addrs: Vec<MetaAddr> = reader.read_list(max_count)?;
Ok(Message::Addr(addrs))
fn read_addr<R: Read>(&self, reader: R) -> Result<Message, Error> {
Ok(Message::Addr(Vec::zcash_deserialize(reader)?))
}
fn read_getaddr<R: Read>(&self, mut _reader: R) -> Result<Message, Error> {
@ -507,24 +467,15 @@ impl Codec {
fn read_block<R: Read>(&self, mut reader: R) -> Result<Message, Error> {
Ok(Message::Block {
version: Version(reader.read_u32::<LittleEndian>()?),
block: Block::zcash_deserialize(&mut reader)?,
})
}
fn read_getblocks<R: Read>(&self, mut reader: R) -> Result<Message, Error> {
let version = Version(reader.read_u32::<LittleEndian>()?);
let max_count = self.builder.max_len / 32;
let block_locator_hashes: Vec<BlockHeaderHash> = reader.read_list(max_count)?;
let hash_stop = BlockHeaderHash(reader.read_32_bytes()?);
Ok(Message::GetBlocks {
version,
block_locator_hashes,
hash_stop,
version: Version(reader.read_u32::<LittleEndian>()?),
block_locator_hashes: Vec::zcash_deserialize(&mut reader)?,
hash_stop: BlockHeaderHash(reader.read_32_bytes()?),
})
}
@ -534,62 +485,32 @@ impl Codec {
///
/// [Zcash block header](https://zips.z.cash/protocol/protocol.pdf#page=84)
fn read_headers<R: Read>(&self, mut reader: R) -> Result<Message, Error> {
const ENCODED_HEADER_SIZE: usize = 4 + 32 + 32 + 32 + 4 + 4 + 32 + 3 + 1344;
let max_count = self.builder.max_len / ENCODED_HEADER_SIZE;
let headers: Vec<BlockHeader> = reader.read_list(max_count)?;
Ok(Message::Headers(headers))
Ok(Message::Headers(Vec::zcash_deserialize(&mut reader)?))
}
fn read_getheaders<R: Read>(&self, mut reader: R) -> Result<Message, Error> {
let version = Version(reader.read_u32::<LittleEndian>()?);
let max_count = self.builder.max_len / 32;
let block_locator_hashes: Vec<BlockHeaderHash> = reader.read_list(max_count)?;
let hash_stop = BlockHeaderHash(reader.read_32_bytes()?);
Ok(Message::GetHeaders {
version,
block_locator_hashes,
hash_stop,
version: Version(reader.read_u32::<LittleEndian>()?),
block_locator_hashes: Vec::zcash_deserialize(&mut reader)?,
hash_stop: BlockHeaderHash(reader.read_32_bytes()?),
})
}
fn read_inv<R: Read>(&self, mut reader: R) -> Result<Message, Error> {
// encoding: 4 byte type tag + 32 byte hash
const ENCODED_INVHASH_SIZE: usize = 4 + 32;
let max_count = self.builder.max_len / ENCODED_INVHASH_SIZE;
let hashes: Vec<InventoryHash> = reader.read_list(max_count)?;
Ok(Message::Inv(hashes))
fn read_inv<R: Read>(&self, reader: R) -> Result<Message, Error> {
Ok(Message::Inv(Vec::zcash_deserialize(reader)?))
}
fn read_getdata<R: Read>(&self, mut reader: R) -> Result<Message, Error> {
// encoding: 4 byte type tag + 32 byte hash
const ENCODED_INVHASH_SIZE: usize = 4 + 32;
let max_count = self.builder.max_len / ENCODED_INVHASH_SIZE;
let hashes: Vec<InventoryHash> = reader.read_list(max_count)?;
Ok(Message::GetData(hashes))
fn read_getdata<R: Read>(&self, reader: R) -> Result<Message, Error> {
Ok(Message::GetData(Vec::zcash_deserialize(reader)?))
}
fn read_notfound<R: Read>(&self, mut reader: R) -> Result<Message, Error> {
// encoding: 4 byte type tag + 32 byte hash
const ENCODED_INVHASH_SIZE: usize = 4 + 32;
let max_count = self.builder.max_len / ENCODED_INVHASH_SIZE;
let hashes: Vec<InventoryHash> = reader.read_list(max_count)?;
Ok(Message::GetData(hashes))
fn read_notfound<R: Read>(&self, reader: R) -> Result<Message, Error> {
Ok(Message::GetData(Vec::zcash_deserialize(reader)?))
}
fn read_tx<R: Read>(&self, mut reader: R) -> Result<Message, Error> {
Ok(Message::Tx {
version: Version(reader.read_u32::<LittleEndian>()?),
transaction: Transaction::zcash_deserialize(&mut reader)?,
})
}