Skip to content

Commit

Permalink
Apply suggestions and avoid using chmod on windows
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 11, 2023
1 parent ede88da commit fe178c6
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 66 deletions.
111 changes: 57 additions & 54 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,32 +1751,33 @@ def __init__(self, waveform_extractor):
if self.format == "binary":
self.extension_folder = self.folder / self.extension_name
if not self.extension_folder.is_dir():
if not self.waveform_extractor.is_read_only():
self.extension_folder.mkdir()
else:
if self.waveform_extractor.is_read_only():
warn(
"WaveformExtractor: cannot save extension in read-only mode. "
"Extension will be saved in memory."
)
self.format = "memory"
self.extension_folder = None
self.folder = None
else:
self.extension_folder.mkdir()

else:
import zarr

mode = "r+" if not self.waveform_extractor.is_read_only() else "r"
zarr_root = zarr.open(self.folder, mode=mode)
if self.extension_name not in zarr_root.keys():
if not self.waveform_extractor.is_read_only():
self.extension_group = zarr_root.create_group(self.extension_name)
else:
if self.waveform_extractor.is_read_only():
warn(
"WaveformExtractor: cannot save extension in read-only mode. "
"Extension will be saved in memory."
)
self.format = "memory"
self.extension_folder = None
self.folder = None
else:
self.extension_group = zarr_root.create_group(self.extension_name)
else:
self.extension_group = zarr_root[self.extension_name]
else:
Expand Down Expand Up @@ -1893,56 +1894,58 @@ def save(self, **kwargs):
self._save(**kwargs)

def _save(self, **kwargs):
if not self.waveform_extractor.is_read_only():
if self.format == "binary":
import pandas as pd

for ext_data_name, ext_data in self._extension_data.items():
if isinstance(ext_data, dict):
with (self.extension_folder / f"{ext_data_name}.json").open("w") as f:
json.dump(ext_data, f)
elif isinstance(ext_data, np.ndarray):
np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data)
elif isinstance(ext_data, pd.DataFrame):
ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True)
else:
try:
with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f:
pickle.dump(ext_data, f)
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
elif self.format == "zarr":
from .zarrrecordingextractor import get_default_zarr_compressor
import pandas as pd
import numcodecs

compressor = kwargs.get("compressor", None)
if compressor is None:
compressor = get_default_zarr_compressor()
for ext_data_name, ext_data in self._extension_data.items():
if ext_data_name in self.extension_group:
del self.extension_group[ext_data_name]
if isinstance(ext_data, dict):
# Only save if not read only
if self.waveform_extractor.is_read_only():
return
if self.format == "binary":
import pandas as pd

for ext_data_name, ext_data in self._extension_data.items():
if isinstance(ext_data, dict):
with (self.extension_folder / f"{ext_data_name}.json").open("w") as f:
json.dump(ext_data, f)
elif isinstance(ext_data, np.ndarray):
np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data)
elif isinstance(ext_data, pd.DataFrame):
ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True)
else:
try:
with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f:
pickle.dump(ext_data, f)
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
elif self.format == "zarr":
from .zarrrecordingextractor import get_default_zarr_compressor
import pandas as pd
import numcodecs

compressor = kwargs.get("compressor", None)
if compressor is None:
compressor = get_default_zarr_compressor()
for ext_data_name, ext_data in self._extension_data.items():
if ext_data_name in self.extension_group:
del self.extension_group[ext_data_name]
if isinstance(ext_data, dict):
self.extension_group.create_dataset(
name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON()
)
self.extension_group[ext_data_name].attrs["dict"] = True
elif isinstance(ext_data, np.ndarray):
self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor)
elif isinstance(ext_data, pd.DataFrame):
ext_data.to_xarray().to_zarr(
store=self.extension_group.store,
group=f"{self.extension_group.name}/{ext_data_name}",
mode="a",
)
self.extension_group[ext_data_name].attrs["dataframe"] = True
else:
try:
self.extension_group.create_dataset(
name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON()
)
self.extension_group[ext_data_name].attrs["dict"] = True
elif isinstance(ext_data, np.ndarray):
self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor)
elif isinstance(ext_data, pd.DataFrame):
ext_data.to_xarray().to_zarr(
store=self.extension_group.store,
group=f"{self.extension_group.name}/{ext_data_name}",
mode="a",
name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle()
)
self.extension_group[ext_data_name].attrs["dataframe"] = True
else:
try:
self.extension_group.create_dataset(
name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle()
)
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
except:
raise Exception(f"Could not save {ext_data_name} as extension data")

def reset(self):
"""
Expand Down
28 changes: 16 additions & 12 deletions src/spikeinterface/postprocessing/tests/common_extension_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pandas as pd
import shutil
import platform
from pathlib import Path

from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity
Expand Down Expand Up @@ -78,12 +79,13 @@ def setUp(self):
self.we2 = we2

# make we read-only
we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly"
if not we_ro_folder.is_dir():
shutil.copytree(we2.folder, we_ro_folder)
if platform.system() != "Windows":
we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly"
if not we_ro_folder.is_dir():
shutil.copytree(we2.folder, we_ro_folder)
# change permissions (R+X)
we_ro_folder.chmod(0o555)
self.we_ro = load_waveforms(we_ro_folder)
we_ro_folder.chmod(0o555)
self.we_ro = load_waveforms(we_ro_folder)

self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30)
we_memory = extract_waveforms(
Expand All @@ -108,8 +110,9 @@ def setUp(self):

def tearDown(self):
# allow pytest to delete RO folder
we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly"
we_ro_folder.chmod(0o777)
if platform.system() != "Windows":
we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly"
we_ro_folder.chmod(0o777)

def _test_extension_folder(self, we, in_memory=False):
if self.extension_function_kwargs_list is None:
Expand Down Expand Up @@ -193,8 +196,9 @@ def test_extension(self):
print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.")

# read-only - Extension is memory only
_ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False)
assert self.extension_class.extension_name in self.we_ro.get_available_extension_names()
ext_ro = self.we_ro.load_extension(self.extension_class.extension_name)
assert ext_ro.format == "memory"
assert ext_ro.extension_folder is None
if platform.system() != "Windows":
_ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False)
assert self.extension_class.extension_name in self.we_ro.get_available_extension_names()
ext_ro = self.we_ro.load_extension(self.extension_class.extension_name)
assert ext_ro.format == "memory"
assert ext_ro.extension_folder is None

0 comments on commit fe178c6

Please sign in to comment.