From 30d1ecce4249a3e645ca09be39799277186e11c6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Sep 2023 11:47:37 +0200 Subject: [PATCH 1/7] Allow to postprocess on read-only waveform folders --- src/spikeinterface/core/waveform_extractor.py | 55 ++++++++++--------- .../tests/common_extension_tests.py | 23 +++++++- .../postprocessing/unit_localization.py | 4 +- 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 877c9fb00c..e404e74be4 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -4,6 +4,7 @@ import shutil from typing import Iterable, Literal, Optional import json +import os import numpy as np from copy import deepcopy @@ -87,6 +88,7 @@ def __init__( self._template_cache = {} self._params = {} self._loaded_extensions = dict() + self._is_read_only = False self.sparsity = sparsity self.folder = folder @@ -103,6 +105,8 @@ def __init__( if (self.folder / "params.json").is_file(): with open(str(self.folder / "params.json"), "r") as f: self._params = json.load(f) + if not os.access(self.folder, os.W_OK): + self._is_read_only = True else: # this is in case of in-memory self.format = "memory" @@ -399,6 +403,9 @@ def return_scaled(self) -> bool: def dtype(self): return self._params["dtype"] + def is_read_only(self) -> bool: + return self._is_read_only + def has_recording(self) -> bool: return self._recording is not None @@ -514,18 +521,8 @@ def is_extension(self, extension_name) -> bool: exists: bool Whether the extension exists or not """ - if self.folder is None: - return extension_name in self._loaded_extensions - else: - if self.format == "binary": - return (self.folder / extension_name).is_dir() and ( - self.folder / extension_name / "params.json" - ).is_file() - elif self.format == "zarr": - return ( - extension_name in self._waveforms_root.keys() - and "params" in self._waveforms_root[extension_name].attrs.keys() - ) + # Extensions are always loaded in memory + return extension_name in self._loaded_extensions def load_extension(self, extension_name): """ @@ -1735,20 +1732,28 @@ def __init__(self, waveform_extractor): self.waveform_extractor = waveform_extractor if self.waveform_extractor.folder is not None: - self.folder = self.waveform_extractor.folder - self.format = self.waveform_extractor.format - if self.format == "binary": - self.extension_folder = self.folder / self.extension_name - if not self.extension_folder.is_dir(): - self.extension_folder.mkdir() - else: - import zarr - - zarr_root = zarr.open(self.folder, mode="r+") - if self.extension_name not in zarr_root.keys(): - self.extension_group = zarr_root.create_group(self.extension_name) + if not self.waveform_extractor.is_read_only(): + self.folder = self.waveform_extractor.folder + self.format = self.waveform_extractor.format + if self.format == "binary": + self.extension_folder = self.folder / self.extension_name + if not self.extension_folder.is_dir(): + self.extension_folder.mkdir() else: - self.extension_group = zarr_root[self.extension_name] + import zarr + + zarr_root = zarr.open(self.folder, mode="r+") + if self.extension_name not in zarr_root.keys(): + self.extension_group = zarr_root.create_group(self.extension_name) + else: + self.extension_group = zarr_root[self.extension_name] + else: + warn( + "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None else: self.format = "memory" self.extension_folder = None diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b9c72f9b99..f44d58470c 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -4,7 +4,7 @@ import shutil from pathlib import Path -from spikeinterface import extract_waveforms, load_extractor, compute_sparsity +from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity from spikeinterface.extractors import toy_example if hasattr(pytest, "global_test_folder"): @@ -76,6 +76,15 @@ def setUp(self): overwrite=True, ) self.we2 = we2 + + # make we read-only + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + if not we_ro_folder.is_dir(): + shutil.copytree(we2.folder, we_ro_folder) + # change permissions (R+X) + we_ro_folder.chmod(0o555) + self.we_ro = load_waveforms(we_ro_folder) + self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30) we_memory = extract_waveforms( recording, @@ -97,6 +106,11 @@ def setUp(self): folder=cache_folder / "toy_sorting_2seg_sparse", format="binary", sparsity=sparsity, overwrite=True ) + def tearDown(self): + # allow pytest to delete RO folder + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + we_ro_folder.chmod(0o777) + def _test_extension_folder(self, we, in_memory=False): if self.extension_function_kwargs_list is None: extension_function_kwargs_list = [dict()] @@ -177,3 +191,10 @@ def test_extension(self): assert ext_data_mem.equals(ext_data_zarr) else: print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.") + + # read-only - Extension is memory only + _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) + assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() + ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) + assert ext_ro.format == "memory" + assert ext_ro.extension_folder is None diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 740fdd234b..d2739f69dd 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -570,6 +570,8 @@ def enforce_decrease_shells_data(wf_data, maxchan, radial_parents, in_place=Fals def get_grid_convolution_templates_and_weights( contact_locations, radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50 ): + import sklearn.metrics + x_min, x_max = contact_locations[:, 0].min(), contact_locations[:, 0].max() y_min, y_max = contact_locations[:, 1].min(), contact_locations[:, 1].max() @@ -593,8 +595,6 @@ def get_grid_convolution_templates_and_weights( template_positions[:, 0] = all_x.flatten() template_positions[:, 1] = all_y.flatten() - import sklearn - # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) nearest_template_mask = dist < radius_um From b8ee13c208cf928573595d941803b11e38278eb0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Sep 2023 15:02:13 +0200 Subject: [PATCH 2/7] Restore extension loading --- src/spikeinterface/core/waveform_extractor.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index e404e74be4..6083732c11 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -521,8 +521,22 @@ def is_extension(self, extension_name) -> bool: exists: bool Whether the extension exists or not """ - # Extensions are always loaded in memory - return extension_name in self._loaded_extensions + if self.folder is None: + return extension_name in self._loaded_extensions + else: + # Extensions already loaded in memory + if extension_name in self._loaded_extensions: + return True + else: + if self.format == "binary": + return (self.folder / extension_name).is_dir() and ( + self.folder / extension_name / "params.json" + ).is_file() + elif self.format == "zarr": + return ( + extension_name in self._waveforms_root.keys() + and "params" in self._waveforms_root[extension_name].attrs.keys() + ) def load_extension(self, extension_name): """ From def525c20a463b625c2f014fd5a84be4f79a00ef Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 5 Sep 2023 15:38:06 +0200 Subject: [PATCH 3/7] handle re-loading correctly --- src/spikeinterface/core/waveform_extractor.py | 140 ++++++++++-------- 1 file changed, 77 insertions(+), 63 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6083732c11..39d115e22c 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1746,28 +1746,39 @@ def __init__(self, waveform_extractor): self.waveform_extractor = waveform_extractor if self.waveform_extractor.folder is not None: - if not self.waveform_extractor.is_read_only(): - self.folder = self.waveform_extractor.folder - self.format = self.waveform_extractor.format - if self.format == "binary": - self.extension_folder = self.folder / self.extension_name - if not self.extension_folder.is_dir(): + self.folder = self.waveform_extractor.folder + self.format = self.waveform_extractor.format + if self.format == "binary": + self.extension_folder = self.folder / self.extension_name + if not self.extension_folder.is_dir(): + if not self.waveform_extractor.is_read_only(): self.extension_folder.mkdir() - else: - import zarr + else: + raise Exception( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + import zarr - zarr_root = zarr.open(self.folder, mode="r+") - if self.extension_name not in zarr_root.keys(): + mode = "r+" if not self.waveform_extractor.is_read_only() else "r" + zarr_root = zarr.open(self.folder, mode=mode) + if self.extension_name not in zarr_root.keys(): + if not self.waveform_extractor.is_read_only(): self.extension_group = zarr_root.create_group(self.extension_name) else: - self.extension_group = zarr_root[self.extension_name] - else: - warn( - "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." - ) - self.format = "memory" - self.extension_folder = None - self.folder = None + raise Exception( + "WaveformExtractor: cannot save extension in read-only mode. " + "Extension will be saved in memory." + ) + self.format = "memory" + self.extension_folder = None + self.folder = None + else: + self.extension_group = zarr_root[self.extension_name] else: self.format = "memory" self.extension_folder = None @@ -1882,53 +1893,56 @@ def save(self, **kwargs): self._save(**kwargs) def _save(self, **kwargs): - if self.format == "binary": - import pandas as pd - - for ext_data_name, ext_data in self._extension_data.items(): - if isinstance(ext_data, dict): - with (self.extension_folder / f"{ext_data_name}.json").open("w") as f: - json.dump(ext_data, f) - elif isinstance(ext_data, np.ndarray): - np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data) - elif isinstance(ext_data, pd.DataFrame): - ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True) - else: - try: - with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f: - pickle.dump(ext_data, f) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") - elif self.format == "zarr": - from .zarrrecordingextractor import get_default_zarr_compressor - import pandas as pd - import numcodecs - - compressor = kwargs.get("compressor", None) - if compressor is None: - compressor = get_default_zarr_compressor() - for ext_data_name, ext_data in self._extension_data.items(): - if ext_data_name in self.extension_group: - del self.extension_group[ext_data_name] - if isinstance(ext_data, dict): - self.extension_group.create_dataset( - name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() - ) - self.extension_group[ext_data_name].attrs["dict"] = True - elif isinstance(ext_data, np.ndarray): - self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) - elif isinstance(ext_data, pd.DataFrame): - ext_data.to_xarray().to_zarr( - store=self.extension_group.store, group=f"{self.extension_group.name}/{ext_data_name}", mode="a" - ) - self.extension_group[ext_data_name].attrs["dataframe"] = True - else: - try: + if not self.waveform_extractor.is_read_only(): + if self.format == "binary": + import pandas as pd + + for ext_data_name, ext_data in self._extension_data.items(): + if isinstance(ext_data, dict): + with (self.extension_folder / f"{ext_data_name}.json").open("w") as f: + json.dump(ext_data, f) + elif isinstance(ext_data, np.ndarray): + np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data) + elif isinstance(ext_data, pd.DataFrame): + ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True) + else: + try: + with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f: + pickle.dump(ext_data, f) + except: + raise Exception(f"Could not save {ext_data_name} as extension data") + elif self.format == "zarr": + from .zarrrecordingextractor import get_default_zarr_compressor + import pandas as pd + import numcodecs + + compressor = kwargs.get("compressor", None) + if compressor is None: + compressor = get_default_zarr_compressor() + for ext_data_name, ext_data in self._extension_data.items(): + if ext_data_name in self.extension_group: + del self.extension_group[ext_data_name] + if isinstance(ext_data, dict): self.extension_group.create_dataset( - name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() + name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() + ) + self.extension_group[ext_data_name].attrs["dict"] = True + elif isinstance(ext_data, np.ndarray): + self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) + elif isinstance(ext_data, pd.DataFrame): + ext_data.to_xarray().to_zarr( + store=self.extension_group.store, + group=f"{self.extension_group.name}/{ext_data_name}", + mode="a", ) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") + self.extension_group[ext_data_name].attrs["dataframe"] = True + else: + try: + self.extension_group.create_dataset( + name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() + ) + except: + raise Exception(f"Could not save {ext_data_name} as extension data") def reset(self): """ From dfa67e681afec0ef741b16e61417c70123c97ef5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 6 Sep 2023 12:08:01 +0200 Subject: [PATCH 4/7] warn instead of raise --- src/spikeinterface/core/waveform_extractor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 39d115e22c..431440c846 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1754,7 +1754,7 @@ def __init__(self, waveform_extractor): if not self.waveform_extractor.is_read_only(): self.extension_folder.mkdir() else: - raise Exception( + warn( "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." ) @@ -1770,7 +1770,7 @@ def __init__(self, waveform_extractor): if not self.waveform_extractor.is_read_only(): self.extension_group = zarr_root.create_group(self.extension_name) else: - raise Exception( + warn( "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." ) From f60024b0c52e17edfebe02b8170f9ac3d78b053f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 6 Sep 2023 12:24:41 +0200 Subject: [PATCH 5/7] Do not overwrite similarity in Phy if available --- src/spikeinterface/exporters/to_phy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 5615402fdb..c92861a8bf 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -178,7 +178,11 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") + if waveform_extractor.is_extension("similarity"): + tmc = waveform_extractor.load_extension("similarity") + template_similarity = tmc.get_data() + else: + template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity") np.save(str(output_folder / "templates.npy"), templates) np.save(str(output_folder / "template_ind.npy"), templates_ind) From fe178c67ac9428477ca146dd6ac453bf1cccfc78 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 11 Sep 2023 10:37:00 +0200 Subject: [PATCH 6/7] Apply suggestions and avoid using chmod on windows --- src/spikeinterface/core/waveform_extractor.py | 111 +++++++++--------- .../tests/common_extension_tests.py | 28 +++-- 2 files changed, 73 insertions(+), 66 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 431440c846..3647e915bf 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1751,9 +1751,7 @@ def __init__(self, waveform_extractor): if self.format == "binary": self.extension_folder = self.folder / self.extension_name if not self.extension_folder.is_dir(): - if not self.waveform_extractor.is_read_only(): - self.extension_folder.mkdir() - else: + if self.waveform_extractor.is_read_only(): warn( "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." @@ -1761,15 +1759,16 @@ def __init__(self, waveform_extractor): self.format = "memory" self.extension_folder = None self.folder = None + else: + self.extension_folder.mkdir() + else: import zarr mode = "r+" if not self.waveform_extractor.is_read_only() else "r" zarr_root = zarr.open(self.folder, mode=mode) if self.extension_name not in zarr_root.keys(): - if not self.waveform_extractor.is_read_only(): - self.extension_group = zarr_root.create_group(self.extension_name) - else: + if self.waveform_extractor.is_read_only(): warn( "WaveformExtractor: cannot save extension in read-only mode. " "Extension will be saved in memory." @@ -1777,6 +1776,8 @@ def __init__(self, waveform_extractor): self.format = "memory" self.extension_folder = None self.folder = None + else: + self.extension_group = zarr_root.create_group(self.extension_name) else: self.extension_group = zarr_root[self.extension_name] else: @@ -1893,56 +1894,58 @@ def save(self, **kwargs): self._save(**kwargs) def _save(self, **kwargs): - if not self.waveform_extractor.is_read_only(): - if self.format == "binary": - import pandas as pd - - for ext_data_name, ext_data in self._extension_data.items(): - if isinstance(ext_data, dict): - with (self.extension_folder / f"{ext_data_name}.json").open("w") as f: - json.dump(ext_data, f) - elif isinstance(ext_data, np.ndarray): - np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data) - elif isinstance(ext_data, pd.DataFrame): - ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True) - else: - try: - with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f: - pickle.dump(ext_data, f) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") - elif self.format == "zarr": - from .zarrrecordingextractor import get_default_zarr_compressor - import pandas as pd - import numcodecs - - compressor = kwargs.get("compressor", None) - if compressor is None: - compressor = get_default_zarr_compressor() - for ext_data_name, ext_data in self._extension_data.items(): - if ext_data_name in self.extension_group: - del self.extension_group[ext_data_name] - if isinstance(ext_data, dict): + # Only save if not read only + if self.waveform_extractor.is_read_only(): + return + if self.format == "binary": + import pandas as pd + + for ext_data_name, ext_data in self._extension_data.items(): + if isinstance(ext_data, dict): + with (self.extension_folder / f"{ext_data_name}.json").open("w") as f: + json.dump(ext_data, f) + elif isinstance(ext_data, np.ndarray): + np.save(self.extension_folder / f"{ext_data_name}.npy", ext_data) + elif isinstance(ext_data, pd.DataFrame): + ext_data.to_csv(self.extension_folder / f"{ext_data_name}.csv", index=True) + else: + try: + with (self.extension_folder / f"{ext_data_name}.pkl").open("wb") as f: + pickle.dump(ext_data, f) + except: + raise Exception(f"Could not save {ext_data_name} as extension data") + elif self.format == "zarr": + from .zarrrecordingextractor import get_default_zarr_compressor + import pandas as pd + import numcodecs + + compressor = kwargs.get("compressor", None) + if compressor is None: + compressor = get_default_zarr_compressor() + for ext_data_name, ext_data in self._extension_data.items(): + if ext_data_name in self.extension_group: + del self.extension_group[ext_data_name] + if isinstance(ext_data, dict): + self.extension_group.create_dataset( + name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() + ) + self.extension_group[ext_data_name].attrs["dict"] = True + elif isinstance(ext_data, np.ndarray): + self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) + elif isinstance(ext_data, pd.DataFrame): + ext_data.to_xarray().to_zarr( + store=self.extension_group.store, + group=f"{self.extension_group.name}/{ext_data_name}", + mode="a", + ) + self.extension_group[ext_data_name].attrs["dataframe"] = True + else: + try: self.extension_group.create_dataset( - name=ext_data_name, data=[ext_data], object_codec=numcodecs.JSON() - ) - self.extension_group[ext_data_name].attrs["dict"] = True - elif isinstance(ext_data, np.ndarray): - self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) - elif isinstance(ext_data, pd.DataFrame): - ext_data.to_xarray().to_zarr( - store=self.extension_group.store, - group=f"{self.extension_group.name}/{ext_data_name}", - mode="a", + name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() ) - self.extension_group[ext_data_name].attrs["dataframe"] = True - else: - try: - self.extension_group.create_dataset( - name=ext_data_name, data=ext_data, object_codec=numcodecs.Pickle() - ) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") + except: + raise Exception(f"Could not save {ext_data_name} as extension data") def reset(self): """ diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index f44d58470c..f7272ddefe 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd import shutil +import platform from pathlib import Path from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity @@ -78,12 +79,13 @@ def setUp(self): self.we2 = we2 # make we read-only - we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" - if not we_ro_folder.is_dir(): - shutil.copytree(we2.folder, we_ro_folder) + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + if not we_ro_folder.is_dir(): + shutil.copytree(we2.folder, we_ro_folder) # change permissions (R+X) - we_ro_folder.chmod(0o555) - self.we_ro = load_waveforms(we_ro_folder) + we_ro_folder.chmod(0o555) + self.we_ro = load_waveforms(we_ro_folder) self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30) we_memory = extract_waveforms( @@ -108,8 +110,9 @@ def setUp(self): def tearDown(self): # allow pytest to delete RO folder - we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" - we_ro_folder.chmod(0o777) + if platform.system() != "Windows": + we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly" + we_ro_folder.chmod(0o777) def _test_extension_folder(self, we, in_memory=False): if self.extension_function_kwargs_list is None: @@ -193,8 +196,9 @@ def test_extension(self): print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.") # read-only - Extension is memory only - _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) - assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() - ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) - assert ext_ro.format == "memory" - assert ext_ro.extension_folder is None + if platform.system() != "Windows": + _ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False) + assert self.extension_class.extension_name in self.we_ro.get_available_extension_names() + ext_ro = self.we_ro.load_extension(self.extension_class.extension_name) + assert ext_ro.format == "memory" + assert ext_ro.extension_folder is None From 426f395c6cb210b016b119225af540fd968fb30f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Sep 2023 12:38:50 +0200 Subject: [PATCH 7/7] Removed unnecessary else --- src/spikeinterface/core/waveform_extractor.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 3647e915bf..6881ab3ec5 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -523,20 +523,20 @@ def is_extension(self, extension_name) -> bool: """ if self.folder is None: return extension_name in self._loaded_extensions + + if extension_name in self._loaded_extensions: + # extension already loaded in memory + return True else: - # Extensions already loaded in memory - if extension_name in self._loaded_extensions: - return True - else: - if self.format == "binary": - return (self.folder / extension_name).is_dir() and ( - self.folder / extension_name / "params.json" - ).is_file() - elif self.format == "zarr": - return ( - extension_name in self._waveforms_root.keys() - and "params" in self._waveforms_root[extension_name].attrs.keys() - ) + if self.format == "binary": + return (self.folder / extension_name).is_dir() and ( + self.folder / extension_name / "params.json" + ).is_file() + elif self.format == "zarr": + return ( + extension_name in self._waveforms_root.keys() + and "params" in self._waveforms_root[extension_name].attrs.keys() + ) def load_extension(self, extension_name): """