Skip to content

Commit

Permalink
Merge pull request #2074 from yger/sc2_fixes
Browse files Browse the repository at this point in the history
Sc2 fixes
  • Loading branch information
samuelgarcia authored Oct 5, 2023
2 parents af9660a + b6f9235 commit 8559546
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
11 changes: 6 additions & 5 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore
from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter

try:
import hdbscan
Expand All @@ -22,7 +22,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1},
"filtering": {"dtype": "float32"},
"filtering": {"freq_min": 150, "dtype": "float32"},
"detection": {"peak_sign": "neg", "detect_threshold": 5},
"selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000},
"localization": {},
Expand Down Expand Up @@ -60,11 +60,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## First, we are filtering the data
filtering_params = params["filtering"].copy()
if params["apply_preprocessing"]:
recording_f = bandpass_filter(recording, **filtering_params)
recording_f = highpass_filter(recording, **filtering_params)
recording_f = common_reference(recording_f)
else:
recording_f = recording

# recording_f = whiten(recording_f, dtype="float32")
recording_f = zscore(recording_f, dtype="float32")

## Then, we are detecting peaks with a locally_exclusive method
Expand Down Expand Up @@ -98,10 +99,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## We launch a clustering (using hdbscan) relying on positions and features extracted on
## the fly from the snippets
clustering_params = params["clustering"].copy()
clustering_params["waveforms_kwargs"] = params["waveforms"]
clustering_params["waveforms"] = params["waveforms"].copy()

for k in ["ms_before", "ms_after"]:
clustering_params["waveforms_kwargs"][k] = params["general"][k]
clustering_params["waveforms"][k] = params["general"][k]

clustering_params.update(dict(shared_memory=params["shared_memory"]))
clustering_params["job_kwargs"] = job_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,29 +593,34 @@ def remove_duplicates_via_matching(

chunk_size = duration + 3 * margin

method_kwargs.update(
local_params = method_kwargs.copy()

local_params.update(
{
"waveform_extractor": waveform_extractor,
"noise_levels": noise_levels,
"amplitudes": [0.95, 1.05],
"omp_min_sps": 0.1,
"omp_min_sps": 0.05,
}
)

spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True)
indices = np.argsort(counts)

ignore_ids = []
similar_templates = [[], []]

for i in range(nb_templates):
for i in indices:
t_start = padding + i * duration
t_stop = padding + (i + 1) * duration

sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging)
method_kwargs.update({"ignored_ids": ignore_ids + [i]})
local_params.update({"ignored_ids": ignore_ids + [i]})
spikes, computed = find_spikes_from_templates(
sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs
sub_recording, method=method, method_kwargs=local_params, extra_outputs=True, **job_kwargs
)
if method == "circus-omp-svd":
method_kwargs.update(
local_params.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
Expand All @@ -629,7 +634,7 @@ def remove_duplicates_via_matching(
}
)
elif method == "circus-omp":
method_kwargs.update(
local_params.update(
{
"overlaps": computed["overlaps"],
"templates": computed["templates"],
Expand Down Expand Up @@ -661,7 +666,7 @@ def remove_duplicates_via_matching(
labels = np.unique(new_labels)
labels = labels[labels >= 0]

del recording, sub_recording, method_kwargs
del recording, sub_recording, local_params, waveform_extractor
os.remove(tmp_filename)

return labels, new_labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,8 @@ def sigmoid(x, L, x0, k, b):
recording,
sorting,
waveform_folder,
ms_before=params["ms_before"],
ms_after=params["ms_after"],
**params["job_kwargs"],
**params["waveforms"],
return_scaled=False,
mode=mode,
)
Expand Down

0 comments on commit 8559546

Please sign in to comment.