Skip to content

Commit

Permalink
Propagate #3139 to bug-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Jul 4, 2024
1 parent 1197aad commit 7c27eed
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
30 changes: 18 additions & 12 deletions src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class CommonReferenceRecording(BasePreprocessor):
recording: RecordingExtractor
The recording extractor to be re-referenced
reference: "global" | "single" | "local", default: "global"
If "global" the reference is the average or median across all the channels.
If "global" the reference is the average or median across all the channels. To select a subset of channels,
you can use the `ref_channel_ids` parameter.
If "single", the reference is a single channel or a list of channels that need to be set with the `ref_channel_ids`.
If "local", the reference is the set of channels within an annulus that must be set with the `local_radius` parameter.
operator: "median" | "average", default: "median"
Expand All @@ -51,11 +52,11 @@ class CommonReferenceRecording(BasePreprocessor):
List of lists containing the channel ids for splitting the reference. The CMR, CAR, or referencing with respect to
single channels are applied group-wise. However, this is not applied for the local CAR.
It is useful when dealing with different channel groups, e.g. multiple tetrodes.
ref_channel_ids: list or str or int, default: None
If no "groups" are specified, all channels are referenced to "ref_channel_ids". If "groups" is provided, then a
list of channels to be applied to each group is expected. If "single" reference, a list of one channel or an
int is expected.
local_radius: tuple(int, int), default: (30, 55)
ref_channel_ids: list | str | int | None, default: None
If "global" reference, a list of channels to be used as reference.
If "single" reference, a list of one channel or a single channel id is expected.
If "groups" is provided, then a list of channels to be applied to each group is expected.
local_radius : tuple(int, int), default: (30, 55)
Use in the local CAR implementation as the selecting annulus with the following format:
`(exclude radius, include radius)`
Expand All @@ -82,10 +83,10 @@ def __init__(
recording: BaseRecording,
reference: Literal["global", "single", "local"] = "global",
operator: Literal["median", "average"] = "median",
groups=None,
ref_channel_ids=None,
local_radius=(30, 55),
dtype=None,
groups: list | None = None,
ref_channel_ids: list | str | int | None = None,
local_radius: tuple[float, float] = (30.0, 55.0),
dtype: str | np.dtype | None = None,
):
num_chans = recording.get_num_channels()
neighbors = None
Expand All @@ -96,7 +97,9 @@ def __init__(
raise ValueError("'operator' must be either 'median', 'average'")

if reference == "global":
pass
if ref_channel_ids is not None:
if not isinstance(ref_channel_ids, list):
raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list")
elif reference == "single":
assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'"
if groups is not None:
Expand Down Expand Up @@ -182,7 +185,10 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None))

if self.reference == "global":
shift = self.operator_func(traces, axis=1, keepdims=True)
if self.ref_channel_indices is None:
shift = self.operator_func(traces, axis=1, keepdims=True)
else:
shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True)
re_referenced_traces = traces[:, channel_indices] - shift
elif self.reference == "single":
# single channel -> no need of operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def _generate_test_recording():
recording = generate_recording(durations=[1.0], num_channels=4)
recording = recording.channel_slice(recording.channel_ids, np.array(["a", "b", "c", "d"]))
recording = recording.rename_channels(np.array(["a", "b", "c", "d"]))
return recording


Expand All @@ -23,12 +23,14 @@ def recording():
def test_common_reference(recording):
# Test simple case
rec_cmr = common_reference(recording, reference="global", operator="median")
rec_cmr_ref = common_reference(recording, reference="global", operator="median", ref_channel_ids=["a", "b", "c"])
rec_car = common_reference(recording, reference="global", operator="average")
rec_sin = common_reference(recording, reference="single", ref_channel_ids=["a"])
rec_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")

traces = recording.get_traces()
assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01)
assert np.allclose(traces, rec_cmr_ref.get_traces() + np.median(traces[:, :3], axis=1, keepdims=True), atol=0.01)
assert np.allclose(traces, rec_car.get_traces() + np.mean(traces, axis=1, keepdims=True), atol=0.01)
assert not np.all(rec_sin.get_traces()[0])
assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0])
Expand Down

0 comments on commit 7c27eed

Please sign in to comment.