Merge pull request #6460 from daira/generalize-weighted-map
Refactoring to split the weighted tx tree out of mempool_limit.{cpp,h} and make it more reusable
This commit is contained in:
commit
659030aa46
|
@ -338,6 +338,7 @@ BITCOIN_CORE_H = \
|
|||
wallet/walletdb.h \
|
||||
wallet/wallet_tx_builder.h \
|
||||
warnings.h \
|
||||
weighted_map.h \
|
||||
zmq/zmqabstractnotifier.h \
|
||||
zmq/zmqconfig.h\
|
||||
zmq/zmqnotificationinterface.h \
|
||||
|
|
|
@ -51,6 +51,7 @@ zcash_gtest_SOURCES = \
|
|||
gtest/test_txid.cpp \
|
||||
gtest/test_upgrades.cpp \
|
||||
gtest/test_validation.cpp \
|
||||
gtest/test_weightedmap.cpp \
|
||||
gtest/test_zip32.cpp \
|
||||
gtest/test_coins.cpp
|
||||
if ENABLE_WALLET
|
||||
|
|
|
@ -80,36 +80,31 @@ TEST(MempoolLimitTests, RecentlyEvictedDropOneAtATime)
|
|||
EXPECT_FALSE(recentlyEvicted.contains(TX_ID3));
|
||||
}
|
||||
|
||||
TEST(MempoolLimitTests, WeightedTxTreeCheckSizeAfterDropping)
|
||||
TEST(MempoolLimitTests, MempoolLimitTxSetCheckSizeAfterDropping)
|
||||
{
|
||||
std::set<uint256> testedDropping;
|
||||
// Run the test until we have tested dropping each of the elements
|
||||
int trialNum = 0;
|
||||
while (testedDropping.size() < 3) {
|
||||
WeightedTxTree tree(MIN_TX_COST * 2);
|
||||
EXPECT_EQ(0, tree.getTotalWeight().cost);
|
||||
EXPECT_EQ(0, tree.getTotalWeight().evictionWeight);
|
||||
tree.add(WeightedTxInfo(TX_ID1, TxWeight(MIN_TX_COST, MIN_TX_COST)));
|
||||
EXPECT_EQ(4000, tree.getTotalWeight().cost);
|
||||
EXPECT_EQ(4000, tree.getTotalWeight().evictionWeight);
|
||||
tree.add(WeightedTxInfo(TX_ID2, TxWeight(MIN_TX_COST, MIN_TX_COST)));
|
||||
EXPECT_EQ(8000, tree.getTotalWeight().cost);
|
||||
EXPECT_EQ(8000, tree.getTotalWeight().evictionWeight);
|
||||
EXPECT_FALSE(tree.maybeDropRandom().has_value());
|
||||
tree.add(WeightedTxInfo(TX_ID3, TxWeight(MIN_TX_COST, MIN_TX_COST + LOW_FEE_PENALTY)));
|
||||
EXPECT_EQ(12000, tree.getTotalWeight().cost);
|
||||
EXPECT_EQ(12000 + LOW_FEE_PENALTY, tree.getTotalWeight().evictionWeight);
|
||||
std::optional<uint256> drop = tree.maybeDropRandom();
|
||||
MempoolLimitTxSet limitSet(MIN_TX_COST * 2);
|
||||
EXPECT_EQ(0, limitSet.getTotalWeight());
|
||||
limitSet.add(TX_ID1, MIN_TX_COST, MIN_TX_COST);
|
||||
EXPECT_EQ(4000, limitSet.getTotalWeight());
|
||||
limitSet.add(TX_ID2, MIN_TX_COST, MIN_TX_COST);
|
||||
EXPECT_EQ(8000, limitSet.getTotalWeight());
|
||||
EXPECT_FALSE(limitSet.maybeDropRandom().has_value());
|
||||
limitSet.add(TX_ID3, MIN_TX_COST, MIN_TX_COST + LOW_FEE_PENALTY);
|
||||
EXPECT_EQ(12000 + LOW_FEE_PENALTY, limitSet.getTotalWeight());
|
||||
std::optional<uint256> drop = limitSet.maybeDropRandom();
|
||||
ASSERT_TRUE(drop.has_value());
|
||||
uint256 txid = drop.value();
|
||||
testedDropping.insert(txid);
|
||||
// Do not continue to test if a particular trial fails
|
||||
ASSERT_EQ(8000, tree.getTotalWeight().cost);
|
||||
ASSERT_EQ(txid == TX_ID3 ? 8000 : 8000 + LOW_FEE_PENALTY, tree.getTotalWeight().evictionWeight);
|
||||
ASSERT_EQ(txid == TX_ID3 ? 8000 : 8000 + LOW_FEE_PENALTY, limitSet.getTotalWeight());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MempoolLimitTests, WeightedTxInfoFromTx)
|
||||
TEST(MempoolLimitTests, MempoolCostAndEvictionWeight)
|
||||
{
|
||||
LoadProofParameters();
|
||||
|
||||
|
@ -126,9 +121,9 @@ TEST(MempoolLimitTests, WeightedTxInfoFromTx)
|
|||
builder.AddSaplingSpend(sk.expanded_spending_key(), testNote.note, testNote.tree.root(), testNote.tree.witness());
|
||||
builder.AddSaplingOutput(sk.full_viewing_key().ovk, sk.default_address(), 25000, {});
|
||||
|
||||
WeightedTxInfo info = WeightedTxInfo::from(builder.Build().GetTxOrThrow(), DEFAULT_FEE);
|
||||
EXPECT_EQ(MIN_TX_COST, info.txWeight.cost);
|
||||
EXPECT_EQ(MIN_TX_COST, info.txWeight.evictionWeight);
|
||||
auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(builder.Build().GetTxOrThrow(), DEFAULT_FEE);
|
||||
EXPECT_EQ(MIN_TX_COST, cost);
|
||||
EXPECT_EQ(MIN_TX_COST, evictionWeight);
|
||||
}
|
||||
|
||||
// Lower than standard fee
|
||||
|
@ -139,9 +134,9 @@ TEST(MempoolLimitTests, WeightedTxInfoFromTx)
|
|||
static_assert(DEFAULT_FEE == 1000);
|
||||
builder.SetFee(DEFAULT_FEE-1);
|
||||
|
||||
WeightedTxInfo info = WeightedTxInfo::from(builder.Build().GetTxOrThrow(), DEFAULT_FEE-1);
|
||||
EXPECT_EQ(MIN_TX_COST, info.txWeight.cost);
|
||||
EXPECT_EQ(MIN_TX_COST + LOW_FEE_PENALTY, info.txWeight.evictionWeight);
|
||||
auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(builder.Build().GetTxOrThrow(), DEFAULT_FEE-1);
|
||||
EXPECT_EQ(MIN_TX_COST, cost);
|
||||
EXPECT_EQ(MIN_TX_COST + LOW_FEE_PENALTY, evictionWeight);
|
||||
}
|
||||
|
||||
// Larger Tx
|
||||
|
@ -157,9 +152,9 @@ TEST(MempoolLimitTests, WeightedTxInfoFromTx)
|
|||
if (result.IsError()) {
|
||||
std::cerr << result.GetError() << std::endl;
|
||||
}
|
||||
WeightedTxInfo info = WeightedTxInfo::from(result.GetTxOrThrow(), DEFAULT_FEE);
|
||||
EXPECT_EQ(5168, info.txWeight.cost);
|
||||
EXPECT_EQ(5168, info.txWeight.evictionWeight);
|
||||
auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(result.GetTxOrThrow(), DEFAULT_FEE);
|
||||
EXPECT_EQ(5168, cost);
|
||||
EXPECT_EQ(5168, evictionWeight);
|
||||
}
|
||||
|
||||
RegtestDeactivateSapling();
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
// Copyright (c) 2019-2023 The Zcash developers
|
||||
// Distributed under the MIT software license, see the accompanying
|
||||
// file COPYING or https://www.opensource.org/licenses/mit-license.php .
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "weighted_map.h"
|
||||
#include "gtest/utils.h"
|
||||
#include "util/test.h"
|
||||
|
||||
|
||||
TEST(WeightedMapTests, WeightedMap)
|
||||
{
|
||||
WeightedMap<int, int, int, GetRandInt> m;
|
||||
|
||||
EXPECT_EQ(0, m.size());
|
||||
EXPECT_TRUE(m.empty());
|
||||
EXPECT_EQ(0, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
EXPECT_TRUE(m.add(3, 30, 3));
|
||||
EXPECT_EQ(1, m.size());
|
||||
EXPECT_FALSE(m.empty());
|
||||
EXPECT_EQ(3, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
EXPECT_TRUE(m.add(1, 10, 2));
|
||||
EXPECT_EQ(2, m.size());
|
||||
EXPECT_FALSE(m.empty());
|
||||
EXPECT_EQ(5, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
// adding a duplicate element should be ignored
|
||||
EXPECT_FALSE(m.add(1, 15, 64));
|
||||
EXPECT_EQ(2, m.size());
|
||||
EXPECT_FALSE(m.empty());
|
||||
EXPECT_EQ(5, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
EXPECT_TRUE(m.add(2, 20, 1));
|
||||
EXPECT_EQ(3, m.size());
|
||||
EXPECT_FALSE(m.empty());
|
||||
EXPECT_EQ(6, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
// regression test: adding three elements and deleting the first caused an invariant violation (not in committed code)
|
||||
EXPECT_EQ(30, m.remove(3).value());
|
||||
EXPECT_EQ(2, m.size());
|
||||
EXPECT_FALSE(m.empty());
|
||||
EXPECT_EQ(3, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
// try to remove a non-existent element
|
||||
EXPECT_FALSE(m.remove(42).has_value());
|
||||
EXPECT_EQ(2, m.size());
|
||||
EXPECT_FALSE(m.empty());
|
||||
EXPECT_EQ(3, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
EXPECT_EQ(20, m.remove(2).value());
|
||||
EXPECT_EQ(1, m.size());
|
||||
EXPECT_FALSE(m.empty());
|
||||
EXPECT_EQ(2, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
EXPECT_TRUE(m.add(2, 20, 1));
|
||||
EXPECT_EQ(2, m.size());
|
||||
EXPECT_FALSE(m.empty());
|
||||
EXPECT_EQ(3, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
// at this point the map should contain 1->10 (weight 2) and 2->20 (weight 1)
|
||||
auto [e1, c1, w1] = m.takeRandom().value();
|
||||
EXPECT_TRUE(e1 == 1 || e1 == 2);
|
||||
EXPECT_EQ(c1, e1*10);
|
||||
EXPECT_EQ(w1, 3-e1);
|
||||
EXPECT_EQ(1, m.size());
|
||||
EXPECT_FALSE(m.empty());
|
||||
EXPECT_EQ(3-w1, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
auto [e2, c2, w2] = m.takeRandom().value();
|
||||
EXPECT_EQ(3, e1 + e2);
|
||||
EXPECT_EQ(c2, e2*10);
|
||||
EXPECT_EQ(w2, 3-e2);
|
||||
EXPECT_EQ(0, m.size());
|
||||
EXPECT_TRUE(m.empty());
|
||||
EXPECT_EQ(0, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
|
||||
EXPECT_FALSE(m.takeRandom().has_value());
|
||||
EXPECT_EQ(0, m.size());
|
||||
EXPECT_TRUE(m.empty());
|
||||
EXPECT_EQ(0, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
}
|
||||
|
||||
TEST(WeightedMapTests, WeightedMapRandomOps)
|
||||
{
|
||||
WeightedMap<int, int, int, GetRandInt> m;
|
||||
std::map<int, int> expected; // element -> weight
|
||||
int total_weight = 0;
|
||||
const int iterations = 1000;
|
||||
const int element_range = 20;
|
||||
const int max_weight = 10;
|
||||
static_assert(iterations <= std::numeric_limits<decltype(total_weight)>::max() / max_weight); // ensure total_weight cannot overflow
|
||||
|
||||
EXPECT_EQ(0, m.size());
|
||||
EXPECT_TRUE(m.empty());
|
||||
EXPECT_EQ(0, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
switch (GetRandInt(4)) {
|
||||
// probability of add should be balanced with (remove or takeRandom)
|
||||
case 0: case 1: {
|
||||
int e = GetRandInt(element_range);
|
||||
int w = GetRandInt(max_weight) + 1;
|
||||
bool added = m.add(e, e*10, w);
|
||||
EXPECT_EQ(added, expected.count(e) == 0);
|
||||
if (added) {
|
||||
total_weight += w;
|
||||
expected[e] = w;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
int e = GetRandInt(element_range);
|
||||
auto c = m.remove(e);
|
||||
if (expected.count(e) == 0) {
|
||||
EXPECT_FALSE(c.has_value());
|
||||
} else {
|
||||
ASSERT_TRUE(c.has_value());
|
||||
EXPECT_EQ(c.value(), e*10);
|
||||
total_weight -= expected[e];
|
||||
expected.erase(e);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
auto r = m.takeRandom();
|
||||
if (expected.empty()) {
|
||||
EXPECT_FALSE(r.has_value());
|
||||
} else {
|
||||
ASSERT_TRUE(r.has_value());
|
||||
auto [e, c, w] = r.value();
|
||||
EXPECT_EQ(1, expected.count(e));
|
||||
EXPECT_EQ(c, e*10);
|
||||
EXPECT_EQ(w, expected[e]);
|
||||
total_weight -= expected[e];
|
||||
expected.erase(e);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(expected.size(), m.size());
|
||||
EXPECT_EQ(expected.empty(), m.empty());
|
||||
EXPECT_EQ(total_weight, m.getTotalWeight());
|
||||
m.checkInvariants();
|
||||
}
|
||||
}
|
|
@ -5,15 +5,11 @@
|
|||
#include "mempool_limit.h"
|
||||
|
||||
#include "core_memusage.h"
|
||||
#include "logging.h"
|
||||
#include "random.h"
|
||||
#include "serialize.h"
|
||||
#include "timedata.h"
|
||||
#include "util/time.h"
|
||||
#include "version.h"
|
||||
|
||||
const TxWeight ZERO_WEIGHT = TxWeight(0, 0);
|
||||
|
||||
void RecentlyEvictedList::pruneList()
|
||||
{
|
||||
if (txIdSet.empty()) {
|
||||
|
@ -43,111 +39,7 @@ bool RecentlyEvictedList::contains(const uint256& txId)
|
|||
return txIdSet.count(txId) > 0;
|
||||
}
|
||||
|
||||
|
||||
TxWeight WeightedTxTree::getWeightAt(size_t index) const
|
||||
{
|
||||
return index < size ? txIdAndWeights[index].txWeight.add(childWeights[index]) : ZERO_WEIGHT;
|
||||
}
|
||||
|
||||
void WeightedTxTree::backPropagate(size_t fromIndex, const TxWeight& weightDelta)
|
||||
{
|
||||
while (fromIndex > 0) {
|
||||
fromIndex = (fromIndex - 1) / 2;
|
||||
childWeights[fromIndex] = childWeights[fromIndex].add(weightDelta);
|
||||
}
|
||||
}
|
||||
|
||||
size_t WeightedTxTree::findByEvictionWeight(size_t fromIndex, int64_t weightToFind) const
|
||||
{
|
||||
int leftWeight = getWeightAt(fromIndex * 2 + 1).evictionWeight;
|
||||
int rightWeight = getWeightAt(fromIndex).evictionWeight - getWeightAt(fromIndex * 2 + 2).evictionWeight;
|
||||
// On Left
|
||||
if (weightToFind < leftWeight) {
|
||||
return findByEvictionWeight(fromIndex * 2 + 1, weightToFind);
|
||||
}
|
||||
// Found
|
||||
if (weightToFind < rightWeight) {
|
||||
return fromIndex;
|
||||
}
|
||||
// On Right
|
||||
return findByEvictionWeight(fromIndex * 2 + 2, weightToFind - rightWeight);
|
||||
}
|
||||
|
||||
TxWeight WeightedTxTree::getTotalWeight() const
|
||||
{
|
||||
return getWeightAt(0);
|
||||
}
|
||||
|
||||
|
||||
void WeightedTxTree::add(const WeightedTxInfo& weightedTxInfo)
|
||||
{
|
||||
if (txIdToIndexMap.count(weightedTxInfo.txId) > 0) {
|
||||
// This should not happen, but should be prevented nonetheless
|
||||
return;
|
||||
}
|
||||
txIdAndWeights.push_back(weightedTxInfo);
|
||||
childWeights.push_back(ZERO_WEIGHT);
|
||||
txIdToIndexMap[weightedTxInfo.txId] = size;
|
||||
backPropagate(size, weightedTxInfo.txWeight);
|
||||
size += 1;
|
||||
}
|
||||
|
||||
void WeightedTxTree::remove(const uint256& txId)
|
||||
{
|
||||
if (txIdToIndexMap.find(txId) == txIdToIndexMap.end()) {
|
||||
// Remove may be called multiple times for a given tx, so this is necessary
|
||||
return;
|
||||
}
|
||||
|
||||
size_t removeIndex = txIdToIndexMap[txId];
|
||||
|
||||
// We reduce the size at the start of this method to avoid saying size - 1
|
||||
// when referring to the last element of the array below
|
||||
size -= 1;
|
||||
|
||||
TxWeight lastChildWeight = txIdAndWeights[size].txWeight;
|
||||
backPropagate(size, lastChildWeight.negate());
|
||||
|
||||
if (removeIndex < size) {
|
||||
TxWeight weightDelta = lastChildWeight.add(txIdAndWeights[removeIndex].txWeight.negate());
|
||||
txIdAndWeights[removeIndex] = txIdAndWeights[size];
|
||||
txIdToIndexMap[txIdAndWeights[removeIndex].txId] = removeIndex;
|
||||
backPropagate(removeIndex, weightDelta);
|
||||
}
|
||||
|
||||
txIdToIndexMap.erase(txId);
|
||||
txIdAndWeights.pop_back();
|
||||
childWeights.pop_back();
|
||||
}
|
||||
|
||||
std::optional<uint256> WeightedTxTree::maybeDropRandom()
|
||||
{
|
||||
TxWeight totalTxWeight = getTotalWeight();
|
||||
if (totalTxWeight.cost <= capacity) {
|
||||
return std::nullopt;
|
||||
}
|
||||
LogPrint("mempool", "Mempool cost limit exceeded (cost=%d, limit=%d)\n", totalTxWeight.cost, capacity);
|
||||
int randomWeight = GetRand(totalTxWeight.evictionWeight);
|
||||
WeightedTxInfo drop = txIdAndWeights[findByEvictionWeight(0, randomWeight)];
|
||||
LogPrint("mempool", "Evicting transaction (txid=%s, cost=%d, evictionWeight=%d)\n",
|
||||
drop.txId.ToString(), drop.txWeight.cost, drop.txWeight.evictionWeight);
|
||||
remove(drop.txId);
|
||||
return drop.txId;
|
||||
}
|
||||
|
||||
|
||||
TxWeight TxWeight::add(const TxWeight& other) const
|
||||
{
|
||||
return TxWeight(cost + other.cost, evictionWeight + other.evictionWeight);
|
||||
}
|
||||
|
||||
TxWeight TxWeight::negate() const
|
||||
{
|
||||
return TxWeight(-cost, -evictionWeight);
|
||||
}
|
||||
|
||||
|
||||
WeightedTxInfo WeightedTxInfo::from(const CTransaction& tx, const CAmount& fee)
|
||||
std::pair<int64_t, int64_t> MempoolCostAndEvictionWeight(const CTransaction& tx, const CAmount& fee)
|
||||
{
|
||||
size_t memUsage = RecursiveDynamicUsage(tx);
|
||||
int64_t cost = std::max((int64_t) memUsage, (int64_t) MIN_TX_COST);
|
||||
|
@ -155,5 +47,5 @@ WeightedTxInfo WeightedTxInfo::from(const CTransaction& tx, const CAmount& fee)
|
|||
if (fee < DEFAULT_FEE) {
|
||||
evictionWeight += LOW_FEE_PENALTY;
|
||||
}
|
||||
return WeightedTxInfo(tx.GetHash(), TxWeight(cost, evictionWeight));
|
||||
return std::make_pair(cost, evictionWeight);
|
||||
}
|
||||
|
|
|
@ -9,19 +9,21 @@
|
|||
#include <map>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "logging.h"
|
||||
#include "random.h"
|
||||
#include "primitives/transaction.h"
|
||||
#include "policy/fees.h"
|
||||
#include "uint256.h"
|
||||
#include "util/time.h"
|
||||
#include "weighted_map.h"
|
||||
|
||||
const size_t DEFAULT_MEMPOOL_TOTAL_COST_LIMIT = 80000000;
|
||||
const int64_t DEFAULT_MEMPOOL_EVICTION_MEMORY_MINUTES = 60;
|
||||
|
||||
const size_t EVICTION_MEMORY_ENTRIES = 40000;
|
||||
const uint64_t MIN_TX_COST = 4000;
|
||||
const uint64_t LOW_FEE_PENALTY = 16000;
|
||||
const int64_t MIN_TX_COST = 4000;
|
||||
const int64_t LOW_FEE_PENALTY = 16000;
|
||||
|
||||
|
||||
// This class keeps track of transactions which have been recently evicted from the mempool
|
||||
|
@ -57,80 +59,62 @@ public:
|
|||
|
||||
// The mempool of a node holds a set of transactions. Each transaction has a *cost*,
|
||||
// which is an integer defined as:
|
||||
// max(serialized transaction size in bytes, 4000)
|
||||
// max(memory usage in bytes, 4000)
|
||||
//
|
||||
// Each transaction also has an *eviction weight*, which is *cost* + *fee_penalty*,
|
||||
// where *fee_penalty* is 16000 if the transaction pays a fee less than 10000 zatoshi,
|
||||
// otherwise 0.
|
||||
struct TxWeight {
|
||||
int64_t cost;
|
||||
int64_t evictionWeight; // *cost* + *fee_penalty*
|
||||
// where *fee_penalty* is 16000 if the transaction pays a fee less than the
|
||||
// conventional fee, otherwise 0.
|
||||
|
||||
TxWeight(int64_t cost_, int64_t evictionWeight_)
|
||||
: cost(cost_), evictionWeight(evictionWeight_) {}
|
||||
|
||||
TxWeight add(const TxWeight& other) const;
|
||||
TxWeight negate() const;
|
||||
};
|
||||
// Calculate cost and eviction weight based on the memory usage and fee.
|
||||
std::pair<int64_t, int64_t> MempoolCostAndEvictionWeight(const CTransaction& tx, const CAmount& fee);
|
||||
|
||||
|
||||
// This struct is a pair of txid, cost.
|
||||
struct WeightedTxInfo {
|
||||
uint256 txId;
|
||||
TxWeight txWeight;
|
||||
|
||||
WeightedTxInfo(uint256 txId_, TxWeight txWeight_) : txId(txId_), txWeight(txWeight_) {}
|
||||
|
||||
// Factory method which calculates cost based on size in bytes and fee.
|
||||
static WeightedTxInfo from(const CTransaction& tx, const CAmount& fee);
|
||||
};
|
||||
|
||||
|
||||
// The following class is a collection of transaction ids and their costs.
|
||||
// In order to be able to remove transactions randomly weighted by their cost,
|
||||
// we keep track of the total cost of all transactions in this collection.
|
||||
// For performance reasons, the collection is represented as a complete binary
|
||||
// tree where each node knows the sum of the weights of the children. This
|
||||
// allows for addition, removal, and random selection/dropping in logarithmic time.
|
||||
class WeightedTxTree
|
||||
class MempoolLimitTxSet
|
||||
{
|
||||
const int64_t capacity;
|
||||
size_t size = 0;
|
||||
|
||||
// The following two vectors are the tree representation of this collection.
|
||||
// We keep track of 3 data points for each node: A transaction's txid, its cost,
|
||||
// and the sum of the weights of all children and descendant of that node.
|
||||
std::vector<WeightedTxInfo> txIdAndWeights;
|
||||
std::vector<TxWeight> childWeights;
|
||||
|
||||
// The following map is to simplify removal. When removing a tx, we do so by txid.
|
||||
// This map allows looking up the transaction's index in the tree.
|
||||
std::map<uint256, size_t> txIdToIndexMap;
|
||||
|
||||
// Returns the sum of a node and all of its children's TxWeights for a given index.
|
||||
TxWeight getWeightAt(size_t index) const;
|
||||
|
||||
// When adding and removing a node we need to update its parent and all of its
|
||||
// ancestors to reflect its cost.
|
||||
void backPropagate(size_t fromIndex, const TxWeight& weightDelta);
|
||||
|
||||
// For a given random cost + fee penalty, this method recursively finds the
|
||||
// correct transaction. This is used by WeightedTxTree::maybeDropRandom().
|
||||
size_t findByEvictionWeight(size_t fromIndex, int64_t weightToFind) const;
|
||||
WeightedMap<uint256, int64_t, int64_t, GetRandInt64> txmap;
|
||||
int64_t capacity;
|
||||
int64_t cost;
|
||||
|
||||
public:
|
||||
WeightedTxTree(int64_t capacity_) : capacity(capacity_) {
|
||||
MempoolLimitTxSet(int64_t capacity_) : capacity(capacity_), cost(0) {
|
||||
assert(capacity >= 0);
|
||||
}
|
||||
|
||||
TxWeight getTotalWeight() const;
|
||||
int64_t getTotalWeight() const
|
||||
{
|
||||
return txmap.getTotalWeight();
|
||||
}
|
||||
bool empty() const
|
||||
{
|
||||
return txmap.empty();
|
||||
}
|
||||
void add(const uint256& txId, int64_t txCost, int64_t txWeight)
|
||||
{
|
||||
if (txmap.add(txId, txCost, txWeight)) {
|
||||
cost += txCost;
|
||||
}
|
||||
}
|
||||
void remove(const uint256& txId)
|
||||
{
|
||||
cost -= txmap.remove(txId).value_or(0);
|
||||
}
|
||||
|
||||
void add(const WeightedTxInfo& weightedTxInfo);
|
||||
void remove(const uint256& txId);
|
||||
|
||||
// If the total cost limit is exceeded, pick a random number based on the total cost
|
||||
// of the collection and remove the associated transaction.
|
||||
std::optional<uint256> maybeDropRandom();
|
||||
// If the total cost limit has not been exceeded, return std::nullopt. Otherwise,
|
||||
// pick a transaction at random with probability proportional to its eviction weight;
|
||||
// remove and return that transaction's txid.
|
||||
std::optional<uint256> maybeDropRandom()
|
||||
{
|
||||
if (cost <= capacity) {
|
||||
return std::nullopt;
|
||||
}
|
||||
LogPrint("mempool", "Mempool cost limit exceeded (cost=%d, limit=%d)\n", cost, capacity);
|
||||
assert(!txmap.empty());
|
||||
auto [txId, txCost, txWeight] = txmap.takeRandom().value();
|
||||
cost -= txCost;
|
||||
LogPrint("mempool", "Evicting transaction (txid=%s, cost=%d, weight=%d)\n",
|
||||
txId.ToString(), txCost, txWeight);
|
||||
return txId;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#endif // ZCASH_MEMPOOL_LIMIT_H
|
||||
|
|
|
@ -55,6 +55,11 @@ uint64_t GetRand(uint64_t nMax)
|
|||
return (nRand % nMax);
|
||||
}
|
||||
|
||||
int64_t GetRandInt64(int64_t nMax)
|
||||
{
|
||||
return GetRand(nMax);
|
||||
}
|
||||
|
||||
int GetRandInt(int nMax)
|
||||
{
|
||||
return GetRand(nMax);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
*/
|
||||
void GetRandBytes(unsigned char* buf, size_t num);
|
||||
uint64_t GetRand(uint64_t nMax);
|
||||
int64_t GetRandInt64(int64_t nMax);
|
||||
int GetRandInt(int nMax);
|
||||
uint256 GetRandHash();
|
||||
|
||||
|
|
|
@ -337,7 +337,7 @@ CTxMemPool::~CTxMemPool()
|
|||
{
|
||||
delete minerPolicyEstimator;
|
||||
delete recentlyEvicted;
|
||||
delete weightedTxTree;
|
||||
delete limitSet;
|
||||
}
|
||||
|
||||
void CTxMemPool::pruneSpent(const uint256 &hashTx, CCoins &coins)
|
||||
|
@ -371,7 +371,8 @@ bool CTxMemPool::addUnchecked(const uint256& hash, const CTxMemPoolEntry &entry,
|
|||
// Used by main.cpp AcceptToMemoryPool(), which DOES do
|
||||
// all the appropriate checks.
|
||||
LOCK(cs);
|
||||
weightedTxTree->add(WeightedTxInfo::from(entry.GetTx(), entry.GetFee()));
|
||||
auto [cost, evictionWeight] = MempoolCostAndEvictionWeight(entry.GetTx(), entry.GetFee());
|
||||
limitSet->add(entry.GetTx().GetHash(), cost, evictionWeight);
|
||||
indexed_transaction_set::iterator newit = mapTx.insert(entry).first;
|
||||
mapLinks.insert(make_pair(newit, TxLinks()));
|
||||
|
||||
|
@ -637,7 +638,7 @@ void CTxMemPool::remove(const CTransaction &origTx, std::list<CTransaction>& rem
|
|||
}
|
||||
RemoveStaged(setAllRemoves);
|
||||
for (CTransaction tx : removed) {
|
||||
weightedTxTree->remove(tx.GetHash());
|
||||
limitSet->remove(tx.GetHash());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1265,7 +1266,7 @@ size_t CTxMemPool::DynamicMemoryUsage() const {
|
|||
memusage::DynamicUsage(mapOrchardNullifiers);
|
||||
|
||||
// DoS mitigation
|
||||
total += memusage::DynamicUsage(recentlyEvicted) + memusage::DynamicUsage(weightedTxTree);
|
||||
total += memusage::DynamicUsage(recentlyEvicted) + memusage::DynamicUsage(limitSet);
|
||||
|
||||
// Insight-related structures
|
||||
size_t insight = 0;
|
||||
|
@ -1282,9 +1283,9 @@ void CTxMemPool::SetMempoolCostLimit(int64_t totalCostLimit, int64_t evictionMem
|
|||
LOCK(cs);
|
||||
LogPrint("mempool", "Setting mempool cost limit: (limit=%d, time=%d)\n", totalCostLimit, evictionMemorySeconds);
|
||||
delete recentlyEvicted;
|
||||
delete weightedTxTree;
|
||||
delete limitSet;
|
||||
recentlyEvicted = new RecentlyEvictedList(GetNodeClock(), evictionMemorySeconds);
|
||||
weightedTxTree = new WeightedTxTree(totalCostLimit);
|
||||
limitSet = new MempoolLimitTxSet(totalCostLimit);
|
||||
}
|
||||
|
||||
bool CTxMemPool::IsRecentlyEvicted(const uint256& txId) {
|
||||
|
@ -1295,7 +1296,7 @@ bool CTxMemPool::IsRecentlyEvicted(const uint256& txId) {
|
|||
void CTxMemPool::EnsureSizeLimit() {
|
||||
AssertLockHeld(cs);
|
||||
std::optional<uint256> maybeDropTxId;
|
||||
while ((maybeDropTxId = weightedTxTree->maybeDropRandom()).has_value()) {
|
||||
while ((maybeDropTxId = limitSet->maybeDropRandom()).has_value()) {
|
||||
uint256 txId = maybeDropTxId.value();
|
||||
recentlyEvicted->add(txId);
|
||||
std::list<CTransaction> removed;
|
||||
|
|
|
@ -366,7 +366,7 @@ private:
|
|||
std::map<uint256, const CTransaction*> mapSaplingNullifiers;
|
||||
std::map<uint256, const CTransaction*> mapOrchardNullifiers;
|
||||
RecentlyEvictedList* recentlyEvicted = new RecentlyEvictedList(GetNodeClock(), DEFAULT_MEMPOOL_EVICTION_MEMORY_MINUTES * 60);
|
||||
WeightedTxTree* weightedTxTree = new WeightedTxTree(DEFAULT_MEMPOOL_TOTAL_COST_LIMIT);
|
||||
MempoolLimitTxSet* limitSet = new MempoolLimitTxSet(DEFAULT_MEMPOOL_TOTAL_COST_LIMIT);
|
||||
|
||||
void checkNullifiers(ShieldedType type) const;
|
||||
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
// Copyright (c) 2019-2023 The Zcash developers
|
||||
// Distributed under the MIT software license, see the accompanying
|
||||
// file COPYING or https://www.opensource.org/licenses/mit-license.php .
|
||||
|
||||
#ifndef ZCASH_WEIGHTED_MAP_H
|
||||
#define ZCASH_WEIGHTED_MAP_H
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
// A WeightedMap represents a map from keys (of type K) to values (of type V),
|
||||
// each entry having a weight (of type W). Elements can be randomly selected and
|
||||
// removed from the map with probability in proportion to their weight. This is
|
||||
// used to implement mempool limiting specified in ZIP 401.
|
||||
//
|
||||
// In order to efficiently implement random selection by weight, we keep track
|
||||
// of the total weight of all keys in the map. For performance reasons, the
|
||||
// map is represented as a binary tree where each node knows the sum of the
|
||||
// weights of the children. This allows for addition, removal, and random
|
||||
// selection/dropping in logarithmic time.
|
||||
//
|
||||
// random(w) must be defined to return a uniform random value between zero
|
||||
// inclusive and w exclusive. The type W must support addition, binary and
|
||||
// unary -, and < comparisons, and W() must construct the zero value (these
|
||||
// constraints are met for primitive signed integer types).
|
||||
template <typename K, typename V, typename W, W random(W)>
|
||||
class WeightedMap
|
||||
{
|
||||
struct Node {
|
||||
K key;
|
||||
V value;
|
||||
W weight;
|
||||
W sumOfDescendantWeights;
|
||||
};
|
||||
|
||||
// The following vector is the tree representation of this collection.
|
||||
// For each node, we keep track of the key, its associated value,
|
||||
// its weight, and the sum of the weights of all its descendants.
|
||||
std::vector<Node> nodes;
|
||||
|
||||
// The following map is to simplify removal.
|
||||
std::map<K, size_t> indexMap;
|
||||
|
||||
static inline size_t leftChild(size_t i) { return i*2 + 1; }
|
||||
static inline size_t rightChild(size_t i) { return i*2 + 2; }
|
||||
static inline size_t parent(size_t i) { return (i-1)/2; }
|
||||
|
||||
public:
|
||||
// Check internal invariants (for tests).
|
||||
void checkInvariants() const
|
||||
{
|
||||
assert(indexMap.size() == nodes.size());
|
||||
for (size_t i = 0; i < nodes.size(); i++) {
|
||||
assert(indexMap.at(nodes.at(i).key) == i);
|
||||
assert(nodes.at(i).sumOfDescendantWeights == getWeightAt(leftChild(i)) + getWeightAt(rightChild(i)));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Return the sum of weights of the node at a given index and all of its descendants.
|
||||
W getWeightAt(size_t index) const
|
||||
{
|
||||
if (index >= nodes.size()) {
|
||||
return W();
|
||||
}
|
||||
auto& node = nodes.at(index);
|
||||
return node.weight + node.sumOfDescendantWeights;
|
||||
}
|
||||
|
||||
// When adding and removing a node we need to update its parent and all of its
|
||||
// ancestors to reflect its weight.
|
||||
void backPropagate(size_t fromIndex, W weightDelta)
|
||||
{
|
||||
while (fromIndex > 0) {
|
||||
fromIndex = parent(fromIndex);
|
||||
nodes[fromIndex].sumOfDescendantWeights += weightDelta;
|
||||
}
|
||||
}
|
||||
|
||||
// For a given random weight, this method recursively finds the index of the
|
||||
// correct entry. This is used by WeightedMap::takeRandom().
|
||||
size_t findByWeight(size_t fromIndex, W weightToFind) const
|
||||
{
|
||||
W leftWeight = getWeightAt(leftChild(fromIndex));
|
||||
// On Left
|
||||
if (weightToFind < leftWeight) {
|
||||
return findByWeight(leftChild(fromIndex), weightToFind);
|
||||
}
|
||||
W rightWeight = getWeightAt(fromIndex) - getWeightAt(rightChild(fromIndex));
|
||||
// Found
|
||||
if (weightToFind < rightWeight) {
|
||||
return fromIndex;
|
||||
}
|
||||
// On Right
|
||||
return findByWeight(rightChild(fromIndex), weightToFind - rightWeight);
|
||||
}
|
||||
|
||||
public:
|
||||
WeightedMap() {}
|
||||
|
||||
// Return the total weight of all entries in the map.
|
||||
W getTotalWeight() const
|
||||
{
|
||||
return getWeightAt(0);
|
||||
}
|
||||
|
||||
// Return true when the map has no entries.
|
||||
bool empty() const
|
||||
{
|
||||
return nodes.empty();
|
||||
}
|
||||
|
||||
// Return the number of entries.
|
||||
size_t size() const
|
||||
{
|
||||
return nodes.size();
|
||||
}
|
||||
|
||||
// Return false if the key already exists in the map.
|
||||
// Otherwise, add an entry mapping `key` to `value` with the given weight,
|
||||
// and return true. The weight must be positive.
|
||||
bool add(K key, V value, W weight)
|
||||
{
|
||||
assert(W() < weight);
|
||||
if (indexMap.count(key) > 0) {
|
||||
return false;
|
||||
}
|
||||
size_t index = nodes.size();
|
||||
nodes.push_back(Node {
|
||||
.key = key,
|
||||
.value = value,
|
||||
.weight = weight,
|
||||
.sumOfDescendantWeights = W(),
|
||||
});
|
||||
indexMap[key] = index;
|
||||
backPropagate(index, weight);
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the given key is not present in the map, return std::nullopt.
|
||||
// Otherwise, remove that key's entry and return its associated value.
|
||||
std::optional<V> remove(K key)
|
||||
{
|
||||
auto it = indexMap.find(key);
|
||||
if (it == indexMap.end()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
size_t removeIndex = it->second;
|
||||
V removeValue = nodes.at(removeIndex).value;
|
||||
|
||||
size_t lastIndex = nodes.size()-1;
|
||||
Node lastNode = nodes.at(lastIndex);
|
||||
W weightDelta = lastNode.weight - nodes.at(removeIndex).weight;
|
||||
backPropagate(lastIndex, -lastNode.weight);
|
||||
|
||||
if (removeIndex < lastIndex) {
|
||||
nodes[removeIndex].key = lastNode.key;
|
||||
nodes[removeIndex].value = lastNode.value;
|
||||
nodes[removeIndex].weight = lastNode.weight;
|
||||
// nodes[removeIndex].sumOfDescendantWeights should not change here.
|
||||
indexMap[lastNode.key] = removeIndex;
|
||||
backPropagate(removeIndex, weightDelta);
|
||||
}
|
||||
|
||||
indexMap.erase(it);
|
||||
nodes.pop_back();
|
||||
return removeValue;
|
||||
}
|
||||
|
||||
// If the map is empty, return std::nullopt. Otherwise, pick a random entry
|
||||
// with probability proportional to its weight; remove it and return a tuple of
|
||||
// the key, its associated value, and its weight.
|
||||
std::optional<std::tuple<K, V, W>> takeRandom()
|
||||
{
|
||||
if (empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
W totalWeight = getTotalWeight();
|
||||
W randomWeight = random(totalWeight);
|
||||
assert(W() <= randomWeight && randomWeight < totalWeight);
|
||||
size_t index = findByWeight(0, randomWeight);
|
||||
assert(index < nodes.size());
|
||||
const Node& drop = nodes.at(index);
|
||||
auto res = std::make_tuple(drop.key, drop.value, drop.weight); // copy values
|
||||
remove(drop.key);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // ZCASH_WEIGHTED_MAP_H
|
Loading…
Reference in New Issue