diff --git a/consensus/state.go b/consensus/state.go index 7dcb6eb0..6a0c54bf 100644 --- a/consensus/state.go +++ b/consensus/state.go @@ -655,7 +655,7 @@ func (cs *ConsensusState) handleMsg(mi msgInfo, rs RoundState) { err = cs.setProposal(msg.Proposal) case *BlockPartMessage: // if the proposal is complete, we'll enterPrevote or tryFinalizeCommit - _, err = cs.addProposalBlockPart(msg.Height, msg.Part) + _, err = cs.addProposalBlockPart(msg.Height, msg.Part, peerKey != "") if err != nil && msg.Round != cs.Round { err = nil } @@ -1291,7 +1291,7 @@ func (cs *ConsensusState) setProposal(proposal *types.Proposal) error { // NOTE: block is not necessarily valid. // Asynchronously triggers either enterPrevote (before we timeout of propose) or tryFinalizeCommit, once we have the full block. -func (cs *ConsensusState) addProposalBlockPart(height int, part *types.Part) (added bool, err error) { +func (cs *ConsensusState) addProposalBlockPart(height int, part *types.Part, verify bool) (added bool, err error) { // Blocks might be reused, so round mismatch is OK if cs.Height != height { return false, nil @@ -1302,7 +1302,7 @@ func (cs *ConsensusState) addProposalBlockPart(height int, part *types.Part) (ad return false, nil // TODO: bad peer? Return error? } - added, err = cs.ProposalBlockParts.AddPart(part) + added, err = cs.ProposalBlockParts.AddPart(part, verify) if err != nil { return added, err } diff --git a/types/part_set.go b/types/part_set.go index bb7d11b1..06314fe1 100644 --- a/types/part_set.go +++ b/types/part_set.go @@ -188,7 +188,7 @@ func (ps *PartSet) Total() int { return ps.total } -func (ps *PartSet) AddPart(part *Part) (bool, error) { +func (ps *PartSet) AddPart(part *Part, verify bool) (bool, error) { ps.mtx.Lock() defer ps.mtx.Unlock() @@ -203,9 +203,10 @@ func (ps *PartSet) AddPart(part *Part) (bool, error) { } // Check hash proof - // TODO: minor gains for not checking part sets we made - if !part.Proof.Verify(part.Index, ps.total, part.Hash(), ps.Hash()) { - return false, ErrPartSetInvalidProof + if verify { + if !part.Proof.Verify(part.Index, ps.total, part.Hash(), ps.Hash()) { + return false, ErrPartSetInvalidProof + } } // Add part