From e97005aa5e94328cee3d97097b98d6a7289ee437 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 4 Oct 2023 16:21:54 +0200 Subject: [PATCH] Patch for scipy --- src/spikeinterface/preprocessing/remove_artifacts.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 8e72b96c6d..1746b23941 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -1,5 +1,4 @@ import numpy as np -import scipy from spikeinterface.core.core_tools import define_function_from_class @@ -108,8 +107,6 @@ def __init__( time_jitter=0, waveforms_kwargs={"allow_unfiltered": True, "mode": "memory"}, ): - import scipy.interpolate - available_modes = ("zeros", "linear", "cubic", "average", "median") num_seg = recording.get_num_segments() @@ -237,7 +234,6 @@ def __init__( time_pad, sparsity, ): - import scipy.interpolate BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -255,6 +251,8 @@ def __init__( self.sparsity = sparsity def get_traces(self, start_frame, end_frame, channel_indices): + + if self.mode in ["average", "median"]: traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) else: @@ -286,6 +284,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): elif trig + pad[1] >= end_frame - start_frame: traces[trig - pad[0] :, :] = 0 elif self.mode in ["linear", "cubic"]: + import scipy.interpolate for trig in triggers: if pad is None: pre_data_end_idx = trig - 1