token-2022: Error sooner when searching for an extension in TLV data (#3761)

This commit is contained in:
Jon Cinque 2022-10-26 11:15:08 -04:00 committed by GitHub
parent 836c9e67a6
commit 0ff8fe1546
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 7 deletions

View File

@ -187,6 +187,9 @@ pub enum TokenError {
/// Account ownership cannot be changed while CPI Guard is enabled /// Account ownership cannot be changed while CPI Guard is enabled
#[error("Account ownership cannot be changed while CPI Guard is enabled")] #[error("Account ownership cannot be changed while CPI Guard is enabled")]
CpiGuardOwnerChangeBlocked, CpiGuardOwnerChangeBlocked,
/// Extension not found in account data
#[error("Extension not found in account data")]
ExtensionNotFound,
} }
impl From<TokenError> for ProgramError { impl From<TokenError> for ProgramError {
fn from(e: TokenError) -> Self { fn from(e: TokenError) -> Self {
@ -323,6 +326,9 @@ impl PrintProgramError for TokenError {
TokenError::CpiGuardOwnerChangeBlocked => { TokenError::CpiGuardOwnerChangeBlocked => {
msg!("Account ownership cannot be changed while CPI Guard is enabled") msg!("Account ownership cannot be changed while CPI Guard is enabled")
} }
TokenError::ExtensionNotFound => {
msg!("Extension not found in account data")
}
} }
} }
} }

View File

@ -104,16 +104,17 @@ fn get_extension_indices<V: Extension>(
let extension_type = let extension_type =
ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?; ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
let account_type = extension_type.get_account_type(); let account_type = extension_type.get_account_type();
// got to an empty spot, can init here, or move forward if not initing if extension_type == V::TYPE {
if extension_type == ExtensionType::Uninitialized { // found an instance of the extension that we're initializing, return!
return Ok(tlv_indices);
// got to an empty spot, init here, or error if we're searching, since
// nothing is written after an Uninitialized spot
} else if extension_type == ExtensionType::Uninitialized {
if init { if init {
return Ok(tlv_indices); return Ok(tlv_indices);
} else { } else {
start_index = tlv_indices.length_start; return Err(TokenError::ExtensionNotFound.into());
} }
} else if extension_type == V::TYPE {
// found an instance of the extension that we're initializing, return!
return Ok(tlv_indices);
} else if v_account_type != account_type { } else if v_account_type != account_type {
return Err(TokenError::ExtensionTypeMismatch.into()); return Err(TokenError::ExtensionTypeMismatch.into());
} else { } else {
@ -1603,8 +1604,14 @@ mod test {
state.base = TEST_ACCOUNT; state.base = TEST_ACCOUNT;
state.pack_base(); state.pack_base();
state.init_account_type().unwrap(); state.init_account_type().unwrap();
state.init_extension::<ImmutableOwner>(true).unwrap();
let err = state.get_extension::<ImmutableOwner>().unwrap_err();
assert_eq!(
err,
ProgramError::Custom(TokenError::ExtensionNotFound as u32)
);
state.init_extension::<ImmutableOwner>(true).unwrap();
assert_eq!( assert_eq!(
get_first_extension_type(state.tlv_data).unwrap(), get_first_extension_type(state.tlv_data).unwrap(),
Some(ExtensionType::ImmutableOwner) Some(ExtensionType::ImmutableOwner)