zcash_client_sqlite: Add tests for sqlite-backed ShardTree & fix revealed issues.

This commit is contained in:
Kris Nuttycombe 2023-06-15 13:50:07 -06:00
parent 425b5e01d7
commit 0a4236f725
6 changed files with 395 additions and 77 deletions

View File

@ -66,9 +66,7 @@ use zcash_client_backend::{
DecryptedOutput, TransferType,
};
use crate::{
error::SqliteClientError, wallet::sapling::commitment_tree::WalletDbSaplingShardStore,
};
use crate::{error::SqliteClientError, wallet::sapling::commitment_tree::SqliteShardStore};
#[cfg(feature = "unstable")]
use {
@ -617,7 +615,8 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Connection, P> {
type Error = Either<io::Error, rusqlite::Error>;
type SaplingShardStore<'a> = WalletDbSaplingShardStore<'a, 'a>;
type SaplingShardStore<'a> =
SqliteShardStore<&'a rusqlite::Transaction<'a>, sapling::Node, SAPLING_SHARD_HEIGHT>;
fn with_sapling_tree_mut<F, A, E>(&mut self, mut callback: F) -> Result<A, E>
where
@ -634,7 +633,7 @@ impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Conn
.conn
.transaction()
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
let shard_store = WalletDbSaplingShardStore::from_connection(&tx)
let shard_store = SqliteShardStore::from_connection(&tx)
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
let result = {
let mut shardtree = ShardTree::new(shard_store, 100);
@ -648,7 +647,8 @@ impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Conn
impl<'conn, P: consensus::Parameters> WalletCommitmentTrees for WalletDb<SqlTransaction<'conn>, P> {
type Error = Either<io::Error, rusqlite::Error>;
type SaplingShardStore<'a> = WalletDbSaplingShardStore<'a, 'a>;
type SaplingShardStore<'a> =
SqliteShardStore<&'a rusqlite::Transaction<'a>, sapling::Node, SAPLING_SHARD_HEIGHT>;
fn with_sapling_tree_mut<F, A, E>(&mut self, mut callback: F) -> Result<A, E>
where
@ -662,7 +662,7 @@ impl<'conn, P: consensus::Parameters> WalletCommitmentTrees for WalletDb<SqlTran
E: From<ShardTreeError<Either<io::Error, rusqlite::Error>>>,
{
let mut shardtree = ShardTree::new(
WalletDbSaplingShardStore::from_connection(self.conn.0)
SqliteShardStore::from_connection(self.conn.0)
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?,
100,
);

View File

@ -403,6 +403,7 @@ mod tests {
checkpoint_id INTEGER NOT NULL,
mark_removed_position INTEGER NOT NULL,
FOREIGN KEY (checkpoint_id) REFERENCES sapling_tree_checkpoints(checkpoint_id)
ON DELETE CASCADE
)",
"CREATE TABLE sapling_tree_checkpoints (
checkpoint_id INTEGER PRIMARY KEY,

View File

@ -20,7 +20,7 @@ use zcash_primitives::{
use crate::wallet::{
init::{migrations::received_notes_nullable_nf, WalletMigrationError},
sapling::commitment_tree::WalletDbSaplingShardStore,
sapling::commitment_tree::SqliteShardStore,
};
pub(super) const MIGRATION_ID: Uuid = Uuid::from_fields(
@ -87,10 +87,14 @@ impl RusqliteMigration for Migration {
checkpoint_id INTEGER NOT NULL,
mark_removed_position INTEGER NOT NULL,
FOREIGN KEY (checkpoint_id) REFERENCES sapling_tree_checkpoints(checkpoint_id)
ON DELETE CASCADE
);",
)?;
let shard_store = WalletDbSaplingShardStore::from_connection(transaction)?;
let shard_store =
SqliteShardStore::<_, sapling::Node, SAPLING_SHARD_HEIGHT>::from_connection(
transaction,
)?;
let mut shard_tree: ShardTree<
_,
{ sapling::NOTE_COMMITMENT_TREE_DEPTH },

View File

@ -1,36 +1,38 @@
use either::Either;
use rusqlite::{self, named_params, Connection, OptionalExtension};
use rusqlite::{self, named_params, OptionalExtension};
use std::{
collections::BTreeSet,
io::{self, Cursor},
ops::Deref,
marker::PhantomData,
};
use incrementalmerkletree::{Address, Level, Position};
use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore, TreeState};
use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer, sapling};
use zcash_client_backend::data_api::SAPLING_SHARD_HEIGHT;
use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer};
use crate::serialization::{read_shard, write_shard_v1};
const SHARD_ROOT_LEVEL: Level = Level::new(SAPLING_SHARD_HEIGHT);
pub struct WalletDbSaplingShardStore<'conn, 'a> {
pub(crate) conn: &'a rusqlite::Transaction<'conn>,
pub struct SqliteShardStore<C, H, const SHARD_HEIGHT: u8> {
pub(crate) conn: C,
_hash_type: PhantomData<H>,
}
impl<'conn, 'a> WalletDbSaplingShardStore<'conn, 'a> {
pub(crate) fn from_connection(
conn: &'a rusqlite::Transaction<'conn>,
) -> Result<Self, rusqlite::Error> {
Ok(WalletDbSaplingShardStore { conn })
impl<C, H, const SHARD_HEIGHT: u8> SqliteShardStore<C, H, SHARD_HEIGHT> {
const SHARD_ROOT_LEVEL: Level = Level::new(SHARD_HEIGHT);
pub(crate) fn from_connection(conn: C) -> Result<Self, rusqlite::Error> {
Ok(SqliteShardStore {
conn,
_hash_type: PhantomData,
})
}
}
impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
type H = sapling::Node;
impl<'conn, 'a: 'conn, H: HashSer, const SHARD_HEIGHT: u8> ShardStore
for SqliteShardStore<&'a rusqlite::Transaction<'conn>, H, SHARD_HEIGHT>
{
type H = H;
type CheckpointId = BlockHeight;
type Error = Either<io::Error, rusqlite::Error>;
@ -42,7 +44,7 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
}
fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
last_shard(self.conn)
last_shard(self.conn, Self::SHARD_ROOT_LEVEL)
}
fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
@ -50,7 +52,7 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
}
fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
get_shard_roots(self.conn)
get_shard_roots(self.conn, Self::SHARD_ROOT_LEVEL)
}
fn truncate(&mut self, from: Address) -> Result<(), Self::Error> {
@ -66,11 +68,11 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
}
fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
todo!()
min_checkpoint_id(self.conn)
}
fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
todo!()
max_checkpoint_id(self.conn)
}
fn add_checkpoint(
@ -99,11 +101,11 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
get_checkpoint(self.conn, *checkpoint_id)
}
fn with_checkpoints<F>(&mut self, _limit: usize, _callback: F) -> Result<(), Self::Error>
fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
where
F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
{
todo!()
with_checkpoints(self.conn, limit, callback)
}
fn update_checkpoint_with<F>(
@ -129,12 +131,128 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
}
}
impl<H: HashSer, const SHARD_HEIGHT: u8> ShardStore
for SqliteShardStore<rusqlite::Connection, H, SHARD_HEIGHT>
{
type H = H;
type CheckpointId = BlockHeight;
type Error = Either<io::Error, rusqlite::Error>;
fn get_shard(
&self,
shard_root: Address,
) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
get_shard(&self.conn, shard_root)
}
fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
last_shard(&self.conn, Self::SHARD_ROOT_LEVEL)
}
fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
put_shard(&tx, subtree)?;
tx.commit().map_err(Either::Right)?;
Ok(())
}
fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
get_shard_roots(&self.conn, Self::SHARD_ROOT_LEVEL)
}
fn truncate(&mut self, from: Address) -> Result<(), Self::Error> {
truncate(&self.conn, from)
}
fn get_cap(&self) -> Result<PrunableTree<Self::H>, Self::Error> {
get_cap(&self.conn)
}
fn put_cap(&mut self, cap: PrunableTree<Self::H>) -> Result<(), Self::Error> {
put_cap(&self.conn, cap)
}
fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
min_checkpoint_id(&self.conn)
}
fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
max_checkpoint_id(&self.conn)
}
fn add_checkpoint(
&mut self,
checkpoint_id: Self::CheckpointId,
checkpoint: Checkpoint,
) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
add_checkpoint(&tx, checkpoint_id, checkpoint)?;
tx.commit().map_err(Either::Right)
}
fn checkpoint_count(&self) -> Result<usize, Self::Error> {
checkpoint_count(&self.conn)
}
fn get_checkpoint_at_depth(
&self,
checkpoint_depth: usize,
) -> Result<Option<(Self::CheckpointId, Checkpoint)>, Self::Error> {
get_checkpoint_at_depth(&self.conn, checkpoint_depth)
}
fn get_checkpoint(
&self,
checkpoint_id: &Self::CheckpointId,
) -> Result<Option<Checkpoint>, Self::Error> {
get_checkpoint(&self.conn, *checkpoint_id)
}
fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
where
F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
{
let tx = self.conn.transaction().map_err(Either::Right)?;
with_checkpoints(&tx, limit, callback)?;
tx.commit().map_err(Either::Right)
}
fn update_checkpoint_with<F>(
&mut self,
checkpoint_id: &Self::CheckpointId,
update: F,
) -> Result<bool, Self::Error>
where
F: Fn(&mut Checkpoint) -> Result<(), Self::Error>,
{
let tx = self.conn.transaction().map_err(Either::Right)?;
let result = update_checkpoint_with(&tx, *checkpoint_id, update)?;
tx.commit().map_err(Either::Right)?;
Ok(result)
}
fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
remove_checkpoint(&tx, *checkpoint_id)?;
tx.commit().map_err(Either::Right)
}
fn truncate_checkpoints(
&mut self,
checkpoint_id: &Self::CheckpointId,
) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
truncate_checkpoints(&tx, *checkpoint_id)?;
tx.commit().map_err(Either::Right)
}
}
type Error = Either<io::Error, rusqlite::Error>;
pub(crate) fn get_shard(
pub(crate) fn get_shard<H: HashSer>(
conn: &rusqlite::Connection,
shard_root: Address,
) -> Result<Option<LocatedPrunableTree<sapling::Node>>, Error> {
) -> Result<Option<LocatedPrunableTree<H>>, Error> {
conn.query_row(
"SELECT shard_data
FROM sapling_tree_shards
@ -151,9 +269,10 @@ pub(crate) fn get_shard(
.transpose()
}
pub(crate) fn last_shard(
pub(crate) fn last_shard<H: HashSer>(
conn: &rusqlite::Connection,
) -> Result<Option<LocatedPrunableTree<sapling::Node>>, Error> {
shard_root_level: Level,
) -> Result<Option<LocatedPrunableTree<H>>, Error> {
conn.query_row(
"SELECT shard_index, shard_data
FROM sapling_tree_shards
@ -169,16 +288,16 @@ pub(crate) fn last_shard(
.optional()
.map_err(Either::Right)?
.map(|(shard_index, shard_data)| {
let shard_root = Address::from_parts(SHARD_ROOT_LEVEL, shard_index);
let shard_root = Address::from_parts(shard_root_level, shard_index);
let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Either::Left)?;
Ok(LocatedPrunableTree::from_parts(shard_root, shard_tree))
})
.transpose()
}
pub(crate) fn put_shard(
conn: &rusqlite::Connection,
subtree: LocatedPrunableTree<sapling::Node>,
pub(crate) fn put_shard<H: HashSer>(
conn: &rusqlite::Transaction<'_>,
subtree: LocatedPrunableTree<H>,
) -> Result<(), Error> {
let subtree_root_hash = subtree
.root()
@ -196,26 +315,31 @@ pub(crate) fn put_shard(
let mut subtree_data = vec![];
write_shard_v1(&mut subtree_data, subtree.root()).map_err(Either::Left)?;
conn.prepare_cached(
"INSERT INTO sapling_tree_shards (shard_index, root_hash, shard_data)
VALUES (:shard_index, :root_hash, :shard_data)
ON CONFLICT (shard_index) DO UPDATE
SET root_hash = :root_hash,
shard_data = :shard_data",
)
.and_then(|mut stmt_put_shard| {
stmt_put_shard.execute(named_params![
let mut stmt_put_shard = conn
.prepare_cached(
"INSERT INTO sapling_tree_shards (shard_index, root_hash, shard_data)
VALUES (:shard_index, :root_hash, :shard_data)
ON CONFLICT (shard_index) DO UPDATE
SET root_hash = :root_hash,
shard_data = :shard_data",
)
.map_err(Either::Right)?;
stmt_put_shard
.execute(named_params![
":shard_index": subtree.root_addr().index(),
":root_hash": subtree_root_hash,
":shard_data": subtree_data
])
})
.map_err(Either::Right)?;
.map_err(Either::Right)?;
Ok(())
}
pub(crate) fn get_shard_roots(conn: &rusqlite::Connection) -> Result<Vec<Address>, Error> {
pub(crate) fn get_shard_roots(
conn: &rusqlite::Connection,
shard_root_level: Level,
) -> Result<Vec<Address>, Error> {
let mut stmt = conn
.prepare("SELECT shard_index FROM sapling_tree_shards ORDER BY shard_index")
.map_err(Either::Right)?;
@ -224,14 +348,14 @@ pub(crate) fn get_shard_roots(conn: &rusqlite::Connection) -> Result<Vec<Address
let mut res = vec![];
while let Some(row) = rows.next().map_err(Either::Right)? {
res.push(Address::from_parts(
SHARD_ROOT_LEVEL,
shard_root_level,
row.get(0).map_err(Either::Right)?,
));
}
Ok(res)
}
pub(crate) fn truncate(conn: &rusqlite::Transaction<'_>, from: Address) -> Result<(), Error> {
pub(crate) fn truncate(conn: &rusqlite::Connection, from: Address) -> Result<(), Error> {
conn.execute(
"DELETE FROM sapling_tree_shards WHERE shard_index >= ?",
[from.index()],
@ -240,7 +364,7 @@ pub(crate) fn truncate(conn: &rusqlite::Transaction<'_>, from: Address) -> Resul
.map(|_| ())
}
pub(crate) fn get_cap(conn: &rusqlite::Connection) -> Result<PrunableTree<sapling::Node>, Error> {
pub(crate) fn get_cap<H: HashSer>(conn: &rusqlite::Connection) -> Result<PrunableTree<H>, Error> {
conn.query_row("SELECT cap_data FROM sapling_tree_cap", [], |row| {
row.get::<_, Vec<u8>>(0)
})
@ -252,9 +376,9 @@ pub(crate) fn get_cap(conn: &rusqlite::Connection) -> Result<PrunableTree<saplin
)
}
pub(crate) fn put_cap(
conn: &rusqlite::Transaction<'_>,
cap: PrunableTree<sapling::Node>,
pub(crate) fn put_cap<H: HashSer>(
conn: &rusqlite::Connection,
cap: PrunableTree<H>,
) -> Result<(), Error> {
let mut stmt = conn
.prepare_cached(
@ -272,22 +396,62 @@ pub(crate) fn put_cap(
Ok(())
}
pub(crate) fn min_checkpoint_id(conn: &rusqlite::Connection) -> Result<Option<BlockHeight>, Error> {
conn.query_row(
"SELECT MIN(checkpoint_id) FROM sapling_tree_checkpoints",
[],
|row| {
row.get::<_, Option<u32>>(0)
.map(|opt| opt.map(BlockHeight::from))
},
)
.map_err(Either::Right)
}
pub(crate) fn max_checkpoint_id(conn: &rusqlite::Connection) -> Result<Option<BlockHeight>, Error> {
conn.query_row(
"SELECT MAX(checkpoint_id) FROM sapling_tree_checkpoints",
[],
|row| {
row.get::<_, Option<u32>>(0)
.map(|opt| opt.map(BlockHeight::from))
},
)
.map_err(Either::Right)
}
pub(crate) fn add_checkpoint(
conn: &rusqlite::Transaction<'_>,
checkpoint_id: BlockHeight,
checkpoint: Checkpoint,
) -> Result<(), Error> {
conn.prepare_cached(
"INSERT INTO sapling_tree_checkpoints (checkpoint_id, position)
VALUES (:checkpoint_id, :position)",
)
.and_then(|mut stmt_insert_checkpoint| {
stmt_insert_checkpoint.execute(named_params![
let mut stmt_insert_checkpoint = conn
.prepare_cached(
"INSERT INTO sapling_tree_checkpoints (checkpoint_id, position)
VALUES (:checkpoint_id, :position)",
)
.map_err(Either::Right)?;
stmt_insert_checkpoint
.execute(named_params![
":checkpoint_id": u32::from(checkpoint_id),
":position": checkpoint.position().map(u64::from)
])
})
.map_err(Either::Right)?;
.map_err(Either::Right)?;
let mut stmt_insert_mark_removed = conn.prepare_cached(
"INSERT INTO sapling_tree_checkpoint_marks_removed (checkpoint_id, mark_removed_position)
VALUES (:checkpoint_id, :position)",
).map_err(Either::Right)?;
for pos in checkpoint.marks_removed() {
stmt_insert_mark_removed
.execute(named_params![
":checkpoint_id": u32::from(checkpoint_id),
":position": u64::from(*pos)
])
.map_err(Either::Right)?;
}
Ok(())
}
@ -299,10 +463,10 @@ pub(crate) fn checkpoint_count(conn: &rusqlite::Connection) -> Result<usize, Err
.map_err(Either::Right)
}
pub(crate) fn get_checkpoint<C: Deref<Target = Connection>>(
conn: &C,
pub(crate) fn get_checkpoint(
conn: &rusqlite::Connection,
checkpoint_id: BlockHeight,
) -> Result<Option<Checkpoint>, Either<io::Error, rusqlite::Error>> {
) -> Result<Option<Checkpoint>, Error> {
let checkpoint_position = conn
.query_row(
"SELECT position
@ -347,10 +511,14 @@ pub(crate) fn get_checkpoint<C: Deref<Target = Connection>>(
.transpose()
}
pub(crate) fn get_checkpoint_at_depth<C: Deref<Target = Connection>>(
conn: &C,
pub(crate) fn get_checkpoint_at_depth(
conn: &rusqlite::Connection,
checkpoint_depth: usize,
) -> Result<Option<(BlockHeight, Checkpoint)>, Either<io::Error, rusqlite::Error>> {
) -> Result<Option<(BlockHeight, Checkpoint)>, Error> {
if checkpoint_depth == 0 {
return Ok(None);
}
let checkpoint_parts = conn
.query_row(
"SELECT checkpoint_id, position
@ -358,7 +526,7 @@ pub(crate) fn get_checkpoint_at_depth<C: Deref<Target = Connection>>(
ORDER BY checkpoint_id DESC
LIMIT 1
OFFSET :offset",
named_params![":offset": checkpoint_depth],
named_params![":offset": checkpoint_depth - 1],
|row| {
let checkpoint_id: u32 = row.get(0)?;
let position: Option<u64> = row.get(1)?;
@ -404,6 +572,62 @@ pub(crate) fn get_checkpoint_at_depth<C: Deref<Target = Connection>>(
.transpose()
}
pub(crate) fn with_checkpoints<F>(
conn: &rusqlite::Transaction<'_>,
limit: usize,
mut callback: F,
) -> Result<(), Error>
where
F: FnMut(&BlockHeight, &Checkpoint) -> Result<(), Error>,
{
let mut stmt_get_checkpoints = conn
.prepare_cached(
"SELECT checkpoint_id, position
FROM sapling_tree_checkpoints
LIMIT :limit",
)
.map_err(Either::Right)?;
let mut stmt_get_checkpoint_marks_removed = conn
.prepare_cached(
"SELECT mark_removed_position
FROM sapling_tree_checkpoint_marks_removed
WHERE checkpoint_id = :checkpoint_id",
)
.map_err(Either::Right)?;
let mut rows = stmt_get_checkpoints
.query(named_params![":limit": limit])
.map_err(Either::Right)?;
while let Some(row) = rows.next().map_err(Either::Right)? {
let checkpoint_id = row.get::<_, u32>(0).map_err(Either::Right)?;
let tree_state = row
.get::<_, Option<u64>>(1)
.map(|opt| opt.map_or_else(|| TreeState::Empty, |p| TreeState::AtPosition(p.into())))
.map_err(Either::Right)?;
let mut mark_removed_rows = stmt_get_checkpoint_marks_removed
.query(named_params![":checkpoint_id": checkpoint_id])
.map_err(Either::Right)?;
let mut marks_removed = BTreeSet::new();
while let Some(mr_row) = mark_removed_rows.next().map_err(Either::Right)? {
let mark_removed_position = mr_row
.get::<_, u64>(0)
.map(Position::from)
.map_err(Either::Right)?;
marks_removed.insert(mark_removed_position);
}
callback(
&BlockHeight::from(checkpoint_id),
&Checkpoint::from_parts(tree_state, marks_removed),
)?
}
Ok(())
}
pub(crate) fn update_checkpoint_with<F>(
conn: &rusqlite::Transaction<'_>,
checkpoint_id: BlockHeight,
@ -426,11 +650,17 @@ pub(crate) fn remove_checkpoint(
conn: &rusqlite::Transaction<'_>,
checkpoint_id: BlockHeight,
) -> Result<(), Error> {
conn.execute(
"DELETE FROM sapling_tree_checkpoints WHERE checkpoint_id = ?",
[u32::from(checkpoint_id)],
)
.map_err(Either::Right)?;
// sapling_tree_checkpoints is constructed with `ON DELETE CASCADE`
let mut stmt_delete_checkpoint = conn
.prepare_cached(
"DELETE FROM sapling_tree_checkpoints
WHERE checkpoint_id = :checkpoint_id",
)
.map_err(Either::Right)?;
stmt_delete_checkpoint
.execute(named_params![":checkpoint_id": u32::from(checkpoint_id),])
.map_err(Either::Right)?;
Ok(())
}
@ -452,3 +682,62 @@ pub(crate) fn truncate_checkpoints(
.map_err(Either::Right)?;
Ok(())
}
#[cfg(test)]
mod tests {
use tempfile::NamedTempFile;
use incrementalmerkletree::testing::{
check_append, check_checkpoint_rewind, check_remove_mark, check_rewind_remove_mark,
check_root_hashes, check_witness_consistency, check_witnesses,
};
use shardtree::ShardTree;
use super::SqliteShardStore;
use crate::{tests, wallet::init::init_wallet_db, WalletDb};
fn new_tree(m: usize) -> ShardTree<SqliteShardStore<rusqlite::Connection, String, 3>, 4, 3> {
let data_file = NamedTempFile::new().unwrap();
let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap();
data_file.keep().unwrap();
init_wallet_db(&mut db_data, None).unwrap();
let store = SqliteShardStore::<_, String, 3>::from_connection(db_data.conn).unwrap();
ShardTree::new(store, m)
}
#[test]
fn append() {
check_append(new_tree);
}
#[test]
fn root_hashes() {
check_root_hashes(new_tree);
}
#[test]
fn witnesses() {
check_witnesses(new_tree);
}
#[test]
fn witness_consistency() {
check_witness_consistency(new_tree);
}
#[test]
fn checkpoint_rewind() {
check_checkpoint_rewind(new_tree);
}
#[test]
fn remove_mark() {
check_remove_mark(new_tree);
}
#[test]
fn rewind_remove_mark() {
check_rewind_remove_mark(new_tree);
}
}

View File

@ -627,6 +627,12 @@ pub mod testing {
)
})
}
impl incrementalmerkletree::testing::TestCheckpoint for BlockHeight {
fn from_u64(value: u64) -> Self {
BlockHeight(u32::try_from(value).expect("Test checkpoint ids do not exceed 32 bits"))
}
}
}
#[cfg(test)]

View File

@ -292,6 +292,7 @@ pub mod testing {
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use incrementalmerkletree::frontier::testing::TestNode;
use std::io::{self, Read, Write};
use zcash_encoding::Vector;
use super::HashSer;
@ -304,6 +305,23 @@ pub mod testing {
writer.write_u64::<LittleEndian>(self.0)
}
}
impl HashSer for String {
fn read<R: Read>(reader: R) -> io::Result<String> {
Vector::read(reader, |r| r.read_u8()).and_then(|xs| {
String::from_utf8(xs).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Not a valid utf8 string: {:?}", e),
)
})
})
}
fn write<W: Write>(&self, writer: W) -> io::Result<()> {
Vector::write(writer, self.as_bytes(), |w, b| w.write_u8(*b))
}
}
}
#[cfg(test)]