Skip to content

Commit

Permalink
refactor: use quattro.gather in runner
Browse files Browse the repository at this point in the history
  • Loading branch information
fubuloubu committed May 31, 2024
1 parent 85996e9 commit 0274d16
Showing 1 changed file with 58 additions and 66 deletions.
124 changes: 58 additions & 66 deletions silverback/runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import atexit
from abc import ABC, abstractmethod

import quattro
from ape import chain
from ape.logging import logger
from ape.utils import ManagerAccessMixin
Expand Down Expand Up @@ -109,6 +111,10 @@ async def _event_task(self, task_data: TaskData):
handle an event handler task for the given contract event
"""

def _shutdown(self):
asyncio.run(self.app.broker.shutdown(), debug=True)
logger.info("Application shutdown completed")

async def run(self):
"""
Run the task broker client for the assembled ``SilverbackApp`` application.
Expand All @@ -124,6 +130,8 @@ async def run(self):
"""
# Initialize broker (run worker startup events)
await self.app.broker.startup()
# NOTE: Always ensure we shutdown the broker no matter what
atexit.register(self._shutdown)

# Obtain system configuration for worker
result = await run_taskiq_task_wait_result(
Expand All @@ -133,18 +141,18 @@ async def run(self):
raise StartupFailure("Unable to determine system configuration of worker")

# NOTE: Increase the specifier set here if there is a breaking change to this
if Version(result.return_value.sdk_version) not in SpecifierSet(">=0.5.0"):
# TODO: set to next breaking change release before release
if (sdk_version := Version(result.return_value.sdk_version)) not in SpecifierSet(">=0.5.0"):
raise StartupFailure("Worker SDK version too old, please rebuild")

if not (
system_tasks := set(TaskType(task_name) for task_name in result.return_value.task_types)
):
# NOTE: Guaranteed to be at least one because of `TaskType.SYSTEM_CONFIG`
raise StartupFailure("No system tasks detected, startup failure")
# NOTE: Guaranteed to be at least one because of `TaskType.SYSTEM_CONFIG`

system_tasks_str = "\n- ".join(system_tasks)
logger.info(
f"Worker using Silverback SDK v{result.return_value.sdk_version}"
f"Worker using Silverback SDK v{sdk_version}"
f", available task types:\n- {system_tasks_str}"
)

Expand All @@ -163,20 +171,18 @@ async def run(self):
self.state = AppState(last_block_seen=-1, last_block_processed=-1)

# Execute Silverback startup task before we init the rest
startup_taskdata_result = await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.STARTUP
)

if startup_taskdata_result.is_err:
raise StartupFailure(startup_taskdata_result.error)

else:
startup_task_handlers = map(
self._create_task_kicker, startup_taskdata_result.return_value
if (
startup_taskdata_result := await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.STARTUP
)
).is_err:
raise StartupFailure(startup_taskdata_result.error)

elif startup_task_handlers := tuple(
map(self._create_task_kicker, startup_taskdata_result.return_value)
):
startup_task_results = await run_taskiq_task_group_wait_results(
(task_handler for task_handler in startup_task_handlers), self.state
startup_task_handlers, self.state
)

if any(result.is_err for result in startup_task_results):
Expand All @@ -187,21 +193,26 @@ async def run(self):

elif self.recorder:
converted_results = map(TaskResult.from_taskiq, startup_task_results)
await asyncio.gather(*(self.recorder.add_result(r) for r in converted_results))
await quattro.gather(*(self.recorder.add_result(r) for r in converted_results))

# NOTE: No need to handle results otherwise
# else: No need to handle results otherwise

else:
logger.info("No startup tasks detected")

# Create our long-running event listeners
new_block_taskdata_results = await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.NEW_BLOCK
)
if new_block_taskdata_results.is_err:
if (
new_block_taskdata_results := await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.NEW_BLOCK
)
).is_err:
raise StartupFailure(new_block_taskdata_results.error)

event_log_taskdata_results = await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.EVENT_LOG
)
if event_log_taskdata_results.is_err:
if (
event_log_taskdata_results := await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.EVENT_LOG
)
).is_err:
raise StartupFailure(event_log_taskdata_results.error)

if (
Expand All @@ -212,50 +223,28 @@ async def run(self):
raise NoTasksAvailableError()

# NOTE: Any propagated failure in here should be handled such that shutdown tasks also run
# TODO: `asyncio.TaskGroup` added in Python 3.11
listener_tasks = (
*(
asyncio.create_task(self._block_task(task_def))
for task_def in new_block_taskdata_results.return_value
),
*(
asyncio.create_task(self._event_task(task_def))
for task_def in event_log_taskdata_results.return_value
),
)

# NOTE: Safe to do this because no tasks were actually scheduled to run
if len(listener_tasks) == 0:
raise NoTasksAvailableError()

# Run until one task bubbles up an exception that should stop execution
tasks_with_errors, tasks_running = await asyncio.wait(
listener_tasks, return_when=asyncio.FIRST_EXCEPTION
exceptions_or_none = await quattro.gather(
*(self._block_task(task_def) for task_def in new_block_taskdata_results.return_value),
*(self._event_task(task_def) for task_def in event_log_taskdata_results.return_value),
return_exceptions=True,
)
if runtime_errors := "\n".join(str(task.exception()) for task in tasks_with_errors):
# NOTE: In case we are somehow not displaying the error correctly with task status
logger.debug(f"Runtime error(s) detected, shutting down:\n{runtime_errors}")

# Cancel any still running
(task.cancel() for task in tasks_running)
# NOTE: All listener tasks are shut down now
# NOTE: Result is either None or Exception
if err_msg := "\n\n".join(str(e) for e in exceptions_or_none if e):
logger.error(f"Runtime error(s) detected, shutting down:\n{err_msg}")

# Execute Silverback shutdown task(s) before shutting down the broker and app
shutdown_taskdata_result = await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.SHUTDOWN
)

if shutdown_taskdata_result.is_err:
raise StartupFailure(shutdown_taskdata_result.error)

else:
shutdown_task_handlers = map(
self._create_task_kicker, shutdown_taskdata_result.return_value
if (
shutdown_taskdata_result := await run_taskiq_task_wait_result(
self._create_system_task_kicker(TaskType.SYSTEM_USER_TASKDATA), TaskType.SHUTDOWN
)
).is_err:
raise RuntimeError(shutdown_taskdata_result.error)

shutdown_task_results = await run_taskiq_task_group_wait_results(
(task_handler for task_handler in shutdown_task_handlers)
)
elif shutdown_task_handlers := tuple(
map(self._create_task_kicker, shutdown_taskdata_result.return_value)
):
shutdown_task_results = await run_taskiq_task_group_wait_results(shutdown_task_handlers)

if any(result.is_err for result in shutdown_task_results):
errors_str = "\n".join(
Expand All @@ -265,11 +254,14 @@ async def run(self):

elif self.recorder:
converted_results = map(TaskResult.from_taskiq, shutdown_task_results)
await asyncio.gather(*(self.recorder.add_result(r) for r in converted_results))
await quattro.gather(*(self.recorder.add_result(r) for r in converted_results))

# else: No need to handle results otherwise

# NOTE: No need to handle results otherwise
else:
logger.info("No shutdown tasks detected")

await self.app.broker.shutdown()
# NOTE: atexit handles self.app.broker.shutdown()


class WebsocketRunner(BaseRunner, ManagerAccessMixin):
Expand Down

0 comments on commit 0274d16

Please sign in to comment.