diff --git a/src/main.cpp b/src/main.cpp index 1f1893654..182eff232 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -2824,15 +2824,17 @@ bool static DisconnectTip(CValidationState &state, bool fBare = false) { // Update chainActive and related variables. UpdateTip(pindexDelete->pprev); // Get the current commitment tree - ZCIncrementalMerkleTree newTree; - assert(pcoinsTip->GetSproutAnchorAt(pcoinsTip->GetBestAnchor(SPROUT), newTree)); + ZCIncrementalMerkleTree newSproutTree; + ZCSaplingIncrementalMerkleTree newSaplingTree; + assert(pcoinsTip->GetSproutAnchorAt(pcoinsTip->GetBestAnchor(SPROUT), newSproutTree)); + assert(pcoinsTip->GetSaplingAnchorAt(pcoinsTip->GetBestAnchor(SAPLING), newSaplingTree)); // Let wallets know transactions went from 1-confirmed to // 0-confirmed or conflicted: BOOST_FOREACH(const CTransaction &tx, block.vtx) { SyncWithWallets(tx, NULL); } // Update cached incremental witnesses - GetMainSignals().ChainTip(pindexDelete, &block, newTree, false); + GetMainSignals().ChainTip(pindexDelete, &block, newSproutTree, newSaplingTree, false); return true; } @@ -2858,8 +2860,10 @@ bool static ConnectTip(CValidationState &state, CBlockIndex *pindexNew, CBlock * pblock = █ } // Get the current commitment tree - ZCIncrementalMerkleTree oldTree; - assert(pcoinsTip->GetSproutAnchorAt(pcoinsTip->GetBestAnchor(SPROUT), oldTree)); + ZCIncrementalMerkleTree oldSproutTree; + ZCSaplingIncrementalMerkleTree oldSaplingTree; + assert(pcoinsTip->GetSproutAnchorAt(pcoinsTip->GetBestAnchor(SPROUT), oldSproutTree)); + assert(pcoinsTip->GetSaplingAnchorAt(pcoinsTip->GetBestAnchor(SAPLING), oldSaplingTree)); // Apply the block atomically to the chain state. int64_t nTime2 = GetTimeMicros(); nTimeReadFromDisk += nTime2 - nTime1; int64_t nTime3; @@ -2904,7 +2908,7 @@ bool static ConnectTip(CValidationState &state, CBlockIndex *pindexNew, CBlock * SyncWithWallets(tx, pblock); } // Update cached incremental witnesses - GetMainSignals().ChainTip(pindexNew, pblock, oldTree, true); + GetMainSignals().ChainTip(pindexNew, pblock, oldSproutTree, oldSaplingTree, true); EnforceNodeDeprecation(pindexNew->nHeight); diff --git a/src/primitives/transaction.cpp b/src/primitives/transaction.cpp index a93364220..79adf290e 100644 --- a/src/primitives/transaction.cpp +++ b/src/primitives/transaction.cpp @@ -149,6 +149,11 @@ std::string COutPoint::ToString() const return strprintf("COutPoint(%s, %u)", hash.ToString().substr(0,10), n); } +std::string SaplingOutPoint::ToString() const +{ + return strprintf("SaplingOutPoint(%s, %u)", hash.ToString().substr(0, 10), n); +} + CTxIn::CTxIn(COutPoint prevoutIn, CScript scriptSigIn, uint32_t nSequenceIn) { prevout = prevoutIn; diff --git a/src/primitives/transaction.h b/src/primitives/transaction.h index 4295dea9f..6ce15a94b 100644 --- a/src/primitives/transaction.h +++ b/src/primitives/transaction.h @@ -309,15 +309,14 @@ public: } }; -/** An outpoint - a combination of a transaction hash and an index n into its vout */ -class COutPoint +class BaseOutPoint { public: uint256 hash; uint32_t n; - COutPoint() { SetNull(); } - COutPoint(uint256 hashIn, uint32_t nIn) { hash = hashIn; n = nIn; } + BaseOutPoint() { SetNull(); } + BaseOutPoint(uint256 hashIn, uint32_t nIn) { hash = hashIn; n = nIn; } ADD_SERIALIZE_METHODS; @@ -330,21 +329,38 @@ public: void SetNull() { hash.SetNull(); n = (uint32_t) -1; } bool IsNull() const { return (hash.IsNull() && n == (uint32_t) -1); } - friend bool operator<(const COutPoint& a, const COutPoint& b) + friend bool operator<(const BaseOutPoint& a, const BaseOutPoint& b) { return (a.hash < b.hash || (a.hash == b.hash && a.n < b.n)); } - friend bool operator==(const COutPoint& a, const COutPoint& b) + friend bool operator==(const BaseOutPoint& a, const BaseOutPoint& b) { return (a.hash == b.hash && a.n == b.n); } - friend bool operator!=(const COutPoint& a, const COutPoint& b) + friend bool operator!=(const BaseOutPoint& a, const BaseOutPoint& b) { return !(a == b); } +}; +/** An outpoint - a combination of a transaction hash and an index n into its vout */ +class COutPoint : public BaseOutPoint +{ +public: + COutPoint() : BaseOutPoint() {}; + COutPoint(uint256 hashIn, uint32_t nIn) : BaseOutPoint(hashIn, nIn) {}; + std::string ToString() const; +}; + +/** An outpoint - a combination of a transaction hash and an index n into its sapling + * output description (vShieldedOutput) */ +class SaplingOutPoint : public BaseOutPoint +{ +public: + SaplingOutPoint() : BaseOutPoint() {}; + SaplingOutPoint(uint256 hashIn, uint32_t nIn) : BaseOutPoint(hashIn, nIn) {}; std::string ToString() const; }; diff --git a/src/utiltest.cpp b/src/utiltest.cpp index 017cd84cd..f2cd7f584 100644 --- a/src/utiltest.cpp +++ b/src/utiltest.cpp @@ -10,9 +10,10 @@ CWalletTx GetValidReceive(ZCJoinSplit& params, const libzcash::SproutSpendingKey& sk, CAmount value, - bool randomInputs) { + bool randomInputs, + int32_t version /* = 2 */) { CMutableTransaction mtx; - mtx.nVersion = 2; // Enable JoinSplits + mtx.nVersion = version; mtx.vin.resize(2); if (randomInputs) { mtx.vin[0].prevout.hash = GetRandHash(); @@ -46,6 +47,12 @@ CWalletTx GetValidReceive(ZCJoinSplit& params, inputs, outputs, 2*value, 0, false}; mtx.vjoinsplit.push_back(jsdesc); + if (version >= 4) { + // Shielded Output + OutputDescription od; + mtx.vShieldedOutput.push_back(od); + } + // Empty output script. uint32_t consensusBranchId = SPROUT_BRANCH_ID; CScript scriptCode; diff --git a/src/utiltest.h b/src/utiltest.h index 722930609..327dc7be4 100644 --- a/src/utiltest.h +++ b/src/utiltest.h @@ -9,7 +9,8 @@ CWalletTx GetValidReceive(ZCJoinSplit& params, const libzcash::SproutSpendingKey& sk, CAmount value, - bool randomInputs); + bool randomInputs, + int32_t version = 2); libzcash::SproutNote GetNote(ZCJoinSplit& params, const libzcash::SproutSpendingKey& sk, const CTransaction& tx, size_t js, size_t n); diff --git a/src/validationinterface.cpp b/src/validationinterface.cpp index cd3e30f3d..ae1e322c2 100644 --- a/src/validationinterface.cpp +++ b/src/validationinterface.cpp @@ -17,7 +17,7 @@ void RegisterValidationInterface(CValidationInterface* pwalletIn) { g_signals.SyncTransaction.connect(boost::bind(&CValidationInterface::SyncTransaction, pwalletIn, _1, _2)); g_signals.EraseTransaction.connect(boost::bind(&CValidationInterface::EraseFromWallet, pwalletIn, _1)); g_signals.UpdatedTransaction.connect(boost::bind(&CValidationInterface::UpdatedTransaction, pwalletIn, _1)); - g_signals.ChainTip.connect(boost::bind(&CValidationInterface::ChainTip, pwalletIn, _1, _2, _3, _4)); + g_signals.ChainTip.connect(boost::bind(&CValidationInterface::ChainTip, pwalletIn, _1, _2, _3, _4, _5)); g_signals.SetBestChain.connect(boost::bind(&CValidationInterface::SetBestChain, pwalletIn, _1)); g_signals.Inventory.connect(boost::bind(&CValidationInterface::Inventory, pwalletIn, _1)); g_signals.Broadcast.connect(boost::bind(&CValidationInterface::ResendWalletTransactions, pwalletIn, _1)); @@ -28,7 +28,7 @@ void UnregisterValidationInterface(CValidationInterface* pwalletIn) { g_signals.BlockChecked.disconnect(boost::bind(&CValidationInterface::BlockChecked, pwalletIn, _1, _2)); g_signals.Broadcast.disconnect(boost::bind(&CValidationInterface::ResendWalletTransactions, pwalletIn, _1)); g_signals.Inventory.disconnect(boost::bind(&CValidationInterface::Inventory, pwalletIn, _1)); - g_signals.ChainTip.disconnect(boost::bind(&CValidationInterface::ChainTip, pwalletIn, _1, _2, _3, _4)); + g_signals.ChainTip.disconnect(boost::bind(&CValidationInterface::ChainTip, pwalletIn, _1, _2, _3, _4, _5)); g_signals.SetBestChain.disconnect(boost::bind(&CValidationInterface::SetBestChain, pwalletIn, _1)); g_signals.UpdatedTransaction.disconnect(boost::bind(&CValidationInterface::UpdatedTransaction, pwalletIn, _1)); g_signals.EraseTransaction.disconnect(boost::bind(&CValidationInterface::EraseFromWallet, pwalletIn, _1)); diff --git a/src/validationinterface.h b/src/validationinterface.h index 1855dacd7..60e90d012 100644 --- a/src/validationinterface.h +++ b/src/validationinterface.h @@ -34,7 +34,7 @@ protected: virtual void UpdatedBlockTip(const CBlockIndex *pindex) {} virtual void SyncTransaction(const CTransaction &tx, const CBlock *pblock) {} virtual void EraseFromWallet(const uint256 &hash) {} - virtual void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree tree, bool added) {} + virtual void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree sproutTree, ZCSaplingIncrementalMerkleTree saplingTree, bool added) {} virtual void SetBestChain(const CBlockLocator &locator) {} virtual void UpdatedTransaction(const uint256 &hash) {} virtual void Inventory(const uint256 &hash) {} @@ -55,7 +55,7 @@ struct CMainSignals { /** Notifies listeners of an updated transaction without new data (for now: a coinbase potentially becoming visible). */ boost::signals2::signal UpdatedTransaction; /** Notifies listeners of a change to the tip of the active block chain. */ - boost::signals2::signal ChainTip; + boost::signals2::signal ChainTip; /** Notifies listeners of a new active block chain. */ boost::signals2::signal SetBestChain; /** Notifies listeners about an inventory item being seen on the network. */ diff --git a/src/wallet/asyncrpcoperation_mergetoaddress.cpp b/src/wallet/asyncrpcoperation_mergetoaddress.cpp index 846d0d674..cc4551f6f 100644 --- a/src/wallet/asyncrpcoperation_mergetoaddress.cpp +++ b/src/wallet/asyncrpcoperation_mergetoaddress.cpp @@ -343,7 +343,7 @@ bool AsyncRPCOperation_mergetoaddress::main_impl() std::vector vOutPoints = {jso}; uint256 inputAnchor; std::vector> vInputWitnesses; - pwalletMain->GetNoteWitnesses(vOutPoints, vInputWitnesses, inputAnchor); + pwalletMain->GetSproutNoteWitnesses(vOutPoints, vInputWitnesses, inputAnchor); jsopWitnessAnchorMap[jso.ToString()] = MergeToAddressWitnessAnchorData{vInputWitnesses[0], inputAnchor}; } } @@ -711,7 +711,7 @@ UniValue AsyncRPCOperation_mergetoaddress::perform_joinsplit(MergeToAddressJSInf uint256 anchor; { LOCK(cs_main); - pwalletMain->GetNoteWitnesses(outPoints, witnesses, anchor); + pwalletMain->GetSproutNoteWitnesses(outPoints, witnesses, anchor); } return perform_joinsplit(info, witnesses, anchor); } diff --git a/src/wallet/asyncrpcoperation_sendmany.cpp b/src/wallet/asyncrpcoperation_sendmany.cpp index 197f21bf8..f5da139eb 100644 --- a/src/wallet/asyncrpcoperation_sendmany.cpp +++ b/src/wallet/asyncrpcoperation_sendmany.cpp @@ -420,7 +420,7 @@ bool AsyncRPCOperation_sendmany::main_impl() { std::vector vOutPoints = { jso }; uint256 inputAnchor; std::vector> vInputWitnesses; - pwalletMain->GetNoteWitnesses(vOutPoints, vInputWitnesses, inputAnchor); + pwalletMain->GetSproutNoteWitnesses(vOutPoints, vInputWitnesses, inputAnchor); jsopWitnessAnchorMap[ jso.ToString() ] = WitnessAnchorData{ vInputWitnesses[0], inputAnchor }; } } @@ -935,7 +935,7 @@ UniValue AsyncRPCOperation_sendmany::perform_joinsplit(AsyncJoinSplitInfo & info uint256 anchor; { LOCK(cs_main); - pwalletMain->GetNoteWitnesses(outPoints, witnesses, anchor); + pwalletMain->GetSproutNoteWitnesses(outPoints, witnesses, anchor); } return perform_joinsplit(info, witnesses, anchor); } diff --git a/src/wallet/gtest/test_wallet.cpp b/src/wallet/gtest/test_wallet.cpp index 9e2fdb8fb..83807986e 100644 --- a/src/wallet/gtest/test_wallet.cpp +++ b/src/wallet/gtest/test_wallet.cpp @@ -51,8 +51,9 @@ public: void IncrementNoteWitnesses(const CBlockIndex* pindex, const CBlock* pblock, - ZCIncrementalMerkleTree& tree) { - CWallet::IncrementNoteWitnesses(pindex, pblock, tree); + ZCIncrementalMerkleTree& sproutTree, + ZCSaplingIncrementalMerkleTree& saplingTree) { + CWallet::IncrementNoteWitnesses(pindex, pblock, sproutTree, saplingTree); } void DecrementNoteWitnesses(const CBlockIndex* pindex) { CWallet::DecrementNoteWitnesses(pindex); @@ -68,8 +69,8 @@ public: } }; -CWalletTx GetValidReceive(const libzcash::SproutSpendingKey& sk, CAmount value, bool randomInputs) { - return GetValidReceive(*params, sk, value, randomInputs); +CWalletTx GetValidReceive(const libzcash::SproutSpendingKey& sk, CAmount value, bool randomInputs, int32_t version = 2) { + return GetValidReceive(*params, sk, value, randomInputs, version); } libzcash::SproutNote GetNote(const libzcash::SproutSpendingKey& sk, @@ -82,26 +83,52 @@ CWalletTx GetValidSpend(const libzcash::SproutSpendingKey& sk, return GetValidSpend(*params, sk, note, value); } -JSOutPoint CreateValidBlock(TestWallet& wallet, +std::vector SetSaplingNoteData(CWalletTx& wtx) { + mapSaplingNoteData_t saplingNoteData; + SaplingOutPoint saplingOutPoint = {wtx.GetHash(), 0}; + SaplingNoteData saplingNd; + saplingNoteData[saplingOutPoint] = saplingNd; + wtx.SetSaplingNoteData(saplingNoteData); + std::vector saplingNotes {saplingOutPoint}; + return saplingNotes; +} + +std::pair CreateValidBlock(TestWallet& wallet, const libzcash::SproutSpendingKey& sk, const CBlockIndex& index, CBlock& block, - ZCIncrementalMerkleTree& tree) { - auto wtx = GetValidReceive(sk, 50, true); + ZCIncrementalMerkleTree& sproutTree, + ZCSaplingIncrementalMerkleTree& saplingTree) { + auto wtx = GetValidReceive(sk, 50, true, 4); auto note = GetNote(sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); + auto saplingNotes = SetSaplingNoteData(wtx); wallet.AddToWallet(wtx, true, NULL); block.vtx.push_back(wtx); - wallet.IncrementNoteWitnesses(&index, &block, tree); + wallet.IncrementNoteWitnesses(&index, &block, sproutTree, saplingTree); - return jsoutpt; + return std::make_pair(jsoutpt, saplingNotes[0]); +} + +std::pair GetWitnessesAndAnchors(TestWallet& wallet, + std::vector& sproutNotes, + std::vector& saplingNotes, + std::vector>& sproutWitnesses, + std::vector>& saplingWitnesses) { + sproutWitnesses.clear(); + saplingWitnesses.clear(); + uint256 sproutAnchor; + uint256 saplingAnchor; + wallet.GetSproutNoteWitnesses(sproutNotes, sproutWitnesses, sproutAnchor); + wallet.GetSaplingNoteWitnesses(saplingNotes, saplingWitnesses, saplingAnchor); + return std::make_pair(sproutAnchor, saplingAnchor); } TEST(wallet_tests, setup_datadir_location_run_as_first_test) { @@ -117,9 +144,9 @@ TEST(wallet_tests, note_data_serialisation) { auto note = GetNote(sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; ZCIncrementalMerkleTree tree; nd.witnesses.push_front(tree.witness()); noteData[jsoutpt] = nd; @@ -127,7 +154,7 @@ TEST(wallet_tests, note_data_serialisation) { CDataStream ss(SER_DISK, CLIENT_VERSION); ss << noteData; - mapNoteData_t noteData2; + mapSproutNoteData_t noteData2; ss >> noteData2; EXPECT_EQ(noteData, noteData2); @@ -145,12 +172,12 @@ TEST(wallet_tests, find_unspent_notes) { auto note = GetNote(sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); wallet.AddToWallet(wtx, true, NULL); EXPECT_FALSE(wallet.IsSpent(nullifier)); @@ -240,12 +267,12 @@ TEST(wallet_tests, find_unspent_notes) { auto note = GetNote(sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); wallet.AddToWallet(wtx, true, NULL); EXPECT_FALSE(wallet.IsSpent(nullifier)); @@ -299,28 +326,28 @@ TEST(wallet_tests, set_note_addrs_in_cwallettx) { auto wtx = GetValidReceive(sk, 10, true); auto note = GetNote(sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); - EXPECT_EQ(0, wtx.mapNoteData.size()); + EXPECT_EQ(0, wtx.mapSproutNoteData.size()); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); - EXPECT_EQ(noteData, wtx.mapNoteData); + wtx.SetSproutNoteData(noteData); + EXPECT_EQ(noteData, wtx.mapSproutNoteData); } TEST(wallet_tests, set_invalid_note_addrs_in_cwallettx) { CWalletTx wtx; - EXPECT_EQ(0, wtx.mapNoteData.size()); + EXPECT_EQ(0, wtx.mapSproutNoteData.size()); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; auto sk = libzcash::SproutSpendingKey::random(); JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), uint256()}; + SproutNoteData nd {sk.address(), uint256()}; noteData[jsoutpt] = nd; - EXPECT_THROW(wtx.SetNoteData(noteData), std::logic_error); + EXPECT_THROW(wtx.SetSproutNoteData(noteData), std::logic_error); } TEST(wallet_tests, GetNoteNullifier) { @@ -374,7 +401,7 @@ TEST(wallet_tests, FindMyNotes) { EXPECT_EQ(2, noteMap.size()); JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; EXPECT_EQ(1, noteMap.count(jsoutpt)); EXPECT_EQ(nd, noteMap[jsoutpt]); } @@ -397,7 +424,7 @@ TEST(wallet_tests, FindMyNotesInEncryptedWallet) { EXPECT_EQ(2, noteMap.size()); JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; EXPECT_EQ(1, noteMap.count(jsoutpt)); EXPECT_NE(nd, noteMap[jsoutpt]); @@ -490,12 +517,12 @@ TEST(wallet_tests, navigate_from_nullifier_to_note) { auto note = GetNote(sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); EXPECT_EQ(0, wallet.mapNullifiersToNotes.count(nullifier)); @@ -520,12 +547,12 @@ TEST(wallet_tests, spent_note_is_from_me) { EXPECT_FALSE(wallet.IsFromMe(wtx)); EXPECT_FALSE(wallet.IsFromMe(wtx2)); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); EXPECT_FALSE(wallet.IsFromMe(wtx)); EXPECT_FALSE(wallet.IsFromMe(wtx2)); @@ -540,44 +567,53 @@ TEST(wallet_tests, cached_witnesses_empty_chain) { auto sk = libzcash::SproutSpendingKey::random(); wallet.AddSpendingKey(sk); - auto wtx = GetValidReceive(sk, 10, true); + auto wtx = GetValidReceive(sk, 10, true, 4); auto note = GetNote(sk, wtx, 0, 0); auto note2 = GetNote(sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); auto nullifier2 = note2.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t sproutNoteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 0}; JSOutPoint jsoutpt2 {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; - CNoteData nd2 {sk.address(), nullifier2}; - noteData[jsoutpt] = nd; - noteData[jsoutpt2] = nd2; - wtx.SetNoteData(noteData); + SproutNoteData nd {sk.address(), nullifier}; + SproutNoteData nd2 {sk.address(), nullifier2}; + sproutNoteData[jsoutpt] = nd; + sproutNoteData[jsoutpt2] = nd2; + wtx.SetSproutNoteData(sproutNoteData); - std::vector notes {jsoutpt, jsoutpt2}; - std::vector> witnesses; - uint256 anchor; + std::vector sproutNotes {jsoutpt, jsoutpt2}; + std::vector saplingNotes = SetSaplingNoteData(wtx); - wallet.GetNoteWitnesses(notes, witnesses, anchor); - EXPECT_FALSE((bool) witnesses[0]); - EXPECT_FALSE((bool) witnesses[1]); + std::vector> sproutWitnesses; + std::vector> saplingWitnesses; + + ::GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + + EXPECT_FALSE((bool) sproutWitnesses[0]); + EXPECT_FALSE((bool) sproutWitnesses[1]); + EXPECT_FALSE((bool) saplingWitnesses[0]); wallet.AddToWallet(wtx, true, NULL); - witnesses.clear(); - wallet.GetNoteWitnesses(notes, witnesses, anchor); - EXPECT_FALSE((bool) witnesses[0]); - EXPECT_FALSE((bool) witnesses[1]); + + ::GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + + EXPECT_FALSE((bool) sproutWitnesses[0]); + EXPECT_FALSE((bool) sproutWitnesses[1]); + EXPECT_FALSE((bool) saplingWitnesses[0]); CBlock block; block.vtx.push_back(wtx); CBlockIndex index(block); - ZCIncrementalMerkleTree tree; - wallet.IncrementNoteWitnesses(&index, &block, tree); - witnesses.clear(); - wallet.GetNoteWitnesses(notes, witnesses, anchor); - EXPECT_TRUE((bool) witnesses[0]); - EXPECT_TRUE((bool) witnesses[1]); + ZCIncrementalMerkleTree sproutTree; + ZCSaplingIncrementalMerkleTree saplingTree; + wallet.IncrementNoteWitnesses(&index, &block, sproutTree, saplingTree); + + ::GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + + EXPECT_TRUE((bool) sproutWitnesses[0]); + EXPECT_TRUE((bool) sproutWitnesses[1]); + EXPECT_TRUE((bool) saplingWitnesses[0]); // Until #1302 is implemented, this should triggger an assertion EXPECT_DEATH(wallet.DecrementNoteWitnesses(&index), @@ -586,9 +622,10 @@ TEST(wallet_tests, cached_witnesses_empty_chain) { TEST(wallet_tests, cached_witnesses_chain_tip) { TestWallet wallet; - uint256 anchor1; + std::pair anchors1; CBlock block1; - ZCIncrementalMerkleTree tree; + ZCIncrementalMerkleTree sproutTree; + ZCSaplingIncrementalMerkleTree saplingTree; auto sk = libzcash::SproutSpendingKey::random(); wallet.AddSpendingKey(sk); @@ -597,33 +634,40 @@ TEST(wallet_tests, cached_witnesses_chain_tip) { // First block (case tested in _empty_chain) CBlockIndex index1(block1); index1.nHeight = 1; - auto jsoutpt = CreateValidBlock(wallet, sk, index1, block1, tree); + auto outpts = CreateValidBlock(wallet, sk, index1, block1, sproutTree, saplingTree); // Called to fetch anchor - std::vector notes {jsoutpt}; - std::vector> witnesses; - wallet.GetNoteWitnesses(notes, witnesses, anchor1); + std::vector sproutNotes {outpts.first}; + std::vector saplingNotes {outpts.second}; + std::vector> sproutWitnesses; + std::vector> saplingWitnesses; + + anchors1 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + EXPECT_NE(anchors1.first, anchors1.second); } { // Second transaction - auto wtx = GetValidReceive(sk, 50, true); + auto wtx = GetValidReceive(sk, 50, true, 4); auto note = GetNote(sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t sproutNoteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; - noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + SproutNoteData nd {sk.address(), nullifier}; + sproutNoteData[jsoutpt] = nd; + wtx.SetSproutNoteData(sproutNoteData); + std::vector saplingNotes = SetSaplingNoteData(wtx); wallet.AddToWallet(wtx, true, NULL); - std::vector notes {jsoutpt}; - std::vector> witnesses; - uint256 anchor2; + std::vector sproutNotes {jsoutpt}; + std::vector> sproutWitnesses; + std::vector> saplingWitnesses; - wallet.GetNoteWitnesses(notes, witnesses, anchor2); - EXPECT_FALSE((bool) witnesses[0]); + GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + + EXPECT_FALSE((bool) sproutWitnesses[0]); + EXPECT_FALSE((bool) saplingWitnesses[0]); // Second block CBlock block2; @@ -631,46 +675,57 @@ TEST(wallet_tests, cached_witnesses_chain_tip) { block2.vtx.push_back(wtx); CBlockIndex index2(block2); index2.nHeight = 2; - ZCIncrementalMerkleTree tree2 {tree}; - wallet.IncrementNoteWitnesses(&index2, &block2, tree2); - witnesses.clear(); - wallet.GetNoteWitnesses(notes, witnesses, anchor2); - EXPECT_TRUE((bool) witnesses[0]); - EXPECT_NE(anchor1, anchor2); + ZCIncrementalMerkleTree sproutTree2 {sproutTree}; + ZCSaplingIncrementalMerkleTree saplingTree2 {saplingTree}; + wallet.IncrementNoteWitnesses(&index2, &block2, sproutTree2, saplingTree2); + + auto anchors2 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + EXPECT_NE(anchors2.first, anchors2.second); + + EXPECT_TRUE((bool) sproutWitnesses[0]); + EXPECT_TRUE((bool) saplingWitnesses[0]); + EXPECT_NE(anchors1.first, anchors2.first); + EXPECT_NE(anchors1.second, anchors2.second); // Decrementing should give us the previous anchor - uint256 anchor3; wallet.DecrementNoteWitnesses(&index2); - witnesses.clear(); - wallet.GetNoteWitnesses(notes, witnesses, anchor3); - EXPECT_FALSE((bool) witnesses[0]); + auto anchors3 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + + EXPECT_FALSE((bool) sproutWitnesses[0]); + EXPECT_FALSE((bool) saplingWitnesses[0]); // Should not equal first anchor because none of these notes had witnesses - EXPECT_NE(anchor1, anchor3); + EXPECT_NE(anchors1.first, anchors3.first); + EXPECT_NE(anchors1.second, anchors3.second); // Re-incrementing with the same block should give the same result - uint256 anchor4; - wallet.IncrementNoteWitnesses(&index2, &block2, tree); - witnesses.clear(); - wallet.GetNoteWitnesses(notes, witnesses, anchor4); - EXPECT_TRUE((bool) witnesses[0]); - EXPECT_EQ(anchor2, anchor4); + wallet.IncrementNoteWitnesses(&index2, &block2, sproutTree, saplingTree); + auto anchors4 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + EXPECT_NE(anchors4.first, anchors4.second); + + EXPECT_TRUE((bool) sproutWitnesses[0]); + EXPECT_TRUE((bool) saplingWitnesses[0]); + EXPECT_EQ(anchors2.first, anchors4.first); + EXPECT_EQ(anchors2.second, anchors4.second); // Incrementing with the same block again should not change the cache - uint256 anchor5; - wallet.IncrementNoteWitnesses(&index2, &block2, tree); - std::vector> witnesses5; - wallet.GetNoteWitnesses(notes, witnesses5, anchor5); - EXPECT_EQ(witnesses, witnesses5); - EXPECT_EQ(anchor4, anchor5); + wallet.IncrementNoteWitnesses(&index2, &block2, sproutTree, saplingTree); + std::vector> sproutWitnesses5; + std::vector> saplingWitnesses5; + + auto anchors5 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses5, saplingWitnesses5); + EXPECT_NE(anchors5.first, anchors5.second); + + EXPECT_EQ(sproutWitnesses, sproutWitnesses5); + EXPECT_EQ(saplingWitnesses, saplingWitnesses5); + EXPECT_EQ(anchors4.first, anchors5.first); + EXPECT_EQ(anchors4.second, anchors5.second); } } TEST(wallet_tests, CachedWitnessesDecrementFirst) { TestWallet wallet; - uint256 anchor2; - CBlock block2; - CBlockIndex index2(block2); - ZCIncrementalMerkleTree tree; + ZCIncrementalMerkleTree sproutTree; + ZCSaplingIncrementalMerkleTree saplingTree; auto sk = libzcash::SproutSpendingKey::random(); wallet.AddSpendingKey(sk); @@ -680,57 +735,70 @@ TEST(wallet_tests, CachedWitnessesDecrementFirst) { CBlock block1; CBlockIndex index1(block1); index1.nHeight = 1; - CreateValidBlock(wallet, sk, index1, block1, tree); + CreateValidBlock(wallet, sk, index1, block1, sproutTree, saplingTree); } + std::pair anchors2; + CBlock block2; + CBlockIndex index2(block2); + { // Second block (case tested in _chain_tip) index2.nHeight = 2; - auto jsoutpt = CreateValidBlock(wallet, sk, index2, block2, tree); + auto outpts = CreateValidBlock(wallet, sk, index2, block2, sproutTree, saplingTree); // Called to fetch anchor - std::vector notes {jsoutpt}; - std::vector> witnesses; - wallet.GetNoteWitnesses(notes, witnesses, anchor2); + std::vector sproutNotes {outpts.first}; + std::vector saplingNotes {outpts.second}; + std::vector> sproutWitnesses; + std::vector> saplingWitnesses; + anchors2 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); } - { +{ // Third transaction - never mined - auto wtx = GetValidReceive(sk, 20, true); + auto wtx = GetValidReceive(sk, 20, true, 4); auto note = GetNote(sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); + std::vector saplingNotes = SetSaplingNoteData(wtx); wallet.AddToWallet(wtx, true, NULL); - std::vector notes {jsoutpt}; - std::vector> witnesses; - uint256 anchor3; + std::vector sproutNotes {jsoutpt}; + std::vector> sproutWitnesses; + std::vector> saplingWitnesses; - wallet.GetNoteWitnesses(notes, witnesses, anchor3); - EXPECT_FALSE((bool) witnesses[0]); + auto anchors3 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + + EXPECT_FALSE((bool) sproutWitnesses[0]); + EXPECT_FALSE((bool) saplingWitnesses[0]); // Decrementing (before the transaction has ever seen an increment) // should give us the previous anchor - uint256 anchor4; wallet.DecrementNoteWitnesses(&index2); - witnesses.clear(); - wallet.GetNoteWitnesses(notes, witnesses, anchor4); - EXPECT_FALSE((bool) witnesses[0]); + + auto anchors4 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + + EXPECT_FALSE((bool) sproutWitnesses[0]); + EXPECT_FALSE((bool) saplingWitnesses[0]); // Should not equal second anchor because none of these notes had witnesses - EXPECT_NE(anchor2, anchor4); + EXPECT_NE(anchors2.first, anchors4.first); + EXPECT_NE(anchors2.second, anchors4.second); // Re-incrementing with the same block should give the same result - uint256 anchor5; - wallet.IncrementNoteWitnesses(&index2, &block2, tree); - witnesses.clear(); - wallet.GetNoteWitnesses(notes, witnesses, anchor5); - EXPECT_FALSE((bool) witnesses[0]); - EXPECT_EQ(anchor3, anchor5); + wallet.IncrementNoteWitnesses(&index2, &block2, sproutTree, saplingTree); + + auto anchors5 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + + EXPECT_FALSE((bool) sproutWitnesses[0]); + EXPECT_FALSE((bool) saplingWitnesses[0]); + EXPECT_EQ(anchors3.first, anchors5.first); + EXPECT_EQ(anchors3.second, anchors5.second); } } @@ -738,11 +806,16 @@ TEST(wallet_tests, CachedWitnessesCleanIndex) { TestWallet wallet; std::vector blocks; std::vector indices; - std::vector notes; - std::vector anchors; - ZCIncrementalMerkleTree tree; - ZCIncrementalMerkleTree riTree = tree; - std::vector> witnesses; + std::vector sproutNotes; + std::vector saplingNotes; + std::vector sproutAnchors; + std::vector saplingAnchors; + ZCIncrementalMerkleTree sproutTree; + ZCIncrementalMerkleTree sproutRiTree = sproutTree; + ZCSaplingIncrementalMerkleTree saplingTree; + ZCSaplingIncrementalMerkleTree saplingRiTree = saplingTree; + std::vector> sproutWitnesses; + std::vector> saplingWitnesses; auto sk = libzcash::SproutSpendingKey::random(); wallet.AddSpendingKey(sk); @@ -753,58 +826,64 @@ TEST(wallet_tests, CachedWitnessesCleanIndex) { indices.resize(numBlocks); for (size_t i = 0; i < numBlocks; i++) { indices[i].nHeight = i; - auto old = tree.root(); - auto jsoutpt = CreateValidBlock(wallet, sk, indices[i], blocks[i], tree); - EXPECT_NE(old, tree.root()); - notes.push_back(jsoutpt); + auto oldSproutRoot = sproutTree.root(); + auto oldSaplingRoot = saplingTree.root(); + auto outpts = CreateValidBlock(wallet, sk, indices[i], blocks[i], sproutTree, saplingTree); + EXPECT_NE(oldSproutRoot, sproutTree.root()); + EXPECT_NE(oldSaplingRoot, saplingTree.root()); + sproutNotes.push_back(outpts.first); + saplingNotes.push_back(outpts.second); - witnesses.clear(); - uint256 anchor; - wallet.GetNoteWitnesses(notes, witnesses, anchor); + auto anchors = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); for (size_t j = 0; j <= i; j++) { - EXPECT_TRUE((bool) witnesses[j]); + EXPECT_TRUE((bool) sproutWitnesses[j]); + EXPECT_TRUE((bool) saplingWitnesses[j]); } - anchors.push_back(anchor); + sproutAnchors.push_back(anchors.first); + saplingAnchors.push_back(anchors.second); } // Now pretend we are reindexing: the chain is cleared, and each block is // used to increment witnesses again. for (size_t i = 0; i < numBlocks; i++) { - ZCIncrementalMerkleTree riPrevTree {riTree}; - wallet.IncrementNoteWitnesses(&(indices[i]), &(blocks[i]), riTree); - witnesses.clear(); - uint256 anchor; - wallet.GetNoteWitnesses(notes, witnesses, anchor); + ZCIncrementalMerkleTree sproutRiPrevTree {sproutRiTree}; + ZCSaplingIncrementalMerkleTree saplingRiPrevTree {saplingRiTree}; + wallet.IncrementNoteWitnesses(&(indices[i]), &(blocks[i]), sproutRiTree, saplingRiTree); + + auto anchors = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); for (size_t j = 0; j < numBlocks; j++) { - EXPECT_TRUE((bool) witnesses[j]); + EXPECT_TRUE((bool) sproutWitnesses[j]); + EXPECT_TRUE((bool) saplingWitnesses[j]); } // Should equal final anchor because witness cache unaffected - EXPECT_EQ(anchors.back(), anchor); + EXPECT_EQ(sproutAnchors.back(), anchors.first); + EXPECT_EQ(saplingAnchors.back(), anchors.second); if ((i == 5) || (i == 50)) { // Pretend a reorg happened that was recorded in the block files { wallet.DecrementNoteWitnesses(&(indices[i])); - witnesses.clear(); - uint256 anchor; - wallet.GetNoteWitnesses(notes, witnesses, anchor); + + auto anchors = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); for (size_t j = 0; j < numBlocks; j++) { - EXPECT_TRUE((bool) witnesses[j]); + EXPECT_TRUE((bool) sproutWitnesses[j]); + EXPECT_TRUE((bool) saplingWitnesses[j]); } // Should equal final anchor because witness cache unaffected - EXPECT_EQ(anchors.back(), anchor); + EXPECT_EQ(sproutAnchors.back(), anchors.first); + EXPECT_EQ(saplingAnchors.back(), anchors.second); } { - wallet.IncrementNoteWitnesses(&(indices[i]), &(blocks[i]), riPrevTree); - witnesses.clear(); - uint256 anchor; - wallet.GetNoteWitnesses(notes, witnesses, anchor); + wallet.IncrementNoteWitnesses(&(indices[i]), &(blocks[i]), sproutRiPrevTree, saplingRiPrevTree); + auto anchors = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); for (size_t j = 0; j < numBlocks; j++) { - EXPECT_TRUE((bool) witnesses[j]); + EXPECT_TRUE((bool) sproutWitnesses[j]); + EXPECT_TRUE((bool) saplingWitnesses[j]); } // Should equal final anchor because witness cache unaffected - EXPECT_EQ(anchors.back(), anchor); + EXPECT_EQ(sproutAnchors.back(), anchors.first); + EXPECT_EQ(saplingAnchors.back(), anchors.second); } } } @@ -816,44 +895,55 @@ TEST(wallet_tests, ClearNoteWitnessCache) { auto sk = libzcash::SproutSpendingKey::random(); wallet.AddSpendingKey(sk); - auto wtx = GetValidReceive(sk, 10, true); + auto wtx = GetValidReceive(sk, 10, true, 4); auto hash = wtx.GetHash(); auto note = GetNote(sk, wtx, 0, 0); auto nullifier = note.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 0}; JSOutPoint jsoutpt2 {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); + auto saplingNotes = SetSaplingNoteData(wtx); // Pretend we mined the tx by adding a fake witness - ZCIncrementalMerkleTree tree; - wtx.mapNoteData[jsoutpt].witnesses.push_front(tree.witness()); - wtx.mapNoteData[jsoutpt].witnessHeight = 1; + ZCIncrementalMerkleTree sproutTree; + wtx.mapSproutNoteData[jsoutpt].witnesses.push_front(sproutTree.witness()); + wtx.mapSproutNoteData[jsoutpt].witnessHeight = 1; wallet.nWitnessCacheSize = 1; + ZCSaplingIncrementalMerkleTree saplingTree; + wtx.mapSaplingNoteData[saplingNotes[0]].witnesses.push_front(saplingTree.witness()); + wtx.mapSaplingNoteData[saplingNotes[0]].witnessHeight = 1; + wallet.nWitnessCacheSize = 2; + wallet.AddToWallet(wtx, true, NULL); - std::vector notes {jsoutpt, jsoutpt2}; - std::vector> witnesses; - uint256 anchor2; + std::vector sproutNotes {jsoutpt, jsoutpt2}; + std::vector> sproutWitnesses; + std::vector> saplingWitnesses; // Before clearing, we should have a witness for one note - wallet.GetNoteWitnesses(notes, witnesses, anchor2); - EXPECT_TRUE((bool) witnesses[0]); - EXPECT_FALSE((bool) witnesses[1]); - EXPECT_EQ(1, wallet.mapWallet[hash].mapNoteData[jsoutpt].witnessHeight); - EXPECT_EQ(1, wallet.nWitnessCacheSize); + GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + EXPECT_TRUE((bool) sproutWitnesses[0]); + EXPECT_FALSE((bool) sproutWitnesses[1]); + EXPECT_TRUE((bool) saplingWitnesses[0]); + EXPECT_FALSE((bool) saplingWitnesses[1]); + EXPECT_EQ(1, wallet.mapWallet[hash].mapSproutNoteData[jsoutpt].witnessHeight); + EXPECT_EQ(1, wallet.mapWallet[hash].mapSaplingNoteData[saplingNotes[0]].witnessHeight); + EXPECT_EQ(2, wallet.nWitnessCacheSize); // After clearing, we should not have a witness for either note wallet.ClearNoteWitnessCache(); - witnesses.clear(); - wallet.GetNoteWitnesses(notes, witnesses, anchor2); - EXPECT_FALSE((bool) witnesses[0]); - EXPECT_FALSE((bool) witnesses[1]); - EXPECT_EQ(-1, wallet.mapWallet[hash].mapNoteData[jsoutpt].witnessHeight); + auto anchros2 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses); + EXPECT_FALSE((bool) sproutWitnesses[0]); + EXPECT_FALSE((bool) sproutWitnesses[1]); + EXPECT_FALSE((bool) saplingWitnesses[0]); + EXPECT_FALSE((bool) saplingWitnesses[1]); + EXPECT_EQ(-1, wallet.mapWallet[hash].mapSproutNoteData[jsoutpt].witnessHeight); + EXPECT_EQ(-1, wallet.mapWallet[hash].mapSaplingNoteData[saplingNotes[0]].witnessHeight); EXPECT_EQ(0, wallet.nWitnessCacheSize); } @@ -949,11 +1039,11 @@ TEST(wallet_tests, UpdateNullifierNoteMap) { auto nullifier = note.nullifier(sk); // Pretend that we called FindMyNotes while the wallet was locked - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address()}; + SproutNoteData nd {sk.address()}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); wallet.AddToWallet(wtx, true, NULL); EXPECT_EQ(0, wallet.mapNullifiersToNotes.count(nullifier)); @@ -984,35 +1074,35 @@ TEST(wallet_tests, UpdatedNoteData) { // First pretend we added the tx to the wallet and // we don't have the key for the second note - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 0}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); // Pretend we mined the tx by adding a fake witness ZCIncrementalMerkleTree tree; - wtx.mapNoteData[jsoutpt].witnesses.push_front(tree.witness()); - wtx.mapNoteData[jsoutpt].witnessHeight = 100; + wtx.mapSproutNoteData[jsoutpt].witnesses.push_front(tree.witness()); + wtx.mapSproutNoteData[jsoutpt].witnessHeight = 100; // Now pretend we added the key for the second note, and // the tx was "added" to the wallet again to update it. // This happens via the 'z_importkey' RPC method. JSOutPoint jsoutpt2 {wtx2.GetHash(), 0, 1}; - CNoteData nd2 {sk.address(), nullifier2}; + SproutNoteData nd2 {sk.address(), nullifier2}; noteData[jsoutpt2] = nd2; - wtx2.SetNoteData(noteData); + wtx2.SetSproutNoteData(noteData); // The txs should initially be different - EXPECT_NE(wtx.mapNoteData, wtx2.mapNoteData); - EXPECT_EQ(1, wtx.mapNoteData[jsoutpt].witnesses.size()); - EXPECT_EQ(100, wtx.mapNoteData[jsoutpt].witnessHeight); + EXPECT_NE(wtx.mapSproutNoteData, wtx2.mapSproutNoteData); + EXPECT_EQ(1, wtx.mapSproutNoteData[jsoutpt].witnesses.size()); + EXPECT_EQ(100, wtx.mapSproutNoteData[jsoutpt].witnessHeight); // After updating, they should be the same EXPECT_TRUE(wallet.UpdatedNoteData(wtx2, wtx)); - EXPECT_EQ(wtx.mapNoteData, wtx2.mapNoteData); - EXPECT_EQ(1, wtx.mapNoteData[jsoutpt].witnesses.size()); - EXPECT_EQ(100, wtx.mapNoteData[jsoutpt].witnessHeight); + EXPECT_EQ(wtx.mapSproutNoteData, wtx2.mapSproutNoteData); + EXPECT_EQ(1, wtx.mapSproutNoteData[jsoutpt].witnesses.size()); + EXPECT_EQ(100, wtx.mapSproutNoteData[jsoutpt].witnessHeight); // TODO: The new note should get witnessed (but maybe not here) (#1350) } @@ -1028,12 +1118,12 @@ TEST(wallet_tests, MarkAffectedTransactionsDirty) { auto nullifier = note.nullifier(sk); auto wtx2 = GetValidSpend(sk, note, 5); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {hash, 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); wallet.AddToWallet(wtx, true, NULL); wallet.MarkAffectedTransactionsDirty(wtx); diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 4d57d04c2..d02b03e7c 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -449,11 +449,14 @@ bool CWallet::ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase, return false; } -void CWallet::ChainTip(const CBlockIndex *pindex, const CBlock *pblock, - ZCIncrementalMerkleTree tree, bool added) +void CWallet::ChainTip(const CBlockIndex *pindex, + const CBlock *pblock, + ZCIncrementalMerkleTree sproutTree, + ZCSaplingIncrementalMerkleTree saplingTree, + bool added) { if (added) { - IncrementNoteWitnesses(pindex, pblock, tree); + IncrementNoteWitnesses(pindex, pblock, sproutTree, saplingTree); } else { DecrementNoteWitnesses(pindex); } @@ -469,7 +472,7 @@ std::set> CWallet::GetNullifiersFor { std::set> nullifierSet; for (const auto & txPair : mapWallet) { - for (const auto & noteDataPair : txPair.second.mapNoteData) { + for (const auto & noteDataPair : txPair.second.mapSproutNoteData) { if (noteDataPair.second.nullifier && addresses.count(noteDataPair.second.address)) { nullifierSet.insert(std::make_pair(noteDataPair.second.address, noteDataPair.second.nullifier.get())); } @@ -653,7 +656,7 @@ void CWallet::SyncMetaData(pair::iterator, typename TxSpe CWalletTx* copyTo = &mapWallet[hash]; if (copyFrom == copyTo) continue; copyTo->mapValue = copyFrom->mapValue; - // mapNoteData not copied on purpose + // mapSproutNoteData not copied on purpose // (it is always set correctly for each CWalletTx) copyTo->vOrderForm = copyFrom->vOrderForm; // fTimeReceivedIsTxTime not copied on purpose @@ -744,7 +747,11 @@ void CWallet::ClearNoteWitnessCache() { LOCK(cs_wallet); for (std::pair& wtxItem : mapWallet) { - for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { + for (mapSproutNoteData_t::value_type& item : wtxItem.second.mapSproutNoteData) { + item.second.witnesses.clear(); + item.second.witnessHeight = -1; + } + for (mapSaplingNoteData_t::value_type& item : wtxItem.second.mapSaplingNoteData) { item.second.witnesses.clear(); item.second.witnessHeight = -1; } @@ -752,176 +759,219 @@ void CWallet::ClearNoteWitnessCache() nWitnessCacheSize = 0; } +template +void CopyPreviousWitnesses(NoteDataMap& noteDataMap, int indexHeight, int64_t nWitnessCacheSize) +{ + for (auto& item : noteDataMap) { + auto* nd = &(item.second); + // Only increment witnesses that are behind the current height + if (nd->witnessHeight < indexHeight) { + // Check the validity of the cache + // The only time a note witnessed above the current height + // would be invalid here is during a reindex when blocks + // have been decremented, and we are incrementing the blocks + // immediately after. + assert(nWitnessCacheSize >= nd->witnesses.size()); + // Witnesses being incremented should always be either -1 + // (never incremented or decremented) or one below indexHeight + assert((nd->witnessHeight == -1) || (nd->witnessHeight == indexHeight - 1)); + // Copy the witness for the previous block if we have one + if (nd->witnesses.size() > 0) { + nd->witnesses.push_front(nd->witnesses.front()); + } + if (nd->witnesses.size() > WITNESS_CACHE_SIZE) { + nd->witnesses.pop_back(); + } + } + } +} + +template +void AppendNoteCommitment(NoteDataMap& noteDataMap, int indexHeight, int64_t nWitnessCacheSize, const uint256& note_commitment) +{ + for (auto& item : noteDataMap) { + auto* nd = &(item.second); + if (nd->witnessHeight < indexHeight && nd->witnesses.size() > 0) { + // Check the validity of the cache + // See comment in CopyPreviousWitnesses about validity. + assert(nWitnessCacheSize >= nd->witnesses.size()); + nd->witnesses.front().append(note_commitment); + } + } +} + +template +void WitnessNoteIfMine(std::map& noteDataMap, int indexHeight, int64_t nWitnessCacheSize, const OutPoint& key, const Witness& witness) +{ + if (noteDataMap.count(key) && noteDataMap[key].witnessHeight < indexHeight) { + auto* nd = &(noteDataMap[key]); + if (nd->witnesses.size() > 0) { + // We think this can happen because we write out the + // witness cache state after every block increment or + // decrement, but the block index itself is written in + // batches. So if the node crashes in between these two + // operations, it is possible for IncrementNoteWitnesses + // to be called again on previously-cached blocks. This + // doesn't affect existing cached notes because of the + // NoteData::witnessHeight checks. See #1378 for details. + LogPrintf("Inconsistent witness cache state found for %s\n- Cache size: %d\n- Top (height %d): %s\n- New (height %d): %s\n", + key.ToString(), nd->witnesses.size(), + nd->witnessHeight, + nd->witnesses.front().root().GetHex(), + indexHeight, + witness.root().GetHex()); + nd->witnesses.clear(); + } + nd->witnesses.push_front(witness); + // Set height to one less than pindex so it gets incremented + nd->witnessHeight = indexHeight - 1; + // Check the validity of the cache + assert(nWitnessCacheSize >= nd->witnesses.size()); + } +} + + +template +void UpdateWitnessHeights(NoteDataMap& noteDataMap, int indexHeight, int64_t nWitnessCacheSize) +{ + for (auto& item : noteDataMap) { + auto* nd = &(item.second); + if (nd->witnessHeight < indexHeight) { + nd->witnessHeight = indexHeight; + // Check the validity of the cache + // See comment in CopyPreviousWitnesses about validity. + assert(nWitnessCacheSize >= nd->witnesses.size()); + } + } +} + void CWallet::IncrementNoteWitnesses(const CBlockIndex* pindex, const CBlock* pblockIn, - ZCIncrementalMerkleTree& tree) + ZCIncrementalMerkleTree& sproutTree, + ZCSaplingIncrementalMerkleTree& saplingTree) { - { - LOCK(cs_wallet); - for (std::pair& wtxItem : mapWallet) { - for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { - CNoteData* nd = &(item.second); - // Only increment witnesses that are behind the current height - if (nd->witnessHeight < pindex->nHeight) { - // Check the validity of the cache - // The only time a note witnessed above the current height - // would be invalid here is during a reindex when blocks - // have been decremented, and we are incrementing the blocks - // immediately after. - assert(nWitnessCacheSize >= nd->witnesses.size()); - // Witnesses being incremented should always be either -1 - // (never incremented or decremented) or one below pindex - assert((nd->witnessHeight == -1) || - (nd->witnessHeight == pindex->nHeight - 1)); - // Copy the witness for the previous block if we have one - if (nd->witnesses.size() > 0) { - nd->witnesses.push_front(nd->witnesses.front()); - } - if (nd->witnesses.size() > WITNESS_CACHE_SIZE) { - nd->witnesses.pop_back(); - } + LOCK(cs_wallet); + for (std::pair& wtxItem : mapWallet) { + ::CopyPreviousWitnesses(wtxItem.second.mapSproutNoteData, pindex->nHeight, nWitnessCacheSize); + ::CopyPreviousWitnesses(wtxItem.second.mapSaplingNoteData, pindex->nHeight, nWitnessCacheSize); + } + + if (nWitnessCacheSize < WITNESS_CACHE_SIZE) { + nWitnessCacheSize += 1; + } + + const CBlock* pblock {pblockIn}; + CBlock block; + if (!pblock) { + ReadBlockFromDisk(block, pindex); + pblock = █ + } + + for (const CTransaction& tx : pblock->vtx) { + auto hash = tx.GetHash(); + bool txIsOurs = mapWallet.count(hash); + // Sprout + for (size_t i = 0; i < tx.vjoinsplit.size(); i++) { + const JSDescription& jsdesc = tx.vjoinsplit[i]; + for (uint8_t j = 0; j < jsdesc.commitments.size(); j++) { + const uint256& note_commitment = jsdesc.commitments[j]; + sproutTree.append(note_commitment); + + // Increment existing witnesses + for (std::pair& wtxItem : mapWallet) { + ::AppendNoteCommitment(wtxItem.second.mapSproutNoteData, pindex->nHeight, nWitnessCacheSize, note_commitment); + } + + // If this is our note, witness it + if (txIsOurs) { + JSOutPoint jsoutpt {hash, i, j}; + ::WitnessNoteIfMine(mapWallet[hash].mapSproutNoteData, pindex->nHeight, nWitnessCacheSize, jsoutpt, sproutTree.witness()); } } } - if (nWitnessCacheSize < WITNESS_CACHE_SIZE) { - nWitnessCacheSize += 1; - } + // Sapling + for (uint32_t i = 0; i < tx.vShieldedOutput.size(); i++) { + const uint256& note_commitment = tx.vShieldedOutput[i].cm; + saplingTree.append(note_commitment); - const CBlock* pblock {pblockIn}; - CBlock block; - if (!pblock) { - ReadBlockFromDisk(block, pindex); - pblock = █ - } + // Increment existing witnesses + for (std::pair& wtxItem : mapWallet) { + ::AppendNoteCommitment(wtxItem.second.mapSaplingNoteData, pindex->nHeight, nWitnessCacheSize, note_commitment); + } - for (const CTransaction& tx : pblock->vtx) { - auto hash = tx.GetHash(); - bool txIsOurs = mapWallet.count(hash); - for (size_t i = 0; i < tx.vjoinsplit.size(); i++) { - const JSDescription& jsdesc = tx.vjoinsplit[i]; - for (uint8_t j = 0; j < jsdesc.commitments.size(); j++) { - const uint256& note_commitment = jsdesc.commitments[j]; - tree.append(note_commitment); - - // Increment existing witnesses - for (std::pair& wtxItem : mapWallet) { - for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { - CNoteData* nd = &(item.second); - if (nd->witnessHeight < pindex->nHeight && - nd->witnesses.size() > 0) { - // Check the validity of the cache - // See earlier comment about validity. - assert(nWitnessCacheSize >= nd->witnesses.size()); - nd->witnesses.front().append(note_commitment); - } - } - } - - // If this is our note, witness it - if (txIsOurs) { - JSOutPoint jsoutpt {hash, i, j}; - if (mapWallet[hash].mapNoteData.count(jsoutpt) && - mapWallet[hash].mapNoteData[jsoutpt].witnessHeight < pindex->nHeight) { - CNoteData* nd = &(mapWallet[hash].mapNoteData[jsoutpt]); - if (nd->witnesses.size() > 0) { - // We think this can happen because we write out the - // witness cache state after every block increment or - // decrement, but the block index itself is written in - // batches. So if the node crashes in between these two - // operations, it is possible for IncrementNoteWitnesses - // to be called again on previously-cached blocks. This - // doesn't affect existing cached notes because of the - // CNoteData::witnessHeight checks. See #1378 for details. - LogPrintf("Inconsistent witness cache state found for %s\n- Cache size: %d\n- Top (height %d): %s\n- New (height %d): %s\n", - jsoutpt.ToString(), nd->witnesses.size(), - nd->witnessHeight, - nd->witnesses.front().root().GetHex(), - pindex->nHeight, - tree.witness().root().GetHex()); - nd->witnesses.clear(); - } - nd->witnesses.push_front(tree.witness()); - // Set height to one less than pindex so it gets incremented - nd->witnessHeight = pindex->nHeight - 1; - // Check the validity of the cache - assert(nWitnessCacheSize >= nd->witnesses.size()); - } - } - } + // If this is our note, witness it + if (txIsOurs) { + SaplingOutPoint outPoint {hash, i}; + ::WitnessNoteIfMine(mapWallet[hash].mapSaplingNoteData, pindex->nHeight, nWitnessCacheSize, outPoint, saplingTree.witness()); } } + } - // Update witness heights - for (std::pair& wtxItem : mapWallet) { - for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { - CNoteData* nd = &(item.second); - if (nd->witnessHeight < pindex->nHeight) { - nd->witnessHeight = pindex->nHeight; - // Check the validity of the cache - // See earlier comment about validity. - assert(nWitnessCacheSize >= nd->witnesses.size()); - } + // Update witness heights + for (std::pair& wtxItem : mapWallet) { + ::UpdateWitnessHeights(wtxItem.second.mapSproutNoteData, pindex->nHeight, nWitnessCacheSize); + ::UpdateWitnessHeights(wtxItem.second.mapSaplingNoteData, pindex->nHeight, nWitnessCacheSize); + } + + // For performance reasons, we write out the witness cache in + // CWallet::SetBestChain() (which also ensures that overall consistency + // of the wallet.dat is maintained). +} + +template +void DecrementNoteWitnesses(NoteDataMap& noteDataMap, int indexHeight, int64_t nWitnessCacheSize) +{ + for (auto& item : noteDataMap) { + auto* nd = &(item.second); + // Only decrement witnesses that are not above the current height + if (nd->witnessHeight <= indexHeight) { + // Check the validity of the cache + // See comment below (this would be invalid if there were a + // prior decrement). + assert(nWitnessCacheSize >= nd->witnesses.size()); + // Witnesses being decremented should always be either -1 + // (never incremented or decremented) or equal to the height + // of the block being removed (indexHeight) + assert((nd->witnessHeight == -1) || (nd->witnessHeight == indexHeight)); + if (nd->witnesses.size() > 0) { + nd->witnesses.pop_front(); } + // indexHeight is the height of the block being removed, so + // the new witness cache height is one below it. + nd->witnessHeight = indexHeight - 1; + } + // Check the validity of the cache + // Technically if there are notes witnessed above the current + // height, their cache will now be invalid (relative to the new + // value of nWitnessCacheSize). However, this would only occur + // during a reindex, and by the time the reindex reaches the tip + // of the chain again, the existing witness caches will be valid + // again. + // We don't set nWitnessCacheSize to zero at the start of the + // reindex because the on-disk blocks had already resulted in a + // chain that didn't trigger the assertion below. + if (nd->witnessHeight < indexHeight) { + // Subtract 1 to compare to what nWitnessCacheSize will be after + // decrementing. + assert((nWitnessCacheSize - 1) >= nd->witnesses.size()); } - - // For performance reasons, we write out the witness cache in - // CWallet::SetBestChain() (which also ensures that overall consistency - // of the wallet.dat is maintained). } } void CWallet::DecrementNoteWitnesses(const CBlockIndex* pindex) { - { - LOCK(cs_wallet); - for (std::pair& wtxItem : mapWallet) { - for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { - CNoteData* nd = &(item.second); - // Only increment witnesses that are not above the current height - if (nd->witnessHeight <= pindex->nHeight) { - // Check the validity of the cache - // See comment below (this would be invalid if there was a - // prior decrement). - assert(nWitnessCacheSize >= nd->witnesses.size()); - // Witnesses being decremented should always be either -1 - // (never incremented or decremented) or equal to pindex - assert((nd->witnessHeight == -1) || - (nd->witnessHeight == pindex->nHeight)); - if (nd->witnesses.size() > 0) { - nd->witnesses.pop_front(); - } - // pindex is the block being removed, so the new witness cache - // height is one below it. - nd->witnessHeight = pindex->nHeight - 1; - } - } - } - nWitnessCacheSize -= 1; - for (std::pair& wtxItem : mapWallet) { - for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { - CNoteData* nd = &(item.second); - // Check the validity of the cache - // Technically if there are notes witnessed above the current - // height, their cache will now be invalid (relative to the new - // value of nWitnessCacheSize). However, this would only occur - // during a reindex, and by the time the reindex reaches the tip - // of the chain again, the existing witness caches will be valid - // again. - // We don't set nWitnessCacheSize to zero at the start of the - // reindex because the on-disk blocks had already resulted in a - // chain that didn't trigger the assertion below. - if (nd->witnessHeight < pindex->nHeight) { - assert(nWitnessCacheSize >= nd->witnesses.size()); - } - } - } - // TODO: If nWitnessCache is zero, we need to regenerate the caches (#1302) - assert(nWitnessCacheSize > 0); - - // For performance reasons, we write out the witness cache in - // CWallet::SetBestChain() (which also ensures that overall consistency - // of the wallet.dat is maintained). + LOCK(cs_wallet); + for (std::pair& wtxItem : mapWallet) { + ::DecrementNoteWitnesses(wtxItem.second.mapSproutNoteData, pindex->nHeight, nWitnessCacheSize); + ::DecrementNoteWitnesses(wtxItem.second.mapSaplingNoteData, pindex->nHeight, nWitnessCacheSize); } + nWitnessCacheSize -= 1; + // TODO: If nWitnessCache is zero, we need to regenerate the caches (#1302) + assert(nWitnessCacheSize > 0); + + // For performance reasons, we write out the witness cache in + // CWallet::SetBestChain() (which also ensures that overall consistency + // of the wallet.dat is maintained). } bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase) @@ -1075,7 +1125,7 @@ bool CWallet::UpdateNullifierNoteMap() ZCNoteDecryption dec; for (std::pair& wtxItem : mapWallet) { - for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { + for (mapSproutNoteData_t::value_type& item : wtxItem.second.mapSproutNoteData) { if (!item.second.nullifier) { if (GetNoteDecryptor(item.second.address, dec)) { auto i = item.first.js; @@ -1103,7 +1153,7 @@ void CWallet::UpdateNullifierNoteMapWithTx(const CWalletTx& wtx) { { LOCK(cs_wallet); - for (const mapNoteData_t::value_type& item : wtx.mapNoteData) { + for (const mapSproutNoteData_t::value_type& item : wtx.mapSproutNoteData) { if (item.second.nullifier) { mapNullifiersToNotes[*item.second.nullifier] = item.first; } @@ -1238,12 +1288,12 @@ bool CWallet::AddToWallet(const CWalletTx& wtxIn, bool fFromLoadWallet, CWalletD bool CWallet::UpdatedNoteData(const CWalletTx& wtxIn, CWalletTx& wtx) { - if (wtxIn.mapNoteData.empty() || wtxIn.mapNoteData == wtx.mapNoteData) { + if (wtxIn.mapSproutNoteData.empty() || wtxIn.mapSproutNoteData == wtx.mapSproutNoteData) { return false; } - auto tmp = wtxIn.mapNoteData; + auto tmp = wtxIn.mapSproutNoteData; // Ensure we keep any cached witnesses we may already have - for (const std::pair nd : wtx.mapNoteData) { + for (const std::pair nd : wtx.mapSproutNoteData) { if (tmp.count(nd.first) && nd.second.witnesses.size() > 0) { tmp.at(nd.first).witnesses.assign( nd.second.witnesses.cbegin(), nd.second.witnesses.cend()); @@ -1251,7 +1301,7 @@ bool CWallet::UpdatedNoteData(const CWalletTx& wtxIn, CWalletTx& wtx) tmp.at(nd.first).witnessHeight = nd.second.witnessHeight; } // Now copy over the updated note data - wtx.mapNoteData = tmp; + wtx.mapSproutNoteData = tmp; return true; } @@ -1272,8 +1322,9 @@ bool CWallet::AddToWalletIfInvolvingMe(const CTransaction& tx, const CBlock* pbl CWalletTx wtx(this,tx); if (noteData.size() > 0) { - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); } + // TODO: Sapling note data // Get merkle branch if transaction was found in a block if (pblock) @@ -1365,14 +1416,14 @@ boost::optional CWallet::GetNoteNullifier(const JSDescription& jsdesc, * * It should never be necessary to call this method with a CWalletTx, because * the result of FindMyNotes (for the addresses available at the time) will - * already have been cached in CWalletTx.mapNoteData. + * already have been cached in CWalletTx.mapSproutNoteData. */ -mapNoteData_t CWallet::FindMyNotes(const CTransaction& tx) const +mapSproutNoteData_t CWallet::FindMyNotes(const CTransaction& tx) const { LOCK(cs_SpendingKeyStore); uint256 hash = tx.GetHash(); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; for (size_t i = 0; i < tx.vjoinsplit.size(); i++) { auto hSig = tx.vjoinsplit[i].h_sig(*pzcashParams, tx.joinSplitPubKey); for (uint8_t j = 0; j < tx.vjoinsplit[i].ciphertexts.size(); j++) { @@ -1386,10 +1437,10 @@ mapNoteData_t CWallet::FindMyNotes(const CTransaction& tx) const item.second, hSig, j); if (nullifier) { - CNoteData nd {address, *nullifier}; + SproutNoteData nd {address, *nullifier}; noteData.insert(std::make_pair(jsoutpt, nd)); } else { - CNoteData nd {address}; + SproutNoteData nd {address}; noteData.insert(std::make_pair(jsoutpt, nd)); } break; @@ -1418,32 +1469,57 @@ bool CWallet::IsFromMe(const uint256& nullifier) const return false; } -void CWallet::GetNoteWitnesses(std::vector notes, - std::vector>& witnesses, - uint256 &final_anchor) +void CWallet::GetSproutNoteWitnesses(std::vector notes, + std::vector>& witnesses, + uint256 &final_anchor) { - { - LOCK(cs_wallet); - witnesses.resize(notes.size()); - boost::optional rt; - int i = 0; - for (JSOutPoint note : notes) { - if (mapWallet.count(note.hash) && - mapWallet[note.hash].mapNoteData.count(note) && - mapWallet[note.hash].mapNoteData[note].witnesses.size() > 0) { - witnesses[i] = mapWallet[note.hash].mapNoteData[note].witnesses.front(); - if (!rt) { - rt = witnesses[i]->root(); - } else { - assert(*rt == witnesses[i]->root()); - } + LOCK(cs_wallet); + witnesses.resize(notes.size()); + boost::optional rt; + int i = 0; + for (JSOutPoint note : notes) { + if (mapWallet.count(note.hash) && + mapWallet[note.hash].mapSproutNoteData.count(note) && + mapWallet[note.hash].mapSproutNoteData[note].witnesses.size() > 0) { + witnesses[i] = mapWallet[note.hash].mapSproutNoteData[note].witnesses.front(); + if (!rt) { + rt = witnesses[i]->root(); + } else { + assert(*rt == witnesses[i]->root()); } - i++; } - // All returned witnesses have the same anchor - if (rt) { - final_anchor = *rt; + i++; + } + // All returned witnesses have the same anchor + if (rt) { + final_anchor = *rt; + } +} + +void CWallet::GetSaplingNoteWitnesses(std::vector notes, + std::vector>& witnesses, + uint256 &final_anchor) +{ + LOCK(cs_wallet); + witnesses.resize(notes.size()); + boost::optional rt; + int i = 0; + for (SaplingOutPoint note : notes) { + if (mapWallet.count(note.hash) && + mapWallet[note.hash].mapSaplingNoteData.count(note) && + mapWallet[note.hash].mapSaplingNoteData[note].witnesses.size() > 0) { + witnesses[i] = mapWallet[note.hash].mapSaplingNoteData[note].witnesses.front(); + if (!rt) { + rt = witnesses[i]->root(); + } else { + assert(*rt == witnesses[i]->root()); + } } + i++; + } + // All returned witnesses have the same anchor + if (rt) { + final_anchor = *rt; } } @@ -1578,18 +1654,30 @@ CAmount CWallet::GetChange(const CTransaction& tx) const return nChange; } -void CWalletTx::SetNoteData(mapNoteData_t ¬eData) +void CWalletTx::SetSproutNoteData(mapSproutNoteData_t ¬eData) { - mapNoteData.clear(); - for (const std::pair nd : noteData) { + mapSproutNoteData.clear(); + for (const std::pair nd : noteData) { if (nd.first.js < vjoinsplit.size() && nd.first.n < vjoinsplit[nd.first.js].ciphertexts.size()) { // Store the address and nullifier for the Note - mapNoteData[nd.first] = nd.second; + mapSproutNoteData[nd.first] = nd.second; } else { // If FindMyNotes() was used to obtain noteData, // this should never happen - throw std::logic_error("CWalletTx::SetNoteData(): Invalid note"); + throw std::logic_error("CWalletTx::SetSproutNoteData(): Invalid note"); + } + } +} + +void CWalletTx::SetSaplingNoteData(mapSaplingNoteData_t ¬eData) +{ + mapSaplingNoteData.clear(); + for (const std::pair nd : noteData) { + if (nd.first.n < vShieldedOutput.size()) { + mapSaplingNoteData[nd.first] = nd.second; + } else { + throw std::logic_error("CWalletTx::SetSaplingNoteData(): Invalid note"); } } } @@ -1691,7 +1779,7 @@ void CWalletTx::GetAmounts(list& listReceived, // Check output side if (!fMyJSDesc) { - for (const std::pair nd : this->mapNoteData) { + for (const std::pair nd : this->mapSproutNoteData) { if (nd.first.js < vjoinsplit.size() && nd.first.n < vjoinsplit[nd.first.js].ciphertexts.size()) { fMyJSDesc = true; break; @@ -1892,12 +1980,18 @@ int CWallet::ScanForWalletTransactions(CBlockIndex* pindexStart, bool fUpdate) ret++; } - ZCIncrementalMerkleTree tree; + ZCIncrementalMerkleTree sproutTree; + ZCSaplingIncrementalMerkleTree saplingTree; // This should never fail: we should always be able to get the tree // state on the path to the tip of our chain - assert(pcoinsTip->GetSproutAnchorAt(pindex->hashSproutAnchor, tree)); + assert(pcoinsTip->GetSproutAnchorAt(pindex->hashSproutAnchor, sproutTree)); + if (pindex->pprev) { + if (NetworkUpgradeActive(pindex->pprev->nHeight, Params().GetConsensus(), Consensus::UPGRADE_SAPLING)) { + assert(pcoinsTip->GetSaplingAnchorAt(pindex->pprev->hashFinalSaplingRoot, saplingTree)); + } + } // Increment note witness caches - IncrementNoteWitnesses(pindex, &block, tree); + IncrementNoteWitnesses(pindex, &block, sproutTree, saplingTree); pindex = chainActive.Next(pindex); if (GetTime() >= nNow + 60) { @@ -3821,13 +3915,13 @@ void CWallet::GetFilteredNotes( continue; } - if (wtx.mapNoteData.size() == 0) { + if (wtx.mapSproutNoteData.size() == 0) { continue; } - for (auto & pair : wtx.mapNoteData) { + for (auto & pair : wtx.mapSproutNoteData) { JSOutPoint jsop = pair.first; - CNoteData nd = pair.second; + SproutNoteData nd = pair.second; SproutPaymentAddress pa = nd.address; // skip notes which belong to a different payment address in the wallet @@ -3902,13 +3996,13 @@ void CWallet::GetUnspentFilteredNotes( continue; } - if (wtx.mapNoteData.size() == 0) { + if (wtx.mapSproutNoteData.size() == 0) { continue; } - for (auto & pair : wtx.mapNoteData) { + for (auto & pair : wtx.mapSproutNoteData) { JSOutPoint jsop = pair.first; - CNoteData nd = pair.second; + SproutNoteData nd = pair.second; SproutPaymentAddress pa = nd.address; // skip notes which belong to a different payment address in the wallet diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 4aee8d01f..43fdc5d77 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -197,14 +197,14 @@ public: std::string ToString() const; }; -class CNoteData +class SproutNoteData { public: libzcash::SproutPaymentAddress address; /** * Cached note nullifier. May not be set if the wallet was not unlocked when - * this was CNoteData was created. If not set, we always assume that the + * this was SproutNoteData was created. If not set, we always assume that the * note has not been spent. * * It's okay to cache the nullifier in the wallet, because we are storing @@ -225,7 +225,7 @@ public: /** * Block height corresponding to the most current witness. * - * When we first create a CNoteData in CWallet::FindMyNotes, this is set to + * When we first create a SproutNoteData in CWallet::FindMyNotes, this is set to * -1 as a placeholder. The next time CWallet::ChainTip is called, we can * determine what height the witness cache for this note is valid for (even * if no witnesses were cached), and so can set the correct value in @@ -233,10 +233,10 @@ public: */ int witnessHeight; - CNoteData() : address(), nullifier(), witnessHeight {-1} { } - CNoteData(libzcash::SproutPaymentAddress a) : + SproutNoteData() : address(), nullifier(), witnessHeight {-1} { } + SproutNoteData(libzcash::SproutPaymentAddress a) : address {a}, nullifier(), witnessHeight {-1} { } - CNoteData(libzcash::SproutPaymentAddress a, uint256 n) : + SproutNoteData(libzcash::SproutPaymentAddress a, uint256 n) : address {a}, nullifier {n}, witnessHeight {-1} { } ADD_SERIALIZE_METHODS; @@ -249,21 +249,35 @@ public: READWRITE(witnessHeight); } - friend bool operator<(const CNoteData& a, const CNoteData& b) { + friend bool operator<(const SproutNoteData& a, const SproutNoteData& b) { return (a.address < b.address || (a.address == b.address && a.nullifier < b.nullifier)); } - friend bool operator==(const CNoteData& a, const CNoteData& b) { + friend bool operator==(const SproutNoteData& a, const SproutNoteData& b) { return (a.address == b.address && a.nullifier == b.nullifier); } - friend bool operator!=(const CNoteData& a, const CNoteData& b) { + friend bool operator!=(const SproutNoteData& a, const SproutNoteData& b) { return !(a == b); } }; -typedef std::map mapNoteData_t; +class SaplingNoteData +{ +public: + /** + * We initialize the hight to -1 for the same reason as we do in SproutNoteData. + * See the comment in that class for a full description. + */ + SaplingNoteData() : witnessHeight {-1} { } + + std::list witnesses; + int witnessHeight; +}; + +typedef std::map mapSproutNoteData_t; +typedef std::map mapSaplingNoteData_t; /** Decrypted note and its location in a transaction. */ struct CSproutNotePlaintextEntry @@ -350,7 +364,8 @@ private: public: mapValue_t mapValue; - mapNoteData_t mapNoteData; + mapSproutNoteData_t mapSproutNoteData; + mapSaplingNoteData_t mapSaplingNoteData; std::vector > vOrderForm; unsigned int fTimeReceivedIsTxTime; unsigned int nTimeReceived; //! time received by this node @@ -403,7 +418,8 @@ public: { pwallet = pwalletIn; mapValue.clear(); - mapNoteData.clear(); + mapSproutNoteData.clear(); + mapSaplingNoteData.clear(); vOrderForm.clear(); fTimeReceivedIsTxTime = false; nTimeReceived = 0; @@ -453,12 +469,14 @@ public: std::vector vUnused; //! Used to be vtxPrev READWRITE(vUnused); READWRITE(mapValue); - READWRITE(mapNoteData); + READWRITE(mapSproutNoteData); READWRITE(vOrderForm); READWRITE(fTimeReceivedIsTxTime); READWRITE(nTimeReceived); READWRITE(fFromMe); READWRITE(fSpent); + // TODO: + //READWRITE(mapSaplingNoteData); if (ser_action.ForRead()) { @@ -495,7 +513,8 @@ public: MarkDirty(); } - void SetNoteData(mapNoteData_t ¬eData); + void SetSproutNoteData(mapSproutNoteData_t ¬eData); + void SetSaplingNoteData(mapSaplingNoteData_t ¬eData); //! filter decides which addresses will count towards the debit CAmount GetDebit(const isminefilter& filter) const; @@ -718,7 +737,8 @@ protected: */ void IncrementNoteWitnesses(const CBlockIndex* pindex, const CBlock* pblock, - ZCIncrementalMerkleTree& tree); + ZCIncrementalMerkleTree& sproutTree, + ZCSaplingIncrementalMerkleTree& saplingTree); /** * pindex is the old tip being disconnected. */ @@ -842,7 +862,7 @@ public: * * - GetFilteredNotes can't filter out spent notes. * - * - Per the comment in CNoteData, we assume that if we don't have a + * - Per the comment in SproutNoteData, we assume that if we don't have a * cached nullifier, the note is not spent. * * Another more problematic implication is that the wallet can fail to @@ -1053,12 +1073,16 @@ public: const ZCNoteDecryption& dec, const uint256& hSig, uint8_t n) const; - mapNoteData_t FindMyNotes(const CTransaction& tx) const; + mapSproutNoteData_t FindMyNotes(const CTransaction& tx) const; bool IsFromMe(const uint256& nullifier) const; - void GetNoteWitnesses( + void GetSproutNoteWitnesses( std::vector notes, std::vector>& witnesses, uint256 &final_anchor); + void GetSaplingNoteWitnesses( + std::vector notes, + std::vector>& witnesses, + uint256 &final_anchor); isminetype IsMine(const CTxIn& txin) const; CAmount GetDebit(const CTxIn& txin, const isminefilter& filter) const; @@ -1072,7 +1096,7 @@ public: CAmount GetDebit(const CTransaction& tx, const isminefilter& filter) const; CAmount GetCredit(const CTransaction& tx, const isminefilter& filter) const; CAmount GetChange(const CTransaction& tx) const; - void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree tree, bool added); + void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree sproutTree, ZCSaplingIncrementalMerkleTree saplingTree, bool added); /** Saves witness caches and best block locator to disk. */ void SetBestChain(const CBlockLocator& loc); std::set> GetNullifiersForAddresses(const std::set & addresses); diff --git a/src/zcbenchmarks.cpp b/src/zcbenchmarks.cpp index b7fb3276e..fea10000f 100644 --- a/src/zcbenchmarks.cpp +++ b/src/zcbenchmarks.cpp @@ -298,7 +298,8 @@ double benchmark_try_decrypt_notes(size_t nAddrs) double benchmark_increment_note_witnesses(size_t nTxs) { CWallet wallet; - ZCIncrementalMerkleTree tree; + ZCIncrementalMerkleTree sproutTree; + ZCSaplingIncrementalMerkleTree saplingTree; auto sk = libzcash::SproutSpendingKey::random(); wallet.AddSpendingKey(sk); @@ -310,12 +311,12 @@ double benchmark_increment_note_witnesses(size_t nTxs) auto note = GetNote(*pzcashParams, sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); wallet.AddToWallet(wtx, true, NULL); block1.vtx.push_back(wtx); } @@ -323,7 +324,7 @@ double benchmark_increment_note_witnesses(size_t nTxs) index1.nHeight = 1; // Increment to get transactions witnessed - wallet.ChainTip(&index1, &block1, tree, true); + wallet.ChainTip(&index1, &block1, sproutTree, saplingTree, true); // Second block CBlock block2; @@ -333,12 +334,12 @@ double benchmark_increment_note_witnesses(size_t nTxs) auto note = GetNote(*pzcashParams, sk, wtx, 0, 1); auto nullifier = note.nullifier(sk); - mapNoteData_t noteData; + mapSproutNoteData_t noteData; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; - CNoteData nd {sk.address(), nullifier}; + SproutNoteData nd {sk.address(), nullifier}; noteData[jsoutpt] = nd; - wtx.SetNoteData(noteData); + wtx.SetSproutNoteData(noteData); wallet.AddToWallet(wtx, true, NULL); block2.vtx.push_back(wtx); } @@ -347,7 +348,7 @@ double benchmark_increment_note_witnesses(size_t nTxs) struct timeval tv_start; timer_start(tv_start); - wallet.ChainTip(&index2, &block2, tree, true); + wallet.ChainTip(&index2, &block2, sproutTree, saplingTree, true); return timer_stop(tv_start); }