Merge branch 'wallet/sqlite_cached_statements'

This commit is contained in:
Kris Nuttycombe 2023-06-19 07:55:08 -06:00
commit d5166b5134
8 changed files with 969 additions and 1646 deletions

View File

@ -47,6 +47,7 @@ uuid = "1.1"
[dev-dependencies] [dev-dependencies]
assert_matches = "1.5" assert_matches = "1.5"
incrementalmerkletree = { version = "0.4", features = ["legacy-api", "test-dependencies"] }
proptest = "1.0.0" proptest = "1.0.0"
rand_core = "0.6" rand_core = "0.6"
regex = "1.4" regex = "1.4"
@ -59,6 +60,7 @@ zcash_address = { version = "0.3", path = "../components/zcash_address", feature
[features] [features]
mainnet = [] mainnet = []
test-dependencies = [ test-dependencies = [
"incrementalmerkletree/test-dependencies",
"zcash_primitives/test-dependencies", "zcash_primitives/test-dependencies",
"zcash_client_backend/test-dependencies", "zcash_client_backend/test-dependencies",
] ]

View File

@ -299,7 +299,7 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet // Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data); let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Empty chain should return None // Empty chain should return None
assert_matches!(db_data.get_max_height_hash(), Ok(None)); assert_matches!(db_data.get_max_height_hash(), Ok(None));
@ -328,8 +328,7 @@ mod tests {
assert_matches!(validate_chain_result, Ok(())); assert_matches!(validate_chain_result, Ok(()));
// Scan the cache // Scan the cache
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Data-only chain should be valid // Data-only chain should be valid
validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap(); validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap();
@ -348,7 +347,7 @@ mod tests {
validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap(); validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap();
// Scan the cache again // Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Data-only chain should be valid // Data-only chain should be valid
validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap(); validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap();
@ -365,7 +364,7 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet // Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data); let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Create some fake CompactBlocks // Create some fake CompactBlocks
let (cb, _) = fake_compact_block( let (cb, _) = fake_compact_block(
@ -386,8 +385,7 @@ mod tests {
insert_into_cache(&db_cache, &cb2); insert_into_cache(&db_cache, &cb2);
// Scan the cache // Scan the cache
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Data-only chain should be valid // Data-only chain should be valid
validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap(); validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap();
@ -427,7 +425,7 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet // Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data); let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Create some fake CompactBlocks // Create some fake CompactBlocks
let (cb, _) = fake_compact_block( let (cb, _) = fake_compact_block(
@ -448,8 +446,7 @@ mod tests {
insert_into_cache(&db_cache, &cb2); insert_into_cache(&db_cache, &cb2);
// Scan the cache // Scan the cache
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Data-only chain should be valid // Data-only chain should be valid
validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap(); validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap();
@ -489,11 +486,11 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet // Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data); let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Account balance should be zero // Account balance should be zero
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::zero() Amount::zero()
); );
@ -519,36 +516,46 @@ mod tests {
insert_into_cache(&db_cache, &cb2); insert_into_cache(&db_cache, &cb2);
// Scan the cache // Scan the cache
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should reflect both received notes // Account balance should reflect both received notes
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value + value2).unwrap() (value + value2).unwrap()
); );
// "Rewind" to height of last scanned block // "Rewind" to height of last scanned block
truncate_to_height(&db_data, sapling_activation_height() + 1).unwrap(); db_data
.transactionally(|wdb| {
truncate_to_height(&wdb.conn.0, &wdb.params, sapling_activation_height() + 1)
})
.unwrap();
// Account balance should be unaltered // Account balance should be unaltered
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value + value2).unwrap() (value + value2).unwrap()
); );
// Rewind so that one block is dropped // Rewind so that one block is dropped
truncate_to_height(&db_data, sapling_activation_height()).unwrap(); db_data
.transactionally(|wdb| {
truncate_to_height(&wdb.conn.0, &wdb.params, sapling_activation_height())
})
.unwrap();
// Account balance should only contain the first received note // Account balance should only contain the first received note
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
// Scan the cache again // Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Account balance should again reflect both received notes // Account balance should again reflect both received notes
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value + value2).unwrap() (value + value2).unwrap()
); );
} }
@ -564,7 +571,7 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet // Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data); let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Create a block with height SAPLING_ACTIVATION_HEIGHT // Create a block with height SAPLING_ACTIVATION_HEIGHT
let value = Amount::from_u64(50000).unwrap(); let value = Amount::from_u64(50000).unwrap();
@ -576,9 +583,11 @@ mod tests {
value, value,
); );
insert_into_cache(&db_cache, &cb1); insert_into_cache(&db_cache, &cb1);
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); assert_eq!(
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
// We cannot scan a block of height SAPLING_ACTIVATION_HEIGHT + 2 next // We cannot scan a block of height SAPLING_ACTIVATION_HEIGHT + 2 next
let (cb2, _) = fake_compact_block( let (cb2, _) = fake_compact_block(
@ -596,7 +605,7 @@ mod tests {
value, value,
); );
insert_into_cache(&db_cache, &cb3); insert_into_cache(&db_cache, &cb3);
match scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None) { match scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None) {
Err(Error::Chain(e)) => { Err(Error::Chain(e)) => {
assert_matches!( assert_matches!(
e.cause(), e.cause(),
@ -609,9 +618,9 @@ mod tests {
// If we add a block of height SAPLING_ACTIVATION_HEIGHT + 1, we can now scan both // If we add a block of height SAPLING_ACTIVATION_HEIGHT + 1, we can now scan both
insert_into_cache(&db_cache, &cb2); insert_into_cache(&db_cache, &cb2);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::from_u64(150_000).unwrap() Amount::from_u64(150_000).unwrap()
); );
} }
@ -627,11 +636,11 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet // Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data); let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Account balance should be zero // Account balance should be zero
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::zero() Amount::zero()
); );
@ -647,11 +656,13 @@ mod tests {
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
// Scan the cache // Scan the cache
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should reflect the received note // Account balance should reflect the received note
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
// Create a second fake CompactBlock sending more value to the address // Create a second fake CompactBlock sending more value to the address
let value2 = Amount::from_u64(7).unwrap(); let value2 = Amount::from_u64(7).unwrap();
@ -665,11 +676,11 @@ mod tests {
insert_into_cache(&db_cache, &cb2); insert_into_cache(&db_cache, &cb2);
// Scan the cache again // Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Account balance should reflect both received notes // Account balance should reflect both received notes
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value + value2).unwrap() (value + value2).unwrap()
); );
} }
@ -685,11 +696,11 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet // Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data); let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Account balance should be zero // Account balance should be zero
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::zero() Amount::zero()
); );
@ -705,11 +716,13 @@ mod tests {
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
// Scan the cache // Scan the cache
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should reflect the received note // Account balance should reflect the received note
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
// Create a second fake CompactBlock spending value from the address // Create a second fake CompactBlock spending value from the address
let extsk2 = ExtendedSpendingKey::master(&[0]); let extsk2 = ExtendedSpendingKey::master(&[0]);
@ -728,11 +741,11 @@ mod tests {
); );
// Scan the cache again // Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Account balance should equal the change // Account balance should equal the change
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value - value2).unwrap() (value - value2).unwrap()
); );
} }

View File

@ -32,18 +32,16 @@
// Catch documentation errors caused by code changes. // Catch documentation errors caused by code changes.
#![deny(rustdoc::broken_intra_doc_links)] #![deny(rustdoc::broken_intra_doc_links)]
use rusqlite::Connection; use rusqlite::{self, Connection};
use secrecy::{ExposeSecret, SecretVec}; use secrecy::{ExposeSecret, SecretVec};
use std::collections::HashMap; use std::{borrow::Borrow, collections::HashMap, convert::AsRef, fmt, path::Path};
use std::fmt;
use std::path::Path;
use zcash_primitives::{ use zcash_primitives::{
block::BlockHash, block::BlockHash,
consensus::{self, BlockHeight}, consensus::{self, BlockHeight},
legacy::TransparentAddress, legacy::TransparentAddress,
memo::{Memo, MemoBytes}, memo::{Memo, MemoBytes},
sapling::{self}, sapling,
transaction::{ transaction::{
components::{amount::Amount, OutPoint}, components::{amount::Amount, OutPoint},
Transaction, TxId, Transaction, TxId,
@ -72,9 +70,6 @@ use {
std::{fs, io}, std::{fs, io},
}; };
mod prepared;
pub use prepared::DataConnStmtCache;
pub mod chain; pub mod chain;
pub mod error; pub mod error;
pub mod wallet; pub mod wallet;
@ -107,12 +102,21 @@ impl fmt::Display for NoteId {
pub struct UtxoId(pub i64); pub struct UtxoId(pub i64);
/// A wrapper for the SQLite connection to the wallet database. /// A wrapper for the SQLite connection to the wallet database.
pub struct WalletDb<P> { pub struct WalletDb<C, P> {
conn: Connection, conn: C,
params: P, params: P,
} }
impl<P: consensus::Parameters> WalletDb<P> { /// A wrapper for a SQLite transaction affecting the wallet database.
pub struct SqlTransaction<'conn>(pub(crate) rusqlite::Transaction<'conn>);
impl Borrow<rusqlite::Connection> for SqlTransaction<'_> {
fn borrow(&self) -> &rusqlite::Connection {
&self.0
}
}
impl<P: consensus::Parameters + Clone> WalletDb<Connection, P> {
/// Construct a connection to the wallet database stored at the specified path. /// Construct a connection to the wallet database stored at the specified path.
pub fn for_path<F: AsRef<Path>>(path: F, params: P) -> Result<Self, rusqlite::Error> { pub fn for_path<F: AsRef<Path>>(path: F, params: P) -> Result<Self, rusqlite::Error> {
Connection::open(path).and_then(move |conn| { Connection::open(path).and_then(move |conn| {
@ -121,53 +125,60 @@ impl<P: consensus::Parameters> WalletDb<P> {
}) })
} }
/// Given a wallet database connection, obtain a handle for the write operations pub fn transactionally<F, A>(&mut self, f: F) -> Result<A, SqliteClientError>
/// for that database. This operation may eagerly initialize and cache sqlite where
/// prepared statements that are used in write operations. F: FnOnce(&WalletDb<SqlTransaction<'_>, P>) -> Result<A, SqliteClientError>,
pub fn get_update_ops(&self) -> Result<DataConnStmtCache<'_, P>, SqliteClientError> { {
DataConnStmtCache::new(self) let wdb = WalletDb {
conn: SqlTransaction(self.conn.transaction()?),
params: self.params.clone(),
};
let result = f(&wdb)?;
wdb.conn.0.commit()?;
Ok(result)
} }
} }
impl<P: consensus::Parameters> WalletRead for WalletDb<P> { impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> WalletRead for WalletDb<C, P> {
type Error = SqliteClientError; type Error = SqliteClientError;
type NoteRef = NoteId; type NoteRef = NoteId;
type TxRef = i64; type TxRef = i64;
fn block_height_extrema(&self) -> Result<Option<(BlockHeight, BlockHeight)>, Self::Error> { fn block_height_extrema(&self) -> Result<Option<(BlockHeight, BlockHeight)>, Self::Error> {
wallet::block_height_extrema(self).map_err(SqliteClientError::from) wallet::block_height_extrema(self.conn.borrow()).map_err(SqliteClientError::from)
} }
fn get_min_unspent_height(&self) -> Result<Option<BlockHeight>, Self::Error> { fn get_min_unspent_height(&self) -> Result<Option<BlockHeight>, Self::Error> {
wallet::get_min_unspent_height(self).map_err(SqliteClientError::from) wallet::get_min_unspent_height(self.conn.borrow()).map_err(SqliteClientError::from)
} }
fn get_block_hash(&self, block_height: BlockHeight) -> Result<Option<BlockHash>, Self::Error> { fn get_block_hash(&self, block_height: BlockHeight) -> Result<Option<BlockHash>, Self::Error> {
wallet::get_block_hash(self, block_height).map_err(SqliteClientError::from) wallet::get_block_hash(self.conn.borrow(), block_height).map_err(SqliteClientError::from)
} }
fn get_tx_height(&self, txid: TxId) -> Result<Option<BlockHeight>, Self::Error> { fn get_tx_height(&self, txid: TxId) -> Result<Option<BlockHeight>, Self::Error> {
wallet::get_tx_height(self, txid).map_err(SqliteClientError::from) wallet::get_tx_height(self.conn.borrow(), txid).map_err(SqliteClientError::from)
} }
fn get_unified_full_viewing_keys( fn get_unified_full_viewing_keys(
&self, &self,
) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error> { ) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error> {
wallet::get_unified_full_viewing_keys(self) wallet::get_unified_full_viewing_keys(self.conn.borrow(), &self.params)
} }
fn get_account_for_ufvk( fn get_account_for_ufvk(
&self, &self,
ufvk: &UnifiedFullViewingKey, ufvk: &UnifiedFullViewingKey,
) -> Result<Option<AccountId>, Self::Error> { ) -> Result<Option<AccountId>, Self::Error> {
wallet::get_account_for_ufvk(self, ufvk) wallet::get_account_for_ufvk(self.conn.borrow(), &self.params, ufvk)
} }
fn get_current_address( fn get_current_address(
&self, &self,
account: AccountId, account: AccountId,
) -> Result<Option<UnifiedAddress>, Self::Error> { ) -> Result<Option<UnifiedAddress>, Self::Error> {
wallet::get_current_address(self, account).map(|res| res.map(|(addr, _)| addr)) wallet::get_current_address(self.conn.borrow(), &self.params, account)
.map(|res| res.map(|(addr, _)| addr))
} }
fn is_valid_account_extfvk( fn is_valid_account_extfvk(
@ -175,7 +186,7 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
account: AccountId, account: AccountId,
extfvk: &ExtendedFullViewingKey, extfvk: &ExtendedFullViewingKey,
) -> Result<bool, Self::Error> { ) -> Result<bool, Self::Error> {
wallet::is_valid_account_extfvk(self, account, extfvk) wallet::is_valid_account_extfvk(self.conn.borrow(), &self.params, account, extfvk)
} }
fn get_balance_at( fn get_balance_at(
@ -183,17 +194,19 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
account: AccountId, account: AccountId,
anchor_height: BlockHeight, anchor_height: BlockHeight,
) -> Result<Amount, Self::Error> { ) -> Result<Amount, Self::Error> {
wallet::get_balance_at(self, account, anchor_height) wallet::get_balance_at(self.conn.borrow(), account, anchor_height)
} }
fn get_transaction(&self, id_tx: i64) -> Result<Transaction, Self::Error> { fn get_transaction(&self, id_tx: i64) -> Result<Transaction, Self::Error> {
wallet::get_transaction(self, id_tx) wallet::get_transaction(self.conn.borrow(), &self.params, id_tx)
} }
fn get_memo(&self, id_note: Self::NoteRef) -> Result<Option<Memo>, Self::Error> { fn get_memo(&self, id_note: Self::NoteRef) -> Result<Option<Memo>, Self::Error> {
match id_note { match id_note {
NoteId::SentNoteId(id_note) => wallet::get_sent_memo(self, id_note), NoteId::SentNoteId(id_note) => wallet::get_sent_memo(self.conn.borrow(), id_note),
NoteId::ReceivedNoteId(id_note) => wallet::get_received_memo(self, id_note), NoteId::ReceivedNoteId(id_note) => {
wallet::get_received_memo(self.conn.borrow(), id_note)
}
} }
} }
@ -201,7 +214,7 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
&self, &self,
block_height: BlockHeight, block_height: BlockHeight,
) -> Result<Option<sapling::CommitmentTree>, Self::Error> { ) -> Result<Option<sapling::CommitmentTree>, Self::Error> {
wallet::sapling::get_sapling_commitment_tree(self, block_height) wallet::sapling::get_sapling_commitment_tree(self.conn.borrow(), block_height)
} }
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
@ -209,7 +222,7 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
&self, &self,
block_height: BlockHeight, block_height: BlockHeight,
) -> Result<Vec<(Self::NoteRef, sapling::IncrementalWitness)>, Self::Error> { ) -> Result<Vec<(Self::NoteRef, sapling::IncrementalWitness)>, Self::Error> {
wallet::sapling::get_sapling_witnesses(self, block_height) wallet::sapling::get_sapling_witnesses(self.conn.borrow(), block_height)
} }
fn get_sapling_nullifiers( fn get_sapling_nullifiers(
@ -217,8 +230,8 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
query: data_api::NullifierQuery, query: data_api::NullifierQuery,
) -> Result<Vec<(AccountId, sapling::Nullifier)>, Self::Error> { ) -> Result<Vec<(AccountId, sapling::Nullifier)>, Self::Error> {
match query { match query {
NullifierQuery::Unspent => wallet::sapling::get_sapling_nullifiers(self), NullifierQuery::Unspent => wallet::sapling::get_sapling_nullifiers(self.conn.borrow()),
NullifierQuery::All => wallet::sapling::get_all_sapling_nullifiers(self), NullifierQuery::All => wallet::sapling::get_all_sapling_nullifiers(self.conn.borrow()),
} }
} }
@ -228,7 +241,12 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
anchor_height: BlockHeight, anchor_height: BlockHeight,
exclude: &[Self::NoteRef], exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedSaplingNote<Self::NoteRef>>, Self::Error> { ) -> Result<Vec<ReceivedSaplingNote<Self::NoteRef>>, Self::Error> {
wallet::sapling::get_spendable_sapling_notes(self, account, anchor_height, exclude) wallet::sapling::get_spendable_sapling_notes(
self.conn.borrow(),
account,
anchor_height,
exclude,
)
} }
fn select_spendable_sapling_notes( fn select_spendable_sapling_notes(
@ -239,7 +257,7 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
exclude: &[Self::NoteRef], exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedSaplingNote<Self::NoteRef>>, Self::Error> { ) -> Result<Vec<ReceivedSaplingNote<Self::NoteRef>>, Self::Error> {
wallet::sapling::select_spendable_sapling_notes( wallet::sapling::select_spendable_sapling_notes(
self, self.conn.borrow(),
account, account,
target_value, target_value,
anchor_height, anchor_height,
@ -252,7 +270,7 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
_account: AccountId, _account: AccountId,
) -> Result<HashMap<TransparentAddress, AddressMetadata>, Self::Error> { ) -> Result<HashMap<TransparentAddress, AddressMetadata>, Self::Error> {
#[cfg(feature = "transparent-inputs")] #[cfg(feature = "transparent-inputs")]
return wallet::get_transparent_receivers(&self.params, &self.conn, _account); return wallet::get_transparent_receivers(self.conn.borrow(), &self.params, _account);
#[cfg(not(feature = "transparent-inputs"))] #[cfg(not(feature = "transparent-inputs"))]
panic!( panic!(
@ -267,7 +285,13 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
_exclude: &[OutPoint], _exclude: &[OutPoint],
) -> Result<Vec<WalletTransparentOutput>, Self::Error> { ) -> Result<Vec<WalletTransparentOutput>, Self::Error> {
#[cfg(feature = "transparent-inputs")] #[cfg(feature = "transparent-inputs")]
return wallet::get_unspent_transparent_outputs(self, _address, _max_height, _exclude); return wallet::get_unspent_transparent_outputs(
self.conn.borrow(),
&self.params,
_address,
_max_height,
_exclude,
);
#[cfg(not(feature = "transparent-inputs"))] #[cfg(not(feature = "transparent-inputs"))]
panic!( panic!(
@ -281,7 +305,12 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
_max_height: BlockHeight, _max_height: BlockHeight,
) -> Result<HashMap<TransparentAddress, Amount>, Self::Error> { ) -> Result<HashMap<TransparentAddress, Amount>, Self::Error> {
#[cfg(feature = "transparent-inputs")] #[cfg(feature = "transparent-inputs")]
return wallet::get_transparent_balances(self, _account, _max_height); return wallet::get_transparent_balances(
self.conn.borrow(),
&self.params,
_account,
_max_height,
);
#[cfg(not(feature = "transparent-inputs"))] #[cfg(not(feature = "transparent-inputs"))]
panic!( panic!(
@ -290,177 +319,15 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
} }
} }
impl<'a, P: consensus::Parameters> WalletRead for DataConnStmtCache<'a, P> { impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P> {
type Error = SqliteClientError;
type NoteRef = NoteId;
type TxRef = i64;
fn block_height_extrema(&self) -> Result<Option<(BlockHeight, BlockHeight)>, Self::Error> {
self.wallet_db.block_height_extrema()
}
fn get_min_unspent_height(&self) -> Result<Option<BlockHeight>, Self::Error> {
self.wallet_db.get_min_unspent_height()
}
fn get_block_hash(&self, block_height: BlockHeight) -> Result<Option<BlockHash>, Self::Error> {
self.wallet_db.get_block_hash(block_height)
}
fn get_tx_height(&self, txid: TxId) -> Result<Option<BlockHeight>, Self::Error> {
self.wallet_db.get_tx_height(txid)
}
fn get_unified_full_viewing_keys(
&self,
) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error> {
self.wallet_db.get_unified_full_viewing_keys()
}
fn get_account_for_ufvk(
&self,
ufvk: &UnifiedFullViewingKey,
) -> Result<Option<AccountId>, Self::Error> {
self.wallet_db.get_account_for_ufvk(ufvk)
}
fn get_current_address(
&self,
account: AccountId,
) -> Result<Option<UnifiedAddress>, Self::Error> {
self.wallet_db.get_current_address(account)
}
fn is_valid_account_extfvk(
&self,
account: AccountId,
extfvk: &ExtendedFullViewingKey,
) -> Result<bool, Self::Error> {
self.wallet_db.is_valid_account_extfvk(account, extfvk)
}
fn get_balance_at(
&self,
account: AccountId,
anchor_height: BlockHeight,
) -> Result<Amount, Self::Error> {
self.wallet_db.get_balance_at(account, anchor_height)
}
fn get_transaction(&self, id_tx: i64) -> Result<Transaction, Self::Error> {
self.wallet_db.get_transaction(id_tx)
}
fn get_memo(&self, id_note: Self::NoteRef) -> Result<Option<Memo>, Self::Error> {
self.wallet_db.get_memo(id_note)
}
fn get_commitment_tree(
&self,
block_height: BlockHeight,
) -> Result<Option<sapling::CommitmentTree>, Self::Error> {
self.wallet_db.get_commitment_tree(block_height)
}
#[allow(clippy::type_complexity)]
fn get_witnesses(
&self,
block_height: BlockHeight,
) -> Result<Vec<(Self::NoteRef, sapling::IncrementalWitness)>, Self::Error> {
self.wallet_db.get_witnesses(block_height)
}
fn get_sapling_nullifiers(
&self,
query: data_api::NullifierQuery,
) -> Result<Vec<(AccountId, sapling::Nullifier)>, Self::Error> {
self.wallet_db.get_sapling_nullifiers(query)
}
fn get_spendable_sapling_notes(
&self,
account: AccountId,
anchor_height: BlockHeight,
exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedSaplingNote<Self::NoteRef>>, Self::Error> {
self.wallet_db
.get_spendable_sapling_notes(account, anchor_height, exclude)
}
fn select_spendable_sapling_notes(
&self,
account: AccountId,
target_value: Amount,
anchor_height: BlockHeight,
exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedSaplingNote<Self::NoteRef>>, Self::Error> {
self.wallet_db
.select_spendable_sapling_notes(account, target_value, anchor_height, exclude)
}
fn get_transparent_receivers(
&self,
account: AccountId,
) -> Result<HashMap<TransparentAddress, AddressMetadata>, Self::Error> {
self.wallet_db.get_transparent_receivers(account)
}
fn get_unspent_transparent_outputs(
&self,
address: &TransparentAddress,
max_height: BlockHeight,
exclude: &[OutPoint],
) -> Result<Vec<WalletTransparentOutput>, Self::Error> {
self.wallet_db
.get_unspent_transparent_outputs(address, max_height, exclude)
}
fn get_transparent_balances(
&self,
account: AccountId,
max_height: BlockHeight,
) -> Result<HashMap<TransparentAddress, Amount>, Self::Error> {
self.wallet_db.get_transparent_balances(account, max_height)
}
}
impl<'a, P: consensus::Parameters> DataConnStmtCache<'a, P> {
fn transactionally<F, A>(&mut self, f: F) -> Result<A, SqliteClientError>
where
F: FnOnce(&mut Self) -> Result<A, SqliteClientError>,
{
self.wallet_db.conn.execute("BEGIN IMMEDIATE", [])?;
match f(self) {
Ok(result) => {
self.wallet_db.conn.execute("COMMIT", [])?;
Ok(result)
}
Err(error) => {
match self.wallet_db.conn.execute("ROLLBACK", []) {
Ok(_) => Err(error),
Err(e) =>
// Panicking here is probably the right thing to do, because it
// means the database is corrupt.
panic!(
"Rollback failed with error {} while attempting to recover from error {}; database is likely corrupt.",
e,
error
)
}
}
}
}
}
impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
type UtxoRef = UtxoId; type UtxoRef = UtxoId;
fn create_account( fn create_account(
&mut self, &mut self,
seed: &SecretVec<u8>, seed: &SecretVec<u8>,
) -> Result<(AccountId, UnifiedSpendingKey), Self::Error> { ) -> Result<(AccountId, UnifiedSpendingKey), Self::Error> {
self.transactionally(|stmts| { self.transactionally(|wdb| {
let account = wallet::get_max_account_id(stmts.wallet_db)? let account = wallet::get_max_account_id(&wdb.conn.0)?
.map(|a| AccountId::from(u32::from(a) + 1)) .map(|a| AccountId::from(u32::from(a) + 1))
.unwrap_or_else(|| AccountId::from(0)); .unwrap_or_else(|| AccountId::from(0));
@ -468,15 +335,11 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
return Err(SqliteClientError::AccountIdOutOfRange); return Err(SqliteClientError::AccountIdOutOfRange);
} }
let usk = UnifiedSpendingKey::from_seed( let usk = UnifiedSpendingKey::from_seed(&wdb.params, seed.expose_secret(), account)
&stmts.wallet_db.params, .map_err(|_| SqliteClientError::KeyDerivationError(account))?;
seed.expose_secret(),
account,
)
.map_err(|_| SqliteClientError::KeyDerivationError(account))?;
let ufvk = usk.to_unified_full_viewing_key(); let ufvk = usk.to_unified_full_viewing_key();
wallet::add_account(stmts.wallet_db, account, &ufvk)?; wallet::add_account(&wdb.conn.0, &wdb.params, account, &ufvk)?;
Ok((account, usk)) Ok((account, usk))
}) })
@ -486,28 +349,37 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
&mut self, &mut self,
account: AccountId, account: AccountId,
) -> Result<Option<UnifiedAddress>, Self::Error> { ) -> Result<Option<UnifiedAddress>, Self::Error> {
match self.get_unified_full_viewing_keys()?.get(&account) { self.transactionally(
Some(ufvk) => { |wdb| match wdb.get_unified_full_viewing_keys()?.get(&account) {
let search_from = match wallet::get_current_address(self.wallet_db, account)? { Some(ufvk) => {
Some((_, mut last_diversifier_index)) => { let search_from =
last_diversifier_index match wallet::get_current_address(&wdb.conn.0, &wdb.params, account)? {
.increment() Some((_, mut last_diversifier_index)) => {
.map_err(|_| SqliteClientError::DiversifierIndexOutOfRange)?; last_diversifier_index
last_diversifier_index .increment()
} .map_err(|_| SqliteClientError::DiversifierIndexOutOfRange)?;
None => DiversifierIndex::default(), last_diversifier_index
}; }
None => DiversifierIndex::default(),
};
let (addr, diversifier_index) = ufvk let (addr, diversifier_index) = ufvk
.find_address(search_from) .find_address(search_from)
.ok_or(SqliteClientError::DiversifierIndexOutOfRange)?; .ok_or(SqliteClientError::DiversifierIndexOutOfRange)?;
self.stmt_insert_address(account, diversifier_index, &addr)?; wallet::insert_address(
&wdb.conn.0,
&wdb.params,
account,
diversifier_index,
&addr,
)?;
Ok(Some(addr)) Ok(Some(addr))
} }
None => Ok(None), None => Ok(None),
} },
)
} }
#[tracing::instrument(skip_all, fields(height = u32::from(block.block_height)))] #[tracing::instrument(skip_all, fields(height = u32::from(block.block_height)))]
@ -517,11 +389,10 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
block: &PrunedBlock, block: &PrunedBlock,
updated_witnesses: &[(Self::NoteRef, sapling::IncrementalWitness)], updated_witnesses: &[(Self::NoteRef, sapling::IncrementalWitness)],
) -> Result<Vec<(Self::NoteRef, sapling::IncrementalWitness)>, Self::Error> { ) -> Result<Vec<(Self::NoteRef, sapling::IncrementalWitness)>, Self::Error> {
// database updates for each block are transactional self.transactionally(|wdb| {
self.transactionally(|up| {
// Insert the block into the database. // Insert the block into the database.
wallet::insert_block( wallet::insert_block(
up, &wdb.conn.0,
block.block_height, block.block_height,
block.block_hash, block.block_hash,
block.block_time, block.block_time,
@ -530,15 +401,16 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
let mut new_witnesses = vec![]; let mut new_witnesses = vec![];
for tx in block.transactions { for tx in block.transactions {
let tx_row = wallet::put_tx_meta(up, tx, block.block_height)?; let tx_row = wallet::put_tx_meta(&wdb.conn.0, tx, block.block_height)?;
// Mark notes as spent and remove them from the scanning cache // Mark notes as spent and remove them from the scanning cache
for spend in &tx.sapling_spends { for spend in &tx.sapling_spends {
wallet::sapling::mark_sapling_note_spent(up, tx_row, spend.nf())?; wallet::sapling::mark_sapling_note_spent(&wdb.conn.0, tx_row, spend.nf())?;
} }
for output in &tx.sapling_outputs { for output in &tx.sapling_outputs {
let received_note_id = wallet::sapling::put_received_note(up, output, tx_row)?; let received_note_id =
wallet::sapling::put_received_note(&wdb.conn.0, output, tx_row)?;
// Save witness for note. // Save witness for note.
new_witnesses.push((received_note_id, output.witness().clone())); new_witnesses.push((received_note_id, output.witness().clone()));
@ -549,17 +421,22 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
for (received_note_id, witness) in updated_witnesses.iter().chain(new_witnesses.iter()) for (received_note_id, witness) in updated_witnesses.iter().chain(new_witnesses.iter())
{ {
if let NoteId::ReceivedNoteId(rnid) = *received_note_id { if let NoteId::ReceivedNoteId(rnid) = *received_note_id {
wallet::sapling::insert_witness(up, rnid, witness, block.block_height)?; wallet::sapling::insert_witness(
&wdb.conn.0,
rnid,
witness,
block.block_height,
)?;
} else { } else {
return Err(SqliteClientError::InvalidNoteId); return Err(SqliteClientError::InvalidNoteId);
} }
} }
// Prune the stored witnesses (we only expect rollbacks of at most PRUNING_HEIGHT blocks). // Prune the stored witnesses (we only expect rollbacks of at most PRUNING_HEIGHT blocks).
wallet::prune_witnesses(up, block.block_height - PRUNING_HEIGHT)?; wallet::prune_witnesses(&wdb.conn.0, block.block_height - PRUNING_HEIGHT)?;
// Update now-expired transactions that didn't get mined. // Update now-expired transactions that didn't get mined.
wallet::update_expired_notes(up, block.block_height)?; wallet::update_expired_notes(&wdb.conn.0, block.block_height)?;
Ok(new_witnesses) Ok(new_witnesses)
}) })
@ -569,91 +446,114 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
&mut self, &mut self,
d_tx: DecryptedTransaction, d_tx: DecryptedTransaction,
) -> Result<Self::TxRef, Self::Error> { ) -> Result<Self::TxRef, Self::Error> {
self.transactionally(|up| { self.transactionally(|wdb| {
let tx_ref = wallet::put_tx_data(up, d_tx.tx, None, None)?; let tx_ref = wallet::put_tx_data(&wdb.conn.0, d_tx.tx, None, None)?;
let mut spending_account_id: Option<AccountId> = None; let mut spending_account_id: Option<AccountId> = None;
for output in d_tx.sapling_outputs { for output in d_tx.sapling_outputs {
match output.transfer_type { match output.transfer_type {
TransferType::Outgoing | TransferType::WalletInternal => { TransferType::Outgoing | TransferType::WalletInternal => {
let recipient = if output.transfer_type == TransferType::Outgoing { let recipient = if output.transfer_type == TransferType::Outgoing {
Recipient::Sapling(output.note.recipient()) Recipient::Sapling(output.note.recipient())
} else { } else {
Recipient::InternalAccount(output.account, PoolType::Sapling) Recipient::InternalAccount(output.account, PoolType::Sapling)
}; };
wallet::put_sent_output( wallet::put_sent_output(
up, &wdb.conn.0,
output.account, &wdb.params,
tx_ref, output.account,
output.index, tx_ref,
&recipient, output.index,
Amount::from_u64(output.note.value().inner()).map_err(|_| &recipient,
SqliteClientError::CorruptedData("Note value is not a valid Zcash amount.".to_string()))?, Amount::from_u64(output.note.value().inner()).map_err(|_| {
Some(&output.memo), SqliteClientError::CorruptedData(
)?; "Note value is not a valid Zcash amount.".to_string(),
)
})?,
Some(&output.memo),
)?;
if matches!(recipient, Recipient::InternalAccount(_, _)) { if matches!(recipient, Recipient::InternalAccount(_, _)) {
wallet::sapling::put_received_note(up, output, tx_ref)?; wallet::sapling::put_received_note(&wdb.conn.0, output, tx_ref)?;
}
} }
TransferType::Incoming => { }
match spending_account_id { TransferType::Incoming => {
Some(id) => match spending_account_id {
if id != output.account { Some(id) => {
panic!("Unable to determine a unique account identifier for z->t spend."); if id != output.account {
} panic!("Unable to determine a unique account identifier for z->t spend.");
None => {
spending_account_id = Some(output.account);
} }
} }
None => {
wallet::sapling::put_received_note(up, output, tx_ref)?; spending_account_id = Some(output.account);
}
}
}
// If any of the utxos spent in the transaction are ours, mark them as spent.
#[cfg(feature = "transparent-inputs")]
for txin in d_tx.tx.transparent_bundle().iter().flat_map(|b| b.vin.iter()) {
wallet::mark_transparent_utxo_spent(up, tx_ref, &txin.prevout)?;
}
// If we have some transparent outputs:
if !d_tx.tx.transparent_bundle().iter().any(|b| b.vout.is_empty()) {
let nullifiers = self.wallet_db.get_sapling_nullifiers(data_api::NullifierQuery::All)?;
// If the transaction contains shielded spends from our wallet, we will store z->t
// transactions we observe in the same way they would be stored by
// create_spend_to_address.
if let Some((account_id, _)) = nullifiers.iter().find(
|(_, nf)|
d_tx.tx.sapling_bundle().iter().flat_map(|b| b.shielded_spends().iter())
.any(|input| nf == input.nullifier())
) {
for (output_index, txout) in d_tx.tx.transparent_bundle().iter().flat_map(|b| b.vout.iter()).enumerate() {
if let Some(address) = txout.recipient_address() {
wallet::put_sent_output(
up,
*account_id,
tx_ref,
output_index,
&Recipient::Transparent(address),
txout.value,
None
)?;
} }
} }
wallet::sapling::put_received_note(&wdb.conn.0, output, tx_ref)?;
} }
} }
Ok(tx_ref) }
// If any of the utxos spent in the transaction are ours, mark them as spent.
#[cfg(feature = "transparent-inputs")]
for txin in d_tx
.tx
.transparent_bundle()
.iter()
.flat_map(|b| b.vin.iter())
{
wallet::mark_transparent_utxo_spent(&wdb.conn.0, tx_ref, &txin.prevout)?;
}
// If we have some transparent outputs:
if !d_tx
.tx
.transparent_bundle()
.iter()
.any(|b| b.vout.is_empty())
{
let nullifiers = wdb.get_sapling_nullifiers(data_api::NullifierQuery::All)?;
// If the transaction contains shielded spends from our wallet, we will store z->t
// transactions we observe in the same way they would be stored by
// create_spend_to_address.
if let Some((account_id, _)) = nullifiers.iter().find(|(_, nf)| {
d_tx.tx
.sapling_bundle()
.iter()
.flat_map(|b| b.shielded_spends().iter())
.any(|input| nf == input.nullifier())
}) {
for (output_index, txout) in d_tx
.tx
.transparent_bundle()
.iter()
.flat_map(|b| b.vout.iter())
.enumerate()
{
if let Some(address) = txout.recipient_address() {
wallet::put_sent_output(
&wdb.conn.0,
&wdb.params,
*account_id,
tx_ref,
output_index,
&Recipient::Transparent(address),
txout.value,
None,
)?;
}
}
}
}
Ok(tx_ref)
}) })
} }
fn store_sent_tx(&mut self, sent_tx: &SentTransaction) -> Result<Self::TxRef, Self::Error> { fn store_sent_tx(&mut self, sent_tx: &SentTransaction) -> Result<Self::TxRef, Self::Error> {
// Update the database atomically, to ensure the result is internally consistent. self.transactionally(|wdb| {
self.transactionally(|up| {
let tx_ref = wallet::put_tx_data( let tx_ref = wallet::put_tx_data(
up, &wdb.conn.0,
sent_tx.tx, sent_tx.tx,
Some(sent_tx.fee_amount), Some(sent_tx.fee_amount),
Some(sent_tx.created), Some(sent_tx.created),
@ -669,21 +569,31 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
// reasonable assumption for a light client such as a mobile phone. // reasonable assumption for a light client such as a mobile phone.
if let Some(bundle) = sent_tx.tx.sapling_bundle() { if let Some(bundle) = sent_tx.tx.sapling_bundle() {
for spend in bundle.shielded_spends() { for spend in bundle.shielded_spends() {
wallet::sapling::mark_sapling_note_spent(up, tx_ref, spend.nullifier())?; wallet::sapling::mark_sapling_note_spent(
&wdb.conn.0,
tx_ref,
spend.nullifier(),
)?;
} }
} }
#[cfg(feature = "transparent-inputs")] #[cfg(feature = "transparent-inputs")]
for utxo_outpoint in &sent_tx.utxos_spent { for utxo_outpoint in &sent_tx.utxos_spent {
wallet::mark_transparent_utxo_spent(up, tx_ref, utxo_outpoint)?; wallet::mark_transparent_utxo_spent(&wdb.conn.0, tx_ref, utxo_outpoint)?;
} }
for output in &sent_tx.outputs { for output in &sent_tx.outputs {
wallet::insert_sent_output(up, tx_ref, sent_tx.account, output)?; wallet::insert_sent_output(
&wdb.conn.0,
&wdb.params,
tx_ref,
sent_tx.account,
output,
)?;
if let Some((account, note)) = output.sapling_change_to() { if let Some((account, note)) = output.sapling_change_to() {
wallet::sapling::put_received_note( wallet::sapling::put_received_note(
up, &wdb.conn.0,
&DecryptedOutput { &DecryptedOutput {
index: output.output_index(), index: output.output_index(),
note: note.clone(), note: note.clone(),
@ -704,7 +614,9 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
} }
fn truncate_to_height(&mut self, block_height: BlockHeight) -> Result<(), Self::Error> { fn truncate_to_height(&mut self, block_height: BlockHeight) -> Result<(), Self::Error> {
wallet::truncate_to_height(self.wallet_db, block_height) self.transactionally(|wdb| {
wallet::truncate_to_height(&wdb.conn.0, &wdb.params, block_height)
})
} }
fn put_received_transparent_utxo( fn put_received_transparent_utxo(
@ -712,7 +624,7 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
_output: &WalletTransparentOutput, _output: &WalletTransparentOutput,
) -> Result<Self::UtxoRef, Self::Error> { ) -> Result<Self::UtxoRef, Self::Error> {
#[cfg(feature = "transparent-inputs")] #[cfg(feature = "transparent-inputs")]
return wallet::put_received_transparent_utxo(self, _output); return wallet::put_received_transparent_utxo(&self.conn, &self.params, _output);
#[cfg(not(feature = "transparent-inputs"))] #[cfg(not(feature = "transparent-inputs"))]
panic!( panic!(
@ -1056,7 +968,7 @@ mod tests {
#[cfg(test)] #[cfg(test)]
pub(crate) fn init_test_accounts_table( pub(crate) fn init_test_accounts_table(
db_data: &WalletDb<Network>, db_data: &mut WalletDb<rusqlite::Connection, Network>,
) -> (DiversifiableFullViewingKey, Option<TransparentAddress>) { ) -> (DiversifiableFullViewingKey, Option<TransparentAddress>) {
let (ufvk, taddr) = init_test_accounts_table_ufvk(db_data); let (ufvk, taddr) = init_test_accounts_table_ufvk(db_data);
(ufvk.sapling().unwrap().clone(), taddr) (ufvk.sapling().unwrap().clone(), taddr)
@ -1064,7 +976,7 @@ mod tests {
#[cfg(test)] #[cfg(test)]
pub(crate) fn init_test_accounts_table_ufvk( pub(crate) fn init_test_accounts_table_ufvk(
db_data: &WalletDb<Network>, db_data: &mut WalletDb<rusqlite::Connection, Network>,
) -> (UnifiedFullViewingKey, Option<TransparentAddress>) { ) -> (UnifiedFullViewingKey, Option<TransparentAddress>) {
let seed = [0u8; 32]; let seed = [0u8; 32];
let account = AccountId::from(0); let account = AccountId::from(0);
@ -1291,13 +1203,12 @@ mod tests {
let account = AccountId::from(0); let account = AccountId::from(0);
init_wallet_db(&mut db_data, None).unwrap(); init_wallet_db(&mut db_data, None).unwrap();
let _ = init_test_accounts_table_ufvk(&db_data); init_test_accounts_table_ufvk(&mut db_data);
let current_addr = db_data.get_current_address(account).unwrap(); let current_addr = db_data.get_current_address(account).unwrap();
assert!(current_addr.is_some()); assert!(current_addr.is_some());
let mut update_ops = db_data.get_update_ops().unwrap(); let addr2 = db_data.get_next_available_address(account).unwrap();
let addr2 = update_ops.get_next_available_address(account).unwrap();
assert!(addr2.is_some()); assert!(addr2.is_some());
assert_ne!(current_addr, addr2); assert_ne!(current_addr, addr2);
@ -1322,7 +1233,7 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet. // Add an account to the wallet.
let (ufvk, taddr) = init_test_accounts_table_ufvk(&db_data); let (ufvk, taddr) = init_test_accounts_table_ufvk(&mut db_data);
let taddr = taddr.unwrap(); let taddr = taddr.unwrap();
let receivers = db_data.get_transparent_receivers(0.into()).unwrap(); let receivers = db_data.get_transparent_receivers(0.into()).unwrap();

View File

@ -1,812 +0,0 @@
//! Prepared SQL statements used by the wallet.
//!
//! Some `rusqlite` crate APIs are only available on prepared statements; these are stored
//! inside the [`DataConnStmtCache`]. When adding a new prepared statement:
//!
//! - Add it as a private field of `DataConnStmtCache`.
//! - Build the statement in [`DataConnStmtCache::new`].
//! - Add a crate-private helper method to `DataConnStmtCache` for running the statement.
use rusqlite::{named_params, params, Statement, ToSql};
use zcash_primitives::{
block::BlockHash,
consensus::{self, BlockHeight},
memo::MemoBytes,
merkle_tree::{write_commitment_tree, write_incremental_witness},
sapling::{self, Diversifier, Nullifier},
transaction::{components::Amount, TxId},
zip32::{AccountId, DiversifierIndex},
};
use zcash_client_backend::{
address::UnifiedAddress,
data_api::{PoolType, Recipient},
encoding::AddressCodec,
};
use crate::{error::SqliteClientError, wallet::pool_code, NoteId, WalletDb};
#[cfg(feature = "transparent-inputs")]
use {
crate::UtxoId, rusqlite::OptionalExtension,
zcash_client_backend::wallet::WalletTransparentOutput,
zcash_primitives::transaction::components::transparent::OutPoint,
};
pub(crate) struct InsertAddress<'a> {
stmt: Statement<'a>,
}
impl<'a> InsertAddress<'a> {
pub(crate) fn new(conn: &'a rusqlite::Connection) -> Result<Self, rusqlite::Error> {
Ok(InsertAddress {
stmt: conn.prepare(
"INSERT INTO addresses (
account,
diversifier_index_be,
address,
cached_transparent_receiver_address
)
VALUES (
:account,
:diversifier_index_be,
:address,
:cached_transparent_receiver_address
)",
)?,
})
}
/// Adds the given address and diversifier index to the addresses table.
///
/// Returns the database row for the newly-inserted address.
pub(crate) fn execute<P: consensus::Parameters>(
&mut self,
params: &P,
account: AccountId,
mut diversifier_index: DiversifierIndex,
address: &UnifiedAddress,
) -> Result<(), rusqlite::Error> {
// the diversifier index is stored in big-endian order to allow sorting
diversifier_index.0.reverse();
self.stmt.execute(named_params![
":account": &u32::from(account),
":diversifier_index_be": &&diversifier_index.0[..],
":address": &address.encode(params),
":cached_transparent_receiver_address": &address.transparent().map(|r| r.encode(params)),
])?;
Ok(())
}
}
/// The primary type used to implement [`WalletWrite`] for the SQLite database.
///
/// A data structure that stores the SQLite prepared statements that are
/// required for the implementation of [`WalletWrite`] against the backing
/// store.
///
/// [`WalletWrite`]: zcash_client_backend::data_api::WalletWrite
pub struct DataConnStmtCache<'a, P> {
pub(crate) wallet_db: &'a WalletDb<P>,
stmt_insert_block: Statement<'a>,
stmt_insert_tx_meta: Statement<'a>,
stmt_update_tx_meta: Statement<'a>,
stmt_insert_tx_data: Statement<'a>,
stmt_update_tx_data: Statement<'a>,
stmt_select_tx_ref: Statement<'a>,
stmt_mark_sapling_note_spent: Statement<'a>,
#[cfg(feature = "transparent-inputs")]
stmt_mark_transparent_utxo_spent: Statement<'a>,
#[cfg(feature = "transparent-inputs")]
stmt_insert_received_transparent_utxo: Statement<'a>,
#[cfg(feature = "transparent-inputs")]
stmt_update_received_transparent_utxo: Statement<'a>,
#[cfg(feature = "transparent-inputs")]
stmt_insert_legacy_transparent_utxo: Statement<'a>,
#[cfg(feature = "transparent-inputs")]
stmt_update_legacy_transparent_utxo: Statement<'a>,
stmt_insert_received_note: Statement<'a>,
stmt_update_received_note: Statement<'a>,
stmt_select_received_note: Statement<'a>,
stmt_insert_sent_output: Statement<'a>,
stmt_update_sent_output: Statement<'a>,
stmt_insert_witness: Statement<'a>,
stmt_prune_witnesses: Statement<'a>,
stmt_update_expired: Statement<'a>,
stmt_insert_address: InsertAddress<'a>,
}
impl<'a, P> DataConnStmtCache<'a, P> {
pub(crate) fn new(wallet_db: &'a WalletDb<P>) -> Result<Self, SqliteClientError> {
Ok(
DataConnStmtCache {
wallet_db,
stmt_insert_block: wallet_db.conn.prepare(
"INSERT INTO blocks (height, hash, time, sapling_tree)
VALUES (?, ?, ?, ?)",
)?,
stmt_insert_tx_meta: wallet_db.conn.prepare(
"INSERT INTO transactions (txid, block, tx_index)
VALUES (?, ?, ?)",
)?,
stmt_update_tx_meta: wallet_db.conn.prepare(
"UPDATE transactions
SET block = ?, tx_index = ? WHERE txid = ?",
)?,
stmt_insert_tx_data: wallet_db.conn.prepare(
"INSERT INTO transactions (txid, created, expiry_height, raw, fee)
VALUES (?, ?, ?, ?, ?)",
)?,
stmt_update_tx_data: wallet_db.conn.prepare(
"UPDATE transactions
SET expiry_height = :expiry_height,
raw = :raw,
fee = IFNULL(:fee, fee)
WHERE txid = :txid",
)?,
stmt_select_tx_ref: wallet_db.conn.prepare(
"SELECT id_tx FROM transactions WHERE txid = ?",
)?,
stmt_mark_sapling_note_spent: wallet_db.conn.prepare(
"UPDATE sapling_received_notes SET spent = ? WHERE nf = ?"
)?,
#[cfg(feature = "transparent-inputs")]
stmt_mark_transparent_utxo_spent: wallet_db.conn.prepare(
"UPDATE utxos SET spent_in_tx = :spent_in_tx
WHERE prevout_txid = :prevout_txid
AND prevout_idx = :prevout_idx"
)?,
#[cfg(feature = "transparent-inputs")]
stmt_insert_received_transparent_utxo: wallet_db.conn.prepare(
"INSERT INTO utxos (
received_by_account, address,
prevout_txid, prevout_idx, script,
value_zat, height)
SELECT
addresses.account, :address,
:prevout_txid, :prevout_idx, :script,
:value_zat, :height
FROM addresses
WHERE addresses.cached_transparent_receiver_address = :address
RETURNING id_utxo"
)?,
#[cfg(feature = "transparent-inputs")]
stmt_update_received_transparent_utxo: wallet_db.conn.prepare(
"UPDATE utxos
SET received_by_account = addresses.account,
height = :height,
address = :address,
script = :script,
value_zat = :value_zat
FROM addresses
WHERE prevout_txid = :prevout_txid
AND prevout_idx = :prevout_idx
AND addresses.cached_transparent_receiver_address = :address
RETURNING id_utxo"
)?,
#[cfg(feature = "transparent-inputs")]
stmt_insert_legacy_transparent_utxo: wallet_db.conn.prepare(
"INSERT INTO utxos (
received_by_account, address,
prevout_txid, prevout_idx, script,
value_zat, height)
VALUES
(:received_by_account, :address,
:prevout_txid, :prevout_idx, :script,
:value_zat, :height)
RETURNING id_utxo"
)?,
#[cfg(feature = "transparent-inputs")]
stmt_update_legacy_transparent_utxo: wallet_db.conn.prepare(
"UPDATE utxos
SET received_by_account = :received_by_account,
height = :height,
address = :address,
script = :script,
value_zat = :value_zat
WHERE prevout_txid = :prevout_txid
AND prevout_idx = :prevout_idx
RETURNING id_utxo"
)?,
stmt_insert_received_note: wallet_db.conn.prepare(
"INSERT INTO sapling_received_notes (tx, output_index, account, diversifier, value, rcm, memo, nf, is_change)
VALUES (:tx, :output_index, :account, :diversifier, :value, :rcm, :memo, :nf, :is_change)",
)?,
stmt_update_received_note: wallet_db.conn.prepare(
"UPDATE sapling_received_notes
SET account = :account,
diversifier = :diversifier,
value = :value,
rcm = :rcm,
nf = IFNULL(:nf, nf),
memo = IFNULL(:memo, memo),
is_change = IFNULL(:is_change, is_change)
WHERE tx = :tx AND output_index = :output_index",
)?,
stmt_select_received_note: wallet_db.conn.prepare(
"SELECT id_note FROM sapling_received_notes WHERE tx = ? AND output_index = ?"
)?,
stmt_update_sent_output: wallet_db.conn.prepare(
"UPDATE sent_notes
SET from_account = :from_account,
to_address = :to_address,
to_account = :to_account,
value = :value,
memo = IFNULL(:memo, memo)
WHERE tx = :tx
AND output_pool = :output_pool
AND output_index = :output_index",
)?,
stmt_insert_sent_output: wallet_db.conn.prepare(
"INSERT INTO sent_notes (
tx, output_pool, output_index, from_account,
to_address, to_account, value, memo)
VALUES (
:tx, :output_pool, :output_index, :from_account,
:to_address, :to_account, :value, :memo)"
)?,
stmt_insert_witness: wallet_db.conn.prepare(
"INSERT INTO sapling_witnesses (note, block, witness)
VALUES (?, ?, ?)",
)?,
stmt_prune_witnesses: wallet_db.conn.prepare(
"DELETE FROM sapling_witnesses WHERE block < ?"
)?,
stmt_update_expired: wallet_db.conn.prepare(
"UPDATE sapling_received_notes SET spent = NULL WHERE EXISTS (
SELECT id_tx FROM transactions
WHERE id_tx = sapling_received_notes.spent AND block IS NULL AND expiry_height < ?
)",
)?,
stmt_insert_address: InsertAddress::new(&wallet_db.conn)?
}
)
}
/// Inserts information about a scanned block into the database.
pub fn stmt_insert_block(
&mut self,
block_height: BlockHeight,
block_hash: BlockHash,
block_time: u32,
commitment_tree: &sapling::CommitmentTree,
) -> Result<(), SqliteClientError> {
let mut encoded_tree = Vec::new();
write_commitment_tree(commitment_tree, &mut encoded_tree).unwrap();
self.stmt_insert_block.execute(params![
u32::from(block_height),
&block_hash.0[..],
block_time,
encoded_tree
])?;
Ok(())
}
/// Inserts the given transaction and its block metadata into the wallet.
///
/// Returns the database row for the newly-inserted transaction, or an error if the
/// transaction exists.
pub(crate) fn stmt_insert_tx_meta(
&mut self,
txid: &TxId,
height: BlockHeight,
tx_index: usize,
) -> Result<i64, SqliteClientError> {
self.stmt_insert_tx_meta.execute(params![
&txid.as_ref()[..],
u32::from(height),
(tx_index as i64),
])?;
Ok(self.wallet_db.conn.last_insert_rowid())
}
/// Updates the block metadata for the given transaction.
///
/// Returns `false` if the transaction doesn't exist in the wallet.
pub(crate) fn stmt_update_tx_meta(
&mut self,
height: BlockHeight,
tx_index: usize,
txid: &TxId,
) -> Result<bool, SqliteClientError> {
match self.stmt_update_tx_meta.execute(params![
u32::from(height),
(tx_index as i64),
&txid.as_ref()[..],
])? {
0 => Ok(false),
1 => Ok(true),
_ => unreachable!("txid column is marked as UNIQUE"),
}
}
/// Inserts the given transaction and its data into the wallet.
///
/// Returns the database row for the newly-inserted transaction, or an error if the
/// transaction exists.
pub(crate) fn stmt_insert_tx_data(
&mut self,
txid: &TxId,
created_at: Option<time::OffsetDateTime>,
expiry_height: BlockHeight,
raw_tx: &[u8],
fee: Option<Amount>,
) -> Result<i64, SqliteClientError> {
self.stmt_insert_tx_data.execute(params![
&txid.as_ref()[..],
created_at,
u32::from(expiry_height),
raw_tx,
fee.map(i64::from)
])?;
Ok(self.wallet_db.conn.last_insert_rowid())
}
/// Updates the data for the given transaction.
///
/// Returns `false` if the transaction doesn't exist in the wallet.
pub(crate) fn stmt_update_tx_data(
&mut self,
expiry_height: BlockHeight,
raw_tx: &[u8],
fee: Option<Amount>,
txid: &TxId,
) -> Result<bool, SqliteClientError> {
let sql_args: &[(&str, &dyn ToSql)] = &[
(":expiry_height", &u32::from(expiry_height)),
(":raw", &raw_tx),
(":fee", &fee.map(i64::from)),
(":txid", &&txid.as_ref()[..]),
];
match self.stmt_update_tx_data.execute(sql_args)? {
0 => Ok(false),
1 => Ok(true),
_ => unreachable!("txid column is marked as UNIQUE"),
}
}
/// Finds the database row for the given `txid`, if the transaction is in the wallet.
pub(crate) fn stmt_select_tx_ref(&mut self, txid: &TxId) -> Result<i64, SqliteClientError> {
self.stmt_select_tx_ref
.query_row([&txid.as_ref()[..]], |row| row.get(0))
.map_err(SqliteClientError::from)
}
/// Marks a given nullifier as having been revealed in the construction of the
/// specified transaction.
///
/// Marking a note spent in this fashion does NOT imply that the spending transaction
/// has been mined.
///
/// Returns `false` if the nullifier does not correspond to any received note.
pub(crate) fn stmt_mark_sapling_note_spent(
&mut self,
tx_ref: i64,
nf: &Nullifier,
) -> Result<bool, SqliteClientError> {
match self
.stmt_mark_sapling_note_spent
.execute(params![tx_ref, &nf.0[..]])?
{
0 => Ok(false),
1 => Ok(true),
_ => unreachable!("nf column is marked as UNIQUE"),
}
}
/// Marks the given UTXO as having been spent.
///
/// Returns `false` if `outpoint` does not correspond to any tracked UTXO.
#[cfg(feature = "transparent-inputs")]
pub(crate) fn stmt_mark_transparent_utxo_spent(
&mut self,
tx_ref: i64,
outpoint: &OutPoint,
) -> Result<bool, SqliteClientError> {
let sql_args: &[(&str, &dyn ToSql)] = &[
(":spent_in_tx", &tx_ref),
(":prevout_txid", &outpoint.hash().to_vec()),
(":prevout_idx", &outpoint.n()),
];
match self.stmt_mark_transparent_utxo_spent.execute(sql_args)? {
0 => Ok(false),
1 => Ok(true),
_ => unreachable!("tx_outpoint constraint is marked as UNIQUE"),
}
}
}
impl<'a, P: consensus::Parameters> DataConnStmtCache<'a, P> {
/// Inserts a sent note into the wallet database.
///
/// `output_index` is the index within the transaction that contains the recipient output:
///
/// - If `to` is a Unified address, this is an index into the outputs of the transaction
/// within the bundle associated with the recipient's output pool.
/// - If `to` is a Sapling address, this is an index into the Sapling outputs of the
/// transaction.
/// - If `to` is a transparent address, this is an index into the transparent outputs of
/// the transaction.
/// - If `to` is an internal account, this is an index into the Sapling outputs of the
/// transaction.
#[allow(clippy::too_many_arguments)]
pub(crate) fn stmt_insert_sent_output(
&mut self,
tx_ref: i64,
output_index: usize,
from_account: AccountId,
to: &Recipient,
value: Amount,
memo: Option<&MemoBytes>,
) -> Result<(), SqliteClientError> {
let (to_address, to_account, pool_type) = match to {
Recipient::Transparent(addr) => (
Some(addr.encode(&self.wallet_db.params)),
None,
PoolType::Transparent,
),
Recipient::Sapling(addr) => (
Some(addr.encode(&self.wallet_db.params)),
None,
PoolType::Sapling,
),
Recipient::Unified(addr, pool) => {
(Some(addr.encode(&self.wallet_db.params)), None, *pool)
}
Recipient::InternalAccount(id, pool) => (None, Some(u32::from(*id)), *pool),
};
self.stmt_insert_sent_output.execute(named_params![
":tx": &tx_ref,
":output_pool": &pool_code(pool_type),
":output_index": &i64::try_from(output_index).unwrap(),
":from_account": &u32::from(from_account),
":to_address": &to_address,
":to_account": &to_account,
":value": &i64::from(value),
":memo": &memo.filter(|m| *m != &MemoBytes::empty()).map(|m| m.as_slice()),
])?;
Ok(())
}
/// Updates the data for the given sent note.
///
/// Returns `false` if the transaction doesn't exist in the wallet.
#[allow(clippy::too_many_arguments)]
pub(crate) fn stmt_update_sent_output(
&mut self,
from_account: AccountId,
to: &Recipient,
value: Amount,
memo: Option<&MemoBytes>,
tx_ref: i64,
output_index: usize,
) -> Result<bool, SqliteClientError> {
let (to_address, to_account, pool_type) = match to {
Recipient::Transparent(addr) => (
Some(addr.encode(&self.wallet_db.params)),
None,
PoolType::Transparent,
),
Recipient::Sapling(addr) => (
Some(addr.encode(&self.wallet_db.params)),
None,
PoolType::Sapling,
),
Recipient::Unified(addr, pool) => {
(Some(addr.encode(&self.wallet_db.params)), None, *pool)
}
Recipient::InternalAccount(id, pool) => (None, Some(u32::from(*id)), *pool),
};
match self.stmt_update_sent_output.execute(named_params![
":from_account": &u32::from(from_account),
":to_address": &to_address,
":to_account": &to_account,
":value": &i64::from(value),
":memo": &memo.filter(|m| *m != &MemoBytes::empty()).map(|m| m.as_slice()),
":tx": &tx_ref,
":output_pool": &pool_code(pool_type),
":output_index": &i64::try_from(output_index).unwrap(),
])? {
0 => Ok(false),
1 => Ok(true),
_ => unreachable!("tx_output constraint is marked as UNIQUE"),
}
}
/// Adds the given received UTXO to the datastore.
///
/// Returns the database identifier for the newly-inserted UTXO if the address to which the
/// UTXO was sent corresponds to a cached transparent receiver in the addresses table, or
/// Ok(None) if the address is unknown. Returns an error if the UTXO exists.
#[cfg(feature = "transparent-inputs")]
pub(crate) fn stmt_insert_received_transparent_utxo(
&mut self,
output: &WalletTransparentOutput,
) -> Result<Option<UtxoId>, SqliteClientError> {
self.stmt_insert_received_transparent_utxo
.query_row(
named_params![
":address": &output.recipient_address().encode(&self.wallet_db.params),
":prevout_txid": &output.outpoint().hash().to_vec(),
":prevout_idx": &output.outpoint().n(),
":script": &output.txout().script_pubkey.0,
":value_zat": &i64::from(output.txout().value),
":height": &u32::from(output.height()),
],
|row| {
let id = row.get(0)?;
Ok(UtxoId(id))
},
)
.optional()
.map_err(SqliteClientError::from)
}
/// Adds the given received UTXO to the datastore.
///
/// Returns the database identifier for the updated UTXO if the address to which the UTXO was
/// sent corresponds to a cached transparent receiver in the addresses table, or Ok(None) if
/// the address is unknown.
#[cfg(feature = "transparent-inputs")]
pub(crate) fn stmt_update_received_transparent_utxo(
&mut self,
output: &WalletTransparentOutput,
) -> Result<Option<UtxoId>, SqliteClientError> {
self.stmt_update_received_transparent_utxo
.query_row(
named_params![
":prevout_txid": &output.outpoint().hash().to_vec(),
":prevout_idx": &output.outpoint().n(),
":address": &output.recipient_address().encode(&self.wallet_db.params),
":script": &output.txout().script_pubkey.0,
":value_zat": &i64::from(output.txout().value),
":height": &u32::from(output.height()),
],
|row| {
let id = row.get(0)?;
Ok(UtxoId(id))
},
)
.optional()
.map_err(SqliteClientError::from)
}
/// Adds the given legacy UTXO to the datastore.
///
/// Returns the database row for the newly-inserted UTXO, or an error if the UTXO
/// exists.
#[cfg(feature = "transparent-inputs")]
pub(crate) fn stmt_insert_legacy_transparent_utxo(
&mut self,
output: &WalletTransparentOutput,
received_by_account: AccountId,
) -> Result<UtxoId, SqliteClientError> {
self.stmt_insert_legacy_transparent_utxo
.query_row(
named_params![
":received_by_account": &u32::from(received_by_account),
":address": &output.recipient_address().encode(&self.wallet_db.params),
":prevout_txid": &output.outpoint().hash().to_vec(),
":prevout_idx": &output.outpoint().n(),
":script": &output.txout().script_pubkey.0,
":value_zat": &i64::from(output.txout().value),
":height": &u32::from(output.height()),
],
|row| {
let id = row.get(0)?;
Ok(UtxoId(id))
},
)
.map_err(SqliteClientError::from)
}
/// Adds the given legacy UTXO to the datastore.
///
/// Returns the database row for the newly-inserted UTXO, or an error if the UTXO
/// exists.
#[cfg(feature = "transparent-inputs")]
pub(crate) fn stmt_update_legacy_transparent_utxo(
&mut self,
output: &WalletTransparentOutput,
received_by_account: AccountId,
) -> Result<Option<UtxoId>, SqliteClientError> {
self.stmt_update_legacy_transparent_utxo
.query_row(
named_params![
":received_by_account": &u32::from(received_by_account),
":prevout_txid": &output.outpoint().hash().to_vec(),
":prevout_idx": &output.outpoint().n(),
":address": &output.recipient_address().encode(&self.wallet_db.params),
":script": &output.txout().script_pubkey.0,
":value_zat": &i64::from(output.txout().value),
":height": &u32::from(output.height()),
],
|row| {
let id = row.get(0)?;
Ok(UtxoId(id))
},
)
.optional()
.map_err(SqliteClientError::from)
}
/// Adds the given address and diversifier index to the addresses table.
///
/// Returns the database row for the newly-inserted address.
pub(crate) fn stmt_insert_address(
&mut self,
account: AccountId,
diversifier_index: DiversifierIndex,
address: &UnifiedAddress,
) -> Result<(), SqliteClientError> {
self.stmt_insert_address.execute(
&self.wallet_db.params,
account,
diversifier_index,
address,
)?;
Ok(())
}
}
impl<'a, P> DataConnStmtCache<'a, P> {
/// Inserts the given received note into the wallet.
///
/// This implementation relies on the facts that:
/// - A transaction will not contain more than 2^63 shielded outputs.
/// - A note value will never exceed 2^63 zatoshis.
///
/// Returns the database row for the newly-inserted note, or an error if the note
/// exists.
#[allow(clippy::too_many_arguments)]
pub(crate) fn stmt_insert_received_note(
&mut self,
tx_ref: i64,
output_index: usize,
account: AccountId,
diversifier: &Diversifier,
value: u64,
rcm: [u8; 32],
nf: Option<&Nullifier>,
memo: Option<&MemoBytes>,
is_change: bool,
) -> Result<NoteId, SqliteClientError> {
let sql_args: &[(&str, &dyn ToSql)] = &[
(":tx", &tx_ref),
(":output_index", &(output_index as i64)),
(":account", &u32::from(account)),
(":diversifier", &diversifier.0.as_ref()),
(":value", &(value as i64)),
(":rcm", &rcm.as_ref()),
(":nf", &nf.map(|nf| nf.0.as_ref())),
(
":memo",
&memo
.filter(|m| *m != &MemoBytes::empty())
.map(|m| m.as_slice()),
),
(":is_change", &is_change),
];
self.stmt_insert_received_note.execute(sql_args)?;
Ok(NoteId::ReceivedNoteId(
self.wallet_db.conn.last_insert_rowid(),
))
}
/// Updates the data for the given transaction.
///
/// This implementation relies on the facts that:
/// - A transaction will not contain more than 2^63 shielded outputs.
/// - A note value will never exceed 2^63 zatoshis.
///
/// Returns `false` if the transaction doesn't exist in the wallet.
#[allow(clippy::too_many_arguments)]
pub(crate) fn stmt_update_received_note(
&mut self,
account: AccountId,
diversifier: &Diversifier,
value: u64,
rcm: [u8; 32],
nf: Option<&Nullifier>,
memo: Option<&MemoBytes>,
is_change: bool,
tx_ref: i64,
output_index: usize,
) -> Result<bool, SqliteClientError> {
let sql_args: &[(&str, &dyn ToSql)] = &[
(":account", &u32::from(account)),
(":diversifier", &diversifier.0.as_ref()),
(":value", &(value as i64)),
(":rcm", &rcm.as_ref()),
(":nf", &nf.map(|nf| nf.0.as_ref())),
(
":memo",
&memo
.filter(|m| *m != &MemoBytes::empty())
.map(|m| m.as_slice()),
),
(":is_change", &is_change),
(":tx", &tx_ref),
(":output_index", &(output_index as i64)),
];
match self.stmt_update_received_note.execute(sql_args)? {
0 => Ok(false),
1 => Ok(true),
_ => unreachable!("tx_output constraint is marked as UNIQUE"),
}
}
/// Finds the database row for the given `txid`, if the transaction is in the wallet.
pub(crate) fn stmt_select_received_note(
&mut self,
tx_ref: i64,
output_index: usize,
) -> Result<NoteId, SqliteClientError> {
self.stmt_select_received_note
.query_row(params![tx_ref, (output_index as i64)], |row| {
row.get(0).map(NoteId::ReceivedNoteId)
})
.map_err(SqliteClientError::from)
}
/// Records the incremental witness for the specified note, as of the given block
/// height.
///
/// Returns `SqliteClientError::InvalidNoteId` if the note ID is for a sent note.
pub(crate) fn stmt_insert_witness(
&mut self,
note_id: NoteId,
height: BlockHeight,
witness: &sapling::IncrementalWitness,
) -> Result<(), SqliteClientError> {
let note_id = match note_id {
NoteId::ReceivedNoteId(note_id) => Ok(note_id),
NoteId::SentNoteId(_) => Err(SqliteClientError::InvalidNoteId),
}?;
let mut encoded = Vec::new();
write_incremental_witness(witness, &mut encoded).unwrap();
self.stmt_insert_witness
.execute(params![note_id, u32::from(height), encoded])?;
Ok(())
}
/// Removes old incremental witnesses up to the given block height.
pub(crate) fn stmt_prune_witnesses(
&mut self,
below_height: BlockHeight,
) -> Result<(), SqliteClientError> {
self.stmt_prune_witnesses
.execute([u32::from(below_height)])?;
Ok(())
}
/// Marks notes that have not been mined in transactions as expired, up to the given
/// block height.
pub fn stmt_update_expired(&mut self, height: BlockHeight) -> Result<(), SqliteClientError> {
self.stmt_update_expired.execute([u32::from(height)])?;
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@ -111,14 +111,14 @@ impl std::error::Error for WalletMigrationError {
// check for unspent transparent outputs whenever running initialization with a version of the // check for unspent transparent outputs whenever running initialization with a version of the
// library *not* compiled with the `transparent-inputs` feature flag, and fail if any are present. // library *not* compiled with the `transparent-inputs` feature flag, and fail if any are present.
pub fn init_wallet_db<P: consensus::Parameters + 'static>( pub fn init_wallet_db<P: consensus::Parameters + 'static>(
wdb: &mut WalletDb<P>, wdb: &mut WalletDb<rusqlite::Connection, P>,
seed: Option<SecretVec<u8>>, seed: Option<SecretVec<u8>>,
) -> Result<(), MigratorError<WalletMigrationError>> { ) -> Result<(), MigratorError<WalletMigrationError>> {
init_wallet_db_internal(wdb, seed, &[]) init_wallet_db_internal(wdb, seed, &[])
} }
fn init_wallet_db_internal<P: consensus::Parameters + 'static>( fn init_wallet_db_internal<P: consensus::Parameters + 'static>(
wdb: &mut WalletDb<P>, wdb: &mut WalletDb<rusqlite::Connection, P>,
seed: Option<SecretVec<u8>>, seed: Option<SecretVec<u8>>,
target_migrations: &[Uuid], target_migrations: &[Uuid],
) -> Result<(), MigratorError<WalletMigrationError>> { ) -> Result<(), MigratorError<WalletMigrationError>> {
@ -200,7 +200,7 @@ fn init_wallet_db_internal<P: consensus::Parameters + 'static>(
/// let dfvk = extsk.to_diversifiable_full_viewing_key(); /// let dfvk = extsk.to_diversifiable_full_viewing_key();
/// let ufvk = UnifiedFullViewingKey::new(None, Some(dfvk), None).unwrap(); /// let ufvk = UnifiedFullViewingKey::new(None, Some(dfvk), None).unwrap();
/// let ufvks = HashMap::from([(account, ufvk)]); /// let ufvks = HashMap::from([(account, ufvk)]);
/// init_accounts_table(&db_data, &ufvks).unwrap(); /// init_accounts_table(&mut db_data, &ufvks).unwrap();
/// # } /// # }
/// ``` /// ```
/// ///
@ -208,29 +208,29 @@ fn init_wallet_db_internal<P: consensus::Parameters + 'static>(
/// [`scan_cached_blocks`]: zcash_client_backend::data_api::chain::scan_cached_blocks /// [`scan_cached_blocks`]: zcash_client_backend::data_api::chain::scan_cached_blocks
/// [`create_spend_to_address`]: zcash_client_backend::data_api::wallet::create_spend_to_address /// [`create_spend_to_address`]: zcash_client_backend::data_api::wallet::create_spend_to_address
pub fn init_accounts_table<P: consensus::Parameters>( pub fn init_accounts_table<P: consensus::Parameters>(
wdb: &WalletDb<P>, wallet_db: &mut WalletDb<rusqlite::Connection, P>,
keys: &HashMap<AccountId, UnifiedFullViewingKey>, keys: &HashMap<AccountId, UnifiedFullViewingKey>,
) -> Result<(), SqliteClientError> { ) -> Result<(), SqliteClientError> {
let mut empty_check = wdb.conn.prepare("SELECT * FROM accounts LIMIT 1")?; wallet_db.transactionally(|wdb| {
if empty_check.exists([])? { let mut empty_check = wdb.conn.0.prepare("SELECT * FROM accounts LIMIT 1")?;
return Err(SqliteClientError::TableNotEmpty); if empty_check.exists([])? {
} return Err(SqliteClientError::TableNotEmpty);
// Ensure that the account identifiers are sequential and begin at zero.
if let Some(account_id) = keys.keys().max() {
if usize::try_from(u32::from(*account_id)).unwrap() >= keys.len() {
return Err(SqliteClientError::AccountIdDiscontinuity);
} }
}
// Insert accounts atomically // Ensure that the account identifiers are sequential and begin at zero.
wdb.conn.execute("BEGIN IMMEDIATE", [])?; if let Some(account_id) = keys.keys().max() {
for (account, key) in keys.iter() { if usize::try_from(u32::from(*account_id)).unwrap() >= keys.len() {
wallet::add_account(wdb, *account, key)?; return Err(SqliteClientError::AccountIdDiscontinuity);
} }
wdb.conn.execute("COMMIT", [])?; }
Ok(()) // Insert accounts atomically
for (account, key) in keys.iter() {
wallet::add_account(&wdb.conn.0, &wdb.params, *account, key)?;
}
Ok(())
})
} }
/// Initialises the data database with the given block. /// Initialises the data database with the given block.
@ -262,33 +262,35 @@ pub fn init_accounts_table<P: consensus::Parameters>(
/// let sapling_tree = &[]; /// let sapling_tree = &[];
/// ///
/// let data_file = NamedTempFile::new().unwrap(); /// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDb::for_path(data_file.path(), Network::TestNetwork).unwrap(); /// let mut db = WalletDb::for_path(data_file.path(), Network::TestNetwork).unwrap();
/// init_blocks_table(&db, height, hash, time, sapling_tree); /// init_blocks_table(&mut db, height, hash, time, sapling_tree);
/// ``` /// ```
pub fn init_blocks_table<P>( pub fn init_blocks_table<P: consensus::Parameters>(
wdb: &WalletDb<P>, wallet_db: &mut WalletDb<rusqlite::Connection, P>,
height: BlockHeight, height: BlockHeight,
hash: BlockHash, hash: BlockHash,
time: u32, time: u32,
sapling_tree: &[u8], sapling_tree: &[u8],
) -> Result<(), SqliteClientError> { ) -> Result<(), SqliteClientError> {
let mut empty_check = wdb.conn.prepare("SELECT * FROM blocks LIMIT 1")?; wallet_db.transactionally(|wdb| {
if empty_check.exists([])? { let mut empty_check = wdb.conn.0.prepare("SELECT * FROM blocks LIMIT 1")?;
return Err(SqliteClientError::TableNotEmpty); if empty_check.exists([])? {
} return Err(SqliteClientError::TableNotEmpty);
}
wdb.conn.execute( wdb.conn.0.execute(
"INSERT INTO blocks (height, hash, time, sapling_tree) "INSERT INTO blocks (height, hash, time, sapling_tree)
VALUES (?, ?, ?, ?)", VALUES (?, ?, ?, ?)",
[ [
u32::from(height).to_sql()?, u32::from(height).to_sql()?,
hash.0.to_sql()?, hash.0.to_sql()?,
time.to_sql()?, time.to_sql()?,
sapling_tree.to_sql()?, sapling_tree.to_sql()?,
], ],
)?; )?;
Ok(()) Ok(())
})
} }
#[cfg(test)] #[cfg(test)]
@ -606,7 +608,7 @@ mod tests {
#[test] #[test]
fn init_migrate_from_0_3_0() { fn init_migrate_from_0_3_0() {
fn init_0_3_0<P>( fn init_0_3_0<P>(
wdb: &mut WalletDb<P>, wdb: &mut WalletDb<rusqlite::Connection, P>,
extfvk: &ExtendedFullViewingKey, extfvk: &ExtendedFullViewingKey,
account: AccountId, account: AccountId,
) -> Result<(), rusqlite::Error> { ) -> Result<(), rusqlite::Error> {
@ -722,7 +724,7 @@ mod tests {
#[test] #[test]
fn init_migrate_from_autoshielding_poc() { fn init_migrate_from_autoshielding_poc() {
fn init_autoshielding<P>( fn init_autoshielding<P>(
wdb: &WalletDb<P>, wdb: &mut WalletDb<rusqlite::Connection, P>,
extfvk: &ExtendedFullViewingKey, extfvk: &ExtendedFullViewingKey,
account: AccountId, account: AccountId,
) -> Result<(), rusqlite::Error> { ) -> Result<(), rusqlite::Error> {
@ -878,14 +880,14 @@ mod tests {
let extfvk = secret_key.to_extended_full_viewing_key(); let extfvk = secret_key.to_extended_full_viewing_key();
let data_file = NamedTempFile::new().unwrap(); let data_file = NamedTempFile::new().unwrap();
let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap();
init_autoshielding(&db_data, &extfvk, account).unwrap(); init_autoshielding(&mut db_data, &extfvk, account).unwrap();
init_wallet_db(&mut db_data, Some(Secret::new(seed.to_vec()))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(seed.to_vec()))).unwrap();
} }
#[test] #[test]
fn init_migrate_from_main_pre_migrations() { fn init_migrate_from_main_pre_migrations() {
fn init_main<P>( fn init_main<P>(
wdb: &WalletDb<P>, wdb: &mut WalletDb<rusqlite::Connection, P>,
ufvk: &UnifiedFullViewingKey, ufvk: &UnifiedFullViewingKey,
account: AccountId, account: AccountId,
) -> Result<(), rusqlite::Error> { ) -> Result<(), rusqlite::Error> {
@ -1025,7 +1027,12 @@ mod tests {
let secret_key = UnifiedSpendingKey::from_seed(&tests::network(), &seed, account).unwrap(); let secret_key = UnifiedSpendingKey::from_seed(&tests::network(), &seed, account).unwrap();
let data_file = NamedTempFile::new().unwrap(); let data_file = NamedTempFile::new().unwrap();
let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap();
init_main(&db_data, &secret_key.to_unified_full_viewing_key(), account).unwrap(); init_main(
&mut db_data,
&secret_key.to_unified_full_viewing_key(),
account,
)
.unwrap();
init_wallet_db(&mut db_data, Some(Secret::new(seed.to_vec()))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(seed.to_vec()))).unwrap();
} }
@ -1036,8 +1043,8 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// We can call the function as many times as we want with no data // We can call the function as many times as we want with no data
init_accounts_table(&db_data, &HashMap::new()).unwrap(); init_accounts_table(&mut db_data, &HashMap::new()).unwrap();
init_accounts_table(&db_data, &HashMap::new()).unwrap(); init_accounts_table(&mut db_data, &HashMap::new()).unwrap();
let seed = [0u8; 32]; let seed = [0u8; 32];
let account = AccountId::from(0); let account = AccountId::from(0);
@ -1062,11 +1069,11 @@ mod tests {
let ufvk = UnifiedFullViewingKey::new(Some(dfvk), None).unwrap(); let ufvk = UnifiedFullViewingKey::new(Some(dfvk), None).unwrap();
let ufvks = HashMap::from([(account, ufvk)]); let ufvks = HashMap::from([(account, ufvk)]);
init_accounts_table(&db_data, &ufvks).unwrap(); init_accounts_table(&mut db_data, &ufvks).unwrap();
// Subsequent calls should return an error // Subsequent calls should return an error
init_accounts_table(&db_data, &HashMap::new()).unwrap_err(); init_accounts_table(&mut db_data, &HashMap::new()).unwrap_err();
init_accounts_table(&db_data, &ufvks).unwrap_err(); init_accounts_table(&mut db_data, &ufvks).unwrap_err();
} }
#[test] #[test]
@ -1090,12 +1097,12 @@ mod tests {
// should fail if we have a gap // should fail if we have a gap
assert_matches!( assert_matches!(
init_accounts_table(&db_data, &ufvks(&[0, 2])), init_accounts_table(&mut db_data, &ufvks(&[0, 2])),
Err(SqliteClientError::AccountIdDiscontinuity) Err(SqliteClientError::AccountIdDiscontinuity)
); );
// should succeed if there are no gaps // should succeed if there are no gaps
assert!(init_accounts_table(&db_data, &ufvks(&[0, 1, 2])).is_ok()); assert!(init_accounts_table(&mut db_data, &ufvks(&[0, 1, 2])).is_ok());
} }
#[test] #[test]
@ -1106,7 +1113,7 @@ mod tests {
// First call with data should initialise the blocks table // First call with data should initialise the blocks table
init_blocks_table( init_blocks_table(
&db_data, &mut db_data,
BlockHeight::from(1u32), BlockHeight::from(1u32),
BlockHash([1; 32]), BlockHash([1; 32]),
1, 1,
@ -1116,7 +1123,7 @@ mod tests {
// Subsequent calls should return an error // Subsequent calls should return an error
init_blocks_table( init_blocks_table(
&db_data, &mut db_data,
BlockHeight::from(2u32), BlockHeight::from(2u32),
BlockHash([2; 32]), BlockHash([2; 32]),
2, 2,
@ -1139,7 +1146,7 @@ mod tests {
let ufvk = usk.to_unified_full_viewing_key(); let ufvk = usk.to_unified_full_viewing_key();
let expected_address = ufvk.sapling().unwrap().default_address().1; let expected_address = ufvk.sapling().unwrap().default_address().1;
let ufvks = HashMap::from([(account_id, ufvk)]); let ufvks = HashMap::from([(account_id, ufvk)]);
init_accounts_table(&db_data, &ufvks).unwrap(); init_accounts_table(&mut db_data, &ufvks).unwrap();
// The account's address should be in the data DB // The account's address should be in the data DB
let ua = db_data.get_current_address(AccountId::from(0)).unwrap(); let ua = db_data.get_current_address(AccountId::from(0)).unwrap();
@ -1153,16 +1160,15 @@ mod tests {
let mut db_data = WalletDb::for_path(data_file.path(), Network::MainNetwork).unwrap(); let mut db_data = WalletDb::for_path(data_file.path(), Network::MainNetwork).unwrap();
init_wallet_db(&mut db_data, None).unwrap(); init_wallet_db(&mut db_data, None).unwrap();
let mut ops = db_data.get_update_ops().unwrap();
let seed = test_vectors::UNIFIED[0].root_seed; let seed = test_vectors::UNIFIED[0].root_seed;
let (account, _usk) = ops.create_account(&Secret::new(seed.to_vec())).unwrap(); let (account, _usk) = db_data.create_account(&Secret::new(seed.to_vec())).unwrap();
assert_eq!(account, AccountId::from(0u32)); assert_eq!(account, AccountId::from(0u32));
for tv in &test_vectors::UNIFIED[..3] { for tv in &test_vectors::UNIFIED[..3] {
if let Some(RecipientAddress::Unified(tvua)) = if let Some(RecipientAddress::Unified(tvua)) =
RecipientAddress::decode(&Network::MainNetwork, tv.unified_addr) RecipientAddress::decode(&Network::MainNetwork, tv.unified_addr)
{ {
let (ua, di) = wallet::get_current_address(&db_data, account) let (ua, di) = wallet::get_current_address(&db_data.conn, &db_data.params, account)
.unwrap() .unwrap()
.expect("create_account generated the first address"); .expect("create_account generated the first address");
assert_eq!(DiversifierIndex::from(tv.diversifier_index), di); assert_eq!(DiversifierIndex::from(tv.diversifier_index), di);
@ -1170,7 +1176,8 @@ mod tests {
assert_eq!(tvua.sapling(), ua.sapling()); assert_eq!(tvua.sapling(), ua.sapling());
assert_eq!(tv.unified_addr, ua.encode(&Network::MainNetwork)); assert_eq!(tv.unified_addr, ua.encode(&Network::MainNetwork));
ops.get_next_available_address(account) db_data
.get_next_available_address(account)
.unwrap() .unwrap()
.expect("get_next_available_address generated an address"); .expect("get_next_available_address generated an address");
} else { } else {

View File

@ -67,7 +67,7 @@ impl<P: consensus::Parameters> RusqliteMigration for Migration<P> {
while let Some(row) = rows.next()? { while let Some(row) = rows.next()? {
let account: u32 = row.get(0)?; let account: u32 = row.get(0)?;
let taddrs = let taddrs =
get_transparent_receivers(&self._params, transaction, AccountId::from(account)) get_transparent_receivers(transaction, &self._params, AccountId::from(account))
.map_err(|e| match e { .map_err(|e| match e {
SqliteClientError::DbError(e) => WalletMigrationError::DbError(e), SqliteClientError::DbError(e) => WalletMigrationError::DbError(e),
SqliteClientError::CorruptedData(s) => { SqliteClientError::CorruptedData(s) => {

View File

@ -1,12 +1,12 @@
//! Functions for Sapling support in the wallet. //! Functions for Sapling support in the wallet.
use group::ff::PrimeField; use group::ff::PrimeField;
use rusqlite::{named_params, types::Value, OptionalExtension, Row}; use rusqlite::{named_params, params, types::Value, Connection, OptionalExtension, Row};
use std::rc::Rc; use std::rc::Rc;
use zcash_primitives::{ use zcash_primitives::{
consensus::BlockHeight, consensus::BlockHeight,
memo::MemoBytes, memo::MemoBytes,
merkle_tree::{read_commitment_tree, read_incremental_witness}, merkle_tree::{read_commitment_tree, read_incremental_witness, write_incremental_witness},
sapling::{self, Diversifier, Note, Nullifier, Rseed}, sapling::{self, Diversifier, Note, Nullifier, Rseed},
transaction::components::Amount, transaction::components::Amount,
zip32::AccountId, zip32::AccountId,
@ -17,7 +17,9 @@ use zcash_client_backend::{
DecryptedOutput, TransferType, DecryptedOutput, TransferType,
}; };
use crate::{error::SqliteClientError, DataConnStmtCache, NoteId, WalletDb}; use crate::{error::SqliteClientError, NoteId};
use super::memo_repr;
/// This trait provides a generalization over shielded output representations. /// This trait provides a generalization over shielded output representations.
pub(crate) trait ReceivedSaplingOutput { pub(crate) trait ReceivedSaplingOutput {
@ -117,13 +119,13 @@ fn to_spendable_note(row: &Row) -> Result<ReceivedSaplingNote<NoteId>, SqliteCli
}) })
} }
pub(crate) fn get_spendable_sapling_notes<P>( pub(crate) fn get_spendable_sapling_notes(
wdb: &WalletDb<P>, conn: &Connection,
account: AccountId, account: AccountId,
anchor_height: BlockHeight, anchor_height: BlockHeight,
exclude: &[NoteId], exclude: &[NoteId],
) -> Result<Vec<ReceivedSaplingNote<NoteId>>, SqliteClientError> { ) -> Result<Vec<ReceivedSaplingNote<NoteId>>, SqliteClientError> {
let mut stmt_select_notes = wdb.conn.prepare( let mut stmt_select_notes = conn.prepare_cached(
"SELECT id_note, diversifier, value, rcm, witness "SELECT id_note, diversifier, value, rcm, witness
FROM sapling_received_notes FROM sapling_received_notes
INNER JOIN transactions ON transactions.id_tx = sapling_received_notes.tx INNER JOIN transactions ON transactions.id_tx = sapling_received_notes.tx
@ -156,8 +158,8 @@ pub(crate) fn get_spendable_sapling_notes<P>(
notes.collect::<Result<_, _>>() notes.collect::<Result<_, _>>()
} }
pub(crate) fn select_spendable_sapling_notes<P>( pub(crate) fn select_spendable_sapling_notes(
wdb: &WalletDb<P>, conn: &Connection,
account: AccountId, account: AccountId,
target_value: Amount, target_value: Amount,
anchor_height: BlockHeight, anchor_height: BlockHeight,
@ -181,7 +183,7 @@ pub(crate) fn select_spendable_sapling_notes<P>(
// required value, bringing the sum of all selected notes across the threshold. // required value, bringing the sum of all selected notes across the threshold.
// //
// 4) Match the selected notes against the witnesses at the desired height. // 4) Match the selected notes against the witnesses at the desired height.
let mut stmt_select_notes = wdb.conn.prepare( let mut stmt_select_notes = conn.prepare_cached(
"WITH selected AS ( "WITH selected AS (
WITH eligible AS ( WITH eligible AS (
SELECT id_note, diversifier, value, rcm, SELECT id_note, diversifier, value, rcm,
@ -230,43 +232,42 @@ pub(crate) fn select_spendable_sapling_notes<P>(
/// Returns the commitment tree for the block at the specified height, /// Returns the commitment tree for the block at the specified height,
/// if any. /// if any.
pub(crate) fn get_sapling_commitment_tree<P>( pub(crate) fn get_sapling_commitment_tree(
wdb: &WalletDb<P>, conn: &Connection,
block_height: BlockHeight, block_height: BlockHeight,
) -> Result<Option<sapling::CommitmentTree>, SqliteClientError> { ) -> Result<Option<sapling::CommitmentTree>, SqliteClientError> {
wdb.conn conn.query_row_and_then(
.query_row_and_then( "SELECT sapling_tree FROM blocks WHERE height = ?",
"SELECT sapling_tree FROM blocks WHERE height = ?", [u32::from(block_height)],
[u32::from(block_height)], |row| {
|row| { let row_data: Vec<u8> = row.get(0)?;
let row_data: Vec<u8> = row.get(0)?; read_commitment_tree(&row_data[..]).map_err(|e| {
read_commitment_tree(&row_data[..]).map_err(|e| { rusqlite::Error::FromSqlConversionFailure(
rusqlite::Error::FromSqlConversionFailure( row_data.len(),
row_data.len(), rusqlite::types::Type::Blob,
rusqlite::types::Type::Blob, Box::new(e),
Box::new(e), )
) })
}) },
}, )
) .optional()
.optional() .map_err(SqliteClientError::from)
.map_err(SqliteClientError::from)
} }
/// Returns the incremental witnesses for the block at the specified height, /// Returns the incremental witnesses for the block at the specified height,
/// if any. /// if any.
pub(crate) fn get_sapling_witnesses<P>( pub(crate) fn get_sapling_witnesses(
wdb: &WalletDb<P>, conn: &Connection,
block_height: BlockHeight, block_height: BlockHeight,
) -> Result<Vec<(NoteId, sapling::IncrementalWitness)>, SqliteClientError> { ) -> Result<Vec<(NoteId, sapling::IncrementalWitness)>, SqliteClientError> {
let mut stmt_fetch_witnesses = wdb let mut stmt_fetch_witnesses =
.conn conn.prepare_cached("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?;
.prepare("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?;
let witnesses = stmt_fetch_witnesses let witnesses = stmt_fetch_witnesses
.query_map([u32::from(block_height)], |row| { .query_map([u32::from(block_height)], |row| {
let id_note = NoteId::ReceivedNoteId(row.get(0)?); let id_note = NoteId::ReceivedNoteId(row.get(0)?);
let wdb: Vec<u8> = row.get(1)?; let witness_data: Vec<u8> = row.get(1)?;
Ok(read_incremental_witness(&wdb[..]).map(|witness| (id_note, witness))) Ok(read_incremental_witness(&witness_data[..]).map(|witness| (id_note, witness)))
}) })
.map_err(SqliteClientError::from)?; .map_err(SqliteClientError::from)?;
@ -277,13 +278,23 @@ pub(crate) fn get_sapling_witnesses<P>(
/// Records the incremental witness for the specified note, /// Records the incremental witness for the specified note,
/// as of the given block height. /// as of the given block height.
pub(crate) fn insert_witness<'a, P>( pub(crate) fn insert_witness(
stmts: &mut DataConnStmtCache<'a, P>, conn: &Connection,
note_id: i64, note_id: i64,
witness: &sapling::IncrementalWitness, witness: &sapling::IncrementalWitness,
height: BlockHeight, height: BlockHeight,
) -> Result<(), SqliteClientError> { ) -> Result<(), SqliteClientError> {
stmts.stmt_insert_witness(NoteId::ReceivedNoteId(note_id), height, witness) let mut stmt_insert_witness = conn.prepare_cached(
"INSERT INTO sapling_witnesses (note, block, witness)
VALUES (?, ?, ?)",
)?;
let mut encoded = Vec::new();
write_incremental_witness(witness, &mut encoded).unwrap();
stmt_insert_witness.execute(params![note_id, u32::from(height), encoded])?;
Ok(())
} }
/// Retrieves the set of nullifiers for "potentially spendable" Sapling notes that the /// Retrieves the set of nullifiers for "potentially spendable" Sapling notes that the
@ -292,11 +303,11 @@ pub(crate) fn insert_witness<'a, P>(
/// "Potentially spendable" means: /// "Potentially spendable" means:
/// - The transaction in which the note was created has been observed as mined. /// - The transaction in which the note was created has been observed as mined.
/// - No transaction in which the note's nullifier appears has been observed as mined. /// - No transaction in which the note's nullifier appears has been observed as mined.
pub(crate) fn get_sapling_nullifiers<P>( pub(crate) fn get_sapling_nullifiers(
wdb: &WalletDb<P>, conn: &Connection,
) -> Result<Vec<(AccountId, Nullifier)>, SqliteClientError> { ) -> Result<Vec<(AccountId, Nullifier)>, SqliteClientError> {
// Get the nullifiers for the notes we are tracking // Get the nullifiers for the notes we are tracking
let mut stmt_fetch_nullifiers = wdb.conn.prepare( let mut stmt_fetch_nullifiers = conn.prepare(
"SELECT rn.id_note, rn.account, rn.nf, tx.block as block "SELECT rn.id_note, rn.account, rn.nf, tx.block as block
FROM sapling_received_notes rn FROM sapling_received_notes rn
LEFT OUTER JOIN transactions tx LEFT OUTER JOIN transactions tx
@ -318,11 +329,11 @@ pub(crate) fn get_sapling_nullifiers<P>(
} }
/// Returns the nullifiers for the notes that this wallet is tracking. /// Returns the nullifiers for the notes that this wallet is tracking.
pub(crate) fn get_all_sapling_nullifiers<P>( pub(crate) fn get_all_sapling_nullifiers(
wdb: &WalletDb<P>, conn: &Connection,
) -> Result<Vec<(AccountId, Nullifier)>, SqliteClientError> { ) -> Result<Vec<(AccountId, Nullifier)>, SqliteClientError> {
// Get the nullifiers for the notes we are tracking // Get the nullifiers for the notes we are tracking
let mut stmt_fetch_nullifiers = wdb.conn.prepare( let mut stmt_fetch_nullifiers = conn.prepare(
"SELECT rn.id_note, rn.account, rn.nf "SELECT rn.id_note, rn.account, rn.nf
FROM sapling_received_notes rn FROM sapling_received_notes rn
WHERE nf IS NOT NULL", WHERE nf IS NOT NULL",
@ -345,13 +356,19 @@ pub(crate) fn get_all_sapling_nullifiers<P>(
/// ///
/// Marking a note spent in this fashion does NOT imply that the /// Marking a note spent in this fashion does NOT imply that the
/// spending transaction has been mined. /// spending transaction has been mined.
pub(crate) fn mark_sapling_note_spent<'a, P>( pub(crate) fn mark_sapling_note_spent(
stmts: &mut DataConnStmtCache<'a, P>, conn: &Connection,
tx_ref: i64, tx_ref: i64,
nf: &Nullifier, nf: &Nullifier,
) -> Result<(), SqliteClientError> { ) -> Result<bool, SqliteClientError> {
stmts.stmt_mark_sapling_note_spent(tx_ref, nf)?; let mut stmt_mark_sapling_note_spent =
Ok(()) conn.prepare_cached("UPDATE sapling_received_notes SET spent = ? WHERE nf = ?")?;
match stmt_mark_sapling_note_spent.execute(params![tx_ref, &nf.0[..]])? {
0 => Ok(false),
1 => Ok(true),
_ => unreachable!("nf column is marked as UNIQUE"),
}
} }
/// Records the specified shielded output as having been received. /// Records the specified shielded output as having been received.
@ -359,49 +376,48 @@ pub(crate) fn mark_sapling_note_spent<'a, P>(
/// This implementation relies on the facts that: /// This implementation relies on the facts that:
/// - A transaction will not contain more than 2^63 shielded outputs. /// - A transaction will not contain more than 2^63 shielded outputs.
/// - A note value will never exceed 2^63 zatoshis. /// - A note value will never exceed 2^63 zatoshis.
pub(crate) fn put_received_note<'a, P, T: ReceivedSaplingOutput>( pub(crate) fn put_received_note<T: ReceivedSaplingOutput>(
stmts: &mut DataConnStmtCache<'a, P>, conn: &Connection,
output: &T, output: &T,
tx_ref: i64, tx_ref: i64,
) -> Result<NoteId, SqliteClientError> { ) -> Result<NoteId, SqliteClientError> {
let mut stmt_upsert_received_note = conn.prepare_cached(
"INSERT INTO sapling_received_notes
(tx, output_index, account, diversifier, value, rcm, memo, nf, is_change)
VALUES
(:tx, :output_index, :account, :diversifier, :value, :rcm, :memo, :nf, :is_change)
ON CONFLICT (tx, output_index) DO UPDATE
SET account = :account,
diversifier = :diversifier,
value = :value,
rcm = :rcm,
nf = IFNULL(:nf, nf),
memo = IFNULL(:memo, memo),
is_change = IFNULL(:is_change, is_change)
RETURNING id_note",
)?;
let rcm = output.note().rcm().to_repr(); let rcm = output.note().rcm().to_repr();
let account = output.account();
let to = output.note().recipient(); let to = output.note().recipient();
let diversifier = to.diversifier(); let diversifier = to.diversifier();
let value = output.note().value();
let memo = output.memo();
let is_change = output.is_change();
let output_index = output.index();
let nf = output.nullifier();
// First try updating an existing received note into the database. let sql_args = named_params![
if !stmts.stmt_update_received_note( ":tx": &tx_ref,
account, ":output_index": i64::try_from(output.index()).expect("output indices are representable as i64"),
diversifier, ":account": u32::from(output.account()),
value.inner(), ":diversifier": &diversifier.0.as_ref(),
rcm, ":value": output.note().value().inner(),
nf, ":rcm": &rcm.as_ref(),
memo, ":nf": output.nullifier().map(|nf| nf.0.as_ref()),
is_change, ":memo": memo_repr(output.memo()),
tx_ref, ":is_change": output.is_change()
output_index, ];
)? {
// It isn't there, so insert our note into the database. stmt_upsert_received_note
stmts.stmt_insert_received_note( .query_row(sql_args, |row| {
tx_ref, row.get::<_, i64>(0).map(NoteId::ReceivedNoteId)
output_index, })
account, .map_err(SqliteClientError::from)
diversifier,
value.inner(),
rcm,
nf,
memo,
is_change,
)
} else {
// It was there, so grab its row number.
stmts.stmt_select_received_note(tx_ref, output.index())
}
} }
#[cfg(test)] #[cfg(test)]
@ -447,7 +463,7 @@ mod tests {
get_balance, get_balance_at, get_balance, get_balance_at,
init::{init_blocks_table, init_wallet_db}, init::{init_blocks_table, init_wallet_db},
}, },
AccountId, BlockDb, DataConnStmtCache, WalletDb, AccountId, BlockDb, WalletDb,
}; };
#[cfg(feature = "transparent-inputs")] #[cfg(feature = "transparent-inputs")]
@ -481,9 +497,8 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet // Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec()); let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap(); let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
let to = dfvk.default_address().1.into(); let to = dfvk.default_address().1.into();
@ -492,10 +507,9 @@ mod tests {
let usk1 = UnifiedSpendingKey::from_seed(&network(), &[1u8; 32], acct1).unwrap(); let usk1 = UnifiedSpendingKey::from_seed(&network(), &[1u8; 32], acct1).unwrap();
// Attempting to spend with a USK that is not in the wallet results in an error // Attempting to spend with a USK that is not in the wallet results in an error
let mut db_write = db_data.get_update_ops().unwrap();
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk1, &usk1,
@ -516,17 +530,15 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap(); init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet // Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec()); let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap(); let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
let to = dfvk.default_address().1.into(); let to = dfvk.default_address().1.into();
// We cannot do anything if we aren't synchronised // We cannot do anything if we aren't synchronised
let mut db_write = db_data.get_update_ops().unwrap();
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -546,7 +558,7 @@ mod tests {
let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap();
init_wallet_db(&mut db_data, None).unwrap(); init_wallet_db(&mut db_data, None).unwrap();
init_blocks_table( init_blocks_table(
&db_data, &mut db_data,
BlockHeight::from(1u32), BlockHeight::from(1u32),
BlockHash([1; 32]), BlockHash([1; 32]),
1, 1,
@ -555,23 +567,21 @@ mod tests {
.unwrap(); .unwrap();
// Add an account to the wallet // Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec()); let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap(); let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
let to = dfvk.default_address().1.into(); let to = dfvk.default_address().1.into();
// Account balance should be zero // Account balance should be zero
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::zero() Amount::zero()
); );
// We cannot spend anything // We cannot spend anything
let mut db_write = db_data.get_update_ops().unwrap();
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -600,9 +610,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap(); init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet // Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec()); let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap(); let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet in a single note // Add funds to the wallet in a single note
@ -615,14 +624,16 @@ mod tests {
value, value,
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Verified balance matches total balance // Verified balance matches total balance
let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap(); let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
assert_eq!( assert_eq!(
get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
assert_eq!(
get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(),
value value
); );
@ -635,16 +646,16 @@ mod tests {
value, value,
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Verified balance does not include the second note // Verified balance does not include the second note
let (_, anchor_height2) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap(); let (_, anchor_height2) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap();
assert_eq!( assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value + value).unwrap() (value + value).unwrap()
); );
assert_eq!( assert_eq!(
get_balance_at(&db_data, AccountId::from(0), anchor_height2).unwrap(), get_balance_at(&db_data.conn, AccountId::from(0), anchor_height2).unwrap(),
value value
); );
@ -653,7 +664,7 @@ mod tests {
let to = extsk2.default_address().1.into(); let to = extsk2.default_address().1.into();
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -683,12 +694,12 @@ mod tests {
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
} }
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Second spend still fails // Second spend still fails
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -715,12 +726,12 @@ mod tests {
value, value,
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Second spend should now succeed // Second spend should now succeed
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -745,9 +756,8 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet // Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec()); let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap(); let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet in a single note // Add funds to the wallet in a single note
@ -760,16 +770,18 @@ mod tests {
value, value,
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); assert_eq!(
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
// Send some of the funds to another address // Send some of the funds to another address
let extsk2 = ExtendedSpendingKey::master(&[]); let extsk2 = ExtendedSpendingKey::master(&[]);
let to = extsk2.default_address().1.into(); let to = extsk2.default_address().1.into();
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -785,7 +797,7 @@ mod tests {
// A second spend fails because there are no usable notes // A second spend fails because there are no usable notes
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -814,12 +826,12 @@ mod tests {
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
} }
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Second spend still fails // Second spend still fails
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -845,11 +857,11 @@ mod tests {
value, value,
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Second spend should now succeed // Second spend should now succeed
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -874,9 +886,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap(); init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet // Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec()); let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap(); let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet in a single note // Add funds to the wallet in a single note
@ -889,17 +900,19 @@ mod tests {
value, value,
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); assert_eq!(
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value); get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
let extsk2 = ExtendedSpendingKey::master(&[]); let extsk2 = ExtendedSpendingKey::master(&[]);
let addr2 = extsk2.default_address().1; let addr2 = extsk2.default_address().1;
let to = addr2.into(); let to = addr2.into();
let send_and_recover_with_policy = |db_write: &mut DataConnStmtCache<'_, _>, ovk_policy| { let send_and_recover_with_policy = |db_data: &mut WalletDb<Connection, _>, ovk_policy| {
let tx_row = create_spend_to_address( let tx_row = create_spend_to_address(
db_write, db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -912,8 +925,7 @@ mod tests {
.unwrap(); .unwrap();
// Fetch the transaction from the database // Fetch the transaction from the database
let raw_tx: Vec<_> = db_write let raw_tx: Vec<_> = db_data
.wallet_db
.conn .conn
.query_row( .query_row(
"SELECT raw FROM transactions "SELECT raw FROM transactions
@ -944,7 +956,7 @@ mod tests {
// Send some of the funds to another address, keeping history. // Send some of the funds to another address, keeping history.
// The recipient output is decryptable by the sender. // The recipient output is decryptable by the sender.
let (_, recovered_to, _) = let (_, recovered_to, _) =
send_and_recover_with_policy(&mut db_write, OvkPolicy::Sender).unwrap(); send_and_recover_with_policy(&mut db_data, OvkPolicy::Sender).unwrap();
assert_eq!(&recovered_to, &addr2); assert_eq!(&recovered_to, &addr2);
// Mine blocks SAPLING_ACTIVATION_HEIGHT + 1 to 42 (that don't send us funds) // Mine blocks SAPLING_ACTIVATION_HEIGHT + 1 to 42 (that don't send us funds)
@ -959,11 +971,11 @@ mod tests {
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
} }
scan_cached_blocks(&network, &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&network, &db_cache, &mut db_data, None).unwrap();
// Send the funds again, discarding history. // Send the funds again, discarding history.
// Neither transaction output is decryptable by the sender. // Neither transaction output is decryptable by the sender.
assert!(send_and_recover_with_policy(&mut db_write, OvkPolicy::Discard).is_none()); assert!(send_and_recover_with_policy(&mut db_data, OvkPolicy::Discard).is_none());
} }
#[test] #[test]
@ -977,9 +989,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap(); init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet // Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec()); let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap(); let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet in a single note // Add funds to the wallet in a single note
@ -992,21 +1003,23 @@ mod tests {
value, value,
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Verified balance matches total balance // Verified balance matches total balance
let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap(); let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
assert_eq!( assert_eq!(
get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
assert_eq!(
get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(),
value value
); );
let to = TransparentAddress::PublicKey([7; 20]).into(); let to = TransparentAddress::PublicKey([7; 20]).into();
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -1031,9 +1044,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap(); init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet // Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec()); let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap(); let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet in a single note // Add funds to the wallet in a single note
@ -1046,21 +1058,23 @@ mod tests {
value, value,
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Verified balance matches total balance // Verified balance matches total balance
let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap(); let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
assert_eq!( assert_eq!(
get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
assert_eq!(
get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(),
value value
); );
let to = TransparentAddress::PublicKey([7; 20]).into(); let to = TransparentAddress::PublicKey([7; 20]).into();
assert_matches!( assert_matches!(
create_spend_to_address( create_spend_to_address(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&usk, &usk,
@ -1085,9 +1099,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap(); init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet // Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec()); let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap(); let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet // Add funds to the wallet
@ -1112,15 +1125,17 @@ mod tests {
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
} }
let mut db_write = db_data.get_update_ops().unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Verified balance matches total balance // Verified balance matches total balance
let total = Amount::from_u64(60000).unwrap(); let total = Amount::from_u64(60000).unwrap();
let (_, anchor_height) = db_data.get_target_and_anchor_heights(1).unwrap().unwrap(); let (_, anchor_height) = db_data.get_target_and_anchor_heights(1).unwrap().unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), total);
assert_eq!( assert_eq!(
get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
total
);
assert_eq!(
get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(),
total total
); );
@ -1142,7 +1157,7 @@ mod tests {
assert_matches!( assert_matches!(
spend( spend(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&input_selector, &input_selector,
@ -1170,7 +1185,7 @@ mod tests {
assert_matches!( assert_matches!(
spend( spend(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&input_selector, &input_selector,
@ -1195,9 +1210,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap(); init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet // Add an account to the wallet
let mut db_write = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec()); let seed = Secret::new([0u8; 32].to_vec());
let (account_id, usk) = db_write.create_account(&seed).unwrap(); let (account_id, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key(); let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
let uaddr = db_data.get_current_address(account_id).unwrap().unwrap(); let uaddr = db_data.get_current_address(account_id).unwrap().unwrap();
let taddr = uaddr.transparent().unwrap(); let taddr = uaddr.transparent().unwrap();
@ -1212,7 +1226,7 @@ mod tests {
) )
.unwrap(); .unwrap();
let res0 = db_write.put_received_transparent_utxo(&utxo); let res0 = db_data.put_received_transparent_utxo(&utxo);
assert!(matches!(res0, Ok(_))); assert!(matches!(res0, Ok(_)));
let input_selector = GreedyInputSelector::new( let input_selector = GreedyInputSelector::new(
@ -1229,11 +1243,11 @@ mod tests {
Amount::from_u64(50000).unwrap(), Amount::from_u64(50000).unwrap(),
); );
insert_into_cache(&db_cache, &cb); insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
assert_matches!( assert_matches!(
shield_transparent_funds( shield_transparent_funds(
&mut db_write, &mut db_data,
&tests::network(), &tests::network(),
test_prover(), test_prover(),
&input_selector, &input_selector,