Greatly simplify `Service`.
This commit is contained in:
parent
6dcdcdba3d
commit
7bbf64fa18
|
@ -76,6 +76,15 @@ class Farmer:
|
|||
error_str = "No keys exist. Please run 'chia keys generate' or open the UI."
|
||||
raise RuntimeError(error_str)
|
||||
|
||||
async def _start(self):
|
||||
pass
|
||||
|
||||
def _close(self):
|
||||
pass
|
||||
|
||||
async def _await_closed(self):
|
||||
pass
|
||||
|
||||
async def _on_connect(self):
|
||||
# Sends a handshake to the harvester
|
||||
msg = harvester_protocol.HarvesterHandshake(
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging.config
|
|||
import signal
|
||||
|
||||
from sys import platform
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import uvloop
|
||||
|
@ -25,14 +25,6 @@ from .reconnect_task import start_reconnect_task
|
|||
from .ssl_context import load_ssl_paths
|
||||
|
||||
|
||||
stopped_by_signal = False
|
||||
|
||||
|
||||
def global_signal_handler(*args):
|
||||
global stopped_by_signal
|
||||
stopped_by_signal = True
|
||||
|
||||
|
||||
class Service:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -73,15 +65,6 @@ class Service:
|
|||
|
||||
ssl_cert_path, ssl_key_path = load_ssl_paths(root_path, config)
|
||||
|
||||
async def start_callback():
|
||||
await api._start()
|
||||
|
||||
def stop_callback():
|
||||
api._close()
|
||||
|
||||
async def await_closed_callback():
|
||||
await api._await_closed()
|
||||
|
||||
self._server = ChiaServer(
|
||||
advertised_port,
|
||||
api,
|
||||
|
@ -103,75 +86,73 @@ class Service:
|
|||
self._server_listen_ports = server_listen_ports
|
||||
|
||||
self._api = api
|
||||
self._task = None
|
||||
self._is_stopping = False
|
||||
self._did_start = False
|
||||
self._is_stopping = asyncio.Event()
|
||||
self._stopped_by_rpc = False
|
||||
|
||||
self._on_connect_callback = on_connect_callback
|
||||
self._start_callback = start_callback
|
||||
self._stop_callback = stop_callback
|
||||
self._await_closed_callback = await_closed_callback
|
||||
self._advertised_port = advertised_port
|
||||
self._server_sockets: List = []
|
||||
|
||||
def start(self):
|
||||
if self._task is not None:
|
||||
async def start(self, **kwargs):
|
||||
# we include `kwargs` as a hack for the wallet, which for some
|
||||
# reason allows parameters to `_start`. This is serious BRAIN DAMAGE,
|
||||
# and should be fixed at some point.
|
||||
# TODO: move those parameters to `__init__`
|
||||
if self._did_start:
|
||||
return
|
||||
self._did_start = True
|
||||
|
||||
async def _run():
|
||||
for port in self._upnp_ports:
|
||||
upnp_remap_port(port)
|
||||
self._enable_signals()
|
||||
|
||||
if self._start_callback:
|
||||
await self._start_callback()
|
||||
await self._api._start(**kwargs)
|
||||
|
||||
self._rpc_task = None
|
||||
self._rpc_close_task = None
|
||||
if self._rpc_info:
|
||||
rpc_api, rpc_port = self._rpc_info
|
||||
for port in self._upnp_ports:
|
||||
upnp_remap_port(port)
|
||||
|
||||
self._rpc_task = asyncio.create_task(
|
||||
start_rpc_server(
|
||||
rpc_api(self._api),
|
||||
self.self_hostname,
|
||||
self.daemon_port,
|
||||
rpc_port,
|
||||
self.stop,
|
||||
)
|
||||
self._server_sockets = [
|
||||
await start_server(self._server, self._on_connect_callback)
|
||||
for _ in self._server_listen_ports
|
||||
]
|
||||
|
||||
self._reconnect_tasks = [
|
||||
start_reconnect_task(self._server, _, self._log, self._auth_connect_peers)
|
||||
for _ in self._connect_peers
|
||||
]
|
||||
|
||||
self._rpc_task = None
|
||||
self._rpc_close_task = None
|
||||
if self._rpc_info:
|
||||
rpc_api, rpc_port = self._rpc_info
|
||||
|
||||
self._rpc_task = asyncio.create_task(
|
||||
start_rpc_server(
|
||||
rpc_api(self._api),
|
||||
self.self_hostname,
|
||||
self.daemon_port,
|
||||
rpc_port,
|
||||
self.stop,
|
||||
)
|
||||
|
||||
self._reconnect_tasks = [
|
||||
start_reconnect_task(
|
||||
self._server, _, self._log, self._auth_connect_peers
|
||||
)
|
||||
for _ in self._connect_peers
|
||||
]
|
||||
self._server_sockets = [
|
||||
await start_server(self._server, self._on_connect_callback)
|
||||
for _ in self._server_listen_ports
|
||||
]
|
||||
|
||||
signal.signal(signal.SIGINT, global_signal_handler)
|
||||
signal.signal(signal.SIGTERM, global_signal_handler)
|
||||
if platform == "win32" or platform == "cygwin":
|
||||
# pylint: disable=E1101
|
||||
signal.signal(signal.SIGBREAK, global_signal_handler) # type: ignore
|
||||
|
||||
self._task = asyncio.create_task(_run())
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
self.start()
|
||||
await self._task
|
||||
while not stopped_by_signal and not self._is_stopping:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
self.stop()
|
||||
await self.start()
|
||||
await self.wait_closed()
|
||||
return 0
|
||||
|
||||
def _enable_signals(self):
|
||||
signal.signal(signal.SIGINT, self._accept_signal)
|
||||
signal.signal(signal.SIGTERM, self._accept_signal)
|
||||
if platform == "win32" or platform == "cygwin":
|
||||
# pylint: disable=E1101
|
||||
signal.signal(signal.SIGBREAK, self._accept_signal) # type: ignore
|
||||
|
||||
def _accept_signal(self, signal_number: int, stack_frame):
|
||||
self._log.info(f"got signal {signal_number}")
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
if not self._is_stopping:
|
||||
self._is_stopping = True
|
||||
if not self._is_stopping.is_set():
|
||||
self._is_stopping.set()
|
||||
self._log.info("Closing server sockets")
|
||||
for _ in self._server_sockets:
|
||||
_.close()
|
||||
|
@ -180,11 +161,10 @@ class Service:
|
|||
_.cancel()
|
||||
self._log.info("Closing connections")
|
||||
self._server.close_all()
|
||||
self._api._close()
|
||||
self._api._shut_down = True
|
||||
|
||||
self._log.info("Calling service stop callback")
|
||||
if self._stop_callback:
|
||||
self._stop_callback()
|
||||
|
||||
if self._rpc_task:
|
||||
self._log.info("Closing RPC server")
|
||||
|
@ -195,6 +175,8 @@ class Service:
|
|||
self._rpc_close_task = asyncio.create_task(close_rpc_server())
|
||||
|
||||
async def wait_closed(self):
|
||||
await self._is_stopping.wait()
|
||||
|
||||
self._log.info("Waiting for socket to be closed (if opened)")
|
||||
for _ in self._server_sockets:
|
||||
await _.wait_closed()
|
||||
|
@ -207,9 +189,8 @@ class Service:
|
|||
await self._rpc_close_task
|
||||
self._log.info("Closed RPC server")
|
||||
|
||||
if self._await_closed_callback:
|
||||
self._log.info("Waiting for service _await_closed callback")
|
||||
await self._await_closed_callback()
|
||||
self._log.info("Waiting for service _await_closed callback")
|
||||
await self._api._await_closed()
|
||||
self._log.info(
|
||||
f"Service {self._service_name} at port {self._advertised_port} fully closed"
|
||||
)
|
||||
|
|
|
@ -94,19 +94,6 @@ async def setup_full_node(
|
|||
bt=bt,
|
||||
)
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def start_callback():
|
||||
await api._start()
|
||||
nonlocal started
|
||||
started.set()
|
||||
|
||||
def stop_callback():
|
||||
api._close()
|
||||
|
||||
async def await_closed_callback():
|
||||
await api._await_closed()
|
||||
|
||||
service = Service(
|
||||
root_path=bt.root_path,
|
||||
api=api,
|
||||
|
@ -116,19 +103,15 @@ async def setup_full_node(
|
|||
server_listen_ports=[port],
|
||||
auth_connect_peers=False,
|
||||
on_connect_callback=api._on_connect,
|
||||
start_callback=start_callback,
|
||||
stop_callback=stop_callback,
|
||||
await_closed_callback=await_closed_callback,
|
||||
parse_cli_args=False,
|
||||
)
|
||||
|
||||
run_task = asyncio.create_task(service.run())
|
||||
await started.wait()
|
||||
await service.start()
|
||||
|
||||
yield api, api.server
|
||||
|
||||
service.stop()
|
||||
await run_task
|
||||
await service.wait_closed()
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
|
@ -175,19 +158,6 @@ async def setup_wallet_node(
|
|||
if full_node_port is not None:
|
||||
connect_peers = [PeerInfo(self_hostname, full_node_port)]
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def start_callback():
|
||||
await api._start(new_wallet=True)
|
||||
nonlocal started
|
||||
started.set()
|
||||
|
||||
def stop_callback():
|
||||
api._close()
|
||||
|
||||
async def await_closed_callback():
|
||||
await api._await_closed()
|
||||
|
||||
service = Service(
|
||||
root_path=bt.root_path,
|
||||
api=api,
|
||||
|
@ -198,19 +168,15 @@ async def setup_wallet_node(
|
|||
connect_peers=connect_peers,
|
||||
auth_connect_peers=False,
|
||||
on_connect_callback=api._on_connect,
|
||||
start_callback=start_callback,
|
||||
stop_callback=stop_callback,
|
||||
await_closed_callback=await_closed_callback,
|
||||
parse_cli_args=False,
|
||||
)
|
||||
|
||||
run_task = asyncio.create_task(service.run())
|
||||
await started.wait()
|
||||
await service.start(new_wallet=True)
|
||||
|
||||
yield api, api.server
|
||||
|
||||
service.stop()
|
||||
await run_task
|
||||
await service.wait_closed()
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
keychain.delete_all_keys()
|
||||
|
@ -219,19 +185,6 @@ async def setup_wallet_node(
|
|||
async def setup_harvester(port, farmer_port, consensus_constants: ConsensusConstants):
|
||||
api = Harvester(bt.root_path, consensus_constants)
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def start_callback():
|
||||
await api._start()
|
||||
nonlocal started
|
||||
started.set()
|
||||
|
||||
def stop_callback():
|
||||
api._close()
|
||||
|
||||
async def await_closed_callback():
|
||||
await api._await_closed()
|
||||
|
||||
service = Service(
|
||||
root_path=bt.root_path,
|
||||
api=api,
|
||||
|
@ -241,19 +194,15 @@ async def setup_harvester(port, farmer_port, consensus_constants: ConsensusConst
|
|||
server_listen_ports=[port],
|
||||
connect_peers=[PeerInfo(self_hostname, farmer_port)],
|
||||
auth_connect_peers=True,
|
||||
start_callback=start_callback,
|
||||
stop_callback=stop_callback,
|
||||
await_closed_callback=await_closed_callback,
|
||||
parse_cli_args=False,
|
||||
)
|
||||
|
||||
run_task = asyncio.create_task(service.run())
|
||||
await started.wait()
|
||||
await service.start()
|
||||
|
||||
yield api, api.server
|
||||
|
||||
service.stop()
|
||||
await run_task
|
||||
await service.wait_closed()
|
||||
|
||||
|
||||
async def setup_farmer(
|
||||
|
@ -274,12 +223,6 @@ async def setup_farmer(
|
|||
|
||||
api = Farmer(config, config_pool, bt.keychain, consensus_constants)
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def start_callback():
|
||||
nonlocal started
|
||||
started.set()
|
||||
|
||||
service = Service(
|
||||
root_path=bt.root_path,
|
||||
api=api,
|
||||
|
@ -290,36 +233,21 @@ async def setup_farmer(
|
|||
on_connect_callback=api._on_connect,
|
||||
connect_peers=connect_peers,
|
||||
auth_connect_peers=False,
|
||||
start_callback=start_callback,
|
||||
parse_cli_args=False,
|
||||
)
|
||||
|
||||
run_task = asyncio.create_task(service.run())
|
||||
await started.wait()
|
||||
await service.start()
|
||||
|
||||
yield api, api.server
|
||||
|
||||
service.stop()
|
||||
await run_task
|
||||
await service.wait_closed()
|
||||
|
||||
|
||||
async def setup_introducer(port):
|
||||
config = load_config(bt.root_path, "config.yaml", "introducer")
|
||||
api = Introducer(config["max_peers_to_send"], config["recent_peer_threshold"])
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def start_callback():
|
||||
await api._start()
|
||||
nonlocal started
|
||||
started.set()
|
||||
|
||||
def stop_callback():
|
||||
api._close()
|
||||
|
||||
async def await_closed_callback():
|
||||
await api._await_closed()
|
||||
|
||||
service = Service(
|
||||
root_path=bt.root_path,
|
||||
api=api,
|
||||
|
@ -328,19 +256,15 @@ async def setup_introducer(port):
|
|||
service_name="introducer",
|
||||
server_listen_ports=[port],
|
||||
auth_connect_peers=False,
|
||||
start_callback=start_callback,
|
||||
stop_callback=stop_callback,
|
||||
await_closed_callback=await_closed_callback,
|
||||
parse_cli_args=False,
|
||||
)
|
||||
|
||||
run_task = asyncio.create_task(service.run())
|
||||
await started.wait()
|
||||
await service.start()
|
||||
|
||||
yield api, api.server
|
||||
|
||||
service.stop()
|
||||
await run_task
|
||||
await service.wait_closed()
|
||||
|
||||
|
||||
async def setup_vdf_clients(port):
|
||||
|
@ -367,19 +291,6 @@ async def setup_timelord(
|
|||
|
||||
api = Timelord(config, consensus_constants.DISCRIMINANT_SIZE_BITS)
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def start_callback():
|
||||
await api._start()
|
||||
nonlocal started
|
||||
started.set()
|
||||
|
||||
def stop_callback():
|
||||
api._close()
|
||||
|
||||
async def await_closed_callback():
|
||||
await api._await_closed()
|
||||
|
||||
service = Service(
|
||||
root_path=bt.root_path,
|
||||
api=api,
|
||||
|
@ -389,19 +300,15 @@ async def setup_timelord(
|
|||
server_listen_ports=[port],
|
||||
connect_peers=[PeerInfo(self_hostname, full_node_port)],
|
||||
auth_connect_peers=False,
|
||||
start_callback=start_callback,
|
||||
stop_callback=stop_callback,
|
||||
await_closed_callback=await_closed_callback,
|
||||
parse_cli_args=False,
|
||||
)
|
||||
|
||||
run_task = asyncio.create_task(service.run())
|
||||
await started.wait()
|
||||
await service.start()
|
||||
|
||||
yield api, api.server
|
||||
|
||||
service.stop()
|
||||
await run_task
|
||||
await service.wait_closed()
|
||||
|
||||
|
||||
async def setup_two_nodes(consensus_constants: ConsensusConstants):
|
||||
|
|
|
@ -23,12 +23,6 @@ def node_height_at_least(node, h):
|
|||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
|
||||
|
||||
class TestSimulation:
|
||||
@pytest.fixture(scope="function")
|
||||
async def simulation(self):
|
||||
|
|
Loading…
Reference in New Issue