From 170e6f18d46deacaa721ac558070f63f527abd49 Mon Sep 17 00:00:00 2001 From: Armani Ferrante Date: Fri, 5 Feb 2021 19:17:40 +0800 Subject: [PATCH] lang: Hash at compile time (#63) --- lang/attribute/account/Cargo.toml | 2 +- lang/attribute/account/src/lib.rs | 44 ++++------ lang/syn/Cargo.toml | 4 + lang/syn/src/hash.rs | 139 ++++++++++++++++++++++++++++++ lang/syn/src/lib.rs | 2 + 5 files changed, 163 insertions(+), 28 deletions(-) create mode 100644 lang/syn/src/hash.rs diff --git a/lang/attribute/account/Cargo.toml b/lang/attribute/account/Cargo.toml index e245c0d2..3df21250 100644 --- a/lang/attribute/account/Cargo.toml +++ b/lang/attribute/account/Cargo.toml @@ -15,4 +15,4 @@ proc-macro2 = "1.0" quote = "1.0" syn = { version = "=1.0.57", features = ["full"] } anyhow = "1.0.32" -anchor-syn = { path = "../../syn", version = "0.1.0" } +anchor-syn = { path = "../../syn", version = "0.1.0", features = ["hash"] } diff --git a/lang/attribute/account/src/lib.rs b/lang/attribute/account/src/lib.rs index bbb9b08c..99e525c9 100644 --- a/lang/attribute/account/src/lib.rs +++ b/lang/attribute/account/src/lib.rs @@ -27,29 +27,26 @@ pub fn account( let account_strct = parse_macro_input!(input as syn::ItemStruct); let account_name = &account_strct.ident; - // Namespace the discriminator to prevent collisions. - let discriminator_preimage = { - if namespace == "" { - format!("account:{}", account_name.to_string()) - } else { - format!("{}:{}", namespace, account_name.to_string()) - } + let discriminator: proc_macro2::TokenStream = { + // Namespace the discriminator to prevent collisions. + let discriminator_preimage = { + if namespace == "" { + format!("account:{}", account_name.to_string()) + } else { + format!("{}:{}", namespace, account_name.to_string()) + } + }; + let mut discriminator = [0u8; 8]; + discriminator.copy_from_slice( + &anchor_syn::hash::hash(discriminator_preimage.as_bytes()).to_bytes()[..8], + ); + format!("{:?}", discriminator).parse().unwrap() }; let coder = quote! { impl anchor_lang::AccountSerialize for #account_name { fn try_serialize(&self, writer: &mut W) -> std::result::Result<(), ProgramError> { - // TODO: we shouldn't have to hash at runtime. However, rust - // is not happy when trying to include solana-sdk from - // the proc-macro crate. - let mut discriminator = [0u8; 8]; - discriminator.copy_from_slice( - &anchor_lang::solana_program::hash::hash( - #discriminator_preimage.as_bytes(), - ).to_bytes()[..8], - ); - - writer.write_all(&discriminator).map_err(|_| ProgramError::InvalidAccountData)?; + writer.write_all(&#discriminator).map_err(|_| ProgramError::InvalidAccountData)?; AnchorSerialize::serialize( self, writer @@ -62,18 +59,11 @@ pub fn account( impl anchor_lang::AccountDeserialize for #account_name { fn try_deserialize(buf: &mut &[u8]) -> std::result::Result { - let mut discriminator = [0u8; 8]; - discriminator.copy_from_slice( - &anchor_lang::solana_program::hash::hash( - #discriminator_preimage.as_bytes(), - ).to_bytes()[..8], - ); - - if buf.len() < discriminator.len() { + if buf.len() < #discriminator.len() { return Err(ProgramError::AccountDataTooSmall); } let given_disc = &buf[..8]; - if &discriminator != given_disc { + if &#discriminator != given_disc { return Err(ProgramError::InvalidInstructionData); } Self::try_deserialize_unchecked(buf) diff --git a/lang/syn/Cargo.toml b/lang/syn/Cargo.toml index ba8b0b13..867afa38 100644 --- a/lang/syn/Cargo.toml +++ b/lang/syn/Cargo.toml @@ -9,6 +9,7 @@ edition = "2018" [features] idl = [] +hash = [] default = [] [dependencies] @@ -19,3 +20,6 @@ anyhow = "1.0.32" heck = "0.3.1" serde = { version = "1.0.118", features = ["derive"] } serde_json = "1.0" +sha2 = "0.9.2" +thiserror = "1.0" +bs58 = "0.3.1" diff --git a/lang/syn/src/hash.rs b/lang/syn/src/hash.rs new file mode 100644 index 00000000..a1d3f2be --- /dev/null +++ b/lang/syn/src/hash.rs @@ -0,0 +1,139 @@ +// Utility hashing module copied from `solana_program::program::hash`, since we +// can't import solana_program for compile time hashing for some reason. + +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::{convert::TryFrom, fmt, mem, str::FromStr}; +use thiserror::Error; + +pub const HASH_BYTES: usize = 32; +#[derive(Serialize, Deserialize, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Hash)] +#[repr(transparent)] +pub struct Hash(pub [u8; HASH_BYTES]); + +#[derive(Clone, Default)] +pub struct Hasher { + hasher: Sha256, +} + +impl Hasher { + pub fn hash(&mut self, val: &[u8]) { + self.hasher.update(val); + } + pub fn hashv(&mut self, vals: &[&[u8]]) { + for val in vals { + self.hash(val); + } + } + pub fn result(self) -> Hash { + // At the time of this writing, the sha2 library is stuck on an old version + // of generic_array (0.9.0). Decouple ourselves with a clone to our version. + Hash(<[u8; HASH_BYTES]>::try_from(self.hasher.finalize().as_slice()).unwrap()) + } +} + +impl AsRef<[u8]> for Hash { + fn as_ref(&self) -> &[u8] { + &self.0[..] + } +} + +impl fmt::Debug for Hash { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", bs58::encode(self.0).into_string()) + } +} + +impl fmt::Display for Hash { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", bs58::encode(self.0).into_string()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum ParseHashError { + #[error("string decoded to wrong size for hash")] + WrongSize, + #[error("failed to decoded string to hash")] + Invalid, +} + +impl FromStr for Hash { + type Err = ParseHashError; + + fn from_str(s: &str) -> Result { + let bytes = bs58::decode(s) + .into_vec() + .map_err(|_| ParseHashError::Invalid)?; + if bytes.len() != mem::size_of::() { + Err(ParseHashError::WrongSize) + } else { + Ok(Hash::new(&bytes)) + } + } +} + +impl Hash { + pub fn new(hash_slice: &[u8]) -> Self { + Hash(<[u8; HASH_BYTES]>::try_from(hash_slice).unwrap()) + } + + pub const fn new_from_array(hash_array: [u8; HASH_BYTES]) -> Self { + Self(hash_array) + } + + /// unique Hash for tests and benchmarks. + pub fn new_unique() -> Self { + use std::sync::atomic::{AtomicU64, Ordering}; + static I: AtomicU64 = AtomicU64::new(1); + + let mut b = [0u8; HASH_BYTES]; + let i = I.fetch_add(1, Ordering::Relaxed); + b[0..8].copy_from_slice(&i.to_le_bytes()); + Self::new(&b) + } + + pub fn to_bytes(self) -> [u8; HASH_BYTES] { + self.0 + } +} + +/// Return a Sha256 hash for the given data. +pub fn hashv(vals: &[&[u8]]) -> Hash { + // Perform the calculation inline, calling this from within a program is + // not supported + #[cfg(not(target_arch = "bpf"))] + { + let mut hasher = Hasher::default(); + hasher.hashv(vals); + hasher.result() + } + // Call via a system call to perform the calculation + #[cfg(target_arch = "bpf")] + { + extern "C" { + fn sol_sha256(vals: *const u8, val_len: u64, hash_result: *mut u8) -> u64; + }; + let mut hash_result = [0; HASH_BYTES]; + unsafe { + sol_sha256( + vals as *const _ as *const u8, + vals.len() as u64, + &mut hash_result as *mut _ as *mut u8, + ); + } + Hash::new_from_array(hash_result) + } +} + +/// Return a Sha256 hash for the given data. +pub fn hash(val: &[u8]) -> Hash { + hashv(&[val]) +} + +/// Return the hash of the given hash extended with the given value. +pub fn extend_and_hash(id: &Hash, val: &[u8]) -> Hash { + let mut hash_data = id.as_ref().to_vec(); + hash_data.extend_from_slice(val); + hash(&hash_data) +} diff --git a/lang/syn/src/lib.rs b/lang/syn/src/lib.rs index 93dff54f..5f59c725 100644 --- a/lang/syn/src/lib.rs +++ b/lang/syn/src/lib.rs @@ -9,6 +9,8 @@ use quote::quote; use std::collections::HashMap; pub mod codegen; +#[cfg(feature = "hash")] +pub mod hash; #[cfg(feature = "idl")] pub mod idl; pub mod parser;