diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 2961f11981..d2d1afaafb 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -631,7 +631,7 @@ def __init__( weight_method={}, ): PeakDetector.__init__(self, recording, return_output=True) - import scipy + from scipy.sparse import csr_matrix if not HAVE_NUMBA: raise ModuleNotFoundError('matched_filtering" needs numba which is not installed') @@ -665,7 +665,7 @@ def __init__( self.num_templates *= 2 self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1) - self.weights = scipy.sparse.csr_matrix(self.weights) + self.weights = csr_matrix(self.weights) random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs) conv_random_data = self.get_convolved_traces(random_data) medians = np.median(conv_random_data, axis=1) @@ -735,9 +735,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) def get_convolved_traces(self, traces): - import scipy.signal - - tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") + from scipy.signal import oaconvolve + tmp = oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") scalar_products = self.weights.dot(tmp) return scalar_products