Improve wallet tests.

Signed-off-by: Daira Emma Hopwood <daira@jacaranda.org>
This commit is contained in:
Daira Emma Hopwood 2023-09-06 19:50:30 +01:00
parent 46e2c4c800
commit 401af23484
4 changed files with 193 additions and 51 deletions

View File

@ -35,7 +35,7 @@ use zcash_client_backend::{
wallet::OvkPolicy,
zip321,
};
use zcash_note_encryption::Domain;
use zcash_note_encryption::{Domain, COMPACT_NOTE_SIZE};
use zcash_primitives::{
block::BlockHash,
consensus::{self, BlockHeight, Network, NetworkUpgrade, Parameters},
@ -52,7 +52,7 @@ use zcash_primitives::{
Amount,
},
fees::FeeRule,
TxId,
Transaction, TxId,
},
zip32::{sapling::DiversifiableFullViewingKey, DiversifierIndex},
};
@ -266,6 +266,34 @@ where
(height, res)
}
/// Creates a fake block at the expected next height containing only the given
/// transaction, and inserts it into the cache.
/// This assumes that the transaction only has Sapling spends and outputs.
///
/// This generated block will be treated as the latest block, and subsequent calls to
/// [`Self::generate_next_block`] will build on it.
pub(crate) fn generate_next_block_from_tx(
&mut self,
tx: &Transaction,
) -> (BlockHeight, Cache::InsertResult) {
let (height, prev_hash, initial_sapling_tree_size) = self
.latest_cached_block
.map(|(prev_height, prev_hash, end_size)| (prev_height + 1, prev_hash, end_size))
.unwrap_or_else(|| (self.sapling_activation_height(), BlockHash([0; 32]), 0));
let cb = fake_compact_block_from_tx(height, prev_hash, tx, initial_sapling_tree_size);
let res = self.cache.insert(&cb);
self.latest_cached_block = Some((
height,
cb.hash(),
initial_sapling_tree_size
+ cb.vtx.iter().map(|tx| tx.outputs.len() as u32).sum::<u32>(),
));
(height, res)
}
/// Invokes [`scan_cached_blocks`] with the given arguments, expecting success.
pub(crate) fn scan_cached_blocks(&mut self, from_height: BlockHeight, limit: usize) {
assert_matches!(self.try_scan_cached_blocks(from_height, limit), Ok(_));
@ -663,6 +691,38 @@ pub(crate) fn fake_compact_block<P: consensus::Parameters>(
(cb, note.nf(&dfvk.fvk().vk.nk, 0))
}
/// Create a fake CompactBlock at the given height containing only the given transaction.
/// This assumes that the transaction only has Sapling spends and outputs.
pub(crate) fn fake_compact_block_from_tx(
height: BlockHeight,
prev_hash: BlockHash,
tx: &Transaction,
initial_sapling_tree_size: u32,
) -> CompactBlock {
// Create a fake CompactTx
let mut ctx = CompactTx {
hash: tx.txid().as_ref().to_vec(),
..Default::default()
};
if let Some(bundle) = tx.sapling_bundle() {
for spend in bundle.shielded_spends() {
ctx.spends.push(CompactSaplingSpend {
nf: spend.nullifier().to_vec(),
});
}
for output in bundle.shielded_outputs() {
ctx.outputs.push(CompactSaplingOutput {
cmu: output.cmu().to_bytes().to_vec(),
ephemeral_key: output.ephemeral_key().0.to_vec(),
ciphertext: output.enc_ciphertext()[..COMPACT_NOTE_SIZE].to_vec(),
});
}
}
fake_compact_block_from_compact_tx(ctx, height, prev_hash, initial_sapling_tree_size)
}
/// Create a fake CompactBlock at the given height, spending a single note from the
/// given address.
#[allow(clippy::too_many_arguments)]
@ -737,6 +797,16 @@ pub(crate) fn fake_compact_block_spending<P: consensus::Parameters>(
}
});
fake_compact_block_from_compact_tx(ctx, height, prev_hash, initial_sapling_tree_size)
}
pub(crate) fn fake_compact_block_from_compact_tx(
ctx: CompactTx,
height: BlockHeight,
prev_hash: BlockHash,
initial_sapling_tree_size: u32,
) -> CompactBlock {
let mut rng = OsRng;
let mut cb = CompactBlock {
hash: {
let mut hash = vec![0; 32];

View File

@ -512,8 +512,10 @@ pub(crate) mod tests {
let (h, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into());
st.scan_cached_blocks(h, 1);
// Verified balance matches total balance
// Spendable balance matches total balance
assert_eq!(st.get_total_balance(account), value);
assert_eq!(st.get_spendable_balance(account, 1), value);
assert_eq!(
block_max_scanned(&st.wallet().conn)
.unwrap()
@ -709,11 +711,16 @@ pub(crate) mod tests {
let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into());
st.scan_cached_blocks(h1, 1);
// Verified balance matches total balance
// Spendable balance matches total balance at 1 confirmation.
assert_eq!(st.get_total_balance(account), value);
assert_eq!(st.get_spendable_balance(account, 1), value);
// Value is considered pending
// Value is considered pending at 10 confirmations.
assert_eq!(st.get_pending_shielded_balance(account, 10), value);
assert_eq!(
st.get_spendable_balance(account, 10),
NonNegativeAmount::ZERO
);
// Wallet is fully scanned
let summary = st.get_wallet_summary(1);
@ -766,7 +773,10 @@ pub(crate) mod tests {
}
st.scan_cached_blocks(h2 + 1, 8);
// Second spend still fails
// Total balance is value * number of blocks scanned (10).
assert_eq!(st.get_total_balance(account), (value * 10).unwrap());
// Spend still fails
assert_matches!(
st.create_spend_to_address(
&usk,
@ -788,17 +798,38 @@ pub(crate) mod tests {
let (h11, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into());
st.scan_cached_blocks(h11, 1);
// Second spend should now succeed
assert_matches!(
st.create_spend_to_address(
// Total balance is value * number of blocks scanned (11).
assert_eq!(st.get_total_balance(account), (value * 11).unwrap());
// Spendable balance at 10 confirmations is value * 2.
assert_eq!(st.get_spendable_balance(account, 10), (value * 2).unwrap());
assert_eq!(
st.get_pending_shielded_balance(account, 10),
(value * 9).unwrap()
);
// Spend should now succeed
let amount_sent = NonNegativeAmount::from_u64(70000).unwrap();
let txid = st
.create_spend_to_address(
&usk,
&to,
Amount::from_u64(70000).unwrap(),
amount_sent.into(),
None,
OvkPolicy::Sender,
NonZeroU32::new(10).unwrap(),
),
Ok(_)
)
.unwrap();
let tx = &st.wallet().get_transaction(txid).unwrap();
let (h, _) = st.generate_next_block_from_tx(tx);
st.scan_cached_blocks(h, 1);
// TODO: send to an account so that we can check its balance.
assert_eq!(
st.get_total_balance(account),
((value * 11).unwrap()
- (amount_sent + NonNegativeAmount::from_u64(10000).unwrap()).unwrap())
.unwrap()
);
}
@ -816,9 +847,12 @@ pub(crate) mod tests {
let value = NonNegativeAmount::from_u64(50000).unwrap();
let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into());
st.scan_cached_blocks(h1, 1);
assert_eq!(st.get_total_balance(account), value);
// Send some of the funds to another address
// Spendable balance matches total balance at 1 confirmation.
assert_eq!(st.get_total_balance(account), value);
assert_eq!(st.get_spendable_balance(account, 1), value);
// Send some of the funds to another address, but don't mine the tx.
let extsk2 = ExtendedSpendingKey::master(&[]);
let to = extsk2.default_address().1.into();
assert_matches!(
@ -886,16 +920,33 @@ pub(crate) mod tests {
);
st.scan_cached_blocks(h43, 1);
// Spendable balance matches total balance at 1 confirmation.
assert_eq!(st.get_total_balance(account), value);
assert_eq!(st.get_spendable_balance(account, 1), value);
// Second spend should now succeed
st.create_spend_to_address(
&usk,
&to,
Amount::from_u64(2000).unwrap(),
None,
OvkPolicy::Sender,
NonZeroU32::new(1).unwrap(),
)
.unwrap();
let amount_sent2 = NonNegativeAmount::from_u64(2000).unwrap();
let txid2 = st
.create_spend_to_address(
&usk,
&to,
amount_sent2.into(),
None,
OvkPolicy::Sender,
NonZeroU32::new(1).unwrap(),
)
.unwrap();
let tx2 = &st.wallet().get_transaction(txid2).unwrap();
let (h, _) = st.generate_next_block_from_tx(tx2);
st.scan_cached_blocks(h, 1);
// TODO: send to an account so that we can check its balance.
assert_eq!(
st.get_total_balance(account),
(value - (amount_sent2 + NonNegativeAmount::from_u64(10000).unwrap()).unwrap())
.unwrap()
);
}
#[test]
@ -912,7 +963,10 @@ pub(crate) mod tests {
let value = NonNegativeAmount::from_u64(50000).unwrap();
let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into());
st.scan_cached_blocks(h1, 1);
// Spendable balance matches total balance at 1 confirmation.
assert_eq!(st.get_total_balance(account), value);
assert_eq!(st.get_spendable_balance(account, 1), value);
let extsk2 = ExtendedSpendingKey::master(&[]);
let addr2 = extsk2.default_address().1;
@ -1007,16 +1061,15 @@ pub(crate) mod tests {
let dfvk = st.test_account_sapling().unwrap();
// Add funds to the wallet in a single note
let value = Amount::from_u64(60000).unwrap();
let (h, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value);
let value = NonNegativeAmount::from_u64(60000).unwrap();
let (h, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into());
st.scan_cached_blocks(h, 1);
// Verified balance matches total balance
assert_eq!(
st.get_total_balance(account),
NonNegativeAmount::try_from(value).unwrap()
);
// Spendable balance matches total balance at 1 confirmation.
assert_eq!(st.get_total_balance(account), value);
assert_eq!(st.get_spendable_balance(account, 1), value);
// TODO: generate_next_block_from_tx does not currently support transparent outputs.
let to = TransparentAddress::PublicKey([7; 20]).into();
assert_matches!(
st.create_spend_to_address(
@ -1042,22 +1095,22 @@ pub(crate) mod tests {
let dfvk = st.test_account_sapling().unwrap();
// Add funds to the wallet in a single note owned by the internal spending key
let value = Amount::from_u64(60000).unwrap();
let (h, _, _) = st.generate_next_block(&dfvk, AddressType::Internal, value);
let value = NonNegativeAmount::from_u64(60000).unwrap();
let (h, _, _) = st.generate_next_block(&dfvk, AddressType::Internal, value.into());
st.scan_cached_blocks(h, 1);
// Verified balance matches total balance
// Spendable balance matches total balance at 1 confirmation.
assert_eq!(st.get_total_balance(account), value);
assert_eq!(st.get_spendable_balance(account, 1), value);
// Value is considered pending at 10 confirmations.
assert_eq!(st.get_pending_shielded_balance(account, 10), value);
assert_eq!(
st.get_total_balance(account),
NonNegativeAmount::try_from(value).unwrap()
);
// the balance is considered pending
assert_eq!(
st.get_pending_shielded_balance(account, 10),
NonNegativeAmount::try_from(value).unwrap()
st.get_spendable_balance(account, 10),
NonNegativeAmount::ZERO
);
// TODO: generate_next_block_from_tx does not currently support transparent outputs.
let to = TransparentAddress::PublicKey([7; 20]).into();
assert_matches!(
st.create_spend_to_address(
@ -1100,12 +1153,10 @@ pub(crate) mod tests {
st.scan_cached_blocks(h1, 11);
// Verified balance matches total balance
let total = Amount::from_u64(60000).unwrap();
assert_eq!(
st.get_total_balance(account),
NonNegativeAmount::try_from(total).unwrap()
);
// Spendable balance matches total balance
let total = NonNegativeAmount::from_u64(60000).unwrap();
assert_eq!(st.get_total_balance(account), total);
assert_eq!(st.get_spendable_balance(account, 1), total);
let input_selector = GreedyInputSelector::new(
zip317::SingleOutputChangeStrategy::new(Zip317FeeRule::standard()),
@ -1148,15 +1199,26 @@ pub(crate) mod tests {
}])
.unwrap();
assert_matches!(
st.spend(
let txid = st
.spend(
&input_selector,
&usk,
req,
OvkPolicy::Sender,
NonZeroU32::new(1).unwrap(),
),
Ok(_)
)
.unwrap();
let tx = &st.wallet().get_transaction(txid).unwrap();
let (h, _) = st.generate_next_block_from_tx(tx);
st.scan_cached_blocks(h, 1);
// TODO: send to an account so that we can check its balance.
// We sent back to the same account so the amount_sent should be included
// in the total balance.
assert_eq!(
st.get_total_balance(account),
(total - NonNegativeAmount::from_u64(10000).unwrap()).unwrap()
);
}

View File

@ -6,6 +6,8 @@ and this library adheres to Rust's notion of
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- Trait implementation `Mul<usize>` for `NonNegativeAmount`.
## [0.13.0-rc.1] - 2023-09-08
### Added

View File

@ -306,6 +306,14 @@ impl Sub<NonNegativeAmount> for Option<NonNegativeAmount> {
}
}
impl Mul<usize> for NonNegativeAmount {
type Output = Option<Self>;
fn mul(self, rhs: usize) -> Option<NonNegativeAmount> {
(self.0 * rhs).map(NonNegativeAmount)
}
}
/// A type for balance violations in amount addition and subtraction
/// (overflow and underflow of allowed ranges)
#[derive(Copy, Clone, Debug, PartialEq, Eq)]