diff --git a/CHANGELOG.md b/CHANGELOG.md index 7419ee675..1601cbaa9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ incremented for features. * ts: `Program` can now be typed with an IDL type ([#795](https://github.com/project-serum/anchor/pull/795)). * lang: Add `mint::freeze_authority` keyword for mint initialization within `#[derive(Accounts)]` ([#835](https://github.com/project-serum/anchor/pull/835)). * lang: Add `AccountLoader` type for `zero_copy` accounts with support for CPI ([#792](https://github.com/project-serum/anchor/pull/792)). +* lang: Add `#[account(init_if_needed)]` keyword for allowing one to invoke the same instruction even if the account was created already ([#906](https://github.com/project-serum/anchor/pull/906)). * lang: Add custom errors support for raw constraints ([#905](https://github.com/project-serum/anchor/pull/905)). ### Breaking diff --git a/lang/syn/src/codegen/accounts/constraints.rs b/lang/syn/src/codegen/accounts/constraints.rs index 0b14a223b..dd414b09e 100644 --- a/lang/syn/src/codegen/accounts/constraints.rs +++ b/lang/syn/src/codegen/accounts/constraints.rs @@ -316,7 +316,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma } } }; - generate_init(f, seeds_with_nonce, payer, &c.space, &c.kind) + generate_init(f, c.if_needed, seeds_with_nonce, payer, &c.space, &c.kind) } fn generate_constraint_seeds(f: &Field, c: &ConstraintSeedsGroup) -> proc_macro2::TokenStream { @@ -397,8 +397,10 @@ fn generate_constraint_associated_token( } } +// `if_needed` is set if account allocation and initialization is optional. pub fn generate_init( f: &Field, + if_needed: bool, seeds_with_nonce: proc_macro2::TokenStream, payer: proc_macro2::TokenStream, space: &Option, @@ -407,6 +409,11 @@ pub fn generate_init( let field = &f.ident; let ty_decl = f.ty_decl(); let from_account_info = f.from_account_info_unchecked(Some(kind)); + let if_needed = if if_needed { + quote! {true} + } else { + quote! {false} + }; match kind { InitKind::Token { owner, mint } => { let create_account = generate_create_account( @@ -417,22 +424,25 @@ pub fn generate_init( ); quote! { let #field: #ty_decl = { - // Define payer variable. - #payer + if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID { + // Define payer variable. + #payer - // Create the account with the system program. - #create_account + // Create the account with the system program. + #create_account + + // Initialize the token account. + let cpi_program = token_program.to_account_info(); + let accounts = anchor_spl::token::InitializeAccount { + account: #field.to_account_info(), + mint: #mint.to_account_info(), + authority: #owner.to_account_info(), + rent: rent.to_account_info(), + }; + let cpi_ctx = CpiContext::new(cpi_program, accounts); + anchor_spl::token::initialize_account(cpi_ctx)?; + } - // Initialize the token account. - let cpi_program = token_program.to_account_info(); - let accounts = anchor_spl::token::InitializeAccount { - account: #field.to_account_info(), - mint: #mint.to_account_info(), - authority: #owner.to_account_info(), - rent: rent.to_account_info(), - }; - let cpi_ctx = CpiContext::new(cpi_program, accounts); - anchor_spl::token::initialize_account(cpi_ctx)?; let pa: #ty_decl = #from_account_info; pa }; @@ -441,20 +451,22 @@ pub fn generate_init( InitKind::AssociatedToken { owner, mint } => { quote! { let #field: #ty_decl = { - #payer + if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID { + #payer - let cpi_program = associated_token_program.to_account_info(); - let cpi_accounts = anchor_spl::associated_token::Create { - payer: payer.to_account_info(), - associated_token: #field.to_account_info(), - authority: #owner.to_account_info(), - mint: #mint.to_account_info(), - system_program: system_program.to_account_info(), - token_program: token_program.to_account_info(), - rent: rent.to_account_info(), - }; - let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts); - anchor_spl::associated_token::create(cpi_ctx)?; + let cpi_program = associated_token_program.to_account_info(); + let cpi_accounts = anchor_spl::associated_token::Create { + payer: payer.to_account_info(), + associated_token: #field.to_account_info(), + authority: #owner.to_account_info(), + mint: #mint.to_account_info(), + system_program: system_program.to_account_info(), + token_program: token_program.to_account_info(), + rent: rent.to_account_info(), + }; + let cpi_ctx = CpiContext::new(cpi_program, cpi_accounts); + anchor_spl::associated_token::create(cpi_ctx)?; + } let pa: #ty_decl = #from_account_info; pa }; @@ -477,20 +489,22 @@ pub fn generate_init( }; quote! { let #field: #ty_decl = { - // Define payer variable. - #payer + if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID { + // Define payer variable. + #payer - // Create the account with the system program. - #create_account + // Create the account with the system program. + #create_account - // Initialize the mint account. - let cpi_program = token_program.to_account_info(); - let accounts = anchor_spl::token::InitializeMint { - mint: #field.to_account_info(), - rent: rent.to_account_info(), - }; - let cpi_ctx = CpiContext::new(cpi_program, accounts); - anchor_spl::token::initialize_mint(cpi_ctx, #decimals, &#owner.to_account_info().key, #freeze_authority)?; + // Initialize the mint account. + let cpi_program = token_program.to_account_info(); + let accounts = anchor_spl::token::InitializeMint { + mint: #field.to_account_info(), + rent: rent.to_account_info(), + }; + let cpi_ctx = CpiContext::new(cpi_program, accounts); + anchor_spl::token::initialize_mint(cpi_ctx, #decimals, &#owner.to_account_info().key, #freeze_authority)?; + } let pa: #ty_decl = #from_account_info; pa }; @@ -535,9 +549,11 @@ pub fn generate_init( generate_create_account(field, quote! {space}, owner, seeds_with_nonce); quote! { let #field = { - #space - #payer - #create_account + if !#if_needed || #field.to_account_info().owner == &anchor_lang::solana_program::system_program::ID { + #space + #payer + #create_account + } let pa: #ty_decl = #from_account_info; pa }; diff --git a/lang/syn/src/lib.rs b/lang/syn/src/lib.rs index ce52d38a9..cd7d7dcc4 100644 --- a/lang/syn/src/lib.rs +++ b/lang/syn/src/lib.rs @@ -594,7 +594,12 @@ impl Parse for ConstraintToken { } #[derive(Debug, Clone)] -pub struct ConstraintInit {} +pub struct ConstraintInit { + pub if_needed: bool, +} + +#[derive(Debug, Clone)] +pub struct ConstraintInitIfNeeded {} #[derive(Debug, Clone)] pub struct ConstraintZeroed {} @@ -639,6 +644,7 @@ pub enum ConstraintRentExempt { #[derive(Debug, Clone)] pub struct ConstraintInitGroup { + pub if_needed: bool, pub seeds: Option, pub payer: Option, pub space: Option, diff --git a/lang/syn/src/parser/accounts/constraints.rs b/lang/syn/src/parser/accounts/constraints.rs index b6e4ca588..d63a28445 100644 --- a/lang/syn/src/parser/accounts/constraints.rs +++ b/lang/syn/src/parser/accounts/constraints.rs @@ -60,7 +60,14 @@ pub fn parse_token(stream: ParseStream) -> ParseResult { let kw = ident.to_string(); let c = match kw.as_str() { - "init" => ConstraintToken::Init(Context::new(ident.span(), ConstraintInit {})), + "init" => ConstraintToken::Init(Context::new( + ident.span(), + ConstraintInit { if_needed: false }, + )), + "init_if_needed" => ConstraintToken::Init(Context::new( + ident.span(), + ConstraintInit { if_needed: true }, + )), "zero" => ConstraintToken::Zeroed(Context::new(ident.span(), ConstraintZeroed {})), "mut" => ConstraintToken::Mut(Context::new(ident.span(), ConstraintMut {})), "signer" => ConstraintToken::Signer(Context::new(ident.span(), ConstraintSigner {})), @@ -518,7 +525,8 @@ impl<'ty> ConstraintGroupBuilder<'ty> { _ => None, }; Ok(ConstraintGroup { - init: init.as_ref().map(|_| Ok(ConstraintInitGroup { + init: init.as_ref().map(|i| Ok(ConstraintInitGroup { + if_needed: i.if_needed, seeds: seeds.clone(), payer: into_inner!(payer.clone()).map(|a| a.target), space: space.clone().map(|s| s.space.clone()), diff --git a/tests/misc/programs/misc/src/context.rs b/tests/misc/programs/misc/src/context.rs index a3fb1d976..5089b565a 100644 --- a/tests/misc/programs/misc/src/context.rs +++ b/tests/misc/programs/misc/src/context.rs @@ -244,3 +244,11 @@ pub struct TestEmptySeedsConstraint<'info> { #[account(seeds = [], bump)] pub pda: AccountInfo<'info>, } + +#[derive(Accounts)] +pub struct TestInitIfNeeded<'info> { + #[account(init_if_needed, payer = payer)] + pub data: Account<'info, DataU16>, + pub payer: Signer<'info>, + pub system_program: Program<'info, System>, +} diff --git a/tests/misc/programs/misc/src/lib.rs b/tests/misc/programs/misc/src/lib.rs index 940f8eee6..4b012ae67 100644 --- a/tests/misc/programs/misc/src/lib.rs +++ b/tests/misc/programs/misc/src/lib.rs @@ -184,4 +184,9 @@ pub mod misc { pub fn test_empty_seeds_constraint(ctx: Context) -> ProgramResult { Ok(()) } + + pub fn test_init_if_needed(ctx: Context, data: u16) -> ProgramResult { + ctx.accounts.data.data = data; + Ok(()) + } } diff --git a/tests/misc/tests/misc.js b/tests/misc/tests/misc.js index 4873f1da4..41fe271bc 100644 --- a/tests/misc/tests/misc.js +++ b/tests/misc/tests/misc.js @@ -652,8 +652,8 @@ describe("misc", () => { accounts: { token: associatedToken, mint: mint.publicKey, - wallet: program.provider.wallet.publicKey - } + wallet: program.provider.wallet.publicKey, + }, }); await assert.rejects( @@ -662,8 +662,8 @@ describe("misc", () => { accounts: { token: associatedToken, mint: mint.publicKey, - wallet: anchor.web3.Keypair.generate().publicKey - } + wallet: anchor.web3.Keypair.generate().publicKey, + }, }); }, (err) => { @@ -735,12 +735,11 @@ describe("misc", () => { ]); // Call for multiple kinds of .all. const allAccounts = await program.account.dataWithFilter.all(); - const allAccountsFilteredByBuffer = - await program.account.dataWithFilter.all( - program.provider.wallet.publicKey.toBuffer() - ); - const allAccountsFilteredByProgramFilters1 = - await program.account.dataWithFilter.all([ + const allAccountsFilteredByBuffer = await program.account.dataWithFilter.all( + program.provider.wallet.publicKey.toBuffer() + ); + const allAccountsFilteredByProgramFilters1 = await program.account.dataWithFilter.all( + [ { memcmp: { offset: 8, @@ -748,9 +747,10 @@ describe("misc", () => { }, }, { memcmp: { offset: 40, bytes: filterable1.toBase58() } }, - ]); - const allAccountsFilteredByProgramFilters2 = - await program.account.dataWithFilter.all([ + ] + ); + const allAccountsFilteredByProgramFilters2 = await program.account.dataWithFilter.all( + [ { memcmp: { offset: 8, @@ -758,7 +758,8 @@ describe("misc", () => { }, }, { memcmp: { offset: 40, bytes: filterable2.toBase58() } }, - ]); + ] + ); // Without filters there should be 4 accounts. assert.equal(allAccounts.length, 4); // Filtering by main wallet there should be 3 accounts. @@ -772,27 +773,33 @@ describe("misc", () => { }); it("Can use pdas with empty seeds", async () => { - const [pda, bump] = await PublicKey.findProgramAddress([], program.programId); + const [pda, bump] = await PublicKey.findProgramAddress( + [], + program.programId + ); await program.rpc.testInitWithEmptySeeds({ accounts: { pda: pda, authority: program.provider.wallet.publicKey, - systemProgram: anchor.web3.SystemProgram.programId - } + systemProgram: anchor.web3.SystemProgram.programId, + }, }); await program.rpc.testEmptySeedsConstraint({ accounts: { - pda: pda - } + pda: pda, + }, }); - const [pda2, bump2] = await PublicKey.findProgramAddress(["non-empty"], program.programId); + const [pda2, bump2] = await PublicKey.findProgramAddress( + ["non-empty"], + program.programId + ); await assert.rejects( program.rpc.testEmptySeedsConstraint({ accounts: { - pda: pda2 - } + pda: pda2, + }, }), (err) => { assert.equal(err.code, 146); @@ -800,4 +807,32 @@ describe("misc", () => { } ); }); + + const ifNeededAcc = anchor.web3.Keypair.generate(); + + it("Can init if needed a new account", async () => { + await program.rpc.testInitIfNeeded(1, { + accounts: { + data: ifNeededAcc.publicKey, + systemProgram: anchor.web3.SystemProgram.programId, + payer: program.provider.wallet.publicKey, + }, + signers: [ifNeededAcc], + }); + const account = await program.account.dataU16.fetch(ifNeededAcc.publicKey); + assert.ok(account.data, 1); + }); + + it("Can init if needed a previously created account", async () => { + await program.rpc.testInitIfNeeded(3, { + accounts: { + data: ifNeededAcc.publicKey, + systemProgram: anchor.web3.SystemProgram.programId, + payer: program.provider.wallet.publicKey, + }, + signers: [ifNeededAcc], + }); + const account = await program.account.dataU16.fetch(ifNeededAcc.publicKey); + assert.ok(account.data, 3); + }); });