Skip to content

Commit

Permalink
Merge pull request #639 from magland/set-probe
Browse files Browse the repository at this point in the history
add set_probe method to BaseRecordingExtractorInterface
  • Loading branch information
CodyCBakerPhD authored Nov 28, 2023
2 parents a4b157f + 5b3a999 commit 09b91cb
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Modify the filtering of traces to also filter out traces with empty values. [PR #649](https://github.com/catalystneuro/neuroconv/pull/649)
* Added tool function `get_default_dataset_configurations` for identifying and collecting all fields of an in-memory `NWBFile` that could become datasets on disk; and return instances of the Pydantic dataset models filled with default values for chunking/buffering/compression. [PR #569](https://github.com/catalystneuro/neuroconv/pull/569)
* Added tool function `get_default_backend_configuration` for conveniently packaging the results of `get_default_dataset_configurations` into an easy-to-modify mapping from locations of objects within the file to their correseponding dataset configuration options, as well as linking to a specific backend DataIO. [PR #570](https://github.com/catalystneuro/neuroconv/pull/570)
* Added `set_probe()` method to `BaseRecordingExtractorInterface`. [PR #639](https://github.com/catalystneuro/neuroconv/pull/639)

### Fixes
* Fixed GenericDataChunkIterator (in hdmf.py) in the case where the number of dimensions is 1 and the size in bytes is greater than the threshold of 1 GB. [PR #638](https://github.com/catalystneuro/neuroconv/pull/638)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,40 @@ def set_aligned_segment_starting_times(self, aligned_segment_starting_times: Lis
]
self.set_aligned_segment_timestamps(aligned_segment_timestamps=aligned_segment_timestamps)

def set_probe(self, probe, group_mode: Literal["by_shank", "by_probe"]):
"""
Set the probe information via a ProbeInterface object.
Parameters
----------
probe : probeinterface.Probe
The probe object.
group_mode : {'by_shank', 'by_probe'}
How to group the channels. If 'by_shank', channels are grouped by the shank_id column.
If 'by_probe', channels are grouped by the probe_id column.
This is a required parameter to avoid the pitfall of using the wrong mode.
"""
# Set the probe to the recording extractor
self.recording_extractor.set_probe(
probe,
in_place=True,
group_mode=group_mode,
)
# Spike interface sets the "group" property
# But neuroconv allows "group_name" property to override spike interface "group" value
self.recording_extractor.set_property("group_name", self.recording_extractor.get_property("group").astype(str))

def has_probe(self) -> bool:
"""
Check if the recording extractor has probe information.
Returns
-------
has_probe : bool
True if the recording extractor has probe information.
"""
return self.recording_extractor.has_probe()

def align_by_interpolation(
self,
unaligned_timestamps: np.ndarray,
Expand Down
35 changes: 35 additions & 0 deletions src/neuroconv/tools/testing/data_interface_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
)
from neuroconv.utils import NWBMetaDataEncoder

from .mock_probes import generate_mock_probe


class DataInterfaceTestMixin:
"""
Expand Down Expand Up @@ -300,6 +302,14 @@ def check_read_nwb(self, nwbfile_path: str):
# are specified, which occurs during check_recordings_equal when there is only one channel
if self.nwb_recording.get_channel_ids()[0] != self.nwb_recording.get_channel_ids()[-1]:
check_recordings_equal(RX1=recording, RX2=self.nwb_recording, return_scaled=False)
for property_name in ["rel_x", "rel_y", "rel_z", "group"]:
if (
property_name in recording.get_property_keys()
or property_name in self.nwb_recording.get_property_keys()
):
assert_array_equal(
recording.get_property(property_name), self.nwb_recording.get_property(property_name)
)
if recording.has_scaled_traces() and self.nwb_recording.has_scaled_traces():
check_recordings_equal(RX1=recording, RX2=self.nwb_recording, return_scaled=True)

Expand Down Expand Up @@ -459,6 +469,31 @@ def test_interface_alignment(self):

self.check_nwbfile_temporal_alignment()

def test_conversion_as_lone_interface(self):
interface_kwargs = self.interface_kwargs
if isinstance(interface_kwargs, dict):
interface_kwargs = [interface_kwargs]
for num, kwargs in enumerate(interface_kwargs):
with self.subTest(str(num)):
self.case = num
self.test_kwargs = kwargs
self.interface = self.data_interface_cls(**self.test_kwargs)
assert isinstance(self.interface, BaseRecordingExtractorInterface)
if not self.interface.has_probe():
self.interface.set_probe(
generate_mock_probe(num_channels=self.interface.recording_extractor.get_num_channels()),
group_mode="by_shank",
)
self.check_metadata_schema_valid()
self.check_conversion_options_schema_valid()
self.check_metadata()
self.nwbfile_path = str(self.save_directory / f"{self.__class__.__name__}_{num}.nwb")
self.run_conversion(nwbfile_path=self.nwbfile_path)
self.check_read_nwb(nwbfile_path=self.nwbfile_path)

# Any extra custom checks to run
self.run_custom_checks()


class SortingExtractorInterfaceTestMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
data_interface_cls: BaseSortingExtractorInterface
Expand Down
29 changes: 29 additions & 0 deletions src/neuroconv/tools/testing/mock_probes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List

import numpy as np


def generate_mock_probe(num_channels: int, num_shanks: int = 3):
import probeinterface as pi

# The shank ids will be 0, 0, 0, ..., 1, 1, 1, ..., 2, 2, 2, ...
shank_ids: List[int] = []
positions = np.zeros((num_channels, 2))
# ceil division
channels_per_shank = (num_channels + num_shanks - 1) // num_shanks
for i in range(num_shanks):
# x0, y0 is the position of the first electrode in the shank
x0 = 0
y0 = i * 200
for j in range(channels_per_shank):
if len(shank_ids) == num_channels:
break
shank_ids.append(i)
x = x0 + j * 10
y = y0 + (j % 2) * 10
positions[len(shank_ids) - 1] = x, y
probe = pi.Probe(ndim=2, si_units="um")
probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
probe.set_device_channel_indices(np.arange(num_channels))
probe.set_shank_ids(shank_ids)
return probe
2 changes: 1 addition & 1 deletion tests/test_on_data/test_recording_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def check_extracted_metadata(self, metadata: dict):
assert len(metadata["Ecephys"]["Device"]) == 1
assert metadata["Ecephys"]["Device"][0]["name"] == "Neuronexus-32"
assert metadata["Ecephys"]["Device"][0]["description"] == "The ecephys device for the MEArec recording."
assert len(metadata["Ecephys"]["ElectrodeGroup"]) == 1
# assert len(metadata["Ecephys"]["ElectrodeGroup"]) == 1 # do not test this condition because in the test we are setting a mock probe
assert metadata["Ecephys"]["ElectrodeGroup"][0]["device"] == "Neuronexus-32"
assert metadata["Ecephys"]["ElectricalSeries"]["description"] == (
'{"angle_tol": 15, "bursting": false, "chunk_duration": 0, "color_noise_floor": 1, '
Expand Down

0 comments on commit 09b91cb

Please sign in to comment.