Skip to content

Commit

Permalink
Merge pull request #1957 from alejoe91/postprocessing-read-only
Browse files Browse the repository at this point in the history
Allow to postprocess on read-only waveform folders
  • Loading branch information
samuelgarcia authored Sep 19, 2023
2 parents 3210f8e + 492f2a2 commit 17b0c03
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 8 deletions.
44 changes: 40 additions & 4 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
from typing import Iterable, Literal, Optional
import json
import os

import numpy as np
from copy import deepcopy
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
self._template_cache = {}
self._params = {}
self._loaded_extensions = dict()
self._is_read_only = False
self.sparsity = sparsity

self.folder = folder
Expand All @@ -103,6 +105,8 @@ def __init__(
if (self.folder / "params.json").is_file():
with open(str(self.folder / "params.json"), "r") as f:
self._params = json.load(f)
if not os.access(self.folder, os.W_OK):
self._is_read_only = True
else:
# this is in case of in-memory
self.format = "memory"
Expand Down Expand Up @@ -399,6 +403,9 @@ def return_scaled(self) -> bool:
def dtype(self):
return self._params["dtype"]

def is_read_only(self) -> bool:
return self._is_read_only

def has_recording(self) -> bool:
return self._recording is not None

Expand Down Expand Up @@ -516,6 +523,10 @@ def is_extension(self, extension_name) -> bool:
"""
if self.folder is None:
return extension_name in self._loaded_extensions

if extension_name in self._loaded_extensions:
# extension already loaded in memory
return True
else:
if self.format == "binary":
return (self.folder / extension_name).is_dir() and (
Expand Down Expand Up @@ -1740,13 +1751,33 @@ def __init__(self, waveform_extractor):
if self.format == "binary":
self.extension_folder = self.folder / self.extension_name
if not self.extension_folder.is_dir():
self.extension_folder.mkdir()
if self.waveform_extractor.is_read_only():
warn(
"WaveformExtractor: cannot save extension in read-only mode. "
"Extension will be saved in memory."
)
self.format = "memory"
self.extension_folder = None
self.folder = None
else:
self.extension_folder.mkdir()

else:
import zarr

zarr_root = zarr.open(self.folder, mode="r+")
mode = "r+" if not self.waveform_extractor.is_read_only() else "r"
zarr_root = zarr.open(self.folder, mode=mode)
if self.extension_name not in zarr_root.keys():
self.extension_group = zarr_root.create_group(self.extension_name)
if self.waveform_extractor.is_read_only():
warn(
"WaveformExtractor: cannot save extension in read-only mode. "
"Extension will be saved in memory."
)
self.format = "memory"
self.extension_folder = None
self.folder = None
else:
self.extension_group = zarr_root.create_group(self.extension_name)
else:
self.extension_group = zarr_root[self.extension_name]
else:
Expand Down Expand Up @@ -1863,6 +1894,9 @@ def save(self, **kwargs):
self._save(**kwargs)

def _save(self, **kwargs):
# Only save if not read only
if self.waveform_extractor.is_read_only():
return
if self.format == "binary":
import pandas as pd

Expand Down Expand Up @@ -1900,7 +1934,9 @@ def _save(self, **kwargs):
self.extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor)
elif isinstance(ext_data, pd.DataFrame):
ext_data.to_xarray().to_zarr(
store=self.extension_group.store, group=f"{self.extension_group.name}/{ext_data_name}", mode="a"
store=self.extension_group.store,
group=f"{self.extension_group.name}/{ext_data_name}",
mode="a",
)
self.extension_group[ext_data_name].attrs["dataframe"] = True
else:
Expand Down
6 changes: 5 additions & 1 deletion src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ def export_to_phy(
templates[unit_ind, :, :][:, : len(chan_inds)] = template
templates_ind[unit_ind, : len(chan_inds)] = chan_inds

template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity")
if waveform_extractor.is_extension("similarity"):
tmc = waveform_extractor.load_extension("similarity")
template_similarity = tmc.get_data()
else:
template_similarity = compute_template_similarity(waveform_extractor, method="cosine_similarity")

np.save(str(output_folder / "templates.npy"), templates)
np.save(str(output_folder / "template_ind.npy"), templates_ind)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import numpy as np
import pandas as pd
import shutil
import platform
from pathlib import Path

from spikeinterface import extract_waveforms, load_extractor, compute_sparsity
from spikeinterface import extract_waveforms, load_extractor, load_waveforms, compute_sparsity
from spikeinterface.extractors import toy_example

if hasattr(pytest, "global_test_folder"):
Expand Down Expand Up @@ -76,6 +77,16 @@ def setUp(self):
overwrite=True,
)
self.we2 = we2

# make we read-only
if platform.system() != "Windows":
we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly"
if not we_ro_folder.is_dir():
shutil.copytree(we2.folder, we_ro_folder)
# change permissions (R+X)
we_ro_folder.chmod(0o555)
self.we_ro = load_waveforms(we_ro_folder)

self.sparsity2 = compute_sparsity(we2, method="radius", radius_um=30)
we_memory = extract_waveforms(
recording,
Expand All @@ -97,6 +108,12 @@ def setUp(self):
folder=cache_folder / "toy_sorting_2seg_sparse", format="binary", sparsity=sparsity, overwrite=True
)

def tearDown(self):
# allow pytest to delete RO folder
if platform.system() != "Windows":
we_ro_folder = cache_folder / "toy_waveforms_2seg_readonly"
we_ro_folder.chmod(0o777)

def _test_extension_folder(self, we, in_memory=False):
if self.extension_function_kwargs_list is None:
extension_function_kwargs_list = [dict()]
Expand Down Expand Up @@ -177,3 +194,11 @@ def test_extension(self):
assert ext_data_mem.equals(ext_data_zarr)
else:
print(f"{ext_data_name} of type {type(ext_data_mem)} not tested.")

# read-only - Extension is memory only
if platform.system() != "Windows":
_ = self.extension_class.get_extension_function()(self.we_ro, load_if_exists=False)
assert self.extension_class.extension_name in self.we_ro.get_available_extension_names()
ext_ro = self.we_ro.load_extension(self.extension_class.extension_name)
assert ext_ro.format == "memory"
assert ext_ro.extension_folder is None
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/unit_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ def enforce_decrease_shells_data(wf_data, maxchan, radial_parents, in_place=Fals
def get_grid_convolution_templates_and_weights(
contact_locations, radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50
):
import sklearn.metrics

x_min, x_max = contact_locations[:, 0].min(), contact_locations[:, 0].max()
y_min, y_max = contact_locations[:, 1].min(), contact_locations[:, 1].max()

Expand All @@ -593,8 +595,6 @@ def get_grid_convolution_templates_and_weights(
template_positions[:, 0] = all_x.flatten()
template_positions[:, 1] = all_y.flatten()

import sklearn

# mask to get nearest template given a channel
dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions)
nearest_template_mask = dist < radius_um
Expand Down

0 comments on commit 17b0c03

Please sign in to comment.