diff --git a/zcash_client_backend/src/data_api/chain.rs b/zcash_client_backend/src/data_api/chain.rs index 4aa0f71f0..de2fcbb82 100644 --- a/zcash_client_backend/src/data_api/chain.rs +++ b/zcash_client_backend/src/data_api/chain.rs @@ -27,7 +27,8 @@ pub const ANCHOR_OFFSET: u32 = 10; /// This function does not mutate either of the databases. pub fn validate_combined_chain< E0, - E: From>, + N, + E: From>, P: consensus::Parameters, C: CacheOps, D: DBOps, @@ -84,7 +85,7 @@ pub fn validate_combined_chain< /// Determines the target height for a transaction, and the height from which to /// select anchors, based on the current synchronised block chain. -pub fn get_target_and_anchor_heights>, D: DBOps>( +pub fn get_target_and_anchor_heights>, D: DBOps>( data: &D, ) -> Result<(BlockHeight, BlockHeight), E> { data.block_height_extrema().and_then(|heights| { diff --git a/zcash_client_backend/src/data_api/error.rs b/zcash_client_backend/src/data_api/error.rs index ba4376e22..5f68d5924 100644 --- a/zcash_client_backend/src/data_api/error.rs +++ b/zcash_client_backend/src/data_api/error.rs @@ -14,7 +14,7 @@ pub enum ChainInvalid { } #[derive(Debug)] -pub enum Error { +pub enum Error { CorruptedData(&'static str), IncorrectHRPExtFVK, InsufficientBalance(u64, u64), @@ -23,7 +23,7 @@ pub enum Error { InvalidMemo(std::str::Utf8Error), InvalidNewWitnessAnchor(usize, TxId, BlockHeight, Node), InvalidNote, - InvalidWitnessAnchor(i64, BlockHeight), + InvalidWitnessAnchor(NoteId, BlockHeight), ScanRequired, TableNotEmpty, Bech32(bech32::Error), @@ -36,16 +36,16 @@ pub enum Error { } impl ChainInvalid { - pub fn prev_hash_mismatch(at_height: BlockHeight) -> Error { + pub fn prev_hash_mismatch(at_height: BlockHeight) -> Error { Error::InvalidChain(at_height, ChainInvalid::PrevHashMismatch) } - pub fn block_height_mismatch(at_height: BlockHeight, found: BlockHeight) -> Error { + pub fn block_height_mismatch(at_height: BlockHeight, found: BlockHeight) -> Error { Error::InvalidChain(at_height, ChainInvalid::BlockHeightMismatch(found)) } } -impl fmt::Display for Error { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self { Error::CorruptedData(reason) => write!(f, "Data DB is corrupted: {}", reason), @@ -86,7 +86,7 @@ impl fmt::Display for Error { } } -impl error::Error for Error { +impl error::Error for Error { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match &self { Error::InvalidMemo(e) => Some(e), @@ -100,31 +100,31 @@ impl error::Error for Error { } } -impl From for Error { +impl From for Error { fn from(e: bech32::Error) -> Self { Error::Bech32(e) } } -impl From for Error { +impl From for Error { fn from(e: bs58::decode::Error) -> Self { Error::Base58(e) } } -impl From for Error { +impl From for Error { fn from(e: builder::Error) -> Self { Error::Builder(e) } } -impl From for Error { +impl From for Error { fn from(e: std::io::Error) -> Self { Error::Io(e) } } -impl From for Error { +impl From for Error { fn from(e: protobuf::ProtobufError) -> Self { Error::Protobuf(e) } diff --git a/zcash_client_backend/src/data_api/mod.rs b/zcash_client_backend/src/data_api/mod.rs index b076f2760..780f6e2e6 100644 --- a/zcash_client_backend/src/data_api/mod.rs +++ b/zcash_client_backend/src/data_api/mod.rs @@ -1,7 +1,7 @@ use zcash_primitives::{ block::BlockHash, consensus::{self, BlockHeight}, - merkle_tree::CommitmentTree, + merkle_tree::{CommitmentTree, IncrementalWitness}, primitives::PaymentAddress, sapling::Node, transaction::components::Amount, @@ -73,6 +73,11 @@ pub trait DBOps { block_height: BlockHeight, ) -> Result>, Self::Error>; + fn get_witnesses( + &self, + block_height: BlockHeight, + ) -> Result)>, Self::Error>; + // fn get_witnesses(block_height: BlockHeight) -> Result>>, Self::Error>; // // fn get_nullifiers() -> Result<(Vec, Account), Self::Error>; diff --git a/zcash_client_sqlite/src/error.rs b/zcash_client_sqlite/src/error.rs index 97890a9c6..401125d0c 100644 --- a/zcash_client_sqlite/src/error.rs +++ b/zcash_client_sqlite/src/error.rs @@ -4,8 +4,10 @@ use zcash_primitives::transaction::builder; use zcash_client_backend::data_api::error::Error; +use crate::NoteId; + #[derive(Debug)] -pub struct SqliteClientError(pub Error); +pub struct SqliteClientError(pub Error); impl fmt::Display for SqliteClientError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -13,8 +15,8 @@ impl fmt::Display for SqliteClientError { } } -impl From> for SqliteClientError { - fn from(e: Error) -> Self { +impl From> for SqliteClientError { + fn from(e: Error) -> Self { SqliteClientError(e) } } diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index 51dcee1da..a1eae4126 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -24,13 +24,15 @@ //! [`CompactBlock`]: zcash_client_backend::proto::compact_formats::CompactBlock //! [`init_cache_database`]: crate::init::init_cache_database -use rusqlite::Connection; +use std::fmt; use std::path::Path; +use rusqlite::Connection; + use zcash_primitives::{ block::BlockHash, consensus::{self, BlockHeight}, - merkle_tree::CommitmentTree, + merkle_tree::{CommitmentTree, IncrementalWitness}, primitives::PaymentAddress, sapling::Node, transaction::components::Amount, @@ -52,9 +54,18 @@ pub mod query; pub mod scan; pub mod transact; +#[derive(Debug, Copy, Clone)] pub struct AccountId(pub u32); + +#[derive(Debug, Copy, Clone)] pub struct NoteId(pub i64); +impl fmt::Display for NoteId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Note {}", self.0) + } +} + pub struct DataConnection(Connection); impl DataConnection { @@ -146,6 +157,13 @@ impl DBOps for DataConnection { ) -> Result>, Self::Error> { query::get_commitment_tree(self, block_height) } + + fn get_witnesses( + &self, + block_height: BlockHeight, + ) -> Result)>, Self::Error> { + query::get_witnesses(self, block_height) + } } pub struct CacheConnection(Connection); diff --git a/zcash_client_sqlite/src/query.rs b/zcash_client_sqlite/src/query.rs index b34b92676..29382fede 100644 --- a/zcash_client_sqlite/src/query.rs +++ b/zcash_client_sqlite/src/query.rs @@ -4,7 +4,7 @@ use rusqlite::{OptionalExtension, NO_PARAMS}; use zcash_primitives::{ consensus::{self, BlockHeight}, - merkle_tree::CommitmentTree, + merkle_tree::{CommitmentTree, IncrementalWitness}, note_encryption::Memo, primitives::PaymentAddress, sapling::Node, @@ -261,6 +261,30 @@ pub fn get_commitment_tree( .map_err(SqliteClientError::from) } +pub fn get_witnesses( + data: &DataConnection, + block_height: BlockHeight, +) -> Result)>, SqliteClientError> { + let mut stmt_fetch_witnesses = data + .0 + .prepare("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?; + let witnesses = stmt_fetch_witnesses + .query_map(&[u32::from(block_height)], |row| { + let id_note = NoteId(row.get(0)?); + let data: Vec = row.get(1)?; + Ok(IncrementalWitness::read(&data[..]).map(|witness| (id_note, witness))) + }) + .map_err(SqliteClientError::from)?; + + let mut res = vec![]; + for witness in witnesses { + // unwrap database error & IO error from IncrementalWitness::read + res.push(witness??); + } + + Ok(res) +} + #[cfg(test)] mod tests { use rusqlite::Connection; diff --git a/zcash_client_sqlite/src/scan.rs b/zcash_client_sqlite/src/scan.rs index 976ff9da6..4686bb0a3 100644 --- a/zcash_client_sqlite/src/scan.rs +++ b/zcash_client_sqlite/src/scan.rs @@ -24,7 +24,7 @@ use zcash_primitives::{ transaction::Transaction, }; -use crate::{error::SqliteClientError, CacheConnection, DataConnection}; +use crate::{error::SqliteClientError, CacheConnection, DataConnection, NoteId}; struct CompactBlockRow { height: BlockHeight, @@ -105,15 +105,7 @@ pub fn scan_cached_blocks( .map(|t| t.unwrap_or(CommitmentTree::new()))?; // Get most recent incremental witnesses for the notes we are tracking - let mut stmt_fetch_witnesses = data - .0 - .prepare("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?; - let witnesses = stmt_fetch_witnesses.query_map(&[u32::from(last_height)], |row| { - let id_note = row.get(0)?; - let data: Vec<_> = row.get(1)?; - Ok(IncrementalWitness::read(&data[..]).map(|witness| WitnessRow { id_note, witness })) - })?; - let mut witnesses: Vec<_> = witnesses.collect::, _>>()??; + let mut witnesses = data.get_witnesses(last_height)?; // Get the nullifiers for the notes we are tracking let mut stmt_fetch_nullifiers = data @@ -209,7 +201,7 @@ pub fn scan_cached_blocks( let txs = { let nf_refs: Vec<_> = nullifiers.iter().map(|(nf, acc)| (&nf[..], *acc)).collect(); - let mut witness_refs: Vec<_> = witnesses.iter_mut().map(|w| &mut w.witness).collect(); + let mut witness_refs: Vec<_> = witnesses.iter_mut().map(|w| &mut w.1).collect(); scan_block( params, block, @@ -225,9 +217,9 @@ pub fn scan_cached_blocks( { let cur_root = tree.root(); for row in &witnesses { - if row.witness.root() != cur_root { + if row.1.root() != cur_root { return Err(SqliteClientError(Error::InvalidWitnessAnchor( - row.id_note, + row.0, last_height, ))); } @@ -305,7 +297,7 @@ pub fn scan_cached_blocks( // - A note value will never exceed 2^63 zatoshis. // First try updating an existing received note into the database. - let note_row = if stmt_update_note.execute(&[ + let note_id = if stmt_update_note.execute(&[ (output.account as i64).to_sql()?, output.to.diversifier().0.to_sql()?, (output.note.value as i64).to_sql()?, @@ -327,20 +319,17 @@ pub fn scan_cached_blocks( nf.to_sql()?, output.is_change.to_sql()?, ])?; - data.0.last_insert_rowid() + NoteId(data.0.last_insert_rowid()) } else { // It was there, so grab its row number. stmt_select_note.query_row( &[tx_row.to_sql()?, (output.index as i64).to_sql()?], - |row| row.get(0), + |row| row.get(0).map(NoteId), )? }; // Save witness for note. - witnesses.push(WitnessRow { - id_note: note_row, - witness: output.witness, - }); + witnesses.push((note_id, output.witness)); // Cache nullifier for note (to detect subsequent spends in this scan). nullifiers.push((nf, output.account)); @@ -352,11 +341,11 @@ pub fn scan_cached_blocks( for witness_row in witnesses.iter() { encoded.clear(); witness_row - .witness + .1 .write(&mut encoded) .expect("Should be able to write to a Vec"); stmt_insert_witness.execute(&[ - witness_row.id_note.to_sql()?, + (witness_row.0).0.to_sql()?, u32::from(last_height).to_sql()?, encoded.to_sql()?, ])?; @@ -567,7 +556,7 @@ mod tests { self, fake_compact_block, fake_compact_block_spending, insert_into_cache, sapling_activation_height, }, - AccountId, CacheConnection, DataConnection, + AccountId, CacheConnection, DataConnection, NoteId, }; use super::scan_cached_blocks; @@ -617,8 +606,8 @@ mod tests { Ok(_) => panic!("Should have failed"), Err(e) => { assert_eq!( - e.0.to_string(), - ChainInvalid::block_height_mismatch::( + e.to_string(), + ChainInvalid::block_height_mismatch::( sapling_activation_height() + 1, sapling_activation_height() + 2 )