diff --git a/impl/src/ast.rs b/impl/src/ast.rs new file mode 100644 index 0000000..0d15388 --- /dev/null +++ b/impl/src/ast.rs @@ -0,0 +1,111 @@ +use crate::attr::{self, Attrs}; +use syn::{ + Data, DataEnum, DataStruct, DeriveInput, Error, Fields, Generics, Ident, Index, Member, Result, + Type, +}; + +pub enum Input<'a> { + Struct(Struct<'a>), + Enum(Enum<'a>), +} + +pub struct Struct<'a> { + pub attrs: Attrs, + pub ident: Ident, + pub generics: &'a Generics, + pub fields: Vec>, +} + +pub struct Enum<'a> { + pub attrs: Attrs, + pub ident: Ident, + pub generics: &'a Generics, + pub variants: Vec>, +} + +pub struct Variant<'a> { + pub original: &'a syn::Variant, + pub attrs: Attrs, + pub ident: Ident, + pub fields: Vec>, +} + +pub struct Field<'a> { + pub original: &'a syn::Field, + pub attrs: Attrs, + pub member: Member, + pub ty: &'a Type, +} + +impl<'a> Input<'a> { + pub fn from_syn(node: &'a DeriveInput) -> Result { + match &node.data { + Data::Struct(data) => Struct::from_syn(node, data).map(Input::Struct), + Data::Enum(data) => Enum::from_syn(node, data).map(Input::Enum), + Data::Union(_) => Err(Error::new_spanned( + node, + "union as errors are not supported", + )), + } + } +} + +impl<'a> Struct<'a> { + fn from_syn(node: &'a DeriveInput, data: &'a DataStruct) -> Result { + Ok(Struct { + attrs: attr::get(&node.attrs)?, + ident: node.ident.clone(), + generics: &node.generics, + fields: Field::multiple_from_syn(&data.fields)?, + }) + } +} + +impl<'a> Enum<'a> { + fn from_syn(node: &'a DeriveInput, data: &'a DataEnum) -> Result { + Ok(Enum { + attrs: attr::get(&node.attrs)?, + ident: node.ident.clone(), + generics: &node.generics, + variants: data + .variants + .iter() + .map(Variant::from_syn) + .collect::>()?, + }) + } +} + +impl<'a> Variant<'a> { + fn from_syn(node: &'a syn::Variant) -> Result { + Ok(Variant { + original: node, + attrs: attr::get(&node.attrs)?, + ident: node.ident.clone(), + fields: Field::multiple_from_syn(&node.fields)?, + }) + } +} + +impl<'a> Field<'a> { + fn multiple_from_syn(fields: &'a Fields) -> Result> { + fields + .iter() + .enumerate() + .map(|(i, field)| Field::from_syn(i, field)) + .collect() + } + + fn from_syn(i: usize, node: &'a syn::Field) -> Result { + Ok(Field { + original: node, + attrs: attr::get(&node.attrs)?, + member: node + .ident + .clone() + .map(Member::Named) + .unwrap_or_else(|| Member::Unnamed(Index::from(i))), + ty: &node.ty, + }) + } +} diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 2708a34..cbd01b3 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -1,28 +1,23 @@ -use crate::attr; +use crate::ast::{Enum, Field, Input, Struct}; +use crate::attr::Attrs; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; use syn::spanned::Spanned; -use syn::{ - Data, DataEnum, DataStruct, DeriveInput, Error, Field, Fields, Ident, Index, Member, Result, - Type, -}; +use syn::{DeriveInput, Error, Member, Result, Type}; -pub fn derive(input: &DeriveInput) -> Result { - match &input.data { - Data::Struct(data) => impl_struct(input, data), - Data::Enum(data) => impl_enum(input, data), - Data::Union(_) => Err(Error::new_spanned( - input, - "union as errors are not supported", - )), +pub fn derive(node: &DeriveInput) -> Result { + let input = Input::from_syn(node)?; + match input { + Input::Struct(input) => impl_struct(input), + Input::Enum(input) => impl_enum(input), } } -fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result { +fn impl_struct(input: Struct) -> Result { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let source = source_member(&data.fields)?; + let source = source_member(&input.fields); let source_method = source.map(|source| { let member = quote_spanned!(source.span()=> self.#source); quote! { @@ -33,7 +28,7 @@ fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result { } }); - let backtrace = backtrace_member(&data.fields)?; + let backtrace = backtrace_member(&input.fields); let backtrace_method = backtrace.map(|backtrace| { quote! { fn backtrace(&self) -> std::option::Option<&std::backtrace::Backtrace> { @@ -42,24 +37,13 @@ fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result { } }); - let struct_attrs = attr::get(&input.attrs)?; - let display = struct_attrs.display.map(|display| { - let pat = match &data.fields { - Fields::Named(fields) => { - let var = fields.named.iter().map(|field| &field.ident); - quote!(Self { #(#var),* }) - } - Fields::Unnamed(fields) => { - let var = (0..fields.unnamed.len()).map(|i| format_ident!("_{}", i)); - quote!(Self(#(#var),*)) - } - Fields::Unit => quote!(_), - }; + let display = input.attrs.display.as_ref().map(|display| { + let pat = fields_pat(&input.fields); quote! { impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { #[allow(unused_variables)] - let #pat = self; + let Self #pat = self; #display } } @@ -75,30 +59,24 @@ fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result { }) } -fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result { +fn impl_enum(input: Enum) -> Result { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let variant_fields: Vec<&Fields> = data + let sources: Vec> = input .variants .iter() - .map(|variant| &variant.fields) + .map(|variant| source_member(&variant.fields)) .collect(); - let sources: Vec> = variant_fields + let backtraces: Vec> = input + .variants .iter() - .cloned() - .map(source_member) - .collect::>()?; - - let backtraces: Vec> = variant_fields - .iter() - .cloned() - .map(backtrace_member) - .collect::>()?; + .map(|variant| backtrace_member(&variant.fields)) + .collect(); let source_method = if sources.iter().any(Option::is_some) { - let arms = data.variants.iter().zip(sources).map(|(variant, source)| { + let arms = input.variants.iter().zip(sources).map(|(variant, source)| { let ident = &variant.ident; match source { Some(source) => quote! { @@ -122,7 +100,7 @@ fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result { }; let backtrace_method = if backtraces.iter().any(Option::is_some) { - let arms = data.variants.iter().zip(backtraces).map(|(variant, backtrace)| { + let arms = input.variants.iter().zip(backtraces).map(|(variant, backtrace)| { let ident = &variant.ident; match backtrace { Some(backtrace) => quote! { @@ -144,31 +122,27 @@ fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result { None }; - let variant_attrs = data + let variant_attrs: Vec<&Attrs> = input .variants .iter() - .map(|variant| attr::get(&variant.attrs)) - .collect::>>()?; + .map(|variant| &variant.attrs) + .collect(); let display = if variant_attrs.iter().any(|attrs| attrs.display.is_some()) { - let arms = data + let arms = input .variants .iter() .zip(variant_attrs) .map(|(variant, attrs)| { - let display = attrs.display.ok_or_else(|| { - Error::new_spanned(variant, "missing #[error(\"...\")] display attribute") + let display = attrs.display.as_ref().ok_or_else(|| { + Error::new_spanned( + variant.original, + "missing #[error(\"...\")] display attribute", + ) })?; let ident = &variant.ident; - Ok(match &variant.fields { - Fields::Named(fields) => { - let var = fields.named.iter().map(|field| &field.ident); - quote!(#ty::#ident { #(#var),* } => #display) - } - Fields::Unnamed(fields) => { - let var = (0..fields.unnamed.len()).map(|i| format_ident!("_{}", i)); - quote!(#ty::#ident(#(#var),*) => #display) - } - Fields::Unit => quote!(#ty::#ident => #display), + let pat = fields_pat(&variant.fields); + Ok(quote! { + #ty::#ident #pat => #display }) }) .collect::>>()?; @@ -195,23 +169,22 @@ fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result { }) } -fn source_member<'a>(fields: impl IntoIterator) -> Result> { - for (i, field) in fields.into_iter().enumerate() { - let attrs = attr::get(&field.attrs)?; - if attrs.source { - return Ok(Some(member(i, &field.ident))); +fn source_member<'a>(fields: &'a [Field]) -> Option<&'a Member> { + for field in fields { + if field.attrs.source { + return Some(&field.member); } } - Ok(None) + None } -fn backtrace_member<'a>(fields: impl IntoIterator) -> Result> { - for (i, field) in fields.into_iter().enumerate() { +fn backtrace_member<'a>(fields: &'a [Field]) -> Option<&'a Member> { + for field in fields { if type_is_backtrace(&field.ty) { - return Ok(Some(member(i, &field.ident))); + return Some(&field.member); } } - Ok(None) + None } fn type_is_backtrace(ty: &Type) -> bool { @@ -224,9 +197,17 @@ fn type_is_backtrace(ty: &Type) -> bool { last.ident == "Backtrace" && last.arguments.is_empty() } -fn member(i: usize, ident: &Option) -> Member { - match ident { - Some(ident) => Member::Named(ident.clone()), - None => Member::Unnamed(Index::from(i)), +fn fields_pat(fields: &[Field]) -> TokenStream { + let mut members = fields.iter().map(|field| &field.member).peekable(); + match members.peek() { + Some(Member::Named(_)) => quote!({ #(#members),* }), + Some(Member::Unnamed(_)) => { + let vars = members.map(|member| match member { + Member::Unnamed(member) => format_ident!("_{}", member.index), + Member::Named(_) => unreachable!(), + }); + quote!((#(#vars),*)) + } + None => quote!({}), } } diff --git a/impl/src/lib.rs b/impl/src/lib.rs index 2051fe7..e819e0e 100644 --- a/impl/src/lib.rs +++ b/impl/src/lib.rs @@ -1,5 +1,6 @@ extern crate proc_macro; +mod ast; mod attr; mod expand; mod fmt;