diff --git a/p2p/src/p2p.rs b/p2p/src/p2p.rs index a294e24c..755062ea 100644 --- a/p2p/src/p2p.rs +++ b/p2p/src/p2p.rs @@ -114,6 +114,10 @@ impl Context { info!("Inbound connections: ({}/{})", ic.0, ic.1); info!("Outbound connections: ({}/{})", oc.0, oc.1); + for channel in context.connections.channels().values() { + channel.session().maintain(); + } + let used_addresses = context.connections.addresses(); let max = (ic.1 + oc.1) as usize; let needed = context.connection_counter.outbound_connections_needed() as usize; diff --git a/p2p/src/protocol/mod.rs b/p2p/src/protocol/mod.rs index 54588f36..560ca4c9 100644 --- a/p2p/src/protocol/mod.rs +++ b/p2p/src/protocol/mod.rs @@ -14,6 +14,9 @@ pub trait Protocol: Send { /// Initialize the protocol. fn initialize(&mut self) {} + /// Maintain the protocol. + fn maintain(&mut self) {} + /// Handle the message. fn on_message(&mut self, command: &Command, payload: &Bytes) -> Result<(), Error>; diff --git a/p2p/src/protocol/ping.rs b/p2p/src/protocol/ping.rs index 76022bdb..adea6b97 100644 --- a/p2p/src/protocol/ping.rs +++ b/p2p/src/protocol/ping.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use time; use bytes::Bytes; use message::{Error, Payload, deserialize_payload}; use message::types::{Ping, Pong}; @@ -7,11 +8,27 @@ use protocol::Protocol; use net::PeerContext; use util::nonce::{NonceGenerator, RandomNonce}; +/// Time that must pass since last message from this peer, before we send ping request +const PING_INTERVAL_S: f64 = 60f64; +/// If peer has not responded to our ping request with pong during this interval => close connection +const MAX_PING_RESPONSE_TIME_S: f64 = 60f64; + +/// Ping state +#[derive(Debug, Copy, Clone, PartialEq)] +enum State { + /// Peer is sending us messages && we wait for `PING_INTERVAL_S` to pass before sending ping request + WaitingTimeout(f64), + /// Ping message is sent to the peer && we are waiting for pong response for `MAX_PING_RESPONSE_TIME_S` + WaitingPong(f64), +} + pub struct PingProtocol { /// Context context: Arc, /// Nonce generator. nonce_generator: T, + /// Ping state + state: State, /// Last nonce sent in the ping message. last_ping_nonce: Option, } @@ -21,6 +38,7 @@ impl PingProtocol { PingProtocol { context: context, nonce_generator: RandomNonce::default(), + state: State::WaitingTimeout(time::precise_time_s()), last_ping_nonce: None, } } @@ -29,13 +47,36 @@ impl PingProtocol { impl Protocol for PingProtocol { fn initialize(&mut self) { // bitcoind always sends ping, let's do the same - let nonce = self.nonce_generator.get(); - self.last_ping_nonce = Some(nonce); - let ping = Ping::new(nonce); - self.context.send_request(&ping); + self.maintain(); + } + + fn maintain(&mut self) { + let now = time::precise_time_s(); + match self.state { + State::WaitingTimeout(time) => { + // send ping request if enough time has passed since last message + if now - time > PING_INTERVAL_S { + let nonce = self.nonce_generator.get(); + self.state = State::WaitingPong(now); + self.last_ping_nonce = Some(nonce); + let ping = Ping::new(nonce); + self.context.send_request(&ping); + } + }, + State::WaitingPong(time) => { + // if no new messages from peer for last MAX_PING_RESPONSE_TIME_S => disconnect + if now - time > MAX_PING_RESPONSE_TIME_S { + trace!("closing connection to peer {}: no messages for last {} seconds", self.context.info().id, now - time); + self.context.close(); + } + }, + } } fn on_message(&mut self, command: &Command, payload: &Bytes) -> Result<(), Error> { + // we have received new message => do not close connection because of timeout + self.state = State::WaitingTimeout(time::precise_time_s()); + if command == &Ping::command() { let ping: Ping = try!(deserialize_payload(payload, self.context.info().version)); let pong = Pong::new(ping.nonce); diff --git a/p2p/src/session.rs b/p2p/src/session.rs index 81a3c61b..1d2e9c4a 100644 --- a/p2p/src/session.rs +++ b/p2p/src/session.rs @@ -54,6 +54,12 @@ impl Session { } } + pub fn maintain(&self) { + for protocol in self.protocols.lock().iter_mut() { + protocol.maintain(); + } + } + pub fn on_message(&self, command: Command, payload: Bytes) -> Result<(), Error> { self.stats().lock().report_recv(command.clone(), payload.len());