Skip to content

Commit

Permalink
Drop @callback and other unused async helpers (#75)
Browse files Browse the repository at this point in the history
* Bye bye `@callback`

* Fix unit tests

* Drop `async_run_job`

* Drop `run_callback_threadsafe`
  • Loading branch information
puddly authored Jul 10, 2024
1 parent 2008457 commit 0854912
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 363 deletions.
218 changes: 22 additions & 196 deletions tests/test_async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,28 @@
import asyncio
import functools
import time
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, patch

import pytest

from zha import async_ as zha_async
from zha.application.gateway import Gateway
from zha.async_ import AsyncUtilMixin, ZHAJob, ZHAJobType, create_eager_task
from zha.decorators import callback


async def test_zhajob_forbid_coroutine() -> None:
"""Test zhajob forbids coroutines."""

async def bla():
pass

coro = bla()

with pytest.raises(ValueError):
_ = ZHAJob(coro).job_type

# To avoid warning about unawaited coro
await coro


@pytest.mark.parametrize("eager_start", [True, False])
async def test_cancellable_zhajob(zha_gateway: Gateway, eager_start: bool) -> None:
"""Simulate a shutdown, ensure cancellable jobs are cancelled."""
job = MagicMock()

@callback
def run_job(job: ZHAJob) -> None:
"""Call the action."""
zha_gateway.async_run_zha_job(job, eager_start=eager_start)

timer1 = zha_gateway.loop.call_later(
60, run_job, ZHAJob(callback(job), cancel_on_shutdown=True)
60, run_job, ZHAJob(job, cancel_on_shutdown=True)
)
timer2 = zha_gateway.loop.call_later(60, run_job, ZHAJob(callback(job)))
timer2 = zha_gateway.loop.call_later(60, run_job, ZHAJob(job))

await zha_gateway.shutdown()

Expand All @@ -57,7 +40,7 @@ async def test_async_add_zha_job_schedule_callback() -> None:
zha_gateway = MagicMock(loop=MagicMock(wraps=asyncio.get_running_loop()))
job = MagicMock()

AsyncUtilMixin.async_add_zha_job(zha_gateway, ZHAJob(callback(job)))
AsyncUtilMixin.async_add_zha_job(zha_gateway, ZHAJob(job))
assert len(zha_gateway.loop.call_soon.mock_calls) == 1
assert len(zha_gateway.loop.create_task.mock_calls) == 0
assert len(zha_gateway.add_job.mock_calls) == 0
Expand All @@ -71,9 +54,7 @@ async def test_async_add_zha_job_eager_start_coro_suspends(
async def job_that_suspends():
await asyncio.sleep(0)

task = zha_gateway.async_add_zha_job(
ZHAJob(callback(job_that_suspends)), eager_start=True
)
task = zha_gateway.async_add_zha_job(ZHAJob(job_that_suspends), eager_start=True)
assert not task.done()
assert task in zha_gateway._tracked_completable_tasks
await task
Expand All @@ -88,7 +69,7 @@ async def test_async_run_zha_job_eager_start_coro_suspends(
async def job_that_suspends():
await asyncio.sleep(0)

task = zha_gateway.async_run_zha_job(ZHAJob(callback(job_that_suspends)))
task = zha_gateway.async_run_zha_job(ZHAJob(job_that_suspends))
assert not task.done()
assert task in zha_gateway._tracked_completable_tasks
await task
Expand All @@ -101,9 +82,7 @@ async def test_async_add_zha_job_background(zha_gateway: Gateway) -> None:
async def job_that_suspends():
await asyncio.sleep(0)

task = zha_gateway.async_add_zha_job(
ZHAJob(callback(job_that_suspends)), background=True
)
task = zha_gateway.async_add_zha_job(ZHAJob(job_that_suspends), background=True)
assert not task.done()
assert task in zha_gateway._background_tasks
await task
Expand All @@ -116,9 +95,7 @@ async def test_async_run_zha_job_background(zha_gateway: Gateway) -> None:
async def job_that_suspends():
await asyncio.sleep(0)

task = zha_gateway.async_run_zha_job(
ZHAJob(callback(job_that_suspends)), background=True
)
task = zha_gateway.async_run_zha_job(ZHAJob(job_that_suspends), background=True)
assert not task.done()
assert task in zha_gateway._background_tasks
await task
Expand All @@ -131,9 +108,7 @@ async def test_async_add_zha_job_eager_background(zha_gateway: Gateway) -> None:
async def job_that_suspends():
await asyncio.sleep(0)

task = zha_gateway.async_add_zha_job(
ZHAJob(callback(job_that_suspends)), background=True
)
task = zha_gateway.async_add_zha_job(ZHAJob(job_that_suspends), background=True)
assert not task.done()
assert task in zha_gateway._background_tasks
await task
Expand All @@ -146,43 +121,28 @@ async def test_async_run_zha_job_eager_background(zha_gateway: Gateway) -> None:
async def job_that_suspends():
await asyncio.sleep(0)

task = zha_gateway.async_run_zha_job(
ZHAJob(callback(job_that_suspends)), background=True
)
task = zha_gateway.async_run_zha_job(ZHAJob(job_that_suspends), background=True)
assert not task.done()
assert task in zha_gateway._background_tasks
await task
assert task not in zha_gateway._background_tasks


async def test_async_run_zha_job_background_synchronous(
@pytest.mark.parametrize("background", [True, False])
async def test_async_run_zha_job_background_no_suspend(
zha_gateway: Gateway,
background: bool,
) -> None:
"""Test scheduling a coro as an eager background task with async_run_zha_job."""

async def job_that_does_not_suspends():
pass

task = zha_gateway.async_run_zha_job(
ZHAJob(callback(job_that_does_not_suspends)),
background=True,
)
assert task.done()
assert task not in zha_gateway._background_tasks
assert task not in zha_gateway._tracked_completable_tasks
await task


async def test_async_run_zha_job_synchronous(zha_gateway: Gateway) -> None:
"""Test scheduling a coro as an eager task with async_run_zha_job."""

async def job_that_does_not_suspends():
pass

task = zha_gateway.async_run_zha_job(
ZHAJob(callback(job_that_does_not_suspends)),
background=False,
ZHAJob(job_that_does_not_suspends),
background=background,
)
assert task is not None
assert task.done()
assert task not in zha_gateway._background_tasks
assert task not in zha_gateway._tracked_completable_tasks
Expand Down Expand Up @@ -219,7 +179,7 @@ async def test_async_add_zha_job_schedule_partial_callback() -> None:
"""Test that we schedule partial coros and add jobs to the job pool."""
zha_gateway = MagicMock(loop=MagicMock(wraps=asyncio.get_running_loop()))
job = MagicMock()
partial = functools.partial(callback(job))
partial = functools.partial(job)

AsyncUtilMixin.async_add_zha_job(zha_gateway, ZHAJob(partial))
assert len(zha_gateway.loop.call_soon.mock_calls) == 1
Expand Down Expand Up @@ -252,6 +212,7 @@ async def job():
) as mock_create_eager_task:
zha_job = ZHAJob(job)
task = AsyncUtilMixin.async_add_zha_job(zha_gateway, zha_job, eager_start=True)
assert task is not None
assert len(zha_gateway.loop.call_soon.mock_calls) == 0
assert len(zha_gateway.add_job.mock_calls) == 0
assert mock_create_eager_task.mock_calls
Expand Down Expand Up @@ -283,9 +244,9 @@ def job():
pass

AsyncUtilMixin.async_add_zha_job(zha_gateway, ZHAJob(job))
assert len(zha_gateway.loop.call_soon.mock_calls) == 0
assert len(zha_gateway.loop.call_soon.mock_calls) == 1
assert len(zha_gateway.loop.create_task.mock_calls) == 0
assert len(zha_gateway.loop.run_in_executor.mock_calls) == 2
assert len(zha_gateway.loop.run_in_executor.mock_calls) == 0


async def test_async_create_task_schedule_coroutine() -> None:
Expand Down Expand Up @@ -337,7 +298,7 @@ def job():
asyncio.get_running_loop() # ensure we are in the event loop
calls.append(1)

AsyncUtilMixin.async_run_zha_job(zha_gateway, ZHAJob(callback(job)))
AsyncUtilMixin.async_run_zha_job(zha_gateway, ZHAJob(job))
assert len(calls) == 1


Expand All @@ -360,24 +321,11 @@ async def test_async_run_zha_job_calls_callback() -> None:
def job():
calls.append(1)

AsyncUtilMixin.async_run_zha_job(zha_gateway, ZHAJob(callback(job)))
AsyncUtilMixin.async_run_zha_job(zha_gateway, ZHAJob(job))
assert len(calls) == 1
assert len(zha_gateway.async_add_job.mock_calls) == 0


async def test_async_run_zha_job_delegates_non_async() -> None:
"""Test that the callback annotation is respected."""
zha_gateway = MagicMock()
calls = []

def job():
calls.append(1)

AsyncUtilMixin.async_run_zha_job(zha_gateway, ZHAJob(job))
assert len(calls) == 0
assert len(zha_gateway.async_add_zha_job.mock_calls) == 1


async def test_async_create_task_pending_tasks_coro(zha_gateway: Gateway) -> None:
"""Add a coro to pending tasks."""
call_count = []
Expand All @@ -395,34 +343,6 @@ async def test_coro():
assert len(zha_gateway._tracked_completable_tasks) == 0


async def test_async_run_job_starts_tasks_eagerly(zha_gateway: Gateway) -> None:
"""Test async_run_job starts tasks eagerly."""
runs = []

async def _test():
runs.append(True)

task = zha_gateway.async_run_job(_test)
# No call to zha_gateway.async_block_till_done to ensure the task is run eagerly
assert len(runs) == 1
assert task.done()
await task


async def test_async_run_job_starts_coro_eagerly(zha_gateway: Gateway) -> None:
"""Test async_run_job starts coros eagerly."""
runs = []

async def _test():
runs.append(True)

task = zha_gateway.async_run_job(_test())
# No call to zha_gateway.async_block_till_done to ensure the task is run eagerly
assert len(runs) == 1
assert task.done()
await task


@pytest.mark.parametrize("eager_start", [True, False])
async def test_background_task(zha_gateway: Gateway, eager_start: bool) -> None:
"""Test background tasks being quit."""
Expand All @@ -447,7 +367,6 @@ async def test_task():
def test_ZHAJob_passing_job_type():
"""Test passing the job type to ZHAJob when we already know it."""

@callback
def callback_func():
pass

Expand Down Expand Up @@ -506,31 +425,6 @@ async def _async_add_executor_job():
await task


@patch("concurrent.futures.Future")
@patch("threading.get_ident")
def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _) -> None:
"""Testing calling run_callback_threadsafe from inside an event loop."""
callback_fn = MagicMock()

loop = Mock(spec=["call_soon_threadsafe"])

loop._thread_ident = None
mock_ident.return_value = 5
zha_async.run_callback_threadsafe(loop, callback_fn)
assert len(loop.call_soon_threadsafe.mock_calls) == 1

loop._thread_ident = 5
mock_ident.return_value = 5
with pytest.raises(RuntimeError):
zha_async.run_callback_threadsafe(loop, callback_fn)
assert len(loop.call_soon_threadsafe.mock_calls) == 1

loop._thread_ident = 1
mock_ident.return_value = 5
zha_async.run_callback_threadsafe(loop, callback_fn)
assert len(loop.call_soon_threadsafe.mock_calls) == 2


async def test_gather_with_limited_concurrency() -> None:
"""Test gather_with_limited_concurrency limits the number of running tasks."""

Expand All @@ -553,74 +447,6 @@ async def _increment_runs_if_in_time():
assert results == [2, 2, -1, -1]


async def test_shutdown_run_callback_threadsafe(zha_gateway: Gateway) -> None:
"""Test we can shutdown run_callback_threadsafe."""
zha_async.shutdown_run_callback_threadsafe(zha_gateway.loop)
callback_fn = MagicMock()

with pytest.raises(RuntimeError):
zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)


async def test_run_callback_threadsafe(zha_gateway: Gateway) -> None:
"""Test run_callback_threadsafe runs code in the event loop."""
it_ran = False

def callback_fn():
nonlocal it_ran
it_ran = True

assert zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)
assert it_ran is False

# Verify that async_block_till_done will flush
# out the callback
await zha_gateway.async_block_till_done()
assert it_ran is True


async def test_run_callback_threadsafe_exception(zha_gateway: Gateway) -> None:
"""Test run_callback_threadsafe runs code in the event loop."""
it_ran = False

def callback_fn():
nonlocal it_ran
it_ran = True
raise ValueError("Test")

future = zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)
assert future
assert it_ran is False

# Verify that async_block_till_done will flush
# out the callback
await zha_gateway.async_block_till_done()
assert it_ran is True

with pytest.raises(ValueError):
future.result()


async def test_callback_is_always_scheduled(zha_gateway: Gateway) -> None:
"""Test run_callback_threadsafe always calls call_soon_threadsafe before checking for shutdown."""
# We have to check the shutdown state AFTER the callback is scheduled otherwise
# the function could continue on and the caller call `future.result()` after
# the point in the main thread where callbacks are no longer run.

callback_fn = MagicMock()
zha_async.shutdown_run_callback_threadsafe(zha_gateway.loop)

with (
patch.object(
zha_gateway.loop, "call_soon_threadsafe"
) as mock_call_soon_threadsafe,
pytest.raises(RuntimeError),
):
zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)

mock_call_soon_threadsafe.assert_called_once()


async def test_create_eager_task_312(zha_gateway: Gateway) -> None: # pylint: disable=unused-argument
"""Test create_eager_task schedules a task eagerly in the event loop.
Expand Down
Loading

0 comments on commit 0854912

Please sign in to comment.