diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index b0d470fe40..c322d61230 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -7,6 +7,8 @@ from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter +from spikeinterface.sortingcomponents.tools import cache_preprocessing +from spikeinterface.core.basesorting import minimum_spike_dtype try: import hdbscan @@ -30,17 +32,52 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): }, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, - "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, + "selection": { + "method": "smart_sampling_amplitudes", + "n_peaks_per_channel": 5000, + "min_n_peaks": 20000, + "select_per_channel": False, + }, "clustering": {"legacy": False}, - "matching": {}, + "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, "apply_preprocessing": True, "shared_memory": True, - "job_kwargs": {"n_jobs": -1}, + "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, + "multi_units_only": False, + "job_kwargs": {"n_jobs": 0.8}, "debug": False, } handle_multi_segment = True + _params_description = { + "general": "A dictionary to describe how templates should be computed. User can define ms_before and ms_after (in ms) \ + and also the radius_um used to be considered during clustering", + "waveforms": "A dictionary to be passed to all the calls to extract_waveforms that will be performed internally. Default is \ + to consider sparse waveforms", + "filtering": "A dictionary for the high_pass filter to be used during preprocessing", + "detection": "A dictionary for the peak detection node (locally_exclusive)", + "selection": "A dictionary for the peak selection node. Default is to use smart_sampling_amplitudes, with a minimum of 20000 peaks\ + and 5000 peaks per electrode on average.", + "clustering": "A dictionary to be provided to the clustering method. By default, random_projections is used, but if legacy is set to\ + True, one other clustering called circus will be used, similar to the one used in Spyking Circus 1", + "matching": "A dictionary to specify the matching engine used to recover spikes. The method default is circus-omp-svd, but other engines\ + can be used", + "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ + median reference + zscore", + "shared_memory": "Boolean to specify if the code should, as much as possible, use an internal data structure in memory (faster)", + "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ + memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", + "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", + "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", + "debug": "Boolean to specify if the internal data structure should be kept for debugging", + } + + sorter_description = """Spyking Circus 2 is a rewriting of Spyking Circus, within the SpikeInterface framework + It uses a more conservative clustering algorithm (compared to Spyking Circus), which is less prone to hallucinate units and/or find noise. + In addition, it also uses a full Orthogonal Matching Pursuit engine to reconstruct the traces, leading to more spikes + being discovered.""" + @classmethod def get_sorter_version(cls): return "2.0" @@ -62,14 +99,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) - sampling_rate = recording.get_sampling_frequency() + sampling_frequency = recording.get_sampling_frequency() num_channels = recording.get_num_channels() ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: recording_f = highpass_filter(recording, **filtering_params) - recording_f = common_reference(recording_f) + if num_channels > 1: + recording_f = common_reference(recording_f) else: recording_f = recording recording_f.annotate(is_filtered=True) @@ -78,6 +116,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f = zscore(recording_f, dtype="float32") noise_levels = np.ones(num_channels, dtype=np.float32) + if recording_f.check_serializability("json"): + recording_f.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) + elif recording_f.check_serializability("pickle"): + recording_f.dump(sorter_output_folder / "preprocessed_recording.pickle", relative_to=None) + + recording_f = cache_preprocessing(recording_f, **job_kwargs, **params["cache_preprocessing"]) + ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() detection_params.update(job_kwargs) @@ -85,122 +130,153 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): detection_params["radius_um"] = params["general"]["radius_um"] if "exclude_sweep_ms" not in detection_params: detection_params["exclude_sweep_ms"] = max(params["general"]["ms_before"], params["general"]["ms_after"]) + detection_params["noise_levels"] = noise_levels peaks = detect_peaks(recording_f, method="locally_exclusive", **detection_params) if verbose: print("We found %d peaks in total" % len(peaks)) - ## We subselect a subset of all the peaks, by making the distributions os SNRs over all - ## channels as flat as possible - selection_params = params["selection"] - selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels - selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - - selection_params.update({"noise_levels": noise_levels}) - selected_peaks = select_peaks( - peaks, method="smart_sampling_amplitudes", select_per_channel=False, **selection_params - ) - - if verbose: - print("We kept %d peaks for clustering" % len(selected_peaks)) - - ## We launch a clustering (using hdbscan) relying on positions and features extracted on - ## the fly from the snippets - clustering_params = params["clustering"].copy() - clustering_params["waveforms"] = params["waveforms"].copy() - - for k in ["ms_before", "ms_after"]: - clustering_params["waveforms"][k] = params["general"][k] - - clustering_params.update(dict(shared_memory=params["shared_memory"])) - clustering_params["job_kwargs"] = job_kwargs - clustering_params["tmp_folder"] = sorter_output_folder / "clustering" - clustering_params.update({"noise_levels": noise_levels}) - - if "legacy" in clustering_params: - legacy = clustering_params.pop("legacy") + if params["multi_units_only"]: + sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_f.unit_ids) else: - legacy = False + ## We subselect a subset of all the peaks, by making the distributions os SNRs over all + ## channels as flat as possible + selection_params = params["selection"] + selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels + selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) + + selection_params.update({"noise_levels": noise_levels}) + selected_peaks = select_peaks(peaks, **selection_params) + + if verbose: + print("We kept %d peaks for clustering" % len(selected_peaks)) + + ## We launch a clustering (using hdbscan) relying on positions and features extracted on + ## the fly from the snippets + clustering_params = params["clustering"].copy() + clustering_params["waveforms"] = params["waveforms"].copy() + + for k in ["ms_before", "ms_after"]: + clustering_params["waveforms"][k] = params["general"][k] + + clustering_params.update(dict(shared_memory=params["shared_memory"])) + clustering_params["job_kwargs"] = job_kwargs + clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + + if "legacy" in clustering_params: + legacy = clustering_params.pop("legacy") + else: + legacy = False + + if legacy: + if verbose: + print("We are using the legacy mode for the clustering") + clustering_method = "circus" + else: + clustering_method = "random_projections" + + labels, peak_labels = find_cluster_from_peaks( + recording_f, selected_peaks, method=clustering_method, method_kwargs=clustering_params + ) + + ## We get the labels for our peaks + mask = peak_labels > -1 + + labeled_peaks = np.zeros(np.sum(mask), dtype=minimum_spike_dtype) + labeled_peaks["sample_index"] = selected_peaks[mask]["sample_index"] + labeled_peaks["segment_index"] = selected_peaks[mask]["segment_index"] + for count, l in enumerate(labels): + sub_mask = peak_labels[mask] == l + labeled_peaks["unit_index"][sub_mask] = count + unit_ids = np.arange(len(np.unique(labeled_peaks["unit_index"]))) + sorting = NumpySorting(labeled_peaks, sampling_frequency, unit_ids=unit_ids) + + clustering_folder = sorter_output_folder / "clustering" + clustering_folder.mkdir(parents=True, exist_ok=True) + + if not params["debug"]: + shutil.rmtree(clustering_folder) + else: + np.save(clustering_folder / "labels", labels) + np.save(clustering_folder / "peaks", selected_peaks) + + ## We get the templates our of such a clustering + waveforms_params = params["waveforms"].copy() + waveforms_params.update(job_kwargs) + + for k in ["ms_before", "ms_after"]: + waveforms_params[k] = params["general"][k] + + if params["shared_memory"] and not params["debug"]: + mode = "memory" + waveforms_folder = None + else: + sorting = sorting.save(folder=clustering_folder / "sorting") + mode = "folder" + waveforms_folder = sorter_output_folder / "waveforms" + + we = extract_waveforms( + recording_f, + sorting, + waveforms_folder, + return_scaled=False, + precompute_template=["median"], + mode=mode, + **waveforms_params, + ) + + ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces + matching_method = params["matching"]["method"] + matching_params = params["matching"]["method_kwargs"].copy() + matching_job_params = {} + matching_job_params.update(job_kwargs) + if matching_method == "wobble": + matching_params["templates"] = we.get_all_templates(mode="median") + matching_params["nbefore"] = we.nbefore + matching_params["nafter"] = we.nafter + else: + matching_params["waveform_extractor"] = we + + if matching_method == "circus-omp-svd": + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in matching_job_params: + matching_job_params.pop(value) + matching_job_params["chunk_duration"] = "100ms" + + spikes = find_spikes_from_templates( + recording_f, matching_method, method_kwargs=matching_params, **matching_job_params + ) + + if params["debug"]: + fitting_folder = sorter_output_folder / "fitting" + fitting_folder.mkdir(parents=True, exist_ok=True) + np.save(fitting_folder / "spikes", spikes) + + if verbose: + print("We found %d spikes" % len(spikes)) + + ## And this is it! We have a spyking circus + sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) + sorting["sample_index"] = spikes["sample_index"] + sorting["unit_index"] = spikes["cluster_index"] + sorting["segment_index"] = spikes["segment_index"] + sorting = NumpySorting(sorting, sampling_frequency, unit_ids) - if legacy: - clustering_method = "circus" - else: - clustering_method = "random_projections" - - labels, peak_labels = find_cluster_from_peaks( - recording_f, selected_peaks, method=clustering_method, method_kwargs=clustering_params - ) - - ## We get the labels for our peaks - mask = peak_labels > -1 - sorting = NumpySorting.from_times_labels( - selected_peaks["sample_index"][mask], peak_labels[mask].astype(int), sampling_rate - ) - clustering_folder = sorter_output_folder / "clustering" - clustering_folder.mkdir(parents=True, exist_ok=True) - - if not params["debug"]: - shutil.rmtree(clustering_folder) - else: - np.save(clustering_folder / "labels", labels) - np.save(clustering_folder / "peaks", selected_peaks) - - ## We get the templates our of such a clustering - waveforms_params = params["waveforms"].copy() - waveforms_params.update(job_kwargs) - - for k in ["ms_before", "ms_after"]: - waveforms_params[k] = params["general"][k] - - if params["shared_memory"] and not params["debug"]: - mode = "memory" - waveforms_folder = None - else: - sorting = sorting.save(folder=clustering_folder) - mode = "folder" - waveforms_folder = sorter_output_folder / "waveforms" - - we = extract_waveforms( - recording_f, - sorting, - waveforms_folder, - return_scaled=False, - precompute_template=["median"], - mode=mode, - **waveforms_params, - ) - - ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces - matching_params = params["matching"].copy() - matching_params["waveform_extractor"] = we - - matching_job_params = job_kwargs.copy() - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in matching_job_params: - matching_job_params.pop(value) - - matching_job_params["chunk_duration"] = "100ms" - - spikes = find_spikes_from_templates( - recording_f, method="circus-omp-svd", method_kwargs=matching_params, **matching_job_params - ) - - if params["debug"]: - fitting_folder = sorter_output_folder / "fitting" - fitting_folder.mkdir(parents=True, exist_ok=True) - np.save(fitting_folder / "spikes", spikes) - - if verbose: - print("We found %d spikes" % len(spikes)) - - ## And this is it! We have a spyking circus - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_rate) sorting_folder = sorter_output_folder / "sorting" - if sorting_folder.exists(): shutil.rmtree(sorting_folder) + folder_to_delete = None + cache_mode = params["cache_preprocessing"]["mode"] + delete_cache = params["cache_preprocessing"]["delete_cache"] + if cache_mode in ["folder", "zarr"] and delete_cache: + folder_to_delete = recording_f._kwargs["folder_path"] + + del recording_f + if folder_to_delete is not None: + shutil.rmtree(folder_to_delete) + sorting = sorting.save(folder=sorting_folder) return sorting diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 629b0b13ac..838839a29e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,7 +593,9 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() - local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025], "omp_min_sps": 1e-3}) + local_params.update( + {"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025], "optimize_amplitudes": False} + ) spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 32782ec627..499ce7869a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -12,6 +12,7 @@ HAVE_HDBSCAN = False import random, string, os +from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core import get_global_tmp_folder, get_channel_distances, get_random_data_chunks from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers @@ -70,8 +71,6 @@ def main_function(cls, recording, peaks, params): d = params verbose = d["job_kwargs"]["verbose"] - peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] - fs = recording.get_sampling_frequency() nbefore = int(params["ms_before"] * fs / 1000.0) nafter = int(params["ms_after"] * fs / 1000.0) @@ -99,9 +98,11 @@ def main_function(cls, recording, peaks, params): node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) - projections = np.random.randn(num_chans, d["nb_projections"]) - projections -= projections.mean(0) - projections /= projections.std(0) + num_projections = min(num_chans, d["nb_projections"]) + projections = np.random.randn(num_chans, num_projections) + if num_chans > 1: + projections -= projections.mean(0) + projections /= projections.std(0) nbefore = int(params["ms_before"] * fs / 1000) nafter = int(params["ms_after"] * fs / 1000) @@ -161,7 +162,7 @@ def main_function(cls, recording, peaks, params): best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[:max_spikes] nb_spikes += best_spikes[unit_ind].size - spikes = np.zeros(nb_spikes, dtype=peak_dtype) + spikes = np.zeros(nb_spikes, dtype=minimum_spike_dtype) mask = np.zeros(0, dtype=np.int32) for unit_ind in labels: diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index f22c3e3399..81fbb70d80 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -518,6 +518,7 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], "vicinity": 0, + "optimize_amplitudes": False, } @classmethod @@ -555,14 +556,39 @@ def _prepare_templates(cls, d): # We reconstruct the approximated templates templates = np.matmul(d["temporal"] * d["singular"][:, np.newaxis, :], d["spatial"]) - d["templates"] = {} + d["templates"] = np.zeros(templates.shape, dtype=np.float32) d["norms"] = np.zeros(num_templates, dtype=np.float32) # And get the norms, saving compressed templates for CC matrix for count in range(num_templates): template = templates[count][:, d["sparsity_mask"][count]] d["norms"][count] = np.linalg.norm(template) - d["templates"][count] = template / d["norms"][count] + d["templates"][count][:, d["sparsity_mask"][count]] = template / d["norms"][count] + + if d["optimize_amplitudes"]: + noise = np.random.randn(200, d["num_samples"] * d["num_channels"]) + r = d["templates"].reshape(num_templates, -1).dot(noise.reshape(len(noise), -1).T) + s = r / d["norms"][:, np.newaxis] + mad = np.median(np.abs(s - np.median(s, 1)[:, np.newaxis]), 1) + a_min = np.median(s, 1) + 5 * mad + + means = np.zeros((num_templates, num_templates), dtype=np.float32) + stds = np.zeros((num_templates, num_templates), dtype=np.float32) + for count, unit_id in enumerate(waveform_extractor.unit_ids): + w = waveform_extractor.get_waveforms(unit_id, force_dense=True) + r = d["templates"].reshape(num_templates, -1).dot(w.reshape(len(w), -1).T) + s = r / d["norms"][:, np.newaxis] + means[count] = np.median(s, 1) + stds[count] = np.median(np.abs(s - np.median(s, 1)[:, np.newaxis]), 1) + + _, a_max = d["amplitudes"] + d["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) + + for count in range(num_templates): + indices = np.argsort(means[count]) + a = np.where(indices == count)[0][0] + d["amplitudes"][count][1] = 1 + 5 * stds[count, indices[a]] + d["amplitudes"][count][0] = max(a_min[count], 1 - 5 * stds[count, indices[a]]) d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] d["temporal"] = np.flip(d["temporal"], axis=1) @@ -663,7 +689,12 @@ def main_function(cls, traces, d): omp_tol = np.finfo(np.float32).eps num_samples = d["nafter"] + d["nbefore"] neighbor_window = num_samples - 1 - min_amplitude, max_amplitude = d["amplitudes"] + if d["optimize_amplitudes"]: + min_amplitude, max_amplitude = d["amplitudes"][:, 0], d["amplitudes"][:, 1] + min_amplitude = min_amplitude[:, np.newaxis] + max_amplitude = max_amplitude[:, np.newaxis] + else: + min_amplitude, max_amplitude = d["amplitudes"] ignored_ids = d["ignored_ids"] vicinity = d["vicinity"] rank = d["rank"] diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 328e3b715d..0a73ee9a81 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -1,7 +1,15 @@ import numpy as np +try: + import psutil + + HAVE_PSUTIL = True +except: + HAVE_PSUTIL = False + from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractSparseWaveforms, PeakRetriever from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer +from spikeinterface.core.job_tools import split_job_kwargs def make_multi_method_doc(methods, ident=" "): @@ -69,3 +77,24 @@ def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0 ) prototype = np.median(waveforms[:, :, 0] / (waveforms[:, nbefore, 0][:, np.newaxis]), axis=0) return prototype + + +def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache=True, **extra_kwargs): + save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) + + if mode == "memory": + if HAVE_PSUTIL: + assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" + memory_usage = memory_limit * psutil.virtual_memory()[4] + if recording.get_total_memory_size() < memory_usage: + recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) + else: + print("Recording too large to be preloaded in RAM...") + else: + print("psutil is required to preload in memory") + elif mode == "folder": + recording = recording.save_to_folder(**extra_kwargs) + elif mode == "zarr": + recording = recording.save_to_zarr(**extra_kwargs) + + return recording