Skip to content

Commit

Permalink
Merge pull request #3392 from samuelgarcia/fix_probegorup_location
Browse files Browse the repository at this point in the history
Fix proposal for channel location when probegroup
  • Loading branch information
alejoe91 authored Sep 13, 2024
2 parents 8d9f8db + ba278a2 commit 4f5b34e
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 19 deletions.
27 changes: 15 additions & 12 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False
else:
raise ValueError("must give Probe or ProbeGroup or list of Probe")

# check that the probe do not overlap
num_probes = len(probegroup.probes)
if num_probes > 1:
check_probe_do_not_overlap(probegroup.probes)

# handle not connected channels
assert all(
probe.device_channel_indices is not None for probe in probegroup.probes
Expand Down Expand Up @@ -234,7 +239,7 @@ def set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False)

warning_msg = (
"`set_probes` is now a private function and the public function will be "
"removed in 0.103.0. Please use `set_probe` or `set_probegroups` instead"
"removed in 0.103.0. Please use `set_probe` or `set_probegroup` instead"
)

warn(warning_msg, category=DeprecationWarning, stacklevel=2)
Expand Down Expand Up @@ -348,17 +353,15 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy"):
if channel_ids is None:
channel_ids = self.get_channel_ids()
channel_indices = self.ids_to_indices(channel_ids)
if self.get_property("contact_vector") is not None:
if len(self.get_probes()) == 1:
probe = self.get_probe()
positions = probe.contact_positions[channel_indices]
else:
all_probes = self.get_probes()
# check that multiple probes are non-overlapping
check_probe_do_not_overlap(all_probes)
all_positions = np.vstack([probe.contact_positions for probe in all_probes])
positions = all_positions[channel_indices]
return select_axes(positions, axes)
contact_vector = self.get_property("contact_vector")
if contact_vector is not None:
# here we bypass the probe reconstruction so this works both for probe and probegroup
ndim = len(axes)
all_positions = np.zeros((contact_vector.size, ndim), dtype="float64")
for i, dim in enumerate(axes):
all_positions[:, i] = contact_vector[dim]
positions = all_positions[channel_indices]
return positions
else:
locations = self.get_property("location")
if locations is None:
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,11 +888,10 @@ def check_probe_do_not_overlap(probes):

for j in range(i + 1, len(probes)):
probe_j = probes[j]

if np.any(
np.array(
[
x_bounds_i[0] < cp[0] < x_bounds_i[1] and y_bounds_i[0] < cp[1] < y_bounds_i[1]
x_bounds_i[0] <= cp[0] <= x_bounds_i[1] and y_bounds_i[0] <= cp[1] <= y_bounds_i[1]
for cp in probe_j.contact_positions
]
)
Expand Down
13 changes: 10 additions & 3 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,9 +1111,16 @@ def get_probe(self):
def get_channel_locations(self) -> np.ndarray:
# important note : contrary to recording
# this give all channel locations, so no kwargs like channel_ids and axes
all_probes = self.get_probegroup().probes
all_positions = np.vstack([probe.contact_positions for probe in all_probes])
return all_positions
probegroup = self.get_probegroup()
probe_as_numpy_array = probegroup.to_numpy(complete=True)
# we need to sort by device_channel_indices to ensure the order of locations is correct
probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])]
ndim = probegroup.ndim
locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64")
# here we only loop through xy because only 2d locations are supported
for i, dim in enumerate(["x", "y"][:ndim]):
locations[:, i] = probe_as_numpy_array[dim]
return locations

def channel_ids_to_indices(self, channel_ids) -> np.ndarray:
all_channel_ids = list(self.rec_attributes["channel_ids"])
Expand Down
33 changes: 31 additions & 2 deletions src/spikeinterface/core/tests/test_baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from numpy.testing import assert_raises

from probeinterface import Probe
from probeinterface import Probe, ProbeGroup, generate_linear_probe

from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, load_extractor, get_default_zarr_compressor
from spikeinterface.core.base import BaseExtractor
Expand Down Expand Up @@ -358,6 +358,34 @@ def test_BaseRecording(create_cache_folder):
assert np.allclose(rec_u.get_traces(cast_unsigned=True), rec_i.get_traces().astype("float"))


def test_interleaved_probegroups():
recording = generate_recording(durations=[1.0], num_channels=16)

probe1 = generate_linear_probe(num_elec=8, ypitch=20.0)
probe2_overlap = probe1.copy()

probegroup_overlap = ProbeGroup()
probegroup_overlap.add_probe(probe1)
probegroup_overlap.add_probe(probe2_overlap)
probegroup_overlap.set_global_device_channel_indices(np.arange(16))

# setting overlapping probes should raise an error
with pytest.raises(Exception):
recording.set_probegroup(probegroup_overlap)

probe2 = probe1.copy()
probe2.move([100.0, 100.0])
probegroup = ProbeGroup()
probegroup.add_probe(probe1)
probegroup.add_probe(probe2)
probegroup.set_global_device_channel_indices(np.random.permutation(16))

recording.set_probegroup(probegroup)
probegroup_set = recording.get_probegroup()
# check that the probe group is correctly set, by sorting the device channel indices
assert np.array_equal(probegroup_set.get_global_device_channel_indices()["device_channel_indices"], np.arange(16))


def test_rename_channels():
recording = generate_recording(durations=[1.0], num_channels=3)
renamed_recording = recording.rename_channels(new_channel_ids=["a", "b", "c"])
Expand Down Expand Up @@ -399,4 +427,5 @@ def test_time_slice_with_time_vector():


if __name__ == "__main__":
test_BaseRecording()
# test_BaseRecording()
test_interleaved_probegroups()
21 changes: 21 additions & 0 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,27 @@ def test_SortingAnalyzer_tmp_recording(dataset):
sorting_analyzer.set_temporary_recording(recording_sliced)


def test_SortingAnalyzer_interleaved_probegroup(dataset):
from probeinterface import generate_linear_probe, ProbeGroup

recording, sorting = dataset
num_channels = recording.get_num_channels()
probe1 = generate_linear_probe(num_elec=num_channels // 2, ypitch=20.0)
probe2 = probe1.copy()
probe2.move([100.0, 100.0])

probegroup = ProbeGroup()
probegroup.add_probe(probe1)
probegroup.add_probe(probe2)
probegroup.set_global_device_channel_indices(np.random.permutation(num_channels))

recording = recording.set_probegroup(probegroup)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False)
# check that locations are correct
assert np.array_equal(recording.get_channel_locations(), sorting_analyzer.get_channel_locations())


def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):

register_result_extension(DummyAnalyzerExtension)
Expand Down

0 comments on commit 4f5b34e

Please sign in to comment.