diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 94bf1ea11c..7138807e73 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -8,6 +8,7 @@ import weakref import shutil import warnings +import importlib import numpy as np @@ -908,35 +909,29 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs): def get_saved_extension_names(self): """ - Get extension saved in folder or zarr that can be loaded. + Get extension names saved in folder or zarr that can be loaded. + This do not load data, this only explores the directory. """ - assert self.format != "memory" - global _possible_extensions + saved_extension_names = [] + if self.format == "binary_folder": + ext_folder = self.folder / "extensions" + if ext_folder.is_dir(): + for extension_folder in ext_folder.iterdir(): + is_saved = extension_folder.is_dir() and (extension_folder / "params.json").is_file() + if not is_saved: + continue + saved_extension_names.append(extension_folder.stem) - if self.format == "zarr": + elif self.format == "zarr": zarr_root = self._get_zarr_root(mode="r") if "extensions" in zarr_root.keys(): extension_group = zarr_root["extensions"] - else: - extension_group = None + for extension_name in extension_group.keys(): + if "params" in extension_group[extension_name].attrs.keys(): + saved_extension_names.append(extension_name) - saved_extension_names = [] - for extension_class in _possible_extensions: - extension_name = extension_class.extension_name - - if self.format == "binary_folder": - extension_folder = self.folder / "extensions" / extension_name - is_saved = extension_folder.is_dir() and (extension_folder / "params.json").is_file() - elif self.format == "zarr": - if extension_group is not None: - is_saved = ( - extension_name in extension_group.keys() - and "params" in extension_group[extension_name].attrs.keys() - ) - else: - is_saved = False - if is_saved: - saved_extension_names.append(extension_class.extension_name) + else: + raise ValueError("SortingAnalyzer.get_saved_extension_names() works only with binary_folder and zarr") return saved_extension_names @@ -1057,7 +1052,7 @@ def register_result_extension(extension_class): _possible_extensions.append(extension_class) -def get_extension_class(extension_name: str): +def get_extension_class(extension_name: str, auto_import=True): """ Get extension class from name and check if registered. @@ -1065,6 +1060,8 @@ def get_extension_class(extension_name: str): ---------- extension_name: str The extension name. + auto_import: bool, default True + Auto import the module if the extension class is not registered yet. Returns ------- @@ -1073,9 +1070,20 @@ def get_extension_class(extension_name: str): """ global _possible_extensions extensions_dict = {ext.extension_name: ext for ext in _possible_extensions} - assert ( - extension_name in extensions_dict - ), f"Extension '{extension_name}' is not registered, please import related module before use" + + if extension_name not in extensions_dict: + if extension_name in _builtin_extensions: + module = _builtin_extensions[extension_name] + if auto_import: + imported_module = importlib.import_module(module) + extensions_dict = {ext.extension_name: ext for ext in _possible_extensions} + else: + raise ValueError( + f"Extension '{extension_name}' is not registered, please import related module before use: 'import {module}'" + ) + else: + raise ValueError(f"Extension '{extension_name}' is unknown maybe this is an external extension or a typo.") + ext_class = extensions_dict[extension_name] return ext_class @@ -1471,3 +1479,27 @@ def get_pipeline_nodes(self): def get_data(self, *args, **kwargs): assert len(self.data) > 0, f"You must run the extension {self.extension_name} before retrieving data" return self._get_data(*args, **kwargs) + + +# this is a hardcoded list to to improve error message and auto_import mechanism +# this is important because extension are registered when the submodule is imported +_builtin_extensions = { + # from core + "random_spikes": "spikeinterface.core", + "waveforms": "spikeinterface.core", + "templates": "spikeinterface.core", + "fast_templates": "spikeinterface.core", + "noise_levels": "spikeinterface.core", + # from postprocessing + "amplitude_scalings": "spikeinterface.postprocessing", + "correlograms": "spikeinterface.postprocessing", + "isi_histograms": "spikeinterface.postprocessing", + "principal_components": "spikeinterface.postprocessing", + "spike_amplitudes": "spikeinterface.postprocessing", + "spike_locations": "spikeinterface.postprocessing", + "template_metrics": "spikeinterface.postprocessing", + "template_similarity": "spikeinterface.postprocessing", + "unit_locations": "spikeinterface.postprocessing", + # from quality metrics + "quality_metrics": "spikeinterface.qualitymetrics", +}