Skip to content

Commit

Permalink
Imports
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Oct 14, 2024
1 parent 5568e1a commit 1427816
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def __init__(
weight_method={},
):
PeakDetector.__init__(self, recording, return_output=True)
import scipy
from scipy.sparse import csr_matrix

if not HAVE_NUMBA:
raise ModuleNotFoundError('matched_filtering" needs numba which is not installed')
Expand Down Expand Up @@ -665,7 +665,7 @@ def __init__(
self.num_templates *= 2

self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1)
self.weights = scipy.sparse.csr_matrix(self.weights)
self.weights = csr_matrix(self.weights)
random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs)
conv_random_data = self.get_convolved_traces(random_data)
medians = np.median(conv_random_data, axis=1)
Expand Down Expand Up @@ -735,9 +735,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
return (local_peaks,)

def get_convolved_traces(self, traces):
import scipy.signal

tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid")
from scipy.signal import oaconvolve
tmp = oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid")
scalar_products = self.weights.dot(tmp)
return scalar_products

Expand Down

0 comments on commit 1427816

Please sign in to comment.