Refactor and document correctness for std::sync::Mutex<AddressBook>

This commit is contained in:
teor 2021-04-19 16:04:24 +10:00 committed by Deirdre Connolly
parent 905b90d6a1
commit 0203d1475a
6 changed files with 73 additions and 41 deletions

View File

@ -15,7 +15,7 @@ use crate::{constants, types::MetaAddr, PeerAddrState};
/// A database of peers, their advertised services, and information on when they /// A database of peers, their advertised services, and information on when they
/// were last seen. /// were last seen.
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct AddressBook { pub struct AddressBook {
/// Each known peer address has a matching `MetaAddr` /// Each known peer address has a matching `MetaAddr`
by_addr: HashMap<SocketAddr, MetaAddr>, by_addr: HashMap<SocketAddr, MetaAddr>,

View File

@ -1,8 +1,4 @@
use std::{ use std::{mem, sync::Arc, time::Duration};
mem,
sync::{Arc, Mutex},
time::Duration,
};
use futures::stream::{FuturesUnordered, StreamExt}; use futures::stream::{FuturesUnordered, StreamExt};
use tokio::time::{sleep, sleep_until, timeout, Sleep}; use tokio::time::{sleep, sleep_until, timeout, Sleep};
@ -105,7 +101,7 @@ use crate::{constants, types::MetaAddr, AddressBook, BoxError, Request, Response
// * draw arrow from the "peer message" box into the `Responded` state box // * draw arrow from the "peer message" box into the `Responded` state box
// * make the "disjoint states" box include `AttemptPending` // * make the "disjoint states" box include `AttemptPending`
pub(super) struct CandidateSet<S> { pub(super) struct CandidateSet<S> {
pub(super) peer_set: Arc<Mutex<AddressBook>>, pub(super) address_book: Arc<std::sync::Mutex<AddressBook>>,
pub(super) peer_service: S, pub(super) peer_service: S,
next_peer_min_wait: Sleep, next_peer_min_wait: Sleep,
} }
@ -123,10 +119,13 @@ where
/// are initiated at least `MIN_PEER_CONNECTION_INTERVAL` apart. /// are initiated at least `MIN_PEER_CONNECTION_INTERVAL` apart.
const MIN_PEER_CONNECTION_INTERVAL: Duration = Duration::from_millis(100); const MIN_PEER_CONNECTION_INTERVAL: Duration = Duration::from_millis(100);
/// Uses `peer_set` and `peer_service` to manage a [`CandidateSet`] of peers. /// Uses `address_book` and `peer_service` to manage a [`CandidateSet`] of peers.
pub fn new(peer_set: Arc<Mutex<AddressBook>>, peer_service: S) -> CandidateSet<S> { pub fn new(
address_book: Arc<std::sync::Mutex<AddressBook>>,
peer_service: S,
) -> CandidateSet<S> {
CandidateSet { CandidateSet {
peer_set, address_book,
peer_service, peer_service,
next_peer_min_wait: sleep(Duration::from_secs(0)), next_peer_min_wait: sleep(Duration::from_secs(0)),
} }
@ -163,9 +162,11 @@ where
for _ in 0..constants::GET_ADDR_FANOUT { for _ in 0..constants::GET_ADDR_FANOUT {
// CORRECTNESS // CORRECTNESS
// //
// avoid deadlocks when there are no connected peers, and: // Use a timeout to avoid deadlocks when there are no connected
// peers, and:
// - we're waiting on a handshake to complete so there are peers, or // - we're waiting on a handshake to complete so there are peers, or
// - another task that handles or adds peers is waiting on this task to complete. // - another task that handles or adds peers is waiting on this task
// to complete.
let peer_service = let peer_service =
match timeout(constants::REQUEST_TIMEOUT, self.peer_service.ready_and()).await { match timeout(constants::REQUEST_TIMEOUT, self.peer_service.ready_and()).await {
// update must only return an error for permanent failures // update must only return an error for permanent failures
@ -185,20 +186,31 @@ where
match rsp { match rsp {
Ok(Response::Peers(rsp_addrs)) => { Ok(Response::Peers(rsp_addrs)) => {
// Filter new addresses to ensure that gossiped addresses are actually new // Filter new addresses to ensure that gossiped addresses are actually new
let peer_set = &self.peer_set; let address_book = &self.address_book;
// # Correctness
//
// Briefly hold the address book threaded mutex, each time we
// check an address.
//
// TODO: reduce mutex contention by moving the filtering into // TODO: reduce mutex contention by moving the filtering into
// the address book itself // the address book itself (#1976)
let new_addrs = rsp_addrs let new_addrs = rsp_addrs
.iter() .iter()
.filter(|meta| !peer_set.lock().unwrap().contains_addr(&meta.addr)) .filter(|meta| !address_book.lock().unwrap().contains_addr(&meta.addr))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
trace!( trace!(
?rsp_addrs, ?rsp_addrs,
new_addr_count = ?new_addrs.len(), new_addr_count = ?new_addrs.len(),
"got response to GetPeers" "got response to GetPeers"
); );
// New addresses are deserialized in the `NeverAttempted` state // New addresses are deserialized in the `NeverAttempted` state
peer_set //
// # Correctness
//
// Briefly hold the address book threaded mutex, to extend
// the address list.
address_book
.lock() .lock()
.unwrap() .unwrap()
.extend(new_addrs.into_iter().cloned()); .extend(new_addrs.into_iter().cloned());
@ -242,9 +254,10 @@ where
let mut sleep = sleep_until(current_deadline + Self::MIN_PEER_CONNECTION_INTERVAL); let mut sleep = sleep_until(current_deadline + Self::MIN_PEER_CONNECTION_INTERVAL);
mem::swap(&mut self.next_peer_min_wait, &mut sleep); mem::swap(&mut self.next_peer_min_wait, &mut sleep);
// CORRECTNESS // # Correctness
// //
// In this critical section, we hold the address mutex. // In this critical section, we hold the address mutex, blocking the
// current thread, and all async tasks scheduled on that thread.
// //
// To avoid deadlocks, the critical section: // To avoid deadlocks, the critical section:
// - must not acquire any other locks // - must not acquire any other locks
@ -253,17 +266,17 @@ where
// To avoid hangs, any computation in the critical section should // To avoid hangs, any computation in the critical section should
// be kept to a minimum. // be kept to a minimum.
let reconnect = { let reconnect = {
let mut peer_set_guard = self.peer_set.lock().unwrap(); let mut guard = self.address_book.lock().unwrap();
// It's okay to early return here because we're returning None // It's okay to return without sleeping here, because we're returning
// instead of yielding the next connection. // `None`. We only need to sleep before yielding an address.
let reconnect = peer_set_guard.reconnection_peers().next()?; let reconnect = guard.reconnection_peers().next()?;
let reconnect = MetaAddr::new_reconnect(&reconnect.addr, &reconnect.services); let reconnect = MetaAddr::new_reconnect(&reconnect.addr, &reconnect.services);
peer_set_guard.update(reconnect); guard.update(reconnect);
reconnect reconnect
}; };
// This is the line that is most relevant to the above ## Security section // SECURITY: rate-limit new candidate connections
sleep.await; sleep.await;
Some(reconnect) Some(reconnect)
@ -272,6 +285,10 @@ where
/// Mark `addr` as a failed peer. /// Mark `addr` as a failed peer.
pub fn report_failed(&mut self, addr: &MetaAddr) { pub fn report_failed(&mut self, addr: &MetaAddr) {
let addr = MetaAddr::new_errored(&addr.addr, &addr.services); let addr = MetaAddr::new_errored(&addr.addr, &addr.services);
self.peer_set.lock().unwrap().update(addr); // # Correctness
//
// Briefly hold the address book threaded mutex, to update the state for
// a single address.
self.address_book.lock().unwrap().update(addr);
} }
} }

View File

@ -3,10 +3,7 @@
// Portions of this submodule were adapted from tower-balance, // Portions of this submodule were adapted from tower-balance,
// which is (c) 2019 Tower Contributors (MIT licensed). // which is (c) 2019 Tower Contributors (MIT licensed).
use std::{ use std::{net::SocketAddr, sync::Arc};
net::SocketAddr,
sync::{Arc, Mutex},
};
use futures::{ use futures::{
channel::mpsc, channel::mpsc,
@ -65,7 +62,7 @@ pub async fn init<S>(
inbound_service: S, inbound_service: S,
) -> ( ) -> (
Buffer<BoxService<Request, Response, BoxError>, Request>, Buffer<BoxService<Request, Response, BoxError>, Request>,
Arc<Mutex<AddressBook>>, Arc<std::sync::Mutex<AddressBook>>,
) )
where where
S: Service<Request, Response = Response, Error = BoxError> + Clone + Send + 'static, S: Service<Request, Response = Response, Error = BoxError> + Clone + Send + 'static,

View File

@ -5,7 +5,7 @@ use std::{
future::Future, future::Future,
marker::PhantomData, marker::PhantomData,
pin::Pin, pin::Pin,
sync::{Arc, Mutex}, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
time::Instant, time::Instant,
}; };
@ -106,7 +106,7 @@ where
/// A shared list of peer addresses. /// A shared list of peer addresses.
/// ///
/// Used for logging diagnostics. /// Used for logging diagnostics.
address_book: Arc<Mutex<AddressBook>>, address_book: Arc<std::sync::Mutex<AddressBook>>,
} }
impl<D> PeerSet<D> impl<D> PeerSet<D>
@ -124,7 +124,7 @@ where
demand_signal: mpsc::Sender<()>, demand_signal: mpsc::Sender<()>,
handle_rx: tokio::sync::oneshot::Receiver<Vec<JoinHandle<Result<(), BoxError>>>>, handle_rx: tokio::sync::oneshot::Receiver<Vec<JoinHandle<Result<(), BoxError>>>>,
inv_stream: broadcast::Receiver<(InventoryHash, SocketAddr)>, inv_stream: broadcast::Receiver<(InventoryHash, SocketAddr)>,
address_book: Arc<Mutex<AddressBook>>, address_book: Arc<std::sync::Mutex<AddressBook>>,
) -> Self { ) -> Self {
Self { Self {
discover, discover,
@ -379,8 +379,13 @@ where
} }
self.last_peer_log = Some(Instant::now()); self.last_peer_log = Some(Instant::now());
// # Correctness
//
// Only log address metrics in exceptional circumstances, to avoid lock contention. // Only log address metrics in exceptional circumstances, to avoid lock contention.
// TODO: replace with a watch channel that is updated in `AddressBook::update_metrics()`. //
// TODO: replace with a watch channel that is updated in `AddressBook::update_metrics()`,
// or turn the address book into a service (#1976)
let address_metrics = self.address_book.lock().unwrap().address_metrics(); let address_metrics = self.address_book.lock().unwrap().address_metrics();
if unready_services_len == 0 { if unready_services_len == 0 {
warn!( warn!(

View File

@ -1,6 +1,6 @@
//! The timestamp collector collects liveness information from peers. //! The timestamp collector collects liveness information from peers.
use std::sync::{Arc, Mutex}; use std::sync::Arc;
use futures::{channel::mpsc, prelude::*}; use futures::{channel::mpsc, prelude::*};
@ -14,11 +14,11 @@ impl TimestampCollector {
/// Spawn a new [`TimestampCollector`] task, and return handles for the /// Spawn a new [`TimestampCollector`] task, and return handles for the
/// transmission channel for timestamp events and for the [`AddressBook`] it /// transmission channel for timestamp events and for the [`AddressBook`] it
/// updates. /// updates.
pub fn spawn() -> (Arc<Mutex<AddressBook>>, mpsc::Sender<MetaAddr>) { pub fn spawn() -> (Arc<std::sync::Mutex<AddressBook>>, mpsc::Sender<MetaAddr>) {
use tracing::Level; use tracing::Level;
const TIMESTAMP_WORKER_BUFFER_SIZE: usize = 100; const TIMESTAMP_WORKER_BUFFER_SIZE: usize = 100;
let (worker_tx, mut worker_rx) = mpsc::channel(TIMESTAMP_WORKER_BUFFER_SIZE); let (worker_tx, mut worker_rx) = mpsc::channel(TIMESTAMP_WORKER_BUFFER_SIZE);
let address_book = Arc::new(Mutex::new(AddressBook::new(span!( let address_book = Arc::new(std::sync::Mutex::new(AddressBook::new(span!(
Level::TRACE, Level::TRACE,
"timestamp collector" "timestamp collector"
)))); ))));
@ -26,6 +26,10 @@ impl TimestampCollector {
let worker = async move { let worker = async move {
while let Some(event) = worker_rx.next().await { while let Some(event) = worker_rx.next().await {
// # Correctness
//
// Briefly hold the address book threaded mutex, to update the
// state for a single address.
worker_address_book worker_address_book
.lock() .lock()
.expect("mutex should be unpoisoned") .expect("mutex should be unpoisoned")

View File

@ -1,7 +1,7 @@
use std::{ use std::{
future::Future, future::Future,
pin::Pin, pin::Pin,
sync::{Arc, Mutex}, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
}; };
@ -31,7 +31,7 @@ type State = Buffer<BoxService<zs::Request, zs::Response, zs::BoxError>, zs::Req
type Verifier = Buffer<BoxService<Arc<Block>, block::Hash, VerifyChainError>, Arc<Block>>; type Verifier = Buffer<BoxService<Arc<Block>, block::Hash, VerifyChainError>, Arc<Block>>;
type InboundDownloads = Downloads<Timeout<Outbound>, Timeout<Verifier>, State>; type InboundDownloads = Downloads<Timeout<Outbound>, Timeout<Verifier>, State>;
pub type NetworkSetupData = (Outbound, Arc<Mutex<AddressBook>>); pub type NetworkSetupData = (Outbound, Arc<std::sync::Mutex<AddressBook>>);
/// Tracks the internal state of the [`Inbound`] service during network setup. /// Tracks the internal state of the [`Inbound`] service during network setup.
pub enum Setup { pub enum Setup {
@ -54,7 +54,7 @@ pub enum Setup {
/// All requests are answered. /// All requests are answered.
Initialized { Initialized {
/// A shared list of peer addresses. /// A shared list of peer addresses.
address_book: Arc<Mutex<zn::AddressBook>>, address_book: Arc<std::sync::Mutex<zn::AddressBook>>,
/// A `futures::Stream` that downloads and verifies gossipped blocks. /// A `futures::Stream` that downloads and verifies gossipped blocks.
downloads: Pin<Box<InboundDownloads>>, downloads: Pin<Box<InboundDownloads>>,
@ -228,11 +228,20 @@ impl Service<zn::Request> for Inbound {
match req { match req {
zn::Request::Peers => { zn::Request::Peers => {
if let Setup::Initialized { address_book, .. } = &self.network_setup { if let Setup::Initialized { address_book, .. } = &self.network_setup {
// # Security
//
// We could truncate the list to try to not reveal our entire // We could truncate the list to try to not reveal our entire
// peer set. But because we don't monitor repeated requests, // peer set. But because we don't monitor repeated requests,
// this wouldn't actually achieve anything, because a crawler // this wouldn't actually achieve anything, because a crawler
// could just repeatedly query it. // could just repeatedly query it.
let mut peers = address_book.lock().unwrap().sanitized(); //
// # Correctness
//
// Briefly hold the address book threaded mutex while
// cloning the address book. Then sanitize after releasing
// the lock.
let peers = address_book.lock().unwrap().clone();
let mut peers = peers.sanitized();
const MAX_ADDR: usize = 1000; // bitcoin protocol constant const MAX_ADDR: usize = 1000; // bitcoin protocol constant
peers.truncate(MAX_ADDR); peers.truncate(MAX_ADDR);
async { Ok(zn::Response::Peers(peers)) }.boxed() async { Ok(zn::Response::Peers(peers)) }.boxed()