diff --git a/zebra-network/Cargo.toml b/zebra-network/Cargo.toml index a6bea3397..f14881b61 100644 --- a/zebra-network/Cargo.toml +++ b/zebra-network/Cargo.toml @@ -7,6 +7,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bytes = "0.4" rand = "0.7" byteorder = "1.3" chrono = "0.4" diff --git a/zebra-network/src/lib.rs b/zebra-network/src/lib.rs index b02238884..efd2ab455 100644 --- a/zebra-network/src/lib.rs +++ b/zebra-network/src/lib.rs @@ -7,6 +7,9 @@ extern crate failure; #[macro_use] extern crate tracing; +mod network; +pub use network::Network; + pub mod protocol; pub mod types; diff --git a/zebra-network/src/network.rs b/zebra-network/src/network.rs new file mode 100644 index 000000000..c211e2a1e --- /dev/null +++ b/zebra-network/src/network.rs @@ -0,0 +1,20 @@ +use crate::{constants::magics, types::Magic}; + +/// An enum describing the possible network choices. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub enum Network { + /// The production mainnet. + Mainnet, + /// The testnet. + Testnet, +} + +impl Network { + /// Get the magic value associated to this `Network`. + pub fn magic(&self) -> Magic { + match self { + Network::Mainnet => magics::MAINNET, + Network::Testnet => magics::TESTNET, + } + } +} diff --git a/zebra-network/src/protocol.rs b/zebra-network/src/protocol.rs index 5baeafa72..21dabde18 100644 --- a/zebra-network/src/protocol.rs +++ b/zebra-network/src/protocol.rs @@ -1,5 +1,4 @@ //! Zcash network protocol handling. -pub mod message; pub mod codec; - +pub mod message; diff --git a/zebra-network/src/protocol/codec.rs b/zebra-network/src/protocol/codec.rs new file mode 100644 index 000000000..8fcfcfb31 --- /dev/null +++ b/zebra-network/src/protocol/codec.rs @@ -0,0 +1,505 @@ +//! A Tokio codec mapping byte streams to Bitcoin message streams. + +use std::io::{Cursor, Read, Write}; + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use bytes::BytesMut; +use chrono::{TimeZone, Utc}; +use failure::Error; +use tokio::codec::{Decoder, Encoder}; + +use zebra_chain::{ + serialization::{ReadZcashExt, WriteZcashExt}, + types::{BlockHeight, Sha256dChecksum}, +}; + +use crate::{constants, types::*, Network}; + +use super::message::Message; + +/// A codec which produces Bitcoin messages from byte streams and vice versa. +pub struct Codec { + builder: Builder, + state: DecodeState, +} + +/// A builder for specifying [`Codec`] options. +pub struct Builder { + /// The network magic to use in encoding. + network: Network, + /// The protocol version to speak when encoding/decoding. + version: Version, + /// The maximum allowable message length. + max_len: usize, +} + +impl Codec { + /// Return a builder for constructing a [`Codec`]. + /// + /// # Example + /// ``` + /// # use zebra_network::protocol::codec::Codec; + /// use zebra_network::{constants, Network}; + /// + /// let codec = Codec::builder() + /// .for_network(Network::Mainnet) + /// .for_version(constants::CURRENT_VERSION) + /// .with_max_body_len(4_000_000) + /// .finish(); + /// ``` + pub fn builder() -> Builder { + Builder { + network: Network::Mainnet, + version: constants::CURRENT_VERSION, + max_len: 4_000_000, + } + } + + /// Reconfigure the version used by the codec, e.g., after completing a handshake. + pub fn reconfigure_version(&mut self, version: Version) { + self.builder.version = version; + } +} + +impl Builder { + /// Finalize the builder and return a [`Codec`]. + pub fn finish(self) -> Codec { + Codec { + builder: self, + state: DecodeState::Head, + } + } + + /// Configure the codec for the given [`Network`]. + pub fn for_network(mut self, network: Network) -> Self { + self.network = network; + self + } + + /// Configure the codec for the given [`Version`]. + pub fn for_version(mut self, version: Version) -> Self { + self.version = version; + self + } + + /// Configure the codec's maximum accepted payload size, in bytes. + pub fn with_max_body_len(mut self, len: usize) -> Self { + self.max_len = len; + self + } +} + +/// The length of a Bitcoin message header. +const HEADER_LEN: usize = 24usize; + +// ======== Encoding ========= + +impl Codec { + /// Write the body of the message into the given writer. This allows writing + /// the message body prior to writing the header, so that the header can + /// contain a checksum of the message body. + fn write_body(&self, msg: &Message, mut writer: W) -> Result<(), Error> { + use Message::*; + match *msg { + Version { + ref version, + ref services, + ref timestamp, + ref address_recv, + ref address_from, + ref nonce, + ref user_agent, + ref start_height, + ref relay, + } => { + writer.write_u32::(version.0)?; + writer.write_u64::(services.0)?; + writer.write_i64::(timestamp.timestamp())?; + + let (recv_services, recv_addr) = address_recv; + writer.write_u64::(recv_services.0)?; + writer.write_socket_addr(*recv_addr)?; + + let (from_services, from_addr) = address_from; + writer.write_u64::(from_services.0)?; + writer.write_socket_addr(*from_addr)?; + + writer.write_u64::(nonce.0)?; + writer.write_string(&user_agent)?; + writer.write_u32::(start_height.0)?; + writer.write_u8(*relay as u8)?; + } + Verack => { /* Empty payload -- no-op */ } + Ping(nonce) => { + writer.write_u64::(nonce.0)?; + } + Pong(nonce) => { + writer.write_u64::(nonce.0)?; + } + _ => bail!("unimplemented message type"), + } + Ok(()) + } +} + +impl Encoder for Codec { + type Item = Message; + type Error = Error; + + #[instrument(skip(src))] + fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { + // XXX(HACK): this is inefficient and does an extra allocation. + // instead, we should have a size estimator for the message, reserve + // that much space, write the header (with zeroed checksum), then the body, + // then write the computed checksum in-place. for now, just do an extra alloc. + + let mut body = Vec::new(); + self.write_body(&item, &mut body)?; + + use Message::*; + // Note: because all match arms must have + // the same type, and the array length is + // part of the type, having at least one + // of length 12 checks that they are all + // of length 12, as they must be &[u8; 12]. + let command = match item { + Version { .. } => b"version\0\0\0\0\0", + Verack { .. } => b"verack\0\0\0\0\0\0", + Ping { .. } => b"ping\0\0\0\0\0\0\0\0", + Pong { .. } => b"pong\0\0\0\0\0\0\0\0", + Reject { .. } => b"reject\0\0\0\0\0\0", + Addr { .. } => b"addr\0\0\0\0\0\0\0\0", + GetAddr { .. } => b"getaddr\0\0\0\0\0", + Block { .. } => b"block\0\0\0\0\0\0\0", + GetBlocks { .. } => b"getblocks\0\0\0", + Headers { .. } => b"headers\0\0\0\0\0", + GetHeaders { .. } => b"getheaders\0\0", + Inventory { .. } => b"inv\0\0\0\0\0\0\0\0\0", // XXX Inventory -> Inv ? + GetData { .. } => b"getdata\0\0\0\0\0", + NotFound { .. } => b"notfound\0\0\0\0", + Tx { .. } => b"tx\0\0\0\0\0\0\0\0\0\0", + Mempool { .. } => b"mempool\0\0\0\0\0", + FilterLoad { .. } => b"filterload\0\0", + FilterAdd { .. } => b"filteradd\0\0\0", + FilterClear { .. } => b"filterclear\0", + MerkleBlock { .. } => b"merkleblock\0", + }; + trace!(?command, len = body.len()); + + // XXX this should write directly into the buffer, + // but leave it for now until we fix the issue above. + let mut header = [0u8; HEADER_LEN]; + let mut header_writer = Cursor::new(&mut header[..]); + header_writer.write_all(&self.builder.network.magic().0)?; + header_writer.write_all(command)?; + header_writer.write_u32::(body.len() as u32)?; + header_writer.write_all(&Sha256dChecksum::from(&body[..]).0)?; + + dst.reserve(HEADER_LEN + body.len()); + dst.extend_from_slice(&header); + dst.extend_from_slice(&body); + + Ok(()) + } +} + +// ======== Decoding ========= + +#[derive(Debug)] +enum DecodeState { + Head, + Body { + body_len: usize, + command: [u8; 12], + checksum: Sha256dChecksum, + }, +} + +impl Decoder for Codec { + type Item = Message; + type Error = Error; + + #[instrument(skip(src))] + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match self.state { + DecodeState::Head => { + // First check that the src buffer contains an entire header. + if src.len() < HEADER_LEN { + trace!(?self.state, "src buffer does not have an entire header, waiting"); + // Signal that decoding requires more data. + return Ok(None); + } + + // Now that we know that src contains a header, split off the header section. + let header = src.split_to(HEADER_LEN); + + // Create a cursor over the header and parse its fields. + let mut header_reader = Cursor::new(&header); + let magic = Magic(header_reader.read_4_bytes()?); + let command = header_reader.read_12_bytes()?; + let body_len = header_reader.read_u32::()? as usize; + let checksum = Sha256dChecksum(header_reader.read_4_bytes()?); + trace!(?self.state, ?magic, ?command, body_len, ?checksum, "read header from src buffer"); + + ensure!( + magic == self.builder.network.magic(), + "supplied magic did not meet expectations" + ); + ensure!( + body_len < self.builder.max_len, + "body length exceeded maximum size", + ); + + // Reserve buffer space for the expected body and the following header. + src.reserve(body_len + HEADER_LEN); + + self.state = DecodeState::Body { + body_len, + command, + checksum, + }; + + // Now that the state is updated, recurse to attempt body decoding. + self.decode(src) + } + DecodeState::Body { + body_len, + command, + checksum, + } => { + if src.len() < body_len { + // Need to wait for the full body + trace!(?self.state, len = src.len(), "src buffer does not have an entire body, waiting"); + return Ok(None); + } + + // Now that we know we have the full body, split off the body, + // and reset the decoder state for the next message. + let body = src.split_to(body_len); + self.state = DecodeState::Head; + + ensure!( + checksum == Sha256dChecksum::from(&body[..]), + "supplied message checksum does not match computed checksum" + ); + + let body_reader = Cursor::new(&body); + let v = self.builder.version; + match &command { + b"version\0\0\0\0\0" => try_read_version(body_reader, v), + b"verack\0\0\0\0\0\0" => try_read_verack(body_reader, v), + b"ping\0\0\0\0\0\0\0\0" => try_read_ping(body_reader, v), + b"pong\0\0\0\0\0\0\0\0" => try_read_pong(body_reader, v), + b"reject\0\0\0\0\0\0" => try_read_reject(body_reader, v), + b"addr\0\0\0\0\0\0\0\0" => try_read_addr(body_reader, v), + b"getaddr\0\0\0\0\0" => try_read_getaddr(body_reader, v), + b"block\0\0\0\0\0\0\0" => try_read_block(body_reader, v), + b"getblocks\0\0\0" => try_read_getblocks(body_reader, v), + b"headers\0\0\0\0\0" => try_read_headers(body_reader, v), + b"getheaders\0\0" => try_read_getheaders(body_reader, v), + b"inv\0\0\0\0\0\0\0\0\0" => try_read_inv(body_reader, v), + b"getdata\0\0\0\0\0" => try_read_getdata(body_reader, v), + b"notfound\0\0\0\0" => try_read_notfound(body_reader, v), + b"tx\0\0\0\0\0\0\0\0\0\0" => try_read_tx(body_reader, v), + b"mempool\0\0\0\0\0" => try_read_mempool(body_reader, v), + b"filterload\0\0" => try_read_filterload(body_reader, v), + b"filteradd\0\0\0" => try_read_filteradd(body_reader, v), + b"filterclear\0" => try_read_filterclear(body_reader, v), + b"merkleblock\0" => try_read_merkleblock(body_reader, v), + _ => bail!("unknown command"), + } + // We need Ok(Some(msg)) to signal that we're done decoding + .map(|msg| Some(msg)) + } + } + } +} + +fn try_read_version(mut reader: R, _parsing_version: Version) -> Result { + Ok(Message::Version { + version: Version(reader.read_u32::()?), + services: Services(reader.read_u64::()?), + timestamp: Utc.timestamp(reader.read_i64::()?, 0), + address_recv: ( + Services(reader.read_u64::()?), + reader.read_socket_addr()?, + ), + address_from: ( + Services(reader.read_u64::()?), + reader.read_socket_addr()?, + ), + nonce: Nonce(reader.read_u64::()?), + user_agent: reader.read_string()?, + start_height: BlockHeight(reader.read_u32::()?), + relay: match reader.read_u8()? { + 0 => false, + 1 => true, + _ => bail!("non-bool value supplied in relay field"), + }, + }) +} + +fn try_read_verack(mut _reader: R, _version: Version) -> Result { + Ok(Message::Verack) +} + +fn try_read_ping(mut reader: R, _version: Version) -> Result { + Ok(Message::Ping(Nonce(reader.read_u64::()?))) +} + +fn try_read_pong(mut reader: R, _version: Version) -> Result { + Ok(Message::Pong(Nonce(reader.read_u64::()?))) +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_reject(mut _reader: R, _version: Version) -> Result { + trace!("reject"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_addr(mut _reader: R, _version: Version) -> Result { + trace!("addr"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_getaddr(mut _reader: R, _version: Version) -> Result { + trace!("getaddr"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_block(mut _reader: R, _version: Version) -> Result { + trace!("block"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_getblocks(mut _reader: R, _version: Version) -> Result { + trace!("getblocks"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_headers(mut _reader: R, _version: Version) -> Result { + trace!("headers"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_getheaders(mut _reader: R, _version: Version) -> Result { + trace!("getheaders"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_inv(mut _reader: R, _version: Version) -> Result { + trace!("inv"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_getdata(mut _reader: R, _version: Version) -> Result { + trace!("getdata"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_notfound(mut _reader: R, _version: Version) -> Result { + trace!("notfound"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_tx(mut _reader: R, _version: Version) -> Result { + trace!("tx"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_mempool(mut _reader: R, _version: Version) -> Result { + trace!("mempool"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_filterload(mut _reader: R, _version: Version) -> Result { + trace!("filterload"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_filteradd(mut _reader: R, _version: Version) -> Result { + trace!("filteradd"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_filterclear(mut _reader: R, _version: Version) -> Result { + trace!("filterclear"); + bail!("unimplemented message type") +} + +#[instrument(level = "trace", skip(_reader, _version))] +fn try_read_merkleblock(mut _reader: R, _version: Version) -> Result { + trace!("merkleblock"); + bail!("unimplemented message type") +} + +// XXX replace these interior unit tests with exterior integration tests + proptest +#[cfg(test)] +mod tests { + use super::*; + use tokio::runtime::Runtime; + + #[test] + fn version_message_round_trip() { + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + let services = Services(0x1); + let timestamp = Utc.timestamp(1568000000, 0); + + let rt = Runtime::new().unwrap(); + + let v = Message::Version { + version: crate::constants::CURRENT_VERSION, + services, + timestamp, + address_recv: ( + services, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 6)), 8233), + ), + address_from: ( + services, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 6)), 8233), + ), + nonce: Nonce(0x9082_4908_8927_9238), + user_agent: "Zebra".to_owned(), + start_height: BlockHeight(540_000), + relay: true, + }; + + use tokio::codec::{FramedRead, FramedWrite}; + use tokio::prelude::*; + let v_bytes = rt.block_on(async { + let mut bytes = Vec::new(); + { + let mut fw = FramedWrite::new(&mut bytes, Codec::builder().finish()); + fw.send(v.clone()) + .await + .expect("message should be serialized"); + } + bytes + }); + + let v_parsed = rt.block_on(async { + let mut fr = FramedRead::new(Cursor::new(&v_bytes), Codec::builder().finish()); + fr.next() + .await + .expect("a next message should be available") + .expect("that message should deserialize") + }); + + assert_eq!(v, v_parsed); + } +} diff --git a/zebra-network/src/protocol/message.rs b/zebra-network/src/protocol/message.rs index af1183b29..220e78081 100644 --- a/zebra-network/src/protocol/message.rs +++ b/zebra-network/src/protocol/message.rs @@ -1,18 +1,10 @@ //! Definitions of network messages. -use std::io::{self, Cursor, Read, Write}; use std::net; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use chrono::{DateTime, TimeZone, Utc}; -use failure::Error; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use chrono::{DateTime, Utc}; -use zebra_chain::{ - serialization::{ReadZcashExt, WriteZcashExt, ZcashDeserialize, ZcashSerialize}, - transaction::Transaction, - types::{BlockHeight, Sha256dChecksum}, -}; +use zebra_chain::{transaction::Transaction, types::BlockHeight}; use crate::meta_addr::MetaAddr; use crate::types::*; @@ -271,388 +263,3 @@ pub enum RejectReason { InsufficientFee = 0x42, Checkpoint = 0x43, } - -// Q: how do we want to implement serialization, exactly? do we want to have -// something generic over stdlib Read and Write traits, or over async versions -// of those traits? -// -// Note: because of the way the message structure is defined (checksum comes -// first) we can't write the message headers before collecting the whole body -// into a buffer -// -// Maybe just write some functions and refactor later? - -impl Message { - /// Send `self` to the given async writer (e.g., a network stream). - #[instrument(level = "debug", skip(writer))] - pub async fn send( - &self, - mut writer: W, - magic: Magic, - version: Version, - ) -> Result<(), Error> { - // Because the header contains a checksum of - // the body data, it must be written first. - let mut body = Vec::new(); - self.write_body(&mut body, magic, version)?; - - use Message::*; - // Note: because all match arms must have - // the same type, and the array length is - // part of the type, having at least one - // of length 12 checks that they are all - // of length 12, as they must be &[u8; 12]. - let command = match *self { - Version { .. } => b"version\0\0\0\0\0", - Verack { .. } => b"verack\0\0\0\0\0\0", - Ping { .. } => b"ping\0\0\0\0\0\0\0\0", - Pong { .. } => b"pong\0\0\0\0\0\0\0\0", - Reject { .. } => b"reject\0\0\0\0\0\0", - Addr { .. } => b"addr\0\0\0\0\0\0\0\0", - GetAddr { .. } => b"getaddr\0\0\0\0\0", - Block { .. } => b"block\0\0\0\0\0\0\0", - GetBlocks { .. } => b"getblocks\0\0\0", - Headers { .. } => b"headers\0\0\0\0\0", - GetHeaders { .. } => b"getheaders\0\0", - Inventory { .. } => b"inv\0\0\0\0\0\0\0\0\0", - GetData { .. } => b"getdata\0\0\0\0\0", - NotFound { .. } => b"notfound\0\0\0\0", - Tx { .. } => b"tx\0\0\0\0\0\0\0\0\0\0", - Mempool { .. } => b"mempool\0\0\0\0\0", - FilterLoad { .. } => b"filterload\0\0", - FilterAdd { .. } => b"filteradd\0\0\0", - FilterClear { .. } => b"filterclear\0", - MerkleBlock { .. } => b"merkleblock\0", - }; - - // Write the header into a stack buffer first before feeding that stack - // buffer into the async writer. This allows using the WriteBytesExt - // extension trait, which is only defined for sync Writers. - - // The header is 4+12+4+4=24 bytes long. - trace!(?command, body_len = body.len()); - let mut header = [0u8; 24]; - let mut header_writer = Cursor::new(&mut header[..]); - header_writer.write_all(&magic.0)?; - header_writer.write_all(command)?; - header_writer.write_u32::(body.len() as u32)?; - header_writer.write_all(&Sha256dChecksum::from(&body[..]).0)?; - - writer.write_all(&header).await?; - writer.write_all(&body).await?; - - Ok(()) - } - - /// Receive a message from the given async reader (e.g., a network stream). - #[instrument(level = "debug", skip(reader))] - pub async fn recv( - mut reader: R, - magic: Magic, - version: Version, - ) -> Result { - // Read the header into a stack buffer before trying to parse it. This - // allows using the ReadBytesExt extension trait, which is only defined - // for sync Readers. Then we can determine the expected message length, - // await that many bytes, and finally parse the message. - - // The header is 4+12+4+4=24 bytes long. - let header = { - let mut bytes = [0u8; 24]; - reader.read_exact(&mut bytes).await?; - bytes - }; - let mut header_reader = Cursor::new(&header[..]); - - // Read header data - let message_magic = Magic(header_reader.read_4_bytes()?); - let command = header_reader.read_12_bytes()?; - let body_len = header_reader.read_u32::()? as usize; - let checksum = Sha256dChecksum(header_reader.read_4_bytes()?); - trace!(?message_magic, ?command, body_len, ?checksum); - - ensure!( - magic == message_magic, - "supplied magic did not meet expectations", - ); - - // XXX bound the body_len value to avoid large attacker-controlled allocs - // XXX add a ChecksumReader(R) wrapper and avoid this - let body = { - let mut bytes = vec![0; body_len]; - reader.read_exact(&mut bytes).await?; - bytes - }; - - ensure!( - checksum == Sha256dChecksum::from(&body[..]), - "supplied message checksum does not match computed checksum" - ); - - let body_reader = Cursor::new(&body); - match &command { - b"version\0\0\0\0\0" => try_read_version(body_reader, version), - b"verack\0\0\0\0\0\0" => try_read_verack(body_reader, version), - b"ping\0\0\0\0\0\0\0\0" => try_read_ping(body_reader, version), - b"pong\0\0\0\0\0\0\0\0" => try_read_pong(body_reader, version), - b"reject\0\0\0\0\0\0" => try_read_reject(body_reader, version), - b"addr\0\0\0\0\0\0\0\0" => try_read_addr(body_reader, version), - b"getaddr\0\0\0\0\0" => try_read_getaddr(body_reader, version), - b"block\0\0\0\0\0\0\0" => try_read_block(body_reader, version), - b"getblocks\0\0\0" => try_read_getblocks(body_reader, version), - b"headers\0\0\0\0\0" => try_read_headers(body_reader, version), - b"getheaders\0\0" => try_read_getheaders(body_reader, version), - b"inv\0\0\0\0\0\0\0\0\0" => try_read_inv(body_reader, version), - b"getdata\0\0\0\0\0" => try_read_getdata(body_reader, version), - b"notfound\0\0\0\0" => try_read_notfound(body_reader, version), - b"tx\0\0\0\0\0\0\0\0\0\0" => try_read_tx(body_reader, version), - b"mempool\0\0\0\0\0" => try_read_mempool(body_reader, version), - b"filterload\0\0" => try_read_filterload(body_reader, version), - b"filteradd\0\0\0" => try_read_filteradd(body_reader, version), - b"filterclear\0" => try_read_filterclear(body_reader, version), - b"merkleblock\0" => try_read_merkleblock(body_reader, version), - _ => bail!("unknown command"), - } - } -} - -impl Message { - /// Write the body of the message into the given writer. This allows writing - /// the message body prior to writing the header, so that the header can - /// contain a checksum of the message body. - fn write_body(&self, mut writer: W, _m: Magic, _v: Version) -> Result<(), Error> { - use Message::*; - trace!(?self); - match *self { - Version { - ref version, - ref services, - ref timestamp, - ref address_recv, - ref address_from, - ref nonce, - ref user_agent, - ref start_height, - ref relay, - } => { - writer.write_u32::(version.0)?; - writer.write_u64::(services.0)?; - writer.write_i64::(timestamp.timestamp())?; - - let (recv_services, recv_addr) = address_recv; - writer.write_u64::(recv_services.0)?; - writer.write_socket_addr(*recv_addr)?; - - let (from_services, from_addr) = address_from; - writer.write_u64::(from_services.0)?; - writer.write_socket_addr(*from_addr)?; - - writer.write_u64::(nonce.0)?; - writer.write_string(&user_agent)?; - writer.write_u32::(start_height.0)?; - writer.write_u8(*relay as u8)?; - } - Verack => { /* Empty payload -- no-op */ } - Ping(nonce) => { - writer.write_u64::(nonce.0)?; - } - Pong(nonce) => { - writer.write_u64::(nonce.0)?; - } - _ => bail!("unimplemented message type"), - } - Ok(()) - } -} - -fn try_read_version( - mut reader: R, - _parsing_version: Version, -) -> Result { - Ok(Message::Version { - version: Version(reader.read_u32::()?), - services: Services(reader.read_u64::()?), - timestamp: Utc.timestamp(reader.read_i64::()?, 0), - address_recv: ( - Services(reader.read_u64::()?), - reader.read_socket_addr()?, - ), - address_from: ( - Services(reader.read_u64::()?), - reader.read_socket_addr()?, - ), - nonce: Nonce(reader.read_u64::()?), - user_agent: reader.read_string()?, - start_height: BlockHeight(reader.read_u32::()?), - relay: match reader.read_u8()? { - 0 => false, - 1 => true, - _ => bail!("non-bool value supplied in relay field"), - }, - }) -} - -fn try_read_verack(mut _reader: R, _version: Version) -> Result { - Ok(Message::Verack) -} - -fn try_read_ping(mut reader: R, _version: Version) -> Result { - Ok(Message::Ping(Nonce(reader.read_u64::()?))) -} - -fn try_read_pong(mut reader: R, _version: Version) -> Result { - Ok(Message::Pong(Nonce(reader.read_u64::()?))) -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_reject(mut _reader: R, _version: Version) -> Result { - trace!("reject"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_addr(mut _reader: R, _version: Version) -> Result { - trace!("addr"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_getaddr(mut _reader: R, _version: Version) -> Result { - trace!("getaddr"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_block(mut _reader: R, _version: Version) -> Result { - trace!("block"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_getblocks(mut _reader: R, _version: Version) -> Result { - trace!("getblocks"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_headers(mut _reader: R, _version: Version) -> Result { - trace!("headers"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_getheaders(mut _reader: R, _version: Version) -> Result { - trace!("getheaders"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_inv(mut _reader: R, _version: Version) -> Result { - trace!("inv"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_getdata(mut _reader: R, _version: Version) -> Result { - trace!("getdata"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_notfound(mut _reader: R, _version: Version) -> Result { - trace!("notfound"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_tx(mut _reader: R, _version: Version) -> Result { - trace!("tx"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_mempool(mut _reader: R, _version: Version) -> Result { - trace!("mempool"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_filterload(mut _reader: R, _version: Version) -> Result { - trace!("filterload"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_filteradd(mut _reader: R, _version: Version) -> Result { - trace!("filteradd"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_filterclear(mut _reader: R, _version: Version) -> Result { - trace!("filterclear"); - bail!("unimplemented message type") -} - -#[instrument(level = "trace", skip(_reader, _version))] -fn try_read_merkleblock(mut _reader: R, _version: Version) -> Result { - trace!("merkleblock"); - bail!("unimplemented message type") -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::runtime::Runtime; - - #[test] - fn version_message_round_trip() { - use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - let services = Services(0x1); - let timestamp = Utc.timestamp(1568000000, 0); - - let rt = Runtime::new().unwrap(); - - let v = Message::Version { - version: crate::constants::CURRENT_VERSION, - services, - timestamp, - address_recv: ( - services, - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 6)), 8233), - ), - address_from: ( - services, - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 6)), 8233), - ), - nonce: Nonce(0x9082_4908_8927_9238), - user_agent: "Zebra".to_owned(), - start_height: BlockHeight(540_000), - relay: true, - }; - - let v_bytes = rt.block_on(async { - let mut bytes = Vec::new(); - v.send( - &mut bytes, - crate::constants::magics::MAINNET, - crate::constants::CURRENT_VERSION, - ) - .await - .unwrap(); - bytes - }); - - let v_parsed = rt.block_on(async { - Message::recv( - Cursor::new(&v_bytes), - crate::constants::magics::MAINNET, - crate::constants::CURRENT_VERSION, - ) - .await - .unwrap() - }); - - assert_eq!(v, v_parsed); - } -} diff --git a/zebrad/src/commands/connect.rs b/zebrad/src/commands/connect.rs index a5e643c1c..583663f0c 100644 --- a/zebrad/src/commands/connect.rs +++ b/zebrad/src/commands/connect.rs @@ -26,7 +26,18 @@ impl Runnable for ConnectCmd { // Combine the connect future with an infinite wait // so that the program has to be explicitly killed and // won't die before all tracing messages are written. - let fut = futures_util::future::join(self.connect(), wait); + let fut = futures_util::future::join( + async { + match self.connect().await { + Ok(()) => {} + Err(e) => { + // Print any error that occurs. + error!(?e); + } + } + }, + wait, + ); let _ = app_reader() .state() @@ -43,14 +54,22 @@ impl ConnectCmd { use std::net::Shutdown; use chrono::Utc; - use tokio::net::TcpStream; + use tokio::{codec::Framed, net::TcpStream, prelude::*}; use zebra_chain::types::BlockHeight; - use zebra_network::{constants, protocol::message::*, types::*}; + use zebra_network::{ + constants, + protocol::{codec::*, message::*}, + types::*, + Network, + }; info!("connecting"); - let mut stream = TcpStream::connect(self.addr).await?; + let mut stream = Framed::new( + TcpStream::connect(self.addr).await?, + Codec::builder().for_network(Network::Mainnet).finish(), + ); let version = Message::Version { version: constants::CURRENT_VERSION, @@ -69,55 +88,22 @@ impl ConnectCmd { info!(version = ?version); - version - .send( - &mut stream, - constants::magics::MAINNET, - constants::CURRENT_VERSION, - ) - .await?; + stream.send(version).await?; + + let resp_version: Message = stream.next().await.expect("expected data")?; - let resp_version = Message::recv( - &mut stream, - constants::magics::MAINNET, - constants::CURRENT_VERSION, - ) - .await?; info!(resp_version = ?resp_version); - Message::Verack - .send( - &mut stream, - constants::magics::MAINNET, - constants::CURRENT_VERSION, - ) - .await?; + stream.send(Message::Verack).await?; - let resp_verack = Message::recv( - &mut stream, - constants::magics::MAINNET, - constants::CURRENT_VERSION, - ) - .await?; + let resp_verack = stream.next().await.expect("expected data")?; info!(resp_verack = ?resp_verack); - loop { - match Message::recv( - &mut stream, - constants::magics::MAINNET, - constants::CURRENT_VERSION, - ) - .await - { + while let Some(maybe_msg) = stream.next().await { + match maybe_msg { Ok(msg) => match msg { Message::Ping(nonce) => { - let pong = Message::Pong(nonce); - pong.send( - &mut stream, - constants::magics::MAINNET, - constants::CURRENT_VERSION, - ) - .await?; + stream.send(Message::Pong(nonce)).await?; } _ => warn!("Unknown message"), }, @@ -125,8 +111,6 @@ impl ConnectCmd { }; } - stream.shutdown(Shutdown::Both)?; - Ok(()) } } diff --git a/zebrad/src/lib.rs b/zebrad/src/lib.rs index fe8da8a2f..d13de42ff 100644 --- a/zebrad/src/lib.rs +++ b/zebrad/src/lib.rs @@ -10,6 +10,8 @@ #[macro_use] extern crate tracing; +#[macro_use] +extern crate failure; mod components;