diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 177188f21d..4961db8524 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -232,6 +232,8 @@ def __repr__(self) -> str: txt += " - sparse" if self.has_recording(): txt += " - has recording" + if self.has_temporary_recording(): + txt += " - has temporary recording" ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys()) txt += "\n" + ext_txt return txt @@ -350,7 +352,7 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes): # used by create and save_as - assert recording is not None, "To create a SortingAnalyzer you need recording not None" + assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" folder = Path(folder) if folder.is_dir(): @@ -1221,7 +1223,7 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar extensions[ext_name] = ext_params self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs) else: - raise ValueError("SortingAnalyzer.compute() need str, dict or list") + raise ValueError("SortingAnalyzer.compute() needs a str, dict or list") def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs) -> "AnalyzerExtension": """ @@ -1355,7 +1357,9 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job for extension_name, extension_params in extensions_with_pipeline.items(): extension_class = get_extension_class(extension_name) - assert self.has_recording(), f"Extension {extension_name} need the recording" + assert ( + self.has_recording() or self.has_temporary_recording() + ), f"Extension {extension_name} requires the recording" for variable_name in extension_class.nodepipeline_variables: result_routage.append((extension_name, variable_name)) @@ -1603,17 +1607,17 @@ def _sort_extensions_by_dependency(extensions): def _get_children_dependencies(extension_name): """ Extension classes have a `depend_on` attribute to declare on which class they - depend. For instance "templates" depend on "waveforms". "waveforms depends on "random_spikes". + depend on. For instance "templates" depends on "waveforms". "waveforms" depends on "random_spikes". - This function is making the reverse way : get all children that depend of a + This function is going the opposite way: it finds all children that depend on a particular extension. - This is recursive so this includes : children and so grand children and great grand children + The implementation is recursive so that the output includes children, grand children, great grand children, etc. - This function is usefull for deleting on recompute. - For instance recompute the "waveforms" need to delete "template" - This make sens if "ms_before" is change in "waveforms" because the template also depends - on this parameters. + This function is useful for deleting existing extensions on recompute. + For instance, recomputing the "waveforms" needs to delete the "templates", since the latter depends on the former. + For this particular example, if we change the "ms_before" parameter of the "waveforms", also the "templates" will + require recomputation as this parameter is inherited. """ names = [] children = _extension_children[extension_name] diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index f1f89403c7..1871c11b85 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -359,12 +359,12 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) p = self.params - we = self.sorting_analyzer - sorting = we.sorting + sorting_analyzer = self.sorting_analyzer + sorting = sorting_analyzer.sorting assert ( - we.has_recording() - ), "To compute PCA projections for all spikes, the waveform extractor needs the recording" - recording = we.recording + sorting_analyzer.has_recording() or sorting_analyzer.has_temporary_recording() + ), "To compute PCA projections for all spikes, the sorting analyzer needs the recording" + recording = sorting_analyzer.recording # assert sorting.get_num_segments() == 1 assert p["mode"] in ("by_channel_local", "by_channel_global") @@ -374,8 +374,9 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): sparsity = self.sorting_analyzer.sparsity if sparsity is None: - sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids} - max_channels_per_template = we.get_num_channels() + num_channels = recording.get_num_channels() + sparse_channels_indices = {unit_id: np.arange(num_channels) for unit_id in sorting_analyzer.unit_ids} + max_channels_per_template = num_channels else: sparse_channels_indices = sparsity.unit_id_to_channel_indices max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()]) @@ -449,9 +450,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): return pca_models def _fit_by_channel_global(self, progress_bar): - # we = self.sorting_analyzer p = self.params - # unit_ids = we.unit_ids unit_ids = self.sorting_analyzer.unit_ids # there is one unique PCA accross channels