diff --git a/programs/bpf/rust/invoke/src/lib.rs b/programs/bpf/rust/invoke/src/lib.rs index fdb66f2b3e..fa179eff26 100644 --- a/programs/bpf/rust/invoke/src/lib.rs +++ b/programs/bpf/rust/invoke/src/lib.rs @@ -10,7 +10,7 @@ use solana_sdk::{ entrypoint, entrypoint::ProgramResult, info, - program::{create_program_address, invoke, invoke_signed}, + program::{invoke, invoke_signed}, program_error::ProgramError, pubkey::Pubkey, system_instruction, @@ -97,7 +97,8 @@ fn process_instruction( info!("Test create_program_address"); { - let address = create_program_address(&[b"You pass butter", &[nonce1]], program_id)?; + let address = + Pubkey::create_program_address(&[b"You pass butter", &[nonce1]], program_id)?; assert_eq!(&address, accounts[DERIVED_KEY1_INDEX].key); } diff --git a/sdk/src/program.rs b/sdk/src/program.rs index 94ba27a0ee..d2a2fbd3e9 100644 --- a/sdk/src/program.rs +++ b/sdk/src/program.rs @@ -2,39 +2,9 @@ use crate::{ account_info::AccountInfo, entrypoint::ProgramResult, entrypoint::SUCCESS, - instruction::Instruction, program_error::ProgramError, pubkey::Pubkey, + instruction::Instruction, }; -pub fn create_program_address( - seeds: &[&[u8]], - program_id: &Pubkey, -) -> Result { - let bytes = [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, - ]; - let result = unsafe { - sol_create_program_address( - seeds as *const _ as *const u8, - seeds.len() as u64, - program_id as *const _ as *const u8, - &bytes as *const _ as *const u8, - ) - }; - match result { - SUCCESS => Ok(Pubkey::new(&bytes)), - _ => Err(result.into()), - } -} -extern "C" { - fn sol_create_program_address( - seeds_addr: *const u8, - seeds_len: u64, - program_id_addr: *const u8, - address_bytes_addr: *const u8, - ) -> u64; -} - /// Invoke a cross-program instruction pub fn invoke(instruction: &Instruction, account_infos: &[AccountInfo]) -> ProgramResult { invoke_signed(instruction, account_infos, &[]) diff --git a/sdk/src/pubkey.rs b/sdk/src/pubkey.rs index ad5d6248b9..c6a69c4f8e 100644 --- a/sdk/src/pubkey.rs +++ b/sdk/src/pubkey.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "program")] +use crate::entrypoint::SUCCESS; #[cfg(not(feature = "program"))] use crate::hash::Hasher; use crate::{decode_error::DecodeError, hash::hashv}; @@ -25,6 +27,15 @@ impl DecodeError for PubkeyError { "PubkeyError" } } +impl From for PubkeyError { + fn from(error: u64) -> Self { + match error { + 0 => PubkeyError::MaxSeedLengthExceeded, + 1 => PubkeyError::InvalidSeeds, + _ => panic!("Unsupported PubkeyError"), + } + } +} #[repr(transparent)] #[derive( @@ -90,34 +101,66 @@ impl Pubkey { /// Create a program address, valid program address must not be on the /// ed25519 curve - #[cfg(not(feature = "program"))] pub fn create_program_address( seeds: &[&[u8]], program_id: &Pubkey, ) -> Result { - let mut hasher = Hasher::default(); - for seed in seeds.iter() { - if seed.len() > MAX_SEED_LEN { - return Err(PubkeyError::MaxSeedLengthExceeded); - } - hasher.hash(seed); - } - hasher.hashv(&[program_id.as_ref(), "ProgramDerivedAddress".as_ref()]); - let hash = hasher.result(); - - if curve25519_dalek::edwards::CompressedEdwardsY::from_slice(hash.as_ref()) - .decompress() - .is_some() + // Perform the calculation inline, calling this from within a program is + // not supported + #[cfg(not(feature = "program"))] { - return Err(PubkeyError::InvalidSeeds); - } + let mut hasher = Hasher::default(); + for seed in seeds.iter() { + if seed.len() > MAX_SEED_LEN { + return Err(PubkeyError::MaxSeedLengthExceeded); + } + hasher.hash(seed); + } + hasher.hashv(&[program_id.as_ref(), "ProgramDerivedAddress".as_ref()]); + let hash = hasher.result(); - Ok(Pubkey::new(hash.as_ref())) + if curve25519_dalek::edwards::CompressedEdwardsY::from_slice(hash.as_ref()) + .decompress() + .is_some() + { + return Err(PubkeyError::InvalidSeeds); + } + + Ok(Pubkey::new(hash.as_ref())) + } + // Call via a system call to perform the calculation + #[cfg(feature = "program")] + { + extern "C" { + fn sol_create_program_address( + seeds_addr: *const u8, + seeds_len: u64, + program_id_addr: *const u8, + address_bytes_addr: *const u8, + ) -> u64; + }; + let bytes = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]; + let result = unsafe { + sol_create_program_address( + seeds as *const _ as *const u8, + seeds.len() as u64, + program_id as *const _ as *const u8, + &bytes as *const _ as *const u8, + ) + }; + match result { + SUCCESS => Ok(Pubkey::new(&bytes)), + _ => Err(result.into()), + } + } } /// Find a valid program address and its corresponding nonce which must be passed /// as an additional seed when calling `create_program_address` - #[cfg(not(feature = "program"))] + // #[cfg(not(feature = "program"))] pub fn find_program_address(seeds: &[&[u8]], program_id: &Pubkey) -> (Pubkey, u8) { let mut nonce = [255]; for _ in 0..std::u8::MAX { @@ -143,6 +186,8 @@ impl Pubkey { } } +// TODO localalize this + impl AsRef<[u8]> for Pubkey { fn as_ref(&self) -> &[u8] { &self.0[..]