diff --git a/src/Makefile.am b/src/Makefile.am index 0e21d2d44..b71639c7e 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -338,6 +338,7 @@ BITCOIN_CORE_H = \ wallet/walletdb.h \ wallet/wallet_tx_builder.h \ warnings.h \ + weighted_map.h \ zmq/zmqabstractnotifier.h \ zmq/zmqconfig.h\ zmq/zmqnotificationinterface.h \ diff --git a/src/Makefile.gtest.include b/src/Makefile.gtest.include index ec80559d7..c51a3ff04 100644 --- a/src/Makefile.gtest.include +++ b/src/Makefile.gtest.include @@ -51,6 +51,7 @@ zcash_gtest_SOURCES = \ gtest/test_txid.cpp \ gtest/test_upgrades.cpp \ gtest/test_validation.cpp \ + gtest/test_weightedmap.cpp \ gtest/test_zip32.cpp \ gtest/test_coins.cpp if ENABLE_WALLET diff --git a/src/gtest/test_mempoollimit.cpp b/src/gtest/test_mempoollimit.cpp index f13047759..3b47497e0 100644 --- a/src/gtest/test_mempoollimit.cpp +++ b/src/gtest/test_mempoollimit.cpp @@ -80,36 +80,31 @@ TEST(MempoolLimitTests, RecentlyEvictedDropOneAtATime) EXPECT_FALSE(recentlyEvicted.contains(TX_ID3)); } -TEST(MempoolLimitTests, WeightedTxTreeCheckSizeAfterDropping) +TEST(MempoolLimitTests, MempoolLimitTxSetCheckSizeAfterDropping) { std::set testedDropping; // Run the test until we have tested dropping each of the elements int trialNum = 0; while (testedDropping.size() < 3) { - WeightedTxTree tree(MIN_TX_COST * 2); - EXPECT_EQ(0, tree.getTotalWeight().cost); - EXPECT_EQ(0, tree.getTotalWeight().evictionWeight); - tree.add(WeightedTxInfo(TX_ID1, TxWeight(MIN_TX_COST, MIN_TX_COST))); - EXPECT_EQ(4000, tree.getTotalWeight().cost); - EXPECT_EQ(4000, tree.getTotalWeight().evictionWeight); - tree.add(WeightedTxInfo(TX_ID2, TxWeight(MIN_TX_COST, MIN_TX_COST))); - EXPECT_EQ(8000, tree.getTotalWeight().cost); - EXPECT_EQ(8000, tree.getTotalWeight().evictionWeight); - EXPECT_FALSE(tree.maybeDropRandom().has_value()); - tree.add(WeightedTxInfo(TX_ID3, TxWeight(MIN_TX_COST, MIN_TX_COST + LOW_FEE_PENALTY))); - EXPECT_EQ(12000, tree.getTotalWeight().cost); - EXPECT_EQ(12000 + LOW_FEE_PENALTY, tree.getTotalWeight().evictionWeight); - std::optional drop = tree.maybeDropRandom(); + MempoolLimitTxSet limitSet(MIN_TX_COST * 2); + EXPECT_EQ(0, limitSet.getTotalWeight()); + limitSet.add(TX_ID1, MIN_TX_COST, MIN_TX_COST); + EXPECT_EQ(4000, limitSet.getTotalWeight()); + limitSet.add(TX_ID2, MIN_TX_COST, MIN_TX_COST); + EXPECT_EQ(8000, limitSet.getTotalWeight()); + EXPECT_FALSE(limitSet.maybeDropRandom().has_value()); + limitSet.add(TX_ID3, MIN_TX_COST, MIN_TX_COST + LOW_FEE_PENALTY); + EXPECT_EQ(12000 + LOW_FEE_PENALTY, limitSet.getTotalWeight()); + std::optional drop = limitSet.maybeDropRandom(); ASSERT_TRUE(drop.has_value()); uint256 txid = drop.value(); testedDropping.insert(txid); // Do not continue to test if a particular trial fails - ASSERT_EQ(8000, tree.getTotalWeight().cost); - ASSERT_EQ(txid == TX_ID3 ? 8000 : 8000 + LOW_FEE_PENALTY, tree.getTotalWeight().evictionWeight); + ASSERT_EQ(txid == TX_ID3 ? 8000 : 8000 + LOW_FEE_PENALTY, limitSet.getTotalWeight()); } } -TEST(MempoolLimitTests, WeightedTxInfoFromTx) +TEST(MempoolLimitTests, MempoolCostAndEvictionWeight) { LoadProofParameters(); @@ -126,9 +121,9 @@ TEST(MempoolLimitTests, WeightedTxInfoFromTx) builder.AddSaplingSpend(sk.expanded_spending_key(), testNote.note, testNote.tree.root(), testNote.tree.witness()); builder.AddSaplingOutput(sk.full_viewing_key().ovk, sk.default_address(), 25000, {}); - WeightedTxInfo info = WeightedTxInfo::from(builder.Build().GetTxOrThrow(), DEFAULT_FEE); - EXPECT_EQ(MIN_TX_COST, info.txWeight.cost); - EXPECT_EQ(MIN_TX_COST, info.txWeight.evictionWeight); + auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(builder.Build().GetTxOrThrow(), DEFAULT_FEE); + EXPECT_EQ(MIN_TX_COST, cost); + EXPECT_EQ(MIN_TX_COST, evictionWeight); } // Lower than standard fee @@ -139,9 +134,9 @@ TEST(MempoolLimitTests, WeightedTxInfoFromTx) static_assert(DEFAULT_FEE == 1000); builder.SetFee(DEFAULT_FEE-1); - WeightedTxInfo info = WeightedTxInfo::from(builder.Build().GetTxOrThrow(), DEFAULT_FEE-1); - EXPECT_EQ(MIN_TX_COST, info.txWeight.cost); - EXPECT_EQ(MIN_TX_COST + LOW_FEE_PENALTY, info.txWeight.evictionWeight); + auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(builder.Build().GetTxOrThrow(), DEFAULT_FEE-1); + EXPECT_EQ(MIN_TX_COST, cost); + EXPECT_EQ(MIN_TX_COST + LOW_FEE_PENALTY, evictionWeight); } // Larger Tx @@ -157,9 +152,9 @@ TEST(MempoolLimitTests, WeightedTxInfoFromTx) if (result.IsError()) { std::cerr << result.GetError() << std::endl; } - WeightedTxInfo info = WeightedTxInfo::from(result.GetTxOrThrow(), DEFAULT_FEE); - EXPECT_EQ(5168, info.txWeight.cost); - EXPECT_EQ(5168, info.txWeight.evictionWeight); + auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(result.GetTxOrThrow(), DEFAULT_FEE); + EXPECT_EQ(5168, cost); + EXPECT_EQ(5168, evictionWeight); } RegtestDeactivateSapling(); diff --git a/src/gtest/test_weightedmap.cpp b/src/gtest/test_weightedmap.cpp new file mode 100644 index 000000000..0d484d08f --- /dev/null +++ b/src/gtest/test_weightedmap.cpp @@ -0,0 +1,160 @@ +// Copyright (c) 2019-2023 The Zcash developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or https://www.opensource.org/licenses/mit-license.php . + +#include + +#include "weighted_map.h" +#include "gtest/utils.h" +#include "util/test.h" + + +TEST(WeightedMapTests, WeightedMap) +{ + WeightedMap m; + + EXPECT_EQ(0, m.size()); + EXPECT_TRUE(m.empty()); + EXPECT_EQ(0, m.getTotalWeight()); + m.checkInvariants(); + + EXPECT_TRUE(m.add(3, 30, 3)); + EXPECT_EQ(1, m.size()); + EXPECT_FALSE(m.empty()); + EXPECT_EQ(3, m.getTotalWeight()); + m.checkInvariants(); + + EXPECT_TRUE(m.add(1, 10, 2)); + EXPECT_EQ(2, m.size()); + EXPECT_FALSE(m.empty()); + EXPECT_EQ(5, m.getTotalWeight()); + m.checkInvariants(); + + // adding a duplicate element should be ignored + EXPECT_FALSE(m.add(1, 15, 64)); + EXPECT_EQ(2, m.size()); + EXPECT_FALSE(m.empty()); + EXPECT_EQ(5, m.getTotalWeight()); + m.checkInvariants(); + + EXPECT_TRUE(m.add(2, 20, 1)); + EXPECT_EQ(3, m.size()); + EXPECT_FALSE(m.empty()); + EXPECT_EQ(6, m.getTotalWeight()); + m.checkInvariants(); + + // regression test: adding three elements and deleting the first caused an invariant violation (not in committed code) + EXPECT_EQ(30, m.remove(3).value()); + EXPECT_EQ(2, m.size()); + EXPECT_FALSE(m.empty()); + EXPECT_EQ(3, m.getTotalWeight()); + m.checkInvariants(); + + // try to remove a non-existent element + EXPECT_FALSE(m.remove(42).has_value()); + EXPECT_EQ(2, m.size()); + EXPECT_FALSE(m.empty()); + EXPECT_EQ(3, m.getTotalWeight()); + m.checkInvariants(); + + EXPECT_EQ(20, m.remove(2).value()); + EXPECT_EQ(1, m.size()); + EXPECT_FALSE(m.empty()); + EXPECT_EQ(2, m.getTotalWeight()); + m.checkInvariants(); + + EXPECT_TRUE(m.add(2, 20, 1)); + EXPECT_EQ(2, m.size()); + EXPECT_FALSE(m.empty()); + EXPECT_EQ(3, m.getTotalWeight()); + m.checkInvariants(); + + // at this point the map should contain 1->10 (weight 2) and 2->20 (weight 1) + auto [e1, c1, w1] = m.takeRandom().value(); + EXPECT_TRUE(e1 == 1 || e1 == 2); + EXPECT_EQ(c1, e1*10); + EXPECT_EQ(w1, 3-e1); + EXPECT_EQ(1, m.size()); + EXPECT_FALSE(m.empty()); + EXPECT_EQ(3-w1, m.getTotalWeight()); + m.checkInvariants(); + + auto [e2, c2, w2] = m.takeRandom().value(); + EXPECT_EQ(3, e1 + e2); + EXPECT_EQ(c2, e2*10); + EXPECT_EQ(w2, 3-e2); + EXPECT_EQ(0, m.size()); + EXPECT_TRUE(m.empty()); + EXPECT_EQ(0, m.getTotalWeight()); + m.checkInvariants(); + + EXPECT_FALSE(m.takeRandom().has_value()); + EXPECT_EQ(0, m.size()); + EXPECT_TRUE(m.empty()); + EXPECT_EQ(0, m.getTotalWeight()); + m.checkInvariants(); +} + +TEST(WeightedMapTests, WeightedMapRandomOps) +{ + WeightedMap m; + std::map expected; // element -> weight + int total_weight = 0; + const int iterations = 1000; + const int element_range = 20; + const int max_weight = 10; + static_assert(iterations <= std::numeric_limits::max() / max_weight); // ensure total_weight cannot overflow + + EXPECT_EQ(0, m.size()); + EXPECT_TRUE(m.empty()); + EXPECT_EQ(0, m.getTotalWeight()); + m.checkInvariants(); + for (int i = 0; i < iterations; i++) { + switch (GetRandInt(4)) { + // probability of add should be balanced with (remove or takeRandom) + case 0: case 1: { + int e = GetRandInt(element_range); + int w = GetRandInt(max_weight) + 1; + bool added = m.add(e, e*10, w); + EXPECT_EQ(added, expected.count(e) == 0); + if (added) { + total_weight += w; + expected[e] = w; + } + break; + } + case 2: { + int e = GetRandInt(element_range); + auto c = m.remove(e); + if (expected.count(e) == 0) { + EXPECT_FALSE(c.has_value()); + } else { + ASSERT_TRUE(c.has_value()); + EXPECT_EQ(c.value(), e*10); + total_weight -= expected[e]; + expected.erase(e); + } + break; + } + case 3: { + auto r = m.takeRandom(); + if (expected.empty()) { + EXPECT_FALSE(r.has_value()); + } else { + ASSERT_TRUE(r.has_value()); + auto [e, c, w] = r.value(); + EXPECT_EQ(1, expected.count(e)); + EXPECT_EQ(c, e*10); + EXPECT_EQ(w, expected[e]); + total_weight -= expected[e]; + expected.erase(e); + } + break; + } + } + EXPECT_EQ(expected.size(), m.size()); + EXPECT_EQ(expected.empty(), m.empty()); + EXPECT_EQ(total_weight, m.getTotalWeight()); + m.checkInvariants(); + } +} diff --git a/src/mempool_limit.cpp b/src/mempool_limit.cpp index 97c679956..743a10e7b 100644 --- a/src/mempool_limit.cpp +++ b/src/mempool_limit.cpp @@ -5,15 +5,11 @@ #include "mempool_limit.h" #include "core_memusage.h" -#include "logging.h" -#include "random.h" #include "serialize.h" #include "timedata.h" #include "util/time.h" #include "version.h" -const TxWeight ZERO_WEIGHT = TxWeight(0, 0); - void RecentlyEvictedList::pruneList() { if (txIdSet.empty()) { @@ -43,111 +39,7 @@ bool RecentlyEvictedList::contains(const uint256& txId) return txIdSet.count(txId) > 0; } - -TxWeight WeightedTxTree::getWeightAt(size_t index) const -{ - return index < size ? txIdAndWeights[index].txWeight.add(childWeights[index]) : ZERO_WEIGHT; -} - -void WeightedTxTree::backPropagate(size_t fromIndex, const TxWeight& weightDelta) -{ - while (fromIndex > 0) { - fromIndex = (fromIndex - 1) / 2; - childWeights[fromIndex] = childWeights[fromIndex].add(weightDelta); - } -} - -size_t WeightedTxTree::findByEvictionWeight(size_t fromIndex, int64_t weightToFind) const -{ - int leftWeight = getWeightAt(fromIndex * 2 + 1).evictionWeight; - int rightWeight = getWeightAt(fromIndex).evictionWeight - getWeightAt(fromIndex * 2 + 2).evictionWeight; - // On Left - if (weightToFind < leftWeight) { - return findByEvictionWeight(fromIndex * 2 + 1, weightToFind); - } - // Found - if (weightToFind < rightWeight) { - return fromIndex; - } - // On Right - return findByEvictionWeight(fromIndex * 2 + 2, weightToFind - rightWeight); -} - -TxWeight WeightedTxTree::getTotalWeight() const -{ - return getWeightAt(0); -} - - -void WeightedTxTree::add(const WeightedTxInfo& weightedTxInfo) -{ - if (txIdToIndexMap.count(weightedTxInfo.txId) > 0) { - // This should not happen, but should be prevented nonetheless - return; - } - txIdAndWeights.push_back(weightedTxInfo); - childWeights.push_back(ZERO_WEIGHT); - txIdToIndexMap[weightedTxInfo.txId] = size; - backPropagate(size, weightedTxInfo.txWeight); - size += 1; -} - -void WeightedTxTree::remove(const uint256& txId) -{ - if (txIdToIndexMap.find(txId) == txIdToIndexMap.end()) { - // Remove may be called multiple times for a given tx, so this is necessary - return; - } - - size_t removeIndex = txIdToIndexMap[txId]; - - // We reduce the size at the start of this method to avoid saying size - 1 - // when referring to the last element of the array below - size -= 1; - - TxWeight lastChildWeight = txIdAndWeights[size].txWeight; - backPropagate(size, lastChildWeight.negate()); - - if (removeIndex < size) { - TxWeight weightDelta = lastChildWeight.add(txIdAndWeights[removeIndex].txWeight.negate()); - txIdAndWeights[removeIndex] = txIdAndWeights[size]; - txIdToIndexMap[txIdAndWeights[removeIndex].txId] = removeIndex; - backPropagate(removeIndex, weightDelta); - } - - txIdToIndexMap.erase(txId); - txIdAndWeights.pop_back(); - childWeights.pop_back(); -} - -std::optional WeightedTxTree::maybeDropRandom() -{ - TxWeight totalTxWeight = getTotalWeight(); - if (totalTxWeight.cost <= capacity) { - return std::nullopt; - } - LogPrint("mempool", "Mempool cost limit exceeded (cost=%d, limit=%d)\n", totalTxWeight.cost, capacity); - int randomWeight = GetRand(totalTxWeight.evictionWeight); - WeightedTxInfo drop = txIdAndWeights[findByEvictionWeight(0, randomWeight)]; - LogPrint("mempool", "Evicting transaction (txid=%s, cost=%d, evictionWeight=%d)\n", - drop.txId.ToString(), drop.txWeight.cost, drop.txWeight.evictionWeight); - remove(drop.txId); - return drop.txId; -} - - -TxWeight TxWeight::add(const TxWeight& other) const -{ - return TxWeight(cost + other.cost, evictionWeight + other.evictionWeight); -} - -TxWeight TxWeight::negate() const -{ - return TxWeight(-cost, -evictionWeight); -} - - -WeightedTxInfo WeightedTxInfo::from(const CTransaction& tx, const CAmount& fee) +std::pair MempoolCostAndEvictionWeight(const CTransaction& tx, const CAmount& fee) { size_t memUsage = RecursiveDynamicUsage(tx); int64_t cost = std::max((int64_t) memUsage, (int64_t) MIN_TX_COST); @@ -155,5 +47,5 @@ WeightedTxInfo WeightedTxInfo::from(const CTransaction& tx, const CAmount& fee) if (fee < DEFAULT_FEE) { evictionWeight += LOW_FEE_PENALTY; } - return WeightedTxInfo(tx.GetHash(), TxWeight(cost, evictionWeight)); + return std::make_pair(cost, evictionWeight); } diff --git a/src/mempool_limit.h b/src/mempool_limit.h index 0075ce803..e29ca5d24 100644 --- a/src/mempool_limit.h +++ b/src/mempool_limit.h @@ -9,19 +9,21 @@ #include #include #include -#include +#include "logging.h" +#include "random.h" #include "primitives/transaction.h" #include "policy/fees.h" #include "uint256.h" #include "util/time.h" +#include "weighted_map.h" const size_t DEFAULT_MEMPOOL_TOTAL_COST_LIMIT = 80000000; const int64_t DEFAULT_MEMPOOL_EVICTION_MEMORY_MINUTES = 60; const size_t EVICTION_MEMORY_ENTRIES = 40000; -const uint64_t MIN_TX_COST = 4000; -const uint64_t LOW_FEE_PENALTY = 16000; +const int64_t MIN_TX_COST = 4000; +const int64_t LOW_FEE_PENALTY = 16000; // This class keeps track of transactions which have been recently evicted from the mempool @@ -57,80 +59,62 @@ public: // The mempool of a node holds a set of transactions. Each transaction has a *cost*, // which is an integer defined as: -// max(serialized transaction size in bytes, 4000) +// max(memory usage in bytes, 4000) +// // Each transaction also has an *eviction weight*, which is *cost* + *fee_penalty*, -// where *fee_penalty* is 16000 if the transaction pays a fee less than 10000 zatoshi, -// otherwise 0. -struct TxWeight { - int64_t cost; - int64_t evictionWeight; // *cost* + *fee_penalty* +// where *fee_penalty* is 16000 if the transaction pays a fee less than the +// ZIP 317 conventional fee, otherwise 0. - TxWeight(int64_t cost_, int64_t evictionWeight_) - : cost(cost_), evictionWeight(evictionWeight_) {} - - TxWeight add(const TxWeight& other) const; - TxWeight negate() const; -}; +// Calculate cost and eviction weight based on the memory usage and fee. +std::pair MempoolCostAndEvictionWeight(const CTransaction& tx, const CAmount& fee); -// This struct is a pair of txid, cost. -struct WeightedTxInfo { - uint256 txId; - TxWeight txWeight; - - WeightedTxInfo(uint256 txId_, TxWeight txWeight_) : txId(txId_), txWeight(txWeight_) {} - - // Factory method which calculates cost based on size in bytes and fee. - static WeightedTxInfo from(const CTransaction& tx, const CAmount& fee); -}; - - -// The following class is a collection of transaction ids and their costs. -// In order to be able to remove transactions randomly weighted by their cost, -// we keep track of the total cost of all transactions in this collection. -// For performance reasons, the collection is represented as a complete binary -// tree where each node knows the sum of the weights of the children. This -// allows for addition, removal, and random selection/dropping in logarithmic time. -class WeightedTxTree +class MempoolLimitTxSet { - const int64_t capacity; - size_t size = 0; - - // The following two vectors are the tree representation of this collection. - // We keep track of 3 data points for each node: A transaction's txid, its cost, - // and the sum of the weights of all children and descendant of that node. - std::vector txIdAndWeights; - std::vector childWeights; - - // The following map is to simplify removal. When removing a tx, we do so by txid. - // This map allows looking up the transaction's index in the tree. - std::map txIdToIndexMap; - - // Returns the sum of a node and all of its children's TxWeights for a given index. - TxWeight getWeightAt(size_t index) const; - - // When adding and removing a node we need to update its parent and all of its - // ancestors to reflect its cost. - void backPropagate(size_t fromIndex, const TxWeight& weightDelta); - - // For a given random cost + fee penalty, this method recursively finds the - // correct transaction. This is used by WeightedTxTree::maybeDropRandom(). - size_t findByEvictionWeight(size_t fromIndex, int64_t weightToFind) const; + WeightedMap txmap; + int64_t capacity; + int64_t cost; public: - WeightedTxTree(int64_t capacity_) : capacity(capacity_) { + MempoolLimitTxSet(int64_t capacity_) : capacity(capacity_), cost(0) { assert(capacity >= 0); } - TxWeight getTotalWeight() const; + int64_t getTotalWeight() const + { + return txmap.getTotalWeight(); + } + bool empty() const + { + return txmap.empty(); + } + void add(const uint256& txId, int64_t txCost, int64_t txWeight) + { + if (txmap.add(txId, txCost, txWeight)) { + cost += txCost; + } + } + void remove(const uint256& txId) + { + cost -= txmap.remove(txId).value_or(0); + } - void add(const WeightedTxInfo& weightedTxInfo); - void remove(const uint256& txId); - - // If the total cost limit is exceeded, pick a random number based on the total cost - // of the collection and remove the associated transaction. - std::optional maybeDropRandom(); + // If the total cost limit has not been exceeded, return std::nullopt. Otherwise, + // pick a transaction at random with probability proportional to its eviction weight; + // remove and return that transaction's txid. + std::optional maybeDropRandom() + { + if (cost <= capacity) { + return std::nullopt; + } + LogPrint("mempool", "Mempool cost limit exceeded (cost=%d, limit=%d)\n", cost, capacity); + assert(!txmap.empty()); + auto [txId, txCost, txWeight] = txmap.takeRandom().value(); + cost -= txCost; + LogPrint("mempool", "Evicting transaction (txid=%s, cost=%d, weight=%d)\n", + txId.ToString(), txCost, txWeight); + return txId; + } }; - #endif // ZCASH_MEMPOOL_LIMIT_H diff --git a/src/random.cpp b/src/random.cpp index f92d39279..4d94b66db 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -55,6 +55,11 @@ uint64_t GetRand(uint64_t nMax) return (nRand % nMax); } +int64_t GetRandInt64(int64_t nMax) +{ + return GetRand(nMax); +} + int GetRandInt(int nMax) { return GetRand(nMax); diff --git a/src/random.h b/src/random.h index 171f0be41..c99a2601f 100644 --- a/src/random.h +++ b/src/random.h @@ -20,6 +20,7 @@ */ void GetRandBytes(unsigned char* buf, size_t num); uint64_t GetRand(uint64_t nMax); +int64_t GetRandInt64(int64_t nMax); int GetRandInt(int nMax); uint256 GetRandHash(); diff --git a/src/txmempool.cpp b/src/txmempool.cpp index e68fc2cd0..6cb6d8ec9 100644 --- a/src/txmempool.cpp +++ b/src/txmempool.cpp @@ -337,7 +337,7 @@ CTxMemPool::~CTxMemPool() { delete minerPolicyEstimator; delete recentlyEvicted; - delete weightedTxTree; + delete limitSet; } void CTxMemPool::pruneSpent(const uint256 &hashTx, CCoins &coins) @@ -371,7 +371,8 @@ bool CTxMemPool::addUnchecked(const uint256& hash, const CTxMemPoolEntry &entry, // Used by main.cpp AcceptToMemoryPool(), which DOES do // all the appropriate checks. LOCK(cs); - weightedTxTree->add(WeightedTxInfo::from(entry.GetTx(), entry.GetFee())); + auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(entry.GetTx(), entry.GetFee()); + limitSet->add(entry.GetTx().GetHash(), cost, evictionWeight); indexed_transaction_set::iterator newit = mapTx.insert(entry).first; mapLinks.insert(make_pair(newit, TxLinks())); @@ -637,7 +638,7 @@ void CTxMemPool::remove(const CTransaction &origTx, std::list& rem } RemoveStaged(setAllRemoves); for (CTransaction tx : removed) { - weightedTxTree->remove(tx.GetHash()); + limitSet->remove(tx.GetHash()); } } } @@ -1265,7 +1266,7 @@ size_t CTxMemPool::DynamicMemoryUsage() const { memusage::DynamicUsage(mapOrchardNullifiers); // DoS mitigation - total += memusage::DynamicUsage(recentlyEvicted) + memusage::DynamicUsage(weightedTxTree); + total += memusage::DynamicUsage(recentlyEvicted) + memusage::DynamicUsage(limitSet); // Insight-related structures size_t insight = 0; @@ -1282,9 +1283,9 @@ void CTxMemPool::SetMempoolCostLimit(int64_t totalCostLimit, int64_t evictionMem LOCK(cs); LogPrint("mempool", "Setting mempool cost limit: (limit=%d, time=%d)\n", totalCostLimit, evictionMemorySeconds); delete recentlyEvicted; - delete weightedTxTree; + delete limitSet; recentlyEvicted = new RecentlyEvictedList(GetNodeClock(), evictionMemorySeconds); - weightedTxTree = new WeightedTxTree(totalCostLimit); + limitSet = new MempoolLimitTxSet(totalCostLimit); } bool CTxMemPool::IsRecentlyEvicted(const uint256& txId) { @@ -1295,7 +1296,7 @@ bool CTxMemPool::IsRecentlyEvicted(const uint256& txId) { void CTxMemPool::EnsureSizeLimit() { AssertLockHeld(cs); std::optional maybeDropTxId; - while ((maybeDropTxId = weightedTxTree->maybeDropRandom()).has_value()) { + while ((maybeDropTxId = limitSet->maybeDropRandom()).has_value()) { uint256 txId = maybeDropTxId.value(); recentlyEvicted->add(txId); std::list removed; diff --git a/src/txmempool.h b/src/txmempool.h index 8cb085db5..ecbed2834 100644 --- a/src/txmempool.h +++ b/src/txmempool.h @@ -366,7 +366,7 @@ private: std::map mapSaplingNullifiers; std::map mapOrchardNullifiers; RecentlyEvictedList* recentlyEvicted = new RecentlyEvictedList(GetNodeClock(), DEFAULT_MEMPOOL_EVICTION_MEMORY_MINUTES * 60); - WeightedTxTree* weightedTxTree = new WeightedTxTree(DEFAULT_MEMPOOL_TOTAL_COST_LIMIT); + MempoolLimitTxSet* limitSet = new MempoolLimitTxSet(DEFAULT_MEMPOOL_TOTAL_COST_LIMIT); void checkNullifiers(ShieldedType type) const; diff --git a/src/weighted_map.h b/src/weighted_map.h new file mode 100644 index 000000000..4ab4cb9e4 --- /dev/null +++ b/src/weighted_map.h @@ -0,0 +1,193 @@ +// Copyright (c) 2019-2023 The Zcash developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or https://www.opensource.org/licenses/mit-license.php . + +#ifndef ZCASH_WEIGHTED_MAP_H +#define ZCASH_WEIGHTED_MAP_H + +#include +#include +#include +#include +#include + +// A WeightedMap represents a map from keys (of type K) to values (of type V), +// each entry having a weight (of type W). Elements can be randomly selected and +// removed from the map with probability in proportion to their weight. This is +// used to implement mempool limiting specified in ZIP 401. +// +// In order to efficiently implement random selection by weight, we keep track +// of the total weight of all keys in the map. For performance reasons, the +// map is represented as a binary tree where each node knows the sum of the +// weights of the children. This allows for addition, removal, and random +// selection/dropping in logarithmic time. +// +// random(w) must be defined to return a uniform random value between zero +// inclusive and w exclusive. The type W must support addition, binary and +// unary -, and < comparisons, and W() must construct the zero value (these +// constraints are met for primitive signed integer types). +template +class WeightedMap +{ + struct Node { + K key; + V value; + W weight; + W sumOfDescendantWeights; + }; + + // The following vector is the tree representation of this collection. + // For each node, we keep track of the key, its associated value, + // its weight, and the sum of the weights of all its descendants. + std::vector nodes; + + // The following map is to simplify removal. + std::map indexMap; + + static inline size_t leftChild(size_t i) { return i*2 + 1; } + static inline size_t rightChild(size_t i) { return i*2 + 2; } + static inline size_t parent(size_t i) { return (i-1)/2; } + +public: + // Check internal invariants (for tests). + void checkInvariants() const + { + assert(indexMap.size() == nodes.size()); + for (size_t i = 0; i < nodes.size(); i++) { + assert(indexMap.at(nodes.at(i).key) == i); + assert(nodes.at(i).sumOfDescendantWeights == getWeightAt(leftChild(i)) + getWeightAt(rightChild(i))); + } + } + +private: + // Return the sum of weights of the node at a given index and all of its descendants. + W getWeightAt(size_t index) const + { + if (index >= nodes.size()) { + return W(); + } + auto& node = nodes.at(index); + return node.weight + node.sumOfDescendantWeights; + } + + // When adding and removing a node we need to update its parent and all of its + // ancestors to reflect its weight. + void backPropagate(size_t fromIndex, W weightDelta) + { + while (fromIndex > 0) { + fromIndex = parent(fromIndex); + nodes[fromIndex].sumOfDescendantWeights += weightDelta; + } + } + + // For a given random weight, this method recursively finds the index of the + // correct entry. This is used by WeightedMap::takeRandom(). + size_t findByWeight(size_t fromIndex, W weightToFind) const + { + W leftWeight = getWeightAt(leftChild(fromIndex)); + W rightWeight = getWeightAt(fromIndex) - getWeightAt(rightChild(fromIndex)); + // On Left + if (weightToFind < leftWeight) { + return findByWeight(leftChild(fromIndex), weightToFind); + } + // Found + if (weightToFind < rightWeight) { + return fromIndex; + } + // On Right + return findByWeight(rightChild(fromIndex), weightToFind - rightWeight); + } + +public: + WeightedMap() {} + + // Return the total weight of all entries in the map. + W getTotalWeight() const + { + return getWeightAt(0); + } + + // Return true when the map has no entries. + bool empty() const + { + return nodes.empty(); + } + + // Return the number of entries. + size_t size() const + { + return nodes.size(); + } + + // Return false if the key already exists in the map. + // Otherwise, add an entry mapping `key` to `value` with the given weight, + // and return true. The weight must be positive. + bool add(K key, V value, W weight) + { + assert(W() < weight); + if (indexMap.count(key) > 0) { + return false; + } + size_t index = nodes.size(); + nodes.push_back(Node { + .key = key, + .value = value, + .weight = weight, + .sumOfDescendantWeights = W(), + }); + indexMap[key] = index; + backPropagate(index, weight); + return true; + } + + // If the given key is not present in the map, return std::nullopt. + // Otherwise, remove that key's entry and return its associated value. + std::optional remove(K key) + { + if (indexMap.count(key) == 0) { + return std::nullopt; + } + + size_t removeIndex = indexMap.at(key); + V removeValue = nodes.at(removeIndex).value; + + size_t lastIndex = nodes.size()-1; + Node lastNode = nodes.at(lastIndex); + W weightDelta = lastNode.weight - nodes.at(removeIndex).weight; + backPropagate(lastIndex, -lastNode.weight); + + if (removeIndex < lastIndex) { + nodes[removeIndex].key = lastNode.key; + nodes[removeIndex].value = lastNode.value; + nodes[removeIndex].weight = lastNode.weight; + // nodes[removeIndex].sumOfDescendantWeights should not change here. + indexMap[lastNode.key] = removeIndex; + backPropagate(removeIndex, weightDelta); + } + + indexMap.erase(key); + nodes.pop_back(); + return removeValue; + } + + // If the map is empty, return std::nullopt. Otherwise, pick a random entry + // with probability proportional to its weight; remove it and return a tuple of + // the key, its associated value, and its weight. + std::optional> takeRandom() + { + if (empty()) { + return std::nullopt; + } + W totalWeight = getTotalWeight(); + W randomWeight = random(totalWeight); + assert(W() <= randomWeight && randomWeight < totalWeight); + size_t index = findByWeight(0, randomWeight); + assert(index < nodes.size()); + const Node& drop = nodes.at(index); + auto res = std::make_tuple(drop.key, drop.value, drop.weight); // copy values + remove(drop.key); + return res; + } +}; + +#endif // ZCASH_WEIGHTED_MAP_H