diff --git a/zebra-network/Cargo.toml b/zebra-network/Cargo.toml index d5d68a8aa..8e70c9650 100644 --- a/zebra-network/Cargo.toml +++ b/zebra-network/Cargo.toml @@ -11,4 +11,6 @@ rand = "0.7" byteorder = "1.3" chrono = "0.4" failure = "0.1" +tokio = "=0.2.0-alpha.4" + zebra-chain = { path = "../zebra-chain" } diff --git a/zebra-network/src/message.rs b/zebra-network/src/message.rs index 0ce4b09d1..7509a5ab5 100644 --- a/zebra-network/src/message.rs +++ b/zebra-network/src/message.rs @@ -1,10 +1,11 @@ //! Definitions of network messages. -use std::io; +use std::io::{self, Cursor, Read, Write}; use std::net; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use chrono::{DateTime, TimeZone, Utc}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zebra_chain::{ serialization::{ @@ -280,12 +281,8 @@ pub enum RejectReason { // Maybe just write some functions and refactor later? impl Message { - /// Serialize `self` into the given writer, similarly to `ZcashSerialize`. - /// - /// This is similar to [`ZcashSerialize::zcash_serialize`], but not part of - /// that trait, because message serialization requires additional parameters - /// (the network magic and the network version). - pub fn zcash_serialize( + /// Send `self` to the given async writer (e.g., a network stream). + pub async fn send( &self, mut writer: W, magic: Magic, @@ -325,15 +322,92 @@ impl Message { MerkleBlock { .. } => b"merkleblock\0", }; - // Write the header and then the body. - writer.write_all(&magic.0)?; - writer.write_all(command)?; - writer.write_u32::(body.len() as u32)?; - writer.write_all(&Sha256dChecksum::from(&body[..]).0)?; - writer.write_all(&body)?; + // Write the header into a stack buffer first before feeding that stack + // buffer into the async writer. This allows using the WriteBytesExt + // extension trait, which is only defined for sync Writers. + + // The header is 4+12+4+4=24 bytes long. + let mut header = [0u8; 24]; + let mut header_writer = Cursor::new(&mut header[..]); + header_writer.write_all(&magic.0)?; + header_writer.write_all(command)?; + header_writer.write_u32::(body.len() as u32)?; + header_writer.write_all(&Sha256dChecksum::from(&body[..]).0)?; + + writer.write_all(&header).await?; + writer.write_all(&body).await?; Ok(()) } + + /// Receive a message from the given async reader (e.g., a network stream). + pub async fn recv( + mut reader: R, + magic: Magic, + version: Version, + ) -> Result { + use SerializationError::ParseError; + + // Read the header into a stack buffer before trying to parse it. This + // allows using the ReadBytesExt extension trait, which is only defined + // for sync Readers. Then we can determine the expected message length, + // await that many bytes, and finally parse the message. + + // The header is 4+12+4+4=24 bytes long. + let header = { + let mut bytes = [0u8; 24]; + reader.read_exact(&mut bytes).await?; + bytes + }; + let mut header_reader = Cursor::new(&header[..]); + + // Read header data + let message_magic = Magic(header_reader.read_4_bytes()?); + let command = header_reader.read_12_bytes()?; + let body_len = header_reader.read_u32::()? as usize; + let checksum = Sha256dChecksum(header_reader.read_4_bytes()?); + + if magic != message_magic { + return Err(ParseError("Message has incorrect magic value")); + } + + // XXX bound the body_len value to avoid large attacker-controlled allocs + // XXX add a ChecksumReader(R) wrapper and avoid this + let body = { + let mut bytes = vec![0; body_len]; + reader.read_exact(&mut bytes).await?; + bytes + }; + + if checksum != Sha256dChecksum::from(&body[..]) { + return Err(SerializationError::ParseError("checksum does not match")); + } + + let body_reader = Cursor::new(&body); + match &command { + b"version\0\0\0\0\0" => try_read_version(body_reader, version), + b"verack\0\0\0\0\0\0" => try_read_verack(body_reader, version), + b"ping\0\0\0\0\0\0\0\0" => try_read_ping(body_reader, version), + b"pong\0\0\0\0\0\0\0\0" => try_read_pong(body_reader, version), + b"reject\0\0\0\0\0\0" => try_read_reject(body_reader, version), + b"addr\0\0\0\0\0\0\0\0" => try_read_addr(body_reader, version), + b"getaddr\0\0\0\0\0" => try_read_getaddr(body_reader, version), + b"block\0\0\0\0\0\0\0" => try_read_block(body_reader, version), + b"getblocks\0\0\0" => try_read_getblocks(body_reader, version), + b"headers\0\0\0\0\0" => try_read_headers(body_reader, version), + b"getheaders\0\0" => try_read_getheaders(body_reader, version), + b"inv\0\0\0\0\0\0\0\0\0" => try_read_inv(body_reader, version), + b"getdata\0\0\0\0\0" => try_read_getdata(body_reader, version), + b"notfound\0\0\0\0" => try_read_notfound(body_reader, version), + b"tx\0\0\0\0\0\0\0\0\0\0" => try_read_tx(body_reader, version), + b"mempool\0\0\0\0\0" => try_read_mempool(body_reader, version), + b"filterload\0\0" => try_read_filterload(body_reader, version), + b"filteradd\0\0\0" => try_read_filteradd(body_reader, version), + b"filterclear\0" => try_read_filterclear(body_reader, version), + b"merkleblock\0" => try_read_merkleblock(body_reader, version), + _ => Err(ParseError("Unknown command")), + } + } } impl Message { @@ -380,66 +454,6 @@ impl Message { } Ok(()) } - - /// Try to deserialize a [`Message`] from the given reader, similarly to `ZcashDeserialize`. - /// - /// This is similar to [`ZcashSerialize::zcash_serialize`], but not part of - /// that trait, because message serialization requires additional parameters - /// (the network magic and the network version). - pub fn zcash_deserialize( - mut reader: R, - magic: Magic, - version: Version, - ) -> Result { - use SerializationError::ParseError; - - // Read header data - let message_magic = Magic(reader.read_4_bytes()?); - let command = reader.read_12_bytes()?; - let body_len = reader.read_u32::()? as usize; - let checksum = Sha256dChecksum(reader.read_4_bytes()?); - - if magic != message_magic { - return Err(ParseError("Message has incorrect magic value")); - } - - // XXX bound the body_len value to avoid large attacker-controlled allocs - // XXX add a ChecksumReader(R) wrapper and avoid this - let body = { - let mut bytes = vec![0; body_len]; - reader.read_exact(&mut bytes)?; - bytes - }; - - if checksum != Sha256dChecksum::from(&body[..]) { - return Err(SerializationError::ParseError("checksum does not match")); - } - - let body_reader = io::Cursor::new(&body); - match &command { - b"version\0\0\0\0\0" => try_read_version(body_reader, version), - b"verack\0\0\0\0\0\0" => try_read_verack(body_reader, version), - b"ping\0\0\0\0\0\0\0\0" => try_read_ping(body_reader, version), - b"pong\0\0\0\0\0\0\0\0" => try_read_pong(body_reader, version), - b"reject\0\0\0\0\0\0" => try_read_reject(body_reader, version), - b"addr\0\0\0\0\0\0\0\0" => try_read_addr(body_reader, version), - b"getaddr\0\0\0\0\0" => try_read_getaddr(body_reader, version), - b"block\0\0\0\0\0\0\0" => try_read_block(body_reader, version), - b"getblocks\0\0\0" => try_read_getblocks(body_reader, version), - b"headers\0\0\0\0\0" => try_read_headers(body_reader, version), - b"getheaders\0\0" => try_read_getheaders(body_reader, version), - b"inv\0\0\0\0\0\0\0\0\0" => try_read_inv(body_reader, version), - b"getdata\0\0\0\0\0" => try_read_getdata(body_reader, version), - b"notfound\0\0\0\0" => try_read_notfound(body_reader, version), - b"tx\0\0\0\0\0\0\0\0\0\0" => try_read_tx(body_reader, version), - b"mempool\0\0\0\0\0" => try_read_mempool(body_reader, version), - b"filterload\0\0" => try_read_filterload(body_reader, version), - b"filteradd\0\0\0" => try_read_filteradd(body_reader, version), - b"filterclear\0" => try_read_filterclear(body_reader, version), - b"merkleblock\0" => try_read_merkleblock(body_reader, version), - _ => Err(ParseError("Unknown command")), - } - } } fn try_read_version( @@ -605,6 +619,7 @@ fn try_read_merkleblock( #[cfg(test)] mod tests { use super::*; + use tokio::runtime::Runtime; #[test] fn version_message_round_trip() { @@ -612,11 +627,12 @@ mod tests { let services = Services(0x1); let timestamp = Utc.timestamp(1568000000, 0); + let rt = Runtime::new().unwrap(); + let v = Message::Version { version: crate::constants::CURRENT_VERSION, services, timestamp, - // XXX maybe better to have Version keep only (Services, SocketAddr) address_recv: ( services, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 6)), 8233), @@ -631,24 +647,27 @@ mod tests { relay: true, }; - use std::io::Cursor; - - let v_bytes = { + let v_bytes = rt.block_on(async { let mut bytes = Vec::new(); - let _ = v.zcash_serialize( - Cursor::new(&mut bytes), + v.send( + &mut bytes, crate::constants::magics::MAINNET, crate::constants::CURRENT_VERSION, - ); + ) + .await + .unwrap(); bytes - }; + }); - let v_parsed = Message::zcash_deserialize( - Cursor::new(&v_bytes), - crate::constants::magics::MAINNET, - crate::constants::CURRENT_VERSION, - ) - .expect("message should parse successfully"); + let v_parsed = rt.block_on(async { + Message::recv( + Cursor::new(&v_bytes), + crate::constants::magics::MAINNET, + crate::constants::CURRENT_VERSION, + ) + .await + .unwrap() + }); assert_eq!(v, v_parsed); }