Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add verify cleanup fixture #54

Merged
merged 5 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Test configuration for the ZHA component."""

import asyncio
from collections.abc import Callable
from collections.abc import Callable, Generator
from contextlib import contextmanager
import itertools
import logging
import os
import reprlib
import threading
import time
from types import TracebackType
from typing import Any, Optional
Expand All @@ -25,6 +29,7 @@
import zigpy.zdo.types as zdo_t

from tests import common
from zha.application import Platform
from zha.application.gateway import Gateway
from zha.application.helpers import (
AlarmControlPanelOptions,
Expand All @@ -33,11 +38,13 @@
ZHAConfiguration,
ZHAData,
)
from zha.async_ import ZHAJob
from zha.zigbee.device import Device

FIXTURE_GRP_ID = 0x1001
FIXTURE_GRP_NAME = "fixture group"
COUNTER_NAMES = ["counter_1", "counter_2", "counter_3"]
INSTANCES = []
_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -120,6 +127,105 @@ def _wrap_mock_instance(obj: Any) -> MagicMock:
return mock


@contextmanager
def long_repr_strings() -> Generator[None, None, None]:
"""Increase reprlib maxstring and maxother to 300."""
arepr = reprlib.aRepr
original_maxstring = arepr.maxstring
original_maxother = arepr.maxother
arepr.maxstring = 300
arepr.maxother = 300
try:
yield
finally:
arepr.maxstring = original_maxstring
arepr.maxother = original_maxother


@pytest.fixture(autouse=True)
def expected_lingering_tasks() -> bool:
"""Temporary ability to bypass test failures.

Parametrize to True to bypass the pytest failure.
@pytest.mark.parametrize("expected_lingering_tasks", [True])

This should be removed when all lingering tasks have been cleaned up.
"""
return False


@pytest.fixture(autouse=True)
def expected_lingering_timers() -> bool:
"""Temporary ability to bypass test failures.

Parametrize to True to bypass the pytest failure.
@pytest.mark.parametrize("expected_lingering_timers", [True])

This should be removed when all lingering timers have been cleaned up.
"""
current_test = os.getenv("PYTEST_CURRENT_TEST")
if (
current_test
and current_test.startswith("tests/components/")
and current_test.split("/")[2] not in {platform.value for platform in Platform}
):
# As a starting point, we ignore non-platform components
return True
return False


@pytest.fixture(autouse=True)
def verify_cleanup(
event_loop: asyncio.AbstractEventLoop,
expected_lingering_tasks: bool, # pylint: disable=redefined-outer-name
expected_lingering_timers: bool, # pylint: disable=redefined-outer-name
) -> Generator[None, None, None]:
"""Verify that the test has cleaned up resources correctly."""
threads_before = frozenset(threading.enumerate())
tasks_before = asyncio.all_tasks(event_loop)
yield

event_loop.run_until_complete(event_loop.shutdown_default_executor())

if len(INSTANCES) >= 2:
count = len(INSTANCES)
for inst in INSTANCES:
inst.stop()
pytest.exit(f"Detected non stopped instances ({count}), aborting test run")

# Warn and clean-up lingering tasks and timers
# before moving on to the next test.
tasks = asyncio.all_tasks(event_loop) - tasks_before
for task in tasks:
if expected_lingering_tasks:
_LOGGER.warning("Lingering task after test %r", task)
else:
pytest.fail(f"Lingering task after test {task!r}")
task.cancel()
if tasks:
event_loop.run_until_complete(asyncio.wait(tasks))

for handle in event_loop._scheduled: # type: ignore[attr-defined]
if not handle.cancelled():
with long_repr_strings():
if expected_lingering_timers:
_LOGGER.warning("Lingering timer after test %r", handle)
elif handle._args and isinstance(job := handle._args[-1], ZHAJob):
if job.cancel_on_shutdown:
continue
pytest.fail(f"Lingering timer after job {job!r}")
else:
pytest.fail(f"Lingering timer after test {handle!r}")
handle.cancel()

# Verify no threads where left behind.
threads = frozenset(threading.enumerate()) - threads_before
for thread in threads:
assert isinstance(thread, threading._DummyThread) or thread.name.startswith(
"waitpid-"
)


@pytest.fixture
async def zigpy_app_controller():
"""Zigpy ApplicationController fixture."""
Expand Down Expand Up @@ -213,12 +319,14 @@ async def __aenter__(self) -> Gateway:
self.zha_gateway = await Gateway.async_from_config(self.zha_data)
await self.zha_gateway.async_block_till_done()
await self.zha_gateway.async_initialize_devices_and_entities()
INSTANCES.append(self.zha_gateway)
return self.zha_gateway

async def __aexit__(
self, exc_type: Exception, exc_value: str, traceback: TracebackType
) -> None:
"""Shutdown the ZHA gateway."""
INSTANCES.remove(self.zha_gateway)
await self.zha_gateway.shutdown()
await asyncio.sleep(0)

Expand Down
3 changes: 3 additions & 0 deletions zha/application/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ async def shutdown(self) -> None:
for device in self._devices.values():
await device.on_remove()

for group in self._groups.values():
await group.on_remove()

_LOGGER.debug("Shutting down ZHA ControllerApplication")
await self.application_controller.shutdown()
self.application_controller = None
Expand Down
7 changes: 6 additions & 1 deletion zha/application/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self, unique_id: str) -> None:

self.__previous_state: Any = None
self._tracked_tasks: list[asyncio.Task] = []
self._tracked_handles: list[asyncio.Handle] = []

@property
def fallback_name(self) -> str | None:
Expand Down Expand Up @@ -213,7 +214,11 @@ def state(self) -> dict[str, Any]:
}

async def on_remove(self) -> None:
"""Cancel tasks this entity owns."""
"""Cancel tasks and timers this entity owns."""
for handle in self._tracked_handles:
self.debug("Cancelling handle: %s", handle)
handle.cancel()

tasks = [t for t in self._tracked_tasks if not (t.done() or t.cancelled())]
for task in tasks:
self.debug("Cancelling task: %s", task)
Expand Down
5 changes: 5 additions & 0 deletions zha/application/platforms/light/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import asyncio
from collections import Counter
from collections.abc import Callable
import contextlib
import dataclasses
from dataclasses import dataclass
import functools
Expand Down Expand Up @@ -652,13 +653,17 @@ def async_transition_start_timer(self, transition_time) -> None:
transition_time,
self.async_transition_complete,
)
self._tracked_handles.append(self._transition_listener)

def _async_unsub_transition_listener(self) -> None:
"""Unsubscribe transition listener."""
if self._transition_listener:
self._transition_listener.cancel()
self._transition_listener = None

with contextlib.suppress(ValueError):
self._tracked_handles.remove(self._transition_listener)

def async_transition_complete(self, _=None) -> None:
"""Set _transitioning_individual to False and write HA state."""
self.debug("transition complete - future attribute reports will write HA state")
Expand Down
6 changes: 6 additions & 0 deletions zha/application/platforms/siren.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import contextlib
from dataclasses import dataclass
from enum import IntFlag
import functools
Expand Down Expand Up @@ -165,6 +166,7 @@ async def async_turn_on(self, **kwargs: Any) -> None:
self._off_listener = asyncio.get_running_loop().call_later(
siren_duration, self.async_set_off
)
self._tracked_handles.append(self._off_listener)
self.maybe_emit_state_changed_event()

async def async_turn_off(self, **kwargs: Any) -> None: # pylint: disable=unused-argument
Expand All @@ -180,5 +182,9 @@ def async_set_off(self) -> None:
self._attr_is_on = False
if self._off_listener:
self._off_listener.cancel()

with contextlib.suppress(ValueError):
self._tracked_handles.remove(self._off_listener)

self._off_listener = None
self.maybe_emit_state_changed_event()
1 change: 0 additions & 1 deletion zha/zigbee/cluster_handlers/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,6 @@ def cluster_command(self, tsn, command_id, args):
)
if on_time > 0:
self._off_listener = asyncio.get_running_loop().call_later(
self._endpoint.device.hass,
(on_time / 10), # value is in 10ths of a second
self.set_to_off,
)
Expand Down
2 changes: 1 addition & 1 deletion zha/zigbee/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def update_available(self, available: bool) -> None:
)

def emit_zha_event(self, event_data: dict[str, str | int]) -> None: # pylint: disable=unused-argument
"""Relay events to hass."""
"""Relay events directly."""
self.emit(
ZHA_EVENT,
ZHAEvent(
Expand Down
5 changes: 5 additions & 0 deletions zha/zigbee/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,8 @@ def log(self, level: int, msg: str, *args: Any, **kwargs) -> None:
msg = f"[%s](%s): {msg}"
args = (self.name, self.group_id) + args
_LOGGER.log(level, msg, *args, **kwargs)

async def on_remove(self) -> None:
"""Cancel tasks this group owns."""
for group_entity in self._group_entities.values():
await group_entity.on_remove()
Loading