// Copyright (c) 2019-2023 The Zcash developers // Distributed under the MIT software license, see the accompanying // file COPYING or https://www.opensource.org/licenses/mit-license.php . #ifndef ZCASH_WEIGHTED_MAP_H #define ZCASH_WEIGHTED_MAP_H #include #include #include #include // A WeightedMap represents a map from keys (of type K) to values (of type V), // each entry having a weight (of type W). Elements can be randomly selected and // removed from the map with probability in proportion to their weight. This is // used to implement mempool limiting specified in ZIP 401, and the block template // construction algorithm specified in ZIP 317. // // In order to efficiently implement random selection by weight, we keep track // of the total weight of all keys in the map. For performance reasons, the // map is represented as a binary tree where each node knows the sum of the // weights of the children. This allows for addition, removal, and random // selection/dropping in logarithmic time. // // random(w) (which will only be called with positive w) must be defined to // return a uniform random value between zero inclusive and w exclusive. // The type W must be a signed numeric type that supports addition, binary // and unary -, and < and <= comparisons, and W() must construct the zero value // (these constraints are met for primitive signed integer types). template class WeightedMap { // W must be a signed numeric type. static_assert(std::numeric_limits::min() < W()); struct Node { K key; V value; W weight; W sumOfDescendantWeights; }; // The following vector is the tree representation of this collection. // For each node, we keep track of the key, its associated value, // its weight, and the sum of the weights of all its descendants. std::vector nodes; // The following map is to simplify removal. std::map indexMap; static inline size_t leftChild(size_t i) { return i*2 + 1; } static inline size_t rightChild(size_t i) { return i*2 + 2; } static inline size_t parent(size_t i) { return (i-1)/2; } public: // Check internal invariants (for tests). void checkInvariants() const { assert(indexMap.size() == nodes.size()); for (size_t i = 0; i < nodes.size(); i++) { assert(indexMap.at(nodes.at(i).key) == i); assert(nodes.at(i).sumOfDescendantWeights == getWeightAt(leftChild(i)) + getWeightAt(rightChild(i))); } } private: // Return the sum of weights of the node at a given index and all of its descendants. W getWeightAt(size_t index) const { if (index >= nodes.size()) { return W(); } auto& node = nodes.at(index); return node.weight + node.sumOfDescendantWeights; } // When adding and removing a node we need to update its parent and all of its // ancestors to reflect its weight. void backPropagate(size_t fromIndex, W weightDelta) { while (fromIndex > 0) { fromIndex = parent(fromIndex); nodes[fromIndex].sumOfDescendantWeights += weightDelta; } } // For a given random weight, this method recursively finds the index of the // correct entry. This is used by WeightedMap::takeRandom(). size_t findByWeight(size_t fromIndex, W weightToFind) const { W leftWeight = getWeightAt(leftChild(fromIndex)); // On Left if (weightToFind < leftWeight) { return findByWeight(leftChild(fromIndex), weightToFind); } W rightWeight = getWeightAt(fromIndex) - getWeightAt(rightChild(fromIndex)); // Found if (weightToFind < rightWeight) { return fromIndex; } // On Right return findByWeight(rightChild(fromIndex), weightToFind - rightWeight); } public: WeightedMap() {} // Return the total weight of all entries in the map. W getTotalWeight() const { return getWeightAt(0); } // Return true when the map has no entries. bool empty() const { return nodes.empty(); } // Return the number of entries. size_t size() const { return nodes.size(); } // Return false if the key already exists in the map. // Otherwise, add an entry mapping `key` to `value` with the given weight, // and return true. The weight must be positive. bool add(K key, V value, W weight) { assert(W() < weight); if (indexMap.count(key) > 0) { return false; } size_t index = nodes.size(); nodes.push_back(Node { .key = key, .value = value, .weight = weight, .sumOfDescendantWeights = W(), }); indexMap[key] = index; backPropagate(index, weight); return true; } // If the given key is not present in the map, return std::nullopt. // Otherwise, remove that key's entry and return its associated value. std::optional remove(K key) { auto it = indexMap.find(key); if (it == indexMap.end()) { return std::nullopt; } size_t removeIndex = it->second; V removeValue = nodes.at(removeIndex).value; size_t lastIndex = nodes.size()-1; Node lastNode = nodes.at(lastIndex); W weightDelta = lastNode.weight - nodes.at(removeIndex).weight; backPropagate(lastIndex, -lastNode.weight); if (removeIndex < lastIndex) { nodes[removeIndex].key = lastNode.key; nodes[removeIndex].value = lastNode.value; nodes[removeIndex].weight = lastNode.weight; // nodes[removeIndex].sumOfDescendantWeights should not change here. indexMap[lastNode.key] = removeIndex; backPropagate(removeIndex, weightDelta); } indexMap.erase(it); nodes.pop_back(); return removeValue; } // If the map is empty, return std::nullopt. Otherwise, pick a random entry // with probability proportional to its weight; remove it and return a tuple of // the key, its associated value, and its weight. std::optional> takeRandom() { if (empty()) { return std::nullopt; } W totalWeight = getTotalWeight(); assert(W() < totalWeight); W randomWeight = random(totalWeight); assert(W() <= randomWeight && randomWeight < totalWeight); size_t index = findByWeight(0, randomWeight); assert(index < nodes.size()); const Node& drop = nodes.at(index); auto res = std::make_tuple(drop.key, drop.value, drop.weight); // copy values remove(drop.key); return res; } }; #endif // ZCASH_WEIGHTED_MAP_H