Auto merge of #3353 - Eirik0:3062-cache-sapling-witnesses, r=bitcartel

Cache Sapling witnesses in the wallet

Closes #3062

I have not update the tests in test_wallet.cpp. Also, there are several other methods in the wallet that have to do with witnesses and note data which will need to be updated, but this PR focuses on IncrementNoteWitnesses and DecrementNoteWitnesses.
This commit is contained in:
Homu 2018-07-25 23:40:11 -07:00
commit 3eefe12c79
13 changed files with 701 additions and 459 deletions

View File

@ -2824,15 +2824,17 @@ bool static DisconnectTip(CValidationState &state, bool fBare = false) {
// Update chainActive and related variables. // Update chainActive and related variables.
UpdateTip(pindexDelete->pprev); UpdateTip(pindexDelete->pprev);
// Get the current commitment tree // Get the current commitment tree
ZCIncrementalMerkleTree newTree; ZCIncrementalMerkleTree newSproutTree;
assert(pcoinsTip->GetSproutAnchorAt(pcoinsTip->GetBestAnchor(SPROUT), newTree)); ZCSaplingIncrementalMerkleTree newSaplingTree;
assert(pcoinsTip->GetSproutAnchorAt(pcoinsTip->GetBestAnchor(SPROUT), newSproutTree));
assert(pcoinsTip->GetSaplingAnchorAt(pcoinsTip->GetBestAnchor(SAPLING), newSaplingTree));
// Let wallets know transactions went from 1-confirmed to // Let wallets know transactions went from 1-confirmed to
// 0-confirmed or conflicted: // 0-confirmed or conflicted:
BOOST_FOREACH(const CTransaction &tx, block.vtx) { BOOST_FOREACH(const CTransaction &tx, block.vtx) {
SyncWithWallets(tx, NULL); SyncWithWallets(tx, NULL);
} }
// Update cached incremental witnesses // Update cached incremental witnesses
GetMainSignals().ChainTip(pindexDelete, &block, newTree, false); GetMainSignals().ChainTip(pindexDelete, &block, newSproutTree, newSaplingTree, false);
return true; return true;
} }
@ -2858,8 +2860,10 @@ bool static ConnectTip(CValidationState &state, CBlockIndex *pindexNew, CBlock *
pblock = █ pblock = █
} }
// Get the current commitment tree // Get the current commitment tree
ZCIncrementalMerkleTree oldTree; ZCIncrementalMerkleTree oldSproutTree;
assert(pcoinsTip->GetSproutAnchorAt(pcoinsTip->GetBestAnchor(SPROUT), oldTree)); ZCSaplingIncrementalMerkleTree oldSaplingTree;
assert(pcoinsTip->GetSproutAnchorAt(pcoinsTip->GetBestAnchor(SPROUT), oldSproutTree));
assert(pcoinsTip->GetSaplingAnchorAt(pcoinsTip->GetBestAnchor(SAPLING), oldSaplingTree));
// Apply the block atomically to the chain state. // Apply the block atomically to the chain state.
int64_t nTime2 = GetTimeMicros(); nTimeReadFromDisk += nTime2 - nTime1; int64_t nTime2 = GetTimeMicros(); nTimeReadFromDisk += nTime2 - nTime1;
int64_t nTime3; int64_t nTime3;
@ -2904,7 +2908,7 @@ bool static ConnectTip(CValidationState &state, CBlockIndex *pindexNew, CBlock *
SyncWithWallets(tx, pblock); SyncWithWallets(tx, pblock);
} }
// Update cached incremental witnesses // Update cached incremental witnesses
GetMainSignals().ChainTip(pindexNew, pblock, oldTree, true); GetMainSignals().ChainTip(pindexNew, pblock, oldSproutTree, oldSaplingTree, true);
EnforceNodeDeprecation(pindexNew->nHeight); EnforceNodeDeprecation(pindexNew->nHeight);

View File

@ -149,6 +149,11 @@ std::string COutPoint::ToString() const
return strprintf("COutPoint(%s, %u)", hash.ToString().substr(0,10), n); return strprintf("COutPoint(%s, %u)", hash.ToString().substr(0,10), n);
} }
std::string SaplingOutPoint::ToString() const
{
return strprintf("SaplingOutPoint(%s, %u)", hash.ToString().substr(0, 10), n);
}
CTxIn::CTxIn(COutPoint prevoutIn, CScript scriptSigIn, uint32_t nSequenceIn) CTxIn::CTxIn(COutPoint prevoutIn, CScript scriptSigIn, uint32_t nSequenceIn)
{ {
prevout = prevoutIn; prevout = prevoutIn;

View File

@ -309,15 +309,14 @@ public:
} }
}; };
/** An outpoint - a combination of a transaction hash and an index n into its vout */ class BaseOutPoint
class COutPoint
{ {
public: public:
uint256 hash; uint256 hash;
uint32_t n; uint32_t n;
COutPoint() { SetNull(); } BaseOutPoint() { SetNull(); }
COutPoint(uint256 hashIn, uint32_t nIn) { hash = hashIn; n = nIn; } BaseOutPoint(uint256 hashIn, uint32_t nIn) { hash = hashIn; n = nIn; }
ADD_SERIALIZE_METHODS; ADD_SERIALIZE_METHODS;
@ -330,21 +329,38 @@ public:
void SetNull() { hash.SetNull(); n = (uint32_t) -1; } void SetNull() { hash.SetNull(); n = (uint32_t) -1; }
bool IsNull() const { return (hash.IsNull() && n == (uint32_t) -1); } bool IsNull() const { return (hash.IsNull() && n == (uint32_t) -1); }
friend bool operator<(const COutPoint& a, const COutPoint& b) friend bool operator<(const BaseOutPoint& a, const BaseOutPoint& b)
{ {
return (a.hash < b.hash || (a.hash == b.hash && a.n < b.n)); return (a.hash < b.hash || (a.hash == b.hash && a.n < b.n));
} }
friend bool operator==(const COutPoint& a, const COutPoint& b) friend bool operator==(const BaseOutPoint& a, const BaseOutPoint& b)
{ {
return (a.hash == b.hash && a.n == b.n); return (a.hash == b.hash && a.n == b.n);
} }
friend bool operator!=(const COutPoint& a, const COutPoint& b) friend bool operator!=(const BaseOutPoint& a, const BaseOutPoint& b)
{ {
return !(a == b); return !(a == b);
} }
};
/** An outpoint - a combination of a transaction hash and an index n into its vout */
class COutPoint : public BaseOutPoint
{
public:
COutPoint() : BaseOutPoint() {};
COutPoint(uint256 hashIn, uint32_t nIn) : BaseOutPoint(hashIn, nIn) {};
std::string ToString() const;
};
/** An outpoint - a combination of a transaction hash and an index n into its sapling
* output description (vShieldedOutput) */
class SaplingOutPoint : public BaseOutPoint
{
public:
SaplingOutPoint() : BaseOutPoint() {};
SaplingOutPoint(uint256 hashIn, uint32_t nIn) : BaseOutPoint(hashIn, nIn) {};
std::string ToString() const; std::string ToString() const;
}; };

View File

@ -10,9 +10,10 @@
CWalletTx GetValidReceive(ZCJoinSplit& params, CWalletTx GetValidReceive(ZCJoinSplit& params,
const libzcash::SproutSpendingKey& sk, CAmount value, const libzcash::SproutSpendingKey& sk, CAmount value,
bool randomInputs) { bool randomInputs,
int32_t version /* = 2 */) {
CMutableTransaction mtx; CMutableTransaction mtx;
mtx.nVersion = 2; // Enable JoinSplits mtx.nVersion = version;
mtx.vin.resize(2); mtx.vin.resize(2);
if (randomInputs) { if (randomInputs) {
mtx.vin[0].prevout.hash = GetRandHash(); mtx.vin[0].prevout.hash = GetRandHash();
@ -46,6 +47,12 @@ CWalletTx GetValidReceive(ZCJoinSplit& params,
inputs, outputs, 2*value, 0, false}; inputs, outputs, 2*value, 0, false};
mtx.vjoinsplit.push_back(jsdesc); mtx.vjoinsplit.push_back(jsdesc);
if (version >= 4) {
// Shielded Output
OutputDescription od;
mtx.vShieldedOutput.push_back(od);
}
// Empty output script. // Empty output script.
uint32_t consensusBranchId = SPROUT_BRANCH_ID; uint32_t consensusBranchId = SPROUT_BRANCH_ID;
CScript scriptCode; CScript scriptCode;

View File

@ -9,7 +9,8 @@
CWalletTx GetValidReceive(ZCJoinSplit& params, CWalletTx GetValidReceive(ZCJoinSplit& params,
const libzcash::SproutSpendingKey& sk, CAmount value, const libzcash::SproutSpendingKey& sk, CAmount value,
bool randomInputs); bool randomInputs,
int32_t version = 2);
libzcash::SproutNote GetNote(ZCJoinSplit& params, libzcash::SproutNote GetNote(ZCJoinSplit& params,
const libzcash::SproutSpendingKey& sk, const libzcash::SproutSpendingKey& sk,
const CTransaction& tx, size_t js, size_t n); const CTransaction& tx, size_t js, size_t n);

View File

@ -17,7 +17,7 @@ void RegisterValidationInterface(CValidationInterface* pwalletIn) {
g_signals.SyncTransaction.connect(boost::bind(&CValidationInterface::SyncTransaction, pwalletIn, _1, _2)); g_signals.SyncTransaction.connect(boost::bind(&CValidationInterface::SyncTransaction, pwalletIn, _1, _2));
g_signals.EraseTransaction.connect(boost::bind(&CValidationInterface::EraseFromWallet, pwalletIn, _1)); g_signals.EraseTransaction.connect(boost::bind(&CValidationInterface::EraseFromWallet, pwalletIn, _1));
g_signals.UpdatedTransaction.connect(boost::bind(&CValidationInterface::UpdatedTransaction, pwalletIn, _1)); g_signals.UpdatedTransaction.connect(boost::bind(&CValidationInterface::UpdatedTransaction, pwalletIn, _1));
g_signals.ChainTip.connect(boost::bind(&CValidationInterface::ChainTip, pwalletIn, _1, _2, _3, _4)); g_signals.ChainTip.connect(boost::bind(&CValidationInterface::ChainTip, pwalletIn, _1, _2, _3, _4, _5));
g_signals.SetBestChain.connect(boost::bind(&CValidationInterface::SetBestChain, pwalletIn, _1)); g_signals.SetBestChain.connect(boost::bind(&CValidationInterface::SetBestChain, pwalletIn, _1));
g_signals.Inventory.connect(boost::bind(&CValidationInterface::Inventory, pwalletIn, _1)); g_signals.Inventory.connect(boost::bind(&CValidationInterface::Inventory, pwalletIn, _1));
g_signals.Broadcast.connect(boost::bind(&CValidationInterface::ResendWalletTransactions, pwalletIn, _1)); g_signals.Broadcast.connect(boost::bind(&CValidationInterface::ResendWalletTransactions, pwalletIn, _1));
@ -28,7 +28,7 @@ void UnregisterValidationInterface(CValidationInterface* pwalletIn) {
g_signals.BlockChecked.disconnect(boost::bind(&CValidationInterface::BlockChecked, pwalletIn, _1, _2)); g_signals.BlockChecked.disconnect(boost::bind(&CValidationInterface::BlockChecked, pwalletIn, _1, _2));
g_signals.Broadcast.disconnect(boost::bind(&CValidationInterface::ResendWalletTransactions, pwalletIn, _1)); g_signals.Broadcast.disconnect(boost::bind(&CValidationInterface::ResendWalletTransactions, pwalletIn, _1));
g_signals.Inventory.disconnect(boost::bind(&CValidationInterface::Inventory, pwalletIn, _1)); g_signals.Inventory.disconnect(boost::bind(&CValidationInterface::Inventory, pwalletIn, _1));
g_signals.ChainTip.disconnect(boost::bind(&CValidationInterface::ChainTip, pwalletIn, _1, _2, _3, _4)); g_signals.ChainTip.disconnect(boost::bind(&CValidationInterface::ChainTip, pwalletIn, _1, _2, _3, _4, _5));
g_signals.SetBestChain.disconnect(boost::bind(&CValidationInterface::SetBestChain, pwalletIn, _1)); g_signals.SetBestChain.disconnect(boost::bind(&CValidationInterface::SetBestChain, pwalletIn, _1));
g_signals.UpdatedTransaction.disconnect(boost::bind(&CValidationInterface::UpdatedTransaction, pwalletIn, _1)); g_signals.UpdatedTransaction.disconnect(boost::bind(&CValidationInterface::UpdatedTransaction, pwalletIn, _1));
g_signals.EraseTransaction.disconnect(boost::bind(&CValidationInterface::EraseFromWallet, pwalletIn, _1)); g_signals.EraseTransaction.disconnect(boost::bind(&CValidationInterface::EraseFromWallet, pwalletIn, _1));

View File

@ -34,7 +34,7 @@ protected:
virtual void UpdatedBlockTip(const CBlockIndex *pindex) {} virtual void UpdatedBlockTip(const CBlockIndex *pindex) {}
virtual void SyncTransaction(const CTransaction &tx, const CBlock *pblock) {} virtual void SyncTransaction(const CTransaction &tx, const CBlock *pblock) {}
virtual void EraseFromWallet(const uint256 &hash) {} virtual void EraseFromWallet(const uint256 &hash) {}
virtual void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree tree, bool added) {} virtual void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree sproutTree, ZCSaplingIncrementalMerkleTree saplingTree, bool added) {}
virtual void SetBestChain(const CBlockLocator &locator) {} virtual void SetBestChain(const CBlockLocator &locator) {}
virtual void UpdatedTransaction(const uint256 &hash) {} virtual void UpdatedTransaction(const uint256 &hash) {}
virtual void Inventory(const uint256 &hash) {} virtual void Inventory(const uint256 &hash) {}
@ -55,7 +55,7 @@ struct CMainSignals {
/** Notifies listeners of an updated transaction without new data (for now: a coinbase potentially becoming visible). */ /** Notifies listeners of an updated transaction without new data (for now: a coinbase potentially becoming visible). */
boost::signals2::signal<void (const uint256 &)> UpdatedTransaction; boost::signals2::signal<void (const uint256 &)> UpdatedTransaction;
/** Notifies listeners of a change to the tip of the active block chain. */ /** Notifies listeners of a change to the tip of the active block chain. */
boost::signals2::signal<void (const CBlockIndex *, const CBlock *, ZCIncrementalMerkleTree, bool)> ChainTip; boost::signals2::signal<void (const CBlockIndex *, const CBlock *, ZCIncrementalMerkleTree, ZCSaplingIncrementalMerkleTree, bool)> ChainTip;
/** Notifies listeners of a new active block chain. */ /** Notifies listeners of a new active block chain. */
boost::signals2::signal<void (const CBlockLocator &)> SetBestChain; boost::signals2::signal<void (const CBlockLocator &)> SetBestChain;
/** Notifies listeners about an inventory item being seen on the network. */ /** Notifies listeners about an inventory item being seen on the network. */

View File

@ -343,7 +343,7 @@ bool AsyncRPCOperation_mergetoaddress::main_impl()
std::vector<JSOutPoint> vOutPoints = {jso}; std::vector<JSOutPoint> vOutPoints = {jso};
uint256 inputAnchor; uint256 inputAnchor;
std::vector<boost::optional<ZCIncrementalWitness>> vInputWitnesses; std::vector<boost::optional<ZCIncrementalWitness>> vInputWitnesses;
pwalletMain->GetNoteWitnesses(vOutPoints, vInputWitnesses, inputAnchor); pwalletMain->GetSproutNoteWitnesses(vOutPoints, vInputWitnesses, inputAnchor);
jsopWitnessAnchorMap[jso.ToString()] = MergeToAddressWitnessAnchorData{vInputWitnesses[0], inputAnchor}; jsopWitnessAnchorMap[jso.ToString()] = MergeToAddressWitnessAnchorData{vInputWitnesses[0], inputAnchor};
} }
} }
@ -711,7 +711,7 @@ UniValue AsyncRPCOperation_mergetoaddress::perform_joinsplit(MergeToAddressJSInf
uint256 anchor; uint256 anchor;
{ {
LOCK(cs_main); LOCK(cs_main);
pwalletMain->GetNoteWitnesses(outPoints, witnesses, anchor); pwalletMain->GetSproutNoteWitnesses(outPoints, witnesses, anchor);
} }
return perform_joinsplit(info, witnesses, anchor); return perform_joinsplit(info, witnesses, anchor);
} }

View File

@ -420,7 +420,7 @@ bool AsyncRPCOperation_sendmany::main_impl() {
std::vector<JSOutPoint> vOutPoints = { jso }; std::vector<JSOutPoint> vOutPoints = { jso };
uint256 inputAnchor; uint256 inputAnchor;
std::vector<boost::optional<ZCIncrementalWitness>> vInputWitnesses; std::vector<boost::optional<ZCIncrementalWitness>> vInputWitnesses;
pwalletMain->GetNoteWitnesses(vOutPoints, vInputWitnesses, inputAnchor); pwalletMain->GetSproutNoteWitnesses(vOutPoints, vInputWitnesses, inputAnchor);
jsopWitnessAnchorMap[ jso.ToString() ] = WitnessAnchorData{ vInputWitnesses[0], inputAnchor }; jsopWitnessAnchorMap[ jso.ToString() ] = WitnessAnchorData{ vInputWitnesses[0], inputAnchor };
} }
} }
@ -935,7 +935,7 @@ UniValue AsyncRPCOperation_sendmany::perform_joinsplit(AsyncJoinSplitInfo & info
uint256 anchor; uint256 anchor;
{ {
LOCK(cs_main); LOCK(cs_main);
pwalletMain->GetNoteWitnesses(outPoints, witnesses, anchor); pwalletMain->GetSproutNoteWitnesses(outPoints, witnesses, anchor);
} }
return perform_joinsplit(info, witnesses, anchor); return perform_joinsplit(info, witnesses, anchor);
} }

View File

@ -51,8 +51,9 @@ public:
void IncrementNoteWitnesses(const CBlockIndex* pindex, void IncrementNoteWitnesses(const CBlockIndex* pindex,
const CBlock* pblock, const CBlock* pblock,
ZCIncrementalMerkleTree& tree) { ZCIncrementalMerkleTree& sproutTree,
CWallet::IncrementNoteWitnesses(pindex, pblock, tree); ZCSaplingIncrementalMerkleTree& saplingTree) {
CWallet::IncrementNoteWitnesses(pindex, pblock, sproutTree, saplingTree);
} }
void DecrementNoteWitnesses(const CBlockIndex* pindex) { void DecrementNoteWitnesses(const CBlockIndex* pindex) {
CWallet::DecrementNoteWitnesses(pindex); CWallet::DecrementNoteWitnesses(pindex);
@ -68,8 +69,8 @@ public:
} }
}; };
CWalletTx GetValidReceive(const libzcash::SproutSpendingKey& sk, CAmount value, bool randomInputs) { CWalletTx GetValidReceive(const libzcash::SproutSpendingKey& sk, CAmount value, bool randomInputs, int32_t version = 2) {
return GetValidReceive(*params, sk, value, randomInputs); return GetValidReceive(*params, sk, value, randomInputs, version);
} }
libzcash::SproutNote GetNote(const libzcash::SproutSpendingKey& sk, libzcash::SproutNote GetNote(const libzcash::SproutSpendingKey& sk,
@ -82,26 +83,52 @@ CWalletTx GetValidSpend(const libzcash::SproutSpendingKey& sk,
return GetValidSpend(*params, sk, note, value); return GetValidSpend(*params, sk, note, value);
} }
JSOutPoint CreateValidBlock(TestWallet& wallet, std::vector<SaplingOutPoint> SetSaplingNoteData(CWalletTx& wtx) {
mapSaplingNoteData_t saplingNoteData;
SaplingOutPoint saplingOutPoint = {wtx.GetHash(), 0};
SaplingNoteData saplingNd;
saplingNoteData[saplingOutPoint] = saplingNd;
wtx.SetSaplingNoteData(saplingNoteData);
std::vector<SaplingOutPoint> saplingNotes {saplingOutPoint};
return saplingNotes;
}
std::pair<JSOutPoint, SaplingOutPoint> CreateValidBlock(TestWallet& wallet,
const libzcash::SproutSpendingKey& sk, const libzcash::SproutSpendingKey& sk,
const CBlockIndex& index, const CBlockIndex& index,
CBlock& block, CBlock& block,
ZCIncrementalMerkleTree& tree) { ZCIncrementalMerkleTree& sproutTree,
auto wtx = GetValidReceive(sk, 50, true); ZCSaplingIncrementalMerkleTree& saplingTree) {
auto wtx = GetValidReceive(sk, 50, true, 4);
auto note = GetNote(sk, wtx, 0, 1); auto note = GetNote(sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
auto saplingNotes = SetSaplingNoteData(wtx);
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
block.vtx.push_back(wtx); block.vtx.push_back(wtx);
wallet.IncrementNoteWitnesses(&index, &block, tree); wallet.IncrementNoteWitnesses(&index, &block, sproutTree, saplingTree);
return jsoutpt; return std::make_pair(jsoutpt, saplingNotes[0]);
}
std::pair<uint256, uint256> GetWitnessesAndAnchors(TestWallet& wallet,
std::vector<JSOutPoint>& sproutNotes,
std::vector<SaplingOutPoint>& saplingNotes,
std::vector<boost::optional<ZCIncrementalWitness>>& sproutWitnesses,
std::vector<boost::optional<ZCSaplingIncrementalWitness>>& saplingWitnesses) {
sproutWitnesses.clear();
saplingWitnesses.clear();
uint256 sproutAnchor;
uint256 saplingAnchor;
wallet.GetSproutNoteWitnesses(sproutNotes, sproutWitnesses, sproutAnchor);
wallet.GetSaplingNoteWitnesses(saplingNotes, saplingWitnesses, saplingAnchor);
return std::make_pair(sproutAnchor, saplingAnchor);
} }
TEST(wallet_tests, setup_datadir_location_run_as_first_test) { TEST(wallet_tests, setup_datadir_location_run_as_first_test) {
@ -117,9 +144,9 @@ TEST(wallet_tests, note_data_serialisation) {
auto note = GetNote(sk, wtx, 0, 1); auto note = GetNote(sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
ZCIncrementalMerkleTree tree; ZCIncrementalMerkleTree tree;
nd.witnesses.push_front(tree.witness()); nd.witnesses.push_front(tree.witness());
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
@ -127,7 +154,7 @@ TEST(wallet_tests, note_data_serialisation) {
CDataStream ss(SER_DISK, CLIENT_VERSION); CDataStream ss(SER_DISK, CLIENT_VERSION);
ss << noteData; ss << noteData;
mapNoteData_t noteData2; mapSproutNoteData_t noteData2;
ss >> noteData2; ss >> noteData2;
EXPECT_EQ(noteData, noteData2); EXPECT_EQ(noteData, noteData2);
@ -145,12 +172,12 @@ TEST(wallet_tests, find_unspent_notes) {
auto note = GetNote(sk, wtx, 0, 1); auto note = GetNote(sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
EXPECT_FALSE(wallet.IsSpent(nullifier)); EXPECT_FALSE(wallet.IsSpent(nullifier));
@ -240,12 +267,12 @@ TEST(wallet_tests, find_unspent_notes) {
auto note = GetNote(sk, wtx, 0, 1); auto note = GetNote(sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
EXPECT_FALSE(wallet.IsSpent(nullifier)); EXPECT_FALSE(wallet.IsSpent(nullifier));
@ -299,28 +326,28 @@ TEST(wallet_tests, set_note_addrs_in_cwallettx) {
auto wtx = GetValidReceive(sk, 10, true); auto wtx = GetValidReceive(sk, 10, true);
auto note = GetNote(sk, wtx, 0, 1); auto note = GetNote(sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
EXPECT_EQ(0, wtx.mapNoteData.size()); EXPECT_EQ(0, wtx.mapSproutNoteData.size());
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
EXPECT_EQ(noteData, wtx.mapNoteData); EXPECT_EQ(noteData, wtx.mapSproutNoteData);
} }
TEST(wallet_tests, set_invalid_note_addrs_in_cwallettx) { TEST(wallet_tests, set_invalid_note_addrs_in_cwallettx) {
CWalletTx wtx; CWalletTx wtx;
EXPECT_EQ(0, wtx.mapNoteData.size()); EXPECT_EQ(0, wtx.mapSproutNoteData.size());
mapNoteData_t noteData; mapSproutNoteData_t noteData;
auto sk = libzcash::SproutSpendingKey::random(); auto sk = libzcash::SproutSpendingKey::random();
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), uint256()}; SproutNoteData nd {sk.address(), uint256()};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
EXPECT_THROW(wtx.SetNoteData(noteData), std::logic_error); EXPECT_THROW(wtx.SetSproutNoteData(noteData), std::logic_error);
} }
TEST(wallet_tests, GetNoteNullifier) { TEST(wallet_tests, GetNoteNullifier) {
@ -374,7 +401,7 @@ TEST(wallet_tests, FindMyNotes) {
EXPECT_EQ(2, noteMap.size()); EXPECT_EQ(2, noteMap.size());
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
EXPECT_EQ(1, noteMap.count(jsoutpt)); EXPECT_EQ(1, noteMap.count(jsoutpt));
EXPECT_EQ(nd, noteMap[jsoutpt]); EXPECT_EQ(nd, noteMap[jsoutpt]);
} }
@ -397,7 +424,7 @@ TEST(wallet_tests, FindMyNotesInEncryptedWallet) {
EXPECT_EQ(2, noteMap.size()); EXPECT_EQ(2, noteMap.size());
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
EXPECT_EQ(1, noteMap.count(jsoutpt)); EXPECT_EQ(1, noteMap.count(jsoutpt));
EXPECT_NE(nd, noteMap[jsoutpt]); EXPECT_NE(nd, noteMap[jsoutpt]);
@ -490,12 +517,12 @@ TEST(wallet_tests, navigate_from_nullifier_to_note) {
auto note = GetNote(sk, wtx, 0, 1); auto note = GetNote(sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
EXPECT_EQ(0, wallet.mapNullifiersToNotes.count(nullifier)); EXPECT_EQ(0, wallet.mapNullifiersToNotes.count(nullifier));
@ -520,12 +547,12 @@ TEST(wallet_tests, spent_note_is_from_me) {
EXPECT_FALSE(wallet.IsFromMe(wtx)); EXPECT_FALSE(wallet.IsFromMe(wtx));
EXPECT_FALSE(wallet.IsFromMe(wtx2)); EXPECT_FALSE(wallet.IsFromMe(wtx2));
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
EXPECT_FALSE(wallet.IsFromMe(wtx)); EXPECT_FALSE(wallet.IsFromMe(wtx));
EXPECT_FALSE(wallet.IsFromMe(wtx2)); EXPECT_FALSE(wallet.IsFromMe(wtx2));
@ -540,44 +567,53 @@ TEST(wallet_tests, cached_witnesses_empty_chain) {
auto sk = libzcash::SproutSpendingKey::random(); auto sk = libzcash::SproutSpendingKey::random();
wallet.AddSpendingKey(sk); wallet.AddSpendingKey(sk);
auto wtx = GetValidReceive(sk, 10, true); auto wtx = GetValidReceive(sk, 10, true, 4);
auto note = GetNote(sk, wtx, 0, 0); auto note = GetNote(sk, wtx, 0, 0);
auto note2 = GetNote(sk, wtx, 0, 1); auto note2 = GetNote(sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
auto nullifier2 = note2.nullifier(sk); auto nullifier2 = note2.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t sproutNoteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 0}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 0};
JSOutPoint jsoutpt2 {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt2 {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
CNoteData nd2 {sk.address(), nullifier2}; SproutNoteData nd2 {sk.address(), nullifier2};
noteData[jsoutpt] = nd; sproutNoteData[jsoutpt] = nd;
noteData[jsoutpt2] = nd2; sproutNoteData[jsoutpt2] = nd2;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(sproutNoteData);
std::vector<JSOutPoint> notes {jsoutpt, jsoutpt2}; std::vector<JSOutPoint> sproutNotes {jsoutpt, jsoutpt2};
std::vector<boost::optional<ZCIncrementalWitness>> witnesses; std::vector<SaplingOutPoint> saplingNotes = SetSaplingNoteData(wtx);
uint256 anchor;
wallet.GetNoteWitnesses(notes, witnesses, anchor); std::vector<boost::optional<ZCIncrementalWitness>> sproutWitnesses;
EXPECT_FALSE((bool) witnesses[0]); std::vector<boost::optional<ZCSaplingIncrementalWitness>> saplingWitnesses;
EXPECT_FALSE((bool) witnesses[1]);
::GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
EXPECT_FALSE((bool) sproutWitnesses[0]);
EXPECT_FALSE((bool) sproutWitnesses[1]);
EXPECT_FALSE((bool) saplingWitnesses[0]);
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
witnesses.clear();
wallet.GetNoteWitnesses(notes, witnesses, anchor); ::GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
EXPECT_FALSE((bool) witnesses[0]);
EXPECT_FALSE((bool) witnesses[1]); EXPECT_FALSE((bool) sproutWitnesses[0]);
EXPECT_FALSE((bool) sproutWitnesses[1]);
EXPECT_FALSE((bool) saplingWitnesses[0]);
CBlock block; CBlock block;
block.vtx.push_back(wtx); block.vtx.push_back(wtx);
CBlockIndex index(block); CBlockIndex index(block);
ZCIncrementalMerkleTree tree; ZCIncrementalMerkleTree sproutTree;
wallet.IncrementNoteWitnesses(&index, &block, tree); ZCSaplingIncrementalMerkleTree saplingTree;
witnesses.clear(); wallet.IncrementNoteWitnesses(&index, &block, sproutTree, saplingTree);
wallet.GetNoteWitnesses(notes, witnesses, anchor);
EXPECT_TRUE((bool) witnesses[0]); ::GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
EXPECT_TRUE((bool) witnesses[1]);
EXPECT_TRUE((bool) sproutWitnesses[0]);
EXPECT_TRUE((bool) sproutWitnesses[1]);
EXPECT_TRUE((bool) saplingWitnesses[0]);
// Until #1302 is implemented, this should triggger an assertion // Until #1302 is implemented, this should triggger an assertion
EXPECT_DEATH(wallet.DecrementNoteWitnesses(&index), EXPECT_DEATH(wallet.DecrementNoteWitnesses(&index),
@ -586,9 +622,10 @@ TEST(wallet_tests, cached_witnesses_empty_chain) {
TEST(wallet_tests, cached_witnesses_chain_tip) { TEST(wallet_tests, cached_witnesses_chain_tip) {
TestWallet wallet; TestWallet wallet;
uint256 anchor1; std::pair<uint256, uint256> anchors1;
CBlock block1; CBlock block1;
ZCIncrementalMerkleTree tree; ZCIncrementalMerkleTree sproutTree;
ZCSaplingIncrementalMerkleTree saplingTree;
auto sk = libzcash::SproutSpendingKey::random(); auto sk = libzcash::SproutSpendingKey::random();
wallet.AddSpendingKey(sk); wallet.AddSpendingKey(sk);
@ -597,33 +634,40 @@ TEST(wallet_tests, cached_witnesses_chain_tip) {
// First block (case tested in _empty_chain) // First block (case tested in _empty_chain)
CBlockIndex index1(block1); CBlockIndex index1(block1);
index1.nHeight = 1; index1.nHeight = 1;
auto jsoutpt = CreateValidBlock(wallet, sk, index1, block1, tree); auto outpts = CreateValidBlock(wallet, sk, index1, block1, sproutTree, saplingTree);
// Called to fetch anchor // Called to fetch anchor
std::vector<JSOutPoint> notes {jsoutpt}; std::vector<JSOutPoint> sproutNotes {outpts.first};
std::vector<boost::optional<ZCIncrementalWitness>> witnesses; std::vector<SaplingOutPoint> saplingNotes {outpts.second};
wallet.GetNoteWitnesses(notes, witnesses, anchor1); std::vector<boost::optional<ZCIncrementalWitness>> sproutWitnesses;
std::vector<boost::optional<ZCSaplingIncrementalWitness>> saplingWitnesses;
anchors1 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
EXPECT_NE(anchors1.first, anchors1.second);
} }
{ {
// Second transaction // Second transaction
auto wtx = GetValidReceive(sk, 50, true); auto wtx = GetValidReceive(sk, 50, true, 4);
auto note = GetNote(sk, wtx, 0, 1); auto note = GetNote(sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t sproutNoteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; sproutNoteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(sproutNoteData);
std::vector<SaplingOutPoint> saplingNotes = SetSaplingNoteData(wtx);
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
std::vector<JSOutPoint> notes {jsoutpt}; std::vector<JSOutPoint> sproutNotes {jsoutpt};
std::vector<boost::optional<ZCIncrementalWitness>> witnesses; std::vector<boost::optional<ZCIncrementalWitness>> sproutWitnesses;
uint256 anchor2; std::vector<boost::optional<ZCSaplingIncrementalWitness>> saplingWitnesses;
wallet.GetNoteWitnesses(notes, witnesses, anchor2); GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
EXPECT_FALSE((bool) witnesses[0]);
EXPECT_FALSE((bool) sproutWitnesses[0]);
EXPECT_FALSE((bool) saplingWitnesses[0]);
// Second block // Second block
CBlock block2; CBlock block2;
@ -631,46 +675,57 @@ TEST(wallet_tests, cached_witnesses_chain_tip) {
block2.vtx.push_back(wtx); block2.vtx.push_back(wtx);
CBlockIndex index2(block2); CBlockIndex index2(block2);
index2.nHeight = 2; index2.nHeight = 2;
ZCIncrementalMerkleTree tree2 {tree}; ZCIncrementalMerkleTree sproutTree2 {sproutTree};
wallet.IncrementNoteWitnesses(&index2, &block2, tree2); ZCSaplingIncrementalMerkleTree saplingTree2 {saplingTree};
witnesses.clear(); wallet.IncrementNoteWitnesses(&index2, &block2, sproutTree2, saplingTree2);
wallet.GetNoteWitnesses(notes, witnesses, anchor2);
EXPECT_TRUE((bool) witnesses[0]); auto anchors2 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
EXPECT_NE(anchor1, anchor2); EXPECT_NE(anchors2.first, anchors2.second);
EXPECT_TRUE((bool) sproutWitnesses[0]);
EXPECT_TRUE((bool) saplingWitnesses[0]);
EXPECT_NE(anchors1.first, anchors2.first);
EXPECT_NE(anchors1.second, anchors2.second);
// Decrementing should give us the previous anchor // Decrementing should give us the previous anchor
uint256 anchor3;
wallet.DecrementNoteWitnesses(&index2); wallet.DecrementNoteWitnesses(&index2);
witnesses.clear(); auto anchors3 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
wallet.GetNoteWitnesses(notes, witnesses, anchor3);
EXPECT_FALSE((bool) witnesses[0]); EXPECT_FALSE((bool) sproutWitnesses[0]);
EXPECT_FALSE((bool) saplingWitnesses[0]);
// Should not equal first anchor because none of these notes had witnesses // Should not equal first anchor because none of these notes had witnesses
EXPECT_NE(anchor1, anchor3); EXPECT_NE(anchors1.first, anchors3.first);
EXPECT_NE(anchors1.second, anchors3.second);
// Re-incrementing with the same block should give the same result // Re-incrementing with the same block should give the same result
uint256 anchor4; wallet.IncrementNoteWitnesses(&index2, &block2, sproutTree, saplingTree);
wallet.IncrementNoteWitnesses(&index2, &block2, tree); auto anchors4 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
witnesses.clear(); EXPECT_NE(anchors4.first, anchors4.second);
wallet.GetNoteWitnesses(notes, witnesses, anchor4);
EXPECT_TRUE((bool) witnesses[0]); EXPECT_TRUE((bool) sproutWitnesses[0]);
EXPECT_EQ(anchor2, anchor4); EXPECT_TRUE((bool) saplingWitnesses[0]);
EXPECT_EQ(anchors2.first, anchors4.first);
EXPECT_EQ(anchors2.second, anchors4.second);
// Incrementing with the same block again should not change the cache // Incrementing with the same block again should not change the cache
uint256 anchor5; wallet.IncrementNoteWitnesses(&index2, &block2, sproutTree, saplingTree);
wallet.IncrementNoteWitnesses(&index2, &block2, tree); std::vector<boost::optional<ZCIncrementalWitness>> sproutWitnesses5;
std::vector<boost::optional<ZCIncrementalWitness>> witnesses5; std::vector<boost::optional<ZCSaplingIncrementalWitness>> saplingWitnesses5;
wallet.GetNoteWitnesses(notes, witnesses5, anchor5);
EXPECT_EQ(witnesses, witnesses5); auto anchors5 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses5, saplingWitnesses5);
EXPECT_EQ(anchor4, anchor5); EXPECT_NE(anchors5.first, anchors5.second);
EXPECT_EQ(sproutWitnesses, sproutWitnesses5);
EXPECT_EQ(saplingWitnesses, saplingWitnesses5);
EXPECT_EQ(anchors4.first, anchors5.first);
EXPECT_EQ(anchors4.second, anchors5.second);
} }
} }
TEST(wallet_tests, CachedWitnessesDecrementFirst) { TEST(wallet_tests, CachedWitnessesDecrementFirst) {
TestWallet wallet; TestWallet wallet;
uint256 anchor2; ZCIncrementalMerkleTree sproutTree;
CBlock block2; ZCSaplingIncrementalMerkleTree saplingTree;
CBlockIndex index2(block2);
ZCIncrementalMerkleTree tree;
auto sk = libzcash::SproutSpendingKey::random(); auto sk = libzcash::SproutSpendingKey::random();
wallet.AddSpendingKey(sk); wallet.AddSpendingKey(sk);
@ -680,57 +735,70 @@ TEST(wallet_tests, CachedWitnessesDecrementFirst) {
CBlock block1; CBlock block1;
CBlockIndex index1(block1); CBlockIndex index1(block1);
index1.nHeight = 1; index1.nHeight = 1;
CreateValidBlock(wallet, sk, index1, block1, tree); CreateValidBlock(wallet, sk, index1, block1, sproutTree, saplingTree);
} }
std::pair<uint256, uint256> anchors2;
CBlock block2;
CBlockIndex index2(block2);
{ {
// Second block (case tested in _chain_tip) // Second block (case tested in _chain_tip)
index2.nHeight = 2; index2.nHeight = 2;
auto jsoutpt = CreateValidBlock(wallet, sk, index2, block2, tree); auto outpts = CreateValidBlock(wallet, sk, index2, block2, sproutTree, saplingTree);
// Called to fetch anchor // Called to fetch anchor
std::vector<JSOutPoint> notes {jsoutpt}; std::vector<JSOutPoint> sproutNotes {outpts.first};
std::vector<boost::optional<ZCIncrementalWitness>> witnesses; std::vector<SaplingOutPoint> saplingNotes {outpts.second};
wallet.GetNoteWitnesses(notes, witnesses, anchor2); std::vector<boost::optional<ZCIncrementalWitness>> sproutWitnesses;
std::vector<boost::optional<ZCSaplingIncrementalWitness>> saplingWitnesses;
anchors2 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
} }
{ {
// Third transaction - never mined // Third transaction - never mined
auto wtx = GetValidReceive(sk, 20, true); auto wtx = GetValidReceive(sk, 20, true, 4);
auto note = GetNote(sk, wtx, 0, 1); auto note = GetNote(sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
std::vector<SaplingOutPoint> saplingNotes = SetSaplingNoteData(wtx);
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
std::vector<JSOutPoint> notes {jsoutpt}; std::vector<JSOutPoint> sproutNotes {jsoutpt};
std::vector<boost::optional<ZCIncrementalWitness>> witnesses; std::vector<boost::optional<ZCIncrementalWitness>> sproutWitnesses;
uint256 anchor3; std::vector<boost::optional<ZCSaplingIncrementalWitness>> saplingWitnesses;
wallet.GetNoteWitnesses(notes, witnesses, anchor3); auto anchors3 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
EXPECT_FALSE((bool) witnesses[0]);
EXPECT_FALSE((bool) sproutWitnesses[0]);
EXPECT_FALSE((bool) saplingWitnesses[0]);
// Decrementing (before the transaction has ever seen an increment) // Decrementing (before the transaction has ever seen an increment)
// should give us the previous anchor // should give us the previous anchor
uint256 anchor4;
wallet.DecrementNoteWitnesses(&index2); wallet.DecrementNoteWitnesses(&index2);
witnesses.clear();
wallet.GetNoteWitnesses(notes, witnesses, anchor4); auto anchors4 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
EXPECT_FALSE((bool) witnesses[0]);
EXPECT_FALSE((bool) sproutWitnesses[0]);
EXPECT_FALSE((bool) saplingWitnesses[0]);
// Should not equal second anchor because none of these notes had witnesses // Should not equal second anchor because none of these notes had witnesses
EXPECT_NE(anchor2, anchor4); EXPECT_NE(anchors2.first, anchors4.first);
EXPECT_NE(anchors2.second, anchors4.second);
// Re-incrementing with the same block should give the same result // Re-incrementing with the same block should give the same result
uint256 anchor5; wallet.IncrementNoteWitnesses(&index2, &block2, sproutTree, saplingTree);
wallet.IncrementNoteWitnesses(&index2, &block2, tree);
witnesses.clear(); auto anchors5 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
wallet.GetNoteWitnesses(notes, witnesses, anchor5);
EXPECT_FALSE((bool) witnesses[0]); EXPECT_FALSE((bool) sproutWitnesses[0]);
EXPECT_EQ(anchor3, anchor5); EXPECT_FALSE((bool) saplingWitnesses[0]);
EXPECT_EQ(anchors3.first, anchors5.first);
EXPECT_EQ(anchors3.second, anchors5.second);
} }
} }
@ -738,11 +806,16 @@ TEST(wallet_tests, CachedWitnessesCleanIndex) {
TestWallet wallet; TestWallet wallet;
std::vector<CBlock> blocks; std::vector<CBlock> blocks;
std::vector<CBlockIndex> indices; std::vector<CBlockIndex> indices;
std::vector<JSOutPoint> notes; std::vector<JSOutPoint> sproutNotes;
std::vector<uint256> anchors; std::vector<SaplingOutPoint> saplingNotes;
ZCIncrementalMerkleTree tree; std::vector<uint256> sproutAnchors;
ZCIncrementalMerkleTree riTree = tree; std::vector<uint256> saplingAnchors;
std::vector<boost::optional<ZCIncrementalWitness>> witnesses; ZCIncrementalMerkleTree sproutTree;
ZCIncrementalMerkleTree sproutRiTree = sproutTree;
ZCSaplingIncrementalMerkleTree saplingTree;
ZCSaplingIncrementalMerkleTree saplingRiTree = saplingTree;
std::vector<boost::optional<ZCIncrementalWitness>> sproutWitnesses;
std::vector<boost::optional<ZCSaplingIncrementalWitness>> saplingWitnesses;
auto sk = libzcash::SproutSpendingKey::random(); auto sk = libzcash::SproutSpendingKey::random();
wallet.AddSpendingKey(sk); wallet.AddSpendingKey(sk);
@ -753,58 +826,64 @@ TEST(wallet_tests, CachedWitnessesCleanIndex) {
indices.resize(numBlocks); indices.resize(numBlocks);
for (size_t i = 0; i < numBlocks; i++) { for (size_t i = 0; i < numBlocks; i++) {
indices[i].nHeight = i; indices[i].nHeight = i;
auto old = tree.root(); auto oldSproutRoot = sproutTree.root();
auto jsoutpt = CreateValidBlock(wallet, sk, indices[i], blocks[i], tree); auto oldSaplingRoot = saplingTree.root();
EXPECT_NE(old, tree.root()); auto outpts = CreateValidBlock(wallet, sk, indices[i], blocks[i], sproutTree, saplingTree);
notes.push_back(jsoutpt); EXPECT_NE(oldSproutRoot, sproutTree.root());
EXPECT_NE(oldSaplingRoot, saplingTree.root());
sproutNotes.push_back(outpts.first);
saplingNotes.push_back(outpts.second);
witnesses.clear(); auto anchors = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
uint256 anchor;
wallet.GetNoteWitnesses(notes, witnesses, anchor);
for (size_t j = 0; j <= i; j++) { for (size_t j = 0; j <= i; j++) {
EXPECT_TRUE((bool) witnesses[j]); EXPECT_TRUE((bool) sproutWitnesses[j]);
EXPECT_TRUE((bool) saplingWitnesses[j]);
} }
anchors.push_back(anchor); sproutAnchors.push_back(anchors.first);
saplingAnchors.push_back(anchors.second);
} }
// Now pretend we are reindexing: the chain is cleared, and each block is // Now pretend we are reindexing: the chain is cleared, and each block is
// used to increment witnesses again. // used to increment witnesses again.
for (size_t i = 0; i < numBlocks; i++) { for (size_t i = 0; i < numBlocks; i++) {
ZCIncrementalMerkleTree riPrevTree {riTree}; ZCIncrementalMerkleTree sproutRiPrevTree {sproutRiTree};
wallet.IncrementNoteWitnesses(&(indices[i]), &(blocks[i]), riTree); ZCSaplingIncrementalMerkleTree saplingRiPrevTree {saplingRiTree};
witnesses.clear(); wallet.IncrementNoteWitnesses(&(indices[i]), &(blocks[i]), sproutRiTree, saplingRiTree);
uint256 anchor;
wallet.GetNoteWitnesses(notes, witnesses, anchor); auto anchors = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
for (size_t j = 0; j < numBlocks; j++) { for (size_t j = 0; j < numBlocks; j++) {
EXPECT_TRUE((bool) witnesses[j]); EXPECT_TRUE((bool) sproutWitnesses[j]);
EXPECT_TRUE((bool) saplingWitnesses[j]);
} }
// Should equal final anchor because witness cache unaffected // Should equal final anchor because witness cache unaffected
EXPECT_EQ(anchors.back(), anchor); EXPECT_EQ(sproutAnchors.back(), anchors.first);
EXPECT_EQ(saplingAnchors.back(), anchors.second);
if ((i == 5) || (i == 50)) { if ((i == 5) || (i == 50)) {
// Pretend a reorg happened that was recorded in the block files // Pretend a reorg happened that was recorded in the block files
{ {
wallet.DecrementNoteWitnesses(&(indices[i])); wallet.DecrementNoteWitnesses(&(indices[i]));
witnesses.clear();
uint256 anchor; auto anchors = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
wallet.GetNoteWitnesses(notes, witnesses, anchor);
for (size_t j = 0; j < numBlocks; j++) { for (size_t j = 0; j < numBlocks; j++) {
EXPECT_TRUE((bool) witnesses[j]); EXPECT_TRUE((bool) sproutWitnesses[j]);
EXPECT_TRUE((bool) saplingWitnesses[j]);
} }
// Should equal final anchor because witness cache unaffected // Should equal final anchor because witness cache unaffected
EXPECT_EQ(anchors.back(), anchor); EXPECT_EQ(sproutAnchors.back(), anchors.first);
EXPECT_EQ(saplingAnchors.back(), anchors.second);
} }
{ {
wallet.IncrementNoteWitnesses(&(indices[i]), &(blocks[i]), riPrevTree); wallet.IncrementNoteWitnesses(&(indices[i]), &(blocks[i]), sproutRiPrevTree, saplingRiPrevTree);
witnesses.clear(); auto anchors = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
uint256 anchor;
wallet.GetNoteWitnesses(notes, witnesses, anchor);
for (size_t j = 0; j < numBlocks; j++) { for (size_t j = 0; j < numBlocks; j++) {
EXPECT_TRUE((bool) witnesses[j]); EXPECT_TRUE((bool) sproutWitnesses[j]);
EXPECT_TRUE((bool) saplingWitnesses[j]);
} }
// Should equal final anchor because witness cache unaffected // Should equal final anchor because witness cache unaffected
EXPECT_EQ(anchors.back(), anchor); EXPECT_EQ(sproutAnchors.back(), anchors.first);
EXPECT_EQ(saplingAnchors.back(), anchors.second);
} }
} }
} }
@ -816,44 +895,55 @@ TEST(wallet_tests, ClearNoteWitnessCache) {
auto sk = libzcash::SproutSpendingKey::random(); auto sk = libzcash::SproutSpendingKey::random();
wallet.AddSpendingKey(sk); wallet.AddSpendingKey(sk);
auto wtx = GetValidReceive(sk, 10, true); auto wtx = GetValidReceive(sk, 10, true, 4);
auto hash = wtx.GetHash(); auto hash = wtx.GetHash();
auto note = GetNote(sk, wtx, 0, 0); auto note = GetNote(sk, wtx, 0, 0);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 0}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 0};
JSOutPoint jsoutpt2 {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt2 {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
auto saplingNotes = SetSaplingNoteData(wtx);
// Pretend we mined the tx by adding a fake witness // Pretend we mined the tx by adding a fake witness
ZCIncrementalMerkleTree tree; ZCIncrementalMerkleTree sproutTree;
wtx.mapNoteData[jsoutpt].witnesses.push_front(tree.witness()); wtx.mapSproutNoteData[jsoutpt].witnesses.push_front(sproutTree.witness());
wtx.mapNoteData[jsoutpt].witnessHeight = 1; wtx.mapSproutNoteData[jsoutpt].witnessHeight = 1;
wallet.nWitnessCacheSize = 1; wallet.nWitnessCacheSize = 1;
ZCSaplingIncrementalMerkleTree saplingTree;
wtx.mapSaplingNoteData[saplingNotes[0]].witnesses.push_front(saplingTree.witness());
wtx.mapSaplingNoteData[saplingNotes[0]].witnessHeight = 1;
wallet.nWitnessCacheSize = 2;
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
std::vector<JSOutPoint> notes {jsoutpt, jsoutpt2}; std::vector<JSOutPoint> sproutNotes {jsoutpt, jsoutpt2};
std::vector<boost::optional<ZCIncrementalWitness>> witnesses; std::vector<boost::optional<ZCIncrementalWitness>> sproutWitnesses;
uint256 anchor2; std::vector<boost::optional<ZCSaplingIncrementalWitness>> saplingWitnesses;
// Before clearing, we should have a witness for one note // Before clearing, we should have a witness for one note
wallet.GetNoteWitnesses(notes, witnesses, anchor2); GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
EXPECT_TRUE((bool) witnesses[0]); EXPECT_TRUE((bool) sproutWitnesses[0]);
EXPECT_FALSE((bool) witnesses[1]); EXPECT_FALSE((bool) sproutWitnesses[1]);
EXPECT_EQ(1, wallet.mapWallet[hash].mapNoteData[jsoutpt].witnessHeight); EXPECT_TRUE((bool) saplingWitnesses[0]);
EXPECT_EQ(1, wallet.nWitnessCacheSize); EXPECT_FALSE((bool) saplingWitnesses[1]);
EXPECT_EQ(1, wallet.mapWallet[hash].mapSproutNoteData[jsoutpt].witnessHeight);
EXPECT_EQ(1, wallet.mapWallet[hash].mapSaplingNoteData[saplingNotes[0]].witnessHeight);
EXPECT_EQ(2, wallet.nWitnessCacheSize);
// After clearing, we should not have a witness for either note // After clearing, we should not have a witness for either note
wallet.ClearNoteWitnessCache(); wallet.ClearNoteWitnessCache();
witnesses.clear(); auto anchros2 = GetWitnessesAndAnchors(wallet, sproutNotes, saplingNotes, sproutWitnesses, saplingWitnesses);
wallet.GetNoteWitnesses(notes, witnesses, anchor2); EXPECT_FALSE((bool) sproutWitnesses[0]);
EXPECT_FALSE((bool) witnesses[0]); EXPECT_FALSE((bool) sproutWitnesses[1]);
EXPECT_FALSE((bool) witnesses[1]); EXPECT_FALSE((bool) saplingWitnesses[0]);
EXPECT_EQ(-1, wallet.mapWallet[hash].mapNoteData[jsoutpt].witnessHeight); EXPECT_FALSE((bool) saplingWitnesses[1]);
EXPECT_EQ(-1, wallet.mapWallet[hash].mapSproutNoteData[jsoutpt].witnessHeight);
EXPECT_EQ(-1, wallet.mapWallet[hash].mapSaplingNoteData[saplingNotes[0]].witnessHeight);
EXPECT_EQ(0, wallet.nWitnessCacheSize); EXPECT_EQ(0, wallet.nWitnessCacheSize);
} }
@ -949,11 +1039,11 @@ TEST(wallet_tests, UpdateNullifierNoteMap) {
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
// Pretend that we called FindMyNotes while the wallet was locked // Pretend that we called FindMyNotes while the wallet was locked
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address()}; SproutNoteData nd {sk.address()};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
EXPECT_EQ(0, wallet.mapNullifiersToNotes.count(nullifier)); EXPECT_EQ(0, wallet.mapNullifiersToNotes.count(nullifier));
@ -984,35 +1074,35 @@ TEST(wallet_tests, UpdatedNoteData) {
// First pretend we added the tx to the wallet and // First pretend we added the tx to the wallet and
// we don't have the key for the second note // we don't have the key for the second note
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 0}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 0};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
// Pretend we mined the tx by adding a fake witness // Pretend we mined the tx by adding a fake witness
ZCIncrementalMerkleTree tree; ZCIncrementalMerkleTree tree;
wtx.mapNoteData[jsoutpt].witnesses.push_front(tree.witness()); wtx.mapSproutNoteData[jsoutpt].witnesses.push_front(tree.witness());
wtx.mapNoteData[jsoutpt].witnessHeight = 100; wtx.mapSproutNoteData[jsoutpt].witnessHeight = 100;
// Now pretend we added the key for the second note, and // Now pretend we added the key for the second note, and
// the tx was "added" to the wallet again to update it. // the tx was "added" to the wallet again to update it.
// This happens via the 'z_importkey' RPC method. // This happens via the 'z_importkey' RPC method.
JSOutPoint jsoutpt2 {wtx2.GetHash(), 0, 1}; JSOutPoint jsoutpt2 {wtx2.GetHash(), 0, 1};
CNoteData nd2 {sk.address(), nullifier2}; SproutNoteData nd2 {sk.address(), nullifier2};
noteData[jsoutpt2] = nd2; noteData[jsoutpt2] = nd2;
wtx2.SetNoteData(noteData); wtx2.SetSproutNoteData(noteData);
// The txs should initially be different // The txs should initially be different
EXPECT_NE(wtx.mapNoteData, wtx2.mapNoteData); EXPECT_NE(wtx.mapSproutNoteData, wtx2.mapSproutNoteData);
EXPECT_EQ(1, wtx.mapNoteData[jsoutpt].witnesses.size()); EXPECT_EQ(1, wtx.mapSproutNoteData[jsoutpt].witnesses.size());
EXPECT_EQ(100, wtx.mapNoteData[jsoutpt].witnessHeight); EXPECT_EQ(100, wtx.mapSproutNoteData[jsoutpt].witnessHeight);
// After updating, they should be the same // After updating, they should be the same
EXPECT_TRUE(wallet.UpdatedNoteData(wtx2, wtx)); EXPECT_TRUE(wallet.UpdatedNoteData(wtx2, wtx));
EXPECT_EQ(wtx.mapNoteData, wtx2.mapNoteData); EXPECT_EQ(wtx.mapSproutNoteData, wtx2.mapSproutNoteData);
EXPECT_EQ(1, wtx.mapNoteData[jsoutpt].witnesses.size()); EXPECT_EQ(1, wtx.mapSproutNoteData[jsoutpt].witnesses.size());
EXPECT_EQ(100, wtx.mapNoteData[jsoutpt].witnessHeight); EXPECT_EQ(100, wtx.mapSproutNoteData[jsoutpt].witnessHeight);
// TODO: The new note should get witnessed (but maybe not here) (#1350) // TODO: The new note should get witnessed (but maybe not here) (#1350)
} }
@ -1028,12 +1118,12 @@ TEST(wallet_tests, MarkAffectedTransactionsDirty) {
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
auto wtx2 = GetValidSpend(sk, note, 5); auto wtx2 = GetValidSpend(sk, note, 5);
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {hash, 0, 1}; JSOutPoint jsoutpt {hash, 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
wallet.MarkAffectedTransactionsDirty(wtx); wallet.MarkAffectedTransactionsDirty(wtx);

View File

@ -449,11 +449,14 @@ bool CWallet::ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase,
return false; return false;
} }
void CWallet::ChainTip(const CBlockIndex *pindex, const CBlock *pblock, void CWallet::ChainTip(const CBlockIndex *pindex,
ZCIncrementalMerkleTree tree, bool added) const CBlock *pblock,
ZCIncrementalMerkleTree sproutTree,
ZCSaplingIncrementalMerkleTree saplingTree,
bool added)
{ {
if (added) { if (added) {
IncrementNoteWitnesses(pindex, pblock, tree); IncrementNoteWitnesses(pindex, pblock, sproutTree, saplingTree);
} else { } else {
DecrementNoteWitnesses(pindex); DecrementNoteWitnesses(pindex);
} }
@ -469,7 +472,7 @@ std::set<std::pair<libzcash::PaymentAddress, uint256>> CWallet::GetNullifiersFor
{ {
std::set<std::pair<libzcash::PaymentAddress, uint256>> nullifierSet; std::set<std::pair<libzcash::PaymentAddress, uint256>> nullifierSet;
for (const auto & txPair : mapWallet) { for (const auto & txPair : mapWallet) {
for (const auto & noteDataPair : txPair.second.mapNoteData) { for (const auto & noteDataPair : txPair.second.mapSproutNoteData) {
if (noteDataPair.second.nullifier && addresses.count(noteDataPair.second.address)) { if (noteDataPair.second.nullifier && addresses.count(noteDataPair.second.address)) {
nullifierSet.insert(std::make_pair(noteDataPair.second.address, noteDataPair.second.nullifier.get())); nullifierSet.insert(std::make_pair(noteDataPair.second.address, noteDataPair.second.nullifier.get()));
} }
@ -653,7 +656,7 @@ void CWallet::SyncMetaData(pair<typename TxSpendMap<T>::iterator, typename TxSpe
CWalletTx* copyTo = &mapWallet[hash]; CWalletTx* copyTo = &mapWallet[hash];
if (copyFrom == copyTo) continue; if (copyFrom == copyTo) continue;
copyTo->mapValue = copyFrom->mapValue; copyTo->mapValue = copyFrom->mapValue;
// mapNoteData not copied on purpose // mapSproutNoteData not copied on purpose
// (it is always set correctly for each CWalletTx) // (it is always set correctly for each CWalletTx)
copyTo->vOrderForm = copyFrom->vOrderForm; copyTo->vOrderForm = copyFrom->vOrderForm;
// fTimeReceivedIsTxTime not copied on purpose // fTimeReceivedIsTxTime not copied on purpose
@ -744,7 +747,11 @@ void CWallet::ClearNoteWitnessCache()
{ {
LOCK(cs_wallet); LOCK(cs_wallet);
for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) { for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) {
for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { for (mapSproutNoteData_t::value_type& item : wtxItem.second.mapSproutNoteData) {
item.second.witnesses.clear();
item.second.witnessHeight = -1;
}
for (mapSaplingNoteData_t::value_type& item : wtxItem.second.mapSaplingNoteData) {
item.second.witnesses.clear(); item.second.witnesses.clear();
item.second.witnessHeight = -1; item.second.witnessHeight = -1;
} }
@ -752,176 +759,219 @@ void CWallet::ClearNoteWitnessCache()
nWitnessCacheSize = 0; nWitnessCacheSize = 0;
} }
template<typename NoteDataMap>
void CopyPreviousWitnesses(NoteDataMap& noteDataMap, int indexHeight, int64_t nWitnessCacheSize)
{
for (auto& item : noteDataMap) {
auto* nd = &(item.second);
// Only increment witnesses that are behind the current height
if (nd->witnessHeight < indexHeight) {
// Check the validity of the cache
// The only time a note witnessed above the current height
// would be invalid here is during a reindex when blocks
// have been decremented, and we are incrementing the blocks
// immediately after.
assert(nWitnessCacheSize >= nd->witnesses.size());
// Witnesses being incremented should always be either -1
// (never incremented or decremented) or one below indexHeight
assert((nd->witnessHeight == -1) || (nd->witnessHeight == indexHeight - 1));
// Copy the witness for the previous block if we have one
if (nd->witnesses.size() > 0) {
nd->witnesses.push_front(nd->witnesses.front());
}
if (nd->witnesses.size() > WITNESS_CACHE_SIZE) {
nd->witnesses.pop_back();
}
}
}
}
template<typename NoteDataMap>
void AppendNoteCommitment(NoteDataMap& noteDataMap, int indexHeight, int64_t nWitnessCacheSize, const uint256& note_commitment)
{
for (auto& item : noteDataMap) {
auto* nd = &(item.second);
if (nd->witnessHeight < indexHeight && nd->witnesses.size() > 0) {
// Check the validity of the cache
// See comment in CopyPreviousWitnesses about validity.
assert(nWitnessCacheSize >= nd->witnesses.size());
nd->witnesses.front().append(note_commitment);
}
}
}
template<typename OutPoint, typename NoteData, typename Witness>
void WitnessNoteIfMine(std::map<OutPoint, NoteData>& noteDataMap, int indexHeight, int64_t nWitnessCacheSize, const OutPoint& key, const Witness& witness)
{
if (noteDataMap.count(key) && noteDataMap[key].witnessHeight < indexHeight) {
auto* nd = &(noteDataMap[key]);
if (nd->witnesses.size() > 0) {
// We think this can happen because we write out the
// witness cache state after every block increment or
// decrement, but the block index itself is written in
// batches. So if the node crashes in between these two
// operations, it is possible for IncrementNoteWitnesses
// to be called again on previously-cached blocks. This
// doesn't affect existing cached notes because of the
// NoteData::witnessHeight checks. See #1378 for details.
LogPrintf("Inconsistent witness cache state found for %s\n- Cache size: %d\n- Top (height %d): %s\n- New (height %d): %s\n",
key.ToString(), nd->witnesses.size(),
nd->witnessHeight,
nd->witnesses.front().root().GetHex(),
indexHeight,
witness.root().GetHex());
nd->witnesses.clear();
}
nd->witnesses.push_front(witness);
// Set height to one less than pindex so it gets incremented
nd->witnessHeight = indexHeight - 1;
// Check the validity of the cache
assert(nWitnessCacheSize >= nd->witnesses.size());
}
}
template<typename NoteDataMap>
void UpdateWitnessHeights(NoteDataMap& noteDataMap, int indexHeight, int64_t nWitnessCacheSize)
{
for (auto& item : noteDataMap) {
auto* nd = &(item.second);
if (nd->witnessHeight < indexHeight) {
nd->witnessHeight = indexHeight;
// Check the validity of the cache
// See comment in CopyPreviousWitnesses about validity.
assert(nWitnessCacheSize >= nd->witnesses.size());
}
}
}
void CWallet::IncrementNoteWitnesses(const CBlockIndex* pindex, void CWallet::IncrementNoteWitnesses(const CBlockIndex* pindex,
const CBlock* pblockIn, const CBlock* pblockIn,
ZCIncrementalMerkleTree& tree) ZCIncrementalMerkleTree& sproutTree,
ZCSaplingIncrementalMerkleTree& saplingTree)
{ {
{ LOCK(cs_wallet);
LOCK(cs_wallet); for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) {
for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) { ::CopyPreviousWitnesses(wtxItem.second.mapSproutNoteData, pindex->nHeight, nWitnessCacheSize);
for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { ::CopyPreviousWitnesses(wtxItem.second.mapSaplingNoteData, pindex->nHeight, nWitnessCacheSize);
CNoteData* nd = &(item.second); }
// Only increment witnesses that are behind the current height
if (nd->witnessHeight < pindex->nHeight) { if (nWitnessCacheSize < WITNESS_CACHE_SIZE) {
// Check the validity of the cache nWitnessCacheSize += 1;
// The only time a note witnessed above the current height }
// would be invalid here is during a reindex when blocks
// have been decremented, and we are incrementing the blocks const CBlock* pblock {pblockIn};
// immediately after. CBlock block;
assert(nWitnessCacheSize >= nd->witnesses.size()); if (!pblock) {
// Witnesses being incremented should always be either -1 ReadBlockFromDisk(block, pindex);
// (never incremented or decremented) or one below pindex pblock = &block;
assert((nd->witnessHeight == -1) || }
(nd->witnessHeight == pindex->nHeight - 1));
// Copy the witness for the previous block if we have one for (const CTransaction& tx : pblock->vtx) {
if (nd->witnesses.size() > 0) { auto hash = tx.GetHash();
nd->witnesses.push_front(nd->witnesses.front()); bool txIsOurs = mapWallet.count(hash);
} // Sprout
if (nd->witnesses.size() > WITNESS_CACHE_SIZE) { for (size_t i = 0; i < tx.vjoinsplit.size(); i++) {
nd->witnesses.pop_back(); const JSDescription& jsdesc = tx.vjoinsplit[i];
} for (uint8_t j = 0; j < jsdesc.commitments.size(); j++) {
const uint256& note_commitment = jsdesc.commitments[j];
sproutTree.append(note_commitment);
// Increment existing witnesses
for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) {
::AppendNoteCommitment(wtxItem.second.mapSproutNoteData, pindex->nHeight, nWitnessCacheSize, note_commitment);
}
// If this is our note, witness it
if (txIsOurs) {
JSOutPoint jsoutpt {hash, i, j};
::WitnessNoteIfMine(mapWallet[hash].mapSproutNoteData, pindex->nHeight, nWitnessCacheSize, jsoutpt, sproutTree.witness());
} }
} }
} }
if (nWitnessCacheSize < WITNESS_CACHE_SIZE) { // Sapling
nWitnessCacheSize += 1; for (uint32_t i = 0; i < tx.vShieldedOutput.size(); i++) {
} const uint256& note_commitment = tx.vShieldedOutput[i].cm;
saplingTree.append(note_commitment);
const CBlock* pblock {pblockIn}; // Increment existing witnesses
CBlock block; for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) {
if (!pblock) { ::AppendNoteCommitment(wtxItem.second.mapSaplingNoteData, pindex->nHeight, nWitnessCacheSize, note_commitment);
ReadBlockFromDisk(block, pindex); }
pblock = &block;
}
for (const CTransaction& tx : pblock->vtx) { // If this is our note, witness it
auto hash = tx.GetHash(); if (txIsOurs) {
bool txIsOurs = mapWallet.count(hash); SaplingOutPoint outPoint {hash, i};
for (size_t i = 0; i < tx.vjoinsplit.size(); i++) { ::WitnessNoteIfMine(mapWallet[hash].mapSaplingNoteData, pindex->nHeight, nWitnessCacheSize, outPoint, saplingTree.witness());
const JSDescription& jsdesc = tx.vjoinsplit[i];
for (uint8_t j = 0; j < jsdesc.commitments.size(); j++) {
const uint256& note_commitment = jsdesc.commitments[j];
tree.append(note_commitment);
// Increment existing witnesses
for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) {
for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) {
CNoteData* nd = &(item.second);
if (nd->witnessHeight < pindex->nHeight &&
nd->witnesses.size() > 0) {
// Check the validity of the cache
// See earlier comment about validity.
assert(nWitnessCacheSize >= nd->witnesses.size());
nd->witnesses.front().append(note_commitment);
}
}
}
// If this is our note, witness it
if (txIsOurs) {
JSOutPoint jsoutpt {hash, i, j};
if (mapWallet[hash].mapNoteData.count(jsoutpt) &&
mapWallet[hash].mapNoteData[jsoutpt].witnessHeight < pindex->nHeight) {
CNoteData* nd = &(mapWallet[hash].mapNoteData[jsoutpt]);
if (nd->witnesses.size() > 0) {
// We think this can happen because we write out the
// witness cache state after every block increment or
// decrement, but the block index itself is written in
// batches. So if the node crashes in between these two
// operations, it is possible for IncrementNoteWitnesses
// to be called again on previously-cached blocks. This
// doesn't affect existing cached notes because of the
// CNoteData::witnessHeight checks. See #1378 for details.
LogPrintf("Inconsistent witness cache state found for %s\n- Cache size: %d\n- Top (height %d): %s\n- New (height %d): %s\n",
jsoutpt.ToString(), nd->witnesses.size(),
nd->witnessHeight,
nd->witnesses.front().root().GetHex(),
pindex->nHeight,
tree.witness().root().GetHex());
nd->witnesses.clear();
}
nd->witnesses.push_front(tree.witness());
// Set height to one less than pindex so it gets incremented
nd->witnessHeight = pindex->nHeight - 1;
// Check the validity of the cache
assert(nWitnessCacheSize >= nd->witnesses.size());
}
}
}
} }
} }
}
// Update witness heights // Update witness heights
for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) { for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) {
for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { ::UpdateWitnessHeights(wtxItem.second.mapSproutNoteData, pindex->nHeight, nWitnessCacheSize);
CNoteData* nd = &(item.second); ::UpdateWitnessHeights(wtxItem.second.mapSaplingNoteData, pindex->nHeight, nWitnessCacheSize);
if (nd->witnessHeight < pindex->nHeight) { }
nd->witnessHeight = pindex->nHeight;
// Check the validity of the cache // For performance reasons, we write out the witness cache in
// See earlier comment about validity. // CWallet::SetBestChain() (which also ensures that overall consistency
assert(nWitnessCacheSize >= nd->witnesses.size()); // of the wallet.dat is maintained).
} }
template<typename NoteDataMap>
void DecrementNoteWitnesses(NoteDataMap& noteDataMap, int indexHeight, int64_t nWitnessCacheSize)
{
for (auto& item : noteDataMap) {
auto* nd = &(item.second);
// Only decrement witnesses that are not above the current height
if (nd->witnessHeight <= indexHeight) {
// Check the validity of the cache
// See comment below (this would be invalid if there were a
// prior decrement).
assert(nWitnessCacheSize >= nd->witnesses.size());
// Witnesses being decremented should always be either -1
// (never incremented or decremented) or equal to the height
// of the block being removed (indexHeight)
assert((nd->witnessHeight == -1) || (nd->witnessHeight == indexHeight));
if (nd->witnesses.size() > 0) {
nd->witnesses.pop_front();
} }
// indexHeight is the height of the block being removed, so
// the new witness cache height is one below it.
nd->witnessHeight = indexHeight - 1;
}
// Check the validity of the cache
// Technically if there are notes witnessed above the current
// height, their cache will now be invalid (relative to the new
// value of nWitnessCacheSize). However, this would only occur
// during a reindex, and by the time the reindex reaches the tip
// of the chain again, the existing witness caches will be valid
// again.
// We don't set nWitnessCacheSize to zero at the start of the
// reindex because the on-disk blocks had already resulted in a
// chain that didn't trigger the assertion below.
if (nd->witnessHeight < indexHeight) {
// Subtract 1 to compare to what nWitnessCacheSize will be after
// decrementing.
assert((nWitnessCacheSize - 1) >= nd->witnesses.size());
} }
// For performance reasons, we write out the witness cache in
// CWallet::SetBestChain() (which also ensures that overall consistency
// of the wallet.dat is maintained).
} }
} }
void CWallet::DecrementNoteWitnesses(const CBlockIndex* pindex) void CWallet::DecrementNoteWitnesses(const CBlockIndex* pindex)
{ {
{ LOCK(cs_wallet);
LOCK(cs_wallet); for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) {
for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) { ::DecrementNoteWitnesses(wtxItem.second.mapSproutNoteData, pindex->nHeight, nWitnessCacheSize);
for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { ::DecrementNoteWitnesses(wtxItem.second.mapSaplingNoteData, pindex->nHeight, nWitnessCacheSize);
CNoteData* nd = &(item.second);
// Only increment witnesses that are not above the current height
if (nd->witnessHeight <= pindex->nHeight) {
// Check the validity of the cache
// See comment below (this would be invalid if there was a
// prior decrement).
assert(nWitnessCacheSize >= nd->witnesses.size());
// Witnesses being decremented should always be either -1
// (never incremented or decremented) or equal to pindex
assert((nd->witnessHeight == -1) ||
(nd->witnessHeight == pindex->nHeight));
if (nd->witnesses.size() > 0) {
nd->witnesses.pop_front();
}
// pindex is the block being removed, so the new witness cache
// height is one below it.
nd->witnessHeight = pindex->nHeight - 1;
}
}
}
nWitnessCacheSize -= 1;
for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) {
for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) {
CNoteData* nd = &(item.second);
// Check the validity of the cache
// Technically if there are notes witnessed above the current
// height, their cache will now be invalid (relative to the new
// value of nWitnessCacheSize). However, this would only occur
// during a reindex, and by the time the reindex reaches the tip
// of the chain again, the existing witness caches will be valid
// again.
// We don't set nWitnessCacheSize to zero at the start of the
// reindex because the on-disk blocks had already resulted in a
// chain that didn't trigger the assertion below.
if (nd->witnessHeight < pindex->nHeight) {
assert(nWitnessCacheSize >= nd->witnesses.size());
}
}
}
// TODO: If nWitnessCache is zero, we need to regenerate the caches (#1302)
assert(nWitnessCacheSize > 0);
// For performance reasons, we write out the witness cache in
// CWallet::SetBestChain() (which also ensures that overall consistency
// of the wallet.dat is maintained).
} }
nWitnessCacheSize -= 1;
// TODO: If nWitnessCache is zero, we need to regenerate the caches (#1302)
assert(nWitnessCacheSize > 0);
// For performance reasons, we write out the witness cache in
// CWallet::SetBestChain() (which also ensures that overall consistency
// of the wallet.dat is maintained).
} }
bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase) bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase)
@ -1075,7 +1125,7 @@ bool CWallet::UpdateNullifierNoteMap()
ZCNoteDecryption dec; ZCNoteDecryption dec;
for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) { for (std::pair<const uint256, CWalletTx>& wtxItem : mapWallet) {
for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { for (mapSproutNoteData_t::value_type& item : wtxItem.second.mapSproutNoteData) {
if (!item.second.nullifier) { if (!item.second.nullifier) {
if (GetNoteDecryptor(item.second.address, dec)) { if (GetNoteDecryptor(item.second.address, dec)) {
auto i = item.first.js; auto i = item.first.js;
@ -1103,7 +1153,7 @@ void CWallet::UpdateNullifierNoteMapWithTx(const CWalletTx& wtx)
{ {
{ {
LOCK(cs_wallet); LOCK(cs_wallet);
for (const mapNoteData_t::value_type& item : wtx.mapNoteData) { for (const mapSproutNoteData_t::value_type& item : wtx.mapSproutNoteData) {
if (item.second.nullifier) { if (item.second.nullifier) {
mapNullifiersToNotes[*item.second.nullifier] = item.first; mapNullifiersToNotes[*item.second.nullifier] = item.first;
} }
@ -1238,12 +1288,12 @@ bool CWallet::AddToWallet(const CWalletTx& wtxIn, bool fFromLoadWallet, CWalletD
bool CWallet::UpdatedNoteData(const CWalletTx& wtxIn, CWalletTx& wtx) bool CWallet::UpdatedNoteData(const CWalletTx& wtxIn, CWalletTx& wtx)
{ {
if (wtxIn.mapNoteData.empty() || wtxIn.mapNoteData == wtx.mapNoteData) { if (wtxIn.mapSproutNoteData.empty() || wtxIn.mapSproutNoteData == wtx.mapSproutNoteData) {
return false; return false;
} }
auto tmp = wtxIn.mapNoteData; auto tmp = wtxIn.mapSproutNoteData;
// Ensure we keep any cached witnesses we may already have // Ensure we keep any cached witnesses we may already have
for (const std::pair<JSOutPoint, CNoteData> nd : wtx.mapNoteData) { for (const std::pair<JSOutPoint, SproutNoteData> nd : wtx.mapSproutNoteData) {
if (tmp.count(nd.first) && nd.second.witnesses.size() > 0) { if (tmp.count(nd.first) && nd.second.witnesses.size() > 0) {
tmp.at(nd.first).witnesses.assign( tmp.at(nd.first).witnesses.assign(
nd.second.witnesses.cbegin(), nd.second.witnesses.cend()); nd.second.witnesses.cbegin(), nd.second.witnesses.cend());
@ -1251,7 +1301,7 @@ bool CWallet::UpdatedNoteData(const CWalletTx& wtxIn, CWalletTx& wtx)
tmp.at(nd.first).witnessHeight = nd.second.witnessHeight; tmp.at(nd.first).witnessHeight = nd.second.witnessHeight;
} }
// Now copy over the updated note data // Now copy over the updated note data
wtx.mapNoteData = tmp; wtx.mapSproutNoteData = tmp;
return true; return true;
} }
@ -1272,8 +1322,9 @@ bool CWallet::AddToWalletIfInvolvingMe(const CTransaction& tx, const CBlock* pbl
CWalletTx wtx(this,tx); CWalletTx wtx(this,tx);
if (noteData.size() > 0) { if (noteData.size() > 0) {
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
} }
// TODO: Sapling note data
// Get merkle branch if transaction was found in a block // Get merkle branch if transaction was found in a block
if (pblock) if (pblock)
@ -1365,14 +1416,14 @@ boost::optional<uint256> CWallet::GetNoteNullifier(const JSDescription& jsdesc,
* *
* It should never be necessary to call this method with a CWalletTx, because * It should never be necessary to call this method with a CWalletTx, because
* the result of FindMyNotes (for the addresses available at the time) will * the result of FindMyNotes (for the addresses available at the time) will
* already have been cached in CWalletTx.mapNoteData. * already have been cached in CWalletTx.mapSproutNoteData.
*/ */
mapNoteData_t CWallet::FindMyNotes(const CTransaction& tx) const mapSproutNoteData_t CWallet::FindMyNotes(const CTransaction& tx) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_SpendingKeyStore);
uint256 hash = tx.GetHash(); uint256 hash = tx.GetHash();
mapNoteData_t noteData; mapSproutNoteData_t noteData;
for (size_t i = 0; i < tx.vjoinsplit.size(); i++) { for (size_t i = 0; i < tx.vjoinsplit.size(); i++) {
auto hSig = tx.vjoinsplit[i].h_sig(*pzcashParams, tx.joinSplitPubKey); auto hSig = tx.vjoinsplit[i].h_sig(*pzcashParams, tx.joinSplitPubKey);
for (uint8_t j = 0; j < tx.vjoinsplit[i].ciphertexts.size(); j++) { for (uint8_t j = 0; j < tx.vjoinsplit[i].ciphertexts.size(); j++) {
@ -1386,10 +1437,10 @@ mapNoteData_t CWallet::FindMyNotes(const CTransaction& tx) const
item.second, item.second,
hSig, j); hSig, j);
if (nullifier) { if (nullifier) {
CNoteData nd {address, *nullifier}; SproutNoteData nd {address, *nullifier};
noteData.insert(std::make_pair(jsoutpt, nd)); noteData.insert(std::make_pair(jsoutpt, nd));
} else { } else {
CNoteData nd {address}; SproutNoteData nd {address};
noteData.insert(std::make_pair(jsoutpt, nd)); noteData.insert(std::make_pair(jsoutpt, nd));
} }
break; break;
@ -1418,32 +1469,57 @@ bool CWallet::IsFromMe(const uint256& nullifier) const
return false; return false;
} }
void CWallet::GetNoteWitnesses(std::vector<JSOutPoint> notes, void CWallet::GetSproutNoteWitnesses(std::vector<JSOutPoint> notes,
std::vector<boost::optional<ZCIncrementalWitness>>& witnesses, std::vector<boost::optional<ZCIncrementalWitness>>& witnesses,
uint256 &final_anchor) uint256 &final_anchor)
{ {
{ LOCK(cs_wallet);
LOCK(cs_wallet); witnesses.resize(notes.size());
witnesses.resize(notes.size()); boost::optional<uint256> rt;
boost::optional<uint256> rt; int i = 0;
int i = 0; for (JSOutPoint note : notes) {
for (JSOutPoint note : notes) { if (mapWallet.count(note.hash) &&
if (mapWallet.count(note.hash) && mapWallet[note.hash].mapSproutNoteData.count(note) &&
mapWallet[note.hash].mapNoteData.count(note) && mapWallet[note.hash].mapSproutNoteData[note].witnesses.size() > 0) {
mapWallet[note.hash].mapNoteData[note].witnesses.size() > 0) { witnesses[i] = mapWallet[note.hash].mapSproutNoteData[note].witnesses.front();
witnesses[i] = mapWallet[note.hash].mapNoteData[note].witnesses.front(); if (!rt) {
if (!rt) { rt = witnesses[i]->root();
rt = witnesses[i]->root(); } else {
} else { assert(*rt == witnesses[i]->root());
assert(*rt == witnesses[i]->root());
}
} }
i++;
} }
// All returned witnesses have the same anchor i++;
if (rt) { }
final_anchor = *rt; // All returned witnesses have the same anchor
if (rt) {
final_anchor = *rt;
}
}
void CWallet::GetSaplingNoteWitnesses(std::vector<SaplingOutPoint> notes,
std::vector<boost::optional<ZCSaplingIncrementalWitness>>& witnesses,
uint256 &final_anchor)
{
LOCK(cs_wallet);
witnesses.resize(notes.size());
boost::optional<uint256> rt;
int i = 0;
for (SaplingOutPoint note : notes) {
if (mapWallet.count(note.hash) &&
mapWallet[note.hash].mapSaplingNoteData.count(note) &&
mapWallet[note.hash].mapSaplingNoteData[note].witnesses.size() > 0) {
witnesses[i] = mapWallet[note.hash].mapSaplingNoteData[note].witnesses.front();
if (!rt) {
rt = witnesses[i]->root();
} else {
assert(*rt == witnesses[i]->root());
}
} }
i++;
}
// All returned witnesses have the same anchor
if (rt) {
final_anchor = *rt;
} }
} }
@ -1578,18 +1654,30 @@ CAmount CWallet::GetChange(const CTransaction& tx) const
return nChange; return nChange;
} }
void CWalletTx::SetNoteData(mapNoteData_t &noteData) void CWalletTx::SetSproutNoteData(mapSproutNoteData_t &noteData)
{ {
mapNoteData.clear(); mapSproutNoteData.clear();
for (const std::pair<JSOutPoint, CNoteData> nd : noteData) { for (const std::pair<JSOutPoint, SproutNoteData> nd : noteData) {
if (nd.first.js < vjoinsplit.size() && if (nd.first.js < vjoinsplit.size() &&
nd.first.n < vjoinsplit[nd.first.js].ciphertexts.size()) { nd.first.n < vjoinsplit[nd.first.js].ciphertexts.size()) {
// Store the address and nullifier for the Note // Store the address and nullifier for the Note
mapNoteData[nd.first] = nd.second; mapSproutNoteData[nd.first] = nd.second;
} else { } else {
// If FindMyNotes() was used to obtain noteData, // If FindMyNotes() was used to obtain noteData,
// this should never happen // this should never happen
throw std::logic_error("CWalletTx::SetNoteData(): Invalid note"); throw std::logic_error("CWalletTx::SetSproutNoteData(): Invalid note");
}
}
}
void CWalletTx::SetSaplingNoteData(mapSaplingNoteData_t &noteData)
{
mapSaplingNoteData.clear();
for (const std::pair<SaplingOutPoint, SaplingNoteData> nd : noteData) {
if (nd.first.n < vShieldedOutput.size()) {
mapSaplingNoteData[nd.first] = nd.second;
} else {
throw std::logic_error("CWalletTx::SetSaplingNoteData(): Invalid note");
} }
} }
} }
@ -1691,7 +1779,7 @@ void CWalletTx::GetAmounts(list<COutputEntry>& listReceived,
// Check output side // Check output side
if (!fMyJSDesc) { if (!fMyJSDesc) {
for (const std::pair<JSOutPoint, CNoteData> nd : this->mapNoteData) { for (const std::pair<JSOutPoint, SproutNoteData> nd : this->mapSproutNoteData) {
if (nd.first.js < vjoinsplit.size() && nd.first.n < vjoinsplit[nd.first.js].ciphertexts.size()) { if (nd.first.js < vjoinsplit.size() && nd.first.n < vjoinsplit[nd.first.js].ciphertexts.size()) {
fMyJSDesc = true; fMyJSDesc = true;
break; break;
@ -1892,12 +1980,18 @@ int CWallet::ScanForWalletTransactions(CBlockIndex* pindexStart, bool fUpdate)
ret++; ret++;
} }
ZCIncrementalMerkleTree tree; ZCIncrementalMerkleTree sproutTree;
ZCSaplingIncrementalMerkleTree saplingTree;
// This should never fail: we should always be able to get the tree // This should never fail: we should always be able to get the tree
// state on the path to the tip of our chain // state on the path to the tip of our chain
assert(pcoinsTip->GetSproutAnchorAt(pindex->hashSproutAnchor, tree)); assert(pcoinsTip->GetSproutAnchorAt(pindex->hashSproutAnchor, sproutTree));
if (pindex->pprev) {
if (NetworkUpgradeActive(pindex->pprev->nHeight, Params().GetConsensus(), Consensus::UPGRADE_SAPLING)) {
assert(pcoinsTip->GetSaplingAnchorAt(pindex->pprev->hashFinalSaplingRoot, saplingTree));
}
}
// Increment note witness caches // Increment note witness caches
IncrementNoteWitnesses(pindex, &block, tree); IncrementNoteWitnesses(pindex, &block, sproutTree, saplingTree);
pindex = chainActive.Next(pindex); pindex = chainActive.Next(pindex);
if (GetTime() >= nNow + 60) { if (GetTime() >= nNow + 60) {
@ -3821,13 +3915,13 @@ void CWallet::GetFilteredNotes(
continue; continue;
} }
if (wtx.mapNoteData.size() == 0) { if (wtx.mapSproutNoteData.size() == 0) {
continue; continue;
} }
for (auto & pair : wtx.mapNoteData) { for (auto & pair : wtx.mapSproutNoteData) {
JSOutPoint jsop = pair.first; JSOutPoint jsop = pair.first;
CNoteData nd = pair.second; SproutNoteData nd = pair.second;
SproutPaymentAddress pa = nd.address; SproutPaymentAddress pa = nd.address;
// skip notes which belong to a different payment address in the wallet // skip notes which belong to a different payment address in the wallet
@ -3902,13 +3996,13 @@ void CWallet::GetUnspentFilteredNotes(
continue; continue;
} }
if (wtx.mapNoteData.size() == 0) { if (wtx.mapSproutNoteData.size() == 0) {
continue; continue;
} }
for (auto & pair : wtx.mapNoteData) { for (auto & pair : wtx.mapSproutNoteData) {
JSOutPoint jsop = pair.first; JSOutPoint jsop = pair.first;
CNoteData nd = pair.second; SproutNoteData nd = pair.second;
SproutPaymentAddress pa = nd.address; SproutPaymentAddress pa = nd.address;
// skip notes which belong to a different payment address in the wallet // skip notes which belong to a different payment address in the wallet

View File

@ -197,14 +197,14 @@ public:
std::string ToString() const; std::string ToString() const;
}; };
class CNoteData class SproutNoteData
{ {
public: public:
libzcash::SproutPaymentAddress address; libzcash::SproutPaymentAddress address;
/** /**
* Cached note nullifier. May not be set if the wallet was not unlocked when * Cached note nullifier. May not be set if the wallet was not unlocked when
* this was CNoteData was created. If not set, we always assume that the * this was SproutNoteData was created. If not set, we always assume that the
* note has not been spent. * note has not been spent.
* *
* It's okay to cache the nullifier in the wallet, because we are storing * It's okay to cache the nullifier in the wallet, because we are storing
@ -225,7 +225,7 @@ public:
/** /**
* Block height corresponding to the most current witness. * Block height corresponding to the most current witness.
* *
* When we first create a CNoteData in CWallet::FindMyNotes, this is set to * When we first create a SproutNoteData in CWallet::FindMyNotes, this is set to
* -1 as a placeholder. The next time CWallet::ChainTip is called, we can * -1 as a placeholder. The next time CWallet::ChainTip is called, we can
* determine what height the witness cache for this note is valid for (even * determine what height the witness cache for this note is valid for (even
* if no witnesses were cached), and so can set the correct value in * if no witnesses were cached), and so can set the correct value in
@ -233,10 +233,10 @@ public:
*/ */
int witnessHeight; int witnessHeight;
CNoteData() : address(), nullifier(), witnessHeight {-1} { } SproutNoteData() : address(), nullifier(), witnessHeight {-1} { }
CNoteData(libzcash::SproutPaymentAddress a) : SproutNoteData(libzcash::SproutPaymentAddress a) :
address {a}, nullifier(), witnessHeight {-1} { } address {a}, nullifier(), witnessHeight {-1} { }
CNoteData(libzcash::SproutPaymentAddress a, uint256 n) : SproutNoteData(libzcash::SproutPaymentAddress a, uint256 n) :
address {a}, nullifier {n}, witnessHeight {-1} { } address {a}, nullifier {n}, witnessHeight {-1} { }
ADD_SERIALIZE_METHODS; ADD_SERIALIZE_METHODS;
@ -249,21 +249,35 @@ public:
READWRITE(witnessHeight); READWRITE(witnessHeight);
} }
friend bool operator<(const CNoteData& a, const CNoteData& b) { friend bool operator<(const SproutNoteData& a, const SproutNoteData& b) {
return (a.address < b.address || return (a.address < b.address ||
(a.address == b.address && a.nullifier < b.nullifier)); (a.address == b.address && a.nullifier < b.nullifier));
} }
friend bool operator==(const CNoteData& a, const CNoteData& b) { friend bool operator==(const SproutNoteData& a, const SproutNoteData& b) {
return (a.address == b.address && a.nullifier == b.nullifier); return (a.address == b.address && a.nullifier == b.nullifier);
} }
friend bool operator!=(const CNoteData& a, const CNoteData& b) { friend bool operator!=(const SproutNoteData& a, const SproutNoteData& b) {
return !(a == b); return !(a == b);
} }
}; };
typedef std::map<JSOutPoint, CNoteData> mapNoteData_t; class SaplingNoteData
{
public:
/**
* We initialize the hight to -1 for the same reason as we do in SproutNoteData.
* See the comment in that class for a full description.
*/
SaplingNoteData() : witnessHeight {-1} { }
std::list<ZCSaplingIncrementalWitness> witnesses;
int witnessHeight;
};
typedef std::map<JSOutPoint, SproutNoteData> mapSproutNoteData_t;
typedef std::map<SaplingOutPoint, SaplingNoteData> mapSaplingNoteData_t;
/** Decrypted note and its location in a transaction. */ /** Decrypted note and its location in a transaction. */
struct CSproutNotePlaintextEntry struct CSproutNotePlaintextEntry
@ -350,7 +364,8 @@ private:
public: public:
mapValue_t mapValue; mapValue_t mapValue;
mapNoteData_t mapNoteData; mapSproutNoteData_t mapSproutNoteData;
mapSaplingNoteData_t mapSaplingNoteData;
std::vector<std::pair<std::string, std::string> > vOrderForm; std::vector<std::pair<std::string, std::string> > vOrderForm;
unsigned int fTimeReceivedIsTxTime; unsigned int fTimeReceivedIsTxTime;
unsigned int nTimeReceived; //! time received by this node unsigned int nTimeReceived; //! time received by this node
@ -403,7 +418,8 @@ public:
{ {
pwallet = pwalletIn; pwallet = pwalletIn;
mapValue.clear(); mapValue.clear();
mapNoteData.clear(); mapSproutNoteData.clear();
mapSaplingNoteData.clear();
vOrderForm.clear(); vOrderForm.clear();
fTimeReceivedIsTxTime = false; fTimeReceivedIsTxTime = false;
nTimeReceived = 0; nTimeReceived = 0;
@ -453,12 +469,14 @@ public:
std::vector<CMerkleTx> vUnused; //! Used to be vtxPrev std::vector<CMerkleTx> vUnused; //! Used to be vtxPrev
READWRITE(vUnused); READWRITE(vUnused);
READWRITE(mapValue); READWRITE(mapValue);
READWRITE(mapNoteData); READWRITE(mapSproutNoteData);
READWRITE(vOrderForm); READWRITE(vOrderForm);
READWRITE(fTimeReceivedIsTxTime); READWRITE(fTimeReceivedIsTxTime);
READWRITE(nTimeReceived); READWRITE(nTimeReceived);
READWRITE(fFromMe); READWRITE(fFromMe);
READWRITE(fSpent); READWRITE(fSpent);
// TODO:
//READWRITE(mapSaplingNoteData);
if (ser_action.ForRead()) if (ser_action.ForRead())
{ {
@ -495,7 +513,8 @@ public:
MarkDirty(); MarkDirty();
} }
void SetNoteData(mapNoteData_t &noteData); void SetSproutNoteData(mapSproutNoteData_t &noteData);
void SetSaplingNoteData(mapSaplingNoteData_t &noteData);
//! filter decides which addresses will count towards the debit //! filter decides which addresses will count towards the debit
CAmount GetDebit(const isminefilter& filter) const; CAmount GetDebit(const isminefilter& filter) const;
@ -718,7 +737,8 @@ protected:
*/ */
void IncrementNoteWitnesses(const CBlockIndex* pindex, void IncrementNoteWitnesses(const CBlockIndex* pindex,
const CBlock* pblock, const CBlock* pblock,
ZCIncrementalMerkleTree& tree); ZCIncrementalMerkleTree& sproutTree,
ZCSaplingIncrementalMerkleTree& saplingTree);
/** /**
* pindex is the old tip being disconnected. * pindex is the old tip being disconnected.
*/ */
@ -842,7 +862,7 @@ public:
* *
* - GetFilteredNotes can't filter out spent notes. * - GetFilteredNotes can't filter out spent notes.
* *
* - Per the comment in CNoteData, we assume that if we don't have a * - Per the comment in SproutNoteData, we assume that if we don't have a
* cached nullifier, the note is not spent. * cached nullifier, the note is not spent.
* *
* Another more problematic implication is that the wallet can fail to * Another more problematic implication is that the wallet can fail to
@ -1053,12 +1073,16 @@ public:
const ZCNoteDecryption& dec, const ZCNoteDecryption& dec,
const uint256& hSig, const uint256& hSig,
uint8_t n) const; uint8_t n) const;
mapNoteData_t FindMyNotes(const CTransaction& tx) const; mapSproutNoteData_t FindMyNotes(const CTransaction& tx) const;
bool IsFromMe(const uint256& nullifier) const; bool IsFromMe(const uint256& nullifier) const;
void GetNoteWitnesses( void GetSproutNoteWitnesses(
std::vector<JSOutPoint> notes, std::vector<JSOutPoint> notes,
std::vector<boost::optional<ZCIncrementalWitness>>& witnesses, std::vector<boost::optional<ZCIncrementalWitness>>& witnesses,
uint256 &final_anchor); uint256 &final_anchor);
void GetSaplingNoteWitnesses(
std::vector<SaplingOutPoint> notes,
std::vector<boost::optional<ZCSaplingIncrementalWitness>>& witnesses,
uint256 &final_anchor);
isminetype IsMine(const CTxIn& txin) const; isminetype IsMine(const CTxIn& txin) const;
CAmount GetDebit(const CTxIn& txin, const isminefilter& filter) const; CAmount GetDebit(const CTxIn& txin, const isminefilter& filter) const;
@ -1072,7 +1096,7 @@ public:
CAmount GetDebit(const CTransaction& tx, const isminefilter& filter) const; CAmount GetDebit(const CTransaction& tx, const isminefilter& filter) const;
CAmount GetCredit(const CTransaction& tx, const isminefilter& filter) const; CAmount GetCredit(const CTransaction& tx, const isminefilter& filter) const;
CAmount GetChange(const CTransaction& tx) const; CAmount GetChange(const CTransaction& tx) const;
void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree tree, bool added); void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree sproutTree, ZCSaplingIncrementalMerkleTree saplingTree, bool added);
/** Saves witness caches and best block locator to disk. */ /** Saves witness caches and best block locator to disk. */
void SetBestChain(const CBlockLocator& loc); void SetBestChain(const CBlockLocator& loc);
std::set<std::pair<libzcash::PaymentAddress, uint256>> GetNullifiersForAddresses(const std::set<libzcash::PaymentAddress> & addresses); std::set<std::pair<libzcash::PaymentAddress, uint256>> GetNullifiersForAddresses(const std::set<libzcash::PaymentAddress> & addresses);

View File

@ -298,7 +298,8 @@ double benchmark_try_decrypt_notes(size_t nAddrs)
double benchmark_increment_note_witnesses(size_t nTxs) double benchmark_increment_note_witnesses(size_t nTxs)
{ {
CWallet wallet; CWallet wallet;
ZCIncrementalMerkleTree tree; ZCIncrementalMerkleTree sproutTree;
ZCSaplingIncrementalMerkleTree saplingTree;
auto sk = libzcash::SproutSpendingKey::random(); auto sk = libzcash::SproutSpendingKey::random();
wallet.AddSpendingKey(sk); wallet.AddSpendingKey(sk);
@ -310,12 +311,12 @@ double benchmark_increment_note_witnesses(size_t nTxs)
auto note = GetNote(*pzcashParams, sk, wtx, 0, 1); auto note = GetNote(*pzcashParams, sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
block1.vtx.push_back(wtx); block1.vtx.push_back(wtx);
} }
@ -323,7 +324,7 @@ double benchmark_increment_note_witnesses(size_t nTxs)
index1.nHeight = 1; index1.nHeight = 1;
// Increment to get transactions witnessed // Increment to get transactions witnessed
wallet.ChainTip(&index1, &block1, tree, true); wallet.ChainTip(&index1, &block1, sproutTree, saplingTree, true);
// Second block // Second block
CBlock block2; CBlock block2;
@ -333,12 +334,12 @@ double benchmark_increment_note_witnesses(size_t nTxs)
auto note = GetNote(*pzcashParams, sk, wtx, 0, 1); auto note = GetNote(*pzcashParams, sk, wtx, 0, 1);
auto nullifier = note.nullifier(sk); auto nullifier = note.nullifier(sk);
mapNoteData_t noteData; mapSproutNoteData_t noteData;
JSOutPoint jsoutpt {wtx.GetHash(), 0, 1}; JSOutPoint jsoutpt {wtx.GetHash(), 0, 1};
CNoteData nd {sk.address(), nullifier}; SproutNoteData nd {sk.address(), nullifier};
noteData[jsoutpt] = nd; noteData[jsoutpt] = nd;
wtx.SetNoteData(noteData); wtx.SetSproutNoteData(noteData);
wallet.AddToWallet(wtx, true, NULL); wallet.AddToWallet(wtx, true, NULL);
block2.vtx.push_back(wtx); block2.vtx.push_back(wtx);
} }
@ -347,7 +348,7 @@ double benchmark_increment_note_witnesses(size_t nTxs)
struct timeval tv_start; struct timeval tv_start;
timer_start(tv_start); timer_start(tv_start);
wallet.ChainTip(&index2, &block2, tree, true); wallet.ChainTip(&index2, &block2, sproutTree, saplingTree, true);
return timer_stop(tv_start); return timer_stop(tv_start);
} }