anchor/lang/syn/src/idl/pda.rs

370 lines
13 KiB
Rust

use crate::idl::*;
use crate::parser;
use crate::parser::context::CrateContext;
use crate::ConstraintSeedsGroup;
use crate::{AccountsStruct, Field};
use std::collections::HashMap;
use std::str::FromStr;
use syn::{Expr, ExprLit, Lit};
// Parses a seeds constraint, extracting the IdlSeed types.
//
// Note: This implementation makes assumptions about the types that can be used
// (e.g., no program-defined function calls in seeds).
//
// This probably doesn't cover all cases. If you see a warning log, you
// can add a new case here. In the worst case, we miss a seed and
// the parser will treat the given seeds as empty and so clients will
// simply fail to automatically populate the PDA accounts.
//
// Seed Assumptions: Seeds must be of one of the following forms:
//
// - instruction argument.
// - account context field pubkey.
// - account data, where the account is defined in the current program.
// We make an exception for the SPL token program, since it is so common
// and sometimes convenient to use fields as a seed (e.g. Auction house
// program). In the case of nested structs/account data, all nested structs
// must be defined in the current program as well.
// - byte string literal (e.g. b"MY_SEED").
// - byte string literal constant (e.g. `pub const MY_SEED: [u8; 2] = *b"hi";`).
// - array constants.
//
pub fn parse(
ctx: &CrateContext,
accounts: &AccountsStruct,
acc: &Field,
seeds_feature: bool,
) -> Option<IdlPda> {
if !seeds_feature {
return None;
}
let pda_parser = PdaParser::new(ctx, accounts);
acc.constraints
.seeds
.as_ref()
.map(|s| pda_parser.parse(s))
.unwrap_or(None)
}
struct PdaParser<'a> {
ctx: &'a CrateContext,
// Accounts context.
accounts: &'a AccountsStruct,
// Maps var name to var type. These are the instruction arguments in a
// given accounts context.
ix_args: HashMap<String, String>,
// Constants available in the crate.
const_names: Vec<String>,
// Constants declared in impl blocks available in the crate
impl_const_names: Vec<String>,
// All field names of the accounts in the accounts context.
account_field_names: Vec<String>,
}
impl<'a> PdaParser<'a> {
fn new(ctx: &'a CrateContext, accounts: &'a AccountsStruct) -> Self {
// All the available sources of seeds.
let ix_args = accounts.instruction_args().unwrap_or_default();
let const_names: Vec<String> = ctx.consts().map(|c| c.ident.to_string()).collect();
let impl_const_names: Vec<String> = ctx
.impl_consts()
.map(|(ident, item)| format!("{} :: {}", ident, item.ident))
.collect();
let account_field_names = accounts.field_names();
Self {
ctx,
accounts,
ix_args,
const_names,
impl_const_names,
account_field_names,
}
}
fn parse(&self, seeds_grp: &ConstraintSeedsGroup) -> Option<IdlPda> {
// Extract the idl seed types from the constraints.
let seeds = seeds_grp
.seeds
.iter()
.map(|s| self.parse_seed(s))
.collect::<Option<Vec<_>>>()?;
// Parse the program id from the constraints.
let program_id = seeds_grp
.program_seed
.as_ref()
.map(|pid| self.parse_seed(pid))
.unwrap_or_default();
// Done.
Some(IdlPda { seeds, program_id })
}
fn parse_seed(&self, seed: &Expr) -> Option<IdlSeed> {
match seed {
Expr::MethodCall(_) => {
let seed_path = parse_seed_path(seed)?;
if self.is_instruction(&seed_path) {
self.parse_instruction(&seed_path)
} else if self.is_const(&seed_path) {
self.parse_const(&seed_path)
} else if self.is_impl_const(&seed_path) {
self.parse_impl_const(&seed_path)
} else if self.is_account(&seed_path) {
self.parse_account(&seed_path)
} else if self.is_str_literal(&seed_path) {
self.parse_str_literal(&seed_path)
} else {
println!("WARNING: unexpected seed category for var: {:?}", seed_path);
None
}
}
Expr::Reference(expr_reference) => self.parse_seed(&expr_reference.expr),
Expr::Index(_) => {
println!("WARNING: auto pda derivation not currently supported for slice literals");
None
}
Expr::Lit(ExprLit {
lit: Lit::ByteStr(lit_byte_str),
..
}) => {
let seed_path: SeedPath = SeedPath(lit_byte_str.token().to_string(), Vec::new());
self.parse_str_literal(&seed_path)
}
// Unknown type. Please file an issue.
_ => {
println!("WARNING: unexpected seed: {:?}", seed);
None
}
}
}
fn parse_instruction(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
let idl_ty = IdlType::from_str(self.ix_args.get(&seed_path.name()).unwrap()).ok()?;
Some(IdlSeed::Arg(IdlSeedArg {
ty: idl_ty,
path: seed_path.path(),
}))
}
fn parse_const(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
// Pull in the constant value directly into the IDL.
assert!(seed_path.components().is_empty());
let const_item = self
.ctx
.consts()
.find(|c| c.ident == seed_path.name())
.unwrap();
let idl_ty = IdlType::from_str(&parser::tts_to_string(&const_item.ty)).ok()?;
let idl_ty_value = parser::tts_to_string(&const_item.expr);
let idl_ty_value = str_lit_to_array(&idl_ty, &idl_ty_value);
Some(IdlSeed::Const(IdlSeedConst {
ty: idl_ty,
value: serde_json::from_str(&idl_ty_value).unwrap(),
}))
}
fn parse_impl_const(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
// Pull in the constant value directly into the IDL.
assert!(seed_path.components().is_empty());
let static_item = self
.ctx
.impl_consts()
.find(|(ident, item)| format!("{} :: {}", ident, item.ident) == seed_path.name())
.unwrap()
.1;
let idl_ty = IdlType::from_str(&parser::tts_to_string(&static_item.ty)).ok()?;
let idl_ty_value = parser::tts_to_string(&static_item.expr);
let idl_ty_value = str_lit_to_array(&idl_ty, &idl_ty_value);
Some(IdlSeed::Const(IdlSeedConst {
ty: idl_ty,
value: serde_json::from_str(&idl_ty_value).unwrap(),
}))
}
fn parse_account(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
// Get the anchor account field from the derive accounts struct.
let account_field = self
.accounts
.fields
.iter()
.find(|field| *field.ident() == seed_path.name())
.unwrap();
// Follow the path to find the seed type.
let ty = {
let mut path = seed_path.components();
match path.len() {
0 => IdlType::PublicKey,
1 => {
// Name of the account struct.
let account = account_field.ty_name()?;
if account == "TokenAccount" {
assert!(path.len() == 1);
match path[0].as_str() {
"mint" => IdlType::PublicKey,
"amount" => IdlType::U64,
"authority" => IdlType::PublicKey,
"delegated_amount" => IdlType::U64,
_ => {
println!("WARNING: token field isn't supported: {}", &path[0]);
return None;
}
}
} else {
// Get the rust representation of the field's struct.
let strct = self.ctx.structs().find(|s| s.ident == account).unwrap();
parse_field_path(self.ctx, strct, &mut path)
}
}
_ => panic!("invariant violation"),
}
};
Some(IdlSeed::Account(IdlSeedAccount {
ty,
account: account_field.ty_name(),
path: seed_path.path(),
}))
}
fn parse_str_literal(&self, seed_path: &SeedPath) -> Option<IdlSeed> {
let mut var_name = seed_path.name();
// Remove the byte `b` prefix if the string is of the form `b"seed".
if var_name.starts_with("b\"") {
var_name.remove(0);
}
let value_string: String = var_name.chars().filter(|c| *c != '"').collect();
Some(IdlSeed::Const(IdlSeedConst {
value: serde_json::Value::String(value_string),
ty: IdlType::String,
}))
}
fn is_instruction(&self, seed_path: &SeedPath) -> bool {
self.ix_args.contains_key(&seed_path.name())
}
fn is_const(&self, seed_path: &SeedPath) -> bool {
self.const_names.contains(&seed_path.name())
}
fn is_impl_const(&self, seed_path: &SeedPath) -> bool {
self.impl_const_names.contains(&seed_path.name())
}
fn is_account(&self, seed_path: &SeedPath) -> bool {
self.account_field_names.contains(&seed_path.name())
}
fn is_str_literal(&self, seed_path: &SeedPath) -> bool {
seed_path.components().is_empty() && seed_path.name().contains('"')
}
}
// SeedPath represents the deconstructed syntax of a single pda seed,
// consisting of a variable name and a vec of all the sub fields accessed
// on that variable name. For example, if a seed is `my_field.my_data.as_ref()`,
// then the field name is `my_field` and the vec of sub fields is `[my_data]`.
#[derive(Debug)]
struct SeedPath(String, Vec<String>);
impl SeedPath {
fn name(&self) -> String {
self.0.clone()
}
// Full path to the data this seed represents.
fn path(&self) -> String {
match self.1.len() {
0 => self.0.clone(),
_ => format!("{}.{}", self.name(), self.components().join(".")),
}
}
// All path components for the subfields accessed on this seed.
fn components(&self) -> &[String] {
&self.1
}
}
// Extracts the seed path from a single seed expression.
fn parse_seed_path(seed: &Expr) -> Option<SeedPath> {
// Convert the seed into the raw string representation.
let seed_str = parser::tts_to_string(&seed);
// Break up the seed into each sub field component.
let mut components: Vec<&str> = seed_str.split(" . ").collect();
if components.len() <= 1 {
println!("WARNING: seeds are in an unexpected format: {:?}", seed);
return None;
}
// The name of the variable (or field).
let name = components.remove(0).to_string();
// The path to the seed (only if the `name` type is a struct).
let mut path = Vec::new();
while !components.is_empty() {
let c = components.remove(0);
if c.contains("()") {
break;
}
path.push(c.to_string());
}
if path.len() == 1 && (path[0] == "key" || path[0] == "key()") {
path = Vec::new();
}
Some(SeedPath(name, path))
}
fn parse_field_path(ctx: &CrateContext, strct: &syn::ItemStruct, path: &mut &[String]) -> IdlType {
let field_name = &path[0];
*path = &path[1..];
// Get the type name for the field.
let next_field = strct
.fields
.iter()
.find(|f| &f.ident.clone().unwrap().to_string() == field_name)
.unwrap();
let next_field_ty_str = parser::tts_to_string(&next_field.ty);
// The path is empty so this must be a primitive type.
if path.is_empty() {
return next_field_ty_str.parse().unwrap();
}
// Get the rust representation of hte field's struct.
let strct = ctx
.structs()
.find(|s| s.ident == next_field_ty_str)
.unwrap();
parse_field_path(ctx, strct, path)
}
fn str_lit_to_array(idl_ty: &IdlType, idl_ty_value: &String) -> String {
if let IdlType::Array(_ty, _size) = &idl_ty {
// Convert str literal to array.
if idl_ty_value.contains("b\"") {
let components: Vec<&str> = idl_ty_value.split('b').collect();
assert_eq!(components.len(), 2);
let mut str_lit = components[1].to_string();
str_lit.retain(|c| c != '"');
return format!("{:?}", str_lit.as_bytes());
}
}
idl_ty_value.to_string()
}