From bd1493cead2e781595d984fb3523fff63f59e872 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 23 Aug 2024 15:12:39 -0600 Subject: [PATCH] [testing] Re-organize data tests around the one dataset per test (#1026) --- CHANGELOG.md | 1 + .../tools/testing/data_interface_mixins.py | 297 +++++------ tests/test_behavior/test_audio_interface.py | 64 ++- .../test_mock_recording_interface.py | 8 +- .../test_on_data/test_behavior_interfaces.py | 265 +++++----- .../test_miniscope_converter.py | 16 +- tests/test_on_data/test_imaging_interfaces.py | 229 +++++---- .../test_on_data/test_recording_interfaces.py | 481 ++++++++++-------- .../test_segmentation_interfaces.py | 225 +++++--- tests/test_on_data/test_sorting_interfaces.py | 185 ++++--- 10 files changed, 970 insertions(+), 801 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b343771a8..1edae1bb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ * Added `get_json_schema_from_method_signature` which constructs Pydantic models automatically from the signature of any function with typical annotation types used throughout NeuroConv. [PR #1016](https://github.com/catalystneuro/neuroconv/pull/1016) * Replaced all interface annotations with Pydantic types. [PR #1017](https://github.com/catalystneuro/neuroconv/pull/1017) * Changed typehint collections (e.g. `List`) to standard collections (e.g. `list`). [PR #1021](https://github.com/catalystneuro/neuroconv/pull/1021) +* Testing now is only one dataset per test [PR #1026](https://github.com/catalystneuro/neuroconv/pull/1026) diff --git a/src/neuroconv/tools/testing/data_interface_mixins.py b/src/neuroconv/tools/testing/data_interface_mixins.py index b923851c2..07e25bede 100644 --- a/src/neuroconv/tools/testing/data_interface_mixins.py +++ b/src/neuroconv/tools/testing/data_interface_mixins.py @@ -8,7 +8,7 @@ from typing import Literal, Optional, Type, Union import numpy as np -from hdmf.testing import TestCase as HDMFTestCase +import pytest from hdmf_zarr import NWBZarrIO from jsonschema.validators import Draft7Validator, validate from numpy.testing import assert_array_equal @@ -35,14 +35,11 @@ ) from neuroconv.utils import NWBMetaDataEncoder -from .mock_probes import generate_mock_probe - class DataInterfaceTestMixin: """ Generic class for testing DataInterfaces. - This mixin must be paired with unittest.TestCase. Several of these tests are required to be run in a specific order. In this case, there is a `test_conversion_as_lone_interface` that calls the `check` functions in @@ -63,11 +60,26 @@ class DataInterfaceTestMixin: """ data_interface_cls: Type[BaseDataInterface] - interface_kwargs: Union[dict, list[dict]] + interface_kwargs: dict save_directory: Path = Path(tempfile.mkdtemp()) - conversion_options: dict = dict() + conversion_options: Optional[dict] = None maxDiff = None + @pytest.fixture + def setup_interface(self, request): + + self.test_name: str = "" + self.conversion_options = self.conversion_options or dict() + self.interface = self.data_interface_cls(**self.interface_kwargs) + + return self.interface, self.test_name + + @pytest.fixture(scope="class", autouse=True) + def setup_default_conversion_options(self, request): + cls = request.cls + cls.conversion_options = cls.conversion_options or dict() + return cls.conversion_options + def test_source_schema_valid(self): schema = self.data_interface_cls.get_source_schema() Draft7Validator.check_schema(schema=schema) @@ -150,8 +162,7 @@ def check_run_conversion_in_nwbconverter_with_backend( class TestNWBConverter(NWBConverter): data_interface_classes = dict(Test=type(self.interface)) - test_kwargs = self.test_kwargs[0] if isinstance(self.test_kwargs, list) else self.test_kwargs - source_data = dict(Test=test_kwargs) + source_data = dict(Test=self.interface_kwargs) converter = TestNWBConverter(source_data=source_data) metadata = converter.get_metadata() @@ -174,8 +185,7 @@ def check_run_conversion_in_nwbconverter_with_backend_configuration( class TestNWBConverter(NWBConverter): data_interface_classes = dict(Test=type(self.interface)) - test_kwargs = self.test_kwargs[0] if isinstance(self.test_kwargs, list) else self.test_kwargs - source_data = dict(Test=test_kwargs) + source_data = dict(Test=self.interface_kwargs) converter = TestNWBConverter(source_data=source_data) metadata = converter.get_metadata() @@ -213,59 +223,59 @@ def run_custom_checks(self): """Override this in child classes to inject additional custom checks.""" pass - def test_all_conversion_checks(self): - interface_kwargs = self.interface_kwargs - if isinstance(interface_kwargs, dict): - interface_kwargs = [interface_kwargs] - for num, kwargs in enumerate(interface_kwargs): - with self.subTest(str(num)): - self.case = num - self.test_kwargs = kwargs - self.interface = self.data_interface_cls(**self.test_kwargs) + def test_all_conversion_checks(self, setup_interface, tmp_path): + interface, test_name = setup_interface - self.check_metadata_schema_valid() - self.check_conversion_options_schema_valid() - self.check_metadata() - self.nwbfile_path = str(self.save_directory / f"{self.__class__.__name__}_{num}.nwb") + # Create a unique test name and file path + nwbfile_path = str(tmp_path / f"{self.__class__.__name__}_{self.test_name}.nwb") + self.nwbfile_path = nwbfile_path - self.check_no_metadata_mutation() - - self.check_configure_backend_for_equivalent_nwbfiles() - - self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=self.nwbfile_path, backend="hdf5") - self.check_run_conversion_in_nwbconverter_with_backend_configuration( - nwbfile_path=self.nwbfile_path, backend="hdf5" - ) + # Now run the checks using the setup objects + self.check_metadata_schema_valid() + self.check_conversion_options_schema_valid() + self.check_metadata() + self.check_no_metadata_mutation() + self.check_configure_backend_for_equivalent_nwbfiles() - self.check_run_conversion_with_backend(nwbfile_path=self.nwbfile_path, backend="hdf5") - self.check_run_conversion_with_backend_configuration(nwbfile_path=self.nwbfile_path, backend="hdf5") + self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=nwbfile_path, backend="hdf5") + self.check_run_conversion_in_nwbconverter_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5") - self.check_read_nwb(nwbfile_path=self.nwbfile_path) + self.check_run_conversion_with_backend(nwbfile_path=nwbfile_path, backend="hdf5") + self.check_run_conversion_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5") - # TODO: enable when all H5DataIO prewraps are gone - # self.nwbfile_path = str(self.save_directory / f"{self.__class__.__name__}_{num}.nwb.zarr") - # self.check_run_conversion(nwbfile_path=self.nwbfile_path, backend="zarr") - # self.check_run_conversion_custom_backend(nwbfile_path=self.nwbfile_path, backend="zarr") - # self.check_basic_zarr_read(nwbfile_path=self.nwbfile_path) + self.check_read_nwb(nwbfile_path=nwbfile_path) - # Any extra custom checks to run - self.run_custom_checks() + # Any extra custom checks to run + self.run_custom_checks() class TemporalAlignmentMixin: """ Generic class for testing temporal alignment methods. - - This mixin must be paired with a unittest.TestCase class. """ data_interface_cls: Type[BaseDataInterface] - interface_kwargs: Union[dict, list[dict]] + interface_kwargs: dict + save_directory: Path = Path(tempfile.mkdtemp()) + conversion_options: Optional[dict] = None maxDiff = None + @pytest.fixture + def setup_interface(self, request): + + self.test_name: str = "" + self.interface = self.data_interface_cls(**self.interface_kwargs) + return self.interface, self.test_name + + @pytest.fixture(scope="class", autouse=True) + def setup_default_conversion_options(self, request): + cls = request.cls + cls.conversion_options = cls.conversion_options or dict() + return cls.conversion_options + def setUpFreshInterface(self): """Protocol for creating a fresh instance of the interface.""" - self.interface = self.data_interface_cls(**self.test_kwargs) + self.interface = self.data_interface_cls(**self.interface_kwargs) def check_interface_get_original_timestamps(self): """ @@ -330,22 +340,17 @@ def check_nwbfile_temporal_alignment(self): """Check the temporally aligned timing information makes it into the NWB file.""" pass # TODO: will be easier to add when interface have 'add' methods separate from .run_conversion() - def test_interface_alignment(self): - interface_kwargs = self.interface_kwargs - if isinstance(interface_kwargs, dict): - interface_kwargs = [interface_kwargs] - for num, kwargs in enumerate(interface_kwargs): - with self.subTest(str(num)): - self.case = num - self.test_kwargs = kwargs + def test_interface_alignment(self, setup_interface): - self.check_interface_get_original_timestamps() - self.check_interface_get_timestamps() - self.check_interface_set_aligned_timestamps() - self.check_shift_timestamps_by_start_time() - self.check_interface_original_timestamps_inmutability() + interface, test_name = setup_interface - self.check_nwbfile_temporal_alignment() + self.check_interface_get_original_timestamps() + self.check_interface_get_timestamps() + self.check_interface_set_aligned_timestamps() + self.check_shift_timestamps_by_start_time() + self.check_interface_original_timestamps_inmutability() + + self.check_nwbfile_temporal_alignment() class ImagingExtractorInterfaceTestMixin(DataInterfaceTestMixin, TemporalAlignmentMixin): @@ -366,10 +371,11 @@ def check_read_nwb(self, nwbfile_path: str): def check_nwbfile_temporal_alignment(self): nwbfile_path = str( - self.save_directory / f"{self.data_interface_cls.__name__}_{self.case}_test_starting_time_alignment.nwb" + self.save_directory + / f"{self.data_interface_cls.__name__}_{self.test_name}_test_starting_time_alignment.nwb" ) - interface = self.data_interface_cls(**self.test_kwargs) + interface = self.data_interface_cls(**self.interface_kwargs) aligned_starting_time = 1.23 interface.set_aligned_starting_time(aligned_starting_time=aligned_starting_time) @@ -399,10 +405,6 @@ def check_read(self, nwbfile_path: str): class RecordingExtractorInterfaceTestMixin(DataInterfaceTestMixin, TemporalAlignmentMixin): """ Generic class for testing any recording interface. - - Runs all the basic DataInterface tests as well as temporal alignment tests. - - This mixin must be paired with a hdmf.testing.TestCase class. """ data_interface_cls: Type[BaseRecordingExtractorInterface] @@ -482,12 +484,9 @@ def check_interface_set_aligned_timestamps(self): retrieved_aligned_timestamps = self.interface.get_timestamps() assert_array_equal(x=retrieved_aligned_timestamps, y=aligned_timestamps) else: - assert isinstance( - self, HDMFTestCase - ), "The RecordingExtractorInterfaceTestMixin must be mixed-in with the TestCase from hdmf.testing!" - with self.assertRaisesWith( - exc_type=AssertionError, - exc_msg="This recording has multiple segments; please use 'align_segment_timestamps' instead.", + with pytest.raises( + AssertionError, + match="This recording has multiple segments; please use 'align_segment_timestamps' instead.", ): all_unaligned_timestamps = self.interface.get_timestamps() @@ -590,74 +589,32 @@ def check_interface_original_timestamps_inmutability(self): post_alignment_original_timestamps = self.interface.get_original_timestamps() assert_array_equal(x=post_alignment_original_timestamps, y=pre_alignment_original_timestamps) else: - assert isinstance( - self, HDMFTestCase - ), "The RecordingExtractorInterfaceTestMixin must be mixed-in with the TestCase from hdmf.testing!" - with self.assertRaisesWith( - exc_type=AssertionError, - exc_msg="This recording has multiple segments; please use 'align_segment_timestamps' instead.", + with pytest.raises( + AssertionError, + match="This recording has multiple segments; please use 'align_segment_timestamps' instead.", ): - all_pre_alignement_timestamps = self.interface.get_original_timestamps() + all_pre_alignment_timestamps = self.interface.get_original_timestamps() all_aligned_timestamps = [ - unaligned_timestamps + 1.23 for unaligned_timestamps in all_pre_alignement_timestamps + unaligned_timestamps + 1.23 for unaligned_timestamps in all_pre_alignment_timestamps ] self.interface.set_aligned_timestamps(aligned_timestamps=all_aligned_timestamps) - def test_interface_alignment(self): - interface_kwargs = self.interface_kwargs - if isinstance(interface_kwargs, dict): - interface_kwargs = [interface_kwargs] - for num, kwargs in enumerate(interface_kwargs): - with self.subTest(str(num)): - self.case = num - self.test_kwargs = kwargs + def test_interface_alignment(self, setup_interface): - self.check_interface_get_original_timestamps() - self.check_interface_get_timestamps() - self.check_interface_set_aligned_timestamps() - self.check_interface_set_aligned_segment_timestamps() - self.check_shift_timestamps_by_start_time() - self.check_shift_segment_timestamps_by_starting_times() - self.check_interface_original_timestamps_inmutability() + interface, test_name = setup_interface - self.check_nwbfile_temporal_alignment() + self.check_interface_get_original_timestamps() + self.check_interface_get_timestamps() + self.check_interface_set_aligned_timestamps() + self.check_shift_timestamps_by_start_time() + self.check_interface_original_timestamps_inmutability() - def test_all_conversion_checks(self): - interface_kwargs = self.interface_kwargs - if isinstance(interface_kwargs, dict): - interface_kwargs = [interface_kwargs] - for num, kwargs in enumerate(interface_kwargs): - with self.subTest(str(num)): - self.case = num - self.test_kwargs = kwargs - self.interface = self.data_interface_cls(**self.test_kwargs) - assert isinstance(self.interface, BaseRecordingExtractorInterface) - if not self.interface.has_probe(): - self.interface.set_probe( - generate_mock_probe(num_channels=self.interface.recording_extractor.get_num_channels()), - group_mode="by_shank", - ) + self.check_interface_set_aligned_segment_timestamps() + self.check_shift_timestamps_by_start_time() + self.check_shift_segment_timestamps_by_starting_times() - self.check_metadata_schema_valid() - self.check_conversion_options_schema_valid() - self.check_metadata() - self.nwbfile_path = str(self.save_directory / f"{self.__class__.__name__}_{num}.nwb") - - self.check_no_metadata_mutation() - - self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=self.nwbfile_path, backend="hdf5") - self.check_run_conversion_in_nwbconverter_with_backend_configuration( - nwbfile_path=self.nwbfile_path, backend="hdf5" - ) - - self.check_run_conversion_with_backend(nwbfile_path=self.nwbfile_path, backend="hdf5") - self.check_run_conversion_with_backend_configuration(nwbfile_path=self.nwbfile_path, backend="hdf5") - - self.check_read_nwb(nwbfile_path=self.nwbfile_path) - - # Any extra custom checks to run - self.run_custom_checks() + self.check_nwbfile_temporal_alignment() class SortingExtractorInterfaceTestMixin(DataInterfaceTestMixin, TemporalAlignmentMixin): @@ -666,7 +623,7 @@ class SortingExtractorInterfaceTestMixin(DataInterfaceTestMixin, TemporalAlignme associated_recording_kwargs: Optional[dict] = None def setUpFreshInterface(self): - self.interface = self.data_interface_cls(**self.test_kwargs) + self.interface = self.data_interface_cls(**self.interface_kwargs) recording_interface = self.associated_recording_cls(**self.associated_recording_kwargs) self.interface.register_recording(recording_interface=recording_interface) @@ -768,26 +725,45 @@ def check_shift_segment_timestamps_by_starting_times(self): ): assert_array_equal(x=retrieved_aligned_timestamps, y=expected_aligned_timestamps) - def test_interface_alignment(self): - interface_kwargs = self.interface_kwargs - if isinstance(interface_kwargs, dict): - interface_kwargs = [interface_kwargs] - for num, kwargs in enumerate(interface_kwargs): - with self.subTest(str(num)): - self.case = num - self.test_kwargs = kwargs + def test_all_conversion_checks(self, setup_interface, tmp_path): + # The fixture `setup_interface` sets up the necessary objects + interface, test_name = setup_interface - if self.associated_recording_cls is None: - continue + # Create a unique test name and file path + nwbfile_path = str(tmp_path / f"{self.__class__.__name__}_{self.test_name}.nwb") - # Skip get_original_timestamps() checks since unsupported - self.check_interface_get_timestamps() - self.check_interface_set_aligned_timestamps() - self.check_interface_set_aligned_segment_timestamps() - self.check_shift_timestamps_by_start_time() - self.check_shift_segment_timestamps_by_starting_times() + # Now run the checks using the setup objects + self.check_metadata_schema_valid() + self.check_conversion_options_schema_valid() + self.check_metadata() + self.check_no_metadata_mutation() + self.check_configure_backend_for_equivalent_nwbfiles() - self.check_nwbfile_temporal_alignment() + self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=nwbfile_path, backend="hdf5") + self.check_run_conversion_in_nwbconverter_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5") + + self.check_run_conversion_with_backend(nwbfile_path=nwbfile_path, backend="hdf5") + self.check_run_conversion_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5") + + self.check_read_nwb(nwbfile_path=nwbfile_path) + + # Any extra custom checks to run + self.run_custom_checks() + + def test_interface_alignment(self, setup_interface): + + # TODO sorting can have times without associated recordings, test this later + if self.associated_recording_cls is None: + return None + + # Skip get_original_timestamps() checks since unsupported + self.check_interface_get_timestamps() + self.check_interface_set_aligned_timestamps() + self.check_interface_set_aligned_segment_timestamps() + self.check_shift_timestamps_by_start_time() + self.check_shift_segment_timestamps_by_starting_times() + + self.check_nwbfile_temporal_alignment() class AudioInterfaceTestMixin(DataInterfaceTestMixin, TemporalAlignmentMixin): @@ -824,7 +800,7 @@ class VideoInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin): def check_read_nwb(self, nwbfile_path: str): with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io: nwbfile = io.read() - video_type = Path(self.test_kwargs["file_paths"][0]).suffix[1:] + video_type = Path(self.interface_kwargs["file_paths"][0]).suffix[1:] assert f"Video: video_{video_type}" in nwbfile.acquisition def check_interface_set_aligned_timestamps(self): @@ -858,7 +834,9 @@ def check_shift_timestamps_by_start_time(self): def check_set_aligned_segment_starting_times(self): self.setUpFreshInterface() - aligned_segment_starting_times = [1.23 * file_path_index for file_path_index in range(len(self.test_kwargs))] + aligned_segment_starting_times = [ + 1.23 * file_path_index for file_path_index in range(len(self.interface_kwargs)) + ] self.interface.set_aligned_segment_starting_times(aligned_segment_starting_times=aligned_segment_starting_times) all_aligned_timestamps = self.interface.get_timestamps() @@ -887,27 +865,6 @@ def check_interface_original_timestamps_inmutability(self): ): assert_array_equal(x=post_alignment_original_timestamps, y=pre_alignment_original_timestamps) - def check_nwbfile_temporal_alignment(self): - pass # TODO in separate PR - - def test_interface_alignment(self): - interface_kwargs = self.interface_kwargs - if isinstance(interface_kwargs, dict): - interface_kwargs = [interface_kwargs] - for num, kwargs in enumerate(interface_kwargs): - with self.subTest(str(num)): - self.case = num - self.test_kwargs = kwargs - - self.check_interface_get_original_timestamps() - self.check_interface_get_timestamps() - self.check_interface_set_aligned_timestamps() - self.check_shift_timestamps_by_start_time() - self.check_interface_original_timestamps_inmutability() - self.check_set_aligned_segment_starting_times() - - self.check_nwbfile_temporal_alignment() - class MedPCInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin): def check_no_metadata_mutation(self, metadata: dict): diff --git a/tests/test_behavior/test_audio_interface.py b/tests/test_behavior/test_audio_interface.py index bdfcf1e52..4fad4fe3d 100644 --- a/tests/test_behavior/test_audio_interface.py +++ b/tests/test_behavior/test_audio_interface.py @@ -1,14 +1,12 @@ -import shutil +import re from copy import deepcopy from datetime import datetime from pathlib import Path -from tempfile import mkdtemp -from warnings import warn import jsonschema import numpy as np +import pytest from dateutil.tz import gettz -from hdmf.testing import TestCase from numpy.testing import assert_array_equal from pydantic import FilePath from pynwb import NWBHDF5IO @@ -38,38 +36,39 @@ def create_audio_files( return audio_file_names -class TestAudioInterface(AudioInterfaceTestMixin, TestCase): - @classmethod - def setUpClass(cls): +class TestAudioInterface(AudioInterfaceTestMixin): + + data_interface_cls = AudioInterface + + @pytest.fixture(scope="class", autouse=True) + def setup_test(self, request, tmp_path_factory): + + cls = request.cls + cls.session_start_time = datetime.now(tz=gettz(name="US/Pacific")) cls.num_frames = int(1e7) cls.num_audio_files = 3 cls.sampling_rate = 500 cls.aligned_segment_starting_times = [0.0, 20.0, 40.0] - cls.test_dir = Path(mkdtemp()) + class_tmp_dir = tmp_path_factory.mktemp("class_tmp_dir") + cls.test_dir = Path(class_tmp_dir) cls.file_paths = create_audio_files( test_dir=cls.test_dir, num_audio_files=cls.num_audio_files, sampling_rate=cls.sampling_rate, num_frames=cls.num_frames, ) - cls.data_interface_cls = AudioInterface cls.interface_kwargs = dict(file_paths=[cls.file_paths[0]]) - def setUp(self): + @pytest.fixture(scope="function", autouse=True) + def setup_converter(self): + self.nwbfile_path = str(self.test_dir / "audio_test.nwb") self.create_audio_converter() self.metadata = self.nwb_converter.get_metadata() self.metadata["NWBFile"].update(session_start_time=self.session_start_time) - @classmethod - def tearDownClass(cls): - try: - shutil.rmtree(cls.test_dir) - except PermissionError: # Windows CI bug - warn(f"Unable to fully clean the temporary directory: {cls.test_dir}\n\nPlease remove it manually.") - def create_audio_converter(self): class AudioTestNWBConverter(NWBConverter): data_interface_classes = dict(Audio=AudioInterface) @@ -83,7 +82,7 @@ class AudioTestNWBConverter(NWBConverter): def test_unsupported_format(self): exc_msg = "The currently supported file format for audio is WAV file. Some of the provided files does not match this format: ['.test']." - with self.assertRaisesWith(ValueError, exc_msg=exc_msg): + with pytest.raises(ValueError, match=re.escape(exc_msg)): AudioInterface(file_paths=["test.test"]) def test_get_metadata(self): @@ -91,10 +90,10 @@ def test_get_metadata(self): metadata = audio_interface.get_metadata() audio_metadata = metadata["Behavior"]["Audio"] - self.assertEqual(len(audio_metadata), self.num_audio_files) + assert len(audio_metadata) == self.num_audio_files def test_incorrect_write_as(self): - with self.assertRaises(jsonschema.exceptions.ValidationError): + with pytest.raises(jsonschema.exceptions.ValidationError): self.nwb_converter.run_conversion( nwbfile_path=self.nwbfile_path, metadata=self.metadata, @@ -125,7 +124,7 @@ def test_incomplete_metadata(self): expected_error_message = ( "The Audio metadata is incomplete (1 entry)! Expected 3 (one for each entry of 'file_paths')." ) - with self.assertRaisesWith(exc_type=AssertionError, exc_msg=expected_error_message): + with pytest.raises(AssertionError, match=re.escape(expected_error_message)): self.nwb_converter.run_conversion(nwbfile_path=self.nwbfile_path, metadata=metadata, overwrite=True) def test_metadata_update(self): @@ -137,7 +136,7 @@ def test_metadata_update(self): nwbfile = io.read() container = nwbfile.stimulus audio_name = metadata["Behavior"]["Audio"][0]["name"] - self.assertEqual("New description for Acoustic waveform series.", container[audio_name].description) + assert container[audio_name].description == "New description for Acoustic waveform series." def test_not_all_metadata_are_unique(self): metadata = deepcopy(self.metadata) @@ -149,21 +148,18 @@ def test_not_all_metadata_are_unique(self): ], ) expected_error_message = "Some of the names for Audio metadata are not unique." - with self.assertRaisesWith(exc_type=AssertionError, exc_msg=expected_error_message): + with pytest.raises(AssertionError, match=re.escape(expected_error_message)): self.interface.run_conversion(nwbfile_path=self.nwbfile_path, metadata=metadata, overwrite=True) def test_segment_starting_times_are_floats(self): - with self.assertRaisesWith( - exc_type=AssertionError, exc_msg="Argument 'aligned_segment_starting_times' must be a list of floats." - ): + with pytest.raises(AssertionError, match="Argument 'aligned_segment_starting_times' must be a list of floats."): self.interface.set_aligned_segment_starting_times(aligned_segment_starting_times=[0, 1, 2]) def test_segment_starting_times_length_mismatch(self): - with self.assertRaisesWith( - exc_type=AssertionError, - exc_msg="The number of entries in 'aligned_segment_starting_times' (4) must be equal to the number of audio file paths (3).", - ): + with pytest.raises(AssertionError) as exc_info: self.interface.set_aligned_segment_starting_times(aligned_segment_starting_times=[0.0, 1.0, 2.0, 4.0]) + exc_msg = "The number of entries in 'aligned_segment_starting_times' (4) must be equal to the number of audio file paths (3)." + assert str(exc_info.value) == exc_msg def test_set_aligned_segment_starting_times(self): fresh_interface = AudioInterface(file_paths=self.file_paths[:2]) @@ -210,12 +206,10 @@ def test_run_conversion(self): nwbfile = io.read() container = nwbfile.stimulus metadata = self.nwb_converter.get_metadata() - self.assertEqual(3, len(container)) + assert len(container) == 3 for audio_ind, audio_metadata in enumerate(metadata["Behavior"]["Audio"]): audio_interface_name = audio_metadata["name"] assert audio_interface_name in container - self.assertEqual( - self.aligned_segment_starting_times[audio_ind], container[audio_interface_name].starting_time - ) - self.assertEqual(self.sampling_rate, container[audio_interface_name].rate) + assert self.aligned_segment_starting_times[audio_ind] == container[audio_interface_name].starting_time + assert self.sampling_rate == container[audio_interface_name].rate assert_array_equal(audio_test_data[audio_ind], container[audio_interface_name].data) diff --git a/tests/test_ecephys/test_mock_recording_interface.py b/tests/test_ecephys/test_mock_recording_interface.py index d7dfc5714..a33f3acd1 100644 --- a/tests/test_ecephys/test_mock_recording_interface.py +++ b/tests/test_ecephys/test_mock_recording_interface.py @@ -1,13 +1,9 @@ -import unittest - from neuroconv.tools.testing.data_interface_mixins import ( RecordingExtractorInterfaceTestMixin, ) from neuroconv.tools.testing.mock_interfaces import MockRecordingInterface -class TestMockRecordingInterface(unittest.TestCase, RecordingExtractorInterfaceTestMixin): +class TestMockRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = MockRecordingInterface - interface_kwargs = [ - dict(durations=[0.100]), - ] + interface_kwargs = dict(durations=[0.100]) diff --git a/tests/test_on_data/test_behavior_interfaces.py b/tests/test_on_data/test_behavior_interfaces.py index 1d25aaf36..33d0d468b 100644 --- a/tests/test_on_data/test_behavior_interfaces.py +++ b/tests/test_on_data/test_behavior_interfaces.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +import pytest import sleap_io from hdmf.testing import TestCase from natsort import natsorted @@ -41,7 +42,7 @@ from setup_paths import BEHAVIOR_DATA_PATH, OUTPUT_PATH -class TestLightningPoseDataInterface(DataInterfaceTestMixin, TemporalAlignmentMixin, unittest.TestCase): +class TestLightningPoseDataInterface(DataInterfaceTestMixin, TemporalAlignmentMixin): data_interface_cls = LightningPoseDataInterface interface_kwargs = dict( file_path=str(BEHAVIOR_DATA_PATH / "lightningpose" / "outputs/2023-11-09/10-14-37/video_preds/test_vid.csv"), @@ -52,8 +53,11 @@ class TestLightningPoseDataInterface(DataInterfaceTestMixin, TemporalAlignmentMi conversion_options = dict(reference_frame="(0,0) corresponds to the top left corner of the video.") save_directory = OUTPUT_PATH - @classmethod - def setUpClass(cls): + @pytest.fixture(scope="class", autouse=True) + def setup_metadata(self, request): + + cls = request.cls + cls.pose_estimation_name = "PoseEstimation" cls.original_video_height = 406 cls.original_video_width = 396 @@ -97,47 +101,61 @@ def setUpClass(cls): cls.test_data = pd.read_csv(cls.interface_kwargs["file_path"], header=[0, 1, 2])["heatmap_tracker"] def check_extracted_metadata(self, metadata: dict): - self.assertEqual( - metadata["NWBFile"]["session_start_time"], - datetime(2023, 11, 9, 10, 14, 37, 0), - ) - self.assertIn(self.pose_estimation_name, metadata["Behavior"]) - self.assertEqual( - metadata["Behavior"][self.pose_estimation_name], self.expected_metadata[self.pose_estimation_name] - ) + assert metadata["NWBFile"]["session_start_time"] == datetime(2023, 11, 9, 10, 14, 37, 0) + assert self.pose_estimation_name in metadata["Behavior"] + assert metadata["Behavior"][self.pose_estimation_name] == self.expected_metadata[self.pose_estimation_name] def check_read_nwb(self, nwbfile_path: str): with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io: nwbfile = io.read() - self.assertIn("behavior", nwbfile.processing) - self.assertIn(self.pose_estimation_name, nwbfile.processing["behavior"].data_interfaces) + + # Replacing assertIn with pytest-style assert + assert "behavior" in nwbfile.processing + assert self.pose_estimation_name in nwbfile.processing["behavior"].data_interfaces + pose_estimation_container = nwbfile.processing["behavior"].data_interfaces[self.pose_estimation_name] - self.assertIsInstance(pose_estimation_container, PoseEstimation) + + # Replacing assertIsInstance with pytest-style assert + assert isinstance(pose_estimation_container, PoseEstimation) pose_estimation_metadata = self.expected_metadata[self.pose_estimation_name] - self.assertEqual(pose_estimation_container.description, pose_estimation_metadata["description"]) - self.assertEqual(pose_estimation_container.scorer, pose_estimation_metadata["scorer"]) - self.assertEqual(pose_estimation_container.source_software, pose_estimation_metadata["source_software"]) + + # Replacing assertEqual with pytest-style assert + assert pose_estimation_container.description == pose_estimation_metadata["description"] + assert pose_estimation_container.scorer == pose_estimation_metadata["scorer"] + assert pose_estimation_container.source_software == pose_estimation_metadata["source_software"] + + # Using numpy's assert_array_equal assert_array_equal( pose_estimation_container.dimensions[:], [[self.original_video_height, self.original_video_width]] ) - self.assertEqual(len(pose_estimation_container.pose_estimation_series), len(self.expected_keypoint_names)) + # Replacing assertEqual with pytest-style assert + assert len(pose_estimation_container.pose_estimation_series) == len(self.expected_keypoint_names) + for keypoint_name in self.expected_keypoint_names: series_metadata = pose_estimation_metadata[keypoint_name] - self.assertIn(series_metadata["name"], pose_estimation_container.pose_estimation_series) + + # Replacing assertIn with pytest-style assert + assert series_metadata["name"] in pose_estimation_container.pose_estimation_series + pose_estimation_series = pose_estimation_container.pose_estimation_series[series_metadata["name"]] - self.assertIsInstance(pose_estimation_series, PoseEstimationSeries) - self.assertEqual(pose_estimation_series.unit, "px") - self.assertEqual(pose_estimation_series.description, series_metadata["description"]) - self.assertEqual(pose_estimation_series.reference_frame, self.conversion_options["reference_frame"]) + + # Replacing assertIsInstance with pytest-style assert + assert isinstance(pose_estimation_series, PoseEstimationSeries) + + # Replacing assertEqual with pytest-style assert + assert pose_estimation_series.unit == "px" + assert pose_estimation_series.description == series_metadata["description"] + assert pose_estimation_series.reference_frame == self.conversion_options["reference_frame"] test_data = self.test_data[keypoint_name] + + # Using numpy's assert_array_equal assert_array_equal(pose_estimation_series.data[:], test_data[["x", "y"]].values) - assert_array_equal(pose_estimation_series.confidence[:], test_data["likelihood"].values) -class TestLightningPoseDataInterfaceWithStubTest(DataInterfaceTestMixin, TemporalAlignmentMixin, unittest.TestCase): +class TestLightningPoseDataInterfaceWithStubTest(DataInterfaceTestMixin, TemporalAlignmentMixin): data_interface_cls = LightningPoseDataInterface interface_kwargs = dict( file_path=str(BEHAVIOR_DATA_PATH / "lightningpose" / "outputs/2023-11-09/10-14-37/video_preds/test_vid.csv"), @@ -145,6 +163,7 @@ class TestLightningPoseDataInterfaceWithStubTest(DataInterfaceTestMixin, Tempora BEHAVIOR_DATA_PATH / "lightningpose" / "outputs/2023-11-09/10-14-37/video_preds/test_vid.mp4" ), ) + conversion_options = dict(stub_test=True) save_directory = OUTPUT_PATH @@ -153,18 +172,16 @@ def check_read_nwb(self, nwbfile_path: str): nwbfile = io.read() pose_estimation_container = nwbfile.processing["behavior"].data_interfaces["PoseEstimation"] for pose_estimation_series in pose_estimation_container.pose_estimation_series.values(): - self.assertEqual(pose_estimation_series.data.shape[0], 10) - self.assertEqual(pose_estimation_series.confidence.shape[0], 10) + assert pose_estimation_series.data.shape[0] == 10 + assert pose_estimation_series.confidence.shape[0] == 10 -class TestFicTracDataInterface(DataInterfaceTestMixin, unittest.TestCase): +class TestFicTracDataInterface(DataInterfaceTestMixin): data_interface_cls = FicTracDataInterface - interface_kwargs = [ - dict( - file_path=str(BEHAVIOR_DATA_PATH / "FicTrac" / "sample" / "sample-20230724_113055.dat"), - configuration_file_path=str(BEHAVIOR_DATA_PATH / "FicTrac" / "sample" / "config.txt"), - ), - ] + interface_kwargs = dict( + file_path=str(BEHAVIOR_DATA_PATH / "FicTrac" / "sample" / "sample-20230724_113055.dat"), + configuration_file_path=str(BEHAVIOR_DATA_PATH / "FicTrac" / "sample" / "config.txt"), + ) save_directory = OUTPUT_PATH @@ -228,11 +245,11 @@ def check_read_nwb(self, nwbfile_path: str): # This is currently structured to assert spatial_series.timestamps[0] == 0.0 -class TestFicTracDataInterfaceWithRadius(DataInterfaceTestMixin, unittest.TestCase): +class TestFicTracDataInterfaceWithRadius(DataInterfaceTestMixin): data_interface_cls = FicTracDataInterface - interface_kwargs = [ - dict(file_path=str(BEHAVIOR_DATA_PATH / "FicTrac" / "sample" / "sample-20230724_113055.dat"), radius=1.0), - ] + interface_kwargs = dict( + file_path=str(BEHAVIOR_DATA_PATH / "FicTrac" / "sample" / "sample-20230724_113055.dat"), radius=1.0 + ) save_directory = OUTPUT_PATH @@ -296,74 +313,25 @@ def check_read_nwb(self, nwbfile_path: str): # This is currently structured to assert spatial_series.timestamps[0] == 0.0 -class TestFicTracDataInterfaceTiming(TemporalAlignmentMixin, unittest.TestCase): +class TestFicTracDataInterfaceTiming(TemporalAlignmentMixin): data_interface_cls = FicTracDataInterface - interface_kwargs = [dict(file_path=str(BEHAVIOR_DATA_PATH / "FicTrac" / "sample" / "sample-20230724_113055.dat"))] + interface_kwargs = dict(file_path=str(BEHAVIOR_DATA_PATH / "FicTrac" / "sample" / "sample-20230724_113055.dat")) save_directory = OUTPUT_PATH -class TestVideoInterface(VideoInterfaceMixin, unittest.TestCase): - data_interface_cls = VideoInterface - interface_kwargs = [ - dict(file_paths=[str(BEHAVIOR_DATA_PATH / "videos" / "CFR" / "video_avi.avi")]), - dict(file_paths=[str(BEHAVIOR_DATA_PATH / "videos" / "CFR" / "video_flv.flv")]), - dict(file_paths=[str(BEHAVIOR_DATA_PATH / "videos" / "CFR" / "video_mov.mov")]), - dict(file_paths=[str(BEHAVIOR_DATA_PATH / "videos" / "CFR" / "video_mp4.mp4")]), - dict(file_paths=[str(BEHAVIOR_DATA_PATH / "videos" / "CFR" / "video_wmv.wmv")]), - ] - save_directory = OUTPUT_PATH - - -class TestDeepLabCutInterface(DeepLabCutInterfaceMixin, unittest.TestCase): +class TestDeepLabCutInterface(DeepLabCutInterfaceMixin): data_interface_cls = DeepLabCutInterface - interface_kwargs_item = dict( + interface_kwargs = dict( file_path=str(BEHAVIOR_DATA_PATH / "DLC" / "m3v1mp4DLC_resnet50_openfieldAug20shuffle1_30000.h5"), config_file_path=str(BEHAVIOR_DATA_PATH / "DLC" / "config.yaml"), subject_name="ind1", ) - # intentional duplicate to workaround 2 tests with changes after interface construction - interface_kwargs = [ - interface_kwargs_item, # this is case=0, no custom timestamp - interface_kwargs_item, # this is case=1, with custom timestamp - ] - - # custom timestamps only for case 1 - _custom_timestamps_case_1 = np.concatenate( - (np.linspace(10, 110, 1000), np.linspace(150, 250, 1000), np.linspace(300, 400, 330)) - ) - save_directory = OUTPUT_PATH def run_custom_checks(self): - self.check_custom_timestamps(nwbfile_path=self.nwbfile_path) self.check_renaming_instance(nwbfile_path=self.nwbfile_path) - def check_custom_timestamps(self, nwbfile_path: str): - # TODO: Peel out into separate test class and replace this part with check_read_nwb - if self.case != 1: # set custom timestamps - return - - metadata = self.interface.get_metadata() - metadata["NWBFile"].update(session_start_time=datetime.now().astimezone()) - - self.interface.set_aligned_timestamps(self._custom_timestamps_case_1) - assert len(self.interface._timestamps) == 2330 - - self.interface.run_conversion(nwbfile_path=nwbfile_path, metadata=metadata, overwrite=True) - - with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io: - nwbfile = io.read() - assert "behavior" in nwbfile.processing - processing_module_interfaces = nwbfile.processing["behavior"].data_interfaces - assert "PoseEstimation" in processing_module_interfaces - - pose_estimation_series_in_nwb = processing_module_interfaces["PoseEstimation"].pose_estimation_series - - for pose_estimation in pose_estimation_series_in_nwb.values(): - pose_timestamps = pose_estimation.timestamps - np.testing.assert_array_equal(pose_timestamps, self._custom_timestamps_case_1) - def check_renaming_instance(self, nwbfile_path: str): custom_container_name = "TestPoseEstimation" @@ -381,7 +349,6 @@ def check_renaming_instance(self, nwbfile_path: str): assert custom_container_name in nwbfile.processing["behavior"].data_interfaces def check_read_nwb(self, nwbfile_path: str): - # TODO: move this to the upstream mixin with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io: nwbfile = io.read() assert "behavior" in nwbfile.processing @@ -398,7 +365,50 @@ def check_read_nwb(self, nwbfile_path: str): assert all(expected_pose_estimation_series_are_in_nwb_file) -class TestSLEAPInterface(DataInterfaceTestMixin, TemporalAlignmentMixin, unittest.TestCase): +class TestDeepLabCutInterfaceSetTimestamps(DeepLabCutInterfaceMixin): + data_interface_cls = DeepLabCutInterface + interface_kwargs = dict( + file_path=str(BEHAVIOR_DATA_PATH / "DLC" / "m3v1mp4DLC_resnet50_openfieldAug20shuffle1_30000.h5"), + config_file_path=str(BEHAVIOR_DATA_PATH / "DLC" / "config.yaml"), + subject_name="ind1", + ) + + save_directory = OUTPUT_PATH + + def run_custom_checks(self): + self.check_custom_timestamps(nwbfile_path=self.nwbfile_path) + + def check_custom_timestamps(self, nwbfile_path: str): + custom_timestamps = np.concatenate( + (np.linspace(10, 110, 1000), np.linspace(150, 250, 1000), np.linspace(300, 400, 330)) + ) + + metadata = self.interface.get_metadata() + metadata["NWBFile"].update(session_start_time=datetime.now().astimezone()) + + self.interface.set_aligned_timestamps(custom_timestamps) + assert len(self.interface._timestamps) == 2330 + + self.interface.run_conversion(nwbfile_path=nwbfile_path, metadata=metadata, overwrite=True) + + with NWBHDF5IO(path=nwbfile_path, mode="r", load_namespaces=True) as io: + nwbfile = io.read() + assert "behavior" in nwbfile.processing + processing_module_interfaces = nwbfile.processing["behavior"].data_interfaces + assert "PoseEstimation" in processing_module_interfaces + + pose_estimation_series_in_nwb = processing_module_interfaces["PoseEstimation"].pose_estimation_series + + for pose_estimation in pose_estimation_series_in_nwb.values(): + pose_timestamps = pose_estimation.timestamps + np.testing.assert_array_equal(pose_timestamps, custom_timestamps) + + # This was tested in the other test + def check_read_nwb(self, nwbfile_path: str): + pass + + +class TestSLEAPInterface(DataInterfaceTestMixin, TemporalAlignmentMixin): data_interface_cls = SLEAPInterface interface_kwargs = dict( file_path=str(BEHAVIOR_DATA_PATH / "sleap" / "predictions_1.2.7_provenance_and_tracking.slp"), @@ -429,16 +439,18 @@ def check_read_nwb(self, nwbfile_path: str): # This is currently structured to "wingL", "wingR", ] - self.assertCountEqual(first=pose_estimation_series_in_nwb, second=expected_pose_estimation_series) + + assert set(pose_estimation_series_in_nwb) == set(expected_pose_estimation_series) -class TestMiniscopeInterface(DataInterfaceTestMixin, unittest.TestCase): +class TestMiniscopeInterface(DataInterfaceTestMixin): data_interface_cls = MiniscopeBehaviorInterface interface_kwargs = dict(folder_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "Miniscope" / "C6-J588_Disc5")) save_directory = OUTPUT_PATH - @classmethod - def setUpClass(cls) -> None: + @pytest.fixture(scope="class", autouse=True) + def setup_metadata(self, request): + cls = request.cls folder_path = Path(OPHYS_DATA_PATH / "imaging_datasets" / "Miniscope" / "C6-J588_Disc5") cls.device_name = "BehavCam2" cls.image_series_name = "BehavCamImageSeries" @@ -455,45 +467,42 @@ def setUpClass(cls) -> None: cls.timestamps = get_timestamps(folder_path=str(folder_path), file_pattern="BehavCam*/timeStamps.csv") def check_extracted_metadata(self, metadata: dict): - self.assertEqual( - metadata["NWBFile"]["session_start_time"], - datetime(2021, 10, 7, 15, 3, 28, 635), - ) - self.assertEqual(metadata["Behavior"]["Device"][0], self.device_metadata) + assert metadata["NWBFile"]["session_start_time"] == datetime(2021, 10, 7, 15, 3, 28, 635) + assert metadata["Behavior"]["Device"][0] == self.device_metadata image_series_metadata = metadata["Behavior"]["ImageSeries"][0] - self.assertEqual(image_series_metadata["name"], self.image_series_name) - self.assertEqual(image_series_metadata["device"], self.device_name) - self.assertEqual(image_series_metadata["unit"], "px") - self.assertEqual(image_series_metadata["dimension"], [1280, 720]) # width x height + assert image_series_metadata["name"] == self.image_series_name + assert image_series_metadata["device"] == self.device_name + assert image_series_metadata["unit"] == "px" + assert image_series_metadata["dimension"] == [1280, 720] # width x height def check_read_nwb(self, nwbfile_path: str): with NWBHDF5IO(nwbfile_path, "r") as io: nwbfile = io.read() # Check device metadata - self.assertIn(self.device_name, nwbfile.devices) + assert self.device_name in nwbfile.devices device = nwbfile.devices[self.device_name] - self.assertIsInstance(device, Miniscope) - self.assertEqual(device.compression, self.device_metadata["compression"]) - self.assertEqual(device.deviceType, self.device_metadata["deviceType"]) - self.assertEqual(device.framesPerFile, self.device_metadata["framesPerFile"]) + assert isinstance(device, Miniscope) + assert device.compression == self.device_metadata["compression"] + assert device.deviceType == self.device_metadata["deviceType"] + assert device.framesPerFile == self.device_metadata["framesPerFile"] roi = [self.device_metadata["ROI"]["height"], self.device_metadata["ROI"]["width"]] assert_array_equal(device.ROI[:], roi) # Check ImageSeries - self.assertIn(self.image_series_name, nwbfile.acquisition) + assert self.image_series_name in nwbfile.acquisition image_series = nwbfile.acquisition[self.image_series_name] - self.assertEqual(image_series.format, "external") + assert image_series.format == "external" assert_array_equal(image_series.starting_frame, self.starting_frames) assert_array_equal(image_series.dimension[:], [1280, 720]) - self.assertEqual(image_series.unit, "px") - self.assertEqual(device, nwbfile.acquisition[self.image_series_name].device) + assert image_series.unit == "px" + assert device == nwbfile.acquisition[self.image_series_name].device assert_array_equal(image_series.timestamps[:], self.timestamps) assert_array_equal(image_series.external_file[:], self.external_files) -class TestNeuralynxNvtInterface(DataInterfaceTestMixin, TemporalAlignmentMixin, unittest.TestCase): +class TestNeuralynxNvtInterface(DataInterfaceTestMixin, TemporalAlignmentMixin): data_interface_cls = NeuralynxNvtInterface interface_kwargs = dict(file_path=str(BEHAVIOR_DATA_PATH / "neuralynx" / "test.nvt")) conversion_options = dict(add_angle=True) @@ -628,6 +637,30 @@ def test_sleap_interface_timestamps_propagation(self, data_interface, interface_ assert set(extracted_timestamps).issubset(expected_timestamps) +class TestVideoInterface(VideoInterfaceMixin): + data_interface_cls = VideoInterface + save_directory = OUTPUT_PATH + + @pytest.fixture( + params=[ + (dict(file_paths=[str(BEHAVIOR_DATA_PATH / "videos" / "CFR" / "video_avi.avi")])), + (dict(file_paths=[str(BEHAVIOR_DATA_PATH / "videos" / "CFR" / "video_flv.flv")])), + (dict(file_paths=[str(BEHAVIOR_DATA_PATH / "videos" / "CFR" / "video_mov.mov")])), + (dict(file_paths=[str(BEHAVIOR_DATA_PATH / "videos" / "CFR" / "video_mp4.mp4")])), + (dict(file_paths=[str(BEHAVIOR_DATA_PATH / "videos" / "CFR" / "video_wmv.wmv")])), + ], + ids=["avi", "flv", "mov", "mp4", "wmv"], + ) + def setup_interface(self, request): + + test_id = request.node.callspec.id + self.test_name = test_id + self.interface_kwargs = request.param + self.interface = self.data_interface_cls(**self.interface_kwargs) + + return self.interface, self.test_name + + class TestVideoConversions(TestCase): @classmethod def setUpClass(cls): diff --git a/tests/test_on_data/test_format_converters/test_miniscope_converter.py b/tests/test_on_data/test_format_converters/test_miniscope_converter.py index 813008455..a1e02ac1d 100644 --- a/tests/test_on_data/test_format_converters/test_miniscope_converter.py +++ b/tests/test_on_data/test_format_converters/test_miniscope_converter.py @@ -56,19 +56,9 @@ def tearDownClass(cls) -> None: def test_converter_metadata(self): metadata = self.converter.get_metadata() - self.assertEqual( - metadata["NWBFile"]["session_start_time"], - datetime(2021, 10, 7, 15, 3, 28, 635), - ) - self.assertDictEqual( - metadata["Ophys"]["Device"][0], - self.device_metadata, - ) - - self.assertDictEqual( - metadata["Behavior"]["Device"][0], - self.behavcam_metadata, - ) + assert metadata["NWBFile"]["session_start_time"] == datetime(2021, 10, 7, 15, 3, 28, 635) + assert metadata["Ophys"]["Device"][0] == self.device_metadata + assert metadata["Behavior"]["Device"][0] == self.behavcam_metadata def test_run_conversion(self): nwbfile_path = str(self.test_dir / "test_miniscope_converter.nwb") diff --git a/tests/test_on_data/test_imaging_interfaces.py b/tests/test_on_data/test_imaging_interfaces.py index 28f3d43ac..1a5328e52 100644 --- a/tests/test_on_data/test_imaging_interfaces.py +++ b/tests/test_on_data/test_imaging_interfaces.py @@ -1,9 +1,10 @@ import platform from datetime import datetime from pathlib import Path -from unittest import TestCase, skipIf +from unittest import skipIf import numpy as np +import pytest from dateutil.tz import tzoffset from hdmf.testing import TestCase as hdmf_TestCase from numpy.testing import assert_array_equal @@ -40,7 +41,7 @@ from setup_paths import OPHYS_DATA_PATH, OUTPUT_PATH -class TestTiffImagingInterface(ImagingExtractorInterfaceTestMixin, TestCase): +class TestTiffImagingInterface(ImagingExtractorInterfaceTestMixin): data_interface_cls = TiffImagingInterface interface_kwargs = dict( file_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "Tif" / "demoMovie.tif"), @@ -69,7 +70,7 @@ class TestTiffImagingInterface(ImagingExtractorInterfaceTestMixin, TestCase): }, ], ) -class TestScanImageImagingInterfaceMultiPlaneCase(ScanImageMultiPlaneImagingInterfaceMixin, TestCase): +class TestScanImageImagingInterfaceMultiPlaneCase(ScanImageMultiPlaneImagingInterfaceMixin): data_interface_cls = ScanImageImagingInterface interface_kwargs = dict( file_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "ScanImage" / "scanimage_20220923_roi.tif"), @@ -127,7 +128,7 @@ def check_extracted_metadata(self, metadata: dict): }, ], ) -class TestScanImageImagingInterfaceSinglePlaneCase(ScanImageSinglePlaneImagingInterfaceMixin, TestCase): +class TestScanImageImagingInterfaceSinglePlaneCase(ScanImageSinglePlaneImagingInterfaceMixin): data_interface_cls = ScanImageImagingInterface save_directory = OUTPUT_PATH interface_kwargs = dict( @@ -203,7 +204,7 @@ def test_non_volumetric_data(self): @skipIf(platform.machine() == "arm64", "Interface not supported on arm64 architecture") -class TestScanImageLegacyImagingInterface(ImagingExtractorInterfaceTestMixin, TestCase): +class TestScanImageLegacyImagingInterface(ImagingExtractorInterfaceTestMixin): data_interface_cls = ScanImageImagingInterface interface_kwargs = dict(file_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "Tif" / "sample_scanimage.tiff")) save_directory = OUTPUT_PATH @@ -238,7 +239,7 @@ def check_extracted_metadata(self, metadata: dict): }, ], ) -class TestScanImageMultiFileImagingInterfaceMultiPlaneCase(ScanImageMultiPlaneImagingInterfaceMixin, TestCase): +class TestScanImageMultiFileImagingInterfaceMultiPlaneCase(ScanImageMultiPlaneImagingInterfaceMixin): data_interface_cls = ScanImageMultiFileImagingInterface interface_kwargs = dict( folder_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "ScanImage"), @@ -279,7 +280,7 @@ def check_extracted_metadata(self, metadata: dict): }, ], ) -class TestScanImageMultiFileImagingInterfaceSinglePlaneCase(ScanImageSinglePlaneImagingInterfaceMixin, TestCase): +class TestScanImageMultiFileImagingInterfaceSinglePlaneCase(ScanImageSinglePlaneImagingInterfaceMixin): data_interface_cls = ScanImageMultiFileImagingInterface interface_kwargs = dict( folder_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "ScanImage"), @@ -353,22 +354,26 @@ def test_plane_name_not_specified(self): ScanImageSinglePlaneMultiFileImagingInterface(folder_path=folder_path, file_pattern=file_pattern) -class TestHdf5ImagingInterface(ImagingExtractorInterfaceTestMixin, TestCase): +class TestHdf5ImagingInterface(ImagingExtractorInterfaceTestMixin): data_interface_cls = Hdf5ImagingInterface interface_kwargs = dict(file_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "hdf5" / "demoMovie.hdf5")) save_directory = OUTPUT_PATH -class TestSbxImagingInterface(ImagingExtractorInterfaceTestMixin, TestCase): +class TestSbxImagingInterfaceMat(ImagingExtractorInterfaceTestMixin): data_interface_cls = SbxImagingInterface - interface_kwargs = [ - dict(file_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "Scanbox" / "sample.mat")), - dict(file_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "Scanbox" / "sample.sbx")), - ] + interface_kwargs = dict(file_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "Scanbox" / "sample.mat")) save_directory = OUTPUT_PATH -class TestBrukerTiffImagingInterface(ImagingExtractorInterfaceTestMixin, TestCase): +class TestSbxImagingInterfaceSBX(ImagingExtractorInterfaceTestMixin): + data_interface_cls = SbxImagingInterface + interface_kwargs = dict(file_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "Scanbox" / "sample.sbx")) + + save_directory = OUTPUT_PATH + + +class TestBrukerTiffImagingInterface(ImagingExtractorInterfaceTestMixin): data_interface_cls = BrukerTiffSinglePlaneImagingInterface interface_kwargs = dict( folder_path=str( @@ -377,8 +382,11 @@ class TestBrukerTiffImagingInterface(ImagingExtractorInterfaceTestMixin, TestCas ) save_directory = OUTPUT_PATH - @classmethod - def setUpClass(cls) -> None: + @pytest.fixture(scope="class", autouse=True) + def setup_metadata(cls, request): + + cls = request.cls + cls.device_metadata = dict(name="BrukerFluorescenceMicroscope", description="Version 5.6.64.400") cls.optical_channel_metadata = dict( name="Ch2", @@ -397,7 +405,6 @@ def setUpClass(cls) -> None: grid_spacing=[1.1078125e-06, 1.1078125e-06], origin_coords=[0.0, 0.0], ) - cls.two_photon_series_metadata = dict( name="TwoPhotonSeries", description="Imaging data acquired from the Bruker Two-Photon Microscope.", @@ -407,7 +414,6 @@ def setUpClass(cls) -> None: scan_line_rate=15840.580398865815, field_of_view=[0.0005672, 0.0005672], ) - cls.ophys_metadata = dict( Device=[cls.device_metadata], ImagingPlane=[cls.imaging_plane_metadata], @@ -415,8 +421,8 @@ def setUpClass(cls) -> None: ) def check_extracted_metadata(self, metadata: dict): - self.assertEqual(metadata["NWBFile"]["session_start_time"], datetime(2023, 2, 20, 15, 58, 25)) - self.assertDictEqual(metadata["Ophys"], self.ophys_metadata) + assert metadata["NWBFile"]["session_start_time"] == datetime(2023, 2, 20, 15, 58, 25) + assert metadata["Ophys"] == self.ophys_metadata def check_read_nwb(self, nwbfile_path: str): """Check the ophys metadata made it to the NWB file""" @@ -424,29 +430,27 @@ def check_read_nwb(self, nwbfile_path: str): with NWBHDF5IO(nwbfile_path, "r") as io: nwbfile = io.read() - self.assertIn(self.device_metadata["name"], nwbfile.devices) - self.assertEqual( - nwbfile.devices[self.device_metadata["name"]].description, self.device_metadata["description"] - ) - self.assertIn(self.imaging_plane_metadata["name"], nwbfile.imaging_planes) + assert self.device_metadata["name"] in nwbfile.devices + assert nwbfile.devices[self.device_metadata["name"]].description == self.device_metadata["description"] + assert self.imaging_plane_metadata["name"] in nwbfile.imaging_planes imaging_plane = nwbfile.imaging_planes[self.imaging_plane_metadata["name"]] optical_channel = imaging_plane.optical_channel[0] - self.assertEqual(optical_channel.name, self.optical_channel_metadata["name"]) - self.assertEqual(optical_channel.description, self.optical_channel_metadata["description"]) - self.assertEqual(imaging_plane.description, self.imaging_plane_metadata["description"]) - self.assertEqual(imaging_plane.imaging_rate, self.imaging_plane_metadata["imaging_rate"]) + assert optical_channel.name == self.optical_channel_metadata["name"] + assert optical_channel.description == self.optical_channel_metadata["description"] + assert imaging_plane.description == self.imaging_plane_metadata["description"] + assert imaging_plane.imaging_rate == self.imaging_plane_metadata["imaging_rate"] assert_array_equal(imaging_plane.grid_spacing[:], self.imaging_plane_metadata["grid_spacing"]) - self.assertIn(self.two_photon_series_metadata["name"], nwbfile.acquisition) + assert self.two_photon_series_metadata["name"] in nwbfile.acquisition two_photon_series = nwbfile.acquisition[self.two_photon_series_metadata["name"]] - self.assertEqual(two_photon_series.description, self.two_photon_series_metadata["description"]) - self.assertEqual(two_photon_series.unit, self.two_photon_series_metadata["unit"]) - self.assertEqual(two_photon_series.scan_line_rate, self.two_photon_series_metadata["scan_line_rate"]) + assert two_photon_series.description == self.two_photon_series_metadata["description"] + assert two_photon_series.unit == self.two_photon_series_metadata["unit"] + assert two_photon_series.scan_line_rate == self.two_photon_series_metadata["scan_line_rate"] assert_array_equal(two_photon_series.field_of_view[:], self.two_photon_series_metadata["field_of_view"]) super().check_read_nwb(nwbfile_path=nwbfile_path) -class TestBrukerTiffImagingInterfaceDualPlaneCase(ImagingExtractorInterfaceTestMixin, TestCase): +class TestBrukerTiffImagingInterfaceDualPlaneCase(ImagingExtractorInterfaceTestMixin): data_interface_cls = BrukerTiffMultiPlaneImagingInterface interface_kwargs = dict( folder_path=str( @@ -455,8 +459,10 @@ class TestBrukerTiffImagingInterfaceDualPlaneCase(ImagingExtractorInterfaceTestM ) save_directory = OUTPUT_PATH - @classmethod - def setUpClass(cls) -> None: + @pytest.fixture(scope="class", autouse=True) + def setup_metadata(self, request): + cls = request.cls + cls.photon_series_name = "TwoPhotonSeries" cls.num_frames = 5 cls.image_shape = (512, 512, 2) @@ -501,22 +507,23 @@ def run_custom_checks(self): streams = self.data_interface_cls.get_streams( folder_path=self.interface_kwargs["folder_path"], plane_separation_type="contiguous" ) - self.assertEqual(streams, self.available_streams) + + assert streams == self.available_streams def check_extracted_metadata(self, metadata: dict): - self.assertEqual(metadata["NWBFile"]["session_start_time"], datetime(2022, 11, 3, 11, 20, 34)) - self.assertDictEqual(metadata["Ophys"], self.ophys_metadata) + assert metadata["NWBFile"]["session_start_time"] == datetime(2022, 11, 3, 11, 20, 34) + assert metadata["Ophys"] == self.ophys_metadata def check_read_nwb(self, nwbfile_path: str): with NWBHDF5IO(path=nwbfile_path) as io: nwbfile = io.read() photon_series = nwbfile.acquisition[self.photon_series_name] - self.assertEqual(photon_series.data.shape, (self.num_frames, *self.image_shape)) - assert_array_equal(photon_series.dimension[:], self.image_shape) - self.assertEqual(photon_series.rate, 20.629515014336377) + assert photon_series.data.shape == (self.num_frames, *self.image_shape) + np.testing.assert_array_equal(photon_series.dimension[:], self.image_shape) + assert photon_series.rate == 20.629515014336377 -class TestBrukerTiffImagingInterfaceDualPlaneDisjointCase(ImagingExtractorInterfaceTestMixin, TestCase): +class TestBrukerTiffImagingInterfaceDualPlaneDisjointCase(ImagingExtractorInterfaceTestMixin): data_interface_cls = BrukerTiffSinglePlaneImagingInterface interface_kwargs = dict( folder_path=str( @@ -526,8 +533,11 @@ class TestBrukerTiffImagingInterfaceDualPlaneDisjointCase(ImagingExtractorInterf ) save_directory = OUTPUT_PATH - @classmethod - def setUpClass(cls) -> None: + @pytest.fixture(scope="class", autouse=True) + def setup_metadata(cls, request): + + cls = request.cls + cls.photon_series_name = "TwoPhotonSeriesCh2000002" cls.num_frames = 5 cls.image_shape = (512, 512) @@ -570,18 +580,19 @@ def setUpClass(cls) -> None: def run_custom_checks(self): # check stream names streams = self.data_interface_cls.get_streams(folder_path=self.interface_kwargs["folder_path"]) - self.assertEqual(streams, self.available_streams) + assert streams == self.available_streams def check_extracted_metadata(self, metadata: dict): - self.assertEqual(metadata["NWBFile"]["session_start_time"], datetime(2022, 11, 3, 11, 20, 34)) - self.assertDictEqual(metadata["Ophys"], self.ophys_metadata) + assert metadata["NWBFile"]["session_start_time"] == datetime(2022, 11, 3, 11, 20, 34) + assert metadata["Ophys"] == self.ophys_metadata def check_nwbfile_temporal_alignment(self): nwbfile_path = str( - self.save_directory / f"{self.data_interface_cls.__name__}_{self.case}_test_starting_time_alignment.nwb" + self.save_directory + / f"{self.data_interface_cls.__name__}_{self.test_name}_test_starting_time_alignment.nwb" ) - interface = self.data_interface_cls(**self.test_kwargs) + interface = self.data_interface_cls(**self.interface_kwargs) aligned_starting_time = 1.23 interface.set_aligned_starting_time(aligned_starting_time=aligned_starting_time) @@ -598,12 +609,12 @@ def check_read_nwb(self, nwbfile_path: str): with NWBHDF5IO(path=nwbfile_path) as io: nwbfile = io.read() photon_series = nwbfile.acquisition[self.photon_series_name] - self.assertEqual(photon_series.data.shape, (self.num_frames, *self.image_shape)) - assert_array_equal(photon_series.dimension[:], self.image_shape) - self.assertEqual(photon_series.rate, 10.314757507168189) + assert photon_series.data.shape == (self.num_frames, *self.image_shape) + np.testing.assert_array_equal(photon_series.dimension[:], self.image_shape) + assert photon_series.rate == 10.314757507168189 -class TestBrukerTiffImagingInterfaceDualColorCase(ImagingExtractorInterfaceTestMixin, TestCase): +class TestBrukerTiffImagingInterfaceDualColorCase(ImagingExtractorInterfaceTestMixin): data_interface_cls = BrukerTiffSinglePlaneImagingInterface interface_kwargs = dict( folder_path=str( @@ -613,8 +624,10 @@ class TestBrukerTiffImagingInterfaceDualColorCase(ImagingExtractorInterfaceTestM ) save_directory = OUTPUT_PATH - @classmethod - def setUpClass(cls) -> None: + @pytest.fixture(scope="class", autouse=True) + def setup_metadata(cls, request): + + cls = request.cls cls.photon_series_name = "TwoPhotonSeriesCh2" cls.num_frames = 10 cls.image_shape = (512, 512) @@ -657,26 +670,27 @@ def setUpClass(cls) -> None: def run_custom_checks(self): # check stream names streams = self.data_interface_cls.get_streams(folder_path=self.interface_kwargs["folder_path"]) - self.assertEqual(streams, self.available_streams) + assert streams == self.available_streams def check_extracted_metadata(self, metadata: dict): - self.assertEqual(metadata["NWBFile"]["session_start_time"], datetime(2023, 7, 6, 15, 13, 58)) - self.assertDictEqual(metadata["Ophys"], self.ophys_metadata) + assert metadata["NWBFile"]["session_start_time"] == datetime(2023, 7, 6, 15, 13, 58) + assert metadata["Ophys"] == self.ophys_metadata def check_read_nwb(self, nwbfile_path: str): with NWBHDF5IO(path=nwbfile_path) as io: nwbfile = io.read() photon_series = nwbfile.acquisition[self.photon_series_name] - self.assertEqual(photon_series.data.shape, (self.num_frames, *self.image_shape)) - assert_array_equal(photon_series.dimension[:], self.image_shape) - self.assertEqual(photon_series.rate, 29.873615189896864) + assert photon_series.data.shape == (self.num_frames, *self.image_shape) + np.testing.assert_array_equal(photon_series.dimension[:], self.image_shape) + assert photon_series.rate == 29.873615189896864 def check_nwbfile_temporal_alignment(self): nwbfile_path = str( - self.save_directory / f"{self.data_interface_cls.__name__}_{self.case}_test_starting_time_alignment.nwb" + self.save_directory + / f"{self.data_interface_cls.__name__}_{self.test_name}_test_starting_time_alignment.nwb" ) - interface = self.data_interface_cls(**self.test_kwargs) + interface = self.data_interface_cls(**self.interface_kwargs) aligned_starting_time = 1.23 interface.set_aligned_starting_time(aligned_starting_time=aligned_starting_time) @@ -690,15 +704,16 @@ def check_nwbfile_temporal_alignment(self): assert nwbfile.acquisition[self.photon_series_name].starting_time == aligned_starting_time -class TestMicroManagerTiffImagingInterface(ImagingExtractorInterfaceTestMixin, TestCase): +class TestMicroManagerTiffImagingInterface(ImagingExtractorInterfaceTestMixin): data_interface_cls = MicroManagerTiffImagingInterface interface_kwargs = dict( folder_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "MicroManagerTif" / "TS12_20220407_20hz_noteasy_1") ) save_directory = OUTPUT_PATH - @classmethod - def setUpClass(cls) -> None: + @pytest.fixture(scope="class", autouse=True) + def setup_metadata(self, request): + cls = request.cls cls.device_metadata = dict(name="Microscope") cls.optical_channel_metadata = dict( name="OpticalChannelDefault", @@ -731,45 +746,47 @@ def setUpClass(cls) -> None: ) def check_extracted_metadata(self, metadata: dict): - self.assertEqual( - metadata["NWBFile"]["session_start_time"], - datetime(2022, 4, 7, 15, 6, 56, 842000, tzinfo=tzoffset(None, -18000)), + + assert metadata["NWBFile"]["session_start_time"] == datetime( + 2022, 4, 7, 15, 6, 56, 842000, tzinfo=tzoffset(None, -18000) ) - self.assertDictEqual(metadata["Ophys"], self.ophys_metadata) + assert metadata["Ophys"] == self.ophys_metadata def check_read_nwb(self, nwbfile_path: str): """Check the ophys metadata made it to the NWB file""" - with NWBHDF5IO(nwbfile_path, "r") as io: + # Assuming you would create and write an NWB file here before reading it back + + with NWBHDF5IO(str(nwbfile_path), "r") as io: nwbfile = io.read() - self.assertIn(self.imaging_plane_metadata["name"], nwbfile.imaging_planes) + assert self.imaging_plane_metadata["name"] in nwbfile.imaging_planes imaging_plane = nwbfile.imaging_planes[self.imaging_plane_metadata["name"]] optical_channel = imaging_plane.optical_channel[0] - self.assertEqual(optical_channel.name, self.optical_channel_metadata["name"]) - self.assertEqual(optical_channel.description, self.optical_channel_metadata["description"]) - self.assertEqual(imaging_plane.description, self.imaging_plane_metadata["description"]) - self.assertEqual(imaging_plane.imaging_rate, self.imaging_plane_metadata["imaging_rate"]) - self.assertIn(self.two_photon_series_metadata["name"], nwbfile.acquisition) + assert optical_channel.name == self.optical_channel_metadata["name"] + assert optical_channel.description == self.optical_channel_metadata["description"] + assert imaging_plane.description == self.imaging_plane_metadata["description"] + assert imaging_plane.imaging_rate == self.imaging_plane_metadata["imaging_rate"] + assert self.two_photon_series_metadata["name"] in nwbfile.acquisition two_photon_series = nwbfile.acquisition[self.two_photon_series_metadata["name"]] - self.assertEqual(two_photon_series.description, self.two_photon_series_metadata["description"]) - self.assertEqual(two_photon_series.unit, self.two_photon_series_metadata["unit"]) - self.assertEqual(two_photon_series.format, self.two_photon_series_metadata["format"]) + assert two_photon_series.description == self.two_photon_series_metadata["description"] + assert two_photon_series.unit == self.two_photon_series_metadata["unit"] + assert two_photon_series.format == self.two_photon_series_metadata["format"] assert_array_equal(two_photon_series.dimension[:], self.two_photon_series_metadata["dimension"]) super().check_read_nwb(nwbfile_path=nwbfile_path) -class TestMiniscopeImagingInterface(MiniscopeImagingInterfaceMixin, hdmf_TestCase): +class TestMiniscopeImagingInterface(MiniscopeImagingInterfaceMixin): data_interface_cls = MiniscopeImagingInterface interface_kwargs = dict(folder_path=str(OPHYS_DATA_PATH / "imaging_datasets" / "Miniscope" / "C6-J588_Disc5")) save_directory = OUTPUT_PATH - @classmethod - def setUpClass(cls) -> None: + @pytest.fixture(scope="class", autouse=True) + def setup_metadata(cls, request): + cls = request.cls + cls.device_name = "Miniscope" - cls.imaging_plane_name = "ImagingPlane" - cls.photon_series_name = "OnePhotonSeries" cls.device_metadata = dict( name=cls.device_name, @@ -781,27 +798,35 @@ def setUpClass(cls) -> None: led0=47, ) - def check_extracted_metadata(self, metadata: dict): - self.assertEqual( - metadata["NWBFile"]["session_start_time"], - datetime(2021, 10, 7, 15, 3, 28, 635), + cls.imaging_plane_name = "ImagingPlane" + cls.imaging_plane_metadata = dict( + name=cls.imaging_plane_name, + device=cls.device_name, + imaging_rate=15.0, ) - self.assertEqual(metadata["Ophys"]["Device"][0], self.device_metadata) + + cls.photon_series_name = "OnePhotonSeries" + cls.photon_series_metadata = dict( + name=cls.photon_series_name, + unit="px", + ) + + def check_extracted_metadata(self, metadata: dict): + assert metadata["NWBFile"]["session_start_time"] == datetime(2021, 10, 7, 15, 3, 28, 635) + assert metadata["Ophys"]["Device"][0] == self.device_metadata + imaging_plane_metadata = metadata["Ophys"]["ImagingPlane"][0] - self.assertEqual(imaging_plane_metadata["name"], self.imaging_plane_name) - self.assertEqual(imaging_plane_metadata["device"], self.device_name) - self.assertEqual(imaging_plane_metadata["imaging_rate"], 15.0) + assert imaging_plane_metadata["name"] == self.imaging_plane_metadata["name"] + assert imaging_plane_metadata["device"] == self.imaging_plane_metadata["device"] + assert imaging_plane_metadata["imaging_rate"] == self.imaging_plane_metadata["imaging_rate"] one_photon_series_metadata = metadata["Ophys"]["OnePhotonSeries"][0] - self.assertEqual(one_photon_series_metadata["name"], self.photon_series_name) - self.assertEqual(one_photon_series_metadata["unit"], "px") - - def run_custom_checks(self): - self.check_incorrect_folder_structure_raises() + assert one_photon_series_metadata["name"] == self.photon_series_metadata["name"] + assert one_photon_series_metadata["unit"] == self.photon_series_metadata["unit"] - def check_incorrect_folder_structure_raises(self): + def test_incorrect_folder_structure_raises(self): folder_path = Path(self.interface_kwargs["folder_path"]) / "15_03_28/BehavCam_2/" - with self.assertRaisesWith( - exc_type=AssertionError, exc_msg="The main folder should contain at least one subfolder named 'Miniscope'." + with pytest.raises( + AssertionError, match="The main folder should contain at least one subfolder named 'Miniscope'." ): self.data_interface_cls(folder_path=folder_path) diff --git a/tests/test_on_data/test_recording_interfaces.py b/tests/test_on_data/test_recording_interfaces.py index 6c3a410e2..cf838cadc 100644 --- a/tests/test_on_data/test_recording_interfaces.py +++ b/tests/test_on_data/test_recording_interfaces.py @@ -2,10 +2,9 @@ from platform import python_version from sys import platform from typing import Literal -from unittest import skip, skipIf import numpy as np -from hdmf.testing import TestCase +import pytest from numpy.testing import assert_array_equal from packaging import version from pynwb import NWBHDF5IO @@ -47,7 +46,7 @@ this_python_version = version.parse(python_version()) -class TestAlphaOmegaRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestAlphaOmegaRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = AlphaOmegaRecordingInterface interface_kwargs = dict(folder_path=str(DATA_PATH / "alphaomega" / "mpx_map_version4")) save_directory = OUTPUT_PATH @@ -56,49 +55,73 @@ def check_extracted_metadata(self, metadata: dict): assert metadata["NWBFile"]["session_start_time"] == datetime(2021, 11, 19, 15, 23, 15) -class TestAxonRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestAxonRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = AxonaRecordingInterface interface_kwargs = dict(file_path=str(DATA_PATH / "axona" / "axona_raw.bin")) save_directory = OUTPUT_PATH -class TestBiocamRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestBiocamRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = BiocamRecordingInterface interface_kwargs = dict(file_path=str(DATA_PATH / "biocam" / "biocam_hw3.0_fw1.6.brw")) save_directory = OUTPUT_PATH -class TestBlackrockRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestBlackrockRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = BlackrockRecordingInterface - interface_kwargs = [ - dict(file_path=str(DATA_PATH / "blackrock" / "blackrock_2_1" / "l101210-001.ns5")), - dict(file_path=str(DATA_PATH / "blackrock" / "FileSpec2.3001.ns5")), - dict(file_path=str(DATA_PATH / "blackrock" / "blackrock_2_1" / "l101210-001.ns2")), - ] save_directory = OUTPUT_PATH + @pytest.fixture( + params=[ + dict(file_path=str(DATA_PATH / "blackrock" / "blackrock_2_1" / "l101210-001.ns5")), + dict(file_path=str(DATA_PATH / "blackrock" / "FileSpec2.3001.ns5")), + dict(file_path=str(DATA_PATH / "blackrock" / "blackrock_2_1" / "l101210-001.ns2")), + ], + ids=["blackrock_ns5_v1", "blackrock_ns5_v2", "blackrock_ns2"], + ) + def setup_interface(self, request): + test_id = request.node.callspec.id + self.test_name = test_id + self.interface_kwargs = request.param + self.interface = self.data_interface_cls(**self.interface_kwargs) + + return self.interface, self.test_name + -@skipIf( +@pytest.mark.skipif( platform == "darwin" or this_python_version > version.parse("3.9"), - reason="Interface unsupported for OSX. Interface only runs on python 3.9", + reason="Interface unsupported for OSX. Interface only runs on Python 3.9", ) -class TestSpike2RecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestSpike2RecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = Spike2RecordingInterface interface_kwargs = dict(file_path=str(DATA_PATH / "spike2" / "m365_1sec.smrx")) save_directory = OUTPUT_PATH -class TestCellExplorerRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestCellExplorerRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = CellExplorerRecordingInterface - interface_kwargs = [ - dict(folder_path=str(DATA_PATH / "cellexplorer" / "dataset_4" / "Peter_MS22_180629_110319_concat_stubbed")), - dict( - folder_path=str(DATA_PATH / "cellexplorer" / "dataset_4" / "Peter_MS22_180629_110319_concat_stubbed_hdf5") - ), - ] save_directory = OUTPUT_PATH - def test_add_channel_metadata_to_nwb(self): + @pytest.fixture( + params=[ + dict(folder_path=str(DATA_PATH / "cellexplorer" / "dataset_4" / "Peter_MS22_180629_110319_concat_stubbed")), + dict( + folder_path=str( + DATA_PATH / "cellexplorer" / "dataset_4" / "Peter_MS22_180629_110319_concat_stubbed_hdf5" + ) + ), + ], + ids=["matlab", "hdf5"], + ) + def setup_interface(self, request): + test_id = request.node.callspec.id + self.test_name = test_id + self.interface_kwargs = request.param + self.interface = self.data_interface_cls(**self.interface_kwargs) + + return self.interface, self.test_name + + def test_add_channel_metadata_to_nwb(self, setup_interface): channel_id = "1" expected_channel_properties_recorder = { "location": np.array([791.5, -160.0]), @@ -112,42 +135,41 @@ def test_add_channel_metadata_to_nwb(self): "group_name": "Group 5", } - interface_kwargs = self.interface_kwargs - for num, kwargs in enumerate(interface_kwargs): - with self.subTest(str(num)): - self.case = num - self.test_kwargs = kwargs - self.interface = self.data_interface_cls(**self.test_kwargs) - self.nwbfile_path = str(self.save_directory / f"{self.data_interface_cls.__name__}_{num}_channel.nwb") - - metadata = self.interface.get_metadata() - metadata["NWBFile"].update(session_start_time=datetime.now().astimezone()) - self.interface.run_conversion( - nwbfile_path=self.nwbfile_path, - overwrite=True, - metadata=metadata, - ) + self.nwbfile_path = str( + self.save_directory / f"{self.data_interface_cls.__name__}_{self.test_name}_channel.nwb" + ) + + metadata = self.interface.get_metadata() + metadata["NWBFile"].update(session_start_time=datetime.now().astimezone()) + self.interface.run_conversion( + nwbfile_path=self.nwbfile_path, + overwrite=True, + metadata=metadata, + ) + + # Test addition to recording extractor + recording_extractor = self.interface.recording_extractor + for key, expected_value in expected_channel_properties_recorder.items(): + extracted_value = recording_extractor.get_channel_property(channel_id=channel_id, key=key) + if key == "location": + assert np.allclose(expected_value, extracted_value) + else: + assert expected_value == extracted_value + + # Test addition to electrodes table + with NWBHDF5IO(self.nwbfile_path, "r") as io: + nwbfile = io.read() + electrode_table = nwbfile.electrodes.to_dataframe() + electrode_table_row = electrode_table.query(f"channel_name=='{channel_id}'").iloc[0] + for key, value in expected_channel_properties_electrodes.items(): + assert electrode_table_row[key] == value + - # Test addition to recording extractor - recording_extractor = self.interface.recording_extractor - for key, expected_value in expected_channel_properties_recorder.items(): - extracted_value = recording_extractor.get_channel_property(channel_id=channel_id, key=key) - if key == "location": - assert np.allclose(expected_value, extracted_value) - else: - assert expected_value == extracted_value - - # Test addition to electrodes table - with NWBHDF5IO(self.nwbfile_path, "r") as io: - nwbfile = io.read() - electrode_table = nwbfile.electrodes.to_dataframe() - electrode_table_row = electrode_table.query(f"channel_name=='{channel_id}'").iloc[0] - for key, value in expected_channel_properties_electrodes.items(): - assert electrode_table_row[key] == value - - -@skipIf(platform == "darwin", reason="Not supported for OSX.") -class TestEDFRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +@pytest.mark.skipif( + platform == "darwin", + reason="Interface unsupported for OSX.", +) +class TestEDFRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = EDFRecordingInterface interface_kwargs = dict(file_path=str(DATA_PATH / "edf" / "edf+C.edf")) save_directory = OUTPUT_PATH @@ -157,23 +179,17 @@ def check_extracted_metadata(self, metadata: dict): def test_interface_alignment(self): interface_kwargs = self.interface_kwargs - if isinstance(interface_kwargs, dict): - interface_kwargs = [interface_kwargs] - for num, kwargs in enumerate(interface_kwargs): - with self.subTest(str(num)): - self.case = num - self.test_kwargs = kwargs - - # TODO - debug hanging I/O from pyedflib - # self.check_interface_get_original_timestamps() - # self.check_interface_get_timestamps() - # self.check_align_starting_time_internal() - # self.check_align_starting_time_external() - # self.check_interface_align_timestamps() - # self.check_shift_timestamps_by_start_time() - # self.check_interface_original_timestamps_inmutability() - - self.check_nwbfile_temporal_alignment() + + # TODO - debug hanging I/O from pyedflib + # self.check_interface_get_original_timestamps() + # self.check_interface_get_timestamps() + # self.check_align_starting_time_internal() + # self.check_align_starting_time_external() + # self.check_interface_align_timestamps() + # self.check_shift_timestamps_by_start_time() + # self.check_interface_original_timestamps_inmutability() + + self.check_nwbfile_temporal_alignment() # EDF has simultaneous access issues; can't have multiple interfaces open on the same file at once... def check_run_conversion_in_nwbconverter_with_backend( @@ -190,17 +206,30 @@ def check_run_conversion_with_backend(self, nwbfile_path: str, backend: Literal[ pass -class TestIntanRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestIntanRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = IntanRecordingInterface - interface_kwargs = [ - dict(file_path=str(DATA_PATH / "intan" / "intan_rhd_test_1.rhd")), - dict(file_path=str(DATA_PATH / "intan" / "intan_rhs_test_1.rhs")), - ] + interface_kwargs = [] save_directory = OUTPUT_PATH + @pytest.fixture( + params=[ + dict(file_path=str(DATA_PATH / "intan" / "intan_rhd_test_1.rhd")), + dict(file_path=str(DATA_PATH / "intan" / "intan_rhs_test_1.rhs")), + ], + ids=["rhd", "rhs"], + ) + def setup_interface(self, request): + + test_id = request.node.callspec.id + self.test_name = test_id + self.interface_kwargs = request.param + self.interface = self.data_interface_cls(**self.interface_kwargs) + + return self.interface, self.test_name + -@skip(reason="This interface fails to load the necessary plugin sometimes.") -class TestMaxOneRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +@pytest.mark.skip(reason="This interface fails to load the necessary plugin sometimes.") +class TestMaxOneRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = MaxOneRecordingInterface interface_kwargs = dict(file_path=str(DATA_PATH / "maxwell" / "MaxOne_data" / "Record" / "000011" / "data.raw.h5")) save_directory = OUTPUT_PATH @@ -211,13 +240,13 @@ def check_extracted_metadata(self, metadata: dict): assert metadata["Ecephys"]["Device"][0]["description"] == "Recorded using Maxwell version '20190530'." -class TestMCSRawRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestMCSRawRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = MCSRawRecordingInterface interface_kwargs = dict(file_path=str(DATA_PATH / "rawmcs" / "raw_mcs_with_header_1.raw")) save_directory = OUTPUT_PATH -class TestMEArecRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestMEArecRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = MEArecRecordingInterface interface_kwargs = dict(file_path=str(DATA_PATH / "mearec" / "mearec_test_10s.h5")) save_directory = OUTPUT_PATH @@ -245,74 +274,86 @@ def check_extracted_metadata(self, metadata: dict): ) -class TestNeuralynxRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestNeuralynxRecordingInterfaceV574: data_interface_cls = NeuralynxRecordingInterface - interface_kwargs = [ - dict(folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v5.7.4" / "original_data")), - dict(folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v5.6.3" / "original_data")), - dict(folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v5.4.0" / "original_data")), - ] + interface_kwargs = (dict(folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v5.7.4" / "original_data")),) + save_directory = OUTPUT_PATH def check_extracted_metadata(self, metadata: dict): file_metadata = metadata["NWBFile"] + assert metadata["NWBFile"]["session_start_time"] == datetime(2017, 2, 16, 17, 56, 4) + assert metadata["NWBFile"]["session_id"] == "d8ba8eef-8d11-4cdc-86dc-05f50d4ba13d" + assert '"FileType": "NCS"' in file_metadata["notes"] + assert '"recording_closed": "2017-02-16 18:01:18"' in file_metadata["notes"] + assert '"ADMaxValue": "32767"' in file_metadata["notes"] + assert '"sampling_rate": "32000.0"' in file_metadata["notes"] + assert metadata["Ecephys"]["Device"][-1] == { + "name": "AcqSystem1 DigitalLynxSX", + "description": "Cheetah 5.7.4", + } + + def check_read(self, nwbfile_path): + super().check_read(nwbfile_path) + expected_single_channel_props = { + "DSPLowCutFilterEnabled": "True", + "DspLowCutFrequency": "10", + "DspLowCutNumTaps": "0", + "DspLowCutFilterType": "DCO", + "DSPHighCutFilterEnabled": "True", + "DspHighCutFrequency": "9000", + "DspHighCutNumTaps": "64", + "DspHighCutFilterType": "FIR", + "DspDelayCompensation": "Enabled", + } - if self.case == 0: - assert metadata["NWBFile"]["session_start_time"] == datetime(2017, 2, 16, 17, 56, 4) - assert metadata["NWBFile"]["session_id"] == "d8ba8eef-8d11-4cdc-86dc-05f50d4ba13d" - assert '"FileType": "NCS"' in file_metadata["notes"] - assert '"recording_closed": "2017-02-16 18:01:18"' in file_metadata["notes"] - assert '"ADMaxValue": "32767"' in file_metadata["notes"] - assert '"sampling_rate": "32000.0"' in file_metadata["notes"] - assert metadata["Ecephys"]["Device"][-1] == { - "name": "AcqSystem1 DigitalLynxSX", - "description": "Cheetah 5.7.4", - } - - elif self.case == 1: - assert file_metadata["session_start_time"] == datetime(2016, 11, 28, 21, 50, 33, 322000) - # Metadata extracted directly from file header (neo >= 0.11) - assert '"FileType": "CSC"' in file_metadata["notes"] - assert '"recording_closed": "2016-11-28 22:44:41.145000"' in file_metadata["notes"] - assert '"ADMaxValue": "32767"' in file_metadata["notes"] - assert '"sampling_rate": "2000.0"' in file_metadata["notes"] - assert metadata["Ecephys"]["Device"][-1] == {"name": "DigitalLynxSX", "description": "Cheetah 5.6.3"} - - elif self.case == 2: - assert file_metadata["session_start_time"] == datetime(2001, 1, 1, 0, 0) - assert '"recording_closed": "2001-01-01 00:00:00"' in file_metadata["notes"] - assert '"ADMaxValue": "32767"' in file_metadata["notes"] - assert '"sampling_rate": "1017.375"' in file_metadata["notes"] - assert metadata["Ecephys"]["Device"][-1] == {"name": "DigitalLynx", "description": "Cheetah 5.4.0"} + n_channels = self.interface.recording_extractor.get_num_channels() + + for key, exp_value in expected_single_channel_props.items(): + extracted_value = self.interface.recording_extractor.get_property(key) + assert len(extracted_value) == n_channels + assert exp_value == extracted_value[0] + + +class TestNeuralynxRecordingInterfaceV563: + data_interface_cls = NeuralynxRecordingInterface + interface_kwargs = (dict(folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v5.6.3" / "original_data")),) + + save_directory = OUTPUT_PATH + + def check_extracted_metadata(self, metadata: dict): + file_metadata = metadata["NWBFile"] + assert file_metadata["session_start_time"] == datetime(2016, 11, 28, 21, 50, 33, 322000) + assert '"FileType": "CSC"' in file_metadata["notes"] + assert '"recording_closed": "2016-11-28 22:44:41.145000"' in file_metadata["notes"] + assert '"ADMaxValue": "32767"' in file_metadata["notes"] + assert '"sampling_rate": "2000.0"' in file_metadata["notes"] + assert metadata["Ecephys"]["Device"][-1] == {"name": "DigitalLynxSX", "description": "Cheetah 5.6.3"} + + def check_read(self, nwbfile_path): + super().check_read(nwbfile_path) + # Add any specific checks for Cheetah_v5.6.3 if needed + + +class TestNeuralynxRecordingInterfaceV540: + data_interface_cls = NeuralynxRecordingInterface + interface_kwargs = (dict(folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v5.4.0" / "original_data")),) + save_directory = OUTPUT_PATH + + def check_extracted_metadata(self, metadata: dict): + file_metadata = metadata["NWBFile"] + assert file_metadata["session_start_time"] == datetime(2001, 1, 1, 0, 0) + assert '"recording_closed": "2001-01-01 00:00:00"' in file_metadata["notes"] + assert '"ADMaxValue": "32767"' in file_metadata["notes"] + assert '"sampling_rate": "1017.375"' in file_metadata["notes"] + assert metadata["Ecephys"]["Device"][-1] == {"name": "DigitalLynx", "description": "Cheetah 5.4.0"} def check_read(self, nwbfile_path): super().check_read(nwbfile_path) - if self.case == 0: - expected_single_channel_props = { - "DSPLowCutFilterEnabled": "True", - "DspLowCutFrequency": "10", - "DspLowCutNumTaps": "0", - "DspLowCutFilterType": "DCO", - "DSPHighCutFilterEnabled": "True", - "DspHighCutFrequency": "9000", - "DspHighCutNumTaps": "64", - "DspHighCutFilterType": "FIR", - "DspDelayCompensation": "Enabled", - # don't check for filter delay as the unit might be differently parsed - # "DspFilterDelay_µs": "984" - } - - n_channels = self.interface.recording_extractor.get_num_channels() - - for key, exp_value in expected_single_channel_props.items(): - extracted_value = self.interface.recording_extractor.get_property(key) - # check consistency of number of entries - assert len(extracted_value) == n_channels - # check values for first channel - assert exp_value == extracted_value[0] - - -class TestMultiStreamNeuralynxRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): + # Add any specific checks for Cheetah_v5.4.0 if need + + +class TestMultiStreamNeuralynxRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = NeuralynxRecordingInterface interface_kwargs = dict( folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v6.4.1dev" / "original_data"), @@ -334,57 +375,44 @@ def check_extracted_metadata(self, metadata: dict): } -class TestNeuroScopeRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestNeuroScopeRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = NeuroScopeRecordingInterface interface_kwargs = dict(file_path=str(DATA_PATH / "neuroscope" / "test1" / "test1.dat")) save_directory = OUTPUT_PATH -class TestOpenEphysBinaryRecordingInterfaceClassMethodsAndAssertions(RecordingExtractorInterfaceTestMixin, TestCase): +class TestOpenEphysBinaryRecordingInterfaceClassMethodsAndAssertions: + data_interface_cls = OpenEphysBinaryRecordingInterface - interface_kwargs = [] - save_directory = OUTPUT_PATH def test_get_stream_names(self): - self.assertCountEqual( - first=self.data_interface_cls.get_stream_names( - folder_path=str(DATA_PATH / "openephysbinary" / "v0.5.3_two_neuropixels_stream") - ), - second=["Record_Node_107#Neuropix-PXI-116.0", "Record_Node_107#Neuropix-PXI-116.1"], + + stream_names = self.data_interface_cls.get_stream_names( + folder_path=str(DATA_PATH / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107") ) + assert stream_names == ["Record_Node_107#Neuropix-PXI-116.0", "Record_Node_107#Neuropix-PXI-116.1"] + def test_folder_structure_assertion(self): - with self.assertRaisesWith( - exc_type=ValueError, - exc_msg=( - "Unable to identify the OpenEphys folder structure! Please check that your `folder_path` contains a " - "settings.xml file and sub-folders of the following form: 'experiment' -> 'recording' ->" - " 'continuous'." - ), + with pytest.raises( + ValueError, + match=r"Unable to identify the OpenEphys folder structure! Please check that your `folder_path` contains a settings.xml file and sub-folders of the following form: 'experiment' -> 'recording' -> 'continuous'.", ): OpenEphysBinaryRecordingInterface(folder_path=str(DATA_PATH / "openephysbinary")) def test_stream_name_missing_assertion(self): - with self.assertRaisesWith( - exc_type=ValueError, - exc_msg=( - "More than one stream is detected! " - "Please specify which stream you wish to load with the `stream_name` argument. " - "To see what streams are available, call " - " `OpenEphysRecordingInterface.get_stream_names(folder_path=...)`." - ), + with pytest.raises( + ValueError, + match=r"More than one stream is detected! Please specify which stream you wish to load with the `stream_name` argument. To see what streams are available, call\s+`OpenEphysRecordingInterface.get_stream_names\(folder_path=\.\.\.\)`.", ): OpenEphysBinaryRecordingInterface( folder_path=str(DATA_PATH / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107") ) def test_stream_name_not_available_assertion(self): - with self.assertRaisesWith( - exc_type=ValueError, - exc_msg=( - "The selected stream 'not_a_stream' is not in the available streams " - "'['Record_Node_107#Neuropix-PXI-116.0', 'Record_Node_107#Neuropix-PXI-116.1']'!" - ), + with pytest.raises( + ValueError, + match=r"The selected stream 'not_a_stream' is not in the available streams '\['Record_Node_107#Neuropix-PXI-116.0', 'Record_Node_107#Neuropix-PXI-116.1'\]'!", ): OpenEphysBinaryRecordingInterface( folder_path=str(DATA_PATH / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107"), @@ -392,7 +420,7 @@ def test_stream_name_not_available_assertion(self): ) -class TestOpenEphysBinaryRecordingInterfaceVersion0_4_4(RecordingExtractorInterfaceTestMixin, TestCase): +class TestOpenEphysBinaryRecordingInterfaceVersion0_4_4(RecordingExtractorInterfaceTestMixin): data_interface_cls = OpenEphysBinaryRecordingInterface interface_kwargs = dict(folder_path=str(DATA_PATH / "openephysbinary" / "v0.4.4.1_with_video_tracking")) save_directory = OUTPUT_PATH @@ -401,7 +429,7 @@ def check_extracted_metadata(self, metadata: dict): assert metadata["NWBFile"]["session_start_time"] == datetime(2021, 2, 15, 17, 20, 4) -class TestOpenEphysBinaryRecordingInterfaceVersion0_5_3_Stream1(RecordingExtractorInterfaceTestMixin, TestCase): +class TestOpenEphysBinaryRecordingInterfaceVersion0_5_3_Stream1(RecordingExtractorInterfaceTestMixin): data_interface_cls = OpenEphysBinaryRecordingInterface interface_kwargs = dict( folder_path=str(DATA_PATH / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107"), @@ -413,7 +441,7 @@ def check_extracted_metadata(self, metadata: dict): assert metadata["NWBFile"]["session_start_time"] == datetime(2020, 11, 24, 15, 46, 56) -class TestOpenEphysBinaryRecordingInterfaceVersion0_5_3_Stream2(RecordingExtractorInterfaceTestMixin, TestCase): +class TestOpenEphysBinaryRecordingInterfaceVersion0_5_3_Stream2(RecordingExtractorInterfaceTestMixin): data_interface_cls = OpenEphysBinaryRecordingInterface interface_kwargs = dict( folder_path=str(DATA_PATH / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107"), @@ -426,7 +454,7 @@ def check_extracted_metadata(self, metadata: dict): class TestOpenEphysBinaryRecordingInterfaceWithBlocks_version_0_6_block_1_stream_1( - RecordingExtractorInterfaceTestMixin, TestCase + RecordingExtractorInterfaceTestMixin ): """From Issue #695, exposed `block_index` argument and added tests on data that include multiple blocks.""" @@ -442,7 +470,7 @@ def check_extracted_metadata(self, metadata: dict): assert metadata["NWBFile"]["session_start_time"] == datetime(2022, 5, 3, 10, 52, 24) -class TestOpenEphysLegacyRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestOpenEphysLegacyRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = OpenEphysLegacyRecordingInterface interface_kwargs = dict(folder_path=str(DATA_PATH / "openephys" / "OpenEphys_SampleData_1")) save_directory = OUTPUT_PATH @@ -451,37 +479,84 @@ def check_extracted_metadata(self, metadata: dict): assert metadata["NWBFile"]["session_start_time"] == datetime(2018, 10, 3, 13, 16, 50) -class TestOpenEphysRecordingInterfaceRouter(RecordingExtractorInterfaceTestMixin, TestCase): +class TestOpenEphysRecordingInterfaceRouter(RecordingExtractorInterfaceTestMixin): data_interface_cls = OpenEphysRecordingInterface - interface_kwargs = [ - dict(folder_path=str(DATA_PATH / "openephys" / "OpenEphys_SampleData_1")), - dict(folder_path=str(DATA_PATH / "openephysbinary" / "v0.4.4.1_with_video_tracking")), - dict( - folder_path=str(DATA_PATH / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107"), - stream_name="Record_Node_107#Neuropix-PXI-116.0", - ), - dict( - folder_path=str(DATA_PATH / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107"), - stream_name="Record_Node_107#Neuropix-PXI-116.1", - ), - ] save_directory = OUTPUT_PATH + @pytest.fixture( + params=[ + dict(folder_path=str(DATA_PATH / "openephys" / "OpenEphys_SampleData_1")), + dict(folder_path=str(DATA_PATH / "openephysbinary" / "v0.4.4.1_with_video_tracking")), + dict( + folder_path=str(DATA_PATH / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107"), + stream_name="Record_Node_107#Neuropix-PXI-116.0", + ), + dict( + folder_path=str(DATA_PATH / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107"), + stream_name="Record_Node_107#Neuropix-PXI-116.1", + ), + ], + ids=[ + "OpenEphys_SampleData_1", + "v0.4.4.1_with_video_tracking", + "Record_Node_107_Neuropix-PXI-116.0", + "Record_Node_107_Neuropix-PXI-116.1", + ], + ) + def setup_interface(self, request): + test_id = request.node.callspec.id + self.test_name = test_id + self.interface_kwargs = request.param + self.interface = self.data_interface_cls(**self.interface_kwargs) + + return self.interface, self.test_name + + def test_interface_extracted_metadata(self, setup_interface): + interface, test_name = setup_interface + metadata = interface.get_metadata() + assert "NWBFile" in metadata # Example assertion + # Additional assertions specific to the metadata can be added here + -class TestSpikeGadgetsRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestSpikeGadgetsRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = SpikeGadgetsRecordingInterface - interface_kwargs = [ - dict(file_path=str(DATA_PATH / "spikegadgets" / "20210225_em8_minirec2_ac.rec")), - dict(file_path=str(DATA_PATH / "spikegadgets" / "20210225_em8_minirec2_ac.rec"), gains=[0.195]), - dict(file_path=str(DATA_PATH / "spikegadgets" / "20210225_em8_minirec2_ac.rec"), gains=[0.385] * 512), - dict(file_path=str(DATA_PATH / "spikegadgets" / "W122_06_09_2019_1_fromSD.rec")), - dict(file_path=str(DATA_PATH / "spikegadgets" / "W122_06_09_2019_1_fromSD.rec"), gains=[0.195]), - dict(file_path=str(DATA_PATH / "spikegadgets" / "W122_06_09_2019_1_fromSD.rec"), gains=[0.385] * 128), - ] save_directory = OUTPUT_PATH + @pytest.fixture( + params=[ + dict(file_path=str(DATA_PATH / "spikegadgets" / "20210225_em8_minirec2_ac.rec")), + dict(file_path=str(DATA_PATH / "spikegadgets" / "20210225_em8_minirec2_ac.rec"), gains=[0.195]), + dict(file_path=str(DATA_PATH / "spikegadgets" / "20210225_em8_minirec2_ac.rec"), gains=[0.385] * 512), + dict(file_path=str(DATA_PATH / "spikegadgets" / "W122_06_09_2019_1_fromSD.rec")), + dict(file_path=str(DATA_PATH / "spikegadgets" / "W122_06_09_2019_1_fromSD.rec"), gains=[0.195]), + dict(file_path=str(DATA_PATH / "spikegadgets" / "W122_06_09_2019_1_fromSD.rec"), gains=[0.385] * 128), + ], + ids=[ + "20210225_em8_minirec2_ac_default_gains", + "20210225_em8_minirec2_ac_gains_0.195", + "20210225_em8_minirec2_ac_gains_0.385x512", + "W122_06_09_2019_1_fromSD_default_gains", + "W122_06_09_2019_1_fromSD_gains_0.195", + "W122_06_09_2019_1_fromSD_gains_0.385x128", + ], + ) + def setup_interface(self, request): + test_id = request.node.callspec.id + self.test_name = test_id + self.interface_kwargs = request.param + self.interface = self.data_interface_cls(**self.interface_kwargs) + + return self.interface, self.test_name + + def test_extracted_metadata(self, setup_interface): + interface, test_name = setup_interface + metadata = interface.get_metadata() + # Example assertion + assert "NWBFile" in metadata + # Additional assertions specific to the metadata can be added here + -class TestSpikeGLXRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestSpikeGLXRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = SpikeGLXRecordingInterface interface_kwargs = dict( file_path=str(DATA_PATH / "spikeglx" / "Noise4Sam_g0" / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin") @@ -502,7 +577,7 @@ def check_extracted_metadata(self, metadata: dict): ) -class TestSpikeGLXRecordingInterfaceLongNHP(RecordingExtractorInterfaceTestMixin, TestCase): +class TestSpikeGLXRecordingInterfaceLongNHP(RecordingExtractorInterfaceTestMixin): data_interface_cls = SpikeGLXRecordingInterface interface_kwargs = dict( file_path=str( @@ -530,7 +605,7 @@ def check_extracted_metadata(self, metadata: dict): ) -class TestTdtRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestTdtRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = TdtRecordingInterface test_gain_value = 0.195 # arbitrary value to test gain interface_kwargs = dict(folder_path=str(DATA_PATH / "tdt" / "aep_05"), gain=test_gain_value) @@ -555,7 +630,7 @@ def check_read_nwb(self, nwbfile_path: str): return super().check_read_nwb(nwbfile_path=nwbfile_path) -class TestPlexonRecordingInterface(RecordingExtractorInterfaceTestMixin, TestCase): +class TestPlexonRecordingInterface(RecordingExtractorInterfaceTestMixin): data_interface_cls = PlexonRecordingInterface interface_kwargs = dict( # Only File_plexon_3.plx has an ecephys recording stream diff --git a/tests/test_on_data/test_segmentation_interfaces.py b/tests/test_on_data/test_segmentation_interfaces.py index 4caf9c48e..3d2547df8 100644 --- a/tests/test_on_data/test_segmentation_interfaces.py +++ b/tests/test_on_data/test_segmentation_interfaces.py @@ -1,6 +1,4 @@ -from unittest import TestCase - -from parameterized import parameterized_class +import pytest from neuroconv.datainterfaces import ( CaimanSegmentationInterface, @@ -18,26 +16,44 @@ from setup_paths import OPHYS_DATA_PATH, OUTPUT_PATH -@parameterized_class( - [ - {"conversion_options": {"mask_type": "image", "include_background_segmentation": True}}, - {"conversion_options": {"mask_type": "pixel", "include_background_segmentation": True}}, - {"conversion_options": {"mask_type": "voxel", "include_background_segmentation": True}}, - # {"conversion_options": {"mask_type": None, "include_background_segmentation": True}}, # Uncomment when https://github.com/catalystneuro/neuroconv/issues/530 is resolved - {"conversion_options": {"include_roi_centroids": False, "include_background_segmentation": True}}, - {"conversion_options": {"include_roi_acceptance": False, "include_background_segmentation": True}}, - {"conversion_options": {"include_background_segmentation": False}}, - ] -) -class TestCaimanSegmentationInterface(SegmentationExtractorInterfaceTestMixin, TestCase): +class TestCaimanSegmentationInterface(SegmentationExtractorInterfaceTestMixin): data_interface_cls = CaimanSegmentationInterface interface_kwargs = dict( file_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "caiman" / "caiman_analysis.hdf5") ) save_directory = OUTPUT_PATH + @pytest.fixture( + params=[ + {"mask_type": "image", "include_background_segmentation": True}, + {"mask_type": "pixel", "include_background_segmentation": True}, + {"mask_type": "voxel", "include_background_segmentation": True}, + # {"mask_type": None, "include_background_segmentation": True}, # Uncomment when https://github.com/catalystneuro/neuroconv/issues/530 is resolved + {"include_roi_centroids": False, "include_background_segmentation": True}, + {"include_roi_acceptance": False, "include_background_segmentation": True}, + {"include_background_segmentation": False}, + ], + ids=[ + "mask_type_image", + "mask_type_pixel", + "mask_type_voxel", + "exclude_roi_centroids", + "exclude_roi_acceptance", + "exclude_background_segmentation", + ], + ) + def setup_interface(self, request): + + test_id = request.node.callspec.id + self.test_name = test_id + self.interface_kwargs = self.interface_kwargs + self.conversion_options = request.param + self.interface = self.data_interface_cls(**self.interface_kwargs) + + return self.interface, self.test_name + -class TestCnmfeSegmentationInterface(SegmentationExtractorInterfaceTestMixin, TestCase): +class TestCnmfeSegmentationInterface(SegmentationExtractorInterfaceTestMixin): data_interface_cls = CnmfeSegmentationInterface interface_kwargs = dict( file_path=str( @@ -47,84 +63,139 @@ class TestCnmfeSegmentationInterface(SegmentationExtractorInterfaceTestMixin, Te save_directory = OUTPUT_PATH -class TestExtractSegmentationInterface(SegmentationExtractorInterfaceTestMixin, TestCase): +class TestExtractSegmentationInterface(SegmentationExtractorInterfaceTestMixin): data_interface_cls = ExtractSegmentationInterface - interface_kwargs = [ - dict( - file_path=str( - OPHYS_DATA_PATH - / "segmentation_datasets" - / "extract" - / "2014_04_01_p203_m19_check01_extractAnalysis.mat" + save_directory = OUTPUT_PATH + + @pytest.fixture( + params=[ + dict( + file_path=str( + OPHYS_DATA_PATH + / "segmentation_datasets" + / "extract" + / "2014_04_01_p203_m19_check01_extractAnalysis.mat" + ), + sampling_frequency=15.0, # typically provided by user + ), + dict( + file_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "extract" / "extract_public_output.mat"), + sampling_frequency=15.0, # typically provided by user ), - sampling_frequency=15.0, # typically provided by user - ), - dict( - file_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "extract" / "extract_public_output.mat"), - sampling_frequency=15.0, # typically provided by user - ), - ] + ], + ids=["dataset_1", "dataset_2"], + ) + def setup_interface(self, request): + test_id = request.node.callspec.id + self.test_name = test_id + self.interface_kwargs = request.param + self.interface = self.data_interface_cls(**self.interface_kwargs) + + return self.interface, self.test_name + + +def test_extract_segmentation_interface_non_default_output_struct_name(): + """Test that the value for 'output_struct_name' is propagated to the extractor level + where an error is raised.""" + file_path = OPHYS_DATA_PATH / "segmentation_datasets" / "extract" / "extract_public_output.mat" + + with pytest.raises(AssertionError, match="Output struct name 'not_output' not found in file."): + ExtractSegmentationInterface( + file_path=str(file_path), + sampling_frequency=15.0, + output_struct_name="not_output", + ) + + +class TestSuite2pSegmentationInterfaceChan1Plane0(SegmentationExtractorInterfaceTestMixin): + data_interface_cls = Suite2pSegmentationInterface save_directory = OUTPUT_PATH + interface_kwargs = dict( + folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p"), + channel_name="chan1", + plane_name="plane0", + ) + + @pytest.fixture(scope="class", autouse=True) + def setup_metadata(self, request): + cls = request.cls + plane_suffix = "Chan1Plane0" + cls.imaging_plane_names = "ImagingPlane" + plane_suffix + cls.plane_segmentation_names = "PlaneSegmentation" + plane_suffix + cls.mean_image_names = "MeanImage" + plane_suffix + cls.correlation_image_names = "CorrelationImage" + plane_suffix + cls.raw_traces_names = "RoiResponseSeries" + plane_suffix + cls.neuropil_traces_names = "Neuropil" + plane_suffix + cls.deconvolved_trace_name = "Deconvolved" + plane_suffix + + def test_check_extracted_metadata(self): + self.interface = self.data_interface_cls(**self.interface_kwargs) + + metadata = self.interface.get_metadata() + + assert metadata["Ophys"]["ImagingPlane"][0]["name"] == self.imaging_plane_names + plane_segmentation_metadata = metadata["Ophys"]["ImageSegmentation"]["plane_segmentations"][0] + plane_segmentation_name = self.plane_segmentation_names + assert plane_segmentation_metadata["name"] == plane_segmentation_name + summary_images_metadata = metadata["Ophys"]["SegmentationImages"][plane_segmentation_name] + assert summary_images_metadata["correlation"]["name"] == self.correlation_image_names + assert summary_images_metadata["mean"]["name"] == self.mean_image_names + + raw_traces_metadata = metadata["Ophys"]["Fluorescence"][plane_segmentation_name]["raw"] + assert raw_traces_metadata["name"] == self.raw_traces_names + neuropil_traces_metadata = metadata["Ophys"]["Fluorescence"][plane_segmentation_name]["neuropil"] + assert neuropil_traces_metadata["name"] == self.neuropil_traces_names - def test_extract_segmentation_interface_non_default_output_struct_name(self): - """Test that the value for 'output_struct_name' is propagated to the extractor level - where an error is raised.""" - file_path = OPHYS_DATA_PATH / "segmentation_datasets" / "extract" / "extract_public_output.mat" - with self.assertRaisesRegex(AssertionError, "Output struct name 'not_output' not found in file."): - ExtractSegmentationInterface( - file_path=str(file_path), - sampling_frequency=15.0, - output_struct_name="not_output", - ) + deconvolved_trace_metadata = metadata["Ophys"]["Fluorescence"][plane_segmentation_name]["deconvolved"] + assert deconvolved_trace_metadata["name"] == self.deconvolved_trace_name -class TestSuite2pSegmentationInterface(SegmentationExtractorInterfaceTestMixin, TestCase): +class TestSuite2pSegmentationInterfaceChan2Plane0(SegmentationExtractorInterfaceTestMixin): data_interface_cls = Suite2pSegmentationInterface - interface_kwargs = [ - dict( - folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p"), - channel_name="chan1", - plane_name="plane0", - ), - dict( - folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p"), - channel_name="chan2", - plane_name="plane0", - ), - ] save_directory = OUTPUT_PATH + interface_kwargs = dict( + folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p"), + channel_name="chan2", + plane_name="plane0", + ) - @classmethod - def setUpClass(cls) -> None: - plane_suffices = ["Chan1Plane0", "Chan2Plane0"] - cls.imaging_plane_names = ["ImagingPlane" + plane_suffix for plane_suffix in plane_suffices] - cls.plane_segmentation_names = ["PlaneSegmentation" + plane_suffix for plane_suffix in plane_suffices] - cls.mean_image_names = ["MeanImage" + plane_suffix for plane_suffix in plane_suffices] - cls.correlation_image_names = ["CorrelationImage" + plane_suffix for plane_suffix in plane_suffices] - cls.raw_traces_names = ["RoiResponseSeries" + plane_suffix for plane_suffix in plane_suffices] - cls.neuropil_traces_names = ["Neuropil" + plane_suffix for plane_suffix in plane_suffices] - cls.deconvolved_trace_name = "Deconvolved" + plane_suffices[0] - - def check_extracted_metadata(self, metadata: dict): - """Check extracted metadata is adjusted correctly for each plane and channel combination.""" - self.assertEqual(metadata["Ophys"]["ImagingPlane"][0]["name"], self.imaging_plane_names[self.case]) + @pytest.fixture(scope="class", autouse=True) + def setup_metadata(self, request): + cls = request.cls + + plane_suffix = "Chan2Plane0" + cls.imaging_plane_names = "ImagingPlane" + plane_suffix + cls.plane_segmentation_names = "PlaneSegmentation" + plane_suffix + cls.mean_image_names = "MeanImage" + plane_suffix + cls.correlation_image_names = "CorrelationImage" + plane_suffix + cls.raw_traces_names = "RoiResponseSeries" + plane_suffix + cls.neuropil_traces_names = "Neuropil" + plane_suffix + cls.deconvolved_trace_name = None + + def test_check_extracted_metadata(self): + self.interface = self.data_interface_cls(**self.interface_kwargs) + + metadata = self.interface.get_metadata() + + assert metadata["Ophys"]["ImagingPlane"][0]["name"] == self.imaging_plane_names plane_segmentation_metadata = metadata["Ophys"]["ImageSegmentation"]["plane_segmentations"][0] - plane_segmentation_name = self.plane_segmentation_names[self.case] - self.assertEqual(plane_segmentation_metadata["name"], plane_segmentation_name) + plane_segmentation_name = self.plane_segmentation_names + assert plane_segmentation_metadata["name"] == plane_segmentation_name summary_images_metadata = metadata["Ophys"]["SegmentationImages"][plane_segmentation_name] - self.assertEqual(summary_images_metadata["correlation"]["name"], self.correlation_image_names[self.case]) - self.assertEqual(summary_images_metadata["mean"]["name"], self.mean_image_names[self.case]) + assert summary_images_metadata["correlation"]["name"] == self.correlation_image_names + assert summary_images_metadata["mean"]["name"] == self.mean_image_names raw_traces_metadata = metadata["Ophys"]["Fluorescence"][plane_segmentation_name]["raw"] - self.assertEqual(raw_traces_metadata["name"], self.raw_traces_names[self.case]) + assert raw_traces_metadata["name"] == self.raw_traces_names neuropil_traces_metadata = metadata["Ophys"]["Fluorescence"][plane_segmentation_name]["neuropil"] - self.assertEqual(neuropil_traces_metadata["name"], self.neuropil_traces_names[self.case]) - if self.case == 0: + assert neuropil_traces_metadata["name"] == self.neuropil_traces_names + + if self.deconvolved_trace_name: deconvolved_trace_metadata = metadata["Ophys"]["Fluorescence"][plane_segmentation_name]["deconvolved"] - self.assertEqual(deconvolved_trace_metadata["name"], self.deconvolved_trace_name) + assert deconvolved_trace_metadata["name"] == self.deconvolved_trace_name -class TestSuite2pSegmentationInterfaceWithStubTest(SegmentationExtractorInterfaceTestMixin, TestCase): +class TestSuite2pSegmentationInterfaceWithStubTest(SegmentationExtractorInterfaceTestMixin): data_interface_cls = Suite2pSegmentationInterface interface_kwargs = dict( folder_path=str(OPHYS_DATA_PATH / "segmentation_datasets" / "suite2p"), diff --git a/tests/test_on_data/test_sorting_interfaces.py b/tests/test_on_data/test_sorting_interfaces.py index 8898d780b..dfb4ff599 100644 --- a/tests/test_on_data/test_sorting_interfaces.py +++ b/tests/test_on_data/test_sorting_interfaces.py @@ -1,5 +1,4 @@ from datetime import datetime -from unittest import TestCase import numpy as np from pynwb import NWBHDF5IO @@ -25,7 +24,7 @@ from setup_paths import OUTPUT_PATH -class TestBlackrockSortingInterface(SortingExtractorInterfaceTestMixin, TestCase): +class TestBlackrockSortingInterface(SortingExtractorInterfaceTestMixin): data_interface_cls = BlackrockSortingInterface interface_kwargs = dict(file_path=str(DATA_PATH / "blackrock" / "FileSpec2.3001.nev")) @@ -35,51 +34,81 @@ class TestBlackrockSortingInterface(SortingExtractorInterfaceTestMixin, TestCase save_directory = OUTPUT_PATH -class TestCellExplorerSortingInterfaceBuzCode(SortingExtractorInterfaceTestMixin, TestCase): +import pytest + + +class TestCellExplorerSortingInterfaceBuzCode(SortingExtractorInterfaceTestMixin): """This corresponds to the Buzsaki old CellExplorerFormat or Buzcode format.""" data_interface_cls = CellExplorerSortingInterface - interface_kwargs = [ - dict( - file_path=str( - DATA_PATH / "cellexplorer" / "dataset_1" / "20170311_684um_2088um_170311_134350.spikes.cellinfo.mat" - ) - ), - dict(file_path=str(DATA_PATH / "cellexplorer" / "dataset_2" / "20170504_396um_0um_merge.spikes.cellinfo.mat")), - dict( - file_path=str(DATA_PATH / "cellexplorer" / "dataset_3" / "20170519_864um_900um_merge.spikes.cellinfo.mat") - ), - ] save_directory = OUTPUT_PATH + @pytest.fixture( + params=[ + dict( + file_path=str( + DATA_PATH / "cellexplorer" / "dataset_1" / "20170311_684um_2088um_170311_134350.spikes.cellinfo.mat" + ) + ), + dict( + file_path=str(DATA_PATH / "cellexplorer" / "dataset_2" / "20170504_396um_0um_merge.spikes.cellinfo.mat") + ), + dict( + file_path=str( + DATA_PATH / "cellexplorer" / "dataset_3" / "20170519_864um_900um_merge.spikes.cellinfo.mat" + ) + ), + ], + ids=["dataset_1", "dataset_2", "dataset_3"], + ) + def setup_interface(self, request): + test_id = request.node.callspec.id + self.test_name = test_id + self.interface_kwargs = request.param + self.interface = self.data_interface_cls(**self.interface_kwargs) -class TestCellEploreSortingInterface(SortingExtractorInterfaceTestMixin, TestCase): + return self.interface, self.test_name + + +class TestCellExplorerSortingInterface(SortingExtractorInterfaceTestMixin): """This corresponds to the Buzsaki new CellExplorerFormat where a session.mat file with rich metadata is provided.""" data_interface_cls = CellExplorerSortingInterface - interface_kwargs = [ - dict( - file_path=str( - DATA_PATH - / "cellexplorer" - / "dataset_4" - / "Peter_MS22_180629_110319_concat_stubbed" - / "Peter_MS22_180629_110319_concat_stubbed.spikes.cellinfo.mat" - ) - ), - dict( - file_path=str( - DATA_PATH - / "cellexplorer" - / "dataset_4" - / "Peter_MS22_180629_110319_concat_stubbed_hdf5" - / "Peter_MS22_180629_110319_concat_stubbed_hdf5.spikes.cellinfo.mat" - ) - ), - ] save_directory = OUTPUT_PATH - def test_writing_channel_metadata(self): + @pytest.fixture( + params=[ + dict( + file_path=str( + DATA_PATH + / "cellexplorer" + / "dataset_4" + / "Peter_MS22_180629_110319_concat_stubbed" + / "Peter_MS22_180629_110319_concat_stubbed.spikes.cellinfo.mat" + ) + ), + dict( + file_path=str( + DATA_PATH + / "cellexplorer" + / "dataset_4" + / "Peter_MS22_180629_110319_concat_stubbed_hdf5" + / "Peter_MS22_180629_110319_concat_stubbed_hdf5.spikes.cellinfo.mat" + ) + ), + ], + ids=["mat", "hdf5"], + ) + def setup_interface(self, request): + self.test_name = request.node.callspec.id + self.interface_kwargs = request.param + self.interface = self.data_interface_cls(**self.interface_kwargs) + + return self.interface, self.test_name + + def test_writing_channel_metadata(self, setup_interface): + interface, test_name = setup_interface + channel_id = "1" expected_channel_properties_recorder = { "location": np.array([791.5, -160.0]), @@ -93,51 +122,49 @@ def test_writing_channel_metadata(self): "group_name": "Group 5", } - interface_kwargs = self.interface_kwargs - for num, kwargs in enumerate(interface_kwargs): - with self.subTest(str(num)): - self.case = num - self.test_kwargs = kwargs - self.interface = self.data_interface_cls(**self.test_kwargs) - self.nwbfile_path = str(self.save_directory / f"{self.data_interface_cls.__name__}_{num}_channel.nwb") - - metadata = self.interface.get_metadata() - metadata["NWBFile"].update(session_start_time=datetime.now().astimezone()) - self.interface.run_conversion( - nwbfile_path=self.nwbfile_path, - overwrite=True, - metadata=metadata, - write_ecephys_metadata=True, - ) + self.nwbfile_path = str(self.save_directory / f"{self.data_interface_cls.__name__}_{test_name}_channel.nwb") + + metadata = interface.get_metadata() + metadata["NWBFile"].update(session_start_time=datetime.now().astimezone()) + interface.run_conversion( + nwbfile_path=self.nwbfile_path, + overwrite=True, + metadata=metadata, + write_ecephys_metadata=True, + ) + + # Test that the registered recording has the expected channel properties + recording_extractor = interface.generate_recording_with_channel_metadata() + for key, expected_value in expected_channel_properties_recorder.items(): + extracted_value = recording_extractor.get_channel_property(channel_id=channel_id, key=key) + if key == "location": + assert np.allclose(expected_value, extracted_value) + else: + assert expected_value == extracted_value + + # Test that the electrode table has the expected values + with NWBHDF5IO(self.nwbfile_path, "r") as io: + nwbfile = io.read() + electrode_table = nwbfile.electrodes.to_dataframe() + electrode_table_row = electrode_table.query(f"channel_name=='{channel_id}'").iloc[0] + for key, value in expected_channel_properties_electrodes.items(): + assert electrode_table_row[key] == value + + +class TestNeuralynxSortingInterfaceCheetahV551(SortingExtractorInterfaceTestMixin): + data_interface_cls = NeuralynxSortingInterface + interface_kwargs = dict(folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v5.5.1" / "original_data")) + save_directory = OUTPUT_PATH - # Test that the registered recording has the `` - recording_extractor = self.interface.generate_recording_with_channel_metadata() - for key, expected_value in expected_channel_properties_recorder.items(): - extracted_value = recording_extractor.get_channel_property(channel_id=channel_id, key=key) - if key == "location": - assert np.allclose(expected_value, extracted_value) - else: - assert expected_value == extracted_value - - # Test that the electrode table has the expected values - with NWBHDF5IO(self.nwbfile_path, "r") as io: - nwbfile = io.read() - electrode_table = nwbfile.electrodes.to_dataframe() - electrode_table_row = electrode_table.query(f"channel_name=='{channel_id}'").iloc[0] - for key, value in expected_channel_properties_electrodes.items(): - assert electrode_table_row[key] == value - - -class TestNeuralynxSortingInterface(SortingExtractorInterfaceTestMixin, TestCase): + +class TestNeuralynxSortingInterfaceCheetah563(SortingExtractorInterfaceTestMixin): data_interface_cls = NeuralynxSortingInterface - interface_kwargs = [ - dict(folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v5.5.1" / "original_data")), - dict(folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v5.6.3" / "original_data")), - ] + interface_kwargs = dict(folder_path=str(DATA_PATH / "neuralynx" / "Cheetah_v5.6.3" / "original_data")) + save_directory = OUTPUT_PATH -class TestNeuroScopeSortingInterface(SortingExtractorInterfaceTestMixin, TestCase): +class TestNeuroScopeSortingInterface(SortingExtractorInterfaceTestMixin): data_interface_cls = NeuroScopeSortingInterface interface_kwargs = dict( folder_path=str(DATA_PATH / "neuroscope" / "dataset_1"), @@ -149,7 +176,7 @@ def check_extracted_metadata(self, metadata: dict): assert metadata["NWBFile"]["session_start_time"] == datetime(2015, 8, 31, 0, 0) -class TestNeuroScopeSortingInterfaceNoXMLSpecified(SortingExtractorInterfaceTestMixin, TestCase): +class TestNeuroScopeSortingInterfaceNoXMLSpecified(SortingExtractorInterfaceTestMixin): """Corresponding to issue https://github.com/NeurodataWithoutBorders/nwb-guide/issues/881.""" data_interface_cls = NeuroScopeSortingInterface @@ -161,13 +188,13 @@ def check_extracted_metadata(self, metadata: dict): pass -class TestPhySortingInterface(SortingExtractorInterfaceTestMixin, TestCase): +class TestPhySortingInterface(SortingExtractorInterfaceTestMixin): data_interface_cls = PhySortingInterface interface_kwargs = dict(folder_path=str(DATA_PATH / "phy" / "phy_example_0")) save_directory = OUTPUT_PATH -class TestPlexonSortingInterface(SortingExtractorInterfaceTestMixin, TestCase): +class TestPlexonSortingInterface(SortingExtractorInterfaceTestMixin): data_interface_cls = PlexonSortingInterface interface_kwargs = dict(file_path=str(DATA_PATH / "plexon" / "File_plexon_2.plx")) save_directory = OUTPUT_PATH