Use type alias for DownloadProgress callback (#17518)

Convert to use type alias for the callback and cascade the changes to callers. Thanks @jeffwashington for the help making it possible.
Changed the closure for the progress update in the validator main to FnMut and modify the abort count in the closure which is more reliable.
This commit is contained in:
Lijun Wang 2021-05-26 13:26:07 -07:00 committed by GitHub
parent 6abe089740
commit 54f0fc9f0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 40 deletions

View File

@ -46,18 +46,18 @@ pub struct DownloadProgressRecord {
pub notification_count: u64,
}
type DownloadProgressCallback<'a> = Box<dyn FnMut(&DownloadProgressRecord) -> bool + 'a>;
type DownloadProgressCallbackOption<'a> = Option<DownloadProgressCallback<'a>>;
/// This callback allows the caller to get notified of the download progress modelled by DownloadProgressRecord
/// Return "true" to continue the download
/// Return "false" to abort the download
pub fn download_file<F>(
pub fn download_file<'a, 'b>(
url: &str,
destination_file: &Path,
use_progress_bar: bool,
progress_notify_callback: &Option<F>,
) -> Result<(), String>
where
F: Fn(&DownloadProgressRecord) -> bool,
{
progress_notify_callback: &'a mut DownloadProgressCallbackOption<'b>,
) -> Result<(), String> {
if destination_file.is_file() {
return Err(format!("{:?} already exists", destination_file));
}
@ -113,10 +113,7 @@ where
info!("Downloading {} bytes from {}", download_size, url);
}
struct DownloadProgress<R, F>
where
F: Fn(&DownloadProgressRecord) -> bool,
{
struct DownloadProgress<'e, 'f, R> {
progress_bar: ProgressBar,
response: R,
last_print: Instant,
@ -125,14 +122,11 @@ where
download_size: f32,
use_progress_bar: bool,
start_time: Instant,
callback: Option<F>,
callback: &'f mut DownloadProgressCallbackOption<'e>,
notification_count: u64,
}
impl<R: Read, F> Read for DownloadProgress<R, F>
where
F: Fn(&DownloadProgressRecord) -> bool,
{
impl<'e, 'f, R: Read> Read for DownloadProgress<'e, 'f, R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.response.read(buf)?;
@ -178,7 +172,7 @@ where
);
}
if let Some(callback) = &self.callback {
if let Some(callback) = self.callback {
if to_update_progress && !callback(&progress_record) {
info!("Download is aborted by the caller");
return Err(io::Error::new(
@ -192,7 +186,7 @@ where
}
}
let mut source = DownloadProgress {
let mut source = DownloadProgress::<'b, 'a> {
progress_bar,
response,
last_print: Instant::now(),
@ -201,7 +195,7 @@ where
download_size: (download_size as f32).max(1f32),
use_progress_bar,
start_time: Instant::now(),
callback: progress_notify_callback.as_ref(),
callback: progress_notify_callback,
notification_count: 0,
};
@ -241,7 +235,7 @@ pub fn download_genesis_if_missing(
&format!("http://{}/{}", rpc_addr, DEFAULT_GENESIS_ARCHIVE),
&tmp_genesis_package,
use_progress_bar,
&None::<fn(&DownloadProgressRecord) -> bool>,
&mut None,
)?;
Ok(tmp_genesis_package)
@ -250,17 +244,14 @@ pub fn download_genesis_if_missing(
}
}
pub fn download_snapshot<F>(
pub fn download_snapshot<'a, 'b>(
rpc_addr: &SocketAddr,
snapshot_output_dir: &Path,
desired_snapshot_hash: (Slot, Hash),
use_progress_bar: bool,
maximum_snapshots_to_retain: usize,
progress_notify_callback: &Option<F>,
) -> Result<(), String>
where
F: Fn(&DownloadProgressRecord) -> bool,
{
progress_notify_callback: &'a mut DownloadProgressCallbackOption<'b>,
) -> Result<(), String> {
snapshot_utils::purge_old_snapshot_archives(snapshot_output_dir, maximum_snapshots_to_retain);
for compression in &[
@ -290,7 +281,7 @@ where
),
&desired_snapshot_package,
use_progress_bar,
&progress_notify_callback,
progress_notify_callback,
)
.is_ok()
{

View File

@ -17,7 +17,7 @@ use solana_core::{
optimistic_confirmation_verifier::OptimisticConfirmationVerifier,
validator::ValidatorConfig,
};
use solana_download_utils::{download_snapshot, DownloadProgressRecord};
use solana_download_utils::download_snapshot;
use solana_gossip::{
cluster_info::{self, VALIDATOR_PORT_RANGE},
crds_value::{self, CrdsData, CrdsValue},
@ -1687,7 +1687,7 @@ fn test_snapshot_download() {
archive_snapshot_hash,
false,
snapshot_utils::DEFAULT_MAX_SNAPSHOTS_TO_RETAIN,
&None::<fn(&DownloadProgressRecord) -> bool>,
&mut None,
)
.unwrap();

View File

@ -4,7 +4,7 @@ use {
crate_description, crate_name, crate_version, value_t, value_t_or_exit, values_t, App, Arg,
},
regex::Regex,
solana_download_utils::{download_file, DownloadProgressRecord},
solana_download_utils::download_file,
solana_sdk::signature::{write_keypair_file, Keypair},
std::{
collections::HashMap,
@ -112,12 +112,7 @@ fn install_if_missing(
url.push_str(version);
url.push('/');
url.push_str(file.to_str().unwrap());
download_file(
&url.as_str(),
&file,
true,
&None::<fn(&DownloadProgressRecord) -> bool>,
)?;
download_file(&url.as_str(), &file, true, &mut None)?;
fs::create_dir_all(&target_path).map_err(|err| err.to_string())?;
let zip = File::open(&file).map_err(|err| err.to_string())?;
let tar = BzDecoder::new(BufReader::new(zip));

View File

@ -898,7 +898,7 @@ fn rpc_bootstrap(
snapshot_hash,
use_progress_bar,
maximum_snapshots_to_retain,
&Some(|download_progress: &DownloadProgressRecord| {
&mut Some(Box::new(|download_progress: &DownloadProgressRecord| {
debug!("Download progress: {:?}", download_progress);
if download_progress.last_throughput < minimal_snapshot_download_speed
@ -922,16 +922,14 @@ fn rpc_bootstrap(
and try a different node. Abort count: {}, Progress detail: {:?}",
download_progress.last_throughput, minimal_snapshot_download_speed,
download_abort_count, download_progress);
download_abort_count += 1;
false
} else {
true
}
}),
})),
);
if ret.is_err() {
download_abort_count += 1;
}
gossip_service.join().unwrap();
ret
})