diff --git a/src/wallet/gtest/test_wallet.cpp b/src/wallet/gtest/test_wallet.cpp index 8db093e7..56ec6337 100644 --- a/src/wallet/gtest/test_wallet.cpp +++ b/src/wallet/gtest/test_wallet.cpp @@ -3,6 +3,7 @@ #include "base58.h" #include "chainparams.h" +#include "main.h" #include "random.h" #include "wallet/wallet.h" #include "zcash/JoinSplit.hpp" @@ -78,6 +79,55 @@ libzcash::Note GetNote(const libzcash::SpendingKey& sk, return note_pt.note(sk.address()); } +CWalletTx GetValidSpend(const libzcash::SpendingKey& sk, + const libzcash::Note& note, CAmount value) { + CMutableTransaction mtx; + mtx.vout.resize(2); + mtx.vout[0].nValue = value; + mtx.vout[1].nValue = 0; + + // Generate an ephemeral keypair. + uint256 joinSplitPubKey; + unsigned char joinSplitPrivKey[crypto_sign_SECRETKEYBYTES]; + crypto_sign_keypair(joinSplitPubKey.begin(), joinSplitPrivKey); + mtx.joinSplitPubKey = joinSplitPubKey; + + // Fake tree for the unused witness + ZCIncrementalMerkleTree tree; + + boost::array inputs = { + libzcash::JSInput(tree.witness(), note, sk), + libzcash::JSInput() // dummy input + }; + + boost::array outputs = { + libzcash::JSOutput(), // dummy output + libzcash::JSOutput() // dummy output + }; + + boost::array output_notes; + + // Prepare JoinSplits + uint256 rt; + JSDescription jsdesc {*params, mtx.joinSplitPubKey, rt, + inputs, outputs, 0, value, false}; + mtx.vjoinsplit.push_back(jsdesc); + + // Empty output script. + CScript scriptCode; + CTransaction signTx(mtx); + uint256 dataToBeSigned = SignatureHash(scriptCode, signTx, NOT_AN_INPUT, SIGHASH_ALL); + + // Add the signature + assert(crypto_sign_detached(&mtx.joinSplitSig[0], NULL, + dataToBeSigned.begin(), 32, + joinSplitPrivKey + ) == 0); + CTransaction tx {mtx}; + CWalletTx wtx {NULL, tx}; + return wtx; +} + TEST(wallet_tests, set_note_addrs_in_cwallettx) { auto sk = libzcash::SpendingKey::random(); auto wtx = GetValidReceive(sk, 10, true); @@ -125,3 +175,74 @@ TEST(wallet_tests, find_note_in_tx) { EXPECT_EQ(1, noteMap.count(jsoutpt)); EXPECT_EQ(nd, noteMap[jsoutpt]); } + +TEST(wallet_tests, get_conflicted_notes) { + CWallet wallet; + + auto sk = libzcash::SpendingKey::random(); + wallet.AddSpendingKey(sk); + + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + + auto wtx2 = GetValidSpend(sk, note, 5); + auto wtx3 = GetValidSpend(sk, note, 10); + auto hash2 = wtx2.GetTxid(); + auto hash3 = wtx3.GetTxid(); + + // No conflicts for no spends + EXPECT_EQ(0, wallet.GetConflicts(hash2).size()); + wallet.AddToWallet(wtx, true, NULL); + EXPECT_EQ(0, wallet.GetConflicts(hash2).size()); + + // No conflicts for one spend + wallet.AddToWallet(wtx2, true, NULL); + EXPECT_EQ(0, wallet.GetConflicts(hash2).size()); + + // Conflicts for two spends + wallet.AddToWallet(wtx3, true, NULL); + auto c3 = wallet.GetConflicts(hash2); + EXPECT_EQ(2, c3.size()); + EXPECT_EQ(std::set({hash2, hash3}), c3); +} + +TEST(wallet_tests, nullifier_is_spent) { + CWallet wallet; + + auto sk = libzcash::SpendingKey::random(); + wallet.AddSpendingKey(sk); + + auto wtx = GetValidReceive(sk, 10, true); + auto note = GetNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + + EXPECT_FALSE(wallet.IsSpent(nullifier)); + + wallet.AddToWallet(wtx, true, NULL); + EXPECT_FALSE(wallet.IsSpent(nullifier)); + + auto wtx2 = GetValidSpend(sk, note, 5); + wallet.AddToWallet(wtx2, true, NULL); + EXPECT_FALSE(wallet.IsSpent(nullifier)); + + // Fake-mine the transaction + EXPECT_EQ(-1, chainActive.Height()); + CBlock block; + block.vtx.push_back(wtx2); + block.hashMerkleRoot = block.BuildMerkleTree(); + auto blockHash = block.GetHash(); + CBlockIndex fakeIndex {block}; + mapBlockIndex.insert(std::make_pair(blockHash, &fakeIndex)); + chainActive.SetTip(&fakeIndex); + EXPECT_TRUE(chainActive.Contains(&fakeIndex)); + EXPECT_EQ(0, chainActive.Height()); + + wtx2.SetMerkleBranch(block); + wallet.AddToWallet(wtx2, true, NULL); + EXPECT_TRUE(wallet.IsSpent(nullifier)); + + // Tear down + chainActive.SetTip(NULL); + mapBlockIndex.erase(blockHash); +} diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 0d287dac..7da567ee 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -401,6 +401,20 @@ set CWallet::GetConflicts(const uint256& txid) const for (TxSpends::const_iterator it = range.first; it != range.second; ++it) result.insert(it->second); } + + std::pair range_n; + + for (const JSDescription& jsdesc : wtx.vjoinsplit) { + for (const uint256& nullifier : jsdesc.nullifiers) { + if (mapTxNullifiers.count(nullifier) <= 1) { + continue; // No conflict if zero or one spends + } + range_n = mapTxNullifiers.equal_range(nullifier); + for (TxNullifiers::const_iterator it = range_n.first; it != range_n.second; ++it) { + result.insert(it->second); + } + } + } return result; } @@ -456,7 +470,8 @@ bool CWallet::Verify(const string& walletFile, string& warningString, string& er return true; } -void CWallet::SyncMetaData(pair range) +template +void CWallet::SyncMetaData(pair::iterator, typename TxSpendMap::iterator> range) { // We want all the wallet transactions in range to have the same metadata as // the oldest (smallest nOrderPos). @@ -464,7 +479,7 @@ void CWallet::SyncMetaData(pair range) int nMinOrderPos = std::numeric_limits::max(); const CWalletTx* copyFrom = NULL; - for (TxSpends::iterator it = range.first; it != range.second; ++it) + for (typename TxSpendMap::iterator it = range.first; it != range.second; ++it) { const uint256& hash = it->second; int n = mapWallet[hash].nOrderPos; @@ -475,7 +490,7 @@ void CWallet::SyncMetaData(pair range) } } // Now copy data from copyFrom to rest: - for (TxSpends::iterator it = range.first; it != range.second; ++it) + for (typename TxSpendMap::iterator it = range.first; it != range.second; ++it) { const uint256& hash = it->second; CWalletTx* copyTo = &mapWallet[hash]; @@ -514,15 +529,42 @@ bool CWallet::IsSpent(const uint256& hash, unsigned int n) const return false; } +/** + * Note is spent if any non-conflicted transaction + * spends it: + */ +bool CWallet::IsSpent(const uint256& nullifier) const +{ + pair range; + range = mapTxNullifiers.equal_range(nullifier); + + for (TxNullifiers::const_iterator it = range.first; it != range.second; ++it) { + const uint256& wtxid = it->second; + std::map::const_iterator mit = mapWallet.find(wtxid); + if (mit != mapWallet.end() && mit->second.GetDepthInMainChain() >= 0) { + return true; // Spent + } + } + return false; +} + void CWallet::AddToSpends(const COutPoint& outpoint, const uint256& wtxid) { mapTxSpends.insert(make_pair(outpoint, wtxid)); pair range; range = mapTxSpends.equal_range(outpoint); - SyncMetaData(range); + SyncMetaData(range); } +void CWallet::AddToSpends(const uint256& nullifier, const uint256& wtxid) +{ + mapTxNullifiers.insert(make_pair(nullifier, wtxid)); + + pair range; + range = mapTxNullifiers.equal_range(nullifier); + SyncMetaData(range); +} void CWallet::AddToSpends(const uint256& wtxid) { @@ -531,8 +573,14 @@ void CWallet::AddToSpends(const uint256& wtxid) if (thisTx.IsCoinBase()) // Coinbases don't spend anything! return; - BOOST_FOREACH(const CTxIn& txin, thisTx.vin) + for (const CTxIn& txin : thisTx.vin) { AddToSpends(txin.prevout, wtxid); + } + for (const JSDescription& jsdesc : thisTx.vjoinsplit) { + for (const uint256& nullifier : jsdesc.nullifiers) { + AddToSpends(nullifier, wtxid); + } + } } bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase) diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 6b118b9c..5c94655e 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -548,17 +548,28 @@ private: int64_t nLastResend; bool fBroadcastTransactions; + template + using TxSpendMap = std::multimap; /** * Used to keep track of spent outpoints, and * detect and report conflicts (double-spends or * mutated transactions where the mutant gets mined). */ - typedef std::multimap TxSpends; + typedef TxSpendMap TxSpends; TxSpends mapTxSpends; + /** + * Used to keep track of spent Notes, and + * detect and report conflicts (double-spends). + */ + typedef TxSpendMap TxNullifiers; + TxNullifiers mapTxNullifiers; + void AddToSpends(const COutPoint& outpoint, const uint256& wtxid); + void AddToSpends(const uint256& nullifier, const uint256& wtxid); void AddToSpends(const uint256& wtxid); - void SyncMetaData(std::pair); + template + void SyncMetaData(std::pair::iterator, typename TxSpendMap::iterator>); public: /* @@ -636,6 +647,7 @@ public: bool SelectCoinsMinConf(const CAmount& nTargetValue, int nConfMine, int nConfTheirs, std::vector vCoins, std::set >& setCoinsRet, CAmount& nValueRet) const; bool IsSpent(const uint256& hash, unsigned int n) const; + bool IsSpent(const uint256& nullifier) const; bool IsLockedCoin(uint256 hash, unsigned int n) const; void LockCoin(COutPoint& output);