reorg in wallets
This commit is contained in:
parent
d8211f5257
commit
b93a5a33e9
|
@ -1,3 +1,5 @@
|
|||
from secrets import token_bytes
|
||||
|
||||
from src.full_node.full_node import FullNode
|
||||
from typing import AsyncGenerator, List, Dict
|
||||
from src.full_node.blockchain import Blockchain
|
||||
|
@ -16,6 +18,7 @@ from src.types.header import Header
|
|||
from src.types.sized_bytes import bytes32
|
||||
from src.full_node.coin_store import CoinStore
|
||||
from src.util.api_decorators import api_request
|
||||
from src.util.ints import uint32
|
||||
from tests.block_tools import BlockTools
|
||||
|
||||
OutboundMessageGenerator = AsyncGenerator[OutboundMessage, None]
|
||||
|
@ -105,6 +108,7 @@ class FullNodeSimulator(FullNode):
|
|||
async for msg in super().request_additions(request):
|
||||
yield msg
|
||||
|
||||
# WALLET LOCAL TEST PROTOCOL
|
||||
def get_tip(self):
|
||||
tips = self.blockchain.tips
|
||||
top = tips[0]
|
||||
|
@ -115,7 +119,6 @@ class FullNodeSimulator(FullNode):
|
|||
|
||||
return top
|
||||
|
||||
# WALLET LOCAL TEST PROTOCOL
|
||||
async def get_current_blocks(self, tip: Header) -> List[FullBlock]:
|
||||
|
||||
current_blocks: List[FullBlock] = []
|
||||
|
@ -156,3 +159,27 @@ class FullNodeSimulator(FullNode):
|
|||
|
||||
async for msg in self.respond_block(full_node_protocol.RespondBlock(new_lca)):
|
||||
self.server.push_message(msg)
|
||||
|
||||
@api_request
|
||||
async def reorg_from_index_to_new_index(
|
||||
self, old_index: uint32, new_index: uint32, coinbase_ph: bytes32
|
||||
):
|
||||
top_tip = self.get_tip()
|
||||
|
||||
current_blocks = await self.get_current_blocks(top_tip)
|
||||
block_count = new_index - old_index
|
||||
|
||||
more_blocks = bt.get_consecutive_blocks(
|
||||
self.constants,
|
||||
block_count,
|
||||
current_blocks[:old_index],
|
||||
10,
|
||||
seed=token_bytes(),
|
||||
reward_puzzlehash=coinbase_ph,
|
||||
transaction_data_at_height={},
|
||||
)
|
||||
|
||||
for block in more_blocks:
|
||||
async for msg in self.respond_block(full_node_protocol.RespondBlock(block)):
|
||||
self.server.push_message(msg)
|
||||
self.log.info(f"New message: {msg}")
|
||||
|
|
|
@ -46,6 +46,7 @@ class WalletStateManager:
|
|||
|
||||
# TODO Don't allow user to send tx until wallet is synced
|
||||
synced: bool
|
||||
genesis: FullBlock
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
|
@ -76,6 +77,7 @@ class WalletStateManager:
|
|||
self.height_to_hash = {}
|
||||
self.block_records = await self.wallet_store.get_lca_path()
|
||||
genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"])
|
||||
self.genesis = genesis
|
||||
|
||||
if len(self.block_records) > 0:
|
||||
# Header hash with the highest weight
|
||||
|
@ -130,7 +132,9 @@ class WalletStateManager:
|
|||
valid_index = current_index - coinbase_freeze_period
|
||||
record_list: Set[
|
||||
CoinRecord
|
||||
] = await self.wallet_store.get_coin_records_by_spent_and_index(False, valid_index)
|
||||
] = await self.wallet_store.get_coin_records_by_spent_and_index(
|
||||
False, valid_index
|
||||
)
|
||||
|
||||
amount: uint64 = uint64(0)
|
||||
|
||||
|
@ -228,7 +232,6 @@ class WalletStateManager:
|
|||
sum += coinrecord.coin.amount
|
||||
used_coins.add(coinrecord.coin)
|
||||
|
||||
|
||||
# This happens when we couldn't use one of the coins because it's already used
|
||||
# but unconfirmed, and we are waiting for the change. (unconfirmed_additions)
|
||||
if sum < amount:
|
||||
|
@ -406,7 +409,7 @@ class WalletStateManager:
|
|||
blocks_to_add: List[BlockRecord] = []
|
||||
tip_hash: bytes32 = block.header_hash
|
||||
while True:
|
||||
if tip_hash == fork_hash:
|
||||
if tip_hash == fork_hash or tip_hash == self.genesis.header_hash:
|
||||
break
|
||||
record = self.block_records[tip_hash]
|
||||
blocks_to_add.append(record)
|
||||
|
@ -519,6 +522,7 @@ class WalletStateManager:
|
|||
Rolls back and updates the coin_store and transaction store. It's possible this height
|
||||
is the tip, or even beyond the tip.
|
||||
"""
|
||||
self.log.warning(f"Rolling back to {index}")
|
||||
await self.wallet_store.rollback_lca_to_block(index)
|
||||
|
||||
reorged: List[TransactionRecord] = await self.tx_store.get_transaction_above(
|
||||
|
|
|
@ -143,12 +143,15 @@ class WalletStore:
|
|||
coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4]))
|
||||
return coins
|
||||
|
||||
async def get_coin_records_by_spent_and_index(self, spent: bool, index: uint32) -> Set[CoinRecord]:
|
||||
async def get_coin_records_by_spent_and_index(
|
||||
self, spent: bool, index: uint32
|
||||
) -> Set[CoinRecord]:
|
||||
""" Returns set of CoinRecords that have been confirmed before index height. """
|
||||
coins = set()
|
||||
|
||||
cursor = await self.db_connection.execute(
|
||||
"SELECT * from coin_record WHERE spent=? and confirmed_index<?", (int(spent), int(index),)
|
||||
"SELECT * from coin_record WHERE spent=? and confirmed_index<?",
|
||||
(int(spent), int(index),),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
await cursor.close()
|
||||
|
|
|
@ -80,7 +80,9 @@ async def setup_full_node_simulator(db_name, port, introducer_port=None, dic={})
|
|||
f"full_node_{port}",
|
||||
test_constants_copy,
|
||||
)
|
||||
server_1 = ChiaServer(port, full_node_1, NodeType.FULL_NODE, name="full-node-simulator-server")
|
||||
server_1 = ChiaServer(
|
||||
port, full_node_1, NodeType.FULL_NODE, name="full-node-simulator-server"
|
||||
)
|
||||
_ = await server_1.start_server(config["host"], full_node_1._on_connect)
|
||||
full_node_1._set_server(server_1)
|
||||
|
||||
|
@ -155,7 +157,11 @@ async def setup_wallet_node(port, introducer_port=None, key_seed=b"", dic={}):
|
|||
Path(db_path).unlink()
|
||||
|
||||
wallet = await WalletNode.create(
|
||||
config, key_config, db_path=db_path, override_constants=test_constants_copy, name="wallet1"
|
||||
config,
|
||||
key_config,
|
||||
db_path=db_path,
|
||||
override_constants=test_constants_copy,
|
||||
name="wallet1",
|
||||
)
|
||||
server = ChiaServer(port, wallet, NodeType.WALLET, name="wallet-server")
|
||||
wallet.set_server(server)
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import asyncio
|
||||
from secrets import token_bytes
|
||||
|
||||
import pytest
|
||||
|
||||
from src.types.header import Header
|
||||
from src.types.peer_info import PeerInfo
|
||||
from src.util.ints import uint16, uint32
|
||||
from tests.setup_nodes import (
|
||||
|
@ -113,3 +116,45 @@ class TestWalletSimulator:
|
|||
|
||||
assert confirmed_balance == new_funds - 10
|
||||
assert unconfirmed_balance == new_funds - 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_coinbase_reorg(self, wallet_node):
|
||||
num_blocks = 10
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_node
|
||||
wallet = wallet_node.wallet
|
||||
ph = await wallet.get_new_puzzlehash()
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
for i in range(1, num_blocks):
|
||||
await full_node_1.farm_new_block(ph)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
funds = sum(
|
||||
[
|
||||
calculate_base_fee(uint32(i)) + calculate_block_reward(uint32(i))
|
||||
for i in range(1, num_blocks - 2)
|
||||
]
|
||||
)
|
||||
assert await wallet.get_confirmed_balance() == funds
|
||||
|
||||
wallet_lca = wallet_node.wallet_state_manager.lca
|
||||
lca_header: Header = wallet_node.wallet_state_manager.block_records[wallet_lca]
|
||||
|
||||
await full_node_1.reorg_from_index_to_new_index(
|
||||
5, num_blocks + 3, token_bytes()
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
|
||||
for i in range(1, 5):
|
||||
wallet_node.log.info(f"i si: {i}")
|
||||
|
||||
funds = sum(
|
||||
[
|
||||
calculate_base_fee(uint32(i)) + calculate_block_reward(uint32(i))
|
||||
for i in range(1, 4)
|
||||
]
|
||||
)
|
||||
|
||||
assert await wallet.get_confirmed_balance() == funds
|
||||
|
|
Loading…
Reference in New Issue