From 0a4236f725a3f46fce7247fa0860ba44871eb8d8 Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Thu, 15 Jun 2023 13:50:07 -0600 Subject: [PATCH] zcash_client_sqlite: Add tests for sqlite-backed ShardTree & fix revealed issues. --- zcash_client_sqlite/src/lib.rs | 14 +- zcash_client_sqlite/src/wallet/init.rs | 1 + .../init/migrations/shardtree_support.rs | 8 +- .../src/wallet/sapling/commitment_tree.rs | 425 +++++++++++++++--- zcash_primitives/src/consensus.rs | 6 + zcash_primitives/src/merkle_tree.rs | 18 + 6 files changed, 395 insertions(+), 77 deletions(-) diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index 0d46a1ce0..9df48d5ca 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -66,9 +66,7 @@ use zcash_client_backend::{ DecryptedOutput, TransferType, }; -use crate::{ - error::SqliteClientError, wallet::sapling::commitment_tree::WalletDbSaplingShardStore, -}; +use crate::{error::SqliteClientError, wallet::sapling::commitment_tree::SqliteShardStore}; #[cfg(feature = "unstable")] use { @@ -617,7 +615,8 @@ impl WalletWrite for WalletDb impl WalletCommitmentTrees for WalletDb { type Error = Either; - type SaplingShardStore<'a> = WalletDbSaplingShardStore<'a, 'a>; + type SaplingShardStore<'a> = + SqliteShardStore<&'a rusqlite::Transaction<'a>, sapling::Node, SAPLING_SHARD_HEIGHT>; fn with_sapling_tree_mut(&mut self, mut callback: F) -> Result where @@ -634,7 +633,7 @@ impl WalletCommitmentTrees for WalletDb WalletCommitmentTrees for WalletDb WalletCommitmentTrees for WalletDb, P> { type Error = Either; - type SaplingShardStore<'a> = WalletDbSaplingShardStore<'a, 'a>; + type SaplingShardStore<'a> = + SqliteShardStore<&'a rusqlite::Transaction<'a>, sapling::Node, SAPLING_SHARD_HEIGHT>; fn with_sapling_tree_mut(&mut self, mut callback: F) -> Result where @@ -662,7 +662,7 @@ impl<'conn, P: consensus::Parameters> WalletCommitmentTrees for WalletDb>>, { let mut shardtree = ShardTree::new( - WalletDbSaplingShardStore::from_connection(self.conn.0) + SqliteShardStore::from_connection(self.conn.0) .map_err(|e| ShardTreeError::Storage(Either::Right(e)))?, 100, ); diff --git a/zcash_client_sqlite/src/wallet/init.rs b/zcash_client_sqlite/src/wallet/init.rs index a98806ee6..c730aea73 100644 --- a/zcash_client_sqlite/src/wallet/init.rs +++ b/zcash_client_sqlite/src/wallet/init.rs @@ -403,6 +403,7 @@ mod tests { checkpoint_id INTEGER NOT NULL, mark_removed_position INTEGER NOT NULL, FOREIGN KEY (checkpoint_id) REFERENCES sapling_tree_checkpoints(checkpoint_id) + ON DELETE CASCADE )", "CREATE TABLE sapling_tree_checkpoints ( checkpoint_id INTEGER PRIMARY KEY, diff --git a/zcash_client_sqlite/src/wallet/init/migrations/shardtree_support.rs b/zcash_client_sqlite/src/wallet/init/migrations/shardtree_support.rs index ded5b2d1f..6a597d56c 100644 --- a/zcash_client_sqlite/src/wallet/init/migrations/shardtree_support.rs +++ b/zcash_client_sqlite/src/wallet/init/migrations/shardtree_support.rs @@ -20,7 +20,7 @@ use zcash_primitives::{ use crate::wallet::{ init::{migrations::received_notes_nullable_nf, WalletMigrationError}, - sapling::commitment_tree::WalletDbSaplingShardStore, + sapling::commitment_tree::SqliteShardStore, }; pub(super) const MIGRATION_ID: Uuid = Uuid::from_fields( @@ -87,10 +87,14 @@ impl RusqliteMigration for Migration { checkpoint_id INTEGER NOT NULL, mark_removed_position INTEGER NOT NULL, FOREIGN KEY (checkpoint_id) REFERENCES sapling_tree_checkpoints(checkpoint_id) + ON DELETE CASCADE );", )?; - let shard_store = WalletDbSaplingShardStore::from_connection(transaction)?; + let shard_store = + SqliteShardStore::<_, sapling::Node, SAPLING_SHARD_HEIGHT>::from_connection( + transaction, + )?; let mut shard_tree: ShardTree< _, { sapling::NOTE_COMMITMENT_TREE_DEPTH }, diff --git a/zcash_client_sqlite/src/wallet/sapling/commitment_tree.rs b/zcash_client_sqlite/src/wallet/sapling/commitment_tree.rs index 5ea9e29fe..f685350c1 100644 --- a/zcash_client_sqlite/src/wallet/sapling/commitment_tree.rs +++ b/zcash_client_sqlite/src/wallet/sapling/commitment_tree.rs @@ -1,36 +1,38 @@ use either::Either; -use rusqlite::{self, named_params, Connection, OptionalExtension}; +use rusqlite::{self, named_params, OptionalExtension}; use std::{ collections::BTreeSet, io::{self, Cursor}, - ops::Deref, + marker::PhantomData, }; use incrementalmerkletree::{Address, Level, Position}; use shardtree::{Checkpoint, LocatedPrunableTree, PrunableTree, ShardStore, TreeState}; -use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer, sapling}; - -use zcash_client_backend::data_api::SAPLING_SHARD_HEIGHT; +use zcash_primitives::{consensus::BlockHeight, merkle_tree::HashSer}; use crate::serialization::{read_shard, write_shard_v1}; -const SHARD_ROOT_LEVEL: Level = Level::new(SAPLING_SHARD_HEIGHT); - -pub struct WalletDbSaplingShardStore<'conn, 'a> { - pub(crate) conn: &'a rusqlite::Transaction<'conn>, +pub struct SqliteShardStore { + pub(crate) conn: C, + _hash_type: PhantomData, } -impl<'conn, 'a> WalletDbSaplingShardStore<'conn, 'a> { - pub(crate) fn from_connection( - conn: &'a rusqlite::Transaction<'conn>, - ) -> Result { - Ok(WalletDbSaplingShardStore { conn }) +impl SqliteShardStore { + const SHARD_ROOT_LEVEL: Level = Level::new(SHARD_HEIGHT); + + pub(crate) fn from_connection(conn: C) -> Result { + Ok(SqliteShardStore { + conn, + _hash_type: PhantomData, + }) } } -impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> { - type H = sapling::Node; +impl<'conn, 'a: 'conn, H: HashSer, const SHARD_HEIGHT: u8> ShardStore + for SqliteShardStore<&'a rusqlite::Transaction<'conn>, H, SHARD_HEIGHT> +{ + type H = H; type CheckpointId = BlockHeight; type Error = Either; @@ -42,7 +44,7 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> { } fn last_shard(&self) -> Result>, Self::Error> { - last_shard(self.conn) + last_shard(self.conn, Self::SHARD_ROOT_LEVEL) } fn put_shard(&mut self, subtree: LocatedPrunableTree) -> Result<(), Self::Error> { @@ -50,7 +52,7 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> { } fn get_shard_roots(&self) -> Result, Self::Error> { - get_shard_roots(self.conn) + get_shard_roots(self.conn, Self::SHARD_ROOT_LEVEL) } fn truncate(&mut self, from: Address) -> Result<(), Self::Error> { @@ -66,11 +68,11 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> { } fn min_checkpoint_id(&self) -> Result, Self::Error> { - todo!() + min_checkpoint_id(self.conn) } fn max_checkpoint_id(&self) -> Result, Self::Error> { - todo!() + max_checkpoint_id(self.conn) } fn add_checkpoint( @@ -99,11 +101,11 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> { get_checkpoint(self.conn, *checkpoint_id) } - fn with_checkpoints(&mut self, _limit: usize, _callback: F) -> Result<(), Self::Error> + fn with_checkpoints(&mut self, limit: usize, callback: F) -> Result<(), Self::Error> where F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>, { - todo!() + with_checkpoints(self.conn, limit, callback) } fn update_checkpoint_with( @@ -129,12 +131,128 @@ impl<'conn, 'a: 'conn> ShardStore for WalletDbSaplingShardStore<'conn, 'a> { } } +impl ShardStore + for SqliteShardStore +{ + type H = H; + type CheckpointId = BlockHeight; + type Error = Either; + + fn get_shard( + &self, + shard_root: Address, + ) -> Result>, Self::Error> { + get_shard(&self.conn, shard_root) + } + + fn last_shard(&self) -> Result>, Self::Error> { + last_shard(&self.conn, Self::SHARD_ROOT_LEVEL) + } + + fn put_shard(&mut self, subtree: LocatedPrunableTree) -> Result<(), Self::Error> { + let tx = self.conn.transaction().map_err(Either::Right)?; + put_shard(&tx, subtree)?; + tx.commit().map_err(Either::Right)?; + Ok(()) + } + + fn get_shard_roots(&self) -> Result, Self::Error> { + get_shard_roots(&self.conn, Self::SHARD_ROOT_LEVEL) + } + + fn truncate(&mut self, from: Address) -> Result<(), Self::Error> { + truncate(&self.conn, from) + } + + fn get_cap(&self) -> Result, Self::Error> { + get_cap(&self.conn) + } + + fn put_cap(&mut self, cap: PrunableTree) -> Result<(), Self::Error> { + put_cap(&self.conn, cap) + } + + fn min_checkpoint_id(&self) -> Result, Self::Error> { + min_checkpoint_id(&self.conn) + } + + fn max_checkpoint_id(&self) -> Result, Self::Error> { + max_checkpoint_id(&self.conn) + } + + fn add_checkpoint( + &mut self, + checkpoint_id: Self::CheckpointId, + checkpoint: Checkpoint, + ) -> Result<(), Self::Error> { + let tx = self.conn.transaction().map_err(Either::Right)?; + add_checkpoint(&tx, checkpoint_id, checkpoint)?; + tx.commit().map_err(Either::Right) + } + + fn checkpoint_count(&self) -> Result { + checkpoint_count(&self.conn) + } + + fn get_checkpoint_at_depth( + &self, + checkpoint_depth: usize, + ) -> Result, Self::Error> { + get_checkpoint_at_depth(&self.conn, checkpoint_depth) + } + + fn get_checkpoint( + &self, + checkpoint_id: &Self::CheckpointId, + ) -> Result, Self::Error> { + get_checkpoint(&self.conn, *checkpoint_id) + } + + fn with_checkpoints(&mut self, limit: usize, callback: F) -> Result<(), Self::Error> + where + F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>, + { + let tx = self.conn.transaction().map_err(Either::Right)?; + with_checkpoints(&tx, limit, callback)?; + tx.commit().map_err(Either::Right) + } + + fn update_checkpoint_with( + &mut self, + checkpoint_id: &Self::CheckpointId, + update: F, + ) -> Result + where + F: Fn(&mut Checkpoint) -> Result<(), Self::Error>, + { + let tx = self.conn.transaction().map_err(Either::Right)?; + let result = update_checkpoint_with(&tx, *checkpoint_id, update)?; + tx.commit().map_err(Either::Right)?; + Ok(result) + } + + fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> { + let tx = self.conn.transaction().map_err(Either::Right)?; + remove_checkpoint(&tx, *checkpoint_id)?; + tx.commit().map_err(Either::Right) + } + + fn truncate_checkpoints( + &mut self, + checkpoint_id: &Self::CheckpointId, + ) -> Result<(), Self::Error> { + let tx = self.conn.transaction().map_err(Either::Right)?; + truncate_checkpoints(&tx, *checkpoint_id)?; + tx.commit().map_err(Either::Right) + } +} + type Error = Either; -pub(crate) fn get_shard( +pub(crate) fn get_shard( conn: &rusqlite::Connection, shard_root: Address, -) -> Result>, Error> { +) -> Result>, Error> { conn.query_row( "SELECT shard_data FROM sapling_tree_shards @@ -151,9 +269,10 @@ pub(crate) fn get_shard( .transpose() } -pub(crate) fn last_shard( +pub(crate) fn last_shard( conn: &rusqlite::Connection, -) -> Result>, Error> { + shard_root_level: Level, +) -> Result>, Error> { conn.query_row( "SELECT shard_index, shard_data FROM sapling_tree_shards @@ -169,16 +288,16 @@ pub(crate) fn last_shard( .optional() .map_err(Either::Right)? .map(|(shard_index, shard_data)| { - let shard_root = Address::from_parts(SHARD_ROOT_LEVEL, shard_index); + let shard_root = Address::from_parts(shard_root_level, shard_index); let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Either::Left)?; Ok(LocatedPrunableTree::from_parts(shard_root, shard_tree)) }) .transpose() } -pub(crate) fn put_shard( - conn: &rusqlite::Connection, - subtree: LocatedPrunableTree, +pub(crate) fn put_shard( + conn: &rusqlite::Transaction<'_>, + subtree: LocatedPrunableTree, ) -> Result<(), Error> { let subtree_root_hash = subtree .root() @@ -196,26 +315,31 @@ pub(crate) fn put_shard( let mut subtree_data = vec![]; write_shard_v1(&mut subtree_data, subtree.root()).map_err(Either::Left)?; - conn.prepare_cached( - "INSERT INTO sapling_tree_shards (shard_index, root_hash, shard_data) - VALUES (:shard_index, :root_hash, :shard_data) - ON CONFLICT (shard_index) DO UPDATE - SET root_hash = :root_hash, - shard_data = :shard_data", - ) - .and_then(|mut stmt_put_shard| { - stmt_put_shard.execute(named_params![ + let mut stmt_put_shard = conn + .prepare_cached( + "INSERT INTO sapling_tree_shards (shard_index, root_hash, shard_data) + VALUES (:shard_index, :root_hash, :shard_data) + ON CONFLICT (shard_index) DO UPDATE + SET root_hash = :root_hash, + shard_data = :shard_data", + ) + .map_err(Either::Right)?; + + stmt_put_shard + .execute(named_params![ ":shard_index": subtree.root_addr().index(), ":root_hash": subtree_root_hash, ":shard_data": subtree_data ]) - }) - .map_err(Either::Right)?; + .map_err(Either::Right)?; Ok(()) } -pub(crate) fn get_shard_roots(conn: &rusqlite::Connection) -> Result, Error> { +pub(crate) fn get_shard_roots( + conn: &rusqlite::Connection, + shard_root_level: Level, +) -> Result, Error> { let mut stmt = conn .prepare("SELECT shard_index FROM sapling_tree_shards ORDER BY shard_index") .map_err(Either::Right)?; @@ -224,14 +348,14 @@ pub(crate) fn get_shard_roots(conn: &rusqlite::Connection) -> Result, from: Address) -> Result<(), Error> { +pub(crate) fn truncate(conn: &rusqlite::Connection, from: Address) -> Result<(), Error> { conn.execute( "DELETE FROM sapling_tree_shards WHERE shard_index >= ?", [from.index()], @@ -240,7 +364,7 @@ pub(crate) fn truncate(conn: &rusqlite::Transaction<'_>, from: Address) -> Resul .map(|_| ()) } -pub(crate) fn get_cap(conn: &rusqlite::Connection) -> Result, Error> { +pub(crate) fn get_cap(conn: &rusqlite::Connection) -> Result, Error> { conn.query_row("SELECT cap_data FROM sapling_tree_cap", [], |row| { row.get::<_, Vec>(0) }) @@ -252,9 +376,9 @@ pub(crate) fn get_cap(conn: &rusqlite::Connection) -> Result, - cap: PrunableTree, +pub(crate) fn put_cap( + conn: &rusqlite::Connection, + cap: PrunableTree, ) -> Result<(), Error> { let mut stmt = conn .prepare_cached( @@ -272,22 +396,62 @@ pub(crate) fn put_cap( Ok(()) } +pub(crate) fn min_checkpoint_id(conn: &rusqlite::Connection) -> Result, Error> { + conn.query_row( + "SELECT MIN(checkpoint_id) FROM sapling_tree_checkpoints", + [], + |row| { + row.get::<_, Option>(0) + .map(|opt| opt.map(BlockHeight::from)) + }, + ) + .map_err(Either::Right) +} + +pub(crate) fn max_checkpoint_id(conn: &rusqlite::Connection) -> Result, Error> { + conn.query_row( + "SELECT MAX(checkpoint_id) FROM sapling_tree_checkpoints", + [], + |row| { + row.get::<_, Option>(0) + .map(|opt| opt.map(BlockHeight::from)) + }, + ) + .map_err(Either::Right) +} + pub(crate) fn add_checkpoint( conn: &rusqlite::Transaction<'_>, checkpoint_id: BlockHeight, checkpoint: Checkpoint, ) -> Result<(), Error> { - conn.prepare_cached( - "INSERT INTO sapling_tree_checkpoints (checkpoint_id, position) - VALUES (:checkpoint_id, :position)", - ) - .and_then(|mut stmt_insert_checkpoint| { - stmt_insert_checkpoint.execute(named_params![ + let mut stmt_insert_checkpoint = conn + .prepare_cached( + "INSERT INTO sapling_tree_checkpoints (checkpoint_id, position) + VALUES (:checkpoint_id, :position)", + ) + .map_err(Either::Right)?; + + stmt_insert_checkpoint + .execute(named_params![ ":checkpoint_id": u32::from(checkpoint_id), ":position": checkpoint.position().map(u64::from) ]) - }) - .map_err(Either::Right)?; + .map_err(Either::Right)?; + + let mut stmt_insert_mark_removed = conn.prepare_cached( + "INSERT INTO sapling_tree_checkpoint_marks_removed (checkpoint_id, mark_removed_position) + VALUES (:checkpoint_id, :position)", + ).map_err(Either::Right)?; + + for pos in checkpoint.marks_removed() { + stmt_insert_mark_removed + .execute(named_params![ + ":checkpoint_id": u32::from(checkpoint_id), + ":position": u64::from(*pos) + ]) + .map_err(Either::Right)?; + } Ok(()) } @@ -299,10 +463,10 @@ pub(crate) fn checkpoint_count(conn: &rusqlite::Connection) -> Result>( - conn: &C, +pub(crate) fn get_checkpoint( + conn: &rusqlite::Connection, checkpoint_id: BlockHeight, -) -> Result, Either> { +) -> Result, Error> { let checkpoint_position = conn .query_row( "SELECT position @@ -347,10 +511,14 @@ pub(crate) fn get_checkpoint>( .transpose() } -pub(crate) fn get_checkpoint_at_depth>( - conn: &C, +pub(crate) fn get_checkpoint_at_depth( + conn: &rusqlite::Connection, checkpoint_depth: usize, -) -> Result, Either> { +) -> Result, Error> { + if checkpoint_depth == 0 { + return Ok(None); + } + let checkpoint_parts = conn .query_row( "SELECT checkpoint_id, position @@ -358,7 +526,7 @@ pub(crate) fn get_checkpoint_at_depth>( ORDER BY checkpoint_id DESC LIMIT 1 OFFSET :offset", - named_params![":offset": checkpoint_depth], + named_params![":offset": checkpoint_depth - 1], |row| { let checkpoint_id: u32 = row.get(0)?; let position: Option = row.get(1)?; @@ -404,6 +572,62 @@ pub(crate) fn get_checkpoint_at_depth>( .transpose() } +pub(crate) fn with_checkpoints( + conn: &rusqlite::Transaction<'_>, + limit: usize, + mut callback: F, +) -> Result<(), Error> +where + F: FnMut(&BlockHeight, &Checkpoint) -> Result<(), Error>, +{ + let mut stmt_get_checkpoints = conn + .prepare_cached( + "SELECT checkpoint_id, position + FROM sapling_tree_checkpoints + LIMIT :limit", + ) + .map_err(Either::Right)?; + + let mut stmt_get_checkpoint_marks_removed = conn + .prepare_cached( + "SELECT mark_removed_position + FROM sapling_tree_checkpoint_marks_removed + WHERE checkpoint_id = :checkpoint_id", + ) + .map_err(Either::Right)?; + + let mut rows = stmt_get_checkpoints + .query(named_params![":limit": limit]) + .map_err(Either::Right)?; + + while let Some(row) = rows.next().map_err(Either::Right)? { + let checkpoint_id = row.get::<_, u32>(0).map_err(Either::Right)?; + let tree_state = row + .get::<_, Option>(1) + .map(|opt| opt.map_or_else(|| TreeState::Empty, |p| TreeState::AtPosition(p.into()))) + .map_err(Either::Right)?; + + let mut mark_removed_rows = stmt_get_checkpoint_marks_removed + .query(named_params![":checkpoint_id": checkpoint_id]) + .map_err(Either::Right)?; + let mut marks_removed = BTreeSet::new(); + while let Some(mr_row) = mark_removed_rows.next().map_err(Either::Right)? { + let mark_removed_position = mr_row + .get::<_, u64>(0) + .map(Position::from) + .map_err(Either::Right)?; + marks_removed.insert(mark_removed_position); + } + + callback( + &BlockHeight::from(checkpoint_id), + &Checkpoint::from_parts(tree_state, marks_removed), + )? + } + + Ok(()) +} + pub(crate) fn update_checkpoint_with( conn: &rusqlite::Transaction<'_>, checkpoint_id: BlockHeight, @@ -426,11 +650,17 @@ pub(crate) fn remove_checkpoint( conn: &rusqlite::Transaction<'_>, checkpoint_id: BlockHeight, ) -> Result<(), Error> { - conn.execute( - "DELETE FROM sapling_tree_checkpoints WHERE checkpoint_id = ?", - [u32::from(checkpoint_id)], - ) - .map_err(Either::Right)?; + // sapling_tree_checkpoints is constructed with `ON DELETE CASCADE` + let mut stmt_delete_checkpoint = conn + .prepare_cached( + "DELETE FROM sapling_tree_checkpoints + WHERE checkpoint_id = :checkpoint_id", + ) + .map_err(Either::Right)?; + + stmt_delete_checkpoint + .execute(named_params![":checkpoint_id": u32::from(checkpoint_id),]) + .map_err(Either::Right)?; Ok(()) } @@ -452,3 +682,62 @@ pub(crate) fn truncate_checkpoints( .map_err(Either::Right)?; Ok(()) } + +#[cfg(test)] +mod tests { + use tempfile::NamedTempFile; + + use incrementalmerkletree::testing::{ + check_append, check_checkpoint_rewind, check_remove_mark, check_rewind_remove_mark, + check_root_hashes, check_witness_consistency, check_witnesses, + }; + use shardtree::ShardTree; + + use super::SqliteShardStore; + use crate::{tests, wallet::init::init_wallet_db, WalletDb}; + + fn new_tree(m: usize) -> ShardTree, 4, 3> { + let data_file = NamedTempFile::new().unwrap(); + let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); + data_file.keep().unwrap(); + + init_wallet_db(&mut db_data, None).unwrap(); + let store = SqliteShardStore::<_, String, 3>::from_connection(db_data.conn).unwrap(); + ShardTree::new(store, m) + } + + #[test] + fn append() { + check_append(new_tree); + } + + #[test] + fn root_hashes() { + check_root_hashes(new_tree); + } + + #[test] + fn witnesses() { + check_witnesses(new_tree); + } + + #[test] + fn witness_consistency() { + check_witness_consistency(new_tree); + } + + #[test] + fn checkpoint_rewind() { + check_checkpoint_rewind(new_tree); + } + + #[test] + fn remove_mark() { + check_remove_mark(new_tree); + } + + #[test] + fn rewind_remove_mark() { + check_rewind_remove_mark(new_tree); + } +} diff --git a/zcash_primitives/src/consensus.rs b/zcash_primitives/src/consensus.rs index dc972f700..563c69806 100644 --- a/zcash_primitives/src/consensus.rs +++ b/zcash_primitives/src/consensus.rs @@ -627,6 +627,12 @@ pub mod testing { ) }) } + + impl incrementalmerkletree::testing::TestCheckpoint for BlockHeight { + fn from_u64(value: u64) -> Self { + BlockHeight(u32::try_from(value).expect("Test checkpoint ids do not exceed 32 bits")) + } + } } #[cfg(test)] diff --git a/zcash_primitives/src/merkle_tree.rs b/zcash_primitives/src/merkle_tree.rs index 6cda449bc..0a24b8a1f 100644 --- a/zcash_primitives/src/merkle_tree.rs +++ b/zcash_primitives/src/merkle_tree.rs @@ -292,6 +292,7 @@ pub mod testing { use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use incrementalmerkletree::frontier::testing::TestNode; use std::io::{self, Read, Write}; + use zcash_encoding::Vector; use super::HashSer; @@ -304,6 +305,23 @@ pub mod testing { writer.write_u64::(self.0) } } + + impl HashSer for String { + fn read(reader: R) -> io::Result { + Vector::read(reader, |r| r.read_u8()).and_then(|xs| { + String::from_utf8(xs).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Not a valid utf8 string: {:?}", e), + ) + }) + }) + } + + fn write(&self, writer: W) -> io::Result<()> { + Vector::write(writer, self.as_bytes(), |w, b| w.write_u8(*b)) + } + } } #[cfg(test)]