diff --git a/pyzeebe/grpc_internals/zeebe_adapter_base.py b/pyzeebe/grpc_internals/zeebe_adapter_base.py index a343bc41..7d2da686 100644 --- a/pyzeebe/grpc_internals/zeebe_adapter_base.py +++ b/pyzeebe/grpc_internals/zeebe_adapter_base.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import logging +from collections.abc import Callable from typing import NoReturn import grpc +from typing_extensions import TypeAlias from zeebe_grpc.gateway_pb2_grpc import GatewayStub from pyzeebe.errors import ( @@ -14,6 +18,8 @@ from pyzeebe.errors.pyzeebe_errors import PyZeebeError from pyzeebe.grpc_internals.grpc_utils import is_error_status +Callback: TypeAlias = Callable[[], None] + logger = logging.getLogger(__name__) @@ -25,11 +31,15 @@ def __init__(self, grpc_channel: grpc.aio.Channel, max_connection_retries: int = self.retrying_connection = False self._max_connection_retries = max_connection_retries self._current_connection_retries = 0 + self._on_disconnect_callbacks: list[Callback] = [] @property def connected(self) -> bool: return self._connected + def add_disconnect_callback(self, callback: Callback) -> None: + self._on_disconnect_callbacks.append(callback) + def _should_retry(self) -> bool: return self._max_connection_retries == -1 or self._current_connection_retries < self._max_connection_retries @@ -50,6 +60,8 @@ async def _close(self) -> None: logger.exception("Failed to close channel, %s exception was raised", type(exception).__name__) finally: self._connected = False + for callback in self._on_disconnect_callbacks: + callback() def _create_pyzeebe_error_from_grpc_error(grpc_error: grpc.aio.AioRpcError) -> PyZeebeError: diff --git a/pyzeebe/worker/worker.py b/pyzeebe/worker/worker.py index aabaca1e..957d9653 100644 --- a/pyzeebe/worker/worker.py +++ b/pyzeebe/worker/worker.py @@ -48,14 +48,15 @@ def __init__( tenant_ids (list[str]): A list of tenant IDs for which to activate jobs. New in Zeebe 8.3. """ super().__init__(before, after, exception_handler) + self._stop_event = anyio.Event() self.zeebe_adapter = ZeebeAdapter(grpc_channel, max_connection_retries) + self.zeebe_adapter.add_disconnect_callback(self._stop_event.set) self.name = name or socket.gethostname() self.request_timeout = request_timeout self.poll_retry_delay = poll_retry_delay self.tenant_ids = tenant_ids self._job_pollers: list[JobPoller] = [] self._job_executors: list[JobExecutor] = [] - self._stop_event = anyio.Event() def _init_tasks(self) -> None: self._job_executors, self._job_pollers = [], [] @@ -110,11 +111,12 @@ async def stop(self) -> None: """ Stop the worker. This will emit a signal asking tasks to complete the current task and stop polling for new. """ - for poller in self._job_pollers: - await poller.stop() + async with anyio.create_task_group() as tg: + for poller in self._job_pollers: + tg.start_soon(poller.stop) - for executor in self._job_executors: - await executor.stop() + for executor in self._job_executors: + tg.start_soon(executor.stop) self._stop_event.set() diff --git a/tests/unit/grpc_internals/zeebe_adapter_base_test.py b/tests/unit/grpc_internals/zeebe_adapter_base_test.py index 11bf59f2..4d6a137d 100644 --- a/tests/unit/grpc_internals/zeebe_adapter_base_test.py +++ b/tests/unit/grpc_internals/zeebe_adapter_base_test.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock import grpc import pytest @@ -67,6 +67,9 @@ async def test_raises_unkown_grpc_status_code_on_unkown_status_code( await zeebe_adapter._handle_grpc_error(error) async def test_closes_after_retries_exceeded(self, zeebe_adapter: ZeebeAdapterBase): + on_disconnect_callback = Mock() + zeebe_adapter.add_disconnect_callback(on_disconnect_callback) + error = grpc.aio.AioRpcError(grpc.StatusCode.UNAVAILABLE, None, None) zeebe_adapter._channel.close = AsyncMock() @@ -76,8 +79,12 @@ async def test_closes_after_retries_exceeded(self, zeebe_adapter: ZeebeAdapterBa assert zeebe_adapter.connected is False zeebe_adapter._channel.close.assert_awaited_once() + on_disconnect_callback.assert_called_once() async def test_closes_after_internal_error(self, zeebe_adapter: ZeebeAdapterBase): + on_disconnect_callback = Mock() + zeebe_adapter.add_disconnect_callback(on_disconnect_callback) + error = grpc.aio.AioRpcError(grpc.StatusCode.INTERNAL, None, None) zeebe_adapter._channel.close = AsyncMock() @@ -87,3 +94,4 @@ async def test_closes_after_internal_error(self, zeebe_adapter: ZeebeAdapterBase assert zeebe_adapter.connected is False zeebe_adapter._channel.close.assert_awaited_once() + on_disconnect_callback.assert_called_once() diff --git a/tests/unit/worker/worker_test.py b/tests/unit/worker/worker_test.py index 9f234604..013d4aac 100644 --- a/tests/unit/worker/worker_test.py +++ b/tests/unit/worker/worker_test.py @@ -296,3 +296,18 @@ async def poll2(): poller_mock.poll.assert_awaited_once() poller2_mock.poll.assert_awaited_once() assert poller2_cancel_event.is_set() + + async def test_stop_after_retries_exceeded(self, zeebe_worker: ZeebeWorker): + @zeebe_worker.task(str(uuid4())) + def dummy_function(): + pass + + zeebe_worker.zeebe_adapter._gateway_stub.ActivateJobs.side_effect = [ + grpc.aio.AioRpcError(grpc.StatusCode.INTERNAL, None, None) + ] + zeebe_worker.zeebe_adapter._max_connection_retries = 1 + + await zeebe_worker.work() + + zeebe_worker.zeebe_adapter._gateway_stub.ActivateJobs.assert_called_once() + assert zeebe_worker._stop_event.is_set() is True