diff --git a/impl/src/attr.rs b/impl/src/attr.rs index 7db26b9..b74c834 100644 --- a/impl/src/attr.rs +++ b/impl/src/attr.rs @@ -1,20 +1,43 @@ use proc_macro2::{Delimiter, Group, TokenStream, TokenTree}; use quote::{format_ident, quote, ToTokens}; use std::iter::once; -use syn::parse::{Nothing, Parse, ParseStream}; +use syn::parse::{Nothing, ParseStream}; use syn::{ - braced, bracketed, parenthesized, token, Attribute, Error, Field, Ident, Index, LitInt, LitStr, - Result, Token, + braced, bracketed, parenthesized, token, Attribute, Ident, Index, LitInt, LitStr, Result, Token, }; +pub struct Attrs { + pub display: Option, + pub source: bool, +} + pub struct Display { pub fmt: LitStr, pub args: TokenStream, pub was_shorthand: bool, } -impl Parse for Display { - fn parse(input: ParseStream) -> Result { +pub fn get(input: &[Attribute]) -> Result { + let mut attrs = Attrs { + display: None, + source: false, + }; + + for attr in input { + if attr.path.is_ident("error") { + let display = parse_display(attr)?; + attrs.display = Some(display); + } else if attr.path.is_ident("source") { + parse_source(attr)?; + attrs.source = true; + } + } + + Ok(attrs) +} + +fn parse_display(attr: &Attribute) -> Result { + attr.parse_args_with(|input: ParseStream| { let mut display = Display { fmt: input.parse()?, args: parse_token_expr(input, false)?, @@ -22,7 +45,7 @@ impl Parse for Display { }; display.expand_shorthand(); Ok(display) - } + }) } fn parse_token_expr(input: ParseStream, mut last_is_comma: bool) -> Result { @@ -73,6 +96,11 @@ fn parse_token_expr(input: ParseStream, mut last_is_comma: bool) -> Result Result<()> { + syn::parse2::(attr.tokens.clone())?; + Ok(()) +} + impl ToTokens for Display { fn to_tokens(&self, tokens: &mut TokenStream) { let fmt = &self.fmt; @@ -94,31 +122,3 @@ impl ToTokens for Display { } } } - -pub fn is_source(field: &Field) -> Result { - for attr in &field.attrs { - if attr.path.is_ident("source") { - syn::parse2::(attr.tokens.clone())?; - return Ok(true); - } - } - Ok(false) -} - -pub fn display(attrs: &[Attribute]) -> Result> { - let mut display = None; - - for attr in attrs { - if attr.path.is_ident("error") { - if display.is_some() { - return Err(Error::new_spanned( - attr, - "only one #[error(...)] attribute is allowed", - )); - } - display = Some(attr.parse_args()?); - } - } - - Ok(display) -} diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 2506d9b..5a775c8 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -22,18 +22,7 @@ fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let source = match &data.fields { - Fields::Named(fields) => source_member(&fields.named)?, - Fields::Unnamed(fields) => source_member(&fields.unnamed)?, - Fields::Unit => None, - }; - - let backtrace = match &data.fields { - Fields::Named(fields) => backtrace_member(&fields.named)?, - Fields::Unnamed(fields) => backtrace_member(&fields.unnamed)?, - Fields::Unit => None, - }; - + let source = source_member(&data.fields)?; let source_method = source.map(|source| { let member = quote_spanned!(source.span()=> self.#source); quote! { @@ -44,6 +33,7 @@ fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result { } }); + let backtrace = backtrace_member(&data.fields)?; let backtrace_method = backtrace.map(|backtrace| { quote! { fn backtrace(&self) -> std::option::Option<&std::backtrace::Backtrace> { @@ -52,7 +42,8 @@ fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result { } }); - let display = attr::display(&input.attrs)?.map(|display| { + let struct_attrs = attr::get(&input.attrs)?; + let display = struct_attrs.display.map(|display| { let pat = match &data.fields { Fields::Named(fields) => { let var = fields.named.iter().map(|field| &field.ident); @@ -88,25 +79,23 @@ fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let sources = data + let variant_fields: Vec<&Fields> = data .variants .iter() - .map(|variant| match &variant.fields { - Fields::Named(fields) => source_member(&fields.named), - Fields::Unnamed(fields) => source_member(&fields.unnamed), - Fields::Unit => Ok(None), - }) - .collect::>>()?; + .map(|variant| &variant.fields) + .collect(); - let backtraces = data - .variants + let sources: Vec> = variant_fields .iter() - .map(|variant| match &variant.fields { - Fields::Named(fields) => backtrace_member(&fields.named), - Fields::Unnamed(fields) => backtrace_member(&fields.unnamed), - Fields::Unit => Ok(None), - }) - .collect::>>()?; + .copied() + .map(source_member) + .collect::>()?; + + let backtraces: Vec> = variant_fields + .iter() + .copied() + .map(backtrace_member) + .collect::>()?; let source_method = if sources.iter().any(Option::is_some) { let arms = data.variants.iter().zip(sources).map(|(variant, source)| { @@ -155,18 +144,18 @@ fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result { None }; - let displays = data + let variant_attrs = data .variants .iter() - .map(|variant| attr::display(&variant.attrs)) + .map(|variant| attr::get(&variant.attrs)) .collect::>>()?; - let display = if displays.iter().any(Option::is_some) { + let display = if variant_attrs.iter().any(|attrs| attrs.display.is_some()) { let arms = data .variants .iter() - .zip(displays) - .map(|(variant, display)| { - let display = display.ok_or_else(|| { + .zip(variant_attrs) + .map(|(variant, attrs)| { + let display = attrs.display.ok_or_else(|| { Error::new_spanned(variant, "missing #[error(\"...\")] display attribute") })?; let ident = &variant.ident; @@ -208,7 +197,8 @@ fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result { fn source_member<'a>(fields: impl IntoIterator) -> Result> { for (i, field) in fields.into_iter().enumerate() { - if attr::is_source(field)? { + let attrs = attr::get(&field.attrs)?; + if attrs.source { return Ok(Some(member(i, &field.ident))); } }