diff --git a/handlers/base.go b/handlers/base.go index 26ca282a7..3df527d5b 100644 --- a/handlers/base.go +++ b/handlers/base.go @@ -23,22 +23,17 @@ type AccountChecker interface { type SimpleFeeHandler struct { AccountChecker MinFee types.Coins - Inner basecoin.Handler -} - -func (h SimpleFeeHandler) Next() basecoin.Handler { - return h.Inner } func (_ SimpleFeeHandler) Name() string { return NameFee } -var _ basecoin.Handler = SimpleFeeHandler{} +var _ basecoin.Middleware = SimpleFeeHandler{} // Yes, I know refactor a bit... really too late already -func (h SimpleFeeHandler) CheckTx(ctx basecoin.Context, store types.KVStore, tx basecoin.Tx) (res basecoin.Result, err error) { +func (h SimpleFeeHandler) CheckTx(ctx basecoin.Context, store types.KVStore, tx basecoin.Tx, next basecoin.Checker) (res basecoin.Result, err error) { feeTx, ok := tx.Unwrap().(*txs.Fee) if !ok { return res, errors.InvalidFormat() @@ -61,7 +56,7 @@ func (h SimpleFeeHandler) CheckTx(ctx basecoin.Context, store types.KVStore, tx return basecoin.Result{Log: "Valid tx"}, nil } -func (h SimpleFeeHandler) DeliverTx(ctx basecoin.Context, store types.KVStore, tx basecoin.Tx) (res basecoin.Result, err error) { +func (h SimpleFeeHandler) DeliverTx(ctx basecoin.Context, store types.KVStore, tx basecoin.Tx, next basecoin.Deliver) (res basecoin.Result, err error) { feeTx, ok := tx.Unwrap().(*txs.Fee) if !ok { return res, errors.InvalidFormat() @@ -81,5 +76,5 @@ func (h SimpleFeeHandler) DeliverTx(ctx basecoin.Context, store types.KVStore, t return res, err } - return h.Next().DeliverTx(ctx, store, feeTx.Next()) + return next.DeliverTx(ctx, store, feeTx.Next()) } diff --git a/handlers/sigs.go b/handlers/sigs.go index 50d80636e..05d323319 100644 --- a/handlers/sigs.go +++ b/handlers/sigs.go @@ -15,57 +15,36 @@ const ( type SignedHandler struct { AllowMultiSig bool - Inner basecoin.Handler } func (_ SignedHandler) Name() string { return NameSigs } -func (h SignedHandler) Next() basecoin.Handler { - return h.Inner -} - -var _ basecoin.Handler = SignedHandler{} +var _ basecoin.Middleware = SignedHandler{} +// Signed allows us to use txs.OneSig and txs.MultiSig (and others??) type Signed interface { basecoin.TxLayer Signers() ([]crypto.PubKey, error) } -func (h SignedHandler) CheckTx(ctx basecoin.Context, store types.KVStore, tx basecoin.Tx) (res basecoin.Result, err error) { - var sigs []crypto.PubKey - - stx, ok := tx.Unwrap().(Signed) - if !ok { - return res, errors.Unauthorized() - } - - sigs, err = stx.Signers() +func (h SignedHandler) CheckTx(ctx basecoin.Context, store types.KVStore, tx basecoin.Tx, next basecoin.Checker) (res basecoin.Result, err error) { + sigs, tnext, err := getSigners(tx) if err != nil { return res, err } - ctx2 := addSigners(ctx, sigs) - return h.Next().CheckTx(ctx2, store, stx.Next()) + return next.CheckTx(ctx2, store, tnext) } -func (h SignedHandler) DeliverTx(ctx basecoin.Context, store types.KVStore, tx basecoin.Tx) (res basecoin.Result, err error) { - var sigs []crypto.PubKey - - stx, ok := tx.Unwrap().(Signed) - if !ok { - return res, errors.Unauthorized() - } - - sigs, err = stx.Signers() +func (h SignedHandler) DeliverTx(ctx basecoin.Context, store types.KVStore, tx basecoin.Tx, next basecoin.Deliver) (res basecoin.Result, err error) { + sigs, tnext, err := getSigners(tx) if err != nil { return res, err } - - // add the signers to the context and continue ctx2 := addSigners(ctx, sigs) - return h.Next().DeliverTx(ctx2, store, stx.Next()) + return next.DeliverTx(ctx2, store, tnext) } func addSigners(ctx basecoin.Context, sigs []crypto.PubKey) basecoin.Context { @@ -73,7 +52,15 @@ func addSigners(ctx basecoin.Context, sigs []crypto.PubKey) basecoin.Context { for i, s := range sigs { perms[i] = basecoin.Permission{App: NameSigs, Address: s.Address()} } - // add the signers to the context and continue return ctx.AddPermissions(perms...) } + +func getSigners(tx basecoin.Tx) ([]crypto.PubKey, basecoin.Tx, error) { + stx, ok := tx.Unwrap().(Signed) + if !ok { + return nil, basecoin.Tx{}, errors.Unauthorized() + } + sig, err := stx.Signers() + return sig, stx.Next(), err +}