zcash_client_sqlite: Support shardtree checkpoint functionality

This commit is contained in:
Kris Nuttycombe 2023-06-14 16:49:16 -06:00
parent c42cffeb1d
commit 425b5e01d7
6 changed files with 270 additions and 105 deletions

View File

@ -23,12 +23,12 @@ pub mod migrations;
/// Implements a traversal of `limit` blocks of the block cache database.
///
/// Starting at the next block above `last_scanned_height`, the `with_row` callback is invoked with
/// each block retrieved from the backing store. If the `limit` value provided is `None`, all
/// blocks are traversed up to the maximum height.
/// Starting at `from_height`, the `with_row` callback is invoked with each block retrieved from
/// the backing store. If the `limit` value provided is `None`, all blocks are traversed up to the
/// maximum height.
pub(crate) fn blockdb_with_blocks<F, DbErrT, NoteRef>(
block_source: &BlockDb,
last_scanned_height: Option<BlockHeight>,
from_height: Option<BlockHeight>,
limit: Option<u32>,
mut with_row: F,
) -> Result<(), Error<DbErrT, SqliteClientError, NoteRef>>
@ -44,14 +44,14 @@ where
.0
.prepare(
"SELECT height, data FROM compactblocks
WHERE height > ?
WHERE height >= ?
ORDER BY height ASC LIMIT ?",
)
.map_err(to_chain_error)?;
let mut rows = stmt_blocks
.query(params![
last_scanned_height.map_or(0u32, u32::from),
from_height.map_or(0u32, u32::from),
limit.unwrap_or(u32::max_value()),
])
.map_err(to_chain_error)?;
@ -191,13 +191,13 @@ pub(crate) fn blockmetadb_find_block(
/// Implements a traversal of `limit` blocks of the filesystem-backed
/// block cache.
///
/// Starting at the next block height above `last_scanned_height`, the `with_row` callback is
/// invoked with each block retrieved from the backing store. If the `limit` value provided is
/// `None`, all blocks are traversed up to the maximum height for which metadata is available.
/// Starting at `from_height`, the `with_row` callback is invoked with each block retrieved from
/// the backing store. If the `limit` value provided is `None`, all blocks are traversed up to the
/// maximum height for which metadata is available.
#[cfg(feature = "unstable")]
pub(crate) fn fsblockdb_with_blocks<F, DbErrT, NoteRef>(
cache: &FsBlockDb,
last_scanned_height: Option<BlockHeight>,
from_height: Option<BlockHeight>,
limit: Option<u32>,
mut with_block: F,
) -> Result<(), Error<DbErrT, FsBlockDbError, NoteRef>>
@ -214,7 +214,7 @@ where
.prepare(
"SELECT height, blockhash, time, sapling_outputs_count, orchard_actions_count
FROM compactblocks_meta
WHERE height > ?
WHERE height >= ?
ORDER BY height ASC LIMIT ?",
)
.map_err(to_chain_error)?;
@ -222,7 +222,7 @@ where
let rows = stmt_blocks
.query_map(
params![
last_scanned_height.map_or(0u32, u32::from),
from_height.map_or(0u32, u32::from),
limit.unwrap_or(u32::max_value()),
],
|row| {
@ -269,14 +269,22 @@ mod tests {
use tempfile::NamedTempFile;
use zcash_primitives::{
block::BlockHash, transaction::components::Amount, zip32::ExtendedSpendingKey,
block::BlockHash,
transaction::{components::Amount, fees::zip317::FeeRule},
zip32::ExtendedSpendingKey,
};
use zcash_client_backend::data_api::chain::{
error::{Cause, Error},
scan_cached_blocks, validate_chain,
use zcash_client_backend::{
address::RecipientAddress,
data_api::{
chain::{error::Error, scan_cached_blocks, validate_chain},
wallet::{input_selection::GreedyInputSelector, spend},
WalletRead, WalletWrite,
},
fees::{zip317::SingleOutputChangeStrategy, DustOutputPolicy},
wallet::OvkPolicy,
zip321::{Payment, TransactionRequest},
};
use zcash_client_backend::data_api::WalletRead;
use crate::{
chain::init::init_cache_database,
@ -573,7 +581,7 @@ mod tests {
}
#[test]
fn scan_cached_blocks_requires_sequential_blocks() {
fn scan_cached_blocks_allows_blocks_out_of_order() {
let cache_file = NamedTempFile::new().unwrap();
let db_cache = BlockDb::for_path(cache_file.path()).unwrap();
init_cache_database(&db_cache).unwrap();
@ -583,7 +591,9 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&mut db_data);
let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Create a block with height SAPLING_ACTIVATION_HEIGHT
let value = Amount::from_u64(50000).unwrap();
@ -602,7 +612,7 @@ mod tests {
value
);
// We cannot scan a block of height SAPLING_ACTIVATION_HEIGHT + 2 next
// Create blocks to reach SAPLING_ACTIVATION_HEIGHT + 2
let (cb2, _) = fake_compact_block(
sapling_activation_height() + 1,
cb1.hash(),
@ -619,25 +629,62 @@ mod tests {
value,
2,
);
insert_into_cache(&db_cache, &cb3);
match scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None, None) {
Err(Error::Chain(e)) => {
assert_matches!(
e.cause(),
Cause::BlockHeightDiscontinuity(h) if *h
== sapling_activation_height() + 2
);
}
Ok(_) | Err(_) => panic!("Should have failed"),
}
// If we add a block of height SAPLING_ACTIVATION_HEIGHT + 1, we can now scan both
// Scan the later block first
insert_into_cache(&db_cache, &cb3);
assert_matches!(
scan_cached_blocks(
&tests::network(),
&db_cache,
&mut db_data,
Some(sapling_activation_height() + 2),
None
),
Ok(_)
);
// If we add a block of height SAPLING_ACTIVATION_HEIGHT + 1, we can now scan that
insert_into_cache(&db_cache, &cb2);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None, None).unwrap();
scan_cached_blocks(
&tests::network(),
&db_cache,
&mut db_data,
Some(sapling_activation_height() + 1),
Some(1),
)
.unwrap();
assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::from_u64(150_000).unwrap()
);
// We can spend the received notes
let req = TransactionRequest::new(vec![Payment {
recipient_address: RecipientAddress::Shielded(dfvk.default_address().1),
amount: Amount::from_u64(110_000).unwrap(),
memo: None,
label: None,
message: None,
other_params: vec![],
}])
.unwrap();
let input_selector = GreedyInputSelector::new(
SingleOutputChangeStrategy::new(FeeRule::standard()),
DustOutputPolicy::default(),
);
assert_matches!(
spend(
&mut db_data,
&tests::network(),
crate::wallet::sapling::tests::test_prover(),
&input_selector,
&usk,
req,
OvkPolicy::Sender,
1,
),
Ok(_)
);
}
#[test]

View File

@ -399,7 +399,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
self.transactionally(|wdb| {
// Insert the block into the database.
let block_height = block.block_height;
wallet::insert_block(
wallet::put_block(
wdb.conn.0,
block_height,
block.block_hash,

View File

@ -64,7 +64,7 @@
//! wallet.
//! - `memo` the shielded memo associated with the output, if any.
use rusqlite::{self, named_params, params, OptionalExtension, ToSql};
use rusqlite::{self, named_params, OptionalExtension, ToSql};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::io::Cursor;
@ -735,15 +735,18 @@ pub(crate) fn get_unspent_transparent_outputs<P: consensus::Parameters>(
FROM utxos u
LEFT OUTER JOIN transactions tx
ON tx.id_tx = u.spent_in_tx
WHERE u.address = ?
AND u.height <= ?
WHERE u.address = :address
AND u.height <= :max_height
AND tx.block IS NULL",
)?;
let addr_str = address.encode(params);
let mut utxos = Vec::<WalletTransparentOutput>::new();
let mut rows = stmt_blocks.query(params![addr_str, u32::from(max_height)])?;
let mut rows = stmt_blocks.query(named_params![
":address": addr_str,
":max_height": u32::from(max_height)
])?;
let excluded: BTreeSet<OutPoint> = exclude.iter().cloned().collect();
while let Some(row) = rows.next()? {
let txid: Vec<u8> = row.get(0)?;
@ -796,14 +799,17 @@ pub(crate) fn get_transparent_balances<P: consensus::Parameters>(
FROM utxos u
LEFT OUTER JOIN transactions tx
ON tx.id_tx = u.spent_in_tx
WHERE u.received_by_account = ?
AND u.height <= ?
WHERE u.received_by_account = :account_id
AND u.height <= :max_height
AND tx.block IS NULL
GROUP BY u.address",
)?;
let mut res = HashMap::new();
let mut rows = stmt_blocks.query(params![u32::from(account), u32::from(max_height)])?;
let mut rows = stmt_blocks.query(named_params![
":account_id": u32::from(account),
":max_height": u32::from(max_height)
])?;
while let Some(row) = rows.next()? {
let taddr_str: String = row.get(0)?;
let taddr = TransparentAddress::decode(params, &taddr_str)?;
@ -816,14 +822,14 @@ pub(crate) fn get_transparent_balances<P: consensus::Parameters>(
}
/// Inserts information about a scanned block into the database.
pub(crate) fn insert_block(
pub(crate) fn put_block(
conn: &rusqlite::Connection,
block_height: BlockHeight,
block_hash: BlockHash,
block_time: u32,
sapling_commitment_tree_size: Option<u64>,
) -> Result<(), SqliteClientError> {
let mut stmt_insert_block = conn.prepare_cached(
let mut stmt_upsert_block = conn.prepare_cached(
"INSERT INTO blocks (
height,
hash,
@ -831,14 +837,24 @@ pub(crate) fn insert_block(
sapling_commitment_tree_size,
sapling_tree
)
VALUES (?, ?, ?, ?, x'00')",
VALUES (
:height,
:hash,
:block_time,
:sapling_commitment_tree_size,
x'00'
)
ON CONFLICT (height) DO UPDATE
SET hash = :hash,
time = :block_time,
sapling_commitment_tree_size = :sapling_commitment_tree_size",
)?;
stmt_insert_block.execute(params![
u32::from(block_height),
&block_hash.0[..],
block_time,
sapling_commitment_tree_size
stmt_upsert_block.execute(named_params![
":height": u32::from(block_height),
":hash": &block_hash.0[..],
":block_time": block_time,
":sapling_commitment_tree_size": sapling_commitment_tree_size
])?;
Ok(())

View File

@ -11,6 +11,7 @@ use schemer_rusqlite::RusqliteMigration;
use shardtree::ShardTree;
use uuid::Uuid;
use zcash_client_backend::data_api::SAPLING_SHARD_HEIGHT;
use zcash_primitives::{
consensus::BlockHeight,
merkle_tree::{read_commitment_tree, read_incremental_witness},
@ -93,7 +94,7 @@ impl RusqliteMigration for Migration {
let mut shard_tree: ShardTree<
_,
{ sapling::NOTE_COMMITMENT_TREE_DEPTH },
{ sapling::NOTE_COMMITMENT_TREE_DEPTH / 2 },
SAPLING_SHARD_HEIGHT,
> = ShardTree::new(shard_store, 100);
// Insert all the tree information that we can get from block-end commitment trees
{

View File

@ -368,7 +368,7 @@ pub(crate) fn put_received_note<T: ReceivedSaplingOutput>(
#[cfg(test)]
#[allow(deprecated)]
mod tests {
pub(crate) mod tests {
use rusqlite::Connection;
use secrecy::Secret;
use tempfile::NamedTempFile;
@ -427,7 +427,7 @@ mod tests {
},
};
fn test_prover() -> impl TxProver {
pub fn test_prover() -> impl TxProver {
match LocalTxProver::with_default_location() {
Some(tx_prover) => tx_prover,
None => {
@ -463,7 +463,7 @@ mod tests {
Amount::from_u64(1).unwrap(),
None,
OvkPolicy::Sender,
10,
1,
),
Err(data_api::error::Error::KeyNotRecognized)
);
@ -492,7 +492,7 @@ mod tests {
Amount::from_u64(1).unwrap(),
None,
OvkPolicy::Sender,
10,
1,
),
Err(data_api::error::Error::ScanRequired)
);
@ -535,7 +535,7 @@ mod tests {
Amount::from_u64(1).unwrap(),
None,
OvkPolicy::Sender,
10,
1,
),
Err(data_api::error::Error::InsufficientFunds {
available,
@ -740,7 +740,7 @@ mod tests {
Amount::from_u64(15000).unwrap(),
None,
OvkPolicy::Sender,
10,
1,
),
Ok(_)
);
@ -756,7 +756,7 @@ mod tests {
Amount::from_u64(2000).unwrap(),
None,
OvkPolicy::Sender,
10,
1,
),
Err(data_api::error::Error::InsufficientFunds {
available,
@ -791,7 +791,7 @@ mod tests {
Amount::from_u64(2000).unwrap(),
None,
OvkPolicy::Sender,
10,
1,
),
Err(data_api::error::Error::InsufficientFunds {
available,
@ -822,7 +822,7 @@ mod tests {
Amount::from_u64(2000).unwrap(),
None,
OvkPolicy::Sender,
10,
1,
)
.unwrap();
}
@ -874,7 +874,7 @@ mod tests {
Amount::from_u64(15000).unwrap(),
None,
ovk_policy,
10,
1,
)
.unwrap();
@ -962,7 +962,7 @@ mod tests {
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None, None).unwrap();
// Verified balance matches total balance
let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap();
let (_, anchor_height) = db_data.get_target_and_anchor_heights(1).unwrap().unwrap();
assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value
@ -983,7 +983,7 @@ mod tests {
Amount::from_u64(50000).unwrap(),
None,
OvkPolicy::Sender,
10,
1,
),
Ok(_)
);
@ -1039,7 +1039,7 @@ mod tests {
Amount::from_u64(50000).unwrap(),
None,
OvkPolicy::Sender,
10,
1,
),
Ok(_)
);
@ -1193,7 +1193,7 @@ mod tests {
DustOutputPolicy::default(),
);
// Add funds to the wallet
// Ensure that the wallet has at least one block
let (cb, _) = fake_compact_block(
sapling_activation_height(),
BlockHash([0; 32]),
@ -1215,7 +1215,7 @@ mod tests {
&usk,
&[*taddr],
&MemoBytes::empty(),
0
1
),
Ok(_)
);

View File

@ -1,19 +1,22 @@
use either::Either;
use incrementalmerkletree::{Address, Position};
use rusqlite::{self, named_params, Connection, OptionalExtension};
use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore, TreeState};
use std::{
collections::BTreeSet,
io::{self, Cursor},
ops::Deref,
};
use incrementalmerkletree::{Address, Level, Position};
use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore, TreeState};
use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer, sapling};
use zcash_client_backend::data_api::SAPLING_SHARD_HEIGHT;
use crate::serialization::{read_shard, write_shard_v1};
const SHARD_ROOT_LEVEL: Level = Level::new(SAPLING_SHARD_HEIGHT);
pub struct WalletDbSaplingShardStore<'conn, 'a> {
pub(crate) conn: &'a rusqlite::Transaction<'conn>,
}
@ -39,8 +42,7 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
}
fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
// SELECT shard_data FROM sapling_tree ORDER BY shard_index DESC LIMIT 1
todo!()
last_shard(self.conn)
}
fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
@ -48,8 +50,7 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
}
fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
// SELECT
todo!()
get_shard_roots(self.conn)
}
fn truncate(&mut self, from: Address) -> Result<(), Self::Error> {
@ -86,9 +87,9 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
fn get_checkpoint_at_depth(
&self,
_checkpoint_depth: usize,
checkpoint_depth: usize,
) -> Result<Option<(Self::CheckpointId, Checkpoint)>, Self::Error> {
todo!()
get_checkpoint_at_depth(self.conn, checkpoint_depth)
}
fn get_checkpoint(
@ -150,6 +151,31 @@ pub(crate) fn get_shard(
.transpose()
}
pub(crate) fn last_shard(
conn: &rusqlite::Connection,
) -> Result<Option<LocatedPrunableTree<sapling::Node>>, Error> {
conn.query_row(
"SELECT shard_index, shard_data
FROM sapling_tree_shards
ORDER BY shard_index DESC
LIMIT 1",
[],
|row| {
let shard_index: u64 = row.get(0)?;
let shard_data: Vec<u8> = row.get(1)?;
Ok((shard_index, shard_data))
},
)
.optional()
.map_err(Either::Right)?
.map(|(shard_index, shard_data)| {
let shard_root = Address::from_parts(SHARD_ROOT_LEVEL, shard_index);
let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Either::Left)?;
Ok(LocatedPrunableTree::from_parts(shard_root, shard_tree))
})
.transpose()
}
pub(crate) fn put_shard(
conn: &rusqlite::Connection,
subtree: LocatedPrunableTree<sapling::Node>,
@ -189,6 +215,22 @@ pub(crate) fn put_shard(
Ok(())
}
pub(crate) fn get_shard_roots(conn: &rusqlite::Connection) -> Result<Vec<Address>, Error> {
let mut stmt = conn
.prepare("SELECT shard_index FROM sapling_tree_shards ORDER BY shard_index")
.map_err(Either::Right)?;
let mut rows = stmt.query([]).map_err(Either::Right)?;
let mut res = vec![];
while let Some(row) = rows.next().map_err(Either::Right)? {
res.push(Address::from_parts(
SHARD_ROOT_LEVEL,
row.get(0).map_err(Either::Right)?,
));
}
Ok(res)
}
pub(crate) fn truncate(conn: &rusqlite::Transaction<'_>, from: Address) -> Result<(), Error> {
conn.execute(
"DELETE FROM sapling_tree_shards WHERE shard_index >= ?",
@ -275,6 +317,8 @@ pub(crate) fn get_checkpoint<C: Deref<Target = Connection>>(
.optional()
.map_err(Either::Right)?;
checkpoint_position
.map(|pos_opt| {
let mut marks_removed = BTreeSet::new();
let mut stmt = conn
.prepare_cached(
@ -295,12 +339,69 @@ pub(crate) fn get_checkpoint<C: Deref<Target = Connection>>(
);
}
Ok(checkpoint_position.map(|pos_opt| {
Ok(Checkpoint::from_parts(
pos_opt.map_or(TreeState::Empty, TreeState::AtPosition),
marks_removed,
))
})
.transpose()
}
pub(crate) fn get_checkpoint_at_depth<C: Deref<Target = Connection>>(
conn: &C,
checkpoint_depth: usize,
) -> Result<Option<(BlockHeight, Checkpoint)>, Either<io::Error, rusqlite::Error>> {
let checkpoint_parts = conn
.query_row(
"SELECT checkpoint_id, position
FROM sapling_tree_checkpoints
ORDER BY checkpoint_id DESC
LIMIT 1
OFFSET :offset",
named_params![":offset": checkpoint_depth],
|row| {
let checkpoint_id: u32 = row.get(0)?;
let position: Option<u64> = row.get(1)?;
Ok((
BlockHeight::from(checkpoint_id),
position.map(Position::from),
))
},
)
.optional()
.map_err(Either::Right)?;
checkpoint_parts
.map(|(checkpoint_id, pos_opt)| {
let mut marks_removed = BTreeSet::new();
let mut stmt = conn
.prepare_cached(
"SELECT mark_removed_position
FROM sapling_tree_checkpoint_marks_removed
WHERE checkpoint_id = ?",
)
.map_err(Either::Right)?;
let mut mark_removed_rows = stmt
.query([u32::from(checkpoint_id)])
.map_err(Either::Right)?;
while let Some(row) = mark_removed_rows.next().map_err(Either::Right)? {
marks_removed.insert(
row.get::<_, u64>(0)
.map(Position::from)
.map_err(Either::Right)?,
);
}
Ok((
checkpoint_id,
Checkpoint::from_parts(
pos_opt.map_or(TreeState::Empty, TreeState::AtPosition),
marks_removed,
)
}))
),
))
})
.transpose()
}
pub(crate) fn update_checkpoint_with<F>(