Consistently group Sapling addresses by IVK for every source.

This commit is contained in:
Kris Nuttycombe 2021-10-07 19:58:31 -06:00
parent 2221bf5484
commit 23507899a3
2 changed files with 77 additions and 72 deletions

View File

@ -448,31 +448,83 @@ UniValue listaddresses(const UniValue& params, bool fHelp)
}
}
// inner function that groups Sapling addresses by IVK for use in all sources
// that can contain Sapling addresses
auto add_sapling = [&](
const std::set<SaplingPaymentAddress>& addrs,
const PaymentAddressSource source,
UniValue& entry
) {
bool hasData = false;
std::map<SaplingIncomingViewingKey, std::vector<SaplingPaymentAddress>> ivkAddrs;
for (const SaplingPaymentAddress& addr : addrs) {
if (GetSourceForPaymentAddress(pwalletMain)(addr) == source) {
SaplingIncomingViewingKey ivkRet;
if (pwalletMain->GetSaplingIncomingViewingKey(addr, ivkRet)) {
ivkAddrs[ivkRet].push_back(addr);
}
}
}
{
UniValue ivk_groups(UniValue::VARR);
for (const auto& [ivk, addrs] : ivkAddrs) {
UniValue sapling_addrs(UniValue::VARR);
for (const SaplingPaymentAddress& addr : addrs) {
sapling_addrs.push_back(keyIO.EncodePaymentAddress(addr));
}
UniValue sapling_obj(UniValue::VOBJ);
if (source == LegacyHDSeed) {
std::string hdKeypath = pwalletMain->mapSaplingZKeyMetadata[ivk].hdKeypath;
std::optional<unsigned long> accountId = libzcash::ParseZip32KeypathAccount(hdKeypath);
if (accountId.has_value()) {
sapling_obj.pushKV("zip32_account_id", (uint64_t) accountId.value());
}
}
sapling_obj.pushKV("addresses", sapling_addrs);
ivk_groups.push_back(sapling_obj);
}
if (!ivk_groups.empty()) {
entry.pushKV("sapling", ivk_groups);
hasData = true;
}
}
return hasData;
};
/// imported source
{
UniValue entry(UniValue::VOBJ);
entry.pushKV("source", "imported");
bool hasData = false;
{
UniValue imported_sapling_addrs(UniValue::VARR);
for (const SaplingPaymentAddress& addr : saplingAddresses) {
UniValue imported_sprout_addrs(UniValue::VARR);
for (const SproutPaymentAddress& addr : sproutAddresses) {
if (GetSourceForPaymentAddress(pwalletMain)(addr) == Imported) {
imported_sapling_addrs.push_back(keyIO.EncodePaymentAddress(addr));
imported_sprout_addrs.push_back(keyIO.EncodePaymentAddress(addr));
}
}
if (!imported_sapling_addrs.empty()) {
UniValue imported_sapling_obj(UniValue::VOBJ);
imported_sapling_obj.pushKV("addresses", imported_sapling_addrs);
UniValue imported_sapling(UniValue::VARR);
imported_sapling.push_back(imported_sapling_obj);
entry.pushKV("sapling", imported_sapling);
if (!imported_sprout_addrs.empty()) {
UniValue imported_sprout(UniValue::VOBJ);
imported_sprout.pushKV("addresses", imported_sprout_addrs);
entry.pushKV("sprout", imported_sprout);
hasData = true;
}
}
hasData |= add_sapling(saplingAddresses, Imported, entry);
if (hasData) {
ret.push_back(entry);
}
@ -513,24 +565,7 @@ UniValue listaddresses(const UniValue& params, bool fHelp)
}
}
{
UniValue watchonly_sapling_addrs(UniValue::VARR);
for (const SaplingPaymentAddress& addr : saplingAddresses) {
if (!HaveSpendingKeyForPaymentAddress(pwalletMain)(addr)) {
watchonly_sapling_addrs.push_back(keyIO.EncodePaymentAddress(addr));
}
}
if (!watchonly_sapling_addrs.empty()) {
UniValue watchonly_sapling_obj(UniValue::VOBJ);
watchonly_sapling_obj.pushKV("addresses", watchonly_sapling_addrs);
UniValue watchonly_sapling(UniValue::VARR);
watchonly_sapling.push_back(watchonly_sapling_obj);
entry.pushKV("sapling", watchonly_sapling);
hasData = true;
}
}
hasData |= add_sapling(saplingAddresses, ImportedWatchOnly, entry);
if (hasData) {
ret.push_back(entry);
@ -541,48 +576,12 @@ UniValue listaddresses(const UniValue& params, bool fHelp)
{
UniValue entry(UniValue::VOBJ);
entry.pushKV("source", "legacy_hdseed");
bool hasData = false;
std::map<SaplingIncomingViewingKey, std::vector<SaplingPaymentAddress>> ivkAddrs;
for (const SaplingPaymentAddress& addr : saplingAddresses) {
if (GetSourceForPaymentAddress(pwalletMain)(addr) == LegacyHDSeed) {
SaplingIncomingViewingKey ivkRet;
if (pwalletMain->GetSaplingIncomingViewingKey(addr, ivkRet)) {
ivkAddrs[ivkRet].push_back(addr);
}
}
}
{
UniValue legacy_sapling(UniValue::VARR);
for (const auto& [ivk, addrs] : ivkAddrs) {
UniValue legacy_sapling_addrs(UniValue::VARR);
for (const SaplingPaymentAddress& addr : addrs) {
legacy_sapling_addrs.push_back(keyIO.EncodePaymentAddress(addr));
}
// this is known to be nonempty from the GetSourceForPaymentAddress check.
std::string hdKeypath = pwalletMain->mapSaplingZKeyMetadata[ivk].hdKeypath;
std::optional<unsigned long> accountId = libzcash::ParseZip32KeypathAccount(hdKeypath);
UniValue legacy_sapling_obj(UniValue::VOBJ);
if (accountId.has_value()) {
legacy_sapling_obj.pushKV("zip32_account_id", (uint64_t) accountId.value());
}
legacy_sapling_obj.pushKV("addresses", legacy_sapling_addrs);
legacy_sapling.push_back(legacy_sapling_obj);
}
if (!legacy_sapling.empty()) {
entry.pushKV("sapling", legacy_sapling);
hasData = true;
}
}
bool hasData = add_sapling(saplingAddresses, LegacyHDSeed, entry);
if (hasData) {
ret.push_back(entry);
}
};
}
return ret;

View File

@ -701,14 +701,20 @@ BOOST_AUTO_TEST_CASE(rpc_wallet_z_importexport)
sproutCountMatch = (sprout_addrs.size() == n1);
}
if (source.get_str() == "imported") {
auto sapling_obj = find_value(a.get_obj(), "sapling").get_array()[0];
auto sapling_addrs = find_value(sapling_obj, "addresses").get_array();
saplingSpendingKeyMatch = (sapling_addrs.size() == n1 / 2);
int addr_count = 0;
for (auto sapling_obj : find_value(a.get_obj(), "sapling").get_array().getValues()) {
auto sapling_addrs = find_value(sapling_obj, "addresses").get_array();
addr_count += sapling_addrs.size();
}
saplingSpendingKeyMatch = (addr_count == n1 / 2);
}
if (source.get_str() == "imported_watchonly") {
auto sapling_obj = find_value(a.get_obj(), "sapling").get_array()[0];
auto sapling_addrs = find_value(sapling_obj, "addresses").get_array();
saplingIVKMatch = (sapling_addrs.size() == n1 / 2);
int addr_count = 0;
for (auto sapling_obj : find_value(a.get_obj(), "sapling").get_array().getValues()) {
auto sapling_addrs = find_value(sapling_obj, "addresses").get_array();
addr_count += sapling_addrs.size();
}
saplingIVKMatch = (addr_count == n1 / 2);
}
}
BOOST_CHECK(sproutCountMatch);