add wallet tests

This commit is contained in:
Yostra 2020-10-20 18:45:09 -07:00
parent 8a9518a320
commit c874bf3ad2
8 changed files with 107 additions and 95 deletions

View File

@ -36,11 +36,7 @@ from src.protocols import (
timelord_protocol,
wallet_protocol,
)
<<<<<<< HEAD
from src.server.connection import PeerConnections
=======
from src.server.node_discovery import FullNodePeers
>>>>>>> test_wallet & peer discovery
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.server.server import ChiaServer
from src.types.end_of_slot_bundle import EndOfSubSlotBundle
@ -616,9 +612,17 @@ class FullNode:
# This is a block we asked for during sync
if self.sync_peers_handler is not None:
async for req in self.sync_peers_handler.new_block(sub_block):
yield req
return
resp: List[OutboundMessage] = await self.sync_peers_handler.new_block(
respond_block.block
)
for req in resp:
type = req.peer_type
node_id = req.specific_peer_node_id
message = req.message
if node_id is None:
await self.server.send_to_all([message], type)
else:
await self.server.send_to_specific([message], node_id)
# Adds the block to seen, and check if it's seen before (which means header is in memory)
header_hash = sub_block.header.get_hash()

View File

@ -195,36 +195,39 @@ class ChiaServer:
if full_message is None or connection is None:
continue
try:
connection.log.info(
f"<- {full_message.function} from peer {connection.peer_node_id}"
)
if len(full_message.function) == 0 or full_message.function.startswith(
"_"
):
# This prevents remote calling of private methods that start with "_"
raise ProtocolError(
Err.INVALID_PROTOCOL_MESSAGE, [full_message.function]
async def api_call(full_message, connection):
try:
connection.log.info(
f"<- {full_message.function} from peer {connection.peer_node_id}"
)
if len(
full_message.function
) == 0 or full_message.function.startswith("_"):
# This prevents remote calling of private methods that start with "_"
raise ProtocolError(
Err.INVALID_PROTOCOL_MESSAGE, [full_message.function]
)
f = getattr(self.api, full_message.function, None)
f = getattr(self.api, full_message.function, None)
if f is None:
raise ProtocolError(
Err.INVALID_PROTOCOL_MESSAGE, [full_message.function]
if f is None:
raise ProtocolError(
Err.INVALID_PROTOCOL_MESSAGE, [full_message.function]
)
response = await f(full_message.data, connection)
if response is not None:
await connection.send_message(response)
except Exception as e:
tb = traceback.format_exc()
connection.log.error(
f"Exception: {e}, closing connection {connection}. {tb}"
)
await connection.close()
response = await f(full_message.data, connection)
if response is not None:
await connection.send_message(response)
except Exception as e:
tb = traceback.format_exc()
connection.log.error(
f"Exception: {e}, closing connection {connection}. {tb}"
)
await connection.close()
asyncio.create_task(api_call(full_message, connection))
async def send_to_others(
self, messages: List[Message], type: NodeType, origin_peer: WSChiaConnection
@ -264,9 +267,15 @@ class ChiaServer:
return result
def close_all(self):
async def close_all_connections(self):
for id, connection in self.global_connections.items():
asyncio.ensure_future(connection.close())
try:
await connection.close()
except Exception as e:
self.log.error(f"exeption while closing connection {e}")
def close_all(self):
asyncio.ensure_future(self.close_all_connections())
self.site_shutdown_task = asyncio.create_task(self.site.stop())
self.app_shut_down_task = asyncio.create_task(self.app.shutdown())

View File

@ -429,7 +429,7 @@ class WalletNodeAPI:
if not self.wallet_node.wallet_state_manager.sync_mode:
self.wallet_node.log.warning("Receiving header hashes while not syncing.")
return
self.header_hashes = response.hashes
self.wallet_node.header_hashes = response.hashes
@api_request
async def reject_all_header_hashes_after_request(
@ -446,7 +446,7 @@ class WalletNodeAPI:
or self.wallet_node.backup_initialized is False
):
return
self.header_hashes_error = True
self.wallet_node.header_hashes_error = True
@api_request
async def new_lca(self, request: wallet_protocol.NewLCA, peer: WSChiaConnection):

View File

@ -277,10 +277,10 @@ async def setup_node_and_wallet(consensus_constants: ConsensusConstants, startin
setup_wallet_node(21235, consensus_constants, None, starting_height=starting_height),
]
full_node, s1 = await node_iters[0].__anext__()
full_node_api = await node_iters[0].__anext__()
wallet, s2 = await node_iters[1].__anext__()
yield full_node, wallet, s1, s2
yield (full_node_api, wallet, full_node_api.full_node.server, s2)
await _teardown_nodes(node_iters)
@ -291,7 +291,7 @@ async def setup_simulators_and_wallets(
dic: Dict,
starting_height=None,
):
simulators: List[Tuple[FullNode, ChiaServer, FullNodeAPI]] = []
simulators: List[FullNodeAPI] = []
wallets = []
node_iters = []
@ -360,14 +360,14 @@ async def setup_full_system(consensus_constants: ConsensusConstants):
vdf = await node_iters[3].__anext__()
timelord, timelord_server = await node_iters[4].__anext__()
node1, node1_server = await node_iters[5].__anext__()
node2, node2_server = await node_iters[6].__anext__()
node_api_1 = await node_iters[5].__anext__()
node_api_2 = await node_iters[6].__anext__()
vdf_sanitizer = await node_iters[7].__anext__()
sanitizer, sanitizer_server = await node_iters[8].__anext__()
yield (
node1,
node2,
node_api_1,
node_api_2,
harvester,
farmer,
introducer,
@ -375,7 +375,7 @@ async def setup_full_system(consensus_constants: ConsensusConstants):
vdf,
sanitizer,
vdf_sanitizer,
node1_server,
node_api_1.full_node.server,
)
await _teardown_nodes(node_iters)

View File

@ -20,7 +20,7 @@ bt = None # TODO: almog
def node_height_at_least(node, h):
if (max([h.height for h in node.blockchain.get_current_tips()])) >= h:
if (max([h.height for h in node.full_node.blockchain.get_current_tips()])) >= h:
return True
return False
@ -40,12 +40,12 @@ class TestSimulation:
await time_out_assert(500, node_height_at_least, True, node2, 10)
# Wait additional 2 minutes to get a compact block.
max_height = node1.blockchain.lca_block.height
max_height = node1.full_node.blockchain.lca_block.height
async def has_compact(node1, node2, max_height):
for h in range(1, max_height):
blocks_1: List[FullBlock] = await node1.block_store.get_full_blocks_at([uint32(h)])
blocks_2: List[FullBlock] = await node2.block_store.get_full_blocks_at([uint32(h)])
blocks_1: List[FullBlock] = await node1.full_node.block_store.get_full_blocks_at([uint32(h)])
blocks_2: List[FullBlock] = await node2.full_node.block_store.get_full_blocks_at([uint32(h)])
has_compact_1 = False
has_compact_2 = False
for block in blocks_1:

View File

@ -31,15 +31,18 @@ class TestCCWalletBackup:
async def test_coin_backup(self, two_wallet_nodes):
num_blocks = 5
full_nodes, wallets = two_wallet_nodes
full_node_1, server_1 = full_nodes[0]
full_node_api = full_nodes[0]
full_node_server = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
wallet = wallet_node.wallet_state_manager.main_wallet
ph = await wallet.get_new_puzzlehash()
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
await server_2.start_client(
PeerInfo("localhost", uint16(full_node_server._port)), None
)
for i in range(1, 4):
await full_node_1.farm_new_block(FarmNewBlockProtocol(ph))
await full_node_api.farm_new_block(FarmNewBlockProtocol(ph), None)
funds = sum(
[
@ -55,7 +58,7 @@ class TestCCWalletBackup:
)
for i in range(1, num_blocks):
await full_node_1.farm_new_block(FarmNewBlockProtocol(ph))
await full_node_api.farm_new_block(FarmNewBlockProtocol(ph), None)
await time_out_assert(15, cc_wallet.get_confirmed_balance, 100)
await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 100)
@ -76,7 +79,9 @@ class TestCCWalletBackup:
assert started is False
await wallet_node._start(backup_file=file_path)
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
await server_2.start_client(
PeerInfo("localhost", uint16(full_node_server._port)), None
)
all_wallets = wallet_node.wallet_state_manager.wallets
assert len(all_wallets) == 2

View File

@ -125,7 +125,6 @@ class TestWalletSimulator:
await time_out_assert(5, wallet.get_confirmed_balance, new_funds - 10)
await time_out_assert(5, wallet.get_unconfirmed_balance, new_funds - 10)
@pytest.mark.asyncio
async def test_wallet_coinbase_reorg(self, wallet_node):
num_blocks = 5
@ -136,7 +135,9 @@ class TestWalletSimulator:
wallet = wallet_node.wallet_state_manager.main_wallet
ph = await wallet.get_new_puzzlehash()
await server_2.start_client(PeerInfo("localhost", uint16(fn_server._port)), None)
await server_2.start_client(
PeerInfo("localhost", uint16(fn_server._port)), None
)
for i in range(1, num_blocks):
await full_node_api.farm_new_block(FarmNewBlockProtocol(ph), None)
@ -193,12 +194,8 @@ class TestWalletSimulator:
all_blocks = await full_node_api_0.get_current_blocks(full_node_api_0.get_tip())
for block in all_blocks:
await full_node_1._respond_block(
full_node_protocol.RespondBlock(block)
)
await full_node_2._respond_block(
full_node_protocol.RespondBlock(block)
)
await full_node_1._respond_block(full_node_protocol.RespondBlock(block))
await full_node_2._respond_block(full_node_protocol.RespondBlock(block))
funds = sum(
[
@ -291,7 +288,9 @@ class TestWalletSimulator:
await time_out_assert(5, wallet_0.get_unconfirmed_balance, funds - 10)
for i in range(0, 4):
await full_node_api_0.farm_new_block(FarmNewBlockProtocol(token_bytes()), None)
await full_node_api_0.farm_new_block(
FarmNewBlockProtocol(token_bytes()), None
)
new_funds = sum(
[
@ -311,7 +310,9 @@ class TestWalletSimulator:
await wallet_1.push_transaction(tx)
for i in range(0, 4):
await full_node_api_0.farm_new_block(FarmNewBlockProtocol(token_bytes()), None)
await full_node_api_0.farm_new_block(
FarmNewBlockProtocol(token_bytes()), None
)
await wallet_0.get_confirmed_balance()
await wallet_0.get_unconfirmed_balance()

View File

@ -46,15 +46,16 @@ class TestWalletSync:
async def test_basic_sync_wallet(self, wallet_node):
num_blocks = 300 # This must be greater than the short_sync in wallet_node
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [])
full_node_1, wallet_node, server_1, server_2 = wallet_node
full_node_api, wallet_node, full_node_server, wallet_server = wallet_node
for i in range(1, len(blocks)):
async for _ in full_node_1.respond_block(
await full_node_api.full_node._respond_block(
full_node_protocol.RespondBlock(blocks[i])
):
pass
)
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
await wallet_server.start_client(
PeerInfo("localhost", uint16(full_node_server._port)), None
)
# The second node should eventually catch up to the first one, and have the
# same tip at height num_blocks - 1.
@ -65,10 +66,9 @@ class TestWalletSync:
# Tests a reorg with the wallet
blocks_reorg = bt.get_consecutive_blocks(test_constants, 15, blocks[:-5])
for i in range(1, len(blocks_reorg)):
async for msg in full_node_1.respond_block(
await full_node_api.full_node._respond_block(
full_node_protocol.RespondBlock(blocks_reorg[i])
):
server_1.push_message(msg)
)
await time_out_assert(200, wallet_height_at_least, True, wallet_node, 33)
@ -79,10 +79,9 @@ class TestWalletSync:
full_node_1, wallet_node, server_1, server_2 = wallet_node_starting_height
for i in range(1, len(blocks)):
async for _ in full_node_1.respond_block(
await full_node_1.full_node._respond_block(
full_node_protocol.RespondBlock(blocks[i])
):
pass
)
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
@ -97,10 +96,9 @@ class TestWalletSync:
full_node_1, wallet_node, server_1, server_2 = wallet_node
for i in range(1, len(blocks)):
async for _ in full_node_1.respond_block(
await full_node_1.full_node._respond_block(
full_node_protocol.RespondBlock(blocks[i])
):
pass
)
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
await time_out_assert(60, wallet_height_at_least, True, wallet_node, 3)
@ -120,16 +118,14 @@ class TestWalletSync:
test_constants, 3, [], 10, b"", coinbase_puzzlehash
)
for block in blocks:
[
_
async for _ in full_node_1.respond_block(
full_node_protocol.RespondBlock(block)
)
]
await full_node_1.full_node._respond_block(
full_node_protocol.RespondBlock(block)
)
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
await time_out_assert(60, wallet_height_at_least, True, wallet_node, 1)
server_2.global_connections.close_all_connections()
await server_2.close_all_connections()
dic_h = {}
prev_coin = blocks[1].get_coinbase()
@ -150,16 +146,15 @@ class TestWalletSync:
)
# Move chain to height 16, with consecutive transactions in blocks 4 to 14
for block in blocks:
async for _ in full_node_1.respond_block(
await full_node_1.full_node._respond_block(
full_node_protocol.RespondBlock(block)
):
pass
)
# Do a short sync from 0 to 14
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
await time_out_assert(60, wallet_height_at_least, True, wallet_node, 14)
server_2.global_connections.close_all_connections()
await server_2.close_all_connections()
# 3 block rewards and 3 fees - 1000 coins spent
assert (
@ -205,16 +200,15 @@ class TestWalletSync:
# Move chain to height 34, with consecutive transactions in blocks 4 to 14
for block in blocks:
async for _ in full_node_1.respond_block(
await full_node_1.full_node._respond_block(
full_node_protocol.RespondBlock(block)
):
pass
)
# Do a sync from 0 to 22
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
await time_out_assert(60, wallet_height_at_least, True, wallet_node, 28)
server_2.global_connections.close_all_connections()
await server_2.close_all_connections()
# 3 block rewards and 3 fees - 1000 coins spent
assert (
@ -272,10 +266,9 @@ class TestWalletSync:
dic_h,
)
for block in blocks:
async for _ in full_node_1.respond_block(
await full_node_1.full_node._respond_block(
full_node_protocol.RespondBlock(block)
):
pass
)
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
await time_out_assert(60, wallet_height_at_least, True, wallet_node, 38)