diff --git a/src/gtest/test_keystore.cpp b/src/gtest/test_keystore.cpp index 6b4d08a59..ccf9cb9ba 100644 --- a/src/gtest/test_keystore.cpp +++ b/src/gtest/test_keystore.cpp @@ -215,12 +215,6 @@ TEST(keystore_tests, StoreAndRetrieveSaplingSpendingKey) { EXPECT_FALSE(keyStore.HaveSaplingIncomingViewingKey(addr)); EXPECT_FALSE(keyStore.GetSaplingIncomingViewingKey(addr, ivkOut)); - // If we don't specify the default address, that mapping isn't created - keyStore.AddSaplingSpendingKey(sk); - EXPECT_TRUE(keyStore.HaveSaplingSpendingKey(fvk)); - EXPECT_TRUE(keyStore.HaveSaplingFullViewingKey(ivk)); - EXPECT_FALSE(keyStore.HaveSaplingIncomingViewingKey(addr)); - // When we specify the default address, we get the full mapping keyStore.AddSaplingSpendingKey(sk, addr); EXPECT_TRUE(keyStore.HaveSaplingSpendingKey(fvk)); diff --git a/src/keystore.cpp b/src/keystore.cpp index fd42f7634..54fb02590 100644 --- a/src/keystore.cpp +++ b/src/keystore.cpp @@ -125,13 +125,13 @@ bool CBasicKeyStore::AddSproutSpendingKey(const libzcash::SproutSpendingKey &sk) //! Sapling bool CBasicKeyStore::AddSaplingSpendingKey( const libzcash::SaplingExtendedSpendingKey &sk, - const boost::optional &defaultAddr) + const libzcash::SaplingPaymentAddress &defaultAddr) { LOCK(cs_SpendingKeyStore); auto fvk = sk.expsk.full_viewing_key(); // if SaplingFullViewingKey is not in SaplingFullViewingKeyMap, add it - if (!AddSaplingFullViewingKey(fvk, defaultAddr)){ + if (!AddSaplingFullViewingKey(fvk, defaultAddr)) { return false; } @@ -151,17 +151,27 @@ bool CBasicKeyStore::AddSproutViewingKey(const libzcash::SproutViewingKey &vk) bool CBasicKeyStore::AddSaplingFullViewingKey( const libzcash::SaplingFullViewingKey &fvk, - const boost::optional &defaultAddr) + const libzcash::SaplingPaymentAddress &defaultAddr) { LOCK(cs_SpendingKeyStore); auto ivk = fvk.in_viewing_key(); mapSaplingFullViewingKeys[ivk] = fvk; - if (defaultAddr) { - // Add defaultAddr -> SaplingIncomingViewing to SaplingIncomingViewingKeyMap - mapSaplingIncomingViewingKeys[defaultAddr.get()] = ivk; - } - + return AddSaplingIncomingViewingKey(ivk, defaultAddr); +} + +// This function updates the wallet's internal address->ivk map. +// If we add an address that is already in the map, the map will +// remain unchanged as each address only has one ivk. +bool CBasicKeyStore::AddSaplingIncomingViewingKey( + const libzcash::SaplingIncomingViewingKey &ivk, + const libzcash::SaplingPaymentAddress &addr) +{ + LOCK(cs_SpendingKeyStore); + + // Add addr -> SaplingIncomingViewing to SaplingIncomingViewingKeyMap + mapSaplingIncomingViewingKeys[addr] = ivk; + return true; } diff --git a/src/keystore.h b/src/keystore.h index 5a80384bb..b369dec78 100644 --- a/src/keystore.h +++ b/src/keystore.h @@ -66,7 +66,7 @@ public: //! Add a Sapling spending key to the store. virtual bool AddSaplingSpendingKey( const libzcash::SaplingExtendedSpendingKey &sk, - const boost::optional &defaultAddr = boost::none) =0; + const libzcash::SaplingPaymentAddress &defaultAddr) =0; //! Check whether a Sapling spending key corresponding to a given Sapling viewing key is present in the store. virtual bool HaveSaplingSpendingKey(const libzcash::SaplingFullViewingKey &fvk) const =0; @@ -75,13 +75,16 @@ public: //! Support for Sapling full viewing keys virtual bool AddSaplingFullViewingKey( const libzcash::SaplingFullViewingKey &fvk, - const boost::optional &defaultAddr = boost::none) =0; + const libzcash::SaplingPaymentAddress &defaultAddr) =0; virtual bool HaveSaplingFullViewingKey(const libzcash::SaplingIncomingViewingKey &ivk) const =0; virtual bool GetSaplingFullViewingKey( const libzcash::SaplingIncomingViewingKey &ivk, libzcash::SaplingFullViewingKey& fvkOut) const =0; //! Sapling incoming viewing keys + virtual bool AddSaplingIncomingViewingKey( + const libzcash::SaplingIncomingViewingKey &ivk, + const libzcash::SaplingPaymentAddress &addr) =0; virtual bool HaveSaplingIncomingViewingKey(const libzcash::SaplingPaymentAddress &addr) const =0; virtual bool GetSaplingIncomingViewingKey( const libzcash::SaplingPaymentAddress &addr, @@ -236,7 +239,7 @@ public: //! Sapling bool AddSaplingSpendingKey( const libzcash::SaplingExtendedSpendingKey &sk, - const boost::optional &defaultAddr = boost::none); + const libzcash::SaplingPaymentAddress &defaultAddr); bool HaveSaplingSpendingKey(const libzcash::SaplingFullViewingKey &fvk) const { bool result; @@ -263,12 +266,15 @@ public: virtual bool AddSaplingFullViewingKey( const libzcash::SaplingFullViewingKey &fvk, - const boost::optional &defaultAddr = boost::none); + const libzcash::SaplingPaymentAddress &defaultAddr); virtual bool HaveSaplingFullViewingKey(const libzcash::SaplingIncomingViewingKey &ivk) const; virtual bool GetSaplingFullViewingKey( const libzcash::SaplingIncomingViewingKey &ivk, libzcash::SaplingFullViewingKey& fvkOut) const; + virtual bool AddSaplingIncomingViewingKey( + const libzcash::SaplingIncomingViewingKey &ivk, + const libzcash::SaplingPaymentAddress &addr); virtual bool HaveSaplingIncomingViewingKey(const libzcash::SaplingPaymentAddress &addr) const; virtual bool GetSaplingIncomingViewingKey( const libzcash::SaplingPaymentAddress &addr, diff --git a/src/wallet/crypter.cpp b/src/wallet/crypter.cpp index 88807598a..25d69cc3d 100644 --- a/src/wallet/crypter.cpp +++ b/src/wallet/crypter.cpp @@ -449,7 +449,7 @@ bool CCryptoKeyStore::AddSproutSpendingKey(const libzcash::SproutSpendingKey &sk bool CCryptoKeyStore::AddSaplingSpendingKey( const libzcash::SaplingExtendedSpendingKey &sk, - const boost::optional &defaultAddr) + const libzcash::SaplingPaymentAddress &defaultAddr) { { LOCK(cs_SpendingKeyStore); @@ -496,7 +496,7 @@ bool CCryptoKeyStore::AddCryptedSproutSpendingKey( bool CCryptoKeyStore::AddCryptedSaplingSpendingKey( const libzcash::SaplingFullViewingKey &fvk, const std::vector &vchCryptedSecret, - const boost::optional &defaultAddr) + const libzcash::SaplingPaymentAddress &defaultAddr) { { LOCK(cs_SpendingKeyStore); @@ -505,7 +505,7 @@ bool CCryptoKeyStore::AddCryptedSaplingSpendingKey( } // if SaplingFullViewingKey is not in SaplingFullViewingKeyMap, add it - if (!AddSaplingFullViewingKey(fvk, defaultAddr)){ + if (!AddSaplingFullViewingKey(fvk, defaultAddr)) { return false; } @@ -614,7 +614,7 @@ bool CCryptoKeyStore::EncryptKeys(CKeyingMaterial& vMasterKeyIn) if (!EncryptSecret(vMasterKeyIn, vchSecret, fvk.GetFingerprint(), vchCryptedSecret)) { return false; } - if (!AddCryptedSaplingSpendingKey(fvk, vchCryptedSecret)) { + if (!AddCryptedSaplingSpendingKey(fvk, vchCryptedSecret, sk.DefaultAddress())) { return false; } } diff --git a/src/wallet/crypter.h b/src/wallet/crypter.h index 7c326fbc4..b751ce300 100644 --- a/src/wallet/crypter.h +++ b/src/wallet/crypter.h @@ -243,10 +243,10 @@ public: virtual bool AddCryptedSaplingSpendingKey( const libzcash::SaplingFullViewingKey &fvk, const std::vector &vchCryptedSecret, - const boost::optional &defaultAddr = boost::none); + const libzcash::SaplingPaymentAddress &defaultAddr); bool AddSaplingSpendingKey( const libzcash::SaplingExtendedSpendingKey &sk, - const boost::optional &defaultAddr = boost::none); + const libzcash::SaplingPaymentAddress &defaultAddr); bool HaveSaplingSpendingKey(const libzcash::SaplingFullViewingKey &fvk) const { { diff --git a/src/wallet/gtest/test_wallet.cpp b/src/wallet/gtest/test_wallet.cpp index a67c3b491..f4b9ddc55 100644 --- a/src/wallet/gtest/test_wallet.cpp +++ b/src/wallet/gtest/test_wallet.cpp @@ -512,13 +512,13 @@ TEST(WalletTests, FindMySaplingNotes) { // No Sapling notes can be found in tx which does not belong to the wallet CWalletTx wtx {&wallet, tx}; ASSERT_FALSE(wallet.HaveSaplingSpendingKey(fvk)); - auto noteMap = wallet.FindMySaplingNotes(wtx); + auto noteMap = wallet.FindMySaplingNotes(wtx).first; EXPECT_EQ(0, noteMap.size()); // Add spending key to wallet, so Sapling notes can be found - ASSERT_TRUE(wallet.AddSaplingZKey(sk)); + ASSERT_TRUE(wallet.AddSaplingZKey(sk, pk)); ASSERT_TRUE(wallet.HaveSaplingSpendingKey(fvk)); - noteMap = wallet.FindMySaplingNotes(wtx); + noteMap = wallet.FindMySaplingNotes(wtx).first; EXPECT_EQ(2, noteMap.size()); // Revert to default @@ -630,7 +630,7 @@ TEST(WalletTests, GetConflictedSaplingNotes) { auto ivk = fvk.in_viewing_key(); auto pk = sk.DefaultAddress(); - ASSERT_TRUE(wallet.AddSaplingZKey(sk)); + ASSERT_TRUE(wallet.AddSaplingZKey(sk, pk)); ASSERT_TRUE(wallet.HaveSaplingSpendingKey(fvk)); // Generate note A @@ -664,7 +664,7 @@ TEST(WalletTests, GetConflictedSaplingNotes) { EXPECT_EQ(0, chainActive.Height()); // Simulate SyncTransaction which calls AddToWalletIfInvolvingMe - auto saplingNoteData = wallet.FindMySaplingNotes(wtx); + auto saplingNoteData = wallet.FindMySaplingNotes(wtx).first; ASSERT_TRUE(saplingNoteData.size() > 0); wtx.SetSaplingNoteData(saplingNoteData); wtx.SetMerkleBranch(block); @@ -815,7 +815,7 @@ TEST(WalletTests, SaplingNullifierIsSpent) { auto tx = maybe_tx.get(); CWalletTx wtx {&wallet, tx}; - ASSERT_TRUE(wallet.AddSaplingZKey(sk)); + ASSERT_TRUE(wallet.AddSaplingZKey(sk, pk)); ASSERT_TRUE(wallet.HaveSaplingSpendingKey(fvk)); // Manually compute the nullifier based on the known position @@ -912,7 +912,7 @@ TEST(WalletTests, NavigateFromSaplingNullifierToNote) { auto tx = maybe_tx.get(); CWalletTx wtx {&wallet, tx}; - ASSERT_TRUE(wallet.AddSaplingZKey(sk)); + ASSERT_TRUE(wallet.AddSaplingZKey(sk, pk)); ASSERT_TRUE(wallet.HaveSaplingSpendingKey(fvk)); // Manually compute the nullifier based on the expected position @@ -938,7 +938,7 @@ TEST(WalletTests, NavigateFromSaplingNullifierToNote) { // Simulate SyncTransaction which calls AddToWalletIfInvolvingMe wtx.SetMerkleBranch(block); - auto saplingNoteData = wallet.FindMySaplingNotes(wtx); + auto saplingNoteData = wallet.FindMySaplingNotes(wtx).first; ASSERT_TRUE(saplingNoteData.size() > 0); wtx.SetSaplingNoteData(saplingNoteData); wallet.AddToWallet(wtx, true, NULL); @@ -1048,7 +1048,7 @@ TEST(WalletTests, SpentSaplingNoteIsFromMe) { auto tx = maybe_tx.get(); CWalletTx wtx {&wallet, tx}; - ASSERT_TRUE(wallet.AddSaplingZKey(sk)); + ASSERT_TRUE(wallet.AddSaplingZKey(sk, pk)); ASSERT_TRUE(wallet.HaveSaplingSpendingKey(fvk)); // Fake-mine the transaction @@ -1064,7 +1064,7 @@ TEST(WalletTests, SpentSaplingNoteIsFromMe) { EXPECT_TRUE(chainActive.Contains(&fakeIndex)); EXPECT_EQ(0, chainActive.Height()); - auto saplingNoteData = wallet.FindMySaplingNotes(wtx); + auto saplingNoteData = wallet.FindMySaplingNotes(wtx).first; ASSERT_TRUE(saplingNoteData.size() > 0); wtx.SetSaplingNoteData(saplingNoteData); wtx.SetMerkleBranch(block); @@ -1141,7 +1141,7 @@ TEST(WalletTests, SpentSaplingNoteIsFromMe) { EXPECT_TRUE(chainActive.Contains(&fakeIndex2)); EXPECT_EQ(1, chainActive.Height()); - auto saplingNoteData2 = wallet.FindMySaplingNotes(wtx2); + auto saplingNoteData2 = wallet.FindMySaplingNotes(wtx2).first; ASSERT_TRUE(saplingNoteData2.size() > 0); wtx2.SetSaplingNoteData(saplingNoteData2); wtx2.SetMerkleBranch(block2); @@ -1751,7 +1751,7 @@ TEST(WalletTests, UpdatedSaplingNoteData) { // Wallet contains fvk1 but not fvk2 CWalletTx wtx {&wallet, tx}; - ASSERT_TRUE(wallet.AddSaplingZKey(sk)); + ASSERT_TRUE(wallet.AddSaplingZKey(sk, pk)); ASSERT_TRUE(wallet.HaveSaplingSpendingKey(fvk)); ASSERT_FALSE(wallet.HaveSaplingSpendingKey(fvk2)); @@ -1769,7 +1769,7 @@ TEST(WalletTests, UpdatedSaplingNoteData) { EXPECT_EQ(0, chainActive.Height()); // Simulate SyncTransaction which calls AddToWalletIfInvolvingMe - auto saplingNoteData = wallet.FindMySaplingNotes(wtx); + auto saplingNoteData = wallet.FindMySaplingNotes(wtx).first; ASSERT_TRUE(saplingNoteData.size() == 1); // wallet only has key for change output wtx.SetSaplingNoteData(saplingNoteData); wtx.SetMerkleBranch(block); @@ -1784,10 +1784,10 @@ TEST(WalletTests, UpdatedSaplingNoteData) { wtx = wallet.mapWallet[hash]; // Now lets add key fvk2 so wallet can find the payment note sent to pk2 - ASSERT_TRUE(wallet.AddSaplingZKey(sk2)); + ASSERT_TRUE(wallet.AddSaplingZKey(sk2, pk2)); ASSERT_TRUE(wallet.HaveSaplingSpendingKey(fvk2)); CWalletTx wtx2 = wtx; - auto saplingNoteData2 = wallet.FindMySaplingNotes(wtx2); + auto saplingNoteData2 = wallet.FindMySaplingNotes(wtx2).first; ASSERT_TRUE(saplingNoteData2.size() == 2); wtx2.SetSaplingNoteData(saplingNoteData2); @@ -1881,7 +1881,7 @@ TEST(WalletTests, MarkAffectedSaplingTransactionsDirty) { auto ivk = fvk.in_viewing_key(); auto pk = sk.DefaultAddress(); - ASSERT_TRUE(wallet.AddSaplingZKey(sk)); + ASSERT_TRUE(wallet.AddSaplingZKey(sk, pk)); ASSERT_TRUE(wallet.HaveSaplingSpendingKey(fvk)); // Set up transparent address @@ -1923,7 +1923,7 @@ TEST(WalletTests, MarkAffectedSaplingTransactionsDirty) { EXPECT_EQ(0, chainActive.Height()); // Simulate SyncTransaction which calls AddToWalletIfInvolvingMe - auto saplingNoteData = wallet.FindMySaplingNotes(wtx); + auto saplingNoteData = wallet.FindMySaplingNotes(wtx).first; ASSERT_TRUE(saplingNoteData.size() > 0); wtx.SetSaplingNoteData(saplingNoteData); wtx.SetMerkleBranch(block); diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index d838f485e..fea0dc881 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -154,7 +154,7 @@ SaplingPaymentAddress CWallet::GenerateNewSaplingZKey() // Add spending key to keystore bool CWallet::AddSaplingZKey( const libzcash::SaplingExtendedSpendingKey &sk, - const boost::optional &defaultAddr) + const libzcash::SaplingPaymentAddress &defaultAddr) { AssertLockHeld(cs_wallet); // mapSaplingZKeyMetadata @@ -295,7 +295,7 @@ bool CWallet::AddCryptedSproutSpendingKey( bool CWallet::AddCryptedSaplingSpendingKey(const libzcash::SaplingFullViewingKey &fvk, const std::vector &vchCryptedSecret, - const boost::optional &defaultAddr) + const libzcash::SaplingPaymentAddress &defaultAddr) { if (!CCryptoKeyStore::AddCryptedSaplingSpendingKey(fvk, vchCryptedSecret, defaultAddr)) return false; @@ -1537,7 +1537,14 @@ bool CWallet::AddToWalletIfInvolvingMe(const CTransaction& tx, const CBlock* pbl bool fExisted = mapWallet.count(tx.GetHash()) != 0; if (fExisted && !fUpdate) return false; auto sproutNoteData = FindMySproutNotes(tx); - auto saplingNoteData = FindMySaplingNotes(tx); + auto saplingNoteDataAndAddressesToAdd = FindMySaplingNotes(tx); + auto saplingNoteData = saplingNoteDataAndAddressesToAdd.first; + auto addressesToAdd = saplingNoteDataAndAddressesToAdd.second; + for (const auto &addressToAdd : addressesToAdd) { + if (!AddSaplingIncomingViewingKey(addressToAdd.second, addressToAdd.first)) { + return false; + } + } if (fExisted || IsMine(tx) || IsFromMe(tx) || sproutNoteData.size() > 0 || saplingNoteData.size() > 0) { CWalletTx wtx(this,tx); @@ -1698,12 +1705,13 @@ mapSproutNoteData_t CWallet::FindMySproutNotes(const CTransaction &tx) const * the result of FindMySaplingNotes (for the addresses available at the time) will * already have been cached in CWalletTx.mapSaplingNoteData. */ -mapSaplingNoteData_t CWallet::FindMySaplingNotes(const CTransaction &tx) const +std::pair CWallet::FindMySaplingNotes(const CTransaction &tx) const { LOCK(cs_SpendingKeyStore); uint256 hash = tx.GetHash(); mapSaplingNoteData_t noteData; + SaplingIncomingViewingKeyMap viewingKeysToAdd; // Protocol Spec: 4.19 Block Chain Scanning (Sapling) for (uint32_t i = 0; i < tx.vShieldedOutput.size(); ++i) { @@ -1714,6 +1722,10 @@ mapSaplingNoteData_t CWallet::FindMySaplingNotes(const CTransaction &tx) const if (!result) { continue; } + auto address = ivk.address(result.get().d); + if (address && mapSaplingIncomingViewingKeys.count(address.get()) == 0) { + viewingKeysToAdd[address.get()] = ivk; + } // We don't cache the nullifier here as computing it requires knowledge of the note position // in the commitment tree, which can only be determined when the transaction has been mined. SaplingOutPoint op {hash, i}; @@ -1724,7 +1736,7 @@ mapSaplingNoteData_t CWallet::FindMySaplingNotes(const CTransaction &tx) const } } - return noteData; + return std::make_pair(noteData, viewingKeysToAdd); } bool CWallet::IsSproutNullifierFromMe(const uint256& nullifier) const diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 3589f9629..a08a8c782 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -1058,11 +1058,11 @@ public: //! Adds Sapling spending key to the store, and saves it to disk bool AddSaplingZKey( const libzcash::SaplingExtendedSpendingKey &key, - const boost::optional &defaultAddr = boost::none); + const libzcash::SaplingPaymentAddress &defaultAddr); bool AddCryptedSaplingSpendingKey( const libzcash::SaplingFullViewingKey &fvk, const std::vector &vchCryptedSecret, - const boost::optional &defaultAddr = boost::none); + const libzcash::SaplingPaymentAddress &defaultAddr); /** * Increment the next transaction order id @@ -1132,7 +1132,7 @@ public: const uint256& hSig, uint8_t n) const; mapSproutNoteData_t FindMySproutNotes(const CTransaction& tx) const; - mapSaplingNoteData_t FindMySaplingNotes(const CTransaction& tx) const; + std::pair FindMySaplingNotes(const CTransaction& tx) const; bool IsSproutNullifierFromMe(const uint256& nullifier) const; bool IsSaplingNullifierFromMe(const uint256& nullifier) const;