diff --git a/consensus/istanbul/core/errors.go b/consensus/istanbul/core/errors.go index 62a5ce226..a97353204 100644 --- a/consensus/istanbul/core/errors.go +++ b/consensus/istanbul/core/errors.go @@ -43,4 +43,6 @@ var ( errFailedDecodeCommit = errors.New("failed to decode COMMIT") // errFailedDecodeMessageSet is returned when the message set is malformed. errFailedDecodeMessageSet = errors.New("failed to decode message set") + // errInvalidSigner is returned when the message is signed by a validator different than message sender + errInvalidSigner = errors.New("message not signed by the sender") ) diff --git a/consensus/istanbul/core/testbackend_test.go b/consensus/istanbul/core/testbackend_test.go index 812056f1f..bde93b1ca 100644 --- a/consensus/istanbul/core/testbackend_test.go +++ b/consensus/istanbul/core/testbackend_test.go @@ -109,8 +109,8 @@ func (self *testSystemBackend) Verify(proposal istanbul.Proposal) (time.Duration } func (self *testSystemBackend) Sign(data []byte) ([]byte, error) { - testLogger.Warn("not sign any data") - return data, nil + testLogger.Info("returning current backend address so that CheckValidatorSignature returns the same value") + return self.address.Bytes(), nil } func (self *testSystemBackend) CheckSignature([]byte, common.Address, []byte) error { @@ -118,7 +118,7 @@ func (self *testSystemBackend) CheckSignature([]byte, common.Address, []byte) er } func (self *testSystemBackend) CheckValidatorSignature(data []byte, sig []byte) (common.Address, error) { - return common.Address{}, nil + return common.BytesToAddress(sig), nil } func (self *testSystemBackend) Hash(b interface{}) common.Hash { diff --git a/consensus/istanbul/core/types.go b/consensus/istanbul/core/types.go index da202cb2c..fca60eb89 100644 --- a/consensus/istanbul/core/types.go +++ b/consensus/istanbul/core/types.go @@ -17,6 +17,7 @@ package core import ( + "bytes" "fmt" "io" @@ -137,10 +138,15 @@ func (m *message) FromPayload(b []byte, validateFn func([]byte, []byte) (common. return err } - _, err = validateFn(payload, m.Signature) + signerAdd, err := validateFn(payload, m.Signature) + if err != nil { + return err + } + if bytes.Compare(signerAdd.Bytes(), m.Address.Bytes()) != 0 { + return errInvalidSigner + } } - // Still return the message even the err is not nil - return err + return nil } func (m *message) Payload() ([]byte, error) { diff --git a/consensus/istanbul/core/types_test.go b/consensus/istanbul/core/types_test.go index a280fba54..10dcaa8f1 100644 --- a/consensus/istanbul/core/types_test.go +++ b/consensus/istanbul/core/types_test.go @@ -124,10 +124,11 @@ func testSubjectWithSignature(t *testing.T) { subjectPayload, _ := Encode(s) // 1. Encode test + address := common.HexToAddress("0x1234567890") m := &message{ Code: msgPreprepare, Msg: subjectPayload, - Address: common.HexToAddress("0x1234567890"), + Address: address, Signature: expectedSig, CommittedSeal: []byte{}, } @@ -141,7 +142,7 @@ func testSubjectWithSignature(t *testing.T) { // 2.1 Test normal validate func decodedMsg := new(message) err = decodedMsg.FromPayload(msgPayload, func(data []byte, sig []byte) (common.Address, error) { - return common.Address{}, nil + return address, nil }) if err != nil { t.Errorf("error mismatch: have %v, want nil", err)