Skip to content

Commit

Permalink
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
Browse files Browse the repository at this point in the history
…merging_units
  • Loading branch information
yger committed Jul 9, 2024
2 parents a3ead81 + a1958ce commit f313a60
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 47 deletions.
10 changes: 7 additions & 3 deletions src/spikeinterface/core/segmentutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,20 @@ def get_traces(self, start_frame, end_frame, channel_indices):
seg_start = self.cumsum_length[i]
if i == i0:
# first
traces_chunk = rec_seg.get_traces(start_frame - seg_start, None, channel_indices)
end_frame_ = rec_seg.get_num_samples()
traces_chunk = rec_seg.get_traces(start_frame - seg_start, end_frame_, channel_indices)
all_traces.append(traces_chunk)
elif i == i1:
# last
if (end_frame - seg_start) > 0:
traces_chunk = rec_seg.get_traces(None, end_frame - seg_start, channel_indices)
start_frame_ = 0
traces_chunk = rec_seg.get_traces(start_frame_, end_frame - seg_start, channel_indices)
all_traces.append(traces_chunk)
else:
# in between
traces_chunk = rec_seg.get_traces(None, None, channel_indices)
start_frame_ = 0
end_frame_ = rec_seg.get_num_samples()
traces_chunk = rec_seg.get_traces(start_frame_, end_frame_, channel_indices)
all_traces.append(traces_chunk)
traces = np.concatenate(all_traces, axis=0)

Expand Down
4 changes: 0 additions & 4 deletions src/spikeinterface/core/tests/test_segmentutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
from numpy.testing import assert_raises

from spikeinterface.core import (
AppendSegmentRecording,
AppendSegmentSorting,
ConcatenateSegmentRecording,
ConcatenateSegmentSorting,
NumpyRecording,
NumpySorting,
append_recordings,
Expand Down
63 changes: 35 additions & 28 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,40 @@ def _retrieve_electrodes_indices_from_electrical_series_backend(open_file, elect
return electrodes_indices


class NwbRecordingExtractor(BaseRecording):
class _BaseNWBExtractor:
"A class for common methods for NWB extractors."

def _close_hdf5_file(self):
has_hdf5_backend = hasattr(self, "_file")
if has_hdf5_backend:
import h5py

main_file_id = self._file.id
open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL)
for object_id in open_object_ids_main:
object_name = h5py.h5i.get_name(object_id).decode("utf-8")
try:
object_id.close()
except:
import warnings

warnings.warn(f"Error closing object {object_name}")

def __del__(self):
# backend mode
if hasattr(self, "_file"):
if hasattr(self._file, "store"):
self._file.store.close()
else:
self._close_hdf5_file()
# pynwb mode
elif hasattr(self, "_nwbfile"):
io = self._nwbfile.get_read_io()
if io is not None:
io.close()


class NwbRecordingExtractor(BaseRecording, _BaseNWBExtractor):
"""Load an NWBFile as a RecordingExtractor.
Parameters
Expand Down Expand Up @@ -623,19 +656,6 @@ def __init__(
"file": file,
}

def __del__(self):
# backend mode
if hasattr(self, "_file"):
if hasattr(self._file, "store"):
self._file.store.close()
else:
self._file.close()
# pynwb mode
elif hasattr(self, "_nwbfile"):
io = self._nwbfile.get_read_io()
if io is not None:
io.close()

def _fetch_recording_segment_info_pynwb(self, file, cache, load_time_vector, samples_for_rate_estimation):
self._nwbfile = read_nwbfile(
backend=self.backend,
Expand Down Expand Up @@ -949,7 +969,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
return traces


class NwbSortingExtractor(BaseSorting):
class NwbSortingExtractor(BaseSorting, _BaseNWBExtractor):
"""Load an NWBFile as a SortingExtractor.
Parameters
----------
Expand Down Expand Up @@ -1105,19 +1125,6 @@ def __init__(
"t_start": self.t_start,
}

def __del__(self):
# backend mode
if hasattr(self, "_file"):
if hasattr(self._file, "store"):
self._file.store.close()
else:
self._file.close()
# pynwb mode
elif hasattr(self, "_nwbfile"):
io = self._nwbfile.get_read_io()
if io is not None:
io.close()

def _fetch_sorting_segment_info_pynwb(
self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False
):
Expand Down
28 changes: 17 additions & 11 deletions src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class CommonReferenceRecording(BasePreprocessor):
recording : RecordingExtractor
The recording extractor to be re-referenced
reference : "global" | "single" | "local", default: "global"
If "global" the reference is the average or median across all the channels.
If "global" the reference is the average or median across all the channels. To select a subset of channels,
you can use the `ref_channel_ids` parameter.
If "single", the reference is a single channel or a list of channels that need to be set with the `ref_channel_ids`.
If "local", the reference is the set of channels within an annulus that must be set with the `local_radius` parameter.
operator : "median" | "average", default: "median"
Expand All @@ -51,10 +52,10 @@ class CommonReferenceRecording(BasePreprocessor):
List of lists containing the channel ids for splitting the reference. The CMR, CAR, or referencing with respect to
single channels are applied group-wise. However, this is not applied for the local CAR.
It is useful when dealing with different channel groups, e.g. multiple tetrodes.
ref_channel_ids : list or str or int, default: None
If no "groups" are specified, all channels are referenced to "ref_channel_ids". If "groups" is provided, then a
list of channels to be applied to each group is expected. If "single" reference, a list of one channel or an
int is expected.
ref_channel_ids : list | str | int | None, default: None
If "global" reference, a list of channels to be used as reference.
If "single" reference, a list of one channel or a single channel id is expected.
If "groups" is provided, then a list of channels to be applied to each group is expected.
local_radius : tuple(int, int), default: (30, 55)
Use in the local CAR implementation as the selecting annulus with the following format:
Expand Down Expand Up @@ -82,10 +83,10 @@ def __init__(
recording: BaseRecording,
reference: Literal["global", "single", "local"] = "global",
operator: Literal["median", "average"] = "median",
groups=None,
ref_channel_ids=None,
local_radius=(30, 55),
dtype=None,
groups: list | None = None,
ref_channel_ids: list | str | int | None = None,
local_radius: tuple[float, float] = (30.0, 55.0),
dtype: str | np.dtype | None = None,
):
num_chans = recording.get_num_channels()
neighbors = None
Expand All @@ -96,7 +97,9 @@ def __init__(
raise ValueError("'operator' must be either 'median', 'average'")

if reference == "global":
pass
if ref_channel_ids is not None:
if not isinstance(ref_channel_ids, list):
raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list")
elif reference == "single":
assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'"
if groups is not None:
Expand Down Expand Up @@ -182,7 +185,10 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None))

if self.reference == "global":
shift = self.operator_func(traces, axis=1, keepdims=True)
if self.ref_channel_indices is None:
shift = self.operator_func(traces, axis=1, keepdims=True)
else:
shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True)
re_referenced_traces = traces[:, channel_indices] - shift
elif self.reference == "single":
# single channel -> no need of operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def _generate_test_recording():
recording = generate_recording(durations=[1.0], num_channels=4)
recording = recording.channel_slice(recording.channel_ids, np.array(["a", "b", "c", "d"]))
recording = recording.rename_channels(np.array(["a", "b", "c", "d"]))
return recording


Expand All @@ -23,12 +23,14 @@ def recording():
def test_common_reference(recording):
# Test simple case
rec_cmr = common_reference(recording, reference="global", operator="median")
rec_cmr_ref = common_reference(recording, reference="global", operator="median", ref_channel_ids=["a", "b", "c"])
rec_car = common_reference(recording, reference="global", operator="average")
rec_sin = common_reference(recording, reference="single", ref_channel_ids=["a"])
rec_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")

traces = recording.get_traces()
assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01)
assert np.allclose(traces, rec_cmr_ref.get_traces() + np.median(traces[:, :3], axis=1, keepdims=True), atol=0.01)
assert np.allclose(traces, rec_car.get_traces() + np.mean(traces, axis=1, keepdims=True), atol=0.01)
assert not np.all(rec_sin.get_traces()[0])
assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0])
Expand Down

0 comments on commit f313a60

Please sign in to comment.