From 492507ec3451b6cb4862e7c1b6985074eefd0085 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 11 Sep 2024 11:46:41 +0200 Subject: [PATCH 1/8] Fix proposal for channel location when probegroup --- .../core/baserecordingsnippets.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 428472bf93..f7b55d3f6a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -348,17 +348,22 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy"): if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - if self.get_property("contact_vector") is not None: - if len(self.get_probes()) == 1: - probe = self.get_probe() - positions = probe.contact_positions[channel_indices] - else: - all_probes = self.get_probes() - # check that multiple probes are non-overlapping - check_probe_do_not_overlap(all_probes) - all_positions = np.vstack([probe.contact_positions for probe in all_probes]) - positions = all_positions[channel_indices] - return select_axes(positions, axes) + contact_vector = self.get_property("contact_vector") + if contact_vector is not None: + # to avoid the get_probes() when only one probe do check unique probe_id + num_probes = np.unique(contact_vector["probe_index"]).size + if num_probes > 1: + # get_probes() is called only when several probes check_overlaps + # TODO make this check_probe_do_not_overlap() use only the contact_vector instead of constructing the probe + check_probe_do_not_overlap(self.get_probes()) + + # here we bypass the probe reconstruction so this works both for probe and probegroup + ndim = len(axes) + all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") + for i, dim in enumerate(axes): + all_positions[:, i] = contact_vector[dim] + positions = all_positions[channel_indices] + return positions else: locations = self.get_property("location") if locations is None: From e5c710d50dc21e3e86748d31c4894a541fc3bcab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 13:04:20 +0200 Subject: [PATCH 2/8] Refactor analyzer.get_channel_locations() --- src/spikeinterface/core/sortinganalyzer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 49a31738e3..a7a1ad587e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1101,9 +1101,14 @@ def get_probe(self): def get_channel_locations(self) -> np.ndarray: # important note : contrary to recording # this give all channel locations, so no kwargs like channel_ids and axes - all_probes = self.get_probegroup().probes - all_positions = np.vstack([probe.contact_positions for probe in all_probes]) - return all_positions + probegroup = self.get_probegroup() + probe_as_numpy_array = probegroup.to_numpy() + # duplicate positions to "locations" property + ndim = probegroup.ndim + locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") + for i, dim in enumerate(["x", "y", "z"][:ndim]): + locations[:, i] = probe_as_numpy_array[dim] + return locations def channel_ids_to_indices(self, channel_ids) -> np.ndarray: all_channel_ids = list(self.rec_attributes["channel_ids"]) From 0f0834428081a7db68bf5be97d04803527f872e0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 13:09:34 +0200 Subject: [PATCH 3/8] Check probes do not overlap at _set_probes --- src/spikeinterface/core/baserecordingsnippets.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index f7b55d3f6a..763f9e5801 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -145,6 +145,11 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False else: raise ValueError("must give Probe or ProbeGroup or list of Probe") + # check that the probe do not overlap + num_probes = len(probegroup.probes) + if num_probes > 1: + check_probe_do_not_overlap(probegroup.probes) + # handle not connected channels assert all( probe.device_channel_indices is not None for probe in probegroup.probes @@ -350,13 +355,6 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy"): channel_indices = self.ids_to_indices(channel_ids) contact_vector = self.get_property("contact_vector") if contact_vector is not None: - # to avoid the get_probes() when only one probe do check unique probe_id - num_probes = np.unique(contact_vector["probe_index"]).size - if num_probes > 1: - # get_probes() is called only when several probes check_overlaps - # TODO make this check_probe_do_not_overlap() use only the contact_vector instead of constructing the probe - check_probe_do_not_overlap(self.get_probes()) - # here we bypass the probe reconstruction so this works both for probe and probegroup ndim = len(axes) all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") From cc447b05d91651e6b3a29280049fb23cc3fd0d10 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 13:10:06 +0200 Subject: [PATCH 4/8] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a7a1ad587e..83e214f4ab 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1103,7 +1103,6 @@ def get_channel_locations(self) -> np.ndarray: # this give all channel locations, so no kwargs like channel_ids and axes probegroup = self.get_probegroup() probe_as_numpy_array = probegroup.to_numpy() - # duplicate positions to "locations" property ndim = probegroup.ndim locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") for i, dim in enumerate(["x", "y", "z"][:ndim]): From d62653b8810cc53cd7275a4b25e4628ef87acf5c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 13:15:33 +0200 Subject: [PATCH 5/8] Sort probegroup array by device_channel_indices --- src/spikeinterface/core/sortinganalyzer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a7a1ad587e..f0bc8e49bb 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1103,6 +1103,8 @@ def get_channel_locations(self) -> np.ndarray: # this give all channel locations, so no kwargs like channel_ids and axes probegroup = self.get_probegroup() probe_as_numpy_array = probegroup.to_numpy() + # we need to sort by device_channel_indices to ensure the order of locations is correct + probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])] # duplicate positions to "locations" property ndim = probegroup.ndim locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") From 9539c93ba2a7dac133d4ce742cc92e834ca9576c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 13:18:42 +0200 Subject: [PATCH 6/8] fix to_numpy --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1f54f6687f..6ce8d180c5 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1102,7 +1102,7 @@ def get_channel_locations(self) -> np.ndarray: # important note : contrary to recording # this give all channel locations, so no kwargs like channel_ids and axes probegroup = self.get_probegroup() - probe_as_numpy_array = probegroup.to_numpy() + probe_as_numpy_array = probegroup.to_numpy(complete=True) # we need to sort by device_channel_indices to ensure the order of locations is correct probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])] ndim = probegroup.ndim From 5a8535a166f9c449d24054dfa47380b6cdb1e811 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 12 Sep 2024 16:54:27 +0200 Subject: [PATCH 7/8] Add recording and analyzer tests with interleaved probegroups --- .../core/baserecordingsnippets.py | 2 +- src/spikeinterface/core/recording_tools.py | 3 +- src/spikeinterface/core/sortinganalyzer.py | 3 +- .../core/tests/test_baserecording.py | 33 +++++++++++++++++-- .../core/tests/test_sortinganalyzer.py | 21 ++++++++++++ 5 files changed, 56 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 763f9e5801..d6088a01d7 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -239,7 +239,7 @@ def set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False) warning_msg = ( "`set_probes` is now a private function and the public function will be " - "removed in 0.103.0. Please use `set_probe` or `set_probegroups` instead" + "removed in 0.103.0. Please use `set_probe` or `set_probegroup` instead" ) warn(warning_msg, category=DeprecationWarning, stacklevel=2) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 0ec5449bae..5137eda545 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -888,11 +888,10 @@ def check_probe_do_not_overlap(probes): for j in range(i + 1, len(probes)): probe_j = probes[j] - if np.any( np.array( [ - x_bounds_i[0] < cp[0] < x_bounds_i[1] and y_bounds_i[0] < cp[1] < y_bounds_i[1] + x_bounds_i[0] <= cp[0] <= x_bounds_i[1] and y_bounds_i[0] <= cp[1] <= y_bounds_i[1] for cp in probe_j.contact_positions ] ) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 6ce8d180c5..e3b6527b90 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1107,7 +1107,8 @@ def get_channel_locations(self) -> np.ndarray: probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])] ndim = probegroup.ndim locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") - for i, dim in enumerate(["x", "y", "z"][:ndim]): + # here we only loop through xy because only 2d locations are supported + for i, dim in enumerate(["x", "y"][:ndim]): locations[:, i] = probe_as_numpy_array[dim] return locations diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 6b60efe2b6..df614978ba 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -10,7 +10,7 @@ import numpy as np from numpy.testing import assert_raises -from probeinterface import Probe +from probeinterface import Probe, ProbeGroup, generate_linear_probe from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, load_extractor, get_default_zarr_compressor from spikeinterface.core.base import BaseExtractor @@ -358,6 +358,34 @@ def test_BaseRecording(create_cache_folder): assert np.allclose(rec_u.get_traces(cast_unsigned=True), rec_i.get_traces().astype("float")) +def test_interleaved_probegroups(): + recording = generate_recording(durations=[1.0], num_channels=16) + + probe1 = generate_linear_probe(num_elec=8, ypitch=20.0) + probe2_overlap = probe1.copy() + + probegroup_overlap = ProbeGroup() + probegroup_overlap.add_probe(probe1) + probegroup_overlap.add_probe(probe2_overlap) + probegroup_overlap.set_global_device_channel_indices(np.arange(16)) + + # setting overlapping probes should raise an error + with pytest.raises(Exception): + recording.set_probegroup(probegroup_overlap) + + probe2 = probe1.copy() + probe2.move([100.0, 100.0]) + probegroup = ProbeGroup() + probegroup.add_probe(probe1) + probegroup.add_probe(probe2) + probegroup.set_global_device_channel_indices(np.random.permutation(16)) + + recording.set_probegroup(probegroup) + probegroup_set = recording.get_probegroup() + # check that the probe group is correctly set, by sorting the device channel indices + assert np.array_equal(probegroup_set.get_global_device_channel_indices()["device_channel_indices"], np.arange(16)) + + def test_rename_channels(): recording = generate_recording(durations=[1.0], num_channels=3) renamed_recording = recording.rename_channels(new_channel_ids=["a", "b", "c"]) @@ -399,4 +427,5 @@ def test_time_slice_with_time_vector(): if __name__ == "__main__": - test_BaseRecording() + # test_BaseRecording() + test_interleaved_probegroups() diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 3f45487f4c..4468c3f505 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -178,6 +178,27 @@ def test_SortingAnalyzer_tmp_recording(dataset): sorting_analyzer.set_temporary_recording(recording_sliced) +def test_SortingAnalyzer_interleaved_probegroup(dataset): + from probeinterface import generate_linear_probe, ProbeGroup + + recording, sorting = dataset + num_channels = recording.get_num_channels() + probe1 = generate_linear_probe(num_elec=num_channels // 2, ypitch=20.0) + probe2 = probe1.copy() + probe2.move([100.0, 100.0]) + + probegroup = ProbeGroup() + probegroup.add_probe(probe1) + probegroup.add_probe(probe2) + probegroup.set_global_device_channel_indices(np.random.permutation(num_channels)) + + recording.set_probegroup(probegroup) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) + # check that locations are correct + assert np.array_equal(recording.get_channel_locations(), sorting_analyzer.get_channel_locations()) + + def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): register_result_extension(DummyAnalyzerExtension) From ba278a2915e9e2d8bc9e6653dbf655a637aba94a Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Thu, 12 Sep 2024 21:17:04 +0200 Subject: [PATCH 8/8] Update src/spikeinterface/core/tests/test_sortinganalyzer.py --- src/spikeinterface/core/tests/test_sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 4468c3f505..bc1db643df 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -192,7 +192,7 @@ def test_SortingAnalyzer_interleaved_probegroup(dataset): probegroup.add_probe(probe2) probegroup.set_global_device_channel_indices(np.random.permutation(num_channels)) - recording.set_probegroup(probegroup) + recording = recording.set_probegroup(probegroup) sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) # check that locations are correct