From 5b8bf322b0c6dc99f062a0f5382a82fb7549273d Mon Sep 17 00:00:00 2001 From: debris Date: Fri, 4 Aug 2017 14:05:58 +0200 Subject: [PATCH] remove redundant recursion in p2p::io --- p2p/src/io/handshake.rs | 180 +++++++++++++++------------------ p2p/src/io/read_any_message.rs | 53 +++++----- p2p/src/io/read_message.rs | 55 +++++----- 3 files changed, 127 insertions(+), 161 deletions(-) diff --git a/p2p/src/io/handshake.rs b/p2p/src/io/handshake.rs index be91a835..b429b317 100644 --- a/p2p/src/io/handshake.rs +++ b/p2p/src/io/handshake.rs @@ -58,7 +58,6 @@ enum HandshakeState { version: Option, future: ReadMessage, }, - Finished, } enum AcceptHandshakeState { @@ -74,7 +73,6 @@ enum AcceptHandshakeState { version: Option, future: WriteMessage, }, - Finished, } pub struct Handshake { @@ -98,65 +96,56 @@ impl Future for Handshake where A: AsyncRead + AsyncWrite { type Error = io::Error; fn poll(&mut self) -> Poll { - let (next, result) = match self.state { - HandshakeState::SendVersion(ref mut future) => { - let (stream, _) = try_ready!(future.poll()); - (HandshakeState::ReceiveVersion(read_message(stream, self.magic, 0)), Async::NotReady) - }, - HandshakeState::ReceiveVersion(ref mut future) => { - let (stream, version) = try_ready!(future.poll()); - let version = match version { - Ok(version) => version, - Err(err) => return Ok((stream, Err(err.into())).into()), - }; + loop { + let next_state = match self.state { + HandshakeState::SendVersion(ref mut future) => { + let (stream, _) = try_ready!(future.poll()); + HandshakeState::ReceiveVersion(read_message(stream, self.magic, 0)) + }, + HandshakeState::ReceiveVersion(ref mut future) => { + let (stream, version) = try_ready!(future.poll()); + let version = match version { + Ok(version) => version, + Err(err) => return Ok((stream, Err(err.into())).into()), + }; - if version.version() < self.min_version { - return Ok((stream, Err(Error::InvalidVersion)).into()); - } - if let (Some(self_nonce), Some(nonce)) = (self.nonce, version.nonce()) { - if self_nonce == nonce { + if version.version() < self.min_version { return Ok((stream, Err(Error::InvalidVersion)).into()); } - } + if let (Some(self_nonce), Some(nonce)) = (self.nonce, version.nonce()) { + if self_nonce == nonce { + return Ok((stream, Err(Error::InvalidVersion)).into()); + } + } - let next = HandshakeState::SendVerack { - version: Some(version), - future: write_message(stream, verack_message(self.magic)), - }; + HandshakeState::SendVerack { + version: Some(version), + future: write_message(stream, verack_message(self.magic)), + } + }, + HandshakeState::SendVerack { ref mut version, ref mut future } => { + let (stream, _) = try_ready!(future.poll()); - (next, Async::NotReady) - }, - HandshakeState::SendVerack { ref mut version, ref mut future } => { - let (stream, _) = try_ready!(future.poll()); + let version = version.take().expect("verack must be preceded by version"); - let version = version.take().expect("verack must be preceded by version"); + HandshakeState::ReceiveVerack { + version: Some(version), + future: read_message(stream, self.magic, 0), + } + }, + HandshakeState::ReceiveVerack { ref mut version, ref mut future } => { + let (stream, _verack) = try_ready!(future.poll()); + let version = version.take().expect("verack must be preceded by version"); - let next = HandshakeState::ReceiveVerack { - version: Some(version), - future: read_message(stream, self.magic, 0), - }; + let result = HandshakeResult { + negotiated_version: negotiate_version(self.version, version.version()), + version: version, + }; - (next, Async::NotReady) - }, - HandshakeState::ReceiveVerack { ref mut version, ref mut future } => { - let (stream, _verack) = try_ready!(future.poll()); - let version = version.take().expect("verack must be preceded by version"); - - let result = HandshakeResult { - negotiated_version: negotiate_version(self.version, version.version()), - version: version, - }; - - (HandshakeState::Finished, Async::Ready((stream, Ok(result)))) - }, - HandshakeState::Finished => panic!("poll Handshake after it's done"), - }; - - self.state = next; - match result { - // by polling again, we register new future - Async::NotReady => self.poll(), - result => Ok(result) + return Ok(Async::Ready((stream, Ok(result)))); + }, + }; + self.state = next_state; } } } @@ -166,60 +155,51 @@ impl Future for AcceptHandshake where A: AsyncRead + AsyncWrite { type Error = io::Error; fn poll(&mut self) -> Poll { - let (next, result) = match self.state { - AcceptHandshakeState::ReceiveVersion { ref mut local_version, ref mut future } => { - let (stream, version) = try_ready!(future.poll()); - let version = match version { - Ok(version) => version, - Err(err) => return Ok((stream, Err(err.into())).into()), - }; + loop { + let next_state = match self.state { + AcceptHandshakeState::ReceiveVersion { ref mut local_version, ref mut future } => { + let (stream, version) = try_ready!(future.poll()); + let version = match version { + Ok(version) => version, + Err(err) => return Ok((stream, Err(err.into())).into()), + }; - if version.version() < self.min_version { - return Ok((stream, Err(Error::InvalidVersion)).into()); - } - if let (Some(self_nonce), Some(nonce)) = (self.nonce, version.nonce()) { - if self_nonce == nonce { + if version.version() < self.min_version { return Ok((stream, Err(Error::InvalidVersion)).into()); } - } + if let (Some(self_nonce), Some(nonce)) = (self.nonce, version.nonce()) { + if self_nonce == nonce { + return Ok((stream, Err(Error::InvalidVersion)).into()); + } + } - let local_version = local_version.take().expect("local version must be set"); - let next = AcceptHandshakeState::SendVersion { - version: Some(version), - future: write_message(stream, version_message(self.magic, local_version)), - }; + let local_version = local_version.take().expect("local version must be set"); + AcceptHandshakeState::SendVersion { + version: Some(version), + future: write_message(stream, version_message(self.magic, local_version)), + } + }, + AcceptHandshakeState::SendVersion { ref mut version, ref mut future } => { + let (stream, _) = try_ready!(future.poll()); + AcceptHandshakeState::SendVerack { + version: version.take(), + future: write_message(stream, verack_message(self.magic)), + } + }, + AcceptHandshakeState::SendVerack { ref mut version, ref mut future } => { + let (stream, _) = try_ready!(future.poll()); - (next, Async::NotReady) - }, - AcceptHandshakeState::SendVersion { ref mut version, ref mut future } => { - let (stream, _) = try_ready!(future.poll()); - let next = AcceptHandshakeState::SendVerack { - version: version.take(), - future: write_message(stream, verack_message(self.magic)), - }; + let version = version.take().expect("verack must be preceded by version"); - (next, Async::NotReady) - }, - AcceptHandshakeState::SendVerack { ref mut version, ref mut future } => { - let (stream, _) = try_ready!(future.poll()); + let result = HandshakeResult { + negotiated_version: negotiate_version(self.version, version.version()), + version: version, + }; - let version = version.take().expect("verack must be preceded by version"); - - let result = HandshakeResult { - negotiated_version: negotiate_version(self.version, version.version()), - version: version, - }; - - (AcceptHandshakeState::Finished, Async::Ready((stream, Ok(result)))) - }, - AcceptHandshakeState::Finished => panic!("poll AcceptHandshake after it's done"), - }; - - self.state = next; - match result { - // by polling again, we register new future - Async::NotReady => self.poll(), - result => Ok(result) + return Ok(Async::Ready((stream, Ok(result)))); + }, + }; + self.state = next_state; } } } diff --git a/p2p/src/io/read_any_message.rs b/p2p/src/io/read_any_message.rs index 63e43eff..254e2a39 100644 --- a/p2p/src/io/read_any_message.rs +++ b/p2p/src/io/read_any_message.rs @@ -20,7 +20,6 @@ pub enum ReadAnyMessageState { header: MessageHeader, future: ReadExact }, - Finished, } pub struct ReadAnyMessage { @@ -32,36 +31,30 @@ impl Future for ReadAnyMessage where A: AsyncRead { type Error = io::Error; fn poll(&mut self) -> Poll { - let (next, result) = match self.state { - ReadAnyMessageState::ReadHeader(ref mut header) => { - let (stream, header) = try_ready!(header.poll()); - let header = match header { - Ok(header) => header, - Err(err) => return Ok(Err(err).into()), - }; - let future = read_exact(stream, Bytes::new_with_len(header.len as usize)); - let next = ReadAnyMessageState::ReadPayload { - header: header, - future: future, - }; - (next, Async::NotReady) - }, - ReadAnyMessageState::ReadPayload { ref mut header, ref mut future } => { - let (_stream, bytes) = try_ready!(future.poll()); - if checksum(&bytes) != header.checksum { - return Ok(Err(Error::InvalidChecksum).into()); - } - let next = ReadAnyMessageState::Finished; - (next, Ok((header.command.clone(), bytes)).into()) - }, - ReadAnyMessageState::Finished => panic!("poll ReadAnyMessage after it's done"), - }; + loop { + let next_state = match self.state { + ReadAnyMessageState::ReadHeader(ref mut header) => { + let (stream, header) = try_ready!(header.poll()); + let header = match header { + Ok(header) => header, + Err(err) => return Ok(Err(err).into()), + }; + ReadAnyMessageState::ReadPayload { + future: read_exact(stream, Bytes::new_with_len(header.len as usize)), + header: header, + } + }, + ReadAnyMessageState::ReadPayload { ref mut header, ref mut future } => { + let (_stream, bytes) = try_ready!(future.poll()); + if checksum(&bytes) != header.checksum { + return Ok(Err(Error::InvalidChecksum).into()); + } - self.state = next; - match result { - // by polling again, we register new future - Async::NotReady => self.poll(), - result => Ok(result) + return Ok(Async::Ready(Ok((header.command.clone(), bytes)))); + }, + }; + + self.state = next_state; } } } diff --git a/p2p/src/io/read_message.rs b/p2p/src/io/read_message.rs index a4935afc..ad5420a7 100644 --- a/p2p/src/io/read_message.rs +++ b/p2p/src/io/read_message.rs @@ -25,7 +25,6 @@ enum ReadMessageState { ReadPayload { future: ReadPayload, }, - Finished, } pub struct ReadMessage { @@ -38,36 +37,30 @@ impl Future for ReadMessage where A: AsyncRead, M: Payload { type Error = io::Error; fn poll(&mut self) -> Poll { - let (next, result) = match self.state { - ReadMessageState::ReadHeader { version, ref mut future } => { - let (read, header) = try_ready!(future.poll()); - let header = match header { - Ok(header) => header, - Err(err) => return Ok((read, Err(err)).into()), - }; - if header.command != M::command() { - return Ok((read, Err(Error::InvalidCommand)).into()); - } - let future = read_payload( - read, version, header.len as usize, header.checksum, - ); - let next = ReadMessageState::ReadPayload { - future: future, - }; - (next, Async::NotReady) - }, - ReadMessageState::ReadPayload { ref mut future } => { - let (read, payload) = try_ready!(future.poll()); - (ReadMessageState::Finished, Async::Ready((read, payload))) - }, - ReadMessageState::Finished => panic!("poll ReadMessage after it's done"), - }; - - self.state = next; - match result { - // by polling again, we register new future - Async::NotReady => self.poll(), - result => Ok(result) + loop { + let next_state = match self.state { + ReadMessageState::ReadHeader { version, ref mut future } => { + let (read, header) = try_ready!(future.poll()); + let header = match header { + Ok(header) => header, + Err(err) => return Ok((read, Err(err)).into()), + }; + if header.command != M::command() { + return Ok((read, Err(Error::InvalidCommand)).into()); + } + let future = read_payload( + read, version, header.len as usize, header.checksum, + ); + ReadMessageState::ReadPayload { + future: future, + } + }, + ReadMessageState::ReadPayload { ref mut future } => { + let (read, payload) = try_ready!(future.poll()); + return Ok(Async::Ready((read, payload))); + }, + }; + self.state = next_state; } } }