diff --git a/zebra-state/src/service/non_finalized_state.rs b/zebra-state/src/service/non_finalized_state.rs index 034e5484a..831c9ea24 100644 --- a/zebra-state/src/service/non_finalized_state.rs +++ b/zebra-state/src/service/non_finalized_state.rs @@ -204,7 +204,7 @@ impl NonFinalizedState { #[tracing::instrument(level = "debug", skip(self, finalized_state, new_chain))] fn validate_and_commit( &self, - mut new_chain: Arc, + new_chain: Arc, prepared: PreparedBlock, finalized_state: &FinalizedState, ) -> Result, ValidateContextError> { @@ -243,9 +243,10 @@ impl NonFinalizedState { // We're pretty sure the new block is valid, // so clone the inner chain if needed, then add the new block. - Arc::make_mut(&mut new_chain).push(contextual)?; - - Ok(new_chain) + Arc::try_unwrap(new_chain) + .unwrap_or_else(|shared_chain| (*shared_chain).clone()) + .push(contextual) + .map(Arc::new) } /// Returns the length of the non-finalized portion of the current best chain. diff --git a/zebra-state/src/service/non_finalized_state/chain.rs b/zebra-state/src/service/non_finalized_state/chain.rs index 01a6b0590..a8b6b4cf2 100644 --- a/zebra-state/src/service/non_finalized_state/chain.rs +++ b/zebra-state/src/service/non_finalized_state/chain.rs @@ -189,31 +189,19 @@ impl Chain { /// Push a contextually valid non-finalized block into this chain as the new tip. /// - /// If the block is invalid, clears the chain, and returns an error. + /// If the block is invalid, drops this chain, and returns an error. /// /// Note: a [`ContextuallyValidBlock`] isn't actually contextually valid until /// [`update_chain_state_with`] returns success. #[instrument(level = "debug", skip(self, block), fields(block = %block.block))] - pub fn push(&mut self, block: ContextuallyValidBlock) -> Result<(), ValidateContextError> { + pub fn push(mut self, block: ContextuallyValidBlock) -> Result { // update cumulative data members - if let Err(error) = self.update_chain_tip_with(&block) { - // The chain could be in an invalid half-updated state, so clear its data. - *self = Chain::new( - self.network, - sprout::tree::NoteCommitmentTree::default(), - sapling::tree::NoteCommitmentTree::default(), - orchard::tree::NoteCommitmentTree::default(), - HistoryTree::default(), - ValueBalance::zero(), - ); - - return Err(error); - } + self.update_chain_tip_with(&block)?; tracing::debug!(block = %block.block, "adding block to chain"); self.blocks.insert(block.height, block); - Ok(()) + Ok(self) } /// Remove the lowest height block of the non-finalized portion of a chain. diff --git a/zebra-state/src/service/non_finalized_state/tests/prop.rs b/zebra-state/src/service/non_finalized_state/tests/prop.rs index 9cb5da21f..48894ee9f 100644 --- a/zebra-state/src/service/non_finalized_state/tests/prop.rs +++ b/zebra-state/src/service/non_finalized_state/tests/prop.rs @@ -64,7 +64,7 @@ fn push_genesis_chain() -> Result<()> { chain_values.insert(block.height.into(), (block.chain_value_pool_change.into(), None)); - only_chain + only_chain = only_chain .push(block.clone()) .map_err(|e| (e, chain_values.clone())) .expect("invalid chain value pools"); @@ -105,7 +105,7 @@ fn push_history_tree_chain() -> Result<()> { .iter() .take(count) .map(ContextuallyValidBlock::test_with_zero_chain_pool_change) { - only_chain.push(block)?; + only_chain = only_chain.push(block)?; } prop_assert_eq!(only_chain.blocks.len(), count); @@ -154,7 +154,7 @@ fn forked_equals_pushed_genesis() -> Result<()> { block, partial_chain.unspent_utxos(), )?; - partial_chain + partial_chain = partial_chain .push(block) .expect("partial chain push is valid"); } @@ -171,7 +171,7 @@ fn forked_equals_pushed_genesis() -> Result<()> { for block in chain.iter().cloned() { let block = ContextuallyValidBlock::with_block_and_spent_utxos(block, full_chain.unspent_utxos())?; - full_chain + full_chain = full_chain .push(block.clone()) .expect("full chain push is valid"); @@ -221,7 +221,7 @@ fn forked_equals_pushed_genesis() -> Result<()> { for block in chain.iter().skip(fork_at_count).cloned() { let block = ContextuallyValidBlock::with_block_and_spent_utxos(block, forked.unspent_utxos())?; - forked.push(block).expect("forked chain push is valid"); + forked = forked.push(block).expect("forked chain push is valid"); } prop_assert_eq!(forked.blocks.len(), full_chain.blocks.len()); @@ -261,13 +261,13 @@ fn forked_equals_pushed_history_tree() -> Result<()> { .iter() .take(fork_at_count) .map(ContextuallyValidBlock::test_with_zero_chain_pool_change) { - partial_chain.push(block)?; + partial_chain = partial_chain.push(block)?; } for block in chain .iter() .map(ContextuallyValidBlock::test_with_zero_chain_pool_change) { - full_chain.push(block.clone())?; + full_chain = full_chain.push(block.clone())?; } let mut forked = full_chain @@ -291,7 +291,7 @@ fn forked_equals_pushed_history_tree() -> Result<()> { .iter() .skip(fork_at_count) .map(ContextuallyValidBlock::test_with_zero_chain_pool_change) { - forked.push(block)?; + forked = forked.push(block)?; } prop_assert_eq!(forked.blocks.len(), full_chain.blocks.len()); @@ -328,7 +328,7 @@ fn finalized_equals_pushed_genesis() -> Result<()> { .iter() .take(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - full_chain.push(block)?; + full_chain = full_chain.push(block)?; } let mut partial_chain = Chain::new( @@ -343,14 +343,14 @@ fn finalized_equals_pushed_genesis() -> Result<()> { .iter() .skip(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - partial_chain.push(block.clone())?; + partial_chain = partial_chain.push(block.clone())?; } for block in chain .iter() .skip(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - full_chain.push(block.clone())?; + full_chain = full_chain.push(block.clone())?; } for _ in 0..finalized_count { @@ -398,7 +398,7 @@ fn finalized_equals_pushed_history_tree() -> Result<()> { .iter() .take(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - full_chain.push(block)?; + full_chain = full_chain.push(block)?; } let mut partial_chain = Chain::new( @@ -414,14 +414,14 @@ fn finalized_equals_pushed_history_tree() -> Result<()> { .iter() .skip(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - partial_chain.push(block.clone())?; + partial_chain = partial_chain.push(block.clone())?; } for block in chain .iter() .skip(finalized_count) .map(ContextuallyValidBlock::test_with_zero_spent_utxos) { - full_chain.push(block.clone())?; + full_chain= full_chain.push(block.clone())?; } for _ in 0..finalized_count { @@ -563,8 +563,8 @@ fn different_blocks_different_chains() -> Result<()> { } else { Default::default() }; - let mut chain1 = Chain::new(Network::Mainnet, Default::default(), Default::default(), Default::default(), finalized_tree1, ValueBalance::fake_populated_pool()); - let mut chain2 = Chain::new(Network::Mainnet, Default::default(), Default::default(), Default::default(), finalized_tree2, ValueBalance::fake_populated_pool()); + let chain1 = Chain::new(Network::Mainnet, Default::default(), Default::default(), Default::default(), finalized_tree1, ValueBalance::fake_populated_pool()); + let chain2 = Chain::new(Network::Mainnet, Default::default(), Default::default(), Default::default(), finalized_tree2, ValueBalance::fake_populated_pool()); let block1 = vec1[1].clone().prepare().test_with_zero_spent_utxos(); let block2 = vec2[1].clone().prepare().test_with_zero_spent_utxos(); @@ -572,8 +572,8 @@ fn different_blocks_different_chains() -> Result<()> { let result1 = chain1.push(block1.clone()); let result2 = chain2.push(block2.clone()); - // if there is an error, the chains come back empty - if result1.is_ok() && result2.is_ok() { + // if there is an error, we don't get the chains back + if let (Ok(mut chain1), Ok(chain2)) = (result1, result2) { if block1 == block2 { // the blocks were equal, so the chains should be equal diff --git a/zebra-state/src/service/non_finalized_state/tests/vectors.rs b/zebra-state/src/service/non_finalized_state/tests/vectors.rs index 7a3d326b3..6bff5d582 100644 --- a/zebra-state/src/service/non_finalized_state/tests/vectors.rs +++ b/zebra-state/src/service/non_finalized_state/tests/vectors.rs @@ -50,7 +50,7 @@ fn construct_single() -> Result<()> { ValueBalance::fake_populated_pool(), ); - chain.push(block.prepare().test_with_zero_spent_utxos())?; + chain = chain.push(block.prepare().test_with_zero_spent_utxos())?; assert_eq!(1, chain.blocks.len()); @@ -81,7 +81,7 @@ fn construct_many() -> Result<()> { ); for block in blocks { - chain.push(block.prepare().test_with_zero_spent_utxos())?; + chain = chain.push(block.prepare().test_with_zero_spent_utxos())?; } assert_eq!(100, chain.blocks.len()); @@ -105,7 +105,7 @@ fn ord_matches_work() -> Result<()> { Default::default(), ValueBalance::fake_populated_pool(), ); - lesser_chain.push(less_block.prepare().test_with_zero_spent_utxos())?; + lesser_chain = lesser_chain.push(less_block.prepare().test_with_zero_spent_utxos())?; let mut bigger_chain = Chain::new( Network::Mainnet, @@ -115,7 +115,7 @@ fn ord_matches_work() -> Result<()> { Default::default(), ValueBalance::zero(), ); - bigger_chain.push(more_block.prepare().test_with_zero_spent_utxos())?; + bigger_chain = bigger_chain.push(more_block.prepare().test_with_zero_spent_utxos())?; assert!(bigger_chain > lesser_chain);