diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cc555036..e60fb48c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The minor version will be incremented upon a breaking change and the patch versi - lang: Add `Account` utility type to get accounts from bytes ([#3091](https://github.com/coral-xyz/anchor/pull/3091)). - client: Add option to pass in mock rpc client when using anchor_client ([#3053](https://github.com/coral-xyz/anchor/pull/3053)). - lang: Get discriminator length dynamically ([#3101](https://github.com/coral-xyz/anchor/pull/3101)). +- lang: Add non-8-byte discriminator support in `declare_program!` ([#3103](https://github.com/coral-xyz/anchor/pull/3103)). ### Fixes diff --git a/lang/attribute/program/src/declare_program/mods/accounts.rs b/lang/attribute/program/src/declare_program/mods/accounts.rs index a73d34f03..87fdb02fc 100644 --- a/lang/attribute/program/src/declare_program/mods/accounts.rs +++ b/lang/attribute/program/src/declare_program/mods/accounts.rs @@ -21,7 +21,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream { return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &buf[..8]; + let given_disc = &buf[..#discriminator.len()]; if &#discriminator != given_disc { return Err( anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch) @@ -51,7 +51,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream { #try_deserialize fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result { - let mut data: &[u8] = &buf[8..]; + let mut data: &[u8] = &buf[#discriminator.len()..]; AnchorDeserialize::deserialize(&mut data) .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into()) } @@ -75,7 +75,7 @@ pub fn gen_accounts_mod(idl: &Idl) -> proc_macro2::TokenStream { #try_deserialize fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result { - let data: &[u8] = &buf[8..]; + let data: &[u8] = &buf[#discriminator.len()..]; let account = anchor_lang::__private::bytemuck::from_bytes(data); Ok(*account) } diff --git a/lang/attribute/program/src/declare_program/mods/internal.rs b/lang/attribute/program/src/declare_program/mods/internal.rs index cfa59a2f5..7152f2670 100644 --- a/lang/attribute/program/src/declare_program/mods/internal.rs +++ b/lang/attribute/program/src/declare_program/mods/internal.rs @@ -46,15 +46,11 @@ fn gen_internal_args_mod(idl: &Idl) -> proc_macro2::TokenStream { } }; - let impl_discriminator = if ix.discriminator.len() == 8 { - let discriminator = gen_discriminator(&ix.discriminator); - quote! { - impl anchor_lang::Discriminator for #ix_struct_name { - const DISCRIMINATOR: &'static [u8] = &#discriminator; - } + let discriminator = gen_discriminator(&ix.discriminator); + let impl_discriminator = quote! { + impl anchor_lang::Discriminator for #ix_struct_name { + const DISCRIMINATOR: &'static [u8] = &#discriminator; } - } else { - quote! {} }; let impl_ix_data = quote! { diff --git a/lang/attribute/program/src/declare_program/mods/utils.rs b/lang/attribute/program/src/declare_program/mods/utils.rs index 9ba325c29..4629ddf10 100644 --- a/lang/attribute/program/src/declare_program/mods/utils.rs +++ b/lang/attribute/program/src/declare_program/mods/utils.rs @@ -24,15 +24,17 @@ fn gen_account(idl: &Idl) -> proc_macro2::TokenStream { .iter() .map(|acc| format_ident!("{}", acc.name)) .map(|name| quote! { #name(#name) }); - let match_arms = idl.accounts.iter().map(|acc| { - let disc = gen_discriminator(&acc.discriminator); + let if_statements = idl.accounts.iter().map(|acc| { let name = format_ident!("{}", acc.name); - let account = quote! { - #name::try_from_slice(&value[8..]) - .map(Self::#name) - .map_err(Into::into) - }; - quote! { #disc => #account } + let disc = gen_discriminator(&acc.discriminator); + let disc_len = acc.discriminator.len(); + quote! { + if value.starts_with(&#disc) { + return #name::try_from_slice(&value[#disc_len..]) + .map(Self::#name) + .map_err(Into::into) + } + } }); quote! { @@ -57,14 +59,8 @@ fn gen_account(idl: &Idl) -> proc_macro2::TokenStream { type Error = anchor_lang::error::Error; fn try_from(value: &[u8]) -> Result { - if value.len() < 8 { - return Err(ProgramError::InvalidArgument.into()); - } - - match &value[..8] { - #(#match_arms,)* - _ => Err(ProgramError::InvalidArgument.into()), - } + #(#if_statements)* + Err(ProgramError::InvalidArgument.into()) } } } @@ -76,15 +72,17 @@ fn gen_event(idl: &Idl) -> proc_macro2::TokenStream { .iter() .map(|ev| format_ident!("{}", ev.name)) .map(|name| quote! { #name(#name) }); - let match_arms = idl.events.iter().map(|ev| { - let disc = gen_discriminator(&ev.discriminator); + let if_statements = idl.events.iter().map(|ev| { let name = format_ident!("{}", ev.name); - let event = quote! { - #name::try_from_slice(&value[8..]) - .map(Self::#name) - .map_err(Into::into) - }; - quote! { #disc => #event } + let disc = gen_discriminator(&ev.discriminator); + let disc_len = ev.discriminator.len(); + quote! { + if value.starts_with(&#disc) { + return #name::try_from_slice(&value[#disc_len..]) + .map(Self::#name) + .map_err(Into::into) + } + } }); quote! { @@ -109,14 +107,8 @@ fn gen_event(idl: &Idl) -> proc_macro2::TokenStream { type Error = anchor_lang::error::Error; fn try_from(value: &[u8]) -> Result { - if value.len() < 8 { - return Err(ProgramError::InvalidArgument.into()); - } - - match &value[..8] { - #(#match_arms,)* - _ => Err(ProgramError::InvalidArgument.into()), - } + #(#if_statements)* + Err(ProgramError::InvalidArgument.into()) } } }