Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark sorting components + Tridesclous2 improvement #2811

Merged
merged 18 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def remove_sorting(self, key):
sorting_folder = self.folder / "sortings" / self.key_to_str(key)
log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json"
comparison_file = self.folder / "comparisons" / self.key_to_str(key)
self.sortings[key] = None
self.comparisons[key] = None
if sorting_folder.exists():
shutil.rmtree(sorting_folder)
for f in (log_file, comparison_file):
Expand Down Expand Up @@ -381,6 +383,7 @@ def get_performance_by_unit(self, case_keys=None):

perf_by_unit = pd.concat(perf_by_unit)
perf_by_unit = perf_by_unit.set_index(self.levels)
perf_by_unit = perf_by_unit.sort_index()
return perf_by_unit

def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
return (local_peaks,)


def sorting_to_peaks(sorting, extremum_channel_inds, dtype):
def sorting_to_peaks(sorting, extremum_channel_inds, dtype=spike_peak_dtype):
spikes = sorting.to_spike_vector()
peaks = np.zeros(spikes.size, dtype=dtype)
peaks["sample_index"] = spikes["sample_index"]
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scale
json.dump(check_json(info), f, indent=4)

# save a copy of the sorting
NumpyFolderSorting.write_sorting(sorting, folder / "sorting")
# NumpyFolderSorting.write_sorting(sorting, folder / "sorting")
sorting.save(folder=folder / "sorting")

# save recording and sorting provenance
if recording.check_serializability("json"):
Expand Down
36 changes: 33 additions & 3 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def make_one_displacement_vector(
start_drift_index = int(t_start_drift * displacement_sampling_frequency)
end_drift_index = int(t_end_drift * displacement_sampling_frequency)

num_samples = int(displacement_sampling_frequency * duration)
num_samples = int(np.ceil(displacement_sampling_frequency * duration))
displacement_vector = np.zeros(num_samples, dtype="float32")

if drift_mode == "zigzag":
Expand Down Expand Up @@ -286,6 +286,7 @@ def generate_drifting_recording(
),
generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0),
generate_noise_kwargs=dict(noise_levels=(12.0, 15.0), spatial_decay=25.0),
more_outputs=False,
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
seed=None,
):
"""
Expand Down Expand Up @@ -314,6 +315,8 @@ def generate_drifting_recording(
Parameters given to generate_sorting().
generate_noise_kwargs: dict
Parameters given to generate_noise().
more_outputs: bool, default False
Return optionaly a dict with more variables.
seed: None ot int
A unique seed for all steps.

Expand All @@ -326,7 +329,14 @@ def generate_drifting_recording(
sorting: Sorting
The ground trith soring object.
Same for both recordings.

more_infos:
If more_outputs=True, then return also a dict that contain various information like:
* displacement_vectors
* displacement_sampling_frequency
* unit_locations
* displacement_unit_factor
* unit_displacements
This can be helpfull for motion benchmark.
"""

rng = np.random.default_rng(seed=seed)
Expand Down Expand Up @@ -356,6 +366,14 @@ def generate_drifting_recording(
generate_displacement_vector(duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs)
)

# unit_displacements is the sum of all discplacements (times, units, direction_x_y)
unit_displacements = np.zeros((displacement_vectors.shape[0], num_units, 2))
for direction in (0, 1):
# x and y
for i in range(displacement_vectors.shape[2]):
m = displacement_vectors[:, direction, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :]
unit_displacements[:, :, direction] += m

# unit_params need to be fixed before the displacement steps
generate_templates_kwargs = generate_templates_kwargs.copy()
unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed)
Expand Down Expand Up @@ -400,6 +418,8 @@ def generate_drifting_recording(
seed=seed,
)

sorting.set_property("gt_unit_locations", unit_locations)

## Important precompute displacement do not work on border and so do not work for tetrode
# here we bypass the interpolation and regenrate templates at severals positions.
## drifting_templates.precompute_displacements(displacements_steps)
Expand Down Expand Up @@ -437,4 +457,14 @@ def generate_drifting_recording(
amplitude_factor=None,
)

return static_recording, drifting_recording, sorting
if more_outputs:
more_infos = dict(
displacement_vectors=displacement_vectors,
displacement_sampling_frequency=displacement_sampling_frequency,
unit_locations=unit_locations,
displacement_unit_factor=displacement_unit_factor,
unit_displacements=unit_displacements,
)
return static_recording, drifting_recording, sorting, more_infos
else:
return static_recording, drifting_recording, sorting
178 changes: 27 additions & 151 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):

_default_params = {
"apply_preprocessing": True,
"apply_motion_correction": False,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"waveforms": {
"ms_before": 0.5,
Expand All @@ -52,10 +53,12 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"ms_before": 2.0,
"ms_after": 3.0,
"max_spikes_per_unit": 400,
"sparsity_threshold": 2.0,
# "peak_shift_ms": 0.2,
},
# "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}},
"matching": {"method": "circus-omp-svd", "method_kwargs": {}},
# "matching": {"method": "circus-omp-svd", "method_kwargs": {}},
"matching": {"method": "wobble", "method_kwargs": {}},
"job_kwargs": {"n_jobs": -1},
"save_array": True,
}
Expand Down Expand Up @@ -102,6 +105,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
from spikeinterface.sortingcomponents.clustering.split import split_clusters
from spikeinterface.sortingcomponents.clustering.merge import merge_clusters
from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse
from spikeinterface.sortingcomponents.clustering.main import find_cluster_from_peaks
from spikeinterface.sortingcomponents.tools import remove_empty_templates

from sklearn.decomposition import TruncatedSVD

Expand All @@ -115,10 +120,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# preprocessing
if params["apply_preprocessing"]:
recording = bandpass_filter(recording_raw, **params["filtering"])
# TODO what is the best about zscore>common_reference or the reverse
recording = common_reference(recording)
recording = zscore(recording, dtype="float32")
# recording = whiten(recording, dtype="float32")
recording = whiten(recording, dtype="float32")

# used only if "folder" or "zarr"
cache_folder = sorter_output_folder / "cache_preprocessing"
Expand Down Expand Up @@ -148,152 +152,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if verbose:
print("We kept %d peaks for clustering" % len(peaks))

ms_before = params["waveforms"]["ms_before"]
ms_after = params["waveforms"]["ms_after"]
clustering_kwargs = {}
clustering_kwargs["folder"] = sorter_output_folder
clustering_kwargs["waveforms"] = params["waveforms"].copy()
clustering_kwargs["clustering"] = params["clustering"].copy()

# SVD for time compression
few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000)
few_wfs = extract_waveform_at_max_channel(
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
labels_set, post_clean_label, extra_out = find_cluster_from_peaks(
recording, peaks, method="tdc_clustering", method_kwargs=clustering_kwargs, extra_outputs=True, **job_kwargs
)

wfs = few_wfs[:, :, 0]
tsvd = TruncatedSVD(params["svd"]["n_components"])
tsvd.fit(wfs)

model_folder = sorter_output_folder / "tsvd_model"

model_folder.mkdir(exist_ok=True)
with open(model_folder / "pca_model.pkl", "wb") as f:
pickle.dump(tsvd, f)

model_params = {
"ms_before": ms_before,
"ms_after": ms_after,
"sampling_frequency": float(sampling_frequency),
}
with open(model_folder / "params.json", "w") as f:
json.dump(model_params, f)

# features

features_folder = sorter_output_folder / "features"
node0 = PeakRetriever(recording, peaks)

radius_um = params["waveforms"]["radius_um"]
node1 = ExtractSparseWaveforms(
recording,
parents=[node0],
return_output=True,
ms_before=ms_before,
ms_after=ms_after,
radius_um=radius_um,
)

model_folder_path = sorter_output_folder / "tsvd_model"

node2 = TemporalPCAProjection(
recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder_path
)

pipeline_nodes = [node0, node1, node2]

output = run_node_pipeline(
recording,
pipeline_nodes,
job_kwargs,
gather_mode="npy",
gather_kwargs=dict(exist_ok=True),
folder=features_folder,
names=["sparse_wfs", "sparse_tsvd"],
)

# TODO make this generic in GatherNPY ???
sparse_mask = node1.neighbours_mask
np.save(features_folder / "sparse_mask.npy", sparse_mask)
np.save(features_folder / "peaks.npy", peaks)

# Clustering: channel index > split > merge
split_radius_um = params["clustering"]["split_radius_um"]
neighbours_mask = get_channel_distances(recording) < split_radius_um

original_labels = peaks["channel_index"]

min_cluster_size = 50

post_split_label, split_count = split_clusters(
original_labels,
recording,
features_folder,
method="local_feature_clustering",
method_kwargs=dict(
# clusterer="hdbscan",
clusterer="isocut5",
feature_name="sparse_tsvd",
# feature_name="sparse_wfs",
neighbours_mask=neighbours_mask,
waveforms_sparse_mask=sparse_mask,
min_size_split=min_cluster_size,
clusterer_kwargs={"min_cluster_size": min_cluster_size},
n_pca_features=3,
),
recursive=True,
recursive_depth=3,
returns_split_count=True,
**job_kwargs,
)

merge_radius_um = params["clustering"]["merge_radius_um"]
threshold_diff = params["clustering"]["threshold_diff"]

post_merge_label, peak_shifts = merge_clusters(
peaks,
post_split_label,
recording,
features_folder,
radius_um=merge_radius_um,
# method="project_distribution",
# method_kwargs=dict(
# waveforms_sparse_mask=sparse_mask,
# feature_name="sparse_wfs",
# projection="centroid",
# criteria="distrib_overlap",
# threshold_overlap=0.3,
# min_cluster_size=min_cluster_size + 1,
# num_shift=5,
# ),
method="normalized_template_diff",
method_kwargs=dict(
waveforms_sparse_mask=sparse_mask,
threshold_diff=threshold_diff,
min_cluster_size=min_cluster_size + 1,
num_shift=5,
),
**job_kwargs,
)

# sparse_wfs = np.load(features_folder / "sparse_wfs.npy", mmap_mode="r")

peak_shifts = extra_out["peak_shifts"]
new_peaks = peaks.copy()
new_peaks["sample_index"] -= peak_shifts

# clean very small cluster before peeler
post_clean_label = post_merge_label.copy()

minimum_cluster_size = 25
labels_set, count = np.unique(post_clean_label, return_counts=True)
to_remove = labels_set[count < minimum_cluster_size]
mask = np.isin(post_clean_label, to_remove)
post_clean_label[mask] = -1

# final label sets
labels_set = np.unique(post_clean_label)
labels_set = labels_set[labels_set >= 0]

mask = post_clean_label >= 0
sorting_pre_peeler = NumpySorting.from_times_labels(
new_peaks["sample_index"][mask],
post_merge_label[mask],
post_clean_label[mask],
sampling_frequency,
unit_ids=labels_set,
)
Expand All @@ -303,6 +177,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0)
nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0)
sparsity_threshold = params["templates"]["sparsity_threshold"]
templates_array = estimate_templates_with_accumulator(
recording_w,
sorting_pre_peeler.to_spike_vector(),
Expand All @@ -320,8 +195,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)
# TODO : try other methods for sparsity
# sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.)
sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=1.0)
sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=sparsity_threshold)
templates = templates_dense.to_sparse(sparsity)
templates = remove_empty_templates(templates)

# snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum")
# print(snrs)
Expand Down Expand Up @@ -350,12 +226,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# )
# )

if matching_method == "circus-omp-svd":
job_kwargs = job_kwargs.copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in job_kwargs:
job_kwargs.pop(value)
job_kwargs["chunk_duration"] = "100ms"
# if matching_method == "circus-omp-svd":
# job_kwargs = job_kwargs.copy()
# for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
# if value in job_kwargs:
# job_kwargs.pop(value)
# job_kwargs["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_w, method=matching_method, method_kwargs=matching_params, **job_kwargs
Expand All @@ -366,9 +242,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

np.save(sorter_output_folder / "noise_levels.npy", noise_levels)
np.save(sorter_output_folder / "all_peaks.npy", all_peaks)
np.save(sorter_output_folder / "post_split_label.npy", post_split_label)
np.save(sorter_output_folder / "split_count.npy", split_count)
np.save(sorter_output_folder / "post_merge_label.npy", post_merge_label)
# np.save(sorter_output_folder / "post_split_label.npy", post_split_label)
# np.save(sorter_output_folder / "split_count.npy", split_count)
# np.save(sorter_output_folder / "post_merge_label.npy", post_merge_label)
np.save(sorter_output_folder / "spikes.npy", spikes)

final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sorters/tests/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def test_run_sorter_by_property():
setup_module()
job_list = get_job_list()

test_run_sorter_jobs_loop(job_list)
# test_run_sorter_jobs_loop(job_list)
# test_run_sorter_jobs_joblib(job_list)
# test_run_sorter_jobs_processpoolexecutor(job_list)
# test_run_sorter_jobs_multiprocessing(job_list)
# test_run_sorter_jobs_dask(job_list)
# test_run_sorter_jobs_slurm(job_list)

# test_run_sorter_by_property()
test_run_sorter_by_property()
Loading
Loading