Factor out witness retrieval from scan.

This commit is contained in:
Kris Nuttycombe 2020-08-18 13:48:59 -06:00
parent 499dcd2e6c
commit ffd503134d
7 changed files with 84 additions and 45 deletions

View File

@ -27,7 +27,8 @@ pub const ANCHOR_OFFSET: u32 = 10;
/// This function does not mutate either of the databases.
pub fn validate_combined_chain<
E0,
E: From<Error<E0>>,
N,
E: From<Error<E0, N>>,
P: consensus::Parameters,
C: CacheOps<Error = E>,
D: DBOps<Error = E>,
@ -84,7 +85,7 @@ pub fn validate_combined_chain<
/// Determines the target height for a transaction, and the height from which to
/// select anchors, based on the current synchronised block chain.
pub fn get_target_and_anchor_heights<E0, E: From<Error<E0>>, D: DBOps<Error = E>>(
pub fn get_target_and_anchor_heights<E0, N, E: From<Error<E0, N>>, D: DBOps<Error = E>>(
data: &D,
) -> Result<(BlockHeight, BlockHeight), E> {
data.block_height_extrema().and_then(|heights| {

View File

@ -14,7 +14,7 @@ pub enum ChainInvalid {
}
#[derive(Debug)]
pub enum Error<DbError> {
pub enum Error<DbError, NoteId> {
CorruptedData(&'static str),
IncorrectHRPExtFVK,
InsufficientBalance(u64, u64),
@ -23,7 +23,7 @@ pub enum Error<DbError> {
InvalidMemo(std::str::Utf8Error),
InvalidNewWitnessAnchor(usize, TxId, BlockHeight, Node),
InvalidNote,
InvalidWitnessAnchor(i64, BlockHeight),
InvalidWitnessAnchor(NoteId, BlockHeight),
ScanRequired,
TableNotEmpty,
Bech32(bech32::Error),
@ -36,16 +36,16 @@ pub enum Error<DbError> {
}
impl ChainInvalid {
pub fn prev_hash_mismatch<E>(at_height: BlockHeight) -> Error<E> {
pub fn prev_hash_mismatch<E, N>(at_height: BlockHeight) -> Error<E, N> {
Error::InvalidChain(at_height, ChainInvalid::PrevHashMismatch)
}
pub fn block_height_mismatch<E>(at_height: BlockHeight, found: BlockHeight) -> Error<E> {
pub fn block_height_mismatch<E, N>(at_height: BlockHeight, found: BlockHeight) -> Error<E, N> {
Error::InvalidChain(at_height, ChainInvalid::BlockHeightMismatch(found))
}
}
impl<E: fmt::Display> fmt::Display for Error<E> {
impl<E: fmt::Display, N: fmt::Display> fmt::Display for Error<E, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self {
Error::CorruptedData(reason) => write!(f, "Data DB is corrupted: {}", reason),
@ -86,7 +86,7 @@ impl<E: fmt::Display> fmt::Display for Error<E> {
}
}
impl<E: error::Error + 'static> error::Error for Error<E> {
impl<E: error::Error + 'static, N: error::Error + 'static> error::Error for Error<E, N> {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match &self {
Error::InvalidMemo(e) => Some(e),
@ -100,31 +100,31 @@ impl<E: error::Error + 'static> error::Error for Error<E> {
}
}
impl<E> From<bech32::Error> for Error<E> {
impl<E, N> From<bech32::Error> for Error<E, N> {
fn from(e: bech32::Error) -> Self {
Error::Bech32(e)
}
}
impl<E> From<bs58::decode::Error> for Error<E> {
impl<E, N> From<bs58::decode::Error> for Error<E, N> {
fn from(e: bs58::decode::Error) -> Self {
Error::Base58(e)
}
}
impl<E> From<builder::Error> for Error<E> {
impl<E, N> From<builder::Error> for Error<E, N> {
fn from(e: builder::Error) -> Self {
Error::Builder(e)
}
}
impl<E> From<std::io::Error> for Error<E> {
impl<E, N> From<std::io::Error> for Error<E, N> {
fn from(e: std::io::Error) -> Self {
Error::Io(e)
}
}
impl<E> From<protobuf::ProtobufError> for Error<E> {
impl<E, N> From<protobuf::ProtobufError> for Error<E, N> {
fn from(e: protobuf::ProtobufError) -> Self {
Error::Protobuf(e)
}

View File

@ -1,7 +1,7 @@
use zcash_primitives::{
block::BlockHash,
consensus::{self, BlockHeight},
merkle_tree::CommitmentTree,
merkle_tree::{CommitmentTree, IncrementalWitness},
primitives::PaymentAddress,
sapling::Node,
transaction::components::Amount,
@ -73,6 +73,11 @@ pub trait DBOps {
block_height: BlockHeight,
) -> Result<Option<CommitmentTree<Node>>, Self::Error>;
fn get_witnesses(
&self,
block_height: BlockHeight,
) -> Result<Vec<(Self::NoteId, IncrementalWitness<Node>)>, Self::Error>;
// fn get_witnesses(block_height: BlockHeight) -> Result<Box<dyn Iterator<Item = IncrementalWitness<Node>>>, Self::Error>;
//
// fn get_nullifiers() -> Result<(Vec<u8>, Account), Self::Error>;

View File

@ -4,8 +4,10 @@ use zcash_primitives::transaction::builder;
use zcash_client_backend::data_api::error::Error;
use crate::NoteId;
#[derive(Debug)]
pub struct SqliteClientError(pub Error<rusqlite::Error>);
pub struct SqliteClientError(pub Error<rusqlite::Error, NoteId>);
impl fmt::Display for SqliteClientError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@ -13,8 +15,8 @@ impl fmt::Display for SqliteClientError {
}
}
impl From<Error<rusqlite::Error>> for SqliteClientError {
fn from(e: Error<rusqlite::Error>) -> Self {
impl From<Error<rusqlite::Error, NoteId>> for SqliteClientError {
fn from(e: Error<rusqlite::Error, NoteId>) -> Self {
SqliteClientError(e)
}
}

View File

@ -24,13 +24,15 @@
//! [`CompactBlock`]: zcash_client_backend::proto::compact_formats::CompactBlock
//! [`init_cache_database`]: crate::init::init_cache_database
use rusqlite::Connection;
use std::fmt;
use std::path::Path;
use rusqlite::Connection;
use zcash_primitives::{
block::BlockHash,
consensus::{self, BlockHeight},
merkle_tree::CommitmentTree,
merkle_tree::{CommitmentTree, IncrementalWitness},
primitives::PaymentAddress,
sapling::Node,
transaction::components::Amount,
@ -52,9 +54,18 @@ pub mod query;
pub mod scan;
pub mod transact;
#[derive(Debug, Copy, Clone)]
pub struct AccountId(pub u32);
#[derive(Debug, Copy, Clone)]
pub struct NoteId(pub i64);
impl fmt::Display for NoteId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Note {}", self.0)
}
}
pub struct DataConnection(Connection);
impl DataConnection {
@ -146,6 +157,13 @@ impl DBOps for DataConnection {
) -> Result<Option<CommitmentTree<Node>>, Self::Error> {
query::get_commitment_tree(self, block_height)
}
fn get_witnesses(
&self,
block_height: BlockHeight,
) -> Result<Vec<(Self::NoteId, IncrementalWitness<Node>)>, Self::Error> {
query::get_witnesses(self, block_height)
}
}
pub struct CacheConnection(Connection);

View File

@ -4,7 +4,7 @@ use rusqlite::{OptionalExtension, NO_PARAMS};
use zcash_primitives::{
consensus::{self, BlockHeight},
merkle_tree::CommitmentTree,
merkle_tree::{CommitmentTree, IncrementalWitness},
note_encryption::Memo,
primitives::PaymentAddress,
sapling::Node,
@ -261,6 +261,30 @@ pub fn get_commitment_tree(
.map_err(SqliteClientError::from)
}
pub fn get_witnesses(
data: &DataConnection,
block_height: BlockHeight,
) -> Result<Vec<(NoteId, IncrementalWitness<Node>)>, SqliteClientError> {
let mut stmt_fetch_witnesses = data
.0
.prepare("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?;
let witnesses = stmt_fetch_witnesses
.query_map(&[u32::from(block_height)], |row| {
let id_note = NoteId(row.get(0)?);
let data: Vec<u8> = row.get(1)?;
Ok(IncrementalWitness::read(&data[..]).map(|witness| (id_note, witness)))
})
.map_err(SqliteClientError::from)?;
let mut res = vec![];
for witness in witnesses {
// unwrap database error & IO error from IncrementalWitness::read
res.push(witness??);
}
Ok(res)
}
#[cfg(test)]
mod tests {
use rusqlite::Connection;

View File

@ -24,7 +24,7 @@ use zcash_primitives::{
transaction::Transaction,
};
use crate::{error::SqliteClientError, CacheConnection, DataConnection};
use crate::{error::SqliteClientError, CacheConnection, DataConnection, NoteId};
struct CompactBlockRow {
height: BlockHeight,
@ -105,15 +105,7 @@ pub fn scan_cached_blocks<P: consensus::Parameters>(
.map(|t| t.unwrap_or(CommitmentTree::new()))?;
// Get most recent incremental witnesses for the notes we are tracking
let mut stmt_fetch_witnesses = data
.0
.prepare("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?;
let witnesses = stmt_fetch_witnesses.query_map(&[u32::from(last_height)], |row| {
let id_note = row.get(0)?;
let data: Vec<_> = row.get(1)?;
Ok(IncrementalWitness::read(&data[..]).map(|witness| WitnessRow { id_note, witness }))
})?;
let mut witnesses: Vec<_> = witnesses.collect::<Result<Result<_, _>, _>>()??;
let mut witnesses = data.get_witnesses(last_height)?;
// Get the nullifiers for the notes we are tracking
let mut stmt_fetch_nullifiers = data
@ -209,7 +201,7 @@ pub fn scan_cached_blocks<P: consensus::Parameters>(
let txs = {
let nf_refs: Vec<_> = nullifiers.iter().map(|(nf, acc)| (&nf[..], *acc)).collect();
let mut witness_refs: Vec<_> = witnesses.iter_mut().map(|w| &mut w.witness).collect();
let mut witness_refs: Vec<_> = witnesses.iter_mut().map(|w| &mut w.1).collect();
scan_block(
params,
block,
@ -225,9 +217,9 @@ pub fn scan_cached_blocks<P: consensus::Parameters>(
{
let cur_root = tree.root();
for row in &witnesses {
if row.witness.root() != cur_root {
if row.1.root() != cur_root {
return Err(SqliteClientError(Error::InvalidWitnessAnchor(
row.id_note,
row.0,
last_height,
)));
}
@ -305,7 +297,7 @@ pub fn scan_cached_blocks<P: consensus::Parameters>(
// - A note value will never exceed 2^63 zatoshis.
// First try updating an existing received note into the database.
let note_row = if stmt_update_note.execute(&[
let note_id = if stmt_update_note.execute(&[
(output.account as i64).to_sql()?,
output.to.diversifier().0.to_sql()?,
(output.note.value as i64).to_sql()?,
@ -327,20 +319,17 @@ pub fn scan_cached_blocks<P: consensus::Parameters>(
nf.to_sql()?,
output.is_change.to_sql()?,
])?;
data.0.last_insert_rowid()
NoteId(data.0.last_insert_rowid())
} else {
// It was there, so grab its row number.
stmt_select_note.query_row(
&[tx_row.to_sql()?, (output.index as i64).to_sql()?],
|row| row.get(0),
|row| row.get(0).map(NoteId),
)?
};
// Save witness for note.
witnesses.push(WitnessRow {
id_note: note_row,
witness: output.witness,
});
witnesses.push((note_id, output.witness));
// Cache nullifier for note (to detect subsequent spends in this scan).
nullifiers.push((nf, output.account));
@ -352,11 +341,11 @@ pub fn scan_cached_blocks<P: consensus::Parameters>(
for witness_row in witnesses.iter() {
encoded.clear();
witness_row
.witness
.1
.write(&mut encoded)
.expect("Should be able to write to a Vec");
stmt_insert_witness.execute(&[
witness_row.id_note.to_sql()?,
(witness_row.0).0.to_sql()?,
u32::from(last_height).to_sql()?,
encoded.to_sql()?,
])?;
@ -567,7 +556,7 @@ mod tests {
self, fake_compact_block, fake_compact_block_spending, insert_into_cache,
sapling_activation_height,
},
AccountId, CacheConnection, DataConnection,
AccountId, CacheConnection, DataConnection, NoteId,
};
use super::scan_cached_blocks;
@ -617,8 +606,8 @@ mod tests {
Ok(_) => panic!("Should have failed"),
Err(e) => {
assert_eq!(
e.0.to_string(),
ChainInvalid::block_height_mismatch::<rusqlite::Error>(
e.to_string(),
ChainInvalid::block_height_mismatch::<rusqlite::Error, NoteId>(
sapling_activation_height() + 1,
sapling_activation_height() + 2
)