Skip to content

Commit

Permalink
Replace h5 data group/dataset creation functions with helper functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
YooSunYoung committed Jan 11, 2024
1 parent ef0cf2f commit efaf78f
Showing 1 changed file with 85 additions and 47 deletions.
132 changes: 85 additions & 47 deletions src/ess/nmx/reduction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import pathlib
from typing import NewType, Union
from typing import NewType, Optional, Union

import h5py
import scipp as sc
Expand Down Expand Up @@ -73,6 +73,49 @@ def counts(self) -> sc.DataArray:
"""Binned time of arrival data from flattened event data."""
return self['counts']

def _create_dataset_from_var(
self,
*,
root_entry: h5py.Group,
var: sc.Variable,
name: str,
long_name: Optional[str] = None,
compression: Optional[str] = None,
compression_opts: Optional[int] = None,
) -> h5py.Dataset:
compression_options = dict()
if compression is not None:
compression_options["compression"] = compression
if compression_opts is not None:
compression_options["compression_opts"] = compression_opts

dataset = root_entry.create_dataset(
name,
data=var.values,
**compression_options,
)
dataset.attrs["units"] = str(var.unit)
if long_name is not None:
dataset.attrs["long_name"] = long_name
return dataset

def _create_compressed_dataset(
self,
*,
root_entry: h5py.Group,
name: str,
var: sc.Variable,
long_name: Optional[str] = None,
) -> h5py.Dataset:
return self._create_dataset_from_var(
root_entry=root_entry,
var=var,
name=name,
long_name=long_name,
compression="gzip",
compression_opts=4,
)

def _create_root_data_entry(self, file_obj: h5py.File) -> h5py.Group:
nx_entry = file_obj.create_group("NMX_data")
nx_entry.attrs["NX_class"] = "NXentry"
Expand All @@ -85,70 +128,62 @@ def _create_root_data_entry(self, file_obj: h5py.File) -> h5py.Group:
def _create_sample_group(self, nx_entry: h5py.Group) -> h5py.Group:
nx_sample = nx_entry.create_group("NXsample")
nx_sample["name"] = self.sample_name.value
crystal_rotation = nx_sample.create_dataset(
'crystal_rotation', data=self.crystal_rotation.values
# Crystal rotation
self._create_dataset_from_var(
root_entry=nx_sample,
var=self.crystal_rotation,
name='crystal_rotation',
long_name='crystal rotation in Phi (XYZ)',
)
crystal_rotation.attrs["units"] = str(self.crystal_rotation.unit)
crystal_rotation.attrs["long_name"] = 'crystal rotation in Phi (XYZ)'

return nx_sample

def _create_compressed_dataset(
self, nx_entry: h5py.Group, name: str, var: sc.Variable, *, long_name: str
) -> h5py.Dataset:
dataset = nx_entry.create_dataset(
name,
data=var.values,
compression="gzip",
compression_opts=4,
)
dataset.attrs["units"] = str(var.unit)
dataset.attrs["long_name"] = name
return dataset

def _create_instrument_group(self, nx_entry: h5py.Group) -> h5py.Group:
nx_instrument = nx_entry.create_group("NXinstrument")
nx_instrument.attrs["nr_detector"] = self.origin_position.sizes['panel']
nx_instrument.create_dataset("proton_charge", data=self.proton_charge)

nx_detector_1 = nx_instrument.create_group("detector_1")
counts = nx_detector_1.create_dataset(
"counts", data=[self.counts.values], compression="gzip", compression_opts=4
# Detector counts
self._create_compressed_dataset(
root_entry=nx_detector_1,
name="counts",
var=self.counts.fold(
'id', sizes={'panel': 1, 'id': self.counts.sizes['id']}
),
)
counts.attrs["units"] = "counts"
t_spectrum = nx_detector_1.create_dataset(
"t_bin",
data=self.counts.coords['t'].values,
compression="gzip",
compression_opts=4,
# Time of arrival bin edges
self._create_dataset_from_var(
root_entry=nx_detector_1,
var=self.counts.coords['t'],
name="t_bin",
long_name="t_bin TOF (ms)",
)
t_spectrum.attrs["units"] = "s"
t_spectrum.attrs["long_name"] = "t_bin TOF (ms)"
pixel_id = nx_detector_1.create_dataset(
"pixel_id",
data=self.counts.coords['id'].values,
compression="gzip",
compression_opts=4,
# Pixel IDs
self._create_compressed_dataset(
root_entry=nx_detector_1,
name="pixel_id",
var=self.counts.coords['id'],
long_name="pixel ID",
)
pixel_id.attrs["units"] = ""
pixel_id.attrs["long_name"] = "pixel ID"
return nx_instrument

def _create_detector_group(self, nx_entry: h5py.Group) -> h5py.Group:
nx_detector = nx_entry.create_group("NXdetector")
# Position of the first pixel (lowest ID) in the detector
detector_origins = nx_detector.create_dataset(
"origin",
data=self.origin_position.values,
compression="gzip",
compression_opts=4,
self._create_compressed_dataset(
root_entry=nx_detector,
name="origin",
var=self.origin_position,
)
detector_origins.attrs["units"] = "m"
# Fast axis, along where the pixel ID increases by 1
nx_detector.create_dataset("fast_axis", data=self.fast_axis.values)
self._create_dataset_from_var(
root_entry=nx_detector, var=self.fast_axis, name="fast_axis"
)
# Slow axis, along where the pixel ID increases
# by the number of pixels in the fast axis
nx_detector.create_dataset("slow_axis", data=self.slow_axis.values)
self._create_dataset_from_var(
root_entry=nx_detector, var=self.slow_axis, name="slow_axis"
)
return nx_detector

def _create_source_group(self, nx_entry: h5py.Group) -> h5py.Group:
Expand Down Expand Up @@ -187,9 +222,12 @@ def bin_time_of_arrival(
) -> NMXReducedData:
"""Bin time of arrival data into ``time_bin_step`` bins."""

counts: sc.DataArray = nmx_data.weights.flatten(dims=['panel', 'id'], to='id').hist(
t=time_bin_step
)
counts.unit = 'counts'

return NMXReducedData(
counts=nmx_data.weights.flatten(dims=['panel', 'id'], to='id').hist(
t=time_bin_step
),
counts=counts,
**{key: nmx_data[key] for key in nmx_data.keys() if key != 'weights'},
)

0 comments on commit efaf78f

Please sign in to comment.