Auto merge of #4328 - oxarbitrage:issue3344, r=str4d

Reduce duplication in key_io decode

Attempt to fix https://github.com/zcash/zcash/issues/3344 by adding a new `DecodeAny` template function with arguments that will handle all the cases. Then on each decoding this new function is called with the appropriate arguments resulting in reducing code duplication.

Some of the complexity(`boost::optional`) will be removed if `ViewingKey` is implemented for sapling.
This commit is contained in:
Homu 2020-02-17 21:50:45 +00:00
commit 5dde1a7702
1 changed files with 54 additions and 63 deletions

View File

@ -270,36 +270,59 @@ std::string EncodePaymentAddress(const libzcash::PaymentAddress& zaddr)
return boost::apply_visitor(PaymentAddressEncoder(Params()), zaddr); return boost::apply_visitor(PaymentAddressEncoder(Params()), zaddr);
} }
libzcash::PaymentAddress DecodePaymentAddress(const std::string& str) template<typename T1, typename T2, typename T3 = T2>
T1 DecodeAny(
const std::string& str,
std::pair<CChainParams::Base58Type, size_t> sprout,
boost::optional<std::pair<CChainParams::Bech32Type, size_t>> sapling)
{ {
std::vector<unsigned char> data; std::vector<unsigned char> data;
if (DecodeBase58Check(str, data)) { if (DecodeBase58Check(str, data)) {
const std::vector<unsigned char>& zaddr_prefix = Params().Base58Prefix(CChainParams::ZCPAYMENT_ADDRRESS); const std::vector<unsigned char>& prefix = Params().Base58Prefix(sprout.first);
if ((data.size() == libzcash::SerializedSproutPaymentAddressSize + zaddr_prefix.size()) && if ((data.size() == sprout.second + prefix.size()) &&
std::equal(zaddr_prefix.begin(), zaddr_prefix.end(), data.begin())) { std::equal(prefix.begin(), prefix.end(), data.begin())) {
CSerializeData serialized(data.begin() + zaddr_prefix.size(), data.end()); CSerializeData serialized(data.begin() + prefix.size(), data.end());
CDataStream ss(serialized, SER_NETWORK, PROTOCOL_VERSION); CDataStream ss(serialized, SER_NETWORK, PROTOCOL_VERSION);
libzcash::SproutPaymentAddress ret; T2 ret;
ss >> ret; ss >> ret;
memory_cleanse(serialized.data(), serialized.size());
memory_cleanse(data.data(), data.size());
return ret; return ret;
} }
} }
data.clear();
auto bech = bech32::Decode(str); if (sapling) {
if (bech.first == Params().Bech32HRP(CChainParams::SAPLING_PAYMENT_ADDRESS) && data.clear();
bech.second.size() == ConvertedSaplingPaymentAddressSize) { auto bech = bech32::Decode(str);
// Bech32 decoding if (bech.first == Params().Bech32HRP(sapling.get().first) &&
data.reserve((bech.second.size() * 5) / 8); bech.second.size() == sapling.get().second) {
if (ConvertBits<5, 8, false>([&](unsigned char c) { data.push_back(c); }, bech.second.begin(), bech.second.end())) { // Bech32 decoding
CDataStream ss(data, SER_NETWORK, PROTOCOL_VERSION); data.reserve((bech.second.size() * 5) / 8);
libzcash::SaplingPaymentAddress ret; if (ConvertBits<5, 8, false>([&](unsigned char c) { data.push_back(c); }, bech.second.begin(), bech.second.end())) {
ss >> ret; CDataStream ss(data, SER_NETWORK, PROTOCOL_VERSION);
return ret; T3 ret;
ss >> ret;
memory_cleanse(data.data(), data.size());
return ret;
}
} }
} }
memory_cleanse(data.data(), data.size());
return libzcash::InvalidEncoding(); return libzcash::InvalidEncoding();
} }
libzcash::PaymentAddress DecodePaymentAddress(const std::string& str)
{
return DecodeAny<libzcash::PaymentAddress,
libzcash::SproutPaymentAddress,
libzcash::SaplingPaymentAddress>(
str,
std::make_pair(CChainParams::ZCPAYMENT_ADDRRESS, libzcash::SerializedSproutPaymentAddressSize),
std::make_pair(CChainParams::SAPLING_PAYMENT_ADDRESS, ConvertedSaplingPaymentAddressSize)
);
}
bool IsValidPaymentAddressString(const std::string& str) { bool IsValidPaymentAddressString(const std::string& str) {
return IsValidPaymentAddress(DecodePaymentAddress(str)); return IsValidPaymentAddress(DecodePaymentAddress(str));
} }
@ -311,22 +334,12 @@ std::string EncodeViewingKey(const libzcash::ViewingKey& vk)
libzcash::ViewingKey DecodeViewingKey(const std::string& str) libzcash::ViewingKey DecodeViewingKey(const std::string& str)
{ {
std::vector<unsigned char> data; return DecodeAny<libzcash::ViewingKey,
if (DecodeBase58Check(str, data)) { libzcash::SproutViewingKey>(
const std::vector<unsigned char>& vk_prefix = Params().Base58Prefix(CChainParams::ZCVIEWING_KEY); str,
if ((data.size() == libzcash::SerializedSproutViewingKeySize + vk_prefix.size()) && std::make_pair(CChainParams::ZCVIEWING_KEY, libzcash::SerializedSproutViewingKeySize),
std::equal(vk_prefix.begin(), vk_prefix.end(), data.begin())) { boost::none
CSerializeData serialized(data.begin() + vk_prefix.size(), data.end()); );
CDataStream ss(serialized, SER_NETWORK, PROTOCOL_VERSION);
libzcash::SproutViewingKey ret;
ss >> ret;
memory_cleanse(serialized.data(), serialized.size());
memory_cleanse(data.data(), data.size());
return ret;
}
}
memory_cleanse(data.data(), data.size());
return libzcash::InvalidEncoding();
} }
std::string EncodeSpendingKey(const libzcash::SpendingKey& zkey) std::string EncodeSpendingKey(const libzcash::SpendingKey& zkey)
@ -336,34 +349,12 @@ std::string EncodeSpendingKey(const libzcash::SpendingKey& zkey)
libzcash::SpendingKey DecodeSpendingKey(const std::string& str) libzcash::SpendingKey DecodeSpendingKey(const std::string& str)
{ {
std::vector<unsigned char> data;
if (DecodeBase58Check(str, data)) { return DecodeAny<libzcash::SpendingKey,
const std::vector<unsigned char>& zkey_prefix = Params().Base58Prefix(CChainParams::ZCSPENDING_KEY); libzcash::SproutSpendingKey,
if ((data.size() == libzcash::SerializedSproutSpendingKeySize + zkey_prefix.size()) && libzcash::SaplingExtendedSpendingKey>(
std::equal(zkey_prefix.begin(), zkey_prefix.end(), data.begin())) { str,
CSerializeData serialized(data.begin() + zkey_prefix.size(), data.end()); std::make_pair(CChainParams::ZCSPENDING_KEY, libzcash::SerializedSproutSpendingKeySize),
CDataStream ss(serialized, SER_NETWORK, PROTOCOL_VERSION); std::make_pair(CChainParams::SAPLING_EXTENDED_SPEND_KEY, ConvertedSaplingExtendedSpendingKeySize)
libzcash::SproutSpendingKey ret; );
ss >> ret;
memory_cleanse(serialized.data(), serialized.size());
memory_cleanse(data.data(), data.size());
return ret;
}
}
data.clear();
auto bech = bech32::Decode(str);
if (bech.first == Params().Bech32HRP(CChainParams::SAPLING_EXTENDED_SPEND_KEY) &&
bech.second.size() == ConvertedSaplingExtendedSpendingKeySize) {
// Bech32 decoding
data.reserve((bech.second.size() * 5) / 8);
if (ConvertBits<5, 8, false>([&](unsigned char c) { data.push_back(c); }, bech.second.begin(), bech.second.end())) {
CDataStream ss(data, SER_NETWORK, PROTOCOL_VERSION);
libzcash::SaplingExtendedSpendingKey ret;
ss >> ret;
memory_cleanse(data.data(), data.size());
return ret;
}
}
memory_cleanse(data.data(), data.size());
return libzcash::InvalidEncoding();
} }