diff --git a/errors/common.go b/errors/common.go index 2f74c8c13..73a0f6a50 100644 --- a/errors/common.go +++ b/errors/common.go @@ -5,102 +5,148 @@ package errors **/ import ( + rawerr "errors" "fmt" + "github.com/pkg/errors" abci "github.com/tendermint/abci/types" "github.com/tendermint/basecoin" ) -const ( - msgDecoding = "Error decoding input" - msgUnauthorized = "Unauthorized" - msgInvalidAddress = "Invalid Address" - msgInvalidCoins = "Invalid Coins" - msgInvalidFormat = "Invalid Format" - msgInvalidSequence = "Invalid Sequence" - msgInvalidSignature = "Invalid Signature" - msgInsufficientFees = "Insufficient Fees" - msgInsufficientFunds = "Insufficient Funds" - msgNoInputs = "No Input Coins" - msgNoOutputs = "No Output Coins" - msgTooLarge = "Input size too large" - msgMissingSignature = "Signature missing" - msgTooManySignatures = "Too many signatures" - msgNoChain = "No chain id provided" - msgWrongChain = "Tx belongs to different chain - %s" - msgUnknownTxType = "We cannot handle this tx - %v" +var ( + errDecoding = rawerr.New("Error decoding input") + errUnauthorized = rawerr.New("Unauthorized") + errInvalidAddress = rawerr.New("Invalid Address") + errInvalidCoins = rawerr.New("Invalid Coins") + errInvalidFormat = rawerr.New("Invalid Format") + errInvalidSequence = rawerr.New("Invalid Sequence") + errInvalidSignature = rawerr.New("Invalid Signature") + errInsufficientFees = rawerr.New("Insufficient Fees") + errInsufficientFunds = rawerr.New("Insufficient Funds") + errNoInputs = rawerr.New("No Input Coins") + errNoOutputs = rawerr.New("No Output Coins") + errTooLarge = rawerr.New("Input size too large") + errMissingSignature = rawerr.New("Signature missing") + errTooManySignatures = rawerr.New("Too many signatures") + errNoChain = rawerr.New("No chain id provided") + errWrongChain = rawerr.New("Wrong chain for tx") + errUnknownTxType = rawerr.New("Tx type unknown") ) -func UnknownTxType(tx basecoin.Tx) TMError { - msg := fmt.Sprintf(msgUnknownTxType, tx) - return New(msg, abci.CodeType_UnknownRequest) +func ErrUnknownTxType(tx basecoin.Tx) TMError { + msg := fmt.Sprintf("%T", tx.Unwrap()) + w := errors.Wrap(errUnknownTxType, msg) + return WithCode(w, abci.CodeType_UnknownRequest) } -func InternalError(msg string) TMError { +func IsUnknownTxTypeErr(err error) bool { + return IsSameError(errUnknownTxType, err) +} + +func ErrInternal(msg string) TMError { return New(msg, abci.CodeType_InternalError) } -func DecodingError() TMError { - return New(msgDecoding, abci.CodeType_EncodingError) +// IsInternalErr matches any error that is not classified +func IsInternalErr(err error) bool { + return HasErrorCode(err, abci.CodeType_InternalError) } -func Unauthorized() TMError { - return New(msgUnauthorized, abci.CodeType_Unauthorized) +func ErrDecoding() TMError { + return WithCode(errDecoding, abci.CodeType_EncodingError) } -func MissingSignature() TMError { - return New(msgMissingSignature, abci.CodeType_Unauthorized) +func IsDecodingErr(err error) bool { + return IsSameError(errDecoding, err) } -func TooManySignatures() TMError { - return New(msgTooManySignatures, abci.CodeType_Unauthorized) +func ErrUnauthorized() TMError { + return WithCode(errUnauthorized, abci.CodeType_Unauthorized) } -func InvalidSignature() TMError { - return New(msgInvalidSignature, abci.CodeType_Unauthorized) +// IsUnauthorizedErr is generic helper for any unauthorized errors, +// also specific sub-types +func IsUnauthorizedErr(err error) bool { + return HasErrorCode(err, abci.CodeType_Unauthorized) } -func NoChain() TMError { - return New(msgNoChain, abci.CodeType_Unauthorized) +func ErrMissingSignature() TMError { + return WithCode(errMissingSignature, abci.CodeType_Unauthorized) } -func WrongChain(chain string) TMError { - msg := fmt.Sprintf(msgWrongChain, chain) - return New(msg, abci.CodeType_Unauthorized) +func IsMissingSignatureErr(err error) bool { + return IsSameError(errMissingSignature, err) +} + +func ErrTooManySignatures() TMError { + return WithCode(errTooManySignatures, abci.CodeType_Unauthorized) +} + +func IsTooManySignaturesErr(err error) bool { + return IsSameError(errTooManySignatures, err) +} + +func ErrInvalidSignature() TMError { + return WithCode(errInvalidSignature, abci.CodeType_Unauthorized) +} + +func IsInvalidSignatureErr(err error) bool { + return IsSameError(errInvalidSignature, err) +} + +func ErrNoChain() TMError { + return WithCode(errNoChain, abci.CodeType_Unauthorized) +} + +func IsNoChainErr(err error) bool { + return IsSameError(errNoChain, err) +} + +func ErrWrongChain(chain string) TMError { + msg := errors.Wrap(errWrongChain, chain) + return WithCode(msg, abci.CodeType_Unauthorized) +} + +func IsWrongChainErr(err error) bool { + return IsSameError(errWrongChain, err) } func InvalidAddress() TMError { - return New(msgInvalidAddress, abci.CodeType_BaseInvalidInput) + return WithCode(errInvalidAddress, abci.CodeType_BaseInvalidInput) } func InvalidCoins() TMError { - return New(msgInvalidCoins, abci.CodeType_BaseInvalidInput) + return WithCode(errInvalidCoins, abci.CodeType_BaseInvalidInput) } func InvalidFormat() TMError { - return New(msgInvalidFormat, abci.CodeType_BaseInvalidInput) + return WithCode(errInvalidFormat, abci.CodeType_BaseInvalidInput) } func InvalidSequence() TMError { - return New(msgInvalidSequence, abci.CodeType_BaseInvalidInput) + return WithCode(errInvalidSequence, abci.CodeType_BaseInvalidInput) } func InsufficientFees() TMError { - return New(msgInsufficientFees, abci.CodeType_BaseInvalidInput) + return WithCode(errInsufficientFees, abci.CodeType_BaseInvalidInput) } func InsufficientFunds() TMError { - return New(msgInsufficientFunds, abci.CodeType_BaseInvalidInput) + return WithCode(errInsufficientFunds, abci.CodeType_BaseInvalidInput) } func NoInputs() TMError { - return New(msgNoInputs, abci.CodeType_BaseInvalidInput) + return WithCode(errNoInputs, abci.CodeType_BaseInvalidInput) } func NoOutputs() TMError { - return New(msgNoOutputs, abci.CodeType_BaseInvalidOutput) + return WithCode(errNoOutputs, abci.CodeType_BaseInvalidOutput) } -func TooLarge() TMError { - return New(msgTooLarge, abci.CodeType_EncodingError) +func ErrTooLarge() TMError { + return WithCode(errTooLarge, abci.CodeType_EncodingError) +} + +func IsTooLargeErr(err error) bool { + return IsSameError(errTooLarge, err) } diff --git a/errors/common_test.go b/errors/common_test.go new file mode 100644 index 000000000..e863ac786 --- /dev/null +++ b/errors/common_test.go @@ -0,0 +1,74 @@ +package errors + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tendermint/basecoin" +) + +type DemoTx struct { + Age int +} + +func (t DemoTx) Wrap() basecoin.Tx { + return basecoin.Tx{t} +} + +func (t DemoTx) ValidateBasic() error { + return nil +} + +func TestErrorMatches(t *testing.T) { + assert := assert.New(t) + + cases := []struct { + pattern, err error + match bool + }{ + {errDecoding, ErrDecoding(), true}, + {errUnauthorized, ErrUnauthorized(), true}, + {errMissingSignature, ErrUnauthorized(), false}, + {errMissingSignature, ErrMissingSignature(), true}, + {errWrongChain, ErrWrongChain("hakz"), true}, + {errUnknownTxType, ErrUnknownTxType(basecoin.Tx{}), true}, + {errUnknownTxType, ErrUnknownTxType(DemoTx{5}.Wrap()), true}, + } + + for i, tc := range cases { + same := IsSameError(tc.pattern, tc.err) + assert.Equal(tc.match, same, "%d: %#v / %#v", i, tc.pattern, tc.err) + } +} + +func TestChecks(t *testing.T) { + // TODO: make sure the Is and Err methods match + assert := assert.New(t) + + cases := []struct { + err error + check func(error) bool + match bool + }{ + {ErrDecoding(), IsDecodingErr, true}, + {ErrUnauthorized(), IsDecodingErr, false}, + {ErrUnauthorized(), IsUnauthorizedErr, true}, + {ErrInvalidSignature(), IsInvalidSignatureErr, true}, + // unauthorized includes InvalidSignature, but not visa versa + {ErrInvalidSignature(), IsUnauthorizedErr, true}, + {ErrUnauthorized(), IsInvalidSignatureErr, false}, + // make sure WrongChain works properly + {ErrWrongChain("fooz"), IsUnauthorizedErr, true}, + {ErrWrongChain("barz"), IsWrongChainErr, true}, + // make sure lots of things match InternalErr, but not everything + {ErrInternal("bad db connection"), IsInternalErr, true}, + {Wrap(errors.New("wrapped")), IsInternalErr, true}, + {ErrUnauthorized(), IsInternalErr, false}, + } + + for i, tc := range cases { + match := tc.check(tc.err) + assert.Equal(tc.match, match, "%d", i) + } +} diff --git a/errors/main.go b/errors/main.go index a4cb8a483..f19a7f692 100644 --- a/errors/main.go +++ b/errors/main.go @@ -19,6 +19,10 @@ type stackTracer interface { StackTrace() errors.StackTrace } +type causer interface { + Cause() error +} + type TMError interface { stackTracer ErrorCode() abci.CodeType @@ -31,6 +35,11 @@ type tmerror struct { msg string } +var ( + _ causer = tmerror{} + _ error = tmerror{} +) + func (t tmerror) ErrorCode() abci.CodeType { return t.code } @@ -39,6 +48,13 @@ func (t tmerror) Message() string { return t.msg } +func (t tmerror) Cause() error { + if c, ok := t.stackTracer.(causer); ok { + return c.Cause() + } + return t.stackTracer +} + // Format handles "%+v" to expose the full stack trace // concept from pkg/errors func (t tmerror) Format(s fmt.State, verb rune) { @@ -102,3 +118,18 @@ func New(msg string, code abci.CodeType) TMError { msg: msg, } } + +// IsSameError returns true if these errors have the same root cause. +// pattern is the expected error type and should always be non-nil +// err may be anything and returns true if it is a wrapped version of pattern +func IsSameError(pattern error, err error) bool { + return err != nil && (errors.Cause(err) == errors.Cause(pattern)) +} + +// HasErrorCode checks if this error would return the named error code +func HasErrorCode(err error, code abci.CodeType) bool { + if tm, ok := err.(TMError); ok { + return tm.ErrorCode() == code + } + return code == defaultErrCode +} diff --git a/errors/main_test.go b/errors/main_test.go index 6209a547e..e9a80f8bb 100644 --- a/errors/main_test.go +++ b/errors/main_test.go @@ -24,8 +24,8 @@ func TestCreateResult(t *testing.T) { {New("nonce", abci.CodeType_BadNonce), "nonce", abci.CodeType_BadNonce}, {Wrap(stderr.New("wrap")), "wrap", defaultErrCode}, {WithCode(stderr.New("coded"), abci.CodeType_BaseInvalidInput), "coded", abci.CodeType_BaseInvalidInput}, - {DecodingError(), msgDecoding, abci.CodeType_EncodingError}, - {Unauthorized(), msgUnauthorized, abci.CodeType_Unauthorized}, + {ErrDecoding(), errDecoding.Error(), abci.CodeType_EncodingError}, + {ErrUnauthorized(), errUnauthorized.Error(), abci.CodeType_Unauthorized}, } for idx, tc := range cases {