Skip to content

Commit

Permalink
rework
Browse files Browse the repository at this point in the history
  • Loading branch information
dmulcahey committed Aug 30, 2024
1 parent 682c587 commit 960d20d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
6 changes: 3 additions & 3 deletions tests/test_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,13 +1283,13 @@ async def test_device_unavailable_or_disabled_skips_entity_polling(

assert entity.state["state"] == 60
assert entity.enabled is True
assert len(zha_gateway.global_updater._update_listeners) == 6
assert len(zha_gateway.global_updater._update_listeners) == 5

# let's drop the normal update method from the updater
entity.disable()

assert entity.enabled is False
assert len(zha_gateway.global_updater._update_listeners) == 5
assert len(zha_gateway.global_updater._update_listeners) == 4

# wrap the update method so we can count how many times it was called
entity.update = MagicMock(wraps=entity.update)
Expand All @@ -1300,7 +1300,7 @@ async def test_device_unavailable_or_disabled_skips_entity_polling(

# re-enable the entity and ensure it is back in the updater and that update is called
entity.enable()
assert len(zha_gateway.global_updater._update_listeners) == 6
assert len(zha_gateway.global_updater._update_listeners) == 5
assert entity.enabled is True

await asyncio.sleep(zha_gateway.global_updater.__polling_interval + 2)
Expand Down
32 changes: 21 additions & 11 deletions zha/application/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,8 @@ class GlobalUpdater:
def __init__(self, gateway: Gateway):
"""Initialize the GlobalUpdater."""
self._updater_task_handle: asyncio.Task = None
self._update_listeners: list[Callable | Awaitable] = []
self._update_listeners: list[Callable] = []
self._update_awaitables: list[Callable[[], Awaitable]] = []
self._gateway: Gateway = gateway
self._polling_active: bool = False

Expand Down Expand Up @@ -401,25 +402,34 @@ def stop(self):
self._updater_task_handle = None
_LOGGER.debug("global updater stopped")

def register_update_listener(self, listener: Callable | Awaitable):
def register_update_listener(self, listener: Callable | Callable[[], Awaitable]):
"""Register an update listener."""
if listener in self._update_listeners:
if listener in self._update_listeners or listener in self._update_awaitables:
_LOGGER.debug(
"listener already registered with global updater - nothing to register: %s",
listener,
)
return
self._update_listeners.append(listener)
if inspect.iscoroutinefunction(listener):
self._update_awaitables.append(listener)
else:
self._update_listeners.append(listener)

def remove_update_listener(self, listener: Callable | Awaitable):
def remove_update_listener(self, listener: Callable | Callable[[], Awaitable]):
"""Remove an update listener."""
if listener not in self._update_listeners:
if (
listener not in self._update_listeners
and listener not in self._update_awaitables
):
_LOGGER.debug(
"listener not registered with global updater - nothing to remove: %s",
listener,
)
return
self._update_listeners.remove(listener)
if inspect.iscoroutinefunction(listener):
self._update_awaitables.remove(listener)
else:
self._update_listeners.remove(listener)

@periodic(_REFRESH_INTERVAL)
async def update_listeners(self):
Expand All @@ -430,10 +440,10 @@ async def update_listeners(self):
_LOGGER.debug("Global updater running update callbacks")
for listener in self._update_listeners:
_LOGGER.debug("Global updater running update callback")
if inspect.iscoroutinefunction(listener):
await listener()
else:
listener() # type: ignore
listener()
await gather_with_limited_concurrency(
3, *[awaitable() for awaitable in self._update_awaitables]
)
self._polling_active = False
else:
_LOGGER.debug("Global updater interval skipped")
Expand Down

0 comments on commit 960d20d

Please sign in to comment.