Skip to content

Commit

Permalink
Merge pull request #2768 from yger/circus2_more_improvements
Browse files Browse the repository at this point in the history
Improvements for circus  2
  • Loading branch information
samuelgarcia authored May 1, 2024
2 parents b6c2c91 + faad05d commit 3c50b34
Show file tree
Hide file tree
Showing 13 changed files with 446 additions and 177 deletions.
5 changes: 4 additions & 1 deletion src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def scan_folder(self):
comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle")
if comparison_file.exists():
with open(comparison_file, mode="rb") as f:
self.comparisons[key] = pickle.load(f)
try:
self.comparisons[key] = pickle.load(f)
except Exception:
pass

def __repr__(self):
t = f"{self.__class__.__name__} {self.folder.stem} \n"
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def get_potential_auto_merge(
template_diff_thresh: float, default: 0.25
The threshold on the "template distance metric" for considering a merge.
It needs to be between 0 and 1
template_metric: 'l1'
The metric to be used when comparing templates. Default is l1 norm
censored_period_ms: float, default: 0.3
Used to compute the refractory period violations aka "contamination"
refractory_period_ms: float, default: 1
Expand Down
164 changes: 116 additions & 48 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
import shutil
import numpy as np

from spikeinterface.core import NumpySorting, load_extractor, BaseRecording
from spikeinterface.core import NumpySorting
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.template import Templates
from spikeinterface.core.template_tools import get_template_extremum_amplitude
from spikeinterface.core.waveform_tools import estimate_templates
from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter
from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion
from spikeinterface.sortingcomponents.tools import cache_preprocessing
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.sparsity import compute_sparsity
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.curation.auto_merge import get_potential_auto_merge
from spikeinterface.core.analyzer_extension_core import ComputeTemplates

from spikeinterface.sortingcomponents.tools import get_prototype_spike

try:
import hdbscan
Expand All @@ -32,16 +35,24 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"sparsity": {"method": "ptp", "threshold": 0.25},
"filtering": {"freq_min": 150},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2},
"detection": {"peak_sign": "neg", "detect_threshold": 4},
"selection": {
"method": "smart_sampling_amplitudes",
"method": "uniform",
"n_peaks_per_channel": 5000,
"min_n_peaks": 100000,
"select_per_channel": False,
"seed": 42,
},
"clustering": {"legacy": False},
"drift_correction": {"preset": "nonrigid_fast_and_accurate"},
"merging": {
"minimum_spikes": 10,
"corr_diff_thresh": 0.5,
"template_metric": "cosine",
"censor_correlograms_ms": 0.4,
"num_channels": 5,
},
"clustering": {"legacy": True},
"matching": {"method": "circus-omp-svd"},
"apply_preprocessing": True,
"matched_filtering": False,
Expand All @@ -65,14 +76,15 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
True, one other clustering called circus will be used, similar to the one used in Spyking Circus 1",
"matching": "A dictionary to specify the matching engine used to recover spikes. The method default is circus-omp-svd, but other engines\
can be used",
"merging": "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)",
"motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)",
"apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\
median reference + zscore",
"shared_memory": "Boolean to specify if the code should, as much as possible, use an internal data structure in memory (faster)",
"cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \
memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting",
"multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)",
"job_kwargs": "A dictionary to specify how many jobs and which parameters they should used",
"debug": "Boolean to specify if the internal data structure should be kept for debugging",
"debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging",
}

sorter_description = """Spyking Circus 2 is a rewriting of Spyking Circus, within the SpikeInterface framework
Expand All @@ -93,6 +105,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction
from spikeinterface.sortingcomponents.tools import get_prototype_spike

job_kwargs = params["job_kwargs"]
job_kwargs = fix_job_kwargs(job_kwargs)
Expand All @@ -109,60 +124,70 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## First, we are filtering the data
filtering_params = params["filtering"].copy()
if params["apply_preprocessing"]:
recording_f = highpass_filter(recording, **filtering_params, dtype="float32")
recording_f = bandpass_filter(recording, **filtering_params, dtype="float32")
if num_channels > 1:
recording_f = common_reference(recording_f)
else:
recording_f = recording
recording_f.annotate(is_filtered=True)

recording_f = zscore(recording_f, dtype="float32")
noise_levels = np.ones(recording_f.get_num_channels(), dtype=np.float32)
valid_geometry = check_probe_for_drift_correction(recording_f)
if params["drift_correction"] is not None:
if not valid_geometry:
print("Geometry of the probe does not allow 1D drift correction")
motion_folder = None
else:
print("Motion correction activated (probe geometry compatible)")
motion_folder = sorter_output_folder / "motion"
params["drift_correction"].update({"folder": motion_folder})
recording_f = correct_motion(recording_f, **params["drift_correction"])
else:
motion_folder = None

## We need to whiten before the template matching step, to boost the results
# TODO add , regularize=True chen ready
recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32")

if recording_f.check_serializability("json"):
recording_f.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None)
elif recording_f.check_serializability("pickle"):
recording_f.dump(sorter_output_folder / "preprocessed_recording.pickle", relative_to=None)
noise_levels = get_noise_levels(recording_w, return_scaled=False)

recording_f = cache_preprocessing(recording_f, **job_kwargs, **params["cache_preprocessing"])
if recording_w.check_serializability("json"):
recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None)
elif recording_w.check_serializability("pickle"):
recording_w.dump(sorter_output_folder / "preprocessed_recording.pickle", relative_to=None)

recording_w = cache_preprocessing(recording_w, **job_kwargs, **params["cache_preprocessing"])

## Then, we are detecting peaks with a locally_exclusive method
detection_params = params["detection"].copy()
detection_params.update(job_kwargs)
radius_um = params["general"].get("radius_um", 100)
if "radius_um" not in detection_params:
detection_params["radius_um"] = radius_um

if "exclude_sweep_ms" not in detection_params:
detection_params["exclude_sweep_ms"] = max(ms_before, ms_after)
if "radius_um" not in detection_params:
detection_params["radius_um"] = radius_um
detection_params["radius_um"] = detection_params.get("radius_um", 50)
detection_params["exclude_sweep_ms"] = detection_params.get("exclude_sweep_ms", 0.5)
detection_params["noise_levels"] = noise_levels

fs = recording_f.get_sampling_frequency()
fs = recording_w.get_sampling_frequency()
nbefore = int(ms_before * fs / 1000.0)
nafter = int(ms_after * fs / 1000.0)

peaks = detect_peaks(recording_f, "locally_exclusive", **detection_params)
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params)

if params["matched_filtering"]:
prototype = get_prototype_spike(recording_f, peaks, ms_before, ms_after, **job_kwargs)
prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs)
detection_params["prototype"] = prototype

matching_job_params = job_kwargs.copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in matching_job_params:
matching_job_params.pop(value)
if value in detection_params:
detection_params.pop(value)

matching_job_params["chunk_duration"] = "100ms"
detection_params["chunk_duration"] = "100ms"

peaks = detect_peaks(recording_f, "matched_filtering", **detection_params, **matching_job_params)
peaks = detect_peaks(recording_w, "matched_filtering", **detection_params)

if verbose:
print("We found %d peaks in total" % len(peaks))

if params["multi_units_only"]:
sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_f.unit_ids)
sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.unit_ids)
else:
## We subselect a subset of all the peaks, by making the distributions os SNRs over all
## channels as flat as possible
Expand All @@ -182,25 +207,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params["waveforms"] = {}
clustering_params["sparsity"] = params["sparsity"]
clustering_params["radius_um"] = radius_um

for k in ["ms_before", "ms_after"]:
clustering_params["waveforms"][k] = params["general"][k]

clustering_params["waveforms"]["ms_before"] = ms_before
clustering_params["waveforms"]["ms_after"] = ms_after
clustering_params["job_kwargs"] = job_kwargs
clustering_params["noise_levels"] = noise_levels
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"

legacy = clustering_params.get("legacy", False)
legacy = clustering_params.get("legacy", True)

if legacy:
if verbose:
print("We are using the legacy mode for the clustering")
clustering_method = "circus"
else:
clustering_method = "random_projections"

labels, peak_labels = find_cluster_from_peaks(
recording_f, selected_peaks, method=clustering_method, method_kwargs=clustering_params
recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params
)

## We get the labels for our peaks
Expand All @@ -224,11 +245,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
np.save(clustering_folder / "labels", labels)
np.save(clustering_folder / "peaks", selected_peaks)

nbefore = int(params["general"]["ms_before"] * sampling_frequency / 1000.0)
nafter = int(params["general"]["ms_after"] * sampling_frequency / 1000.0)

recording_w = whiten(recording_f, mode="local", radius_um=100.0)

templates_array = estimate_templates(
recording_w, labeled_peaks, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs
)
Expand Down Expand Up @@ -287,17 +303,69 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if sorting_folder.exists():
shutil.rmtree(sorting_folder)

merging_params = params["merging"].copy()

if len(merging_params) > 0:
if params["drift_correction"] and motion_folder is not None:
from spikeinterface.preprocessing.motion import load_motion_info

motion_info = load_motion_info(motion_folder)
merging_params["maximum_distance_um"] = max(50, 2 * np.abs(motion_info["motion"]).max())

# peak_sign = params['detection'].get('peak_sign', 'neg')
# best_amplitudes = get_template_extremum_amplitude(templates, peak_sign=peak_sign)
# guessed_amplitudes = spikes['amplitude'].copy()
# for ind in unit_ids:
# mask = spikes['cluster_index'] == ind
# guessed_amplitudes[mask] *= best_amplitudes[ind]

if params["debug"]:
curation_folder = sorter_output_folder / "curation"
if curation_folder.exists():
shutil.rmtree(curation_folder)
sorting.save(folder=curation_folder)
# np.save(fitting_folder / "amplitudes", guessed_amplitudes)

sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params)

if verbose:
print(f"Final merging, keeping {len(sorting.unit_ids)} units")

folder_to_delete = None
cache_mode = params["cache_preprocessing"].get("mode", "memory")
delete_cache = params["cache_preprocessing"].get("delete_cache", True)

if cache_mode in ["folder", "zarr"] and delete_cache:
folder_to_delete = recording_f._kwargs["folder_path"]
folder_to_delete = recording_w._kwargs["folder_path"]

del recording_f
del recording_w
if folder_to_delete is not None:
shutil.rmtree(folder_to_delete)

sorting = sorting.save(folder=sorting_folder)

return sorting


def final_cleaning_circus(recording, sorting, templates, **merging_kwargs):

from spikeinterface.sortingcomponents.clustering.clustering_tools import (
resolve_merging_graph,
apply_merges_to_sorting,
)

sparsity = templates.sparsity
templates_array = templates.get_dense_templates().copy()

sa = create_sorting_analyzer(sorting, recording, format="memory", sparsity=sparsity)

sa.extensions["templates"] = ComputeTemplates(sa)
sa.extensions["templates"].params = {"nbefore": templates.nbefore}
sa.extensions["templates"].data["average"] = templates_array
sa.compute("unit_locations", method="monopolar_triangulation")
merges = get_potential_auto_merge(sa, **merging_kwargs)
merges = resolve_merging_graph(sorting, merges)
sorting = apply_merges_to_sorting(sorting, merges)
# sorting = merge_units_sorting(sorting, merges)

return sorting
3 changes: 1 addition & 2 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
neighbours_mask=neighbours_mask,
waveforms_sparse_mask=sparse_mask,
min_size_split=min_cluster_size,
min_cluster_size=min_cluster_size,
min_samples=50,
clusterer_kwargs={"min_cluster_size": min_cluster_size},
n_pca_features=3,
),
recursive=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def run(self, **job_kwargs):

def compute_result(self, **result_params):
self.noise = self.result["peak_labels"] < 0

spikes = self.gt_sorting.to_spike_vector()
self.result["sliced_gt_sorting"] = NumpySorting(
spikes[self.indices], self.recording.sampling_frequency, self.gt_sorting.unit_ids
Expand Down Expand Up @@ -301,8 +300,8 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
result = self.get_result(key)
scores = result["gt_comparison"].agreement_scores

# positions = result["gt_comparison"].sorting1.get_property('gt_unit_locations')
positions = self.datasets[key[1]][1].get_property("gt_unit_locations")
positions = result["sliced_gt_sorting"].get_property("gt_unit_locations")
# positions = self.datasets[key[1]][1].get_property("gt_unit_locations")
depth = positions[:, 1]

analyzer = self.get_sorting_analyzer(key)
Expand Down
Loading

0 comments on commit 3c50b34

Please sign in to comment.