reorg in wallets

This commit is contained in:
Yostra 2020-03-10 12:58:42 -07:00
parent d8211f5257
commit b93a5a33e9
5 changed files with 93 additions and 8 deletions

View File

@ -1,3 +1,5 @@
from secrets import token_bytes
from src.full_node.full_node import FullNode
from typing import AsyncGenerator, List, Dict
from src.full_node.blockchain import Blockchain
@ -16,6 +18,7 @@ from src.types.header import Header
from src.types.sized_bytes import bytes32
from src.full_node.coin_store import CoinStore
from src.util.api_decorators import api_request
from src.util.ints import uint32
from tests.block_tools import BlockTools
OutboundMessageGenerator = AsyncGenerator[OutboundMessage, None]
@ -105,6 +108,7 @@ class FullNodeSimulator(FullNode):
async for msg in super().request_additions(request):
yield msg
# WALLET LOCAL TEST PROTOCOL
def get_tip(self):
tips = self.blockchain.tips
top = tips[0]
@ -115,7 +119,6 @@ class FullNodeSimulator(FullNode):
return top
# WALLET LOCAL TEST PROTOCOL
async def get_current_blocks(self, tip: Header) -> List[FullBlock]:
current_blocks: List[FullBlock] = []
@ -156,3 +159,27 @@ class FullNodeSimulator(FullNode):
async for msg in self.respond_block(full_node_protocol.RespondBlock(new_lca)):
self.server.push_message(msg)
@api_request
async def reorg_from_index_to_new_index(
self, old_index: uint32, new_index: uint32, coinbase_ph: bytes32
):
top_tip = self.get_tip()
current_blocks = await self.get_current_blocks(top_tip)
block_count = new_index - old_index
more_blocks = bt.get_consecutive_blocks(
self.constants,
block_count,
current_blocks[:old_index],
10,
seed=token_bytes(),
reward_puzzlehash=coinbase_ph,
transaction_data_at_height={},
)
for block in more_blocks:
async for msg in self.respond_block(full_node_protocol.RespondBlock(block)):
self.server.push_message(msg)
self.log.info(f"New message: {msg}")

View File

@ -46,6 +46,7 @@ class WalletStateManager:
# TODO Don't allow user to send tx until wallet is synced
synced: bool
genesis: FullBlock
@staticmethod
async def create(
@ -76,6 +77,7 @@ class WalletStateManager:
self.height_to_hash = {}
self.block_records = await self.wallet_store.get_lca_path()
genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"])
self.genesis = genesis
if len(self.block_records) > 0:
# Header hash with the highest weight
@ -130,7 +132,9 @@ class WalletStateManager:
valid_index = current_index - coinbase_freeze_period
record_list: Set[
CoinRecord
] = await self.wallet_store.get_coin_records_by_spent_and_index(False, valid_index)
] = await self.wallet_store.get_coin_records_by_spent_and_index(
False, valid_index
)
amount: uint64 = uint64(0)
@ -228,7 +232,6 @@ class WalletStateManager:
sum += coinrecord.coin.amount
used_coins.add(coinrecord.coin)
# This happens when we couldn't use one of the coins because it's already used
# but unconfirmed, and we are waiting for the change. (unconfirmed_additions)
if sum < amount:
@ -406,7 +409,7 @@ class WalletStateManager:
blocks_to_add: List[BlockRecord] = []
tip_hash: bytes32 = block.header_hash
while True:
if tip_hash == fork_hash:
if tip_hash == fork_hash or tip_hash == self.genesis.header_hash:
break
record = self.block_records[tip_hash]
blocks_to_add.append(record)
@ -519,6 +522,7 @@ class WalletStateManager:
Rolls back and updates the coin_store and transaction store. It's possible this height
is the tip, or even beyond the tip.
"""
self.log.warning(f"Rolling back to {index}")
await self.wallet_store.rollback_lca_to_block(index)
reorged: List[TransactionRecord] = await self.tx_store.get_transaction_above(

View File

@ -143,12 +143,15 @@ class WalletStore:
coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4]))
return coins
async def get_coin_records_by_spent_and_index(self, spent: bool, index: uint32) -> Set[CoinRecord]:
async def get_coin_records_by_spent_and_index(
self, spent: bool, index: uint32
) -> Set[CoinRecord]:
""" Returns set of CoinRecords that have been confirmed before index height. """
coins = set()
cursor = await self.db_connection.execute(
"SELECT * from coin_record WHERE spent=? and confirmed_index<?", (int(spent), int(index),)
"SELECT * from coin_record WHERE spent=? and confirmed_index<?",
(int(spent), int(index),),
)
rows = await cursor.fetchall()
await cursor.close()

View File

@ -80,7 +80,9 @@ async def setup_full_node_simulator(db_name, port, introducer_port=None, dic={})
f"full_node_{port}",
test_constants_copy,
)
server_1 = ChiaServer(port, full_node_1, NodeType.FULL_NODE, name="full-node-simulator-server")
server_1 = ChiaServer(
port, full_node_1, NodeType.FULL_NODE, name="full-node-simulator-server"
)
_ = await server_1.start_server(config["host"], full_node_1._on_connect)
full_node_1._set_server(server_1)
@ -155,7 +157,11 @@ async def setup_wallet_node(port, introducer_port=None, key_seed=b"", dic={}):
Path(db_path).unlink()
wallet = await WalletNode.create(
config, key_config, db_path=db_path, override_constants=test_constants_copy, name="wallet1"
config,
key_config,
db_path=db_path,
override_constants=test_constants_copy,
name="wallet1",
)
server = ChiaServer(port, wallet, NodeType.WALLET, name="wallet-server")
wallet.set_server(server)

View File

@ -1,6 +1,9 @@
import asyncio
from secrets import token_bytes
import pytest
from src.types.header import Header
from src.types.peer_info import PeerInfo
from src.util.ints import uint16, uint32
from tests.setup_nodes import (
@ -113,3 +116,45 @@ class TestWalletSimulator:
assert confirmed_balance == new_funds - 10
assert unconfirmed_balance == new_funds - 10
@pytest.mark.asyncio
async def test_wallet_coinbase_reorg(self, wallet_node):
num_blocks = 10
full_node_1, wallet_node, server_1, server_2 = wallet_node
wallet = wallet_node.wallet
ph = await wallet.get_new_puzzlehash()
await server_2.start_client(
PeerInfo(server_1._host, uint16(server_1._port)), None
)
for i in range(1, num_blocks):
await full_node_1.farm_new_block(ph)
await asyncio.sleep(3)
funds = sum(
[
calculate_base_fee(uint32(i)) + calculate_block_reward(uint32(i))
for i in range(1, num_blocks - 2)
]
)
assert await wallet.get_confirmed_balance() == funds
wallet_lca = wallet_node.wallet_state_manager.lca
lca_header: Header = wallet_node.wallet_state_manager.block_records[wallet_lca]
await full_node_1.reorg_from_index_to_new_index(
5, num_blocks + 3, token_bytes()
)
await asyncio.sleep(3)
for i in range(1, 5):
wallet_node.log.info(f"i si: {i}")
funds = sum(
[
calculate_base_fee(uint32(i)) + calculate_block_reward(uint32(i))
for i in range(1, 4)
]
)
assert await wallet.get_confirmed_balance() == funds