zcash_client_sqlite: Add shardtree truncation & checkpoint operations.

This commit is contained in:
Kris Nuttycombe 2023-06-13 11:20:18 -06:00
parent ade882d01c
commit d11f3d2acc
6 changed files with 209 additions and 54 deletions

View File

@ -539,7 +539,7 @@ mod tests {
// "Rewind" to height of last scanned block
db_data
.transactionally(|wdb| {
truncate_to_height(&wdb.conn.0, &wdb.params, sapling_activation_height() + 1)
truncate_to_height(wdb.conn.0, &wdb.params, sapling_activation_height() + 1)
})
.unwrap();
@ -552,7 +552,7 @@ mod tests {
// Rewind so that one block is dropped
db_data
.transactionally(|wdb| {
truncate_to_height(&wdb.conn.0, &wdb.params, sapling_activation_height())
truncate_to_height(wdb.conn.0, &wdb.params, sapling_activation_height())
})
.unwrap();

View File

@ -116,11 +116,11 @@ pub struct WalletDb<C, P> {
}
/// A wrapper for a SQLite transaction affecting the wallet database.
pub struct SqlTransaction<'conn>(pub(crate) rusqlite::Transaction<'conn>);
pub struct SqlTransaction<'conn>(pub(crate) &'conn rusqlite::Transaction<'conn>);
impl Borrow<rusqlite::Connection> for SqlTransaction<'_> {
fn borrow(&self) -> &rusqlite::Connection {
&self.0
self.0
}
}
@ -137,12 +137,13 @@ impl<P: consensus::Parameters + Clone> WalletDb<Connection, P> {
where
F: FnOnce(&mut WalletDb<SqlTransaction<'_>, P>) -> Result<A, E>,
{
let tx = self.conn.transaction()?;
let mut wdb = WalletDb {
conn: SqlTransaction(self.conn.transaction()?),
conn: SqlTransaction(&tx),
params: self.params.clone(),
};
let result = f(&mut wdb)?;
wdb.conn.0.commit()?;
tx.commit()?;
Ok(result)
}
}
@ -334,7 +335,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
seed: &SecretVec<u8>,
) -> Result<(AccountId, UnifiedSpendingKey), Self::Error> {
self.transactionally(|wdb| {
let account = wallet::get_max_account_id(&wdb.conn.0)?
let account = wallet::get_max_account_id(wdb.conn.0)?
.map(|a| AccountId::from(u32::from(a) + 1))
.unwrap_or_else(|| AccountId::from(0));
@ -346,7 +347,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
.map_err(|_| SqliteClientError::KeyDerivationError(account))?;
let ufvk = usk.to_unified_full_viewing_key();
wallet::add_account(&wdb.conn.0, &wdb.params, account, &ufvk)?;
wallet::add_account(wdb.conn.0, &wdb.params, account, &ufvk)?;
Ok((account, usk))
})
@ -360,7 +361,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
|wdb| match wdb.get_unified_full_viewing_keys()?.get(&account) {
Some(ufvk) => {
let search_from =
match wallet::get_current_address(&wdb.conn.0, &wdb.params, account)? {
match wallet::get_current_address(wdb.conn.0, &wdb.params, account)? {
Some((_, mut last_diversifier_index)) => {
last_diversifier_index
.increment()
@ -375,7 +376,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
.ok_or(SqliteClientError::DiversifierIndexOutOfRange)?;
wallet::insert_address(
&wdb.conn.0,
wdb.conn.0,
&wdb.params,
account,
diversifier_index,
@ -399,7 +400,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
// Insert the block into the database.
let block_height = block.block_height;
wallet::insert_block(
&wdb.conn.0,
wdb.conn.0,
block_height,
block.block_hash,
block.block_time,
@ -408,16 +409,16 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
let mut wallet_note_ids = vec![];
for tx in &block.transactions {
let tx_row = wallet::put_tx_meta(&wdb.conn.0, tx, block.block_height)?;
let tx_row = wallet::put_tx_meta(wdb.conn.0, tx, block.block_height)?;
// Mark notes as spent and remove them from the scanning cache
for spend in &tx.sapling_spends {
wallet::sapling::mark_sapling_note_spent(&wdb.conn.0, tx_row, spend.nf())?;
wallet::sapling::mark_sapling_note_spent(wdb.conn.0, tx_row, spend.nf())?;
}
for output in &tx.sapling_outputs {
let received_note_id =
wallet::sapling::put_received_note(&wdb.conn.0, output, tx_row)?;
wallet::sapling::put_received_note(wdb.conn.0, output, tx_row)?;
// Save witness for note.
wallet_note_ids.push(received_note_id);
@ -435,7 +436,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
})?;
// Update now-expired transactions that didn't get mined.
wallet::update_expired_notes(&wdb.conn.0, block_height)?;
wallet::update_expired_notes(wdb.conn.0, block_height)?;
Ok(wallet_note_ids)
})
@ -446,7 +447,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
d_tx: DecryptedTransaction,
) -> Result<Self::TxRef, Self::Error> {
self.transactionally(|wdb| {
let tx_ref = wallet::put_tx_data(&wdb.conn.0, d_tx.tx, None, None)?;
let tx_ref = wallet::put_tx_data(wdb.conn.0, d_tx.tx, None, None)?;
let mut spending_account_id: Option<AccountId> = None;
for output in d_tx.sapling_outputs {
@ -459,7 +460,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
};
wallet::put_sent_output(
&wdb.conn.0,
wdb.conn.0,
&wdb.params,
output.account,
tx_ref,
@ -474,7 +475,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
)?;
if matches!(recipient, Recipient::InternalAccount(_, _)) {
wallet::sapling::put_received_note(&wdb.conn.0, output, tx_ref)?;
wallet::sapling::put_received_note(wdb.conn.0, output, tx_ref)?;
}
}
TransferType::Incoming => {
@ -489,14 +490,14 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
}
}
wallet::sapling::put_received_note(&wdb.conn.0, output, tx_ref)?;
wallet::sapling::put_received_note(wdb.conn.0, output, tx_ref)?;
}
}
// If any of the utxos spent in the transaction are ours, mark them as spent.
#[cfg(feature = "transparent-inputs")]
for txin in d_tx.tx.transparent_bundle().iter().flat_map(|b| b.vin.iter()) {
wallet::mark_transparent_utxo_spent(&wdb.conn.0, tx_ref, &txin.prevout)?;
wallet::mark_transparent_utxo_spent(wdb.conn.0, tx_ref, &txin.prevout)?;
}
// If we have some transparent outputs:
@ -513,7 +514,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
for (output_index, txout) in d_tx.tx.transparent_bundle().iter().flat_map(|b| b.vout.iter()).enumerate() {
if let Some(address) = txout.recipient_address() {
wallet::put_sent_output(
&wdb.conn.0,
wdb.conn.0,
&wdb.params,
*account_id,
tx_ref,
@ -534,7 +535,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
fn store_sent_tx(&mut self, sent_tx: &SentTransaction) -> Result<Self::TxRef, Self::Error> {
self.transactionally(|wdb| {
let tx_ref = wallet::put_tx_data(
&wdb.conn.0,
wdb.conn.0,
sent_tx.tx,
Some(sent_tx.fee_amount),
Some(sent_tx.created),
@ -551,7 +552,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
if let Some(bundle) = sent_tx.tx.sapling_bundle() {
for spend in bundle.shielded_spends() {
wallet::sapling::mark_sapling_note_spent(
&wdb.conn.0,
wdb.conn.0,
tx_ref,
spend.nullifier(),
)?;
@ -560,12 +561,12 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
#[cfg(feature = "transparent-inputs")]
for utxo_outpoint in &sent_tx.utxos_spent {
wallet::mark_transparent_utxo_spent(&wdb.conn.0, tx_ref, utxo_outpoint)?;
wallet::mark_transparent_utxo_spent(wdb.conn.0, tx_ref, utxo_outpoint)?;
}
for output in &sent_tx.outputs {
wallet::insert_sent_output(
&wdb.conn.0,
wdb.conn.0,
&wdb.params,
tx_ref,
sent_tx.account,
@ -574,7 +575,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
if let Some((account, note)) = output.sapling_change_to() {
wallet::sapling::put_received_note(
&wdb.conn.0,
wdb.conn.0,
&DecryptedOutput {
index: output.output_index(),
note: note.clone(),
@ -596,7 +597,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
fn truncate_to_height(&mut self, block_height: BlockHeight) -> Result<(), Self::Error> {
self.transactionally(|wdb| {
wallet::truncate_to_height(&wdb.conn.0, &wdb.params, block_height)
wallet::truncate_to_height(wdb.conn.0, &wdb.params, block_height)
})
}
@ -661,7 +662,7 @@ impl<'conn, P: consensus::Parameters> WalletCommitmentTrees for WalletDb<SqlTran
E: From<ShardTreeError<Either<io::Error, rusqlite::Error>>>,
{
let mut shardtree = ShardTree::new(
WalletDbSaplingShardStore::from_connection(&self.conn.0)
WalletDbSaplingShardStore::from_connection(self.conn.0)
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?,
100,
);

View File

@ -89,7 +89,9 @@ use zcash_client_backend::{
wallet::WalletTx,
};
use crate::{error::SqliteClientError, PRUNING_HEIGHT};
use crate::{
error::SqliteClientError, SqlTransaction, WalletCommitmentTrees, WalletDb, PRUNING_HEIGHT,
};
#[cfg(feature = "transparent-inputs")]
use {
@ -637,7 +639,7 @@ pub(crate) fn get_min_unspent_height(
/// block, this function does nothing.
///
/// This should only be executed inside a transactional context.
pub(crate) fn truncate_to_height<P: consensus::Parameters>(
pub(crate) fn truncate_to_height<P: consensus::Parameters + Clone>(
conn: &rusqlite::Transaction,
params: &P,
block_height: BlockHeight,
@ -662,7 +664,16 @@ pub(crate) fn truncate_to_height<P: consensus::Parameters>(
// nothing to do if we're deleting back down to the max height
if block_height < last_scanned_height {
// Decrement witnesses.
// Truncate the note commitment trees
let mut wdb = WalletDb {
conn: SqlTransaction(conn),
params: params.clone(),
};
wdb.with_sapling_tree_mut(|tree| {
tree.truncate_removing_checkpoint(&block_height).map(|_| ())
})?;
// Remove any legacy Sapling witnesses
conn.execute(
"DELETE FROM sapling_witnesses WHERE block > ?",
[u32::from(block_height)],

View File

@ -237,7 +237,7 @@ pub fn init_accounts_table<P: consensus::Parameters>(
// Insert accounts atomically
for (account, key) in keys.iter() {
wallet::add_account(&wdb.conn.0, &wdb.params, *account, key)?;
wallet::add_account(wdb.conn.0, &wdb.params, *account, key)?;
}
Ok(())
@ -394,6 +394,9 @@ mod tests {
CONSTRAINT tx_output UNIQUE (tx, output_index)
)",
"CREATE TABLE sapling_tree_cap (
-- cap_id exists only to be able to take advantage of `ON CONFLICT`
-- upsert functionality; the table will only ever contain one row
cap_id INTEGER PRIMARY KEY,
cap_data BLOB NOT NULL
)",
"CREATE TABLE sapling_tree_checkpoint_marks_removed (

View File

@ -69,6 +69,9 @@ impl RusqliteMigration for Migration {
CONSTRAINT root_unique UNIQUE (root_hash)
);
CREATE TABLE sapling_tree_cap (
-- cap_id exists only to be able to take advantage of `ON CONFLICT`
-- upsert functionality; the table will only ever contain one row
cap_id INTEGER PRIMARY KEY,
cap_data BLOB NOT NULL
);",
)?;

View File

@ -1,10 +1,14 @@
use either::Either;
use incrementalmerkletree::Address;
use rusqlite::{self, named_params, OptionalExtension};
use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore};
use incrementalmerkletree::{Address, Position};
use rusqlite::{self, named_params, Connection, OptionalExtension};
use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore, TreeState};
use std::io::{self, Cursor};
use std::{
collections::BTreeSet,
io::{self, Cursor},
ops::Deref,
};
use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer, sapling};
@ -48,16 +52,16 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
todo!()
}
fn truncate(&mut self, _from: Address) -> Result<(), Self::Error> {
todo!()
fn truncate(&mut self, from: Address) -> Result<(), Self::Error> {
truncate(self.conn, from)
}
fn get_cap(&self) -> Result<PrunableTree<Self::H>, Self::Error> {
todo!()
get_cap(self.conn)
}
fn put_cap(&mut self, _cap: PrunableTree<Self::H>) -> Result<(), Self::Error> {
todo!()
fn put_cap(&mut self, cap: PrunableTree<Self::H>) -> Result<(), Self::Error> {
put_cap(self.conn, cap)
}
fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
@ -89,9 +93,9 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
fn get_checkpoint(
&self,
_checkpoint_id: &Self::CheckpointId,
checkpoint_id: &Self::CheckpointId,
) -> Result<Option<Checkpoint>, Self::Error> {
todo!()
get_checkpoint(self.conn, *checkpoint_id)
}
fn with_checkpoints<F>(&mut self, _limit: usize, _callback: F) -> Result<(), Self::Error>
@ -103,27 +107,24 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
fn update_checkpoint_with<F>(
&mut self,
_checkpoint_id: &Self::CheckpointId,
_update: F,
checkpoint_id: &Self::CheckpointId,
update: F,
) -> Result<bool, Self::Error>
where
F: Fn(&mut Checkpoint) -> Result<(), Self::Error>,
{
todo!()
update_checkpoint_with(self.conn, *checkpoint_id, update)
}
fn remove_checkpoint(
&mut self,
_checkpoint_id: &Self::CheckpointId,
) -> Result<(), Self::Error> {
todo!()
fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> {
remove_checkpoint(self.conn, *checkpoint_id)
}
fn truncate_checkpoints(
&mut self,
_checkpoint_id: &Self::CheckpointId,
checkpoint_id: &Self::CheckpointId,
) -> Result<(), Self::Error> {
todo!()
truncate_checkpoints(self.conn, *checkpoint_id)
}
}
@ -134,7 +135,7 @@ pub(crate) fn get_shard(
shard_root: Address,
) -> Result<Option<LocatedPrunableTree<sapling::Node>>, Error> {
conn.query_row(
"SELECT shard_data
"SELECT shard_data
FROM sapling_tree_shards
WHERE shard_index = :shard_index",
named_params![":shard_index": shard_root.index()],
@ -188,6 +189,47 @@ pub(crate) fn put_shard(
Ok(())
}
pub(crate) fn truncate(conn: &rusqlite::Transaction<'_>, from: Address) -> Result<(), Error> {
conn.execute(
"DELETE FROM sapling_tree_shards WHERE shard_index >= ?",
[from.index()],
)
.map_err(Either::Right)
.map(|_| ())
}
pub(crate) fn get_cap(conn: &rusqlite::Connection) -> Result<PrunableTree<sapling::Node>, Error> {
conn.query_row("SELECT cap_data FROM sapling_tree_cap", [], |row| {
row.get::<_, Vec<u8>>(0)
})
.optional()
.map_err(Either::Right)?
.map_or_else(
|| Ok(PrunableTree::empty()),
|cap_data| read_shard(&mut Cursor::new(cap_data)).map_err(Either::Left),
)
}
pub(crate) fn put_cap(
conn: &rusqlite::Transaction<'_>,
cap: PrunableTree<sapling::Node>,
) -> Result<(), Error> {
let mut stmt = conn
.prepare_cached(
"INSERT INTO sapling_tree_cap (cap_id, cap_data)
VALUES (0, :cap_data)
ON CONFLICT (cap_id) DO UPDATE
SET cap_data = :cap_data",
)
.map_err(Either::Right)?;
let mut cap_data = vec![];
write_shard_v1(&mut cap_data, &cap).map_err(Either::Left)?;
stmt.execute([cap_data]).map_err(Either::Right)?;
Ok(())
}
pub(crate) fn add_checkpoint(
conn: &rusqlite::Transaction<'_>,
checkpoint_id: BlockHeight,
@ -214,3 +256,98 @@ pub(crate) fn checkpoint_count(conn: &rusqlite::Connection) -> Result<usize, Err
})
.map_err(Either::Right)
}
pub(crate) fn get_checkpoint<C: Deref<Target = Connection>>(
conn: &C,
checkpoint_id: BlockHeight,
) -> Result<Option<Checkpoint>, Either<io::Error, rusqlite::Error>> {
let checkpoint_position = conn
.query_row(
"SELECT position
FROM sapling_tree_checkpoints
WHERE checkpoint_id = ?",
[u32::from(checkpoint_id)],
|row| {
row.get::<_, Option<u64>>(0)
.map(|opt| opt.map(Position::from))
},
)
.optional()
.map_err(Either::Right)?;
let mut marks_removed = BTreeSet::new();
let mut stmt = conn
.prepare_cached(
"SELECT mark_removed_position
FROM sapling_tree_checkpoint_marks_removed
WHERE checkpoint_id = ?",
)
.map_err(Either::Right)?;
let mut mark_removed_rows = stmt
.query([u32::from(checkpoint_id)])
.map_err(Either::Right)?;
while let Some(row) = mark_removed_rows.next().map_err(Either::Right)? {
marks_removed.insert(
row.get::<_, u64>(0)
.map(Position::from)
.map_err(Either::Right)?,
);
}
Ok(checkpoint_position.map(|pos_opt| {
Checkpoint::from_parts(
pos_opt.map_or(TreeState::Empty, TreeState::AtPosition),
marks_removed,
)
}))
}
pub(crate) fn update_checkpoint_with<F>(
conn: &rusqlite::Transaction<'_>,
checkpoint_id: BlockHeight,
update: F,
) -> Result<bool, Error>
where
F: Fn(&mut Checkpoint) -> Result<(), Error>,
{
if let Some(mut c) = get_checkpoint(conn, checkpoint_id)? {
update(&mut c)?;
remove_checkpoint(conn, checkpoint_id)?;
add_checkpoint(conn, checkpoint_id, c)?;
Ok(true)
} else {
Ok(false)
}
}
pub(crate) fn remove_checkpoint(
conn: &rusqlite::Transaction<'_>,
checkpoint_id: BlockHeight,
) -> Result<(), Error> {
conn.execute(
"DELETE FROM sapling_tree_checkpoints WHERE checkpoint_id = ?",
[u32::from(checkpoint_id)],
)
.map_err(Either::Right)?;
Ok(())
}
pub(crate) fn truncate_checkpoints(
conn: &rusqlite::Transaction<'_>,
checkpoint_id: BlockHeight,
) -> Result<(), Error> {
conn.execute(
"DELETE FROM sapling_tree_checkpoints WHERE checkpoint_id >= ?",
[u32::from(checkpoint_id)],
)
.map_err(Either::Right)?;
conn.execute(
"DELETE FROM sapling_tree_checkpoint_marks_removed WHERE checkpoint_id >= ?",
[u32::from(checkpoint_id)],
)
.map_err(Either::Right)?;
Ok(())
}