diff --git a/tokens/src/commands.rs b/tokens/src/commands.rs index d09b618c3..7ea2bde37 100644 --- a/tokens/src/commands.rs +++ b/tokens/src/commands.rs @@ -21,10 +21,10 @@ use { solana_rpc_client_api::{ client_error::{Error as ClientError, Result as ClientResult}, config::RpcSendTransactionConfig, - request::MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS, + request::{MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS, MAX_MULTIPLE_ACCOUNTS}, }, solana_sdk::{ - clock::{Slot, DEFAULT_MS_PER_SLOT}, + clock::Slot, commitment_config::CommitmentConfig, hash::Hash, instruction::Instruction, @@ -49,7 +49,7 @@ use { Arc, }, thread::sleep, - time::{Duration, Instant}, + time::Duration, }, }; @@ -108,6 +108,10 @@ pub enum Error { ClientError(#[from] ClientError), #[error("Missing lockup authority")] MissingLockupAuthority, + #[error("Missing messages")] + MissingMessages, + #[error("Error estimating message fees")] + FeeEstimationError, #[error("insufficient funds in {0:?}, requires {1}")] InsufficientFunds(FundingSources, String), #[error("Program error")] @@ -297,7 +301,27 @@ fn build_messages( stake_extras: &mut StakeExtras, created_accounts: &mut u64, ) -> Result<(), Error> { - for allocation in allocations.iter() { + let mut existing_associated_token_accounts = vec![]; + if let Some(spl_token_args) = &args.spl_token_args { + let allocation_chunks = allocations.chunks(MAX_MULTIPLE_ACCOUNTS); + for allocation_chunk in allocation_chunks { + let associated_token_addresses = allocation_chunk + .iter() + .map(|x| { + let wallet_address = x.recipient.parse().unwrap(); + let associated_token_address = get_associated_token_address( + &wallet_address, + &spl_token_pubkey(&spl_token_args.mint), + ); + pubkey_from_spl_token(&associated_token_address) + }) + .collect::>(); + let mut maybe_accounts = client.get_multiple_accounts(&associated_token_addresses)?; + existing_associated_token_accounts.append(&mut maybe_accounts); + } + } + + for (i, allocation) in allocations.iter().enumerate() { if exit.load(Ordering::SeqCst) { db.dump()?; return Err(Error::ExitSignal); @@ -311,14 +335,8 @@ fn build_messages( let do_create_associated_token_account = if let Some(spl_token_args) = &args.spl_token_args { - let wallet_address = allocation.recipient.parse().unwrap(); - let associated_token_address = get_associated_token_address( - &wallet_address, - &spl_token_pubkey(&spl_token_args.mint), - ); - let do_create_associated_token_account = client - .get_multiple_accounts(&[pubkey_from_spl_token(&associated_token_address)])?[0] - .is_none(); + let do_create_associated_token_account = + existing_associated_token_accounts[i].is_none(); if do_create_associated_token_account { *created_accounts += 1; } @@ -733,23 +751,18 @@ fn log_transaction_confirmations( Ok(()) } -pub fn get_fees_for_messages(messages: &[Message], client: &RpcClient) -> Result { - // This is an arbitrary value to get regular blockhash updates for balance checks without - // hitting the RPC node with too many requests - const BLOCKHASH_REFRESH_MILLIS: u64 = DEFAULT_MS_PER_SLOT * 32; - - let mut latest_blockhash = client.get_latest_blockhash()?; - let mut now = Instant::now(); - let mut fees = 0; - for mut message in messages.iter().cloned() { - if now.elapsed() > Duration::from_millis(BLOCKHASH_REFRESH_MILLIS) { - latest_blockhash = client.get_latest_blockhash()?; - now = Instant::now(); - } - message.recent_blockhash = latest_blockhash; - fees += client.get_fee_for_message(&message)?; - } - Ok(fees) +pub fn get_fee_estimate_for_messages( + messages: &[Message], + client: &RpcClient, +) -> Result { + let mut message = messages.first().ok_or(Error::MissingMessages)?.clone(); + let latest_blockhash = client.get_latest_blockhash()?; + message.recent_blockhash = latest_blockhash; + let fee = client.get_fee_for_message(&message)?; + let fee_estimate = fee + .checked_mul(messages.len() as u64) + .ok_or(Error::FeeEstimationError)?; + Ok(fee_estimate) } fn check_payer_balances( @@ -759,7 +772,7 @@ fn check_payer_balances( args: &DistributeTokensArgs, ) -> Result<(), Error> { let mut undistributed_tokens: u64 = allocations.iter().map(|x| x.amount).sum(); - let fees = get_fees_for_messages(messages, client)?; + let fees = get_fee_estimate_for_messages(messages, client)?; let (distribution_source, unlocked_sol_source) = if let Some(stake_args) = &args.stake_args { let total_unlocked_sol = allocations.len() as u64 * stake_args.unlocked_sol; diff --git a/tokens/src/spl_token.rs b/tokens/src/spl_token.rs index 3850e7467..c897172fa 100644 --- a/tokens/src/spl_token.rs +++ b/tokens/src/spl_token.rs @@ -1,7 +1,7 @@ use { crate::{ args::{DistributeTokensArgs, SplTokenArgs}, - commands::{get_fees_for_messages, Allocation, Error, FundingSource}, + commands::{get_fee_estimate_for_messages, Allocation, Error, FundingSource}, }, console::style, solana_account_decoder::parse_token::{ @@ -96,7 +96,7 @@ pub fn check_spl_token_balances( .as_ref() .expect("spl_token_args must be some"); let allocation_amount: u64 = allocations.iter().map(|x| x.amount).sum(); - let fees = get_fees_for_messages(messages, client)?; + let fees = get_fee_estimate_for_messages(messages, client)?; let token_account_rent_exempt_balance = client.get_minimum_balance_for_rent_exemption(SplTokenAccount::LEN)?;