token-2022: Error sooner when searching for an extension in TLV data (#3761)

This commit is contained in:
Jon Cinque 2022-10-26 11:15:08 -04:00 committed by GitHub
parent 836c9e67a6
commit 0ff8fe1546
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 7 deletions

View File

@ -187,6 +187,9 @@ pub enum TokenError {
/// Account ownership cannot be changed while CPI Guard is enabled
#[error("Account ownership cannot be changed while CPI Guard is enabled")]
CpiGuardOwnerChangeBlocked,
/// Extension not found in account data
#[error("Extension not found in account data")]
ExtensionNotFound,
}
impl From<TokenError> for ProgramError {
fn from(e: TokenError) -> Self {
@ -323,6 +326,9 @@ impl PrintProgramError for TokenError {
TokenError::CpiGuardOwnerChangeBlocked => {
msg!("Account ownership cannot be changed while CPI Guard is enabled")
}
TokenError::ExtensionNotFound => {
msg!("Extension not found in account data")
}
}
}
}

View File

@ -104,16 +104,17 @@ fn get_extension_indices<V: Extension>(
let extension_type =
ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
let account_type = extension_type.get_account_type();
// got to an empty spot, can init here, or move forward if not initing
if extension_type == ExtensionType::Uninitialized {
if extension_type == V::TYPE {
// found an instance of the extension that we're initializing, return!
return Ok(tlv_indices);
// got to an empty spot, init here, or error if we're searching, since
// nothing is written after an Uninitialized spot
} else if extension_type == ExtensionType::Uninitialized {
if init {
return Ok(tlv_indices);
} else {
start_index = tlv_indices.length_start;
return Err(TokenError::ExtensionNotFound.into());
}
} else if extension_type == V::TYPE {
// found an instance of the extension that we're initializing, return!
return Ok(tlv_indices);
} else if v_account_type != account_type {
return Err(TokenError::ExtensionTypeMismatch.into());
} else {
@ -1603,8 +1604,14 @@ mod test {
state.base = TEST_ACCOUNT;
state.pack_base();
state.init_account_type().unwrap();
state.init_extension::<ImmutableOwner>(true).unwrap();
let err = state.get_extension::<ImmutableOwner>().unwrap_err();
assert_eq!(
err,
ProgramError::Custom(TokenError::ExtensionNotFound as u32)
);
state.init_extension::<ImmutableOwner>(true).unwrap();
assert_eq!(
get_first_extension_type(state.tlv_data).unwrap(),
Some(ExtensionType::ImmutableOwner)