diff --git a/src/main.cpp b/src/main.cpp index 22baf0f3e..340614459 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -3168,7 +3168,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) else if (strCommand == "verack") { - pfrom->vRecv.SetVersion(min(pfrom->nVersion, PROTOCOL_VERSION)); + pfrom->SetRecvVersion(min(pfrom->nVersion, PROTOCOL_VERSION)); } @@ -3705,13 +3705,13 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv) return true; } +// requires LOCK(cs_vRecvMsg) bool ProcessMessages(CNode* pfrom) { - CDataStream& vRecv = pfrom->vRecv; - if (vRecv.empty()) + if (pfrom->vRecvMsg.empty()) return true; //if (fDebug) - // printf("ProcessMessages(%u bytes)\n", vRecv.size()); + // printf("ProcessMessages(%zu messages)\n", pfrom->vRecvMsg.size()); // // Message format @@ -3722,32 +3722,32 @@ bool ProcessMessages(CNode* pfrom) // (x) data // - loop + unsigned int nMsgPos = 0; + for (; nMsgPos < pfrom->vRecvMsg.size(); nMsgPos++) { // Don't bother if send buffer is too full to respond anyway if (pfrom->vSend.size() >= SendBufferSize()) break; - // Scan for message start - CDataStream::iterator pstart = search(vRecv.begin(), vRecv.end(), BEGIN(pchMessageStart), END(pchMessageStart)); - int nHeaderSize = vRecv.GetSerializeSize(CMessageHeader()); - if (vRecv.end() - pstart < nHeaderSize) - { - if ((int)vRecv.size() > nHeaderSize) - { - printf("\n\nPROCESSMESSAGE MESSAGESTART NOT FOUND\n\n"); - vRecv.erase(vRecv.begin(), vRecv.end() - nHeaderSize); - } + // get next message; end, if an incomplete message is found + CNetMessage& msg = pfrom->vRecvMsg[nMsgPos]; + + //if (fDebug) + // printf("ProcessMessages(message %u msgsz, %zu bytes, complete:%s)\n", + // msg.hdr.nMessageSize, msg.vRecv.size(), + // msg.complete() ? "Y" : "N"); + + if (!msg.complete()) break; + + // Scan for message start + if (memcmp(msg.hdr.pchMessageStart, pchMessageStart, sizeof(pchMessageStart)) != 0) { + printf("\n\nPROCESSMESSAGE: INVALID MESSAGESTART\n\n"); + return false; } - if (pstart - vRecv.begin() > 0) - printf("\n\nPROCESSMESSAGE SKIPPED %"PRIpdd" BYTES\n\n", pstart - vRecv.begin()); - vRecv.erase(vRecv.begin(), pstart); // Read header - vector vHeaderSave(vRecv.begin(), vRecv.begin() + nHeaderSize); - CMessageHeader hdr; - vRecv >> hdr; + CMessageHeader& hdr = msg.hdr; if (!hdr.IsValid()) { printf("\n\nPROCESSMESSAGE: ERRORS IN HEADER %s\n\n\n", hdr.GetCommand().c_str()); @@ -3757,19 +3757,9 @@ bool ProcessMessages(CNode* pfrom) // Message size unsigned int nMessageSize = hdr.nMessageSize; - if (nMessageSize > MAX_SIZE) - { - printf("ProcessMessages(%s, %u bytes) : nMessageSize > MAX_SIZE\n", strCommand.c_str(), nMessageSize); - continue; - } - if (nMessageSize > vRecv.size()) - { - // Rewind and wait for rest of message - vRecv.insert(vRecv.begin(), vHeaderSave.begin(), vHeaderSave.end()); - break; - } // Checksum + CDataStream& vRecv = msg.vRecv; uint256 hash = Hash(vRecv.begin(), vRecv.begin() + nMessageSize); unsigned int nChecksum = 0; memcpy(&nChecksum, &hash, sizeof(nChecksum)); @@ -3780,17 +3770,13 @@ bool ProcessMessages(CNode* pfrom) continue; } - // Copy message to its own buffer - CDataStream vMsg(vRecv.begin(), vRecv.begin() + nMessageSize, vRecv.nType, vRecv.nVersion); - vRecv.ignore(nMessageSize); - // Process message bool fRet = false; try { { LOCK(cs_main); - fRet = ProcessMessage(pfrom, strCommand, vMsg); + fRet = ProcessMessage(pfrom, strCommand, vRecv); } if (fShutdown) return true; @@ -3822,7 +3808,10 @@ bool ProcessMessages(CNode* pfrom) printf("ProcessMessage(%s, %u bytes) FAILED\n", strCommand.c_str(), nMessageSize); } - vRecv.Compact(); + // remove processed messages; one incomplete message may remain + if (nMsgPos > 0) + pfrom->vRecvMsg.erase(pfrom->vRecvMsg.begin(), + pfrom->vRecvMsg.begin() + nMsgPos); return true; } diff --git a/src/net.cpp b/src/net.cpp index 6c8fe3ffc..0e558228d 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -536,7 +536,7 @@ void CNode::CloseSocketDisconnect() printf("disconnecting node %s\n", addrName.c_str()); closesocket(hSocket); hSocket = INVALID_SOCKET; - vRecv.clear(); + vRecvMsg.clear(); } } @@ -628,6 +628,78 @@ void CNode::copyStats(CNodeStats &stats) } #undef X +// requires LOCK(cs_vRecvMsg) +bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes) +{ + while (nBytes > 0) { + + // get current incomplete message, or create a new one + if (vRecvMsg.size() == 0 || + vRecvMsg.back().complete()) + vRecvMsg.push_back(CNetMessage(SER_NETWORK, nRecvVersion)); + + CNetMessage& msg = vRecvMsg.back(); + + // absorb network data + int handled; + if (!msg.in_data) + handled = msg.readHeader(pch, nBytes); + else + handled = msg.readData(pch, nBytes); + + if (handled < 0) + return false; + + pch += handled; + nBytes -= handled; + } + + return true; +} + +int CNetMessage::readHeader(const char *pch, unsigned int nBytes) +{ + // copy data to temporary parsing buffer + unsigned int nRemaining = 24 - nHdrPos; + unsigned int nCopy = std::min(nRemaining, nBytes); + + memcpy(&hdrbuf[nHdrPos], pch, nCopy); + nHdrPos += nCopy; + + // if header incomplete, exit + if (nHdrPos < 24) + return nCopy; + + // deserialize to CMessageHeader + try { + hdrbuf >> hdr; + } + catch (std::exception &e) { + return -1; + } + + // reject messages larger than MAX_SIZE + if (hdr.nMessageSize > MAX_SIZE) + return -1; + + // switch state to reading message data + in_data = true; + vRecv.resize(hdr.nMessageSize); + + return nCopy; +} + +int CNetMessage::readData(const char *pch, unsigned int nBytes) +{ + unsigned int nRemaining = hdr.nMessageSize - nDataPos; + unsigned int nCopy = std::min(nRemaining, nBytes); + + memcpy(&vRecv[nDataPos], pch, nCopy); + nDataPos += nCopy; + + return nCopy; +} + @@ -676,7 +748,7 @@ void ThreadSocketHandler2(void* parg) BOOST_FOREACH(CNode* pnode, vNodesCopy) { if (pnode->fDisconnect || - (pnode->GetRefCount() <= 0 && pnode->vRecv.empty() && pnode->vSend.empty())) + (pnode->GetRefCount() <= 0 && pnode->vRecvMsg.empty() && pnode->vSend.empty())) { // remove from vNodes vNodes.erase(remove(vNodes.begin(), vNodes.end(), pnode), vNodes.end()); @@ -708,7 +780,7 @@ void ThreadSocketHandler2(void* parg) TRY_LOCK(pnode->cs_vSend, lockSend); if (lockSend) { - TRY_LOCK(pnode->cs_vRecv, lockRecv); + TRY_LOCK(pnode->cs_vRecvMsg, lockRecv); if (lockRecv) { TRY_LOCK(pnode->cs_inventory, lockInv); @@ -873,15 +945,12 @@ void ThreadSocketHandler2(void* parg) continue; if (FD_ISSET(pnode->hSocket, &fdsetRecv) || FD_ISSET(pnode->hSocket, &fdsetError)) { - TRY_LOCK(pnode->cs_vRecv, lockRecv); + TRY_LOCK(pnode->cs_vRecvMsg, lockRecv); if (lockRecv) { - CDataStream& vRecv = pnode->vRecv; - unsigned int nPos = vRecv.size(); - - if (nPos > ReceiveBufferSize()) { + if (pnode->GetTotalRecvSize() > ReceiveFloodSize()) { if (!pnode->fDisconnect) - printf("socket recv flood control disconnect (%"PRIszu" bytes)\n", vRecv.size()); + printf("socket recv flood control disconnect (%u bytes)\n", pnode->GetTotalRecvSize()); pnode->CloseSocketDisconnect(); } else { @@ -890,8 +959,8 @@ void ThreadSocketHandler2(void* parg) int nBytes = recv(pnode->hSocket, pchBuf, sizeof(pchBuf), MSG_DONTWAIT); if (nBytes > 0) { - vRecv.resize(nPos + nBytes); - memcpy(&vRecv[nPos], pchBuf, nBytes); + if (!pnode->ReceiveMsgBytes(pchBuf, nBytes)) + pnode->CloseSocketDisconnect(); pnode->nLastRecv = GetTime(); } else if (nBytes == 0) @@ -1693,9 +1762,10 @@ void ThreadMessageHandler2(void* parg) { // Receive messages { - TRY_LOCK(pnode->cs_vRecv, lockRecv); + TRY_LOCK(pnode->cs_vRecvMsg, lockRecv); if (lockRecv) - ProcessMessages(pnode); + if (!ProcessMessages(pnode)) + pnode->CloseSocketDisconnect(); } if (fShutdown) return; diff --git a/src/net.h b/src/net.h index 3b46523cd..78f8e72fb 100644 --- a/src/net.h +++ b/src/net.h @@ -27,7 +27,7 @@ extern int nBestHeight; -inline unsigned int ReceiveBufferSize() { return 1000*GetArg("-maxreceivebuffer", 5*1000); } +inline unsigned int ReceiveFloodSize() { return 1000*GetArg("-maxreceivebuffer", 5*1000); } inline unsigned int SendBufferSize() { return 1000*GetArg("-maxsendbuffer", 1*1000); } void AddOneShot(std::string strDest); @@ -126,6 +126,44 @@ public: +class CNetMessage { +public: + bool in_data; // parsing header (false) or data (true) + + CDataStream hdrbuf; // partially received header + CMessageHeader hdr; // complete header + unsigned int nHdrPos; + + CDataStream vRecv; // received message data + unsigned int nDataPos; + + CNetMessage(int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), vRecv(nTypeIn, nVersionIn) { + hdrbuf.resize(24); + in_data = false; + nHdrPos = 0; + nDataPos = 0; + } + + bool complete() const + { + if (!in_data) + return false; + return (hdr.nMessageSize == nDataPos); + } + + void SetVersion(int nVersionIn) + { + hdrbuf.SetVersion(nVersionIn); + vRecv.SetVersion(nVersionIn); + } + + int readHeader(const char *pch, unsigned int nBytes); + int readData(const char *pch, unsigned int nBytes); +}; + + + + /** Information about a peer */ class CNode @@ -135,9 +173,12 @@ public: uint64 nServices; SOCKET hSocket; CDataStream vSend; - CDataStream vRecv; CCriticalSection cs_vSend; - CCriticalSection cs_vRecv; + + std::vector vRecvMsg; + CCriticalSection cs_vRecvMsg; + int nRecvVersion; + int64 nLastSend; int64 nLastRecv; int64 nLastSendEmpty; @@ -191,10 +232,11 @@ public: CCriticalSection cs_inventory; std::multimap mapAskFor; - CNode(SOCKET hSocketIn, CAddress addrIn, std::string addrNameIn = "", bool fInboundIn=false) : vSend(SER_NETWORK, MIN_PROTO_VERSION), vRecv(SER_NETWORK, MIN_PROTO_VERSION) + CNode(SOCKET hSocketIn, CAddress addrIn, std::string addrNameIn = "", bool fInboundIn=false) : vSend(SER_NETWORK, MIN_PROTO_VERSION) { nServices = 0; hSocket = hSocketIn; + nRecvVersion = MIN_PROTO_VERSION; nLastSend = 0; nLastRecv = 0; nLastSendEmpty = GetTime(); @@ -250,6 +292,26 @@ public: return std::max(nRefCount, 0) + (GetTime() < nReleaseTime ? 1 : 0); } + // requires LOCK(cs_vRecvMsg) + unsigned int GetTotalRecvSize() + { + unsigned int total = 0; + for (unsigned int i = 0; i < vRecvMsg.size(); i++) + total += vRecvMsg[i].vRecv.size(); + return total; + } + + // requires LOCK(cs_vRecvMsg) + bool ReceiveMsgBytes(const char *pch, unsigned int nBytes); + + // requires LOCK(cs_vRecvMsg) + void SetRecvVersion(int nVersionIn) + { + nRecvVersion = nVersionIn; + for (unsigned int i = 0; i < vRecvMsg.size(); i++) + vRecvMsg[i].SetVersion(nVersionIn); + } + CNode* AddRef(int64 nTimeout=0) { if (nTimeout != 0)