Skip to content

Commit

Permalink
Merge pull request #2267 from samuelgarcia/tridesclous2
Browse files Browse the repository at this point in the history
Tridesclous2 update
  • Loading branch information
alejoe91 authored Feb 6, 2024
2 parents 6bb7e71 + ecb7e3e commit ae03264
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 54 deletions.
3 changes: 2 additions & 1 deletion src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@


# This is to separate names when the key are tuples when saving folders
_key_separator = "_##_"
# _key_separator = "_##_"
_key_separator = "_-°°-_"


class GroundTruthStudy:
Expand Down
121 changes: 75 additions & 46 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import shutil

from .si_based import ComponentsBasedSorter

from spikeinterface.core import (
load_extractor,
BaseRecording,
get_noise_levels,
extract_waveforms,
NumpySorting,
Expand All @@ -14,10 +13,12 @@

from spikeinterface.core.job_tools import fix_job_kwargs

from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore
from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten
from spikeinterface.core.basesorting import minimum_spike_dtype

from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel
from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing

# from spikeinterface.qualitymetrics import compute_snrs

import numpy as np

Expand All @@ -30,6 +31,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):

_default_params = {
"apply_preprocessing": True,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"waveforms": {
"ms_before": 0.5,
"ms_after": 1.5,
Expand All @@ -38,22 +40,39 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"filtering": {"freq_min": 300.0, "freq_max": 12000.0},
"detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0},
"selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000},
"features": {},
"svd": {"n_components": 6},
"clustering": {
"split_radius_um": 40.0,
"merge_radius_um": 40.0,
"threshold_diff": 1.5,
},
"templates": {
"ms_before": 1.5,
"ms_after": 2.5,
"ms_before": 2.0,
"ms_after": 3.0,
"max_spikes_per_unit": 400,
# "peak_shift_ms": 0.2,
},
"matching": {"peak_shift_ms": 0.2, "radius_um": 100.0},
# "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}},
"matching": {"method": "circus-omp-svd", "method_kwargs": {}},
"job_kwargs": {"n_jobs": -1},
"save_array": True,
}

_params_description = {
"apply_preprocessing": "Apply internal preprocessing or not",
"cache_preprocessing": "A dict contaning how to cache the preprocessed recording. mode='memory' | 'folder | 'zarr' ",
"waveforms": "A dictonary containing waveforms params: ms_before, ms_after, radius_um",
"filtering": "A dictonary containing filtering params: freq_min, freq_max",
"detection": "A dictonary containing detection params: peak_sign, detect_threshold, exclude_sweep_ms, radius_um",
"selection": "A dictonary containing selection params: n_peaks_per_channel, min_n_peaks",
"svd": "A dictonary containing svd params: n_components",
"clustering": "A dictonary containing clustering params: split_radius_um, merge_radius_um",
"templates": "A dictonary containing waveforms params for peeler: ms_before, ms_after",
"matching": "A dictonary containing matching params for matching: peak_shift_ms, radius_um",
"job_kwargs": "A dictionary containing job kwargs",
"save_array": "Save or not intermediate arrays",
}

handle_multi_segment = True

@classmethod
Expand Down Expand Up @@ -97,6 +116,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# TODO what is the best about zscore>common_reference or the reverse
recording = common_reference(recording)
recording = zscore(recording, dtype="float32")
# recording = whiten(recording, dtype="float32")

# used only if "folder" or "zarr"
cache_folder = sorter_output_folder / "cache_preprocessing"
recording = cache_preprocessing(
recording, folder=cache_folder, **job_kwargs, **params["cache_preprocessing"]
)

noise_levels = np.ones(num_chans, dtype="float32")
else:
recording = recording_raw
Expand Down Expand Up @@ -151,22 +178,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
features_folder = sorter_output_folder / "features"
node0 = PeakRetriever(recording, peaks)

# node1 = ExtractDenseWaveforms(rec, parents=[node0], return_output=False,
# ms_before=0.5,
# ms_after=1.5,
# )

# node2 = LocalizeCenterOfMass(rec, parents=[node0, node1], return_output=True,
# local_radius_um=75.0,
# feature="ptp", )

# node2 = LocalizeGridConvolution(rec, parents=[node0, node1], return_output=True,
# local_radius_um=40.,
# upsampling_um=5.0,
# )

radius_um = params["waveforms"]["radius_um"]
node3 = ExtractSparseWaveforms(
node1 = ExtractSparseWaveforms(
recording,
parents=[node0],
return_output=True,
Expand All @@ -177,12 +190,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

model_folder_path = sorter_output_folder / "tsvd_model"

node4 = TemporalPCAProjection(
recording, parents=[node0, node3], return_output=True, model_folder_path=model_folder_path
node2 = TemporalPCAProjection(
recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder_path
)

# pipeline_nodes = [node0, node1, node2, node3, node4]
pipeline_nodes = [node0, node3, node4]
pipeline_nodes = [node0, node1, node2]

output = run_node_pipeline(
recording,
Expand All @@ -195,7 +207,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)

# TODO make this generic in GatherNPY ???
sparse_mask = node3.neighbours_mask
sparse_mask = node1.neighbours_mask
np.save(features_folder / "sparse_mask.npy", sparse_mask)
np.save(features_folder / "peaks.npy", peaks)

Expand Down Expand Up @@ -231,6 +243,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)

merge_radius_um = params["clustering"]["merge_radius_um"]
threshold_diff = params["clustering"]["threshold_diff"]

post_merge_label, peak_shifts = merge_clusters(
peaks,
Expand All @@ -251,7 +264,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
method="normalized_template_diff",
method_kwargs=dict(
waveforms_sparse_mask=sparse_mask,
threshold_diff=0.2,
threshold_diff=threshold_diff,
min_cluster_size=min_cluster_size + 1,
num_shift=5,
),
Expand Down Expand Up @@ -284,29 +297,45 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)
sorting_temp = sorting_temp.save(folder=sorter_output_folder / "sorting_temp")

ms_before = params["templates"]["ms_before"]
ms_after = params["templates"]["ms_after"]
max_spikes_per_unit = 300
we = extract_waveforms(recording, sorting_temp, sorter_output_folder / "waveforms_temp", **params["templates"])

we = extract_waveforms(
recording,
sorting_temp,
sorter_output_folder / "waveforms_temp",
ms_before=ms_before,
ms_after=ms_after,
max_spikes_per_unit=max_spikes_per_unit,
**job_kwargs,
)
# snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum")
# print(snrs)

# matching_params = params["matching"].copy()
# matching_params["waveform_extractor"] = we
# matching_params["noise_levels"] = noise_levels
# matching_params["peak_sign"] = params["detection"]["peak_sign"]
# matching_params["detect_threshold"] = params["detection"]["detect_threshold"]
# matching_params["radius_um"] = params["detection"]["radius_um"]

# spikes = find_spikes_from_templates(
# recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs
# )

matching_method = params["matching"]["method"]
matching_params = params["matching"]["method_kwargs"].copy()

matching_params = params["matching"].copy()
matching_params["waveform_extractor"] = we
matching_params["noise_levels"] = noise_levels
matching_params["peak_sign"] = params["detection"]["peak_sign"]
matching_params["detect_threshold"] = params["detection"]["detect_threshold"]
matching_params["radius_um"] = params["detection"]["radius_um"]
# matching_params["peak_sign"] = params["detection"]["peak_sign"]
# matching_params["detect_threshold"] = params["detection"]["detect_threshold"]
# matching_params["radius_um"] = params["detection"]["radius_um"]

# spikes = find_spikes_from_templates(
# recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs
# )
# )

if matching_method == "circus-omp-svd":
job_kwargs = job_kwargs.copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in job_kwargs:
job_kwargs.pop(value)
job_kwargs["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording, method="tridesclous", method_kwargs=matching_params, **job_kwargs
recording, method=matching_method, method_kwargs=matching_params, **job_kwargs
)

if params["save_array"]:
Expand Down
71 changes: 64 additions & 7 deletions src/spikeinterface/sortingcomponents/clustering/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def merge(
peaks,
features,
waveforms_sparse_mask=None,
threshold_diff=0.05,
threshold_diff=1.5,
min_cluster_size=50,
num_shift=5,
):
Expand Down Expand Up @@ -649,39 +649,96 @@ def merge(
num_samples = template0.shape[0]
# norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1))
norm = np.mean(np.abs(template0) + np.abs(template1))

# norm_per_channel = np.max(np.abs(template0) + np.abs(template1), axis=0) / 2.
norm_per_channel = (np.max(np.abs(template0), axis=0) + np.max(np.abs(template1), axis=0)) * 0.5
# norm_per_channel = np.max(np.abs(template0)) + np.max(np.abs(template1)) / 2.
# print(norm_per_channel)

all_shift_diff = []
# all_shift_diff_by_channel = []
for shift in range(-num_shift, num_shift + 1):
temp0 = template0[num_shift : num_samples - num_shift, :]
temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :]
d = np.mean(np.abs(temp0 - temp1)) / (norm)
all_shift_diff.append(d)
# d = np.mean(np.abs(temp0 - temp1)) / (norm)
# d = np.max(np.abs(temp0 - temp1)) / (norm)
diff_per_channel = np.abs(temp0 - temp1) / norm

diff_max = np.max(diff_per_channel, axis=0)

# diff = np.max(diff_per_channel)
diff = np.average(diff_max, weights=norm_per_channel)
# diff = np.average(diff_max)
all_shift_diff.append(diff)
# diff_by_channel = np.mean(np.abs(temp0 - temp1), axis=0) / (norm)
# all_shift_diff_by_channel.append(diff_by_channel)
# d = np.mean(diff_by_channel)
# all_shift_diff.append(d)
normed_diff = np.min(all_shift_diff)

is_merge = normed_diff < threshold_diff

if is_merge:
merge_value = normed_diff
final_shift = np.argmin(all_shift_diff) - num_shift

# diff_by_channel = all_shift_diff_by_channel[np.argmin(all_shift_diff)]
else:
final_shift = 0
merge_value = np.nan

if DEBUG and normed_diff < 0.2:
# print('merge_value', merge_value, 'final_shift', final_shift, 'is_merge', is_merge)

DEBUG = False
# DEBUG = True
# if DEBUG and ( 0. < normed_diff < .4):
# if 0.5 < normed_diff < 4:
if DEBUG and is_merge:
# if DEBUG:

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
fig, axs = plt.subplots(nrows=3)

m0 = template0.flatten()
m1 = template1.flatten()
temp0 = template0[num_shift : num_samples - num_shift, :]
temp1 = template1[num_shift + final_shift : num_samples - num_shift + final_shift, :]

diff_per_channel = np.abs(temp0 - temp1) / norm
diff = np.max(diff_per_channel)

m0 = temp0.T.flatten()
m1 = temp1.T.flatten()

ax = axs[0]
ax.plot(m0, color="C0", label=f"{label0} {inds0.size}")
ax.plot(m1, color="C1", label=f"{label1} {inds1.size}")

ax.set_title(
f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}"
)
ax.legend()

ax = axs[1]

# ~ temp0 = template0[num_shift : num_samples - num_shift, :]
# ~ temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :]
ax.plot(np.abs(m0 - m1))
# ax.axhline(norm, ls='--', color='k')
ax = axs[2]
ax.plot(diff_per_channel.T.flatten())
ax.axhline(threshold_diff, ls="--")
ax.axhline(normed_diff)

# ax.axhline(normed_diff, ls='-', color='b')
# ax.plot(norm, ls='--')
# ax.plot(diff_by_channel)

# ax.plot(np.abs(m0) + np.abs(m1))

# ax.plot(np.abs(m0 - m1) / (np.abs(m0) + np.abs(m1)))

# ax.set_title(f"{norm=:.3f}")

plt.show()

return is_merge, label0, label1, final_shift, merge_value
Expand Down

0 comments on commit ae03264

Please sign in to comment.