lang, ts: Fallback functions (#457)

This commit is contained in:
Armani Ferrante 2021-07-02 17:33:48 -07:00 committed by GitHub
parent 6ad68ed368
commit 915e6dd398
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 156 additions and 35 deletions

View File

@ -11,6 +11,11 @@ incremented for features.
## [Unreleased]
### Features
* lang: Add fallback functions ([#457](https://github.com/project-serum/anchor/pull/457)).
* lang: Add feature flag for using the old state account discriminator. This is a temporary flag for those with programs built prior to v0.7.0 but want to use the latest Anchor version. Expect this to be removed in a future version ([#446](https://github.com/project-serum/anchor/pull/446)).
### Breaking Changes
* cli: Remove `.spec` suffix on TypeScript tests files ([#441](https://github.com/project-serum/anchor/pull/441)).

View File

@ -128,6 +128,14 @@ pub mod misc {
pub fn test_token_seeds_init(_ctx: Context<TestTokenSeedsInit>, _nonce: u8) -> ProgramResult {
Ok(())
}
pub fn default<'info>(
_program_id: &Pubkey,
_accounts: &[AccountInfo<'info>],
_data: &[u8],
) -> ProgramResult {
Err(ProgramError::Custom(1234))
}
}
#[derive(Accounts)]

View File

@ -141,17 +141,19 @@ describe("misc", () => {
// Manual associated address calculation for test only. Clients should use
// the generated methods.
const [associatedAccount, nonce] =
await anchor.web3.PublicKey.findProgramAddress(
[
anchor.utils.bytes.utf8.encode("anchor"),
program.provider.wallet.publicKey.toBuffer(),
state.toBuffer(),
data.publicKey.toBuffer(),
anchor.utils.bytes.utf8.encode("my-seed"),
],
program.programId
);
const [
associatedAccount,
nonce,
] = await anchor.web3.PublicKey.findProgramAddress(
[
anchor.utils.bytes.utf8.encode("anchor"),
program.provider.wallet.publicKey.toBuffer(),
state.toBuffer(),
data.publicKey.toBuffer(),
anchor.utils.bytes.utf8.encode("my-seed"),
],
program.programId
);
await assert.rejects(
async () => {
await program.account.testData.fetch(associatedAccount);
@ -186,17 +188,19 @@ describe("misc", () => {
it("Can use an associated program account", async () => {
const state = await program.state.address();
const [associatedAccount, nonce] =
await anchor.web3.PublicKey.findProgramAddress(
[
anchor.utils.bytes.utf8.encode("anchor"),
program.provider.wallet.publicKey.toBuffer(),
state.toBuffer(),
data.publicKey.toBuffer(),
anchor.utils.bytes.utf8.encode("my-seed"),
],
program.programId
);
const [
associatedAccount,
nonce,
] = await anchor.web3.PublicKey.findProgramAddress(
[
anchor.utils.bytes.utf8.encode("anchor"),
program.provider.wallet.publicKey.toBuffer(),
state.toBuffer(),
data.publicKey.toBuffer(),
anchor.utils.bytes.utf8.encode("my-seed"),
],
program.programId
);
await program.rpc.testAssociatedAccount(new anchor.BN(5), {
accounts: {
myAccount: associatedAccount,
@ -435,4 +439,16 @@ describe("misc", () => {
assert.ok(account.owner.equals(program.provider.wallet.publicKey));
assert.ok(account.mint.equals(mint.publicKey));
});
it("Can execute a fallback function", async () => {
await assert.rejects(
async () => {
await anchor.utils.rpc.invoke(program.programId);
},
(err) => {
assert.ok(err.toString().includes("custom program error: 0x4d2"));
return true;
}
);
});
});

View File

@ -113,7 +113,9 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
}
})
.collect();
let fallback_fn = gen_fallback(program).unwrap_or(quote! {
Err(anchor_lang::__private::ErrorCode::InstructionFallbackNotFound.into())
});
quote! {
/// Performs method dispatch.
///
@ -152,10 +154,20 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
#(#trait_dispatch_arms)*
#(#global_dispatch_arms)*
_ => {
msg!("Fallback functions are not supported. If you have a use case, please file an issue.");
Err(anchor_lang::__private::ErrorCode::InstructionFallbackNotFound.into())
#fallback_fn
}
}
}
}
}
pub fn gen_fallback(program: &Program) -> Option<proc_macro2::TokenStream> {
program.fallback_fn.as_ref().map(|fallback_fn| {
let program_name = &program.name;
let method = &fallback_fn.raw_method;
let fn_name = &method.sig.ident;
quote! {
#program_name::#fn_name(program_id, accounts, ix_data)
}
})
}

View File

@ -1,7 +1,11 @@
use crate::program_codegen::dispatch;
use crate::Program;
use quote::quote;
pub fn generate(_program: &Program) -> proc_macro2::TokenStream {
pub fn generate(program: &Program) -> proc_macro2::TokenStream {
let fallback_maybe = dispatch::gen_fallback(program).unwrap_or(quote! {
Err(anchor_lang::__private::ErrorCode::InstructionMissing.into());
});
quote! {
#[cfg(not(feature = "no-entrypoint"))]
anchor_lang::solana_program::entrypoint!(entry);
@ -52,7 +56,7 @@ pub fn generate(_program: &Program) -> proc_macro2::TokenStream {
msg!("anchor-debug is active");
}
if ix_data.len() < 8 {
return Err(anchor_lang::__private::ErrorCode::InstructionMissing.into());
return #fallback_maybe
}
// Split the instruction data into the first 8 byte method

View File

@ -30,6 +30,7 @@ pub struct Program {
pub ixs: Vec<Ix>,
pub name: Ident,
pub program_mod: ItemMod,
pub fallback_fn: Option<FallbackFn>,
}
impl Parse for Program {
@ -92,6 +93,11 @@ pub struct IxArg {
pub raw_arg: PatType,
}
#[derive(Debug)]
pub struct FallbackFn {
raw_method: ItemFn,
}
#[derive(Debug)]
pub struct AccountsStruct {
// Name of the accounts struct.

View File

@ -1,20 +1,24 @@
use crate::parser::program::ctx_accounts_ident;
use crate::{Ix, IxArg};
use crate::{FallbackFn, Ix, IxArg};
use syn::parse::{Error as ParseError, Result as ParseResult};
use syn::spanned::Spanned;
// Parse all non-state ix handlers from the program mod definition.
pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<Vec<Ix>> {
pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<FallbackFn>)> {
let mod_content = &program_mod
.content
.as_ref()
.ok_or_else(|| ParseError::new(program_mod.span(), "program content not provided"))?
.1;
mod_content
let ixs = mod_content
.iter()
.filter_map(|item| match item {
syn::Item::Fn(item_fn) => Some(item_fn),
syn::Item::Fn(item_fn) => {
let (ctx, _) = parse_args(item_fn).ok()?;
ctx_accounts_ident(&ctx.raw_arg).ok()?;
Some(item_fn)
}
_ => None,
})
.map(|method: &syn::ItemFn| {
@ -27,7 +31,36 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<Vec<Ix>> {
anchor_ident,
})
})
.collect::<ParseResult<Vec<Ix>>>()
.collect::<ParseResult<Vec<Ix>>>()?;
let fallback_fn = {
let fallback_fns = mod_content
.iter()
.filter_map(|item| match item {
syn::Item::Fn(item_fn) => {
let (ctx, _args) = parse_args(item_fn).ok()?;
if ctx_accounts_ident(&ctx.raw_arg).is_ok() {
return None;
}
Some(item_fn)
}
_ => None,
})
.collect::<Vec<_>>();
if fallback_fns.len() > 1 {
return Err(ParseError::new(
fallback_fns[0].span(),
"More than one fallback function found",
));
}
fallback_fns
.first()
.map(|method: &&syn::ItemFn| FallbackFn {
raw_method: (*method).clone(),
})
};
Ok((ixs, fallback_fn))
}
pub fn parse_args(method: &syn::ItemFn) -> ParseResult<(IxArg, Vec<IxArg>)> {

View File

@ -7,13 +7,13 @@ mod state;
pub fn parse(program_mod: syn::ItemMod) -> ParseResult<Program> {
let state = state::parse(&program_mod)?;
let ixs = instructions::parse(&program_mod)?;
let (ixs, fallback_fn) = instructions::parse(&program_mod)?;
Ok(Program {
state,
ixs,
name: program_mod.ident.clone(),
program_mod,
fallback_fn,
})
}

View File

@ -1,5 +1,42 @@
import assert from "assert";
import { PublicKey, AccountInfo, Connection } from "@solana/web3.js";
import {
AccountInfo,
AccountMeta,
Connection,
PublicKey,
TransactionSignature,
Transaction,
TransactionInstruction,
} from "@solana/web3.js";
import { Address, translateAddress } from "../program/common";
import Provider, { getProvider } from "../provider";
/**
* Sends a transaction to a program with the given accounts and instruction
* data.
*/
export async function invoke(
programId: Address,
accounts?: Array<AccountMeta>,
data?: Buffer,
provider?: Provider
): Promise<TransactionSignature> {
programId = translateAddress(programId);
if (!provider) {
provider = getProvider();
}
const tx = new Transaction();
tx.add(
new TransactionInstruction({
programId,
keys: accounts ?? [],
data,
})
);
return await provider.send(tx);
}
export async function getMultipleAccounts(
connection: Connection,