From ad9f25d55ad6cb5ea092d5abb19401ecbef1d773 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 16 Sep 2024 10:05:33 -0600 Subject: [PATCH] Add `MockSortingInterface` (#1065) --- CHANGELOG.md | 4 +- src/neuroconv/tools/testing/__init__.py | 8 ++- .../tools/testing/mock_interfaces.py | 51 +++++++++++++++++++ tests/test_ecephys/test_ecephys_interfaces.py | 33 +++++++++++- 4 files changed, 91 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dfa9612aa..7570045b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ ## Features * Added chunking/compression for string-only compound objects: [PR #1042](https://github.com/catalystneuro/neuroconv/pull/1042) * Added automated EFS volume creation and mounting to the `submit_aws_job` helper function. [PR #1018](https://github.com/catalystneuro/neuroconv/pull/1018) +* Added a `MockSortingInterface` for testing purposes. [PR #1065](https://github.com/catalystneuro/neuroconv/pull/1065) + ## Improvements * Add writing to zarr test for to the test on data [PR #1056](https://github.com/catalystneuro/neuroconv/pull/1056) @@ -33,8 +35,6 @@ * Added `get_stream_names` to `OpenEphysRecordingInterface`: [PR #1039](https://github.com/catalystneuro/neuroconv/pull/1039) * Most data interfaces and converters now use Pydantic to validate their inputs, including existence of file and folder paths. [PR #1022](https://github.com/catalystneuro/neuroconv/pull/1022) * All remaining data interfaces and converters now use Pydantic to validate their inputs, including existence of file and folder paths. [PR #1055](https://github.com/catalystneuro/neuroconv/pull/1055) -* Added a mock for segmentation extractors interfaces in ophys: `MockSegmentationInterface` [PR #1067](https://github.com/catalystneuro/neuroconv/pull/1067) -* Added automated EFS volume creation and mounting to the `submit_aws_job` helper function. [PR #1018](https://github.com/catalystneuro/neuroconv/pull/1018) ### Improvements diff --git a/src/neuroconv/tools/testing/__init__.py b/src/neuroconv/tools/testing/__init__.py index 7179a7544..79b54d3f9 100644 --- a/src/neuroconv/tools/testing/__init__.py +++ b/src/neuroconv/tools/testing/__init__.py @@ -5,5 +5,11 @@ mock_ZarrDatasetIOConfiguration, ) from .mock_files import generate_path_expander_demo_ibl -from .mock_interfaces import MockBehaviorEventInterface, MockSpikeGLXNIDQInterface +from .mock_interfaces import ( + MockBehaviorEventInterface, + MockSpikeGLXNIDQInterface, + MockRecordingInterface, + MockImagingInterface, + MockSortingInterface, +) from .mock_ttl_signals import generate_mock_ttl_signal, regenerate_test_cases diff --git a/src/neuroconv/tools/testing/mock_interfaces.py b/src/neuroconv/tools/testing/mock_interfaces.py index 87f6dcf8e..43f8c2dd2 100644 --- a/src/neuroconv/tools/testing/mock_interfaces.py +++ b/src/neuroconv/tools/testing/mock_interfaces.py @@ -11,6 +11,9 @@ from ...datainterfaces.ecephys.baserecordingextractorinterface import ( BaseRecordingExtractorInterface, ) +from ...datainterfaces.ecephys.basesortingextractorinterface import ( + BaseSortingExtractorInterface, +) from ...datainterfaces.ophys.baseimagingextractorinterface import ( BaseImagingExtractorInterface, ) @@ -160,6 +163,54 @@ def get_metadata(self) -> dict: return metadata +class MockSortingInterface(BaseSortingExtractorInterface): + """A mock sorting extractor interface for generating synthetic sorting data.""" + + # TODO: Implement this class with the lazy generator once is merged + # https://github.com/SpikeInterface/spikeinterface/pull/2227 + + ExtractorModuleName = "spikeinterface.core.generate" + ExtractorName = "generate_sorting" + + def __init__( + self, + num_units: int = 4, + sampling_frequency: float = 30_000.0, + durations: tuple[float] = (1.0,), + seed: int = 0, + verbose: bool = True, + ): + """ + Parameters + ---------- + num_units : int, optional + Number of units to generate, by default 4. + sampling_frequency : float, optional + Sampling frequency of the generated data in Hz, by default 30,000.0 Hz. + durations : tuple of float, optional + Durations of the segments in seconds, by default (1.0,). + seed : int, optional + Seed for the random number generator, by default 0. + verbose : bool, optional + Control whether to display verbose messages during writing, by default True. + + """ + + super().__init__( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=durations, + seed=seed, + verbose=verbose, + ) + + def get_metadata(self) -> dict: # noqa D102 + metadata = super().get_metadata() + session_start_time = datetime.now().astimezone() + metadata["NWBFile"]["session_start_time"] = session_start_time + return metadata + + class MockImagingInterface(BaseImagingExtractorInterface): """ A mock imaging interface for testing purposes. diff --git a/tests/test_ecephys/test_ecephys_interfaces.py b/tests/test_ecephys/test_ecephys_interfaces.py index 6372b3535..5591bb6fb 100644 --- a/tests/test_ecephys/test_ecephys_interfaces.py +++ b/tests/test_ecephys/test_ecephys_interfaces.py @@ -20,7 +20,10 @@ BaseSortingExtractorInterface, ) from neuroconv.tools.nwb_helpers import get_module -from neuroconv.tools.testing.mock_interfaces import MockRecordingInterface +from neuroconv.tools.testing.mock_interfaces import ( + MockRecordingInterface, + MockSortingInterface, +) python_version = Version(get_python_version()) @@ -67,7 +70,33 @@ def test_spike2_import_assertions_3_11(self): Spike2RecordingInterface.get_all_channels_info(file_path="does_not_matter.smrx") -class TestSortingInterface(unittest.TestCase): +class TestSortingInterface: + + def test_run_conversion(self, tmp_path): + + nwbfile_path = Path(tmp_path) / "test_sorting.nwb" + num_units = 4 + interface = MockSortingInterface(num_units=num_units, durations=(1.0,)) + interface.sorting_extractor = interface.sorting_extractor.rename_units(new_unit_ids=["a", "b", "c", "d"]) + + interface.run_conversion(nwbfile_path=nwbfile_path) + with NWBHDF5IO(nwbfile_path, "r") as io: + nwbfile = io.read() + + units = nwbfile.units + assert len(units) == num_units + units_df = units.to_dataframe() + # Get index in units table + for unit_id in interface.sorting_extractor.unit_ids: + # In pynwb we write unit name as unit_id + row = units_df.query(f"unit_name == '{unit_id}'") + spike_times = interface.sorting_extractor.get_unit_spike_train(unit_id=unit_id, return_times=True) + written_spike_times = row["spike_times"].iloc[0] + + np.testing.assert_array_equal(spike_times, written_spike_times) + + +class TestSortingInterfaceOld(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.test_dir = Path(mkdtemp())