Skip to content

Commit

Permalink
Merge pull request #3378 from alejoe91/zarr-consolidated
Browse files Browse the repository at this point in the history
Ensure sorting analyzer in zarr are consolidated
  • Loading branch information
samuelgarcia authored Sep 12, 2024
2 parents 48b2131 + ff07ac6 commit c176830
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 12 deletions.
60 changes: 48 additions & 12 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def create_sorting_analyzer(
return_scaled : bool, default: True
All extensions that play with traces will use this global return_scaled : "waveforms", "noise_levels", "templates".
This prevent return_scaled being differents from different extensions and having wrong snr for instance.
overwrite: bool, default: False
If True, overwrite the folder if it already exists.
Returns
-------
Expand Down Expand Up @@ -486,7 +489,11 @@ def _get_zarr_root(self, mode="r+"):

if is_path_remote(str(self.folder)):
mode = "r"
zarr_root = zarr.open(self.folder, mode=mode, storage_options=self.storage_options)
# we open_consolidated only if we are in read mode
if mode in ("r+", "a"):
zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=self.storage_options)
else:
zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=self.storage_options)
return zarr_root

@classmethod
Expand Down Expand Up @@ -564,11 +571,13 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at

recording_info = zarr_root.create_group("extensions")

zarr.consolidate_metadata(zarr_root.store)

@classmethod
def load_from_zarr(cls, folder, recording=None, storage_options=None):
import zarr

zarr_root = zarr.open(str(folder), mode="r", storage_options=storage_options)
zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options)

# load internal sorting in memory
sorting = NumpySorting.from_sorting(
Expand Down Expand Up @@ -614,6 +623,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None):
format="zarr",
sparsity=sparsity,
return_scaled=return_scaled,
storage_options=storage_options,
)

return sorting_analyzer
Expand Down Expand Up @@ -1462,7 +1472,7 @@ def delete_extension(self, extension_name) -> None:
if self.format != "memory" and self.has_extension(extension_name):
# need a reload to reset the folder
ext = self.load_extension(extension_name)
ext.reset()
ext.delete()

# remove from dict
self.extensions.pop(extension_name, None)
Expand Down Expand Up @@ -2004,19 +2014,17 @@ def run(self, save=True, **kwargs):
# NB: this call to _save_params() also resets the folder or zarr group
self._save_params()
self._save_importing_provenance()
self._save_run_info()

t_start = perf_counter()
self._run(**kwargs)
t_end = perf_counter()
self.run_info["runtime_s"] = t_end - t_start
self.run_info["run_completed"] = True

if save and not self.sorting_analyzer.is_read_only():
self._save_run_info()
self._save_data(**kwargs)

self.run_info["run_completed"] = True
self._save_run_info()

def save(self, **kwargs):
self._save_params()
self._save_importing_provenance()
Expand Down Expand Up @@ -2062,7 +2070,7 @@ def _save_data(self, **kwargs):
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
elif self.format == "zarr":

import zarr
import numcodecs

extension_group = self._get_zarr_extension_group(mode="r+")
Expand Down Expand Up @@ -2096,6 +2104,8 @@ def _save_data(self, **kwargs):
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
extension_group[ext_data_name].attrs["object"] = True
# we need to re-consolidate
zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _reset_extension_folder(self):
"""
Expand All @@ -2110,8 +2120,35 @@ def _reset_extension_folder(self):
elif self.format == "zarr":
import zarr

zarr_root = zarr.open(self.folder, mode="r+")
extension_group = zarr_root["extensions"].create_group(self.extension_name, overwrite=True)
zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+")
_ = zarr_root["extensions"].create_group(self.extension_name, overwrite=True)
zarr.consolidate_metadata(zarr_root.store)

def _delete_extension_folder(self):
"""
Delete the extension in a folder (binary or zarr).
"""
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
if extension_folder.is_dir():
shutil.rmtree(extension_folder)

elif self.format == "zarr":
import zarr

zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+")
if self.extension_name in zarr_root["extensions"]:
del zarr_root["extensions"][self.extension_name]
zarr.consolidate_metadata(zarr_root.store)

def delete(self):
"""
Delete the extension from the folder or zarr and from the dict.
"""
self._delete_extension_folder()
self.params = None
self.run_info = self._default_run_info_dict()
self.data = dict()

def reset(self):
"""
Expand All @@ -2128,7 +2165,7 @@ def set_params(self, save=True, **params):
Set parameters for the extension and
make it persistent in json.
"""
# this ensure data is also deleted and corresponf to params
# this ensure data is also deleted and corresponds to params
# this also ensure the group is created
self._reset_extension_folder()

Expand All @@ -2141,7 +2178,6 @@ def set_params(self, save=True, **params):
if save:
self._save_params()
self._save_importing_provenance()
self._save_run_info()

def _save_params(self):
params_to_save = self.params.copy()
Expand Down
5 changes: 5 additions & 0 deletions src/spikeinterface/core/tests/test_analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,20 @@ def _check_result_extension(sorting_analyzer, extension_name, cache_folder):
)
def test_ComputeRandomSpikes(format, sparse, create_cache_folder):
cache_folder = create_cache_folder
print("Creating analyzer")
sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse)

print("Computing random spikes")
ext = sorting_analyzer.compute("random_spikes", max_spikes_per_unit=10, seed=2205)
indices = ext.data["random_spikes_indices"]
assert indices.size == 10 * sorting_analyzer.sorting.unit_ids.size

print("Checking results")
_check_result_extension(sorting_analyzer, "random_spikes", cache_folder)
print("Delering extension")
sorting_analyzer.delete_extension("random_spikes")

print("Re-computing random spikes")
ext = sorting_analyzer.compute("random_spikes", method="all")
indices = ext.data["random_spikes_indices"]
assert indices.size == len(sorting_analyzer.sorting.to_spike_vector())
Expand Down
3 changes: 3 additions & 0 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):


def test_load_without_runtime_info(tmp_path, dataset):
import zarr

recording, sorting = dataset

folder = tmp_path / "test_SortingAnalyzer_run_info"
Expand Down Expand Up @@ -153,6 +155,7 @@ def test_load_without_runtime_info(tmp_path, dataset):
root = sorting_analyzer._get_zarr_root(mode="r+")
for ext in extensions:
del root["extensions"][ext].attrs["run_info"]
zarr.consolidate_metadata(root.store)
# should raise a warning for missing run_info
with pytest.warns(UserWarning):
sorting_analyzer = load_sorting_analyzer(folder, format="auto")
Expand Down

0 comments on commit c176830

Please sign in to comment.