diff --git a/download-utils/src/lib.rs b/download-utils/src/lib.rs index 8decda47b..da31b0d34 100644 --- a/download-utils/src/lib.rs +++ b/download-utils/src/lib.rs @@ -46,18 +46,18 @@ pub struct DownloadProgressRecord { pub notification_count: u64, } +type DownloadProgressCallback<'a> = Box bool + 'a>; +type DownloadProgressCallbackOption<'a> = Option>; + /// This callback allows the caller to get notified of the download progress modelled by DownloadProgressRecord /// Return "true" to continue the download /// Return "false" to abort the download -pub fn download_file( +pub fn download_file<'a, 'b>( url: &str, destination_file: &Path, use_progress_bar: bool, - progress_notify_callback: &Option, -) -> Result<(), String> -where - F: Fn(&DownloadProgressRecord) -> bool, -{ + progress_notify_callback: &'a mut DownloadProgressCallbackOption<'b>, +) -> Result<(), String> { if destination_file.is_file() { return Err(format!("{:?} already exists", destination_file)); } @@ -113,10 +113,7 @@ where info!("Downloading {} bytes from {}", download_size, url); } - struct DownloadProgress - where - F: Fn(&DownloadProgressRecord) -> bool, - { + struct DownloadProgress<'e, 'f, R> { progress_bar: ProgressBar, response: R, last_print: Instant, @@ -125,14 +122,11 @@ where download_size: f32, use_progress_bar: bool, start_time: Instant, - callback: Option, + callback: &'f mut DownloadProgressCallbackOption<'e>, notification_count: u64, } - impl Read for DownloadProgress - where - F: Fn(&DownloadProgressRecord) -> bool, - { + impl<'e, 'f, R: Read> Read for DownloadProgress<'e, 'f, R> { fn read(&mut self, buf: &mut [u8]) -> io::Result { let n = self.response.read(buf)?; @@ -178,7 +172,7 @@ where ); } - if let Some(callback) = &self.callback { + if let Some(callback) = self.callback { if to_update_progress && !callback(&progress_record) { info!("Download is aborted by the caller"); return Err(io::Error::new( @@ -192,7 +186,7 @@ where } } - let mut source = DownloadProgress { + let mut source = DownloadProgress::<'b, 'a> { progress_bar, response, last_print: Instant::now(), @@ -201,7 +195,7 @@ where download_size: (download_size as f32).max(1f32), use_progress_bar, start_time: Instant::now(), - callback: progress_notify_callback.as_ref(), + callback: progress_notify_callback, notification_count: 0, }; @@ -241,7 +235,7 @@ pub fn download_genesis_if_missing( &format!("http://{}/{}", rpc_addr, DEFAULT_GENESIS_ARCHIVE), &tmp_genesis_package, use_progress_bar, - &None:: bool>, + &mut None, )?; Ok(tmp_genesis_package) @@ -250,17 +244,14 @@ pub fn download_genesis_if_missing( } } -pub fn download_snapshot( +pub fn download_snapshot<'a, 'b>( rpc_addr: &SocketAddr, snapshot_output_dir: &Path, desired_snapshot_hash: (Slot, Hash), use_progress_bar: bool, maximum_snapshots_to_retain: usize, - progress_notify_callback: &Option, -) -> Result<(), String> -where - F: Fn(&DownloadProgressRecord) -> bool, -{ + progress_notify_callback: &'a mut DownloadProgressCallbackOption<'b>, +) -> Result<(), String> { snapshot_utils::purge_old_snapshot_archives(snapshot_output_dir, maximum_snapshots_to_retain); for compression in &[ @@ -290,7 +281,7 @@ where ), &desired_snapshot_package, use_progress_bar, - &progress_notify_callback, + progress_notify_callback, ) .is_ok() { diff --git a/local-cluster/tests/local_cluster.rs b/local-cluster/tests/local_cluster.rs index 8ae01e6a1..7011e1230 100644 --- a/local-cluster/tests/local_cluster.rs +++ b/local-cluster/tests/local_cluster.rs @@ -17,7 +17,7 @@ use solana_core::{ optimistic_confirmation_verifier::OptimisticConfirmationVerifier, validator::ValidatorConfig, }; -use solana_download_utils::{download_snapshot, DownloadProgressRecord}; +use solana_download_utils::download_snapshot; use solana_gossip::{ cluster_info::{self, VALIDATOR_PORT_RANGE}, crds_value::{self, CrdsData, CrdsValue}, @@ -1687,7 +1687,7 @@ fn test_snapshot_download() { archive_snapshot_hash, false, snapshot_utils::DEFAULT_MAX_SNAPSHOTS_TO_RETAIN, - &None:: bool>, + &mut None, ) .unwrap(); diff --git a/sdk/cargo-build-bpf/src/main.rs b/sdk/cargo-build-bpf/src/main.rs index cb0790b0d..ba0bf06d9 100644 --- a/sdk/cargo-build-bpf/src/main.rs +++ b/sdk/cargo-build-bpf/src/main.rs @@ -4,7 +4,7 @@ use { crate_description, crate_name, crate_version, value_t, value_t_or_exit, values_t, App, Arg, }, regex::Regex, - solana_download_utils::{download_file, DownloadProgressRecord}, + solana_download_utils::download_file, solana_sdk::signature::{write_keypair_file, Keypair}, std::{ collections::HashMap, @@ -112,12 +112,7 @@ fn install_if_missing( url.push_str(version); url.push('/'); url.push_str(file.to_str().unwrap()); - download_file( - &url.as_str(), - &file, - true, - &None:: bool>, - )?; + download_file(&url.as_str(), &file, true, &mut None)?; fs::create_dir_all(&target_path).map_err(|err| err.to_string())?; let zip = File::open(&file).map_err(|err| err.to_string())?; let tar = BzDecoder::new(BufReader::new(zip)); diff --git a/validator/src/main.rs b/validator/src/main.rs index f3e69f2e8..5cd3cbada 100644 --- a/validator/src/main.rs +++ b/validator/src/main.rs @@ -898,7 +898,7 @@ fn rpc_bootstrap( snapshot_hash, use_progress_bar, maximum_snapshots_to_retain, - &Some(|download_progress: &DownloadProgressRecord| { + &mut Some(Box::new(|download_progress: &DownloadProgressRecord| { debug!("Download progress: {:?}", download_progress); if download_progress.last_throughput < minimal_snapshot_download_speed @@ -922,16 +922,14 @@ fn rpc_bootstrap( and try a different node. Abort count: {}, Progress detail: {:?}", download_progress.last_throughput, minimal_snapshot_download_speed, download_abort_count, download_progress); + download_abort_count += 1; false } else { true } - }), + })), ); - if ret.is_err() { - download_abort_count += 1; - } gossip_service.join().unwrap(); ret })