Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Report Device name in error when using AsyncStatus.wrap #607

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions src/ophyd_async/core/_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from bluesky.protocols import Status

from ._device import Device
from ._protocol import Watcher
from ._utils import Callback, P, T, WatcherUpdate

Expand All @@ -23,13 +24,14 @@
class AsyncStatusBase(Status):
"""Convert asyncio awaitable to bluesky Status interface"""

def __init__(self, awaitable: Coroutine | asyncio.Task):
def __init__(self, awaitable: Coroutine | asyncio.Task, name: str | None = None):
if isinstance(awaitable, asyncio.Task):
self.task = awaitable
else:
self.task = asyncio.create_task(awaitable)
self.task.add_done_callback(self._run_callbacks)
self._callbacks: list[Callback[Status]] = []
self._name = name

def __await__(self):
return self.task.__await__()
Expand Down Expand Up @@ -76,21 +78,27 @@ def __repr__(self) -> str:
status = "done"
else:
status = "pending"
return f"<{type(self).__name__}, task: {self.task.get_coro()}, {status}>"
device_str = f"device: {self._name}, " if self._name else ""
return (
f"<{type(self).__name__}, {device_str}"
f"task: {self.task.get_coro()}, {status}>"
)

__str__ = __repr__


class AsyncStatus(AsyncStatusBase):
"""Convert asyncio awaitable to bluesky Status interface"""

@classmethod
def wrap(cls: type[AS], f: Callable[P, Coroutine]) -> Callable[P, AS]:
"""Wrap an async function in an AsyncStatus."""

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
return cls(f(*args, **kwargs))
if args and isinstance(args[0], Device):
name = args[0].name
else:
name = None
return cls(f(*args, **kwargs), name=name)

# type is actually functools._Wrapped[P, Awaitable, P, AS]
# but functools._Wrapped is not necessarily available
Expand All @@ -100,11 +108,13 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface."""

def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
def __init__(
self, iterator: AsyncIterator[WatcherUpdate[T]], name: str | None = None
):
self._watchers: list[Watcher] = []
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator))
super().__init__(self._notify_watchers_from(iterator), name)

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for update in iterator:
Expand Down Expand Up @@ -136,7 +146,11 @@ def wrap(

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
return cls(f(*args, **kwargs))
if args and isinstance(args[0], Device):
name = args[0].name
else:
name = None
return cls(f(*args, **kwargs), name=name)

return cast(Callable[P, WAS], wrap_f)

Expand Down
21 changes: 15 additions & 6 deletions tests/core/test_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,14 @@ class FailingMovable(Movable, Device):
def _fail(self):
raise ValueError("This doesn't work")

async def _set(self, value):
@AsyncStatus.wrap
async def set(self, value):
if value:
self._fail()

def set(self, value) -> AsyncStatus:
return AsyncStatus(self._set(value))
return self._fail()


async def test_status_propogates_traceback_under_RE(RE) -> None:
expected_call_stack = ["_set", "_fail"]
expected_call_stack = ["set", "_fail"]
d = FailingMovable()
with pytest.raises(FailedStatus) as ctx:
RE(bps.mv(d, 3))
Expand Down Expand Up @@ -203,3 +201,14 @@ async def test_completed_status():
with pytest.raises(ValueError):
await completed_status(ValueError())
await completed_status()


async def test_device_name_in_failure_message_AsyncStatus_wrap(RE):
device_name = "MyFailingMovable"
d = FailingMovable(name=device_name)
with pytest.raises(FailedStatus) as ctx:
RE(bps.mv(d, 3))
# FailingMovable.set is decorated with @AsyncStatus.wrap
# undecorated methods will not print the device name
status: AsyncStatus = ctx.value.args[0]
assert f"device: {device_name}" in repr(status)
Loading