1use rusqlite::{self, named_params, OptionalExtension};
2use std::{
3 collections::BTreeSet,
4 error, fmt,
5 io::{self, Cursor},
6 marker::PhantomData,
7 num::NonZeroU32,
8 ops::Range,
9 sync::Arc,
10};
11
12use incrementalmerkletree::{Address, Hashable, Level, Position, Retention};
13use shardtree::{
14 error::{QueryError, ShardTreeError},
15 store::{Checkpoint, ShardStore, TreeState},
16 LocatedPrunableTree, LocatedTree, PrunableTree, RetentionFlags,
17};
18
19use zcash_client_backend::{
20 data_api::chain::CommitmentTreeRoot,
21 serialization::shardtree::{read_shard, write_shard},
22};
23use zcash_primitives::merkle_tree::HashSer;
24use zcash_protocol::{consensus::BlockHeight, ShieldedProtocol};
25
26use crate::{error::SqliteClientError, sapling_tree};
27
28#[cfg(feature = "orchard")]
29use crate::orchard_tree;
30
31#[derive(Debug)]
33pub enum Error {
34 Serialization(io::Error),
36 Query(rusqlite::Error),
38 CheckpointConflict {
42 checkpoint_id: BlockHeight,
43 checkpoint: Checkpoint,
44 extant_tree_state: TreeState,
45 extant_marks_removed: Option<BTreeSet<Position>>,
46 },
47 SubtreeDiscontinuity {
50 attempted_insertion_range: Range<u64>,
51 existing_range: Range<u64>,
52 },
53}
54
55impl fmt::Display for Error {
56 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57 match &self {
58 Error::Serialization(err) => write!(f, "Commitment tree serialization error: {}", err),
59 Error::Query(err) => write!(f, "Commitment tree query or update error: {}", err),
60 Error::CheckpointConflict {
61 checkpoint_id,
62 checkpoint,
63 extant_tree_state,
64 extant_marks_removed,
65 } => {
66 write!(
67 f,
68 "Conflict at checkpoint id {}, tried to insert {:?}, which is incompatible with existing state ({:?}, {:?})",
69 checkpoint_id, checkpoint, extant_tree_state, extant_marks_removed
70 )
71 }
72 Error::SubtreeDiscontinuity {
73 attempted_insertion_range,
74 existing_range,
75 } => {
76 write!(
77 f,
78 "Attempted to write subtree roots with indices {:?} which is discontinuous with existing subtree range {:?}",
79 attempted_insertion_range, existing_range,
80 )
81 }
82 }
83 }
84}
85
86impl error::Error for Error {
87 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
88 match &self {
89 Error::Serialization(e) => Some(e),
90 Error::Query(e) => Some(e),
91 Error::CheckpointConflict { .. } => None,
92 Error::SubtreeDiscontinuity { .. } => None,
93 }
94 }
95}
96
97pub struct SqliteShardStore<C, H, const SHARD_HEIGHT: u8> {
98 pub(crate) conn: C,
99 table_prefix: &'static str,
100 _hash_type: PhantomData<H>,
101}
102
103impl<C, H, const SHARD_HEIGHT: u8> SqliteShardStore<C, H, SHARD_HEIGHT> {
104 const SHARD_ROOT_LEVEL: Level = Level::new(SHARD_HEIGHT);
105
106 pub(crate) fn from_connection(
107 conn: C,
108 table_prefix: &'static str,
109 ) -> Result<Self, rusqlite::Error> {
110 Ok(SqliteShardStore {
111 conn,
112 table_prefix,
113 _hash_type: PhantomData,
114 })
115 }
116}
117
118impl<'conn, 'a: 'conn, H: HashSer, const SHARD_HEIGHT: u8> ShardStore
119 for SqliteShardStore<&'a rusqlite::Transaction<'conn>, H, SHARD_HEIGHT>
120{
121 type H = H;
122 type CheckpointId = BlockHeight;
123 type Error = Error;
124
125 fn get_shard(
126 &self,
127 shard_root: Address,
128 ) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
129 get_shard(self.conn, self.table_prefix, shard_root)
130 }
131
132 fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
133 last_shard(self.conn, self.table_prefix, Self::SHARD_ROOT_LEVEL)
134 }
135
136 fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
137 put_shard(self.conn, self.table_prefix, subtree)
138 }
139
140 fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
141 get_shard_roots(self.conn, self.table_prefix, Self::SHARD_ROOT_LEVEL)
142 }
143
144 fn truncate_shards(&mut self, shard_index: u64) -> Result<(), Self::Error> {
145 truncate_shards(self.conn, self.table_prefix, shard_index)
146 }
147
148 fn get_cap(&self) -> Result<PrunableTree<Self::H>, Self::Error> {
149 get_cap(self.conn, self.table_prefix)
150 }
151
152 fn put_cap(&mut self, cap: PrunableTree<Self::H>) -> Result<(), Self::Error> {
153 put_cap(self.conn, self.table_prefix, cap)
154 }
155
156 fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
157 min_checkpoint_id(self.conn, self.table_prefix)
158 }
159
160 fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
161 max_checkpoint_id(self.conn, self.table_prefix)
162 }
163
164 fn add_checkpoint(
165 &mut self,
166 checkpoint_id: Self::CheckpointId,
167 checkpoint: Checkpoint,
168 ) -> Result<(), Self::Error> {
169 add_checkpoint(self.conn, self.table_prefix, checkpoint_id, checkpoint)
170 }
171
172 fn checkpoint_count(&self) -> Result<usize, Self::Error> {
173 checkpoint_count(self.conn, self.table_prefix)
174 }
175
176 fn get_checkpoint_at_depth(
177 &self,
178 checkpoint_depth: usize,
179 ) -> Result<Option<(Self::CheckpointId, Checkpoint)>, Self::Error> {
180 get_checkpoint_at_depth(self.conn, self.table_prefix, checkpoint_depth)
181 .map_err(Error::Query)
182 }
183
184 fn get_checkpoint(
185 &self,
186 checkpoint_id: &Self::CheckpointId,
187 ) -> Result<Option<Checkpoint>, Self::Error> {
188 get_checkpoint(self.conn, self.table_prefix, *checkpoint_id)
189 }
190
191 fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
192 where
193 F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
194 {
195 with_checkpoints(self.conn, self.table_prefix, limit, callback)
196 }
197
198 fn for_each_checkpoint<F>(&self, limit: usize, callback: F) -> Result<(), Self::Error>
199 where
200 F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
201 {
202 with_checkpoints(self.conn, self.table_prefix, limit, callback)
203 }
204
205 fn update_checkpoint_with<F>(
206 &mut self,
207 checkpoint_id: &Self::CheckpointId,
208 update: F,
209 ) -> Result<bool, Self::Error>
210 where
211 F: Fn(&mut Checkpoint) -> Result<(), Self::Error>,
212 {
213 update_checkpoint_with(self.conn, self.table_prefix, *checkpoint_id, update)
214 }
215
216 fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> {
217 remove_checkpoint(self.conn, self.table_prefix, *checkpoint_id)
218 }
219
220 fn truncate_checkpoints_retaining(
221 &mut self,
222 checkpoint_id: &Self::CheckpointId,
223 ) -> Result<(), Self::Error> {
224 truncate_checkpoints_retaining(self.conn, self.table_prefix, *checkpoint_id)
225 }
226}
227
228impl<H: HashSer, const SHARD_HEIGHT: u8> ShardStore
229 for SqliteShardStore<rusqlite::Connection, H, SHARD_HEIGHT>
230{
231 type H = H;
232 type CheckpointId = BlockHeight;
233 type Error = Error;
234
235 fn get_shard(
236 &self,
237 shard_root: Address,
238 ) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
239 get_shard(&self.conn, self.table_prefix, shard_root)
240 }
241
242 fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
243 last_shard(&self.conn, self.table_prefix, Self::SHARD_ROOT_LEVEL)
244 }
245
246 fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
247 let tx = self.conn.transaction().map_err(Error::Query)?;
248 put_shard(&tx, self.table_prefix, subtree)?;
249 tx.commit().map_err(Error::Query)?;
250 Ok(())
251 }
252
253 fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
254 get_shard_roots(&self.conn, self.table_prefix, Self::SHARD_ROOT_LEVEL)
255 }
256
257 fn truncate_shards(&mut self, shard_index: u64) -> Result<(), Self::Error> {
258 truncate_shards(&self.conn, self.table_prefix, shard_index)
259 }
260
261 fn get_cap(&self) -> Result<PrunableTree<Self::H>, Self::Error> {
262 get_cap(&self.conn, self.table_prefix)
263 }
264
265 fn put_cap(&mut self, cap: PrunableTree<Self::H>) -> Result<(), Self::Error> {
266 put_cap(&self.conn, self.table_prefix, cap)
267 }
268
269 fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
270 min_checkpoint_id(&self.conn, self.table_prefix)
271 }
272
273 fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
274 max_checkpoint_id(&self.conn, self.table_prefix)
275 }
276
277 fn add_checkpoint(
278 &mut self,
279 checkpoint_id: Self::CheckpointId,
280 checkpoint: Checkpoint,
281 ) -> Result<(), Self::Error> {
282 let tx = self.conn.transaction().map_err(Error::Query)?;
283 add_checkpoint(&tx, self.table_prefix, checkpoint_id, checkpoint)?;
284 tx.commit().map_err(Error::Query)
285 }
286
287 fn checkpoint_count(&self) -> Result<usize, Self::Error> {
288 checkpoint_count(&self.conn, self.table_prefix)
289 }
290
291 fn get_checkpoint_at_depth(
292 &self,
293 checkpoint_depth: usize,
294 ) -> Result<Option<(Self::CheckpointId, Checkpoint)>, Self::Error> {
295 get_checkpoint_at_depth(&self.conn, self.table_prefix, checkpoint_depth)
296 .map_err(Error::Query)
297 }
298
299 fn get_checkpoint(
300 &self,
301 checkpoint_id: &Self::CheckpointId,
302 ) -> Result<Option<Checkpoint>, Self::Error> {
303 get_checkpoint(&self.conn, self.table_prefix, *checkpoint_id)
304 }
305
306 fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
307 where
308 F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
309 {
310 let tx = self.conn.transaction().map_err(Error::Query)?;
311 with_checkpoints(&tx, self.table_prefix, limit, callback)?;
312 tx.commit().map_err(Error::Query)
313 }
314
315 fn for_each_checkpoint<F>(&self, limit: usize, callback: F) -> Result<(), Self::Error>
316 where
317 F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
318 {
319 let tx = self.conn.unchecked_transaction().map_err(Error::Query)?;
320 with_checkpoints(&tx, self.table_prefix, limit, callback)?;
321 tx.rollback().map_err(Error::Query)
324 }
325
326 fn update_checkpoint_with<F>(
327 &mut self,
328 checkpoint_id: &Self::CheckpointId,
329 update: F,
330 ) -> Result<bool, Self::Error>
331 where
332 F: Fn(&mut Checkpoint) -> Result<(), Self::Error>,
333 {
334 let tx = self.conn.transaction().map_err(Error::Query)?;
335 let result = update_checkpoint_with(&tx, self.table_prefix, *checkpoint_id, update)?;
336 tx.commit().map_err(Error::Query)?;
337 Ok(result)
338 }
339
340 fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> {
341 let tx = self.conn.transaction().map_err(Error::Query)?;
342 remove_checkpoint(&tx, self.table_prefix, *checkpoint_id)?;
343 tx.commit().map_err(Error::Query)
344 }
345
346 fn truncate_checkpoints_retaining(
347 &mut self,
348 checkpoint_id: &Self::CheckpointId,
349 ) -> Result<(), Self::Error> {
350 let tx = self.conn.transaction().map_err(Error::Query)?;
351 truncate_checkpoints_retaining(&tx, self.table_prefix, *checkpoint_id)?;
352 tx.commit().map_err(Error::Query)
353 }
354}
355
356pub(crate) fn get_shard<H: HashSer>(
357 conn: &rusqlite::Connection,
358 table_prefix: &'static str,
359 shard_root_addr: Address,
360) -> Result<Option<LocatedPrunableTree<H>>, Error> {
361 conn.query_row(
362 &format!(
363 "SELECT shard_data, root_hash
364 FROM {}_tree_shards
365 WHERE shard_index = :shard_index",
366 table_prefix
367 ),
368 named_params![":shard_index": shard_root_addr.index()],
369 |row| Ok((row.get::<_, Vec<u8>>(0)?, row.get::<_, Option<Vec<u8>>>(1)?)),
370 )
371 .optional()
372 .map_err(Error::Query)?
373 .map(|(shard_data, root_hash)| {
374 let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Error::Serialization)?;
375 let located_tree =
376 LocatedPrunableTree::from_parts(shard_root_addr, shard_tree).map_err(|e| {
377 Error::Serialization(io::Error::new(
378 io::ErrorKind::InvalidData,
379 format!("Tree contained invalid data at address {:?}", e),
380 ))
381 })?;
382 if let Some(root_hash_data) = root_hash {
383 let root_hash = H::read(Cursor::new(root_hash_data)).map_err(Error::Serialization)?;
384 Ok(located_tree.reannotate_root(Some(Arc::new(root_hash))))
385 } else {
386 Ok(located_tree)
387 }
388 })
389 .transpose()
390}
391
392pub(crate) fn last_shard<H: HashSer>(
393 conn: &rusqlite::Connection,
394 table_prefix: &'static str,
395 shard_root_level: Level,
396) -> Result<Option<LocatedPrunableTree<H>>, Error> {
397 conn.query_row(
398 &format!(
399 "SELECT shard_index, shard_data
400 FROM {}_tree_shards
401 ORDER BY shard_index DESC
402 LIMIT 1",
403 table_prefix
404 ),
405 [],
406 |row| {
407 let shard_index: u64 = row.get(0)?;
408 let shard_data: Vec<u8> = row.get(1)?;
409 Ok((shard_index, shard_data))
410 },
411 )
412 .optional()
413 .map_err(Error::Query)?
414 .map(|(shard_index, shard_data)| {
415 let shard_root = Address::from_parts(shard_root_level, shard_index);
416 let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Error::Serialization)?;
417 LocatedPrunableTree::from_parts(shard_root, shard_tree).map_err(|e| {
418 Error::Serialization(io::Error::new(
419 io::ErrorKind::InvalidData,
420 format!("Tree contained invalid data at address {:?}", e),
421 ))
422 })
423 })
424 .transpose()
425}
426
427#[tracing::instrument(skip(conn))]
431fn check_shard_discontinuity(
432 conn: &rusqlite::Connection,
433 table_prefix: &'static str,
434 proposed_insertion_range: Range<u64>,
435) -> Result<(), Error> {
436 if let Ok((Some(stored_min), Some(stored_max))) = conn
437 .query_row(
438 &format!(
439 "SELECT MIN(shard_index), MAX(shard_index) FROM {}_tree_shards",
440 table_prefix
441 ),
442 [],
443 |row| {
444 let min = row.get::<_, Option<u64>>(0)?;
445 let max = row.get::<_, Option<u64>>(1)?;
446 Ok((min, max))
447 },
448 )
449 .map_err(Error::Query)
450 {
451 let (cur_start, cur_end) = (stored_min, stored_max + 1);
459 let (ins_start, ins_end) = (proposed_insertion_range.start, proposed_insertion_range.end);
460 if cur_start > ins_end || ins_start > cur_end {
461 return Err(Error::SubtreeDiscontinuity {
462 attempted_insertion_range: proposed_insertion_range,
463 existing_range: cur_start..cur_end,
464 });
465 }
466 }
467
468 Ok(())
469}
470
471pub(crate) fn put_shard<H: HashSer>(
472 conn: &rusqlite::Transaction<'_>,
473 table_prefix: &'static str,
474 subtree: LocatedPrunableTree<H>,
475) -> Result<(), Error> {
476 let subtree_root_hash = subtree
477 .root()
478 .annotation()
479 .and_then(|ann| {
480 ann.as_ref().map(|rc| {
481 let mut root_hash = vec![];
482 rc.write(&mut root_hash)?;
483 Ok(root_hash)
484 })
485 })
486 .transpose()
487 .map_err(Error::Serialization)?;
488
489 let mut subtree_data = vec![];
490 write_shard(&mut subtree_data, subtree.root()).map_err(Error::Serialization)?;
491
492 let shard_index = subtree.root_addr().index();
493
494 check_shard_discontinuity(conn, table_prefix, shard_index..shard_index + 1)?;
495
496 let mut stmt_put_shard = conn
497 .prepare_cached(&format!(
498 "INSERT INTO {}_tree_shards (shard_index, root_hash, shard_data)
499 VALUES (:shard_index, :root_hash, :shard_data)
500 ON CONFLICT (shard_index) DO UPDATE
501 SET root_hash = :root_hash,
502 shard_data = :shard_data",
503 table_prefix
504 ))
505 .map_err(Error::Query)?;
506
507 stmt_put_shard
508 .execute(named_params![
509 ":shard_index": shard_index,
510 ":root_hash": subtree_root_hash,
511 ":shard_data": subtree_data
512 ])
513 .map_err(Error::Query)?;
514
515 Ok(())
516}
517
518pub(crate) fn get_shard_roots(
519 conn: &rusqlite::Connection,
520 table_prefix: &'static str,
521 shard_root_level: Level,
522) -> Result<Vec<Address>, Error> {
523 let mut stmt = conn
524 .prepare(&format!(
525 "SELECT shard_index FROM {}_tree_shards ORDER BY shard_index",
526 table_prefix
527 ))
528 .map_err(Error::Query)?;
529 let mut rows = stmt.query([]).map_err(Error::Query)?;
530
531 let mut res = vec![];
532 while let Some(row) = rows.next().map_err(Error::Query)? {
533 res.push(Address::from_parts(
534 shard_root_level,
535 row.get(0).map_err(Error::Query)?,
536 ));
537 }
538 Ok(res)
539}
540
541pub(crate) fn truncate_shards(
542 conn: &rusqlite::Connection,
543 table_prefix: &'static str,
544 shard_index: u64,
545) -> Result<(), Error> {
546 conn.execute(
547 &format!(
548 "DELETE FROM {}_tree_shards WHERE shard_index >= ?",
549 table_prefix
550 ),
551 [shard_index],
552 )
553 .map_err(Error::Query)
554 .map(|_| ())
555}
556
557#[tracing::instrument(skip(conn))]
558pub(crate) fn get_cap<H: HashSer>(
559 conn: &rusqlite::Connection,
560 table_prefix: &'static str,
561) -> Result<PrunableTree<H>, Error> {
562 conn.query_row(
563 &format!("SELECT cap_data FROM {}_tree_cap", table_prefix),
564 [],
565 |row| row.get::<_, Vec<u8>>(0),
566 )
567 .optional()
568 .map_err(Error::Query)?
569 .map_or_else(
570 || Ok(PrunableTree::empty()),
571 |cap_data| read_shard(&mut Cursor::new(cap_data)).map_err(Error::Serialization),
572 )
573}
574
575#[tracing::instrument(skip(conn, cap))]
576pub(crate) fn put_cap<H: HashSer>(
577 conn: &rusqlite::Connection,
578 table_prefix: &'static str,
579 cap: PrunableTree<H>,
580) -> Result<(), Error> {
581 let mut stmt = conn
582 .prepare_cached(&format!(
583 "INSERT INTO {}_tree_cap (cap_id, cap_data)
584 VALUES (0, :cap_data)
585 ON CONFLICT (cap_id) DO UPDATE
586 SET cap_data = :cap_data",
587 table_prefix
588 ))
589 .map_err(Error::Query)?;
590
591 let mut cap_data = vec![];
592 write_shard(&mut cap_data, &cap).map_err(Error::Serialization)?;
593 stmt.execute([cap_data]).map_err(Error::Query)?;
594
595 Ok(())
596}
597
598pub(crate) fn min_checkpoint_id(
599 conn: &rusqlite::Connection,
600 table_prefix: &'static str,
601) -> Result<Option<BlockHeight>, Error> {
602 conn.query_row(
603 &format!(
604 "SELECT MIN(checkpoint_id) FROM {}_tree_checkpoints",
605 table_prefix
606 ),
607 [],
608 |row| {
609 row.get::<_, Option<u32>>(0)
610 .map(|opt| opt.map(BlockHeight::from))
611 },
612 )
613 .map_err(Error::Query)
614}
615
616pub(crate) fn max_checkpoint_id(
617 conn: &rusqlite::Connection,
618 table_prefix: &'static str,
619) -> Result<Option<BlockHeight>, Error> {
620 conn.query_row(
621 &format!(
622 "SELECT MAX(checkpoint_id) FROM {}_tree_checkpoints",
623 table_prefix
624 ),
625 [],
626 |row| {
627 row.get::<_, Option<u32>>(0)
628 .map(|opt| opt.map(BlockHeight::from))
629 },
630 )
631 .map_err(Error::Query)
632}
633
634pub(crate) fn add_checkpoint(
635 conn: &rusqlite::Transaction<'_>,
636 table_prefix: &'static str,
637 checkpoint_id: BlockHeight,
638 checkpoint: Checkpoint,
639) -> Result<(), Error> {
640 let extant_tree_state = conn
641 .query_row(
642 &format!(
643 "SELECT position FROM {}_tree_checkpoints WHERE checkpoint_id = :checkpoint_id",
644 table_prefix
645 ),
646 named_params![":checkpoint_id": u32::from(checkpoint_id),],
647 |row| {
648 row.get::<_, Option<u64>>(0).map(|opt| {
649 opt.map_or_else(
650 || TreeState::Empty,
651 |pos| TreeState::AtPosition(Position::from(pos)),
652 )
653 })
654 },
655 )
656 .optional()
657 .map_err(Error::Query)?;
658
659 match extant_tree_state {
660 Some(current) => {
661 if current != checkpoint.tree_state() {
662 Err(Error::CheckpointConflict {
666 checkpoint_id,
667 checkpoint,
668 extant_tree_state: current,
669 extant_marks_removed: None,
670 })
671 } else {
672 let marks_removed = get_marks_removed(conn, table_prefix, checkpoint_id)?;
675 if &marks_removed == checkpoint.marks_removed() {
676 Ok(())
677 } else {
678 Err(Error::CheckpointConflict {
679 checkpoint_id,
680 checkpoint,
681 extant_tree_state: current,
682 extant_marks_removed: Some(marks_removed),
683 })
684 }
685 }
686 }
687 None => {
688 let mut stmt_insert_checkpoint = conn
689 .prepare_cached(&format!(
690 "INSERT INTO {}_tree_checkpoints (checkpoint_id, position)
691 VALUES (:checkpoint_id, :position)",
692 table_prefix
693 ))
694 .map_err(Error::Query)?;
695
696 stmt_insert_checkpoint
697 .execute(named_params![
698 ":checkpoint_id": u32::from(checkpoint_id),
699 ":position": checkpoint.position().map(u64::from)
700 ])
701 .map_err(Error::Query)?;
702
703 let mut stmt_insert_mark_removed = conn
704 .prepare_cached(&format!(
705 "INSERT INTO {}_tree_checkpoint_marks_removed (checkpoint_id, mark_removed_position)
706 VALUES (:checkpoint_id, :position)",
707 table_prefix
708 ))
709 .map_err(Error::Query)?;
710
711 for pos in checkpoint.marks_removed() {
712 stmt_insert_mark_removed
713 .execute(named_params![
714 ":checkpoint_id": u32::from(checkpoint_id),
715 ":position": u64::from(*pos)
716 ])
717 .map_err(Error::Query)?;
718 }
719
720 Ok(())
721 }
722 }
723}
724
725pub(crate) fn checkpoint_count(
726 conn: &rusqlite::Connection,
727 table_prefix: &'static str,
728) -> Result<usize, Error> {
729 conn.query_row(
730 &format!("SELECT COUNT(*) FROM {}_tree_checkpoints", table_prefix),
731 [],
732 |row| row.get::<_, usize>(0),
733 )
734 .map_err(Error::Query)
735}
736
737fn get_marks_removed(
738 conn: &rusqlite::Connection,
739 table_prefix: &'static str,
740 checkpoint_id: BlockHeight,
741) -> Result<BTreeSet<Position>, Error> {
742 let mut stmt = conn
743 .prepare_cached(&format!(
744 "SELECT mark_removed_position
745 FROM {}_tree_checkpoint_marks_removed
746 WHERE checkpoint_id = ?",
747 table_prefix
748 ))
749 .map_err(Error::Query)?;
750 let mark_removed_rows = stmt
751 .query([u32::from(checkpoint_id)])
752 .map_err(Error::Query)?;
753
754 mark_removed_rows
755 .mapped(|row| row.get::<_, u64>(0).map(Position::from))
756 .collect::<Result<BTreeSet<_>, _>>()
757 .map_err(Error::Query)
758}
759
760pub(crate) fn get_checkpoint(
761 conn: &rusqlite::Connection,
762 table_prefix: &'static str,
763 checkpoint_id: BlockHeight,
764) -> Result<Option<Checkpoint>, Error> {
765 let checkpoint_position = conn
766 .query_row(
767 &format!(
768 "SELECT position
769 FROM {}_tree_checkpoints
770 WHERE checkpoint_id = ?",
771 table_prefix
772 ),
773 [u32::from(checkpoint_id)],
774 |row| {
775 row.get::<_, Option<u64>>(0)
776 .map(|opt| opt.map(Position::from))
777 },
778 )
779 .optional()
780 .map_err(Error::Query)?;
781
782 checkpoint_position
783 .map(|pos_opt| {
784 Ok(Checkpoint::from_parts(
785 pos_opt.map_or(TreeState::Empty, TreeState::AtPosition),
786 get_marks_removed(conn, table_prefix, checkpoint_id)?,
787 ))
788 })
789 .transpose()
790}
791
792pub(crate) fn get_max_checkpointed_height(
793 conn: &rusqlite::Connection,
794 table_prefix: &'static str,
795 chain_tip_height: BlockHeight,
796 min_confirmations: NonZeroU32,
797) -> Result<Option<BlockHeight>, rusqlite::Error> {
798 let max_checkpoint_height =
799 u32::from(chain_tip_height).saturating_sub(u32::from(min_confirmations) - 1);
800
801 conn.query_row(
804 &format!(
805 "SELECT checkpoint_id
806 FROM {}_tree_checkpoints
807 WHERE checkpoint_id <= :max_checkpoint_height
808 ORDER BY checkpoint_id DESC
809 LIMIT 1",
810 table_prefix
811 ),
812 named_params![":max_checkpoint_height": max_checkpoint_height],
813 |row| row.get::<_, u32>(0).map(BlockHeight::from),
814 )
815 .optional()
816}
817
818pub(crate) fn get_checkpoint_at_depth(
819 conn: &rusqlite::Connection,
820 table_prefix: &'static str,
821 checkpoint_depth: usize,
822) -> Result<Option<(BlockHeight, Checkpoint)>, rusqlite::Error> {
823 let checkpoint_parts = conn
824 .query_row(
825 &format!(
826 "SELECT checkpoint_id, position
827 FROM {}_tree_checkpoints
828 ORDER BY checkpoint_id DESC
829 LIMIT 1
830 OFFSET :offset",
831 table_prefix
832 ),
833 named_params![":offset": checkpoint_depth],
834 |row| {
835 let checkpoint_id: u32 = row.get(0)?;
836 let position: Option<u64> = row.get(1)?;
837 Ok((
838 BlockHeight::from(checkpoint_id),
839 position.map(Position::from),
840 ))
841 },
842 )
843 .optional()?;
844
845 checkpoint_parts
846 .map(|(checkpoint_id, pos_opt)| {
847 let mut stmt = conn.prepare_cached(&format!(
848 "SELECT mark_removed_position
849 FROM {}_tree_checkpoint_marks_removed
850 WHERE checkpoint_id = ?",
851 table_prefix
852 ))?;
853 let mark_removed_rows = stmt.query([u32::from(checkpoint_id)])?;
854
855 let marks_removed = mark_removed_rows
856 .mapped(|row| row.get::<_, u64>(0).map(Position::from))
857 .collect::<Result<BTreeSet<_>, _>>()?;
858
859 Ok((
860 checkpoint_id,
861 Checkpoint::from_parts(
862 pos_opt.map_or(TreeState::Empty, TreeState::AtPosition),
863 marks_removed,
864 ),
865 ))
866 })
867 .transpose()
868}
869
870pub(crate) fn with_checkpoints<F>(
871 conn: &rusqlite::Transaction<'_>,
872 table_prefix: &'static str,
873 limit: usize,
874 mut callback: F,
875) -> Result<(), Error>
876where
877 F: FnMut(&BlockHeight, &Checkpoint) -> Result<(), Error>,
878{
879 let mut stmt_get_checkpoints = conn
880 .prepare_cached(&format!(
881 "SELECT checkpoint_id, position
882 FROM {}_tree_checkpoints
883 ORDER BY position
884 LIMIT :limit",
885 table_prefix
886 ))
887 .map_err(Error::Query)?;
888
889 let mut stmt_get_checkpoint_marks_removed = conn
890 .prepare_cached(&format!(
891 "SELECT mark_removed_position
892 FROM {}_tree_checkpoint_marks_removed
893 WHERE checkpoint_id = :checkpoint_id",
894 table_prefix
895 ))
896 .map_err(Error::Query)?;
897
898 let mut rows = stmt_get_checkpoints
899 .query(named_params![":limit": limit])
900 .map_err(Error::Query)?;
901
902 while let Some(row) = rows.next().map_err(Error::Query)? {
903 let checkpoint_id = row.get::<_, u32>(0).map_err(Error::Query)?;
904 let tree_state = row
905 .get::<_, Option<u64>>(1)
906 .map(|opt| opt.map_or_else(|| TreeState::Empty, |p| TreeState::AtPosition(p.into())))
907 .map_err(Error::Query)?;
908
909 let mark_removed_rows = stmt_get_checkpoint_marks_removed
910 .query(named_params![":checkpoint_id": checkpoint_id])
911 .map_err(Error::Query)?;
912
913 let marks_removed = mark_removed_rows
914 .mapped(|row| row.get::<_, u64>(0).map(Position::from))
915 .collect::<Result<BTreeSet<_>, _>>()
916 .map_err(Error::Query)?;
917
918 callback(
919 &BlockHeight::from(checkpoint_id),
920 &Checkpoint::from_parts(tree_state, marks_removed),
921 )?
922 }
923
924 Ok(())
925}
926
927pub(crate) fn update_checkpoint_with<F>(
928 conn: &rusqlite::Transaction<'_>,
929 table_prefix: &'static str,
930 checkpoint_id: BlockHeight,
931 update: F,
932) -> Result<bool, Error>
933where
934 F: Fn(&mut Checkpoint) -> Result<(), Error>,
935{
936 if let Some(mut c) = get_checkpoint(conn, table_prefix, checkpoint_id)? {
937 update(&mut c)?;
938 remove_checkpoint(conn, table_prefix, checkpoint_id)?;
939 add_checkpoint(conn, table_prefix, checkpoint_id, c)?;
940 Ok(true)
941 } else {
942 Ok(false)
943 }
944}
945
946pub(crate) fn remove_checkpoint(
947 conn: &rusqlite::Transaction<'_>,
948 table_prefix: &'static str,
949 checkpoint_id: BlockHeight,
950) -> Result<(), Error> {
951 let mut stmt_delete_checkpoint = conn
954 .prepare_cached(&format!(
955 "DELETE FROM {}_tree_checkpoints
956 WHERE checkpoint_id = :checkpoint_id",
957 table_prefix
958 ))
959 .map_err(Error::Query)?;
960
961 stmt_delete_checkpoint
962 .execute(named_params![":checkpoint_id": u32::from(checkpoint_id),])
963 .map_err(Error::Query)?;
964
965 Ok(())
966}
967
968pub(crate) fn truncate_checkpoints_retaining(
969 conn: &rusqlite::Transaction<'_>,
970 table_prefix: &'static str,
971 checkpoint_id: BlockHeight,
972) -> Result<(), Error> {
973 conn.execute(
976 &format!(
977 "DELETE FROM {}_tree_checkpoints WHERE checkpoint_id > ?",
978 table_prefix
979 ),
980 [u32::from(checkpoint_id)],
981 )
982 .map_err(Error::Query)?;
983
984 conn.execute(
986 &format!(
987 "DELETE FROM {}_tree_checkpoint_marks_removed WHERE checkpoint_id = ?",
988 table_prefix
989 ),
990 [u32::from(checkpoint_id)],
991 )
992 .map_err(Error::Query)?;
993
994 Ok(())
995}
996
997#[tracing::instrument(skip(conn, roots))]
998pub(crate) fn put_shard_roots<
999 H: Hashable + HashSer + Clone + Eq,
1000 const DEPTH: u8,
1001 const SHARD_HEIGHT: u8,
1002>(
1003 conn: &rusqlite::Transaction<'_>,
1004 table_prefix: &'static str,
1005 start_index: u64,
1006 roots: &[CommitmentTreeRoot<H>],
1007) -> Result<(), ShardTreeError<Error>> {
1008 if roots.is_empty() {
1009 return Ok(());
1011 }
1012
1013 #[derive(Clone, Debug, PartialEq, Eq)]
1017 struct LevelShifter<H, const SHARD_HEIGHT: u8>(H);
1018 impl<H: Hashable, const SHARD_HEIGHT: u8> Hashable for LevelShifter<H, SHARD_HEIGHT> {
1019 fn empty_leaf() -> Self {
1020 Self(H::empty_root(SHARD_HEIGHT.into()))
1021 }
1022
1023 fn combine(level: Level, a: &Self, b: &Self) -> Self {
1024 Self(H::combine(level + SHARD_HEIGHT, &a.0, &b.0))
1025 }
1026
1027 fn empty_root(level: Level) -> Self
1028 where
1029 Self: Sized,
1030 {
1031 Self(H::empty_root(level + SHARD_HEIGHT))
1032 }
1033 }
1034 impl<H: HashSer, const SHARD_HEIGHT: u8> HashSer for LevelShifter<H, SHARD_HEIGHT> {
1035 fn read<R: io::Read>(reader: R) -> io::Result<Self>
1036 where
1037 Self: Sized,
1038 {
1039 H::read(reader).map(Self)
1040 }
1041
1042 fn write<W: io::Write>(&self, writer: W) -> io::Result<()> {
1043 self.0.write(writer)
1044 }
1045 }
1046
1047 let cap = LocatedTree::from_parts(
1048 Address::from_parts((DEPTH - SHARD_HEIGHT).into(), 0),
1049 get_cap::<LevelShifter<H, SHARD_HEIGHT>>(conn, table_prefix)
1050 .map_err(ShardTreeError::Storage)?,
1051 )
1052 .map_err(|e| {
1053 ShardTreeError::Storage(Error::Serialization(io::Error::new(
1054 io::ErrorKind::InvalidData,
1055 format!("Note commitment tree cap was invalid at address {:?}", e),
1056 )))
1057 })?;
1058
1059 let insert_into_cap = tracing::info_span!("insert_into_cap").entered();
1060 let cap_result = cap
1061 .batch_insert::<(), _>(
1062 Position::from(start_index),
1063 roots
1064 .iter()
1065 .map(|r| (LevelShifter(r.root_hash().clone()), Retention::Reference)),
1066 )
1067 .map_err(ShardTreeError::Insert)?
1068 .expect("slice of inserted roots was verified to be nonempty");
1069 drop(insert_into_cap);
1070
1071 put_cap(conn, table_prefix, cap_result.subtree.take_root()).map_err(ShardTreeError::Storage)?;
1072
1073 check_shard_discontinuity(
1074 conn,
1075 table_prefix,
1076 start_index..start_index + (roots.len() as u64),
1077 )
1078 .map_err(ShardTreeError::Storage)?;
1079
1080 let mut stmt = conn
1085 .prepare_cached(&format!(
1086 "INSERT INTO {}_tree_shards (shard_index, subtree_end_height, root_hash, shard_data)
1087 VALUES (:shard_index, :subtree_end_height, :root_hash, :shard_data)
1088 ON CONFLICT (shard_index) DO UPDATE
1089 SET subtree_end_height = :subtree_end_height, root_hash = :root_hash",
1090 table_prefix
1091 ))
1092 .map_err(|e| ShardTreeError::Storage(Error::Query(e)))?;
1093
1094 let put_roots = tracing::info_span!("write_shards").entered();
1095 for (root, i) in roots.iter().zip(0u64..) {
1096 let mut shard_data: Vec<u8> = vec![];
1098 let tree = PrunableTree::leaf((root.root_hash().clone(), RetentionFlags::EPHEMERAL));
1099 write_shard(&mut shard_data, &tree)
1100 .map_err(|e| ShardTreeError::Storage(Error::Serialization(e)))?;
1101
1102 let mut root_hash_data: Vec<u8> = vec![];
1103 root.root_hash()
1104 .write(&mut root_hash_data)
1105 .map_err(|e| ShardTreeError::Storage(Error::Serialization(e)))?;
1106
1107 stmt.execute(named_params![
1108 ":shard_index": start_index + i,
1109 ":subtree_end_height": u32::from(root.subtree_end_height()),
1110 ":root_hash": root_hash_data,
1111 ":shard_data": shard_data,
1112 ])
1113 .map_err(|e| ShardTreeError::Storage(Error::Query(e)))?;
1114 }
1115 drop(put_roots);
1116
1117 Ok(())
1118}
1119
1120pub(crate) fn check_witnesses(
1121 conn: &rusqlite::Transaction<'_>,
1122) -> Result<Vec<Range<BlockHeight>>, SqliteClientError> {
1123 let chain_tip_height =
1124 super::chain_tip_height(conn)?.ok_or(SqliteClientError::ChainHeightUnknown)?;
1125 let wallet_birthday = super::wallet_birthday(conn)?.ok_or(SqliteClientError::AccountUnknown)?;
1126 let unspent_sapling_note_meta =
1127 super::sapling::select_unspent_note_meta(conn, chain_tip_height, wallet_birthday)?;
1128
1129 let mut scan_ranges = vec![];
1130 let mut sapling_incomplete = vec![];
1131 let sapling_tree = sapling_tree(conn)?;
1132 for m in unspent_sapling_note_meta.iter() {
1133 match sapling_tree.witness_at_checkpoint_depth(m.commitment_tree_position(), 0) {
1134 Ok(_) => {}
1135 Err(ShardTreeError::Query(QueryError::TreeIncomplete(mut addrs))) => {
1136 sapling_incomplete.append(&mut addrs);
1137 }
1138 Err(other) => {
1139 return Err(SqliteClientError::CommitmentTree(other));
1140 }
1141 }
1142 }
1143
1144 for addr in sapling_incomplete {
1145 let range = super::get_block_range(conn, ShieldedProtocol::Sapling, addr)?;
1146 scan_ranges.extend(range.into_iter());
1147 }
1148
1149 #[cfg(feature = "orchard")]
1150 {
1151 let unspent_orchard_note_meta =
1152 super::orchard::select_unspent_note_meta(conn, chain_tip_height, wallet_birthday)?;
1153 let mut orchard_incomplete = vec![];
1154 let orchard_tree = orchard_tree(conn)?;
1155 for m in unspent_orchard_note_meta.iter() {
1156 match orchard_tree.witness_at_checkpoint_depth(m.commitment_tree_position(), 0) {
1157 Ok(_) => {}
1158 Err(ShardTreeError::Query(QueryError::TreeIncomplete(mut addrs))) => {
1159 orchard_incomplete.append(&mut addrs);
1160 }
1161 Err(other) => {
1162 return Err(SqliteClientError::CommitmentTree(other));
1163 }
1164 }
1165 }
1166
1167 for addr in orchard_incomplete {
1168 let range = super::get_block_range(conn, ShieldedProtocol::Orchard, addr)?;
1169 scan_ranges.extend(range.into_iter());
1170 }
1171 }
1172
1173 Ok(scan_ranges)
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178 use tempfile::NamedTempFile;
1179
1180 use incrementalmerkletree::{Marking, Position, Retention};
1181 use incrementalmerkletree_testing::{
1182 check_append, check_checkpoint_rewind, check_remove_mark, check_rewind_remove_mark,
1183 check_root_hashes, check_witness_consistency, check_witnesses,
1184 };
1185 use shardtree::ShardTree;
1186 use zcash_client_backend::data_api::{
1187 chain::CommitmentTreeRoot,
1188 testing::{pool::ShieldedPoolTester, sapling::SaplingPoolTester},
1189 };
1190 use zcash_protocol::consensus::{BlockHeight, Network};
1191
1192 use super::SqliteShardStore;
1193 use crate::{
1194 testing::{
1195 db::{test_clock, test_rng},
1196 pool::ShieldedPoolPersistence,
1197 },
1198 wallet::init::WalletMigrator,
1199 WalletDb,
1200 };
1201
1202 fn new_tree<T: ShieldedPoolTester + ShieldedPoolPersistence>(
1203 m: usize,
1204 ) -> ShardTree<SqliteShardStore<rusqlite::Connection, String, 3>, 4, 3> {
1205 let data_file = NamedTempFile::new().unwrap();
1206 let mut db_data = WalletDb::for_path(
1207 data_file.path(),
1208 Network::TestNetwork,
1209 test_clock(),
1210 test_rng(),
1211 )
1212 .unwrap();
1213 data_file.keep().unwrap();
1214
1215 WalletMigrator::new().init_or_migrate(&mut db_data).unwrap();
1216 let store =
1217 SqliteShardStore::<_, String, 3>::from_connection(db_data.conn, T::TABLES_PREFIX)
1218 .unwrap();
1219 ShardTree::new(store, m)
1220 }
1221
1222 #[cfg(feature = "orchard")]
1223 mod orchard {
1224 use super::new_tree;
1225 use zcash_client_backend::data_api::testing::orchard::OrchardPoolTester;
1226
1227 #[test]
1228 fn append() {
1229 super::check_append(new_tree::<OrchardPoolTester>);
1230 }
1231
1232 #[test]
1233 fn root_hashes() {
1234 super::check_root_hashes(new_tree::<OrchardPoolTester>);
1235 }
1236
1237 #[test]
1238 fn witnesses() {
1239 super::check_witnesses(new_tree::<OrchardPoolTester>);
1240 }
1241
1242 #[test]
1243 fn witness_consistency() {
1244 super::check_witness_consistency(new_tree::<OrchardPoolTester>);
1245 }
1246
1247 #[test]
1248 fn checkpoint_rewind() {
1249 super::check_checkpoint_rewind(new_tree::<OrchardPoolTester>);
1250 }
1251
1252 #[test]
1253 fn remove_mark() {
1254 super::check_remove_mark(new_tree::<OrchardPoolTester>);
1255 }
1256
1257 #[test]
1258 fn rewind_remove_mark() {
1259 super::check_rewind_remove_mark(new_tree::<OrchardPoolTester>);
1260 }
1261
1262 #[test]
1263 fn put_shard_roots() {
1264 super::put_shard_roots::<OrchardPoolTester>()
1265 }
1266 }
1267
1268 #[test]
1269 fn sapling_append() {
1270 check_append(new_tree::<SaplingPoolTester>);
1271 }
1272
1273 #[test]
1274 fn sapling_root_hashes() {
1275 check_root_hashes(new_tree::<SaplingPoolTester>);
1276 }
1277
1278 #[test]
1279 fn sapling_witnesses() {
1280 check_witnesses(new_tree::<SaplingPoolTester>);
1281 }
1282
1283 #[test]
1284 fn sapling_witness_consistency() {
1285 check_witness_consistency(new_tree::<SaplingPoolTester>);
1286 }
1287
1288 #[test]
1289 fn sapling_checkpoint_rewind() {
1290 check_checkpoint_rewind(new_tree::<SaplingPoolTester>);
1291 }
1292
1293 #[test]
1294 fn sapling_remove_mark() {
1295 check_remove_mark(new_tree::<SaplingPoolTester>);
1296 }
1297
1298 #[test]
1299 fn sapling_rewind_remove_mark() {
1300 check_rewind_remove_mark(new_tree::<SaplingPoolTester>);
1301 }
1302
1303 #[test]
1304 fn sapling_put_shard_roots() {
1305 put_shard_roots::<SaplingPoolTester>()
1306 }
1307
1308 fn put_shard_roots<T: ShieldedPoolTester + ShieldedPoolPersistence>() {
1309 let data_file = NamedTempFile::new().unwrap();
1310 let mut db_data = WalletDb::for_path(
1311 data_file.path(),
1312 Network::TestNetwork,
1313 test_clock(),
1314 test_rng(),
1315 )
1316 .unwrap();
1317 data_file.keep().unwrap();
1318
1319 WalletMigrator::new().init_or_migrate(&mut db_data).unwrap();
1320 let tx = db_data.conn.transaction().unwrap();
1321 let store =
1322 SqliteShardStore::<_, String, 3>::from_connection(&tx, T::TABLES_PREFIX).unwrap();
1323
1324 let roots = (0u32..4)
1326 .map(|idx| {
1327 CommitmentTreeRoot::from_parts(
1328 BlockHeight::from((idx + 1) * 3),
1329 if idx == 3 {
1330 "abcdefgh".to_string()
1331 } else {
1332 idx.to_string()
1333 },
1334 )
1335 })
1336 .collect::<Vec<_>>();
1337 super::put_shard_roots::<_, 6, 3>(store.conn, T::TABLES_PREFIX, 0, &roots).unwrap();
1338
1339 let mut tree = ShardTree::<_, 6, 3>::new(store, 10);
1341 let checkpoint_height = BlockHeight::from(3);
1342 tree.batch_insert(
1343 Position::from(24),
1344 ('a'..='h').map(|c| {
1345 (
1346 c.to_string(),
1347 match c {
1348 'c' => Retention::Marked,
1349 'h' => Retention::Checkpoint {
1350 id: checkpoint_height,
1351 marking: Marking::None,
1352 },
1353 _ => Retention::Ephemeral,
1354 },
1355 )
1356 }),
1357 )
1358 .unwrap();
1359
1360 let witness = tree
1362 .witness_at_checkpoint_id(Position::from(26), &checkpoint_height)
1363 .unwrap();
1364 assert_eq!(
1365 witness
1366 .expect("an anchor exists at the expected checkpoint height")
1367 .path_elems(),
1368 &[
1369 "d",
1370 "ab",
1371 "efgh",
1372 "2",
1373 "01",
1374 "________________________________"
1375 ]
1376 );
1377 }
1378}