From 3b510568576f5d4ae797b9cdbf66f7988f46002b Mon Sep 17 00:00:00 2001 From: Henry de Valence Date: Thu, 19 Sep 2019 10:08:35 -0700 Subject: [PATCH] Change Message serialization to async send/recv. Because we want to be able to read messages from async sources (like a tcp socket), we need to have at least async header parsing logic, so that we can correctly determine how many bytes to await to parse each message, so it makes sense for the entire message parsing functions to be async. Because we perform message serialization into async readers and writers in the context of sending messages over the network, code using these functions is more clear with these names. --- zebra-network/Cargo.toml | 2 + zebra-network/src/message.rs | 193 +++++++++++++++++++---------------- 2 files changed, 108 insertions(+), 87 deletions(-) diff --git a/zebra-network/Cargo.toml b/zebra-network/Cargo.toml index d5d68a8aa..8e70c9650 100644 --- a/zebra-network/Cargo.toml +++ b/zebra-network/Cargo.toml @@ -11,4 +11,6 @@ rand = "0.7" byteorder = "1.3" chrono = "0.4" failure = "0.1" +tokio = "=0.2.0-alpha.4" + zebra-chain = { path = "../zebra-chain" } diff --git a/zebra-network/src/message.rs b/zebra-network/src/message.rs index 0ce4b09d1..7509a5ab5 100644 --- a/zebra-network/src/message.rs +++ b/zebra-network/src/message.rs @@ -1,10 +1,11 @@ //! Definitions of network messages. -use std::io; +use std::io::{self, Cursor, Read, Write}; use std::net; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use chrono::{DateTime, TimeZone, Utc}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zebra_chain::{ serialization::{ @@ -280,12 +281,8 @@ pub enum RejectReason { // Maybe just write some functions and refactor later? impl Message { - /// Serialize `self` into the given writer, similarly to `ZcashSerialize`. - /// - /// This is similar to [`ZcashSerialize::zcash_serialize`], but not part of - /// that trait, because message serialization requires additional parameters - /// (the network magic and the network version). - pub fn zcash_serialize( + /// Send `self` to the given async writer (e.g., a network stream). + pub async fn send( &self, mut writer: W, magic: Magic, @@ -325,15 +322,92 @@ impl Message { MerkleBlock { .. } => b"merkleblock\0", }; - // Write the header and then the body. - writer.write_all(&magic.0)?; - writer.write_all(command)?; - writer.write_u32::(body.len() as u32)?; - writer.write_all(&Sha256dChecksum::from(&body[..]).0)?; - writer.write_all(&body)?; + // Write the header into a stack buffer first before feeding that stack + // buffer into the async writer. This allows using the WriteBytesExt + // extension trait, which is only defined for sync Writers. + + // The header is 4+12+4+4=24 bytes long. + let mut header = [0u8; 24]; + let mut header_writer = Cursor::new(&mut header[..]); + header_writer.write_all(&magic.0)?; + header_writer.write_all(command)?; + header_writer.write_u32::(body.len() as u32)?; + header_writer.write_all(&Sha256dChecksum::from(&body[..]).0)?; + + writer.write_all(&header).await?; + writer.write_all(&body).await?; Ok(()) } + + /// Receive a message from the given async reader (e.g., a network stream). + pub async fn recv( + mut reader: R, + magic: Magic, + version: Version, + ) -> Result { + use SerializationError::ParseError; + + // Read the header into a stack buffer before trying to parse it. This + // allows using the ReadBytesExt extension trait, which is only defined + // for sync Readers. Then we can determine the expected message length, + // await that many bytes, and finally parse the message. + + // The header is 4+12+4+4=24 bytes long. + let header = { + let mut bytes = [0u8; 24]; + reader.read_exact(&mut bytes).await?; + bytes + }; + let mut header_reader = Cursor::new(&header[..]); + + // Read header data + let message_magic = Magic(header_reader.read_4_bytes()?); + let command = header_reader.read_12_bytes()?; + let body_len = header_reader.read_u32::()? as usize; + let checksum = Sha256dChecksum(header_reader.read_4_bytes()?); + + if magic != message_magic { + return Err(ParseError("Message has incorrect magic value")); + } + + // XXX bound the body_len value to avoid large attacker-controlled allocs + // XXX add a ChecksumReader(R) wrapper and avoid this + let body = { + let mut bytes = vec![0; body_len]; + reader.read_exact(&mut bytes).await?; + bytes + }; + + if checksum != Sha256dChecksum::from(&body[..]) { + return Err(SerializationError::ParseError("checksum does not match")); + } + + let body_reader = Cursor::new(&body); + match &command { + b"version\0\0\0\0\0" => try_read_version(body_reader, version), + b"verack\0\0\0\0\0\0" => try_read_verack(body_reader, version), + b"ping\0\0\0\0\0\0\0\0" => try_read_ping(body_reader, version), + b"pong\0\0\0\0\0\0\0\0" => try_read_pong(body_reader, version), + b"reject\0\0\0\0\0\0" => try_read_reject(body_reader, version), + b"addr\0\0\0\0\0\0\0\0" => try_read_addr(body_reader, version), + b"getaddr\0\0\0\0\0" => try_read_getaddr(body_reader, version), + b"block\0\0\0\0\0\0\0" => try_read_block(body_reader, version), + b"getblocks\0\0\0" => try_read_getblocks(body_reader, version), + b"headers\0\0\0\0\0" => try_read_headers(body_reader, version), + b"getheaders\0\0" => try_read_getheaders(body_reader, version), + b"inv\0\0\0\0\0\0\0\0\0" => try_read_inv(body_reader, version), + b"getdata\0\0\0\0\0" => try_read_getdata(body_reader, version), + b"notfound\0\0\0\0" => try_read_notfound(body_reader, version), + b"tx\0\0\0\0\0\0\0\0\0\0" => try_read_tx(body_reader, version), + b"mempool\0\0\0\0\0" => try_read_mempool(body_reader, version), + b"filterload\0\0" => try_read_filterload(body_reader, version), + b"filteradd\0\0\0" => try_read_filteradd(body_reader, version), + b"filterclear\0" => try_read_filterclear(body_reader, version), + b"merkleblock\0" => try_read_merkleblock(body_reader, version), + _ => Err(ParseError("Unknown command")), + } + } } impl Message { @@ -380,66 +454,6 @@ impl Message { } Ok(()) } - - /// Try to deserialize a [`Message`] from the given reader, similarly to `ZcashDeserialize`. - /// - /// This is similar to [`ZcashSerialize::zcash_serialize`], but not part of - /// that trait, because message serialization requires additional parameters - /// (the network magic and the network version). - pub fn zcash_deserialize( - mut reader: R, - magic: Magic, - version: Version, - ) -> Result { - use SerializationError::ParseError; - - // Read header data - let message_magic = Magic(reader.read_4_bytes()?); - let command = reader.read_12_bytes()?; - let body_len = reader.read_u32::()? as usize; - let checksum = Sha256dChecksum(reader.read_4_bytes()?); - - if magic != message_magic { - return Err(ParseError("Message has incorrect magic value")); - } - - // XXX bound the body_len value to avoid large attacker-controlled allocs - // XXX add a ChecksumReader(R) wrapper and avoid this - let body = { - let mut bytes = vec![0; body_len]; - reader.read_exact(&mut bytes)?; - bytes - }; - - if checksum != Sha256dChecksum::from(&body[..]) { - return Err(SerializationError::ParseError("checksum does not match")); - } - - let body_reader = io::Cursor::new(&body); - match &command { - b"version\0\0\0\0\0" => try_read_version(body_reader, version), - b"verack\0\0\0\0\0\0" => try_read_verack(body_reader, version), - b"ping\0\0\0\0\0\0\0\0" => try_read_ping(body_reader, version), - b"pong\0\0\0\0\0\0\0\0" => try_read_pong(body_reader, version), - b"reject\0\0\0\0\0\0" => try_read_reject(body_reader, version), - b"addr\0\0\0\0\0\0\0\0" => try_read_addr(body_reader, version), - b"getaddr\0\0\0\0\0" => try_read_getaddr(body_reader, version), - b"block\0\0\0\0\0\0\0" => try_read_block(body_reader, version), - b"getblocks\0\0\0" => try_read_getblocks(body_reader, version), - b"headers\0\0\0\0\0" => try_read_headers(body_reader, version), - b"getheaders\0\0" => try_read_getheaders(body_reader, version), - b"inv\0\0\0\0\0\0\0\0\0" => try_read_inv(body_reader, version), - b"getdata\0\0\0\0\0" => try_read_getdata(body_reader, version), - b"notfound\0\0\0\0" => try_read_notfound(body_reader, version), - b"tx\0\0\0\0\0\0\0\0\0\0" => try_read_tx(body_reader, version), - b"mempool\0\0\0\0\0" => try_read_mempool(body_reader, version), - b"filterload\0\0" => try_read_filterload(body_reader, version), - b"filteradd\0\0\0" => try_read_filteradd(body_reader, version), - b"filterclear\0" => try_read_filterclear(body_reader, version), - b"merkleblock\0" => try_read_merkleblock(body_reader, version), - _ => Err(ParseError("Unknown command")), - } - } } fn try_read_version( @@ -605,6 +619,7 @@ fn try_read_merkleblock( #[cfg(test)] mod tests { use super::*; + use tokio::runtime::Runtime; #[test] fn version_message_round_trip() { @@ -612,11 +627,12 @@ mod tests { let services = Services(0x1); let timestamp = Utc.timestamp(1568000000, 0); + let rt = Runtime::new().unwrap(); + let v = Message::Version { version: crate::constants::CURRENT_VERSION, services, timestamp, - // XXX maybe better to have Version keep only (Services, SocketAddr) address_recv: ( services, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 6)), 8233), @@ -631,24 +647,27 @@ mod tests { relay: true, }; - use std::io::Cursor; - - let v_bytes = { + let v_bytes = rt.block_on(async { let mut bytes = Vec::new(); - let _ = v.zcash_serialize( - Cursor::new(&mut bytes), + v.send( + &mut bytes, crate::constants::magics::MAINNET, crate::constants::CURRENT_VERSION, - ); + ) + .await + .unwrap(); bytes - }; + }); - let v_parsed = Message::zcash_deserialize( - Cursor::new(&v_bytes), - crate::constants::magics::MAINNET, - crate::constants::CURRENT_VERSION, - ) - .expect("message should parse successfully"); + let v_parsed = rt.block_on(async { + Message::recv( + Cursor::new(&v_bytes), + crate::constants::magics::MAINNET, + crate::constants::CURRENT_VERSION, + ) + .await + .unwrap() + }); assert_eq!(v, v_parsed); }