Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix running async functions without event loop during testing #36859

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
import asyncio
import builtins
import contextlib
import inspect
import json
import logging
Expand All @@ -37,7 +38,7 @@
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from enum import Enum, IntFlag
from functools import partial
from functools import partial, wraps
from itertools import chain
from typing import Any, Iterable, List, Optional, Tuple

Expand Down Expand Up @@ -100,6 +101,28 @@ class TestRunnerHooks:
_GLOBAL_DATA = {}


def asyncio_thread_executor(f):
"""Run an async function in an event loop in a separate thread.

This decorator function blocks the current thread until the async function
completes. Also, it forwards any exceptions that occurred in that thread.
"""
@wraps(f)
def wrapper(*args, **kwargs):
def run(coroutine, q: queue.Queue):
try:
asyncio.run(coroutine)
except Exception as e:
q.put(e)
q = queue.Queue()
thread = threading.Thread(target=run, args=(f(*args, **kwargs), q))
thread.start()
thread.join()
with contextlib.suppress(queue.Empty):
raise q.get(block=False)
return wrapper


def stash_globally(o: object) -> str:
id = str(uuid.uuid1())
_GLOBAL_DATA[id] = o
Expand Down Expand Up @@ -2074,10 +2097,9 @@ def parse_matter_test_args(argv: Optional[List[str]] = None) -> MatterTestConfig
return convert_args_to_matter_config(parser.parse_known_args(argv)[0])


def _async_runner(body, self: MatterBaseTest, *args, **kwargs):
async def async_runner_with_timeout(body, self: MatterBaseTest, *args, **kwargs):
timeout = self.matter_test_config.timeout if self.matter_test_config.timeout is not None else self.default_timeout
runner_with_timeout = asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout)
return asyncio.run(runner_with_timeout)
return await asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout)


def async_test_body(body):
Expand All @@ -2088,8 +2110,9 @@ def async_test_body(body):
a asyncio-run synchronous method. This decorator does the wrapping.
"""

def async_runner(self: MatterBaseTest, *args, **kwargs):
return _async_runner(body, self, *args, **kwargs)
@asyncio_thread_executor
async def async_runner(self: MatterBaseTest, *args, **kwargs):
return await async_runner_with_timeout(body, self, *args, **kwargs)

return async_runner

Expand Down Expand Up @@ -2268,9 +2291,9 @@ def run_on_singleton_matching_endpoint(accept_function: EndpointCheckFunction):
Note that currently this test is limited to devices with a SINGLE matching endpoint.
"""
def run_on_singleton_matching_endpoint_internal(body):
def matching_runner(self: MatterBaseTest, *args, **kwargs):
runner_with_timeout = asyncio.wait_for(_get_all_matching_endpoints(self, accept_function), timeout=30)
matching = asyncio.run(runner_with_timeout)
@asyncio_thread_executor
async def matching_runner(self: MatterBaseTest, *args, **kwargs):
matching = await asyncio.wait_for(_get_all_matching_endpoints(self, accept_function), timeout=30)
asserts.assert_less_equal(len(matching), 1, "More than one matching endpoint found for singleton test.")
if not matching:
logging.info("Test is not applicable to any endpoint - skipping test")
Expand All @@ -2281,7 +2304,7 @@ def matching_runner(self: MatterBaseTest, *args, **kwargs):
old_endpoint = self.matter_test_config.endpoint
self.matter_test_config.endpoint = matching[0]
logging.info(f'Running test on endpoint {self.matter_test_config.endpoint}')
_async_runner(body, self, *args, **kwargs)
await async_runner_with_timeout(body, self, *args, **kwargs)
finally:
self.matter_test_config.endpoint = old_endpoint
return matching_runner
Expand Down Expand Up @@ -2315,15 +2338,15 @@ def run_if_endpoint_matches(accept_function: EndpointCheckFunction):
PICS values internally.
"""
def run_if_endpoint_matches_internal(body):
def per_endpoint_runner(self: MatterBaseTest, *args, **kwargs):
runner_with_timeout = asyncio.wait_for(should_run_test_on_endpoint(self, accept_function), timeout=60)
should_run_test = asyncio.run(runner_with_timeout)
@asyncio_thread_executor
async def per_endpoint_runner(self: MatterBaseTest, *args, **kwargs):
should_run_test = await asyncio.wait_for(should_run_test_on_endpoint(self, accept_function), timeout=60)
if not should_run_test:
logging.info("Test is not applicable to this endpoint - skipping test")
asserts.skip('Endpoint does not match test requirements')
return
logging.info(f'Running test on endpoint {self.matter_test_config.endpoint}')
_async_runner(body, self, *args, **kwargs)
await async_runner_with_timeout(body, self, *args, **kwargs)
return per_endpoint_runner
return run_if_endpoint_matches_internal

Expand All @@ -2335,14 +2358,14 @@ def __init__(self, *args):
super().__init__(*args)
self.is_commissioning = True

def test_run_commissioning(self):
@asyncio_thread_executor
async def test_run_commissioning(self):
conf = self.matter_test_config
for commission_idx, node_id in enumerate(conf.dut_node_ids):
logging.info("Starting commissioning for root index %d, fabric ID 0x%016X, node ID 0x%016X" %
(conf.root_of_trust_index, conf.fabric_id, node_id))
logging.info("Commissioning method: %s" % conf.commissioning_method)

if not asyncio.run(self._commission_device(commission_idx)):
if not await self._commission_device(commission_idx):
raise signals.TestAbortAll("Failed to commission node")

async def _commission_device(self, i) -> bool:
Expand Down Expand Up @@ -2443,7 +2466,8 @@ def get_test_info(test_class: MatterBaseTest, matter_test_config: MatterTestConf
return info


def run_tests_no_exit(test_class: MatterBaseTest, matter_test_config: MatterTestConfig, hooks: TestRunnerHooks, default_controller=None, external_stack=None) -> bool:
async def run_tests_no_exit(test_class: MatterBaseTest, matter_test_config: MatterTestConfig,
hooks: TestRunnerHooks, default_controller=None, external_stack=None) -> bool:

get_test_info(test_class, matter_test_config)

Expand Down Expand Up @@ -2534,6 +2558,10 @@ def run_tests_no_exit(test_class: MatterBaseTest, matter_test_config: MatterTest
return ok


def run_tests(test_class: MatterBaseTest, matter_test_config: MatterTestConfig, hooks: TestRunnerHooks, default_controller=None, external_stack=None) -> None:
if not run_tests_no_exit(test_class, matter_test_config, hooks, default_controller, external_stack):
def run_tests_no_exit_sync(*args, **kwargs) -> bool:
return asyncio.run(run_tests_no_exit(*args, **kwargs))


def run_tests(*args, **kwargs):
if not run_tests_no_exit_sync(*args, **kwargs):
sys.exit(1)
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@
try:
from chip.testing.basic_composition import BasicCompositionTests
from chip.testing.matter_testing import (MatterBaseTest, MatterStackState, MatterTestConfig, TestStep, async_test_body,
run_tests_no_exit)
run_tests_no_exit_sync)
except ImportError:
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__), '..')))
from chip.testing.basic_composition import BasicCompositionTests
from chip.testing.matter_testing import (MatterBaseTest, MatterStackState, MatterTestConfig, TestStep, async_test_body,
run_tests_no_exit)
run_tests_no_exit_sync)

try:
import fetch_paa_certs_from_dcl
Expand Down Expand Up @@ -390,7 +390,7 @@ def run_test(test_class: MatterBaseTest, tests: typing.List[str], test_config: T
stack = test_config.get_stack()
controller = test_config.get_controller()
matter_config = test_config.get_config(tests)
ok = run_tests_no_exit(test_class, matter_config, hooks, controller, stack)
ok = run_tests_no_exit_sync(test_class, matter_config, hooks, controller, stack)
if not ok:
print(f"Test failure. Failed on step: {hooks.get_failures()}")
return hooks.get_failures()
Expand Down
4 changes: 2 additions & 2 deletions src/python_testing/test_testing/MockTestRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from unittest.mock import MagicMock

from chip.clusters import Attribute
from chip.testing.matter_testing import MatterStackState, MatterTestConfig, run_tests_no_exit
from chip.testing.matter_testing import MatterStackState, MatterTestConfig, run_tests_no_exit_sync


class AsyncMock(MagicMock):
Expand Down Expand Up @@ -75,4 +75,4 @@ def run_test_with_mock_read(self, read_cache: Attribute.AsyncReadTransaction.Re
self.default_controller.Read = AsyncMock(return_value=read_cache)
# This doesn't need to do anything since we are overriding the read anyway
self.default_controller.FindOrEstablishPASESession = AsyncMock(return_value=None)
return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack)
return run_tests_no_exit_sync(self.test_class, self.config, hooks, self.default_controller, self.stack)
6 changes: 3 additions & 3 deletions src/python_testing/test_testing/test_TC_CCNTL_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from MockTestRunner import AsyncMock, MockTestRunner

try:
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit_sync
except ImportError:
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__), '..')))
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit_sync

invoke_call_count = 0
event_call_count = 0
Expand Down Expand Up @@ -166,7 +166,7 @@ def run_test_with_mock(self, dynamic_invoke_return: typing.Callable, dynamic_eve
self.default_controller.FindOrEstablishPASESession = AsyncMock(return_value=None)
self.default_controller.ReadEvent = AsyncMock(return_value=[], side_effect=dynamic_event_return)

return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack)
return run_tests_no_exit_sync(self.test_class, self.config, hooks, self.default_controller, self.stack)


@click.command()
Expand Down
6 changes: 3 additions & 3 deletions src/python_testing/test_testing/test_TC_MCORE_FS_1_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from MockTestRunner import AsyncMock, MockTestRunner

try:
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit_sync
except ImportError:
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__), '..')))
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit
from chip.testing.matter_testing import MatterTestConfig, get_default_paa_trust_store, run_tests_no_exit_sync

invoke_call_count = 0
event_call_count = 0
Expand Down Expand Up @@ -137,7 +137,7 @@ def run_test_with_mock(self, dynamic_invoke_return: typing.Callable, dynamic_eve
self.default_controller.FindOrEstablishPASESession = AsyncMock(return_value=None)
self.default_controller.ReadEvent = AsyncMock(return_value=[], side_effect=dynamic_event_return)

return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack)
return run_tests_no_exit_sync(self.test_class, self.config, hooks, self.default_controller, self.stack)


@click.command()
Expand Down
Loading