Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add Client.apply method, adjust TaskStatus to properly capture exceptions #53

Merged
merged 8 commits into from
Jul 12, 2024
22 changes: 22 additions & 0 deletions docs/source/upcoming_release_notes/53-enh_client_apply.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
53 enh_client_apply
###################

API Breaks
----------
- N/A

Features
--------
- Implements `Client.apply` method for writing values from `Entry` data to the control system.

Bugfixes
--------
- N/A

Maintenance
-----------
- Adjusts `TaskStatus` use to properly capture exceptions.

Contributors
------------
- tangkong
120 changes: 112 additions & 8 deletions superscore/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
"""Client for superscore. Used for programmatic interactions with superscore"""
from typing import Any, Generator
import logging
from typing import Any, Generator, List, Optional, Union
from uuid import UUID

from superscore.backends.core import _Backend
from superscore.model import Entry
from superscore.control_layers import ControlLayer
from superscore.control_layers.status import TaskStatus
from superscore.model import Entry, Setpoint, Snapshot

logger = logging.getLogger(__name__)


class Client:
backend: _Backend
cl: ControlLayer

def __init__(self, backend=None, **kwargs) -> None:
# if backend is None, startup default filestore backend
return
def __init__(self, backend: _Backend, **kwargs) -> None:
self.backend = backend
self.cl = ControlLayer()

@classmethod
def from_config(cls, cfg=None):
Expand All @@ -34,9 +41,106 @@ def compare(self, entry_l: Entry, entry_r: Entry) -> Any:
"""Compare two entries. Should be of same type, and return a diff"""
raise NotImplementedError

def apply(self, entry: Entry):
"""Apply settings found in ``entry``. If no values found, no-op"""
raise NotImplementedError
def apply(
self,
entry: Union[Setpoint, Snapshot],
sequential: bool = False
) -> Optional[List[TaskStatus]]:
"""
Apply settings found in ``entry``. If no writable values found, return.
If ``sequential`` is True, apply values in ``entry`` in sequence, blocking
with each put request. Else apply all values simultaneously (asynchronously)

Returns
tangkong marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
entry : Union[Setpoint, Snapshot]
The entry to apply values from
sequential : bool, optional
Whether to apply values sequentially, by default False

Returns
-------
Optional[List[TaskStatus]]
TaskStatus(es) for each value applied.
"""
if not isinstance(entry, (Setpoint, Snapshot)):
logger.info("Entries must be a Snapshot or Setpoint")
return

if isinstance(entry, Setpoint):
return [self.cl.put(entry.pv_name, entry.data)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the extra wrapping of list/not list here, I have to wonder if it would have been simpler for put to always return a list of statuses, even if we only put one PV

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had this thought too. In the end I decided that if I, as a user of the Client, put to a single PV and had to index into a single-element list, I'd be annoyed.

But maybe the users and I can just deal with it 🤷

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way is probably ok, your explanation here makes sense to me


# Gather pv-value list and apply at once
status_list = []
pv_list, data_list = self._gather_data(entry)
if sequential:
for pv, data in zip(pv_list, data_list):
logger.debug(f'Putting {pv} = {data}')
status: TaskStatus = self.cl.put(pv, data)
if status.exception():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is status guaranteed to be complete after self.cl.put returns it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does currently, since we asyncio.run and block until the coroutine is finished. I'm not entirely convinced this is the best way to do things, but returning the status immediately would require us either

  • be in an async context
  • run the event loop in a separate thread (ophyd style).

I've been ok with this for now since we have the ability to run parallel puts (and wait on all of them), but it has been something I've been pondering changing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is totally ok, I wanted to make sure you hadn't skip a wait call here (since wait was added in this PR).
It's not so strange to require more nuanced asynchronous behavior to be done in an async context.

logger.warning(f"Failed to put {pv} = {data}, "
"terminating put sequence")
return

status_list.append(status)
else:
return self.cl.put(pv_list, data_list)

def _gather_data(
self,
entry: Union[Setpoint, Snapshot, UUID],
pv_list: Optional[List[str]] = None,
data_list: Optional[List[Any]] = None
) -> Optional[tuple[List[str], List[Any]]]:
"""
Gather writable pv name - data pairs recursively.
If pv_list and data_list are provided, gathered data will be added to
these lists in-place. If both lists are omitted, this function will return
the two lists after gathering.

Queries the backend to fill any UUID values found.

Parameters
----------
entry : Union[Setpoint, Snapshot, UUID]
Entry to gather writable data from
pv_list : Optional[List[str]], optional
List of addresses to write data to, by default None
data_list : Optional[List[Any]], optional
List of data to write to addresses in ``pv_list``, by default None

Returns
-------
Optional[tuple[List[str], List[Any]]]
the filled pv_list and data_list
"""
top_level = False
if (pv_list is None) and (data_list is None):
pv_list = []
data_list = []
top_level = True
elif (pv_list is None) or (data_list is None):
raise ValueError(
"Arguments pv_list and data_list must either both be provided "
"or both omitted."
)

if isinstance(entry, Snapshot):
for child in entry.children:
self._gather_data(child, pv_list, data_list)
elif isinstance(entry, UUID):
child_entry = self.backend.get_entry(entry)
self._gather_data(child_entry, pv_list, data_list)
elif isinstance(entry, Setpoint):
pv_list.append(entry.pv_name)
data_list.append(entry.data)

# Readbacks are not writable, and are not gathered

if top_level:
return pv_list, data_list

def validate(self, entry: Entry):
"""
Expand Down
1 change: 1 addition & 0 deletions superscore/control_layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .core import ControlLayer # noqa
from .status import TaskStatus # noqa
6 changes: 3 additions & 3 deletions superscore/control_layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def status_coro():
status = self._put_one(address, value)
if cb is not None:
status.add_callback(cb)
await status.task
await asyncio.gather(status, return_exceptions=True)
return status

return asyncio.run(status_coro())
Expand Down Expand Up @@ -185,15 +185,15 @@ async def status_coros():
status.add_callback(c)

statuses.append(status)
await asyncio.gather(*[s.task for s in statuses])
await asyncio.gather(*statuses, return_exceptions=True)
return statuses

return asyncio.run(status_coros())

@TaskStatus.wrap
async def _put_one(self, address: str, value: Any):
"""
Base async get function. Use this to construct higher-level get methods
Base async put function. Use this to construct higher-level put methods
"""
shim = self.shim_from_pv(address)
await shim.put(address, value)
Expand Down
27 changes: 25 additions & 2 deletions superscore/control_layers/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
class TaskStatus:
"""
Unified Status object for wrapping task completion information and attaching
callbacks This must be created inside of a coroutine, but can be returned to
synchronous scope for examining the task
callbacks. This must be created inside of a coroutine, but can be returned to
synchronous scope for examining the task.

Awaiting this status is similar to awaiting the wrapped task.

Largely vendored from bluesky/ophyd-async
"""
Expand Down Expand Up @@ -57,6 +59,27 @@ def success(self) -> bool:
and self.task.exception() is None
)

def wait(self, timeout=None) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small documentation thing: you've added this wait function here and not documented its inclusion in the release notes or in the PR text. It's also not used anywhere outside the test suite yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I'd need it, then I realized the full ramifications of the statuses I'd created 😆

"""
Block until the coroutine finishes. Raises asyncio.TimeoutError if
the timeout elapses before the task is completed

To be called in a synchronous context, if the status has not been awaited

Parameters
----------
timeout : number, optional
timeout in seconds, by default None

Raises
------
asyncio.TimeoutError
"""
# ensure task runs in the event loop it was assigned to originally
asyncio.get_event_loop().run_until_complete(
asyncio.wait_for(self.task, timeout)
)

def __repr__(self) -> str:
if self.done:
if e := self.exception():
Expand Down
27 changes: 27 additions & 0 deletions superscore/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import shutil
from pathlib import Path
from typing import List
from unittest.mock import MagicMock

import pytest

from superscore.backends.core import _Backend
from superscore.backends.filestore import FilestoreBackend
from superscore.backends.test import TestBackend
from superscore.client import Client
from superscore.control_layers._base_shim import _BaseShim
from superscore.control_layers.core import ControlLayer
from superscore.model import (Collection, Parameter, Readback, Root, Setpoint,
Expand Down Expand Up @@ -697,3 +699,28 @@ def dummy_cl() -> ControlLayer:
cl.shims['ca'] = DummyShim()
cl.shims['pva'] = DummyShim()
return cl


@pytest.fixture(scope='function')
def mock_backend() -> _Backend:
bk = _Backend()
bk.delete_entry = MagicMock()
bk.save_entry = MagicMock()
bk.get_entry = MagicMock()
bk.search = MagicMock()
bk.update_entry = MagicMock()


class MockTaskStatus:
def exception(self):
return None

@property
def done(self):
return True


@pytest.fixture(scope='function')
def mock_client(mock_backend: _Backend) -> Client:
client = Client(backend=mock_backend)
return client
8 changes: 5 additions & 3 deletions superscore/tests/test_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def test_fail(dummy_cl):
mock_ca_put = AsyncMock(side_effect=ValueError)
dummy_cl.shims['ca'].put = mock_ca_put

# exceptions get passed through the control layer
with pytest.raises(ValueError):
dummy_cl.put("THAT:PV", 4)
# exceptions get captured in status object
status = dummy_cl.put("THAT:PV", 4)
assert isinstance(status.exception(), ValueError)

assert mock_ca_put.called


def test_put_callback(dummy_cl):
Expand Down
21 changes: 21 additions & 0 deletions superscore/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from unittest.mock import patch

from superscore.client import Client
from superscore.model import Root

from .conftest import MockTaskStatus


@patch('superscore.control_layers.core.ControlLayer.put')
def test_apply(put_mock, mock_client: Client, sample_database: Root):
put_mock.return_value = MockTaskStatus()
snap = sample_database.entries[3]
mock_client.apply(snap)
assert put_mock.call_count == 1
call_args = put_mock.call_args[0]
assert len(call_args[0]) == len(call_args[1]) == 3

put_mock.reset_mock()

mock_client.apply(snap, sequential=True)
assert put_mock.call_count == 3
42 changes: 40 additions & 2 deletions superscore/tests/test_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ async def inner_coroutine():
return inner_coroutine


@pytest.fixture
async def long_coroutine_status() -> TaskStatus:
@TaskStatus.wrap
async def inner_coroutine():
for i in range(100):
print(f'coro wait: {i}')
await asyncio.sleep(1)

return inner_coroutine()


async def test_status_success(normal_coroutine):
st = TaskStatus(normal_coroutine())
assert isinstance(st, TaskStatus)
Expand All @@ -40,7 +51,34 @@ async def test_status_fail(failing_coroutine):
with pytest.raises(ValueError):
await status

assert type(status.exception()) == ValueError
assert isinstance(status.exception(), ValueError)
shilorigins marked this conversation as resolved.
Show resolved Hide resolved


def test_sync_status_fail(failing_coroutine):
# A usage note for the curious. If we gather these tasks with
# `return_exceptions` = False (default), the first exception will be propagated,
# though the other tasks will complete. This may stop tasks from being returned
# `retur_exceptions` = True will not raise exceptions, instead those exceptions
# will only be captured in `task.exception()`
async def wrap_coro(return_exc: bool):
status = TaskStatus(failing_coroutine())
await asyncio.gather(status, return_exceptions=return_exc)
return status

status = asyncio.run(wrap_coro(True))
assert status.done
assert isinstance(status.exception(), ValueError)

with pytest.raises(ValueError):
asyncio.run(wrap_coro(False))


def test_status_wait(long_coroutine_status):
assert not long_coroutine_status.done
with pytest.raises(asyncio.TimeoutError):
long_coroutine_status.wait(1)
assert long_coroutine_status.done
assert isinstance(long_coroutine_status.exception(), asyncio.CancelledError)


async def test_status_wrap():
Expand All @@ -50,5 +88,5 @@ async def coro_status():

st = coro_status()
assert isinstance(st, TaskStatus)
await st.task
await st
assert st.done