Error handling

This commit is contained in:
Armani Ferrante 2021-01-15 23:05:26 -08:00
parent 41b25afed0
commit e636cf9721
22 changed files with 434 additions and 9 deletions

11
Cargo.lock generated
View File

@ -65,6 +65,16 @@ dependencies = [
"syn 1.0.57",
]
[[package]]
name = "anchor-attribute-error"
version = "0.0.0-alpha.0"
dependencies = [
"anchor-syn",
"proc-macro2 1.0.24",
"quote 1.0.8",
"syn 1.0.57",
]
[[package]]
name = "anchor-attribute-program"
version = "0.0.0-alpha.0"
@ -113,6 +123,7 @@ version = "0.0.0-alpha.0"
dependencies = [
"anchor-attribute-access-control",
"anchor-attribute-account",
"anchor-attribute-error",
"anchor-attribute-program",
"anchor-derive-accounts",
"serum-borsh",

View File

@ -14,6 +14,7 @@ default = []
[dependencies]
anchor-attribute-access-control = { path = "./attribute/access-control", version = "0.0.0-alpha.0" }
anchor-attribute-account = { path = "./attribute/account", version = "0.0.0-alpha.0" }
anchor-attribute-error = { path = "./attribute/error" }
anchor-attribute-program = { path = "./attribute/program", version = "0.0.0-alpha.0" }
anchor-derive-accounts = { path = "./derive/accounts", version = "0.0.0-alpha.0" }
serum-borsh = { version = "0.8.0-serum.1", features = ["serum-program"] }

View File

@ -0,0 +1,17 @@
[package]
name = "anchor-attribute-error"
version = "0.0.0-alpha.0"
authors = ["Serum Foundation <foundation@projectserum.com>"]
repository = "https://github.com/project-serum/anchor"
license = "Apache-2.0"
description = "Anchor attribute macro for creating error types"
edition = "2018"
[lib]
proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "=1.0.57", features = ["full"] }
anchor-syn = { path = "../../syn" }

View File

@ -0,0 +1,16 @@
extern crate proc_macro;
use anchor_syn::codegen::error as error_codegen;
use anchor_syn::parser::error as error_parser;
use syn::parse_macro_input;
/// Generates an error type from an error code enum.
#[proc_macro_attribute]
pub fn error(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let mut error_enum = parse_macro_input!(input as syn::ItemEnum);
let error = error_codegen::generate(error_parser::parse(&mut error_enum));
proc_macro::TokenStream::from(error)
}

View File

@ -0,0 +1,2 @@
cluster = "localnet"
wallet = "/home/armaniferrante/.config/solana/id.json"

View File

@ -0,0 +1,4 @@
[workspace]
members = [
"programs/*"
]

View File

@ -0,0 +1,17 @@
[package]
name = "errors"
version = "0.1.0"
description = "Created with Anchor"
edition = "2018"
[lib]
crate-type = ["cdylib", "lib"]
name = "errors"
[features]
no-entrypoint = []
cpi = ["no-entrypoint"]
[dependencies]
# anchor-lang = { git = "https://github.com/project-serum/anchor", features = ["derive"] }
anchor-lang = { path = "/home/armaniferrante/Documents/code/src/github.com/project-serum/anchor", features = ["derive"] }

View File

@ -0,0 +1,2 @@
[target.bpfel-unknown-unknown.dependencies.std]
features = []

View File

@ -0,0 +1,36 @@
#![feature(proc_macro_hygiene)]
use anchor_lang::prelude::*;
#[program]
mod errors {
use super::*;
pub fn hello(ctx: Context<Hello>) -> Result<(), Error> {
Err(MyError::Hello.into())
}
pub fn hello_no_msg(ctx: Context<HelloNoMsg>) -> Result<(), Error> {
Err(MyError::HelloNoMsg.into())
}
pub fn hello_next(ctx: Context<HelloNext>) -> Result<(), Error> {
Err(MyError::HelloNext.into())
}
}
#[derive(Accounts)]
pub struct Hello {}
#[derive(Accounts)]
pub struct HelloNoMsg {}
#[derive(Accounts)]
pub struct HelloNext {}
#[error]
pub enum MyError {
#[msg("This is an error message clients will automatically display")]
Hello,
HelloNoMsg = 123,
HelloNext,
}

View File

@ -0,0 +1,47 @@
const assert = require("assert");
//const anchor = require('@project-serum/anchor');
const anchor = require("/home/armaniferrante/Documents/code/src/github.com/project-serum/anchor/ts");
describe("errors", () => {
// Configure the client to use the local cluster.
anchor.setProvider(anchor.Provider.local());
const program = anchor.workspace.Errors;
it("Emits a Hello error", async () => {
try {
const tx = await program.rpc.hello();
assert.ok(false);
} catch (err) {
const errMsg =
"This is an error message clients will automatically display";
assert.equal(err.toString(), errMsg);
assert.equal(err.msg, errMsg);
assert.equal(err.code, 100);
}
});
it("Emits a HelloNoMsg error", async () => {
try {
const tx = await program.rpc.helloNoMsg();
assert.ok(false);
} catch (err) {
const errMsg = "HelloNoMsg";
assert.equal(err.toString(), errMsg);
assert.equal(err.msg, errMsg);
assert.equal(err.code, 100 + 123);
}
});
it("Emits a HelloNext error", async () => {
try {
const tx = await program.rpc.helloNext();
assert.ok(false);
} catch (err) {
const errMsg = "HelloNext";
assert.equal(err.toString(), errMsg);
assert.equal(err.msg, errMsg);
assert.equal(err.code, 100 + 124);
}
});
});

32
src/error.rs Normal file
View File

@ -0,0 +1,32 @@
use solana_program::program_error::ProgramError;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error(transparent)]
ProgramError(#[from] ProgramError),
#[error("{0:?}")]
ErrorCode(#[from] ErrorCode),
}
#[derive(Debug, Clone, Copy)]
#[repr(u32)]
pub enum ErrorCode {
WrongSerialization = 1,
}
impl std::fmt::Display for ErrorCode {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
<Self as std::fmt::Debug>::fmt(self, fmt)
}
}
impl std::error::Error for ErrorCode {}
impl std::convert::From<Error> for ProgramError {
fn from(e: Error) -> ProgramError {
match e {
Error::ProgramError(e) => e,
Error::ErrorCode(c) => ProgramError::Custom(c as u32),
}
}
}

View File

@ -30,6 +30,7 @@ use std::io::Write;
mod account_info;
mod context;
mod cpi_account;
mod error;
mod program_account;
mod sysvar;
@ -39,10 +40,12 @@ pub use crate::program_account::ProgramAccount;
pub use crate::sysvar::Sysvar;
pub use anchor_attribute_access_control::access_control;
pub use anchor_attribute_account::account;
pub use anchor_attribute_error::error;
pub use anchor_attribute_program::program;
pub use anchor_derive_accounts::Accounts;
/// Default serialization format for anchor instructions and accounts.
pub use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSerialize};
pub use error::Error;
pub use solana_program;
/// A data structure of accounts that can be deserialized from the input
@ -115,7 +118,7 @@ pub trait AccountDeserialize: Sized {
/// All programs should include it via `anchor_lang::prelude::*;`.
pub mod prelude {
pub use super::{
access_control, account, program, AccountDeserialize, AccountSerialize, Accounts,
access_control, account, error, program, AccountDeserialize, AccountSerialize, Accounts,
AccountsInit, AnchorDeserialize, AnchorSerialize, Context, CpiAccount, CpiContext,
ProgramAccount, Sysvar, ToAccountInfo, ToAccountInfos, ToAccountMetas,
};
@ -138,4 +141,5 @@ pub mod prelude {
pub use solana_program::sysvar::slot_history::SlotHistory;
pub use solana_program::sysvar::stake_history::StakeHistory;
pub use solana_program::sysvar::Sysvar as SolanaSysvar;
pub use thiserror;
}

39
syn/src/codegen/error.rs Normal file
View File

@ -0,0 +1,39 @@
use crate::Error;
use quote::quote;
pub fn generate(error: Error) -> proc_macro2::TokenStream {
let error_enum = error.raw_enum;
let enum_name = &error.ident;
quote! {
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error(transparent)]
ProgramError(#[from] ProgramError),
#[error("{0:?}")]
ErrorCode(#[from] #enum_name),
}
#[derive(Debug, Clone, Copy)]
#[repr(u32)]
#error_enum
impl std::fmt::Display for #enum_name {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
<Self as std::fmt::Debug>::fmt(self, fmt)
}
}
impl std::error::Error for #enum_name {}
impl std::convert::From<Error> for ProgramError {
fn from(e: Error) -> ProgramError {
// Errors 0-100 are reserved for the framework.
let error_offset = 100u32;
match e {
Error::ProgramError(e) => e,
Error::ErrorCode(c) => ProgramError::Custom(c as u32 + error_offset),
}
}
}
}
}

View File

@ -1,2 +1,3 @@
pub mod accounts;
pub mod error;
pub mod program;

View File

@ -10,6 +10,8 @@ pub struct Idl {
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub types: Vec<IdlTypeDef>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub errors: Option<Vec<IdlErrorCode>>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub metadata: Option<serde_json::Value>,
}
@ -132,3 +134,11 @@ impl std::str::FromStr for IdlType {
Ok(r)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct IdlErrorCode {
pub code: u32,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub msg: Option<String>,
}

View File

@ -245,3 +245,17 @@ pub enum ConstraintRentExempt {
Enforce,
Skip,
}
#[derive(Debug)]
pub struct Error {
pub raw_enum: syn::ItemEnum,
pub ident: syn::Ident,
pub codes: Vec<ErrorCode>,
}
#[derive(Debug)]
pub struct ErrorCode {
pub id: u32,
pub ident: syn::Ident,
pub msg: Option<String>,
}

69
syn/src/parser/error.rs Normal file
View File

@ -0,0 +1,69 @@
use crate::{Error, ErrorCode};
// Removes any internal #[msg] attributes, as they are inert.
pub fn parse(error_enum: &mut syn::ItemEnum) -> Error {
let ident = error_enum.ident.clone();
let mut last_discriminant = 0;
let codes: Vec<ErrorCode> = error_enum
.variants
.iter_mut()
.map(|variant: &mut syn::Variant| {
let msg = parse_error_attribute(variant);
let ident = variant.ident.clone();
let id = match &variant.discriminant {
None => last_discriminant,
Some((_, disc)) => match disc {
syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
syn::Lit::Int(int) => {
int.base10_parse::<u32>().expect("Must be a base 10 number")
}
_ => panic!("Invalid error discriminant"),
},
_ => panic!("Invalid error discriminant"),
},
};
last_discriminant = id + 1;
// Remove any attributes on the error variant.
variant.attrs = vec![];
ErrorCode { id, ident, msg }
})
.collect();
Error {
raw_enum: error_enum.clone(),
ident,
codes,
}
}
fn parse_error_attribute(variant: &syn::Variant) -> Option<String> {
let attrs = &variant.attrs;
match attrs.len() {
0 => None,
1 => {
let attr = &attrs[0];
let attr_str = attr.path.segments[0].ident.to_string();
if &attr_str != "msg" {
panic!("Use msg to specify error strings");
}
let mut tts = attr.tokens.clone().into_iter();
let g_stream = match tts.next().expect("Must have a token group") {
proc_macro2::TokenTree::Group(g) => g.stream(),
_ => panic!("Invalid syntax"),
};
let msg = match g_stream.into_iter().next() {
None => panic!("Must specify a message string"),
Some(msg) => msg.to_string().replace("\"", ""),
};
Some(msg)
}
_ => {
panic!("Too many attributes found. Use `msg` to specify error strings");
}
}
}

View File

@ -1,6 +1,5 @@
use crate::idl::*;
use crate::parser::accounts;
use crate::parser::program;
use crate::parser::{accounts, error, program};
use crate::AccountsStruct;
use anyhow::Result;
use heck::MixedCase;
@ -22,6 +21,17 @@ pub fn parse(filename: impl AsRef<Path>) -> Result<Idl> {
let f = syn::parse_file(&src).expect("Unable to parse file");
let p = program::parse(parse_program_mod(&f));
let errors = parse_error_enum(&f).map(|mut e| {
error::parse(&mut e)
.codes
.iter()
.map(|code| IdlErrorCode {
code: 100 + code.id,
name: code.ident.to_string(),
msg: code.msg.clone(),
})
.collect::<Vec<IdlErrorCode>>()
});
let accs = parse_accounts(&f);
@ -83,6 +93,7 @@ pub fn parse(filename: impl AsRef<Path>) -> Result<Idl> {
instructions,
types,
accounts,
errors,
metadata: None,
})
}
@ -117,6 +128,33 @@ fn parse_program_mod(f: &syn::File) -> syn::ItemMod {
mods[0].clone()
}
fn parse_error_enum(f: &syn::File) -> Option<syn::ItemEnum> {
f.items
.iter()
.filter_map(|i| match i {
syn::Item::Enum(item_enum) => {
let attrs = item_enum
.attrs
.iter()
.filter_map(|attr| {
let segment = attr.path.segments.last().unwrap();
if segment.ident.to_string() == "error" {
return Some(attr);
}
None
})
.collect::<Vec<_>>();
match attrs.len() {
0 => None,
1 => Some(item_enum),
_ => panic!("Invalid syntax: one error attribute allowed"),
}
}
_ => None,
})
.next()
.cloned()
}
// Parse all structs implementing the `Accounts` trait.
fn parse_accounts(f: &syn::File) -> HashMap<String, AccountsStruct> {
f.items

View File

@ -1,4 +1,5 @@
pub mod accounts;
pub mod error;
#[cfg(feature = "idl")]
pub mod file;
pub mod program;

View File

@ -1 +1,12 @@
export class IdlError extends Error {}
// An error from a user defined program.
export class ProgramError extends Error {
constructor(readonly code: number, readonly msg: string, ...params: any[]) {
super(...params);
}
public toString(): string {
return this.msg;
}
}

View File

@ -4,6 +4,7 @@ export type Idl = {
instructions: IdlInstruction[];
accounts?: IdlTypeDef[];
types?: IdlTypeDef[];
errors?: IdlErrorCode[];
};
export type IdlInstruction = {
@ -74,3 +75,9 @@ export type IdlTypeDefined = {
type IdlEnumVariant = {
// todo
};
type IdlErrorCode = {
code: number;
name: string;
msg?: string;
};

View File

@ -9,7 +9,7 @@ import {
} from "@solana/web3.js";
import { sha256 } from "crypto-hash";
import { Idl, IdlInstruction } from "./idl";
import { IdlError } from "./error";
import { IdlError, ProgramError } from "./error";
import Coder from "./coder";
import { getProvider } from "./";
@ -95,11 +95,12 @@ export class RpcFactory {
const rpcs: Rpcs = {};
const ixFns: Ixs = {};
const accountFns: Accounts = {};
const idlErrors = parseIdlErrors(idl);
idl.instructions.forEach((idlIx) => {
// Function to create a raw `TransactionInstruction`.
const ix = RpcFactory.buildIx(idlIx, coder, programId);
// Function to invoke an RPC against a cluster.
const rpc = RpcFactory.buildRpc(idlIx, ix);
const rpc = RpcFactory.buildRpc(idlIx, ix, idlErrors);
const name = camelCase(idlIx.name);
rpcs[name] = rpc;
@ -175,7 +176,11 @@ export class RpcFactory {
return ix;
}
private static buildRpc(idlIx: IdlInstruction, ixFn: IxFn): RpcFn {
private static buildRpc(
idlIx: IdlInstruction,
ixFn: IxFn,
idlErrors: Map<number, string>
): RpcFn {
const rpc = async (...args: any[]): Promise<TransactionSignature> => {
const [_, ctx] = splitArgsAndCtx(idlIx, [...args]);
const tx = new Transaction();
@ -187,15 +192,56 @@ export class RpcFactory {
if (provider === null) {
throw new Error("Provider not found");
}
const txSig = await provider.send(tx, ctx.signers, ctx.options);
return txSig;
try {
const txSig = await provider.send(tx, ctx.signers, ctx.options);
return txSig;
} catch (err) {
let translatedErr = translateError(idlErrors, err);
if (err === null) {
throw err;
}
throw translatedErr;
}
};
return rpc;
}
}
function translateError(
idlErrors: Map<number, string>,
err: any
): Error | null {
// TODO: don't rely on the error string. web3.js should preserve the error
// code information instead of giving us an untyped string.
let components = err.toString().split("custom program error: ");
if (components.length === 2) {
try {
const errorCode = parseInt(components[1]);
let errorMsg = idlErrors.get(errorCode);
if (errorMsg === undefined) {
// Unexpected error code so just throw the untranslated error.
throw err;
}
return new ProgramError(errorCode, errorMsg);
} catch (parseErr) {
// Unable to parse the error. Just return the untranslated error.
return null;
}
}
}
function parseIdlErrors(idl: Idl): Map<number, string> {
const errors = new Map();
if (idl.errors) {
idl.errors.forEach((e) => {
let msg = e.msg ?? e.name;
errors.set(e.code, msg);
});
}
return errors;
}
function splitArgsAndCtx(
idlIx: IdlInstruction,
args: any[]