Apply suggestions from code review

Co-authored-by: str4d <jack@electriccoin.co>
This commit is contained in:
Kris Nuttycombe 2022-01-03 15:46:28 -07:00 committed by Kris Nuttycombe
parent ad54591061
commit eb53abbbaf
8 changed files with 86 additions and 75 deletions

View File

@ -536,17 +536,19 @@ TEST(KeystoreTests, StoreAndRetrieveUFVK) {
auto addrPair = zufvk.FindAddress(diversifier_index_t(0), {ReceiverType::Sapling}).value(); auto addrPair = zufvk.FindAddress(diversifier_index_t(0), {ReceiverType::Sapling}).value();
EXPECT_TRUE(addrPair.first.GetSaplingReceiver().has_value()); EXPECT_TRUE(addrPair.first.GetSaplingReceiver().has_value());
auto saplingReceiver = addrPair.first.GetSaplingReceiver().value(); auto saplingReceiver = addrPair.first.GetSaplingReceiver().value();
auto ufvkmeta = keyStore.GetUFVKMetadataForReceiver(saplingReceiver);
EXPECT_FALSE(ufvkmeta.has_value());
auto saplingIvk = zufvk.GetSaplingKey().value().fvk.in_viewing_key(); auto saplingIvk = zufvk.GetSaplingKey().value().fvk.in_viewing_key();
keyStore.AddSaplingIncomingViewingKey(saplingIvk, saplingReceiver); keyStore.AddSaplingIncomingViewingKey(saplingIvk, saplingReceiver);
auto ufvkmeta = keyStore.GetUFVKMetadataForReceiver(saplingReceiver); ufvkmeta = keyStore.GetUFVKMetadataForReceiver(saplingReceiver);
EXPECT_TRUE(ufvkmeta.has_value()); EXPECT_TRUE(ufvkmeta.has_value());
EXPECT_EQ(ufvkmeta.value().first, ufvkid); EXPECT_EQ(ufvkmeta.value().first, ufvkid);
EXPECT_FALSE(ufvkmeta.value().second.has_value()); EXPECT_FALSE(ufvkmeta.value().second.has_value());
} }
TEST(KeystoreTests, AddUnifiedAddress) { TEST(KeystoreTests, AddTransparentReceiverForUnifiedAddress) {
SelectParams(CBaseChainParams::TESTNET); SelectParams(CBaseChainParams::TESTNET);
CBasicKeyStore keyStore; CBasicKeyStore keyStore;
@ -559,10 +561,12 @@ TEST(KeystoreTests, AddUnifiedAddress) {
auto ufvkid = zufvk.GetKeyID(); auto ufvkid = zufvk.GetKeyID();
auto addrPair = zufvk.FindAddress(diversifier_index_t(0), {ReceiverType::P2PKH, ReceiverType::Sapling}).value(); auto addrPair = zufvk.FindAddress(diversifier_index_t(0), {ReceiverType::P2PKH, ReceiverType::Sapling}).value();
EXPECT_TRUE(addrPair.first.GetP2PKHReceiver().has_value()); EXPECT_TRUE(addrPair.first.GetP2PKHReceiver().has_value());
keyStore.AddUnifiedAddress(ufvkid, addrPair.second, addrPair.first);
auto ufvkmeta = keyStore.GetUFVKMetadataForReceiver(addrPair.first.GetP2PKHReceiver().value()); auto ufvkmeta = keyStore.GetUFVKMetadataForReceiver(addrPair.first.GetP2PKHReceiver().value());
EXPECT_FALSE(ufvkmeta.has_value());
keyStore.AddTransparentReceiverForUnifiedAddress(ufvkid, addrPair.second, addrPair.first);
ufvkmeta = keyStore.GetUFVKMetadataForReceiver(addrPair.first.GetP2PKHReceiver().value());
EXPECT_TRUE(ufvkmeta.has_value()); EXPECT_TRUE(ufvkmeta.has_value());
EXPECT_EQ(ufvkmeta.value().first, ufvkid); EXPECT_EQ(ufvkmeta.value().first, ufvkid);
} }

View File

@ -316,7 +316,7 @@ bool CBasicKeyStore::AddUnifiedFullViewingKey(
return true; return true;
} }
bool CBasicKeyStore::AddUnifiedAddress( bool CBasicKeyStore::AddTransparentReceiverForUnifiedAddress(
const libzcash::UFVKId& keyId, const libzcash::UFVKId& keyId,
const libzcash::diversifier_index_t& diversifierIndex, const libzcash::diversifier_index_t& diversifierIndex,
const libzcash::UnifiedAddress& ua) const libzcash::UnifiedAddress& ua)
@ -361,10 +361,11 @@ CBasicKeyStore::GetUFVKMetadataForReceiver(const libzcash::Receiver& receiver) c
std::optional<std::pair<libzcash::UFVKId, std::optional<libzcash::diversifier_index_t>>> std::optional<std::pair<libzcash::UFVKId, std::optional<libzcash::diversifier_index_t>>>
FindUFVKId::operator()(const libzcash::SaplingPaymentAddress& saplingAddr) const { FindUFVKId::operator()(const libzcash::SaplingPaymentAddress& saplingAddr) const {
if (keystore.mapSaplingIncomingViewingKeys.count(saplingAddr) > 0) { const auto saplingIvk = keystore.mapSaplingIncomingViewingKeys.find(saplingAddr);
const auto& saplingIvk = keystore.mapSaplingIncomingViewingKeys.at(saplingAddr); if (saplingIvk != keystore.mapSaplingIncomingViewingKeys.end()) {
if (keystore.mapSaplingKeyUnified.count(saplingIvk) > 0) { const auto ufvkId = keystore.mapSaplingKeyUnified.find(saplingIvk->second);
return std::make_pair(keystore.mapSaplingKeyUnified.at(saplingIvk), std::nullopt); if (ufvkId != keystore.mapSaplingKeyUnified.end()) {
return std::make_pair(ufvkId->second, std::nullopt);
} else { } else {
return std::nullopt; return std::nullopt;
} }
@ -374,16 +375,18 @@ FindUFVKId::operator()(const libzcash::SaplingPaymentAddress& saplingAddr) const
} }
std::optional<std::pair<libzcash::UFVKId, std::optional<libzcash::diversifier_index_t>>> std::optional<std::pair<libzcash::UFVKId, std::optional<libzcash::diversifier_index_t>>>
FindUFVKId::operator()(const CScriptID& scriptId) const { FindUFVKId::operator()(const CScriptID& scriptId) const {
if (keystore.mapP2SHUnified.count(scriptId) > 0) { const auto metadata = keystore.mapP2SHUnified.find(scriptId);
return keystore.mapP2SHUnified.at(scriptId); if (metadata != keystore.mapP2SHUnified.end()) {
return metadata->second;
} else { } else {
return std::nullopt; return std::nullopt;
} }
} }
std::optional<std::pair<libzcash::UFVKId, std::optional<libzcash::diversifier_index_t>>> std::optional<std::pair<libzcash::UFVKId, std::optional<libzcash::diversifier_index_t>>>
FindUFVKId::operator()(const CKeyID& keyId) const { FindUFVKId::operator()(const CKeyID& keyId) const {
if (keystore.mapP2PKHUnified.count(keyId) > 0) { const auto metadata = keystore.mapP2PKHUnified.find(keyId);
return keystore.mapP2PKHUnified.at(keyId); if (metadata != keystore.mapP2PKHUnified.end()) {
return metadata->second;
} else { } else {
return std::nullopt; return std::nullopt;
} }

View File

@ -117,7 +117,7 @@ public:
* viewing key upon discovery of the address as having received * viewing key upon discovery of the address as having received
* funds. * funds.
*/ */
virtual bool AddUnifiedAddress( virtual bool AddTransparentReceiverForUnifiedAddress(
const libzcash::UFVKId& keyId, const libzcash::UFVKId& keyId,
const libzcash::diversifier_index_t& diversifierIndex, const libzcash::diversifier_index_t& diversifierIndex,
const libzcash::UnifiedAddress& ua) = 0; const libzcash::UnifiedAddress& ua) = 0;
@ -357,7 +357,7 @@ public:
virtual bool AddUnifiedFullViewingKey( virtual bool AddUnifiedFullViewingKey(
const libzcash::ZcashdUnifiedFullViewingKey &ufvk); const libzcash::ZcashdUnifiedFullViewingKey &ufvk);
virtual bool AddUnifiedAddress( virtual bool AddTransparentReceiverForUnifiedAddress(
const libzcash::UFVKId& keyId, const libzcash::UFVKId& keyId,
const libzcash::diversifier_index_t& diversifierIndex, const libzcash::diversifier_index_t& diversifierIndex,
const libzcash::UnifiedAddress& ua); const libzcash::UnifiedAddress& ua);

View File

@ -182,7 +182,6 @@ TEST(WalletTests, SproutNoteDataSerialisation) {
EXPECT_EQ(noteData[jsoutpt].witnesses, noteData2[jsoutpt].witnesses); EXPECT_EQ(noteData[jsoutpt].witnesses, noteData2[jsoutpt].witnesses);
} }
TEST(WalletTests, FindUnspentSproutNotes) { TEST(WalletTests, FindUnspentSproutNotes) {
SelectParams(CBaseChainParams::TESTNET); SelectParams(CBaseChainParams::TESTNET);
@ -356,9 +355,6 @@ TEST(WalletTests, FindUnspentSproutNotes) {
mapBlockIndex.erase(blockHash); mapBlockIndex.erase(blockHash);
mapBlockIndex.erase(blockHash2); mapBlockIndex.erase(blockHash2);
mapBlockIndex.erase(blockHash3); mapBlockIndex.erase(blockHash3);
// Revert to default
RegtestDeactivateSapling();
} }
@ -2232,4 +2228,7 @@ TEST(WalletTests, GenerateUnifiedAddress) {
expected = AddressGenerationError::DiversifierSpaceExhausted; expected = AddressGenerationError::DiversifierSpaceExhausted;
EXPECT_EQ(uaResult, expected); EXPECT_EQ(uaResult, expected);
} }
// Revert to default
RegtestDeactivateSapling();
} }

View File

@ -195,9 +195,7 @@ bool CWallet::AddSaplingZKey(const libzcash::SaplingExtendedSpendingKey &sk)
return true; return true;
} }
bool CWallet::AddSaplingFullViewingKey( bool CWallet::AddSaplingFullViewingKey(const libzcash::SaplingExtendedFullViewingKey &extfvk)
const libzcash::SaplingExtendedFullViewingKey &extfvk,
const std::optional<libzcash::UFVKId>& ufvkId)
{ {
AssertLockHeld(cs_wallet); AssertLockHeld(cs_wallet);
@ -636,16 +634,16 @@ UAGenerationResult CWallet::GenerateUnifiedAddress(
return AddressGenerationError::NoAddressForDiversifier; return AddressGenerationError::NoAddressForDiversifier;
} }
// Persist the newly created address to the keystore assert(mapUfvkAddressMetadata[ufvkid].SetReceivers(diversifierIndex, receiverTypes));
mapUfvkAddressMetadata[ufvkid].SetReceivers(diversifierIndex, receiverTypes); // Writing this data is handled by `CWalletDB::WriteUnifiedAddressMetadata` below.
CCryptoKeyStore::AddUnifiedAddress(ufvkid, diversifierIndex, address.value()); assert(CCryptoKeyStore::AddTransparentReceiverForUnifiedAddress(ufvkid, diversifierIndex, address.value()));
// Save the metadata for the generated address so that we can re-derive // Save the metadata for the generated address so that we can re-derive
// it in the future. // it in the future.
ZcashdUnifiedAddressMetadata addrmeta(ufvkid, diversifierIndex, receiverTypes); ZcashdUnifiedAddressMetadata addrmeta(ufvkid, diversifierIndex, receiverTypes);
if (fFileBacked && !CWalletDB(strWalletFile).WriteUnifiedAddressMetadata(addrmeta)) { if (fFileBacked && !CWalletDB(strWalletFile).WriteUnifiedAddressMetadata(addrmeta)) {
throw std::runtime_error( throw std::runtime_error(
"CWallet::AdddUnifiedAddress(): Writing unified address metadata failed"); "CWallet::AddUnifiedAddress(): Writing unified address metadata failed");
} }
if (hasTransparent) { if (hasTransparent) {
@ -691,11 +689,14 @@ bool CWallet::LoadUnifiedFullViewingKey(const libzcash::UnifiedFullViewingKey &k
if (metadata != mapUfvkAddressMetadata.end()) { if (metadata != mapUfvkAddressMetadata.end()) {
// restore unified addresses that have been previously generated to the // restore unified addresses that have been previously generated to the
// keystore // keystore
for (const auto &[j, receiverTypes] : metadata->second.GetAllReceivers()) { for (const auto &[j, receiverTypes] : metadata->second.GetKnownReceiverSetsByDiversifierIndex()) {
auto addr = zufvk.Address(j, receiverTypes).value(); auto addr = zufvk.Address(j, receiverTypes).value();
CCryptoKeyStore::AddUnifiedAddress(zufvk.GetKeyID(), j, addr); if (!CCryptoKeyStore::AddTransparentReceiverForUnifiedAddress(zufvk.GetKeyID(), j, addr)) {
return false;
} }
} }
}
return CCryptoKeyStore::AddUnifiedFullViewingKey(zufvk); return CCryptoKeyStore::AddUnifiedFullViewingKey(zufvk);
} }
@ -710,16 +711,18 @@ bool CWallet::LoadUnifiedAccountMetadata(const ZcashdUnifiedAccountMetadata &skm
bool CWallet::LoadUnifiedAddressMetadata(const ZcashdUnifiedAddressMetadata &addrmeta) bool CWallet::LoadUnifiedAddressMetadata(const ZcashdUnifiedAddressMetadata &addrmeta)
{ {
AssertLockHeld(cs_wallet); AssertLockHeld(cs_wallet);
mapUfvkAddressMetadata[addrmeta.GetKeyID()].SetReceivers( if (!mapUfvkAddressMetadata[addrmeta.GetKeyID()].SetReceivers(
addrmeta.GetDiversifierIndex(), addrmeta.GetDiversifierIndex(),
addrmeta.GetReceiverTypes()); addrmeta.GetReceiverTypes())) {
return false;
}
auto ufvk = GetUnifiedFullViewingKey(addrmeta.GetKeyID()); auto ufvk = GetUnifiedFullViewingKey(addrmeta.GetKeyID());
if (ufvk.has_value()) { if (ufvk.has_value()) {
// Regenerate the unified address and add it to the keystore. // Regenerate the unified address and add it to the keystore.
auto j = addrmeta.GetDiversifierIndex(); auto j = addrmeta.GetDiversifierIndex();
auto addr = ufvk.value().Address(j, addrmeta.GetReceiverTypes()).value(); auto addr = ufvk.value().Address(j, addrmeta.GetReceiverTypes()).value();
return CCryptoKeyStore::AddUnifiedAddress(addrmeta.GetKeyID(), j, addr); return CCryptoKeyStore::AddTransparentReceiverForUnifiedAddress(addrmeta.GetKeyID(), j, addr);
} }
return true; return true;
@ -5862,13 +5865,14 @@ std::optional<libzcash::UnifiedAddress> LookupUnifiedAddress::operator()(const l
} }
diversifier_index_t j; diversifier_index_t j;
auto metadata = wallet.mapUfvkAddressMetadata.find(ufvkid); // If the wallet is missing metadata at this UFVK id, it is probably
if (metadata != wallet.mapUfvkAddressMetadata.end()) { // corrupt and the node should shut down.
const auto& metadata = wallet.mapUfvkAddressMetadata.at(ufvkid);
librustzcash_sapling_diversifier_index( librustzcash_sapling_diversifier_index(
ufvk.value().GetSaplingKey().value().dk.begin(), ufvk.value().GetSaplingKey().value().dk.begin(),
saplingAddr.d.begin(), saplingAddr.d.begin(),
j.begin()); j.begin());
auto receivers = metadata->second.GetReceivers(j); auto receivers = metadata.GetReceivers(j);
if (receivers.has_value()) { if (receivers.has_value()) {
return ufvk.value().Address(j, receivers.value()); return ufvk.value().Address(j, receivers.value());
} else { } else {
@ -5879,9 +5883,6 @@ std::optional<libzcash::UnifiedAddress> LookupUnifiedAddress::operator()(const l
} else { } else {
return std::nullopt; return std::nullopt;
} }
} else {
return std::nullopt;
}
} }
std::optional<libzcash::UnifiedAddress> LookupUnifiedAddress::operator()(const CScriptID& scriptId) const { std::optional<libzcash::UnifiedAddress> LookupUnifiedAddress::operator()(const CScriptID& scriptId) const {
return std::nullopt; return std::nullopt;
@ -5899,12 +5900,14 @@ std::optional<libzcash::UnifiedAddress> LookupUnifiedAddress::operator()(const C
throw std::runtime_error("CWallet::LookupUnifiedAddress(): UFVK has no P2PKH key part."); throw std::runtime_error("CWallet::LookupUnifiedAddress(): UFVK has no P2PKH key part.");
} }
// Find the set of receivers at the diversifier index. If no metadata is available // If the wallet is missing metadata at this UFVK id, it is probably
// for the ufvk, or we do not know the receiver types for the address produced // corrupt and the node should shut down.
// at this diversifier, we cannot reconstruct the address. const auto& metadata = wallet.mapUfvkAddressMetadata.at(ufvkid);
auto metadata = wallet.mapUfvkAddressMetadata.find(ufvkid);
if (metadata != wallet.mapUfvkAddressMetadata.end()) { // Find the set of receivers at the diversifier index. If we do not
auto receivers = metadata->second.GetReceivers(j); // know the receiver types for the address produced at this
// diversifier, we cannot reconstruct the address.
auto receivers = metadata.GetReceivers(j);
if (receivers.has_value()) { if (receivers.has_value()) {
return ufvk.value().Address(j, receivers.value()); return ufvk.value().Address(j, receivers.value());
} else { } else {
@ -5913,9 +5916,6 @@ std::optional<libzcash::UnifiedAddress> LookupUnifiedAddress::operator()(const C
} else { } else {
return std::nullopt; return std::nullopt;
} }
} else {
return std::nullopt;
}
} }
std::optional<libzcash::UnifiedAddress> LookupUnifiedAddress::operator()(const libzcash::UnknownReceiver& receiver) const { std::optional<libzcash::UnifiedAddress> LookupUnifiedAddress::operator()(const libzcash::UnknownReceiver& receiver) const {
return std::nullopt; return std::nullopt;

View File

@ -685,13 +685,21 @@ public:
class UFVKAddressMetadata class UFVKAddressMetadata
{ {
private: private:
// The account ID may be absent for imported UFVKs, and also may temporarily
// be absent when this data structure is in a partially-reconstructed state
// during the wallet load process.
std::optional<libzcash::AccountId> accountId; std::optional<libzcash::AccountId> accountId;
std::map<libzcash::diversifier_index_t, std::set<libzcash::ReceiverType>> addressReceivers; std::map<libzcash::diversifier_index_t, std::set<libzcash::ReceiverType>> addressReceivers;
public: public:
UFVKAddressMetadata() {} UFVKAddressMetadata() {}
UFVKAddressMetadata(libzcash::AccountId accountId): accountId(accountId) {} UFVKAddressMetadata(libzcash::AccountId accountId): accountId(accountId) {}
const std::map<libzcash::diversifier_index_t, std::set<libzcash::ReceiverType>>& GetAllReceivers() const { /**
* Return all currently known diversifier indices for which addresses
* have been generated, each accompanied by the associated set of receiver
* types that were used when generating that address.
*/
const std::map<libzcash::diversifier_index_t, std::set<libzcash::ReceiverType>>& GetKnownReceiverSetsByDiversifierIndex() const {
return addressReceivers; return addressReceivers;
} }
@ -715,10 +723,11 @@ public:
bool SetReceivers( bool SetReceivers(
const libzcash::diversifier_index_t& j, const libzcash::diversifier_index_t& j,
const std::set<libzcash::ReceiverType>& receivers) { const std::set<libzcash::ReceiverType>& receivers) {
if (addressReceivers.count(j) > 0) { const auto [it, success] = addressReceivers.insert(std::make_pair(j, receivers));
return addressReceivers[j] == receivers; if (success) {
return true;
} else { } else {
return addressReceivers.insert(std::make_pair(j, receivers)).second; return it->second == receivers;
} }
} }
@ -742,12 +751,7 @@ public:
if (addressReceivers.empty()) { if (addressReceivers.empty()) {
return libzcash::diversifier_index_t(0); return libzcash::diversifier_index_t(0);
} else { } else {
auto lastIndex = addressReceivers.rbegin()->first; return addressReceivers.rbegin()->first.succ();
if (lastIndex.increment()) {
return lastIndex;
} else {
return std::nullopt;
}
} }
} }
}; };
@ -1167,8 +1171,7 @@ public:
//! CBasicKeyStore::AddSaplingFullViewingKey is called directly when adding a //! CBasicKeyStore::AddSaplingFullViewingKey is called directly when adding a
//! full viewing key to the keystore, to avoid this override. //! full viewing key to the keystore, to avoid this override.
bool AddSaplingFullViewingKey( bool AddSaplingFullViewingKey(
const libzcash::SaplingExtendedFullViewingKey &extfvk, const libzcash::SaplingExtendedFullViewingKey &extfvk);
const std::optional<libzcash::UFVKId>& ufvkId = std::nullopt);
bool AddSaplingIncomingViewingKey( bool AddSaplingIncomingViewingKey(
const libzcash::SaplingIncomingViewingKey &ivk, const libzcash::SaplingIncomingViewingKey &ivk,
const libzcash::SaplingPaymentAddress &addr); const libzcash::SaplingPaymentAddress &addr);
@ -1199,8 +1202,10 @@ public:
std::pair<libzcash::ZcashdUnifiedSpendingKey, libzcash::AccountId> std::pair<libzcash::ZcashdUnifiedSpendingKey, libzcash::AccountId>
GenerateNewUnifiedSpendingKey(); GenerateNewUnifiedSpendingKey();
//! Generate the next available unified spending key from the wallet's //! Generate the unified spending key for the specified ZIP-32/BIP-44
//! mnemonic seed. //! account identifier from the wallet's mnemonic seed, or returns
//! std::nullopt if the account identifier does not produce a valid
//! spending key for all receiver types.
std::optional<libzcash::ZcashdUnifiedSpendingKey> std::optional<libzcash::ZcashdUnifiedSpendingKey>
GenerateUnifiedSpendingKeyForAccount(libzcash::AccountId accountId); GenerateUnifiedSpendingKeyForAccount(libzcash::AccountId accountId);

View File

@ -21,11 +21,11 @@ bool libzcash::HasShielded(const std::set<ReceiverType>& receiverTypes) {
} }
bool libzcash::HasTransparent(const std::set<ReceiverType>& receiverTypes) { bool libzcash::HasTransparent(const std::set<ReceiverType>& receiverTypes) {
auto has_shielded = [](ReceiverType r) { auto has_transparent = [](ReceiverType r) {
// TODO: update this as support for new transparent protocols is added. // TODO: update this as support for new transparent protocols is added.
return r == ReceiverType::P2PKH || r == ReceiverType::P2SH; return r == ReceiverType::P2PKH || r == ReceiverType::P2SH;
}; };
return std::find_if(receiverTypes.begin(), receiverTypes.end(), has_shielded) != receiverTypes.end(); return std::find_if(receiverTypes.begin(), receiverTypes.end(), has_transparent) != receiverTypes.end();
} }
std::optional<ZcashdUnifiedSpendingKey> ZcashdUnifiedSpendingKey::ForAccount( std::optional<ZcashdUnifiedSpendingKey> ZcashdUnifiedSpendingKey::ForAccount(
@ -132,5 +132,5 @@ std::optional<std::pair<UnifiedAddress, diversifier_index_t>> ZcashdUnifiedFullV
std::optional<std::pair<UnifiedAddress, diversifier_index_t>> ZcashdUnifiedFullViewingKey::FindAddress( std::optional<std::pair<UnifiedAddress, diversifier_index_t>> ZcashdUnifiedFullViewingKey::FindAddress(
const diversifier_index_t& j) const { const diversifier_index_t& j) const {
return FindAddress(j, {ReceiverType::P2PKH, ReceiverType::Sapling, ReceiverType::Orchard}); return FindAddress(j, {ReceiverType::P2PKH, ReceiverType::Sapling});
} }

View File

@ -15,7 +15,7 @@ enum class ReceiverType: uint32_t {
P2PKH = 0x00, P2PKH = 0x00,
P2SH = 0x01, P2SH = 0x01,
Sapling = 0x02, Sapling = 0x02,
Orchard = 0x03 //Orchard = 0x03
}; };
/** /**