From da00b39f4f92fb16417bd2d8bd218a04a34527b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Thu, 17 Feb 2022 10:16:28 +0100 Subject: [PATCH] Cleanup: get_program_key() and get_loader_key() in TransactionContext (#23191) * Moves TransactionContext::get_program_key() to InstructionContext::get_program_key(). * Removes TransactionContext::get_loader_key(). * Test full program and loader executable account chain in BPF loader. --- program-runtime/src/invoke_context.rs | 13 +- program-runtime/src/native_loader.rs | 4 +- program-test/src/lib.rs | 26 +- programs/bpf/tests/programs.rs | 11 +- programs/bpf_loader/src/lib.rs | 32 +- programs/bpf_loader/src/syscalls.rs | 486 ++++++++++---------------- runtime/src/bank.rs | 4 +- runtime/src/builtins.rs | 8 +- sdk/src/transaction_context.rs | 24 +- 9 files changed, 247 insertions(+), 361 deletions(-) diff --git a/program-runtime/src/invoke_context.rs b/program-runtime/src/invoke_context.rs index 3e0a3bf99..6e0bebbe5 100644 --- a/program-runtime/src/invoke_context.rs +++ b/program-runtime/src/invoke_context.rs @@ -443,9 +443,12 @@ impl<'a> InvokeContext<'a> { ) -> Result<(), InstructionError> { let do_support_realloc = self.feature_set.is_active(&do_support_realloc::id()); let cap_accounts_data_len = self.feature_set.is_active(&cap_accounts_data_len::id()); - let program_id = self + let instruction_context = self .transaction_context - .get_program_key() + .get_current_instruction_context() + .map_err(|_| InstructionError::CallDepth)?; + let program_id = instruction_context + .get_program_key(self.transaction_context) .map_err(|_| InstructionError::CallDepth)?; // Verify all executable accounts have zero outstanding refs @@ -530,8 +533,8 @@ impl<'a> InvokeContext<'a> { let cap_accounts_data_len = self.feature_set.is_active(&cap_accounts_data_len::id()); let transaction_context = &self.transaction_context; let instruction_context = transaction_context.get_current_instruction_context()?; - let program_id = transaction_context - .get_program_key() + let program_id = instruction_context + .get_program_key(transaction_context) .map_err(|_| InstructionError::CallDepth)?; // Verify the per-account instruction results @@ -1268,7 +1271,7 @@ mod tests { ) -> Result<(), InstructionError> { let transaction_context = &invoke_context.transaction_context; let instruction_context = transaction_context.get_current_instruction_context()?; - let program_id = transaction_context.get_program_key()?; + let program_id = instruction_context.get_program_key(transaction_context)?; assert_eq!( program_id, instruction_context diff --git a/program-runtime/src/native_loader.rs b/program-runtime/src/native_loader.rs index 181387393..2b3d7d0a3 100644 --- a/program-runtime/src/native_loader.rs +++ b/program-runtime/src/native_loader.rs @@ -170,7 +170,9 @@ impl NativeLoader { invoke_context: &mut InvokeContext, ) -> Result<(), InstructionError> { let (program_id, name_vec) = { - let program_id = invoke_context.transaction_context.get_program_key()?; + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context()?; + let program_id = instruction_context.get_program_key(transaction_context)?; let keyed_accounts = invoke_context.get_keyed_accounts()?; let program = keyed_account_at_index(keyed_accounts, first_instruction_account)?; if native_loader::id() != *program_id { diff --git a/program-test/src/lib.rs b/program-test/src/lib.rs index 1bfff1c9d..bd5662237 100644 --- a/program-test/src/lib.rs +++ b/program-test/src/lib.rs @@ -105,7 +105,7 @@ pub fn builtin_process_instruction( ..instruction_context.get_number_of_accounts(); let log_collector = invoke_context.get_log_collector(); - let program_id = transaction_context.get_program_key()?; + let program_id = instruction_context.get_program_key(transaction_context)?; stable_log::program_invoke( &log_collector, program_id, @@ -250,10 +250,12 @@ impl solana_sdk::program_stubs::SyscallStubs for SyscallStubs { ) -> ProgramResult { let invoke_context = get_invoke_context(); let log_collector = invoke_context.get_log_collector(); - - let caller = *invoke_context - .transaction_context - .get_program_key() + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context + .get_current_instruction_context() + .unwrap(); + let caller = instruction_context + .get_program_key(transaction_context) .unwrap(); stable_log::program_invoke( @@ -264,7 +266,7 @@ impl solana_sdk::program_stubs::SyscallStubs for SyscallStubs { let signers = signers_seeds .iter() - .map(|seeds| Pubkey::create_program_address(seeds, &caller).unwrap()) + .map(|seeds| Pubkey::create_program_address(seeds, caller).unwrap()) .collect::>(); let (instruction_accounts, program_indices) = invoke_context .prepare_instruction(instruction, &signers) @@ -374,12 +376,14 @@ impl solana_sdk::program_stubs::SyscallStubs for SyscallStubs { fn sol_set_return_data(&self, data: &[u8]) { let invoke_context = get_invoke_context(); - let caller = *invoke_context - .transaction_context - .get_program_key() + let transaction_context = &mut invoke_context.transaction_context; + let instruction_context = transaction_context + .get_current_instruction_context() .unwrap(); - invoke_context - .transaction_context + let caller = *instruction_context + .get_program_key(transaction_context) + .unwrap(); + transaction_context .set_return_data(caller, data.to_vec()) .unwrap(); } diff --git a/programs/bpf/tests/programs.rs b/programs/bpf/tests/programs.rs index 66b10ec7c..2543ea603 100644 --- a/programs/bpf/tests/programs.rs +++ b/programs/bpf/tests/programs.rs @@ -229,13 +229,12 @@ fn run_program(name: &str) -> u64 { let mut instruction_count = 0; let mut tracer = None; for i in 0..2 { - invoke_context - .transaction_context + let transaction_context = &mut invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context().unwrap(); + let caller = *instruction_context.get_program_key(transaction_context).unwrap(); + transaction_context .set_return_data( - *invoke_context - .transaction_context - .get_program_key() - .unwrap(), + caller, Vec::new(), ) .unwrap(); diff --git a/programs/bpf_loader/src/lib.rs b/programs/bpf_loader/src/lib.rs index 5dcfd80d9..d5f3e5264 100644 --- a/programs/bpf_loader/src/lib.rs +++ b/programs/bpf_loader/src/lib.rs @@ -273,7 +273,9 @@ fn process_instruction_common( use_jit: bool, ) -> Result<(), InstructionError> { let log_collector = invoke_context.get_log_collector(); - let program_id = invoke_context.transaction_context.get_program_key()?; + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context()?; + let program_id = instruction_context.get_program_key(transaction_context)?; let keyed_accounts = invoke_context.get_keyed_accounts()?; let first_account = keyed_account_at_index(keyed_accounts, first_instruction_account)?; @@ -350,7 +352,9 @@ fn process_instruction_common( use_jit, false, )?; - let program_id = invoke_context.transaction_context.get_program_key()?; + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context()?; + let program_id = instruction_context.get_program_key(transaction_context)?; invoke_context.add_executor(program_id, executor.clone()); executor } @@ -397,7 +401,9 @@ fn process_loader_upgradeable_instruction( use_jit: bool, ) -> Result<(), InstructionError> { let log_collector = invoke_context.get_log_collector(); - let program_id = invoke_context.transaction_context.get_program_key()?; + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context()?; + let program_id = instruction_context.get_program_key(transaction_context)?; let keyed_accounts = invoke_context.get_keyed_accounts()?; match limited_deserialize(instruction_data)? { @@ -550,7 +556,9 @@ fn process_loader_upgradeable_instruction( .accounts .push(AccountMeta::new(*buffer.unsigned_key(), false)); - let caller_program_id = invoke_context.transaction_context.get_program_key()?; + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context()?; + let caller_program_id = instruction_context.get_program_key(transaction_context)?; let signers = [&[new_program_id.as_ref(), &[bump_seed]]] .iter() .map(|seeds| Pubkey::create_program_address(*seeds, caller_program_id)) @@ -947,7 +955,9 @@ fn process_loader_instruction( invoke_context: &mut InvokeContext, use_jit: bool, ) -> Result<(), InstructionError> { - let program_id = invoke_context.transaction_context.get_program_key()?; + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context()?; + let program_id = instruction_context.get_program_key(transaction_context)?; let keyed_accounts = invoke_context.get_keyed_accounts()?; let program = keyed_account_at_index(keyed_accounts, first_instruction_account)?; if program.owner()? != *program_id { @@ -1036,15 +1046,13 @@ impl Executor for BpfExecutor { let log_collector = invoke_context.get_log_collector(); let compute_meter = invoke_context.get_compute_meter(); let stack_height = invoke_context.get_stack_height(); + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context()?; + let program_id = *instruction_context.get_program_key(transaction_context)?; let mut serialize_time = Measure::start("serialize"); - let program_id = *invoke_context.transaction_context.get_program_key()?; - let (mut parameter_bytes, account_lengths) = serialize_parameters( - invoke_context.transaction_context, - invoke_context - .transaction_context - .get_current_instruction_context()?, - )?; + let (mut parameter_bytes, account_lengths) = + serialize_parameters(invoke_context.transaction_context, instruction_context)?; serialize_time.stop(); let mut create_vm_time = Measure::start("create_vm"); let mut execute_time; diff --git a/programs/bpf_loader/src/syscalls.rs b/programs/bpf_loader/src/syscalls.rs index 975e0c253..7f584fdc3 100644 --- a/programs/bpf_loader/src/syscalls.rs +++ b/programs/bpf_loader/src/syscalls.rs @@ -289,7 +289,11 @@ pub fn bind_syscall_context_objects<'a, 'b>( let loader_id = invoke_context .transaction_context - .get_loader_key() + .get_current_instruction_context() + .and_then(|instruction_context| { + instruction_context.try_borrow_program_account(invoke_context.transaction_context) + }) + .map(|program_account| *program_account.get_owner()) .map_err(SyscallError::InstructionError)?; let invoke_context = Rc::new(RefCell::new(invoke_context)); @@ -617,6 +621,18 @@ fn translate_string_and_do( } } +/// Returns the owner of the program account in the current InstructionContext +fn get_current_loader_key(invoke_context: &InvokeContext) -> Result { + invoke_context + .transaction_context + .get_current_instruction_context() + .and_then(|instruction_context| { + instruction_context.try_borrow_program_account(invoke_context.transaction_context) + }) + .map(|program_account| *program_account.get_owner()) + .map_err(SyscallError::InstructionError) +} + /// Abort syscall functions, called when the BPF program calls `abort()` /// LLVM will insert calls to `abort()` if it detects an untenable situation, /// `abort()` is not intended to be called explicitly by the program. @@ -666,20 +682,11 @@ impl<'a, 'b> SyscallObject for SyscallPanic<'a, 'b> { { question_mark!(invoke_context.get_compute_meter().consume(len), result); } - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); - *result = translate_string_and_do( - memory_mapping, - file, - len, - &loader_id, - &mut |string: &str| Err(SyscallError::Panic(string.to_string(), line, column).into()), - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); + *result = + translate_string_and_do(memory_mapping, file, len, loader_id, &mut |string: &str| { + Err(SyscallError::Panic(string.to_string(), line, column).into()) + }); } } @@ -717,24 +724,12 @@ impl<'a, 'b> SyscallObject for SyscallLog<'a, 'b> { }; question_mark!(invoke_context.get_compute_meter().consume(cost), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); question_mark!( - translate_string_and_do( - memory_mapping, - addr, - len, - &loader_id, - &mut |string: &str| { - stable_log::program_log(&invoke_context.get_log_collector(), string); - Ok(0) - }, - ), + translate_string_and_do(memory_mapping, addr, len, loader_id, &mut |string: &str| { + stable_log::program_log(&invoke_context.get_log_collector(), string); + Ok(0) + }), result ); *result = Ok(0); @@ -840,15 +835,9 @@ impl<'a, 'b> SyscallObject for SyscallLogPubkey<'a, 'b> { let cost = invoke_context.get_compute_budget().log_pubkey_units; question_mark!(invoke_context.get_compute_meter().consume(cost), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let pubkey = question_mark!( - translate_type::(memory_mapping, pubkey_addr, &loader_id), + translate_type::(memory_mapping, pubkey_addr, loader_id), result ); stable_log::program_log(&invoke_context.get_log_collector(), &pubkey.to_string()); @@ -957,20 +946,14 @@ impl<'a, 'b> SyscallObject for SyscallCreateProgramAddress<'a, 'b> { .create_program_address_units; question_mark!(invoke_context.get_compute_meter().consume(cost), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let (seeds, program_id) = question_mark!( translate_and_check_program_address_inputs( seeds_addr, seeds_len, program_id_addr, memory_mapping, - &loader_id, + loader_id, ), result ); @@ -983,7 +966,7 @@ impl<'a, 'b> SyscallObject for SyscallCreateProgramAddress<'a, 'b> { } }; let address = question_mark!( - translate_slice_mut::(memory_mapping, address_addr, 32, &loader_id), + translate_slice_mut::(memory_mapping, address_addr, 32, loader_id), result ); address.copy_from_slice(new_address.as_ref()); @@ -1017,20 +1000,14 @@ impl<'a, 'b> SyscallObject for SyscallTryFindProgramAddress<'a, 'b> { .create_program_address_units; question_mark!(invoke_context.get_compute_meter().consume(cost), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let (seeds, program_id) = question_mark!( translate_and_check_program_address_inputs( seeds_addr, seeds_len, program_id_addr, memory_mapping, - &loader_id, + loader_id, ), result ); @@ -1045,11 +1022,11 @@ impl<'a, 'b> SyscallObject for SyscallTryFindProgramAddress<'a, 'b> { Pubkey::create_program_address(&seeds_with_bump, program_id) { let bump_seed_ref = question_mark!( - translate_type_mut::(memory_mapping, bump_seed_addr, &loader_id), + translate_type_mut::(memory_mapping, bump_seed_addr, loader_id), result ); let address = question_mark!( - translate_slice_mut::(memory_mapping, address_addr, 32, &loader_id), + translate_slice_mut::(memory_mapping, address_addr, 32, loader_id), result ); *bump_seed_ref = bump_seed[0]; @@ -1094,21 +1071,15 @@ impl<'a, 'b> SyscallObject for SyscallSha256<'a, 'b> { result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let hash_result = question_mark!( - translate_slice_mut::(memory_mapping, result_addr, HASH_BYTES as u64, &loader_id), + translate_slice_mut::(memory_mapping, result_addr, HASH_BYTES as u64, loader_id), result ); let mut hasher = Hasher::default(); if vals_len > 0 { let vals = question_mark!( - translate_slice::<&[u8]>(memory_mapping, vals_addr, vals_len, &loader_id), + translate_slice::<&[u8]>(memory_mapping, vals_addr, vals_len, loader_id), result ); for val in vals.iter() { @@ -1117,7 +1088,7 @@ impl<'a, 'b> SyscallObject for SyscallSha256<'a, 'b> { memory_mapping, val.as_ptr() as u64, val.len() as u64, - &loader_id, + loader_id, ), result ); @@ -1170,17 +1141,11 @@ impl<'a, 'b> SyscallObject for SyscallGetClockSysvar<'a, 'b> { .map_err(|_| SyscallError::InvokeContextBorrowFailed), result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); *result = get_sysvar( invoke_context.get_sysvar_cache().get_clock(), var_addr, - &loader_id, + loader_id, memory_mapping, &mut invoke_context, ); @@ -1207,17 +1172,11 @@ impl<'a, 'b> SyscallObject for SyscallGetEpochScheduleSysvar<'a, 'b> { .map_err(|_| SyscallError::InvokeContextBorrowFailed), result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); *result = get_sysvar( invoke_context.get_sysvar_cache().get_epoch_schedule(), var_addr, - &loader_id, + loader_id, memory_mapping, &mut invoke_context, ); @@ -1245,17 +1204,11 @@ impl<'a, 'b> SyscallObject for SyscallGetFeesSysvar<'a, 'b> { .map_err(|_| SyscallError::InvokeContextBorrowFailed), result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); *result = get_sysvar( invoke_context.get_sysvar_cache().get_fees(), var_addr, - &loader_id, + loader_id, memory_mapping, &mut invoke_context, ); @@ -1282,17 +1235,11 @@ impl<'a, 'b> SyscallObject for SyscallGetRentSysvar<'a, 'b> { .map_err(|_| SyscallError::InvokeContextBorrowFailed), result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); *result = get_sysvar( invoke_context.get_sysvar_cache().get_rent(), var_addr, - &loader_id, + loader_id, memory_mapping, &mut invoke_context, ); @@ -1328,26 +1275,20 @@ impl<'a, 'b> SyscallObject for SyscallKeccak256<'a, 'b> { result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let hash_result = question_mark!( translate_slice_mut::( memory_mapping, result_addr, keccak::HASH_BYTES as u64, - &loader_id, + loader_id, ), result ); let mut hasher = keccak::Hasher::default(); if vals_len > 0 { let vals = question_mark!( - translate_slice::<&[u8]>(memory_mapping, vals_addr, vals_len, &loader_id), + translate_slice::<&[u8]>(memory_mapping, vals_addr, vals_len, loader_id), result ); for val in vals.iter() { @@ -1356,7 +1297,7 @@ impl<'a, 'b> SyscallObject for SyscallKeccak256<'a, 'b> { memory_mapping, val.as_ptr() as u64, val.len() as u64, - &loader_id, + loader_id, ), result ); @@ -1451,19 +1392,13 @@ impl<'a, 'b> SyscallObject for SyscallMemcpy<'a, 'b> { question_mark!(invoke_context.get_compute_meter().consume(cost), result); }; - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let dst = question_mark!( - translate_slice_mut::(memory_mapping, dst_addr, n, &loader_id), + translate_slice_mut::(memory_mapping, dst_addr, n, loader_id), result ); let src = question_mark!( - translate_slice::(memory_mapping, src_addr, n, &loader_id), + translate_slice::(memory_mapping, src_addr, n, loader_id), result ); unsafe { @@ -1495,19 +1430,13 @@ impl<'a, 'b> SyscallObject for SyscallMemmove<'a, 'b> { ); question_mark!(mem_op_consume(&invoke_context, n), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let dst = question_mark!( - translate_slice_mut::(memory_mapping, dst_addr, n, &loader_id), + translate_slice_mut::(memory_mapping, dst_addr, n, loader_id), result ); let src = question_mark!( - translate_slice::(memory_mapping, src_addr, n, &loader_id), + translate_slice::(memory_mapping, src_addr, n, loader_id), result ); unsafe { @@ -1539,23 +1468,17 @@ impl<'a, 'b> SyscallObject for SyscallMemcmp<'a, 'b> { ); question_mark!(mem_op_consume(&invoke_context, n), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let s1 = question_mark!( - translate_slice::(memory_mapping, s1_addr, n, &loader_id), + translate_slice::(memory_mapping, s1_addr, n, loader_id), result ); let s2 = question_mark!( - translate_slice::(memory_mapping, s2_addr, n, &loader_id), + translate_slice::(memory_mapping, s2_addr, n, loader_id), result ); let cmp_result = question_mark!( - translate_type_mut::(memory_mapping, cmp_result_addr, &loader_id), + translate_type_mut::(memory_mapping, cmp_result_addr, loader_id), result ); let mut i = 0; @@ -1596,15 +1519,9 @@ impl<'a, 'b> SyscallObject for SyscallMemset<'a, 'b> { ); question_mark!(mem_op_consume(&invoke_context, n), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let s = question_mark!( - translate_slice_mut::(memory_mapping, s_addr, n, &loader_id), + translate_slice_mut::(memory_mapping, s_addr, n, loader_id), result ); for val in s.iter_mut().take(n as usize) { @@ -1639,19 +1556,13 @@ impl<'a, 'b> SyscallObject for SyscallSecp256k1Recover<'a, 'b> { let cost = invoke_context.get_compute_budget().secp256k1_recover_cost; question_mark!(invoke_context.get_compute_meter().consume(cost), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let hash = question_mark!( translate_slice::( memory_mapping, hash_addr, keccak::HASH_BYTES as u64, - &loader_id, + loader_id, ), result ); @@ -1660,7 +1571,7 @@ impl<'a, 'b> SyscallObject for SyscallSecp256k1Recover<'a, 'b> { memory_mapping, signature_addr, SECP256K1_SIGNATURE_LENGTH as u64, - &loader_id, + loader_id, ), result ); @@ -1669,7 +1580,7 @@ impl<'a, 'b> SyscallObject for SyscallSecp256k1Recover<'a, 'b> { memory_mapping, result_addr, SECP256K1_PUBLIC_KEY_LENGTH as u64, - &loader_id, + loader_id, ), result ); @@ -1744,20 +1655,13 @@ impl<'a, 'b> SyscallObject for SyscallZkTokenElgamalOp<'a, 'b> { let cost = invoke_context.get_compute_budget().zk_token_elgamal_op_cost; question_mark!(invoke_context.get_compute_meter().consume(cost), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); - + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let ct_0 = question_mark!( - translate_type::(memory_mapping, ct_0_addr, &loader_id), + translate_type::(memory_mapping, ct_0_addr, loader_id), result ); let ct_1 = question_mark!( - translate_type::(memory_mapping, ct_1_addr, &loader_id), + translate_type::(memory_mapping, ct_1_addr, loader_id), result ); @@ -1770,7 +1674,7 @@ impl<'a, 'b> SyscallObject for SyscallZkTokenElgamalOp<'a, 'b> { translate_type_mut::( memory_mapping, ct_result_addr, - &loader_id, + loader_id, ), result ) = ct_result; @@ -1807,24 +1711,17 @@ impl<'a, 'b> SyscallObject for SyscallZkTokenElgamalOpWithLoHi<'a, 'b> let cost = invoke_context.get_compute_budget().zk_token_elgamal_op_cost; question_mark!(invoke_context.get_compute_meter().consume(cost), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); - + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let ct_0 = question_mark!( - translate_type::(memory_mapping, ct_0_addr, &loader_id), + translate_type::(memory_mapping, ct_0_addr, loader_id), result ); let ct_1_lo = question_mark!( - translate_type::(memory_mapping, ct_1_lo_addr, &loader_id), + translate_type::(memory_mapping, ct_1_lo_addr, loader_id), result ); let ct_1_hi = question_mark!( - translate_type::(memory_mapping, ct_1_hi_addr, &loader_id), + translate_type::(memory_mapping, ct_1_hi_addr, loader_id), result ); @@ -1837,7 +1734,7 @@ impl<'a, 'b> SyscallObject for SyscallZkTokenElgamalOpWithLoHi<'a, 'b> translate_type_mut::( memory_mapping, ct_result_addr, - &loader_id, + loader_id, ), result ) = ct_result; @@ -1874,16 +1771,9 @@ impl<'a, 'b> SyscallObject for SyscallZkTokenElgamalOpWithScalar<'a, ' let cost = invoke_context.get_compute_budget().zk_token_elgamal_op_cost; question_mark!(invoke_context.get_compute_meter().consume(cost), result); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); - + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let ct = question_mark!( - translate_type::(memory_mapping, ct_addr, &loader_id), + translate_type::(memory_mapping, ct_addr, loader_id), result ); @@ -1896,7 +1786,7 @@ impl<'a, 'b> SyscallObject for SyscallZkTokenElgamalOpWithScalar<'a, ' translate_type_mut::( memory_mapping, ct_result_addr, - &loader_id, + loader_id, ), result ) = ct_result; @@ -1936,26 +1826,20 @@ impl<'a, 'b> SyscallObject for SyscallBlake3<'a, 'b> { result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let hash_result = question_mark!( translate_slice_mut::( memory_mapping, result_addr, blake3::HASH_BYTES as u64, - &loader_id, + loader_id, ), result ); let mut hasher = blake3::Hasher::default(); if vals_len > 0 { let vals = question_mark!( - translate_slice::<&[u8]>(memory_mapping, vals_addr, vals_len, &loader_id), + translate_slice::<&[u8]>(memory_mapping, vals_addr, vals_len, loader_id), result ); for val in vals.iter() { @@ -1964,7 +1848,7 @@ impl<'a, 'b> SyscallObject for SyscallBlake3<'a, 'b> { memory_mapping, val.as_ptr() as u64, val.len() as u64, - &loader_id, + loader_id, ), result ); @@ -2676,7 +2560,11 @@ fn call<'a, 'b: 'a>( // Translate and verify caller's data let loader_id = invoke_context .transaction_context - .get_loader_key() + .get_current_instruction_context() + .and_then(|instruction_context| { + instruction_context.try_borrow_program_account(invoke_context.transaction_context) + }) + .map(|program_account| *program_account.get_owner()) .map_err(SyscallError::InstructionError)?; let instruction = syscall.translate_instruction( &loader_id, @@ -2684,9 +2572,12 @@ fn call<'a, 'b: 'a>( memory_mapping, *invoke_context, )?; - let caller_program_id = invoke_context - .transaction_context - .get_program_key() + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context + .get_current_instruction_context() + .map_err(SyscallError::InstructionError)?; + let caller_program_id = instruction_context + .get_program_key(transaction_context) .map_err(SyscallError::InstructionError)?; let signers = syscall.translate_signers( &loader_id, @@ -2801,14 +2692,7 @@ impl<'a, 'b> SyscallObject for SyscallSetReturnData<'a, 'b> { .map_err(|_| SyscallError::InvokeContextBorrowFailed), result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); - + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let budget = invoke_context.get_compute_budget(); question_mark!( @@ -2827,21 +2711,23 @@ impl<'a, 'b> SyscallObject for SyscallSetReturnData<'a, 'b> { Vec::new() } else { question_mark!( - translate_slice::(memory_mapping, addr, len, &loader_id), + translate_slice::(memory_mapping, addr, len, loader_id), result ) .to_vec() }; + let transaction_context = &mut invoke_context.transaction_context; let program_id = *question_mark!( - invoke_context - .transaction_context - .get_program_key() + transaction_context + .get_current_instruction_context() + .and_then( + |instruction_context| instruction_context.get_program_key(transaction_context) + ) .map_err(SyscallError::InstructionError), result ); question_mark!( - invoke_context - .transaction_context + transaction_context .set_return_data(program_id, return_data) .map_err(SyscallError::InstructionError), result @@ -2871,14 +2757,7 @@ impl<'a, 'b> SyscallObject for SyscallGetReturnData<'a, 'b> { .map_err(|_| SyscallError::InvokeContextBorrowFailed), result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); - + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let budget = invoke_context.get_compute_budget(); question_mark!( @@ -2899,14 +2778,14 @@ impl<'a, 'b> SyscallObject for SyscallGetReturnData<'a, 'b> { ); let return_data_result = question_mark!( - translate_slice_mut::(memory_mapping, return_data_addr, length, &loader_id), + translate_slice_mut::(memory_mapping, return_data_addr, length, loader_id), result ); return_data_result.copy_from_slice(&return_data[..length as usize]); let program_id_result = question_mark!( - translate_slice_mut::(memory_mapping, program_id_addr, 1, &loader_id), + translate_slice_mut::(memory_mapping, program_id_addr, 1, loader_id), result ); @@ -2939,14 +2818,7 @@ impl<'a, 'b> SyscallObject for SyscallLogData<'a, 'b> { .map_err(|_| SyscallError::InvokeContextBorrowFailed), result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); - + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let budget = invoke_context.get_compute_budget(); question_mark!( @@ -2957,7 +2829,7 @@ impl<'a, 'b> SyscallObject for SyscallLogData<'a, 'b> { ); let untranslated_fields = question_mark!( - translate_slice::<&[u8]>(memory_mapping, addr, len, &loader_id), + translate_slice::<&[u8]>(memory_mapping, addr, len, loader_id), result ); @@ -2986,7 +2858,7 @@ impl<'a, 'b> SyscallObject for SyscallLogData<'a, 'b> { memory_mapping, untranslated_field.as_ptr() as *const _ as u64, untranslated_field.len() as u64, - &loader_id, + loader_id, ), result )); @@ -3020,14 +2892,7 @@ impl<'a, 'b> SyscallObject for SyscallGetProcessedSiblingInstruction<' .map_err(|_| SyscallError::InvokeContextBorrowFailed), result ); - let loader_id = question_mark!( - invoke_context - .transaction_context - .get_loader_key() - .map_err(SyscallError::InstructionError), - result - ); - + let loader_id = &question_mark!(get_current_loader_key(&invoke_context), result); let budget = invoke_context.get_compute_budget(); question_mark!( invoke_context @@ -3071,7 +2936,7 @@ impl<'a, 'b> SyscallObject for SyscallGetProcessedSiblingInstruction<' translate_type_mut::( memory_mapping, meta_addr, - &loader_id + loader_id, ), result ); @@ -3080,7 +2945,7 @@ impl<'a, 'b> SyscallObject for SyscallGetProcessedSiblingInstruction<' && *accounts_len == instruction_context.get_number_of_instruction_accounts() { let program_id = question_mark!( - translate_type_mut::(memory_mapping, program_id_addr, &loader_id), + translate_type_mut::(memory_mapping, program_id_addr, loader_id), result ); let data = question_mark!( @@ -3088,7 +2953,7 @@ impl<'a, 'b> SyscallObject for SyscallGetProcessedSiblingInstruction<' memory_mapping, data_addr, *data_len as u64, - &loader_id, + loader_id, ), result ); @@ -3097,7 +2962,7 @@ impl<'a, 'b> SyscallObject for SyscallGetProcessedSiblingInstruction<' memory_mapping, accounts_addr, *accounts_len as u64, - &loader_id, + loader_id, ), result ); @@ -3196,6 +3061,28 @@ mod tests { }; } + macro_rules! prepare_mockup { + ($invoke_context:ident, + $transaction_context:ident, + $program_key:ident, + $loader_key:expr $(,)?) => { + let $program_key = Pubkey::new_unique(); + let mut $transaction_context = TransactionContext::new( + vec![ + ( + $loader_key, + AccountSharedData::new(0, 0, &native_loader::id()), + ), + ($program_key, AccountSharedData::new(0, 0, &$loader_key)), + ], + 1, + 1, + ); + let mut $invoke_context = InvokeContext::new_mock(&mut $transaction_context, &[]); + $invoke_context.push(&[], &[0, 1], &[]).unwrap(); + }; + } + #[allow(dead_code)] struct MockSlice { pub vm_addr: u64, @@ -3490,14 +3377,12 @@ mod tests { #[test] #[should_panic(expected = "UserError(SyscallError(Panic(\"Gaggablaghblagh!\", 42, 84)))")] fn test_syscall_sol_panic() { - let program_id = Pubkey::new_unique(); - let mut transaction_context = TransactionContext::new( - vec![(program_id, AccountSharedData::new(0, 0, &bpf_loader::id()))], - 1, - 1, + prepare_mockup!( + invoke_context, + transaction_context, + program_id, + bpf_loader::id(), ); - let mut invoke_context = InvokeContext::new_mock(&mut transaction_context, &[]); - invoke_context.push(&[], &[0], &[]).unwrap(); let mut syscall_panic = SyscallPanic { invoke_context: Rc::new(RefCell::new(&mut invoke_context)), }; @@ -3564,14 +3449,12 @@ mod tests { #[test] fn test_syscall_sol_log() { - let program_id = Pubkey::new_unique(); - let mut transaction_context = TransactionContext::new( - vec![(program_id, AccountSharedData::new(0, 0, &bpf_loader::id()))], - 1, - 1, + prepare_mockup!( + invoke_context, + transaction_context, + program_id, + bpf_loader::id(), ); - let mut invoke_context = InvokeContext::new_mock(&mut transaction_context, &[]); - invoke_context.push(&[], &[0], &[]).unwrap(); let mut syscall_sol_log = SyscallLog { invoke_context: Rc::new(RefCell::new(&mut invoke_context)), }; @@ -3665,14 +3548,12 @@ mod tests { #[test] fn test_syscall_sol_log_u64() { - let program_id = Pubkey::new_unique(); - let mut transaction_context = TransactionContext::new( - vec![(program_id, AccountSharedData::new(0, 0, &bpf_loader::id()))], - 1, - 1, + prepare_mockup!( + invoke_context, + transaction_context, + program_id, + bpf_loader::id(), ); - let mut invoke_context = InvokeContext::new_mock(&mut transaction_context, &[]); - invoke_context.push(&[], &[0], &[]).unwrap(); let cost = invoke_context.get_compute_budget().log_64_units; let mut syscall_sol_log_u64 = SyscallLogU64 { invoke_context: Rc::new(RefCell::new(&mut invoke_context)), @@ -3704,14 +3585,12 @@ mod tests { #[test] fn test_syscall_sol_pubkey() { - let program_id = Pubkey::new_unique(); - let mut transaction_context = TransactionContext::new( - vec![(program_id, AccountSharedData::new(0, 0, &bpf_loader::id()))], - 1, - 1, + prepare_mockup!( + invoke_context, + transaction_context, + program_id, + bpf_loader::id(), ); - let mut invoke_context = InvokeContext::new_mock(&mut transaction_context, &[]); - invoke_context.push(&[], &[0], &[]).unwrap(); let cost = invoke_context.get_compute_budget().log_pubkey_units; let mut syscall_sol_pubkey = SyscallLogPubkey { invoke_context: Rc::new(RefCell::new(&mut invoke_context)), @@ -3913,17 +3792,12 @@ mod tests { #[test] fn test_syscall_sha256() { let config = Config::default(); - let program_id = Pubkey::new_unique(); - let mut transaction_context = TransactionContext::new( - vec![( - program_id, - AccountSharedData::new(0, 0, &bpf_loader_deprecated::id()), - )], - 1, - 1, + prepare_mockup!( + invoke_context, + transaction_context, + program_id, + bpf_loader_deprecated::id(), ); - let mut invoke_context = InvokeContext::new_mock(&mut transaction_context, &[]); - invoke_context.push(&[], &[0], &[]).unwrap(); let bytes1 = "Gaggablaghblagh!"; let bytes2 = "flurbos"; @@ -4074,15 +3948,13 @@ mod tests { sysvar_cache.set_fees(src_fees.clone()); sysvar_cache.set_rent(src_rent); - let program_id = Pubkey::new_unique(); - let mut transaction_context = TransactionContext::new( - vec![(program_id, AccountSharedData::new(0, 0, &bpf_loader::id()))], - 1, - 1, + prepare_mockup!( + invoke_context, + transaction_context, + program_id, + bpf_loader::id(), ); - let mut invoke_context = InvokeContext::new_mock(&mut transaction_context, &[]); invoke_context.sysvar_cache = Cow::Owned(sysvar_cache); - invoke_context.push(&[], &[0], &[]).unwrap(); // Test clock sysvar { @@ -4325,14 +4197,12 @@ mod tests { fn test_create_program_address() { // These tests duplicate the direct tests in solana_program::pubkey - let program_id = Pubkey::new_unique(); - let mut transaction_context = TransactionContext::new( - vec![(program_id, AccountSharedData::new(0, 0, &bpf_loader::id()))], - 1, - 1, + prepare_mockup!( + invoke_context, + transaction_context, + program_id, + bpf_loader::id(), ); - let mut invoke_context = InvokeContext::new_mock(&mut transaction_context, &[]); - invoke_context.push(&[], &[0], &[]).unwrap(); let address = bpf_loader_upgradeable::id(); let exceeded_seed = &[127; MAX_SEED_LEN + 1]; @@ -4438,14 +4308,12 @@ mod tests { #[test] fn test_find_program_address() { - let program_id = Pubkey::new_unique(); - let mut transaction_context = TransactionContext::new( - vec![(program_id, AccountSharedData::new(0, 0, &bpf_loader::id()))], - 1, - 1, + prepare_mockup!( + invoke_context, + transaction_context, + program_id, + bpf_loader::id(), ); - let mut invoke_context = InvokeContext::new_mock(&mut transaction_context, &[]); - invoke_context.push(&[], &[0], &[]).unwrap(); let cost = invoke_context .get_compute_budget() .create_program_address_units; diff --git a/runtime/src/bank.rs b/runtime/src/bank.rs index 5eb6dab10..22b6aeffc 100644 --- a/runtime/src/bank.rs +++ b/runtime/src/bank.rs @@ -11212,7 +11212,9 @@ pub(crate) mod tests { _instruction_data: &[u8], invoke_context: &mut InvokeContext, ) -> std::result::Result<(), InstructionError> { - let program_id = invoke_context.transaction_context.get_program_key()?; + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context()?; + let program_id = instruction_context.get_program_key(transaction_context)?; if mock_vote_program_id() != *program_id { return Err(InstructionError::IncorrectProgramId); } diff --git a/runtime/src/builtins.rs b/runtime/src/builtins.rs index 8bb618dbb..b02163413 100644 --- a/runtime/src/builtins.rs +++ b/runtime/src/builtins.rs @@ -19,12 +19,16 @@ fn process_instruction_with_program_logging( invoke_context: &mut InvokeContext, ) -> Result<(), InstructionError> { let logger = invoke_context.get_log_collector(); - let program_id = invoke_context.transaction_context.get_program_key()?; + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context()?; + let program_id = instruction_context.get_program_key(transaction_context)?; stable_log::program_invoke(&logger, program_id, invoke_context.get_stack_height()); let result = process_instruction(first_instruction_account, instruction_data, invoke_context); - let program_id = invoke_context.transaction_context.get_program_key()?; + let transaction_context = &invoke_context.transaction_context; + let instruction_context = transaction_context.get_current_instruction_context()?; + let program_id = instruction_context.get_program_key(transaction_context)?; match &result { Ok(()) => stable_log::program_success(&logger, program_id), Err(err) => stable_log::program_failure(&logger, program_id, err), diff --git a/sdk/src/transaction_context.rs b/sdk/src/transaction_context.rs index 60934ad9a..cf87b7622 100644 --- a/sdk/src/transaction_context.rs +++ b/sdk/src/transaction_context.rs @@ -228,20 +228,6 @@ impl TransactionContext { Ok(()) } - /// Returns the key of the current InstructionContexts program account - pub fn get_program_key(&self) -> Result<&Pubkey, InstructionError> { - let instruction_context = self.get_current_instruction_context()?; - let program_account = instruction_context.try_borrow_program_account(self)?; - Ok(&self.account_keys[program_account.index_in_transaction]) - } - - /// Returns the owner of the current InstructionContexts program account - pub fn get_loader_key(&self) -> Result { - let instruction_context = self.get_current_instruction_context()?; - let program_account = instruction_context.try_borrow_program_account(self)?; - Ok(*program_account.get_owner()) - } - /// Gets the return data of the current InstructionContext or any above pub fn get_return_data(&self) -> (&Pubkey, &[u8]) { (&self.return_data.0, &self.return_data.1) @@ -408,6 +394,16 @@ impl InstructionContext { }) } + /// Gets the key of the last program account of this Instruction + pub fn get_program_key<'a, 'b: 'a>( + &'a self, + transaction_context: &'b TransactionContext, + ) -> Result<&'b Pubkey, InstructionError> { + let index_in_transaction = + self.get_index_in_transaction(self.program_accounts.len().saturating_sub(1))?; + transaction_context.get_key_of_account_at_index(index_in_transaction) + } + /// Gets the last program account of this Instruction pub fn try_borrow_program_account<'a, 'b: 'a>( &'a self,