Refactor errors package, so we can do type-checking IsXXErr()

This commit is contained in:
Ethan Frey 2017-07-03 14:22:46 +02:00
parent 2fc4da1076
commit 5fa77bf647
4 changed files with 200 additions and 49 deletions

View File

@ -5,102 +5,148 @@ package errors
**/
import (
rawerr "errors"
"fmt"
"github.com/pkg/errors"
abci "github.com/tendermint/abci/types"
"github.com/tendermint/basecoin"
)
const (
msgDecoding = "Error decoding input"
msgUnauthorized = "Unauthorized"
msgInvalidAddress = "Invalid Address"
msgInvalidCoins = "Invalid Coins"
msgInvalidFormat = "Invalid Format"
msgInvalidSequence = "Invalid Sequence"
msgInvalidSignature = "Invalid Signature"
msgInsufficientFees = "Insufficient Fees"
msgInsufficientFunds = "Insufficient Funds"
msgNoInputs = "No Input Coins"
msgNoOutputs = "No Output Coins"
msgTooLarge = "Input size too large"
msgMissingSignature = "Signature missing"
msgTooManySignatures = "Too many signatures"
msgNoChain = "No chain id provided"
msgWrongChain = "Tx belongs to different chain - %s"
msgUnknownTxType = "We cannot handle this tx - %v"
var (
errDecoding = rawerr.New("Error decoding input")
errUnauthorized = rawerr.New("Unauthorized")
errInvalidAddress = rawerr.New("Invalid Address")
errInvalidCoins = rawerr.New("Invalid Coins")
errInvalidFormat = rawerr.New("Invalid Format")
errInvalidSequence = rawerr.New("Invalid Sequence")
errInvalidSignature = rawerr.New("Invalid Signature")
errInsufficientFees = rawerr.New("Insufficient Fees")
errInsufficientFunds = rawerr.New("Insufficient Funds")
errNoInputs = rawerr.New("No Input Coins")
errNoOutputs = rawerr.New("No Output Coins")
errTooLarge = rawerr.New("Input size too large")
errMissingSignature = rawerr.New("Signature missing")
errTooManySignatures = rawerr.New("Too many signatures")
errNoChain = rawerr.New("No chain id provided")
errWrongChain = rawerr.New("Wrong chain for tx")
errUnknownTxType = rawerr.New("Tx type unknown")
)
func UnknownTxType(tx basecoin.Tx) TMError {
msg := fmt.Sprintf(msgUnknownTxType, tx)
return New(msg, abci.CodeType_UnknownRequest)
func ErrUnknownTxType(tx basecoin.Tx) TMError {
msg := fmt.Sprintf("%T", tx.Unwrap())
w := errors.Wrap(errUnknownTxType, msg)
return WithCode(w, abci.CodeType_UnknownRequest)
}
func InternalError(msg string) TMError {
func IsUnknownTxTypeErr(err error) bool {
return IsSameError(errUnknownTxType, err)
}
func ErrInternal(msg string) TMError {
return New(msg, abci.CodeType_InternalError)
}
func DecodingError() TMError {
return New(msgDecoding, abci.CodeType_EncodingError)
// IsInternalErr matches any error that is not classified
func IsInternalErr(err error) bool {
return HasErrorCode(err, abci.CodeType_InternalError)
}
func Unauthorized() TMError {
return New(msgUnauthorized, abci.CodeType_Unauthorized)
func ErrDecoding() TMError {
return WithCode(errDecoding, abci.CodeType_EncodingError)
}
func MissingSignature() TMError {
return New(msgMissingSignature, abci.CodeType_Unauthorized)
func IsDecodingErr(err error) bool {
return IsSameError(errDecoding, err)
}
func TooManySignatures() TMError {
return New(msgTooManySignatures, abci.CodeType_Unauthorized)
func ErrUnauthorized() TMError {
return WithCode(errUnauthorized, abci.CodeType_Unauthorized)
}
func InvalidSignature() TMError {
return New(msgInvalidSignature, abci.CodeType_Unauthorized)
// IsUnauthorizedErr is generic helper for any unauthorized errors,
// also specific sub-types
func IsUnauthorizedErr(err error) bool {
return HasErrorCode(err, abci.CodeType_Unauthorized)
}
func NoChain() TMError {
return New(msgNoChain, abci.CodeType_Unauthorized)
func ErrMissingSignature() TMError {
return WithCode(errMissingSignature, abci.CodeType_Unauthorized)
}
func WrongChain(chain string) TMError {
msg := fmt.Sprintf(msgWrongChain, chain)
return New(msg, abci.CodeType_Unauthorized)
func IsMissingSignatureErr(err error) bool {
return IsSameError(errMissingSignature, err)
}
func ErrTooManySignatures() TMError {
return WithCode(errTooManySignatures, abci.CodeType_Unauthorized)
}
func IsTooManySignaturesErr(err error) bool {
return IsSameError(errTooManySignatures, err)
}
func ErrInvalidSignature() TMError {
return WithCode(errInvalidSignature, abci.CodeType_Unauthorized)
}
func IsInvalidSignatureErr(err error) bool {
return IsSameError(errInvalidSignature, err)
}
func ErrNoChain() TMError {
return WithCode(errNoChain, abci.CodeType_Unauthorized)
}
func IsNoChainErr(err error) bool {
return IsSameError(errNoChain, err)
}
func ErrWrongChain(chain string) TMError {
msg := errors.Wrap(errWrongChain, chain)
return WithCode(msg, abci.CodeType_Unauthorized)
}
func IsWrongChainErr(err error) bool {
return IsSameError(errWrongChain, err)
}
func InvalidAddress() TMError {
return New(msgInvalidAddress, abci.CodeType_BaseInvalidInput)
return WithCode(errInvalidAddress, abci.CodeType_BaseInvalidInput)
}
func InvalidCoins() TMError {
return New(msgInvalidCoins, abci.CodeType_BaseInvalidInput)
return WithCode(errInvalidCoins, abci.CodeType_BaseInvalidInput)
}
func InvalidFormat() TMError {
return New(msgInvalidFormat, abci.CodeType_BaseInvalidInput)
return WithCode(errInvalidFormat, abci.CodeType_BaseInvalidInput)
}
func InvalidSequence() TMError {
return New(msgInvalidSequence, abci.CodeType_BaseInvalidInput)
return WithCode(errInvalidSequence, abci.CodeType_BaseInvalidInput)
}
func InsufficientFees() TMError {
return New(msgInsufficientFees, abci.CodeType_BaseInvalidInput)
return WithCode(errInsufficientFees, abci.CodeType_BaseInvalidInput)
}
func InsufficientFunds() TMError {
return New(msgInsufficientFunds, abci.CodeType_BaseInvalidInput)
return WithCode(errInsufficientFunds, abci.CodeType_BaseInvalidInput)
}
func NoInputs() TMError {
return New(msgNoInputs, abci.CodeType_BaseInvalidInput)
return WithCode(errNoInputs, abci.CodeType_BaseInvalidInput)
}
func NoOutputs() TMError {
return New(msgNoOutputs, abci.CodeType_BaseInvalidOutput)
return WithCode(errNoOutputs, abci.CodeType_BaseInvalidOutput)
}
func TooLarge() TMError {
return New(msgTooLarge, abci.CodeType_EncodingError)
func ErrTooLarge() TMError {
return WithCode(errTooLarge, abci.CodeType_EncodingError)
}
func IsTooLargeErr(err error) bool {
return IsSameError(errTooLarge, err)
}

74
errors/common_test.go Normal file
View File

@ -0,0 +1,74 @@
package errors
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tendermint/basecoin"
)
type DemoTx struct {
Age int
}
func (t DemoTx) Wrap() basecoin.Tx {
return basecoin.Tx{t}
}
func (t DemoTx) ValidateBasic() error {
return nil
}
func TestErrorMatches(t *testing.T) {
assert := assert.New(t)
cases := []struct {
pattern, err error
match bool
}{
{errDecoding, ErrDecoding(), true},
{errUnauthorized, ErrUnauthorized(), true},
{errMissingSignature, ErrUnauthorized(), false},
{errMissingSignature, ErrMissingSignature(), true},
{errWrongChain, ErrWrongChain("hakz"), true},
{errUnknownTxType, ErrUnknownTxType(basecoin.Tx{}), true},
{errUnknownTxType, ErrUnknownTxType(DemoTx{5}.Wrap()), true},
}
for i, tc := range cases {
same := IsSameError(tc.pattern, tc.err)
assert.Equal(tc.match, same, "%d: %#v / %#v", i, tc.pattern, tc.err)
}
}
func TestChecks(t *testing.T) {
// TODO: make sure the Is and Err methods match
assert := assert.New(t)
cases := []struct {
err error
check func(error) bool
match bool
}{
{ErrDecoding(), IsDecodingErr, true},
{ErrUnauthorized(), IsDecodingErr, false},
{ErrUnauthorized(), IsUnauthorizedErr, true},
{ErrInvalidSignature(), IsInvalidSignatureErr, true},
// unauthorized includes InvalidSignature, but not visa versa
{ErrInvalidSignature(), IsUnauthorizedErr, true},
{ErrUnauthorized(), IsInvalidSignatureErr, false},
// make sure WrongChain works properly
{ErrWrongChain("fooz"), IsUnauthorizedErr, true},
{ErrWrongChain("barz"), IsWrongChainErr, true},
// make sure lots of things match InternalErr, but not everything
{ErrInternal("bad db connection"), IsInternalErr, true},
{Wrap(errors.New("wrapped")), IsInternalErr, true},
{ErrUnauthorized(), IsInternalErr, false},
}
for i, tc := range cases {
match := tc.check(tc.err)
assert.Equal(tc.match, match, "%d", i)
}
}

View File

@ -19,6 +19,10 @@ type stackTracer interface {
StackTrace() errors.StackTrace
}
type causer interface {
Cause() error
}
type TMError interface {
stackTracer
ErrorCode() abci.CodeType
@ -31,6 +35,11 @@ type tmerror struct {
msg string
}
var (
_ causer = tmerror{}
_ error = tmerror{}
)
func (t tmerror) ErrorCode() abci.CodeType {
return t.code
}
@ -39,6 +48,13 @@ func (t tmerror) Message() string {
return t.msg
}
func (t tmerror) Cause() error {
if c, ok := t.stackTracer.(causer); ok {
return c.Cause()
}
return t.stackTracer
}
// Format handles "%+v" to expose the full stack trace
// concept from pkg/errors
func (t tmerror) Format(s fmt.State, verb rune) {
@ -102,3 +118,18 @@ func New(msg string, code abci.CodeType) TMError {
msg: msg,
}
}
// IsSameError returns true if these errors have the same root cause.
// pattern is the expected error type and should always be non-nil
// err may be anything and returns true if it is a wrapped version of pattern
func IsSameError(pattern error, err error) bool {
return err != nil && (errors.Cause(err) == errors.Cause(pattern))
}
// HasErrorCode checks if this error would return the named error code
func HasErrorCode(err error, code abci.CodeType) bool {
if tm, ok := err.(TMError); ok {
return tm.ErrorCode() == code
}
return code == defaultErrCode
}

View File

@ -24,8 +24,8 @@ func TestCreateResult(t *testing.T) {
{New("nonce", abci.CodeType_BadNonce), "nonce", abci.CodeType_BadNonce},
{Wrap(stderr.New("wrap")), "wrap", defaultErrCode},
{WithCode(stderr.New("coded"), abci.CodeType_BaseInvalidInput), "coded", abci.CodeType_BaseInvalidInput},
{DecodingError(), msgDecoding, abci.CodeType_EncodingError},
{Unauthorized(), msgUnauthorized, abci.CodeType_Unauthorized},
{ErrDecoding(), errDecoding.Error(), abci.CodeType_EncodingError},
{ErrUnauthorized(), errUnauthorized.Error(), abci.CodeType_Unauthorized},
}
for idx, tc := range cases {