Skip to content

Commit

Permalink
Add causal filtering to filter.py (#3172)
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanPimientoCaicedo authored Aug 20, 2024
1 parent 694f862 commit be0fd8a
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 12 deletions.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ spikeinterface.preprocessing
.. autofunction:: interpolate_bad_channels
.. autofunction:: normalize_by_quantile
.. autofunction:: notch_filter
.. autofunction:: causal_filter
.. autofunction:: phase_shift
.. autofunction:: rectify
.. autofunction:: remove_artifacts
Expand Down
127 changes: 116 additions & 11 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@

class FilterRecording(BasePreprocessor):
"""
Generic filter class based on:
* scipy.signal.iirfilter
* scipy.signal.filtfilt or scipy.signal.sosfilt
A generic filter class based on:
For filter coefficient generation:
* scipy.signal.iirfilter
For filter application:
* scipy.signal.filtfilt or scipy.signal.sosfiltfilt when direction = "forward-backward"
* scipy.signal.lfilter or scipy.signal.sosfilt when direction = "forward" or "backward"
BandpassFilterRecording is built on top of it.
Expand Down Expand Up @@ -56,6 +58,11 @@ class FilterRecording(BasePreprocessor):
- numerator/denominator : ("ba")
ftype : str, default: "butter"
Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1".
direction : "forward" | "backward" | "forward-backward", default: "forward-backward"
Direction of filtering:
- "forward" - filter is applied to the timeseries in one direction, creating phase shifts
- "backward" - the timeseries is reversed, the filter is applied and filtered timeseries reversed again. Creates phase shifts in the opposite direction to "forward"
- "forward-backward" - Applies the filter in the forward and backward direction, resulting in zero-phase filtering. Note this doubles the effective filter order.
Returns
-------
Expand All @@ -75,6 +82,7 @@ def __init__(
add_reflect_padding=False,
coeff=None,
dtype=None,
direction="forward-backward",
):
import scipy.signal

Expand Down Expand Up @@ -106,7 +114,13 @@ def __init__(
for parent_segment in recording._recording_segments:
self.add_recording_segment(
FilterRecordingSegment(
parent_segment, filter_coeff, filter_mode, margin, dtype, add_reflect_padding=add_reflect_padding
parent_segment,
filter_coeff,
filter_mode,
margin,
dtype,
add_reflect_padding=add_reflect_padding,
direction=direction,
)
)

Expand All @@ -121,14 +135,25 @@ def __init__(
margin_ms=margin_ms,
add_reflect_padding=add_reflect_padding,
dtype=dtype.str,
direction=direction,
)


class FilterRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, add_reflect_padding=False):
def __init__(
self,
parent_recording_segment,
coeff,
filter_mode,
margin,
dtype,
add_reflect_padding=False,
direction="forward-backward",
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.coeff = coeff
self.filter_mode = filter_mode
self.direction = direction
self.margin = margin
self.add_reflect_padding = add_reflect_padding
self.dtype = dtype
Expand All @@ -150,11 +175,24 @@ def get_traces(self, start_frame, end_frame, channel_indices):

import scipy.signal

if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)
if self.direction == "forward-backward":
if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)
else:
if self.direction == "backward":
traces_chunk = np.flip(traces_chunk, axis=0)

if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.lfilter(b, a, traces_chunk, axis=0)

if self.direction == "backward":
filtered_traces = np.flip(filtered_traces, axis=0)

if right_margin > 0:
filtered_traces = filtered_traces[left_margin:-right_margin, :]
Expand Down Expand Up @@ -289,6 +327,73 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):
notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter")
highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter")


def causal_filter(
recording,
direction="forward",
band=[300.0, 6000.0],
btype="bandpass",
filter_order=5,
ftype="butter",
filter_mode="sos",
margin_ms=5.0,
add_reflect_padding=False,
coeff=None,
dtype=None,
):
"""
Generic causal filter built on top of the filter function.
Parameters
----------
recording : Recording
The recording extractor to be re-referenced
direction : "forward" | "backward", default: "forward"
Direction of causal filter. The "backward" option flips the traces in time before applying the filter
and then flips them back.
band : float or list, default: [300.0, 6000.0]
If float, cutoff frequency in Hz for "highpass" filter type
If list. band (low, high) in Hz for "bandpass" filter type
btype : "bandpass" | "highpass", default: "bandpass"
Type of the filter
margin_ms : float, default: 5.0
Margin in ms on border to avoid border effect
coeff : array | None, default: None
Filter coefficients in the filter_mode form.
dtype : dtype or None, default: None
The dtype of the returned traces. If None, the dtype of the parent recording is used
add_reflect_padding : Bool, default False
If True, uses a left and right margin during calculation.
filter_order : order
The order of the filter for `scipy.signal.iirfilter`
filter_mode : "sos" | "ba", default: "sos"
Filter form of the filter coefficients for `scipy.signal.iirfilter`:
- second-order sections ("sos")
- numerator/denominator : ("ba")
ftype : str, default: "butter"
Filter type for `scipy.signal.iirfilter` e.g. "butter", "cheby1".
Returns
-------
filter_recording : FilterRecording
The causal-filtered recording extractor object
"""
assert direction in ["forward", "backward"], "Direction must be either 'forward' or 'backward'"
return filter(
recording=recording,
direction=direction,
band=band,
btype=btype,
filter_order=filter_order,
ftype=ftype,
filter_mode=filter_mode,
margin_ms=margin_ms,
add_reflect_padding=add_reflect_padding,
coeff=coeff,
dtype=dtype,
)


bandpass_filter.__doc__ = bandpass_filter.__doc__.format(_common_filter_docs)
highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs)

Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/preprocessing/preprocessinglist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
notch_filter,
HighpassFilterRecording,
highpass_filter,
causal_filter,
)
from .filter_gaussian import GaussianFilterRecording, gaussian_filter
from .normalize_scale import (
Expand Down
137 changes: 136 additions & 1 deletion src/spikeinterface/preprocessing/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,140 @@
from spikeinterface.core import generate_recording
from spikeinterface import NumpyRecording, set_global_tmp_folder

from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter
from spikeinterface.preprocessing import filter, bandpass_filter, notch_filter, causal_filter


class TestCausalFilter:
"""
The only thing that is not tested (JZ, as of 23/07/2024) is the
propagation of margin kwargs, these are general filter params
and can be tested in an upcoming PR.
"""

@pytest.fixture(scope="session")
def recording_and_data(self):
recording = generate_recording(durations=[1])
raw_data = recording.get_traces()

return (recording, raw_data)

def test_causal_filter_main_kwargs(self, recording_and_data):
"""
Perform a test that expected output is returned under change
of all key filter-related kwargs. First run the filter in
the forward direction with options and compare it
to the expected output from scipy.
Next, change every filter-related kwarg and set in the backwards
direction. Again check it matches expected scipy output.
"""
from scipy.signal import lfilter, sosfilt

recording, raw_data = recording_and_data

# First, check in the forward direction with
# the default set of kwargs
options = self._get_filter_options()

sos = self._run_iirfilter(options, recording)

test_data = sosfilt(sos, raw_data, axis=0)
test_data.astype(recording.dtype)

filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces()

assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6)

# Then, change all kwargs to ensure they are propagated
# and check the backwards version.
options["band"] = [671]
options["btype"] = "highpass"
options["filter_order"] = 8
options["ftype"] = "bessel"
options["filter_mode"] = "ba"
options["dtype"] = np.float16

b, a = self._run_iirfilter(options, recording)

flip_raw = np.flip(raw_data, axis=0)
test_data = lfilter(b, a, flip_raw, axis=0)
test_data = np.flip(test_data, axis=0)
test_data = test_data.astype(options["dtype"])

filt_data = causal_filter(recording, direction="backward", **options, margin_ms=0).get_traces()

assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6)

def test_causal_filter_custom_coeff(self, recording_and_data):
"""
A different path is taken when custom coeff is selected.
Therefore, explicitly test the expected outputs are obtained
when passing custom coeff, under the "ba" and "sos" conditions.
"""
from scipy.signal import lfilter, sosfilt

recording, raw_data = recording_and_data

options = self._get_filter_options()
options["filter_mode"] = "ba"
options["coeff"] = (np.array([0.1, 0.2, 0.3]), np.array([0.4, 0.5, 0.6]))

# Check the custom coeff are propagated in both modes.
# First, in "ba" mode
test_data = lfilter(options["coeff"][0], options["coeff"][1], raw_data, axis=0)
test_data = test_data.astype(recording.get_dtype())

filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces()

assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True)

# Next, in "sos" mode
options["filter_mode"] = "sos"
options["coeff"] = np.ones((2, 6))

test_data = sosfilt(options["coeff"], raw_data, axis=0)
test_data = test_data.astype(recording.get_dtype())

filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces()

assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True)

def test_causal_kwarg_error_raised(self, recording_and_data):
"""
Test that passing the "forward-backward" direction results in
an error. It is is critical this error is raised,
otherwise the filter will no longer be causal.
"""
recording, raw_data = recording_and_data

with pytest.raises(BaseException) as e:
filt_data = causal_filter(recording, direction="forward-backward")

def _run_iirfilter(self, options, recording):
"""
Convenience function to convert Si kwarg
names to Scipy.
"""
from scipy.signal import iirfilter

return iirfilter(
N=options["filter_order"],
Wn=options["band"],
btype=options["btype"],
ftype=options["ftype"],
output=options["filter_mode"],
fs=recording.get_sampling_frequency(),
)

def _get_filter_options(self):
return {
"band": [300.0, 6000.0],
"btype": "bandpass",
"filter_order": 5,
"ftype": "butter",
"filter_mode": "sos",
"coeff": None,
}


def test_filter():
Expand All @@ -28,6 +161,8 @@ def test_filter():
# other filtering types
rec3 = filter(rec, band=500.0, btype="highpass", filter_mode="ba", filter_order=2)
rec4 = notch_filter(rec, freq=3000, q=30, margin_ms=5.0)
rec5 = causal_filter(rec, direction="forward")
rec6 = causal_filter(rec, direction="backward")

# filter from coefficients
from scipy.signal import iirfilter
Expand Down

0 comments on commit be0fd8a

Please sign in to comment.