cosmos-sdk/x/auth/middleware/tx.go

78 lines
2.2 KiB
Go

package middleware
import (
"context"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
)
type txDecoderHandler struct {
next tx.Handler
txDecoder sdk.TxDecoder
}
// NewTxDecoderMiddleware creates a new middleware that will decode tx bytes
// into a sdk.Tx. As input request, at least one of Tx or TxBytes must be set.
// If only TxBytes is set, then TxDecoderMiddleware will populate the Tx field.
// If only Tx is set, then TxBytes will be left empty, but some middlewares
// such as signature verification might fail.
func NewTxDecoderMiddleware(txDecoder sdk.TxDecoder) tx.Middleware {
return func(txh tx.Handler) tx.Handler {
return txDecoderHandler{next: txh, txDecoder: txDecoder}
}
}
var _ tx.Handler = gasTxHandler{}
// CheckTx implements tx.Handler.CheckTx.
func (h txDecoderHandler) CheckTx(ctx context.Context, req tx.Request, checkReq tx.RequestCheckTx) (tx.Response, tx.ResponseCheckTx, error) {
newReq, err := h.populateReq(req)
if err != nil {
return tx.Response{}, tx.ResponseCheckTx{}, err
}
return h.next.CheckTx(ctx, newReq, checkReq)
}
// DeliverTx implements tx.Handler.DeliverTx.
func (h txDecoderHandler) DeliverTx(ctx context.Context, req tx.Request) (tx.Response, error) {
newReq, err := h.populateReq(req)
if err != nil {
return tx.Response{}, err
}
return h.next.DeliverTx(ctx, newReq)
}
// SimulateTx implements tx.Handler.SimulateTx method.
func (h txDecoderHandler) SimulateTx(ctx context.Context, req tx.Request) (tx.Response, error) {
newReq, err := h.populateReq(req)
if err != nil {
return tx.Response{}, err
}
return h.next.SimulateTx(ctx, newReq)
}
// populateReq takes a tx.Request, and if its Tx field is not set, then
// decodes the TxBytes and populates the decoded Tx field. It leaves
// req.TxBytes untouched.
func (h txDecoderHandler) populateReq(req tx.Request) (tx.Request, error) {
if len(req.TxBytes) == 0 && req.Tx == nil {
return tx.Request{}, sdkerrors.ErrInvalidRequest.Wrap("got empty tx request")
}
sdkTx := req.Tx
var err error
if len(req.TxBytes) != 0 {
sdkTx, err = h.txDecoder(req.TxBytes)
if err != nil {
return tx.Request{}, err
}
}
return tx.Request{Tx: sdkTx, TxBytes: req.TxBytes}, nil
}