diff --git a/src/coins.cpp b/src/coins.cpp index 250354614..f291cea2d 100644 --- a/src/coins.cpp +++ b/src/coins.cpp @@ -43,7 +43,7 @@ bool CCoins::Spend(uint32_t nPos) return true; } bool CCoinsView::GetAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const { return false; } -bool CCoinsView::GetNullifier(const uint256 &nullifier) const { return false; } +bool CCoinsView::GetNullifier(const uint256 &nullifier, NullifierType type) const { return false; } bool CCoinsView::GetCoins(const uint256 &txid, CCoins &coins) const { return false; } bool CCoinsView::HaveCoins(const uint256 &txid) const { return false; } uint256 CCoinsView::GetBestBlock() const { return uint256(); } @@ -52,14 +52,15 @@ bool CCoinsView::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashBlock, const uint256 &hashAnchor, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers) { return false; } + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap &mapSaplingNullifiers) { return false; } bool CCoinsView::GetStats(CCoinsStats &stats) const { return false; } CCoinsViewBacked::CCoinsViewBacked(CCoinsView *viewIn) : base(viewIn) { } bool CCoinsViewBacked::GetAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const { return base->GetAnchorAt(rt, tree); } -bool CCoinsViewBacked::GetNullifier(const uint256 &nullifier) const { return base->GetNullifier(nullifier); } +bool CCoinsViewBacked::GetNullifier(const uint256 &nullifier, NullifierType type) const { return base->GetNullifier(nullifier, type); } bool CCoinsViewBacked::GetCoins(const uint256 &txid, CCoins &coins) const { return base->GetCoins(txid, coins); } bool CCoinsViewBacked::HaveCoins(const uint256 &txid) const { return base->HaveCoins(txid); } uint256 CCoinsViewBacked::GetBestBlock() const { return base->GetBestBlock(); } @@ -69,7 +70,8 @@ bool CCoinsViewBacked::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashBlock, const uint256 &hashAnchor, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers) { return base->BatchWrite(mapCoins, hashBlock, hashAnchor, mapAnchors, mapNullifiers); } + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap &mapSaplingNullifiers) { return base->BatchWrite(mapCoins, hashBlock, hashAnchor, mapAnchors, mapSproutNullifiers, mapSaplingNullifiers); } bool CCoinsViewBacked::GetStats(CCoinsStats &stats) const { return base->GetStats(stats); } CCoinsKeyHasher::CCoinsKeyHasher() : salt(GetRandHash()) {} @@ -84,7 +86,8 @@ CCoinsViewCache::~CCoinsViewCache() size_t CCoinsViewCache::DynamicMemoryUsage() const { return memusage::DynamicUsage(cacheCoins) + memusage::DynamicUsage(cacheAnchors) + - memusage::DynamicUsage(cacheNullifiers) + + memusage::DynamicUsage(cacheSproutNullifiers) + + memusage::DynamicUsage(cacheSaplingNullifiers) + cachedCoinsUsage; } @@ -130,16 +133,27 @@ bool CCoinsViewCache::GetAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tr return true; } -bool CCoinsViewCache::GetNullifier(const uint256 &nullifier) const { - CNullifiersMap::iterator it = cacheNullifiers.find(nullifier); - if (it != cacheNullifiers.end()) +bool CCoinsViewCache::GetNullifier(const uint256 &nullifier, NullifierType type) const { + CNullifiersMap* cacheToUse; + switch (type) { + case SPROUT_NULLIFIER: + cacheToUse = &cacheSproutNullifiers; + break; + case SAPLING_NULLIFIER: + cacheToUse = &cacheSaplingNullifiers; + break; + default: + throw std::runtime_error("Unknown nullifier type " + type); + } + CNullifiersMap::iterator it = cacheToUse->find(nullifier); + if (it != cacheToUse->end()) return it->second.entered; CNullifiersCacheEntry entry; - bool tmp = base->GetNullifier(nullifier); + bool tmp = base->GetNullifier(nullifier, type); entry.entered = tmp; - cacheNullifiers.insert(std::make_pair(nullifier, entry)); + cacheToUse->insert(std::make_pair(nullifier, entry)); return tmp; } @@ -196,10 +210,19 @@ void CCoinsViewCache::PopAnchor(const uint256 &newrt) { } } -void CCoinsViewCache::SetNullifier(const uint256 &nullifier, bool spent) { - std::pair ret = cacheNullifiers.insert(std::make_pair(nullifier, CNullifiersCacheEntry())); - ret.first->second.entered = spent; - ret.first->second.flags |= CNullifiersCacheEntry::DIRTY; +void CCoinsViewCache::SetNullifiers(const CTransaction& tx, bool spent) { + for (const JSDescription &joinsplit : tx.vjoinsplit) { + for (const uint256 &nullifier : joinsplit.nullifiers) { + std::pair ret = cacheSproutNullifiers.insert(std::make_pair(nullifier, CNullifiersCacheEntry())); + ret.first->second.entered = spent; + ret.first->second.flags |= CNullifiersCacheEntry::DIRTY; + } + } + for (const SpendDescription &spendDescription : tx.vShieldedSpend) { + std::pair ret = cacheSaplingNullifiers.insert(std::make_pair(spendDescription.nullifier, CNullifiersCacheEntry())); + ret.first->second.entered = spent; + ret.first->second.flags |= CNullifiersCacheEntry::DIRTY; + } } bool CCoinsViewCache::GetCoins(const uint256 &txid, CCoins &coins) const { @@ -267,11 +290,34 @@ void CCoinsViewCache::SetBestBlock(const uint256 &hashBlockIn) { hashBlock = hashBlockIn; } +void BatchWriteNullifiers(CNullifiersMap &mapNullifiers, CNullifiersMap &cacheNullifiers) +{ + for (CNullifiersMap::iterator child_it = mapNullifiers.begin(); child_it != mapNullifiers.end();) { + if (child_it->second.flags & CNullifiersCacheEntry::DIRTY) { // Ignore non-dirty entries (optimization). + CNullifiersMap::iterator parent_it = cacheNullifiers.find(child_it->first); + + if (parent_it == cacheNullifiers.end()) { + CNullifiersCacheEntry& entry = cacheNullifiers[child_it->first]; + entry.entered = child_it->second.entered; + entry.flags = CNullifiersCacheEntry::DIRTY; + } else { + if (parent_it->second.entered != child_it->second.entered) { + parent_it->second.entered = child_it->second.entered; + parent_it->second.flags |= CNullifiersCacheEntry::DIRTY; + } + } + } + CNullifiersMap::iterator itOld = child_it++; + mapNullifiers.erase(itOld); + } +} + bool CCoinsViewCache::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashBlockIn, const uint256 &hashAnchorIn, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers) { + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap &mapSaplingNullifiers) { assert(!hasModifier); for (CCoinsMap::iterator it = mapCoins.begin(); it != mapCoins.end();) { if (it->second.flags & CCoinsCacheEntry::DIRTY) { // Ignore non-dirty entries (optimization). @@ -333,25 +379,8 @@ bool CCoinsViewCache::BatchWrite(CCoinsMap &mapCoins, mapAnchors.erase(itOld); } - for (CNullifiersMap::iterator child_it = mapNullifiers.begin(); child_it != mapNullifiers.end();) - { - if (child_it->second.flags & CNullifiersCacheEntry::DIRTY) { // Ignore non-dirty entries (optimization). - CNullifiersMap::iterator parent_it = cacheNullifiers.find(child_it->first); - - if (parent_it == cacheNullifiers.end()) { - CNullifiersCacheEntry& entry = cacheNullifiers[child_it->first]; - entry.entered = child_it->second.entered; - entry.flags = CNullifiersCacheEntry::DIRTY; - } else { - if (parent_it->second.entered != child_it->second.entered) { - parent_it->second.entered = child_it->second.entered; - parent_it->second.flags |= CNullifiersCacheEntry::DIRTY; - } - } - } - CNullifiersMap::iterator itOld = child_it++; - mapNullifiers.erase(itOld); - } + ::BatchWriteNullifiers(mapSproutNullifiers, cacheSproutNullifiers); + ::BatchWriteNullifiers(mapSaplingNullifiers, cacheSaplingNullifiers); hashAnchor = hashAnchorIn; hashBlock = hashBlockIn; @@ -359,10 +388,11 @@ bool CCoinsViewCache::BatchWrite(CCoinsMap &mapCoins, } bool CCoinsViewCache::Flush() { - bool fOk = base->BatchWrite(cacheCoins, hashBlock, hashAnchor, cacheAnchors, cacheNullifiers); + bool fOk = base->BatchWrite(cacheCoins, hashBlock, hashAnchor, cacheAnchors, cacheSproutNullifiers, cacheSaplingNullifiers); cacheCoins.clear(); cacheAnchors.clear(); - cacheNullifiers.clear(); + cacheSproutNullifiers.clear(); + cacheSaplingNullifiers.clear(); cachedCoinsUsage = 0; return fOk; } @@ -400,7 +430,7 @@ bool CCoinsViewCache::HaveJoinSplitRequirements(const CTransaction& tx) const { BOOST_FOREACH(const uint256& nullifier, joinsplit.nullifiers) { - if (GetNullifier(nullifier)) { + if (GetNullifier(nullifier, SPROUT_NULLIFIER)) { // If the nullifier is set, this transaction // double-spends! return false; @@ -423,6 +453,11 @@ bool CCoinsViewCache::HaveJoinSplitRequirements(const CTransaction& tx) const intermediates.insert(std::make_pair(tree.root(), tree)); } + for (const SpendDescription &spendDescription : tx.vShieldedSpend) { + if (GetNullifier(spendDescription.nullifier, SAPLING_NULLIFIER)) // Prevent double spends + return false; + } + return true; } diff --git a/src/coins.h b/src/coins.h index 4e1cac438..e2b454649 100644 --- a/src/coins.h +++ b/src/coins.h @@ -298,6 +298,12 @@ struct CNullifiersCacheEntry CNullifiersCacheEntry() : entered(false), flags(0) {} }; +enum NullifierType +{ + SPROUT_NULLIFIER, + SAPLING_NULLIFIER, +}; + typedef boost::unordered_map CCoinsMap; typedef boost::unordered_map CAnchorsMap; typedef boost::unordered_map CNullifiersMap; @@ -324,7 +330,7 @@ public: virtual bool GetAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const; //! Determine whether a nullifier is spent or not - virtual bool GetNullifier(const uint256 &nullifier) const; + virtual bool GetNullifier(const uint256 &nullifier, NullifierType type) const; //! Retrieve the CCoins (unspent transaction outputs) for a given txid virtual bool GetCoins(const uint256 &txid, CCoins &coins) const; @@ -345,7 +351,8 @@ public: const uint256 &hashBlock, const uint256 &hashAnchor, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers); + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap &mapSaplingNullifiers); //! Calculate statistics about the unspent transaction output set virtual bool GetStats(CCoinsStats &stats) const; @@ -364,7 +371,7 @@ protected: public: CCoinsViewBacked(CCoinsView *viewIn); bool GetAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const; - bool GetNullifier(const uint256 &nullifier) const; + bool GetNullifier(const uint256 &nullifier, NullifierType type) const; bool GetCoins(const uint256 &txid, CCoins &coins) const; bool HaveCoins(const uint256 &txid) const; uint256 GetBestBlock() const; @@ -374,7 +381,8 @@ public: const uint256 &hashBlock, const uint256 &hashAnchor, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers); + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap &mapSaplingNullifiers); bool GetStats(CCoinsStats &stats) const; }; @@ -417,7 +425,8 @@ protected: mutable CCoinsMap cacheCoins; mutable uint256 hashAnchor; mutable CAnchorsMap cacheAnchors; - mutable CNullifiersMap cacheNullifiers; + mutable CNullifiersMap cacheSproutNullifiers; + mutable CNullifiersMap cacheSaplingNullifiers; /* Cached dynamic memory usage for the inner CCoins objects. */ mutable size_t cachedCoinsUsage; @@ -428,7 +437,7 @@ public: // Standard CCoinsView methods bool GetAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const; - bool GetNullifier(const uint256 &nullifier) const; + bool GetNullifier(const uint256 &nullifier, NullifierType type) const; bool GetCoins(const uint256 &txid, CCoins &coins) const; bool HaveCoins(const uint256 &txid) const; uint256 GetBestBlock() const; @@ -438,7 +447,8 @@ public: const uint256 &hashBlock, const uint256 &hashAnchor, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers); + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap &mapSaplingNullifiers); // Adds the tree to mapAnchors and sets the current commitment @@ -449,8 +459,8 @@ public: // the new current root. void PopAnchor(const uint256 &rt); - // Marks a nullifier as spent or not. - void SetNullifier(const uint256 &nullifier, bool spent); + // Marks nullifiers for a given transaction as spent or not. + void SetNullifiers(const CTransaction& tx, bool spent); /** * Return a pointer to CCoins in the cache, or NULL if not found. This is diff --git a/src/gtest/test_mempool.cpp b/src/gtest/test_mempool.cpp index 6ee3eb1b1..981c4eb08 100644 --- a/src/gtest/test_mempool.cpp +++ b/src/gtest/test_mempool.cpp @@ -23,7 +23,7 @@ public: return false; } - bool GetNullifier(const uint256 &nf) const { + bool GetNullifier(const uint256 &nf, NullifierType type) const { return false; } @@ -56,7 +56,8 @@ public: const uint256 &hashBlock, const uint256 &hashAnchor, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers) { + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap &mapSaplingNullifiers) { return false; } diff --git a/src/gtest/test_validation.cpp b/src/gtest/test_validation.cpp index 4cd3eacca..710e3c600 100644 --- a/src/gtest/test_validation.cpp +++ b/src/gtest/test_validation.cpp @@ -25,7 +25,7 @@ public: return false; } - bool GetNullifier(const uint256 &nf) const { + bool GetNullifier(const uint256 &nf, NullifierType type) const { return false; } @@ -51,7 +51,8 @@ public: const uint256 &hashBlock, const uint256 &hashAnchor, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers) { + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap saplingNullifiersMap) { return false; } diff --git a/src/main.cpp b/src/main.cpp index fb6196c70..634dc177b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1040,11 +1040,11 @@ bool CheckTransactionWithoutProofVerification(const CTransaction& tx, CValidatio } // Transactions can contain empty `vin` and `vout` so long as - // `vjoinsplit` is non-empty. - if (tx.vin.empty() && tx.vjoinsplit.empty()) + // either `vjoinsplit` or `vShieldedSpend` are non-empty. + if (tx.vin.empty() && tx.vjoinsplit.empty() && tx.vShieldedSpend.empty()) return state.DoS(10, error("CheckTransaction(): vin empty"), REJECT_INVALID, "bad-txns-vin-empty"); - if (tx.vout.empty() && tx.vjoinsplit.empty()) + if (tx.vout.empty() && tx.vjoinsplit.empty() && tx.vShieldedSpend.empty()) return state.DoS(10, error("CheckTransaction(): vout empty"), REJECT_INVALID, "bad-txns-vout-empty"); @@ -1134,16 +1134,31 @@ bool CheckTransactionWithoutProofVerification(const CTransaction& tx, CValidatio } // Check for duplicate joinsplit nullifiers in this transaction - set vJoinSplitNullifiers; - BOOST_FOREACH(const JSDescription& joinsplit, tx.vjoinsplit) { - BOOST_FOREACH(const uint256& nf, joinsplit.nullifiers) + set vJoinSplitNullifiers; + BOOST_FOREACH(const JSDescription& joinsplit, tx.vjoinsplit) { - if (vJoinSplitNullifiers.count(nf)) - return state.DoS(100, error("CheckTransaction(): duplicate nullifiers"), - REJECT_INVALID, "bad-joinsplits-nullifiers-duplicate"); + BOOST_FOREACH(const uint256& nf, joinsplit.nullifiers) + { + if (vJoinSplitNullifiers.count(nf)) + return state.DoS(100, error("CheckTransaction(): duplicate nullifiers"), + REJECT_INVALID, "bad-joinsplits-nullifiers-duplicate"); - vJoinSplitNullifiers.insert(nf); + vJoinSplitNullifiers.insert(nf); + } + } + } + + // Check for duplicate sapling nullifiers in this transaction + { + set vSaplingNullifiers; + BOOST_FOREACH(const SpendDescription& spend_desc, tx.vShieldedSpend) + { + if (vSaplingNullifiers.count(spend_desc.nullifier)) + return state.DoS(100, error("CheckTransaction(): duplicate nullifiers"), + REJECT_INVALID, "bad-spend-description-nullifiers-duplicate"); + + vSaplingNullifiers.insert(spend_desc.nullifier); } } @@ -1270,12 +1285,16 @@ bool AcceptToMemoryPool(CTxMemPool& pool, CValidationState &state, const CTransa } BOOST_FOREACH(const JSDescription &joinsplit, tx.vjoinsplit) { BOOST_FOREACH(const uint256 &nf, joinsplit.nullifiers) { - if (pool.mapNullifiers.count(nf)) - { + if (pool.nullifierExists(nf, SPROUT_NULLIFIER)) { return false; } } } + for (const SpendDescription &spendDescription : tx.vShieldedSpend) { + if (pool.nullifierExists(spendDescription.nullifier, SAPLING_NULLIFIER)) { + return false; + } + } } { @@ -1775,11 +1794,7 @@ void UpdateCoins(const CTransaction& tx, CCoinsViewCache& inputs, CTxUndo &txund } // spend nullifiers - BOOST_FOREACH(const JSDescription &joinsplit, tx.vjoinsplit) { - BOOST_FOREACH(const uint256 &nf, joinsplit.nullifiers) { - inputs.SetNullifier(nf, true); - } - } + inputs.SetNullifiers(tx, true); // add outputs inputs.ModifyCoins(tx.GetHash())->FromTx(tx, nHeight); @@ -2097,11 +2112,7 @@ bool DisconnectBlock(CBlock& block, CValidationState& state, CBlockIndex* pindex } // unspend nullifiers - BOOST_FOREACH(const JSDescription &joinsplit, tx.vjoinsplit) { - BOOST_FOREACH(const uint256 &nf, joinsplit.nullifiers) { - view.SetNullifier(nf, false); - } - } + view.SetNullifiers(tx, false); // restore inputs if (i > 0) { // not coinbases diff --git a/src/test/coins_tests.cpp b/src/test/coins_tests.cpp index 383616ae0..8225fce5d 100644 --- a/src/test/coins_tests.cpp +++ b/src/test/coins_tests.cpp @@ -11,6 +11,7 @@ #include "consensus/validation.h" #include "main.h" #include "undo.h" +#include "primitives/transaction.h" #include "pubkey.h" #include @@ -27,7 +28,8 @@ class CCoinsViewTest : public CCoinsView uint256 hashBestAnchor_; std::map map_; std::map mapAnchors_; - std::map mapNullifiers_; + std::map mapSproutNullifiers_; + std::map mapSaplingNullifiers_; public: CCoinsViewTest() { @@ -50,11 +52,21 @@ public: } } - bool GetNullifier(const uint256 &nf) const + bool GetNullifier(const uint256 &nf, NullifierType type) const { - std::map::const_iterator it = mapNullifiers_.find(nf); - - if (it == mapNullifiers_.end()) { + const std::map* mapToUse; + switch (type) { + case SPROUT_NULLIFIER: + mapToUse = &mapSproutNullifiers_; + break; + case SAPLING_NULLIFIER: + mapToUse = &mapSaplingNullifiers_; + break; + default: + throw std::runtime_error("Unknown nullifier type " + type); + } + std::map::const_iterator it = mapToUse->find(nf); + if (it == mapToUse->end()) { return false; } else { // The map shouldn't contain any false entries. @@ -87,11 +99,25 @@ public: uint256 GetBestBlock() const { return hashBestBlock_; } + void BatchWriteNullifiers(CNullifiersMap& mapNullifiers, std::map& cacheNullifiers) + { + for (CNullifiersMap::iterator it = mapNullifiers.begin(); it != mapNullifiers.end(); ) { + if (it->second.entered) { + cacheNullifiers[it->first] = true; + } else { + cacheNullifiers.erase(it->first); + } + mapNullifiers.erase(it++); + } + mapNullifiers.clear(); + } + bool BatchWrite(CCoinsMap& mapCoins, const uint256& hashBlock, const uint256& hashAnchor, CAnchorsMap& mapAnchors, - CNullifiersMap& mapNullifiers) + CNullifiersMap& mapSproutNullifiers, + CNullifiersMap& mapSaplingNullifiers) { for (CCoinsMap::iterator it = mapCoins.begin(); it != mapCoins.end(); ) { map_[it->first] = it->second.coins; @@ -112,17 +138,12 @@ public: } mapAnchors.erase(it++); } - for (CNullifiersMap::iterator it = mapNullifiers.begin(); it != mapNullifiers.end(); ) { - if (it->second.entered) { - mapNullifiers_[it->first] = true; - } else { - mapNullifiers_.erase(it->first); - } - mapNullifiers.erase(it++); - } + + BatchWriteNullifiers(mapSproutNullifiers, mapSproutNullifiers_); + BatchWriteNullifiers(mapSaplingNullifiers, mapSaplingNullifiers_); + mapCoins.clear(); mapAnchors.clear(); - mapNullifiers.clear(); hashBestBlock_ = hashBlock; hashBestAnchor_ = hashAnchor; return true; @@ -141,7 +162,8 @@ public: // Manually recompute the dynamic usage of the whole data, and compare it. size_t ret = memusage::DynamicUsage(cacheCoins) + memusage::DynamicUsage(cacheAnchors) + - memusage::DynamicUsage(cacheNullifiers); + memusage::DynamicUsage(cacheSproutNullifiers) + + memusage::DynamicUsage(cacheSaplingNullifiers); for (CCoinsMap::iterator it = cacheCoins.begin(); it != cacheCoins.end(); it++) { ret += it->second.coins.DynamicMemoryUsage(); } @@ -150,6 +172,31 @@ public: }; +class TxWithNullifiers +{ +public: + CTransaction tx; + uint256 sproutNullifier; + uint256 saplingNullifier; + + TxWithNullifiers() + { + CMutableTransaction mutableTx; + + sproutNullifier = GetRandHash(); + JSDescription jsd; + jsd.nullifiers[0] = sproutNullifier; + mutableTx.vjoinsplit.emplace_back(jsd); + + saplingNullifier = GetRandHash(); + SpendDescription sd; + sd.nullifier = saplingNullifier; + mutableTx.vShieldedSpend.push_back(sd); + + tx = CTransaction(mutableTx); + } +}; + } uint256 appendRandomCommitment(ZCIncrementalMerkleTree &tree) @@ -166,6 +213,17 @@ uint256 appendRandomCommitment(ZCIncrementalMerkleTree &tree) BOOST_FIXTURE_TEST_SUITE(coins_tests, BasicTestingSetup) +void checkNullifierCache(const CCoinsViewCacheTest &cache, const TxWithNullifiers &txWithNullifiers, bool shouldBeInCache) { + // Make sure the nullifiers have not gotten mixed up + BOOST_CHECK(!cache.GetNullifier(txWithNullifiers.sproutNullifier, SAPLING_NULLIFIER)); + BOOST_CHECK(!cache.GetNullifier(txWithNullifiers.saplingNullifier, SPROUT_NULLIFIER)); + // Check if the nullifiers either are or are not in the cache + bool containsSproutNullifier = cache.GetNullifier(txWithNullifiers.sproutNullifier, SPROUT_NULLIFIER); + bool containsSaplingNullifier = cache.GetNullifier(txWithNullifiers.saplingNullifier, SAPLING_NULLIFIER); + BOOST_CHECK(containsSproutNullifier == shouldBeInCache); + BOOST_CHECK(containsSaplingNullifier == shouldBeInCache); +} + BOOST_AUTO_TEST_CASE(nullifier_regression_test) { // Correct behavior: @@ -173,16 +231,18 @@ BOOST_AUTO_TEST_CASE(nullifier_regression_test) CCoinsViewTest base; CCoinsViewCacheTest cache1(&base); + TxWithNullifiers txWithNullifiers; + // Insert a nullifier into the base. - uint256 nf = GetRandHash(); - cache1.SetNullifier(nf, true); + cache1.SetNullifiers(txWithNullifiers.tx, true); + checkNullifierCache(cache1, txWithNullifiers, true); cache1.Flush(); // Flush to base. // Remove the nullifier from cache - cache1.SetNullifier(nf, false); + cache1.SetNullifiers(txWithNullifiers.tx, false); // The nullifier now should be `false`. - BOOST_CHECK(!cache1.GetNullifier(nf)); + checkNullifierCache(cache1, txWithNullifiers, false); } // Also correct behavior: @@ -190,17 +250,19 @@ BOOST_AUTO_TEST_CASE(nullifier_regression_test) CCoinsViewTest base; CCoinsViewCacheTest cache1(&base); + TxWithNullifiers txWithNullifiers; + // Insert a nullifier into the base. - uint256 nf = GetRandHash(); - cache1.SetNullifier(nf, true); + cache1.SetNullifiers(txWithNullifiers.tx, true); + checkNullifierCache(cache1, txWithNullifiers, true); cache1.Flush(); // Flush to base. // Remove the nullifier from cache - cache1.SetNullifier(nf, false); + cache1.SetNullifiers(txWithNullifiers.tx, false); cache1.Flush(); // Flush to base. // The nullifier now should be `false`. - BOOST_CHECK(!cache1.GetNullifier(nf)); + checkNullifierCache(cache1, txWithNullifiers, false); } // Works because we bring it from the parent cache: @@ -209,21 +271,22 @@ BOOST_AUTO_TEST_CASE(nullifier_regression_test) CCoinsViewCacheTest cache1(&base); // Insert a nullifier into the base. - uint256 nf = GetRandHash(); - cache1.SetNullifier(nf, true); + TxWithNullifiers txWithNullifiers; + cache1.SetNullifiers(txWithNullifiers.tx, true); + checkNullifierCache(cache1, txWithNullifiers, true); cache1.Flush(); // Empties cache. // Create cache on top. { // Remove the nullifier. CCoinsViewCacheTest cache2(&cache1); - BOOST_CHECK(cache2.GetNullifier(nf)); - cache2.SetNullifier(nf, false); + checkNullifierCache(cache2, txWithNullifiers, true); + cache1.SetNullifiers(txWithNullifiers.tx, false); cache2.Flush(); // Empties cache, flushes to cache1. } // The nullifier now should be `false`. - BOOST_CHECK(!cache1.GetNullifier(nf)); + checkNullifierCache(cache1, txWithNullifiers, false); } // Was broken: @@ -232,20 +295,20 @@ BOOST_AUTO_TEST_CASE(nullifier_regression_test) CCoinsViewCacheTest cache1(&base); // Insert a nullifier into the base. - uint256 nf = GetRandHash(); - cache1.SetNullifier(nf, true); + TxWithNullifiers txWithNullifiers; + cache1.SetNullifiers(txWithNullifiers.tx, true); cache1.Flush(); // Empties cache. // Create cache on top. { // Remove the nullifier. CCoinsViewCacheTest cache2(&cache1); - cache2.SetNullifier(nf, false); + cache2.SetNullifiers(txWithNullifiers.tx, false); cache2.Flush(); // Empties cache, flushes to cache1. } // The nullifier now should be `false`. - BOOST_CHECK(!cache1.GetNullifier(nf)); + checkNullifierCache(cache1, txWithNullifiers, false); } } @@ -414,23 +477,22 @@ BOOST_AUTO_TEST_CASE(nullifiers_test) CCoinsViewTest base; CCoinsViewCacheTest cache(&base); - uint256 nf = GetRandHash(); - - BOOST_CHECK(!cache.GetNullifier(nf)); - cache.SetNullifier(nf, true); - BOOST_CHECK(cache.GetNullifier(nf)); + TxWithNullifiers txWithNullifiers; + checkNullifierCache(cache, txWithNullifiers, false); + cache.SetNullifiers(txWithNullifiers.tx, true); + checkNullifierCache(cache, txWithNullifiers, true); cache.Flush(); CCoinsViewCacheTest cache2(&base); - BOOST_CHECK(cache2.GetNullifier(nf)); - cache2.SetNullifier(nf, false); - BOOST_CHECK(!cache2.GetNullifier(nf)); + checkNullifierCache(cache2, txWithNullifiers, true); + cache2.SetNullifiers(txWithNullifiers.tx, false); + checkNullifierCache(cache2, txWithNullifiers, false); cache2.Flush(); CCoinsViewCacheTest cache3(&base); - BOOST_CHECK(!cache3.GetNullifier(nf)); + checkNullifierCache(cache3, txWithNullifiers, false); } BOOST_AUTO_TEST_CASE(anchors_flush_test) diff --git a/src/test/transaction_tests.cpp b/src/test/transaction_tests.cpp index 1b0b26ee7..8524e1bc2 100644 --- a/src/test/transaction_tests.cpp +++ b/src/test/transaction_tests.cpp @@ -399,6 +399,34 @@ BOOST_AUTO_TEST_CASE(test_basic_joinsplit_verification) } } +void test_simple_sapling_invalidity(uint32_t consensusBranchId, CMutableTransaction tx) +{ + { + CMutableTransaction newTx(tx); + CValidationState state; + + BOOST_CHECK(!CheckTransactionWithoutProofVerification(newTx, state)); + BOOST_CHECK(state.GetRejectReason() == "bad-txns-vin-empty"); + } + { + // Ensure that nullifiers are never duplicated within a transaction. + CMutableTransaction newTx(tx); + CValidationState state; + + newTx.vShieldedSpend.push_back(SpendDescription()); + newTx.vShieldedSpend[0].nullifier = GetRandHash(); + newTx.vShieldedSpend.push_back(SpendDescription()); + newTx.vShieldedSpend[1].nullifier = newTx.vShieldedSpend[0].nullifier; + + BOOST_CHECK(!CheckTransactionWithoutProofVerification(newTx, state)); + BOOST_CHECK(state.GetRejectReason() == "bad-spend-description-nullifiers-duplicate"); + + newTx.vShieldedSpend[1].nullifier = GetRandHash(); + + BOOST_CHECK(CheckTransactionWithoutProofVerification(newTx, state)); + } +} + void test_simple_joinsplit_invalidity(uint32_t consensusBranchId, CMutableTransaction tx) { auto verifier = libzcash::ProofVerifier::Strict(); @@ -548,6 +576,14 @@ BOOST_AUTO_TEST_CASE(test_simple_joinsplit_invalidity_driver) { test_simple_joinsplit_invalidity(NetworkUpgradeInfo[Consensus::UPGRADE_OVERWINTER].nBranchId, mtx); UpdateNetworkUpgradeParameters(Consensus::UPGRADE_OVERWINTER, Consensus::NetworkUpgrade::NO_ACTIVATION_HEIGHT); + // Test Sapling things + mtx.nVersionGroupId = SAPLING_VERSION_GROUP_ID; + mtx.nVersion = SAPLING_TX_VERSION; + + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_SAPLING, Consensus::NetworkUpgrade::ALWAYS_ACTIVE); + test_simple_sapling_invalidity(NetworkUpgradeInfo[Consensus::UPGRADE_SAPLING].nBranchId, mtx); + UpdateNetworkUpgradeParameters(Consensus::UPGRADE_SAPLING, Consensus::NetworkUpgrade::NO_ACTIVATION_HEIGHT); + // Switch back to mainnet parameters as originally selected in test fixture SelectParams(CBaseChainParams::MAIN); } diff --git a/src/txdb.cpp b/src/txdb.cpp index 24ab0b4b5..4759ff106 100644 --- a/src/txdb.cpp +++ b/src/txdb.cpp @@ -19,6 +19,7 @@ using namespace std; static const char DB_ANCHOR = 'A'; static const char DB_NULLIFIER = 's'; +static const char DB_SAPLING_NULLIFIER = 'S'; static const char DB_COINS = 'c'; static const char DB_BLOCK_FILES = 'f'; static const char DB_TXINDEX = 't'; @@ -51,11 +52,20 @@ bool CCoinsViewDB::GetAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) return read; } -bool CCoinsViewDB::GetNullifier(const uint256 &nf) const { +bool CCoinsViewDB::GetNullifier(const uint256 &nf, NullifierType type) const { bool spent = false; - bool read = db.Read(make_pair(DB_NULLIFIER, nf), spent); - - return read; + char dbChar; + switch (type) { + case SPROUT_NULLIFIER: + dbChar = DB_NULLIFIER; + break; + case SAPLING_NULLIFIER: + dbChar = DB_SAPLING_NULLIFIER; + break; + default: + throw runtime_error("Unknown nullifier type " + type); + } + return db.Read(make_pair(dbChar, nf), spent); } bool CCoinsViewDB::GetCoins(const uint256 &txid, CCoins &coins) const { @@ -80,11 +90,27 @@ uint256 CCoinsViewDB::GetBestAnchor() const { return hashBestAnchor; } +void BatchWriteNullifiers(CDBBatch& batch, CNullifiersMap& mapToUse, const char& dbChar) +{ + for (CNullifiersMap::iterator it = mapToUse.begin(); it != mapToUse.end();) { + if (it->second.flags & CNullifiersCacheEntry::DIRTY) { + if (!it->second.entered) + batch.Erase(make_pair(dbChar, it->first)); + else + batch.Write(make_pair(dbChar, it->first), true); + // TODO: changed++? ... See comment in CCoinsViewDB::BatchWrite. If this is needed we could return an int + } + CNullifiersMap::iterator itOld = it++; + mapToUse.erase(itOld); + } +} + bool CCoinsViewDB::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashBlock, const uint256 &hashAnchor, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers) { + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap &mapSaplingNullifiers) { CDBBatch batch(db); size_t count = 0; size_t changed = 0; @@ -114,17 +140,8 @@ bool CCoinsViewDB::BatchWrite(CCoinsMap &mapCoins, mapAnchors.erase(itOld); } - for (CNullifiersMap::iterator it = mapNullifiers.begin(); it != mapNullifiers.end();) { - if (it->second.flags & CNullifiersCacheEntry::DIRTY) { - if (!it->second.entered) - batch.Erase(make_pair(DB_NULLIFIER, it->first)); - else - batch.Write(make_pair(DB_NULLIFIER, it->first), true); - // TODO: changed++? - } - CNullifiersMap::iterator itOld = it++; - mapNullifiers.erase(itOld); - } + ::BatchWriteNullifiers(batch, mapSproutNullifiers, DB_NULLIFIER); + ::BatchWriteNullifiers(batch, mapSaplingNullifiers, DB_SAPLING_NULLIFIER); if (!hashBlock.IsNull()) batch.Write(DB_BEST_BLOCK, hashBlock); diff --git a/src/txdb.h b/src/txdb.h index f96b07676..53f3c31c6 100644 --- a/src/txdb.h +++ b/src/txdb.h @@ -36,7 +36,7 @@ public: CCoinsViewDB(size_t nCacheSize, bool fMemory = false, bool fWipe = false); bool GetAnchorAt(const uint256 &rt, ZCIncrementalMerkleTree &tree) const; - bool GetNullifier(const uint256 &nf) const; + bool GetNullifier(const uint256 &nf, NullifierType type) const; bool GetCoins(const uint256 &txid, CCoins &coins) const; bool HaveCoins(const uint256 &txid) const; uint256 GetBestBlock() const; @@ -45,7 +45,8 @@ public: const uint256 &hashBlock, const uint256 &hashAnchor, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers); + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap &mapSaplingNullifiers); bool GetStats(CCoinsStats &stats) const; }; diff --git a/src/txmempool.cpp b/src/txmempool.cpp index 650239e89..9480d81cd 100644 --- a/src/txmempool.cpp +++ b/src/txmempool.cpp @@ -107,9 +107,12 @@ bool CTxMemPool::addUnchecked(const uint256& hash, const CTxMemPoolEntry &entry, mapNextTx[tx.vin[i].prevout] = CInPoint(&tx, i); BOOST_FOREACH(const JSDescription &joinsplit, tx.vjoinsplit) { BOOST_FOREACH(const uint256 &nf, joinsplit.nullifiers) { - mapNullifiers[nf] = &tx; + mapSproutNullifiers[nf] = &tx; } } + for (const SpendDescription &spendDescription : tx.vShieldedSpend) { + mapSaplingNullifiers[spendDescription.nullifier] = &tx; + } nTransactionsUpdated++; totalTxSize += entry.GetTxSize(); cachedInnerUsage += entry.DynamicMemoryUsage(); @@ -157,10 +160,12 @@ void CTxMemPool::remove(const CTransaction &origTx, std::list& rem mapNextTx.erase(txin.prevout); BOOST_FOREACH(const JSDescription& joinsplit, tx.vjoinsplit) { BOOST_FOREACH(const uint256& nf, joinsplit.nullifiers) { - mapNullifiers.erase(nf); + mapSproutNullifiers.erase(nf); } } - + for (const SpendDescription &spendDescription : tx.vShieldedSpend) { + mapSaplingNullifiers.erase(spendDescription.nullifier); + } removed.push_back(tx); totalTxSize -= mapTx.find(hash)->GetTxSize(); cachedInnerUsage -= mapTx.find(hash)->DynamicMemoryUsage(); @@ -244,16 +249,24 @@ void CTxMemPool::removeConflicts(const CTransaction &tx, std::list BOOST_FOREACH(const JSDescription &joinsplit, tx.vjoinsplit) { BOOST_FOREACH(const uint256 &nf, joinsplit.nullifiers) { - std::map::iterator it = mapNullifiers.find(nf); - if (it != mapNullifiers.end()) { + std::map::iterator it = mapSproutNullifiers.find(nf); + if (it != mapSproutNullifiers.end()) { const CTransaction &txConflict = *it->second; - if (txConflict != tx) - { + if (txConflict != tx) { remove(txConflict, removed, true); } } } } + for (const SpendDescription &spendDescription : tx.vShieldedSpend) { + std::map::iterator it = mapSaplingNullifiers.find(spendDescription.nullifier); + if (it != mapSaplingNullifiers.end()) { + const CTransaction &txConflict = *it->second; + if (txConflict != tx) { + remove(txConflict, removed, true); + } + } + } } void CTxMemPool::removeExpired(unsigned int nBlockHeight) @@ -381,7 +394,7 @@ void CTxMemPool::check(const CCoinsViewCache *pcoins) const BOOST_FOREACH(const JSDescription &joinsplit, tx.vjoinsplit) { BOOST_FOREACH(const uint256 &nf, joinsplit.nullifiers) { - assert(!pcoins->GetNullifier(nf)); + assert(!pcoins->GetNullifier(nf, SPROUT_NULLIFIER)); } ZCIncrementalMerkleTree tree; @@ -399,6 +412,9 @@ void CTxMemPool::check(const CCoinsViewCache *pcoins) const intermediates.insert(std::make_pair(tree.root(), tree)); } + for (const SpendDescription &spendDescription : tx.vShieldedSpend) { + assert(!pcoins->GetNullifier(spendDescription.nullifier, SAPLING_NULLIFIER)); + } if (fDependsWait) waitingOnDependants.push_back(&(*it)); else { @@ -436,18 +452,35 @@ void CTxMemPool::check(const CCoinsViewCache *pcoins) const assert(it->first == it->second.ptx->vin[it->second.n].prevout); } - for (std::map::const_iterator it = mapNullifiers.begin(); it != mapNullifiers.end(); it++) { - uint256 hash = it->second->GetHash(); - indexed_transaction_set::const_iterator it2 = mapTx.find(hash); - const CTransaction& tx = it2->GetTx(); - assert(it2 != mapTx.end()); - assert(&tx == it->second); - } + checkNullifiers(SPROUT_NULLIFIER); + checkNullifiers(SAPLING_NULLIFIER); assert(totalTxSize == checkTotal); assert(innerUsage == cachedInnerUsage); } +void CTxMemPool::checkNullifiers(NullifierType type) const +{ + const std::map* mapToUse; + switch (type) { + case SPROUT_NULLIFIER: + mapToUse = &mapSproutNullifiers; + break; + case SAPLING_NULLIFIER: + mapToUse = &mapSaplingNullifiers; + break; + default: + throw runtime_error("Unknown nullifier type " + type); + } + for (const auto& entry : *mapToUse) { + uint256 hash = entry.second->GetHash(); + CTxMemPool::indexed_transaction_set::const_iterator findTx = mapTx.find(hash); + const CTransaction& tx = findTx->GetTx(); + assert(findTx != mapTx.end()); + assert(&tx == entry.second); + } +} + void CTxMemPool::queryHashes(vector& vtxid) { vtxid.clear(); @@ -549,13 +582,23 @@ bool CTxMemPool::HasNoInputsOf(const CTransaction &tx) const return true; } +bool CTxMemPool::nullifierExists(const uint256& nullifier, NullifierType type) const +{ + switch (type) { + case SPROUT_NULLIFIER: + return mapSproutNullifiers.count(nullifier); + case SAPLING_NULLIFIER: + return mapSaplingNullifiers.count(nullifier); + default: + throw runtime_error("Unknown nullifier type " + type); + } +} + CCoinsViewMemPool::CCoinsViewMemPool(CCoinsView *baseIn, CTxMemPool &mempoolIn) : CCoinsViewBacked(baseIn), mempool(mempoolIn) { } -bool CCoinsViewMemPool::GetNullifier(const uint256 &nf) const { - if (mempool.mapNullifiers.count(nf)) - return true; - - return base->GetNullifier(nf); +bool CCoinsViewMemPool::GetNullifier(const uint256 &nf, NullifierType type) const +{ + return mempool.nullifierExists(nf, type) || base->GetNullifier(nf, type); } bool CCoinsViewMemPool::GetCoins(const uint256 &txid, CCoins &coins) const { diff --git a/src/txmempool.h b/src/txmempool.h index fd8758741..0397c6d7f 100644 --- a/src/txmempool.h +++ b/src/txmempool.h @@ -131,6 +131,11 @@ private: uint64_t totalTxSize = 0; //! sum of all mempool tx' byte sizes uint64_t cachedInnerUsage; //! sum of dynamic memory usage of all the map elements (NOT the maps themselves) + std::map mapSproutNullifiers; + std::map mapSaplingNullifiers; + + void checkNullifiers(NullifierType type) const; + public: typedef boost::multi_index_container< CTxMemPoolEntry, @@ -148,7 +153,6 @@ public: mutable CCriticalSection cs; indexed_transaction_set mapTx; std::map mapNextTx; - std::map mapNullifiers; std::map > mapDeltas; CTxMemPool(const CFeeRate& _minRelayFee); @@ -188,6 +192,8 @@ public: void ApplyDeltas(const uint256 hash, double &dPriorityDelta, CAmount &nFeeDelta); void ClearPrioritisation(const uint256 hash); + bool nullifierExists(const uint256& nullifier, NullifierType type) const; + unsigned long size() { LOCK(cs); @@ -237,7 +243,7 @@ protected: public: CCoinsViewMemPool(CCoinsView *baseIn, CTxMemPool &mempoolIn); - bool GetNullifier(const uint256 &txid) const; + bool GetNullifier(const uint256 &txid, NullifierType type) const; bool GetCoins(const uint256 &txid, CCoins &coins) const; bool HaveCoins(const uint256 &txid) const; }; diff --git a/src/zcbenchmarks.cpp b/src/zcbenchmarks.cpp index c268e6e12..1d4ad78c8 100644 --- a/src/zcbenchmarks.cpp +++ b/src/zcbenchmarks.cpp @@ -366,7 +366,7 @@ public: return false; } - bool GetNullifier(const uint256 &nf) const { + bool GetNullifier(const uint256 &nf, NullifierType type) const { return false; } @@ -382,7 +382,8 @@ public: const uint256 &hashBlock, const uint256 &hashAnchor, CAnchorsMap &mapAnchors, - CNullifiersMap &mapNullifiers) { + CNullifiersMap &mapSproutNullifiers, + CNullifiersMap& mapSaplingNullifiers) { return false; }