From 9f3ea48b7b00bddc51260f442e20a7ddf972b979 Mon Sep 17 00:00:00 2001 From: Armani Ferrante Date: Tue, 8 Feb 2022 13:32:23 -0500 Subject: [PATCH] lang, ts: account versioning --- CHANGELOG.md | 5 + Cargo.lock | 2 +- client/src/lib.rs | 10 ++ lang/attribute/account/src/lib.rs | 88 ++++++++++----- lang/attribute/event/src/lib.rs | 41 +++++-- lang/src/accounts/account.rs | 12 +- lang/src/accounts/account_loader.rs | 62 ++++------- lang/src/accounts/loader.rs | 55 ++++----- lang/src/accounts/program_account.rs | 17 ++- lang/src/lib.rs | 3 + lang/syn/src/codegen/accounts/constraints.rs | 90 ++++++++++++--- lang/syn/src/codegen/program/handlers.rs | 44 +++++++- lang/syn/src/idl/file.rs | 1 + lang/syn/src/idl/mod.rs | 4 + lang/syn/src/lib.rs | 24 +++- tests/chat/programs/chat/src/lib.rs | 2 +- tests/misc/programs/misc/src/lib.rs | 4 +- tests/zero-copy/programs/zero-copy/src/lib.rs | 4 +- .../zero-copy/tests/compute_unit_test.rs | 5 +- ts/src/coder/borsh/accounts.ts | 104 ++++++++++++++---- ts/src/coder/borsh/event.ts | 35 +++++- ts/src/coder/borsh/index.ts | 2 +- ts/src/coder/borsh/state.ts | 17 ++- ts/src/coder/index.ts | 4 +- ts/src/coder/spl-token/accounts.ts | 8 +- ts/src/idl.ts | 1 + ts/src/program/namespace/account.ts | 25 +++-- ts/src/program/namespace/state.ts | 14 ++- ts/src/spl/token.ts | 2 + ts/src/utils/features.ts | 6 +- ts/tests/events.spec.ts | 1 + ts/tests/transaction.spec.ts | 1 + 32 files changed, 506 insertions(+), 187 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a9ea947..7af3915d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,11 @@ incremented for features. ## [Unreleased] +### Breaking + +* ts: `BorshAccountsCoder.accountDiscriminator` method has been replaced with `BorshAccountHeader.discriminator` ([#]()). +* lang, ts: 8 byte account discriminator has been replaced with a versioned account header ([#]()). + ## [0.21.0] - 2022-02-07 ### Fixes diff --git a/Cargo.lock b/Cargo.lock index a9a38fab..edb865ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -303,7 +303,7 @@ checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" [[package]] name = "avm" -version = "0.20.1" +version = "0.21.0" dependencies = [ "anyhow", "cfg-if 1.0.0", diff --git a/client/src/lib.rs b/client/src/lib.rs index 65da3c5f..8ea55f6c 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -291,12 +291,22 @@ fn handle_program_log( }; let mut slice: &[u8] = &borsh_bytes[..]; + + #[cfg(feature = "deprecated-layout")] let disc: [u8; 8] = { let mut disc = [0; 8]; disc.copy_from_slice(&borsh_bytes[..8]); slice = &slice[8..]; disc }; + #[cfg(not(feature = "deprecated-layout"))] + let disc: [u8; 4] = { + let mut disc = [0; 4]; + disc.copy_from_slice(&borsh_bytes[2..6]); + slice = &slice[8..]; + disc + }; + let mut event = None; if disc == T::discriminator() { let e: T = anchor_lang::AnchorDeserialize::deserialize(&mut slice) diff --git a/lang/attribute/account/src/lib.rs b/lang/attribute/account/src/lib.rs index 4ea13811..ff169eca 100644 --- a/lang/attribute/account/src/lib.rs +++ b/lang/attribute/account/src/lib.rs @@ -88,24 +88,6 @@ pub fn account( let account_name = &account_strct.ident; let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl(); - let discriminator: proc_macro2::TokenStream = { - // Namespace the discriminator to prevent collisions. - let discriminator_preimage = { - // For now, zero copy accounts can't be namespaced. - if namespace.is_empty() { - format!("account:{}", account_name) - } else { - format!("{}:{}", namespace, account_name) - } - }; - - let mut discriminator = [0u8; 8]; - discriminator.copy_from_slice( - &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8], - ); - format!("{:?}", discriminator).parse().unwrap() - }; - let owner_impl = { if namespace.is_empty() { quote! { @@ -121,6 +103,60 @@ pub fn account( } }; + let discriminator: proc_macro2::TokenStream = { + // Namespace the discriminator to prevent collisions. + let discriminator_preimage = { + // For now, zero copy accounts can't be namespaced. + if namespace.is_empty() { + format!("account:{}", account_name) + } else { + format!("{}:{}", namespace, account_name) + } + }; + + if cfg!(feature = "deprecated-layout") { + let mut discriminator = [0u8; 8]; + discriminator.copy_from_slice( + &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8], + ); + format!("{:?}", discriminator).parse().unwrap() + } else { + let mut discriminator = [0u8; 4]; + discriminator.copy_from_slice( + &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..4], + ); + format!("{:?}", discriminator).parse().unwrap() + } + }; + + let disc_bytes = { + if cfg!(feature = "deprecated-layout") { + quote! { + let given_disc = &buf[..8]; + } + } else { + quote! { + let given_disc = &buf[2..6]; + } + } + }; + + let disc_fn = { + if cfg!(feature = "deprecated-layout") { + quote! { + fn discriminator() -> [u8; 8] { + #discriminator + } + } + } else { + quote! { + fn discriminator() -> [u8; 4] { + #discriminator + } + } + } + }; + proc_macro::TokenStream::from({ if is_zero_copy { quote! { @@ -137,9 +173,7 @@ pub fn account( #[automatically_derived] impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause { - fn discriminator() -> [u8; 8] { - #discriminator - } + #disc_fn } // This trait is useful for clients deserializing accounts. @@ -147,10 +181,11 @@ pub fn account( #[automatically_derived] impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause { fn try_deserialize(buf: &mut &[u8]) -> std::result::Result { - if buf.len() < #discriminator.len() { + // Header is always 8 bytes. + if buf.len() < 8 { return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &buf[..8]; + #disc_bytes if &#discriminator != given_disc { return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorMismatch.into()); } @@ -176,7 +211,6 @@ pub fn account( #[automatically_derived] impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause { fn try_serialize(&self, writer: &mut W) -> std::result::Result<(), ProgramError> { - writer.write_all(&#discriminator).map_err(|_| anchor_lang::__private::ErrorCode::AccountDidNotSerialize)?; AnchorSerialize::serialize( self, writer @@ -192,7 +226,7 @@ pub fn account( if buf.len() < #discriminator.len() { return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into()); } - let given_disc = &buf[..8]; + #disc_bytes if &#discriminator != given_disc { return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorMismatch.into()); } @@ -208,9 +242,7 @@ pub fn account( #[automatically_derived] impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause { - fn discriminator() -> [u8; 8] { - #discriminator - } + #disc_fn } #owner_impl diff --git a/lang/attribute/event/src/lib.rs b/lang/attribute/event/src/lib.rs index 960b7623..8035404b 100644 --- a/lang/attribute/event/src/lib.rs +++ b/lang/attribute/event/src/lib.rs @@ -18,13 +18,42 @@ pub fn event( let discriminator: proc_macro2::TokenStream = { let discriminator_preimage = format!("event:{}", event_name); - let mut discriminator = [0u8; 8]; - discriminator.copy_from_slice( - &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8], - ); + + #[cfg(feature = "deprecated-layout")] + let discriminator = { + let mut discriminator = [0u8; 8]; + discriminator.copy_from_slice( + &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8], + ); + discriminator + }; + #[cfg(not(feature = "deprecated-layout"))] + let discriminator = { + let mut discriminator = [0u8; 4]; + discriminator.copy_from_slice( + &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..4], + ); + discriminator + }; format!("{:?}", discriminator).parse().unwrap() }; + let discriminator_trait_impl = { + if cfg!(feature = "deprecated_layout") { + quote! { + fn discriminator() -> [u8; 8] { + #discriminator + } + } + } else { + quote! { + fn discriminator() -> [u8; 4] { + #discriminator + } + } + } + }; + proc_macro::TokenStream::from(quote! { #[derive(anchor_lang::__private::EventIndex, AnchorSerialize, AnchorDeserialize)] #event_strct @@ -38,9 +67,7 @@ pub fn event( } impl anchor_lang::Discriminator for #event_name { - fn discriminator() -> [u8; 8] { - #discriminator - } + #discriminator_trait_impl } }) } diff --git a/lang/src/accounts/account.rs b/lang/src/accounts/account.rs index cb2e3703..e589edc0 100644 --- a/lang/src/accounts/account.rs +++ b/lang/src/accounts/account.rs @@ -334,7 +334,10 @@ impl<'info, T: AccountSerialize + AccountDeserialize + Owner + Clone> AccountsEx if &T::owner() == program_id { let info = self.to_account_info(); let mut data = info.try_borrow_mut_data()?; - let dst: &mut [u8] = &mut data; + + // Chop off the header. + let dst: &mut [u8] = &mut data[8..]; + let mut cursor = std::io::Cursor::new(dst); self.account.try_serialize(&mut cursor)?; } @@ -405,3 +408,10 @@ impl<'a, T: AccountSerialize + AccountDeserialize + Owner + Clone> DerefMut for &mut self.account } } + +#[cfg(not(feature = "deprecated-layout"))] +impl<'a, T: AccountSerialize + AccountDeserialize + Owner + Clone> Bump for Account<'a, T> { + fn seed(&self) -> u8 { + self.info.data.borrow()[1] + } +} diff --git a/lang/src/accounts/account_loader.rs b/lang/src/accounts/account_loader.rs index 4e2e81c4..6d5621cd 100644 --- a/lang/src/accounts/account_loader.rs +++ b/lang/src/accounts/account_loader.rs @@ -1,10 +1,7 @@ //! Type facilitating on demand zero copy deserialization. use crate::error::ErrorCode; -use crate::{ - Accounts, AccountsClose, AccountsExit, Owner, ToAccountInfo, ToAccountInfos, ToAccountMetas, - ZeroCopy, -}; +use crate::*; use arrayref::array_ref; use solana_program::account_info::AccountInfo; use solana_program::entrypoint::ProgramResult; @@ -14,7 +11,6 @@ use solana_program::pubkey::Pubkey; use std::cell::{Ref, RefMut}; use std::collections::BTreeMap; use std::fmt; -use std::io::Write; use std::marker::PhantomData; use std::mem; use std::ops::DerefMut; @@ -24,8 +20,6 @@ use std::ops::DerefMut; /// Note that using accounts in this way is distinctly different from using, /// for example, the [`Account`](./struct.Account.html). Namely, /// one must call -/// - `load_init` after initializing an account (this will ignore the missing -/// account discriminator that gets added only after the user's instruction code) /// - `load` when the account is not mutable /// - `load_mut` when the account is mutable /// @@ -117,7 +111,7 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> { } } - /// Constructs a new `Loader` from a previously initialized account. + /// Constructs a new `AccountLoader` from a previously initialized account. #[inline(never)] pub fn try_from( acc_info: &AccountInfo<'info>, @@ -127,7 +121,11 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> { } let data: &[u8] = &acc_info.try_borrow_data()?; // Discriminator must match. + #[cfg(feature = "deprecated-layout")] let disc_bytes = array_ref![data, 0, 8]; + #[cfg(not(feature = "deprecated-layout"))] + let disc_bytes = array_ref![data, 2, 4]; + if disc_bytes != &T::discriminator() { return Err(ErrorCode::AccountDiscriminatorMismatch.into()); } @@ -135,7 +133,7 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> { Ok(AccountLoader::new(acc_info.clone())) } - /// Constructs a new `Loader` from an uninitialized account. + /// Constructs a new `AccountLoader` from an uninitialized account. #[inline(never)] pub fn try_from_unchecked( _program_id: &Pubkey, @@ -146,16 +144,18 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> { } Ok(AccountLoader::new(acc_info.clone())) } - /// Returns a Ref to the account data structure for reading. pub fn load(&self) -> Result, ProgramError> { let data = self.acc_info.try_borrow_data()?; + #[cfg(feature = "deprecated-layout")] let disc_bytes = array_ref![data, 0, 8]; + #[cfg(not(feature = "deprecated-layout"))] + let disc_bytes = array_ref![data, 2, 4]; + if disc_bytes != &T::discriminator() { return Err(ErrorCode::AccountDiscriminatorMismatch.into()); } - Ok(Ref::map(data, |data| { bytemuck::from_bytes(&data[8..mem::size_of::() + 8]) })) @@ -171,7 +171,11 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> { let data = self.acc_info.try_borrow_mut_data()?; + #[cfg(feature = "deprecated-layout")] let disc_bytes = array_ref![data, 0, 8]; + #[cfg(not(feature = "deprecated-layout"))] + let disc_bytes = array_ref![data, 2, 4]; + if disc_bytes != &T::discriminator() { return Err(ErrorCode::AccountDiscriminatorMismatch.into()); } @@ -180,30 +184,6 @@ impl<'info, T: ZeroCopy + Owner> AccountLoader<'info, T> { bytemuck::from_bytes_mut(&mut data.deref_mut()[8..mem::size_of::() + 8]) })) } - - /// Returns a `RefMut` to the account data structure for reading or writing. - /// Should only be called once, when the account is being initialized. - pub fn load_init(&self) -> Result, ProgramError> { - // AccountInfo api allows you to borrow mut even if the account isn't - // writable, so add this check for a better dev experience. - if !self.acc_info.is_writable { - return Err(ErrorCode::AccountNotMutable.into()); - } - - let data = self.acc_info.try_borrow_mut_data()?; - - // The discriminator should be zero, since we're initializing. - let mut disc_bytes = [0u8; 8]; - disc_bytes.copy_from_slice(&data[..8]); - let discriminator = u64::from_le_bytes(disc_bytes); - if discriminator != 0 { - return Err(ErrorCode::AccountDiscriminatorAlreadySet.into()); - } - - Ok(RefMut::map(data, |data| { - bytemuck::from_bytes_mut(&mut data.deref_mut()[8..mem::size_of::() + 8]) - })) - } } impl<'info, T: ZeroCopy + Owner> Accounts<'info> for AccountLoader<'info, T> { @@ -227,10 +207,7 @@ impl<'info, T: ZeroCopy + Owner> Accounts<'info> for AccountLoader<'info, T> { impl<'info, T: ZeroCopy + Owner> AccountsExit<'info> for AccountLoader<'info, T> { // The account *cannot* be loaded when this is called. fn exit(&self, _program_id: &Pubkey) -> ProgramResult { - let mut data = self.acc_info.try_borrow_mut_data()?; - let dst: &mut [u8] = &mut data; - let mut cursor = std::io::Cursor::new(dst); - cursor.write_all(&T::discriminator()).unwrap(); + // No-op. Ok(()) } } @@ -263,3 +240,10 @@ impl<'info, T: ZeroCopy + Owner> ToAccountInfos<'info> for AccountLoader<'info, vec![self.acc_info.clone()] } } + +#[cfg(not(feature = "deprecated-layout"))] +impl<'a, T: ZeroCopy + Owner> Bump for AccountLoader<'a, T> { + fn seed(&self) -> u8 { + self.acc_info.data.borrow()[1] + } +} diff --git a/lang/src/accounts/loader.rs b/lang/src/accounts/loader.rs index f60bbd71..5d372210 100644 --- a/lang/src/accounts/loader.rs +++ b/lang/src/accounts/loader.rs @@ -1,7 +1,5 @@ use crate::error::ErrorCode; -use crate::{ - Accounts, AccountsClose, AccountsExit, ToAccountInfo, ToAccountInfos, ToAccountMetas, ZeroCopy, -}; +use crate::*; use arrayref::array_ref; use solana_program::account_info::AccountInfo; use solana_program::entrypoint::ProgramResult; @@ -11,7 +9,6 @@ use solana_program::pubkey::Pubkey; use std::cell::{Ref, RefMut}; use std::collections::BTreeMap; use std::fmt; -use std::io::Write; use std::marker::PhantomData; use std::ops::DerefMut; @@ -62,8 +59,13 @@ impl<'info, T: ZeroCopy> Loader<'info, T> { return Err(ErrorCode::AccountOwnedByWrongProgram.into()); } let data: &[u8] = &acc_info.try_borrow_data()?; + // Discriminator must match. + #[cfg(feature = "deprecated-layout")] let disc_bytes = array_ref![data, 0, 8]; + #[cfg(not(feature = "deprecated-layout"))] + let disc_bytes = array_ref![data, 2, 4]; + if disc_bytes != &T::discriminator() { return Err(ErrorCode::AccountDiscriminatorMismatch.into()); } @@ -89,7 +91,11 @@ impl<'info, T: ZeroCopy> Loader<'info, T> { pub fn load(&self) -> Result, ProgramError> { let data = self.acc_info.try_borrow_data()?; + #[cfg(feature = "deprecated-layout")] let disc_bytes = array_ref![data, 0, 8]; + #[cfg(not(feature = "deprecated-layout"))] + let disc_bytes = array_ref![data, 2, 4]; + if disc_bytes != &T::discriminator() { return Err(ErrorCode::AccountDiscriminatorMismatch.into()); } @@ -108,7 +114,11 @@ impl<'info, T: ZeroCopy> Loader<'info, T> { let data = self.acc_info.try_borrow_mut_data()?; + #[cfg(feature = "deprecated-layout")] let disc_bytes = array_ref![data, 0, 8]; + #[cfg(not(feature = "deprecated-layout"))] + let disc_bytes = array_ref![data, 2, 4]; + if disc_bytes != &T::discriminator() { return Err(ErrorCode::AccountDiscriminatorMismatch.into()); } @@ -117,31 +127,6 @@ impl<'info, T: ZeroCopy> Loader<'info, T> { bytemuck::from_bytes_mut(&mut data.deref_mut()[8..]) })) } - - /// Returns a `RefMut` to the account data structure for reading or writing. - /// Should only be called once, when the account is being initialized. - #[allow(deprecated)] - pub fn load_init(&self) -> Result, ProgramError> { - // AccountInfo api allows you to borrow mut even if the account isn't - // writable, so add this check for a better dev experience. - if !self.acc_info.is_writable { - return Err(ErrorCode::AccountNotMutable.into()); - } - - let data = self.acc_info.try_borrow_mut_data()?; - - // The discriminator should be zero, since we're initializing. - let mut disc_bytes = [0u8; 8]; - disc_bytes.copy_from_slice(&data[..8]); - let discriminator = u64::from_le_bytes(disc_bytes); - if discriminator != 0 { - return Err(ErrorCode::AccountDiscriminatorAlreadySet.into()); - } - - Ok(RefMut::map(data, |data| { - bytemuck::from_bytes_mut(&mut data.deref_mut()[8..]) - })) - } } #[allow(deprecated)] @@ -167,10 +152,7 @@ impl<'info, T: ZeroCopy> Accounts<'info> for Loader<'info, T> { impl<'info, T: ZeroCopy> AccountsExit<'info> for Loader<'info, T> { // The account *cannot* be loaded when this is called. fn exit(&self, _program_id: &Pubkey) -> ProgramResult { - let mut data = self.acc_info.try_borrow_mut_data()?; - let dst: &mut [u8] = &mut data; - let mut cursor = std::io::Cursor::new(dst); - cursor.write_all(&T::discriminator()).unwrap(); + // No-op. Ok(()) } } @@ -207,3 +189,10 @@ impl<'info, T: ZeroCopy> ToAccountInfos<'info> for Loader<'info, T> { vec![self.acc_info.clone()] } } + +#[cfg(not(feature = "deprecated-layout"))] +impl<'a, T: ZeroCopy> Bump for Loader<'a, T> { + fn seed(&self) -> u8 { + self.acc_info.data.borrow()[1] + } +} diff --git a/lang/src/accounts/program_account.rs b/lang/src/accounts/program_account.rs index 7a76770f..62afb483 100644 --- a/lang/src/accounts/program_account.rs +++ b/lang/src/accounts/program_account.rs @@ -1,10 +1,7 @@ #[allow(deprecated)] use crate::accounts::cpi_account::CpiAccount; use crate::error::ErrorCode; -use crate::{ - AccountDeserialize, AccountSerialize, Accounts, AccountsClose, AccountsExit, ToAccountInfo, - ToAccountInfos, ToAccountMetas, -}; +use crate::*; use solana_program::account_info::AccountInfo; use solana_program::entrypoint::ProgramResult; use solana_program::instruction::AccountMeta; @@ -102,7 +99,10 @@ impl<'info, T: AccountSerialize + AccountDeserialize + Clone> AccountsExit<'info fn exit(&self, _program_id: &Pubkey) -> ProgramResult { let info = self.to_account_info(); let mut data = info.try_borrow_mut_data()?; - let dst: &mut [u8] = &mut data; + + // Chop off the header. + let dst: &mut [u8] = &mut data[8..]; + let mut cursor = std::io::Cursor::new(dst); self.inner.account.try_serialize(&mut cursor)?; Ok(()) @@ -181,3 +181,10 @@ where Self::new(a.to_account_info(), Deref::deref(&a).clone()) } } + +#[cfg(not(feature = "deprecated-layout"))] +impl<'a, T: AccountSerialize + AccountDeserialize + Clone> Bump for ProgramAccount<'a, T> { + fn seed(&self) -> u8 { + self.inner.info.data.borrow()[1] + } +} diff --git a/lang/src/lib.rs b/lang/src/lib.rs index 9746c96f..84493ea3 100644 --- a/lang/src/lib.rs +++ b/lang/src/lib.rs @@ -200,7 +200,10 @@ pub trait EventData: AnchorSerialize + Discriminator { /// 8 byte unique identifier for a type. pub trait Discriminator { + #[cfg(feature = "deprecated-layout")] fn discriminator() -> [u8; 8]; + #[cfg(not(feature = "deprecated-layout"))] + fn discriminator() -> [u8; 4]; } /// Bump seed for program derived addresses. diff --git a/lang/syn/src/codegen/accounts/constraints.rs b/lang/syn/src/codegen/accounts/constraints.rs index 1754b88e..ccda5edc 100644 --- a/lang/syn/src/codegen/accounts/constraints.rs +++ b/lang/syn/src/codegen/accounts/constraints.rs @@ -142,22 +142,43 @@ fn generate_constraint_address(f: &Field, c: &ConstraintAddress) -> proc_macro2: } } -pub fn generate_constraint_init(f: &Field, c: &ConstraintInitGroup) -> proc_macro2::TokenStream { - generate_constraint_init_group(f, c) -} - pub fn generate_constraint_zeroed(f: &Field, _c: &ConstraintZeroed) -> proc_macro2::TokenStream { let field = &f.ident; let ty_decl = f.ty_decl(); - let from_account_info = f.from_account_info_unchecked(None); + let account_ty = f.account_ty(); + let from_account_info = f.from_account_info(None); + let header_write = { + if cfg!(feature = "deprecated-layout") { + quote! { + use std::io::{Write, Cursor}; + use anchor_lang::Discriminator; + let __dst: &mut [u8] = &mut __data; + let mut __cursor = Cursor::new(__dst); + Write::write_all(&mut __cursor, &#account_ty::discriminator()).unwrap(); + } + } else { + quote! { + use std::io::{Write, Cursor}; + use anchor_lang::Discriminator; + let __dst: &mut [u8] = &mut __data[2..]; + let mut __cursor = Cursor::new(__dst); + Write::write_all(&mut __cursor, &#account_ty::discriminator()).unwrap(); + } + } + }; + // Check the *entire* account header is zero. quote! { let #field: #ty_decl = { - let mut __data: &[u8] = &#field.try_borrow_data()?; - let mut __disc_bytes = [0u8; 8]; - __disc_bytes.copy_from_slice(&__data[..8]); - let __discriminator = u64::from_le_bytes(__disc_bytes); - if __discriminator != 0 { - return Err(anchor_lang::__private::ErrorCode::ConstraintZero.into()); + { + let mut __data: &mut [u8] = &mut #field.try_borrow_mut_data()?; + let mut __header_bytes = [0u8; 8]; + __header_bytes.copy_from_slice(&__data[..8]); + let __header = u64::from_le_bytes(__header_bytes); + if __header != 0 { + return Err(anchor_lang::__private::ErrorCode::ConstraintZero.into()); + } + + #header_write } #from_account_info }; @@ -276,7 +297,7 @@ pub fn generate_constraint_rent_exempt( } } -fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_macro2::TokenStream { +fn generate_constraint_init(f: &Field, c: &ConstraintInitGroup) -> proc_macro2::TokenStream { let field = &f.ident; let ty_decl = f.ty_decl(); let if_needed = if c.if_needed { @@ -295,7 +316,7 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma }; // Convert from account info to account context wrapper type. - let from_account_info = f.from_account_info_unchecked(Some(&c.kind)); + let from_account_info = f.from_account_info(Some(&c.kind)); // PDA bump seeds. let (find_pda, seeds_with_bump) = match &c.seeds { @@ -512,6 +533,45 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma let create_account = generate_create_account(field, quote! {space}, owner.clone(), seeds_with_bump); + // Write the 8 byte header. + let header_write = { + match &f.ty { + Ty::Account(_) + | Ty::ProgramAccount(_) + | Ty::Loader(_) + | Ty::AccountLoader(_) => { + let account_ty = f.account_ty(); + if cfg!(feature = "deprecated-layout") { + quote! { + { + use std::io::{Write, Cursor}; + use anchor_lang::Discriminator; + + let mut __data = actual_field.try_borrow_mut_data()?; + let __dst: &mut [u8] = &mut __data; + let mut __cursor = Cursor::new(__dst); + Write::write_all(&mut __cursor, &#account_ty::discriminator()).unwrap(); + } + } + } else { + quote! { + { + use std::io::{Write, Seek, SeekFrom, Cursor}; + use anchor_lang::Discriminator; + + let mut __data = actual_field.try_borrow_mut_data()?; + let __dst: &mut [u8] = &mut __data; + let mut __cursor = Cursor::new(__dst); + Seek::seek(&mut __cursor, SeekFrom::Start(2)).unwrap(); + Write::write_all(&mut __cursor, &#account_ty::discriminator()).unwrap(); + } + } + } + } + _ => quote! {}, + } + }; + // Put it all together. quote! { // Define the bump variable. @@ -534,6 +594,10 @@ fn generate_constraint_init_group(f: &Field, c: &ConstraintInitGroup) -> proc_ma #create_account } + // Write the account header into the account data before + // deserializing. + #header_write + // Convert from account info to account context wrapper type. let pa: #ty_decl = #from_account_info; diff --git a/lang/syn/src/codegen/program/handlers.rs b/lang/syn/src/codegen/program/handlers.rs index 7a23fcf8..ba725e96 100644 --- a/lang/syn/src/codegen/program/handlers.rs +++ b/lang/syn/src/codegen/program/handlers.rs @@ -89,7 +89,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream { let seed = anchor_lang::idl::IdlAccount::seed(); let owner = accounts.program.key; let to = Pubkey::create_with_seed(&base, seed, owner).unwrap(); - // Space: account discriminator || authority pubkey || vec len || vec data + // Space: account header || authority pubkey || vec len || vec data let space = 8 + 32 + 4 + data_len as usize; let rent = Rent::get()?; let lamports = rent.minimum_balance(space); @@ -130,6 +130,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream { let mut data = accounts.to.try_borrow_mut_data()?; let dst: &mut [u8] = &mut data; let mut cursor = std::io::Cursor::new(dst); + std::io::Seek::seek(&mut cursor, std::io::SeekFrom::Start(8)).unwrap(); idl_account.try_serialize(&mut cursor)?; Ok(()) @@ -201,6 +202,34 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream { let ix_name: proc_macro2::TokenStream = generate_ctor_variant_name().parse().unwrap(); let ix_name_log = format!("Instruction: {}", ix_name); + let header_write = { + if cfg!(feature = "deprecated-layout") { + quote! { + { + use std::io::{Write, Cursor}; + use anchor_lang::Discriminator; + + let mut __data = ctor_accounts.to.try_borrow_mut_data()?; + let __dst: &mut [u8] = &mut __data; + let mut __cursor = Cursor::new(__dst); + Write::write_all(&mut __cursor, &#name::discriminator()).unwrap(); + } + } + } else { + quote! { + { + use std::io::{Write, Cursor, SeekFrom, Seek}; + use anchor_lang::Discriminator; + + let mut __data = ctor_accounts.to.try_borrow_mut_data()?; + let __dst: &mut [u8] = &mut __data; + let mut __cursor = Cursor::new(__dst); + Seek::seek(&mut __cursor, SeekFrom::Start(2)).unwrap(); + Write::write_all(&mut __cursor, &#name::discriminator()).unwrap(); + } + } + } + }; if state.is_zero_copy { quote! { // One time state account initializer. Will faill on subsequent @@ -254,14 +283,17 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream { &[seeds], )?; + // Initialize the header. + #header_write + // Zero copy deserialize. - let loader: anchor_lang::accounts::loader::Loader<#mod_name::#name> = anchor_lang::accounts::loader::Loader::try_from_unchecked(program_id, &ctor_accounts.to)?; + let loader: anchor_lang::accounts::loader::Loader<#mod_name::#name> = anchor_lang::accounts::loader::Loader::try_from(program_id, &ctor_accounts.to)?; // Invoke the ctor in a new lexical scope so that // the zero-copy RefMut gets dropped. Required // so that we can subsequently run the exit routine. { - let mut instance = loader.load_init()?; + let mut instance = loader.load_mut()?; instance.new( anchor_lang::context::Context::new( program_id, @@ -344,11 +376,15 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream { &[seeds], )?; + // Initialize the account header. + #header_write + // Serialize the state and save it to storage. ctor_user_def_accounts.exit(program_id)?; let mut data = ctor_accounts.to.try_borrow_mut_data()?; let dst: &mut [u8] = &mut data; let mut cursor = std::io::Cursor::new(dst); + std::io::Seek::seek(&mut cursor, std::io::SeekFrom::Start(8)).unwrap(); instance.try_serialize(&mut cursor)?; Ok(()) @@ -500,6 +536,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream { let mut data = acc_info.try_borrow_mut_data()?; let dst: &mut [u8] = &mut data; let mut cursor = std::io::Cursor::new(dst); + std::io::Seek::seek(&mut cursor, std::io::SeekFrom::Start(8)).unwrap(); state.try_serialize(&mut cursor)?; Ok(()) @@ -628,6 +665,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream { let mut data = acc_info.try_borrow_mut_data()?; let dst: &mut [u8] = &mut data; let mut cursor = std::io::Cursor::new(dst); + std::io::Seek::seek(&mut cursor, std::io::SeekFrom::Start(8)).unwrap(); state.try_serialize(&mut cursor)?; Ok(()) diff --git a/lang/syn/src/idl/file.rs b/lang/syn/src/idl/file.rs index 119f173a..682a4f65 100644 --- a/lang/syn/src/idl/file.rs +++ b/lang/syn/src/idl/file.rs @@ -238,6 +238,7 @@ pub fn parse( .collect::>(); Ok(Some(Idl { + layout_version: "0.1.0".to_string(), version, name: p.name.to_string(), state, diff --git a/lang/syn/src/idl/mod.rs b/lang/syn/src/idl/mod.rs index 9482d367..01407a79 100644 --- a/lang/syn/src/idl/mod.rs +++ b/lang/syn/src/idl/mod.rs @@ -5,7 +5,11 @@ pub mod file; pub mod pda; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] pub struct Idl { + // Version of the idl protocol. + pub layout_version: String, + // Version of the program. pub version: String, pub name: String, #[serde(skip_serializing_if = "Vec::is_empty", default)] diff --git a/lang/syn/src/lib.rs b/lang/syn/src/lib.rs index 7d29cc51..bb58e29e 100644 --- a/lang/syn/src/lib.rs +++ b/lang/syn/src/lib.rs @@ -26,6 +26,13 @@ pub(crate) mod hash; pub mod idl; pub mod parser; +// Layout indices. +pub const LAYOUT_VERSION: u8 = 0; +pub const LAYOUT_VERSION_INDEX: u8 = 0; +pub const LAYOUT_BUMP_INDEX: u8 = 1; +pub const LAYOUT_DISCRIMINATOR_INDEX: u8 = 2; +pub const LAYOUT_UNUSED_INDEX: u8 = 6; + #[derive(Debug)] pub struct Program { pub state: Option, @@ -273,7 +280,7 @@ impl Field { // TODO: remove the option once `CpiAccount` is completely removed (not // just deprecated). - pub fn from_account_info_unchecked(&self, kind: Option<&InitKind>) -> proc_macro2::TokenStream { + pub fn from_account_info(&self, kind: Option<&InitKind>) -> proc_macro2::TokenStream { let field = &self.ident; let container_ty = self.container_ty(); match &self.ty { @@ -284,13 +291,13 @@ impl Field { Ty::Account(AccountTy { boxed, .. }) => { if *boxed { quote! { - Box::new(#container_ty::try_from_unchecked( + Box::new(#container_ty::try_from( &#field, )?) } } else { quote! { - #container_ty::try_from_unchecked( + #container_ty::try_from( &#field, )? } @@ -298,7 +305,14 @@ impl Field { } Ty::CpiAccount(_) => { quote! { - #container_ty::try_from_unchecked( + #container_ty::try_from( + &#field, + )? + } + } + Ty::AccountLoader(_) => { + quote! { + #container_ty::try_from( &#field, )? } @@ -314,7 +328,7 @@ impl Field { }, }; quote! { - #container_ty::try_from_unchecked( + #container_ty::try_from( #owner_addr, &#field, )? diff --git a/tests/chat/programs/chat/src/lib.rs b/tests/chat/programs/chat/src/lib.rs index 40fac7f8..effcf1af 100644 --- a/tests/chat/programs/chat/src/lib.rs +++ b/tests/chat/programs/chat/src/lib.rs @@ -19,7 +19,7 @@ pub mod chat { let given_name = name.as_bytes(); let mut name = [0u8; 280]; name[..given_name.len()].copy_from_slice(given_name); - let mut chat = ctx.accounts.chat_room.load_init()?; + let mut chat = ctx.accounts.chat_room.load_mut()?; chat.name = name; Ok(()) } diff --git a/tests/misc/programs/misc/src/lib.rs b/tests/misc/programs/misc/src/lib.rs index 0f5e31f6..0a0a0611 100644 --- a/tests/misc/programs/misc/src/lib.rs +++ b/tests/misc/programs/misc/src/lib.rs @@ -122,7 +122,7 @@ pub mod misc { } pub fn test_pda_init_zero_copy(ctx: Context) -> ProgramResult { - let mut acc = ctx.accounts.my_pda.load_init()?; + let mut acc = ctx.accounts.my_pda.load_mut()?; acc.data = 9; acc.bump = *ctx.bumps.get("my_pda").unwrap(); Ok(()) @@ -152,7 +152,7 @@ pub mod misc { } pub fn test_init_zero_copy(ctx: Context) -> ProgramResult { - let mut data = ctx.accounts.data.load_init()?; + let mut data = ctx.accounts.data.load_mut()?; data.data = 10; data.bump = 2; Ok(()) diff --git a/tests/zero-copy/programs/zero-copy/src/lib.rs b/tests/zero-copy/programs/zero-copy/src/lib.rs index 05927688..800f93b8 100644 --- a/tests/zero-copy/programs/zero-copy/src/lib.rs +++ b/tests/zero-copy/programs/zero-copy/src/lib.rs @@ -14,7 +14,7 @@ pub mod zero_copy { use super::*; pub fn create_foo(ctx: Context) -> ProgramResult { - let foo = &mut ctx.accounts.foo.load_init()?; + let foo = &mut ctx.accounts.foo.load_mut()?; foo.authority = *ctx.accounts.authority.key; foo.set_second_authority(ctx.accounts.authority.key); Ok(()) @@ -33,7 +33,7 @@ pub mod zero_copy { } pub fn create_bar(ctx: Context) -> ProgramResult { - let bar = &mut ctx.accounts.bar.load_init()?; + let bar = &mut ctx.accounts.bar.load_mut()?; bar.authority = *ctx.accounts.authority.key; Ok(()) } diff --git a/tests/zero-copy/programs/zero-copy/tests/compute_unit_test.rs b/tests/zero-copy/programs/zero-copy/tests/compute_unit_test.rs index 485782ff..77001ca9 100644 --- a/tests/zero-copy/programs/zero-copy/tests/compute_unit_test.rs +++ b/tests/zero-copy/programs/zero-copy/tests/compute_unit_test.rs @@ -21,8 +21,11 @@ async fn update_foo() { let authority = Keypair::new(); let foo_pubkey = Pubkey::new_unique(); let foo_account = { - let mut foo_data = Vec::new(); + // Write header. + let mut foo_data = vec![0, 0]; foo_data.extend_from_slice(&zero_copy::Foo::discriminator()); + foo_data.extend_from_slice(&[0, 0]); + // Write data. foo_data.extend_from_slice(bytemuck::bytes_of(&zero_copy::Foo { authority: authority.pubkey(), ..zero_copy::Foo::default() diff --git a/ts/src/coder/borsh/accounts.ts b/ts/src/coder/borsh/accounts.ts index b12a13dd..1b709f37 100644 --- a/ts/src/coder/borsh/accounts.ts +++ b/ts/src/coder/borsh/accounts.ts @@ -1,3 +1,4 @@ +import { GetProgramAccountsFilter } from "@solana/web3.js"; import bs58 from "bs58"; import { Buffer } from "buffer"; import { Layout } from "buffer-layout"; @@ -7,11 +8,18 @@ import { Idl, IdlTypeDef } from "../../idl.js"; import { IdlCoder } from "./idl.js"; import { AccountsCoder } from "../index.js"; import { accountSize } from "../common.js"; +import * as features from "../../utils/features"; + +/** + * Number of bytes of the account header. + */ +const ACCOUNT_HEADER_SIZE = 8; /** * Number of bytes of the account discriminator. */ -export const ACCOUNT_DISCRIMINATOR_SIZE = 8; +const ACCOUNT_DISCRIMINATOR_SIZE = 4; +const DEPRECATED_ACCOUNT_DISCRIMINATOR_SIZE = 4; /** * Encodes and decodes account objects. @@ -49,22 +57,21 @@ export class BorshAccountsCoder } const len = layout.encode(account, buffer); let accountData = buffer.slice(0, len); - let discriminator = BorshAccountsCoder.accountDiscriminator(accountName); - return Buffer.concat([discriminator, accountData]); + let header = BorshAccountHeader.encode(accountName); + return Buffer.concat([header, accountData]); } public decode(accountName: A, data: Buffer): T { - // Assert the account discriminator is correct. - const discriminator = BorshAccountsCoder.accountDiscriminator(accountName); - if (discriminator.compare(data.slice(0, 8))) { + const expectedDiscriminator = BorshAccountHeader.discriminator(accountName); + const givenDisc = BorshAccountHeader.parseDiscriminator(data); + if (expectedDiscriminator.compare(givenDisc)) { throw new Error("Invalid account discriminator"); } return this.decodeUnchecked(accountName, data); } public decodeUnchecked(accountName: A, ix: Buffer): T { - // Chop off the discriminator before decoding. - const data = ix.slice(ACCOUNT_DISCRIMINATOR_SIZE); + const data = ix.slice(BorshAccountHeader.size()); // Chop off the header. const layout = this.accountLayouts.get(accountName); if (!layout) { throw new Error(`Unknown account: ${accountName}`); @@ -72,20 +79,40 @@ export class BorshAccountsCoder return layout.decode(data); } - public memcmp(accountName: A, appendData?: Buffer): any { - const discriminator = BorshAccountsCoder.accountDiscriminator(accountName); + public memcmp(accountName: A): GetProgramAccountsFilter { + const discriminator = BorshAccountHeader.discriminator(accountName); return { - offset: 0, - bytes: bs58.encode( - appendData ? Buffer.concat([discriminator, appendData]) : discriminator - ), + memcmp: { + offset: BorshAccountHeader.discriminatorOffset(), + bytes: bs58.encode(discriminator), + }, }; } + public memcmpDataOffset(): number { + return BorshAccountHeader.size(); + } + public size(idlAccount: IdlTypeDef): number { - return ( - ACCOUNT_DISCRIMINATOR_SIZE + (accountSize(this.idl, idlAccount) ?? 0) - ); + return BorshAccountHeader.size() + (accountSize(this.idl, idlAccount) ?? 0); + } +} + +export class BorshAccountHeader { + /** + * Returns the default account header for an account with the given name. + */ + public static encode(accountName: string, nameSpace?: string): Buffer { + if (features.isSet("deprecated-layout")) { + return BorshAccountHeader.discriminator(accountName, nameSpace); + } else { + return Buffer.concat([ + Buffer.from([0]), // Version. + Buffer.from([0]), // Bump. + BorshAccountHeader.discriminator(accountName), // Disc. + Buffer.from([0, 0]), // Unused. + ]); + } } /** @@ -93,9 +120,46 @@ export class BorshAccountsCoder * * @param name The name of the account to calculate the discriminator. */ - public static accountDiscriminator(name: string): Buffer { + public static discriminator(name: string, nameSpace?: string): Buffer { return Buffer.from( - sha256.digest(`account:${camelcase(name, { pascalCase: true })}`) - ).slice(0, ACCOUNT_DISCRIMINATOR_SIZE); + sha256.digest( + `${nameSpace ?? "account"}:${camelcase(name, { pascalCase: true })}` + ) + ).slice(0, BorshAccountHeader.discriminatorSize()); + } + + public static discriminatorSize(): number { + return features.isSet("deprecated-layout") + ? DEPRECATED_ACCOUNT_DISCRIMINATOR_SIZE + : ACCOUNT_DISCRIMINATOR_SIZE; + } + + /** + * Returns the account data index at which the discriminator starts. + */ + public static discriminatorOffset(): number { + if (features.isSet("deprecated-layout")) { + return 0; + } else { + return 2; + } + } + + /** + * Returns the byte size of the account header. + */ + public static size(): number { + return ACCOUNT_HEADER_SIZE; + } + + /** + * Returns the discriminator from the given account data. + */ + public static parseDiscriminator(data: Buffer): Buffer { + if (features.isSet("deprecated-layout")) { + return data.slice(0, 8); + } else { + return data.slice(2, 6); + } } } diff --git a/ts/src/coder/borsh/event.ts b/ts/src/coder/borsh/event.ts index 6581c623..2cdb0c6e 100644 --- a/ts/src/coder/borsh/event.ts +++ b/ts/src/coder/borsh/event.ts @@ -6,6 +6,7 @@ import { Idl, IdlEvent, IdlTypeDef } from "../../idl.js"; import { Event, EventData } from "../../program/event.js"; import { IdlCoder } from "./idl.js"; import { EventCoder } from "../index.js"; +import * as features from "../../utils/features"; export class BorshEventCoder implements EventCoder { /** @@ -41,7 +42,7 @@ export class BorshEventCoder implements EventCoder { idl.events === undefined ? [] : idl.events.map((e) => [ - base64.fromByteArray(eventDiscriminator(e.name)), + base64.fromByteArray(EventHeader.discriminator(e.name)), e.name, ]) ); @@ -57,7 +58,7 @@ export class BorshEventCoder implements EventCoder { } catch (e) { return null; } - const disc = base64.fromByteArray(logArr.slice(0, 8)); + const disc = base64.fromByteArray(EventHeader.parseDiscriminator(logArr)); // Only deserialize if the discriminator implies a proper event. const eventName = this.discriminators.get(disc); @@ -69,7 +70,7 @@ export class BorshEventCoder implements EventCoder { if (!layout) { throw new Error(`Unknown event: ${eventName}`); } - const data = layout.decode(logArr.slice(8)) as EventData< + const data = layout.decode(logArr.slice(EventHeader.size())) as EventData< E["fields"][number], T >; @@ -78,5 +79,31 @@ export class BorshEventCoder implements EventCoder { } export function eventDiscriminator(name: string): Buffer { - return Buffer.from(sha256.digest(`event:${name}`)).slice(0, 8); + return EventHeader.discriminator(name); +} + +class EventHeader { + public static parseDiscriminator(data: Buffer): Buffer { + if (features.isSet("deprecated-layout")) { + return data.slice(0, 8); + } else { + return data.slice(0, 4); + } + } + + public static size(): number { + if (features.isSet("deprecated-layout")) { + return 8; + } else { + return 4; + } + } + + public static discriminator(name: string): Buffer { + if (features.isSet("deprecated-layout")) { + return Buffer.from(sha256.digest(`event:${name}`)).slice(0, 8); + } else { + return Buffer.from(sha256.digest(`event:${name}`)).slice(0, 4); + } + } } diff --git a/ts/src/coder/borsh/index.ts b/ts/src/coder/borsh/index.ts index 4b84fd2e..a16fc649 100644 --- a/ts/src/coder/borsh/index.ts +++ b/ts/src/coder/borsh/index.ts @@ -6,7 +6,7 @@ import { BorshStateCoder } from "./state.js"; import { Coder } from "../index.js"; export { BorshInstructionCoder } from "./instruction.js"; -export { BorshAccountsCoder, ACCOUNT_DISCRIMINATOR_SIZE } from "./accounts.js"; +export { BorshAccountsCoder, BorshAccountHeader } from "./accounts.js"; export { BorshEventCoder, eventDiscriminator } from "./event.js"; export { BorshStateCoder, stateDiscriminator } from "./state.js"; diff --git a/ts/src/coder/borsh/state.ts b/ts/src/coder/borsh/state.ts index b33f7b20..3c58b2eb 100644 --- a/ts/src/coder/borsh/state.ts +++ b/ts/src/coder/borsh/state.ts @@ -4,6 +4,7 @@ import { sha256 } from "js-sha256"; import { Idl } from "../../idl.js"; import { IdlCoder } from "./idl.js"; import * as features from "../../utils/features.js"; +import { BorshAccountHeader } from "./accounts"; export class BorshStateCoder { private layout: Layout; @@ -19,15 +20,16 @@ export class BorshStateCoder { const buffer = Buffer.alloc(1000); // TODO: use a tighter buffer. const len = this.layout.encode(account, buffer); - const disc = await stateDiscriminator(name); + let ns = features.isSet("anchor-deprecated-state") ? "account" : "state"; + const header = BorshAccountHeader.encode(name, ns); const accData = buffer.slice(0, len); - return Buffer.concat([disc, accData]); + return Buffer.concat([header, accData]); } - public decode(ix: Buffer): T { - // Chop off discriminator. - const data = ix.slice(8); + public decode(data: Buffer): T { + // Chop off header. + data = data.slice(BorshAccountHeader.size()); return this.layout.decode(data); } } @@ -35,5 +37,8 @@ export class BorshStateCoder { // Calculates unique 8 byte discriminator prepended to all anchor state accounts. export async function stateDiscriminator(name: string): Promise { let ns = features.isSet("anchor-deprecated-state") ? "account" : "state"; - return Buffer.from(sha256.digest(`${ns}:${name}`)).slice(0, 8); + return Buffer.from(sha256.digest(`${ns}:${name}`)).slice( + 0, + BorshAccountHeader.discriminatorSize() + ); } diff --git a/ts/src/coder/index.ts b/ts/src/coder/index.ts index 6256103a..02d0c8b5 100644 --- a/ts/src/coder/index.ts +++ b/ts/src/coder/index.ts @@ -1,3 +1,4 @@ +import { GetProgramAccountsFilter } from "@solana/web3.js"; import { IdlEvent, IdlTypeDef } from "../idl.js"; import { Event } from "../program/event.js"; @@ -38,7 +39,8 @@ export interface AccountsCoder { encode(accountName: A, account: T): Promise; decode(accountName: A, ix: Buffer): T; decodeUnchecked(accountName: A, ix: Buffer): T; - memcmp(accountName: A, appendData?: Buffer): any; + memcmp(accountName: A): GetProgramAccountsFilter; + memcmpDataOffset(): number; size(idlAccount: IdlTypeDef): number; } diff --git a/ts/src/coder/spl-token/accounts.ts b/ts/src/coder/spl-token/accounts.ts index a73311f2..65394263 100644 --- a/ts/src/coder/spl-token/accounts.ts +++ b/ts/src/coder/spl-token/accounts.ts @@ -1,3 +1,4 @@ +import { GetProgramAccountsFilter } from "@solana/web3.js"; import * as BufferLayout from "buffer-layout"; import { publicKey, uint64, coption, bool } from "./buffer-layout.js"; import { AccountsCoder } from "../index.js"; @@ -44,8 +45,7 @@ export class SplTokenAccountsCoder } } - // TODO: this won't use the appendData. - public memcmp(accountName: A, _appendData?: Buffer): any { + public memcmp(accountName: A): GetProgramAccountsFilter { switch (accountName) { case "Token": { return { @@ -63,6 +63,10 @@ export class SplTokenAccountsCoder } } + public memcmpDataOffset(): number { + return 0; + } + public size(idlAccount: IdlTypeDef): number { return accountSize(this.idl, idlAccount) ?? 0; } diff --git a/ts/src/idl.ts b/ts/src/idl.ts index afe07ad4..989ee620 100644 --- a/ts/src/idl.ts +++ b/ts/src/idl.ts @@ -3,6 +3,7 @@ import { PublicKey } from "@solana/web3.js"; import * as borsh from "@project-serum/borsh"; export type Idl = { + layoutVersion: string; version: string; name: string; instructions: IdlInstruction[]; diff --git a/ts/src/program/namespace/account.ts b/ts/src/program/namespace/account.ts index 69341c58..cfb8282d 100644 --- a/ts/src/program/namespace/account.ts +++ b/ts/src/program/namespace/account.ts @@ -1,3 +1,4 @@ +import * as bs58 from 'bs58'; import camelCase from "camelcase"; import EventEmitter from "eventemitter3"; import { @@ -201,19 +202,25 @@ export class AccountClient< async all( filters?: Buffer | GetProgramAccountsFilter[] ): Promise[]> { + const typeFilter = [this.coder.accounts.memcmp(this._idlAccount.name)]; + const dataFilter = + filters instanceof Buffer + ? [ + { + memcmp: { + offset: this.coder.accounts.memcmpDataOffset(), + bytes: bs58.encode(filters), + }, + }, + ] + : []; + const miscFilters = Array.isArray(filters) ? filters : []; + let resp = await this._provider.connection.getProgramAccounts( this._programId, { commitment: this._provider.connection.commitment, - filters: [ - { - memcmp: this.coder.accounts.memcmp( - this._idlAccount.name, - filters instanceof Buffer ? filters : undefined - ), - }, - ...(Array.isArray(filters) ? filters : []), - ], + filters: typeFilter.concat(dataFilter).concat(miscFilters), } ); return resp.map(({ pubkey, account }) => { diff --git a/ts/src/program/namespace/state.ts b/ts/src/program/namespace/state.ts index 2e3bc6f4..9795f43a 100644 --- a/ts/src/program/namespace/state.ts +++ b/ts/src/program/namespace/state.ts @@ -24,6 +24,7 @@ import InstructionNamespaceFactory from "./instruction.js"; import RpcNamespaceFactory from "./rpc.js"; import TransactionNamespaceFactory from "./transaction.js"; import { IdlTypes, TypeDef } from "./types.js"; +import * as features from "../../utils/features.js"; export default class StateFactory { public static build( @@ -172,10 +173,19 @@ export class StateClient { if (!state) { throw new Error("State is not specified in IDL."); } + const expectedDiscriminator = await stateDiscriminator(state.struct.name); - if (expectedDiscriminator.compare(accountInfo.data.slice(0, 8))) { - throw new Error("Invalid account discriminator"); + + if (features.isSet("deprecated-layout")) { + if (expectedDiscriminator.compare(accountInfo.data.slice(0, 8))) { + throw new Error("Invalid state discriminator"); + } + } else { + if (expectedDiscriminator.compare(accountInfo.data.slice(2, 6))) { + throw new Error("Invalid state discriminator"); + } } + return this.coder.state.decode(accountInfo.data); } diff --git a/ts/src/spl/token.ts b/ts/src/spl/token.ts index cb661251..886b06fb 100644 --- a/ts/src/spl/token.ts +++ b/ts/src/spl/token.ts @@ -19,6 +19,7 @@ export function coder(): SplTokenCoder { * SplToken IDL. */ export type SplToken = { + layoutVersion: "custom"; version: "0.1.0"; name: "spl_token"; instructions: [ @@ -624,6 +625,7 @@ export type SplToken = { }; export const IDL: SplToken = { + layoutVersion: "custom", version: "0.1.0", name: "spl_token", instructions: [ diff --git a/ts/src/utils/features.ts b/ts/src/utils/features.ts index 246b2b99..32e03d05 100644 --- a/ts/src/utils/features.ts +++ b/ts/src/utils/features.ts @@ -1,4 +1,8 @@ -const _AVAILABLE_FEATURES = new Set(["anchor-deprecated-state", "debug-logs"]); +const _AVAILABLE_FEATURES = new Set([ + "anchor-deprecated-state", + "debug-logs", + "deprecated-layout", +]); const _FEATURES = new Map(); diff --git a/ts/tests/events.spec.ts b/ts/tests/events.spec.ts index 5e8456ad..4b583539 100644 --- a/ts/tests/events.spec.ts +++ b/ts/tests/events.spec.ts @@ -12,6 +12,7 @@ describe("Events", () => { "Program J2XMGdW2qQLx7rAdwWtSZpTXDgAQ988BLP9QTgUZvm54 success", ]; const idl = { + layoutVersion: "0.1.0", version: "0.0.0", name: "basic_0", instructions: [ diff --git a/ts/tests/transaction.spec.ts b/ts/tests/transaction.spec.ts index 4a3e69ae..4a080c90 100644 --- a/ts/tests/transaction.spec.ts +++ b/ts/tests/transaction.spec.ts @@ -15,6 +15,7 @@ describe("Transaction", () => { data: Buffer.from("post"), }); const idl = { + layoutVersion: "0.1.0", version: "0.0.0", name: "basic_0", instructions: [