zcash_client_sqlite: Add tests for sqlite-backed ShardTree & fix revealed issues.

This commit is contained in:
Kris Nuttycombe 2023-06-15 13:50:07 -06:00
parent 425b5e01d7
commit 0a4236f725
6 changed files with 395 additions and 77 deletions

View File

@ -66,9 +66,7 @@ use zcash_client_backend::{
DecryptedOutput, TransferType, DecryptedOutput, TransferType,
}; };
use crate::{ use crate::{error::SqliteClientError, wallet::sapling::commitment_tree::SqliteShardStore};
error::SqliteClientError, wallet::sapling::commitment_tree::WalletDbSaplingShardStore,
};
#[cfg(feature = "unstable")] #[cfg(feature = "unstable")]
use { use {
@ -617,7 +615,8 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Connection, P> { impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Connection, P> {
type Error = Either<io::Error, rusqlite::Error>; type Error = Either<io::Error, rusqlite::Error>;
type SaplingShardStore<'a> = WalletDbSaplingShardStore<'a, 'a>; type SaplingShardStore<'a> =
SqliteShardStore<&'a rusqlite::Transaction<'a>, sapling::Node, SAPLING_SHARD_HEIGHT>;
fn with_sapling_tree_mut<F, A, E>(&mut self, mut callback: F) -> Result<A, E> fn with_sapling_tree_mut<F, A, E>(&mut self, mut callback: F) -> Result<A, E>
where where
@ -634,7 +633,7 @@ impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Conn
.conn .conn
.transaction() .transaction()
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?; .map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
let shard_store = WalletDbSaplingShardStore::from_connection(&tx) let shard_store = SqliteShardStore::from_connection(&tx)
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?; .map_err(|e| ShardTreeError::Storage(Either::Right(e)))?;
let result = { let result = {
let mut shardtree = ShardTree::new(shard_store, 100); let mut shardtree = ShardTree::new(shard_store, 100);
@ -648,7 +647,8 @@ impl<P: consensus::Parameters> WalletCommitmentTrees for WalletDb<rusqlite::Conn
impl<'conn, P: consensus::Parameters> WalletCommitmentTrees for WalletDb<SqlTransaction<'conn>, P> { impl<'conn, P: consensus::Parameters> WalletCommitmentTrees for WalletDb<SqlTransaction<'conn>, P> {
type Error = Either<io::Error, rusqlite::Error>; type Error = Either<io::Error, rusqlite::Error>;
type SaplingShardStore<'a> = WalletDbSaplingShardStore<'a, 'a>; type SaplingShardStore<'a> =
SqliteShardStore<&'a rusqlite::Transaction<'a>, sapling::Node, SAPLING_SHARD_HEIGHT>;
fn with_sapling_tree_mut<F, A, E>(&mut self, mut callback: F) -> Result<A, E> fn with_sapling_tree_mut<F, A, E>(&mut self, mut callback: F) -> Result<A, E>
where where
@ -662,7 +662,7 @@ impl<'conn, P: consensus::Parameters> WalletCommitmentTrees for WalletDb<SqlTran
E: From<ShardTreeError<Either<io::Error, rusqlite::Error>>>, E: From<ShardTreeError<Either<io::Error, rusqlite::Error>>>,
{ {
let mut shardtree = ShardTree::new( let mut shardtree = ShardTree::new(
WalletDbSaplingShardStore::from_connection(self.conn.0) SqliteShardStore::from_connection(self.conn.0)
.map_err(|e| ShardTreeError::Storage(Either::Right(e)))?, .map_err(|e| ShardTreeError::Storage(Either::Right(e)))?,
100, 100,
); );

View File

@ -403,6 +403,7 @@ mod tests {
checkpoint_id INTEGER NOT NULL, checkpoint_id INTEGER NOT NULL,
mark_removed_position INTEGER NOT NULL, mark_removed_position INTEGER NOT NULL,
FOREIGN KEY (checkpoint_id) REFERENCES sapling_tree_checkpoints(checkpoint_id) FOREIGN KEY (checkpoint_id) REFERENCES sapling_tree_checkpoints(checkpoint_id)
ON DELETE CASCADE
)", )",
"CREATE TABLE sapling_tree_checkpoints ( "CREATE TABLE sapling_tree_checkpoints (
checkpoint_id INTEGER PRIMARY KEY, checkpoint_id INTEGER PRIMARY KEY,

View File

@ -20,7 +20,7 @@ use zcash_primitives::{
use crate::wallet::{ use crate::wallet::{
init::{migrations::received_notes_nullable_nf, WalletMigrationError}, init::{migrations::received_notes_nullable_nf, WalletMigrationError},
sapling::commitment_tree::WalletDbSaplingShardStore, sapling::commitment_tree::SqliteShardStore,
}; };
pub(super) const MIGRATION_ID: Uuid = Uuid::from_fields( pub(super) const MIGRATION_ID: Uuid = Uuid::from_fields(
@ -87,10 +87,14 @@ impl RusqliteMigration for Migration {
checkpoint_id INTEGER NOT NULL, checkpoint_id INTEGER NOT NULL,
mark_removed_position INTEGER NOT NULL, mark_removed_position INTEGER NOT NULL,
FOREIGN KEY (checkpoint_id) REFERENCES sapling_tree_checkpoints(checkpoint_id) FOREIGN KEY (checkpoint_id) REFERENCES sapling_tree_checkpoints(checkpoint_id)
ON DELETE CASCADE
);", );",
)?; )?;
let shard_store = WalletDbSaplingShardStore::from_connection(transaction)?; let shard_store =
SqliteShardStore::<_, sapling::Node, SAPLING_SHARD_HEIGHT>::from_connection(
transaction,
)?;
let mut shard_tree: ShardTree< let mut shard_tree: ShardTree<
_, _,
{ sapling::NOTE_COMMITMENT_TREE_DEPTH }, { sapling::NOTE_COMMITMENT_TREE_DEPTH },

View File

@ -1,36 +1,38 @@
use either::Either; use either::Either;
use rusqlite::{self, named_params, Connection, OptionalExtension}; use rusqlite::{self, named_params, OptionalExtension};
use std::{ use std::{
collections::BTreeSet, collections::BTreeSet,
io::{self, Cursor}, io::{self, Cursor},
ops::Deref, marker::PhantomData,
}; };
use incrementalmerkletree::{Address, Level, Position}; use incrementalmerkletree::{Address, Level, Position};
use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore, TreeState}; use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore, TreeState};
use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer, sapling}; use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer};
use zcash_client_backend::data_api::SAPLING_SHARD_HEIGHT;
use crate::serialization::{read_shard, write_shard_v1}; use crate::serialization::{read_shard, write_shard_v1};
const SHARD_ROOT_LEVEL: Level = Level::new(SAPLING_SHARD_HEIGHT); pub struct SqliteShardStore<C, H, const SHARD_HEIGHT: u8> {
pub(crate) conn: C,
pub struct WalletDbSaplingShardStore<'conn, 'a> { _hash_type: PhantomData<H>,
pub(crate) conn: &'a rusqlite::Transaction<'conn>,
} }
impl<'conn, 'a> WalletDbSaplingShardStore<'conn, 'a> { impl<C, H, const SHARD_HEIGHT: u8> SqliteShardStore<C, H, SHARD_HEIGHT> {
pub(crate) fn from_connection( const SHARD_ROOT_LEVEL: Level = Level::new(SHARD_HEIGHT);
conn: &'a rusqlite::Transaction<'conn>,
) -> Result<Self, rusqlite::Error> { pub(crate) fn from_connection(conn: C) -> Result<Self, rusqlite::Error> {
Ok(WalletDbSaplingShardStore { conn }) Ok(SqliteShardStore {
conn,
_hash_type: PhantomData,
})
} }
} }
impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> { impl<'conn, 'a: 'conn, H: HashSer, const SHARD_HEIGHT: u8> ShardStore
type H = sapling::Node; for SqliteShardStore<&'a rusqlite::Transaction<'conn>, H, SHARD_HEIGHT>
{
type H = H;
type CheckpointId = BlockHeight; type CheckpointId = BlockHeight;
type Error = Either<io::Error, rusqlite::Error>; type Error = Either<io::Error, rusqlite::Error>;
@ -42,7 +44,7 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
} }
fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> { fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
last_shard(self.conn) last_shard(self.conn, Self::SHARD_ROOT_LEVEL)
} }
fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> { fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
@ -50,7 +52,7 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
} }
fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> { fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
get_shard_roots(self.conn) get_shard_roots(self.conn, Self::SHARD_ROOT_LEVEL)
} }
fn truncate(&mut self, from: Address) -> Result<(), Self::Error> { fn truncate(&mut self, from: Address) -> Result<(), Self::Error> {
@ -66,11 +68,11 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
} }
fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> { fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
todo!() min_checkpoint_id(self.conn)
} }
fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> { fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
todo!() max_checkpoint_id(self.conn)
} }
fn add_checkpoint( fn add_checkpoint(
@ -99,11 +101,11 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
get_checkpoint(self.conn, *checkpoint_id) get_checkpoint(self.conn, *checkpoint_id)
} }
fn with_checkpoints<F>(&mut self, _limit: usize, _callback: F) -> Result<(), Self::Error> fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
where where
F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>, F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
{ {
todo!() with_checkpoints(self.conn, limit, callback)
} }
fn update_checkpoint_with<F>( fn update_checkpoint_with<F>(
@ -129,12 +131,128 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> {
} }
} }
impl<H: HashSer, const SHARD_HEIGHT: u8> ShardStore
for SqliteShardStore<rusqlite::Connection, H, SHARD_HEIGHT>
{
type H = H;
type CheckpointId = BlockHeight;
type Error = Either<io::Error, rusqlite::Error>;
fn get_shard(
&self,
shard_root: Address,
) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
get_shard(&self.conn, shard_root)
}
fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
last_shard(&self.conn, Self::SHARD_ROOT_LEVEL)
}
fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
put_shard(&tx, subtree)?;
tx.commit().map_err(Either::Right)?;
Ok(())
}
fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
get_shard_roots(&self.conn, Self::SHARD_ROOT_LEVEL)
}
fn truncate(&mut self, from: Address) -> Result<(), Self::Error> {
truncate(&self.conn, from)
}
fn get_cap(&self) -> Result<PrunableTree<Self::H>, Self::Error> {
get_cap(&self.conn)
}
fn put_cap(&mut self, cap: PrunableTree<Self::H>) -> Result<(), Self::Error> {
put_cap(&self.conn, cap)
}
fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
min_checkpoint_id(&self.conn)
}
fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
max_checkpoint_id(&self.conn)
}
fn add_checkpoint(
&mut self,
checkpoint_id: Self::CheckpointId,
checkpoint: Checkpoint,
) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
add_checkpoint(&tx, checkpoint_id, checkpoint)?;
tx.commit().map_err(Either::Right)
}
fn checkpoint_count(&self) -> Result<usize, Self::Error> {
checkpoint_count(&self.conn)
}
fn get_checkpoint_at_depth(
&self,
checkpoint_depth: usize,
) -> Result<Option<(Self::CheckpointId, Checkpoint)>, Self::Error> {
get_checkpoint_at_depth(&self.conn, checkpoint_depth)
}
fn get_checkpoint(
&self,
checkpoint_id: &Self::CheckpointId,
) -> Result<Option<Checkpoint>, Self::Error> {
get_checkpoint(&self.conn, *checkpoint_id)
}
fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
where
F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
{
let tx = self.conn.transaction().map_err(Either::Right)?;
with_checkpoints(&tx, limit, callback)?;
tx.commit().map_err(Either::Right)
}
fn update_checkpoint_with<F>(
&mut self,
checkpoint_id: &Self::CheckpointId,
update: F,
) -> Result<bool, Self::Error>
where
F: Fn(&mut Checkpoint) -> Result<(), Self::Error>,
{
let tx = self.conn.transaction().map_err(Either::Right)?;
let result = update_checkpoint_with(&tx, *checkpoint_id, update)?;
tx.commit().map_err(Either::Right)?;
Ok(result)
}
fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
remove_checkpoint(&tx, *checkpoint_id)?;
tx.commit().map_err(Either::Right)
}
fn truncate_checkpoints(
&mut self,
checkpoint_id: &Self::CheckpointId,
) -> Result<(), Self::Error> {
let tx = self.conn.transaction().map_err(Either::Right)?;
truncate_checkpoints(&tx, *checkpoint_id)?;
tx.commit().map_err(Either::Right)
}
}
type Error = Either<io::Error, rusqlite::Error>; type Error = Either<io::Error, rusqlite::Error>;
pub(crate) fn get_shard( pub(crate) fn get_shard<H: HashSer>(
conn: &rusqlite::Connection, conn: &rusqlite::Connection,
shard_root: Address, shard_root: Address,
) -> Result<Option<LocatedPrunableTree<sapling::Node>>, Error> { ) -> Result<Option<LocatedPrunableTree<H>>, Error> {
conn.query_row( conn.query_row(
"SELECT shard_data "SELECT shard_data
FROM sapling_tree_shards FROM sapling_tree_shards
@ -151,9 +269,10 @@ pub(crate) fn get_shard(
.transpose() .transpose()
} }
pub(crate) fn last_shard( pub(crate) fn last_shard<H: HashSer>(
conn: &rusqlite::Connection, conn: &rusqlite::Connection,
) -> Result<Option<LocatedPrunableTree<sapling::Node>>, Error> { shard_root_level: Level,
) -> Result<Option<LocatedPrunableTree<H>>, Error> {
conn.query_row( conn.query_row(
"SELECT shard_index, shard_data "SELECT shard_index, shard_data
FROM sapling_tree_shards FROM sapling_tree_shards
@ -169,16 +288,16 @@ pub(crate) fn last_shard(
.optional() .optional()
.map_err(Either::Right)? .map_err(Either::Right)?
.map(|(shard_index, shard_data)| { .map(|(shard_index, shard_data)| {
let shard_root = Address::from_parts(SHARD_ROOT_LEVEL, shard_index); let shard_root = Address::from_parts(shard_root_level, shard_index);
let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Either::Left)?; let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Either::Left)?;
Ok(LocatedPrunableTree::from_parts(shard_root, shard_tree)) Ok(LocatedPrunableTree::from_parts(shard_root, shard_tree))
}) })
.transpose() .transpose()
} }
pub(crate) fn put_shard( pub(crate) fn put_shard<H: HashSer>(
conn: &rusqlite::Connection, conn: &rusqlite::Transaction<'_>,
subtree: LocatedPrunableTree<sapling::Node>, subtree: LocatedPrunableTree<H>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let subtree_root_hash = subtree let subtree_root_hash = subtree
.root() .root()
@ -196,26 +315,31 @@ pub(crate) fn put_shard(
let mut subtree_data = vec![]; let mut subtree_data = vec![];
write_shard_v1(&mut subtree_data, subtree.root()).map_err(Either::Left)?; write_shard_v1(&mut subtree_data, subtree.root()).map_err(Either::Left)?;
conn.prepare_cached( let mut stmt_put_shard = conn
"INSERT INTO sapling_tree_shards (shard_index, root_hash, shard_data) .prepare_cached(
VALUES (:shard_index, :root_hash, :shard_data) "INSERT INTO sapling_tree_shards (shard_index, root_hash, shard_data)
ON CONFLICT (shard_index) DO UPDATE VALUES (:shard_index, :root_hash, :shard_data)
SET root_hash = :root_hash, ON CONFLICT (shard_index) DO UPDATE
shard_data = :shard_data", SET root_hash = :root_hash,
) shard_data = :shard_data",
.and_then(|mut stmt_put_shard| { )
stmt_put_shard.execute(named_params![ .map_err(Either::Right)?;
stmt_put_shard
.execute(named_params![
":shard_index": subtree.root_addr().index(), ":shard_index": subtree.root_addr().index(),
":root_hash": subtree_root_hash, ":root_hash": subtree_root_hash,
":shard_data": subtree_data ":shard_data": subtree_data
]) ])
}) .map_err(Either::Right)?;
.map_err(Either::Right)?;
Ok(()) Ok(())
} }
pub(crate) fn get_shard_roots(conn: &rusqlite::Connection) -> Result<Vec<Address>, Error> { pub(crate) fn get_shard_roots(
conn: &rusqlite::Connection,
shard_root_level: Level,
) -> Result<Vec<Address>, Error> {
let mut stmt = conn let mut stmt = conn
.prepare("SELECT shard_index FROM sapling_tree_shards ORDER BY shard_index") .prepare("SELECT shard_index FROM sapling_tree_shards ORDER BY shard_index")
.map_err(Either::Right)?; .map_err(Either::Right)?;
@ -224,14 +348,14 @@ pub(crate) fn get_shard_roots(conn: &rusqlite::Connection) -> Result<Vec<Address
let mut res = vec![]; let mut res = vec![];
while let Some(row) = rows.next().map_err(Either::Right)? { while let Some(row) = rows.next().map_err(Either::Right)? {
res.push(Address::from_parts( res.push(Address::from_parts(
SHARD_ROOT_LEVEL, shard_root_level,
row.get(0).map_err(Either::Right)?, row.get(0).map_err(Either::Right)?,
)); ));
} }
Ok(res) Ok(res)
} }
pub(crate) fn truncate(conn: &rusqlite::Transaction<'_>, from: Address) -> Result<(), Error> { pub(crate) fn truncate(conn: &rusqlite::Connection, from: Address) -> Result<(), Error> {
conn.execute( conn.execute(
"DELETE FROM sapling_tree_shards WHERE shard_index >= ?", "DELETE FROM sapling_tree_shards WHERE shard_index >= ?",
[from.index()], [from.index()],
@ -240,7 +364,7 @@ pub(crate) fn truncate(conn: &rusqlite::Transaction<'_>, from: Address) -> Resul
.map(|_| ()) .map(|_| ())
} }
pub(crate) fn get_cap(conn: &rusqlite::Connection) -> Result<PrunableTree<sapling::Node>, Error> { pub(crate) fn get_cap<H: HashSer>(conn: &rusqlite::Connection) -> Result<PrunableTree<H>, Error> {
conn.query_row("SELECT cap_data FROM sapling_tree_cap", [], |row| { conn.query_row("SELECT cap_data FROM sapling_tree_cap", [], |row| {
row.get::<_, Vec<u8>>(0) row.get::<_, Vec<u8>>(0)
}) })
@ -252,9 +376,9 @@ pub(crate) fn get_cap(conn: &rusqlite::Connection) -> Result<PrunableTree<saplin
) )
} }
pub(crate) fn put_cap( pub(crate) fn put_cap<H: HashSer>(
conn: &rusqlite::Transaction<'_>, conn: &rusqlite::Connection,
cap: PrunableTree<sapling::Node>, cap: PrunableTree<H>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut stmt = conn let mut stmt = conn
.prepare_cached( .prepare_cached(
@ -272,22 +396,62 @@ pub(crate) fn put_cap(
Ok(()) Ok(())
} }
pub(crate) fn min_checkpoint_id(conn: &rusqlite::Connection) -> Result<Option<BlockHeight>, Error> {
conn.query_row(
"SELECT MIN(checkpoint_id) FROM sapling_tree_checkpoints",
[],
|row| {
row.get::<_, Option<u32>>(0)
.map(|opt| opt.map(BlockHeight::from))
},
)
.map_err(Either::Right)
}
pub(crate) fn max_checkpoint_id(conn: &rusqlite::Connection) -> Result<Option<BlockHeight>, Error> {
conn.query_row(
"SELECT MAX(checkpoint_id) FROM sapling_tree_checkpoints",
[],
|row| {
row.get::<_, Option<u32>>(0)
.map(|opt| opt.map(BlockHeight::from))
},
)
.map_err(Either::Right)
}
pub(crate) fn add_checkpoint( pub(crate) fn add_checkpoint(
conn: &rusqlite::Transaction<'_>, conn: &rusqlite::Transaction<'_>,
checkpoint_id: BlockHeight, checkpoint_id: BlockHeight,
checkpoint: Checkpoint, checkpoint: Checkpoint,
) -> Result<(), Error> { ) -> Result<(), Error> {
conn.prepare_cached( let mut stmt_insert_checkpoint = conn
"INSERT INTO sapling_tree_checkpoints (checkpoint_id, position) .prepare_cached(
VALUES (:checkpoint_id, :position)", "INSERT INTO sapling_tree_checkpoints (checkpoint_id, position)
) VALUES (:checkpoint_id, :position)",
.and_then(|mut stmt_insert_checkpoint| { )
stmt_insert_checkpoint.execute(named_params![ .map_err(Either::Right)?;
stmt_insert_checkpoint
.execute(named_params![
":checkpoint_id": u32::from(checkpoint_id), ":checkpoint_id": u32::from(checkpoint_id),
":position": checkpoint.position().map(u64::from) ":position": checkpoint.position().map(u64::from)
]) ])
}) .map_err(Either::Right)?;
.map_err(Either::Right)?;
let mut stmt_insert_mark_removed = conn.prepare_cached(
"INSERT INTO sapling_tree_checkpoint_marks_removed (checkpoint_id, mark_removed_position)
VALUES (:checkpoint_id, :position)",
).map_err(Either::Right)?;
for pos in checkpoint.marks_removed() {
stmt_insert_mark_removed
.execute(named_params![
":checkpoint_id": u32::from(checkpoint_id),
":position": u64::from(*pos)
])
.map_err(Either::Right)?;
}
Ok(()) Ok(())
} }
@ -299,10 +463,10 @@ pub(crate) fn checkpoint_count(conn: &rusqlite::Connection) -> Result<usize, Err
.map_err(Either::Right) .map_err(Either::Right)
} }
pub(crate) fn get_checkpoint<C: Deref<Target = Connection>>( pub(crate) fn get_checkpoint(
conn: &C, conn: &rusqlite::Connection,
checkpoint_id: BlockHeight, checkpoint_id: BlockHeight,
) -> Result<Option<Checkpoint>, Either<io::Error, rusqlite::Error>> { ) -> Result<Option<Checkpoint>, Error> {
let checkpoint_position = conn let checkpoint_position = conn
.query_row( .query_row(
"SELECT position "SELECT position
@ -347,10 +511,14 @@ pub(crate) fn get_checkpoint<C: Deref<Target = Connection>>(
.transpose() .transpose()
} }
pub(crate) fn get_checkpoint_at_depth<C: Deref<Target = Connection>>( pub(crate) fn get_checkpoint_at_depth(
conn: &C, conn: &rusqlite::Connection,
checkpoint_depth: usize, checkpoint_depth: usize,
) -> Result<Option<(BlockHeight, Checkpoint)>, Either<io::Error, rusqlite::Error>> { ) -> Result<Option<(BlockHeight, Checkpoint)>, Error> {
if checkpoint_depth == 0 {
return Ok(None);
}
let checkpoint_parts = conn let checkpoint_parts = conn
.query_row( .query_row(
"SELECT checkpoint_id, position "SELECT checkpoint_id, position
@ -358,7 +526,7 @@ pub(crate) fn get_checkpoint_at_depth<C: Deref<Target = Connection>>(
ORDER BY checkpoint_id DESC ORDER BY checkpoint_id DESC
LIMIT 1 LIMIT 1
OFFSET :offset", OFFSET :offset",
named_params![":offset": checkpoint_depth], named_params![":offset": checkpoint_depth - 1],
|row| { |row| {
let checkpoint_id: u32 = row.get(0)?; let checkpoint_id: u32 = row.get(0)?;
let position: Option<u64> = row.get(1)?; let position: Option<u64> = row.get(1)?;
@ -404,6 +572,62 @@ pub(crate) fn get_checkpoint_at_depth<C: Deref<Target = Connection>>(
.transpose() .transpose()
} }
pub(crate) fn with_checkpoints<F>(
conn: &rusqlite::Transaction<'_>,
limit: usize,
mut callback: F,
) -> Result<(), Error>
where
F: FnMut(&BlockHeight, &Checkpoint) -> Result<(), Error>,
{
let mut stmt_get_checkpoints = conn
.prepare_cached(
"SELECT checkpoint_id, position
FROM sapling_tree_checkpoints
LIMIT :limit",
)
.map_err(Either::Right)?;
let mut stmt_get_checkpoint_marks_removed = conn
.prepare_cached(
"SELECT mark_removed_position
FROM sapling_tree_checkpoint_marks_removed
WHERE checkpoint_id = :checkpoint_id",
)
.map_err(Either::Right)?;
let mut rows = stmt_get_checkpoints
.query(named_params![":limit": limit])
.map_err(Either::Right)?;
while let Some(row) = rows.next().map_err(Either::Right)? {
let checkpoint_id = row.get::<_, u32>(0).map_err(Either::Right)?;
let tree_state = row
.get::<_, Option<u64>>(1)
.map(|opt| opt.map_or_else(|| TreeState::Empty, |p| TreeState::AtPosition(p.into())))
.map_err(Either::Right)?;
let mut mark_removed_rows = stmt_get_checkpoint_marks_removed
.query(named_params![":checkpoint_id": checkpoint_id])
.map_err(Either::Right)?;
let mut marks_removed = BTreeSet::new();
while let Some(mr_row) = mark_removed_rows.next().map_err(Either::Right)? {
let mark_removed_position = mr_row
.get::<_, u64>(0)
.map(Position::from)
.map_err(Either::Right)?;
marks_removed.insert(mark_removed_position);
}
callback(
&BlockHeight::from(checkpoint_id),
&Checkpoint::from_parts(tree_state, marks_removed),
)?
}
Ok(())
}
pub(crate) fn update_checkpoint_with<F>( pub(crate) fn update_checkpoint_with<F>(
conn: &rusqlite::Transaction<'_>, conn: &rusqlite::Transaction<'_>,
checkpoint_id: BlockHeight, checkpoint_id: BlockHeight,
@ -426,11 +650,17 @@ pub(crate) fn remove_checkpoint(
conn: &rusqlite::Transaction<'_>, conn: &rusqlite::Transaction<'_>,
checkpoint_id: BlockHeight, checkpoint_id: BlockHeight,
) -> Result<(), Error> { ) -> Result<(), Error> {
conn.execute( // sapling_tree_checkpoints is constructed with `ON DELETE CASCADE`
"DELETE FROM sapling_tree_checkpoints WHERE checkpoint_id = ?", let mut stmt_delete_checkpoint = conn
[u32::from(checkpoint_id)], .prepare_cached(
) "DELETE FROM sapling_tree_checkpoints
.map_err(Either::Right)?; WHERE checkpoint_id = :checkpoint_id",
)
.map_err(Either::Right)?;
stmt_delete_checkpoint
.execute(named_params![":checkpoint_id": u32::from(checkpoint_id),])
.map_err(Either::Right)?;
Ok(()) Ok(())
} }
@ -452,3 +682,62 @@ pub(crate) fn truncate_checkpoints(
.map_err(Either::Right)?; .map_err(Either::Right)?;
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use tempfile::NamedTempFile;
use incrementalmerkletree::testing::{
check_append, check_checkpoint_rewind, check_remove_mark, check_rewind_remove_mark,
check_root_hashes, check_witness_consistency, check_witnesses,
};
use shardtree::ShardTree;
use super::SqliteShardStore;
use crate::{tests, wallet::init::init_wallet_db, WalletDb};
fn new_tree(m: usize) -> ShardTree<SqliteShardStore<rusqlite::Connection, String, 3>, 4, 3> {
let data_file = NamedTempFile::new().unwrap();
let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap();
data_file.keep().unwrap();
init_wallet_db(&mut db_data, None).unwrap();
let store = SqliteShardStore::<_, String, 3>::from_connection(db_data.conn).unwrap();
ShardTree::new(store, m)
}
#[test]
fn append() {
check_append(new_tree);
}
#[test]
fn root_hashes() {
check_root_hashes(new_tree);
}
#[test]
fn witnesses() {
check_witnesses(new_tree);
}
#[test]
fn witness_consistency() {
check_witness_consistency(new_tree);
}
#[test]
fn checkpoint_rewind() {
check_checkpoint_rewind(new_tree);
}
#[test]
fn remove_mark() {
check_remove_mark(new_tree);
}
#[test]
fn rewind_remove_mark() {
check_rewind_remove_mark(new_tree);
}
}

View File

@ -627,6 +627,12 @@ pub mod testing {
) )
}) })
} }
impl incrementalmerkletree::testing::TestCheckpoint for BlockHeight {
fn from_u64(value: u64) -> Self {
BlockHeight(u32::try_from(value).expect("Test checkpoint ids do not exceed 32 bits"))
}
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -292,6 +292,7 @@ pub mod testing {
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use incrementalmerkletree::frontier::testing::TestNode; use incrementalmerkletree::frontier::testing::TestNode;
use std::io::{self, Read, Write}; use std::io::{self, Read, Write};
use zcash_encoding::Vector;
use super::HashSer; use super::HashSer;
@ -304,6 +305,23 @@ pub mod testing {
writer.write_u64::<LittleEndian>(self.0) writer.write_u64::<LittleEndian>(self.0)
} }
} }
impl HashSer for String {
fn read<R: Read>(reader: R) -> io::Result<String> {
Vector::read(reader, |r| r.read_u8()).and_then(|xs| {
String::from_utf8(xs).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Not a valid utf8 string: {:?}", e),
)
})
})
}
fn write<W: Write>(&self, writer: W) -> io::Result<()> {
Vector::write(writer, self.as_bytes(), |w, b| w.write_u8(*b))
}
}
} }
#[cfg(test)] #[cfg(test)]