zcash_client_sqlite: Generalize more `TestState` operations.

This commit is contained in:
Kris Nuttycombe 2024-09-05 17:31:41 -06:00
parent acd26d5d53
commit 58b464d102
3 changed files with 117 additions and 90 deletions

View File

@ -55,6 +55,7 @@ pub struct TransactionSummary<AccountId> {
}
impl<AccountId> TransactionSummary<AccountId> {
#[allow(clippy::too_many_arguments)]
pub fn new(
account_id: AccountId,
txid: TxId,

View File

@ -615,7 +615,7 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> WalletRead for W
Ok(iter.collect())
}
#[cfg(feature = "test-dependencies")]
#[cfg(any(test, feature = "test-dependencies"))]
fn get_tx_history(&self) -> Result<Vec<TransactionSummary<Self::AccountId>>, Self::Error> {
wallet::testing::get_tx_history(self.conn.borrow())
}

View File

@ -11,7 +11,7 @@ use nonempty::NonEmpty;
use prost::Message;
use rand_chacha::ChaChaRng;
use rand_core::{CryptoRng, RngCore, SeedableRng};
use rusqlite::{params, Connection};
use rusqlite::params;
use secrecy::{Secret, SecretVec};
use shardtree::error::ShardTreeError;
@ -75,12 +75,12 @@ use zcash_protocol::value::Zatoshis;
use crate::{
chain::init::init_cache_database,
error::SqliteClientError,
wallet::{
commitment_tree, get_wallet_summary, sapling::tests::test_prover, SubtreeScanProgress,
},
AccountId, ReceivedNoteId, WalletDb,
wallet::{get_wallet_summary, sapling::tests::test_prover, SubtreeScanProgress},
AccountId, ReceivedNoteId,
};
use self::db::TestDb;
use super::BlockDb;
#[cfg(feature = "orchard")]
@ -859,7 +859,7 @@ where
Cache: TestCache,
<Cache::BlockSource as BlockSource>::Error: fmt::Debug,
ParamsT: consensus::Parameters + Send + 'static,
DbT: WalletWrite,
DbT: InputSource + WalletWrite + WalletCommitmentTrees,
<DbT as WalletRead>::AccountId: ConditionallySelectable + Default + Send + 'static,
{
/// Invokes [`scan_cached_blocks`] with the given arguments, expecting success.
@ -880,7 +880,10 @@ where
limit: usize,
) -> Result<
ScanSummary,
data_api::chain::error::Error<DbT::Error, <Cache::BlockSource as BlockSource>::Error>,
data_api::chain::error::Error<
<DbT as WalletRead>::Error,
<Cache::BlockSource as BlockSource>::Error,
>,
> {
let prior_cached_block = self
.latest_cached_block_below_height(from_height)
@ -897,41 +900,7 @@ where
);
result
}
}
impl<Cache, DbT: WalletRead + Reset> TestState<Cache, DbT, LocalNetwork> {
/// Resets the wallet using a new wallet database but with the same cache of blocks,
/// and returns the old wallet database file.
///
/// This does not recreate accounts, nor does it rescan the cached blocks.
/// The resulting wallet has no test account.
/// Before using any `generate_*` method on the reset state, call `reset_latest_cached_block()`.
pub(crate) fn reset(&mut self) -> DbT::Handle {
self.latest_block_height = None;
self.test_account = None;
DbT::reset(self)
}
// /// Reset the latest cached block to the most recent one in the cache database.
// #[allow(dead_code)]
// pub(crate) fn reset_latest_cached_block(&mut self) {
// self.cache
// .block_source()
// .with_blocks::<_, Infallible>(None, None, |block: CompactBlock| {
// let chain_metadata = block.chain_metadata.unwrap();
// self.latest_cached_block = Some(CachedBlock::at(
// BlockHash::from_slice(block.hash.as_slice()),
// BlockHeight::from_u32(block.height.try_into().unwrap()),
// chain_metadata.sapling_commitment_tree_size,
// chain_metadata.orchard_commitment_tree_size,
// ));
// Ok(())
// })
// .unwrap();
// }
}
impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
/// Insert shard roots for both trees.
pub(crate) fn put_subtree_roots(
&mut self,
@ -939,7 +908,7 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
sapling_roots: &[CommitmentTreeRoot<sapling::Node>],
#[cfg(feature = "orchard")] orchard_start_index: u64,
#[cfg(feature = "orchard")] orchard_roots: &[CommitmentTreeRoot<MerkleHashOrchard>],
) -> Result<(), ShardTreeError<commitment_tree::Error>> {
) -> Result<(), ShardTreeError<<DbT as WalletCommitmentTrees>::Error>> {
self.wallet_mut()
.put_sapling_subtree_roots(sapling_start_index, sapling_roots)?;
@ -949,7 +918,18 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
Ok(())
}
}
impl<Cache, DbT, ParamsT, AccountIdT, ErrT> TestState<Cache, DbT, ParamsT>
where
ParamsT: consensus::Parameters + Send + 'static,
AccountIdT: std::cmp::Eq + std::hash::Hash,
ErrT: std::fmt::Debug,
DbT: InputSource<AccountId = AccountIdT, Error = ErrT>
+ WalletWrite<AccountId = AccountIdT, Error = ErrT>
+ WalletCommitmentTrees,
<DbT as WalletRead>::AccountId: ConditionallySelectable + Default + Send + 'static,
{
/// Invokes [`create_spend_to_address`] with the given arguments.
#[allow(deprecated)]
#[allow(clippy::type_complexity)]
@ -967,16 +947,17 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
) -> Result<
NonEmpty<TxId>,
data_api::error::Error<
SqliteClientError,
commitment_tree::Error,
GreedyInputSelectorError<Zip317FeeError, ReceivedNoteId>,
ErrT,
<DbT as WalletCommitmentTrees>::Error,
GreedyInputSelectorError<Zip317FeeError, <DbT as InputSource>::NoteRef>,
Zip317FeeError,
>,
> {
let prover = test_prover();
let network = self.network().clone();
create_spend_to_address(
self.wallet_data.db_mut(),
&self.network,
self.wallet_mut(),
&network,
&prover,
&prover,
usk,
@ -1002,20 +983,21 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
) -> Result<
NonEmpty<TxId>,
data_api::error::Error<
SqliteClientError,
commitment_tree::Error,
ErrT,
<DbT as WalletCommitmentTrees>::Error,
InputsT::Error,
<InputsT::FeeRule as FeeRule>::Error,
>,
>
where
InputsT: InputSelector<InputSource = WalletDb<Connection, LocalNetwork>>,
InputsT: InputSelector<InputSource = DbT>,
{
#![allow(deprecated)]
let prover = test_prover();
let network = self.network().clone();
spend(
self.wallet_data.db_mut(),
&self.network,
self.wallet_mut(),
&network,
&prover,
&prover,
input_selector,
@ -1030,25 +1012,26 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
#[allow(clippy::type_complexity)]
pub(crate) fn propose_transfer<InputsT>(
&mut self,
spend_from_account: AccountId,
spend_from_account: <DbT as InputSource>::AccountId,
input_selector: &InputsT,
request: zip321::TransactionRequest,
min_confirmations: NonZeroU32,
) -> Result<
Proposal<InputsT::FeeRule, ReceivedNoteId>,
Proposal<InputsT::FeeRule, <DbT as InputSource>::NoteRef>,
data_api::error::Error<
SqliteClientError,
ErrT,
Infallible,
InputsT::Error,
<InputsT::FeeRule as FeeRule>::Error,
>,
>
where
InputsT: InputSelector<InputSource = WalletDb<Connection, LocalNetwork>>,
InputsT: InputSelector<InputSource = DbT>,
{
let network = self.network().clone();
propose_transfer::<_, _, _, Infallible>(
self.wallet_data.db_mut(),
&self.network,
self.wallet_mut(),
&network,
spend_from_account,
input_selector,
request,
@ -1061,7 +1044,7 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn propose_standard_transfer<CommitmentTreeErrT>(
&mut self,
spend_from_account: AccountId,
spend_from_account: <DbT as InputSource>::AccountId,
fee_rule: StandardFeeRule,
min_confirmations: NonZeroU32,
to: &Address,
@ -1070,17 +1053,18 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
change_memo: Option<MemoBytes>,
fallback_change_pool: ShieldedProtocol,
) -> Result<
Proposal<StandardFeeRule, ReceivedNoteId>,
Proposal<StandardFeeRule, <DbT as InputSource>::NoteRef>,
data_api::error::Error<
SqliteClientError,
ErrT,
CommitmentTreeErrT,
GreedyInputSelectorError<Zip317FeeError, ReceivedNoteId>,
GreedyInputSelectorError<Zip317FeeError, <DbT as InputSource>::NoteRef>,
Zip317FeeError,
>,
> {
let network = self.network().clone();
let result = propose_standard_transfer_to_address::<_, _, CommitmentTreeErrT>(
self.wallet_data.db_mut(),
&self.network,
self.wallet_mut(),
&network,
fee_rule,
spend_from_account,
min_confirmations,
@ -1092,7 +1076,7 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
);
if let Ok(proposal) = &result {
check_proposal_serialization_roundtrip(self.wallet_data.db(), proposal);
check_proposal_serialization_roundtrip(self.wallet(), proposal);
}
result
@ -1111,18 +1095,19 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
) -> Result<
Proposal<InputsT::FeeRule, Infallible>,
data_api::error::Error<
SqliteClientError,
ErrT,
Infallible,
InputsT::Error,
<InputsT::FeeRule as FeeRule>::Error,
>,
>
where
InputsT: ShieldingSelector<InputSource = WalletDb<Connection, LocalNetwork>>,
InputsT: ShieldingSelector<InputSource = DbT>,
{
let network = self.network().clone();
propose_shielding::<_, _, _, Infallible>(
self.wallet_data.db_mut(),
&self.network,
self.wallet_mut(),
&network,
input_selector,
shielding_threshold,
from_addrs,
@ -1131,6 +1116,7 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
}
/// Invokes [`create_proposed_transactions`] with the given arguments.
#[allow(clippy::type_complexity)]
pub(crate) fn create_proposed_transactions<InputsErrT, FeeRuleT>(
&mut self,
usk: &UnifiedSpendingKey,
@ -1139,8 +1125,8 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
) -> Result<
NonEmpty<TxId>,
data_api::error::Error<
SqliteClientError,
commitment_tree::Error,
ErrT,
<DbT as WalletCommitmentTrees>::Error,
InputsErrT,
FeeRuleT::Error,
>,
@ -1149,9 +1135,10 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
FeeRuleT: FeeRule,
{
let prover = test_prover();
let network = self.network().clone();
create_proposed_transactions(
self.wallet_data.db_mut(),
&self.network,
self.wallet_mut(),
&network,
&prover,
&prover,
usk,
@ -1173,19 +1160,20 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
) -> Result<
NonEmpty<TxId>,
data_api::error::Error<
SqliteClientError,
commitment_tree::Error,
ErrT,
<DbT as WalletCommitmentTrees>::Error,
InputsT::Error,
<InputsT::FeeRule as FeeRule>::Error,
>,
>
where
InputsT: ShieldingSelector<InputSource = WalletDb<Connection, LocalNetwork>>,
InputsT: ShieldingSelector<InputSource = DbT>,
{
let prover = test_prover();
let network = self.network().clone();
shield_transparent_funds(
self.wallet_data.db_mut(),
&self.network,
self.wallet_mut(),
&network,
&prover,
&prover,
input_selector,
@ -1198,21 +1186,25 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
fn with_account_balance<T, F: FnOnce(&AccountBalance) -> T>(
&self,
account: AccountId,
account: AccountIdT,
min_confirmations: u32,
f: F,
) -> T {
let binding = self.get_wallet_summary(min_confirmations).unwrap();
let binding = self
.wallet()
.get_wallet_summary(min_confirmations)
.unwrap()
.unwrap();
f(binding.account_balances().get(&account).unwrap())
}
pub(crate) fn get_total_balance(&self, account: AccountId) -> NonNegativeAmount {
pub(crate) fn get_total_balance(&self, account: AccountIdT) -> NonNegativeAmount {
self.with_account_balance(account, 0, |balance| balance.total())
}
pub(crate) fn get_spendable_balance(
&self,
account: AccountId,
account: AccountIdT,
min_confirmations: u32,
) -> NonNegativeAmount {
self.with_account_balance(account, min_confirmations, |balance| {
@ -1222,7 +1214,7 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
pub(crate) fn get_pending_shielded_balance(
&self,
account: AccountId,
account: AccountIdT,
min_confirmations: u32,
) -> NonNegativeAmount {
self.with_account_balance(account, min_confirmations, |balance| {
@ -1234,14 +1226,16 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
#[allow(dead_code)]
pub(crate) fn get_pending_change(
&self,
account: AccountId,
account: AccountIdT,
min_confirmations: u32,
) -> NonNegativeAmount {
self.with_account_balance(account, min_confirmations, |balance| {
balance.change_pending_confirmation()
})
}
}
impl<Cache> TestState<Cache, TestDb, LocalNetwork> {
pub(crate) fn get_wallet_summary(
&self,
min_confirmations: u32,
@ -1328,6 +1322,38 @@ impl<Cache> TestState<Cache, db::TestDb, LocalNetwork> {
}
}
impl<Cache, DbT: WalletRead + Reset> TestState<Cache, DbT, LocalNetwork> {
/// Resets the wallet using a new wallet database but with the same cache of blocks,
/// and returns the old wallet database file.
///
/// This does not recreate accounts, nor does it rescan the cached blocks.
/// The resulting wallet has no test account.
/// Before using any `generate_*` method on the reset state, call `reset_latest_cached_block()`.
pub(crate) fn reset(&mut self) -> DbT::Handle {
self.latest_block_height = None;
self.test_account = None;
DbT::reset(self)
}
// /// Reset the latest cached block to the most recent one in the cache database.
// #[allow(dead_code)]
// pub(crate) fn reset_latest_cached_block(&mut self) {
// self.cache
// .block_source()
// .with_blocks::<_, Infallible>(None, None, |block: CompactBlock| {
// let chain_metadata = block.chain_metadata.unwrap();
// self.latest_cached_block = Some(CachedBlock::at(
// BlockHash::from_slice(block.hash.as_slice()),
// BlockHeight::from_u32(block.height.try_into().unwrap()),
// chain_metadata.sapling_commitment_tree_size,
// chain_metadata.orchard_commitment_tree_size,
// ));
// Ok(())
// })
// .unwrap();
// }
}
// See the doc comment for `TestState::run_sqlite3` above.
//
// - `db_path` is the path to the database file.
@ -2108,11 +2134,11 @@ impl TestCache for FsBlockCache {
}
}
pub(crate) fn input_selector<P: consensus::Parameters>(
pub(crate) fn input_selector<DbT: InputSource>(
fee_rule: StandardFeeRule,
change_memo: Option<&str>,
fallback_change_pool: ShieldedProtocol,
) -> GreedyInputSelector<WalletDb<rusqlite::Connection, P>, standard::SingleOutputChangeStrategy> {
) -> GreedyInputSelector<DbT, standard::SingleOutputChangeStrategy> {
let change_memo = change_memo.map(|m| MemoBytes::from(m.parse::<Memo>().unwrap()));
let change_strategy =
standard::SingleOutputChangeStrategy::new(fee_rule, change_memo, fallback_change_pool);
@ -2121,9 +2147,9 @@ pub(crate) fn input_selector<P: consensus::Parameters>(
// Checks that a protobuf proposal serialized from the provided proposal value correctly parses to
// the same proposal value.
fn check_proposal_serialization_roundtrip<P: consensus::Parameters>(
wallet_data: &WalletDb<rusqlite::Connection, P>,
proposal: &Proposal<StandardFeeRule, ReceivedNoteId>,
fn check_proposal_serialization_roundtrip<DbT: InputSource>(
wallet_data: &DbT,
proposal: &Proposal<StandardFeeRule, DbT::NoteRef>,
) {
let proposal_proto = proposal::Proposal::from_standard_proposal(proposal);
let deserialized_proposal = proposal_proto.try_into_standard_proposal(wallet_data);