zcash_client_sqlite: Support shardtree checkpoint functionality

This commit is contained in:
Kris Nuttycombe 2023-06-14 16:49:16 -06:00
parent c42cffeb1d
commit 425b5e01d7
6 changed files with 270 additions and 105 deletions

View File

@ -23,12 +23,12 @@ pub mod migrations;
/// Implements a traversal of `limit` blocks of the block cache database. /// Implements a traversal of `limit` blocks of the block cache database.
/// ///
/// Starting at the next block above `last_scanned_height`, the `with_row` callback is invoked with /// Starting at `from_height`, the `with_row` callback is invoked with each block retrieved from
/// each block retrieved from the backing store. If the `limit` value provided is `None`, all /// the backing store. If the `limit` value provided is `None`, all blocks are traversed up to the
/// blocks are traversed up to the maximum height. /// maximum height.
pub(crate) fn blockdb_with_blocks<F, DbErrT, NoteRef>( pub(crate) fn blockdb_with_blocks<F, DbErrT, NoteRef>(
block_source: &BlockDb, block_source: &BlockDb,
last_scanned_height: Option<BlockHeight>, from_height: Option<BlockHeight>,
limit: Option<u32>, limit: Option<u32>,
mut with_row: F, mut with_row: F,
) -> Result<(), Error<DbErrT, SqliteClientError, NoteRef>> ) -> Result<(), Error<DbErrT, SqliteClientError, NoteRef>>
@ -44,14 +44,14 @@ where
.0 .0
.prepare( .prepare(
"SELECT height, data FROM compactblocks "SELECT height, data FROM compactblocks
WHERE height > ? WHERE height >= ?
ORDER BY height ASC LIMIT ?", ORDER BY height ASC LIMIT ?",
) )
.map_err(to_chain_error)?; .map_err(to_chain_error)?;
let mut rows = stmt_blocks let mut rows = stmt_blocks
.query(params![ .query(params![
last_scanned_height.map_or(0u32, u32::from), from_height.map_or(0u32, u32::from),
limit.unwrap_or(u32::max_value()), limit.unwrap_or(u32::max_value()),
]) ])
.map_err(to_chain_error)?; .map_err(to_chain_error)?;
@ -191,13 +191,13 @@ pub(crate) fn blockmetadb_find_block(
/// Implements a traversal of `limit` blocks of the filesystem-backed /// Implements a traversal of `limit` blocks of the filesystem-backed
/// block cache. /// block cache.
/// ///
/// Starting at the next block height above `last_scanned_height`, the `with_row` callback is /// Starting at `from_height`, the `with_row` callback is invoked with each block retrieved from
/// invoked with each block retrieved from the backing store. If the `limit` value provided is /// the backing store. If the `limit` value provided is `None`, all blocks are traversed up to the
/// `None`, all blocks are traversed up to the maximum height for which metadata is available. /// maximum height for which metadata is available.
#[cfg(feature = "unstable")] #[cfg(feature = "unstable")]
pub(crate) fn fsblockdb_with_blocks<F, DbErrT, NoteRef>( pub(crate) fn fsblockdb_with_blocks<F, DbErrT, NoteRef>(
cache: &FsBlockDb, cache: &FsBlockDb,
last_scanned_height: Option<BlockHeight>, from_height: Option<BlockHeight>,
limit: Option<u32>, limit: Option<u32>,
mut with_block: F, mut with_block: F,
) -> Result<(), Error<DbErrT, FsBlockDbError, NoteRef>> ) -> Result<(), Error<DbErrT, FsBlockDbError, NoteRef>>
@ -214,7 +214,7 @@ where
.prepare( .prepare(
"SELECT height, blockhash, time, sapling_outputs_count, orchard_actions_count "SELECT height, blockhash, time, sapling_outputs_count, orchard_actions_count
FROM compactblocks_meta FROM compactblocks_meta
WHERE height > ? WHERE height >= ?
ORDER BY height ASC LIMIT ?", ORDER BY height ASC LIMIT ?",
) )
.map_err(to_chain_error)?; .map_err(to_chain_error)?;
@ -222,7 +222,7 @@ where
let rows = stmt_blocks let rows = stmt_blocks
.query_map( .query_map(
params![ params![
last_scanned_height.map_or(0u32, u32::from), from_height.map_or(0u32, u32::from),
limit.unwrap_or(u32::max_value()), limit.unwrap_or(u32::max_value()),
], ],
|row| { |row| {
@ -269,14 +269,22 @@ mod tests {
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
use zcash_primitives::{ use zcash_primitives::{
block::BlockHash, transaction::components::Amount, zip32::ExtendedSpendingKey, block::BlockHash,
transaction::{components::Amount, fees::zip317::FeeRule},
zip32::ExtendedSpendingKey,
}; };
use zcash_client_backend::data_api::chain::{ use zcash_client_backend::{
error::{Cause, Error}, address::RecipientAddress,
scan_cached_blocks, validate_chain, data_api::{
chain::{error::Error, scan_cached_blocks, validate_chain},
wallet::{input_selection::GreedyInputSelector, spend},
WalletRead, WalletWrite,
},
fees::{zip317::SingleOutputChangeStrategy, DustOutputPolicy},
wallet::OvkPolicy,
zip321::{Payment, TransactionRequest},
}; };
use zcash_client_backend::data_api::WalletRead;
use crate::{ use crate::{
chain::init::init_cache_database, chain::init::init_cache_database,
@ -573,7 +581,7 @@ mod tests {
} }
#[test] #[test]
fn scan_cached_blocks_requires_sequential_blocks() { fn scan_cached_blocks_allows_blocks_out_of_order() {
let cache_file = NamedTempFile::new().unwrap(); let cache_file = NamedTempFile::new().unwrap();
let db_cache = BlockDb::for_path(cache_file.path()).unwrap(); let db_cache = BlockDb::for_path(cache_file.path()).unwrap();
init_cache_database(&db_cache).unwrap(); init_cache_database(&db_cache).unwrap();
@ -583,7 +591,9 @@ mod tests {
init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap(); init_wallet_db(&mut db_data, Some(Secret::new(vec![]))).unwrap();
// Add an account to the wallet // Add an account to the wallet
let (dfvk, _taddr) = init_test_accounts_table(&mut db_data); let seed = Secret::new([0u8; 32].to_vec());
let (_, usk) = db_data.create_account(&seed).unwrap();
let dfvk = usk.sapling().to_diversifiable_full_viewing_key();
// Create a block with height SAPLING_ACTIVATION_HEIGHT // Create a block with height SAPLING_ACTIVATION_HEIGHT
let value = Amount::from_u64(50000).unwrap(); let value = Amount::from_u64(50000).unwrap();
@ -602,7 +612,7 @@ mod tests {
value value
); );
// We cannot scan a block of height SAPLING_ACTIVATION_HEIGHT + 2 next // Create blocks to reach SAPLING_ACTIVATION_HEIGHT + 2
let (cb2, _) = fake_compact_block( let (cb2, _) = fake_compact_block(
sapling_activation_height() + 1, sapling_activation_height() + 1,
cb1.hash(), cb1.hash(),
@ -619,25 +629,62 @@ mod tests {
value, value,
2, 2,
); );
insert_into_cache(&db_cache, &cb3);
match scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None, None) {
Err(Error::Chain(e)) => {
assert_matches!(
e.cause(),
Cause::BlockHeightDiscontinuity(h) if *h
== sapling_activation_height() + 2
);
}
Ok(_) | Err(_) => panic!("Should have failed"),
}
// If we add a block of height SAPLING_ACTIVATION_HEIGHT + 1, we can now scan both // Scan the later block first
insert_into_cache(&db_cache, &cb3);
assert_matches!(
scan_cached_blocks(
&tests::network(),
&db_cache,
&mut db_data,
Some(sapling_activation_height() + 2),
None
),
Ok(_)
);
// If we add a block of height SAPLING_ACTIVATION_HEIGHT + 1, we can now scan that
insert_into_cache(&db_cache, &cb2); insert_into_cache(&db_cache, &cb2);
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None, None).unwrap(); scan_cached_blocks(
&tests::network(),
&db_cache,
&mut db_data,
Some(sapling_activation_height() + 1),
Some(1),
)
.unwrap();
assert_eq!( assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
Amount::from_u64(150_000).unwrap() Amount::from_u64(150_000).unwrap()
); );
// We can spend the received notes
let req = TransactionRequest::new(vec![Payment {
recipient_address: RecipientAddress::Shielded(dfvk.default_address().1),
amount: Amount::from_u64(110_000).unwrap(),
memo: None,
label: None,
message: None,
other_params: vec![],
}])
.unwrap();
let input_selector = GreedyInputSelector::new(
SingleOutputChangeStrategy::new(FeeRule::standard()),
DustOutputPolicy::default(),
);
assert_matches!(
spend(
&mut db_data,
&tests::network(),
crate::wallet::sapling::tests::test_prover(),
&input_selector,
&usk,
req,
OvkPolicy::Sender,
1,
),
Ok(_)
);
} }
#[test] #[test]

View File

@ -399,7 +399,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
self.transactionally(|wdb| { self.transactionally(|wdb| {
// Insert the block into the database. // Insert the block into the database.
let block_height = block.block_height; let block_height = block.block_height;
wallet::insert_block( wallet::put_block(
wdb.conn.0, wdb.conn.0,
block_height, block_height,
block.block_hash, block.block_hash,

View File

@ -64,7 +64,7 @@
//! wallet. //! wallet.
//! - `memo` the shielded memo associated with the output, if any. //! - `memo` the shielded memo associated with the output, if any.
use rusqlite::{self, named_params, params, OptionalExtension, ToSql}; use rusqlite::{self, named_params, OptionalExtension, ToSql};
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::io::Cursor; use std::io::Cursor;
@ -735,15 +735,18 @@ pub(crate) fn get_unspent_transparent_outputs<P: consensus::Parameters>(
FROM utxos u FROM utxos u
LEFT OUTER JOIN transactions tx LEFT OUTER JOIN transactions tx
ON tx.id_tx = u.spent_in_tx ON tx.id_tx = u.spent_in_tx
WHERE u.address = ? WHERE u.address = :address
AND u.height <= ? AND u.height <= :max_height
AND tx.block IS NULL", AND tx.block IS NULL",
)?; )?;
let addr_str = address.encode(params); let addr_str = address.encode(params);
let mut utxos = Vec::<WalletTransparentOutput>::new(); let mut utxos = Vec::<WalletTransparentOutput>::new();
let mut rows = stmt_blocks.query(params![addr_str, u32::from(max_height)])?; let mut rows = stmt_blocks.query(named_params![
":address": addr_str,
":max_height": u32::from(max_height)
])?;
let excluded: BTreeSet<OutPoint> = exclude.iter().cloned().collect(); let excluded: BTreeSet<OutPoint> = exclude.iter().cloned().collect();
while let Some(row) = rows.next()? { while let Some(row) = rows.next()? {
let txid: Vec<u8> = row.get(0)?; let txid: Vec<u8> = row.get(0)?;
@ -796,14 +799,17 @@ pub(crate) fn get_transparent_balances<P: consensus::Parameters>(
FROM utxos u FROM utxos u
LEFT OUTER JOIN transactions tx LEFT OUTER JOIN transactions tx
ON tx.id_tx = u.spent_in_tx ON tx.id_tx = u.spent_in_tx
WHERE u.received_by_account = ? WHERE u.received_by_account = :account_id
AND u.height <= ? AND u.height <= :max_height
AND tx.block IS NULL AND tx.block IS NULL
GROUP BY u.address", GROUP BY u.address",
)?; )?;
let mut res = HashMap::new(); let mut res = HashMap::new();
let mut rows = stmt_blocks.query(params![u32::from(account), u32::from(max_height)])?; let mut rows = stmt_blocks.query(named_params![
":account_id": u32::from(account),
":max_height": u32::from(max_height)
])?;
while let Some(row) = rows.next()? { while let Some(row) = rows.next()? {
let taddr_str: String = row.get(0)?; let taddr_str: String = row.get(0)?;
let taddr = TransparentAddress::decode(params, &taddr_str)?; let taddr = TransparentAddress::decode(params, &taddr_str)?;
@ -816,14 +822,14 @@ pub(crate) fn get_transparent_balances<P: consensus::Parameters>(
} }
/// Inserts information about a scanned block into the database. /// Inserts information about a scanned block into the database.
pub(crate) fn insert_block( pub(crate) fn put_block(
conn: &rusqlite::Connection, conn: &rusqlite::Connection,
block_height: BlockHeight, block_height: BlockHeight,
block_hash: BlockHash, block_hash: BlockHash,
block_time: u32, block_time: u32,
sapling_commitment_tree_size: Option<u64>, sapling_commitment_tree_size: Option<u64>,
) -> Result<(), SqliteClientError> { ) -> Result<(), SqliteClientError> {
let mut stmt_insert_block = conn.prepare_cached( let mut stmt_upsert_block = conn.prepare_cached(
"INSERT INTO blocks ( "INSERT INTO blocks (
height, height,
hash, hash,
@ -831,14 +837,24 @@ pub(crate) fn insert_block(
sapling_commitment_tree_size, sapling_commitment_tree_size,
sapling_tree sapling_tree
) )
VALUES (?, ?, ?, ?, x'00')", VALUES (
:height,
:hash,
:block_time,
:sapling_commitment_tree_size,
x'00'
)
ON CONFLICT (height) DO UPDATE
SET hash = :hash,
time = :block_time,
sapling_commitment_tree_size = :sapling_commitment_tree_size",
)?; )?;
stmt_insert_block.execute(params![ stmt_upsert_block.execute(named_params![
u32::from(block_height), ":height": u32::from(block_height),
&block_hash.0[..], ":hash": &block_hash.0[..],
block_time, ":block_time": block_time,
sapling_commitment_tree_size ":sapling_commitment_tree_size": sapling_commitment_tree_size
])?; ])?;
Ok(()) Ok(())

View File

@ -11,6 +11,7 @@ use schemer_rusqlite::RusqliteMigration;
use shardtree::ShardTree; use shardtree::ShardTree;
use uuid::Uuid; use uuid::Uuid;
use zcash_client_backend::data_api::SAPLING_SHARD_HEIGHT;
use zcash_primitives::{ use zcash_primitives::{
consensus::BlockHeight, consensus::BlockHeight,
merkle_tree::{read_commitment_tree, read_incremental_witness}, merkle_tree::{read_commitment_tree, read_incremental_witness},
@ -93,7 +94,7 @@ impl RusqliteMigration for Migration {
let mut shard_tree: ShardTree< let mut shard_tree: ShardTree<
_, _,
{ sapling::NOTE_COMMITMENT_TREE_DEPTH }, { sapling::NOTE_COMMITMENT_TREE_DEPTH },
{ sapling::NOTE_COMMITMENT_TREE_DEPTH / 2 }, SAPLING_SHARD_HEIGHT,
> = ShardTree::new(shard_store, 100); > = ShardTree::new(shard_store, 100);
// Insert all the tree information that we can get from block-end commitment trees // Insert all the tree information that we can get from block-end commitment trees
{ {

View File

@ -368,7 +368,7 @@ pub(crate) fn put_received_note<T: ReceivedSaplingOutput>(
#[cfg(test)] #[cfg(test)]
#[allow(deprecated)] #[allow(deprecated)]
mod tests { pub(crate) mod tests {
use rusqlite::Connection; use rusqlite::Connection;
use secrecy::Secret; use secrecy::Secret;
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
@ -427,7 +427,7 @@ mod tests {
}, },
}; };
fn test_prover() -> impl TxProver { pub fn test_prover() -> impl TxProver {
match LocalTxProver::with_default_location() { match LocalTxProver::with_default_location() {
Some(tx_prover) => tx_prover, Some(tx_prover) => tx_prover,
None => { None => {
@ -463,7 +463,7 @@ mod tests {
Amount::from_u64(1).unwrap(), Amount::from_u64(1).unwrap(),
None, None,
OvkPolicy::Sender, OvkPolicy::Sender,
10, 1,
), ),
Err(data_api::error::Error::KeyNotRecognized) Err(data_api::error::Error::KeyNotRecognized)
); );
@ -492,7 +492,7 @@ mod tests {
Amount::from_u64(1).unwrap(), Amount::from_u64(1).unwrap(),
None, None,
OvkPolicy::Sender, OvkPolicy::Sender,
10, 1,
), ),
Err(data_api::error::Error::ScanRequired) Err(data_api::error::Error::ScanRequired)
); );
@ -535,7 +535,7 @@ mod tests {
Amount::from_u64(1).unwrap(), Amount::from_u64(1).unwrap(),
None, None,
OvkPolicy::Sender, OvkPolicy::Sender,
10, 1,
), ),
Err(data_api::error::Error::InsufficientFunds { Err(data_api::error::Error::InsufficientFunds {
available, available,
@ -740,7 +740,7 @@ mod tests {
Amount::from_u64(15000).unwrap(), Amount::from_u64(15000).unwrap(),
None, None,
OvkPolicy::Sender, OvkPolicy::Sender,
10, 1,
), ),
Ok(_) Ok(_)
); );
@ -756,7 +756,7 @@ mod tests {
Amount::from_u64(2000).unwrap(), Amount::from_u64(2000).unwrap(),
None, None,
OvkPolicy::Sender, OvkPolicy::Sender,
10, 1,
), ),
Err(data_api::error::Error::InsufficientFunds { Err(data_api::error::Error::InsufficientFunds {
available, available,
@ -791,7 +791,7 @@ mod tests {
Amount::from_u64(2000).unwrap(), Amount::from_u64(2000).unwrap(),
None, None,
OvkPolicy::Sender, OvkPolicy::Sender,
10, 1,
), ),
Err(data_api::error::Error::InsufficientFunds { Err(data_api::error::Error::InsufficientFunds {
available, available,
@ -822,7 +822,7 @@ mod tests {
Amount::from_u64(2000).unwrap(), Amount::from_u64(2000).unwrap(),
None, None,
OvkPolicy::Sender, OvkPolicy::Sender,
10, 1,
) )
.unwrap(); .unwrap();
} }
@ -874,7 +874,7 @@ mod tests {
Amount::from_u64(15000).unwrap(), Amount::from_u64(15000).unwrap(),
None, None,
ovk_policy, ovk_policy,
10, 1,
) )
.unwrap(); .unwrap();
@ -962,7 +962,7 @@ mod tests {
scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None, None).unwrap(); scan_cached_blocks(&tests::network(), &db_cache, &mut db_data, None, None).unwrap();
// Verified balance matches total balance // Verified balance matches total balance
let (_, anchor_height) = db_data.get_target_and_anchor_heights(10).unwrap().unwrap(); let (_, anchor_height) = db_data.get_target_and_anchor_heights(1).unwrap().unwrap();
assert_eq!( assert_eq!(
get_balance(&db_data.conn, AccountId::from(0)).unwrap(), get_balance(&db_data.conn, AccountId::from(0)).unwrap(),
value value
@ -983,7 +983,7 @@ mod tests {
Amount::from_u64(50000).unwrap(), Amount::from_u64(50000).unwrap(),
None, None,
OvkPolicy::Sender, OvkPolicy::Sender,
10, 1,
), ),
Ok(_) Ok(_)
); );
@ -1039,7 +1039,7 @@ mod tests {
Amount::from_u64(50000).unwrap(), Amount::from_u64(50000).unwrap(),
None, None,
OvkPolicy::Sender, OvkPolicy::Sender,
10, 1,
), ),
Ok(_) Ok(_)
); );
@ -1193,7 +1193,7 @@ mod tests {
DustOutputPolicy::default(), DustOutputPolicy::default(),
); );
// Add funds to the wallet // Ensure that the wallet has at least one block
let (cb, _) = fake_compact_block( let (cb, _) = fake_compact_block(
sapling_activation_height(), sapling_activation_height(),
BlockHash([0; 32]), BlockHash([0; 32]),
@ -1215,7 +1215,7 @@ mod tests {
&usk, &usk,
&[*taddr], &[*taddr],
&MemoBytes::empty(), &MemoBytes::empty(),
0 1
), ),
Ok(_) Ok(_)
); );

View File

@ -1,19 +1,22 @@
use either::Either; use either::Either;
use incrementalmerkletree::{Address, Position};
use rusqlite::{self, named_params, Connection, OptionalExtension}; use rusqlite::{self, named_params, Connection, OptionalExtension};
use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore, TreeState};
use std::{ use std::{
collections::BTreeSet, collections::BTreeSet,
io::{self, Cursor}, io::{self, Cursor},
ops::Deref, ops::Deref,
}; };
use incrementalmerkletree::{Address, Level, Position};
use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore, TreeState};
use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer, sapling}; use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer, sapling};
use zcash_client_backend::data_api::SAPLING_SHARD_HEIGHT;
use crate::serialization::{read_shard, write_shard_v1}; use crate::serialization::{read_shard, write_shard_v1};
const SHARD_ROOT_LEVEL: Level = Level::new(SAPLING_SHARD_HEIGHT);
pub struct WalletDbSaplingShardStore<'conn, 'a> { pub struct WalletDbSaplingShardStore<'conn, 'a> {
pub(crate) conn: &'a rusqlite::Transaction<'conn>, pub(crate) conn: &'a rusqlite::Transaction<'conn>,
} }
@ -39,8 +42,7 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
} }
fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> { fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
// SELECT shard_data FROM sapling_tree ORDER BY shard_index DESC LIMIT 1 last_shard(self.conn)
todo!()
} }
fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> { fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
@ -48,8 +50,7 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
} }
fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> { fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
// SELECT get_shard_roots(self.conn)
todo!()
} }
fn truncate(&mut self, from: Address) -> Result<(), Self::Error> { fn truncate(&mut self, from: Address) -> Result<(), Self::Error> {
@ -86,9 +87,9 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
fn get_checkpoint_at_depth( fn get_checkpoint_at_depth(
&self, &self,
_checkpoint_depth: usize, checkpoint_depth: usize,
) -> Result<Option<(Self::CheckpointId, Checkpoint)>, Self::Error> { ) -> Result<Option<(Self::CheckpointId, Checkpoint)>, Self::Error> {
todo!() get_checkpoint_at_depth(self.conn, checkpoint_depth)
} }
fn get_checkpoint( fn get_checkpoint(
@ -150,6 +151,31 @@ pub(crate) fn get_shard(
.transpose() .transpose()
} }
pub(crate) fn last_shard(
conn: &rusqlite::Connection,
) -> Result<Option<LocatedPrunableTree<sapling::Node>>, Error> {
conn.query_row(
"SELECT shard_index, shard_data
FROM sapling_tree_shards
ORDER BY shard_index DESC
LIMIT 1",
[],
|row| {
let shard_index: u64 = row.get(0)?;
let shard_data: Vec<u8> = row.get(1)?;
Ok((shard_index, shard_data))
},
)
.optional()
.map_err(Either::Right)?
.map(|(shard_index, shard_data)| {
let shard_root = Address::from_parts(SHARD_ROOT_LEVEL, shard_index);
let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Either::Left)?;
Ok(LocatedPrunableTree::from_parts(shard_root, shard_tree))
})
.transpose()
}
pub(crate) fn put_shard( pub(crate) fn put_shard(
conn: &rusqlite::Connection, conn: &rusqlite::Connection,
subtree: LocatedPrunableTree<sapling::Node>, subtree: LocatedPrunableTree<sapling::Node>,
@ -172,10 +198,10 @@ pub(crate) fn put_shard(
conn.prepare_cached( conn.prepare_cached(
"INSERT INTO sapling_tree_shards (shard_index, root_hash, shard_data) "INSERT INTO sapling_tree_shards (shard_index, root_hash, shard_data)
VALUES (:shard_index, :root_hash, :shard_data) VALUES (:shard_index, :root_hash, :shard_data)
ON CONFLICT (shard_index) DO UPDATE ON CONFLICT (shard_index) DO UPDATE
SET root_hash = :root_hash, SET root_hash = :root_hash,
shard_data = :shard_data", shard_data = :shard_data",
) )
.and_then(|mut stmt_put_shard| { .and_then(|mut stmt_put_shard| {
stmt_put_shard.execute(named_params![ stmt_put_shard.execute(named_params![
@ -189,6 +215,22 @@ pub(crate) fn put_shard(
Ok(()) Ok(())
} }
pub(crate) fn get_shard_roots(conn: &rusqlite::Connection) -> Result<Vec<Address>, Error> {
let mut stmt = conn
.prepare("SELECT shard_index FROM sapling_tree_shards ORDER BY shard_index")
.map_err(Either::Right)?;
let mut rows = stmt.query([]).map_err(Either::Right)?;
let mut res = vec![];
while let Some(row) = rows.next().map_err(Either::Right)? {
res.push(Address::from_parts(
SHARD_ROOT_LEVEL,
row.get(0).map_err(Either::Right)?,
));
}
Ok(res)
}
pub(crate) fn truncate(conn: &rusqlite::Transaction<'_>, from: Address) -> Result<(), Error> { pub(crate) fn truncate(conn: &rusqlite::Transaction<'_>, from: Address) -> Result<(), Error> {
conn.execute( conn.execute(
"DELETE FROM sapling_tree_shards WHERE shard_index >= ?", "DELETE FROM sapling_tree_shards WHERE shard_index >= ?",
@ -264,8 +306,8 @@ pub(crate) fn get_checkpoint<C: Deref<Target = Connection>>(
let checkpoint_position = conn let checkpoint_position = conn
.query_row( .query_row(
"SELECT position "SELECT position
FROM sapling_tree_checkpoints FROM sapling_tree_checkpoints
WHERE checkpoint_id = ?", WHERE checkpoint_id = ?",
[u32::from(checkpoint_id)], [u32::from(checkpoint_id)],
|row| { |row| {
row.get::<_, Option<u64>>(0) row.get::<_, Option<u64>>(0)
@ -275,32 +317,91 @@ pub(crate) fn get_checkpoint<C: Deref<Target = Connection>>(
.optional() .optional()
.map_err(Either::Right)?; .map_err(Either::Right)?;
let mut marks_removed = BTreeSet::new(); checkpoint_position
let mut stmt = conn .map(|pos_opt| {
.prepare_cached( let mut marks_removed = BTreeSet::new();
"SELECT mark_removed_position let mut stmt = conn
FROM sapling_tree_checkpoint_marks_removed .prepare_cached(
WHERE checkpoint_id = ?", "SELECT mark_removed_position
FROM sapling_tree_checkpoint_marks_removed
WHERE checkpoint_id = ?",
)
.map_err(Either::Right)?;
let mut mark_removed_rows = stmt
.query([u32::from(checkpoint_id)])
.map_err(Either::Right)?;
while let Some(row) = mark_removed_rows.next().map_err(Either::Right)? {
marks_removed.insert(
row.get::<_, u64>(0)
.map(Position::from)
.map_err(Either::Right)?,
);
}
Ok(Checkpoint::from_parts(
pos_opt.map_or(TreeState::Empty, TreeState::AtPosition),
marks_removed,
))
})
.transpose()
}
pub(crate) fn get_checkpoint_at_depth<C: Deref<Target = Connection>>(
conn: &C,
checkpoint_depth: usize,
) -> Result<Option<(BlockHeight, Checkpoint)>, Either<io::Error, rusqlite::Error>> {
let checkpoint_parts = conn
.query_row(
"SELECT checkpoint_id, position
FROM sapling_tree_checkpoints
ORDER BY checkpoint_id DESC
LIMIT 1
OFFSET :offset",
named_params![":offset": checkpoint_depth],
|row| {
let checkpoint_id: u32 = row.get(0)?;
let position: Option<u64> = row.get(1)?;
Ok((
BlockHeight::from(checkpoint_id),
position.map(Position::from),
))
},
) )
.map_err(Either::Right)?; .optional()
let mut mark_removed_rows = stmt
.query([u32::from(checkpoint_id)])
.map_err(Either::Right)?; .map_err(Either::Right)?;
while let Some(row) = mark_removed_rows.next().map_err(Either::Right)? { checkpoint_parts
marks_removed.insert( .map(|(checkpoint_id, pos_opt)| {
row.get::<_, u64>(0) let mut marks_removed = BTreeSet::new();
.map(Position::from) let mut stmt = conn
.map_err(Either::Right)?, .prepare_cached(
); "SELECT mark_removed_position
} FROM sapling_tree_checkpoint_marks_removed
WHERE checkpoint_id = ?",
)
.map_err(Either::Right)?;
let mut mark_removed_rows = stmt
.query([u32::from(checkpoint_id)])
.map_err(Either::Right)?;
Ok(checkpoint_position.map(|pos_opt| { while let Some(row) = mark_removed_rows.next().map_err(Either::Right)? {
Checkpoint::from_parts( marks_removed.insert(
pos_opt.map_or(TreeState::Empty, TreeState::AtPosition), row.get::<_, u64>(0)
marks_removed, .map(Position::from)
) .map_err(Either::Right)?,
})) );
}
Ok((
checkpoint_id,
Checkpoint::from_parts(
pos_opt.map_or(TreeState::Empty, TreeState::AtPosition),
marks_removed,
),
))
})
.transpose()
} }
pub(crate) fn update_checkpoint_with<F>( pub(crate) fn update_checkpoint_with<F>(