diff --git a/doc/images/spikeinterface_gui.png b/doc/images/spikeinterface_gui.png new file mode 100644 index 0000000000..6afc7f762d Binary files /dev/null and b/doc/images/spikeinterface_gui.png differ diff --git a/doc/modules/core.rst b/doc/modules/core.rst index e993d0120b..81787e7f7b 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -377,6 +377,20 @@ backends without writing to disk. So, you can compute an extension *in-memory* w you have decided on your desired parameters you can either use :code:`compute` with :code:`save=True` or use :code:`save_as` to write everything out to disk. + +Finally, the :code:`SortingAnalyzer` object can be used directly to curate a spike sorting output by selecting/removing units +and merging unit groups. + +.. code-block:: python + + sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3]) + sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0]) + sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3]) + +All computed extensions will be automatically propagated or merged when curating. Please refer to the +:ref:`modules/curation` documentation for more information. + + Event ----- diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 45e6fb9ae8..3cdf5c170b 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -1,48 +1,162 @@ +.. _curation: Curation module =============== -**Note:** As of February 2023, this module is still under construction and quite experimental. -The API of some of the functions could be changed/improved from time to time. +.. note:: + As of July 2024, this module is still under construction and quite experimental. + The API of some of the functions could be changed/improved from time to time. -Manual curation ---------------- -SpikeInterface offers machinery to manually curate a sorting output and keep track of the curation history. -The curation has several "steps" that can be repeated and chained: +Curation with the ``SortingAnalyzer`` +------------------------------------- - * remove/select units - * split units - * merge units +The :py:class:`~spikeinterface.core.SortingAnalyzer`, as seen in previous modules, +is a powerful tool to posprocess the spike sorting output, as it can compute many +extensions and metrics to further characterize the spike sorting results. -This functionality is done with :py:class:`~spikeinterface.curation.CurationSorting` class. -Internally, this class keeps the history of curation as a graph. -The merging and splitting operations are handled by the :py:class:`~spikeinterface.curation.MergeUnitsSorting` and -:py:class:`~spikeinterface.curation.SplitUnitSorting`. These two classes can also be used independently. +To facilitate the spike sorting workflow, the :py:class:`~spikeinterface.core.SortingAnalyzer` +can also be used to perform curation tasks, such as removing bad units or merging similar units. +Here's an example of how to use the :py:class:`~spikeinterface.core.SortingAnalyzer` to remove +a subset of units from a spike sorting output and to perform some merges: .. code-block:: python - from spikeinterface.curation import CurationSorting + from spikeinterface import create_sorting_analyzer - sorting = run_sorter(sorter_name='kilosort2', recording=recording) + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording) - cs = CurationSorting(parent_sorting=sorting) + # compute some extensions + sorting_analyzer.compute(["random_spikes", "templates", "template_similarity", "correlograms"]) - # make a first merge - cs.merge(units_to_merge=['#1', '#5', '#15']) + # remove some units + remove_unit_ids = [1, 2] + sorting_analyzer2 = sorting_analyzer.remove_units(remove_unit_ids=remove_unit_ids) - # make a second merge - cs.merge(units_to_merge=['#11', '#21']) + # merge some units + merge_unit_groups = [[4, 5], [7, 8, 12]] + sorting_analyzer3 = sorting_analyzer2.merge_units( + merge_unit_groups=merge_unit_groups, + censored_period_ms=0.5, + merging_mode="soft" + ) - # make a split - split_index = ... # some criteria on spikes - cs.split(split_unit_id='#20', indices_list=split_index) - # here is the final clean sorting - clean_sorting = cs.sorting +Importantly, all the extensions that were computed on the original :py:class:`~spikeinterface.core.SortingAnalyzer` +(``sorting_analyzer``) are automatically propagated to the returned new +:py:class:`~spikeinterface.core.SortingAnalyzer` objects (``sorting_analyzer2``, ``sorting_analyzer3``). + +In particular, the merging steps supports a few interesting and useful functions. +If ``censored_period_ms`` is set, the function will remove spikes that are too close in time after the merge +(in the case above, closer than 0.5 ms). +The ``merging_mode`` parameter can be set to ``"soft"`` (default) or ``"hard"``. The ``"soft"`` mode will +try to smartly combine the existing extension data (e.g. templates, template similarity, etc.) +to estimate the merged units' data, when possible. This is the fastest mode, but it can be less accurate. +The ``"hard"`` mode will simply merge the spike trains of the units and recompute the extensions on the +merged spike train. This is more accurate but slower, especially for the extensions that need to traverse the +raw data (e.g., spike amplitudes, spike locations, amplitude scalings, etc.). + + +Automatic curation tools +------------------------ + +The :code:`spikeinterface.curation` module provides several automatic curation tools to clean spike sorting outputs. +Many of them are ported, adapted, or inspired by `Lussac `_ +([Llobet]_). + + +Remove duplicated spikes and redundant units +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +There are some convenient functions of the curation module allows you to remove redundant +units and duplicated spikes from the sorting output. + +The :py:func:`~spikeinterface.curation.remove_duplicated_spikes` function removes +duplicated spikes from the sorting output. Duplicated spikes are spikes that are +occur within a certain time window for the same unit. + +.. code-block:: python + + from spikeinterface.curation import remove_duplicated_spikes + + # remove duplicated spikes from BaseSorting object + clean_sorting = remove_duplicated_spikes(sorting, censored_period_ms=0.1) + +The ``censored_period_ms`` parameter is the time window in milliseconds to consider two spikes as duplicated. + +The :py:func:`~spikeinterface.curation.remove_redundand_units` function removes +redundant units from the sorting output. Redundant units are units that share over +a certain percentage of spikes, by default 80%. +The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. + +.. code-block:: python + + from spikeinterface.curation import remove_redundant_units + + # remove redundant units from BaseSorting object + clean_sorting = remove_redundant_units( + sorting, + duplicate_threshold=0.9, + remove_strategy="max_spikes" + ) + + # remove redundant units from SortingAnalyzer object + clean_sorting_analyzer = remove_redundant_units( + sorting_analyzer, + duplicate_threshold=0.9, + remove_strategy="min_shift" + ) + +We recommend usinf the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps +the unit (among the redundant ones), with a better template alignment. + + +Auto-merging units +^^^^^^^^^^^^^^^^^^ + +The :py:func:`~spikeinterface.curation.get_potential_auto_merge` function returns a list of potential merges. +The list of potential merges can be then applied to the sorting output. +:py:func:`~spikeinterface.curation.get_potential_auto_merge` has many internal tricks and steps to identify potential +merges. It offers multiple "presets" and the flexibility to apply individual steps, with different parameters. +**Read the function documentation carefully and do not apply it blindly!** + + +.. code-block:: python + + from spikeinterface import create_sorting_analyzer + from spikeinterface.curation import get_potential_auto_merge + + analyzer = create_sorting_analyzer(sorting=sorting, recording=recording) + + # some extensions are required + analyzer.compute(["random_spikes", "templates", "template_similarity", "correlograms"]) -Manual curation format ----------------------- + # merges is a list of unit pairs, with unit_ids to be merged. + merge_unit_pairs = get_potential_auto_merge( + analyzer=analyzer, + preset="similarity_correlograms", + ) + # with resolve_graph=True, merges_resolved is a list of merge groups, + # which can contain more than two units + merge_unit_groups = get_potential_auto_merge( + analyzer=analyzer, + preset="similarity_correlograms", + resolve_graph=True + ) + + # here we apply the merges + analyzer_merged = analyzer.merge_units(merge_unit_groups=merge_unit_groups) + + +Manual curation +--------------- + +While automatic curation tools can be very useful, manual curation is still widely used to +clean spike sorting outputs and it is sometoimes necessary to have a human in the loop. + + +Curation format +^^^^^^^^^^^^^^^ SpikeInterface internally supports a JSON-based manual curation format. When manual curation is necessary, modifying a dataset in place is a bad practice. @@ -62,7 +176,7 @@ This format has two part: * **manual output** curation with the folowing keys: * "manual_labels" - * "merged_unit_groups" + * "merge_unit_groups" * "removed_units" Here is the description of the format with a simple example (the first part of the @@ -128,7 +242,7 @@ format is the definition; the second part of the format is manual action): ] } ], - "merged_unit_groups": [ + "merge_unit_groups": [ [ "u3", "u6" @@ -146,43 +260,54 @@ format is the definition; the second part of the format is manual action): } +.. note:: + The curation format was recently introduced (v0.101.0), and we are still working on + properly integrating it into the SpikeInterface ecosystem. + Soon there will be functions vailable, in the curation module, to apply this + standardized curation format to ``SortingAnalyzer`` and a ``BaseSorting`` objects. -Automatic curation tools ------------------------- -`Lussac `_ is an external package with several strategies -for automatic curation of a spike sorting output. +Using the ``SpikeInterface GUI`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Some of them, like the auto-merging, have been ported to SpikeInterface. -The :py:func:`~spikeinterface.curation.get_potential_auto_merge` function returns a list of potential merges. -The list of potential merges can be then applied to the sorting output. -:py:func:`~spikeinterface.curation.get_potential_auto_merge` has many internal tricks and steps to identify potential -merges. Therefore, it has many parameters and options. -**Read the function documentation carefully and do not apply it blindly!** +We support several tools to perform manual curation of spike sorting outputs. +The first one is the `SpikeInterface-GUI `_, a QT-based GUI that allows you to +visualize and curate the spike sorting output. -.. code-block:: python +.. image:: ../images/spikeinterface_gui.png - from spikeinterface.curation import MergeUnitsSorting, get_potential_auto_merge +To launch the GUI, you can use the :py:func:`~spikeinterface.widgets.plot_sorting_summary` function +and select the ``backend='spikeinterface_gui'``. - sorting = run_sorter(sorter_name='kilosort', recording=recording) +.. code-block:: python - we = extract_waveforms(recording=recording, sorting=sorting, folder='wf_folder') + from spikeinterface import create_sorting_analyzer + from spikeinterface.curation import apply_sortingview_curation + from spikeinterface.widgets import plot_sorting_summary - # merges is a list of lists, with unit_ids to be merged. - merges = get_potential_auto_merge(waveform_extractor=we, minimum_spikes=1000, maximum_distance_um=150., - peak_sign="neg", bin_ms=0.25, window_ms=100., - corr_diff_thresh=0.16, template_diff_thresh=0.25, - censored_period_ms=0., refractory_period_ms=1.0, - contamination_threshold=0.2, num_channels=5, num_shift=5, - firing_contamination_balance=1.5) + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording) + + # some extensions are required + sorting_analyzer.compute([ + "random_spikes", + "noise_levels", + "templates", + "template_similarity", + "unit_locations", + "spike_amplitudes", + "principal_components", + "correlograms" + ] + ) + sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) - # here we apply the merges - clean_sorting = MergeUnitsSorting(parent_sorting=sorting, units_to_merge=merges) + # this will open the GUI in a different window + plot_sorting_summary(sorting_analyzer=sorting_analyzer, curation=True, backend='spikeinterface_gui') -Manual curation with sortingview ---------------------------------- +Using the ``sortingview`` web-app +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Within the :code:`sortingview` widgets backend (see :ref:`sorting_view`), the :py:func:`~spikeinterface.widgets.plot_sorting_summary` produces a powerful web-based GUI that enables manual curation @@ -197,31 +322,32 @@ The manual curation (including merges and labels) can be applied to a SpikeInter .. code-block:: python + from spikeinterface import create_sorting_analyzer from spikeinterface.curation import apply_sortingview_curation - from spikeinterface.postprocessing import (compute_spike_amplitudes, compute_unit_locations, - compute_template_similarity, compute_correlograms) from spikeinterface.widgets import plot_sorting_summary - # run a sorter and export waveforms - sorting = run_sorter(sorter_name='kilosort2', recording=recording) - we = extract_waveforms(recording=recording, sorting=sorting, folder='wf_folder') + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording) - # some postprocessing is required - _ = compute_spike_amplitudes(waveform_extractor=we) - _ = compute_unit_locations(waveform_extractor=we) - _ = compute_template_similarity(waveform_extractor=we) - _ = compute_correlograms(waveform_extractor=we) + # some extensions are required + sorting_analyzer.compute([ + "random_spikes", + "templates", + "template_similarity", + "unit_locations", + "spike_amplitudes", + "correlograms"] + ) # This loads the data to the cloud for web-based plotting and sharing # curation=True required for allowing curation in the sortingview gui - plot_sorting_summary(waveform_extractor=we, curation=True, backend='sortingview') + plot_sorting_summary(sorting_analyzer=sorting_analyzer, curation=True, backend='sortingview') # we open the printed link URL in a browser # - make manual merges and labeling # - from the curation box, click on "Save as snapshot (sha1://)" # copy the uri sha_uri = "sha1://59feb326204cf61356f1a2eb31f04d8e0177c4f1" - clean_sorting = apply_sortingview_curation(sorting=sorting, uri_or_json=sha_uri) + clean_sorting = apply_sortingview_curation(sorting=sorting_analyzer.sorting, uri_or_json=sha_uri) Note that you can also "Export as JSON" and pass the json file as :code:`uri_or_json` parameter. @@ -234,7 +360,43 @@ Other curation tools We have other tools for cleaning spike sorting outputs: * :py:func:`~spikeinterface.curation.find_duplicated_spikes` : find duplicated spikes in the spike trains - * | :py:func:`~spikeinterface.curation.remove_duplicated_spikes` : remove all duplicated spikes from the spike trains - | :py:class:`~spikeinterface.core.BaseSorting` object (internally using the previous function) * | :py:func:`~spikeinterface.curation.remove_excess_spikes` : remove spikes whose times are greater than the | recording's number of samples (by segment) + + +The `CurationSorting` class (deprecated) +---------------------------------------- + +SpikeInterface offers machinery to manually curate a sorting output and keep track of the curation history. +The curation has several "steps" that can be repeated and chained: + + * remove/select units + * split units + * merge units + +This functionality is done with :py:class:`~spikeinterface.curation.CurationSorting` class. +Internally, this class keeps the history of curation as a graph. +The merging and splitting operations are handled by the :py:class:`~spikeinterface.curation.MergeUnitsSorting` and +:py:class:`~spikeinterface.curation.SplitUnitSorting`. These two classes can also be used independently. + + +.. code-block:: python + + from spikeinterface.curation import CurationSorting + + sorting = run_sorter(sorter_name='kilosort2', recording=recording) + + cs = CurationSorting(parent_sorting=sorting) + + # make a first merge + cs.merge(units_to_merge=['#1', '#5', '#15']) + + # make a second merge + cs.merge(units_to_merge=['#11', '#21']) + + # make a split + split_index = ... # some criteria on spikes + cs.split(split_unit_id='#20', indices_list=split_index) + + # here is the final clean sorting + clean_sorting = cs.sorting diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 073708f353..ad23a5f249 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -9,6 +9,8 @@ * ComputeNoiseLevels which is very convenient to have """ +import warnings + import numpy as np from .sortinganalyzer import AnalyzerExtension, register_result_extension @@ -76,6 +78,20 @@ def _select_extension_data(self, unit_ids): new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_spike_mask]) return new_data + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + new_data = dict() + random_spikes_indices = self.data["random_spikes_indices"] + if keep_mask is None: + new_data["random_spikes_indices"] = random_spikes_indices.copy() + else: + spikes = self.sorting_analyzer.sorting.to_spike_vector() + selected_mask = np.zeros(spikes.size, dtype=bool) + selected_mask[random_spikes_indices] = True + new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_mask]) + return new_data + def _get_data(self): return self.data["random_spikes_indices"] @@ -224,18 +240,66 @@ def _select_extension_data(self, unit_ids): return new_data - def get_waveforms_one_unit( - self, - unit_id, - force_dense: bool = False, + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): + new_data = dict() + + waveforms = self.data["waveforms"] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() + if keep_mask is not None: + spike_indices = self.sorting_analyzer.get_extension("random_spikes").get_data() + valid = keep_mask[spike_indices] + some_spikes = some_spikes[valid] + waveforms = waveforms[valid] + else: + waveforms = waveforms.copy() + + old_sparsity = self.sorting_analyzer.sparsity + if old_sparsity is not None: + # we need a realignement inside each group because we take the channel intersection sparsity + for group_ids in merge_unit_groups: + group_indices = self.sorting_analyzer.sorting.ids_to_indices(group_ids) + group_sparsity_mask = old_sparsity.mask[group_indices, :] + group_selection = [] + for unit_id in group_ids: + unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + selection = np.flatnonzero(some_spikes["unit_index"] == unit_index) + group_selection.append(selection) + _inplace_sparse_realign_waveforms(waveforms, group_selection, group_sparsity_mask) + + old_num_chans = int(np.max(np.sum(old_sparsity.mask, axis=1))) + new_num_chans = int(np.max(np.sum(new_sorting_analyzer.sparsity.mask, axis=1))) + if new_num_chans < old_num_chans: + waveforms = waveforms[:, :, :new_num_chans] + + return dict(waveforms=waveforms) + + def get_waveforms_one_unit(self, unit_id, force_dense: bool = False): + """ + Returns the waveforms of a unit id. + + Parameters + ---------- + unit_id : int or str + The unit id to return waveforms for + force_dense : bool, default: False + If True, and SortingAnalyzer must be sparse then only waveforms on sparse channels are returned. + + Returns + ------- + waveforms: np.array + The waveforms (num_waveforms, num_samples, num_channels). + In case sparsity is used, only the waveforms on sparse channels are returned. + """ sorting = self.sorting_analyzer.sorting unit_index = sorting.id_to_index(unit_id) - # spikes = sorting.to_spike_vector() - # some_spikes = spikes[self.sorting_analyzer.random_spikes_indices] + + waveforms = self.data["waveforms"] some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() + spike_mask = some_spikes["unit_index"] == unit_index - wfs = self.data["waveforms"][spike_mask, :, :] + wfs = waveforms[spike_mask, :, :] if self.sorting_analyzer.sparsity is not None: chan_inds = self.sorting_analyzer.sparsity.unit_id_to_channel_indices[unit_id] @@ -252,6 +316,22 @@ def _get_data(self): return self.data["waveforms"] +def _inplace_sparse_realign_waveforms(waveforms, group_selection, group_sparsity_mask): + # this is used by "waveforms" extension but also "pca" + + # common mask is intersection + common_mask = np.all(group_sparsity_mask, axis=0) + + for i in range(len(group_selection)): + chan_mask = group_sparsity_mask[i, :] + sel = group_selection[i] + wfs = waveforms[sel, :, :][:, :, : np.sum(chan_mask)] + keep_mask = common_mask[chan_mask] + wfs = wfs[:, :, keep_mask] + waveforms[:, :, : wfs.shape[2]][sel, :, :] = wfs + waveforms[:, :, wfs.shape[2] :][sel, :, :] = 0.0 + + compute_waveforms = ComputeWaveforms.function_factory() register_result_extension(ComputeWaveforms) @@ -298,16 +378,13 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N waveforms_extension = self.sorting_analyzer.get_extension("waveforms") if waveforms_extension is not None: - nbefore = waveforms_extension.nbefore - nafter = waveforms_extension.nafter - else: - nbefore = int(ms_before * self.sorting_analyzer.sampling_frequency / 1000.0) - nafter = int(ms_after * self.sorting_analyzer.sampling_frequency / 1000.0) + ms_before = waveforms_extension.params["ms_before"] + ms_after = waveforms_extension.params["ms_after"] params = dict( operators=operators, - nbefore=nbefore, - nafter=nafter, + ms_before=ms_before, + ms_after=ms_after, ) return params @@ -316,6 +393,7 @@ def _run(self, verbose=False, **job_kwargs): if self.sorting_analyzer.has_extension("waveforms"): self._compute_and_append_from_waveforms(self.params["operators"]) + else: for operator in self.params["operators"]: if operator not in ("average", "std"): @@ -380,7 +458,6 @@ def _compute_and_append_from_waveforms(self, operators): "random_spikes" ), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()" some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() - for unit_index, unit_id in enumerate(unit_ids): spike_mask = some_spikes["unit_index"] == unit_index wfs = waveforms[spike_mask, :, :] @@ -410,11 +487,33 @@ def _compute_and_append_from_waveforms(self, operators): @property def nbefore(self): - return self.params["nbefore"] + if "ms_before" not in self.params: + # compatibility february 2024 > july 2024 + self.params["ms_before"] = self.params["nbefore"] * 1000.0 / self.sorting_analyzer.sampling_frequency + warnings.warn( + "The 'nbefore' parameter is deprecated and it's been replaced by 'ms_before' in the params." + "You can save the sorting_analyzer to update the params.", + DeprecationWarning, + stacklevel=2, + ) + + nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0) + return nbefore @property def nafter(self): - return self.params["nafter"] + if "ms_after" not in self.params: + # compatibility february 2024 > july 2024 + warnings.warn( + "The 'nafter' parameter is deprecated and it's been replaced by 'ms_after' in the params." + "You can save the sorting_analyzer to update the params.", + DeprecationWarning, + stacklevel=2, + ) + self.params["ms_after"] = self.params["nafter"] * 1000.0 / self.sorting_analyzer.sampling_frequency + + nafter = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) + return nafter def _select_extension_data(self, unit_ids): keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) @@ -425,12 +524,43 @@ def _select_extension_data(self, unit_ids): return new_data + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + + all_new_units = new_sorting_analyzer.unit_ids + new_data = dict() + counts = self.sorting_analyzer.sorting.count_num_spikes_per_unit() + for key, arr in self.data.items(): + new_data[key] = np.zeros((len(all_new_units), arr.shape[1], arr.shape[2]), dtype=arr.dtype) + for unit_index, unit_id in enumerate(all_new_units): + if unit_id not in new_unit_ids: + keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + new_data[key][unit_index] = arr[keep_unit_index, :, :] + else: + merge_group = merge_unit_groups[list(new_unit_ids).index(unit_id)] + keep_unit_indices = self.sorting_analyzer.sorting.ids_to_indices(merge_group) + # We do a weighted sum of the templates + weights = np.zeros(len(merge_group), dtype=np.float32) + for count, merge_unit_id in enumerate(merge_group): + weights[count] = counts[merge_unit_id] + weights /= weights.sum() + new_data[key][unit_index] = (arr[keep_unit_indices, :, :] * weights[:, np.newaxis, np.newaxis]).sum( + 0 + ) + if new_sorting_analyzer.sparsity is not None: + chan_ids = new_sorting_analyzer.sparsity.unit_id_to_channel_indices[unit_id] + mask = ~np.isin(np.arange(arr.shape[2]), chan_ids) + new_data[key][unit_index][:, mask] = 0 + + return new_data + def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator else: assert percentile is not None, "You must provide percentile=..." - key = f"pencentile_{percentile}" + key = f"percentile_{percentile}" templates_array = self.data[key] @@ -582,6 +712,12 @@ def _select_extension_data(self, unit_ids): # this do not depend on units return self.data + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + # this do not depend on units + return self.data.copy() + def _run(self, verbose=False): self.data["noise_levels"] = get_noise_levels( self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index db57d028f7..62aa7f37c3 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -973,6 +973,8 @@ def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=Fa ------- sorting_with_split : NumpySorting A sorting with split units. + other_ids : dict + The dictionary with the split unit_ids. Returned only if output_ids is True. """ unit_ids = sorting.unit_ids assert unit_ids.dtype.kind == "i" diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 918d95bf52..2a2f7b6b5a 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +from spikeinterface.core import NumpySorting from .basesorting import BaseSorting from .numpyextractors import NumpySorting @@ -226,7 +227,7 @@ def random_spikes_selection( def apply_merges_to_sorting( - sorting, units_to_merge, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append" + sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append" ): """ Apply a resolved representation of the merges to a sorting object. @@ -241,16 +242,16 @@ def apply_merges_to_sorting( ---------- sorting : Sorting The Sorting object to apply merges. - units_to_merge : list/tuple of lists/tuples + merge_unit_groups : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids : list | None, default: None - A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. If None, + A new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`. If None, merged units will have the first unit_id of every lists of merges. censor_ms: float | None, default: None When applying the merges, should be discard consecutive spikes violating a given refractory per return_kept : bool, default: False - If True, also return also a booolean mask of kept spikes. + If True, also return also a boolean mask of kept spikes. new_id_strategy : "append" | "take_first", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. @@ -270,15 +271,15 @@ def apply_merges_to_sorting( keep_mask = np.ones(len(spikes), dtype=bool) new_unit_ids = generate_unit_ids_for_merge_group( - sorting.unit_ids, units_to_merge, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + sorting.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy ) rename_ids = {} - for i, merge_group in enumerate(units_to_merge): + for i, merge_group in enumerate(merge_unit_groups): for unit_id in merge_group: rename_ids[unit_id] = new_unit_ids[i] - all_unit_ids = _get_ids_after_merging(sorting.unit_ids, units_to_merge, new_unit_ids) + all_unit_ids = _get_ids_after_merging(sorting.unit_ids, merge_unit_groups, new_unit_ids) all_unit_ids = list(all_unit_ids) num_seg = sorting.get_num_segments() @@ -302,7 +303,7 @@ def apply_merges_to_sorting( if censor_ms is not None: rpv = int(sorting.sampling_frequency * censor_ms / 1000.0) - for group_old_ids in units_to_merge: + for group_old_ids in merge_unit_groups: for segment_index in range(num_seg): group_indices = [] for unit_id in group_old_ids: @@ -321,7 +322,7 @@ def apply_merges_to_sorting( return sorting -def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): +def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids): """ Function to get the list of unique unit_ids after some merges, with given new_units_ids would be provided. @@ -332,11 +333,11 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): ---------- old_unit_ids : np.array The old unit_ids. - units_to_merge : list/tuple of lists/tuples + merge_unit_groups : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids : list | None - A new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. + A new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`. Returns ------- @@ -347,10 +348,10 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): """ old_unit_ids = np.asarray(old_unit_ids) - assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + assert len(new_unit_ids) == len(merge_unit_groups), "new_unit_ids should have the same len as merge_unit_groups" all_unit_ids = list(old_unit_ids.copy()) - for new_unit_id, group_ids in zip(new_unit_ids, units_to_merge): + for new_unit_id, group_ids in zip(new_unit_ids, merge_unit_groups): assert len(group_ids) > 1, "A merge should have at least two units" for unit_id in group_ids: assert unit_id in old_unit_ids, "Merged ids should be in the sorting" @@ -363,21 +364,21 @@ def _get_ids_after_merging(old_unit_ids, units_to_merge, new_unit_ids): return np.array(all_unit_ids) -def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids=None, new_id_strategy="append"): +def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy="append"): """ Function to generate new units ids during a merging procedure. If new_units_ids are provided, it will return these unit ids, checking that they have the the same - length as `units_to_merge`. + length as `merge_unit_groups`. Parameters ---------- old_unit_ids : np.array The old unit_ids. - units_to_merge : list/tuple of lists/tuples + merge_unit_groups : list/tuple of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids : list | None, default: None - Optional new unit_ids for merged units. If given, it needs to have the same length as `units_to_merge`. + Optional new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`. If None, new ids will be generated. new_id_strategy : "append" | "take_first", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. @@ -394,17 +395,17 @@ def generate_unit_ids_for_merge_group(old_unit_ids, units_to_merge, new_unit_ids if new_unit_ids is not None: # then only doing a consistency check - assert len(new_unit_ids) == len(units_to_merge), "new_unit_ids should have the same len as units_to_merge" + assert len(new_unit_ids) == len(merge_unit_groups), "new_unit_ids should have the same len as merge_unit_groups" # new_unit_ids can also be part of old_unit_ids only inside the same group: for i, new_unit_id in enumerate(new_unit_ids): if new_unit_id in old_unit_ids: - assert new_unit_id in units_to_merge[i], "new_unit_ids already exists but outside the merged groups" + assert new_unit_id in merge_unit_groups[i], "new_unit_ids already exists but outside the merged groups" else: dtype = old_unit_ids.dtype - num_merge = len(units_to_merge) + num_merge = len(merge_unit_groups) # select new_unit_ids greater that the max id, event greater than the numerical str ids if new_id_strategy == "take_first": - new_unit_ids = [to_be_merged[0] for to_be_merged in units_to_merge] + new_unit_ids = [to_be_merged[0] for to_be_merged in merge_unit_groups] elif new_id_strategy == "append": if np.issubdtype(dtype, np.character): # dtype str diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 89e9e2cf0f..27a47a31ac 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -24,6 +24,7 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match from .core_tools import check_json, retrieve_importing_provenance +from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting from .sparsity import ChannelSparsity, estimate_sparsity @@ -629,11 +630,54 @@ def set_temporary_recording(self, recording: BaseRecording): warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.") self._temporary_recording = recording - def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingAnalyzer": - """ - Internal used by both save_as(), copy() and select_units() which are more or less the same. + def _save_or_select_or_merge( + self, + format="binary_folder", + folder=None, + unit_ids=None, + merge_unit_groups=None, + censor_ms=None, + merging_mode="soft", + sparsity_overlap=0.75, + verbose=False, + new_unit_ids=None, + **job_kwargs, + ) -> "SortingAnalyzer": """ + Internal method used by both `save_as()`, `copy()`, `select_units()`, and `merge_units()`. + Parameters + ---------- + format : "memory" | "binary_folder" | "zarr", default: "binary_folder" + The format to save the SortingAnalyzer object + folder : str | Path | None, default: None + The folder where the SortingAnalyzer object will be saved + unit_ids : list or None, default: None + The unit ids to keep in the new SortingAnalyzer object. If `merge_unit_groups` is not None, + `unit_ids` must be given it must contain all unit_ids. + merge_unit_groups : list/tuple of lists/tuples or None, default: None + A list of lists for every merge group. Each element needs to have at least two elements + (two units to merge). If `merge_unit_groups` is not None, `new_unit_ids` must be given. + censor_ms : None or float, default: None + When merging units, any spikes violating this refractory period will be discarded. + merging_mode : "soft" | "hard", default: "soft" + How merges are performed. In the "soft" mode, merges will be approximated, with no smart merging + of the extension data. + sparsity_overlap : float, default 0.75 + The percentage of overlap that units should share in order to accept merges. If this criteria is not + achieved, soft merging will not be performed. + new_unit_ids : list or None, default: None + The new unit ids for merged units. Required if `merge_unit_groups` is not None. + verbose : bool, default: False + If True, output is verbose. + job_kwargs : dict + Keyword arguments for parallelization. + + Returns + ------- + new_sorting_analyzer : SortingAnalyzer + The newly created SortingAnalyzer object. + """ if self.has_recording(): recording = self._recording elif self.has_temporary_recording(): @@ -641,24 +685,65 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> else: recording = None - if self.sparsity is not None and unit_ids is None: + if self.sparsity is not None and unit_ids is None and merge_unit_groups is None: sparsity = self.sparsity - elif self.sparsity is not None and unit_ids is not None: + elif self.sparsity is not None and unit_ids is not None and merge_unit_groups is None: sparsity_mask = self.sparsity.mask[np.isin(self.unit_ids, unit_ids), :] sparsity = ChannelSparsity(sparsity_mask, unit_ids, self.channel_ids) + elif self.sparsity is not None and merge_unit_groups is not None: + all_unit_ids = unit_ids + sparsity_mask = np.zeros((len(all_unit_ids), self.sparsity.mask.shape[1]), dtype=bool) + for unit_index, unit_id in enumerate(all_unit_ids): + if unit_id in new_unit_ids: + # This is a new unit, and the sparsity mask will be the intersection of the + # ones of all merges + current_merge_group = merge_unit_groups[list(new_unit_ids).index(unit_id)] + merge_unit_indices = self.sorting.ids_to_indices(current_merge_group) + union_mask = np.sum(self.sparsity.mask[merge_unit_indices], axis=0) > 0 + if merging_mode == "soft": + intersection_mask = np.prod(self.sparsity.mask[merge_unit_indices], axis=0) > 0 + thr = np.sum(intersection_mask) / np.sum(union_mask) + assert thr > sparsity_overlap, ( + f"The sparsities of {current_merge_group} do not overlap enough for a soft merge using " + f"a sparsity threshold of {sparsity_overlap}. You can either lower the threshold or use " + "a hard merge." + ) + sparsity_mask[unit_index] = intersection_mask + elif merging_mode == "hard": + sparsity_mask[unit_index] = union_mask + else: + # This means that the unit is already in the previous sorting + index = self.sorting.id_to_index(unit_id) + sparsity_mask[unit_index] = self.sparsity.mask[index] + sparsity = ChannelSparsity(sparsity_mask, list(all_unit_ids), self.channel_ids) else: sparsity = None # Note that the sorting is a copy we need to go back to the orginal sorting (if available) sorting_provenance = self.get_sorting_provenance() if sorting_provenance is None: - # if the original sorting objetc is not available anymore (kilosort folder deleted, ....), take the copy + # if the original sorting object is not available anymore (kilosort folder deleted, ....), take the copy sorting_provenance = self.sorting - if unit_ids is not None: + if merge_unit_groups is None: # when only some unit_ids then the sorting must be sliced # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! sorting_provenance = sorting_provenance.select_units(unit_ids) + else: + from spikeinterface.core.sorting_tools import apply_merges_to_sorting + + sorting_provenance, keep_mask = apply_merges_to_sorting( + sorting=sorting_provenance, + merge_unit_groups=merge_unit_groups, + new_unit_ids=new_unit_ids, + censor_ms=censor_ms, + return_kept=True, + ) + if censor_ms is None: + # in this case having keep_mask None is faster instead of having a vector of ones + keep_mask = None + # TODO: sam/pierre would create a curation field / curation.json with the applied merges. + # What do you think? if format == "memory": # This make a copy of actual SortingAnalyzer @@ -691,10 +776,31 @@ def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> # make a copy of extensions # note that the copy of extension handle itself the slicing of units when necessary and also the saveing - for extension_name, extension in self.extensions.items(): - new_ext = new_sorting_analyzer.extensions[extension_name] = extension.copy( - new_sorting_analyzer, unit_ids=unit_ids - ) + sorted_extensions = _sort_extensions_by_dependency(self.extensions) + recompute_dict = {} + + for extension_name, extension in sorted_extensions.items(): + if merge_unit_groups is None: + # copy full or select + new_sorting_analyzer.extensions[extension_name] = extension.copy( + new_sorting_analyzer, unit_ids=unit_ids + ) + else: + # merge + if merging_mode == "soft": + new_sorting_analyzer.extensions[extension_name] = extension.merge( + new_sorting_analyzer, + merge_unit_groups=merge_unit_groups, + new_unit_ids=new_unit_ids, + keep_mask=keep_mask, + verbose=verbose, + **job_kwargs, + ) + elif merging_mode == "hard": + recompute_dict[extension_name] = extension.params + + if merge_unit_groups is not None and merging_mode == "hard" and len(recompute_dict) > 0: + new_sorting_analyzer.compute_several_extensions(recompute_dict, save=True, verbose=verbose, **job_kwargs) return new_sorting_analyzer @@ -714,7 +820,7 @@ def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": format : "binary_folder" | "zarr", default: "binary_folder" The backend to use for saving the waveforms """ - return self._save_or_select(format=format, folder=folder, unit_ids=None) + return self._save_or_select_or_merge(format=format, folder=folder) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -727,23 +833,136 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyz ---------- unit_ids : list or array The unit ids to keep in the new SortingAnalyzer object + format : "memory" | "binary_folder" | "zarr" , default: "memory" + The format of the returned SortingAnalyzer. folder : Path or None - The new folder where selected waveforms are copied - format: - a + The new folder where selected waveforms are copied. + + Returns + ------- + analyzer : SortingAnalyzer + The newly create sorting_analyzer with the selected units + """ + # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! + return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) + + def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "SortingAnalyzer": + """ + This method is equivalent to `save_as()` but with removal of a subset of units. + Filters units by creating a new sorting analyzer object in a new folder. + + Extensions are also updated to remove the unit ids. + + Parameters + ---------- + remove_unit_ids : list or array + The unit ids to remove in the new SortingAnalyzer object. + format : "memory" | "binary_folder" | "zarr" , default: "memory" + The format of the returned SortingAnalyzer. + folder : Path or None + The new folder where selected waveforms are copied. + Returns ------- - we : SortingAnalyzer + analyzer : SortingAnalyzer The newly create sorting_analyzer with the selected units """ # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! - return self._save_or_select(format=format, folder=folder, unit_ids=unit_ids) + unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] + return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) + + def merge_units( + self, + merge_unit_groups, + new_unit_ids=None, + censor_ms=None, + merging_mode="soft", + sparsity_overlap=0.75, + new_id_strategy="append", + format="memory", + folder=None, + verbose=False, + **job_kwargs, + ) -> "SortingAnalyzer": + """ + This method is equivalent to `save_as()`but with a list of merges that have to be achieved. + Merges units by creating a new sorting analyzer object in a new folder with appropriate merges + + Extensions are also updated to display the merged unit ids. + + Parameters + ---------- + merge_unit_groups : list/tuple of lists/tuples + A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), + but it can also have more (merge multiple units at once). + new_unit_ids : None or list + A new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`. If None, + merged units will have the first unit_id of every lists of merges + censor_ms : None or float + When merging units, any spikes violating this refractory period will be discarded. Default is None + merging_mode : "soft" can be in ["soft", "hard"] + How merges are performed. In the "soft" mode, merges will be approximated, with no reloading of the + waveforms. This will lead to approximations. If "hard", recomputations are accuratly performed, + reloading waveforms if needed + sparsity_overlap : float, default 0.75 + The percentage of overlap that units should share in order to accept merges. If this criteria is not + achieved, soft merging will not be possible + new_id_strategy : "append" | "take_first", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "take_first" : new_unit_ids will be the first unit_id of every list of merges + folder : Path or None + The new folder where selected waveforms are copied + format : "auto" | "binary_folder" | "zarr" + The format of the folder. + verbose: + + + Returns + ------- + analyzer : SortingAnalyzer + The newly create sorting_analyzer with the selected units + """ + + assert merging_mode in ["soft", "hard"], "Merging mode should be either soft or hard" + + if len(merge_unit_groups) == 0: + # TODO I think we should raise an error or at least make a copy and not return itself + return self + + for units in merge_unit_groups: + # TODO more checks like one units is only in one group + if len(units) < 2: + raise ValueError("Merging requires at least two units to merge") + + # TODO : no this function did not exists before + if not isinstance(merge_unit_groups[0], (list, tuple)): + # keep backward compatibility : the previous behavior was only one merge + merge_unit_groups = [merge_unit_groups] + + new_unit_ids = generate_unit_ids_for_merge_group( + self.unit_ids, merge_unit_groups, new_unit_ids, new_id_strategy + ) + all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids) + + return self._save_or_select_or_merge( + format=format, + folder=folder, + merge_unit_groups=merge_unit_groups, + unit_ids=all_unit_ids, + censor_ms=censor_ms, + merging_mode=merging_mode, + sparsity_overlap=sparsity_overlap, + verbose=verbose, + new_unit_ids=new_unit_ids, + **job_kwargs, + ) def copy(self): """ Create a a copy of SortingAnalyzer with format "memory". """ - return self._save_or_select(format="memory", folder=None, unit_ids=None) + return self._save_or_select_or_merge(format="memory", folder=None) def is_read_only(self) -> bool: if self.format == "memory": @@ -877,7 +1096,7 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar input : str or dict or list The extensions to compute, which can be passed as: * a string: compute one extension. Additional parameters can be passed as key word arguments. - * a dict: compute several extensions. The keys are the extension names and the values are dictiopnaries with the extension parameters. + * a dict: compute several extensions. The keys are the extension names and the values are dictionaries with the extension parameters. * a list: compute several extensions. The list contains the extension names. Additional parameters can be passed with the extension_params argument. save : bool, default: True @@ -1001,7 +1220,6 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar extension_instance.run(save=save, verbose=verbose) self.extensions[extension_name] = extension_instance - return extension_instance def compute_several_extensions(self, extensions, save=True, verbose=False, **job_kwargs): @@ -1462,6 +1680,7 @@ class AnalyzerExtension: * _set_params() * _run() * _select_extension_data() + * _merge_extension_data() * _get_data() The subclass must also set an `extension_name` class attribute which is not None by default. @@ -1504,6 +1723,12 @@ def _select_extension_data(self, unit_ids): # must be implemented in subclass raise NotImplementedError + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=False, **job_kwargs + ): + # must be implemented in subclass + raise NotImplementedError + def _get_pipeline_nodes(self): # must be implemented in subclass only if use_nodepipeline=True raise NotImplementedError @@ -1512,9 +1737,6 @@ def _get_data(self): # must be implemented in subclass raise NotImplementedError - # - ####### - @classmethod def function_factory(cls): # make equivalent @@ -1667,6 +1889,23 @@ def copy(self, new_sorting_analyzer, unit_ids=None): new_extension.save() return new_extension + def merge( + self, + new_sorting_analyzer, + merge_unit_groups, + new_unit_ids, + keep_mask=None, + verbose=False, + **job_kwargs, + ): + new_extension = self.__class__(new_sorting_analyzer) + new_extension.params = self.params.copy() + new_extension.data = self._merge_extension_data( + merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs + ) + new_extension.save() + return new_extension + def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): # this also reset the folder or zarr group diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 7456680b2a..d89eb7fac0 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -1,5 +1,6 @@ import pytest from pathlib import Path +import numpy as np import shutil @@ -112,6 +113,9 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): # this bug requires that we have an info.json file so we calculate templates above select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) assert len(select_units_sorting_analyer.unit_ids) == 1 + remove_units_sorting_analyer = sorting_analyzer.remove_units(remove_unit_ids=[1]) + assert len(remove_units_sorting_analyer.unit_ids) == len(sorting_analyzer.unit_ids) - 1 + assert 1 not in remove_units_sorting_analyer.unit_ids folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" if folder.exists(): @@ -215,13 +219,59 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): keep_unit_ids = original_sorting.unit_ids[::2] sorting_analyzer2 = sorting_analyzer.select_units(unit_ids=keep_unit_ids, format=format, folder=folder) - # check propagation of result data and correct sligin + # check propagation of result data and correct aligin assert np.array_equal(keep_unit_ids, sorting_analyzer2.unit_ids) data = sorting_analyzer2.get_extension("dummy").data assert data["result_one"] == sorting_analyzer.get_extension("dummy").data["result_one"] # unit 1, 3, ... should be removed assert np.all(~np.isin(data["result_two"], [1, 3])) + # remove unit_ids to several format + if format != "memory": + if format == "zarr": + folder = cache_folder / f"test_SortingAnalyzer_remove_units_with_{format}.zarr" + else: + folder = cache_folder / f"test_SortingAnalyzer_remove_units_with_{format}" + if folder.exists(): + shutil.rmtree(folder) + else: + folder = None + # compute one extension to check the slice + sorting_analyzer.compute("dummy") + remove_unit_ids = original_sorting.unit_ids[::2] + sorting_analyzer3 = sorting_analyzer.remove_units(remove_unit_ids=remove_unit_ids, format=format, folder=folder) + + # check propagation of result data and correct aligin + assert np.array_equal(original_sorting.unit_ids[1::2], sorting_analyzer3.unit_ids) + data = sorting_analyzer3.get_extension("dummy").data + assert data["result_one"] == sorting_analyzer.get_extension("dummy").data["result_one"] + # unit 0, 2, ... should be removed + assert np.all(~np.isin(data["result_two"], [0, 2])) + + if format != "memory": + if format == "zarr": + folder = cache_folder / f"test_SortingAnalyzer_merge_soft_with_{format}.zarr" + else: + folder = cache_folder / f"test_SortingAnalyzer_merge_with_{format}" + if folder.exists(): + shutil.rmtree(folder) + else: + folder = None + sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[[0, 1]], format=format, folder=folder) + + if format != "memory": + if format == "zarr": + folder = cache_folder / f"test_SortingAnalyzer_merge_hard_with_{format}.zarr" + else: + folder = cache_folder / f"test_SortingAnalyzer_merge_hard_with_{format}" + if folder.exists(): + shutil.rmtree(folder) + else: + folder = None + sorting_analyzer5 = sorting_analyzer.merge_units( + merge_unit_groups=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, mode="hard" + ) + # test compute with extension-specific params sorting_analyzer.compute(["dummy"], extension_params={"dummy": {"param1": 5.5}}) dummy_ext = sorting_analyzer.get_extension("dummy") @@ -237,11 +287,11 @@ def test_extension_params(): assert ext in computable_extension if mod == "spikeinterface.core": default_params = get_default_analyzer_extension_params(ext) - print(ext, default_params) + # print(ext, default_params) else: try: default_params = get_default_analyzer_extension_params(ext) - print(ext, default_params) + # print(ext, default_params) except: print(f"Failed to import {ext}") @@ -254,7 +304,6 @@ class DummyAnalyzerExtension(AnalyzerExtension): def _set_params(self, param0="yep", param1=1.2, param2=[1, 2, 3.0]): params = dict(param0=param0, param1=param1, param2=param2) - params["more_option"] = "yep" return params def _run(self, **kwargs): @@ -264,6 +313,7 @@ def _run(self, **kwargs): # and represent nothing (the trick is to use unit_index for testing slice) spikes = self.sorting_analyzer.sorting.to_spike_vector() self.data["result_two"] = spikes["unit_index"].copy() + self.data["result_three"] = np.zeros((len(self.sorting_analyzer.unit_ids), 2)) def _select_extension_data(self, unit_ids): keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) @@ -276,6 +326,32 @@ def _select_extension_data(self, unit_ids): new_data["result_one"] = self.data["result_one"] new_data["result_two"] = self.data["result_two"][keep_spike_mask] + keep_spike_mask = np.isin(self.sorting_analyzer.unit_ids, unit_ids) + new_data["result_three"] = self.data["result_three"][keep_spike_mask] + + return new_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + + all_new_unit_ids = new_sorting_analyzer.unit_ids + new_data = dict() + new_data["result_one"] = self.data["result_one"] + new_data["result_two"] = self.data["result_two"] + + arr = self.data["result_three"] + num_dims = arr.shape[1] + new_data["result_three"] = np.zeros((len(all_new_unit_ids), num_dims), dtype=arr.dtype) + for unit_ind, unit_id in enumerate(all_new_unit_ids): + if unit_id not in new_unit_ids: + keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + new_data["result_three"][unit_ind] = arr[keep_unit_index] + else: + id = np.flatnonzero(new_unit_ids == unit_id)[0] + keep_unit_indices = self.sorting_analyzer.sorting.ids_to_indices(merge_unit_groups[id]) + new_data["result_three"][unit_ind] = arr[keep_unit_indices].mean(axis=0) + return new_data def _get_data(self): @@ -331,4 +407,5 @@ def test_extensions_sorting(): test_SortingAnalyzer_zarr(tmp_path, dataset) test_SortingAnalyzer_tmp_recording(dataset) test_extension() + test_SortingAnalyzer_merge_all_extensions() test_extension_params() diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 40feac3397..d6d60ee73b 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -6,7 +6,8 @@ from __future__ import annotations -from typing import Literal, Optional +import warnings +from typing import Optional from pathlib import Path @@ -543,7 +544,11 @@ def _read_old_waveforms_extractor_binary(folder, sorting): ext = new_class(sorting_analyzer) with open(ext_folder / "params.json", "r") as f: params = json.load(f) - ext.params = params + # update params + new_params = ext._set_params() + updated_params = make_ext_params_up_to_date(ext, params, new_params) + ext.set_params(**updated_params) + if new_name == "spike_amplitudes": amplitudes = [] for segment_index in range(sorting.get_num_segments()): @@ -604,6 +609,21 @@ def _read_old_waveforms_extractor_binary(folder, sorting): return sorting_analyzer +def make_ext_params_up_to_date(ext, old_params, new_params): + # adjust params + old_name = ext.extension_name + updated_params = old_params.copy() + for p, values in old_params.items(): + if isinstance(values, dict): + new_values = new_params.get(p, {}) + updated_params[p] = make_ext_params_up_to_date(ext, values, new_values) + else: + if p not in new_params: + warnings.warn(f"Removing legacy param {p} from {old_name} extension") + updated_params.pop(p) + return updated_params + + # this was never used, let's comment it out # def _read_old_waveforms_extractor_zarr(folder, sorting): # import zarr diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index d6eded4345..fc75f74399 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -28,7 +28,7 @@ def validate_curation_dict(curation_dict): # unit_ids labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) - merged_units_set = set(sum(curation_dict["merged_unit_groups"], [])) + merged_units_set = set(sum(curation_dict["merge_unit_groups"], [])) removed_units_set = set(curation_dict["removed_units"]) if curation_dict["unit_ids"] is not None: @@ -41,7 +41,7 @@ def validate_curation_dict(curation_dict): if not removed_units_set.issubset(unit_set): raise ValueError("Curation format: some removed units are not in the unit list") - all_merging_groups = [set(group) for group in curation_dict["merged_unit_groups"]] + all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]] for gp_1, gp_2 in combinations(all_merging_groups, 2): if len(gp_1.intersection(gp_2)) != 0: raise ValueError("Some units belong to multiple merge groups") @@ -112,7 +112,7 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo "unit_ids": None, "label_definitions": labels_def, "manual_labels": manual_labels, - "merged_unit_groups": merge_groups, + "merge_unit_groups": merge_groups, "removed_units": [], } diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 6d132fbe97..94812ee0aa 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -23,7 +23,7 @@ category_key1': List[str], } ], - 'merged_unit_groups': List[List[unit_ids]], # one cell goes into at most one list + 'merge_unit_groups': List[List[unit_ids]], # one cell goes into at most one list 'removed_units': List[unit_ids] # Can not be in the merged_units } """ @@ -50,7 +50,7 @@ }, {"unit_id": 3, "putative_type": ["inhibitory"]}, ], - "merged_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list + "merge_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list "removed_units": [31, 42], # Can not be in the merged_units } @@ -75,23 +75,23 @@ }, {"unit_id": "u3", "putative_type": ["inhibitory"]}, ], - "merged_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list + "merge_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list "removed_units": ["u31", "u42"], # Can not be in the merged_units } # This is a failure example with duplicated merge duplicate_merge = curation_ids_int.copy() -duplicate_merge["merged_unit_groups"] = [[3, 6, 10], [10, 14, 20]] +duplicate_merge["merge_unit_groups"] = [[3, 6, 10], [10, 14, 20]] # This is a failure example with unit 3 both in removed and merged merged_and_removed = curation_ids_int.copy() -merged_and_removed["merged_unit_groups"] = [[3, 6], [10, 14, 20]] +merged_and_removed["merge_unit_groups"] = [[3, 6], [10, 14, 20]] merged_and_removed["removed_units"] = [3, 31, 42] # this is a failure because unit 99 is not in the initial list unknown_merged_unit = curation_ids_int.copy() -unknown_merged_unit["merged_unit_groups"] = [[3, 6, 99], [10, 14, 20]] +unknown_merged_unit["merge_unit_groups"] = [[3, 6, 99], [10, 14, 20]] # this is a failure because unit 99 is not in the initial list unknown_removed_unit = curation_ids_int.copy() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 8ff9cc5666..0d57aec21e 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -13,8 +13,6 @@ from ..core.template_tools import get_dense_templates_array, _get_nbefore -# TODO extra sparsity and job_kwargs handling - class ComputeAmplitudeScalings(AnalyzerExtension): """ @@ -114,6 +112,22 @@ def _select_extension_data(self, unit_ids): new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask] return new_data + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + new_data = dict() + + if keep_mask is None: + new_data["amplitude_scalings"] = self.data["amplitude_scalings"].copy() + if self.params["handle_collisions"]: + new_data["collision_mask"] = self.data["collision_mask"].copy() + else: + new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask] + if self.params["handle_collisions"]: + new_data["collision_mask"] = self.data["collision_mask"][keep_mask] + + return new_data + def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 7c22260dbe..3c65f2075c 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -90,6 +90,14 @@ def _select_extension_data(self, unit_ids): new_data = dict(ccgs=new_ccgs, bins=new_bins) return new_data + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs + ): + # recomputing correlogram is fast enough and much easier in this case + new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) + new_data = dict(ccgs=new_ccgs, bins=new_bins) + return new_data + def _run(self, verbose=False): ccgs, bins = _compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) self.data["ccgs"] = ccgs diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index c738383636..fa919e11e2 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -56,6 +56,30 @@ def _select_extension_data(self, unit_ids): new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) return new_extension_data + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs + ): + new_bins = self.data["bins"] + arr = self.data["isi_histograms"] + num_dims = arr.shape[1] + all_new_units = new_sorting_analyzer.unit_ids + new_isi_hists = np.zeros((len(all_new_units), num_dims), dtype=arr.dtype) + + # compute all new isi at once + new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids) + only_new_hist, _ = _compute_isi_histograms(new_sorting, **self.params) + + for unit_ind, unit_id in enumerate(all_new_units): + if unit_id not in new_unit_ids: + keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + new_isi_hists[unit_ind, :] = arr[keep_unit_index, :] + else: + new_unit_index = new_sorting.id_to_index(unit_id) + new_isi_hists[unit_ind, :] = only_new_hist[new_unit_index, :] + + new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) + return new_extension_data + def _run(self, verbose=False): isi_histograms, bins = _compute_isi_histograms(self.sorting_analyzer.sorting, **self.params) self.data["isi_histograms"] = isi_histograms diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index b7571a6f3e..e6278fc59f 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -18,6 +18,7 @@ def compute_monopolar_triangulation( sorting_analyzer_or_templates: SortingAnalyzer | Templates, + unit_ids=None, optimizer: str = "least_square", radius_um: float = 75, max_distance_um: float = 1000, @@ -46,6 +47,8 @@ def compute_monopolar_triangulation( ---------- sorting_analyzer_or_templates : SortingAnalyzer | Templates A SortingAnalyzer or Templates object + unit_ids: str | int | None + A list of unit_id to restrci the computation method : "least_square" | "minimize_with_log_penality", default: "least_square" The optimizer to use radius_um : float, default: 75 @@ -71,7 +74,6 @@ def compute_monopolar_triangulation( assert optimizer in ("least_square", "minimize_with_log_penality") assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" - unit_ids = sorting_analyzer_or_templates.unit_ids contact_locations = sorting_analyzer_or_templates.get_channel_locations() @@ -81,6 +83,13 @@ def compute_monopolar_triangulation( ) nbefore = _get_nbefore(sorting_analyzer_or_templates) + if unit_ids is None: + unit_ids = sorting_analyzer_or_templates.unit_ids + else: + unit_ids = np.asanyarray(unit_ids) + keep = np.isin(sorting_analyzer_or_templates.unit_ids, unit_ids) + templates = templates[keep, :, :] + if enforce_decrease: neighbours_mask = np.zeros((templates.shape[0], templates.shape[2]), dtype=bool) for i, unit_id in enumerate(unit_ids): @@ -118,6 +127,7 @@ def compute_monopolar_triangulation( def compute_center_of_mass( sorting_analyzer_or_templates: SortingAnalyzer | Templates, + unit_ids=None, peak_sign: str = "neg", radius_um: float = 75, feature: str = "ptp", @@ -129,6 +139,8 @@ def compute_center_of_mass( ---------- sorting_analyzer_or_templates : SortingAnalyzer | Templates A SortingAnalyzer or Templates object + unit_ids: str | int | None + A list of unit_id to restrci the computation peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um : float @@ -140,7 +152,6 @@ def compute_center_of_mass( ------- unit_location: np.array """ - unit_ids = sorting_analyzer_or_templates.unit_ids contact_locations = sorting_analyzer_or_templates.get_channel_locations() @@ -154,6 +165,13 @@ def compute_center_of_mass( ) nbefore = _get_nbefore(sorting_analyzer_or_templates) + if unit_ids is None: + unit_ids = sorting_analyzer_or_templates.unit_ids + else: + unit_ids = np.asanyarray(unit_ids) + keep = np.isin(sorting_analyzer_or_templates.unit_ids, unit_ids) + templates = templates[keep, :, :] + unit_location = np.zeros((unit_ids.size, 2), dtype="float64") for i, unit_id in enumerate(unit_ids): chan_inds = sparsity.unit_id_to_channel_indices[unit_id] @@ -179,6 +197,7 @@ def compute_center_of_mass( def compute_grid_convolution( sorting_analyzer_or_templates: SortingAnalyzer | Templates, + unit_ids=None, peak_sign: str = "neg", radius_um: float = 40.0, upsampling_um: float = 5, @@ -195,6 +214,8 @@ def compute_grid_convolution( ---------- sorting_analyzer_or_templates : SortingAnalyzer | Templates A SortingAnalyzer or Templates object + unit_ids: str | int | None + A list of unit_id to restrci the computation peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um : float, default: 40.0 @@ -220,7 +241,6 @@ def compute_grid_convolution( """ contact_locations = sorting_analyzer_or_templates.get_channel_locations() - unit_ids = sorting_analyzer_or_templates.unit_ids templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) @@ -228,6 +248,13 @@ def compute_grid_convolution( nbefore = _get_nbefore(sorting_analyzer_or_templates) nafter = templates.shape[1] - nbefore + if unit_ids is None: + unit_ids = sorting_analyzer_or_templates.unit_ids + else: + unit_ids = np.asanyarray(unit_ids) + keep = np.isin(sorting_analyzer_or_templates.unit_ids, unit_ids) + templates = templates[keep, :, :] + fs = sorting_analyzer_or_templates.sampling_frequency percentile = 100 - percentile assert 0 <= percentile <= 100, "Percentile should be in [0, 100]" @@ -621,3 +648,10 @@ def get_convolution_weights( if HAVE_NUMBA: enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True) + + +_unit_location_methods = { + "center_of_mass": compute_center_of_mass, + "grid_convolution": compute_grid_convolution, + "monopolar_triangulation": compute_monopolar_triangulation, +} diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 6252c0582b..1138adac7d 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -13,6 +13,8 @@ from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs +from spikeinterface.core.analyzer_extension_core import _inplace_sparse_realign_waveforms + _possible_modes = ["by_channel_local", "by_channel_global", "concatenated"] @@ -101,6 +103,50 @@ def _select_extension_data(self, unit_ids): new_data[k] = v return new_data + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + + pca_projections = self.data["pca_projection"] + some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() + + if keep_mask is not None: + spike_indices = self.sorting_analyzer.get_extension("random_spikes").get_data() + valid = keep_mask[spike_indices] + some_spikes = some_spikes[valid] + pca_projections = pca_projections[valid] + else: + pca_projections = pca_projections.copy() + + old_sparsity = self.sorting_analyzer.sparsity + if old_sparsity is not None: + + # we need a realignement inside each group because we take the channel intersection sparsity + # the story is same as in "waveforms" extension + for group_ids in merge_unit_groups: + group_indices = self.sorting_analyzer.sorting.ids_to_indices(group_ids) + group_sparsity_mask = old_sparsity.mask[group_indices, :] + group_selection = [] + for unit_id in group_ids: + unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + selection = np.flatnonzero(some_spikes["unit_index"] == unit_index) + group_selection.append(selection) + + _inplace_sparse_realign_waveforms(pca_projections, group_selection, group_sparsity_mask) + + old_num_chans = int(np.max(np.sum(old_sparsity.mask, axis=1))) + new_num_chans = int(np.max(np.sum(new_sorting_analyzer.sparsity.mask, axis=1))) + if new_num_chans < old_num_chans: + pca_projections = pca_projections[:, :, :new_num_chans] + + new_data = dict(pca_projection=pca_projections) + + # one or several model + for k, v in self.data.items(): + if "model" in k: + new_data[k] = v + return new_data + def get_pca_model(self): """ Returns the scikit-learn PCA model objects. diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 72cbcb651f..e82a9e61e4 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -81,6 +81,18 @@ def _select_extension_data(self, unit_ids): return new_data + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + new_data = dict() + + if keep_mask is None: + new_data["amplitudes"] = self.data["amplitudes"].copy() + else: + new_data["amplitudes"] = self.data["amplitudes"][keep_mask] + + return new_data + def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 23301292e5..53e55b4d1f 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -92,6 +92,19 @@ def _select_extension_data(self, unit_ids): new_spike_locations = self.data["spike_locations"][spike_mask] return dict(spike_locations=new_spike_locations) + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + + if keep_mask is None: + new_spike_locations = self.data["spike_locations"].copy() + else: + new_spike_locations = self.data["spike_locations"][keep_mask] + + ### In theory here, we should recompute the locations since the peak positions + ### in a merged could be different. Should be discussed + return dict(spike_locations=new_spike_locations) + def _get_pipeline_nodes(self): from spikeinterface.sortingcomponents.peak_localization import get_localization_pipeline_nodes diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index fdc4ef4719..eef2a2f32c 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -111,6 +111,7 @@ def _set_params( sparsity=None, metrics_kwargs=None, include_multi_channel_metrics=False, + **other_kwargs, ): # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() @@ -132,6 +133,10 @@ def _set_params( if metrics_kwargs is None: metrics_kwargs_ = _default_function_kwargs.copy() + if len(other_kwargs) > 0: + for m in other_kwargs: + if m in metrics_kwargs_: + metrics_kwargs_[m] = other_kwargs[m] else: metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) @@ -150,7 +155,28 @@ def _select_extension_data(self, unit_ids): new_metrics = self.data["metrics"].loc[np.array(unit_ids)] return dict(metrics=new_metrics) - def _run(self, verbose=False): + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + import pandas as pd + + old_metrics = self.data["metrics"] + + all_unit_ids = new_sorting_analyzer.unit_ids + not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)] + + metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) + + metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] + metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + + new_data = dict(metrics=metrics) + return new_data + + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + """ + Compute template metrics. + """ import pandas as pd from scipy.signal import resample_poly @@ -158,16 +184,15 @@ def _run(self, verbose=False): sparsity = self.params["sparsity"] peak_sign = self.params["peak_sign"] upsampling_factor = self.params["upsampling_factor"] - unit_ids = self.sorting_analyzer.unit_ids - sampling_frequency = self.sorting_analyzer.sampling_frequency + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + sampling_frequency = sorting_analyzer.sampling_frequency metrics_single_channel = [m for m in metric_names if m in get_single_channel_template_metric_names()] metrics_multi_channel = [m for m in metric_names if m in get_multi_channel_template_metric_names()] if sparsity is None: - extremum_channels_ids = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="id" - ) + extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="id") template_metrics = pd.DataFrame(index=unit_ids, columns=metric_names) else: @@ -182,16 +207,17 @@ def _run(self, verbose=False): ) template_metrics = pd.DataFrame(index=multi_index, columns=metric_names) - all_templates = get_dense_templates_array(self.sorting_analyzer, return_scaled=True) + all_templates = get_dense_templates_array(sorting_analyzer, return_scaled=True) - channel_locations = self.sorting_analyzer.get_channel_locations() + channel_locations = sorting_analyzer.get_channel_locations() - for unit_index, unit_id in enumerate(unit_ids): + for unit_id in unit_ids: + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) template_all_chans = all_templates[unit_index] chan_ids = np.array(extremum_channels_ids[unit_id]) if chan_ids.ndim == 0: chan_ids = [chan_ids] - chan_ind = self.sorting_analyzer.channel_ids_to_indices(chan_ids) + chan_ind = sorting_analyzer.channel_ids_to_indices(chan_ids) template = template_all_chans[:, chan_ind] # compute single_channel metrics @@ -225,8 +251,8 @@ def _run(self, verbose=False): for metric_name in metrics_multi_channel: # retrieve template (with sparsity if waveform extractor is sparse) template = all_templates[unit_index, :, :] - if self.sorting_analyzer.is_sparse(): - mask = self.sorting_analyzer.sparsity.mask[unit_index, :] + if sorting_analyzer.is_sparse(): + mask = sorting_analyzer.sparsity.mask[unit_index, :] template = template[:, mask] if template.shape[1] < self.min_channels_for_multi_channel_warning: @@ -234,8 +260,8 @@ def _run(self, verbose=False): f"With less than {self.min_channels_for_multi_channel_warning} channels, " "multi-channel metrics might not be reliable." ) - if self.sorting_analyzer.is_sparse(): - channel_locations_sparse = channel_locations[self.sorting_analyzer.sparsity.mask[unit_index]] + if sorting_analyzer.is_sparse(): + channel_locations_sparse = channel_locations[sorting_analyzer.sparsity.mask[unit_index]] else: channel_locations_sparse = channel_locations @@ -255,7 +281,12 @@ def _run(self, verbose=False): **self.params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value - self.data["metrics"] = template_metrics + return template_metrics + + def _run(self, verbose=False): + self.data["metrics"] = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose + ) def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 777f84dfd7..a9592b0b91 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -5,6 +5,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from ..core.template_tools import get_dense_templates_array +from ..core.sparsity import ChannelSparsity class ComputeTemplateSimilarity(AnalyzerExtension): @@ -63,6 +64,59 @@ def _select_extension_data(self, unit_ids): new_similarity = self.data["similarity"][unit_indices][:, unit_indices] return dict(similarity=new_similarity) + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000) + all_templates_array = get_dense_templates_array( + new_sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled + ) + + keep = np.isin(new_sorting_analyzer.unit_ids, new_unit_ids) + new_templates_array = all_templates_array[keep, :, :] + if new_sorting_analyzer.sparsity is None: + new_sparsity = None + else: + new_sparsity = ChannelSparsity( + new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids, new_sorting_analyzer.channel_ids + ) + + new_similarity = compute_similarity_with_templates_array( + new_templates_array, + all_templates_array, + method=self.params["method"], + num_shifts=num_shifts, + support=self.params["support"], + sparsity=new_sparsity, + other_sparsity=new_sorting_analyzer.sparsity, + ) + + old_similarity = self.data["similarity"] + + all_new_unit_ids = new_sorting_analyzer.unit_ids + n = all_new_unit_ids.size + similarity = np.zeros((n, n), dtype=old_similarity.dtype) + + # copy old similarity + for unit_ind1, unit_id1 in enumerate(all_new_unit_ids): + if unit_id1 not in new_unit_ids: + old_ind1 = self.sorting_analyzer.sorting.id_to_index(unit_id1) + for unit_ind2, unit_id2 in enumerate(all_new_unit_ids): + if unit_id2 not in new_unit_ids: + old_ind2 = self.sorting_analyzer.sorting.id_to_index(unit_id2) + s = self.data["similarity"][old_ind1, old_ind2] + similarity[unit_ind1, unit_ind2] = s + similarity[unit_ind1, unit_ind2] = s + + # insert new similarity both way + for unit_ind, unit_id in enumerate(all_new_unit_ids): + if unit_id in new_unit_ids: + new_index = list(new_unit_ids).index(unit_id) + similarity[unit_ind, :] = new_similarity[new_index, :] + similarity[:, unit_ind] = new_similarity[new_index, :] + + return dict(similarity=similarity) + def _run(self, verbose=False): num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000) templates_array = get_dense_templates_array( @@ -114,6 +168,8 @@ def compute_similarity_with_templates_array( num_channels = templates_array.shape[2] other_num_templates = other_templates_array.shape[0] + same_array = np.array_equal(templates_array, other_templates_array) + mask = None if sparsity is not None and other_sparsity is not None: if support == "intersection": @@ -139,7 +195,14 @@ def compute_similarity_with_templates_array( # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t # So the matrix can be computed only for negative lags and be transposed - for count, shift in enumerate(range(-num_shifts, 1)): + + if same_array: + # optimisation when array are the same because of symetry in shift + shift_loop = range(-num_shifts, 1) + else: + shift_loop = range(-num_shifts, num_shifts + 1) + + for count, shift in enumerate(shift_loop): src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): @@ -147,7 +210,8 @@ def compute_similarity_with_templates_array( tgt_templates = tgt_sliced_templates[overlapping_templates[i]] for gcount, j in enumerate(overlapping_templates[i]): # symmetric values are handled later - if num_templates == other_num_templates and j < i: + if same_array and j < i: + # no need exhaustive looping when same template continue src = src_template[:, mask[i, j]].reshape(1, -1) tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) @@ -164,11 +228,12 @@ def compute_similarity_with_templates_array( distances[count, i, j] /= norm_i + norm_j else: distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="cosine") - if num_templates == other_num_templates: + + if same_array: distances[count, j, i] = distances[count, i, j] - if num_shifts != 0: - distances[num_shifts_both_sides - count - 1] = distances[count].T + if same_array and num_shifts != 0: + distances[num_shifts_both_sides - count - 1] = distances[count].T distances = np.min(distances, axis=0) similarity = 1 - distances diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 52dbaf23d4..c93e941033 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -101,7 +101,7 @@ def _prepare_sorting_analyzer(self, format, sparse, extension_class): sorting_analyzer = self.get_sorting_analyzer( self.recording, self.sorting, format=format, sparsity=sparsity_, name=extension_class.extension_name ) - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=50, seed=2205) + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=20, seed=2205) for dependency_name in extension_class.depend_on: if "|" in dependency_name: @@ -133,6 +133,11 @@ def _check_one(self, sorting_analyzer, extension_class, params): sliced = sorting_analyzer.select_units(some_unit_ids, format="memory") assert np.array_equal(sliced.unit_ids, sorting_analyzer.unit_ids[::2]) + some_merges = [sorting_analyzer.unit_ids[:2].tolist()] + num_units_after_merge = len(sorting_analyzer.unit_ids) - 1 + merged = sorting_analyzer.merge_units(some_merges, format="memory", mode="soft", sparsity_overlap=0.0) + assert len(merged.unit_ids) == num_units_after_merge + def run_extension_tests(self, extension_class, params): """ Convenience function to perform all checks on the extension diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py new file mode 100644 index 0000000000..bf0000135c --- /dev/null +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -0,0 +1,190 @@ +import pytest + +import time +import numpy as np + +from spikeinterface import ( + create_sorting_analyzer, + generate_ground_truth_recording, + set_global_job_kwargs, + get_template_extremum_amplitude, +) +from spikeinterface.core.generate import inject_some_split_units + + +def get_dataset(): + recording, sorting = generate_ground_truth_recording( + durations=[30.0], + sampling_frequency=16000.0, + num_channels=10, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + + # since templates are going to be averaged and this might be a problem for amplitude scaling + # we select the 3 units with the largest templates to split + analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) + analyzer_raw.compute(["random_spikes", "templates"]) + # select 3 largest templates to split + sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] + split_ids = sorting.unit_ids[sort_by_amp][:3] + + sorting_with_splits, other_ids = inject_some_split_units( + sorting, num_split=3, split_ids=split_ids, output_ids=True, seed=0 + ) + return recording, sorting_with_splits, other_ids + + +@pytest.fixture(scope="module") +def dataset(): + return get_dataset() + + +@pytest.mark.parametrize("sparse", [False, True]) +def test_SortingAnalyzer_merge_all_extensions(dataset, sparse): + set_global_job_kwargs(n_jobs=1) + + recording, sorting, other_ids = dataset + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=sparse) + + # we apply the merges according to the artificial splits + merges = [list(v) for v in other_ids.values()] + split_unit_ids = np.ravel(merges) + unmerged_unit_ids = sorting_analyzer.unit_ids[~np.isin(sorting_analyzer.unit_ids, split_unit_ids)] + + # even if this is in postprocessing, we make an extension for quality metrics + extension_dict = { + "noise_levels": dict(), + "random_spikes": dict(), + "waveforms": dict(), + "templates": dict(), + "principal_components": dict(), + "spike_amplitudes": dict(), + "template_similarity": dict(), + "correlograms": dict(), + "isi_histograms": dict(), + "amplitude_scalings": dict(handle_collisions=False), # otherwise hard mode could fail due to dropped spikes + "spike_locations": dict(method="center_of_mass"), # trick to avoid UserWarning + "unit_locations": dict(), + "template_metrics": dict(), + "quality_metrics": dict(metric_names=["firing_rate", "isi_violation", "snr"]), + } + extension_data_type = { + "noise_levels": None, + "templates": "unit", + "isi_histograms": "unit", + "unit_locations": "unit", + "spike_amplitudes": "spike", + "amplitude_scalings": "spike", + "spike_locations": "spike", + "quality_metrics": "pandas", + "template_metrics": "pandas", + "correlograms": "matrix", + "template_similarity": "matrix", + "principal_components": "random", + "waveforms": "random", + "random_spikes": "random_spikes", + } + data_with_miltiple_returns = ["isi_histograms", "correlograms"] + + # due to incremental PCA, hard computation could result in different results for PCA + # the model is differents always + random_computation = ["principal_components"] + + sorting_analyzer.compute(extension_dict, n_jobs=1) + + # TODO: still some UserWarnings for n_jobs, where from? + t0 = time.perf_counter() + analyzer_merged_hard = sorting_analyzer.merge_units( + merge_unit_groups=merges, censor_ms=2, merging_mode="hard", n_jobs=1 + ) + t_hard = time.perf_counter() - t0 + + t0 = time.perf_counter() + analyzer_merged_soft = sorting_analyzer.merge_units( + merge_unit_groups=merges, censor_ms=2, merging_mode="soft", sparsity_overlap=0.0, n_jobs=1 + ) + t_soft = time.perf_counter() - t0 + + # soft must faster + assert t_soft < t_hard + np.testing.assert_array_equal(analyzer_merged_hard.unit_ids, analyzer_merged_soft.unit_ids) + new_unit_ids = list(np.arange(max(split_unit_ids) + 1, max(split_unit_ids) + 1 + len(merges))) + np.testing.assert_array_equal(analyzer_merged_hard.unit_ids, list(unmerged_unit_ids) + new_unit_ids) + + for ext in extension_dict: + # 1. check that data are exactly the same for unchanged units between hard/soft/original + data_original = sorting_analyzer.get_extension(ext).get_data() + data_hard = analyzer_merged_hard.get_extension(ext).get_data() + data_soft = analyzer_merged_soft.get_extension(ext).get_data() + if ext in data_with_miltiple_returns: + data_original = data_original[0] + data_hard = data_hard[0] + data_soft = data_soft[0] + data_original_unmerged = get_extension_data_for_units( + sorting_analyzer, data_original, unmerged_unit_ids, extension_data_type[ext] + ) + data_hard_unmerged = get_extension_data_for_units( + analyzer_merged_hard, data_hard, unmerged_unit_ids, extension_data_type[ext] + ) + data_soft_unmerged = get_extension_data_for_units( + analyzer_merged_soft, data_soft, unmerged_unit_ids, extension_data_type[ext] + ) + + np.testing.assert_array_equal(data_original_unmerged, data_soft_unmerged) + + if ext not in random_computation: + np.testing.assert_array_equal(data_original_unmerged, data_hard_unmerged) + else: + print(f"Skipping hard test for {ext} due to randomness in computation") + + # 2. check that soft/hard data are similar for merged units + data_hard_merged = get_extension_data_for_units( + analyzer_merged_hard, data_hard, new_unit_ids, extension_data_type[ext] + ) + data_soft_merged = get_extension_data_for_units( + analyzer_merged_soft, data_soft, new_unit_ids, extension_data_type[ext] + ) + + if ext not in random_computation: + if extension_data_type[ext] == "pandas": + data_hard_merged = data_hard_merged.dropna().to_numpy().astype("float") + data_soft_merged = data_soft_merged.dropna().to_numpy().astype("float") + if data_hard_merged.dtype.fields is None: + assert np.allclose(data_hard_merged, data_soft_merged, rtol=0.1) + else: + for f in data_hard_merged.dtype.fields: + assert np.allclose(data_hard_merged[f], data_soft_merged[f], rtol=0.1) + + +def get_extension_data_for_units(sorting_analyzer, data, unit_ids, ext_data_type): + unit_indices = sorting_analyzer.sorting.ids_to_indices(unit_ids) + spike_vector = sorting_analyzer.sorting.to_spike_vector() + if ext_data_type is None: + return data + elif ext_data_type == "random_spikes": + random_spikes = spike_vector[data] + unit_mask = np.isin(random_spikes["unit_index"], unit_indices) + # since merging could scramble unit ids and drop spikes, we need to get the original unit ids + return sorting_analyzer.unit_ids[random_spikes[unit_mask]["unit_index"]] + elif ext_data_type == "random": + random_indices = sorting_analyzer.get_extension("random_spikes").get_data() + unit_mask = np.isin(spike_vector[random_indices]["unit_index"], unit_indices) + return data[unit_mask] + elif ext_data_type == "matrix": + return data[unit_indices][:, unit_indices] + elif ext_data_type == "unit": + return data[unit_indices] + elif ext_data_type == "spike": + unit_mask = np.isin(spike_vector["unit_index"], unit_indices) + return data[unit_mask] + elif ext_data_type == "pandas": + return data.loc[unit_ids].dropna() + + +if __name__ == "__main__": + dataset = get_dataset() + test_SortingAnalyzer_merge_all_extensions(dataset, False) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 38ae3b2c5e..4de86be32b 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -109,6 +109,14 @@ def test_get_projections(self, sparse): assert some_projections.shape[2] == some_channel_ids.size assert 1 not in spike_unit_index + # check correctness + channel_indices = sorting_analyzer.recording.ids_to_indices(some_channel_ids) + for unit_id in some_unit_ids: + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) + spike_mask = spike_unit_index == unit_index + proj_one_unit = ext.get_projections_one_unit(unit_id, sparse=False) + np.testing.assert_array_almost_equal(some_projections[spike_mask], proj_one_unit[:, :, channel_indices]) + @pytest.mark.parametrize("sparse", [True, False]) def test_compute_for_all_spikes(self, sparse): """ @@ -167,3 +175,8 @@ def test_project_new(self): assert new_proj.shape[0] == num_spike assert new_proj.shape[1] == n_components assert new_proj.shape[2] == ext_pca.data["pca_projection"].shape[2] + + +if __name__ == "__main__": + test = TestPrincipalComponentsExtension() + test.test_get_projections(sparse=True) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index f98a5624db..cc6797c262 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -1,10 +1,13 @@ import pytest +import numpy as np + from spikeinterface.postprocessing.tests.common_extension_tests import ( AnalyzerExtensionCommonTestSuite, ) from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity +from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite): @@ -45,3 +48,42 @@ def test_check_equal_template_with_distribution_overlap(self): waveforms1 = wf_ext.get_waveforms_one_unit(unit_id1) assert not check_equal_template_with_distribution_overlap(waveforms0, waveforms1) + + +@pytest.mark.parametrize( + "params", + [ + dict(method="cosine"), + dict(method="cosine", num_shifts=8), + dict(method="l2"), + dict(method="l1", support="intersection"), + dict(method="l2", support="union"), + dict(method="cosine", support="dense"), + ], +) +def test_compute_similarity_with_templates_array(params): + # TODO @ pierre please make more test here + + rng = np.random.default_rng(seed=2205) + templates_array = rng.random(size=(2, 20, 5)) + other_templates_array = rng.random(size=(4, 20, 5)) + + similarity = compute_similarity_with_templates_array(templates_array, other_templates_array, **params) + print(similarity.shape) + + +if __name__ == "__main__": + from spikeinterface.postprocessing.tests.common_extension_tests import get_dataset + from spikeinterface.core import estimate_sparsity + from pathlib import Path + + test = TestSimilarityExtension() + + test.recording, test.sorting = get_dataset() + + test.sparsity = estimate_sparsity(test.sorting, test.recording, method="radius", radius_um=20) + test.cache_folder = Path("./cache_folder") + test.test_extension(params=dict(method="l2")) + + # params = dict(method="cosine", num_shifts=8) + # test_compute_similarity_with_templates_array(params) diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 9435030775..516f22e31e 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -4,12 +4,10 @@ import warnings from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension -from .localization_tools import ( - compute_center_of_mass, - compute_grid_convolution, - compute_monopolar_triangulation, -) +from .localization_tools import _unit_location_methods + +# this dict is for peak location dtype_localize_by_method = { "center_of_mass": [("x", "float64"), ("y", "float64")], "grid_convolution": [("x", "float64"), ("y", "float64"), ("z", "float64")], @@ -49,7 +47,8 @@ def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) def _set_params(self, method="monopolar_triangulation", **method_kwargs): - params = dict(method=method, method_kwargs=method_kwargs) + params = dict(method=method) + params.update(method_kwargs) return params def _select_extension_data(self, unit_ids): @@ -57,27 +56,49 @@ def _select_extension_data(self, unit_ids): new_unit_location = self.data["unit_locations"][unit_inds] return dict(unit_locations=new_unit_location) + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + old_unit_locations = self.data["unit_locations"] + num_dims = old_unit_locations.shape[1] + + method = self.params.get("method") + method_kwargs = self.params.copy() + method_kwargs.pop("method") + func = _unit_location_methods[method] + new_unit_locations = func(new_sorting_analyzer, unit_ids=new_unit_ids, **method_kwargs) + assert new_unit_locations.shape[0] == len(new_unit_ids) + + all_new_unit_ids = new_sorting_analyzer.unit_ids + unit_location = np.zeros((len(all_new_unit_ids), num_dims), dtype=old_unit_locations.dtype) + for unit_index, unit_id in enumerate(all_new_unit_ids): + if unit_id not in new_unit_ids: + old_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + unit_location[unit_index] = old_unit_locations[old_index] + else: + new_index = list(new_unit_ids).index(unit_id) + unit_location[unit_index] = new_unit_locations[new_index] + + return dict(unit_locations=unit_location) + def _run(self, verbose=False): - method = self.params["method"] - method_kwargs = self.params["method_kwargs"] + method = self.params.get("method") + method_kwargs = self.params.copy() + method_kwargs.pop("method") - assert method in possible_localization_methods + if method not in _unit_location_methods: + raise ValueError(f"Wrong ethod for unit_locations : it should be in {list(_unit_location_methods.keys())}") - if method == "center_of_mass": - unit_location = compute_center_of_mass(self.sorting_analyzer, **method_kwargs) - elif method == "grid_convolution": - unit_location = compute_grid_convolution(self.sorting_analyzer, **method_kwargs) - elif method == "monopolar_triangulation": - unit_location = compute_monopolar_triangulation(self.sorting_analyzer, **method_kwargs) - self.data["unit_locations"] = unit_location + func = _unit_location_methods[method] + self.data["unit_locations"] = func(self.sorting_analyzer, **method_kwargs) def get_data(self, outputs="numpy"): if outputs == "numpy": return self.data["unit_locations"] elif outputs == "by_unit": locations_by_unit = {} - for unit_ind, unit_id in enumerate(self.sorting_analyzer.unit_ids): - locations_by_unit[unit_id] = self.data["unit_locations"][unit_ind] + for unit_index, unit_id in enumerate(self.sorting_analyzer.unit_ids): + locations_by_unit[unit_id] = self.data["unit_locations"][unit_index] return locations_by_unit diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 24165da5b3..7465d58737 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -593,6 +593,9 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ sorting = sorting_analyzer.sorting + if unit_ids is None: + unit_ids = sorting.unit_ids + spike_counts = sorting.count_num_spikes_per_unit(outputs="dict") spikes = sorting.to_spike_vector() @@ -603,21 +606,15 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ for sync_idx, synchrony_size in enumerate(synchrony_sizes_np): sync_id_metrics_dict = {} for i, unit_id in enumerate(all_unit_ids): + if unit_id not in unit_ids: + continue if spike_counts[unit_id] != 0: sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / spike_counts[unit_id] else: sync_id_metrics_dict[unit_id] = 0 synchrony_metrics_dict[f"sync_spike_{synchrony_size}"] = sync_id_metrics_dict - if np.all(unit_ids == None) or (len(unit_ids) == len(all_unit_ids)): - return res(**synchrony_metrics_dict) - else: - reduced_synchrony_metrics_dict = {} - for key in synchrony_metrics_dict: - reduced_synchrony_metrics_dict[key] = { - unit_id: synchrony_metrics_dict[key][unit_id] for unit_id in unit_ids - } - return res(**reduced_synchrony_metrics_dict) + return res(**synchrony_metrics_dict) _default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) @@ -1036,6 +1033,7 @@ def compute_drift_metrics( spike_locations_by_unit = {} for unit_id in unit_ids: unit_index = sorting.id_to_index(unit_id) + # TODO @alessio this is very slow this sjould be done with spike_vector_to_indices() in code spike_mask = spikes["unit_index"] == unit_index spike_locations_by_unit[unit_id] = spike_locations[spike_mask] @@ -1074,8 +1072,9 @@ def compute_drift_metrics( # reference positions are the medians across segments reference_positions = np.zeros(len(unit_ids)) - for unit_ind, unit_id in enumerate(unit_ids): - reference_positions[unit_ind] = np.median(spike_locations_by_unit[unit_id][direction]) + for i, unit_id in enumerate(unit_ids): + unit_ind = sorting.id_to_index(unit_id) + reference_positions[i] = np.median(spike_locations_by_unit[unit_id][direction]) # now compute median positions and concatenate them over segments median_position_segments = None @@ -1098,7 +1097,8 @@ def compute_drift_metrics( spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] for i, unit_id in enumerate(unit_ids): - mask = spikes_in_bin["unit_index"] == sorting.id_to_index(unit_id) + unit_ind = sorting.id_to_index(unit_id) + mask = spikes_in_bin["unit_index"] == unit_ind if np.sum(mask) >= min_spikes_per_interval: median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask]) if median_position_segments is None: @@ -1108,8 +1108,8 @@ def compute_drift_metrics( # finally, compute deviations and drifts position_diffs = median_position_segments - reference_positions[:, None] - for unit_ind, unit_id in enumerate(unit_ids): - position_diff = position_diffs[unit_ind] + for i, unit_id in enumerate(unit_ids): + position_diff = position_diffs[i] if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index fa1940c2ba..1c5a491bf8 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -162,8 +162,8 @@ def compute_pc_metrics( if progress_bar: units_loop = tqdm(units_loop, desc="calculate pc_metrics", total=len(unit_ids)) - for unit_ind, unit_id in units_loop: - pca_metrics_unit = pca_metrics_one_unit(items[unit_ind]) + for i, unit_id in units_loop: + pca_metrics_unit = pca_metrics_one_unit(items[i]) for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric elif run_in_parallel and non_nn_metrics: diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index f3eecb20bf..0c7cf25237 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -86,7 +86,25 @@ def _select_extension_data(self, unit_ids): new_data = dict(metrics=new_metrics) return new_data - def _run(self, verbose=False, **job_kwargs): + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + import pandas as pd + + old_metrics = self.data["metrics"] + + all_unit_ids = new_sorting_analyzer.unit_ids + not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)] + + metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) + + metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] + metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + + new_data = dict(metrics=metrics) + return new_data + + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): """ Compute quality metrics. """ @@ -100,15 +118,19 @@ def _run(self, verbose=False, **job_kwargs): n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] - sorting = self.sorting_analyzer.sorting - unit_ids = sorting.unit_ids - non_empty_unit_ids = sorting.get_non_empty_unit_ids() - empty_unit_ids = unit_ids[~np.isin(unit_ids, non_empty_unit_ids)] - if len(empty_unit_ids) > 0: - warnings.warn( - f"Units {empty_unit_ids} are empty. Quality metrcs will be set to NaN " - f"for these units.\n To remove empty units, use `sorting.remove_empty_units()`." - ) + if unit_ids is None: + sorting = sorting_analyzer.sorting + unit_ids = sorting.unit_ids + non_empty_unit_ids = sorting.get_non_empty_unit_ids() + empty_unit_ids = unit_ids[~np.isin(unit_ids, non_empty_unit_ids)] + if len(empty_unit_ids) > 0: + warnings.warn( + f"Units {empty_unit_ids} are empty. Quality metrics will be set to NaN " + f"for these units.\n To remove empty units, use `sorting.remove_empty_units()`." + ) + else: + non_empty_unit_ids = unit_ids + empty_unit_ids = [] import pandas as pd @@ -126,7 +148,7 @@ def _run(self, verbose=False, **job_kwargs): func = _misc_metric_name_to_func[metric_name] params = qm_params[metric_name] if metric_name in qm_params else {} - res = func(self.sorting_analyzer, unit_ids=non_empty_unit_ids, **params) + res = func(sorting_analyzer, unit_ids=non_empty_unit_ids, **params) # QM with uninstall dependencies might return None if res is not None: if isinstance(res, dict): @@ -141,10 +163,10 @@ def _run(self, verbose=False, **job_kwargs): # metrics based on PCs pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]: - if not self.sorting_analyzer.has_extension("principal_components"): + if not sorting_analyzer.has_extension("principal_components"): raise ValueError("waveform_principal_component must be provied") pc_metrics = compute_pc_metrics( - self.sorting_analyzer, + sorting_analyzer, unit_ids=non_empty_unit_ids, metric_names=pc_metric_names, # sparsity=sparsity, @@ -160,7 +182,12 @@ def _run(self, verbose=False, **job_kwargs): if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan - self.data["metrics"] = metrics + return metrics + + def _run(self, verbose=False, **job_kwargs): + self.data["metrics"] = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, **job_kwargs + ) def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 90b622b9ab..bb222200e9 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -573,19 +573,21 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): sorting_analyzer = _sorting_analyzer_simple() print(sorting_analyzer) + test_unit_structure_in_output(_small_sorting_analyzer()) + # test_calculate_firing_rate_num_spikes(sorting_analyzer) # test_calculate_snrs(sorting_analyzer) - test_calculate_amplitude_cutoff(sorting_analyzer) + # test_calculate_amplitude_cutoff(sorting_analyzer) # test_calculate_presence_ratio(sorting_analyzer) # test_calculate_amplitude_median(sorting_analyzer) # test_calculate_sliding_rp_violations(sorting_analyzer) # test_calculate_drift_metrics(sorting_analyzer) - test_synchrony_metrics(sorting_analyzer) - test_synchrony_metrics_unit_id_subset(sorting_analyzer) - test_synchrony_metrics_no_unit_ids(sorting_analyzer) + # test_synchrony_metrics(sorting_analyzer) + # test_synchrony_metrics_unit_id_subset(sorting_analyzer) + # test_synchrony_metrics_no_unit_ids(sorting_analyzer) # test_calculate_firing_range(sorting_analyzer) # test_calculate_amplitude_cv_metrics(sorting_analyzer) - test_calculate_sd_ratio(sorting_analyzer) + # test_calculate_sd_ratio(sorting_analyzer) # sorting_analyzer_violations = _sorting_analyzer_violations() # print(sorting_analyzer_violations) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index be75877f02..29bc5e0f73 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -368,7 +368,7 @@ def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): sa = create_sorting_analyzer(sorting, recording, format="memory", sparsity=sparsity) sa.extensions["templates"] = ComputeTemplates(sa) - sa.extensions["templates"].params = {"nbefore": templates.nbefore} + sa.extensions["templates"].params = {"ms_before": templates.ms_before, "ms_after": templates.ms_after} sa.extensions["templates"].data["average"] = templates_array sa.compute("unit_locations", method="monopolar_triangulation") merges = get_potential_auto_merge(sa, **merging_kwargs) diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index 6d2ad09239..b578eb4478 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -44,6 +44,8 @@ def get_localization_pipeline_nodes( method in possible_localization_methods ), f"Method {method} is not supported. Choose from {possible_localization_methods}" + # TODO : this is a bad idea becaise it trigger warning when n_jobs is not set globally + # because the job_kwargs is never transmitted until here method_kwargs, job_kwargs = split_job_kwargs(kwargs) if method == "center_of_mass":