Skip to content

Commit

Permalink
Merge branch 'main' into prepare_release
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Feb 5, 2024
2 parents d27beb1 + bd317fe commit a9c72c3
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 34 deletions.
127 changes: 126 additions & 1 deletion src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from dataclasses import dataclass, field, astuple
from probeinterface import Probe
from pathlib import Path
from .sparsity import ChannelSparsity


Expand Down Expand Up @@ -168,6 +169,131 @@ def from_dict(cls, data):
probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]),
)

def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None:
"""
Adds a serialized version of the object to a given Zarr group.
It is the inverse of the `from_zarr_group` method.
Parameters
----------
zarr_group : zarr.Group
The Zarr group to which the template object will be serialized.
Notes
-----
This method will create datasets within the Zarr group for `templates_array`,
`channel_ids`, and `unit_ids`. It will also add `sampling_frequency` and `nbefore`
as attributes to the group. If `sparsity_mask` and `probe` are not None, they will
be included as a dataset and a subgroup, respectively.
The `templates_array` dataset is saved with a chunk size that has a single unit per chunk
to optimize read/write operations for individual units.
"""

# Saves one chunk per unit
arrays_chunk = (1, None, None)
zarr_group.create_dataset("templates_array", data=self.templates_array, chunks=arrays_chunk)
zarr_group.create_dataset("channel_ids", data=self.channel_ids)
zarr_group.create_dataset("unit_ids", data=self.unit_ids)

zarr_group.attrs["sampling_frequency"] = self.sampling_frequency
zarr_group.attrs["nbefore"] = self.nbefore

if self.sparsity_mask is not None:
zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask)

if self.probe is not None:
probe_group = zarr_group.create_group("probe")
self.probe.add_probe_to_zarr_group(probe_group)

def to_zarr(self, folder_path: str | Path) -> None:
"""
Saves the object's data to a Zarr file in the specified folder.
Use the `add_templates_to_zarr_group` method to serialize the object to a Zarr group and then
save the group to a Zarr file.
Parameters
----------
folder_path : str | Path
The path to the folder where the Zarr data will be saved.
"""
import zarr

zarr_group = zarr.open_group(folder_path, mode="w")

self.add_templates_to_zarr_group(zarr_group)

@classmethod
def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":
"""
Loads an instance of the class from an open Zarr group.
This is the inverse of the `add_templates_to_zarr_group` method.
Parameters
----------
zarr_group : zarr.Group
The Zarr group from which to load the instance.
Returns
-------
Templates
An instance of Templates populated with the data from the Zarr group.
Notes
-----
This method assumes the Zarr group has the same structure as the one created by
the `add_templates_to_zarr_group` method.
"""
templates_array = zarr_group["templates_array"]
channel_ids = zarr_group["channel_ids"]
unit_ids = zarr_group["unit_ids"]
sampling_frequency = zarr_group.attrs["sampling_frequency"]
nbefore = zarr_group.attrs["nbefore"]

sparsity_mask = None
if "sparsity_mask" in zarr_group:
sparsity_mask = zarr_group["sparsity_mask"]

probe = None
if "probe" in zarr_group:
probe = Probe.from_zarr_group(zarr_group["probe"])

return cls(
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=sparsity_mask,
channel_ids=channel_ids,
unit_ids=unit_ids,
probe=probe,
)

@staticmethod
def from_zarr(folder_path: str | Path) -> "Templates":
"""
Deserialize the Templates object from a Zarr file located at the given folder path.
Parameters
----------
folder_path : str | Path
The path to the folder where the Zarr file is located.
Returns
-------
Templates
An instance of Templates initialized with data from the Zarr file.
"""
import zarr

zarr_group = zarr.open_group(folder_path, mode="r")

return Templates.from_zarr_group(zarr_group)

def to_json(self):
from spikeinterface.core.core_tools import SIJsonEncoder

Expand Down Expand Up @@ -209,7 +335,6 @@ def __eq__(self, other):
return False
if not np.array_equal(s_field.channel_ids, o_field.channel_ids):
return False

else:
if s_field != o_field:
return False
Expand Down
17 changes: 16 additions & 1 deletion src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def generate_test_template(template_type):
probe = generate_multi_columns_probe(num_columns=1, num_contact_per_column=[3])

if template_type == "dense":
return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore)
return Templates(
templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe
)
elif template_type == "sparse": # sparse with sparse templates
sparsity_mask = np.array([[True, False, True], [False, True, False]])
sparsity = ChannelSparsity(
Expand Down Expand Up @@ -92,6 +94,19 @@ def test_initialization_fail_with_dense_templates():
template = generate_test_template(template_type="sparse_with_dense_templates")


@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_save_and_load_zarr(template_type, tmp_path):
original_template = generate_test_template(template_type)

zarr_path = tmp_path / "templates.zarr"
original_template.to_zarr(str(zarr_path))

# Load from the Zarr archive
loaded_template = Templates.from_zarr(str(zarr_path))

assert original_template == loaded_template


if __name__ == "__main__":
# test_json_serialization("sparse")
test_json_serialization("dense")
25 changes: 16 additions & 9 deletions src/spikeinterface/extractors/neoextractors/openephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ class OpenEphysLegacyRecordingExtractor(NeoBaseRecordingExtractor):
If there are several blocks (experiments), specify the block index you want to load
all_annotations: bool, default: False
Load exhaustively all annotation from neo
ignore_timestamps_errors: bool, default: False
Ignore the discontinuous timestamps errors in neo
ignore_timestamps_errors: None
Deprecated keyword argument. This is now ignored.
neo.OpenEphysRawIO is now handling gaps directly but makes the read slower.
"""

mode = "folder"
Expand All @@ -71,9 +72,15 @@ def __init__(
stream_name=None,
block_index=None,
all_annotations=False,
ignore_timestamps_errors=False,
ignore_timestamps_errors=None,
):
neo_kwargs = self.map_to_neo_kwargs(folder_path, ignore_timestamps_errors)
if ignore_timestamps_errors is not None:
warnings.warn(
"OpenEphysLegacyRecordingExtractor: ignore_timestamps_errors is deprecated and is ignored",
DeprecationWarning,
stacklevel=2,
)
neo_kwargs = self.map_to_neo_kwargs(folder_path)
NeoBaseRecordingExtractor.__init__(
self,
stream_id=stream_id,
Expand All @@ -85,8 +92,8 @@ def __init__(
self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute())))

@classmethod
def map_to_neo_kwargs(cls, folder_path, ignore_timestamps_errors=False):
neo_kwargs = {"dirname": str(folder_path), "ignore_timestamps_errors": ignore_timestamps_errors}
def map_to_neo_kwargs(cls, folder_path):
neo_kwargs = {"dirname": str(folder_path)}
neo_kwargs = drop_invalid_neo_arguments_for_version_0_12_0(neo_kwargs)
return neo_kwargs

Expand Down Expand Up @@ -330,9 +337,9 @@ def read_openephys(folder_path, **kwargs):
recording: OpenEphysLegacyRecordingExtractor or OpenEphysBinaryExtractor
"""
# auto guess format
files = [str(f) for f in Path(folder_path).iterdir()]
if np.any([f.endswith("continuous") for f in files]):
#  format = 'legacy'
files = [f for f in Path(folder_path).iterdir()]
if np.any([".continuous" in f.name and f.is_file() for f in files]):
# format = 'legacy'
recording = OpenEphysLegacyRecordingExtractor(folder_path, **kwargs)
else:
# format = 'binary'
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/extractors/tests/test_neoextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class OpenEphysLegacyRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
downloads = ["openephys"]
entities = [
"openephys/OpenEphys_SampleData_1",
# This has gaps!!!
"openephys/OpenEphys_SampleData_2_(multiple_starts)",
]


Expand Down
61 changes: 42 additions & 19 deletions src/spikeinterface/preprocessing/filter_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,48 @@
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment


class GaussianBandpassFilterRecording(BasePreprocessor):
class GaussianFilterRecording(BasePreprocessor):
"""
Class for performing a bandpass gaussian filtering/smoothing on a recording.
Class for performing a gaussian filtering/smoothing on a recording.
This is done by a convolution with a Gaussian kernel, which acts as a lowpass-filter.
The highpass-filter can be computed by subtracting the result.
A highpass-filter can be computed by subtracting the result of the convolution to
the original signal.
A bandpass-filter is obtained by substracting the signal smoothed with a narrow
gaussian to the signal smoothed with a wider gaussian.
Here, the bandpass is computed in the Fourier domain to accelerate the computation.
Here, convolution is performed in the Fourier domain to accelerate the computation.
Parameters
----------
recording: BaseRecording
The recording extractor to be filtered.
freq_min: float
freq_min: float or None
The lower frequency cutoff for the bandpass filter.
freq_max: float
If None, the resulting object is a lowpass filter.
freq_max: float or None
The higher frequency cutoff for the bandpass filter.
If None, the resulting object is a highpass filter.
margin_sd: float, default: 5.0
The number of standard deviation to take for margins.
Returns
-------
gaussian_bandpass_filtered_recording: GaussianBandpassFilterRecording
gaussian_filtered_recording: GaussianFilterRecording
The filtered recording extractor object.
"""

name = "gaussian_bandpass_filter"
name = "gaussian_filter"

def __init__(
self, recording: BaseRecording, freq_min: float = 300.0, freq_max: float = 5000.0, margin_sd: float = 5.0
):
sf = recording.sampling_frequency
BasePreprocessor.__init__(self, recording)
self.annotate(is_filtered=True)

if freq_min is None and freq_max is None:
raise ValueError("At least one of `freq_min`,`freq_max` should be specified.")

for parent_segment in recording._recording_segments:
self.add_recording_segment(GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max, margin_sd))

Expand All @@ -58,9 +66,14 @@ def __init__(
self.cached_gaussian = dict()

sf = parent_recording_segment.sampling_frequency
low_sigma = sf / (2 * np.pi * freq_min)
high_sigma = sf / (2 * np.pi * freq_max)
self.margin = 1 + int(max(low_sigma, high_sigma) * margin_sd)

# Margin from widest gaussian
sigmas = []
if freq_min is not None:
sigmas.append(sf / (2 * np.pi * freq_min))
if freq_max is not None:
sigmas.append(sf / (2 * np.pi * freq_max))
self.margin = 1 + int(max(sigmas) * margin_sd)

def get_traces(
self,
Expand All @@ -69,15 +82,27 @@ def get_traces(
channel_indices: Union[Iterable, None] = None,
):
traces, left_margin, right_margin = get_chunk_with_margin(
self.parent_recording_segment, start_frame, end_frame, channel_indices, self.margin
self.parent_recording_segment,
start_frame,
end_frame,
channel_indices,
self.margin,
add_reflect_padding=True,
)
dtype = traces.dtype

traces_fft = np.fft.fft(traces, axis=0)
gauss_low = self._create_gaussian(traces.shape[0], self.freq_min)
gauss_high = self._create_gaussian(traces.shape[0], self.freq_max)

filtered_fft = traces_fft * (gauss_high - gauss_low)[:, None]
if self.freq_max is not None:
pos_factor = self._create_gaussian(traces.shape[0], self.freq_max)
else:
pos_factor = np.ones((traces.shape[0],))
if self.freq_min is not None:
neg_factor = self._create_gaussian(traces.shape[0], self.freq_min)
else:
neg_factor = np.zeros((traces.shape[0],))

filtered_fft = traces_fft * (pos_factor - neg_factor)[:, None]
filtered_traces = np.real(np.fft.ifft(filtered_fft, axis=0))

if np.issubdtype(dtype, np.integer):
Expand Down Expand Up @@ -111,6 +136,4 @@ def _create_gaussian(self, N: int, cutoff_f: float):
return gaussian


gaussian_bandpass_filter = define_function_from_class(
source_class=GaussianBandpassFilterRecording, name="gaussian_filter"
)
gaussian_filter = define_function_from_class(source_class=GaussianFilterRecording, name="gaussian_filter")
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/preprocessinglist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
HighpassFilterRecording,
highpass_filter,
)
from .filter_gaussian import GaussianBandpassFilterRecording, gaussian_bandpass_filter
from .filter_gaussian import GaussianFilterRecording, gaussian_filter
from .normalize_scale import (
NormalizeByQuantileRecording,
normalize_by_quantile,
Expand Down Expand Up @@ -46,7 +46,7 @@
BandpassFilterRecording,
HighpassFilterRecording,
NotchFilterRecording,
GaussianBandpassFilterRecording,
GaussianFilterRecording,
# gain offset stuff
NormalizeByQuantileRecording,
ScaleRecording,
Expand Down
Loading

0 comments on commit a9c72c3

Please sign in to comment.