From 0ae32e729acb0be01d1cee28453e3ad3503877fb Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Fri, 11 Oct 2024 11:32:27 +0200 Subject: [PATCH] Tdc peeler (#3466) Improving the Peeler --- .../benchmark/benchmark_matching.py | 73 +- .../benchmark/benchmark_plot_tools.py | 64 +- .../sortingcomponents/matching/tdc.py | 686 +++++++++++++----- .../sortingcomponents/matching/wobble.py | 2 +- .../tests/test_template_matching.py | 40 +- 5 files changed, 608 insertions(+), 257 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index c53567f460..3799fa19b3 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -33,7 +33,7 @@ def run(self, **job_kwargs): sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids) - self.result = {"sorting": sorting} + self.result = {"sorting": sorting, "spikes": spikes} self.result["templates"] = self.templates def compute_result(self, with_collision=False, **result_params): @@ -45,6 +45,7 @@ def compute_result(self, with_collision=False, **result_params): _run_key_saved = [ ("sorting", "sorting"), + ("spikes", "npy"), ("templates", "zarr_templates"), ] _result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")] @@ -71,6 +72,11 @@ def plot_performances_vs_snr(self, **kwargs): return plot_performances_vs_snr(self, **kwargs) + def plot_performances_comparison(self, **kwargs): + from .benchmark_plot_tools import plot_performances_comparison + + return plot_performances_comparison(self, **kwargs) + def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) @@ -90,70 +96,6 @@ def plot_collisions(self, case_keys=None, figsize=None): return fig - def plot_comparison_matching( - self, - case_keys=None, - performance_names=["accuracy", "recall", "precision"], - colors=["g", "b", "r"], - ylim=(-0.1, 1.1), - figsize=None, - ): - - if case_keys is None: - case_keys = list(self.cases.keys()) - - num_methods = len(case_keys) - import pylab as plt - - fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10)) - for i, key1 in enumerate(case_keys): - for j, key2 in enumerate(case_keys): - if len(axs.shape) > 1: - ax = axs[i, j] - else: - ax = axs[j] - comp1 = self.get_result(key1)["gt_comparison"] - comp2 = self.get_result(key2)["gt_comparison"] - if i <= j: - for performance, color in zip(performance_names, colors): - perf1 = comp1.get_performance()[performance] - perf2 = comp2.get_performance()[performance] - ax.plot(perf2, perf1, ".", label=performance, color=color) - - ax.plot([0, 1], [0, 1], "k--", alpha=0.5) - ax.set_ylim(ylim) - ax.set_xlim(ylim) - ax.spines[["right", "top"]].set_visible(False) - ax.set_aspect("equal") - - label1 = self.cases[key1]["label"] - label2 = self.cases[key2]["label"] - if j == i: - ax.set_ylabel(f"{label1}") - else: - ax.set_yticks([]) - if i == j: - ax.set_xlabel(f"{label2}") - else: - ax.set_xticks([]) - if i == num_methods - 1 and j == num_methods - 1: - patches = [] - import matplotlib.patches as mpatches - - for color, name in zip(colors, performance_names): - patches.append(mpatches.Patch(color=color, label=name)) - ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) - else: - ax.spines["bottom"].set_visible(False) - ax.spines["left"].set_visible(False) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.set_xticks([]) - ax.set_yticks([]) - plt.tight_layout(h_pad=0, w_pad=0) - - return fig - def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): import pandas as pd @@ -196,6 +138,7 @@ def plot_unit_counts(self, case_keys=None, figsize=None): plot_study_unit_counts(self, case_keys, figsize=figsize) def plot_unit_losses(self, before, after, metric=["precision"], figsize=None): + import matplotlib.pyplot as plt fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False) diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index a6e9b6dacc..e15636ebaf 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -235,9 +235,71 @@ def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accu ax.scatter(x, y, marker=".", label=label) ax.set_title(k) - ax.set_ylim(0, 1.05) + ax.set_ylim(-0.05, 1.05) if count == 2: ax.legend() return fig + + +def plot_performances_comparison( + study, + case_keys=None, + figsize=None, + metrics=["accuracy", "recall", "precision"], + colors=["g", "b", "r"], + ylim=(-0.1, 1.1), +): + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + num_methods = len(case_keys) + assert num_methods >= 2, "plot_performances_comparison need at least 2 cases!" + + fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=(10, 10), squeeze=False) + for i, key1 in enumerate(case_keys): + for j, key2 in enumerate(case_keys): + + if i < j: + ax = axs[i, j - 1] + + comp1 = study.get_result(key1)["gt_comparison"] + comp2 = study.get_result(key2)["gt_comparison"] + + for performance, color in zip(metrics, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.scatter(perf2, perf1, marker=".", label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + label1 = study.cases[key1]["label"] + label2 = study.cases[key2]["label"] + + if i == j - 1: + ax.set_xlabel(label2) + ax.set_ylabel(label1) + + else: + if j >= 1 and i < num_methods - 1: + ax = axs[i, j - 1] + ax.spines[["right", "top", "left", "bottom"]].set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + + ax = axs[num_methods - 2, 0] + patches = [] + from matplotlib.patches import Patch + + for color, name in zip(colors, metrics): + patches.append(Patch(color=color, label=name)) + ax.legend(handles=patches) + fig.tight_layout() + return fig diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 56457fe2fa..125baa3bda 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -2,15 +2,11 @@ import numpy as np from spikeinterface.core import ( - get_noise_levels, get_channel_distances, - compute_sparsity, get_template_extremum_channel, ) -from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive -from spikeinterface.core.template import Templates - +from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive, DetectPeakMatchedFiltering from .base import BaseTemplateMatching, _base_matching_dtype @@ -25,7 +21,7 @@ class TridesclousPeeler(BaseTemplateMatching): """ - Template-matching ported from Tridesclous sorter. + Template-matching used by Tridesclous sorter. The idea of this peeler is pretty simple. 1. Find peaks @@ -34,8 +30,10 @@ class TridesclousPeeler(BaseTemplateMatching): 4. remove it from traces. 5. in the residual find peaks again - This method is quite fast but don't give exelent results to resolve - spike collision when templates have high similarity. + Contrary tp circus_peeler or wobble, this template matching is working directly one the waveforms. + There is no SVD decomposition + + """ def __init__( @@ -45,26 +43,29 @@ def __init__( parents=None, templates=None, peak_sign="neg", + exclude_sweep_ms=0.5, peak_shift_ms=0.2, detect_threshold=5, noise_levels=None, - radius_um=100.0, - num_closest=5, - sample_shift=3, - ms_before=0.8, - ms_after=1.2, - num_peeler_loop=2, - num_template_try=1, + use_fine_detector=True, + # TODO optimize theses radius + detection_radius_um=80.0, + cluster_radius_um=150.0, + amplitude_fitting_radius_um=150.0, + sample_shift=2, + ms_before=0.5, + ms_after=0.8, + max_peeler_loop=2, + amplitude_limits=(0.7, 1.4), ): BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - # maybe in base? - self.templates_array = templates.get_dense_templates() - unit_ids = templates.unit_ids channel_ids = recording.channel_ids + num_templates = unit_ids.size + sr = recording.sampling_frequency self.nbefore = templates.nbefore @@ -82,8 +83,9 @@ def __init__( s1 = -(templates.nafter - nafter_short) if s1 == 0: s1 = None + # TODO check with out copy - self.templates_short = self.templates_array[:, slice(s0, s1), :].copy() + self.sparse_templates_array_short = templates.templates_array[:, slice(s0, s1), :].copy() self.peak_shift = int(peak_shift_ms / 1000 * sr) @@ -92,12 +94,12 @@ def __init__( self.abs_thresholds = noise_levels * detect_threshold channel_distance = get_channel_distances(recording) - self.neighbours_mask = channel_distance < radius_um + self.neighbours_mask = channel_distance <= detection_radius_um if templates.sparsity is not None: - self.template_sparsity = templates.sparsity.mask + self.sparsity_mask = templates.sparsity.mask else: - self.template_sparsity = np.ones((unit_ids.size, channel_ids.size), dtype=bool) + self.sparsity_mask = np.ones((unit_ids.size, channel_ids.size), dtype=bool) extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") # as numpy vector @@ -109,72 +111,108 @@ def __init__( # distance between units import scipy - unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - - # seach for closet units and unitary discriminant vector - closest_units = [] - for unit_ind, unit_id in enumerate(unit_ids): - order = np.argsort(unit_distances[unit_ind, :]) - closest_u = np.arange(unit_ids.size)[order].tolist() - closest_u.remove(unit_ind) - closest_u = np.array(closest_u[:num_closest]) - - # compute unitary discriminent vector - (chans,) = np.nonzero(self.template_sparsity[unit_ind, :]) - template_sparse = self.templates_array[unit_ind, :, :][:, chans] - closest_vec = [] - # against N closets - for u in closest_u: - vec = self.templates_array[u, :, :][:, chans] - template_sparse - vec /= np.sum(vec**2) - closest_vec.append((u, vec)) - # against noise - closest_vec.append((None, -template_sparse / np.sum(template_sparse**2))) - - closest_units.append(closest_vec) - - self.closest_units = closest_units - - # distance channel from unit - import scipy - - distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") - near_cluster_mask = distances < radius_um - # nearby cluster for each channel + distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") + near_cluster_mask = distances <= cluster_radius_um self.possible_clusters_by_channel = [] for channel_index in range(distances.shape[0]): (cluster_inds,) = np.nonzero(near_cluster_mask[channel_index, :]) self.possible_clusters_by_channel.append(cluster_inds) + # precompute template norms ons sparse channels + self.template_norms = np.zeros(num_templates, dtype="float32") + for i in range(unit_ids.size): + chan_mask = self.sparsity_mask[i, :] + n = np.sum(chan_mask) + template = templates.templates_array[i, :, :n] + self.template_norms[i] = np.sum(template**2) + + # + distances = scipy.spatial.distance.cdist(channel_locations, channel_locations, metric="euclidean") + self.near_chan_mask = distances <= amplitude_fitting_radius_um + self.possible_shifts = np.arange(-sample_shift, sample_shift + 1, dtype="int64") - self.num_peeler_loop = num_peeler_loop - self.num_template_try = num_template_try + self.max_peeler_loop = max_peeler_loop + self.amplitude_limits = amplitude_limits + + self.fast_spike_detector = DetectPeakLocallyExclusive( + recording=recording, + peak_sign=peak_sign, + detect_threshold=detect_threshold, + exclude_sweep_ms=exclude_sweep_ms, + radius_um=detection_radius_um, + noise_levels=noise_levels, + ) - self.margin = max(self.nbefore, self.nafter) * 2 + ##get prototype from best channel of each template + prototype = np.zeros(self.nbefore + self.nafter, dtype="float32") + for i in range(num_templates): + template = templates.templates_array[i, :, :] + chan_ind = np.argmax(np.abs(template[self.nbefore, :])) + if template[self.nbefore, chan_ind] != 0: + prototype += template[:, chan_ind] / np.abs(template[self.nbefore, chan_ind]) + prototype /= np.abs(prototype[self.nbefore]) + + # import matplotlib.pyplot as plt + # fig,ax = plt.subplots() + # ax.plot(prototype) + # plt.show() + + self.use_fine_detector = use_fine_detector + if self.use_fine_detector: + self.fine_spike_detector = DetectPeakMatchedFiltering( + recording=recording, + prototype=prototype, + ms_before=templates.nbefore / sr * 1000.0, + peak_sign="neg", + detect_threshold=detect_threshold, + exclude_sweep_ms=exclude_sweep_ms, + radius_um=detection_radius_um, + weight_method=dict( + z_list_um=np.array([50.0]), + sigma_3d=2.5, + mode="exponential_3d", + ), + noise_levels=None, + ) + + self.detector_margin0 = self.fast_spike_detector.get_trace_margin() + self.detector_margin1 = self.fine_spike_detector.get_trace_margin() if use_fine_detector else 0 + self.peeler_margin = max(self.nbefore, self.nafter) * 2 + self.margin = max(self.peeler_margin, self.detector_margin0, self.detector_margin1) def get_trace_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): - traces = traces.copy() + + # TODO check if this is usefull + residuals = traces.copy() all_spikes = [] level = 0 + spikes_prev_loop = np.zeros(0, dtype=_base_matching_dtype) + use_fine_detector_level = False while True: - # spikes = _tdc_find_spikes(traces, d, level=level) - spikes = self._find_spikes_one_level(traces, level=level) - keep = spikes["cluster_index"] >= 0 - - if not np.any(keep): - break - all_spikes.append(spikes[keep]) + # print('level', level) + spikes = self._find_spikes_one_level(residuals, spikes_prev_loop, use_fine_detector_level, level) + if spikes.size > 0: + all_spikes.append(spikes) level += 1 - if level == self.num_peeler_loop: - break + # TODO concatenate all spikes for this instead of prev loop + spikes_prev_loop = spikes + + if (spikes.size == 0) or (level == self.max_peeler_loop): + if self.use_fine_detector and not use_fine_detector_level: + # extra loop with fine detector + use_fine_detector_level = True + level = self.max_peeler_loop - 1 + continue + else: + break if len(all_spikes) > 0: all_spikes = np.concatenate(all_spikes) @@ -185,13 +223,34 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): return all_spikes - def _find_spikes_one_level(self, traces, level=0): + def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, level): - peak_traces = traces[self.margin // 2 : -self.margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask - ) - peak_sample_ind += self.margin // 2 + # print(use_fine_detector, level) + + # TODO change the threhold dynaically depending the level + # peak_traces = traces[self.detector_margin : -self.detector_margin, :] + + # peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( + # peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask + # ) + + if use_fine_detector: + peak_detector = self.fine_spike_detector + else: + peak_detector = self.fast_spike_detector + + detector_margin = peak_detector.get_trace_margin() + if self.peeler_margin > detector_margin: + margin_shift = self.peeler_margin - detector_margin + sl = slice(margin_shift, -margin_shift) + else: + sl = slice(None) + margin_shift = 0 + peak_traces = traces[sl, :] + (peaks,) = peak_detector.compute(peak_traces, None, None, 0, self.margin) + peak_sample_ind = peaks["sample_index"] + peak_chan_ind = peaks["channel_index"] + peak_sample_ind += margin_shift peak_amplitude = traces[peak_sample_ind, peak_chan_ind] order = np.argsort(np.abs(peak_amplitude))[::-1] @@ -200,153 +259,438 @@ def _find_spikes_one_level(self, traces, level=0): spikes = np.zeros(peak_sample_ind.size, dtype=_base_matching_dtype) spikes["sample_index"] = peak_sample_ind - spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template + spikes["channel_index"] = peak_chan_ind - possible_shifts = self.possible_shifts - distances_shift = np.zeros(possible_shifts.size) + distances_shift = np.zeros(self.possible_shifts.size) - for i in range(peak_sample_ind.size): + delta_sample = max(self.nbefore, self.nafter) # TODO check this maybe add margin + # neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + + # neighbors in actual and previous level + neighbors_spikes_inds = get_neighbors_spikes( + np.concatenate([spikes["sample_index"], spikes_prev_loop["sample_index"]]), + np.concatenate([spikes["channel_index"], spikes_prev_loop["channel_index"]]), + delta_sample, + self.near_chan_mask, + ) + + for i in range(spikes.size): sample_index = peak_sample_ind[i] chan_ind = peak_chan_ind[i] possible_clusters = self.possible_clusters_by_channel[chan_ind] if possible_clusters.size > 0: - # ~ s0 = sample_index - d['nbefore'] - # ~ s1 = sample_index + d['nafter'] + cluster_index = get_most_probable_cluster( + traces, + self.sparse_templates_array_short, + possible_clusters, + sample_index, + chan_ind, + self.nbefore_short, + self.nafter_short, + self.sparsity_mask, + ) + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # chans = np.any(self.sparsity_mask[possible_clusters, :], axis=0) + # wf = traces[sample_index - self.nbefore : sample_index + self.nafter][:, chans] + # ax.plot(wf.T.flatten(), color='k') + # dense_templates_array = self.templates.get_dense_templates() + # for c_ind in possible_clusters: + # template = dense_templates_array[c_ind, :, :][:, chans] + # ax.plot(template.T.flatten()) + # if c_ind == cluster_index: + # ax.plot(template.T.flatten(), color='m', ls='--') + # ax.set_title(f"use_fine_detector{use_fine_detector} level{level}") + # plt.show() + + chan_sparsity_mask = self.sparsity_mask[cluster_index, :] + + # find best shift + numba_best_shift_sparse( + traces, + self.sparse_templates_array_short[cluster_index, :, :], + sample_index, + self.nbefore_short, + self.possible_shifts, + distances_shift, + chan_sparsity_mask, + ) + + ind_shift = np.argmin(distances_shift) + shift = self.possible_shifts[ind_shift] + + # TODO DEBUG shift later + spikes["sample_index"][i] += shift + + spikes["cluster_index"][i] = cluster_index + + # check that the the same cluster is not already detected at same place + # this can happen for small template the substract forvever the traces + outer_neighbors_inds = [ind for ind in neighbors_spikes_inds[i] if ind > i and ind >= spikes.size] + is_valid = True + for b in outer_neighbors_inds: + b = b - spikes.size + if (spikes[i]["sample_index"] == spikes_prev_loop[b]["sample_index"]) and ( + spikes[i]["cluster_index"] == spikes_prev_loop[b]["cluster_index"] + ): + is_valid = False + + if is_valid: + # temporary assign a cluster to neighbors if not done yet + inner_neighbors_inds = [ind for ind in neighbors_spikes_inds[i] if (ind > i and ind < spikes.size)] + for b in inner_neighbors_inds: + spikes["cluster_index"][b] = get_most_probable_cluster( + traces, + self.sparse_templates_array_short, + possible_clusters, + spikes["sample_index"][b], + spikes["channel_index"][b], + self.nbefore_short, + self.nafter_short, + self.sparsity_mask, + ) + + amp = fit_one_amplitude_with_neighbors( + spikes[i], + spikes[inner_neighbors_inds], + traces, + self.sparsity_mask, + self.templates.templates_array, + self.template_norms, + self.nbefore, + self.nafter, + ) - # ~ wf = traces[s0:s1, :] + low_lim, up_lim = self.amplitude_limits + if low_lim <= amp <= up_lim: + spikes["amplitude"][i] = amp + wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) # TODO move this before the loop + construct_prediction_sparse( + spikes[i : i + 1], + traces, + self.templates.templates_array, + self.sparsity_mask, + wanted_channel_mask, + self.nbefore, + additive=False, + ) + elif low_lim > amp: + # print("bad amp", amp) + spikes["cluster_index"][i] = -1 + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # sample_ind = spikes["sample_index"][i] + # print(chan_sparsity_mask) + # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] + # dense_templates_array = self.templates.get_dense_templates() + # template = dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # ax.plot(wf.T.flatten()) + # ax.plot(template.T.flatten()) + # ax.plot(template.T.flatten() * amp) + # ax.set_title(f"amp{amp} use_fine_detector{use_fine_detector} level{level}") + # plt.show() + else: + # amp > up_lim + # TODO should try other cluster for the fit!! + # spikes["cluster_index"][i] = -1 + + # force amplitude to be one and need a fiting at next level + spikes["amplitude"][i] = 1 + + # print(amp) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # sample_ind = spikes["sample_index"][i] + # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] + # dense_templates_array = self.templates.get_dense_templates() + # template = dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # ax.plot(wf.T.flatten()) + # ax.plot(template.T.flatten()) + # ax.plot(template.T.flatten() * amp) + # ax.set_title(f"amp{amp} use_fine_detector{use_fine_detector} level{level}") + # plt.show() + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # chans = np.any(self.sparsity_mask[possible_clusters, :], axis=0) + # wf = traces[sample_index - self.nbefore : sample_index + self.nafter][:, chans] + # ax.plot(wf.T.flatten(), color='k') + # dense_templates_array = self.templates.get_dense_templates() + # for c_ind in possible_clusters: + # template = dense_templates_array[c_ind, :, :][:, chans] + # ax.plot(template.T.flatten()) + # if c_ind == cluster_index: + # ax.plot(template.T.flatten(), color='m', ls='--') + # ax.set_title(f"use_fine_detector{use_fine_detector} level{level}") + # plt.show() - s0 = sample_index - self.nbefore_short - s1 = sample_index + self.nafter_short - wf_short = traces[s0:s1, :] + else: + # not valid because already detected + spikes["cluster_index"][i] = -1 - ## pure numpy with cluster spasity - # distances = np.sum(np.sum((templates[possible_clusters, :, :] - wf[None, : , :])**2, axis=1), axis=1) + else: + # no possible cluster in neighborhood for this channel + spikes["cluster_index"][i] = -1 - ## pure numpy with cluster+channel spasity - # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) - # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) + # delta_sample = self.nbefore + self.nafter + # # TODO benchmark this and make this faster + # neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + # for i in range(spikes.size): + # amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_spikes_inds[i]], traces, + # self.sparsity_mask, self.templates.templates_array, self.nbefore, self.nafter) + # spikes["amplitude"][i] = amp - ## numba with cluster+channel spasity - union_channels = np.any(self.template_sparsity[possible_clusters, :], axis=0) - # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) - distances = numba_sparse_dist(wf_short, self.templates_short, union_channels, possible_clusters) + keep = spikes["cluster_index"] >= 0 + spikes = spikes[keep] - # DEBUG - # ~ ind = np.argmin(distances) - # ~ cluster_index = possible_clusters[ind] + # keep = (spikes["amplitude"] >= 0.7) & (spikes["amplitude"] <= 1.4) + # spikes = spikes[keep] - for ind in np.argsort(distances)[: self.num_template_try]: - cluster_index = possible_clusters[ind] + # sparse_templates_array = self.templates.templates_array + # wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) + # assert np.sum(wanted_channel_mask) == traces.shape[1] # TODO remove this DEBUG later + # construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, wanted_channel_mask, self.nbefore, additive=False) - chan_sparsity = self.template_sparsity[cluster_index, :] - template_sparse = self.templates_array[cluster_index, :, :][:, chan_sparsity] + return spikes - # find best shift - ## pure numpy version - # for s, shift in enumerate(possible_shifts): - # wf_shift = traces[s0 + shift: s1 + shift, chan_sparsity] - # distances_shift[s] = np.sum((template_sparse - wf_shift)**2) - # ind_shift = np.argmin(distances_shift) - # shift = possible_shifts[ind_shift] +def get_most_probable_cluster( + traces, + sparse_templates_array, + possible_clusters, + sample_index, + chan_ind, + nbefore_short, + nafter_short, + template_sparsity_mask, +): + s0 = sample_index - nbefore_short + s1 = sample_index + nafter_short + wf_short = traces[s0:s1, :] - ## numba version - numba_best_shift( - traces, - self.templates_array[cluster_index, :, :], - sample_index, - self.nbefore, - possible_shifts, - distances_shift, - chan_sparsity, - ) - ind_shift = np.argmin(distances_shift) - shift = possible_shifts[ind_shift] - - sample_index = sample_index + shift - s0 = sample_index - self.nbefore - s1 = sample_index + self.nafter - wf_sparse = traces[s0:s1, chan_sparsity] - - # accept or not - - centered = wf_sparse - template_sparse - accepted = True - for other_ind, other_vector in self.closest_units[cluster_index]: - v = np.sum(centered * other_vector) - if np.abs(v) > 0.5: - accepted = False - break - - if accepted: - # ~ if ind != np.argsort(distances)[0]: - # ~ print('not first one', np.argsort(distances), ind) - break - - if accepted: - amplitude = 1.0 - - # remove template - template = self.templates_array[cluster_index, :, :] - s0 = sample_index - self.nbefore - s1 = sample_index + self.nafter - traces[s0:s1, :] -= template * amplitude + ## numba with cluster+channel spasity + union_channels = np.any(template_sparsity_mask[possible_clusters, :], axis=0) + distances = numba_sparse_distance( + wf_short, sparse_templates_array, template_sparsity_mask, union_channels, possible_clusters + ) - else: - cluster_index = -1 - amplitude = 0.0 + ind = np.argmin(distances) + cluster_index = possible_clusters[ind] - else: - cluster_index = -1 - amplitude = 0.0 + return cluster_index - spikes["cluster_index"][i] = cluster_index - spikes["amplitude"][i] = amplitude - return spikes +def get_neighbors_spikes(sample_inds, chan_inds, delta_sample, near_chan_mask): + + neighbors_spikes_inds = [] + for i in range(sample_inds.size): + + inds = np.flatnonzero(np.abs(sample_inds - sample_inds[i]) < delta_sample) + neighb = [] + for ind in inds: + if near_chan_mask[chan_inds[i], chan_inds[ind]] and i != ind: + neighb.append(ind) + neighbors_spikes_inds.append(neighb) + + return neighbors_spikes_inds + + +def fit_one_amplitude_with_neighbors( + spike, neighbors_spikes, traces, template_sparsity_mask, sparse_templates_array, template_norms, nbefore, nafter +): + """ + Fit amplitude one spike of one spike with/without neighbors + + """ + + import scipy.linalg + + cluster_index = spike["cluster_index"] + sample_index = spike["sample_index"] + chan_sparsity_mask = template_sparsity_mask[cluster_index, :] + num_chans = np.sum(chan_sparsity_mask) + if num_chans == 0: + # protect against empty template because too sparse + return 0.0 + start, stop = sample_index - nbefore, sample_index + nafter + if neighbors_spikes is None or (neighbors_spikes.size == 0): + template = sparse_templates_array[cluster_index, :, :num_chans] + wf = traces[start:stop, :][:, chan_sparsity_mask] + # TODO precompute template norms + amplitude = np.sum(template.flatten() * wf.flatten()) / template_norms[cluster_index] + else: + + lim0 = min(start, np.min(neighbors_spikes["sample_index"]) - nbefore) + lim1 = max(stop, np.max(neighbors_spikes["sample_index"]) + nafter) + + local_traces = traces[lim0:lim1, :][:, chan_sparsity_mask] + mask_not_fitted = (neighbors_spikes["amplitude"] == 0.0) & (neighbors_spikes["cluster_index"] >= 0) + local_spike = spike.copy() + local_spike["sample_index"] -= lim0 + local_spike["amplitude"] = 1.0 + + local_neighbors_spikes = neighbors_spikes.copy() + local_neighbors_spikes["sample_index"] -= lim0 + local_neighbors_spikes["amplitude"][:] = 1.0 + + num_spikes_to_fit = 1 + np.sum(mask_not_fitted) + x = np.zeros((lim1 - lim0, num_chans, num_spikes_to_fit), dtype="float32") + wanted_channel_mask = chan_sparsity_mask + construct_prediction_sparse( + np.array([local_spike]), + x[:, :, 0], + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + True, + ) + + j = 1 + for i in range(neighbors_spikes.size): + if mask_not_fitted[i]: + # add to one regressor + construct_prediction_sparse( + local_neighbors_spikes[i : i + 1], + x[:, :, j], + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + True, + ) + j += 1 + elif local_neighbors_spikes[neighbors_spikes[i]]["sample_index"] >= 0: + # remove from traces + construct_prediction_sparse( + local_neighbors_spikes[i : i + 1], + local_traces, + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + False, + ) + # else: + # pass + + x = x.reshape(-1, num_spikes_to_fit) + y = local_traces.flatten() + + res = scipy.linalg.lstsq(x, y, cond=None, lapack_driver="gelsd") + amplitudes = res[0] + amplitude = amplitudes[0] + + # import matplotlib.pyplot as plt + # x_plot = x.reshape((lim1 - lim0, num_chans, num_spikes_to_fit)).swapaxes(0, 1).reshape(-1, num_spikes_to_fit) + # pred = x @ amplitudes + # pred_plot = pred.reshape(-1, num_chans).T.flatten() + # y_plot = y.reshape(-1, num_chans).T.flatten() + # fig, ax = plt.subplots() + # ax.plot(x_plot, color='b') + # print(x_plot.shape, y_plot.shape) + # ax.plot(y_plot, color='g') + # ax.plot(pred_plot , color='r') + # ax.set_title(f"{amplitudes}") + # # ax.set_title(f"{amplitudes} {amp_dot}") + # plt.show() + + return amplitude if HAVE_NUMBA: @jit(nopython=True) - def numba_sparse_dist(wf, templates, union_channels, possible_clusters): + def construct_prediction_sparse( + spikes, traces, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, additive + ): + # must have np.sum(wanted_channel_mask) == traces.shape[0] + total_chans = wanted_channel_mask.shape[0] + for spike in spikes: + ind0 = spike["sample_index"] - nbefore + ind1 = ind0 + sparse_templates_array.shape[1] + cluster_index = spike["cluster_index"] + amplitude = spike["amplitude"] + chan_in_template = 0 + chan_in_trace = 0 + for chan in range(total_chans): + if wanted_channel_mask[chan]: + if template_sparsity_mask[cluster_index, chan]: + if additive: + traces[ind0:ind1, chan_in_trace] += ( + sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + ) + else: + traces[ind0:ind1, chan_in_trace] -= ( + sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + ) + chan_in_template += 1 + chan_in_trace += 1 + else: + if template_sparsity_mask[cluster_index, chan]: + chan_in_template += 1 + + @jit(nopython=True) + def numba_sparse_distance( + wf, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, possible_clusters + ): """ numba implementation that compute distance from template with sparsity - handle by two separate vectors + + wf is dense + sparse_templates_array is sparse with the template_sparsity_mask """ - total_cluster, width, num_chan = templates.shape + width, total_chans = wf.shape num_cluster = possible_clusters.shape[0] distances = np.zeros((num_cluster,), dtype=np.float32) for i in prange(num_cluster): cluster_index = possible_clusters[i] sum_dist = 0.0 - for chan_ind in range(num_chan): - if union_channels[chan_ind]: - for s in range(width): - v = wf[s, chan_ind] - t = templates[cluster_index, s, chan_ind] - sum_dist += (v - t) ** 2 + chan_in_template = 0 + for chan in range(total_chans): + if wanted_channel_mask[chan]: + if template_sparsity_mask[cluster_index, chan]: + for s in range(width): + v = wf[s, chan] + t = sparse_templates_array[cluster_index, s, chan_in_template] + sum_dist += (v - t) ** 2 + chan_in_template += 1 + else: + for s in range(width): + v = wf[s, chan] + t = 0 + sum_dist += (v - t) ** 2 + else: + if template_sparsity_mask[cluster_index, chan]: + chan_in_template += 1 distances[i] = sum_dist return distances @jit(nopython=True) - def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity): + def numba_best_shift_sparse( + traces, sparse_template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity + ): """ numba implementation to compute several sample shift before template substraction """ - width, num_chan = template.shape + width = sparse_template.shape[0] + total_chans = traces.shape[1] n_shift = possible_shifts.size for i in range(n_shift): shift = possible_shifts[i] sum_dist = 0.0 - for chan_ind in range(num_chan): - if chan_sparsity[chan_ind]: + chan_in_template = 0 + for chan in range(total_chans): + if chan_sparsity[chan]: for s in range(width): - v = traces[sample_index - nbefore + s + shift, chan_ind] - t = template[s, chan_ind] + v = traces[sample_index - nbefore + s + shift, chan] + t = sparse_template[s, chan_in_template] sum_dist += (v - t) ** 2 + chan_in_template += 1 distances_shift[i] = sum_dist return distances_shift diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 2531a922da..3099448b11 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -348,7 +348,7 @@ def __init__( BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - templates_array = templates.get_dense_templates().astype(np.float32, casting="safe") + templates_array = templates.get_dense_templates().astype(np.float32) # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index cbf1d29932..7cd899a3bb 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -9,8 +9,8 @@ from spikeinterface.sortingcomponents.tests.common import make_dataset -job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) -# job_kwargs = dict(n_jobs=1, chunk_duration="500ms", progress_bar=True) +# job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) +job_kwargs = dict(n_jobs=1, chunk_duration="500ms", progress_bar=True) def get_sorting_analyzer(): @@ -45,7 +45,7 @@ def test_find_spikes_from_templates(method, sorting_analyzer): "templates": templates, } method_kwargs = {} - if method in ("naive", "tdc-peeler", "circus"): + if method in ("naive", "tdc-peeler", "circus", "tdc-peeler2"): method_kwargs["noise_levels"] = noise_levels # method_kwargs["wobble"] = { @@ -61,26 +61,28 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # print(info) - # DEBUG = True + DEBUG = True - # if DEBUG: - # import matplotlib.pyplot as plt - # import spikeinterface.full as si + if DEBUG: + import matplotlib.pyplot as plt + import spikeinterface.full as si - # sorting_analyzer.compute("waveforms") - # sorting_analyzer.compute("templates") + sorting_analyzer.compute("waveforms") + sorting_analyzer.compute("templates") - # gt_sorting = sorting_analyzer.sorting + gt_sorting = sorting_analyzer.sorting - # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency) + sorting = NumpySorting.from_times_labels( + spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency + ) - # ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) + ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) - # fig, ax = plt.subplots() - # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) - # si.plot_agreement_matrix(comp, ax=ax) - # ax.set_title(method) - # plt.show() + # fig, ax = plt.subplots() + # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) + # si.plot_agreement_matrix(comp, ax=ax) + # ax.set_title(method) + # plt.show() if __name__ == "__main__": @@ -88,6 +90,6 @@ def test_find_spikes_from_templates(method, sorting_analyzer): # method = "naive" # method = "tdc-peeler" # method = "circus" - method = "circus-omp-svd" - # method = "wobble" + # method = "circus-omp-svd" + method = "wobble" test_find_spikes_from_templates(method, sorting_analyzer)