diff --git a/docs/source/upcoming_release_notes/53-enh_client_apply.rst b/docs/source/upcoming_release_notes/53-enh_client_apply.rst new file mode 100644 index 0000000..efe5efa --- /dev/null +++ b/docs/source/upcoming_release_notes/53-enh_client_apply.rst @@ -0,0 +1,22 @@ +53 enh_client_apply +################### + +API Breaks +---------- +- N/A + +Features +-------- +- Implements `Client.apply` method for writing values from `Entry` data to the control system. + +Bugfixes +-------- +- N/A + +Maintenance +----------- +- Adjusts `TaskStatus` use to properly capture exceptions. + +Contributors +------------ +- tangkong diff --git a/superscore/client.py b/superscore/client.py index 0482c5f..f306aee 100644 --- a/superscore/client.py +++ b/superscore/client.py @@ -1,16 +1,23 @@ """Client for superscore. Used for programmatic interactions with superscore""" -from typing import Any, Generator +import logging +from typing import Any, Generator, List, Optional, Union +from uuid import UUID from superscore.backends.core import _Backend -from superscore.model import Entry +from superscore.control_layers import ControlLayer +from superscore.control_layers.status import TaskStatus +from superscore.model import Entry, Setpoint, Snapshot + +logger = logging.getLogger(__name__) class Client: backend: _Backend + cl: ControlLayer - def __init__(self, backend=None, **kwargs) -> None: - # if backend is None, startup default filestore backend - return + def __init__(self, backend: _Backend, **kwargs) -> None: + self.backend = backend + self.cl = ControlLayer() @classmethod def from_config(cls, cfg=None): @@ -34,9 +41,104 @@ def compare(self, entry_l: Entry, entry_r: Entry) -> Any: """Compare two entries. Should be of same type, and return a diff""" raise NotImplementedError - def apply(self, entry: Entry): - """Apply settings found in ``entry``. If no values found, no-op""" - raise NotImplementedError + def apply( + self, + entry: Union[Setpoint, Snapshot], + sequential: bool = False + ) -> Optional[List[TaskStatus]]: + """ + Apply settings found in ``entry``. If no writable values found, return. + If ``sequential`` is True, apply values in ``entry`` in sequence, blocking + with each put request. Else apply all values simultaneously (asynchronously) + + Parameters + ---------- + entry : Union[Setpoint, Snapshot] + The entry to apply values from + sequential : bool, optional + Whether to apply values sequentially, by default False + + Returns + ------- + Optional[List[TaskStatus]] + TaskStatus(es) for each value applied. + """ + if not isinstance(entry, (Setpoint, Snapshot)): + logger.info("Entries must be a Snapshot or Setpoint") + return + + if isinstance(entry, Setpoint): + return [self.cl.put(entry.pv_name, entry.data)] + + # Gather pv-value list and apply at once + status_list = [] + pv_list, data_list = self._gather_data(entry) + if sequential: + for pv, data in zip(pv_list, data_list): + logger.debug(f'Putting {pv} = {data}') + status: TaskStatus = self.cl.put(pv, data) + if status.exception(): + logger.warning(f"Failed to put {pv} = {data}, " + "terminating put sequence") + return + + status_list.append(status) + else: + return self.cl.put(pv_list, data_list) + + def _gather_data( + self, + entry: Union[Setpoint, Snapshot, UUID], + pv_list: Optional[List[str]] = None, + data_list: Optional[List[Any]] = None + ) -> Optional[tuple[List[str], List[Any]]]: + """ + Gather writable pv name - data pairs recursively. + If pv_list and data_list are provided, gathered data will be added to + these lists in-place. If both lists are omitted, this function will return + the two lists after gathering. + + Queries the backend to fill any UUID values found. + + Parameters + ---------- + entry : Union[Setpoint, Snapshot, UUID] + Entry to gather writable data from + pv_list : Optional[List[str]], optional + List of addresses to write data to, by default None + data_list : Optional[List[Any]], optional + List of data to write to addresses in ``pv_list``, by default None + + Returns + ------- + Optional[tuple[List[str], List[Any]]] + the filled pv_list and data_list + """ + top_level = False + if (pv_list is None) and (data_list is None): + pv_list = [] + data_list = [] + top_level = True + elif (pv_list is None) or (data_list is None): + raise ValueError( + "Arguments pv_list and data_list must either both be provided " + "or both omitted." + ) + + if isinstance(entry, Snapshot): + for child in entry.children: + self._gather_data(child, pv_list, data_list) + elif isinstance(entry, UUID): + child_entry = self.backend.get_entry(entry) + self._gather_data(child_entry, pv_list, data_list) + elif isinstance(entry, Setpoint): + pv_list.append(entry.pv_name) + data_list.append(entry.data) + + # Readbacks are not writable, and are not gathered + + if top_level: + return pv_list, data_list def validate(self, entry: Entry): """ diff --git a/superscore/control_layers/__init__.py b/superscore/control_layers/__init__.py index af7fe27..877e1ae 100644 --- a/superscore/control_layers/__init__.py +++ b/superscore/control_layers/__init__.py @@ -1 +1,2 @@ from .core import ControlLayer # noqa +from .status import TaskStatus # noqa diff --git a/superscore/control_layers/core.py b/superscore/control_layers/core.py index 66391e0..1fda04b 100644 --- a/superscore/control_layers/core.py +++ b/superscore/control_layers/core.py @@ -145,7 +145,7 @@ async def status_coro(): status = self._put_one(address, value) if cb is not None: status.add_callback(cb) - await status.task + await asyncio.gather(status, return_exceptions=True) return status return asyncio.run(status_coro()) @@ -185,7 +185,7 @@ async def status_coros(): status.add_callback(c) statuses.append(status) - await asyncio.gather(*[s.task for s in statuses]) + await asyncio.gather(*statuses, return_exceptions=True) return statuses return asyncio.run(status_coros()) @@ -193,7 +193,7 @@ async def status_coros(): @TaskStatus.wrap async def _put_one(self, address: str, value: Any): """ - Base async get function. Use this to construct higher-level get methods + Base async put function. Use this to construct higher-level put methods """ shim = self.shim_from_pv(address) await shim.put(address, value) diff --git a/superscore/control_layers/status.py b/superscore/control_layers/status.py index c463a1f..4186509 100644 --- a/superscore/control_layers/status.py +++ b/superscore/control_layers/status.py @@ -10,8 +10,10 @@ class TaskStatus: """ Unified Status object for wrapping task completion information and attaching - callbacks This must be created inside of a coroutine, but can be returned to - synchronous scope for examining the task + callbacks. This must be created inside of a coroutine, but can be returned to + synchronous scope for examining the task. + + Awaiting this status is similar to awaiting the wrapped task. Largely vendored from bluesky/ophyd-async """ @@ -57,6 +59,27 @@ def success(self) -> bool: and self.task.exception() is None ) + def wait(self, timeout=None) -> None: + """ + Block until the coroutine finishes. Raises asyncio.TimeoutError if + the timeout elapses before the task is completed + + To be called in a synchronous context, if the status has not been awaited + + Parameters + ---------- + timeout : number, optional + timeout in seconds, by default None + + Raises + ------ + asyncio.TimeoutError + """ + # ensure task runs in the event loop it was assigned to originally + asyncio.get_event_loop().run_until_complete( + asyncio.wait_for(self.task, timeout) + ) + def __repr__(self) -> str: if self.done: if e := self.exception(): diff --git a/superscore/tests/conftest.py b/superscore/tests/conftest.py index 372d6ca..d30ff04 100644 --- a/superscore/tests/conftest.py +++ b/superscore/tests/conftest.py @@ -1,12 +1,14 @@ import shutil from pathlib import Path from typing import List +from unittest.mock import MagicMock import pytest from superscore.backends.core import _Backend from superscore.backends.filestore import FilestoreBackend from superscore.backends.test import TestBackend +from superscore.client import Client from superscore.control_layers._base_shim import _BaseShim from superscore.control_layers.core import ControlLayer from superscore.model import (Collection, Parameter, Readback, Root, Setpoint, @@ -697,3 +699,28 @@ def dummy_cl() -> ControlLayer: cl.shims['ca'] = DummyShim() cl.shims['pva'] = DummyShim() return cl + + +@pytest.fixture(scope='function') +def mock_backend() -> _Backend: + bk = _Backend() + bk.delete_entry = MagicMock() + bk.save_entry = MagicMock() + bk.get_entry = MagicMock() + bk.search = MagicMock() + bk.update_entry = MagicMock() + + +class MockTaskStatus: + def exception(self): + return None + + @property + def done(self): + return True + + +@pytest.fixture(scope='function') +def mock_client(mock_backend: _Backend) -> Client: + client = Client(backend=mock_backend) + return client diff --git a/superscore/tests/test_cl.py b/superscore/tests/test_cl.py index a135112..cc8dfee 100644 --- a/superscore/tests/test_cl.py +++ b/superscore/tests/test_cl.py @@ -38,9 +38,11 @@ def test_fail(dummy_cl): mock_ca_put = AsyncMock(side_effect=ValueError) dummy_cl.shims['ca'].put = mock_ca_put - # exceptions get passed through the control layer - with pytest.raises(ValueError): - dummy_cl.put("THAT:PV", 4) + # exceptions get captured in status object + status = dummy_cl.put("THAT:PV", 4) + assert isinstance(status.exception(), ValueError) + + assert mock_ca_put.called def test_put_callback(dummy_cl): diff --git a/superscore/tests/test_client.py b/superscore/tests/test_client.py new file mode 100644 index 0000000..3152d7a --- /dev/null +++ b/superscore/tests/test_client.py @@ -0,0 +1,21 @@ +from unittest.mock import patch + +from superscore.client import Client +from superscore.model import Root + +from .conftest import MockTaskStatus + + +@patch('superscore.control_layers.core.ControlLayer.put') +def test_apply(put_mock, mock_client: Client, sample_database: Root): + put_mock.return_value = MockTaskStatus() + snap = sample_database.entries[3] + mock_client.apply(snap) + assert put_mock.call_count == 1 + call_args = put_mock.call_args[0] + assert len(call_args[0]) == len(call_args[1]) == 3 + + put_mock.reset_mock() + + mock_client.apply(snap, sequential=True) + assert put_mock.call_count == 3 diff --git a/superscore/tests/test_status.py b/superscore/tests/test_status.py index e2678bc..1a83433 100644 --- a/superscore/tests/test_status.py +++ b/superscore/tests/test_status.py @@ -23,6 +23,17 @@ async def inner_coroutine(): return inner_coroutine +@pytest.fixture +async def long_coroutine_status() -> TaskStatus: + @TaskStatus.wrap + async def inner_coroutine(): + for i in range(100): + print(f'coro wait: {i}') + await asyncio.sleep(1) + + return inner_coroutine() + + async def test_status_success(normal_coroutine): st = TaskStatus(normal_coroutine()) assert isinstance(st, TaskStatus) @@ -40,7 +51,34 @@ async def test_status_fail(failing_coroutine): with pytest.raises(ValueError): await status - assert type(status.exception()) == ValueError + assert isinstance(status.exception(), ValueError) + + +def test_sync_status_fail(failing_coroutine): + # A usage note for the curious. If we gather these tasks with + # `return_exceptions` = False (default), the first exception will be propagated, + # though the other tasks will complete. This may stop tasks from being returned + # `retur_exceptions` = True will not raise exceptions, instead those exceptions + # will only be captured in `task.exception()` + async def wrap_coro(return_exc: bool): + status = TaskStatus(failing_coroutine()) + await asyncio.gather(status, return_exceptions=return_exc) + return status + + status = asyncio.run(wrap_coro(True)) + assert status.done + assert isinstance(status.exception(), ValueError) + + with pytest.raises(ValueError): + asyncio.run(wrap_coro(False)) + + +def test_status_wait(long_coroutine_status): + assert not long_coroutine_status.done + with pytest.raises(asyncio.TimeoutError): + long_coroutine_status.wait(1) + assert long_coroutine_status.done + assert isinstance(long_coroutine_status.exception(), asyncio.CancelledError) async def test_status_wrap(): @@ -50,5 +88,5 @@ async def coro_status(): st = coro_status() assert isinstance(st, TaskStatus) - await st.task + await st assert st.done