diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 74bc0c1d14..1fa218851b 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -17,6 +17,7 @@ from .globals import get_global_tmp_folder, is_set_global_tmp_folder from .core_tools import ( check_json, + clean_zarr_folder_name, is_dict_extractor, SIJsonEncoder, make_paths_relative, @@ -1061,9 +1062,7 @@ def save_to_zarr( print(f"Use zarr_path={zarr_path}") else: if storage_options is None: - folder = Path(folder) - if folder.suffix != ".zarr": - folder = folder.parent / f"{folder.stem}.zarr" + folder = clean_zarr_folder_name(folder) if folder.is_dir() and overwrite: shutil.rmtree(folder) zarr_path = folder diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index b38222391c..b3595dddf2 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -153,6 +153,13 @@ def check_json(dictionary: dict) -> dict: return json.loads(json_string) +def clean_zarr_folder_name(folder): + folder = Path(folder) + if folder.suffix != ".zarr": + folder = folder.parent / f"{folder.stem}.zarr" + return folder + + def add_suffix(file_path, possible_suffix): file_path = Path(file_path) if isinstance(possible_suffix, str): diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index fa4547d272..fbf0307498 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match -from .core_tools import check_json, retrieve_importing_provenance, is_path_remote +from .core_tools import check_json, retrieve_importing_provenance, is_path_remote, clean_zarr_folder_name from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting @@ -111,6 +111,8 @@ def create_sorting_analyzer( sparsity off (or give external sparsity) like this. """ if format != "memory": + if format == "zarr": + folder = clean_zarr_folder_name(folder) if Path(folder).is_dir(): if not overwrite: raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.") @@ -162,6 +164,8 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto"): The loaded SortingAnalyzer """ + if format == "zarr": + folder = clean_zarr_folder_name(folder) return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format) @@ -269,6 +273,8 @@ def create( sorting_analyzer = cls.load_from_binary_folder(folder, recording=recording) sorting_analyzer.folder = Path(folder) elif format == "zarr": + assert folder is not None, "For format='zarr' folder must be provided" + folder = clean_zarr_folder_name(folder) cls.create_zarr(folder, sorting, recording, sparsity, return_scaled, rec_attributes=None) sorting_analyzer = cls.load_from_zarr(folder, recording=recording) sorting_analyzer.folder = Path(folder) @@ -487,10 +493,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at import zarr import numcodecs - folder = Path(folder) - # force zarr sufix - if folder.suffix != ".zarr": - folder = folder.parent / f"{folder.stem}.zarr" + folder = clean_zarr_folder_name(folder) if folder.is_dir(): raise ValueError(f"Folder already exists {folder}") @@ -768,9 +771,7 @@ def _save_or_select_or_merge( elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" - folder = Path(folder) - if folder.suffix != ".zarr": - folder = folder.parent / f"{folder.stem}.zarr" + folder = clean_zarr_folder_name(folder) SortingAnalyzer.create_zarr( folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes ) @@ -829,6 +830,8 @@ def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": format : "memory" | "binary_folder" | "zarr", default: "memory" The new backend format to use """ + if format == "zarr": + folder = clean_zarr_folder_name(folder) return self._save_or_select_or_merge(format=format, folder=folder) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": @@ -854,6 +857,8 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyz The newly create sorting_analyzer with the selected units """ # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! + if format == "zarr": + folder = clean_zarr_folder_name(folder) return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "SortingAnalyzer": @@ -880,6 +885,8 @@ def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "Sortin """ # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] + if format == "zarr": + folder = clean_zarr_folder_name(folder) return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) def merge_units( @@ -938,6 +945,9 @@ def merge_units( The newly create `SortingAnalyzer` with the selected units """ + if format == "zarr": + folder = clean_zarr_folder_name(folder) + assert merging_mode in ["soft", "hard"], "Merging mode should be either soft or hard" if len(merge_unit_groups) == 0: