Simplify block source & clean up chain validation.

This commit is contained in:
Kris Nuttycombe 2020-10-19 16:52:48 -06:00
parent 8a215d67fe
commit 897a70dd9e
7 changed files with 155 additions and 229 deletions

View File

@ -202,26 +202,14 @@ pub trait BlockSource {
fn init_cache(&self) -> Result<(), Self::Error>;
// Validate the cached chain by applying a function that checks pairwise constraints
// (top_block :: &CompactBlock, next_block :: &CompactBlock) -> Result<(), Self::Error)
// beginning with the current maximum height walking backward through the chain, terminating
// with the block at `from_height`. Returns the hash of the block at height `from_height`
fn validate_chain<F>(
&self,
from_height: BlockHeight,
validate: F,
) -> Result<Option<BlockHash>, Self::Error>
where
F: Fn(&CompactBlock, &CompactBlock) -> Result<(), Self::Error>;
fn with_cached_blocks<F>(
fn with_blocks<F>(
&self,
from_height: BlockHeight,
limit: Option<u32>,
with_row: F,
) -> Result<(), Self::Error>
where
F: FnMut(BlockHeight, CompactBlock) -> Result<(), Self::Error>;
F: FnMut(CompactBlock) -> Result<(), Self::Error>;
}
pub trait ShieldedOutput {

View File

@ -38,7 +38,7 @@ use crate::{
/// - `Err(e)` if there was an error during validation unrelated to chain validity.
///
/// This function does not mutate either of the databases.
pub fn validate_combined_chain<'db, E0, N, E, P, C>(
pub fn validate_chain<'db, E0, N, E, P, C>(
parameters: &P,
cache: &C,
validate_from: Option<(BlockHeight, BlockHash)>,
@ -59,32 +59,26 @@ where
let from_height = validate_from
.map(|(height, _)| height)
.unwrap_or(sapling_activation_height - 1);
let scan_start_hash = cache.validate_chain(from_height, |top_block, next_block| {
if next_block.height() != top_block.height() - 1 {
Err(
ChainInvalid::block_height_mismatch(top_block.height() - 1, next_block.height())
.into(),
)
} else if next_block.hash() != top_block.prev_hash() {
Err(ChainInvalid::prev_hash_mismatch(next_block.height()).into())
} else {
Ok(())
}
})?;
match (scan_start_hash, validate_from) {
(Some(scan_start_hash), Some((validate_from_height, validate_from_hash))) => {
if scan_start_hash == validate_from_hash {
Ok(())
} else {
Err(ChainInvalid::prev_hash_mismatch(validate_from_height).into())
let mut prev_height = from_height;
let mut prev_hash: Option<BlockHash> = validate_from.map(|(_, hash)| hash);
cache.with_blocks(from_height, None, move |block| {
let current_height = block.height();
let result = if current_height != prev_height + 1 {
Err(ChainInvalid::block_height_discontinuity(prev_height + 1, current_height).into())
} else {
match prev_hash {
None => Ok(()),
Some(h) if h == block.prev_hash() => Ok(()),
Some(_) => Err(ChainInvalid::prev_hash_mismatch(current_height).into()),
}
}
_ => {
// No cached blocks are present, or the max data height is absent, this is fine.
Ok(())
}
}
};
prev_height = current_height;
prev_hash = Some(block.hash());
result
})
}
/// Scans at most `limit` new blocks added to the cache for any transactions received by
@ -169,111 +163,110 @@ where
// Get the nullifiers for the notes we are tracking
let mut nullifiers = data.get_nullifiers()?;
cache.with_cached_blocks(
last_height,
limit,
|height: BlockHeight, block: CompactBlock| {
// Scanned blocks MUST be height-sequential.
if height != (last_height + 1) {
return Err(ChainInvalid::block_height_mismatch(last_height + 1, height).into());
}
last_height = height;
cache.with_blocks(last_height, limit, |block: CompactBlock| {
let current_height = block.height();
// Scanned blocks MUST be height-sequential.
if current_height != (last_height + 1) {
return Err(
ChainInvalid::block_height_discontinuity(last_height + 1, current_height).into(),
);
}
last_height = current_height;
let block_hash = BlockHash::from_slice(&block.hash);
let block_time = block.time;
let block_hash = BlockHash::from_slice(&block.hash);
let block_time = block.time;
let txs: Vec<WalletTx> = {
let nf_refs: Vec<_> = nullifiers
.iter()
.map(|(nf, acc)| (&nf[..], acc.0 as usize))
.collect();
let mut witness_refs: Vec<_> = witnesses.iter_mut().map(|w| &mut w.1).collect();
scan_block(
params,
block,
&extfvks[..],
&nf_refs,
&mut tree,
&mut witness_refs[..],
)
};
let txs: Vec<WalletTx> = {
let nf_refs: Vec<_> = nullifiers
.iter()
.map(|(nf, acc)| (&nf[..], acc.0 as usize))
.collect();
let mut witness_refs: Vec<_> = witnesses.iter_mut().map(|w| &mut w.1).collect();
scan_block(
params,
block,
&extfvks[..],
&nf_refs,
&mut tree,
&mut witness_refs[..],
)
};
// Enforce that all roots match. This is slow, so only include in debug builds.
#[cfg(debug_assertions)]
{
let cur_root = tree.root();
for row in &witnesses {
if row.1.root() != cur_root {
return Err(Error::InvalidWitnessAnchor(row.0, last_height).into());
}
// Enforce that all roots match. This is slow, so only include in debug builds.
#[cfg(debug_assertions)]
{
let cur_root = tree.root();
for row in &witnesses {
if row.1.root() != cur_root {
return Err(Error::InvalidWitnessAnchor(row.0, last_height).into());
}
for tx in &txs {
for output in tx.shielded_outputs.iter() {
if output.witness.root() != cur_root {
return Err(Error::InvalidNewWitnessAnchor(
output.index,
tx.txid,
last_height,
output.witness.root(),
)
.into());
}
}
for tx in &txs {
for output in tx.shielded_outputs.iter() {
if output.witness.root() != cur_root {
return Err(Error::InvalidNewWitnessAnchor(
output.index,
tx.txid,
last_height,
output.witness.root(),
)
.into());
}
}
}
}
// database updates for each block are transactional
let mut db_update = data.get_update_ops()?;
db_update.transactionally(|up| {
// Insert the block into the database.
up.insert_block(height, block_hash, block_time, &tree)?;
// database updates for each block are transactional
let mut db_update = data.get_update_ops()?;
db_update.transactionally(|up| {
// Insert the block into the database.
up.insert_block(current_height, block_hash, block_time, &tree)?;
for tx in txs {
let tx_row = up.put_tx_meta(&tx, height)?;
for tx in txs {
let tx_row = up.put_tx_meta(&tx, current_height)?;
// Mark notes as spent and remove them from the scanning cache
for spend in &tx.shielded_spends {
up.mark_spent(tx_row, &spend.nf)?;
}
nullifiers.retain(|(nf, _acc)| {
tx.shielded_spends
.iter()
.find(|spend| &spend.nf == nf)
.is_none()
});
for output in tx.shielded_outputs {
let nf = output.note.nf(
&extfvks[output.account].fvk.vk,
output.witness.position() as u64,
);
let note_id = up.put_received_note(&output, Some(&nf), tx_row)?;
// Save witness for note.
witnesses.push((note_id, output.witness));
// Cache nullifier for note (to detect subsequent spends in this scan).
nullifiers.push((nf, AccountId(output.account as u32)));
}
// Mark notes as spent and remove them from the scanning cache
for spend in &tx.shielded_spends {
up.mark_spent(tx_row, &spend.nf)?;
}
// Insert current witnesses into the database.
for (note_id, witness) in witnesses.iter() {
up.insert_witness(*note_id, witness, last_height)?;
nullifiers.retain(|(nf, _acc)| {
tx.shielded_spends
.iter()
.find(|spend| &spend.nf == nf)
.is_none()
});
for output in tx.shielded_outputs {
let nf = output.note.nf(
&extfvks[output.account].fvk.vk,
output.witness.position() as u64,
);
let note_id = up.put_received_note(&output, Some(&nf), tx_row)?;
// Save witness for note.
witnesses.push((note_id, output.witness));
// Cache nullifier for note (to detect subsequent spends in this scan).
nullifiers.push((nf, AccountId(output.account as u32)));
}
}
// Prune the stored witnesses (we only expect rollbacks of at most 100 blocks).
up.prune_witnesses(last_height - 100)?;
// Insert current witnesses into the database.
for (note_id, witness) in witnesses.iter() {
up.insert_witness(*note_id, witness, last_height)?;
}
// Update now-expired transactions that didn't get mined.
up.update_expired_notes(last_height)?;
// Prune the stored witnesses (we only expect rollbacks of at most 100 blocks).
up.prune_witnesses(last_height - 100)?;
Ok(())
})
},
)?;
// Update now-expired transactions that didn't get mined.
up.update_expired_notes(last_height)?;
Ok(())
})
})?;
Ok(())
}

View File

@ -12,12 +12,12 @@ use crate::wallet::AccountId;
pub enum ChainInvalid {
PrevHashMismatch,
/// (expected_height, actual_height)
BlockHeightMismatch(BlockHeight),
BlockHeightDiscontinuity(BlockHeight),
}
#[derive(Debug)]
pub enum Error<DbError, NoteId> {
CorruptedData(&'static str),
CorruptedData(String),
IncorrectHRPExtFVK,
InsufficientBalance(Amount, Amount),
InvalidChain(BlockHeight, ChainInvalid),
@ -42,8 +42,11 @@ impl ChainInvalid {
Error::InvalidChain(at_height, ChainInvalid::PrevHashMismatch)
}
pub fn block_height_mismatch<E, N>(at_height: BlockHeight, found: BlockHeight) -> Error<E, N> {
Error::InvalidChain(at_height, ChainInvalid::BlockHeightMismatch(found))
pub fn block_height_discontinuity<E, N>(
at_height: BlockHeight,
found: BlockHeight,
) -> Error<E, N> {
Error::InvalidChain(at_height, ChainInvalid::BlockHeightDiscontinuity(found))
}
}

View File

@ -12,7 +12,7 @@
//! data_api::{
//! WalletRead,
//! chain::{
//! validate_combined_chain,
//! validate_chain,
//! scan_cached_blocks,
//! },
//! error::Error,
@ -37,7 +37,7 @@
//! //
//! // Given that we assume the server always gives us correct-at-the-time blocks, any
//! // errors are in the blocks we have previously cached or scanned.
//! if let Err(e) = validate_combined_chain(&network, &db_cache, (&db_data).get_max_height_hash().unwrap()) {
//! if let Err(e) = validate_chain(&network, &db_cache, (&db_data).get_max_height_hash().unwrap()) {
//! match e.0 {
//! Error::InvalidChain(upper_bound, _) => {
//! // a) Pick a height to rewind to.
@ -77,12 +77,9 @@ use protobuf::parse_from_bytes;
use rusqlite::types::ToSql;
use zcash_primitives::{block::BlockHash, consensus::BlockHeight};
use zcash_primitives::consensus::BlockHeight;
use zcash_client_backend::{
data_api::error::{ChainInvalid, Error},
proto::compact_formats::CompactBlock,
};
use zcash_client_backend::{data_api::error::Error, proto::compact_formats::CompactBlock};
use crate::{error::SqliteClientError, CacheConnection};
@ -93,69 +90,14 @@ struct CompactBlockRow {
data: Vec<u8>,
}
pub fn validate_chain<F>(
conn: &CacheConnection,
from_height: BlockHeight,
validate: F,
) -> Result<Option<BlockHash>, SqliteClientError>
where
F: Fn(&CompactBlock, &CompactBlock) -> Result<(), SqliteClientError>,
{
let mut stmt_blocks = conn
.0
.prepare("SELECT height, data FROM compactblocks WHERE height >= ? ORDER BY height DESC")?;
let block_rows = stmt_blocks.query_map(&[u32::from(from_height)], |row| {
let height: BlockHeight = row.get(0).map(u32::into)?;
let data = row.get::<_, Vec<_>>(1)?;
Ok(CompactBlockRow { height, data })
})?;
let mut blocks = block_rows.map(|cbr_result| {
let cbr = cbr_result.map_err(Error::Database)?;
let block: CompactBlock = parse_from_bytes(&cbr.data).map_err(Error::from)?;
if block.height() == cbr.height {
Ok(block)
} else {
Err(ChainInvalid::block_height_mismatch(
cbr.height,
block.height(),
))
}
});
let mut current_block: CompactBlock = match blocks.next() {
Some(Ok(block)) => block,
Some(Err(error)) => {
return Err(SqliteClientError(error));
}
None => {
// No cached blocks, and we've already validated the blocks we've scanned,
// so there's nothing to validate.
// TODO: Maybe we still want to check if there are cached blocks that are
// at heights we previously scanned? Check scanning flow again.
return Ok(None);
}
};
for block_result in blocks {
let block = block_result?;
validate(&current_block, &block)?;
current_block = block;
}
Ok(Some(current_block.hash()))
}
pub fn with_cached_blocks<F>(
pub fn with_blocks<F>(
cache: &CacheConnection,
from_height: BlockHeight,
limit: Option<u32>,
mut with_row: F,
) -> Result<(), SqliteClientError>
where
F: FnMut(BlockHeight, CompactBlock) -> Result<(), SqliteClientError>,
F: FnMut(CompactBlock) -> Result<(), SqliteClientError>,
{
// Fetch the CompactBlocks we need to scan
let mut stmt_blocks = cache.0.prepare(
@ -175,8 +117,19 @@ where
)?;
for row_result in rows {
let row = row_result?;
with_row(row.height, parse_from_bytes(&row.data)?)?;
let cbr = row_result?;
let block: CompactBlock = parse_from_bytes(&cbr.data)?;
if block.height() != cbr.height {
return Err(Error::CorruptedData(format!(
"Block height {} did not match row's height field value {}",
block.height(),
cbr.height
))
.into());
}
with_row(block)?;
}
Ok(())
@ -195,7 +148,7 @@ mod tests {
use zcash_client_backend::data_api::WalletRead;
use zcash_client_backend::data_api::{
chain::{scan_cached_blocks, validate_combined_chain},
chain::{scan_cached_blocks, validate_chain},
error::{ChainInvalid, Error},
};
@ -229,7 +182,7 @@ mod tests {
init_accounts_table(&db_data, &tests::network(), &[extfvk.clone()]).unwrap();
// Empty chain should be valid
validate_combined_chain(
validate_chain(
&tests::network(),
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
@ -246,7 +199,7 @@ mod tests {
insert_into_cache(&db_cache, &cb);
// Cache-only chain should be valid
validate_combined_chain(
validate_chain(
&tests::network(),
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
@ -257,7 +210,7 @@ mod tests {
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
// Data-only chain should be valid
validate_combined_chain(
validate_chain(
&tests::network(),
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
@ -274,7 +227,7 @@ mod tests {
insert_into_cache(&db_cache, &cb2);
// Data+cache chain should be valid
validate_combined_chain(
validate_chain(
&tests::network(),
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
@ -285,7 +238,7 @@ mod tests {
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
// Data-only chain should be valid
validate_combined_chain(
validate_chain(
&tests::network(),
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
@ -328,7 +281,7 @@ mod tests {
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
// Data-only chain should be valid
validate_combined_chain(
validate_chain(
&tests::network(),
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
@ -352,7 +305,7 @@ mod tests {
insert_into_cache(&db_cache, &cb4);
// Data+cache chain should be invalid at the data/cache boundary
match validate_combined_chain(
match validate_chain(
&tests::network(),
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
@ -360,7 +313,7 @@ mod tests {
.map_err(|e| e.0)
{
Err(Error::InvalidChain(upper_bound, _)) => {
assert_eq!(upper_bound, sapling_activation_height() + 1)
assert_eq!(upper_bound, sapling_activation_height() + 2)
}
_ => panic!(),
}
@ -401,7 +354,7 @@ mod tests {
scan_cached_blocks(&tests::network(), &db_cache, &db_data, None).unwrap();
// Data-only chain should be valid
validate_combined_chain(
validate_chain(
&tests::network(),
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
@ -425,7 +378,7 @@ mod tests {
insert_into_cache(&db_cache, &cb4);
// Data+cache chain should be invalid inside the cache
match validate_combined_chain(
match validate_chain(
&tests::network(),
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
@ -433,7 +386,7 @@ mod tests {
.map_err(|e| e.0)
{
Err(Error::InvalidChain(upper_bound, _)) => {
assert_eq!(upper_bound, sapling_activation_height() + 2)
assert_eq!(upper_bound, sapling_activation_height() + 3)
}
_ => panic!(),
}
@ -543,7 +496,7 @@ mod tests {
Err(e) => {
assert_eq!(
e.to_string(),
ChainInvalid::block_height_mismatch::<rusqlite::Error, NoteId>(
ChainInvalid::block_height_discontinuity::<rusqlite::Error, NoteId>(
sapling_activation_height() + 1,
sapling_activation_height() + 2
)

View File

@ -503,7 +503,7 @@ impl<'a> WalletWrite for DataConnStmtCache<'a> {
AccountId(output.account as u32),
&RecipientAddress::Shielded(output.to.clone()),
Amount::from_u64(output.note.value)
.map_err(|_| Error::CorruptedData("Note value invalid."))?,
.map_err(|_| Error::CorruptedData("Note value invalid.".to_string()))?,
Some(output.memo.clone()),
)?
}
@ -551,27 +551,16 @@ impl BlockSource for CacheConnection {
chain::init::init_cache_database(self).map_err(SqliteClientError::from)
}
fn validate_chain<F>(
&self,
from_height: BlockHeight,
validate: F,
) -> Result<Option<BlockHash>, Self::Error>
where
F: Fn(&CompactBlock, &CompactBlock) -> Result<(), Self::Error>,
{
chain::validate_chain(self, from_height, validate)
}
fn with_cached_blocks<F>(
fn with_blocks<F>(
&self,
from_height: BlockHeight,
limit: Option<u32>,
with_row: F,
) -> Result<(), Self::Error>
where
F: FnMut(BlockHeight, CompactBlock) -> Result<(), Self::Error>,
F: FnMut(CompactBlock) -> Result<(), Self::Error>,
{
chain::with_cached_blocks(self, from_height, limit, with_row)
chain::with_blocks(self, from_height, limit, with_row)
}
}

View File

@ -139,7 +139,7 @@ pub fn get_balance(data: &DataConnection, account: AccountId) -> Result<Amount,
match Amount::from_i64(balance) {
Ok(amount) if !amount.is_negative() => Ok(amount),
_ => Err(SqliteClientError(Error::CorruptedData(
"Sum of values in received_notes is out of range",
"Sum of values in received_notes is out of range".to_string(),
))),
}
}
@ -178,7 +178,7 @@ pub fn get_verified_balance(
match Amount::from_i64(balance) {
Ok(amount) if !amount.is_negative() => Ok(amount),
_ => Err(SqliteClientError(Error::CorruptedData(
"Sum of values in received_notes is out of range",
"Sum of values in received_notes is out of range".to_string(),
))),
}
}

View File

@ -77,7 +77,7 @@ pub fn select_spendable_notes(
let d: Vec<_> = row.get(0)?;
if d.len() != 11 {
return Err(SqliteClientError(Error::CorruptedData(
"Invalid diversifier length",
"Invalid diversifier length".to_string(),
)));
}
let mut tmp = [0; 11];