diff --git a/CHANGELOG.md b/CHANGELOG.md index aba6fd34..d83440ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The minor version will be incremented upon a breaking change and the patch versi * client: Add `transaction` functions to RequestBuilder ([#1958](https://github.com/coral-xyz/anchor/pull/1958)). * spl: Add `create_metadata_accounts_v3` and `set_collection_size` wrappers ([#2119](https://github.com/coral-xyz/anchor/pull/2119)) * spl: Add `MetadataAccount` account deserialization. ([#2014](https://github.com/coral-xyz/anchor/pull/2014)). +* lang: Add parsing for consts from impl blocks for IDL PDA seeds generation ([#2128](https://github.com/coral-xyz/anchor/pull/2014)) ### Fixes diff --git a/lang/syn/src/idl/mod.rs b/lang/syn/src/idl/mod.rs index 3d3dd12b..b9220ef2 100644 --- a/lang/syn/src/idl/mod.rs +++ b/lang/syn/src/idl/mod.rs @@ -234,7 +234,7 @@ impl std::str::FromStr for IdlType { "u128" => IdlType::U128, "i128" => IdlType::I128, "Vec" => IdlType::Bytes, - "String" | "&str" => IdlType::String, + "String" | "&str" | "&'staticstr" => IdlType::String, "Pubkey" => IdlType::PublicKey, _ => match s.to_string().strip_prefix("Option<") { None => match s.to_string().strip_prefix("Vec<") { diff --git a/lang/syn/src/idl/pda.rs b/lang/syn/src/idl/pda.rs index 81879c82..1a34e377 100644 --- a/lang/syn/src/idl/pda.rs +++ b/lang/syn/src/idl/pda.rs @@ -56,6 +56,8 @@ struct PdaParser<'a> { ix_args: HashMap, // Constants available in the crate. const_names: Vec, + // Constants declared in impl blocks available in the crate + impl_const_names: Vec, // All field names of the accounts in the accounts context. account_field_names: Vec, } @@ -65,6 +67,12 @@ impl<'a> PdaParser<'a> { // All the available sources of seeds. let ix_args = accounts.instruction_args().unwrap_or_default(); let const_names: Vec = ctx.consts().map(|c| c.ident.to_string()).collect(); + + let impl_const_names: Vec = ctx + .impl_consts() + .map(|(ident, item)| format!("{} :: {}", ident, item.ident)) + .collect(); + let account_field_names = accounts.field_names(); Self { @@ -72,6 +80,7 @@ impl<'a> PdaParser<'a> { accounts, ix_args, const_names, + impl_const_names, account_field_names, } } @@ -83,7 +92,6 @@ impl<'a> PdaParser<'a> { .iter() .map(|s| self.parse_seed(s)) .collect::>>()?; - // Parse the program id from the constraints. let program_id = seeds_grp .program_seed @@ -104,6 +112,8 @@ impl<'a> PdaParser<'a> { self.parse_instruction(&seed_path) } else if self.is_const(&seed_path) { self.parse_const(&seed_path) + } else if self.is_impl_const(&seed_path) { + self.parse_impl_const(&seed_path) } else if self.is_account(&seed_path) { self.parse_account(&seed_path) } else if self.is_str_literal(&seed_path) { @@ -150,18 +160,30 @@ impl<'a> PdaParser<'a> { .find(|c| c.ident == seed_path.name()) .unwrap(); let idl_ty = IdlType::from_str(&parser::tts_to_string(&const_item.ty)).ok()?; - let mut idl_ty_value = parser::tts_to_string(&const_item.expr); - if let IdlType::Array(_ty, _size) = &idl_ty { - // Convert str literal to array. - if idl_ty_value.contains("b\"") { - let components: Vec<&str> = idl_ty_value.split('b').collect(); - assert!(components.len() == 2); - let mut str_lit = components[1].to_string(); - str_lit.retain(|c| c != '"'); - idl_ty_value = format!("{:?}", str_lit.as_bytes()); - } - } + let idl_ty_value = parser::tts_to_string(&const_item.expr); + let idl_ty_value = str_lit_to_array(&idl_ty, &idl_ty_value); + + Some(IdlSeed::Const(IdlSeedConst { + ty: idl_ty, + value: serde_json::from_str(&idl_ty_value).unwrap(), + })) + } + + fn parse_impl_const(&self, seed_path: &SeedPath) -> Option { + // Pull in the constant value directly into the IDL. + assert!(seed_path.components().is_empty()); + let static_item = self + .ctx + .impl_consts() + .find(|(ident, item)| format!("{} :: {}", ident, item.ident) == seed_path.name()) + .unwrap() + .1; + + let idl_ty = IdlType::from_str(&parser::tts_to_string(&static_item.ty)).ok()?; + + let idl_ty_value = parser::tts_to_string(&static_item.expr); + let idl_ty_value = str_lit_to_array(&idl_ty, &idl_ty_value); Some(IdlSeed::Const(IdlSeedConst { ty: idl_ty, @@ -236,6 +258,10 @@ impl<'a> PdaParser<'a> { self.const_names.contains(&seed_path.name()) } + fn is_impl_const(&self, seed_path: &SeedPath) -> bool { + self.impl_const_names.contains(&seed_path.name()) + } + fn is_account(&self, seed_path: &SeedPath) -> bool { self.account_field_names.contains(&seed_path.name()) } @@ -327,3 +353,17 @@ fn parse_field_path(ctx: &CrateContext, strct: &syn::ItemStruct, path: &mut &[St parse_field_path(ctx, strct, path) } + +fn str_lit_to_array(idl_ty: &IdlType, idl_ty_value: &String) -> String { + if let IdlType::Array(_ty, _size) = &idl_ty { + // Convert str literal to array. + if idl_ty_value.contains("b\"") { + let components: Vec<&str> = idl_ty_value.split('b').collect(); + assert_eq!(components.len(), 2); + let mut str_lit = components[1].to_string(); + str_lit.retain(|c| c != '"'); + return format!("{:?}", str_lit.as_bytes()); + } + } + idl_ty_value.to_string() +} diff --git a/lang/syn/src/parser/context.rs b/lang/syn/src/parser/context.rs index 28cfdee1..8e0fb2a3 100644 --- a/lang/syn/src/parser/context.rs +++ b/lang/syn/src/parser/context.rs @@ -2,6 +2,7 @@ use anyhow::anyhow; use std::collections::BTreeMap; use std::path::{Path, PathBuf}; use syn::parse::{Error as ParseError, Result as ParseResult}; +use syn::{Ident, ImplItem, ImplItemConst, Type, TypePath}; /// Crate parse context /// @@ -15,6 +16,10 @@ impl CrateContext { self.modules.iter().flat_map(|(_, ctx)| ctx.consts()) } + pub fn impl_consts(&self) -> impl Iterator { + self.modules.iter().flat_map(|(_, ctx)| ctx.impl_consts()) + } + pub fn structs(&self) -> impl Iterator { self.modules.iter().flat_map(|(_, ctx)| ctx.structs()) } @@ -244,4 +249,36 @@ impl ParsedModule { _ => None, }) } + + fn impl_consts(&self) -> impl Iterator { + self.items + .iter() + .filter_map(|i| match i { + syn::Item::Impl(syn::ItemImpl { + self_ty: ty, items, .. + }) => { + if let Type::Path(TypePath { + qself: None, + path: p, + }) = ty.as_ref() + { + if let Some(ident) = p.get_ident() { + let mut to_return = Vec::new(); + items.iter().for_each(|item| { + if let ImplItem::Const(item) = item { + to_return.push((ident, item)); + } + }); + Some(to_return) + } else { + None + } + } else { + None + } + } + _ => None, + }) + .flatten() + } }