Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jackbdoughty committed Oct 9, 2024
1 parent 0f5a152 commit e38024a
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 71 deletions.
2 changes: 2 additions & 0 deletions manual_system_tests/dae_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def dae_scan_plan() -> Generator[Msg, None, None]:
block.name,
controller.run_number.name,
reducer.intensity.name,
reducer.intensity_stddev.name,
reducer.det_counts.name,
reducer.det_counts_stddev.name,
dae.good_frames.name,
]
),
Expand Down
77 changes: 62 additions & 15 deletions src/ibex_bluesky_core/devices/dae/dae_spectra.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
"""ophyd-async devices and utilities for a single DAE spectra."""

import asyncio

from bluesky.protocols import Triggerable
import scipp as sc
import typing
import numpy as np
from event_model.documents.event_descriptor import DataKey
from numpy import float32
from numpy.typing import NDArray
from ophyd_async.core import SignalR, StandardReadable, soft_signal_r_and_setter
from ophyd_async.core import SignalR, StandardReadable, soft_signal_r_and_setter, AsyncStageable, AsyncStatus
from ophyd_async.epics.signal import epics_signal_r


class DaeSpectra(StandardReadable):
# def soft_signal_r_and_setter(
# datatype: type[T] | None = None,
# initial_value: T | None = None,
# name: str = "",
# units: str | None = None,
# precision: int | None = None,
# ) -> tuple[SignalR[T], Callable[[T], None]]:
# """Returns a tuple of a read-only Signal and a callable through
# which the signal can be internally modified within the device.
# May pass metadata, which are propagated into describe.
# Use soft_signal_rw if you want a device that is externally modifiable
# """
# metadata = SignalMetadata(units=units, precision=precision)
# backend = SoftSignalBackend(datatype, initial_value, metadata=metadata)
# signal = SignalR(backend, name=name)

# return (signal, backend.set_value, lambda u: metadata.units = u)


# def get_units(sig: SignalR[Any]) -> str:
# pass


class DaeSpectra(StandardReadable, Triggerable):
"""Subdevice for a single DAE spectra."""


def __init__(self, dae_prefix: str, *, spectra: int, period: int, name: str = "") -> None:
"""Set up signals for a single DAE spectra."""
# x-axis; time-of-flight.
# These are bin-centre coordinates.
self.tof: SignalR[NDArray[float32]] = epics_signal_r(
self._tof_raw: SignalR[NDArray[float32]] = epics_signal_r(
NDArray[float32], f"{dae_prefix}SPEC:{period}:{spectra}:X"
)
self.tof_size: SignalR[int] = epics_signal_r(
Expand All @@ -26,7 +52,7 @@ def __init__(self, dae_prefix: str, *, spectra: int, period: int, name: str = ""

# x-axis; time-of-flight.
# These are bin-edge coordinates, with a size one more than the corresponding data.
self.tof_edges: SignalR[NDArray[float32]] = epics_signal_r(
self._tof_edges_raw: SignalR[NDArray[float32]] = epics_signal_r(
NDArray[float32], f"{dae_prefix}SPEC:{period}:{spectra}:XE"
)
self.tof_edges_size: SignalR[int] = epics_signal_r(
Expand All @@ -38,7 +64,7 @@ def __init__(self, dae_prefix: str, *, spectra: int, period: int, name: str = ""
# that ToF bin.
# - Unsuitable for summing counts directly.
# - Will give a continuous plot for non-uniform bin sizes.
self.counts_per_time: SignalR[NDArray[float32]] = epics_signal_r(
self._counts_per_time_raw: SignalR[NDArray[float32]] = epics_signal_r(
NDArray[float32], f"{dae_prefix}SPEC:{period}:{spectra}:Y"
)
self.counts_per_time_size: SignalR[int] = epics_signal_r(
Expand All @@ -49,17 +75,34 @@ def __init__(self, dae_prefix: str, *, spectra: int, period: int, name: str = ""
# This is unnormalized number of counts per ToF bin.
# - Suitable for summing counts
# - This will give a discontinuous plot for non-uniform bin sizes.
self.counts: SignalR[NDArray[float32]] = epics_signal_r(
self._counts_raw: SignalR[NDArray[float32]] = epics_signal_r(
NDArray[float32], f"{dae_prefix}SPEC:{period}:{spectra}:YC"
)
self.counts_size: SignalR[int] = epics_signal_r(
int, f"{dae_prefix}SPEC:{period}:{spectra}:YC.NORD"
)

self.tof, self._tof_setter = soft_signal_r_and_setter(NDArray[float32], [])
self.tof_edges, self._tof_edges_setter = soft_signal_r_and_setter(NDArray[float32], [])
self.counts_per_time, self._counts_per_time_setter = soft_signal_r_and_setter(NDArray[float32], [])
self.counts, self._counts_setter = soft_signal_r_and_setter(NDArray[float32], [])
self.stddev, self._stddev_setter = soft_signal_r_and_setter(NDArray[float32], [])

super().__init__(name=name)


@AsyncStatus.wrap
async def trigger(self) -> None:

self._tof_setter(await self._read_sized(self._tof_raw, self.tof_size))
self._tof_edges_setter(await self._read_sized(self._tof_edges_raw, self.tof_edges_size))
self._counts_per_time_setter(await self._read_sized(self._counts_per_time_raw, self.counts_per_time_size))
self._counts_setter(await self._read_sized(self._counts_raw, self.counts_size))

stddev = await self.counts.get_value()
self._stddev_setter(np.sqrt(stddev))


async def _read_sized(
self, array_signal: SignalR[NDArray[float32]], size_signal: SignalR[int]
) -> NDArray[float32]:
Expand All @@ -68,24 +111,28 @@ async def _read_sized(

async def read_tof(self) -> NDArray[float32]:
"""Read a correctly-sized time-of-flight (x) array representing bin centres."""
return await self._read_sized(self.tof, self.tof_size)
tof = await self.tof.get_value()
return typing.cast(NDArray[float32], tof)

async def read_tof_edges(self) -> NDArray[float32]:
"""Read a correctly-sized time-of-flight (x) array representing bin edges."""
return await self._read_sized(self.tof_edges, self.tof_edges_size)
tof_edges = await self.tof_edges.get_value()
return typing.cast(NDArray[float32], tof_edges)

async def read_counts(self) -> NDArray[float32]:
"""Read a correctly-sized array of counts."""
return await self._read_sized(self.counts, self.counts_size)
counts = await self.counts.get_value()
return typing.cast(NDArray[float32], counts)

async def read_counts_per_time(self) -> NDArray[float32]:
"""Read a correctly-sized array of counts divided by bin width."""
return await self._read_sized(self.counts_per_time, self.counts_per_time_size)
counts_per_time = await self.counts_per_time.get_value()
return typing.cast(NDArray[float32], counts_per_time)

async def read_counts_uncertainties(self) -> NDArray[float32]:
"""Read a correctly-sized array of uncertainties for each count."""

return await self._read_sized(self.stddev, self.counts_size) # type: ignore ???????
stddev = await self.stddev.get_value()
return typing.cast(NDArray[float32], stddev)

async def read_spectrum_dataarray(self) -> sc.DataArray:
"""Get a scipp DataArray containing the current data from this spectrum.
Expand All @@ -98,7 +145,7 @@ async def read_spectrum_dataarray(self) -> sc.DataArray:
"""
tof_edges, tof_edges_descriptor, counts = await asyncio.gather(
self.read_tof_edges(),
self.tof_edges.describe(),
self._tof_edges_raw.describe(),
self.read_counts(),
)

Expand All @@ -109,7 +156,7 @@ async def read_spectrum_dataarray(self) -> sc.DataArray:
f"Edges size was {tof_edges.size}, counts size was {counts.size}."
)

datakey: DataKey = tof_edges_descriptor[self.tof_edges.name]
datakey: DataKey = tof_edges_descriptor[self._tof_edges_raw.name]
unit = datakey.get("units", None)
if unit is None:
raise ValueError("Could not determine engineering units of tof edges.")
Expand Down
37 changes: 10 additions & 27 deletions src/ibex_bluesky_core/devices/simpledae/reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def denominator(self, dae: "SimpleDae") -> SignalR[int] | SignalR[float]:

async def reduce_data(self, dae: "SimpleDae") -> None:
"""Apply the normalization."""

await asyncio.gather(*[self.detectors[i].trigger() for i in self.detectors])

summed_counts, denominator = await asyncio.gather(
sum_spectra(self.detectors.values()), self.denominator(dae).get_value()
)
Expand All @@ -81,10 +84,6 @@ async def reduce_data(self, dae: "SimpleDae") -> None:
self._det_counts_stddev_setter(math.sqrt(detector_counts_var))
self._intensity_stddev_setter(math.sqrt(intensity_var))

for spec in self.detectors.values():
stddev = await spec.read_counts()
spec._stddev_setter(np.sqrt(stddev))

def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]:
"""Publish interesting signals derived or used by this reducer."""
return [
Expand All @@ -94,11 +93,6 @@ def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]:
self.det_counts_stddev,
self.intensity_stddev,
]

def readable_detector_count_uncertainties(self, dae: "SimpleDae") -> list[Device]:
"""Publish individual uncertainty signals for all detectors."""

return [self.detectors[i].stddev for i in self.detectors]


class PeriodGoodFramesNormalizer(ScalarNormalizer):
Expand Down Expand Up @@ -151,6 +145,13 @@ def __init__(

async def reduce_data(self, dae: "SimpleDae") -> None:
"""Apply the normalization."""
async def trigger_detectors():
await asyncio.gather(*[self.detectors[i].trigger() for i in self.detectors])
async def trigger_monitors():
await asyncio.gather(*[self.monitors[i].trigger() for i in self.monitors])

await asyncio.gather(trigger_detectors(), trigger_monitors())

detector_counts, monitor_counts = await asyncio.gather(
sum_spectra(self.detectors.values()), sum_spectra(self.monitors.values())
)
Expand All @@ -167,14 +168,6 @@ async def reduce_data(self, dae: "SimpleDae") -> None:
self._mon_counts_stddev_setter(math.sqrt(monitor_counts_var))
self._intensity_stddev_setter(math.sqrt(intensity_var))

for spec in self.detectors.values():
stddev = await spec.read_counts()
spec._stddev_setter(np.sqrt(stddev))

for spec in self.monitors.values():
stddev = await spec.read_counts()
spec._stddev_setter(np.sqrt(stddev))

def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]:
"""Publish interesting signals derived or used by this reducer."""
return [
Expand All @@ -185,13 +178,3 @@ def additional_readable_signals(self, dae: "SimpleDae") -> list[Device]:
self.mon_counts_stddev,
self.intensity_stddev,
]

def readable_detector_count_uncertainties(self, dae: "SimpleDae") -> list[Device]:
"""Publish individual uncertainty signals for all detectors."""

return [self.detectors[i].stddev for i in self.detectors]

def readable_monitor_count_uncertainties(self, dae: "SimpleDae") -> list[Device]:
"""Publish individual uncertainty signals for all monitors."""

return [self.monitors[i].stddev for i in self.monitors]
25 changes: 0 additions & 25 deletions tests/devices/simpledae/test_reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,6 @@ async def test_period_good_frames_normalizer_uncertainties(
assert intensity_stddev == pytest.approx(math.sqrt((21000 + (123**2 / 21000) ) / 123**2), 1e-4)


async def test_scalar_normalizer_publishes_individual_detector_count_uncertainties(
simpledae: SimpleDae,
period_good_frames_reducer: PeriodGoodFramesNormalizer,
):

readables = period_good_frames_reducer.readable_detector_count_uncertainties(simpledae)
assert period_good_frames_reducer.detectors[1].stddev in readables


# Monitor Normalizer


Expand Down Expand Up @@ -217,19 +208,3 @@ async def test_monitor_normalizer_publishes_raw_and_normalized_count_uncertainti
assert monitor_normalizer.intensity_stddev in readables
assert monitor_normalizer.det_counts_stddev in readables
assert monitor_normalizer.mon_counts_stddev in readables


async def test_monitor_normalizer_publishes_individual_detector_count_uncertainties(
simpledae: SimpleDae,
monitor_normalizer: MonitorNormalizer,
):
readables = monitor_normalizer.readable_detector_count_uncertainties(simpledae)
assert monitor_normalizer.detectors[1].stddev in readables


async def test_monitor_normalizer_publishes_individual_monitor_count_uncertainties(
simpledae: SimpleDae,
monitor_normalizer: MonitorNormalizer,
):
readables = monitor_normalizer.readable_monitor_count_uncertainties(simpledae)
assert monitor_normalizer.monitors[2].stddev in readables
10 changes: 6 additions & 4 deletions tests/devices/test_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,15 +950,17 @@ def test_empty_dae_settings_dataclass_does_not_change_any_settings(dae: Dae, RE:


async def test_read_spectra_correctly_sizes_arrays(spectrum: DaeSpectra):
set_mock_value(spectrum.tof, np.zeros(dtype=np.float32, shape=(1000,)))
set_mock_value(spectrum._tof_raw, np.zeros(dtype=np.float32, shape=(1000,)))
set_mock_value(spectrum.tof_size, 100)
set_mock_value(spectrum.counts, np.zeros(dtype=np.float32, shape=(2000,)))
set_mock_value(spectrum._counts_raw, np.zeros(dtype=np.float32, shape=(2000,)))
set_mock_value(spectrum.counts_size, 200)
set_mock_value(spectrum.counts_per_time, np.zeros(dtype=np.float32, shape=(3000,)))
set_mock_value(spectrum._counts_per_time_raw, np.zeros(dtype=np.float32, shape=(3000,)))
set_mock_value(spectrum.counts_per_time_size, 300)
set_mock_value(spectrum.tof_edges, np.zeros(dtype=np.float32, shape=(4000,)))
set_mock_value(spectrum._tof_edges_raw, np.zeros(dtype=np.float32, shape=(4000,)))
set_mock_value(spectrum.tof_edges_size, 400)

await spectrum.trigger()

assert (await spectrum.read_tof()).shape == (100,)
assert (await spectrum.read_counts()).shape == (200,)
assert (await spectrum.read_counts_per_time()).shape == (300,)
Expand Down

0 comments on commit e38024a

Please sign in to comment.