From be0fd8afacea9508d38e65a9f89655a5e25bba57 Mon Sep 17 00:00:00 2001 From: JuanPimiento <148992347+JuanPimientoCaicedo@users.noreply.github.com> Date: Tue, 20 Aug 2024 12:52:10 -0400 Subject: [PATCH] Add causal filtering to filter.py (#3172) --- doc/api.rst | 1 + src/spikeinterface/preprocessing/filter.py | 127 ++++++++++++++-- .../preprocessing/preprocessinglist.py | 1 + .../preprocessing/tests/test_filter.py | 137 +++++++++++++++++- 4 files changed, 254 insertions(+), 12 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 1966b48a37..42f9fec299 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -171,6 +171,7 @@ spikeinterface.preprocessing .. autofunction:: interpolate_bad_channels .. autofunction:: normalize_by_quantile .. autofunction:: notch_filter + .. autofunction:: causal_filter .. autofunction:: phase_shift .. autofunction:: rectify .. autofunction:: remove_artifacts diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 54c5ab2b2d..a67d163d3d 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -24,10 +24,12 @@ class FilterRecording(BasePreprocessor): """ - Generic filter class based on: - - * scipy.signal.iirfilter - * scipy.signal.filtfilt or scipy.signal.sosfilt + A generic filter class based on: + For filter coefficient generation: + * scipy.signal.iirfilter + For filter application: + * scipy.signal.filtfilt or scipy.signal.sosfiltfilt when direction = "forward-backward" + * scipy.signal.lfilter or scipy.signal.sosfilt when direction = "forward" or "backward" BandpassFilterRecording is built on top of it. @@ -56,6 +58,11 @@ class FilterRecording(BasePreprocessor): - numerator/denominator : ("ba") ftype : str, default: "butter" Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + direction : "forward" | "backward" | "forward-backward", default: "forward-backward" + Direction of filtering: + - "forward" - filter is applied to the timeseries in one direction, creating phase shifts + - "backward" - the timeseries is reversed, the filter is applied and filtered timeseries reversed again. Creates phase shifts in the opposite direction to "forward" + - "forward-backward" - Applies the filter in the forward and backward direction, resulting in zero-phase filtering. Note this doubles the effective filter order. Returns ------- @@ -75,6 +82,7 @@ def __init__( add_reflect_padding=False, coeff=None, dtype=None, + direction="forward-backward", ): import scipy.signal @@ -106,7 +114,13 @@ def __init__( for parent_segment in recording._recording_segments: self.add_recording_segment( FilterRecordingSegment( - parent_segment, filter_coeff, filter_mode, margin, dtype, add_reflect_padding=add_reflect_padding + parent_segment, + filter_coeff, + filter_mode, + margin, + dtype, + add_reflect_padding=add_reflect_padding, + direction=direction, ) ) @@ -121,14 +135,25 @@ def __init__( margin_ms=margin_ms, add_reflect_padding=add_reflect_padding, dtype=dtype.str, + direction=direction, ) class FilterRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, add_reflect_padding=False): + def __init__( + self, + parent_recording_segment, + coeff, + filter_mode, + margin, + dtype, + add_reflect_padding=False, + direction="forward-backward", + ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.coeff = coeff self.filter_mode = filter_mode + self.direction = direction self.margin = margin self.add_reflect_padding = add_reflect_padding self.dtype = dtype @@ -150,11 +175,24 @@ def get_traces(self, start_frame, end_frame, channel_indices): import scipy.signal - if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) - elif self.filter_mode == "ba": - b, a = self.coeff - filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + if self.direction == "forward-backward": + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + else: + if self.direction == "backward": + traces_chunk = np.flip(traces_chunk, axis=0) + + if self.filter_mode == "sos": + filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0) + elif self.filter_mode == "ba": + b, a = self.coeff + filtered_traces = scipy.signal.lfilter(b, a, traces_chunk, axis=0) + + if self.direction == "backward": + filtered_traces = np.flip(filtered_traces, axis=0) if right_margin > 0: filtered_traces = filtered_traces[left_margin:-right_margin, :] @@ -289,6 +327,73 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") + +def causal_filter( + recording, + direction="forward", + band=[300.0, 6000.0], + btype="bandpass", + filter_order=5, + ftype="butter", + filter_mode="sos", + margin_ms=5.0, + add_reflect_padding=False, + coeff=None, + dtype=None, +): + """ + Generic causal filter built on top of the filter function. + + Parameters + ---------- + recording : Recording + The recording extractor to be re-referenced + direction : "forward" | "backward", default: "forward" + Direction of causal filter. The "backward" option flips the traces in time before applying the filter + and then flips them back. + band : float or list, default: [300.0, 6000.0] + If float, cutoff frequency in Hz for "highpass" filter type + If list. band (low, high) in Hz for "bandpass" filter type + btype : "bandpass" | "highpass", default: "bandpass" + Type of the filter + margin_ms : float, default: 5.0 + Margin in ms on border to avoid border effect + coeff : array | None, default: None + Filter coefficients in the filter_mode form. + dtype : dtype or None, default: None + The dtype of the returned traces. If None, the dtype of the parent recording is used + add_reflect_padding : Bool, default False + If True, uses a left and right margin during calculation. + filter_order : order + The order of the filter for `scipy.signal.iirfilter` + filter_mode : "sos" | "ba", default: "sos" + Filter form of the filter coefficients for `scipy.signal.iirfilter`: + - second-order sections ("sos") + - numerator/denominator : ("ba") + ftype : str, default: "butter" + Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1". + + Returns + ------- + filter_recording : FilterRecording + The causal-filtered recording extractor object + """ + assert direction in ["forward", "backward"], "Direction must be either 'forward' or 'backward'" + return filter( + recording=recording, + direction=direction, + band=band, + btype=btype, + filter_order=filter_order, + ftype=ftype, + filter_mode=filter_mode, + margin_ms=margin_ms, + add_reflect_padding=add_reflect_padding, + coeff=coeff, + dtype=dtype, + ) + + bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs) highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs) diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index 149c6eb458..bdf5f2219c 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -12,6 +12,7 @@ notch_filter, HighpassFilterRecording, highpass_filter, + causal_filter, ) from .filter_gaussian import GaussianFilterRecording, gaussian_filter from .normalize_scale import ( diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 68790b3273..9df60af3db 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -4,7 +4,140 @@ from spikeinterface.core import generate_recording from spikeinterface import NumpyRecording, set_global_tmp_folder -from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter +from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter, causal_filter + + +class TestCausalFilter: + """ + The only thing that is not tested (JZ, as of 23/07/2024) is the + propagation of margin kwargs, these are general filter params + and can be tested in an upcoming PR. + """ + + @pytest.fixture(scope="session") + def recording_and_data(self): + recording = generate_recording(durations=[1]) + raw_data = recording.get_traces() + + return (recording, raw_data) + + def test_causal_filter_main_kwargs(self, recording_and_data): + """ + Perform a test that expected output is returned under change + of all key filter-related kwargs. First run the filter in + the forward direction with options and compare it + to the expected output from scipy. + + Next, change every filter-related kwarg and set in the backwards + direction. Again check it matches expected scipy output. + """ + from scipy.signal import lfilter, sosfilt + + recording, raw_data = recording_and_data + + # First, check in the forward direction with + # the default set of kwargs + options = self._get_filter_options() + + sos = self._run_iirfilter(options, recording) + + test_data = sosfilt(sos, raw_data, axis=0) + test_data.astype(recording.dtype) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + + # Then, change all kwargs to ensure they are propagated + # and check the backwards version. + options["band"] = [671] + options["btype"] = "highpass" + options["filter_order"] = 8 + options["ftype"] = "bessel" + options["filter_mode"] = "ba" + options["dtype"] = np.float16 + + b, a = self._run_iirfilter(options, recording) + + flip_raw = np.flip(raw_data, axis=0) + test_data = lfilter(b, a, flip_raw, axis=0) + test_data = np.flip(test_data, axis=0) + test_data = test_data.astype(options["dtype"]) + + filt_data = causal_filter(recording, direction="backward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + + def test_causal_filter_custom_coeff(self, recording_and_data): + """ + A different path is taken when custom coeff is selected. + Therefore, explicitly test the expected outputs are obtained + when passing custom coeff, under the "ba" and "sos" conditions. + """ + from scipy.signal import lfilter, sosfilt + + recording, raw_data = recording_and_data + + options = self._get_filter_options() + options["filter_mode"] = "ba" + options["coeff"] = (np.array([0.1, 0.2, 0.3]), np.array([0.4, 0.5, 0.6])) + + # Check the custom coeff are propagated in both modes. + # First, in "ba" mode + test_data = lfilter(options["coeff"][0], options["coeff"][1], raw_data, axis=0) + test_data = test_data.astype(recording.get_dtype()) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + + # Next, in "sos" mode + options["filter_mode"] = "sos" + options["coeff"] = np.ones((2, 6)) + + test_data = sosfilt(options["coeff"], raw_data, axis=0) + test_data = test_data.astype(recording.get_dtype()) + + filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() + + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + + def test_causal_kwarg_error_raised(self, recording_and_data): + """ + Test that passing the "forward-backward" direction results in + an error. It is is critical this error is raised, + otherwise the filter will no longer be causal. + """ + recording, raw_data = recording_and_data + + with pytest.raises(BaseException) as e: + filt_data = causal_filter(recording, direction="forward-backward") + + def _run_iirfilter(self, options, recording): + """ + Convenience function to convert Si kwarg + names to Scipy. + """ + from scipy.signal import iirfilter + + return iirfilter( + N=options["filter_order"], + Wn=options["band"], + btype=options["btype"], + ftype=options["ftype"], + output=options["filter_mode"], + fs=recording.get_sampling_frequency(), + ) + + def _get_filter_options(self): + return { + "band": [300.0, 6000.0], + "btype": "bandpass", + "filter_order": 5, + "ftype": "butter", + "filter_mode": "sos", + "coeff": None, + } def test_filter(): @@ -28,6 +161,8 @@ def test_filter(): # other filtering types rec3 = filter(rec, band=500.0, btype="highpass", filter_mode="ba", filter_order=2) rec4 = notch_filter(rec, freq=3000, q=30, margin_ms=5.0) + rec5 = causal_filter(rec, direction="forward") + rec6 = causal_filter(rec, direction="backward") # filter from coefficients from scipy.signal import iirfilter