Fix test compilation with WalletWrite changes.

This commit is contained in:
Kris Nuttycombe 2021-01-11 18:13:40 -07:00
parent aad2e174c1
commit 7d92150965
12 changed files with 223 additions and 177 deletions

View File

@ -1,4 +1,5 @@
use std::cmp;
use std::collections::HashMap;
use std::fmt::Debug;
use zcash_primitives::{
@ -6,7 +7,7 @@ use zcash_primitives::{
consensus::{self, BlockHeight},
merkle_tree::{CommitmentTree, IncrementalWitness},
note_encryption::Memo,
primitives::{Note, PaymentAddress, Nullifier},
primitives::{Note, Nullifier, PaymentAddress},
sapling::Node,
transaction::{components::Amount, Transaction, TxId},
zip32::ExtendedFullViewingKey,
@ -28,19 +29,19 @@ pub mod wallet;
///
/// This trait defines the read-only portion of the storage
/// interface atop which higher-level wallet operations are
/// implemented. It serves to allow wallet functions to be
/// implemented. It serves to allow wallet functions to be
/// abstracted away from any particular data storage substrate.
pub trait WalletRead {
/// The type of errors produced by a wallet backend.
type Error;
/// Backend-specific note identifier.
/// Backend-specific note identifier.
///
/// For example, this might be a database identifier type
/// or a UUID.
type NoteRef: Copy + Debug;
type NoteRef: Copy + Debug;
/// Backend-specific transaction identifier.
/// Backend-specific transaction identifier.
///
/// For example, this might be a database identifier type
/// or a TxId if the backend is able to support that type
@ -51,7 +52,7 @@ pub trait WalletRead {
fn block_height_extrema(&self) -> Result<Option<(BlockHeight, BlockHeight)>, Self::Error>;
/// Returns the default target height and anchor height, given the
/// range of block heights that the backend knows about.
/// range of block heights that the backend knows about.
fn get_target_and_anchor_heights(
&self,
) -> Result<Option<(BlockHeight, BlockHeight)>, Self::Error> {
@ -105,10 +106,10 @@ pub trait WalletRead {
fn get_extended_full_viewing_keys<P: consensus::Parameters>(
&self,
params: &P,
) -> Result<Vec<ExtendedFullViewingKey>, Self::Error>;
) -> Result<HashMap<AccountId, ExtendedFullViewingKey>, Self::Error>;
/// Checks whether the specified extended full viewing key is a valid
/// key for the specified account.
/// Checks whether the specified extended full viewing key is
/// associated with the account.
fn is_valid_account_extfvk<P: consensus::Parameters>(
&self,
params: &P,
@ -120,13 +121,13 @@ pub trait WalletRead {
///
/// This balance amount is the raw balance of all transactions in known
/// mined blocks, irrespective of confirmation depth.
// TODO: Do we actually need this? You can always get the "verified"
// TODO: Do we actually need this? You can always get the "verified"
// balance from the current chain tip.
fn get_balance(&self, account: AccountId) -> Result<Amount, Self::Error>;
/// Returns the wallet balance for an account as of the specified block
/// height. and
///
///
/// This may be used to obtain a balance that ignores notes that have been
/// received so recently that they are not yet deemed spendable.
fn get_verified_balance(
@ -161,7 +162,7 @@ pub trait WalletRead {
fn get_nullifiers(&self) -> Result<Vec<(Nullifier, AccountId)>, Self::Error>;
/// Returns a list of spendable notes sufficient to cover the specified
/// target value, if possible.
/// target value, if possible.
fn select_spendable_notes(
&self,
account: AccountId,
@ -175,8 +176,8 @@ pub trait WalletRead {
pub trait WalletWrite: WalletRead {
/// Perform one or more write operations of this trait transactionally.
/// Implementations of this method must ensure that all mutations to the
/// state of the data store made by the provided closure must be performed
/// atomically and modifications to state must be automatically rolled back
/// state of the data store made by the provided closure must be performed
/// atomically and modifications to state must be automatically rolled back
/// if the provided closure returns an error.
fn transactionally<F, A>(&mut self, f: F) -> Result<A, Self::Error>
where
@ -217,8 +218,11 @@ pub trait WalletWrite: WalletRead {
created_at: Option<time::OffsetDateTime>,
) -> Result<Self::TxRef, Self::Error>;
/// Mark the specified transaction as spent and record the nullifier.
fn mark_spent(&mut self, tx_ref: Self::TxRef, nf: &Nullifier) -> Result<(), Self::Error>;
/// Record a note as having been received, along with its nullifier and the transaction
/// within which the note was created.
fn put_received_note<T: ShieldedOutput>(
&mut self,
output: &T,
@ -306,7 +310,7 @@ impl ShieldedOutput for DecryptedOutput {
self.index
}
fn account(&self) -> AccountId {
AccountId(self.account as u32)
self.account
}
fn to(&self) -> &PaymentAddress {
&self.to

View File

@ -9,10 +9,10 @@ use zcash_primitives::{
use crate::{
data_api::{
error::{ChainInvalid, Error},
BlockSource, WalletRead, WalletWrite,
BlockSource, WalletWrite,
},
proto::compact_formats::CompactBlock,
wallet::{WalletTx},
wallet::WalletTx,
welding_rig::scan_block,
};
@ -120,21 +120,21 @@ where
/// let cache_file = NamedTempFile::new().unwrap();
/// let cache = BlockDB::for_path(cache_file).unwrap();
/// let data_file = NamedTempFile::new().unwrap();
/// let data = WalletDB::for_path(data_file).unwrap();
/// let data = WalletDB::for_path(data_file).unwrap().get_update_ops().unwrap();
/// scan_cached_blocks(&Network::TestNetwork, &cache, &data, None);
/// ```
///
/// [`init_blocks_table`]: crate::init::init_blocks_table
pub fn scan_cached_blocks<'db, E, E0, N, P, C, D>(
pub fn scan_cached_blocks<E, E0, N, P, C, D>(
params: &P,
cache: &C,
mut data: &'db D,
data: &mut D,
limit: Option<u32>,
) -> Result<(), E>
where
P: consensus::Parameters,
C: BlockSource<Error = E>,
&'db D: WalletRead<Error = E, NoteRef = N> + WalletWrite<Error = E, NoteRef = N>,
D: WalletWrite<Error = E, NoteRef = N>,
N: Copy + Debug,
E: From<Error<E0, N>>,
{
@ -151,6 +151,7 @@ where
// Fetch the ExtendedFullViewingKeys we are tracking
let extfvks = data.get_extended_full_viewing_keys(params)?;
let ivks: Vec<_> = extfvks.values().map(|extfvk| extfvk.fvk.vk.ivk()).collect();
// Get the most recent CommitmentTree
let mut tree = data
@ -181,7 +182,7 @@ where
scan_block(
params,
block,
&extfvks[..],
&ivks,
&nullifiers,
&mut tree,
&mut witness_refs[..],
@ -226,25 +227,27 @@ where
}
// remove spent nullifiers from the nullifier set
nullifiers.retain(|(nf, _acc)| {
!tx.shielded_spends
.iter()
.any(|spend| &spend.nf == nf)
});
nullifiers
.retain(|(nf, _acc)| !tx.shielded_spends.iter().any(|spend| &spend.nf == nf));
for output in tx.shielded_outputs {
let nf = output.note.nf(
&extfvks[output.account.0 as usize].fvk.vk,
output.witness.position() as u64,
);
match &extfvks.get(&output.account) {
Some(extfvk) => {
let nf = output.note.nf(
&extfvk.fvk.vk,
output.witness.position() as u64,
);
let note_id = up.put_received_note(&output, &Some(nf), tx_row)?;
let note_id = up.put_received_note(&output, &Some(nf), tx_row)?;
// Save witness for note.
witnesses.push((note_id, output.witness));
// Save witness for note.
witnesses.push((note_id, output.witness));
// Cache nullifier for note (to detect subsequent spends in this scan).
nullifiers.push((nf, output.account));
// Cache nullifier for note (to detect subsequent spends in this scan).
nullifiers.push((nf, output.account));
}
None => ()
}
}
}

View File

@ -32,7 +32,7 @@ pub fn decrypt_and_store_transaction<'db, E0, N, E, P, D>(
where
E: From<Error<E0, N>>,
P: consensus::Parameters,
&'db D: WalletRead<Error = E> + WalletWrite<Error = E>,
&'db D: WalletWrite<Error = E>,
{
// Fetch the ExtendedFullViewingKeys we are tracking
let extfvks = data.get_extended_full_viewing_keys(params)?;
@ -125,7 +125,7 @@ where
/// let to = extsk.default_address().unwrap().1.into();
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file).unwrap();
/// let db = WalletDB::for_path(data_file).unwrap().get_update_ops().unwrap();
/// match create_spend_to_address(
/// &db,
/// &Network::TestNetwork,
@ -141,8 +141,8 @@ where
/// Err(e) => (),
/// }
/// ```
pub fn create_spend_to_address<'db, E0, N, E, P, D, R>(
mut data: &'db D,
pub fn create_spend_to_address<E0, N, E, P, D, R>(
data: &mut D,
params: &P,
prover: impl TxProver,
account: AccountId,
@ -156,7 +156,7 @@ where
E0: Into<Error<E, N>>,
P: consensus::Parameters + Clone,
R: Copy + Debug,
&'db D: WalletRead<Error = E0, TxRef = R> + WalletWrite<Error = E0, TxRef = R>,
D: WalletWrite<Error = E0, TxRef = R>,
{
// Check that the ExtendedSpendingKey we have been given corresponds to the
// ExtendedFullViewingKey for the account we are spending from.
@ -233,33 +233,33 @@ where
// Update the database atomically, to ensure the result is internally consistent.
data.transactionally(|up| {
let created = time::OffsetDateTime::now_utc();
let tx_ref = up.put_tx_data(&tx, Some(created))?;
let created = time::OffsetDateTime::now_utc();
let tx_ref = up.put_tx_data(&tx, Some(created))?;
// Mark notes as spent.
//
// This locks the notes so they aren't selected again by a subsequent call to
// create_spend_to_address() before this transaction has been mined (at which point the notes
// get re-marked as spent).
//
// Assumes that create_spend_to_address() will never be called in parallel, which is a
// reasonable assumption for a light client such as a mobile phone.
for spend in &tx.shielded_spends {
up.mark_spent(tx_ref, &spend.nullifier)?;
}
// Mark notes as spent.
//
// This locks the notes so they aren't selected again by a subsequent call to
// create_spend_to_address() before this transaction has been mined (at which point the notes
// get re-marked as spent).
//
// Assumes that create_spend_to_address() will never be called in parallel, which is a
// reasonable assumption for a light client such as a mobile phone.
for spend in &tx.shielded_spends {
up.mark_spent(tx_ref, &spend.nullifier)?;
}
up.insert_sent_note(
params,
tx_ref,
output_index as usize,
account,
to,
value,
memo,
)?;
up.insert_sent_note(
params,
tx_ref,
output_index as usize,
account,
to,
value,
memo,
)?;
// Return the row number of the transaction, so the caller can fetch it for sending.
Ok(tx_ref)
})
.map_err(|e| e.into())
// Return the row number of the transaction, so the caller can fetch it for sending.
Ok(tx_ref)
})
.map_err(|e| e.into())
}

View File

@ -1,3 +1,5 @@
use std::collections::HashMap;
use zcash_primitives::{
consensus::{self, BlockHeight},
note_encryption::{try_sapling_note_decryption, try_sapling_output_recovery, Memo},
@ -6,6 +8,8 @@ use zcash_primitives::{
zip32::ExtendedFullViewingKey,
};
use crate::wallet::AccountId;
/// A decrypted shielded output.
pub struct DecryptedOutput {
/// The index of the output within [`shielded_outputs`].
@ -15,7 +19,7 @@ pub struct DecryptedOutput {
/// The note within the output.
pub note: Note,
/// The account that decrypted the note.
pub account: usize,
pub account: AccountId,
/// The address the note was sent to.
pub to: PaymentAddress,
/// The memo included with the note.
@ -33,22 +37,19 @@ pub fn decrypt_transaction<P: consensus::Parameters>(
params: &P,
height: BlockHeight,
tx: &Transaction,
extfvks: &[ExtendedFullViewingKey],
extfvks: &HashMap<AccountId, ExtendedFullViewingKey>,
) -> Vec<DecryptedOutput> {
let mut decrypted = vec![];
// Cache IncomingViewingKey calculation
let vks: Vec<_> = extfvks
.iter()
.map(|extfvk| (extfvk.fvk.vk.ivk(), extfvk.fvk.ovk))
.collect();
for (account, extfvk) in extfvks.iter() {
let ivk = extfvk.fvk.vk.ivk();
let ovk = extfvk.fvk.ovk;
for (index, output) in tx.shielded_outputs.iter().enumerate() {
for (account, (ivk, ovk)) in vks.iter().enumerate() {
for (index, output) in tx.shielded_outputs.iter().enumerate() {
let ((note, to, memo), outgoing) = match try_sapling_note_decryption(
params,
height,
ivk,
&ivk,
&output.ephemeral_key,
&output.cmu,
&output.enc_ciphertext,
@ -57,7 +58,7 @@ pub fn decrypt_transaction<P: consensus::Parameters>(
None => match try_sapling_output_recovery(
params,
height,
ovk,
&ovk,
&output.cv,
&output.cmu,
&output.ephemeral_key,
@ -71,7 +72,7 @@ pub fn decrypt_transaction<P: consensus::Parameters>(
decrypted.push(DecryptedOutput {
index,
note,
account,
account: *account,
to,
memo,
outgoing,

View File

@ -6,11 +6,10 @@ use subtle::{ConditionallySelectable, ConstantTimeEq, CtOption};
use zcash_primitives::{
consensus::{self, BlockHeight},
merkle_tree::{CommitmentTree, IncrementalWitness},
primitives::Nullifier,
note_encryption::try_sapling_compact_note_decryption,
primitives::Nullifier,
sapling::Node,
transaction::TxId,
zip32::ExtendedFullViewingKey,
};
use crate::proto::compact_formats::{CompactBlock, CompactOutput};
@ -91,13 +90,12 @@ fn scan_output<P: consensus::Parameters>(
pub fn scan_block<P: consensus::Parameters>(
params: &P,
block: CompactBlock,
extfvks: &[ExtendedFullViewingKey],
ivks: &[jubjub::Fr],
nullifiers: &[(Nullifier, AccountId)],
tree: &mut CommitmentTree<Node>,
existing_witnesses: &mut [&mut IncrementalWitness<Node>],
) -> Vec<WalletTx> {
let mut wtxs: Vec<WalletTx> = vec![];
let ivks: Vec<_> = extfvks.iter().map(|extfvk| extfvk.fvk.vk.ivk()).collect();
let block_height = block.height();
for tx in block.vtx.into_iter() {
@ -201,13 +199,14 @@ mod tests {
constants::SPENDING_KEY_GENERATOR,
merkle_tree::CommitmentTree,
note_encryption::{Memo, SaplingNoteEncryption},
primitives::Note,
primitives::{Note, Nullifier},
transaction::components::Amount,
util::generate_random_rseed,
zip32::{ExtendedFullViewingKey, ExtendedSpendingKey},
};
use super::scan_block;
use crate::wallet::AccountId;
use crate::proto::compact_formats::{CompactBlock, CompactOutput, CompactSpend, CompactTx};
fn random_compact_tx(mut rng: impl RngCore) -> CompactTx {
@ -247,7 +246,7 @@ mod tests {
/// Returns the CompactBlock.
fn fake_compact_block(
height: BlockHeight,
nf: [u8; 32],
nf: Nullifier,
extfvk: ExtendedFullViewingKey,
value: Amount,
tx_after: bool,
@ -286,7 +285,7 @@ mod tests {
}
let mut cspend = CompactSpend::new();
cspend.set_nf(nf.to_vec());
cspend.set_nf(nf.0.to_vec());
let mut cout = CompactOutput::new();
cout.set_cmu(cmu);
cout.set_epk(epk);
@ -317,7 +316,7 @@ mod tests {
let cb = fake_compact_block(
1u32.into(),
[0; 32],
Nullifier([0; 32]),
extfvk.clone(),
Amount::from_u64(5).unwrap(),
false,
@ -328,7 +327,7 @@ mod tests {
let txs = scan_block(
&Network::TestNetwork,
cb,
&[extfvk],
&[extfvk.fvk.vk.ivk()],
&[],
&mut tree,
&mut [],
@ -342,7 +341,7 @@ mod tests {
assert_eq!(tx.shielded_spends.len(), 0);
assert_eq!(tx.shielded_outputs.len(), 1);
assert_eq!(tx.shielded_outputs[0].index, 0);
assert_eq!(tx.shielded_outputs[0].account, 0);
assert_eq!(tx.shielded_outputs[0].account, AccountId(0));
assert_eq!(tx.shielded_outputs[0].note.value, 5);
// Check that the witness root matches
@ -356,7 +355,7 @@ mod tests {
let cb = fake_compact_block(
1u32.into(),
[0; 32],
Nullifier([0; 32]),
extfvk.clone(),
Amount::from_u64(5).unwrap(),
true,
@ -367,7 +366,7 @@ mod tests {
let txs = scan_block(
&Network::TestNetwork,
cb,
&[extfvk],
&[extfvk.fvk.vk.ivk()],
&[],
&mut tree,
&mut [],
@ -381,7 +380,7 @@ mod tests {
assert_eq!(tx.shielded_spends.len(), 0);
assert_eq!(tx.shielded_outputs.len(), 1);
assert_eq!(tx.shielded_outputs[0].index, 0);
assert_eq!(tx.shielded_outputs[0].account, 0);
assert_eq!(tx.shielded_outputs[0].account, AccountId(0));
assert_eq!(tx.shielded_outputs[0].note.value, 5);
// Check that the witness root matches
@ -392,8 +391,8 @@ mod tests {
fn scan_block_with_my_spend() {
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
let nf = [7; 32];
let account = 12;
let nf = Nullifier([7; 32]);
let account = AccountId(12);
let cb = fake_compact_block(1u32.into(), nf, extfvk, Amount::from_u64(5).unwrap(), false);
assert_eq!(cb.vtx.len(), 2);
@ -403,7 +402,7 @@ mod tests {
&Network::TestNetwork,
cb,
&[],
&[(&nf, account)],
&[(nf.clone(), account)],
&mut tree,
&mut [],
);

View File

@ -137,7 +137,6 @@ where
#[cfg(test)]
mod tests {
use rusqlite::Connection;
use tempfile::NamedTempFile;
use zcash_primitives::{
@ -163,17 +162,17 @@ mod tests {
init::{init_accounts_table, init_data_database},
rewind_to_height,
},
AccountId, BlockDB, NoteId, WalletDB,
AccountId, BlockDB, NoteId, WalletDB
};
#[test]
fn valid_chain_states() {
let cache_file = NamedTempFile::new().unwrap();
let db_cache = BlockDB(Connection::open(cache_file.path()).unwrap());
let db_cache = BlockDB::for_path(cache_file.path()).unwrap();
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
@ -207,7 +206,13 @@ mod tests {
.unwrap();
// Scan the cache
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(
&tests::network(),
&db_cache,
&mut db_write,
None
).unwrap();
// Data-only chain should be valid
validate_chain(
@ -235,7 +240,7 @@ mod tests {
.unwrap();
// Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Data-only chain should be valid
validate_chain(
@ -249,11 +254,11 @@ mod tests {
#[test]
fn invalid_chain_cache_disconnected() {
let cache_file = NamedTempFile::new().unwrap();
let db_cache = BlockDB(Connection::open(cache_file.path()).unwrap());
let db_cache = BlockDB::for_path(cache_file.path()).unwrap();
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
@ -278,7 +283,8 @@ mod tests {
insert_into_cache(&db_cache, &cb2);
// Scan the cache
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Data-only chain should be valid
validate_chain(
@ -322,11 +328,11 @@ mod tests {
#[test]
fn invalid_chain_cache_reorg() {
let cache_file = NamedTempFile::new().unwrap();
let db_cache = BlockDB(Connection::open(cache_file.path()).unwrap());
let db_cache = BlockDB::for_path(cache_file.path()).unwrap();
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
@ -351,7 +357,8 @@ mod tests {
insert_into_cache(&db_cache, &cb2);
// Scan the cache
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Data-only chain should be valid
validate_chain(
@ -395,11 +402,11 @@ mod tests {
#[test]
fn data_db_rewinding() {
let cache_file = NamedTempFile::new().unwrap();
let db_cache = BlockDB(Connection::open(cache_file.path()).unwrap());
let db_cache = BlockDB::for_path(cache_file.path()).unwrap();
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
@ -426,7 +433,8 @@ mod tests {
insert_into_cache(&db_cache, &cb2);
// Scan the cache
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should reflect both received notes
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
@ -444,7 +452,7 @@ mod tests {
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value);
// Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should again reflect both received notes
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
@ -453,11 +461,11 @@ mod tests {
#[test]
fn scan_cached_blocks_requires_sequential_blocks() {
let cache_file = NamedTempFile::new().unwrap();
let db_cache = BlockDB(Connection::open(cache_file.path()).unwrap());
let db_cache = BlockDB::for_path(cache_file.path()).unwrap();
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
@ -474,7 +482,8 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb1);
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value);
// We cannot scan a block of height SAPLING_ACTIVATION_HEIGHT + 2 next
@ -491,7 +500,7 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb3);
match scan_cached_blocks(&tests::network(), &db_cache, &db_data, None) {
match scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None) {
Ok(_) => panic!("Should have failed"),
Err(e) => {
assert_eq!(
@ -507,7 +516,7 @@ mod tests {
// If we add a block of height SAPLING_ACTIVATION_HEIGHT + 1, we can now scan both
insert_into_cache(&db_cache, &cb2);
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
assert_eq!(
get_balance(&db_data, AccountId(0)).unwrap(),
Amount::from_u64(150_000).unwrap()
@ -517,11 +526,11 @@ mod tests {
#[test]
fn scan_cached_blocks_finds_received_notes() {
let cache_file = NamedTempFile::new().unwrap();
let db_cache = BlockDB(Connection::open(cache_file.path()).unwrap());
let db_cache = BlockDB::for_path(cache_file.path()).unwrap();
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
@ -543,7 +552,8 @@ mod tests {
insert_into_cache(&db_cache, &cb);
// Scan the cache
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should reflect the received note
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value);
@ -555,7 +565,7 @@ mod tests {
insert_into_cache(&db_cache, &cb2);
// Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should reflect both received notes
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
@ -564,11 +574,11 @@ mod tests {
#[test]
fn scan_cached_blocks_finds_change_notes() {
let cache_file = NamedTempFile::new().unwrap();
let db_cache = BlockDB(Connection::open(cache_file.path()).unwrap());
let db_cache = BlockDB::for_path(cache_file.path()).unwrap();
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
@ -590,7 +600,8 @@ mod tests {
insert_into_cache(&db_cache, &cb);
// Scan the cache
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should reflect the received note
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value);
@ -612,7 +623,7 @@ mod tests {
);
// Scan the cache again
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should equal the change
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value - value2);

View File

@ -25,6 +25,7 @@
//! [`init_cache_database`]: crate::init::init_cache_database
use std::fmt;
use std::collections::HashMap;
use std::path::Path;
use rusqlite::{types::ToSql, Connection, Statement, NO_PARAMS};
@ -57,6 +58,8 @@ pub mod chain;
pub mod error;
pub mod wallet;
/// A newtype wrapper for sqlite primary key values for the notes
/// table.
#[derive(Debug, Copy, Clone)]
pub struct NoteId(pub i64);
@ -66,13 +69,18 @@ impl fmt::Display for NoteId {
}
}
/// A newtype wrapper for the sqlite connection to the wallet database.
pub struct WalletDB(Connection);
impl WalletDB {
/// Construct a connection to the wallet database stored at the specified path.
pub fn for_path<P: AsRef<Path>>(path: P) -> Result<Self, rusqlite::Error> {
Connection::open(path).map(WalletDB)
}
/// Given a wallet database connection, obtain a handle for the write operations
/// for that database. This operation may eagerly initialize and cache sqlite
/// prepared statements that are used in write operations.
pub fn get_update_ops<'a>(&'a self) -> Result<DataConnStmtCache<'a>, SqliteClientError> {
Ok(
DataConnStmtCache {
@ -148,7 +156,7 @@ impl WalletDB {
}
}
impl<'a> WalletRead for &'a WalletDB {
impl WalletRead for WalletDB {
type Error = SqliteClientError;
type NoteRef = NoteId;
type TxRef = i64;
@ -168,7 +176,7 @@ impl<'a> WalletRead for &'a WalletDB {
fn get_extended_full_viewing_keys<P: consensus::Parameters>(
&self,
params: &P,
) -> Result<Vec<ExtendedFullViewingKey>, Self::Error> {
) -> Result<HashMap<AccountId, ExtendedFullViewingKey>, Self::Error> {
wallet::get_extended_full_viewing_keys(self, params)
}
@ -285,7 +293,7 @@ impl<'a> WalletRead for DataConnStmtCache<'a> {
fn get_extended_full_viewing_keys<P: consensus::Parameters>(
&self,
params: &P,
) -> Result<Vec<ExtendedFullViewingKey>, Self::Error> {
) -> Result<HashMap<AccountId, ExtendedFullViewingKey>, Self::Error> {
self.conn.get_extended_full_viewing_keys(params)
}
@ -569,7 +577,7 @@ impl<'a> WalletWrite for DataConnStmtCache<'a> {
tx_ref: Self::TxRef,
) -> Result<(), Self::Error> {
let output_index = output.index as i64;
let account = output.account as i64;
let account = output.account.0 as i64;
let value = output.note.value as i64;
let to_str = encode_payment_address(params.hrp_sapling_payment_address(), &output.to);
@ -588,7 +596,7 @@ impl<'a> WalletWrite for DataConnStmtCache<'a> {
params,
tx_ref,
output.index,
AccountId(output.account as u32),
output.account,
&RecipientAddress::Shielded(output.to.clone()),
Amount::from_u64(output.note.value)
.map_err(|_| Error::CorruptedData("Note value invalid.".to_string()))?,
@ -676,7 +684,7 @@ mod tests {
block::BlockHash,
consensus::{BlockHeight, Network, NetworkUpgrade, Parameters},
note_encryption::{Memo, SaplingNoteEncryption},
primitives::{Note, PaymentAddress},
primitives::{Note, Nullifier, PaymentAddress},
transaction::components::Amount,
util::generate_random_rseed,
zip32::ExtendedFullViewingKey,
@ -715,7 +723,7 @@ mod tests {
prev_hash: BlockHash,
extfvk: ExtendedFullViewingKey,
value: Amount,
) -> (CompactBlock, Vec<u8>) {
) -> (CompactBlock, Nullifier) {
let to = extfvk.default_address().unwrap().1;
// Create a fake Note for the account
@ -762,7 +770,7 @@ mod tests {
pub(crate) fn fake_compact_block_spending(
height: BlockHeight,
prev_hash: BlockHash,
(nf, in_value): (Vec<u8>, Amount),
(nf, in_value): (Nullifier, Amount),
extfvk: ExtendedFullViewingKey,
to: PaymentAddress,
value: Amount,
@ -772,7 +780,7 @@ mod tests {
// Create a fake CompactBlock containing the note
let mut cspend = CompactSpend::new();
cspend.set_nf(nf);
cspend.set_nf(nf.to_vec());
let mut ctx = CompactTx::new();
let mut txid = vec![0; 32];
rng.fill_bytes(&mut txid);

View File

@ -1,6 +1,7 @@
//! Functions for querying information in the data database.
use rusqlite::{OptionalExtension, ToSql, NO_PARAMS};
use std::collections::HashMap;
use zcash_primitives::{
block::BlockHash,
@ -63,15 +64,16 @@ pub fn get_address<P: consensus::Parameters>(
pub fn get_extended_full_viewing_keys<P: consensus::Parameters>(
data: &WalletDB,
params: &P,
) -> Result<Vec<ExtendedFullViewingKey>, SqliteClientError> {
) -> Result<HashMap<AccountId, ExtendedFullViewingKey>, SqliteClientError> {
// Fetch the ExtendedFullViewingKeys we are tracking
let mut stmt_fetch_accounts = data
.0
.prepare("SELECT extfvk FROM accounts ORDER BY account ASC")?;
.prepare("SELECT account, extfvk FROM accounts ORDER BY account ASC")?;
let rows = stmt_fetch_accounts
.query_map(NO_PARAMS, |row| {
row.get(0).map(|extfvk: String| {
let acct = row.get(0).map(AccountId)?;
let extfvk = row.get(1).map(|extfvk: String| {
decode_extended_full_viewing_key(
params.hrp_sapling_extended_full_viewing_key(),
&extfvk,
@ -79,11 +81,19 @@ pub fn get_extended_full_viewing_keys<P: consensus::Parameters>(
.map_err(|e| Error::Bech32(e))
.and_then(|k| k.ok_or(Error::IncorrectHRPExtFVK))
.map_err(SqliteClientError)
})
})?;
Ok((acct, extfvk))
})
.map_err(SqliteClientError::from)?;
rows.collect::<Result<Result<_, _>, _>>()?
let mut res: HashMap<AccountId, ExtendedFullViewingKey> = HashMap::new();
for row in rows {
let (account_id, efvkr) = row?;
res.insert(account_id, efvkr?);
}
Ok(res)
}
pub fn is_valid_account_extfvk<P: consensus::Parameters>(
@ -144,7 +154,7 @@ pub fn get_balance(data: &WalletDB, account: AccountId) -> Result<Amount, Sqlite
}
}
/// Returns the verified balance for the account at the specified height,
/// Returns the verified balance for the account at the specified height,
/// This may be used to obtain a balance that ignores notes that have been
/// received so recently that they are not yet deemed spendable.
///

View File

@ -148,7 +148,7 @@ mod tests {
get_balance, get_verified_balance,
init::{init_accounts_table, init_blocks_table, init_data_database},
},
AccountId, BlockDB, WalletDB,
AccountId, BlockDB, WalletDB, DataConnStmtCache
};
fn test_prover() -> impl TxProver {
@ -177,8 +177,9 @@ mod tests {
let to = extsk0.default_address().unwrap().1.into();
// Invalid extsk for the given account should cause an error
let mut db_write = db_data.get_update_ops().unwrap();
match create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(0),
@ -193,7 +194,7 @@ mod tests {
}
match create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(1),
@ -221,8 +222,9 @@ mod tests {
let to = extsk.default_address().unwrap().1.into();
// We cannot do anything if we aren't synchronised
let mut db_write = db_data.get_update_ops().unwrap();
match create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(0),
@ -261,8 +263,9 @@ mod tests {
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), Amount::zero());
// We cannot spend anything
let mut db_write = db_data.get_update_ops().unwrap();
match create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(0),
@ -304,7 +307,8 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Verified balance matches total balance
let (_, anchor_height) = (&db_data).get_target_and_anchor_heights().unwrap().unwrap();
@ -322,7 +326,7 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Verified balance does not include the second note
let (_, anchor_height2) = (&db_data).get_target_and_anchor_heights().unwrap().unwrap();
@ -336,7 +340,7 @@ mod tests {
let extsk2 = ExtendedSpendingKey::master(&[]);
let to = extsk2.default_address().unwrap().1.into();
match create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(0),
@ -364,11 +368,11 @@ mod tests {
);
insert_into_cache(&db_cache, &cb);
}
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Second spend still fails
match create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(0),
@ -393,11 +397,11 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Second spend should now succeed
create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(0),
@ -434,14 +438,15 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value);
// Send some of the funds to another address
let extsk2 = ExtendedSpendingKey::master(&[]);
let to = extsk2.default_address().unwrap().1.into();
create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(0),
@ -455,7 +460,7 @@ mod tests {
// A second spend fails because there are no usable notes
match create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(0),
@ -483,11 +488,11 @@ mod tests {
);
insert_into_cache(&db_cache, &cb);
}
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Second spend still fails
match create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(0),
@ -512,11 +517,11 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Second spend should now succeed
create_spend_to_address(
&db_data,
&mut db_write,
&tests::network(),
test_prover(),
AccountId(0),
@ -554,16 +559,17 @@ mod tests {
value,
);
insert_into_cache(&db_cache, &cb);
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value);
let extsk2 = ExtendedSpendingKey::master(&[]);
let addr2 = extsk2.default_address().unwrap().1;
let to = addr2.clone().into();
let send_and_recover_with_policy = |ovk_policy| {
let send_and_recover_with_policy = |db_write: &mut DataConnStmtCache<'_>, ovk_policy| {
let tx_row = create_spend_to_address(
&db_data,
db_write,
&network,
test_prover(),
AccountId(0),
@ -576,7 +582,7 @@ mod tests {
.unwrap();
// Fetch the transaction from the database
let raw_tx: Vec<_> = db_data
let raw_tx: Vec<_> = db_write.conn
.0
.query_row(
"SELECT raw FROM transactions
@ -588,7 +594,7 @@ mod tests {
let tx = Transaction::read(&raw_tx[..]).unwrap();
// Fetch the output index from the database
let output_index: i64 = db_data
let output_index: i64 = db_write.conn
.0
.query_row(
"SELECT output_index FROM sent_notes
@ -614,7 +620,7 @@ mod tests {
// Send some of the funds to another address, keeping history.
// The recipient output is decryptable by the sender.
let (_, recovered_to, _) = send_and_recover_with_policy(OvkPolicy::Sender).unwrap();
let (_, recovered_to, _) = send_and_recover_with_policy(&mut db_write, OvkPolicy::Sender).unwrap();
assert_eq!(&recovered_to, &addr2);
// Mine blocks SAPLING_ACTIVATION_HEIGHT + 1 to 22 (that don't send us funds)
@ -628,10 +634,10 @@ mod tests {
);
insert_into_cache(&db_cache, &cb);
}
scan_cached_blocks(&network, &db_cache, &db_data, None).unwrap();
scan_cached_blocks(&network, &db_cache, &mut db_write, None).unwrap();
// Send the funds again, discarding history.
// Neither transaction output is decryptable by the sender.
assert!(send_and_recover_with_policy(OvkPolicy::Discard).is_none());
assert!(send_and_recover_with_policy(&mut db_write, OvkPolicy::Discard).is_none());
}
}

View File

@ -206,6 +206,10 @@ impl Nullifier {
nf.0.copy_from_slice(bytes);
nf
}
pub fn to_vec(&self) -> Vec<u8> {
self.0.to_vec()
}
}
#[derive(Clone, Debug)]
@ -275,14 +279,14 @@ impl Note {
// Compute nf = BLAKE2s(nk | rho)
Nullifier::from_slice(
Blake2sParams::new()
.hash_length(32)
.personal(constants::PRF_NF_PERSONALIZATION)
.to_state()
.update(&viewing_key.nk.to_bytes())
.update(&rho.to_bytes())
.finalize()
.as_bytes()
Blake2sParams::new()
.hash_length(32)
.personal(constants::PRF_NF_PERSONALIZATION)
.to_state()
.update(&viewing_key.nk.to_bytes())
.update(&rho.to_bytes())
.finalize()
.as_bytes(),
)
}

View File

@ -10,9 +10,9 @@ use std::io::{self, Read, Write};
use crate::extensions::transparent as tze;
use crate::legacy::Script;
use crate::primitives::Nullifier;
use crate::redjubjub::{PublicKey, Signature};
use crate::serialize::{CompactSize, Vector};
use crate::primitives::Nullifier;
pub mod amount;
pub use self::amount::Amount;

View File

@ -609,7 +609,7 @@ fn test_input_circuit_with_bls12_381() {
}
let expected_nf = note.nf(&viewing_key, position);
let expected_nf = multipack::bytes_to_bits_le(&expected_nf);
let expected_nf = multipack::bytes_to_bits_le(&expected_nf.0);
let expected_nf = multipack::compute_multipacking(&expected_nf);
assert_eq!(expected_nf.len(), 2);
@ -789,7 +789,7 @@ fn test_input_circuit_with_bls12_381_external_test_vectors() {
}
let expected_nf = note.nf(&viewing_key, position);
let expected_nf = multipack::bytes_to_bits_le(&expected_nf);
let expected_nf = multipack::bytes_to_bits_le(&expected_nf.0);
let expected_nf = multipack::compute_multipacking(&expected_nf);
assert_eq!(expected_nf.len(), 2);