Skip to content

Commit

Permalink
Matched filtering with both peak signs simultaneously
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Oct 7, 2024
1 parent 80cc888 commit fbbb89f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 71 deletions.
103 changes: 36 additions & 67 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,47 +631,31 @@ def __init__(
self.conv_margin = prototype.shape[0]

assert peak_sign in ("both", "neg", "pos")
idx = np.argmax(np.abs(prototype))
self.nbefore = int(ms_before * recording.sampling_frequency / 1000)
if peak_sign == "neg":
assert prototype[idx] < 0, "Prototype should have a negative peak"
assert prototype[self.nbefore] < 0, "Prototype should have a negative peak"
peak_sign = "pos"
elif peak_sign == "pos":
assert prototype[idx] > 0, "Prototype should have a positive peak"
elif peak_sign == "both":
raise NotImplementedError("Matched filtering not working with peak_sign=both yet!")
assert prototype[self.nbefore] > 0, "Prototype should have a positive peak"

self.peak_sign = peak_sign
self.nbefore = int(ms_before * recording.sampling_frequency / 1000)
self.prototype = np.flip(prototype) / np.linalg.norm(prototype)

contact_locations = recording.get_channel_locations()
dist = np.linalg.norm(contact_locations[:, np.newaxis] - contact_locations[np.newaxis, :], axis=2)
weights, self.z_factors = get_convolution_weights(dist, **weight_method)
self.weights, self.z_factors = get_convolution_weights(dist, **weight_method)
self.num_z_factors = len(self.z_factors)
self.num_channels = recording.get_num_channels()
self.num_templates = self.num_channels
if peak_sign == "both":
self.weights = np.hstack((self.weights, self.weights))
self.weights[:, self.num_templates :, :] *= -1
self.num_templates *= 2

num_channels = recording.get_num_channels()
num_templates = num_channels * len(self.z_factors)
weights = weights.reshape(num_templates, -1)

templates = weights[:, None, :] * prototype[None, :, None]
templates -= templates.mean(axis=(1, 2))[:, None, None]
temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False)
temporal = temporal[:, :, :rank]
singular = singular[:, :rank]
spatial = spatial[:, :rank, :]
templates = np.matmul(temporal * singular[:, np.newaxis, :], spatial)
norms = np.linalg.norm(templates, axis=(1, 2))
del templates

temporal /= norms[:, np.newaxis, np.newaxis]
temporal = np.flip(temporal, axis=1)
spatial = np.moveaxis(spatial, [0, 1, 2], [1, 0, 2])
temporal = np.moveaxis(temporal, [0, 1, 2], [1, 2, 0])
singular = singular.T[:, :, np.newaxis]

self.temporal = temporal
self.spatial = spatial
self.singular = singular
self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1)

random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs)
conv_random_data = self.get_convolved_traces(random_data, temporal, spatial, singular)
conv_random_data = self.get_convolved_traces(random_data)
medians = np.median(conv_random_data, axis=1)
medians = medians[:, None]
noise_levels = np.median(np.abs(conv_random_data - medians), axis=1) / 0.6744897501960817
Expand All @@ -688,16 +672,13 @@ def get_trace_margin(self):
def compute(self, traces, start_frame, end_frame, segment_index, max_margin):

assert HAVE_NUMBA, "You need to install numba"
conv_traces = self.get_convolved_traces(traces, self.temporal, self.spatial, self.singular)
conv_traces = self.get_convolved_traces(traces)
conv_traces /= self.abs_thresholds[:, None]
conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin]
traces_center = conv_traces[:, self.exclude_sweep_size : -self.exclude_sweep_size]

num_z_factors = len(self.z_factors)
num_templates = traces.shape[1]

traces_center = traces_center.reshape(num_z_factors, num_templates, traces_center.shape[1])
conv_traces = conv_traces.reshape(num_z_factors, num_templates, conv_traces.shape[1])
traces_center = traces_center.reshape(self.num_z_factors, self.num_templates, traces_center.shape[1])
conv_traces = conv_traces.reshape(self.num_z_factors, self.num_templates, conv_traces.shape[1])
peak_mask = traces_center > 1

peak_mask = _numba_detect_peak_matched_filtering(
Expand All @@ -708,11 +689,13 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
self.abs_thresholds,
self.peak_sign,
self.neighbours_mask,
num_templates,
self.num_channels,
)

# Find peaks and correct for time shift
z_ind, peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask)
if self.peak_sign == "both":
peak_chan_ind = peak_chan_ind % self.num_channels

# If we want to estimate z
# peak_chan_ind = peak_chan_ind % num_channels
Expand All @@ -739,16 +722,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# return is always a tuple
return (local_peaks,)

def get_convolved_traces(self, traces, temporal, spatial, singular):
def get_convolved_traces(self, traces):
import scipy.signal

num_timesteps, num_templates = len(traces), temporal.shape[1]
num_peaks = num_timesteps - self.conv_margin + 1
scalar_products = np.zeros((num_templates, num_peaks), dtype=np.float32)
spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :])
scaled_filtered_data = spatially_filtered_data * singular
objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="valid")
scalar_products += np.sum(objective_by_rank, axis=0)
tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid")
scalar_products = np.dot(self.weights, tmp)
return scalar_products


Expand Down Expand Up @@ -873,37 +851,28 @@ def _numba_detect_peak_neg(

@numba.jit(nopython=True, parallel=False)
def _numba_detect_peak_matched_filtering(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_templates
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels
):
num_z = traces_center.shape[0]
num_templates = traces_center.shape[1]
for template_ind in range(num_templates):
for z in range(num_z):
for s in range(peak_mask.shape[2]):
if not peak_mask[z, template_ind, s]:
continue
for neighbour in range(num_templates):
if not neighbours_mask[template_ind, neighbour]:
continue
for j in range(num_z):
if not neighbours_mask[template_ind % num_channels, neighbour % num_channels]:
continue
for i in range(exclude_sweep_size):
if template_ind >= neighbour:
if z >= j:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] >= traces_center[j, neighbour, s]
)
else:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
elif template_ind < neighbour:
if z > j:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
else:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
if template_ind >= neighbour and z >= j:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] >= traces_center[j, neighbour, s]
)
else:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces[j, neighbour, s + i]
)
Expand Down
27 changes: 23 additions & 4 deletions src/spikeinterface/sortingcomponents/tests/test_peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,19 +328,38 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs)
)
assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np)

peaks_local_mf_filtering_both = detect_peaks(
recording,
method="matched_filtering",
peak_sign="both",
detect_threshold=5,
exclude_sweep_ms=0.1,
prototype=prototype,
ms_before=1.0,
**job_kwargs,
)
assert len(peaks_local_mf_filtering_both) > len(peaks_local_mf_filtering)

DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

peaks = peaks_local_mf_filtering
peaks_local = peaks_by_channel_np
peaks_mf_neg = peaks_local_mf_filtering
peaks_mf_both = peaks_local_mf_filtering_both
labels = ["locally_exclusive", "mf_neg", "mf_both"]

sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"]
fig, ax = plt.subplots()
chan_offset = 500
traces = recording.get_traces().copy()
traces += np.arange(traces.shape[1])[None, :] * chan_offset
fig, ax = plt.subplots()
ax.plot(traces, color="k")
ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, color="r")

for count, peaks in enumerate([peaks_local, peaks_mf_neg, peaks_mf_both]):
sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"]
ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, label=labels[count])

ax.legend()
plt.show()


Expand Down

0 comments on commit fbbb89f

Please sign in to comment.