diff --git a/src/key_io.cpp b/src/key_io.cpp index f118e3625..d2f7b2ded 100644 --- a/src/key_io.cpp +++ b/src/key_io.cpp @@ -270,36 +270,59 @@ std::string EncodePaymentAddress(const libzcash::PaymentAddress& zaddr) return boost::apply_visitor(PaymentAddressEncoder(Params()), zaddr); } -libzcash::PaymentAddress DecodePaymentAddress(const std::string& str) +template +T1 DecodeAny( + const std::string& str, + std::pair sprout, + boost::optional> sapling) { std::vector data; if (DecodeBase58Check(str, data)) { - const std::vector& zaddr_prefix = Params().Base58Prefix(CChainParams::ZCPAYMENT_ADDRRESS); - if ((data.size() == libzcash::SerializedSproutPaymentAddressSize + zaddr_prefix.size()) && - std::equal(zaddr_prefix.begin(), zaddr_prefix.end(), data.begin())) { - CSerializeData serialized(data.begin() + zaddr_prefix.size(), data.end()); + const std::vector& prefix = Params().Base58Prefix(sprout.first); + if ((data.size() == sprout.second + prefix.size()) && + std::equal(prefix.begin(), prefix.end(), data.begin())) { + CSerializeData serialized(data.begin() + prefix.size(), data.end()); CDataStream ss(serialized, SER_NETWORK, PROTOCOL_VERSION); - libzcash::SproutPaymentAddress ret; + T2 ret; ss >> ret; + memory_cleanse(serialized.data(), serialized.size()); + memory_cleanse(data.data(), data.size()); return ret; } } - data.clear(); - auto bech = bech32::Decode(str); - if (bech.first == Params().Bech32HRP(CChainParams::SAPLING_PAYMENT_ADDRESS) && - bech.second.size() == ConvertedSaplingPaymentAddressSize) { - // Bech32 decoding - data.reserve((bech.second.size() * 5) / 8); - if (ConvertBits<5, 8, false>([&](unsigned char c) { data.push_back(c); }, bech.second.begin(), bech.second.end())) { - CDataStream ss(data, SER_NETWORK, PROTOCOL_VERSION); - libzcash::SaplingPaymentAddress ret; - ss >> ret; - return ret; + + if (sapling) { + data.clear(); + auto bech = bech32::Decode(str); + if (bech.first == Params().Bech32HRP(sapling.get().first) && + bech.second.size() == sapling.get().second) { + // Bech32 decoding + data.reserve((bech.second.size() * 5) / 8); + if (ConvertBits<5, 8, false>([&](unsigned char c) { data.push_back(c); }, bech.second.begin(), bech.second.end())) { + CDataStream ss(data, SER_NETWORK, PROTOCOL_VERSION); + T3 ret; + ss >> ret; + memory_cleanse(data.data(), data.size()); + return ret; + } } } + + memory_cleanse(data.data(), data.size()); return libzcash::InvalidEncoding(); } +libzcash::PaymentAddress DecodePaymentAddress(const std::string& str) +{ + return DecodeAny( + str, + std::make_pair(CChainParams::ZCPAYMENT_ADDRRESS, libzcash::SerializedSproutPaymentAddressSize), + std::make_pair(CChainParams::SAPLING_PAYMENT_ADDRESS, ConvertedSaplingPaymentAddressSize) + ); +} + bool IsValidPaymentAddressString(const std::string& str) { return IsValidPaymentAddress(DecodePaymentAddress(str)); } @@ -311,22 +334,12 @@ std::string EncodeViewingKey(const libzcash::ViewingKey& vk) libzcash::ViewingKey DecodeViewingKey(const std::string& str) { - std::vector data; - if (DecodeBase58Check(str, data)) { - const std::vector& vk_prefix = Params().Base58Prefix(CChainParams::ZCVIEWING_KEY); - if ((data.size() == libzcash::SerializedSproutViewingKeySize + vk_prefix.size()) && - std::equal(vk_prefix.begin(), vk_prefix.end(), data.begin())) { - CSerializeData serialized(data.begin() + vk_prefix.size(), data.end()); - CDataStream ss(serialized, SER_NETWORK, PROTOCOL_VERSION); - libzcash::SproutViewingKey ret; - ss >> ret; - memory_cleanse(serialized.data(), serialized.size()); - memory_cleanse(data.data(), data.size()); - return ret; - } - } - memory_cleanse(data.data(), data.size()); - return libzcash::InvalidEncoding(); + return DecodeAny( + str, + std::make_pair(CChainParams::ZCVIEWING_KEY, libzcash::SerializedSproutViewingKeySize), + boost::none + ); } std::string EncodeSpendingKey(const libzcash::SpendingKey& zkey) @@ -336,34 +349,12 @@ std::string EncodeSpendingKey(const libzcash::SpendingKey& zkey) libzcash::SpendingKey DecodeSpendingKey(const std::string& str) { - std::vector data; - if (DecodeBase58Check(str, data)) { - const std::vector& zkey_prefix = Params().Base58Prefix(CChainParams::ZCSPENDING_KEY); - if ((data.size() == libzcash::SerializedSproutSpendingKeySize + zkey_prefix.size()) && - std::equal(zkey_prefix.begin(), zkey_prefix.end(), data.begin())) { - CSerializeData serialized(data.begin() + zkey_prefix.size(), data.end()); - CDataStream ss(serialized, SER_NETWORK, PROTOCOL_VERSION); - libzcash::SproutSpendingKey ret; - ss >> ret; - memory_cleanse(serialized.data(), serialized.size()); - memory_cleanse(data.data(), data.size()); - return ret; - } - } - data.clear(); - auto bech = bech32::Decode(str); - if (bech.first == Params().Bech32HRP(CChainParams::SAPLING_EXTENDED_SPEND_KEY) && - bech.second.size() == ConvertedSaplingExtendedSpendingKeySize) { - // Bech32 decoding - data.reserve((bech.second.size() * 5) / 8); - if (ConvertBits<5, 8, false>([&](unsigned char c) { data.push_back(c); }, bech.second.begin(), bech.second.end())) { - CDataStream ss(data, SER_NETWORK, PROTOCOL_VERSION); - libzcash::SaplingExtendedSpendingKey ret; - ss >> ret; - memory_cleanse(data.data(), data.size()); - return ret; - } - } - memory_cleanse(data.data(), data.size()); - return libzcash::InvalidEncoding(); + + return DecodeAny( + str, + std::make_pair(CChainParams::ZCSPENDING_KEY, libzcash::SerializedSproutSpendingKeySize), + std::make_pair(CChainParams::SAPLING_EXTENDED_SPEND_KEY, ConvertedSaplingExtendedSpendingKeySize) + ); }