Skip to content

Commit

Permalink
Merge branch 'main' into report_without_waveforms
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow authored Nov 4, 2024
2 parents 2faed13 + c4d2eaa commit d337001
Show file tree
Hide file tree
Showing 25 changed files with 853 additions and 349 deletions.
2 changes: 1 addition & 1 deletion doc/modules/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ and merging unit groups.
sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3])
sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0])
sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3])
sorting_analyzer_merge = sorting_analyzer.merge_units([[0, 1], [2, 3]])
All computed extensions will be automatically propagated or merged when curating. Please refer to the
:ref:`modules/curation:Curation module` documentation for more information.
Expand Down
11 changes: 8 additions & 3 deletions doc/modules/curation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ The ``censored_period_ms`` parameter is the time window in milliseconds to consi
The :py:func:`~spikeinterface.curation.remove_redundand_units` function removes
redundant units from the sorting output. Redundant units are units that share over
a certain percentage of spikes, by default 80%.
The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.
The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.

.. code-block:: python
Expand All @@ -102,13 +102,18 @@ The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object.
)
# remove redundant units from SortingAnalyzer object
clean_sorting_analyzer = remove_redundant_units(
# note this returns a cleaned sorting
clean_sorting = remove_redundant_units(
sorting_analyzer,
duplicate_threshold=0.9,
remove_strategy="min_shift"
)
# in order to have a SortingAnalyer with only the non-redundant units one must
# select the designed units remembering to give format and folder if one wants
# a persistent SortingAnalyzer.
clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids)
We recommend usinf the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps
We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps
the unit (among the redundant ones), with a better template alignment.


Expand Down
14 changes: 12 additions & 2 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,13 @@ class ComputeNoiseLevels(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None):
params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed)
def _set_params(self, **noise_level_params):
params = noise_level_params.copy()
return params

def _select_extension_data(self, unit_ids):
Expand All @@ -717,6 +718,15 @@ def _run(self, verbose=False):
def _get_data(self):
return self.data["noise_levels"]

def _handle_backward_compatibility_on_load(self):
# The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None)
# now it is handle more explicitly using random_slices_kwargs=dict()
for key in ("num_chunks_per_segment", "chunk_size", "seed"):
if key in self.params:
if "random_slices_kwargs" not in self.params:
self.params["random_slices_kwargs"] = dict()
self.params["random_slices_kwargs"][key] = self.params.pop(key)


register_result_extension(ComputeNoiseLevels)
compute_noise_levels = ComputeNoiseLevels.function_factory()
6 changes: 4 additions & 2 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,10 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False
number_of_device_channel_indices = np.max(list(device_channel_indices) + [0])
if number_of_device_channel_indices >= self.get_num_channels():
error_msg = (
f"The given Probe have 'device_channel_indices' that do not match channel count \n"
f"{number_of_device_channel_indices} vs {self.get_num_channels()} \n"
f"The given Probe either has 'device_channel_indices' that does not match channel count \n"
f"{len(device_channel_indices)} vs {self.get_num_channels()} \n"
f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n"
f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n"
f"device_channel_indices are the following: {device_channel_indices} \n"
f"recording channels are the following: {self.get_channel_ids()} \n"
)
Expand Down
51 changes: 29 additions & 22 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size):


def divide_recording_into_chunks(recording, chunk_size):
all_chunks = []
recording_slices = []
for segment_index in range(recording.get_num_segments()):
num_frames = recording.get_num_samples(segment_index)
chunks = divide_segment_into_chunks(num_frames, chunk_size)
all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks])
return all_chunks
recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks])
return recording_slices


def ensure_n_jobs(recording, n_jobs=1):
Expand Down Expand Up @@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1):
return n_jobs


def chunk_duration_to_chunk_size(chunk_duration, recording):
if isinstance(chunk_duration, float):
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
elif isinstance(chunk_duration, str):
if chunk_duration.endswith("ms"):
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
elif chunk_duration.endswith("s"):
chunk_duration = float(chunk_duration.replace("s", ""))
else:
raise ValueError("chunk_duration must ends with s or ms")
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
else:
raise ValueError("chunk_duration must be str or float")
return chunk_size


def ensure_chunk_size(
recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
):
Expand Down Expand Up @@ -231,18 +247,7 @@ def ensure_chunk_size(
num_channels = recording.get_num_channels()
chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs))
elif chunk_duration is not None:
if isinstance(chunk_duration, float):
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
elif isinstance(chunk_duration, str):
if chunk_duration.endswith("ms"):
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
elif chunk_duration.endswith("s"):
chunk_duration = float(chunk_duration.replace("s", ""))
else:
raise ValueError("chunk_duration must ends with s or ms")
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
else:
raise ValueError("chunk_duration must be str or float")
chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording)
else:
# Edge case to define single chunk per segment for n_jobs=1.
# All chunking parameters equal None mean single chunk per segment
Expand Down Expand Up @@ -382,11 +387,13 @@ def __init__(
f"chunk_duration={chunk_duration_str}",
)

def run(self):
def run(self, recording_slices=None):
"""
Runs the defined jobs.
"""
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)

if recording_slices is None:
recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size)

if self.handle_returns:
returns = []
Expand All @@ -395,17 +402,17 @@ def run(self):

if self.n_jobs == 1:
if self.progress_bar:
all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name)
recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name)

worker_ctx = self.init_func(*self.init_args)
for segment_index, frame_start, frame_stop in all_chunks:
for segment_index, frame_start, frame_stop in recording_slices:
res = self.func(segment_index, frame_start, frame_stop, worker_ctx)
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)
else:
n_jobs = min(self.n_jobs, len(all_chunks))
n_jobs = min(self.n_jobs, len(recording_slices))

# parallel
with ProcessPoolExecutor(
Expand All @@ -414,10 +421,10 @@ def run(self):
mp_context=mp.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
) as executor:
results = executor.map(function_wrapper, all_chunks)
results = executor.map(function_wrapper, recording_slices)

if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(all_chunks))
results = tqdm(results, desc=self.job_name, total=len(recording_slices))

for res in results:
if self.handle_returns:
Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def run_node_pipeline(
names=None,
verbose=False,
skip_after_n_peaks=None,
recording_slices=None,
):
"""
Machinery to compute in parallel operations on peaks and traces.
Expand Down Expand Up @@ -540,6 +541,10 @@ def run_node_pipeline(
skip_after_n_peaks : None | int
Skip the computation after n_peaks.
This is not an exact because internally this skip is done per worker in average.
recording_slices : None | list[tuple]
Optionaly give a list of slices to run the pipeline only on some chunks of the recording.
It must be a list of (segment_index, frame_start, frame_stop).
If None (default), the function iterates over the entire duration of the recording.
Returns
-------
Expand Down Expand Up @@ -578,7 +583,7 @@ def run_node_pipeline(
**job_kwargs,
)

processor.run()
processor.run(recording_slices=recording_slices)

outs = gather_func.finalize_buffers(squeeze_output=squeeze_output)
return outs
Expand Down
Loading

0 comments on commit d337001

Please sign in to comment.