diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index e5581c7a67..d039206296 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -96,7 +96,7 @@ def is_set_global_dataset_folder(): ######################################## global global_job_kwargs -global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) +global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global global_job_kwargs_set global_job_kwargs_set = False diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 84ee502c14..cf7a67489c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -380,10 +380,6 @@ def run(self): self.gather_func(res) else: n_jobs = min(self.n_jobs, len(all_chunks)) - ######## Do you want to limit the number of threads per process? - ######## It has to be done to speed up numpy a lot if multicores - ######## Otherwise, np.dot will be slow. How to do that, up to you - ######## This is just a suggestion, but here it adds a dependency # parallel with ProcessPoolExecutor( @@ -436,3 +432,59 @@ def function_wrapper(args): else: with threadpool_limits(limits=max_threads_per_process): return _func(segment_index, start_frame, end_frame, _worker_ctx) + + +# Here some utils copy/paste from DART (Charlie Windolf) + + +class MockFuture: + """A non-concurrent class for mocking the concurrent.futures API.""" + + def __init__(self, f, *args): + self.f = f + self.args = args + + def result(self): + return self.f(*self.args) + + +class MockPoolExecutor: + """A non-concurrent class for mocking the concurrent.futures API.""" + + def __init__( + self, + max_workers=None, + mp_context=None, + initializer=None, + initargs=None, + context=None, + ): + if initializer is not None: + initializer(*initargs) + self.map = map + self.imap = map + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return + + def submit(self, f, *args): + return MockFuture(f, *args) + + +class MockQueue: + """Another helper class for turning off concurrency when debugging.""" + + def __init__(self): + self.q = [] + self.put = self.q.append + self.get = lambda: self.q.pop(0) + + +def get_poolexecutor(n_jobs): + if n_jobs == 1: + return MockPoolExecutor + else: + return ProcessPoolExecutor diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 651804c995..a0ded216d1 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -432,6 +432,7 @@ def run_node_pipeline( job_name="pipeline", mp_context=None, gather_mode="memory", + gather_kwargs={}, squeeze_output=True, folder=None, names=None, @@ -448,7 +449,7 @@ def run_node_pipeline( if gather_mode == "memory": gather_func = GatherToMemory() elif gather_mode == "npy": - gather_func = GatherToNpy(folder, names) + gather_func = GatherToNpy(folder, names, **gather_kwargs) else: raise ValueError(f"wrong gather_mode : {gather_mode}") @@ -593,9 +594,9 @@ class GatherToNpy: * create the npy v1.0 header at the end with the correct shape and dtype """ - def __init__(self, folder, names, npy_header_size=1024): + def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): self.folder = Path(folder) - self.folder.mkdir(parents=True, exist_ok=False) + self.folder.mkdir(parents=True, exist_ok=exist_ok) assert names is not None self.names = names self.npy_header_size = npy_header_size diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 8216a4aae6..d0672405d6 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -37,16 +37,20 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): - job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True) + job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True) + assert global_job_kwargs == dict( + n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True) + assert global_job_kwargs == dict( + n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=10) job_kwargs_split = fix_job_kwargs(new_job_kwargs) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 2bbf5e9b0f..b56180a9e9 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -346,6 +346,8 @@ def test_recordingless(): # delete original recording and rely on rec_attributes if platform.system() != "Windows": + # this avoid reference on the folder + del we, recording shutil.rmtree(cache_folder / "recording1") we_loaded = WaveformExtractor.load(wf_folder, with_recording=False) assert not we_loaded.has_recording() @@ -554,6 +556,6 @@ def test_non_json_object(): # test_WaveformExtractor() # test_extract_waveforms() # test_portability() - # test_recordingless() + test_recordingless() # test_compute_sparsity() - test_non_json_object() + # test_non_json_object() diff --git a/src/spikeinterface/sortingcomponents/clustering/clean.py b/src/spikeinterface/sortingcomponents/clustering/clean.py new file mode 100644 index 0000000000..cbded0c49f --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/clean.py @@ -0,0 +1,45 @@ +import numpy as np + +from .tools import FeaturesLoader, compute_template_from_sparse + +# This is work in progress ... + + +def clean_clusters( + peaks, + peak_labels, + recording, + features_dict_or_folder, + peak_sign="neg", +): + total_channels = recording.get_num_channels() + + if isinstance(features_dict_or_folder, dict): + features = features_dict_or_folder + else: + features = FeaturesLoader(features_dict_or_folder) + + clean_labels = peak_labels.copy() + + sparse_wfs = features["sparse_wfs"] + sparse_mask = features["sparse_mask"] + + labels_set = np.setdiff1d(peak_labels, [-1]).tolist() + n = len(labels_set) + + count = np.zeros(n, dtype="int64") + for i, label in enumerate(labels_set): + count[i] = np.sum(peak_labels == label) + print(count) + + templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels) + + if peak_sign == "both": + max_values = np.max(np.abs(templates), axis=(1, 2)) + elif peak_sign == "neg": + max_values = -np.min(templates, axis=(1, 2)) + elif peak_sign == "pos": + max_values = np.max(templates, axis=(1, 2)) + print(max_values) + + return clean_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py new file mode 100644 index 0000000000..d892d0723a --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -0,0 +1,581 @@ +from pathlib import Path +from multiprocessing import get_context +from concurrent.futures import ProcessPoolExecutor +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +import scipy.spatial +from sklearn.decomposition import PCA +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from hdbscan import HDBSCAN + +import numpy as np +import networkx as nx + +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + + +from .isocut5 import isocut5 + +from .tools import aggregate_sparse_features, FeaturesLoader, compute_template_from_sparse + + +def merge_clusters( + peaks, + peak_labels, + recording, + features_dict_or_folder, + radius_um=70, + method="waveforms_lda", + method_kwargs={}, + **job_kwargs, +): + """ + Merge cluster using differents methods. + + Parameters + ---------- + peaks: numpy.ndarray 1d + detected peaks (or a subset) + peak_labels: numpy.ndarray 1d + original label before merge + peak_labels.size == peaks.size + recording: Recording object + A recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. + method: str + The method used + method_kwargs: dict + Option for the method. + Returns + ------- + merge_peak_labels: numpy.ndarray 1d + New vectors label after merges. + peak_shifts: numpy.ndarray 1d + A vector of sample shift to be reverse applied on original sample_index on peak detection + Negative shift means too early. + Posituve shift means too late. + So the correction must be applied like this externaly: + final_peaks = peaks.copy() + final_peaks['sample_index'] -= peak_shifts + + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + + features = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + sparse_wfs = features["sparse_wfs"] + sparse_mask = features["sparse_mask"] + + labels_set, pair_mask, pair_shift, pair_values = find_merge_pairs( + peaks, + peak_labels, + recording, + features_dict_or_folder, + sparse_wfs, + sparse_mask, + radius_um=radius_um, + method=method, + method_kwargs=method_kwargs, + **job_kwargs, + ) + + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.matshow(pair_values) + + pair_values[~pair_mask] = 20 + + import hdbscan + + fig, ax = plt.subplots() + clusterer = hdbscan.HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True) + clusterer.fit(pair_values) + print(clusterer.labels_) + clusterer.single_linkage_tree_.plot(cmap="viridis", colorbar=True) + # ~ fig, ax = plt.subplots() + # ~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', + # ~ edge_alpha=0.6, + # ~ node_size=80, + # ~ edge_linewidth=2) + + graph = clusterer.single_linkage_tree_.to_networkx() + + import scipy.cluster + + fig, ax = plt.subplots() + scipy.cluster.hierarchy.dendrogram(clusterer.single_linkage_tree_.to_numpy(), ax=ax) + + import networkx as nx + + fig = plt.figure() + nx.draw_networkx(graph) + plt.show() + + plt.show() + + merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") + # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") + + group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift) + + # apply final label and shift + merge_peak_labels = peak_labels.copy() + peak_shifts = np.zeros(peak_labels.size, dtype="int64") + for merge, shifts in zip(merges, group_shifts): + label0 = merge[0] + mask = np.in1d(peak_labels, merge) + merge_peak_labels[mask] = label0 + for l, label1 in enumerate(merge): + if l == 0: + # the first label is the reference (shift=0) + continue + peak_shifts[peak_labels == label1] = shifts[l] + + return merge_peak_labels, peak_shifts + + +def resolve_final_shifts(labels_set, merges, pair_mask, pair_shift): + labels_set = list(labels_set) + + group_shifts = [] + for merge in merges: + shifts = np.zeros(len(merge), dtype="int64") + + label_inds = [labels_set.index(label) for label in merge] + + label0 = merge[0] + ind0 = label_inds[0] + + # First find relative shift to label0 (l=0) in the subgraph + local_pair_mask = pair_mask[label_inds, :][:, label_inds] + local_pair_shift = None + G = None + for l, label1 in enumerate(merge): + if l == 0: + # the first label is the reference (shift=0) + continue + ind1 = label_inds[l] + if local_pair_mask[0, l]: + # easy case the pair label0<>label1 was existing + shift = pair_shift[ind0, ind1] + else: + # more complicated case need to find intermediate label and propagate the shift!! + if G is None: + # the the graph only once and only if needed + G = nx.from_numpy_array(local_pair_mask | local_pair_mask.T) + local_pair_shift = pair_shift[label_inds, :][:, label_inds] + local_pair_shift += local_pair_shift.T + + shift_chain = nx.shortest_path(G, source=l, target=0) + shift = 0 + for i in range(len(shift_chain) - 1): + shift += local_pair_shift[shift_chain[i + 1], shift_chain[i]] + shifts[l] = shift + + group_shifts.append(shifts) + + return group_shifts + + +def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"): + """ + Agglomerate merge pairs into final merge groups. + + The merges are ordered by label. + + """ + + labels_set = np.array(labels_set) + + merges = [] + + graph = nx.from_numpy_array(pair_mask | pair_mask.T) + # put real nodes names for debugging + maps = dict(zip(np.arange(labels_set.size), labels_set)) + graph = nx.relabel_nodes(graph, maps) + + groups = list(nx.connected_components(graph)) + for group in groups: + if len(group) == 1: + continue + sub_graph = graph.subgraph(group) + # print(group, sub_graph) + cliques = list(nx.find_cliques(sub_graph)) + if len(cliques) == 1 and len(cliques[0]) == len(group): + # the sub graph is full connected: no ambiguity + # merges.append(labels_set[cliques[0]]) + merges.append(cliques[0]) + elif len(cliques) > 1: + # the subgraph is not fully connected + if connection_mode == "full": + # node merge + pass + elif connection_mode == "partial": + group = list(group) + # merges.append(labels_set[group]) + merges.append(group) + elif connection_mode == "clique": + raise NotImplementedError + else: + raise ValueError + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig = plt.figure() + nx.draw_networkx(sub_graph) + plt.show() + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + fig = plt.figure() + nx.draw_networkx(graph) + plt.show() + + # ensure ordered label + merges = [np.sort(merge) for merge in merges] + + return merges + + +def find_merge_pairs( + peaks, + peak_labels, + recording, + features_dict_or_folder, + sparse_wfs, + sparse_mask, + radius_um=70, + method="waveforms_lda", + method_kwargs={}, + **job_kwargs + # n_jobs=1, + # mp_context="fork", + # max_threads_per_process=1, + # progress_bar=True, +): + """ + Searh some possible merge 2 by 2. + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + # features_dict_or_folder = Path(features_dict_or_folder) + + # peaks = features_dict_or_folder['peaks'] + total_channels = recording.get_num_channels() + + # sparse_wfs = features['sparse_wfs'] + + labels_set = np.setdiff1d(peak_labels, [-1]).tolist() + n = len(labels_set) + pair_mask = np.triu(np.ones((n, n), dtype="bool")) & ~np.eye(n, dtype="bool") + pair_shift = np.zeros((n, n), dtype="int64") + pair_values = np.zeros((n, n), dtype="float64") + + # compute template (no shift at this step) + + templates = compute_template_from_sparse( + peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels, peak_shifts=None + ) + + max_chans = np.argmax(np.max(np.abs(templates), axis=1), axis=1) + + channel_locs = recording.get_channel_locations() + template_locs = channel_locs[max_chans, :] + template_dist = scipy.spatial.distance.cdist(template_locs, template_locs, metric="euclidean") + + pair_mask = pair_mask & (template_dist < radius_um) + indices0, indices1 = np.nonzero(pair_mask) + + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs["mp_context"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + progress_bar = job_kwargs["progress_bar"] + + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=find_pair_worker_init, + mp_context=get_context(mp_context), + initargs=(recording, features_dict_or_folder, peak_labels, method, method_kwargs, max_threads_per_process), + ) as pool: + jobs = [] + for ind0, ind1 in zip(indices0, indices1): + label0 = labels_set[ind0] + label1 = labels_set[ind1] + jobs.append(pool.submit(find_pair_function_wrapper, label0, label1)) + + if progress_bar: + iterator = tqdm(jobs, desc=f"find_merge_pairs with {method}", total=len(jobs)) + else: + iterator = jobs + + for res in iterator: + is_merge, label0, label1, shift, merge_value = res.result() + ind0 = labels_set.index(label0) + ind1 = labels_set.index(label1) + + pair_mask[ind0, ind1] = is_merge + if is_merge: + pair_shift[ind0, ind1] = shift + pair_values[ind0, ind1] = merge_value + + pair_mask = pair_mask & (template_dist < radius_um) + indices0, indices1 = np.nonzero(pair_mask) + + return labels_set, pair_mask, pair_shift, pair_values + + +def find_pair_worker_init( + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process +): + global _ctx + _ctx = {} + + _ctx["recording"] = recording + _ctx["original_labels"] = original_labels + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = find_pair_method_dict[method] + _ctx["max_threads_per_process"] = max_threads_per_process + + # if isinstance(features_dict_or_folder, dict): + # _ctx["features"] = features_dict_or_folder + # else: + # _ctx["features"] = FeaturesLoader(features_dict_or_folder) + + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + + _ctx["peaks"] = _ctx["features"]["peaks"] + + +def find_pair_function_wrapper(label0, label1): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_process"]): + is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( + label0, label1, _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + ) + return is_merge, label0, label1, shift, merge_value + + +class ProjectDistribution: + """ + This method is a refactorized mix between: + * old tridesclous code + * some ideas by Charlie Windolf in spikespvae + + The idea is : + * project the waveform (or features) samples on a 1d axis (using LDA for instance). + * check that it is the same or not distribution (diptest, distrib_overlap, ...) + + + """ + + name = "project_distribution" + + @staticmethod + def merge( + label0, + label1, + original_labels, + peaks, + features, + waveforms_sparse_mask=None, + feature_name="sparse_tsvd", + projection="centroid", + criteria="diptest", + threshold_diptest=0.5, + threshold_percentile=80.0, + threshold_overlap=0.4, + num_shift=2, + ): + if num_shift > 0: + assert feature_name == "sparse_wfs" + sparse_wfs = features[feature_name] + + assert waveforms_sparse_mask is not None + + (inds0,) = np.nonzero(original_labels == label0) + chans0 = np.unique(peaks["channel_index"][inds0]) + target_chans0 = np.flatnonzero(np.all(waveforms_sparse_mask[chans0, :], axis=0)) + + (inds1,) = np.nonzero(original_labels == label1) + chans1 = np.unique(peaks["channel_index"][inds1]) + target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) + + if inds0.size < 40 or inds1.size < 40: + is_merge = False + merge_value = 0 + final_shift = 0 + return is_merge, label0, label1, final_shift, merge_value + + target_chans = np.intersect1d(target_chans0, target_chans1) + + inds = np.concatenate([inds0, inds1]) + labels = np.zeros(inds.size, dtype="int") + labels[inds0.size :] = 1 + wfs, out = aggregate_sparse_features(peaks, inds, sparse_wfs, waveforms_sparse_mask, target_chans) + wfs = wfs[~out] + labels = labels[~out] + + cut = np.searchsorted(labels, 1) + wfs0_ = wfs[:cut, :, :] + wfs1_ = wfs[cut:, :, :] + + template0_ = np.mean(wfs0_, axis=0) + template1_ = np.mean(wfs1_, axis=0) + num_samples = template0_.shape[0] + + template0 = template0_[num_shift : num_samples - num_shift, :] + + wfs0 = wfs0_[:, num_shift : num_samples - num_shift, :] + + # best shift strategy 1 = max cosine + # values = [] + # for shift in range(num_shift * 2 + 1): + # template1 = template1_[shift : shift + template0.shape[0], :] + # norm = np.linalg.norm(template0.flatten()) * np.linalg.norm(template1.flatten()) + # value = np.sum(template0.flatten() * template1.flatten()) / norm + # values.append(value) + # best_shift = np.argmax(values) + + # best shift strategy 2 = min dist**2 + # values = [] + # for shift in range(num_shift * 2 + 1): + # template1 = template1_[shift : shift + template0.shape[0], :] + # value = np.sum((template1 - template0)**2) + # values.append(value) + # best_shift = np.argmin(values) + + # best shift strategy 3 : average delta argmin between channels + channel_shift = np.argmax(np.abs(template1_), axis=0) - np.argmax(np.abs(template0_), axis=0) + mask = np.abs(channel_shift) <= num_shift + channel_shift = channel_shift[mask] + if channel_shift.size > 0: + best_shift = int(np.round(np.mean(channel_shift))) + num_shift + else: + best_shift = num_shift + + wfs1 = wfs1_[:, best_shift : best_shift + template0.shape[0], :] + template1 = template1_[best_shift : best_shift + template0.shape[0], :] + + if projection == "lda": + wfs_0_1 = np.concatenate([wfs0, wfs1], axis=0) + flat_wfs = wfs_0_1.reshape(wfs_0_1.shape[0], -1) + feat = LinearDiscriminantAnalysis(n_components=1).fit_transform(flat_wfs, labels) + feat = feat[:, 0] + feat0 = feat[:cut] + feat1 = feat[cut:] + + elif projection == "centroid": + vector_0_1 = template1 - template0 + vector_0_1 /= np.sum(vector_0_1**2) + feat0 = np.sum((wfs0 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + feat1 = np.sum((wfs1 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + # feat = np.sum((wfs_0_1 - template0[np.newaxis, :, :]) * vector_0_1[np.newaxis, :, :], axis=(1, 2)) + feat = np.concatenate([feat0, feat1], axis=0) + + else: + raise ValueError(f"bad projection {projection}") + + if criteria == "diptest": + dipscore, cutpoint = isocut5(feat) + is_merge = dipscore < threshold_diptest + merge_value = dipscore + elif criteria == "percentile": + l0 = np.percentile(feat0, threshold_percentile) + l1 = np.percentile(feat1, 100.0 - threshold_percentile) + is_merge = l0 >= l1 + merge_value = l0 - l1 + elif criteria == "distrib_overlap": + lim0 = min(np.min(feat0), np.min(feat1)) + lim1 = max(np.max(feat0), np.max(feat1)) + bin_size = (lim1 - lim0) / 200.0 + bins = np.arange(lim0, lim1, bin_size) + + pdf0, _ = np.histogram(feat0, bins=bins, density=True) + pdf1, _ = np.histogram(feat1, bins=bins, density=True) + pdf0 *= bin_size + pdf1 *= bin_size + overlap = np.sum(np.minimum(pdf0, pdf1)) + + is_merge = overlap >= threshold_overlap + + merge_value = 1 - overlap + + else: + raise ValueError(f"bad criteria {criteria}") + + if is_merge: + final_shift = best_shift - num_shift + else: + final_shift = 0 + + # DEBUG = True + DEBUG = False + + if DEBUG and is_merge: + # if DEBUG and not is_merge: + # if DEBUG and (overlap > 0.05 and overlap <0.25): + # if label0 == 49 and label1== 65: + import matplotlib.pyplot as plt + + flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) + flatten_wfs1 = wfs1.swapaxes(1, 2).reshape(wfs1.shape[0], -1) + + fig, axs = plt.subplots(ncols=2) + ax = axs[0] + ax.plot(flatten_wfs0.T, color="C0", alpha=0.01) + ax.plot(flatten_wfs1.T, color="C1", alpha=0.01) + m0 = np.mean(flatten_wfs0, axis=0) + m1 = np.mean(flatten_wfs1, axis=0) + ax.plot(m0, color="C0", alpha=1, lw=4, label=f"{label0} {inds0.size}") + ax.plot(m1, color="C1", alpha=1, lw=4, label=f"{label1} {inds1.size}") + + ax.legend() + + bins = np.linspace(np.percentile(feat, 1), np.percentile(feat, 99), 100) + bin_size = bins[1] - bins[0] + count0, _ = np.histogram(feat0, bins=bins, density=True) + count1, _ = np.histogram(feat1, bins=bins, density=True) + pdf0 = count0 * bin_size + pdf1 = count1 * bin_size + + ax = axs[1] + ax.plot(bins[:-1], pdf0, color="C0") + ax.plot(bins[:-1], pdf1, color="C1") + + if criteria == "diptest": + ax.set_title(f"{dipscore:.4f} {is_merge}") + elif criteria == "percentile": + ax.set_title(f"{l0:.4f} {l1:.4f} {is_merge}") + ax.axvline(l0, color="C0") + ax.axvline(l1, color="C1") + elif criteria == "distrib_overlap": + print( + lim0, + lim1, + ) + ax.set_title(f"{overlap:.4f} {is_merge}") + ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls="--", color="k") + + plt.show() + + return is_merge, label0, label1, final_shift, merge_value + + +find_pair_method_list = [ + ProjectDistribution, +] +find_pair_method_dict = {e.name: e for e in find_pair_method_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py new file mode 100644 index 0000000000..9836e9110f --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -0,0 +1,278 @@ +from multiprocessing import get_context +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +from sklearn.decomposition import TruncatedSVD +from hdbscan import HDBSCAN + +import numpy as np + +from spikeinterface.core.job_tools import get_poolexecutor, fix_job_kwargs + +from .tools import aggregate_sparse_features, FeaturesLoader +from .isocut5 import isocut5 + + +# important all DEBUG and matplotlib are left in the code intentionally + + +def split_clusters( + peak_labels, + recording, + features_dict_or_folder, + method="hdbscan_on_local_pca", + method_kwargs={}, + recursive=False, + recursive_depth=None, + returns_split_count=False, + **job_kwargs, +): + """ + Run recusrsively (or not) in a multi process pool a local split method. + + Parameters + ---------- + peak_labels: numpy.array + Peak label before split + recording: Recording + Recording object + features_dict_or_folder: dict or folder + A dictionary of features precomputed with peak_pipeline or a folder containing npz file for features. + method: str + The method name + method_kwargs: dict + The method option + recursive: bool Default False + Reccursive or not. + recursive_depth: None or int + If recursive=True, then this is the max split per spikes. + returns_split_count: bool + Optionally return the split count vector. Same size as labels. + + Returns + ------- + new_labels: numpy.ndarray + The labels of peaks after split. + split_count: numpy.ndarray + Optionally returned + """ + + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + mp_context = job_kwargs["mp_context"] + progress_bar = job_kwargs["progress_bar"] + max_threads_per_process = job_kwargs["max_threads_per_process"] + + original_labels = peak_labels + peak_labels = peak_labels.copy() + split_count = np.zeros(peak_labels.size, dtype=int) + + Executor = get_poolexecutor(n_jobs) + + with Executor( + max_workers=n_jobs, + initializer=split_worker_init, + mp_context=get_context(mp_context), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), + ) as pool: + labels_set = np.setdiff1d(peak_labels, [-1]) + current_max_label = np.max(labels_set) + 1 + + jobs = [] + for label in labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices)) + + if progress_bar: + iterator = tqdm(jobs, desc=f"split_clusters with {method}", total=len(labels_set)) + else: + iterator = jobs + + for res in iterator: + is_split, local_labels, peak_indices = res.result() + if not is_split: + continue + + mask = local_labels >= 0 + peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label + peak_labels[peak_indices[~mask]] = local_labels[~mask] + + split_count[peak_indices] += 1 + + current_max_label += np.max(local_labels[mask]) + 1 + + if recursive: + if recursive_depth is not None: + # stop reccursivity when recursive_depth is reach + extra_ball = np.max(split_count[peak_indices]) < recursive_depth + else: + # reccurssive always + extra_ball = True + + if extra_ball: + new_labels_set = np.setdiff1d(peak_labels[peak_indices], [-1]) + for label in new_labels_set: + peak_indices = np.flatnonzero(peak_labels == label) + if peak_indices.size > 0: + jobs.append(pool.submit(split_function_wrapper, peak_indices)) + if progress_bar: + iterator.total += 1 + + if returns_split_count: + return peak_labels, split_count + else: + return peak_labels + + +global _ctx + + +def split_worker_init( + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process +): + global _ctx + _ctx = {} + + _ctx["recording"] = recording + features_dict_or_folder + _ctx["original_labels"] = original_labels + _ctx["method"] = method + _ctx["method_kwargs"] = method_kwargs + _ctx["method_class"] = split_methods_dict[method] + _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) + _ctx["peaks"] = _ctx["features"]["peaks"] + + +def split_function_wrapper(peak_indices): + global _ctx + with threadpool_limits(limits=_ctx["max_threads_per_process"]): + is_split, local_labels = _ctx["method_class"].split( + peak_indices, _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + ) + return is_split, local_labels, peak_indices + + +class LocalFeatureClustering: + """ + This method is a refactorized mix between: + * old tridesclous code + * "herding_split()" in DART/spikepsvae by Charlie Windolf + + The idea simple : + * agregate features (svd or even waveforms) with sparse channel. + * run a local feature reduction (pca or svd) + * try a new split (hdscan or isocut5) + """ + + name = "local_feature_clustering" + + @staticmethod + def split( + peak_indices, + peaks, + features, + clusterer="hdbscan", + feature_name="sparse_tsvd", + neighbours_mask=None, + waveforms_sparse_mask=None, + min_size_split=25, + min_cluster_size=25, + min_samples=25, + n_pca_features=2, + minimum_common_channels=2, + ): + local_labels = np.zeros(peak_indices.size, dtype=np.int64) + + # can be sparse_tsvd or sparse_wfs + sparse_features = features[feature_name] + + assert waveforms_sparse_mask is not None + + # target channel subset is done intersect local channels + neighbours + local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + + # TODO fix this a better way, this when cluster have too few overlapping channels + if target_channels.size < minimum_common_channels: + return False, None + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_channels + ) + + local_labels[dont_have_channels] = -2 + kept = np.flatnonzero(~dont_have_channels) + if kept.size < min_size_split: + return False, None + + aligned_wfs = aligned_wfs[kept, :, :] + + flatten_features = aligned_wfs.reshape(aligned_wfs.shape[0], -1) + + # final_features = PCA(n_pca_features, whiten=True).fit_transform(flatten_features) + # final_features = PCA(n_pca_features, whiten=False).fit_transform(flatten_features) + final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) + + if clusterer == "hdbscan": + clust = HDBSCAN( + min_cluster_size=min_cluster_size, + min_samples=min_samples, + allow_single_cluster=True, + cluster_selection_method="leaf", + ) + clust.fit(final_features) + possible_labels = clust.labels_ + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + elif clusterer == "isocut5": + dipscore, cutpoint = isocut5(final_features[:, 0]) + possible_labels = np.zeros(final_features.shape[0]) + if dipscore > 1.5: + mask = final_features[:, 0] > cutpoint + if np.sum(mask) > min_cluster_size and np.sum(~mask): + possible_labels[mask] = 1 + is_split = np.setdiff1d(possible_labels, [-1]).size > 1 + else: + is_split = False + else: + raise ValueError(f"wrong clusterer {clusterer}") + + # DEBUG = True + DEBUG = False + if DEBUG: + import matplotlib.pyplot as plt + + labels_set = np.setdiff1d(possible_labels, [-1]) + colors = plt.get_cmap("tab10", len(labels_set)) + colors = {k: colors(i) for i, k in enumerate(labels_set)} + colors[-1] = "k" + fix, axs = plt.subplots(nrows=2) + + flatten_wfs = aligned_wfs.swapaxes(1, 2).reshape(aligned_wfs.shape[0], -1) + + sl = slice(None, None, 10) + for k in np.unique(possible_labels): + mask = possible_labels == k + ax = axs[0] + ax.scatter(final_features[:, 0][mask][sl], final_features[:, 1][mask][sl], s=5, color=colors[k]) + + ax = axs[1] + ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) + + axs[0].set_title(f"{clusterer} {is_split}") + + plt.show() + + if not is_split: + return is_split, None + + local_labels[kept] = possible_labels + + return is_split, local_labels + + +split_methods_list = [ + LocalFeatureClustering, +] +split_methods_dict = {e.name: e for e in split_methods_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py new file mode 100644 index 0000000000..8e25c9cb7f --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -0,0 +1,196 @@ +from pathlib import Path +from typing import Any +import numpy as np + + +# TODO find a way to attach a a sparse_mask to a given features (waveforms, pca, tsvd ....) + + +class FeaturesLoader: + """ + Feature can be computed in memory or in a folder contaning npy files. + + This class read the folder and behave like a dict of array lazily. + + Parameters + ---------- + feature_folder + + preload + + """ + + def __init__(self, feature_folder, preload=["peaks"]): + self.feature_folder = Path(feature_folder) + + self.file_feature = {} + self.loaded_features = {} + for file in self.feature_folder.glob("*.npy"): + name = file.stem + if name in preload: + self.loaded_features[name] = np.load(file) + else: + self.file_feature[name] = file + + def __getitem__(self, name): + if name in self.loaded_features: + return self.loaded_features[name] + else: + return np.load(self.file_feature[name], mmap_mode="r") + + @staticmethod + def from_dict_or_folder(features_dict_or_folder): + if isinstance(features_dict_or_folder, dict): + return features_dict_or_folder + else: + return FeaturesLoader(features_dict_or_folder) + + +def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, target_channels): + """ + Aggregate sparse features that have unaligned channels and realigned then on target_channels. + + This is usefull to aligned back peaks waveform or pca or tsvd when detected a differents channels. + + + Parameters + ---------- + peaks + + peak_indices + + sparse_feature + + sparse_mask + + target_channels + + Returns + ------- + aligned_features: numpy.array + Aligned features. shape is (local_peaks.size, sparse_feature.shape[1], target_channels.size) + dont_have_channels: numpy.array + Boolean vector to indicate spikes that do not have all target channels to be taken in account + shape is peak_indices.size + """ + local_peaks = peaks[peak_indices] + + aligned_features = np.zeros( + (local_peaks.size, sparse_feature.shape[1], target_channels.size), dtype=sparse_feature.dtype + ) + dont_have_channels = np.zeros(peak_indices.size, dtype=bool) + + for chan in np.unique(local_peaks["channel_index"]): + sparse_chans = np.flatnonzero(sparse_mask[chan, :]) + peak_inds = np.flatnonzero(local_peaks["channel_index"] == chan) + if np.all(np.isin(target_channels, sparse_chans)): + # peaks feature channel have all target_channels + source_chans = np.flatnonzero(np.in1d(sparse_chans, target_channels)) + aligned_features[peak_inds, :, :] = sparse_feature[peak_indices[peak_inds], :, :][:, :, source_chans] + else: + # some channel are missing, peak are not removde + dont_have_channels[peak_inds] = True + + return aligned_features, dont_have_channels + + +def compute_template_from_sparse( + peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None +): + """ + Compute template average from single sparse waveforms buffer. + + Parameters + ---------- + peaks + + labels + + labels_set + + sparse_waveforms + + sparse_mask + + total_channels + + peak_shifts + + Returns + ------- + templates: numpy.array + Templates shape : (len(labels_set), num_samples, total_channels) + """ + n = len(labels_set) + + templates = np.zeros((n, sparse_waveforms.shape[1], total_channels), dtype=sparse_waveforms.dtype) + + for i, label in enumerate(labels_set): + peak_indices = np.flatnonzero(labels == label) + + local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(sparse_mask[local_chans, :], axis=0)) + + aligned_wfs, dont_have_channels = aggregate_sparse_features( + peaks, peak_indices, sparse_waveforms, sparse_mask, target_channels + ) + + if peak_shifts is not None: + apply_waveforms_shift(aligned_wfs, peak_shifts[peak_indices], inplace=True) + + templates[i, :, :][:, target_channels] = np.mean(aligned_wfs[~dont_have_channels], axis=0) + + return templates + + +def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): + """ + Apply a shift a spike level to realign waveforms buffers. + + This is usefull to compute template after merge when to cluster are shifted. + + A negative shift need the waveforms to be moved toward the right because the trough was too early. + A positive shift need the waveforms to be moved toward the left because the trough was too late. + + Note the border sample are left as before without move. + + Parameters + ---------- + + waveforms + + peak_shifts + + inplace + + Returns + ------- + aligned_waveforms + + + """ + + print("apply_waveforms_shift") + + if inplace: + aligned_waveforms = waveforms + else: + aligned_waveforms = waveforms.copy() + + shift_set = np.unique(peak_shifts) + assert max(np.abs(shift_set)) < aligned_waveforms.shape[1] + + for shift in shift_set: + if shift == 0: + continue + mask = peak_shifts == shift + wfs = waveforms[mask] + + if shift > 0: + aligned_waveforms[mask, :-shift, :] = wfs[:, shift:, :] + else: + aligned_waveforms[mask, -shift:, :] = wfs[:, :-shift, :] + + print("apply_waveforms_shift DONE") + + return aligned_waveforms diff --git a/src/spikeinterface/sortingcomponents/tests/test_merge.py b/src/spikeinterface/sortingcomponents/tests/test_merge.py new file mode 100644 index 0000000000..6b3ea2a901 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_merge.py @@ -0,0 +1,14 @@ +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.clustering.split import split_clusters + +# no proper test at the moment this is used in tridesclous2 + + +def test_merge(): + pass + + +if __name__ == "__main__": + test_merge() diff --git a/src/spikeinterface/sortingcomponents/tests/test_split.py b/src/spikeinterface/sortingcomponents/tests/test_split.py new file mode 100644 index 0000000000..5953f74e24 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_split.py @@ -0,0 +1,14 @@ +import pytest +import numpy as np + +from spikeinterface.sortingcomponents.clustering.split import split_clusters + +# no proper test at the moment this is used in tridesclous2 + + +def test_split(): + pass + + +if __name__ == "__main__": + test_split()