Refactor SaplingNotePlaintext::decrypt

Break up plaintext decryption into height-dependent and non-height-dependent parts.
This commit is contained in:
therealyingtong 2020-06-20 15:35:16 +08:00
parent 3c8e970358
commit 6402c589c6
6 changed files with 243 additions and 62 deletions

View File

@ -3769,8 +3769,9 @@ UniValue z_viewtransaction(const UniValue& params, bool fHelp)
auto op = res->second;
auto wtxPrev = pwalletMain->mapWallet.at(op.hash);
// TODO: decide which height to use here instead of wtxPrev.nExpiryHeight
auto decrypted = wtxPrev.DecryptSaplingNote(Params().GetConsensus(), wtxPrev.nExpiryHeight, op).get();
// We don't need to check the leadbyte here: if wtx exists in
// the wallet, it must have already passed the leadbyte check
auto decrypted = wtxPrev.DecryptSaplingNoteWithoutLeadByteCheck(op).get();
auto notePt = decrypted.first;
auto pa = decrypted.second;
@ -3798,16 +3799,16 @@ UniValue z_viewtransaction(const UniValue& params, bool fHelp)
SaplingPaymentAddress pa;
bool isOutgoing;
// TODO: decide which height to use here instead of wtx.nExpiryHeight
auto decrypted = wtx.DecryptSaplingNote(Params().GetConsensus(), wtx.nExpiryHeight, op);
// We don't need to check the leadbyte here: if wtx exists in
// the wallet, it must have already passed the leadbyte check
auto decrypted = wtx.DecryptSaplingNoteWithoutLeadByteCheck(op);
if (decrypted) {
notePt = decrypted->first;
pa = decrypted->second;
isOutgoing = false;
} else {
// Try recovering the output
// TODO: decide which height to use here instead of wtxPrev.nExpiryHeight
auto recovered = wtx.RecoverSaplingNote(Params().GetConsensus(), wtx.nExpiryHeight, op, ovks);
auto recovered = wtx.RecoverSaplingNoteWithoutLeadByteCheck(op, ovks);
if (recovered) {
notePt = recovered->first;
pa = recovered->second;

View File

@ -1496,8 +1496,15 @@ void CWallet::UpdateSaplingNullifierNoteMapWithTx(CWalletTx& wtx) {
auto extfvk = mapSaplingFullViewingKeys.at(nd.ivk);
OutputDescription output = wtx.vShieldedOutput[op.n];
// TODO: decide which height to use here instead of wtx.nExpiryHeight
auto optPlaintext = SaplingNotePlaintext::decrypt(Params().GetConsensus(), wtx.nExpiryHeight, output.encCiphertext, nd.ivk, output.ephemeralKey, output.cmu);
auto optDeserialized = SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization(output.encCiphertext, nd.ivk, output.ephemeralKey);
if (!optDeserialized) {
// The transaction would not have entered the wallet unless
// its plaintest had been succesfully decrypted previously.
assert(false);
}
auto optPlaintext = SaplingNotePlaintext::plaintext_checks_without_height(*optDeserialized, nd.ivk, output.ephemeralKey, output.cmu);
if (!optPlaintext) {
// An item in mapSaplingNoteData must have already been successfully decrypted,
// otherwise the item would not exist in the first place.
@ -1716,7 +1723,7 @@ bool CWallet::AddToWalletIfInvolvingMe(const CTransaction& tx, const CBlock* pbl
bool fExisted = mapWallet.count(tx.GetHash()) != 0;
if (fExisted && !fUpdate) return false;
auto sproutNoteData = FindMySproutNotes(tx);
auto saplingNoteDataAndAddressesToAdd = FindMySaplingNotes(tx);
auto saplingNoteDataAndAddressesToAdd = FindMySaplingNotes(tx, chainActive.Height());
auto saplingNoteData = saplingNoteDataAndAddressesToAdd.first;
auto addressesToAdd = saplingNoteDataAndAddressesToAdd.second;
for (const auto &addressToAdd : addressesToAdd) {
@ -1890,7 +1897,7 @@ mapSproutNoteData_t CWallet::FindMySproutNotes(const CTransaction &tx) const
* the result of FindMySaplingNotes (for the addresses available at the time) will
* already have been cached in CWalletTx.mapSaplingNoteData.
*/
std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> CWallet::FindMySaplingNotes(const CTransaction &tx) const
std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> CWallet::FindMySaplingNotes(const CTransaction &tx, int height) const
{
LOCK(cs_KeyStore);
uint256 hash = tx.GetHash();
@ -1904,8 +1911,7 @@ std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> CWallet::FindMySap
for (auto it = mapSaplingFullViewingKeys.begin(); it != mapSaplingFullViewingKeys.end(); ++it) {
SaplingIncomingViewingKey ivk = it->first;
// TODO: decide which height to use here instead of wtx.nExpiryHeight
auto result = SaplingNotePlaintext::decrypt(Params().GetConsensus(), tx.nExpiryHeight, output.encCiphertext, ivk, output.ephemeralKey, output.cmu);
auto result = SaplingNotePlaintext::decrypt(Params().GetConsensus(), height, output.encCiphertext, ivk, output.ephemeralKey, output.cmu);
if (!result) {
continue;
}
@ -2331,6 +2337,41 @@ boost::optional<std::pair<
return std::make_pair(notePt, pa);
}
boost::optional<std::pair<
SaplingNotePlaintext,
SaplingPaymentAddress>> CWalletTx::DecryptSaplingNoteWithoutLeadByteCheck(SaplingOutPoint op) const
{
// Check whether we can decrypt this SaplingOutPoint
if (this->mapSaplingNoteData.count(op) == 0) {
return boost::none;
}
auto output = this->vShieldedOutput[op.n];
auto nd = this->mapSaplingNoteData.at(op);
auto optDeserialized = SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization(output.encCiphertext, nd.ivk, output.ephemeralKey);
if (!optDeserialized) {
// The transaction would not have entered the wallet unless
// its plaintest had been succesfully decrypted previously.
assert(false);
}
auto maybe_pt = SaplingNotePlaintext::plaintext_checks_without_height(
*optDeserialized,
nd.ivk,
output.ephemeralKey,
output.cmu);
assert(static_cast<bool>(maybe_pt));
auto notePt = maybe_pt.get();
auto maybe_pa = nd.ivk.address(notePt.d);
assert(static_cast<bool>(maybe_pa));
auto pa = maybe_pa.get();
return std::make_pair(notePt, pa);
}
boost::optional<std::pair<
SaplingNotePlaintext,
SaplingPaymentAddress>> CWalletTx::RecoverSaplingNote(const Consensus::Params& params, int height, SaplingOutPoint op, std::set<uint256>& ovks) const
@ -2366,6 +2407,47 @@ boost::optional<std::pair<
return boost::none;
}
boost::optional<std::pair<
SaplingNotePlaintext,
SaplingPaymentAddress>> CWalletTx::RecoverSaplingNoteWithoutLeadByteCheck(SaplingOutPoint op, std::set<uint256>& ovks) const
{
auto output = this->vShieldedOutput[op.n];
for (auto ovk : ovks) {
auto outPt = SaplingOutgoingPlaintext::decrypt(
output.outCiphertext,
ovk,
output.cv,
output.cmu,
output.ephemeralKey);
if (!outPt) {
continue;
}
auto optDeserialized = SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization(output.encCiphertext, output.ephemeralKey, outPt->esk, outPt->pk_d);
if (!optDeserialized) {
// The transaction would not have entered the wallet unless
// its plaintest had been succesfully decrypted previously.
assert(false);
}
auto maybe_pt = SaplingNotePlaintext::plaintext_checks_without_height(
*optDeserialized,
output.ephemeralKey,
outPt->esk,
outPt->pk_d,
output.cmu);
assert(static_cast<bool>(maybe_pt));
auto notePt = maybe_pt.get();
return std::make_pair(notePt, SaplingPaymentAddress(notePt.d, outPt->pk_d));
}
// Couldn't recover with any of the provided OutgoingViewingKeys
return boost::none;
}
int64_t CWalletTx::GetTxTime() const
{
int64_t n = nTimeSmart;
@ -4982,11 +5064,17 @@ void CWallet::GetFilteredNotes(
SaplingOutPoint op = pair.first;
SaplingNoteData nd = pair.second;
// TODO: decide which height to use here instead of wtx.nExpiryHeight
auto maybe_pt = SaplingNotePlaintext::decrypt(
Params().GetConsensus(),
wtx.nExpiryHeight,
wtx.vShieldedOutput[op.n].encCiphertext,
auto optDeserialized = SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization(wtx.vShieldedOutput[op.n].encCiphertext, nd.ivk, wtx.vShieldedOutput[op.n].ephemeralKey);
if (!optDeserialized) {
// The transaction would not have entered the wallet unless
// its plaintest had been succesfully decrypted previously.
assert(false);
}
// We don't need to check the leadbyte here: if wtx exists in
// the wallet, it must have already passed the leadbyte check
auto maybe_pt = SaplingNotePlaintext::plaintext_checks_without_height(
*optDeserialized,
nd.ivk,
wtx.vShieldedOutput[op.n].ephemeralKey,
wtx.vShieldedOutput[op.n].cmu);

View File

@ -567,10 +567,16 @@ public:
boost::optional<std::pair<
libzcash::SaplingNotePlaintext,
libzcash::SaplingPaymentAddress>> DecryptSaplingNote(const Consensus::Params& params, int height, SaplingOutPoint op) const;
boost::optional<std::pair<
libzcash::SaplingNotePlaintext,
libzcash::SaplingPaymentAddress>> DecryptSaplingNoteWithoutLeadByteCheck(SaplingOutPoint op) const;
boost::optional<std::pair<
libzcash::SaplingNotePlaintext,
libzcash::SaplingPaymentAddress>> RecoverSaplingNote(const Consensus::Params& params, int height,
SaplingOutPoint op, std::set<uint256>& ovks) const;
boost::optional<std::pair<
libzcash::SaplingNotePlaintext,
libzcash::SaplingPaymentAddress>> RecoverSaplingNoteWithoutLeadByteCheck(SaplingOutPoint op, std::set<uint256>& ovks) const;
//! filter decides which addresses will count towards the debit
CAmount GetDebit(const isminefilter& filter) const;
@ -1221,7 +1227,7 @@ public:
const uint256& hSig,
uint8_t n) const;
mapSproutNoteData_t FindMySproutNotes(const CTransaction& tx) const;
std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> FindMySaplingNotes(const CTransaction& tx) const;
std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> FindMySaplingNotes(const CTransaction& tx, int height) const;
bool IsSproutNullifierFromMe(const uint256& nullifier) const;
bool IsSaplingNullifierFromMe(const uint256& nullifier) const;

View File

@ -209,34 +209,40 @@ boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::decrypt(
const uint256 &cmu
)
{
auto pt = AttemptSaplingEncDecryption(ciphertext, ivk, epk);
if (!pt) {
return boost::none;
}
// Deserialize from the plaintext
SaplingNotePlaintext ret;
CDataStream ss(SER_NETWORK, PROTOCOL_VERSION);
ss << pt.get();
ss >> ret;
assert(ss.size() == 0);
// Check leadbyte is allowed at block height
if (!plaintext_version_is_valid(params, height, ret.leadByte)) {
auto ret = attempt_sapling_enc_decryption_deserialization(ciphertext, ivk, epk);
if (!ret) {
return boost::none;
} else {
const SaplingNotePlaintext plaintext = *ret;
// Check leadbyte is allowed at block height
if (!plaintext_version_is_valid(params, height, plaintext.leadByte)) {
return boost::none;
}
return plaintext_checks_without_height(plaintext, ivk, epk, cmu);
}
}
boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::plaintext_checks_without_height(
const SaplingNotePlaintext &plaintext,
const uint256 &ivk,
const uint256 &epk,
const uint256 &cmu
)
{
uint256 pk_d;
if (!librustzcash_ivk_to_pkd(ivk.begin(), ret.d.data(), pk_d.begin())) {
if (!librustzcash_ivk_to_pkd(ivk.begin(), plaintext.d.data(), pk_d.begin())) {
return boost::none;
}
uint256 cmu_expected;
uint256 rcm = ret.rcm();
uint256 rcm = plaintext.rcm();
if (!librustzcash_sapling_compute_cm(
ret.d.data(),
plaintext.d.data(),
pk_d.begin(),
ret.value(),
plaintext.value(),
rcm.begin(),
cmu_expected.begin()
))
@ -248,12 +254,12 @@ boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::decrypt(
return boost::none;
}
if (ret.leadByte == 0x02) {
if (plaintext.leadByte == 0x02) {
// ZIP 212: Check that epk is consistent to prevent against linkability
// attacks without relying on the soundness of the SNARK.
uint256 expected_epk;
uint256 esk = ret.generate_esk();
if (!librustzcash_sapling_ka_derivepublic(ret.d.data(), esk.begin(), expected_epk.begin())) {
uint256 esk = plaintext.generate_esk();
if (!librustzcash_sapling_ka_derivepublic(plaintext.d.data(), esk.begin(), expected_epk.begin())) {
return boost::none;
}
if (expected_epk != epk) {
@ -261,7 +267,29 @@ boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::decrypt(
}
}
return ret;
return plaintext;
}
boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization(
const SaplingEncCiphertext &ciphertext,
const uint256 &ivk,
const uint256 &epk
)
{
auto encPlaintext = AttemptSaplingEncDecryption(ciphertext, ivk, epk);
if (!encPlaintext) {
return boost::none;
};
// Deserialize from the plaintext
SaplingNotePlaintext plaintext;
CDataStream ss(SER_NETWORK, PROTOCOL_VERSION);
ss << encPlaintext.get();
ss >> plaintext;
assert(ss.size() == 0);
return plaintext;
}
boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::decrypt(
@ -274,26 +302,33 @@ boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::decrypt(
const uint256 &cmu
)
{
auto pt = AttemptSaplingEncDecryption(ciphertext, epk, esk, pk_d);
if (!pt) {
return boost::none;
}
// Deserialize from the plaintext
SaplingNotePlaintext ret;
CDataStream ss(SER_NETWORK, PROTOCOL_VERSION);
ss << pt.get();
ss >> ret;
assert(ss.size() == 0);
// Check leadbyte is legible at block height
if (!plaintext_version_is_valid(params, height, ret.leadByte)) {
auto ret = attempt_sapling_enc_decryption_deserialization(ciphertext, epk, esk, pk_d);
if (!ret) {
return boost::none;
} else {
SaplingNotePlaintext plaintext = *ret;
// Check leadbyte is allowed at block height
if (!plaintext_version_is_valid(params, height, plaintext.leadByte)) {
return boost::none;
}
return plaintext_checks_without_height(plaintext, epk, esk, pk_d, cmu);
}
}
boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::plaintext_checks_without_height(
const SaplingNotePlaintext &plaintext,
const uint256 &epk,
const uint256 &esk,
const uint256 &pk_d,
const uint256 &cmu
)
{
// Check that epk is consistent with esk
uint256 expected_epk;
if (!librustzcash_sapling_ka_derivepublic(ret.d.data(), esk.begin(), expected_epk.begin())) {
if (!librustzcash_sapling_ka_derivepublic(plaintext.d.data(), esk.begin(), expected_epk.begin())) {
return boost::none;
}
if (expected_epk != epk) {
@ -301,11 +336,11 @@ boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::decrypt(
}
uint256 cmu_expected;
uint256 rcm = ret.rcm();
uint256 rcm = plaintext.rcm();
if (!librustzcash_sapling_compute_cm(
ret.d.data(),
plaintext.d.data(),
pk_d.begin(),
ret.value(),
plaintext.value(),
rcm.begin(),
cmu_expected.begin()
))
@ -317,15 +352,38 @@ boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::decrypt(
return boost::none;
}
if (ret.leadByte == 0x02) {
if (plaintext.leadByte == 0x02) {
// ZIP 212: Additionally check that the esk provided to this function
// is consistent with the esk we can derive
if (esk != ret.generate_esk()) {
if (esk != plaintext.generate_esk()) {
return boost::none;
}
}
return ret;
return plaintext;
}
boost::optional<SaplingNotePlaintext> SaplingNotePlaintext::attempt_sapling_enc_decryption_deserialization(
const SaplingEncCiphertext &ciphertext,
const uint256 &epk,
const uint256 &esk,
const uint256 &pk_d
)
{
auto encPlaintext = AttemptSaplingEncDecryption(ciphertext, epk, esk, pk_d);
if (!encPlaintext) {
return boost::none;
};
// Deserialize from the plaintext
SaplingNotePlaintext plaintext;
CDataStream ss(SER_NETWORK, PROTOCOL_VERSION);
ss << encPlaintext.get();
ss >> plaintext;
assert(ss.size() == 0);
return plaintext;
}
boost::optional<SaplingNotePlaintextEncryptionResult> SaplingNotePlaintext::encrypt(const uint256& pk_d) const

View File

@ -167,6 +167,19 @@ public:
const uint256 &cmu
);
static boost::optional<SaplingNotePlaintext> plaintext_checks_without_height(
const SaplingNotePlaintext &plaintext,
const uint256 &ivk,
const uint256 &epk,
const uint256 &cmu
);
static boost::optional<SaplingNotePlaintext> attempt_sapling_enc_decryption_deserialization(
const SaplingEncCiphertext &ciphertext,
const uint256 &ivk,
const uint256 &epk
);
static boost::optional<SaplingNotePlaintext> decrypt(
const Consensus::Params& params,
int height,
@ -177,6 +190,21 @@ public:
const uint256 &cmu
);
static boost::optional<SaplingNotePlaintext> plaintext_checks_without_height(
const SaplingNotePlaintext &plaintext,
const uint256 &epk,
const uint256 &esk,
const uint256 &pk_d,
const uint256 &cmu
);
static boost::optional<SaplingNotePlaintext> attempt_sapling_enc_decryption_deserialization(
const SaplingEncCiphertext &ciphertext,
const uint256 &epk,
const uint256 &esk,
const uint256 &pk_d
);
boost::optional<SaplingNote> note(const SaplingIncomingViewingKey& ivk) const;
virtual ~SaplingNotePlaintext() {}

View File

@ -307,7 +307,7 @@ double benchmark_try_decrypt_sapling_notes(size_t nKeys)
struct timeval tv_start;
timer_start(tv_start);
auto noteDataMapAndAddressesToAdd = wallet.FindMySaplingNotes(tx);
auto noteDataMapAndAddressesToAdd = wallet.FindMySaplingNotes(tx, 1);
assert(noteDataMapAndAddressesToAdd.first.empty());
return timer_stop(tv_start);
}