diff --git a/daemon/daemon.go b/daemon/daemon.go index a9928dcc..b6497b13 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -65,6 +65,7 @@ func NewNode() *Node { } sw := p2p.NewSwitch([]p2p.Reactor{pexReactor, mempoolReactor, consensusReactor}) + sw.SetChainId(state.Hash(), config.App().GetString("Network")) return &Node{ sw: sw, diff --git a/p2p/pex_reactor.go b/p2p/pex_reactor.go index 3ede6dc0..166c3c5f 100644 --- a/p2p/pex_reactor.go +++ b/p2p/pex_reactor.go @@ -96,6 +96,12 @@ func (pexR *PEXReactor) Receive(chId byte, src *Peer, msgBytes []byte) { log.Info("Received message", "msg", msg) switch msg.(type) { + case *pexHandshakeMessage: + chainId := msg.(*pexHandshakeMessage).ChainId + if chainId != pexR.sw.chainId { + err := fmt.Sprintf("Peer is on a different chain/network. Got %s, expected %s", chainId, pexR.sw.chainId) + pexR.sw.StopPeerForError(src, err) + } case *pexRequestMessage: // src requested some peers. // TODO: prevent abuse. @@ -201,9 +207,10 @@ func (pexR *PEXReactor) ensurePeers() { // Messages const ( - msgTypeUnknown = byte(0x00) - msgTypeRequest = byte(0x01) - msgTypeAddrs = byte(0x02) + msgTypeUnknown = byte(0x00) + msgTypeRequest = byte(0x01) + msgTypeAddrs = byte(0x02) + msgTypeHandshake = byte(0x03) ) // TODO: check for unnecessary extra bytes at the end. @@ -213,6 +220,8 @@ func DecodeMessage(bz []byte) (msg interface{}, err error) { r := bytes.NewReader(bz) // log.Debug(Fmt("decoding msg bytes: %X", bz)) switch msgType { + case msgTypeHandshake: + msg = binary.ReadBinary(&pexHandshakeMessage{}, r, n, &err) case msgTypeRequest: msg = &pexRequestMessage{} case msgTypeAddrs: @@ -223,6 +232,19 @@ func DecodeMessage(bz []byte) (msg interface{}, err error) { return } +/* +A pexHandshakeMessage contains the peer's chainId +*/ +type pexHandshakeMessage struct { + ChainId string +} + +func (m *pexHandshakeMessage) TypeByte() byte { return msgTypeHandshake } + +func (m *pexHandshakeMessage) String() string { + return "[pexHandshake]" +} + /* A pexRequestMessage requests additional peer addresses. */ diff --git a/p2p/switch.go b/p2p/switch.go index b9696d24..0b348003 100644 --- a/p2p/switch.go +++ b/p2p/switch.go @@ -1,6 +1,7 @@ package p2p import ( + "encoding/hex" "errors" "fmt" "net" @@ -37,6 +38,7 @@ type Switch struct { quit chan struct{} started uint32 stopped uint32 + chainId string } var ( @@ -129,6 +131,10 @@ func (sw *Switch) AddPeerWithConnection(conn net.Conn, outbound bool) (*Peer, er // Notify listeners. sw.doAddPeer(peer) + // Send handshake + msg := &pexHandshakeMessage{ChainId: sw.chainId} + peer.Send(PexCh, msg) + return peer, nil } @@ -216,6 +222,10 @@ func (sw *Switch) StopPeerGracefully(peer *Peer) { sw.doRemovePeer(peer, nil) } +func (sw *Switch) SetChainId(hash []byte, network string) { + sw.chainId = hex.EncodeToString(hash) + "-" + network +} + func (sw *Switch) IsListening() bool { return sw.listeners.Size() > 0 }