zcash_client_sqlite: Replace Either-based definition of `wallet::commitment_tree::Error` with a bespoke error type.

This commit is contained in:
Kris Nuttycombe 2023-08-03 11:19:40 -06:00
parent eade2acab9
commit 0ee45e40c4
5 changed files with 112 additions and 90 deletions

View File

@ -32,7 +32,6 @@ tracing = "0.1"
# - Serialization
byteorder = "1"
prost = "0.11"
either = "1.8"
group = "0.13"
jubjub = "0.10"

View File

@ -1,14 +1,13 @@
//! Error types for problems that may arise when reading or storing wallet data to SQLite.
use either::Either;
use std::error;
use std::fmt;
use std::io;
use shardtree::error::ShardTreeError;
use zcash_client_backend::encoding::{Bech32DecodeError, TransparentCodecError};
use zcash_primitives::{consensus::BlockHeight, zip32::AccountId};
use crate::wallet::commitment_tree;
use crate::PRUNING_DEPTH;
#[cfg(feature = "transparent-inputs")]
@ -85,7 +84,7 @@ pub enum SqliteClientError {
/// An error occurred in inserting data into or accessing data from one of the wallet's note
/// commitment trees.
CommitmentTree(ShardTreeError<Either<io::Error, rusqlite::Error>>),
CommitmentTree(ShardTreeError<commitment_tree::Error>),
}
impl error::Error for SqliteClientError {
@ -176,8 +175,8 @@ impl From<zcash_primitives::memo::Error> for SqliteClientError {
}
}
impl From<ShardTreeError<Either<io::Error, rusqlite::Error>>> for SqliteClientError {
fn from(e: ShardTreeError<Either<io::Error, rusqlite::Error>>) -> Self {
impl From<ShardTreeError<commitment_tree::Error>> for SqliteClientError {
fn from(e: ShardTreeError<commitment_tree::Error>) -> Self {
SqliteClientError::CommitmentTree(e)
}
}

View File

@ -32,7 +32,6 @@
// Catch documentation errors caused by code changes.
#![deny(rustdoc::broken_intra_doc_links)]
use either::Either;
use maybe_rayon::{
prelude::{IndexedParallelIterator, ParallelIterator},
slice::ParallelSliceMut,
@ -86,7 +85,7 @@ pub mod error;
pub mod serialization;
pub mod wallet;
use wallet::commitment_tree::put_shard_roots;
use wallet::commitment_tree::{self, put_shard_roots};
/// The maximum number of blocks the wallet is allowed to rewind. This is
/// consistent with the bound in zcashd, and allows block data deeper than
@ -726,7 +725,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
}
impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Connection, P> {
type Error = Either<io::Error, rusqlite::Error>;
type Error = commitment_tree::Error;
type SaplingShardStore<'a> =
SqliteShardStore<&'a rusqlite::Transaction<'a>, sapling::Node, SAPLING_SHARD_HEIGHT>;
@ -739,21 +738,21 @@ impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Conn
SAPLING_SHARD_HEIGHT,
>,
) -> Result<A, E>,
E: From<ShardTreeError<Either<io::Error, rusqlite::Error>>>,
E: From<ShardTreeError<Self::Error>>,
{
let tx = self
.conn
.transaction()
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
.map_err(|e| ShardTreeError::Storage(commitment_tree::Error::Query(e)))?;
let shard_store = SqliteShardStore::from_connection(&tx, SAPLING_TABLES_PREFIX)
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
.map_err(|e| ShardTreeError::Storage(commitment_tree::Error::Query(e)))?;
let result = {
let mut shardtree = ShardTree::new(shard_store, PRUNING_DEPTH.try_into().unwrap());
callback(&mut shardtree)?
};
tx.commit()
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
.map_err(|e| ShardTreeError::Storage(commitment_tree::Error::Query(e)))?;
Ok(result)
}
@ -765,7 +764,7 @@ impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Conn
let tx = self
.conn
.transaction()
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
.map_err(|e| ShardTreeError::Storage(commitment_tree::Error::Query(e)))?;
put_shard_roots::<_, { sapling::NOTE_COMMITMENT_TREE_DEPTH }, SAPLING_SHARD_HEIGHT>(
&tx,
SAPLING_TABLES_PREFIX,
@ -773,13 +772,13 @@ impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Conn
roots,
)?;
tx.commit()
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
.map_err(|e| ShardTreeError::Storage(commitment_tree::Error::Query(e)))?;
Ok(())
}
}
impl<'conn, P: consensus::Parameters> WalletCommitmentTrees for WalletDb<SqlTransaction<'conn>, P> {
type Error = Either<io::Error, rusqlite::Error>;
type Error = commitment_tree::Error;
type SaplingShardStore<'a> =
SqliteShardStore<&'a rusqlite::Transaction<'a>, sapling::Node, SAPLING_SHARD_HEIGHT>;
@ -792,11 +791,11 @@ impl<'conn, P: consensus::Parameters> WalletCommitmentTrees for WalletDb<SqlTran
SAPLING_SHARD_HEIGHT,
>,
) -> Result<A, E>,
E: From<ShardTreeError<Either<io::Error, rusqlite::Error>>>,
E: From<ShardTreeError<commitment_tree::Error>>,
{
let mut shardtree = ShardTree::new(
SqliteShardStore::from_connection(self.conn.0, SAPLING_TABLES_PREFIX)
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?,
.map_err(|e| ShardTreeError::Storage(commitment_tree::Error::Query(e)))?,
PRUNING_DEPTH.try_into().unwrap(),
);
let result = callback(&mut shardtree)?;

View File

@ -1,7 +1,7 @@
use either::Either;
use rusqlite::{self, named_params, OptionalExtension};
use std::{
collections::BTreeSet,
error, fmt,
io::{self, Cursor},
marker::PhantomData,
sync::Arc,
@ -19,6 +19,33 @@ use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer};
use crate::serialization::{read_shard, write_shard};
/// Errors that can appear in SQLite-back [`ShardStore`] implementation operations.
#[derive(Debug)]
pub enum Error {
/// Errors in deserializing stored shard data
Serialization(io::Error),
/// Errors encountered querying stored shard data
Query(rusqlite::Error),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self {
Error::Serialization(err) => write!(f, "Commitment tree serializtion error: {}", err),
Error::Query(err) => write!(f, "Commitment tree query or update error: {}", err),
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match &self {
Error::Serialization(e) => Some(e),
Error::Query(e) => Some(e),
}
}
}
pub struct SqliteShardStore<C, H, const SHARD_HEIGHT: u8> {
pub(crate) conn: C,
table_prefix: &'static str,
@ -45,7 +72,7 @@ impl<'conn, 'a: 'conn, H: HashSer, const SHARD_HEIGHT: u8> ShardStore
{
type H = H;
type CheckpointId = BlockHeight;
type Error = Either<io::Error, rusqlite::Error>;
type Error = Error;
fn get_shard(
&self,
@ -147,7 +174,7 @@ impl<H: HashSer, const SHARD_HEIGHT: u8> ShardStore
{
type H = H;
type CheckpointId = BlockHeight;
type Error = Either<io::Error, rusqlite::Error>;
type Error = Error;
fn get_shard(
&self,
@ -161,9 +188,9 @@ impl<H: HashSer, const SHARD_HEIGHT: u8> ShardStore
}
fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
let tx = self.conn.transaction().map_err(Error::Query)?;
put_shard(&tx, self.table_prefix, subtree)?;
tx.commit().map_err(Either::Right)?;
tx.commit().map_err(Error::Query)?;
Ok(())
}
@ -196,9 +223,9 @@ impl<H: HashSer, const SHARD_HEIGHT: u8> ShardStore
checkpoint_id: Self::CheckpointId,
checkpoint: Checkpoint,
) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
let tx = self.conn.transaction().map_err(Error::Query)?;
add_checkpoint(&tx, self.table_prefix, checkpoint_id, checkpoint)?;
tx.commit().map_err(Either::Right)
tx.commit().map_err(Error::Query)
}
fn checkpoint_count(&self) -> Result<usize, Self::Error> {
@ -223,9 +250,9 @@ impl<H: HashSer, const SHARD_HEIGHT: u8> ShardStore
where
F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
{
let tx = self.conn.transaction().map_err(Either::Right)?;
let tx = self.conn.transaction().map_err(Error::Query)?;
with_checkpoints(&tx, self.table_prefix, limit, callback)?;
tx.commit().map_err(Either::Right)
tx.commit().map_err(Error::Query)
}
fn update_checkpoint_with<F>(
@ -236,30 +263,28 @@ impl<H: HashSer, const SHARD_HEIGHT: u8> ShardStore
where
F: Fn(&mut Checkpoint) -> Result<(), Self::Error>,
{
let tx = self.conn.transaction().map_err(Either::Right)?;
let tx = self.conn.transaction().map_err(Error::Query)?;
let result = update_checkpoint_with(&tx, self.table_prefix, *checkpoint_id, update)?;
tx.commit().map_err(Either::Right)?;
tx.commit().map_err(Error::Query)?;
Ok(result)
}
fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
let tx = self.conn.transaction().map_err(Error::Query)?;
remove_checkpoint(&tx, self.table_prefix, *checkpoint_id)?;
tx.commit().map_err(Either::Right)
tx.commit().map_err(Error::Query)
}
fn truncate_checkpoints(
&mut self,
checkpoint_id: &Self::CheckpointId,
) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
let tx = self.conn.transaction().map_err(Error::Query)?;
truncate_checkpoints(&tx, self.table_prefix, *checkpoint_id)?;
tx.commit().map_err(Either::Right)
tx.commit().map_err(Error::Query)
}
}
type Error = Either<io::Error, rusqlite::Error>;
pub(crate) fn get_shard<H: HashSer>(
conn: &rusqlite::Connection,
table_prefix: &'static str,
@ -276,12 +301,12 @@ pub(crate) fn get_shard<H: HashSer>(
|row| Ok((row.get::<_, Vec<u8>>(0)?, row.get::<_, Option<Vec<u8>>>(1)?)),
)
.optional()
.map_err(Either::Right)?
.map_err(Error::Query)?
.map(|(shard_data, root_hash)| {
let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Either::Left)?;
let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Error::Serialization)?;
let located_tree = LocatedPrunableTree::from_parts(shard_root_addr, shard_tree);
if let Some(root_hash_data) = root_hash {
let root_hash = H::read(Cursor::new(root_hash_data)).map_err(Either::Left)?;
let root_hash = H::read(Cursor::new(root_hash_data)).map_err(Error::Serialization)?;
Ok(located_tree.reannotate_root(Some(Arc::new(root_hash))))
} else {
Ok(located_tree)
@ -311,10 +336,10 @@ pub(crate) fn last_shard<H: HashSer>(
},
)
.optional()
.map_err(Either::Right)?
.map_err(Error::Query)?
.map(|(shard_index, shard_data)| {
let shard_root = Address::from_parts(shard_root_level, shard_index);
let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Either::Left)?;
let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Error::Serialization)?;
Ok(LocatedPrunableTree::from_parts(shard_root, shard_tree))
})
.transpose()
@ -336,10 +361,10 @@ pub(crate) fn put_shard<H: HashSer>(
})
})
.transpose()
.map_err(Either::Left)?;
.map_err(Error::Serialization)?;
let mut subtree_data = vec![];
write_shard(&mut subtree_data, subtree.root()).map_err(Either::Left)?;
write_shard(&mut subtree_data, subtree.root()).map_err(Error::Serialization)?;
let mut stmt_put_shard = conn
.prepare_cached(&format!(
@ -350,7 +375,7 @@ pub(crate) fn put_shard<H: HashSer>(
shard_data = :shard_data",
table_prefix
))
.map_err(Either::Right)?;
.map_err(Error::Query)?;
stmt_put_shard
.execute(named_params![
@ -358,7 +383,7 @@ pub(crate) fn put_shard<H: HashSer>(
":root_hash": subtree_root_hash,
":shard_data": subtree_data
])
.map_err(Either::Right)?;
.map_err(Error::Query)?;
Ok(())
}
@ -373,14 +398,14 @@ pub(crate) fn get_shard_roots(
"SELECT shard_index FROM {}_tree_shards ORDER BY shard_index",
table_prefix
))
.map_err(Either::Right)?;
let mut rows = stmt.query([]).map_err(Either::Right)?;
.map_err(Error::Query)?;
let mut rows = stmt.query([]).map_err(Error::Query)?;
let mut res = vec![];
while let Some(row) = rows.next().map_err(Either::Right)? {
while let Some(row) = rows.next().map_err(Error::Query)? {
res.push(Address::from_parts(
shard_root_level,
row.get(0).map_err(Either::Right)?,
row.get(0).map_err(Error::Query)?,
));
}
Ok(res)
@ -398,7 +423,7 @@ pub(crate) fn truncate(
),
[from.index()],
)
.map_err(Either::Right)
.map_err(Error::Query)
.map(|_| ())
}
@ -412,10 +437,10 @@ pub(crate) fn get_cap<H: HashSer>(
|row| row.get::<_, Vec<u8>>(0),
)
.optional()
.map_err(Either::Right)?
.map_err(Error::Query)?
.map_or_else(
|| Ok(PrunableTree::empty()),
|cap_data| read_shard(&mut Cursor::new(cap_data)).map_err(Either::Left),
|cap_data| read_shard(&mut Cursor::new(cap_data)).map_err(Error::Serialization),
)
}
@ -432,11 +457,11 @@ pub(crate) fn put_cap<H: HashSer>(
SET cap_data = :cap_data",
table_prefix
))
.map_err(Either::Right)?;
.map_err(Error::Query)?;
let mut cap_data = vec![];
write_shard(&mut cap_data, &cap).map_err(Either::Left)?;
stmt.execute([cap_data]).map_err(Either::Right)?;
write_shard(&mut cap_data, &cap).map_err(Error::Serialization)?;
stmt.execute([cap_data]).map_err(Error::Query)?;
Ok(())
}
@ -456,7 +481,7 @@ pub(crate) fn min_checkpoint_id(
.map(|opt| opt.map(BlockHeight::from))
},
)
.map_err(Either::Right)
.map_err(Error::Query)
}
pub(crate) fn max_checkpoint_id(
@ -474,7 +499,7 @@ pub(crate) fn max_checkpoint_id(
.map(|opt| opt.map(BlockHeight::from))
},
)
.map_err(Either::Right)
.map_err(Error::Query)
}
pub(crate) fn add_checkpoint(
@ -489,14 +514,14 @@ pub(crate) fn add_checkpoint(
VALUES (:checkpoint_id, :position)",
table_prefix
))
.map_err(Either::Right)?;
.map_err(Error::Query)?;
stmt_insert_checkpoint
.execute(named_params![
":checkpoint_id": u32::from(checkpoint_id),
":position": checkpoint.position().map(u64::from)
])
.map_err(Either::Right)?;
.map_err(Error::Query)?;
let mut stmt_insert_mark_removed = conn
.prepare_cached(&format!(
@ -504,7 +529,7 @@ pub(crate) fn add_checkpoint(
VALUES (:checkpoint_id, :position)",
table_prefix
))
.map_err(Either::Right)?;
.map_err(Error::Query)?;
for pos in checkpoint.marks_removed() {
stmt_insert_mark_removed
@ -512,7 +537,7 @@ pub(crate) fn add_checkpoint(
":checkpoint_id": u32::from(checkpoint_id),
":position": u64::from(*pos)
])
.map_err(Either::Right)?;
.map_err(Error::Query)?;
}
Ok(())
@ -527,7 +552,7 @@ pub(crate) fn checkpoint_count(
[],
|row| row.get::<_, usize>(0),
)
.map_err(Either::Right)
.map_err(Error::Query)
}
pub(crate) fn get_checkpoint(
@ -550,7 +575,7 @@ pub(crate) fn get_checkpoint(
},
)
.optional()
.map_err(Either::Right)?;
.map_err(Error::Query)?;
checkpoint_position
.map(|pos_opt| {
@ -561,15 +586,15 @@ pub(crate) fn get_checkpoint(
WHERE checkpoint_id = ?",
table_prefix
))
.map_err(Either::Right)?;
.map_err(Error::Query)?;
let mark_removed_rows = stmt
.query([u32::from(checkpoint_id)])
.map_err(Either::Right)?;
.map_err(Error::Query)?;
let marks_removed = mark_removed_rows
.mapped(|row| row.get::<_, u64>(0).map(Position::from))
.collect::<Result<BTreeSet<_>, _>>()
.map_err(Either::Right)?;
.map_err(Error::Query)?;
Ok(Checkpoint::from_parts(
pos_opt.map_or(TreeState::Empty, TreeState::AtPosition),
@ -609,7 +634,7 @@ pub(crate) fn get_checkpoint_at_depth(
},
)
.optional()
.map_err(Either::Right)?;
.map_err(Error::Query)?;
checkpoint_parts
.map(|(checkpoint_id, pos_opt)| {
@ -620,15 +645,15 @@ pub(crate) fn get_checkpoint_at_depth(
WHERE checkpoint_id = ?",
table_prefix
))
.map_err(Either::Right)?;
.map_err(Error::Query)?;
let mark_removed_rows = stmt
.query([u32::from(checkpoint_id)])
.map_err(Either::Right)?;
.map_err(Error::Query)?;
let marks_removed = mark_removed_rows
.mapped(|row| row.get::<_, u64>(0).map(Position::from))
.collect::<Result<BTreeSet<_>, _>>()
.map_err(Either::Right)?;
.map_err(Error::Query)?;
Ok((
checkpoint_id,
@ -658,7 +683,7 @@ where
LIMIT :limit",
table_prefix
))
.map_err(Either::Right)?;
.map_err(Error::Query)?;
let mut stmt_get_checkpoint_marks_removed = conn
.prepare_cached(&format!(
@ -667,27 +692,27 @@ where
WHERE checkpoint_id = :checkpoint_id",
table_prefix
))
.map_err(Either::Right)?;
.map_err(Error::Query)?;
let mut rows = stmt_get_checkpoints
.query(named_params![":limit": limit])
.map_err(Either::Right)?;
.map_err(Error::Query)?;
while let Some(row) = rows.next().map_err(Either::Right)? {
let checkpoint_id = row.get::<_, u32>(0).map_err(Either::Right)?;
while let Some(row) = rows.next().map_err(Error::Query)? {
let checkpoint_id = row.get::<_, u32>(0).map_err(Error::Query)?;
let tree_state = row
.get::<_, Option<u64>>(1)
.map(|opt| opt.map_or_else(|| TreeState::Empty, |p| TreeState::AtPosition(p.into())))
.map_err(Either::Right)?;
.map_err(Error::Query)?;
let mark_removed_rows = stmt_get_checkpoint_marks_removed
.query(named_params![":checkpoint_id": checkpoint_id])
.map_err(Either::Right)?;
.map_err(Error::Query)?;
let marks_removed = mark_removed_rows
.mapped(|row| row.get::<_, u64>(0).map(Position::from))
.collect::<Result<BTreeSet<_>, _>>()
.map_err(Either::Right)?;
.map_err(Error::Query)?;
callback(
&BlockHeight::from(checkpoint_id),
@ -730,11 +755,11 @@ pub(crate) fn remove_checkpoint(
WHERE checkpoint_id = :checkpoint_id",
table_prefix
))
.map_err(Either::Right)?;
.map_err(Error::Query)?;
stmt_delete_checkpoint
.execute(named_params![":checkpoint_id": u32::from(checkpoint_id),])
.map_err(Either::Right)?;
.map_err(Error::Query)?;
Ok(())
}
@ -753,7 +778,7 @@ pub(crate) fn truncate_checkpoints(
),
[u32::from(checkpoint_id)],
)
.map_err(Either::Right)?;
.map_err(Error::Query)?;
Ok(())
}
@ -844,18 +869,18 @@ pub(crate) fn put_shard_roots<
SET subtree_end_height = :subtree_end_height, root_hash = :root_hash",
table_prefix
))
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
.map_err(|e| ShardTreeError::Storage(Error::Query(e)))?;
// The `shard_data` value will only be used in the case that no tree already exists.
let mut shard_data: Vec<u8> = vec![];
let tree = PrunableTree::leaf((root.root_hash().clone(), RetentionFlags::EPHEMERAL));
write_shard(&mut shard_data, &tree)
.map_err(|e| ShardTreeError::Storage(Either::Left(e)))?;
.map_err(|e| ShardTreeError::Storage(Error::Serialization(e)))?;
let mut root_hash_data: Vec<u8> = vec![];
root.root_hash()
.write(&mut root_hash_data)
.map_err(|e| ShardTreeError::Storage(Either::Left(e)))?;
.map_err(|e| ShardTreeError::Storage(Error::Serialization(e)))?;
stmt.execute(named_params![
":shard_index": start_index + i,
@ -863,7 +888,7 @@ pub(crate) fn put_shard_roots<
":root_hash": root_hash_data,
":shard_data": shard_data,
])
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
.map_err(|e| ShardTreeError::Storage(Error::Query(e)))?;
}
Ok(())

View File

@ -1,7 +1,7 @@
//! Functions for initializing the various databases.
use either::Either;
use incrementalmerkletree::Retention;
use std::{collections::HashMap, fmt, io};
use std::{collections::HashMap, fmt};
use tracing::debug;
use rusqlite::{self, types::ToSql};
@ -24,7 +24,7 @@ use zcash_client_backend::{data_api::SAPLING_SHARD_HEIGHT, keys::UnifiedFullView
use crate::{error::SqliteClientError, wallet, WalletDb, PRUNING_DEPTH, SAPLING_TABLES_PREFIX};
use super::commitment_tree::SqliteShardStore;
use super::commitment_tree::{self, SqliteShardStore};
mod migrations;
@ -43,7 +43,7 @@ pub enum WalletMigrationError {
BalanceError(BalanceError),
/// Wrapper for commitment tree invariant violations
CommitmentTree(ShardTreeError<Either<io::Error, rusqlite::Error>>),
CommitmentTree(ShardTreeError<commitment_tree::Error>),
}
impl From<rusqlite::Error> for WalletMigrationError {
@ -58,8 +58,8 @@ impl From<BalanceError> for WalletMigrationError {
}
}
impl From<ShardTreeError<Either<io::Error, rusqlite::Error>>> for WalletMigrationError {
fn from(e: ShardTreeError<Either<io::Error, rusqlite::Error>>) -> Self {
impl From<ShardTreeError<commitment_tree::Error>> for WalletMigrationError {
fn from(e: ShardTreeError<commitment_tree::Error>) -> Self {
WalletMigrationError::CommitmentTree(e)
}
}