token-2022: prevent an already configured confidential account to be configured again (#3216)

* token-2022: prevent an already configured confidential account to be configured again

* token-2022: add overwrite flag to init extension

* token-2022: clippy

* token-2022: update initialize mint for interest bearing mint

* token-2022: confidential transfer mint init allow overwrite
This commit is contained in:
samkim-crypto 2022-06-07 09:42:23 +09:00 committed by GitHub
parent 08d0592bea
commit c2a3ecd970
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 112 additions and 62 deletions

View File

@ -10,6 +10,7 @@ use {
signer::keypair::Keypair, transaction::TransactionError, transport::TransportError,
},
spl_token_2022::{
error::TokenError,
extension::{
confidential_transfer::{
ConfidentialTransferAccount, ConfidentialTransferMint, EncryptedWithheldAmount,
@ -331,6 +332,27 @@ async fn ct_configure_token_account() {
.get_extension::<ConfidentialTransferAccount>()
.unwrap();
assert!(bool::from(&extension.approved));
// Configuring an already initialized account should produce an error
let err = token
.confidential_transfer_configure_token_account(
&alice_meta.token_account,
&alice,
alice_meta.elgamal_keypair.public,
alice_meta.ae_key.encrypt(0_u64),
)
.await
.unwrap_err();
assert_eq!(
err,
TokenClientError::Client(Box::new(TransportError::TransactionError(
TransactionError::InstructionError(
0,
InstructionError::Custom(TokenError::ExtensionAlreadyInitialized as u32),
)
)))
);
}
#[tokio::test]

View File

@ -54,7 +54,7 @@ fn process_initialize_mint(
check_program_account(mint_info.owner)?;
let mint_data = &mut mint_info.data.borrow_mut();
let mut mint = StateWithExtensionsMut::<Mint>::unpack_uninitialized(mint_data)?;
*mint.init_extension::<ConfidentialTransferMint>()? = *confidential_transfer_mint;
*mint.init_extension::<ConfidentialTransferMint>(true)? = *confidential_transfer_mint;
Ok(())
}
@ -125,7 +125,7 @@ fn process_configure_account(
// Note: The caller is expected to use the `Reallocate` instruction to ensure there is
// sufficient room in their token account for the new `ConfidentialTransferAccount` extension
let mut confidential_transfer_account =
token_account.init_extension::<ConfidentialTransferAccount>()?;
token_account.init_extension::<ConfidentialTransferAccount>(false)?;
confidential_transfer_account.approved = confidential_transfer_mint.auto_approve_new_accounts;
confidential_transfer_account.encryption_pubkey = *encryption_pubkey;

View File

@ -36,7 +36,7 @@ fn process_initialize_default_account_state(
let mint_account_info = next_account_info(account_info_iter)?;
let mut mint_data = mint_account_info.data.borrow_mut();
let mut mint = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut mint_data)?;
let extension = mint.init_extension::<DefaultAccountState>()?;
let extension = mint.init_extension::<DefaultAccountState>(true)?;
extension.state = state.into();
Ok(())
}

View File

@ -36,7 +36,7 @@ fn process_initialize(
let mut mint = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut mint_data)?;
let clock = Clock::get()?;
let extension = mint.init_extension::<InterestBearingConfig>()?;
let extension = mint.init_extension::<InterestBearingConfig>(true)?;
extension.rate_authority = *rate_authority;
extension.initialization_timestamp = clock.unix_timestamp.into();
extension.last_update_timestamp = clock.unix_timestamp.into();

View File

@ -40,7 +40,7 @@ fn process_enable_required_memo_transfers(
let extension = if let Ok(extension) = account.get_extension_mut::<MemoTransfer>() {
extension
} else {
account.init_extension::<MemoTransfer>()?
account.init_extension::<MemoTransfer>(true)?
};
extension.require_incoming_transfer_memos = true.into();
Ok(())
@ -69,7 +69,7 @@ fn process_diasble_required_memo_transfers(
let extension = if let Ok(extension) = account.get_extension_mut::<MemoTransfer>() {
extension
} else {
account.init_extension::<MemoTransfer>()?
account.init_extension::<MemoTransfer>(true)?
};
extension.require_incoming_transfer_memos = false.into();
Ok(())

View File

@ -413,7 +413,8 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
}
}
fn init_or_get_extension<V: Extension>(&mut self, init: bool) -> Result<&mut V, ProgramError> {
/// Unpack a portion of the TLV data as the desired type that allows modifying the type
pub fn get_extension_mut<V: Extension>(&mut self) -> Result<&mut V, ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}
@ -421,39 +422,15 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
type_start,
length_start,
value_start,
} = get_extension_indices::<V>(self.tlv_data, init)?;
} = get_extension_indices::<V>(self.tlv_data, false)?;
if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
return Err(ProgramError::InvalidAccountData);
}
if init {
// write extension type
let extension_type_array: [u8; 2] = V::TYPE.into();
let extension_type_ref = &mut self.tlv_data[type_start..length_start];
extension_type_ref.copy_from_slice(&extension_type_array);
// write length
let length_ref =
pod_from_bytes_mut::<Length>(&mut self.tlv_data[length_start..value_start])?;
// maybe this becomes smarter later for dynamically sized extensions
let length = pod_get_packed_len::<V>();
*length_ref = Length::try_from(length).unwrap();
let value_end = value_start.saturating_add(length);
let extension_ref =
pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])?;
*extension_ref = V::default();
Ok(extension_ref)
} else {
let length = pod_from_bytes::<Length>(&self.tlv_data[length_start..value_start])?;
let value_end = value_start.saturating_add(usize::from(*length));
pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])
}
}
/// Unpack a portion of the TLV data as the desired type that allows modifying the type
pub fn get_extension_mut<V: Extension>(&mut self) -> Result<&mut V, ProgramError> {
self.init_or_get_extension(false)
}
/// Unpack a portion of the TLV data as the desired type
pub fn get_extension<V: Extension>(&self) -> Result<&V, ProgramError> {
@ -480,9 +457,48 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
}
/// Packs the default extension data into an open slot if not already found in the
/// data buffer, otherwise overwrites the existing extension with the default state
pub fn init_extension<V: Extension>(&mut self) -> Result<&mut V, ProgramError> {
self.init_or_get_extension(true)
/// data buffer. If extension is already found in the buffer, it overwrites the existing
/// extension with the default state if `overwrite` is set. If extension found, but
/// `overwrite` is not set, it returns error.
pub fn init_extension<V: Extension>(
&mut self,
overwrite: bool,
) -> Result<&mut V, ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}
let TlvIndices {
type_start,
length_start,
value_start,
} = get_extension_indices::<V>(self.tlv_data, true)?;
if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
return Err(ProgramError::InvalidAccountData);
}
let extension_type = ExtensionType::try_from(&self.tlv_data[type_start..length_start])?;
if extension_type == ExtensionType::Uninitialized || overwrite {
// write extension type
let extension_type_array: [u8; 2] = V::TYPE.into();
let extension_type_ref = &mut self.tlv_data[type_start..length_start];
extension_type_ref.copy_from_slice(&extension_type_array);
// write length
let length_ref =
pod_from_bytes_mut::<Length>(&mut self.tlv_data[length_start..value_start])?;
// maybe this becomes smarter later for dynamically sized extensions
let length = pod_get_packed_len::<V>();
*length_ref = Length::try_from(length).unwrap();
let value_end = value_start.saturating_add(length);
let extension_ref =
pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])?;
*extension_ref = V::default();
Ok(extension_ref)
} else {
// extension is already initialized, but no overwrite permission
Err(TokenError::ExtensionAlreadyInitialized.into())
}
}
/// If `extension_type` is an Account-associated ExtensionType that requires initialization on
@ -498,14 +514,14 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
}
match extension_type {
ExtensionType::TransferFeeAmount => {
self.init_extension::<TransferFeeAmount>().map(|_| ())
self.init_extension::<TransferFeeAmount>(true).map(|_| ())
}
// ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety
// on InitializeAccount
ExtensionType::ConfidentialTransferAccount => Ok(()),
#[cfg(test)]
ExtensionType::AccountPaddingTest => {
self.init_extension::<AccountPaddingTest>().map(|_| ())
self.init_extension::<AccountPaddingTest>(true).map(|_| ())
}
_ => unreachable!(),
}
@ -932,19 +948,27 @@ mod test {
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
// fail init account extension
assert_eq!(
state.init_extension::<TransferFeeAmount>(),
state.init_extension::<TransferFeeAmount>(true),
Err(ProgramError::InvalidAccountData),
);
// success write extension
let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
let extension = state.init_extension::<MintCloseAuthority>().unwrap();
let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
extension.close_authority = close_authority;
assert_eq!(
&state.get_extension_types().unwrap(),
&[ExtensionType::MintCloseAuthority]
);
// fail init extension when already initialized
assert_eq!(
state.init_extension::<MintCloseAuthority>(false),
Err(ProgramError::Custom(
TokenError::ExtensionAlreadyInitialized as u32
))
);
// fail unpack as account, a mint extension was written
assert_eq!(
StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer),
@ -1030,7 +1054,7 @@ mod test {
let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
// init one more extension
let mint_transfer_fee = test_transfer_fee_config();
let new_extension = state.init_extension::<TransferFeeConfig>().unwrap();
let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
new_extension.transfer_fee_config_authority =
mint_transfer_fee.transfer_fee_config_authority;
new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
@ -1063,7 +1087,7 @@ mod test {
// fail to init one more extension that does not fit
let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
assert_eq!(
state.init_extension::<MintPaddingTest>(),
state.init_extension::<MintPaddingTest>(true),
Err(ProgramError::InvalidAccountData),
);
}
@ -1079,11 +1103,11 @@ mod test {
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
// write extensions
let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
let extension = state.init_extension::<MintCloseAuthority>().unwrap();
let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
extension.close_authority = close_authority;
let mint_transfer_fee = test_transfer_fee_config();
let extension = state.init_extension::<TransferFeeConfig>().unwrap();
let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
extension.withheld_amount = mint_transfer_fee.withheld_amount;
@ -1115,7 +1139,7 @@ mod test {
// write extensions in a different order
let mint_transfer_fee = test_transfer_fee_config();
let extension = state.init_extension::<TransferFeeConfig>().unwrap();
let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
extension.withheld_amount = mint_transfer_fee.withheld_amount;
@ -1123,7 +1147,7 @@ mod test {
extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
let extension = state.init_extension::<MintCloseAuthority>().unwrap();
let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
extension.close_authority = close_authority;
assert_eq!(
@ -1169,7 +1193,7 @@ mod test {
state.init_account_type().unwrap();
// write padding
let extension = state.init_extension::<MintPaddingTest>().unwrap();
let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
extension.padding1 = [1; 128];
extension.padding2 = [1; 48];
extension.padding3 = [1; 9];
@ -1206,12 +1230,12 @@ mod test {
StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
// fail init mint extension
assert_eq!(
state.init_extension::<TransferFeeConfig>(),
state.init_extension::<TransferFeeConfig>(true),
Err(ProgramError::InvalidAccountData),
);
// success write extension
let withheld_amount = PodU64::from(u64::MAX);
let extension = state.init_extension::<TransferFeeAmount>().unwrap();
let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
extension.withheld_amount = withheld_amount;
assert_eq!(
@ -1305,7 +1329,7 @@ mod test {
state.init_account_type().unwrap();
// write padding
let extension = state.init_extension::<AccountPaddingTest>().unwrap();
let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
extension.0.padding1 = [2; 128];
extension.0.padding2 = [2; 48];
extension.0.padding3 = [2; 9];
@ -1341,7 +1365,7 @@ mod test {
let mut state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
assert_eq!(state.base, TEST_ACCOUNT);
assert_eq!(state.account_type[0], AccountType::Account as u8);
state.init_extension::<ImmutableOwner>().unwrap(); // just confirming initialization works
state.init_extension::<ImmutableOwner>(true).unwrap(); // just confirming initialization works
// account with buffer big enough for AccountType only
let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
@ -1384,7 +1408,7 @@ mod test {
let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
assert_eq!(state.base, TEST_MINT);
assert_eq!(state.account_type[0], AccountType::Mint as u8);
state.init_extension::<MintCloseAuthority>().unwrap();
state.init_extension::<MintCloseAuthority>(true).unwrap();
// mint with buffer big enough for AccountType only
let mut buffer = TEST_MINT_SLICE.to_vec();
@ -1499,7 +1523,7 @@ mod test {
// fail init extension
assert_eq!(
state.init_extension::<TransferFeeConfig>(),
state.init_extension::<TransferFeeConfig>(true),
Err(ProgramError::InvalidAccountData),
);
@ -1514,7 +1538,7 @@ mod test {
state.base = TEST_MINT;
state.pack_base();
state.init_account_type().unwrap();
let extension = state.init_extension::<MintPaddingTest>().unwrap();
let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
assert_eq!(extension.padding1, [1; 128]);
assert_eq!(extension.padding2, [2; 48]);
assert_eq!(extension.padding3, [3; 9]);
@ -1526,7 +1550,9 @@ mod test {
ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintCloseAuthority]);
let mut buffer = vec![0; mint_size - 1];
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
let err = state.init_extension::<MintCloseAuthority>().unwrap_err();
let err = state
.init_extension::<MintCloseAuthority>(true)
.unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
state.tlv_data[0] = 3;
@ -1556,7 +1582,7 @@ mod test {
state.base = TEST_ACCOUNT;
state.pack_base();
state.init_account_type().unwrap();
state.init_extension::<ImmutableOwner>().unwrap();
state.init_extension::<ImmutableOwner>(true).unwrap();
assert_eq!(
get_first_extension_type(state.tlv_data).unwrap(),

View File

@ -36,7 +36,7 @@ fn process_initialize_transfer_fee_config(
let mut mint_data = mint_account_info.data.borrow_mut();
let mut mint = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut mint_data)?;
let extension = mint.init_extension::<TransferFeeConfig>()?;
let extension = mint.init_extension::<TransferFeeConfig>(true)?;
extension.transfer_fee_config_authority = transfer_fee_config_authority.try_into()?;
extension.withdraw_withheld_authority = withdraw_withheld_authority.try_into()?;
extension.withheld_amount = 0u64.into();

View File

@ -1031,7 +1031,7 @@ impl Processor {
let mut mint_data = mint_account_info.data.borrow_mut();
let mut mint = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut mint_data)?;
let extension = mint.init_extension::<MintCloseAuthority>()?;
let extension = mint.init_extension::<MintCloseAuthority>(true)?;
extension.close_authority = close_authority.try_into()?;
Ok(())
@ -1063,7 +1063,9 @@ impl Processor {
let token_account_data = &mut token_account_info.data.borrow_mut();
let mut token_account =
StateWithExtensionsMut::<Account>::unpack_uninitialized(token_account_data)?;
token_account.init_extension::<ImmutableOwner>().map(|_| ())
token_account
.init_extension::<ImmutableOwner>(true)
.map(|_| ())
}
/// Processes an [AmountToUiAmount](enum.TokenInstruction.html) instruction
@ -1160,7 +1162,7 @@ impl Processor {
let mut mint_data = mint_account_info.data.borrow_mut();
let mut mint = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut mint_data)?;
mint.init_extension::<NonTransferable>()?;
mint.init_extension::<NonTransferable>(true)?;
Ok(())
}