From dcedbb33739663e0cb662d313e0224f3d64d04f8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 13:01:11 +0200 Subject: [PATCH 1/5] Simplify pandas save-load and convert dtypes --- pyproject.toml | 3 -- src/spikeinterface/core/sortinganalyzer.py | 37 ++++++++++++------- .../postprocessing/template_metrics.py | 2 +- .../tests/common_extension_tests.py | 23 +++++++++++- .../quality_metric_calculator.py | 2 +- 5 files changed, 47 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8309ca89fe..b5894bf3a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,6 @@ preprocessing = [ full = [ "h5py", "pandas", - "xarray", "scipy", "scikit-learn", "networkx", @@ -148,7 +147,6 @@ test = [ "pytest-dependency", "pytest-cov", - "xarray", "huggingface_hub", # preprocessing @@ -193,7 +191,6 @@ docs = [ "pandas", # in the modules gallery comparison tutorial "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions - "xarray", # For use of SortingAnalyzer zarr format "networkx", # Download data "pooch>=1.8.2", diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8980fb5559..312f85a8ca 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1970,12 +1970,14 @@ def load_data(self): if "dict" in ext_data_.attrs: ext_data = ext_data_[0] elif "dataframe" in ext_data_.attrs: - import xarray + import pandas as pd - ext_data = xarray.open_zarr( - ext_data_.store, group=f"{extension_group.name}/{ext_data_name}" - ).to_pandas() - ext_data.index.rename("", inplace=True) + index = ext_data_["index"] + ext_data = pd.DataFrame(index=index) + for col in ext_data_.keys(): + if col != "index": + ext_data.loc[:, col] = ext_data_[col][:] + ext_data = ext_data.convert_dtypes() elif "object" in ext_data_.attrs: ext_data = ext_data_[0] else: @@ -2031,12 +2033,21 @@ def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): self._save_run_info() self._save_data(**kwargs) + if self.format == "zarr": + import zarr + + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def save(self, **kwargs): self._save_params() self._save_importing_provenance() - self._save_data(**kwargs) self._save_run_info() + self._save_data(**kwargs) + + if self.format == "zarr": + import zarr + + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def _save_data(self, **kwargs): if self.format == "memory": @@ -2096,12 +2107,12 @@ def _save_data(self, **kwargs): elif isinstance(ext_data, np.ndarray): extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): - ext_data.to_xarray().to_zarr( - store=extension_group.store, - group=f"{extension_group.name}/{ext_data_name}", - mode="a", - ) - extension_group[ext_data_name].attrs["dataframe"] = True + df_group = extension_group.create_group(ext_data_name) + # first we save the index + df_group.create_dataset(name="index", data=ext_data.index.to_numpy()) + for col in ext_data.columns: + df_group.create_dataset(name=col, data=ext_data[col].to_numpy()) + df_group.attrs["dataframe"] = True else: # any object try: @@ -2111,8 +2122,6 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") extension_group[ext_data_name].attrs["object"] = True - # we need to re-consolidate - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def _reset_extension_folder(self): """ diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 45ba55dee4..ee5ac6103b 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -287,7 +287,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") value = np.nan template_metrics.at[index, metric_name] = value - return template_metrics + return template_metrics.convert_dtypes() def _run(self, verbose=False): self.data["metrics"] = self._compute_metrics( diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 3945e71881..1b0a94d635 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -3,9 +3,10 @@ import pytest import shutil import numpy as np +import pandas as pd from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import create_sorting_analyzer +from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer from spikeinterface.core import estimate_sparsity @@ -138,6 +139,26 @@ def _check_one(self, sorting_analyzer, extension_class, params): merged = sorting_analyzer.merge_units(some_merges, format="memory", merging_mode="soft", sparsity_overlap=0.0) assert len(merged.unit_ids) == num_units_after_merge + # test roundtrip + if sorting_analyzer.format in ("binary_folder", "zarr"): + sorting_analyzer_loaded = load_sorting_analyzer(sorting_analyzer.folder) + ext_loaded = sorting_analyzer_loaded.get_extension(extension_class.extension_name) + for ext_data_name, ext_data_loaded in ext_loaded.data.items(): + if isinstance(ext_data_loaded, np.ndarray): + assert np.array_equal(ext.data[ext_data_name], ext_data_loaded) + elif isinstance(ext_data_loaded, pd.DataFrame): + # skip nan values + for col in ext_data_loaded.columns: + np.testing.assert_array_almost_equal( + ext.data[ext_data_name][col].dropna().to_numpy(), + ext_data_loaded[col].dropna().to_numpy(), + decimal=5, + ) + elif isinstance(ext_data_loaded, dict): + assert ext.data[ext_data_name] == ext_data_loaded + else: + continue + def run_extension_tests(self, extension_class, params): """ Convenience function to perform all checks on the extension diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index cdf6151e95..3d7096651f 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -185,7 +185,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan - return metrics + return metrics.convert_dtypes() def _run(self, verbose=False, **job_kwargs): self.data["metrics"] = self._compute_metrics( From 9000ce1980130fac15394b7cf68137fe239b0b13 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 13:08:25 +0200 Subject: [PATCH 2/5] local import --- .../postprocessing/tests/common_extension_tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 1b0a94d635..2207b98da6 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -3,7 +3,6 @@ import pytest import shutil import numpy as np -import pandas as pd from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer @@ -117,6 +116,8 @@ def _check_one(self, sorting_analyzer, extension_class, params): with the passed parameters, and check the output is not empty, the extension exists and `select_units()` method works. """ + import pandas as pd + if extension_class.need_job_kwargs: job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) else: From c1228d9269663118c065ae9e5f68cb1c0a8b16c7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 18:00:28 +0200 Subject: [PATCH 3/5] Add comment and re-consolidation step for 0.101.0 datasets --- src/spikeinterface/core/sortinganalyzer.py | 15 +++++++++++++++ .../postprocessing/template_metrics.py | 6 +++++- .../qualitymetrics/quality_metric_calculator.py | 9 ++++++--- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 312f85a8ca..a94b7aa3dc 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,6 +11,7 @@ import shutil import warnings import importlib +from packaging.version import parse from time import perf_counter import numpy as np @@ -579,6 +580,20 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) + si_info = zarr_root.attrs["spikeinterface_info"] + if parse(si_info["version"]) < parse("0.101.1"): + # v0.101.0 did not have a consolidate metadata step after computing extensions. + # Here we try to consolidate the metadata and throw a warning if it fails. + try: + zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options) + zarr.consolidate_metadata(zarr_root_a.store) + except Exception as e: + warnings.warn( + "The zarr store was not properly consolidated prior to v0.101.1. " + "This may lead to unexpected behavior in loading extensions. " + "Please consider re-saving the SortingAnalyzer object." + ) + # load internal sorting in memory sorting = NumpySorting.from_sorting( ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index ee5ac6103b..aa50be4c13 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -287,7 +287,11 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") value = np.nan template_metrics.at[index, metric_name] = value - return template_metrics.convert_dtypes() + + # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns + # (in case of NaN values) + template_metrics = template_metrics.convert_dtypes() + return template_metrics def _run(self, verbose=False): self.data["metrics"] = self._compute_metrics( diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 3d7096651f..b2804c2638 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -108,6 +108,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job """ Compute quality metrics. """ + import pandas as pd + metric_names = self.params["metric_names"] qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] @@ -132,8 +134,6 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job non_empty_unit_ids = unit_ids empty_unit_ids = [] - import pandas as pd - metrics = pd.DataFrame(index=unit_ids) # simple metrics not based on PCs @@ -185,7 +185,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan - return metrics.convert_dtypes() + # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns + # (in case of NaN values) + metrics = metrics.convert_dtypes() + return metrics def _run(self, verbose=False, **job_kwargs): self.data["metrics"] = self._compute_metrics( From b1677fabd82f36d0ed51af8418a559661cbfa4e3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 18:43:27 +0200 Subject: [PATCH 4/5] Update src/spikeinterface/qualitymetrics/quality_metric_calculator.py --- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 41ad40293a..3b6c6d3e50 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -141,7 +141,6 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri """ import pandas as pd - metric_names = self.params["metric_names"] qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] From 9a7295948b83bef46736252f49ac6b164acee53d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 10:04:17 +0200 Subject: [PATCH 5/5] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a94b7aa3dc..177188f21d 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -591,7 +591,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): warnings.warn( "The zarr store was not properly consolidated prior to v0.101.1. " "This may lead to unexpected behavior in loading extensions. " - "Please consider re-saving the SortingAnalyzer object." + "Please consider re-generating the SortingAnalyzer object." ) # load internal sorting in memory