diff --git a/doc/callbacks/plotting.md b/doc/callbacks/plotting.md index 5bfffba5..899e94d0 100644 --- a/doc/callbacks/plotting.md +++ b/doc/callbacks/plotting.md @@ -31,15 +31,23 @@ to plot a scan with a logarithmically-scaled y-axis: ```python import matplotlib.pyplot as plt from ibex_bluesky_core.callbacks.plotting import LivePlot -# Create a new figure to plot onto. -plt.figure() -# Make a new set of axes on that figure -ax = plt.gca() -# Set the y-scale to logarithmic -ax.set_yscale("log") -# Use the above axes in a LivePlot callback -plot_callback = LivePlot(y="y_variable", x="x_variable", ax=ax, yerr="yerr_variable") -# yerr is the uncertanties of each y value, producing error bars +from ibex_bluesky_core.plan_stubs import call_qt_aware + +def plan(): + # Create a new figure to plot onto. + yield from call_qt_aware(plt.figure) + # Make a new set of axes on that figure + ax = yield from call_qt_aware(plt.gca) + # Set the y-scale to logarithmic + yield from call_qt_aware(ax.set_yscale, "log") + # Use the above axes in a LivePlot callback + plot_callback = LivePlot(y="y_variable", x="x_variable", ax=ax, yerr="yerr_variable") + # yerr is the uncertanties of each y value, producing error bars +``` + +```{note} +See [docs for `call_qt_aware`](../plan_stubs/matplotlib_helpers.md) for a description of why we need to use +`yield from call_qt_aware` rather than calling `matplotlib` functions directly. ``` By providing a signal name to the `yerr` argument you can pass uncertainties to LivePlot, by not providing anything for this argument means that no errorbars will be drawn. Errorbars are drawn after each point collected, displaying their standard deviation- uncertainty data is collected from Bluesky event documents and errorbars are updated after every new point added. diff --git a/doc/devices/blocks.md b/doc/devices/blocks.md index 93fa40ce..9788e7cd 100644 --- a/doc/devices/blocks.md +++ b/doc/devices/blocks.md @@ -131,12 +131,16 @@ from ibex_bluesky_core.devices.block import block_mot mot_block = block_mot("motor_block") ``` +A motor block does not need an explicit write config: it always waits for the requested motion +to complete. See {py:obj}`ibex_bluesky_core.devices.block.BlockMot` for a detailed mapping of +the usual write-configuration options and how these are instead achieved by a motor block. + ## Configuring block write behaviour `BlockRw` and `BlockRwRbv` both take a `write_config` argument, which can be used to configure the behaviour on writing to a block, for example tolerances and settle times. -See the docstring on `ibex_bluesky_core.devices.block.BlockWriteConfig` for a detailed +See {py:class}`ibex_bluesky_core.devices.block.BlockWriteConfig` for a detailed description of all the options which are available. ## Run control diff --git a/doc/plan_stubs/matplotlib_helpers.md b/doc/plan_stubs/matplotlib_helpers.md new file mode 100644 index 00000000..74d66b1f --- /dev/null +++ b/doc/plan_stubs/matplotlib_helpers.md @@ -0,0 +1,50 @@ +# `call_qt_aware` (matplotlib helpers) + +When attempting to use `matplotlib` UI functions directly in a plan, and running `matplotlib` using a `Qt` +backend (e.g. in a standalone shell outside IBEX), you may see a hang or an error of the form: + +``` +UserWarning: Starting a Matplotlib GUI outside of the main thread will likely fail. + fig, ax = plt.subplots() +``` + +This is because the `RunEngine` runs plans in a worker thread, not in the main thread, which then requires special +handling when calling functions that will update a UI. + +The {py:obj}`ibex_bluesky_core.plan_stubs.call_qt_aware` plan stub can call `matplotlib` functions in a +Qt-aware context, which allows them to be run directly from a plan. It allows the same arguments and +keyword-arguments as the underlying matplotlib function it is passed. + +```{note} +Callbacks such as `LivePlot` and `LiveFitPlot` already route UI calls to the appropriate UI thread by default. +The following plan stubs are only necessary if you need to call functions which will create or update a matplotlib +plot from a plan directly - for example to create or close a set of axes before passing them to callbacks. +``` + +Usage example: + +```python +import matplotlib.pyplot as plt +from ibex_bluesky_core.plan_stubs import call_qt_aware +from ibex_bluesky_core.callbacks.plotting import LivePlot +from bluesky.callbacks import LiveFitPlot +from bluesky.preprocessors import subs_decorator + + +def my_plan(): + # BAD - likely to either crash or hang the plan. + # plt.close("all") + # fig, ax = plt.subplots() + + # GOOD + yield from call_qt_aware(plt.close, "all") + fig, ax = yield from call_qt_aware(plt.subplots) + + # Pass the matplotlib ax object to other callbacks + @subs_decorator([ + LiveFitPlot(..., ax=ax), + LivePlot(..., ax=ax), + ]) + def inner_plan(): + ... +``` diff --git a/manual_system_tests/dae_scan.py b/manual_system_tests/dae_scan.py index dd8a55c4..e488c1f8 100644 --- a/manual_system_tests/dae_scan.py +++ b/manual_system_tests/dae_scan.py @@ -1,8 +1,8 @@ """Demonstration plan showing basic bluesky functionality.""" import os +from collections.abc import Generator from pathlib import Path -from typing import Generator import bluesky.plan_stubs as bps import bluesky.plans as bp @@ -28,6 +28,7 @@ GoodFramesNormalizer, ) from ibex_bluesky_core.devices.simpledae.waiters import GoodFramesWaiter +from ibex_bluesky_core.plan_stubs import call_qt_aware from ibex_bluesky_core.run_engine import get_run_engine NUM_POINTS: int = 3 @@ -72,7 +73,8 @@ def dae_scan_plan() -> Generator[Msg, None, None]: controller.run_number.set_name("run number") reducer.intensity.set_name("normalized counts") - _, ax = plt.subplots() + _, ax = yield from call_qt_aware(plt.subplots) + lf = LiveFit( Linear.fit(), y=reducer.intensity.name, x=block.name, yerr=reducer.intensity_stddev.name ) diff --git a/manual_system_tests/interruption.py b/manual_system_tests/interruption.py index 152a42a7..dacdb3e4 100644 --- a/manual_system_tests/interruption.py +++ b/manual_system_tests/interruption.py @@ -1,7 +1,7 @@ """Demonstration plan showing basic bluesky functionality.""" import os -from typing import Generator +from collections.abc import Generator import bluesky.plan_stubs as bps from bluesky.utils import Msg diff --git a/ruff.toml b/ruff.toml index 6323b8e3..e8b272de 100644 --- a/ruff.toml +++ b/ruff.toml @@ -2,31 +2,49 @@ line-length = 100 indent-width = 4 [lint] +preview = true extend-select = [ "N", # pep8-naming - "D", # pydocstyle (can use this later but for now causes too many errors) + "D", # pydocstyle "I", # isort (for imports) "E501", # Line too long ({width} > {limit}) - "E", - "F", - "W", - "ANN", - "ASYNC", # Asyncio-specific checks - "B", - "NPY", # Numpy-specific rules - "RUF", # Ruff-specific checks, include some useful asyncio rules + "E", # Pycodestyle errors + "W", # Pycodestyle warnings + "F", # Pyflakes + "PL", # Pylint + "B", # Flake8-bugbear + "PIE", # Flake8-pie + "ANN", # Annotations + "ASYNC", # Asyncio-specific checks + "NPY", # Numpy-specific rules + "RUF", # Ruff-specific checks, include some useful asyncio rules + "FURB", # Rules from refurb + "ERA", # Commented-out code + "PT", # Pytest-specific rules + "LOG", # Logging-specific rules + "G", # Logging-specific rules + "UP", # Pyupgrade + "SLF", # Private-member usage + "PERF", # Performance-related rules ] ignore = [ "D406", # Section name should end with a newline ("{name}") "D407", # Missing dashed underline after section ("{name}") - "D213", # Incompatible with D212 - "D203", # Incompatible with D211 + "D213", # Incompatible with D212 + "D203", # Incompatible with D211 + "B901", # This is a standard, expected, pattern in bluesky + "PLR6301" # Too noisy ] [lint.per-file-ignores] "tests/*" = [ - "N802", - "D", # Don't require method documentation for test methods - "ANN" # Don't require tests to use type annotations + "N802", # Allow test names to be long / not pep8 + "D", # Don't require method documentation for test methods + "ANN", # Don't require tests to use type annotations + "PLR2004", # Allow magic numbers in tests + "PLR0915", # Allow complex tests + "PLR0914", # Allow complex tests + "PLC2701", # Allow tests to import "private" things + "SLF001", # Allow tests to use "private" things ] "doc/conf.py" = [ "D100" @@ -34,3 +52,6 @@ ignore = [ [lint.pep8-naming] extend-ignore-names = ["RE"] # Conventional name used for RunEngine + +[lint.pylint] +max-args = 6 diff --git a/src/ibex_bluesky_core/callbacks/document_logger.py b/src/ibex_bluesky_core/callbacks/document_logger.py index a11b0d5d..77290cfc 100644 --- a/src/ibex_bluesky_core/callbacks/document_logger.py +++ b/src/ibex_bluesky_core/callbacks/document_logger.py @@ -34,5 +34,5 @@ def __call__(self, name: str, document: dict[str, Any]) -> None: to_write: dict[str, Any] = {"type": name, "document": document} - with open(self.filename, "a") as outfile: + with open(self.filename, "a", encoding="utf8") as outfile: outfile.write(f"{json.dumps(to_write)}\n") diff --git a/src/ibex_bluesky_core/callbacks/file_logger.py b/src/ibex_bluesky_core/callbacks/file_logger.py index e059c910..6e2ec6fc 100644 --- a/src/ibex_bluesky_core/callbacks/file_logger.py +++ b/src/ibex_bluesky_core/callbacks/file_logger.py @@ -65,7 +65,7 @@ def start(self, doc: RunStart) -> None: ) header_data[START_TIME] = formatted_time - with open(self.filename, "a", newline="") as outfile: + with open(self.filename, "a", newline="", encoding="utf-8") as outfile: for key, value in header_data.items(): outfile.write(f"{key}: {value}\n") @@ -102,7 +102,7 @@ def event(self, doc: Event) -> Event: else value ) - with open(self.filename, "a", newline="") as outfile: + with open(self.filename, "a", newline="", encoding="utf-8") as outfile: file_delimiter = "," if doc[SEQ_NUM] == 1: # If this is the first event, write out the units before writing event data. diff --git a/src/ibex_bluesky_core/callbacks/fitting/fitting_utils.py b/src/ibex_bluesky_core/callbacks/fitting/fitting_utils.py index 25c05d23..f7502a1c 100644 --- a/src/ibex_bluesky_core/callbacks/fitting/fitting_utils.py +++ b/src/ibex_bluesky_core/callbacks/fitting/fitting_utils.py @@ -32,7 +32,6 @@ def model(cls, *args: int) -> lmfit.Model: (x-values: NDArray, parameters: np.float64 -> y-values: NDArray) """ - pass @classmethod @abstractmethod @@ -49,7 +48,6 @@ def guess( (x-values: NDArray, y-values: NDArray -> parameters: Dict[str, lmfit.Parameter]) """ - pass @classmethod def fit(cls, *args: int) -> FitMethod: @@ -208,8 +206,9 @@ class Polynomial(Fit): @classmethod def _check_degree(cls, args: tuple[int, ...]) -> int: """Check that polynomial degree is valid.""" - degree = args[0] if args else 7 - if not (0 <= degree <= 7): + max_degree = 7 + degree = args[0] if args else max_degree + if not (0 <= degree <= max_degree): raise ValueError("The polynomial degree should be at least 0 and smaller than 8.") return degree diff --git a/src/ibex_bluesky_core/callbacks/plotting.py b/src/ibex_bluesky_core/callbacks/plotting.py index 5547346b..2feb4367 100644 --- a/src/ibex_bluesky_core/callbacks/plotting.py +++ b/src/ibex_bluesky_core/callbacks/plotting.py @@ -37,7 +37,7 @@ def __init__( """ super().__init__(y=y, x=x, *args, **kwargs) # noqa: B026 if yerr is not None: - self.yerr, *others = get_obj_fields([yerr]) + self.yerr, *_others = get_obj_fields([yerr]) else: self.yerr = None self.yerr_data = [] diff --git a/src/ibex_bluesky_core/devices/__init__.py b/src/ibex_bluesky_core/devices/__init__.py index a96f78e7..306e63f5 100644 --- a/src/ibex_bluesky_core/devices/__init__.py +++ b/src/ibex_bluesky_core/devices/__init__.py @@ -5,7 +5,7 @@ import binascii import os import zlib -from typing import Type, TypeVar +from typing import TypeVar from ophyd_async.core import SignalDatatype, SignalRW from ophyd_async.epics.core import epics_signal_rw @@ -18,7 +18,7 @@ def get_pv_prefix() -> str: prefix = os.getenv("MYPVPREFIX") if prefix is None: - raise EnvironmentError("MYPVPREFIX environment variable not available - please define") + raise OSError("MYPVPREFIX environment variable not available - please define") return prefix @@ -48,7 +48,7 @@ def compress_and_hex(value: str) -> bytes: return binascii.hexlify(compr) -def isis_epics_signal_rw(datatype: Type[T], read_pv: str, name: str = "") -> SignalRW[T]: +def isis_epics_signal_rw(datatype: type[T], read_pv: str, name: str = "") -> SignalRW[T]: """Make a RW signal with ISIS' PV naming standard ie. read_pv as TITLE, write_pv as TITLE:SP.""" write_pv = f"{read_pv}:SP" return epics_signal_rw(datatype, read_pv, write_pv, name) diff --git a/src/ibex_bluesky_core/devices/block.py b/src/ibex_bluesky_core/devices/block.py index bab6476e..98269398 100644 --- a/src/ibex_bluesky_core/devices/block.py +++ b/src/ibex_bluesky_core/devices/block.py @@ -3,7 +3,7 @@ import asyncio import logging from dataclasses import dataclass -from typing import Callable, Generic, Type, TypeVar +from typing import Callable, Generic, TypeVar from bluesky.protocols import Locatable, Location, Movable, Triggerable from ophyd_async.core import ( @@ -14,6 +14,7 @@ StandardReadable, StandardReadableFormat, observe_value, + wait_for_value, ) from ophyd_async.epics.core import epics_signal_r, epics_signal_rw from ophyd_async.epics.motor import Motor @@ -39,8 +40,13 @@ "block_rw_rbv", ] +# When using the global moving flag, we want to give IOCs enough time to update the +# global flag before checking it. This is an amount of time always applied before +# looking at the global moving flag. +GLOBAL_MOVING_FLAG_PRE_WAIT = 0.1 -@dataclass(kw_only=True) + +@dataclass(kw_only=True, frozen=True) class BlockWriteConfig(Generic[T]): """Configuration settings for writing to blocks. @@ -77,12 +83,19 @@ def check(setpoint: T, actual: T) -> bool: A wait time, in seconds, which is unconditionally applied just before the set status is marked as complete. Defaults to zero. + use_global_moving_flag: + Whether to wait for the IBEX global moving indicator to return "stationary". This is useful + for compound moves, where changing a single block may cause multiple underlying axes to + move, and all movement needs to be complete before the set is considered complete. Defaults + to False. + """ use_completion_callback: bool = True set_success_func: Callable[[T, T], bool] | None = None set_timeout_s: float | None = None settle_time_s: float = 0.0 + use_global_moving_flag: bool = False class RunControl(StandardReadable): @@ -119,7 +132,7 @@ def __init__(self, prefix: str, name: str = "") -> None: class BlockR(StandardReadable, Triggerable, Generic[T]): """Device representing an IBEX readable block of arbitrary data type.""" - def __init__(self, datatype: Type[T], prefix: str, block_name: str) -> None: + def __init__(self, datatype: type[T], prefix: str, block_name: str) -> None: """Create a new read-only block. Args: @@ -154,7 +167,7 @@ class BlockRw(BlockR[T], Movable): def __init__( self, - datatype: Type[T], + datatype: type[T], prefix: str, block_name: str, *, @@ -185,6 +198,11 @@ def __init__( self._write_config: BlockWriteConfig[T] = write_config or BlockWriteConfig() + if self._write_config.use_global_moving_flag: + # Misleading PV name... says it's a str but it's really a bi record. + # Only link to this if we need to (i.e. if use_global_moving_flag was requested) + self.global_moving = epics_signal_r(bool, f"{prefix}CS:MOT:MOVING:STR") + super().__init__(datatype=datatype, prefix=prefix, block_name=block_name) @AsyncStatus.wrap @@ -198,6 +216,21 @@ async def do_set(setpoint: T) -> None: ) logger.info("Got completion callback from setting block %s to %s", self.name, setpoint) + if self._write_config.use_global_moving_flag: + logger.info( + "Waiting for global moving flag on setting block %s to %s", self.name, setpoint + ) + # Paranoid sleep - ensure that the global flag has had a chance to go into moving, + # otherwise there could be a race condition where we check the flag before the move + # has even started. + await asyncio.sleep(GLOBAL_MOVING_FLAG_PRE_WAIT) + await wait_for_value(self.global_moving, False, timeout=None) + logger.info( + "Done wait for global moving flag on setting block %s to %s", + self.name, + setpoint, + ) + # Wait for the _set_success_func to return true. # This uses an "async for" to loop over items from observe_value, which is an async # generator. See documentation on "observe_value" or python "async for" for more details @@ -231,7 +264,7 @@ class BlockRwRbv(BlockRw[T], Locatable): def __init__( self, - datatype: Type[T], + datatype: type[T], prefix: str, block_name: str, *, @@ -282,6 +315,7 @@ def __init__( """Create a new motor-record block. The 'BlockMot' object supports motion-specific functionality such as: + - Stopping if a scan is aborted (supports the bluesky 'Stoppable' protocol) - Limit checking (before a move starts - supports the bluesky 'Checkable' protocol) - Automatic calculation of move timeouts based on motor velocity @@ -308,6 +342,9 @@ def __init__( keyword-argument on set(). settle_time_s: Use .DLY on the motor record to configure this. + use_global_moving_flag: + This is unnecessary for a single motor block, as a completion callback will always be + used instead to detect when a single move has finished. """ self.run_control = RunControl(f"{prefix}CS:SB:{block_name}:RC:") @@ -326,7 +363,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(name={self.name})" -def block_r(datatype: Type[T], block_name: str) -> BlockR[T]: +def block_r(datatype: type[T], block_name: str) -> BlockR[T]: """Get a local read-only block for the current instrument. See documentation of BlockR for more information. @@ -335,7 +372,7 @@ def block_r(datatype: Type[T], block_name: str) -> BlockR[T]: def block_rw( - datatype: Type[T], block_name: str, *, write_config: BlockWriteConfig[T] | None = None + datatype: type[T], block_name: str, *, write_config: BlockWriteConfig[T] | None = None ) -> BlockRw[T]: """Get a local read-write block for the current instrument. @@ -347,7 +384,7 @@ def block_rw( def block_rw_rbv( - datatype: Type[T], block_name: str, *, write_config: BlockWriteConfig[T] | None = None + datatype: type[T], block_name: str, *, write_config: BlockWriteConfig[T] | None = None ) -> BlockRwRbv[T]: """Get a local read/write/setpoint readback block for the current instrument. diff --git a/src/ibex_bluesky_core/devices/dae/__init__.py b/src/ibex_bluesky_core/devices/dae/__init__.py index 81a506c3..f0bfd31a 100644 --- a/src/ibex_bluesky_core/devices/dae/__init__.py +++ b/src/ibex_bluesky_core/devices/dae/__init__.py @@ -1,11 +1,11 @@ """Utilities for the DAE device - mostly XML helpers.""" from enum import Enum -from typing import Any, Dict, List +from typing import Any from xml.etree.ElementTree import Element -def convert_xml_to_names_and_values(xml: Element) -> Dict[str, str]: +def convert_xml_to_names_and_values(xml: Element) -> dict[str, str]: """Convert an XML element's children to a dict containing .text:.text.""" names_and_values = dict() elements = get_all_elements_in_xml_with_child_called_name(xml) @@ -16,7 +16,7 @@ def convert_xml_to_names_and_values(xml: Element) -> Dict[str, str]: return names_and_values -def get_all_elements_in_xml_with_child_called_name(xml: Element) -> List[Element]: +def get_all_elements_in_xml_with_child_called_name(xml: Element) -> list[Element]: """Find all elements with a "name" element, but ignore the first one as it's the root.""" elements = xml.findall("*/Name/..") return elements @@ -32,7 +32,7 @@ def _get_names_and_values(element: Element) -> tuple[Any, Any] | tuple[None, Non def set_value_in_dae_xml( - elements: List[Element], name: str, value: str | Enum | int | None | float + elements: list[Element], name: str, value: str | Enum | int | float | None ) -> None: """Find and set a value in the DAE XML, given a name and value. diff --git a/src/ibex_bluesky_core/devices/dae/dae_period_settings.py b/src/ibex_bluesky_core/devices/dae/dae_period_settings.py index 66ec95e5..341c9425 100644 --- a/src/ibex_bluesky_core/devices/dae/dae_period_settings.py +++ b/src/ibex_bluesky_core/devices/dae/dae_period_settings.py @@ -4,7 +4,6 @@ import xml.etree.ElementTree as ET from dataclasses import dataclass from enum import Enum -from typing import List from xml.etree.ElementTree import tostring from bluesky.protocols import Locatable, Location, Movable @@ -58,13 +57,13 @@ class SinglePeriodSettings: class DaePeriodSettingsData: """Dataclass for the hardware period settings.""" - periods_settings: List[SinglePeriodSettings] | None = None - periods_soft_num: None | int = None + periods_settings: list[SinglePeriodSettings] | None = None + periods_soft_num: int | None = None periods_type: PeriodType | None = None periods_src: PeriodSource | None = None - periods_file: None | str = None - periods_seq: None | int = None - periods_delay: None | int = None + periods_file: str | None = None + periods_seq: int | None = None + periods_delay: int | None = None def _convert_xml_to_period_settings(value: str) -> DaePeriodSettingsData: diff --git a/src/ibex_bluesky_core/devices/dae/dae_tcb_settings.py b/src/ibex_bluesky_core/devices/dae/dae_tcb_settings.py index b272fa42..9d7c9989 100644 --- a/src/ibex_bluesky_core/devices/dae/dae_tcb_settings.py +++ b/src/ibex_bluesky_core/devices/dae/dae_tcb_settings.py @@ -4,7 +4,6 @@ import xml.etree.ElementTree as ET from dataclasses import dataclass from enum import Enum -from typing import Dict from xml.etree.ElementTree import tostring from bluesky.protocols import Locatable, Location, Movable @@ -66,14 +65,14 @@ class TimeRegimeRow: class TimeRegime: """Time regime - contains a dict(rows) which is row_number:TimeRegimeRow.""" - rows: Dict[int, TimeRegimeRow] + rows: dict[int, TimeRegimeRow] @dataclass(kw_only=True) class DaeTCBSettingsData: """Dataclass for the DAE TCB settings.""" - tcb_tables: Dict[int, TimeRegime] | None = None + tcb_tables: dict[int, TimeRegime] | None = None tcb_file: str | None = None time_unit: TimeUnit | None = None tcb_calculation_method: CalculationMethod | None = None diff --git a/src/ibex_bluesky_core/devices/simpledae/__init__.py b/src/ibex_bluesky_core/devices/simpledae/__init__.py index e99e1186..27a67ef5 100644 --- a/src/ibex_bluesky_core/devices/simpledae/__init__.py +++ b/src/ibex_bluesky_core/devices/simpledae/__init__.py @@ -51,9 +51,9 @@ def __init__( """ self.prefix = prefix - self.controller: "Controller" = controller - self.waiter: "Waiter" = waiter - self.reducer: "Reducer" = reducer + self.controller: Controller = controller + self.waiter: Waiter = waiter + self.reducer: Reducer = reducer logger.info( "created simpledae with prefix=%s, controller=%s, waiter=%s, reducer=%s", @@ -72,8 +72,7 @@ def __init__( # published when the top-level SimpleDae object is read. extra_readables = set() for strategy in [self.controller, self.waiter, self.reducer]: - for sig in strategy.additional_readable_signals(self): - extra_readables.add(sig) + extra_readables.update(strategy.additional_readable_signals(self)) logger.info("extra readables: %s", list(extra_readables)) self.add_readables(devices=list(extra_readables)) diff --git a/src/ibex_bluesky_core/devices/simpledae/controllers.py b/src/ibex_bluesky_core/devices/simpledae/controllers.py index 771fe802..12fe7742 100644 --- a/src/ibex_bluesky_core/devices/simpledae/controllers.py +++ b/src/ibex_bluesky_core/devices/simpledae/controllers.py @@ -83,7 +83,7 @@ async def start_counting(self, dae: "SimpleDae") -> None: await dae.controls.resume_run.trigger(wait=True, timeout=None) await wait_for_value( dae.run_state, - lambda v: v in [RunstateEnum.RUNNING, RunstateEnum.WAITING, RunstateEnum.VETOING], + lambda v: v in {RunstateEnum.RUNNING, RunstateEnum.WAITING, RunstateEnum.VETOING}, timeout=10, ) @@ -130,7 +130,7 @@ async def start_counting(self, dae: "SimpleDae") -> None: await dae.controls.begin_run.trigger(wait=True, timeout=None) await wait_for_value( dae.run_state, - lambda v: v in [RunstateEnum.RUNNING, RunstateEnum.WAITING, RunstateEnum.VETOING], + lambda v: v in {RunstateEnum.RUNNING, RunstateEnum.WAITING, RunstateEnum.VETOING}, timeout=10, ) diff --git a/src/ibex_bluesky_core/devices/simpledae/reducers.py b/src/ibex_bluesky_core/devices/simpledae/reducers.py index cc2e8838..0c9b431c 100644 --- a/src/ibex_bluesky_core/devices/simpledae/reducers.py +++ b/src/ibex_bluesky_core/devices/simpledae/reducers.py @@ -3,8 +3,9 @@ import asyncio import logging import math -from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Collection, Sequence +from abc import ABC, abstractmethod +from collections.abc import Collection, Sequence +from typing import TYPE_CHECKING import scipp as sc from ophyd_async.core import ( @@ -41,7 +42,7 @@ async def sum_spectra(spectra: Collection[DaeSpectra]) -> sc.Variable | sc.DataA return summed_counts -class ScalarNormalizer(Reducer, StandardReadable, metaclass=ABCMeta): +class ScalarNormalizer(Reducer, StandardReadable, ABC): """Sum a set of user-specified spectra, then normalize by a scalar signal.""" def __init__(self, prefix: str, detector_spectra: Sequence[int]) -> None: diff --git a/src/ibex_bluesky_core/devices/simpledae/waiters.py b/src/ibex_bluesky_core/devices/simpledae/waiters.py index 258e870a..af321775 100644 --- a/src/ibex_bluesky_core/devices/simpledae/waiters.py +++ b/src/ibex_bluesky_core/devices/simpledae/waiters.py @@ -2,7 +2,7 @@ import asyncio import logging -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Generic, TypeVar from ophyd_async.core import ( @@ -22,7 +22,7 @@ T = TypeVar("T", int, float) -class SimpleWaiter(Waiter, Generic[T], metaclass=ABCMeta): +class SimpleWaiter(Waiter, Generic[T], ABC): """Wait for a single DAE variable to be greater or equal to a specified numeric value.""" def __init__(self, value: T) -> None: diff --git a/src/ibex_bluesky_core/plan_stubs/__init__.py b/src/ibex_bluesky_core/plan_stubs/__init__.py index 803e43b5..3e1b4fbe 100644 --- a/src/ibex_bluesky_core/plan_stubs/__init__.py +++ b/src/ibex_bluesky_core/plan_stubs/__init__.py @@ -1,6 +1,7 @@ """Core plan stubs.""" -from typing import Callable, Generator, ParamSpec, TypeVar, cast +from collections.abc import Generator +from typing import Callable, ParamSpec, TypeVar, cast import bluesky.plan_stubs as bps from bluesky.utils import Msg @@ -10,6 +11,10 @@ CALL_SYNC_MSG_KEY = "ibex_bluesky_core_call_sync" +CALL_QT_AWARE_MSG_KEY = "ibex_bluesky_core_call_qt_aware" + + +__all__ = ["call_qt_aware", "call_sync"] def call_sync(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Generator[Msg, None, T]: @@ -35,9 +40,41 @@ def call_sync(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Genera Args: func: A callable to run. - args: Arbitrary arguments to be passed to the wrapped function - kwargs: Arbitrary keyword arguments to be passed to the wrapped function + *args: Arbitrary arguments to be passed to the wrapped function + **kwargs: Arbitrary keyword arguments to be passed to the wrapped function + + Returns: + The return value of the wrapped function """ yield from bps.clear_checkpoint() return cast(T, (yield Msg(CALL_SYNC_MSG_KEY, func, *args, **kwargs))) + + +def call_qt_aware( + func: Callable[P, T], *args: P.args, **kwargs: P.kwargs +) -> Generator[Msg, None, T]: + """Call a matplotlib function in a Qt-aware context, from within a plan. + + If matplotlib is using a Qt backend then UI operations are run on the Qt thread via Qt signals. + + Only matplotlib functions may be run using this plan stub. + + Args: + func: A matplotlib function reference. + *args: Arbitrary arguments, passed through to matplotlib.pyplot.subplots + **kwargs: Arbitrary keyword arguments, passed through to matplotlib.pyplot.subplots + + Raises: + ValueError: if the passed function is not a matplotlib function. + + Returns: + The return value of the wrapped function + + """ + # Limit potential for misuse - constrain to just running matplotlib functions. + if not getattr(func, "__module__", "").startswith("matplotlib"): + raise ValueError("Only matplotlib functions should be passed to call_qt_aware") + + yield from bps.clear_checkpoint() + return cast(T, (yield Msg(CALL_QT_AWARE_MSG_KEY, func, *args, **kwargs))) diff --git a/src/ibex_bluesky_core/preprocessors.py b/src/ibex_bluesky_core/preprocessors.py index 8ccaa304..bf2bc964 100644 --- a/src/ibex_bluesky_core/preprocessors.py +++ b/src/ibex_bluesky_core/preprocessors.py @@ -1,7 +1,7 @@ """Bluesky plan preprocessors specific to ISIS.""" import logging -from typing import Generator +from collections.abc import Generator from bluesky import plan_stubs as bps from bluesky import preprocessors as bpp diff --git a/src/ibex_bluesky_core/run_engine/__init__.py b/src/ibex_bluesky_core/run_engine/__init__.py index f6d9344c..19355bc4 100644 --- a/src/ibex_bluesky_core/run_engine/__init__.py +++ b/src/ibex_bluesky_core/run_engine/__init__.py @@ -17,8 +17,8 @@ __all__ = ["get_run_engine"] -from ibex_bluesky_core.plan_stubs import CALL_SYNC_MSG_KEY -from ibex_bluesky_core.run_engine._msg_handlers import call_sync_handler +from ibex_bluesky_core.plan_stubs import CALL_QT_AWARE_MSG_KEY, CALL_SYNC_MSG_KEY +from ibex_bluesky_core.run_engine._msg_handlers import call_qt_aware_handler, call_sync_handler logger = logging.getLogger(__name__) @@ -97,6 +97,7 @@ def get_run_engine() -> RunEngine: RE.subscribe(log_callback) RE.register_command(CALL_SYNC_MSG_KEY, call_sync_handler) + RE.register_command(CALL_QT_AWARE_MSG_KEY, call_qt_aware_handler) RE.preprocessors.append(functools.partial(bpp.plan_mutator, msg_proc=add_rb_number_processor)) diff --git a/src/ibex_bluesky_core/run_engine/_msg_handlers.py b/src/ibex_bluesky_core/run_engine/_msg_handlers.py index 8c821787..8801fc9b 100644 --- a/src/ibex_bluesky_core/run_engine/_msg_handlers.py +++ b/src/ibex_bluesky_core/run_engine/_msg_handlers.py @@ -9,7 +9,9 @@ from asyncio import CancelledError, Event, get_running_loop from typing import Any +from bluesky.callbacks.mpl_plotting import QtAwareCallback from bluesky.utils import Msg +from event_model import RunStart logger = logging.getLogger(__name__) @@ -28,7 +30,7 @@ async def call_sync_handler(msg: Msg) -> Any: # noqa: ANN401 def _wrapper() -> Any: # noqa: ANN401 nonlocal ret, exc - logger.info("Running '{func.__name__}' with args=({msg.args}), kwargs=({msg.kwargs})") + logger.info("Running '%s' with args=(%s), kwargs=(%s)", func.__name__, msg.args, msg.kwargs) try: ret = func(*msg.args, **msg.kwargs) logger.debug("Running '%s' successful", func.__name__) @@ -94,3 +96,50 @@ def _wrapper() -> Any: # noqa: ANN401 logger.debug("Re-raising %s thrown by %s", exc.__class__.__name__, func.__name__) raise exc return ret + + +async def call_qt_aware_handler(msg: Msg) -> Any: # noqa: ANN401 + """Handle ibex_bluesky_core.plan_stubs.call_sync.""" + func = msg.obj + done_event = Event() + result: Any = None + exc: BaseException | None = None + loop = get_running_loop() + + # Slightly hacky, this isn't really a callback per-se but we want to benefit from + # bluesky's Qt-matplotlib infrastructure. + # This never gets attached to the RunEngine. + class _Cb(QtAwareCallback): + def start(self, doc: RunStart) -> None: + nonlocal result, exc + try: + logger.info( + "Running '%s' with args=(%s), kwargs=(%s) (Qt)", + func.__name__, + msg.args, + msg.kwargs, + ) + result = func(*msg.args, **msg.kwargs) + logger.debug("Running '%s' (Qt) successful", func.__name__) + except BaseException as e: + logger.error( + "Running '%s' failed with %s: %s", func.__name__, e.__class__.__name__, e + ) + exc = e + finally: + loop.call_soon_threadsafe(done_event.set) + + cb = _Cb() + # Send fake event to our callback to trigger it (actual contents unimportant) + # If not using Qt, this will run synchronously i.e. block until complete + # If using Qt, this will be sent off to the Qt teleporter which will execute it asynchronously, + # and we have to wait for the event to be set. + cb("start", {"time": 0, "uid": ""}) + + # Attempting to forcibly interrupt a function while it's doing UI operations/using + # Qt signals is highly likely to be a bad idea. Don't do that here. No special ctrl-c handling. + await done_event.wait() + + if exc is not None: + raise exc + return result diff --git a/tests/callbacks/fitting/test_fitting_methods.py b/tests/callbacks/fitting/test_fitting_methods.py index c5e9e3b4..8fed994d 100644 --- a/tests/callbacks/fitting/test_fitting_methods.py +++ b/tests/callbacks/fitting/test_fitting_methods.py @@ -266,7 +266,9 @@ def test_polynomial_model_order(self, deg: int): # -1 and 8 are both invalid polynomial degrees x = np.zeros(3) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="The polynomial degree should be at least 0 and smaller than 8." + ): Polynomial.model(deg).func(x) def test_polynomial_model(self): @@ -301,7 +303,9 @@ def test_invalid_polynomial_guess(self, deg: int): y = np.array([1.0, 0.0, 1.0]) # -1 and 8 are both invalid polynomial degrees - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="The polynomial degree should be at least 0 and smaller than 8." + ): Polynomial.guess(deg)(x, y) diff --git a/tests/callbacks/test_document_logging_callback.py b/tests/callbacks/test_document_logging_callback.py index a38cd381..26c55e8b 100644 --- a/tests/callbacks/test_document_logging_callback.py +++ b/tests/callbacks/test_document_logging_callback.py @@ -1,8 +1,8 @@ # pyright: reportMissingParameterType=false import json +from collections.abc import Generator from pathlib import Path -from typing import Generator from unittest.mock import mock_open, patch import bluesky.plan_stubs as bps @@ -22,7 +22,7 @@ def basic_plan() -> Generator[Msg, None, None]: result: RunEngineResult = RE(basic_plan()) filepath = log_location / f"{result.run_start_uids[0]}.log" - for i in range(0, 2): + for i in range(2): assert m.call_args_list[i].args == (filepath, "a") # Checks that the file is opened 2 times, for open and then stop diff --git a/tests/callbacks/test_write_log_callback.py b/tests/callbacks/test_write_log_callback.py index ad108612..b4b131b3 100644 --- a/tests/callbacks/test_write_log_callback.py +++ b/tests/callbacks/test_write_log_callback.py @@ -26,7 +26,7 @@ def test_header_data_all_available_on_start(cb): cb.start(run_start) result = save_path / f"{run_start['uid']}.txt" - mock_file.assert_called_with(result, "a", newline="") + mock_file.assert_called_with(result, "a", newline="", encoding="utf-8") # time should have been renamed to start_time and converted to human readable mock_file().write.assert_any_call("start_time: 2024-10-04 14:43:43\n") mock_file().write.assert_any_call(f"uid: {uid}\n") @@ -67,7 +67,7 @@ def test_event_prints_header_with_units_and_respects_precision_of_value_on_first with patch("ibex_bluesky_core.callbacks.file_logger.open", mock_open()) as mock_file: cb.event(event) - mock_file.assert_called_with(cb.filename, "a", newline="") + mock_file.assert_called_with(cb.filename, "a", newline="", encoding="utf-8") first_call = call(f"\n{field_name}({units})\n") second_call = call(f"{expected_value:.{prec}f}\n") mock_file().write.assert_has_calls([first_call, second_call]) @@ -92,7 +92,7 @@ def test_event_prints_header_without_units_and_does_not_truncate_precision_if_no with patch("ibex_bluesky_core.callbacks.file_logger.open", mock_open()) as mock_file: cb.event(event) - mock_file.assert_called_with(cb.filename, "a", newline="") + mock_file.assert_called_with(cb.filename, "a", newline="", encoding="utf-8") mock_file().write.assert_has_calls([call("\ntest\n"), call("1.2345\n")]) assert mock_file().write.call_count == 2 @@ -118,7 +118,7 @@ def test_event_prints_header_only_on_first_event_and_does_not_truncate_if_not_fl with patch("ibex_bluesky_core.callbacks.file_logger.open", mock_open()) as mock_file: cb.event(second_event) - mock_file.assert_called_with(cb.filename, "a", newline="") + mock_file.assert_called_with(cb.filename, "a", newline="", encoding="utf-8") mock_file().write.assert_called_once_with(f"{expected_value}\n") assert mock_file().write.call_count == 1 diff --git a/tests/devices/simpledae/test_controllers.py b/tests/devices/simpledae/test_controllers.py index bd7ad957..a8507931 100644 --- a/tests/devices/simpledae/test_controllers.py +++ b/tests/devices/simpledae/test_controllers.py @@ -30,7 +30,7 @@ def aborting_run_per_point_controller() -> RunPerPointController: return RunPerPointController(save_run=False) -async def test_period_per_point_controller_publishes_current_period( +def test_period_per_point_controller_publishes_current_period( simpledae: SimpleDae, period_per_point_controller: PeriodPerPointController ): assert period_per_point_controller.additional_readable_signals(simpledae) == [ @@ -83,7 +83,7 @@ async def test_run_per_point_controller_starts_and_ends_runs( get_mock_put(simpledae.controls.end_run).assert_called_once_with(None, wait=True) -async def test_run_per_point_controller_publishes_run( +def test_run_per_point_controller_publishes_run( simpledae: SimpleDae, run_per_point_controller: RunPerPointController ): assert run_per_point_controller.additional_readable_signals(simpledae) == [ @@ -91,7 +91,7 @@ async def test_run_per_point_controller_publishes_run( ] -async def test_aborting_run_per_point_controller_doesnt_publish_run( +def test_aborting_run_per_point_controller_doesnt_publish_run( simpledae: SimpleDae, aborting_run_per_point_controller: RunPerPointController ): assert aborting_run_per_point_controller.additional_readable_signals(simpledae) == [] diff --git a/tests/devices/simpledae/test_reducers.py b/tests/devices/simpledae/test_reducers.py index de784797..29504422 100644 --- a/tests/devices/simpledae/test_reducers.py +++ b/tests/devices/simpledae/test_reducers.py @@ -49,7 +49,7 @@ def __init__(self): # Scalar Normalizer -async def test_period_good_frames_normalizer_publishes_period_good_frames( +def test_period_good_frames_normalizer_publishes_period_good_frames( period_good_frames_reducer: PeriodGoodFramesNormalizer, ): fake_dae: SimpleDae = FakeDae() # type: ignore @@ -60,7 +60,7 @@ async def test_period_good_frames_normalizer_publishes_period_good_frames( assert period_good_frames_reducer.denominator(fake_dae) == fake_dae.period.good_frames -async def test_good_frames_normalizer_publishes_good_frames( +def test_good_frames_normalizer_publishes_good_frames( good_frames_reducer: GoodFramesNormalizer, ): fake_dae: SimpleDae = FakeDae() # type: ignore @@ -71,7 +71,7 @@ async def test_good_frames_normalizer_publishes_good_frames( assert good_frames_reducer.denominator(fake_dae) == fake_dae.good_frames -async def test_scalar_normalizer_publishes_uncertainties( +def test_scalar_normalizer_publishes_uncertainties( simpledae: SimpleDae, good_frames_reducer: GoodFramesNormalizer, ): @@ -284,7 +284,7 @@ async def test_monitor_normalizer_uncertainties( assert intensity_stddev == pytest.approx(math.sqrt((6000 + (6000**2 / 15000)) / 15000**2), 1e-4) -async def test_monitor_normalizer_publishes_raw_and_normalized_counts( +def test_monitor_normalizer_publishes_raw_and_normalized_counts( simpledae: SimpleDae, monitor_normalizer: MonitorNormalizer, ): @@ -294,7 +294,7 @@ async def test_monitor_normalizer_publishes_raw_and_normalized_counts( assert monitor_normalizer.mon_counts in readables -async def test_monitor_normalizer_publishes_raw_and_normalized_count_uncertainties( +def test_monitor_normalizer_publishes_raw_and_normalized_count_uncertainties( simpledae: SimpleDae, monitor_normalizer: MonitorNormalizer, ): diff --git a/tests/devices/test_block.py b/tests/devices/test_block.py index 8ecebbc7..29f43a45 100644 --- a/tests/devices/test_block.py +++ b/tests/devices/test_block.py @@ -2,7 +2,7 @@ import asyncio import sys -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import ANY, MagicMock, call, patch import bluesky.plan_stubs as bps import bluesky.plans as bp @@ -10,6 +10,7 @@ from ophyd_async.core import get_mock_put, set_mock_value from ibex_bluesky_core.devices.block import ( + GLOBAL_MOVING_FLAG_PRE_WAIT, BlockMot, BlockR, BlockRw, @@ -100,12 +101,12 @@ async def test_locate(rw_rbv_block): } -async def test_hints(readable_block): +def test_hints(readable_block): # The primary readback should be the only "hinted" signal on a block assert readable_block.hints == {"fields": ["float_block"]} -async def test_mot_hints(mot_block): +def test_mot_hints(mot_block): assert mot_block.hints == {"fields": ["mot_block"]} @@ -205,8 +206,41 @@ async def test_block_set_with_settle_time_longer_than_timeout(): mock_aio_sleep.assert_called_once_with(30) +async def test_block_set_waiting_for_global_moving_flag(): + block = await _block_with_write_config( + BlockWriteConfig(use_global_moving_flag=True, set_timeout_s=0.1) + ) + + set_mock_value(block.global_moving, False) + with patch("ibex_bluesky_core.devices.block.asyncio.sleep") as mock_aio_sleep: + await block.set(10) + # Only check first call, as wait_for_value from ophyd_async gives us a few more... + assert mock_aio_sleep.mock_calls[0] == call(GLOBAL_MOVING_FLAG_PRE_WAIT) + + +async def test_block_set_waiting_for_global_moving_flag_timeout(): + block = await _block_with_write_config( + BlockWriteConfig(use_global_moving_flag=True, set_timeout_s=0.1) + ) + + set_mock_value(block.global_moving, True) + with patch("ibex_bluesky_core.devices.block.asyncio.sleep") as mock_aio_sleep: + with pytest.raises(aio_timeout_error): + await block.set(10) + # Only check first call, as wait_for_value from ophyd_async gives us a few more... + assert mock_aio_sleep.mock_calls[0] == call(GLOBAL_MOVING_FLAG_PRE_WAIT) + + +async def test_block_without_use_global_moving_flag_does_not_refer_to_global_moving_pv(): + block_without = await _block_with_write_config(BlockWriteConfig(use_global_moving_flag=False)) + block_with = await _block_with_write_config(BlockWriteConfig(use_global_moving_flag=True)) + + assert not hasattr(block_without, "global_moving") + assert hasattr(block_with, "global_moving") + + @pytest.mark.parametrize( - "func,args", + ("func", "args"), [ (block_r, (float, "some_block")), (block_rw, (float, "some_block")), @@ -239,18 +273,18 @@ async def test_runcontrol_read_and_describe(readable_block): assert descriptor["float_block-run_control-in_range"]["dtype"] == "boolean" -async def test_runcontrol_hints(readable_block): +def test_runcontrol_hints(readable_block): # Hinted field for explicitly reading run-control: is the reading in range? hints = readable_block.run_control.hints assert hints == {"fields": ["float_block-run_control-in_range"]} -async def test_runcontrol_monitors_correct_pv(readable_block): +def test_runcontrol_monitors_correct_pv(readable_block): source = readable_block.run_control.in_range.source assert source.endswith("UNITTEST:MOCK:CS:SB:float_block:RC:INRANGE") -async def test_mot_block_runcontrol_monitors_correct_pv(mot_block): +def test_mot_block_runcontrol_monitors_correct_pv(mot_block): source = mot_block.run_control.in_range.source # The main "motor" uses mot_block:SP:RBV, but run control should not. assert source.endswith("UNITTEST:MOCK:CS:SB:mot_block:RC:INRANGE") diff --git a/tests/devices/test_init.py b/tests/devices/test_init.py index affcfda2..e946a40d 100644 --- a/tests/devices/test_init.py +++ b/tests/devices/test_init.py @@ -33,7 +33,7 @@ def test_get_pv_prefix(): def test_cannot_get_pv_prefix(): with patch("os.getenv") as mock_getenv: mock_getenv.return_value = None - with pytest.raises(EnvironmentError): + with pytest.raises(EnvironmentError, match="MYPVPREFIX environment variable not available"): get_pv_prefix() diff --git a/tests/test_log.py b/tests/test_log.py index 29c3556f..180fadd9 100644 --- a/tests/test_log.py +++ b/tests/test_log.py @@ -25,7 +25,7 @@ def test_setup_logging_does_not_crash_if_directory_cannot_be_created( mock_makedirs.side_effect = OSError setup_logging() - stdout, stderr = capfd.readouterr() + _, stderr = capfd.readouterr() assert stderr == "unable to create ibex_bluesky_core log directory\n" @@ -49,13 +49,13 @@ def test_set_bluesky_log_levels_default_previously_unset(): def test_set_bluesky_log_levels_default_previously_set(): # Setup, set some explicit log levels on various loggers. - logging.getLogger("ibex_bluesky_core").setLevel(logging.WARN) + logging.getLogger("ibex_bluesky_core").setLevel(logging.WARNING) logging.getLogger("bluesky").setLevel(logging.INFO) logging.getLogger("ophyd_async").setLevel(logging.DEBUG) set_bluesky_log_levels() # Assert we didn't override the previously explicitly-set levels - assert logging.getLogger("ibex_bluesky_core").level == logging.WARN + assert logging.getLogger("ibex_bluesky_core").level == logging.WARNING assert logging.getLogger("bluesky").level == logging.INFO assert logging.getLogger("ophyd_async").level == logging.DEBUG diff --git a/tests/test_plan_stubs.py b/tests/test_plan_stubs.py index c63dd5bd..c3227727 100644 --- a/tests/test_plan_stubs.py +++ b/tests/test_plan_stubs.py @@ -1,13 +1,13 @@ # pyright: reportMissingParameterType=false - import time from asyncio import CancelledError -from unittest.mock import patch +from unittest.mock import MagicMock, patch +import matplotlib.pyplot as plt import pytest from bluesky.utils import Msg -from ibex_bluesky_core.plan_stubs import call_sync +from ibex_bluesky_core.plan_stubs import CALL_QT_AWARE_MSG_KEY, call_qt_aware, call_sync from ibex_bluesky_core.run_engine._msg_handlers import call_sync_handler @@ -66,3 +66,56 @@ def f(): end = time.monotonic() assert end - start == pytest.approx(1, abs=0.2) + + +def test_call_qt_aware_returns_result(RE): + def f(arg, keyword_arg): + assert arg == "foo" + assert keyword_arg == "bar" + return 123 + + def plan(): + return (yield Msg(CALL_QT_AWARE_MSG_KEY, f, "foo", keyword_arg="bar")) + + result = RE(plan()) + + assert result.plan_result == 123 + + +def test_call_qt_aware_throws_exception(RE): + def f(): + raise ValueError("broke it") + + def plan(): + return (yield Msg(CALL_QT_AWARE_MSG_KEY, f)) + + with pytest.raises(ValueError, match="broke it"): + RE(plan()) + + +def test_call_qt_aware_matplotlib_function(RE): + mock = MagicMock(spec=plt.close) + mock.__module__ = "matplotlib.pyplot" + mock.return_value = 123 + + def plan(): + return (yield from call_qt_aware(mock, "all")) + + result = RE(plan()) + assert result.plan_result == 123 + mock.assert_called_once_with("all") + + +def test_call_qt_aware_non_matplotlib_function(RE): + mock = MagicMock() + mock.__module__ = "some_random_module" + + def plan(): + return (yield from call_qt_aware(mock, "arg", keyword_arg="kwarg")) + + with pytest.raises( + ValueError, match="Only matplotlib functions should be passed to call_qt_aware" + ): + RE(plan()) + + mock.assert_not_called() diff --git a/tests/test_preprocessors.py b/tests/test_preprocessors.py index 2a3ec9e8..3201d2d4 100644 --- a/tests/test_preprocessors.py +++ b/tests/test_preprocessors.py @@ -24,7 +24,7 @@ async def mock_rb_num() -> SignalWithExpectedRbv: return SignalWithExpectedRbv(mock_rbnum_signal, rb_num) -async def test_rb_number_preprocessor_adds_rb_number(RE, mock_rb_num): +def test_rb_number_preprocessor_adds_rb_number(RE, mock_rb_num): with ( patch( "ibex_bluesky_core.preprocessors._get_rb_number_signal", return_value=mock_rb_num.signal @@ -45,7 +45,7 @@ def plan(): assert start_doc["rb_number"] == mock_rb_num.rb_num -async def test_rb_number_preprocessor_adds_unknown_if_signal_not_connected(RE, mock_rb_num): +def test_rb_number_preprocessor_adds_unknown_if_signal_not_connected(RE, mock_rb_num): with ( patch( "ibex_bluesky_core.preprocessors._get_rb_number_signal", return_value=mock_rb_num.signal diff --git a/tests/test_run_engine.py b/tests/test_run_engine.py index b7f4f4cd..0256dd59 100644 --- a/tests/test_run_engine.py +++ b/tests/test_run_engine.py @@ -1,7 +1,8 @@ # pyright: reportMissingParameterType=false import threading -from typing import Any, Generator +from collections.abc import Generator +from typing import Any from unittest.mock import MagicMock import bluesky.plan_stubs as bps