diff --git a/src/ophyd_async/core/_status.py b/src/ophyd_async/core/_status.py index 93b988840..028a7324a 100644 --- a/src/ophyd_async/core/_status.py +++ b/src/ophyd_async/core/_status.py @@ -13,6 +13,7 @@ from bluesky.protocols import Status +from ._device import Device from ._protocol import Watcher from ._utils import Callback, P, T, WatcherUpdate @@ -23,13 +24,14 @@ class AsyncStatusBase(Status): """Convert asyncio awaitable to bluesky Status interface""" - def __init__(self, awaitable: Coroutine | asyncio.Task): + def __init__(self, awaitable: Coroutine | asyncio.Task, name: str | None = None): if isinstance(awaitable, asyncio.Task): self.task = awaitable else: self.task = asyncio.create_task(awaitable) self.task.add_done_callback(self._run_callbacks) self._callbacks: list[Callback[Status]] = [] + self._name = name def __await__(self): return self.task.__await__() @@ -76,21 +78,27 @@ def __repr__(self) -> str: status = "done" else: status = "pending" - return f"<{type(self).__name__}, task: {self.task.get_coro()}, {status}>" + device_str = f"device: {self._name}, " if self._name else "" + return ( + f"<{type(self).__name__}, {device_str}" + f"task: {self.task.get_coro()}, {status}>" + ) __str__ = __repr__ class AsyncStatus(AsyncStatusBase): - """Convert asyncio awaitable to bluesky Status interface""" - @classmethod def wrap(cls: type[AS], f: Callable[P, Coroutine]) -> Callable[P, AS]: """Wrap an async function in an AsyncStatus.""" @functools.wraps(f) def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS: - return cls(f(*args, **kwargs)) + if args and isinstance(args[0], Device): + name = args[0].name + else: + name = None + return cls(f(*args, **kwargs), name=name) # type is actually functools._Wrapped[P, Awaitable, P, AS] # but functools._Wrapped is not necessarily available @@ -100,11 +108,13 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS: class WatchableAsyncStatus(AsyncStatusBase, Generic[T]): """Convert AsyncIterator of WatcherUpdates to bluesky Status interface.""" - def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]): + def __init__( + self, iterator: AsyncIterator[WatcherUpdate[T]], name: str | None = None + ): self._watchers: list[Watcher] = [] self._start = time.monotonic() self._last_update: WatcherUpdate[T] | None = None - super().__init__(self._notify_watchers_from(iterator)) + super().__init__(self._notify_watchers_from(iterator), name) async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]): async for update in iterator: @@ -136,7 +146,11 @@ def wrap( @functools.wraps(f) def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS: - return cls(f(*args, **kwargs)) + if args and isinstance(args[0], Device): + name = args[0].name + else: + name = None + return cls(f(*args, **kwargs), name=name) return cast(Callable[P, WAS], wrap_f) diff --git a/tests/core/test_status.py b/tests/core/test_status.py index 263a39dca..109dfd4b9 100644 --- a/tests/core/test_status.py +++ b/tests/core/test_status.py @@ -118,16 +118,14 @@ class FailingMovable(Movable, Device): def _fail(self): raise ValueError("This doesn't work") - async def _set(self, value): + @AsyncStatus.wrap + async def set(self, value): if value: - self._fail() - - def set(self, value) -> AsyncStatus: - return AsyncStatus(self._set(value)) + return self._fail() async def test_status_propogates_traceback_under_RE(RE) -> None: - expected_call_stack = ["_set", "_fail"] + expected_call_stack = ["set", "_fail"] d = FailingMovable() with pytest.raises(FailedStatus) as ctx: RE(bps.mv(d, 3)) @@ -203,3 +201,14 @@ async def test_completed_status(): with pytest.raises(ValueError): await completed_status(ValueError()) await completed_status() + + +async def test_device_name_in_failure_message_AsyncStatus_wrap(RE): + device_name = "MyFailingMovable" + d = FailingMovable(name=device_name) + with pytest.raises(FailedStatus) as ctx: + RE(bps.mv(d, 3)) + # FailingMovable.set is decorated with @AsyncStatus.wrap + # undecorated methods will not print the device name + status: AsyncStatus = ctx.value.args[0] + assert f"device: {device_name}" in repr(status)