diff --git a/src/utiltest.cpp b/src/utiltest.cpp index 3c16e73e7..8eb23249f 100644 --- a/src/utiltest.cpp +++ b/src/utiltest.cpp @@ -10,7 +10,7 @@ #include // Sprout -CWalletTx GetValidSproutReceive(ZCJoinSplit& params, +CMutableTransaction GetValidSproutReceiveTransaction(ZCJoinSplit& params, const libzcash::SproutSpendingKey& sk, CAmount value, bool randomInputs, @@ -71,6 +71,34 @@ CWalletTx GetValidSproutReceive(ZCJoinSplit& params, joinSplitPrivKey ) == 0); + return mtx; +} + +CWalletTx GetValidSproutReceive(ZCJoinSplit& params, + const libzcash::SproutSpendingKey& sk, + CAmount value, + bool randomInputs, + int32_t version /* = 2 */) +{ + CMutableTransaction mtx = GetValidSproutReceiveTransaction( + params, sk, value, randomInputs, version + ); + CTransaction tx {mtx}; + CWalletTx wtx {NULL, tx}; + return wtx; +} + +CWalletTx GetInvalidCommitmentSproutReceive(ZCJoinSplit& params, + const libzcash::SproutSpendingKey& sk, + CAmount value, + bool randomInputs, + int32_t version /* = 2 */) +{ + CMutableTransaction mtx = GetValidSproutReceiveTransaction( + params, sk, value, randomInputs, version + ); + mtx.vjoinsplit[0].commitments[0] = uint256(); + mtx.vjoinsplit[0].commitments[1] = uint256(); CTransaction tx {mtx}; CWalletTx wtx {NULL, tx}; return wtx; diff --git a/src/utiltest.h b/src/utiltest.h index 1a91eefe2..5480c559e 100644 --- a/src/utiltest.h +++ b/src/utiltest.h @@ -18,6 +18,11 @@ CWalletTx GetValidSproutReceive(ZCJoinSplit& params, CAmount value, bool randomInputs, int32_t version = 2); +CWalletTx GetInvalidCommitmentSproutReceive(ZCJoinSplit& params, + const libzcash::SproutSpendingKey& sk, + CAmount value, + bool randomInputs, + int32_t version = 2); libzcash::SproutNote GetSproutNote(ZCJoinSplit& params, const libzcash::SproutSpendingKey& sk, const CTransaction& tx, size_t js, size_t n); diff --git a/src/wallet/gtest/test_wallet.cpp b/src/wallet/gtest/test_wallet.cpp index 8308243ab..9af4aa060 100644 --- a/src/wallet/gtest/test_wallet.cpp +++ b/src/wallet/gtest/test_wallet.cpp @@ -75,6 +75,10 @@ CWalletTx GetValidSproutReceive(const libzcash::SproutSpendingKey& sk, CAmount v return GetValidSproutReceive(*params, sk, value, randomInputs, version); } +CWalletTx GetInvalidCommitmentSproutReceive(const libzcash::SproutSpendingKey& sk, CAmount value, bool randomInputs, int32_t version = 2) { + return GetInvalidCommitmentSproutReceive(*params, sk, value, randomInputs, version); +} + libzcash::SproutNote GetSproutNote(const libzcash::SproutSpendingKey& sk, const CTransaction& tx, size_t js, size_t n) { return GetSproutNote(*params, sk, tx, js, n); @@ -436,6 +440,27 @@ TEST(WalletTests, SetInvalidSaplingNoteDataInCWalletTx) { EXPECT_THROW(wtx.SetSaplingNoteData(noteData), std::logic_error); } +TEST(WalletTests, CheckSproutNoteCommitmentAgainstNotePlaintext) { + CWallet wallet; + + auto sk = libzcash::SproutSpendingKey::random(); + auto address = sk.address(); + auto dec = ZCNoteDecryption(sk.receiving_key()); + + auto wtx = GetInvalidCommitmentSproutReceive(sk, 10, true); + auto note = GetSproutNote(sk, wtx, 0, 1); + auto nullifier = note.nullifier(sk); + + auto hSig = wtx.vjoinsplit[0].h_sig( + *params, wtx.joinSplitPubKey); + + ASSERT_THROW(wallet.GetSproutNoteNullifier( + wtx.vjoinsplit[0], + address, + dec, + hSig, 1), libzcash::note_decryption_failed); +} + TEST(WalletTests, GetSproutNoteNullifier) { CWallet wallet; diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 89aabad69..6dfa15025 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -1699,6 +1699,12 @@ boost::optional CWallet::GetSproutNoteNullifier(const JSDescription &js hSig, (unsigned char) n); auto note = note_pt.note(address); + + // Check note plaintext against note commitment + if (note.cm() != jsdesc.commitments[n]) { + throw libzcash::note_decryption_failed(); + } + // SpendingKeys are only available if: // - We have them (this isn't a viewing key) // - The wallet is unlocked