diff --git a/zebra-network/src/message.rs b/zebra-network/src/message.rs index 617076cb2..d7185ac1c 100644 --- a/zebra-network/src/message.rs +++ b/zebra-network/src/message.rs @@ -5,12 +5,11 @@ use std::net; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use chrono::{DateTime, TimeZone, Utc}; +use failure::Error; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zebra_chain::{ - serialization::{ - ReadZcashExt, SerializationError, WriteZcashExt, ZcashDeserialize, ZcashSerialize, - }, + serialization::{ReadZcashExt, WriteZcashExt, ZcashDeserialize, ZcashSerialize}, transaction::Transaction, types::{BlockHeight, Sha256dChecksum}, }; @@ -290,7 +289,7 @@ impl Message { mut writer: W, magic: Magic, version: Version, - ) -> Result<(), SerializationError> { + ) -> Result<(), Error> { // Because the header contains a checksum of // the body data, it must be written first. let mut body = Vec::new(); @@ -348,9 +347,7 @@ impl Message { mut reader: R, magic: Magic, version: Version, - ) -> Result { - use SerializationError::ParseError; - + ) -> Result { // Read the header into a stack buffer before trying to parse it. This // allows using the ReadBytesExt extension trait, which is only defined // for sync Readers. Then we can determine the expected message length, @@ -370,9 +367,10 @@ impl Message { let body_len = header_reader.read_u32::()? as usize; let checksum = Sha256dChecksum(header_reader.read_4_bytes()?); - if magic != message_magic { - return Err(ParseError("Message has incorrect magic value")); - } + ensure!( + magic == message_magic, + "supplied magic did not meet expectations", + ); // XXX bound the body_len value to avoid large attacker-controlled allocs // XXX add a ChecksumReader(R) wrapper and avoid this @@ -382,9 +380,10 @@ impl Message { bytes }; - if checksum != Sha256dChecksum::from(&body[..]) { - return Err(SerializationError::ParseError("checksum does not match")); - } + ensure!( + checksum == Sha256dChecksum::from(&body[..]), + "supplied message checksum does not match computed checksum" + ); let body_reader = Cursor::new(&body); match &command { @@ -408,7 +407,7 @@ impl Message { b"filteradd\0\0\0" => try_read_filteradd(body_reader, version), b"filterclear\0" => try_read_filterclear(body_reader, version), b"merkleblock\0" => try_read_merkleblock(body_reader, version), - _ => Err(ParseError("Unknown command")), + _ => bail!("unknown command"), } } } @@ -417,12 +416,7 @@ impl Message { /// Write the body of the message into the given writer. This allows writing /// the message body prior to writing the header, so that the header can /// contain a checksum of the message body. - fn write_body( - &self, - mut writer: W, - _m: Magic, - _v: Version, - ) -> Result<(), SerializationError> { + fn write_body(&self, mut writer: W, _m: Magic, _v: Version) -> Result<(), Error> { use Message::*; match *self { Version { @@ -460,7 +454,7 @@ impl Message { Pong(nonce) => { writer.write_u64::(nonce.0)?; } - _ => unimplemented!(), + _ => bail!("unimplemented message type"), } Ok(()) } @@ -469,7 +463,7 @@ impl Message { fn try_read_version( mut reader: R, _parsing_version: Version, -) -> Result { +) -> Result { Ok(Message::Version { version: Version(reader.read_u32::()?), services: Services(reader.read_u64::()?), @@ -488,142 +482,85 @@ fn try_read_version( relay: match reader.read_u8()? { 0 => false, 1 => true, - _ => return Err(SerializationError::ParseError("non-bool value")), + _ => bail!("non-bool value supplied in relay field"), }, }) } -fn try_read_verack( - mut _reader: R, - _version: Version, -) -> Result { +fn try_read_verack(mut _reader: R, _version: Version) -> Result { Ok(Message::Verack) } -fn try_read_ping( - mut reader: R, - _version: Version, -) -> Result { +fn try_read_ping(mut reader: R, _version: Version) -> Result { Ok(Message::Ping(Nonce(reader.read_u64::()?))) } -fn try_read_pong( - mut reader: R, - _version: Version, -) -> Result { +fn try_read_pong(mut reader: R, _version: Version) -> Result { Ok(Message::Pong(Nonce(reader.read_u64::()?))) } -fn try_read_reject( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_reject(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_addr( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_addr(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_getaddr( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_getaddr(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_block( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_block(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_getblocks( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_getblocks(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_headers( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_headers(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_getheaders( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_getheaders(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_inv( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_inv(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_getdata( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_getdata(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_notfound( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_notfound(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_tx( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_tx(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_mempool( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_mempool(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_filterload( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_filterload(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_filteradd( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_filteradd(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_filterclear( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_filterclear(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } -fn try_read_merkleblock( - mut _reader: R, - _version: Version, -) -> Result { - unimplemented!() +fn try_read_merkleblock(mut _reader: R, _version: Version) -> Result { + bail!("unimplemented message type") } #[cfg(test)]