Add several tests

This commit is contained in:
Andrew Arnott 2024-06-18 11:12:05 -06:00
parent d27bf4fc12
commit b075636e86
No known key found for this signature in database
GPG Key ID: 251505B99C25745D
3 changed files with 196 additions and 44 deletions

View File

@ -612,13 +612,13 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
_spending_key_available: bool,
) -> Result<Self::Account, Self::Error> {
self.transactionally(|wdb| {
Ok(wallet::add_account(
wallet::add_account(
wdb.conn.0,
&wdb.params,
AccountSource::Imported,
wallet::ViewingKey::Full(Box::new(ufvk.to_owned())),
birthday,
)?)
)
})
}
@ -1909,11 +1909,14 @@ extern crate assert_matches;
#[cfg(test)]
mod tests {
use secrecy::SecretVec;
use zcash_client_backend::data_api::{WalletRead, WalletWrite};
use secrecy::{Secret, SecretVec};
use zcash_client_backend::data_api::{
chain::ChainState, Account, AccountBirthday, AccountSource, WalletRead, WalletWrite,
};
use zcash_keys::keys::UnifiedSpendingKey;
use zcash_primitives::block::BlockHash;
use crate::{testing::TestBuilder, AccountId, DEFAULT_UA_REQUEST};
use crate::{error::SqliteClientError, testing::TestBuilder, AccountId, DEFAULT_UA_REQUEST};
#[cfg(feature = "unstable")]
use {
@ -1977,6 +1980,111 @@ mod tests {
assert_eq!(addr2, addr2_cur);
}
#[test]
pub(crate) fn import_account_hd_0() {
let st = TestBuilder::new()
.with_account_from_sapling_activation(BlockHash([0; 32]))
.with_account_having_index(zip32::AccountId::ZERO)
.build();
assert_matches!(
st.test_account().unwrap().account().source(),
AccountSource::Derived { seed_fingerprint: _, account_index } if account_index == zip32::AccountId::ZERO);
}
#[test]
pub(crate) fn import_account_hd_1_then_2() {
let mut st = TestBuilder::new().build();
let birthday = AccountBirthday::from_parts(
ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])),
None,
);
let seed = Secret::new(vec![0u8; 32]);
let zip32_index = zip32::AccountId::ZERO.next().unwrap();
let first = st
.wallet_mut()
.import_account_hd(&seed, zip32_index, &birthday)
.unwrap();
assert_matches!(
first.0.source(),
AccountSource::Derived { seed_fingerprint: _, account_index } if account_index == zip32_index);
let second = st
.wallet_mut()
.import_account_hd(&seed, zip32_index.next().unwrap(), &birthday)
.unwrap();
assert_matches!(
second.0.source(),
AccountSource::Derived { seed_fingerprint: _, account_index } if account_index == zip32_index.next().unwrap());
}
#[test]
pub(crate) fn import_account_hd_1_twice() {
let mut st = TestBuilder::new().build();
let birthday = AccountBirthday::from_parts(
ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])),
None,
);
let seed = Secret::new(vec![0u8; 32]);
let zip32_index = zip32::AccountId::ZERO.next().unwrap();
let first = st
.wallet_mut()
.import_account_hd(&seed, zip32_index, &birthday)
.unwrap();
assert_matches!(
st.wallet_mut().import_account_hd(&seed, zip32_index, &birthday),
Err(SqliteClientError::AccountCollision(id)) if id == first.0.id());
}
#[test]
pub(crate) fn import_account_ufvk() {
let mut st = TestBuilder::new().build();
let birthday = AccountBirthday::from_parts(
ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])),
None,
);
let seed = vec![0u8; 32];
let usk = UnifiedSpendingKey::from_seed(&st.wallet().params, &seed, zip32::AccountId::ZERO)
.unwrap();
let ufvk = usk.to_unified_full_viewing_key();
let account = st
.wallet_mut()
.import_account_ufvk(&ufvk, &birthday, true)
.unwrap();
assert_eq!(
ufvk.encode(&st.wallet().params),
account.ufvk().unwrap().encode(&st.wallet().params)
);
}
#[test]
pub(crate) fn create_account_then_conflicting_import_account_ufvk() {
let mut st = TestBuilder::new().build();
let birthday = AccountBirthday::from_parts(
ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])),
None,
);
let seed = Secret::new(vec![0u8; 32]);
let seed_based = st.wallet_mut().create_account(&seed, &birthday).unwrap();
let seed_based_account = st.wallet().get_account(seed_based.0).unwrap().unwrap();
let ufvk = seed_based_account.ufvk().unwrap();
assert_matches!(
st.wallet_mut().import_account_ufvk(ufvk, &birthday, true),
Err(SqliteClientError::AccountCollision(id)) if id == seed_based.0);
}
#[cfg(feature = "transparent-inputs")]
#[test]
fn transparent_receivers() {

View File

@ -26,6 +26,7 @@ use sapling::{
zip32::DiversifiableFullViewingKey,
Note, Nullifier,
};
use zcash_client_backend::data_api::Account as AccountTrait;
#[allow(deprecated)]
use zcash_client_backend::{
address::Address,
@ -74,7 +75,7 @@ use crate::{
error::SqliteClientError,
wallet::{
commitment_tree, get_wallet_summary, init::init_wallet_db, sapling::tests::test_prover,
SubtreeScanProgress,
Account, SubtreeScanProgress,
},
AccountId, ReceivedNoteId, WalletDb,
};
@ -117,6 +118,7 @@ pub(crate) struct TestBuilder<Cache> {
cache: Cache,
initial_chain_state: Option<InitialChainState>,
account_birthday: Option<AccountBirthday>,
account_index: Option<zip32::AccountId>,
}
impl TestBuilder<()> {
@ -143,6 +145,7 @@ impl TestBuilder<()> {
cache: (),
initial_chain_state: None,
account_birthday: None,
account_index: None,
}
}
@ -154,6 +157,7 @@ impl TestBuilder<()> {
cache: BlockCache::new(),
initial_chain_state: self.initial_chain_state,
account_birthday: self.account_birthday,
account_index: self.account_index,
}
}
@ -166,6 +170,7 @@ impl TestBuilder<()> {
cache: FsBlockCache::new(),
initial_chain_state: self.initial_chain_state,
account_birthday: self.account_birthday,
account_index: self.account_index,
}
}
}
@ -227,6 +232,12 @@ impl<Cache> TestBuilder<Cache> {
self
}
pub(crate) fn with_account_having_index(mut self, index: zip32::AccountId) -> Self {
assert!(self.account_index.is_none());
self.account_index = Some(index);
self
}
/// Builds the state for this test.
pub(crate) fn build(self) -> TestState<Cache> {
let data_file = NamedTempFile::new().unwrap();
@ -288,11 +299,17 @@ impl<Cache> TestBuilder<Cache> {
let test_account = self.account_birthday.map(|birthday| {
let seed = Secret::new(vec![0u8; 32]);
let (account_id, usk) = db_data.create_account(&seed, &birthday).unwrap();
let (account, usk) = match self.account_index {
Some(index) => db_data.import_account_hd(&seed, index, &birthday).unwrap(),
None => {
let result = db_data.create_account(&seed, &birthday).unwrap();
(db_data.get_account(result.0).unwrap().unwrap(), result.1)
}
};
(
seed,
TestAccount {
account_id,
account,
usk,
birthday,
},
@ -394,14 +411,18 @@ impl CachedBlock {
#[derive(Clone)]
pub(crate) struct TestAccount {
account_id: AccountId,
account: Account,
usk: UnifiedSpendingKey,
birthday: AccountBirthday,
}
impl TestAccount {
pub(crate) fn account(&self) -> &Account {
&self.account
}
pub(crate) fn account_id(&self) -> AccountId {
self.account_id
self.account.id()
}
pub(crate) fn usk(&self) -> &UnifiedSpendingKey {

View File

@ -65,7 +65,7 @@
//! - `memo` the shielded memo associated with the output, if any.
use incrementalmerkletree::Retention;
use rusqlite::{self, named_params, OptionalExtension};
use rusqlite::{self, named_params, params, OptionalExtension};
use secrecy::{ExposeSecret, SecretVec};
use shardtree::{error::ShardTreeError, store::ShardStore, ShardTree};
use zip32::fingerprint::SeedFingerprint;
@ -390,40 +390,63 @@ pub(crate) fn add_account<P: consensus::Parameters>(
#[cfg(not(feature = "orchard"))]
let birthday_orchard_tree_size: Option<u64> = None;
let account_id: AccountId = conn.query_row(
r#"
INSERT INTO accounts (
account_kind, hd_seed_fingerprint, hd_account_index,
ufvk, uivk,
orchard_fvk_item_cache, sapling_fvk_item_cache, p2pkh_fvk_item_cache,
birthday_height, birthday_sapling_tree_size, birthday_orchard_tree_size,
recover_until_height
let ufvk_encoded = viewing_key.ufvk().map(|ufvk| ufvk.encode(params));
let account_id: AccountId = conn
.query_row(
r#"
INSERT INTO accounts (
account_kind, hd_seed_fingerprint, hd_account_index,
ufvk, uivk,
orchard_fvk_item_cache, sapling_fvk_item_cache, p2pkh_fvk_item_cache,
birthday_height, birthday_sapling_tree_size, birthday_orchard_tree_size,
recover_until_height
)
VALUES (
:account_kind, :hd_seed_fingerprint, :hd_account_index,
:ufvk, :uivk,
:orchard_fvk_item_cache, :sapling_fvk_item_cache, :p2pkh_fvk_item_cache,
:birthday_height, :birthday_sapling_tree_size, :birthday_orchard_tree_size,
:recover_until_height
)
RETURNING id;
"#,
named_params![
":account_kind": account_kind_code(kind),
":hd_seed_fingerprint": hd_seed_fingerprint.as_ref().map(|fp| fp.to_bytes()),
":hd_account_index": hd_account_index.map(u32::from),
":ufvk": ufvk_encoded,
":uivk": viewing_key.uivk().encode(params),
":orchard_fvk_item_cache": orchard_item,
":sapling_fvk_item_cache": sapling_item,
":p2pkh_fvk_item_cache": transparent_item,
":birthday_height": u32::from(birthday.height()),
":birthday_sapling_tree_size": birthday_sapling_tree_size,
":birthday_orchard_tree_size": birthday_orchard_tree_size,
":recover_until_height": birthday.recover_until().map(u32::from)
],
|row| Ok(AccountId(row.get(0)?)),
)
VALUES (
:account_kind, :hd_seed_fingerprint, :hd_account_index,
:ufvk, :uivk,
:orchard_fvk_item_cache, :sapling_fvk_item_cache, :p2pkh_fvk_item_cache,
:birthday_height, :birthday_sapling_tree_size, :birthday_orchard_tree_size,
:recover_until_height
)
RETURNING id;
"#,
named_params![
":account_kind": account_kind_code(kind),
":hd_seed_fingerprint": hd_seed_fingerprint.as_ref().map(|fp| fp.to_bytes()),
":hd_account_index": hd_account_index.map(u32::from),
":ufvk": viewing_key.ufvk().map(|ufvk| ufvk.encode(params)),
":uivk": viewing_key.uivk().encode(params),
":orchard_fvk_item_cache": orchard_item,
":sapling_fvk_item_cache": sapling_item,
":p2pkh_fvk_item_cache": transparent_item,
":birthday_height": u32::from(birthday.height()),
":birthday_sapling_tree_size": birthday_sapling_tree_size,
":birthday_orchard_tree_size": birthday_orchard_tree_size,
":recover_until_height": birthday.recover_until().map(u32::from)
],
|row| Ok(AccountId(row.get(0)?)),
)?;
.map_err(|e| match e {
rusqlite::Error::SqliteFailure(f, s)
if f.code == rusqlite::ErrorCode::ConstraintViolation =>
{
// An account conflict occurred.
// Make a best effort to determine the AccountId of the pre-existing row
// and provide that to our caller.
if s.clone().is_some_and(|s| s.contains(".ufvk")) {
if let Ok(id) = conn.query_row(
"SELECT id FROM accounts WHERE ufvk = ?",
params![ufvk_encoded],
|row| Ok(AccountId(row.get(0)?)),
) {
return SqliteClientError::AccountCollision(id);
}
}
SqliteClientError::from(rusqlite::Error::SqliteFailure(f, s))
}
_ => SqliteClientError::from(e),
})?;
let account = Account {
account_id,