-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from scipp/grow-events-script
feat: add grow nexus script
- Loading branch information
Showing
11 changed files
with
218 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
# Do not make an environment from this file, use test.txt instead! | ||
|
||
pytest | ||
numpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import argparse | ||
import shutil | ||
from typing import Optional | ||
|
||
import h5py | ||
|
||
|
||
def _scale_group(event_data: h5py.Group, scale: int): | ||
if not all( | ||
required_field in event_data | ||
for required_field in ('event_index', 'event_time_offset', 'event_id') | ||
): | ||
return | ||
event_index = (event_data['event_index'][:] * scale).astype('uint') | ||
event_data['event_index'][:] = event_index | ||
|
||
size = event_data['event_id'].size | ||
event_data['event_id'].resize(event_index[-1], axis=0) | ||
event_data['event_time_offset'].resize(event_index[-1], axis=0) | ||
|
||
for s in range(1, scale): | ||
event_data['event_id'][s * size : (s + 1) * size] = event_data['event_id'][ | ||
:size | ||
] | ||
event_data['event_time_offset'][s * size : (s + 1) * size] = event_data[ | ||
'event_time_offset' | ||
][:size] | ||
|
||
|
||
def _grow_nexus_file_impl(file: h5py.File, detector_scale: int, monitor_scale: int): | ||
for group in file.values(): | ||
if group.attrs.get('NX_class', '') == 'NXentry': | ||
entry = group | ||
break | ||
for group in entry.values(): | ||
if group.attrs.get('NX_class', '') == 'NXinstrument': | ||
instrument = group | ||
break | ||
for group in instrument.values(): | ||
if (nx_class := group.attrs.get('NX_class', '')) in ( | ||
'NXdetector', | ||
'NXmonitor', | ||
): | ||
for subgroup in group.values(): | ||
if subgroup.attrs.get('NX_class', '') == 'NXevent_data': | ||
_scale_group( | ||
subgroup, | ||
scale=detector_scale | ||
if nx_class == 'NXdetector' | ||
else monitor_scale, | ||
) | ||
|
||
|
||
def grow_nexus_file( | ||
*, filename: str, detector_scale: int, monitor_scale: Optional[int] | ||
): | ||
with h5py.File(filename, 'a') as f: | ||
_grow_nexus_file_impl( | ||
f, | ||
detector_scale, | ||
monitor_scale if monitor_scale is not None else detector_scale, | ||
) | ||
|
||
|
||
def integer_greater_than_one(x): | ||
x = int(x) | ||
if x < 1: | ||
raise argparse.ArgumentTypeError('Must be larger than or equal to 1') | ||
return x | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-f", | ||
"--file", | ||
type=str, | ||
help=( | ||
'Input file name. The events in the input file will be ' | ||
'repeated `scale` times and stored in the output file.' | ||
), | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"-o", | ||
"--output", | ||
type=str, | ||
help='Output file name where the resulting nexus file will be written.', | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"-s", | ||
"--detector-scale", | ||
type=integer_greater_than_one, | ||
help=('Scale factor to multiply the number of detector events by.'), | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"-m", | ||
"--monitor-scale", | ||
type=integer_greater_than_one, | ||
default=None, | ||
help=( | ||
'Scale factor to multiply the number of monitor events by. ' | ||
'If not given, the detector scale will be used' | ||
), | ||
) | ||
args = parser.parse_args() | ||
if args.file != args.output: | ||
shutil.copy2(args.file, args.output) | ||
grow_nexus_file( | ||
filename=args.output, | ||
detector_scale=args.detector_scale, | ||
monitor_scale=args.monitor_scale, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
import tempfile | ||
|
||
import h5py | ||
import numpy as np | ||
import pytest | ||
|
||
from ess.reduce.scripts.grow_nexus import grow_nexus_file | ||
|
||
|
||
@pytest.fixture | ||
def nexus_file(): | ||
with tempfile.TemporaryDirectory() as tmp: | ||
path = os.path.join(tmp, 'test.nxs') | ||
with h5py.File(path, 'a') as hf: | ||
entry = hf.create_group('entry') | ||
entry.attrs['NX_class'] = 'NXentry' | ||
|
||
instrument = entry.create_group('instrument') | ||
instrument.attrs['NX_class'] = 'NXinstrument' | ||
|
||
for group, nxclass in ( | ||
('detector', 'NXdetector'), | ||
('monitor', 'NXmonitor'), | ||
): | ||
detector = instrument.create_group(group) | ||
detector.attrs['NX_class'] = nxclass | ||
|
||
event_data = detector.create_group('event_data') | ||
event_data.attrs['NX_class'] = 'NXevent_data' | ||
|
||
event_data.create_dataset( | ||
'event_index', | ||
data=np.array([2, 4, 6]), | ||
maxshape=(None,), | ||
chunks=True, | ||
) | ||
event_data.create_dataset( | ||
'event_time_zero', | ||
data=np.array([0, 1, 2]), | ||
maxshape=(None,), | ||
chunks=True, | ||
) | ||
event_data.create_dataset( | ||
'event_id', | ||
data=np.array([0, 1, 2, 0, 1, 2]), | ||
maxshape=(None,), | ||
chunks=True, | ||
) | ||
event_data.create_dataset( | ||
'event_time_offset', | ||
data=np.array([1, 2, 1, 2, 1, 2]), | ||
maxshape=(None,), | ||
chunks=True, | ||
) | ||
|
||
yield path | ||
|
||
|
||
@pytest.mark.parametrize('monitor_scale', (1, 2, None)) | ||
@pytest.mark.parametrize('detector_scale', (1, 2)) | ||
def test_grow_nexus(nexus_file, detector_scale, monitor_scale): | ||
grow_nexus_file( | ||
filename=nexus_file, detector_scale=detector_scale, monitor_scale=monitor_scale | ||
) | ||
|
||
monitor_scale = monitor_scale if monitor_scale is not None else detector_scale | ||
|
||
with h5py.File(nexus_file, 'r') as f: | ||
for detector, scale in zip( | ||
('detector', 'monitor'), (detector_scale, monitor_scale) | ||
): | ||
np.testing.assert_equal( | ||
[scale * i for i in [2, 4, 6]], | ||
f[f'entry/instrument/{detector}/event_data/event_index'][()], | ||
) | ||
np.testing.assert_equal( | ||
scale * [0, 1, 2, 0, 1, 2], | ||
f[f'entry/instrument/{detector}/event_data/event_id'][()], | ||
) | ||
np.testing.assert_equal( | ||
scale * [1, 2, 1, 2, 1, 2], | ||
f[f'entry/instrument/{detector}/event_data/event_time_offset'][()], | ||
) |