diff --git a/Cargo.lock b/Cargo.lock index 7fbc638d..4d61bf32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,6 +65,16 @@ dependencies = [ "syn 1.0.57", ] +[[package]] +name = "anchor-attribute-error" +version = "0.0.0-alpha.0" +dependencies = [ + "anchor-syn", + "proc-macro2 1.0.24", + "quote 1.0.8", + "syn 1.0.57", +] + [[package]] name = "anchor-attribute-program" version = "0.0.0-alpha.0" @@ -113,6 +123,7 @@ version = "0.0.0-alpha.0" dependencies = [ "anchor-attribute-access-control", "anchor-attribute-account", + "anchor-attribute-error", "anchor-attribute-program", "anchor-derive-accounts", "serum-borsh", diff --git a/Cargo.toml b/Cargo.toml index 6882d115..62e0a91f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ default = [] [dependencies] anchor-attribute-access-control = { path = "./attribute/access-control", version = "0.0.0-alpha.0" } anchor-attribute-account = { path = "./attribute/account", version = "0.0.0-alpha.0" } +anchor-attribute-error = { path = "./attribute/error" } anchor-attribute-program = { path = "./attribute/program", version = "0.0.0-alpha.0" } anchor-derive-accounts = { path = "./derive/accounts", version = "0.0.0-alpha.0" } serum-borsh = { version = "0.8.0-serum.1", features = ["serum-program"] } diff --git a/attribute/error/Cargo.toml b/attribute/error/Cargo.toml new file mode 100644 index 00000000..a4371aee --- /dev/null +++ b/attribute/error/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "anchor-attribute-error" +version = "0.0.0-alpha.0" +authors = ["Serum Foundation "] +repository = "https://github.com/project-serum/anchor" +license = "Apache-2.0" +description = "Anchor attribute macro for creating error types" +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "=1.0.57", features = ["full"] } +anchor-syn = { path = "../../syn" } \ No newline at end of file diff --git a/attribute/error/src/lib.rs b/attribute/error/src/lib.rs new file mode 100644 index 00000000..8549a918 --- /dev/null +++ b/attribute/error/src/lib.rs @@ -0,0 +1,16 @@ +extern crate proc_macro; + +use anchor_syn::codegen::error as error_codegen; +use anchor_syn::parser::error as error_parser; +use syn::parse_macro_input; + +/// Generates an error type from an error code enum. +#[proc_macro_attribute] +pub fn error( + _args: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let mut error_enum = parse_macro_input!(input as syn::ItemEnum); + let error = error_codegen::generate(error_parser::parse(&mut error_enum)); + proc_macro::TokenStream::from(error) +} diff --git a/examples/errors/Anchor.toml b/examples/errors/Anchor.toml new file mode 100644 index 00000000..6e90c3dd --- /dev/null +++ b/examples/errors/Anchor.toml @@ -0,0 +1,2 @@ +cluster = "localnet" +wallet = "/home/armaniferrante/.config/solana/id.json" diff --git a/examples/errors/Cargo.toml b/examples/errors/Cargo.toml new file mode 100644 index 00000000..a60de986 --- /dev/null +++ b/examples/errors/Cargo.toml @@ -0,0 +1,4 @@ +[workspace] +members = [ + "programs/*" +] diff --git a/examples/errors/programs/errors/Cargo.toml b/examples/errors/programs/errors/Cargo.toml new file mode 100644 index 00000000..583d28cb --- /dev/null +++ b/examples/errors/programs/errors/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "errors" +version = "0.1.0" +description = "Created with Anchor" +edition = "2018" + +[lib] +crate-type = ["cdylib", "lib"] +name = "errors" + +[features] +no-entrypoint = [] +cpi = ["no-entrypoint"] + +[dependencies] +# anchor-lang = { git = "https://github.com/project-serum/anchor", features = ["derive"] } +anchor-lang = { path = "/home/armaniferrante/Documents/code/src/github.com/project-serum/anchor", features = ["derive"] } diff --git a/examples/errors/programs/errors/Xargo.toml b/examples/errors/programs/errors/Xargo.toml new file mode 100644 index 00000000..1744f098 --- /dev/null +++ b/examples/errors/programs/errors/Xargo.toml @@ -0,0 +1,2 @@ +[target.bpfel-unknown-unknown.dependencies.std] +features = [] \ No newline at end of file diff --git a/examples/errors/programs/errors/src/lib.rs b/examples/errors/programs/errors/src/lib.rs new file mode 100644 index 00000000..a7c560ac --- /dev/null +++ b/examples/errors/programs/errors/src/lib.rs @@ -0,0 +1,36 @@ +#![feature(proc_macro_hygiene)] + +use anchor_lang::prelude::*; + +#[program] +mod errors { + use super::*; + pub fn hello(ctx: Context) -> Result<(), Error> { + Err(MyError::Hello.into()) + } + + pub fn hello_no_msg(ctx: Context) -> Result<(), Error> { + Err(MyError::HelloNoMsg.into()) + } + + pub fn hello_next(ctx: Context) -> Result<(), Error> { + Err(MyError::HelloNext.into()) + } +} + +#[derive(Accounts)] +pub struct Hello {} + +#[derive(Accounts)] +pub struct HelloNoMsg {} + +#[derive(Accounts)] +pub struct HelloNext {} + +#[error] +pub enum MyError { + #[msg("This is an error message clients will automatically display")] + Hello, + HelloNoMsg = 123, + HelloNext, +} diff --git a/examples/errors/tests/errors.js b/examples/errors/tests/errors.js new file mode 100644 index 00000000..35c323ce --- /dev/null +++ b/examples/errors/tests/errors.js @@ -0,0 +1,47 @@ +const assert = require("assert"); +//const anchor = require('@project-serum/anchor'); +const anchor = require("/home/armaniferrante/Documents/code/src/github.com/project-serum/anchor/ts"); + +describe("errors", () => { + // Configure the client to use the local cluster. + anchor.setProvider(anchor.Provider.local()); + + const program = anchor.workspace.Errors; + + it("Emits a Hello error", async () => { + try { + const tx = await program.rpc.hello(); + assert.ok(false); + } catch (err) { + const errMsg = + "This is an error message clients will automatically display"; + assert.equal(err.toString(), errMsg); + assert.equal(err.msg, errMsg); + assert.equal(err.code, 100); + } + }); + + it("Emits a HelloNoMsg error", async () => { + try { + const tx = await program.rpc.helloNoMsg(); + assert.ok(false); + } catch (err) { + const errMsg = "HelloNoMsg"; + assert.equal(err.toString(), errMsg); + assert.equal(err.msg, errMsg); + assert.equal(err.code, 100 + 123); + } + }); + + it("Emits a HelloNext error", async () => { + try { + const tx = await program.rpc.helloNext(); + assert.ok(false); + } catch (err) { + const errMsg = "HelloNext"; + assert.equal(err.toString(), errMsg); + assert.equal(err.msg, errMsg); + assert.equal(err.code, 100 + 124); + } + }); +}); diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..7c94c761 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,32 @@ +use solana_program::program_error::ProgramError; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error(transparent)] + ProgramError(#[from] ProgramError), + #[error("{0:?}")] + ErrorCode(#[from] ErrorCode), +} + +#[derive(Debug, Clone, Copy)] +#[repr(u32)] +pub enum ErrorCode { + WrongSerialization = 1, +} + +impl std::fmt::Display for ErrorCode { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + ::fmt(self, fmt) + } +} + +impl std::error::Error for ErrorCode {} + +impl std::convert::From for ProgramError { + fn from(e: Error) -> ProgramError { + match e { + Error::ProgramError(e) => e, + Error::ErrorCode(c) => ProgramError::Custom(c as u32), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 66d9eed9..c75d0fc9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ use std::io::Write; mod account_info; mod context; mod cpi_account; +mod error; mod program_account; mod sysvar; @@ -39,10 +40,12 @@ pub use crate::program_account::ProgramAccount; pub use crate::sysvar::Sysvar; pub use anchor_attribute_access_control::access_control; pub use anchor_attribute_account::account; +pub use anchor_attribute_error::error; pub use anchor_attribute_program::program; pub use anchor_derive_accounts::Accounts; /// Default serialization format for anchor instructions and accounts. pub use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSerialize}; +pub use error::Error; pub use solana_program; /// A data structure of accounts that can be deserialized from the input @@ -115,7 +118,7 @@ pub trait AccountDeserialize: Sized { /// All programs should include it via `anchor_lang::prelude::*;`. pub mod prelude { pub use super::{ - access_control, account, program, AccountDeserialize, AccountSerialize, Accounts, + access_control, account, error, program, AccountDeserialize, AccountSerialize, Accounts, AccountsInit, AnchorDeserialize, AnchorSerialize, Context, CpiAccount, CpiContext, ProgramAccount, Sysvar, ToAccountInfo, ToAccountInfos, ToAccountMetas, }; @@ -138,4 +141,5 @@ pub mod prelude { pub use solana_program::sysvar::slot_history::SlotHistory; pub use solana_program::sysvar::stake_history::StakeHistory; pub use solana_program::sysvar::Sysvar as SolanaSysvar; + pub use thiserror; } diff --git a/syn/src/codegen/error.rs b/syn/src/codegen/error.rs new file mode 100644 index 00000000..bcc62de6 --- /dev/null +++ b/syn/src/codegen/error.rs @@ -0,0 +1,39 @@ +use crate::Error; +use quote::quote; + +pub fn generate(error: Error) -> proc_macro2::TokenStream { + let error_enum = error.raw_enum; + let enum_name = &error.ident; + quote! { + #[derive(thiserror::Error, Debug)] + pub enum Error { + #[error(transparent)] + ProgramError(#[from] ProgramError), + #[error("{0:?}")] + ErrorCode(#[from] #enum_name), + } + + #[derive(Debug, Clone, Copy)] + #[repr(u32)] + #error_enum + + impl std::fmt::Display for #enum_name { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + ::fmt(self, fmt) + } + } + + impl std::error::Error for #enum_name {} + + impl std::convert::From for ProgramError { + fn from(e: Error) -> ProgramError { + // Errors 0-100 are reserved for the framework. + let error_offset = 100u32; + match e { + Error::ProgramError(e) => e, + Error::ErrorCode(c) => ProgramError::Custom(c as u32 + error_offset), + } + } + } + } +} diff --git a/syn/src/codegen/mod.rs b/syn/src/codegen/mod.rs index 64824441..a8a20d9f 100644 --- a/syn/src/codegen/mod.rs +++ b/syn/src/codegen/mod.rs @@ -1,2 +1,3 @@ pub mod accounts; +pub mod error; pub mod program; diff --git a/syn/src/idl.rs b/syn/src/idl.rs index 0e069445..2a29e5aa 100644 --- a/syn/src/idl.rs +++ b/syn/src/idl.rs @@ -10,6 +10,8 @@ pub struct Idl { #[serde(skip_serializing_if = "Vec::is_empty", default)] pub types: Vec, #[serde(skip_serializing_if = "Option::is_none", default)] + pub errors: Option>, + #[serde(skip_serializing_if = "Option::is_none", default)] pub metadata: Option, } @@ -132,3 +134,11 @@ impl std::str::FromStr for IdlType { Ok(r) } } + +#[derive(Debug, Serialize, Deserialize)] +pub struct IdlErrorCode { + pub code: u32, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub msg: Option, +} diff --git a/syn/src/lib.rs b/syn/src/lib.rs index 8a3c1502..e230fa28 100644 --- a/syn/src/lib.rs +++ b/syn/src/lib.rs @@ -245,3 +245,17 @@ pub enum ConstraintRentExempt { Enforce, Skip, } + +#[derive(Debug)] +pub struct Error { + pub raw_enum: syn::ItemEnum, + pub ident: syn::Ident, + pub codes: Vec, +} + +#[derive(Debug)] +pub struct ErrorCode { + pub id: u32, + pub ident: syn::Ident, + pub msg: Option, +} diff --git a/syn/src/parser/error.rs b/syn/src/parser/error.rs new file mode 100644 index 00000000..41d632a0 --- /dev/null +++ b/syn/src/parser/error.rs @@ -0,0 +1,69 @@ +use crate::{Error, ErrorCode}; + +// Removes any internal #[msg] attributes, as they are inert. +pub fn parse(error_enum: &mut syn::ItemEnum) -> Error { + let ident = error_enum.ident.clone(); + let mut last_discriminant = 0; + let codes: Vec = error_enum + .variants + .iter_mut() + .map(|variant: &mut syn::Variant| { + let msg = parse_error_attribute(variant); + let ident = variant.ident.clone(); + let id = match &variant.discriminant { + None => last_discriminant, + Some((_, disc)) => match disc { + syn::Expr::Lit(expr_lit) => match &expr_lit.lit { + syn::Lit::Int(int) => { + int.base10_parse::().expect("Must be a base 10 number") + } + _ => panic!("Invalid error discriminant"), + }, + _ => panic!("Invalid error discriminant"), + }, + }; + last_discriminant = id + 1; + + // Remove any attributes on the error variant. + variant.attrs = vec![]; + + ErrorCode { id, ident, msg } + }) + .collect(); + + Error { + raw_enum: error_enum.clone(), + ident, + codes, + } +} + +fn parse_error_attribute(variant: &syn::Variant) -> Option { + let attrs = &variant.attrs; + match attrs.len() { + 0 => None, + 1 => { + let attr = &attrs[0]; + let attr_str = attr.path.segments[0].ident.to_string(); + if &attr_str != "msg" { + panic!("Use msg to specify error strings"); + } + + let mut tts = attr.tokens.clone().into_iter(); + let g_stream = match tts.next().expect("Must have a token group") { + proc_macro2::TokenTree::Group(g) => g.stream(), + _ => panic!("Invalid syntax"), + }; + + let msg = match g_stream.into_iter().next() { + None => panic!("Must specify a message string"), + Some(msg) => msg.to_string().replace("\"", ""), + }; + + Some(msg) + } + _ => { + panic!("Too many attributes found. Use `msg` to specify error strings"); + } + } +} diff --git a/syn/src/parser/file.rs b/syn/src/parser/file.rs index ee6bc949..4a5b164e 100644 --- a/syn/src/parser/file.rs +++ b/syn/src/parser/file.rs @@ -1,6 +1,5 @@ use crate::idl::*; -use crate::parser::accounts; -use crate::parser::program; +use crate::parser::{accounts, error, program}; use crate::AccountsStruct; use anyhow::Result; use heck::MixedCase; @@ -22,6 +21,17 @@ pub fn parse(filename: impl AsRef) -> Result { let f = syn::parse_file(&src).expect("Unable to parse file"); let p = program::parse(parse_program_mod(&f)); + let errors = parse_error_enum(&f).map(|mut e| { + error::parse(&mut e) + .codes + .iter() + .map(|code| IdlErrorCode { + code: 100 + code.id, + name: code.ident.to_string(), + msg: code.msg.clone(), + }) + .collect::>() + }); let accs = parse_accounts(&f); @@ -83,6 +93,7 @@ pub fn parse(filename: impl AsRef) -> Result { instructions, types, accounts, + errors, metadata: None, }) } @@ -117,6 +128,33 @@ fn parse_program_mod(f: &syn::File) -> syn::ItemMod { mods[0].clone() } +fn parse_error_enum(f: &syn::File) -> Option { + f.items + .iter() + .filter_map(|i| match i { + syn::Item::Enum(item_enum) => { + let attrs = item_enum + .attrs + .iter() + .filter_map(|attr| { + let segment = attr.path.segments.last().unwrap(); + if segment.ident.to_string() == "error" { + return Some(attr); + } + None + }) + .collect::>(); + match attrs.len() { + 0 => None, + 1 => Some(item_enum), + _ => panic!("Invalid syntax: one error attribute allowed"), + } + } + _ => None, + }) + .next() + .cloned() +} // Parse all structs implementing the `Accounts` trait. fn parse_accounts(f: &syn::File) -> HashMap { f.items diff --git a/syn/src/parser/mod.rs b/syn/src/parser/mod.rs index dacea0c4..a611a2fc 100644 --- a/syn/src/parser/mod.rs +++ b/syn/src/parser/mod.rs @@ -1,4 +1,5 @@ pub mod accounts; +pub mod error; #[cfg(feature = "idl")] pub mod file; pub mod program; diff --git a/ts/src/error.ts b/ts/src/error.ts index 23d427d5..c06ac4c0 100644 --- a/ts/src/error.ts +++ b/ts/src/error.ts @@ -1 +1,12 @@ export class IdlError extends Error {} + +// An error from a user defined program. +export class ProgramError extends Error { + constructor(readonly code: number, readonly msg: string, ...params: any[]) { + super(...params); + } + + public toString(): string { + return this.msg; + } +} diff --git a/ts/src/idl.ts b/ts/src/idl.ts index 96b71dbc..1331c0ac 100644 --- a/ts/src/idl.ts +++ b/ts/src/idl.ts @@ -4,6 +4,7 @@ export type Idl = { instructions: IdlInstruction[]; accounts?: IdlTypeDef[]; types?: IdlTypeDef[]; + errors?: IdlErrorCode[]; }; export type IdlInstruction = { @@ -74,3 +75,9 @@ export type IdlTypeDefined = { type IdlEnumVariant = { // todo }; + +type IdlErrorCode = { + code: number; + name: string; + msg?: string; +}; diff --git a/ts/src/rpc.ts b/ts/src/rpc.ts index 86cead99..e4690eaf 100644 --- a/ts/src/rpc.ts +++ b/ts/src/rpc.ts @@ -9,7 +9,7 @@ import { } from "@solana/web3.js"; import { sha256 } from "crypto-hash"; import { Idl, IdlInstruction } from "./idl"; -import { IdlError } from "./error"; +import { IdlError, ProgramError } from "./error"; import Coder from "./coder"; import { getProvider } from "./"; @@ -95,11 +95,12 @@ export class RpcFactory { const rpcs: Rpcs = {}; const ixFns: Ixs = {}; const accountFns: Accounts = {}; + const idlErrors = parseIdlErrors(idl); idl.instructions.forEach((idlIx) => { // Function to create a raw `TransactionInstruction`. const ix = RpcFactory.buildIx(idlIx, coder, programId); // Function to invoke an RPC against a cluster. - const rpc = RpcFactory.buildRpc(idlIx, ix); + const rpc = RpcFactory.buildRpc(idlIx, ix, idlErrors); const name = camelCase(idlIx.name); rpcs[name] = rpc; @@ -175,7 +176,11 @@ export class RpcFactory { return ix; } - private static buildRpc(idlIx: IdlInstruction, ixFn: IxFn): RpcFn { + private static buildRpc( + idlIx: IdlInstruction, + ixFn: IxFn, + idlErrors: Map + ): RpcFn { const rpc = async (...args: any[]): Promise => { const [_, ctx] = splitArgsAndCtx(idlIx, [...args]); const tx = new Transaction(); @@ -187,15 +192,56 @@ export class RpcFactory { if (provider === null) { throw new Error("Provider not found"); } - - const txSig = await provider.send(tx, ctx.signers, ctx.options); - return txSig; + try { + const txSig = await provider.send(tx, ctx.signers, ctx.options); + return txSig; + } catch (err) { + let translatedErr = translateError(idlErrors, err); + if (err === null) { + throw err; + } + throw translatedErr; + } }; return rpc; } } +function translateError( + idlErrors: Map, + err: any +): Error | null { + // TODO: don't rely on the error string. web3.js should preserve the error + // code information instead of giving us an untyped string. + let components = err.toString().split("custom program error: "); + if (components.length === 2) { + try { + const errorCode = parseInt(components[1]); + let errorMsg = idlErrors.get(errorCode); + if (errorMsg === undefined) { + // Unexpected error code so just throw the untranslated error. + throw err; + } + return new ProgramError(errorCode, errorMsg); + } catch (parseErr) { + // Unable to parse the error. Just return the untranslated error. + return null; + } + } +} + +function parseIdlErrors(idl: Idl): Map { + const errors = new Map(); + if (idl.errors) { + idl.errors.forEach((e) => { + let msg = e.msg ?? e.name; + errors.set(e.code, msg); + }); + } + return errors; +} + function splitArgsAndCtx( idlIx: IdlInstruction, args: any[]