Skip to content

Commit

Permalink
Merge pull request #3474 from yger/mf_sparse
Browse files Browse the repository at this point in the history
Sparsify the weights
  • Loading branch information
samuelgarcia authored Oct 15, 2024
2 parents 0ae32e7 + b9f2cc8 commit 9fa21c9
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def __init__(
weight_method={},
):
PeakDetector.__init__(self, recording, return_output=True)
from scipy.sparse import csr_matrix

if not HAVE_NUMBA:
raise ModuleNotFoundError('matched_filtering" needs numba which is not installed')
Expand Down Expand Up @@ -664,7 +665,7 @@ def __init__(
self.num_templates *= 2

self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1)

self.weights = csr_matrix(self.weights)
random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs)
conv_random_data = self.get_convolved_traces(random_data)
medians = np.median(conv_random_data, axis=1)
Expand Down Expand Up @@ -734,10 +735,10 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
return (local_peaks,)

def get_convolved_traces(self, traces):
import scipy.signal
from scipy.signal import oaconvolve

tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid")
scalar_products = np.dot(self.weights, tmp)
tmp = oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid")
scalar_products = self.weights.dot(tmp)
return scalar_products


Expand Down

0 comments on commit 9fa21c9

Please sign in to comment.