diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 49a31738e3..424fab7c5e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -77,6 +77,9 @@ def create_sorting_analyzer( return_scaled : bool, default: True All extensions that play with traces will use this global return_scaled : "waveforms", "noise_levels", "templates". This prevent return_scaled being differents from different extensions and having wrong snr for instance. + overwrite: bool, default: False + If True, overwrite the folder if it already exists. + Returns ------- @@ -486,7 +489,11 @@ def _get_zarr_root(self, mode="r+"): if is_path_remote(str(self.folder)): mode = "r" - zarr_root = zarr.open(self.folder, mode=mode, storage_options=self.storage_options) + # we open_consolidated only if we are in read mode + if mode in ("r+", "a"): + zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=self.storage_options) + else: + zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=self.storage_options) return zarr_root @classmethod @@ -564,11 +571,13 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info = zarr_root.create_group("extensions") + zarr.consolidate_metadata(zarr_root.store) + @classmethod def load_from_zarr(cls, folder, recording=None, storage_options=None): import zarr - zarr_root = zarr.open(str(folder), mode="r", storage_options=storage_options) + zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) # load internal sorting in memory sorting = NumpySorting.from_sorting( @@ -614,6 +623,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): format="zarr", sparsity=sparsity, return_scaled=return_scaled, + storage_options=storage_options, ) return sorting_analyzer @@ -1462,7 +1472,7 @@ def delete_extension(self, extension_name) -> None: if self.format != "memory" and self.has_extension(extension_name): # need a reload to reset the folder ext = self.load_extension(extension_name) - ext.reset() + ext.delete() # remove from dict self.extensions.pop(extension_name, None) @@ -2004,19 +2014,17 @@ def run(self, save=True, **kwargs): # NB: this call to _save_params() also resets the folder or zarr group self._save_params() self._save_importing_provenance() - self._save_run_info() t_start = perf_counter() self._run(**kwargs) t_end = perf_counter() self.run_info["runtime_s"] = t_end - t_start + self.run_info["run_completed"] = True if save and not self.sorting_analyzer.is_read_only(): + self._save_run_info() self._save_data(**kwargs) - self.run_info["run_completed"] = True - self._save_run_info() - def save(self, **kwargs): self._save_params() self._save_importing_provenance() @@ -2062,7 +2070,7 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - + import zarr import numcodecs extension_group = self._get_zarr_extension_group(mode="r+") @@ -2096,6 +2104,8 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") extension_group[ext_data_name].attrs["object"] = True + # we need to re-consolidate + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def _reset_extension_folder(self): """ @@ -2110,8 +2120,35 @@ def _reset_extension_folder(self): elif self.format == "zarr": import zarr - zarr_root = zarr.open(self.folder, mode="r+") - extension_group = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) + zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+") + _ = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) + zarr.consolidate_metadata(zarr_root.store) + + def _delete_extension_folder(self): + """ + Delete the extension in a folder (binary or zarr). + """ + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + if extension_folder.is_dir(): + shutil.rmtree(extension_folder) + + elif self.format == "zarr": + import zarr + + zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+") + if self.extension_name in zarr_root["extensions"]: + del zarr_root["extensions"][self.extension_name] + zarr.consolidate_metadata(zarr_root.store) + + def delete(self): + """ + Delete the extension from the folder or zarr and from the dict. + """ + self._delete_extension_folder() + self.params = None + self.run_info = self._default_run_info_dict() + self.data = dict() def reset(self): """ @@ -2128,7 +2165,7 @@ def set_params(self, save=True, **params): Set parameters for the extension and make it persistent in json. """ - # this ensure data is also deleted and corresponf to params + # this ensure data is also deleted and corresponds to params # this also ensure the group is created self._reset_extension_folder() @@ -2141,7 +2178,6 @@ def set_params(self, save=True, **params): if save: self._save_params() self._save_importing_provenance() - self._save_run_info() def _save_params(self): params_to_save = self.params.copy() diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index b4d96a3391..626899ab6e 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -79,15 +79,20 @@ def _check_result_extension(sorting_analyzer, extension_name, cache_folder): ) def test_ComputeRandomSpikes(format, sparse, create_cache_folder): cache_folder = create_cache_folder + print("Creating analyzer") sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse) + print("Computing random spikes") ext = sorting_analyzer.compute("random_spikes", max_spikes_per_unit=10, seed=2205) indices = ext.data["random_spikes_indices"] assert indices.size == 10 * sorting_analyzer.sorting.unit_ids.size + print("Checking results") _check_result_extension(sorting_analyzer, "random_spikes", cache_folder) + print("Delering extension") sorting_analyzer.delete_extension("random_spikes") + print("Re-computing random spikes") ext = sorting_analyzer.compute("random_spikes", method="all") indices = ext.data["random_spikes_indices"] assert indices.size == len(sorting_analyzer.sorting.to_spike_vector()) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 3f45487f4c..77b8f2c5bf 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -126,6 +126,8 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): def test_load_without_runtime_info(tmp_path, dataset): + import zarr + recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_run_info" @@ -153,6 +155,7 @@ def test_load_without_runtime_info(tmp_path, dataset): root = sorting_analyzer._get_zarr_root(mode="r+") for ext in extensions: del root["extensions"][ext].attrs["run_info"] + zarr.consolidate_metadata(root.store) # should raise a warning for missing run_info with pytest.warns(UserWarning): sorting_analyzer = load_sorting_analyzer(folder, format="auto")