diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index b654b965ff..1602fa93b5 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -290,6 +290,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str) + class Causalfilter(BasePreprocessor): """ filter class based on: @@ -346,7 +347,6 @@ def __init__( add_reflect_padding=False, coeff=None, dtype=None, - ): import scipy.signal @@ -378,9 +378,14 @@ def __init__( for parent_segment in recording._recording_segments: self.add_recording_segment( CausalFilterRecordingSegment( - parent_segment, filter_coeff, filter_mode, margin, dtype, - direction, add_reflect_padding=add_reflect_padding - ) + parent_segment, + filter_coeff, + filter_mode, + margin, + dtype, + direction, + add_reflect_padding=add_reflect_padding, + ) ) self._kwargs = dict( @@ -399,7 +404,9 @@ def __init__( class CausalFilterRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, coeff, filter_mode, margin, dtype, direction, add_reflect_padding=False): + def __init__( + self, parent_recording_segment, coeff, filter_mode, margin, dtype, direction, add_reflect_padding=False + ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.coeff = coeff self.filter_mode = filter_mode @@ -429,13 +436,15 @@ def get_traces(self, start_frame, end_frame, channel_indices): if self.direction == "forward": filtered_traces = scipy.signal.sosfilt(self.coeff, traces_chunk, axis=0) elif self.direction == "backward": - filtered_traces = np.flip(scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk, axis = 0), axis = 0), axis = 0) + filtered_traces = np.flip( + scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk, axis=0), axis=0), axis=0 + ) elif self.filter_mode == "ba": b, a = self.coeff if self.direction == "forward": filtered_traces = scipy.signal.lfilt(b, a, traces_chunk, axis=0) elif self.direction == "backward": - filtered_traces = np.flip(scipy.signal.lfilt(b, a, np.flip(traces_chunk, axis = 0), axis=0), axis = 0) + filtered_traces = np.flip(scipy.signal.lfilt(b, a, np.flip(traces_chunk, axis=0), axis=0), axis=0) if right_margin > 0: filtered_traces = filtered_traces[left_margin:-right_margin, :] else: @@ -446,6 +455,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): return filtered_traces.astype(self.dtype) + # functions for API filter = define_function_from_class(source_class=FilterRecording, name="filter") bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter")