zcash_client_sqlite: Remove the remainder of DataConnStmtCache

This commit is contained in:
Kris Nuttycombe 2023-06-09 11:02:00 -06:00
parent bf7f05282f
commit 2674209818
7 changed files with 631 additions and 816 deletions

View File

@ -299,7 +299,7 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data);
let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Empty chain should return None
assert_matches!(db_data.get_max_height_hash(), Ok(None));
@ -328,8 +328,7 @@ mod tests {
assert_matches!(validate_chain_result, Ok(()));
// Scan the cache
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Data-only chain should be valid
validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap();
@ -348,7 +347,7 @@ mod tests {
validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap();
// Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Data-only chain should be valid
validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap();
@ -365,7 +364,7 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data);
let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Create some fake CompactBlocks
let (cb, _) = fake_compact_block(
@ -386,8 +385,7 @@ mod tests {
insert_into_cache(&db_cache, &cb2);
// Scan the cache
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Data-only chain should be valid
validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap();
@ -427,7 +425,7 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data);
let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Create some fake CompactBlocks
let (cb, _) = fake_compact_block(
@ -448,8 +446,7 @@ mod tests {
insert_into_cache(&db_cache, &cb2);
// Scan the cache
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Data-only chain should be valid
validate_chain(&db_cache, db_data.get_max_height_hash().unwrap(), None).unwrap();
@ -489,11 +486,11 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data);
let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Account balance should be zero
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::zero()
);
@ -519,36 +516,46 @@ mod tests {
insert_into_cache(&db_cache, &cb2);
// Scan the cache
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Account balance should reflect both received notes
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value + value2).unwrap()
);
// "Rewind" to height of last scanned block
truncate_to_height(&db_data, sapling_activation_height() + 1).unwrap();
db_data
.transactionally(|wdb| {
truncate_to_height(&wdb.conn.0, &wdb.params, sapling_activation_height() + 1)
})
.unwrap();
// Account balance should be unaltered
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value + value2).unwrap()
);
// Rewind so that one block is dropped
truncate_to_height(&db_data, sapling_activation_height()).unwrap();
db_data
.transactionally(|wdb| {
truncate_to_height(&wdb.conn.0, &wdb.params, sapling_activation_height())
})
.unwrap();
// Account balance should only contain the first received note
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
// Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Account balance should again reflect both received notes
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value + value2).unwrap()
);
}
@ -564,7 +571,7 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data);
let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Create a block with height SAPLING_ACTIVATION_HEIGHT
let value = Amount::from_u64(50000).unwrap();
@ -576,9 +583,11 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb1);
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
// We cannot scan a block of height SAPLING_ACTIVATION_HEIGHT + 2 next
let (cb2, _) = fake_compact_block(
@ -596,7 +605,7 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb3);
match scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None) {
match scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None) {
Err(Error::Chain(e)) => {
assert_matches!(
e.cause(),
@ -609,9 +618,9 @@ mod tests {
// If we add a block of height SAPLING_ACTIVATION_HEIGHT + 1, we can now scan both
insert_into_cache(&db_cache, &cb2);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::from_u64(150_000).unwrap()
);
}
@ -627,11 +636,11 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data);
let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Account balance should be zero
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::zero()
);
@ -647,11 +656,13 @@ mod tests {
insert_into_cache(&db_cache, &cb);
// Scan the cache
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Account balance should reflect the received note
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
// Create a second fake CompactBlock sending more value to the address
let value2 = Amount::from_u64(7).unwrap();
@ -665,11 +676,11 @@ mod tests {
insert_into_cache(&db_cache, &cb2);
// Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Account balance should reflect both received notes
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value + value2).unwrap()
);
}
@ -685,11 +696,11 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&db_data);
let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
// Account balance should be zero
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::zero()
);
@ -705,11 +716,13 @@ mod tests {
insert_into_cache(&db_cache, &cb);
// Scan the cache
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Account balance should reflect the received note
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
// Create a second fake CompactBlock spending value from the address
let extsk2 = ExtendedSpendingKey::master(&[0]);
@ -728,11 +741,11 @@ mod tests {
);
// Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Account balance should equal the change
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value - value2).unwrap()
);
}

View File

@ -32,11 +32,9 @@
// Catch documentation errors caused by code changes.
#![deny(rustdoc::broken_intra_doc_links)]
use rusqlite::Connection;
use rusqlite::{self, Connection};
use secrecy::{ExposeSecret, SecretVec};
use std::collections::HashMap;
use std::fmt;
use std::path::Path;
use std::{borrow::Borrow, collections::HashMap, convert::AsRef, fmt, path::Path};
use zcash_primitives::{
block::BlockHash,
@ -72,9 +70,6 @@ use {
std::{fs, io},
};
mod prepared;
pub use prepared::DataConnStmtCache;
pub mod chain;
pub mod error;
pub mod wallet;
@ -107,12 +102,21 @@ impl fmt::Display for NoteId {
pub struct UtxoId(pub i64);
/// A wrapper for the SQLite connection to the wallet database.
pub struct WalletDb<P> {
conn: Connection,
pub struct WalletDb<C, P> {
conn: C,
params: P,
}
impl<P: consensus::Parameters> WalletDb<P> {
/// A wrapper for a SQLite transaction affecting the wallet database.
pub struct WalletTransaction<'conn>(pub(crate) rusqlite::Transaction<'conn>);
impl Borrow<rusqlite::Connection> for WalletTransaction<'_> {
fn borrow(&self) -> &rusqlite::Connection {
&self.0
}
}
impl<P: consensus::Parameters + Clone> WalletDb<Connection, P> {
/// Construct a connection to the wallet database stored at the specified path.
pub fn for_path<F: AsRef<Path>>(path: F, params: P) -> Result<Self, rusqlite::Error> {
Connection::open(path).and_then(move |conn| {
@ -121,53 +125,60 @@ impl<P: consensus::Parameters> WalletDb<P> {
})
}
/// Given a wallet database connection, obtain a handle for the write operations
/// for that database. This operation may eagerly initialize and cache sqlite
/// prepared statements that are used in write operations.
pub fn get_update_ops(&self) -> Result<DataConnStmtCache<'_, P>, SqliteClientError> {
DataConnStmtCache::new(self)
pub fn transactionally<F, A>(&mut self, f: F) -> Result<A, SqliteClientError>
where
F: FnOnce(&WalletDb<WalletTransaction<'_>, P>) -> Result<A, SqliteClientError>,
{
let wdb = WalletDb {
conn: WalletTransaction(self.conn.transaction()?),
params: self.params.clone(),
};
let result = f(&wdb)?;
wdb.conn.0.commit()?;
Ok(result)
}
}
impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> WalletRead for WalletDb<C, P> {
type Error = SqliteClientError;
type NoteRef = NoteId;
type TxRef = i64;
fn block_height_extrema(&self) -> Result<Option<(BlockHeight, BlockHeight)>, Self::Error> {
wallet::block_height_extrema(self).map_err(SqliteClientError::from)
wallet::block_height_extrema(self.conn.borrow()).map_err(SqliteClientError::from)
}
fn get_min_unspent_height(&self) -> Result<Option<BlockHeight>, Self::Error> {
wallet::get_min_unspent_height(self).map_err(SqliteClientError::from)
wallet::get_min_unspent_height(self.conn.borrow()).map_err(SqliteClientError::from)
}
fn get_block_hash(&self, block_height: BlockHeight) -> Result<Option<BlockHash>, Self::Error> {
wallet::get_block_hash(self, block_height).map_err(SqliteClientError::from)
wallet::get_block_hash(self.conn.borrow(), block_height).map_err(SqliteClientError::from)
}
fn get_tx_height(&self, txid: TxId) -> Result<Option<BlockHeight>, Self::Error> {
wallet::get_tx_height(self, txid).map_err(SqliteClientError::from)
wallet::get_tx_height(self.conn.borrow(), txid).map_err(SqliteClientError::from)
}
fn get_unified_full_viewing_keys(
&self,
) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error> {
wallet::get_unified_full_viewing_keys(self)
wallet::get_unified_full_viewing_keys(self.conn.borrow(), &self.params)
}
fn get_account_for_ufvk(
&self,
ufvk: &UnifiedFullViewingKey,
) -> Result<Option<AccountId>, Self::Error> {
wallet::get_account_for_ufvk(self, ufvk)
wallet::get_account_for_ufvk(self.conn.borrow(), &self.params, ufvk)
}
fn get_current_address(
&self,
account: AccountId,
) -> Result<Option<UnifiedAddress>, Self::Error> {
wallet::get_current_address(self, account).map(|res| res.map(|(addr, _)| addr))
wallet::get_current_address(self.conn.borrow(), &self.params, account)
.map(|res| res.map(|(addr, _)| addr))
}
fn is_valid_account_extfvk(
@ -175,7 +186,7 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
account: AccountId,
extfvk: &ExtendedFullViewingKey,
) -> Result<bool, Self::Error> {
wallet::is_valid_account_extfvk(self, account, extfvk)
wallet::is_valid_account_extfvk(self.conn.borrow(), &self.params, account, extfvk)
}
fn get_balance_at(
@ -183,17 +194,19 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
account: AccountId,
anchor_height: BlockHeight,
) -> Result<Amount, Self::Error> {
wallet::get_balance_at(self, account, anchor_height)
wallet::get_balance_at(self.conn.borrow(), account, anchor_height)
}
fn get_transaction(&self, id_tx: i64) -> Result<Transaction, Self::Error> {
wallet::get_transaction(self, id_tx)
wallet::get_transaction(self.conn.borrow(), &self.params, id_tx)
}
fn get_memo(&self, id_note: Self::NoteRef) -> Result<Option<Memo>, Self::Error> {
match id_note {
NoteId::SentNoteId(id_note) => wallet::get_sent_memo(self, id_note),
NoteId::ReceivedNoteId(id_note) => wallet::get_received_memo(self, id_note),
NoteId::SentNoteId(id_note) => wallet::get_sent_memo(self.conn.borrow(), id_note),
NoteId::ReceivedNoteId(id_note) => {
wallet::get_received_memo(self.conn.borrow(), id_note)
}
}
}
@ -201,7 +214,7 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
&self,
block_height: BlockHeight,
) -> Result<Option<sapling::CommitmentTree>, Self::Error> {
wallet::sapling::get_sapling_commitment_tree(self, block_height)
wallet::sapling::get_sapling_commitment_tree(self.conn.borrow(), block_height)
}
#[allow(clippy::type_complexity)]
@ -209,7 +222,7 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
&self,
block_height: BlockHeight,
) -> Result<Vec<(Self::NoteRef, sapling::IncrementalWitness)>, Self::Error> {
wallet::sapling::get_sapling_witnesses(self, block_height)
wallet::sapling::get_sapling_witnesses(self.conn.borrow(), block_height)
}
fn get_sapling_nullifiers(
@ -217,8 +230,8 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
query: data_api::NullifierQuery,
) -> Result<Vec<(AccountId, sapling::Nullifier)>, Self::Error> {
match query {
NullifierQuery::Unspent => wallet::sapling::get_sapling_nullifiers(self),
NullifierQuery::All => wallet::sapling::get_all_sapling_nullifiers(&self.conn),
NullifierQuery::Unspent => wallet::sapling::get_sapling_nullifiers(self.conn.borrow()),
NullifierQuery::All => wallet::sapling::get_all_sapling_nullifiers(self.conn.borrow()),
}
}
@ -228,7 +241,12 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
anchor_height: BlockHeight,
exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedSaplingNote<Self::NoteRef>>, Self::Error> {
wallet::sapling::get_spendable_sapling_notes(&self.conn, account, anchor_height, exclude)
wallet::sapling::get_spendable_sapling_notes(
self.conn.borrow(),
account,
anchor_height,
exclude,
)
}
fn select_spendable_sapling_notes(
@ -239,7 +257,7 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedSaplingNote<Self::NoteRef>>, Self::Error> {
wallet::sapling::select_spendable_sapling_notes(
&self.conn,
self.conn.borrow(),
account,
target_value,
anchor_height,
@ -252,7 +270,7 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
_account: AccountId,
) -> Result<HashMap<TransparentAddress, AddressMetadata>, Self::Error> {
#[cfg(feature = "transparent-inputs")]
return wallet::get_transparent_receivers(&self.params, &self.conn, _account);
return wallet::get_transparent_receivers(self.conn.borrow(), &self.params, _account);
#[cfg(not(feature = "transparent-inputs"))]
panic!(
@ -267,7 +285,13 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
_exclude: &[OutPoint],
) -> Result<Vec<WalletTransparentOutput>, Self::Error> {
#[cfg(feature = "transparent-inputs")]
return wallet::get_unspent_transparent_outputs(self, _address, _max_height, _exclude);
return wallet::get_unspent_transparent_outputs(
self.conn.borrow(),
&self.params,
_address,
_max_height,
_exclude,
);
#[cfg(not(feature = "transparent-inputs"))]
panic!(
@ -281,7 +305,12 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
_max_height: BlockHeight,
) -> Result<HashMap<TransparentAddress, Amount>, Self::Error> {
#[cfg(feature = "transparent-inputs")]
return wallet::get_transparent_balances(self, _account, _max_height);
return wallet::get_transparent_balances(
self.conn.borrow(),
&self.params,
_account,
_max_height,
);
#[cfg(not(feature = "transparent-inputs"))]
panic!(
@ -290,177 +319,15 @@ impl<P: consensus::Parameters> WalletRead for WalletDb<P> {
}
}
impl<'a, P: consensus::Parameters> WalletRead for DataConnStmtCache<'a, P> {
type Error = SqliteClientError;
type NoteRef = NoteId;
type TxRef = i64;
fn block_height_extrema(&self) -> Result<Option<(BlockHeight, BlockHeight)>, Self::Error> {
self.wallet_db.block_height_extrema()
}
fn get_min_unspent_height(&self) -> Result<Option<BlockHeight>, Self::Error> {
self.wallet_db.get_min_unspent_height()
}
fn get_block_hash(&self, block_height: BlockHeight) -> Result<Option<BlockHash>, Self::Error> {
self.wallet_db.get_block_hash(block_height)
}
fn get_tx_height(&self, txid: TxId) -> Result<Option<BlockHeight>, Self::Error> {
self.wallet_db.get_tx_height(txid)
}
fn get_unified_full_viewing_keys(
&self,
) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, Self::Error> {
self.wallet_db.get_unified_full_viewing_keys()
}
fn get_account_for_ufvk(
&self,
ufvk: &UnifiedFullViewingKey,
) -> Result<Option<AccountId>, Self::Error> {
self.wallet_db.get_account_for_ufvk(ufvk)
}
fn get_current_address(
&self,
account: AccountId,
) -> Result<Option<UnifiedAddress>, Self::Error> {
self.wallet_db.get_current_address(account)
}
fn is_valid_account_extfvk(
&self,
account: AccountId,
extfvk: &ExtendedFullViewingKey,
) -> Result<bool, Self::Error> {
self.wallet_db.is_valid_account_extfvk(account, extfvk)
}
fn get_balance_at(
&self,
account: AccountId,
anchor_height: BlockHeight,
) -> Result<Amount, Self::Error> {
self.wallet_db.get_balance_at(account, anchor_height)
}
fn get_transaction(&self, id_tx: i64) -> Result<Transaction, Self::Error> {
self.wallet_db.get_transaction(id_tx)
}
fn get_memo(&self, id_note: Self::NoteRef) -> Result<Option<Memo>, Self::Error> {
self.wallet_db.get_memo(id_note)
}
fn get_commitment_tree(
&self,
block_height: BlockHeight,
) -> Result<Option<sapling::CommitmentTree>, Self::Error> {
self.wallet_db.get_commitment_tree(block_height)
}
#[allow(clippy::type_complexity)]
fn get_witnesses(
&self,
block_height: BlockHeight,
) -> Result<Vec<(Self::NoteRef, sapling::IncrementalWitness)>, Self::Error> {
self.wallet_db.get_witnesses(block_height)
}
fn get_sapling_nullifiers(
&self,
query: data_api::NullifierQuery,
) -> Result<Vec<(AccountId, sapling::Nullifier)>, Self::Error> {
self.wallet_db.get_sapling_nullifiers(query)
}
fn get_spendable_sapling_notes(
&self,
account: AccountId,
anchor_height: BlockHeight,
exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedSaplingNote<Self::NoteRef>>, Self::Error> {
self.wallet_db
.get_spendable_sapling_notes(account, anchor_height, exclude)
}
fn select_spendable_sapling_notes(
&self,
account: AccountId,
target_value: Amount,
anchor_height: BlockHeight,
exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedSaplingNote<Self::NoteRef>>, Self::Error> {
self.wallet_db
.select_spendable_sapling_notes(account, target_value, anchor_height, exclude)
}
fn get_transparent_receivers(
&self,
account: AccountId,
) -> Result<HashMap<TransparentAddress, AddressMetadata>, Self::Error> {
self.wallet_db.get_transparent_receivers(account)
}
fn get_unspent_transparent_outputs(
&self,
address: &TransparentAddress,
max_height: BlockHeight,
exclude: &[OutPoint],
) -> Result<Vec<WalletTransparentOutput>, Self::Error> {
self.wallet_db
.get_unspent_transparent_outputs(address, max_height, exclude)
}
fn get_transparent_balances(
&self,
account: AccountId,
max_height: BlockHeight,
) -> Result<HashMap<TransparentAddress, Amount>, Self::Error> {
self.wallet_db.get_transparent_balances(account, max_height)
}
}
impl<'a, P: consensus::Parameters> DataConnStmtCache<'a, P> {
fn transactionally<F, A>(&mut self, f: F) -> Result<A, SqliteClientError>
where
F: FnOnce(&mut Self) -> Result<A, SqliteClientError>,
{
self.wallet_db.conn.execute("BEGIN IMMEDIATE", [])?;
match f(self) {
Ok(result) => {
self.wallet_db.conn.execute("COMMIT", [])?;
Ok(result)
}
Err(error) => {
match self.wallet_db.conn.execute("ROLLBACK", []) {
Ok(_) => Err(error),
Err(e) =>
// Panicking here is probably the right thing to do, because it
// means the database is corrupt.
panic!(
"Rollback failed with error {} while attempting to recover from error {}; database is likely corrupt.",
e,
error
)
}
}
}
}
}
impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P> {
type UtxoRef = UtxoId;
fn create_account(
&mut self,
seed: &SecretVec<u8>,
) -> Result<(AccountId, UnifiedSpendingKey), Self::Error> {
self.transactionally(|stmts| {
let account = wallet::get_max_account_id(stmts.wallet_db)?
self.transactionally(|wdb| {
let account = wallet::get_max_account_id(&wdb.conn.0)?
.map(|a| AccountId::from(u32::from(a) + 1))
.unwrap_or_else(|| AccountId::from(0));
@ -468,15 +335,11 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
return Err(SqliteClientError::AccountIdOutOfRange);
}
let usk = UnifiedSpendingKey::from_seed(
&stmts.wallet_db.params,
seed.expose_secret(),
account,
)
.map_err(|_| SqliteClientError::KeyDerivationError(account))?;
let usk = UnifiedSpendingKey::from_seed(&wdb.params, seed.expose_secret(), account)
.map_err(|_| SqliteClientError::KeyDerivationError(account))?;
let ufvk = usk.to_unified_full_viewing_key();
wallet::add_account(stmts.wallet_db, account, &ufvk)?;
wallet::add_account(&wdb.conn.0, &wdb.params, account, &ufvk)?;
Ok((account, usk))
})
@ -486,34 +349,37 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
&mut self,
account: AccountId,
) -> Result<Option<UnifiedAddress>, Self::Error> {
match self.get_unified_full_viewing_keys()?.get(&account) {
Some(ufvk) => {
let search_from = match wallet::get_current_address(self.wallet_db, account)? {
Some((_, mut last_diversifier_index)) => {
last_diversifier_index
.increment()
.map_err(|_| SqliteClientError::DiversifierIndexOutOfRange)?;
last_diversifier_index
}
None => DiversifierIndex::default(),
};
self.transactionally(
|wdb| match wdb.get_unified_full_viewing_keys()?.get(&account) {
Some(ufvk) => {
let search_from =
match wallet::get_current_address(&wdb.conn.0, &wdb.params, account)? {
Some((_, mut last_diversifier_index)) => {
last_diversifier_index
.increment()
.map_err(|_| SqliteClientError::DiversifierIndexOutOfRange)?;
last_diversifier_index
}
None => DiversifierIndex::default(),
};
let (addr, diversifier_index) = ufvk
.find_address(search_from)
.ok_or(SqliteClientError::DiversifierIndexOutOfRange)?;
let (addr, diversifier_index) = ufvk
.find_address(search_from)
.ok_or(SqliteClientError::DiversifierIndexOutOfRange)?;
wallet::insert_address(
&self.wallet_db.conn,
&self.wallet_db.params,
account,
diversifier_index,
&addr,
)?;
wallet::insert_address(
&wdb.conn.0,
&wdb.params,
account,
diversifier_index,
&addr,
)?;
Ok(Some(addr))
}
None => Ok(None),
}
Ok(Some(addr))
}
None => Ok(None),
},
)
}
#[tracing::instrument(skip_all, fields(height = u32::from(block.block_height)))]
@ -523,11 +389,10 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
block: &PrunedBlock,
updated_witnesses: &[(Self::NoteRef, sapling::IncrementalWitness)],
) -> Result<Vec<(Self::NoteRef, sapling::IncrementalWitness)>, Self::Error> {
// database updates for each block are transactional
self.transactionally(|up| {
self.transactionally(|wdb| {
// Insert the block into the database.
wallet::insert_block(
up,
&wdb.conn.0,
block.block_height,
block.block_hash,
block.block_time,
@ -536,20 +401,16 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
let mut new_witnesses = vec![];
for tx in block.transactions {
let tx_row = wallet::put_tx_meta(up, tx, block.block_height)?;
let tx_row = wallet::put_tx_meta(&wdb.conn.0, tx, block.block_height)?;
// Mark notes as spent and remove them from the scanning cache
for spend in &tx.sapling_spends {
wallet::sapling::mark_sapling_note_spent(
&up.wallet_db.conn,
tx_row,
spend.nf(),
)?;
wallet::sapling::mark_sapling_note_spent(&wdb.conn.0, tx_row, spend.nf())?;
}
for output in &tx.sapling_outputs {
let received_note_id =
wallet::sapling::put_received_note(&up.wallet_db.conn, output, tx_row)?;
wallet::sapling::put_received_note(&wdb.conn.0, output, tx_row)?;
// Save witness for note.
new_witnesses.push((received_note_id, output.witness().clone()));
@ -560,17 +421,22 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
for (received_note_id, witness) in updated_witnesses.iter().chain(new_witnesses.iter())
{
if let NoteId::ReceivedNoteId(rnid) = *received_note_id {
wallet::sapling::insert_witness(up, rnid, witness, block.block_height)?;
wallet::sapling::insert_witness(
&wdb.conn.0,
rnid,
witness,
block.block_height,
)?;
} else {
return Err(SqliteClientError::InvalidNoteId);
}
}
// Prune the stored witnesses (we only expect rollbacks of at most PRUNING_HEIGHT blocks).
wallet::prune_witnesses(up, block.block_height - PRUNING_HEIGHT)?;
wallet::prune_witnesses(&wdb.conn.0, block.block_height - PRUNING_HEIGHT)?;
// Update now-expired transactions that didn't get mined.
wallet::update_expired_notes(up, block.block_height)?;
wallet::update_expired_notes(&wdb.conn.0, block.block_height)?;
Ok(new_witnesses)
})
@ -580,93 +446,114 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
&mut self,
d_tx: DecryptedTransaction,
) -> Result<Self::TxRef, Self::Error> {
self.transactionally(|up| {
let tx_ref = wallet::put_tx_data(up, d_tx.tx, None, None)?;
self.transactionally(|wdb| {
let tx_ref = wallet::put_tx_data(&wdb.conn.0, d_tx.tx, None, None)?;
let mut spending_account_id: Option<AccountId> = None;
for output in d_tx.sapling_outputs {
match output.transfer_type {
TransferType::Outgoing | TransferType::WalletInternal => {
let recipient = if output.transfer_type == TransferType::Outgoing {
Recipient::Sapling(output.note.recipient())
} else {
Recipient::InternalAccount(output.account, PoolType::Sapling)
};
let mut spending_account_id: Option<AccountId> = None;
for output in d_tx.sapling_outputs {
match output.transfer_type {
TransferType::Outgoing | TransferType::WalletInternal => {
let recipient = if output.transfer_type == TransferType::Outgoing {
Recipient::Sapling(output.note.recipient())
} else {
Recipient::InternalAccount(output.account, PoolType::Sapling)
};
wallet::put_sent_output(
&up.wallet_db.conn,
&up.wallet_db.params,
output.account,
tx_ref,
output.index,
&recipient,
Amount::from_u64(output.note.value().inner()).map_err(|_|
SqliteClientError::CorruptedData("Note value is not a valid Zcash amount.".to_string()))?,
Some(&output.memo),
)?;
wallet::put_sent_output(
&wdb.conn.0,
&wdb.params,
output.account,
tx_ref,
output.index,
&recipient,
Amount::from_u64(output.note.value().inner()).map_err(|_| {
SqliteClientError::CorruptedData(
"Note value is not a valid Zcash amount.".to_string(),
)
})?,
Some(&output.memo),
)?;
if matches!(recipient, Recipient::InternalAccount(_, _)) {
wallet::sapling::put_received_note(&up.wallet_db.conn, output, tx_ref)?;
}
if matches!(recipient, Recipient::InternalAccount(_, _)) {
wallet::sapling::put_received_note(&wdb.conn.0, output, tx_ref)?;
}
TransferType::Incoming => {
match spending_account_id {
Some(id) =>
if id != output.account {
panic!("Unable to determine a unique account identifier for z->t spend.");
}
None => {
spending_account_id = Some(output.account);
}
TransferType::Incoming => {
match spending_account_id {
Some(id) => {
if id != output.account {
panic!("Unable to determine a unique account identifier for z->t spend.");
}
}
wallet::sapling::put_received_note(&up.wallet_db.conn, output, tx_ref)?;
}
}
}
// If any of the utxos spent in the transaction are ours, mark them as spent.
#[cfg(feature = "transparent-inputs")]
for txin in d_tx.tx.transparent_bundle().iter().flat_map(|b| b.vin.iter()) {
wallet::mark_transparent_utxo_spent(&up.wallet_db.conn, tx_ref, &txin.prevout)?;
}
// If we have some transparent outputs:
if !d_tx.tx.transparent_bundle().iter().any(|b| b.vout.is_empty()) {
let nullifiers = self.wallet_db.get_sapling_nullifiers(data_api::NullifierQuery::All)?;
// If the transaction contains shielded spends from our wallet, we will store z->t
// transactions we observe in the same way they would be stored by
// create_spend_to_address.
if let Some((account_id, _)) = nullifiers.iter().find(
|(_, nf)|
d_tx.tx.sapling_bundle().iter().flat_map(|b| b.shielded_spends().iter())
.any(|input| nf == input.nullifier())
) {
for (output_index, txout) in d_tx.tx.transparent_bundle().iter().flat_map(|b| b.vout.iter()).enumerate() {
if let Some(address) = txout.recipient_address() {
wallet::put_sent_output(
&up.wallet_db.conn,
&up.wallet_db.params,
*account_id,
tx_ref,
output_index,
&Recipient::Transparent(address),
txout.value,
None
)?;
None => {
spending_account_id = Some(output.account);
}
}
wallet::sapling::put_received_note(&wdb.conn.0, output, tx_ref)?;
}
}
Ok(tx_ref)
}
// If any of the utxos spent in the transaction are ours, mark them as spent.
#[cfg(feature = "transparent-inputs")]
for txin in d_tx
.tx
.transparent_bundle()
.iter()
.flat_map(|b| b.vin.iter())
{
wallet::mark_transparent_utxo_spent(&wdb.conn.0, tx_ref, &txin.prevout)?;
}
// If we have some transparent outputs:
if !d_tx
.tx
.transparent_bundle()
.iter()
.any(|b| b.vout.is_empty())
{
let nullifiers = wdb.get_sapling_nullifiers(data_api::NullifierQuery::All)?;
// If the transaction contains shielded spends from our wallet, we will store z->t
// transactions we observe in the same way they would be stored by
// create_spend_to_address.
if let Some((account_id, _)) = nullifiers.iter().find(|(_, nf)| {
d_tx.tx
.sapling_bundle()
.iter()
.flat_map(|b| b.shielded_spends().iter())
.any(|input| nf == input.nullifier())
}) {
for (output_index, txout) in d_tx
.tx
.transparent_bundle()
.iter()
.flat_map(|b| b.vout.iter())
.enumerate()
{
if let Some(address) = txout.recipient_address() {
wallet::put_sent_output(
&wdb.conn.0,
&wdb.params,
*account_id,
tx_ref,
output_index,
&Recipient::Transparent(address),
txout.value,
None,
)?;
}
}
}
}
Ok(tx_ref)
})
}
fn store_sent_tx(&mut self, sent_tx: &SentTransaction) -> Result<Self::TxRef, Self::Error> {
// Update the database atomically, to ensure the result is internally consistent.
self.transactionally(|up| {
self.transactionally(|wdb| {
let tx_ref = wallet::put_tx_data(
up,
&wdb.conn.0,
sent_tx.tx,
Some(sent_tx.fee_amount),
Some(sent_tx.created),
@ -683,7 +570,7 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
if let Some(bundle) = sent_tx.tx.sapling_bundle() {
for spend in bundle.shielded_spends() {
wallet::sapling::mark_sapling_note_spent(
&up.wallet_db.conn,
&wdb.conn.0,
tx_ref,
spend.nullifier(),
)?;
@ -692,13 +579,13 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
#[cfg(feature = "transparent-inputs")]
for utxo_outpoint in &sent_tx.utxos_spent {
wallet::mark_transparent_utxo_spent(&up.wallet_db.conn, tx_ref, utxo_outpoint)?;
wallet::mark_transparent_utxo_spent(&wdb.conn.0, tx_ref, utxo_outpoint)?;
}
for output in &sent_tx.outputs {
wallet::insert_sent_output(
&up.wallet_db.conn,
&up.wallet_db.params,
&wdb.conn.0,
&wdb.params,
tx_ref,
sent_tx.account,
output,
@ -706,7 +593,7 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
if let Some((account, note)) = output.sapling_change_to() {
wallet::sapling::put_received_note(
&up.wallet_db.conn,
&wdb.conn.0,
&DecryptedOutput {
index: output.output_index(),
note: note.clone(),
@ -727,7 +614,9 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
}
fn truncate_to_height(&mut self, block_height: BlockHeight) -> Result<(), Self::Error> {
wallet::truncate_to_height(self.wallet_db, block_height)
self.transactionally(|wdb| {
wallet::truncate_to_height(&wdb.conn.0, &wdb.params, block_height)
})
}
fn put_received_transparent_utxo(
@ -735,11 +624,7 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
_output: &WalletTransparentOutput,
) -> Result<Self::UtxoRef, Self::Error> {
#[cfg(feature = "transparent-inputs")]
return wallet::put_received_transparent_utxo(
&self.wallet_db.conn,
&self.wallet_db.params,
_output,
);
return wallet::put_received_transparent_utxo(&self.conn, &self.params, _output);
#[cfg(not(feature = "transparent-inputs"))]
panic!(
@ -1083,7 +968,7 @@ mod tests {
#[cfg(test)]
pub(crate) fn init_test_accounts_table(
db_data: &WalletDb<Network>,
db_data: &mut WalletDb<rusqlite::Connection, Network>,
) -> (DiversifiableFullViewingKey, Option<TransparentAddress>) {
let (ufvk, taddr) = init_test_accounts_table_ufvk(db_data);
(ufvk.sapling().unwrap().clone(), taddr)
@ -1091,7 +976,7 @@ mod tests {
#[cfg(test)]
pub(crate) fn init_test_accounts_table_ufvk(
db_data: &WalletDb<Network>,
db_data: &mut WalletDb<rusqlite::Connection, Network>,
) -> (UnifiedFullViewingKey, Option<TransparentAddress>) {
let seed = [0u8; 32];
let account = AccountId::from(0);
@ -1318,13 +1203,12 @@ mod tests {
let account = AccountId::from(0);
init_wallet_db(&mut db_data, None).unwrap();
let _ = init_test_accounts_table_ufvk(&db_data);
init_test_accounts_table_ufvk(&mut db_data);
let current_addr = db_data.get_current_address(account).unwrap();
assert!(current_addr.is_some());
let mut update_ops = db_data.get_update_ops().unwrap();
let addr2 = update_ops.get_next_available_address(account).unwrap();
let addr2 = db_data.get_next_available_address(account).unwrap();
assert!(addr2.is_some());
assert_ne!(current_addr, addr2);
@ -1349,7 +1233,7 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet.
let (ufvk, taddr) = init_test_accounts_table_ufvk(&db_data);
let (ufvk, taddr) = init_test_accounts_table_ufvk(&mut db_data);
let taddr = taddr.unwrap();
let receivers = db_data.get_transparent_receivers(0.into()).unwrap();

View File

@ -1,95 +0,0 @@
//! Prepared SQL statements used by the wallet.
//!
//! Some `rusqlite` crate APIs are only available on prepared statements; these are stored
//! inside the [`DataConnStmtCache`]. When adding a new prepared statement:
//!
//! - Add it as a private field of `DataConnStmtCache`.
//! - Build the statement in [`DataConnStmtCache::new`].
//! - Add a crate-private helper method to `DataConnStmtCache` for running the statement.
use rusqlite::{params, Statement};
use zcash_primitives::{consensus::BlockHeight, merkle_tree::write_incremental_witness, sapling};
use crate::{error::SqliteClientError, NoteId, WalletDb};
/// The primary type used to implement [`WalletWrite`] for the SQLite database.
///
/// A data structure that stores the SQLite prepared statements that are
/// required for the implementation of [`WalletWrite`] against the backing
/// store.
///
/// [`WalletWrite`]: zcash_client_backend::data_api::WalletWrite
pub struct DataConnStmtCache<'a, P> {
pub(crate) wallet_db: &'a WalletDb<P>,
stmt_insert_witness: Statement<'a>,
stmt_prune_witnesses: Statement<'a>,
stmt_update_expired: Statement<'a>,
}
impl<'a, P> DataConnStmtCache<'a, P> {
pub(crate) fn new(wallet_db: &'a WalletDb<P>) -> Result<Self, SqliteClientError> {
Ok(
DataConnStmtCache {
wallet_db,
#[cfg(feature = "transparent-inputs")]
stmt_insert_witness: wallet_db.conn.prepare(
"INSERT INTO sapling_witnesses (note, block, witness)
VALUES (?, ?, ?)",
)?,
stmt_prune_witnesses: wallet_db.conn.prepare(
"DELETE FROM sapling_witnesses WHERE block < ?"
)?,
stmt_update_expired: wallet_db.conn.prepare(
"UPDATE sapling_received_notes SET spent = NULL WHERE EXISTS (
SELECT id_tx FROM transactions
WHERE id_tx = sapling_received_notes.spent AND block IS NULL AND expiry_height < ?
)",
)?,
}
)
}
}
impl<'a, P> DataConnStmtCache<'a, P> {
/// Records the incremental witness for the specified note, as of the given block
/// height.
///
/// Returns `SqliteClientError::InvalidNoteId` if the note ID is for a sent note.
pub(crate) fn stmt_insert_witness(
&mut self,
note_id: NoteId,
height: BlockHeight,
witness: &sapling::IncrementalWitness,
) -> Result<(), SqliteClientError> {
let note_id = match note_id {
NoteId::ReceivedNoteId(note_id) => Ok(note_id),
NoteId::SentNoteId(_) => Err(SqliteClientError::InvalidNoteId),
}?;
let mut encoded = Vec::new();
write_incremental_witness(witness, &mut encoded).unwrap();
self.stmt_insert_witness
.execute(params![note_id, u32::from(height), encoded])?;
Ok(())
}
/// Removes old incremental witnesses up to the given block height.
pub(crate) fn stmt_prune_witnesses(
&mut self,
below_height: BlockHeight,
) -> Result<(), SqliteClientError> {
self.stmt_prune_witnesses
.execute([u32::from(below_height)])?;
Ok(())
}
/// Marks notes that have not been mined in transactions as expired, up to the given
/// block height.
pub fn stmt_update_expired(&mut self, height: BlockHeight) -> Result<(), SqliteClientError> {
self.stmt_update_expired.execute([u32::from(height)])?;
Ok(())
}
}

View File

@ -1,4 +1,4 @@
//! Functions for querying information in the wdb database.
//! Functions for querying information in the wallet database.
//!
//! These functions should generally not be used directly; instead,
//! their functionality is available via the [`WalletRead`] and
@ -64,7 +64,7 @@
//! wallet.
//! - `memo` the shielded memo associated with the output, if any.
use rusqlite::{named_params, params, Connection, OptionalExtension, ToSql};
use rusqlite::{self, named_params, params, OptionalExtension, ToSql};
use std::collections::HashMap;
use std::convert::TryFrom;
@ -89,7 +89,7 @@ use zcash_client_backend::{
wallet::WalletTx,
};
use crate::{error::SqliteClientError, DataConnStmtCache, WalletDb, PRUNING_HEIGHT};
use crate::{error::SqliteClientError, PRUNING_HEIGHT};
#[cfg(feature = "transparent-inputs")]
use {
@ -115,28 +115,28 @@ pub(crate) fn pool_code(pool_type: PoolType) -> i64 {
}
}
pub(crate) fn get_max_account_id<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_max_account_id(
conn: &rusqlite::Connection,
) -> Result<Option<AccountId>, SqliteClientError> {
// This returns the most recently generated address.
wdb.conn
.query_row("SELECT MAX(account) FROM accounts", [], |row| {
let account_id: Option<u32> = row.get(0)?;
Ok(account_id.map(AccountId::from))
})
.map_err(SqliteClientError::from)
conn.query_row("SELECT MAX(account) FROM accounts", [], |row| {
let account_id: Option<u32> = row.get(0)?;
Ok(account_id.map(AccountId::from))
})
.map_err(SqliteClientError::from)
}
pub(crate) fn add_account<P: consensus::Parameters>(
wdb: &WalletDb<P>,
conn: &rusqlite::Transaction,
params: &P,
account: AccountId,
key: &UnifiedFullViewingKey,
) -> Result<(), SqliteClientError> {
add_account_internal(&wdb.conn, &wdb.params, "accounts", account, key)
add_account_internal(conn, params, "accounts", account, key)
}
pub(crate) fn add_account_internal<P: consensus::Parameters, E: From<rusqlite::Error>>(
conn: &rusqlite::Connection,
conn: &rusqlite::Transaction,
network: &P,
accounts_table: &'static str,
account: AccountId,
@ -159,12 +159,12 @@ pub(crate) fn add_account_internal<P: consensus::Parameters, E: From<rusqlite::E
}
pub(crate) fn get_current_address<P: consensus::Parameters>(
wdb: &WalletDb<P>,
conn: &rusqlite::Connection,
params: &P,
account: AccountId,
) -> Result<Option<(UnifiedAddress, DiversifierIndex)>, SqliteClientError> {
// This returns the most recently generated address.
let addr: Option<(String, Vec<u8>)> = wdb
.conn
let addr: Option<(String, Vec<u8>)> = conn
.query_row(
"SELECT address, diversifier_index_be
FROM addresses WHERE account = :account
@ -181,7 +181,7 @@ pub(crate) fn get_current_address<P: consensus::Parameters>(
})?;
di_be.reverse();
RecipientAddress::decode(&wdb.params, &addr_str)
RecipientAddress::decode(params, &addr_str)
.ok_or_else(|| {
SqliteClientError::CorruptedData("Not a valid Zcash recipient address".to_owned())
})
@ -201,7 +201,7 @@ pub(crate) fn get_current_address<P: consensus::Parameters>(
///
/// Returns the database row for the newly-inserted address.
pub(crate) fn insert_address<P: consensus::Parameters>(
conn: &Connection,
conn: &rusqlite::Connection,
params: &P,
account: AccountId,
mut diversifier_index: DiversifierIndex,
@ -236,8 +236,8 @@ pub(crate) fn insert_address<P: consensus::Parameters>(
#[cfg(feature = "transparent-inputs")]
pub(crate) fn get_transparent_receivers<P: consensus::Parameters>(
conn: &rusqlite::Connection,
params: &P,
conn: &Connection,
account: AccountId,
) -> Result<HashMap<TransparentAddress, AddressMetadata>, SqliteClientError> {
let mut ret = HashMap::new();
@ -288,7 +288,7 @@ pub(crate) fn get_transparent_receivers<P: consensus::Parameters>(
#[cfg(feature = "transparent-inputs")]
pub(crate) fn get_legacy_transparent_address<P: consensus::Parameters>(
params: &P,
conn: &Connection,
conn: &rusqlite::Connection,
account: AccountId,
) -> Result<Option<(TransparentAddress, DiversifierIndex)>, SqliteClientError> {
// Get the UFVK for the account.
@ -322,18 +322,18 @@ pub(crate) fn get_legacy_transparent_address<P: consensus::Parameters>(
/// Returns the [`UnifiedFullViewingKey`]s for the wallet.
pub(crate) fn get_unified_full_viewing_keys<P: consensus::Parameters>(
wdb: &WalletDb<P>,
conn: &rusqlite::Connection,
params: &P,
) -> Result<HashMap<AccountId, UnifiedFullViewingKey>, SqliteClientError> {
// Fetch the UnifiedFullViewingKeys we are tracking
let mut stmt_fetch_accounts = wdb
.conn
.prepare("SELECT account, ufvk FROM accounts ORDER BY account ASC")?;
let mut stmt_fetch_accounts =
conn.prepare("SELECT account, ufvk FROM accounts ORDER BY account ASC")?;
let rows = stmt_fetch_accounts.query_map([], |row| {
let acct: u32 = row.get(0)?;
let account = AccountId::from(acct);
let ufvk_str: String = row.get(1)?;
let ufvk = UnifiedFullViewingKey::decode(&wdb.params, &ufvk_str)
let ufvk = UnifiedFullViewingKey::decode(params, &ufvk_str)
.map_err(SqliteClientError::CorruptedData);
Ok((account, ufvk))
@ -351,20 +351,20 @@ pub(crate) fn get_unified_full_viewing_keys<P: consensus::Parameters>(
/// Returns the account id corresponding to a given [`UnifiedFullViewingKey`],
/// if any.
pub(crate) fn get_account_for_ufvk<P: consensus::Parameters>(
wdb: &WalletDb<P>,
conn: &rusqlite::Connection,
params: &P,
ufvk: &UnifiedFullViewingKey,
) -> Result<Option<AccountId>, SqliteClientError> {
wdb.conn
.query_row(
"SELECT account FROM accounts WHERE ufvk = ?",
[&ufvk.encode(&wdb.params)],
|row| {
let acct: u32 = row.get(0)?;
Ok(AccountId::from(acct))
},
)
.optional()
.map_err(SqliteClientError::from)
conn.query_row(
"SELECT account FROM accounts WHERE ufvk = ?",
[&ufvk.encode(params)],
|row| {
let acct: u32 = row.get(0)?;
Ok(AccountId::from(acct))
},
)
.optional()
.map_err(SqliteClientError::from)
}
/// Checks whether the specified [`ExtendedFullViewingKey`] is valid and corresponds to the
@ -372,15 +372,15 @@ pub(crate) fn get_account_for_ufvk<P: consensus::Parameters>(
///
/// [`ExtendedFullViewingKey`]: zcash_primitives::zip32::ExtendedFullViewingKey
pub(crate) fn is_valid_account_extfvk<P: consensus::Parameters>(
wdb: &WalletDb<P>,
conn: &rusqlite::Connection,
params: &P,
account: AccountId,
extfvk: &ExtendedFullViewingKey,
) -> Result<bool, SqliteClientError> {
wdb.conn
.prepare("SELECT ufvk FROM accounts WHERE account = ?")?
conn.prepare("SELECT ufvk FROM accounts WHERE account = ?")?
.query_row([u32::from(account).to_sql()?], |row| {
row.get(0).map(|ufvk_str: String| {
UnifiedFullViewingKey::decode(&wdb.params, &ufvk_str)
UnifiedFullViewingKey::decode(params, &ufvk_str)
.map_err(SqliteClientError::CorruptedData)
})
})
@ -406,11 +406,11 @@ pub(crate) fn is_valid_account_extfvk<P: consensus::Parameters>(
/// caveat. Use [`get_balance_at`] where you need a more reliable indication of the
/// wallet balance.
#[cfg(test)]
pub(crate) fn get_balance<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_balance(
conn: &rusqlite::Connection,
account: AccountId,
) -> Result<Amount, SqliteClientError> {
let balance = wdb.conn.query_row(
let balance = conn.query_row(
"SELECT SUM(value) FROM sapling_received_notes
INNER JOIN transactions ON transactions.id_tx = sapling_received_notes.tx
WHERE account = ? AND spent IS NULL AND transactions.block IS NOT NULL",
@ -429,12 +429,12 @@ pub(crate) fn get_balance<P>(
/// Returns the verified balance for the account at the specified height,
/// This may be used to obtain a balance that ignores notes that have been
/// received so recently that they are not yet deemed spendable.
pub(crate) fn get_balance_at<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_balance_at(
conn: &rusqlite::Connection,
account: AccountId,
anchor_height: BlockHeight,
) -> Result<Amount, SqliteClientError> {
let balance = wdb.conn.query_row(
let balance = conn.query_row(
"SELECT SUM(value) FROM sapling_received_notes
INNER JOIN transactions ON transactions.id_tx = sapling_received_notes.tx
WHERE account = ? AND spent IS NULL AND transactions.block <= ?",
@ -454,11 +454,11 @@ pub(crate) fn get_balance_at<P>(
///
/// The note is identified by its row index in the `sapling_received_notes` table within the wdb
/// database.
pub(crate) fn get_received_memo<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_received_memo(
conn: &rusqlite::Connection,
id_note: i64,
) -> Result<Option<Memo>, SqliteClientError> {
let memo_bytes: Option<Vec<_>> = wdb.conn.query_row(
let memo_bytes: Option<Vec<_>> = conn.query_row(
"SELECT memo FROM sapling_received_notes
WHERE id_note = ?",
[id_note],
@ -476,10 +476,11 @@ pub(crate) fn get_received_memo<P>(
/// Looks up a transaction by its internal database identifier.
pub(crate) fn get_transaction<P: Parameters>(
wdb: &WalletDb<P>,
conn: &rusqlite::Connection,
params: &P,
id_tx: i64,
) -> Result<Transaction, SqliteClientError> {
let (tx_bytes, block_height): (Vec<_>, BlockHeight) = wdb.conn.query_row(
let (tx_bytes, block_height): (Vec<_>, BlockHeight) = conn.query_row(
"SELECT raw, block FROM transactions
WHERE id_tx = ?",
[id_tx],
@ -489,22 +490,19 @@ pub(crate) fn get_transaction<P: Parameters>(
},
)?;
Transaction::read(
&tx_bytes[..],
BranchId::for_height(&wdb.params, block_height),
)
.map_err(SqliteClientError::from)
Transaction::read(&tx_bytes[..], BranchId::for_height(params, block_height))
.map_err(SqliteClientError::from)
}
/// Returns the memo for a sent note.
///
/// The note is identified by its row index in the `sent_notes` table within the wdb
/// database.
pub(crate) fn get_sent_memo<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_sent_memo(
conn: &rusqlite::Connection,
id_note: i64,
) -> Result<Option<Memo>, SqliteClientError> {
let memo_bytes: Option<Vec<_>> = wdb.conn.query_row(
let memo_bytes: Option<Vec<_>> = conn.query_row(
"SELECT memo FROM sent_notes
WHERE id_note = ?",
[id_note],
@ -521,74 +519,70 @@ pub(crate) fn get_sent_memo<P>(
}
/// Returns the minimum and maximum heights for blocks stored in the wallet database.
pub(crate) fn block_height_extrema<P>(
wdb: &WalletDb<P>,
pub(crate) fn block_height_extrema(
conn: &rusqlite::Connection,
) -> Result<Option<(BlockHeight, BlockHeight)>, rusqlite::Error> {
wdb.conn
.query_row("SELECT MIN(height), MAX(height) FROM blocks", [], |row| {
let min_height: u32 = row.get(0)?;
let max_height: u32 = row.get(1)?;
Ok(Some((
BlockHeight::from(min_height),
BlockHeight::from(max_height),
)))
})
//.optional() doesn't work here because a failed aggregate function
//produces a runtime error, not an empty set of rows.
.or(Ok(None))
conn.query_row("SELECT MIN(height), MAX(height) FROM blocks", [], |row| {
let min_height: u32 = row.get(0)?;
let max_height: u32 = row.get(1)?;
Ok(Some((
BlockHeight::from(min_height),
BlockHeight::from(max_height),
)))
})
//.optional() doesn't work here because a failed aggregate function
//produces a runtime error, not an empty set of rows.
.or(Ok(None))
}
/// Returns the block height at which the specified transaction was mined,
/// if any.
pub(crate) fn get_tx_height<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_tx_height(
conn: &rusqlite::Connection,
txid: TxId,
) -> Result<Option<BlockHeight>, rusqlite::Error> {
wdb.conn
.query_row(
"SELECT block FROM transactions WHERE txid = ?",
[txid.as_ref().to_vec()],
|row| row.get(0).map(u32::into),
)
.optional()
conn.query_row(
"SELECT block FROM transactions WHERE txid = ?",
[txid.as_ref().to_vec()],
|row| row.get(0).map(u32::into),
)
.optional()
}
/// Returns the block hash for the block at the specified height,
/// if any.
pub(crate) fn get_block_hash<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_block_hash(
conn: &rusqlite::Connection,
block_height: BlockHeight,
) -> Result<Option<BlockHash>, rusqlite::Error> {
wdb.conn
.query_row(
"SELECT hash FROM blocks WHERE height = ?",
[u32::from(block_height)],
|row| {
let row_data = row.get::<_, Vec<_>>(0)?;
Ok(BlockHash::from_slice(&row_data))
},
)
.optional()
conn.query_row(
"SELECT hash FROM blocks WHERE height = ?",
[u32::from(block_height)],
|row| {
let row_data = row.get::<_, Vec<_>>(0)?;
Ok(BlockHash::from_slice(&row_data))
},
)
.optional()
}
/// Gets the height to which the database must be truncated if any truncation that would remove a
/// number of blocks greater than the pruning height is attempted.
pub(crate) fn get_min_unspent_height<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_min_unspent_height(
conn: &rusqlite::Connection,
) -> Result<Option<BlockHeight>, SqliteClientError> {
wdb.conn
.query_row(
"SELECT MIN(tx.block)
conn.query_row(
"SELECT MIN(tx.block)
FROM sapling_received_notes n
JOIN transactions tx ON tx.id_tx = n.tx
WHERE n.spent IS NULL",
[],
|row| {
row.get(0)
.map(|maybe_height: Option<u32>| maybe_height.map(|height| height.into()))
},
)
.map_err(SqliteClientError::from)
[],
|row| {
row.get(0)
.map(|maybe_height: Option<u32>| maybe_height.map(|height| height.into()))
},
)
.map_err(SqliteClientError::from)
}
/// Truncates the database to the given height.
@ -598,25 +592,23 @@ pub(crate) fn get_min_unspent_height<P>(
///
/// This should only be executed inside a transactional context.
pub(crate) fn truncate_to_height<P: consensus::Parameters>(
wdb: &WalletDb<P>,
conn: &rusqlite::Transaction,
params: &P,
block_height: BlockHeight,
) -> Result<(), SqliteClientError> {
let sapling_activation_height = wdb
.params
let sapling_activation_height = params
.activation_height(NetworkUpgrade::Sapling)
.expect("Sapling activation height mutst be available.");
// Recall where we synced up to previously.
let last_scanned_height = wdb
.conn
.query_row("SELECT MAX(height) FROM blocks", [], |row| {
row.get(0)
.map(|h: u32| h.into())
.or_else(|_| Ok(sapling_activation_height - 1))
})?;
let last_scanned_height = conn.query_row("SELECT MAX(height) FROM blocks", [], |row| {
row.get(0)
.map(|h: u32| h.into())
.or_else(|_| Ok(sapling_activation_height - 1))
})?;
if block_height < last_scanned_height - PRUNING_HEIGHT {
if let Some(h) = get_min_unspent_height(wdb)? {
if let Some(h) = get_min_unspent_height(conn)? {
if block_height > h {
return Err(SqliteClientError::RequestedRewindInvalid(h, block_height));
}
@ -626,13 +618,13 @@ pub(crate) fn truncate_to_height<P: consensus::Parameters>(
// nothing to do if we're deleting back down to the max height
if block_height < last_scanned_height {
// Decrement witnesses.
wdb.conn.execute(
conn.execute(
"DELETE FROM sapling_witnesses WHERE block > ?",
[u32::from(block_height)],
)?;
// Rewind received notes
wdb.conn.execute(
conn.execute(
"DELETE FROM sapling_received_notes
WHERE id_note IN (
SELECT rn.id_note
@ -649,19 +641,19 @@ pub(crate) fn truncate_to_height<P: consensus::Parameters>(
// presence of stale sent notes that link to unmined transactions.
// Rewind utxos
wdb.conn.execute(
conn.execute(
"DELETE FROM utxos WHERE height > ?",
[u32::from(block_height)],
)?;
// Un-mine transactions.
wdb.conn.execute(
conn.execute(
"UPDATE transactions SET block = NULL, tx_index = NULL WHERE block IS NOT NULL AND block > ?",
[u32::from(block_height)],
)?;
// Now that they aren't depended on, delete scanned blocks.
wdb.conn.execute(
conn.execute(
"DELETE FROM blocks WHERE height > ?",
[u32::from(block_height)],
)?;
@ -675,12 +667,13 @@ pub(crate) fn truncate_to_height<P: consensus::Parameters>(
/// height less than or equal to the provided `max_height`.
#[cfg(feature = "transparent-inputs")]
pub(crate) fn get_unspent_transparent_outputs<P: consensus::Parameters>(
wdb: &WalletDb<P>,
conn: &rusqlite::Connection,
params: &P,
address: &TransparentAddress,
max_height: BlockHeight,
exclude: &[OutPoint],
) -> Result<Vec<WalletTransparentOutput>, SqliteClientError> {
let mut stmt_blocks = wdb.conn.prepare(
let mut stmt_blocks = conn.prepare(
"SELECT u.prevout_txid, u.prevout_idx, u.script,
u.value_zat, u.height, tx.block as block
FROM utxos u
@ -691,7 +684,7 @@ pub(crate) fn get_unspent_transparent_outputs<P: consensus::Parameters>(
AND tx.block IS NULL",
)?;
let addr_str = address.encode(&wdb.params);
let addr_str = address.encode(params);
let mut utxos = Vec::<WalletTransparentOutput>::new();
let mut rows = stmt_blocks.query(params![addr_str, u32::from(max_height)])?;
@ -737,11 +730,12 @@ pub(crate) fn get_unspent_transparent_outputs<P: consensus::Parameters>(
/// the provided `max_height`.
#[cfg(feature = "transparent-inputs")]
pub(crate) fn get_transparent_balances<P: consensus::Parameters>(
wdb: &WalletDb<P>,
conn: &rusqlite::Connection,
params: &P,
account: AccountId,
max_height: BlockHeight,
) -> Result<HashMap<TransparentAddress, Amount>, SqliteClientError> {
let mut stmt_blocks = wdb.conn.prepare(
let mut stmt_blocks = conn.prepare(
"SELECT u.address, SUM(u.value_zat)
FROM utxos u
LEFT OUTER JOIN transactions tx
@ -756,7 +750,7 @@ pub(crate) fn get_transparent_balances<P: consensus::Parameters>(
let mut rows = stmt_blocks.query(params![u32::from(account), u32::from(max_height)])?;
while let Some(row) = rows.next()? {
let taddr_str: String = row.get(0)?;
let taddr = TransparentAddress::decode(&wdb.params, &taddr_str)?;
let taddr = TransparentAddress::decode(params, &taddr_str)?;
let value = Amount::from_i64(row.get(1)?).unwrap();
res.insert(taddr, value);
@ -766,8 +760,8 @@ pub(crate) fn get_transparent_balances<P: consensus::Parameters>(
}
/// Inserts information about a scanned block into the database.
pub(crate) fn insert_block<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
pub(crate) fn insert_block(
conn: &rusqlite::Connection,
block_height: BlockHeight,
block_hash: BlockHash,
block_time: u32,
@ -776,7 +770,7 @@ pub(crate) fn insert_block<'a, P>(
let mut encoded_tree = Vec::new();
write_commitment_tree(commitment_tree, &mut encoded_tree).unwrap();
let mut stmt_insert_block = stmts.wallet_db.conn.prepare_cached(
let mut stmt_insert_block = conn.prepare_cached(
"INSERT INTO blocks (height, hash, time, sapling_tree)
VALUES (?, ?, ?, ?)",
)?;
@ -793,13 +787,13 @@ pub(crate) fn insert_block<'a, P>(
/// Inserts information about a mined transaction that was observed to
/// contain a note related to this wallet into the database.
pub(crate) fn put_tx_meta<'a, P, N>(
stmts: &mut DataConnStmtCache<'a, P>,
pub(crate) fn put_tx_meta<N>(
conn: &rusqlite::Connection,
tx: &WalletTx<N>,
height: BlockHeight,
) -> Result<i64, SqliteClientError> {
// It isn't there, so insert our transaction into the database.
let mut stmt_upsert_tx_meta = stmts.wallet_db.conn.prepare_cached(
let mut stmt_upsert_tx_meta = conn.prepare_cached(
"INSERT INTO transactions (txid, block, tx_index)
VALUES (:txid, :block, :tx_index)
ON CONFLICT (txid) DO UPDATE
@ -820,13 +814,13 @@ pub(crate) fn put_tx_meta<'a, P, N>(
}
/// Inserts full transaction data into the database.
pub(crate) fn put_tx_data<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
pub(crate) fn put_tx_data(
conn: &rusqlite::Connection,
tx: &Transaction,
fee: Option<Amount>,
created_at: Option<time::OffsetDateTime>,
) -> Result<i64, SqliteClientError> {
let mut stmt_upsert_tx_data = stmts.wallet_db.conn.prepare_cached(
let mut stmt_upsert_tx_data = conn.prepare_cached(
"INSERT INTO transactions (txid, created, expiry_height, raw, fee)
VALUES (:txid, :created_at, :expiry_height, :raw, :fee)
ON CONFLICT (txid) DO UPDATE
@ -856,7 +850,7 @@ pub(crate) fn put_tx_data<'a, P>(
/// Marks the given UTXO as having been spent.
#[cfg(feature = "transparent-inputs")]
pub(crate) fn mark_transparent_utxo_spent(
conn: &Connection,
conn: &rusqlite::Connection,
tx_ref: i64,
outpoint: &OutPoint,
) -> Result<(), SqliteClientError> {
@ -879,7 +873,7 @@ pub(crate) fn mark_transparent_utxo_spent(
/// Adds the given received UTXO to the datastore.
#[cfg(feature = "transparent-inputs")]
pub(crate) fn put_received_transparent_utxo<P: consensus::Parameters>(
conn: &Connection,
conn: &rusqlite::Connection,
params: &P,
output: &WalletTransparentOutput,
) -> Result<UtxoId, SqliteClientError> {
@ -920,7 +914,7 @@ pub(crate) fn put_received_transparent_utxo<P: consensus::Parameters>(
#[cfg(feature = "transparent-inputs")]
pub(crate) fn put_legacy_transparent_utxo<P: consensus::Parameters>(
conn: &Connection,
conn: &rusqlite::Connection,
params: &P,
output: &WalletTransparentOutput,
received_by_account: AccountId,
@ -958,20 +952,30 @@ pub(crate) fn put_legacy_transparent_utxo<P: consensus::Parameters>(
}
/// Removes old incremental witnesses up to the given block height.
pub(crate) fn prune_witnesses<P>(
stmts: &mut DataConnStmtCache<'_, P>,
pub(crate) fn prune_witnesses(
conn: &rusqlite::Connection,
below_height: BlockHeight,
) -> Result<(), SqliteClientError> {
stmts.stmt_prune_witnesses(below_height)
let mut stmt_prune_witnesses =
conn.prepare_cached("DELETE FROM sapling_witnesses WHERE block < ?")?;
stmt_prune_witnesses.execute([u32::from(below_height)])?;
Ok(())
}
/// Marks notes that have not been mined in transactions
/// as expired, up to the given block height.
pub(crate) fn update_expired_notes<P>(
stmts: &mut DataConnStmtCache<'_, P>,
pub(crate) fn update_expired_notes(
conn: &rusqlite::Connection,
height: BlockHeight,
) -> Result<(), SqliteClientError> {
stmts.stmt_update_expired(height)
let mut stmt_update_expired = conn.prepare_cached(
"UPDATE sapling_received_notes SET spent = NULL WHERE EXISTS (
SELECT id_tx FROM transactions
WHERE id_tx = sapling_received_notes.spent AND block IS NULL AND expiry_height < ?
)",
)?;
stmt_update_expired.execute([u32::from(height)])?;
Ok(())
}
// A utility function for creation of parameters for use in `insert_sent_output`
@ -1001,7 +1005,7 @@ fn recipient_params<P: consensus::Parameters>(
///
/// This is a crate-internal convenience method.
pub(crate) fn insert_sent_output<P: consensus::Parameters>(
conn: &Connection,
conn: &rusqlite::Connection,
params: &P,
tx_ref: i64,
from_account: AccountId,
@ -1037,8 +1041,8 @@ pub(crate) fn insert_sent_output<P: consensus::Parameters>(
///
/// This is a crate-internal convenience method.
#[allow(clippy::too_many_arguments)]
pub(crate) fn put_sent_output<'a, P: consensus::Parameters>(
conn: &Connection,
pub(crate) fn put_sent_output<P: consensus::Parameters>(
conn: &rusqlite::Connection,
params: &P,
from_account: AccountId,
tx_ref: i64,
@ -1114,11 +1118,11 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
tests::init_test_accounts_table(&db_data);
tests::init_test_accounts_table(&mut db_data);
// The account should be empty
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::zero()
);
@ -1126,9 +1130,12 @@ mod tests {
assert_eq!(db_data.get_target_and_anchor_heights(10).unwrap(), None);
// An invalid account has zero balance
assert_matches!(get_current_address(&db_data, AccountId::from(1)), Ok(None));
assert_matches!(
get_current_address(&db_data.conn, &db_data.params, AccountId::from(1)),
Ok(None)
);
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::zero()
);
}
@ -1141,9 +1148,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (account_id, _usk) = ops.create_account(&seed).unwrap();
let (account_id, _usk) = db_data.create_account(&seed).unwrap();
let uaddr = db_data.get_current_address(account_id).unwrap().unwrap();
let taddr = uaddr.transparent().unwrap();
@ -1162,8 +1168,7 @@ mod tests {
)
.unwrap();
let res0 =
super::put_received_transparent_utxo(&ops.wallet_db.conn, &ops.wallet_db.params, &utxo);
let res0 = super::put_received_transparent_utxo(&db_data.conn, &db_data.params, &utxo);
assert_matches!(res0, Ok(_));
// Change the mined height of the UTXO and upsert; we should get back
@ -1177,16 +1182,13 @@ mod tests {
BlockHeight::from_u32(34567),
)
.unwrap();
let res1 = super::put_received_transparent_utxo(
&ops.wallet_db.conn,
&ops.wallet_db.params,
&utxo2,
);
let res1 = super::put_received_transparent_utxo(&db_data.conn, &db_data.params, &utxo2);
assert_matches!(res1, Ok(id) if id == res0.unwrap());
assert_matches!(
super::get_unspent_transparent_outputs(
&db_data,
&db_data.conn,
&db_data.params,
taddr,
BlockHeight::from_u32(12345),
&[]
@ -1196,7 +1198,8 @@ mod tests {
assert_matches!(
super::get_unspent_transparent_outputs(
&db_data,
&db_data.conn,
&db_data.params,
taddr,
BlockHeight::from_u32(34567),
&[]
@ -1222,11 +1225,7 @@ mod tests {
)
.unwrap();
let res2 = super::put_received_transparent_utxo(
&ops.wallet_db.conn,
&ops.wallet_db.params,
&utxo2,
);
let res2 = super::put_received_transparent_utxo(&db_data.conn, &db_data.params, &utxo2);
assert_matches!(res2, Err(_));
}
}

View File

@ -111,14 +111,14 @@ impl std::error::Error for WalletMigrationError {
// check for unspent transparent outputs whenever running initialization with a version of the
// library *not* compiled with the `transparent-inputs` feature flag, and fail if any are present.
pub fn init_wallet_db<P: consensus::Parameters + 'static>(
wdb: &mut WalletDb<P>,
wdb: &mut WalletDb<rusqlite::Connection, P>,
seed: Option<SecretVec<u8>>,
) -> Result<(), MigratorError<WalletMigrationError>> {
init_wallet_db_internal(wdb, seed, &[])
}
fn init_wallet_db_internal<P: consensus::Parameters + 'static>(
wdb: &mut WalletDb<P>,
wdb: &mut WalletDb<rusqlite::Connection, P>,
seed: Option<SecretVec<u8>>,
target_migrations: &[Uuid],
) -> Result<(), MigratorError<WalletMigrationError>> {
@ -200,7 +200,7 @@ fn init_wallet_db_internal<P: consensus::Parameters + 'static>(
/// let dfvk = extsk.to_diversifiable_full_viewing_key();
/// let ufvk = UnifiedFullViewingKey::new(None, Some(dfvk), None).unwrap();
/// let ufvks = HashMap::from([(account, ufvk)]);
/// init_accounts_table(&db_data, &ufvks).unwrap();
/// init_accounts_table(&mut db_data, &ufvks).unwrap();
/// # }
/// ```
///
@ -208,29 +208,29 @@ fn init_wallet_db_internal<P: consensus::Parameters + 'static>(
/// [`scan_cached_blocks`]: zcash_client_backend::data_api::chain::scan_cached_blocks
/// [`create_spend_to_address`]: zcash_client_backend::data_api::wallet::create_spend_to_address
pub fn init_accounts_table<P: consensus::Parameters>(
wdb: &WalletDb<P>,
wallet_db: &mut WalletDb<rusqlite::Connection, P>,
keys: &HashMap<AccountId, UnifiedFullViewingKey>,
) -> Result<(), SqliteClientError> {
let mut empty_check = wdb.conn.prepare("SELECT * FROM accounts LIMIT 1")?;
if empty_check.exists([])? {
return Err(SqliteClientError::TableNotEmpty);
}
// Ensure that the account identifiers are sequential and begin at zero.
if let Some(account_id) = keys.keys().max() {
if usize::try_from(u32::from(*account_id)).unwrap() >= keys.len() {
return Err(SqliteClientError::AccountIdDiscontinuity);
wallet_db.transactionally(|wdb| {
let mut empty_check = wdb.conn.0.prepare("SELECT * FROM accounts LIMIT 1")?;
if empty_check.exists([])? {
return Err(SqliteClientError::TableNotEmpty);
}
}
// Insert accounts atomically
wdb.conn.execute("BEGIN IMMEDIATE", [])?;
for (account, key) in keys.iter() {
wallet::add_account(wdb, *account, key)?;
}
wdb.conn.execute("COMMIT", [])?;
// Ensure that the account identifiers are sequential and begin at zero.
if let Some(account_id) = keys.keys().max() {
if usize::try_from(u32::from(*account_id)).unwrap() >= keys.len() {
return Err(SqliteClientError::AccountIdDiscontinuity);
}
}
Ok(())
// Insert accounts atomically
for (account, key) in keys.iter() {
wallet::add_account(&wdb.conn.0, &wdb.params, *account, key)?;
}
Ok(())
})
}
/// Initialises the data database with the given block.
@ -262,33 +262,35 @@ pub fn init_accounts_table<P: consensus::Parameters>(
/// let sapling_tree = &[];
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDb::for_path(data_file.path(), Network::TestNetwork).unwrap();
/// init_blocks_table(&db, height, hash, time, sapling_tree);
/// let mut db = WalletDb::for_path(data_file.path(), Network::TestNetwork).unwrap();
/// init_blocks_table(&mut db, height, hash, time, sapling_tree);
/// ```
pub fn init_blocks_table<P>(
wdb: &WalletDb<P>,
pub fn init_blocks_table<P: consensus::Parameters>(
wallet_db: &mut WalletDb<rusqlite::Connection, P>,
height: BlockHeight,
hash: BlockHash,
time: u32,
sapling_tree: &[u8],
) -> Result<(), SqliteClientError> {
let mut empty_check = wdb.conn.prepare("SELECT * FROM blocks LIMIT 1")?;
if empty_check.exists([])? {
return Err(SqliteClientError::TableNotEmpty);
}
wallet_db.transactionally(|wdb| {
let mut empty_check = wdb.conn.0.prepare("SELECT * FROM blocks LIMIT 1")?;
if empty_check.exists([])? {
return Err(SqliteClientError::TableNotEmpty);
}
wdb.conn.execute(
"INSERT INTO blocks (height, hash, time, sapling_tree)
wdb.conn.0.execute(
"INSERT INTO blocks (height, hash, time, sapling_tree)
VALUES (?, ?, ?, ?)",
[
u32::from(height).to_sql()?,
hash.0.to_sql()?,
time.to_sql()?,
sapling_tree.to_sql()?,
],
)?;
[
u32::from(height).to_sql()?,
hash.0.to_sql()?,
time.to_sql()?,
sapling_tree.to_sql()?,
],
)?;
Ok(())
Ok(())
})
}
#[cfg(test)]
@ -606,7 +608,7 @@ mod tests {
#[test]
fn init_migrate_from_0_3_0() {
fn init_0_3_0<P>(
wdb: &mut WalletDb<P>,
wdb: &mut WalletDb<rusqlite::Connection, P>,
extfvk: &ExtendedFullViewingKey,
account: AccountId,
) -> Result<(), rusqlite::Error> {
@ -722,7 +724,7 @@ mod tests {
#[test]
fn init_migrate_from_autoshielding_poc() {
fn init_autoshielding<P>(
wdb: &WalletDb<P>,
wdb: &mut WalletDb<rusqlite::Connection, P>,
extfvk: &ExtendedFullViewingKey,
account: AccountId,
) -> Result<(), rusqlite::Error> {
@ -878,14 +880,14 @@ mod tests {
let extfvk = secret_key.to_extended_full_viewing_key();
let data_file = NamedTempFile::new().unwrap();
let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap();
init_autoshielding(&db_data, &extfvk, account).unwrap();
init_autoshielding(&mut db_data, &extfvk, account).unwrap();
init_wallet_db(&mut db_data, Some(Secret::new(seed.to_vec()))).unwrap();
}
#[test]
fn init_migrate_from_main_pre_migrations() {
fn init_main<P>(
wdb: &WalletDb<P>,
wdb: &mut WalletDb<rusqlite::Connection, P>,
ufvk: &UnifiedFullViewingKey,
account: AccountId,
) -> Result<(), rusqlite::Error> {
@ -1025,7 +1027,12 @@ mod tests {
let secret_key = UnifiedSpendingKey::from_seed(&tests::network(), &seed, account).unwrap();
let data_file = NamedTempFile::new().unwrap();
let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap();
init_main(&db_data, &secret_key.to_unified_full_viewing_key(), account).unwrap();
init_main(
&mut db_data,
&secret_key.to_unified_full_viewing_key(),
account,
)
.unwrap();
init_wallet_db(&mut db_data, Some(Secret::new(seed.to_vec()))).unwrap();
}
@ -1036,8 +1043,8 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// We can call the function as many times as we want with no data
init_accounts_table(&db_data, &HashMap::new()).unwrap();
init_accounts_table(&db_data, &HashMap::new()).unwrap();
init_accounts_table(&mut db_data, &HashMap::new()).unwrap();
init_accounts_table(&mut db_data, &HashMap::new()).unwrap();
let seed = [0u8; 32];
let account = AccountId::from(0);
@ -1062,11 +1069,11 @@ mod tests {
let ufvk = UnifiedFullViewingKey::new(Some(dfvk), None).unwrap();
let ufvks = HashMap::from([(account, ufvk)]);
init_accounts_table(&db_data, &ufvks).unwrap();
init_accounts_table(&mut db_data, &ufvks).unwrap();
// Subsequent calls should return an error
init_accounts_table(&db_data, &HashMap::new()).unwrap_err();
init_accounts_table(&db_data, &ufvks).unwrap_err();
init_accounts_table(&mut db_data, &HashMap::new()).unwrap_err();
init_accounts_table(&mut db_data, &ufvks).unwrap_err();
}
#[test]
@ -1090,12 +1097,12 @@ mod tests {
// should fail if we have a gap
assert_matches!(
init_accounts_table(&db_data, &ufvks(&[0, 2])),
init_accounts_table(&mut db_data, &ufvks(&[0, 2])),
Err(SqliteClientError::AccountIdDiscontinuity)
);
// should succeed if there are no gaps
assert!(init_accounts_table(&db_data, &ufvks(&[0, 1, 2])).is_ok());
assert!(init_accounts_table(&mut db_data, &ufvks(&[0, 1, 2])).is_ok());
}
#[test]
@ -1106,7 +1113,7 @@ mod tests {
// First call with data should initialise the blocks table
init_blocks_table(
&db_data,
&mut db_data,
BlockHeight::from(1u32),
BlockHash([1; 32]),
1,
@ -1116,7 +1123,7 @@ mod tests {
// Subsequent calls should return an error
init_blocks_table(
&db_data,
&mut db_data,
BlockHeight::from(2u32),
BlockHash([2; 32]),
2,
@ -1139,7 +1146,7 @@ mod tests {
let ufvk = usk.to_unified_full_viewing_key();
let expected_address = ufvk.sapling().unwrap().default_address().1;
let ufvks = HashMap::from([(account_id, ufvk)]);
init_accounts_table(&db_data, &ufvks).unwrap();
init_accounts_table(&mut db_data, &ufvks).unwrap();
// The account's address should be in the data DB
let ua = db_data.get_current_address(AccountId::from(0)).unwrap();
@ -1153,16 +1160,15 @@ mod tests {
let mut db_data = WalletDb::for_path(data_file.path(), Network::MainNetwork).unwrap();
init_wallet_db(&mut db_data, None).unwrap();
let mut ops = db_data.get_update_ops().unwrap();
let seed = test_vectors::UNIFIED[0].root_seed;
let (account, _usk) = ops.create_account(&Secret::new(seed.to_vec())).unwrap();
let (account, _usk) = db_data.create_account(&Secret::new(seed.to_vec())).unwrap();
assert_eq!(account, AccountId::from(0u32));
for tv in &test_vectors::UNIFIED[..3] {
if let Some(RecipientAddress::Unified(tvua)) =
RecipientAddress::decode(&Network::MainNetwork, tv.unified_addr)
{
let (ua, di) = wallet::get_current_address(&db_data, account)
let (ua, di) = wallet::get_current_address(&db_data.conn, &db_data.params, account)
.unwrap()
.expect("create_account generated the first address");
assert_eq!(DiversifierIndex::from(tv.diversifier_index), di);
@ -1170,7 +1176,8 @@ mod tests {
assert_eq!(tvua.sapling(), ua.sapling());
assert_eq!(tv.unified_addr, ua.encode(&Network::MainNetwork));
ops.get_next_available_address(account)
db_data
.get_next_available_address(account)
.unwrap()
.expect("get_next_available_address generated an address");
} else {

View File

@ -67,7 +67,7 @@ impl<P: consensus::Parameters> RusqliteMigration for Migration<P> {
while let Some(row) = rows.next()? {
let account: u32 = row.get(0)?;
let taddrs =
get_transparent_receivers(&self._params, transaction, AccountId::from(account))
get_transparent_receivers(transaction, &self._params, AccountId::from(account))
.map_err(|e| match e {
SqliteClientError::DbError(e) => WalletMigrationError::DbError(e),
SqliteClientError::CorruptedData(s) => {

View File

@ -6,7 +6,7 @@ use std::rc::Rc;
use zcash_primitives::{
consensus::BlockHeight,
memo::MemoBytes,
merkle_tree::{read_commitment_tree, read_incremental_witness},
merkle_tree::{read_commitment_tree, read_incremental_witness, write_incremental_witness},
sapling::{self, Diversifier, Note, Nullifier, Rseed},
transaction::components::Amount,
zip32::AccountId,
@ -17,7 +17,7 @@ use zcash_client_backend::{
DecryptedOutput, TransferType,
};
use crate::{error::SqliteClientError, DataConnStmtCache, NoteId, WalletDb};
use crate::{error::SqliteClientError, NoteId};
/// This trait provides a generalization over shielded output representations.
pub(crate) trait ReceivedSaplingOutput {
@ -230,43 +230,42 @@ pub(crate) fn select_spendable_sapling_notes(
/// Returns the commitment tree for the block at the specified height,
/// if any.
pub(crate) fn get_sapling_commitment_tree<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_sapling_commitment_tree(
conn: &Connection,
block_height: BlockHeight,
) -> Result<Option<sapling::CommitmentTree>, SqliteClientError> {
wdb.conn
.query_row_and_then(
"SELECT sapling_tree FROM blocks WHERE height = ?",
[u32::from(block_height)],
|row| {
let row_data: Vec<u8> = row.get(0)?;
read_commitment_tree(&row_data[..]).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(
row_data.len(),
rusqlite::types::Type::Blob,
Box::new(e),
)
})
},
)
.optional()
.map_err(SqliteClientError::from)
conn.query_row_and_then(
"SELECT sapling_tree FROM blocks WHERE height = ?",
[u32::from(block_height)],
|row| {
let row_data: Vec<u8> = row.get(0)?;
read_commitment_tree(&row_data[..]).map_err(|e| {
rusqlite::Error::FromSqlConversionFailure(
row_data.len(),
rusqlite::types::Type::Blob,
Box::new(e),
)
})
},
)
.optional()
.map_err(SqliteClientError::from)
}
/// Returns the incremental witnesses for the block at the specified height,
/// if any.
pub(crate) fn get_sapling_witnesses<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_sapling_witnesses(
conn: &Connection,
block_height: BlockHeight,
) -> Result<Vec<(NoteId, sapling::IncrementalWitness)>, SqliteClientError> {
let mut stmt_fetch_witnesses = wdb
.conn
.prepare("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?;
let mut stmt_fetch_witnesses =
conn.prepare_cached("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?;
let witnesses = stmt_fetch_witnesses
.query_map([u32::from(block_height)], |row| {
let id_note = NoteId::ReceivedNoteId(row.get(0)?);
let wdb: Vec<u8> = row.get(1)?;
Ok(read_incremental_witness(&wdb[..]).map(|witness| (id_note, witness)))
let witness_data: Vec<u8> = row.get(1)?;
Ok(read_incremental_witness(&witness_data[..]).map(|witness| (id_note, witness)))
})
.map_err(SqliteClientError::from)?;
@ -277,13 +276,23 @@ pub(crate) fn get_sapling_witnesses<P>(
/// Records the incremental witness for the specified note,
/// as of the given block height.
pub(crate) fn insert_witness<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
pub(crate) fn insert_witness(
conn: &Connection,
note_id: i64,
witness: &sapling::IncrementalWitness,
height: BlockHeight,
) -> Result<(), SqliteClientError> {
stmts.stmt_insert_witness(NoteId::ReceivedNoteId(note_id), height, witness)
let mut stmt_insert_witness = conn.prepare_cached(
"INSERT INTO sapling_witnesses (note, block, witness)
VALUES (?, ?, ?)",
)?;
let mut encoded = Vec::new();
write_incremental_witness(witness, &mut encoded).unwrap();
stmt_insert_witness.execute(params![note_id, u32::from(height), encoded])?;
Ok(())
}
/// Retrieves the set of nullifiers for "potentially spendable" Sapling notes that the
@ -292,11 +301,11 @@ pub(crate) fn insert_witness<'a, P>(
/// "Potentially spendable" means:
/// - The transaction in which the note was created has been observed as mined.
/// - No transaction in which the note's nullifier appears has been observed as mined.
pub(crate) fn get_sapling_nullifiers<P>(
wdb: &WalletDb<P>,
pub(crate) fn get_sapling_nullifiers(
conn: &Connection,
) -> Result<Vec<(AccountId, Nullifier)>, SqliteClientError> {
// Get the nullifiers for the notes we are tracking
let mut stmt_fetch_nullifiers = wdb.conn.prepare(
let mut stmt_fetch_nullifiers = conn.prepare(
"SELECT rn.id_note, rn.account, rn.nf, tx.block as block
FROM sapling_received_notes rn
LEFT OUTER JOIN transactions tx
@ -454,7 +463,7 @@ mod tests {
get_balance, get_balance_at,
init::{init_blocks_table, init_wallet_db},
},
AccountId, BlockDb, DataConnStmtCache, WalletDb,
AccountId, BlockDb, WalletDb,
};
#[cfg(feature = "transparent-inputs")]
@ -488,9 +497,8 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap();
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
let to = dfvk.default_address().1.into();
@ -499,10 +507,9 @@ mod tests {
let usk1 = UnifiedSpendingKey::from_seed(&network(), &[1u8; 32], acct1).unwrap();
// Attempting to spend with a USK that is not in the wallet results in an error
let mut db_write = db_data.get_update_ops().unwrap();
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk1,
@ -523,17 +530,15 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap();
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
let to = dfvk.default_address().1.into();
// We cannot do anything if we aren't synchronised
let mut db_write = db_data.get_update_ops().unwrap();
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -553,7 +558,7 @@ mod tests {
let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap();
init_wallet_db(&mut db_data, None).unwrap();
init_blocks_table(
&db_data,
&mut db_data,
BlockHeight::from(1u32),
BlockHash([1; 32]),
1,
@ -562,23 +567,21 @@ mod tests {
.unwrap();
// Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap();
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
let to = dfvk.default_address().1.into();
// Account balance should be zero
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::zero()
);
// We cannot spend anything
let mut db_write = db_data.get_update_ops().unwrap();
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -607,9 +610,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap();
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet in a single note
@ -622,14 +624,16 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Verified balance matches total balance
let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
assert_eq!(
get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
assert_eq!(
get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(),
value
);
@ -642,16 +646,16 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Verified balance does not include the second note
let (_, anchor_height2) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap();
assert_eq!(
get_balance(&db_data, AccountId::from(0)).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
(value + value).unwrap()
);
assert_eq!(
get_balance_at(&db_data, AccountId::from(0), anchor_height2).unwrap(),
get_balance_at(&db_data.conn, AccountId::from(0), anchor_height2).unwrap(),
value
);
@ -660,7 +664,7 @@ mod tests {
let to = extsk2.default_address().1.into();
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -690,12 +694,12 @@ mod tests {
);
insert_into_cache(&db_cache, &cb);
}
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Second spend still fails
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -722,12 +726,12 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Second spend should now succeed
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -752,9 +756,8 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap();
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet in a single note
@ -767,16 +770,18 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
// Send some of the funds to another address
let extsk2 = ExtendedSpendingKey::master(&[]);
let to = extsk2.default_address().1.into();
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -792,7 +797,7 @@ mod tests {
// A second spend fails because there are no usable notes
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -821,12 +826,12 @@ mod tests {
);
insert_into_cache(&db_cache, &cb);
}
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Second spend still fails
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -852,11 +857,11 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Second spend should now succeed
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -881,9 +886,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap();
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet in a single note
@ -896,17 +900,19 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
let extsk2 = ExtendedSpendingKey::master(&[]);
let addr2 = extsk2.default_address().1;
let to = addr2.into();
let send_and_recover_with_policy = |db_write: &mut DataConnStmtCache<'_, _>, ovk_policy| {
let send_and_recover_with_policy = |db_data: &mut WalletDb<Connection, _>, ovk_policy| {
let tx_row = create_spend_to_address(
db_write,
db_data,
&tests::network(),
test_prover(),
&usk,
@ -919,8 +925,7 @@ mod tests {
.unwrap();
// Fetch the transaction from the database
let raw_tx: Vec<_> = db_write
.wallet_db
let raw_tx: Vec<_> = db_data
.conn
.query_row(
"SELECT raw FROM transactions
@ -951,7 +956,7 @@ mod tests {
// Send some of the funds to another address, keeping history.
// The recipient output is decryptable by the sender.
let (_, recovered_to, _) =
send_and_recover_with_policy(&mut db_write, OvkPolicy::Sender).unwrap();
send_and_recover_with_policy(&mut db_data, OvkPolicy::Sender).unwrap();
assert_eq!(&recovered_to, &addr2);
// Mine blocks SAPLING_ACTIVATION_HEIGHT + 1 to 42 (that don't send us funds)
@ -966,11 +971,11 @@ mod tests {
);
insert_into_cache(&db_cache, &cb);
}
scan_cached_blocks(&network, &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&network, &db_cache, &mut db_data, None).unwrap();
// Send the funds again, discarding history.
// Neither transaction output is decryptable by the sender.
assert!(send_and_recover_with_policy(&mut db_write, OvkPolicy::Discard).is_none());
assert!(send_and_recover_with_policy(&mut db_data, OvkPolicy::Discard).is_none());
}
#[test]
@ -984,9 +989,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap();
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet in a single note
@ -999,21 +1003,23 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Verified balance matches total balance
let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
assert_eq!(
get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
assert_eq!(
get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(),
value
);
let to = TransparentAddress::PublicKey([7; 20]).into();
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -1038,9 +1044,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap();
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet in a single note
@ -1053,21 +1058,23 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Verified balance matches total balance
let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), value);
assert_eq!(
get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
);
assert_eq!(
get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(),
value
);
let to = TransparentAddress::PublicKey([7; 20]).into();
assert_matches!(
create_spend_to_address(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&usk,
@ -1092,9 +1099,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet
let mut ops = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = ops.create_account(&seed).unwrap();
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Add funds to the wallet
@ -1119,15 +1125,17 @@ mod tests {
insert_into_cache(&db_cache, &cb);
}
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
// Verified balance matches total balance
let total = Amount::from_u64(60000).unwrap();
let (_, anchor_height) = db_data.get_target_and_anchor_heights(1).unwrap().unwrap();
assert_eq!(get_balance(&db_data, AccountId::from(0)).unwrap(), total);
assert_eq!(
get_balance_at(&db_data, AccountId::from(0), anchor_height).unwrap(),
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
total
);
assert_eq!(
get_balance_at(&db_data.conn, AccountId::from(0), anchor_height).unwrap(),
total
);
@ -1149,7 +1157,7 @@ mod tests {
assert_matches!(
spend(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&input_selector,
@ -1177,7 +1185,7 @@ mod tests {
assert_matches!(
spend(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&input_selector,
@ -1202,9 +1210,8 @@ mod tests {
init_wallet_db(&mut db_data, None).unwrap();
// Add an account to the wallet
let mut db_write = db_data.get_update_ops().unwrap();
let seed = Secret::new([0u8; 32].to_vec());
let (account_id, usk) = db_write.create_account(&seed).unwrap();
let (account_id, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
let uaddr = db_data.get_current_address(account_id).unwrap().unwrap();
let taddr = uaddr.transparent().unwrap();
@ -1219,7 +1226,7 @@ mod tests {
)
.unwrap();
let res0 = db_write.put_received_transparent_utxo(&utxo);
let res0 = db_data.put_received_transparent_utxo(&utxo);
assert!(matches!(res0, Ok(_)));
let input_selector = GreedyInputSelector::new(
@ -1236,11 +1243,11 @@ mod tests {
Amount::from_u64(50000).unwrap(),
);
insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None).unwrap();
assert_matches!(
shield_transparent_funds(
&mut db_write,
&mut db_data,
&tests::network(),
test_prover(),
&input_selector,