diff --git a/qa/pull-tester/rpc-tests.sh b/qa/pull-tester/rpc-tests.sh index 773bc4cc0..59455129d 100755 --- a/qa/pull-tester/rpc-tests.sh +++ b/qa/pull-tester/rpc-tests.sh @@ -29,6 +29,7 @@ testScripts=( 'mempool_spendcoinbase.py' 'mempool_reorg.py' 'mempool_tx_input_limit.py' + 'mempool_nu_activation.py' 'httpbasics.py' 'zapwallettxes.py' 'proxy_test.py' diff --git a/qa/rpc-tests/mempool_nu_activation.py b/qa/rpc-tests/mempool_nu_activation.py new file mode 100755 index 000000000..f54095660 --- /dev/null +++ b/qa/rpc-tests/mempool_nu_activation.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python2 +# Copyright (c) 2018 The Zcash developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. + +from test_framework.test_framework import BitcoinTestFramework +from test_framework.util import assert_equal, initialize_chain_clean, \ + start_node, connect_nodes, wait_and_assert_operationid_status + +from decimal import Decimal + +# Test mempool behaviour around network upgrade activation +class MempoolUpgradeActivationTest(BitcoinTestFramework): + + alert_filename = None # Set by setup_network + + def setup_network(self): + args = ["-checkmempool", "-debug=mempool", "-blockmaxsize=4000", "-nuparams=5ba81b19:200"] + self.nodes = [] + self.nodes.append(start_node(0, self.options.tmpdir, args)) + self.nodes.append(start_node(1, self.options.tmpdir, args)) + connect_nodes(self.nodes[1], 0) + self.is_network_split = False + self.sync_all + + def setup_chain(self): + print "Initializing test directory "+self.options.tmpdir + initialize_chain_clean(self.options.tmpdir, 2) + + def run_test(self): + self.nodes[1].generate(100) + self.sync_all() + + # Mine 97 blocks. After this, nodes[1] blocks + # 1 to 97 are spend-able. + self.nodes[0].generate(97) + self.sync_all() + + # Shield some ZEC + node1_taddr = self.nodes[1].getnewaddress() + node0_zaddr = self.nodes[0].z_getnewaddress() + recipients = [{'address': node0_zaddr, 'amount': Decimal('10')}] + myopid = self.nodes[1].z_sendmany(node1_taddr, recipients, 1, Decimal('0')) + print wait_and_assert_operationid_status(self.nodes[1], myopid) + self.sync_all() + + # Mine block 198. After this, the mempool expects + # block 199, which is the last Sprout block. + self.nodes[0].generate(1) + self.sync_all() + + # Mempool should be empty. + assert_equal(set(self.nodes[0].getrawmempool()), set()) + + # Check node 0 shielded balance + assert_equal(self.nodes[0].z_getbalance(node0_zaddr), Decimal('10')) + + # Fill the mempool with twice as many transactions as can fit into blocks + node0_taddr = self.nodes[0].getnewaddress() + sprout_txids = [] + while self.nodes[1].getmempoolinfo()['bytes'] < 2 * 4000: + sprout_txids.append(self.nodes[1].sendtoaddress(node0_taddr, Decimal('0.001'))) + self.sync_all() + + # Spends should be in the mempool + sprout_mempool = set(self.nodes[0].getrawmempool()) + assert_equal(sprout_mempool, set(sprout_txids)) + + # Mine block 199. After this, the mempool expects + # block 200, which is the first Overwinter block. + self.nodes[0].generate(1) + self.sync_all() + + # mempool should be empty. + assert_equal(set(self.nodes[0].getrawmempool()), set()) + + # Block 199 should contain a subset of the original mempool + # (with all other transactions having been dropped) + block_txids = self.nodes[0].getblock(self.nodes[0].getbestblockhash())['tx'] + assert(len(block_txids) < len(sprout_txids)) + for txid in block_txids[1:]: # Exclude coinbase + assert(txid in sprout_txids) + + # Create some transparent Overwinter transactions + overwinter_txids = [self.nodes[1].sendtoaddress(node0_taddr, Decimal('0.001')) for i in range(10)] + self.sync_all() + + # Create a shielded Overwinter transaction + recipients = [{'address': node0_taddr, 'amount': Decimal('10')}] + myopid = self.nodes[0].z_sendmany(node0_zaddr, recipients, 1, Decimal('0')) + shielded = wait_and_assert_operationid_status(self.nodes[0], myopid) + assert(shielded != None) + overwinter_txids.append(shielded) + self.sync_all() + + # Spends should be in the mempool + assert_equal(set(self.nodes[0].getrawmempool()), set(overwinter_txids)) + + # Node 0 note should be unspendable + assert_equal(self.nodes[0].z_getbalance(node0_zaddr), Decimal('0')) + + # Invalidate block 199. + self.nodes[0].invalidateblock(self.nodes[0].getbestblockhash()) + + # BUG: Ideally, the mempool should now only contain the transactions + # that were in block 199, the Overwinter transactions having been dropped. + # However, because chainActive is not updated until after the transactions + # in the disconnected block have been re-added to the mempool, the height + # seen by AcceptToMemoryPool is one greater than it should be. This causes + # the block 199 transactions to be validated against the Overwinter rules, + # and rejected because they (obviously) fail. + #assert_equal(set(self.nodes[0].getrawmempool()), set(block_txids[1:])) + assert_equal(set(self.nodes[0].getrawmempool()), set()) + + # Node 0 note should be spendable again + assert_equal(self.nodes[0].z_getbalance(node0_zaddr), Decimal('10')) + +if __name__ == '__main__': + MempoolUpgradeActivationTest().main() diff --git a/src/consensus/upgrades.h b/src/consensus/upgrades.h index 0c5462c2e..6a9173264 100644 --- a/src/consensus/upgrades.h +++ b/src/consensus/upgrades.h @@ -26,6 +26,9 @@ struct NUInfo { extern const struct NUInfo NetworkUpgradeInfo[]; +// Consensus branch id to identify pre-overwinter (Sprout) consensus rules. +static const uint32_t SPROUT_BRANCH_ID = NetworkUpgradeInfo[Consensus::BASE_SPROUT].nBranchId; + /** * Checks the state of a given network upgrade based on block height. * Caller must check that the height is >= 0 (and handle unknown heights). diff --git a/src/gtest/test_mempool.cpp b/src/gtest/test_mempool.cpp index 46b0a6d12..c0d2cb874 100644 --- a/src/gtest/test_mempool.cpp +++ b/src/gtest/test_mempool.cpp @@ -1,6 +1,7 @@ #include #include +#include "consensus/upgrades.h" #include "consensus/validation.h" #include "core_io.h" #include "main.h" @@ -83,7 +84,7 @@ TEST(Mempool, PriorityStatsDoNotCrash) { unsigned int nHeight = 92045; double dPriority = view.GetPriority(tx, nHeight); - CTxMemPoolEntry entry(tx, nFees, nTime, dPriority, nHeight, true, false); + CTxMemPoolEntry entry(tx, nFees, nTime, dPriority, nHeight, true, false, SPROUT_BRANCH_ID); // Check it does not crash (ie. the death test fails) EXPECT_NONFATAL_FAILURE(EXPECT_DEATH(testPool.addUnchecked(tx.GetHash(), entry), ""), ""); diff --git a/src/main.cpp b/src/main.cpp index 8cb6b9dd3..ba1edd183 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1199,7 +1199,12 @@ bool AcceptToMemoryPool(CTxMemPool& pool, CValidationState &state, const CTransa } } - CTxMemPoolEntry entry(tx, nFees, GetTime(), dPriority, chainActive.Height(), mempool.HasNoInputsOf(tx), fSpendsCoinbase); + // Grab the branch ID we expect this transaction to commit to. We don't + // yet know if it does, but if the entry gets added to the mempool, then + // it has passed ContextualCheckInputs and therefore this is correct. + auto consensusBranchId = CurrentEpochBranchId(chainActive.Height() + 1, Params().GetConsensus()); + + CTxMemPoolEntry entry(tx, nFees, GetTime(), dPriority, chainActive.Height(), mempool.HasNoInputsOf(tx), fSpendsCoinbase, consensusBranchId); unsigned int nSize = entry.GetTxSize(); // Accept a tx if it contains joinsplits and has at least the default fee specified by z_sendmany. @@ -2430,7 +2435,10 @@ void static UpdateTip(CBlockIndex *pindexNew) { } } -/** Disconnect chainActive's tip. You probably want to call mempool.removeForReorg after this, with cs_main held. */ +/** + * Disconnect chainActive's tip. You probably want to call mempool.removeForReorg and + * mempool.removeWithoutBranchId after this, with cs_main held. + */ bool static DisconnectTip(CValidationState &state, bool fBare = false) { CBlockIndex *pindexDelete = chainActive.Tip(); assert(pindexDelete); @@ -2493,6 +2501,7 @@ static int64_t nTimePostConnect = 0; /** * Connect a new block to chainActive. pblock is either NULL or a pointer to a CBlock * corresponding to pindexNew, to bypass loading it again from disk. + * You probably want to call mempool.removeWithoutBranchId after this, with cs_main held. */ bool static ConnectTip(CValidationState &state, CBlockIndex *pindexNew, CBlock *pblock) { assert(pindexNew->pprev == chainActive.Tip()); @@ -2691,6 +2700,8 @@ static bool ActivateBestChainStep(CValidationState &state, CBlockIndex *pindexMo if (fBlocksDisconnected) { mempool.removeForReorg(pcoinsTip, chainActive.Tip()->nHeight + 1, STANDARD_LOCKTIME_VERIFY_FLAGS); } + mempool.removeWithoutBranchId( + CurrentEpochBranchId(chainActive.Tip()->nHeight + 1, Params().GetConsensus())); mempool.check(pcoinsTip); // Callbacks/notifications for a new best chain. @@ -2778,6 +2789,8 @@ bool InvalidateBlock(CValidationState& state, CBlockIndex *pindex) { // unconditionally valid already, so force disconnect away from it. if (!DisconnectTip(state)) { mempool.removeForReorg(pcoinsTip, chainActive.Tip()->nHeight + 1, STANDARD_LOCKTIME_VERIFY_FLAGS); + mempool.removeWithoutBranchId( + CurrentEpochBranchId(chainActive.Tip()->nHeight + 1, Params().GetConsensus())); return false; } } @@ -2794,6 +2807,8 @@ bool InvalidateBlock(CValidationState& state, CBlockIndex *pindex) { InvalidChainFound(pindex); mempool.removeForReorg(pcoinsTip, chainActive.Tip()->nHeight + 1, STANDARD_LOCKTIME_VERIFY_FLAGS); + mempool.removeWithoutBranchId( + CurrentEpochBranchId(chainActive.Tip()->nHeight + 1, Params().GetConsensus())); return true; } diff --git a/src/test/mempool_tests.cpp b/src/test/mempool_tests.cpp index c72ec4c36..325789572 100644 --- a/src/test/mempool_tests.cpp +++ b/src/test/mempool_tests.cpp @@ -2,6 +2,7 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. +#include "consensus/upgrades.h" #include "main.h" #include "txmempool.h" #include "util.h" @@ -156,4 +157,59 @@ BOOST_AUTO_TEST_CASE(MempoolIndexingTest) BOOST_CHECK(it == pool.mapTx.get<1>().end()); } +BOOST_AUTO_TEST_CASE(RemoveWithoutBranchId) { + CTxMemPool pool(CFeeRate(0)); + TestMemPoolEntryHelper entry; + entry.nFee = 10000LL; + entry.hadNoDependencies = true; + + // Add some Sprout transactions + for (auto i = 1; i < 11; i++) { + CMutableTransaction tx = CMutableTransaction(); + tx.vout.resize(1); + tx.vout[0].scriptPubKey = CScript() << OP_11 << OP_EQUAL; + tx.vout[0].nValue = i * COIN; + pool.addUnchecked(tx.GetHash(), entry.BranchId(NetworkUpgradeInfo[Consensus::BASE_SPROUT].nBranchId).FromTx(tx)); + } + BOOST_CHECK_EQUAL(pool.size(), 10); + + // Check the pool only contains Sprout transactions + for (CTxMemPool::indexed_transaction_set::const_iterator it = pool.mapTx.begin(); it != pool.mapTx.end(); it++) { + BOOST_CHECK_EQUAL(it->GetValidatedBranchId(), NetworkUpgradeInfo[Consensus::BASE_SPROUT].nBranchId); + } + + // Add some dummy transactions + for (auto i = 1; i < 11; i++) { + CMutableTransaction tx = CMutableTransaction(); + tx.vout.resize(1); + tx.vout[0].scriptPubKey = CScript() << OP_11 << OP_EQUAL; + tx.vout[0].nValue = i * COIN + 100; + pool.addUnchecked(tx.GetHash(), entry.BranchId(NetworkUpgradeInfo[Consensus::UPGRADE_TESTDUMMY].nBranchId).FromTx(tx)); + } + BOOST_CHECK_EQUAL(pool.size(), 20); + + // Add some Overwinter transactions + for (auto i = 1; i < 11; i++) { + CMutableTransaction tx = CMutableTransaction(); + tx.vout.resize(1); + tx.vout[0].scriptPubKey = CScript() << OP_11 << OP_EQUAL; + tx.vout[0].nValue = i * COIN + 200; + pool.addUnchecked(tx.GetHash(), entry.BranchId(NetworkUpgradeInfo[Consensus::UPGRADE_OVERWINTER].nBranchId).FromTx(tx)); + } + BOOST_CHECK_EQUAL(pool.size(), 30); + + // Remove transactions that are not for Overwinter + pool.removeWithoutBranchId(NetworkUpgradeInfo[Consensus::UPGRADE_OVERWINTER].nBranchId); + BOOST_CHECK_EQUAL(pool.size(), 10); + + // Check the pool only contains Overwinter transactions + for (CTxMemPool::indexed_transaction_set::const_iterator it = pool.mapTx.begin(); it != pool.mapTx.end(); it++) { + BOOST_CHECK_EQUAL(it->GetValidatedBranchId(), NetworkUpgradeInfo[Consensus::UPGRADE_OVERWINTER].nBranchId); + } + + // Roll back to Sprout + pool.removeWithoutBranchId(NetworkUpgradeInfo[Consensus::BASE_SPROUT].nBranchId); + BOOST_CHECK_EQUAL(pool.size(), 0); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/test_bitcoin.cpp b/src/test/test_bitcoin.cpp index 822474235..02a3a50d2 100644 --- a/src/test/test_bitcoin.cpp +++ b/src/test/test_bitcoin.cpp @@ -106,7 +106,8 @@ TestingSetup::~TestingSetup() CTxMemPoolEntry TestMemPoolEntryHelper::FromTx(CMutableTransaction &tx, CTxMemPool *pool) { return CTxMemPoolEntry(tx, nFee, nTime, dPriority, nHeight, - pool ? pool->HasNoInputsOf(tx) : hadNoDependencies, spendsCoinbase); + pool ? pool->HasNoInputsOf(tx) : hadNoDependencies, + spendsCoinbase, nBranchId); } void Shutdown(void* parg) diff --git a/src/test/test_bitcoin.h b/src/test/test_bitcoin.h index db9245b7f..ae528d682 100644 --- a/src/test/test_bitcoin.h +++ b/src/test/test_bitcoin.h @@ -1,6 +1,7 @@ #ifndef BITCOIN_TEST_TEST_BITCOIN_H #define BITCOIN_TEST_TEST_BITCOIN_H +#include "consensus/upgrades.h" #include "pubkey.h" #include "txdb.h" @@ -48,10 +49,12 @@ struct TestMemPoolEntryHelper unsigned int nHeight; bool hadNoDependencies; bool spendsCoinbase; + uint32_t nBranchId; TestMemPoolEntryHelper() : nFee(0), nTime(0), dPriority(0.0), nHeight(1), - hadNoDependencies(false), spendsCoinbase(false) { } + hadNoDependencies(false), spendsCoinbase(false), + nBranchId(SPROUT_BRANCH_ID) { } CTxMemPoolEntry FromTx(CMutableTransaction &tx, CTxMemPool *pool = NULL); @@ -62,5 +65,6 @@ struct TestMemPoolEntryHelper TestMemPoolEntryHelper &Height(unsigned int _height) { nHeight = _height; return *this; } TestMemPoolEntryHelper &HadNoDependencies(bool _hnd) { hadNoDependencies = _hnd; return *this; } TestMemPoolEntryHelper &SpendsCoinbase(bool _flag) { spendsCoinbase = _flag; return *this; } + TestMemPoolEntryHelper &BranchId(uint32_t _branchId) { nBranchId = _branchId; return *this; } }; #endif diff --git a/src/txmempool.cpp b/src/txmempool.cpp index 570af4945..c69e77d05 100644 --- a/src/txmempool.cpp +++ b/src/txmempool.cpp @@ -28,10 +28,10 @@ CTxMemPoolEntry::CTxMemPoolEntry(): CTxMemPoolEntry::CTxMemPoolEntry(const CTransaction& _tx, const CAmount& _nFee, int64_t _nTime, double _dPriority, unsigned int _nHeight, bool poolHasNoInputsOf, - bool _spendsCoinbase): + bool _spendsCoinbase, uint32_t _nBranchId): tx(_tx), nFee(_nFee), nTime(_nTime), dPriority(_dPriority), nHeight(_nHeight), hadNoDependencies(poolHasNoInputsOf), - spendsCoinbase(_spendsCoinbase) + spendsCoinbase(_spendsCoinbase), nBranchId(_nBranchId) { nTxSize = ::GetSerializeSize(tx, SER_NETWORK, PROTOCOL_VERSION); nModSize = tx.CalculateModifiedSize(nTxSize); @@ -283,6 +283,28 @@ void CTxMemPool::removeForBlock(const std::vector& vtx, unsigned i minerPolicyEstimator->processBlock(nBlockHeight, entries, fCurrentEstimate); } +/** + * Called whenever the tip changes. Removes transactions which don't commit to + * the given branch ID from the mempool. + */ +void CTxMemPool::removeWithoutBranchId(uint32_t nMemPoolBranchId) +{ + LOCK(cs); + std::list transactionsToRemove; + + for (indexed_transaction_set::const_iterator it = mapTx.begin(); it != mapTx.end(); it++) { + const CTransaction& tx = it->GetTx(); + if (it->GetValidatedBranchId() != nMemPoolBranchId) { + transactionsToRemove.push_back(tx); + } + } + + for (const CTransaction& tx : transactionsToRemove) { + std::list removed; + remove(tx, removed, true); + } +} + void CTxMemPool::clear() { LOCK(cs); diff --git a/src/txmempool.h b/src/txmempool.h index 913682eeb..2cb2c8f05 100644 --- a/src/txmempool.h +++ b/src/txmempool.h @@ -51,11 +51,12 @@ private: unsigned int nHeight; //! Chain height when entering the mempool bool hadNoDependencies; //! Not dependent on any other txs when it entered the mempool bool spendsCoinbase; //! keep track of transactions that spend a coinbase + uint32_t nBranchId; //! Branch ID this transaction is known to commit to, cached for efficiency public: CTxMemPoolEntry(const CTransaction& _tx, const CAmount& _nFee, int64_t _nTime, double _dPriority, unsigned int _nHeight, - bool poolHasNoInputsOf, bool spendsCoinbase); + bool poolHasNoInputsOf, bool spendsCoinbase, uint32_t nBranchId); CTxMemPoolEntry(); CTxMemPoolEntry(const CTxMemPoolEntry& other); @@ -70,6 +71,7 @@ public: size_t DynamicMemoryUsage() const { return nUsageSize; } bool GetSpendsCoinbase() const { return spendsCoinbase; } + uint32_t GetValidatedBranchId() const { return nBranchId; } }; // extracts a TxMemPoolEntry's transaction hash @@ -168,6 +170,7 @@ public: void removeConflicts(const CTransaction &tx, std::list& removed); void removeForBlock(const std::vector& vtx, unsigned int nBlockHeight, std::list& conflicts, bool fCurrentEstimate = true); + void removeWithoutBranchId(uint32_t nMemPoolBranchId); void clear(); void queryHashes(std::vector& vtxid); void pruneSpent(const uint256& hash, CCoins &coins);