lang: Add init_with_needed keyword (#906)

This commit is contained in:
Armani Ferrante 2021-10-21 18:05:16 -05:00 committed by GitHub
parent d41fb4feb5
commit 95bb9b3183
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 147 additions and 68 deletions

View File

@ -18,6 +18,7 @@ incremented for features.
* ts: `Program<T>` can now be typed with an IDL type ([#795](https://github.com/project-serum/anchor/pull/795)).
* lang: Add `mint::freeze_authority` keyword for mint initialization within `#[derive(Accounts)]` ([#835](https://github.com/project-serum/anchor/pull/835)).
* lang: Add `AccountLoader` type for `zero_copy` accounts with support for CPI ([#792](https://github.com/project-serum/anchor/pull/792)).
* lang: Add `#[account(init_if_needed)]` keyword for allowing one to invoke the same instruction even if the account was created already ([#906](https://github.com/project-serum/anchor/pull/906)).
* lang: Add custom errors support for raw constraints ([#905](https://github.com/project-serum/anchor/pull/905)).
### Breaking

View File

@ -316,7 +316,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma
}
}
};
generate_init(f, seeds_with_nonce, payer, &c.space, &c.kind)
generate_init(f, c.if_needed, seeds_with_nonce, payer, &c.space, &c.kind)
}
fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2::TokenStream {
@ -397,8 +397,10 @@ fn generate_constraint_associated_token(
}
}
// `if_needed` is set if account allocation and initialization is optional.
pub fn generate_init(
f: &Field,
if_needed: bool,
seeds_with_nonce: proc_macro2::TokenStream,
payer: proc_macro2::TokenStream,
space: &Option<Expr>,
@ -407,6 +409,11 @@ pub fn generate_init(
let field = &f.ident;
let ty_decl = f.ty_decl();
let from_account_info = f.from_account_info_unchecked(Some(kind));
let if_needed = if if_needed {
quote! {true}
} else {
quote! {false}
};
match kind {
InitKind::Token { owner, mint } => {
let create_account = generate_create_account(
@ -417,22 +424,25 @@ pub fn generate_init(
);
quote! {
let #field: #ty_decl = {
// Define payer variable.
#payer
if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
// Define payer variable.
#payer
// Create the account with the system program.
#create_account
// Create the account with the system program.
#create_account
// Initialize the token account.
let cpi_program = token_program.to_account_info();
let accounts = anchor_spl::token::InitializeAccount {
account: #field.to_account_info(),
mint: #mint.to_account_info(),
authority: #owner.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, accounts);
anchor_spl::token::initialize_account(cpi_ctx)?;
}
// Initialize the token account.
let cpi_program = token_program.to_account_info();
let accounts = anchor_spl::token::InitializeAccount {
account: #field.to_account_info(),
mint: #mint.to_account_info(),
authority: #owner.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, accounts);
anchor_spl::token::initialize_account(cpi_ctx)?;
let pa: #ty_decl = #from_account_info;
pa
};
@ -441,20 +451,22 @@ pub fn generate_init(
InitKind::AssociatedToken { owner, mint } => {
quote! {
let #field: #ty_decl = {
#payer
if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
#payer
let cpi_program = associated_token_program.to_account_info();
let cpi_accounts = anchor_spl::associated_token::Create {
payer: payer.to_account_info(),
associated_token: #field.to_account_info(),
authority: #owner.to_account_info(),
mint: #mint.to_account_info(),
system_program: system_program.to_account_info(),
token_program: token_program.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts);
anchor_spl::associated_token::create(cpi_ctx)?;
let cpi_program = associated_token_program.to_account_info();
let cpi_accounts = anchor_spl::associated_token::Create {
payer: payer.to_account_info(),
associated_token: #field.to_account_info(),
authority: #owner.to_account_info(),
mint: #mint.to_account_info(),
system_program: system_program.to_account_info(),
token_program: token_program.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts);
anchor_spl::associated_token::create(cpi_ctx)?;
}
let pa: #ty_decl = #from_account_info;
pa
};
@ -477,20 +489,22 @@ pub fn generate_init(
};
quote! {
let #field: #ty_decl = {
// Define payer variable.
#payer
if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
// Define payer variable.
#payer
// Create the account with the system program.
#create_account
// Create the account with the system program.
#create_account
// Initialize the mint account.
let cpi_program = token_program.to_account_info();
let accounts = anchor_spl::token::InitializeMint {
mint: #field.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, accounts);
anchor_spl::token::initialize_mint(cpi_ctx, #decimals, &#owner.to_account_info().key, #freeze_authority)?;
// Initialize the mint account.
let cpi_program = token_program.to_account_info();
let accounts = anchor_spl::token::InitializeMint {
mint: #field.to_account_info(),
rent: rent.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program, accounts);
anchor_spl::token::initialize_mint(cpi_ctx, #decimals, &#owner.to_account_info().key, #freeze_authority)?;
}
let pa: #ty_decl = #from_account_info;
pa
};
@ -535,9 +549,11 @@ pub fn generate_init(
generate_create_account(field, quote! {space}, owner, seeds_with_nonce);
quote! {
let #field = {
#space
#payer
#create_account
if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID {
#space
#payer
#create_account
}
let pa: #ty_decl = #from_account_info;
pa
};

View File

@ -594,7 +594,12 @@ impl Parse for ConstraintToken {
}
#[derive(Debug, Clone)]
pub struct ConstraintInit {}
pub struct ConstraintInit {
pub if_needed: bool,
}
#[derive(Debug, Clone)]
pub struct ConstraintInitIfNeeded {}
#[derive(Debug, Clone)]
pub struct ConstraintZeroed {}
@ -639,6 +644,7 @@ pub enum ConstraintRentExempt {
#[derive(Debug, Clone)]
pub struct ConstraintInitGroup {
pub if_needed: bool,
pub seeds: Option<ConstraintSeedsGroup>,
pub payer: Option<Expr>,
pub space: Option<Expr>,

View File

@ -60,7 +60,14 @@ pub fn parse_token(stream: ParseStream) -> ParseResult<ConstraintToken> {
let kw = ident.to_string();
let c = match kw.as_str() {
"init" => ConstraintToken::Init(Context::new(ident.span(), ConstraintInit {})),
"init" => ConstraintToken::Init(Context::new(
ident.span(),
ConstraintInit { if_needed: false },
)),
"init_if_needed" => ConstraintToken::Init(Context::new(
ident.span(),
ConstraintInit { if_needed: true },
)),
"zero" => ConstraintToken::Zeroed(Context::new(ident.span(), ConstraintZeroed {})),
"mut" => ConstraintToken::Mut(Context::new(ident.span(), ConstraintMut {})),
"signer" => ConstraintToken::Signer(Context::new(ident.span(), ConstraintSigner {})),
@ -518,7 +525,8 @@ impl<'ty> ConstraintGroupBuilder<'ty> {
_ => None,
};
Ok(ConstraintGroup {
init: init.as_ref().map(|_| Ok(ConstraintInitGroup {
init: init.as_ref().map(|i| Ok(ConstraintInitGroup {
if_needed: i.if_needed,
seeds: seeds.clone(),
payer: into_inner!(payer.clone()).map(|a| a.target),
space: space.clone().map(|s| s.space.clone()),

View File

@ -244,3 +244,11 @@ pub struct TestEmptySeedsConstraint<'info> {
#[account(seeds = [], bump)]
pub pda: AccountInfo<'info>,
}
#[derive(Accounts)]
pub struct TestInitIfNeeded<'info> {
#[account(init_if_needed, payer = payer)]
pub data: Account<'info, DataU16>,
pub payer: Signer<'info>,
pub system_program: Program<'info, System>,
}

View File

@ -184,4 +184,9 @@ pub mod misc {
pub fn test_empty_seeds_constraint(ctx: Context<TestEmptySeedsConstraint>) -> ProgramResult {
Ok(())
}
pub fn test_init_if_needed(ctx: Context<TestInitIfNeeded>, data: u16) -> ProgramResult {
ctx.accounts.data.data = data;
Ok(())
}
}

View File

@ -652,8 +652,8 @@ describe("misc", () => {
accounts: {
token: associatedToken,
mint: mint.publicKey,
wallet: program.provider.wallet.publicKey
}
wallet: program.provider.wallet.publicKey,
},
});
await assert.rejects(
@ -662,8 +662,8 @@ describe("misc", () => {
accounts: {
token: associatedToken,
mint: mint.publicKey,
wallet: anchor.web3.Keypair.generate().publicKey
}
wallet: anchor.web3.Keypair.generate().publicKey,
},
});
},
(err) => {
@ -735,12 +735,11 @@ describe("misc", () => {
]);
// Call for multiple kinds of .all.
const allAccounts = await program.account.dataWithFilter.all();
const allAccountsFilteredByBuffer =
await program.account.dataWithFilter.all(
program.provider.wallet.publicKey.toBuffer()
);
const allAccountsFilteredByProgramFilters1 =
await program.account.dataWithFilter.all([
const allAccountsFilteredByBuffer = await program.account.dataWithFilter.all(
program.provider.wallet.publicKey.toBuffer()
);
const allAccountsFilteredByProgramFilters1 = await program.account.dataWithFilter.all(
[
{
memcmp: {
offset: 8,
@ -748,9 +747,10 @@ describe("misc", () => {
},
},
{ memcmp: { offset: 40, bytes: filterable1.toBase58() } },
]);
const allAccountsFilteredByProgramFilters2 =
await program.account.dataWithFilter.all([
]
);
const allAccountsFilteredByProgramFilters2 = await program.account.dataWithFilter.all(
[
{
memcmp: {
offset: 8,
@ -758,7 +758,8 @@ describe("misc", () => {
},
},
{ memcmp: { offset: 40, bytes: filterable2.toBase58() } },
]);
]
);
// Without filters there should be 4 accounts.
assert.equal(allAccounts.length, 4);
// Filtering by main wallet there should be 3 accounts.
@ -772,27 +773,33 @@ describe("misc", () => {
});
it("Can use pdas with empty seeds", async () => {
const [pda, bump] = await PublicKey.findProgramAddress([], program.programId);
const [pda, bump] = await PublicKey.findProgramAddress(
[],
program.programId
);
await program.rpc.testInitWithEmptySeeds({
accounts: {
pda: pda,
authority: program.provider.wallet.publicKey,
systemProgram: anchor.web3.SystemProgram.programId
}
systemProgram: anchor.web3.SystemProgram.programId,
},
});
await program.rpc.testEmptySeedsConstraint({
accounts: {
pda: pda
}
pda: pda,
},
});
const [pda2, bump2] = await PublicKey.findProgramAddress(["non-empty"], program.programId);
const [pda2, bump2] = await PublicKey.findProgramAddress(
["non-empty"],
program.programId
);
await assert.rejects(
program.rpc.testEmptySeedsConstraint({
accounts: {
pda: pda2
}
pda: pda2,
},
}),
(err) => {
assert.equal(err.code, 146);
@ -800,4 +807,32 @@ describe("misc", () => {
}
);
});
const ifNeededAcc = anchor.web3.Keypair.generate();
it("Can init if needed a new account", async () => {
await program.rpc.testInitIfNeeded(1, {
accounts: {
data: ifNeededAcc.publicKey,
systemProgram: anchor.web3.SystemProgram.programId,
payer: program.provider.wallet.publicKey,
},
signers: [ifNeededAcc],
});
const account = await program.account.dataU16.fetch(ifNeededAcc.publicKey);
assert.ok(account.data, 1);
});
it("Can init if needed a previously created account", async () => {
await program.rpc.testInitIfNeeded(3, {
accounts: {
data: ifNeededAcc.publicKey,
systemProgram: anchor.web3.SystemProgram.programId,
payer: program.provider.wallet.publicKey,
},
signers: [ifNeededAcc],
});
const account = await program.account.dataU16.fetch(ifNeededAcc.publicKey);
assert.ok(account.data, 3);
});
});