From 7d9c0753fb3c59577dd244d3c9bce1d6272015e6 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 4 Oct 2023 14:11:54 +0200 Subject: [PATCH 1/3] WIP --- src/spikeinterface/preprocessing/remove_artifacts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 7e84822c61..8e72b96c6d 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -1,4 +1,5 @@ import numpy as np +import scipy from spikeinterface.core.core_tools import define_function_from_class From e97005aa5e94328cee3d97097b98d6a7289ee437 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 4 Oct 2023 16:21:54 +0200 Subject: [PATCH 2/3] Patch for scipy --- src/spikeinterface/preprocessing/remove_artifacts.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 8e72b96c6d..1746b23941 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -1,5 +1,4 @@ import numpy as np -import scipy from spikeinterface.core.core_tools import define_function_from_class @@ -108,8 +107,6 @@ def __init__( time_jitter=0, waveforms_kwargs={"allow_unfiltered": True, "mode": "memory"}, ): - import scipy.interpolate - available_modes = ("zeros", "linear", "cubic", "average", "median") num_seg = recording.get_num_segments() @@ -237,7 +234,6 @@ def __init__( time_pad, sparsity, ): - import scipy.interpolate BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -255,6 +251,8 @@ def __init__( self.sparsity = sparsity def get_traces(self, start_frame, end_frame, channel_indices): + + if self.mode in ["average", "median"]: traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) else: @@ -286,6 +284,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): elif trig + pad[1] >= end_frame - start_frame: traces[trig - pad[0] :, :] = 0 elif self.mode in ["linear", "cubic"]: + import scipy.interpolate for trig in triggers: if pad is None: pre_data_end_idx = trig - 1 From 2a5e37c83054999514ccacd45b3c81d1865bc196 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Oct 2023 14:23:26 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/remove_artifacts.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 1746b23941..1eafa48a0b 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -234,7 +234,6 @@ def __init__( time_pad, sparsity, ): - BasePreprocessorSegment.__init__(self, parent_recording_segment) self.triggers = np.asarray(triggers, dtype="int64") @@ -251,8 +250,6 @@ def __init__( self.sparsity = sparsity def get_traces(self, start_frame, end_frame, channel_indices): - - if self.mode in ["average", "median"]: traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) else: @@ -285,6 +282,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces[trig - pad[0] :, :] = 0 elif self.mode in ["linear", "cubic"]: import scipy.interpolate + for trig in triggers: if pad is None: pre_data_end_idx = trig - 1