diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 2aa89555f2..8c9620bfa4 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -8,6 +8,7 @@ It also implements: * ComputeNoiseLevels which is very convenient to have """ + import warnings import numpy as np @@ -82,7 +83,7 @@ def _merge_extension_data( ): new_data = dict() random_spikes_indices = self.data["random_spikes_indices"] - if keep_mask is None: + if keep_mask is None: new_data["random_spikes_indices"] = random_spikes_indices.copy() else: mask = keep_mask[random_spikes_indices] @@ -266,9 +267,8 @@ def _merge_extension_data( selection = np.flatnonzero(some_spikes["unit_index"] == unit_index) group_selection.append(selection) _inplace_sparse_realign_waveforms(waveforms, group_selection, group_sparsity_mask) - - return dict(waveforms=waveforms) + return dict(waveforms=waveforms) # def _merge_extension_data( # self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs @@ -435,9 +435,9 @@ def _merge_extension_data( # wfs[:, :, channel_mask_channel_inds] = sparse_waveforms # else: # wfs = sparse_waveforms - # some_waveforms[spike_mask] = wfs - - # return some_waveforms, selected_inds + # some_waveforms[spike_mask] = wfs + + # return some_waveforms, selected_inds def get_waveforms_one_unit( self, @@ -460,7 +460,7 @@ def get_waveforms_one_unit( The waveforms (num_waveforms, num_samples, num_channels). In case sparsity is used, only the waveforms on sparse channels are returned. """ - + sorting = self.sorting_analyzer.sorting unit_index = sorting.id_to_index(unit_id) some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() @@ -491,13 +491,11 @@ def _inplace_sparse_realign_waveforms(waveforms, group_selection, group_sparsity for i in range(len(group_selection)): chan_mask = group_sparsity_mask[i, :] sel = group_selection[i] - wfs = waveforms[sel, :, :][:, :, :np.sum(chan_mask)] + wfs = waveforms[sel, :, :][:, :, : np.sum(chan_mask)] keep_mask = common_mask[chan_mask] wfs = wfs[:, :, keep_mask] - waveforms[:, :, :wfs.shape[2]][sel, :, :] = wfs - waveforms[:, :, wfs.shape[2]:][sel, :, :] = 0. - - + waveforms[:, :, : wfs.shape[2]][sel, :, :] = wfs + waveforms[:, :, wfs.shape[2] :][sel, :, :] = 0.0 compute_waveforms = ComputeWaveforms.function_factory() @@ -561,7 +559,7 @@ def _run(self, verbose=False, **job_kwargs): if self.sorting_analyzer.has_extension("waveforms"): self._compute_and_append_from_waveforms(self.params["operators"]) - + else: for operator in self.params["operators"]: if operator not in ("average", "std"): @@ -662,7 +660,8 @@ def nbefore(self): warnings.warn( "The 'nbefore' parameter is deprecated and it's been replaced by 'ms_before' in the params." "You can save the sorting_analyzer to update the params.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0) @@ -675,7 +674,8 @@ def nafter(self): warnings.warn( "The 'nafter' parameter is deprecated and it's been replaced by 'ms_after' in the params." "You can save the sorting_analyzer to update the params.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency @@ -712,7 +712,9 @@ def _merge_extension_data( for count, merge_unit_id in enumerate(merge_group): weights[count] = counts[merge_unit_id] weights /= weights.sum() - new_data[key][unit_index] = (arr[keep_unit_indices, :, :] * weights[:, np.newaxis, np.newaxis]).sum(0) + new_data[key][unit_index] = (arr[keep_unit_indices, :, :] * weights[:, np.newaxis, np.newaxis]).sum( + 0 + ) if new_sorting_analyzer.sparsity is not None: chan_ids = new_sorting_analyzer.sparsity.unit_id_to_channel_indices[unit_id] mask = ~np.isin(np.arange(arr.shape[2]), chan_ids) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 35f9787111..a1ddebb740 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -33,7 +33,6 @@ from .node_pipeline import run_node_pipeline - # high level function def create_sorting_analyzer( sorting, @@ -654,15 +653,15 @@ def _save_or_select_or_merge( folder : str | Path | None, default: None The folder where the SortingAnalyzer object will be saved unit_ids : list or None, default: None - The unit ids to keep in the new SortingAnalyzer object. If `merge_unit_groups` is not None, + The unit ids to keep in the new SortingAnalyzer object. If `merge_unit_groups` is not None, `unit_ids` must be given it must contain all unit_ids. merge_unit_groups : list/tuple of lists/tuples or None, default: None - A list of lists for every merge group. Each element needs to have at least two elements + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge). If `merge_unit_groups` is not None, `new_unit_ids` must be given. censor_ms : None or float, default: None When merging units, any spikes violating this refractory period will be discarded. merging_mode : "soft" | "hard", default: "soft" - How merges are performed. In the "soft" mode, merges will be approximated, with no smart merging + How merges are performed. In the "soft" mode, merges will be approximated, with no smart merging of the extension data. sparsity_overlap : float, default 0.75 The percentage of overlap that units should share in order to accept merges. If this criteria is not @@ -704,9 +703,7 @@ def _save_or_select_or_merge( if merging_mode == "soft": intersection_mask = np.prod(self.sparsity.mask[merge_unit_indices], axis=0) > 0 thr = np.sum(intersection_mask) / np.sum(union_mask) - assert ( - thr > sparsity_overlap - ), ( + assert thr > sparsity_overlap, ( f"The sparsities of {current_merge_group} do not overlap enough for a soft merge using " f"a sparsity threshold of {sparsity_overlap}. You can either lower the threshold or use " "a hard merge." @@ -918,7 +915,9 @@ def merge_units( # keep backward compatibility : the previous behavior was only one merge merge_unit_groups = [merge_unit_groups] - new_unit_ids = generate_unit_ids_for_merge_group(self.unit_ids, merge_unit_groups, new_unit_ids, new_id_strategy) + new_unit_ids = generate_unit_ids_for_merge_group( + self.unit_ids, merge_unit_groups, new_unit_ids, new_id_strategy + ) all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids) return self._save_or_select_or_merge( diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index d2fbbf9227..591250f72c 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -142,7 +142,6 @@ def test_SortingAnalyzer_tmp_recording(dataset): sorting_analyzer.set_temporary_recording(recording_sliced) - def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): register_result_extension(DummyAnalyzerExtension) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 1b50d225b2..0d57aec21e 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -122,9 +122,9 @@ def _merge_extension_data( if self.params["handle_collisions"]: new_data["collision_mask"] = self.data["collision_mask"].copy() else: - new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask] + new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask] if self.params["handle_collisions"]: - new_data["collision_mask"] = self.data["collision_mask"][keep_mask] + new_data["collision_mask"] = self.data["collision_mask"][keep_mask] return new_data diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index f2a7adb0d1..1066b1ff2c 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -102,11 +102,11 @@ def _select_extension_data(self, unit_ids): if "model" in k: new_data[k] = v return new_data - + def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - + pca_projections = self.data["pca_projection"] some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() @@ -120,7 +120,7 @@ def _merge_extension_data( old_sparsity = self.sorting_analyzer.sparsity if old_sparsity is not None: - + # we need a realignement inside each group because we take the channel intersection sparsity # the story is same as in "waveforms" extension for group_ids in merge_unit_groups: @@ -131,9 +131,8 @@ def _merge_extension_data( unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) selection = np.flatnonzero(some_spikes["unit_index"] == unit_index) group_selection.append(selection) - - _inplace_sparse_realign_waveforms(pca_projections, group_selection, group_sparsity_mask) + _inplace_sparse_realign_waveforms(pca_projections, group_selection, group_sparsity_mask) new_data = dict(pca_projections=pca_projections) @@ -143,7 +142,6 @@ def _merge_extension_data( new_data[k] = v return new_data - # def _merge_extension_data( # self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs # ): @@ -180,7 +178,6 @@ def _merge_extension_data( # new_data[k] = v # return new_data - def get_pca_model(self): """ Returns the scikit-learn PCA model objects. diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index aec3f7ca84..e82a9e61e4 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -85,7 +85,7 @@ def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): new_data = dict() - + if keep_mask is None: new_data["amplitudes"] = self.data["amplitudes"].copy() else: diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 8add8c257c..3ca79105ae 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -160,15 +160,12 @@ def _merge_extension_data( ): import pandas as pd - old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)] metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) - - metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 156ebdf32e..3f67be3da1 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -99,7 +99,7 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) new_data = dict(metrics=metrics) return new_data