diff --git a/src/chain.cpp b/src/chain.cpp index c9c787900..39520cc8f 100644 --- a/src/chain.cpp +++ b/src/chain.cpp @@ -7,13 +7,6 @@ using namespace std; -uint256 CBlockIndex::GetSaplingAnchorEnd() const { - // TODO: The block header's hashFinalSaplingRoot is only guaranteed to - // be valid on or after the Sapling activation height. - - return hashFinalSaplingRoot; -} - /** * CChain implementation */ diff --git a/src/chain.h b/src/chain.h index 017944bf8..b5a1a3ba6 100644 --- a/src/chain.h +++ b/src/chain.h @@ -321,9 +321,6 @@ public: //! Efficiently find an ancestor of this block. CBlockIndex* GetAncestor(int height); const CBlockIndex* GetAncestor(int height) const; - - //! Get the root of the Sapling merkle tree (at the end of this block) - uint256 GetSaplingAnchorEnd() const; }; /** Used to marshal pointers into hashes for db storage. */ diff --git a/src/coins.cpp b/src/coins.cpp index c8b98d428..42209539e 100644 --- a/src/coins.cpp +++ b/src/coins.cpp @@ -585,9 +585,12 @@ bool CCoinsViewCache::HaveJoinSplitRequirements(const CTransaction& tx) const for (const SpendDescription &spendDescription : tx.vShieldedSpend) { if (GetNullifier(spendDescription.nullifier, SAPLING)) // Prevent double spends return false; - } - // TODO: Sapling anchor checks + ZCSaplingIncrementalMerkleTree tree; + if (!GetSaplingAnchorAt(spendDescription.anchor, tree)) { + return false; + } + } return true; } diff --git a/src/main.cpp b/src/main.cpp index 6a409402c..2851fbbbb 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -2149,8 +2149,19 @@ bool DisconnectBlock(CBlock& block, CValidationState& state, CBlockIndex* pindex } } - // set the old best anchor back - view.PopAnchor(blockUndo.old_tree_root, SPROUT); + // set the old best Sprout anchor back + view.PopAnchor(blockUndo.old_sprout_tree_root, SPROUT); + + // set the old best Sapling anchor back + // We can get this from the `hashFinalSaplingRoot` of the last block + // However, this is only reliable if the last block was on or after + // the Sapling activation height. Otherwise, the last anchor was the + // empty root. + if (NetworkUpgradeActive(pindex->pprev->nHeight, Params().GetConsensus(), Consensus::UPGRADE_SAPLING)) { + view.PopAnchor(pindex->pprev->hashFinalSaplingRoot, SAPLING); + } else { + view.PopAnchor(ZCSaplingIncrementalMerkleTree::empty_root(), SAPLING); + } // move best block pointer to prevout block view.SetBestBlock(pindex->pprev->GetBlockHash()); @@ -2330,22 +2341,25 @@ bool ConnectBlock(const CBlock& block, CValidationState& state, CBlockIndex* pin // Construct the incremental merkle tree at the current // block position, - auto old_tree_root = view.GetBestAnchor(SPROUT); + auto old_sprout_tree_root = view.GetBestAnchor(SPROUT); // saving the top anchor in the block index as we go. if (!fJustCheck) { - pindex->hashSproutAnchor = old_tree_root; + pindex->hashSproutAnchor = old_sprout_tree_root; } - ZCIncrementalMerkleTree tree; + ZCIncrementalMerkleTree sprout_tree; // This should never fail: we should always be able to get the root // that is on the tip of our chain - assert(view.GetSproutAnchorAt(old_tree_root, tree)); + assert(view.GetSproutAnchorAt(old_sprout_tree_root, sprout_tree)); { // Consistency check: the root of the tree we're given should // match what we asked for. - assert(tree.root() == old_tree_root); + assert(sprout_tree.root() == old_sprout_tree_root); } + ZCSaplingIncrementalMerkleTree sapling_tree; + assert(view.GetSaplingAnchorAt(view.GetBestAnchor(SAPLING), sapling_tree)); + // Grab the consensus branch ID for the block's height auto consensusBranchId = CurrentEpochBranchId(pindex->nHeight, Params().GetConsensus()); @@ -2403,19 +2417,34 @@ bool ConnectBlock(const CBlock& block, CValidationState& state, CBlockIndex* pin BOOST_FOREACH(const uint256 ¬e_commitment, joinsplit.commitments) { // Insert the note commitments into our temporary tree. - tree.append(note_commitment); + sprout_tree.append(note_commitment); } } + BOOST_FOREACH(const OutputDescription &outputDescription, tx.vShieldedOutput) { + sapling_tree.append(outputDescription.cm); + } + vPos.push_back(std::make_pair(tx.GetHash(), pos)); pos.nTxOffset += ::GetSerializeSize(tx, SER_DISK, CLIENT_VERSION); } - view.PushSproutAnchor(tree); + view.PushSproutAnchor(sprout_tree); + view.PushSaplingAnchor(sapling_tree); if (!fJustCheck) { - pindex->hashFinalSproutRoot = tree.root(); + pindex->hashFinalSproutRoot = sprout_tree.root(); + } + blockundo.old_sprout_tree_root = old_sprout_tree_root; + + // If Sapling is active, block.hashFinalSaplingRoot must be the + // same as the root of the Sapling tree + if (NetworkUpgradeActive(pindex->nHeight, chainparams.GetConsensus(), Consensus::UPGRADE_SAPLING)) { + if (block.hashFinalSaplingRoot != sapling_tree.root()) { + return state.DoS(100, + error("ConnectBlock(): block's hashFinalSaplingRoot is incorrect"), + REJECT_INVALID, "bad-sapling-root-in-block"); + } } - blockundo.old_tree_root = old_tree_root; int64_t nTime1 = GetTimeMicros(); nTimeConnect += nTime1 - nTimeStart; LogPrint("bench", " - Connect %u transactions: %.2fms (%.3fms/tx, %.3fms/txin) [%.2fs]\n", (unsigned)block.vtx.size(), 0.001 * (nTime1 - nTimeStart), 0.001 * (nTime1 - nTimeStart) / block.vtx.size(), nInputs <= 1 ? 0 : 0.001 * (nTime1 - nTimeStart) / (nInputs-1), nTimeConnect * 0.000001); diff --git a/src/undo.h b/src/undo.h index e01814e72..fbb350e60 100644 --- a/src/undo.h +++ b/src/undo.h @@ -67,14 +67,14 @@ class CBlockUndo { public: std::vector vtxundo; // for all but the coinbase - uint256 old_tree_root; + uint256 old_sprout_tree_root; ADD_SERIALIZE_METHODS; template inline void SerializationOp(Stream& s, Operation ser_action) { READWRITE(vtxundo); - READWRITE(old_tree_root); + READWRITE(old_sprout_tree_root); } };