Fix tests

This commit is contained in:
Mariano Sorgente 2020-03-09 12:36:46 +09:00
parent 54a49f2287
commit d0b80f075c
No known key found for this signature in database
GPG Key ID: 0F866338C369278C
7 changed files with 134 additions and 79 deletions

View File

@ -172,9 +172,7 @@ class FullNode:
)
# If connected to a wallet, send the LCA
lca = self.blockchain.lca_block
new_lca = wallet_protocol.NewLCA(
lca.header_hash, lca.prev_header_hash, lca.height, lca.weight
)
new_lca = wallet_protocol.NewLCA(lca.header_hash, lca.height, lca.weight)
yield OutboundMessage(
NodeType.WALLET, Message("new_lca", new_lca), Delivery.RESPOND
)
@ -1635,10 +1633,7 @@ class FullNode:
new_lca = self.blockchain.lca_block
if new_lca != prev_lca:
new_lca_req = wallet_protocol.NewLCA(
new_lca.header_hash,
new_lca.prev_header_hash,
new_lca.height,
new_lca.weight,
new_lca.header_hash, new_lca.height, new_lca.weight,
)
yield OutboundMessage(
NodeType.WALLET, Message("new_lca", new_lca_req), Delivery.BROADCAST

View File

@ -65,7 +65,6 @@ class RejectAllHeaderHashesAfterRequest:
@cbor_message
class NewLCA:
lca_hash: bytes32
prev_header_hash: bytes32
height: uint32
weight: uint128

View File

@ -8,7 +8,8 @@ from blspy import ExtendedPrivateKey
from src.util.merkle_set import (
confirm_included_already_hashed,
confirm_not_included_already_hashed,
MerkleSet)
MerkleSet,
)
from src.protocols import wallet_protocol
from src.consensus.constants import constants as consensus_constants
from src.server.server import ChiaServer
@ -323,6 +324,9 @@ class WalletNode:
async for msg in self._block_finished(new_br, new_hb):
yield msg
if res == ReceiveBlockResult.ADDED_TO_HEAD:
self.log.info(
f"Updated LCA to {block_record.prev_header_hash} at height {block_record.height}"
)
# Removes outdated cached blocks if we're not syncing
if not self.sync_mode:
for header_hash in self.cached_blocks:
@ -392,6 +396,7 @@ class WalletNode:
@api_request
async def new_lca(self, request: wallet_protocol.NewLCA):
print("Got LCA height", request)
if self.sync_mode:
return
# If already seen LCA, ignore.
@ -416,7 +421,7 @@ class WalletNode:
self.sync_mode = False
else:
header_request = wallet_protocol.RequestHeader(
uint32(request.height - 1), request.prev_header_hash
uint32(request.height), request.lca_hash
)
yield OutboundMessage(
NodeType.FULL_NODE,
@ -426,6 +431,7 @@ class WalletNode:
@api_request
async def respond_header(self, response: wallet_protocol.RespondHeader):
print("Got header height", response.header_block.header.height)
block = response.header_block
# If we already have, return
if block.header_hash in self.wallet_state_manager.block_records:
@ -507,7 +513,8 @@ class WalletNode:
# Verify removals root
removals_merkle_set = MerkleSet()
for coin in removals:
removals_merkle_set.add_already_hashed(coin.name())
if coin is not None:
removals_merkle_set.add_already_hashed(coin.name())
removals_root = removals_merkle_set.get_root()
if header_block.header.data.removals_root != removals_root:
return

View File

@ -329,8 +329,8 @@ class WalletStateManager:
return ReceiveBlockResult.DISCONNECTED_BLOCK
if header_block is not None:
# TODO: validate header block
pass
if not self.validate_header_block(header_block):
return ReceiveBlockResult.INVALID_BLOCK
self.block_records[block.header_hash] = block
await self.wallet_store.add_block_record(block, False)
@ -368,12 +368,12 @@ class WalletStateManager:
for path_block in blocks_to_add:
self.height_to_hash[path_block.height] = path_block.header_hash
await self.wallet_store.add_block_to_path(path_block.header_hash)
if header_block:
if header_block is not None:
coinbase = header_block.header.data.coinbase
fees_coin = header_block.header.data.fees_coin
if (await self.is_addition_relevant(coinbase.puzzle_hash)):
if await self.is_addition_relevant(coinbase):
await self.coin_added(coinbase, path_block.height, True)
if (await self.is_addition_relevant(fees_coin.puzzle_hash)):
if await self.is_addition_relevant(fees_coin):
await self.coin_added(fees_coin, path_block.height, True)
for coin in path_block.additions:
await self.coin_added(coin, path_block.height, False)
@ -384,6 +384,10 @@ class WalletStateManager:
return ReceiveBlockResult.ADDED_AS_ORPHAN
async def validate_header_block(self, header_block: HeaderBlock) -> bool:
# TODO(mariano): implement
return True
def find_fork_for_lca(self, new_lca: BlockRecord) -> uint32:
""" Tries to find height where new chain (current) diverged from the old chain where old_lca was the LCA"""
tmp_old: BlockRecord = self.block_records[self.lca]

View File

@ -43,59 +43,22 @@ class TestFullSync:
)
await asyncio.sleep(2) # Allow connections to get made
start_unf = time.time()
start = time.time()
while time.time() - start_unf < 60:
while time.time() - start < 60:
# The second node should eventually catch up to the first one, and have the
# same tip at height num_blocks - 1.
if (
max([h.height for h in full_node_2.blockchain.get_current_tips()])
== num_blocks - 1
):
print(f"Time taken to sync {num_blocks} is {time.time() - start_unf}")
print(f"Time taken to sync {num_blocks} is {time.time() - start}")
return
await asyncio.sleep(0.1)
raise Exception(
f"Took too long to process blocks, stopped at: {time.time() - start_unf}"
)
@pytest.mark.asyncio
async def test_basic_sync_wallet(self, wallet_node):
num_blocks = 25
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
full_node_1, wallet_node, server_1, server_2 = wallet_node
for i in range(1, num_blocks):
async for _ in full_node_1.respond_block(
full_node_protocol.RespondBlock(blocks[i])
):
pass
await server_2.start_client(
PeerInfo(server_1._host, uint16(server_1._port)), None
)
await asyncio.sleep(2) # Allow connections to get made
start_unf = time.time()
while time.time() - start_unf < 60:
# The second node should eventually catch up to the first one, and have the
# same tip at height num_blocks - 1.
if (
wallet_node.wallet_state_manager.block_records[
wallet_node.wallet_state_manager.lca
].height
== num_blocks - 7
):
print(f"Time taken to sync {num_blocks} is {time.time() - start_unf}")
return
await asyncio.sleep(0.1)
raise Exception(
f"Took too long to process blocks, stopped at: {time.time() - start_unf}"
f"Took too long to process blocks, stopped at: {time.time() - start}"
)
@pytest.mark.asyncio
@ -132,17 +95,109 @@ class TestFullSync:
)
await asyncio.sleep(2) # Allow connections to get made
start_unf = time.time()
start = time.time()
while time.time() - start_unf < 30:
while time.time() - start < 30:
# The second node should eventually catch up to the first one, and have the
# same tip at height num_blocks - 1.
if (
max([h.height for h in full_node_2.blockchain.get_current_tips()])
== num_blocks - 1
):
print(f"Time taken to sync {num_blocks} is {time.time() - start_unf}")
print(f"Time taken to sync {num_blocks} is {time.time() - start}")
return
await asyncio.sleep(0.1)
raise Exception("Took too long to process blocks")
@pytest.mark.asyncio
async def test_basic_sync_wallet(self, wallet_node):
num_blocks = 25
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
full_node_1, wallet_node, server_1, server_2 = wallet_node
for i in range(1, len(blocks)):
async for _ in full_node_1.respond_block(
full_node_protocol.RespondBlock(blocks[i])
):
pass
await server_2.start_client(
PeerInfo(server_1._host, uint16(server_1._port)), None
)
start = time.time()
found = False
while time.time() - start < 60:
# The second node should eventually catch up to the first one, and have the
# same tip at height num_blocks - 1.
if (
wallet_node.wallet_state_manager.block_records[
wallet_node.wallet_state_manager.lca
].height
== num_blocks - 6
):
found = True
break
await asyncio.sleep(0.1)
if not found:
raise Exception(
f"Took too long to process blocks, stopped at: {time.time() - start}"
)
# Tests a reorg with the wallet
start = time.time()
found = False
blocks_reorg = bt.get_consecutive_blocks(test_constants, 15, blocks[:-5], 10)
for i in range(1, len(blocks_reorg)):
async for msg in full_node_1.respond_block(
full_node_protocol.RespondBlock(blocks_reorg[i])
):
server_1.push_message(msg)
start = time.time()
while time.time() - start < 60:
if (
wallet_node.wallet_state_manager.block_records[
wallet_node.wallet_state_manager.lca
].height
== 33
):
found = True
break
await asyncio.sleep(0.1)
if not found:
raise Exception(
f"Took too long to process blocks, stopped at: {time.time() - start}"
)
@pytest.mark.asyncio
async def test_short_sync_wallet(self, wallet_node):
num_blocks = 8
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
full_node_1, wallet_node, server_1, server_2 = wallet_node
for i in range(1, len(blocks)):
async for _ in full_node_1.respond_block(
full_node_protocol.RespondBlock(blocks[i])
):
pass
await server_2.start_client(
PeerInfo(server_1._host, uint16(server_1._port)), None
)
start = time.time()
while time.time() - start < 60:
# The second node should eventually catch up to the first one, and have the
# same tip at height num_blocks - 1.
if (
wallet_node.wallet_state_manager.block_records[
wallet_node.wallet_state_manager.lca
].height
== 6
):
return
await asyncio.sleep(0.1)
raise Exception(
f"Took too long to process blocks, stopped at: {time.time() - start}"
)

View File

@ -102,7 +102,7 @@ async def setup_wallet_node(port, introducer_port=None, dic={}):
test_constants_copy = test_constants.copy()
for k in dic.keys():
test_constants_copy[k] = dic[k]
db_path = "test-wallet-db" + token_bytes(32).hex()
db_path = "test-wallet-db" + token_bytes(32).hex() + ".db"
if Path(db_path).exists():
Path(db_path).unlink()

View File

@ -4,7 +4,8 @@
# from blspy import ExtendedPrivateKey
# from src.wallet.wallet_node import WalletNode
# from tests.setup_nodes import setup_two_nodes, test_constants, bt
# from src.protocols import full_node_protocol
# from tests.setup_nodes import setup_node_and_wallet, test_constants, bt
# @pytest.fixture(scope="module")
@ -15,31 +16,25 @@
# class TestWallet:
# @pytest.fixture(scope="function")
# async def two_nodes(self):
# async for _ in setup_two_nodes({"COINBASE_FREEZE_PERIOD": 0}):
# async def wallet_node(self):
# async for _ in setup_node_and_wallet():
# yield _
# @pytest.mark.asyncio
# async def test_wallet_receive_body(self, two_nodes):
# sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
# key_config = {"wallet_sk": sk}
# wallet_node = await WalletNode.create({}, key_config)
# wallet = wallet_node.wallet
# await wallet_node.wallet_store._clear_database()
# await wallet_node.tx_store._clear_database()
# async def test_wallet_receive_body(self, wallet_node):
# num_blocks = 10
# full_node_1, wallet_node, server_1, server_2 = wallet_node
# wallet = wallet_node.wallet
# ph = await wallet.get_new_puzzlehash()
# blocks = bt.get_consecutive_blocks(
# test_constants, num_blocks, [], 10, reward_puzzlehash=ph,
# )
# for i in range(1, num_blocks):
# a = RespondBody(
# blocks[i].header, blocks[i].transactions_generator, blocks[i].height
# )
# await wallet_node.received_body(a)
# for i in range(1, len(blocks)):
# async for _ in full_node_1.respond_block(
# full_node_protocol.RespondBlock(blocks[i])
# ):
# pass
# await asyncio.sleep(50)
# assert await wallet.get_confirmed_balance() == 144000000000000