diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index a4369309f2..ffe7467753 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -39,7 +39,8 @@ class CommonReferenceRecording(BasePreprocessor): recording: RecordingExtractor The recording extractor to be re-referenced reference: "global" | "single" | "local", default: "global" - If "global" the reference is the average or median across all the channels. + If "global" the reference is the average or median across all the channels. To select a subset of channels, + you can use the `ref_channel_ids` parameter. If "single", the reference is a single channel or a list of channels that need to be set with the `ref_channel_ids`. If "local", the reference is the set of channels within an annulus that must be set with the `local_radius` parameter. operator: "median" | "average", default: "median" @@ -51,11 +52,11 @@ class CommonReferenceRecording(BasePreprocessor): List of lists containing the channel ids for splitting the reference. The CMR, CAR, or referencing with respect to single channels are applied group-wise. However, this is not applied for the local CAR. It is useful when dealing with different channel groups, e.g. multiple tetrodes. - ref_channel_ids: list or str or int, default: None - If no "groups" are specified, all channels are referenced to "ref_channel_ids". If "groups" is provided, then a - list of channels to be applied to each group is expected. If "single" reference, a list of one channel or an - int is expected. - local_radius: tuple(int, int), default: (30, 55) + ref_channel_ids: list | str | int | None, default: None + If "global" reference, a list of channels to be used as reference. + If "single" reference, a list of one channel or a single channel id is expected. + If "groups" is provided, then a list of channels to be applied to each group is expected. + local_radius : tuple(int, int), default: (30, 55) Use in the local CAR implementation as the selecting annulus with the following format: `(exclude radius, include radius)` @@ -82,10 +83,10 @@ def __init__( recording: BaseRecording, reference: Literal["global", "single", "local"] = "global", operator: Literal["median", "average"] = "median", - groups=None, - ref_channel_ids=None, - local_radius=(30, 55), - dtype=None, + groups: list | None = None, + ref_channel_ids: list | str | int | None = None, + local_radius: tuple[float, float] = (30.0, 55.0), + dtype: str | np.dtype | None = None, ): num_chans = recording.get_num_channels() neighbors = None @@ -96,7 +97,9 @@ def __init__( raise ValueError("'operator' must be either 'median', 'average'") if reference == "global": - pass + if ref_channel_ids is not None: + if not isinstance(ref_channel_ids, list): + raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") elif reference == "single": assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'" if groups is not None: @@ -182,7 +185,10 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) if self.reference == "global": - shift = self.operator_func(traces, axis=1, keepdims=True) + if self.ref_channel_indices is None: + shift = self.operator_func(traces, axis=1, keepdims=True) + else: + shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) re_referenced_traces = traces[:, channel_indices] - shift elif self.reference == "single": # single channel -> no need of operator diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index 1df9b21c81..8b37e7f4b9 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -11,7 +11,7 @@ def _generate_test_recording(): recording = generate_recording(durations=[1.0], num_channels=4) - recording = recording.channel_slice(recording.channel_ids, np.array(["a", "b", "c", "d"])) + recording = recording.rename_channels(np.array(["a", "b", "c", "d"])) return recording @@ -23,12 +23,14 @@ def recording(): def test_common_reference(recording): # Test simple case rec_cmr = common_reference(recording, reference="global", operator="median") + rec_cmr_ref = common_reference(recording, reference="global", operator="median", ref_channel_ids=["a", "b", "c"]) rec_car = common_reference(recording, reference="global", operator="average") rec_sin = common_reference(recording, reference="single", ref_channel_ids=["a"]) rec_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") traces = recording.get_traces() assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01) + assert np.allclose(traces, rec_cmr_ref.get_traces() + np.median(traces[:, :3], axis=1, keepdims=True), atol=0.01) assert np.allclose(traces, rec_car.get_traces() + np.mean(traces, axis=1, keepdims=True), atol=0.01) assert not np.all(rec_sin.get_traces()[0]) assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0])