Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sc2 fixes #2074

Merged
merged 11 commits into from
Oct 5, 2023
11 changes: 6 additions & 5 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore
from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter

try:
import hdbscan
Expand All @@ -22,7 +22,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1},
"filtering": {"dtype": "float32"},
"filtering": {"freq_min": 150, "dtype": "float32"},
"detection": {"peak_sign": "neg", "detect_threshold": 5},
"selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000},
"localization": {},
Expand Down Expand Up @@ -60,11 +60,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## First, we are filtering the data
filtering_params = params["filtering"].copy()
if params["apply_preprocessing"]:
recording_f = bandpass_filter(recording, **filtering_params)
recording_f = highpass_filter(recording, **filtering_params)
recording_f = common_reference(recording_f)
else:
recording_f = recording

# recording_f = whiten(recording_f, dtype="float32")
recording_f = zscore(recording_f, dtype="float32")

## Then, we are detecting peaks with a locally_exclusive method
Expand Down Expand Up @@ -98,10 +99,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## We launch a clustering (using hdbscan) relying on positions and features extracted on
## the fly from the snippets
clustering_params = params["clustering"].copy()
clustering_params["waveforms_kwargs"] = params["waveforms"]
clustering_params["waveforms"] = params["waveforms"].copy()

for k in ["ms_before", "ms_after"]:
clustering_params["waveforms_kwargs"][k] = params["general"][k]
clustering_params["waveforms"][k] = params["general"][k]

clustering_params.update(dict(shared_memory=params["shared_memory"]))
clustering_params["job_kwargs"] = job_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,29 +593,34 @@ def remove_duplicates_via_matching(

chunk_size = duration + 3 * margin

method_kwargs.update(
local_params = method_kwargs.copy()

local_params.update(
{
"waveform_extractor": waveform_extractor,
"noise_levels": noise_levels,
"amplitudes": [0.95, 1.05],
"omp_min_sps": 0.1,
"omp_min_sps": 0.05,
}
)

spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True)
indices = np.argsort(counts)

ignore_ids = []
similar_templates = [[], []]

for i in range(nb_templates):
for i in indices:
t_start = padding + i * duration
t_stop = padding + (i + 1) * duration

sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging)
method_kwargs.update({"ignored_ids": ignore_ids + [i]})
local_params.update({"ignored_ids": ignore_ids + [i]})
spikes, computed = find_spikes_from_templates(
sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs
sub_recording, method=method, method_kwargs=local_params, extra_outputs=True, **job_kwargs
)
if method == "circus-omp-svd":
method_kwargs.update(
local_params.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
Expand All @@ -629,7 +634,7 @@ def remove_duplicates_via_matching(
}
)
elif method == "circus-omp":
method_kwargs.update(
local_params.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
Expand Down Expand Up @@ -661,7 +666,7 @@ def remove_duplicates_via_matching(
labels = np.unique(new_labels)
labels = labels[labels >= 0]

del recording, sub_recording, method_kwargs
del recording, sub_recording, local_params, waveform_extractor
os.remove(tmp_filename)

return labels, new_labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,8 @@ def sigmoid(x, L, x0, k, b):
recording,
sorting,
waveform_folder,
ms_before=params["ms_before"],
ms_after=params["ms_after"],
**params["job_kwargs"],
**params["waveforms"],
return_scaled=False,
mode=mode,
)
Expand Down