token-cli: support cpi guard enable/disable

This commit is contained in:
hanako mumei 2022-10-28 13:31:22 -07:00 committed by hana
parent 9610bed534
commit 0bc2a02f00
1 changed files with 204 additions and 11 deletions

View File

@ -37,8 +37,9 @@ use solana_sdk::{
use spl_associated_token_account::get_associated_token_address_with_program_id;
use spl_token_2022::{
extension::{
interest_bearing_mint::InterestBearingConfig, memo_transfer::MemoTransfer,
mint_close_authority::MintCloseAuthority, ExtensionType, StateWithExtensionsOwned,
cpi_guard::CpiGuard, interest_bearing_mint::InterestBearingConfig,
memo_transfer::MemoTransfer, mint_close_authority::MintCloseAuthority, ExtensionType,
StateWithExtensionsOwned,
},
instruction::*,
state::{Account, Mint},
@ -131,6 +132,8 @@ pub enum CommandName {
SyncNative,
EnableRequiredTransferMemos,
DisableRequiredTransferMemos,
EnableCpiGuard,
DisableCpiGuard,
}
impl fmt::Display for CommandName {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@ -1842,8 +1845,7 @@ async fn command_sync_native(config: &Config<'_>, native_account_address: Pubkey
})
}
// Both enable_required_transfer_mesos and disable_required_transfer_mesos
// Switches with enable_memos bool
// both enables and disables required transfer memos, via enable_memos bool
async fn command_required_transfer_memos(
config: &Config<'_>,
token_account_address: Pubkey,
@ -1852,7 +1854,7 @@ async fn command_required_transfer_memos(
enable_memos: bool,
) -> CommandResult {
if config.sign_only {
panic!("Config can not be sign only for enabling/disabling required transfer memos.");
panic!("Config can not be sign-only for enabling/disabling required transfer memos.");
}
let account = config.get_account_checked(&token_account_address).await?;
@ -1864,14 +1866,15 @@ async fn command_required_transfer_memos(
// Reallocation (if needed)
let mut existing_extensions: Vec<ExtensionType> = state_with_extension.get_extension_types()?;
if existing_extensions.contains(&ExtensionType::MemoTransfer) {
let extension_data: bool = state_with_extension
let extension_state = state_with_extension
.get_extension::<MemoTransfer>()?
.require_incoming_transfer_memos
.into();
if extension_data == enable_memos {
if extension_state == enable_memos {
return Ok(format!(
"Required memo transfer was already {}",
if extension_data {
"Required transfer memos were already {}",
if extension_state {
"enabled"
} else {
"disabled"
@ -1914,6 +1917,78 @@ async fn command_required_transfer_memos(
})
}
// both enables and disables cpi guard, via enable_guard bool
async fn command_cpi_guard(
config: &Config<'_>,
token_account_address: Pubkey,
owner: Pubkey,
bulk_signers: BulkSigners,
enable_guard: bool,
) -> CommandResult {
if config.sign_only {
panic!("Config can not be sign-only for enabling/disabling required transfer memos.");
}
let account = config.get_account_checked(&token_account_address).await?;
let current_account_len = account.data.len();
let state_with_extension = StateWithExtensionsOwned::<Account>::unpack(account.data)?;
let token = token_client_from_config(config, &state_with_extension.base.mint, None)?;
// reallocation (if needed)
let mut existing_extensions: Vec<ExtensionType> = state_with_extension.get_extension_types()?;
if existing_extensions.contains(&ExtensionType::CpiGuard) {
let extension_state = state_with_extension
.get_extension::<CpiGuard>()?
.lock_cpi
.into();
if extension_state == enable_guard {
return Ok(format!(
"CPI Guard was already {}",
if extension_state {
"enabled"
} else {
"disabled"
}
));
}
} else {
existing_extensions.push(ExtensionType::CpiGuard);
let required_account_len = ExtensionType::get_account_len::<Account>(&existing_extensions);
if required_account_len > current_account_len {
token
.reallocate(
&token_account_address,
&owner,
&[ExtensionType::CpiGuard],
&bulk_signers,
)
.await?;
}
}
let res = if enable_guard {
token
.enable_cpi_guard(&token_account_address, &owner, &bulk_signers)
.await
} else {
token
.disable_cpi_guard(&token_account_address, &owner, &bulk_signers)
.await
}?;
let tx_return = finish_tx(config, &res, false).await?;
Ok(match tx_return {
TransactionReturnData::CliSignature(signature) => {
config.output_format.formatted_string(&signature)
}
TransactionReturnData::CliSignOnlyData(sign_only_data) => {
config.output_format.formatted_string(&sign_only_data)
}
})
}
struct SignOnlyNeedsFullMintSpec {}
impl offline::ArgsConfig for SignOnlyNeedsFullMintSpec {
fn sign_only_arg<'a, 'b>(&self, arg: Arg<'a, 'b>) -> Arg<'a, 'b> {
@ -2902,7 +2977,7 @@ fn app<'a, 'b>(
.takes_value(true)
.index(1)
.required(true)
.help("The address of the token account to enable required transfer memos")
.help("The address of the token account to require transfer memos for")
)
.arg(
owner_address_arg()
@ -2920,7 +2995,43 @@ fn app<'a, 'b>(
.takes_value(true)
.index(1)
.required(true)
.help("The address of the token account to disable required transfer memos"),
.help("The address of the token account to stop requiring transfer memos for"),
)
.arg(
owner_address_arg()
)
.arg(multisig_signer_arg())
.nonce_args(true)
)
.subcommand(
SubCommand::with_name(CommandName::EnableCpiGuard.into())
.about("Enable CPI Guard for token account")
.arg(
Arg::with_name("account")
.validator(is_valid_pubkey)
.value_name("TOKEN_ACCOUNT_ADDRESS")
.takes_value(true)
.index(1)
.required(true)
.help("The address of the token account to enable CPI Guard for")
)
.arg(
owner_address_arg()
)
.arg(multisig_signer_arg())
.nonce_args(true)
)
.subcommand(
SubCommand::with_name(CommandName::DisableCpiGuard.into())
.about("Disable CPI Guard for token account")
.arg(
Arg::with_name("account")
.validator(is_valid_pubkey)
.value_name("TOKEN_ACCOUNT_ADDRESS")
.takes_value(true)
.index(1)
.required(true)
.help("The address of the token account to disable CPI Guard for"),
)
.arg(
owner_address_arg()
@ -3509,6 +3620,28 @@ async fn process_command<'a>(
config.pubkey_or_default(arg_matches, "account", &mut wallet_manager)?;
command_required_transfer_memos(config, token_account, owner, bulk_signers, false).await
}
(CommandName::EnableCpiGuard, arg_matches) => {
let (owner_signer, owner) =
config.signer_or_default(arg_matches, "owner", &mut wallet_manager);
if !bulk_signers.contains(&owner_signer) {
bulk_signers.push(owner_signer);
}
// Since account is required argument it will always be present
let token_account =
config.pubkey_or_default(arg_matches, "account", &mut wallet_manager)?;
command_cpi_guard(config, token_account, owner, bulk_signers, true).await
}
(CommandName::DisableCpiGuard, arg_matches) => {
let (owner_signer, owner) =
config.signer_or_default(arg_matches, "owner", &mut wallet_manager);
if !bulk_signers.contains(&owner_signer) {
bulk_signers.push(owner_signer);
}
// Since account is required argument it will always be present
let token_account =
config.pubkey_or_default(arg_matches, "account", &mut wallet_manager)?;
command_cpi_guard(config, token_account, owner, bulk_signers, false).await
}
}
}
@ -5094,6 +5227,66 @@ mod tests {
assert!(!enabled);
}
#[tokio::test]
#[serial]
async fn cpi_guard() {
let (test_validator, payer) = new_validator_for_test().await;
let program_id = spl_token_2022::id();
let config = test_config_with_default_signer(&test_validator, &payer, &program_id);
let token = create_token(&config, &payer).await;
let token_account = create_associated_account(&config, &payer, token).await;
// enable works
process_test_command(
&config,
&payer,
&[
"spl-token",
CommandName::EnableCpiGuard.into(),
&token_account.to_string(),
],
)
.await
.unwrap();
let extensions = StateWithExtensionsOwned::<Account>::unpack(
config
.rpc_client
.get_account(&token_account)
.await
.unwrap()
.data,
)
.unwrap();
let cpi_guard = extensions.get_extension::<CpiGuard>().unwrap();
let enabled: bool = cpi_guard.lock_cpi.into();
assert!(enabled);
// disable works
process_test_command(
&config,
&payer,
&[
"spl-token",
CommandName::DisableCpiGuard.into(),
&token_account.to_string(),
],
)
.await
.unwrap();
let extensions = StateWithExtensionsOwned::<Account>::unpack(
config
.rpc_client
.get_account(&token_account)
.await
.unwrap()
.data,
)
.unwrap();
let cpi_guard = extensions.get_extension::<CpiGuard>().unwrap();
let enabled: bool = cpi_guard.lock_cpi.into();
assert!(!enabled);
}
#[tokio::test]
#[serial]
async fn immutable_accounts() {