Add extension realloc helper (#2821)

This commit is contained in:
Tyera Eulberg 2022-01-28 15:41:22 -07:00 committed by GitHub
parent eaaed0d3c4
commit f9e6f66758
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 83 additions and 0 deletions

View File

@ -491,6 +491,24 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
get_first_extension_type(self.tlv_data)
}
/// Compares the length of an extension with the currently used TLV buffer to determine if
/// reallocation is needed. If so, returns Some(v), where v is the difference between current
/// space and needed.
#[allow(dead_code)]
pub(crate) fn realloc_needed(
&self,
new_extension: ExtensionType,
) -> Result<Option<usize>, ProgramError> {
let current_extensions = self.get_extension_types()?;
let needed_tlv_len = ExtensionType::get_total_tlv_len(&current_extensions);
let new_needed_tlv_len = needed_tlv_len.saturating_add(new_extension.get_type_len());
if self.tlv_data.len() >= new_needed_tlv_len {
Ok(None)
} else {
Ok(Some(new_needed_tlv_len - self.tlv_data.len())) // arithmetic safe because of if clause
}
}
}
/// Different kinds of accounts. Note that `Mint`, `Account`, and `Multisig` types
@ -1344,4 +1362,69 @@ mod test {
assert_eq!(state.get_extension_types().unwrap(), vec![]);
}
#[test]
fn test_realloc_needed() {
// buffer exact size of existing extension
let mint_size = ExtensionType::get_account_len::<Mint>(&[ExtensionType::TransferFeeConfig]);
let mut buffer = vec![0; mint_size];
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
state.base = TEST_MINT;
state.pack_base();
state.init_account_type().unwrap();
assert_eq!(
state
.realloc_needed(ExtensionType::TransferFeeConfig)
.unwrap(),
None
);
state.init_extension::<TransferFeeConfig>().unwrap();
assert_eq!(
state
.realloc_needed(ExtensionType::MintCloseAuthority)
.unwrap(),
Some(ExtensionType::MintCloseAuthority.get_type_len())
);
// buffer with multisig len
let mint_size = ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintPaddingTest]);
let mut buffer = vec![0; mint_size];
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
state.base = TEST_MINT;
state.pack_base();
state.init_account_type().unwrap();
assert_eq!(
state
.realloc_needed(ExtensionType::MintPaddingTest)
.unwrap(),
None
);
state.init_extension::<MintPaddingTest>().unwrap();
assert_eq!(
state
.realloc_needed(ExtensionType::MintCloseAuthority)
.unwrap(),
Some(ExtensionType::MintCloseAuthority.get_type_len() - size_of::<ExtensionType>())
);
// huge buffer
let mut buffer = vec![0; u16::MAX.into()];
let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
state.base = TEST_MINT;
state.pack_base();
state.init_account_type().unwrap();
assert_eq!(
state
.realloc_needed(ExtensionType::TransferFeeConfig)
.unwrap(),
None
);
state.init_extension::<TransferFeeConfig>().unwrap();
assert_eq!(
state
.realloc_needed(ExtensionType::MintCloseAuthority)
.unwrap(),
None
);
}
}