Skip to content

Commit

Permalink
Merge branch 'SpikeInterface:main' into meta_merging
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Oct 28, 2024
2 parents 08a599b + 0df1160 commit b10ca04
Show file tree
Hide file tree
Showing 15 changed files with 374 additions and 118 deletions.
2 changes: 1 addition & 1 deletion doc/modules/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ and merging unit groups.
sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3])
sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0])
sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3])
sorting_analyzer_merge = sorting_analyzer.merge_units([[0, 1], [2, 3]])
All computed extensions will be automatically propagated or merged when curating. Please refer to the
:ref:`modules/curation:Curation module` documentation for more information.
Expand Down
11 changes: 8 additions & 3 deletions doc/modules/curation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ The ``censored_period_ms`` parameter is the time window in milliseconds to consi
The :py:func:`~spikeinterface.curation.remove_redundand_units` function removes
redundant units from the sorting output. Redundant units are units that share over
a certain percentage of spikes, by default 80%.
The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.
The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.

.. code-block:: python
Expand All @@ -102,13 +102,18 @@ The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.
)
# remove redundant units from SortingAnalyzer object
clean_sorting_analyzer = remove_redundant_units(
# note this returns a cleaned sorting
clean_sorting = remove_redundant_units(
sorting_analyzer,
duplicate_threshold=0.9,
remove_strategy="min_shift"
)
# in order to have a SortingAnalyer with only the non-redundant units one must
# select the designed units remembering to give format and folder if one wants
# a persistent SortingAnalyzer.
clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids)
We recommend usinf the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps
We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps
the unit (among the redundant ones), with a better template alignment.


Expand Down
14 changes: 12 additions & 2 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,13 @@ class ComputeNoiseLevels(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None):
params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed)
def _set_params(self, **noise_level_params):
params = noise_level_params.copy()
return params

def _select_extension_data(self, unit_ids):
Expand All @@ -717,6 +718,15 @@ def _run(self, verbose=False):
def _get_data(self):
return self.data["noise_levels"]

def _handle_backward_compatibility_on_load(self):
# The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None)
# now it is handle more explicitly using random_slices_kwargs=dict()
for key in ("num_chunks_per_segment", "chunk_size", "seed"):
if key in self.params:
if "random_slices_kwargs" not in self.params:
self.params["random_slices_kwargs"] = dict()
self.params["random_slices_kwargs"][key] = self.params.pop(key)


register_result_extension(ComputeNoiseLevels)
compute_noise_levels = ComputeNoiseLevels.function_factory()
6 changes: 4 additions & 2 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False
number_of_device_channel_indices = np.max(list(device_channel_indices) + [0])
if number_of_device_channel_indices >= self.get_num_channels():
error_msg = (
f"The given Probe have 'device_channel_indices' that do not match channel count \n"
f"{number_of_device_channel_indices} vs {self.get_num_channels()} \n"
f"The given Probe either has 'device_channel_indices' that does not match channel count \n"
f"{len(device_channel_indices)} vs {self.get_num_channels()} \n"
f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n"
f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n"
f"device_channel_indices are the following: {device_channel_indices} \n"
f"recording channels are the following: {self.get_channel_ids()} \n"
)
Expand Down
35 changes: 21 additions & 14 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1):
return n_jobs


def chunk_duration_to_chunk_size(chunk_duration, recording):
if isinstance(chunk_duration, float):
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
elif isinstance(chunk_duration, str):
if chunk_duration.endswith("ms"):
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
elif chunk_duration.endswith("s"):
chunk_duration = float(chunk_duration.replace("s", ""))
else:
raise ValueError("chunk_duration must ends with s or ms")
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
else:
raise ValueError("chunk_duration must be str or float")
return chunk_size


def ensure_chunk_size(
recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
):
Expand Down Expand Up @@ -231,18 +247,7 @@ def ensure_chunk_size(
num_channels = recording.get_num_channels()
chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs))
elif chunk_duration is not None:
if isinstance(chunk_duration, float):
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
elif isinstance(chunk_duration, str):
if chunk_duration.endswith("ms"):
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
elif chunk_duration.endswith("s"):
chunk_duration = float(chunk_duration.replace("s", ""))
else:
raise ValueError("chunk_duration must ends with s or ms")
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
else:
raise ValueError("chunk_duration must be str or float")
chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording)
else:
# Edge case to define single chunk per segment for n_jobs=1.
# All chunking parameters equal None mean single chunk per segment
Expand Down Expand Up @@ -382,11 +387,13 @@ def __init__(
f"chunk_duration={chunk_duration_str}",
)

def run(self):
def run(self, all_chunks=None):
"""
Runs the defined jobs.
"""
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)

if all_chunks is None:
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)

if self.handle_returns:
returns = []
Expand Down
Loading

0 comments on commit b10ca04

Please sign in to comment.