From 6402c589c6cb00185e14d4f8c4b6bb9acc981489 Mon Sep 17 00:00:00 2001 From: therealyingtong Date: Sat, 20 Jun 2020 15:35:16 +0800 Subject: [PATCH] Refactor SaplingNotePlaintext::decrypt Break up plaintext decryption into height-dependent and non-height-dependent parts. --- src/wallet/rpcwallet.cpp | 13 ++-- src/wallet/wallet.cpp | 110 +++++++++++++++++++++++++++--- src/wallet/wallet.h | 8 ++- src/zcash/Note.cpp | 144 +++++++++++++++++++++++++++------------ src/zcash/Note.hpp | 28 ++++++++ src/zcbenchmarks.cpp | 2 +- 6 files changed, 243 insertions(+), 62 deletions(-) diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index e45aa218e..08bf81662 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -3769,8 +3769,9 @@ UniValue z_viewtransaction(const UniValue& params, bool fHelp) auto op = res->second; auto wtxPrev = pwalletMain->mapWallet.at(op.hash); - // TODO: decide which height to use here instead of wtxPrev.nExpiryHeight - auto decrypted = wtxPrev.DecryptSaplingNote(Params().GetConsensus(), wtxPrev.nExpiryHeight, op).get(); + // We don't need to check the leadbyte here: if wtx exists in + // the wallet, it must have already passed the leadbyte check + auto decrypted = wtxPrev.DecryptSaplingNoteWithoutLeadByteCheck(op).get(); auto notePt = decrypted.first; auto pa = decrypted.second; @@ -3798,16 +3799,16 @@ UniValue z_viewtransaction(const UniValue& params, bool fHelp) SaplingPaymentAddress pa; bool isOutgoing; - // TODO: decide which height to use here instead of wtx.nExpiryHeight - auto decrypted = wtx.DecryptSaplingNote(Params().GetConsensus(), wtx.nExpiryHeight, op); + // We don't need to check the leadbyte here: if wtx exists in + // the wallet, it must have already passed the leadbyte check + auto decrypted = wtx.DecryptSaplingNoteWithoutLeadByteCheck(op); if (decrypted) { notePt = decrypted->first; pa = decrypted->second; isOutgoing = false; } else { // Try recovering the output - // TODO: decide which height to use here instead of wtxPrev.nExpiryHeight - auto recovered = wtx.RecoverSaplingNote(Params().GetConsensus(), wtx.nExpiryHeight, op, ovks); + auto recovered = wtx.RecoverSaplingNoteWithoutLeadByteCheck(op, ovks); if (recovered) { notePt = recovered->first; pa = recovered->second; diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 0ccea2ae0..d43881256 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -1496,8 +1496,15 @@ void CWallet::UpdateSaplingNullifierNoteMapWithTx(CWalletTx& wtx) { auto extfvk = mapSaplingFullViewingKeys.at(nd.ivk); OutputDescription output = wtx.vShieldedOutput[op.n]; - // TODO: decide which height to use here instead of wtx.nExpiryHeight - auto optPlaintext = SaplingNotePlaintext::decrypt(Params().GetConsensus(), wtx.nExpiryHeight, output.encCiphertext, nd.ivk, output.ephemeralKey, output.cmu); + auto optDeserialized = SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization(output.encCiphertext, nd.ivk, output.ephemeralKey); + + if (!optDeserialized) { + // The transaction would not have entered the wallet unless + // its plaintest had been succesfully decrypted previously. + assert(false); + } + + auto optPlaintext = SaplingNotePlaintext::plaintext_checks_without_height(*optDeserialized, nd.ivk, output.ephemeralKey, output.cmu); if (!optPlaintext) { // An item in mapSaplingNoteData must have already been successfully decrypted, // otherwise the item would not exist in the first place. @@ -1716,7 +1723,7 @@ bool CWallet::AddToWalletIfInvolvingMe(const CTransaction& tx, const CBlock* pbl bool fExisted = mapWallet.count(tx.GetHash()) != 0; if (fExisted && !fUpdate) return false; auto sproutNoteData = FindMySproutNotes(tx); - auto saplingNoteDataAndAddressesToAdd = FindMySaplingNotes(tx); + auto saplingNoteDataAndAddressesToAdd = FindMySaplingNotes(tx, chainActive.Height()); auto saplingNoteData = saplingNoteDataAndAddressesToAdd.first; auto addressesToAdd = saplingNoteDataAndAddressesToAdd.second; for (const auto &addressToAdd : addressesToAdd) { @@ -1890,7 +1897,7 @@ mapSproutNoteData_t CWallet::FindMySproutNotes(const CTransaction &tx) const * the result of FindMySaplingNotes (for the addresses available at the time) will * already have been cached in CWalletTx.mapSaplingNoteData. */ -std::pair CWallet::FindMySaplingNotes(const CTransaction &tx) const +std::pair CWallet::FindMySaplingNotes(const CTransaction &tx, int height) const { LOCK(cs_KeyStore); uint256 hash = tx.GetHash(); @@ -1904,8 +1911,7 @@ std::pair CWallet::FindMySap for (auto it = mapSaplingFullViewingKeys.begin(); it != mapSaplingFullViewingKeys.end(); ++it) { SaplingIncomingViewingKey ivk = it->first; - // TODO: decide which height to use here instead of wtx.nExpiryHeight - auto result = SaplingNotePlaintext::decrypt(Params().GetConsensus(), tx.nExpiryHeight, output.encCiphertext, ivk, output.ephemeralKey, output.cmu); + auto result = SaplingNotePlaintext::decrypt(Params().GetConsensus(), height, output.encCiphertext, ivk, output.ephemeralKey, output.cmu); if (!result) { continue; } @@ -2331,6 +2337,41 @@ boost::optional> CWalletTx::DecryptSaplingNoteWithoutLeadByteCheck(SaplingOutPoint op) const +{ + // Check whether we can decrypt this SaplingOutPoint + if (this->mapSaplingNoteData.count(op) == 0) { + return boost::none; + } + + auto output = this->vShieldedOutput[op.n]; + auto nd = this->mapSaplingNoteData.at(op); + + auto optDeserialized = SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization(output.encCiphertext, nd.ivk, output.ephemeralKey); + + if (!optDeserialized) { + // The transaction would not have entered the wallet unless + // its plaintest had been succesfully decrypted previously. + assert(false); + } + + auto maybe_pt = SaplingNotePlaintext::plaintext_checks_without_height( + *optDeserialized, + nd.ivk, + output.ephemeralKey, + output.cmu); + assert(static_cast(maybe_pt)); + auto notePt = maybe_pt.get(); + + auto maybe_pa = nd.ivk.address(notePt.d); + assert(static_cast(maybe_pa)); + auto pa = maybe_pa.get(); + + return std::make_pair(notePt, pa); +} + boost::optional> CWalletTx::RecoverSaplingNote(const Consensus::Params& params, int height, SaplingOutPoint op, std::set& ovks) const @@ -2366,6 +2407,47 @@ boost::optional> CWalletTx::RecoverSaplingNoteWithoutLeadByteCheck(SaplingOutPoint op, std::set& ovks) const +{ + auto output = this->vShieldedOutput[op.n]; + + for (auto ovk : ovks) { + auto outPt = SaplingOutgoingPlaintext::decrypt( + output.outCiphertext, + ovk, + output.cv, + output.cmu, + output.ephemeralKey); + if (!outPt) { + continue; + } + + auto optDeserialized = SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization(output.encCiphertext, output.ephemeralKey, outPt->esk, outPt->pk_d); + + if (!optDeserialized) { + // The transaction would not have entered the wallet unless + // its plaintest had been succesfully decrypted previously. + assert(false); + } + + auto maybe_pt = SaplingNotePlaintext::plaintext_checks_without_height( + *optDeserialized, + output.ephemeralKey, + outPt->esk, + outPt->pk_d, + output.cmu); + assert(static_cast(maybe_pt)); + auto notePt = maybe_pt.get(); + + return std::make_pair(notePt, SaplingPaymentAddress(notePt.d, outPt->pk_d)); + } + + // Couldn't recover with any of the provided OutgoingViewingKeys + return boost::none; +} + int64_t CWalletTx::GetTxTime() const { int64_t n = nTimeSmart; @@ -4982,11 +5064,17 @@ void CWallet::GetFilteredNotes( SaplingOutPoint op = pair.first; SaplingNoteData nd = pair.second; - // TODO: decide which height to use here instead of wtx.nExpiryHeight - auto maybe_pt = SaplingNotePlaintext::decrypt( - Params().GetConsensus(), - wtx.nExpiryHeight, - wtx.vShieldedOutput[op.n].encCiphertext, + auto optDeserialized = SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization(wtx.vShieldedOutput[op.n].encCiphertext, nd.ivk, wtx.vShieldedOutput[op.n].ephemeralKey); + + if (!optDeserialized) { + // The transaction would not have entered the wallet unless + // its plaintest had been succesfully decrypted previously. + assert(false); + } + // We don't need to check the leadbyte here: if wtx exists in + // the wallet, it must have already passed the leadbyte check + auto maybe_pt = SaplingNotePlaintext::plaintext_checks_without_height( + *optDeserialized, nd.ivk, wtx.vShieldedOutput[op.n].ephemeralKey, wtx.vShieldedOutput[op.n].cmu); diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index d0696afd6..4d6bcb95a 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -567,10 +567,16 @@ public: boost::optional> DecryptSaplingNote(const Consensus::Params& params, int height, SaplingOutPoint op) const; + boost::optional> DecryptSaplingNoteWithoutLeadByteCheck(SaplingOutPoint op) const; boost::optional> RecoverSaplingNote(const Consensus::Params& params, int height, SaplingOutPoint op, std::set& ovks) const; + boost::optional> RecoverSaplingNoteWithoutLeadByteCheck(SaplingOutPoint op, std::set& ovks) const; //! filter decides which addresses will count towards the debit CAmount GetDebit(const isminefilter& filter) const; @@ -1221,7 +1227,7 @@ public: const uint256& hSig, uint8_t n) const; mapSproutNoteData_t FindMySproutNotes(const CTransaction& tx) const; - std::pair FindMySaplingNotes(const CTransaction& tx) const; + std::pair FindMySaplingNotes(const CTransaction& tx, int height) const; bool IsSproutNullifierFromMe(const uint256& nullifier) const; bool IsSaplingNullifierFromMe(const uint256& nullifier) const; diff --git a/src/zcash/Note.cpp b/src/zcash/Note.cpp index 76fcb1be2..08da32ed6 100644 --- a/src/zcash/Note.cpp +++ b/src/zcash/Note.cpp @@ -209,34 +209,40 @@ boost::optional SaplingNotePlaintext::decrypt( const uint256 &cmu ) { - auto pt = AttemptSaplingEncDecryption(ciphertext, ivk, epk); - if (!pt) { - return boost::none; - } - - // Deserialize from the plaintext - SaplingNotePlaintext ret; - CDataStream ss(SER_NETWORK, PROTOCOL_VERSION); - ss << pt.get(); - ss >> ret; - assert(ss.size() == 0); - - // Check leadbyte is allowed at block height - if (!plaintext_version_is_valid(params, height, ret.leadByte)) { + auto ret = attempt_sapling_enc_decryption_deserialization(ciphertext, ivk, epk); + + if (!ret) { return boost::none; + } else { + const SaplingNotePlaintext plaintext = *ret; + + // Check leadbyte is allowed at block height + if (!plaintext_version_is_valid(params, height, plaintext.leadByte)) { + return boost::none; + } + + return plaintext_checks_without_height(plaintext, ivk, epk, cmu); } +} +boost::optional SaplingNotePlaintext::plaintext_checks_without_height( + const SaplingNotePlaintext &plaintext, + const uint256 &ivk, + const uint256 &epk, + const uint256 &cmu +) +{ uint256 pk_d; - if (!librustzcash_ivk_to_pkd(ivk.begin(), ret.d.data(), pk_d.begin())) { + if (!librustzcash_ivk_to_pkd(ivk.begin(), plaintext.d.data(), pk_d.begin())) { return boost::none; } uint256 cmu_expected; - uint256 rcm = ret.rcm(); + uint256 rcm = plaintext.rcm(); if (!librustzcash_sapling_compute_cm( - ret.d.data(), + plaintext.d.data(), pk_d.begin(), - ret.value(), + plaintext.value(), rcm.begin(), cmu_expected.begin() )) @@ -248,12 +254,12 @@ boost::optional SaplingNotePlaintext::decrypt( return boost::none; } - if (ret.leadByte == 0x02) { + if (plaintext.leadByte == 0x02) { // ZIP 212: Check that epk is consistent to prevent against linkability // attacks without relying on the soundness of the SNARK. uint256 expected_epk; - uint256 esk = ret.generate_esk(); - if (!librustzcash_sapling_ka_derivepublic(ret.d.data(), esk.begin(), expected_epk.begin())) { + uint256 esk = plaintext.generate_esk(); + if (!librustzcash_sapling_ka_derivepublic(plaintext.d.data(), esk.begin(), expected_epk.begin())) { return boost::none; } if (expected_epk != epk) { @@ -261,7 +267,29 @@ boost::optional SaplingNotePlaintext::decrypt( } } - return ret; + return plaintext; +} + +boost::optional SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization( + const SaplingEncCiphertext &ciphertext, + const uint256 &ivk, + const uint256 &epk +) +{ + auto encPlaintext = AttemptSaplingEncDecryption(ciphertext, ivk, epk); + + if (!encPlaintext) { + return boost::none; + }; + + // Deserialize from the plaintext + SaplingNotePlaintext plaintext; + CDataStream ss(SER_NETWORK, PROTOCOL_VERSION); + ss << encPlaintext.get(); + ss >> plaintext; + assert(ss.size() == 0); + + return plaintext; } boost::optional SaplingNotePlaintext::decrypt( @@ -274,26 +302,33 @@ boost::optional SaplingNotePlaintext::decrypt( const uint256 &cmu ) { - auto pt = AttemptSaplingEncDecryption(ciphertext, epk, esk, pk_d); - if (!pt) { - return boost::none; - } - - // Deserialize from the plaintext - SaplingNotePlaintext ret; - CDataStream ss(SER_NETWORK, PROTOCOL_VERSION); - ss << pt.get(); - ss >> ret; - assert(ss.size() == 0); - - // Check leadbyte is legible at block height - if (!plaintext_version_is_valid(params, height, ret.leadByte)) { + auto ret = attempt_sapling_enc_decryption_deserialization(ciphertext, epk, esk, pk_d); + + if (!ret) { return boost::none; + } else { + SaplingNotePlaintext plaintext = *ret; + + // Check leadbyte is allowed at block height + if (!plaintext_version_is_valid(params, height, plaintext.leadByte)) { + return boost::none; + } + + return plaintext_checks_without_height(plaintext, epk, esk, pk_d, cmu); } +} +boost::optional SaplingNotePlaintext::plaintext_checks_without_height( + const SaplingNotePlaintext &plaintext, + const uint256 &epk, + const uint256 &esk, + const uint256 &pk_d, + const uint256 &cmu +) +{ // Check that epk is consistent with esk uint256 expected_epk; - if (!librustzcash_sapling_ka_derivepublic(ret.d.data(), esk.begin(), expected_epk.begin())) { + if (!librustzcash_sapling_ka_derivepublic(plaintext.d.data(), esk.begin(), expected_epk.begin())) { return boost::none; } if (expected_epk != epk) { @@ -301,11 +336,11 @@ boost::optional SaplingNotePlaintext::decrypt( } uint256 cmu_expected; - uint256 rcm = ret.rcm(); + uint256 rcm = plaintext.rcm(); if (!librustzcash_sapling_compute_cm( - ret.d.data(), + plaintext.d.data(), pk_d.begin(), - ret.value(), + plaintext.value(), rcm.begin(), cmu_expected.begin() )) @@ -317,15 +352,38 @@ boost::optional SaplingNotePlaintext::decrypt( return boost::none; } - if (ret.leadByte == 0x02) { + if (plaintext.leadByte == 0x02) { // ZIP 212: Additionally check that the esk provided to this function // is consistent with the esk we can derive - if (esk != ret.generate_esk()) { + if (esk != plaintext.generate_esk()) { return boost::none; } } - return ret; + return plaintext; +} + +boost::optional SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization( + const SaplingEncCiphertext &ciphertext, + const uint256 &epk, + const uint256 &esk, + const uint256 &pk_d +) +{ + auto encPlaintext = AttemptSaplingEncDecryption(ciphertext, epk, esk, pk_d); + + if (!encPlaintext) { + return boost::none; + }; + + // Deserialize from the plaintext + SaplingNotePlaintext plaintext; + CDataStream ss(SER_NETWORK, PROTOCOL_VERSION); + ss << encPlaintext.get(); + ss >> plaintext; + assert(ss.size() == 0); + + return plaintext; } boost::optional SaplingNotePlaintext::encrypt(const uint256& pk_d) const diff --git a/src/zcash/Note.hpp b/src/zcash/Note.hpp index 134c289eb..b9df41547 100644 --- a/src/zcash/Note.hpp +++ b/src/zcash/Note.hpp @@ -167,6 +167,19 @@ public: const uint256 &cmu ); + static boost::optional plaintext_checks_without_height( + const SaplingNotePlaintext &plaintext, + const uint256 &ivk, + const uint256 &epk, + const uint256 &cmu + ); + + static boost::optional attempt_sapling_enc_decryption_deserialization( + const SaplingEncCiphertext &ciphertext, + const uint256 &ivk, + const uint256 &epk + ); + static boost::optional decrypt( const Consensus::Params& params, int height, @@ -177,6 +190,21 @@ public: const uint256 &cmu ); + static boost::optional plaintext_checks_without_height( + const SaplingNotePlaintext &plaintext, + const uint256 &epk, + const uint256 &esk, + const uint256 &pk_d, + const uint256 &cmu + ); + + static boost::optional attempt_sapling_enc_decryption_deserialization( + const SaplingEncCiphertext &ciphertext, + const uint256 &epk, + const uint256 &esk, + const uint256 &pk_d + ); + boost::optional note(const SaplingIncomingViewingKey& ivk) const; virtual ~SaplingNotePlaintext() {} diff --git a/src/zcbenchmarks.cpp b/src/zcbenchmarks.cpp index 33378f2db..4311f7bf7 100644 --- a/src/zcbenchmarks.cpp +++ b/src/zcbenchmarks.cpp @@ -307,7 +307,7 @@ double benchmark_try_decrypt_sapling_notes(size_t nKeys) struct timeval tv_start; timer_start(tv_start); - auto noteDataMapAndAddressesToAdd = wallet.FindMySaplingNotes(tx); + auto noteDataMapAndAddressesToAdd = wallet.FindMySaplingNotes(tx, 1); assert(noteDataMapAndAddressesToAdd.first.empty()); return timer_stop(tv_start); }