From cf65301c82c48e72e10a77f6a7f891453b69e409 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 15:24:30 +0200 Subject: [PATCH 01/27] Check main_ids are ints or strings --- src/spikeinterface/core/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 8b4f094c20..86692fa69c 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -47,6 +47,7 @@ def __init__(self, main_ids: Sequence) -> None: # 'main_ids' will either be channel_ids or units_ids # They is used for properties self._main_ids = np.array(main_ids) + assert self._main_ids.dtype.kind in "uiSU", "Main IDs can only be integers (signed/unsigned) or strings" # dict at object level self._annotations = {} From 8343d3a70a6bb3cf56f3013abc77c8e534059150 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 15:57:54 +0200 Subject: [PATCH 02/27] Fix NpySnippets --- src/spikeinterface/core/base.py | 3 ++- src/spikeinterface/core/baserecordingsnippets.py | 4 ++-- src/spikeinterface/core/basesnippets.py | 2 -- src/spikeinterface/core/npysnippetsextractor.py | 5 ++++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 86692fa69c..f1a51c99d1 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -47,7 +47,8 @@ def __init__(self, main_ids: Sequence) -> None: # 'main_ids' will either be channel_ids or units_ids # They is used for properties self._main_ids = np.array(main_ids) - assert self._main_ids.dtype.kind in "uiSU", "Main IDs can only be integers (signed/unsigned) or strings" + if len(self._main_ids) > 0: + assert self._main_ids.dtype.kind in "uiSU", "Main IDs can only be integers (signed/unsigned) or strings" # dict at object level self._annotations = {} diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index affde8a75e..d411f38d2a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,4 +1,4 @@ -from typing import List +from __future__ import annotations from pathlib import Path import numpy as np @@ -19,7 +19,7 @@ class BaseRecordingSnippets(BaseExtractor): has_default_locations = False - def __init__(self, sampling_frequency: float, channel_ids: List, dtype): + def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype): BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = sampling_frequency self._dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index f35bc2b266..b4e3c11f55 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -1,10 +1,8 @@ from typing import List, Union -from pathlib import Path from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets import numpy as np from warnings import warn -from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes # snippets segments? diff --git a/src/spikeinterface/core/npysnippetsextractor.py b/src/spikeinterface/core/npysnippetsextractor.py index 80979ce6c9..69c48356e5 100644 --- a/src/spikeinterface/core/npysnippetsextractor.py +++ b/src/spikeinterface/core/npysnippetsextractor.py @@ -27,6 +27,9 @@ def __init__( num_segments = len(file_paths) data = np.load(file_paths[0], mmap_mode="r") + if channel_ids is None: + channel_ids = np.arange(data["snippet"].shape[2]) + BaseSnippets.__init__( self, sampling_frequency, @@ -84,7 +87,7 @@ def write_snippets(snippets, file_paths, dtype=None): arr = np.empty(n, dtype=snippets_t, order="F") arr["frame"] = snippets.get_frames(segment_index=i) arr["snippet"] = snippets.get_snippets(segment_index=i).astype(dtype, copy=False) - + file_paths[i].parent.mkdir(parents=True, exist_ok=True) np.save(file_paths[i], arr) From 89d1f827c445702a61eda864c9972401567a9b67 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 16:26:25 +0200 Subject: [PATCH 03/27] Force CellExplorer unit ids as int --- src/spikeinterface/core/base.py | 4 +++- src/spikeinterface/extractors/cellexplorersortingextractor.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index f1a51c99d1..1116aeb507 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -48,7 +48,9 @@ def __init__(self, main_ids: Sequence) -> None: # They is used for properties self._main_ids = np.array(main_ids) if len(self._main_ids) > 0: - assert self._main_ids.dtype.kind in "uiSU", "Main IDs can only be integers (signed/unsigned) or strings" + assert ( + self._main_ids.dtype.kind in "uiSU" + ), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}" # dict at object level self._annotations = {} diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 31241a4147..f72670fbcd 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -118,7 +118,7 @@ def __init__( spike_times = spikes_data["times"] # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames - unit_ids = unit_ids[:].tolist() + unit_ids = unit_ids[:].astype(int).tolist() spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)} for unit_id in unit_ids: spiketrains_dict[unit_id] = (sampling_frequency * spiketrains_dict[unit_id]).round().astype(np.int64) From d75f0588707da10a61e926e337334739a0b9a20b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 3 Oct 2023 10:58:15 +0200 Subject: [PATCH 04/27] Update src/spikeinterface/extractors/cellexplorersortingextractor.py Co-authored-by: Heberto Mayorquin --- src/spikeinterface/extractors/cellexplorersortingextractor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index f72670fbcd..0096a40a79 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -119,6 +119,7 @@ def __init__( # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames unit_ids = unit_ids[:].astype(int).tolist() + unit_ids = [str(unit_id) for unit_id in unit_ids] spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)} for unit_id in unit_ids: spiketrains_dict[unit_id] = (sampling_frequency * spiketrains_dict[unit_id]).round().astype(np.int64) From 86b2271df55b671b49cd5b58601df94ab0dd2109 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 4 Oct 2023 16:03:51 +0200 Subject: [PATCH 05/27] Change some default parameters for better user experience. --- src/spikeinterface/core/waveform_extractor.py | 8 ++++---- src/spikeinterface/postprocessing/correlograms.py | 4 ++-- src/spikeinterface/postprocessing/unit_localization.py | 2 +- src/spikeinterface/sorters/runsorter.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6d9e5d41e3..1c6002226f 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1458,13 +1458,13 @@ def extract_waveforms( folder=None, mode="folder", precompute_template=("average",), - ms_before=3.0, - ms_after=4.0, + ms_before=1.0, + ms_after=2.0, max_spikes_per_unit=500, overwrite=False, return_scaled=True, dtype=None, - sparse=False, + sparse=True, sparsity=None, num_spikes_for_sparsity=100, allow_unfiltered=False, @@ -1508,7 +1508,7 @@ def extract_waveforms( If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV. dtype: dtype or None Dtype of the output waveforms. If None, the recording dtype is maintained. - sparse: bool (default False) + sparse: bool (default True) If True, before extracting all waveforms the `precompute_sparsity()` function is run using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the waveforms will be sparse at extraction time, which saves a lot of memory. diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 6cd5238abd..6e693635eb 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -137,8 +137,8 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_ def compute_correlograms( waveform_or_sorting_extractor, load_if_exists=False, - window_ms: float = 100.0, - bin_ms: float = 5.0, + window_ms: float = 50.0, + bin_ms: float = 1.0, method: str = "auto", ): """Compute auto and cross correlograms. diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index d2739f69dd..48ceb34a4e 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -96,7 +96,7 @@ def get_extension_function(): def compute_unit_locations( - waveform_extractor, load_if_exists=False, method="center_of_mass", outputs="numpy", **method_kwargs + waveform_extractor, load_if_exists=False, method="monopolar_triangulation", outputs="numpy", **method_kwargs ): """ Localize units in 2D or 3D with several methods given the template. diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 9bacd8e2c9..a49a605a75 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -91,7 +91,7 @@ def run_sorter( sorter_name: str, recording: BaseRecording, output_folder: Optional[str] = None, - remove_existing_folder: bool = True, + remove_existing_folder: bool = False, delete_output_folder: bool = False, verbose: bool = False, raise_error: bool = True, From d9803d43e9598810337d11d2e68414261dbc3b81 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 4 Oct 2023 17:07:05 +0200 Subject: [PATCH 06/27] oups --- src/spikeinterface/core/waveform_extractor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index d83b3d66f1..eb027faf81 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1726,6 +1726,7 @@ def precompute_sparsity( max_spikes_per_unit=num_spikes_for_sparsity, return_scaled=False, allow_unfiltered=allow_unfiltered, + sparse=False, **job_kwargs, ) local_sparsity = compute_sparsity(local_we, **sparse_kwargs) From 590cd6ba2440569469859a0e08ce321a5320e27d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 4 Oct 2023 21:04:26 +0200 Subject: [PATCH 07/27] small fix --- src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index d25f1ea97b..364fc298c6 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -43,6 +43,8 @@ def plot(self): self._do_plot() def _do_plot(self): + from matplotlib import pyplot as plt + fig = self.figure for ax in fig.axes: From 204c8e90fd44d56e4b5eb6b0b7e92f09ea18db91 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 4 Oct 2023 21:08:17 +0200 Subject: [PATCH 08/27] fix waveform extactor with empty sorting and sparse --- src/spikeinterface/core/sparsity.py | 6 +++++- src/spikeinterface/core/tests/test_waveform_extractor.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 8c5c62d568..896e3800d7 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -102,7 +102,11 @@ def __init__(self, mask, unit_ids, channel_ids): self.num_channels = self.channel_ids.size self.num_units = self.unit_ids.size - self.max_num_active_channels = self.mask.sum(axis=1).max() + if self.mask.shape[0]: + self.max_num_active_channels = self.mask.sum(axis=1).max() + else: + # empty sorting without units + self.max_num_active_channels = 0 def __repr__(self): density = np.mean(self.mask) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 2bbf5e9b0f..00244f600b 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -556,4 +556,5 @@ def test_non_json_object(): # test_portability() # test_recordingless() # test_compute_sparsity() - test_non_json_object() + # test_non_json_object() + test_empty_sorting() From 50f6fcf5322bf10f1b8310ac228921a975b17557 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 12:16:50 +0200 Subject: [PATCH 09/27] small fix unrelated --- src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index 364fc298c6..c921f42c6d 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -179,6 +179,8 @@ def plot(self): def _do_plot(self): import sklearn + import matplotlib.pyplot as plt + import matplotlib # compute similarity # take index of template (respect unit_ids order) From 0798169827321ca8a823780baa377ed8d5820469 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Thu, 5 Oct 2023 13:12:27 +0200 Subject: [PATCH 10/27] Update src/spikeinterface/core/waveform_extractor.py --- src/spikeinterface/core/waveform_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index eb027faf81..0fc5694207 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1507,7 +1507,7 @@ def extract_waveforms( If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV. dtype: dtype or None Dtype of the output waveforms. If None, the recording dtype is maintained. - sparse: bool (default True) + sparse: bool, default: True If True, before extracting all waveforms the `precompute_sparsity()` function is run using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the waveforms will be sparse at extraction time, which saves a lot of memory. From 4293b2244be7b71aa0ce68f4dabad24d23318637 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 11:17:03 +0000 Subject: [PATCH 11/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index c921f42c6d..468b96ff3b 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -44,7 +44,7 @@ def plot(self): def _do_plot(self): from matplotlib import pyplot as plt - + fig = self.figure for ax in fig.axes: From 3371915310a4bda8cbd9ecd8a5e2d2f3e0ee55b1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 15:36:46 +0200 Subject: [PATCH 12/27] Keep sparse=False in postprocessing tests --- .../postprocessing/tests/common_extension_tests.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 8f864e9b84..50e2ecdb57 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -57,6 +57,7 @@ def setUp(self): ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, + sparse=False, n_jobs=1, chunk_size=30000, overwrite=True, @@ -92,6 +93,7 @@ def setUp(self): ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, + sparse=False, n_jobs=1, chunk_size=30000, overwrite=True, @@ -112,6 +114,7 @@ def setUp(self): recording, sorting, mode="memory", + sparse=False, ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, From f2fe6bbcedc5a1cca38918444afe52e3ae1bec19 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 5 Oct 2023 11:42:38 -0400 Subject: [PATCH 13/27] assert typo fixes round 1 --- src/spikeinterface/core/base.py | 6 +-- src/spikeinterface/core/baserecording.py | 6 +-- src/spikeinterface/core/basesorting.py | 2 +- .../core/binaryrecordingextractor.py | 2 +- .../core/channelsaggregationrecording.py | 4 +- src/spikeinterface/core/channelslice.py | 4 +- .../core/frameslicerecording.py | 2 +- src/spikeinterface/core/frameslicesorting.py | 8 ++-- src/spikeinterface/core/generate.py | 4 +- src/spikeinterface/core/template_tools.py | 41 ++++++++++--------- .../core/unitsaggregationsorting.py | 2 +- 11 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 8b4f094c20..ba18cf09b6 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -45,7 +45,7 @@ def __init__(self, main_ids: Sequence) -> None: self._kwargs = {} # 'main_ids' will either be channel_ids or units_ids - # They is used for properties + # They are used for properties self._main_ids = np.array(main_ids) # dict at object level @@ -984,7 +984,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: class_name = None if "kwargs" not in dic: - raise Exception(f"This dict cannot be load into extractor {dic}") + raise Exception(f"This dict cannot be loaded into extractor {dic}") # Create new kwargs to avoid modifying the original dict["kwargs"] new_kwargs = dict() @@ -1005,7 +1005,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class" if not _check_same_version(class_name, dic["version"]): warnings.warn( - f"Versions are not the same. This might lead compatibility errors. " + f"Versions are not the same. This might lead to compatibility errors. " f"Using {class_name.split('.')[0]}=={dic['version']} is recommended" ) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 08f187895b..d3572ef66b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -305,7 +305,7 @@ def get_traces( if not self.has_scaled(): raise ValueError( - "This recording do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)" + "This recording does not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)" ) else: gains = self.get_property("gain_to_uV") @@ -416,8 +416,8 @@ def set_times(self, times, segment_index=None, with_warning=True): if with_warning: warn( "Setting times with Recording.set_times() is not recommended because " - "times are not always propagated to across preprocessing" - "Use use this carefully!" + "times are not always propagated across preprocessing" + "Use this carefully!" ) def sample_index_to_time(self, sample_ind, segment_index=None): diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index e6d08d38f7..2a06a699cb 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -170,7 +170,7 @@ def register_recording(self, recording, check_spike_frames=True): if check_spike_frames: if has_exceeding_spikes(recording, self): warnings.warn( - "Some spikes are exceeding the recording's duration! " + "Some spikes exceed the recording's duration! " "Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` " "Might be necessary for further postprocessing." ) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 72a95637f6..b45290caa5 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -91,7 +91,7 @@ def __init__( file_path_list = [Path(file_paths)] if t_starts is not None: - assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths" + assert len(t_starts) == len(file_path_list), "t_starts must be a list of the same size as file_paths" t_starts = [float(t_start) for t_start in t_starts] dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index d36e168f8d..8714580821 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -104,11 +104,11 @@ def __init__(self, channel_map, parent_segments): times_kargs0 = parent_segment0.get_times_kwargs() if times_kargs0["time_vector"] is None: for ps in parent_segments: - assert ps.get_times_kwargs()["time_vector"] is None, "All segment should not have times set" + assert ps.get_times_kwargs()["time_vector"] is None, "All segments should not have times set" else: for ps in parent_segments: assert ps.get_times_kwargs()["t_start"] == times_kargs0["t_start"], ( - "All segment should have the same " "t_start" + "All segments should have the same " "t_start" ) BaseRecordingSegment.__init__(self, **times_kargs0) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index ebd1b7db03..3a21e356a6 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -35,7 +35,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) ), "ChannelSliceRecording: renamed channel_ids must be the same size" assert ( self._channel_ids.size == np.unique(self._channel_ids).size - ), "ChannelSliceRecording : channel_ids not unique" + ), "ChannelSliceRecording : channel_ids are not unique" sampling_frequency = parent_recording.get_sampling_frequency() @@ -123,7 +123,7 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): ), "ChannelSliceSnippets: renamed channel_ids must be the same size" assert ( self._channel_ids.size == np.unique(self._channel_ids).size - ), "ChannelSliceSnippets : channel_ids not unique" + ), "ChannelSliceSnippets : channel_ids are not unique" sampling_frequency = parent_snippets.get_sampling_frequency() diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 968f27c6ad..b8574c506f 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -27,7 +27,7 @@ class FrameSliceRecording(BaseRecording): def __init__(self, parent_recording, start_frame=None, end_frame=None): channel_ids = parent_recording.get_channel_ids() - assert parent_recording.get_num_segments() == 1, "FrameSliceRecording work only with one segment" + assert parent_recording.get_num_segments() == 1, "FrameSliceRecording only works with one segment" parent_size = parent_recording.get_num_samples(0) if start_frame is None: diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index 5da5350f06..ed1391b0e2 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -36,7 +36,7 @@ class FrameSliceSorting(BaseSorting): def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike_frames=True): unit_ids = parent_sorting.get_unit_ids() - assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting work only with one segment" + assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting only works with one segment" if start_frame is None: start_frame = 0 @@ -49,10 +49,10 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike end_frame = parent_n_samples assert ( end_frame <= parent_n_samples - ), "`end_frame` should be smaller than the sortings total number of samples." + ), "`end_frame` should be smaller than the sortings' total number of samples." assert ( start_frame <= parent_n_samples - ), "`start_frame` should be smaller than the sortings total number of samples." + ), "`start_frame` should be smaller than the sortings' total number of samples." if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): raise ValueError( "The sorting object has spikes exceeding the recording duration. You have to remove those spikes " @@ -67,7 +67,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike end_frame = max_spike_time + 1 assert start_frame < end_frame, ( - "`start_frame` should be greater than `end_frame`. " + "`start_frame` should be less than `end_frame`. " "This may be due to start_frame >= max_spike_time, if the end frame " "was not specified explicitly." ) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 06a5ec96ec..0c67404069 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1101,11 +1101,11 @@ def __init__( # handle also upsampling and jitter upsample_factor = templates.shape[3] elif templates.ndim == 5: - # handle also dirft + # handle also drift raise NotImplementedError("Drift will be implented soon...") # upsample_factor = templates.shape[3] else: - raise ValueError("templates have wring dim should 3 or 4") + raise ValueError("templates have wrong dim should 3 or 4") if upsample_factor is not None: assert upsample_vector is not None diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 95278b76da..552642751c 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np import warnings @@ -5,7 +6,7 @@ from .recording_tools import get_channel_distances, get_noise_levels -def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: str = "extremum"): +def get_template_amplitudes(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum"): """ Get amplitude per channel for each unit. @@ -13,9 +14,9 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "extremum" 'extremum': max or min 'at_index': take value at spike index @@ -24,8 +25,8 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st peak_values: dict Dictionary with unit ids as keys and template amplitudes as values """ - assert peak_sign in ("both", "neg", "pos") - assert mode in ("extremum", "at_index") + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" + assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" unit_ids = waveform_extractor.sorting.unit_ids before = waveform_extractor.nbefore @@ -57,7 +58,7 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st def get_template_extremum_channel( - waveform_extractor, peak_sign: str = "neg", mode: str = "extremum", outputs: str = "id" + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", outputs: "id" | "index" = "id" ): """ Compute the channel with the extremum peak for each unit. @@ -66,12 +67,12 @@ def get_template_extremum_channel( ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "extremum" 'extremum': max or min 'at_index': take value at spike index - outputs: str + outputs: "id" | "index", default: "id" * 'id': channel id * 'index': channel index @@ -159,7 +160,7 @@ def get_template_channel_sparsity( get_template_channel_sparsity.__doc__ = get_template_channel_sparsity.__doc__.format(_sparsity_doc) -def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str = "neg"): +def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg"): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. @@ -169,8 +170,8 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels Returns ------- @@ -203,7 +204,7 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str return shifts -def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", mode: str = "at_index"): +def get_template_extremum_amplitude(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index"): """ Computes amplitudes on the best channel. @@ -211,9 +212,9 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "at_index" Where the amplitude is computed 'extremum': max or min 'at_index': take value at spike index @@ -223,8 +224,8 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", amplitudes: dict Dictionary with unit ids as keys and amplitudes as values """ - assert peak_sign in ("both", "neg", "pos") - assert mode in ("extremum", "at_index") + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" + assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" unit_ids = waveform_extractor.sorting.unit_ids before = waveform_extractor.nbefore diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 32158f00df..4e98864ba9 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -95,7 +95,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None): try: property_dict[prop_name] = np.concatenate((property_dict[prop_name], values)) except Exception as e: - print(f"Skipping property '{prop_name}' for shape inconsistency") + print(f"Skipping property '{prop_name}' due to shape inconsistency") del property_dict[prop_name] break for prop_name, prop_values in property_dict.items(): From 2417b9af67a652f38e32cf24f749f9c7706554e9 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 5 Oct 2023 12:11:01 -0400 Subject: [PATCH 14/27] add asserts msgs and fix typos --- src/spikeinterface/preprocessing/clip.py | 2 +- src/spikeinterface/preprocessing/common_reference.py | 2 +- .../preprocessing/detect_bad_channels.py | 4 ++-- src/spikeinterface/preprocessing/filter.py | 6 +++--- src/spikeinterface/preprocessing/filter_opencl.py | 12 ++++++------ .../preprocessing/highpass_spatial_filter.py | 2 +- src/spikeinterface/preprocessing/normalize_scale.py | 4 ++-- src/spikeinterface/preprocessing/phase_shift.py | 2 +- 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index a2349c1ee9..cc18d51d2e 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -97,7 +97,7 @@ def __init__( chunk_size=500, seed=0, ): - assert direction in ("upper", "lower", "both") + assert direction in ("upper", "lower", "both"), "'direction' must be 'upper', 'lower', or 'both'" if fill_value is None or quantile_threshold is not None: random_data = get_random_data_chunks( diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index d2ac227217..6d6ce256de 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -83,7 +83,7 @@ def __init__( ref_channel_ids = np.asarray(ref_channel_ids) assert np.all( [ch in recording.get_channel_ids() for ch in ref_channel_ids] - ), "Some wrong 'ref_channel_ids'!" + ), "Some 'ref_channel_ids' are wrong!" elif reference == "local": assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index cc4e8601e2..e6e2836a35 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -211,9 +211,9 @@ def detect_bad_channels( if bad_channel_ids.size > recording.get_num_channels() / 3: warnings.warn( - "Over 1/3 of channels are detected as bad. In the precense of a high" + "Over 1/3 of channels are detected as bad. In the presence of a high" "number of dead / noisy channels, bad channel detection may fail " - "(erroneously label good channels as dead)." + "(good channels may be erroneously labeled as dead)." ) elif method == "neighborhood_r2": diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 51c1fb4ad6..b31088edf7 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -71,10 +71,10 @@ def __init__( ): import scipy.signal - assert filter_mode in ("sos", "ba") + assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'" fs = recording.get_sampling_frequency() if coeff is None: - assert btype in ("bandpass", "highpass") + assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'" # coefficient # self.coeff is 'sos' or 'ab' style filter_coeff = scipy.signal.iirfilter( @@ -258,7 +258,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): if dtype.kind == "u": raise TypeError( "The notch filter only supports signed types. Use the 'dtype' argument" - "to specify a signed type (e.g. 'int16', 'float32'" + "to specify a signed type (e.g. 'int16', 'float32')" ) BasePreprocessor.__init__(self, recording, dtype=dtype) diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index 790279d647..d3a08297c6 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -50,9 +50,9 @@ def __init__( margin_ms=5.0, ): assert HAVE_PYOPENCL, "You need to install pyopencl (and GPU driver!!)" - - assert btype in ("bandpass", "lowpass", "highpass", "bandstop") - assert filter_mode in ("sos",) + btype_modes = ("bandpass", "lowpass", "highpass", "bandstop") + assert btype in btype_modes, f"'btype' must be in {btype_modes}" + assert filter_mode in ("sos",), "'filter_mode' must be 'sos'" # coefficient sf = recording.get_sampling_frequency() @@ -96,8 +96,8 @@ def __init__(self, parent_recording_segment, executor, margin): self.margin = margin def get_traces(self, start_frame, end_frame, channel_indices): - assert start_frame is not None, "FilterOpenCLRecording work with fixed chunk_size" - assert end_frame is not None, "FilterOpenCLRecording work with fixed chunk_size" + assert start_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size" + assert end_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size" chunk_size = end_frame - start_frame if chunk_size != self.executor.chunk_size: @@ -157,7 +157,7 @@ def process(self, traces): if traces.shape[0] != self.full_size: if self.full_size is not None: - print(f"Warning : chunk_size have change {self.chunk_size} {traces.shape[0]}, need recompile CL!!!") + print(f"Warning : chunk_size has changed {self.chunk_size} {traces.shape[0]}, need to recompile CL!!!") self.create_buffers_and_compile() event = pyopencl.enqueue_copy(self.queue, self.input_cl, traces) diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index aa98410568..4df4a409bc 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -212,7 +212,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces * self.taper[np.newaxis, :] # apply actual HP filter - import scipy + import scipy.signal traces = scipy.signal.sosfiltfilt(self.sos_filter, traces, axis=1) diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 7d43982853..bd53866b6a 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -68,7 +68,7 @@ def __init__( dtype="float32", **random_chunk_kwargs, ): - assert mode in ("pool_channel", "by_channel") + assert mode in ("pool_channel", "by_channel"), "'mode' must be 'pool_channel' or 'by_channel'" random_data = get_random_data_chunks(recording, **random_chunk_kwargs) @@ -260,7 +260,7 @@ def __init__( dtype="float32", **random_chunk_kwargs, ): - assert mode in ("median+mad", "mean+std") + assert mode in ("median+mad", "mean+std"), "'mode' must be 'median+mad' or 'mean+std'" # fix dtype dtype_ = fix_dtype(recording, dtype) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 9c8b2589a0..237f32eca4 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -42,7 +42,7 @@ def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=Non assert "inter_sample_shift" in recording.get_property_keys(), "'inter_sample_shift' is not a property!" sample_shifts = recording.get_property("inter_sample_shift") else: - assert len(inter_sample_shift) == recording.get_num_channels(), "sample " + assert len(inter_sample_shift) == recording.get_num_channels(), "the 'inter_sample_shift' must be same size at the num_channels " sample_shifts = np.asarray(inter_sample_shift) margin = int(margin_ms * recording.get_sampling_frequency() / 1000.0) From 9db087de50bd4b132b5e42c743dcf17fa8a9106b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 16:27:04 +0000 Subject: [PATCH 15/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/baserecording.py | 3 ++- src/spikeinterface/core/template_tools.py | 13 ++++++++++--- src/spikeinterface/preprocessing/phase_shift.py | 4 +++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index d3572ef66b..2977211c25 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -305,7 +305,8 @@ def get_traces( if not self.has_scaled(): raise ValueError( - "This recording does not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)" + "This recording does not support return_scaled=True (need gain_to_uV and offset_" + "to_uV properties)" ) else: gains = self.get_property("gain_to_uV") diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 552642751c..b6022e27c0 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -6,7 +6,9 @@ from .recording_tools import get_channel_distances, get_noise_levels -def get_template_amplitudes(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum"): +def get_template_amplitudes( + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum" +): """ Get amplitude per channel for each unit. @@ -58,7 +60,10 @@ def get_template_amplitudes(waveform_extractor, peak_sign: "neg" | "pos" | "both def get_template_extremum_channel( - waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum", outputs: "id" | "index" = "id" + waveform_extractor, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" = "extremum", + outputs: "id" | "index" = "id", ): """ Compute the channel with the extremum peak for each unit. @@ -204,7 +209,9 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg return shifts -def get_template_extremum_amplitude(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index"): +def get_template_extremum_amplitude( + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index" +): """ Computes amplitudes on the best channel. diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 237f32eca4..bdba55038d 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -42,7 +42,9 @@ def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=Non assert "inter_sample_shift" in recording.get_property_keys(), "'inter_sample_shift' is not a property!" sample_shifts = recording.get_property("inter_sample_shift") else: - assert len(inter_sample_shift) == recording.get_num_channels(), "the 'inter_sample_shift' must be same size at the num_channels " + assert ( + len(inter_sample_shift) == recording.get_num_channels() + ), "the 'inter_sample_shift' must be same size at the num_channels " sample_shifts = np.asarray(inter_sample_shift) margin = int(margin_ms * recording.get_sampling_frequency() / 1000.0) From 57078791382deed5fe73c4799bd352e6c3e0ee80 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 18:39:27 +0200 Subject: [PATCH 16/27] Fix ipywidgets with explicit dense/sparse waveforms --- .../widgets/tests/test_widgets.py | 102 +++++++++--------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index f44878927d..da16136fa9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -49,28 +49,28 @@ def setUpClass(cls): cls.num_units = len(cls.sorting.get_unit_ids()) if (cache_folder / "mearec_test").is_dir(): - cls.we = load_waveforms(cache_folder / "mearec_test") + cls.we_dense = load_waveforms(cache_folder / "mearec_test") else: - cls.we = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test") + cls.we_dense = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test", sparse=False) sw.set_default_plotter_backend("matplotlib") metric_names = ["snr", "isi_violation", "num_spikes"] - _ = compute_spike_amplitudes(cls.we) - _ = compute_unit_locations(cls.we) - _ = compute_spike_locations(cls.we) - _ = compute_quality_metrics(cls.we, metric_names=metric_names) - _ = compute_template_metrics(cls.we) - _ = compute_correlograms(cls.we) - _ = compute_template_similarity(cls.we) + _ = compute_spike_amplitudes(cls.we_dense) + _ = compute_unit_locations(cls.we_dense) + _ = compute_spike_locations(cls.we_dense) + _ = compute_quality_metrics(cls.we_dense, metric_names=metric_names) + _ = compute_template_metrics(cls.we_dense) + _ = compute_correlograms(cls.we_dense) + _ = compute_template_similarity(cls.we_dense) # make sparse waveforms - cls.sparsity_radius = compute_sparsity(cls.we, method="radius", radius_um=50) - cls.sparsity_best = compute_sparsity(cls.we, method="best_channels", num_channels=5) + cls.sparsity_radius = compute_sparsity(cls.we_dense, method="radius", radius_um=50) + cls.sparsity_best = compute_sparsity(cls.we_dense, method="best_channels", num_channels=5) if (cache_folder / "mearec_test_sparse").is_dir(): cls.we_sparse = load_waveforms(cache_folder / "mearec_test_sparse") else: - cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) + cls.we_sparse = cls.we_dense.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) cls.skip_backends = ["ipywidgets", "ephyviewer"] @@ -124,17 +124,17 @@ def test_plot_unit_waveforms(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_waveforms(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_waveforms(self.we_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_waveforms( - self.we, + self.we_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend], ) sw.plot_unit_waveforms( - self.we, + self.we_dense, sparsity=self.sparsity_best, unit_ids=unit_ids, backend=backend, @@ -148,10 +148,10 @@ def test_plot_unit_templates(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_templates(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_templates(self.we_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_templates( - self.we, + self.we_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, @@ -171,7 +171,7 @@ def test_plot_unit_waveforms_density_map(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) def test_plot_unit_waveforms_density_map_sparsity_radius(self): @@ -180,7 +180,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we, + self.we_dense, sparsity=self.sparsity_radius, same_axis=False, unit_ids=unit_ids, @@ -234,11 +234,11 @@ def test_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_amplitudes(self.we, backend=backend, **self.backend_kwargs[backend]) - unit_ids = self.we.unit_ids[:4] - sw.plot_amplitudes(self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) + sw.plot_amplitudes(self.we_dense, backend=backend, **self.backend_kwargs[backend]) + unit_ids = self.we_dense.unit_ids[:4] + sw.plot_amplitudes(self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) sw.plot_amplitudes( - self.we, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend] ) sw.plot_amplitudes( self.we_sparse, @@ -252,9 +252,9 @@ def test_plot_all_amplitudes_distributions(self): possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - unit_ids = self.we.unit_ids[:4] + unit_ids = self.we_dense.unit_ids[:4] sw.plot_all_amplitudes_distributions( - self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) sw.plot_all_amplitudes_distributions( self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] @@ -264,7 +264,7 @@ def test_unit_locations(self): possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_locations(self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) sw.plot_unit_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -273,7 +273,7 @@ def test_spike_locations(self): possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_spike_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_spike_locations(self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) sw.plot_spike_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -282,28 +282,28 @@ def test_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_similarity(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_similarity(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_quality_metrics(self): possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_quality_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_quality_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_template_metrics(self): possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_depths(self): possible_backends = list(sw.UnitDepthsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_depths(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_depths(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_unit_depths(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_summary(self): @@ -311,17 +311,17 @@ def test_plot_unit_summary(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( - self.we, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.we_dense, self.we_dense.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) sw.plot_unit_summary( - self.we_sparse, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.we_sparse, self.we_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) def test_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_agreement_matrix(self): @@ -355,23 +355,23 @@ def test_plot_rasters(self): mytest = TestWidgets() mytest.setUpClass() - # mytest.test_plot_unit_waveforms_density_map() - # mytest.test_plot_unit_summary() - # mytest.test_plot_all_amplitudes_distributions() - # mytest.test_plot_traces() - # mytest.test_plot_unit_waveforms() - # mytest.test_plot_unit_templates() - # mytest.test_plot_unit_templates() - # mytest.test_plot_unit_depths() - # mytest.test_plot_unit_templates() - # mytest.test_plot_unit_summary() - # mytest.test_unit_locations() - # mytest.test_quality_metrics() - # mytest.test_template_metrics() - # mytest.test_amplitudes() - # mytest.test_plot_agreement_matrix() - # mytest.test_plot_confusion_matrix() - # mytest.test_plot_probe_map() + mytest.test_plot_unit_waveforms_density_map() + mytest.test_plot_unit_summary() + mytest.test_plot_all_amplitudes_distributions() + mytest.test_plot_traces() + mytest.test_plot_unit_waveforms() + mytest.test_plot_unit_templates() + mytest.test_plot_unit_templates() + mytest.test_plot_unit_depths() + mytest.test_plot_unit_templates() + mytest.test_plot_unit_summary() + mytest.test_unit_locations() + mytest.test_quality_metrics() + mytest.test_template_metrics() + mytest.test_amplitudes() + mytest.test_plot_agreement_matrix() + mytest.test_plot_confusion_matrix() + mytest.test_plot_probe_map() mytest.test_plot_rasters() # plt.ion() From 3ac58086dd8d46e02d433ee840378617d5d42e9d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 06:31:41 +0000 Subject: [PATCH 17/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/tests/test_widgets.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index da16136fa9..ca53d85648 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -238,7 +238,11 @@ def test_amplitudes(self): unit_ids = self.we_dense.unit_ids[:4] sw.plot_amplitudes(self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) sw.plot_amplitudes( - self.we_dense, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend] + self.we_dense, + unit_ids=unit_ids, + plot_histograms=True, + backend=backend, + **self.backend_kwargs[backend], ) sw.plot_amplitudes( self.we_sparse, @@ -264,7 +268,9 @@ def test_unit_locations(self): possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_locations(self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_locations( + self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_unit_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -273,7 +279,9 @@ def test_spike_locations(self): possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_spike_locations(self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_spike_locations( + self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_spike_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) From 3448e1ec4b19d5f5091ba6a2792362cf35a9f941 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 08:57:56 +0200 Subject: [PATCH 18/27] Fix plot_traces with ipywidgets when channel_ids is not None --- src/spikeinterface/widgets/traces.py | 10 ++++++---- src/spikeinterface/widgets/utils_ipywidgets.py | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 9b6716e8f3..2783b6a369 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -138,9 +138,10 @@ def __init__( # colors is a nested dict by layer and channels # lets first create black for all channels and layer + # all color are generated for ipywidgets colors = {} for k in layer_keys: - colors[k] = {chan_id: "k" for chan_id in channel_ids} + colors[k] = {chan_id: "k" for chan_id in rec0.channel_ids} if color_groups: channel_groups = rec0.get_channel_groups(channel_ids=channel_ids) @@ -149,7 +150,7 @@ def __init__( group_colors = get_some_colors(groups, color_engine="auto") channel_colors = {} - for i, chan_id in enumerate(channel_ids): + for i, chan_id in enumerate(rec0.channel_ids): group = channel_groups[i] channel_colors[chan_id] = group_colors[group] @@ -159,12 +160,12 @@ def __init__( elif color is not None: # old behavior one color for all channel # if multi layer then black for all - colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids} + colors[layer_keys[0]] = {chan_id: color for chan_id in rec0.channel_ids} elif color is None and len(recordings) > 1: # several layer layer_colors = get_some_colors(layer_keys) for k in layer_keys: - colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids} + colors[k] = {chan_id: layer_colors[k] for chan_id in rec0.channel_ids} else: # color is None unique layer : all channels black pass @@ -336,6 +337,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) + self.channel_selector.value = data_plot["channel_ids"] left_sidebar = W.VBox( children=[ diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 6e872eca55..5bbe31302c 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -235,8 +235,7 @@ def __init__(self, channel_ids, **kwargs): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.selector.observe(self.on_selector_changed, names=["value"], type="change") - # TODO external value change - # self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=['value'], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value @@ -259,6 +258,19 @@ def on_selector_changed(self, change=None): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.value = channel_ids + + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + channel_ids = self.selector.value + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + class ScaleWidget(W.VBox): From e51bb75f226c7c2be97c4a6ceeae460a7c610efe Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 09:25:35 +0200 Subject: [PATCH 19/27] Fix order_channel_by_depth in ipywidgets Fix order_channel_by_depth when channel_ids is given. --- src/spikeinterface/widgets/traces.py | 58 +++++++++++++++------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 2783b6a369..802f90c62a 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,6 +88,26 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") + if "location" in rec0.get_property_keys(): + channel_locations = rec0.get_channel_locations() + else: + channel_locations = None + + if order_channel_by_depth and channel_locations is not None: + from ..preprocessing import depth_order + rec0 = depth_order(rec0) + recordings = {k: depth_order(rec) for k, rec in recordings.items()} + + if channel_ids is not None: + # ensure that channel_ids are in the good order + channel_ids_ = list(rec0.channel_ids) + order = np.argsort([channel_ids_.index(c) for c in channel_ids]) + channel_ids = list(np.array(channel_ids)[order]) + + if channel_ids is None: + channel_ids = rec0.channel_ids + + layer_keys = list(recordings.keys()) if segment_index is None: @@ -95,19 +115,6 @@ def __init__( raise ValueError("You must provide segment_index=...") segment_index = 0 - if channel_ids is None: - channel_ids = rec0.channel_ids - - if "location" in rec0.get_property_keys(): - channel_locations = rec0.get_channel_locations() - else: - channel_locations = None - - if order_channel_by_depth: - if channel_locations is not None: - order, _ = order_channels_by_depth(rec0, channel_ids) - else: - order = None fs = rec0.get_sampling_frequency() if time_range is None: @@ -124,7 +131,7 @@ def __init__( cmap = cmap times, list_traces, frame_range, channel_ids = _get_trace_list( - recordings, channel_ids, time_range, segment_index, order, return_scaled + recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled ) # stat for auto scaling done on the first layer @@ -202,7 +209,6 @@ def __init__( show_channel_ids=show_channel_ids, add_legend=add_legend, order_channel_by_depth=order_channel_by_depth, - order=order, tile_size=tile_size, num_timepoints_per_row=int(seconds_per_row * fs), return_scaled=return_scaled, @@ -337,7 +343,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) - self.channel_selector.value = data_plot["channel_ids"] + self.channel_selector.value = list(data_plot["channel_ids"]) left_sidebar = W.VBox( children=[ @@ -400,17 +406,17 @@ def _mode_changed(self, change=None): def _retrieve_traces(self, change=None): channel_ids = np.array(self.channel_selector.value) - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None + # if self.data_plot["order_channel_by_depth"]: + # order, _ = order_channels_by_depth(self.rec0, channel_ids) + # else: + # order = None start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times, list_traces, frame_range, channel_ids = _get_trace_list( - self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled + self._selected_recordings, channel_ids, time_range, segment_index, return_scaled=self.return_scaled ) self._channel_ids = channel_ids @@ -525,7 +531,7 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs): app.exec() -def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): +def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_scaled=False): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] rec0 = recordings[k0] @@ -552,11 +558,11 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=No return_scaled=return_scaled, ) - if order is not None: - traces = traces[:, order] + # if order is not None: + # traces = traces[:, order] list_traces.append(traces) - if order is not None: - channel_ids = np.array(channel_ids)[order] + # if order is not None: + # channel_ids = np.array(channel_ids)[order] return times, list_traces, frame_range, channel_ids From bc3234cc4ce7d35cd62e0c29e33e38002f43ecd0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 09:52:20 +0200 Subject: [PATCH 20/27] More fix in widgets due to sparse=True by default --- .../tests/test_widgets_legacy.py | 6 +- .../widgets/tests/test_widgets.py | 57 +++++++++---------- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 39eb80e2e5..8814e0131a 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -32,10 +32,10 @@ def setUp(self): self.num_units = len(self._sorting.get_unit_ids()) #  self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True) - if (cache_folder / "mearec_test").is_dir(): - self._we = load_waveforms(cache_folder / "mearec_test") + if (cache_folder / "mearec_test_old_api").is_dir(): + self._we = load_waveforms(cache_folder / "mearec_test_old_api") else: - self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test") + self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test_old_api", sparse=False) self._amplitudes = compute_spike_amplitudes(self._we, peak_sign="neg", outputs="by_unit") self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index ca53d85648..5f1a936a6e 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -48,22 +48,21 @@ def setUpClass(cls): cls.sorting = se.MEArecSortingExtractor(local_path) cls.num_units = len(cls.sorting.get_unit_ids()) - if (cache_folder / "mearec_test").is_dir(): - cls.we_dense = load_waveforms(cache_folder / "mearec_test") + if (cache_folder / "mearec_test_dense").is_dir(): + cls.we_dense = load_waveforms(cache_folder / "mearec_test_dense") else: - cls.we_dense = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test", sparse=False) + cls.we_dense = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test_dense", sparse=False) + metric_names = ["snr", "isi_violation", "num_spikes"] + _ = compute_spike_amplitudes(cls.we_dense) + _ = compute_unit_locations(cls.we_dense) + _ = compute_spike_locations(cls.we_dense) + _ = compute_quality_metrics(cls.we_dense, metric_names=metric_names) + _ = compute_template_metrics(cls.we_dense) + _ = compute_correlograms(cls.we_dense) + _ = compute_template_similarity(cls.we_dense) sw.set_default_plotter_backend("matplotlib") - metric_names = ["snr", "isi_violation", "num_spikes"] - _ = compute_spike_amplitudes(cls.we_dense) - _ = compute_unit_locations(cls.we_dense) - _ = compute_spike_locations(cls.we_dense) - _ = compute_quality_metrics(cls.we_dense, metric_names=metric_names) - _ = compute_template_metrics(cls.we_dense) - _ = compute_correlograms(cls.we_dense) - _ = compute_template_similarity(cls.we_dense) - # make sparse waveforms cls.sparsity_radius = compute_sparsity(cls.we_dense, method="radius", radius_um=50) cls.sparsity_best = compute_sparsity(cls.we_dense, method="best_channels", num_channels=5) @@ -363,24 +362,24 @@ def test_plot_rasters(self): mytest = TestWidgets() mytest.setUpClass() - mytest.test_plot_unit_waveforms_density_map() - mytest.test_plot_unit_summary() - mytest.test_plot_all_amplitudes_distributions() - mytest.test_plot_traces() - mytest.test_plot_unit_waveforms() - mytest.test_plot_unit_templates() - mytest.test_plot_unit_templates() - mytest.test_plot_unit_depths() - mytest.test_plot_unit_templates() - mytest.test_plot_unit_summary() - mytest.test_unit_locations() - mytest.test_quality_metrics() - mytest.test_template_metrics() - mytest.test_amplitudes() + # mytest.test_plot_unit_waveforms_density_map() + # mytest.test_plot_unit_summary() + # mytest.test_plot_all_amplitudes_distributions() + # mytest.test_plot_traces() + # mytest.test_plot_unit_waveforms() + # mytest.test_plot_unit_templates() + # mytest.test_plot_unit_templates() + # mytest.test_plot_unit_depths() + # mytest.test_plot_unit_templates() + # mytest.test_plot_unit_summary() + # mytest.test_unit_locations() + # mytest.test_quality_metrics() + # mytest.test_template_metrics() + # mytest.test_amplitudes() mytest.test_plot_agreement_matrix() - mytest.test_plot_confusion_matrix() - mytest.test_plot_probe_map() - mytest.test_plot_rasters() + # mytest.test_plot_confusion_matrix() + # mytest.test_plot_probe_map() + # mytest.test_plot_rasters() # plt.ion() plt.show() From 7cd60ac434288e7eb9d43684e0b575396f70daaa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 07:52:41 +0000 Subject: [PATCH 21/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/tests/test_widgets.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 5f1a936a6e..1a2fdf38d9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -51,7 +51,9 @@ def setUpClass(cls): if (cache_folder / "mearec_test_dense").is_dir(): cls.we_dense = load_waveforms(cache_folder / "mearec_test_dense") else: - cls.we_dense = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test_dense", sparse=False) + cls.we_dense = extract_waveforms( + cls.recording, cls.sorting, cache_folder / "mearec_test_dense", sparse=False + ) metric_names = ["snr", "isi_violation", "num_spikes"] _ = compute_spike_amplitudes(cls.we_dense) _ = compute_unit_locations(cls.we_dense) @@ -366,7 +368,7 @@ def test_plot_rasters(self): # mytest.test_plot_unit_summary() # mytest.test_plot_all_amplitudes_distributions() # mytest.test_plot_traces() - # mytest.test_plot_unit_waveforms() + # mytest.test_plot_unit_waveforms() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_depths() From 5c5f32fb0df19cb5faf7e24c11758639c1740f18 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 09:53:33 +0200 Subject: [PATCH 22/27] yep --- src/spikeinterface/widgets/traces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 802f90c62a..d010c96a27 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,7 +88,7 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") - if "location" in rec0.get_property_keys(): + if rec0.has_channel_locations(): channel_locations = rec0.get_channel_locations() else: channel_locations = None From 986d6d9f26417740dd7162e671db3082363930f6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 10:20:20 +0200 Subject: [PATCH 23/27] Fix fix with sparse waveform extractor --- src/spikeinterface/exporters/tests/test_export_to_phy.py | 6 +++--- src/spikeinterface/exporters/to_phy.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index 7528f0ebf9..39bb875ea8 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -78,7 +78,7 @@ def test_export_to_phy_by_property(): recording = recording.save(folder=rec_folder) sorting = sorting.save(folder=sort_folder) - waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) + waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False) sparsity_group = compute_sparsity(waveform_extractor, method="by_property", by_property="group") export_to_phy( waveform_extractor, @@ -96,7 +96,7 @@ def test_export_to_phy_by_property(): # Remove one channel recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7]) - waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm) + waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm, sparse=False) sparsity_group = compute_sparsity(waveform_extractor_rm, method="by_property", by_property="group") export_to_phy( @@ -130,7 +130,7 @@ def test_export_to_phy_by_sparsity(): if f.is_dir(): shutil.rmtree(f) - waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) + waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False) sparsity_radius = compute_sparsity(waveform_extractor, method="radius", radius_um=50.0) export_to_phy( waveform_extractor, diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index ebc810b953..31a452f389 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -94,6 +94,7 @@ def export_to_phy( if waveform_extractor.is_sparse(): used_sparsity = waveform_extractor.sparsity + assert sparsity is None elif sparsity is not None: used_sparsity = sparsity else: From 63494f2a44424085d7ad22935313f9cbd2c8b88c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 09:11:43 +0000 Subject: [PATCH 24/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/traces.py | 3 +-- src/spikeinterface/widgets/utils_ipywidgets.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index d010c96a27..7a4306b284 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -95,6 +95,7 @@ def __init__( if order_channel_by_depth and channel_locations is not None: from ..preprocessing import depth_order + rec0 = depth_order(rec0) recordings = {k: depth_order(rec) for k, rec in recordings.items()} @@ -107,7 +108,6 @@ def __init__( if channel_ids is None: channel_ids = rec0.channel_ids - layer_keys = list(recordings.keys()) if segment_index is None: @@ -115,7 +115,6 @@ def __init__( raise ValueError("You must provide segment_index=...") segment_index = 0 - fs = rec0.get_sampling_frequency() if time_range is None: time_range = (0, 1.0) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 5bbe31302c..58dd5c7f32 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -235,7 +235,7 @@ def __init__(self, channel_ids, **kwargs): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.selector.observe(self.on_selector_changed, names=["value"], type="change") - self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value @@ -258,7 +258,7 @@ def on_selector_changed(self, change=None): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.value = channel_ids - + def value_changed(self, change=None): self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") self.selector.value = change["new"] @@ -272,7 +272,6 @@ def value_changed(self, change=None): self.slider.observe(self.on_slider_changed, names=["value"], type="change") - class ScaleWidget(W.VBox): value = traitlets.Float() From 5660de282ac43d96324184d47aa2d951910d6fec Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Oct 2023 11:16:24 +0200 Subject: [PATCH 25/27] Simplify parsing in cellexplorer --- src/spikeinterface/extractors/cellexplorersortingextractor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 0096a40a79..0980e89f1c 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -118,7 +118,6 @@ def __init__( spike_times = spikes_data["times"] # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames - unit_ids = unit_ids[:].astype(int).tolist() unit_ids = [str(unit_id) for unit_id in unit_ids] spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)} for unit_id in unit_ids: From c0d4c60095f9704f9b27adfb5fa0f4867adfaf10 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 11:38:15 +0200 Subject: [PATCH 26/27] oups --- src/spikeinterface/widgets/traces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index d010c96a27..ce34af0bfa 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,7 +88,7 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") - if rec0.has_channel_locations(): + if rec0.has_channel_location(): channel_locations = rec0.get_channel_locations() else: channel_locations = None From 2907934928719cf8d0403a2c55628645483187f7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 11:48:37 +0200 Subject: [PATCH 27/27] clean --- src/spikeinterface/widgets/traces.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 5a8212302c..fc8b30eb05 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -557,11 +557,6 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_s return_scaled=return_scaled, ) - # if order is not None: - # traces = traces[:, order] list_traces.append(traces) - # if order is not None: - # channel_ids = np.array(channel_ids)[order] - return times, list_traces, frame_range, channel_ids