lang: Added generic support to Accounts (#496)

This commit is contained in:
Brett Etter 2021-07-08 15:14:39 -05:00 committed by GitHub
parent dd5f18271f
commit d612ffddc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 203 additions and 69 deletions

View File

@ -17,6 +17,7 @@ incremented for features.
* lang: Add fallback functions ([#457](https://github.com/project-serum/anchor/pull/457)).
* lang: Add feature flag for using the old state account discriminator. This is a temporary flag for those with programs built prior to v0.7.0 but want to use the latest Anchor version. Expect this to be removed in a future version ([#446](https://github.com/project-serum/anchor/pull/446)).
* lang: Add generic support to Accounts ([#496](https://github.com/project-serum/anchor/pull/496)).
### Breaking Changes

View File

@ -78,6 +78,7 @@ pub fn account(
let account_strct = parse_macro_input!(input as syn::ItemStruct);
let account_name = &account_strct.ident;
let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
let discriminator: proc_macro2::TokenStream = {
// Namespace the discriminator to prevent collisions.
@ -103,12 +104,16 @@ pub fn account(
#[zero_copy]
#account_strct
unsafe impl anchor_lang::__private::bytemuck::Pod for #account_name {}
unsafe impl anchor_lang::__private::bytemuck::Zeroable for #account_name {}
#[automatically_derived]
unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
#[automatically_derived]
unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}
impl anchor_lang::ZeroCopy for #account_name {}
#[automatically_derived]
impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}
impl anchor_lang::Discriminator for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
fn discriminator() -> [u8; 8] {
#discriminator
}
@ -116,7 +121,8 @@ pub fn account(
// This trait is useful for clients deserializing accounts.
// It's expected on-chain programs deserialize via zero-copy.
impl anchor_lang::AccountDeserialize for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
if buf.len() < #discriminator.len() {
return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into());
@ -142,7 +148,8 @@ pub fn account(
#[derive(AnchorSerialize, AnchorDeserialize, Clone)]
#account_strct
impl anchor_lang::AccountSerialize for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> std::result::Result<(), ProgramError> {
writer.write_all(&#discriminator).map_err(|_| anchor_lang::__private::ErrorCode::AccountDidNotSerialize)?;
AnchorSerialize::serialize(
@ -154,7 +161,8 @@ pub fn account(
}
}
impl anchor_lang::AccountDeserialize for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
if buf.len() < #discriminator.len() {
return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into());
@ -173,7 +181,8 @@ pub fn account(
}
}
impl anchor_lang::Discriminator for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
fn discriminator() -> [u8; 8] {
#discriminator
}
@ -206,6 +215,7 @@ pub fn associated(
) -> proc_macro::TokenStream {
let mut account_strct = parse_macro_input!(input as syn::ItemStruct);
let account_name = &account_strct.ident;
let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();
// Add a `__nonce: u8` field to the struct to hold the bump seed for
// the program dervied address.
@ -245,7 +255,8 @@ pub fn associated(
#[anchor_lang::account(#args)]
#account_strct
impl anchor_lang::Bump for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::Bump for #account_name #ty_gen #where_clause {
fn seed(&self) -> u8 {
self.__nonce
}
@ -257,6 +268,7 @@ pub fn associated(
pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
let account_strct = parse_macro_input!(item as syn::ItemStruct);
let account_name = &account_strct.ident;
let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();
let fields = match &account_strct.fields {
syn::Fields::Named(n) => n,
@ -300,7 +312,8 @@ pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::T
})
.collect();
proc_macro::TokenStream::from(quote! {
impl #account_name {
#[automatically_derived]
impl #impl_gen #account_name #ty_gen #where_clause {
#(#methods)*
}
})

View File

@ -115,6 +115,7 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
#(#account_struct_fields),*
}
#[automatically_derived]
impl anchor_lang::ToAccountMetas for #name {
fn to_account_metas(&self, is_signer: Option<bool>) -> Vec<anchor_lang::solana_program::instruction::AccountMeta> {
let mut account_metas = vec![];

View File

@ -355,24 +355,24 @@ pub fn generate_constraint_associated_init(
)
}
fn parse_ty(f: &Field) -> (&syn::Ident, proc_macro2::TokenStream, bool) {
fn parse_ty(f: &Field) -> (&syn::TypePath, proc_macro2::TokenStream, bool) {
match &f.ty {
Ty::ProgramAccount(ty) => (
&ty.account_ident,
&ty.account_type_path,
quote! {
anchor_lang::ProgramAccount
},
false,
),
Ty::Loader(ty) => (
&ty.account_ident,
&ty.account_type_path,
quote! {
anchor_lang::Loader
},
true,
),
Ty::CpiAccount(ty) => (
&ty.account_ident,
&ty.account_type_path,
quote! {
anchor_lang::CpiAccount
},
@ -617,7 +617,7 @@ pub fn generate_constraint_state(f: &Field, c: &ConstraintState) -> proc_macro2:
let program_target = c.program_target.clone();
let ident = &f.ident;
let account_ty = match &f.ty {
Ty::CpiState(ty) => &ty.account_ident,
Ty::CpiState(ty) => &ty.account_type_path,
_ => panic!("Invalid state constraint"),
};
quote! {

View File

@ -1,11 +1,16 @@
use crate::codegen::accounts::generics;
use crate::codegen::accounts::{generics, ParsedGenerics};
use crate::{AccountField, AccountsStruct};
use quote::quote;
// Generates the `Exit` trait implementation.
pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
let name = &accs.ident;
let (combined_generics, trait_generics, strct_generics) = generics(accs);
let ParsedGenerics {
combined_generics,
trait_generics,
struct_generics,
where_clause,
} = generics(accs);
let on_save: Vec<proc_macro2::TokenStream> = accs
.fields
@ -39,7 +44,8 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
})
.collect();
quote! {
impl#combined_generics anchor_lang::AccountsExit#trait_generics for #name#strct_generics {
#[automatically_derived]
impl<#combined_generics> anchor_lang::AccountsExit<#trait_generics> for #name<#struct_generics> #where_clause{
fn exit(&self, program_id: &anchor_lang::solana_program::pubkey::Pubkey) -> anchor_lang::solana_program::entrypoint::ProgramResult {
#(#on_save)*
Ok(())

View File

@ -1,5 +1,9 @@
use crate::AccountsStruct;
use quote::quote;
use std::iter;
use syn::punctuated::Punctuated;
use syn::{ConstParam, LifetimeDef, Token, TypeParam};
use syn::{GenericParam, PredicateLifetime, WhereClause, WherePredicate};
mod __client_accounts;
mod constraints;
@ -26,18 +30,70 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
}
}
fn generics(
accs: &AccountsStruct,
) -> (
proc_macro2::TokenStream,
proc_macro2::TokenStream,
proc_macro2::TokenStream,
) {
match accs.generics.lt_token {
None => (quote! {<'info>}, quote! {<'info>}, quote! {}),
Some(_) => {
let g = &accs.generics;
(quote! {#g}, quote! {#g}, quote! {#g})
}
fn generics(accs: &AccountsStruct) -> ParsedGenerics {
let trait_lifetime = accs
.generics
.lifetimes()
.next()
.cloned()
.unwrap_or_else(|| syn::parse_str("'info").expect("Could not parse lifetime"));
let mut where_clause = accs.generics.where_clause.clone().unwrap_or(WhereClause {
where_token: Default::default(),
predicates: Default::default(),
});
for lifetime in accs.generics.lifetimes().map(|def| &def.lifetime) {
where_clause
.predicates
.push(WherePredicate::Lifetime(PredicateLifetime {
lifetime: lifetime.clone(),
colon_token: Default::default(),
bounds: iter::once(trait_lifetime.lifetime.clone()).collect(),
}))
}
let trait_lifetime = GenericParam::Lifetime(trait_lifetime);
ParsedGenerics {
combined_generics: if accs.generics.lifetimes().next().is_some() {
accs.generics.params.clone()
} else {
iter::once(trait_lifetime.clone())
.chain(accs.generics.params.clone())
.collect()
},
trait_generics: iter::once(trait_lifetime).collect(),
struct_generics: accs
.generics
.params
.clone()
.into_iter()
.map(|param: GenericParam| match param {
GenericParam::Const(ConstParam { ident, .. })
| GenericParam::Type(TypeParam { ident, .. }) => GenericParam::Type(TypeParam {
attrs: vec![],
ident,
colon_token: None,
bounds: Default::default(),
eq_token: None,
default: None,
}),
GenericParam::Lifetime(LifetimeDef { lifetime, .. }) => {
GenericParam::Lifetime(LifetimeDef {
attrs: vec![],
lifetime,
colon_token: None,
bounds: Default::default(),
})
}
})
.collect(),
where_clause,
}
}
struct ParsedGenerics {
pub combined_generics: Punctuated<GenericParam, Token![,]>,
pub trait_generics: Punctuated<GenericParam, Token![,]>,
pub struct_generics: Punctuated<GenericParam, Token![,]>,
pub where_clause: WhereClause,
}

View File

@ -1,11 +1,16 @@
use crate::codegen::accounts::generics;
use crate::codegen::accounts::{generics, ParsedGenerics};
use crate::{AccountField, AccountsStruct};
use quote::quote;
// Generates the `ToAccountInfos` trait implementation.
pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
let name = &accs.ident;
let (combined_generics, trait_generics, strct_generics) = generics(accs);
let ParsedGenerics {
combined_generics,
trait_generics,
struct_generics,
where_clause,
} = generics(accs);
let to_acc_infos: Vec<proc_macro2::TokenStream> = accs
.fields
@ -21,7 +26,8 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
})
.collect();
quote! {
impl#combined_generics anchor_lang::ToAccountInfos#trait_generics for #name#strct_generics {
#[automatically_derived]
impl<#combined_generics> anchor_lang::ToAccountInfos<#trait_generics> for #name <#struct_generics> #where_clause{
fn to_account_infos(&self) -> Vec<anchor_lang::solana_program::account_info::AccountInfo<'info>> {
let mut account_infos = vec![];

View File

@ -1,11 +1,9 @@
use crate::codegen::accounts::generics;
use crate::{AccountField, AccountsStruct};
use quote::quote;
// Generates the `ToAccountMetas` trait implementation.
pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
let name = &accs.ident;
let (combined_generics, _trait_generics, strct_generics) = generics(accs);
let to_acc_metas: Vec<proc_macro2::TokenStream> = accs
.fields
@ -26,8 +24,12 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
}
})
.collect();
let (impl_gen, ty_gen, where_clause) = accs.generics.split_for_impl();
quote! {
impl#combined_generics anchor_lang::ToAccountMetas for #name#strct_generics {
#[automatically_derived]
impl#impl_gen anchor_lang::ToAccountMetas for #name #ty_gen #where_clause{
fn to_account_metas(&self, is_signer: Option<bool>) -> Vec<anchor_lang::solana_program::instruction::AccountMeta> {
let mut account_metas = vec![];

View File

@ -1,4 +1,4 @@
use crate::codegen::accounts::{constraints, generics};
use crate::codegen::accounts::{constraints, generics, ParsedGenerics};
use crate::{AccountField, AccountsStruct, Field, SysvarTy, Ty};
use proc_macro2::TokenStream;
use quote::quote;
@ -7,7 +7,12 @@ use syn::Expr;
// Generates the `Accounts` trait implementation.
pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
let name = &accs.ident;
let (combined_generics, trait_generics, strct_generics) = generics(accs);
let ParsedGenerics {
combined_generics,
trait_generics,
struct_generics,
where_clause,
} = generics(accs);
// Deserialization for each field
let deser_fields: Vec<proc_macro2::TokenStream> = accs
@ -88,7 +93,8 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
};
quote! {
impl#combined_generics anchor_lang::Accounts#trait_generics for #name#strct_generics {
#[automatically_derived]
impl<#combined_generics> anchor_lang::Accounts<#trait_generics> for #name<#struct_generics> #where_clause {
#[inline(never)]
fn try_accounts(
program_id: &anchor_lang::solana_program::pubkey::Pubkey,
@ -133,31 +139,31 @@ fn typed_ident(field: &Field) -> TokenStream {
let ty = match &field.ty {
Ty::AccountInfo => quote! { AccountInfo },
Ty::ProgramState(ty) => {
let account = &ty.account_ident;
let account = &ty.account_type_path;
quote! {
ProgramState<#account>
}
}
Ty::CpiState(ty) => {
let account = &ty.account_ident;
let account = &ty.account_type_path;
quote! {
CpiState<#account>
}
}
Ty::ProgramAccount(ty) => {
let account = &ty.account_ident;
let account = &ty.account_type_path;
quote! {
ProgramAccount<#account>
}
}
Ty::Loader(ty) => {
let account = &ty.account_ident;
let account = &ty.account_type_path;
quote! {
Loader<#account>
}
}
Ty::CpiAccount(ty) => {
let account = &ty.account_ident;
let account = &ty.account_type_path;
quote! {
CpiAccount<#account>
}

View File

@ -12,7 +12,7 @@ use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{
Expr, Generics, Ident, ImplItemMethod, ItemEnum, ItemFn, ItemImpl, ItemMod, ItemStruct, LitInt,
LitStr, PatType, Token,
LitStr, PatType, Token, TypePath,
};
pub mod codegen;
@ -198,30 +198,30 @@ pub enum SysvarTy {
#[derive(Debug, PartialEq)]
pub struct ProgramStateTy {
pub account_ident: Ident,
pub account_type_path: TypePath,
}
#[derive(Debug, PartialEq)]
pub struct CpiStateTy {
pub account_ident: Ident,
pub account_type_path: TypePath,
}
#[derive(Debug, PartialEq)]
pub struct ProgramAccountTy {
// The struct type of the account.
pub account_ident: Ident,
pub account_type_path: TypePath,
}
#[derive(Debug, PartialEq)]
pub struct CpiAccountTy {
// The struct type of the account.
pub account_ident: Ident,
pub account_type_path: TypePath,
}
#[derive(Debug, PartialEq)]
pub struct LoaderTy {
// The struct type of the account.
pub account_ident: Ident,
pub account_type_path: TypePath,
}
#[derive(Debug)]

View File

@ -118,30 +118,40 @@ fn ident_string(f: &syn::Field) -> ParseResult<String> {
fn parse_program_state(path: &syn::Path) -> ParseResult<ProgramStateTy> {
let account_ident = parse_account(path)?;
Ok(ProgramStateTy { account_ident })
Ok(ProgramStateTy {
account_type_path: account_ident,
})
}
fn parse_cpi_state(path: &syn::Path) -> ParseResult<CpiStateTy> {
let account_ident = parse_account(path)?;
Ok(CpiStateTy { account_ident })
Ok(CpiStateTy {
account_type_path: account_ident,
})
}
fn parse_cpi_account(path: &syn::Path) -> ParseResult<CpiAccountTy> {
let account_ident = parse_account(path)?;
Ok(CpiAccountTy { account_ident })
Ok(CpiAccountTy {
account_type_path: account_ident,
})
}
fn parse_program_account(path: &syn::Path) -> ParseResult<ProgramAccountTy> {
let account_ident = parse_account(path)?;
Ok(ProgramAccountTy { account_ident })
Ok(ProgramAccountTy {
account_type_path: account_ident,
})
}
fn parse_program_account_zero_copy(path: &syn::Path) -> ParseResult<LoaderTy> {
let account_ident = parse_account(path)?;
Ok(LoaderTy { account_ident })
Ok(LoaderTy {
account_type_path: account_ident,
})
}
fn parse_account(path: &syn::Path) -> ParseResult<syn::Ident> {
fn parse_account(path: &syn::Path) -> ParseResult<syn::TypePath> {
let segments = &path.segments[0];
match &segments.arguments {
syn::PathArguments::AngleBracketed(args) => {
@ -153,18 +163,7 @@ fn parse_account(path: &syn::Path) -> ParseResult<syn::Ident> {
));
}
match &args.args[1] {
syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
// TODO: allow segmented paths.
if ty_path.path.segments.len() != 1 {
return Err(ParseError::new(
ty_path.path.span(),
"segmented paths are not currently allowed",
));
}
let path_segment = &ty_path.path.segments[0];
Ok(path_segment.ident.clone())
}
syn::GenericArgument::Type(syn::Type::Path(ty_path)) => Ok(ty_path.clone()),
_ => Err(ParseError::new(
args.args[1].span(),
"first bracket argument must be a lifetime",

View File

@ -0,0 +1,44 @@
#![allow(dead_code)]
use anchor_lang::prelude::borsh::maybestd::io::Write;
use anchor_lang::prelude::*;
use borsh::{BorshDeserialize, BorshSerialize};
#[derive(Accounts)]
pub struct GenericsTest<'info, T, U, const N: usize>
where
T: AccountSerialize + AccountDeserialize + Clone,
U: BorshSerialize + BorshDeserialize + Default + Clone,
{
pub non_generic: AccountInfo<'info>,
pub generic: ProgramAccount<'info, T>,
pub const_generic: Loader<'info, Account<N>>,
pub associated: CpiAccount<'info, Associated<U>>,
}
#[account(zero_copy)]
pub struct Account<const N: usize> {
pub data: WrappedU8Array<N>,
}
#[associated]
#[derive(Default)]
pub struct Associated<T>
where
T: BorshDeserialize + BorshSerialize + Default,
{
pub data: T,
}
#[derive(Copy, Clone)]
pub struct WrappedU8Array<const N: usize>(u8);
impl<const N: usize> BorshSerialize for WrappedU8Array<N> {
fn serialize<W: Write>(&self, _writer: &mut W) -> borsh::maybestd::io::Result<()> {
todo!()
}
}
impl<const N: usize> BorshDeserialize for WrappedU8Array<N> {
fn deserialize(_buf: &mut &[u8]) -> borsh::maybestd::io::Result<Self> {
todo!()
}
}