Skip to content

Commit

Permalink
Merge pull request #3359 from samuelgarcia/improve_noise_level_machinery
Browse files Browse the repository at this point in the history
Improve noise level machinery
  • Loading branch information
samuelgarcia authored Oct 25, 2024
2 parents 4a1a45a + bac57fe commit 0df1160
Show file tree
Hide file tree
Showing 9 changed files with 295 additions and 107 deletions.
14 changes: 12 additions & 2 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,13 @@ class ComputeNoiseLevels(AnalyzerExtension):
need_recording = True
use_nodepipeline = False
need_job_kwargs = False
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None):
params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed)
def _set_params(self, **noise_level_params):
params = noise_level_params.copy()
return params

def _select_extension_data(self, unit_ids):
Expand All @@ -717,6 +718,15 @@ def _run(self, verbose=False):
def _get_data(self):
return self.data["noise_levels"]

def _handle_backward_compatibility_on_load(self):
# The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None)
# now it is handle more explicitly using random_slices_kwargs=dict()
for key in ("num_chunks_per_segment", "chunk_size", "seed"):
if key in self.params:
if "random_slices_kwargs" not in self.params:
self.params["random_slices_kwargs"] = dict()
self.params["random_slices_kwargs"][key] = self.params.pop(key)


register_result_extension(ComputeNoiseLevels)
compute_noise_levels = ComputeNoiseLevels.function_factory()
35 changes: 21 additions & 14 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1):
return n_jobs


def chunk_duration_to_chunk_size(chunk_duration, recording):
if isinstance(chunk_duration, float):
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
elif isinstance(chunk_duration, str):
if chunk_duration.endswith("ms"):
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
elif chunk_duration.endswith("s"):
chunk_duration = float(chunk_duration.replace("s", ""))
else:
raise ValueError("chunk_duration must ends with s or ms")
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
else:
raise ValueError("chunk_duration must be str or float")
return chunk_size


def ensure_chunk_size(
recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
):
Expand Down Expand Up @@ -231,18 +247,7 @@ def ensure_chunk_size(
num_channels = recording.get_num_channels()
chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs))
elif chunk_duration is not None:
if isinstance(chunk_duration, float):
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
elif isinstance(chunk_duration, str):
if chunk_duration.endswith("ms"):
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
elif chunk_duration.endswith("s"):
chunk_duration = float(chunk_duration.replace("s", ""))
else:
raise ValueError("chunk_duration must ends with s or ms")
chunk_size = int(chunk_duration * recording.get_sampling_frequency())
else:
raise ValueError("chunk_duration must be str or float")
chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording)
else:
# Edge case to define single chunk per segment for n_jobs=1.
# All chunking parameters equal None mean single chunk per segment
Expand Down Expand Up @@ -382,11 +387,13 @@ def __init__(
f"chunk_duration={chunk_duration_str}",
)

def run(self):
def run(self, all_chunks=None):
"""
Runs the defined jobs.
"""
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)

if all_chunks is None:
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)

if self.handle_returns:
returns = []
Expand Down
Loading

0 comments on commit 0df1160

Please sign in to comment.