Skip to content

Commit

Permalink
Merge branch 'main' into meta_merging_sc2
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Jul 9, 2024
2 parents d697b8c + 588ff5f commit 4d34a61
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 62 deletions.
15 changes: 14 additions & 1 deletion src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BaseExtractor:
# This replaces the old key_properties
# These are annotations/properties that always need to be
# dumped (for instance locations, groups, is_fileterd, etc.)
_main_annotations = []
_main_annotations = ["name"]
_main_properties = []

# these properties are skipped by default in copy_metadata
Expand Down Expand Up @@ -79,6 +79,19 @@ def __init__(self, main_ids: Sequence) -> None:
# preferred context for multiprocessing
self._preferred_mp_context = None

@property
def name(self):
name = self._annotations.get("name", None)
return name if name is not None else self.__class__.__name__

@name.setter
def name(self, value):
if value is not None:
self.annotate(name=value)
else:
# we remove the annotation if it exists
_ = self._annotations.pop("name", None)

def get_num_segments(self) -> int:
# This is implemented in BaseRecording or BaseSorting
raise NotImplementedError
Expand Down
17 changes: 7 additions & 10 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BaseRecording(BaseRecordingSnippets):
Internally handle list of RecordingSegment
"""

_main_annotations = ["is_filtered"]
_main_annotations = BaseRecordingSnippets._main_annotations + ["is_filtered"]
_main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"]
_main_features = [] # recording do not handle features

Expand All @@ -45,9 +45,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype):
self.annotate(is_filtered=False)

def __repr__(self):

class_name = self.__class__.__name__
name_to_display = class_name
num_segments = self.get_num_segments()

txt = self._repr_header()
Expand All @@ -57,7 +54,7 @@ def __repr__(self):
split_index = txt.rfind("-", 0, 100) # Find the last "-" before character 100
if split_index != -1:
first_line = txt[:split_index]
recording_string_space = len(name_to_display) + 2 # Length of name_to_display plus ": "
recording_string_space = len(self.name) + 2 # Length of self.name plus ": "
white_space_to_align_with_first_line = " " * recording_string_space
second_line = white_space_to_align_with_first_line + txt[split_index + 1 :].lstrip()
txt = first_line + "\n" + second_line
Expand Down Expand Up @@ -97,21 +94,21 @@ def list_to_string(lst, max_size=6):
return txt

def _repr_header(self):
class_name = self.__class__.__name__
name_to_display = class_name
num_segments = self.get_num_segments()
num_channels = self.get_num_channels()
sf_khz = self.get_sampling_frequency() / 1000.0
sf_hz = self.get_sampling_frequency()
sf_khz = sf_hz / 1000
dtype = self.get_dtype()

total_samples = self.get_total_samples()
total_duration = self.get_total_duration()
total_memory_size = self.get_total_memory_size()
sampling_frequency_repr = f"{sf_khz:0.1f}kHz" if sf_hz > 10_000.0 else f"{sf_hz:0.1f}Hz"

txt = (
f"{name_to_display}: "
f"{self.name}: "
f"{num_channels} channels - "
f"{sf_khz:0.1f}kHz - "
f"{sampling_frequency_repr} - "
f"{num_segments} segments - "
f"{total_samples:,} samples - "
f"{convert_seconds_to_str(total_duration)} - "
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class BaseSnippets(BaseRecordingSnippets):
Abstract class representing several multichannel snippets.
"""

_main_annotations = []
_main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"]
_main_features = []

Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ def __init__(self, sampling_frequency: float, unit_ids: List):
self._cached_spike_trains = {}

def __repr__(self):
clsname = self.__class__.__name__
nseg = self.get_num_segments()
nunits = self.get_num_units()
sf_khz = self.get_sampling_frequency() / 1000.0
txt = f"{clsname}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz"
txt = f"{self.name}: {nunits} units - {nseg} segments - {sf_khz:0.1f}kHz"
if "file_path" in self._kwargs:
txt += "\n file_path: {}".format(self._kwargs["file_path"])
return txt
Expand Down
5 changes: 5 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def generate_recording(
probe.set_device_channel_indices(np.arange(num_channels))
recording.set_probe(probe, in_place=True)

recording.name = "SyntheticRecording"

return recording


Expand Down Expand Up @@ -2122,4 +2124,7 @@ def generate_ground_truth_recording(
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)

recording.name = "GroundTruthRecording"
sorting.name = "GroundTruthSorting"

return recording, sorting
10 changes: 7 additions & 3 deletions src/spikeinterface/core/segmentutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,20 @@ def get_traces(self, start_frame, end_frame, channel_indices):
seg_start = self.cumsum_length[i]
if i == i0:
# first
traces_chunk = rec_seg.get_traces(start_frame - seg_start, None, channel_indices)
end_frame_ = rec_seg.get_num_samples()
traces_chunk = rec_seg.get_traces(start_frame - seg_start, end_frame_, channel_indices)
all_traces.append(traces_chunk)
elif i == i1:
# last
if (end_frame - seg_start) > 0:
traces_chunk = rec_seg.get_traces(None, end_frame - seg_start, channel_indices)
start_frame_ = 0
traces_chunk = rec_seg.get_traces(start_frame_, end_frame - seg_start, channel_indices)
all_traces.append(traces_chunk)
else:
# in between
traces_chunk = rec_seg.get_traces(None, None, channel_indices)
start_frame_ = 0
end_frame_ = rec_seg.get_num_samples()
traces_chunk = rec_seg.get_traces(start_frame_, end_frame_, channel_indices)
all_traces.append(traces_chunk)
traces = np.concatenate(all_traces, axis=0)

Expand Down
31 changes: 30 additions & 1 deletion src/spikeinterface/core/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""

from typing import Sequence
import numpy as np
from spikeinterface.core.base import BaseExtractor
from spikeinterface.core import generate_recording, concatenate_recordings
from spikeinterface.core import generate_recording, generate_ground_truth_recording, concatenate_recordings


class DummyDictExtractor(BaseExtractor):
Expand Down Expand Up @@ -65,6 +66,34 @@ def test_check_if_serializable():
assert not extractor.check_serializability("json")


def test_name_and_repr():
test_recording, test_sorting = generate_ground_truth_recording(seed=0, durations=[2])
assert test_recording.name == "GroundTruthRecording"
assert test_sorting.name == "GroundTruthSorting"

# set a different name
test_recording.name = "MyRecording"
assert test_recording.name == "MyRecording"

# to/from dict
test_recording_dict = test_recording.to_dict()
test_recording2 = BaseExtractor.from_dict(test_recording_dict)
assert test_recording2.name == "MyRecording"

# repr
rec_str = str(test_recording2)
assert "MyRecording" in rec_str
test_recording2.name = None
assert "MyRecording" not in str(test_recording2)
assert test_recording2.__class__.__name__ in str(test_recording2)
# above 10khz, sampling frequency is printed in kHz
assert f"kHz" in rec_str
# below 10khz sampling frequency is printed in Hz
test_rec_low_fs = generate_recording(seed=0, durations=[2], sampling_frequency=5000)
rec_str = str(test_rec_low_fs)
assert "Hz" in rec_str


if __name__ == "__main__":
test_check_if_memory_serializable()
test_check_if_serializable()
4 changes: 0 additions & 4 deletions src/spikeinterface/core/tests/test_segmentutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
from numpy.testing import assert_raises

from spikeinterface.core import (
AppendSegmentRecording,
AppendSegmentSorting,
ConcatenateSegmentRecording,
ConcatenateSegmentSorting,
NumpyRecording,
NumpySorting,
append_recordings,
Expand Down
63 changes: 35 additions & 28 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,40 @@ def _retrieve_electrodes_indices_from_electrical_series_backend(open_file, elect
return electrodes_indices


class NwbRecordingExtractor(BaseRecording):
class _BaseNWBExtractor:
"A class for common methods for NWB extractors."

def _close_hdf5_file(self):
has_hdf5_backend = hasattr(self, "_file")
if has_hdf5_backend:
import h5py

main_file_id = self._file.id
open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL)
for object_id in open_object_ids_main:
object_name = h5py.h5i.get_name(object_id).decode("utf-8")
try:
object_id.close()
except:
import warnings

warnings.warn(f"Error closing object {object_name}")

def __del__(self):
# backend mode
if hasattr(self, "_file"):
if hasattr(self._file, "store"):
self._file.store.close()
else:
self._close_hdf5_file()
# pynwb mode
elif hasattr(self, "_nwbfile"):
io = self._nwbfile.get_read_io()
if io is not None:
io.close()


class NwbRecordingExtractor(BaseRecording, _BaseNWBExtractor):
"""Load an NWBFile as a RecordingExtractor.
Parameters
Expand Down Expand Up @@ -623,19 +656,6 @@ def __init__(
"file": file,
}

def __del__(self):
# backend mode
if hasattr(self, "_file"):
if hasattr(self._file, "store"):
self._file.store.close()
else:
self._file.close()
# pynwb mode
elif hasattr(self, "_nwbfile"):
io = self._nwbfile.get_read_io()
if io is not None:
io.close()

def _fetch_recording_segment_info_pynwb(self, file, cache, load_time_vector, samples_for_rate_estimation):
self._nwbfile = read_nwbfile(
backend=self.backend,
Expand Down Expand Up @@ -949,7 +969,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
return traces


class NwbSortingExtractor(BaseSorting):
class NwbSortingExtractor(BaseSorting, _BaseNWBExtractor):
"""Load an NWBFile as a SortingExtractor.
Parameters
----------
Expand Down Expand Up @@ -1105,19 +1125,6 @@ def __init__(
"t_start": self.t_start,
}

def __del__(self):
# backend mode
if hasattr(self, "_file"):
if hasattr(self._file, "store"):
self._file.store.close()
else:
self._file.close()
# pynwb mode
elif hasattr(self, "_nwbfile"):
io = self._nwbfile.get_read_io()
if io is not None:
io.close()

def _fetch_sorting_segment_info_pynwb(
self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False
):
Expand Down
28 changes: 17 additions & 11 deletions src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class CommonReferenceRecording(BasePreprocessor):
recording : RecordingExtractor
The recording extractor to be re-referenced
reference : "global" | "single" | "local", default: "global"
If "global" the reference is the average or median across all the channels.
If "global" the reference is the average or median across all the channels. To select a subset of channels,
you can use the `ref_channel_ids` parameter.
If "single", the reference is a single channel or a list of channels that need to be set with the `ref_channel_ids`.
If "local", the reference is the set of channels within an annulus that must be set with the `local_radius` parameter.
operator : "median" | "average", default: "median"
Expand All @@ -51,10 +52,10 @@ class CommonReferenceRecording(BasePreprocessor):
List of lists containing the channel ids for splitting the reference. The CMR, CAR, or referencing with respect to
single channels are applied group-wise. However, this is not applied for the local CAR.
It is useful when dealing with different channel groups, e.g. multiple tetrodes.
ref_channel_ids : list or str or int, default: None
If no "groups" are specified, all channels are referenced to "ref_channel_ids". If "groups" is provided, then a
list of channels to be applied to each group is expected. If "single" reference, a list of one channel or an
int is expected.
ref_channel_ids : list | str | int | None, default: None
If "global" reference, a list of channels to be used as reference.
If "single" reference, a list of one channel or a single channel id is expected.
If "groups" is provided, then a list of channels to be applied to each group is expected.
local_radius : tuple(int, int), default: (30, 55)
Use in the local CAR implementation as the selecting annulus with the following format:
Expand Down Expand Up @@ -82,10 +83,10 @@ def __init__(
recording: BaseRecording,
reference: Literal["global", "single", "local"] = "global",
operator: Literal["median", "average"] = "median",
groups=None,
ref_channel_ids=None,
local_radius=(30, 55),
dtype=None,
groups: list | None = None,
ref_channel_ids: list | str | int | None = None,
local_radius: tuple[float, float] = (30.0, 55.0),
dtype: str | np.dtype | None = None,
):
num_chans = recording.get_num_channels()
neighbors = None
Expand All @@ -96,7 +97,9 @@ def __init__(
raise ValueError("'operator' must be either 'median', 'average'")

if reference == "global":
pass
if ref_channel_ids is not None:
if not isinstance(ref_channel_ids, list):
raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list")
elif reference == "single":
assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'"
if groups is not None:
Expand Down Expand Up @@ -182,7 +185,10 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None))

if self.reference == "global":
shift = self.operator_func(traces, axis=1, keepdims=True)
if self.ref_channel_indices is None:
shift = self.operator_func(traces, axis=1, keepdims=True)
else:
shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True)
re_referenced_traces = traces[:, channel_indices] - shift
elif self.reference == "single":
# single channel -> no need of operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def _generate_test_recording():
recording = generate_recording(durations=[1.0], num_channels=4)
recording = recording.channel_slice(recording.channel_ids, np.array(["a", "b", "c", "d"]))
recording = recording.rename_channels(np.array(["a", "b", "c", "d"]))
return recording


Expand All @@ -23,12 +23,14 @@ def recording():
def test_common_reference(recording):
# Test simple case
rec_cmr = common_reference(recording, reference="global", operator="median")
rec_cmr_ref = common_reference(recording, reference="global", operator="median", ref_channel_ids=["a", "b", "c"])
rec_car = common_reference(recording, reference="global", operator="average")
rec_sin = common_reference(recording, reference="single", ref_channel_ids=["a"])
rec_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")

traces = recording.get_traces()
assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01)
assert np.allclose(traces, rec_cmr_ref.get_traces() + np.median(traces[:, :3], axis=1, keepdims=True), atol=0.01)
assert np.allclose(traces, rec_car.get_traces() + np.mean(traces, axis=1, keepdims=True), atol=0.01)
assert not np.all(rec_sin.get_traces()[0])
assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0])
Expand Down

0 comments on commit 4d34a61

Please sign in to comment.