Skip to content

Commit

Permalink
Fix running async functions without event loop during testing
Browse files Browse the repository at this point in the history
  • Loading branch information
arkq committed Dec 19, 2024
1 parent ddc48d9 commit 91c18a0
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from enum import Enum, IntFlag
from functools import partial
from functools import partial, wraps
from itertools import chain
from typing import Any, Iterable, List, Optional, Tuple

Expand Down Expand Up @@ -100,6 +100,15 @@ class TestRunnerHooks:
_GLOBAL_DATA = {}


def asyncio_thread_executor(f):
@wraps(f)
def wrapper(*args, **kwargs):
thread = threading.Thread(target=asyncio.run, args=(f(*args, **kwargs),))
thread.start()
thread.join()
return wrapper


def stash_globally(o: object) -> str:
id = str(uuid.uuid1())
_GLOBAL_DATA[id] = o
Expand Down Expand Up @@ -2074,10 +2083,9 @@ def parse_matter_test_args(argv: Optional[List[str]] = None) -> MatterTestConfig
return convert_args_to_matter_config(parser.parse_known_args(argv)[0])


def _async_runner(body, self: MatterBaseTest, *args, **kwargs):
async def async_runner_with_timeout(body, self: MatterBaseTest, *args, **kwargs):
timeout = self.matter_test_config.timeout if self.matter_test_config.timeout is not None else self.default_timeout
runner_with_timeout = asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout)
return asyncio.run(runner_with_timeout)
return asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout)


def async_test_body(body):
Expand All @@ -2089,7 +2097,7 @@ def async_test_body(body):
"""

def async_runner(self: MatterBaseTest, *args, **kwargs):
return _async_runner(body, self, *args, **kwargs)
return asyncio_thread_executor(body)(self, *args, **kwargs)

return async_runner

Expand Down Expand Up @@ -2268,9 +2276,9 @@ def run_on_singleton_matching_endpoint(accept_function: EndpointCheckFunction):
Note that currently this test is limited to devices with a SINGLE matching endpoint.
"""
def run_on_singleton_matching_endpoint_internal(body):
def matching_runner(self: MatterBaseTest, *args, **kwargs):
runner_with_timeout = asyncio.wait_for(_get_all_matching_endpoints(self, accept_function), timeout=30)
matching = asyncio.run(runner_with_timeout)
@asyncio_thread_executor
async def matching_runner(self: MatterBaseTest, *args, **kwargs):
matching = await asyncio.wait_for(_get_all_matching_endpoints(self, accept_function), timeout=30)
asserts.assert_less_equal(len(matching), 1, "More than one matching endpoint found for singleton test.")
if not matching:
logging.info("Test is not applicable to any endpoint - skipping test")
Expand All @@ -2281,7 +2289,7 @@ def matching_runner(self: MatterBaseTest, *args, **kwargs):
old_endpoint = self.matter_test_config.endpoint
self.matter_test_config.endpoint = matching[0]
logging.info(f'Running test on endpoint {self.matter_test_config.endpoint}')
_async_runner(body, self, *args, **kwargs)
await async_runner_with_timeout(body, self, *args, **kwargs)
finally:
self.matter_test_config.endpoint = old_endpoint
return matching_runner
Expand Down Expand Up @@ -2315,15 +2323,15 @@ def run_if_endpoint_matches(accept_function: EndpointCheckFunction):
PICS values internally.
"""
def run_if_endpoint_matches_internal(body):
def per_endpoint_runner(self: MatterBaseTest, *args, **kwargs):
runner_with_timeout = asyncio.wait_for(should_run_test_on_endpoint(self, accept_function), timeout=60)
should_run_test = asyncio.run(runner_with_timeout)
@asyncio_thread_executor
async def per_endpoint_runner(self: MatterBaseTest, *args, **kwargs):
should_run_test = await asyncio.wait_for(should_run_test_on_endpoint(self, accept_function), timeout=60)
if not should_run_test:
logging.info("Test is not applicable to this endpoint - skipping test")
asserts.skip('Endpoint does not match test requirements')
return
logging.info(f'Running test on endpoint {self.matter_test_config.endpoint}')
_async_runner(body, self, *args, **kwargs)
await async_runner_with_timeout(body, self, *args, **kwargs)
return per_endpoint_runner
return run_if_endpoint_matches_internal

Expand All @@ -2335,14 +2343,14 @@ def __init__(self, *args):
super().__init__(*args)
self.is_commissioning = True

def test_run_commissioning(self):
@asyncio_thread_executor
async def test_run_commissioning(self):
conf = self.matter_test_config
for commission_idx, node_id in enumerate(conf.dut_node_ids):
logging.info("Starting commissioning for root index %d, fabric ID 0x%016X, node ID 0x%016X" %
(conf.root_of_trust_index, conf.fabric_id, node_id))
logging.info("Commissioning method: %s" % conf.commissioning_method)

if not asyncio.run(self._commission_device(commission_idx)):
if not await self._commission_device(commission_idx):
raise signals.TestAbortAll("Failed to commission node")

async def _commission_device(self, i) -> bool:
Expand Down Expand Up @@ -2443,7 +2451,8 @@ def get_test_info(test_class: MatterBaseTest, matter_test_config: MatterTestConf
return info


def run_tests_no_exit(test_class: MatterBaseTest, matter_test_config: MatterTestConfig, hooks: TestRunnerHooks, default_controller=None, external_stack=None) -> bool:
async def run_tests_no_exit(test_class: MatterBaseTest, matter_test_config: MatterTestConfig,
hooks: TestRunnerHooks, default_controller=None, external_stack=None) -> bool:

get_test_info(test_class, matter_test_config)

Expand Down Expand Up @@ -2534,6 +2543,10 @@ def run_tests_no_exit(test_class: MatterBaseTest, matter_test_config: MatterTest
return ok


def run_tests(test_class: MatterBaseTest, matter_test_config: MatterTestConfig, hooks: TestRunnerHooks, default_controller=None, external_stack=None) -> None:
if not run_tests_no_exit(test_class, matter_test_config, hooks, default_controller, external_stack):
def run_tests_no_exit_sync(*args, **kwargs) -> bool:
return asyncio.run(run_tests_no_exit(*args, **kwargs))


def run_tests(*args, **kwargs):
if not run_tests_no_exit_sync(*args, **kwargs):
sys.exit(1)
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@
try:
from chip.testing.basic_composition import BasicCompositionTests
from chip.testing.matter_testing import (MatterBaseTest, MatterStackState, MatterTestConfig, TestStep, async_test_body,
run_tests_no_exit)
run_tests_no_exit_sync)
except ImportError:
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__), '..')))
from chip.testing.basic_composition import BasicCompositionTests
from chip.testing.matter_testing import (MatterBaseTest, MatterStackState, MatterTestConfig, TestStep, async_test_body,
run_tests_no_exit)
run_tests_no_exit_sync)

try:
import fetch_paa_certs_from_dcl
Expand Down Expand Up @@ -390,7 +390,7 @@ def run_test(test_class: MatterBaseTest, tests: typing.List[str], test_config: T
stack = test_config.get_stack()
controller = test_config.get_controller()
matter_config = test_config.get_config(tests)
ok = run_tests_no_exit(test_class, matter_config, hooks, controller, stack)
ok = run_tests_no_exit_sync(test_class, matter_config, hooks, controller, stack)
if not ok:
print(f"Test failure. Failed on step: {hooks.get_failures()}")
return hooks.get_failures()
Expand Down
4 changes: 2 additions & 2 deletions src/python_testing/test_testing/MockTestRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from unittest.mock import MagicMock

from chip.clusters import Attribute
from chip.testing.matter_testing import MatterStackState, MatterTestConfig, run_tests_no_exit
from chip.testing.matter_testing import MatterStackState, MatterTestConfig, run_tests_no_exit_sync


class AsyncMock(MagicMock):
Expand Down Expand Up @@ -75,4 +75,4 @@ def run_test_with_mock_read(self, read_cache: Attribute.AsyncReadTransaction.Re
self.default_controller.Read = AsyncMock(return_value=read_cache)
# This doesn't need to do anything since we are overriding the read anyway
self.default_controller.FindOrEstablishPASESession = AsyncMock(return_value=None)
return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack)
return run_tests_no_exit_sync(self.test_class, self.config, hooks, self.default_controller, self.stack)
6 changes: 3 additions & 3 deletions src/python_testing/test_testing/test_TC_CCNTL_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from MockTestRunner import AsyncMock, MockTestRunner

try:
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit_sync
except ImportError:
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__), '..')))
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit_sync

invoke_call_count = 0
event_call_count = 0
Expand Down Expand Up @@ -166,7 +166,7 @@ def run_test_with_mock(self, dynamic_invoke_return: typing.Callable, dynamic_eve
self.default_controller.FindOrEstablishPASESession = AsyncMock(return_value=None)
self.default_controller.ReadEvent = AsyncMock(return_value=[], side_effect=dynamic_event_return)

return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack)
return run_tests_no_exit_sync(self.test_class, self.config, hooks, self.default_controller, self.stack)


@click.command()
Expand Down
6 changes: 3 additions & 3 deletions src/python_testing/test_testing/test_TC_MCORE_FS_1_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from MockTestRunner import AsyncMock, MockTestRunner

try:
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit_sync
except ImportError:
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__), '..')))
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit_sync

invoke_call_count = 0
event_call_count = 0
Expand Down Expand Up @@ -137,7 +137,7 @@ def run_test_with_mock(self, dynamic_invoke_return: typing.Callable, dynamic_eve
self.default_controller.FindOrEstablishPASESession = AsyncMock(return_value=None)
self.default_controller.ReadEvent = AsyncMock(return_value=[], side_effect=dynamic_event_return)

return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack)
return run_tests_no_exit_sync(self.test_class, self.config, hooks, self.default_controller, self.stack)


@click.command()
Expand Down

0 comments on commit 91c18a0

Please sign in to comment.