diff --git a/qa/rpc-tests/nodehandling.py b/qa/rpc-tests/nodehandling.py index 9a77bd97e..d89cfcf59 100755 --- a/qa/rpc-tests/nodehandling.py +++ b/qa/rpc-tests/nodehandling.py @@ -48,7 +48,25 @@ class NodeHandlingTest (BitcoinTestFramework): assert_equal(len(self.nodes[2].listbanned()), 0) self.nodes[2].clearbanned() assert_equal(len(self.nodes[2].listbanned()), 0) - + + ##test persisted banlist + self.nodes[2].setban("127.0.0.0/32", "add") + self.nodes[2].setban("127.0.0.0/24", "add") + self.nodes[2].setban("192.168.0.1", "add", 1) #ban for 1 seconds + self.nodes[2].setban("2001:4d48:ac57:400:cacf:e9ff:fe1d:9c63/19", "add", 1000) #ban for 1000 seconds + listBeforeShutdown = self.nodes[2].listbanned(); + assert_equal("192.168.0.1/255.255.255.255", listBeforeShutdown[2]['address']) #must be here + time.sleep(2) #make 100% sure we expired 192.168.0.1 node time + + #stop node + stop_node(self.nodes[2], 2) + + self.nodes[2] = start_node(2, self.options.tmpdir) + listAfterShutdown = self.nodes[2].listbanned(); + assert_equal("127.0.0.0/255.255.255.0", listAfterShutdown[0]['address']) + assert_equal("127.0.0.0/255.255.255.255", listAfterShutdown[1]['address']) + assert_equal("2001:4000::/ffff:e000:0:0:0:0:0:0", listAfterShutdown[2]['address']) + ########################### # RPC disconnectnode test # ########################### diff --git a/src/main.cpp b/src/main.cpp index 6c4cfe75a..a000a81fd 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -4959,7 +4959,7 @@ bool SendMessages(CNode* pto, bool fSendTrickle) LogPrintf("Warning: not banning local peer %s!\n", pto->addr.ToString()); else { - CNode::Ban(pto->addr); + CNode::Ban(pto->addr, BanReasonNodeMisbehaving); } } state.fShouldBan = false; diff --git a/src/net.cpp b/src/net.cpp index 0511256e5..950311ee3 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -443,13 +443,15 @@ void CNode::PushVersion() -std::map CNode::setBanned; +banmap_t CNode::setBanned; CCriticalSection CNode::cs_setBanned; +bool CNode::setBannedIsDirty; void CNode::ClearBanned() { LOCK(cs_setBanned); setBanned.clear(); + setBannedIsDirty = true; } bool CNode::IsBanned(CNetAddr ip) @@ -457,12 +459,12 @@ bool CNode::IsBanned(CNetAddr ip) bool fResult = false; { LOCK(cs_setBanned); - for (std::map::iterator it = setBanned.begin(); it != setBanned.end(); it++) + for (banmap_t::iterator it = setBanned.begin(); it != setBanned.end(); it++) { CSubNet subNet = (*it).first; - int64_t t = (*it).second; + CBanEntry banEntry = (*it).second; - if(subNet.Match(ip) && GetTime() < t) + if(subNet.Match(ip) && GetTime() < banEntry.nBanUntil) fResult = true; } } @@ -474,50 +476,99 @@ bool CNode::IsBanned(CSubNet subnet) bool fResult = false; { LOCK(cs_setBanned); - std::map::iterator i = setBanned.find(subnet); + banmap_t::iterator i = setBanned.find(subnet); if (i != setBanned.end()) { - int64_t t = (*i).second; - if (GetTime() < t) + CBanEntry banEntry = (*i).second; + if (GetTime() < banEntry.nBanUntil) fResult = true; } } return fResult; } -void CNode::Ban(const CNetAddr& addr, int64_t bantimeoffset, bool sinceUnixEpoch) { - CSubNet subNet(addr.ToString()+(addr.IsIPv4() ? "/32" : "/128")); - Ban(subNet, bantimeoffset, sinceUnixEpoch); +void CNode::Ban(const CNetAddr& addr, const BanReason &banReason, int64_t bantimeoffset, bool sinceUnixEpoch) { + CSubNet subNet(addr); + Ban(subNet, banReason, bantimeoffset, sinceUnixEpoch); } -void CNode::Ban(const CSubNet& subNet, int64_t bantimeoffset, bool sinceUnixEpoch) { - int64_t banTime = GetTime()+GetArg("-bantime", 60*60*24); // Default 24-hour ban - if (bantimeoffset > 0) - banTime = (sinceUnixEpoch ? 0 : GetTime() )+bantimeoffset; +void CNode::Ban(const CSubNet& subNet, const BanReason &banReason, int64_t bantimeoffset, bool sinceUnixEpoch) { + CBanEntry banEntry(GetTime()); + banEntry.banReason = banReason; + if (bantimeoffset <= 0) + { + bantimeoffset = GetArg("-bantime", 60*60*24); // Default 24-hour ban + sinceUnixEpoch = false; + } + banEntry.nBanUntil = (sinceUnixEpoch ? 0 : GetTime() )+bantimeoffset; + LOCK(cs_setBanned); - if (setBanned[subNet] < banTime) - setBanned[subNet] = banTime; + if (setBanned[subNet].nBanUntil < banEntry.nBanUntil) + setBanned[subNet] = banEntry; + + setBannedIsDirty = true; } bool CNode::Unban(const CNetAddr &addr) { - CSubNet subNet(addr.ToString()+(addr.IsIPv4() ? "/32" : "/128")); + CSubNet subNet(addr); return Unban(subNet); } bool CNode::Unban(const CSubNet &subNet) { LOCK(cs_setBanned); if (setBanned.erase(subNet)) + { + setBannedIsDirty = true; return true; + } return false; } -void CNode::GetBanned(std::map &banMap) +void CNode::GetBanned(banmap_t &banMap) { LOCK(cs_setBanned); banMap = setBanned; //create a thread safe copy } +void CNode::SetBanned(const banmap_t &banMap) +{ + LOCK(cs_setBanned); + setBanned = banMap; + setBannedIsDirty = true; +} + +void CNode::SweepBanned() +{ + int64_t now = GetTime(); + + LOCK(cs_setBanned); + banmap_t::iterator it = setBanned.begin(); + while(it != setBanned.end()) + { + CBanEntry banEntry = (*it).second; + if(now > banEntry.nBanUntil) + { + setBanned.erase(it++); + setBannedIsDirty = true; + } + else + ++it; + } +} + +bool CNode::BannedSetIsDirty() +{ + LOCK(cs_setBanned); + return setBannedIsDirty; +} + +void CNode::SetBannedSetDirty(bool dirty) +{ + LOCK(cs_setBanned); //reuse setBanned lock for the isDirty flag + setBannedIsDirty = dirty; +} + std::vector CNode::vWhitelistedRange; CCriticalSection CNode::cs_vWhitelistedRange; @@ -1212,6 +1263,17 @@ void DumpAddresses() addrman.size(), GetTimeMillis() - nStart); } +void DumpData() +{ + DumpAddresses(); + + if (CNode::BannedSetIsDirty()) + { + DumpBanlist(); + CNode::SetBannedSetDirty(false); + } +} + void static ProcessOneShot() { string strDest; @@ -1650,6 +1712,17 @@ void StartNode(boost::thread_group& threadGroup, CScheduler& scheduler) if (!adb.Read(addrman)) LogPrintf("Invalid or missing peers.dat; recreating\n"); } + + //try to read stored banlist + CBanDB bandb; + banmap_t banmap; + if (!bandb.Read(banmap)) + LogPrintf("Invalid or missing banlist.dat; recreating\n"); + + CNode::SetBanned(banmap); //thread save setter + CNode::SetBannedSetDirty(false); //no need to write down just read or nonexistent data + CNode::SweepBanned(); //sweap out unused entries + LogPrintf("Loaded %i addresses from peers.dat %dms\n", addrman.size(), GetTimeMillis() - nStart); fAddressesInitialized = true; @@ -1690,7 +1763,7 @@ void StartNode(boost::thread_group& threadGroup, CScheduler& scheduler) threadGroup.create_thread(boost::bind(&TraceThread, "msghand", &ThreadMessageHandler)); // Dump network addresses - scheduler.scheduleEvery(&DumpAddresses, DUMP_ADDRESSES_INTERVAL); + scheduler.scheduleEvery(&DumpData, DUMP_ADDRESSES_INTERVAL); } bool StopNode() @@ -1703,7 +1776,7 @@ bool StopNode() if (fAddressesInitialized) { - DumpAddresses(); + DumpData(); fAddressesInitialized = false; } @@ -1907,11 +1980,11 @@ bool CAddrDB::Read(CAddrMan& addr) return error("%s: Failed to open file %s", __func__, pathAddr.string()); // use file size to size memory buffer - int fileSize = boost::filesystem::file_size(pathAddr); - int dataSize = fileSize - sizeof(uint256); + uint64_t fileSize = boost::filesystem::file_size(pathAddr); + uint64_t dataSize = 0; // Don't try to resize to a negative number if file is small - if (dataSize < 0) - dataSize = 0; + if (fileSize >= sizeof(uint256)) + dataSize = fileSize - sizeof(uint256); vector vchData; vchData.resize(dataSize); uint256 hashIn; @@ -2107,3 +2180,119 @@ void CNode::EndMessage() UNLOCK_FUNCTION(cs_vSend) LEAVE_CRITICAL_SECTION(cs_vSend); } + +// +// CBanDB +// + +CBanDB::CBanDB() +{ + pathBanlist = GetDataDir() / "banlist.dat"; +} + +bool CBanDB::Write(const banmap_t& banSet) +{ + // Generate random temporary filename + unsigned short randv = 0; + GetRandBytes((unsigned char*)&randv, sizeof(randv)); + std::string tmpfn = strprintf("banlist.dat.%04x", randv); + + // serialize banlist, checksum data up to that point, then append csum + CDataStream ssBanlist(SER_DISK, CLIENT_VERSION); + ssBanlist << FLATDATA(Params().MessageStart()); + ssBanlist << banSet; + uint256 hash = Hash(ssBanlist.begin(), ssBanlist.end()); + ssBanlist << hash; + + // open temp output file, and associate with CAutoFile + boost::filesystem::path pathTmp = GetDataDir() / tmpfn; + FILE *file = fopen(pathTmp.string().c_str(), "wb"); + CAutoFile fileout(file, SER_DISK, CLIENT_VERSION); + if (fileout.IsNull()) + return error("%s: Failed to open file %s", __func__, pathTmp.string()); + + // Write and commit header, data + try { + fileout << ssBanlist; + } + catch (const std::exception& e) { + return error("%s: Serialize or I/O error - %s", __func__, e.what()); + } + FileCommit(fileout.Get()); + fileout.fclose(); + + // replace existing banlist.dat, if any, with new banlist.dat.XXXX + if (!RenameOver(pathTmp, pathBanlist)) + return error("%s: Rename-into-place failed", __func__); + + return true; +} + +bool CBanDB::Read(banmap_t& banSet) +{ + // open input file, and associate with CAutoFile + FILE *file = fopen(pathBanlist.string().c_str(), "rb"); + CAutoFile filein(file, SER_DISK, CLIENT_VERSION); + if (filein.IsNull()) + return error("%s: Failed to open file %s", __func__, pathBanlist.string()); + + // use file size to size memory buffer + uint64_t fileSize = boost::filesystem::file_size(pathBanlist); + uint64_t dataSize = 0; + // Don't try to resize to a negative number if file is small + if (fileSize >= sizeof(uint256)) + dataSize = fileSize - sizeof(uint256); + vector vchData; + vchData.resize(dataSize); + uint256 hashIn; + + // read data and checksum from file + try { + filein.read((char *)&vchData[0], dataSize); + filein >> hashIn; + } + catch (const std::exception& e) { + return error("%s: Deserialize or I/O error - %s", __func__, e.what()); + } + filein.fclose(); + + CDataStream ssBanlist(vchData, SER_DISK, CLIENT_VERSION); + + // verify stored checksum matches input data + uint256 hashTmp = Hash(ssBanlist.begin(), ssBanlist.end()); + if (hashIn != hashTmp) + return error("%s: Checksum mismatch, data corrupted", __func__); + + unsigned char pchMsgTmp[4]; + try { + // de-serialize file header (network specific magic number) and .. + ssBanlist >> FLATDATA(pchMsgTmp); + + // ... verify the network matches ours + if (memcmp(pchMsgTmp, Params().MessageStart(), sizeof(pchMsgTmp))) + return error("%s: Invalid network magic number", __func__); + + // de-serialize address data into one CAddrMan object + ssBanlist >> banSet; + } + catch (const std::exception& e) { + return error("%s: Deserialize or I/O error - %s", __func__, e.what()); + } + + return true; +} + +void DumpBanlist() +{ + int64_t nStart = GetTimeMillis(); + + CNode::SweepBanned(); //clean unused entires (if bantime has expired) + + CBanDB bandb; + banmap_t banmap; + CNode::GetBanned(banmap); + bandb.Write(banmap); + + LogPrint("net", "Flushed %d banned node ips/subnets to banlist.dat %dms\n", + banmap.size(), GetTimeMillis() - nStart); +} \ No newline at end of file diff --git a/src/net.h b/src/net.h index 69e4c592a..f15b85474 100644 --- a/src/net.h +++ b/src/net.h @@ -228,8 +228,66 @@ public: }; +typedef enum BanReason +{ + BanReasonUnknown = 0, + BanReasonNodeMisbehaving = 1, + BanReasonManuallyAdded = 2 +} BanReason; +class CBanEntry +{ +public: + static const int CURRENT_VERSION=1; + int nVersion; + int64_t nCreateTime; + int64_t nBanUntil; + uint8_t banReason; + CBanEntry() + { + SetNull(); + } + + CBanEntry(int64_t nCreateTimeIn) + { + SetNull(); + nCreateTime = nCreateTimeIn; + } + + ADD_SERIALIZE_METHODS; + + template + inline void SerializationOp(Stream& s, Operation ser_action, int nType, int nVersion) { + READWRITE(this->nVersion); + nVersion = this->nVersion; + READWRITE(nCreateTime); + READWRITE(nBanUntil); + READWRITE(banReason); + } + + void SetNull() + { + nVersion = CBanEntry::CURRENT_VERSION; + nCreateTime = 0; + nBanUntil = 0; + banReason = BanReasonUnknown; + } + + std::string banReasonToString() + { + switch (banReason) { + case BanReasonNodeMisbehaving: + return "node misbehabing"; + case BanReasonManuallyAdded: + return "manually added"; + default: + return "unknown"; + } + } +}; + +typedef std::map banmap_t; /** Information about a peer */ class CNode @@ -285,8 +343,9 @@ protected: // Denial-of-service detection/prevention // Key is IP address, value is banned-until-time - static std::map setBanned; + static banmap_t setBanned; static CCriticalSection cs_setBanned; + static bool setBannedIsDirty; // Whitelisted ranges. Any node connecting from these is automatically // whitelisted (as well as those connecting to whitelisted binds). @@ -608,11 +667,19 @@ public: static void ClearBanned(); // needed for unit testing static bool IsBanned(CNetAddr ip); static bool IsBanned(CSubNet subnet); - static void Ban(const CNetAddr &ip, int64_t bantimeoffset = 0, bool sinceUnixEpoch = false); - static void Ban(const CSubNet &subNet, int64_t bantimeoffset = 0, bool sinceUnixEpoch = false); + static void Ban(const CNetAddr &ip, const BanReason &banReason, int64_t bantimeoffset = 0, bool sinceUnixEpoch = false); + static void Ban(const CSubNet &subNet, const BanReason &banReason, int64_t bantimeoffset = 0, bool sinceUnixEpoch = false); static bool Unban(const CNetAddr &ip); static bool Unban(const CSubNet &ip); - static void GetBanned(std::map &banmap); + static void GetBanned(banmap_t &banmap); + static void SetBanned(const banmap_t &banmap); + + //!check is the banlist has unwritten changes + static bool BannedSetIsDirty(); + //!set the "dirty" flag for the banlist + static void SetBannedSetDirty(bool dirty=true); + //!clean unused entires (if bantime has expired) + static void SweepBanned(); void copyStats(CNodeStats &stats); @@ -644,4 +711,17 @@ public: bool Read(CAddrMan& addr); }; +/** Access to the banlist database (banlist.dat) */ +class CBanDB +{ +private: + boost::filesystem::path pathBanlist; +public: + CBanDB(); + bool Write(const banmap_t& banSet); + bool Read(banmap_t& banSet); +}; + +void DumpBanlist(); + #endif // BITCOIN_NET_H diff --git a/src/netbase.cpp b/src/netbase.cpp index adac5c2d0..c9fc7d67f 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -1291,6 +1291,13 @@ CSubNet::CSubNet(const std::string &strSubnet, bool fAllowLookup) network.ip[x] &= netmask[x]; } +CSubNet::CSubNet(const CNetAddr &addr): + valid(addr.IsValid()) +{ + memset(netmask, 255, sizeof(netmask)); + network = addr; +} + bool CSubNet::Match(const CNetAddr &addr) const { if (!valid || !addr.IsValid()) diff --git a/src/netbase.h b/src/netbase.h index 27f0eac2a..6f8882b85 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -118,6 +118,9 @@ class CSubNet CSubNet(); explicit CSubNet(const std::string &strSubnet, bool fAllowLookup = false); + //constructor for single ip subnet (/32 or /128) + explicit CSubNet(const CNetAddr &addr); + bool Match(const CNetAddr &addr) const; std::string ToString() const; @@ -126,6 +129,15 @@ class CSubNet friend bool operator==(const CSubNet& a, const CSubNet& b); friend bool operator!=(const CSubNet& a, const CSubNet& b); friend bool operator<(const CSubNet& a, const CSubNet& b); + + ADD_SERIALIZE_METHODS; + + template + inline void SerializationOp(Stream& s, Operation ser_action, int nType, int nVersion) { + READWRITE(network); + READWRITE(FLATDATA(netmask)); + READWRITE(FLATDATA(valid)); + } }; /** A combination of a network address (CNetAddr) and a (TCP) port */ diff --git a/src/rpcnet.cpp b/src/rpcnet.cpp index 1572b1668..dd631905f 100644 --- a/src/rpcnet.cpp +++ b/src/rpcnet.cpp @@ -515,7 +515,7 @@ UniValue setban(const UniValue& params, bool fHelp) if (params.size() == 4 && params[3].isTrue()) absolute = true; - isSubnet ? CNode::Ban(subNet, banTime, absolute) : CNode::Ban(netAddr, banTime, absolute); + isSubnet ? CNode::Ban(subNet, BanReasonManuallyAdded, banTime, absolute) : CNode::Ban(netAddr, BanReasonManuallyAdded, banTime, absolute); //disconnect possible nodes while(CNode *bannedNode = (isSubnet ? FindNode(subNet) : FindNode(netAddr))) @@ -527,6 +527,7 @@ UniValue setban(const UniValue& params, bool fHelp) throw JSONRPCError(RPC_MISC_ERROR, "Error: Unban failed"); } + DumpBanlist(); //store banlist to disk return NullUniValue; } @@ -541,15 +542,19 @@ UniValue listbanned(const UniValue& params, bool fHelp) + HelpExampleRpc("listbanned", "") ); - std::map banMap; + banmap_t banMap; CNode::GetBanned(banMap); UniValue bannedAddresses(UniValue::VARR); - for (std::map::iterator it = banMap.begin(); it != banMap.end(); it++) + for (banmap_t::iterator it = banMap.begin(); it != banMap.end(); it++) { + CBanEntry banEntry = (*it).second; UniValue rec(UniValue::VOBJ); rec.push_back(Pair("address", (*it).first.ToString())); - rec.push_back(Pair("banned_untill", (*it).second)); + rec.push_back(Pair("banned_until", banEntry.nBanUntil)); + rec.push_back(Pair("ban_created", banEntry.nCreateTime)); + rec.push_back(Pair("ban_reason", banEntry.banReasonToString())); + bannedAddresses.push_back(rec); } @@ -568,6 +573,7 @@ UniValue clearbanned(const UniValue& params, bool fHelp) ); CNode::ClearBanned(); + DumpBanlist(); //store banlist to disk return NullUniValue; } diff --git a/src/test/netbase_tests.cpp b/src/test/netbase_tests.cpp index 0f5e1615c..7154476c7 100644 --- a/src/test/netbase_tests.cpp +++ b/src/test/netbase_tests.cpp @@ -143,6 +143,17 @@ BOOST_AUTO_TEST_CASE(subnet_test) BOOST_CHECK(CSubNet("1:2:3:4:5:6:7:8/128").IsValid()); BOOST_CHECK(!CSubNet("1:2:3:4:5:6:7:8/129").IsValid()); BOOST_CHECK(!CSubNet("fuzzy").IsValid()); + + //CNetAddr constructor test + BOOST_CHECK(CSubNet(CNetAddr("127.0.0.1")).IsValid()); + BOOST_CHECK(CSubNet(CNetAddr("127.0.0.1")).Match(CNetAddr("127.0.0.1"))); + BOOST_CHECK(!CSubNet(CNetAddr("127.0.0.1")).Match(CNetAddr("127.0.0.2"))); + BOOST_CHECK(CSubNet(CNetAddr("127.0.0.1")).ToString() == "127.0.0.1/255.255.255.255"); + + BOOST_CHECK(CSubNet(CNetAddr("1:2:3:4:5:6:7:8")).IsValid()); + BOOST_CHECK(CSubNet(CNetAddr("1:2:3:4:5:6:7:8")).Match(CNetAddr("1:2:3:4:5:6:7:8"))); + BOOST_CHECK(!CSubNet(CNetAddr("1:2:3:4:5:6:7:8")).Match(CNetAddr("1:2:3:4:5:6:7:9"))); + BOOST_CHECK(CSubNet(CNetAddr("1:2:3:4:5:6:7:8")).ToString() == "1:2:3:4:5:6:7:8/ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"); } BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/rpc_tests.cpp b/src/test/rpc_tests.cpp index c38df0ecf..9e99ff628 100644 --- a/src/test/rpc_tests.cpp +++ b/src/test/rpc_tests.cpp @@ -199,7 +199,7 @@ BOOST_AUTO_TEST_CASE(rpc_ban) ar = r.get_array(); o1 = ar[0].get_obj(); adr = find_value(o1, "address"); - UniValue banned_until = find_value(o1, "banned_untill"); + UniValue banned_until = find_value(o1, "banned_until"); BOOST_CHECK_EQUAL(adr.get_str(), "127.0.0.0/255.255.255.0"); BOOST_CHECK_EQUAL(banned_until.get_int64(), 1607731200); // absolute time check @@ -210,7 +210,7 @@ BOOST_AUTO_TEST_CASE(rpc_ban) ar = r.get_array(); o1 = ar[0].get_obj(); adr = find_value(o1, "address"); - banned_until = find_value(o1, "banned_untill"); + banned_until = find_value(o1, "banned_until"); BOOST_CHECK_EQUAL(adr.get_str(), "127.0.0.0/255.255.255.0"); int64_t now = GetTime(); BOOST_CHECK(banned_until.get_int64() > now);