Use futures:🔒:Mutex for the nonce set

This commit is contained in:
teor 2021-04-19 15:32:31 +10:00 committed by Deirdre Connolly
parent 2ed8bb00cf
commit 3f45735f3f
1 changed files with 18 additions and 10 deletions

View File

@ -3,7 +3,7 @@ use std::{
future::Future, future::Future,
net::SocketAddr, net::SocketAddr,
pin::Pin, pin::Pin,
sync::{Arc, Mutex}, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
}; };
@ -46,7 +46,7 @@ pub struct Handshake<S> {
inbound_service: S, inbound_service: S,
timestamp_collector: mpsc::Sender<MetaAddr>, timestamp_collector: mpsc::Sender<MetaAddr>,
inv_collector: broadcast::Sender<(InventoryHash, SocketAddr)>, inv_collector: broadcast::Sender<(InventoryHash, SocketAddr)>,
nonces: Arc<Mutex<HashSet<Nonce>>>, nonces: Arc<futures::lock::Mutex<HashSet<Nonce>>>,
user_agent: String, user_agent: String,
our_services: PeerServices, our_services: PeerServices,
relay: bool, relay: bool,
@ -139,7 +139,7 @@ where
let (tx, _rx) = mpsc::channel(1); let (tx, _rx) = mpsc::channel(1);
tx tx
}); });
let nonces = Arc::new(Mutex::new(HashSet::new())); let nonces = Arc::new(futures::lock::Mutex::new(HashSet::new()));
let user_agent = self.user_agent.unwrap_or_else(|| "".to_string()); let user_agent = self.user_agent.unwrap_or_else(|| "".to_string());
let our_services = self.our_services.unwrap_or_else(PeerServices::empty); let our_services = self.our_services.unwrap_or_else(PeerServices::empty);
let relay = self.relay.unwrap_or(false); let relay = self.relay.unwrap_or(false);
@ -189,17 +189,19 @@ pub async fn negotiate_version(
peer_conn: &mut Framed<TcpStream, Codec>, peer_conn: &mut Framed<TcpStream, Codec>,
addr: &SocketAddr, addr: &SocketAddr,
config: Config, config: Config,
nonces: Arc<Mutex<HashSet<Nonce>>>, nonces: Arc<futures::lock::Mutex<HashSet<Nonce>>>,
user_agent: String, user_agent: String,
our_services: PeerServices, our_services: PeerServices,
relay: bool, relay: bool,
) -> Result<(Version, PeerServices), HandshakeError> { ) -> Result<(Version, PeerServices), HandshakeError> {
// Create a random nonce for this connection // Create a random nonce for this connection
let local_nonce = Nonce::default(); let local_nonce = Nonce::default();
nonces // # Correctness
.lock() //
.expect("mutex should be unpoisoned") // It is ok to wait for the lock here, because handshakes have a short
.insert(local_nonce); // timeout, and the async mutex will be released when the task times
// out.
nonces.lock().await.insert(local_nonce);
// Don't leak our exact clock skew to our peers. On the other hand, // Don't leak our exact clock skew to our peers. On the other hand,
// we can't deviate too much, or zcashd will get confused. // we can't deviate too much, or zcashd will get confused.
@ -258,9 +260,15 @@ pub async fn negotiate_version(
Err(HandshakeError::UnexpectedMessage(Box::new(remote_msg)))? Err(HandshakeError::UnexpectedMessage(Box::new(remote_msg)))?
}; };
// Check for nonce reuse, indicating self-connection. // Check for nonce reuse, indicating self-connection
//
// # Correctness
//
// We must wait for the lock before we continue with the connection, to avoid
// self-connection. If the connection times out, the async lock will be
// released.
let nonce_reuse = { let nonce_reuse = {
let mut locked_nonces = nonces.lock().expect("mutex should be unpoisoned"); let mut locked_nonces = nonces.lock().await;
let nonce_reuse = locked_nonces.contains(&remote_nonce); let nonce_reuse = locked_nonces.contains(&remote_nonce);
// Regardless of whether we observed nonce reuse, clean up the nonce set. // Regardless of whether we observed nonce reuse, clean up the nonce set.
locked_nonces.remove(&local_nonce); locked_nonces.remove(&local_nonce);