Uniformly parse all attributes in all positions

This commit is contained in:
David Tolnay 2019-10-11 15:16:18 -07:00
parent 31ecbd2410
commit 3d43d39ed1
No known key found for this signature in database
GPG Key ID: F9BA143B95FF6D82
2 changed files with 59 additions and 69 deletions

View File

@ -1,20 +1,43 @@
use proc_macro2::{Delimiter, Group, TokenStream, TokenTree};
use quote::{format_ident, quote, ToTokens};
use std::iter::once;
use syn::parse::{Nothing, Parse, ParseStream};
use syn::parse::{Nothing, ParseStream};
use syn::{
braced, bracketed, parenthesized, token, Attribute, Error, Field, Ident, Index, LitInt, LitStr,
Result, Token,
braced, bracketed, parenthesized, token, Attribute, Ident, Index, LitInt, LitStr, Result, Token,
};
pub struct Attrs {
pub display: Option<Display>,
pub source: bool,
}
pub struct Display {
pub fmt: LitStr,
pub args: TokenStream,
pub was_shorthand: bool,
}
impl Parse for Display {
fn parse(input: ParseStream) -> Result<Self> {
pub fn get(input: &[Attribute]) -> Result<Attrs> {
let mut attrs = Attrs {
display: None,
source: false,
};
for attr in input {
if attr.path.is_ident("error") {
let display = parse_display(attr)?;
attrs.display = Some(display);
} else if attr.path.is_ident("source") {
parse_source(attr)?;
attrs.source = true;
}
}
Ok(attrs)
}
fn parse_display(attr: &Attribute) -> Result<Display> {
attr.parse_args_with(|input: ParseStream| {
let mut display = Display {
fmt: input.parse()?,
args: parse_token_expr(input, false)?,
@ -22,7 +45,7 @@ impl Parse for Display {
};
display.expand_shorthand();
Ok(display)
}
})
}
fn parse_token_expr(input: ParseStream, mut last_is_comma: bool) -> Result<TokenStream> {
@ -73,6 +96,11 @@ fn parse_token_expr(input: ParseStream, mut last_is_comma: bool) -> Result<Token
Ok(tokens)
}
fn parse_source(attr: &Attribute) -> Result<()> {
syn::parse2::<Nothing>(attr.tokens.clone())?;
Ok(())
}
impl ToTokens for Display {
fn to_tokens(&self, tokens: &mut TokenStream) {
let fmt = &self.fmt;
@ -94,31 +122,3 @@ impl ToTokens for Display {
}
}
}
pub fn is_source(field: &Field) -> Result<bool> {
for attr in &field.attrs {
if attr.path.is_ident("source") {
syn::parse2::<Nothing>(attr.tokens.clone())?;
return Ok(true);
}
}
Ok(false)
}
pub fn display(attrs: &[Attribute]) -> Result<Option<Display>> {
let mut display = None;
for attr in attrs {
if attr.path.is_ident("error") {
if display.is_some() {
return Err(Error::new_spanned(
attr,
"only one #[error(...)] attribute is allowed",
));
}
display = Some(attr.parse_args()?);
}
}
Ok(display)
}

View File

@ -22,18 +22,7 @@ fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result<TokenStream> {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let source = match &data.fields {
Fields::Named(fields) => source_member(&fields.named)?,
Fields::Unnamed(fields) => source_member(&fields.unnamed)?,
Fields::Unit => None,
};
let backtrace = match &data.fields {
Fields::Named(fields) => backtrace_member(&fields.named)?,
Fields::Unnamed(fields) => backtrace_member(&fields.unnamed)?,
Fields::Unit => None,
};
let source = source_member(&data.fields)?;
let source_method = source.map(|source| {
let member = quote_spanned!(source.span()=> self.#source);
quote! {
@ -44,6 +33,7 @@ fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result<TokenStream> {
}
});
let backtrace = backtrace_member(&data.fields)?;
let backtrace_method = backtrace.map(|backtrace| {
quote! {
fn backtrace(&self) -> std::option::Option<&std::backtrace::Backtrace> {
@ -52,7 +42,8 @@ fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result<TokenStream> {
}
});
let display = attr::display(&input.attrs)?.map(|display| {
let struct_attrs = attr::get(&input.attrs)?;
let display = struct_attrs.display.map(|display| {
let pat = match &data.fields {
Fields::Named(fields) => {
let var = fields.named.iter().map(|field| &field.ident);
@ -88,25 +79,23 @@ fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result<TokenStream> {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let sources = data
let variant_fields: Vec<&Fields> = data
.variants
.iter()
.map(|variant| match &variant.fields {
Fields::Named(fields) => source_member(&fields.named),
Fields::Unnamed(fields) => source_member(&fields.unnamed),
Fields::Unit => Ok(None),
})
.collect::<Result<Vec<_>>>()?;
.map(|variant| &variant.fields)
.collect();
let backtraces = data
.variants
let sources: Vec<Option<Member>> = variant_fields
.iter()
.map(|variant| match &variant.fields {
Fields::Named(fields) => backtrace_member(&fields.named),
Fields::Unnamed(fields) => backtrace_member(&fields.unnamed),
Fields::Unit => Ok(None),
})
.collect::<Result<Vec<_>>>()?;
.copied()
.map(source_member)
.collect::<Result<_>>()?;
let backtraces: Vec<Option<Member>> = variant_fields
.iter()
.copied()
.map(backtrace_member)
.collect::<Result<_>>()?;
let source_method = if sources.iter().any(Option::is_some) {
let arms = data.variants.iter().zip(sources).map(|(variant, source)| {
@ -155,18 +144,18 @@ fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result<TokenStream> {
None
};
let displays = data
let variant_attrs = data
.variants
.iter()
.map(|variant| attr::display(&variant.attrs))
.map(|variant| attr::get(&variant.attrs))
.collect::<Result<Vec<_>>>()?;
let display = if displays.iter().any(Option::is_some) {
let display = if variant_attrs.iter().any(|attrs| attrs.display.is_some()) {
let arms = data
.variants
.iter()
.zip(displays)
.map(|(variant, display)| {
let display = display.ok_or_else(|| {
.zip(variant_attrs)
.map(|(variant, attrs)| {
let display = attrs.display.ok_or_else(|| {
Error::new_spanned(variant, "missing #[error(\"...\")] display attribute")
})?;
let ident = &variant.ident;
@ -208,7 +197,8 @@ fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result<TokenStream> {
fn source_member<'a>(fields: impl IntoIterator<Item = &'a Field>) -> Result<Option<Member>> {
for (i, field) in fields.into_iter().enumerate() {
if attr::is_source(field)? {
let attrs = attr::get(&field.attrs)?;
if attrs.source {
return Ok(Some(member(i, &field.ident)));
}
}