Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matched filtering to enhance peak detection #2259

Merged
merged 152 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
152 commits
Select commit Hold shift + click to select a range
d43103b
Adding matched filtering detection
yger Oct 10, 2023
dbe80f8
SC2 with matched filtering
yger Oct 10, 2023
f472cfc
WIP
yger Oct 10, 2023
6b58f87
WIP
yger Oct 11, 2023
b7cb4b4
WIP
yger Oct 11, 2023
cc9dacb
Merge branch 'SpikeInterface:main' into matched_filtering
yger Oct 22, 2023
6f2e036
Merge branch 'SpikeInterface:main' into matched_filtering
yger Oct 24, 2023
3242026
WIP
yger Oct 26, 2023
e8a9875
Merge branch 'SpikeInterface:main' into matched_filtering
yger Oct 26, 2023
4a68ad3
WIP
yger Oct 26, 2023
8949154
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Oct 26, 2023
3c45af9
WIP
yger Oct 27, 2023
b49a10f
WIP
yger Oct 27, 2023
85efeee
WIP
yger Nov 6, 2023
691991e
WIP
yger Nov 7, 2023
ea9a97c
Cleaning with ptp maps computed while detecting peaks
yger Nov 7, 2023
cdb6214
Merge branch 'SpikeInterface:main' into matched_filtering
yger Nov 9, 2023
b78bf5d
WIP
yger Nov 9, 2023
3a2d2f7
Merge branch 'SpikeInterface:main' into matched_filtering
yger Nov 9, 2023
61349c6
WIP
yger Nov 23, 2023
588c511
WIP
yger Nov 23, 2023
844e5b3
WIP
yger Nov 23, 2023
38bf2d1
WIP
yger Nov 23, 2023
92b8ca1
Merge branch 'SpikeInterface:main' into matched_filtering
yger Nov 24, 2023
5c80044
Adding the matched filtering for peak detection
yger Nov 24, 2023
8ac71c1
Better with only one sigma
yger Nov 24, 2023
e6a5bf7
Adding a test for matched filtering
yger Nov 24, 2023
e0b3a7a
Some assertions
yger Nov 24, 2023
34d9e77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2023
6b36b71
WIP
yger Nov 24, 2023
55417ca
WIp
yger Nov 24, 2023
e7db8d9
Cleaning
yger Nov 24, 2023
fc35e5a
Cleaning'
yger Nov 24, 2023
2faa212
Fixes
yger Nov 24, 2023
7a20fb4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2023
7101f0a
Merge branch 'main' into matched_filtering
yger Nov 24, 2023
4ff592e
default n_jobs
yger Nov 24, 2023
83154b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2023
c2f29cd
Merge branch 'SpikeInterface:main' into matched_filtering
yger Nov 27, 2023
100e944
Merge branch 'SpikeInterface:main' into matched_filtering
yger Nov 28, 2023
7a1832a
Merge branch 'SpikeInterface:main' into matched_filtering
yger Nov 28, 2023
803c571
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 4, 2023
f200a12
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 5, 2023
1e0bcdc
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Dec 5, 2023
ff87c59
Multiple depths
yger Dec 5, 2023
18888ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2023
2f4ad4e
Calibrating
yger Dec 5, 2023
ddf2da0
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Dec 5, 2023
7c6fafc
WIP
yger Dec 11, 2023
7641b99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2023
96d34cd
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 11, 2023
dee1018
Simplifications
yger Dec 12, 2023
65caace
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2023
d2ef11d
Export also the depths where peak has been found
yger Dec 12, 2023
077d81e
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Dec 12, 2023
8c5f1e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2023
560b3fc
Simplification
yger Dec 12, 2023
5e6c156
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Dec 12, 2023
4424dc1
Should we estimate depth as weighted average?
yger Dec 12, 2023
09d1d61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2023
4b3fe94
Same params
yger Dec 12, 2023
f34786a
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 13, 2023
5f0a93a
Fixing the shape of the templates
yger Dec 13, 2023
bb699bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
305dfac
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 13, 2023
f78e04a
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 13, 2023
5ca0523
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Dec 13, 2023
b1336c3
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 14, 2023
a858904
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Dec 14, 2023
301f59b
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 15, 2023
1451e3a
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Dec 15, 2023
166a07e
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Dec 15, 2023
b558299
WIP
yger Dec 15, 2023
ec76bee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 15, 2023
f3779e3
Fixes
yger Dec 15, 2023
4fac47d
WIP
yger Dec 15, 2023
a4ee1ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 15, 2023
2bb8c03
Bug in prototype
yger Dec 15, 2023
1391215
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Dec 15, 2023
b77c132
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 15, 2023
4b4bf42
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 18, 2023
8092576
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Dec 18, 2023
60c15fc
WIP
yger Dec 18, 2023
a57b77a
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 19, 2023
3d23d98
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 19, 2023
8e929bb
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Dec 20, 2023
dd14e9f
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Dec 20, 2023
4e72653
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 20, 2023
3dbd312
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Dec 20, 2023
ff4e835
Merge branch 'SpikeInterface:main' into matched_filtering
yger Dec 20, 2023
764eb2d
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Dec 20, 2023
304e2bd
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Dec 20, 2023
8a1ca30
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Dec 21, 2023
6d30edf
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Dec 21, 2023
9544348
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Dec 21, 2023
2f5fd7d
WIP
yger Dec 21, 2023
59f9c95
Merge branch 'SpikeInterface:main' into matched_filtering
yger Jan 2, 2024
3ba776f
Merge branch 'matched_filtering' of github.com:yger/spikeinterface in…
yger Jan 2, 2024
c768f72
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Jan 2, 2024
6746621
New arguments
yger Jan 2, 2024
d49d48b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2024
085014e
WIP
yger Jan 2, 2024
5c7e482
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2024
72ac994
Merge branch 'SpikeInterface:main' into matched_filtering
yger Jan 3, 2024
f95c7f0
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Jan 5, 2024
67f868a
Merge branch 'SpikeInterface:main' into matched_filtering
yger Jan 12, 2024
e48cb5d
Merge branch 'SpikeInterface:main' into matched_filtering
yger Jan 12, 2024
d6fb1a1
Merge branch 'SpikeInterface:main' into matched_filtering
yger Jan 12, 2024
98ec656
Merge branch 'SpikeInterface:main' into matched_filtering
yger Jan 15, 2024
d6d1648
Merge branch 'SpikeInterface:main' into matched_filtering
yger Jan 19, 2024
375831f
WIP
yger Jan 19, 2024
c233f73
Merge branch 'main' into matched_filtering
yger Jan 24, 2024
7d87668
Merge branch 'grid_depth' of github.com:yger/spikeinterface into matc…
yger Jan 24, 2024
4d1ca87
Merge branch 'main' into matched_filtering
yger Jan 30, 2024
368685b
Merge branch 'SpikeInterface:main' into matched_filtering
yger Jan 30, 2024
72a72c4
New signature to get prototype
yger Jan 30, 2024
f971f2e
Merge branch 'SpikeInterface:main' into matched_filtering
yger Feb 1, 2024
96accaf
Merge branch 'SpikeInterface:main' into matched_filtering
yger Feb 2, 2024
6013c53
Merge branch 'SpikeInterface:main' into matched_filtering
yger Feb 2, 2024
73a00f2
Merge branch 'SpikeInterface:main' into matched_filtering
yger Feb 5, 2024
6455e2a
Merge branch 'SpikeInterface:main' into matched_filtering
yger Feb 6, 2024
b937cab
Merge branch 'SpikeInterface:main' into matched_filtering
yger Feb 16, 2024
1e635a4
WIP
yger Mar 11, 2024
a59ae17
WIP
yger Mar 11, 2024
f4801aa
WIP
yger Mar 14, 2024
d692c18
Naming
yger Mar 14, 2024
0bfdaf1
Names
yger Mar 14, 2024
39a0990
Merge branch 'SpikeInterface:main' into matched_filtering
yger Mar 15, 2024
021fdaf
Default params to speed up as almost no cost
yger Mar 15, 2024
f5e332d
Merge branch 'main' into matched_filtering
yger Mar 15, 2024
3dabe45
Merge branch 'main' of github.com:yger/spikeinterface into matched_fi…
yger Mar 20, 2024
a4a962f
Merge branch 'SpikeInterface:main' into matched_filtering
yger Mar 27, 2024
05483fb
Merge branch 'main' into matched_filtering
yger Mar 29, 2024
ce5da53
patch
yger Mar 29, 2024
c59f8ea
Merge branch 'main' into matched_filtering
yger Mar 29, 2024
59cf286
wrong merge
yger Mar 29, 2024
e3874f2
Merge branch 'SpikeInterface:main' into matched_filtering
yger Apr 1, 2024
de83198
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 3, 2024
871b564
Merge branch 'SpikeInterface:main' into matched_filtering
yger Apr 4, 2024
a5b9609
Merge branch 'main' into matched_filtering
yger Apr 9, 2024
4a5c077
Merge branch 'main' into matched_filtering
yger Apr 10, 2024
c274ea0
Merge branch 'main' into matched_filtering
yger Apr 11, 2024
ba0a3cb
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 12, 2024
cecfee6
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Apr 15, 2024
e7c2451
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 17, 2024
747b2ee
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 19, 2024
1a3d48d
Merge branch 'SpikeInterface:main' into matched_filtering
yger Apr 23, 2024
874b4a9
Refactor DetectPeakMatchedFiltering into one class
samuelgarcia Apr 26, 2024
85d29e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
32186da
Merge branch 'main' into matched_filtering
samuelgarcia Apr 26, 2024
a210ba9
Add convolution margin for "matched_filtering" peak detection
samuelgarcia Apr 30, 2024
2ab25d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from spikeinterface.core.sparsity import compute_sparsity
from spikeinterface.sortingcomponents.tools import remove_empty_templates

from spikeinterface.sortingcomponents.tools import get_prototype_spike

try:
import hdbscan

Expand All @@ -42,6 +44,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"clustering": {"legacy": False},
"matching": {"method": "circus-omp-svd"},
"apply_preprocessing": True,
"matched_filtering": False,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"multi_units_only": False,
"job_kwargs": {"n_jobs": 0.8},
Expand Down Expand Up @@ -99,6 +102,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

sampling_frequency = recording.get_sampling_frequency()
num_channels = recording.get_num_channels()
ms_before = params["general"].get("ms_before", 2)
ms_after = params["general"].get("ms_after", 2)
radius_um = params["general"].get("radius_um", 100)

## First, we are filtering the data
filtering_params = params["filtering"].copy()
Expand Down Expand Up @@ -126,11 +132,31 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
radius_um = params["general"].get("radius_um", 100)
if "radius_um" not in detection_params:
detection_params["radius_um"] = radius_um

if "exclude_sweep_ms" not in detection_params:
detection_params["exclude_sweep_ms"] = max(params["general"]["ms_before"], params["general"]["ms_after"])
detection_params["exclude_sweep_ms"] = max(ms_before, ms_after)
if "radius_um" not in detection_params:
detection_params["radius_um"] = radius_um
detection_params["noise_levels"] = noise_levels

peaks = detect_peaks(recording_f, method="locally_exclusive", **detection_params)
fs = recording_f.get_sampling_frequency()
nbefore = int(ms_before * fs / 1000.0)
nafter = int(ms_after * fs / 1000.0)

peaks = detect_peaks(recording_f, "locally_exclusive", **detection_params)

if params["matched_filtering"]:
prototype = get_prototype_spike(recording_f, peaks, ms_before, ms_after, **job_kwargs)
detection_params["prototype"] = prototype

matching_job_params = job_kwargs.copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in matching_job_params:
matching_job_params.pop(value)

matching_job_params["chunk_duration"] = "100ms"

peaks = detect_peaks(recording_f, "matched_filtering", **detection_params, **matching_job_params)

if verbose:
print("We found %d peaks in total" % len(peaks))
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ExtractSparseWaveforms,
PeakRetriever,
)

from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel


Expand Down
197 changes: 196 additions & 1 deletion src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
split_job_kwargs,
fix_job_kwargs,
)
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances, get_random_data_chunks

from spikeinterface.core.baserecording import BaseRecording
from spikeinterface.core.node_pipeline import (
Expand All @@ -25,6 +25,7 @@
base_peak_dtype,
)

from spikeinterface.postprocessing.unit_localization import get_convolution_weights
from ..core import get_chunk_with_margin

from .tools import make_multi_method_doc
Expand Down Expand Up @@ -586,6 +587,174 @@ def detect_peaks(cls, traces, peak_sign, abs_thresholds, exclude_sweep_size, nei
return peak_sample_ind, peak_chan_ind


class DetectPeakMatchedFiltering(PeakDetector):
"""Detect peaks using the 'matched_filtering' method."""

name = "matched_filtering"
engine = "numba"
preferred_mp_context = None
params_doc = (
DetectPeakByChannel.params_doc
+ """
radius_um: float
The radius to use to select neighbour channels for locally exclusive detection.
prototype: array
The canonical waveform of action potentials
rank : int (default 1)
The rank for SVD convolution of spatiotemporal templates with the traces
weight_method: dict
Parameter that should be provided to the get_convolution_weights() function
in order to know how to estimate the positions. One argument is mode that could
be either gaussian_2d (KS like) or exponential_3d (default)
"""
)

def __init__(
self,
recording,
prototype,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
radius_um=50,
rank=1,
noise_levels=None,
random_chunk_kwargs={"num_chunks_per_segment": 5},
weight_method={},
):
PeakDetector.__init__(self, recording, return_output=True)

if not HAVE_NUMBA:
raise ModuleNotFoundError('matched_filtering" needs numba which is not installed')

self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0)
channel_distance = get_channel_distances(recording)
self.neighbours_mask = channel_distance <= radius_um

self.conv_margin = prototype.shape[0]

assert peak_sign in ("both", "neg", "pos")
idx = np.argmax(np.abs(prototype))
if peak_sign == "neg":
assert prototype[idx] < 0, "Prototype should have a negative peak"
peak_sign = "pos"
elif peak_sign == "pos":
assert prototype[idx] > 0, "Prototype should have a positive peak"
elif peak_sign == "both":
raise NotImplementedError("Matched filtering not working with peak_sign=both yet!")

self.peak_sign = peak_sign
contact_locations = recording.get_channel_locations()
dist = np.linalg.norm(contact_locations[:, np.newaxis] - contact_locations[np.newaxis, :], axis=2)
weights, self.z_factors = get_convolution_weights(dist, **weight_method)

num_channels = recording.get_num_channels()
num_templates = num_channels * len(self.z_factors)
weights = weights.reshape(num_templates, -1)

templates = weights[:, None, :] * prototype[None, :, None]
templates -= templates.mean(axis=(1, 2))[:, None, None]
temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False)
temporal = temporal[:, :, :rank]
singular = singular[:, :rank]
spatial = spatial[:, :rank, :]
templates = np.matmul(temporal * singular[:, np.newaxis, :], spatial)
norms = np.linalg.norm(templates, axis=(1, 2))
del templates

temporal /= norms[:, np.newaxis, np.newaxis]
temporal = np.flip(temporal, axis=1)
spatial = np.moveaxis(spatial, [0, 1, 2], [1, 0, 2])
temporal = np.moveaxis(temporal, [0, 1, 2], [1, 2, 0])
singular = singular.T[:, :, np.newaxis]

self.temporal = temporal
self.spatial = spatial
self.singular = singular

random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs)
conv_random_data = self.get_convolved_traces(random_data, temporal, spatial, singular)
medians = np.median(conv_random_data, axis=1)
medians = medians[:, None]
noise_levels = np.median(np.abs(conv_random_data - medians), axis=1) / 0.6744897501960817
self.abs_thresholds = noise_levels * detect_threshold

self._dtype = np.dtype(base_peak_dtype + [("z", "float32")])

def get_dtype(self):
return self._dtype

def get_trace_margin(self):
return self.exclude_sweep_size + self.conv_margin

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):

# peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask, temporal, spatial, singular, z_factors = self.args

assert HAVE_NUMBA, "You need to install numba"
conv_traces = self.get_convolved_traces(traces, self.temporal, self.spatial, self.singular)
conv_traces /= self.abs_thresholds[:, None]
conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin]
traces_center = conv_traces[:, self.exclude_sweep_size : -self.exclude_sweep_size]
num_z_factors = len(self.z_factors)
num_channels = conv_traces.shape[0] // num_z_factors

peak_mask = traces_center > 1
peak_mask = _numba_detect_peak_matched_filtering(
conv_traces,
traces_center,
peak_mask,
self.exclude_sweep_size,
self.abs_thresholds,
self.peak_sign,
self.neighbours_mask,
num_channels,
)

# Find peaks and correct for time shift
peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask)

# If we do not want to estimate the z accurately
z = self.z_factors[peak_chan_ind // num_channels]
peak_chan_ind = peak_chan_ind % num_channels

# If we want to estimate z
# peak_chan_ind = peak_chan_ind % num_channels
# z = np.zeros(len(peak_sample_ind), dtype=np.float32)
# for count in range(len(peak_chan_ind)):
# channel = peak_chan_ind[count]
# peak = peak_sample_ind[count]
# data = traces[channel::num_channels, peak]
# z[count] = np.dot(data, z_factors)/data.sum()

if peak_sample_ind.size == 0 or peak_chan_ind.size == 0:
return (np.zeros(0, dtype=self._dtype),)

peak_sample_ind += self.exclude_sweep_size + self.conv_margin

peak_amplitude = traces[peak_sample_ind, peak_chan_ind]
local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype)
local_peaks["sample_index"] = peak_sample_ind
local_peaks["channel_index"] = peak_chan_ind
local_peaks["amplitude"] = peak_amplitude
local_peaks["segment_index"] = segment_index
local_peaks["z"] = z

# return is always a tuple
return (local_peaks,)

def get_convolved_traces(self, traces, temporal, spatial, singular):
import scipy.signal

num_timesteps, num_templates = len(traces), temporal.shape[1]
scalar_products = np.zeros((num_templates, num_timesteps), dtype=np.float32)
spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :])
scaled_filtered_data = spatially_filtered_data * singular
objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="same")
scalar_products += np.sum(objective_by_rank, axis=0)
return scalar_products


class DetectPeakLocallyExclusiveTorch(PeakDetectorWrapper):
"""Detect peaks using the "locally exclusive" method with pytorch."""

Expand Down Expand Up @@ -705,6 +874,31 @@ def _numba_detect_peak_neg(
break
return peak_mask

@numba.jit(nopython=True, parallel=False)
def _numba_detect_peak_matched_filtering(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels
):
num_chans = traces_center.shape[0]
for chan_ind in range(num_chans):
for s in range(peak_mask.shape[1]):
if not peak_mask[chan_ind, s]:
continue
for neighbour in range(num_chans):
if not neighbours_mask[chan_ind % num_channels, neighbour % num_channels]:
continue
for i in range(exclude_sweep_size):
if chan_ind != neighbour:
peak_mask[chan_ind, s] &= traces_center[chan_ind, s] >= traces_center[neighbour, s]
peak_mask[chan_ind, s] &= traces_center[chan_ind, s] > traces[neighbour, s + i]
peak_mask[chan_ind, s] &= (
traces_center[chan_ind, s] >= traces[neighbour, exclude_sweep_size + s + i + 1]
)
if not peak_mask[chan_ind, s]:
break
if not peak_mask[chan_ind, s]:
break
return peak_mask


if HAVE_TORCH:

Expand Down Expand Up @@ -1089,6 +1283,7 @@ def detect_peak(self, traces):
DetectPeakLocallyExclusiveOpenCL,
DetectPeakByChannelTorch,
DetectPeakLocallyExclusiveTorch,
DetectPeakMatchedFiltering,
]
detect_peak_methods = {m.name: m for m in _methods_list}
method_doc = make_multi_method_doc(_methods_list)
Expand Down
46 changes: 44 additions & 2 deletions src/spikeinterface/sortingcomponents/tests/test_peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

from spikeinterface.core.node_pipeline import run_node_pipeline
from spikeinterface.sortingcomponents.tools import get_prototype_spike

from spikeinterface.sortingcomponents.tests.common import make_dataset

Expand Down Expand Up @@ -306,6 +307,42 @@ def test_detect_peaks_locally_exclusive(recording, job_kwargs, torch_job_kwargs)
assert len(peaks_local_numba) == len(peaks_local_cl)


def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs):
peaks_by_channel_np = detect_peaks(
recording, method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, **job_kwargs
)

ms_before = 1.0
ms_after = 1.0
prototype = get_prototype_spike(recording, peaks_by_channel_np, ms_before, ms_after, **job_kwargs)

peaks_local_mf_filtering = detect_peaks(
recording,
method="matched_filtering",
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
prototype=prototype,
**job_kwargs,
)
assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np)

DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

peaks = peaks_local_mf_filtering

sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"]
chan_offset = 500
traces = recording.get_traces().copy()
traces += np.arange(traces.shape[1])[None, :] * chan_offset
fig, ax = plt.subplots()
ax.plot(traces, color="k")
ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, color="r")
plt.show()


detection_classes = [
DetectPeakByChannel,
DetectPeakByChannelTorch,
Expand Down Expand Up @@ -465,7 +502,7 @@ def test_peak_detection_with_pipeline(recording, job_kwargs, torch_job_kwargs, t
tmp_path = Path(tempfile.mkdtemp())

job_kwargs_main = job_kwargs()
torch_job_kwargs_main = torch_job_kwargs(job_kwargs_main)
# torch_job_kwargs_main = torch_job_kwargs(job_kwargs_main)
# Create a temporary directory using the standard library
# tmp_dir_main = tempfile.mkdtemp()
# pca_model_folder_path_main = pca_model_folder_path(recording, job_kwargs_main, tmp_dir_main)
Expand All @@ -476,4 +513,9 @@ def test_peak_detection_with_pipeline(recording, job_kwargs, torch_job_kwargs, t
# )

# test_peak_sign_consistency(recording, torch_job_kwargs_main, DetectPeakLocallyExclusiveTorch)
test_peak_detection_with_pipeline(recording, job_kwargs_main, torch_job_kwargs_main, tmp_path)
# test_peak_detection_with_pipeline(recording, job_kwargs_main, torch_job_kwargs_main, tmp_path)

test_detect_peaks_locally_exclusive_matched_filtering(
recording,
job_kwargs_main,
)
Loading