Skip to content

Commit

Permalink
Merge branch 'SpikeInterface:main' into meta_merging
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Nov 5, 2024
2 parents 2f80b8d + e525d85 commit 42754ae
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 88 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
description = "Python toolkit for analysis, visualization, and comparison of spike sorting output"
readme = "README.md"
requires-python = ">=3.9,<4.0"
requires-python = ">=3.9,<3.13" # Only numpy 2.1 supported on python 3.13 for windows. We need to wait for fix on neo
classifiers = [
"Programming Language :: Python :: 3 :: Only",
"License :: OSI Approved :: MIT License",
Expand Down
79 changes: 49 additions & 30 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,23 @@

class ComputeRandomSpikes(AnalyzerExtension):
"""
AnalyzerExtension that select some random spikes.
AnalyzerExtension that select somes random spikes.
This allows for a subsampling of spikes for further calculations and is important
for managing that amount of memory and speed of computation in the analyzer.
This will be used by the `waveforms`/`templates` extensions.
This internally use `random_spikes_selection()` parameters are the same.
This internally uses `random_spikes_selection()` parameters.
Parameters
----------
method: "uniform" | "all", default: "uniform"
method : "uniform" | "all", default: "uniform"
The method to select the spikes
max_spikes_per_unit: int, default: 500
max_spikes_per_unit : int, default: 500
The maximum number of spikes per unit, ignored if method="all"
margin_size: int, default: None
margin_size : int, default: None
A margin on each border of segments to avoid border spikes, ignored if method="all"
seed: int or None, default: None
seed : int or None, default: None
A seed for the random generator, ignored if method="all"
Returns
Expand Down Expand Up @@ -104,7 +106,7 @@ def get_random_spikes(self):
return self._some_spikes

def get_selected_indices_in_spike_train(self, unit_id, segment_index):
# usefull for Waveforms extractor backwars compatibility
# useful for WaveformExtractor backwards compatibility
# In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain
sorting = self.sorting_analyzer.sorting
random_spikes_indices = self.data["random_spikes_indices"]
Expand Down Expand Up @@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension):
Parameters
----------
ms_before: float, default: 1.0
ms_before : float, default: 1.0
The number of ms to extract before the spike events
ms_after: float, default: 2.0
ms_after : float, default: 2.0
The number of ms to extract after the spike events
dtype: None | dtype, default: None
dtype : None | dtype, default: None
The dtype of the waveforms. If None, the dtype of the recording is used.
Returns
-------
waveforms: np.ndarray
waveforms : np.ndarray
Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels)
"""

Expand Down Expand Up @@ -380,7 +382,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N
assert isinstance(operators, list)
for operator in operators:
if isinstance(operator, str):
assert operator in ("average", "std", "median", "mad")
if operator not in ("average", "std", "median", "mad"):
error_msg = (
f"You have entered an operator {operator} in your `operators` argument which is "
f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead."
)
raise ValueError(error_msg)
else:
assert isinstance(operator, (list, tuple))
assert len(operator) == 2
Expand All @@ -405,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs):
self._compute_and_append_from_waveforms(self.params["operators"])

else:
for operator in self.params["operators"]:
if operator not in ("average", "std"):
raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension")
bad_operator_list = [
operator for operator in self.params["operators"] if operator not in ("average", "std")
]
if len(bad_operator_list) > 0:
raise ValueError(
f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension"
)

recording = self.sorting_analyzer.recording
sorting = self.sorting_analyzer.sorting
Expand Down Expand Up @@ -441,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs):

def _compute_and_append_from_waveforms(self, operators):
if not self.sorting_analyzer.has_extension("waveforms"):
raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension")
raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension")

unit_ids = self.sorting_analyzer.unit_ids
channel_ids = self.sorting_analyzer.channel_ids
Expand All @@ -466,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators):

assert self.sorting_analyzer.has_extension(
"random_spikes"
), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()"
), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')"
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()
for unit_index, unit_id in enumerate(unit_ids):
spike_mask = some_spikes["unit_index"] == unit_index
Expand Down Expand Up @@ -549,9 +560,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
if operator != "percentile":
key = operator
else:
assert percentile is not None, "You must provide percentile=..."
assert percentile is not None, "You must provide percentile=... if `operator=percentile`"
key = f"percentile_{percentile}"

if key not in self.data.keys():
error_msg = (
f"You have entered `operator={key}`, but the only operators calculated are "
f"{list(self.data.keys())}. Please use one of these as your `operator` in the "
f"`get_data` function."
)
raise ValueError(error_msg)

templates_array = self.data[key]

if outputs == "numpy":
Expand All @@ -566,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"):
probe=self.sorting_analyzer.get_probe(),
)
else:
raise ValueError("outputs must be numpy or Templates")
raise ValueError("outputs must be `numpy` or `Templates`")

def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"):
"""
Expand All @@ -576,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
Parameters
----------
unit_ids: list or None
unit_ids : list or None
Unit ids to retrieve waveforms for
operator: "average" | "median" | "std" | "percentile", default: "average"
operator : "average" | "median" | "std" | "percentile", default: "average"
The operator to compute the templates
percentile: float, default: None
percentile : float, default: None
Percentile to use for operator="percentile"
save: bool, default True
save : bool, default: True
In case, the operator is not computed yet it can be saved to folder or zarr
outputs: "numpy" | "Templates"
outputs : "numpy" | "Templates", default: "numpy"
Whether to return a numpy array or a Templates object
Returns
-------
templates: np.array
templates : np.array | Templates
The returned templates (num_units, num_samples, num_channels)
"""
if operator != "percentile":
key = operator
else:
assert percentile is not None, "You must provide percentile=..."
assert percentile is not None, "You must provide percentile=... if `operator='percentile'`"
key = f"pencentile_{percentile}"

if key in self.data:
Expand Down Expand Up @@ -632,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
is_scaled=self.sorting_analyzer.return_scaled,
)
else:
raise ValueError("outputs must be numpy or Templates")
raise ValueError("`outputs` must be 'numpy' or 'Templates'")

def get_unit_template(self, unit_id, operator="average"):
"""
Expand All @@ -642,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"):
----------
unit_id: str | int
Unit id to retrieve waveforms for
operator: str
operator: str, default: "average"
The operator to compute the templates
Returns
Expand Down Expand Up @@ -701,13 +720,13 @@ def _set_params(self, **noise_level_params):
return params

def _select_extension_data(self, unit_ids):
# this do not depend on units
# this does not depend on units
return self.data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
# this do not depend on units
# this does not depend on units
return self.data.copy()

def _run(self, verbose=False):
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/exporters/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def export_report(
**job_kwargs,
):
"""
Exports a SI spike sorting report. The report includes summary figures of the spike sorting output
(e.g. amplitude distributions, unit localization and depth VS amplitude) as well as unit-specific reports,
that include waveforms, templates, template maps, ISI distributions, and more.
Exports a SI spike sorting report. The report includes summary figures of the spike sorting output.
What is plotted depends on what has been calculated. Unit locations and unit waveforms are always included.
Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms`,
and `spike_amplitudes` have been computed for the given `sorting_analyzer`.
Parameters
----------
Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/sorters/basesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo
elif recording.check_serializability("pickle"):
recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder)
else:
# TODO: deprecate and finally remove this after 0.100
d = {"warning": "The recording is not serializable to json"}
rec_file.write_text(json.dumps(d, indent=4), encoding="utf8")
raise RuntimeError(
"This recording is not serializable and so can not be sorted. Consider `recording.save()` to save a "
"compatible binary file."
)

return output_folder

Expand Down
8 changes: 7 additions & 1 deletion src/spikeinterface/widgets/autocorrelograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@ class AutoCorrelogramsWidget(CrossCorrelogramsWidget):
# the doc is copied form CrossCorrelogramsWidget

def __init__(self, *args, **kargs):
CrossCorrelogramsWidget.__init__(self, *args, **kargs)
_ = kargs.pop("min_similarity_for_correlograms", 0.0)
CrossCorrelogramsWidget.__init__(
self,
*args,
**kargs,
min_similarity_for_correlograms=None,
)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/widgets/crosscorrelograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class CrossCorrelogramsWidget(BaseWidget):
List of unit ids
min_similarity_for_correlograms : float, default: 0.2
For sortingview backend. Threshold for computing pair-wise cross-correlograms.
If template similarity between two units is below this threshold, the cross-correlogram is not displayed
If template similarity between two units is below this threshold, the cross-correlogram is not displayed.
For auto-correlograms plot, this is automatically set to None.
window_ms : float, default: 100.0
Window for CCGs in ms. If correlograms are already computed (e.g. with SortingAnalyzer),
this argument is ignored
Expand Down
5 changes: 4 additions & 1 deletion src/spikeinterface/widgets/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class TracesWidget(BaseWidget):
If dict, keys should be the same as recording keys
scale : float, default: 1
Scale factor for the traces
vspacing_factor : float, default: 1.5
Vertical spacing between channels as a multiple of maximum channel amplitude
with_colorbar : bool, default: True
When mode is "map", a colorbar is added
tile_size : int, default: 1500
Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__(
tile_size=1500,
seconds_per_row=0.2,
scale=1,
vspacing_factor=1.5,
with_colorbar=True,
add_legend=True,
backend=None,
Expand Down Expand Up @@ -168,7 +171,7 @@ def __init__(
traces0 = list_traces[0]
mean_channel_std = np.mean(np.std(traces0, axis=0))
max_channel_amp = np.max(np.max(np.abs(traces0), axis=0))
vspacing = max_channel_amp * 1.5
vspacing = max_channel_amp * vspacing_factor

if rec0.get_channel_groups() is None:
color_groups = False
Expand Down
Loading

0 comments on commit 42754ae

Please sign in to comment.