parity-zcash/p2p/src/io/read_message.rs

105 lines
3.0 KiB
Rust
Raw Normal View History

2016-10-10 02:56:01 -07:00
use std::io;
use std::marker::PhantomData;
use futures::{Poll, Future, Async};
use network::Magic;
use message::{MessageResult, Error, Payload};
2016-10-12 02:24:56 -07:00
use io::{read_header, ReadHeader, read_payload, ReadPayload};
2016-10-10 02:56:01 -07:00
2016-10-12 02:24:56 -07:00
pub fn read_message<M, A>(a: A, magic: Magic, version: u32) -> ReadMessage<M, A>
where A: io::Read, M: Payload {
2016-10-12 02:24:56 -07:00
ReadMessage {
2016-10-10 02:56:01 -07:00
state: ReadMessageState::ReadHeader {
version: version,
future: read_header(a, magic),
},
message_type: PhantomData
}
}
enum ReadMessageState<M, A> {
ReadHeader {
version: u32,
future: ReadHeader<A>,
},
ReadPayload {
2016-10-12 02:24:56 -07:00
future: ReadPayload<M, A>,
2016-10-10 02:56:01 -07:00
},
Finished,
}
2016-10-12 02:24:56 -07:00
pub struct ReadMessage<M, A> {
2016-10-10 02:56:01 -07:00
state: ReadMessageState<M, A>,
message_type: PhantomData<M>,
}
impl<M, A> Future for ReadMessage<M, A> where A: io::Read, M: Payload {
2016-10-10 02:56:01 -07:00
type Item = (A, MessageResult<M>);
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let (next, result) = match self.state {
ReadMessageState::ReadHeader { version, ref mut future } => {
let (read, header) = try_ready!(future.poll());
let header = match header {
Ok(header) => header,
2016-10-13 06:24:37 -07:00
Err(err) => return Ok((read, Err(err)).into()),
2016-10-10 02:56:01 -07:00
};
if header.command != M::command() {
2016-10-10 02:56:01 -07:00
return Ok((read, Err(Error::InvalidCommand)).into());
}
2016-10-12 02:24:56 -07:00
let future = read_payload(
2016-10-10 02:56:01 -07:00
read, version, header.len as usize, header.checksum,
);
let next = ReadMessageState::ReadPayload {
future: future,
};
(next, Async::NotReady)
},
ReadMessageState::ReadPayload { ref mut future } => {
let (read, payload) = try_ready!(future.poll());
(ReadMessageState::Finished, Async::Ready((read, payload)))
},
2016-10-12 02:24:56 -07:00
ReadMessageState::Finished => panic!("poll ReadMessage after it's done"),
2016-10-10 02:56:01 -07:00
};
self.state = next;
match result {
// by polling again, we register new future
Async::NotReady => self.poll(),
result => Ok(result)
}
}
}
#[cfg(test)]
mod tests {
use futures::Future;
use bytes::Bytes;
use network::Magic;
use message::Error;
use message::types::{Ping, Pong};
use super::read_message;
#[test]
fn test_read_message() {
let raw: Bytes = "f9beb4d970696e6700000000000000000800000083c00c765845303b6da97786".into();
let ping = Ping::new(u64::from_str_radix("8677a96d3b304558", 16).unwrap());
assert_eq!(read_message(raw.as_ref(), Magic::Mainnet, 0).wait().unwrap().1, Ok(ping));
assert_eq!(read_message::<Ping, _>(raw.as_ref(), Magic::Testnet, 0).wait().unwrap().1, Err(Error::InvalidMagic));
assert_eq!(read_message::<Pong, _>(raw.as_ref(), Magic::Mainnet, 0).wait().unwrap().1, Err(Error::InvalidCommand));
}
#[test]
fn test_read_too_short_message() {
let raw: Bytes = "f9beb4d970696e6700000000000000000800000083c00c765845303b6da977".into();
assert!(read_message::<Ping, _>(raw.as_ref(), Magic::Mainnet, 0).wait().is_err());
}
#[test]
fn test_read_message_with_invalid_checksum() {
let raw: Bytes = "f9beb4d970696e6700000000000000000800000083c01c765845303b6da97786".into();
assert_eq!(read_message::<Ping, _>(raw.as_ref(), Magic::Mainnet, 0).wait().unwrap().1, Err(Error::InvalidChecksum));
}
}