From 54f0fc9f0f9bd7e4c9261ed5327e7a81e9ad8265 Mon Sep 17 00:00:00 2001 From: Lijun Wang <83639177+lijunwangs@users.noreply.github.com> Date: Wed, 26 May 2021 13:26:07 -0700 Subject: [PATCH] Use type alias for DownloadProgress callback (#17518) Convert to use type alias for the callback and cascade the changes to callers. Thanks @jeffwashington for the help making it possible. Changed the closure for the progress update in the validator main to FnMut and modify the abort count in the closure which is more reliable. --- download-utils/src/lib.rs | 43 +++++++++++----------------- local-cluster/tests/local_cluster.rs | 4 +-- sdk/cargo-build-bpf/src/main.rs | 9 ++---- validator/src/main.rs | 8 ++---- 4 files changed, 24 insertions(+), 40 deletions(-) diff --git a/download-utils/src/lib.rs b/download-utils/src/lib.rs index 8decda47b..da31b0d34 100644 --- a/download-utils/src/lib.rs +++ b/download-utils/src/lib.rs @@ -46,18 +46,18 @@ pub struct DownloadProgressRecord { pub notification_count: u64, } +type DownloadProgressCallback<'a> = Box bool + 'a>; +type DownloadProgressCallbackOption<'a> = Option>; + /// This callback allows the caller to get notified of the download progress modelled by DownloadProgressRecord /// Return "true" to continue the download /// Return "false" to abort the download -pub fn download_file( +pub fn download_file<'a, 'b>( url: &str, destination_file: &Path, use_progress_bar: bool, - progress_notify_callback: &Option, -) -> Result<(), String> -where - F: Fn(&DownloadProgressRecord) -> bool, -{ + progress_notify_callback: &'a mut DownloadProgressCallbackOption<'b>, +) -> Result<(), String> { if destination_file.is_file() { return Err(format!("{:?} already exists", destination_file)); } @@ -113,10 +113,7 @@ where info!("Downloading {} bytes from {}", download_size, url); } - struct DownloadProgress - where - F: Fn(&DownloadProgressRecord) -> bool, - { + struct DownloadProgress<'e, 'f, R> { progress_bar: ProgressBar, response: R, last_print: Instant, @@ -125,14 +122,11 @@ where download_size: f32, use_progress_bar: bool, start_time: Instant, - callback: Option, + callback: &'f mut DownloadProgressCallbackOption<'e>, notification_count: u64, } - impl Read for DownloadProgress - where - F: Fn(&DownloadProgressRecord) -> bool, - { + impl<'e, 'f, R: Read> Read for DownloadProgress<'e, 'f, R> { fn read(&mut self, buf: &mut [u8]) -> io::Result { let n = self.response.read(buf)?; @@ -178,7 +172,7 @@ where ); } - if let Some(callback) = &self.callback { + if let Some(callback) = self.callback { if to_update_progress && !callback(&progress_record) { info!("Download is aborted by the caller"); return Err(io::Error::new( @@ -192,7 +186,7 @@ where } } - let mut source = DownloadProgress { + let mut source = DownloadProgress::<'b, 'a> { progress_bar, response, last_print: Instant::now(), @@ -201,7 +195,7 @@ where download_size: (download_size as f32).max(1f32), use_progress_bar, start_time: Instant::now(), - callback: progress_notify_callback.as_ref(), + callback: progress_notify_callback, notification_count: 0, }; @@ -241,7 +235,7 @@ pub fn download_genesis_if_missing( &format!("http://{}/{}", rpc_addr, DEFAULT_GENESIS_ARCHIVE), &tmp_genesis_package, use_progress_bar, - &None:: bool>, + &mut None, )?; Ok(tmp_genesis_package) @@ -250,17 +244,14 @@ pub fn download_genesis_if_missing( } } -pub fn download_snapshot( +pub fn download_snapshot<'a, 'b>( rpc_addr: &SocketAddr, snapshot_output_dir: &Path, desired_snapshot_hash: (Slot, Hash), use_progress_bar: bool, maximum_snapshots_to_retain: usize, - progress_notify_callback: &Option, -) -> Result<(), String> -where - F: Fn(&DownloadProgressRecord) -> bool, -{ + progress_notify_callback: &'a mut DownloadProgressCallbackOption<'b>, +) -> Result<(), String> { snapshot_utils::purge_old_snapshot_archives(snapshot_output_dir, maximum_snapshots_to_retain); for compression in &[ @@ -290,7 +281,7 @@ where ), &desired_snapshot_package, use_progress_bar, - &progress_notify_callback, + progress_notify_callback, ) .is_ok() { diff --git a/local-cluster/tests/local_cluster.rs b/local-cluster/tests/local_cluster.rs index 8ae01e6a1..7011e1230 100644 --- a/local-cluster/tests/local_cluster.rs +++ b/local-cluster/tests/local_cluster.rs @@ -17,7 +17,7 @@ use solana_core::{ optimistic_confirmation_verifier::OptimisticConfirmationVerifier, validator::ValidatorConfig, }; -use solana_download_utils::{download_snapshot, DownloadProgressRecord}; +use solana_download_utils::download_snapshot; use solana_gossip::{ cluster_info::{self, VALIDATOR_PORT_RANGE}, crds_value::{self, CrdsData, CrdsValue}, @@ -1687,7 +1687,7 @@ fn test_snapshot_download() { archive_snapshot_hash, false, snapshot_utils::DEFAULT_MAX_SNAPSHOTS_TO_RETAIN, - &None:: bool>, + &mut None, ) .unwrap(); diff --git a/sdk/cargo-build-bpf/src/main.rs b/sdk/cargo-build-bpf/src/main.rs index cb0790b0d..ba0bf06d9 100644 --- a/sdk/cargo-build-bpf/src/main.rs +++ b/sdk/cargo-build-bpf/src/main.rs @@ -4,7 +4,7 @@ use { crate_description, crate_name, crate_version, value_t, value_t_or_exit, values_t, App, Arg, }, regex::Regex, - solana_download_utils::{download_file, DownloadProgressRecord}, + solana_download_utils::download_file, solana_sdk::signature::{write_keypair_file, Keypair}, std::{ collections::HashMap, @@ -112,12 +112,7 @@ fn install_if_missing( url.push_str(version); url.push('/'); url.push_str(file.to_str().unwrap()); - download_file( - &url.as_str(), - &file, - true, - &None:: bool>, - )?; + download_file(&url.as_str(), &file, true, &mut None)?; fs::create_dir_all(&target_path).map_err(|err| err.to_string())?; let zip = File::open(&file).map_err(|err| err.to_string())?; let tar = BzDecoder::new(BufReader::new(zip)); diff --git a/validator/src/main.rs b/validator/src/main.rs index f3e69f2e8..5cd3cbada 100644 --- a/validator/src/main.rs +++ b/validator/src/main.rs @@ -898,7 +898,7 @@ fn rpc_bootstrap( snapshot_hash, use_progress_bar, maximum_snapshots_to_retain, - &Some(|download_progress: &DownloadProgressRecord| { + &mut Some(Box::new(|download_progress: &DownloadProgressRecord| { debug!("Download progress: {:?}", download_progress); if download_progress.last_throughput < minimal_snapshot_download_speed @@ -922,16 +922,14 @@ fn rpc_bootstrap( and try a different node. Abort count: {}, Progress detail: {:?}", download_progress.last_throughput, minimal_snapshot_download_speed, download_abort_count, download_progress); + download_abort_count += 1; false } else { true } - }), + })), ); - if ret.is_err() { - download_abort_count += 1; - } gossip_service.join().unwrap(); ret })