Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal for auto import extensions module. #2571

Merged
merged 10 commits into from
Mar 19, 2024
86 changes: 59 additions & 27 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import weakref
import shutil
import warnings
import importlib

import numpy as np

Expand Down Expand Up @@ -908,35 +909,29 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):

def get_saved_extension_names(self):
"""
Get extension saved in folder or zarr that can be loaded.
Get extension names saved in folder or zarr that can be loaded.
This do not load data, this only explores the directory.
"""
assert self.format != "memory"
global _possible_extensions
saved_extension_names = []
if self.format == "binary_folder":
ext_folder = self.folder / "extensions"
if ext_folder.is_dir():
for extension_folder in ext_folder.iterdir():
is_saved = extension_folder.is_dir() and (extension_folder / "params.json").is_file()
if not is_saved:
continue
saved_extension_names.append(extension_folder.stem)

if self.format == "zarr":
elif self.format == "zarr":
zarr_root = self._get_zarr_root(mode="r")
if "extensions" in zarr_root.keys():
extension_group = zarr_root["extensions"]
else:
extension_group = None
for extension_name in extension_group.keys():
if "params" in extension_group[extension_name].attrs.keys():
saved_extension_names.append(extension_name)

saved_extension_names = []
for extension_class in _possible_extensions:
extension_name = extension_class.extension_name

if self.format == "binary_folder":
extension_folder = self.folder / "extensions" / extension_name
is_saved = extension_folder.is_dir() and (extension_folder / "params.json").is_file()
elif self.format == "zarr":
if extension_group is not None:
is_saved = (
extension_name in extension_group.keys()
and "params" in extension_group[extension_name].attrs.keys()
)
else:
is_saved = False
if is_saved:
saved_extension_names.append(extension_class.extension_name)
else:
raise ValueError("SortingAnalyzer.get_saved_extension_names() works only with binary_folder and zarr")

return saved_extension_names

Expand Down Expand Up @@ -1057,14 +1052,16 @@ def register_result_extension(extension_class):
_possible_extensions.append(extension_class)


def get_extension_class(extension_name: str):
def get_extension_class(extension_name: str, auto_import=True):
"""
Get extension class from name and check if registered.

Parameters
----------
extension_name: str
The extension name.
auto_import: bool, default True
Auto import the module if the extension class is not registered yet.

Returns
-------
Expand All @@ -1073,9 +1070,20 @@ def get_extension_class(extension_name: str):
"""
global _possible_extensions
extensions_dict = {ext.extension_name: ext for ext in _possible_extensions}
assert (
extension_name in extensions_dict
), f"Extension '{extension_name}' is not registered, please import related module before use"

if extension_name not in extensions_dict:
if extension_name in _builtin_extensions:
module = _builtin_extensions[extension_name]
if auto_import:
imported_module = importlib.import_module(module)
extensions_dict = {ext.extension_name: ext for ext in _possible_extensions}
else:
raise ValueError(
f"Extension '{extension_name}' is not registered, please import related module before use: 'import {module}'"
)
else:
raise ValueError(f"Extension '{extension_name}' is unknown maybe this is an external extension or a typo.")

ext_class = extensions_dict[extension_name]
return ext_class

Expand Down Expand Up @@ -1471,3 +1479,27 @@ def get_pipeline_nodes(self):
def get_data(self, *args, **kwargs):
assert len(self.data) > 0, f"You must run the extension {self.extension_name} before retrieving data"
return self._get_data(*args, **kwargs)


# this is a hardcoded list to to improve error message and auto_import mechanism
# this is important because extension are registered when the submodule is imported
_builtin_extensions = {
# from core
"random_spikes": "spikeinterface.core",
"waveforms": "spikeinterface.core",
"templates": "spikeinterface.core",
"fast_templates": "spikeinterface.core",
"noise_levels": "spikeinterface.core",
# from postprocessing
"amplitude_scalings": "spikeinterface.postprocessing",
"correlograms": "spikeinterface.postprocessing",
"isi_histograms": "spikeinterface.postprocessing",
"principal_components": "spikeinterface.postprocessing",
"spike_amplitudes": "spikeinterface.postprocessing",
"spike_locations": "spikeinterface.postprocessing",
"template_metrics": "spikeinterface.postprocessing",
"template_similarity": "spikeinterface.postprocessing",
"unit_locations": "spikeinterface.postprocessing",
# from quality metrics
"quality_metrics": "spikeinterface.qualitymetrics",
}
Loading