diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9de2762562..db3d88f116 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -3,7 +3,6 @@ import os import shutil import numpy as np -import os from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs @@ -21,18 +20,17 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, "filtering": {"dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "localization": {}, "clustering": {}, "matching": {}, - "registration": {}, "apply_preprocessing": True, - "shared_memory": False, - "job_kwargs": {}, + "shared_memory": True, + "job_kwargs": {"n_jobs": -1}, } @classmethod @@ -63,8 +61,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - # if recording.is_filtered == True: - # print('Looks like the recording is already filtered, check preprocessing!') recording_f = bandpass_filter(recording, **filtering_params) recording_f = common_reference(recording_f) else: @@ -103,8 +99,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets clustering_params = params["clustering"].copy() - clustering_params.update(params["waveforms"]) - clustering_params.update(params["general"]) + clustering_params["waveforms_kwargs"] = params["waveforms"] + + for k in ["ms_before", "ms_after"]: + clustering_params["waveforms_kwargs"][k] = params["general"][k] + clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs clustering_params["tmp_folder"] = sorter_output_folder / "clustering" @@ -126,6 +125,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): waveforms_params = params["waveforms"].copy() waveforms_params.update(job_kwargs) + for k in ["ms_before", "ms_after"]: + waveforms_params[k] = params["general"][k] + if params["shared_memory"]: mode = "memory" waveforms_folder = None @@ -143,6 +145,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_params.update({"noise_levels": noise_levels}) matching_job_params = job_kwargs.copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in matching_job_params: + matching_job_params.pop(value) + matching_job_params["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index 772c99bc0a..4efabbc9c5 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -600,29 +600,38 @@ def plot_comparison_matching( else: ax = axs[j] comp1, comp2 = comp_per_method[method1], comp_per_method[method2] - for performance, color in zip(performance_names, colors): - perf1 = comp1.get_performance()[performance] - perf2 = comp2.get_performance()[performance] - ax.plot(perf2, perf1, ".", label=performance, color=color) - ax.plot([0, 1], [0, 1], "k--", alpha=0.5) - ax.set_ylim(ylim) - ax.set_xlim(ylim) - ax.spines[["right", "top"]].set_visible(False) - ax.set_aspect("equal") - - if j == 0: - ax.set_ylabel(f"{method1}") - else: - ax.set_yticks([]) - if i == num_methods - 1: - ax.set_xlabel(f"{method2}") + if i <= j: + for performance, color in zip(performance_names, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.plot(perf2, perf1, ".", label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + if j == i: + ax.set_ylabel(f"{method1}") + else: + ax.set_yticks([]) + if i == j: + ax.set_xlabel(f"{method2}") + else: + ax.set_xticks([]) + if i == num_methods - 1 and j == num_methods - 1: + patches = [] + for color, name in zip(colors, performance_names): + patches.append(mpatches.Patch(color=color, label=name)) + ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) else: + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) ax.set_xticks([]) - if i == num_methods - 1 and j == num_methods - 1: - patches = [] - for color, name in zip(colors, performance_names): - patches.append(mpatches.Patch(color=color, label=name)) - ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) + ax.set_yticks([]) plt.tight_layout(h_pad=0, w_pad=0) return fig, axs diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 23fdbf1979..b87bbc7cee 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -536,7 +536,6 @@ def remove_duplicates_via_matching( waveform_extractor, noise_levels, peak_labels, - sparsify_threshold=1, method_kwargs={}, job_kwargs={}, tmp_folder=None, @@ -552,6 +551,10 @@ def remove_duplicates_via_matching( from pathlib import Path job_kwargs = fix_job_kwargs(job_kwargs) + + if waveform_extractor.is_sparse(): + sparsity = waveform_extractor.sparsity.mask + templates = waveform_extractor.get_all_templates(mode="median").copy() nb_templates = len(templates) duration = waveform_extractor.nbefore + waveform_extractor.nafter @@ -559,9 +562,9 @@ def remove_duplicates_via_matching( fs = waveform_extractor.recording.get_sampling_frequency() num_chans = waveform_extractor.recording.get_num_channels() - for t in range(nb_templates): - is_silent = templates[t].ptp(0) < sparsify_threshold - templates[t, :, is_silent] = 0 + if waveform_extractor.is_sparse(): + for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): + templates[count][:, ~sparsity[count]] = 0 zdata = templates.reshape(nb_templates, -1) @@ -581,6 +584,7 @@ def remove_duplicates_via_matching( recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") recording.annotate(is_filtered=True) + recording = recording.set_probe(waveform_extractor.recording.get_probe()) margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) half_marging = margin // 2 @@ -597,7 +601,6 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], - "sparsify_threshold": sparsify_threshold, "omp_min_sps": 0.1, "templates": None, "overlaps": None, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index fcbcac097f..be8ecd6702 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -41,7 +41,6 @@ class RandomProjectionClustering: "ms_before": 1.5, "ms_after": 1.5, "random_seed": 42, - "cleaning_method": "matching", "shared_memory": False, "min_values": {"ptp": 0, "energy": 0}, "tmp_folder": None, @@ -160,86 +159,60 @@ def main_function(cls, recording, peaks, params): spikes["segment_index"] = peaks[mask]["segment_index"] spikes["unit_index"] = peak_labels[mask] - cleaning_method = params["cleaning_method"] - if verbose: - print("We found %d raw clusters, starting to clean with %s..." % (len(labels), cleaning_method)) - - if cleaning_method == "cosine": - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - labels, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - folder=None, - dtype=recording.get_dtype(), - sparsity_mask=None, - copy=True, - **params["job_kwargs"], - ) - - labels, peak_labels = remove_duplicates( - wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] - ) - - elif cleaning_method == "dip": - wfs_arrays = {} - for label in labels: - mask = label == peak_labels - wfs_arrays[label] = hdbscan_data[mask] - - labels, peak_labels = remove_duplicates_via_dip(wfs_arrays, peak_labels, **params["cleaning_kwargs"]) - - elif cleaning_method == "matching": - # create a tmp folder - if params["tmp_folder"] is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = get_global_tmp_folder() / name - else: - tmp_folder = Path(params["tmp_folder"]) - - if params["shared_memory"]: - waveform_folder = None - mode = "memory" - else: - waveform_folder = tmp_folder / "waveforms" - mode = "folder" - - sorting_folder = tmp_folder / "sorting" - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) - sorting = sorting.save(folder=sorting_folder) - we = extract_waveforms( - recording, - sorting, - waveform_folder, - ms_before=params["ms_before"], - ms_after=params["ms_after"], - **params["job_kwargs"], - return_scaled=False, - mode=mode, - ) - - cleaning_matching_params = params["job_kwargs"].copy() - cleaning_matching_params["chunk_duration"] = "100ms" - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["verbose"] = False - cleaning_matching_params["progress_bar"] = False - - cleaning_params = params["cleaning_kwargs"].copy() - cleaning_params["tmp_folder"] = tmp_folder - - labels, peak_labels = remove_duplicates_via_matching( - we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params - ) - - if params["tmp_folder"] is None: - shutil.rmtree(tmp_folder) - else: + print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) + + # create a tmp folder + if params["tmp_folder"] is None: + name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + tmp_folder = get_global_tmp_folder() / name + else: + tmp_folder = Path(params["tmp_folder"]) + + if params["shared_memory"]: + waveform_folder = None + mode = "memory" + else: + waveform_folder = tmp_folder / "waveforms" + mode = "folder" + + sorting_folder = tmp_folder / "sorting" + unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) + sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) + sorting = sorting.save(folder=sorting_folder) + we = extract_waveforms( + recording, + sorting, + waveform_folder, + ms_before=params["ms_before"], + ms_after=params["ms_after"], + **params["job_kwargs"], + return_scaled=False, + mode=mode, + ) + + cleaning_matching_params = params["job_kwargs"].copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in cleaning_matching_params: + cleaning_matching_params.pop(value) + cleaning_matching_params["chunk_duration"] = "100ms" + cleaning_matching_params["n_jobs"] = 1 + cleaning_matching_params["verbose"] = False + cleaning_matching_params["progress_bar"] = False + + cleaning_params = params["cleaning_kwargs"].copy() + cleaning_params["tmp_folder"] = tmp_folder + + labels, peak_labels = remove_duplicates_via_matching( + we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + ) + + if params["tmp_folder"] is None: + shutil.rmtree(tmp_folder) + else: + if not params["shared_memory"]: shutil.rmtree(tmp_folder / "waveforms") - shutil.rmtree(tmp_folder / "sorting") + shutil.rmtree(tmp_folder / "sorting") if verbose: print("We kept %d non-duplicated clusters..." % len(labels)) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 2196320378..a19e7b71b5 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -5,7 +5,6 @@ import scipy.spatial -from tqdm import tqdm import scipy try: @@ -16,7 +15,8 @@ except ImportError: HAVE_SKLEARN = False -from spikeinterface.core import get_noise_levels, get_random_data_chunks + +from spikeinterface.core import get_noise_levels, get_random_data_chunks, compute_sparsity from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) @@ -131,6 +131,38 @@ def _freq_domain_conv(in1, in2, axes, shape, cache, calc_fast_len=True): return ret +def compute_overlaps(templates, num_samples, num_channels, sparsities): + num_templates = len(templates) + + dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) + for i in range(num_templates): + dense_templates[i, :, sparsities[i]] = templates[i].T + + size = 2 * num_samples - 1 + + all_delays = list(range(0, num_samples + 1)) + + overlaps = {} + + for delay in all_delays: + source = dense_templates[:, :delay, :].reshape(num_templates, -1) + target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) + + overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) + + if delay < num_samples: + overlaps[size - delay + 1] = overlaps[delay].T.tocsr() + + new_overlaps = [] + + for i in range(num_templates): + data = [overlaps[j][i, :].T for j in range(size)] + data = scipy.sparse.hstack(data) + new_overlaps += [data] + + return new_overlaps + + class CircusOMPPeeler(BaseTemplateMatchingEngine): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter @@ -152,21 +184,18 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): (Minimal, Maximal) amplitudes allowed for every template omp_min_sps: float Stopping criteria of the OMP algorithm, in percentage of the norm - sparsify_threshold: float - Templates are sparsified in order to keep only the channels necessary - to explain. ptp limit for considering a channel as silent - smoothing_factor: float - Templates are smoothed via Spline Interpolation noise_levels: array The noise levels, for every channels. If None, they will be automatically computed random_chunk_kwargs: dict Parameters for computing noise levels, if not provided (sub optimal) + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. ----- """ _default_params = { - "sparsify_threshold": 1, "amplitudes": [0.6, 2], "omp_min_sps": 0.1, "waveform_extractor": None, @@ -175,36 +204,21 @@ class CircusOMPPeeler(BaseTemplateMatchingEngine): "norms": None, "random_chunk_kwargs": {}, "noise_levels": None, - "smoothing_factor": 0.25, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], + "vicinity": 0, } - @classmethod - def _sparsify_template(cls, template, sparsify_threshold): - is_silent = template.ptp(0) < sparsify_threshold - template[:, is_silent] = 0 - (active_channels,) = np.where(np.logical_not(is_silent)) - - return template, active_channels - - @classmethod - def _regularize_template(cls, template, smoothing_factor=0.25): - nb_channels = template.shape[1] - nb_timesteps = template.shape[0] - xaxis = np.arange(nb_timesteps) - for i in range(nb_channels): - z = scipy.interpolate.UnivariateSpline(xaxis, template[:, i]) - z.set_smoothing_factor(smoothing_factor) - template[:, i] = z(xaxis) - return template - @classmethod def _prepare_templates(cls, d): waveform_extractor = d["waveform_extractor"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] num_templates = len(d["waveform_extractor"].sorting.unit_ids) + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask + templates = waveform_extractor.get_all_templates(mode="median").copy() d["sparsities"] = {} @@ -212,52 +226,10 @@ def _prepare_templates(cls, d): d["norms"] = np.zeros(num_templates, dtype=np.float32) for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids): - if d["smoothing_factor"] > 0: - template = cls._regularize_template(templates[count], d["smoothing_factor"]) - else: - template = templates[count] - template, active_channels = cls._sparsify_template(template, d["sparsify_threshold"]) - d["sparsities"][count] = active_channels + template = templates[count][:, sparsity[count]] + (d["sparsities"][count],) = np.nonzero(sparsity[count]) d["norms"][count] = np.linalg.norm(template) - d["templates"][count] = template[:, active_channels] / d["norms"][count] - - return d - - @classmethod - def _prepare_overlaps(cls, d): - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - sparsities = d["sparsities"] - - dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) - for i in range(num_templates): - dense_templates[i, :, sparsities[i]] = templates[i].T - - size = 2 * num_samples - 1 - - all_delays = list(range(0, num_samples + 1)) - - overlaps = {} - - for delay in all_delays: - source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) - - overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) - - if delay < num_samples: - overlaps[size - delay + 1] = overlaps[delay].T.tocsr() - - new_overlaps = [] - - for i in range(num_templates): - data = [overlaps[j][i, :].T for j in range(size)] - data = scipy.sparse.hstack(data) - new_overlaps += [data] - - d["overlaps"] = new_overlaps + d["templates"][count] = template / d["norms"][count] return d @@ -276,6 +248,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["nbefore"] = d["waveform_extractor"].nbefore d["nafter"] = d["waveform_extractor"].nafter d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() + d["vicinity"] *= d["num_samples"] if d["noise_levels"] is None: print("CircusOMPPeeler : noise should be computed outside") @@ -290,15 +263,12 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["num_templates"] = len(d["templates"]) if d["overlaps"] is None: - d = cls._prepare_overlaps(d) + d["overlaps"] = compute_overlaps(d["templates"], d["num_samples"], d["num_channels"], d["sparsities"]) d["ignored_ids"] = np.array(d["ignored_ids"]) omp_min_sps = d["omp_min_sps"] - norms = d["norms"] - sparsities = d["sparsities"] - - nb_active_channels = np.array([len(sparsities[i]) for i in range(d["num_templates"])]) + # nb_active_channels = np.array([len(d['sparsities'][count]) for count in range(d['num_templates'])]) d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) return d @@ -336,6 +306,7 @@ def main_function(cls, traces, d): sparsities = d["sparsities"] ignored_ids = d["ignored_ids"] stop_criteria = d["stop_criteria"] + vicinity = d["vicinity"] if "cached_fft_kernels" not in d: d["cached_fft_kernels"] = {"fshape": 0} @@ -381,7 +352,7 @@ def main_function(cls, traces, d): spikes = np.empty(scalar_products.size, dtype=spike_dtype) idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) - M = np.zeros((num_peaks, num_peaks), dtype=np.float32) + M = np.zeros((100, 100), dtype=np.float32) all_selections = np.empty((2, scalar_products.size), dtype=np.int32) final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) @@ -393,6 +364,8 @@ def main_function(cls, traces, d): cached_overlaps = {} is_valid = scalar_products > stop_criteria + all_amplitudes = np.zeros(0, dtype=np.float32) + is_in_vicinity = np.zeros(0, dtype=np.int32) while np.any(is_valid): best_amplitude_ind = scalar_products[is_valid].argmax() @@ -412,20 +385,39 @@ def main_function(cls, traces, d): M = Z M[num_selection, idx] = cached_overlaps[best_cluster_ind][selection[0, idx], myline] - scipy.linalg.solve_triangular( - M[:num_selection, :num_selection], - M[num_selection, :num_selection], - trans=0, - lower=1, - overwrite_b=True, - check_finite=False, - ) - - v = nrm2(M[num_selection, :num_selection]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) + + if vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, :num_selection]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 else: M[0, 0] = 1 @@ -435,9 +427,16 @@ def main_function(cls, traces, d): selection = all_selections[:, :num_selection] res_sps = full_sps[selection[0], selection[1]] - all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - - all_amplitudes /= norms[selection[0]] + if True: # vicinity == 0: + all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + all_amplitudes /= norms[selection[0]] + else: + # This is not working, need to figure out why + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] @@ -515,14 +514,12 @@ class CircusPeeler(BaseTemplateMatchingEngine): Maximal amplitude allowed for every template min_amplitude: float Minimal amplitude allowed for every template - sparsify_threshold: float - Templates are sparsified in order to keep only the channels necessary - to explain a given fraction of the total norm use_sparse_matrix_threshold: float If density of the templates is below a given threshold, sparse matrix are used (memory efficient) - progress_bar_steps: bool - In order to display or not steps from the algorithm + sparse_kwargs: dict + Parameters to extract a sparsity mask from the waveform_extractor, if not + already sparse. ----- @@ -535,68 +532,40 @@ class CircusPeeler(BaseTemplateMatchingEngine): "detect_threshold": 5, "noise_levels": None, "random_chunk_kwargs": {}, - "sparsify_threshold": 0.99, "max_amplitude": 1.5, "min_amplitude": 0.5, "use_sparse_matrix_threshold": 0.25, - "progess_bar_steps": False, "waveform_extractor": None, - "smoothing_factor": 0.25, + "sparse_kwargs": {"method": "ptp", "threshold": 1}, } - @classmethod - def _sparsify_template(cls, template, sparsify_threshold, noise_levels): - is_silent = template.std(0) < 0.1 * noise_levels - - template[:, is_silent] = 0 - - channel_norms = np.linalg.norm(template, axis=0) ** 2 - total_norm = np.linalg.norm(template) ** 2 - - idx = np.argsort(channel_norms)[::-1] - explained_norms = np.cumsum(channel_norms[idx] / total_norm) - channel = np.searchsorted(explained_norms, sparsify_threshold) - active_channels = np.sort(idx[:channel]) - template[:, idx[channel:]] = 0 - return template, active_channels - - @classmethod - def _regularize_template(cls, template, smoothing_factor=0.25): - nb_channels = template.shape[1] - nb_timesteps = template.shape[0] - xaxis = np.arange(nb_timesteps) - for i in range(nb_channels): - z = scipy.interpolate.UnivariateSpline(xaxis, template[:, i]) - z.set_smoothing_factor(smoothing_factor) - template[:, i] = z(xaxis) - return template - @classmethod def _prepare_templates(cls, d): - parameters = d - waveform_extractor = parameters["waveform_extractor"] - num_samples = parameters["num_samples"] - num_channels = parameters["num_channels"] - num_templates = parameters["num_templates"] - max_amplitude = parameters["max_amplitude"] - min_amplitude = parameters["min_amplitude"] - use_sparse_matrix_threshold = parameters["use_sparse_matrix_threshold"] + waveform_extractor = d["waveform_extractor"] + num_samples = d["num_samples"] + num_channels = d["num_channels"] + num_templates = d["num_templates"] + use_sparse_matrix_threshold = d["use_sparse_matrix_threshold"] + + d["norms"] = np.zeros(num_templates, dtype=np.float32) - parameters["norms"] = np.zeros(num_templates, dtype=np.float32) + all_units = list(d["waveform_extractor"].sorting.unit_ids) - all_units = list(parameters["waveform_extractor"].sorting.unit_ids) + if not waveform_extractor.is_sparse(): + sparsity = compute_sparsity(waveform_extractor, **d["sparse_kwargs"]).mask + else: + sparsity = waveform_extractor.sparsity.mask templates = waveform_extractor.get_all_templates(mode="median").copy() + d["sparsities"] = {} + d["circus_templates"] = {} for count, unit_id in enumerate(all_units): - if parameters["smoothing_factor"] > 0: - templates[count] = cls._regularize_template(templates[count], parameters["smoothing_factor"]) - - templates[count], _ = cls._sparsify_template( - templates[count], parameters["sparsify_threshold"], parameters["noise_levels"] - ) - parameters["norms"][count] = np.linalg.norm(templates[count]) - templates[count] /= parameters["norms"][count] + (d["sparsities"][count],) = np.nonzero(sparsity[count]) + templates[count][:, ~sparsity[count]] = 0 + d["norms"][count] = np.linalg.norm(templates[count]) + templates[count] /= d["norms"][count] + d["circus_templates"][count] = templates[count][:, sparsity[count]] templates = templates.reshape(num_templates, -1) @@ -604,54 +573,11 @@ def _prepare_templates(cls, d): if nnz <= use_sparse_matrix_threshold: templates = scipy.sparse.csr_matrix(templates) print(f"Templates are automatically sparsified (sparsity level is {nnz})") - parameters["is_dense"] = False + d["is_dense"] = False else: - parameters["is_dense"] = True - - parameters["templates"] = templates + d["is_dense"] = True - return parameters - - @classmethod - def _prepare_overlaps(cls, d): - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - is_dense = d["is_dense"] - - if not is_dense: - dense_templates = templates.toarray() - else: - dense_templates = templates - - dense_templates = dense_templates.reshape(num_templates, num_samples, num_channels) - - size = 2 * num_samples - 1 - - all_delays = list(range(0, num_samples + 1)) - if d["progess_bar_steps"]: - all_delays = tqdm(all_delays, desc="[1] compute overlaps") - - overlaps = {} - - for delay in all_delays: - source = dense_templates[:, :delay, :].reshape(num_templates, -1) - target = dense_templates[:, num_samples - delay :, :].reshape(num_templates, -1) - - overlaps[delay] = scipy.sparse.csr_matrix(source.dot(target.T)) - - if delay < num_samples: - overlaps[size - delay] = overlaps[delay].T.tocsr() - - new_overlaps = [] - - for i in range(num_templates): - data = [overlaps[j][i, :].T for j in range(size)] - data = scipy.sparse.hstack(data) - new_overlaps += [data] - - d["overlaps"] = new_overlaps + d["templates"] = templates return d @@ -687,15 +613,13 @@ def _optimize_amplitudes(cls, noise_snippets, d): alpha = 0.5 norms = parameters["norms"] all_units = list(waveform_extractor.sorting.unit_ids) - if parameters["progess_bar_steps"]: - all_units = tqdm(all_units, desc="[2] compute amplitudes") parameters["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) noise = templates.dot(noise_snippets) / norms[:, np.newaxis] all_amps = {} for count, unit_id in enumerate(all_units): - waveform = waveform_extractor.get_waveforms(unit_id) + waveform = waveform_extractor.get_waveforms(unit_id, force_dense=True) snippets = waveform.reshape(waveform.shape[0], -1).T amps = templates.dot(snippets) / norms[:, np.newaxis] good = amps[count, :].flatten() @@ -708,16 +632,6 @@ def _optimize_amplitudes(cls, noise_snippets, d): res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) parameters["amplitudes"][count] = res.x - # import pylab as plt - # plt.hist(good, 100, alpha=0.5) - # plt.hist(bad, 100, alpha=0.5) - # plt.hist(noise[count], 100, alpha=0.5) - # ymin, ymax = plt.ylim() - # plt.plot([res.x[0], res.x[0]], [ymin, ymax], 'k--') - # plt.plot([res.x[1], res.x[1]], [ymin, ymax], 'k--') - # plt.savefig('test_%d.png' %count) - # plt.close() - return d @classmethod @@ -727,8 +641,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): default_parameters.update(kwargs) # assert isinstance(d['waveform_extractor'], WaveformExtractor) - - for v in ["sparsify_threshold", "use_sparse_matrix_threshold"]: + for v in ["use_sparse_matrix_threshold"]: assert (default_parameters[v] >= 0) and (default_parameters[v] <= 1), f"{v} should be in [0, 1]" default_parameters["num_channels"] = default_parameters["waveform_extractor"].recording.get_num_channels() @@ -746,7 +659,13 @@ def initialize_and_check_kwargs(cls, recording, kwargs): ) default_parameters = cls._prepare_templates(default_parameters) - default_parameters = cls._prepare_overlaps(default_parameters) + + default_parameters["overlaps"] = compute_overlaps( + default_parameters["circus_templates"], + default_parameters["num_samples"], + default_parameters["num_channels"], + default_parameters["sparsities"], + ) default_parameters["exclude_sweep_size"] = int( default_parameters["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0 @@ -817,31 +736,31 @@ def main_function(cls, traces, d): sym_patch = d["sym_patch"] peak_traces = traces[margin // 2 : -margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakByChannel.detect_peaks( + peak_sample_index, peak_chan_ind = DetectPeakByChannel.detect_peaks( peak_traces, peak_sign, abs_threholds, exclude_sweep_size ) if jitter > 0: - jittered_peaks = peak_sample_ind[:, np.newaxis] + np.arange(-jitter, jitter) + jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-jitter, jitter) jittered_channels = peak_chan_ind[:, np.newaxis] + np.zeros(2 * jitter) mask = (jittered_peaks > 0) & (jittered_peaks < len(peak_traces)) jittered_peaks = jittered_peaks[mask] jittered_channels = jittered_channels[mask] - peak_sample_ind, unique_idx = np.unique(jittered_peaks, return_index=True) + peak_sample_index, unique_idx = np.unique(jittered_peaks, return_index=True) peak_chan_ind = jittered_channels[unique_idx] else: - peak_sample_ind, unique_idx = np.unique(peak_sample_ind, return_index=True) + peak_sample_index, unique_idx = np.unique(peak_sample_index, return_index=True) peak_chan_ind = peak_chan_ind[unique_idx] - num_peaks = len(peak_sample_ind) + num_peaks = len(peak_sample_index) if sym_patch: - snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_ind] - peak_sample_ind += margin // 2 + snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_index] + peak_sample_index += margin // 2 else: - peak_sample_ind += margin // 2 + peak_sample_index += margin // 2 snippet_window = np.arange(-d["nbefore"], d["nafter"]) - snippets = traces[peak_sample_ind[:, np.newaxis] + snippet_window] + snippets = traces[peak_sample_index[:, np.newaxis] + snippet_window] if num_peaks > 0: snippets = snippets.reshape(num_peaks, -1) @@ -865,10 +784,10 @@ def main_function(cls, traces, d): best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) best_amplitude = scalar_products[best_cluster_ind, peak_index] - best_peak_sample_ind = peak_sample_ind[peak_index] + best_peak_sample_index = peak_sample_index[peak_index] best_peak_chan_ind = peak_chan_ind[peak_index] - peak_data = peak_sample_ind - peak_sample_ind[peak_index] + peak_data = peak_sample_index - peak_sample_index[peak_index] is_valid_nn = np.searchsorted(peak_data, [-neighbor_window, neighbor_window + 1]) idx_neighbor = peak_data[is_valid_nn[0] : is_valid_nn[1]] + neighbor_window @@ -880,7 +799,7 @@ def main_function(cls, traces, d): scalar_products[:, is_valid_nn[0] : is_valid_nn[1]] += to_add scalar_products[best_cluster_ind, is_valid_nn[0] : is_valid_nn[1]] = -np.inf - spikes["sample_index"][num_spikes] = best_peak_sample_ind + spikes["sample_index"][num_spikes] = best_peak_sample_index spikes["channel_index"][num_spikes] = best_peak_chan_ind spikes["cluster_index"][num_spikes] = best_cluster_ind spikes["amplitude"][num_spikes] = best_amplitude