diff --git a/src/net.cpp b/src/net.cpp index df88b12c7..db914096f 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -689,6 +689,33 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete return true; } +void CNode::SetSendVersion(int nVersionIn) +{ + // Send version may only be changed in the version message, and + // only one version message is allowed per session. We can therefore + // treat this value as const and even atomic as long as it's only used + // once a version message has been successfully processed. Any attempt to + // set this twice is an error. + if (nSendVersion != 0) { + error("Send version already set for node: %i. Refusing to change from %i to %i", id, nSendVersion, nVersionIn); + } else { + nSendVersion = nVersionIn; + } +} + +int CNode::GetSendVersion() const +{ + // The send version should always be explicitly set to + // INIT_PROTO_VERSION rather than using this value until SetSendVersion + // has been called. + if (nSendVersion == 0) { + error("Requesting unset send version for node: %i. Using %i", id, INIT_PROTO_VERSION); + return INIT_PROTO_VERSION; + } + return nSendVersion; +} + + int CNetMessage::readHeader(const char *pch, unsigned int nBytes) { // copy data to temporary parsing buffer @@ -2630,6 +2657,11 @@ void CNode::AskFor(const CInv& inv) mapAskFor.insert(std::make_pair(nRequestTime, inv)); } +bool CConnman::NodeFullyConnected(const CNode* pnode) +{ + return pnode && pnode->fSuccessfullyConnected && !pnode->fDisconnect; +} + void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) { size_t nMessageSize = msg.data.size(); @@ -2680,7 +2712,7 @@ bool CConnman::ForNode(NodeId id, std::function func) break; } } - return found != nullptr && func(found); + return found != nullptr && NodeFullyConnected(found) && func(found); } int64_t PoissonNextSend(int64_t nNow, int average_interval_seconds) { diff --git a/src/net.h b/src/net.h index 0b8efcc88..1e3033785 100644 --- a/src/net.h +++ b/src/net.h @@ -161,76 +161,34 @@ public: void PushMessage(CNode* pnode, CSerializedNetMsg&& msg); - template - bool ForEachNodeContinueIf(Callable&& func) - { - LOCK(cs_vNodes); - for (auto&& node : vNodes) - if(!func(node)) - return false; - return true; - }; - - template - bool ForEachNodeContinueIf(Callable&& func) const - { - LOCK(cs_vNodes); - for (const auto& node : vNodes) - if(!func(node)) - return false; - return true; - }; - - template - bool ForEachNodeContinueIfThen(Callable&& pre, CallableAfter&& post) - { - bool ret = true; - LOCK(cs_vNodes); - for (auto&& node : vNodes) - if(!pre(node)) { - ret = false; - break; - } - post(); - return ret; - }; - - template - bool ForEachNodeContinueIfThen(Callable&& pre, CallableAfter&& post) const - { - bool ret = true; - LOCK(cs_vNodes); - for (const auto& node : vNodes) - if(!pre(node)) { - ret = false; - break; - } - post(); - return ret; - }; - template void ForEachNode(Callable&& func) { LOCK(cs_vNodes); - for (auto&& node : vNodes) - func(node); + for (auto&& node : vNodes) { + if (NodeFullyConnected(node)) + func(node); + } }; template void ForEachNode(Callable&& func) const { LOCK(cs_vNodes); - for (const auto& node : vNodes) - func(node); + for (auto&& node : vNodes) { + if (NodeFullyConnected(node)) + func(node); + } }; template void ForEachNodeThen(Callable&& pre, CallableAfter&& post) { LOCK(cs_vNodes); - for (auto&& node : vNodes) - pre(node); + for (auto&& node : vNodes) { + if (NodeFullyConnected(node)) + pre(node); + } post(); }; @@ -238,8 +196,10 @@ public: void ForEachNodeThen(Callable&& pre, CallableAfter&& post) const { LOCK(cs_vNodes); - for (const auto& node : vNodes) - pre(node); + for (auto&& node : vNodes) { + if (NodeFullyConnected(node)) + pre(node); + } post(); }; @@ -372,6 +332,9 @@ private: void RecordBytesRecv(uint64_t bytes); void RecordBytesSent(uint64_t bytes); + // Whether the node should be passed out in ForEach* callbacks + static bool NodeFullyConnected(const CNode* pnode); + // Network usage totals CCriticalSection cs_totalBytesRecv; CCriticalSection cs_totalBytesSent; @@ -627,7 +590,7 @@ public: const CAddress addr; std::string addrName; CService addrLocal; - int nVersion; + std::atomic nVersion; // strSubVer is whatever byte array we read from the wire. However, this field is intended // to be printed out, displayed to humans in various forms and so on. So we sanitize it and // store the sanitized version in cleanSubVer. The original should be used when dealing with @@ -639,7 +602,7 @@ public: bool fAddnode; bool fClient; const bool fInbound; - bool fSuccessfullyConnected; + std::atomic_bool fSuccessfullyConnected; std::atomic_bool fDisconnect; // We use fRelayTxes for two purposes - // a) it allows us to not relay tx invs before receiving the peer's version message @@ -760,25 +723,8 @@ public: { return nRecvVersion; } - void SetSendVersion(int nVersionIn) - { - // Send version may only be changed in the version message, and - // only one version message is allowed per session. We can therefore - // treat this value as const and even atomic as long as it's only used - // once the handshake is complete. Any attempt to set this twice is an - // error. - assert(nSendVersion == 0); - nSendVersion = nVersionIn; - } - - int GetSendVersion() const - { - // The send version should always be explicitly set to - // INIT_PROTO_VERSION rather than using this value until the handshake - // is complete. - assert(nSendVersion != 0); - return nSendVersion; - } + void SetSendVersion(int nVersionIn); + int GetSendVersion() const; CNode* AddRef() { diff --git a/src/net_processing.cpp b/src/net_processing.cpp index b9667eb6c..a7acd6edf 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -1199,50 +1199,51 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR CAddress addrFrom; uint64_t nNonce = 1; uint64_t nServiceInt; - vRecv >> pfrom->nVersion >> nServiceInt >> nTime >> addrMe; - pfrom->nServices = ServiceFlags(nServiceInt); + ServiceFlags nServices; + int nVersion; + int nSendVersion; + std::string strSubVer; + int nStartingHeight = -1; + bool fRelay = true; + + vRecv >> nVersion >> nServiceInt >> nTime >> addrMe; + nSendVersion = std::min(nVersion, PROTOCOL_VERSION); + nServices = ServiceFlags(nServiceInt); if (!pfrom->fInbound) { - connman.SetServices(pfrom->addr, pfrom->nServices); + connman.SetServices(pfrom->addr, nServices); } - if (pfrom->nServicesExpected & ~pfrom->nServices) + if (pfrom->nServicesExpected & ~nServices) { - LogPrint("net", "peer=%d does not offer the expected services (%08x offered, %08x expected); disconnecting\n", pfrom->id, pfrom->nServices, pfrom->nServicesExpected); + LogPrint("net", "peer=%d does not offer the expected services (%08x offered, %08x expected); disconnecting\n", pfrom->id, nServices, pfrom->nServicesExpected); connman.PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::REJECT, strCommand, REJECT_NONSTANDARD, strprintf("Expected to offer services %08x", pfrom->nServicesExpected))); pfrom->fDisconnect = true; return false; } - if (pfrom->nVersion < MIN_PEER_PROTO_VERSION) + if (nVersion < MIN_PEER_PROTO_VERSION) { // disconnect from peers older than this proto version - LogPrintf("peer=%d using obsolete version %i; disconnecting\n", pfrom->id, pfrom->nVersion); + LogPrintf("peer=%d using obsolete version %i; disconnecting\n", pfrom->id, nVersion); connman.PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::REJECT, strCommand, REJECT_OBSOLETE, strprintf("Version must be %d or greater", MIN_PEER_PROTO_VERSION))); pfrom->fDisconnect = true; return false; } - if (pfrom->nVersion == 10300) - pfrom->nVersion = 300; + if (nVersion == 10300) + nVersion = 300; if (!vRecv.empty()) vRecv >> addrFrom >> nNonce; if (!vRecv.empty()) { - vRecv >> LIMITED_STRING(pfrom->strSubVer, MAX_SUBVERSION_LENGTH); - pfrom->cleanSubVer = SanitizeString(pfrom->strSubVer); + vRecv >> LIMITED_STRING(strSubVer, MAX_SUBVERSION_LENGTH); } if (!vRecv.empty()) { - vRecv >> pfrom->nStartingHeight; + vRecv >> nStartingHeight; } - { - LOCK(pfrom->cs_filter); - if (!vRecv.empty()) - vRecv >> pfrom->fRelayTxes; // set to true after we get the first filter* message - else - pfrom->fRelayTxes = true; - } - + if (!vRecv.empty()) + vRecv >> fRelay; // Disconnect if we connected to ourself if (pfrom->fInbound && !connman.CheckIncomingNonce(nNonce)) { @@ -1251,7 +1252,6 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR return true; } - pfrom->addrLocal = addrMe; if (pfrom->fInbound && addrMe.IsRoutable()) { SeenLocal(addrMe); @@ -1261,9 +1261,24 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR if (pfrom->fInbound) PushNodeVersion(pfrom, connman, GetAdjustedTime()); - pfrom->fClient = !(pfrom->nServices & NODE_NETWORK); + connman.PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::VERACK)); - if((pfrom->nServices & NODE_WITNESS)) + pfrom->nServices = nServices; + pfrom->addrLocal = addrMe; + pfrom->strSubVer = strSubVer; + pfrom->cleanSubVer = SanitizeString(strSubVer); + pfrom->nStartingHeight = nStartingHeight; + pfrom->fClient = !(nServices & NODE_NETWORK); + { + LOCK(pfrom->cs_filter); + pfrom->fRelayTxes = fRelay; // set to true after we get the first filter* message + } + + // Change version + pfrom->SetSendVersion(nSendVersion); + pfrom->nVersion = nVersion; + + if((nServices & NODE_WITNESS)) { LOCK(cs_main); State(pfrom->GetId())->fHaveWitness = true; @@ -1275,11 +1290,6 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR UpdatePreferredDownload(pfrom, State(pfrom->GetId())); } - // Change version - connman.PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::VERACK)); - int nSendVersion = std::min(pfrom->nVersion, PROTOCOL_VERSION); - pfrom->SetSendVersion(nSendVersion); - if (!pfrom->fInbound) { // Advertise our address @@ -1307,8 +1317,6 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR connman.MarkAddressGood(pfrom->addr); } - pfrom->fSuccessfullyConnected = true; - std::string remoteAddr; if (fLogIPs) remoteAddr = ", peeraddr=" + pfrom->addr.ToString(); @@ -1350,7 +1358,7 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR if (strCommand == NetMsgType::VERACK) { - pfrom->SetRecvVersion(std::min(pfrom->nVersion, PROTOCOL_VERSION)); + pfrom->SetRecvVersion(std::min(pfrom->nVersion.load(), PROTOCOL_VERSION)); if (!pfrom->fInbound) { // Mark this node as currently connected, so we update its timestamp later. @@ -1378,6 +1386,7 @@ bool static ProcessMessage(CNode* pfrom, std::string strCommand, CDataStream& vR nCMPCTBLOCKVersion = 1; connman.PushMessage(pfrom, msgMaker.Make(NetMsgType::SENDCMPCT, fAnnounceUsingCMPCTBLOCK, nCMPCTBLOCKVersion)); } + pfrom->fSuccessfullyConnected = true; } @@ -2716,8 +2725,8 @@ bool SendMessages(CNode* pto, CConnman& connman, std::atomic& interruptMsg { const Consensus::Params& consensusParams = Params().GetConsensus(); { - // Don't send anything until we get its version message - if (pto->nVersion == 0 || pto->fDisconnect) + // Don't send anything until the version handshake is complete + if (!pto->fSuccessfullyConnected || pto->fDisconnect) return true; // If we get here, the outgoing message serialization version is set and can't change. diff --git a/src/test/DoS_tests.cpp b/src/test/DoS_tests.cpp index 9345a44fb..a8f09ba6a 100644 --- a/src/test/DoS_tests.cpp +++ b/src/test/DoS_tests.cpp @@ -55,6 +55,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) dummyNode1.SetSendVersion(PROTOCOL_VERSION); GetNodeSignals().InitializeNode(&dummyNode1, *connman); dummyNode1.nVersion = 1; + dummyNode1.fSuccessfullyConnected = true; Misbehaving(dummyNode1.GetId(), 100); // Should get banned SendMessages(&dummyNode1, *connman, interruptDummy); BOOST_CHECK(connman->IsBanned(addr1)); @@ -65,6 +66,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) dummyNode2.SetSendVersion(PROTOCOL_VERSION); GetNodeSignals().InitializeNode(&dummyNode2, *connman); dummyNode2.nVersion = 1; + dummyNode2.fSuccessfullyConnected = true; Misbehaving(dummyNode2.GetId(), 50); SendMessages(&dummyNode2, *connman, interruptDummy); BOOST_CHECK(!connman->IsBanned(addr2)); // 2 not banned yet... @@ -85,6 +87,7 @@ BOOST_AUTO_TEST_CASE(DoS_banscore) dummyNode1.SetSendVersion(PROTOCOL_VERSION); GetNodeSignals().InitializeNode(&dummyNode1, *connman); dummyNode1.nVersion = 1; + dummyNode1.fSuccessfullyConnected = true; Misbehaving(dummyNode1.GetId(), 100); SendMessages(&dummyNode1, *connman, interruptDummy); BOOST_CHECK(!connman->IsBanned(addr1)); @@ -110,6 +113,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) dummyNode.SetSendVersion(PROTOCOL_VERSION); GetNodeSignals().InitializeNode(&dummyNode, *connman); dummyNode.nVersion = 1; + dummyNode.fSuccessfullyConnected = true; Misbehaving(dummyNode.GetId(), 100); SendMessages(&dummyNode, *connman, interruptDummy);