diff --git a/token/program-2022-test/tests/confidential_transfer.rs b/token/program-2022-test/tests/confidential_transfer.rs index 626b09e4..95bb41a0 100644 --- a/token/program-2022-test/tests/confidential_transfer.rs +++ b/token/program-2022-test/tests/confidential_transfer.rs @@ -10,6 +10,7 @@ use { signer::keypair::Keypair, transaction::TransactionError, transport::TransportError, }, spl_token_2022::{ + error::TokenError, extension::{ confidential_transfer::{ ConfidentialTransferAccount, ConfidentialTransferMint, EncryptedWithheldAmount, @@ -331,6 +332,27 @@ async fn ct_configure_token_account() { .get_extension::() .unwrap(); assert!(bool::from(&extension.approved)); + + // Configuring an already initialized account should produce an error + let err = token + .confidential_transfer_configure_token_account( + &alice_meta.token_account, + &alice, + alice_meta.elgamal_keypair.public, + alice_meta.ae_key.encrypt(0_u64), + ) + .await + .unwrap_err(); + + assert_eq!( + err, + TokenClientError::Client(Box::new(TransportError::TransactionError( + TransactionError::InstructionError( + 0, + InstructionError::Custom(TokenError::ExtensionAlreadyInitialized as u32), + ) + ))) + ); } #[tokio::test] diff --git a/token/program-2022/src/extension/confidential_transfer/processor.rs b/token/program-2022/src/extension/confidential_transfer/processor.rs index 3b8862e7..6c474c38 100644 --- a/token/program-2022/src/extension/confidential_transfer/processor.rs +++ b/token/program-2022/src/extension/confidential_transfer/processor.rs @@ -54,7 +54,7 @@ fn process_initialize_mint( check_program_account(mint_info.owner)?; let mint_data = &mut mint_info.data.borrow_mut(); let mut mint = StateWithExtensionsMut::::unpack_uninitialized(mint_data)?; - *mint.init_extension::()? = *confidential_transfer_mint; + *mint.init_extension::(true)? = *confidential_transfer_mint; Ok(()) } @@ -125,7 +125,7 @@ fn process_configure_account( // Note: The caller is expected to use the `Reallocate` instruction to ensure there is // sufficient room in their token account for the new `ConfidentialTransferAccount` extension let mut confidential_transfer_account = - token_account.init_extension::()?; + token_account.init_extension::(false)?; confidential_transfer_account.approved = confidential_transfer_mint.auto_approve_new_accounts; confidential_transfer_account.encryption_pubkey = *encryption_pubkey; diff --git a/token/program-2022/src/extension/default_account_state/processor.rs b/token/program-2022/src/extension/default_account_state/processor.rs index 6367fa43..becc9fa8 100644 --- a/token/program-2022/src/extension/default_account_state/processor.rs +++ b/token/program-2022/src/extension/default_account_state/processor.rs @@ -36,7 +36,7 @@ fn process_initialize_default_account_state( let mint_account_info = next_account_info(account_info_iter)?; let mut mint_data = mint_account_info.data.borrow_mut(); let mut mint = StateWithExtensionsMut::::unpack_uninitialized(&mut mint_data)?; - let extension = mint.init_extension::()?; + let extension = mint.init_extension::(true)?; extension.state = state.into(); Ok(()) } diff --git a/token/program-2022/src/extension/interest_bearing_mint/processor.rs b/token/program-2022/src/extension/interest_bearing_mint/processor.rs index 24a2bc2f..3c05084a 100644 --- a/token/program-2022/src/extension/interest_bearing_mint/processor.rs +++ b/token/program-2022/src/extension/interest_bearing_mint/processor.rs @@ -36,7 +36,7 @@ fn process_initialize( let mut mint = StateWithExtensionsMut::::unpack_uninitialized(&mut mint_data)?; let clock = Clock::get()?; - let extension = mint.init_extension::()?; + let extension = mint.init_extension::(true)?; extension.rate_authority = *rate_authority; extension.initialization_timestamp = clock.unix_timestamp.into(); extension.last_update_timestamp = clock.unix_timestamp.into(); diff --git a/token/program-2022/src/extension/memo_transfer/processor.rs b/token/program-2022/src/extension/memo_transfer/processor.rs index 5d53aaa9..38c987ca 100644 --- a/token/program-2022/src/extension/memo_transfer/processor.rs +++ b/token/program-2022/src/extension/memo_transfer/processor.rs @@ -40,7 +40,7 @@ fn process_enable_required_memo_transfers( let extension = if let Ok(extension) = account.get_extension_mut::() { extension } else { - account.init_extension::()? + account.init_extension::(true)? }; extension.require_incoming_transfer_memos = true.into(); Ok(()) @@ -69,7 +69,7 @@ fn process_diasble_required_memo_transfers( let extension = if let Ok(extension) = account.get_extension_mut::() { extension } else { - account.init_extension::()? + account.init_extension::(true)? }; extension.require_incoming_transfer_memos = false.into(); Ok(()) diff --git a/token/program-2022/src/extension/mod.rs b/token/program-2022/src/extension/mod.rs index 4902f523..387fc5b1 100644 --- a/token/program-2022/src/extension/mod.rs +++ b/token/program-2022/src/extension/mod.rs @@ -413,7 +413,8 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> { } } - fn init_or_get_extension(&mut self, init: bool) -> Result<&mut V, ProgramError> { + /// Unpack a portion of the TLV data as the desired type that allows modifying the type + pub fn get_extension_mut(&mut self) -> Result<&mut V, ProgramError> { if V::TYPE.get_account_type() != S::ACCOUNT_TYPE { return Err(ProgramError::InvalidAccountData); } @@ -421,38 +422,14 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> { type_start, length_start, value_start, - } = get_extension_indices::(self.tlv_data, init)?; + } = get_extension_indices::(self.tlv_data, false)?; if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() { return Err(ProgramError::InvalidAccountData); } - if init { - // write extension type - let extension_type_array: [u8; 2] = V::TYPE.into(); - let extension_type_ref = &mut self.tlv_data[type_start..length_start]; - extension_type_ref.copy_from_slice(&extension_type_array); - // write length - let length_ref = - pod_from_bytes_mut::(&mut self.tlv_data[length_start..value_start])?; - // maybe this becomes smarter later for dynamically sized extensions - let length = pod_get_packed_len::(); - *length_ref = Length::try_from(length).unwrap(); - - let value_end = value_start.saturating_add(length); - let extension_ref = - pod_from_bytes_mut::(&mut self.tlv_data[value_start..value_end])?; - *extension_ref = V::default(); - Ok(extension_ref) - } else { - let length = pod_from_bytes::(&self.tlv_data[length_start..value_start])?; - let value_end = value_start.saturating_add(usize::from(*length)); - pod_from_bytes_mut::(&mut self.tlv_data[value_start..value_end]) - } - } - - /// Unpack a portion of the TLV data as the desired type that allows modifying the type - pub fn get_extension_mut(&mut self) -> Result<&mut V, ProgramError> { - self.init_or_get_extension(false) + let length = pod_from_bytes::(&self.tlv_data[length_start..value_start])?; + let value_end = value_start.saturating_add(usize::from(*length)); + pod_from_bytes_mut::(&mut self.tlv_data[value_start..value_end]) } /// Unpack a portion of the TLV data as the desired type @@ -480,9 +457,48 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> { } /// Packs the default extension data into an open slot if not already found in the - /// data buffer, otherwise overwrites the existing extension with the default state - pub fn init_extension(&mut self) -> Result<&mut V, ProgramError> { - self.init_or_get_extension(true) + /// data buffer. If extension is already found in the buffer, it overwrites the existing + /// extension with the default state if `overwrite` is set. If extension found, but + /// `overwrite` is not set, it returns error. + pub fn init_extension( + &mut self, + overwrite: bool, + ) -> Result<&mut V, ProgramError> { + if V::TYPE.get_account_type() != S::ACCOUNT_TYPE { + return Err(ProgramError::InvalidAccountData); + } + let TlvIndices { + type_start, + length_start, + value_start, + } = get_extension_indices::(self.tlv_data, true)?; + + if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() { + return Err(ProgramError::InvalidAccountData); + } + let extension_type = ExtensionType::try_from(&self.tlv_data[type_start..length_start])?; + + if extension_type == ExtensionType::Uninitialized || overwrite { + // write extension type + let extension_type_array: [u8; 2] = V::TYPE.into(); + let extension_type_ref = &mut self.tlv_data[type_start..length_start]; + extension_type_ref.copy_from_slice(&extension_type_array); + // write length + let length_ref = + pod_from_bytes_mut::(&mut self.tlv_data[length_start..value_start])?; + // maybe this becomes smarter later for dynamically sized extensions + let length = pod_get_packed_len::(); + *length_ref = Length::try_from(length).unwrap(); + + let value_end = value_start.saturating_add(length); + let extension_ref = + pod_from_bytes_mut::(&mut self.tlv_data[value_start..value_end])?; + *extension_ref = V::default(); + Ok(extension_ref) + } else { + // extension is already initialized, but no overwrite permission + Err(TokenError::ExtensionAlreadyInitialized.into()) + } } /// If `extension_type` is an Account-associated ExtensionType that requires initialization on @@ -498,14 +514,14 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> { } match extension_type { ExtensionType::TransferFeeAmount => { - self.init_extension::().map(|_| ()) + self.init_extension::(true).map(|_| ()) } // ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety // on InitializeAccount ExtensionType::ConfidentialTransferAccount => Ok(()), #[cfg(test)] ExtensionType::AccountPaddingTest => { - self.init_extension::().map(|_| ()) + self.init_extension::(true).map(|_| ()) } _ => unreachable!(), } @@ -932,19 +948,27 @@ mod test { let mut state = StateWithExtensionsMut::::unpack_uninitialized(&mut buffer).unwrap(); // fail init account extension assert_eq!( - state.init_extension::(), + state.init_extension::(true), Err(ProgramError::InvalidAccountData), ); // success write extension let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap(); - let extension = state.init_extension::().unwrap(); + let extension = state.init_extension::(true).unwrap(); extension.close_authority = close_authority; assert_eq!( &state.get_extension_types().unwrap(), &[ExtensionType::MintCloseAuthority] ); + // fail init extension when already initialized + assert_eq!( + state.init_extension::(false), + Err(ProgramError::Custom( + TokenError::ExtensionAlreadyInitialized as u32 + )) + ); + // fail unpack as account, a mint extension was written assert_eq!( StateWithExtensionsMut::::unpack_uninitialized(&mut buffer), @@ -1030,7 +1054,7 @@ mod test { let mut state = StateWithExtensionsMut::::unpack(&mut buffer).unwrap(); // init one more extension let mint_transfer_fee = test_transfer_fee_config(); - let new_extension = state.init_extension::().unwrap(); + let new_extension = state.init_extension::(true).unwrap(); new_extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority; new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority; @@ -1063,7 +1087,7 @@ mod test { // fail to init one more extension that does not fit let mut state = StateWithExtensionsMut::::unpack(&mut buffer).unwrap(); assert_eq!( - state.init_extension::(), + state.init_extension::(true), Err(ProgramError::InvalidAccountData), ); } @@ -1079,11 +1103,11 @@ mod test { let mut state = StateWithExtensionsMut::::unpack_uninitialized(&mut buffer).unwrap(); // write extensions let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap(); - let extension = state.init_extension::().unwrap(); + let extension = state.init_extension::(true).unwrap(); extension.close_authority = close_authority; let mint_transfer_fee = test_transfer_fee_config(); - let extension = state.init_extension::().unwrap(); + let extension = state.init_extension::(true).unwrap(); extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority; extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority; extension.withheld_amount = mint_transfer_fee.withheld_amount; @@ -1115,7 +1139,7 @@ mod test { // write extensions in a different order let mint_transfer_fee = test_transfer_fee_config(); - let extension = state.init_extension::().unwrap(); + let extension = state.init_extension::(true).unwrap(); extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority; extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority; extension.withheld_amount = mint_transfer_fee.withheld_amount; @@ -1123,7 +1147,7 @@ mod test { extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee; let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap(); - let extension = state.init_extension::().unwrap(); + let extension = state.init_extension::(true).unwrap(); extension.close_authority = close_authority; assert_eq!( @@ -1169,7 +1193,7 @@ mod test { state.init_account_type().unwrap(); // write padding - let extension = state.init_extension::().unwrap(); + let extension = state.init_extension::(true).unwrap(); extension.padding1 = [1; 128]; extension.padding2 = [1; 48]; extension.padding3 = [1; 9]; @@ -1206,12 +1230,12 @@ mod test { StateWithExtensionsMut::::unpack_uninitialized(&mut buffer).unwrap(); // fail init mint extension assert_eq!( - state.init_extension::(), + state.init_extension::(true), Err(ProgramError::InvalidAccountData), ); // success write extension let withheld_amount = PodU64::from(u64::MAX); - let extension = state.init_extension::().unwrap(); + let extension = state.init_extension::(true).unwrap(); extension.withheld_amount = withheld_amount; assert_eq!( @@ -1305,7 +1329,7 @@ mod test { state.init_account_type().unwrap(); // write padding - let extension = state.init_extension::().unwrap(); + let extension = state.init_extension::(true).unwrap(); extension.0.padding1 = [2; 128]; extension.0.padding2 = [2; 48]; extension.0.padding3 = [2; 9]; @@ -1341,7 +1365,7 @@ mod test { let mut state = StateWithExtensionsMut::::unpack(&mut buffer).unwrap(); assert_eq!(state.base, TEST_ACCOUNT); assert_eq!(state.account_type[0], AccountType::Account as u8); - state.init_extension::().unwrap(); // just confirming initialization works + state.init_extension::(true).unwrap(); // just confirming initialization works // account with buffer big enough for AccountType only let mut buffer = TEST_ACCOUNT_SLICE.to_vec(); @@ -1384,7 +1408,7 @@ mod test { let mut state = StateWithExtensionsMut::::unpack(&mut buffer).unwrap(); assert_eq!(state.base, TEST_MINT); assert_eq!(state.account_type[0], AccountType::Mint as u8); - state.init_extension::().unwrap(); + state.init_extension::(true).unwrap(); // mint with buffer big enough for AccountType only let mut buffer = TEST_MINT_SLICE.to_vec(); @@ -1499,7 +1523,7 @@ mod test { // fail init extension assert_eq!( - state.init_extension::(), + state.init_extension::(true), Err(ProgramError::InvalidAccountData), ); @@ -1514,7 +1538,7 @@ mod test { state.base = TEST_MINT; state.pack_base(); state.init_account_type().unwrap(); - let extension = state.init_extension::().unwrap(); + let extension = state.init_extension::(true).unwrap(); assert_eq!(extension.padding1, [1; 128]); assert_eq!(extension.padding2, [2; 48]); assert_eq!(extension.padding3, [3; 9]); @@ -1526,7 +1550,9 @@ mod test { ExtensionType::get_account_len::(&[ExtensionType::MintCloseAuthority]); let mut buffer = vec![0; mint_size - 1]; let mut state = StateWithExtensionsMut::::unpack_uninitialized(&mut buffer).unwrap(); - let err = state.init_extension::().unwrap_err(); + let err = state + .init_extension::(true) + .unwrap_err(); assert_eq!(err, ProgramError::InvalidAccountData); state.tlv_data[0] = 3; @@ -1556,7 +1582,7 @@ mod test { state.base = TEST_ACCOUNT; state.pack_base(); state.init_account_type().unwrap(); - state.init_extension::().unwrap(); + state.init_extension::(true).unwrap(); assert_eq!( get_first_extension_type(state.tlv_data).unwrap(), diff --git a/token/program-2022/src/extension/transfer_fee/processor.rs b/token/program-2022/src/extension/transfer_fee/processor.rs index 60ef5944..5d18b22f 100644 --- a/token/program-2022/src/extension/transfer_fee/processor.rs +++ b/token/program-2022/src/extension/transfer_fee/processor.rs @@ -36,7 +36,7 @@ fn process_initialize_transfer_fee_config( let mut mint_data = mint_account_info.data.borrow_mut(); let mut mint = StateWithExtensionsMut::::unpack_uninitialized(&mut mint_data)?; - let extension = mint.init_extension::()?; + let extension = mint.init_extension::(true)?; extension.transfer_fee_config_authority = transfer_fee_config_authority.try_into()?; extension.withdraw_withheld_authority = withdraw_withheld_authority.try_into()?; extension.withheld_amount = 0u64.into(); diff --git a/token/program-2022/src/processor.rs b/token/program-2022/src/processor.rs index 642aeab5..da0deeef 100644 --- a/token/program-2022/src/processor.rs +++ b/token/program-2022/src/processor.rs @@ -1031,7 +1031,7 @@ impl Processor { let mut mint_data = mint_account_info.data.borrow_mut(); let mut mint = StateWithExtensionsMut::::unpack_uninitialized(&mut mint_data)?; - let extension = mint.init_extension::()?; + let extension = mint.init_extension::(true)?; extension.close_authority = close_authority.try_into()?; Ok(()) @@ -1063,7 +1063,9 @@ impl Processor { let token_account_data = &mut token_account_info.data.borrow_mut(); let mut token_account = StateWithExtensionsMut::::unpack_uninitialized(token_account_data)?; - token_account.init_extension::().map(|_| ()) + token_account + .init_extension::(true) + .map(|_| ()) } /// Processes an [AmountToUiAmount](enum.TokenInstruction.html) instruction @@ -1160,7 +1162,7 @@ impl Processor { let mut mint_data = mint_account_info.data.borrow_mut(); let mut mint = StateWithExtensionsMut::::unpack_uninitialized(&mut mint_data)?; - mint.init_extension::()?; + mint.init_extension::(true)?; Ok(()) }