From 9220ebb0dedfc588f6f450dd59c82eb390b076aa Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Sun, 20 Mar 2022 12:27:12 -0600 Subject: [PATCH] Make CWallet::DefaultReceiverTypes height-dependent. Addresses https://github.com/zcash/zcash/pull/5700#discussion_r830538538 --- src/wallet/rpcwallet.cpp | 6 ++++-- src/wallet/wallet.cpp | 17 ++++++++++++++--- src/wallet/wallet.h | 6 +++--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index 3de78ee15..61ad6d051 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -3110,7 +3110,9 @@ UniValue z_getaddressforaccount(const UniValue& params, bool fHelp) throw JSONRPCError(RPC_WALLET_ENCRYPTION_FAILED, "Error: the Orchard wallet experimental extensions are disabled."); } - LOCK(pwalletMain->cs_wallet); + // cs_main is required for obtaining the current height, for + // CWallet::DefaultReceiverTypes + LOCK2(cs_main, pwalletMain->cs_wallet); int64_t accountInt = params[0].get_int64(); if (accountInt < 0 || accountInt >= ZCASH_LEGACY_ACCOUNT) { @@ -3136,7 +3138,7 @@ UniValue z_getaddressforaccount(const UniValue& params, bool fHelp) } if (receivers.empty()) { // Default is the best and second-best shielded pools, and the transparent pool. - receivers = CWallet::DefaultReceiverTypes(); + receivers = CWallet::DefaultReceiverTypes(chainActive.Height()); } std::optional j = std::nullopt; diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 5c4c0f4cd..c9446f94d 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -61,7 +61,9 @@ const char * DEFAULT_WALLET_DAT = "wallet.dat"; */ CFeeRate CWallet::minTxFee = CFeeRate(DEFAULT_TRANSACTION_MINFEE); -std::set CWallet::DefaultReceiverTypes() { +std::set CWallet::DefaultReceiverTypes(int nHeight) { + // For now, just ignore the height information because the default + // is always the same. return {ReceiverType::P2PKH, ReceiverType::Sapling, ReceiverType::Orchard}; } @@ -860,7 +862,16 @@ std::pair CWallet::GetPaymentAddressForRecipient( const uint256& txid, const libzcash::RecipientAddress& recipient) const { + AssertLockHeld(cs_wallet); + auto self = this; + + auto nHeight = chainActive.Height(); + auto wtxPtr = mapWallet.find(txid); + if (wtxPtr != mapWallet.end()) { + nHeight = wtxPtr->second.GetDepthInMainChain(); + } + auto ufvk = self->GetUFVKForReceiver(RecipientAddressToReceiver(recipient)); std::pair defaultAddress = std::visit(match { [&](const CKeyID& addr) { @@ -909,7 +920,7 @@ std::pair CWallet::GetPaymentAddressForRecipient( if (j.value().second) { // std::get is safe here because we know we have a valid Sapling diversifier index auto defaultUA = std::get>( - ufvk->Address(j.value().first, CWallet::DefaultReceiverTypes())); + ufvk->Address(j.value().first, CWallet::DefaultReceiverTypes(nHeight))); return std::make_pair(PaymentAddress{defaultUA.first}, RecipientType::WalletExternalAddress); } else { return std::make_pair(PaymentAddress{addr}, RecipientType::WalletInternalAddress); @@ -932,7 +943,7 @@ std::pair CWallet::GetPaymentAddressForRecipient( if (j.has_value()) { if (j.value().second) { // Attempt to reproduce the original unified address - auto genResult = ufvk->Address(j.value().first, CWallet::DefaultReceiverTypes()); + auto genResult = ufvk->Address(j.value().first, CWallet::DefaultReceiverTypes(nHeight)); auto defaultUA = std::get_if>(&genResult); if (defaultUA != nullptr) { return std::make_pair(PaymentAddress{defaultUA->first}, RecipientType::WalletExternalAddress); diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 5ea51b40e..31b1922bb 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -1741,10 +1741,10 @@ public: static CAmount GetRequiredFee(unsigned int nTxBytes); /** - * The current set of default receiver types used when the wallet generates - * unified addresses + * The set of default receiver types used when the wallet generates + * unified addresses, as of the specified chain height. */ - static std::set DefaultReceiverTypes(); + static std::set DefaultReceiverTypes(int nHeight); private: bool NewKeyPool();