Skip to content

Commit

Permalink
Merge pull request #3412 from alejoe91/refactor-pandas-save-load
Browse files Browse the repository at this point in the history
Refactor pandas save load and convert dtypes
  • Loading branch information
samuelgarcia authored Sep 16, 2024
2 parents 667e4bd + 9a72959 commit b4dceac
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 20 deletions.
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ preprocessing = [
full = [
"h5py",
"pandas",
"xarray",
"scipy",
"scikit-learn",
"networkx",
Expand Down Expand Up @@ -148,7 +147,6 @@ test = [
"pytest-dependency",
"pytest-cov",

"xarray",
"huggingface_hub",

# preprocessing
Expand Down Expand Up @@ -193,7 +191,6 @@ docs = [
"pandas", # in the modules gallery comparison tutorial
"hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous
"numba", # For many postprocessing functions
"xarray", # For use of SortingAnalyzer zarr format
"networkx",
# Download data
"pooch>=1.8.2",
Expand Down
52 changes: 38 additions & 14 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
import warnings
import importlib
from packaging.version import parse
from time import perf_counter

import numpy as np
Expand Down Expand Up @@ -579,6 +580,20 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None):

zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options)

si_info = zarr_root.attrs["spikeinterface_info"]
if parse(si_info["version"]) < parse("0.101.1"):
# v0.101.0 did not have a consolidate metadata step after computing extensions.
# Here we try to consolidate the metadata and throw a warning if it fails.
try:
zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options)
zarr.consolidate_metadata(zarr_root_a.store)
except Exception as e:
warnings.warn(
"The zarr store was not properly consolidated prior to v0.101.1. "
"This may lead to unexpected behavior in loading extensions. "
"Please consider re-generating the SortingAnalyzer object."
)

# load internal sorting in memory
sorting = NumpySorting.from_sorting(
ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options),
Expand Down Expand Up @@ -1970,12 +1985,14 @@ def load_data(self):
if "dict" in ext_data_.attrs:
ext_data = ext_data_[0]
elif "dataframe" in ext_data_.attrs:
import xarray
import pandas as pd

ext_data = xarray.open_zarr(
ext_data_.store, group=f"{extension_group.name}/{ext_data_name}"
).to_pandas()
ext_data.index.rename("", inplace=True)
index = ext_data_["index"]
ext_data = pd.DataFrame(index=index)
for col in ext_data_.keys():
if col != "index":
ext_data.loc[:, col] = ext_data_[col][:]
ext_data = ext_data.convert_dtypes()
elif "object" in ext_data_.attrs:
ext_data = ext_data_[0]
else:
Expand Down Expand Up @@ -2031,12 +2048,21 @@ def run(self, save=True, **kwargs):
if save and not self.sorting_analyzer.is_read_only():
self._save_run_info()
self._save_data(**kwargs)
if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def save(self, **kwargs):
self._save_params()
self._save_importing_provenance()
self._save_data(**kwargs)
self._save_run_info()
self._save_data(**kwargs)

if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _save_data(self, **kwargs):
if self.format == "memory":
Expand Down Expand Up @@ -2096,12 +2122,12 @@ def _save_data(self, **kwargs):
elif isinstance(ext_data, np.ndarray):
extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor)
elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame):
ext_data.to_xarray().to_zarr(
store=extension_group.store,
group=f"{extension_group.name}/{ext_data_name}",
mode="a",
)
extension_group[ext_data_name].attrs["dataframe"] = True
df_group = extension_group.create_group(ext_data_name)
# first we save the index
df_group.create_dataset(name="index", data=ext_data.index.to_numpy())
for col in ext_data.columns:
df_group.create_dataset(name=col, data=ext_data[col].to_numpy())
df_group.attrs["dataframe"] = True
else:
# any object
try:
Expand All @@ -2111,8 +2137,6 @@ def _save_data(self, **kwargs):
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
extension_group[ext_data_name].attrs["object"] = True
# we need to re-consolidate
zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _reset_extension_folder(self):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
value = np.nan
template_metrics.at[index, metric_name] = value

# we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns
# (in case of NaN values)
template_metrics = template_metrics.convert_dtypes()
return template_metrics

def _run(self, verbose=False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from spikeinterface.core import generate_ground_truth_recording
from spikeinterface.core import create_sorting_analyzer
from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer
from spikeinterface.core import estimate_sparsity


Expand Down Expand Up @@ -116,6 +116,8 @@ def _check_one(self, sorting_analyzer, extension_class, params):
with the passed parameters, and check the output is not empty, the extension
exists and `select_units()` method works.
"""
import pandas as pd

if extension_class.need_job_kwargs:
job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True)
else:
Expand All @@ -138,6 +140,26 @@ def _check_one(self, sorting_analyzer, extension_class, params):
merged = sorting_analyzer.merge_units(some_merges, format="memory", merging_mode="soft", sparsity_overlap=0.0)
assert len(merged.unit_ids) == num_units_after_merge

# test roundtrip
if sorting_analyzer.format in ("binary_folder", "zarr"):
sorting_analyzer_loaded = load_sorting_analyzer(sorting_analyzer.folder)
ext_loaded = sorting_analyzer_loaded.get_extension(extension_class.extension_name)
for ext_data_name, ext_data_loaded in ext_loaded.data.items():
if isinstance(ext_data_loaded, np.ndarray):
assert np.array_equal(ext.data[ext_data_name], ext_data_loaded)
elif isinstance(ext_data_loaded, pd.DataFrame):
# skip nan values
for col in ext_data_loaded.columns:
np.testing.assert_array_almost_equal(
ext.data[ext_data_name][col].dropna().to_numpy(),
ext_data_loaded[col].dropna().to_numpy(),
decimal=5,
)
elif isinstance(ext_data_loaded, dict):
assert ext.data[ext_data_name] == ext_data_loaded
else:
continue

def run_extension_tests(self, extension_class, params):
"""
Convenience function to perform all checks on the extension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
"""
Compute quality metrics.
"""
import pandas as pd

qm_params = self.params["qm_params"]
# sparsity = self.params["sparsity"]
Expand All @@ -163,8 +164,6 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
non_empty_unit_ids = unit_ids
empty_unit_ids = []

import pandas as pd

metrics = pd.DataFrame(index=unit_ids)

# simple metrics not based on PCs
Expand Down Expand Up @@ -216,6 +215,9 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
if len(empty_unit_ids) > 0:
metrics.loc[empty_unit_ids] = np.nan

# we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns
# (in case of NaN values)
metrics = metrics.convert_dtypes()
return metrics

def _run(self, verbose=False, **job_kwargs):
Expand Down

0 comments on commit b4dceac

Please sign in to comment.