Skip to content

Commit

Permalink
Merge pull request #2574 from yger/circus2_improvements
Browse files Browse the repository at this point in the history
Circus2 improvements
  • Loading branch information
samuelgarcia authored Mar 29, 2024
2 parents 630b06b + 2510af4 commit a858a93
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 102 deletions.
24 changes: 8 additions & 16 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"n_peaks_per_channel": 5000,
"min_n_peaks": 100000,
"select_per_channel": False,
"seed": 42,
},
"clustering": {"legacy": False},
"matching": {"method": "circus-omp-svd"},
"apply_preprocessing": True,
"cache_preprocessing": {"mode": None, "memory_limit": 0.5, "delete_cache": True},
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"multi_units_only": False,
"job_kwargs": {"n_jobs": 0.8},
"debug": False,
Expand Down Expand Up @@ -122,8 +123,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## Then, we are detecting peaks with a locally_exclusive method
detection_params = params["detection"].copy()
detection_params.update(job_kwargs)
radius_um = params["general"].get("radius_um", 100)
if "radius_um" not in detection_params:
detection_params["radius_um"] = params["general"]["radius_um"]
detection_params["radius_um"] = radius_um
if "exclude_sweep_ms" not in detection_params:
detection_params["exclude_sweep_ms"] = max(params["general"]["ms_before"], params["general"]["ms_after"])
detection_params["noise_levels"] = noise_levels
Expand Down Expand Up @@ -153,6 +155,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params = params["clustering"].copy()
clustering_params["waveforms"] = {}
clustering_params["sparsity"] = params["sparsity"]
clustering_params["radius_um"] = radius_um

for k in ["ms_before", "ms_after"]:
clustering_params["waveforms"][k] = params["general"][k]
Expand All @@ -161,10 +164,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params["noise_levels"] = noise_levels
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"

if "legacy" in clustering_params:
legacy = clustering_params.pop("legacy")
else:
legacy = False
legacy = clustering_params.get("legacy", False)

if legacy:
if verbose:
Expand Down Expand Up @@ -260,16 +260,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
shutil.rmtree(sorting_folder)

folder_to_delete = None

if "mode" in params["cache_preprocessing"]:
cache_mode = params["cache_preprocessing"]["mode"]
else:
cache_mode = "memory"

if "delete_cache" in params["cache_preprocessing"]:
delete_cache = params["cache_preprocessing"]
else:
delete_cache = True
cache_mode = params["cache_preprocessing"].get("mode", "memory")
delete_cache = params["cache_preprocessing"].get("delete_cache", True)

if cache_mode in ["folder", "zarr"] and delete_cache:
folder_to_delete = recording_f._kwargs["folder_path"]
Expand Down
31 changes: 9 additions & 22 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# """Sorting components: clustering"""
from pathlib import Path

import shutil
import numpy as np

try:
Expand All @@ -13,16 +12,13 @@
except:
HAVE_HDBSCAN = False

import random, string, os
from spikeinterface.core import get_global_tmp_folder, get_channel_distances
import random, string
from spikeinterface.core import get_global_tmp_folder
from spikeinterface.core.basesorting import minimum_spike_dtype
from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler
from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers, estimate_templates
from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip
from spikeinterface.core import NumpySorting
from spikeinterface.core.waveform_tools import estimate_templates
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.core import extract_waveforms
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection
from sklearn.decomposition import TruncatedSVD
Expand All @@ -32,7 +28,6 @@
import pickle, json
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
ExtractDenseWaveforms,
ExtractSparseWaveforms,
PeakRetriever,
)
Expand All @@ -59,7 +54,6 @@ class CircusClustering:
"n_svd": [5, 10],
"ms_before": 0.5,
"ms_after": 0.5,
"random_seed": 42,
"noise_levels": None,
"tmp_folder": None,
"job_kwargs": {},
Expand All @@ -72,21 +66,13 @@ def main_function(cls, recording, peaks, params):
job_kwargs = fix_job_kwargs(params["job_kwargs"])

d = params
if "verbose" in job_kwargs:
verbose = job_kwargs["verbose"]
else:
verbose = False

peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")]
verbose = job_kwargs.get("verbose", False)

fs = recording.get_sampling_frequency()
ms_before = params["ms_before"]
ms_after = params["ms_after"]
nbefore = int(ms_before * fs / 1000.0)
nafter = int(ms_after * fs / 1000.0)
num_samples = nbefore + nafter
num_chans = recording.get_num_channels()
np.random.seed(d["random_seed"])

if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
Expand Down Expand Up @@ -122,7 +108,6 @@ def main_function(cls, recording, peaks, params):
json.dump(model_params, f)

# features
features_folder = model_folder / "features"
node0 = PeakRetriever(recording, peaks)

radius_um = params["radius_um"]
Expand Down Expand Up @@ -152,7 +137,10 @@ def main_function(cls, recording, peaks, params):
nb_clusters = 0
for c in np.unique(peaks["channel_index"]):
mask = peaks["channel_index"] == c
tsvd = TruncatedSVD(params["n_svd"][1])
if all_pc_data.shape[1] > params["n_svd"][1]:
tsvd = TruncatedSVD(params["n_svd"][1])
else:
tsvd = TruncatedSVD(all_pc_data.shape[1])
sub_data = all_pc_data[mask]
hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1))
try:
Expand Down Expand Up @@ -206,7 +194,6 @@ def main_function(cls, recording, peaks, params):
cleaning_matching_params["progress_bar"] = False

cleaning_params = params["cleaning_kwargs"].copy()
cleaning_params["tmp_folder"] = tmp_folder

labels, peak_labels = remove_duplicates_via_matching(
templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params
Expand Down
41 changes: 21 additions & 20 deletions src/spikeinterface/sortingcomponents/clustering/clustering_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def remove_duplicates(

def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job_kwargs={}, tmp_folder=None):
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface.core import BinaryRecordingExtractor
from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, SharedMemoryRecording
from spikeinterface.core import NumpySorting
from spikeinterface.core import get_global_tmp_folder
import os
Expand All @@ -553,25 +553,25 @@ def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job
fs = templates.sampling_frequency
num_chans = len(templates.channel_ids)

zdata = templates_array.reshape(nb_templates, -1)

padding = 2 * duration
blanck = np.zeros(padding * num_chans, dtype=np.float32)

if tmp_folder is None:
tmp_folder = get_global_tmp_folder()

tmp_folder.mkdir(parents=True, exist_ok=True)

tmp_filename = tmp_folder / "tmp.raw"

f = open(tmp_filename, "wb")
f.write(blanck)
f.write(zdata.flatten())
f.write(blanck)
f.close()
tmp_filename = None
zdata = templates_array.reshape(nb_templates * duration, num_chans)
blank = np.zeros((2 * duration, num_chans), dtype=zdata.dtype)
zdata = np.vstack((blank, zdata, blank))

if tmp_folder is not None:
tmp_folder.mkdir(parents=True, exist_ok=True)
tmp_filename = tmp_folder / "tmp.raw"
f = open(tmp_filename, "wb")
f.write(zdata.flatten())
f.close()
recording = BinaryRecordingExtractor(
tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype=zdata.dtype
)
else:
recording = NumpyRecording(zdata, sampling_frequency=fs)
recording = SharedMemoryRecording.from_recording(recording)

recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32")
recording = recording.set_probe(templates.probe)
recording.annotate(is_filtered=True)

Expand All @@ -580,7 +580,7 @@ def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job

local_params = method_kwargs.copy()

local_params.update({"templates": templates, "amplitudes": [0.975, 1.025]})
local_params.update({"templates": templates, "amplitudes": [0.95, 1.05]})

ignore_ids = []
similar_templates = [[], []]
Expand Down Expand Up @@ -631,7 +631,8 @@ def remove_duplicates_via_matching(templates, peak_labels, method_kwargs={}, job
labels = labels[labels >= 0]

del recording, sub_recording, local_params, templates
os.remove(tmp_filename)
if tmp_filename is not None:
os.remove(tmp_filename)

return labels, new_labels

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,9 @@
except:
HAVE_HDBSCAN = False

import random, string, os
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core import get_global_tmp_folder, get_channel_distances, get_random_data_chunks
from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler
from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers, estimate_templates
from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip
from spikeinterface.core import NumpySorting
from spikeinterface.core import extract_waveforms
from spikeinterface.core.waveform_tools import estimate_templates
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
Expand All @@ -30,7 +25,6 @@
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
ExtractDenseWaveforms,
ExtractSparseWaveforms,
PeakRetriever,
)
Expand All @@ -47,12 +41,14 @@ class RandomProjectionClustering:
"allow_single_cluster": True,
"core_dist_n_jobs": -1,
"cluster_selection_method": "leaf",
"cluster_selection_epsilon": 2,
},
"cleaning_kwargs": {},
"waveforms": {"ms_before": 2, "ms_after": 2},
"sparsity": {"method": "ptp", "threshold": 0.25},
"radius_um": 100,
"nb_projections": 10,
"feature": "energy",
"ms_before": 0.5,
"ms_after": 0.5,
"random_seed": 42,
Expand All @@ -69,25 +65,14 @@ def main_function(cls, recording, peaks, params):
job_kwargs = fix_job_kwargs(params["job_kwargs"])

d = params
if "verbose" in job_kwargs:
verbose = job_kwargs["verbose"]
else:
verbose = False
verbose = job_kwargs.get("verbose", False)

fs = recording.get_sampling_frequency()
radius_um = params["radius_um"]
nbefore = int(params["ms_before"] * fs / 1000.0)
nafter = int(params["ms_after"] * fs / 1000.0)
num_samples = nbefore + nafter
num_chans = recording.get_num_channels()
np.random.seed(d["random_seed"])

if params["tmp_folder"] is None:
name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
tmp_folder = get_global_tmp_folder() / name
else:
tmp_folder = Path(params["tmp_folder"]).absolute()

tmp_folder.mkdir(parents=True, exist_ok=True)
rng = np.random.RandomState(d["random_seed"])

node0 = PeakRetriever(recording, peaks)
node1 = ExtractSparseWaveforms(
Expand All @@ -96,30 +81,33 @@ def main_function(cls, recording, peaks, params):
return_output=False,
ms_before=params["ms_before"],
ms_after=params["ms_after"],
radius_um=params["radius_um"],
radius_um=radius_um,
)

node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"])

num_projections = min(num_chans, d["nb_projections"])
projections = np.random.randn(num_chans, num_projections)
projections = rng.randn(num_chans, num_projections)
if num_chans > 1:
projections -= projections.mean(0)
projections /= projections.std(0)
projections -= projections.mean()
projections /= projections.std()

nbefore = int(params["ms_before"] * fs / 1000)
nafter = int(params["ms_after"] * fs / 1000)
nsamples = nbefore + nafter

# noise_ptps = np.linalg.norm(np.random.randn(1000, nsamples), axis=1)
# noise_threshold = np.mean(noise_ptps) + 3 * np.std(noise_ptps)
# if params["feature"] == "ptp":
# noise_values = np.ptp(rng.randn(1000, nsamples), axis=1)
# elif params["feature"] == "energy":
# noise_values = np.linalg.norm(rng.randn(1000, nsamples), axis=1)
# noise_threshold = np.mean(noise_values) + 3 * np.std(noise_values)

node3 = RandomProjectionsFeature(
recording,
parents=[node0, node2],
return_output=True,
feature=params["feature"],
projections=projections,
radius_um=params["radius_um"],
radius_um=radius_um,
noise_threshold=None,
sparse=True,
)
Expand All @@ -130,8 +118,6 @@ def main_function(cls, recording, peaks, params):
recording, pipeline_nodes, job_kwargs=job_kwargs, job_name="extracting features"
)

import sklearn

clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"])
peak_labels = clustering[0]

Expand Down Expand Up @@ -175,7 +161,6 @@ def main_function(cls, recording, peaks, params):
cleaning_matching_params["progress_bar"] = False

cleaning_params = params["cleaning_kwargs"].copy()
cleaning_params["tmp_folder"] = tmp_folder

labels, peak_labels = remove_duplicates_via_matching(
templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params
Expand Down
4 changes: 1 addition & 3 deletions src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,7 @@ def compute(self, traces, peaks, waveforms):
local_map = np.median(features, axis=0) < self.noise_threshold
features[features < local_map] = 0

denom = np.sum(features, axis=1)
mask = denom != 0
all_projections[idx[mask]] = np.dot(features[mask], local_projections) / (denom[mask][:, np.newaxis])
all_projections[idx] = np.dot(features, local_projections)

return all_projections

Expand Down
Loading

0 comments on commit a858a93

Please sign in to comment.