diff --git a/src/Makefile.gtest.include b/src/Makefile.gtest.include index dab686806..c93e04332 100644 --- a/src/Makefile.gtest.include +++ b/src/Makefile.gtest.include @@ -18,7 +18,8 @@ zcash_gtest_SOURCES = \ gtest/test_txid.cpp \ gtest/test_wallet_zkeys.cpp \ gtest/test_libzcash_utils.cpp \ - gtest/test_proofs.cpp + gtest/test_proofs.cpp \ + wallet/gtest/test_wallet.cpp zcash_gtest_CPPFLAGS = -DMULTICORE -fopenmp -DBINARY_OUTPUT -DCURVE_ALT_BN128 -DSTATIC diff --git a/src/gtest/test_keystore.cpp b/src/gtest/test_keystore.cpp index 8587ba49b..11d967e89 100644 --- a/src/gtest/test_keystore.cpp +++ b/src/gtest/test_keystore.cpp @@ -27,3 +27,17 @@ TEST(keystore_tests, store_and_retrieve_spending_key) { EXPECT_EQ(1, addrs.size()); EXPECT_EQ(1, addrs.count(addr)); } + +TEST(keystore_tests, store_and_retrieve_note_decryptor) { + CBasicKeyStore keyStore; + ZCNoteDecryption decOut; + + auto sk = libzcash::SpendingKey::random(); + auto addr = sk.address(); + + EXPECT_FALSE(keyStore.GetNoteDecryptor(addr, decOut)); + + keyStore.AddSpendingKey(sk); + EXPECT_TRUE(keyStore.GetNoteDecryptor(addr, decOut)); + EXPECT_EQ(ZCNoteDecryption(sk.viewing_key()), decOut); +} diff --git a/src/keystore.cpp b/src/keystore.cpp index 7240cd747..f32ba0c32 100644 --- a/src/keystore.cpp +++ b/src/keystore.cpp @@ -87,6 +87,8 @@ bool CBasicKeyStore::HaveWatchOnly() const bool CBasicKeyStore::AddSpendingKey(const libzcash::SpendingKey &sk) { LOCK(cs_SpendingKeyStore); - mapSpendingKeys[sk.address()] = sk; + auto address = sk.address(); + mapSpendingKeys[address] = sk; + mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(sk.viewing_key()))); return true; } diff --git a/src/keystore.h b/src/keystore.h index 987f32070..aa3aefdf2 100644 --- a/src/keystore.h +++ b/src/keystore.h @@ -12,6 +12,7 @@ #include "script/standard.h" #include "sync.h" #include "zcash/Address.hpp" +#include "zcash/NoteEncryption.hpp" #include #include @@ -60,6 +61,7 @@ typedef std::map KeyMap; typedef std::map ScriptMap; typedef std::set WatchOnlySet; typedef std::map SpendingKeyMap; +typedef std::map NoteDecryptorMap; /** Basic key store, that keeps keys in an address->secret map */ class CBasicKeyStore : public CKeyStore @@ -69,6 +71,7 @@ protected: ScriptMap mapScripts; WatchOnlySet setWatchOnly; SpendingKeyMap mapSpendingKeys; + NoteDecryptorMap mapNoteDecryptors; public: bool AddKeyPubKey(const CKey& key, const CPubKey &pubkey); @@ -139,6 +142,19 @@ public: } return false; } + bool GetNoteDecryptor(const libzcash::PaymentAddress &address, ZCNoteDecryption &decOut) const + { + { + LOCK(cs_SpendingKeyStore); + NoteDecryptorMap::const_iterator mi = mapNoteDecryptors.find(address); + if (mi != mapNoteDecryptors.end()) + { + decOut = mi->second; + return true; + } + } + return false; + } void GetPaymentAddresses(std::set &setAddress) const { setAddress.clear(); diff --git a/src/main.cpp b/src/main.cpp index ddeed4e05..9f850f493 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -2398,11 +2398,16 @@ bool static DisconnectTip(CValidationState &state) { mempool.check(pcoinsTip); // Update chainActive and related variables. UpdateTip(pindexDelete->pprev); + // Get the current commitment tree + ZCIncrementalMerkleTree newTree; + assert(pcoinsTip->GetAnchorAt(pcoinsTip->GetBestAnchor(), newTree)); // Let wallets know transactions went from 1-confirmed to // 0-confirmed or conflicted: BOOST_FOREACH(const CTransaction &tx, block.vtx) { SyncWithWallets(tx, NULL); } + // Update cached incremental witnesses + GetMainSignals().ChainTip(pindexDelete, &block, newTree, false); return true; } @@ -2427,6 +2432,9 @@ bool static ConnectTip(CValidationState &state, CBlockIndex *pindexNew, CBlock * return AbortNode(state, "Failed to read block"); pblock = █ } + // Get the current commitment tree + ZCIncrementalMerkleTree oldTree; + assert(pcoinsTip->GetAnchorAt(pcoinsTip->GetBestAnchor(), oldTree)); // Apply the block atomically to the chain state. int64_t nTime2 = GetTimeMicros(); nTimeReadFromDisk += nTime2 - nTime1; int64_t nTime3; @@ -2468,6 +2476,8 @@ bool static ConnectTip(CValidationState &state, CBlockIndex *pindexNew, CBlock * BOOST_FOREACH(const CTransaction &tx, pblock->vtx) { SyncWithWallets(tx, pblock); } + // Update cached incremental witnesses + GetMainSignals().ChainTip(pindexNew, pblock, oldTree, true); int64_t nTime6 = GetTimeMicros(); nTimePostConnect += nTime6 - nTime5; nTimeTotal += nTime6 - nTime1; LogPrint("bench", " - Connect postprocess: %.2fms [%.2fs]\n", (nTime6 - nTime5) * 0.001, nTimePostConnect * 0.000001); diff --git a/src/primitives/transaction.cpp b/src/primitives/transaction.cpp index 854a0be66..5c1f94758 100644 --- a/src/primitives/transaction.cpp +++ b/src/primitives/transaction.cpp @@ -15,11 +15,14 @@ JSDescription::JSDescription(ZCJoinSplit& params, const boost::array& inputs, const boost::array& outputs, CAmount vpub_old, - CAmount vpub_new) : vpub_old(vpub_old), vpub_new(vpub_new), anchor(anchor) + CAmount vpub_new, + bool computeProof) : vpub_old(vpub_old), vpub_new(vpub_new), anchor(anchor) { boost::array notes; - params.loadProvingKey(); + if (computeProof) { + params.loadProvingKey(); + } proof = params.prove( inputs, outputs, @@ -33,7 +36,8 @@ JSDescription::JSDescription(ZCJoinSplit& params, commitments, vpub_old, vpub_new, - anchor + anchor, + computeProof ); } diff --git a/src/primitives/transaction.h b/src/primitives/transaction.h index c88b26d17..44375fc63 100644 --- a/src/primitives/transaction.h +++ b/src/primitives/transaction.h @@ -74,7 +74,8 @@ public: const boost::array& inputs, const boost::array& outputs, CAmount vpub_old, - CAmount vpub_new + CAmount vpub_new, + bool computeProof = true // Set to false in some tests ); // Verifies that the JoinSplit proof is correct. diff --git a/src/serialize.h b/src/serialize.h index aca3ed076..34d41bf84 100644 --- a/src/serialize.h +++ b/src/serialize.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -544,6 +545,13 @@ template unsigned int GetSerializeSize(co template void Serialize(Stream& os, const std::set& m, int nType, int nVersion); template void Unserialize(Stream& is, std::set& m, int nType, int nVersion); +/** + * list + */ +template unsigned int GetSerializeSize(const std::list& m, int nType, int nVersion); +template void Serialize(Stream& os, const std::list& m, int nType, int nVersion); +template void Unserialize(Stream& is, std::list& m, int nType, int nVersion); + @@ -890,6 +898,42 @@ void Unserialize(Stream& is, std::set& m, int nType, int nVersion) +/** + * list + */ +template +unsigned int GetSerializeSize(const std::list& l, int nType, int nVersion) +{ + unsigned int nSize = GetSizeOfCompactSize(l.size()); + for (typename std::list::const_iterator it = l.begin(); it != l.end(); ++it) + nSize += GetSerializeSize((*it), nType, nVersion); + return nSize; +} + +template +void Serialize(Stream& os, const std::list& l, int nType, int nVersion) +{ + WriteCompactSize(os, l.size()); + for (typename std::list::const_iterator it = l.begin(); it != l.end(); ++it) + Serialize(os, (*it), nType, nVersion); +} + +template +void Unserialize(Stream& is, std::list& l, int nType, int nVersion) +{ + l.clear(); + unsigned int nSize = ReadCompactSize(is); + typename std::list::iterator it = l.begin(); + for (unsigned int i = 0; i < nSize; i++) + { + T item; + Unserialize(is, item, nType, nVersion); + l.push_back(item); + } +} + + + /** * Support for ADD_SERIALIZE_METHODS and READWRITE macro */ diff --git a/src/validationinterface.cpp b/src/validationinterface.cpp index aa9aefb0d..3df6cd3f2 100644 --- a/src/validationinterface.cpp +++ b/src/validationinterface.cpp @@ -16,6 +16,7 @@ void RegisterValidationInterface(CValidationInterface* pwalletIn) { g_signals.SyncTransaction.connect(boost::bind(&CValidationInterface::SyncTransaction, pwalletIn, _1, _2)); g_signals.EraseTransaction.connect(boost::bind(&CValidationInterface::EraseFromWallet, pwalletIn, _1)); g_signals.UpdatedTransaction.connect(boost::bind(&CValidationInterface::UpdatedTransaction, pwalletIn, _1)); + g_signals.ChainTip.connect(boost::bind(&CValidationInterface::ChainTip, pwalletIn, _1, _2, _3, _4)); g_signals.SetBestChain.connect(boost::bind(&CValidationInterface::SetBestChain, pwalletIn, _1)); g_signals.Inventory.connect(boost::bind(&CValidationInterface::Inventory, pwalletIn, _1)); g_signals.Broadcast.connect(boost::bind(&CValidationInterface::ResendWalletTransactions, pwalletIn, _1)); @@ -26,6 +27,7 @@ void UnregisterValidationInterface(CValidationInterface* pwalletIn) { g_signals.BlockChecked.disconnect(boost::bind(&CValidationInterface::BlockChecked, pwalletIn, _1, _2)); g_signals.Broadcast.disconnect(boost::bind(&CValidationInterface::ResendWalletTransactions, pwalletIn, _1)); g_signals.Inventory.disconnect(boost::bind(&CValidationInterface::Inventory, pwalletIn, _1)); + g_signals.ChainTip.disconnect(boost::bind(&CValidationInterface::ChainTip, pwalletIn, _1, _2, _3, _4)); g_signals.SetBestChain.disconnect(boost::bind(&CValidationInterface::SetBestChain, pwalletIn, _1)); g_signals.UpdatedTransaction.disconnect(boost::bind(&CValidationInterface::UpdatedTransaction, pwalletIn, _1)); g_signals.EraseTransaction.disconnect(boost::bind(&CValidationInterface::EraseFromWallet, pwalletIn, _1)); @@ -36,6 +38,7 @@ void UnregisterAllValidationInterfaces() { g_signals.BlockChecked.disconnect_all_slots(); g_signals.Broadcast.disconnect_all_slots(); g_signals.Inventory.disconnect_all_slots(); + g_signals.ChainTip.disconnect_all_slots(); g_signals.SetBestChain.disconnect_all_slots(); g_signals.UpdatedTransaction.disconnect_all_slots(); g_signals.EraseTransaction.disconnect_all_slots(); diff --git a/src/validationinterface.h b/src/validationinterface.h index 2a4c7ecce..144e716bb 100644 --- a/src/validationinterface.h +++ b/src/validationinterface.h @@ -8,7 +8,10 @@ #include +#include "zcash/IncrementalMerkleTree.hpp" + class CBlock; +class CBlockIndex; struct CBlockLocator; class CTransaction; class CValidationInterface; @@ -30,6 +33,7 @@ class CValidationInterface { protected: virtual void SyncTransaction(const CTransaction &tx, const CBlock *pblock) {} virtual void EraseFromWallet(const uint256 &hash) {} + virtual void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree tree, bool added) {} virtual void SetBestChain(const CBlockLocator &locator) {} virtual void UpdatedTransaction(const uint256 &hash) {} virtual void Inventory(const uint256 &hash) {} @@ -47,6 +51,8 @@ struct CMainSignals { boost::signals2::signal EraseTransaction; /** Notifies listeners of an updated transaction without new data (for now: a coinbase potentially becoming visible). */ boost::signals2::signal UpdatedTransaction; + /** Notifies listeners of a change to the tip of the active block chain. */ + boost::signals2::signal ChainTip; /** Notifies listeners of a new active block chain. */ boost::signals2::signal SetBestChain; /** Notifies listeners about an inventory item being seen on the network. */ diff --git a/src/wallet/gtest/test_wallet.cpp b/src/wallet/gtest/test_wallet.cpp new file mode 100644 index 000000000..6b35baa2d --- /dev/null +++ b/src/wallet/gtest/test_wallet.cpp @@ -0,0 +1,474 @@ +#include +#include +#include + +#include "base58.h" +#include "chainparams.h" +#include "main.h" +#include "random.h" +#include "wallet/wallet.h" +#include "zcash/JoinSplit.hpp" +#include "zcash/Note.hpp" +#include "zcash/NoteEncryption.hpp" + +using ::testing::_; +using ::testing::Return; + +ZCJoinSplit* params = ZCJoinSplit::Unopened(); + +class TestWallet : public CWallet { +public: + TestWallet() : CWallet() { } + + void IncrementNoteWitnesses(const CBlockIndex* pindex, + const CBlock* pblock, + ZCIncrementalMerkleTree tree) { + CWallet::IncrementNoteWitnesses(pindex, pblock, tree); + } + void DecrementNoteWitnesses() { + CWallet::DecrementNoteWitnesses(); + } +}; + +CWalletTx GetValidReceive(const libzcash::SpendingKey& sk, CAmount value, bool randomInputs) { + CMutableTransaction mtx; + mtx.nVersion = 2; // Enable JoinSplits + mtx.vin.resize(2); + if (randomInputs) { + mtx.vin[0].prevout.hash = GetRandHash(); + mtx.vin[1].prevout.hash = GetRandHash(); + } else { + mtx.vin[0].prevout.hash = uint256S("0000000000000000000000000000000000000000000000000000000000000001"); + mtx.vin[1].prevout.hash = uint256S("0000000000000000000000000000000000000000000000000000000000000002"); + } + mtx.vin[0].prevout.n = 0; + mtx.vin[1].prevout.n = 0; + + // Generate an ephemeral keypair. + uint256 joinSplitPubKey; + unsigned char joinSplitPrivKey[crypto_sign_SECRETKEYBYTES]; + crypto_sign_keypair(joinSplitPubKey.begin(), joinSplitPrivKey); + mtx.joinSplitPubKey = joinSplitPubKey; + + boost::array inputs = { + libzcash::JSInput(), // dummy input + libzcash::JSInput() // dummy input + }; + + boost::array outputs = { + libzcash::JSOutput(sk.address(), value), + libzcash::JSOutput(sk.address(), value) + }; + + boost::array output_notes; + + // Prepare JoinSplits + uint256 rt; + JSDescription jsdesc {*params, mtx.joinSplitPubKey, rt, + inputs, outputs, value, 0, false}; + mtx.vjoinsplit.push_back(jsdesc); + + // Empty output script. + CScript scriptCode; + CTransaction signTx(mtx); + uint256 dataToBeSigned = SignatureHash(scriptCode, signTx, NOT_AN_INPUT, SIGHASH_ALL); + + // Add the signature + assert(crypto_sign_detached(&mtx.joinSplitSig[0], NULL, + dataToBeSigned.begin(), 32, + joinSplitPrivKey + ) == 0); + + CTransaction tx {mtx}; + CWalletTx wtx {NULL, tx}; + return wtx; +} + +libzcash::Note GetNote(const libzcash::SpendingKey& sk, + const CTransaction& tx, size_t js, size_t n) { + ZCNoteDecryption decryptor {sk.viewing_key()}; + auto hSig = tx.vjoinsplit[js].h_sig(*params, tx.joinSplitPubKey); + auto note_pt = libzcash::NotePlaintext::decrypt( + decryptor, + tx.vjoinsplit[js].ciphertexts[n], + tx.vjoinsplit[js].ephemeralKey, + hSig, + (unsigned char) n); + return note_pt.note(sk.address()); +} + +CWalletTx GetValidSpend(const libzcash::SpendingKey& sk, + const libzcash::Note& note, CAmount value) { + CMutableTransaction mtx; + mtx.vout.resize(2); + mtx.vout[0].nValue = value; + mtx.vout[1].nValue = 0; + + // Generate an ephemeral keypair. + uint256 joinSplitPubKey; + unsigned char joinSplitPrivKey[crypto_sign_SECRETKEYBYTES]; + crypto_sign_keypair(joinSplitPubKey.begin(), joinSplitPrivKey); + mtx.joinSplitPubKey = joinSplitPubKey; + + // Fake tree for the unused witness + ZCIncrementalMerkleTree tree; + + boost::array inputs = { + libzcash::JSInput(tree.witness(), note, sk), + libzcash::JSInput() // dummy input + }; + + boost::array outputs = { + libzcash::JSOutput(), // dummy output + libzcash::JSOutput() // dummy output + }; + + boost::array output_notes; + + // Prepare JoinSplits + uint256 rt; + JSDescription jsdesc {*params, mtx.joinSplitPubKey, rt, + inputs, outputs, 0, value, false}; + mtx.vjoinsplit.push_back(jsdesc); + + // Empty output script. + CScript scriptCode; + CTransaction signTx(mtx); + uint256 dataToBeSigned = SignatureHash(scriptCode, signTx, NOT_AN_INPUT, SIGHASH_ALL); + + // Add the signature + assert(crypto_sign_detached(&mtx.joinSplitSig[0], NULL, + dataToBeSigned.begin(), 32, + joinSplitPrivKey + ) == 0); + CTransaction tx {mtx}; + CWalletTx wtx {NULL, tx}; + return wtx; +} + +TEST(wallet_tests, note_data_serialisation) { + auto sk = libzcash::SpendingKey::random(); + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + + mapNoteData_t noteData; + JSOutPoint jsoutpt {wtx.GetTxid(), 0, 1}; + CNoteData nd {sk.address(), nullifier}; + ZCIncrementalMerkleTree tree; + nd.witnesses.push_front(tree.witness()); + noteData[jsoutpt] = nd; + + CDataStream ss(SER_DISK, CLIENT_VERSION); + ss << noteData; + + mapNoteData_t noteData2; + ss >> noteData2; + + EXPECT_EQ(noteData, noteData2); + EXPECT_EQ(noteData[jsoutpt].witnesses, noteData2[jsoutpt].witnesses); +} + +TEST(wallet_tests, set_note_addrs_in_cwallettx) { + auto sk = libzcash::SpendingKey::random(); + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + EXPECT_EQ(0, wtx.mapNoteData.size()); + + mapNoteData_t noteData; + JSOutPoint jsoutpt {wtx.GetTxid(), 0, 1}; + CNoteData nd {sk.address(), nullifier}; + noteData[jsoutpt] = nd; + + wtx.SetNoteData(noteData); + EXPECT_EQ(noteData, wtx.mapNoteData); +} + +TEST(wallet_tests, set_invalid_note_addrs_in_cwallettx) { + CWalletTx wtx; + EXPECT_EQ(0, wtx.mapNoteData.size()); + + mapNoteData_t noteData; + auto sk = libzcash::SpendingKey::random(); + JSOutPoint jsoutpt {wtx.GetTxid(), 0, 1}; + CNoteData nd {sk.address(), uint256()}; + noteData[jsoutpt] = nd; + + EXPECT_THROW(wtx.SetNoteData(noteData), std::logic_error); +} + +TEST(wallet_tests, find_note_in_tx) { + CWallet wallet; + + auto sk = libzcash::SpendingKey::random(); + wallet.AddSpendingKey(sk); + + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + + auto noteMap = wallet.FindMyNotes(wtx); + EXPECT_EQ(2, noteMap.size()); + + JSOutPoint jsoutpt {wtx.GetTxid(), 0, 1}; + CNoteData nd {sk.address(), nullifier}; + EXPECT_EQ(1, noteMap.count(jsoutpt)); + EXPECT_EQ(nd, noteMap[jsoutpt]); +} + +TEST(wallet_tests, get_conflicted_notes) { + CWallet wallet; + + auto sk = libzcash::SpendingKey::random(); + wallet.AddSpendingKey(sk); + + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + + auto wtx2 = GetValidSpend(sk, note, 5); + auto wtx3 = GetValidSpend(sk, note, 10); + auto hash2 = wtx2.GetTxid(); + auto hash3 = wtx3.GetTxid(); + + // No conflicts for no spends + EXPECT_EQ(0, wallet.GetConflicts(hash2).size()); + wallet.AddToWallet(wtx, true, NULL); + EXPECT_EQ(0, wallet.GetConflicts(hash2).size()); + + // No conflicts for one spend + wallet.AddToWallet(wtx2, true, NULL); + EXPECT_EQ(0, wallet.GetConflicts(hash2).size()); + + // Conflicts for two spends + wallet.AddToWallet(wtx3, true, NULL); + auto c3 = wallet.GetConflicts(hash2); + EXPECT_EQ(2, c3.size()); + EXPECT_EQ(std::set({hash2, hash3}), c3); +} + +TEST(wallet_tests, nullifier_is_spent) { + CWallet wallet; + + auto sk = libzcash::SpendingKey::random(); + wallet.AddSpendingKey(sk); + + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + + EXPECT_FALSE(wallet.IsSpent(nullifier)); + + wallet.AddToWallet(wtx, true, NULL); + EXPECT_FALSE(wallet.IsSpent(nullifier)); + + auto wtx2 = GetValidSpend(sk, note, 5); + wallet.AddToWallet(wtx2, true, NULL); + EXPECT_FALSE(wallet.IsSpent(nullifier)); + + // Fake-mine the transaction + EXPECT_EQ(-1, chainActive.Height()); + CBlock block; + block.vtx.push_back(wtx2); + block.hashMerkleRoot = block.BuildMerkleTree(); + auto blockHash = block.GetHash(); + CBlockIndex fakeIndex {block}; + mapBlockIndex.insert(std::make_pair(blockHash, &fakeIndex)); + chainActive.SetTip(&fakeIndex); + EXPECT_TRUE(chainActive.Contains(&fakeIndex)); + EXPECT_EQ(0, chainActive.Height()); + + wtx2.SetMerkleBranch(block); + wallet.AddToWallet(wtx2, true, NULL); + EXPECT_TRUE(wallet.IsSpent(nullifier)); + + // Tear down + chainActive.SetTip(NULL); + mapBlockIndex.erase(blockHash); +} + +TEST(wallet_tests, navigate_from_nullifier_to_note) { + CWallet wallet; + + auto sk = libzcash::SpendingKey::random(); + wallet.AddSpendingKey(sk); + + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + + mapNoteData_t noteData; + JSOutPoint jsoutpt {wtx.GetTxid(), 0, 1}; + CNoteData nd {sk.address(), nullifier}; + noteData[jsoutpt] = nd; + + wtx.SetNoteData(noteData); + + EXPECT_EQ(0, wallet.mapNullifiersToNotes.count(nullifier)); + + wallet.AddToWallet(wtx, true, NULL); + EXPECT_EQ(1, wallet.mapNullifiersToNotes.count(nullifier)); + EXPECT_EQ(wtx.GetTxid(), wallet.mapNullifiersToNotes[nullifier].hash); + EXPECT_EQ(0, wallet.mapNullifiersToNotes[nullifier].js); + EXPECT_EQ(1, wallet.mapNullifiersToNotes[nullifier].n); +} + +TEST(wallet_tests, spent_note_is_from_me) { + CWallet wallet; + + auto sk = libzcash::SpendingKey::random(); + wallet.AddSpendingKey(sk); + + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + auto wtx2 = GetValidSpend(sk, note, 5); + + EXPECT_FALSE(wallet.IsFromMe(wtx)); + EXPECT_FALSE(wallet.IsFromMe(wtx2)); + + mapNoteData_t noteData; + JSOutPoint jsoutpt {wtx.GetTxid(), 0, 1}; + CNoteData nd {sk.address(), nullifier}; + noteData[jsoutpt] = nd; + + wtx.SetNoteData(noteData); + EXPECT_FALSE(wallet.IsFromMe(wtx)); + EXPECT_FALSE(wallet.IsFromMe(wtx2)); + + wallet.AddToWallet(wtx, true, NULL); + EXPECT_FALSE(wallet.IsFromMe(wtx)); + EXPECT_TRUE(wallet.IsFromMe(wtx2)); +} + +TEST(wallet_tests, cached_witnesses_empty_chain) { + TestWallet wallet; + + auto sk = libzcash::SpendingKey::random(); + wallet.AddSpendingKey(sk); + + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 0); + auto note2 = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + auto nullifier2 = note2.nullifier(sk); + + mapNoteData_t noteData; + JSOutPoint jsoutpt {wtx.GetTxid(), 0, 0}; + JSOutPoint jsoutpt2 {wtx.GetTxid(), 0, 1}; + CNoteData nd {sk.address(), nullifier}; + CNoteData nd2 {sk.address(), nullifier2}; + noteData[jsoutpt] = nd; + noteData[jsoutpt2] = nd2; + wtx.SetNoteData(noteData); + + std::vector notes {jsoutpt, jsoutpt2}; + std::vector> witnesses; + uint256 anchor; + + wallet.GetNoteWitnesses(notes, witnesses, anchor); + EXPECT_FALSE((bool) witnesses[0]); + EXPECT_FALSE((bool) witnesses[1]); + + wallet.AddToWallet(wtx, true, NULL); + witnesses.clear(); + wallet.GetNoteWitnesses(notes, witnesses, anchor); + EXPECT_FALSE((bool) witnesses[0]); + EXPECT_FALSE((bool) witnesses[1]); + + CBlock block; + block.vtx.push_back(wtx); + ZCIncrementalMerkleTree tree; + wallet.IncrementNoteWitnesses(NULL, &block, tree); + witnesses.clear(); + wallet.GetNoteWitnesses(notes, witnesses, anchor); + EXPECT_TRUE((bool) witnesses[0]); + EXPECT_TRUE((bool) witnesses[1]); + + // Until #1302 is implemented, this should triggger an assertion + EXPECT_DEATH(wallet.DecrementNoteWitnesses(), + "Assertion `nWitnessCacheSize > 0' failed."); +} + +TEST(wallet_tests, cached_witnesses_chain_tip) { + TestWallet wallet; + uint256 anchor1; + CBlock block1; + ZCIncrementalMerkleTree tree; + + auto sk = libzcash::SpendingKey::random(); + wallet.AddSpendingKey(sk); + + { + // First transaction (case tested in _empty_chain) + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + + mapNoteData_t noteData; + JSOutPoint jsoutpt {wtx.GetTxid(), 0, 1}; + CNoteData nd {sk.address(), nullifier}; + noteData[jsoutpt] = nd; + wtx.SetNoteData(noteData); + wallet.AddToWallet(wtx, true, NULL); + + std::vector notes {jsoutpt}; + std::vector> witnesses; + + // First block (case tested in _empty_chain) + block1.vtx.push_back(wtx); + wallet.IncrementNoteWitnesses(NULL, &block1, tree); + // Called to fetch anchor + wallet.GetNoteWitnesses(notes, witnesses, anchor1); + } + + { + // Second transaction + auto wtx = GetValidReceive(sk, 50, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + + mapNoteData_t noteData; + JSOutPoint jsoutpt {wtx.GetTxid(), 0, 1}; + CNoteData nd {sk.address(), nullifier}; + noteData[jsoutpt] = nd; + wtx.SetNoteData(noteData); + wallet.AddToWallet(wtx, true, NULL); + + std::vector notes {jsoutpt}; + std::vector> witnesses; + uint256 anchor2; + + wallet.GetNoteWitnesses(notes, witnesses, anchor2); + EXPECT_FALSE((bool) witnesses[0]); + + // Second block + CBlock block2; + block2.hashPrevBlock = block1.GetHash(); + block2.vtx.push_back(wtx); + ZCIncrementalMerkleTree tree2 {tree}; + wallet.IncrementNoteWitnesses(NULL, &block2, tree2); + witnesses.clear(); + wallet.GetNoteWitnesses(notes, witnesses, anchor2); + EXPECT_TRUE((bool) witnesses[0]); + EXPECT_NE(anchor1, anchor2); + + // Decrementing should give us the previous anchor + uint256 anchor3; + wallet.DecrementNoteWitnesses(); + witnesses.clear(); + wallet.GetNoteWitnesses(notes, witnesses, anchor3); + EXPECT_FALSE((bool) witnesses[0]); + // Should not equal first anchor because none of these notes had witnesses + EXPECT_NE(anchor1, anchor3); + + // Re-incrementing with the same block should give the same result + uint256 anchor4; + wallet.IncrementNoteWitnesses(NULL, &block2, tree); + witnesses.clear(); + wallet.GetNoteWitnesses(notes, witnesses, anchor4); + EXPECT_TRUE((bool) witnesses[0]); + EXPECT_EQ(anchor2, anchor4); + } +} diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 0f6eeff49..3407df53e 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -8,8 +8,8 @@ #include "base58.h" #include "checkpoints.h" #include "coincontrol.h" -#include "consensus/consensus.h" #include "consensus/validation.h" +#include "init.h" #include "main.h" #include "net.h" #include "script/script.h" @@ -17,6 +17,7 @@ #include "timedata.h" #include "util.h" #include "utilmoneystr.h" +#include "zcash/Note.hpp" #include @@ -57,6 +58,11 @@ struct CompareValueOnly } }; +std::string JSOutPoint::ToString() const +{ + return strprintf("JSOutPoint(%s, %d, %d)", hash.ToString().substr(0,10), js, n); +} + std::string COutput::ToString() const { return strprintf("COutput(%s, %d, %d) [%s]", tx->GetTxid().ToString(), i, nDepth, FormatMoney(tx->vout[i].nValue)); @@ -329,6 +335,16 @@ bool CWallet::ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase, return false; } +void CWallet::ChainTip(const CBlockIndex *pindex, const CBlock *pblock, + ZCIncrementalMerkleTree tree, bool added) +{ + if (added) { + IncrementNoteWitnesses(pindex, pblock, tree); + } else { + DecrementNoteWitnesses(); + } +} + void CWallet::SetBestChain(const CBlockLocator& loc) { CWalletDB walletdb(strWalletFile); @@ -394,6 +410,20 @@ set CWallet::GetConflicts(const uint256& txid) const for (TxSpends::const_iterator it = range.first; it != range.second; ++it) result.insert(it->second); } + + std::pair range_n; + + for (const JSDescription& jsdesc : wtx.vjoinsplit) { + for (const uint256& nullifier : jsdesc.nullifiers) { + if (mapTxNullifiers.count(nullifier) <= 1) { + continue; // No conflict if zero or one spends + } + range_n = mapTxNullifiers.equal_range(nullifier); + for (TxNullifiers::const_iterator it = range_n.first; it != range_n.second; ++it) { + result.insert(it->second); + } + } + } return result; } @@ -449,7 +479,8 @@ bool CWallet::Verify(const string& walletFile, string& warningString, string& er return true; } -void CWallet::SyncMetaData(pair range) +template +void CWallet::SyncMetaData(pair::iterator, typename TxSpendMap::iterator> range) { // We want all the wallet transactions in range to have the same metadata as // the oldest (smallest nOrderPos). @@ -457,7 +488,7 @@ void CWallet::SyncMetaData(pair range) int nMinOrderPos = std::numeric_limits::max(); const CWalletTx* copyFrom = NULL; - for (TxSpends::iterator it = range.first; it != range.second; ++it) + for (typename TxSpendMap::iterator it = range.first; it != range.second; ++it) { const uint256& hash = it->second; int n = mapWallet[hash].nOrderPos; @@ -468,12 +499,14 @@ void CWallet::SyncMetaData(pair range) } } // Now copy data from copyFrom to rest: - for (TxSpends::iterator it = range.first; it != range.second; ++it) + for (typename TxSpendMap::iterator it = range.first; it != range.second; ++it) { const uint256& hash = it->second; CWalletTx* copyTo = &mapWallet[hash]; if (copyFrom == copyTo) continue; copyTo->mapValue = copyFrom->mapValue; + // mapNoteData not copied on purpose + // (it is always set correctly for each CWalletTx) copyTo->vOrderForm = copyFrom->vOrderForm; // fTimeReceivedIsTxTime not copied on purpose // nTimeReceived not copied on purpose @@ -505,15 +538,42 @@ bool CWallet::IsSpent(const uint256& hash, unsigned int n) const return false; } +/** + * Note is spent if any non-conflicted transaction + * spends it: + */ +bool CWallet::IsSpent(const uint256& nullifier) const +{ + pair range; + range = mapTxNullifiers.equal_range(nullifier); + + for (TxNullifiers::const_iterator it = range.first; it != range.second; ++it) { + const uint256& wtxid = it->second; + std::map::const_iterator mit = mapWallet.find(wtxid); + if (mit != mapWallet.end() && mit->second.GetDepthInMainChain() >= 0) { + return true; // Spent + } + } + return false; +} + void CWallet::AddToSpends(const COutPoint& outpoint, const uint256& wtxid) { mapTxSpends.insert(make_pair(outpoint, wtxid)); pair range; range = mapTxSpends.equal_range(outpoint); - SyncMetaData(range); + SyncMetaData(range); } +void CWallet::AddToSpends(const uint256& nullifier, const uint256& wtxid) +{ + mapTxNullifiers.insert(make_pair(nullifier, wtxid)); + + pair range; + range = mapTxNullifiers.equal_range(nullifier); + SyncMetaData(range); +} void CWallet::AddToSpends(const uint256& wtxid) { @@ -522,8 +582,102 @@ void CWallet::AddToSpends(const uint256& wtxid) if (thisTx.IsCoinBase()) // Coinbases don't spend anything! return; - BOOST_FOREACH(const CTxIn& txin, thisTx.vin) + for (const CTxIn& txin : thisTx.vin) { AddToSpends(txin.prevout, wtxid); + } + for (const JSDescription& jsdesc : thisTx.vjoinsplit) { + for (const uint256& nullifier : jsdesc.nullifiers) { + AddToSpends(nullifier, wtxid); + } + } +} + +void CWallet::IncrementNoteWitnesses(const CBlockIndex* pindex, + const CBlock* pblockIn, + ZCIncrementalMerkleTree tree) +{ + { + LOCK(cs_wallet); + for (std::pair& wtxItem : mapWallet) { + for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { + CNoteData* nd = &(item.second); + // Check the validity of the cache + assert(nWitnessCacheSize >= nd->witnesses.size()); + // Copy the witness for the previous block if we have one + if (nd->witnesses.size() > 0) { + nd->witnesses.push_front(nd->witnesses.front()); + } + if (nd->witnesses.size() > WITNESS_CACHE_SIZE) { + nd->witnesses.pop_back(); + } + } + } + + const CBlock* pblock {pblockIn}; + CBlock block; + if (!pblock) { + ReadBlockFromDisk(block, pindex); + pblock = █ + } + + for (const CTransaction& tx : pblock->vtx) { + auto hash = tx.GetTxid(); + bool txIsOurs = mapWallet.count(hash); + for (size_t i = 0; i < tx.vjoinsplit.size(); i++) { + const JSDescription& jsdesc = tx.vjoinsplit[i]; + for (uint8_t j = 0; j < jsdesc.commitments.size(); j++) { + const uint256& note_commitment = jsdesc.commitments[j]; + tree.append(note_commitment); + + // Increment existing witnesses + for (std::pair& wtxItem : mapWallet) { + for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { + CNoteData* nd = &(item.second); + if (nd->witnesses.size() > 0) { + nd->witnesses.front().append(note_commitment); + } + } + } + + // If this is our note, witness it + if (txIsOurs) { + JSOutPoint jsoutpt {hash, i, j}; + if (mapWallet[hash].mapNoteData.count(jsoutpt)) { + mapWallet[hash].mapNoteData[jsoutpt].witnesses.push_front( + tree.witness()); + } + } + } + } + } + if (nWitnessCacheSize < WITNESS_CACHE_SIZE) { + nWitnessCacheSize += 1; + } + if (fFileBacked) { + CWalletDB(strWalletFile).WriteWitnessCacheSize(nWitnessCacheSize); + } + } +} + +void CWallet::DecrementNoteWitnesses() +{ + { + LOCK(cs_wallet); + for (std::pair& wtxItem : mapWallet) { + for (mapNoteData_t::value_type& item : wtxItem.second.mapNoteData) { + CNoteData* nd = &(item.second); + if (nd->witnesses.size() > 0) { + nd->witnesses.pop_front(); + } + } + } + nWitnessCacheSize -= 1; + // TODO: If nWitnessCache is zero, we need to regenerate the caches (#1302) + assert(nWitnessCacheSize > 0); + if (fFileBacked) { + CWalletDB(strWalletFile).WriteWitnessCacheSize(nWitnessCacheSize); + } + } } bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase) @@ -665,6 +819,16 @@ void CWallet::MarkDirty() } } +void CWallet::UpdateNullifierNoteMap(const CWalletTx& wtx) +{ + { + LOCK(cs_wallet); + for (const mapNoteData_t::value_type& item : wtx.mapNoteData) { + mapNullifiersToNotes[item.second.nullifier] = item.first; + } + } +} + bool CWallet::AddToWallet(const CWalletTx& wtxIn, bool fFromLoadWallet, CWalletDB* pwalletdb) { uint256 hash = wtxIn.GetTxid(); @@ -673,6 +837,7 @@ bool CWallet::AddToWallet(const CWalletTx& wtxIn, bool fFromLoadWallet, CWalletD { mapWallet[hash] = wtxIn; mapWallet[hash].BindWallet(this); + UpdateNullifierNoteMap(mapWallet[hash]); AddToSpends(hash); } else @@ -682,6 +847,7 @@ bool CWallet::AddToWallet(const CWalletTx& wtxIn, bool fFromLoadWallet, CWalletD pair::iterator, bool> ret = mapWallet.insert(make_pair(hash, wtxIn)); CWalletTx& wtx = (*ret.first).second; wtx.BindWallet(this); + UpdateNullifierNoteMap(wtx); bool fInsertedNew = ret.second; if (fInsertedNew) { @@ -751,6 +917,20 @@ bool CWallet::AddToWallet(const CWalletTx& wtxIn, bool fFromLoadWallet, CWalletD wtx.nIndex = wtxIn.nIndex; fUpdated = true; } + if (!wtxIn.mapNoteData.empty() && wtxIn.mapNoteData != wtx.mapNoteData) + { + auto tmp = wtxIn.mapNoteData; + // Ensure we keep any cached witnesses we may already have + for (const std::pair nd : wtx.mapNoteData) { + if (tmp.count(nd.first) && nd.second.witnesses.size() > 0) { + tmp.at(nd.first).witnesses.assign( + nd.second.witnesses.cbegin(), nd.second.witnesses.cend()); + } + } + // Now copy over the updated note data + wtx.mapNoteData = tmp; + fUpdated = true; + } if (wtxIn.fFromMe && wtxIn.fFromMe != wtx.fFromMe) { wtx.fFromMe = wtxIn.fFromMe; @@ -796,10 +976,15 @@ bool CWallet::AddToWalletIfInvolvingMe(const CTransaction& tx, const CBlock* pbl AssertLockHeld(cs_wallet); bool fExisted = mapWallet.count(tx.GetTxid()) != 0; if (fExisted && !fUpdate) return false; - if (fExisted || IsMine(tx) || IsFromMe(tx)) + auto noteData = FindMyNotes(tx); + if (fExisted || IsMine(tx) || IsFromMe(tx) || noteData.size() > 0) { CWalletTx wtx(this,tx); + if (noteData.size() > 0) { + wtx.SetNoteData(noteData); + } + // Get merkle branch if transaction was found in a block if (pblock) wtx.SetMerkleBranch(*pblock); @@ -828,6 +1013,14 @@ void CWallet::SyncTransaction(const CTransaction& tx, const CBlock* pblock) if (mapWallet.count(txin.prevout.hash)) mapWallet[txin.prevout.hash].MarkDirty(); } + for (const JSDescription& jsdesc : tx.vjoinsplit) { + for (const uint256& nullifier : jsdesc.nullifiers) { + if (mapNullifiersToNotes.count(nullifier) && + mapWallet.count(mapNullifiersToNotes[nullifier].hash)) { + mapWallet[mapNullifiersToNotes[nullifier].hash].MarkDirty(); + } + } + } } void CWallet::EraseFromWallet(const uint256 &hash) @@ -843,6 +1036,86 @@ void CWallet::EraseFromWallet(const uint256 &hash) } +mapNoteData_t CWallet::FindMyNotes(const CTransaction& tx) const +{ + LOCK(cs_SpendingKeyStore); + uint256 hash = tx.GetTxid(); + + mapNoteData_t noteData; + libzcash::SpendingKey key; + for (size_t i = 0; i < tx.vjoinsplit.size(); i++) { + auto hSig = tx.vjoinsplit[i].h_sig(*pzcashParams, tx.joinSplitPubKey); + for (uint8_t j = 0; j < tx.vjoinsplit[i].ciphertexts.size(); j++) { + for (const NoteDecryptorMap::value_type& item : mapNoteDecryptors) { + try { + auto note_pt = libzcash::NotePlaintext::decrypt( + item.second, + tx.vjoinsplit[i].ciphertexts[j], + tx.vjoinsplit[i].ephemeralKey, + hSig, + (unsigned char) j); + auto address = item.first; + // Decryptors are only cached when SpendingKeys are added + assert(GetSpendingKey(address, key)); + auto note = note_pt.note(address); + JSOutPoint jsoutpt {hash, i, j}; + CNoteData nd {address, note.nullifier(key)}; + noteData.insert(std::make_pair(jsoutpt, nd)); + break; + } catch (const std::runtime_error &) { + // Couldn't decrypt with this spending key + } catch (const std::exception &exc) { + // Unexpected failure + LogPrintf("FindMyNotes(): Unexpected error while testing decrypt:\n"); + LogPrintf("%s\n", exc.what()); + } + } + } + } + return noteData; +} + +bool CWallet::IsFromMe(const uint256& nullifier) const +{ + { + LOCK(cs_wallet); + if (mapNullifiersToNotes.count(nullifier) && + mapWallet.count(mapNullifiersToNotes.at(nullifier).hash)) { + return true; + } + } + return false; +} + +void CWallet::GetNoteWitnesses(std::vector notes, + std::vector>& witnesses, + uint256 &final_anchor) +{ + { + LOCK(cs_wallet); + witnesses.resize(notes.size()); + boost::optional rt; + int i = 0; + for (JSOutPoint note : notes) { + if (mapWallet.count(note.hash) && + mapWallet[note.hash].mapNoteData.count(note) && + mapWallet[note.hash].mapNoteData[note].witnesses.size() > 0) { + witnesses[i] = mapWallet[note.hash].mapNoteData[note].witnesses.front(); + if (!rt) { + rt = witnesses[i]->root(); + } else { + assert(*rt == witnesses[i]->root()); + } + } + i++; + } + // All returned witnesses have the same anchor + if (rt) { + final_anchor = *rt; + } + } +} + isminetype CWallet::IsMine(const CTxIn &txin) const { { @@ -925,7 +1198,17 @@ bool CWallet::IsMine(const CTransaction& tx) const bool CWallet::IsFromMe(const CTransaction& tx) const { - return (GetDebit(tx, ISMINE_ALL) > 0); + if (GetDebit(tx, ISMINE_ALL) > 0) { + return true; + } + for (const JSDescription& jsdesc : tx.vjoinsplit) { + for (const uint256& nullifier : jsdesc.nullifiers) { + if (IsFromMe(nullifier)) { + return true; + } + } + } + return false; } CAmount CWallet::GetDebit(const CTransaction& tx, const isminefilter& filter) const @@ -964,6 +1247,22 @@ CAmount CWallet::GetChange(const CTransaction& tx) const return nChange; } +void CWalletTx::SetNoteData(mapNoteData_t ¬eData) +{ + mapNoteData.clear(); + for (const std::pair nd : noteData) { + if (nd.first.js < vjoinsplit.size() && + nd.first.n < vjoinsplit[nd.first.js].ciphertexts.size()) { + // Store the address and nullifier for the Note + mapNoteData[nd.first] = nd.second; + } else { + // If FindMyNotes() was used to obtain noteData, + // this should never happen + throw std::logic_error("CWalletTx::SetNoteData(): Invalid note"); + } + } +} + int64_t CWalletTx::GetTxTime() const { int64_t n = nTimeSmart; diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 0651c6404..b736e6906 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -7,6 +7,8 @@ #define BITCOIN_WALLET_WALLET_H #include "amount.h" +#include "coins.h" +#include "consensus/consensus.h" #include "key.h" #include "keystore.h" #include "primitives/block.h" @@ -52,6 +54,10 @@ static const unsigned int DEFAULT_TX_CONFIRM_TARGET = 2; static const CAmount nHighTransactionMaxFeeWarning = 100 * nHighTransactionFeeWarning; //! Largest (in bytes) free transaction we're willing to create static const unsigned int MAX_FREE_TRANSACTION_CREATE_SIZE = 1000; +//! Size of witness cache +// Should be large enough that we can expect not to reorg beyond our cache +// unless there is some exceptional network disruption. +static const unsigned int WITNESS_CACHE_SIZE = COINBASE_MATURITY; class CAccountingEntry; class CBlockIndex; @@ -146,6 +152,95 @@ struct COutputEntry int vout; }; +/** An note outpoint */ +class JSOutPoint +{ +public: + // Transaction hash + uint256 hash; + // Index into CTransaction.vjoinsplit + size_t js; + // Index into JSDescription fields of length ZC_NUM_JS_OUTPUTS + uint8_t n; + + JSOutPoint() { SetNull(); } + JSOutPoint(uint256 h, size_t js, uint8_t n) : hash {h}, js {js}, n {n} { } + + ADD_SERIALIZE_METHODS; + + template + inline void SerializationOp(Stream& s, Operation ser_action, int nType, int nVersion) { + READWRITE(hash); + READWRITE(js); + READWRITE(n); + } + + void SetNull() { hash.SetNull(); } + bool IsNull() const { return hash.IsNull(); } + + friend bool operator<(const JSOutPoint& a, const JSOutPoint& b) { + return (a.hash < b.hash || + (a.hash == b.hash && a.js < b.js) || + (a.hash == b.hash && a.js == b.js && a.n < b.n)); + } + + friend bool operator==(const JSOutPoint& a, const JSOutPoint& b) { + return (a.hash == b.hash && a.js == b.js && a.n == b.n); + } + + friend bool operator!=(const JSOutPoint& a, const JSOutPoint& b) { + return !(a == b); + } + + std::string ToString() const; +}; + +class CNoteData +{ +public: + libzcash::PaymentAddress address; + + // It's okay to cache the nullifier in the wallet, because we are storing + // the spending key there too, which could be used to derive this. + // If PR #1210 is merged, we need to revisit the threat model and decide + // whether it is okay to store this unencrypted while the spending key is + // encrypted. + uint256 nullifier; + + /** + * Cached incremental witnesses for spendable Notes. + * Beginning of the list is the most recent witness. + */ + std::list witnesses; + + CNoteData() : address(), nullifier() { } + CNoteData(libzcash::PaymentAddress a, uint256 n) : address {a}, nullifier {n} { } + + ADD_SERIALIZE_METHODS; + + template + inline void SerializationOp(Stream& s, Operation ser_action, int nType, int nVersion) { + READWRITE(address); + READWRITE(nullifier); + READWRITE(witnesses); + } + + friend bool operator<(const CNoteData& a, const CNoteData& b) { + return (a.address < b.address || + (a.address == b.address && a.nullifier < b.nullifier)); + } + + friend bool operator==(const CNoteData& a, const CNoteData& b) { + return (a.address == b.address && a.nullifier == b.nullifier); + } + + friend bool operator!=(const CNoteData& a, const CNoteData& b) { + return !(a == b); + } +}; + +typedef std::map mapNoteData_t; + /** A transaction with a merkle branch linking it to the block chain. */ class CMerkleTx : public CTransaction { @@ -216,6 +311,7 @@ private: public: mapValue_t mapValue; + mapNoteData_t mapNoteData; std::vector > vOrderForm; unsigned int fTimeReceivedIsTxTime; unsigned int nTimeReceived; //! time received by this node @@ -268,6 +364,7 @@ public: { pwallet = pwalletIn; mapValue.clear(); + mapNoteData.clear(); vOrderForm.clear(); fTimeReceivedIsTxTime = false; nTimeReceived = 0; @@ -317,6 +414,7 @@ public: std::vector vUnused; //! Used to be vtxPrev READWRITE(vUnused); READWRITE(mapValue); + READWRITE(mapNoteData); READWRITE(vOrderForm); READWRITE(fTimeReceivedIsTxTime); READWRITE(nTimeReceived); @@ -358,6 +456,8 @@ public: MarkDirty(); } + void SetNoteData(mapNoteData_t ¬eData); + //! filter decides which addresses will count towards the debit CAmount GetDebit(const isminefilter& filter) const; CAmount GetCredit(const isminefilter& filter) const; @@ -461,17 +561,43 @@ private: int64_t nLastResend; bool fBroadcastTransactions; + template + using TxSpendMap = std::multimap; /** * Used to keep track of spent outpoints, and * detect and report conflicts (double-spends or * mutated transactions where the mutant gets mined). */ - typedef std::multimap TxSpends; + typedef TxSpendMap TxSpends; TxSpends mapTxSpends; + /** + * Used to keep track of spent Notes, and + * detect and report conflicts (double-spends). + */ + typedef TxSpendMap TxNullifiers; + TxNullifiers mapTxNullifiers; + void AddToSpends(const COutPoint& outpoint, const uint256& wtxid); + void AddToSpends(const uint256& nullifier, const uint256& wtxid); void AddToSpends(const uint256& wtxid); - void SyncMetaData(std::pair); +public: + /* + * Size of the incremental witness cache for the notes in our wallet. + * This will always be greater than or equal to the size of the largest + * incremental witness cache in any transaction in mapWallet. + */ + int64_t nWitnessCacheSize; + +protected: + void IncrementNoteWitnesses(const CBlockIndex* pindex, + const CBlock* pblock, + ZCIncrementalMerkleTree tree); + void DecrementNoteWitnesses(); + +private: + template + void SyncMetaData(std::pair::iterator, typename TxSpendMap::iterator>); public: /* @@ -525,8 +651,10 @@ public: nLastResend = 0; nTimeFirstKey = 0; fBroadcastTransactions = false; + nWitnessCacheSize = 0; } + std::map mapNullifiersToNotes; std::map mapWallet; int64_t nOrderPosNext; @@ -549,6 +677,7 @@ public: bool SelectCoinsMinConf(const CAmount& nTargetValue, int nConfMine, int nConfTheirs, std::vector vCoins, std::set >& setCoinsRet, CAmount& nValueRet) const; bool IsSpent(const uint256& hash, unsigned int n) const; + bool IsSpent(const uint256& nullifier) const; bool IsLockedCoin(uint256 hash, unsigned int n) const; void LockCoin(COutPoint& output); @@ -628,6 +757,7 @@ public: TxItems OrderedTxItems(std::list& acentries, std::string strAccount = ""); void MarkDirty(); + void UpdateNullifierNoteMap(const CWalletTx& wtx); bool AddToWallet(const CWalletTx& wtxIn, bool fFromLoadWallet, CWalletDB* pwalletdb); void SyncTransaction(const CTransaction& tx, const CBlock* pblock); bool AddToWalletIfInvolvingMe(const CTransaction& tx, const CBlock* pblock, bool fUpdate); @@ -667,6 +797,13 @@ public: std::set GetAccountAddresses(std::string strAccount) const; + mapNoteData_t FindMyNotes(const CTransaction& tx) const; + bool IsFromMe(const uint256& nullifier) const; + void GetNoteWitnesses( + std::vector notes, + std::vector>& witnesses, + uint256 &final_anchor); + isminetype IsMine(const CTxIn& txin) const; CAmount GetDebit(const CTxIn& txin, const isminefilter& filter) const; isminetype IsMine(const CTxOut& txout) const; @@ -679,6 +816,7 @@ public: CAmount GetDebit(const CTransaction& tx, const isminefilter& filter) const; CAmount GetCredit(const CTransaction& tx, const isminefilter& filter) const; CAmount GetChange(const CTransaction& tx) const; + void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, ZCIncrementalMerkleTree tree, bool added); void SetBestChain(const CBlockLocator& loc); DBErrors LoadWallet(bool& fFirstRunRet); diff --git a/src/wallet/walletdb.cpp b/src/wallet/walletdb.cpp index ab1d364e4..e72ee5e3e 100644 --- a/src/wallet/walletdb.cpp +++ b/src/wallet/walletdb.cpp @@ -162,6 +162,12 @@ bool CWalletDB::WriteDefaultKey(const CPubKey& vchPubKey) return Write(std::string("defaultkey"), vchPubKey); } +bool CWalletDB::WriteWitnessCacheSize(int64_t nWitnessCacheSize) +{ + nWalletDBUpdated++; + return Write(std::string("witnesscachesize"), nWitnessCacheSize); +} + bool CWalletDB::ReadPool(int64_t nPool, CKeyPool& keypool) { return Read(std::make_pair(std::string("pool"), nPool), keypool); @@ -631,6 +637,10 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, return false; } } + else if (strType == "witnesscachesize") + { + ssValue >> pwallet->nWitnessCacheSize; + } } catch (...) { return false; diff --git a/src/wallet/walletdb.h b/src/wallet/walletdb.h index d261c9644..317f71304 100644 --- a/src/wallet/walletdb.h +++ b/src/wallet/walletdb.h @@ -106,6 +106,8 @@ public: bool WriteDefaultKey(const CPubKey& vchPubKey); + bool WriteWitnessCacheSize(int64_t nWitnessCacheSize); + bool ReadPool(int64_t nPool, CKeyPool& keypool); bool WritePool(int64_t nPool, const CKeyPool& keypool); bool ErasePool(int64_t nPool); diff --git a/src/zcash/Address.hpp b/src/zcash/Address.hpp index 36b9402a3..58caae772 100644 --- a/src/zcash/Address.hpp +++ b/src/zcash/Address.hpp @@ -23,7 +23,13 @@ public: READWRITE(pk_enc); } - friend inline bool operator<(const PaymentAddress& a, const PaymentAddress& b) { return a.a_pk < b.a_pk; } + friend inline bool operator==(const PaymentAddress& a, const PaymentAddress& b) { + return a.a_pk == b.a_pk && a.pk_enc == b.pk_enc; + } + friend inline bool operator<(const PaymentAddress& a, const PaymentAddress& b) { + return (a.a_pk < b.a_pk || + (a.a_pk == b.a_pk && a.pk_enc < b.pk_enc)); + } }; class ViewingKey : public uint256 { diff --git a/src/zcash/IncrementalMerkleTree.hpp b/src/zcash/IncrementalMerkleTree.hpp index ab4913922..1d023168a 100644 --- a/src/zcash/IncrementalMerkleTree.hpp +++ b/src/zcash/IncrementalMerkleTree.hpp @@ -43,10 +43,19 @@ public: Hash empty_root(size_t depth) { return empty_roots.at(depth); } + template + friend bool operator==(const EmptyMerkleRoots& a, + const EmptyMerkleRoots& b); private: boost::array empty_roots; }; +template +bool operator==(const EmptyMerkleRoots& a, + const EmptyMerkleRoots& b) { + return a.empty_roots == b.empty_roots; +} + template class IncrementalWitness; @@ -90,6 +99,10 @@ public: return emptyroots.empty_root(Depth); } + template + friend bool operator==(const IncrementalMerkleTree& a, + const IncrementalMerkleTree& b); + private: static EmptyMerkleRoots emptyroots; boost::optional left; @@ -104,11 +117,23 @@ private: void wfcheck() const; }; +template +bool operator==(const IncrementalMerkleTree& a, + const IncrementalMerkleTree& b) { + return (a.emptyroots == b.emptyroots && + a.left == b.left && + a.right == b.right && + a.parents == b.parents); +} + template class IncrementalWitness { friend class IncrementalMerkleTree; public: + // Required for Unserialize() + IncrementalWitness() {} + MerklePath path() const { return tree.path(partial_path()); } @@ -130,6 +155,10 @@ public: cursor_depth = tree.next_depth(filled.size()); } + template + friend bool operator==(const IncrementalWitness& a, + const IncrementalWitness& b); + private: IncrementalMerkleTree tree; std::vector filled; @@ -139,6 +168,15 @@ private: IncrementalWitness(IncrementalMerkleTree tree) : tree(tree) {} }; +template +bool operator==(const IncrementalWitness& a, + const IncrementalWitness& b) { + return (a.tree == b.tree && + a.filled == b.filled && + a.cursor == b.cursor && + a.cursor_depth == b.cursor_depth); +} + class SHA256Compress : public uint256 { public: SHA256Compress() : uint256() {} diff --git a/src/zcash/JoinSplit.cpp b/src/zcash/JoinSplit.cpp index b103581b3..71c1ae0ad 100644 --- a/src/zcash/JoinSplit.cpp +++ b/src/zcash/JoinSplit.cpp @@ -173,9 +173,10 @@ public: boost::array& out_commitments, uint64_t vpub_old, uint64_t vpub_new, - const uint256& rt + const uint256& rt, + bool computeProof ) { - if (!pk) { + if (computeProof && !pk) { throw std::runtime_error("JoinSplit proving key not loaded"); } @@ -231,6 +232,10 @@ public: out_macs[i] = PRF_pk(inputs[i].key, i, h_sig); } + if (!computeProof) { + return ZCProof(); + } + protoboard pb; { joinsplit_gadget g(pb); diff --git a/src/zcash/JoinSplit.hpp b/src/zcash/JoinSplit.hpp index e9e89c62d..1b655728d 100644 --- a/src/zcash/JoinSplit.hpp +++ b/src/zcash/JoinSplit.hpp @@ -73,7 +73,8 @@ public: boost::array& out_commitments, uint64_t vpub_old, uint64_t vpub_new, - const uint256& rt + const uint256& rt, + bool computeProof = true ) = 0; virtual bool verify( diff --git a/src/zcash/NoteEncryption.hpp b/src/zcash/NoteEncryption.hpp index 7161d5a20..e1f3718b0 100644 --- a/src/zcash/NoteEncryption.hpp +++ b/src/zcash/NoteEncryption.hpp @@ -61,6 +61,7 @@ public: typedef boost::array Ciphertext; typedef boost::array Plaintext; + NoteDecryption() { } NoteDecryption(uint256 sk_enc); Plaintext decrypt(const Ciphertext &ciphertext, @@ -68,6 +69,14 @@ public: const uint256 &hSig, unsigned char nonce ) const; + + friend inline bool operator==(const NoteDecryption& a, const NoteDecryption& b) { + return a.sk_enc == b.sk_enc && a.pk_enc == b.pk_enc; + } + friend inline bool operator<(const NoteDecryption& a, const NoteDecryption& b) { + return (a.sk_enc < b.sk_enc || + (a.sk_enc == b.sk_enc && a.pk_enc < b.pk_enc)); + } }; uint256 random_uint256();