Skip to content

Commit

Permalink
Merge pull request #2860 from samuelgarcia/tridesclous2
Browse files Browse the repository at this point in the history
Update tridesclous2
  • Loading branch information
samuelgarcia authored May 24, 2024
2 parents 95d2917 + 023baba commit 5d69ec0
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 82 deletions.
137 changes: 64 additions & 73 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from spikeinterface.core import (
get_noise_levels,
NumpySorting,
get_channel_distances,
estimate_templates_with_accumulator,
Templates,
compute_sparsity,
Expand All @@ -18,29 +17,26 @@
from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten
from spikeinterface.core.basesorting import minimum_spike_dtype

from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing
from spikeinterface.sortingcomponents.tools import cache_preprocessing

# from spikeinterface.qualitymetrics import compute_snrs

import numpy as np

import pickle
import json


class Tridesclous2Sorter(ComponentsBasedSorter):
sorter_name = "tridesclous2"

_default_params = {
"apply_preprocessing": True,
"apply_motion_correction": False,
"motion_correction": {"preset": "nonrigid_fast_and_accurate"},
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"waveforms": {
"ms_before": 0.5,
"ms_after": 1.5,
"radius_um": 120.0,
},
"filtering": {"freq_min": 300.0, "freq_max": 12000.0},
"filtering": {"freq_min": 300.0, "freq_max": 8000.0},
"detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0},
"selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000},
"svd": {"n_components": 6},
Expand All @@ -53,7 +49,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"ms_before": 2.0,
"ms_after": 3.0,
"max_spikes_per_unit": 400,
"sparsity_threshold": 2.0,
"sparsity_threshold": 1.5,
# "peak_shift_ms": 0.2,
},
# "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}},
Expand Down Expand Up @@ -86,31 +82,18 @@ def get_sorter_version(cls):

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
job_kwargs = params["job_kwargs"].copy()
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs["progress_bar"] = verbose

from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
ExtractDenseWaveforms,
ExtractSparseWaveforms,
PeakRetriever,
)
from spikeinterface.sortingcomponents.peak_detection import detect_peaks, DetectPeakLocallyExclusive
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeGridConvolution
from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection

from spikeinterface.sortingcomponents.clustering.split import split_clusters
from spikeinterface.sortingcomponents.clustering.merge import merge_clusters
from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse
from spikeinterface.sortingcomponents.clustering.main import find_cluster_from_peaks
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.preprocessing import correct_motion
from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording

from sklearn.decomposition import TruncatedSVD

import hdbscan
job_kwargs = params["job_kwargs"].copy()
job_kwargs = fix_job_kwargs(job_kwargs)
job_kwargs["progress_bar"] = verbose

recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)

Expand All @@ -119,10 +102,44 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

# preprocessing
if params["apply_preprocessing"]:
recording = bandpass_filter(recording_raw, **params["filtering"])
if params["apply_motion_correction"]:
rec_for_motion = recording_raw
if params["apply_preprocessing"]:
rec_for_motion = bandpass_filter(rec_for_motion, freq_min=300.0, freq_max=6000.0, dtype="float32")
rec_for_motion = common_reference(rec_for_motion)
if verbose:
print("Start correct_motion()")
_, motion_info = correct_motion(
rec_for_motion,
folder=sorter_output_folder / "motion",
output_motion_info=True,
**params["motion_correction"],
)
if verbose:
print("Done correct_motion()")

recording = bandpass_filter(recording_raw, **params["filtering"], dtype="float32")
recording = common_reference(recording)

if params["apply_motion_correction"]:
interpolate_motion_kwargs = dict(
direction=1,
border_mode="force_extrapolate",
spatial_interpolation_method="kriging",
sigma_um=20.0,
p=2,
)

recording = InterpolateMotionRecording(
recording,
motion_info["motion"],
motion_info["temporal_bins"],
motion_info["spatial_bins"],
**interpolate_motion_kwargs,
)

recording = zscore(recording, dtype="float32")
recording = whiten(recording, dtype="float32")
recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0)

# used only if "folder" or "zarr"
cache_folder = sorter_output_folder / "cache_preprocessing"
Expand All @@ -141,7 +158,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
all_peaks = detect_peaks(recording, method="locally_exclusive", **detection_params, **job_kwargs)

if verbose:
print("We found %d peaks in total" % len(all_peaks))
print(f"detect_peaks(): {len(all_peaks)} peaks found")

# selection
selection_params = params["selection"].copy()
Expand All @@ -150,36 +167,38 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
peaks = select_peaks(all_peaks, method="uniform", n_peaks=n_peaks)

if verbose:
print("We kept %d peaks for clustering" % len(peaks))
print(f"select_peaks(): {len(peaks)} peaks kept for clustering")

clustering_kwargs = {}
clustering_kwargs["folder"] = sorter_output_folder
clustering_kwargs["waveforms"] = params["waveforms"].copy()
clustering_kwargs["clustering"] = params["clustering"].copy()

labels_set, post_clean_label, extra_out = find_cluster_from_peaks(
labels_set, clustering_label, extra_out = find_cluster_from_peaks(
recording, peaks, method="tdc_clustering", method_kwargs=clustering_kwargs, extra_outputs=True, **job_kwargs
)
peak_shifts = extra_out["peak_shifts"]
new_peaks = peaks.copy()
new_peaks["sample_index"] -= peak_shifts

mask = post_clean_label >= 0
mask = clustering_label >= 0
sorting_pre_peeler = NumpySorting.from_times_labels(
new_peaks["sample_index"][mask],
post_clean_label[mask],
clustering_label[mask],
sampling_frequency,
unit_ids=labels_set,
)
# sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler")

recording_w = whiten(recording, mode="local", radius_um=100.0)
if verbose:
print(f"find_cluster_from_peaks(): {sorting_pre_peeler.unit_ids.size} cluster found")

recording_for_peeler = recording

nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0)
nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0)
sparsity_threshold = params["templates"]["sparsity_threshold"]

templates_array = estimate_templates_with_accumulator(
recording_w,
recording_for_peeler,
sorting_pre_peeler.to_spike_vector(),
sorting_pre_peeler.unit_ids,
nbefore,
Expand All @@ -192,61 +211,33 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=None,
probe=recording_w.get_probe(),
probe=recording_for_peeler.get_probe(),
is_scaled=False,
)

# TODO : try other methods for sparsity
sparsity_threshold = params["templates"]["sparsity_threshold"]
# sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.)
sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=sparsity_threshold)
templates = templates_dense.to_sparse(sparsity)
templates = remove_empty_templates(templates)

# snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum")
# print(snrs)

# matching_params = params["matching"].copy()
# matching_params["noise_levels"] = noise_levels
# matching_params["peak_sign"] = params["detection"]["peak_sign"]
# matching_params["detect_threshold"] = params["detection"]["detect_threshold"]
# matching_params["radius_um"] = params["detection"]["radius_um"]

# spikes = find_spikes_from_templates(
# recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs
# )

## peeler
matching_method = params["matching"]["method"]
matching_params = params["matching"]["method_kwargs"].copy()

matching_params["templates"] = templates
matching_params["noise_levels"] = noise_levels
# matching_params["peak_sign"] = params["detection"]["peak_sign"]
# matching_params["detect_threshold"] = params["detection"]["detect_threshold"]
# matching_params["radius_um"] = params["detection"]["radius_um"]

# spikes = find_spikes_from_templates(
# recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs
# )
# )

# if matching_method == "circus-omp-svd":
# job_kwargs = job_kwargs.copy()
# for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
# if value in job_kwargs:
# job_kwargs.pop(value)
# job_kwargs["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_w, method=matching_method, method_kwargs=matching_params, **job_kwargs
recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs
)

if params["save_array"]:
sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler")

np.save(sorter_output_folder / "noise_levels.npy", noise_levels)
np.save(sorter_output_folder / "all_peaks.npy", all_peaks)
# np.save(sorter_output_folder / "post_split_label.npy", post_split_label)
# np.save(sorter_output_folder / "split_count.npy", split_count)
# np.save(sorter_output_folder / "post_merge_label.npy", post_merge_label)
np.save(sorter_output_folder / "peaks.npy", peaks)
np.save(sorter_output_folder / "clustering_label.npy", clustering_label)
np.save(sorter_output_folder / "spikes.npy", spikes)

final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def plot_agreements(self, case_keys=None, figsize=(15, 15)):
ax.set_title(self.cases[key]["label"])
plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax)

return fig

def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)):
if case_keys is None:
case_keys = list(self.cases.keys())
Expand All @@ -210,6 +212,8 @@ def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)):
if count == 2:
ax.legend()

return fig

def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):

if case_keys is None:
Expand Down Expand Up @@ -244,6 +248,8 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):
label = self.cases[key]["label"]
axs[0, count].set_title(label)

return fig

def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)):

if case_keys is None:
Expand Down Expand Up @@ -296,6 +302,8 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5
axs[0, count].set_title(label)
axs[0, count].legend()

return fig

def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figsize=(15, 5)):

if case_keys is None:
Expand Down Expand Up @@ -354,6 +362,8 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
axs[0, count].set_title(label)
# axs[0, count].legend()

return fig

def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=None):

fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize)
Expand Down Expand Up @@ -384,6 +394,7 @@ def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=
fig.colorbar(im, ax=ax)
ax.set_title(k)
ax.set_ylabel("snr")
return fig

def plot_comparison_clustering(
self,
Expand Down Expand Up @@ -444,10 +455,13 @@ def plot_comparison_clustering(

plt.tight_layout(h_pad=0, w_pad=0)

return fig

def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())

figs = []
for count, key in enumerate(case_keys):
label = self.cases[key]["label"]
comp = self.get_result(key)["gt_comparison"]
Expand Down Expand Up @@ -475,13 +489,17 @@ def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units
ax.set_xticks([])

fig.suptitle(label)
figs.append(fig)
else:
print(key, "no overmerged")

return figs

def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())

figs = []
for count, key in enumerate(case_keys):
label = self.cases[key]["label"]
comp = self.get_result(key)["gt_comparison"]
Expand Down Expand Up @@ -509,5 +527,8 @@ def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units
ax.set_xticks([])

fig.suptitle(label)
figs.append(fig)
else:
print(key, "no over splited")

return figs
Loading

0 comments on commit 5d69ec0

Please sign in to comment.