diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 8b4f094c20..ad31b97d8e 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -45,8 +45,12 @@ def __init__(self, main_ids: Sequence) -> None: self._kwargs = {} # 'main_ids' will either be channel_ids or units_ids - # They is used for properties + # They are used for properties self._main_ids = np.array(main_ids) + if len(self._main_ids) > 0: + assert ( + self._main_ids.dtype.kind in "uiSU" + ), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}" # dict at object level self._annotations = {} @@ -984,7 +988,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: class_name = None if "kwargs" not in dic: - raise Exception(f"This dict cannot be load into extractor {dic}") + raise Exception(f"This dict cannot be loaded into extractor {dic}") # Create new kwargs to avoid modifying the original dict["kwargs"] new_kwargs = dict() @@ -1005,7 +1009,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class" if not _check_same_version(class_name, dic["version"]): warnings.warn( - f"Versions are not the same. This might lead compatibility errors. " + f"Versions are not the same. This might lead to compatibility errors. " f"Using {class_name.split('.')[0]}=={dic['version']} is recommended" ) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 08f187895b..2977211c25 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -305,7 +305,8 @@ def get_traces( if not self.has_scaled(): raise ValueError( - "This recording do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)" + "This recording does not support return_scaled=True (need gain_to_uV and offset_" + "to_uV properties)" ) else: gains = self.get_property("gain_to_uV") @@ -416,8 +417,8 @@ def set_times(self, times, segment_index=None, with_warning=True): if with_warning: warn( "Setting times with Recording.set_times() is not recommended because " - "times are not always propagated to across preprocessing" - "Use use this carefully!" + "times are not always propagated across preprocessing" + "Use this carefully!" ) def sample_index_to_time(self, sample_ind, segment_index=None): diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index affde8a75e..d411f38d2a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,4 +1,4 @@ -from typing import List +from __future__ import annotations from pathlib import Path import numpy as np @@ -19,7 +19,7 @@ class BaseRecordingSnippets(BaseExtractor): has_default_locations = False - def __init__(self, sampling_frequency: float, channel_ids: List, dtype): + def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype): BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = sampling_frequency self._dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index f35bc2b266..b4e3c11f55 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -1,10 +1,8 @@ from typing import List, Union -from pathlib import Path from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets import numpy as np from warnings import warn -from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes # snippets segments? diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index e6d08d38f7..2a06a699cb 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -170,7 +170,7 @@ def register_recording(self, recording, check_spike_frames=True): if check_spike_frames: if has_exceeding_spikes(recording, self): warnings.warn( - "Some spikes are exceeding the recording's duration! " + "Some spikes exceed the recording's duration! " "Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` " "Might be necessary for further postprocessing." ) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 72a95637f6..b45290caa5 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -91,7 +91,7 @@ def __init__( file_path_list = [Path(file_paths)] if t_starts is not None: - assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths" + assert len(t_starts) == len(file_path_list), "t_starts must be a list of the same size as file_paths" t_starts = [float(t_start) for t_start in t_starts] dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index d36e168f8d..8714580821 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -104,11 +104,11 @@ def __init__(self, channel_map, parent_segments): times_kargs0 = parent_segment0.get_times_kwargs() if times_kargs0["time_vector"] is None: for ps in parent_segments: - assert ps.get_times_kwargs()["time_vector"] is None, "All segment should not have times set" + assert ps.get_times_kwargs()["time_vector"] is None, "All segments should not have times set" else: for ps in parent_segments: assert ps.get_times_kwargs()["t_start"] == times_kargs0["t_start"], ( - "All segment should have the same " "t_start" + "All segments should have the same " "t_start" ) BaseRecordingSegment.__init__(self, **times_kargs0) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index ebd1b7db03..3a21e356a6 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -35,7 +35,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) ), "ChannelSliceRecording: renamed channel_ids must be the same size" assert ( self._channel_ids.size == np.unique(self._channel_ids).size - ), "ChannelSliceRecording : channel_ids not unique" + ), "ChannelSliceRecording : channel_ids are not unique" sampling_frequency = parent_recording.get_sampling_frequency() @@ -123,7 +123,7 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): ), "ChannelSliceSnippets: renamed channel_ids must be the same size" assert ( self._channel_ids.size == np.unique(self._channel_ids).size - ), "ChannelSliceSnippets : channel_ids not unique" + ), "ChannelSliceSnippets : channel_ids are not unique" sampling_frequency = parent_snippets.get_sampling_frequency() diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 968f27c6ad..b8574c506f 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -27,7 +27,7 @@ class FrameSliceRecording(BaseRecording): def __init__(self, parent_recording, start_frame=None, end_frame=None): channel_ids = parent_recording.get_channel_ids() - assert parent_recording.get_num_segments() == 1, "FrameSliceRecording work only with one segment" + assert parent_recording.get_num_segments() == 1, "FrameSliceRecording only works with one segment" parent_size = parent_recording.get_num_samples(0) if start_frame is None: diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index 5da5350f06..ed1391b0e2 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -36,7 +36,7 @@ class FrameSliceSorting(BaseSorting): def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike_frames=True): unit_ids = parent_sorting.get_unit_ids() - assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting work only with one segment" + assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting only works with one segment" if start_frame is None: start_frame = 0 @@ -49,10 +49,10 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike end_frame = parent_n_samples assert ( end_frame <= parent_n_samples - ), "`end_frame` should be smaller than the sortings total number of samples." + ), "`end_frame` should be smaller than the sortings' total number of samples." assert ( start_frame <= parent_n_samples - ), "`start_frame` should be smaller than the sortings total number of samples." + ), "`start_frame` should be smaller than the sortings' total number of samples." if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): raise ValueError( "The sorting object has spikes exceeding the recording duration. You have to remove those spikes " @@ -67,7 +67,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike end_frame = max_spike_time + 1 assert start_frame < end_frame, ( - "`start_frame` should be greater than `end_frame`. " + "`start_frame` should be less than `end_frame`. " "This may be due to start_frame >= max_spike_time, if the end frame " "was not specified explicitly." ) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 06a5ec96ec..0c67404069 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1101,11 +1101,11 @@ def __init__( # handle also upsampling and jitter upsample_factor = templates.shape[3] elif templates.ndim == 5: - # handle also dirft + # handle also drift raise NotImplementedError("Drift will be implented soon...") # upsample_factor = templates.shape[3] else: - raise ValueError("templates have wring dim should 3 or 4") + raise ValueError("templates have wrong dim should 3 or 4") if upsample_factor is not None: assert upsample_vector is not None diff --git a/src/spikeinterface/core/npysnippetsextractor.py b/src/spikeinterface/core/npysnippetsextractor.py index 80979ce6c9..69c48356e5 100644 --- a/src/spikeinterface/core/npysnippetsextractor.py +++ b/src/spikeinterface/core/npysnippetsextractor.py @@ -27,6 +27,9 @@ def __init__( num_segments = len(file_paths) data = np.load(file_paths[0], mmap_mode="r") + if channel_ids is None: + channel_ids = np.arange(data["snippet"].shape[2]) + BaseSnippets.__init__( self, sampling_frequency, @@ -84,7 +87,7 @@ def write_snippets(snippets, file_paths, dtype=None): arr = np.empty(n, dtype=snippets_t, order="F") arr["frame"] = snippets.get_frames(segment_index=i) arr["snippet"] = snippets.get_snippets(segment_index=i).astype(dtype, copy=False) - + file_paths[i].parent.mkdir(parents=True, exist_ok=True) np.save(file_paths[i], arr) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 8c5c62d568..896e3800d7 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -102,7 +102,11 @@ def __init__(self, mask, unit_ids, channel_ids): self.num_channels = self.channel_ids.size self.num_units = self.unit_ids.size - self.max_num_active_channels = self.mask.sum(axis=1).max() + if self.mask.shape[0]: + self.max_num_active_channels = self.mask.sum(axis=1).max() + else: + # empty sorting without units + self.max_num_active_channels = 0 def __repr__(self): density = np.mean(self.mask) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 95278b76da..b6022e27c0 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np import warnings @@ -5,7 +6,9 @@ from .recording_tools import get_channel_distances, get_noise_levels -def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: str = "extremum"): +def get_template_amplitudes( + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum" +): """ Get amplitude per channel for each unit. @@ -13,9 +16,9 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "extremum" 'extremum': max or min 'at_index': take value at spike index @@ -24,8 +27,8 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st peak_values: dict Dictionary with unit ids as keys and template amplitudes as values """ - assert peak_sign in ("both", "neg", "pos") - assert mode in ("extremum", "at_index") + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" + assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" unit_ids = waveform_extractor.sorting.unit_ids before = waveform_extractor.nbefore @@ -57,7 +60,10 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st def get_template_extremum_channel( - waveform_extractor, peak_sign: str = "neg", mode: str = "extremum", outputs: str = "id" + waveform_extractor, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" = "extremum", + outputs: "id" | "index" = "id", ): """ Compute the channel with the extremum peak for each unit. @@ -66,12 +72,12 @@ def get_template_extremum_channel( ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "extremum" 'extremum': max or min 'at_index': take value at spike index - outputs: str + outputs: "id" | "index", default: "id" * 'id': channel id * 'index': channel index @@ -159,7 +165,7 @@ def get_template_channel_sparsity( get_template_channel_sparsity.__doc__ = get_template_channel_sparsity.__doc__.format(_sparsity_doc) -def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str = "neg"): +def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg"): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. @@ -169,8 +175,8 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') + peak_sign: "neg" | "pos" | "both", default: "neg" + Sign of the template to compute best channels Returns ------- @@ -203,7 +209,9 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str return shifts -def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", mode: str = "at_index"): +def get_template_extremum_amplitude( + waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index" +): """ Computes amplitudes on the best channel. @@ -211,9 +219,9 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", ---------- waveform_extractor: WaveformExtractor The waveform extractor - peak_sign: str - Sign of the template to compute best channels ('neg', 'pos', 'both') - mode: str + peak_sign: "neg" | "pos" | "both" + Sign of the template to compute best channels + mode: "extremum" | "at_index", default: "at_index" Where the amplitude is computed 'extremum': max or min 'at_index': take value at spike index @@ -223,8 +231,8 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", amplitudes: dict Dictionary with unit ids as keys and amplitudes as values """ - assert peak_sign in ("both", "neg", "pos") - assert mode in ("extremum", "at_index") + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" + assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'" unit_ids = waveform_extractor.sorting.unit_ids before = waveform_extractor.nbefore diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index b56180a9e9..204f796c0e 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -559,3 +559,4 @@ def test_non_json_object(): test_recordingless() # test_compute_sparsity() # test_non_json_object() + test_empty_sorting() diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 32158f00df..4e98864ba9 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -95,7 +95,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None): try: property_dict[prop_name] = np.concatenate((property_dict[prop_name], values)) except Exception as e: - print(f"Skipping property '{prop_name}' for shape inconsistency") + print(f"Skipping property '{prop_name}' due to shape inconsistency") del property_dict[prop_name] break for prop_name, prop_values in property_dict.items(): diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 576a0a1a58..0fc5694207 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1457,13 +1457,13 @@ def extract_waveforms( folder=None, mode="folder", precompute_template=("average",), - ms_before=3.0, - ms_after=4.0, + ms_before=1.0, + ms_after=2.0, max_spikes_per_unit=500, overwrite=False, return_scaled=True, dtype=None, - sparse=False, + sparse=True, sparsity=None, num_spikes_for_sparsity=100, allow_unfiltered=False, @@ -1507,7 +1507,7 @@ def extract_waveforms( If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV. dtype: dtype or None Dtype of the output waveforms. If None, the recording dtype is maintained. - sparse: bool (default False) + sparse: bool, default: True If True, before extracting all waveforms the `precompute_sparsity()` function is run using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the waveforms will be sparse at extraction time, which saves a lot of memory. @@ -1726,6 +1726,7 @@ def precompute_sparsity( max_spikes_per_unit=num_spikes_for_sparsity, return_scaled=False, allow_unfiltered=allow_unfiltered, + sparse=False, **job_kwargs, ) local_sparsity = compute_sparsity(local_we, **sparse_kwargs) diff --git a/src/spikeinterface/exporters/tests/test_export_to_phy.py b/src/spikeinterface/exporters/tests/test_export_to_phy.py index 7528f0ebf9..39bb875ea8 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_phy.py +++ b/src/spikeinterface/exporters/tests/test_export_to_phy.py @@ -78,7 +78,7 @@ def test_export_to_phy_by_property(): recording = recording.save(folder=rec_folder) sorting = sorting.save(folder=sort_folder) - waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) + waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False) sparsity_group = compute_sparsity(waveform_extractor, method="by_property", by_property="group") export_to_phy( waveform_extractor, @@ -96,7 +96,7 @@ def test_export_to_phy_by_property(): # Remove one channel recording_rm = recording.channel_slice([0, 2, 3, 4, 5, 6, 7]) - waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm) + waveform_extractor_rm = extract_waveforms(recording_rm, sorting, waveform_folder_rm, sparse=False) sparsity_group = compute_sparsity(waveform_extractor_rm, method="by_property", by_property="group") export_to_phy( @@ -130,7 +130,7 @@ def test_export_to_phy_by_sparsity(): if f.is_dir(): shutil.rmtree(f) - waveform_extractor = extract_waveforms(recording, sorting, waveform_folder) + waveform_extractor = extract_waveforms(recording, sorting, waveform_folder, sparse=False) sparsity_radius = compute_sparsity(waveform_extractor, method="radius", radius_um=50.0) export_to_phy( waveform_extractor, diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index ebc810b953..31a452f389 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -94,6 +94,7 @@ def export_to_phy( if waveform_extractor.is_sparse(): used_sparsity = waveform_extractor.sparsity + assert sparsity is None elif sparsity is not None: used_sparsity = sparsity else: diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 31241a4147..0980e89f1c 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -118,7 +118,7 @@ def __init__( spike_times = spikes_data["times"] # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames - unit_ids = unit_ids[:].tolist() + unit_ids = [str(unit_id) for unit_id in unit_ids] spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)} for unit_id in unit_ids: spiketrains_dict[unit_id] = (sampling_frequency * spiketrains_dict[unit_id]).round().astype(np.int64) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 6cd5238abd..6e693635eb 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -137,8 +137,8 @@ def compute_crosscorrelogram_from_spiketrain(spike_times1, spike_times2, window_ def compute_correlograms( waveform_or_sorting_extractor, load_if_exists=False, - window_ms: float = 100.0, - bin_ms: float = 5.0, + window_ms: float = 50.0, + bin_ms: float = 1.0, method: str = "auto", ): """Compute auto and cross correlograms. diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 8f864e9b84..50e2ecdb57 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -57,6 +57,7 @@ def setUp(self): ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, + sparse=False, n_jobs=1, chunk_size=30000, overwrite=True, @@ -92,6 +93,7 @@ def setUp(self): ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, + sparse=False, n_jobs=1, chunk_size=30000, overwrite=True, @@ -112,6 +114,7 @@ def setUp(self): recording, sorting, mode="memory", + sparse=False, ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500, diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index d2739f69dd..48ceb34a4e 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -96,7 +96,7 @@ def get_extension_function(): def compute_unit_locations( - waveform_extractor, load_if_exists=False, method="center_of_mass", outputs="numpy", **method_kwargs + waveform_extractor, load_if_exists=False, method="monopolar_triangulation", outputs="numpy", **method_kwargs ): """ Localize units in 2D or 3D with several methods given the template. diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index a2349c1ee9..cc18d51d2e 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -97,7 +97,7 @@ def __init__( chunk_size=500, seed=0, ): - assert direction in ("upper", "lower", "both") + assert direction in ("upper", "lower", "both"), "'direction' must be 'upper', 'lower', or 'both'" if fill_value is None or quantile_threshold is not None: random_data = get_random_data_chunks( diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index d2ac227217..6d6ce256de 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -83,7 +83,7 @@ def __init__( ref_channel_ids = np.asarray(ref_channel_ids) assert np.all( [ch in recording.get_channel_ids() for ch in ref_channel_ids] - ), "Some wrong 'ref_channel_ids'!" + ), "Some 'ref_channel_ids' are wrong!" elif reference == "local": assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index cc4e8601e2..e6e2836a35 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -211,9 +211,9 @@ def detect_bad_channels( if bad_channel_ids.size > recording.get_num_channels() / 3: warnings.warn( - "Over 1/3 of channels are detected as bad. In the precense of a high" + "Over 1/3 of channels are detected as bad. In the presence of a high" "number of dead / noisy channels, bad channel detection may fail " - "(erroneously label good channels as dead)." + "(good channels may be erroneously labeled as dead)." ) elif method == "neighborhood_r2": diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 51c1fb4ad6..b31088edf7 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -71,10 +71,10 @@ def __init__( ): import scipy.signal - assert filter_mode in ("sos", "ba") + assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'" fs = recording.get_sampling_frequency() if coeff is None: - assert btype in ("bandpass", "highpass") + assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'" # coefficient # self.coeff is 'sos' or 'ab' style filter_coeff = scipy.signal.iirfilter( @@ -258,7 +258,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): if dtype.kind == "u": raise TypeError( "The notch filter only supports signed types. Use the 'dtype' argument" - "to specify a signed type (e.g. 'int16', 'float32'" + "to specify a signed type (e.g. 'int16', 'float32')" ) BasePreprocessor.__init__(self, recording, dtype=dtype) diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index 790279d647..d3a08297c6 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -50,9 +50,9 @@ def __init__( margin_ms=5.0, ): assert HAVE_PYOPENCL, "You need to install pyopencl (and GPU driver!!)" - - assert btype in ("bandpass", "lowpass", "highpass", "bandstop") - assert filter_mode in ("sos",) + btype_modes = ("bandpass", "lowpass", "highpass", "bandstop") + assert btype in btype_modes, f"'btype' must be in {btype_modes}" + assert filter_mode in ("sos",), "'filter_mode' must be 'sos'" # coefficient sf = recording.get_sampling_frequency() @@ -96,8 +96,8 @@ def __init__(self, parent_recording_segment, executor, margin): self.margin = margin def get_traces(self, start_frame, end_frame, channel_indices): - assert start_frame is not None, "FilterOpenCLRecording work with fixed chunk_size" - assert end_frame is not None, "FilterOpenCLRecording work with fixed chunk_size" + assert start_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size" + assert end_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size" chunk_size = end_frame - start_frame if chunk_size != self.executor.chunk_size: @@ -157,7 +157,7 @@ def process(self, traces): if traces.shape[0] != self.full_size: if self.full_size is not None: - print(f"Warning : chunk_size have change {self.chunk_size} {traces.shape[0]}, need recompile CL!!!") + print(f"Warning : chunk_size has changed {self.chunk_size} {traces.shape[0]}, need to recompile CL!!!") self.create_buffers_and_compile() event = pyopencl.enqueue_copy(self.queue, self.input_cl, traces) diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index aa98410568..4df4a409bc 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -212,7 +212,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces * self.taper[np.newaxis, :] # apply actual HP filter - import scipy + import scipy.signal traces = scipy.signal.sosfiltfilt(self.sos_filter, traces, axis=1) diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 7d43982853..bd53866b6a 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -68,7 +68,7 @@ def __init__( dtype="float32", **random_chunk_kwargs, ): - assert mode in ("pool_channel", "by_channel") + assert mode in ("pool_channel", "by_channel"), "'mode' must be 'pool_channel' or 'by_channel'" random_data = get_random_data_chunks(recording, **random_chunk_kwargs) @@ -260,7 +260,7 @@ def __init__( dtype="float32", **random_chunk_kwargs, ): - assert mode in ("median+mad", "mean+std") + assert mode in ("median+mad", "mean+std"), "'mode' must be 'median+mad' or 'mean+std'" # fix dtype dtype_ = fix_dtype(recording, dtype) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 9c8b2589a0..bdba55038d 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -42,7 +42,9 @@ def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=Non assert "inter_sample_shift" in recording.get_property_keys(), "'inter_sample_shift' is not a property!" sample_shifts = recording.get_property("inter_sample_shift") else: - assert len(inter_sample_shift) == recording.get_num_channels(), "sample " + assert ( + len(inter_sample_shift) == recording.get_num_channels() + ), "the 'inter_sample_shift' must be same size at the num_channels " sample_shifts = np.asarray(inter_sample_shift) margin = int(margin_ms * recording.get_sampling_frequency() / 1000.0) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 9bacd8e2c9..a49a605a75 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -91,7 +91,7 @@ def run_sorter( sorter_name: str, recording: BaseRecording, output_folder: Optional[str] = None, - remove_existing_folder: bool = True, + remove_existing_folder: bool = False, delete_output_folder: bool = False, verbose: bool = False, raise_error: bool = True, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index d25f1ea97b..468b96ff3b 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -43,6 +43,8 @@ def plot(self): self._do_plot() def _do_plot(self): + from matplotlib import pyplot as plt + fig = self.figure for ax in fig.axes: @@ -177,6 +179,8 @@ def plot(self): def _do_plot(self): import sklearn + import matplotlib.pyplot as plt + import matplotlib # compute similarity # take index of template (respect unit_ids order) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 39eb80e2e5..8814e0131a 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -32,10 +32,10 @@ def setUp(self): self.num_units = len(self._sorting.get_unit_ids()) #  self._we = extract_waveforms(self._rec, self._sorting, './toy_example', load_if_exists=True) - if (cache_folder / "mearec_test").is_dir(): - self._we = load_waveforms(cache_folder / "mearec_test") + if (cache_folder / "mearec_test_old_api").is_dir(): + self._we = load_waveforms(cache_folder / "mearec_test_old_api") else: - self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test") + self._we = extract_waveforms(self._rec, self._sorting, cache_folder / "mearec_test_old_api", sparse=False) self._amplitudes = compute_spike_amplitudes(self._we, peak_sign="neg", outputs="by_unit") self._gt_comp = sc.compare_sorter_to_ground_truth(self._sorting, self._sorting) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index f44878927d..1a2fdf38d9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -48,29 +48,30 @@ def setUpClass(cls): cls.sorting = se.MEArecSortingExtractor(local_path) cls.num_units = len(cls.sorting.get_unit_ids()) - if (cache_folder / "mearec_test").is_dir(): - cls.we = load_waveforms(cache_folder / "mearec_test") + if (cache_folder / "mearec_test_dense").is_dir(): + cls.we_dense = load_waveforms(cache_folder / "mearec_test_dense") else: - cls.we = extract_waveforms(cls.recording, cls.sorting, cache_folder / "mearec_test") + cls.we_dense = extract_waveforms( + cls.recording, cls.sorting, cache_folder / "mearec_test_dense", sparse=False + ) + metric_names = ["snr", "isi_violation", "num_spikes"] + _ = compute_spike_amplitudes(cls.we_dense) + _ = compute_unit_locations(cls.we_dense) + _ = compute_spike_locations(cls.we_dense) + _ = compute_quality_metrics(cls.we_dense, metric_names=metric_names) + _ = compute_template_metrics(cls.we_dense) + _ = compute_correlograms(cls.we_dense) + _ = compute_template_similarity(cls.we_dense) sw.set_default_plotter_backend("matplotlib") - metric_names = ["snr", "isi_violation", "num_spikes"] - _ = compute_spike_amplitudes(cls.we) - _ = compute_unit_locations(cls.we) - _ = compute_spike_locations(cls.we) - _ = compute_quality_metrics(cls.we, metric_names=metric_names) - _ = compute_template_metrics(cls.we) - _ = compute_correlograms(cls.we) - _ = compute_template_similarity(cls.we) - # make sparse waveforms - cls.sparsity_radius = compute_sparsity(cls.we, method="radius", radius_um=50) - cls.sparsity_best = compute_sparsity(cls.we, method="best_channels", num_channels=5) + cls.sparsity_radius = compute_sparsity(cls.we_dense, method="radius", radius_um=50) + cls.sparsity_best = compute_sparsity(cls.we_dense, method="best_channels", num_channels=5) if (cache_folder / "mearec_test_sparse").is_dir(): cls.we_sparse = load_waveforms(cache_folder / "mearec_test_sparse") else: - cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) + cls.we_sparse = cls.we_dense.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) cls.skip_backends = ["ipywidgets", "ephyviewer"] @@ -124,17 +125,17 @@ def test_plot_unit_waveforms(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_waveforms(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_waveforms(self.we_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_waveforms( - self.we, + self.we_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend], ) sw.plot_unit_waveforms( - self.we, + self.we_dense, sparsity=self.sparsity_best, unit_ids=unit_ids, backend=backend, @@ -148,10 +149,10 @@ def test_plot_unit_templates(self): possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_templates(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_templates(self.we_dense, backend=backend, **self.backend_kwargs[backend]) unit_ids = self.sorting.unit_ids[:6] sw.plot_unit_templates( - self.we, + self.we_dense, sparsity=self.sparsity_radius, unit_ids=unit_ids, backend=backend, @@ -171,7 +172,7 @@ def test_plot_unit_waveforms_density_map(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) def test_plot_unit_waveforms_density_map_sparsity_radius(self): @@ -180,7 +181,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self): if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] sw.plot_unit_waveforms_density_map( - self.we, + self.we_dense, sparsity=self.sparsity_radius, same_axis=False, unit_ids=unit_ids, @@ -234,11 +235,15 @@ def test_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_amplitudes(self.we, backend=backend, **self.backend_kwargs[backend]) - unit_ids = self.we.unit_ids[:4] - sw.plot_amplitudes(self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) + sw.plot_amplitudes(self.we_dense, backend=backend, **self.backend_kwargs[backend]) + unit_ids = self.we_dense.unit_ids[:4] + sw.plot_amplitudes(self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend]) sw.plot_amplitudes( - self.we, unit_ids=unit_ids, plot_histograms=True, backend=backend, **self.backend_kwargs[backend] + self.we_dense, + unit_ids=unit_ids, + plot_histograms=True, + backend=backend, + **self.backend_kwargs[backend], ) sw.plot_amplitudes( self.we_sparse, @@ -252,9 +257,9 @@ def test_plot_all_amplitudes_distributions(self): possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - unit_ids = self.we.unit_ids[:4] + unit_ids = self.we_dense.unit_ids[:4] sw.plot_all_amplitudes_distributions( - self.we, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] + self.we_dense, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] ) sw.plot_all_amplitudes_distributions( self.we_sparse, unit_ids=unit_ids, backend=backend, **self.backend_kwargs[backend] @@ -264,7 +269,9 @@ def test_unit_locations(self): possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_locations( + self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_unit_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -273,7 +280,9 @@ def test_spike_locations(self): possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_spike_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) + sw.plot_spike_locations( + self.we_dense, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] + ) sw.plot_spike_locations( self.we_sparse, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) @@ -282,28 +291,28 @@ def test_similarity(self): possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_similarity(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_similarity(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_quality_metrics(self): possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_quality_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_quality_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_template_metrics(self): possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_template_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_template_metrics(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_depths(self): possible_backends = list(sw.UnitDepthsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_unit_depths(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_unit_depths(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_unit_depths(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_summary(self): @@ -311,17 +320,17 @@ def test_plot_unit_summary(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( - self.we, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.we_dense, self.we_dense.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) sw.plot_unit_summary( - self.we_sparse, self.we.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] + self.we_sparse, self.we_sparse.sorting.unit_ids[0], backend=backend, **self.backend_kwargs[backend] ) def test_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary(self.we_dense, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_agreement_matrix(self): @@ -369,10 +378,10 @@ def test_plot_rasters(self): # mytest.test_quality_metrics() # mytest.test_template_metrics() # mytest.test_amplitudes() - # mytest.test_plot_agreement_matrix() + mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() - mytest.test_plot_rasters() + # mytest.test_plot_rasters() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 9b6716e8f3..fc8b30eb05 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -88,26 +88,32 @@ def __init__( else: raise ValueError("plot_traces recording must be recording or dict or list") - layer_keys = list(recordings.keys()) + if rec0.has_channel_location(): + channel_locations = rec0.get_channel_locations() + else: + channel_locations = None - if segment_index is None: - if rec0.get_num_segments() != 1: - raise ValueError("You must provide segment_index=...") - segment_index = 0 + if order_channel_by_depth and channel_locations is not None: + from ..preprocessing import depth_order + + rec0 = depth_order(rec0) + recordings = {k: depth_order(rec) for k, rec in recordings.items()} + + if channel_ids is not None: + # ensure that channel_ids are in the good order + channel_ids_ = list(rec0.channel_ids) + order = np.argsort([channel_ids_.index(c) for c in channel_ids]) + channel_ids = list(np.array(channel_ids)[order]) if channel_ids is None: channel_ids = rec0.channel_ids - if "location" in rec0.get_property_keys(): - channel_locations = rec0.get_channel_locations() - else: - channel_locations = None + layer_keys = list(recordings.keys()) - if order_channel_by_depth: - if channel_locations is not None: - order, _ = order_channels_by_depth(rec0, channel_ids) - else: - order = None + if segment_index is None: + if rec0.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") + segment_index = 0 fs = rec0.get_sampling_frequency() if time_range is None: @@ -124,7 +130,7 @@ def __init__( cmap = cmap times, list_traces, frame_range, channel_ids = _get_trace_list( - recordings, channel_ids, time_range, segment_index, order, return_scaled + recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled ) # stat for auto scaling done on the first layer @@ -138,9 +144,10 @@ def __init__( # colors is a nested dict by layer and channels # lets first create black for all channels and layer + # all color are generated for ipywidgets colors = {} for k in layer_keys: - colors[k] = {chan_id: "k" for chan_id in channel_ids} + colors[k] = {chan_id: "k" for chan_id in rec0.channel_ids} if color_groups: channel_groups = rec0.get_channel_groups(channel_ids=channel_ids) @@ -149,7 +156,7 @@ def __init__( group_colors = get_some_colors(groups, color_engine="auto") channel_colors = {} - for i, chan_id in enumerate(channel_ids): + for i, chan_id in enumerate(rec0.channel_ids): group = channel_groups[i] channel_colors[chan_id] = group_colors[group] @@ -159,12 +166,12 @@ def __init__( elif color is not None: # old behavior one color for all channel # if multi layer then black for all - colors[layer_keys[0]] = {chan_id: color for chan_id in channel_ids} + colors[layer_keys[0]] = {chan_id: color for chan_id in rec0.channel_ids} elif color is None and len(recordings) > 1: # several layer layer_colors = get_some_colors(layer_keys) for k in layer_keys: - colors[k] = {chan_id: layer_colors[k] for chan_id in channel_ids} + colors[k] = {chan_id: layer_colors[k] for chan_id in rec0.channel_ids} else: # color is None unique layer : all channels black pass @@ -201,7 +208,6 @@ def __init__( show_channel_ids=show_channel_ids, add_legend=add_legend, order_channel_by_depth=order_channel_by_depth, - order=order, tile_size=tile_size, num_timepoints_per_row=int(seconds_per_row * fs), return_scaled=return_scaled, @@ -336,6 +342,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) self.scaler = ScaleWidget() self.channel_selector = ChannelSelector(self.rec0.channel_ids) + self.channel_selector.value = list(data_plot["channel_ids"]) left_sidebar = W.VBox( children=[ @@ -398,17 +405,17 @@ def _mode_changed(self, change=None): def _retrieve_traces(self, change=None): channel_ids = np.array(self.channel_selector.value) - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None + # if self.data_plot["order_channel_by_depth"]: + # order, _ = order_channels_by_depth(self.rec0, channel_ids) + # else: + # order = None start_frame, end_frame, segment_index = self.time_slider.value time_range = np.array([start_frame, end_frame]) / self.rec0.sampling_frequency self._selected_recordings = {k: self.recordings[k] for k in self._get_layers()} times, list_traces, frame_range, channel_ids = _get_trace_list( - self._selected_recordings, channel_ids, time_range, segment_index, order, self.return_scaled + self._selected_recordings, channel_ids, time_range, segment_index, return_scaled=self.return_scaled ) self._channel_ids = channel_ids @@ -523,7 +530,7 @@ def plot_ephyviewer(self, data_plot, **backend_kwargs): app.exec() -def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): +def _get_trace_list(recordings, channel_ids, time_range, segment_index, return_scaled=False): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] rec0 = recordings[k0] @@ -550,11 +557,6 @@ def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=No return_scaled=return_scaled, ) - if order is not None: - traces = traces[:, order] list_traces.append(traces) - if order is not None: - channel_ids = np.array(channel_ids)[order] - return times, list_traces, frame_range, channel_ids diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 6e872eca55..58dd5c7f32 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -235,8 +235,7 @@ def __init__(self, channel_ids, **kwargs): self.slider.observe(self.on_slider_changed, names=["value"], type="change") self.selector.observe(self.on_selector_changed, names=["value"], type="change") - # TODO external value change - # self.observe(self.value_changed, names=['value'], type="change") + self.observe(self.value_changed, names=["value"], type="change") def on_slider_changed(self, change=None): i0, i1 = self.slider.value @@ -260,6 +259,18 @@ def on_selector_changed(self, change=None): self.value = channel_ids + def value_changed(self, change=None): + self.selector.unobserve(self.on_selector_changed, names=["value"], type="change") + self.selector.value = change["new"] + self.selector.observe(self.on_selector_changed, names=["value"], type="change") + + channel_ids = self.selector.value + self.slider.unobserve(self.on_slider_changed, names=["value"], type="change") + i0 = self.channel_ids.index(channel_ids[0]) + i1 = self.channel_ids.index(channel_ids[-1]) + 1 + self.slider.value = (i0, i1) + self.slider.observe(self.on_slider_changed, names=["value"], type="change") + class ScaleWidget(W.VBox): value = traitlets.Float()