diff --git a/docs/source/traits_user_manual/notification.rst b/docs/source/traits_user_manual/notification.rst index 13ca695df..2d4f1de62 100644 --- a/docs/source/traits_user_manual/notification.rst +++ b/docs/source/traits_user_manual/notification.rst @@ -313,6 +313,39 @@ it is invoked. The following example shows the first option: :start-at: from traits.api +Async Notification Handlers +``````````````````````````` + +Since Traits 7.0 you can use an async coroutine as an observe handler, either +with an |@observe| decorator:: + + class AsyncExample(HasTraits): + value = Str() + + @observe('value') + async def _value_updated(self, event): + await asyncio.sleep(0) + print("value changed") + +or via the |HasTraits.observe| method:: + + async def async_observer(self, event): + await asyncio.sleep(0) + print("value changed") + + async_example = AsyncExample() + async_example.observe(async_observer, "value") + + +When a trait change event occurs which is observed by an async handler while +in an asyncio event loop, a task will be created to call the handler at a later +time. If the event loop is not running an exception will be raised. + +.. warning:: + + This is an experimental feature, and behavior may change in the future. + + Features and fixes provided by |@observe| ----------------------------------------- diff --git a/traits/observation/observe.py b/traits/observation/observe.py index dfab99761..eeb3a8e75 100644 --- a/traits/observation/observe.py +++ b/traits/observation/observe.py @@ -8,23 +8,39 @@ # # Thanks for using Enthought open source! +import asyncio +import inspect + from traits.observation._observe import add_or_remove_notifiers from traits.observation.expression import compile_expr +#: Set to hold references to active async traits handlers. +_active_handler_tasks = set() + def dispatch_same(handler, event): """ Dispatch an event handler on the same thread. + This dispatcher accepts both callables and async callables, the latter + being dispatched asynchronously via an async Task. Asynchronous dispatch + is only available when an async event loop is running; it will raise if + it cannot create an async Task. + Parameters ---------- - handler : callable(event) + handler : callable(event) or async callable(event) User-defined callable to handle change events. ``event`` is an object representing the change. Its type and content depends on the change. event : object The event object to be given to handler. """ - handler(event) + if inspect.iscoroutinefunction(handler): + task = asyncio.create_task(handler(event)) + _active_handler_tasks.add(task) + task.add_done_callback(_active_handler_tasks.discard) + else: + handler(event) def observe( diff --git a/traits/observation/tests/test_observe.py b/traits/observation/tests/test_observe.py index 29f00b2c5..436987965 100644 --- a/traits/observation/tests/test_observe.py +++ b/traits/observation/tests/test_observe.py @@ -8,6 +8,8 @@ # # Thanks for using Enthought open source! +import asyncio +from contextlib import contextmanager import unittest from unittest import mock @@ -653,3 +655,70 @@ def test_apply_observers_different_target(self): # then # the handler should be called twice as the targets are different. self.assertEqual(handler.call_count, 2) + + +# ---- Low-level tests for async dispatch_same ------------------------------ + + +class TestAsyncDispatchSame(unittest.IsolatedAsyncioTestCase): + """Test low-level async dispatch.""" + + def setUp(self): + from traits.observation.observe import _active_handler_tasks + + # ensure no lingering references to handler tasks after test run + self.addCleanup(_active_handler_tasks.clear) + + push_exception_handler(reraise_exceptions=True) + self.addCleanup(pop_exception_handler) + + async def test_async_dispatch(self): + event = asyncio.Event() + + async def handler(event): + event.set() + + dispatch_same(handler, event) + + await asyncio.wait_for(event.wait(), timeout=10) + + self.assertTrue(event.is_set()) + + async def test_async_dispatch_error(self): + event = asyncio.Event() + exceptions = [] + + async def handler(event): + raise Exception("Bad handler") + + def exception_handler(loop, context): + exceptions.append(context["exception"].args[0]) + event.set() + + with self.asyncio_exception_handler(exception_handler): + dispatch_same(handler, event) + await asyncio.wait_for(event.wait(), timeout=10.0) + + self.assertEqual(exceptions, ["Bad handler"]) + + def test_async_dispatch_no_loop(self): + event = asyncio.Event() + + async def handler(event): + event.set() + + with self.assertWarns(RuntimeWarning): + with self.assertRaises(RuntimeError): + dispatch_same(handler, event) + + self.assertFalse(event.is_set()) + + @contextmanager + def asyncio_exception_handler(self, exc_handler): + loop = asyncio.get_event_loop() + old_handler = loop.get_exception_handler() + loop.set_exception_handler(exc_handler) + try: + yield exc_handler + finally: + loop.set_exception_handler(old_handler) diff --git a/traits/tests/test_observe.py b/traits/tests/test_observe.py index bd0460f8f..266f5daa1 100644 --- a/traits/tests/test_observe.py +++ b/traits/tests/test_observe.py @@ -12,6 +12,7 @@ See tests in ``traits.observations`` for more targeted tests. """ +import asyncio import unittest from traits.api import ( @@ -930,3 +931,68 @@ class A(HasTraits): self.assertEqual(event.index, 2) self.assertEqual(event.removed, [3]) self.assertEqual(event.added, [4]) + + +# Integration tests for async observe decorator ------------------------------- + + +class SimpleAsyncExample(HasTraits): + + value = Str() + + events = List() + + event = Instance(asyncio.Event) + + queue = Instance(asyncio.Queue) + + @observe('value') + async def value_changed_async(self, event): + queue_value = await self.queue.get() + self.events.append((event, queue_value)) + self.event.set() + + +class TestAsyncObserverDecorator(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + from traits.observation.observe import _active_handler_tasks + + # ensure no lingering references to handler tasks after test run + self.addCleanup(_active_handler_tasks.clear) + + async def test_async_dispatch(self): + event = asyncio.Event() + queue = asyncio.Queue() + + obj = SimpleAsyncExample(value='initial', event=event, queue=queue) + + self.assertEqual(len(obj.events), 0) + + task = asyncio.create_task(queue.put("first")) + + await asyncio.wait_for(event.wait(), timeout=10) + + self.assertTrue(task.done()) + self.assertEqual(len(obj.events), 1) + trait_event, queue_value = obj.events[0] + self.assertEqual(trait_event.name, 'value') + self.assertEqual(trait_event.new, 'initial') + self.assertEqual(queue_value, 'first') + + event.clear() + + obj.value = 'changed' + + self.assertEqual(len(obj.events), 1) + + task = asyncio.create_task(queue.put("second")) + + await asyncio.wait_for(event.wait(), timeout=10) + + self.assertTrue(task.done()) + self.assertEqual(len(obj.events), 2) + trait_event, queue_value = obj.events[1] + self.assertEqual(trait_event.name, 'value') + self.assertEqual(trait_event.new, 'changed') + self.assertEqual(queue_value, 'second')