Move decoding errors to sqlite crate.

Also move dependency on params out of wallet read/write methods.
The result is cleaner because these parameters are only required
for backend-specific encoding and decoding operations.
This commit is contained in:
Kris Nuttycombe 2021-01-12 18:24:18 -07:00
parent 5927e32059
commit 48f226f8b5
10 changed files with 657 additions and 593 deletions

View File

@ -6,7 +6,7 @@ use std::fmt::Debug;
use zcash_primitives::{
block::BlockHash,
consensus::{self, BlockHeight},
consensus::{BlockHeight},
merkle_tree::{CommitmentTree, IncrementalWitness},
note_encryption::Memo,
primitives::{Note, Nullifier, PaymentAddress},
@ -97,23 +97,20 @@ pub trait WalletRead {
/// Returns the payment address for the specified account, if the account
/// identifier specified refers to a valid account for this wallet.
fn get_address<P: consensus::Parameters>(
fn get_address(
&self,
params: &P,
account: AccountId,
) -> Result<Option<PaymentAddress>, Self::Error>;
/// Returns all extended full viewing keys known about by this wallet
fn get_extended_full_viewing_keys<P: consensus::Parameters>(
fn get_extended_full_viewing_keys(
&self,
params: &P,
) -> Result<HashMap<AccountId, ExtendedFullViewingKey>, Self::Error>;
/// Checks whether the specified extended full viewing key is
/// associated with the account.
fn is_valid_account_extfvk<P: consensus::Parameters>(
fn is_valid_account_extfvk(
&self,
params: &P,
account: AccountId,
extfvk: &ExtendedFullViewingKey,
) -> Result<bool, Self::Error>;
@ -204,9 +201,8 @@ pub trait WalletWrite: WalletRead {
/// a chain reorg might invalidate some stored state, this method must be
/// implemented in order to allow users of this API to "reset" the data store
/// to correctly represent chainstate as of a specified block height.
fn rewind_to_height<P: consensus::Parameters>(
fn rewind_to_height(
&mut self,
parameters: &P,
block_height: BlockHeight,
) -> Result<(), Self::Error>;
@ -260,17 +256,15 @@ pub trait WalletWrite: WalletRead {
/// Add the decrypted contents of a sent note to the database if it does not exist;
/// otherwise, update the note. This is useful in the case of a wallet restore where
/// the send of the note is being discovered via trial decryption.
fn put_sent_note<P: consensus::Parameters>(
fn put_sent_note(
&mut self,
params: &P,
output: &DecryptedOutput,
tx_ref: Self::TxRef,
) -> Result<(), Self::Error>;
/// Add the decrypted contents of a sent note to the database.
fn insert_sent_note<P: consensus::Parameters>(
fn insert_sent_note(
&mut self,
params: &P,
tx_ref: Self::TxRef,
output_index: usize,
account: AccountId,
@ -417,24 +411,21 @@ pub mod testing {
Ok(None)
}
fn get_address<P: consensus::Parameters>(
fn get_address(
&self,
_params: &P,
_account: AccountId,
) -> Result<Option<PaymentAddress>, Self::Error> {
Ok(None)
}
fn get_extended_full_viewing_keys<P: consensus::Parameters>(
&self,
_params: &P,
fn get_extended_full_viewing_keys(
&self
) -> Result<HashMap<AccountId, ExtendedFullViewingKey>, Self::Error> {
Ok(HashMap::new())
}
fn is_valid_account_extfvk<P: consensus::Parameters>(
fn is_valid_account_extfvk(
&self,
_params: &P,
_account: AccountId,
_extfvk: &ExtendedFullViewingKey,
) -> Result<bool, Self::Error> {
@ -502,9 +493,8 @@ pub mod testing {
Ok(())
}
fn rewind_to_height<P: consensus::Parameters>(
fn rewind_to_height(
&mut self,
_parameters: &P,
_block_height: BlockHeight,
) -> Result<(), Self::Error> {
Ok(())
@ -556,18 +546,16 @@ pub mod testing {
Ok(())
}
fn put_sent_note<P: consensus::Parameters>(
fn put_sent_note(
&mut self,
_params: &P,
_output: &DecryptedOutput,
_tx_ref: Self::TxRef,
) -> Result<(), Self::Error> {
Ok(())
}
fn insert_sent_note<P: consensus::Parameters>(
fn insert_sent_note(
&mut self,
_params: &P,
_tx_ref: Self::TxRef,
_output_index: usize,
_account: AccountId,

View File

@ -122,8 +122,9 @@ where
/// let cache_file = NamedTempFile::new().unwrap();
/// let cache = BlockDB::for_path(cache_file).unwrap();
/// let data_file = NamedTempFile::new().unwrap();
/// let data = WalletDB::for_path(data_file).unwrap().get_update_ops().unwrap();
/// scan_cached_blocks(&Network::TestNetwork, &cache, &data, None);
/// let db_read = WalletDB::for_path(data_file, Network::TestNetwork).unwrap();
/// let mut data = db_read.get_update_ops().unwrap();
/// scan_cached_blocks(&Network::TestNetwork, &cache, &mut data, None);
/// ```
///
/// [`init_blocks_table`]: crate::init::init_blocks_table
@ -152,7 +153,7 @@ where
})?;
// Fetch the ExtendedFullViewingKeys we are tracking
let extfvks = data.get_extended_full_viewing_keys(params)?;
let extfvks = data.get_extended_full_viewing_keys()?;
let ivks: Vec<_> = extfvks.values().map(|extfvk| extfvk.fvk.vk.ivk()).collect();
// Get the most recent CommitmentTree

View File

@ -23,42 +23,25 @@ pub enum ChainInvalid {
#[derive(Debug)]
pub enum Error<DbError, NoteId> {
/// Decoding of a stored value from its serialized form has failed.
CorruptedData(String),
/// Decoding of the extended full viewing key has failed (for the specified network)
IncorrectHRPExtFVK,
/// Unable to create a new spend because the wallet balance is not sufficient.
InsufficientBalance(Amount, Amount),
/// Chain validation detected an error in the block at the specified block height.
InvalidChain(BlockHeight, ChainInvalid),
/// A provided extfvk is not associated with the specified account.
InvalidExtSK(AccountId),
/// A received memo cannot be interpreted as a UTF-8 string.
InvalidMemo(std::str::Utf8Error),
/// The root of an output's witness tree in a newly arrived transaction does not correspond to
/// root of the stored commitment tree at the recorded height.
InvalidNewWitnessAnchor(usize, TxId, BlockHeight, Node),
/// The rcm value for a note cannot be decoded to a valid JubJub point.
InvalidNote,
/// The root of an output's witness tree in a previously stored transaction does not correspond to
/// root of the current commitment tree.
InvalidWitnessAnchor(NoteId, BlockHeight),
/// The wallet must first perform a scan of the blockchain before other
/// operations can be performed.
ScanRequired,
/// Illegal attempt to reinitialize an already-initialized wallet database.
//TODO: This ought to be moved to the database backend error type.
TableNotEmpty,
/// Bech32 decoding error
Bech32(bech32::Error),
/// Base58 decoding error
Base58(bs58::decode::Error),
/// An error occurred building a new transaction.
Builder(builder::Error),
/// Wrapper for errors from the underlying data store.
Database(DbError),
/// Wrapper for errors from the IO subsystem
Io(std::io::Error),
/// An error occurred decoding a protobuf message.
Protobuf(protobuf::ProtobufError),
/// The wallet attempted a sapling-only operation at a block
@ -82,8 +65,6 @@ impl ChainInvalid {
impl<E: fmt::Display, N: fmt::Display> fmt::Display for Error<E, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self {
Error::CorruptedData(reason) => write!(f, "Data DB is corrupted: {}", reason),
Error::IncorrectHRPExtFVK => write!(f, "Incorrect HRP for extfvk"),
Error::InsufficientBalance(have, need) => write!(
f,
"Insufficient balance (have {}, need {} including fee)",
@ -95,25 +76,19 @@ impl<E: fmt::Display, N: fmt::Display> fmt::Display for Error<E, N> {
Error::InvalidExtSK(account) => {
write!(f, "Incorrect ExtendedSpendingKey for account {}", account.0)
}
Error::InvalidMemo(e) => write!(f, "{}", e),
Error::InvalidNewWitnessAnchor(output, txid, last_height, anchor) => write!(
f,
"New witness for output {} in tx {} has incorrect anchor after scanning block {}: {:?}",
output, txid, last_height, anchor,
),
Error::InvalidNote => write!(f, "Invalid note"),
Error::InvalidWitnessAnchor(id_note, last_height) => write!(
f,
"Witness for note {} has incorrect anchor after scanning block {}",
id_note, last_height
),
Error::ScanRequired => write!(f, "Must scan blocks first"),
Error::TableNotEmpty => write!(f, "Table is not empty"),
Error::Bech32(e) => write!(f, "{}", e),
Error::Base58(e) => write!(f, "{}", e),
Error::Builder(e) => write!(f, "{:?}", e),
Error::Database(e) => write!(f, "{}", e),
Error::Io(e) => write!(f, "{}", e),
Error::Protobuf(e) => write!(f, "{}", e),
Error::SaplingNotActive => write!(f, "Could not determine Sapling upgrade activation height."),
}
@ -123,40 +98,25 @@ impl<E: fmt::Display, N: fmt::Display> fmt::Display for Error<E, N> {
impl<E: error::Error + 'static, N: error::Error + 'static> error::Error for Error<E, N> {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match &self {
Error::InvalidMemo(e) => Some(e),
Error::Bech32(e) => Some(e),
Error::Builder(e) => Some(e),
Error::Database(e) => Some(e),
Error::Io(e) => Some(e),
Error::Protobuf(e) => Some(e),
_ => None,
}
}
}
impl<E, N> From<bech32::Error> for Error<E, N> {
fn from(e: bech32::Error) -> Self {
Error::Bech32(e)
}
}
impl<E, N> From<bs58::decode::Error> for Error<E, N> {
fn from(e: bs58::decode::Error) -> Self {
Error::Base58(e)
}
}
impl<E, N> From<builder::Error> for Error<E, N> {
fn from(e: builder::Error) -> Self {
Error::Builder(e)
}
}
impl<E, N> From<std::io::Error> for Error<E, N> {
fn from(e: std::io::Error) -> Self {
Error::Io(e)
}
}
//impl<E, N> From<std::io::Error> for Error<E, N> {
// fn from(e: std::io::Error) -> Self {
// Error::Io(e)
// }
//}
impl<E, N> From<protobuf::ProtobufError> for Error<E, N> {
fn from(e: protobuf::ProtobufError) -> Self {

View File

@ -35,7 +35,7 @@ where
&'db D: WalletWrite<Error = E>,
{
// Fetch the ExtendedFullViewingKeys we are tracking
let extfvks = data.get_extended_full_viewing_keys(params)?;
let extfvks = data.get_extended_full_viewing_keys()?;
// Height is block height for mined transactions, and the "mempool height" (chain height + 1)
// for mempool transactions.
@ -57,7 +57,7 @@ where
for output in outputs {
if output.outgoing {
up.put_sent_note(params, &output, tx_ref)?;
up.put_sent_note(&output, tx_ref)?;
} else {
up.put_received_note(&output, &None, tx_ref)?;
}
@ -125,9 +125,10 @@ where
/// let to = extsk.default_address().unwrap().1.into();
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file).unwrap().get_update_ops().unwrap();
/// let db_read = WalletDB::for_path(data_file, Network::TestNetwork).unwrap();
/// let mut db = db_read.get_update_ops().unwrap();
/// match create_spend_to_address(
/// &db,
/// &mut db,
/// &Network::TestNetwork,
/// tx_prover,
/// account,
@ -162,7 +163,7 @@ where
// ExtendedFullViewingKey for the account we are spending from.
let extfvk = ExtendedFullViewingKey::from(extsk);
if !data
.is_valid_account_extfvk(params, account, &extfvk)
.is_valid_account_extfvk(account, &extfvk)
.map_err(|e| e.into())?
{
return Err(Error::InvalidExtSK(account));
@ -249,7 +250,6 @@ where
}
up.insert_sent_note(
params,
tx_ref,
output_index as usize,
account,

View File

@ -16,7 +16,6 @@
//! scan_cached_blocks,
//! },
//! error::Error,
//! testing::{MockBlockSource, MockWalletDB}
//! },
//! };
//!
@ -28,8 +27,13 @@
//! };
//!
//! let network = Network::TestNetwork;
//! let db_cache = MockBlockSource { };
//! let mut db_data = MockWalletDB { };
//! let cache_file = NamedTempFile::new().unwrap();
//! let db_cache = BlockDB::for_path(cache_file).unwrap();
//! let db_file = NamedTempFile::new().unwrap();
//! let db_read = WalletDB::for_path(db_file, network).unwrap();
//! init_data_database(&db_read).unwrap();
//!
//! let mut db_data = db_read.get_update_ops().unwrap();
//!
//! // 1) Download new CompactBlocks into db_cache.
//!
@ -48,7 +52,7 @@
//! let rewind_height = upper_bound - 10;
//!
//! // b) Rewind scanned block information.
//! db_data.rewind_to_height(&network, rewind_height);
//! db_data.rewind_to_height(rewind_height);
//!
//! // c) Delete cached blocks from rewind_height onwards.
//! //
@ -75,13 +79,13 @@
//! ```
use protobuf::parse_from_bytes;
use rusqlite::types::ToSql;
use rusqlite::{params};
use zcash_primitives::consensus::BlockHeight;
use zcash_client_backend::{data_api::error::Error, proto::compact_formats::CompactBlock};
use crate::{error::SqliteClientError, BlockDB};
use crate::{error::{SqliteClientError, db_error}, BlockDB, NoteId};
pub mod init;
@ -95,18 +99,19 @@ pub fn with_blocks<F>(
from_height: BlockHeight,
limit: Option<u32>,
mut with_row: F,
) -> Result<(), SqliteClientError>
) -> Result<(), Error<SqliteClientError, NoteId>>
where
F: FnMut(CompactBlock) -> Result<(), SqliteClientError>,
F: FnMut(CompactBlock) -> Result<(), Error<SqliteClientError, NoteId>>,
{
// Fetch the CompactBlocks we need to scan
let mut stmt_blocks = cache.0.prepare(
"SELECT height, data FROM compactblocks WHERE height > ? ORDER BY height ASC LIMIT ?",
)?;
).map_err(db_error)?;
let rows = stmt_blocks.query_map(
&[
u32::from(from_height).to_sql()?,
limit.unwrap_or(u32::max_value()).to_sql()?,
params![
u32::from(from_height),
limit.unwrap_or(u32::max_value()),
],
|row| {
Ok(CompactBlockRow {
@ -114,18 +119,19 @@ where
data: row.get(1)?,
})
},
)?;
).map_err(db_error)?;
for row_result in rows {
let cbr = row_result?;
let block: CompactBlock = parse_from_bytes(&cbr.data)?;
let cbr = row_result.map_err(db_error)?;
let block: CompactBlock = parse_from_bytes(&cbr.data).map_err(Error::from)?;
if block.height() != cbr.height {
return Err(Error::CorruptedData(format!(
return Err(
Error::Database(SqliteClientError::CorruptedData(format!(
"Block height {} did not match row's height field value {}",
block.height(),
cbr.height
))
)))
.into());
}
@ -162,7 +168,7 @@ mod tests {
init::{init_accounts_table, init_data_database},
rewind_to_height,
},
AccountId, BlockDB, NoteId, WalletDB
AccountId, BlockDB, NoteId, WalletDB,
};
#[test]
@ -172,13 +178,13 @@ mod tests {
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB::for_path(data_file.path()).unwrap();
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
init_accounts_table(&db_data, &tests::network(), &[extfvk.clone()]).unwrap();
init_accounts_table(&db_data, &[extfvk.clone()]).unwrap();
// Empty chain should be valid
validate_chain(
@ -207,12 +213,7 @@ mod tests {
// Scan the cache
let mut db_write = db_data.get_update_ops().unwrap();
scan_cached_blocks(
&tests::network(),
&db_cache,
&mut db_write,
None
).unwrap();
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Data-only chain should be valid
validate_chain(
@ -258,13 +259,13 @@ mod tests {
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB::for_path(data_file.path()).unwrap();
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
init_accounts_table(&db_data, &tests::network(), &[extfvk.clone()]).unwrap();
init_accounts_table(&db_data, &[extfvk.clone()]).unwrap();
// Create some fake CompactBlocks
let (cb, _) = fake_compact_block(
@ -315,9 +316,7 @@ mod tests {
&tests::network(),
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
)
.map_err(|e| e.0)
{
) {
Err(Error::InvalidChain(upper_bound, _)) => {
assert_eq!(upper_bound, sapling_activation_height() + 2)
}
@ -332,13 +331,13 @@ mod tests {
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB::for_path(data_file.path()).unwrap();
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
init_accounts_table(&db_data, &tests::network(), &[extfvk.clone()]).unwrap();
init_accounts_table(&db_data, &[extfvk.clone()]).unwrap();
// Create some fake CompactBlocks
let (cb, _) = fake_compact_block(
@ -390,7 +389,6 @@ mod tests {
&db_cache,
(&db_data).get_max_height_hash().unwrap(),
)
.map_err(|e| e.0)
{
Err(Error::InvalidChain(upper_bound, _)) => {
assert_eq!(upper_bound, sapling_activation_height() + 3)
@ -406,13 +404,13 @@ mod tests {
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB::for_path(data_file.path()).unwrap();
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
init_accounts_table(&db_data, &tests::network(), &[extfvk.clone()]).unwrap();
init_accounts_table(&db_data, &[extfvk.clone()]).unwrap();
// Account balance should be zero
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), Amount::zero());
@ -440,13 +438,13 @@ mod tests {
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
// "Rewind" to height of last scanned block
rewind_to_height(&db_data, &tests::network(), sapling_activation_height() + 1).unwrap();
rewind_to_height(&db_data, sapling_activation_height() + 1).unwrap();
// Account balance should be unaltered
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
// Rewind so that one block is dropped
rewind_to_height(&db_data, &tests::network(), sapling_activation_height()).unwrap();
rewind_to_height(&db_data, sapling_activation_height()).unwrap();
// Account balance should only contain the first received note
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value);
@ -465,13 +463,13 @@ mod tests {
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB::for_path(data_file.path()).unwrap();
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
init_accounts_table(&db_data, &tests::network(), &[extfvk.clone()]).unwrap();
init_accounts_table(&db_data, &[extfvk.clone()]).unwrap();
// Create a block with height SAPLING_ACTIVATION_HEIGHT
let value = Amount::from_u64(50000).unwrap();
@ -530,13 +528,13 @@ mod tests {
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB::for_path(data_file.path()).unwrap();
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
init_accounts_table(&db_data, &tests::network(), &[extfvk.clone()]).unwrap();
init_accounts_table(&db_data, &[extfvk.clone()]).unwrap();
// Account balance should be zero
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), Amount::zero());
@ -578,13 +576,13 @@ mod tests {
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB::for_path(data_file.path()).unwrap();
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
init_accounts_table(&db_data, &tests::network(), &[extfvk.clone()]).unwrap();
init_accounts_table(&db_data, &[extfvk.clone()]).unwrap();
// Account balance should be zero
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), Amount::zero());

View File

@ -1,64 +1,86 @@
use std::error;
use std::fmt;
use zcash_primitives::transaction::builder;
use zcash_client_backend::data_api::error::Error;
use crate::NoteId;
#[derive(Debug)]
pub struct SqliteClientError(pub Error<rusqlite::Error, NoteId>);
pub enum SqliteClientError {
/// Decoding of a stored value from its serialized form has failed.
CorruptedData(String),
/// Decoding of the extended full viewing key has failed (for the specified network)
IncorrectHRPExtFVK,
/// The rcm value for a note cannot be decoded to a valid JubJub point.
InvalidNote,
/// Bech32 decoding error
Bech32(bech32::Error),
/// Base58 decoding error
Base58(bs58::decode::Error),
/// Illegal attempt to reinitialize an already-initialized wallet database.
TableNotEmpty,
/// Wrapper for rusqlite errors.
DbError(rusqlite::Error),
/// Wrapper for errors from the IO subsystem
Io(std::io::Error),
/// A received memo cannot be interpreted as a UTF-8 string.
InvalidMemo(std::str::Utf8Error),
}
impl error::Error for SqliteClientError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match &self {
SqliteClientError::InvalidMemo(e) => Some(e),
SqliteClientError::Bech32(e) => Some(e),
SqliteClientError::DbError(e) => Some(e),
SqliteClientError::Io(e) => Some(e),
_ => None,
}
}
}
impl fmt::Display for SqliteClientError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}
impl From<Error<rusqlite::Error, NoteId>> for SqliteClientError {
fn from(e: Error<rusqlite::Error, NoteId>) -> Self {
SqliteClientError(e)
}
}
impl From<bech32::Error> for SqliteClientError {
fn from(e: bech32::Error) -> Self {
SqliteClientError(Error::Bech32(e))
match &self {
SqliteClientError::CorruptedData(reason) => {
write!(f, "Data DB is corrupted: {}", reason)
}
SqliteClientError::IncorrectHRPExtFVK => write!(f, "Incorrect HRP for extfvk"),
SqliteClientError::InvalidNote => write!(f, "Invalid note"),
SqliteClientError::Bech32(e) => write!(f, "{}", e),
SqliteClientError::Base58(e) => write!(f, "{}", e),
SqliteClientError::TableNotEmpty => write!(f, "Table is not empty"),
SqliteClientError::DbError(e) => write!(f, "{}", e),
SqliteClientError::Io(e) => write!(f, "{}", e),
SqliteClientError::InvalidMemo(e) => write!(f, "{}", e),
}
}
}
impl From<rusqlite::Error> for SqliteClientError {
fn from(e: rusqlite::Error) -> Self {
SqliteClientError(Error::Database(e))
}
}
impl From<bs58::decode::Error> for SqliteClientError {
fn from(e: bs58::decode::Error) -> Self {
SqliteClientError(Error::Base58(e))
}
}
impl From<builder::Error> for SqliteClientError {
fn from(e: builder::Error) -> Self {
SqliteClientError(Error::Builder(e))
SqliteClientError::DbError(e)
}
}
impl From<std::io::Error> for SqliteClientError {
fn from(e: std::io::Error) -> Self {
SqliteClientError(Error::Io(e))
SqliteClientError::Io(e)
}
}
impl From<protobuf::ProtobufError> for SqliteClientError {
fn from(e: protobuf::ProtobufError) -> Self {
SqliteClientError(Error::Protobuf(e))
impl From<bech32::Error> for SqliteClientError {
fn from(e: bech32::Error) -> Self {
SqliteClientError::Bech32(e)
}
}
impl From<SqliteClientError> for Error<rusqlite::Error, NoteId> {
fn from(e: SqliteClientError) -> Self {
e.0
impl From<bs58::decode::Error> for SqliteClientError {
fn from(e: bs58::decode::Error) -> Self {
SqliteClientError::Base58(e)
}
}
pub fn db_error(r: rusqlite::Error) -> Error<SqliteClientError, NoteId> {
Error::Database(SqliteClientError::DbError(r))
}

View File

@ -24,13 +24,11 @@
//! [`CompactBlock`]: zcash_client_backend::proto::compact_formats::CompactBlock
//! [`init_cache_database`]: crate::init::init_cache_database
use std::fmt;
use std::collections::HashMap;
use std::fmt;
use std::path::Path;
use rusqlite::{types::ToSql, Connection, Statement, NO_PARAMS};
use ff::PrimeField;
use rusqlite::{Connection, Statement, NO_PARAMS};
use zcash_primitives::{
block::BlockHash,
@ -52,7 +50,7 @@ use zcash_client_backend::{
DecryptedOutput,
};
use crate::error::SqliteClientError;
use crate::error::{db_error, SqliteClientError};
pub mod chain;
pub mod error;
@ -69,53 +67,56 @@ impl fmt::Display for NoteId {
}
}
/// A newtype wrapper for the sqlite connection to the wallet database.
pub struct WalletDB(Connection);
/// A wrapper for the sqlite connection to the wallet database.
pub struct WalletDB<P> {
conn: Connection,
params: P,
}
impl WalletDB {
impl<P: consensus::Parameters> WalletDB<P> {
/// Construct a connection to the wallet database stored at the specified path.
pub fn for_path<P: AsRef<Path>>(path: P) -> Result<Self, rusqlite::Error> {
Connection::open(path).map(WalletDB)
pub fn for_path<F: AsRef<Path>>(path: F, params: P) -> Result<Self, rusqlite::Error> {
Connection::open(path).map(move |conn| WalletDB { conn, params })
}
/// Given a wallet database connection, obtain a handle for the write operations
/// for that database. This operation may eagerly initialize and cache sqlite
/// prepared statements that are used in write operations.
pub fn get_update_ops<'a>(&'a self) -> Result<DataConnStmtCache<'a>, SqliteClientError> {
pub fn get_update_ops<'a>(&'a self) -> Result<DataConnStmtCache<'a, P>, SqliteClientError> {
Ok(
DataConnStmtCache {
conn: self,
stmt_insert_block: self.0.prepare(
wallet_db: self,
stmt_insert_block: self.conn.prepare(
"INSERT INTO blocks (height, hash, time, sapling_tree)
VALUES (?, ?, ?, ?)",
)?,
stmt_insert_tx_meta: self.0.prepare(
stmt_insert_tx_meta: self.conn.prepare(
"INSERT INTO transactions (txid, block, tx_index)
VALUES (?, ?, ?)",
)?,
stmt_update_tx_meta: self.0.prepare(
stmt_update_tx_meta: self.conn.prepare(
"UPDATE transactions
SET block = ?, tx_index = ? WHERE txid = ?",
)?,
stmt_insert_tx_data: self.0.prepare(
stmt_insert_tx_data: self.conn.prepare(
"INSERT INTO transactions (txid, created, expiry_height, raw)
VALUES (?, ?, ?, ?)",
)?,
stmt_update_tx_data: self.0.prepare(
stmt_update_tx_data: self.conn.prepare(
"UPDATE transactions
SET expiry_height = ?, raw = ? WHERE txid = ?",
)?,
stmt_select_tx_ref: self.0.prepare(
stmt_select_tx_ref: self.conn.prepare(
"SELECT id_tx FROM transactions WHERE txid = ?",
)?,
stmt_mark_recived_note_spent: self.0.prepare(
stmt_mark_recived_note_spent: self.conn.prepare(
"UPDATE received_notes SET spent = ? WHERE nf = ?"
)?,
stmt_insert_received_note: self.0.prepare(
stmt_insert_received_note: self.conn.prepare(
"INSERT INTO received_notes (tx, output_index, account, diversifier, value, rcm, memo, nf, is_change)
VALUES (:tx, :output_index, :account, :diversifier, :value, :rcm, :memo, :nf, :is_change)",
)?,
stmt_update_received_note: self.0.prepare(
stmt_update_received_note: self.conn.prepare(
"UPDATE received_notes
SET account = :account,
diversifier = :diversifier,
@ -126,26 +127,26 @@ impl WalletDB {
is_change = IFNULL(:is_change, is_change)
WHERE tx = :tx AND output_index = :output_index",
)?,
stmt_select_received_note: self.0.prepare(
stmt_select_received_note: self.conn.prepare(
"SELECT id_note FROM received_notes WHERE tx = ? AND output_index = ?"
)?,
stmt_update_sent_note: self.0.prepare(
stmt_update_sent_note: self.conn.prepare(
"UPDATE sent_notes
SET from_account = ?, address = ?, value = ?, memo = ?
WHERE tx = ? AND output_index = ?",
)?,
stmt_insert_sent_note: self.0.prepare(
stmt_insert_sent_note: self.conn.prepare(
"INSERT INTO sent_notes (tx, output_index, from_account, address, value, memo)
VALUES (?, ?, ?, ?, ?, ?)",
)?,
stmt_insert_witness: self.0.prepare(
stmt_insert_witness: self.conn.prepare(
"INSERT INTO sapling_witnesses (note, block, witness)
VALUES (?, ?, ?)",
)?,
stmt_prune_witnesses: self.0.prepare(
stmt_prune_witnesses: self.conn.prepare(
"DELETE FROM sapling_witnesses WHERE block < ?"
)?,
stmt_update_expired: self.0.prepare(
stmt_update_expired: self.conn.prepare(
"UPDATE received_notes SET spent = NULL WHERE EXISTS (
SELECT id_tx FROM transactions
WHERE id_tx = received_notes.spent AND block IS NULL AND expiry_height < ?
@ -156,49 +157,43 @@ impl WalletDB {
}
}
impl WalletRead for WalletDB {
type Error = SqliteClientError;
impl<P: consensus::Parameters> WalletRead for WalletDB<P> {
type Error = Error<SqliteClientError, NoteId>;
type NoteRef = NoteId;
type TxRef = i64;
fn block_height_extrema(&self) -> Result<Option<(BlockHeight, BlockHeight)>, Self::Error> {
wallet::block_height_extrema(self).map_err(SqliteClientError::from)
wallet::block_height_extrema(self).map_err(db_error)
}
fn get_block_hash(&self, block_height: BlockHeight) -> Result<Option<BlockHash>, Self::Error> {
wallet::get_block_hash(self, block_height).map_err(SqliteClientError::from)
wallet::get_block_hash(self, block_height).map_err(db_error)
}
fn get_tx_height(&self, txid: TxId) -> Result<Option<BlockHeight>, Self::Error> {
wallet::get_tx_height(self, txid).map_err(SqliteClientError::from)
wallet::get_tx_height(self, txid).map_err(db_error)
}
fn get_extended_full_viewing_keys<P: consensus::Parameters>(
fn get_extended_full_viewing_keys(
&self,
params: &P,
) -> Result<HashMap<AccountId, ExtendedFullViewingKey>, Self::Error> {
wallet::get_extended_full_viewing_keys(self, params)
wallet::get_extended_full_viewing_keys(self).map_err(Error::Database)
}
fn get_address<P: consensus::Parameters>(
&self,
params: &P,
account: AccountId,
) -> Result<Option<PaymentAddress>, Self::Error> {
wallet::get_address(self, params, account)
fn get_address(&self, account: AccountId) -> Result<Option<PaymentAddress>, Self::Error> {
wallet::get_address(self, account).map_err(Error::Database)
}
fn is_valid_account_extfvk<P: consensus::Parameters>(
fn is_valid_account_extfvk(
&self,
params: &P,
account: AccountId,
extfvk: &ExtendedFullViewingKey,
) -> Result<bool, Self::Error> {
wallet::is_valid_account_extfvk(self, params, account, extfvk)
wallet::is_valid_account_extfvk(self, account, extfvk).map_err(Error::Database)
}
fn get_balance(&self, account: AccountId) -> Result<Amount, Self::Error> {
wallet::get_balance(self, account)
wallet::get_balance(self, account).map_err(Error::Database)
}
fn get_verified_balance(
@ -206,36 +201,36 @@ impl WalletRead for WalletDB {
account: AccountId,
anchor_height: BlockHeight,
) -> Result<Amount, Self::Error> {
wallet::get_verified_balance(self, account, anchor_height)
wallet::get_verified_balance(self, account, anchor_height).map_err(Error::Database)
}
fn get_received_memo_as_utf8(
&self,
id_note: Self::NoteRef,
) -> Result<Option<String>, Self::Error> {
wallet::get_received_memo_as_utf8(self, id_note)
wallet::get_received_memo_as_utf8(self, id_note).map_err(Error::Database)
}
fn get_sent_memo_as_utf8(&self, id_note: Self::NoteRef) -> Result<Option<String>, Self::Error> {
wallet::get_sent_memo_as_utf8(self, id_note)
wallet::get_sent_memo_as_utf8(self, id_note).map_err(Error::Database)
}
fn get_commitment_tree(
&self,
block_height: BlockHeight,
) -> Result<Option<CommitmentTree<Node>>, Self::Error> {
wallet::get_commitment_tree(self, block_height)
wallet::get_commitment_tree(self, block_height).map_err(Error::Database)
}
fn get_witnesses(
&self,
block_height: BlockHeight,
) -> Result<Vec<(Self::NoteRef, IncrementalWitness<Node>)>, Self::Error> {
wallet::get_witnesses(self, block_height)
wallet::get_witnesses(self, block_height).map_err(Error::Database)
}
fn get_nullifiers(&self) -> Result<Vec<(Nullifier, AccountId)>, Self::Error> {
wallet::get_nullifiers(self)
wallet::get_nullifiers(self).map_err(Error::Database)
}
fn select_spendable_notes(
@ -245,11 +240,12 @@ impl WalletRead for WalletDB {
anchor_height: BlockHeight,
) -> Result<Vec<SpendableNote>, Self::Error> {
wallet::transact::select_spendable_notes(self, account, target_value, anchor_height)
.map_err(Error::Database)
}
}
pub struct DataConnStmtCache<'a> {
conn: &'a WalletDB,
pub struct DataConnStmtCache<'a, P> {
wallet_db: &'a WalletDB<P>,
stmt_insert_block: Statement<'a>,
stmt_insert_tx_meta: Statement<'a>,
@ -273,49 +269,43 @@ pub struct DataConnStmtCache<'a> {
stmt_update_expired: Statement<'a>,
}
impl<'a> WalletRead for DataConnStmtCache<'a> {
type Error = SqliteClientError;
impl<'a, P: consensus::Parameters> WalletRead for DataConnStmtCache<'a, P> {
type Error = Error<SqliteClientError, NoteId>;
type NoteRef = NoteId;
type TxRef = i64;
fn block_height_extrema(&self) -> Result<Option<(BlockHeight, BlockHeight)>, Self::Error> {
self.conn.block_height_extrema()
self.wallet_db.block_height_extrema()
}
fn get_block_hash(&self, block_height: BlockHeight) -> Result<Option<BlockHash>, Self::Error> {
self.conn.get_block_hash(block_height)
self.wallet_db.get_block_hash(block_height)
}
fn get_tx_height(&self, txid: TxId) -> Result<Option<BlockHeight>, Self::Error> {
self.conn.get_tx_height(txid)
self.wallet_db.get_tx_height(txid)
}
fn get_extended_full_viewing_keys<P: consensus::Parameters>(
fn get_extended_full_viewing_keys(
&self,
params: &P,
) -> Result<HashMap<AccountId, ExtendedFullViewingKey>, Self::Error> {
self.conn.get_extended_full_viewing_keys(params)
self.wallet_db.get_extended_full_viewing_keys()
}
fn get_address<P: consensus::Parameters>(
&self,
params: &P,
account: AccountId,
) -> Result<Option<PaymentAddress>, Self::Error> {
self.conn.get_address(params, account)
fn get_address(&self, account: AccountId) -> Result<Option<PaymentAddress>, Self::Error> {
self.wallet_db.get_address(account)
}
fn is_valid_account_extfvk<P: consensus::Parameters>(
fn is_valid_account_extfvk(
&self,
params: &P,
account: AccountId,
extfvk: &ExtendedFullViewingKey,
) -> Result<bool, Self::Error> {
self.conn.is_valid_account_extfvk(params, account, extfvk)
self.wallet_db.is_valid_account_extfvk(account, extfvk)
}
fn get_balance(&self, account: AccountId) -> Result<Amount, Self::Error> {
self.conn.get_balance(account)
self.wallet_db.get_balance(account)
}
fn get_verified_balance(
@ -323,36 +313,36 @@ impl<'a> WalletRead for DataConnStmtCache<'a> {
account: AccountId,
anchor_height: BlockHeight,
) -> Result<Amount, Self::Error> {
self.conn.get_verified_balance(account, anchor_height)
self.wallet_db.get_verified_balance(account, anchor_height)
}
fn get_received_memo_as_utf8(
&self,
id_note: Self::NoteRef,
) -> Result<Option<String>, Self::Error> {
self.conn.get_received_memo_as_utf8(id_note)
self.wallet_db.get_received_memo_as_utf8(id_note)
}
fn get_sent_memo_as_utf8(&self, id_note: Self::NoteRef) -> Result<Option<String>, Self::Error> {
self.conn.get_sent_memo_as_utf8(id_note)
self.wallet_db.get_sent_memo_as_utf8(id_note)
}
fn get_commitment_tree(
&self,
block_height: BlockHeight,
) -> Result<Option<CommitmentTree<Node>>, Self::Error> {
self.conn.get_commitment_tree(block_height)
self.wallet_db.get_commitment_tree(block_height)
}
fn get_witnesses(
&self,
block_height: BlockHeight,
) -> Result<Vec<(Self::NoteRef, IncrementalWitness<Node>)>, Self::Error> {
self.conn.get_witnesses(block_height)
self.wallet_db.get_witnesses(block_height)
}
fn get_nullifiers(&self) -> Result<Vec<(Nullifier, AccountId)>, Self::Error> {
self.conn.get_nullifiers()
self.wallet_db.get_nullifiers()
}
fn select_spendable_notes(
@ -361,23 +351,30 @@ impl<'a> WalletRead for DataConnStmtCache<'a> {
target_value: Amount,
anchor_height: BlockHeight,
) -> Result<Vec<SpendableNote>, Self::Error> {
self.conn.select_spendable_notes(account, target_value, anchor_height)
self.wallet_db
.select_spendable_notes(account, target_value, anchor_height)
}
}
impl<'a> WalletWrite for DataConnStmtCache<'a> {
impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> {
fn transactionally<F, A>(&mut self, f: F) -> Result<A, Self::Error>
where
F: FnOnce(&mut Self) -> Result<A, Self::Error>,
{
self.conn.0.execute("BEGIN IMMEDIATE", NO_PARAMS)?;
self.wallet_db
.conn
.execute("BEGIN IMMEDIATE", NO_PARAMS)
.map_err(db_error)?;
match f(self) {
Ok(result) => {
self.conn.0.execute("COMMIT", NO_PARAMS)?;
self.wallet_db
.conn
.execute("COMMIT", NO_PARAMS)
.map_err(db_error)?;
Ok(result)
}
Err(error) => {
match self.conn.0.execute("ROLLBACK", NO_PARAMS) {
match self.wallet_db.conn.execute("ROLLBACK", NO_PARAMS) {
Ok(_) => Err(error),
Err(e) =>
// REVIEW: If rollback fails, what do we want to do? I think that
@ -386,7 +383,7 @@ impl<'a> WalletWrite for DataConnStmtCache<'a> {
panic!(
"Rollback failed with error {} while attempting to recover from error {}; database is likely corrupt.",
e,
error.0
error
)
}
}
@ -400,28 +397,12 @@ impl<'a> WalletWrite for DataConnStmtCache<'a> {
block_time: u32,
commitment_tree: &CommitmentTree<Node>,
) -> Result<(), Self::Error> {
let mut encoded_tree = Vec::new();
commitment_tree
.write(&mut encoded_tree)
.expect("Should be able to write to a Vec");
self.stmt_insert_block.execute(&[
u32::from(block_height).to_sql()?,
block_hash.0.to_sql()?,
block_time.to_sql()?,
encoded_tree.to_sql()?,
])?;
Ok(())
wallet::insert_block(self, block_height, block_hash, block_time, commitment_tree)
.map_err(Error::Database)
}
fn rewind_to_height<P: consensus::Parameters>(
&mut self,
parameters: &P,
block_height: BlockHeight,
) -> Result<(), Self::Error> {
wallet::rewind_to_height(self.conn, parameters, block_height)
fn rewind_to_height(&mut self, block_height: BlockHeight) -> Result<(), Self::Error> {
wallet::rewind_to_height(self.wallet_db, block_height)
}
fn put_tx_meta(
@ -429,27 +410,7 @@ impl<'a> WalletWrite for DataConnStmtCache<'a> {
tx: &WalletTx,
height: BlockHeight,
) -> Result<Self::TxRef, Self::Error> {
let txid = tx.txid.0.to_vec();
if self.stmt_update_tx_meta.execute(&[
u32::from(height).to_sql()?,
(tx.index as i64).to_sql()?,
txid.to_sql()?,
])? == 0
{
// It isn't there, so insert our transaction into the database.
self.stmt_insert_tx_meta.execute(&[
txid.to_sql()?,
u32::from(height).to_sql()?,
(tx.index as i64).to_sql()?,
])?;
Ok(self.conn.0.last_insert_rowid())
} else {
// It was there, so grab its row number.
self.stmt_select_tx_ref
.query_row(&[txid], |row| row.get(0))
.map_err(SqliteClientError::from)
}
wallet::put_tx_meta(self, tx, height).map_err(Error::Database)
}
fn put_tx_data(
@ -457,38 +418,11 @@ impl<'a> WalletWrite for DataConnStmtCache<'a> {
tx: &Transaction,
created_at: Option<time::OffsetDateTime>,
) -> Result<Self::TxRef, Self::Error> {
let txid = tx.txid().0.to_vec();
let mut raw_tx = vec![];
tx.write(&mut raw_tx)?;
if self.stmt_update_tx_data.execute(&[
u32::from(tx.expiry_height).to_sql()?,
raw_tx.to_sql()?,
txid.to_sql()?,
])? == 0
{
// It isn't there, so insert our transaction into the database.
self.stmt_insert_tx_data.execute(&[
txid.to_sql()?,
created_at.to_sql()?,
u32::from(tx.expiry_height).to_sql()?,
raw_tx.to_sql()?,
])?;
Ok(self.conn.0.last_insert_rowid())
} else {
// It was there, so grab its row number.
self.stmt_select_tx_ref
.query_row(&[txid], |row| row.get(0))
.map_err(SqliteClientError::from)
}
wallet::put_tx_data(self, tx, created_at).map_err(Error::Database)
}
fn mark_spent(&mut self, tx_ref: Self::TxRef, nf: &Nullifier) -> Result<(), Self::Error> {
self.stmt_mark_recived_note_spent
.execute(&[tx_ref.to_sql()?, nf.0.to_sql()?])?;
Ok(())
wallet::mark_spent(self, tx_ref, nf).map_err(Error::Database)
}
// Assumptions:
@ -500,44 +434,7 @@ impl<'a> WalletWrite for DataConnStmtCache<'a> {
nf_opt: &Option<Nullifier>,
tx_ref: Self::TxRef,
) -> Result<Self::NoteRef, Self::Error> {
let rcm = output.note().rcm().to_repr();
let account = output.account().0 as i64;
let diversifier = output.to().diversifier().0.to_vec();
let value = output.note().value as i64;
let rcm = rcm.as_ref();
let memo = output.memo().map(|m| m.as_bytes());
let is_change = output.is_change();
let tx = tx_ref;
let output_index = output.index() as i64;
let nf_bytes = nf_opt.map(|nf| nf.0.to_vec());
let sql_args: Vec<(&str, &dyn ToSql)> = vec![
(&":account", &account),
(&":diversifier", &diversifier),
(&":value", &value),
(&":rcm", &rcm),
(&":nf", &nf_bytes),
(&":memo", &memo),
(&":is_change", &is_change),
(&":tx", &tx),
(&":output_index", &output_index),
];
// First try updating an existing received note into the database.
if self.stmt_update_received_note.execute_named(&sql_args)? == 0 {
// It isn't there, so insert our note into the database.
self.stmt_insert_received_note.execute_named(&sql_args)?;
Ok(NoteId(self.conn.0.last_insert_rowid()))
} else {
// It was there, so grab its row number.
self.stmt_select_received_note
.query_row(
&[tx_ref.to_sql()?, (output.index() as i64).to_sql()?],
|row| row.get(0).map(NoteId),
)
.map_err(SqliteClientError::from)
}
wallet::put_received_note(self, output, nf_opt, tx_ref).map_err(Error::Database)
}
fn insert_witness(
@ -546,70 +443,27 @@ impl<'a> WalletWrite for DataConnStmtCache<'a> {
witness: &IncrementalWitness<Node>,
height: BlockHeight,
) -> Result<(), Self::Error> {
let mut encoded = Vec::new();
witness
.write(&mut encoded)
.expect("Should be able to write to a Vec");
self.stmt_insert_witness.execute(&[
note_id.0.to_sql()?,
u32::from(height).to_sql()?,
encoded.to_sql()?,
])?;
Ok(())
wallet::insert_witness(self, note_id, witness, height).map_err(Error::Database)
}
fn prune_witnesses(&mut self, below_height: BlockHeight) -> Result<(), Self::Error> {
self.stmt_prune_witnesses
.execute(&[u32::from(below_height)])?;
Ok(())
wallet::prune_witnesses(self, below_height).map_err(Error::Database)
}
fn update_expired_notes(&mut self, height: BlockHeight) -> Result<(), Self::Error> {
self.stmt_update_expired.execute(&[u32::from(height)])?;
Ok(())
wallet::update_expired_notes(self, height).map_err(Error::Database)
}
fn put_sent_note<P: consensus::Parameters>(
fn put_sent_note(
&mut self,
params: &P,
output: &DecryptedOutput,
tx_ref: Self::TxRef,
) -> Result<(), Self::Error> {
let output_index = output.index as i64;
let account = output.account.0 as i64;
let value = output.note.value as i64;
let to_str = encode_payment_address(params.hrp_sapling_payment_address(), &output.to);
// Try updating an existing sent note.
if self.stmt_update_sent_note.execute(&[
account.to_sql()?,
to_str.to_sql()?,
value.to_sql()?,
output.memo.as_bytes().to_sql()?,
tx_ref.to_sql()?,
output_index.to_sql()?,
])? == 0
{
// It isn't there, so insert.
self.insert_sent_note(
params,
tx_ref,
output.index,
output.account,
&RecipientAddress::Shielded(output.to.clone()),
Amount::from_u64(output.note.value)
.map_err(|_| Error::CorruptedData("Note value invalid.".to_string()))?,
Some(output.memo.clone()),
)?
}
Ok(())
wallet::put_sent_note(self, output, tx_ref).map_err(Error::Database)
}
fn insert_sent_note<P: consensus::Parameters>(
fn insert_sent_note(
&mut self,
params: &P,
tx_ref: Self::TxRef,
output_index: usize,
account: AccountId,
@ -617,18 +471,8 @@ impl<'a> WalletWrite for DataConnStmtCache<'a> {
value: Amount,
memo: Option<Memo>,
) -> Result<(), Self::Error> {
let to_str = to.encode(params);
let ivalue: i64 = value.into();
self.stmt_insert_sent_note.execute(&[
tx_ref.to_sql()?,
(output_index as i64).to_sql()?,
account.0.to_sql()?,
to_str.to_sql()?,
ivalue.to_sql()?,
memo.map(|m| m.as_bytes().to_vec()).to_sql()?,
])?;
Ok(())
wallet::insert_sent_note(self, tx_ref, output_index, account, to, value, memo)
.map_err(Error::Database)
}
}
@ -641,7 +485,7 @@ impl BlockDB {
}
impl BlockSource for BlockDB {
type Error = SqliteClientError;
type Error = Error<SqliteClientError, NoteId>;
fn with_blocks<F>(
&self,
@ -670,7 +514,7 @@ mod tests {
use group::GroupEncoding;
use protobuf::Message;
use rand_core::{OsRng, RngCore};
use rusqlite::types::ToSql;
use rusqlite::{params};
use zcash_client_backend::proto::compact_formats::{
CompactBlock, CompactOutput, CompactSpend, CompactTx,
@ -853,10 +697,7 @@ mod tests {
.0
.prepare("INSERT INTO compactblocks (height, data) VALUES (?, ?)")
.unwrap()
.execute(&[
u32::from(cb.height()).to_sql().unwrap(),
cb_bytes.to_sql().unwrap(),
])
.execute(params![u32::from(cb.height()), cb_bytes,])
.unwrap();
}
}

View File

@ -1,6 +1,7 @@
//! Functions for querying information in the data database.
//! Functions for querying information in the wdb database.
use rusqlite::{OptionalExtension, ToSql, NO_PARAMS};
use ff::PrimeField;
use rusqlite::{params, OptionalExtension, ToSql, NO_PARAMS};
use std::collections::HashMap;
use zcash_primitives::{
@ -10,18 +11,25 @@ use zcash_primitives::{
note_encryption::Memo,
primitives::{Nullifier, PaymentAddress},
sapling::Node,
transaction::{components::Amount, TxId},
transaction::{components::Amount, Transaction, TxId},
zip32::ExtendedFullViewingKey,
};
use zcash_client_backend::{
data_api::error::Error,
address::RecipientAddress,
data_api::{error::Error, ShieldedOutput},
encoding::{
decode_extended_full_viewing_key, decode_payment_address, encode_extended_full_viewing_key,
encode_payment_address,
},
DecryptedOutput,
wallet::{AccountId, WalletTx},
};
use crate::{error::SqliteClientError, AccountId, NoteId, WalletDB};
use crate::{
error::{db_error, SqliteClientError},
DataConnStmtCache, NoteId, WalletDB,
};
pub mod init;
pub mod transact;
@ -42,32 +50,30 @@ pub mod transact;
/// };
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file).unwrap();
/// let addr = get_address(&db, &Network::TestNetwork, AccountId(0));
/// let db = WalletDB::for_path(data_file, Network::TestNetwork).unwrap();
/// let addr = get_address(&db, AccountId(0));
/// ```
pub fn get_address<P: consensus::Parameters>(
data: &WalletDB,
params: &P,
wdb: &WalletDB<P>,
account: AccountId,
) -> Result<Option<PaymentAddress>, SqliteClientError> {
let addr: String = data.0.query_row(
let addr: String = wdb.conn.query_row(
"SELECT address FROM accounts
WHERE account = ?",
&[account.0],
|row| row.get(0),
)?;
decode_payment_address(params.hrp_sapling_payment_address(), &addr)
.map_err(|e| SqliteClientError(e.into()))
decode_payment_address(wdb.params.hrp_sapling_payment_address(), &addr)
.map_err(SqliteClientError::Bech32)
}
pub fn get_extended_full_viewing_keys<P: consensus::Parameters>(
data: &WalletDB,
params: &P,
wdb: &WalletDB<P>,
) -> Result<HashMap<AccountId, ExtendedFullViewingKey>, SqliteClientError> {
// Fetch the ExtendedFullViewingKeys we are tracking
let mut stmt_fetch_accounts = data
.0
let mut stmt_fetch_accounts = wdb
.conn
.prepare("SELECT account, extfvk FROM accounts ORDER BY account ASC")?;
let rows = stmt_fetch_accounts
@ -75,12 +81,11 @@ pub fn get_extended_full_viewing_keys<P: consensus::Parameters>(
let acct = row.get(0).map(AccountId)?;
let extfvk = row.get(1).map(|extfvk: String| {
decode_extended_full_viewing_key(
params.hrp_sapling_extended_full_viewing_key(),
wdb.params.hrp_sapling_extended_full_viewing_key(),
&extfvk,
)
.map_err(|e| Error::Bech32(e))
.and_then(|k| k.ok_or(Error::IncorrectHRPExtFVK))
.map_err(SqliteClientError)
.map_err(SqliteClientError::Bech32)
.and_then(|k| k.ok_or(SqliteClientError::IncorrectHRPExtFVK))
})?;
Ok((acct, extfvk))
@ -89,7 +94,7 @@ pub fn get_extended_full_viewing_keys<P: consensus::Parameters>(
let mut res: HashMap<AccountId, ExtendedFullViewingKey> = HashMap::new();
for row in rows {
let (account_id, efvkr) = row?;
let (account_id, efvkr) = row?;
res.insert(account_id, efvkr?);
}
@ -97,17 +102,16 @@ pub fn get_extended_full_viewing_keys<P: consensus::Parameters>(
}
pub fn is_valid_account_extfvk<P: consensus::Parameters>(
data: &WalletDB,
params: &P,
wdb: &WalletDB<P>,
account: AccountId,
extfvk: &ExtendedFullViewingKey,
) -> Result<bool, SqliteClientError> {
data.0
wdb.conn
.prepare("SELECT * FROM accounts WHERE account = ? AND extfvk = ?")?
.exists(&[
account.0.to_sql()?,
encode_extended_full_viewing_key(
params.hrp_sapling_extended_full_viewing_key(),
wdb.params.hrp_sapling_extended_full_viewing_key(),
extfvk,
)
.to_sql()?,
@ -127,6 +131,7 @@ pub fn is_valid_account_extfvk<P: consensus::Parameters>(
///
/// ```
/// use tempfile::NamedTempFile;
/// use zcash_primitives::consensus::Network;
/// use zcash_client_backend::wallet::AccountId;
/// use zcash_client_sqlite::{
/// WalletDB,
@ -134,11 +139,11 @@ pub fn is_valid_account_extfvk<P: consensus::Parameters>(
/// };
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file).unwrap();
/// let db = WalletDB::for_path(data_file, Network::TestNetwork).unwrap();
/// let addr = get_balance(&db, AccountId(0));
/// ```
pub fn get_balance(data: &WalletDB, account: AccountId) -> Result<Amount, SqliteClientError> {
let balance = data.0.query_row(
pub fn get_balance<P>(wdb: &WalletDB<P>, account: AccountId) -> Result<Amount, SqliteClientError> {
let balance = wdb.conn.query_row(
"SELECT SUM(value) FROM received_notes
INNER JOIN transactions ON transactions.id_tx = received_notes.tx
WHERE account = ? AND spent IS NULL AND transactions.block IS NOT NULL",
@ -148,9 +153,9 @@ pub fn get_balance(data: &WalletDB, account: AccountId) -> Result<Amount, Sqlite
match Amount::from_i64(balance) {
Ok(amount) if !amount.is_negative() => Ok(amount),
_ => Err(SqliteClientError(Error::CorruptedData(
_ => Err(SqliteClientError::CorruptedData(
"Sum of values in received_notes is out of range".to_string(),
))),
)),
}
}
@ -162,7 +167,7 @@ pub fn get_balance(data: &WalletDB, account: AccountId) -> Result<Amount, Sqlite
///
/// ```
/// use tempfile::NamedTempFile;
/// use zcash_primitives::consensus::{BlockHeight};
/// use zcash_primitives::consensus::{BlockHeight, Network};
/// use zcash_client_backend::wallet::AccountId;
/// use zcash_client_sqlite::{
/// WalletDB,
@ -170,15 +175,15 @@ pub fn get_balance(data: &WalletDB, account: AccountId) -> Result<Amount, Sqlite
/// };
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file).unwrap();
/// let db = WalletDB::for_path(data_file, Network::TestNetwork).unwrap();
/// let addr = get_verified_balance(&db, AccountId(0), BlockHeight::from_u32(0));
/// ```
pub fn get_verified_balance(
data: &WalletDB,
pub fn get_verified_balance<P>(
wdb: &WalletDB<P>,
account: AccountId,
anchor_height: BlockHeight,
) -> Result<Amount, SqliteClientError> {
let balance = data.0.query_row(
let balance = wdb.conn.query_row(
"SELECT SUM(value) FROM received_notes
INNER JOIN transactions ON transactions.id_tx = received_notes.tx
WHERE account = ? AND spent IS NULL AND transactions.block <= ?",
@ -188,21 +193,22 @@ pub fn get_verified_balance(
match Amount::from_i64(balance) {
Ok(amount) if !amount.is_negative() => Ok(amount),
_ => Err(SqliteClientError(Error::CorruptedData(
_ => Err(SqliteClientError::CorruptedData(
"Sum of values in received_notes is out of range".to_string(),
))),
)),
}
}
/// Returns the memo for a received note, if it is known and a valid UTF-8 string.
///
/// The note is identified by its row index in the `received_notes` table within the data
/// The note is identified by its row index in the `received_notes` table within the wdb
/// database.
///
/// # Examples
///
/// ```
/// use tempfile::NamedTempFile;
/// use zcash_primitives::consensus::Network;
/// use zcash_client_sqlite::{
/// NoteId,
/// WalletDB,
@ -210,14 +216,14 @@ pub fn get_verified_balance(
/// };
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file).unwrap();
/// let db = WalletDB::for_path(data_file, Network::TestNetwork).unwrap();
/// let memo = get_received_memo_as_utf8(&db, NoteId(27));
/// ```
pub fn get_received_memo_as_utf8(
data: &WalletDB,
pub fn get_received_memo_as_utf8<P>(
wdb: &WalletDB<P>,
id_note: NoteId,
) -> Result<Option<String>, SqliteClientError> {
let memo: Vec<_> = data.0.query_row(
let memo: Vec<_> = wdb.conn.query_row(
"SELECT memo FROM received_notes
WHERE id_note = ?",
&[id_note.0],
@ -227,7 +233,7 @@ pub fn get_received_memo_as_utf8(
match Memo::from_bytes(&memo) {
Some(memo) => match memo.to_utf8() {
Some(Ok(res)) => Ok(Some(res)),
Some(Err(e)) => Err(SqliteClientError(Error::InvalidMemo(e))),
Some(Err(e)) => Err(SqliteClientError::InvalidMemo(e)),
None => Ok(None),
},
None => Ok(None),
@ -236,13 +242,14 @@ pub fn get_received_memo_as_utf8(
/// Returns the memo for a sent note, if it is known and a valid UTF-8 string.
///
/// The note is identified by its row index in the `sent_notes` table within the data
/// The note is identified by its row index in the `sent_notes` table within the wdb
/// database.
///
/// # Examples
///
/// ```
/// use tempfile::NamedTempFile;
/// use zcash_primitives::consensus::Network;
/// use zcash_client_sqlite::{
/// NoteId,
/// WalletDB,
@ -250,14 +257,14 @@ pub fn get_received_memo_as_utf8(
/// };
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file).unwrap();
/// let db = WalletDB::for_path(data_file, Network::TestNetwork).unwrap();
/// let memo = get_sent_memo_as_utf8(&db, NoteId(12));
/// ```
pub fn get_sent_memo_as_utf8(
data: &WalletDB,
pub fn get_sent_memo_as_utf8<P>(
wdb: &WalletDB<P>,
id_note: NoteId,
) -> Result<Option<String>, SqliteClientError> {
let memo: Vec<_> = data.0.query_row(
let memo: Vec<_> = wdb.conn.query_row(
"SELECT memo FROM sent_notes
WHERE id_note = ?",
&[id_note.0],
@ -267,17 +274,17 @@ pub fn get_sent_memo_as_utf8(
match Memo::from_bytes(&memo) {
Some(memo) => match memo.to_utf8() {
Some(Ok(res)) => Ok(Some(res)),
Some(Err(e)) => Err(SqliteClientError(Error::InvalidMemo(e))),
Some(Err(e)) => Err(SqliteClientError::InvalidMemo(e)),
None => Ok(None),
},
None => Ok(None),
}
}
pub fn block_height_extrema(
conn: &WalletDB,
pub fn block_height_extrema<P>(
wdb: &WalletDB<P>,
) -> Result<Option<(BlockHeight, BlockHeight)>, rusqlite::Error> {
conn.0
wdb.conn
.query_row(
"SELECT MIN(height), MAX(height) FROM blocks",
NO_PARAMS,
@ -295,8 +302,11 @@ pub fn block_height_extrema(
.or(Ok(None))
}
pub fn get_tx_height(conn: &WalletDB, txid: TxId) -> Result<Option<BlockHeight>, rusqlite::Error> {
conn.0
pub fn get_tx_height<P>(
wdb: &WalletDB<P>,
txid: TxId,
) -> Result<Option<BlockHeight>, rusqlite::Error> {
wdb.conn
.query_row(
"SELECT block FROM transactions WHERE txid = ?",
&[txid.0.to_vec()],
@ -305,11 +315,11 @@ pub fn get_tx_height(conn: &WalletDB, txid: TxId) -> Result<Option<BlockHeight>,
.optional()
}
pub fn get_block_hash(
conn: &WalletDB,
pub fn get_block_hash<P>(
wdb: &WalletDB<P>,
block_height: BlockHeight,
) -> Result<Option<BlockHash>, rusqlite::Error> {
conn.0
wdb.conn
.query_row(
"SELECT hash FROM blocks WHERE height = ?",
&[u32::from(block_height)],
@ -328,55 +338,61 @@ pub fn get_block_hash(
///
/// This should only be executed inside a transactional context.
pub fn rewind_to_height<P: consensus::Parameters>(
conn: &WalletDB,
parameters: &P,
wdb: &WalletDB<P>,
block_height: BlockHeight,
) -> Result<(), SqliteClientError> {
let sapling_activation_height = parameters
) -> Result<(), Error<SqliteClientError, NoteId>> {
let sapling_activation_height = wdb
.params
.activation_height(NetworkUpgrade::Sapling)
.ok_or(SqliteClientError(Error::SaplingNotActive))?;
.ok_or(Error::SaplingNotActive)?;
// Recall where we synced up to previously.
// If we have never synced, use Sapling activation height.
let last_scanned_height =
conn.0
.query_row("SELECT MAX(height) FROM blocks", NO_PARAMS, |row| {
row.get(0)
.map(u32::into)
.or(Ok(sapling_activation_height - 1))
})?;
let last_scanned_height = wdb
.conn
.query_row("SELECT MAX(height) FROM blocks", NO_PARAMS, |row| {
row.get(0)
.map(|h: u32| h.into())
.or(Ok(sapling_activation_height - 1))
})
.map_err(db_error)?;
// nothing to do if we're deleting back down to the max height
if block_height >= last_scanned_height {
// Nothing to do.
return Ok(());
Ok(())
} else {
// Decrement witnesses.
wdb.conn
.execute(
"DELETE FROM sapling_witnesses WHERE block > ?",
&[u32::from(block_height)],
)
.map_err(db_error)?;
// Un-mine transactions.
wdb.conn
.execute(
"UPDATE transactions SET block = NULL, tx_index = NULL WHERE block > ?",
&[u32::from(block_height)],
)
.map_err(db_error)?;
// Now that they aren't depended on, delete scanned blocks.
wdb.conn
.execute(
"DELETE FROM blocks WHERE height > ?",
&[u32::from(block_height)],
)
.map_err(db_error)?;
Ok(())
}
// Decrement witnesses.
conn.0.execute(
"DELETE FROM sapling_witnesses WHERE block > ?",
&[u32::from(block_height)],
)?;
// Un-mine transactions.
conn.0.execute(
"UPDATE transactions SET block = NULL, tx_index = NULL WHERE block > ?",
&[u32::from(block_height)],
)?;
// Now that they aren't depended on, delete scanned blocks.
conn.0.execute(
"DELETE FROM blocks WHERE height > ?",
&[u32::from(block_height)],
)?;
Ok(())
}
pub fn get_commitment_tree(
data: &WalletDB,
pub fn get_commitment_tree<P>(
wdb: &WalletDB<P>,
block_height: BlockHeight,
) -> Result<Option<CommitmentTree<Node>>, SqliteClientError> {
data.0
wdb.conn
.query_row_and_then(
"SELECT sapling_tree FROM blocks WHERE height = ?",
&[u32::from(block_height)],
@ -395,18 +411,18 @@ pub fn get_commitment_tree(
.map_err(SqliteClientError::from)
}
pub fn get_witnesses(
data: &WalletDB,
pub fn get_witnesses<P>(
wdb: &WalletDB<P>,
block_height: BlockHeight,
) -> Result<Vec<(NoteId, IncrementalWitness<Node>)>, SqliteClientError> {
let mut stmt_fetch_witnesses = data
.0
let mut stmt_fetch_witnesses = wdb
.conn
.prepare("SELECT note, witness FROM sapling_witnesses WHERE block = ?")?;
let witnesses = stmt_fetch_witnesses
.query_map(&[u32::from(block_height)], |row| {
let id_note = NoteId(row.get(0)?);
let data: Vec<u8> = row.get(1)?;
Ok(IncrementalWitness::read(&data[..]).map(|witness| (id_note, witness)))
let wdb: Vec<u8> = row.get(1)?;
Ok(IncrementalWitness::read(&wdb[..]).map(|witness| (id_note, witness)))
})
.map_err(SqliteClientError::from)?;
@ -419,10 +435,12 @@ pub fn get_witnesses(
Ok(res)
}
pub fn get_nullifiers(data: &WalletDB) -> Result<Vec<(Nullifier, AccountId)>, SqliteClientError> {
pub fn get_nullifiers<P>(
wdb: &WalletDB<P>,
) -> Result<Vec<(Nullifier, AccountId)>, SqliteClientError> {
// Get the nullifiers for the notes we are tracking
let mut stmt_fetch_nullifiers = data
.0
let mut stmt_fetch_nullifiers = wdb
.conn
.prepare("SELECT id_note, nf, account FROM received_notes WHERE spent IS NULL")?;
let nullifiers = stmt_fetch_nullifiers.query_map(NO_PARAMS, |row| {
let nf_bytes: Vec<u8> = row.get(1)?;
@ -438,9 +456,244 @@ pub fn get_nullifiers(data: &WalletDB) -> Result<Vec<(Nullifier, AccountId)>, Sq
Ok(res)
}
pub fn insert_block<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
block_height: BlockHeight,
block_hash: BlockHash,
block_time: u32,
commitment_tree: &CommitmentTree<Node>,
) -> Result<(), SqliteClientError> {
let mut encoded_tree = Vec::new();
commitment_tree.write(&mut encoded_tree).unwrap();
stmts.stmt_insert_block
.execute(params![
u32::from(block_height),
&block_hash.0[..],
block_time,
encoded_tree
])?;
Ok(())
}
pub fn put_tx_meta<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
tx: &WalletTx,
height: BlockHeight,
) -> Result<i64, SqliteClientError> {
let txid = tx.txid.0.to_vec();
if stmts
.stmt_update_tx_meta
.execute(params![u32::from(height), (tx.index as i64), txid])?
== 0
{
// It isn't there, so insert our transaction into the database.
stmts
.stmt_insert_tx_meta
.execute(params![txid, u32::from(height), (tx.index as i64),])?;
Ok(stmts.wallet_db.conn.last_insert_rowid())
} else {
// It was there, so grab its row number.
stmts
.stmt_select_tx_ref
.query_row(&[txid], |row| row.get(0))
.map_err(SqliteClientError::from)
}
}
pub fn put_tx_data<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
tx: &Transaction,
created_at: Option<time::OffsetDateTime>,
) -> Result<i64, SqliteClientError> {
let txid = tx.txid().0.to_vec();
let mut raw_tx = vec![];
tx.write(&mut raw_tx)?;
if stmts
.stmt_update_tx_data
.execute(params![u32::from(tx.expiry_height), raw_tx, txid,])?
== 0
{
// It isn't there, so insert our transaction into the database.
stmts.stmt_insert_tx_data.execute(params![
txid,
created_at,
u32::from(tx.expiry_height),
raw_tx
])?;
Ok(stmts.wallet_db.conn.last_insert_rowid())
} else {
// It was there, so grab its row number.
stmts
.stmt_select_tx_ref
.query_row(&[txid], |row| row.get(0))
.map_err(SqliteClientError::from)
}
}
pub fn mark_spent<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
tx_ref: i64,
nf: &Nullifier,
) -> Result<(), SqliteClientError> {
stmts
.stmt_mark_recived_note_spent
.execute(&[tx_ref.to_sql()?, nf.0.to_sql()?])?;
Ok(())
}
// Assumptions:
// - A transaction will not contain more than 2^63 shielded outputs.
// - A note value will never exceed 2^63 zatoshis.
pub fn put_received_note<'a, P, T: ShieldedOutput>(
stmts: &mut DataConnStmtCache<'a, P>,
output: &T,
nf_opt: &Option<Nullifier>,
tx_ref: i64,
) -> Result<NoteId, SqliteClientError> {
let rcm = output.note().rcm().to_repr();
let account = output.account().0 as i64;
let diversifier = output.to().diversifier().0.to_vec();
let value = output.note().value as i64;
let rcm = rcm.as_ref();
let memo = output.memo().map(|m| m.as_bytes());
let is_change = output.is_change();
let tx = tx_ref;
let output_index = output.index() as i64;
let nf_bytes = nf_opt.map(|nf| nf.0.to_vec());
let sql_args: Vec<(&str, &dyn ToSql)> = vec![
(&":account", &account),
(&":diversifier", &diversifier),
(&":value", &value),
(&":rcm", &rcm),
(&":nf", &nf_bytes),
(&":memo", &memo),
(&":is_change", &is_change),
(&":tx", &tx),
(&":output_index", &output_index),
];
// First try updating an existing received note into the database.
if stmts.stmt_update_received_note.execute_named(&sql_args)? == 0 {
// It isn't there, so insert our note into the database.
stmts.stmt_insert_received_note.execute_named(&sql_args)?;
Ok(NoteId(stmts.wallet_db.conn.last_insert_rowid()))
} else {
// It was there, so grab its row number.
stmts
.stmt_select_received_note
.query_row(params![tx_ref, (output.index() as i64)], |row| {
row.get(0).map(NoteId)
})
.map_err(SqliteClientError::from)
}
}
pub fn insert_witness<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
note_id: NoteId,
witness: &IncrementalWitness<Node>,
height: BlockHeight,
) -> Result<(), SqliteClientError> {
let mut encoded = Vec::new();
witness.write(&mut encoded).unwrap();
stmts
.stmt_insert_witness
.execute(params![note_id.0, u32::from(height), encoded])?;
Ok(())
}
pub fn prune_witnesses<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
below_height: BlockHeight,
) -> Result<(), SqliteClientError> {
stmts
.stmt_prune_witnesses
.execute(&[u32::from(below_height)])?;
Ok(())
}
pub fn update_expired_notes<'a, P>(
stmts: &mut DataConnStmtCache<'a, P>,
height: BlockHeight,
) -> Result<(), SqliteClientError> {
stmts.stmt_update_expired.execute(&[u32::from(height)])?;
Ok(())
}
pub fn put_sent_note<'a, P: consensus::Parameters>(
stmts: &mut DataConnStmtCache<'a, P>,
output: &DecryptedOutput,
tx_ref: i64,
) -> Result<(), SqliteClientError> {
let output_index = output.index as i64;
let account = output.account.0 as i64;
let value = output.note.value as i64;
let to_str = encode_payment_address(
stmts.wallet_db.params.hrp_sapling_payment_address(),
&output.to,
);
// Try updating an existing sent note.
if stmts.stmt_update_sent_note.execute(params![
account,
to_str,
value,
&output.memo.as_bytes(),
tx_ref,
output_index
])? == 0
{
// It isn't there, so insert.
insert_sent_note(
stmts,
tx_ref,
output.index,
output.account,
&RecipientAddress::Shielded(output.to.clone()),
Amount::from_u64(output.note.value)
.map_err(|_| SqliteClientError::CorruptedData("Note value invalid.".to_string()))?,
Some(output.memo.clone()),
)?
}
Ok(())
}
pub fn insert_sent_note<'a, P: consensus::Parameters>(
stmts: &mut DataConnStmtCache<'a, P>,
tx_ref: i64,
output_index: usize,
account: AccountId,
to: &RecipientAddress,
value: Amount,
memo: Option<Memo>,
) -> Result<(), SqliteClientError> {
let to_str = to.encode(&stmts.wallet_db.params);
let ivalue: i64 = value.into();
stmts.stmt_insert_sent_note.execute(params![
tx_ref,
(output_index as i64),
account.0,
to_str,
ivalue,
memo.map(|m| m.as_bytes().to_vec()),
])?;
Ok(())
}
#[cfg(test)]
mod tests {
use rusqlite::Connection;
use tempfile::NamedTempFile;
use zcash_primitives::{
@ -461,13 +714,13 @@ mod tests {
#[test]
fn empty_database_has_no_balance() {
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvks = [ExtendedFullViewingKey::from(&extsk)];
init_accounts_table(&db_data, &tests::network(), &extfvks).unwrap();
init_accounts_table(&db_data, &extfvks).unwrap();
// The account should be empty
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), Amount::zero());
@ -476,7 +729,7 @@ mod tests {
assert_eq!((&db_data).get_target_and_anchor_heights().unwrap(), None);
// An invalid account has zero balance
assert!(get_address(&db_data, &tests::network(), AccountId(1)).is_err());
assert!(get_address(&db_data, AccountId(1)).is_err());
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), Amount::zero());
}
}

View File

@ -8,7 +8,7 @@ use zcash_primitives::{
zip32::ExtendedFullViewingKey,
};
use zcash_client_backend::{data_api::error::Error, encoding::encode_extended_full_viewing_key};
use zcash_client_backend::{encoding::encode_extended_full_viewing_key};
use crate::{address_from_extfvk, error::SqliteClientError, WalletDB};
@ -18,17 +18,18 @@ use crate::{address_from_extfvk, error::SqliteClientError, WalletDB};
///
/// ```
/// use tempfile::NamedTempFile;
/// use zcash_primitives::consensus::Network;
/// use zcash_client_sqlite::{
/// WalletDB,
/// wallet::init::init_data_database,
/// };
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file.path()).unwrap();
/// let db = WalletDB::for_path(data_file.path(), Network::TestNetwork).unwrap();
/// init_data_database(&db).unwrap();
/// ```
pub fn init_data_database(db_data: &WalletDB) -> Result<(), rusqlite::Error> {
db_data.0.execute(
pub fn init_data_database<P>(wdb: &WalletDB<P>) -> Result<(), rusqlite::Error> {
wdb.conn.execute(
"CREATE TABLE IF NOT EXISTS accounts (
account INTEGER PRIMARY KEY,
extfvk TEXT NOT NULL,
@ -36,7 +37,7 @@ pub fn init_data_database(db_data: &WalletDB) -> Result<(), rusqlite::Error> {
)",
NO_PARAMS,
)?;
db_data.0.execute(
wdb.conn.execute(
"CREATE TABLE IF NOT EXISTS blocks (
height INTEGER PRIMARY KEY,
hash BLOB NOT NULL,
@ -45,7 +46,7 @@ pub fn init_data_database(db_data: &WalletDB) -> Result<(), rusqlite::Error> {
)",
NO_PARAMS,
)?;
db_data.0.execute(
wdb.conn.execute(
"CREATE TABLE IF NOT EXISTS transactions (
id_tx INTEGER PRIMARY KEY,
txid BLOB NOT NULL UNIQUE,
@ -58,7 +59,7 @@ pub fn init_data_database(db_data: &WalletDB) -> Result<(), rusqlite::Error> {
)",
NO_PARAMS,
)?;
db_data.0.execute(
wdb.conn.execute(
"CREATE TABLE IF NOT EXISTS received_notes (
id_note INTEGER PRIMARY KEY,
tx INTEGER NOT NULL,
@ -78,7 +79,7 @@ pub fn init_data_database(db_data: &WalletDB) -> Result<(), rusqlite::Error> {
)",
NO_PARAMS,
)?;
db_data.0.execute(
wdb.conn.execute(
"CREATE TABLE IF NOT EXISTS sapling_witnesses (
id_witness INTEGER PRIMARY KEY,
note INTEGER NOT NULL,
@ -90,7 +91,7 @@ pub fn init_data_database(db_data: &WalletDB) -> Result<(), rusqlite::Error> {
)",
NO_PARAMS,
)?;
db_data.0.execute(
wdb.conn.execute(
"CREATE TABLE IF NOT EXISTS sent_notes (
id_note INTEGER PRIMARY KEY,
tx INTEGER NOT NULL,
@ -131,38 +132,37 @@ pub fn init_data_database(db_data: &WalletDB) -> Result<(), rusqlite::Error> {
/// };
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db_data = WalletDB::for_path(data_file.path()).unwrap();
/// let db_data = WalletDB::for_path(data_file.path(), Network::TestNetwork).unwrap();
/// init_data_database(&db_data).unwrap();
///
/// let extsk = ExtendedSpendingKey::master(&[]);
/// let extfvks = [ExtendedFullViewingKey::from(&extsk)];
/// init_accounts_table(&db_data, &Network::TestNetwork, &extfvks).unwrap();
/// init_accounts_table(&db_data, &extfvks).unwrap();
/// ```
///
/// [`get_address`]: crate::wallet::get_address
/// [`scan_cached_blocks`]: crate::scan::scan_cached_blocks
/// [`create_to_address`]: crate::transact::create_to_address
pub fn init_accounts_table<P: consensus::Parameters>(
data: &WalletDB,
params: &P,
wdb: &WalletDB<P>,
extfvks: &[ExtendedFullViewingKey],
) -> Result<(), SqliteClientError> {
let mut empty_check = data.0.prepare("SELECT * FROM accounts LIMIT 1")?;
let mut empty_check = wdb.conn.prepare("SELECT * FROM accounts LIMIT 1")?;
if empty_check.exists(NO_PARAMS)? {
return Err(SqliteClientError(Error::TableNotEmpty));
return Err(SqliteClientError::TableNotEmpty);
}
// Insert accounts atomically
data.0.execute("BEGIN IMMEDIATE", NO_PARAMS)?;
wdb.conn.execute("BEGIN IMMEDIATE", NO_PARAMS)?;
for (account, extfvk) in extfvks.iter().enumerate() {
let extfvk_str = encode_extended_full_viewing_key(
params.hrp_sapling_extended_full_viewing_key(),
wdb.params.hrp_sapling_extended_full_viewing_key(),
extfvk,
);
let address_str = address_from_extfvk(params, extfvk);
let address_str = address_from_extfvk(&wdb.params, extfvk);
data.0.execute(
wdb.conn.execute(
"INSERT INTO accounts (account, extfvk, address)
VALUES (?, ?, ?)",
&[
@ -172,7 +172,7 @@ pub fn init_accounts_table<P: consensus::Parameters>(
],
)?;
}
data.0.execute("COMMIT", NO_PARAMS)?;
wdb.conn.execute("COMMIT", NO_PARAMS)?;
Ok(())
}
@ -188,7 +188,7 @@ pub fn init_accounts_table<P: consensus::Parameters>(
/// use tempfile::NamedTempFile;
/// use zcash_primitives::{
/// block::BlockHash,
/// consensus::BlockHeight,
/// consensus::{BlockHeight, Network},
/// };
/// use zcash_client_sqlite::{
/// WalletDB,
@ -206,22 +206,22 @@ pub fn init_accounts_table<P: consensus::Parameters>(
/// let sapling_tree = &[];
///
/// let data_file = NamedTempFile::new().unwrap();
/// let db = WalletDB::for_path(data_file.path()).unwrap();
/// let db = WalletDB::for_path(data_file.path(), Network::TestNetwork).unwrap();
/// init_blocks_table(&db, height, hash, time, sapling_tree);
/// ```
pub fn init_blocks_table(
data: &WalletDB,
pub fn init_blocks_table<P>(
wdb: &WalletDB<P>,
height: BlockHeight,
hash: BlockHash,
time: u32,
sapling_tree: &[u8],
) -> Result<(), SqliteClientError> {
let mut empty_check = data.0.prepare("SELECT * FROM blocks LIMIT 1")?;
let mut empty_check = wdb.conn.prepare("SELECT * FROM blocks LIMIT 1")?;
if empty_check.exists(NO_PARAMS)? {
return Err(SqliteClientError(Error::TableNotEmpty));
return Err(SqliteClientError::TableNotEmpty);
}
data.0.execute(
wdb.conn.execute(
"INSERT INTO blocks (height, hash, time, sapling_tree)
VALUES (?, ?, ?, ?)",
&[
@ -237,7 +237,6 @@ pub fn init_blocks_table(
#[cfg(test)]
mod tests {
use rusqlite::Connection;
use tempfile::NamedTempFile;
use zcash_primitives::{
@ -253,28 +252,28 @@ mod tests {
#[test]
fn init_accounts_table_only_works_once() {
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// We can call the function as many times as we want with no data
init_accounts_table(&db_data, &tests::network(), &[]).unwrap();
init_accounts_table(&db_data, &tests::network(), &[]).unwrap();
init_accounts_table(&db_data, &[]).unwrap();
init_accounts_table(&db_data, &[]).unwrap();
// First call with data should initialise the accounts table
let extfvks = [ExtendedFullViewingKey::from(&ExtendedSpendingKey::master(
&[],
))];
init_accounts_table(&db_data, &tests::network(), &extfvks).unwrap();
init_accounts_table(&db_data, &extfvks).unwrap();
// Subsequent calls should return an error
init_accounts_table(&db_data, &tests::network(), &[]).unwrap_err();
init_accounts_table(&db_data, &tests::network(), &extfvks).unwrap_err();
init_accounts_table(&db_data, &[]).unwrap_err();
init_accounts_table(&db_data, &extfvks).unwrap_err();
}
#[test]
fn init_blocks_table_only_works_once() {
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// First call with data should initialise the blocks table
@ -301,16 +300,16 @@ mod tests {
#[test]
fn init_accounts_table_stores_correct_address() {
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvks = [ExtendedFullViewingKey::from(&extsk)];
init_accounts_table(&db_data, &tests::network(), &extfvks).unwrap();
init_accounts_table(&db_data, &extfvks).unwrap();
// The account's address should be in the data DB
let pa = get_address(&db_data, &tests::network(), AccountId(0)).unwrap();
let pa = get_address(&db_data, AccountId(0)).unwrap();
assert_eq!(pa.unwrap(), extsk.default_address().unwrap().1);
}
}

View File

@ -13,14 +13,13 @@ use zcash_primitives::{
};
use zcash_client_backend::{
data_api::error::Error,
wallet::{AccountId, SpendableNote},
};
use crate::{error::SqliteClientError, WalletDB};
pub fn select_spendable_notes(
data: &WalletDB,
pub fn select_spendable_notes<P>(
wdb: &WalletDB<P>,
account: AccountId,
target_value: Amount,
anchor_height: BlockHeight,
@ -43,7 +42,7 @@ pub fn select_spendable_notes(
// required value, bringing the sum of all selected notes across the threshold.
//
// 4) Match the selected notes against the witnesses at the desired height.
let mut stmt_select_notes = data.0.prepare(
let mut stmt_select_notes = wdb.conn.prepare(
"WITH selected AS (
WITH eligible AS (
SELECT id_note, diversifier, value, rcm,
@ -76,9 +75,9 @@ pub fn select_spendable_notes(
let diversifier = {
let d: Vec<_> = row.get(0)?;
if d.len() != 11 {
return Err(SqliteClientError(Error::CorruptedData(
return Err(SqliteClientError::CorruptedData(
"Invalid diversifier length".to_string(),
)));
));
}
let mut tmp = [0; 11];
tmp.copy_from_slice(&d);
@ -96,9 +95,9 @@ pub fn select_spendable_notes(
let rcm = jubjub::Fr::from_repr(
rcm_bytes[..]
.try_into()
.map_err(|_| SqliteClientError(Error::InvalidNote))?,
.map_err(|_| SqliteClientError::InvalidNote)?,
)
.ok_or(SqliteClientError(Error::InvalidNote))?;
.ok_or(SqliteClientError::InvalidNote)?;
Rseed::BeforeZip212(rcm)
};
@ -148,7 +147,7 @@ mod tests {
get_balance, get_verified_balance,
init::{init_accounts_table, init_blocks_table, init_data_database},
},
AccountId, BlockDB, WalletDB, DataConnStmtCache
AccountId, BlockDB, DataConnStmtCache, WalletDB,
};
fn test_prover() -> impl TxProver {
@ -163,7 +162,7 @@ mod tests {
#[test]
fn create_to_address_fails_on_incorrect_extsk() {
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add two accounts to the wallet
@ -173,7 +172,7 @@ mod tests {
ExtendedFullViewingKey::from(&extsk0),
ExtendedFullViewingKey::from(&extsk1),
];
init_accounts_table(&db_data, &tests::network(), &extfvks).unwrap();
init_accounts_table(&db_data, &extfvks).unwrap();
let to = extsk0.default_address().unwrap().1.into();
// Invalid extsk for the given account should cause an error
@ -212,20 +211,20 @@ mod tests {
#[test]
fn create_to_address_fails_with_no_blocks() {
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvks = [ExtendedFullViewingKey::from(&extsk)];
init_accounts_table(&db_data, &tests::network(), &extfvks).unwrap();
init_accounts_table(&db_data, &extfvks).unwrap();
let to = extsk.default_address().unwrap().1.into();
// We cannot do anything if we aren't synchronised
let mut db_write = db_data.get_update_ops().unwrap();
match create_spend_to_address(
&mut db_write,
&tests::network(),
&tests::network(),
test_prover(),
AccountId(0),
&extsk,
@ -242,7 +241,7 @@ mod tests {
#[test]
fn create_to_address_fails_on_insufficient_balance() {
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
init_blocks_table(
&db_data,
@ -256,7 +255,7 @@ mod tests {
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvks = [ExtendedFullViewingKey::from(&extsk)];
init_accounts_table(&db_data, &tests::network(), &extfvks).unwrap();
init_accounts_table(&db_data, &extfvks).unwrap();
let to = extsk.default_address().unwrap().1.into();
// Account balance should be zero
@ -290,13 +289,13 @@ mod tests {
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
init_accounts_table(&db_data, &tests::network(), &[extfvk.clone()]).unwrap();
init_accounts_table(&db_data, &[extfvk.clone()]).unwrap();
// Add funds to the wallet in a single note
let value = Amount::from_u64(50000).unwrap();
@ -421,13 +420,13 @@ mod tests {
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path(), tests::network()).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
init_accounts_table(&db_data, &tests::network(), &[extfvk.clone()]).unwrap();
init_accounts_table(&db_data, &[extfvk.clone()]).unwrap();
// Add funds to the wallet in a single note
let value = Amount::from_u64(50000).unwrap();
@ -542,13 +541,13 @@ mod tests {
init_cache_database(&db_cache).unwrap();
let data_file = NamedTempFile::new().unwrap();
let db_data = WalletDB(Connection::open(data_file.path()).unwrap());
let db_data = WalletDB::for_path(data_file.path(), network).unwrap();
init_data_database(&db_data).unwrap();
// Add an account to the wallet
let extsk = ExtendedSpendingKey::master(&[]);
let extfvk = ExtendedFullViewingKey::from(&extsk);
init_accounts_table(&db_data, &network, &[extfvk.clone()]).unwrap();
init_accounts_table(&db_data, &[extfvk.clone()]).unwrap();
// Add funds to the wallet in a single note
let value = Amount::from_u64(50000).unwrap();
@ -567,10 +566,10 @@ mod tests {
let addr2 = extsk2.default_address().unwrap().1;
let to = addr2.clone().into();
let send_and_recover_with_policy = |db_write: &mut DataConnStmtCache<'_>, ovk_policy| {
let send_and_recover_with_policy = |db_write: &mut DataConnStmtCache<'_, _>, ovk_policy| {
let tx_row = create_spend_to_address(
db_write,
&network,
&tests::network(),
test_prover(),
AccountId(0),
&extsk,
@ -582,8 +581,9 @@ mod tests {
.unwrap();
// Fetch the transaction from the database
let raw_tx: Vec<_> = db_write.conn
.0
let raw_tx: Vec<_> = db_write
.wallet_db
.conn
.query_row(
"SELECT raw FROM transactions
WHERE id_tx = ?",
@ -594,8 +594,9 @@ mod tests {
let tx = Transaction::read(&raw_tx[..]).unwrap();
// Fetch the output index from the database
let output_index: i64 = db_write.conn
.0
let output_index: i64 = db_write
.wallet_db
.conn
.query_row(
"SELECT output_index FROM sent_notes
WHERE tx = ?",
@ -620,7 +621,8 @@ mod tests {
// Send some of the funds to another address, keeping history.
// The recipient output is decryptable by the sender.
let (_, recovered_to, _) = send_and_recover_with_policy(&mut db_write, OvkPolicy::Sender).unwrap();
let (_, recovered_to, _) =
send_and_recover_with_policy(&mut db_write, OvkPolicy::Sender).unwrap();
assert_eq!(&recovered_to, &addr2);
// Mine blocks SAPLING_ACTIVATION_HEIGHT + 1 to 22 (that don't send us funds)