From b29ca10b8d96a0b9ffbaafdbf810b652af9f0b9b Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Thu, 9 Dec 2021 18:34:04 -0700 Subject: [PATCH] Add unified address generation. Generate unified addresses from UFVKs, and add the associated metadata to the wallet database. --- src/keystore.cpp | 12 +++- src/keystore.h | 10 ++- src/wallet/wallet.cpp | 116 +++++++++++++++++++++++++++++++--- src/wallet/wallet.h | 32 ++++++++-- src/wallet/walletdb.cpp | 12 ++++ src/wallet/walletdb.h | 9 +-- src/zcash/address/unified.cpp | 42 +++++++++--- src/zcash/address/unified.h | 34 +++++++++- 8 files changed, 237 insertions(+), 30 deletions(-) diff --git a/src/keystore.cpp b/src/keystore.cpp index 4230fd7cb..54fdb2a92 100644 --- a/src/keystore.cpp +++ b/src/keystore.cpp @@ -309,7 +309,7 @@ bool CBasicKeyStore::AddUnifiedFullViewingKey( return true; } -bool CBasicKeyStore::AddUnifiedAddress( +void CBasicKeyStore::AddUnifiedAddress( const libzcash::UFVKId& keyId, const libzcash::UnifiedAddress& ua) { @@ -331,3 +331,13 @@ bool CBasicKeyStore::AddUnifiedAddress( } } +std::optional CBasicKeyStore::GetUnifiedFullViewingKey( + const libzcash::UFVKId& keyId) +{ + auto mi = mapUnifiedFullViewingKeys.find(keyId); + if (mi != mapUnifiedFullViewingKeys.end()) { + return mi->second; + } else { + return std::nullopt; + } +} diff --git a/src/keystore.h b/src/keystore.h index dd731195a..3f611fa39 100644 --- a/src/keystore.h +++ b/src/keystore.h @@ -108,10 +108,13 @@ public: const libzcash::ZcashdUnifiedFullViewingKey &ufvk ) = 0; - virtual bool AddUnifiedAddress( + virtual void AddUnifiedAddress( const libzcash::UFVKId& keyId, const libzcash::UnifiedAddress &ua ) = 0; + + virtual std::optional GetUnifiedFullViewingKey( + const libzcash::UFVKId& keyId) = 0; }; typedef std::map KeyMap; @@ -346,9 +349,12 @@ public: * Add the transparent component of the unified address, if any, * to the keystore to make it possible to identify the */ - virtual bool AddUnifiedAddress( + virtual void AddUnifiedAddress( const libzcash::UFVKId& keyId, const libzcash::UnifiedAddress &ua); + + virtual std::optional GetUnifiedFullViewingKey( + const libzcash::UFVKId& keyId); }; typedef std::vector > CKeyingMaterial; diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index a9f1eb9f1..16c6019a6 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -474,7 +474,7 @@ std::optional> CWallet::GetUnifiedFullViewingKeyByAccount(libzcash::AccountId accountId) { + if (!mnemonicHDChain.has_value()) { + throw std::runtime_error( + "CWallet::GenerateNewUnifiedSpendingKey(): Wallet is missing mnemonic seed metadata."); + } + + auto seedfp = mnemonicHDChain.value().GetSeedFingerprint(); + auto i = mapUnifiedKeyMetadata.find(std::make_pair(seedfp, accountId)); + if (i != mapUnifiedKeyMetadata.end()) { + auto keyId = i->second.GetKeyID(); + auto key = CCryptoKeyStore::GetUnifiedFullViewingKey(keyId); + if (key.has_value()) { + return std::make_pair(keyId, key.value()); + } else { + return std::nullopt; + } + } else { + return std::nullopt; + } } -void CWallet::LoadUnifiedKeyMetadata(const ZcashdUnifiedSpendingKeyMetadata &meta) +UAGenerationResult CWallet::GenerateUnifiedAddress( + const libzcash::AccountId& accountId, + const libzcash::diversifier_index_t& j, + const std::set& receiverTypes) +{ + if (!libzcash::HasShielded(receiverTypes)) { + return AddressGenerationError::InvalidReceiverTypes; + } + + auto identifiedKey = GetUnifiedFullViewingKeyByAccount(accountId); + if (identifiedKey.has_value()) { + auto ufvkid = identifiedKey.value().first; + auto ufvk = identifiedKey.value().second; + + // Check whether an address has already been generated for this + // diversifier index. If so, ensure that the set of receiver types + // being requested is the same as the set of receiver types that was + // previously generated; if so, return the previously generated address, + // otherwise return an error. + if (mapUnifiedAddressMetadata.count(ufvkid) > 0) { + const auto& accountKeys = mapUnifiedAddressMetadata.at(ufvkid); + if (accountKeys.count(j) > 0) { + if (accountKeys.at(j) == receiverTypes) { + ZcashdUnifiedAddressMetadata addrmeta(ufvkid, j, receiverTypes); + auto addr = ufvk.Address(j, receiverTypes); + return std::make_pair(addr.value(), addrmeta); + } else { + return AddressGenerationError::ExistingAddressMismatch; + } + } + } + + // Find a working diversifier and construct the associated address. + auto found = ufvk.FindAddress(j, receiverTypes); + auto diversifierIndex = found.second; + + // Persist the newly created address to the keystore + AddUnifiedAddress(ufvkid, found.first); + + // Save the metadata for the generated address so that we can re-derive + // it in the future. + ZcashdUnifiedAddressMetadata addrmeta(ufvkid, found.second, receiverTypes); + mapUnifiedAddressMetadata[ufvkid].insert({diversifierIndex, receiverTypes}); + if (fFileBacked) { + CWalletDB(strWalletFile).WriteUnifiedAddressMetadata(addrmeta); + } + return std::make_pair(found.first, addrmeta); + } else { + return AddressGenerationError::NoSuchAccount; + } +} + +bool CWallet::LoadUnifiedFullViewingKey(const libzcash::UnifiedFullViewingKey &key) +{ + auto ufvkid = key.GetKeyID(Params()); + auto zufvk = ZcashdUnifiedFullViewingKey::FromUnifiedFullViewingKey(key); + if (mapUnifiedAddressMetadata.count(ufvkid) > 0) { + // restore unified addresses that have been previously generated to the + // keystore + for (const auto &[j, receiverTypes] : mapUnifiedAddressMetadata[ufvkid]) { + auto addr = zufvk.Address(j, receiverTypes).value(); + AddUnifiedAddress(ufvkid, addr); + } + } + return CCryptoKeyStore::AddUnifiedFullViewingKey(ufvkid, zufvk); +} + +void CWallet::LoadUnifiedKeyMetadata(const ZcashdUnifiedSpendingKeyMetadata &skmeta) { AssertLockHeld(cs_wallet); // mapUnifiedKeyMetadata - mapUnifiedKeyMetadata.insert({meta.GetKeyID(), meta}); + auto metaKey = std::make_pair(skmeta.GetSeedFingerprint(), skmeta.GetAccountId()); + mapUnifiedKeyMetadata.insert({metaKey, skmeta}); +} + +bool CWallet::LoadUnifiedAddressMetadata(const ZcashdUnifiedAddressMetadata &addrmeta) +{ + AssertLockHeld(cs_wallet); // mapUnifiedKeyMetadata + auto ufvk = GetUnifiedFullViewingKey(addrmeta.GetKeyID()); + if (ufvk.has_value()) { + // restore unified addresses that have been previously generated + auto addr = ufvk.value().Address(addrmeta.GetDiversifierIndex(), addrmeta.GetReceiverTypes()); + if (addr.has_value()) { + AddUnifiedAddress(addrmeta.GetKeyID(), addr.value()); + } else { + // an error has occurred; the ufvk is loaded but cannot reproduce the + // address identified by the address metadata. + return false; + } + } + return true; } void CWallet::LoadKeyMetadata(const CPubKey &pubkey, const CKeyMetadata &meta) diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index d497f5b56..2d05c82c1 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -402,6 +402,17 @@ public: bool AcceptToMemoryPool(bool fLimitFree=true, bool fRejectAbsurdFee=true); }; +enum class AddressGenerationError { + NoSuchAccount, + InvalidReceiverTypes, + ExistingAddressMismatch, + NoSaplingAddressForDiversifier +}; + +typedef std::variant< + std::pair, + AddressGenerationError> UAGenerationResult; + /** * A transaction with a bunch of additional info that only the owner cares about. * It includes any unrecorded transactions needed to link it back to the block chain. @@ -843,7 +854,8 @@ public: std::map mapSproutZKeyMetadata; std::map mapSaplingZKeyMetadata; - std::map mapUnifiedKeyMetadata; + std::map, ZcashdUnifiedSpendingKeyMetadata> mapUnifiedKeyMetadata; + std::map>> mapUnifiedAddressMetadata; typedef std::map MasterKeyMap; MasterKeyMap mapMasterKeys; @@ -1129,10 +1141,22 @@ public: std::optional> GenerateUnifiedSpendingKeyForAccount(libzcash::AccountId accountId); - bool AddUnifiedFullViewingKey(const libzcash::UnifiedFullViewingKey &ufvk); - bool LoadUnifiedFullViewingKey(const libzcash::UnifiedFullViewingKey &key); + //! Retrieves the UFVK derived from the wallet's mnemonic seed for the specified account. + std::optional> + GetUnifiedFullViewingKeyByAccount(libzcash::AccountId account); - void LoadUnifiedKeyMetadata(const ZcashdUnifiedSpendingKeyMetadata &meta); + //! Generate a new unified address for the specified account, diversifier, and + //! set of receiver types. + UAGenerationResult GenerateUnifiedAddress( + const libzcash::AccountId& accountId, + const libzcash::diversifier_index_t& j, + const std::set& receivers); + + bool AddUnifiedFullViewingKey(const libzcash::UnifiedFullViewingKey &ufvk); + bool LoadUnifiedFullViewingKey(const libzcash::UnifiedFullViewingKey &ufvk); + + void LoadUnifiedKeyMetadata(const ZcashdUnifiedSpendingKeyMetadata &skmeta); + bool LoadUnifiedAddressMetadata(const ZcashdUnifiedAddressMetadata &addrmeta); /** * Increment the next transaction order id diff --git a/src/wallet/walletdb.cpp b/src/wallet/walletdb.cpp index 6b0660810..72080dc19 100644 --- a/src/wallet/walletdb.cpp +++ b/src/wallet/walletdb.cpp @@ -235,6 +235,13 @@ bool CWalletDB::WriteUnifiedFullViewingKey(const libzcash::UnifiedFullViewingKey return Write(std::make_pair(std::string("unifiedfvk"), ufvkId), ufvk.Encode(Params())); } +bool CWalletDB::WriteUnifiedAddressMetadata(const ZcashdUnifiedAddressMetadata& addrmeta) +{ + nWalletDBUpdateCounter++; + auto ufvkId = addrmeta.GetKeyID(); + return Write(std::make_pair(std::string("unifiedaddrmeta"), ufvkId), addrmeta); +} + // // // @@ -684,6 +691,11 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, auto keymeta = ZcashdUnifiedSpendingKeyMetadata::Read(ssValue); pwallet->LoadUnifiedKeyMetadata(keymeta); } + else if (strType == "unifiedaddrmeta") + { + auto keymeta = ZcashdUnifiedAddressMetadata::Read(ssValue); + pwallet->LoadUnifiedAddressMetadata(keymeta); + } else if (strType == "pool") { int64_t nIndex; diff --git a/src/wallet/walletdb.h b/src/wallet/walletdb.h index 124a78f73..71f5404d2 100644 --- a/src/wallet/walletdb.h +++ b/src/wallet/walletdb.h @@ -251,14 +251,14 @@ class ZcashdUnifiedAddressMetadata { private: libzcash::UFVKId ufvkId; libzcash::diversifier_index_t diversifierIndex; - std::vector receiverTypes; + std::set receiverTypes; ZcashdUnifiedAddressMetadata() {} public: ZcashdUnifiedAddressMetadata( libzcash::UFVKId ufvkId, libzcash::diversifier_index_t diversifierIndex, - std::vector receiverTypes): + std::set receiverTypes): ufvkId(ufvkId), diversifierIndex(diversifierIndex), receiverTypes(receiverTypes) {} libzcash::UFVKId GetKeyID() const { @@ -267,7 +267,7 @@ public: libzcash::diversifier_index_t GetDiversifierIndex() const { return diversifierIndex; } - const std::vector& GetReceiverTypes() const { + const std::set& GetReceiverTypes() const { return receiverTypes; } @@ -282,7 +282,7 @@ public: READWRITE(serReceiverTypes); receiverTypes.clear(); for (ReceiverTypeSer r : serReceiverTypes) - receiverTypes.push_back(r.t); + receiverTypes.insert(r.t); } else { std::vector serReceiverTypes; for (libzcash::ReceiverType r : receiverTypes) @@ -381,6 +381,7 @@ public: bool WriteUnifiedSpendingKeyMetadata(const ZcashdUnifiedSpendingKeyMetadata& keymeta); bool WriteUnifiedFullViewingKey(const libzcash::UnifiedFullViewingKey& ufvk); + bool WriteUnifiedAddressMetadata(const ZcashdUnifiedAddressMetadata& addrmeta); static void IncrementUpdateCounter(); static unsigned int GetUpdateCounter(); diff --git a/src/zcash/address/unified.cpp b/src/zcash/address/unified.cpp index 3d328467e..1bf4b13b5 100644 --- a/src/zcash/address/unified.cpp +++ b/src/zcash/address/unified.cpp @@ -12,6 +12,14 @@ using namespace libzcash; // Unified Keys // +bool libzcash::HasShielded(const std::set& receiverTypes) { + auto has_shielded = [](ReceiverType r) { + // TODO: update this as support for new protocols is added. + return r == ReceiverType::Sapling; + }; + return std::find_if(receiverTypes.begin(), receiverTypes.end(), has_shielded) != receiverTypes.end(); +} + std::optional ZcashdUnifiedSpendingKey::ForAccount( const HDSeed& seed, uint32_t bip44CoinType, @@ -63,10 +71,17 @@ ZcashdUnifiedFullViewingKey ZcashdUnifiedFullViewingKey::FromUnifiedFullViewingK return result; } -std::optional ZcashdUnifiedFullViewingKey::Address(diversifier_index_t j) const { - UnifiedAddress ua; +std::optional ZcashdUnifiedFullViewingKey::Address( + const diversifier_index_t& j, + const std::set& receiverTypes) const +{ + if (!HasShielded(receiverTypes)) { + throw std::runtime_error("Unified addresses must include a shielded receiver."); + } - if (saplingKey.has_value()) { + UnifiedAddress ua; + if (saplingKey.has_value() && + std::find(receiverTypes.begin(), receiverTypes.end(), ReceiverType::Sapling) != receiverTypes.end()) { auto saplingAddress = saplingKey.value().Address(j); if (saplingAddress.has_value()) { ua.AddReceiver(saplingAddress.value()); @@ -75,7 +90,8 @@ std::optional ZcashdUnifiedFullViewingKey::Address(diversifier_i } } - if (transparentKey.has_value()) { + if (transparentKey.has_value() && + std::find(receiverTypes.begin(), receiverTypes.end(), ReceiverType::P2PKH) != receiverTypes.end()) { const auto& tkey = transparentKey.value(); auto childIndex = j.ToTransparentChildIndex(); if (!childIndex.has_value()) return std::nullopt; @@ -98,12 +114,20 @@ std::optional ZcashdUnifiedFullViewingKey::Address(diversifier_i return ua; } -std::pair ZcashdUnifiedFullViewingKey::FindAddress(diversifier_index_t j) const { - auto addr = Address(j); +std::pair ZcashdUnifiedFullViewingKey::FindAddress( + const diversifier_index_t& j, + const std::set& receiverTypes) const { + diversifier_index_t j0(j); + auto addr = Address(j0, receiverTypes); while (!addr.has_value()) { - if (!j.increment()) + if (!j0.increment()) throw std::runtime_error(std::string(__func__) + ": diversifier index overflow.");; - addr = Address(j); + addr = Address(j0, receiverTypes); } - return std::make_pair(addr.value(), j); + return std::make_pair(addr.value(), j0); +} + +std::pair ZcashdUnifiedFullViewingKey::FindAddress( + const diversifier_index_t& j) const { + return FindAddress(j, {ReceiverType::P2PKH, ReceiverType::Sapling, ReceiverType::Orchard}); } diff --git a/src/zcash/address/unified.h b/src/zcash/address/unified.h index c8cce33cf..53a4755cb 100644 --- a/src/zcash/address/unified.h +++ b/src/zcash/address/unified.h @@ -17,6 +17,12 @@ enum class ReceiverType: uint32_t { Orchard = 0x03 }; +/** + * Test whether the specified list of receiver types contains a + * shielded receiver type + */ +bool HasShielded(const std::set& receiverTypes); + class ZcashdUnifiedSpendingKey; // prototypes for the classes handling ZIP-316 encoding (in Address.hpp) @@ -51,9 +57,33 @@ public: return saplingKey; } - std::optional Address(diversifier_index_t j) const; + /** + * Creates a new unified address having the specified receiver types, at the specified + * diversifier index, unless the diversifer index would generate an invalid receiver. + * Returns `std::nullopt` if the diversifier index does not produce a valid receiver + * for one or more of the specified receiver types; under this circumstance, the caller + * should usually try successive diversifier indices until the operation returns a + * non-null value. + * + * This method will throw if `receiverTypes` does not include a shielded receiver type. + */ + std::optional Address( + const diversifier_index_t& j, + const std::set& receiverTypes) const; - std::pair FindAddress(diversifier_index_t j) const; + /** + * Find the smallest diversifier index >= `j` such that it generates a valid + * unified address according to the conditions specified in the documentation + * for the `Address` method above, and returns the newly created address along + * with the diversifier index used to produce it. + * + * This method will throw if `receiverTypes` does not include a shielded receiver type. + */ + std::pair FindAddress( + const diversifier_index_t& j, + const std::set& receiverTypes) const; + + std::pair FindAddress(const diversifier_index_t& j) const; }; /**