diff --git a/tests/test_async_.py b/tests/test_async_.py index 151e390a..3bebcfaa 100644 --- a/tests/test_async_.py +++ b/tests/test_async_.py @@ -3,29 +3,13 @@ import asyncio import functools import time -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import pytest from zha import async_ as zha_async from zha.application.gateway import Gateway from zha.async_ import AsyncUtilMixin, ZHAJob, ZHAJobType, create_eager_task -from zha.decorators import callback - - -async def test_zhajob_forbid_coroutine() -> None: - """Test zhajob forbids coroutines.""" - - async def bla(): - pass - - coro = bla() - - with pytest.raises(ValueError): - _ = ZHAJob(coro).job_type - - # To avoid warning about unawaited coro - await coro @pytest.mark.parametrize("eager_start", [True, False]) @@ -33,15 +17,14 @@ async def test_cancellable_zhajob(zha_gateway: Gateway, eager_start: bool) -> No """Simulate a shutdown, ensure cancellable jobs are cancelled.""" job = MagicMock() - @callback def run_job(job: ZHAJob) -> None: """Call the action.""" zha_gateway.async_run_zha_job(job, eager_start=eager_start) timer1 = zha_gateway.loop.call_later( - 60, run_job, ZHAJob(callback(job), cancel_on_shutdown=True) + 60, run_job, ZHAJob(job, cancel_on_shutdown=True) ) - timer2 = zha_gateway.loop.call_later(60, run_job, ZHAJob(callback(job))) + timer2 = zha_gateway.loop.call_later(60, run_job, ZHAJob(job)) await zha_gateway.shutdown() @@ -57,7 +40,7 @@ async def test_async_add_zha_job_schedule_callback() -> None: zha_gateway = MagicMock(loop=MagicMock(wraps=asyncio.get_running_loop())) job = MagicMock() - AsyncUtilMixin.async_add_zha_job(zha_gateway, ZHAJob(callback(job))) + AsyncUtilMixin.async_add_zha_job(zha_gateway, ZHAJob(job)) assert len(zha_gateway.loop.call_soon.mock_calls) == 1 assert len(zha_gateway.loop.create_task.mock_calls) == 0 assert len(zha_gateway.add_job.mock_calls) == 0 @@ -71,9 +54,7 @@ async def test_async_add_zha_job_eager_start_coro_suspends( async def job_that_suspends(): await asyncio.sleep(0) - task = zha_gateway.async_add_zha_job( - ZHAJob(callback(job_that_suspends)), eager_start=True - ) + task = zha_gateway.async_add_zha_job(ZHAJob(job_that_suspends), eager_start=True) assert not task.done() assert task in zha_gateway._tracked_completable_tasks await task @@ -88,7 +69,7 @@ async def test_async_run_zha_job_eager_start_coro_suspends( async def job_that_suspends(): await asyncio.sleep(0) - task = zha_gateway.async_run_zha_job(ZHAJob(callback(job_that_suspends))) + task = zha_gateway.async_run_zha_job(ZHAJob(job_that_suspends)) assert not task.done() assert task in zha_gateway._tracked_completable_tasks await task @@ -101,9 +82,7 @@ async def test_async_add_zha_job_background(zha_gateway: Gateway) -> None: async def job_that_suspends(): await asyncio.sleep(0) - task = zha_gateway.async_add_zha_job( - ZHAJob(callback(job_that_suspends)), background=True - ) + task = zha_gateway.async_add_zha_job(ZHAJob(job_that_suspends), background=True) assert not task.done() assert task in zha_gateway._background_tasks await task @@ -116,9 +95,7 @@ async def test_async_run_zha_job_background(zha_gateway: Gateway) -> None: async def job_that_suspends(): await asyncio.sleep(0) - task = zha_gateway.async_run_zha_job( - ZHAJob(callback(job_that_suspends)), background=True - ) + task = zha_gateway.async_run_zha_job(ZHAJob(job_that_suspends), background=True) assert not task.done() assert task in zha_gateway._background_tasks await task @@ -131,9 +108,7 @@ async def test_async_add_zha_job_eager_background(zha_gateway: Gateway) -> None: async def job_that_suspends(): await asyncio.sleep(0) - task = zha_gateway.async_add_zha_job( - ZHAJob(callback(job_that_suspends)), background=True - ) + task = zha_gateway.async_add_zha_job(ZHAJob(job_that_suspends), background=True) assert not task.done() assert task in zha_gateway._background_tasks await task @@ -146,17 +121,17 @@ async def test_async_run_zha_job_eager_background(zha_gateway: Gateway) -> None: async def job_that_suspends(): await asyncio.sleep(0) - task = zha_gateway.async_run_zha_job( - ZHAJob(callback(job_that_suspends)), background=True - ) + task = zha_gateway.async_run_zha_job(ZHAJob(job_that_suspends), background=True) assert not task.done() assert task in zha_gateway._background_tasks await task assert task not in zha_gateway._background_tasks -async def test_async_run_zha_job_background_synchronous( +@pytest.mark.parametrize("background", [True, False]) +async def test_async_run_zha_job_background_no_suspend( zha_gateway: Gateway, + background: bool, ) -> None: """Test scheduling a coro as an eager background task with async_run_zha_job.""" @@ -164,25 +139,10 @@ async def job_that_does_not_suspends(): pass task = zha_gateway.async_run_zha_job( - ZHAJob(callback(job_that_does_not_suspends)), - background=True, - ) - assert task.done() - assert task not in zha_gateway._background_tasks - assert task not in zha_gateway._tracked_completable_tasks - await task - - -async def test_async_run_zha_job_synchronous(zha_gateway: Gateway) -> None: - """Test scheduling a coro as an eager task with async_run_zha_job.""" - - async def job_that_does_not_suspends(): - pass - - task = zha_gateway.async_run_zha_job( - ZHAJob(callback(job_that_does_not_suspends)), - background=False, + ZHAJob(job_that_does_not_suspends), + background=background, ) + assert task is not None assert task.done() assert task not in zha_gateway._background_tasks assert task not in zha_gateway._tracked_completable_tasks @@ -219,7 +179,7 @@ async def test_async_add_zha_job_schedule_partial_callback() -> None: """Test that we schedule partial coros and add jobs to the job pool.""" zha_gateway = MagicMock(loop=MagicMock(wraps=asyncio.get_running_loop())) job = MagicMock() - partial = functools.partial(callback(job)) + partial = functools.partial(job) AsyncUtilMixin.async_add_zha_job(zha_gateway, ZHAJob(partial)) assert len(zha_gateway.loop.call_soon.mock_calls) == 1 @@ -252,6 +212,7 @@ async def job(): ) as mock_create_eager_task: zha_job = ZHAJob(job) task = AsyncUtilMixin.async_add_zha_job(zha_gateway, zha_job, eager_start=True) + assert task is not None assert len(zha_gateway.loop.call_soon.mock_calls) == 0 assert len(zha_gateway.add_job.mock_calls) == 0 assert mock_create_eager_task.mock_calls @@ -283,9 +244,9 @@ def job(): pass AsyncUtilMixin.async_add_zha_job(zha_gateway, ZHAJob(job)) - assert len(zha_gateway.loop.call_soon.mock_calls) == 0 + assert len(zha_gateway.loop.call_soon.mock_calls) == 1 assert len(zha_gateway.loop.create_task.mock_calls) == 0 - assert len(zha_gateway.loop.run_in_executor.mock_calls) == 2 + assert len(zha_gateway.loop.run_in_executor.mock_calls) == 0 async def test_async_create_task_schedule_coroutine() -> None: @@ -337,7 +298,7 @@ def job(): asyncio.get_running_loop() # ensure we are in the event loop calls.append(1) - AsyncUtilMixin.async_run_zha_job(zha_gateway, ZHAJob(callback(job))) + AsyncUtilMixin.async_run_zha_job(zha_gateway, ZHAJob(job)) assert len(calls) == 1 @@ -360,24 +321,11 @@ async def test_async_run_zha_job_calls_callback() -> None: def job(): calls.append(1) - AsyncUtilMixin.async_run_zha_job(zha_gateway, ZHAJob(callback(job))) + AsyncUtilMixin.async_run_zha_job(zha_gateway, ZHAJob(job)) assert len(calls) == 1 assert len(zha_gateway.async_add_job.mock_calls) == 0 -async def test_async_run_zha_job_delegates_non_async() -> None: - """Test that the callback annotation is respected.""" - zha_gateway = MagicMock() - calls = [] - - def job(): - calls.append(1) - - AsyncUtilMixin.async_run_zha_job(zha_gateway, ZHAJob(job)) - assert len(calls) == 0 - assert len(zha_gateway.async_add_zha_job.mock_calls) == 1 - - async def test_async_create_task_pending_tasks_coro(zha_gateway: Gateway) -> None: """Add a coro to pending tasks.""" call_count = [] @@ -395,34 +343,6 @@ async def test_coro(): assert len(zha_gateway._tracked_completable_tasks) == 0 -async def test_async_run_job_starts_tasks_eagerly(zha_gateway: Gateway) -> None: - """Test async_run_job starts tasks eagerly.""" - runs = [] - - async def _test(): - runs.append(True) - - task = zha_gateway.async_run_job(_test) - # No call to zha_gateway.async_block_till_done to ensure the task is run eagerly - assert len(runs) == 1 - assert task.done() - await task - - -async def test_async_run_job_starts_coro_eagerly(zha_gateway: Gateway) -> None: - """Test async_run_job starts coros eagerly.""" - runs = [] - - async def _test(): - runs.append(True) - - task = zha_gateway.async_run_job(_test()) - # No call to zha_gateway.async_block_till_done to ensure the task is run eagerly - assert len(runs) == 1 - assert task.done() - await task - - @pytest.mark.parametrize("eager_start", [True, False]) async def test_background_task(zha_gateway: Gateway, eager_start: bool) -> None: """Test background tasks being quit.""" @@ -447,7 +367,6 @@ async def test_task(): def test_ZHAJob_passing_job_type(): """Test passing the job type to ZHAJob when we already know it.""" - @callback def callback_func(): pass @@ -506,31 +425,6 @@ async def _async_add_executor_job(): await task -@patch("concurrent.futures.Future") -@patch("threading.get_ident") -def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _) -> None: - """Testing calling run_callback_threadsafe from inside an event loop.""" - callback_fn = MagicMock() - - loop = Mock(spec=["call_soon_threadsafe"]) - - loop._thread_ident = None - mock_ident.return_value = 5 - zha_async.run_callback_threadsafe(loop, callback_fn) - assert len(loop.call_soon_threadsafe.mock_calls) == 1 - - loop._thread_ident = 5 - mock_ident.return_value = 5 - with pytest.raises(RuntimeError): - zha_async.run_callback_threadsafe(loop, callback_fn) - assert len(loop.call_soon_threadsafe.mock_calls) == 1 - - loop._thread_ident = 1 - mock_ident.return_value = 5 - zha_async.run_callback_threadsafe(loop, callback_fn) - assert len(loop.call_soon_threadsafe.mock_calls) == 2 - - async def test_gather_with_limited_concurrency() -> None: """Test gather_with_limited_concurrency limits the number of running tasks.""" @@ -553,74 +447,6 @@ async def _increment_runs_if_in_time(): assert results == [2, 2, -1, -1] -async def test_shutdown_run_callback_threadsafe(zha_gateway: Gateway) -> None: - """Test we can shutdown run_callback_threadsafe.""" - zha_async.shutdown_run_callback_threadsafe(zha_gateway.loop) - callback_fn = MagicMock() - - with pytest.raises(RuntimeError): - zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn) - - -async def test_run_callback_threadsafe(zha_gateway: Gateway) -> None: - """Test run_callback_threadsafe runs code in the event loop.""" - it_ran = False - - def callback_fn(): - nonlocal it_ran - it_ran = True - - assert zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn) - assert it_ran is False - - # Verify that async_block_till_done will flush - # out the callback - await zha_gateway.async_block_till_done() - assert it_ran is True - - -async def test_run_callback_threadsafe_exception(zha_gateway: Gateway) -> None: - """Test run_callback_threadsafe runs code in the event loop.""" - it_ran = False - - def callback_fn(): - nonlocal it_ran - it_ran = True - raise ValueError("Test") - - future = zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn) - assert future - assert it_ran is False - - # Verify that async_block_till_done will flush - # out the callback - await zha_gateway.async_block_till_done() - assert it_ran is True - - with pytest.raises(ValueError): - future.result() - - -async def test_callback_is_always_scheduled(zha_gateway: Gateway) -> None: - """Test run_callback_threadsafe always calls call_soon_threadsafe before checking for shutdown.""" - # We have to check the shutdown state AFTER the callback is scheduled otherwise - # the function could continue on and the caller call `future.result()` after - # the point in the main thread where callbacks are no longer run. - - callback_fn = MagicMock() - zha_async.shutdown_run_callback_threadsafe(zha_gateway.loop) - - with ( - patch.object( - zha_gateway.loop, "call_soon_threadsafe" - ) as mock_call_soon_threadsafe, - pytest.raises(RuntimeError), - ): - zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn) - - mock_call_soon_threadsafe.assert_called_once() - - async def test_create_eager_task_312(zha_gateway: Gateway) -> None: # pylint: disable=unused-argument """Test create_eager_task schedules a task eagerly in the event loop. diff --git a/tests/test_debouncer.py b/tests/test_debouncer.py index c3708979..057b5e37 100644 --- a/tests/test_debouncer.py +++ b/tests/test_debouncer.py @@ -8,7 +8,6 @@ from zha.application.gateway import Gateway from zha.debounce import Debouncer -from zha.decorators import callback _LOGGER = logging.getLogger(__name__) @@ -135,7 +134,7 @@ async def test_immediate_works_with_callback_function(zha_gateway: Gateway) -> N _LOGGER, cooldown=0.01, immediate=True, - function=callback(Mock(side_effect=lambda: calls.append(None))), + function=Mock(side_effect=lambda: calls.append(None)), ) # Call when nothing happening @@ -176,7 +175,6 @@ async def test_immediate_works_with_passed_callback_function_raises( """Test immediate works with a callback function that raises.""" calls: list[None] = [] - @callback def _append_and_raise() -> None: calls.append(None) raise RuntimeError("forced_raise") diff --git a/zha/application/helpers.py b/zha/application/helpers.py index 815d5e9c..f55b794b 100644 --- a/zha/application/helpers.py +++ b/zha/application/helpers.py @@ -29,7 +29,7 @@ CONF_DEFAULT_CONSIDER_UNAVAILABLE_MAINS, ) from zha.async_ import gather_with_limited_concurrency -from zha.decorators import SetRegistry, callback, periodic +from zha.decorators import SetRegistry, periodic # from zha.zigbee.cluster_handlers.registries import BINDABLE_CLUSTERS BINDABLE_CLUSTERS = SetRegistry() @@ -394,7 +394,6 @@ def remove_update_listener(self, listener: Callable): """Remove an update listener.""" self._update_listeners.remove(listener) - @callback @periodic(_REFRESH_INTERVAL) async def update_listeners(self): """Update all listeners.""" diff --git a/zha/application/platforms/__init__.py b/zha/application/platforms/__init__.py index bccbefea..2a8a9392 100644 --- a/zha/application/platforms/__init__.py +++ b/zha/application/platforms/__init__.py @@ -17,7 +17,6 @@ from zha.application import Platform from zha.const import STATE_CHANGED from zha.debounce import Debouncer -from zha.decorators import callback from zha.event import EventBase from zha.mixins import LogMixin from zha.zigbee.cluster_handlers import ClusterHandlerInfo @@ -453,7 +452,6 @@ def group(self) -> Group: """Return the group.""" return self._group - @callback def debounced_update(self, _: Any | None = None) -> None: """Debounce updating group entity from member entity updates.""" # Delay to ensure that we get updates from all members before updating the group entity diff --git a/zha/application/platforms/fan/__init__.py b/zha/application/platforms/fan/__init__.py index a95cddf1..cb5c6157 100644 --- a/zha/application/platforms/fan/__init__.py +++ b/zha/application/platforms/fan/__init__.py @@ -40,7 +40,6 @@ ranged_value_to_percentage, ) from zha.application.registries import PLATFORM_ENTITIES -from zha.decorators import callback from zha.zigbee.cluster_handlers import ( ClusterAttributeUpdatedEvent, wrap_zigpy_exceptions, @@ -346,7 +345,6 @@ async def _async_set_fan_mode(self, fan_mode: int) -> None: self.maybe_emit_state_changed_event() - @callback def update(self, _: Any = None) -> None: """Attempt to retrieve on off state from the fan.""" self.debug("Updating fan group entity state") diff --git a/zha/application/platforms/light/__init__.py b/zha/application/platforms/light/__init__.py index c35907b2..f1d4d450 100644 --- a/zha/application/platforms/light/__init__.py +++ b/zha/application/platforms/light/__init__.py @@ -65,7 +65,7 @@ ) from zha.application.registries import PLATFORM_ENTITIES from zha.debounce import Debouncer -from zha.decorators import callback, periodic +from zha.decorators import periodic from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, @@ -1246,7 +1246,6 @@ async def async_turn_off(self, **kwargs: Any) -> None: if self._debounced_member_refresh: await self._debounced_member_refresh.async_call() - @callback def update(self, _: Any = None) -> None: """Query all members and determine the light group state.""" self.debug("Updating light group entity state") diff --git a/zha/application/platforms/switch.py b/zha/application/platforms/switch.py index 2908fe21..ac5c1c93 100644 --- a/zha/application/platforms/switch.py +++ b/zha/application/platforms/switch.py @@ -24,7 +24,6 @@ PlatformEntity, ) from zha.application.registries import PLATFORM_ENTITIES -from zha.decorators import callback from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, @@ -169,7 +168,6 @@ async def async_turn_off(self, **kwargs: Any) -> None: # pylint: disable=unused self._state = False self.maybe_emit_state_changed_event() - @callback def update(self, _: Any | None = None) -> None: """Query all members and determine the light group state.""" self.debug("Updating switch group entity state") diff --git a/zha/application/platforms/update.py b/zha/application/platforms/update.py index 4be8cc9a..f49450b3 100644 --- a/zha/application/platforms/update.py +++ b/zha/application/platforms/update.py @@ -15,7 +15,6 @@ from zha.application import Platform from zha.application.platforms import BaseEntityInfo, EntityCategory, PlatformEntity from zha.application.registries import PLATFORM_ENTITIES -from zha.decorators import callback from zha.exceptions import ZHAException from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent from zha.zigbee.cluster_handlers.const import ( @@ -232,7 +231,6 @@ def handle_cluster_handler_attribute_updated( self._attr_installed_version = f"0x{event.attribute_value:08x}" self.maybe_emit_state_changed_event() - @callback def device_ota_update_available( self, image: OtaImageWithMetadata, current_file_version: int ) -> None: diff --git a/zha/async_.py b/zha/async_.py index 2e067332..782209a2 100644 --- a/zha/async_.py +++ b/zha/async_.py @@ -5,14 +5,12 @@ import asyncio from asyncio import AbstractEventLoop, Future, Semaphore, Task, gather, get_running_loop from collections.abc import Awaitable, Callable, Collection, Coroutine, Iterable -import concurrent.futures import contextlib from dataclasses import dataclass import enum import functools from functools import cached_property import logging -import threading import time from typing import ( TYPE_CHECKING, @@ -27,8 +25,6 @@ from zigpy.types.named import EUI64 -from zha.decorators import callback - _T = TypeVar("_T") _R = TypeVar("_R") _R_co = TypeVar("_R_co", covariant=True) @@ -36,8 +32,6 @@ _Ts = TypeVarTuple("_Ts") BLOCK_LOG_TIMEOUT: Final[int] = 60 -_SHUTDOWN_RUN_CALLBACK_THREADSAFE = "_zha_shutdown_run_callback_threadsafe" - _LOGGER = logging.getLogger(__name__) @@ -75,72 +69,6 @@ async def sem_task(task: Awaitable[Any]) -> Any: ) -def run_callback_threadsafe( - loop: AbstractEventLoop, callback_fn: Callable[[*_Ts], _T], *args: *_Ts -) -> concurrent.futures.Future[_T]: - """Submit a callback object to a given event loop. - - Return a concurrent.futures.Future to access the result. - """ - ident = loop.__dict__.get("_thread_ident") - if ident is not None and ident == threading.get_ident(): - raise RuntimeError("Cannot be called from within the event loop") - - future: concurrent.futures.Future[_T] = concurrent.futures.Future() - - def run_callback() -> None: - """Run callback and store result.""" - try: - future.set_result(callback_fn(*args)) - except Exception as exc: # pylint: disable=broad-except - if future.set_running_or_notify_cancel(): - future.set_exception(exc) - else: - _LOGGER.warning("Exception on lost future: ", exc_info=True) - - loop.call_soon_threadsafe(run_callback) - - if hasattr(loop, _SHUTDOWN_RUN_CALLBACK_THREADSAFE): - # - # If the final `Gateway.async_block_till_done` in - # `Gateway.shutdown` has already been called, the callback - # will never run and, `future.result()` will block forever which - # will prevent the thread running this code from shutting down which - # will result in a deadlock when the main thread attempts to shutdown - # the executor and `.join()` the thread running this code. - # - # To prevent this deadlock we do the following on shutdown: - # - # 1. Set the _SHUTDOWN_RUN_CALLBACK_THREADSAFE attr on this function - # by calling `shutdown_run_callback_threadsafe` - # 2. Call `zha_gateway.async_block_till_done` at least once after shutdown - # to ensure all callbacks have run - # 3. Raise an exception here to ensure `future.result()` can never be - # called and hit the deadlock since once `shutdown_run_callback_threadsafe` - # we cannot promise the callback will be executed. - # - raise RuntimeError("The event loop is in the process of shutting down.") - - return future - - -def shutdown_run_callback_threadsafe(loop: AbstractEventLoop) -> None: - """Call when run_callback_threadsafe should prevent creating new futures. - - We must finish all callbacks before the executor is shutdown - or we can end up in a deadlock state where: - - `executor.result()` is waiting for its `._condition` - and the executor shutdown is trying to `.join()` the - executor thread. - - This function is considered irreversible and should only ever - be called when ZHA is going to shutdown and - python is going to exit. - """ - setattr(loop, _SHUTDOWN_RUN_CALLBACK_THREADSAFE, True) - - def cancelling(task: Future[Any]) -> bool: """Return True if task is cancelling.""" return bool((cancelling_ := getattr(task, "cancelling", None)) and cancelling_()) @@ -152,7 +80,6 @@ class ZHAJobType(enum.Enum): Coroutinefunction = 1 Callback = 2 - Executor = 3 class ZHAJob(Generic[_P, _R_co]): @@ -203,22 +130,13 @@ class ZHAJobWithArgs: def get_zhajob_callable_job_type(target: Callable[..., Any]) -> ZHAJobType: """Determine the job type from the callable.""" # Check for partials to properly determine if coroutine function - check_target = target - while isinstance(check_target, functools.partial): - check_target = check_target.func + while isinstance(target, functools.partial): + target = target.func - if asyncio.iscoroutinefunction(check_target): + if asyncio.iscoroutinefunction(target): return ZHAJobType.Coroutinefunction - if is_callback(check_target): + else: return ZHAJobType.Callback - if asyncio.iscoroutine(check_target): - raise ValueError("Coroutine not allowed to be passed to ZHAJob") - return ZHAJobType.Executor - - -def is_callback(func: Callable[..., Any]) -> bool: - """Check if function is safe to be called in the event loop.""" - return getattr(func, "_zha_callback", False) is True class AsyncUtilMixin: @@ -236,13 +154,6 @@ def __init__(self, *args, **kw_args) -> None: async def shutdown(self) -> None: """Shutdown the executor.""" - # Prevent run_callback_threadsafe from scheduling any additional - # callbacks in the event loop as callbacks created on the futures - # it returns will never run after the final `self.async_block_till_done` - # which will cause the futures to block forever when waiting for - # the `result()` which will cause a deadlock when shutting down the executor. - shutdown_run_callback_threadsafe(self.loop) - async def _cancel_tasks(tasks_to_cancel: Iterable) -> None: tasks = [t for t in tasks_to_cancel if not (t.done() or t.cancelled())] for task in tasks: @@ -320,14 +231,13 @@ def _cancel_cancellable_timers(self) -> None: ): handle.cancel() - @callback def async_add_zha_job( self, zhajob: ZHAJob[..., Coroutine[Any, Any, _R] | _R], *args: Any, eager_start: bool = False, background: bool = False, - ) -> asyncio.Future[_R] | None: + ) -> asyncio.Future[_R] | asyncio.Task[_R] | None: """Add a ZHAJob from within the event loop. If eager_start is True, coroutine functions will be scheduled eagerly. @@ -342,32 +252,27 @@ def async_add_zha_job( # if TYPE_CHECKING to avoid the overhead of constructing # the type used for the cast. For history see: # https://github.com/home-assistant/core/pull/71960 - if zhajob.job_type is ZHAJobType.Coroutinefunction: - if TYPE_CHECKING: - zhajob.target = cast( - Callable[..., Coroutine[Any, Any, _R]], zhajob.target - ) - # Use loop.create_task - # to avoid the extra function call in asyncio.create_task. - if eager_start: - task = create_eager_task( - zhajob.target(*args), - name=zhajob.name, - loop=self.loop, - ) - if task.done(): - return task - else: - task = self.loop.create_task(zhajob.target(*args), name=zhajob.name) - elif zhajob.job_type is ZHAJobType.Callback: + if zhajob.job_type is ZHAJobType.Callback: if TYPE_CHECKING: zhajob.target = cast(Callable[..., _R], zhajob.target) self.loop.call_soon(zhajob.target, *args) return None + + if TYPE_CHECKING: + zhajob.target = cast(Callable[..., Coroutine[Any, Any, _R]], zhajob.target) + + # Use loop.create_task + # to avoid the extra function call in asyncio.create_task. + if eager_start: + task = create_eager_task( + zhajob.target(*args), + name=zhajob.name, + loop=self.loop, + ) + if task.done(): + return task else: - if TYPE_CHECKING: - zhajob.target = cast(Callable[..., _R], zhajob.target) - task = self.loop.run_in_executor(None, zhajob.target, *args) + task = self.loop.create_task(zhajob.target(*args), name=zhajob.name) task_bucket = ( self._background_tasks if background else self._tracked_completable_tasks @@ -388,7 +293,6 @@ def create_task( functools.partial(self.async_create_task, target, name, eager_start=True) ) - @callback def async_create_task( self, target: Coroutine[Any, Any, _R], @@ -414,7 +318,6 @@ def async_create_task( task.add_done_callback(self._tracked_completable_tasks.remove) return task - @callback def async_create_background_task( self, target: Coroutine[Any, Any, _R], @@ -454,7 +357,6 @@ def async_create_background_task( task.add_done_callback(task_bucket.remove) return task - @callback def async_add_executor_job( self, target: Callable[..., _T], *args: Any ) -> asyncio.Future[_T]: @@ -468,13 +370,13 @@ def async_add_executor_job( task.add_done_callback(task_bucket.remove) return task - @callback def async_run_zha_job( self, zhajob: ZHAJob[..., Coroutine[Any, Any, _R] | _R], *args: Any, background: bool = False, - ) -> asyncio.Future[_R] | None: + eager_start: bool = True, + ) -> asyncio.Future[_R] | asyncio.Task[_R] | None: """Run a ZHAJob from within the event loop. This method must be run in the event loop. @@ -495,29 +397,8 @@ def async_run_zha_job( return None return self.async_add_zha_job( - zhajob, *args, eager_start=True, background=background + zhajob, + *args, + eager_start=eager_start, + background=background, ) - - @callback - def async_run_job( - self, - target: Callable[..., Coroutine[Any, Any, _R] | _R] | Coroutine[Any, Any, _R], - *args: Any, - ) -> asyncio.Future[_R] | None: - """Run a job from within the event loop. - - This method must be run in the event loop. - - target: target to call. - args: parameters for method to call. - """ - if asyncio.iscoroutine(target): - return self.async_create_task(target, eager_start=True) - - # This code path is performance sensitive and uses - # if TYPE_CHECKING to avoid the overhead of constructing - # the type used for the cast. For history see: - # https://github.com/home-assistant/core/pull/71960 - if TYPE_CHECKING: - target = cast(Callable[..., Coroutine[Any, Any, _R] | _R], target) - return self.async_run_zha_job(ZHAJob(target), *args) diff --git a/zha/decorators.py b/zha/decorators.py index ca61dc53..bc521439 100644 --- a/zha/decorators.py +++ b/zha/decorators.py @@ -100,9 +100,3 @@ async def wrapper(*args: Any, **kwargs: Any) -> None: return wrapper return scheduler - - -def callback(func: Callable[..., Any]) -> Callable[..., Any]: - """Annotation to mark method as safe to call from within the event loop.""" - setattr(func, "_zha_callback", True) - return func