From 8349b90593622af022fa6b80ede0bc021296e5d6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 14 Sep 2023 21:07:04 +0200 Subject: [PATCH 01/19] implement proof of concept merge_clusters/split_clusters from tridesclous and columbia codes. --- src/spikeinterface/core/globals.py | 2 +- src/spikeinterface/core/job_tools.py | 60 ++- src/spikeinterface/core/node_pipeline.py | 7 +- .../sortingcomponents/clustering/clean.py | 45 ++ .../sortingcomponents/clustering/merge.py | 500 ++++++++++++++++++ .../sortingcomponents/clustering/split.py | 260 +++++++++ .../sortingcomponents/clustering/tools.py | 196 +++++++ .../sortingcomponents/tests/test_split.py | 12 + 8 files changed, 1074 insertions(+), 8 deletions(-) create mode 100644 src/spikeinterface/sortingcomponents/clustering/clean.py create mode 100644 src/spikeinterface/sortingcomponents/clustering/merge.py create mode 100644 src/spikeinterface/sortingcomponents/clustering/split.py create mode 100644 src/spikeinterface/sortingcomponents/clustering/tools.py create mode 100644 src/spikeinterface/sortingcomponents/tests/test_split.py diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index e5581c7a67..d039206296 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -96,7 +96,7 @@ def is_set_global_dataset_folder(): ######################################## global global_job_kwargs -global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) +global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global global_job_kwargs_set global_job_kwargs_set = False diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index c0ee77d2fd..3e25d64983 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -380,10 +380,6 @@ def run(self): self.gather_func(res) else: n_jobs = min(self.n_jobs, len(all_chunks)) - ######## Do you want to limit the number of threads per process? - ######## It has to be done to speed up numpy a lot if multicores - ######## Otherwise, np.dot will be slow. How to do that, up to you - ######## This is just a suggestion, but here it adds a dependency # parallel with ProcessPoolExecutor( @@ -436,3 +432,59 @@ def function_wrapper(args): else: with threadpool_limits(limits=max_threads_per_process): return _func(segment_index, start_frame, end_frame, _worker_ctx) + + +# Here some utils + + +class MockFuture: + """A non-concurrent class for mocking the concurrent.futures API.""" + + def __init__(self, f, *args): + self.f = f + self.args = args + + def result(self): + return self.f(*self.args) + + +class MockPoolExecutor: + """A non-concurrent class for mocking the concurrent.futures API.""" + + def __init__( + self, + max_workers=None, + mp_context=None, + initializer=None, + initargs=None, + context=None, + ): + if initializer is not None: + initializer(*initargs) + self.map = map + self.imap = map + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return + + def submit(self, f, *args): + return MockFuture(f, *args) + + +class MockQueue: + """Another helper class for turning off concurrency when debugging.""" + + def __init__(self): + self.q = [] + self.put = self.q.append + self.get = lambda: self.q.pop(0) + + +def get_poolexecutor(n_jobs): + if n_jobs == 1: + return MockPoolExecutor + else: + return ProcessPoolExecutor diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index b11f40a441..cd858da1e1 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -436,6 +436,7 @@ def run_node_pipeline( job_name="pipeline", mp_context=None, gather_mode="memory", + gather_kwargs={}, squeeze_output=True, folder=None, names=None, @@ -452,7 +453,7 @@ def run_node_pipeline( if gather_mode == "memory": gather_func = GatherToMemory() elif gather_mode == "npy": - gather_func = GatherToNpy(folder, names) + gather_func = GatherToNpy(folder, names, **gather_kwargs) else: raise ValueError(f"wrong gather_mode : {gather_mode}") @@ -597,9 +598,9 @@ class GatherToNpy: * create the npy v1.0 header at the end with the correct shape and dtype """ - def __init__(self, folder, names, npy_header_size=1024): + def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): self.folder = Path(folder) - self.folder.mkdir(parents=True, exist_ok=False) + self.folder.mkdir(parents=True, exist_ok=exist_ok) assert names is not None self.names = names self.npy_header_size = npy_header_size diff --git a/src/spikeinterface/sortingcomponents/clustering/clean.py b/src/spikeinterface/sortingcomponents/clustering/clean.py new file mode 100644 index 0000000000..cbded0c49f --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/clean.py @@ -0,0 +1,45 @@ +import numpy as np + +from .tools import FeaturesLoader, compute_template_from_sparse + +# This is work in progress ... + + +def clean_clusters( + peaks, + peak_labels, + recording, + features_dict_or_folder, + peak_sign="neg", +): + total_channels = recording.get_num_channels() + + if isinstance(features_dict_or_folder, dict): + features = features_dict_or_folder + else: + features = FeaturesLoader(features_dict_or_folder) + + clean_labels = peak_labels.copy() + + sparse_wfs = features["sparse_wfs"] + sparse_mask = features["sparse_mask"] + + labels_set = np.setdiff1d(peak_labels, [-1]).tolist() + n = len(labels_set) + + count = np.zeros(n, dtype="int64") + for i, label in enumerate(labels_set): + count[i] = np.sum(peak_labels == label) + print(count) + + templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels) + + if peak_sign == "both": + max_values = np.max(np.abs(templates), axis=(1, 2)) + elif peak_sign == "neg": + max_values = -np.min(templates, axis=(1, 2)) + elif peak_sign == "pos": + max_values = np.max(templates, axis=(1, 2)) + print(max_values) + + return clean_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py new file mode 100644 index 0000000000..2e839ef0fc --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -0,0 +1,500 @@ +from pathlib import Path +from multiprocessing import get_context +from concurrent.futures import ProcessPoolExecutor +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +import scipy.spatial +from sklearn.decomposition import PCA +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from hdbscan import HDBSCAN + +import numpy as np +import networkx as nx + +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + + +from .isocut5 import isocut5 + +from .tools import aggregate_sparse_features, FeaturesLoader, compute_template_from_sparse + + +def merge_clusters( + peaks, + peak_labels, + recording, + features_dict_or_folder, + radius_um=70, + method="waveforms_lda", + method_kwargs={}, + **job_kwargs, +): + """ + Merge cluster using differents methods. + + Parameters + ---------- + peaks: numpy.ndarray 1d + detected peaks (or a subset) + peak_labels: numpy.ndarray 1d + original label before merge + peak_labels.size == peaks.size + recording: Recording object + A recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. + method: str + The method used + method_kwargs: dict + Option for the method. + Returns + ------- + merge_peak_labels: numpy.ndarray 1d + New vectors label after merges. + peak_shifts: numpy.ndarray 1d + A vector of sample shift to be reverse applied on original sample_index on peak detection + Negative shift means too early. + Posituve shift means too late. + So the correction must be applied like this externaly: + final_peaks = peaks.copy() + final_peaks['sample_index'] -= peak_shifts + + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + + features = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + sparse_wfs = features["sparse_wfs"] + sparse_mask = features["sparse_mask"] + + labels_set, pair_mask, pair_shift, pair_values = find_merge_pairs( + peaks, + peak_labels, + recording, + features_dict_or_folder, + sparse_wfs, + sparse_mask, + radius_um=radius_um, + method=method, + method_kwargs=method_kwargs, + **job_kwargs, + ) + + # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") + merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") + + group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift) + + # apply final label and shift + merge_peak_labels = peak_labels.copy() + peak_shifts = np.zeros(peak_labels.size, dtype="int64") + for merge, shifts in zip(merges, group_shifts): + label0 = merge[0] + mask = np.in1d(peak_labels, merge) + merge_peak_labels[mask] = label0 + for l, label1 in enumerate(merge): + if l == 0: + # the first label is the reference (shift=0) + continue + peak_shifts[peak_labels == label1] = shifts[l] + + return merge_peak_labels, peak_shifts + + +def resolve_final_shifts(labels_set, merges, pair_mask, pair_shift): + labels_set = list(labels_set) + + group_shifts = [] + for merge in merges: + shifts = np.zeros(len(merge), dtype="int64") + + label_inds = [labels_set.index(label) for label in merge] + + label0 = merge[0] + ind0 = label_inds[0] + + # First find relative shift to label0 (l=0) in the subgraph + local_pair_mask = pair_mask[label_inds, :][:, label_inds] + local_pair_shift = None + G = None + for l, label1 in enumerate(merge): + if l == 0: + # the first label is the reference (shift=0) + continue + ind1 = label_inds[l] + if local_pair_mask[0, l]: + # easy case the pair label0<>label1 was existing + shift = pair_shift[ind0, ind1] + else: + # more complicated case need to find intermediate label and propagate the shift!! + if G is None: + # the the graph only once and only if needed + G = nx.from_numpy_array(local_pair_mask | local_pair_mask.T) + local_pair_shift = pair_shift[label_inds, :][:, label_inds] + local_pair_shift += local_pair_shift.T + + shift_chain = nx.shortest_path(G, source=l, target=0) + shift = 0 + for i in range(len(shift_chain) - 1): + shift += local_pair_shift[shift_chain[i + 1], shift_chain[i]] + shifts[l] = shift + + group_shifts.append(shifts) + + return group_shifts + + +def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"): + """ + Agglomerate merge pairs into final merge groups. + + The merges are ordered by label. + + """ + + labels_set = np.array(labels_set) + + merges = [] + + graph = nx.from_numpy_matrix(pair_mask | pair_mask.T) + # put real nodes names for debugging + maps = dict(zip(np.arange(labels_set.size), labels_set)) + graph = nx.relabel_nodes(graph, maps) + + groups = list(nx.connected_components(graph)) + for group in groups: + if len(group) == 1: + continue + sub_graph = graph.subgraph(group) + # print(group, sub_graph) + cliques = list(nx.find_cliques(sub_graph)) + if len(cliques) == 1 and len(cliques[0]) == len(group): + # the sub graph is full connected: no ambiguity + # merges.append(labels_set[cliques[0]]) + merges.append(cliques[0]) + elif len(cliques) > 1: + # the subgraph is not fully connected + if connection_mode == "full": + # node merge + pass + elif connection_mode == "partial": + group = list(group) + # merges.append(labels_set[group]) + merges.append(group) + elif connection_mode == "clique": + raise NotImplementedError + else: + raise ValueError + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig = plt.figure() + nx.draw_networkx(sub_graph) + plt.show() + + DEBUG = True + # DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig = plt.figure() + nx.draw_networkx(graph) + plt.show() + + # ensure ordered label + merges = [np.sort(merge) for merge in merges] + + return merges + + +def find_merge_pairs( + peaks, + peak_labels, + recording, + features_dict_or_folder, + sparse_wfs, + sparse_mask, + radius_um=70, + method="waveforms_lda", + method_kwargs={}, + **job_kwargs + # n_jobs=1, + # mp_context="fork", + # max_threads_per_process=1, + # progress_bar=True, +): + """ + Searh some possible merge 2 by 2. + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + # features_dict_or_folder = Path(features_dict_or_folder) + + # peaks = features_dict_or_folder['peaks'] + total_channels = recording.get_num_channels() + + # sparse_wfs = features['sparse_wfs'] + + labels_set = np.setdiff1d(peak_labels, [-1]).tolist() + n = len(labels_set) + pair_mask = np.triu(np.ones((n, n), dtype="bool")) & ~np.eye(n, dtype="bool") + pair_shift = np.zeros((n, n), dtype="int64") + pair_values = np.zeros((n, n), dtype="float64") + + # compute template (no shift at this step) + + templates = compute_template_from_sparse( + peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels, peak_shifts=None + ) + + max_chans = np.argmax(np.max(np.abs(templates), axis=1), axis=1) + + channel_locs = recording.get_channel_locations() + template_locs = channel_locs[max_chans, :] + template_dist = scipy.spatial.distance.cdist(template_locs, template_locs, metric="euclidean") + + pair_mask = pair_mask & (template_dist < radius_um) + indices0, indices1 = np.nonzero(pair_mask) + + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs["mp_context"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + progress_bar = job_kwargs["progress_bar"] + + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=find_pair_worker_init, + mp_context=get_context(mp_context), + initargs=(recording, features_dict_or_folder, peak_labels, method, method_kwargs, max_threads_per_process), + ) as pool: + jobs = [] + for ind0, ind1 in zip(indices0, indices1): + label0 = labels_set[ind0] + label1 = labels_set[ind1] + jobs.append(pool.submit(find_pair_function_wrapper, label0, label1)) + + if progress_bar: + iterator = tqdm(jobs, desc=f"find_merge_pairs with {method}", total=len(jobs)) + else: + iterator = jobs + + for res in iterator: + is_merge, label0, label1, shift, merge_value = res.result() + ind0 = labels_set.index(label0) + ind1 = labels_set.index(label1) + + pair_mask[ind0, ind1] = is_merge + if is_merge: + pair_shift[ind0, ind1] = shift + pair_values[ind0, ind1] = merge_value + + pair_mask = pair_mask & (template_dist < radius_um) + indices0, indices1 = np.nonzero(pair_mask) + + return labels_set, pair_mask, pair_shift, pair_values + + +def find_pair_worker_init( + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process +): + global _ctx + _ctx = {} + + _ctx["recording"] = recording + _ctx["original_labels"] = original_labels + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = find_pair_method_dict[method] + _ctx["max_threads_per_process"] = max_threads_per_process + + # if isinstance(features_dict_or_folder, dict): + # _ctx["features"] = features_dict_or_folder + # else: + # _ctx["features"] = FeaturesLoader(features_dict_or_folder) + + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + + _ctx["peaks"] = _ctx["features"]["peaks"] + + +def find_pair_function_wrapper(label0, label1): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_process"]): + is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( + label0, label1, _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + ) + return is_merge, label0, label1, shift, merge_value + + +class WaveformsLda: + name = "waveforms_lda" + + @staticmethod + def merge( + label0, + label1, + original_labels, + peaks, + features, + waveforms_sparse_mask=None, + feature_name="sparse_tsvd", + projection="centroid", + criteria="diptest", + threshold_diptest=0.5, + threshold_percentile=80.0, + num_shift=2, + ): + if num_shift > 0: + assert feature_name == "sparse_wfs" + sparse_wfs = features[feature_name] + + assert waveforms_sparse_mask is not None + + (inds0,) = np.nonzero(original_labels == label0) + chans0 = np.unique(peaks["channel_index"][inds0]) + target_chans0 = np.flatnonzero(np.all(waveforms_sparse_mask[chans0, :], axis=0)) + + (inds1,) = np.nonzero(original_labels == label1) + chans1 = np.unique(peaks["channel_index"][inds1]) + target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) + + if inds0.size <40 or inds1.size <40: + is_merge = False + merge_value = 0 + final_shift = 0 + return is_merge, label0, label1, final_shift, merge_value + + + target_chans = np.intersect1d(target_chans0, target_chans1) + + inds = np.concatenate([inds0, inds1]) + labels = np.zeros(inds.size, dtype="int") + labels[inds0.size :] = 1 + wfs, out = aggregate_sparse_features(peaks, inds, sparse_wfs, waveforms_sparse_mask, target_chans) + wfs = wfs[~out] + labels = labels[~out] + + cut = np.searchsorted(labels, 1) + wfs0_ = wfs[:cut, :, :] + wfs1_ = wfs[cut:, :, :] + + template0_ = np.mean(wfs0_, axis=0) + template1_ = np.mean(wfs1_, axis=0) + num_samples = template0_.shape[0] + + template0 = template0_[num_shift : num_samples - num_shift, :] + + wfs0 = wfs0_[:, num_shift : num_samples - num_shift, :] + + # best shift strategy 1 = max cosine + # values = [] + # for shift in range(num_shift * 2 + 1): + # template1 = template1_[shift : shift + template0.shape[0], :] + # norm = np.linalg.norm(template0.flatten()) * np.linalg.norm(template1.flatten()) + # value = np.sum(template0.flatten() * template1.flatten()) / norm + # values.append(value) + # best_shift = np.argmax(values) + + # best shift strategy 2 = min dist**2 + # values = [] + # for shift in range(num_shift * 2 + 1): + # template1 = template1_[shift : shift + template0.shape[0], :] + # value = np.sum((template1 - template0)**2) + # values.append(value) + # best_shift = np.argmin(values) + + # best shift strategy 3 : average delta argmin between channels + channel_shift = np.argmax(np.abs(template1_), axis=0) - np.argmax(np.abs(template0_), axis=0) + mask = np.abs(channel_shift) <= num_shift + channel_shift = channel_shift[mask] + if channel_shift.size > 0: + best_shift = int(np.round(np.mean(channel_shift))) + num_shift + else: + best_shift = num_shift + + wfs1 = wfs1_[:, best_shift : best_shift + template0.shape[0], :] + template1 = template1_[best_shift : best_shift + template0.shape[0], :] + + if projection == "lda": + wfs_0_1 = np.concatenate([wfs0, wfs1], axis=0) + flat_wfs = wfs_0_1.reshape(wfs_0_1.shape[0], -1) + feat = LinearDiscriminantAnalysis(n_components=1).fit_transform(flat_wfs, labels) + feat = feat[:, 0] + feat0 = feat[:cut] + feat1 = feat[cut:] + + elif projection == "centroid": + vector_0_1 = template1 - template0 + vector_0_1 /= np.sum(vector_0_1**2) + feat0 = np.sum((wfs0 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + feat1 = np.sum((wfs1 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + # feat = np.sum((wfs_0_1 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + feat = np.concatenate([feat0, feat1], axis=0) + + else: + raise ValueError(f"bad projection {projection}") + + if criteria == "diptest": + dipscore, cutpoint = isocut5(feat) + is_merge = dipscore < threshold_diptest + merge_value = dipscore + elif criteria == "percentile": + l0 = np.percentile(feat0, threshold_percentile) + l1 = np.percentile(feat1, 100.0 - threshold_percentile) + is_merge = l0 >= l1 + merge_value = l0 - l1 + else: + raise ValueError(f"bad criteria {criteria}") + + if is_merge: + final_shift = best_shift - num_shift + else: + final_shift = 0 + + DEBUG = True + # DEBUG = False + + if DEBUG and is_merge: + # if DEBUG: + import matplotlib.pyplot as plt + + flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) + flatten_wfs1 = wfs1.swapaxes(1, 2).reshape(wfs1.shape[0], -1) + + fig, axs = plt.subplots(ncols=2) + ax = axs[0] + ax.plot(flatten_wfs0.T, color="C0", alpha=0.01) + ax.plot(flatten_wfs1.T, color="C1", alpha=0.01) + m0 = np.mean(flatten_wfs0, axis=0) + m1 = np.mean(flatten_wfs1, axis=0) + ax.plot(m0, color="C0", alpha=1, lw=4, label=f"{label0} {inds0.size}") + ax.plot(m1, color="C1", alpha=1, lw=4, label=f"{label1} {inds1.size}") + + ax.legend() + + bins = np.linspace(np.percentile(feat, 1), np.percentile(feat, 99), 100) + + count0, _ = np.histogram(feat0, bins=bins) + count1, _ = np.histogram(feat1, bins=bins) + + ax = axs[1] + ax.plot(bins[:-1], count0, color="C0") + ax.plot(bins[:-1], count1, color="C1") + + ax.set_title(f"{dipscore:.4f} {is_merge}") + plt.show() + + + return is_merge, label0, label1, final_shift, merge_value + + +find_pair_method_list = [ + WaveformsLda, +] +find_pair_method_dict = {e.name: e for e in find_pair_method_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py new file mode 100644 index 0000000000..411d8c2116 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -0,0 +1,260 @@ +from multiprocessing import get_context +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +from sklearn.decomposition import TruncatedSVD +from hdbscan import HDBSCAN + +import numpy as np + +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + +from .tools import aggregate_sparse_features, FeaturesLoader +from .isocut5 import isocut5 + + +def split_clusters( + peak_labels, + recording, + features_dict_or_folder, + method="hdbscan_on_local_pca", + method_kwargs={}, + recursive=False, + recursive_depth=None, + returns_split_count=False, + **job_kwargs, +): + """ + Run recusrsively or not in a multi process pool a local split method. + + Parameters + ---------- + peak_labels: numpy.array + Peak label before split + recording: Recording + Recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. + method: str + The method name + method_kwargs: dict + The method option + recursive: bool Default False + Reccursive or not. + recursive_depth: None or int + If recursive=True, then this is the max split per spikes. + returns_split_count: bool + Optionally return the split count vector. Same size as labels. + + Returns + ------- + new_labels: numpy.ndarray + The labels of peaks after split. + split_count: numpy.ndarray + Optionally returned + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs["mp_context"] + progress_bar = job_kwargs["progress_bar"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + + original_labels = peak_labels + peak_labels = peak_labels.copy() + split_count = np.zeros(peak_labels.size, dtype=int) + + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=split_worker_init, + mp_context=get_context(mp_context), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), + ) as pool: + labels_set = np.setdiff1d(peak_labels, [-1]) + current_max_label = np.max(labels_set) + 1 + + jobs = [] + for label in labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices)) + + if progress_bar: + iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set)) + else: + iterator = jobs + + for res in iterator: + is_split, local_labels, peak_indices = res.result() + if not is_split: + continue + + mask = local_labels >= 0 + peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label + peak_labels[peak_indices[~mask]] = local_labels[~mask] + + split_count[peak_indices] += 1 + + current_max_label += np.max(local_labels[mask]) + 1 + + if recursive: + if recursive_depth is not None: + # stop reccursivity when recursive_depth is reach + extra_ball = np.max(split_count[peak_indices]) < recursive_depth + else: + # reccurssive always + extra_ball = True + + if extra_ball: + new_labels_set = np.setdiff1d(peak_labels[peak_indices], [-1]) + for label in new_labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices)) + if progress_bar: + iterator.total += 1 + + if returns_split_count: + return peak_labels, split_count + else: + return peak_labels + + +global _ctx + + +def split_worker_init( + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process +): + global _ctx + _ctx = {} + + _ctx["recording"] = recording + features_dict_or_folder + _ctx["original_labels"] = original_labels + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = split_methods_dict[method] + _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + _ctx["peaks"] = _ctx["features"]["peaks"] + + +def split_function_wrapper(peak_indices): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_process"]): + is_split, local_labels = _ctx["method_class"].split( + peak_indices, _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + ) + return is_split, local_labels, peak_indices + + +class HdbscanOnLocalPca: + # @charlie : this is the equivalent of "herding_split()" in DART + # but simplified, flexible and renamed + + name = "hdbscan_on_local_pca" + + @staticmethod + def split( + peak_indices, + peaks, + features, + clusterer="hdbscan", + feature_name="sparse_tsvd", + neighbours_mask=None, + waveforms_sparse_mask=None, + min_size_split=25, + min_cluster_size=25, + min_samples=25, + n_pca_features=2, + ): + local_labels = np.zeros(peak_indices.size, dtype=np.int64) + + # can be sparse_tsvd or sparse_wfs + sparse_features = features[feature_name] + + assert waveforms_sparse_mask is not None + + # target channel subset is done intersect local channels + neighbours + local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + + # TODO fix this a better way, this when cluster have too few overlapping channels + minimum_channels = 2 + if target_channels.size < minimum_channels: + return False, None + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_channels + ) + + local_labels[dont_have_channels] = -2 + kept = np.flatnonzero(~dont_have_channels) + if kept.size < min_size_split: + return False, None + + aligned_wfs = aligned_wfs[kept, :, :] + + flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1) + + # final_features = PCA(n_pca_features, whiten=True).fit_transform(flatten_features) + # final_features = PCA(n_pca_features, whiten=False).fit_transform(flatten_features) + final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) + + if clusterer == "hdbscan": + clust = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, allow_single_cluster=True) + clust.fit(final_features) + possible_labels = clust.labels_ + elif clusterer == "isocut5": + dipscore, cutpoint = isocut5(final_features[:, 0]) + possible_labels = np.zeros(final_features.shape[0]) + if dipscore > 1.5: + mask = final_features[:, 0] > cutpoint + if np.sum(mask) > min_cluster_size and np.sum(~mask): + possible_labels[mask] = 1 + else: + return False, None + else: + raise ValueError(f"wrong clusterer {clusterer}") + + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + labels_set = np.setdiff1d(possible_labels, [-1]) + colors = plt.get_cmap("tab10", len(labels_set)) + colors = {k: colors(i) for i, k in enumerate(labels_set)} + colors[-1] = "k" + fix, axs = plt.subplots(nrows=2) + + flatten_wfs = aligned_wfs.swapaxes(1, 2).reshape(aligned_wfs.shape[0], -1) + + sl = slice(None, None, 10) + for k in np.unique(possible_labels): + mask = possible_labels == k + ax = axs[0] + ax.scatter(final_features[:, 0][mask][sl], final_features[:, 1][mask][sl], s=5, color=colors[k]) + + ax = axs[1] + ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) + + plt.show() + + if not is_split: + return is_split, None + + local_labels[kept] = possible_labels + + return is_split, local_labels + + +split_methods_list = [ + HdbscanOnLocalPca, +] +split_methods_dict = {e.name: e for e in split_methods_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py new file mode 100644 index 0000000000..9a537ab8a8 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -0,0 +1,196 @@ +from pathlib import Path +from typing import Any +import numpy as np + + +# TODO find a way to attach a a sparse_mask to a given features (waveforms, pca, tsvd ....) + + +class FeaturesLoader: + """ + Feature can be computed in memory or in a folder contaning npy files. + + This class read the folder and behave like a dict of array lazily. + + Parameters + ---------- + feature_folder + + preload + + """ + + def __init__(self, feature_folder, preload=["peaks"]): + self.feature_folder = Path(feature_folder) + + self.file_feature = {} + self.loaded_features = {} + for file in self.feature_folder.glob("*.npy"): + name = file.stem + if name in preload: + self.loaded_features[name] = np.load(file) + else: + self.file_feature[name] = file + + def __getitem__(self, name): + if name in self.loaded_features: + return self.loaded_features[name] + else: + return np.load(self.file_feature[name], mmap_mode="r") + + @staticmethod + def from_dict_or_folder(features_dict_or_folder): + if isinstance(features_dict_or_folder, dict): + return features_dict_or_folder + else: + return FeaturesLoader(features_dict_or_folder) + + +def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, target_channels): + """ + Aggregate sparse features that have unaligned channels and realigned then on target_channels. + + This is usefull to aligned back peaks waveform or pca or tsvd when detected a differents channels. + + + Parameters + ---------- + peaks + + peak_indices + + sparse_feature + + sparse_mask + + target_channels + + Returns + ------- + aligned_features: numpy.array + Aligned features. shape is (local_peaks.size, sparse_feature.shape[1], target_channels.size) + dont_have_channels: numpy.array + Boolean vector to indicate spikes that do not have all target channels to be taken in account + shape is peak_indices.size + """ + local_peaks = peaks[peak_indices] + + aligned_features = np.zeros( + (local_peaks.size, sparse_feature.shape[1], target_channels.size), dtype=sparse_feature.dtype + ) + dont_have_channels = np.zeros(peak_indices.size, dtype=bool) + + for chan in np.unique(local_peaks["channel_index"]): + sparse_chans = np.flatnonzero(sparse_mask[chan, :]) + peak_inds = np.flatnonzero(local_peaks["channel_index"] == chan) + if np.all(np.in1d(target_channels, sparse_chans)): + # peaks feature channel have all target_channels + source_chans = np.flatnonzero(np.in1d(sparse_chans, target_channels)) + aligned_features[peak_inds, :, :] = sparse_feature[peak_indices[peak_inds], :, :][:, :, source_chans] + else: + # some channel are missing, peak are not removde + dont_have_channels[peak_inds] = True + + return aligned_features, dont_have_channels + + +def compute_template_from_sparse( + peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None +): + """ + Compute template average from single sparse waveforms buffer. + + Parameters + ---------- + peaks + + labels + + labels_set + + sparse_waveforms + + sparse_mask + + total_channels + + peak_shifts + + Returns + ------- + templates: numpy.array + Templates shape : (len(labels_set), num_samples, total_channels) + """ + n = len(labels_set) + + templates = np.zeros((n, sparse_waveforms.shape[1], total_channels), dtype=sparse_waveforms.dtype) + + for i, label in enumerate(labels_set): + peak_indices = np.flatnonzero(labels == label) + + local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(sparse_mask[local_chans, :], axis=0)) + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_waveforms, sparse_mask, target_channels + ) + + if peak_shifts is not None: + apply_waveforms_shift(aligned_wfs, peak_shifts[peak_indices], inplace=True) + + templates[i, :, :][:, target_channels] = np.mean(aligned_wfs[~dont_have_channels], axis=0) + + return templates + + +def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): + """ + Apply a shift a spike level to realign waveforms buffers. + + This is usefull to compute template after merge when to cluster are shifted. + + A negative shift need the waveforms to be moved toward the right because the trough was too early. + A positive shift need the waveforms to be moved toward the left because the trough was too late. + + Note the border sample are left as before without move. + + Parameters + ---------- + + waveforms + + peak_shifts + + inplace + + Returns + ------- + aligned_waveforms + + + """ + + print("apply_waveforms_shift") + + if inplace: + aligned_waveforms = waveforms + else: + aligned_waveforms = waveforms.copy() + + shift_set = np.unique(peak_shifts) + assert max(np.abs(shift_set)) < aligned_waveforms.shape[1] + + for shift in shift_set: + if shift == 0: + continue + mask = peak_shifts == shift + wfs = waveforms[mask] + + if shift > 0: + aligned_waveforms[mask, :-shift, :] = wfs[:, shift:, :] + else: + aligned_waveforms[mask, -shift:, :] = wfs[:, :-shift, :] + + print("apply_waveforms_shift DONE") + + return aligned_waveforms diff --git a/src/spikeinterface/sortingcomponents/tests/test_split.py b/src/spikeinterface/sortingcomponents/tests/test_split.py new file mode 100644 index 0000000000..ed5e756469 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_split.py @@ -0,0 +1,12 @@ +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.clustering.split import split_clusters + + +def test_split(): + pass + + +if __name__ == "__main__": + test_split() From 9c4bba37b89012c4d016394916a04832df3109c7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 21:00:13 +0200 Subject: [PATCH 02/19] wip --- .../sortingcomponents/clustering/merge.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 2e839ef0fc..e2049d70bf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -157,7 +157,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" merges = [] - graph = nx.from_numpy_matrix(pair_mask | pair_mask.T) + graph = nx.from_numpy_array(pair_mask | pair_mask.T) # put real nodes names for debugging maps = dict(zip(np.arange(labels_set.size), labels_set)) graph = nx.relabel_nodes(graph, maps) @@ -196,8 +196,8 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" nx.draw_networkx(sub_graph) plt.show() - DEBUG = True - # DEBUG = False + # DEBUG = True + DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -457,8 +457,8 @@ def merge( else: final_shift = 0 - DEBUG = True - # DEBUG = False + # DEBUG = True + DEBUG = False if DEBUG and is_merge: # if DEBUG: @@ -487,7 +487,15 @@ def merge( ax.plot(bins[:-1], count0, color="C0") ax.plot(bins[:-1], count1, color="C1") - ax.set_title(f"{dipscore:.4f} {is_merge}") + if criteria == "diptest": + ax.set_title(f"{dipscore:.4f} {is_merge}") + elif criteria == "percentile": + ax.set_title(f"{l0:.4f} {l1:.4f} {is_merge}") + ax.axvline(l0, color="C0") + ax.axvline(l1, color="C1") + + + plt.show() From 1939b936e94d30c8437633f89c49fd006ca71a80 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 4 Oct 2023 10:19:11 +0200 Subject: [PATCH 03/19] Diff for SC2 --- src/spikeinterface/sorters/internal/spyking_circus2.py | 7 ++++--- .../sortingcomponents/clustering/clustering_tools.py | 7 +++++-- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a0a4d0823c..db06287f6c 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,7 +6,7 @@ from spikeinterface.core import NumpySorting, load_extractor, BaseRecording, get_noise_levels, extract_waveforms from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore +from spikeinterface.preprocessing import common_reference, zscore, whiten, highpass_filter try: import hdbscan @@ -22,7 +22,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, - "filtering": {"dtype": "float32"}, + "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "localization": {}, @@ -60,11 +60,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: - recording_f = bandpass_filter(recording, **filtering_params) + recording_f = highpass_filter(recording, **filtering_params) recording_f = common_reference(recording_f) else: recording_f = recording + #recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") ## Then, we are detecting peaks with a locally_exclusive method diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 891c355448..6dba4b7f0f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -598,14 +598,17 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.95, 1.05], - "omp_min_sps": 0.1, + "omp_min_sps": 0.05, } ) + spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()['unit_index'], return_counts=True) + indices = np.argsort(counts) + ignore_ids = [] similar_templates = [[], []] - for i in range(nb_templates): + for i in np.arange(nb_templates)[indices]: t_start = padding + i * duration t_stop = padding + (i + 1) * duration diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 1f97bf5201..d7ceef2561 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -33,7 +33,7 @@ class RandomProjectionClustering: "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": os.cpu_count(), - "cluster_selection_method": "leaf", + "cluster_selection_method": "leaf" }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, From 7d9c0753fb3c59577dd244d3c9bce1d6272015e6 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 4 Oct 2023 14:11:54 +0200 Subject: [PATCH 04/19] WIP --- src/spikeinterface/preprocessing/remove_artifacts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 7e84822c61..8e72b96c6d 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -1,4 +1,5 @@ import numpy as np +import scipy from spikeinterface.core.core_tools import define_function_from_class From e97005aa5e94328cee3d97097b98d6a7289ee437 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 4 Oct 2023 16:21:54 +0200 Subject: [PATCH 05/19] Patch for scipy --- src/spikeinterface/preprocessing/remove_artifacts.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 8e72b96c6d..1746b23941 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -1,5 +1,4 @@ import numpy as np -import scipy from spikeinterface.core.core_tools import define_function_from_class @@ -108,8 +107,6 @@ def __init__( time_jitter=0, waveforms_kwargs={"allow_unfiltered": True, "mode": "memory"}, ): - import scipy.interpolate - available_modes = ("zeros", "linear", "cubic", "average", "median") num_seg = recording.get_num_segments() @@ -237,7 +234,6 @@ def __init__( time_pad, sparsity, ): - import scipy.interpolate BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -255,6 +251,8 @@ def __init__( self.sparsity = sparsity def get_traces(self, start_frame, end_frame, channel_indices): + + if self.mode in ["average", "median"]: traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) else: @@ -286,6 +284,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): elif trig + pad[1] >= end_frame - start_frame: traces[trig - pad[0] :, :] = 0 elif self.mode in ["linear", "cubic"]: + import scipy.interpolate for trig in triggers: if pad is None: pre_data_end_idx = trig - 1 From 2a5e37c83054999514ccacd45b3c81d1865bc196 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Oct 2023 14:23:26 +0000 Subject: [PATCH 06/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/remove_artifacts.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 1746b23941..1eafa48a0b 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -234,7 +234,6 @@ def __init__( time_pad, sparsity, ): - BasePreprocessorSegment.__init__(self, parent_recording_segment) self.triggers = np.asarray(triggers, dtype="int64") @@ -251,8 +250,6 @@ def __init__( self.sparsity = sparsity def get_traces(self, start_frame, end_frame, channel_indices): - - if self.mode in ["average", "median"]: traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) else: @@ -285,6 +282,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces[trig - pad[0] :, :] = 0 elif self.mode in ["linear", "cubic"]: import scipy.interpolate + for trig in triggers: if pad is None: pre_data_end_idx = trig - 1 From 4cd3747786728e2942bef43b5c9d5ecba8d102fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 06:25:31 +0000 Subject: [PATCH 07/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index db06287f6c..6cf925e852 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -65,7 +65,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: recording_f = recording - #recording_f = whiten(recording_f, dtype="float32") + # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") ## Then, we are detecting peaks with a locally_exclusive method diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6dba4b7f0f..72cfd71791 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -602,7 +602,7 @@ def remove_duplicates_via_matching( } ) - spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()['unit_index'], return_counts=True) + spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) ignore_ids = [] diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index d7ceef2561..1f97bf5201 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -33,7 +33,7 @@ class RandomProjectionClustering: "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": os.cpu_count(), - "cluster_selection_method": "leaf" + "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, From 22c0eb426507be87790cbcd68427e3d3764721ee Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 08:29:18 +0200 Subject: [PATCH 08/19] Fix bug while reloading --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6dba4b7f0f..d94345f56b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -664,7 +664,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording, method_kwargs + del recording, sub_recording, method_kwargs, waveform_extractor os.remove(tmp_filename) return labels, new_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index d7ceef2561..4d1dd1f9d5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -223,6 +223,8 @@ def sigmoid(x, L, x0, k, b): ) del we, sorting + import gc + gc.collect() if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) From f69d7e3dbd013c52564b79c1f6ce5c87a3f67af0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 06:30:11 +0000 Subject: [PATCH 09/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/random_projections.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 7cb882409d..620346a875 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -224,6 +224,7 @@ def sigmoid(x, L, x0, k, b): del we, sorting import gc + gc.collect() if params["tmp_folder"] is None: From 403890ce83b065a76bcc1542a562d1a73e6e04be Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 09:01:02 +0200 Subject: [PATCH 10/19] Found it! --- .../clustering/clustering_tools.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index ce29c47113..734ceff1a3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,7 +593,9 @@ def remove_duplicates_via_matching( chunk_size = duration + 3 * margin - method_kwargs.update( + local_params = method_kwargs.copy() + + local_params.update( { "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, @@ -613,12 +615,12 @@ def remove_duplicates_via_matching( t_stop = padding + (i + 1) * duration sub_recording = recording.frame_slice(t_start - half_marging, t_stop + half_marging) - method_kwargs.update({"ignored_ids": ignore_ids + [i]}) + local_params.update({"ignored_ids": ignore_ids + [i]}) spikes, computed = find_spikes_from_templates( - sub_recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs + sub_recording, method=method, method_kwargs=local_params, extra_outputs=True, **job_kwargs ) if method == "circus-omp-svd": - method_kwargs.update( + local_params.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], @@ -632,7 +634,7 @@ def remove_duplicates_via_matching( } ) elif method == "circus-omp": - method_kwargs.update( + local_params.update( { "overlaps": computed["overlaps"], "templates": computed["templates"], @@ -664,7 +666,7 @@ def remove_duplicates_via_matching( labels = np.unique(new_labels) labels = labels[labels >= 0] - del recording, sub_recording, method_kwargs, waveform_extractor + del recording, sub_recording, local_params, waveform_extractor os.remove(tmp_filename) return labels, new_labels From 6951e856c0794e78108be180d6f16e0fde6af6e2 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 09:54:27 +0200 Subject: [PATCH 11/19] WIP --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 734ceff1a3..b4938717f8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -610,7 +610,7 @@ def remove_duplicates_via_matching( ignore_ids = [] similar_templates = [[], []] - for i in np.arange(nb_templates)[indices]: + for i in indices: t_start = padding + i * duration t_stop = padding + (i + 1) * duration diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 620346a875..1f97bf5201 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -223,9 +223,6 @@ def sigmoid(x, L, x0, k, b): ) del we, sorting - import gc - - gc.collect() if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) From fdebd12b09654796a177f4ab91b8e614409f5ac7 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 5 Oct 2023 10:43:20 +0200 Subject: [PATCH 12/19] Sparse waveforms were not handled --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- .../sortingcomponents/clustering/random_projections.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 6cf925e852..0c3b9f95d1 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -99,10 +99,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets clustering_params = params["clustering"].copy() - clustering_params["waveforms_kwargs"] = params["waveforms"] + clustering_params["waveforms"] = params["waveforms"].copy() for k in ["ms_before", "ms_after"]: - clustering_params["waveforms_kwargs"][k] = params["general"][k] + clustering_params["waveforms"][k] = params["general"][k] clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 1f97bf5201..ffb868f682 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -199,9 +199,8 @@ def sigmoid(x, L, x0, k, b): recording, sorting, waveform_folder, - ms_before=params["ms_before"], - ms_after=params["ms_after"], **params["job_kwargs"], + **params['waveforms'], return_scaled=False, mode=mode, ) From b6f9235a7cf9c2ad106ec0e4cb6be365a243d2af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 08:44:20 +0000 Subject: [PATCH 13/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/random_projections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index ffb868f682..a81458d7a8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -200,7 +200,7 @@ def sigmoid(x, L, x0, k, b): sorting, waveform_folder, **params["job_kwargs"], - **params['waveforms'], + **params["waveforms"], return_scaled=False, mode=mode, ) From f4f3fb4199a59add1882b26e0925e08c00d1fed3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 11:34:41 +0200 Subject: [PATCH 14/19] wip --- .../sortingcomponents/clustering/merge.py | 87 ++++++++++++++++--- .../sortingcomponents/clustering/split.py | 13 ++- 2 files changed, 83 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index e2049d70bf..1dd9f9fc37 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -80,9 +80,45 @@ def merge_clusters( method_kwargs=method_kwargs, **job_kwargs, ) + + + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + fig, ax = plt.subplots() + ax.matshow(pair_values) + + pair_values[~pair_mask] = 20 + + import hdbscan + fig, ax = plt.subplots() + clusterer = hdbscan.HDBSCAN(metric='precomputed', min_cluster_size=2, allow_single_cluster=True) + clusterer.fit(pair_values) + print(clusterer.labels_) + clusterer.single_linkage_tree_.plot(cmap='viridis', colorbar=True) + #~ fig, ax = plt.subplots() + #~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', + #~ edge_alpha=0.6, + #~ node_size=80, + #~ edge_linewidth=2) + + graph = clusterer.single_linkage_tree_.to_networkx() + + import scipy.cluster + fig, ax = plt.subplots() + scipy.cluster.hierarchy.dendrogram(clusterer.single_linkage_tree_.to_numpy(), ax=ax) + + import networkx as nx + fig = plt.figure() + nx.draw_networkx(graph) + plt.show() - # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") - merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") + plt.show() + + + + merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") + # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift) @@ -187,7 +223,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" else: raise ValueError - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -196,7 +232,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" nx.draw_networkx(sub_graph) plt.show() - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -348,6 +384,7 @@ def merge( criteria="diptest", threshold_diptest=0.5, threshold_percentile=80.0, + threshold_overlap=0.4, num_shift=2, ): if num_shift > 0: @@ -449,6 +486,23 @@ def merge( l1 = np.percentile(feat1, 100.0 - threshold_percentile) is_merge = l0 >= l1 merge_value = l0 - l1 + elif criteria == "distrib_overlap": + lim0 = min(np.min(feat0), np.min(feat1)) + lim1 = max(np.max(feat0), np.max(feat1)) + bin_size = (lim1 - lim0) / 200. + bins = np.arange(lim0, lim1, bin_size) + + pdf0, _ = np.histogram(feat0, bins=bins, density=True) + pdf1, _ = np.histogram(feat1, bins=bins, density=True) + pdf0 *= bin_size + pdf1 *= bin_size + overlap = np.sum(np.minimum(pdf0, pdf1)) + + is_merge = overlap >= threshold_overlap + + merge_value = 1 - overlap + + else: raise ValueError(f"bad criteria {criteria}") @@ -457,11 +511,13 @@ def merge( else: final_shift = 0 - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG and is_merge: - # if DEBUG: + # if DEBUG and not is_merge: + # if DEBUG and (overlap > 0.05 and overlap <0.25): + # if label0 == 49 and label1== 65: import matplotlib.pyplot as plt flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) @@ -479,13 +535,16 @@ def merge( ax.legend() bins = np.linspace(np.percentile(feat, 1), np.percentile(feat, 99), 100) - - count0, _ = np.histogram(feat0, bins=bins) - count1, _ = np.histogram(feat1, bins=bins) + bin_size = bins[1] - bins[0] + count0, _ = np.histogram(feat0, bins=bins, density=True) + count1, _ = np.histogram(feat1, bins=bins, density=True) + pdf0 = count0 * bin_size + pdf1 = count1 * bin_size + ax = axs[1] - ax.plot(bins[:-1], count0, color="C0") - ax.plot(bins[:-1], count1, color="C1") + ax.plot(bins[:-1], pdf0, color="C0") + ax.plot(bins[:-1], pdf1, color="C1") if criteria == "diptest": ax.set_title(f"{dipscore:.4f} {is_merge}") @@ -493,9 +552,11 @@ def merge( ax.set_title(f"{l0:.4f} {l1:.4f} {is_merge}") ax.axvline(l0, color="C0") ax.axvline(l1, color="C1") + elif criteria == "distrib_overlap": + print(lim0, lim1, ) + ax.set_title(f"{overlap:.4f} {is_merge}") + ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls='--', color='k') - - plt.show() diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 411d8c2116..d3e630a165 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -205,9 +205,11 @@ def split( final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) if clusterer == "hdbscan": - clust = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, allow_single_cluster=True) + clust = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, allow_single_cluster=True, + cluster_selection_method="leaf") clust.fit(final_features) possible_labels = clust.labels_ + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 elif clusterer == "isocut5": dipscore, cutpoint = isocut5(final_features[:, 0]) possible_labels = np.zeros(final_features.shape[0]) @@ -215,14 +217,15 @@ def split( mask = final_features[:, 0] > cutpoint if np.sum(mask) > min_cluster_size and np.sum(~mask): possible_labels[mask] = 1 + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 else: - return False, None + is_split = False else: raise ValueError(f"wrong clusterer {clusterer}") - is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -243,6 +246,8 @@ def split( ax = axs[1] ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) + + axs[0].set_title(f"{clusterer} {is_split}") plt.show() From bef9c4ab9d5eeea9331bfbab5076da23ef5f61cc Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 16:09:48 +0200 Subject: [PATCH 15/19] change split merge naming --- .../sortingcomponents/clustering/merge.py | 17 +++++++++-- .../sortingcomponents/clustering/split.py | 29 ++++++++++++++----- .../sortingcomponents/tests/test_merge.py | 13 +++++++++ .../sortingcomponents/tests/test_split.py | 1 + 4 files changed, 49 insertions(+), 11 deletions(-) create mode 100644 src/spikeinterface/sortingcomponents/tests/test_merge.py diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 1dd9f9fc37..5539ec1051 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -368,8 +368,19 @@ def find_pair_function_wrapper(label0, label1): return is_merge, label0, label1, shift, merge_value -class WaveformsLda: - name = "waveforms_lda" +class ProjectDistribution: + """ + This method is a refactorized mix between: + * old tridesclous code + * some ideas by Charlie Windolf in spikespvae + + The idea is : + * project the waveform (or features) samples on a 1d axis (using LDA for instance). + * check that it is the same or not distribution (diptest, distrib_overlap, ...) + + + """ + name = "project_distribution" @staticmethod def merge( @@ -564,6 +575,6 @@ def merge( find_pair_method_list = [ - WaveformsLda, + ProjectDistribution, ] find_pair_method_dict = {e.name: e for e in find_pair_method_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index d3e630a165..dc649cec97 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -13,6 +13,9 @@ from .isocut5 import isocut5 +# important all DEBUG and matplotlib are left in the code intentionally + + def split_clusters( peak_labels, recording, @@ -25,7 +28,7 @@ def split_clusters( **job_kwargs, ): """ - Run recusrsively or not in a multi process pool a local split method. + Run recusrsively (or not) in a multi process pool a local split method. Parameters ---------- @@ -151,11 +154,20 @@ def split_function_wrapper(peak_indices): return is_split, local_labels, peak_indices -class HdbscanOnLocalPca: - # @charlie : this is the equivalent of "herding_split()" in DART - # but simplified, flexible and renamed - name = "hdbscan_on_local_pca" +class LocalFeatureClustering: + """ + This method is a refactorized mix between: + * old tridesclous code + * "herding_split()" in DART/spikepsvae by Charlie Windolf + + The idea simple : + * agregate features (svd or even waveforms) with sparse channel. + * run a local feature reduction (pca or svd) + * try a new split (hdscan or isocut5) + """ + + name = "local_feature_clustering" @staticmethod def split( @@ -170,6 +182,8 @@ def split( min_cluster_size=25, min_samples=25, n_pca_features=2, + minimum_common_channels=2, + ): local_labels = np.zeros(peak_indices.size, dtype=np.int64) @@ -183,8 +197,7 @@ def split( target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) # TODO fix this a better way, this when cluster have too few overlapping channels - minimum_channels = 2 - if target_channels.size < minimum_channels: + if target_channels.size < minimum_common_channels: return False, None aligned_wfs, dont_have_channels = aggregate_sparse_features( @@ -260,6 +273,6 @@ def split( split_methods_list = [ - HdbscanOnLocalPca, + LocalFeatureClustering, ] split_methods_dict = {e.name: e for e in split_methods_list} diff --git a/src/spikeinterface/sortingcomponents/tests/test_merge.py b/src/spikeinterface/sortingcomponents/tests/test_merge.py new file mode 100644 index 0000000000..b7a669a263 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_merge.py @@ -0,0 +1,13 @@ +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.clustering.split import split_clusters + +# no proper test at the moment this is used in tridesclous2 + +def test_merge(): + pass + + +if __name__ == "__main__": + test_merge() diff --git a/src/spikeinterface/sortingcomponents/tests/test_split.py b/src/spikeinterface/sortingcomponents/tests/test_split.py index ed5e756469..ca5e5b57e7 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_split.py +++ b/src/spikeinterface/sortingcomponents/tests/test_split.py @@ -3,6 +3,7 @@ from spikeinterface.sortingcomponents.clustering.split import split_clusters +# no proper test at the moment this is used in tridesclous2 def test_split(): pass From df35c6a2ba3458597e2ec3c47673cbee9e4b7182 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 16:22:42 +0200 Subject: [PATCH 16/19] small fixes in tests --- src/spikeinterface/core/job_tools.py | 3 +-- src/spikeinterface/core/tests/test_globals.py | 6 +++--- src/spikeinterface/core/tests/test_waveform_extractor.py | 6 ++++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index a13e1dd527..e42f7bb8b4 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -434,8 +434,7 @@ def function_wrapper(args): return _func(segment_index, start_frame, end_frame, _worker_ctx) -# Here some utils - +# Here some utils copy/paste from DART (Charlie Windolf) class MockFuture: """A non-concurrent class for mocking the concurrent.futures API.""" diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 8216a4aae6..2c0792c152 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -37,16 +37,16 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): - job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True) + job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True) + assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True) + assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=10) job_kwargs_split = fix_job_kwargs(new_job_kwargs) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 2bbf5e9b0f..de6c3d752a 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -346,6 +346,8 @@ def test_recordingless(): # delete original recording and rely on rec_attributes if platform.system() != "Windows": + # this avoid reference on the folder + del we, recording shutil.rmtree(cache_folder / "recording1") we_loaded = WaveformExtractor.load(wf_folder, with_recording=False) assert not we_loaded.has_recording() @@ -554,6 +556,6 @@ def test_non_json_object(): # test_WaveformExtractor() # test_extract_waveforms() # test_portability() - # test_recordingless() + test_recordingless() # test_compute_sparsity() - test_non_json_object() + # test_non_json_object() From 94bfb70f528603ecf22d7b499228146792fb33b9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 16:23:56 +0200 Subject: [PATCH 17/19] in1d to isin --- src/spikeinterface/sortingcomponents/clustering/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 9a537ab8a8..8e25c9cb7f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -83,7 +83,7 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, for chan in np.unique(local_peaks["channel_index"]): sparse_chans = np.flatnonzero(sparse_mask[chan, :]) peak_inds = np.flatnonzero(local_peaks["channel_index"] == chan) - if np.all(np.in1d(target_channels, sparse_chans)): + if np.all(np.isin(target_channels, sparse_chans)): # peaks feature channel have all target_channels source_chans = np.flatnonzero(np.in1d(sparse_chans, target_channels)) aligned_features[peak_inds, :, :] = sparse_feature[peak_indices[peak_inds], :, :][:, :, source_chans] From 48da4ea5f429eac411a331a39d9b468428b70897 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 5 Oct 2023 16:45:52 +0200 Subject: [PATCH 18/19] wip --- src/spikeinterface/sortingcomponents/clustering/tools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 8e25c9cb7f..c334daebe3 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -94,6 +94,7 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, return aligned_features, dont_have_channels + def compute_template_from_sparse( peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None ): From de2d642d5f833b5c3f68df5150311b1ed5eddca8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 14:41:53 +0000 Subject: [PATCH 19/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/job_tools.py | 1 + src/spikeinterface/core/tests/test_globals.py | 8 +- .../core/tests/test_waveform_extractor.py | 2 +- .../sortingcomponents/clustering/merge.py | 77 ++++++++++--------- .../sortingcomponents/clustering/split.py | 18 ++--- .../sortingcomponents/clustering/tools.py | 1 - .../sortingcomponents/tests/test_merge.py | 1 + .../sortingcomponents/tests/test_split.py | 1 + 8 files changed, 58 insertions(+), 51 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index e42f7bb8b4..cf7a67489c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -436,6 +436,7 @@ def function_wrapper(args): # Here some utils copy/paste from DART (Charlie Windolf) + class MockFuture: """A non-concurrent class for mocking the concurrent.futures API.""" diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 2c0792c152..d0672405d6 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -39,14 +39,18 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + assert global_job_kwargs == dict( + n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + assert global_job_kwargs == dict( + n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=10) job_kwargs_split = fix_job_kwargs(new_job_kwargs) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index de6c3d752a..b56180a9e9 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -558,4 +558,4 @@ def test_non_json_object(): # test_portability() test_recordingless() # test_compute_sparsity() - # test_non_json_object() + # test_non_json_object() diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 5539ec1051..d892d0723a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -80,45 +80,46 @@ def merge_clusters( method_kwargs=method_kwargs, **job_kwargs, ) - - + DEBUG = False if DEBUG: import matplotlib.pyplot as plt + fig, ax = plt.subplots() ax.matshow(pair_values) - - pair_values[~pair_mask] = 20 - + + pair_values[~pair_mask] = 20 + import hdbscan + fig, ax = plt.subplots() - clusterer = hdbscan.HDBSCAN(metric='precomputed', min_cluster_size=2, allow_single_cluster=True) + clusterer = hdbscan.HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True) clusterer.fit(pair_values) print(clusterer.labels_) - clusterer.single_linkage_tree_.plot(cmap='viridis', colorbar=True) - #~ fig, ax = plt.subplots() - #~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', - #~ edge_alpha=0.6, - #~ node_size=80, - #~ edge_linewidth=2) - + clusterer.single_linkage_tree_.plot(cmap="viridis", colorbar=True) + # ~ fig, ax = plt.subplots() + # ~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', + # ~ edge_alpha=0.6, + # ~ node_size=80, + # ~ edge_linewidth=2) + graph = clusterer.single_linkage_tree_.to_networkx() import scipy.cluster + fig, ax = plt.subplots() scipy.cluster.hierarchy.dendrogram(clusterer.single_linkage_tree_.to_numpy(), ax=ax) - + import networkx as nx + fig = plt.figure() nx.draw_networkx(graph) plt.show() plt.show() - - - + merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") - # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") + # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift) @@ -223,7 +224,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" else: raise ValueError - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -232,7 +233,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" nx.draw_networkx(sub_graph) plt.show() - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -377,9 +378,10 @@ class ProjectDistribution: The idea is : * project the waveform (or features) samples on a 1d axis (using LDA for instance). * check that it is the same or not distribution (diptest, distrib_overlap, ...) - + """ + name = "project_distribution" @staticmethod @@ -412,13 +414,12 @@ def merge( chans1 = np.unique(peaks["channel_index"][inds1]) target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) - if inds0.size <40 or inds1.size <40: + if inds0.size < 40 or inds1.size < 40: is_merge = False merge_value = 0 final_shift = 0 return is_merge, label0, label1, final_shift, merge_value - target_chans = np.intersect1d(target_chans0, target_chans1) inds = np.concatenate([inds0, inds1]) @@ -500,20 +501,19 @@ def merge( elif criteria == "distrib_overlap": lim0 = min(np.min(feat0), np.min(feat1)) lim1 = max(np.max(feat0), np.max(feat1)) - bin_size = (lim1 - lim0) / 200. + bin_size = (lim1 - lim0) / 200.0 bins = np.arange(lim0, lim1, bin_size) - + pdf0, _ = np.histogram(feat0, bins=bins, density=True) pdf1, _ = np.histogram(feat1, bins=bins, density=True) pdf0 *= bin_size - pdf1 *= bin_size + pdf1 *= bin_size overlap = np.sum(np.minimum(pdf0, pdf1)) - + is_merge = overlap >= threshold_overlap - + merge_value = 1 - overlap - - + else: raise ValueError(f"bad criteria {criteria}") @@ -522,13 +522,13 @@ def merge( else: final_shift = 0 - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG and is_merge: - # if DEBUG and not is_merge: - # if DEBUG and (overlap > 0.05 and overlap <0.25): - # if label0 == 49 and label1== 65: + # if DEBUG and not is_merge: + # if DEBUG and (overlap > 0.05 and overlap <0.25): + # if label0 == 49 and label1== 65: import matplotlib.pyplot as plt flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) @@ -551,7 +551,6 @@ def merge( count1, _ = np.histogram(feat1, bins=bins, density=True) pdf0 = count0 * bin_size pdf1 = count1 * bin_size - ax = axs[1] ax.plot(bins[:-1], pdf0, color="C0") @@ -564,13 +563,15 @@ def merge( ax.axvline(l0, color="C0") ax.axvline(l1, color="C1") elif criteria == "distrib_overlap": - print(lim0, lim1, ) + print( + lim0, + lim1, + ) ax.set_title(f"{overlap:.4f} {is_merge}") - ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls='--', color='k') + ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls="--", color="k") plt.show() - return is_merge, label0, label1, final_shift, merge_value diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index dc649cec97..9836e9110f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -154,13 +154,12 @@ def split_function_wrapper(peak_indices): return is_split, local_labels, peak_indices - class LocalFeatureClustering: """ This method is a refactorized mix between: * old tridesclous code * "herding_split()" in DART/spikepsvae by Charlie Windolf - + The idea simple : * agregate features (svd or even waveforms) with sparse channel. * run a local feature reduction (pca or svd) @@ -183,7 +182,6 @@ def split( min_samples=25, n_pca_features=2, minimum_common_channels=2, - ): local_labels = np.zeros(peak_indices.size, dtype=np.int64) @@ -218,8 +216,12 @@ def split( final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) if clusterer == "hdbscan": - clust = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, allow_single_cluster=True, - cluster_selection_method="leaf") + clust = HDBSCAN( + min_cluster_size=min_cluster_size, + min_samples=min_samples, + allow_single_cluster=True, + cluster_selection_method="leaf", + ) clust.fit(final_features) possible_labels = clust.labels_ is_split = np.setdiff1d(possible_labels, [-1]).size > 1 @@ -236,9 +238,7 @@ def split( else: raise ValueError(f"wrong clusterer {clusterer}") - - - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -259,7 +259,7 @@ def split( ax = axs[1] ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) - + axs[0].set_title(f"{clusterer} {is_split}") plt.show() diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index c334daebe3..8e25c9cb7f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -94,7 +94,6 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, return aligned_features, dont_have_channels - def compute_template_from_sparse( peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None ): diff --git a/src/spikeinterface/sortingcomponents/tests/test_merge.py b/src/spikeinterface/sortingcomponents/tests/test_merge.py index b7a669a263..6b3ea2a901 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_merge.py +++ b/src/spikeinterface/sortingcomponents/tests/test_merge.py @@ -5,6 +5,7 @@ # no proper test at the moment this is used in tridesclous2 + def test_merge(): pass diff --git a/src/spikeinterface/sortingcomponents/tests/test_split.py b/src/spikeinterface/sortingcomponents/tests/test_split.py index ca5e5b57e7..5953f74e24 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_split.py +++ b/src/spikeinterface/sortingcomponents/tests/test_split.py @@ -5,6 +5,7 @@ # no proper test at the moment this is used in tridesclous2 + def test_split(): pass