Refactoring to split the weighted tx tree out of mempool_limit.{cpp,h}

and make it more reusable.

Signed-off-by: Daira Hopwood <daira@jacaranda.org>
This commit is contained in:
Daira Hopwood 2023-02-03 10:38:25 +00:00 committed by Daira Emma Hopwood
parent fcdfb5e780
commit 16099d66b6
11 changed files with 444 additions and 211 deletions

View File

@ -338,6 +338,7 @@ BITCOIN_CORE_H = \
wallet/walletdb.h \
wallet/wallet_tx_builder.h \
warnings.h \
weighted_map.h \
zmq/zmqabstractnotifier.h \
zmq/zmqconfig.h\
zmq/zmqnotificationinterface.h \

View File

@ -51,6 +51,7 @@ zcash_gtest_SOURCES = \
gtest/test_txid.cpp \
gtest/test_upgrades.cpp \
gtest/test_validation.cpp \
gtest/test_weightedmap.cpp \
gtest/test_zip32.cpp \
gtest/test_coins.cpp
if ENABLE_WALLET

View File

@ -80,36 +80,31 @@ TEST(MempoolLimitTests, RecentlyEvictedDropOneAtATime)
EXPECT_FALSE(recentlyEvicted.contains(TX_ID3));
}
TEST(MempoolLimitTests, WeightedTxTreeCheckSizeAfterDropping)
TEST(MempoolLimitTests, MempoolLimitTxSetCheckSizeAfterDropping)
{
std::set<uint256> testedDropping;
// Run the test until we have tested dropping each of the elements
int trialNum = 0;
while (testedDropping.size() < 3) {
WeightedTxTree tree(MIN_TX_COST * 2);
EXPECT_EQ(0, tree.getTotalWeight().cost);
EXPECT_EQ(0, tree.getTotalWeight().evictionWeight);
tree.add(WeightedTxInfo(TX_ID1, TxWeight(MIN_TX_COST, MIN_TX_COST)));
EXPECT_EQ(4000, tree.getTotalWeight().cost);
EXPECT_EQ(4000, tree.getTotalWeight().evictionWeight);
tree.add(WeightedTxInfo(TX_ID2, TxWeight(MIN_TX_COST, MIN_TX_COST)));
EXPECT_EQ(8000, tree.getTotalWeight().cost);
EXPECT_EQ(8000, tree.getTotalWeight().evictionWeight);
EXPECT_FALSE(tree.maybeDropRandom().has_value());
tree.add(WeightedTxInfo(TX_ID3, TxWeight(MIN_TX_COST, MIN_TX_COST + LOW_FEE_PENALTY)));
EXPECT_EQ(12000, tree.getTotalWeight().cost);
EXPECT_EQ(12000 + LOW_FEE_PENALTY, tree.getTotalWeight().evictionWeight);
std::optional<uint256> drop = tree.maybeDropRandom();
MempoolLimitTxSet limitSet(MIN_TX_COST * 2);
EXPECT_EQ(0, limitSet.getTotalWeight());
limitSet.add(TX_ID1, MIN_TX_COST, MIN_TX_COST);
EXPECT_EQ(4000, limitSet.getTotalWeight());
limitSet.add(TX_ID2, MIN_TX_COST, MIN_TX_COST);
EXPECT_EQ(8000, limitSet.getTotalWeight());
EXPECT_FALSE(limitSet.maybeDropRandom().has_value());
limitSet.add(TX_ID3, MIN_TX_COST, MIN_TX_COST + LOW_FEE_PENALTY);
EXPECT_EQ(12000 + LOW_FEE_PENALTY, limitSet.getTotalWeight());
std::optional<uint256> drop = limitSet.maybeDropRandom();
ASSERT_TRUE(drop.has_value());
uint256 txid = drop.value();
testedDropping.insert(txid);
// Do not continue to test if a particular trial fails
ASSERT_EQ(8000, tree.getTotalWeight().cost);
ASSERT_EQ(txid == TX_ID3 ? 8000 : 8000 + LOW_FEE_PENALTY, tree.getTotalWeight().evictionWeight);
ASSERT_EQ(txid == TX_ID3 ? 8000 : 8000 + LOW_FEE_PENALTY, limitSet.getTotalWeight());
}
}
TEST(MempoolLimitTests, WeightedTxInfoFromTx)
TEST(MempoolLimitTests, MempoolCostAndEvictionWeight)
{
LoadProofParameters();
@ -126,9 +121,9 @@ TEST(MempoolLimitTests, WeightedTxInfoFromTx)
builder.AddSaplingSpend(sk.expanded_spending_key(), testNote.note, testNote.tree.root(), testNote.tree.witness());
builder.AddSaplingOutput(sk.full_viewing_key().ovk, sk.default_address(), 25000, {});
WeightedTxInfo info = WeightedTxInfo::from(builder.Build().GetTxOrThrow(), DEFAULT_FEE);
EXPECT_EQ(MIN_TX_COST, info.txWeight.cost);
EXPECT_EQ(MIN_TX_COST, info.txWeight.evictionWeight);
auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(builder.Build().GetTxOrThrow(), DEFAULT_FEE);
EXPECT_EQ(MIN_TX_COST, cost);
EXPECT_EQ(MIN_TX_COST, evictionWeight);
}
// Lower than standard fee
@ -139,9 +134,9 @@ TEST(MempoolLimitTests, WeightedTxInfoFromTx)
static_assert(DEFAULT_FEE == 1000);
builder.SetFee(DEFAULT_FEE-1);
WeightedTxInfo info = WeightedTxInfo::from(builder.Build().GetTxOrThrow(), DEFAULT_FEE-1);
EXPECT_EQ(MIN_TX_COST, info.txWeight.cost);
EXPECT_EQ(MIN_TX_COST + LOW_FEE_PENALTY, info.txWeight.evictionWeight);
auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(builder.Build().GetTxOrThrow(), DEFAULT_FEE-1);
EXPECT_EQ(MIN_TX_COST, cost);
EXPECT_EQ(MIN_TX_COST + LOW_FEE_PENALTY, evictionWeight);
}
// Larger Tx
@ -157,9 +152,9 @@ TEST(MempoolLimitTests, WeightedTxInfoFromTx)
if (result.IsError()) {
std::cerr << result.GetError() << std::endl;
}
WeightedTxInfo info = WeightedTxInfo::from(result.GetTxOrThrow(), DEFAULT_FEE);
EXPECT_EQ(5168, info.txWeight.cost);
EXPECT_EQ(5168, info.txWeight.evictionWeight);
auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(result.GetTxOrThrow(), DEFAULT_FEE);
EXPECT_EQ(5168, cost);
EXPECT_EQ(5168, evictionWeight);
}
RegtestDeactivateSapling();

View File

@ -0,0 +1,160 @@
// Copyright (c) 2019-2023 The Zcash developers
// Distributed under the MIT software license, see the accompanying
// file COPYING or https://www.opensource.org/licenses/mit-license.php .
#include <gtest/gtest.h>
#include "weighted_map.h"
#include "gtest/utils.h"
#include "util/test.h"
TEST(WeightedMapTests, WeightedMap)
{
WeightedMap<int, int, int, GetRandInt> m;
EXPECT_EQ(0, m.size());
EXPECT_TRUE(m.empty());
EXPECT_EQ(0, m.getTotalWeight());
m.checkInvariants();
EXPECT_TRUE(m.add(3, 30, 3));
EXPECT_EQ(1, m.size());
EXPECT_FALSE(m.empty());
EXPECT_EQ(3, m.getTotalWeight());
m.checkInvariants();
EXPECT_TRUE(m.add(1, 10, 2));
EXPECT_EQ(2, m.size());
EXPECT_FALSE(m.empty());
EXPECT_EQ(5, m.getTotalWeight());
m.checkInvariants();
// adding a duplicate element should be ignored
EXPECT_FALSE(m.add(1, 15, 64));
EXPECT_EQ(2, m.size());
EXPECT_FALSE(m.empty());
EXPECT_EQ(5, m.getTotalWeight());
m.checkInvariants();
EXPECT_TRUE(m.add(2, 20, 1));
EXPECT_EQ(3, m.size());
EXPECT_FALSE(m.empty());
EXPECT_EQ(6, m.getTotalWeight());
m.checkInvariants();
// regression test: adding three elements and deleting the first caused an invariant violation (not in committed code)
EXPECT_EQ(30, m.remove(3).value());
EXPECT_EQ(2, m.size());
EXPECT_FALSE(m.empty());
EXPECT_EQ(3, m.getTotalWeight());
m.checkInvariants();
// try to remove a non-existent element
EXPECT_FALSE(m.remove(42).has_value());
EXPECT_EQ(2, m.size());
EXPECT_FALSE(m.empty());
EXPECT_EQ(3, m.getTotalWeight());
m.checkInvariants();
EXPECT_EQ(20, m.remove(2).value());
EXPECT_EQ(1, m.size());
EXPECT_FALSE(m.empty());
EXPECT_EQ(2, m.getTotalWeight());
m.checkInvariants();
EXPECT_TRUE(m.add(2, 20, 1));
EXPECT_EQ(2, m.size());
EXPECT_FALSE(m.empty());
EXPECT_EQ(3, m.getTotalWeight());
m.checkInvariants();
// at this point the map should contain 1->10 (weight 2) and 2->20 (weight 1)
auto [e1, c1, w1] = m.takeRandom().value();
EXPECT_TRUE(e1 == 1 || e1 == 2);
EXPECT_EQ(c1, e1*10);
EXPECT_EQ(w1, 3-e1);
EXPECT_EQ(1, m.size());
EXPECT_FALSE(m.empty());
EXPECT_EQ(3-w1, m.getTotalWeight());
m.checkInvariants();
auto [e2, c2, w2] = m.takeRandom().value();
EXPECT_EQ(3, e1 + e2);
EXPECT_EQ(c2, e2*10);
EXPECT_EQ(w2, 3-e2);
EXPECT_EQ(0, m.size());
EXPECT_TRUE(m.empty());
EXPECT_EQ(0, m.getTotalWeight());
m.checkInvariants();
EXPECT_FALSE(m.takeRandom().has_value());
EXPECT_EQ(0, m.size());
EXPECT_TRUE(m.empty());
EXPECT_EQ(0, m.getTotalWeight());
m.checkInvariants();
}
TEST(WeightedMapTests, WeightedMapRandomOps)
{
WeightedMap<int, int, int, GetRandInt> m;
std::map<int, int> expected; // element -> weight
int total_weight = 0;
const int iterations = 1000;
const int element_range = 20;
const int max_weight = 10;
static_assert(iterations <= std::numeric_limits<decltype(total_weight)>::max() / max_weight); // ensure total_weight cannot overflow
EXPECT_EQ(0, m.size());
EXPECT_TRUE(m.empty());
EXPECT_EQ(0, m.getTotalWeight());
m.checkInvariants();
for (int i = 0; i < iterations; i++) {
switch (GetRandInt(4)) {
// probability of add should be balanced with (remove or takeRandom)
case 0: case 1: {
int e = GetRandInt(element_range);
int w = GetRandInt(max_weight) + 1;
bool added = m.add(e, e*10, w);
EXPECT_EQ(added, expected.count(e) == 0);
if (added) {
total_weight += w;
expected[e] = w;
}
break;
}
case 2: {
int e = GetRandInt(element_range);
auto c = m.remove(e);
if (expected.count(e) == 0) {
EXPECT_FALSE(c.has_value());
} else {
ASSERT_TRUE(c.has_value());
EXPECT_EQ(c.value(), e*10);
total_weight -= expected[e];
expected.erase(e);
}
break;
}
case 3: {
auto r = m.takeRandom();
if (expected.empty()) {
EXPECT_FALSE(r.has_value());
} else {
ASSERT_TRUE(r.has_value());
auto [e, c, w] = r.value();
EXPECT_EQ(1, expected.count(e));
EXPECT_EQ(c, e*10);
EXPECT_EQ(w, expected[e]);
total_weight -= expected[e];
expected.erase(e);
}
break;
}
}
EXPECT_EQ(expected.size(), m.size());
EXPECT_EQ(expected.empty(), m.empty());
EXPECT_EQ(total_weight, m.getTotalWeight());
m.checkInvariants();
}
}

View File

@ -5,15 +5,11 @@
#include "mempool_limit.h"
#include "core_memusage.h"
#include "logging.h"
#include "random.h"
#include "serialize.h"
#include "timedata.h"
#include "util/time.h"
#include "version.h"
const TxWeight ZERO_WEIGHT = TxWeight(0, 0);
void RecentlyEvictedList::pruneList()
{
if (txIdSet.empty()) {
@ -43,111 +39,7 @@ bool RecentlyEvictedList::contains(const uint256& txId)
return txIdSet.count(txId) > 0;
}
TxWeight WeightedTxTree::getWeightAt(size_t index) const
{
return index < size ? txIdAndWeights[index].txWeight.add(childWeights[index]) : ZERO_WEIGHT;
}
void WeightedTxTree::backPropagate(size_t fromIndex, const TxWeight& weightDelta)
{
while (fromIndex > 0) {
fromIndex = (fromIndex - 1) / 2;
childWeights[fromIndex] = childWeights[fromIndex].add(weightDelta);
}
}
size_t WeightedTxTree::findByEvictionWeight(size_t fromIndex, int64_t weightToFind) const
{
int leftWeight = getWeightAt(fromIndex * 2 + 1).evictionWeight;
int rightWeight = getWeightAt(fromIndex).evictionWeight - getWeightAt(fromIndex * 2 + 2).evictionWeight;
// On Left
if (weightToFind < leftWeight) {
return findByEvictionWeight(fromIndex * 2 + 1, weightToFind);
}
// Found
if (weightToFind < rightWeight) {
return fromIndex;
}
// On Right
return findByEvictionWeight(fromIndex * 2 + 2, weightToFind - rightWeight);
}
TxWeight WeightedTxTree::getTotalWeight() const
{
return getWeightAt(0);
}
void WeightedTxTree::add(const WeightedTxInfo& weightedTxInfo)
{
if (txIdToIndexMap.count(weightedTxInfo.txId) > 0) {
// This should not happen, but should be prevented nonetheless
return;
}
txIdAndWeights.push_back(weightedTxInfo);
childWeights.push_back(ZERO_WEIGHT);
txIdToIndexMap[weightedTxInfo.txId] = size;
backPropagate(size, weightedTxInfo.txWeight);
size += 1;
}
void WeightedTxTree::remove(const uint256& txId)
{
if (txIdToIndexMap.find(txId) == txIdToIndexMap.end()) {
// Remove may be called multiple times for a given tx, so this is necessary
return;
}
size_t removeIndex = txIdToIndexMap[txId];
// We reduce the size at the start of this method to avoid saying size - 1
// when referring to the last element of the array below
size -= 1;
TxWeight lastChildWeight = txIdAndWeights[size].txWeight;
backPropagate(size, lastChildWeight.negate());
if (removeIndex < size) {
TxWeight weightDelta = lastChildWeight.add(txIdAndWeights[removeIndex].txWeight.negate());
txIdAndWeights[removeIndex] = txIdAndWeights[size];
txIdToIndexMap[txIdAndWeights[removeIndex].txId] = removeIndex;
backPropagate(removeIndex, weightDelta);
}
txIdToIndexMap.erase(txId);
txIdAndWeights.pop_back();
childWeights.pop_back();
}
std::optional<uint256> WeightedTxTree::maybeDropRandom()
{
TxWeight totalTxWeight = getTotalWeight();
if (totalTxWeight.cost <= capacity) {
return std::nullopt;
}
LogPrint("mempool", "Mempool cost limit exceeded (cost=%d, limit=%d)\n", totalTxWeight.cost, capacity);
int randomWeight = GetRand(totalTxWeight.evictionWeight);
WeightedTxInfo drop = txIdAndWeights[findByEvictionWeight(0, randomWeight)];
LogPrint("mempool", "Evicting transaction (txid=%s, cost=%d, evictionWeight=%d)\n",
drop.txId.ToString(), drop.txWeight.cost, drop.txWeight.evictionWeight);
remove(drop.txId);
return drop.txId;
}
TxWeight TxWeight::add(const TxWeight& other) const
{
return TxWeight(cost + other.cost, evictionWeight + other.evictionWeight);
}
TxWeight TxWeight::negate() const
{
return TxWeight(-cost, -evictionWeight);
}
WeightedTxInfo WeightedTxInfo::from(const CTransaction& tx, const CAmount& fee)
std::pair<int64_t, int64_t> MempoolCostAndEvictionWeight(const CTransaction& tx, const CAmount& fee)
{
size_t memUsage = RecursiveDynamicUsage(tx);
int64_t cost = std::max((int64_t) memUsage, (int64_t) MIN_TX_COST);
@ -155,5 +47,5 @@ WeightedTxInfo WeightedTxInfo::from(const CTransaction& tx, const CAmount& fee)
if (fee < DEFAULT_FEE) {
evictionWeight += LOW_FEE_PENALTY;
}
return WeightedTxInfo(tx.GetHash(), TxWeight(cost, evictionWeight));
return std::make_pair(cost, evictionWeight);
}

View File

@ -9,19 +9,21 @@
#include <map>
#include <optional>
#include <set>
#include <vector>
#include "logging.h"
#include "random.h"
#include "primitives/transaction.h"
#include "policy/fees.h"
#include "uint256.h"
#include "util/time.h"
#include "weighted_map.h"
const size_t DEFAULT_MEMPOOL_TOTAL_COST_LIMIT = 80000000;
const int64_t DEFAULT_MEMPOOL_EVICTION_MEMORY_MINUTES = 60;
const size_t EVICTION_MEMORY_ENTRIES = 40000;
const uint64_t MIN_TX_COST = 4000;
const uint64_t LOW_FEE_PENALTY = 16000;
const int64_t MIN_TX_COST = 4000;
const int64_t LOW_FEE_PENALTY = 16000;
// This class keeps track of transactions which have been recently evicted from the mempool
@ -57,80 +59,62 @@ public:
// The mempool of a node holds a set of transactions. Each transaction has a *cost*,
// which is an integer defined as:
// max(serialized transaction size in bytes, 4000)
// max(memory usage in bytes, 4000)
//
// Each transaction also has an *eviction weight*, which is *cost* + *fee_penalty*,
// where *fee_penalty* is 16000 if the transaction pays a fee less than 10000 zatoshi,
// otherwise 0.
struct TxWeight {
int64_t cost;
int64_t evictionWeight; // *cost* + *fee_penalty*
// where *fee_penalty* is 16000 if the transaction pays a fee less than the
// ZIP 317 conventional fee, otherwise 0.
TxWeight(int64_t cost_, int64_t evictionWeight_)
: cost(cost_), evictionWeight(evictionWeight_) {}
TxWeight add(const TxWeight& other) const;
TxWeight negate() const;
};
// Calculate cost and eviction weight based on the memory usage and fee.
std::pair<int64_t, int64_t> MempoolCostAndEvictionWeight(const CTransaction& tx, const CAmount& fee);
// This struct is a pair of txid, cost.
struct WeightedTxInfo {
uint256 txId;
TxWeight txWeight;
WeightedTxInfo(uint256 txId_, TxWeight txWeight_) : txId(txId_), txWeight(txWeight_) {}
// Factory method which calculates cost based on size in bytes and fee.
static WeightedTxInfo from(const CTransaction& tx, const CAmount& fee);
};
// The following class is a collection of transaction ids and their costs.
// In order to be able to remove transactions randomly weighted by their cost,
// we keep track of the total cost of all transactions in this collection.
// For performance reasons, the collection is represented as a complete binary
// tree where each node knows the sum of the weights of the children. This
// allows for addition, removal, and random selection/dropping in logarithmic time.
class WeightedTxTree
class MempoolLimitTxSet
{
const int64_t capacity;
size_t size = 0;
// The following two vectors are the tree representation of this collection.
// We keep track of 3 data points for each node: A transaction's txid, its cost,
// and the sum of the weights of all children and descendant of that node.
std::vector<WeightedTxInfo> txIdAndWeights;
std::vector<TxWeight> childWeights;
// The following map is to simplify removal. When removing a tx, we do so by txid.
// This map allows looking up the transaction's index in the tree.
std::map<uint256, size_t> txIdToIndexMap;
// Returns the sum of a node and all of its children's TxWeights for a given index.
TxWeight getWeightAt(size_t index) const;
// When adding and removing a node we need to update its parent and all of its
// ancestors to reflect its cost.
void backPropagate(size_t fromIndex, const TxWeight& weightDelta);
// For a given random cost + fee penalty, this method recursively finds the
// correct transaction. This is used by WeightedTxTree::maybeDropRandom().
size_t findByEvictionWeight(size_t fromIndex, int64_t weightToFind) const;
WeightedMap<uint256, int64_t, int64_t, GetRandInt64> txmap;
int64_t capacity;
int64_t cost;
public:
WeightedTxTree(int64_t capacity_) : capacity(capacity_) {
MempoolLimitTxSet(int64_t capacity_) : capacity(capacity_), cost(0) {
assert(capacity >= 0);
}
TxWeight getTotalWeight() const;
int64_t getTotalWeight() const
{
return txmap.getTotalWeight();
}
bool empty() const
{
return txmap.empty();
}
void add(const uint256& txId, int64_t txCost, int64_t txWeight)
{
if (txmap.add(txId, txCost, txWeight)) {
cost += txCost;
}
}
void remove(const uint256& txId)
{
cost -= txmap.remove(txId).value_or(0);
}
void add(const WeightedTxInfo& weightedTxInfo);
void remove(const uint256& txId);
// If the total cost limit is exceeded, pick a random number based on the total cost
// of the collection and remove the associated transaction.
std::optional<uint256> maybeDropRandom();
// If the total cost limit has not been exceeded, return std::nullopt. Otherwise,
// pick a transaction at random with probability proportional to its eviction weight;
// remove and return that transaction's txid.
std::optional<uint256> maybeDropRandom()
{
if (cost <= capacity) {
return std::nullopt;
}
LogPrint("mempool", "Mempool cost limit exceeded (cost=%d, limit=%d)\n", cost, capacity);
assert(!txmap.empty());
auto [txId, txCost, txWeight] = txmap.takeRandom().value();
cost -= txCost;
LogPrint("mempool", "Evicting transaction (txid=%s, cost=%d, weight=%d)\n",
txId.ToString(), txCost, txWeight);
return txId;
}
};
#endif // ZCASH_MEMPOOL_LIMIT_H

View File

@ -55,6 +55,11 @@ uint64_t GetRand(uint64_t nMax)
return (nRand % nMax);
}
int64_t GetRandInt64(int64_t nMax)
{
return GetRand(nMax);
}
int GetRandInt(int nMax)
{
return GetRand(nMax);

View File

@ -20,6 +20,7 @@
*/
void GetRandBytes(unsigned char* buf, size_t num);
uint64_t GetRand(uint64_t nMax);
int64_t GetRandInt64(int64_t nMax);
int GetRandInt(int nMax);
uint256 GetRandHash();

View File

@ -337,7 +337,7 @@ CTxMemPool::~CTxMemPool()
{
delete minerPolicyEstimator;
delete recentlyEvicted;
delete weightedTxTree;
delete limitSet;
}
void CTxMemPool::pruneSpent(const uint256 &hashTx, CCoins &coins)
@ -371,7 +371,8 @@ bool CTxMemPool::addUnchecked(const uint256& hash, const CTxMemPoolEntry &entry,
// Used by main.cpp AcceptToMemoryPool(), which DOES do
// all the appropriate checks.
LOCK(cs);
weightedTxTree->add(WeightedTxInfo::from(entry.GetTx(), entry.GetFee()));
auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(entry.GetTx(), entry.GetFee());
limitSet->add(entry.GetTx().GetHash(), cost, evictionWeight);
indexed_transaction_set::iterator newit = mapTx.insert(entry).first;
mapLinks.insert(make_pair(newit, TxLinks()));
@ -637,7 +638,7 @@ void CTxMemPool::remove(const CTransaction &origTx, std::list<CTransaction>& rem
}
RemoveStaged(setAllRemoves);
for (CTransaction tx : removed) {
weightedTxTree->remove(tx.GetHash());
limitSet->remove(tx.GetHash());
}
}
}
@ -1265,7 +1266,7 @@ size_t CTxMemPool::DynamicMemoryUsage() const {
memusage::DynamicUsage(mapOrchardNullifiers);
// DoS mitigation
total += memusage::DynamicUsage(recentlyEvicted) + memusage::DynamicUsage(weightedTxTree);
total += memusage::DynamicUsage(recentlyEvicted) + memusage::DynamicUsage(limitSet);
// Insight-related structures
size_t insight = 0;
@ -1282,9 +1283,9 @@ void CTxMemPool::SetMempoolCostLimit(int64_t totalCostLimit, int64_t evictionMem
LOCK(cs);
LogPrint("mempool", "Setting mempool cost limit: (limit=%d, time=%d)\n", totalCostLimit, evictionMemorySeconds);
delete recentlyEvicted;
delete weightedTxTree;
delete limitSet;
recentlyEvicted = new RecentlyEvictedList(GetNodeClock(), evictionMemorySeconds);
weightedTxTree = new WeightedTxTree(totalCostLimit);
limitSet = new MempoolLimitTxSet(totalCostLimit);
}
bool CTxMemPool::IsRecentlyEvicted(const uint256& txId) {
@ -1295,7 +1296,7 @@ bool CTxMemPool::IsRecentlyEvicted(const uint256& txId) {
void CTxMemPool::EnsureSizeLimit() {
AssertLockHeld(cs);
std::optional<uint256> maybeDropTxId;
while ((maybeDropTxId = weightedTxTree->maybeDropRandom()).has_value()) {
while ((maybeDropTxId = limitSet->maybeDropRandom()).has_value()) {
uint256 txId = maybeDropTxId.value();
recentlyEvicted->add(txId);
std::list<CTransaction> removed;

View File

@ -366,7 +366,7 @@ private:
std::map<uint256, const CTransaction*> mapSaplingNullifiers;
std::map<uint256, const CTransaction*> mapOrchardNullifiers;
RecentlyEvictedList* recentlyEvicted = new RecentlyEvictedList(GetNodeClock(), DEFAULT_MEMPOOL_EVICTION_MEMORY_MINUTES * 60);
WeightedTxTree* weightedTxTree = new WeightedTxTree(DEFAULT_MEMPOOL_TOTAL_COST_LIMIT);
MempoolLimitTxSet* limitSet = new MempoolLimitTxSet(DEFAULT_MEMPOOL_TOTAL_COST_LIMIT);
void checkNullifiers(ShieldedType type) const;

193
src/weighted_map.h Normal file
View File

@ -0,0 +1,193 @@
// Copyright (c) 2019-2023 The Zcash developers
// Distributed under the MIT software license, see the accompanying
// file COPYING or https://www.opensource.org/licenses/mit-license.php .
#ifndef ZCASH_WEIGHTED_MAP_H
#define ZCASH_WEIGHTED_MAP_H
#include <iostream>
#include <map>
#include <optional>
#include <set>
#include <vector>
// A WeightedMap represents a map from keys (of type K) to values (of type V),
// each entry having a weight (of type W). Elements can be randomly selected and
// removed from the map with probability in proportion to their weight. This is
// used to implement mempool limiting specified in ZIP 401.
//
// In order to efficiently implement random selection by weight, we keep track
// of the total weight of all keys in the map. For performance reasons, the
// map is represented as a binary tree where each node knows the sum of the
// weights of the children. This allows for addition, removal, and random
// selection/dropping in logarithmic time.
//
// random(w) must be defined to return a uniform random value between zero
// inclusive and w exclusive. The type W must support addition, binary and
// unary -, and < comparisons, and W() must construct the zero value (these
// constraints are met for primitive signed integer types).
template <typename K, typename V, typename W, W random(W)>
class WeightedMap
{
struct Node {
K key;
V value;
W weight;
W sumOfDescendantWeights;
};
// The following vector is the tree representation of this collection.
// For each node, we keep track of the key, its associated value,
// its weight, and the sum of the weights of all its descendants.
std::vector<Node> nodes;
// The following map is to simplify removal.
std::map<K, size_t> indexMap;
static inline size_t leftChild(size_t i) { return i*2 + 1; }
static inline size_t rightChild(size_t i) { return i*2 + 2; }
static inline size_t parent(size_t i) { return (i-1)/2; }
public:
// Check internal invariants (for tests).
void checkInvariants() const
{
assert(indexMap.size() == nodes.size());
for (size_t i = 0; i < nodes.size(); i++) {
assert(indexMap.at(nodes.at(i).key) == i);
assert(nodes.at(i).sumOfDescendantWeights == getWeightAt(leftChild(i)) + getWeightAt(rightChild(i)));
}
}
private:
// Return the sum of weights of the node at a given index and all of its descendants.
W getWeightAt(size_t index) const
{
if (index >= nodes.size()) {
return W();
}
auto& node = nodes.at(index);
return node.weight + node.sumOfDescendantWeights;
}
// When adding and removing a node we need to update its parent and all of its
// ancestors to reflect its weight.
void backPropagate(size_t fromIndex, W weightDelta)
{
while (fromIndex > 0) {
fromIndex = parent(fromIndex);
nodes[fromIndex].sumOfDescendantWeights += weightDelta;
}
}
// For a given random weight, this method recursively finds the index of the
// correct entry. This is used by WeightedMap::takeRandom().
size_t findByWeight(size_t fromIndex, W weightToFind) const
{
W leftWeight = getWeightAt(leftChild(fromIndex));
W rightWeight = getWeightAt(fromIndex) - getWeightAt(rightChild(fromIndex));
// On Left
if (weightToFind < leftWeight) {
return findByWeight(leftChild(fromIndex), weightToFind);
}
// Found
if (weightToFind < rightWeight) {
return fromIndex;
}
// On Right
return findByWeight(rightChild(fromIndex), weightToFind - rightWeight);
}
public:
WeightedMap() {}
// Return the total weight of all entries in the map.
W getTotalWeight() const
{
return getWeightAt(0);
}
// Return true when the map has no entries.
bool empty() const
{
return nodes.empty();
}
// Return the number of entries.
size_t size() const
{
return nodes.size();
}
// Return false if the key already exists in the map.
// Otherwise, add an entry mapping `key` to `value` with the given weight,
// and return true. The weight must be positive.
bool add(K key, V value, W weight)
{
assert(W() < weight);
if (indexMap.count(key) > 0) {
return false;
}
size_t index = nodes.size();
nodes.push_back(Node {
.key = key,
.value = value,
.weight = weight,
.sumOfDescendantWeights = W(),
});
indexMap[key] = index;
backPropagate(index, weight);
return true;
}
// If the given key is not present in the map, return std::nullopt.
// Otherwise, remove that key's entry and return its associated value.
std::optional<V> remove(K key)
{
if (indexMap.count(key) == 0) {
return std::nullopt;
}
size_t removeIndex = indexMap.at(key);
V removeValue = nodes.at(removeIndex).value;
size_t lastIndex = nodes.size()-1;
Node lastNode = nodes.at(lastIndex);
W weightDelta = lastNode.weight - nodes.at(removeIndex).weight;
backPropagate(lastIndex, -lastNode.weight);
if (removeIndex < lastIndex) {
nodes[removeIndex].key = lastNode.key;
nodes[removeIndex].value = lastNode.value;
nodes[removeIndex].weight = lastNode.weight;
// nodes[removeIndex].sumOfDescendantWeights should not change here.
indexMap[lastNode.key] = removeIndex;
backPropagate(removeIndex, weightDelta);
}
indexMap.erase(key);
nodes.pop_back();
return removeValue;
}
// If the map is empty, return std::nullopt. Otherwise, pick a random entry
// with probability proportional to its weight; remove it and return a tuple of
// the key, its associated value, and its weight.
std::optional<std::tuple<K, V, W>> takeRandom()
{
if (empty()) {
return std::nullopt;
}
W totalWeight = getTotalWeight();
W randomWeight = random(totalWeight);
assert(W() <= randomWeight && randomWeight < totalWeight);
size_t index = findByWeight(0, randomWeight);
assert(index < nodes.size());
const Node& drop = nodes.at(index);
auto res = std::make_tuple(drop.key, drop.value, drop.weight); // copy values
remove(drop.key);
return res;
}
};
#endif // ZCASH_WEIGHTED_MAP_H