Greatly simplify `Service`.

This commit is contained in:
Richard Kiss 2020-10-09 12:37:57 -07:00
parent 6dcdcdba3d
commit 7bbf64fa18
4 changed files with 77 additions and 186 deletions

View File

@ -76,6 +76,15 @@ class Farmer:
error_str = "No keys exist. Please run 'chia keys generate' or open the UI."
raise RuntimeError(error_str)
async def _start(self):
pass
def _close(self):
pass
async def _await_closed(self):
pass
async def _on_connect(self):
# Sends a handshake to the harvester
msg = harvester_protocol.HarvesterHandshake(

View File

@ -4,7 +4,7 @@ import logging.config
import signal
from sys import platform
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, List, Optional, Tuple
try:
import uvloop
@ -25,14 +25,6 @@ from .reconnect_task import start_reconnect_task
from .ssl_context import load_ssl_paths
stopped_by_signal = False
def global_signal_handler(*args):
global stopped_by_signal
stopped_by_signal = True
class Service:
def __init__(
self,
@ -73,15 +65,6 @@ class Service:
ssl_cert_path, ssl_key_path = load_ssl_paths(root_path, config)
async def start_callback():
await api._start()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
self._server = ChiaServer(
advertised_port,
api,
@ -103,75 +86,73 @@ class Service:
self._server_listen_ports = server_listen_ports
self._api = api
self._task = None
self._is_stopping = False
self._did_start = False
self._is_stopping = asyncio.Event()
self._stopped_by_rpc = False
self._on_connect_callback = on_connect_callback
self._start_callback = start_callback
self._stop_callback = stop_callback
self._await_closed_callback = await_closed_callback
self._advertised_port = advertised_port
self._server_sockets: List = []
def start(self):
if self._task is not None:
async def start(self, **kwargs):
# we include `kwargs` as a hack for the wallet, which for some
# reason allows parameters to `_start`. This is serious BRAIN DAMAGE,
# and should be fixed at some point.
# TODO: move those parameters to `__init__`
if self._did_start:
return
self._did_start = True
async def _run():
for port in self._upnp_ports:
upnp_remap_port(port)
self._enable_signals()
if self._start_callback:
await self._start_callback()
await self._api._start(**kwargs)
self._rpc_task = None
self._rpc_close_task = None
if self._rpc_info:
rpc_api, rpc_port = self._rpc_info
for port in self._upnp_ports:
upnp_remap_port(port)
self._rpc_task = asyncio.create_task(
start_rpc_server(
rpc_api(self._api),
self.self_hostname,
self.daemon_port,
rpc_port,
self.stop,
)
self._server_sockets = [
await start_server(self._server, self._on_connect_callback)
for _ in self._server_listen_ports
]
self._reconnect_tasks = [
start_reconnect_task(self._server, _, self._log, self._auth_connect_peers)
for _ in self._connect_peers
]
self._rpc_task = None
self._rpc_close_task = None
if self._rpc_info:
rpc_api, rpc_port = self._rpc_info
self._rpc_task = asyncio.create_task(
start_rpc_server(
rpc_api(self._api),
self.self_hostname,
self.daemon_port,
rpc_port,
self.stop,
)
self._reconnect_tasks = [
start_reconnect_task(
self._server, _, self._log, self._auth_connect_peers
)
for _ in self._connect_peers
]
self._server_sockets = [
await start_server(self._server, self._on_connect_callback)
for _ in self._server_listen_ports
]
signal.signal(signal.SIGINT, global_signal_handler)
signal.signal(signal.SIGTERM, global_signal_handler)
if platform == "win32" or platform == "cygwin":
# pylint: disable=E1101
signal.signal(signal.SIGBREAK, global_signal_handler) # type: ignore
self._task = asyncio.create_task(_run())
)
async def run(self):
self.start()
await self._task
while not stopped_by_signal and not self._is_stopping:
await asyncio.sleep(1)
self.stop()
await self.start()
await self.wait_closed()
return 0
def _enable_signals(self):
signal.signal(signal.SIGINT, self._accept_signal)
signal.signal(signal.SIGTERM, self._accept_signal)
if platform == "win32" or platform == "cygwin":
# pylint: disable=E1101
signal.signal(signal.SIGBREAK, self._accept_signal) # type: ignore
def _accept_signal(self, signal_number: int, stack_frame):
self._log.info(f"got signal {signal_number}")
self.stop()
def stop(self):
if not self._is_stopping:
self._is_stopping = True
if not self._is_stopping.is_set():
self._is_stopping.set()
self._log.info("Closing server sockets")
for _ in self._server_sockets:
_.close()
@ -180,11 +161,10 @@ class Service:
_.cancel()
self._log.info("Closing connections")
self._server.close_all()
self._api._close()
self._api._shut_down = True
self._log.info("Calling service stop callback")
if self._stop_callback:
self._stop_callback()
if self._rpc_task:
self._log.info("Closing RPC server")
@ -195,6 +175,8 @@ class Service:
self._rpc_close_task = asyncio.create_task(close_rpc_server())
async def wait_closed(self):
await self._is_stopping.wait()
self._log.info("Waiting for socket to be closed (if opened)")
for _ in self._server_sockets:
await _.wait_closed()
@ -207,9 +189,8 @@ class Service:
await self._rpc_close_task
self._log.info("Closed RPC server")
if self._await_closed_callback:
self._log.info("Waiting for service _await_closed callback")
await self._await_closed_callback()
self._log.info("Waiting for service _await_closed callback")
await self._api._await_closed()
self._log.info(
f"Service {self._service_name} at port {self._advertised_port} fully closed"
)

View File

@ -94,19 +94,6 @@ async def setup_full_node(
bt=bt,
)
started = asyncio.Event()
async def start_callback():
await api._start()
nonlocal started
started.set()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
service = Service(
root_path=bt.root_path,
api=api,
@ -116,19 +103,15 @@ async def setup_full_node(
server_listen_ports=[port],
auth_connect_peers=False,
on_connect_callback=api._on_connect,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
parse_cli_args=False,
)
run_task = asyncio.create_task(service.run())
await started.wait()
await service.start()
yield api, api.server
service.stop()
await run_task
await service.wait_closed()
if db_path.exists():
db_path.unlink()
@ -175,19 +158,6 @@ async def setup_wallet_node(
if full_node_port is not None:
connect_peers = [PeerInfo(self_hostname, full_node_port)]
started = asyncio.Event()
async def start_callback():
await api._start(new_wallet=True)
nonlocal started
started.set()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
service = Service(
root_path=bt.root_path,
api=api,
@ -198,19 +168,15 @@ async def setup_wallet_node(
connect_peers=connect_peers,
auth_connect_peers=False,
on_connect_callback=api._on_connect,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
parse_cli_args=False,
)
run_task = asyncio.create_task(service.run())
await started.wait()
await service.start(new_wallet=True)
yield api, api.server
service.stop()
await run_task
await service.wait_closed()
if db_path.exists():
db_path.unlink()
keychain.delete_all_keys()
@ -219,19 +185,6 @@ async def setup_wallet_node(
async def setup_harvester(port, farmer_port, consensus_constants: ConsensusConstants):
api = Harvester(bt.root_path, consensus_constants)
started = asyncio.Event()
async def start_callback():
await api._start()
nonlocal started
started.set()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
service = Service(
root_path=bt.root_path,
api=api,
@ -241,19 +194,15 @@ async def setup_harvester(port, farmer_port, consensus_constants: ConsensusConst
server_listen_ports=[port],
connect_peers=[PeerInfo(self_hostname, farmer_port)],
auth_connect_peers=True,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
parse_cli_args=False,
)
run_task = asyncio.create_task(service.run())
await started.wait()
await service.start()
yield api, api.server
service.stop()
await run_task
await service.wait_closed()
async def setup_farmer(
@ -274,12 +223,6 @@ async def setup_farmer(
api = Farmer(config, config_pool, bt.keychain, consensus_constants)
started = asyncio.Event()
async def start_callback():
nonlocal started
started.set()
service = Service(
root_path=bt.root_path,
api=api,
@ -290,36 +233,21 @@ async def setup_farmer(
on_connect_callback=api._on_connect,
connect_peers=connect_peers,
auth_connect_peers=False,
start_callback=start_callback,
parse_cli_args=False,
)
run_task = asyncio.create_task(service.run())
await started.wait()
await service.start()
yield api, api.server
service.stop()
await run_task
await service.wait_closed()
async def setup_introducer(port):
config = load_config(bt.root_path, "config.yaml", "introducer")
api = Introducer(config["max_peers_to_send"], config["recent_peer_threshold"])
started = asyncio.Event()
async def start_callback():
await api._start()
nonlocal started
started.set()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
service = Service(
root_path=bt.root_path,
api=api,
@ -328,19 +256,15 @@ async def setup_introducer(port):
service_name="introducer",
server_listen_ports=[port],
auth_connect_peers=False,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
parse_cli_args=False,
)
run_task = asyncio.create_task(service.run())
await started.wait()
await service.start()
yield api, api.server
service.stop()
await run_task
await service.wait_closed()
async def setup_vdf_clients(port):
@ -367,19 +291,6 @@ async def setup_timelord(
api = Timelord(config, consensus_constants.DISCRIMINANT_SIZE_BITS)
started = asyncio.Event()
async def start_callback():
await api._start()
nonlocal started
started.set()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
service = Service(
root_path=bt.root_path,
api=api,
@ -389,19 +300,15 @@ async def setup_timelord(
server_listen_ports=[port],
connect_peers=[PeerInfo(self_hostname, full_node_port)],
auth_connect_peers=False,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
parse_cli_args=False,
)
run_task = asyncio.create_task(service.run())
await started.wait()
await service.start()
yield api, api.server
service.stop()
await run_task
await service.wait_closed()
async def setup_two_nodes(consensus_constants: ConsensusConstants):

View File

@ -23,12 +23,6 @@ def node_height_at_least(node, h):
return False
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
class TestSimulation:
@pytest.fixture(scope="function")
async def simulation(self):