Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 8, 2024
1 parent 9b6eb38 commit feb0ff8
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 39 deletions.
34 changes: 18 additions & 16 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
It also implements:
* ComputeNoiseLevels which is very convenient to have
"""

import warnings

import numpy as np
Expand Down Expand Up @@ -82,7 +83,7 @@ def _merge_extension_data(
):
new_data = dict()
random_spikes_indices = self.data["random_spikes_indices"]
if keep_mask is None:
if keep_mask is None:
new_data["random_spikes_indices"] = random_spikes_indices.copy()
else:
mask = keep_mask[random_spikes_indices]
Expand Down Expand Up @@ -266,9 +267,8 @@ def _merge_extension_data(
selection = np.flatnonzero(some_spikes["unit_index"] == unit_index)
group_selection.append(selection)
_inplace_sparse_realign_waveforms(waveforms, group_selection, group_sparsity_mask)

return dict(waveforms=waveforms)

return dict(waveforms=waveforms)

# def _merge_extension_data(
# self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
Expand Down Expand Up @@ -435,9 +435,9 @@ def _merge_extension_data(
# wfs[:, :, channel_mask_channel_inds] = sparse_waveforms
# else:
# wfs = sparse_waveforms
# some_waveforms[spike_mask] = wfs
# return some_waveforms, selected_inds
# some_waveforms[spike_mask] = wfs

# return some_waveforms, selected_inds

def get_waveforms_one_unit(
self,
Expand All @@ -460,7 +460,7 @@ def get_waveforms_one_unit(
The waveforms (num_waveforms, num_samples, num_channels).
In case sparsity is used, only the waveforms on sparse channels are returned.
"""

sorting = self.sorting_analyzer.sorting
unit_index = sorting.id_to_index(unit_id)
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()
Expand Down Expand Up @@ -491,13 +491,11 @@ def _inplace_sparse_realign_waveforms(waveforms, group_selection, group_sparsity
for i in range(len(group_selection)):
chan_mask = group_sparsity_mask[i, :]
sel = group_selection[i]
wfs = waveforms[sel, :, :][:, :, :np.sum(chan_mask)]
wfs = waveforms[sel, :, :][:, :, : np.sum(chan_mask)]
keep_mask = common_mask[chan_mask]
wfs = wfs[:, :, keep_mask]
waveforms[:, :, :wfs.shape[2]][sel, :, :] = wfs
waveforms[:, :, wfs.shape[2]:][sel, :, :] = 0.


waveforms[:, :, : wfs.shape[2]][sel, :, :] = wfs
waveforms[:, :, wfs.shape[2] :][sel, :, :] = 0.0


compute_waveforms = ComputeWaveforms.function_factory()
Expand Down Expand Up @@ -561,7 +559,7 @@ def _run(self, verbose=False, **job_kwargs):

if self.sorting_analyzer.has_extension("waveforms"):
self._compute_and_append_from_waveforms(self.params["operators"])

else:
for operator in self.params["operators"]:
if operator not in ("average", "std"):
Expand Down Expand Up @@ -662,7 +660,8 @@ def nbefore(self):
warnings.warn(
"The 'nbefore' parameter is deprecated and it's been replaced by 'ms_before' in the params."
"You can save the sorting_analyzer to update the params.",
DeprecationWarning, stacklevel=2
DeprecationWarning,
stacklevel=2,
)

nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0)
Expand All @@ -675,7 +674,8 @@ def nafter(self):
warnings.warn(
"The 'nafter' parameter is deprecated and it's been replaced by 'ms_after' in the params."
"You can save the sorting_analyzer to update the params.",
DeprecationWarning, stacklevel=2
DeprecationWarning,
stacklevel=2,
)
self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency

Expand Down Expand Up @@ -712,7 +712,9 @@ def _merge_extension_data(
for count, merge_unit_id in enumerate(merge_group):
weights[count] = counts[merge_unit_id]
weights /= weights.sum()
new_data[key][unit_index] = (arr[keep_unit_indices, :, :] * weights[:, np.newaxis, np.newaxis]).sum(0)
new_data[key][unit_index] = (arr[keep_unit_indices, :, :] * weights[:, np.newaxis, np.newaxis]).sum(
0
)
if new_sorting_analyzer.sparsity is not None:
chan_ids = new_sorting_analyzer.sparsity.unit_id_to_channel_indices[unit_id]
mask = ~np.isin(np.arange(arr.shape[2]), chan_ids)
Expand Down
15 changes: 7 additions & 8 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from .node_pipeline import run_node_pipeline



# high level function
def create_sorting_analyzer(
sorting,
Expand Down Expand Up @@ -654,15 +653,15 @@ def _save_or_select_or_merge(
folder : str | Path | None, default: None
The folder where the SortingAnalyzer object will be saved
unit_ids : list or None, default: None
The unit ids to keep in the new SortingAnalyzer object. If `merge_unit_groups` is not None,
The unit ids to keep in the new SortingAnalyzer object. If `merge_unit_groups` is not None,
`unit_ids` must be given it must contain all unit_ids.
merge_unit_groups : list/tuple of lists/tuples or None, default: None
A list of lists for every merge group. Each element needs to have at least two elements
A list of lists for every merge group. Each element needs to have at least two elements
(two units to merge). If `merge_unit_groups` is not None, `new_unit_ids` must be given.
censor_ms : None or float, default: None
When merging units, any spikes violating this refractory period will be discarded.
merging_mode : "soft" | "hard", default: "soft"
How merges are performed. In the "soft" mode, merges will be approximated, with no smart merging
How merges are performed. In the "soft" mode, merges will be approximated, with no smart merging
of the extension data.
sparsity_overlap : float, default 0.75
The percentage of overlap that units should share in order to accept merges. If this criteria is not
Expand Down Expand Up @@ -704,9 +703,7 @@ def _save_or_select_or_merge(
if merging_mode == "soft":
intersection_mask = np.prod(self.sparsity.mask[merge_unit_indices], axis=0) > 0
thr = np.sum(intersection_mask) / np.sum(union_mask)
assert (
thr > sparsity_overlap
), (
assert thr > sparsity_overlap, (
f"The sparsities of {current_merge_group} do not overlap enough for a soft merge using "
f"a sparsity threshold of {sparsity_overlap}. You can either lower the threshold or use "
"a hard merge."
Expand Down Expand Up @@ -918,7 +915,9 @@ def merge_units(
# keep backward compatibility : the previous behavior was only one merge
merge_unit_groups = [merge_unit_groups]

new_unit_ids = generate_unit_ids_for_merge_group(self.unit_ids, merge_unit_groups, new_unit_ids, new_id_strategy)
new_unit_ids = generate_unit_ids_for_merge_group(
self.unit_ids, merge_unit_groups, new_unit_ids, new_id_strategy
)
all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids)

return self._save_or_select_or_merge(
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def test_SortingAnalyzer_tmp_recording(dataset):
sorting_analyzer.set_temporary_recording(recording_sliced)



def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):

register_result_extension(DummyAnalyzerExtension)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def _merge_extension_data(
if self.params["handle_collisions"]:
new_data["collision_mask"] = self.data["collision_mask"].copy()
else:
new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask]
new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask]
if self.params["handle_collisions"]:
new_data["collision_mask"] = self.data["collision_mask"][keep_mask]
new_data["collision_mask"] = self.data["collision_mask"][keep_mask]

return new_data

Expand Down
11 changes: 4 additions & 7 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def _select_extension_data(self, unit_ids):
if "model" in k:
new_data[k] = v
return new_data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):

pca_projections = self.data["pca_projection"]
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()

Expand All @@ -120,7 +120,7 @@ def _merge_extension_data(

old_sparsity = self.sorting_analyzer.sparsity
if old_sparsity is not None:

# we need a realignement inside each group because we take the channel intersection sparsity
# the story is same as in "waveforms" extension
for group_ids in merge_unit_groups:
Expand All @@ -131,9 +131,8 @@ def _merge_extension_data(
unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id)
selection = np.flatnonzero(some_spikes["unit_index"] == unit_index)
group_selection.append(selection)

_inplace_sparse_realign_waveforms(pca_projections, group_selection, group_sparsity_mask)

_inplace_sparse_realign_waveforms(pca_projections, group_selection, group_sparsity_mask)

new_data = dict(pca_projections=pca_projections)

Expand All @@ -143,7 +142,6 @@ def _merge_extension_data(
new_data[k] = v
return new_data


# def _merge_extension_data(
# self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
# ):
Expand Down Expand Up @@ -180,7 +178,6 @@ def _merge_extension_data(
# new_data[k] = v
# return new_data


def get_pca_model(self):
"""
Returns the scikit-learn PCA model objects.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
new_data = dict()

if keep_mask is None:
new_data["amplitudes"] = self.data["amplitudes"].copy()
else:
Expand Down
3 changes: 0 additions & 3 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,12 @@ def _merge_extension_data(
):
import pandas as pd


old_metrics = self.data["metrics"]

all_unit_ids = new_sorting_analyzer.unit_ids
not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)]

metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns)



metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :]
metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _merge_extension_data(
metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns)

metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :]
metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs)
metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs)

new_data = dict(metrics=metrics)
return new_data
Expand Down

0 comments on commit feb0ff8

Please sign in to comment.