Change Message serialization to async send/recv.

Because we want to be able to read messages from async sources (like a
tcp socket), we need to have at least async header parsing logic, so
that we can correctly determine how many bytes to await to parse each
message, so it makes sense for the entire message parsing functions
to be async.

Because we perform message serialization into async readers and writers
in the context of sending messages over the network, code using these
functions is more clear with these names.
This commit is contained in:
Henry de Valence 2019-09-19 10:08:35 -07:00 committed by Deirdre Connolly
parent fa4ba442eb
commit 3b51056857
2 changed files with 108 additions and 87 deletions

View File

@ -11,4 +11,6 @@ rand = "0.7"
byteorder = "1.3"
chrono = "0.4"
failure = "0.1"
tokio = "=0.2.0-alpha.4"
zebra-chain = { path = "../zebra-chain" }

View File

@ -1,10 +1,11 @@
//! Definitions of network messages.
use std::io;
use std::io::{self, Cursor, Read, Write};
use std::net;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use chrono::{DateTime, TimeZone, Utc};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use zebra_chain::{
serialization::{
@ -280,12 +281,8 @@ pub enum RejectReason {
// Maybe just write some functions and refactor later?
impl Message {
/// Serialize `self` into the given writer, similarly to `ZcashSerialize`.
///
/// This is similar to [`ZcashSerialize::zcash_serialize`], but not part of
/// that trait, because message serialization requires additional parameters
/// (the network magic and the network version).
pub fn zcash_serialize<W: io::Write>(
/// Send `self` to the given async writer (e.g., a network stream).
pub async fn send<W: Unpin + AsyncWrite>(
&self,
mut writer: W,
magic: Magic,
@ -325,15 +322,92 @@ impl Message {
MerkleBlock { .. } => b"merkleblock\0",
};
// Write the header and then the body.
writer.write_all(&magic.0)?;
writer.write_all(command)?;
writer.write_u32::<LittleEndian>(body.len() as u32)?;
writer.write_all(&Sha256dChecksum::from(&body[..]).0)?;
writer.write_all(&body)?;
// Write the header into a stack buffer first before feeding that stack
// buffer into the async writer. This allows using the WriteBytesExt
// extension trait, which is only defined for sync Writers.
// The header is 4+12+4+4=24 bytes long.
let mut header = [0u8; 24];
let mut header_writer = Cursor::new(&mut header[..]);
header_writer.write_all(&magic.0)?;
header_writer.write_all(command)?;
header_writer.write_u32::<LittleEndian>(body.len() as u32)?;
header_writer.write_all(&Sha256dChecksum::from(&body[..]).0)?;
writer.write_all(&header).await?;
writer.write_all(&body).await?;
Ok(())
}
/// Receive a message from the given async reader (e.g., a network stream).
pub async fn recv<R: Unpin + AsyncRead>(
mut reader: R,
magic: Magic,
version: Version,
) -> Result<Self, SerializationError> {
use SerializationError::ParseError;
// Read the header into a stack buffer before trying to parse it. This
// allows using the ReadBytesExt extension trait, which is only defined
// for sync Readers. Then we can determine the expected message length,
// await that many bytes, and finally parse the message.
// The header is 4+12+4+4=24 bytes long.
let header = {
let mut bytes = [0u8; 24];
reader.read_exact(&mut bytes).await?;
bytes
};
let mut header_reader = Cursor::new(&header[..]);
// Read header data
let message_magic = Magic(header_reader.read_4_bytes()?);
let command = header_reader.read_12_bytes()?;
let body_len = header_reader.read_u32::<LittleEndian>()? as usize;
let checksum = Sha256dChecksum(header_reader.read_4_bytes()?);
if magic != message_magic {
return Err(ParseError("Message has incorrect magic value"));
}
// XXX bound the body_len value to avoid large attacker-controlled allocs
// XXX add a ChecksumReader<R: Read>(R) wrapper and avoid this
let body = {
let mut bytes = vec![0; body_len];
reader.read_exact(&mut bytes).await?;
bytes
};
if checksum != Sha256dChecksum::from(&body[..]) {
return Err(SerializationError::ParseError("checksum does not match"));
}
let body_reader = Cursor::new(&body);
match &command {
b"version\0\0\0\0\0" => try_read_version(body_reader, version),
b"verack\0\0\0\0\0\0" => try_read_verack(body_reader, version),
b"ping\0\0\0\0\0\0\0\0" => try_read_ping(body_reader, version),
b"pong\0\0\0\0\0\0\0\0" => try_read_pong(body_reader, version),
b"reject\0\0\0\0\0\0" => try_read_reject(body_reader, version),
b"addr\0\0\0\0\0\0\0\0" => try_read_addr(body_reader, version),
b"getaddr\0\0\0\0\0" => try_read_getaddr(body_reader, version),
b"block\0\0\0\0\0\0\0" => try_read_block(body_reader, version),
b"getblocks\0\0\0" => try_read_getblocks(body_reader, version),
b"headers\0\0\0\0\0" => try_read_headers(body_reader, version),
b"getheaders\0\0" => try_read_getheaders(body_reader, version),
b"inv\0\0\0\0\0\0\0\0\0" => try_read_inv(body_reader, version),
b"getdata\0\0\0\0\0" => try_read_getdata(body_reader, version),
b"notfound\0\0\0\0" => try_read_notfound(body_reader, version),
b"tx\0\0\0\0\0\0\0\0\0\0" => try_read_tx(body_reader, version),
b"mempool\0\0\0\0\0" => try_read_mempool(body_reader, version),
b"filterload\0\0" => try_read_filterload(body_reader, version),
b"filteradd\0\0\0" => try_read_filteradd(body_reader, version),
b"filterclear\0" => try_read_filterclear(body_reader, version),
b"merkleblock\0" => try_read_merkleblock(body_reader, version),
_ => Err(ParseError("Unknown command")),
}
}
}
impl Message {
@ -380,66 +454,6 @@ impl Message {
}
Ok(())
}
/// Try to deserialize a [`Message`] from the given reader, similarly to `ZcashDeserialize`.
///
/// This is similar to [`ZcashSerialize::zcash_serialize`], but not part of
/// that trait, because message serialization requires additional parameters
/// (the network magic and the network version).
pub fn zcash_deserialize<R: io::Read>(
mut reader: R,
magic: Magic,
version: Version,
) -> Result<Self, SerializationError> {
use SerializationError::ParseError;
// Read header data
let message_magic = Magic(reader.read_4_bytes()?);
let command = reader.read_12_bytes()?;
let body_len = reader.read_u32::<LittleEndian>()? as usize;
let checksum = Sha256dChecksum(reader.read_4_bytes()?);
if magic != message_magic {
return Err(ParseError("Message has incorrect magic value"));
}
// XXX bound the body_len value to avoid large attacker-controlled allocs
// XXX add a ChecksumReader<R: Read>(R) wrapper and avoid this
let body = {
let mut bytes = vec![0; body_len];
reader.read_exact(&mut bytes)?;
bytes
};
if checksum != Sha256dChecksum::from(&body[..]) {
return Err(SerializationError::ParseError("checksum does not match"));
}
let body_reader = io::Cursor::new(&body);
match &command {
b"version\0\0\0\0\0" => try_read_version(body_reader, version),
b"verack\0\0\0\0\0\0" => try_read_verack(body_reader, version),
b"ping\0\0\0\0\0\0\0\0" => try_read_ping(body_reader, version),
b"pong\0\0\0\0\0\0\0\0" => try_read_pong(body_reader, version),
b"reject\0\0\0\0\0\0" => try_read_reject(body_reader, version),
b"addr\0\0\0\0\0\0\0\0" => try_read_addr(body_reader, version),
b"getaddr\0\0\0\0\0" => try_read_getaddr(body_reader, version),
b"block\0\0\0\0\0\0\0" => try_read_block(body_reader, version),
b"getblocks\0\0\0" => try_read_getblocks(body_reader, version),
b"headers\0\0\0\0\0" => try_read_headers(body_reader, version),
b"getheaders\0\0" => try_read_getheaders(body_reader, version),
b"inv\0\0\0\0\0\0\0\0\0" => try_read_inv(body_reader, version),
b"getdata\0\0\0\0\0" => try_read_getdata(body_reader, version),
b"notfound\0\0\0\0" => try_read_notfound(body_reader, version),
b"tx\0\0\0\0\0\0\0\0\0\0" => try_read_tx(body_reader, version),
b"mempool\0\0\0\0\0" => try_read_mempool(body_reader, version),
b"filterload\0\0" => try_read_filterload(body_reader, version),
b"filteradd\0\0\0" => try_read_filteradd(body_reader, version),
b"filterclear\0" => try_read_filterclear(body_reader, version),
b"merkleblock\0" => try_read_merkleblock(body_reader, version),
_ => Err(ParseError("Unknown command")),
}
}
}
fn try_read_version<R: io::Read>(
@ -605,6 +619,7 @@ fn try_read_merkleblock<R: io::Read>(
#[cfg(test)]
mod tests {
use super::*;
use tokio::runtime::Runtime;
#[test]
fn version_message_round_trip() {
@ -612,11 +627,12 @@ mod tests {
let services = Services(0x1);
let timestamp = Utc.timestamp(1568000000, 0);
let rt = Runtime::new().unwrap();
let v = Message::Version {
version: crate::constants::CURRENT_VERSION,
services,
timestamp,
// XXX maybe better to have Version keep only (Services, SocketAddr)
address_recv: (
services,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 6)), 8233),
@ -631,24 +647,27 @@ mod tests {
relay: true,
};
use std::io::Cursor;
let v_bytes = {
let v_bytes = rt.block_on(async {
let mut bytes = Vec::new();
let _ = v.zcash_serialize(
Cursor::new(&mut bytes),
v.send(
&mut bytes,
crate::constants::magics::MAINNET,
crate::constants::CURRENT_VERSION,
);
)
.await
.unwrap();
bytes
};
});
let v_parsed = Message::zcash_deserialize(
Cursor::new(&v_bytes),
crate::constants::magics::MAINNET,
crate::constants::CURRENT_VERSION,
)
.expect("message should parse successfully");
let v_parsed = rt.block_on(async {
Message::recv(
Cursor::new(&v_bytes),
crate::constants::magics::MAINNET,
crate::constants::CURRENT_VERSION,
)
.await
.unwrap()
});
assert_eq!(v, v_parsed);
}