diff --git a/core/src/cost_update_service.rs b/core/src/cost_update_service.rs index 0ebee0c09c..9430c920af 100644 --- a/core/src/cost_update_service.rs +++ b/core/src/cost_update_service.rs @@ -127,8 +127,10 @@ impl CostUpdateService { CostUpdate::FrozenBank { bank } => { bank.read_cost_tracker().unwrap().report_stats(bank.slot()); } - CostUpdate::ExecuteTiming { execute_timings } => { - dirty |= Self::update_cost_model(&cost_model, &execute_timings); + CostUpdate::ExecuteTiming { + mut execute_timings, + } => { + dirty |= Self::update_cost_model(&cost_model, &mut execute_timings); update_count += 1; } } @@ -151,16 +153,27 @@ impl CostUpdateService { } } - fn update_cost_model(cost_model: &RwLock, execute_timings: &ExecuteTimings) -> bool { + fn update_cost_model( + cost_model: &RwLock, + execute_timings: &mut ExecuteTimings, + ) -> bool { let mut dirty = false; { - let mut cost_model_mutable = cost_model.write().unwrap(); - for (program_id, timing) in &execute_timings.details.per_program_timings { - if timing.count < 1 { + for (program_id, program_timings) in &mut execute_timings.details.per_program_timings { + let current_estimated_program_cost = + cost_model.read().unwrap().find_instruction_cost(program_id); + program_timings.coalesce_error_timings(current_estimated_program_cost); + + if program_timings.count < 1 { continue; } - let units = timing.accumulated_units / timing.count as u64; - match cost_model_mutable.upsert_instruction_cost(program_id, units) { + + let units = program_timings.accumulated_units / program_timings.count as u64; + match cost_model + .write() + .unwrap() + .upsert_instruction_cost(program_id, units) + { Ok(c) => { debug!( "after replayed into bank, instruction {:?} has averaged cost {}", @@ -213,8 +226,8 @@ mod tests { #[test] fn test_update_cost_model_with_empty_execute_timings() { let cost_model = Arc::new(RwLock::new(CostModel::default())); - let empty_execute_timings = ExecuteTimings::default(); - CostUpdateService::update_cost_model(&cost_model, &empty_execute_timings); + let mut empty_execute_timings = ExecuteTimings::default(); + CostUpdateService::update_cost_model(&cost_model, &mut empty_execute_timings); assert_eq!( 0, @@ -247,9 +260,10 @@ mod tests { accumulated_us, accumulated_units, count, + errored_txs_compute_consumed: vec![], }, ); - CostUpdateService::update_cost_model(&cost_model, &execute_timings); + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); assert_eq!( 1, cost_model @@ -282,9 +296,10 @@ mod tests { accumulated_us, accumulated_units, count, + errored_txs_compute_consumed: vec![], }, ); - CostUpdateService::update_cost_model(&cost_model, &execute_timings); + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); assert_eq!( 1, cost_model @@ -303,4 +318,99 @@ mod tests { ); } } + + #[test] + fn test_update_cost_model_with_error_execute_timings() { + let cost_model = Arc::new(RwLock::new(CostModel::default())); + let mut execute_timings = ExecuteTimings::default(); + let program_key_1 = Pubkey::new_unique(); + + // Test updating cost model with a `ProgramTiming` with no compute units accumulated, i.e. + // `accumulated_units` == 0 + { + execute_timings.details.per_program_timings.insert( + program_key_1, + ProgramTiming { + accumulated_us: 1000, + accumulated_units: 0, + count: 0, + errored_txs_compute_consumed: vec![], + }, + ); + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + // If both the `errored_txs_compute_consumed` is empty and `count == 0`, then + // nothing should be inserted into the cost model + assert!(cost_model + .read() + .unwrap() + .get_instruction_cost_table() + .is_empty()); + } + + // Test updating cost model with only erroring compute costs where the `cost_per_error` is + // greater than the current instruction cost for the program. Should update with the + // new erroring compute costs + let cost_per_error = 1000; + { + execute_timings.details.per_program_timings.insert( + program_key_1, + ProgramTiming { + accumulated_us: 1000, + accumulated_units: 0, + count: 0, + errored_txs_compute_consumed: vec![cost_per_error; 3], + }, + ); + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + assert_eq!( + 1, + cost_model + .read() + .unwrap() + .get_instruction_cost_table() + .len() + ); + assert_eq!( + Some(&cost_per_error), + cost_model + .read() + .unwrap() + .get_instruction_cost_table() + .get(&program_key_1) + ); + } + + // Test updating cost model with only erroring compute costs where the error cost is + // `smaller_cost_per_error`, less than the current instruction cost for the program. + // The cost should not decrease for these new lesser errors + let smaller_cost_per_error = cost_per_error - 10; + { + execute_timings.details.per_program_timings.insert( + program_key_1, + ProgramTiming { + accumulated_us: 1000, + accumulated_units: 0, + count: 0, + errored_txs_compute_consumed: vec![smaller_cost_per_error; 3], + }, + ); + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + assert_eq!( + 1, + cost_model + .read() + .unwrap() + .get_instruction_cost_table() + .len() + ); + assert_eq!( + Some(&cost_per_error), + cost_model + .read() + .unwrap() + .get_instruction_cost_table() + .get(&program_key_1) + ); + } + } } diff --git a/program-runtime/src/invoke_context.rs b/program-runtime/src/invoke_context.rs index 026d3496e4..3d87cbe746 100644 --- a/program-runtime/src/invoke_context.rs +++ b/program-runtime/src/invoke_context.rs @@ -28,6 +28,12 @@ use { pub type ProcessInstructionWithContext = fn(usize, &[u8], &mut InvokeContext) -> Result<(), InstructionError>; +#[derive(Debug, PartialEq)] +pub struct ProcessInstructionResult { + pub compute_units_consumed: u64, + pub result: Result<(), InstructionError>, +} + #[derive(Clone)] pub struct BuiltinProgram { pub program_id: Pubkey, @@ -520,7 +526,8 @@ impl<'a> InvokeContext<'a> { &instruction_accounts, Some(&caller_write_privileges), &program_indices, - )?; + ) + .result?; // Verify the called program has not misbehaved let do_support_realloc = self.feature_set.is_active(&do_support_realloc::id()); @@ -709,7 +716,7 @@ impl<'a> InvokeContext<'a> { instruction_accounts: &[InstructionAccount], caller_write_privileges: Option<&[bool]>, program_indices: &[usize], - ) -> Result { + ) -> ProcessInstructionResult { let program_id = program_indices .last() .map(|index| *self.transaction_context.get_key_of_account_at_index(*index)) @@ -722,8 +729,13 @@ impl<'a> InvokeContext<'a> { } } else { // Verify the calling program hasn't misbehaved - self.verify_and_update(instruction_accounts, caller_write_privileges)?; - + let result = self.verify_and_update(instruction_accounts, caller_write_privileges); + if result.is_err() { + return ProcessInstructionResult { + compute_units_consumed: 0, + result, + }; + } // Record instruction if let Some(instruction_recorder) = &self.instruction_recorder { let compiled_instruction = CompiledInstruction { @@ -743,27 +755,31 @@ impl<'a> InvokeContext<'a> { } } + let mut compute_units_consumed = 0; let result = self .push(instruction_accounts, program_indices) .and_then(|_| { self.return_data = (program_id, Vec::new()); let pre_remaining_units = self.compute_meter.borrow().get_remaining(); - self.process_executable_chain(instruction_data)?; + let execution_result = self.process_executable_chain(instruction_data); let post_remaining_units = self.compute_meter.borrow().get_remaining(); + compute_units_consumed = pre_remaining_units.saturating_sub(post_remaining_units); + execution_result?; // Verify the called program has not misbehaved if is_lowest_invocation_level { - self.verify(instruction_accounts, program_indices)?; + self.verify(instruction_accounts, program_indices) } else { - self.verify_and_update(instruction_accounts, None)?; + self.verify_and_update(instruction_accounts, None) } - - Ok(pre_remaining_units.saturating_sub(post_remaining_units)) }); // Pop the invoke_stack to restore previous state self.pop(); - result + ProcessInstructionResult { + compute_units_consumed, + result, + } } /// Calls the instruction's program entrypoint method @@ -1068,6 +1084,10 @@ mod tests { ModifyOwned, ModifyNotOwned, ModifyReadonly, + ConsumeComputeUnits { + compute_units_consumed: u64, + desired_result: Result<(), InstructionError>, + }, } #[test] @@ -1185,6 +1205,17 @@ mod tests { .try_account_ref_mut()? .data_as_mut_slice()[0] = 1 } + MockInstruction::ConsumeComputeUnits { + compute_units_consumed, + desired_result, + } => { + invoke_context + .get_compute_meter() + .borrow_mut() + .consume(compute_units_consumed) + .unwrap(); + return desired_result; + } } } else { return Err(InstructionError::InvalidInstructionData); @@ -1371,12 +1402,14 @@ mod tests { .borrow_mut() .data_as_mut_slice()[0] = 1; assert_eq!( - invoke_context.process_instruction( - &instruction.data, - &instruction_accounts, - None, - &program_indices[1..], - ), + invoke_context + .process_instruction( + &instruction.data, + &instruction_accounts, + None, + &program_indices[1..], + ) + .result, Err(InstructionError::ExternalAccountDataModified) ); transaction_context @@ -1390,12 +1423,14 @@ mod tests { .borrow_mut() .data_as_mut_slice()[0] = 1; assert_eq!( - invoke_context.process_instruction( - &instruction.data, - &instruction_accounts, - None, - &program_indices[1..], - ), + invoke_context + .process_instruction( + &instruction.data, + &instruction_accounts, + None, + &program_indices[1..], + ) + .result, Err(InstructionError::ReadonlyDataModified) ); transaction_context @@ -1406,15 +1441,33 @@ mod tests { invoke_context.pop(); let cases = vec![ - (MockInstruction::NoopSuccess, Ok(0)), + ( + MockInstruction::NoopSuccess, + ProcessInstructionResult { + result: Ok(()), + compute_units_consumed: 0, + }, + ), ( MockInstruction::NoopFail, - Err(InstructionError::GenericError), + ProcessInstructionResult { + result: Err(InstructionError::GenericError), + compute_units_consumed: 0, + }, + ), + ( + MockInstruction::ModifyOwned, + ProcessInstructionResult { + result: Ok(()), + compute_units_consumed: 0, + }, ), - (MockInstruction::ModifyOwned, Ok(0)), ( MockInstruction::ModifyNotOwned, - Err(InstructionError::ExternalAccountDataModified), + ProcessInstructionResult { + result: Err(InstructionError::ExternalAccountDataModified), + compute_units_consumed: 0, + }, ), ]; for case in cases { @@ -1586,4 +1639,83 @@ mod tests { ); invoke_context.pop(); } + + #[test] + fn test_process_instruction_compute_budget() { + let caller_program_id = solana_sdk::pubkey::new_rand(); + let callee_program_id = solana_sdk::pubkey::new_rand(); + let builtin_programs = &[BuiltinProgram { + program_id: callee_program_id, + process_instruction: mock_process_instruction, + }]; + + let owned_account = AccountSharedData::new(42, 1, &callee_program_id); + let not_owned_account = AccountSharedData::new(84, 1, &solana_sdk::pubkey::new_rand()); + let readonly_account = AccountSharedData::new(168, 1, &solana_sdk::pubkey::new_rand()); + let loader_account = AccountSharedData::new(0, 0, &native_loader::id()); + let mut program_account = AccountSharedData::new(1, 0, &native_loader::id()); + program_account.set_executable(true); + + let accounts = vec![ + (solana_sdk::pubkey::new_rand(), owned_account), + (solana_sdk::pubkey::new_rand(), not_owned_account), + (solana_sdk::pubkey::new_rand(), readonly_account), + (caller_program_id, loader_account), + (callee_program_id, program_account), + ]; + let program_indices = [3, 4]; + + let metas = vec![ + AccountMeta::new(accounts[0].0, false), + AccountMeta::new(accounts[1].0, false), + AccountMeta::new_readonly(accounts[2].0, false), + ]; + let instruction_accounts = metas + .iter() + .enumerate() + .map(|(account_index, account_meta)| InstructionAccount { + index: account_index, + is_signer: account_meta.is_signer, + is_writable: account_meta.is_writable, + }) + .collect::>(); + + let transaction_context = TransactionContext::new(accounts, 1); + let mut invoke_context = InvokeContext::new_mock(&transaction_context, builtin_programs); + let compute_units_consumed = 10; + let desired_results = vec![Ok(()), Err(InstructionError::GenericError)]; + + for desired_result in desired_results { + let instruction = Instruction::new_with_bincode( + callee_program_id, + &MockInstruction::ConsumeComputeUnits { + compute_units_consumed, + desired_result: desired_result.clone(), + }, + metas.clone(), + ); + invoke_context + .push(&instruction_accounts, &program_indices[..1]) + .unwrap(); + + let result = invoke_context.process_instruction( + &instruction.data, + &instruction_accounts, + None, + &program_indices[1..], + ); + + // Because the instruction had compute cost > 0, then regardless of the execution result, + // the number of compute units consumed should be a non-default which is something greater + // than zero. + assert!(result.compute_units_consumed > 0); + assert_eq!( + result, + ProcessInstructionResult { + compute_units_consumed, + result: desired_result, + } + ); + } + } } diff --git a/program-runtime/src/timings.rs b/program-runtime/src/timings.rs index a61b621e1a..c5f71e177f 100644 --- a/program-runtime/src/timings.rs +++ b/program-runtime/src/timings.rs @@ -5,6 +5,18 @@ pub struct ProgramTiming { pub accumulated_us: u64, pub accumulated_units: u64, pub count: u32, + pub errored_txs_compute_consumed: Vec, +} + +impl ProgramTiming { + pub fn coalesce_error_timings(&mut self, current_estimated_program_cost: u64) { + for tx_error_compute_consumed in self.errored_txs_compute_consumed.drain(..) { + let compute_units_update = + std::cmp::max(current_estimated_program_cost, tx_error_compute_consumed); + self.accumulated_units = self.accumulated_units.saturating_add(compute_units_update); + self.count = self.count.saturating_add(1); + } + } } #[derive(Default, Debug)] @@ -46,10 +58,24 @@ impl ExecuteDetailsTimings { program_timing.count = program_timing.count.saturating_add(other.count); } } - pub fn accumulate_program(&mut self, program_id: &Pubkey, us: u64, units: u64) { + pub fn accumulate_program( + &mut self, + program_id: &Pubkey, + us: u64, + compute_units_consumed: u64, + is_error: bool, + ) { let program_timing = self.per_program_timings.entry(*program_id).or_default(); program_timing.accumulated_us = program_timing.accumulated_us.saturating_add(us); - program_timing.accumulated_units = program_timing.accumulated_units.saturating_add(units); - program_timing.count = program_timing.count.saturating_add(1); + if is_error { + program_timing + .errored_txs_compute_consumed + .push(compute_units_consumed); + } else { + program_timing.accumulated_units = program_timing + .accumulated_units + .saturating_add(compute_units_consumed); + program_timing.count = program_timing.count.saturating_add(1); + }; } } diff --git a/program-test/src/lib.rs b/program-test/src/lib.rs index f179d1a679..1a193c5d7c 100644 --- a/program-test/src/lib.rs +++ b/program-test/src/lib.rs @@ -283,6 +283,7 @@ impl solana_sdk::program_stubs::SyscallStubs for SyscallStubs { Some(&caller_write_privileges), &program_indices, ) + .result .map_err(|err| ProgramError::try_from(err).unwrap_or_else(|err| panic!("{}", err)))?; // Copy invoke_context accounts modifications into caller's account_info diff --git a/programs/bpf_loader/src/syscalls.rs b/programs/bpf_loader/src/syscalls.rs index 3e382666ab..8f6e6fe78d 100644 --- a/programs/bpf_loader/src/syscalls.rs +++ b/programs/bpf_loader/src/syscalls.rs @@ -2391,6 +2391,7 @@ fn call<'a, 'b: 'a>( Some(&caller_write_privileges), &program_indices, ) + .result .map_err(SyscallError::InstructionError)?; // Copy results back to caller diff --git a/runtime/src/cost_model.rs b/runtime/src/cost_model.rs index 9c29da5952..c9f982ec6b 100644 --- a/runtime/src/cost_model.rs +++ b/runtime/src/cost_model.rs @@ -142,6 +142,20 @@ impl CostModel { self.instruction_execution_cost_table.get_cost_table() } + pub fn find_instruction_cost(&self, program_key: &Pubkey) -> u64 { + match self.instruction_execution_cost_table.get_cost(program_key) { + Some(cost) => *cost, + None => { + let default_value = self.instruction_execution_cost_table.get_mode(); + debug!( + "Program key {:?} does not have assigned cost, using mode {}", + program_key, default_value + ); + default_value + } + } + } + fn get_signature_cost(&self, transaction: &SanitizedTransaction) -> u64 { transaction.signatures().len() as u64 * SIGNATURE_COST } @@ -188,20 +202,6 @@ impl CostModel { cost } - fn find_instruction_cost(&self, program_key: &Pubkey) -> u64 { - match self.instruction_execution_cost_table.get_cost(program_key) { - Some(cost) => *cost, - None => { - let default_value = self.instruction_execution_cost_table.get_mode(); - debug!( - "Program key {:?} does not have assigned cost, using mode {}", - program_key, default_value - ); - default_value - } - } - } - fn calculate_account_data_size_on_deserialized_system_instruction( instruction: SystemInstruction, ) -> u64 { diff --git a/runtime/src/message_processor.rs b/runtime/src/message_processor.rs index 2796e1fc23..7bbd9faeb8 100644 --- a/runtime/src/message_processor.rs +++ b/runtime/src/message_processor.rs @@ -3,7 +3,7 @@ use { solana_measure::measure::Measure, solana_program_runtime::{ instruction_recorder::InstructionRecorder, - invoke_context::{BuiltinProgram, Executors, InvokeContext}, + invoke_context::{BuiltinProgram, Executors, InvokeContext, ProcessInstructionResult}, log_collector::LogCollector, timings::ExecuteDetailsTimings, }, @@ -128,21 +128,25 @@ impl MessageProcessor { }) .collect::>(); let mut time = Measure::start("execute_instruction"); - let compute_meter_consumption = invoke_context - .process_instruction( - &instruction.data, - &instruction_accounts, - None, - program_indices, - ) - .map_err(|err| TransactionError::InstructionError(instruction_index as u8, err))?; + let ProcessInstructionResult { + compute_units_consumed, + result, + } = invoke_context.process_instruction( + &instruction.data, + &instruction_accounts, + None, + program_indices, + ); time.stop(); timings.accumulate_program( instruction.program_id(&message.account_keys), time.as_us(), - compute_meter_consumption, + compute_units_consumed, + result.is_err(), ); timings.accumulate(&invoke_context.timings); + result + .map_err(|err| TransactionError::InstructionError(instruction_index as u8, err))?; } Ok(ProcessedMessageInfo { accounts_data_len: invoke_context.get_accounts_data_meter().current(),