lang: Add non-8-byte discriminator support in `declare_program!` (#3103)

This commit is contained in:
acheron 2024-07-22 15:27:28 +02:00 committed by GitHub
parent ba33d5e974
commit e5bed20736
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 32 additions and 43 deletions

View File

@ -19,6 +19,7 @@ The minor version will be incremented upon a breaking change and the patch versi
- lang: Add `Account` utility type to get accounts from bytes ([#3091](https://github.com/coral-xyz/anchor/pull/3091)). - lang: Add `Account` utility type to get accounts from bytes ([#3091](https://github.com/coral-xyz/anchor/pull/3091)).
- client: Add option to pass in mock rpc client when using anchor_client ([#3053](https://github.com/coral-xyz/anchor/pull/3053)). - client: Add option to pass in mock rpc client when using anchor_client ([#3053](https://github.com/coral-xyz/anchor/pull/3053)).
- lang: Get discriminator length dynamically ([#3101](https://github.com/coral-xyz/anchor/pull/3101)). - lang: Get discriminator length dynamically ([#3101](https://github.com/coral-xyz/anchor/pull/3101)).
- lang: Add non-8-byte discriminator support in `declare_program!` ([#3103](https://github.com/coral-xyz/anchor/pull/3103)).
### Fixes ### Fixes

View File

@ -21,7 +21,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into()); return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
} }
let given_disc = &buf[..8]; let given_disc = &buf[..#discriminator.len()];
if &#discriminator != given_disc { if &#discriminator != given_disc {
return Err( return Err(
anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch) anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch)
@ -51,7 +51,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
#try_deserialize #try_deserialize
fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> { fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
let mut data: &[u8] = &buf[8..]; let mut data: &[u8] = &buf[#discriminator.len()..];
AnchorDeserialize::deserialize(&mut data) AnchorDeserialize::deserialize(&mut data)
.map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into()) .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
} }
@ -75,7 +75,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream {
#try_deserialize #try_deserialize
fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> { fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
let data: &[u8] = &buf[8..]; let data: &[u8] = &buf[#discriminator.len()..];
let account = anchor_lang::__private::bytemuck::from_bytes(data); let account = anchor_lang::__private::bytemuck::from_bytes(data);
Ok(*account) Ok(*account)
} }

View File

@ -46,15 +46,11 @@ fn gen_internal_args_mod(idl: &Idl) -> proc_macro2::TokenStream {
} }
}; };
let impl_discriminator = if ix.discriminator.len() == 8 { let discriminator = gen_discriminator(&ix.discriminator);
let discriminator = gen_discriminator(&ix.discriminator); let impl_discriminator = quote! {
quote! { impl anchor_lang::Discriminator for #ix_struct_name {
impl anchor_lang::Discriminator for #ix_struct_name { const DISCRIMINATOR: &'static [u8] = &#discriminator;
const DISCRIMINATOR: &'static [u8] = &#discriminator;
}
} }
} else {
quote! {}
}; };
let impl_ix_data = quote! { let impl_ix_data = quote! {

View File

@ -24,15 +24,17 @@ fn gen_account(idl: &Idl) -> proc_macro2::TokenStream {
.iter() .iter()
.map(|acc| format_ident!("{}", acc.name)) .map(|acc| format_ident!("{}", acc.name))
.map(|name| quote! { #name(#name) }); .map(|name| quote! { #name(#name) });
let match_arms = idl.accounts.iter().map(|acc| { let if_statements = idl.accounts.iter().map(|acc| {
let disc = gen_discriminator(&acc.discriminator);
let name = format_ident!("{}", acc.name); let name = format_ident!("{}", acc.name);
let account = quote! { let disc = gen_discriminator(&acc.discriminator);
#name::try_from_slice(&value[8..]) let disc_len = acc.discriminator.len();
.map(Self::#name) quote! {
.map_err(Into::into) if value.starts_with(&#disc) {
}; return #name::try_from_slice(&value[#disc_len..])
quote! { #disc => #account } .map(Self::#name)
.map_err(Into::into)
}
}
}); });
quote! { quote! {
@ -57,14 +59,8 @@ fn gen_account(idl: &Idl) -> proc_macro2::TokenStream {
type Error = anchor_lang::error::Error; type Error = anchor_lang::error::Error;
fn try_from(value: &[u8]) -> Result<Self> { fn try_from(value: &[u8]) -> Result<Self> {
if value.len() < 8 { #(#if_statements)*
return Err(ProgramError::InvalidArgument.into()); Err(ProgramError::InvalidArgument.into())
}
match &value[..8] {
#(#match_arms,)*
_ => Err(ProgramError::InvalidArgument.into()),
}
} }
} }
} }
@ -76,15 +72,17 @@ fn gen_event(idl: &Idl) -> proc_macro2::TokenStream {
.iter() .iter()
.map(|ev| format_ident!("{}", ev.name)) .map(|ev| format_ident!("{}", ev.name))
.map(|name| quote! { #name(#name) }); .map(|name| quote! { #name(#name) });
let match_arms = idl.events.iter().map(|ev| { let if_statements = idl.events.iter().map(|ev| {
let disc = gen_discriminator(&ev.discriminator);
let name = format_ident!("{}", ev.name); let name = format_ident!("{}", ev.name);
let event = quote! { let disc = gen_discriminator(&ev.discriminator);
#name::try_from_slice(&value[8..]) let disc_len = ev.discriminator.len();
.map(Self::#name) quote! {
.map_err(Into::into) if value.starts_with(&#disc) {
}; return #name::try_from_slice(&value[#disc_len..])
quote! { #disc => #event } .map(Self::#name)
.map_err(Into::into)
}
}
}); });
quote! { quote! {
@ -109,14 +107,8 @@ fn gen_event(idl: &Idl) -> proc_macro2::TokenStream {
type Error = anchor_lang::error::Error; type Error = anchor_lang::error::Error;
fn try_from(value: &[u8]) -> Result<Self> { fn try_from(value: &[u8]) -> Result<Self> {
if value.len() < 8 { #(#if_statements)*
return Err(ProgramError::InvalidArgument.into()); Err(ProgramError::InvalidArgument.into())
}
match &value[..8] {
#(#match_arms,)*
_ => Err(ProgramError::InvalidArgument.into()),
}
} }
} }
} }