Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse waveforms for Spyking Circus 2 #1943

Merged
merged 39 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9dc04f1
WIP
yger Jul 14, 2023
0f9fee6
WIP
yger Jul 17, 2023
15a5b5f
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Jul 17, 2023
7a3d4c2
WIP
yger Jul 17, 2023
4a19ad2
Merge branch 'SpikeInterface:main' into factoring_omp
yger Jul 18, 2023
d89d5a9
WIP
yger Jul 24, 2023
892305b
WIP
yger Jul 24, 2023
3044c80
Merge branch 'SpikeInterface:main' into factoring_omp
yger Aug 28, 2023
1cb122c
WIP for circus2
yger Aug 28, 2023
169a3e9
Merge branch 'factoring_omp' of github.com:yger/spikeinterface into f…
yger Aug 28, 2023
ef204dd
WIP
yger Aug 28, 2023
242799f
Docs
yger Aug 28, 2023
5566c91
Fix for circus
yger Aug 29, 2023
75c9793
WIP
yger Aug 29, 2023
d7e9ac1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2023
cb3c07c
Merge branch 'main' into factoring_omp
yger Aug 29, 2023
14c8f58
useless dependency
yger Aug 29, 2023
e455da3
Fix for classical circus with sparsity
yger Aug 29, 2023
2f84c6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2023
3d849fb
Fix for classical circus with sparsity
yger Aug 29, 2023
025d31c
Merge branch 'factoring_omp' of github.com:yger/spikeinterface into f…
yger Aug 29, 2023
7dcfdb0
Fixing slow tests with SC2
yger Aug 29, 2023
9f196b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2023
78fd5ed
Merge branch 'main' into factoring_omp
yger Aug 29, 2023
1c7c802
WIP for cleaning
yger Aug 29, 2023
f53dba1
Merge branch 'factoring_omp' of github.com:yger/spikeinterface into f…
yger Aug 29, 2023
af4f187
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2023
8c2af8f
WIP
yger Aug 29, 2023
8379d9a
Merge branch 'main' into factoring_omp
yger Aug 30, 2023
434dd9c
Merge branch 'SpikeInterface:main' into factoring_omp
yger Sep 13, 2023
75ee1af
Merge branch 'factoring_omp' of github.com:yger/spikeinterface into f…
yger Sep 13, 2023
99e7acc
WIP
yger Sep 13, 2023
cc79213
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2023
dda7803
Adding unit_ids
yger Sep 13, 2023
e4b99cb
Merge branch 'main' into factoring_omp
yger Sep 13, 2023
19e7f5d
Merge branch 'main' into factoring_omp
yger Sep 14, 2023
9d7ae1f
Merge branch 'SpikeInterface:main' into factoring_omp
yger Sep 14, 2023
393ca66
Merge branch 'SpikeInterface:main' into factoring_omp
yger Sep 19, 2023
3781c16
Merge branch 'SpikeInterface:main' into factoring_omp
yger Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import shutil
import numpy as np
import os

from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms
from spikeinterface.core.job_tools import fix_job_kwargs
Expand All @@ -21,18 +20,17 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
sorter_name = "spykingcircus2"

_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"waveforms": {"max_spikes_per_unit": 200, "overwrite": True},
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 75},
"waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1},
"filtering": {"dtype": "float32"},
"detection": {"peak_sign": "neg", "detect_threshold": 5},
"selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000},
"localization": {},
"clustering": {},
"matching": {},
"registration": {},
"apply_preprocessing": True,
"shared_memory": False,
"job_kwargs": {},
"shared_memory": True,
"job_kwargs": {"n_jobs": -1},
}

@classmethod
Expand Down Expand Up @@ -63,8 +61,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## First, we are filtering the data
filtering_params = params["filtering"].copy()
if params["apply_preprocessing"]:
# if recording.is_filtered == True:
# print('Looks like the recording is already filtered, check preprocessing!')
recording_f = bandpass_filter(recording, **filtering_params)
recording_f = common_reference(recording_f)
else:
Expand Down Expand Up @@ -103,8 +99,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## We launch a clustering (using hdbscan) relying on positions and features extracted on
## the fly from the snippets
clustering_params = params["clustering"].copy()
clustering_params.update(params["waveforms"])
clustering_params.update(params["general"])
clustering_params["waveforms_kwargs"] = params["waveforms"]

for k in ["ms_before", "ms_after"]:
clustering_params["waveforms_kwargs"][k] = params["general"][k]

clustering_params.update(dict(shared_memory=params["shared_memory"]))
clustering_params["job_kwargs"] = job_kwargs
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"
Expand All @@ -126,6 +125,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
waveforms_params = params["waveforms"].copy()
waveforms_params.update(job_kwargs)

for k in ["ms_before", "ms_after"]:
waveforms_params[k] = params["general"][k]

if params["shared_memory"]:
mode = "memory"
waveforms_folder = None
Expand All @@ -143,6 +145,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_params.update({"noise_levels": noise_levels})

matching_job_params = job_kwargs.copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in matching_job_params:
matching_job_params.pop(value)

matching_job_params["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,29 +600,38 @@ def plot_comparison_matching(
else:
ax = axs[j]
comp1, comp2 = comp_per_method[method1], comp_per_method[method2]
for performance, color in zip(performance_names, colors):
perf1 = comp1.get_performance()[performance]
perf2 = comp2.get_performance()[performance]
ax.plot(perf2, perf1, ".", label=performance, color=color)
ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
ax.set_ylim(ylim)
ax.set_xlim(ylim)
ax.spines[["right", "top"]].set_visible(False)
ax.set_aspect("equal")

if j == 0:
ax.set_ylabel(f"{method1}")
else:
ax.set_yticks([])
if i == num_methods - 1:
ax.set_xlabel(f"{method2}")
if i <= j:
for performance, color in zip(performance_names, colors):
perf1 = comp1.get_performance()[performance]
perf2 = comp2.get_performance()[performance]
ax.plot(perf2, perf1, ".", label=performance, color=color)

ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
ax.set_ylim(ylim)
ax.set_xlim(ylim)
ax.spines[["right", "top"]].set_visible(False)
ax.set_aspect("equal")

if j == i:
ax.set_ylabel(f"{method1}")
else:
ax.set_yticks([])
if i == j:
ax.set_xlabel(f"{method2}")
else:
ax.set_xticks([])
if i == num_methods - 1 and j == num_methods - 1:
patches = []
for color, name in zip(colors, performance_names):
patches.append(mpatches.Patch(color=color, label=name))
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0)
else:
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xticks([])
if i == num_methods - 1 and j == num_methods - 1:
patches = []
for color, name in zip(colors, performance_names):
patches.append(mpatches.Patch(color=color, label=name))
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0)
ax.set_yticks([])
plt.tight_layout(h_pad=0, w_pad=0)
return fig, axs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,6 @@ def remove_duplicates_via_matching(
waveform_extractor,
noise_levels,
peak_labels,
sparsify_threshold=1,
method_kwargs={},
job_kwargs={},
tmp_folder=None,
Expand All @@ -552,16 +551,20 @@ def remove_duplicates_via_matching(
from pathlib import Path

job_kwargs = fix_job_kwargs(job_kwargs)

if waveform_extractor.is_sparse():
sparsity = waveform_extractor.sparsity.mask

templates = waveform_extractor.get_all_templates(mode="median").copy()
nb_templates = len(templates)
duration = waveform_extractor.nbefore + waveform_extractor.nafter

fs = waveform_extractor.recording.get_sampling_frequency()
num_chans = waveform_extractor.recording.get_num_channels()

for t in range(nb_templates):
is_silent = templates[t].ptp(0) < sparsify_threshold
templates[t, :, is_silent] = 0
if waveform_extractor.is_sparse():
for count, unit_id in enumerate(waveform_extractor.sorting.unit_ids):
templates[count][:, ~sparsity[count]] = 0

zdata = templates.reshape(nb_templates, -1)

Expand All @@ -581,6 +584,7 @@ def remove_duplicates_via_matching(

recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32")
recording.annotate(is_filtered=True)
recording = recording.set_probe(waveform_extractor.recording.get_probe())

margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter)
half_marging = margin // 2
Expand All @@ -597,7 +601,6 @@ def remove_duplicates_via_matching(
"waveform_extractor": waveform_extractor,
"noise_levels": noise_levels,
"amplitudes": [0.95, 1.05],
"sparsify_threshold": sparsify_threshold,
"omp_min_sps": 0.1,
"templates": None,
"overlaps": None,
Expand Down
130 changes: 51 additions & 79 deletions src/spikeinterface/sortingcomponents/clustering/random_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class RandomProjectionClustering:
"ms_before": 1.5,
"ms_after": 1.5,
"random_seed": 42,
"cleaning_method": "matching",
"shared_memory": False,
"min_values": {"ptp": 0, "energy": 0},
"tmp_folder": None,
Expand Down Expand Up @@ -160,86 +159,59 @@ def main_function(cls, recording, peaks, params):
spikes["segment_index"] = peaks[mask]["segment_index"]
spikes["unit_index"] = peak_labels[mask]

cleaning_method = params["cleaning_method"]

if verbose:
print("We found %d raw clusters, starting to clean with %s..." % (len(labels), cleaning_method))

if cleaning_method == "cosine":
wfs_arrays = extract_waveforms_to_buffers(
recording,
spikes,
labels,
nbefore,
nafter,
mode="shared_memory",
return_scaled=False,
folder=None,
dtype=recording.get_dtype(),
sparsity_mask=None,
copy=True,
**params["job_kwargs"],
)

labels, peak_labels = remove_duplicates(
wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"]
)

elif cleaning_method == "dip":
wfs_arrays = {}
for label in labels:
mask = label == peak_labels
wfs_arrays[label] = hdbscan_data[mask]

labels, peak_labels = remove_duplicates_via_dip(wfs_arrays, peak_labels, **params["cleaning_kwargs"])

elif cleaning_method == "matching":
# create a tmp folder
if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
tmp_folder = get_global_tmp_folder() / name
else:
tmp_folder = Path(params["tmp_folder"])

if params["shared_memory"]:
waveform_folder = None
mode = "memory"
else:
waveform_folder = tmp_folder / "waveforms"
mode = "folder"

sorting_folder = tmp_folder / "sorting"
sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs)
sorting = sorting.save(folder=sorting_folder)
we = extract_waveforms(
recording,
sorting,
waveform_folder,
ms_before=params["ms_before"],
ms_after=params["ms_after"],
**params["job_kwargs"],
return_scaled=False,
mode=mode,
)

cleaning_matching_params = params["job_kwargs"].copy()
cleaning_matching_params["chunk_duration"] = "100ms"
cleaning_matching_params["n_jobs"] = 1
cleaning_matching_params["verbose"] = False
cleaning_matching_params["progress_bar"] = False

cleaning_params = params["cleaning_kwargs"].copy()
cleaning_params["tmp_folder"] = tmp_folder

labels, peak_labels = remove_duplicates_via_matching(
we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params
)

if params["tmp_folder"] is None:
shutil.rmtree(tmp_folder)
else:
print("We found %d raw clusters, starting to clean with matching..." % (len(labels)))

# create a tmp folder
if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
tmp_folder = get_global_tmp_folder() / name
else:
tmp_folder = Path(params["tmp_folder"])

if params["shared_memory"]:
waveform_folder = None
mode = "memory"
else:
waveform_folder = tmp_folder / "waveforms"
mode = "folder"

sorting_folder = tmp_folder / "sorting"
sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is faster and should be the default when we already have a spike vector.
Note the you need to given also unit_ids

Suggested change
sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs)
sorting = NumpySorting(spikes, fs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yger this is the only change that is required! Can you add the unit_ids?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. However, I have a question for the new waveform speedup, with single or multi buffers. @samuelgarcia @alejoe91 shouldn't this functionnality be exposed at the extract_waveforms() level?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the plan but not now.
And WaveformEtractor will be refactor a lot anyway not even sure this will be done before the big refactoring.

sorting = sorting.save(folder=sorting_folder)
we = extract_waveforms(
recording,
sorting,
waveform_folder,
ms_before=params["ms_before"],
ms_after=params["ms_after"],
**params["job_kwargs"],
return_scaled=False,
mode=mode,
)

cleaning_matching_params = params["job_kwargs"].copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in cleaning_matching_params:
cleaning_matching_params.pop(value)
cleaning_matching_params["chunk_duration"] = "100ms"
cleaning_matching_params["n_jobs"] = 1
cleaning_matching_params["verbose"] = False
cleaning_matching_params["progress_bar"] = False

cleaning_params = params["cleaning_kwargs"].copy()
cleaning_params["tmp_folder"] = tmp_folder

labels, peak_labels = remove_duplicates_via_matching(
we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params
)

if params["tmp_folder"] is None:
shutil.rmtree(tmp_folder)
else:
if not params["shared_memory"]:
shutil.rmtree(tmp_folder / "waveforms")
shutil.rmtree(tmp_folder / "sorting")
shutil.rmtree(tmp_folder / "sorting")

if verbose:
print("We kept %d non-duplicated clusters..." % len(labels))
Expand Down
Loading