diff --git a/zcash_client_sqlite/Cargo.toml b/zcash_client_sqlite/Cargo.toml index c58dba2a7..80a200026 100644 --- a/zcash_client_sqlite/Cargo.toml +++ b/zcash_client_sqlite/Cargo.toml @@ -47,6 +47,7 @@ uuid = "1.1" [dev-dependencies] assert_matches = "1.5" +incrementalmerkletree = { version = "0.4", features = ["legacy-api", "test-dependencies"] } proptest = "1.0.0" rand_core = "0.6" regex = "1.4" @@ -59,6 +60,7 @@ zcash_address = { version = "0.3", path = "../components/zcash_address", feature [features] mainnet = [] test-dependencies = [ + "incrementalmerkletree/test-dependencies", "zcash_primitives/test-dependencies", "zcash_client_backend/test-dependencies", ] diff --git a/zcash_client_sqlite/src/chain.rs b/zcash_client_sqlite/src/chain.rs index fc9e8d09f..81a0e028a 100644 --- a/zcash_client_sqlite/src/chain.rs +++ b/zcash_client_sqlite/src/chain.rs @@ -299,7 +299,7 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet - let (dfvk, _taddr) = init_test_accounts_table(&db_data); + let (dfvk, _taddr) = init_test_accounts_table(&mut db_data); // Empty chain should return None assert_matches!(db_data.get_max_height_hash(), Ok(None)); @@ -328,8 +328,7 @@ mod tests { assert_matches!(validate_chain_result, Ok(())); // Scan the cache - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Data-only chain should be valid validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap(); @@ -348,7 +347,7 @@ mod tests { validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap(); // Scan the cache again - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Data-only chain should be valid validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap(); @@ -365,7 +364,7 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet - let (dfvk, _taddr) = init_test_accounts_table(&db_data); + let (dfvk, _taddr) = init_test_accounts_table(&mut db_data); // Create some fake CompactBlocks let (cb, _) = fake_compact_block( @@ -386,8 +385,7 @@ mod tests { insert_into_cache(&db_cache, &cb2); // Scan the cache - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Data-only chain should be valid validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap(); @@ -427,7 +425,7 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet - let (dfvk, _taddr) = init_test_accounts_table(&db_data); + let (dfvk, _taddr) = init_test_accounts_table(&mut db_data); // Create some fake CompactBlocks let (cb, _) = fake_compact_block( @@ -448,8 +446,7 @@ mod tests { insert_into_cache(&db_cache, &cb2); // Scan the cache - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Data-only chain should be valid validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap(); @@ -489,11 +486,11 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet - let (dfvk, _taddr) = init_test_accounts_table(&db_data); + let (dfvk, _taddr) = init_test_accounts_table(&mut db_data); // Account balance should be zero assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), Amount::zero() ); @@ -519,36 +516,46 @@ mod tests { insert_into_cache(&db_cache, &cb2); // Scan the cache - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Account balance should reflect both received notes assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), (value + value2).unwrap() ); // "Rewind" to height of last scanned block - truncate_to_height(&db_data, sapling_activation_height() + 1).unwrap(); + db_data + .transactionally(|wdb| { + truncate_to_height(&wdb.conn.0, &wdb.params, sapling_activation_height() + 1) + }) + .unwrap(); // Account balance should be unaltered assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), (value + value2).unwrap() ); // Rewind so that one block is dropped - truncate_to_height(&db_data, sapling_activation_height()).unwrap(); + db_data + .transactionally(|wdb| { + truncate_to_height(&wdb.conn.0, &wdb.params, sapling_activation_height()) + }) + .unwrap(); // Account balance should only contain the first received note - assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); + assert_eq!( + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), + value + ); // Scan the cache again - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Account balance should again reflect both received notes assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), (value + value2).unwrap() ); } @@ -564,7 +571,7 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet - let (dfvk, _taddr) = init_test_accounts_table(&db_data); + let (dfvk, _taddr) = init_test_accounts_table(&mut db_data); // Create a block with height SAPLING_ACTIVATION_HEIGHT let value = Amount::from_u64(50000).unwrap(); @@ -576,9 +583,11 @@ mod tests { value, ); insert_into_cache(&db_cache, &cb1); - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); - assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); + assert_eq!( + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), + value + ); // We cannot scan a block of height SAPLING_ACTIVATION_HEIGHT + 2 next let (cb2, _) = fake_compact_block( @@ -596,7 +605,7 @@ mod tests { value, ); insert_into_cache(&db_cache, &cb3); - match scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None) { + match scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None) { Err(Error::Chain(e)) => { assert_matches!( e.cause(), @@ -609,9 +618,9 @@ mod tests { // If we add a block of height SAPLING_ACTIVATION_HEIGHT + 1, we can now scan both insert_into_cache(&db_cache, &cb2); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), Amount::from_u64(150_000).unwrap() ); } @@ -627,11 +636,11 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet - let (dfvk, _taddr) = init_test_accounts_table(&db_data); + let (dfvk, _taddr) = init_test_accounts_table(&mut db_data); // Account balance should be zero assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), Amount::zero() ); @@ -647,11 +656,13 @@ mod tests { insert_into_cache(&db_cache, &cb); // Scan the cache - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Account balance should reflect the received note - assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); + assert_eq!( + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), + value + ); // Create a second fake CompactBlock sending more value to the address let value2 = Amount::from_u64(7).unwrap(); @@ -665,11 +676,11 @@ mod tests { insert_into_cache(&db_cache, &cb2); // Scan the cache again - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Account balance should reflect both received notes assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), (value + value2).unwrap() ); } @@ -685,11 +696,11 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet - let (dfvk, _taddr) = init_test_accounts_table(&db_data); + let (dfvk, _taddr) = init_test_accounts_table(&mut db_data); // Account balance should be zero assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), Amount::zero() ); @@ -705,11 +716,13 @@ mod tests { insert_into_cache(&db_cache, &cb); // Scan the cache - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Account balance should reflect the received note - assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); + assert_eq!( + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), + value + ); // Create a second fake CompactBlock spending value from the address let extsk2 = ExtendedSpendingKey::master(&[0]); @@ -728,11 +741,11 @@ mod tests { ); // Scan the cache again - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Account balance should equal the change assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), (value - value2).unwrap() ); } diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index 60cc142ae..1835b27a9 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -32,18 +32,16 @@ // Catch documentation errors caused by code changes. #![deny(rustdoc::broken_intra_doc_links)] -use rusqlite::Connection; +use rusqlite::{self, Connection}; use secrecy::{ExposeSecret, SecretVec}; -use std::collections::HashMap; -use std::fmt; -use std::path::Path; +use std::{borrow::Borrow, collections::HashMap, convert::AsRef, fmt, path::Path}; use zcash_primitives::{ block::BlockHash, consensus::{self, BlockHeight}, legacy::TransparentAddress, memo::{Memo, MemoBytes}, - sapling::{self}, + sapling, transaction::{ components::{amount::Amount, OutPoint}, Transaction, TxId, @@ -72,9 +70,6 @@ use { std::{fs, io}, }; -mod prepared; -pub use prepared::DataConnStmtCache; - pub mod chain; pub mod error; pub mod wallet; @@ -107,12 +102,21 @@ impl fmt::Display for NoteId { pub struct UtxoId(pub i64); /// A wrapper for the SQLite connection to the wallet database. -pub struct WalletDb

{ - conn: Connection, +pub struct WalletDb { + conn: C, params: P, } -impl WalletDb

{ +/// A wrapper for a SQLite transaction affecting the wallet database. +pub struct SqlTransaction<'conn>(pub(crate) rusqlite::Transaction<'conn>); + +impl Borrow for SqlTransaction<'_> { + fn borrow(&self) -> &rusqlite::Connection { + &self.0 + } +} + +impl WalletDb { /// Construct a connection to the wallet database stored at the specified path. pub fn for_path>(path: F, params: P) -> Result { Connection::open(path).and_then(move |conn| { @@ -121,53 +125,60 @@ impl WalletDb

{ }) } - /// Given a wallet database connection, obtain a handle for the write operations - /// for that database. This operation may eagerly initialize and cache sqlite - /// prepared statements that are used in write operations. - pub fn get_update_ops(&self) -> Result, SqliteClientError> { - DataConnStmtCache::new(self) + pub fn transactionally(&mut self, f: F) -> Result + where + F: FnOnce(&WalletDb, P>) -> Result, + { + let wdb = WalletDb { + conn: SqlTransaction(self.conn.transaction()?), + params: self.params.clone(), + }; + let result = f(&wdb)?; + wdb.conn.0.commit()?; + Ok(result) } } -impl WalletRead for WalletDb

{ +impl, P: consensus::Parameters> WalletRead for WalletDb { type Error = SqliteClientError; type NoteRef = NoteId; type TxRef = i64; fn block_height_extrema(&self) -> Result, Self::Error> { - wallet::block_height_extrema(self).map_err(SqliteClientError::from) + wallet::block_height_extrema(self.conn.borrow()).map_err(SqliteClientError::from) } fn get_min_unspent_height(&self) -> Result, Self::Error> { - wallet::get_min_unspent_height(self).map_err(SqliteClientError::from) + wallet::get_min_unspent_height(self.conn.borrow()).map_err(SqliteClientError::from) } fn get_block_hash(&self, block_height: BlockHeight) -> Result, Self::Error> { - wallet::get_block_hash(self, block_height).map_err(SqliteClientError::from) + wallet::get_block_hash(self.conn.borrow(), block_height).map_err(SqliteClientError::from) } fn get_tx_height(&self, txid: TxId) -> Result, Self::Error> { - wallet::get_tx_height(self, txid).map_err(SqliteClientError::from) + wallet::get_tx_height(self.conn.borrow(), txid).map_err(SqliteClientError::from) } fn get_unified_full_viewing_keys( &self, ) -> Result, Self::Error> { - wallet::get_unified_full_viewing_keys(self) + wallet::get_unified_full_viewing_keys(self.conn.borrow(), &self.params) } fn get_account_for_ufvk( &self, ufvk: &UnifiedFullViewingKey, ) -> Result, Self::Error> { - wallet::get_account_for_ufvk(self, ufvk) + wallet::get_account_for_ufvk(self.conn.borrow(), &self.params, ufvk) } fn get_current_address( &self, account: AccountId, ) -> Result, Self::Error> { - wallet::get_current_address(self, account).map(|res| res.map(|(addr, _)| addr)) + wallet::get_current_address(self.conn.borrow(), &self.params, account) + .map(|res| res.map(|(addr, _)| addr)) } fn is_valid_account_extfvk( @@ -175,7 +186,7 @@ impl WalletRead for WalletDb

{ account: AccountId, extfvk: &ExtendedFullViewingKey, ) -> Result { - wallet::is_valid_account_extfvk(self, account, extfvk) + wallet::is_valid_account_extfvk(self.conn.borrow(), &self.params, account, extfvk) } fn get_balance_at( @@ -183,17 +194,19 @@ impl WalletRead for WalletDb

{ account: AccountId, anchor_height: BlockHeight, ) -> Result { - wallet::get_balance_at(self, account, anchor_height) + wallet::get_balance_at(self.conn.borrow(), account, anchor_height) } fn get_transaction(&self, id_tx: i64) -> Result { - wallet::get_transaction(self, id_tx) + wallet::get_transaction(self.conn.borrow(), &self.params, id_tx) } fn get_memo(&self, id_note: Self::NoteRef) -> Result, Self::Error> { match id_note { - NoteId::SentNoteId(id_note) => wallet::get_sent_memo(self, id_note), - NoteId::ReceivedNoteId(id_note) => wallet::get_received_memo(self, id_note), + NoteId::SentNoteId(id_note) => wallet::get_sent_memo(self.conn.borrow(), id_note), + NoteId::ReceivedNoteId(id_note) => { + wallet::get_received_memo(self.conn.borrow(), id_note) + } } } @@ -201,7 +214,7 @@ impl WalletRead for WalletDb

{ &self, block_height: BlockHeight, ) -> Result, Self::Error> { - wallet::sapling::get_sapling_commitment_tree(self, block_height) + wallet::sapling::get_sapling_commitment_tree(self.conn.borrow(), block_height) } #[allow(clippy::type_complexity)] @@ -209,7 +222,7 @@ impl WalletRead for WalletDb

{ &self, block_height: BlockHeight, ) -> Result, Self::Error> { - wallet::sapling::get_sapling_witnesses(self, block_height) + wallet::sapling::get_sapling_witnesses(self.conn.borrow(), block_height) } fn get_sapling_nullifiers( @@ -217,8 +230,8 @@ impl WalletRead for WalletDb

{ query: data_api::NullifierQuery, ) -> Result, Self::Error> { match query { - NullifierQuery::Unspent => wallet::sapling::get_sapling_nullifiers(self), - NullifierQuery::All => wallet::sapling::get_all_sapling_nullifiers(self), + NullifierQuery::Unspent => wallet::sapling::get_sapling_nullifiers(self.conn.borrow()), + NullifierQuery::All => wallet::sapling::get_all_sapling_nullifiers(self.conn.borrow()), } } @@ -228,7 +241,12 @@ impl WalletRead for WalletDb

{ anchor_height: BlockHeight, exclude: &[Self::NoteRef], ) -> Result>, Self::Error> { - wallet::sapling::get_spendable_sapling_notes(self, account, anchor_height, exclude) + wallet::sapling::get_spendable_sapling_notes( + self.conn.borrow(), + account, + anchor_height, + exclude, + ) } fn select_spendable_sapling_notes( @@ -239,7 +257,7 @@ impl WalletRead for WalletDb

{ exclude: &[Self::NoteRef], ) -> Result>, Self::Error> { wallet::sapling::select_spendable_sapling_notes( - self, + self.conn.borrow(), account, target_value, anchor_height, @@ -252,7 +270,7 @@ impl WalletRead for WalletDb

{ _account: AccountId, ) -> Result, Self::Error> { #[cfg(feature = "transparent-inputs")] - return wallet::get_transparent_receivers(&self.params, &self.conn, _account); + return wallet::get_transparent_receivers(self.conn.borrow(), &self.params, _account); #[cfg(not(feature = "transparent-inputs"))] panic!( @@ -267,7 +285,13 @@ impl WalletRead for WalletDb

{ _exclude: &[OutPoint], ) -> Result, Self::Error> { #[cfg(feature = "transparent-inputs")] - return wallet::get_unspent_transparent_outputs(self, _address, _max_height, _exclude); + return wallet::get_unspent_transparent_outputs( + self.conn.borrow(), + &self.params, + _address, + _max_height, + _exclude, + ); #[cfg(not(feature = "transparent-inputs"))] panic!( @@ -281,7 +305,12 @@ impl WalletRead for WalletDb

{ _max_height: BlockHeight, ) -> Result, Self::Error> { #[cfg(feature = "transparent-inputs")] - return wallet::get_transparent_balances(self, _account, _max_height); + return wallet::get_transparent_balances( + self.conn.borrow(), + &self.params, + _account, + _max_height, + ); #[cfg(not(feature = "transparent-inputs"))] panic!( @@ -290,177 +319,15 @@ impl WalletRead for WalletDb

{ } } -impl<'a, P: consensus::Parameters> WalletRead for DataConnStmtCache<'a, P> { - type Error = SqliteClientError; - type NoteRef = NoteId; - type TxRef = i64; - - fn block_height_extrema(&self) -> Result, Self::Error> { - self.wallet_db.block_height_extrema() - } - - fn get_min_unspent_height(&self) -> Result, Self::Error> { - self.wallet_db.get_min_unspent_height() - } - - fn get_block_hash(&self, block_height: BlockHeight) -> Result, Self::Error> { - self.wallet_db.get_block_hash(block_height) - } - - fn get_tx_height(&self, txid: TxId) -> Result, Self::Error> { - self.wallet_db.get_tx_height(txid) - } - - fn get_unified_full_viewing_keys( - &self, - ) -> Result, Self::Error> { - self.wallet_db.get_unified_full_viewing_keys() - } - - fn get_account_for_ufvk( - &self, - ufvk: &UnifiedFullViewingKey, - ) -> Result, Self::Error> { - self.wallet_db.get_account_for_ufvk(ufvk) - } - - fn get_current_address( - &self, - account: AccountId, - ) -> Result, Self::Error> { - self.wallet_db.get_current_address(account) - } - - fn is_valid_account_extfvk( - &self, - account: AccountId, - extfvk: &ExtendedFullViewingKey, - ) -> Result { - self.wallet_db.is_valid_account_extfvk(account, extfvk) - } - - fn get_balance_at( - &self, - account: AccountId, - anchor_height: BlockHeight, - ) -> Result { - self.wallet_db.get_balance_at(account, anchor_height) - } - - fn get_transaction(&self, id_tx: i64) -> Result { - self.wallet_db.get_transaction(id_tx) - } - - fn get_memo(&self, id_note: Self::NoteRef) -> Result, Self::Error> { - self.wallet_db.get_memo(id_note) - } - - fn get_commitment_tree( - &self, - block_height: BlockHeight, - ) -> Result, Self::Error> { - self.wallet_db.get_commitment_tree(block_height) - } - - #[allow(clippy::type_complexity)] - fn get_witnesses( - &self, - block_height: BlockHeight, - ) -> Result, Self::Error> { - self.wallet_db.get_witnesses(block_height) - } - - fn get_sapling_nullifiers( - &self, - query: data_api::NullifierQuery, - ) -> Result, Self::Error> { - self.wallet_db.get_sapling_nullifiers(query) - } - - fn get_spendable_sapling_notes( - &self, - account: AccountId, - anchor_height: BlockHeight, - exclude: &[Self::NoteRef], - ) -> Result>, Self::Error> { - self.wallet_db - .get_spendable_sapling_notes(account, anchor_height, exclude) - } - - fn select_spendable_sapling_notes( - &self, - account: AccountId, - target_value: Amount, - anchor_height: BlockHeight, - exclude: &[Self::NoteRef], - ) -> Result>, Self::Error> { - self.wallet_db - .select_spendable_sapling_notes(account, target_value, anchor_height, exclude) - } - - fn get_transparent_receivers( - &self, - account: AccountId, - ) -> Result, Self::Error> { - self.wallet_db.get_transparent_receivers(account) - } - - fn get_unspent_transparent_outputs( - &self, - address: &TransparentAddress, - max_height: BlockHeight, - exclude: &[OutPoint], - ) -> Result, Self::Error> { - self.wallet_db - .get_unspent_transparent_outputs(address, max_height, exclude) - } - - fn get_transparent_balances( - &self, - account: AccountId, - max_height: BlockHeight, - ) -> Result, Self::Error> { - self.wallet_db.get_transparent_balances(account, max_height) - } -} - -impl<'a, P: consensus::Parameters> DataConnStmtCache<'a, P> { - fn transactionally(&mut self, f: F) -> Result - where - F: FnOnce(&mut Self) -> Result, - { - self.wallet_db.conn.execute("BEGIN IMMEDIATE", [])?; - match f(self) { - Ok(result) => { - self.wallet_db.conn.execute("COMMIT", [])?; - Ok(result) - } - Err(error) => { - match self.wallet_db.conn.execute("ROLLBACK", []) { - Ok(_) => Err(error), - Err(e) => - // Panicking here is probably the right thing to do, because it - // means the database is corrupt. - panic!( - "Rollback failed with error {} while attempting to recover from error {}; database is likely corrupt.", - e, - error - ) - } - } - } - } -} - -impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { +impl WalletWrite for WalletDb { type UtxoRef = UtxoId; fn create_account( &mut self, seed: &SecretVec, ) -> Result<(AccountId, UnifiedSpendingKey), Self::Error> { - self.transactionally(|stmts| { - let account = wallet::get_max_account_id(stmts.wallet_db)? + self.transactionally(|wdb| { + let account = wallet::get_max_account_id(&wdb.conn.0)? .map(|a| AccountId::from(u32::from(a) + 1)) .unwrap_or_else(|| AccountId::from(0)); @@ -468,15 +335,11 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { return Err(SqliteClientError::AccountIdOutOfRange); } - let usk = UnifiedSpendingKey::from_seed( - &stmts.wallet_db.params, - seed.expose_secret(), - account, - ) - .map_err(|_| SqliteClientError::KeyDerivationError(account))?; + let usk = UnifiedSpendingKey::from_seed(&wdb.params, seed.expose_secret(), account) + .map_err(|_| SqliteClientError::KeyDerivationError(account))?; let ufvk = usk.to_unified_full_viewing_key(); - wallet::add_account(stmts.wallet_db, account, &ufvk)?; + wallet::add_account(&wdb.conn.0, &wdb.params, account, &ufvk)?; Ok((account, usk)) }) @@ -486,28 +349,37 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { &mut self, account: AccountId, ) -> Result, Self::Error> { - match self.get_unified_full_viewing_keys()?.get(&account) { - Some(ufvk) => { - let search_from = match wallet::get_current_address(self.wallet_db, account)? { - Some((_, mut last_diversifier_index)) => { - last_diversifier_index - .increment() - .map_err(|_| SqliteClientError::DiversifierIndexOutOfRange)?; - last_diversifier_index - } - None => DiversifierIndex::default(), - }; + self.transactionally( + |wdb| match wdb.get_unified_full_viewing_keys()?.get(&account) { + Some(ufvk) => { + let search_from = + match wallet::get_current_address(&wdb.conn.0, &wdb.params, account)? { + Some((_, mut last_diversifier_index)) => { + last_diversifier_index + .increment() + .map_err(|_| SqliteClientError::DiversifierIndexOutOfRange)?; + last_diversifier_index + } + None => DiversifierIndex::default(), + }; - let (addr, diversifier_index) = ufvk - .find_address(search_from) - .ok_or(SqliteClientError::DiversifierIndexOutOfRange)?; + let (addr, diversifier_index) = ufvk + .find_address(search_from) + .ok_or(SqliteClientError::DiversifierIndexOutOfRange)?; - self.stmt_insert_address(account, diversifier_index, &addr)?; + wallet::insert_address( + &wdb.conn.0, + &wdb.params, + account, + diversifier_index, + &addr, + )?; - Ok(Some(addr)) - } - None => Ok(None), - } + Ok(Some(addr)) + } + None => Ok(None), + }, + ) } #[tracing::instrument(skip_all, fields(height = u32::from(block.block_height)))] @@ -517,11 +389,10 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { block: &PrunedBlock, updated_witnesses: &[(Self::NoteRef, sapling::IncrementalWitness)], ) -> Result, Self::Error> { - // database updates for each block are transactional - self.transactionally(|up| { + self.transactionally(|wdb| { // Insert the block into the database. wallet::insert_block( - up, + &wdb.conn.0, block.block_height, block.block_hash, block.block_time, @@ -530,15 +401,16 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { let mut new_witnesses = vec![]; for tx in block.transactions { - let tx_row = wallet::put_tx_meta(up, tx, block.block_height)?; + let tx_row = wallet::put_tx_meta(&wdb.conn.0, tx, block.block_height)?; // Mark notes as spent and remove them from the scanning cache for spend in &tx.sapling_spends { - wallet::sapling::mark_sapling_note_spent(up, tx_row, spend.nf())?; + wallet::sapling::mark_sapling_note_spent(&wdb.conn.0, tx_row, spend.nf())?; } for output in &tx.sapling_outputs { - let received_note_id = wallet::sapling::put_received_note(up, output, tx_row)?; + let received_note_id = + wallet::sapling::put_received_note(&wdb.conn.0, output, tx_row)?; // Save witness for note. new_witnesses.push((received_note_id, output.witness().clone())); @@ -549,17 +421,22 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { for (received_note_id, witness) in updated_witnesses.iter().chain(new_witnesses.iter()) { if let NoteId::ReceivedNoteId(rnid) = *received_note_id { - wallet::sapling::insert_witness(up, rnid, witness, block.block_height)?; + wallet::sapling::insert_witness( + &wdb.conn.0, + rnid, + witness, + block.block_height, + )?; } else { return Err(SqliteClientError::InvalidNoteId); } } // Prune the stored witnesses (we only expect rollbacks of at most PRUNING_HEIGHT blocks). - wallet::prune_witnesses(up, block.block_height - PRUNING_HEIGHT)?; + wallet::prune_witnesses(&wdb.conn.0, block.block_height - PRUNING_HEIGHT)?; // Update now-expired transactions that didn't get mined. - wallet::update_expired_notes(up, block.block_height)?; + wallet::update_expired_notes(&wdb.conn.0, block.block_height)?; Ok(new_witnesses) }) @@ -569,91 +446,114 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { &mut self, d_tx: DecryptedTransaction, ) -> Result { - self.transactionally(|up| { - let tx_ref = wallet::put_tx_data(up, d_tx.tx, None, None)?; + self.transactionally(|wdb| { + let tx_ref = wallet::put_tx_data(&wdb.conn.0, d_tx.tx, None, None)?; - let mut spending_account_id: Option = None; - for output in d_tx.sapling_outputs { - match output.transfer_type { - TransferType::Outgoing | TransferType::WalletInternal => { - let recipient = if output.transfer_type == TransferType::Outgoing { - Recipient::Sapling(output.note.recipient()) - } else { - Recipient::InternalAccount(output.account, PoolType::Sapling) - }; + let mut spending_account_id: Option = None; + for output in d_tx.sapling_outputs { + match output.transfer_type { + TransferType::Outgoing | TransferType::WalletInternal => { + let recipient = if output.transfer_type == TransferType::Outgoing { + Recipient::Sapling(output.note.recipient()) + } else { + Recipient::InternalAccount(output.account, PoolType::Sapling) + }; - wallet::put_sent_output( - up, - output.account, - tx_ref, - output.index, - &recipient, - Amount::from_u64(output.note.value().inner()).map_err(|_| - SqliteClientError::CorruptedData("Note value is not a valid Zcash amount.".to_string()))?, - Some(&output.memo), - )?; + wallet::put_sent_output( + &wdb.conn.0, + &wdb.params, + output.account, + tx_ref, + output.index, + &recipient, + Amount::from_u64(output.note.value().inner()).map_err(|_| { + SqliteClientError::CorruptedData( + "Note value is not a valid Zcash amount.".to_string(), + ) + })?, + Some(&output.memo), + )?; - if matches!(recipient, Recipient::InternalAccount(_, _)) { - wallet::sapling::put_received_note(up, output, tx_ref)?; - } + if matches!(recipient, Recipient::InternalAccount(_, _)) { + wallet::sapling::put_received_note(&wdb.conn.0, output, tx_ref)?; } - TransferType::Incoming => { - match spending_account_id { - Some(id) => - if id != output.account { - panic!("Unable to determine a unique account identifier for z->t spend."); - } - None => { - spending_account_id = Some(output.account); + } + TransferType::Incoming => { + match spending_account_id { + Some(id) => { + if id != output.account { + panic!("Unable to determine a unique account identifier for z->t spend."); } } - - wallet::sapling::put_received_note(up, output, tx_ref)?; - } - } - } - - // If any of the utxos spent in the transaction are ours, mark them as spent. - #[cfg(feature = "transparent-inputs")] - for txin in d_tx.tx.transparent_bundle().iter().flat_map(|b| b.vin.iter()) { - wallet::mark_transparent_utxo_spent(up, tx_ref, &txin.prevout)?; - } - - // If we have some transparent outputs: - if !d_tx.tx.transparent_bundle().iter().any(|b| b.vout.is_empty()) { - let nullifiers = self.wallet_db.get_sapling_nullifiers(data_api::NullifierQuery::All)?; - // If the transaction contains shielded spends from our wallet, we will store z->t - // transactions we observe in the same way they would be stored by - // create_spend_to_address. - if let Some((account_id, _)) = nullifiers.iter().find( - |(_, nf)| - d_tx.tx.sapling_bundle().iter().flat_map(|b| b.shielded_spends().iter()) - .any(|input| nf == input.nullifier()) - ) { - for (output_index, txout) in d_tx.tx.transparent_bundle().iter().flat_map(|b| b.vout.iter()).enumerate() { - if let Some(address) = txout.recipient_address() { - wallet::put_sent_output( - up, - *account_id, - tx_ref, - output_index, - &Recipient::Transparent(address), - txout.value, - None - )?; + None => { + spending_account_id = Some(output.account); } } + + wallet::sapling::put_received_note(&wdb.conn.0, output, tx_ref)?; } } - Ok(tx_ref) + } + + // If any of the utxos spent in the transaction are ours, mark them as spent. + #[cfg(feature = "transparent-inputs")] + for txin in d_tx + .tx + .transparent_bundle() + .iter() + .flat_map(|b| b.vin.iter()) + { + wallet::mark_transparent_utxo_spent(&wdb.conn.0, tx_ref, &txin.prevout)?; + } + + // If we have some transparent outputs: + if !d_tx + .tx + .transparent_bundle() + .iter() + .any(|b| b.vout.is_empty()) + { + let nullifiers = wdb.get_sapling_nullifiers(data_api::NullifierQuery::All)?; + // If the transaction contains shielded spends from our wallet, we will store z->t + // transactions we observe in the same way they would be stored by + // create_spend_to_address. + if let Some((account_id, _)) = nullifiers.iter().find(|(_, nf)| { + d_tx.tx + .sapling_bundle() + .iter() + .flat_map(|b| b.shielded_spends().iter()) + .any(|input| nf == input.nullifier()) + }) { + for (output_index, txout) in d_tx + .tx + .transparent_bundle() + .iter() + .flat_map(|b| b.vout.iter()) + .enumerate() + { + if let Some(address) = txout.recipient_address() { + wallet::put_sent_output( + &wdb.conn.0, + &wdb.params, + *account_id, + tx_ref, + output_index, + &Recipient::Transparent(address), + txout.value, + None, + )?; + } + } + } + } + Ok(tx_ref) }) } fn store_sent_tx(&mut self, sent_tx: &SentTransaction) -> Result { - // Update the database atomically, to ensure the result is internally consistent. - self.transactionally(|up| { + self.transactionally(|wdb| { let tx_ref = wallet::put_tx_data( - up, + &wdb.conn.0, sent_tx.tx, Some(sent_tx.fee_amount), Some(sent_tx.created), @@ -669,21 +569,31 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { // reasonable assumption for a light client such as a mobile phone. if let Some(bundle) = sent_tx.tx.sapling_bundle() { for spend in bundle.shielded_spends() { - wallet::sapling::mark_sapling_note_spent(up, tx_ref, spend.nullifier())?; + wallet::sapling::mark_sapling_note_spent( + &wdb.conn.0, + tx_ref, + spend.nullifier(), + )?; } } #[cfg(feature = "transparent-inputs")] for utxo_outpoint in &sent_tx.utxos_spent { - wallet::mark_transparent_utxo_spent(up, tx_ref, utxo_outpoint)?; + wallet::mark_transparent_utxo_spent(&wdb.conn.0, tx_ref, utxo_outpoint)?; } for output in &sent_tx.outputs { - wallet::insert_sent_output(up, tx_ref, sent_tx.account, output)?; + wallet::insert_sent_output( + &wdb.conn.0, + &wdb.params, + tx_ref, + sent_tx.account, + output, + )?; if let Some((account, note)) = output.sapling_change_to() { wallet::sapling::put_received_note( - up, + &wdb.conn.0, &DecryptedOutput { index: output.output_index(), note: note.clone(), @@ -704,7 +614,9 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { } fn truncate_to_height(&mut self, block_height: BlockHeight) -> Result<(), Self::Error> { - wallet::truncate_to_height(self.wallet_db, block_height) + self.transactionally(|wdb| { + wallet::truncate_to_height(&wdb.conn.0, &wdb.params, block_height) + }) } fn put_received_transparent_utxo( @@ -712,7 +624,7 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { _output: &WalletTransparentOutput, ) -> Result { #[cfg(feature = "transparent-inputs")] - return wallet::put_received_transparent_utxo(self, _output); + return wallet::put_received_transparent_utxo(&self.conn, &self.params, _output); #[cfg(not(feature = "transparent-inputs"))] panic!( @@ -1056,7 +968,7 @@ mod tests { #[cfg(test)] pub(crate) fn init_test_accounts_table( - db_data: &WalletDb, + db_data: &mut WalletDb, ) -> (DiversifiableFullViewingKey, Option) { let (ufvk, taddr) = init_test_accounts_table_ufvk(db_data); (ufvk.sapling().unwrap().clone(), taddr) @@ -1064,7 +976,7 @@ mod tests { #[cfg(test)] pub(crate) fn init_test_accounts_table_ufvk( - db_data: &WalletDb, + db_data: &mut WalletDb, ) -> (UnifiedFullViewingKey, Option) { let seed = [0u8; 32]; let account = AccountId::from(0); @@ -1291,13 +1203,12 @@ mod tests { let account = AccountId::from(0); init_wallet_db(&mut db_data, None).unwrap(); - let _ = init_test_accounts_table_ufvk(&db_data); + init_test_accounts_table_ufvk(&mut db_data); let current_addr = db_data.get_current_address(account).unwrap(); assert!(current_addr.is_some()); - let mut update_ops = db_data.get_update_ops().unwrap(); - let addr2 = update_ops.get_next_available_address(account).unwrap(); + let addr2 = db_data.get_next_available_address(account).unwrap(); assert!(addr2.is_some()); assert_ne!(current_addr, addr2); @@ -1322,7 +1233,7 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet. - let (ufvk, taddr) = init_test_accounts_table_ufvk(&db_data); + let (ufvk, taddr) = init_test_accounts_table_ufvk(&mut db_data); let taddr = taddr.unwrap(); let receivers = db_data.get_transparent_receivers(0.into()).unwrap(); diff --git a/zcash_client_sqlite/src/prepared.rs b/zcash_client_sqlite/src/prepared.rs deleted file mode 100644 index da97faa6c..000000000 --- a/zcash_client_sqlite/src/prepared.rs +++ /dev/null @@ -1,812 +0,0 @@ -//! Prepared SQL statements used by the wallet. -//! -//! Some `rusqlite` crate APIs are only available on prepared statements; these are stored -//! inside the [`DataConnStmtCache`]. When adding a new prepared statement: -//! -//! - Add it as a private field of `DataConnStmtCache`. -//! - Build the statement in [`DataConnStmtCache::new`]. -//! - Add a crate-private helper method to `DataConnStmtCache` for running the statement. - -use rusqlite::{named_params, params, Statement, ToSql}; -use zcash_primitives::{ - block::BlockHash, - consensus::{self, BlockHeight}, - memo::MemoBytes, - merkle_tree::{write_commitment_tree, write_incremental_witness}, - sapling::{self, Diversifier, Nullifier}, - transaction::{components::Amount, TxId}, - zip32::{AccountId, DiversifierIndex}, -}; - -use zcash_client_backend::{ - address::UnifiedAddress, - data_api::{PoolType, Recipient}, - encoding::AddressCodec, -}; - -use crate::{error::SqliteClientError, wallet::pool_code, NoteId, WalletDb}; - -#[cfg(feature = "transparent-inputs")] -use { - crate::UtxoId, rusqlite::OptionalExtension, - zcash_client_backend::wallet::WalletTransparentOutput, - zcash_primitives::transaction::components::transparent::OutPoint, -}; - -pub(crate) struct InsertAddress<'a> { - stmt: Statement<'a>, -} - -impl<'a> InsertAddress<'a> { - pub(crate) fn new(conn: &'a rusqlite::Connection) -> Result { - Ok(InsertAddress { - stmt: conn.prepare( - "INSERT INTO addresses ( - account, - diversifier_index_be, - address, - cached_transparent_receiver_address - ) - VALUES ( - :account, - :diversifier_index_be, - :address, - :cached_transparent_receiver_address - )", - )?, - }) - } - - /// Adds the given address and diversifier index to the addresses table. - /// - /// Returns the database row for the newly-inserted address. - pub(crate) fn execute( - &mut self, - params: &P, - account: AccountId, - mut diversifier_index: DiversifierIndex, - address: &UnifiedAddress, - ) -> Result<(), rusqlite::Error> { - // the diversifier index is stored in big-endian order to allow sorting - diversifier_index.0.reverse(); - self.stmt.execute(named_params![ - ":account": &u32::from(account), - ":diversifier_index_be": &&diversifier_index.0[..], - ":address": &address.encode(params), - ":cached_transparent_receiver_address": &address.transparent().map(|r| r.encode(params)), - ])?; - - Ok(()) - } -} - -/// The primary type used to implement [`WalletWrite`] for the SQLite database. -/// -/// A data structure that stores the SQLite prepared statements that are -/// required for the implementation of [`WalletWrite`] against the backing -/// store. -/// -/// [`WalletWrite`]: zcash_client_backend::data_api::WalletWrite -pub struct DataConnStmtCache<'a, P> { - pub(crate) wallet_db: &'a WalletDb

, - stmt_insert_block: Statement<'a>, - - stmt_insert_tx_meta: Statement<'a>, - stmt_update_tx_meta: Statement<'a>, - - stmt_insert_tx_data: Statement<'a>, - stmt_update_tx_data: Statement<'a>, - stmt_select_tx_ref: Statement<'a>, - - stmt_mark_sapling_note_spent: Statement<'a>, - #[cfg(feature = "transparent-inputs")] - stmt_mark_transparent_utxo_spent: Statement<'a>, - - #[cfg(feature = "transparent-inputs")] - stmt_insert_received_transparent_utxo: Statement<'a>, - #[cfg(feature = "transparent-inputs")] - stmt_update_received_transparent_utxo: Statement<'a>, - #[cfg(feature = "transparent-inputs")] - stmt_insert_legacy_transparent_utxo: Statement<'a>, - #[cfg(feature = "transparent-inputs")] - stmt_update_legacy_transparent_utxo: Statement<'a>, - stmt_insert_received_note: Statement<'a>, - stmt_update_received_note: Statement<'a>, - stmt_select_received_note: Statement<'a>, - - stmt_insert_sent_output: Statement<'a>, - stmt_update_sent_output: Statement<'a>, - - stmt_insert_witness: Statement<'a>, - stmt_prune_witnesses: Statement<'a>, - stmt_update_expired: Statement<'a>, - - stmt_insert_address: InsertAddress<'a>, -} - -impl<'a, P> DataConnStmtCache<'a, P> { - pub(crate) fn new(wallet_db: &'a WalletDb

) -> Result { - Ok( - DataConnStmtCache { - wallet_db, - stmt_insert_block: wallet_db.conn.prepare( - "INSERT INTO blocks (height, hash, time, sapling_tree) - VALUES (?, ?, ?, ?)", - )?, - stmt_insert_tx_meta: wallet_db.conn.prepare( - "INSERT INTO transactions (txid, block, tx_index) - VALUES (?, ?, ?)", - )?, - stmt_update_tx_meta: wallet_db.conn.prepare( - "UPDATE transactions - SET block = ?, tx_index = ? WHERE txid = ?", - )?, - stmt_insert_tx_data: wallet_db.conn.prepare( - "INSERT INTO transactions (txid, created, expiry_height, raw, fee) - VALUES (?, ?, ?, ?, ?)", - )?, - stmt_update_tx_data: wallet_db.conn.prepare( - "UPDATE transactions - SET expiry_height = :expiry_height, - raw = :raw, - fee = IFNULL(:fee, fee) - WHERE txid = :txid", - )?, - stmt_select_tx_ref: wallet_db.conn.prepare( - "SELECT id_tx FROM transactions WHERE txid = ?", - )?, - stmt_mark_sapling_note_spent: wallet_db.conn.prepare( - "UPDATE sapling_received_notes SET spent = ? WHERE nf = ?" - )?, - #[cfg(feature = "transparent-inputs")] - stmt_mark_transparent_utxo_spent: wallet_db.conn.prepare( - "UPDATE utxos SET spent_in_tx = :spent_in_tx - WHERE prevout_txid = :prevout_txid - AND prevout_idx = :prevout_idx" - )?, - #[cfg(feature = "transparent-inputs")] - stmt_insert_received_transparent_utxo: wallet_db.conn.prepare( - "INSERT INTO utxos ( - received_by_account, address, - prevout_txid, prevout_idx, script, - value_zat, height) - SELECT - addresses.account, :address, - :prevout_txid, :prevout_idx, :script, - :value_zat, :height - FROM addresses - WHERE addresses.cached_transparent_receiver_address = :address - RETURNING id_utxo" - )?, - #[cfg(feature = "transparent-inputs")] - stmt_update_received_transparent_utxo: wallet_db.conn.prepare( - "UPDATE utxos - SET received_by_account = addresses.account, - height = :height, - address = :address, - script = :script, - value_zat = :value_zat - FROM addresses - WHERE prevout_txid = :prevout_txid - AND prevout_idx = :prevout_idx - AND addresses.cached_transparent_receiver_address = :address - RETURNING id_utxo" - )?, - #[cfg(feature = "transparent-inputs")] - stmt_insert_legacy_transparent_utxo: wallet_db.conn.prepare( - "INSERT INTO utxos ( - received_by_account, address, - prevout_txid, prevout_idx, script, - value_zat, height) - VALUES - (:received_by_account, :address, - :prevout_txid, :prevout_idx, :script, - :value_zat, :height) - RETURNING id_utxo" - )?, - #[cfg(feature = "transparent-inputs")] - stmt_update_legacy_transparent_utxo: wallet_db.conn.prepare( - "UPDATE utxos - SET received_by_account = :received_by_account, - height = :height, - address = :address, - script = :script, - value_zat = :value_zat - WHERE prevout_txid = :prevout_txid - AND prevout_idx = :prevout_idx - RETURNING id_utxo" - )?, - stmt_insert_received_note: wallet_db.conn.prepare( - "INSERT INTO sapling_received_notes (tx, output_index, account, diversifier, value, rcm, memo, nf, is_change) - VALUES (:tx, :output_index, :account, :diversifier, :value, :rcm, :memo, :nf, :is_change)", - )?, - stmt_update_received_note: wallet_db.conn.prepare( - "UPDATE sapling_received_notes - SET account = :account, - diversifier = :diversifier, - value = :value, - rcm = :rcm, - nf = IFNULL(:nf, nf), - memo = IFNULL(:memo, memo), - is_change = IFNULL(:is_change, is_change) - WHERE tx = :tx AND output_index = :output_index", - )?, - stmt_select_received_note: wallet_db.conn.prepare( - "SELECT id_note FROM sapling_received_notes WHERE tx = ? AND output_index = ?" - )?, - stmt_update_sent_output: wallet_db.conn.prepare( - "UPDATE sent_notes - SET from_account = :from_account, - to_address = :to_address, - to_account = :to_account, - value = :value, - memo = IFNULL(:memo, memo) - WHERE tx = :tx - AND output_pool = :output_pool - AND output_index = :output_index", - )?, - stmt_insert_sent_output: wallet_db.conn.prepare( - "INSERT INTO sent_notes ( - tx, output_pool, output_index, from_account, - to_address, to_account, value, memo) - VALUES ( - :tx, :output_pool, :output_index, :from_account, - :to_address, :to_account, :value, :memo)" - )?, - stmt_insert_witness: wallet_db.conn.prepare( - "INSERT INTO sapling_witnesses (note, block, witness) - VALUES (?, ?, ?)", - )?, - stmt_prune_witnesses: wallet_db.conn.prepare( - "DELETE FROM sapling_witnesses WHERE block < ?" - )?, - stmt_update_expired: wallet_db.conn.prepare( - "UPDATE sapling_received_notes SET spent = NULL WHERE EXISTS ( - SELECT id_tx FROM transactions - WHERE id_tx = sapling_received_notes.spent AND block IS NULL AND expiry_height < ? - )", - )?, - stmt_insert_address: InsertAddress::new(&wallet_db.conn)? - } - ) - } - - /// Inserts information about a scanned block into the database. - pub fn stmt_insert_block( - &mut self, - block_height: BlockHeight, - block_hash: BlockHash, - block_time: u32, - commitment_tree: &sapling::CommitmentTree, - ) -> Result<(), SqliteClientError> { - let mut encoded_tree = Vec::new(); - write_commitment_tree(commitment_tree, &mut encoded_tree).unwrap(); - - self.stmt_insert_block.execute(params![ - u32::from(block_height), - &block_hash.0[..], - block_time, - encoded_tree - ])?; - - Ok(()) - } - - /// Inserts the given transaction and its block metadata into the wallet. - /// - /// Returns the database row for the newly-inserted transaction, or an error if the - /// transaction exists. - pub(crate) fn stmt_insert_tx_meta( - &mut self, - txid: &TxId, - height: BlockHeight, - tx_index: usize, - ) -> Result { - self.stmt_insert_tx_meta.execute(params![ - &txid.as_ref()[..], - u32::from(height), - (tx_index as i64), - ])?; - - Ok(self.wallet_db.conn.last_insert_rowid()) - } - - /// Updates the block metadata for the given transaction. - /// - /// Returns `false` if the transaction doesn't exist in the wallet. - pub(crate) fn stmt_update_tx_meta( - &mut self, - height: BlockHeight, - tx_index: usize, - txid: &TxId, - ) -> Result { - match self.stmt_update_tx_meta.execute(params![ - u32::from(height), - (tx_index as i64), - &txid.as_ref()[..], - ])? { - 0 => Ok(false), - 1 => Ok(true), - _ => unreachable!("txid column is marked as UNIQUE"), - } - } - - /// Inserts the given transaction and its data into the wallet. - /// - /// Returns the database row for the newly-inserted transaction, or an error if the - /// transaction exists. - pub(crate) fn stmt_insert_tx_data( - &mut self, - txid: &TxId, - created_at: Option, - expiry_height: BlockHeight, - raw_tx: &[u8], - fee: Option, - ) -> Result { - self.stmt_insert_tx_data.execute(params![ - &txid.as_ref()[..], - created_at, - u32::from(expiry_height), - raw_tx, - fee.map(i64::from) - ])?; - - Ok(self.wallet_db.conn.last_insert_rowid()) - } - - /// Updates the data for the given transaction. - /// - /// Returns `false` if the transaction doesn't exist in the wallet. - pub(crate) fn stmt_update_tx_data( - &mut self, - expiry_height: BlockHeight, - raw_tx: &[u8], - fee: Option, - txid: &TxId, - ) -> Result { - let sql_args: &[(&str, &dyn ToSql)] = &[ - (":expiry_height", &u32::from(expiry_height)), - (":raw", &raw_tx), - (":fee", &fee.map(i64::from)), - (":txid", &&txid.as_ref()[..]), - ]; - match self.stmt_update_tx_data.execute(sql_args)? { - 0 => Ok(false), - 1 => Ok(true), - _ => unreachable!("txid column is marked as UNIQUE"), - } - } - - /// Finds the database row for the given `txid`, if the transaction is in the wallet. - pub(crate) fn stmt_select_tx_ref(&mut self, txid: &TxId) -> Result { - self.stmt_select_tx_ref - .query_row([&txid.as_ref()[..]], |row| row.get(0)) - .map_err(SqliteClientError::from) - } - - /// Marks a given nullifier as having been revealed in the construction of the - /// specified transaction. - /// - /// Marking a note spent in this fashion does NOT imply that the spending transaction - /// has been mined. - /// - /// Returns `false` if the nullifier does not correspond to any received note. - pub(crate) fn stmt_mark_sapling_note_spent( - &mut self, - tx_ref: i64, - nf: &Nullifier, - ) -> Result { - match self - .stmt_mark_sapling_note_spent - .execute(params![tx_ref, &nf.0[..]])? - { - 0 => Ok(false), - 1 => Ok(true), - _ => unreachable!("nf column is marked as UNIQUE"), - } - } - - /// Marks the given UTXO as having been spent. - /// - /// Returns `false` if `outpoint` does not correspond to any tracked UTXO. - #[cfg(feature = "transparent-inputs")] - pub(crate) fn stmt_mark_transparent_utxo_spent( - &mut self, - tx_ref: i64, - outpoint: &OutPoint, - ) -> Result { - let sql_args: &[(&str, &dyn ToSql)] = &[ - (":spent_in_tx", &tx_ref), - (":prevout_txid", &outpoint.hash().to_vec()), - (":prevout_idx", &outpoint.n()), - ]; - - match self.stmt_mark_transparent_utxo_spent.execute(sql_args)? { - 0 => Ok(false), - 1 => Ok(true), - _ => unreachable!("tx_outpoint constraint is marked as UNIQUE"), - } - } -} - -impl<'a, P: consensus::Parameters> DataConnStmtCache<'a, P> { - /// Inserts a sent note into the wallet database. - /// - /// `output_index` is the index within the transaction that contains the recipient output: - /// - /// - If `to` is a Unified address, this is an index into the outputs of the transaction - /// within the bundle associated with the recipient's output pool. - /// - If `to` is a Sapling address, this is an index into the Sapling outputs of the - /// transaction. - /// - If `to` is a transparent address, this is an index into the transparent outputs of - /// the transaction. - /// - If `to` is an internal account, this is an index into the Sapling outputs of the - /// transaction. - #[allow(clippy::too_many_arguments)] - pub(crate) fn stmt_insert_sent_output( - &mut self, - tx_ref: i64, - output_index: usize, - from_account: AccountId, - to: &Recipient, - value: Amount, - memo: Option<&MemoBytes>, - ) -> Result<(), SqliteClientError> { - let (to_address, to_account, pool_type) = match to { - Recipient::Transparent(addr) => ( - Some(addr.encode(&self.wallet_db.params)), - None, - PoolType::Transparent, - ), - Recipient::Sapling(addr) => ( - Some(addr.encode(&self.wallet_db.params)), - None, - PoolType::Sapling, - ), - Recipient::Unified(addr, pool) => { - (Some(addr.encode(&self.wallet_db.params)), None, *pool) - } - Recipient::InternalAccount(id, pool) => (None, Some(u32::from(*id)), *pool), - }; - - self.stmt_insert_sent_output.execute(named_params![ - ":tx": &tx_ref, - ":output_pool": &pool_code(pool_type), - ":output_index": &i64::try_from(output_index).unwrap(), - ":from_account": &u32::from(from_account), - ":to_address": &to_address, - ":to_account": &to_account, - ":value": &i64::from(value), - ":memo": &memo.filter(|m| *m != &MemoBytes::empty()).map(|m| m.as_slice()), - ])?; - - Ok(()) - } - - /// Updates the data for the given sent note. - /// - /// Returns `false` if the transaction doesn't exist in the wallet. - #[allow(clippy::too_many_arguments)] - pub(crate) fn stmt_update_sent_output( - &mut self, - from_account: AccountId, - to: &Recipient, - value: Amount, - memo: Option<&MemoBytes>, - tx_ref: i64, - output_index: usize, - ) -> Result { - let (to_address, to_account, pool_type) = match to { - Recipient::Transparent(addr) => ( - Some(addr.encode(&self.wallet_db.params)), - None, - PoolType::Transparent, - ), - Recipient::Sapling(addr) => ( - Some(addr.encode(&self.wallet_db.params)), - None, - PoolType::Sapling, - ), - Recipient::Unified(addr, pool) => { - (Some(addr.encode(&self.wallet_db.params)), None, *pool) - } - Recipient::InternalAccount(id, pool) => (None, Some(u32::from(*id)), *pool), - }; - - match self.stmt_update_sent_output.execute(named_params![ - ":from_account": &u32::from(from_account), - ":to_address": &to_address, - ":to_account": &to_account, - ":value": &i64::from(value), - ":memo": &memo.filter(|m| *m != &MemoBytes::empty()).map(|m| m.as_slice()), - ":tx": &tx_ref, - ":output_pool": &pool_code(pool_type), - ":output_index": &i64::try_from(output_index).unwrap(), - ])? { - 0 => Ok(false), - 1 => Ok(true), - _ => unreachable!("tx_output constraint is marked as UNIQUE"), - } - } - - /// Adds the given received UTXO to the datastore. - /// - /// Returns the database identifier for the newly-inserted UTXO if the address to which the - /// UTXO was sent corresponds to a cached transparent receiver in the addresses table, or - /// Ok(None) if the address is unknown. Returns an error if the UTXO exists. - #[cfg(feature = "transparent-inputs")] - pub(crate) fn stmt_insert_received_transparent_utxo( - &mut self, - output: &WalletTransparentOutput, - ) -> Result, SqliteClientError> { - self.stmt_insert_received_transparent_utxo - .query_row( - named_params![ - ":address": &output.recipient_address().encode(&self.wallet_db.params), - ":prevout_txid": &output.outpoint().hash().to_vec(), - ":prevout_idx": &output.outpoint().n(), - ":script": &output.txout().script_pubkey.0, - ":value_zat": &i64::from(output.txout().value), - ":height": &u32::from(output.height()), - ], - |row| { - let id = row.get(0)?; - Ok(UtxoId(id)) - }, - ) - .optional() - .map_err(SqliteClientError::from) - } - - /// Adds the given received UTXO to the datastore. - /// - /// Returns the database identifier for the updated UTXO if the address to which the UTXO was - /// sent corresponds to a cached transparent receiver in the addresses table, or Ok(None) if - /// the address is unknown. - #[cfg(feature = "transparent-inputs")] - pub(crate) fn stmt_update_received_transparent_utxo( - &mut self, - output: &WalletTransparentOutput, - ) -> Result, SqliteClientError> { - self.stmt_update_received_transparent_utxo - .query_row( - named_params![ - ":prevout_txid": &output.outpoint().hash().to_vec(), - ":prevout_idx": &output.outpoint().n(), - ":address": &output.recipient_address().encode(&self.wallet_db.params), - ":script": &output.txout().script_pubkey.0, - ":value_zat": &i64::from(output.txout().value), - ":height": &u32::from(output.height()), - ], - |row| { - let id = row.get(0)?; - Ok(UtxoId(id)) - }, - ) - .optional() - .map_err(SqliteClientError::from) - } - - /// Adds the given legacy UTXO to the datastore. - /// - /// Returns the database row for the newly-inserted UTXO, or an error if the UTXO - /// exists. - #[cfg(feature = "transparent-inputs")] - pub(crate) fn stmt_insert_legacy_transparent_utxo( - &mut self, - output: &WalletTransparentOutput, - received_by_account: AccountId, - ) -> Result { - self.stmt_insert_legacy_transparent_utxo - .query_row( - named_params![ - ":received_by_account": &u32::from(received_by_account), - ":address": &output.recipient_address().encode(&self.wallet_db.params), - ":prevout_txid": &output.outpoint().hash().to_vec(), - ":prevout_idx": &output.outpoint().n(), - ":script": &output.txout().script_pubkey.0, - ":value_zat": &i64::from(output.txout().value), - ":height": &u32::from(output.height()), - ], - |row| { - let id = row.get(0)?; - Ok(UtxoId(id)) - }, - ) - .map_err(SqliteClientError::from) - } - - /// Adds the given legacy UTXO to the datastore. - /// - /// Returns the database row for the newly-inserted UTXO, or an error if the UTXO - /// exists. - #[cfg(feature = "transparent-inputs")] - pub(crate) fn stmt_update_legacy_transparent_utxo( - &mut self, - output: &WalletTransparentOutput, - received_by_account: AccountId, - ) -> Result, SqliteClientError> { - self.stmt_update_legacy_transparent_utxo - .query_row( - named_params![ - ":received_by_account": &u32::from(received_by_account), - ":prevout_txid": &output.outpoint().hash().to_vec(), - ":prevout_idx": &output.outpoint().n(), - ":address": &output.recipient_address().encode(&self.wallet_db.params), - ":script": &output.txout().script_pubkey.0, - ":value_zat": &i64::from(output.txout().value), - ":height": &u32::from(output.height()), - ], - |row| { - let id = row.get(0)?; - Ok(UtxoId(id)) - }, - ) - .optional() - .map_err(SqliteClientError::from) - } - - /// Adds the given address and diversifier index to the addresses table. - /// - /// Returns the database row for the newly-inserted address. - pub(crate) fn stmt_insert_address( - &mut self, - account: AccountId, - diversifier_index: DiversifierIndex, - address: &UnifiedAddress, - ) -> Result<(), SqliteClientError> { - self.stmt_insert_address.execute( - &self.wallet_db.params, - account, - diversifier_index, - address, - )?; - - Ok(()) - } -} - -impl<'a, P> DataConnStmtCache<'a, P> { - /// Inserts the given received note into the wallet. - /// - /// This implementation relies on the facts that: - /// - A transaction will not contain more than 2^63 shielded outputs. - /// - A note value will never exceed 2^63 zatoshis. - /// - /// Returns the database row for the newly-inserted note, or an error if the note - /// exists. - #[allow(clippy::too_many_arguments)] - pub(crate) fn stmt_insert_received_note( - &mut self, - tx_ref: i64, - output_index: usize, - account: AccountId, - diversifier: &Diversifier, - value: u64, - rcm: [u8; 32], - nf: Option<&Nullifier>, - memo: Option<&MemoBytes>, - is_change: bool, - ) -> Result { - let sql_args: &[(&str, &dyn ToSql)] = &[ - (":tx", &tx_ref), - (":output_index", &(output_index as i64)), - (":account", &u32::from(account)), - (":diversifier", &diversifier.0.as_ref()), - (":value", &(value as i64)), - (":rcm", &rcm.as_ref()), - (":nf", &nf.map(|nf| nf.0.as_ref())), - ( - ":memo", - &memo - .filter(|m| *m != &MemoBytes::empty()) - .map(|m| m.as_slice()), - ), - (":is_change", &is_change), - ]; - - self.stmt_insert_received_note.execute(sql_args)?; - - Ok(NoteId::ReceivedNoteId( - self.wallet_db.conn.last_insert_rowid(), - )) - } - - /// Updates the data for the given transaction. - /// - /// This implementation relies on the facts that: - /// - A transaction will not contain more than 2^63 shielded outputs. - /// - A note value will never exceed 2^63 zatoshis. - /// - /// Returns `false` if the transaction doesn't exist in the wallet. - #[allow(clippy::too_many_arguments)] - pub(crate) fn stmt_update_received_note( - &mut self, - account: AccountId, - diversifier: &Diversifier, - value: u64, - rcm: [u8; 32], - nf: Option<&Nullifier>, - memo: Option<&MemoBytes>, - is_change: bool, - tx_ref: i64, - output_index: usize, - ) -> Result { - let sql_args: &[(&str, &dyn ToSql)] = &[ - (":account", &u32::from(account)), - (":diversifier", &diversifier.0.as_ref()), - (":value", &(value as i64)), - (":rcm", &rcm.as_ref()), - (":nf", &nf.map(|nf| nf.0.as_ref())), - ( - ":memo", - &memo - .filter(|m| *m != &MemoBytes::empty()) - .map(|m| m.as_slice()), - ), - (":is_change", &is_change), - (":tx", &tx_ref), - (":output_index", &(output_index as i64)), - ]; - - match self.stmt_update_received_note.execute(sql_args)? { - 0 => Ok(false), - 1 => Ok(true), - _ => unreachable!("tx_output constraint is marked as UNIQUE"), - } - } - - /// Finds the database row for the given `txid`, if the transaction is in the wallet. - pub(crate) fn stmt_select_received_note( - &mut self, - tx_ref: i64, - output_index: usize, - ) -> Result { - self.stmt_select_received_note - .query_row(params![tx_ref, (output_index as i64)], |row| { - row.get(0).map(NoteId::ReceivedNoteId) - }) - .map_err(SqliteClientError::from) - } - - /// Records the incremental witness for the specified note, as of the given block - /// height. - /// - /// Returns `SqliteClientError::InvalidNoteId` if the note ID is for a sent note. - pub(crate) fn stmt_insert_witness( - &mut self, - note_id: NoteId, - height: BlockHeight, - witness: &sapling::IncrementalWitness, - ) -> Result<(), SqliteClientError> { - let note_id = match note_id { - NoteId::ReceivedNoteId(note_id) => Ok(note_id), - NoteId::SentNoteId(_) => Err(SqliteClientError::InvalidNoteId), - }?; - - let mut encoded = Vec::new(); - write_incremental_witness(witness, &mut encoded).unwrap(); - - self.stmt_insert_witness - .execute(params![note_id, u32::from(height), encoded])?; - - Ok(()) - } - - /// Removes old incremental witnesses up to the given block height. - pub(crate) fn stmt_prune_witnesses( - &mut self, - below_height: BlockHeight, - ) -> Result<(), SqliteClientError> { - self.stmt_prune_witnesses - .execute([u32::from(below_height)])?; - Ok(()) - } - - /// Marks notes that have not been mined in transactions as expired, up to the given - /// block height. - pub fn stmt_update_expired(&mut self, height: BlockHeight) -> Result<(), SqliteClientError> { - self.stmt_update_expired.execute([u32::from(height)])?; - Ok(()) - } -} diff --git a/zcash_client_sqlite/src/wallet.rs b/zcash_client_sqlite/src/wallet.rs index dbffcaadd..a05597134 100644 --- a/zcash_client_sqlite/src/wallet.rs +++ b/zcash_client_sqlite/src/wallet.rs @@ -1,4 +1,4 @@ -//! Functions for querying information in the wdb database. +//! Functions for querying information in the wallet database. //! //! These functions should generally not be used directly; instead, //! their functionality is available via the [`WalletRead`] and @@ -64,7 +64,7 @@ //! wallet. //! - `memo` the shielded memo associated with the output, if any. -use rusqlite::{named_params, OptionalExtension, ToSql}; +use rusqlite::{self, named_params, params, OptionalExtension, ToSql}; use std::collections::HashMap; use std::convert::TryFrom; @@ -72,6 +72,7 @@ use zcash_primitives::{ block::BlockHash, consensus::{self, BlockHeight, BranchId, NetworkUpgrade, Parameters}, memo::{Memo, MemoBytes}, + merkle_tree::write_commitment_tree, sapling::CommitmentTree, transaction::{components::Amount, Transaction, TxId}, zip32::{ @@ -83,22 +84,18 @@ use zcash_primitives::{ use zcash_client_backend::{ address::{RecipientAddress, UnifiedAddress}, data_api::{PoolType, Recipient, SentTransactionOutput}, + encoding::AddressCodec, keys::UnifiedFullViewingKey, wallet::WalletTx, }; -use crate::{ - error::SqliteClientError, prepared::InsertAddress, DataConnStmtCache, WalletDb, PRUNING_HEIGHT, -}; +use crate::{error::SqliteClientError, PRUNING_HEIGHT}; #[cfg(feature = "transparent-inputs")] use { crate::UtxoId, - rusqlite::{params, Connection}, std::collections::BTreeSet, - zcash_client_backend::{ - address::AddressMetadata, encoding::AddressCodec, wallet::WalletTransparentOutput, - }, + zcash_client_backend::{address::AddressMetadata, wallet::WalletTransparentOutput}, zcash_primitives::{ legacy::{keys::IncomingViewingKey, Script, TransparentAddress}, transaction::components::{OutPoint, TxOut}, @@ -118,28 +115,33 @@ pub(crate) fn pool_code(pool_type: PoolType) -> i64 { } } -pub(crate) fn get_max_account_id

( - wdb: &WalletDb

, +pub(crate) fn memo_repr(memo: Option<&MemoBytes>) -> Option<&[u8]> { + memo.filter(|m| *m != &MemoBytes::empty()) + .map(|m| m.as_slice()) +} + +pub(crate) fn get_max_account_id( + conn: &rusqlite::Connection, ) -> Result, SqliteClientError> { // This returns the most recently generated address. - wdb.conn - .query_row("SELECT MAX(account) FROM accounts", [], |row| { - let account_id: Option = row.get(0)?; - Ok(account_id.map(AccountId::from)) - }) - .map_err(SqliteClientError::from) + conn.query_row("SELECT MAX(account) FROM accounts", [], |row| { + let account_id: Option = row.get(0)?; + Ok(account_id.map(AccountId::from)) + }) + .map_err(SqliteClientError::from) } pub(crate) fn add_account( - wdb: &WalletDb

, + conn: &rusqlite::Transaction, + params: &P, account: AccountId, key: &UnifiedFullViewingKey, ) -> Result<(), SqliteClientError> { - add_account_internal(&wdb.conn, &wdb.params, "accounts", account, key) + add_account_internal(conn, params, "accounts", account, key) } pub(crate) fn add_account_internal>( - conn: &rusqlite::Connection, + conn: &rusqlite::Transaction, network: &P, accounts_table: &'static str, account: AccountId, @@ -156,18 +158,18 @@ pub(crate) fn add_account_internal( - wdb: &WalletDb

, + conn: &rusqlite::Connection, + params: &P, account: AccountId, ) -> Result, SqliteClientError> { // This returns the most recently generated address. - let addr: Option<(String, Vec)> = wdb - .conn + let addr: Option<(String, Vec)> = conn .query_row( "SELECT address, diversifier_index_be FROM addresses WHERE account = :account @@ -184,7 +186,7 @@ pub(crate) fn get_current_address( })?; di_be.reverse(); - RecipientAddress::decode(&wdb.params, &addr_str) + RecipientAddress::decode(params, &addr_str) .ok_or_else(|| { SqliteClientError::CorruptedData("Not a valid Zcash recipient address".to_owned()) }) @@ -200,10 +202,47 @@ pub(crate) fn get_current_address( .transpose() } +/// Adds the given address and diversifier index to the addresses table. +/// +/// Returns the database row for the newly-inserted address. +pub(crate) fn insert_address( + conn: &rusqlite::Connection, + params: &P, + account: AccountId, + mut diversifier_index: DiversifierIndex, + address: &UnifiedAddress, +) -> Result<(), rusqlite::Error> { + let mut stmt = conn.prepare_cached( + "INSERT INTO addresses ( + account, + diversifier_index_be, + address, + cached_transparent_receiver_address + ) + VALUES ( + :account, + :diversifier_index_be, + :address, + :cached_transparent_receiver_address + )", + )?; + + // the diversifier index is stored in big-endian order to allow sorting + diversifier_index.0.reverse(); + stmt.execute(named_params![ + ":account": &u32::from(account), + ":diversifier_index_be": &&diversifier_index.0[..], + ":address": &address.encode(params), + ":cached_transparent_receiver_address": &address.transparent().map(|r| r.encode(params)), + ])?; + + Ok(()) +} + #[cfg(feature = "transparent-inputs")] pub(crate) fn get_transparent_receivers( + conn: &rusqlite::Connection, params: &P, - conn: &Connection, account: AccountId, ) -> Result, SqliteClientError> { let mut ret = HashMap::new(); @@ -254,7 +293,7 @@ pub(crate) fn get_transparent_receivers( #[cfg(feature = "transparent-inputs")] pub(crate) fn get_legacy_transparent_address( params: &P, - conn: &Connection, + conn: &rusqlite::Connection, account: AccountId, ) -> Result, SqliteClientError> { // Get the UFVK for the account. @@ -288,18 +327,18 @@ pub(crate) fn get_legacy_transparent_address( /// Returns the [`UnifiedFullViewingKey`]s for the wallet. pub(crate) fn get_unified_full_viewing_keys( - wdb: &WalletDb

, + conn: &rusqlite::Connection, + params: &P, ) -> Result, SqliteClientError> { // Fetch the UnifiedFullViewingKeys we are tracking - let mut stmt_fetch_accounts = wdb - .conn - .prepare("SELECT account, ufvk FROM accounts ORDER BY account ASC")?; + let mut stmt_fetch_accounts = + conn.prepare("SELECT account, ufvk FROM accounts ORDER BY account ASC")?; let rows = stmt_fetch_accounts.query_map([], |row| { let acct: u32 = row.get(0)?; let account = AccountId::from(acct); let ufvk_str: String = row.get(1)?; - let ufvk = UnifiedFullViewingKey::decode(&wdb.params, &ufvk_str) + let ufvk = UnifiedFullViewingKey::decode(params, &ufvk_str) .map_err(SqliteClientError::CorruptedData); Ok((account, ufvk)) @@ -317,20 +356,20 @@ pub(crate) fn get_unified_full_viewing_keys( /// Returns the account id corresponding to a given [`UnifiedFullViewingKey`], /// if any. pub(crate) fn get_account_for_ufvk( - wdb: &WalletDb

, + conn: &rusqlite::Connection, + params: &P, ufvk: &UnifiedFullViewingKey, ) -> Result, SqliteClientError> { - wdb.conn - .query_row( - "SELECT account FROM accounts WHERE ufvk = ?", - [&ufvk.encode(&wdb.params)], - |row| { - let acct: u32 = row.get(0)?; - Ok(AccountId::from(acct)) - }, - ) - .optional() - .map_err(SqliteClientError::from) + conn.query_row( + "SELECT account FROM accounts WHERE ufvk = ?", + [&ufvk.encode(params)], + |row| { + let acct: u32 = row.get(0)?; + Ok(AccountId::from(acct)) + }, + ) + .optional() + .map_err(SqliteClientError::from) } /// Checks whether the specified [`ExtendedFullViewingKey`] is valid and corresponds to the @@ -338,15 +377,15 @@ pub(crate) fn get_account_for_ufvk( /// /// [`ExtendedFullViewingKey`]: zcash_primitives::zip32::ExtendedFullViewingKey pub(crate) fn is_valid_account_extfvk( - wdb: &WalletDb

, + conn: &rusqlite::Connection, + params: &P, account: AccountId, extfvk: &ExtendedFullViewingKey, ) -> Result { - wdb.conn - .prepare("SELECT ufvk FROM accounts WHERE account = ?")? + conn.prepare("SELECT ufvk FROM accounts WHERE account = ?")? .query_row([u32::from(account).to_sql()?], |row| { row.get(0).map(|ufvk_str: String| { - UnifiedFullViewingKey::decode(&wdb.params, &ufvk_str) + UnifiedFullViewingKey::decode(params, &ufvk_str) .map_err(SqliteClientError::CorruptedData) }) }) @@ -372,11 +411,11 @@ pub(crate) fn is_valid_account_extfvk( /// caveat. Use [`get_balance_at`] where you need a more reliable indication of the /// wallet balance. #[cfg(test)] -pub(crate) fn get_balance

( - wdb: &WalletDb

, +pub(crate) fn get_balance( + conn: &rusqlite::Connection, account: AccountId, ) -> Result { - let balance = wdb.conn.query_row( + let balance = conn.query_row( "SELECT SUM(value) FROM sapling_received_notes INNER JOIN transactions ON transactions.id_tx = sapling_received_notes.tx WHERE account = ? AND spent IS NULL AND transactions.block IS NOT NULL", @@ -395,12 +434,12 @@ pub(crate) fn get_balance

( /// Returns the verified balance for the account at the specified height, /// This may be used to obtain a balance that ignores notes that have been /// received so recently that they are not yet deemed spendable. -pub(crate) fn get_balance_at

( - wdb: &WalletDb

, +pub(crate) fn get_balance_at( + conn: &rusqlite::Connection, account: AccountId, anchor_height: BlockHeight, ) -> Result { - let balance = wdb.conn.query_row( + let balance = conn.query_row( "SELECT SUM(value) FROM sapling_received_notes INNER JOIN transactions ON transactions.id_tx = sapling_received_notes.tx WHERE account = ? AND spent IS NULL AND transactions.block <= ?", @@ -420,11 +459,11 @@ pub(crate) fn get_balance_at

( /// /// The note is identified by its row index in the `sapling_received_notes` table within the wdb /// database. -pub(crate) fn get_received_memo

( - wdb: &WalletDb

, +pub(crate) fn get_received_memo( + conn: &rusqlite::Connection, id_note: i64, ) -> Result, SqliteClientError> { - let memo_bytes: Option> = wdb.conn.query_row( + let memo_bytes: Option> = conn.query_row( "SELECT memo FROM sapling_received_notes WHERE id_note = ?", [id_note], @@ -442,10 +481,11 @@ pub(crate) fn get_received_memo

( /// Looks up a transaction by its internal database identifier. pub(crate) fn get_transaction( - wdb: &WalletDb

, + conn: &rusqlite::Connection, + params: &P, id_tx: i64, ) -> Result { - let (tx_bytes, block_height): (Vec<_>, BlockHeight) = wdb.conn.query_row( + let (tx_bytes, block_height): (Vec<_>, BlockHeight) = conn.query_row( "SELECT raw, block FROM transactions WHERE id_tx = ?", [id_tx], @@ -455,22 +495,19 @@ pub(crate) fn get_transaction( }, )?; - Transaction::read( - &tx_bytes[..], - BranchId::for_height(&wdb.params, block_height), - ) - .map_err(SqliteClientError::from) + Transaction::read(&tx_bytes[..], BranchId::for_height(params, block_height)) + .map_err(SqliteClientError::from) } /// Returns the memo for a sent note. /// /// The note is identified by its row index in the `sent_notes` table within the wdb /// database. -pub(crate) fn get_sent_memo

( - wdb: &WalletDb

, +pub(crate) fn get_sent_memo( + conn: &rusqlite::Connection, id_note: i64, ) -> Result, SqliteClientError> { - let memo_bytes: Option> = wdb.conn.query_row( + let memo_bytes: Option> = conn.query_row( "SELECT memo FROM sent_notes WHERE id_note = ?", [id_note], @@ -487,74 +524,66 @@ pub(crate) fn get_sent_memo

( } /// Returns the minimum and maximum heights for blocks stored in the wallet database. -pub(crate) fn block_height_extrema

( - wdb: &WalletDb

, +pub(crate) fn block_height_extrema( + conn: &rusqlite::Connection, ) -> Result, rusqlite::Error> { - wdb.conn - .query_row("SELECT MIN(height), MAX(height) FROM blocks", [], |row| { - let min_height: u32 = row.get(0)?; - let max_height: u32 = row.get(1)?; - Ok(Some(( - BlockHeight::from(min_height), - BlockHeight::from(max_height), - ))) - }) - //.optional() doesn't work here because a failed aggregate function - //produces a runtime error, not an empty set of rows. - .or(Ok(None)) + conn.query_row("SELECT MIN(height), MAX(height) FROM blocks", [], |row| { + let min_height: Option = row.get(0)?; + let max_height: Option = row.get(1)?; + Ok(min_height + .map(BlockHeight::from) + .zip(max_height.map(BlockHeight::from))) + }) } /// Returns the block height at which the specified transaction was mined, /// if any. -pub(crate) fn get_tx_height

( - wdb: &WalletDb

, +pub(crate) fn get_tx_height( + conn: &rusqlite::Connection, txid: TxId, ) -> Result, rusqlite::Error> { - wdb.conn - .query_row( - "SELECT block FROM transactions WHERE txid = ?", - [txid.as_ref().to_vec()], - |row| row.get(0).map(u32::into), - ) - .optional() + conn.query_row( + "SELECT block FROM transactions WHERE txid = ?", + [txid.as_ref().to_vec()], + |row| row.get(0).map(u32::into), + ) + .optional() } /// Returns the block hash for the block at the specified height, /// if any. -pub(crate) fn get_block_hash

( - wdb: &WalletDb

, +pub(crate) fn get_block_hash( + conn: &rusqlite::Connection, block_height: BlockHeight, ) -> Result, rusqlite::Error> { - wdb.conn - .query_row( - "SELECT hash FROM blocks WHERE height = ?", - [u32::from(block_height)], - |row| { - let row_data = row.get::<_, Vec<_>>(0)?; - Ok(BlockHash::from_slice(&row_data)) - }, - ) - .optional() + conn.query_row( + "SELECT hash FROM blocks WHERE height = ?", + [u32::from(block_height)], + |row| { + let row_data = row.get::<_, Vec<_>>(0)?; + Ok(BlockHash::from_slice(&row_data)) + }, + ) + .optional() } /// Gets the height to which the database must be truncated if any truncation that would remove a /// number of blocks greater than the pruning height is attempted. -pub(crate) fn get_min_unspent_height

( - wdb: &WalletDb

, +pub(crate) fn get_min_unspent_height( + conn: &rusqlite::Connection, ) -> Result, SqliteClientError> { - wdb.conn - .query_row( - "SELECT MIN(tx.block) + conn.query_row( + "SELECT MIN(tx.block) FROM sapling_received_notes n JOIN transactions tx ON tx.id_tx = n.tx WHERE n.spent IS NULL", - [], - |row| { - row.get(0) - .map(|maybe_height: Option| maybe_height.map(|height| height.into())) - }, - ) - .map_err(SqliteClientError::from) + [], + |row| { + row.get(0) + .map(|maybe_height: Option| maybe_height.map(|height| height.into())) + }, + ) + .map_err(SqliteClientError::from) } /// Truncates the database to the given height. @@ -564,25 +593,22 @@ pub(crate) fn get_min_unspent_height

( /// /// This should only be executed inside a transactional context. pub(crate) fn truncate_to_height( - wdb: &WalletDb

, + conn: &rusqlite::Transaction, + params: &P, block_height: BlockHeight, ) -> Result<(), SqliteClientError> { - let sapling_activation_height = wdb - .params + let sapling_activation_height = params .activation_height(NetworkUpgrade::Sapling) .expect("Sapling activation height mutst be available."); // Recall where we synced up to previously. - let last_scanned_height = wdb - .conn - .query_row("SELECT MAX(height) FROM blocks", [], |row| { - row.get(0) - .map(|h: u32| h.into()) - .or_else(|_| Ok(sapling_activation_height - 1)) - })?; + let last_scanned_height = conn.query_row("SELECT MAX(height) FROM blocks", [], |row| { + row.get::<_, Option>(0) + .map(|opt| opt.map_or_else(|| sapling_activation_height - 1, BlockHeight::from)) + })?; if block_height < last_scanned_height - PRUNING_HEIGHT { - if let Some(h) = get_min_unspent_height(wdb)? { + if let Some(h) = get_min_unspent_height(conn)? { if block_height > h { return Err(SqliteClientError::RequestedRewindInvalid(h, block_height)); } @@ -592,21 +618,21 @@ pub(crate) fn truncate_to_height( // nothing to do if we're deleting back down to the max height if block_height < last_scanned_height { // Decrement witnesses. - wdb.conn.execute( + conn.execute( "DELETE FROM sapling_witnesses WHERE block > ?", [u32::from(block_height)], )?; // Rewind received notes - wdb.conn.execute( + conn.execute( "DELETE FROM sapling_received_notes - WHERE id_note IN ( - SELECT rn.id_note - FROM sapling_received_notes rn - LEFT OUTER JOIN transactions tx - ON tx.id_tx = rn.tx - WHERE tx.block IS NOT NULL AND tx.block > ? - );", + WHERE id_note IN ( + SELECT rn.id_note + FROM sapling_received_notes rn + LEFT OUTER JOIN transactions tx + ON tx.id_tx = rn.tx + WHERE tx.block IS NOT NULL AND tx.block > ? + );", [u32::from(block_height)], )?; @@ -615,19 +641,19 @@ pub(crate) fn truncate_to_height( // presence of stale sent notes that link to unmined transactions. // Rewind utxos - wdb.conn.execute( + conn.execute( "DELETE FROM utxos WHERE height > ?", [u32::from(block_height)], )?; // Un-mine transactions. - wdb.conn.execute( + conn.execute( "UPDATE transactions SET block = NULL, tx_index = NULL WHERE block IS NOT NULL AND block > ?", [u32::from(block_height)], )?; // Now that they aren't depended on, delete scanned blocks. - wdb.conn.execute( + conn.execute( "DELETE FROM blocks WHERE height > ?", [u32::from(block_height)], )?; @@ -641,12 +667,13 @@ pub(crate) fn truncate_to_height( /// height less than or equal to the provided `max_height`. #[cfg(feature = "transparent-inputs")] pub(crate) fn get_unspent_transparent_outputs( - wdb: &WalletDb

, + conn: &rusqlite::Connection, + params: &P, address: &TransparentAddress, max_height: BlockHeight, exclude: &[OutPoint], ) -> Result, SqliteClientError> { - let mut stmt_blocks = wdb.conn.prepare( + let mut stmt_blocks = conn.prepare( "SELECT u.prevout_txid, u.prevout_idx, u.script, u.value_zat, u.height, tx.block as block FROM utxos u @@ -657,7 +684,7 @@ pub(crate) fn get_unspent_transparent_outputs( AND tx.block IS NULL", )?; - let addr_str = address.encode(&wdb.params); + let addr_str = address.encode(params); let mut utxos = Vec::::new(); let mut rows = stmt_blocks.query(params![addr_str, u32::from(max_height)])?; @@ -703,11 +730,12 @@ pub(crate) fn get_unspent_transparent_outputs( /// the provided `max_height`. #[cfg(feature = "transparent-inputs")] pub(crate) fn get_transparent_balances( - wdb: &WalletDb

, + conn: &rusqlite::Connection, + params: &P, account: AccountId, max_height: BlockHeight, ) -> Result, SqliteClientError> { - let mut stmt_blocks = wdb.conn.prepare( + let mut stmt_blocks = conn.prepare( "SELECT u.address, SUM(u.value_zat) FROM utxos u LEFT OUTER JOIN transactions tx @@ -722,7 +750,7 @@ pub(crate) fn get_transparent_balances( let mut rows = stmt_blocks.query(params![u32::from(account), u32::from(max_height)])?; while let Some(row) = rows.next()? { let taddr_str: String = row.get(0)?; - let taddr = TransparentAddress::decode(&wdb.params, &taddr_str)?; + let taddr = TransparentAddress::decode(params, &taddr_str)?; let value = Amount::from_i64(row.get(1)?).unwrap(); res.insert(taddr, value); @@ -732,149 +760,287 @@ pub(crate) fn get_transparent_balances( } /// Inserts information about a scanned block into the database. -pub(crate) fn insert_block<'a, P>( - stmts: &mut DataConnStmtCache<'a, P>, +pub(crate) fn insert_block( + conn: &rusqlite::Connection, block_height: BlockHeight, block_hash: BlockHash, block_time: u32, commitment_tree: &CommitmentTree, ) -> Result<(), SqliteClientError> { - stmts.stmt_insert_block(block_height, block_hash, block_time, commitment_tree) + let mut encoded_tree = Vec::new(); + write_commitment_tree(commitment_tree, &mut encoded_tree).unwrap(); + + let mut stmt_insert_block = conn.prepare_cached( + "INSERT INTO blocks (height, hash, time, sapling_tree) + VALUES (?, ?, ?, ?)", + )?; + + stmt_insert_block.execute(params![ + u32::from(block_height), + &block_hash.0[..], + block_time, + encoded_tree + ])?; + + Ok(()) } /// Inserts information about a mined transaction that was observed to /// contain a note related to this wallet into the database. -pub(crate) fn put_tx_meta<'a, P, N>( - stmts: &mut DataConnStmtCache<'a, P>, +pub(crate) fn put_tx_meta( + conn: &rusqlite::Connection, tx: &WalletTx, height: BlockHeight, ) -> Result { - if !stmts.stmt_update_tx_meta(height, tx.index, &tx.txid)? { - // It isn't there, so insert our transaction into the database. - stmts.stmt_insert_tx_meta(&tx.txid, height, tx.index) - } else { - // It was there, so grab its row number. - stmts.stmt_select_tx_ref(&tx.txid) - } + // It isn't there, so insert our transaction into the database. + let mut stmt_upsert_tx_meta = conn.prepare_cached( + "INSERT INTO transactions (txid, block, tx_index) + VALUES (:txid, :block, :tx_index) + ON CONFLICT (txid) DO UPDATE + SET block = :block, + tx_index = :tx_index + RETURNING id_tx", + )?; + + let tx_params = named_params![ + ":txid": &tx.txid.as_ref()[..], + ":block": u32::from(height), + ":tx_index": i64::try_from(tx.index).expect("transaction indices are representable as i64"), + ]; + + stmt_upsert_tx_meta + .query_row(tx_params, |row| row.get::<_, i64>(0)) + .map_err(SqliteClientError::from) } /// Inserts full transaction data into the database. -pub(crate) fn put_tx_data<'a, P>( - stmts: &mut DataConnStmtCache<'a, P>, +pub(crate) fn put_tx_data( + conn: &rusqlite::Connection, tx: &Transaction, fee: Option, created_at: Option, ) -> Result { - let txid = tx.txid(); + let mut stmt_upsert_tx_data = conn.prepare_cached( + "INSERT INTO transactions (txid, created, expiry_height, raw, fee) + VALUES (:txid, :created_at, :expiry_height, :raw, :fee) + ON CONFLICT (txid) DO UPDATE + SET expiry_height = :expiry_height, + raw = :raw, + fee = IFNULL(:fee, fee) + RETURNING id_tx", + )?; + let txid = tx.txid(); let mut raw_tx = vec![]; tx.write(&mut raw_tx)?; - if !stmts.stmt_update_tx_data(tx.expiry_height(), &raw_tx, fee, &txid)? { - // It isn't there, so insert our transaction into the database. - stmts.stmt_insert_tx_data(&txid, created_at, tx.expiry_height(), &raw_tx, fee) - } else { - // It was there, so grab its row number. - stmts.stmt_select_tx_ref(&txid) - } + let tx_params = named_params![ + ":txid": &txid.as_ref()[..], + ":created_at": created_at, + ":expiry_height": u32::from(tx.expiry_height()), + ":raw": raw_tx, + ":fee": fee.map(i64::from), + ]; + + stmt_upsert_tx_data + .query_row(tx_params, |row| row.get::<_, i64>(0)) + .map_err(SqliteClientError::from) } /// Marks the given UTXO as having been spent. #[cfg(feature = "transparent-inputs")] -pub(crate) fn mark_transparent_utxo_spent<'a, P>( - stmts: &mut DataConnStmtCache<'a, P>, +pub(crate) fn mark_transparent_utxo_spent( + conn: &rusqlite::Connection, tx_ref: i64, outpoint: &OutPoint, ) -> Result<(), SqliteClientError> { - stmts.stmt_mark_transparent_utxo_spent(tx_ref, outpoint)?; + let mut stmt_mark_transparent_utxo_spent = conn.prepare_cached( + "UPDATE utxos SET spent_in_tx = :spent_in_tx + WHERE prevout_txid = :prevout_txid + AND prevout_idx = :prevout_idx", + )?; + let sql_args = named_params![ + ":spent_in_tx": &tx_ref, + ":prevout_txid": &outpoint.hash().to_vec(), + ":prevout_idx": &outpoint.n(), + ]; + + stmt_mark_transparent_utxo_spent.execute(sql_args)?; Ok(()) } /// Adds the given received UTXO to the datastore. #[cfg(feature = "transparent-inputs")] -pub(crate) fn put_received_transparent_utxo<'a, P: consensus::Parameters>( - stmts: &mut DataConnStmtCache<'a, P>, +pub(crate) fn put_received_transparent_utxo( + conn: &rusqlite::Connection, + params: &P, output: &WalletTransparentOutput, ) -> Result { - stmts - .stmt_update_received_transparent_utxo(output) - .transpose() - .or_else(|| { - stmts - .stmt_insert_received_transparent_utxo(output) - .transpose() - }) - .unwrap_or_else(|| { - // This could occur if the UTXO is received at the legacy transparent - // address, in which case the join to the `addresses` table will fail. - // In this case, we should look up the legacy address for account 0 and - // check whether it matches the address for the received UTXO, and if - // so then insert/update it directly. - let account = AccountId::from(0u32); - get_legacy_transparent_address(&stmts.wallet_db.params, &stmts.wallet_db.conn, account) - .and_then(|legacy_taddr| { - if legacy_taddr - .iter() - .any(|(taddr, _)| taddr == output.recipient_address()) - { - stmts - .stmt_update_legacy_transparent_utxo(output, account) - .transpose() - .unwrap_or_else(|| { - stmts.stmt_insert_legacy_transparent_utxo(output, account) - }) - } else { - Err(SqliteClientError::AddressNotRecognized( - *output.recipient_address(), - )) - } - }) - }) + let address_str = output.recipient_address().encode(params); + let account_id = conn + .query_row( + "SELECT account FROM addresses WHERE cached_transparent_receiver_address = :address", + named_params![":address": &address_str], + |row| row.get::<_, u32>(0).map(AccountId::from), + ) + .optional()?; + + let utxoid = if let Some(account) = account_id { + put_legacy_transparent_utxo(conn, params, output, account)? + } else { + // If the UTXO is received at the legacy transparent address, there may be no entry in the + // addresses table that can be used to tie the address to a particular account. In this + // case, we should look up the legacy address for account 0 and check whether it matches + // the address for the received UTXO, and if so then insert/update it directly. + let account = AccountId::from(0u32); + get_legacy_transparent_address(params, conn, account).and_then(|legacy_taddr| { + if legacy_taddr + .iter() + .any(|(taddr, _)| taddr == output.recipient_address()) + { + put_legacy_transparent_utxo(conn, params, output, account) + .map_err(SqliteClientError::from) + } else { + Err(SqliteClientError::AddressNotRecognized( + *output.recipient_address(), + )) + } + })? + }; + + Ok(utxoid) +} + +#[cfg(feature = "transparent-inputs")] +pub(crate) fn put_legacy_transparent_utxo( + conn: &rusqlite::Connection, + params: &P, + output: &WalletTransparentOutput, + received_by_account: AccountId, +) -> Result { + #[cfg(feature = "transparent-inputs")] + let mut stmt_upsert_legacy_transparent_utxo = conn.prepare_cached( + "INSERT INTO utxos ( + prevout_txid, prevout_idx, + received_by_account, address, script, + value_zat, height) + VALUES + (:prevout_txid, :prevout_idx, + :received_by_account, :address, :script, + :value_zat, :height) + ON CONFLICT (prevout_txid, prevout_idx) DO UPDATE + SET received_by_account = :received_by_account, + height = :height, + address = :address, + script = :script, + value_zat = :value_zat + RETURNING id_utxo", + )?; + + let sql_args = named_params![ + ":prevout_txid": &output.outpoint().hash().to_vec(), + ":prevout_idx": &output.outpoint().n(), + ":received_by_account": &u32::from(received_by_account), + ":address": &output.recipient_address().encode(params), + ":script": &output.txout().script_pubkey.0, + ":value_zat": &i64::from(output.txout().value), + ":height": &u32::from(output.height()), + ]; + + stmt_upsert_legacy_transparent_utxo.query_row(sql_args, |row| row.get::<_, i64>(0).map(UtxoId)) } /// Removes old incremental witnesses up to the given block height. -pub(crate) fn prune_witnesses

( - stmts: &mut DataConnStmtCache<'_, P>, +pub(crate) fn prune_witnesses( + conn: &rusqlite::Connection, below_height: BlockHeight, ) -> Result<(), SqliteClientError> { - stmts.stmt_prune_witnesses(below_height) + let mut stmt_prune_witnesses = + conn.prepare_cached("DELETE FROM sapling_witnesses WHERE block < ?")?; + stmt_prune_witnesses.execute([u32::from(below_height)])?; + Ok(()) } /// Marks notes that have not been mined in transactions /// as expired, up to the given block height. -pub(crate) fn update_expired_notes

( - stmts: &mut DataConnStmtCache<'_, P>, +pub(crate) fn update_expired_notes( + conn: &rusqlite::Connection, height: BlockHeight, ) -> Result<(), SqliteClientError> { - stmts.stmt_update_expired(height) + let mut stmt_update_expired = conn.prepare_cached( + "UPDATE sapling_received_notes SET spent = NULL WHERE EXISTS ( + SELECT id_tx FROM transactions + WHERE id_tx = sapling_received_notes.spent AND block IS NULL AND expiry_height < ? + )", + )?; + stmt_update_expired.execute([u32::from(height)])?; + Ok(()) +} + +// A utility function for creation of parameters for use in `insert_sent_output` +// and `put_sent_output` +fn recipient_params( + params: &P, + to: &Recipient, +) -> (Option, Option, PoolType) { + match to { + Recipient::Transparent(addr) => (Some(addr.encode(params)), None, PoolType::Transparent), + Recipient::Sapling(addr) => (Some(addr.encode(params)), None, PoolType::Sapling), + Recipient::Unified(addr, pool) => (Some(addr.encode(params)), None, *pool), + Recipient::InternalAccount(id, pool) => (None, Some(u32::from(*id)), *pool), + } } /// Records information about a transaction output that your wallet created. -/// -/// This is a crate-internal convenience method. -pub(crate) fn insert_sent_output<'a, P: consensus::Parameters>( - stmts: &mut DataConnStmtCache<'a, P>, +pub(crate) fn insert_sent_output( + conn: &rusqlite::Connection, + params: &P, tx_ref: i64, from_account: AccountId, output: &SentTransactionOutput, ) -> Result<(), SqliteClientError> { - stmts.stmt_insert_sent_output( - tx_ref, - output.output_index(), - from_account, - output.recipient(), - output.value(), - output.memo(), - ) + let mut stmt_insert_sent_output = conn.prepare_cached( + "INSERT INTO sent_notes ( + tx, output_pool, output_index, from_account, + to_address, to_account, value, memo) + VALUES ( + :tx, :output_pool, :output_index, :from_account, + :to_address, :to_account, :value, :memo)", + )?; + + let (to_address, to_account, pool_type) = recipient_params(params, output.recipient()); + let sql_args = named_params![ + ":tx": &tx_ref, + ":output_pool": &pool_code(pool_type), + ":output_index": &i64::try_from(output.output_index()).unwrap(), + ":from_account": &u32::from(from_account), + ":to_address": &to_address, + ":to_account": &to_account, + ":value": &i64::from(output.value()), + ":memo": output.memo().filter(|m| *m != &MemoBytes::empty()).map(|m| m.as_slice()), + ]; + + stmt_insert_sent_output.execute(sql_args)?; + + Ok(()) } -/// Records information about a transaction output that your wallet created. +/// Records information about a transaction output that your wallet created, from the constituent +/// properties of that output. /// -/// This is a crate-internal convenience method. +/// - If `recipient` is a Unified address, `output_index` is an index into the outputs of the +/// transaction within the bundle associated with the recipient's output pool. +/// - If `recipient` is a Sapling address, `output_index` is an index into the Sapling outputs of +/// the transaction. +/// - If `recipient` is a transparent address, `output_index` is an index into the transparent +/// outputs of the transaction. +/// - If `recipient` is an internal account, `output_index` is an index into the Sapling outputs of +/// the transaction. #[allow(clippy::too_many_arguments)] -pub(crate) fn put_sent_output<'a, P: consensus::Parameters>( - stmts: &mut DataConnStmtCache<'a, P>, +pub(crate) fn put_sent_output( + conn: &rusqlite::Connection, + params: &P, from_account: AccountId, tx_ref: i64, output_index: usize, @@ -882,16 +1048,34 @@ pub(crate) fn put_sent_output<'a, P: consensus::Parameters>( value: Amount, memo: Option<&MemoBytes>, ) -> Result<(), SqliteClientError> { - if !stmts.stmt_update_sent_output(from_account, recipient, value, memo, tx_ref, output_index)? { - stmts.stmt_insert_sent_output( - tx_ref, - output_index, - from_account, - recipient, - value, - memo, - )?; - } + let mut stmt_upsert_sent_output = conn.prepare_cached( + "INSERT INTO sent_notes ( + tx, output_pool, output_index, from_account, + to_address, to_account, value, memo) + VALUES ( + :tx, :output_pool, :output_index, :from_account, + :to_address, :to_account, :value, :memo) + ON CONFLICT (tx, output_pool, output_index) DO UPDATE + SET from_account = :from_account, + to_address = :to_address, + to_account = :to_account, + value = :value, + memo = IFNULL(:memo, memo)", + )?; + + let (to_address, to_account, pool_type) = recipient_params(params, recipient); + let sql_args = named_params![ + ":tx": &tx_ref, + ":output_pool": &pool_code(pool_type), + ":output_index": &i64::try_from(output_index).unwrap(), + ":from_account": &u32::from(from_account), + ":to_address": &to_address, + ":to_account": &to_account, + ":value": &i64::from(value), + ":memo": memo_repr(memo) + ]; + + stmt_upsert_sent_output.execute(sql_args)?; Ok(()) } @@ -931,11 +1115,11 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet - tests::init_test_accounts_table(&db_data); + tests::init_test_accounts_table(&mut db_data); // The account should be empty assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), Amount::zero() ); @@ -943,9 +1127,12 @@ mod tests { assert_eq!(db_data.get_target_and_anchor_heights(10).unwrap(), None); // An invalid account has zero balance - assert_matches!(get_current_address(&db_data, AccountId::from(1)), Ok(None)); + assert_matches!( + get_current_address(&db_data.conn, &db_data.params, AccountId::from(1)), + Ok(None) + ); assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), Amount::zero() ); } @@ -958,9 +1145,8 @@ mod tests { init_wallet_db(&mut db_data, None).unwrap(); // Add an account to the wallet - let mut ops = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (account_id, _usk) = ops.create_account(&seed).unwrap(); + let (account_id, _usk) = db_data.create_account(&seed).unwrap(); let uaddr = db_data.get_current_address(account_id).unwrap().unwrap(); let taddr = uaddr.transparent().unwrap(); @@ -979,7 +1165,7 @@ mod tests { ) .unwrap(); - let res0 = super::put_received_transparent_utxo(&mut ops, &utxo); + let res0 = super::put_received_transparent_utxo(&db_data.conn, &db_data.params, &utxo); assert_matches!(res0, Ok(_)); // Change the mined height of the UTXO and upsert; we should get back @@ -993,12 +1179,13 @@ mod tests { BlockHeight::from_u32(34567), ) .unwrap(); - let res1 = super::put_received_transparent_utxo(&mut ops, &utxo2); + let res1 = super::put_received_transparent_utxo(&db_data.conn, &db_data.params, &utxo2); assert_matches!(res1, Ok(id) if id == res0.unwrap()); assert_matches!( super::get_unspent_transparent_outputs( - &db_data, + &db_data.conn, + &db_data.params, taddr, BlockHeight::from_u32(12345), &[] @@ -1008,7 +1195,8 @@ mod tests { assert_matches!( super::get_unspent_transparent_outputs( - &db_data, + &db_data.conn, + &db_data.params, taddr, BlockHeight::from_u32(34567), &[] @@ -1034,7 +1222,7 @@ mod tests { ) .unwrap(); - let res2 = super::put_received_transparent_utxo(&mut ops, &utxo2); + let res2 = super::put_received_transparent_utxo(&db_data.conn, &db_data.params, &utxo2); assert_matches!(res2, Err(_)); } } diff --git a/zcash_client_sqlite/src/wallet/init.rs b/zcash_client_sqlite/src/wallet/init.rs index 665032291..7f5a60ccc 100644 --- a/zcash_client_sqlite/src/wallet/init.rs +++ b/zcash_client_sqlite/src/wallet/init.rs @@ -111,14 +111,14 @@ impl std::error::Error for WalletMigrationError { // check for unspent transparent outputs whenever running initialization with a version of the // library *not* compiled with the `transparent-inputs` feature flag, and fail if any are present. pub fn init_wallet_db( - wdb: &mut WalletDb

, + wdb: &mut WalletDb, seed: Option>, ) -> Result<(), MigratorError> { init_wallet_db_internal(wdb, seed, &[]) } fn init_wallet_db_internal( - wdb: &mut WalletDb

, + wdb: &mut WalletDb, seed: Option>, target_migrations: &[Uuid], ) -> Result<(), MigratorError> { @@ -200,7 +200,7 @@ fn init_wallet_db_internal( /// let dfvk = extsk.to_diversifiable_full_viewing_key(); /// let ufvk = UnifiedFullViewingKey::new(None, Some(dfvk), None).unwrap(); /// let ufvks = HashMap::from([(account, ufvk)]); -/// init_accounts_table(&db_data, &ufvks).unwrap(); +/// init_accounts_table(&mut db_data, &ufvks).unwrap(); /// # } /// ``` /// @@ -208,29 +208,29 @@ fn init_wallet_db_internal( /// [`scan_cached_blocks`]: zcash_client_backend::data_api::chain::scan_cached_blocks /// [`create_spend_to_address`]: zcash_client_backend::data_api::wallet::create_spend_to_address pub fn init_accounts_table( - wdb: &WalletDb

, + wallet_db: &mut WalletDb, keys: &HashMap, ) -> Result<(), SqliteClientError> { - let mut empty_check = wdb.conn.prepare("SELECT * FROM accounts LIMIT 1")?; - if empty_check.exists([])? { - return Err(SqliteClientError::TableNotEmpty); - } - - // Ensure that the account identifiers are sequential and begin at zero. - if let Some(account_id) = keys.keys().max() { - if usize::try_from(u32::from(*account_id)).unwrap() >= keys.len() { - return Err(SqliteClientError::AccountIdDiscontinuity); + wallet_db.transactionally(|wdb| { + let mut empty_check = wdb.conn.0.prepare("SELECT * FROM accounts LIMIT 1")?; + if empty_check.exists([])? { + return Err(SqliteClientError::TableNotEmpty); } - } - // Insert accounts atomically - wdb.conn.execute("BEGIN IMMEDIATE", [])?; - for (account, key) in keys.iter() { - wallet::add_account(wdb, *account, key)?; - } - wdb.conn.execute("COMMIT", [])?; + // Ensure that the account identifiers are sequential and begin at zero. + if let Some(account_id) = keys.keys().max() { + if usize::try_from(u32::from(*account_id)).unwrap() >= keys.len() { + return Err(SqliteClientError::AccountIdDiscontinuity); + } + } - Ok(()) + // Insert accounts atomically + for (account, key) in keys.iter() { + wallet::add_account(&wdb.conn.0, &wdb.params, *account, key)?; + } + + Ok(()) + }) } /// Initialises the data database with the given block. @@ -262,33 +262,35 @@ pub fn init_accounts_table( /// let sapling_tree = &[]; /// /// let data_file = NamedTempFile::new().unwrap(); -/// let db = WalletDb::for_path(data_file.path(), Network::TestNetwork).unwrap(); -/// init_blocks_table(&db, height, hash, time, sapling_tree); +/// let mut db = WalletDb::for_path(data_file.path(), Network::TestNetwork).unwrap(); +/// init_blocks_table(&mut db, height, hash, time, sapling_tree); /// ``` -pub fn init_blocks_table

( - wdb: &WalletDb

, +pub fn init_blocks_table( + wallet_db: &mut WalletDb, height: BlockHeight, hash: BlockHash, time: u32, sapling_tree: &[u8], ) -> Result<(), SqliteClientError> { - let mut empty_check = wdb.conn.prepare("SELECT * FROM blocks LIMIT 1")?; - if empty_check.exists([])? { - return Err(SqliteClientError::TableNotEmpty); - } + wallet_db.transactionally(|wdb| { + let mut empty_check = wdb.conn.0.prepare("SELECT * FROM blocks LIMIT 1")?; + if empty_check.exists([])? { + return Err(SqliteClientError::TableNotEmpty); + } - wdb.conn.execute( - "INSERT INTO blocks (height, hash, time, sapling_tree) + wdb.conn.0.execute( + "INSERT INTO blocks (height, hash, time, sapling_tree) VALUES (?, ?, ?, ?)", - [ - u32::from(height).to_sql()?, - hash.0.to_sql()?, - time.to_sql()?, - sapling_tree.to_sql()?, - ], - )?; + [ + u32::from(height).to_sql()?, + hash.0.to_sql()?, + time.to_sql()?, + sapling_tree.to_sql()?, + ], + )?; - Ok(()) + Ok(()) + }) } #[cfg(test)] @@ -606,7 +608,7 @@ mod tests { #[test] fn init_migrate_from_0_3_0() { fn init_0_3_0

( - wdb: &mut WalletDb

, + wdb: &mut WalletDb, extfvk: &ExtendedFullViewingKey, account: AccountId, ) -> Result<(), rusqlite::Error> { @@ -722,7 +724,7 @@ mod tests { #[test] fn init_migrate_from_autoshielding_poc() { fn init_autoshielding

( - wdb: &WalletDb

, + wdb: &mut WalletDb, extfvk: &ExtendedFullViewingKey, account: AccountId, ) -> Result<(), rusqlite::Error> { @@ -878,14 +880,14 @@ mod tests { let extfvk = secret_key.to_extended_full_viewing_key(); let data_file = NamedTempFile::new().unwrap(); let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); - init_autoshielding(&db_data, &extfvk, account).unwrap(); + init_autoshielding(&mut db_data, &extfvk, account).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(seed.to_vec()))).unwrap(); } #[test] fn init_migrate_from_main_pre_migrations() { fn init_main

( - wdb: &WalletDb

, + wdb: &mut WalletDb, ufvk: &UnifiedFullViewingKey, account: AccountId, ) -> Result<(), rusqlite::Error> { @@ -1025,7 +1027,12 @@ mod tests { let secret_key = UnifiedSpendingKey::from_seed(&tests::network(), &seed, account).unwrap(); let data_file = NamedTempFile::new().unwrap(); let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); - init_main(&db_data, &secret_key.to_unified_full_viewing_key(), account).unwrap(); + init_main( + &mut db_data, + &secret_key.to_unified_full_viewing_key(), + account, + ) + .unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(seed.to_vec()))).unwrap(); } @@ -1036,8 +1043,8 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // We can call the function as many times as we want with no data - init_accounts_table(&db_data, &HashMap::new()).unwrap(); - init_accounts_table(&db_data, &HashMap::new()).unwrap(); + init_accounts_table(&mut db_data, &HashMap::new()).unwrap(); + init_accounts_table(&mut db_data, &HashMap::new()).unwrap(); let seed = [0u8; 32]; let account = AccountId::from(0); @@ -1062,11 +1069,11 @@ mod tests { let ufvk = UnifiedFullViewingKey::new(Some(dfvk), None).unwrap(); let ufvks = HashMap::from([(account, ufvk)]); - init_accounts_table(&db_data, &ufvks).unwrap(); + init_accounts_table(&mut db_data, &ufvks).unwrap(); // Subsequent calls should return an error - init_accounts_table(&db_data, &HashMap::new()).unwrap_err(); - init_accounts_table(&db_data, &ufvks).unwrap_err(); + init_accounts_table(&mut db_data, &HashMap::new()).unwrap_err(); + init_accounts_table(&mut db_data, &ufvks).unwrap_err(); } #[test] @@ -1090,12 +1097,12 @@ mod tests { // should fail if we have a gap assert_matches!( - init_accounts_table(&db_data, &ufvks(&[0, 2])), + init_accounts_table(&mut db_data, &ufvks(&[0, 2])), Err(SqliteClientError::AccountIdDiscontinuity) ); // should succeed if there are no gaps - assert!(init_accounts_table(&db_data, &ufvks(&[0, 1, 2])).is_ok()); + assert!(init_accounts_table(&mut db_data, &ufvks(&[0, 1, 2])).is_ok()); } #[test] @@ -1106,7 +1113,7 @@ mod tests { // First call with data should initialise the blocks table init_blocks_table( - &db_data, + &mut db_data, BlockHeight::from(1u32), BlockHash([1; 32]), 1, @@ -1116,7 +1123,7 @@ mod tests { // Subsequent calls should return an error init_blocks_table( - &db_data, + &mut db_data, BlockHeight::from(2u32), BlockHash([2; 32]), 2, @@ -1139,7 +1146,7 @@ mod tests { let ufvk = usk.to_unified_full_viewing_key(); let expected_address = ufvk.sapling().unwrap().default_address().1; let ufvks = HashMap::from([(account_id, ufvk)]); - init_accounts_table(&db_data, &ufvks).unwrap(); + init_accounts_table(&mut db_data, &ufvks).unwrap(); // The account's address should be in the data DB let ua = db_data.get_current_address(AccountId::from(0)).unwrap(); @@ -1153,16 +1160,15 @@ mod tests { let mut db_data = WalletDb::for_path(data_file.path(), Network::MainNetwork).unwrap(); init_wallet_db(&mut db_data, None).unwrap(); - let mut ops = db_data.get_update_ops().unwrap(); let seed = test_vectors::UNIFIED[0].root_seed; - let (account, _usk) = ops.create_account(&Secret::new(seed.to_vec())).unwrap(); + let (account, _usk) = db_data.create_account(&Secret::new(seed.to_vec())).unwrap(); assert_eq!(account, AccountId::from(0u32)); for tv in &test_vectors::UNIFIED[..3] { if let Some(RecipientAddress::Unified(tvua)) = RecipientAddress::decode(&Network::MainNetwork, tv.unified_addr) { - let (ua, di) = wallet::get_current_address(&db_data, account) + let (ua, di) = wallet::get_current_address(&db_data.conn, &db_data.params, account) .unwrap() .expect("create_account generated the first address"); assert_eq!(DiversifierIndex::from(tv.diversifier_index), di); @@ -1170,7 +1176,8 @@ mod tests { assert_eq!(tvua.sapling(), ua.sapling()); assert_eq!(tv.unified_addr, ua.encode(&Network::MainNetwork)); - ops.get_next_available_address(account) + db_data + .get_next_available_address(account) .unwrap() .expect("get_next_available_address generated an address"); } else { diff --git a/zcash_client_sqlite/src/wallet/init/migrations/add_utxo_account.rs b/zcash_client_sqlite/src/wallet/init/migrations/add_utxo_account.rs index f658ae103..cc3c61f6a 100644 --- a/zcash_client_sqlite/src/wallet/init/migrations/add_utxo_account.rs +++ b/zcash_client_sqlite/src/wallet/init/migrations/add_utxo_account.rs @@ -67,7 +67,7 @@ impl RusqliteMigration for Migration

{ while let Some(row) = rows.next()? { let account: u32 = row.get(0)?; let taddrs = - get_transparent_receivers(&self._params, transaction, AccountId::from(account)) + get_transparent_receivers(transaction, &self._params, AccountId::from(account)) .map_err(|e| match e { SqliteClientError::DbError(e) => WalletMigrationError::DbError(e), SqliteClientError::CorruptedData(s) => { diff --git a/zcash_client_sqlite/src/wallet/sapling.rs b/zcash_client_sqlite/src/wallet/sapling.rs index a18281dc8..66734cfdc 100644 --- a/zcash_client_sqlite/src/wallet/sapling.rs +++ b/zcash_client_sqlite/src/wallet/sapling.rs @@ -1,12 +1,12 @@ //! Functions for Sapling support in the wallet. use group::ff::PrimeField; -use rusqlite::{named_params, types::Value, OptionalExtension, Row}; +use rusqlite::{named_params, params, types::Value, Connection, OptionalExtension, Row}; use std::rc::Rc; use zcash_primitives::{ consensus::BlockHeight, memo::MemoBytes, - merkle_tree::{read_commitment_tree, read_incremental_witness}, + merkle_tree::{read_commitment_tree, read_incremental_witness, write_incremental_witness}, sapling::{self, Diversifier, Note, Nullifier, Rseed}, transaction::components::Amount, zip32::AccountId, @@ -17,7 +17,9 @@ use zcash_client_backend::{ DecryptedOutput, TransferType, }; -use crate::{error::SqliteClientError, DataConnStmtCache, NoteId, WalletDb}; +use crate::{error::SqliteClientError, NoteId}; + +use super::memo_repr; /// This trait provides a generalization over shielded output representations. pub(crate) trait ReceivedSaplingOutput { @@ -117,13 +119,13 @@ fn to_spendable_note(row: &Row) -> Result, SqliteCli }) } -pub(crate) fn get_spendable_sapling_notes

( - wdb: &WalletDb

, +pub(crate) fn get_spendable_sapling_notes( + conn: &Connection, account: AccountId, anchor_height: BlockHeight, exclude: &[NoteId], ) -> Result>, SqliteClientError> { - let mut stmt_select_notes = wdb.conn.prepare( + let mut stmt_select_notes = conn.prepare_cached( "SELECT id_note, diversifier, value, rcm, witness FROM sapling_received_notes INNER JOIN transactions ON transactions.id_tx = sapling_received_notes.tx @@ -156,8 +158,8 @@ pub(crate) fn get_spendable_sapling_notes

( notes.collect::>() } -pub(crate) fn select_spendable_sapling_notes

( - wdb: &WalletDb

, +pub(crate) fn select_spendable_sapling_notes( + conn: &Connection, account: AccountId, target_value: Amount, anchor_height: BlockHeight, @@ -181,7 +183,7 @@ pub(crate) fn select_spendable_sapling_notes

( // required value, bringing the sum of all selected notes across the threshold. // // 4) Match the selected notes against the witnesses at the desired height. - let mut stmt_select_notes = wdb.conn.prepare( + let mut stmt_select_notes = conn.prepare_cached( "WITH selected AS ( WITH eligible AS ( SELECT id_note, diversifier, value, rcm, @@ -189,8 +191,8 @@ pub(crate) fn select_spendable_sapling_notes

( (PARTITION BY account, spent ORDER BY id_note) AS so_far FROM sapling_received_notes INNER JOIN transactions ON transactions.id_tx = sapling_received_notes.tx - WHERE account = :account - AND spent IS NULL + WHERE account = :account + AND spent IS NULL AND transactions.block <= :anchor_height AND id_note NOT IN rarray(:exclude) ) @@ -230,43 +232,42 @@ pub(crate) fn select_spendable_sapling_notes

( /// Returns the commitment tree for the block at the specified height, /// if any. -pub(crate) fn get_sapling_commitment_tree

( - wdb: &WalletDb

, +pub(crate) fn get_sapling_commitment_tree( + conn: &Connection, block_height: BlockHeight, ) -> Result, SqliteClientError> { - wdb.conn - .query_row_and_then( - "SELECT sapling_tree FROM blocks WHERE height = ?", - [u32::from(block_height)], - |row| { - let row_data: Vec = row.get(0)?; - read_commitment_tree(&row_data[..]).map_err(|e| { - rusqlite::Error::FromSqlConversionFailure( - row_data.len(), - rusqlite::types::Type::Blob, - Box::new(e), - ) - }) - }, - ) - .optional() - .map_err(SqliteClientError::from) + conn.query_row_and_then( + "SELECT sapling_tree FROM blocks WHERE height = ?", + [u32::from(block_height)], + |row| { + let row_data: Vec = row.get(0)?; + read_commitment_tree(&row_data[..]).map_err(|e| { + rusqlite::Error::FromSqlConversionFailure( + row_data.len(), + rusqlite::types::Type::Blob, + Box::new(e), + ) + }) + }, + ) + .optional() + .map_err(SqliteClientError::from) } /// Returns the incremental witnesses for the block at the specified height, /// if any. -pub(crate) fn get_sapling_witnesses

( - wdb: &WalletDb

, +pub(crate) fn get_sapling_witnesses( + conn: &Connection, block_height: BlockHeight, ) -> Result, SqliteClientError> { - let mut stmt_fetch_witnesses = wdb - .conn - .prepare("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?; + let mut stmt_fetch_witnesses = + conn.prepare_cached("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?; + let witnesses = stmt_fetch_witnesses .query_map([u32::from(block_height)], |row| { let id_note = NoteId::ReceivedNoteId(row.get(0)?); - let wdb: Vec = row.get(1)?; - Ok(read_incremental_witness(&wdb[..]).map(|witness| (id_note, witness))) + let witness_data: Vec = row.get(1)?; + Ok(read_incremental_witness(&witness_data[..]).map(|witness| (id_note, witness))) }) .map_err(SqliteClientError::from)?; @@ -277,13 +278,23 @@ pub(crate) fn get_sapling_witnesses

( /// Records the incremental witness for the specified note, /// as of the given block height. -pub(crate) fn insert_witness<'a, P>( - stmts: &mut DataConnStmtCache<'a, P>, +pub(crate) fn insert_witness( + conn: &Connection, note_id: i64, witness: &sapling::IncrementalWitness, height: BlockHeight, ) -> Result<(), SqliteClientError> { - stmts.stmt_insert_witness(NoteId::ReceivedNoteId(note_id), height, witness) + let mut stmt_insert_witness = conn.prepare_cached( + "INSERT INTO sapling_witnesses (note, block, witness) + VALUES (?, ?, ?)", + )?; + + let mut encoded = Vec::new(); + write_incremental_witness(witness, &mut encoded).unwrap(); + + stmt_insert_witness.execute(params![note_id, u32::from(height), encoded])?; + + Ok(()) } /// Retrieves the set of nullifiers for "potentially spendable" Sapling notes that the @@ -292,11 +303,11 @@ pub(crate) fn insert_witness<'a, P>( /// "Potentially spendable" means: /// - The transaction in which the note was created has been observed as mined. /// - No transaction in which the note's nullifier appears has been observed as mined. -pub(crate) fn get_sapling_nullifiers

( - wdb: &WalletDb

, +pub(crate) fn get_sapling_nullifiers( + conn: &Connection, ) -> Result, SqliteClientError> { // Get the nullifiers for the notes we are tracking - let mut stmt_fetch_nullifiers = wdb.conn.prepare( + let mut stmt_fetch_nullifiers = conn.prepare( "SELECT rn.id_note, rn.account, rn.nf, tx.block as block FROM sapling_received_notes rn LEFT OUTER JOIN transactions tx @@ -318,11 +329,11 @@ pub(crate) fn get_sapling_nullifiers

( } /// Returns the nullifiers for the notes that this wallet is tracking. -pub(crate) fn get_all_sapling_nullifiers

( - wdb: &WalletDb

, +pub(crate) fn get_all_sapling_nullifiers( + conn: &Connection, ) -> Result, SqliteClientError> { // Get the nullifiers for the notes we are tracking - let mut stmt_fetch_nullifiers = wdb.conn.prepare( + let mut stmt_fetch_nullifiers = conn.prepare( "SELECT rn.id_note, rn.account, rn.nf FROM sapling_received_notes rn WHERE nf IS NOT NULL", @@ -345,13 +356,19 @@ pub(crate) fn get_all_sapling_nullifiers

( /// /// Marking a note spent in this fashion does NOT imply that the /// spending transaction has been mined. -pub(crate) fn mark_sapling_note_spent<'a, P>( - stmts: &mut DataConnStmtCache<'a, P>, +pub(crate) fn mark_sapling_note_spent( + conn: &Connection, tx_ref: i64, nf: &Nullifier, -) -> Result<(), SqliteClientError> { - stmts.stmt_mark_sapling_note_spent(tx_ref, nf)?; - Ok(()) +) -> Result { + let mut stmt_mark_sapling_note_spent = + conn.prepare_cached("UPDATE sapling_received_notes SET spent = ? WHERE nf = ?")?; + + match stmt_mark_sapling_note_spent.execute(params![tx_ref, &nf.0[..]])? { + 0 => Ok(false), + 1 => Ok(true), + _ => unreachable!("nf column is marked as UNIQUE"), + } } /// Records the specified shielded output as having been received. @@ -359,49 +376,48 @@ pub(crate) fn mark_sapling_note_spent<'a, P>( /// This implementation relies on the facts that: /// - A transaction will not contain more than 2^63 shielded outputs. /// - A note value will never exceed 2^63 zatoshis. -pub(crate) fn put_received_note<'a, P, T: ReceivedSaplingOutput>( - stmts: &mut DataConnStmtCache<'a, P>, +pub(crate) fn put_received_note( + conn: &Connection, output: &T, tx_ref: i64, ) -> Result { + let mut stmt_upsert_received_note = conn.prepare_cached( + "INSERT INTO sapling_received_notes + (tx, output_index, account, diversifier, value, rcm, memo, nf, is_change) + VALUES + (:tx, :output_index, :account, :diversifier, :value, :rcm, :memo, :nf, :is_change) + ON CONFLICT (tx, output_index) DO UPDATE + SET account = :account, + diversifier = :diversifier, + value = :value, + rcm = :rcm, + nf = IFNULL(:nf, nf), + memo = IFNULL(:memo, memo), + is_change = IFNULL(:is_change, is_change) + RETURNING id_note", + )?; + let rcm = output.note().rcm().to_repr(); - let account = output.account(); let to = output.note().recipient(); let diversifier = to.diversifier(); - let value = output.note().value(); - let memo = output.memo(); - let is_change = output.is_change(); - let output_index = output.index(); - let nf = output.nullifier(); - // First try updating an existing received note into the database. - if !stmts.stmt_update_received_note( - account, - diversifier, - value.inner(), - rcm, - nf, - memo, - is_change, - tx_ref, - output_index, - )? { - // It isn't there, so insert our note into the database. - stmts.stmt_insert_received_note( - tx_ref, - output_index, - account, - diversifier, - value.inner(), - rcm, - nf, - memo, - is_change, - ) - } else { - // It was there, so grab its row number. - stmts.stmt_select_received_note(tx_ref, output.index()) - } + let sql_args = named_params![ + ":tx": &tx_ref, + ":output_index": i64::try_from(output.index()).expect("output indices are representable as i64"), + ":account": u32::from(output.account()), + ":diversifier": &diversifier.0.as_ref(), + ":value": output.note().value().inner(), + ":rcm": &rcm.as_ref(), + ":nf": output.nullifier().map(|nf| nf.0.as_ref()), + ":memo": memo_repr(output.memo()), + ":is_change": output.is_change() + ]; + + stmt_upsert_received_note + .query_row(sql_args, |row| { + row.get::<_, i64>(0).map(NoteId::ReceivedNoteId) + }) + .map_err(SqliteClientError::from) } #[cfg(test)] @@ -447,7 +463,7 @@ mod tests { get_balance, get_balance_at, init::{init_blocks_table, init_wallet_db}, }, - AccountId, BlockDb, DataConnStmtCache, WalletDb, + AccountId, BlockDb, WalletDb, }; #[cfg(feature = "transparent-inputs")] @@ -481,9 +497,8 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet - let mut ops = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (_, usk) = ops.create_account(&seed).unwrap(); + let (_, usk) = db_data.create_account(&seed).unwrap(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let to = dfvk.default_address().1.into(); @@ -492,10 +507,9 @@ mod tests { let usk1 = UnifiedSpendingKey::from_seed(&network(), &[1u8; 32], acct1).unwrap(); // Attempting to spend with a USK that is not in the wallet results in an error - let mut db_write = db_data.get_update_ops().unwrap(); assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk1, @@ -516,17 +530,15 @@ mod tests { init_wallet_db(&mut db_data, None).unwrap(); // Add an account to the wallet - let mut ops = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (_, usk) = ops.create_account(&seed).unwrap(); + let (_, usk) = db_data.create_account(&seed).unwrap(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let to = dfvk.default_address().1.into(); // We cannot do anything if we aren't synchronised - let mut db_write = db_data.get_update_ops().unwrap(); assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -546,7 +558,7 @@ mod tests { let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); init_wallet_db(&mut db_data, None).unwrap(); init_blocks_table( - &db_data, + &mut db_data, BlockHeight::from(1u32), BlockHash([1; 32]), 1, @@ -555,23 +567,21 @@ mod tests { .unwrap(); // Add an account to the wallet - let mut ops = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (_, usk) = ops.create_account(&seed).unwrap(); + let (_, usk) = db_data.create_account(&seed).unwrap(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let to = dfvk.default_address().1.into(); // Account balance should be zero assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), Amount::zero() ); // We cannot spend anything - let mut db_write = db_data.get_update_ops().unwrap(); assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -600,9 +610,8 @@ mod tests { init_wallet_db(&mut db_data, None).unwrap(); // Add an account to the wallet - let mut ops = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (_, usk) = ops.create_account(&seed).unwrap(); + let (_, usk) = db_data.create_account(&seed).unwrap(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); // Add funds to the wallet in a single note @@ -615,14 +624,16 @@ mod tests { value, ); insert_into_cache(&db_cache, &cb); - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Verified balance matches total balance let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap(); - assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); assert_eq!( - get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), + value + ); + assert_eq!( + get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(), value ); @@ -635,16 +646,16 @@ mod tests { value, ); insert_into_cache(&db_cache, &cb); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Verified balance does not include the second note let (_, anchor_height2) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap(); assert_eq!( - get_balance(&db_data, AccountId::from(0)).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), (value + value).unwrap() ); assert_eq!( - get_balance_at(&db_data, AccountId::from(0), anchor_height2).unwrap(), + get_balance_at(&db_data.conn, AccountId::from(0), anchor_height2).unwrap(), value ); @@ -653,7 +664,7 @@ mod tests { let to = extsk2.default_address().1.into(); assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -683,12 +694,12 @@ mod tests { ); insert_into_cache(&db_cache, &cb); } - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Second spend still fails assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -715,12 +726,12 @@ mod tests { value, ); insert_into_cache(&db_cache, &cb); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Second spend should now succeed assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -745,9 +756,8 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); // Add an account to the wallet - let mut ops = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (_, usk) = ops.create_account(&seed).unwrap(); + let (_, usk) = db_data.create_account(&seed).unwrap(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); // Add funds to the wallet in a single note @@ -760,16 +770,18 @@ mod tests { value, ); insert_into_cache(&db_cache, &cb); - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); - assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); + assert_eq!( + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), + value + ); // Send some of the funds to another address let extsk2 = ExtendedSpendingKey::master(&[]); let to = extsk2.default_address().1.into(); assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -785,7 +797,7 @@ mod tests { // A second spend fails because there are no usable notes assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -814,12 +826,12 @@ mod tests { ); insert_into_cache(&db_cache, &cb); } - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Second spend still fails assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -845,11 +857,11 @@ mod tests { value, ); insert_into_cache(&db_cache, &cb); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Second spend should now succeed create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -874,9 +886,8 @@ mod tests { init_wallet_db(&mut db_data, None).unwrap(); // Add an account to the wallet - let mut ops = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (_, usk) = ops.create_account(&seed).unwrap(); + let (_, usk) = db_data.create_account(&seed).unwrap(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); // Add funds to the wallet in a single note @@ -889,17 +900,19 @@ mod tests { value, ); insert_into_cache(&db_cache, &cb); - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); - assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); + assert_eq!( + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), + value + ); let extsk2 = ExtendedSpendingKey::master(&[]); let addr2 = extsk2.default_address().1; let to = addr2.into(); - let send_and_recover_with_policy = |db_write: &mut DataConnStmtCache<'_, _>, ovk_policy| { + let send_and_recover_with_policy = |db_data: &mut WalletDb, ovk_policy| { let tx_row = create_spend_to_address( - db_write, + db_data, &tests::network(), test_prover(), &usk, @@ -912,8 +925,7 @@ mod tests { .unwrap(); // Fetch the transaction from the database - let raw_tx: Vec<_> = db_write - .wallet_db + let raw_tx: Vec<_> = db_data .conn .query_row( "SELECT raw FROM transactions @@ -944,7 +956,7 @@ mod tests { // Send some of the funds to another address, keeping history. // The recipient output is decryptable by the sender. let (_, recovered_to, _) = - send_and_recover_with_policy(&mut db_write, OvkPolicy::Sender).unwrap(); + send_and_recover_with_policy(&mut db_data, OvkPolicy::Sender).unwrap(); assert_eq!(&recovered_to, &addr2); // Mine blocks SAPLING_ACTIVATION_HEIGHT + 1 to 42 (that don't send us funds) @@ -959,11 +971,11 @@ mod tests { ); insert_into_cache(&db_cache, &cb); } - scan_cached_blocks(&network, &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&network, &db_cache, &mut db_data, None).unwrap(); // Send the funds again, discarding history. // Neither transaction output is decryptable by the sender. - assert!(send_and_recover_with_policy(&mut db_write, OvkPolicy::Discard).is_none()); + assert!(send_and_recover_with_policy(&mut db_data, OvkPolicy::Discard).is_none()); } #[test] @@ -977,9 +989,8 @@ mod tests { init_wallet_db(&mut db_data, None).unwrap(); // Add an account to the wallet - let mut ops = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (_, usk) = ops.create_account(&seed).unwrap(); + let (_, usk) = db_data.create_account(&seed).unwrap(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); // Add funds to the wallet in a single note @@ -992,21 +1003,23 @@ mod tests { value, ); insert_into_cache(&db_cache, &cb); - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Verified balance matches total balance let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap(); - assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); assert_eq!( - get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), + value + ); + assert_eq!( + get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(), value ); let to = TransparentAddress::PublicKey([7; 20]).into(); assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -1031,9 +1044,8 @@ mod tests { init_wallet_db(&mut db_data, None).unwrap(); // Add an account to the wallet - let mut ops = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (_, usk) = ops.create_account(&seed).unwrap(); + let (_, usk) = db_data.create_account(&seed).unwrap(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); // Add funds to the wallet in a single note @@ -1046,21 +1058,23 @@ mod tests { value, ); insert_into_cache(&db_cache, &cb); - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Verified balance matches total balance let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap(); - assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); assert_eq!( - get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), + value + ); + assert_eq!( + get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(), value ); let to = TransparentAddress::PublicKey([7; 20]).into(); assert_matches!( create_spend_to_address( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &usk, @@ -1085,9 +1099,8 @@ mod tests { init_wallet_db(&mut db_data, None).unwrap(); // Add an account to the wallet - let mut ops = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (_, usk) = ops.create_account(&seed).unwrap(); + let (_, usk) = db_data.create_account(&seed).unwrap(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); // Add funds to the wallet @@ -1112,15 +1125,17 @@ mod tests { insert_into_cache(&db_cache, &cb); } - let mut db_write = db_data.get_update_ops().unwrap(); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); // Verified balance matches total balance let total = Amount::from_u64(60000).unwrap(); let (_, anchor_height) = db_data.get_target_and_anchor_heights(1).unwrap().unwrap(); - assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), total); assert_eq!( - get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(), + get_balance(&db_data.conn, AccountId::from(0)).unwrap(), + total + ); + assert_eq!( + get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(), total ); @@ -1142,7 +1157,7 @@ mod tests { assert_matches!( spend( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &input_selector, @@ -1170,7 +1185,7 @@ mod tests { assert_matches!( spend( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &input_selector, @@ -1195,9 +1210,8 @@ mod tests { init_wallet_db(&mut db_data, None).unwrap(); // Add an account to the wallet - let mut db_write = db_data.get_update_ops().unwrap(); let seed = Secret::new([0u8; 32].to_vec()); - let (account_id, usk) = db_write.create_account(&seed).unwrap(); + let (account_id, usk) = db_data.create_account(&seed).unwrap(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let uaddr = db_data.get_current_address(account_id).unwrap().unwrap(); let taddr = uaddr.transparent().unwrap(); @@ -1212,7 +1226,7 @@ mod tests { ) .unwrap(); - let res0 = db_write.put_received_transparent_utxo(&utxo); + let res0 = db_data.put_received_transparent_utxo(&utxo); assert!(matches!(res0, Ok(_))); let input_selector = GreedyInputSelector::new( @@ -1229,11 +1243,11 @@ mod tests { Amount::from_u64(50000).unwrap(), ); insert_into_cache(&db_cache, &cb); - scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); + scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap(); assert_matches!( shield_transparent_funds( - &mut db_write, + &mut db_data, &tests::network(), test_prover(), &input_selector,