From 65b1be51590ff72ad7b1588f9fac3d745ecedeb9 Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Mon, 8 Jan 2024 18:45:57 +0100 Subject: [PATCH 01/10] Geometry parsing draft --- src/ess/nmx/mcstas_loader.py | 273 ++++++++++++++++++++++++++++++++++- src/ess/nmx/rotation.py | 70 +++++++++ 2 files changed, 337 insertions(+), 6 deletions(-) create mode 100644 src/ess/nmx/rotation.py diff --git a/src/ess/nmx/mcstas_loader.py b/src/ess/nmx/mcstas_loader.py index 9085542..103fd34 100644 --- a/src/ess/nmx/mcstas_loader.py +++ b/src/ess/nmx/mcstas_loader.py @@ -1,19 +1,128 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Iterable, NewType, Optional +from dataclasses import dataclass +from types import MappingProxyType +from typing import Iterable, NamedTuple, NewType, Optional, Protocol, Tuple +import numpy as np import scipp as sc import scippnexus as snx +from typing_extensions import Self PixelIDs = NewType("PixelIDs", sc.Variable) InputFilepath = NewType("InputFilepath", str) -NMXData = NewType("NMXData", sc.DataArray) +NMXData = NewType("NMXData", sc.DataGroup) # McStas Configurations MaximumProbability = NewType("MaximumProbability", int) DefaultMaximumProbability = MaximumProbability(100_000) +AXIS_TO_VECTOR = MappingProxyType( + { + 'x': sc.vector([1.0, 0.0, 0.0]), + 'y': sc.vector([0.0, 1.0, 0.0]), + 'z': sc.vector([0.0, 0.0, 1.0]), + } +) + + +class _XML(Protocol): + """XML element type. Temporarily used for type hinting. + + Builtin XML type is blocked by bandit security check.""" + + tag: str + attrib: dict[str, str] + + def find(self, name: str) -> Optional[Self]: + ... + + def __iter__(self) -> Self: + ... + + def __next__(self) -> Self: + ... + + +class Position3D(NamedTuple): + """3D vector of location.""" + + x: float + y: float + z: float + + +class RotationAxisAngle(NamedTuple): + """Rotation in axis-angle representation.""" + + theta: float + x: float + y: float + z: float + + +@dataclass +class DetectorDesc: + """Combined information of detector and detector type in McStas.""" + + component_type: str # 'type' + name: str + id_start: int # 'idstart' + fast_axis_name: str # 'idfillbyfirst' + position: sc.Variable # 'x', 'y', 'z' + rotation: RotationAxisAngle + num_x: int # 'xpixels' + num_y: int # 'ypixels' + step_x: float # 'xstep' + step_y: float # 'ystep' + # Calculated fields + _rotation_matrix: Optional[sc.Variable] = None + _fast_axis: Optional[sc.Variable] = None + _slow_axis: Optional[sc.Variable] = None + + @property + def total_pixels(self) -> int: + return self.num_x * self.num_y + + @property + def slow_axis_name(self) -> str: + if self.fast_axis_name not in 'xy': + raise ValueError( + f"Invalid slow axis {self.fast_axis_name}.Should be 'x' or 'y'." + ) + + return 'xy'.replace(self.fast_axis_name, '') + + @property + def rotation_matrix(self) -> sc.Variable: + if self._rotation_matrix is None: + from .rotation import axis_angle_to_quaternion, quaternion_to_matrix + + theta, x, y, z = self.rotation + q = axis_angle_to_quaternion(x, y, z, sc.scalar(-theta, unit='deg')) + self._rotation_matrix = quaternion_to_matrix(*q) + + return self._rotation_matrix + + def _rotate_axis(self, axis: sc.Variable) -> sc.Variable: + return sc.vector(np.round((self.rotation_matrix * axis).values, 2)) + + @property + def fast_axis(self) -> sc.Variable: + if self._fast_axis is None: + self._fast_axis = self._rotate_axis(AXIS_TO_VECTOR[self.fast_axis_name]) + + return self._fast_axis + + @property + def slow_axis(self) -> sc.Variable: + if self._slow_axis is None: + self._slow_axis = self._rotate_axis(AXIS_TO_VECTOR[self.slow_axis_name]) + + return self._slow_axis + + def _retrieve_event_list_name(keys: Iterable[str]) -> str: prefix = "bank01_events_dat_list" @@ -37,13 +146,142 @@ def _copy_partial_var( return var -def _get_mcstas_pixel_ids() -> PixelIDs: +def _get_mcstas_pixel_ids(detector_descs: Tuple[DetectorDesc, ...]) -> PixelIDs: """pixel IDs for each detector""" - intervals = [(1, 1638401), (2000001, 3638401), (4000001, 5638401)] + intervals = [ + (desc.id_start, desc.id_start + desc.total_pixels) for desc in detector_descs + ] ids = [sc.arange('id', start, stop, unit=None) for start, stop in intervals] return PixelIDs(sc.concat(ids, 'id')) +def _pixel_positions(detector: DetectorDesc, sample_position: sc.Variable): + """pixel IDs for each detector""" + pixel_idx = sc.arange('id', detector.total_pixels) + n_rows = sc.scalar( + detector.num_x if detector.fast_axis_name == 'x' else detector.num_y + ) + steps = { + 'x': sc.scalar(detector.step_x, unit='m'), + 'y': sc.scalar(detector.step_y, unit='m'), + } + + pixel_n_slow = pixel_idx // n_rows + pixel_n_fast = pixel_idx % n_rows + + fast_axis_steps = detector.fast_axis * steps[detector.fast_axis_name] + slow_axis_steps = detector.slow_axis * steps[detector.slow_axis_name] + + return ( + (pixel_n_slow * slow_axis_steps) + + (pixel_n_fast * fast_axis_steps) + + (detector.position - sample_position) + ) + + +def _get_mcstas_pixel_positions( + detector_descs: Tuple[DetectorDesc, ...], sample_position +): + """pixel IDs for each detector""" + positions = [ + _pixel_positions(detector, sample_position) for detector in detector_descs + ] + return sc.concat(positions, 'panel') + + +def _read_mcstas_geometry_xml(file_path: InputFilepath) -> bytes: + """Retrieve geometry parameters from mcstas file""" + import h5py + + instrument_xml_path = 'entry1/instrument/instrument_xml/data' + with h5py.File(file_path) as file: + return file[instrument_xml_path][...][0] + + +def _select_by_type_prefix(components: list[_XML], prefix: str) -> list[_XML]: + """Select components by type prefix.""" + return [comp for comp in components if comp.attrib['type'].startswith(prefix)] + + +def _check_and_unpack_if_only_one(xml_items: list[_XML], name: str) -> _XML: + """Check if there is only one element with ``name``.""" + if len(xml_items) > 1: + raise ValueError(f"Multiple {name}s found.") + elif len(xml_items) == 0: + raise ValueError(f"No {name} found.") + + return xml_items.pop() + + +def _retrieve_attribs(component: _XML, *args: str) -> list[float]: + """Retrieve ``args`` from xml.""" + + return [float(component.attrib[key]) for key in args] + + +def find_location(component: _XML) -> _XML: + """Retrieve ``location`` from xml component.""" + location = component.find('location') + if location is None: + raise ValueError("No location found in component ", component.find('name')) + + return location + + +def _retrieve_3d_position(component: _XML) -> sc.Variable: + """Retrieve x, y, z position from xml.""" + location = find_location(component) + + return sc.vector(_retrieve_attribs(location, 'x', 'y', 'z'), unit='m') + + +def _retrieve_detector_descriptions(tree: _XML) -> Tuple[DetectorDesc, ...]: + """Retrieve detector geometry descriptions from mcstas file.""" + + def _retrieve_rotation_axis_angle(component: _XML) -> RotationAxisAngle: + """Retrieve rotation angle(theta), x, y, z axes from location.""" + location = find_location(component) + return RotationAxisAngle( + *_retrieve_attribs(location, 'rot', 'axis-x', 'axis-y', 'axis-z') + ) + + def _find_type_desc(det: _XML, types: list[_XML]) -> _XML: + for type_ in types: + if type_.attrib['name'] == det.attrib['type']: + return type_ + + raise ValueError( + f"Cannot find type {det.attrib['type']} for {det.attrib['name']}." + ) + + components = [branch for branch in tree if branch.tag == 'component'] + detectors = [ + comp for comp in components if comp.attrib['type'].startswith('MonNDtype') + ] + type_list = [branch for branch in tree if branch.tag == 'type'] + + detector_components = [] + for det in detectors: + det_type = _find_type_desc(det, type_list) + + detector_components.append( + DetectorDesc( + component_type=det_type.attrib['name'], + name=det.attrib['name'], + id_start=int(det.attrib['idstart']), + fast_axis_name=det.attrib['idfillbyfirst'], + position=_retrieve_3d_position(det), + rotation=RotationAxisAngle(*_retrieve_rotation_axis_angle(det)), + num_x=int(det_type.attrib['xpixels']), + num_y=int(det_type.attrib['ypixels']), + step_x=float(det_type.attrib['xstep']), + step_y=float(det_type.attrib['ystep']), + ) + ) + + return tuple(sorted(detector_components, key=lambda x: x.id_start)) + + def load_mcstas_nexus( file_path: InputFilepath, max_probability: Optional[MaximumProbability] = None, @@ -59,6 +297,20 @@ def load_mcstas_nexus( The maximum probability to scale the weights. """ + from defusedxml.ElementTree import fromstring + + tree = fromstring(_read_mcstas_geometry_xml(file_path)) + detector_descs = _retrieve_detector_descriptions(tree) + components = [branch for branch in tree if branch.tag == 'component'] + sources = _select_by_type_prefix(components, 'sourceMantid-type') + samples = _select_by_type_prefix(components, 'sampleMantid-type') + source = _check_and_unpack_if_only_one(sources, 'source') + sample = _check_and_unpack_if_only_one(samples, 'sample') + sample_position = _retrieve_3d_position(sample) + + slow_axes = [det.slow_axis for det in detector_descs] + fast_axes = [det.fast_axis for det in detector_descs] + origins = [det.position - sample_position for det in detector_descs] probability = max_probability or DefaultMaximumProbability @@ -76,6 +328,15 @@ def load_mcstas_nexus( weights = (probability / weights.max()) * weights loaded = sc.DataArray(data=weights, coords={'t': t_list, 'id': id_list}) - grouped = loaded.group(_get_mcstas_pixel_ids()) + grouped = loaded.group(_get_mcstas_pixel_ids(detector_descs)) + da = grouped.fold(dim='id', sizes={'panel': len(detector_descs), 'id': -1}) + da.coords['fast_axis'] = sc.concat(fast_axes, 'panel') + da.coords['slow_axis'] = sc.concat(slow_axes, 'panel') + da.coords['origin_position'] = sc.concat(origins, 'panel') + da.coords['position'] = _get_mcstas_pixel_positions( + detector_descs, sample_position + ) + da.coords['sample_position'] = sample_position - sample_position + da.coords['source_position'] = _retrieve_3d_position(source) - sample_position - return NMXData(grouped.fold(dim='id', sizes={'panel': 3, 'id': -1})) + return NMXData(da) diff --git a/src/ess/nmx/rotation.py b/src/ess/nmx/rotation.py new file mode 100644 index 0000000..98beab8 --- /dev/null +++ b/src/ess/nmx/rotation.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +# Rotation related functions for NMX +import numpy as np +import scipp as sc +from numpy.typing import NDArray + + +def axis_angle_to_quaternion( + x: float, y: float, z: float, theta: sc.Variable +) -> NDArray: + """Convert axis-angle to queternions, [x, y, z, w]. + + Parameters + ---------- + x: + X component of axis of rotation. + y: + Y component of axis of rotation. + z: + Z component of axis of rotation. + theta: + Angle of rotation, with unit of ``rad`` or ``deg``. + + Returns + ------- + : + A list of (normalized) queternions, [x, y, z, w]. + + Notes + ----- + Axis of rotation (x, y, z) does not need to be normalized, + but it returns a unit quaternion (x, y, z, w). + + """ + + w: sc.Variable = sc.cos(theta.to(unit='rad') / 2) + xyz: sc.Variable = -sc.sin(theta.to(unit='rad') / 2) * sc.vector([x, y, z]) + q = np.array([*xyz.values, w.value]) + return q / np.linalg.norm(q) + + +def quaternion_to_matrix(x: float, y: float, z: float, w: float) -> sc.Variable: + """Convert quaternion to rotation matrix. + + Parameters + ---------- + x: + x(a) component of quaternion. + y: + y(b) component of quaternion. + z: + z(c) component of quaternion. + w: + w component of quaternion. + + Returns + ------- + : + A 3X3 rotation matrix (3 vectors). + + """ + from scipy.spatial.transform import Rotation + + return sc.spatial.rotations_from_rotvecs( + rotation_vectors=sc.vector( + Rotation.from_quat([x, y, z, w]).as_rotvec(), + unit='rad', + ) + ) From 6e94c66148fec5fa2adb6ac347e4be624643e0f0 Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Tue, 9 Jan 2024 14:21:28 +0100 Subject: [PATCH 02/10] McStas data loader with mcstas geometry xml parser. --- docs/examples/workflow.ipynb | 33 ++- src/ess/nmx/mcstas_loader.py | 276 +----------------------- src/ess/nmx/mcstas_xml.py | 393 +++++++++++++++++++++++++++++++++++ 3 files changed, 431 insertions(+), 271 deletions(-) create mode 100644 src/ess/nmx/mcstas_xml.py diff --git a/docs/examples/workflow.ipynb b/docs/examples/workflow.ipynb index 5a0551c..5b47571 100644 --- a/docs/examples/workflow.ipynb +++ b/docs/examples/workflow.ipynb @@ -91,11 +91,40 @@ "da = nmx_workflow.compute(NMXData)\n", "da" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instrument View\n", + "\n", + "Pixel positions are not used for later steps, but it is included in the coordinates for instrument view.\n", + "\n", + "All pixel positions are respect to the sample position, therefore the sample is at (0, 0, 0).\n", + "\n", + "**It might be very slow or not work in the ``VS Code`` jupyter notebook editor.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import scippneutron as scn\n", + "\n", + "\n", + "unnecessary_coords = list(coord for coord in da.coords if coord != 'position')\n", + "instrument_view_da = da.drop_coords(unnecessary_coords).flatten(['panel', 'id'], 'id').hist()\n", + "view = scn.instrument_view(instrument_view_da)\n", + "view.children[0].toolbar.cameraz()\n", + "view" + ] } ], "metadata": { "kernelspec": { - "display_name": "nmx-dev-39", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -113,5 +142,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/src/ess/nmx/mcstas_loader.py b/src/ess/nmx/mcstas_loader.py index 103fd34..92dbf55 100644 --- a/src/ess/nmx/mcstas_loader.py +++ b/src/ess/nmx/mcstas_loader.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from dataclasses import dataclass -from types import MappingProxyType -from typing import Iterable, NamedTuple, NewType, Optional, Protocol, Tuple +from typing import Iterable, NewType, Optional -import numpy as np import scipp as sc import scippnexus as snx -from typing_extensions import Self PixelIDs = NewType("PixelIDs", sc.Variable) InputFilepath = NewType("InputFilepath", str) @@ -18,111 +14,6 @@ DefaultMaximumProbability = MaximumProbability(100_000) -AXIS_TO_VECTOR = MappingProxyType( - { - 'x': sc.vector([1.0, 0.0, 0.0]), - 'y': sc.vector([0.0, 1.0, 0.0]), - 'z': sc.vector([0.0, 0.0, 1.0]), - } -) - - -class _XML(Protocol): - """XML element type. Temporarily used for type hinting. - - Builtin XML type is blocked by bandit security check.""" - - tag: str - attrib: dict[str, str] - - def find(self, name: str) -> Optional[Self]: - ... - - def __iter__(self) -> Self: - ... - - def __next__(self) -> Self: - ... - - -class Position3D(NamedTuple): - """3D vector of location.""" - - x: float - y: float - z: float - - -class RotationAxisAngle(NamedTuple): - """Rotation in axis-angle representation.""" - - theta: float - x: float - y: float - z: float - - -@dataclass -class DetectorDesc: - """Combined information of detector and detector type in McStas.""" - - component_type: str # 'type' - name: str - id_start: int # 'idstart' - fast_axis_name: str # 'idfillbyfirst' - position: sc.Variable # 'x', 'y', 'z' - rotation: RotationAxisAngle - num_x: int # 'xpixels' - num_y: int # 'ypixels' - step_x: float # 'xstep' - step_y: float # 'ystep' - # Calculated fields - _rotation_matrix: Optional[sc.Variable] = None - _fast_axis: Optional[sc.Variable] = None - _slow_axis: Optional[sc.Variable] = None - - @property - def total_pixels(self) -> int: - return self.num_x * self.num_y - - @property - def slow_axis_name(self) -> str: - if self.fast_axis_name not in 'xy': - raise ValueError( - f"Invalid slow axis {self.fast_axis_name}.Should be 'x' or 'y'." - ) - - return 'xy'.replace(self.fast_axis_name, '') - - @property - def rotation_matrix(self) -> sc.Variable: - if self._rotation_matrix is None: - from .rotation import axis_angle_to_quaternion, quaternion_to_matrix - - theta, x, y, z = self.rotation - q = axis_angle_to_quaternion(x, y, z, sc.scalar(-theta, unit='deg')) - self._rotation_matrix = quaternion_to_matrix(*q) - - return self._rotation_matrix - - def _rotate_axis(self, axis: sc.Variable) -> sc.Variable: - return sc.vector(np.round((self.rotation_matrix * axis).values, 2)) - - @property - def fast_axis(self) -> sc.Variable: - if self._fast_axis is None: - self._fast_axis = self._rotate_axis(AXIS_TO_VECTOR[self.fast_axis_name]) - - return self._fast_axis - - @property - def slow_axis(self) -> sc.Variable: - if self._slow_axis is None: - self._slow_axis = self._rotate_axis(AXIS_TO_VECTOR[self.slow_axis_name]) - - return self._slow_axis - - def _retrieve_event_list_name(keys: Iterable[str]) -> str: prefix = "bank01_events_dat_list" @@ -146,142 +37,6 @@ def _copy_partial_var( return var -def _get_mcstas_pixel_ids(detector_descs: Tuple[DetectorDesc, ...]) -> PixelIDs: - """pixel IDs for each detector""" - intervals = [ - (desc.id_start, desc.id_start + desc.total_pixels) for desc in detector_descs - ] - ids = [sc.arange('id', start, stop, unit=None) for start, stop in intervals] - return PixelIDs(sc.concat(ids, 'id')) - - -def _pixel_positions(detector: DetectorDesc, sample_position: sc.Variable): - """pixel IDs for each detector""" - pixel_idx = sc.arange('id', detector.total_pixels) - n_rows = sc.scalar( - detector.num_x if detector.fast_axis_name == 'x' else detector.num_y - ) - steps = { - 'x': sc.scalar(detector.step_x, unit='m'), - 'y': sc.scalar(detector.step_y, unit='m'), - } - - pixel_n_slow = pixel_idx // n_rows - pixel_n_fast = pixel_idx % n_rows - - fast_axis_steps = detector.fast_axis * steps[detector.fast_axis_name] - slow_axis_steps = detector.slow_axis * steps[detector.slow_axis_name] - - return ( - (pixel_n_slow * slow_axis_steps) - + (pixel_n_fast * fast_axis_steps) - + (detector.position - sample_position) - ) - - -def _get_mcstas_pixel_positions( - detector_descs: Tuple[DetectorDesc, ...], sample_position -): - """pixel IDs for each detector""" - positions = [ - _pixel_positions(detector, sample_position) for detector in detector_descs - ] - return sc.concat(positions, 'panel') - - -def _read_mcstas_geometry_xml(file_path: InputFilepath) -> bytes: - """Retrieve geometry parameters from mcstas file""" - import h5py - - instrument_xml_path = 'entry1/instrument/instrument_xml/data' - with h5py.File(file_path) as file: - return file[instrument_xml_path][...][0] - - -def _select_by_type_prefix(components: list[_XML], prefix: str) -> list[_XML]: - """Select components by type prefix.""" - return [comp for comp in components if comp.attrib['type'].startswith(prefix)] - - -def _check_and_unpack_if_only_one(xml_items: list[_XML], name: str) -> _XML: - """Check if there is only one element with ``name``.""" - if len(xml_items) > 1: - raise ValueError(f"Multiple {name}s found.") - elif len(xml_items) == 0: - raise ValueError(f"No {name} found.") - - return xml_items.pop() - - -def _retrieve_attribs(component: _XML, *args: str) -> list[float]: - """Retrieve ``args`` from xml.""" - - return [float(component.attrib[key]) for key in args] - - -def find_location(component: _XML) -> _XML: - """Retrieve ``location`` from xml component.""" - location = component.find('location') - if location is None: - raise ValueError("No location found in component ", component.find('name')) - - return location - - -def _retrieve_3d_position(component: _XML) -> sc.Variable: - """Retrieve x, y, z position from xml.""" - location = find_location(component) - - return sc.vector(_retrieve_attribs(location, 'x', 'y', 'z'), unit='m') - - -def _retrieve_detector_descriptions(tree: _XML) -> Tuple[DetectorDesc, ...]: - """Retrieve detector geometry descriptions from mcstas file.""" - - def _retrieve_rotation_axis_angle(component: _XML) -> RotationAxisAngle: - """Retrieve rotation angle(theta), x, y, z axes from location.""" - location = find_location(component) - return RotationAxisAngle( - *_retrieve_attribs(location, 'rot', 'axis-x', 'axis-y', 'axis-z') - ) - - def _find_type_desc(det: _XML, types: list[_XML]) -> _XML: - for type_ in types: - if type_.attrib['name'] == det.attrib['type']: - return type_ - - raise ValueError( - f"Cannot find type {det.attrib['type']} for {det.attrib['name']}." - ) - - components = [branch for branch in tree if branch.tag == 'component'] - detectors = [ - comp for comp in components if comp.attrib['type'].startswith('MonNDtype') - ] - type_list = [branch for branch in tree if branch.tag == 'type'] - - detector_components = [] - for det in detectors: - det_type = _find_type_desc(det, type_list) - - detector_components.append( - DetectorDesc( - component_type=det_type.attrib['name'], - name=det.attrib['name'], - id_start=int(det.attrib['idstart']), - fast_axis_name=det.attrib['idfillbyfirst'], - position=_retrieve_3d_position(det), - rotation=RotationAxisAngle(*_retrieve_rotation_axis_angle(det)), - num_x=int(det_type.attrib['xpixels']), - num_y=int(det_type.attrib['ypixels']), - step_x=float(det_type.attrib['xstep']), - step_y=float(det_type.attrib['ystep']), - ) - ) - - return tuple(sorted(detector_components, key=lambda x: x.id_start)) - - def load_mcstas_nexus( file_path: InputFilepath, max_probability: Optional[MaximumProbability] = None, @@ -297,21 +52,10 @@ def load_mcstas_nexus( The maximum probability to scale the weights. """ - from defusedxml.ElementTree import fromstring - tree = fromstring(_read_mcstas_geometry_xml(file_path)) - detector_descs = _retrieve_detector_descriptions(tree) - components = [branch for branch in tree if branch.tag == 'component'] - sources = _select_by_type_prefix(components, 'sourceMantid-type') - samples = _select_by_type_prefix(components, 'sampleMantid-type') - source = _check_and_unpack_if_only_one(sources, 'source') - sample = _check_and_unpack_if_only_one(samples, 'sample') - sample_position = _retrieve_3d_position(sample) - - slow_axes = [det.slow_axis for det in detector_descs] - fast_axes = [det.fast_axis for det in detector_descs] - origins = [det.position - sample_position for det in detector_descs] + from .mcstas_xml import read_mcstas_geometry_xml + geometry = read_mcstas_geometry_xml(file_path) probability = max_probability or DefaultMaximumProbability with snx.File(file_path) as file: @@ -328,15 +72,9 @@ def load_mcstas_nexus( weights = (probability / weights.max()) * weights loaded = sc.DataArray(data=weights, coords={'t': t_list, 'id': id_list}) - grouped = loaded.group(_get_mcstas_pixel_ids(detector_descs)) - da = grouped.fold(dim='id', sizes={'panel': len(detector_descs), 'id': -1}) - da.coords['fast_axis'] = sc.concat(fast_axes, 'panel') - da.coords['slow_axis'] = sc.concat(slow_axes, 'panel') - da.coords['origin_position'] = sc.concat(origins, 'panel') - da.coords['position'] = _get_mcstas_pixel_positions( - detector_descs, sample_position - ) - da.coords['sample_position'] = sample_position - sample_position - da.coords['source_position'] = _retrieve_3d_position(source) - sample_position + coords = geometry.to_coords() + grouped = loaded.group(coords.pop('pixel_ids')) + da = grouped.fold(dim='id', sizes={'panel': len(geometry.detectors), 'id': -1}) + da.coords.update(coords) return NMXData(da) diff --git a/src/ess/nmx/mcstas_xml.py b/src/ess/nmx/mcstas_xml.py new file mode 100644 index 0000000..4ba79ef --- /dev/null +++ b/src/ess/nmx/mcstas_xml.py @@ -0,0 +1,393 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +# McStas instrument geometry xml description related functions. +from dataclasses import dataclass +from pathlib import Path +from types import MappingProxyType +from typing import Iterable, Optional, Protocol, Tuple, TypeVar, Union + +import numpy as np +import scipp as sc +from typing_extensions import Self + +T = TypeVar('T') + + +_AXISNAME_TO_UNIT_VECTOR = MappingProxyType( + { + 'x': sc.vector([1.0, 0.0, 0.0]), + 'y': sc.vector([0.0, 1.0, 0.0]), + 'z': sc.vector([0.0, 0.0, 1.0]), + } +) + + +class _XML(Protocol): + """XML element or tree type. + + Temporarily used for type hinting. + Builtin XML type is blocked by bandit security check.""" + + tag: str + attrib: dict[str, str] + + def find(self, name: str) -> Optional[Self]: + ... + + def __iter__(self) -> Self: + ... + + def __next__(self) -> Self: + ... + + +def _check_and_unpack_if_only_one(xml_items: list[_XML], name: str) -> _XML: + """Check if there is only one element with ``name``.""" + if len(xml_items) > 1: + raise ValueError(f"Multiple {name}s found.") + elif len(xml_items) == 0: + raise ValueError(f"No {name} found.") + + return xml_items.pop() + + +def select_by_tag(xml_items: _XML, tag: str) -> _XML: + """Select element with ``tag`` if there is only one.""" + + return _check_and_unpack_if_only_one(list(filter_by_tag(xml_items, tag)), tag) + + +def filter_by_tag(xml_items: Iterable[_XML], tag: str) -> Iterable[_XML]: + """Filter xml items by tag.""" + return (item for item in xml_items if item.tag == tag) + + +def filter_by_type_prefix(xml_items: Iterable[_XML], prefix: str) -> Iterable[_XML]: + """Filter xml items by type prefix.""" + return ( + item for item in xml_items if item.attrib.get('type', '').startswith(prefix) + ) + + +def select_by_type_prefix(xml_items: Iterable[_XML], prefix: str) -> _XML: + """Select xml item by type prefix.""" + + cands = list(filter_by_type_prefix(xml_items, prefix)) + return _check_and_unpack_if_only_one(cands, prefix) + + +def find_attributes(component: _XML, *args: str) -> dict[str, float]: + """Retrieve ``args`` as float from xml.""" + + return {key: float(component.attrib[key]) for key in args} + + +@dataclass +class SimulationSettings: + """Simulation settings extracted from McStas instrument xml description.""" + + # From + length_unit: str # 'unit' of + angle_unit: str # 'unit' of + # From + beam_axis: str # 'axis' of + handedness: str # 'val' of + + @classmethod + def from_xml(cls, tree: _XML) -> Self: + """Create simulation settings from xml.""" + defaults = select_by_tag(tree, 'defaults') + length_desc = select_by_tag(defaults, 'length') + angle_desc = select_by_tag(defaults, 'angle') + reference_frame = select_by_tag(defaults, 'reference-frame') + along_beam = select_by_tag(reference_frame, 'along-beam') + handedness = select_by_tag(reference_frame, 'handedness') + + return cls( + length_unit=length_desc.attrib['unit'], + angle_unit=angle_desc.attrib['unit'], + beam_axis=along_beam.attrib['axis'], + handedness=handedness.attrib['val'], + ) + + +def _position_from_location(location: _XML, unit: str = 'm') -> sc.Variable: + """Retrieve position from location.""" + x, y, z = find_attributes(location, 'x', 'y', 'z').values() + return sc.vector([x, y, z], unit=unit) + + +def _rotation_matrix_from_location( + location: _XML, angle_unit: str = 'degree' +) -> sc.Variable: + """Retrieve rotation matrix from location.""" + from .rotation import axis_angle_to_quaternion, quaternion_to_matrix + + theta, x, y, z = find_attributes( + location, 'rot', 'axis-x', 'axis-y', 'axis-z' + ).values() + q = axis_angle_to_quaternion(x, y, z, sc.scalar(-theta, unit=angle_unit)) + return quaternion_to_matrix(*q) + + +@dataclass +class DetectorDesc: + """Detector information extracted from McStas instrument xml description.""" + + # From + component_type: str # 'type' + name: str + id_start: int # 'idstart' + fast_axis_name: str # 'idfillbyfirst' + # From + num_x: int # 'xpixels' + num_y: int # 'ypixels' + step_x: sc.Variable # 'xstep' + step_y: sc.Variable # 'ystep' + start_x: float # 'xstart' + start_y: float # 'ystart' + # From under + position: sc.Variable # 'x', 'y', 'z' + # Calculated fields + rotation_matrix: sc.Variable + slow_axis_name: str + fast_axis: sc.Variable + slow_axis: sc.Variable + + @classmethod + def from_xml( + cls, component: _XML, type_desc: _XML, simulation_settings: SimulationSettings + ) -> Self: + """Create detector description from xml component and type.""" + + def _rotate_axis(matrix: sc.Variable, axis: sc.Variable) -> sc.Variable: + return sc.vector(np.round((matrix * axis).values, 2)) + + location = select_by_tag(component, 'location') + rotation_matrix = _rotation_matrix_from_location( + location, simulation_settings.angle_unit + ) + fast_axis_name = component.attrib['idfillbyfirst'] + slow_axis_name = 'xy'.replace(fast_axis_name, '') + + length_unit = simulation_settings.length_unit + + return cls( + component_type=type_desc.attrib['name'], + name=component.attrib['name'], + id_start=int(component.attrib['idstart']), + fast_axis_name=fast_axis_name, + slow_axis_name=slow_axis_name, + num_x=int(type_desc.attrib['xpixels']), + num_y=int(type_desc.attrib['ypixels']), + step_x=sc.scalar(float(type_desc.attrib['xstep']), unit=length_unit), + step_y=sc.scalar(float(type_desc.attrib['ystep']), unit=length_unit), + start_x=float(type_desc.attrib['xstart']), + start_y=float(type_desc.attrib['ystart']), + position=_position_from_location(location, simulation_settings.length_unit), + rotation_matrix=rotation_matrix, + fast_axis=_rotate_axis( + rotation_matrix, _AXISNAME_TO_UNIT_VECTOR[fast_axis_name] + ), + slow_axis=_rotate_axis( + rotation_matrix, _AXISNAME_TO_UNIT_VECTOR[slow_axis_name] + ), + ) + + @property + def total_pixels(self) -> int: + return self.num_x * self.num_y + + @property + def slow_step(self) -> sc.Variable: + return self.step_y if self.fast_axis_name == 'x' else self.step_x + + @property + def fast_step(self) -> sc.Variable: + return self.step_x if self.fast_axis_name == 'x' else self.step_y + + @property + def num_fast_pixels_per_row(self) -> int: + """Number of pixels in each row of the detector along the fast axis.""" + return self.num_x if self.fast_axis_name == 'x' else self.num_y + + +def _collect_detector_descriptions(tree: _XML) -> Tuple[DetectorDesc, ...]: + """Retrieve detector geometry descriptions from mcstas file.""" + type_list = filter_by_tag(tree, 'type') + simulation_settings = SimulationSettings.from_xml(tree) + + def _find_type_desc(det: _XML) -> _XML: + for type_ in type_list: + if type_.attrib['name'] == det.attrib['type']: + return type_ + + raise ValueError( + f"Cannot find type {det.attrib['type']} for {det.attrib['name']}." + ) + + detector_components = [ + DetectorDesc.from_xml(det, _find_type_desc(det), simulation_settings) + for det in filter_by_type_prefix(filter_by_tag(tree, 'component'), 'MonNDtype') + ] + + return tuple(sorted(detector_components, key=lambda x: x.id_start)) + + +@dataclass +class SampleDesc: + """Sample description extracted from McStas instrument xml description.""" + + # From + component_type: str + name: str + # From under + position: sc.Variable + rotation_matrix: sc.Variable + + @classmethod + def from_xml(cls, tree: _XML, simulation_settings: SimulationSettings) -> Self: + """Create sample description from xml component.""" + source_xml = select_by_type_prefix(tree, 'sampleMantid-type') + location = select_by_tag(source_xml, 'location') + + return cls( + component_type=source_xml.attrib['type'], + name=source_xml.attrib['name'], + position=_position_from_location(location, simulation_settings.length_unit), + rotation_matrix=_rotation_matrix_from_location( + location, simulation_settings.angle_unit + ), + ) + + def position_from_sample(self, other: sc.Variable) -> sc.Variable: + """Position of ``other`` relative to the sample. + + All positions and distance are stored respect to the sample position. + + Parameters + ---------- + other: + Position of the other object in 3D vector. + + """ + + return other - self.position + + +@dataclass +class SourceDesc: + """Source description extracted from McStas instrument xml description.""" + + # From + component_type: str + name: str + # From under + position: sc.Variable + + @classmethod + def from_xml(cls, tree: _XML, simulation_settings: SimulationSettings) -> Self: + """Create source description from xml component.""" + source_xml = select_by_type_prefix(tree, 'sourceMantid-type') + location = select_by_tag(source_xml, 'location') + + return cls( + component_type=source_xml.attrib['type'], + name=source_xml.attrib['name'], + position=_position_from_location(location, simulation_settings.length_unit), + ) + + +def _construct_pixel_ids(detector_descs: Tuple[DetectorDesc, ...]) -> sc.Variable: + """Pixel IDs for all detectors.""" + intervals = [ + (desc.id_start, desc.id_start + desc.total_pixels) for desc in detector_descs + ] + ids = [sc.arange('id', start, stop, unit=None) for start, stop in intervals] + return sc.concat(ids, 'id') + + +def _pixel_positions( + detector: DetectorDesc, position_offset: sc.Variable +) -> sc.Variable: + """Position of pixels of the ``detector``. + + Position of each pixel is relative to the position_offset. + """ + pixel_idx = sc.arange('id', detector.total_pixels) + n_row = sc.scalar(detector.num_fast_pixels_per_row) + + pixel_n_slow = pixel_idx // n_row + pixel_n_fast = pixel_idx % n_row + + fast_axis_steps = detector.fast_axis * detector.fast_step + slow_axis_steps = detector.slow_axis * detector.slow_step + + return ( + (pixel_n_slow * slow_axis_steps) + + (pixel_n_fast * fast_axis_steps) + + position_offset + ) + + +def _detector_pixel_positions( + detector_descs: Tuple[DetectorDesc, ...], sample: SampleDesc +) -> sc.Variable: + """Position of pixels of all detectors.""" + + positions = [ + _pixel_positions(detector, sample.position_from_sample(detector.position)) + for detector in detector_descs + ] + return sc.concat(positions, 'panel') + + +@dataclass +class McStasInstrument: + simulation_settings: SimulationSettings + detectors: Tuple[DetectorDesc, ...] + source: SourceDesc + sample: SampleDesc + + @classmethod + def from_xml(cls, tree: _XML) -> Self: + """Create McStas instrument from xml.""" + simulation_settings = SimulationSettings.from_xml(tree) + + return cls( + simulation_settings=simulation_settings, + detectors=_collect_detector_descriptions(tree), + source=SourceDesc.from_xml(tree, simulation_settings), + sample=SampleDesc.from_xml(tree, simulation_settings), + ) + + def to_coords(self) -> dict[str, sc.Variable]: + """Extract coordinates from the McStas instrument description.""" + slow_axes = [det.slow_axis for det in self.detectors] + fast_axes = [det.fast_axis for det in self.detectors] + origins = [ + self.sample.position_from_sample(det.position) for det in self.detectors + ] + detector_dim = 'panel' + + return { + 'pixel_ids': _construct_pixel_ids(self.detectors), + 'fast_axis': sc.concat(fast_axes, detector_dim), + 'slow_axis': sc.concat(slow_axes, detector_dim), + 'origin_position': sc.concat(origins, detector_dim), + 'sample_position': self.sample.position_from_sample(self.sample.position), + 'source_position': self.sample.position_from_sample(self.source.position), + 'sample_name': sc.scalar(self.sample.name), + 'position': _detector_pixel_positions(self.detectors, self.sample), + } + + +def read_mcstas_geometry_xml(file_path: Union[Path, str]) -> McStasInstrument: + """Retrieve geometry parameters from mcstas file""" + import h5py + from defusedxml.ElementTree import fromstring + + instrument_xml_path = 'entry1/instrument/instrument_xml/data' + with h5py.File(file_path) as file: + tree = fromstring(file[instrument_xml_path][...][0]) + return McStasInstrument.from_xml(tree) From ae8ebd01b47ac3ee15e2df3b4635e502d1122d29 Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Tue, 9 Jan 2024 14:34:14 +0100 Subject: [PATCH 03/10] Use string-based type-hinting instead of self. --- docs/examples/workflow.ipynb | 3 ++- src/ess/nmx/mcstas_xml.py | 21 ++++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/docs/examples/workflow.ipynb b/docs/examples/workflow.ipynb index 5b47571..1f8977b 100644 --- a/docs/examples/workflow.ipynb +++ b/docs/examples/workflow.ipynb @@ -9,7 +9,8 @@ "## Collect Parameters and Providers\n", "### Simulation(McStas) Data\n", "There is a dedicated loader, ``load_mcstas_nexus`` for ``McStas`` simulation data workflow.
\n", - "``MaximumProbability`` can be manually provided to the loader to derive more realistic number of events.
\n", + "``MaximumProbability`` can be manually provided to the loader
\n", + "to derive more realistic number of events.
\n", "It is because ``weights`` are given as probability, not number of events in a McStas file.
" ] }, diff --git a/src/ess/nmx/mcstas_xml.py b/src/ess/nmx/mcstas_xml.py index 4ba79ef..1464f6f 100644 --- a/src/ess/nmx/mcstas_xml.py +++ b/src/ess/nmx/mcstas_xml.py @@ -8,7 +8,6 @@ import numpy as np import scipp as sc -from typing_extensions import Self T = TypeVar('T') @@ -31,13 +30,13 @@ class _XML(Protocol): tag: str attrib: dict[str, str] - def find(self, name: str) -> Optional[Self]: + def find(self, name: str) -> Optional['_XML']: ... - def __iter__(self) -> Self: + def __iter__(self) -> '_XML': ... - def __next__(self) -> Self: + def __next__(self) -> '_XML': ... @@ -94,7 +93,7 @@ class SimulationSettings: handedness: str # 'val' of @classmethod - def from_xml(cls, tree: _XML) -> Self: + def from_xml(cls, tree: _XML) -> 'SimulationSettings': """Create simulation settings from xml.""" defaults = select_by_tag(tree, 'defaults') length_desc = select_by_tag(defaults, 'length') @@ -157,7 +156,7 @@ class DetectorDesc: @classmethod def from_xml( cls, component: _XML, type_desc: _XML, simulation_settings: SimulationSettings - ) -> Self: + ) -> 'DetectorDesc': """Create detector description from xml component and type.""" def _rotate_axis(matrix: sc.Variable, axis: sc.Variable) -> sc.Variable: @@ -246,7 +245,9 @@ class SampleDesc: rotation_matrix: sc.Variable @classmethod - def from_xml(cls, tree: _XML, simulation_settings: SimulationSettings) -> Self: + def from_xml( + cls, tree: _XML, simulation_settings: SimulationSettings + ) -> 'SampleDesc': """Create sample description from xml component.""" source_xml = select_by_type_prefix(tree, 'sampleMantid-type') location = select_by_tag(source_xml, 'location') @@ -286,7 +287,9 @@ class SourceDesc: position: sc.Variable @classmethod - def from_xml(cls, tree: _XML, simulation_settings: SimulationSettings) -> Self: + def from_xml( + cls, tree: _XML, simulation_settings: SimulationSettings + ) -> 'SourceDesc': """Create source description from xml component.""" source_xml = select_by_type_prefix(tree, 'sourceMantid-type') location = select_by_tag(source_xml, 'location') @@ -350,7 +353,7 @@ class McStasInstrument: sample: SampleDesc @classmethod - def from_xml(cls, tree: _XML) -> Self: + def from_xml(cls, tree: _XML) -> 'McStasInstrument': """Create McStas instrument from xml.""" simulation_settings = SimulationSettings.from_xml(tree) From 7de24c188ef035b1748c1560dfe97e088ca0cbac Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Tue, 9 Jan 2024 15:13:46 +0100 Subject: [PATCH 04/10] Add instrument view in the document. --- docs/examples/workflow.ipynb | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/examples/workflow.ipynb b/docs/examples/workflow.ipynb index 1f8977b..1921a98 100644 --- a/docs/examples/workflow.ipynb +++ b/docs/examples/workflow.ipynb @@ -48,7 +48,9 @@ "source": [ "from typing import get_type_hints\n", "param_reprs = {key.__name__: value for key, value in params.items()}\n", - "prov_reprs = {get_type_hints(prov)['return'].__name__: prov.__name__ for prov in providers}\n", + "prov_reprs = {\n", + " get_type_hints(prov)['return'].__name__: prov.__name__ for prov in providers\n", + "}\n", "\n", "# Providers and parameters to be used for pipeline\n", "sc.DataGroup(**prov_reprs, **param_reprs)" @@ -99,9 +101,11 @@ "source": [ "## Instrument View\n", "\n", - "Pixel positions are not used for later steps, but it is included in the coordinates for instrument view.\n", + "Pixel positions are not used for later steps,\n", + "but it is included in the coordinates for instrument view.\n", "\n", - "All pixel positions are respect to the sample position, therefore the sample is at (0, 0, 0).\n", + "All pixel positions are respect to the sample position,\n", + "therefore the sample is at (0, 0, 0).\n", "\n", "**It might be very slow or not work in the ``VS Code`` jupyter notebook editor.**" ] @@ -114,7 +118,6 @@ "source": [ "import scippneutron as scn\n", "\n", - "\n", "unnecessary_coords = list(coord for coord in da.coords if coord != 'position')\n", "instrument_view_da = da.drop_coords(unnecessary_coords).flatten(['panel', 'id'], 'id').hist()\n", "view = scn.instrument_view(instrument_view_da)\n", From b720685d947e2c9351d4a1c13cafbfd780cc465d Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Tue, 9 Jan 2024 15:21:56 +0100 Subject: [PATCH 05/10] Remove instrument view from the cell temporarily. --- docs/examples/workflow.ipynb | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/docs/examples/workflow.ipynb b/docs/examples/workflow.ipynb index 1921a98..ac93f67 100644 --- a/docs/examples/workflow.ipynb +++ b/docs/examples/workflow.ipynb @@ -107,22 +107,19 @@ "All pixel positions are respect to the sample position,\n", "therefore the sample is at (0, 0, 0).\n", "\n", - "**It might be very slow or not work in the ``VS Code`` jupyter notebook editor.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "You can plot the instrument view like below.\n", + "\n", + "```python\n", "import scippneutron as scn\n", "\n", "unnecessary_coords = list(coord for coord in da.coords if coord != 'position')\n", "instrument_view_da = da.drop_coords(unnecessary_coords).flatten(['panel', 'id'], 'id').hist()\n", "view = scn.instrument_view(instrument_view_da)\n", "view.children[0].toolbar.cameraz()\n", - "view" + "view\n", + "```\n", + "\n", + "**It might be very slow or not work in the ``VS Code`` jupyter notebook editor.**" ] } ], From 22bf7c5429097bbedaf8741d15d7c0112948ea68 Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Tue, 9 Jan 2024 15:36:49 +0100 Subject: [PATCH 06/10] Add expected coordinate values. --- tests/loader_test.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/loader_test.py b/tests/loader_test.py index cef49f2..bcf5dc5 100644 --- a/tests/loader_test.py +++ b/tests/loader_test.py @@ -30,8 +30,23 @@ def test_file_reader_mcstas() -> None: assert isinstance(da, sc.DataArray) assert da.shape == (3, 1280 * 1280) + assert sc.identical( + da.coords['sample_position'], sc.vector(value=[0, 0, 0], unit='m') + ) assert da.bins.size().sum().value == data_length assert sc.identical(da.data.max(), expected_weight_max) + # Expected coordinate values are provided by the IDS + # based on the simulation settings of the sample file. + assert sc.identical( + da.coords['fast_axis'], + sc.vectors( + dims=['panel'], values=[[0, 0, -0.01], [-0.01, 0, -1], [0.01, 0, 1]] + ), + ) + assert sc.identical( + da.coords['slow_axis'], + sc.vectors(dims=['panel'], values=[[0, 1, 0], [0, 1, 0], [0, 1, 0]]), + ) @pytest.fixture From 0fb3e65fb248aa33024f0d16e661bed2d354bdef Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Tue, 9 Jan 2024 15:44:36 +0100 Subject: [PATCH 07/10] Add expected coordinate values. --- tests/loader_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/loader_test.py b/tests/loader_test.py index bcf5dc5..84be0de 100644 --- a/tests/loader_test.py +++ b/tests/loader_test.py @@ -40,12 +40,15 @@ def test_file_reader_mcstas() -> None: assert sc.identical( da.coords['fast_axis'], sc.vectors( - dims=['panel'], values=[[0, 0, -0.01], [-0.01, 0, -1], [0.01, 0, 1]] + dims=['panel'], + values=[(1.0, 0.0, -0.01), (-0.01, 0.0, -1.0), (0.01, 0.0, 1.0)], ), ) assert sc.identical( da.coords['slow_axis'], - sc.vectors(dims=['panel'], values=[[0, 1, 0], [0, 1, 0], [0, 1, 0]]), + sc.vectors( + dims=['panel'], values=[[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] + ), ) From 5854fe2bfe89ae10ba65107d8174707272a42fd8 Mon Sep 17 00:00:00 2001 From: Sunyoung Yoo Date: Wed, 10 Jan 2024 09:05:43 +0100 Subject: [PATCH 08/10] Fix typos and update names. Co-authored-by: Simon Heybrock <12912489+SimonHeybrock@users.noreply.github.com> --- docs/examples/workflow.ipynb | 2 +- src/ess/nmx/mcstas_xml.py | 4 ++-- src/ess/nmx/rotation.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/examples/workflow.ipynb b/docs/examples/workflow.ipynb index ac93f67..709bd1f 100644 --- a/docs/examples/workflow.ipynb +++ b/docs/examples/workflow.ipynb @@ -104,7 +104,7 @@ "Pixel positions are not used for later steps,\n", "but it is included in the coordinates for instrument view.\n", "\n", - "All pixel positions are respect to the sample position,\n", + "All pixel positions are relative to the sample position,\n", "therefore the sample is at (0, 0, 0).\n", "\n", "You can plot the instrument view like below.\n", diff --git a/src/ess/nmx/mcstas_xml.py b/src/ess/nmx/mcstas_xml.py index 1464f6f..2a266c6 100644 --- a/src/ess/nmx/mcstas_xml.py +++ b/src/ess/nmx/mcstas_xml.py @@ -264,7 +264,7 @@ def from_xml( def position_from_sample(self, other: sc.Variable) -> sc.Variable: """Position of ``other`` relative to the sample. - All positions and distance are stored respect to the sample position. + All positions and distance are stored relative to the sample position. Parameters ---------- @@ -374,7 +374,7 @@ def to_coords(self) -> dict[str, sc.Variable]: detector_dim = 'panel' return { - 'pixel_ids': _construct_pixel_ids(self.detectors), + 'pixel_id': _construct_pixel_ids(self.detectors), 'fast_axis': sc.concat(fast_axes, detector_dim), 'slow_axis': sc.concat(slow_axes, detector_dim), 'origin_position': sc.concat(origins, detector_dim), diff --git a/src/ess/nmx/rotation.py b/src/ess/nmx/rotation.py index 98beab8..c3dd04e 100644 --- a/src/ess/nmx/rotation.py +++ b/src/ess/nmx/rotation.py @@ -25,7 +25,7 @@ def axis_angle_to_quaternion( Returns ------- : - A list of (normalized) queternions, [x, y, z, w]. + A list of (normalized) quaternions, [x, y, z, w]. Notes ----- @@ -57,7 +57,7 @@ def quaternion_to_matrix(x: float, y: float, z: float, w: float) -> sc.Variable: Returns ------- : - A 3X3 rotation matrix (3 vectors). + A 3x3 rotation matrix. """ from scipy.spatial.transform import Rotation From 74efc2e85a65c33a6046d5880985e5b47263fe7f Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Wed, 10 Jan 2024 09:37:11 +0100 Subject: [PATCH 09/10] Remove rounding and make rotation function arguments keyword-only. --- src/ess/nmx/mcstas_loader.py | 2 +- src/ess/nmx/mcstas_xml.py | 46 +++++++++++++++++++++++------------- src/ess/nmx/rotation.py | 4 ++-- tests/loader_test.py | 10 ++++---- 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/ess/nmx/mcstas_loader.py b/src/ess/nmx/mcstas_loader.py index 92dbf55..c56ed28 100644 --- a/src/ess/nmx/mcstas_loader.py +++ b/src/ess/nmx/mcstas_loader.py @@ -73,7 +73,7 @@ def load_mcstas_nexus( loaded = sc.DataArray(data=weights, coords={'t': t_list, 'id': id_list}) coords = geometry.to_coords() - grouped = loaded.group(coords.pop('pixel_ids')) + grouped = loaded.group(coords.pop('pixel_id')) da = grouped.fold(dim='id', sizes={'panel': len(geometry.detectors), 'id': -1}) da.coords.update(coords) diff --git a/src/ess/nmx/mcstas_xml.py b/src/ess/nmx/mcstas_xml.py index 2a266c6..f73a6fb 100644 --- a/src/ess/nmx/mcstas_xml.py +++ b/src/ess/nmx/mcstas_xml.py @@ -6,7 +6,6 @@ from types import MappingProxyType from typing import Iterable, Optional, Protocol, Tuple, TypeVar, Union -import numpy as np import scipp as sc T = TypeVar('T') @@ -122,11 +121,14 @@ def _rotation_matrix_from_location( """Retrieve rotation matrix from location.""" from .rotation import axis_angle_to_quaternion, quaternion_to_matrix - theta, x, y, z = find_attributes( - location, 'rot', 'axis-x', 'axis-y', 'axis-z' - ).values() - q = axis_angle_to_quaternion(x, y, z, sc.scalar(-theta, unit=angle_unit)) - return quaternion_to_matrix(*q) + attribs = find_attributes(location, 'axis-x', 'axis-y', 'axis-z', 'rot') + x, y, z, w = axis_angle_to_quaternion( + x=attribs['axis-x'], + y=attribs['axis-y'], + z=attribs['axis-z'], + theta=sc.scalar(-attribs['rot'], unit=angle_unit), + ) + return quaternion_to_matrix(x=x, y=y, z=z, w=w) @dataclass @@ -155,12 +157,16 @@ class DetectorDesc: @classmethod def from_xml( - cls, component: _XML, type_desc: _XML, simulation_settings: SimulationSettings + cls, + *, + component: _XML, + type_desc: _XML, + simulation_settings: SimulationSettings, ) -> 'DetectorDesc': """Create detector description from xml component and type.""" def _rotate_axis(matrix: sc.Variable, axis: sc.Variable) -> sc.Variable: - return sc.vector(np.round((matrix * axis).values, 2)) + return matrix * axis location = select_by_tag(component, 'location') rotation_matrix = _rotation_matrix_from_location( @@ -226,7 +232,11 @@ def _find_type_desc(det: _XML) -> _XML: ) detector_components = [ - DetectorDesc.from_xml(det, _find_type_desc(det), simulation_settings) + DetectorDesc.from_xml( + component=det, + type_desc=_find_type_desc(det), + simulation_settings=simulation_settings, + ) for det in filter_by_type_prefix(filter_by_tag(tree, 'component'), 'MonNDtype') ] @@ -246,7 +256,7 @@ class SampleDesc: @classmethod def from_xml( - cls, tree: _XML, simulation_settings: SimulationSettings + cls, *, tree: _XML, simulation_settings: SimulationSettings ) -> 'SampleDesc': """Create sample description from xml component.""" source_xml = select_by_type_prefix(tree, 'sampleMantid-type') @@ -288,7 +298,7 @@ class SourceDesc: @classmethod def from_xml( - cls, tree: _XML, simulation_settings: SimulationSettings + cls, *, tree: _XML, simulation_settings: SimulationSettings ) -> 'SourceDesc': """Create source description from xml component.""" source_xml = select_by_type_prefix(tree, 'sourceMantid-type') @@ -318,10 +328,10 @@ def _pixel_positions( Position of each pixel is relative to the position_offset. """ pixel_idx = sc.arange('id', detector.total_pixels) - n_row = sc.scalar(detector.num_fast_pixels_per_row) + n_col = sc.scalar(detector.num_fast_pixels_per_row) - pixel_n_slow = pixel_idx // n_row - pixel_n_fast = pixel_idx % n_row + pixel_n_slow = pixel_idx // n_col + pixel_n_fast = pixel_idx % n_col fast_axis_steps = detector.fast_axis * detector.fast_step slow_axis_steps = detector.slow_axis * detector.slow_step @@ -360,8 +370,12 @@ def from_xml(cls, tree: _XML) -> 'McStasInstrument': return cls( simulation_settings=simulation_settings, detectors=_collect_detector_descriptions(tree), - source=SourceDesc.from_xml(tree, simulation_settings), - sample=SampleDesc.from_xml(tree, simulation_settings), + source=SourceDesc.from_xml( + tree=tree, simulation_settings=simulation_settings + ), + sample=SampleDesc.from_xml( + tree=tree, simulation_settings=simulation_settings + ), ) def to_coords(self) -> dict[str, sc.Variable]: diff --git a/src/ess/nmx/rotation.py b/src/ess/nmx/rotation.py index c3dd04e..4ec91c8 100644 --- a/src/ess/nmx/rotation.py +++ b/src/ess/nmx/rotation.py @@ -7,7 +7,7 @@ def axis_angle_to_quaternion( - x: float, y: float, z: float, theta: sc.Variable + *, x: float, y: float, z: float, theta: sc.Variable ) -> NDArray: """Convert axis-angle to queternions, [x, y, z, w]. @@ -40,7 +40,7 @@ def axis_angle_to_quaternion( return q / np.linalg.norm(q) -def quaternion_to_matrix(x: float, y: float, z: float, w: float) -> sc.Variable: +def quaternion_to_matrix(*, x: float, y: float, z: float, w: float) -> sc.Variable: """Convert quaternion to rotation matrix. Parameters diff --git a/tests/loader_test.py b/tests/loader_test.py index 84be0de..2b969c8 100644 --- a/tests/loader_test.py +++ b/tests/loader_test.py @@ -10,6 +10,7 @@ def test_file_reader_mcstas() -> None: + import numpy as np import scippnexus as snx from ess.nmx.mcstas_loader import ( @@ -37,12 +38,13 @@ def test_file_reader_mcstas() -> None: assert sc.identical(da.data.max(), expected_weight_max) # Expected coordinate values are provided by the IDS # based on the simulation settings of the sample file. - assert sc.identical( - da.coords['fast_axis'], - sc.vectors( + # The expected values are rounded to 2 decimal places. + assert np.all( + np.round(da.coords['fast_axis'].values, 2) + == sc.vectors( dims=['panel'], values=[(1.0, 0.0, -0.01), (-0.01, 0.0, -1.0), (0.01, 0.0, 1.0)], - ), + ).values, ) assert sc.identical( da.coords['slow_axis'], From f4495e2bd5a9e085a094e09613b365eb7fa4ed7f Mon Sep 17 00:00:00 2001 From: YooSunyoung Date: Wed, 10 Jan 2024 11:47:38 +0100 Subject: [PATCH 10/10] Remove unnecessary helper. --- src/ess/nmx/mcstas_xml.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/ess/nmx/mcstas_xml.py b/src/ess/nmx/mcstas_xml.py index f73a6fb..d981a30 100644 --- a/src/ess/nmx/mcstas_xml.py +++ b/src/ess/nmx/mcstas_xml.py @@ -165,9 +165,6 @@ def from_xml( ) -> 'DetectorDesc': """Create detector description from xml component and type.""" - def _rotate_axis(matrix: sc.Variable, axis: sc.Variable) -> sc.Variable: - return matrix * axis - location = select_by_tag(component, 'location') rotation_matrix = _rotation_matrix_from_location( location, simulation_settings.angle_unit @@ -191,12 +188,8 @@ def _rotate_axis(matrix: sc.Variable, axis: sc.Variable) -> sc.Variable: start_y=float(type_desc.attrib['ystart']), position=_position_from_location(location, simulation_settings.length_unit), rotation_matrix=rotation_matrix, - fast_axis=_rotate_axis( - rotation_matrix, _AXISNAME_TO_UNIT_VECTOR[fast_axis_name] - ), - slow_axis=_rotate_axis( - rotation_matrix, _AXISNAME_TO_UNIT_VECTOR[slow_axis_name] - ), + fast_axis=rotation_matrix * _AXISNAME_TO_UNIT_VECTOR[fast_axis_name], + slow_axis=rotation_matrix * _AXISNAME_TO_UNIT_VECTOR[slow_axis_name], ) @property