Skip to content

Commit

Permalink
Remove rounding and make rotation function arguments keyword-only.
Browse files Browse the repository at this point in the history
  • Loading branch information
YooSunYoung committed Jan 10, 2024
1 parent 10b73a5 commit ed546e8
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/ess/nmx/mcstas_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def load_mcstas_nexus(

loaded = sc.DataArray(data=weights, coords={'t': t_list, 'id': id_list})
coords = geometry.to_coords()
grouped = loaded.group(coords.pop('pixel_ids'))
grouped = loaded.group(coords.pop('pixel_id'))
da = grouped.fold(dim='id', sizes={'panel': len(geometry.detectors), 'id': -1})
da.coords.update(coords)

Expand Down
46 changes: 30 additions & 16 deletions src/ess/nmx/mcstas_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from types import MappingProxyType
from typing import Iterable, Optional, Protocol, Tuple, TypeVar, Union

import numpy as np
import scipp as sc

T = TypeVar('T')
Expand Down Expand Up @@ -122,11 +121,14 @@ def _rotation_matrix_from_location(
"""Retrieve rotation matrix from location."""
from .rotation import axis_angle_to_quaternion, quaternion_to_matrix

theta, x, y, z = find_attributes(
location, 'rot', 'axis-x', 'axis-y', 'axis-z'
).values()
q = axis_angle_to_quaternion(x, y, z, sc.scalar(-theta, unit=angle_unit))
return quaternion_to_matrix(*q)
attribs = find_attributes(location, 'axis-x', 'axis-y', 'axis-z', 'rot')
x, y, z, w = axis_angle_to_quaternion(
x=attribs['axis-x'],
y=attribs['axis-y'],
z=attribs['axis-z'],
theta=sc.scalar(-attribs['rot'], unit=angle_unit),
)
return quaternion_to_matrix(x=x, y=y, z=z, w=w)


@dataclass
Expand Down Expand Up @@ -155,12 +157,16 @@ class DetectorDesc:

@classmethod
def from_xml(
cls, component: _XML, type_desc: _XML, simulation_settings: SimulationSettings
cls,
*,
component: _XML,
type_desc: _XML,
simulation_settings: SimulationSettings,
) -> 'DetectorDesc':
"""Create detector description from xml component and type."""

def _rotate_axis(matrix: sc.Variable, axis: sc.Variable) -> sc.Variable:
return sc.vector(np.round((matrix * axis).values, 2))
return matrix * axis

location = select_by_tag(component, 'location')
rotation_matrix = _rotation_matrix_from_location(
Expand Down Expand Up @@ -226,7 +232,11 @@ def _find_type_desc(det: _XML) -> _XML:
)

detector_components = [
DetectorDesc.from_xml(det, _find_type_desc(det), simulation_settings)
DetectorDesc.from_xml(
component=det,
type_desc=_find_type_desc(det),
simulation_settings=simulation_settings,
)
for det in filter_by_type_prefix(filter_by_tag(tree, 'component'), 'MonNDtype')
]

Expand All @@ -246,7 +256,7 @@ class SampleDesc:

@classmethod
def from_xml(
cls, tree: _XML, simulation_settings: SimulationSettings
cls, *, tree: _XML, simulation_settings: SimulationSettings
) -> 'SampleDesc':
"""Create sample description from xml component."""
source_xml = select_by_type_prefix(tree, 'sampleMantid-type')
Expand Down Expand Up @@ -288,7 +298,7 @@ class SourceDesc:

@classmethod
def from_xml(
cls, tree: _XML, simulation_settings: SimulationSettings
cls, *, tree: _XML, simulation_settings: SimulationSettings
) -> 'SourceDesc':
"""Create source description from xml component."""
source_xml = select_by_type_prefix(tree, 'sourceMantid-type')
Expand Down Expand Up @@ -318,10 +328,10 @@ def _pixel_positions(
Position of each pixel is relative to the position_offset.
"""
pixel_idx = sc.arange('id', detector.total_pixels)
n_row = sc.scalar(detector.num_fast_pixels_per_row)
n_col = sc.scalar(detector.num_fast_pixels_per_row)

pixel_n_slow = pixel_idx // n_row
pixel_n_fast = pixel_idx % n_row
pixel_n_slow = pixel_idx // n_col
pixel_n_fast = pixel_idx % n_col

fast_axis_steps = detector.fast_axis * detector.fast_step
slow_axis_steps = detector.slow_axis * detector.slow_step
Expand Down Expand Up @@ -360,8 +370,12 @@ def from_xml(cls, tree: _XML) -> 'McStasInstrument':
return cls(
simulation_settings=simulation_settings,
detectors=_collect_detector_descriptions(tree),
source=SourceDesc.from_xml(tree, simulation_settings),
sample=SampleDesc.from_xml(tree, simulation_settings),
source=SourceDesc.from_xml(
tree=tree, simulation_settings=simulation_settings
),
sample=SampleDesc.from_xml(
tree=tree, simulation_settings=simulation_settings
),
)

def to_coords(self) -> dict[str, sc.Variable]:
Expand Down
4 changes: 2 additions & 2 deletions src/ess/nmx/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def axis_angle_to_quaternion(
x: float, y: float, z: float, theta: sc.Variable
*, x: float, y: float, z: float, theta: sc.Variable
) -> NDArray:
"""Convert axis-angle to queternions, [x, y, z, w].
Expand Down Expand Up @@ -40,7 +40,7 @@ def axis_angle_to_quaternion(
return q / np.linalg.norm(q)


def quaternion_to_matrix(x: float, y: float, z: float, w: float) -> sc.Variable:
def quaternion_to_matrix(*, x: float, y: float, z: float, w: float) -> sc.Variable:
"""Convert quaternion to rotation matrix.
Parameters
Expand Down
10 changes: 6 additions & 4 deletions tests/loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def test_file_reader_mcstas() -> None:
import numpy as np
import scippnexus as snx

from ess.nmx.mcstas_loader import (
Expand Down Expand Up @@ -37,12 +38,13 @@ def test_file_reader_mcstas() -> None:
assert sc.identical(da.data.max(), expected_weight_max)
# Expected coordinate values are provided by the IDS
# based on the simulation settings of the sample file.
assert sc.identical(
da.coords['fast_axis'],
sc.vectors(
# The expected values are rounded to 2 decimal places.
assert np.all(
np.round(da.coords['fast_axis'].values, 2)
== sc.vectors(
dims=['panel'],
values=[(1.0, 0.0, -0.01), (-0.01, 0.0, -1.0), (0.01, 0.0, 1.0)],
),
).values,
)
assert sc.identical(
da.coords['slow_axis'],
Expand Down

0 comments on commit ed546e8

Please sign in to comment.