diff --git a/src/ess/reduce/nexus/__init__.py b/src/ess/reduce/nexus/__init__.py index 514c1a88..07042f82 100644 --- a/src/ess/reduce/nexus/__init__.py +++ b/src/ess/reduce/nexus/__init__.py @@ -12,24 +12,18 @@ from . import types from ._nexus_loader import ( - extract_detector_data, - extract_monitor_data, - load_detector, load_event_data, group_event_data, - load_monitor, - load_sample, - load_source, + load_component, + compute_component_position, + extract_events_or_histogram, ) __all__ = [ 'types', - 'extract_detector_data', - 'extract_monitor_data', 'group_event_data', - 'load_detector', 'load_event_data', - 'load_monitor', - 'load_sample', - 'load_source', + 'load_component', + 'compute_component_position', + 'extract_events_or_histogram', ] diff --git a/src/ess/reduce/nexus/_nexus_loader.py b/src/ess/reduce/nexus/_nexus_loader.py index 2330434a..ea75abe7 100644 --- a/src/ess/reduce/nexus/_nexus_loader.py +++ b/src/ess/reduce/nexus/_nexus_loader.py @@ -13,21 +13,7 @@ import scippnexus as snx from ..logging import get_logger -from .types import ( - AnyNeXusMonitorName, - AnyRunAnyNeXusMonitor, - AnyRunFilename, - AnyRunNeXusDetector, - AnyRunNeXusSample, - AnyRunNeXusSource, - NeXusDetectorName, - NeXusEntryName, - NeXusGroup, - NeXusLocationSpec, - NeXusSourceName, - RawDetectorData, - RawMonitorData, -) +from .types import AnyRunFilename, NeXusEntryName, NeXusGroup, NeXusLocationSpec class NoNewDefinitionsType: ... @@ -36,206 +22,6 @@ class NoNewDefinitionsType: ... NoNewDefinitions = NoNewDefinitionsType() -def load_detector( - file_path: AnyRunFilename, - selection=(), - *, - detector_name: NeXusDetectorName, - entry_name: NeXusEntryName | None = None, - definitions: Mapping | None | NoNewDefinitionsType = NoNewDefinitions, -) -> AnyRunNeXusDetector: - """Load a single detector (bank) from a NeXus file. - - The detector positions are computed automatically from NeXus transformations, - and the combined transformation is stored under the name 'transform'. - - Parameters - ---------- - file_path: - Indicates where to load data from. - One of: - - - Path to a NeXus file on disk. - - File handle or buffer for reading binary data. - - A ScippNexus group of the root of a NeXus file. - detector_name: - Name of the detector (bank) to load. - Must be a group in an instrument group in the entry (see below). - entry_name: - Name of the entry that contains the detector. - If ``None``, the entry will be located based - on its NeXus class, but there cannot be more than 1. - definitions: - Definitions used by scippnexus loader, see :py:`scippnexus.File` - for documentation. - - Returns - ------- - : - A data group containing the detector events or histogram - and any auxiliary data stored in the same NeXus group. - """ - return AnyRunNeXusDetector( - load_component( - NeXusLocationSpec( - filename=file_path, - component_name=detector_name, - entry_name=entry_name, - selection=selection, - ), - nx_class=snx.NXdetector, - definitions=definitions, - ) - ) - - -def load_monitor( - file_path: AnyRunFilename, - selection=(), - *, - monitor_name: AnyNeXusMonitorName, - entry_name: NeXusEntryName | None = None, - definitions: Mapping | None | NoNewDefinitionsType = NoNewDefinitions, -) -> AnyRunAnyNeXusMonitor: - """Load a single monitor from a NeXus file. - - The monitor position is computed automatically from NeXus transformations, - and the combined transformation is stored under the name 'transform'. - - Parameters - ---------- - file_path: - Indicates where to load data from. - One of: - - - Path to a NeXus file on disk. - - File handle or buffer for reading binary data. - - A ScippNexus group of the root of a NeXus file. - monitor_name: - Name of the monitor to load. - Must be a group in an instrument group in the entry (see below). - entry_name: - Name of the entry that contains the monitor. - If ``None``, the entry will be located based - on its NeXus class, but there cannot be more than 1. - definitions: - Definitions used by scippnexus loader, see :py:`scippnexus.File` - for documentation. - - Returns - ------- - : - A data group containing the monitor events or histogram - and any auxiliary data stored in the same NeXus group. - """ - return AnyRunAnyNeXusMonitor( - load_component( - NeXusLocationSpec( - filename=file_path, - component_name=monitor_name, - entry_name=entry_name, - selection=selection, - ), - nx_class=snx.NXmonitor, - definitions=definitions, - ) - ) - - -def load_source( - file_path: AnyRunFilename, - *, - source_name: NeXusSourceName | None = None, - entry_name: NeXusEntryName | None = None, - definitions: Mapping | None | NoNewDefinitionsType = NoNewDefinitions, -) -> AnyRunNeXusSource: - """Load a source from a NeXus file. - - The source position is computed automatically from NeXus transformations, - and the combined transformation is stored under the name 'transform'. - - Parameters - ---------- - file_path: - Indicates where to load data from. - One of: - - - Path to a NeXus file on disk. - - File handle or buffer for reading binary data. - - A ScippNexus group of the root of a NeXus file. - source_name: - Name of the source to load. - Must be a group in an instrument group in the entry (see below). - If ``None``, the source will be located based - on its NeXus class. - entry_name: - Name of the instrument that contains the source. - If ``None``, the entry will be located based - on its NeXus class, but there cannot be more than 1. - definitions: - Definitions used by scippnexus loader, see :py:`scippnexus.File` - for documentation. - - Returns - ------- - : - A data group containing all data stored in - the source NeXus group. - """ - return AnyRunNeXusSource( - load_component( - NeXusLocationSpec( - filename=file_path, component_name=source_name, entry_name=entry_name - ), - nx_class=snx.NXsource, - definitions=definitions, - ) - ) - - -def load_sample( - file_path: AnyRunFilename, - entry_name: NeXusEntryName | None = None, - definitions: Mapping | None | NoNewDefinitionsType = NoNewDefinitions, -) -> AnyRunNeXusSample: - """Load a sample from a NeXus file. - - The sample is located based on its NeXus class. - There can be only one sample in a NeXus file or - in the entry indicated by ``entry_name``. - - Parameters - ---------- - file_path: - Indicates where to load data from. - One of: - - - Path to a NeXus file on disk. - - File handle or buffer for reading binary data. - - A ScippNexus group of the root of a NeXus file. - entry_name: - Name of the instrument that contains the source. - If ``None``, the entry will be located based - on its NeXus class, but there cannot be more than 1. - definitions: - Definitions used by scippnexus loader, see :py:`scippnexus.File` - for documentation. - - Returns - ------- - : - A data group containing all data stored in - the sample NeXus group. - """ - return AnyRunNeXusSample( - load_component( - NeXusLocationSpec(filename=file_path, entry_name=entry_name), - nx_class=snx.NXsample, - definitions=definitions, - ) - ) - - def load_component( location: NeXusLocationSpec, *, @@ -254,24 +40,26 @@ def load_component( instrument = _unique_child_group(entry, snx.NXinstrument, None) component = _unique_child_group(instrument, nx_class, group_name) loaded = cast(sc.DataGroup, component[selection]) + loaded['nexus_component_name'] = component.name.split('/')[-1] + return compute_component_position(loaded) - transform_out_name = 'transform' - if transform_out_name in loaded: - raise RuntimeError( - f"Loaded data contains an item '{transform_out_name}' but we want to " - "store the combined NeXus transformations under that name." - ) - position_out_name = 'position' - if position_out_name in loaded: - raise RuntimeError( - f"Loaded data contains an item '{position_out_name}' but we want to " - "store the computed positions under that name." - ) - loaded = snx.compute_positions( - loaded, store_position=position_out_name, store_transform=transform_out_name + +def compute_component_position(dg: sc.DataGroup) -> sc.DataGroup: + transform_out_name = 'transform' + if transform_out_name in dg: + raise RuntimeError( + f"Loaded data contains an item '{transform_out_name}' but we want to " + "store the combined NeXus transformations under that name." ) - loaded['nexus_component_name'] = component.name.split('/')[-1] - return loaded + position_out_name = 'position' + if position_out_name in dg: + raise RuntimeError( + f"Loaded data contains an item '{position_out_name}' but we want to " + "store the computed positions under that name." + ) + return snx.compute_positions( + dg, store_position=position_out_name, store_transform=transform_out_name + ) def _open_nexus_file( @@ -311,65 +99,7 @@ def _unique_child_group( return next(iter(children.values())) # type: ignore[return-value] -def extract_detector_data(detector: AnyRunNeXusDetector) -> RawDetectorData: - """Get and return the events or histogram from a detector loaded from NeXus. - - This function looks for a data array in the detector group and returns that. - - Parameters - ---------- - detector: - A detector loaded from NeXus. - - Returns - ------- - : - A data array containing the events or histogram. - - Raises - ------ - ValueError - If there is more than one data array. - - See also - -------- - load_detector: - Load a detector from a NeXus file in a format compatible with - ``extract_detector_data``. - """ - return RawDetectorData(_extract_events_or_histogram(detector)) - - -def extract_monitor_data(monitor: AnyRunAnyNeXusMonitor) -> RawMonitorData: - """Get and return the events or histogram from a monitor loaded from NeXus. - - This function looks for a data array in the monitor group and returns that. - - Parameters - ---------- - monitor: - A monitor loaded from NeXus. - - Returns - ------- - : - A data array containing the events or histogram. - - Raises - ------ - ValueError - If there is more than one data array. - - See also - -------- - load_monitor: - Load a monitor from a NeXus file in a format compatible with - ``extract_monitor_data``. - """ - return RawMonitorData(_extract_events_or_histogram(monitor)) - - -def _extract_events_or_histogram(dg: sc.DataGroup) -> sc.DataArray: +def extract_events_or_histogram(dg: sc.DataGroup) -> sc.DataArray: event_data_arrays = { key: value for key, value in dg.items() diff --git a/src/ess/reduce/nexus/workflow.py b/src/ess/reduce/nexus/workflow.py index cd958c39..59d227c9 100644 --- a/src/ess/reduce/nexus/workflow.py +++ b/src/ess/reduce/nexus/workflow.py @@ -358,7 +358,7 @@ def get_calibrated_detector( bank_sizes: Dictionary of detector bank sizes. """ - da = nexus.extract_detector_data(detector) + da = nexus.extract_events_or_histogram(detector) if ( sizes := (bank_sizes or {}).get(detector.get('nexus_component_name')) ) is not None: @@ -421,7 +421,7 @@ def get_calibrated_monitor( Position of the neutron source. """ return AnyRunAnyCalibratedMonitor( - nexus.extract_monitor_data(monitor).assign_coords( + nexus.extract_events_or_histogram(monitor).assign_coords( position=monitor['position'] + offset.to(unit=monitor['position'].unit), source_position=source_position, ) diff --git a/tests/nexus/nexus_loader_test.py b/tests/nexus/nexus_loader_test.py index f392ccf4..ae7f34ee 100644 --- a/tests/nexus/nexus_loader_test.py +++ b/tests/nexus/nexus_loader_test.py @@ -12,6 +12,7 @@ import scippnexus as snx from ess.reduce import nexus +from ess.reduce.nexus.types import NeXusLocationSpec year_zero = sc.datetime('1970-01-01T00:00:00') @@ -219,12 +220,14 @@ def expected_sample() -> sc.DataGroup: ], ) def test_load_detector(nexus_file, expected_bank12, entry_name, selection): - detector = nexus.load_detector( - nexus_file, - **({'selection': selection} if selection is not None else {}), - detector_name=nexus.types.NeXusDetectorName('bank12'), + loc = NeXusLocationSpec( + filename=nexus_file, entry_name=entry_name, + component_name=nexus.types.NeXusDetectorName('bank12'), ) + if selection is not None: + loc.selection = selection + detector = nexus.load_component(loc, nx_class=snx.NXdetector) if selection: expected = expected_bank12.bins[selection] expected.coords.pop(selection[0]) @@ -250,11 +253,13 @@ def test_load_detector(nexus_file, expected_bank12, entry_name, selection): def test_load_and_group_event_data_consistent_with_load_via_detector( nexus_file, selection ): - detector = nexus.load_detector( - nexus_file, - selection=selection, - detector_name=nexus.types.NeXusDetectorName('bank12'), - )['bank12_events'] + loc = NeXusLocationSpec( + filename=nexus_file, + component_name=nexus.types.NeXusDetectorName('bank12'), + ) + if selection: + loc.selection = selection + detector = nexus.load_component(loc, nx_class=snx.NXdetector)['bank12_events'] events = nexus.load_event_data( nexus_file, selection=selection, @@ -268,10 +273,11 @@ def test_load_and_group_event_data_consistent_with_load_via_detector( def test_group_event_data_does_not_modify_input(nexus_file): - detector = nexus.load_detector( - nexus_file, - detector_name=nexus.types.NeXusDetectorName('bank12'), - )['bank12_events'] + loc = NeXusLocationSpec( + filename=nexus_file, + component_name=nexus.types.NeXusDetectorName('bank12'), + ) + detector = nexus.load_component(loc, nx_class=snx.NXdetector)['bank12_events'] events = nexus.load_event_data( nexus_file, component_name=nexus.types.NeXusDetectorName('bank12'), @@ -284,19 +290,15 @@ def test_group_event_data_does_not_modify_input(nexus_file): def test_load_detector_open_file_with_new_definitions_raises(nexus_file): + loc = NeXusLocationSpec( + filename=nexus_file, + component_name=nexus.types.NeXusDetectorName('bank12'), + ) if isinstance(nexus_file, snx.Group): with pytest.raises(ValueError, match="new definitions"): - nexus.load_detector( - nexus_file, - detector_name=nexus.types.NeXusDetectorName('bank12'), - definitions={}, - ) + nexus.load_component(loc, nx_class=snx.NXdetector, definitions={}) else: - nexus.load_detector( - nexus_file, - detector_name=nexus.types.NeXusDetectorName('bank12'), - definitions={}, - ) + nexus.load_component(loc, nx_class=snx.NXdetector, definitions={}) def test_load_detector_new_definitions_applied(nexus_file, expected_bank12): @@ -308,13 +310,15 @@ def detector(*args, **kwargs): new_definition_used = True return snx.base_definitions()['NXdetector'](*args, **kwargs) - nexus.load_detector( - nexus_file, - detector_name=nexus.types.NeXusDetectorName('bank12'), - definitions=dict( - snx.base_definitions(), - NXdetector=detector, - ), + loc = NeXusLocationSpec( + filename=nexus_file, + component_name=nexus.types.NeXusDetectorName('bank12'), + ) + + nexus.load_component( + loc, + nx_class=snx.NXdetector, + definitions=dict(snx.base_definitions(), NXdetector=detector), ) assert new_definition_used @@ -327,12 +331,13 @@ def test_load_detector_requires_entry_name_if_not_unique(nexus_file): with snx.File(nexus_file, 'r+') as f: f.create_class('entry', snx.NXentry) + loc = NeXusLocationSpec( + filename=nexus.types.FilePath(nexus_file), + component_name=nexus.types.NeXusDetectorName('bank12'), + entry_name=None, + ) with pytest.raises(ValueError, match="Expected exactly one"): - nexus.load_detector( - nexus.types.FilePath(nexus_file), - detector_name=nexus.types.NeXusDetectorName('bank12'), - entry_name=None, - ) + nexus.load_component(loc, nx_class=snx.NXdetector) def test_load_detector_select_entry_if_not_unique(nexus_file, expected_bank12): @@ -343,11 +348,12 @@ def test_load_detector_select_entry_if_not_unique(nexus_file, expected_bank12): with snx.File(nexus_file, 'r+') as f: f.create_class('entry', snx.NXentry) - detector = nexus.load_detector( - nexus.types.FilePath(nexus_file), - detector_name=nexus.types.NeXusDetectorName('bank12'), + loc = NeXusLocationSpec( + filename=nexus.types.FilePath(nexus_file), + component_name=nexus.types.NeXusDetectorName('bank12'), entry_name=nexus.types.NeXusEntryName('entry-001'), ) + detector = nexus.load_component(loc, nx_class=snx.NXdetector) sc.testing.assert_identical(detector['bank12_events'], expected_bank12) @@ -361,12 +367,14 @@ def test_load_detector_select_entry_if_not_unique(nexus_file, expected_bank12): ], ) def test_load_monitor(nexus_file, expected_monitor, entry_name, selection): - monitor = nexus.load_monitor( - nexus_file, - **({'selection': selection} if selection is not None else {}), - monitor_name=nexus.types.AnyNeXusMonitorName('monitor'), + loc = NeXusLocationSpec( + filename=nexus_file, entry_name=entry_name, + component_name=nexus.types.AnyNeXusMonitorName('monitor'), ) + if selection is not None: + loc.selection = selection + monitor = nexus.load_component(loc, nx_class=snx.NXmonitor) expected = expected_monitor[selection] if selection else expected_monitor sc.testing.assert_identical(monitor['data'], expected) @@ -374,11 +382,12 @@ def test_load_monitor(nexus_file, expected_monitor, entry_name, selection): @pytest.mark.parametrize('entry_name', [None, nexus.types.NeXusEntryName('entry-001')]) @pytest.mark.parametrize('source_name', [None, nexus.types.NeXusSourceName('source')]) def test_load_source(nexus_file, expected_source, entry_name, source_name): - source = nexus.load_source( - nexus_file, + loc = NeXusLocationSpec( + filename=nexus_file, entry_name=entry_name, - source_name=source_name, + component_name=source_name, ) + source = nexus.load_component(loc, nx_class=snx.NXsource) # NeXus details that we don't need to test as long as the positions are ok: del source['depends_on'] del source['transformations'] @@ -388,8 +397,8 @@ def test_load_source(nexus_file, expected_source, entry_name, source_name): @pytest.mark.parametrize( ('loader', 'cls', 'name'), [ - (nexus.load_source, snx.NXsource, 'NXsource'), - (nexus.load_sample, snx.NXsample, 'NXsample'), + (nexus.load_component, snx.NXsource, 'NXsource'), + (nexus.load_component, snx.NXsample, 'NXsample'), ], ) def test_load_new_definitions_applied(nexus_file, loader, cls, name): @@ -401,19 +410,18 @@ def new(*args, **kwargs): new_definition_used = True return cls(*args, **kwargs) - loader( - nexus_file, - definitions={ - **snx.base_definitions(), - name: new, - }, - ) + loc = NeXusLocationSpec(filename=nexus_file) + loader(loc, nx_class=cls, definitions={**snx.base_definitions(), name: new}) assert new_definition_used @pytest.mark.parametrize('entry_name', [None, nexus.types.NeXusEntryName('entry-001')]) def test_load_sample(nexus_file, expected_sample, entry_name): - sample = nexus.load_sample(nexus_file, entry_name=entry_name) + loc = NeXusLocationSpec( + filename=nexus_file, + entry_name=entry_name, + ) + sample = nexus.load_component(loc, nx_class=snx.NXsample) sc.testing.assert_identical(sample, nexus.types.AnyRunNeXusSample(expected_sample)) @@ -425,7 +433,7 @@ def test_extract_detector_data(): ' _': sc.linspace('xx', 2, 3, 10), } ) - data = nexus.extract_detector_data(nexus.types.AnyRunNeXusDetector(detector)) + data = nexus.extract_events_or_histogram(nexus.types.AnyRunNeXusDetector(detector)) sc.testing.assert_identical(data, nexus.types.RawDetectorData(detector['jdl2ab'])) @@ -437,7 +445,7 @@ def test_extract_monitor_data(): ' _': sc.linspace('xx', 2, 3, 10), } ) - data = nexus.extract_monitor_data(nexus.types.AnyRunAnyNeXusMonitor(monitor)) + data = nexus.extract_events_or_histogram(nexus.types.AnyRunAnyNeXusMonitor(monitor)) sc.testing.assert_identical(data, nexus.types.RawMonitorData(monitor['(eed)'])) @@ -453,17 +461,17 @@ def test_extract_detector_data_requires_unique_dense_data(): with pytest.raises( ValueError, match="Cannot uniquely identify the data to extract" ): - nexus.extract_detector_data(nexus.types.AnyRunNeXusDetector(detector)) + nexus.extract_events_or_histogram(nexus.types.AnyRunNeXusDetector(detector)) def test_extract_detector_data_ignores_position_data_array(): detector = sc.DataGroup(jdl2ab=sc.data.data_xy(), position=sc.data.data_xy()) - nexus.extract_detector_data(nexus.types.AnyRunNeXusDetector(detector)) + nexus.extract_events_or_histogram(nexus.types.AnyRunNeXusDetector(detector)) def test_extract_detector_data_ignores_transform_data_array(): detector = sc.DataGroup(jdl2ab=sc.data.data_xy(), transform=sc.data.data_xy()) - nexus.extract_detector_data(nexus.types.AnyRunNeXusDetector(detector)) + nexus.extract_events_or_histogram(nexus.types.AnyRunNeXusDetector(detector)) def test_extract_detector_data_requires_unique_event_data(): @@ -478,7 +486,7 @@ def test_extract_detector_data_requires_unique_event_data(): with pytest.raises( ValueError, match="Cannot uniquely identify the data to extract" ): - nexus.extract_detector_data(nexus.types.AnyRunNeXusDetector(detector)) + nexus.extract_events_or_histogram(nexus.types.AnyRunNeXusDetector(detector)) def test_extract_detector_data_favors_event_data_over_histogram_data(): @@ -490,5 +498,5 @@ def test_extract_detector_data_favors_event_data_over_histogram_data(): ' _': sc.linspace('xx', 2, 3, 10), } ) - data = nexus.extract_detector_data(nexus.types.AnyRunNeXusDetector(detector)) + data = nexus.extract_events_or_histogram(nexus.types.AnyRunNeXusDetector(detector)) sc.testing.assert_identical(data, nexus.types.RawDetectorData(detector['lob']))