Skip to content

Commit

Permalink
Merge pull request #2907 from samuelgarcia/fix_verbose
Browse files Browse the repository at this point in the history
Fix more verbosity propagation
  • Loading branch information
samuelgarcia authored May 25, 2024
2 parents 73c8f07 + ee4529f commit f478f26
Show file tree
Hide file tree
Showing 15 changed files with 64 additions and 36 deletions.
13 changes: 7 additions & 6 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ class ComputeRandomSpikes(AnalyzerExtension):
use_nodepipeline = False
need_job_kwargs = False

def _run(
self,
):
def _run(self, verbose=False):

self.data["random_spikes_indices"] = random_spikes_selection(
self.sorting_analyzer.sorting,
num_samples=self.sorting_analyzer.rec_attributes["num_samples"],
Expand Down Expand Up @@ -145,7 +144,7 @@ def nbefore(self):
def nafter(self):
return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
self.data.clear()

recording = self.sorting_analyzer.recording
Expand Down Expand Up @@ -183,6 +182,7 @@ def _run(self, **job_kwargs):
sparsity_mask=sparsity_mask,
copy=copy,
job_name="compute_waveforms",
verbose=verbose,
**job_kwargs,
)

Expand Down Expand Up @@ -311,7 +311,7 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N
)
return params

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
self.data.clear()

if self.sorting_analyzer.has_extension("waveforms"):
Expand Down Expand Up @@ -339,6 +339,7 @@ def _run(self, **job_kwargs):
self.nafter,
return_scaled=return_scaled,
return_std=return_std,
verbose=verbose,
**job_kwargs,
)

Expand Down Expand Up @@ -581,7 +582,7 @@ def _select_extension_data(self, unit_ids):
# this do not depend on units
return self.data

def _run(self):
def _run(self, verbose=False):
self.data["noise_levels"] = get_noise_levels(
self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params
)
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def run_node_pipeline(
squeeze_output=True,
folder=None,
names=None,
verbose=False,
):
"""
Common function to run pipeline with peak detector or already detected peak.
Expand All @@ -499,6 +500,7 @@ def run_node_pipeline(
init_args,
gather_func=gather_func,
job_name=job_name,
verbose=verbose,
**job_kwargs,
)

Expand Down
28 changes: 16 additions & 12 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def get_num_units(self) -> int:
return self.sorting.get_num_units()

## extensions zone
def compute(self, input, save=True, extension_params=None, **kwargs):
def compute(self, input, save=True, extension_params=None, verbose=False, **kwargs):
"""
Compute one extension or several extensiosn.
Internally calls compute_one_extension() or compute_several_extensions() depending on the input type.
Expand Down Expand Up @@ -883,11 +883,11 @@ def compute(self, input, save=True, extension_params=None, **kwargs):
)
"""
if isinstance(input, str):
return self.compute_one_extension(extension_name=input, save=save, **kwargs)
return self.compute_one_extension(extension_name=input, save=save, verbose=verbose, **kwargs)
elif isinstance(input, dict):
params_, job_kwargs = split_job_kwargs(kwargs)
assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()"
self.compute_several_extensions(extensions=input, save=save, **job_kwargs)
self.compute_several_extensions(extensions=input, save=save, verbose=verbose, **job_kwargs)
elif isinstance(input, list):
params_, job_kwargs = split_job_kwargs(kwargs)
assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()"
Expand All @@ -898,11 +898,11 @@ def compute(self, input, save=True, extension_params=None, **kwargs):
ext_name in input
), f"SortingAnalyzer.compute(): Parameters specified for {ext_name}, which is not in the specified {input}"
extensions[ext_name] = ext_params
self.compute_several_extensions(extensions=extensions, save=save, **job_kwargs)
self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs)
else:
raise ValueError("SortingAnalyzer.compute() need str, dict or list")

def compute_one_extension(self, extension_name, save=True, **kwargs):
def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs):
"""
Compute one extension.
Expand All @@ -925,7 +925,7 @@ def compute_one_extension(self, extension_name, save=True, **kwargs):
Returns
-------
result_extension: AnalyzerExtension
Return the extension instance.
Return the extension instance
Examples
--------
Expand Down Expand Up @@ -961,13 +961,16 @@ def compute_one_extension(self, extension_name, save=True, **kwargs):

extension_instance = extension_class(self)
extension_instance.set_params(save=save, **params)
extension_instance.run(save=save, **job_kwargs)
if extension_class.need_job_kwargs:
extension_instance.run(save=save, verbose=verbose, **job_kwargs)
else:
extension_instance.run(save=save, verbose=verbose)

self.extensions[extension_name] = extension_instance

return extension_instance

def compute_several_extensions(self, extensions, save=True, **job_kwargs):
def compute_several_extensions(self, extensions, save=True, verbose=False, **job_kwargs):
"""
Compute several extensions
Expand Down Expand Up @@ -1021,9 +1024,9 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
for extension_name, extension_params in extensions_without_pipeline.items():
extension_class = get_extension_class(extension_name)
if extension_class.need_job_kwargs:
self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs)
self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs)
else:
self.compute_one_extension(extension_name, save=save, **extension_params)
self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params)
# then extensions with pipeline
if len(extensions_with_pipeline) > 0:
all_nodes = []
Expand Down Expand Up @@ -1053,6 +1056,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
job_name=job_name,
gather_mode="memory",
squeeze_output=False,
verbose=verbose,
)

for r, result in enumerate(results):
Expand All @@ -1071,9 +1075,9 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
for extension_name, extension_params in extensions_post_pipeline.items():
extension_class = get_extension_class(extension_name)
if extension_class.need_job_kwargs:
self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs)
self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs)
else:
self.compute_one_extension(extension_name, save=save, **extension_params)
self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params)

def get_saved_extension_names(self):
"""
Expand Down
15 changes: 12 additions & 3 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def distribute_waveforms_to_buffers(
mode="memmap",
sparsity_mask=None,
job_name=None,
verbose=False,
**job_kwargs,
):
"""
Expand Down Expand Up @@ -281,7 +282,9 @@ def distribute_waveforms_to_buffers(
)
if job_name is None:
job_name = f"extract waveforms {mode} multi buffer"
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs
)
processor.run()


Expand Down Expand Up @@ -410,6 +413,7 @@ def extract_waveforms_to_single_buffer(
sparsity_mask=None,
copy=True,
job_name=None,
verbose=False,
**job_kwargs,
):
"""
Expand Down Expand Up @@ -523,7 +527,9 @@ def extract_waveforms_to_single_buffer(
if job_name is None:
job_name = f"extract waveforms {mode} mono buffer"

processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs
)
processor.run()

if mode == "memmap":
Expand Down Expand Up @@ -783,6 +789,7 @@ def estimate_templates_with_accumulator(
return_scaled: bool = True,
job_name=None,
return_std: bool = False,
verbose: bool = False,
**job_kwargs,
):
"""
Expand Down Expand Up @@ -861,7 +868,9 @@ def estimate_templates_with_accumulator(

if job_name is None:
job_name = "estimate_templates_with_accumulator"
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs
)
processor.run()

# average
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _get_pipeline_nodes(self):
nodes = [spike_retriever_node, amplitude_scalings_node]
return nodes

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
amp_scalings, collision_mask = run_node_pipeline(
Expand All @@ -190,6 +190,7 @@ def _run(self, **job_kwargs):
job_kwargs=job_kwargs,
job_name="amplitude_scalings",
gather_mode="memory",
verbose=verbose,
)
self.data["amplitude_scalings"] = amp_scalings
if self.params["handle_collisions"]:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _select_extension_data(self, unit_ids):
new_data = dict(ccgs=new_ccgs, bins=new_bins)
return new_data

def _run(self):
def _run(self, verbose=False):
ccgs, bins = compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params)
self.data["ccgs"] = ccgs
self.data["bins"] = bins
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/isi.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _select_extension_data(self, unit_ids):
new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins)
return new_extension_data

def _run(self):
def _run(self, verbose=False):
isi_histograms, bins = _compute_isi_histograms(self.sorting_analyzer.sorting, **self.params)
self.data["isi_histograms"] = isi_histograms
self.data["bins"] = bins
Expand Down
8 changes: 5 additions & 3 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def project_new(self, new_spikes, new_waveforms, progress_bar=True):
new_projections = self._transform_waveforms(new_spikes, new_waveforms, pca_model, progress_bar=progress_bar)
return new_projections

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
"""
Compute the PCs on waveforms extacted within the by ComputeWaveforms.
Projections are computed only on the waveforms sampled by the SortingAnalyzer.
Expand Down Expand Up @@ -295,7 +295,7 @@ def _run(self, **job_kwargs):
def _get_data(self):
return self.data["pca_projection"]

def run_for_all_spikes(self, file_path=None, **job_kwargs):
def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):
"""
Project all spikes from the sorting on the PCA model.
This is a long computation because waveform need to be extracted from each spikes.
Expand Down Expand Up @@ -359,7 +359,9 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs):
unit_channels,
pca_model,
)
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", **job_kwargs)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs
)
processor.run()

def _fit_by_channel_local(self, n_jobs, progress_bar):
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _get_pipeline_nodes(self):
nodes = [spike_retriever_node, spike_amplitudes_node]
return nodes

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
amps = run_node_pipeline(
Expand All @@ -116,6 +116,7 @@ def _run(self, **job_kwargs):
job_kwargs=job_kwargs,
job_name="spike_amplitudes",
gather_mode="memory",
verbose=False,
)
self.data["amplitudes"] = amps

Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _get_pipeline_nodes(self):
)
return nodes

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
spike_locations = run_node_pipeline(
Expand All @@ -129,6 +129,7 @@ def _run(self, **job_kwargs):
job_kwargs=job_kwargs,
job_name="spike_locations",
gather_mode="memory",
verbose=verbose,
)
self.data["spike_locations"] = spike_locations

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _select_extension_data(self, unit_ids):
new_metrics = self.data["metrics"].loc[np.array(unit_ids)]
return dict(metrics=new_metrics)

def _run(self):
def _run(self, verbose=False):
import pandas as pd
from scipy.signal import resample_poly

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _select_extension_data(self, unit_ids):
new_similarity = self.data["similarity"][unit_indices][:, unit_indices]
return dict(similarity=new_similarity)

def _run(self):
def _run(self, verbose=False):
templates_array = get_dense_templates_array(
self.sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled
)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/unit_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _select_extension_data(self, unit_ids):
new_unit_location = self.data["unit_locations"][unit_inds]
return dict(unit_locations=new_unit_location)

def _run(self):
def _run(self, verbose=False):
method = self.params["method"]
method_kwargs = self.params["method_kwargs"]

Expand Down
13 changes: 11 additions & 2 deletions src/spikeinterface/sortingcomponents/matching/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from spikeinterface.core import get_chunk_with_margin


def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extra_outputs=False, **job_kwargs):
def find_spikes_from_templates(
recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs
):
"""Find spike from a recording from given templates.
Parameters
Expand Down Expand Up @@ -53,7 +55,14 @@ def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extr
init_func = _init_worker_find_spikes
init_args = (recording, method, method_kwargs_seralized)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, handle_returns=True, job_name=f"find spikes ({method})", **job_kwargs
recording,
func,
init_func,
init_args,
handle_returns=True,
job_name=f"find spikes ({method})",
verbose=verbose,
**job_kwargs,
)
spikes = processor.run()

Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import numpy as np

from spikeinterface.core.job_tools import (
ChunkRecordingExecutor,
_shared_job_kwargs_doc,
split_job_kwargs,
fix_job_kwargs,
)
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances, get_random_data_chunks

Expand Down

0 comments on commit f478f26

Please sign in to comment.