From d11f3d2acc11d92f3f819e8f765516304501f9ab Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Tue, 13 Jun 2023 11:20:18 -0600 Subject: [PATCH] zcash_client_sqlite: Add shardtree truncation & checkpoint operations. --- zcash_client_sqlite/src/chain.rs | 4 +- zcash_client_sqlite/src/lib.rs | 53 ++--- zcash_client_sqlite/src/wallet.rs | 17 +- zcash_client_sqlite/src/wallet/init.rs | 5 +- .../init/migrations/shardtree_support.rs | 3 + .../src/wallet/sapling/commitment_tree.rs | 181 +++++++++++++++--- 6 files changed, 209 insertions(+), 54 deletions(-) diff --git a/zcash_client_sqlite/src/chain.rs b/zcash_client_sqlite/src/chain.rs index 11e065f9e..d115482e2 100644 --- a/zcash_client_sqlite/src/chain.rs +++ b/zcash_client_sqlite/src/chain.rs @@ -539,7 +539,7 @@ mod tests { // "Rewind" to height of last scanned block db_data .transactionally(|wdb| { - truncate_to_height(&wdb.conn.0, &wdb.params, sapling_activation_height() + 1) + truncate_to_height(wdb.conn.0, &wdb.params, sapling_activation_height() + 1) }) .unwrap(); @@ -552,7 +552,7 @@ mod tests { // Rewind so that one block is dropped db_data .transactionally(|wdb| { - truncate_to_height(&wdb.conn.0, &wdb.params, sapling_activation_height()) + truncate_to_height(wdb.conn.0, &wdb.params, sapling_activation_height()) }) .unwrap(); diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index 88b884b12..cae3685e7 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -116,11 +116,11 @@ pub struct WalletDb { } /// A wrapper for a SQLite transaction affecting the wallet database. -pub struct SqlTransaction<'conn>(pub(crate) rusqlite::Transaction<'conn>); +pub struct SqlTransaction<'conn>(pub(crate) &'conn rusqlite::Transaction<'conn>); impl Borrow for SqlTransaction<'_> { fn borrow(&self) -> &rusqlite::Connection { - &self.0 + self.0 } } @@ -137,12 +137,13 @@ impl WalletDb { where F: FnOnce(&mut WalletDb, P>) -> Result, { + let tx = self.conn.transaction()?; let mut wdb = WalletDb { - conn: SqlTransaction(self.conn.transaction()?), + conn: SqlTransaction(&tx), params: self.params.clone(), }; let result = f(&mut wdb)?; - wdb.conn.0.commit()?; + tx.commit()?; Ok(result) } } @@ -334,7 +335,7 @@ impl WalletWrite for WalletDb seed: &SecretVec, ) -> Result<(AccountId, UnifiedSpendingKey), Self::Error> { self.transactionally(|wdb| { - let account = wallet::get_max_account_id(&wdb.conn.0)? + let account = wallet::get_max_account_id(wdb.conn.0)? .map(|a| AccountId::from(u32::from(a) + 1)) .unwrap_or_else(|| AccountId::from(0)); @@ -346,7 +347,7 @@ impl WalletWrite for WalletDb .map_err(|_| SqliteClientError::KeyDerivationError(account))?; let ufvk = usk.to_unified_full_viewing_key(); - wallet::add_account(&wdb.conn.0, &wdb.params, account, &ufvk)?; + wallet::add_account(wdb.conn.0, &wdb.params, account, &ufvk)?; Ok((account, usk)) }) @@ -360,7 +361,7 @@ impl WalletWrite for WalletDb |wdb| match wdb.get_unified_full_viewing_keys()?.get(&account) { Some(ufvk) => { let search_from = - match wallet::get_current_address(&wdb.conn.0, &wdb.params, account)? { + match wallet::get_current_address(wdb.conn.0, &wdb.params, account)? { Some((_, mut last_diversifier_index)) => { last_diversifier_index .increment() @@ -375,7 +376,7 @@ impl WalletWrite for WalletDb .ok_or(SqliteClientError::DiversifierIndexOutOfRange)?; wallet::insert_address( - &wdb.conn.0, + wdb.conn.0, &wdb.params, account, diversifier_index, @@ -399,7 +400,7 @@ impl WalletWrite for WalletDb // Insert the block into the database. let block_height = block.block_height; wallet::insert_block( - &wdb.conn.0, + wdb.conn.0, block_height, block.block_hash, block.block_time, @@ -408,16 +409,16 @@ impl WalletWrite for WalletDb let mut wallet_note_ids = vec![]; for tx in &block.transactions { - let tx_row = wallet::put_tx_meta(&wdb.conn.0, tx, block.block_height)?; + let tx_row = wallet::put_tx_meta(wdb.conn.0, tx, block.block_height)?; // Mark notes as spent and remove them from the scanning cache for spend in &tx.sapling_spends { - wallet::sapling::mark_sapling_note_spent(&wdb.conn.0, tx_row, spend.nf())?; + wallet::sapling::mark_sapling_note_spent(wdb.conn.0, tx_row, spend.nf())?; } for output in &tx.sapling_outputs { let received_note_id = - wallet::sapling::put_received_note(&wdb.conn.0, output, tx_row)?; + wallet::sapling::put_received_note(wdb.conn.0, output, tx_row)?; // Save witness for note. wallet_note_ids.push(received_note_id); @@ -435,7 +436,7 @@ impl WalletWrite for WalletDb })?; // Update now-expired transactions that didn't get mined. - wallet::update_expired_notes(&wdb.conn.0, block_height)?; + wallet::update_expired_notes(wdb.conn.0, block_height)?; Ok(wallet_note_ids) }) @@ -446,7 +447,7 @@ impl WalletWrite for WalletDb d_tx: DecryptedTransaction, ) -> Result { self.transactionally(|wdb| { - let tx_ref = wallet::put_tx_data(&wdb.conn.0, d_tx.tx, None, None)?; + let tx_ref = wallet::put_tx_data(wdb.conn.0, d_tx.tx, None, None)?; let mut spending_account_id: Option = None; for output in d_tx.sapling_outputs { @@ -459,7 +460,7 @@ impl WalletWrite for WalletDb }; wallet::put_sent_output( - &wdb.conn.0, + wdb.conn.0, &wdb.params, output.account, tx_ref, @@ -474,7 +475,7 @@ impl WalletWrite for WalletDb )?; if matches!(recipient, Recipient::InternalAccount(_, _)) { - wallet::sapling::put_received_note(&wdb.conn.0, output, tx_ref)?; + wallet::sapling::put_received_note(wdb.conn.0, output, tx_ref)?; } } TransferType::Incoming => { @@ -489,14 +490,14 @@ impl WalletWrite for WalletDb } } - wallet::sapling::put_received_note(&wdb.conn.0, output, tx_ref)?; + wallet::sapling::put_received_note(wdb.conn.0, output, tx_ref)?; } } // If any of the utxos spent in the transaction are ours, mark them as spent. #[cfg(feature = "transparent-inputs")] for txin in d_tx.tx.transparent_bundle().iter().flat_map(|b| b.vin.iter()) { - wallet::mark_transparent_utxo_spent(&wdb.conn.0, tx_ref, &txin.prevout)?; + wallet::mark_transparent_utxo_spent(wdb.conn.0, tx_ref, &txin.prevout)?; } // If we have some transparent outputs: @@ -513,7 +514,7 @@ impl WalletWrite for WalletDb for (output_index, txout) in d_tx.tx.transparent_bundle().iter().flat_map(|b| b.vout.iter()).enumerate() { if let Some(address) = txout.recipient_address() { wallet::put_sent_output( - &wdb.conn.0, + wdb.conn.0, &wdb.params, *account_id, tx_ref, @@ -534,7 +535,7 @@ impl WalletWrite for WalletDb fn store_sent_tx(&mut self, sent_tx: &SentTransaction) -> Result { self.transactionally(|wdb| { let tx_ref = wallet::put_tx_data( - &wdb.conn.0, + wdb.conn.0, sent_tx.tx, Some(sent_tx.fee_amount), Some(sent_tx.created), @@ -551,7 +552,7 @@ impl WalletWrite for WalletDb if let Some(bundle) = sent_tx.tx.sapling_bundle() { for spend in bundle.shielded_spends() { wallet::sapling::mark_sapling_note_spent( - &wdb.conn.0, + wdb.conn.0, tx_ref, spend.nullifier(), )?; @@ -560,12 +561,12 @@ impl WalletWrite for WalletDb #[cfg(feature = "transparent-inputs")] for utxo_outpoint in &sent_tx.utxos_spent { - wallet::mark_transparent_utxo_spent(&wdb.conn.0, tx_ref, utxo_outpoint)?; + wallet::mark_transparent_utxo_spent(wdb.conn.0, tx_ref, utxo_outpoint)?; } for output in &sent_tx.outputs { wallet::insert_sent_output( - &wdb.conn.0, + wdb.conn.0, &wdb.params, tx_ref, sent_tx.account, @@ -574,7 +575,7 @@ impl WalletWrite for WalletDb if let Some((account, note)) = output.sapling_change_to() { wallet::sapling::put_received_note( - &wdb.conn.0, + wdb.conn.0, &DecryptedOutput { index: output.output_index(), note: note.clone(), @@ -596,7 +597,7 @@ impl WalletWrite for WalletDb fn truncate_to_height(&mut self, block_height: BlockHeight) -> Result<(), Self::Error> { self.transactionally(|wdb| { - wallet::truncate_to_height(&wdb.conn.0, &wdb.params, block_height) + wallet::truncate_to_height(wdb.conn.0, &wdb.params, block_height) }) } @@ -661,7 +662,7 @@ impl<'conn, P: consensus::Parameters> WalletCommitmentTrees for WalletDb>>, { let mut shardtree = ShardTree::new( - WalletDbSaplingShardStore::from_connection(&self.conn.0) + WalletDbSaplingShardStore::from_connection(self.conn.0) .map_err(|e| ShardTreeError::Storage(Either::Right(e)))?, 100, ); diff --git a/zcash_client_sqlite/src/wallet.rs b/zcash_client_sqlite/src/wallet.rs index 8f071d55d..6b2d15072 100644 --- a/zcash_client_sqlite/src/wallet.rs +++ b/zcash_client_sqlite/src/wallet.rs @@ -89,7 +89,9 @@ use zcash_client_backend::{ wallet::WalletTx, }; -use crate::{error::SqliteClientError, PRUNING_HEIGHT}; +use crate::{ + error::SqliteClientError, SqlTransaction, WalletCommitmentTrees, WalletDb, PRUNING_HEIGHT, +}; #[cfg(feature = "transparent-inputs")] use { @@ -637,7 +639,7 @@ pub(crate) fn get_min_unspent_height( /// block, this function does nothing. /// /// This should only be executed inside a transactional context. -pub(crate) fn truncate_to_height( +pub(crate) fn truncate_to_height( conn: &rusqlite::Transaction, params: &P, block_height: BlockHeight, @@ -662,7 +664,16 @@ pub(crate) fn truncate_to_height( // nothing to do if we're deleting back down to the max height if block_height < last_scanned_height { - // Decrement witnesses. + // Truncate the note commitment trees + let mut wdb = WalletDb { + conn: SqlTransaction(conn), + params: params.clone(), + }; + wdb.with_sapling_tree_mut(|tree| { + tree.truncate_removing_checkpoint(&block_height).map(|_| ()) + })?; + + // Remove any legacy Sapling witnesses conn.execute( "DELETE FROM sapling_witnesses WHERE block > ?", [u32::from(block_height)], diff --git a/zcash_client_sqlite/src/wallet/init.rs b/zcash_client_sqlite/src/wallet/init.rs index 6fba9396c..a98806ee6 100644 --- a/zcash_client_sqlite/src/wallet/init.rs +++ b/zcash_client_sqlite/src/wallet/init.rs @@ -237,7 +237,7 @@ pub fn init_accounts_table( // Insert accounts atomically for (account, key) in keys.iter() { - wallet::add_account(&wdb.conn.0, &wdb.params, *account, key)?; + wallet::add_account(wdb.conn.0, &wdb.params, *account, key)?; } Ok(()) @@ -394,6 +394,9 @@ mod tests { CONSTRAINT tx_output UNIQUE (tx, output_index) )", "CREATE TABLE sapling_tree_cap ( + -- cap_id exists only to be able to take advantage of `ON CONFLICT` + -- upsert functionality; the table will only ever contain one row + cap_id INTEGER PRIMARY KEY, cap_data BLOB NOT NULL )", "CREATE TABLE sapling_tree_checkpoint_marks_removed ( diff --git a/zcash_client_sqlite/src/wallet/init/migrations/shardtree_support.rs b/zcash_client_sqlite/src/wallet/init/migrations/shardtree_support.rs index e16c36c8c..f22b03c20 100644 --- a/zcash_client_sqlite/src/wallet/init/migrations/shardtree_support.rs +++ b/zcash_client_sqlite/src/wallet/init/migrations/shardtree_support.rs @@ -69,6 +69,9 @@ impl RusqliteMigration for Migration { CONSTRAINT root_unique UNIQUE (root_hash) ); CREATE TABLE sapling_tree_cap ( + -- cap_id exists only to be able to take advantage of `ON CONFLICT` + -- upsert functionality; the table will only ever contain one row + cap_id INTEGER PRIMARY KEY, cap_data BLOB NOT NULL );", )?; diff --git a/zcash_client_sqlite/src/wallet/sapling/commitment_tree.rs b/zcash_client_sqlite/src/wallet/sapling/commitment_tree.rs index 80c800fe6..912831fcf 100644 --- a/zcash_client_sqlite/src/wallet/sapling/commitment_tree.rs +++ b/zcash_client_sqlite/src/wallet/sapling/commitment_tree.rs @@ -1,10 +1,14 @@ use either::Either; -use incrementalmerkletree::Address; -use rusqlite::{self, named_params, OptionalExtension}; -use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore}; +use incrementalmerkletree::{Address, Position}; +use rusqlite::{self, named_params, Connection, OptionalExtension}; +use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore, TreeState}; -use std::io::{self, Cursor}; +use std::{ + collections::BTreeSet, + io::{self, Cursor}, + ops::Deref, +}; use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer, sapling}; @@ -48,16 +52,16 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> { todo!() } - fn truncate(&mut self, _from: Address) -> Result<(), Self::Error> { - todo!() + fn truncate(&mut self, from: Address) -> Result<(), Self::Error> { + truncate(self.conn, from) } fn get_cap(&self) -> Result, Self::Error> { - todo!() + get_cap(self.conn) } - fn put_cap(&mut self, _cap: PrunableTree) -> Result<(), Self::Error> { - todo!() + fn put_cap(&mut self, cap: PrunableTree) -> Result<(), Self::Error> { + put_cap(self.conn, cap) } fn min_checkpoint_id(&self) -> Result, Self::Error> { @@ -89,9 +93,9 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> { fn get_checkpoint( &self, - _checkpoint_id: &Self::CheckpointId, + checkpoint_id: &Self::CheckpointId, ) -> Result, Self::Error> { - todo!() + get_checkpoint(self.conn, *checkpoint_id) } fn with_checkpoints(&mut self, _limit: usize, _callback: F) -> Result<(), Self::Error> @@ -103,27 +107,24 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> { fn update_checkpoint_with( &mut self, - _checkpoint_id: &Self::CheckpointId, - _update: F, + checkpoint_id: &Self::CheckpointId, + update: F, ) -> Result where F: Fn(&mut Checkpoint) -> Result<(), Self::Error>, { - todo!() + update_checkpoint_with(self.conn, *checkpoint_id, update) } - fn remove_checkpoint( - &mut self, - _checkpoint_id: &Self::CheckpointId, - ) -> Result<(), Self::Error> { - todo!() + fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> { + remove_checkpoint(self.conn, *checkpoint_id) } fn truncate_checkpoints( &mut self, - _checkpoint_id: &Self::CheckpointId, + checkpoint_id: &Self::CheckpointId, ) -> Result<(), Self::Error> { - todo!() + truncate_checkpoints(self.conn, *checkpoint_id) } } @@ -134,7 +135,7 @@ pub(crate) fn get_shard( shard_root: Address, ) -> Result>, Error> { conn.query_row( - "SELECT shard_data + "SELECT shard_data FROM sapling_tree_shards WHERE shard_index = :shard_index", named_params![":shard_index": shard_root.index()], @@ -188,6 +189,47 @@ pub(crate) fn put_shard( Ok(()) } +pub(crate) fn truncate(conn: &rusqlite::Transaction<'_>, from: Address) -> Result<(), Error> { + conn.execute( + "DELETE FROM sapling_tree_shards WHERE shard_index >= ?", + [from.index()], + ) + .map_err(Either::Right) + .map(|_| ()) +} + +pub(crate) fn get_cap(conn: &rusqlite::Connection) -> Result, Error> { + conn.query_row("SELECT cap_data FROM sapling_tree_cap", [], |row| { + row.get::<_, Vec>(0) + }) + .optional() + .map_err(Either::Right)? + .map_or_else( + || Ok(PrunableTree::empty()), + |cap_data| read_shard(&mut Cursor::new(cap_data)).map_err(Either::Left), + ) +} + +pub(crate) fn put_cap( + conn: &rusqlite::Transaction<'_>, + cap: PrunableTree, +) -> Result<(), Error> { + let mut stmt = conn + .prepare_cached( + "INSERT INTO sapling_tree_cap (cap_id, cap_data) + VALUES (0, :cap_data) + ON CONFLICT (cap_id) DO UPDATE + SET cap_data = :cap_data", + ) + .map_err(Either::Right)?; + + let mut cap_data = vec![]; + write_shard_v1(&mut cap_data, &cap).map_err(Either::Left)?; + stmt.execute([cap_data]).map_err(Either::Right)?; + + Ok(()) +} + pub(crate) fn add_checkpoint( conn: &rusqlite::Transaction<'_>, checkpoint_id: BlockHeight, @@ -214,3 +256,98 @@ pub(crate) fn checkpoint_count(conn: &rusqlite::Connection) -> Result>( + conn: &C, + checkpoint_id: BlockHeight, +) -> Result, Either> { + let checkpoint_position = conn + .query_row( + "SELECT position + FROM sapling_tree_checkpoints + WHERE checkpoint_id = ?", + [u32::from(checkpoint_id)], + |row| { + row.get::<_, Option>(0) + .map(|opt| opt.map(Position::from)) + }, + ) + .optional() + .map_err(Either::Right)?; + + let mut marks_removed = BTreeSet::new(); + let mut stmt = conn + .prepare_cached( + "SELECT mark_removed_position + FROM sapling_tree_checkpoint_marks_removed + WHERE checkpoint_id = ?", + ) + .map_err(Either::Right)?; + let mut mark_removed_rows = stmt + .query([u32::from(checkpoint_id)]) + .map_err(Either::Right)?; + + while let Some(row) = mark_removed_rows.next().map_err(Either::Right)? { + marks_removed.insert( + row.get::<_, u64>(0) + .map(Position::from) + .map_err(Either::Right)?, + ); + } + + Ok(checkpoint_position.map(|pos_opt| { + Checkpoint::from_parts( + pos_opt.map_or(TreeState::Empty, TreeState::AtPosition), + marks_removed, + ) + })) +} + +pub(crate) fn update_checkpoint_with( + conn: &rusqlite::Transaction<'_>, + checkpoint_id: BlockHeight, + update: F, +) -> Result +where + F: Fn(&mut Checkpoint) -> Result<(), Error>, +{ + if let Some(mut c) = get_checkpoint(conn, checkpoint_id)? { + update(&mut c)?; + remove_checkpoint(conn, checkpoint_id)?; + add_checkpoint(conn, checkpoint_id, c)?; + Ok(true) + } else { + Ok(false) + } +} + +pub(crate) fn remove_checkpoint( + conn: &rusqlite::Transaction<'_>, + checkpoint_id: BlockHeight, +) -> Result<(), Error> { + conn.execute( + "DELETE FROM sapling_tree_checkpoints WHERE checkpoint_id = ?", + [u32::from(checkpoint_id)], + ) + .map_err(Either::Right)?; + + Ok(()) +} + +pub(crate) fn truncate_checkpoints( + conn: &rusqlite::Transaction<'_>, + checkpoint_id: BlockHeight, +) -> Result<(), Error> { + conn.execute( + "DELETE FROM sapling_tree_checkpoints WHERE checkpoint_id >= ?", + [u32::from(checkpoint_id)], + ) + .map_err(Either::Right)?; + + conn.execute( + "DELETE FROM sapling_tree_checkpoint_marks_removed WHERE checkpoint_id >= ?", + [u32::from(checkpoint_id)], + ) + .map_err(Either::Right)?; + Ok(()) +}