diff --git a/docs/examples/workflow.ipynb b/docs/examples/workflow.ipynb
index 5a0551c..709bd1f 100644
--- a/docs/examples/workflow.ipynb
+++ b/docs/examples/workflow.ipynb
@@ -9,7 +9,8 @@
"## Collect Parameters and Providers\n",
"### Simulation(McStas) Data\n",
"There is a dedicated loader, ``load_mcstas_nexus`` for ``McStas`` simulation data workflow.
\n",
- "``MaximumProbability`` can be manually provided to the loader to derive more realistic number of events.
\n",
+ "``MaximumProbability`` can be manually provided to the loader
\n",
+ "to derive more realistic number of events.
\n",
"It is because ``weights`` are given as probability, not number of events in a McStas file.
"
]
},
@@ -47,7 +48,9 @@
"source": [
"from typing import get_type_hints\n",
"param_reprs = {key.__name__: value for key, value in params.items()}\n",
- "prov_reprs = {get_type_hints(prov)['return'].__name__: prov.__name__ for prov in providers}\n",
+ "prov_reprs = {\n",
+ " get_type_hints(prov)['return'].__name__: prov.__name__ for prov in providers\n",
+ "}\n",
"\n",
"# Providers and parameters to be used for pipeline\n",
"sc.DataGroup(**prov_reprs, **param_reprs)"
@@ -91,11 +94,38 @@
"da = nmx_workflow.compute(NMXData)\n",
"da"
]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Instrument View\n",
+ "\n",
+ "Pixel positions are not used for later steps,\n",
+ "but it is included in the coordinates for instrument view.\n",
+ "\n",
+ "All pixel positions are relative to the sample position,\n",
+ "therefore the sample is at (0, 0, 0).\n",
+ "\n",
+ "You can plot the instrument view like below.\n",
+ "\n",
+ "```python\n",
+ "import scippneutron as scn\n",
+ "\n",
+ "unnecessary_coords = list(coord for coord in da.coords if coord != 'position')\n",
+ "instrument_view_da = da.drop_coords(unnecessary_coords).flatten(['panel', 'id'], 'id').hist()\n",
+ "view = scn.instrument_view(instrument_view_da)\n",
+ "view.children[0].toolbar.cameraz()\n",
+ "view\n",
+ "```\n",
+ "\n",
+ "**It might be very slow or not work in the ``VS Code`` jupyter notebook editor.**"
+ ]
}
],
"metadata": {
"kernelspec": {
- "display_name": "nmx-dev-39",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -113,5 +143,5 @@
}
},
"nbformat": 4,
- "nbformat_minor": 2
+ "nbformat_minor": 4
}
diff --git a/src/ess/nmx/mcstas_loader.py b/src/ess/nmx/mcstas_loader.py
index 9085542..c56ed28 100644
--- a/src/ess/nmx/mcstas_loader.py
+++ b/src/ess/nmx/mcstas_loader.py
@@ -7,7 +7,7 @@
PixelIDs = NewType("PixelIDs", sc.Variable)
InputFilepath = NewType("InputFilepath", str)
-NMXData = NewType("NMXData", sc.DataArray)
+NMXData = NewType("NMXData", sc.DataGroup)
# McStas Configurations
MaximumProbability = NewType("MaximumProbability", int)
@@ -37,13 +37,6 @@ def _copy_partial_var(
return var
-def _get_mcstas_pixel_ids() -> PixelIDs:
- """pixel IDs for each detector"""
- intervals = [(1, 1638401), (2000001, 3638401), (4000001, 5638401)]
- ids = [sc.arange('id', start, stop, unit=None) for start, stop in intervals]
- return PixelIDs(sc.concat(ids, 'id'))
-
-
def load_mcstas_nexus(
file_path: InputFilepath,
max_probability: Optional[MaximumProbability] = None,
@@ -60,6 +53,9 @@ def load_mcstas_nexus(
"""
+ from .mcstas_xml import read_mcstas_geometry_xml
+
+ geometry = read_mcstas_geometry_xml(file_path)
probability = max_probability or DefaultMaximumProbability
with snx.File(file_path) as file:
@@ -76,6 +72,9 @@ def load_mcstas_nexus(
weights = (probability / weights.max()) * weights
loaded = sc.DataArray(data=weights, coords={'t': t_list, 'id': id_list})
- grouped = loaded.group(_get_mcstas_pixel_ids())
+ coords = geometry.to_coords()
+ grouped = loaded.group(coords.pop('pixel_id'))
+ da = grouped.fold(dim='id', sizes={'panel': len(geometry.detectors), 'id': -1})
+ da.coords.update(coords)
- return NMXData(grouped.fold(dim='id', sizes={'panel': 3, 'id': -1}))
+ return NMXData(da)
diff --git a/src/ess/nmx/mcstas_xml.py b/src/ess/nmx/mcstas_xml.py
new file mode 100644
index 0000000..d981a30
--- /dev/null
+++ b/src/ess/nmx/mcstas_xml.py
@@ -0,0 +1,403 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+# McStas instrument geometry xml description related functions.
+from dataclasses import dataclass
+from pathlib import Path
+from types import MappingProxyType
+from typing import Iterable, Optional, Protocol, Tuple, TypeVar, Union
+
+import scipp as sc
+
+T = TypeVar('T')
+
+
+_AXISNAME_TO_UNIT_VECTOR = MappingProxyType(
+ {
+ 'x': sc.vector([1.0, 0.0, 0.0]),
+ 'y': sc.vector([0.0, 1.0, 0.0]),
+ 'z': sc.vector([0.0, 0.0, 1.0]),
+ }
+)
+
+
+class _XML(Protocol):
+ """XML element or tree type.
+
+ Temporarily used for type hinting.
+ Builtin XML type is blocked by bandit security check."""
+
+ tag: str
+ attrib: dict[str, str]
+
+ def find(self, name: str) -> Optional['_XML']:
+ ...
+
+ def __iter__(self) -> '_XML':
+ ...
+
+ def __next__(self) -> '_XML':
+ ...
+
+
+def _check_and_unpack_if_only_one(xml_items: list[_XML], name: str) -> _XML:
+ """Check if there is only one element with ``name``."""
+ if len(xml_items) > 1:
+ raise ValueError(f"Multiple {name}s found.")
+ elif len(xml_items) == 0:
+ raise ValueError(f"No {name} found.")
+
+ return xml_items.pop()
+
+
+def select_by_tag(xml_items: _XML, tag: str) -> _XML:
+ """Select element with ``tag`` if there is only one."""
+
+ return _check_and_unpack_if_only_one(list(filter_by_tag(xml_items, tag)), tag)
+
+
+def filter_by_tag(xml_items: Iterable[_XML], tag: str) -> Iterable[_XML]:
+ """Filter xml items by tag."""
+ return (item for item in xml_items if item.tag == tag)
+
+
+def filter_by_type_prefix(xml_items: Iterable[_XML], prefix: str) -> Iterable[_XML]:
+ """Filter xml items by type prefix."""
+ return (
+ item for item in xml_items if item.attrib.get('type', '').startswith(prefix)
+ )
+
+
+def select_by_type_prefix(xml_items: Iterable[_XML], prefix: str) -> _XML:
+ """Select xml item by type prefix."""
+
+ cands = list(filter_by_type_prefix(xml_items, prefix))
+ return _check_and_unpack_if_only_one(cands, prefix)
+
+
+def find_attributes(component: _XML, *args: str) -> dict[str, float]:
+ """Retrieve ``args`` as float from xml."""
+
+ return {key: float(component.attrib[key]) for key in args}
+
+
+@dataclass
+class SimulationSettings:
+ """Simulation settings extracted from McStas instrument xml description."""
+
+ # From
+ length_unit: str # 'unit' of
+ angle_unit: str # 'unit' of
+ # From
+ beam_axis: str # 'axis' of
+ handedness: str # 'val' of
+
+ @classmethod
+ def from_xml(cls, tree: _XML) -> 'SimulationSettings':
+ """Create simulation settings from xml."""
+ defaults = select_by_tag(tree, 'defaults')
+ length_desc = select_by_tag(defaults, 'length')
+ angle_desc = select_by_tag(defaults, 'angle')
+ reference_frame = select_by_tag(defaults, 'reference-frame')
+ along_beam = select_by_tag(reference_frame, 'along-beam')
+ handedness = select_by_tag(reference_frame, 'handedness')
+
+ return cls(
+ length_unit=length_desc.attrib['unit'],
+ angle_unit=angle_desc.attrib['unit'],
+ beam_axis=along_beam.attrib['axis'],
+ handedness=handedness.attrib['val'],
+ )
+
+
+def _position_from_location(location: _XML, unit: str = 'm') -> sc.Variable:
+ """Retrieve position from location."""
+ x, y, z = find_attributes(location, 'x', 'y', 'z').values()
+ return sc.vector([x, y, z], unit=unit)
+
+
+def _rotation_matrix_from_location(
+ location: _XML, angle_unit: str = 'degree'
+) -> sc.Variable:
+ """Retrieve rotation matrix from location."""
+ from .rotation import axis_angle_to_quaternion, quaternion_to_matrix
+
+ attribs = find_attributes(location, 'axis-x', 'axis-y', 'axis-z', 'rot')
+ x, y, z, w = axis_angle_to_quaternion(
+ x=attribs['axis-x'],
+ y=attribs['axis-y'],
+ z=attribs['axis-z'],
+ theta=sc.scalar(-attribs['rot'], unit=angle_unit),
+ )
+ return quaternion_to_matrix(x=x, y=y, z=z, w=w)
+
+
+@dataclass
+class DetectorDesc:
+ """Detector information extracted from McStas instrument xml description."""
+
+ # From
+ component_type: str # 'type'
+ name: str
+ id_start: int # 'idstart'
+ fast_axis_name: str # 'idfillbyfirst'
+ # From
+ num_x: int # 'xpixels'
+ num_y: int # 'ypixels'
+ step_x: sc.Variable # 'xstep'
+ step_y: sc.Variable # 'ystep'
+ start_x: float # 'xstart'
+ start_y: float # 'ystart'
+ # From under
+ position: sc.Variable # 'x', 'y', 'z'
+ # Calculated fields
+ rotation_matrix: sc.Variable
+ slow_axis_name: str
+ fast_axis: sc.Variable
+ slow_axis: sc.Variable
+
+ @classmethod
+ def from_xml(
+ cls,
+ *,
+ component: _XML,
+ type_desc: _XML,
+ simulation_settings: SimulationSettings,
+ ) -> 'DetectorDesc':
+ """Create detector description from xml component and type."""
+
+ location = select_by_tag(component, 'location')
+ rotation_matrix = _rotation_matrix_from_location(
+ location, simulation_settings.angle_unit
+ )
+ fast_axis_name = component.attrib['idfillbyfirst']
+ slow_axis_name = 'xy'.replace(fast_axis_name, '')
+
+ length_unit = simulation_settings.length_unit
+
+ return cls(
+ component_type=type_desc.attrib['name'],
+ name=component.attrib['name'],
+ id_start=int(component.attrib['idstart']),
+ fast_axis_name=fast_axis_name,
+ slow_axis_name=slow_axis_name,
+ num_x=int(type_desc.attrib['xpixels']),
+ num_y=int(type_desc.attrib['ypixels']),
+ step_x=sc.scalar(float(type_desc.attrib['xstep']), unit=length_unit),
+ step_y=sc.scalar(float(type_desc.attrib['ystep']), unit=length_unit),
+ start_x=float(type_desc.attrib['xstart']),
+ start_y=float(type_desc.attrib['ystart']),
+ position=_position_from_location(location, simulation_settings.length_unit),
+ rotation_matrix=rotation_matrix,
+ fast_axis=rotation_matrix * _AXISNAME_TO_UNIT_VECTOR[fast_axis_name],
+ slow_axis=rotation_matrix * _AXISNAME_TO_UNIT_VECTOR[slow_axis_name],
+ )
+
+ @property
+ def total_pixels(self) -> int:
+ return self.num_x * self.num_y
+
+ @property
+ def slow_step(self) -> sc.Variable:
+ return self.step_y if self.fast_axis_name == 'x' else self.step_x
+
+ @property
+ def fast_step(self) -> sc.Variable:
+ return self.step_x if self.fast_axis_name == 'x' else self.step_y
+
+ @property
+ def num_fast_pixels_per_row(self) -> int:
+ """Number of pixels in each row of the detector along the fast axis."""
+ return self.num_x if self.fast_axis_name == 'x' else self.num_y
+
+
+def _collect_detector_descriptions(tree: _XML) -> Tuple[DetectorDesc, ...]:
+ """Retrieve detector geometry descriptions from mcstas file."""
+ type_list = filter_by_tag(tree, 'type')
+ simulation_settings = SimulationSettings.from_xml(tree)
+
+ def _find_type_desc(det: _XML) -> _XML:
+ for type_ in type_list:
+ if type_.attrib['name'] == det.attrib['type']:
+ return type_
+
+ raise ValueError(
+ f"Cannot find type {det.attrib['type']} for {det.attrib['name']}."
+ )
+
+ detector_components = [
+ DetectorDesc.from_xml(
+ component=det,
+ type_desc=_find_type_desc(det),
+ simulation_settings=simulation_settings,
+ )
+ for det in filter_by_type_prefix(filter_by_tag(tree, 'component'), 'MonNDtype')
+ ]
+
+ return tuple(sorted(detector_components, key=lambda x: x.id_start))
+
+
+@dataclass
+class SampleDesc:
+ """Sample description extracted from McStas instrument xml description."""
+
+ # From
+ component_type: str
+ name: str
+ # From under
+ position: sc.Variable
+ rotation_matrix: sc.Variable
+
+ @classmethod
+ def from_xml(
+ cls, *, tree: _XML, simulation_settings: SimulationSettings
+ ) -> 'SampleDesc':
+ """Create sample description from xml component."""
+ source_xml = select_by_type_prefix(tree, 'sampleMantid-type')
+ location = select_by_tag(source_xml, 'location')
+
+ return cls(
+ component_type=source_xml.attrib['type'],
+ name=source_xml.attrib['name'],
+ position=_position_from_location(location, simulation_settings.length_unit),
+ rotation_matrix=_rotation_matrix_from_location(
+ location, simulation_settings.angle_unit
+ ),
+ )
+
+ def position_from_sample(self, other: sc.Variable) -> sc.Variable:
+ """Position of ``other`` relative to the sample.
+
+ All positions and distance are stored relative to the sample position.
+
+ Parameters
+ ----------
+ other:
+ Position of the other object in 3D vector.
+
+ """
+
+ return other - self.position
+
+
+@dataclass
+class SourceDesc:
+ """Source description extracted from McStas instrument xml description."""
+
+ # From
+ component_type: str
+ name: str
+ # From under
+ position: sc.Variable
+
+ @classmethod
+ def from_xml(
+ cls, *, tree: _XML, simulation_settings: SimulationSettings
+ ) -> 'SourceDesc':
+ """Create source description from xml component."""
+ source_xml = select_by_type_prefix(tree, 'sourceMantid-type')
+ location = select_by_tag(source_xml, 'location')
+
+ return cls(
+ component_type=source_xml.attrib['type'],
+ name=source_xml.attrib['name'],
+ position=_position_from_location(location, simulation_settings.length_unit),
+ )
+
+
+def _construct_pixel_ids(detector_descs: Tuple[DetectorDesc, ...]) -> sc.Variable:
+ """Pixel IDs for all detectors."""
+ intervals = [
+ (desc.id_start, desc.id_start + desc.total_pixels) for desc in detector_descs
+ ]
+ ids = [sc.arange('id', start, stop, unit=None) for start, stop in intervals]
+ return sc.concat(ids, 'id')
+
+
+def _pixel_positions(
+ detector: DetectorDesc, position_offset: sc.Variable
+) -> sc.Variable:
+ """Position of pixels of the ``detector``.
+
+ Position of each pixel is relative to the position_offset.
+ """
+ pixel_idx = sc.arange('id', detector.total_pixels)
+ n_col = sc.scalar(detector.num_fast_pixels_per_row)
+
+ pixel_n_slow = pixel_idx // n_col
+ pixel_n_fast = pixel_idx % n_col
+
+ fast_axis_steps = detector.fast_axis * detector.fast_step
+ slow_axis_steps = detector.slow_axis * detector.slow_step
+
+ return (
+ (pixel_n_slow * slow_axis_steps)
+ + (pixel_n_fast * fast_axis_steps)
+ + position_offset
+ )
+
+
+def _detector_pixel_positions(
+ detector_descs: Tuple[DetectorDesc, ...], sample: SampleDesc
+) -> sc.Variable:
+ """Position of pixels of all detectors."""
+
+ positions = [
+ _pixel_positions(detector, sample.position_from_sample(detector.position))
+ for detector in detector_descs
+ ]
+ return sc.concat(positions, 'panel')
+
+
+@dataclass
+class McStasInstrument:
+ simulation_settings: SimulationSettings
+ detectors: Tuple[DetectorDesc, ...]
+ source: SourceDesc
+ sample: SampleDesc
+
+ @classmethod
+ def from_xml(cls, tree: _XML) -> 'McStasInstrument':
+ """Create McStas instrument from xml."""
+ simulation_settings = SimulationSettings.from_xml(tree)
+
+ return cls(
+ simulation_settings=simulation_settings,
+ detectors=_collect_detector_descriptions(tree),
+ source=SourceDesc.from_xml(
+ tree=tree, simulation_settings=simulation_settings
+ ),
+ sample=SampleDesc.from_xml(
+ tree=tree, simulation_settings=simulation_settings
+ ),
+ )
+
+ def to_coords(self) -> dict[str, sc.Variable]:
+ """Extract coordinates from the McStas instrument description."""
+ slow_axes = [det.slow_axis for det in self.detectors]
+ fast_axes = [det.fast_axis for det in self.detectors]
+ origins = [
+ self.sample.position_from_sample(det.position) for det in self.detectors
+ ]
+ detector_dim = 'panel'
+
+ return {
+ 'pixel_id': _construct_pixel_ids(self.detectors),
+ 'fast_axis': sc.concat(fast_axes, detector_dim),
+ 'slow_axis': sc.concat(slow_axes, detector_dim),
+ 'origin_position': sc.concat(origins, detector_dim),
+ 'sample_position': self.sample.position_from_sample(self.sample.position),
+ 'source_position': self.sample.position_from_sample(self.source.position),
+ 'sample_name': sc.scalar(self.sample.name),
+ 'position': _detector_pixel_positions(self.detectors, self.sample),
+ }
+
+
+def read_mcstas_geometry_xml(file_path: Union[Path, str]) -> McStasInstrument:
+ """Retrieve geometry parameters from mcstas file"""
+ import h5py
+ from defusedxml.ElementTree import fromstring
+
+ instrument_xml_path = 'entry1/instrument/instrument_xml/data'
+ with h5py.File(file_path) as file:
+ tree = fromstring(file[instrument_xml_path][...][0])
+ return McStasInstrument.from_xml(tree)
diff --git a/src/ess/nmx/rotation.py b/src/ess/nmx/rotation.py
new file mode 100644
index 0000000..4ec91c8
--- /dev/null
+++ b/src/ess/nmx/rotation.py
@@ -0,0 +1,70 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+# Rotation related functions for NMX
+import numpy as np
+import scipp as sc
+from numpy.typing import NDArray
+
+
+def axis_angle_to_quaternion(
+ *, x: float, y: float, z: float, theta: sc.Variable
+) -> NDArray:
+ """Convert axis-angle to queternions, [x, y, z, w].
+
+ Parameters
+ ----------
+ x:
+ X component of axis of rotation.
+ y:
+ Y component of axis of rotation.
+ z:
+ Z component of axis of rotation.
+ theta:
+ Angle of rotation, with unit of ``rad`` or ``deg``.
+
+ Returns
+ -------
+ :
+ A list of (normalized) quaternions, [x, y, z, w].
+
+ Notes
+ -----
+ Axis of rotation (x, y, z) does not need to be normalized,
+ but it returns a unit quaternion (x, y, z, w).
+
+ """
+
+ w: sc.Variable = sc.cos(theta.to(unit='rad') / 2)
+ xyz: sc.Variable = -sc.sin(theta.to(unit='rad') / 2) * sc.vector([x, y, z])
+ q = np.array([*xyz.values, w.value])
+ return q / np.linalg.norm(q)
+
+
+def quaternion_to_matrix(*, x: float, y: float, z: float, w: float) -> sc.Variable:
+ """Convert quaternion to rotation matrix.
+
+ Parameters
+ ----------
+ x:
+ x(a) component of quaternion.
+ y:
+ y(b) component of quaternion.
+ z:
+ z(c) component of quaternion.
+ w:
+ w component of quaternion.
+
+ Returns
+ -------
+ :
+ A 3x3 rotation matrix.
+
+ """
+ from scipy.spatial.transform import Rotation
+
+ return sc.spatial.rotations_from_rotvecs(
+ rotation_vectors=sc.vector(
+ Rotation.from_quat([x, y, z, w]).as_rotvec(),
+ unit='rad',
+ )
+ )
diff --git a/tests/loader_test.py b/tests/loader_test.py
index cef49f2..2b969c8 100644
--- a/tests/loader_test.py
+++ b/tests/loader_test.py
@@ -10,6 +10,7 @@
def test_file_reader_mcstas() -> None:
+ import numpy as np
import scippnexus as snx
from ess.nmx.mcstas_loader import (
@@ -30,8 +31,27 @@ def test_file_reader_mcstas() -> None:
assert isinstance(da, sc.DataArray)
assert da.shape == (3, 1280 * 1280)
+ assert sc.identical(
+ da.coords['sample_position'], sc.vector(value=[0, 0, 0], unit='m')
+ )
assert da.bins.size().sum().value == data_length
assert sc.identical(da.data.max(), expected_weight_max)
+ # Expected coordinate values are provided by the IDS
+ # based on the simulation settings of the sample file.
+ # The expected values are rounded to 2 decimal places.
+ assert np.all(
+ np.round(da.coords['fast_axis'].values, 2)
+ == sc.vectors(
+ dims=['panel'],
+ values=[(1.0, 0.0, -0.01), (-0.01, 0.0, -1.0), (0.01, 0.0, 1.0)],
+ ).values,
+ )
+ assert sc.identical(
+ da.coords['slow_axis'],
+ sc.vectors(
+ dims=['panel'], values=[[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]
+ ),
+ )
@pytest.fixture