From 960d20d3b8f0384bb638f5dc2db270d7dd62e388 Mon Sep 17 00:00:00 2001 From: David Mulcahey Date: Fri, 30 Aug 2024 19:07:15 -0400 Subject: [PATCH] rework --- tests/test_sensor.py | 6 +++--- zha/application/helpers.py | 32 +++++++++++++++++++++----------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/tests/test_sensor.py b/tests/test_sensor.py index da07d94b..2695fa6c 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -1283,13 +1283,13 @@ async def test_device_unavailable_or_disabled_skips_entity_polling( assert entity.state["state"] == 60 assert entity.enabled is True - assert len(zha_gateway.global_updater._update_listeners) == 6 + assert len(zha_gateway.global_updater._update_listeners) == 5 # let's drop the normal update method from the updater entity.disable() assert entity.enabled is False - assert len(zha_gateway.global_updater._update_listeners) == 5 + assert len(zha_gateway.global_updater._update_listeners) == 4 # wrap the update method so we can count how many times it was called entity.update = MagicMock(wraps=entity.update) @@ -1300,7 +1300,7 @@ async def test_device_unavailable_or_disabled_skips_entity_polling( # re-enable the entity and ensure it is back in the updater and that update is called entity.enable() - assert len(zha_gateway.global_updater._update_listeners) == 6 + assert len(zha_gateway.global_updater._update_listeners) == 5 assert entity.enabled is True await asyncio.sleep(zha_gateway.global_updater.__polling_interval + 2) diff --git a/zha/application/helpers.py b/zha/application/helpers.py index 8240cb07..b1520f30 100644 --- a/zha/application/helpers.py +++ b/zha/application/helpers.py @@ -371,7 +371,8 @@ class GlobalUpdater: def __init__(self, gateway: Gateway): """Initialize the GlobalUpdater.""" self._updater_task_handle: asyncio.Task = None - self._update_listeners: list[Callable | Awaitable] = [] + self._update_listeners: list[Callable] = [] + self._update_awaitables: list[Callable[[], Awaitable]] = [] self._gateway: Gateway = gateway self._polling_active: bool = False @@ -401,25 +402,34 @@ def stop(self): self._updater_task_handle = None _LOGGER.debug("global updater stopped") - def register_update_listener(self, listener: Callable | Awaitable): + def register_update_listener(self, listener: Callable | Callable[[], Awaitable]): """Register an update listener.""" - if listener in self._update_listeners: + if listener in self._update_listeners or listener in self._update_awaitables: _LOGGER.debug( "listener already registered with global updater - nothing to register: %s", listener, ) return - self._update_listeners.append(listener) + if inspect.iscoroutinefunction(listener): + self._update_awaitables.append(listener) + else: + self._update_listeners.append(listener) - def remove_update_listener(self, listener: Callable | Awaitable): + def remove_update_listener(self, listener: Callable | Callable[[], Awaitable]): """Remove an update listener.""" - if listener not in self._update_listeners: + if ( + listener not in self._update_listeners + and listener not in self._update_awaitables + ): _LOGGER.debug( "listener not registered with global updater - nothing to remove: %s", listener, ) return - self._update_listeners.remove(listener) + if inspect.iscoroutinefunction(listener): + self._update_awaitables.remove(listener) + else: + self._update_listeners.remove(listener) @periodic(_REFRESH_INTERVAL) async def update_listeners(self): @@ -430,10 +440,10 @@ async def update_listeners(self): _LOGGER.debug("Global updater running update callbacks") for listener in self._update_listeners: _LOGGER.debug("Global updater running update callback") - if inspect.iscoroutinefunction(listener): - await listener() - else: - listener() # type: ignore + listener() + await gather_with_limited_concurrency( + 3, *[awaitable() for awaitable in self._update_awaitables] + ) self._polling_active = False else: _LOGGER.debug("Global updater interval skipped")