From 0ff8fe1546ea889b95c50ea3bdf50dc331d45b67 Mon Sep 17 00:00:00 2001 From: Jon Cinque Date: Wed, 26 Oct 2022 11:15:08 -0400 Subject: [PATCH] token-2022: Error sooner when searching for an extension in TLV data (#3761) --- token/program-2022/src/error.rs | 6 ++++++ token/program-2022/src/extension/mod.rs | 21 ++++++++++++++------- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/token/program-2022/src/error.rs b/token/program-2022/src/error.rs index 041b50eb..e10b5b20 100644 --- a/token/program-2022/src/error.rs +++ b/token/program-2022/src/error.rs @@ -187,6 +187,9 @@ pub enum TokenError { /// Account ownership cannot be changed while CPI Guard is enabled #[error("Account ownership cannot be changed while CPI Guard is enabled")] CpiGuardOwnerChangeBlocked, + /// Extension not found in account data + #[error("Extension not found in account data")] + ExtensionNotFound, } impl From for ProgramError { fn from(e: TokenError) -> Self { @@ -323,6 +326,9 @@ impl PrintProgramError for TokenError { TokenError::CpiGuardOwnerChangeBlocked => { msg!("Account ownership cannot be changed while CPI Guard is enabled") } + TokenError::ExtensionNotFound => { + msg!("Extension not found in account data") + } } } } diff --git a/token/program-2022/src/extension/mod.rs b/token/program-2022/src/extension/mod.rs index 3e83aae5..82c8d332 100644 --- a/token/program-2022/src/extension/mod.rs +++ b/token/program-2022/src/extension/mod.rs @@ -104,16 +104,17 @@ fn get_extension_indices( let extension_type = ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?; let account_type = extension_type.get_account_type(); - // got to an empty spot, can init here, or move forward if not initing - if extension_type == ExtensionType::Uninitialized { + if extension_type == V::TYPE { + // found an instance of the extension that we're initializing, return! + return Ok(tlv_indices); + // got to an empty spot, init here, or error if we're searching, since + // nothing is written after an Uninitialized spot + } else if extension_type == ExtensionType::Uninitialized { if init { return Ok(tlv_indices); } else { - start_index = tlv_indices.length_start; + return Err(TokenError::ExtensionNotFound.into()); } - } else if extension_type == V::TYPE { - // found an instance of the extension that we're initializing, return! - return Ok(tlv_indices); } else if v_account_type != account_type { return Err(TokenError::ExtensionTypeMismatch.into()); } else { @@ -1603,8 +1604,14 @@ mod test { state.base = TEST_ACCOUNT; state.pack_base(); state.init_account_type().unwrap(); - state.init_extension::(true).unwrap(); + let err = state.get_extension::().unwrap_err(); + assert_eq!( + err, + ProgramError::Custom(TokenError::ExtensionNotFound as u32) + ); + + state.init_extension::(true).unwrap(); assert_eq!( get_first_extension_type(state.tlv_data).unwrap(), Some(ExtensionType::ImmutableOwner)