Skip to content

Commit

Permalink
Add NeXus prefix to Names
Browse files Browse the repository at this point in the history
  • Loading branch information
nvaytet committed Mar 8, 2024
1 parent 4446a00 commit 3a51a10
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
24 changes: 12 additions & 12 deletions src/ess/reduce/nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
NeXusGroup = NewType('NeXusGroup', snx.Group)
"""A ScippNexus group in an open file."""

DetectorName = NewType('DetectorName', str)
NeXusDetectorName = NewType('NeXusDetectorName', str)
"""Name of a detector (bank) in a NeXus file."""
EntryName = NewType('EntryName', str)
NeXusEntryName = NewType('NeXusEntryName', str)
"""Name of an entry in a NeXus file."""
MonitorName = NewType('MonitorName', str)
NeXusMonitorName = NewType('NeXusMonitorName', str)
"""Name of a monitor in a NeXus file."""
SourceName = NewType('SourceName', str)
NeXusSourceName = NewType('NeXusSourceName', str)
"""Name of a source in a NeXus file."""

RawDetector = NewType('RawDetector', sc.DataGroup)
Expand All @@ -59,8 +59,8 @@
def load_detector(
file_path: Union[FilePath, NeXusFile, NeXusGroup],
*,
detector_name: DetectorName,
entry_name: Optional[EntryName] = None,
detector_name: NeXusDetectorName,
entry_name: Optional[NeXusEntryName] = None,
) -> RawDetector:
"""Load a single detector (bank) from a NeXus file.
Expand Down Expand Up @@ -103,8 +103,8 @@ def load_detector(
def load_monitor(
file_path: Union[FilePath, NeXusFile, NeXusGroup],
*,
monitor_name: MonitorName,
entry_name: Optional[EntryName] = None,
monitor_name: NeXusMonitorName,
entry_name: Optional[NeXusEntryName] = None,
) -> RawMonitor:
"""Load a single monitor from a NeXus file.
Expand Down Expand Up @@ -147,8 +147,8 @@ def load_monitor(
def load_source(
file_path: Union[FilePath, NeXusFile, NeXusGroup],
*,
source_name: Optional[SourceName] = None,
entry_name: Optional[EntryName] = None,
source_name: Optional[NeXusSourceName] = None,
entry_name: Optional[NeXusEntryName] = None,
) -> RawSource:
"""Load a source from a NeXus file.
Expand Down Expand Up @@ -192,7 +192,7 @@ def load_source(

def load_sample(
file_path: Union[FilePath, NeXusFile, NeXusGroup],
entry_name: Optional[EntryName] = None,
entry_name: Optional[NeXusEntryName] = None,
) -> RawSample:
"""Load a sample from a NeXus file.
Expand Down Expand Up @@ -231,7 +231,7 @@ def _load_group_with_positions(
*,
group_name: Optional[str],
nx_class: Type[snx.NXobject],
entry_name: Optional[EntryName] = None,
entry_name: Optional[NeXusEntryName] = None,
) -> sc.DataGroup:
with _open_nexus_file(file_path) as f:
entry = _unique_child_group(f, snx.NXentry, entry_name)
Expand Down
20 changes: 10 additions & 10 deletions tests/nexus_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ def expected_sample() -> sc.DataGroup:
return _sample_data()


@pytest.mark.parametrize('entry_name', (None, nexus.EntryName('entry-001')))
@pytest.mark.parametrize('entry_name', (None, nexus.NeXusEntryName('entry-001')))
def test_load_detector(nexus_file, expected_bank12, entry_name):
detector = nexus.load_detector(
nexus_file,
detector_name=nexus.DetectorName('bank12'),
detector_name=nexus.NeXusDetectorName('bank12'),
entry_name=entry_name,
)
sc.testing.assert_identical(detector['bank12_events'], expected_bank12)
Expand All @@ -232,7 +232,7 @@ def test_load_detector_requires_entry_name_if_not_unique(nexus_file):
with pytest.raises(ValueError):
nexus.load_detector(
nexus.FilePath(nexus_file),
detector_name=nexus.DetectorName('bank12'),
detector_name=nexus.NeXusDetectorName('bank12'),
entry_name=None,
)

Expand All @@ -247,24 +247,24 @@ def test_load_detector_select_entry_if_not_unique(nexus_file, expected_bank12):

detector = nexus.load_detector(
nexus.FilePath(nexus_file),
detector_name=nexus.DetectorName('bank12'),
entry_name=nexus.EntryName('entry-001'),
detector_name=nexus.NeXusDetectorName('bank12'),
entry_name=nexus.NeXusEntryName('entry-001'),
)
sc.testing.assert_identical(detector['bank12_events'], expected_bank12)


@pytest.mark.parametrize('entry_name', (None, nexus.EntryName('entry-001')))
@pytest.mark.parametrize('entry_name', (None, nexus.NeXusEntryName('entry-001')))
def test_load_monitor(nexus_file, expected_monitor, entry_name):
monitor = nexus.load_monitor(
nexus_file,
monitor_name=nexus.MonitorName('monitor'),
monitor_name=nexus.NeXusMonitorName('monitor'),
entry_name=entry_name,
)
sc.testing.assert_identical(monitor['data'], expected_monitor)


@pytest.mark.parametrize('entry_name', (None, nexus.EntryName('entry-001')))
@pytest.mark.parametrize('source_name', (None, nexus.SourceName('source')))
@pytest.mark.parametrize('entry_name', (None, nexus.NeXusEntryName('entry-001')))
@pytest.mark.parametrize('source_name', (None, nexus.NeXusSourceName('source')))
def test_load_source(nexus_file, expected_source, entry_name, source_name):
source = nexus.load_source(
nexus_file,
Expand All @@ -277,7 +277,7 @@ def test_load_source(nexus_file, expected_source, entry_name, source_name):
sc.testing.assert_identical(source, nexus.RawSource(expected_source))


@pytest.mark.parametrize('entry_name', (None, nexus.EntryName('entry-001')))
@pytest.mark.parametrize('entry_name', (None, nexus.NeXusEntryName('entry-001')))
def test_load_sample(nexus_file, expected_sample, entry_name):
sample = nexus.load_sample(nexus_file, entry_name=entry_name)
sc.testing.assert_identical(sample, nexus.RawSample(expected_sample))
Expand Down

0 comments on commit 3a51a10

Please sign in to comment.