diff --git a/impl/src/expand.rs b/impl/src/expand.rs index c933f3e..410e7eb 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -60,6 +60,85 @@ fn struct_error(input: &DeriveInput, data: &DataStruct) -> Result { }) } +fn enum_error(input: &DeriveInput, data: &DataEnum) -> Result { + let ident = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + let sources = data + .variants + .iter() + .map(|variant| match &variant.fields { + Fields::Named(fields) => source_member(&fields.named), + Fields::Unnamed(fields) => source_member(&fields.unnamed), + Fields::Unit => Ok(None), + }) + .collect::>>()?; + + let backtraces = data + .variants + .iter() + .map(|variant| match &variant.fields { + Fields::Named(fields) => backtrace_member(&fields.named), + Fields::Unnamed(fields) => backtrace_member(&fields.unnamed), + Fields::Unit => Ok(None), + }) + .collect::>>()?; + + let source_method = if sources.iter().any(Option::is_some) { + let arms = data.variants.iter().zip(sources).map(|(variant, source)| { + let ident = &variant.ident; + match source { + Some(source) => quote! { + Self::#ident {#source: source, ..} => std::option::Option::Some(source.as_dyn_error()), + }, + None => quote! { + Self::#ident {..} => std::option::Option::None, + }, + } + }); + Some(quote! { + fn source(&self) -> std::option::Option<&(dyn std::error::Error + 'static)> { + use thiserror::private::AsDynError; + match self { + #(#arms)* + } + } + }) + } else { + None + }; + + let backtrace_method = if backtraces.iter().any(Option::is_some) { + let arms = data.variants.iter().zip(backtraces).map(|(variant, backtrace)| { + let ident = &variant.ident; + match backtrace { + Some(backtrace) => quote! { + Self::#ident {#backtrace: backtrace, ..} => std::option::Option::Some(backtrace), + }, + None => quote! { + Self::#ident {..} => std::option::Option::None, + }, + } + }); + Some(quote! { + fn backtrace(&self) -> std::option::Option<&std::backtrace::Backtrace> { + match self { + #(#arms)* + } + } + }) + } else { + None + }; + + Ok(quote! { + impl #impl_generics std::error::Error for #ident #ty_generics #where_clause { + #source_method + #backtrace_method + } + }) +} + fn source_member<'a>(fields: impl IntoIterator) -> Result> { for (i, field) in fields.into_iter().enumerate() { if attr::is_source(field)? { @@ -94,9 +173,3 @@ fn member(i: usize, ident: &Option) -> Member { None => Member::Unnamed(Index::from(i)), } } - -fn enum_error(input: &DeriveInput, data: &DataEnum) -> Result { - let _ = input; - let _ = data; - unimplemented!() -} diff --git a/tests/test.rs b/tests/test.rs index 3bd6ab9..fab934d 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use std::fmt::{self, Display}; use std::io; use thiserror::Error; @@ -36,8 +38,19 @@ struct WithAnyhow { cause: anyhow::Error, } +#[derive(Error, Debug)] +enum EnumError { + Braced { + #[source] + cause: io::Error, + }, + Tuple(#[source] io::Error), + Unit, +} + unimplemented_display!(BracedError); unimplemented_display!(TupleError); unimplemented_display!(UnitError); unimplemented_display!(WithSource); unimplemented_display!(WithAnyhow); +unimplemented_display!(EnumError);