zcash_client_sqlite/wallet/
commitment_tree.rs

1use rusqlite::{self, named_params, OptionalExtension};
2use std::{
3    collections::BTreeSet,
4    error, fmt,
5    io::{self, Cursor},
6    marker::PhantomData,
7    num::NonZeroU32,
8    ops::Range,
9    sync::Arc,
10};
11
12use incrementalmerkletree::{Address, Hashable, Level, Position, Retention};
13use shardtree::{
14    error::{QueryError, ShardTreeError},
15    store::{Checkpoint, ShardStore, TreeState},
16    LocatedPrunableTree, LocatedTree, PrunableTree, RetentionFlags,
17};
18
19use zcash_client_backend::{
20    data_api::chain::CommitmentTreeRoot,
21    serialization::shardtree::{read_shard, write_shard},
22};
23use zcash_primitives::merkle_tree::HashSer;
24use zcash_protocol::{consensus::BlockHeight, ShieldedProtocol};
25
26use crate::{error::SqliteClientError, sapling_tree};
27
28#[cfg(feature = "orchard")]
29use crate::orchard_tree;
30
31/// Errors that can appear in SQLite-back [`ShardStore`] implementation operations.
32#[derive(Debug)]
33pub enum Error {
34    /// Errors in deserializing stored shard data
35    Serialization(io::Error),
36    /// Errors encountered querying stored shard data
37    Query(rusqlite::Error),
38    /// Raised when the caller attempts to add a checkpoint at a block height where a checkpoint
39    /// already exists, but the tree state being checkpointed or the marks removed at that
40    /// checkpoint conflict with the existing tree state.
41    CheckpointConflict {
42        checkpoint_id: BlockHeight,
43        checkpoint: Checkpoint,
44        extant_tree_state: TreeState,
45        extant_marks_removed: Option<BTreeSet<Position>>,
46    },
47    /// Raised when attempting to add shard roots to the database that
48    /// are discontinuous with the existing roots in the database.
49    SubtreeDiscontinuity {
50        attempted_insertion_range: Range<u64>,
51        existing_range: Range<u64>,
52    },
53}
54
55impl fmt::Display for Error {
56    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57        match &self {
58            Error::Serialization(err) => write!(f, "Commitment tree serialization error: {}", err),
59            Error::Query(err) => write!(f, "Commitment tree query or update error: {}", err),
60            Error::CheckpointConflict {
61                checkpoint_id,
62                checkpoint,
63                extant_tree_state,
64                extant_marks_removed,
65            } => {
66                write!(
67                    f,
68                    "Conflict at checkpoint id {}, tried to insert {:?}, which is incompatible with existing state ({:?}, {:?})",
69                    checkpoint_id, checkpoint, extant_tree_state, extant_marks_removed
70                )
71            }
72            Error::SubtreeDiscontinuity {
73                attempted_insertion_range,
74                existing_range,
75            } => {
76                write!(
77                    f,
78                    "Attempted to write subtree roots with indices {:?} which is discontinuous with existing subtree range {:?}",
79                    attempted_insertion_range, existing_range,
80                )
81            }
82        }
83    }
84}
85
86impl error::Error for Error {
87    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
88        match &self {
89            Error::Serialization(e) => Some(e),
90            Error::Query(e) => Some(e),
91            Error::CheckpointConflict { .. } => None,
92            Error::SubtreeDiscontinuity { .. } => None,
93        }
94    }
95}
96
97pub struct SqliteShardStore<C, H, const SHARD_HEIGHT: u8> {
98    pub(crate) conn: C,
99    table_prefix: &'static str,
100    _hash_type: PhantomData<H>,
101}
102
103impl<C, H, const SHARD_HEIGHT: u8> SqliteShardStore<C, H, SHARD_HEIGHT> {
104    const SHARD_ROOT_LEVEL: Level = Level::new(SHARD_HEIGHT);
105
106    pub(crate) fn from_connection(
107        conn: C,
108        table_prefix: &'static str,
109    ) -> Result<Self, rusqlite::Error> {
110        Ok(SqliteShardStore {
111            conn,
112            table_prefix,
113            _hash_type: PhantomData,
114        })
115    }
116}
117
118impl<'conn, 'a: 'conn, H: HashSer, const SHARD_HEIGHT: u8> ShardStore
119    for SqliteShardStore<&'a rusqlite::Transaction<'conn>, H, SHARD_HEIGHT>
120{
121    type H = H;
122    type CheckpointId = BlockHeight;
123    type Error = Error;
124
125    fn get_shard(
126        &self,
127        shard_root: Address,
128    ) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
129        get_shard(self.conn, self.table_prefix, shard_root)
130    }
131
132    fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
133        last_shard(self.conn, self.table_prefix, Self::SHARD_ROOT_LEVEL)
134    }
135
136    fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
137        put_shard(self.conn, self.table_prefix, subtree)
138    }
139
140    fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
141        get_shard_roots(self.conn, self.table_prefix, Self::SHARD_ROOT_LEVEL)
142    }
143
144    fn truncate_shards(&mut self, shard_index: u64) -> Result<(), Self::Error> {
145        truncate_shards(self.conn, self.table_prefix, shard_index)
146    }
147
148    fn get_cap(&self) -> Result<PrunableTree<Self::H>, Self::Error> {
149        get_cap(self.conn, self.table_prefix)
150    }
151
152    fn put_cap(&mut self, cap: PrunableTree<Self::H>) -> Result<(), Self::Error> {
153        put_cap(self.conn, self.table_prefix, cap)
154    }
155
156    fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
157        min_checkpoint_id(self.conn, self.table_prefix)
158    }
159
160    fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
161        max_checkpoint_id(self.conn, self.table_prefix)
162    }
163
164    fn add_checkpoint(
165        &mut self,
166        checkpoint_id: Self::CheckpointId,
167        checkpoint: Checkpoint,
168    ) -> Result<(), Self::Error> {
169        add_checkpoint(self.conn, self.table_prefix, checkpoint_id, checkpoint)
170    }
171
172    fn checkpoint_count(&self) -> Result<usize, Self::Error> {
173        checkpoint_count(self.conn, self.table_prefix)
174    }
175
176    fn get_checkpoint_at_depth(
177        &self,
178        checkpoint_depth: usize,
179    ) -> Result<Option<(Self::CheckpointId, Checkpoint)>, Self::Error> {
180        get_checkpoint_at_depth(self.conn, self.table_prefix, checkpoint_depth)
181            .map_err(Error::Query)
182    }
183
184    fn get_checkpoint(
185        &self,
186        checkpoint_id: &Self::CheckpointId,
187    ) -> Result<Option<Checkpoint>, Self::Error> {
188        get_checkpoint(self.conn, self.table_prefix, *checkpoint_id)
189    }
190
191    fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
192    where
193        F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
194    {
195        with_checkpoints(self.conn, self.table_prefix, limit, callback)
196    }
197
198    fn for_each_checkpoint<F>(&self, limit: usize, callback: F) -> Result<(), Self::Error>
199    where
200        F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
201    {
202        with_checkpoints(self.conn, self.table_prefix, limit, callback)
203    }
204
205    fn update_checkpoint_with<F>(
206        &mut self,
207        checkpoint_id: &Self::CheckpointId,
208        update: F,
209    ) -> Result<bool, Self::Error>
210    where
211        F: Fn(&mut Checkpoint) -> Result<(), Self::Error>,
212    {
213        update_checkpoint_with(self.conn, self.table_prefix, *checkpoint_id, update)
214    }
215
216    fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> {
217        remove_checkpoint(self.conn, self.table_prefix, *checkpoint_id)
218    }
219
220    fn truncate_checkpoints_retaining(
221        &mut self,
222        checkpoint_id: &Self::CheckpointId,
223    ) -> Result<(), Self::Error> {
224        truncate_checkpoints_retaining(self.conn, self.table_prefix, *checkpoint_id)
225    }
226}
227
228impl<H: HashSer, const SHARD_HEIGHT: u8> ShardStore
229    for SqliteShardStore<rusqlite::Connection, H, SHARD_HEIGHT>
230{
231    type H = H;
232    type CheckpointId = BlockHeight;
233    type Error = Error;
234
235    fn get_shard(
236        &self,
237        shard_root: Address,
238    ) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
239        get_shard(&self.conn, self.table_prefix, shard_root)
240    }
241
242    fn last_shard(&self) -> Result<Option<LocatedPrunableTree<Self::H>>, Self::Error> {
243        last_shard(&self.conn, self.table_prefix, Self::SHARD_ROOT_LEVEL)
244    }
245
246    fn put_shard(&mut self, subtree: LocatedPrunableTree<Self::H>) -> Result<(), Self::Error> {
247        let tx = self.conn.transaction().map_err(Error::Query)?;
248        put_shard(&tx, self.table_prefix, subtree)?;
249        tx.commit().map_err(Error::Query)?;
250        Ok(())
251    }
252
253    fn get_shard_roots(&self) -> Result<Vec<Address>, Self::Error> {
254        get_shard_roots(&self.conn, self.table_prefix, Self::SHARD_ROOT_LEVEL)
255    }
256
257    fn truncate_shards(&mut self, shard_index: u64) -> Result<(), Self::Error> {
258        truncate_shards(&self.conn, self.table_prefix, shard_index)
259    }
260
261    fn get_cap(&self) -> Result<PrunableTree<Self::H>, Self::Error> {
262        get_cap(&self.conn, self.table_prefix)
263    }
264
265    fn put_cap(&mut self, cap: PrunableTree<Self::H>) -> Result<(), Self::Error> {
266        put_cap(&self.conn, self.table_prefix, cap)
267    }
268
269    fn min_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
270        min_checkpoint_id(&self.conn, self.table_prefix)
271    }
272
273    fn max_checkpoint_id(&self) -> Result<Option<Self::CheckpointId>, Self::Error> {
274        max_checkpoint_id(&self.conn, self.table_prefix)
275    }
276
277    fn add_checkpoint(
278        &mut self,
279        checkpoint_id: Self::CheckpointId,
280        checkpoint: Checkpoint,
281    ) -> Result<(), Self::Error> {
282        let tx = self.conn.transaction().map_err(Error::Query)?;
283        add_checkpoint(&tx, self.table_prefix, checkpoint_id, checkpoint)?;
284        tx.commit().map_err(Error::Query)
285    }
286
287    fn checkpoint_count(&self) -> Result<usize, Self::Error> {
288        checkpoint_count(&self.conn, self.table_prefix)
289    }
290
291    fn get_checkpoint_at_depth(
292        &self,
293        checkpoint_depth: usize,
294    ) -> Result<Option<(Self::CheckpointId, Checkpoint)>, Self::Error> {
295        get_checkpoint_at_depth(&self.conn, self.table_prefix, checkpoint_depth)
296            .map_err(Error::Query)
297    }
298
299    fn get_checkpoint(
300        &self,
301        checkpoint_id: &Self::CheckpointId,
302    ) -> Result<Option<Checkpoint>, Self::Error> {
303        get_checkpoint(&self.conn, self.table_prefix, *checkpoint_id)
304    }
305
306    fn with_checkpoints<F>(&mut self, limit: usize, callback: F) -> Result<(), Self::Error>
307    where
308        F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
309    {
310        let tx = self.conn.transaction().map_err(Error::Query)?;
311        with_checkpoints(&tx, self.table_prefix, limit, callback)?;
312        tx.commit().map_err(Error::Query)
313    }
314
315    fn for_each_checkpoint<F>(&self, limit: usize, callback: F) -> Result<(), Self::Error>
316    where
317        F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>,
318    {
319        let tx = self.conn.unchecked_transaction().map_err(Error::Query)?;
320        with_checkpoints(&tx, self.table_prefix, limit, callback)?;
321        // Here, we use `tx.rollback` as the semantics of this method is that the callback must
322        // not mutate the data store.
323        tx.rollback().map_err(Error::Query)
324    }
325
326    fn update_checkpoint_with<F>(
327        &mut self,
328        checkpoint_id: &Self::CheckpointId,
329        update: F,
330    ) -> Result<bool, Self::Error>
331    where
332        F: Fn(&mut Checkpoint) -> Result<(), Self::Error>,
333    {
334        let tx = self.conn.transaction().map_err(Error::Query)?;
335        let result = update_checkpoint_with(&tx, self.table_prefix, *checkpoint_id, update)?;
336        tx.commit().map_err(Error::Query)?;
337        Ok(result)
338    }
339
340    fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> {
341        let tx = self.conn.transaction().map_err(Error::Query)?;
342        remove_checkpoint(&tx, self.table_prefix, *checkpoint_id)?;
343        tx.commit().map_err(Error::Query)
344    }
345
346    fn truncate_checkpoints_retaining(
347        &mut self,
348        checkpoint_id: &Self::CheckpointId,
349    ) -> Result<(), Self::Error> {
350        let tx = self.conn.transaction().map_err(Error::Query)?;
351        truncate_checkpoints_retaining(&tx, self.table_prefix, *checkpoint_id)?;
352        tx.commit().map_err(Error::Query)
353    }
354}
355
356pub(crate) fn get_shard<H: HashSer>(
357    conn: &rusqlite::Connection,
358    table_prefix: &'static str,
359    shard_root_addr: Address,
360) -> Result<Option<LocatedPrunableTree<H>>, Error> {
361    conn.query_row(
362        &format!(
363            "SELECT shard_data, root_hash
364             FROM {}_tree_shards
365             WHERE shard_index = :shard_index",
366            table_prefix
367        ),
368        named_params![":shard_index": shard_root_addr.index()],
369        |row| Ok((row.get::<_, Vec<u8>>(0)?, row.get::<_, Option<Vec<u8>>>(1)?)),
370    )
371    .optional()
372    .map_err(Error::Query)?
373    .map(|(shard_data, root_hash)| {
374        let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Error::Serialization)?;
375        let located_tree =
376            LocatedPrunableTree::from_parts(shard_root_addr, shard_tree).map_err(|e| {
377                Error::Serialization(io::Error::new(
378                    io::ErrorKind::InvalidData,
379                    format!("Tree contained invalid data at address {:?}", e),
380                ))
381            })?;
382        if let Some(root_hash_data) = root_hash {
383            let root_hash = H::read(Cursor::new(root_hash_data)).map_err(Error::Serialization)?;
384            Ok(located_tree.reannotate_root(Some(Arc::new(root_hash))))
385        } else {
386            Ok(located_tree)
387        }
388    })
389    .transpose()
390}
391
392pub(crate) fn last_shard<H: HashSer>(
393    conn: &rusqlite::Connection,
394    table_prefix: &'static str,
395    shard_root_level: Level,
396) -> Result<Option<LocatedPrunableTree<H>>, Error> {
397    conn.query_row(
398        &format!(
399            "SELECT shard_index, shard_data
400             FROM {}_tree_shards
401             ORDER BY shard_index DESC
402             LIMIT 1",
403            table_prefix
404        ),
405        [],
406        |row| {
407            let shard_index: u64 = row.get(0)?;
408            let shard_data: Vec<u8> = row.get(1)?;
409            Ok((shard_index, shard_data))
410        },
411    )
412    .optional()
413    .map_err(Error::Query)?
414    .map(|(shard_index, shard_data)| {
415        let shard_root = Address::from_parts(shard_root_level, shard_index);
416        let shard_tree = read_shard(&mut Cursor::new(shard_data)).map_err(Error::Serialization)?;
417        LocatedPrunableTree::from_parts(shard_root, shard_tree).map_err(|e| {
418            Error::Serialization(io::Error::new(
419                io::ErrorKind::InvalidData,
420                format!("Tree contained invalid data at address {:?}", e),
421            ))
422        })
423    })
424    .transpose()
425}
426
427/// Returns an error iff the proposed insertion range
428/// for the tree shards would create a discontinuity
429/// in the database.
430#[tracing::instrument(skip(conn))]
431fn check_shard_discontinuity(
432    conn: &rusqlite::Connection,
433    table_prefix: &'static str,
434    proposed_insertion_range: Range<u64>,
435) -> Result<(), Error> {
436    if let Ok((Some(stored_min), Some(stored_max))) = conn
437        .query_row(
438            &format!(
439                "SELECT MIN(shard_index), MAX(shard_index) FROM {}_tree_shards",
440                table_prefix
441            ),
442            [],
443            |row| {
444                let min = row.get::<_, Option<u64>>(0)?;
445                let max = row.get::<_, Option<u64>>(1)?;
446                Ok((min, max))
447            },
448        )
449        .map_err(Error::Query)
450    {
451        // If the ranges overlap, or are directly adjacent, then we aren't creating a
452        // discontinuity. We can check this by comparing their start-inclusive,
453        // end-exclusive bounds:
454        // - If `cur_start == ins_end` then the proposed insertion range is immediately
455        //   before the current shards. If `cur_start > ins_end` then there is a gap.
456        // - If `ins_start == cur_end` then the proposed insertion range is immediately
457        //   after the current shards. If `ins_start > cur_end` then there is a gap.
458        let (cur_start, cur_end) = (stored_min, stored_max + 1);
459        let (ins_start, ins_end) = (proposed_insertion_range.start, proposed_insertion_range.end);
460        if cur_start > ins_end || ins_start > cur_end {
461            return Err(Error::SubtreeDiscontinuity {
462                attempted_insertion_range: proposed_insertion_range,
463                existing_range: cur_start..cur_end,
464            });
465        }
466    }
467
468    Ok(())
469}
470
471pub(crate) fn put_shard<H: HashSer>(
472    conn: &rusqlite::Transaction<'_>,
473    table_prefix: &'static str,
474    subtree: LocatedPrunableTree<H>,
475) -> Result<(), Error> {
476    let subtree_root_hash = subtree
477        .root()
478        .annotation()
479        .and_then(|ann| {
480            ann.as_ref().map(|rc| {
481                let mut root_hash = vec![];
482                rc.write(&mut root_hash)?;
483                Ok(root_hash)
484            })
485        })
486        .transpose()
487        .map_err(Error::Serialization)?;
488
489    let mut subtree_data = vec![];
490    write_shard(&mut subtree_data, subtree.root()).map_err(Error::Serialization)?;
491
492    let shard_index = subtree.root_addr().index();
493
494    check_shard_discontinuity(conn, table_prefix, shard_index..shard_index + 1)?;
495
496    let mut stmt_put_shard = conn
497        .prepare_cached(&format!(
498            "INSERT INTO {}_tree_shards (shard_index, root_hash, shard_data)
499             VALUES (:shard_index, :root_hash, :shard_data)
500             ON CONFLICT (shard_index) DO UPDATE
501             SET root_hash = :root_hash,
502             shard_data = :shard_data",
503            table_prefix
504        ))
505        .map_err(Error::Query)?;
506
507    stmt_put_shard
508        .execute(named_params![
509            ":shard_index": shard_index,
510            ":root_hash": subtree_root_hash,
511            ":shard_data": subtree_data
512        ])
513        .map_err(Error::Query)?;
514
515    Ok(())
516}
517
518pub(crate) fn get_shard_roots(
519    conn: &rusqlite::Connection,
520    table_prefix: &'static str,
521    shard_root_level: Level,
522) -> Result<Vec<Address>, Error> {
523    let mut stmt = conn
524        .prepare(&format!(
525            "SELECT shard_index FROM {}_tree_shards ORDER BY shard_index",
526            table_prefix
527        ))
528        .map_err(Error::Query)?;
529    let mut rows = stmt.query([]).map_err(Error::Query)?;
530
531    let mut res = vec![];
532    while let Some(row) = rows.next().map_err(Error::Query)? {
533        res.push(Address::from_parts(
534            shard_root_level,
535            row.get(0).map_err(Error::Query)?,
536        ));
537    }
538    Ok(res)
539}
540
541pub(crate) fn truncate_shards(
542    conn: &rusqlite::Connection,
543    table_prefix: &'static str,
544    shard_index: u64,
545) -> Result<(), Error> {
546    conn.execute(
547        &format!(
548            "DELETE FROM {}_tree_shards WHERE shard_index >= ?",
549            table_prefix
550        ),
551        [shard_index],
552    )
553    .map_err(Error::Query)
554    .map(|_| ())
555}
556
557#[tracing::instrument(skip(conn))]
558pub(crate) fn get_cap<H: HashSer>(
559    conn: &rusqlite::Connection,
560    table_prefix: &'static str,
561) -> Result<PrunableTree<H>, Error> {
562    conn.query_row(
563        &format!("SELECT cap_data FROM {}_tree_cap", table_prefix),
564        [],
565        |row| row.get::<_, Vec<u8>>(0),
566    )
567    .optional()
568    .map_err(Error::Query)?
569    .map_or_else(
570        || Ok(PrunableTree::empty()),
571        |cap_data| read_shard(&mut Cursor::new(cap_data)).map_err(Error::Serialization),
572    )
573}
574
575#[tracing::instrument(skip(conn, cap))]
576pub(crate) fn put_cap<H: HashSer>(
577    conn: &rusqlite::Connection,
578    table_prefix: &'static str,
579    cap: PrunableTree<H>,
580) -> Result<(), Error> {
581    let mut stmt = conn
582        .prepare_cached(&format!(
583            "INSERT INTO {}_tree_cap (cap_id, cap_data)
584             VALUES (0, :cap_data)
585             ON CONFLICT (cap_id) DO UPDATE
586             SET cap_data = :cap_data",
587            table_prefix
588        ))
589        .map_err(Error::Query)?;
590
591    let mut cap_data = vec![];
592    write_shard(&mut cap_data, &cap).map_err(Error::Serialization)?;
593    stmt.execute([cap_data]).map_err(Error::Query)?;
594
595    Ok(())
596}
597
598pub(crate) fn min_checkpoint_id(
599    conn: &rusqlite::Connection,
600    table_prefix: &'static str,
601) -> Result<Option<BlockHeight>, Error> {
602    conn.query_row(
603        &format!(
604            "SELECT MIN(checkpoint_id) FROM {}_tree_checkpoints",
605            table_prefix
606        ),
607        [],
608        |row| {
609            row.get::<_, Option<u32>>(0)
610                .map(|opt| opt.map(BlockHeight::from))
611        },
612    )
613    .map_err(Error::Query)
614}
615
616pub(crate) fn max_checkpoint_id(
617    conn: &rusqlite::Connection,
618    table_prefix: &'static str,
619) -> Result<Option<BlockHeight>, Error> {
620    conn.query_row(
621        &format!(
622            "SELECT MAX(checkpoint_id) FROM {}_tree_checkpoints",
623            table_prefix
624        ),
625        [],
626        |row| {
627            row.get::<_, Option<u32>>(0)
628                .map(|opt| opt.map(BlockHeight::from))
629        },
630    )
631    .map_err(Error::Query)
632}
633
634pub(crate) fn add_checkpoint(
635    conn: &rusqlite::Transaction<'_>,
636    table_prefix: &'static str,
637    checkpoint_id: BlockHeight,
638    checkpoint: Checkpoint,
639) -> Result<(), Error> {
640    let extant_tree_state = conn
641        .query_row(
642            &format!(
643                "SELECT position FROM {}_tree_checkpoints WHERE checkpoint_id = :checkpoint_id",
644                table_prefix
645            ),
646            named_params![":checkpoint_id": u32::from(checkpoint_id),],
647            |row| {
648                row.get::<_, Option<u64>>(0).map(|opt| {
649                    opt.map_or_else(
650                        || TreeState::Empty,
651                        |pos| TreeState::AtPosition(Position::from(pos)),
652                    )
653                })
654            },
655        )
656        .optional()
657        .map_err(Error::Query)?;
658
659    match extant_tree_state {
660        Some(current) => {
661            if current != checkpoint.tree_state() {
662                // If the checkpoint position for a given checkpoint identifier has changed, we treat
663                // this as an error because the wallet should have detected a chain reorg and truncated
664                // the tree.
665                Err(Error::CheckpointConflict {
666                    checkpoint_id,
667                    checkpoint,
668                    extant_tree_state: current,
669                    extant_marks_removed: None,
670                })
671            } else {
672                // if the existing spends are the same, we can skip the insert; if the
673                // existing spends have changed, this is also a conflict.
674                let marks_removed = get_marks_removed(conn, table_prefix, checkpoint_id)?;
675                if &marks_removed == checkpoint.marks_removed() {
676                    Ok(())
677                } else {
678                    Err(Error::CheckpointConflict {
679                        checkpoint_id,
680                        checkpoint,
681                        extant_tree_state: current,
682                        extant_marks_removed: Some(marks_removed),
683                    })
684                }
685            }
686        }
687        None => {
688            let mut stmt_insert_checkpoint = conn
689                .prepare_cached(&format!(
690                    "INSERT INTO {}_tree_checkpoints (checkpoint_id, position)
691                     VALUES (:checkpoint_id, :position)",
692                    table_prefix
693                ))
694                .map_err(Error::Query)?;
695
696            stmt_insert_checkpoint
697                .execute(named_params![
698                    ":checkpoint_id": u32::from(checkpoint_id),
699                    ":position": checkpoint.position().map(u64::from)
700                ])
701                .map_err(Error::Query)?;
702
703            let mut stmt_insert_mark_removed = conn
704                .prepare_cached(&format!(
705                    "INSERT INTO {}_tree_checkpoint_marks_removed (checkpoint_id, mark_removed_position)
706                     VALUES (:checkpoint_id, :position)",
707                    table_prefix
708                ))
709                .map_err(Error::Query)?;
710
711            for pos in checkpoint.marks_removed() {
712                stmt_insert_mark_removed
713                    .execute(named_params![
714                        ":checkpoint_id": u32::from(checkpoint_id),
715                        ":position": u64::from(*pos)
716                    ])
717                    .map_err(Error::Query)?;
718            }
719
720            Ok(())
721        }
722    }
723}
724
725pub(crate) fn checkpoint_count(
726    conn: &rusqlite::Connection,
727    table_prefix: &'static str,
728) -> Result<usize, Error> {
729    conn.query_row(
730        &format!("SELECT COUNT(*) FROM {}_tree_checkpoints", table_prefix),
731        [],
732        |row| row.get::<_, usize>(0),
733    )
734    .map_err(Error::Query)
735}
736
737fn get_marks_removed(
738    conn: &rusqlite::Connection,
739    table_prefix: &'static str,
740    checkpoint_id: BlockHeight,
741) -> Result<BTreeSet<Position>, Error> {
742    let mut stmt = conn
743        .prepare_cached(&format!(
744            "SELECT mark_removed_position
745            FROM {}_tree_checkpoint_marks_removed
746            WHERE checkpoint_id = ?",
747            table_prefix
748        ))
749        .map_err(Error::Query)?;
750    let mark_removed_rows = stmt
751        .query([u32::from(checkpoint_id)])
752        .map_err(Error::Query)?;
753
754    mark_removed_rows
755        .mapped(|row| row.get::<_, u64>(0).map(Position::from))
756        .collect::<Result<BTreeSet<_>, _>>()
757        .map_err(Error::Query)
758}
759
760pub(crate) fn get_checkpoint(
761    conn: &rusqlite::Connection,
762    table_prefix: &'static str,
763    checkpoint_id: BlockHeight,
764) -> Result<Option<Checkpoint>, Error> {
765    let checkpoint_position = conn
766        .query_row(
767            &format!(
768                "SELECT position
769                 FROM {}_tree_checkpoints
770                 WHERE checkpoint_id = ?",
771                table_prefix
772            ),
773            [u32::from(checkpoint_id)],
774            |row| {
775                row.get::<_, Option<u64>>(0)
776                    .map(|opt| opt.map(Position::from))
777            },
778        )
779        .optional()
780        .map_err(Error::Query)?;
781
782    checkpoint_position
783        .map(|pos_opt| {
784            Ok(Checkpoint::from_parts(
785                pos_opt.map_or(TreeState::Empty, TreeState::AtPosition),
786                get_marks_removed(conn, table_prefix, checkpoint_id)?,
787            ))
788        })
789        .transpose()
790}
791
792pub(crate) fn get_max_checkpointed_height(
793    conn: &rusqlite::Connection,
794    table_prefix: &'static str,
795    chain_tip_height: BlockHeight,
796    min_confirmations: NonZeroU32,
797) -> Result<Option<BlockHeight>, rusqlite::Error> {
798    let max_checkpoint_height =
799        u32::from(chain_tip_height).saturating_sub(u32::from(min_confirmations) - 1);
800
801    // We exclude from consideration all checkpoints having heights greater than the maximum
802    // checkpoint height. The checkpoint depth is the number of excluded checkpoints + 1.
803    conn.query_row(
804        &format!(
805            "SELECT checkpoint_id
806             FROM {}_tree_checkpoints
807             WHERE checkpoint_id <= :max_checkpoint_height
808             ORDER BY checkpoint_id DESC
809             LIMIT 1",
810            table_prefix
811        ),
812        named_params![":max_checkpoint_height": max_checkpoint_height],
813        |row| row.get::<_, u32>(0).map(BlockHeight::from),
814    )
815    .optional()
816}
817
818pub(crate) fn get_checkpoint_at_depth(
819    conn: &rusqlite::Connection,
820    table_prefix: &'static str,
821    checkpoint_depth: usize,
822) -> Result<Option<(BlockHeight, Checkpoint)>, rusqlite::Error> {
823    let checkpoint_parts = conn
824        .query_row(
825            &format!(
826                "SELECT checkpoint_id, position
827                FROM {}_tree_checkpoints
828                ORDER BY checkpoint_id DESC
829                LIMIT 1
830                OFFSET :offset",
831                table_prefix
832            ),
833            named_params![":offset": checkpoint_depth],
834            |row| {
835                let checkpoint_id: u32 = row.get(0)?;
836                let position: Option<u64> = row.get(1)?;
837                Ok((
838                    BlockHeight::from(checkpoint_id),
839                    position.map(Position::from),
840                ))
841            },
842        )
843        .optional()?;
844
845    checkpoint_parts
846        .map(|(checkpoint_id, pos_opt)| {
847            let mut stmt = conn.prepare_cached(&format!(
848                "SELECT mark_removed_position
849                    FROM {}_tree_checkpoint_marks_removed
850                    WHERE checkpoint_id = ?",
851                table_prefix
852            ))?;
853            let mark_removed_rows = stmt.query([u32::from(checkpoint_id)])?;
854
855            let marks_removed = mark_removed_rows
856                .mapped(|row| row.get::<_, u64>(0).map(Position::from))
857                .collect::<Result<BTreeSet<_>, _>>()?;
858
859            Ok((
860                checkpoint_id,
861                Checkpoint::from_parts(
862                    pos_opt.map_or(TreeState::Empty, TreeState::AtPosition),
863                    marks_removed,
864                ),
865            ))
866        })
867        .transpose()
868}
869
870pub(crate) fn with_checkpoints<F>(
871    conn: &rusqlite::Transaction<'_>,
872    table_prefix: &'static str,
873    limit: usize,
874    mut callback: F,
875) -> Result<(), Error>
876where
877    F: FnMut(&BlockHeight, &Checkpoint) -> Result<(), Error>,
878{
879    let mut stmt_get_checkpoints = conn
880        .prepare_cached(&format!(
881            "SELECT checkpoint_id, position
882            FROM {}_tree_checkpoints
883            ORDER BY position
884            LIMIT :limit",
885            table_prefix
886        ))
887        .map_err(Error::Query)?;
888
889    let mut stmt_get_checkpoint_marks_removed = conn
890        .prepare_cached(&format!(
891            "SELECT mark_removed_position
892            FROM {}_tree_checkpoint_marks_removed
893            WHERE checkpoint_id = :checkpoint_id",
894            table_prefix
895        ))
896        .map_err(Error::Query)?;
897
898    let mut rows = stmt_get_checkpoints
899        .query(named_params![":limit": limit])
900        .map_err(Error::Query)?;
901
902    while let Some(row) = rows.next().map_err(Error::Query)? {
903        let checkpoint_id = row.get::<_, u32>(0).map_err(Error::Query)?;
904        let tree_state = row
905            .get::<_, Option<u64>>(1)
906            .map(|opt| opt.map_or_else(|| TreeState::Empty, |p| TreeState::AtPosition(p.into())))
907            .map_err(Error::Query)?;
908
909        let mark_removed_rows = stmt_get_checkpoint_marks_removed
910            .query(named_params![":checkpoint_id": checkpoint_id])
911            .map_err(Error::Query)?;
912
913        let marks_removed = mark_removed_rows
914            .mapped(|row| row.get::<_, u64>(0).map(Position::from))
915            .collect::<Result<BTreeSet<_>, _>>()
916            .map_err(Error::Query)?;
917
918        callback(
919            &BlockHeight::from(checkpoint_id),
920            &Checkpoint::from_parts(tree_state, marks_removed),
921        )?
922    }
923
924    Ok(())
925}
926
927pub(crate) fn update_checkpoint_with<F>(
928    conn: &rusqlite::Transaction<'_>,
929    table_prefix: &'static str,
930    checkpoint_id: BlockHeight,
931    update: F,
932) -> Result<bool, Error>
933where
934    F: Fn(&mut Checkpoint) -> Result<(), Error>,
935{
936    if let Some(mut c) = get_checkpoint(conn, table_prefix, checkpoint_id)? {
937        update(&mut c)?;
938        remove_checkpoint(conn, table_prefix, checkpoint_id)?;
939        add_checkpoint(conn, table_prefix, checkpoint_id, c)?;
940        Ok(true)
941    } else {
942        Ok(false)
943    }
944}
945
946pub(crate) fn remove_checkpoint(
947    conn: &rusqlite::Transaction<'_>,
948    table_prefix: &'static str,
949    checkpoint_id: BlockHeight,
950) -> Result<(), Error> {
951    // cascading delete here obviates the need to manually delete from
952    // `tree_checkpoint_marks_removed`
953    let mut stmt_delete_checkpoint = conn
954        .prepare_cached(&format!(
955            "DELETE FROM {}_tree_checkpoints
956             WHERE checkpoint_id = :checkpoint_id",
957            table_prefix
958        ))
959        .map_err(Error::Query)?;
960
961    stmt_delete_checkpoint
962        .execute(named_params![":checkpoint_id": u32::from(checkpoint_id),])
963        .map_err(Error::Query)?;
964
965    Ok(())
966}
967
968pub(crate) fn truncate_checkpoints_retaining(
969    conn: &rusqlite::Transaction<'_>,
970    table_prefix: &'static str,
971    checkpoint_id: BlockHeight,
972) -> Result<(), Error> {
973    // cascading delete here obviates the need to manually delete from
974    // `<protocol>_tree_checkpoint_marks_removed`
975    conn.execute(
976        &format!(
977            "DELETE FROM {}_tree_checkpoints WHERE checkpoint_id > ?",
978            table_prefix
979        ),
980        [u32::from(checkpoint_id)],
981    )
982    .map_err(Error::Query)?;
983
984    // we do however need to manually delete any marks associated with the retained checkpoint
985    conn.execute(
986        &format!(
987            "DELETE FROM {}_tree_checkpoint_marks_removed WHERE checkpoint_id = ?",
988            table_prefix
989        ),
990        [u32::from(checkpoint_id)],
991    )
992    .map_err(Error::Query)?;
993
994    Ok(())
995}
996
997#[tracing::instrument(skip(conn, roots))]
998pub(crate) fn put_shard_roots<
999    H: Hashable + HashSer + Clone + Eq,
1000    const DEPTH: u8,
1001    const SHARD_HEIGHT: u8,
1002>(
1003    conn: &rusqlite::Transaction<'_>,
1004    table_prefix: &'static str,
1005    start_index: u64,
1006    roots: &[CommitmentTreeRoot<H>],
1007) -> Result<(), ShardTreeError<Error>> {
1008    if roots.is_empty() {
1009        // nothing to do
1010        return Ok(());
1011    }
1012
1013    // We treat the cap as a tree with `DEPTH - SHARD_HEIGHT` levels, so that we can make a
1014    // batch insertion of root data using `Position::from(start_index)` as the starting position
1015    // and treating the roots as level-0 leaves.
1016    #[derive(Clone, Debug, PartialEq, Eq)]
1017    struct LevelShifter<H, const SHARD_HEIGHT: u8>(H);
1018    impl<H: Hashable, const SHARD_HEIGHT: u8> Hashable for LevelShifter<H, SHARD_HEIGHT> {
1019        fn empty_leaf() -> Self {
1020            Self(H::empty_root(SHARD_HEIGHT.into()))
1021        }
1022
1023        fn combine(level: Level, a: &Self, b: &Self) -> Self {
1024            Self(H::combine(level + SHARD_HEIGHT, &a.0, &b.0))
1025        }
1026
1027        fn empty_root(level: Level) -> Self
1028        where
1029            Self: Sized,
1030        {
1031            Self(H::empty_root(level + SHARD_HEIGHT))
1032        }
1033    }
1034    impl<H: HashSer, const SHARD_HEIGHT: u8> HashSer for LevelShifter<H, SHARD_HEIGHT> {
1035        fn read<R: io::Read>(reader: R) -> io::Result<Self>
1036        where
1037            Self: Sized,
1038        {
1039            H::read(reader).map(Self)
1040        }
1041
1042        fn write<W: io::Write>(&self, writer: W) -> io::Result<()> {
1043            self.0.write(writer)
1044        }
1045    }
1046
1047    let cap = LocatedTree::from_parts(
1048        Address::from_parts((DEPTH - SHARD_HEIGHT).into(), 0),
1049        get_cap::<LevelShifter<H, SHARD_HEIGHT>>(conn, table_prefix)
1050            .map_err(ShardTreeError::Storage)?,
1051    )
1052    .map_err(|e| {
1053        ShardTreeError::Storage(Error::Serialization(io::Error::new(
1054            io::ErrorKind::InvalidData,
1055            format!("Note commitment tree cap was invalid at address {:?}", e),
1056        )))
1057    })?;
1058
1059    let insert_into_cap = tracing::info_span!("insert_into_cap").entered();
1060    let cap_result = cap
1061        .batch_insert::<(), _>(
1062            Position::from(start_index),
1063            roots
1064                .iter()
1065                .map(|r| (LevelShifter(r.root_hash().clone()), Retention::Reference)),
1066        )
1067        .map_err(ShardTreeError::Insert)?
1068        .expect("slice of inserted roots was verified to be nonempty");
1069    drop(insert_into_cap);
1070
1071    put_cap(conn, table_prefix, cap_result.subtree.take_root()).map_err(ShardTreeError::Storage)?;
1072
1073    check_shard_discontinuity(
1074        conn,
1075        table_prefix,
1076        start_index..start_index + (roots.len() as u64),
1077    )
1078    .map_err(ShardTreeError::Storage)?;
1079
1080    // We want to avoid deserializing the subtree just to annotate its root node, so we simply
1081    // cache the downloaded root alongside of any already-persisted subtree. We will update the
1082    // subtree data itself by reannotating the root node of the tree, handling conflicts, at
1083    // the time that we deserialize the tree.
1084    let mut stmt = conn
1085        .prepare_cached(&format!(
1086            "INSERT INTO {}_tree_shards (shard_index, subtree_end_height, root_hash, shard_data)
1087            VALUES (:shard_index, :subtree_end_height, :root_hash, :shard_data)
1088            ON CONFLICT (shard_index) DO UPDATE
1089            SET subtree_end_height = :subtree_end_height, root_hash = :root_hash",
1090            table_prefix
1091        ))
1092        .map_err(|e| ShardTreeError::Storage(Error::Query(e)))?;
1093
1094    let put_roots = tracing::info_span!("write_shards").entered();
1095    for (root, i) in roots.iter().zip(0u64..) {
1096        // The `shard_data` value will only be used in the case that no tree already exists.
1097        let mut shard_data: Vec<u8> = vec![];
1098        let tree = PrunableTree::leaf((root.root_hash().clone(), RetentionFlags::EPHEMERAL));
1099        write_shard(&mut shard_data, &tree)
1100            .map_err(|e| ShardTreeError::Storage(Error::Serialization(e)))?;
1101
1102        let mut root_hash_data: Vec<u8> = vec![];
1103        root.root_hash()
1104            .write(&mut root_hash_data)
1105            .map_err(|e| ShardTreeError::Storage(Error::Serialization(e)))?;
1106
1107        stmt.execute(named_params![
1108            ":shard_index": start_index + i,
1109            ":subtree_end_height": u32::from(root.subtree_end_height()),
1110            ":root_hash": root_hash_data,
1111            ":shard_data": shard_data,
1112        ])
1113        .map_err(|e| ShardTreeError::Storage(Error::Query(e)))?;
1114    }
1115    drop(put_roots);
1116
1117    Ok(())
1118}
1119
1120pub(crate) fn check_witnesses(
1121    conn: &rusqlite::Transaction<'_>,
1122) -> Result<Vec<Range<BlockHeight>>, SqliteClientError> {
1123    let chain_tip_height =
1124        super::chain_tip_height(conn)?.ok_or(SqliteClientError::ChainHeightUnknown)?;
1125    let wallet_birthday = super::wallet_birthday(conn)?.ok_or(SqliteClientError::AccountUnknown)?;
1126    let unspent_sapling_note_meta =
1127        super::sapling::select_unspent_note_meta(conn, chain_tip_height, wallet_birthday)?;
1128
1129    let mut scan_ranges = vec![];
1130    let mut sapling_incomplete = vec![];
1131    let sapling_tree = sapling_tree(conn)?;
1132    for m in unspent_sapling_note_meta.iter() {
1133        match sapling_tree.witness_at_checkpoint_depth(m.commitment_tree_position(), 0) {
1134            Ok(_) => {}
1135            Err(ShardTreeError::Query(QueryError::TreeIncomplete(mut addrs))) => {
1136                sapling_incomplete.append(&mut addrs);
1137            }
1138            Err(other) => {
1139                return Err(SqliteClientError::CommitmentTree(other));
1140            }
1141        }
1142    }
1143
1144    for addr in sapling_incomplete {
1145        let range = super::get_block_range(conn, ShieldedProtocol::Sapling, addr)?;
1146        scan_ranges.extend(range.into_iter());
1147    }
1148
1149    #[cfg(feature = "orchard")]
1150    {
1151        let unspent_orchard_note_meta =
1152            super::orchard::select_unspent_note_meta(conn, chain_tip_height, wallet_birthday)?;
1153        let mut orchard_incomplete = vec![];
1154        let orchard_tree = orchard_tree(conn)?;
1155        for m in unspent_orchard_note_meta.iter() {
1156            match orchard_tree.witness_at_checkpoint_depth(m.commitment_tree_position(), 0) {
1157                Ok(_) => {}
1158                Err(ShardTreeError::Query(QueryError::TreeIncomplete(mut addrs))) => {
1159                    orchard_incomplete.append(&mut addrs);
1160                }
1161                Err(other) => {
1162                    return Err(SqliteClientError::CommitmentTree(other));
1163                }
1164            }
1165        }
1166
1167        for addr in orchard_incomplete {
1168            let range = super::get_block_range(conn, ShieldedProtocol::Orchard, addr)?;
1169            scan_ranges.extend(range.into_iter());
1170        }
1171    }
1172
1173    Ok(scan_ranges)
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178    use tempfile::NamedTempFile;
1179
1180    use incrementalmerkletree::{Marking, Position, Retention};
1181    use incrementalmerkletree_testing::{
1182        check_append, check_checkpoint_rewind, check_remove_mark, check_rewind_remove_mark,
1183        check_root_hashes, check_witness_consistency, check_witnesses,
1184    };
1185    use shardtree::ShardTree;
1186    use zcash_client_backend::data_api::{
1187        chain::CommitmentTreeRoot,
1188        testing::{pool::ShieldedPoolTester, sapling::SaplingPoolTester},
1189    };
1190    use zcash_protocol::consensus::{BlockHeight, Network};
1191
1192    use super::SqliteShardStore;
1193    use crate::{
1194        testing::{
1195            db::{test_clock, test_rng},
1196            pool::ShieldedPoolPersistence,
1197        },
1198        wallet::init::WalletMigrator,
1199        WalletDb,
1200    };
1201
1202    fn new_tree<T: ShieldedPoolTester + ShieldedPoolPersistence>(
1203        m: usize,
1204    ) -> ShardTree<SqliteShardStore<rusqlite::Connection, String, 3>, 4, 3> {
1205        let data_file = NamedTempFile::new().unwrap();
1206        let mut db_data = WalletDb::for_path(
1207            data_file.path(),
1208            Network::TestNetwork,
1209            test_clock(),
1210            test_rng(),
1211        )
1212        .unwrap();
1213        data_file.keep().unwrap();
1214
1215        WalletMigrator::new().init_or_migrate(&mut db_data).unwrap();
1216        let store =
1217            SqliteShardStore::<_, String, 3>::from_connection(db_data.conn, T::TABLES_PREFIX)
1218                .unwrap();
1219        ShardTree::new(store, m)
1220    }
1221
1222    #[cfg(feature = "orchard")]
1223    mod orchard {
1224        use super::new_tree;
1225        use zcash_client_backend::data_api::testing::orchard::OrchardPoolTester;
1226
1227        #[test]
1228        fn append() {
1229            super::check_append(new_tree::<OrchardPoolTester>);
1230        }
1231
1232        #[test]
1233        fn root_hashes() {
1234            super::check_root_hashes(new_tree::<OrchardPoolTester>);
1235        }
1236
1237        #[test]
1238        fn witnesses() {
1239            super::check_witnesses(new_tree::<OrchardPoolTester>);
1240        }
1241
1242        #[test]
1243        fn witness_consistency() {
1244            super::check_witness_consistency(new_tree::<OrchardPoolTester>);
1245        }
1246
1247        #[test]
1248        fn checkpoint_rewind() {
1249            super::check_checkpoint_rewind(new_tree::<OrchardPoolTester>);
1250        }
1251
1252        #[test]
1253        fn remove_mark() {
1254            super::check_remove_mark(new_tree::<OrchardPoolTester>);
1255        }
1256
1257        #[test]
1258        fn rewind_remove_mark() {
1259            super::check_rewind_remove_mark(new_tree::<OrchardPoolTester>);
1260        }
1261
1262        #[test]
1263        fn put_shard_roots() {
1264            super::put_shard_roots::<OrchardPoolTester>()
1265        }
1266    }
1267
1268    #[test]
1269    fn sapling_append() {
1270        check_append(new_tree::<SaplingPoolTester>);
1271    }
1272
1273    #[test]
1274    fn sapling_root_hashes() {
1275        check_root_hashes(new_tree::<SaplingPoolTester>);
1276    }
1277
1278    #[test]
1279    fn sapling_witnesses() {
1280        check_witnesses(new_tree::<SaplingPoolTester>);
1281    }
1282
1283    #[test]
1284    fn sapling_witness_consistency() {
1285        check_witness_consistency(new_tree::<SaplingPoolTester>);
1286    }
1287
1288    #[test]
1289    fn sapling_checkpoint_rewind() {
1290        check_checkpoint_rewind(new_tree::<SaplingPoolTester>);
1291    }
1292
1293    #[test]
1294    fn sapling_remove_mark() {
1295        check_remove_mark(new_tree::<SaplingPoolTester>);
1296    }
1297
1298    #[test]
1299    fn sapling_rewind_remove_mark() {
1300        check_rewind_remove_mark(new_tree::<SaplingPoolTester>);
1301    }
1302
1303    #[test]
1304    fn sapling_put_shard_roots() {
1305        put_shard_roots::<SaplingPoolTester>()
1306    }
1307
1308    fn put_shard_roots<T: ShieldedPoolTester + ShieldedPoolPersistence>() {
1309        let data_file = NamedTempFile::new().unwrap();
1310        let mut db_data = WalletDb::for_path(
1311            data_file.path(),
1312            Network::TestNetwork,
1313            test_clock(),
1314            test_rng(),
1315        )
1316        .unwrap();
1317        data_file.keep().unwrap();
1318
1319        WalletMigrator::new().init_or_migrate(&mut db_data).unwrap();
1320        let tx = db_data.conn.transaction().unwrap();
1321        let store =
1322            SqliteShardStore::<_, String, 3>::from_connection(&tx, T::TABLES_PREFIX).unwrap();
1323
1324        // introduce some roots
1325        let roots = (0u32..4)
1326            .map(|idx| {
1327                CommitmentTreeRoot::from_parts(
1328                    BlockHeight::from((idx + 1) * 3),
1329                    if idx == 3 {
1330                        "abcdefgh".to_string()
1331                    } else {
1332                        idx.to_string()
1333                    },
1334                )
1335            })
1336            .collect::<Vec<_>>();
1337        super::put_shard_roots::<_, 6, 3>(store.conn, T::TABLES_PREFIX, 0, &roots).unwrap();
1338
1339        // simulate discovery of a note
1340        let mut tree = ShardTree::<_, 6, 3>::new(store, 10);
1341        let checkpoint_height = BlockHeight::from(3);
1342        tree.batch_insert(
1343            Position::from(24),
1344            ('a'..='h').map(|c| {
1345                (
1346                    c.to_string(),
1347                    match c {
1348                        'c' => Retention::Marked,
1349                        'h' => Retention::Checkpoint {
1350                            id: checkpoint_height,
1351                            marking: Marking::None,
1352                        },
1353                        _ => Retention::Ephemeral,
1354                    },
1355                )
1356            }),
1357        )
1358        .unwrap();
1359
1360        // construct a witness for the note
1361        let witness = tree
1362            .witness_at_checkpoint_id(Position::from(26), &checkpoint_height)
1363            .unwrap();
1364        assert_eq!(
1365            witness
1366                .expect("an anchor exists at the expected checkpoint height")
1367                .path_elems(),
1368            &[
1369                "d",
1370                "ab",
1371                "efgh",
1372                "2",
1373                "01",
1374                "________________________________"
1375            ]
1376        );
1377    }
1378}