Skip to content

Commit

Permalink
Merge branch 'main' into add-verbose-to-mda
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow authored Nov 13, 2024
2 parents 8845d3d + 0e13973 commit 74f93b4
Show file tree
Hide file tree
Showing 35 changed files with 1,022 additions and 466 deletions.
2 changes: 1 addition & 1 deletion doc/modules/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ and merging unit groups.
sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3])
sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0])
sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3])
sorting_analyzer_merge = sorting_analyzer.merge_units([[0, 1], [2, 3]])
All computed extensions will be automatically propagated or merged when curating. Please refer to the
:ref:`modules/curation:Curation module` documentation for more information.
Expand Down
11 changes: 8 additions & 3 deletions doc/modules/curation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ The ``censored_period_ms`` parameter is the time window in milliseconds to consi
The :py:func:`~spikeinterface.curation.remove_redundand_units` function removes
redundant units from the sorting output. Redundant units are units that share over
a certain percentage of spikes, by default 80%.
The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.
The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.

.. code-block:: python
Expand All @@ -102,13 +102,18 @@ The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.
)
# remove redundant units from SortingAnalyzer object
clean_sorting_analyzer = remove_redundant_units(
# note this returns a cleaned sorting
clean_sorting = remove_redundant_units(
sorting_analyzer,
duplicate_threshold=0.9,
remove_strategy="min_shift"
)
# in order to have a SortingAnalyer with only the non-redundant units one must
# select the designed units remembering to give format and folder if one wants
# a persistent SortingAnalyzer.
clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids)
We recommend usinf the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps
We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps
the unit (among the redundant ones), with a better template alignment.


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
description = "Python toolkit for analysis, visualization, and comparison of spike sorting output"
readme = "README.md"
requires-python = ">=3.9,<4.0"
requires-python = ">=3.9,<3.13" # Only numpy 2.1 supported on python 3.13 for windows. We need to wait for fix on neo
classifiers = [
"Programming Language :: Python :: 3 :: Only",
"License :: OSI Approved :: MIT License",
Expand Down
93 changes: 61 additions & 32 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,23 @@

class ComputeRandomSpikes(AnalyzerExtension):
"""
AnalyzerExtension that select some random spikes.
AnalyzerExtension that select somes random spikes.
This allows for a subsampling of spikes for further calculations and is important
for managing that amount of memory and speed of computation in the analyzer.
This will be used by the `waveforms`/`templates` extensions.
This internally use `random_spikes_selection()` parameters are the same.
This internally uses `random_spikes_selection()` parameters.
Parameters
----------
method: "uniform" | "all", default: "uniform"
method : "uniform" | "all", default: "uniform"
The method to select the spikes
max_spikes_per_unit: int, default: 500
max_spikes_per_unit : int, default: 500
The maximum number of spikes per unit, ignored if method="all"
margin_size: int, default: None
margin_size : int, default: None
A margin on each border of segments to avoid border spikes, ignored if method="all"
seed: int or None, default: None
seed : int or None, default: None
A seed for the random generator, ignored if method="all"
Returns
Expand Down Expand Up @@ -104,7 +106,7 @@ def get_random_spikes(self):
return self._some_spikes

def get_selected_indices_in_spike_train(self, unit_id, segment_index):
# usefull for Waveforms extractor backwars compatibility
# useful for WaveformExtractor backwards compatibility
# In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain
sorting = self.sorting_analyzer.sorting
random_spikes_indices = self.data["random_spikes_indices"]
Expand Down Expand Up @@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension):
Parameters
----------
ms_before: float, default: 1.0
ms_before : float, default: 1.0
The number of ms to extract before the spike events
ms_after: float, default: 2.0
ms_after : float, default: 2.0
The number of ms to extract after the spike events
dtype: None | dtype, default: None
dtype : None | dtype, default: None
The dtype of the waveforms. If None, the dtype of the recording is used.
Returns
-------
waveforms: np.ndarray
waveforms : np.ndarray
Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels)
"""

Expand Down Expand Up @@ -380,7 +382,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N
assert isinstance(operators, list)
for operator in operators:
if isinstance(operator, str):
assert operator in ("average", "std", "median", "mad")
if operator not in ("average", "std", "median", "mad"):
error_msg = (
f"You have entered an operator {operator} in your `operators` argument which is "
f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead."
)
raise ValueError(error_msg)
else:
assert isinstance(operator, (list, tuple))
assert len(operator) == 2
Expand All @@ -405,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs):
self._compute_and_append_from_waveforms(self.params["operators"])

else:
for operator in self.params["operators"]:
if operator not in ("average", "std"):
raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension")
bad_operator_list = [
operator for operator in self.params["operators"] if operator not in ("average", "std")
]
if len(bad_operator_list) > 0:
raise ValueError(
f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension"
)

recording = self.sorting_analyzer.recording
sorting = self.sorting_analyzer.sorting
Expand Down Expand Up @@ -441,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs):

def _compute_and_append_from_waveforms(self, operators):
if not self.sorting_analyzer.has_extension("waveforms"):
raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension")
raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension")

unit_ids = self.sorting_analyzer.unit_ids
channel_ids = self.sorting_analyzer.channel_ids
Expand All @@ -466,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators):

assert self.sorting_analyzer.has_extension(
"random_spikes"
), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()"
), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')"
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()
for unit_index, unit_id in enumerate(unit_ids):
spike_mask = some_spikes["unit_index"] == unit_index
Expand Down Expand Up @@ -549,9 +560,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
if operator != "percentile":
key = operator
else:
assert percentile is not None, "You must provide percentile=..."
assert percentile is not None, "You must provide percentile=... if `operator=percentile`"
key = f"percentile_{percentile}"

if key not in self.data.keys():
error_msg = (
f"You have entered `operator={key}`, but the only operators calculated are "
f"{list(self.data.keys())}. Please use one of these as your `operator` in the "
f"`get_data` function."
)
raise ValueError(error_msg)

templates_array = self.data[key]

if outputs == "numpy":
Expand All @@ -566,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
probe=self.sorting_analyzer.get_probe(),
)
else:
raise ValueError("outputs must be numpy or Templates")
raise ValueError("outputs must be `numpy` or `Templates`")

def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"):
"""
Expand All @@ -576,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
Parameters
----------
unit_ids: list or None
unit_ids : list or None
Unit ids to retrieve waveforms for
operator: "average" | "median" | "std" | "percentile", default: "average"
operator : "average" | "median" | "std" | "percentile", default: "average"
The operator to compute the templates
percentile: float, default: None
percentile : float, default: None
Percentile to use for operator="percentile"
save: bool, default True
save : bool, default: True
In case, the operator is not computed yet it can be saved to folder or zarr
outputs: "numpy" | "Templates"
outputs : "numpy" | "Templates", default: "numpy"
Whether to return a numpy array or a Templates object
Returns
-------
templates: np.array
templates : np.array | Templates
The returned templates (num_units, num_samples, num_channels)
"""
if operator != "percentile":
key = operator
else:
assert percentile is not None, "You must provide percentile=..."
assert percentile is not None, "You must provide percentile=... if `operator='percentile'`"
key = f"pencentile_{percentile}"

if key in self.data:
Expand Down Expand Up @@ -632,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
is_scaled=self.sorting_analyzer.return_scaled,
)
else:
raise ValueError("outputs must be numpy or Templates")
raise ValueError("`outputs` must be 'numpy' or 'Templates'")

def get_unit_template(self, unit_id, operator="average"):
"""
Expand All @@ -642,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"):
----------
unit_id: str | int
Unit id to retrieve waveforms for
operator: str
operator: str, default: "average"
The operator to compute the templates
Returns
Expand Down Expand Up @@ -691,22 +710,23 @@ class ComputeNoiseLevels(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None):
params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed)
def _set_params(self, **noise_level_params):
params = noise_level_params.copy()
return params

def _select_extension_data(self, unit_ids):
# this do not depend on units
# this does not depend on units
return self.data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
# this do not depend on units
# this does not depend on units
return self.data.copy()

def _run(self, verbose=False):
Expand All @@ -717,6 +737,15 @@ def _run(self, verbose=False):
def _get_data(self):
return self.data["noise_levels"]

def _handle_backward_compatibility_on_load(self):
# The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None)
# now it is handle more explicitly using random_slices_kwargs=dict()
for key in ("num_chunks_per_segment", "chunk_size", "seed"):
if key in self.params:
if "random_slices_kwargs" not in self.params:
self.params["random_slices_kwargs"] = dict()
self.params["random_slices_kwargs"][key] = self.params.pop(key)


register_result_extension(ComputeNoiseLevels)
compute_noise_levels = ComputeNoiseLevels.function_factory()
6 changes: 4 additions & 2 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False
number_of_device_channel_indices = np.max(list(device_channel_indices) + [0])
if number_of_device_channel_indices >= self.get_num_channels():
error_msg = (
f"The given Probe have 'device_channel_indices' that do not match channel count \n"
f"{number_of_device_channel_indices} vs {self.get_num_channels()} \n"
f"The given Probe either has 'device_channel_indices' that does not match channel count \n"
f"{len(device_channel_indices)} vs {self.get_num_channels()} \n"
f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n"
f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n"
f"device_channel_indices are the following: {device_channel_indices} \n"
f"recording channels are the following: {self.get_channel_ids()} \n"
)
Expand Down
Loading

0 comments on commit 74f93b4

Please sign in to comment.