diff --git a/zcash_client_backend/src/data_api.rs b/zcash_client_backend/src/data_api.rs index 72dd062a3..3c2ed1687 100644 --- a/zcash_client_backend/src/data_api.rs +++ b/zcash_client_backend/src/data_api.rs @@ -1242,12 +1242,9 @@ pub trait WalletTest: InputSource + WalletRead { #[allow(clippy::type_complexity)] fn get_checkpoint_history( &self, + protocol: &ShieldedProtocol, ) -> Result< - Vec<( - BlockHeight, - ShieldedProtocol, - Option, - )>, + Vec<(BlockHeight, Option)>, ::Error, >; diff --git a/zcash_client_backend/src/data_api/testing.rs b/zcash_client_backend/src/data_api/testing.rs index 060f77108..3289a69d6 100644 --- a/zcash_client_backend/src/data_api/testing.rs +++ b/zcash_client_backend/src/data_api/testing.rs @@ -17,7 +17,7 @@ use assert_matches::assert_matches; use group::ff::Field; use incrementalmerkletree::{Marking, Retention}; use nonempty::NonEmpty; -use rand::{CryptoRng, RngCore, SeedableRng}; +use rand::{CryptoRng, Rng, RngCore, SeedableRng}; use rand_chacha::ChaChaRng; use secrecy::{ExposeSecret, Secret, SecretVec}; use shardtree::{error::ShardTreeError, store::memory::MemoryShardStore, ShardTree}; @@ -507,6 +507,7 @@ where ); self.cache.insert(&compact_block) } + /// Creates a fake block at the expected next height containing a single output of the /// given value, and inserts it into the cache. pub fn generate_next_block( @@ -762,6 +763,24 @@ where (height, res) } + + /// Truncates the test wallet and block cache to the specified height, discarding all data from + /// blocks at heights greater than the specified height, excluding transaction data that may + /// not be recoverable from the chain. + pub fn truncate_to_height(&mut self, height: BlockHeight) { + self.wallet_mut().truncate_to_height(height).unwrap(); + self.cache.truncate_to_height(height); + self.cached_blocks.split_off(&(height + 1)); + self.latest_block_height = Some(height); + } + + /// Truncates the test wallet to the specified height, and resets the cache's latest block + /// height but does not truncate the block cache. This is useful for circumstances when you + /// want to re-scan a set of cached blocks. + pub fn truncate_to_height_retaining_cache(&mut self, height: BlockHeight) { + self.wallet_mut().truncate_to_height(height).unwrap(); + self.latest_block_height = Some(height); + } } impl TestState @@ -1590,7 +1609,7 @@ impl TestBuilder { } /// Trait used by tests that require a full viewing key. -pub trait TestFvk { +pub trait TestFvk: Clone { /// The type of nullifier corresponding to the kind of note that this full viewing key /// can detect (and that its corresponding spending key can spend). type Nullifier: Copy; @@ -2013,6 +2032,16 @@ impl FakeCompactOutput { value, } } + + /// Constructs a new random fake external output to the given FVK with a value in the range + /// 10000..1000000 ZAT. + pub fn random(rng: &mut R, fvk: Fvk) -> Self { + Self { + fvk, + address_type: AddressType::DefaultExternal, + value: Zatoshis::const_from_u64(rng.gen_range(10000..1000000)), + } + } } /// Create a fake CompactBlock at the given height, containing the specified fake compact outputs. @@ -2241,6 +2270,10 @@ pub trait TestCache { /// Inserts a CompactBlock into the cache DB. fn insert(&mut self, cb: &CompactBlock) -> Self::InsertResult; + + /// Deletes block data from the cache, retaining blocks at heights less than or equal to the + /// specified height. + fn truncate_to_height(&mut self, height: BlockHeight); } /// A convenience type for the note commitments contained within a [`CompactBlock`]. diff --git a/zcash_client_backend/src/data_api/testing/pool.rs b/zcash_client_backend/src/data_api/testing/pool.rs index aaf70a1ae..b92f3c406 100644 --- a/zcash_client_backend/src/data_api/testing/pool.rs +++ b/zcash_client_backend/src/data_api/testing/pool.rs @@ -6,8 +6,8 @@ use std::{ }; use assert_matches::assert_matches; -use incrementalmerkletree::{frontier::Frontier, Level}; -use rand::RngCore; +use incrementalmerkletree::{frontier::Frontier, Level, Position}; +use rand::{Rng, RngCore}; use secrecy::Secret; use shardtree::error::ShardTreeError; @@ -25,6 +25,7 @@ use zcash_primitives::{ }; use zcash_protocol::{ consensus::{self, BlockHeight, NetworkUpgrade, Parameters}, + local_consensus::LocalNetwork, memo::{Memo, MemoBytes}, value::Zatoshis, ShieldedProtocol, @@ -77,7 +78,7 @@ use { }; #[cfg(feature = "orchard")] -use {crate::PoolType, incrementalmerkletree::Position}; +use crate::PoolType; /// Trait that exposes the pool-specific types and operations necessary to run the /// single-shielded-pool tests on a given pool. @@ -2193,7 +2194,7 @@ pub fn multi_pool_checkpoint( .unwrap(); assert_eq!(st.get_total_balance(acct_id), expected_final); - let expected_checkpoints_p0: Vec<(BlockHeight, ShieldedProtocol, Option)> = [ + let expected_checkpoints_p0: Vec<(BlockHeight, Option)> = [ (99999, None), (100000, Some(0)), (100001, Some(1)), @@ -2204,16 +2205,10 @@ pub fn multi_pool_checkpoint( (100020, Some(6)), ] .into_iter() - .map(|(h, pos)| { - ( - BlockHeight::from(h), - P0::SHIELDED_PROTOCOL, - pos.map(Position::from), - ) - }) + .map(|(h, pos)| (BlockHeight::from(h), pos.map(Position::from))) .collect(); - let expected_checkpoints_p1: Vec<(BlockHeight, ShieldedProtocol, Option)> = [ + let expected_checkpoints_p1: Vec<(BlockHeight, Option)> = [ (99999, None), (100000, None), (100001, None), @@ -2224,33 +2219,20 @@ pub fn multi_pool_checkpoint( (100020, Some(2)), ] .into_iter() - .map(|(h, pos)| { - ( - BlockHeight::from(h), - P1::SHIELDED_PROTOCOL, - pos.map(Position::from), - ) - }) + .map(|(h, pos)| (BlockHeight::from(h), pos.map(Position::from))) .collect(); - let actual_checkpoints = st.wallet().get_checkpoint_history().unwrap(); + let p0_checkpoints = st + .wallet() + .get_checkpoint_history(&P0::SHIELDED_PROTOCOL) + .unwrap(); + assert_eq!(p0_checkpoints.to_vec(), expected_checkpoints_p0); - assert_eq!( - actual_checkpoints - .iter() - .filter(|(_, p, _)| p == &P0::SHIELDED_PROTOCOL) - .cloned() - .collect::>(), - expected_checkpoints_p0 - ); - assert_eq!( - actual_checkpoints - .iter() - .filter(|(_, p, _)| p == &P1::SHIELDED_PROTOCOL) - .cloned() - .collect::>(), - expected_checkpoints_p1 - ); + let p1_checkpoints = st + .wallet() + .get_checkpoint_history(&P1::SHIELDED_PROTOCOL) + .unwrap(); + assert_eq!(p1_checkpoints.to_vec(), expected_checkpoints_p1); } #[cfg(feature = "orchard")] @@ -2443,6 +2425,145 @@ where ); } +pub fn reorg_to_checkpoint(ds_factory: DSF, cache: C) +where + DSF: DataStoreFactory, + ::AccountId: std::fmt::Debug, + C: TestCache, +{ + let mut st = TestBuilder::new() + .with_data_store_factory(ds_factory) + .with_block_cache(cache) + .with_account_from_sapling_activation(BlockHash([0; 32])) + .build(); + + let account = st.test_account().cloned().unwrap(); + + // Create a sequence of blocks to serve as the foundation of our chain state. + let p0_fvk = T::random_fvk(st.rng_mut()); + let gen_random_block = |st: &mut TestState, + output_count: usize| { + let fake_outputs = + std::iter::repeat_with(|| FakeCompactOutput::random(st.rng_mut(), p0_fvk.clone())) + .take(output_count) + .collect::>(); + st.generate_next_block_multi(&fake_outputs[..]); + output_count + }; + + // The stable portion of the tree will contain 20 notes. + for _ in 0..10 { + gen_random_block(&mut st, 4); + } + + // We will reorg to this height. + let reorg_height = account.birthday().height() + 4; + let reorg_position = Position::from(19); + + // Scan the first 5 blocks. The last block in this sequence will be where we simulate a + // reorg. + st.scan_cached_blocks(account.birthday().height(), 5); + assert_eq!( + st.wallet() + .block_max_scanned() + .unwrap() + .unwrap() + .block_height(), + reorg_height + ); + + // There will be 6 checkpoints: one for the prior block frontier, and then one for each scanned + // block. + let checkpoints = st + .wallet() + .get_checkpoint_history(&T::SHIELDED_PROTOCOL) + .unwrap(); + assert_eq!(checkpoints.len(), 6); + assert_eq!( + checkpoints.last(), + Some(&(reorg_height, Some(reorg_position))) + ); + + // Scan another block, then simulate a reorg. + st.scan_cached_blocks(reorg_height + 1, 1); + assert_eq!( + st.wallet() + .block_max_scanned() + .unwrap() + .unwrap() + .block_height(), + reorg_height + 1 + ); + let checkpoints = st + .wallet() + .get_checkpoint_history(&T::SHIELDED_PROTOCOL) + .unwrap(); + assert_eq!(checkpoints.len(), 7); + assert_eq!( + checkpoints.last(), + Some(&(reorg_height + 1, Some(reorg_position + 4))) + ); + + // /\ /\ /\ + // .... /\/\/\/\/\/\ + // c d e + + // Truncate back to the reorg height, but retain the block cache. + st.truncate_to_height_retaining_cache(reorg_height); + + // The following error-prone tree state is generated by the current truncate implementation: + // /\ /\ + // .... /\/\/\/\ + // c + + // We have pruned back to the original checkpoints & tree state. + // let checkpoints = st + // .wallet() + // .get_checkpoint_history(&T::SHIELDED_PROTOCOL) + // .unwrap(); + // assert_eq!(checkpoints.len(), 6); + // assert_eq!( + // checkpoints.last(), + // Some(&(reorg_height, Some(reorg_position))) + // ); + + // Skip two blocks, then (re) scan the same block. + st.scan_cached_blocks(reorg_height + 2, 1); + + // Given the buggy truncation, this would result in this the following tree state: + // /\ /\ \ /\ + // .... /\/\/\/\ \/\/\ + // c e f + + // let checkpoints = st + // .wallet() + // .get_checkpoint_history(&T::SHIELDED_PROTOCOL) + // .unwrap(); + // // Even though we only scanned one block, we get a checkpoint at both the start and the end of + // // the block due to the insertion of the prior block frontier. + // assert_eq!(checkpoints.len(), 8); + // assert_eq!( + // checkpoints.last(), + // Some(&(reorg_height + 2, Some(reorg_position + 8))) + // ); + + // Now, fully truncate back to the reorg height. This should leave the tree in a state + // where it can be added to with arbitrary notes. + st.truncate_to_height(reorg_height); + + // Generate some new random blocks + for _ in 0..10 { + let output_count = st.rng_mut().gen_range(2..10); + gen_random_block(&mut st, output_count); + } + + // The previous truncation retained the cache, so re-scanning the same blocks would have + // resulted in the same note commitment tree state, and hence no conflicts; could occur. Now + // that we have cleared the cache and generated a different sequence blocks, if truncation did + // not completely clear the tree state this would generates a note commitment tree conflict. + st.scan_cached_blocks(reorg_height + 1, 1); +} + pub fn scan_cached_blocks_allows_blocks_out_of_order( ds_factory: impl DataStoreFactory, cache: impl TestCache, diff --git a/zcash_client_sqlite/src/chain.rs b/zcash_client_sqlite/src/chain.rs index 88951daa0..01209e181 100644 --- a/zcash_client_sqlite/src/chain.rs +++ b/zcash_client_sqlite/src/chain.rs @@ -364,6 +364,17 @@ mod tests { testing::pool::data_db_truncation::() } + #[test] + fn reorg_to_checkpoint_sapling() { + testing::pool::reorg_to_checkpoint::() + } + + #[test] + #[cfg(feature = "orchard")] + fn reorg_to_checkpoint_orchard() { + testing::pool::reorg_to_checkpoint::() + } + #[test] fn scan_cached_blocks_allows_blocks_out_of_order_sapling() { testing::pool::scan_cached_blocks_allows_blocks_out_of_order::() diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index c735aeb3a..28f407b51 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -685,15 +685,12 @@ impl, P: consensus::Parameters> WalletTest for W fn get_checkpoint_history( &self, + protocol: &ShieldedProtocol, ) -> Result< - Vec<( - BlockHeight, - ShieldedProtocol, - Option, - )>, + Vec<(BlockHeight, Option)>, ::Error, > { - wallet::testing::get_checkpoint_history(self.conn.borrow()) + wallet::testing::get_checkpoint_history(self.conn.borrow(), protocol) } #[cfg(feature = "transparent-inputs")] @@ -1186,6 +1183,8 @@ impl WalletWrite for WalletDb from_state.block_height(), from_state.final_sapling_tree().tree_size() ); + // We insert the frontier with `Checkpoint` retention because we need to be + // able to truncate the tree back to this point. sapling_tree.insert_frontier( from_state.final_sapling_tree().clone(), Retention::Checkpoint { @@ -1235,6 +1234,8 @@ impl WalletWrite for WalletDb from_state.block_height(), from_state.final_orchard_tree().tree_size() ); + // We insert the frontier with `Checkpoint` retention because we need to be + // able to truncate the tree back to this point. orchard_tree.insert_frontier( from_state.final_orchard_tree().clone(), Retention::Checkpoint { diff --git a/zcash_client_sqlite/src/testing.rs b/zcash_client_sqlite/src/testing.rs index aab011794..508961b89 100644 --- a/zcash_client_sqlite/src/testing.rs +++ b/zcash_client_sqlite/src/testing.rs @@ -56,12 +56,23 @@ impl TestCache for BlockCache { let res = NoteCommitments::from_compact_block(cb); self.db_cache .0 - .prepare("INSERT INTO compactblocks (height, data) VALUES (?, ?)") - .unwrap() - .execute(params![u32::from(cb.height()), cb_bytes,]) + .execute( + "INSERT INTO compactblocks (height, data) VALUES (?, ?)", + params![u32::from(cb.height()), cb_bytes,], + ) .unwrap(); res } + + fn truncate_to_height(&mut self, height: zcash_protocol::consensus::BlockHeight) { + self.db_cache + .0 + .execute( + "DELETE FROM compactblocks WHERE height > ?", + params![u32::from(height)], + ) + .unwrap(); + } } #[cfg(feature = "unstable")] @@ -115,4 +126,8 @@ impl TestCache for FsBlockCache { meta } + + fn truncate_to_height(&mut self, height: zcash_protocol::consensus::BlockHeight) { + self.db_meta.truncate_to_height(height).unwrap() + } } diff --git a/zcash_client_sqlite/src/testing/pool.rs b/zcash_client_sqlite/src/testing/pool.rs index a82eef350..21b2c6a79 100644 --- a/zcash_client_sqlite/src/testing/pool.rs +++ b/zcash_client_sqlite/src/testing/pool.rs @@ -217,6 +217,13 @@ pub(crate) fn data_db_truncation() { ) } +pub(crate) fn reorg_to_checkpoint() { + zcash_client_backend::data_api::testing::pool::reorg_to_checkpoint::( + TestDbFactory, + BlockCache::new(), + ) +} + pub(crate) fn scan_cached_blocks_allows_blocks_out_of_order() { zcash_client_backend::data_api::testing::pool::scan_cached_blocks_allows_blocks_out_of_order::( TestDbFactory, diff --git a/zcash_client_sqlite/src/wallet.rs b/zcash_client_sqlite/src/wallet.rs index 53ce52d53..ca3e98c27 100644 --- a/zcash_client_sqlite/src/wallet.rs +++ b/zcash_client_sqlite/src/wallet.rs @@ -2434,11 +2434,13 @@ pub(crate) fn truncate_to_height( params: params.clone(), }; wdb.with_sapling_tree_mut(|tree| { - tree.truncate_removing_checkpoint(&block_height).map(|_| ()) + tree.truncate_removing_checkpoint(&block_height)?; + Ok::<_, SqliteClientError>(()) })?; #[cfg(feature = "orchard")] wdb.with_orchard_tree_mut(|tree| { - tree.truncate_removing_checkpoint(&block_height).map(|_| ()) + tree.truncate_removing_checkpoint(&block_height)?; + Ok::<_, SqliteClientError>(()) })?; // Do not delete sent notes; this can contain data that is not recoverable @@ -3448,7 +3450,10 @@ pub mod testing { ShieldedProtocol, }; - use crate::{error::SqliteClientError, AccountId}; + use crate::{error::SqliteClientError, AccountId, SAPLING_TABLES_PREFIX}; + + #[cfg(feature = "orchard")] + use crate::ORCHARD_TABLES_PREFIX; pub(crate) fn get_tx_history( conn: &rusqlite::Connection, @@ -3490,24 +3495,31 @@ pub mod testing { #[allow(dead_code)] // used only for tests that are flagged off by default pub(crate) fn get_checkpoint_history( conn: &rusqlite::Connection, - ) -> Result)>, SqliteClientError> { - let mut stmt = conn.prepare_cached( - "SELECT checkpoint_id, 2 AS pool, position FROM sapling_tree_checkpoints - UNION - SELECT checkpoint_id, 3 AS pool, position FROM orchard_tree_checkpoints + protocol: &ShieldedProtocol, + ) -> Result)>, SqliteClientError> { + let table_prefix = match protocol { + ShieldedProtocol::Sapling => SAPLING_TABLES_PREFIX, + #[cfg(feature = "orchard")] + ShieldedProtocol::Orchard => ORCHARD_TABLES_PREFIX, + #[cfg(not(feature = "orchard"))] + ShieldedProtocol::Orchard => { + return Err(SqliteClientError::UnsupportedPoolType( + zcash_protocol::PoolType::ORCHARD, + )); + } + }; + + let mut stmt = conn.prepare_cached(&format!( + "SELECT checkpoint_id, position FROM {}_tree_checkpoints ORDER BY checkpoint_id", - )?; + table_prefix + ))?; let results = stmt .query_and_then::<_, SqliteClientError, _, _>([], |row| { Ok(( BlockHeight::from(row.get::<_, u32>(0)?), - match row.get::<_, i64>(1)? { - 2 => ShieldedProtocol::Sapling, - 3 => ShieldedProtocol::Orchard, - _ => unreachable!(), - }, - row.get::<_, Option>(2)?.map(Position::from), + row.get::<_, Option>(1)?.map(Position::from), )) })? .collect::, _>>()?;