Fix tests
This commit is contained in:
parent
54a49f2287
commit
d0b80f075c
|
@ -172,9 +172,7 @@ class FullNode:
|
|||
)
|
||||
# If connected to a wallet, send the LCA
|
||||
lca = self.blockchain.lca_block
|
||||
new_lca = wallet_protocol.NewLCA(
|
||||
lca.header_hash, lca.prev_header_hash, lca.height, lca.weight
|
||||
)
|
||||
new_lca = wallet_protocol.NewLCA(lca.header_hash, lca.height, lca.weight)
|
||||
yield OutboundMessage(
|
||||
NodeType.WALLET, Message("new_lca", new_lca), Delivery.RESPOND
|
||||
)
|
||||
|
@ -1635,10 +1633,7 @@ class FullNode:
|
|||
new_lca = self.blockchain.lca_block
|
||||
if new_lca != prev_lca:
|
||||
new_lca_req = wallet_protocol.NewLCA(
|
||||
new_lca.header_hash,
|
||||
new_lca.prev_header_hash,
|
||||
new_lca.height,
|
||||
new_lca.weight,
|
||||
new_lca.header_hash, new_lca.height, new_lca.weight,
|
||||
)
|
||||
yield OutboundMessage(
|
||||
NodeType.WALLET, Message("new_lca", new_lca_req), Delivery.BROADCAST
|
||||
|
|
|
@ -65,7 +65,6 @@ class RejectAllHeaderHashesAfterRequest:
|
|||
@cbor_message
|
||||
class NewLCA:
|
||||
lca_hash: bytes32
|
||||
prev_header_hash: bytes32
|
||||
height: uint32
|
||||
weight: uint128
|
||||
|
||||
|
|
|
@ -8,7 +8,8 @@ from blspy import ExtendedPrivateKey
|
|||
from src.util.merkle_set import (
|
||||
confirm_included_already_hashed,
|
||||
confirm_not_included_already_hashed,
|
||||
MerkleSet)
|
||||
MerkleSet,
|
||||
)
|
||||
from src.protocols import wallet_protocol
|
||||
from src.consensus.constants import constants as consensus_constants
|
||||
from src.server.server import ChiaServer
|
||||
|
@ -323,6 +324,9 @@ class WalletNode:
|
|||
async for msg in self._block_finished(new_br, new_hb):
|
||||
yield msg
|
||||
if res == ReceiveBlockResult.ADDED_TO_HEAD:
|
||||
self.log.info(
|
||||
f"Updated LCA to {block_record.prev_header_hash} at height {block_record.height}"
|
||||
)
|
||||
# Removes outdated cached blocks if we're not syncing
|
||||
if not self.sync_mode:
|
||||
for header_hash in self.cached_blocks:
|
||||
|
@ -392,6 +396,7 @@ class WalletNode:
|
|||
|
||||
@api_request
|
||||
async def new_lca(self, request: wallet_protocol.NewLCA):
|
||||
print("Got LCA height", request)
|
||||
if self.sync_mode:
|
||||
return
|
||||
# If already seen LCA, ignore.
|
||||
|
@ -416,7 +421,7 @@ class WalletNode:
|
|||
self.sync_mode = False
|
||||
else:
|
||||
header_request = wallet_protocol.RequestHeader(
|
||||
uint32(request.height - 1), request.prev_header_hash
|
||||
uint32(request.height), request.lca_hash
|
||||
)
|
||||
yield OutboundMessage(
|
||||
NodeType.FULL_NODE,
|
||||
|
@ -426,6 +431,7 @@ class WalletNode:
|
|||
|
||||
@api_request
|
||||
async def respond_header(self, response: wallet_protocol.RespondHeader):
|
||||
print("Got header height", response.header_block.header.height)
|
||||
block = response.header_block
|
||||
# If we already have, return
|
||||
if block.header_hash in self.wallet_state_manager.block_records:
|
||||
|
@ -507,7 +513,8 @@ class WalletNode:
|
|||
# Verify removals root
|
||||
removals_merkle_set = MerkleSet()
|
||||
for coin in removals:
|
||||
removals_merkle_set.add_already_hashed(coin.name())
|
||||
if coin is not None:
|
||||
removals_merkle_set.add_already_hashed(coin.name())
|
||||
removals_root = removals_merkle_set.get_root()
|
||||
if header_block.header.data.removals_root != removals_root:
|
||||
return
|
||||
|
|
|
@ -329,8 +329,8 @@ class WalletStateManager:
|
|||
return ReceiveBlockResult.DISCONNECTED_BLOCK
|
||||
|
||||
if header_block is not None:
|
||||
# TODO: validate header block
|
||||
pass
|
||||
if not self.validate_header_block(header_block):
|
||||
return ReceiveBlockResult.INVALID_BLOCK
|
||||
|
||||
self.block_records[block.header_hash] = block
|
||||
await self.wallet_store.add_block_record(block, False)
|
||||
|
@ -368,12 +368,12 @@ class WalletStateManager:
|
|||
for path_block in blocks_to_add:
|
||||
self.height_to_hash[path_block.height] = path_block.header_hash
|
||||
await self.wallet_store.add_block_to_path(path_block.header_hash)
|
||||
if header_block:
|
||||
if header_block is not None:
|
||||
coinbase = header_block.header.data.coinbase
|
||||
fees_coin = header_block.header.data.fees_coin
|
||||
if (await self.is_addition_relevant(coinbase.puzzle_hash)):
|
||||
if await self.is_addition_relevant(coinbase):
|
||||
await self.coin_added(coinbase, path_block.height, True)
|
||||
if (await self.is_addition_relevant(fees_coin.puzzle_hash)):
|
||||
if await self.is_addition_relevant(fees_coin):
|
||||
await self.coin_added(fees_coin, path_block.height, True)
|
||||
for coin in path_block.additions:
|
||||
await self.coin_added(coin, path_block.height, False)
|
||||
|
@ -384,6 +384,10 @@ class WalletStateManager:
|
|||
|
||||
return ReceiveBlockResult.ADDED_AS_ORPHAN
|
||||
|
||||
async def validate_header_block(self, header_block: HeaderBlock) -> bool:
|
||||
# TODO(mariano): implement
|
||||
return True
|
||||
|
||||
def find_fork_for_lca(self, new_lca: BlockRecord) -> uint32:
|
||||
""" Tries to find height where new chain (current) diverged from the old chain where old_lca was the LCA"""
|
||||
tmp_old: BlockRecord = self.block_records[self.lca]
|
||||
|
|
|
@ -43,59 +43,22 @@ class TestFullSync:
|
|||
)
|
||||
|
||||
await asyncio.sleep(2) # Allow connections to get made
|
||||
start_unf = time.time()
|
||||
start = time.time()
|
||||
|
||||
while time.time() - start_unf < 60:
|
||||
while time.time() - start < 60:
|
||||
# The second node should eventually catch up to the first one, and have the
|
||||
# same tip at height num_blocks - 1.
|
||||
if (
|
||||
max([h.height for h in full_node_2.blockchain.get_current_tips()])
|
||||
== num_blocks - 1
|
||||
):
|
||||
print(f"Time taken to sync {num_blocks} is {time.time() - start_unf}")
|
||||
print(f"Time taken to sync {num_blocks} is {time.time() - start}")
|
||||
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
raise Exception(
|
||||
f"Took too long to process blocks, stopped at: {time.time() - start_unf}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_sync_wallet(self, wallet_node):
|
||||
num_blocks = 25
|
||||
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_node
|
||||
|
||||
for i in range(1, num_blocks):
|
||||
async for _ in full_node_1.respond_block(
|
||||
full_node_protocol.RespondBlock(blocks[i])
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
|
||||
await asyncio.sleep(2) # Allow connections to get made
|
||||
start_unf = time.time()
|
||||
|
||||
while time.time() - start_unf < 60:
|
||||
# The second node should eventually catch up to the first one, and have the
|
||||
# same tip at height num_blocks - 1.
|
||||
if (
|
||||
wallet_node.wallet_state_manager.block_records[
|
||||
wallet_node.wallet_state_manager.lca
|
||||
].height
|
||||
== num_blocks - 7
|
||||
):
|
||||
print(f"Time taken to sync {num_blocks} is {time.time() - start_unf}")
|
||||
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
raise Exception(
|
||||
f"Took too long to process blocks, stopped at: {time.time() - start_unf}"
|
||||
f"Took too long to process blocks, stopped at: {time.time() - start}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -132,17 +95,109 @@ class TestFullSync:
|
|||
)
|
||||
await asyncio.sleep(2) # Allow connections to get made
|
||||
|
||||
start_unf = time.time()
|
||||
start = time.time()
|
||||
|
||||
while time.time() - start_unf < 30:
|
||||
while time.time() - start < 30:
|
||||
# The second node should eventually catch up to the first one, and have the
|
||||
# same tip at height num_blocks - 1.
|
||||
if (
|
||||
max([h.height for h in full_node_2.blockchain.get_current_tips()])
|
||||
== num_blocks - 1
|
||||
):
|
||||
print(f"Time taken to sync {num_blocks} is {time.time() - start_unf}")
|
||||
print(f"Time taken to sync {num_blocks} is {time.time() - start}")
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
raise Exception("Took too long to process blocks")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_sync_wallet(self, wallet_node):
|
||||
num_blocks = 25
|
||||
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_node
|
||||
|
||||
for i in range(1, len(blocks)):
|
||||
async for _ in full_node_1.respond_block(
|
||||
full_node_protocol.RespondBlock(blocks[i])
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
found = False
|
||||
while time.time() - start < 60:
|
||||
# The second node should eventually catch up to the first one, and have the
|
||||
# same tip at height num_blocks - 1.
|
||||
if (
|
||||
wallet_node.wallet_state_manager.block_records[
|
||||
wallet_node.wallet_state_manager.lca
|
||||
].height
|
||||
== num_blocks - 6
|
||||
):
|
||||
found = True
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
if not found:
|
||||
raise Exception(
|
||||
f"Took too long to process blocks, stopped at: {time.time() - start}"
|
||||
)
|
||||
|
||||
# Tests a reorg with the wallet
|
||||
start = time.time()
|
||||
found = False
|
||||
blocks_reorg = bt.get_consecutive_blocks(test_constants, 15, blocks[:-5], 10)
|
||||
for i in range(1, len(blocks_reorg)):
|
||||
async for msg in full_node_1.respond_block(
|
||||
full_node_protocol.RespondBlock(blocks_reorg[i])
|
||||
):
|
||||
server_1.push_message(msg)
|
||||
start = time.time()
|
||||
|
||||
while time.time() - start < 60:
|
||||
if (
|
||||
wallet_node.wallet_state_manager.block_records[
|
||||
wallet_node.wallet_state_manager.lca
|
||||
].height
|
||||
== 33
|
||||
):
|
||||
found = True
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
if not found:
|
||||
raise Exception(
|
||||
f"Took too long to process blocks, stopped at: {time.time() - start}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_sync_wallet(self, wallet_node):
|
||||
num_blocks = 8
|
||||
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_node
|
||||
|
||||
for i in range(1, len(blocks)):
|
||||
async for _ in full_node_1.respond_block(
|
||||
full_node_protocol.RespondBlock(blocks[i])
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
start = time.time()
|
||||
while time.time() - start < 60:
|
||||
# The second node should eventually catch up to the first one, and have the
|
||||
# same tip at height num_blocks - 1.
|
||||
if (
|
||||
wallet_node.wallet_state_manager.block_records[
|
||||
wallet_node.wallet_state_manager.lca
|
||||
].height
|
||||
== 6
|
||||
):
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
raise Exception(
|
||||
f"Took too long to process blocks, stopped at: {time.time() - start}"
|
||||
)
|
||||
|
|
|
@ -102,7 +102,7 @@ async def setup_wallet_node(port, introducer_port=None, dic={}):
|
|||
test_constants_copy = test_constants.copy()
|
||||
for k in dic.keys():
|
||||
test_constants_copy[k] = dic[k]
|
||||
db_path = "test-wallet-db" + token_bytes(32).hex()
|
||||
db_path = "test-wallet-db" + token_bytes(32).hex() + ".db"
|
||||
if Path(db_path).exists():
|
||||
Path(db_path).unlink()
|
||||
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
# from blspy import ExtendedPrivateKey
|
||||
|
||||
# from src.wallet.wallet_node import WalletNode
|
||||
# from tests.setup_nodes import setup_two_nodes, test_constants, bt
|
||||
# from src.protocols import full_node_protocol
|
||||
# from tests.setup_nodes import setup_node_and_wallet, test_constants, bt
|
||||
|
||||
|
||||
# @pytest.fixture(scope="module")
|
||||
|
@ -15,31 +16,25 @@
|
|||
|
||||
# class TestWallet:
|
||||
# @pytest.fixture(scope="function")
|
||||
# async def two_nodes(self):
|
||||
# async for _ in setup_two_nodes({"COINBASE_FREEZE_PERIOD": 0}):
|
||||
# async def wallet_node(self):
|
||||
# async for _ in setup_node_and_wallet():
|
||||
# yield _
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_wallet_receive_body(self, two_nodes):
|
||||
# sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
|
||||
# key_config = {"wallet_sk": sk}
|
||||
|
||||
# wallet_node = await WalletNode.create({}, key_config)
|
||||
# wallet = wallet_node.wallet
|
||||
# await wallet_node.wallet_store._clear_database()
|
||||
# await wallet_node.tx_store._clear_database()
|
||||
|
||||
# async def test_wallet_receive_body(self, wallet_node):
|
||||
# num_blocks = 10
|
||||
# full_node_1, wallet_node, server_1, server_2 = wallet_node
|
||||
# wallet = wallet_node.wallet
|
||||
# ph = await wallet.get_new_puzzlehash()
|
||||
# blocks = bt.get_consecutive_blocks(
|
||||
# test_constants, num_blocks, [], 10, reward_puzzlehash=ph,
|
||||
# )
|
||||
|
||||
# for i in range(1, num_blocks):
|
||||
# a = RespondBody(
|
||||
# blocks[i].header, blocks[i].transactions_generator, blocks[i].height
|
||||
# )
|
||||
# await wallet_node.received_body(a)
|
||||
# for i in range(1, len(blocks)):
|
||||
# async for _ in full_node_1.respond_block(
|
||||
# full_node_protocol.RespondBlock(blocks[i])
|
||||
# ):
|
||||
# pass
|
||||
# await asyncio.sleep(50)
|
||||
|
||||
# assert await wallet.get_confirmed_balance() == 144000000000000
|
||||
|
||||
|
|
Loading…
Reference in New Issue