Skip to content

Commit

Permalink
Add MockSortingInterface (#1065)
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Sep 16, 2024
1 parent ad1e2a1 commit ad9f25d
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 5 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
## Features
* Added chunking/compression for string-only compound objects: [PR #1042](https://github.com/catalystneuro/neuroconv/pull/1042)
* Added automated EFS volume creation and mounting to the `submit_aws_job` helper function. [PR #1018](https://github.com/catalystneuro/neuroconv/pull/1018)
* Added a `MockSortingInterface` for testing purposes. [PR #1065](https://github.com/catalystneuro/neuroconv/pull/1065)


## Improvements
* Add writing to zarr test for to the test on data [PR #1056](https://github.com/catalystneuro/neuroconv/pull/1056)
Expand All @@ -33,8 +35,6 @@
* Added `get_stream_names` to `OpenEphysRecordingInterface`: [PR #1039](https://github.com/catalystneuro/neuroconv/pull/1039)
* Most data interfaces and converters now use Pydantic to validate their inputs, including existence of file and folder paths. [PR #1022](https://github.com/catalystneuro/neuroconv/pull/1022)
* All remaining data interfaces and converters now use Pydantic to validate their inputs, including existence of file and folder paths. [PR #1055](https://github.com/catalystneuro/neuroconv/pull/1055)
* Added a mock for segmentation extractors interfaces in ophys: `MockSegmentationInterface` [PR #1067](https://github.com/catalystneuro/neuroconv/pull/1067)
* Added automated EFS volume creation and mounting to the `submit_aws_job` helper function. [PR #1018](https://github.com/catalystneuro/neuroconv/pull/1018)


### Improvements
Expand Down
8 changes: 7 additions & 1 deletion src/neuroconv/tools/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,11 @@
mock_ZarrDatasetIOConfiguration,
)
from .mock_files import generate_path_expander_demo_ibl
from .mock_interfaces import MockBehaviorEventInterface, MockSpikeGLXNIDQInterface
from .mock_interfaces import (
MockBehaviorEventInterface,
MockSpikeGLXNIDQInterface,
MockRecordingInterface,
MockImagingInterface,
MockSortingInterface,
)
from .mock_ttl_signals import generate_mock_ttl_signal, regenerate_test_cases
51 changes: 51 additions & 0 deletions src/neuroconv/tools/testing/mock_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from ...datainterfaces.ecephys.baserecordingextractorinterface import (
BaseRecordingExtractorInterface,
)
from ...datainterfaces.ecephys.basesortingextractorinterface import (
BaseSortingExtractorInterface,
)
from ...datainterfaces.ophys.baseimagingextractorinterface import (
BaseImagingExtractorInterface,
)
Expand Down Expand Up @@ -160,6 +163,54 @@ def get_metadata(self) -> dict:
return metadata


class MockSortingInterface(BaseSortingExtractorInterface):
"""A mock sorting extractor interface for generating synthetic sorting data."""

# TODO: Implement this class with the lazy generator once is merged
# https://github.com/SpikeInterface/spikeinterface/pull/2227

ExtractorModuleName = "spikeinterface.core.generate"
ExtractorName = "generate_sorting"

def __init__(
self,
num_units: int = 4,
sampling_frequency: float = 30_000.0,
durations: tuple[float] = (1.0,),
seed: int = 0,
verbose: bool = True,
):
"""
Parameters
----------
num_units : int, optional
Number of units to generate, by default 4.
sampling_frequency : float, optional
Sampling frequency of the generated data in Hz, by default 30,000.0 Hz.
durations : tuple of float, optional
Durations of the segments in seconds, by default (1.0,).
seed : int, optional
Seed for the random number generator, by default 0.
verbose : bool, optional
Control whether to display verbose messages during writing, by default True.
"""

super().__init__(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=durations,
seed=seed,
verbose=verbose,
)

def get_metadata(self) -> dict: # noqa D102
metadata = super().get_metadata()
session_start_time = datetime.now().astimezone()
metadata["NWBFile"]["session_start_time"] = session_start_time
return metadata


class MockImagingInterface(BaseImagingExtractorInterface):
"""
A mock imaging interface for testing purposes.
Expand Down
33 changes: 31 additions & 2 deletions tests/test_ecephys/test_ecephys_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
BaseSortingExtractorInterface,
)
from neuroconv.tools.nwb_helpers import get_module
from neuroconv.tools.testing.mock_interfaces import MockRecordingInterface
from neuroconv.tools.testing.mock_interfaces import (
MockRecordingInterface,
MockSortingInterface,
)

python_version = Version(get_python_version())

Expand Down Expand Up @@ -67,7 +70,33 @@ def test_spike2_import_assertions_3_11(self):
Spike2RecordingInterface.get_all_channels_info(file_path="does_not_matter.smrx")


class TestSortingInterface(unittest.TestCase):
class TestSortingInterface:

def test_run_conversion(self, tmp_path):

nwbfile_path = Path(tmp_path) / "test_sorting.nwb"
num_units = 4
interface = MockSortingInterface(num_units=num_units, durations=(1.0,))
interface.sorting_extractor = interface.sorting_extractor.rename_units(new_unit_ids=["a", "b", "c", "d"])

interface.run_conversion(nwbfile_path=nwbfile_path)
with NWBHDF5IO(nwbfile_path, "r") as io:
nwbfile = io.read()

units = nwbfile.units
assert len(units) == num_units
units_df = units.to_dataframe()
# Get index in units table
for unit_id in interface.sorting_extractor.unit_ids:
# In pynwb we write unit name as unit_id
row = units_df.query(f"unit_name == '{unit_id}'")
spike_times = interface.sorting_extractor.get_unit_spike_train(unit_id=unit_id, return_times=True)
written_spike_times = row["spike_times"].iloc[0]

np.testing.assert_array_equal(spike_times, written_spike_times)


class TestSortingInterfaceOld(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.test_dir = Path(mkdtemp())
Expand Down

0 comments on commit ad9f25d

Please sign in to comment.