diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 066194725d..073708f353 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -49,9 +49,8 @@ class ComputeRandomSpikes(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def _run( - self, - ): + def _run(self, verbose=False): + self.data["random_spikes_indices"] = random_spikes_selection( self.sorting_analyzer.sorting, num_samples=self.sorting_analyzer.rec_attributes["num_samples"], @@ -145,7 +144,7 @@ def nbefore(self): def nafter(self): return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): self.data.clear() recording = self.sorting_analyzer.recording @@ -183,6 +182,7 @@ def _run(self, **job_kwargs): sparsity_mask=sparsity_mask, copy=copy, job_name="compute_waveforms", + verbose=verbose, **job_kwargs, ) @@ -311,7 +311,7 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N ) return params - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): self.data.clear() if self.sorting_analyzer.has_extension("waveforms"): @@ -339,6 +339,7 @@ def _run(self, **job_kwargs): self.nafter, return_scaled=return_scaled, return_std=return_std, + verbose=verbose, **job_kwargs, ) @@ -581,7 +582,7 @@ def _select_extension_data(self, unit_ids): # this do not depend on units return self.data - def _run(self): + def _run(self, verbose=False): self.data["noise_levels"] = get_noise_levels( self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params ) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ee6cf5268d..3585b07b23 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -473,6 +473,7 @@ def run_node_pipeline( squeeze_output=True, folder=None, names=None, + verbose=False, ): """ Common function to run pipeline with peak detector or already detected peak. @@ -499,6 +500,7 @@ def run_node_pipeline( init_args, gather_func=gather_func, job_name=job_name, + verbose=verbose, **job_kwargs, ) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index d9fcf44442..53e060262b 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -835,7 +835,7 @@ def get_num_units(self) -> int: return self.sorting.get_num_units() ## extensions zone - def compute(self, input, save=True, extension_params=None, **kwargs): + def compute(self, input, save=True, extension_params=None, verbose=False, **kwargs): """ Compute one extension or several extensiosn. Internally calls compute_one_extension() or compute_several_extensions() depending on the input type. @@ -883,11 +883,11 @@ def compute(self, input, save=True, extension_params=None, **kwargs): ) """ if isinstance(input, str): - return self.compute_one_extension(extension_name=input, save=save, **kwargs) + return self.compute_one_extension(extension_name=input, save=save, verbose=verbose, **kwargs) elif isinstance(input, dict): params_, job_kwargs = split_job_kwargs(kwargs) assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" - self.compute_several_extensions(extensions=input, save=save, **job_kwargs) + self.compute_several_extensions(extensions=input, save=save, verbose=verbose, **job_kwargs) elif isinstance(input, list): params_, job_kwargs = split_job_kwargs(kwargs) assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()" @@ -898,11 +898,11 @@ def compute(self, input, save=True, extension_params=None, **kwargs): ext_name in input ), f"SortingAnalyzer.compute(): Parameters specified for {ext_name}, which is not in the specified {input}" extensions[ext_name] = ext_params - self.compute_several_extensions(extensions=extensions, save=save, **job_kwargs) + self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs) else: raise ValueError("SortingAnalyzer.compute() need str, dict or list") - def compute_one_extension(self, extension_name, save=True, **kwargs): + def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs): """ Compute one extension. @@ -925,7 +925,7 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): Returns ------- result_extension: AnalyzerExtension - Return the extension instance. + Return the extension instance Examples -------- @@ -961,13 +961,16 @@ def compute_one_extension(self, extension_name, save=True, **kwargs): extension_instance = extension_class(self) extension_instance.set_params(save=save, **params) - extension_instance.run(save=save, **job_kwargs) + if extension_class.need_job_kwargs: + extension_instance.run(save=save, verbose=verbose, **job_kwargs) + else: + extension_instance.run(save=save, verbose=verbose) self.extensions[extension_name] = extension_instance return extension_instance - def compute_several_extensions(self, extensions, save=True, **job_kwargs): + def compute_several_extensions(self, extensions, save=True, verbose=False, **job_kwargs): """ Compute several extensions @@ -1021,9 +1024,9 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): for extension_name, extension_params in extensions_without_pipeline.items(): extension_class = get_extension_class(extension_name) if extension_class.need_job_kwargs: - self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs) else: - self.compute_one_extension(extension_name, save=save, **extension_params) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params) # then extensions with pipeline if len(extensions_with_pipeline) > 0: all_nodes = [] @@ -1053,6 +1056,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): job_name=job_name, gather_mode="memory", squeeze_output=False, + verbose=verbose, ) for r, result in enumerate(results): @@ -1071,9 +1075,9 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): for extension_name, extension_params in extensions_post_pipeline.items(): extension_class = get_extension_class(extension_name) if extension_class.need_job_kwargs: - self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs) else: - self.compute_one_extension(extension_name, save=save, **extension_params) + self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params) def get_saved_extension_names(self): """ diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 58966334db..acc368b2e5 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -221,6 +221,7 @@ def distribute_waveforms_to_buffers( mode="memmap", sparsity_mask=None, job_name=None, + verbose=False, **job_kwargs, ): """ @@ -281,7 +282,9 @@ def distribute_waveforms_to_buffers( ) if job_name is None: job_name = f"extract waveforms {mode} multi buffer" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + ) processor.run() @@ -410,6 +413,7 @@ def extract_waveforms_to_single_buffer( sparsity_mask=None, copy=True, job_name=None, + verbose=False, **job_kwargs, ): """ @@ -523,7 +527,9 @@ def extract_waveforms_to_single_buffer( if job_name is None: job_name = f"extract waveforms {mode} mono buffer" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + ) processor.run() if mode == "memmap": @@ -783,6 +789,7 @@ def estimate_templates_with_accumulator( return_scaled: bool = True, job_name=None, return_std: bool = False, + verbose: bool = False, **job_kwargs, ): """ @@ -861,7 +868,9 @@ def estimate_templates_with_accumulator( if job_name is None: job_name = "estimate_templates_with_accumulator" - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + ) processor.run() # average diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index e2dcdd8e5a..d2b363e69a 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -181,7 +181,7 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, amplitude_scalings_node] return nodes - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amp_scalings, collision_mask = run_node_pipeline( @@ -190,6 +190,7 @@ def _run(self, **job_kwargs): job_kwargs=job_kwargs, job_name="amplitude_scalings", gather_mode="memory", + verbose=verbose, ) self.data["amplitude_scalings"] = amp_scalings if self.params["handle_collisions"]: diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 3a01305d6b..f0bd151c68 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -70,7 +70,7 @@ def _select_extension_data(self, unit_ids): new_data = dict(ccgs=new_ccgs, bins=new_bins) return new_data - def _run(self): + def _run(self, verbose=False): ccgs, bins = compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) self.data["ccgs"] = ccgs self.data["bins"] = bins diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index c7e850993f..3742cbfa96 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -56,7 +56,7 @@ def _select_extension_data(self, unit_ids): new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) return new_extension_data - def _run(self): + def _run(self, verbose=False): isi_histograms, bins = _compute_isi_histograms(self.sorting_analyzer.sorting, **self.params) self.data["isi_histograms"] = isi_histograms self.data["bins"] = bins diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 76d7c1744e..8eb375e90b 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -256,7 +256,7 @@ def project_new(self, new_spikes, new_waveforms, progress_bar=True): new_projections = self._transform_waveforms(new_spikes, new_waveforms, pca_model, progress_bar=progress_bar) return new_projections - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): """ Compute the PCs on waveforms extacted within the by ComputeWaveforms. Projections are computed only on the waveforms sampled by the SortingAnalyzer. @@ -295,7 +295,7 @@ def _run(self, **job_kwargs): def _get_data(self): return self.data["pca_projection"] - def run_for_all_spikes(self, file_path=None, **job_kwargs): + def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): """ Project all spikes from the sorting on the PCA model. This is a long computation because waveform need to be extracted from each spikes. @@ -359,7 +359,9 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): unit_channels, pca_model, ) - processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", **job_kwargs) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs + ) processor.run() def _fit_by_channel_local(self, n_jobs, progress_bar): diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index add9764790..cc1d4b26e9 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -107,7 +107,7 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, spike_amplitudes_node] return nodes - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() amps = run_node_pipeline( @@ -116,6 +116,7 @@ def _run(self, **job_kwargs): job_kwargs=job_kwargs, job_name="spike_amplitudes", gather_mode="memory", + verbose=False, ) self.data["amplitudes"] = amps diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d1e1d38c6a..d468bd90ab 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -120,7 +120,7 @@ def _get_pipeline_nodes(self): ) return nodes - def _run(self, **job_kwargs): + def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) nodes = self.get_pipeline_nodes() spike_locations = run_node_pipeline( @@ -129,6 +129,7 @@ def _run(self, **job_kwargs): job_kwargs=job_kwargs, job_name="spike_locations", gather_mode="memory", + verbose=verbose, ) self.data["spike_locations"] = spike_locations diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 19e6a1b47a..d7179ffefa 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -150,7 +150,7 @@ def _select_extension_data(self, unit_ids): new_metrics = self.data["metrics"].loc[np.array(unit_ids)] return dict(metrics=new_metrics) - def _run(self): + def _run(self, verbose=False): import pandas as pd from scipy.signal import resample_poly diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 4a45da0269..15a1fe34ce 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -42,7 +42,7 @@ def _select_extension_data(self, unit_ids): new_similarity = self.data["similarity"][unit_indices][:, unit_indices] return dict(similarity=new_similarity) - def _run(self): + def _run(self, verbose=False): templates_array = get_dense_templates_array( self.sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled ) diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index d1d3b5075a..e40523e7e5 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -64,7 +64,7 @@ def _select_extension_data(self, unit_ids): new_unit_location = self.data["unit_locations"][unit_inds] return dict(unit_locations=new_unit_location) - def _run(self): + def _run(self, verbose=False): method = self.params["method"] method_kwargs = self.params["method_kwargs"] diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 1c5c947b02..9476a0df03 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -7,7 +7,9 @@ from spikeinterface.core import get_chunk_with_margin -def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extra_outputs=False, **job_kwargs): +def find_spikes_from_templates( + recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs +): """Find spike from a recording from given templates. Parameters @@ -53,7 +55,14 @@ def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extr init_func = _init_worker_find_spikes init_args = (recording, method, method_kwargs_seralized) processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, handle_returns=True, job_name=f"find spikes ({method})", **job_kwargs + recording, + func, + init_func, + init_args, + handle_returns=True, + job_name=f"find spikes ({method})", + verbose=verbose, + **job_kwargs, ) spikes = processor.run() diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 508a033c41..4e1fa64961 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -9,10 +9,8 @@ import numpy as np from spikeinterface.core.job_tools import ( - ChunkRecordingExecutor, _shared_job_kwargs_doc, split_job_kwargs, - fix_job_kwargs, ) from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances, get_random_data_chunks