Code simplifications

Co-authored-by: Daira Hopwood <daira@jacaranda.org>
This commit is contained in:
Jack Grigg 2023-07-12 17:12:14 +01:00
parent bb920341a6
commit f7163e9dd9
3 changed files with 126 additions and 202 deletions

View File

@ -48,7 +48,7 @@ impl ScanRange {
/// Returns whether or not the scan range is empty.
pub fn is_empty(&self) -> bool {
self.block_range.end == self.block_range.start
self.block_range.is_empty()
}
/// Returns the number of blocks in the scan range.
@ -89,20 +89,16 @@ impl ScanRange {
/// end of the first range returned and the start of the second. Returns `None` if
/// `p <= self.block_range().start || p >= self.block_range().end`.
pub fn split_at(&self, p: BlockHeight) -> Option<(Self, Self)> {
if p > self.block_range.start && p < self.block_range.end {
Some((
ScanRange {
block_range: self.block_range.start..p,
priority: self.priority,
},
ScanRange {
block_range: p..self.block_range.end,
priority: self.priority,
},
))
} else {
None
}
(p > self.block_range.start && p < self.block_range.end).then_some((
ScanRange {
block_range: self.block_range.start..p,
priority: self.priority,
},
ScanRange {
block_range: p..self.block_range.end,
priority: self.priority,
},
))
}
}

View File

@ -414,10 +414,13 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
});
let mut wallet_note_ids = vec![];
let mut sapling_commitments = vec![];
let mut end_height = None;
let mut last_scanned_height = None;
let mut note_positions = vec![];
for block in blocks.into_iter() {
if end_height.iter().any(|prev| block.height() != *prev + 1) {
if last_scanned_height
.iter()
.any(|prev| block.height() != *prev + 1)
{
return Err(SqliteClientError::NonSequentialBlocks);
}
@ -453,14 +456,14 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
.map(|out| out.note_commitment_tree_position())
}));
end_height = Some(block.height());
last_scanned_height = Some(block.height());
sapling_commitments.extend(block.into_sapling_commitments().into_iter());
}
// We will have a start position and an end height in all cases where `blocks` is
// non-empty.
if let Some(((start_height, start_position), end_height)) =
start_positions.zip(end_height)
// We will have a start position and a last scanned height in all cases where
// `blocks` is non-empty.
if let Some(((start_height, start_position), last_scanned_height)) =
start_positions.zip(last_scanned_height)
{
// Update the Sapling note commitment tree with all newly read note commitments
let mut sapling_commitments = sapling_commitments.into_iter();
@ -470,14 +473,14 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
})?;
// Update now-expired transactions that didn't get mined.
wallet::update_expired_notes(wdb.conn.0, end_height)?;
wallet::update_expired_notes(wdb.conn.0, last_scanned_height)?;
wallet::scanning::scan_complete(
wdb.conn.0,
&wdb.params,
Range {
start: start_height,
end: end_height + 1,
end: last_scanned_height + 1,
},
&note_positions,
)?;

View File

@ -107,24 +107,13 @@ pub(crate) fn suggest_scan_ranges(
// This implements the dominance rule for range priority. If the inserted range's priority is
// `Verify`, this replaces any existing priority. Otherwise, if the current priority is
// `Scanned`, this overwrites any priority
fn update_priority(current: ScanPriority, inserted: ScanPriority) -> ScanPriority {
match (current, inserted) {
(_, ScanPriority::Verify) => ScanPriority::Verify,
(ScanPriority::Scanned, _) => ScanPriority::Scanned,
(_, ScanPriority::Scanned) => ScanPriority::Scanned,
(a, b) => max(a, b),
}
}
fn dominance(current: &ScanPriority, inserted: &ScanPriority, insert: Insert) -> Dominance {
match (current, inserted) {
(_, ScanPriority::Verify | ScanPriority::Scanned) => Dominance::from(insert),
(ScanPriority::Scanned, _) => Dominance::from(!insert),
(a, b) => match a.cmp(b) {
Ordering::Less => Dominance::from(insert),
Ordering::Equal => Dominance::Equal,
Ordering::Greater => Dominance::from(!insert),
},
match (current.cmp(inserted), (current, inserted)) {
(Ordering::Equal, _) => Dominance::Equal,
(_, (_, ScanPriority::Verify | ScanPriority::Scanned)) => Dominance::from(insert),
(_, (ScanPriority::Scanned, _)) => Dominance::from(!insert),
(Ordering::Less, _) => Dominance::from(insert),
(Ordering::Greater, _) => Dominance::from(!insert),
}
}
@ -141,31 +130,16 @@ enum RangeOrdering {
impl RangeOrdering {
fn cmp<A: Ord>(a: &Range<A>, b: &Range<A>) -> Self {
use RangeOrdering::*;
use Ordering::*;
assert!(a.start <= a.end && b.start <= b.end);
if a.end <= b.start {
LeftFirstDisjoint
} else if b.end <= a.start {
RightFirstDisjoint
} else if a.start < b.start {
if a.end >= b.end {
RightContained
} else {
LeftFirstOverlap
}
} else if b.start < a.start {
if b.end >= a.end {
LeftContained
} else {
RightFirstOverlap
}
} else {
// a.start == b.start
match a.end.cmp(&b.end) {
Ordering::Less => LeftContained,
Ordering::Equal => Equal,
Ordering::Greater => RightContained,
}
match (a.start.cmp(&b.start), a.end.cmp(&b.end)) {
_ if a.end <= b.start => RangeOrdering::LeftFirstDisjoint,
_ if b.end <= a.start => RangeOrdering::RightFirstDisjoint,
(Less, Less) => RangeOrdering::LeftFirstOverlap,
(Equal, Less) | (Greater, Less) | (Greater, Equal) => RangeOrdering::LeftContained,
(Equal, Equal) => RangeOrdering::Equal,
(Equal, Greater) | (Less, Greater) | (Less, Equal) => RangeOrdering::RightContained,
(Greater, Greater) => RangeOrdering::RightFirstOverlap,
}
}
}
@ -196,9 +170,9 @@ fn join_nonoverlapping(left: ScanRange, right: ScanRange) -> Joined {
);
match join_nonoverlapping(left, gap) {
Joined::One(left) => join_nonoverlapping(left, right),
Joined::One(merged) => join_nonoverlapping(merged, right),
Joined::Two(left, gap) => match join_nonoverlapping(gap, right) {
Joined::One(right) => Joined::Two(left, right),
Joined::One(merged) => Joined::Two(left, merged),
Joined::Two(gap, right) => Joined::Three(left, gap, right),
_ => unreachable!(),
},
@ -232,19 +206,15 @@ fn insert(current: ScanRange, to_insert: ScanRange) -> Joined {
left.block_range().start..max(left.block_range().end, right.block_range().end),
left.priority(),
)),
Dominance::Right => {
if let Some(left) = left.truncate_end(right.block_range().start) {
if let Some(end) = left.truncate_start(right.block_range().end) {
Joined::Three(left, right, end)
} else {
Joined::Two(left, right)
}
} else if let Some(end) = left.truncate_start(right.block_range().end) {
Joined::Two(right, end)
} else {
Joined::One(right)
}
}
Dominance::Right => match (
left.truncate_end(right.block_range().start),
left.truncate_start(right.block_range().end),
) {
(Some(before), Some(after)) => Joined::Three(before, right, after),
(Some(before), None) => Joined::Two(before, right),
(None, Some(after)) => Joined::Two(right, after),
(None, None) => Joined::One(right),
},
}
}
@ -254,7 +224,10 @@ fn insert(current: ScanRange, to_insert: ScanRange) -> Joined {
LeftFirstOverlap | RightContained => join_overlapping(to_insert, current, Insert::Left),
Equal => Joined::One(ScanRange::from_parts(
to_insert.block_range().clone(),
update_priority(current.priority(), to_insert.priority()),
match dominance(&current.priority(), &to_insert.priority(), Insert::Right) {
Dominance::Left | Dominance::Equal => current.priority(),
Dominance::Right => to_insert.priority(),
},
)),
RightFirstOverlap | LeftContained => join_overlapping(current, to_insert, Insert::Right),
RightFirstDisjoint => join_nonoverlapping(current, to_insert),
@ -299,6 +272,36 @@ impl SpanningTree {
}
}
fn from_insert(
left: Box<Self>,
right: Box<Self>,
to_insert: ScanRange,
insert: Insert,
) -> Self {
let (left, right) = match insert {
Insert::Left => (Box::new(left.insert(to_insert)), right),
Insert::Right => (left, Box::new(right.insert(to_insert))),
};
SpanningTree::Parent {
span: left.span().start..right.span().end,
left,
right,
}
}
fn from_split(left: Self, right: Self, to_insert: ScanRange, split_point: BlockHeight) -> Self {
let (l_insert, r_insert) = to_insert
.split_at(split_point)
.expect("Split point is within the range of to_insert");
let left = Box::new(left.insert(l_insert));
let right = Box::new(right.insert(r_insert));
SpanningTree::Parent {
span: left.span().start..right.span().end,
left,
right,
}
}
fn insert(self, to_insert: ScanRange) -> Self {
match self {
SpanningTree::Leaf(cur) => Self::from_joined(insert(cur, to_insert)),
@ -311,33 +314,15 @@ impl SpanningTree {
match RangeOrdering::cmp(&span, to_insert.block_range()) {
LeftFirstDisjoint => {
// extend the right-hand branch
SpanningTree::Parent {
span: left.span().start..to_insert.block_range().end,
left,
right: Box::new(right.insert(to_insert)),
}
Self::from_insert(left, right, to_insert, Insert::Right)
}
LeftFirstOverlap => {
let split_point = left.span().end;
if split_point > to_insert.block_range().start {
let (l_insert, r_insert) = to_insert
.split_at(split_point)
.expect("Split point is within the range of to_insert");
let left = Box::new(left.insert(l_insert));
let right = Box::new(right.insert(r_insert));
SpanningTree::Parent {
span: left.span().start..right.span().end,
left,
right,
}
Self::from_split(*left, *right, to_insert, split_point)
} else {
// to_insert is fully contained in or equals the right child
SpanningTree::Parent {
span: left.span().start
..max(right.span().end, to_insert.block_range().end),
left,
right: Box::new(right.insert(to_insert)),
}
Self::from_insert(left, right, to_insert, Insert::Right)
}
}
RightContained => {
@ -346,44 +331,19 @@ impl SpanningTree {
let split_point = left.span().end;
if to_insert.block_range().start >= split_point {
// to_insert is fully contained in the right
SpanningTree::Parent {
span,
left,
right: Box::new(right.insert(to_insert)),
}
Self::from_insert(left, right, to_insert, Insert::Right)
} else if to_insert.block_range().end <= split_point {
// to_insert is fully contained in the left
SpanningTree::Parent {
span,
left: Box::new(left.insert(to_insert)),
right,
}
Self::from_insert(left, right, to_insert, Insert::Left)
} else {
// to_insert must be split.
let (l_insert, r_insert) = to_insert
.split_at(split_point)
.expect("Split point is within the range of to_insert");
let left = Box::new(left.insert(l_insert));
let right = Box::new(right.insert(r_insert));
SpanningTree::Parent {
span: left.span().start..right.span().end,
left,
right,
}
Self::from_split(*left, *right, to_insert, split_point)
}
}
Equal => {
if left.span().end > to_insert.block_range().start {
let (l_insert, r_insert) = to_insert
.split_at(left.span().end)
.expect("Split point is within the range of to_insert");
let left = Box::new(left.insert(l_insert));
let right = Box::new(right.insert(r_insert));
SpanningTree::Parent {
span: left.span().start..right.span().end,
left,
right,
}
let split_point = left.span().end;
if split_point > to_insert.block_range().start {
Self::from_split(*left, *right, to_insert, split_point)
} else {
// to_insert is fully contained in the right subtree
right.insert(to_insert)
@ -392,47 +352,21 @@ impl SpanningTree {
LeftContained => {
// the current span is fully contained within to_insert, so we will extend
// or overwrite both sides
let (l_insert, r_insert) = to_insert
.split_at(left.span().end)
.expect("Split point is within the range of to_insert");
let left = Box::new(left.insert(l_insert));
let right = Box::new(right.insert(r_insert));
SpanningTree::Parent {
span: left.span().start..right.span().end,
left,
right,
}
let split_point = left.span().end;
Self::from_split(*left, *right, to_insert, split_point)
}
RightFirstOverlap => {
let split_point = left.span().end;
if split_point < to_insert.block_range().end {
let (l_insert, r_insert) = to_insert
.split_at(split_point)
.expect("Split point is within the range of to_insert");
let left = Box::new(left.insert(l_insert));
let right = Box::new(right.insert(r_insert));
SpanningTree::Parent {
span: left.span().start..right.span().end,
left,
right,
}
Self::from_split(*left, *right, to_insert, split_point)
} else {
// to_insert is fully contained in or equals the left child
SpanningTree::Parent {
span: min(to_insert.block_range().start, left.span().start)
..right.span().end,
left: Box::new(left.insert(to_insert)),
right,
}
Self::from_insert(left, right, to_insert, Insert::Left)
}
}
RightFirstDisjoint => {
// extend the left-hand branch
SpanningTree::Parent {
span: to_insert.block_range().start..right.span().end,
left: Box::new(left.insert(to_insert)),
right,
}
Self::from_insert(left, right, to_insert, Insert::Left)
}
}
}
@ -445,7 +379,7 @@ impl SpanningTree {
SpanningTree::Leaf(entry) => {
if let Some(top) = acc.pop() {
match join_nonoverlapping(top, entry) {
Joined::One(entry) => acc.push(entry),
Joined::One(merged) => acc.push(merged),
Joined::Two(l, r) => {
acc.push(l);
acc.push(r);
@ -479,7 +413,7 @@ pub(crate) fn insert_queue_entries<'a>(
)?;
for entry in entries {
if entry.block_range().end > entry.block_range().start {
if !entry.is_empty() {
stmt.execute(named_params![
":block_range_start": u32::from(entry.block_range().start) ,
":block_range_end": u32::from(entry.block_range().end),
@ -494,7 +428,7 @@ pub(crate) fn insert_queue_entries<'a>(
pub(crate) fn replace_queue_entries(
conn: &rusqlite::Connection,
query_range: &Range<BlockHeight>,
mut entries: impl Iterator<Item = ScanRange>,
entries: impl Iterator<Item = ScanRange>,
) -> Result<(), SqliteClientError> {
let (to_create, to_delete_ends) = {
let mut suggested_stmt = conn.prepare_cached(
@ -527,7 +461,7 @@ pub(crate) fn replace_queue_entries(
// identified as needing to be fully scanned. For each such range add it to the
// spanning tree (these should all be nonoverlapping ranges, but we might coalesce
// some in the process).
let mut existing_ranges: Option<SpanningTree> = None;
let mut to_create: Option<SpanningTree> = None;
let mut to_delete_ends: Vec<Value> = vec![];
while let Some(row) = rows.next()? {
let entry = ScanRange::from_parts(
@ -546,7 +480,7 @@ pub(crate) fn replace_queue_entries(
},
);
to_delete_ends.push(Value::from(u32::from(entry.block_range().end)));
existing_ranges = if let Some(cur) = existing_ranges {
to_create = if let Some(cur) = to_create {
Some(cur.insert(entry))
} else {
Some(SpanningTree::Leaf(entry))
@ -555,15 +489,12 @@ pub(crate) fn replace_queue_entries(
// Update the tree that we read from the database, or if we didn't find any ranges
// start with the scanned range.
let mut to_create = match (existing_ranges, entries.next()) {
(Some(cur), Some(entry)) => Some(cur.insert(entry)),
(None, Some(entry)) => Some(SpanningTree::Leaf(entry)),
(Some(cur), None) => Some(cur),
(None, None) => None,
};
for entry in entries {
to_create = to_create.map(|cur| cur.insert(entry));
to_create = if let Some(cur) = to_create {
Some(cur.insert(entry))
} else {
Some(SpanningTree::Leaf(entry))
};
}
(to_create, to_delete_ends)
@ -611,6 +542,16 @@ pub(crate) fn scan_complete<P: consensus::Parameters>(
WHERE shard_index = :shard_index",
)?;
let mut sapling_shard_end = |index: u64| -> Result<Option<BlockHeight>, rusqlite::Error> {
Ok(sapling_shard_end_stmt
.query_row(named_params![":shard_index": index], |row| {
row.get::<_, Option<u32>>(0)
.map(|opt| opt.map(BlockHeight::from))
})
.optional()?
.flatten())
};
// if no notes belonging to the wallet were found, so don't need to extend the scanning
// range suggestions to include the associated subtrees, and our bounds are just the
// scanned range
@ -618,38 +559,18 @@ pub(crate) fn scan_complete<P: consensus::Parameters>(
.map(|(min_idx, max_idx)| {
let range_min = if *min_idx > 0 {
// get the block height of the end of the previous shard
sapling_shard_end_stmt
.query_row(named_params![":shard_index": *min_idx - 1], |row| {
row.get::<_, Option<u32>>(0)
.map(|opt| opt.map(BlockHeight::from))
})
.optional()?
.flatten()
sapling_shard_end(*min_idx - 1)?
} else {
// our lower bound is going to be the Sapling activation height
params.activation_height(NetworkUpgrade::Sapling)
};
// get the block height for the end of the current shard
let range_max = sapling_shard_end_stmt
.query_row(named_params![":shard_index": max_idx], |row| {
row.get::<_, Option<u32>>(0)
.map(|opt| opt.map(BlockHeight::from))
})
.optional()?
.flatten();
let range_max = sapling_shard_end(*max_idx)?;
Ok::<Range<BlockHeight>, rusqlite::Error>(match (range_min, range_max) {
(Some(start), Some(end)) => Range { start, end },
(Some(start), None) => Range {
start,
end: range.end,
},
(None, Some(end)) => Range {
start: range.start,
end,
},
(None, None) => range.clone(),
Ok::<Range<BlockHeight>, rusqlite::Error>(Range {
start: range_min.unwrap_or(range.start),
end: range_max.unwrap_or(range.end),
})
})
.transpose()
@ -804,6 +725,10 @@ mod tests {
assert_eq!(RangeOrdering::cmp(&(1..2), &(0..1)), RightFirstDisjoint);
assert_eq!(RangeOrdering::cmp(&(0..1), &(2..3)), LeftFirstDisjoint);
assert_eq!(RangeOrdering::cmp(&(2..3), &(0..1)), RightFirstDisjoint);
assert_eq!(RangeOrdering::cmp(&(1..2), &(2..2)), LeftFirstDisjoint);
assert_eq!(RangeOrdering::cmp(&(2..2), &(1..2)), RightFirstDisjoint);
assert_eq!(RangeOrdering::cmp(&(1..1), &(1..2)), LeftFirstDisjoint);
assert_eq!(RangeOrdering::cmp(&(1..2), &(1..1)), RightFirstDisjoint);
// Contained
assert_eq!(RangeOrdering::cmp(&(1..2), &(0..3)), LeftContained);