From de2d642d5f833b5c3f68df5150311b1ed5eddca8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Oct 2023 14:41:53 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/job_tools.py | 1 + src/spikeinterface/core/tests/test_globals.py | 8 +- .../core/tests/test_waveform_extractor.py | 2 +- .../sortingcomponents/clustering/merge.py | 77 ++++++++++--------- .../sortingcomponents/clustering/split.py | 18 ++--- .../sortingcomponents/clustering/tools.py | 1 - .../sortingcomponents/tests/test_merge.py | 1 + .../sortingcomponents/tests/test_split.py | 1 + 8 files changed, 58 insertions(+), 51 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index e42f7bb8b4..cf7a67489c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -436,6 +436,7 @@ def function_wrapper(args): # Here some utils copy/paste from DART (Charlie Windolf) + class MockFuture: """A non-concurrent class for mocking the concurrent.futures API.""" diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 2c0792c152..d0672405d6 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -39,14 +39,18 @@ def test_global_tmp_folder(): def test_global_job_kwargs(): job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + assert global_job_kwargs == dict( + n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() - assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + assert global_job_kwargs == dict( + n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=10) job_kwargs_split = fix_job_kwargs(new_job_kwargs) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index de6c3d752a..b56180a9e9 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -558,4 +558,4 @@ def test_non_json_object(): # test_portability() test_recordingless() # test_compute_sparsity() - # test_non_json_object() + # test_non_json_object() diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 5539ec1051..d892d0723a 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -80,45 +80,46 @@ def merge_clusters( method_kwargs=method_kwargs, **job_kwargs, ) - - + DEBUG = False if DEBUG: import matplotlib.pyplot as plt + fig, ax = plt.subplots() ax.matshow(pair_values) - - pair_values[~pair_mask] = 20 - + + pair_values[~pair_mask] = 20 + import hdbscan + fig, ax = plt.subplots() - clusterer = hdbscan.HDBSCAN(metric='precomputed', min_cluster_size=2, allow_single_cluster=True) + clusterer = hdbscan.HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True) clusterer.fit(pair_values) print(clusterer.labels_) - clusterer.single_linkage_tree_.plot(cmap='viridis', colorbar=True) - #~ fig, ax = plt.subplots() - #~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', - #~ edge_alpha=0.6, - #~ node_size=80, - #~ edge_linewidth=2) - + clusterer.single_linkage_tree_.plot(cmap="viridis", colorbar=True) + # ~ fig, ax = plt.subplots() + # ~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis', + # ~ edge_alpha=0.6, + # ~ node_size=80, + # ~ edge_linewidth=2) + graph = clusterer.single_linkage_tree_.to_networkx() import scipy.cluster + fig, ax = plt.subplots() scipy.cluster.hierarchy.dendrogram(clusterer.single_linkage_tree_.to_numpy(), ax=ax) - + import networkx as nx + fig = plt.figure() nx.draw_networkx(graph) plt.show() plt.show() - - - + merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial") - # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") + # merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full") group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift) @@ -223,7 +224,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" else: raise ValueError - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -232,7 +233,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" nx.draw_networkx(sub_graph) plt.show() - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -377,9 +378,10 @@ class ProjectDistribution: The idea is : * project the waveform (or features) samples on a 1d axis (using LDA for instance). * check that it is the same or not distribution (diptest, distrib_overlap, ...) - + """ + name = "project_distribution" @staticmethod @@ -412,13 +414,12 @@ def merge( chans1 = np.unique(peaks["channel_index"][inds1]) target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) - if inds0.size <40 or inds1.size <40: + if inds0.size < 40 or inds1.size < 40: is_merge = False merge_value = 0 final_shift = 0 return is_merge, label0, label1, final_shift, merge_value - target_chans = np.intersect1d(target_chans0, target_chans1) inds = np.concatenate([inds0, inds1]) @@ -500,20 +501,19 @@ def merge( elif criteria == "distrib_overlap": lim0 = min(np.min(feat0), np.min(feat1)) lim1 = max(np.max(feat0), np.max(feat1)) - bin_size = (lim1 - lim0) / 200. + bin_size = (lim1 - lim0) / 200.0 bins = np.arange(lim0, lim1, bin_size) - + pdf0, _ = np.histogram(feat0, bins=bins, density=True) pdf1, _ = np.histogram(feat1, bins=bins, density=True) pdf0 *= bin_size - pdf1 *= bin_size + pdf1 *= bin_size overlap = np.sum(np.minimum(pdf0, pdf1)) - + is_merge = overlap >= threshold_overlap - + merge_value = 1 - overlap - - + else: raise ValueError(f"bad criteria {criteria}") @@ -522,13 +522,13 @@ def merge( else: final_shift = 0 - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG and is_merge: - # if DEBUG and not is_merge: - # if DEBUG and (overlap > 0.05 and overlap <0.25): - # if label0 == 49 and label1== 65: + # if DEBUG and not is_merge: + # if DEBUG and (overlap > 0.05 and overlap <0.25): + # if label0 == 49 and label1== 65: import matplotlib.pyplot as plt flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) @@ -551,7 +551,6 @@ def merge( count1, _ = np.histogram(feat1, bins=bins, density=True) pdf0 = count0 * bin_size pdf1 = count1 * bin_size - ax = axs[1] ax.plot(bins[:-1], pdf0, color="C0") @@ -564,13 +563,15 @@ def merge( ax.axvline(l0, color="C0") ax.axvline(l1, color="C1") elif criteria == "distrib_overlap": - print(lim0, lim1, ) + print( + lim0, + lim1, + ) ax.set_title(f"{overlap:.4f} {is_merge}") - ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls='--', color='k') + ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls="--", color="k") plt.show() - return is_merge, label0, label1, final_shift, merge_value diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index dc649cec97..9836e9110f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -154,13 +154,12 @@ def split_function_wrapper(peak_indices): return is_split, local_labels, peak_indices - class LocalFeatureClustering: """ This method is a refactorized mix between: * old tridesclous code * "herding_split()" in DART/spikepsvae by Charlie Windolf - + The idea simple : * agregate features (svd or even waveforms) with sparse channel. * run a local feature reduction (pca or svd) @@ -183,7 +182,6 @@ def split( min_samples=25, n_pca_features=2, minimum_common_channels=2, - ): local_labels = np.zeros(peak_indices.size, dtype=np.int64) @@ -218,8 +216,12 @@ def split( final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features) if clusterer == "hdbscan": - clust = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, allow_single_cluster=True, - cluster_selection_method="leaf") + clust = HDBSCAN( + min_cluster_size=min_cluster_size, + min_samples=min_samples, + allow_single_cluster=True, + cluster_selection_method="leaf", + ) clust.fit(final_features) possible_labels = clust.labels_ is_split = np.setdiff1d(possible_labels, [-1]).size > 1 @@ -236,9 +238,7 @@ def split( else: raise ValueError(f"wrong clusterer {clusterer}") - - - # DEBUG = True + # DEBUG = True DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -259,7 +259,7 @@ def split( ax = axs[1] ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5) - + axs[0].set_title(f"{clusterer} {is_split}") plt.show() diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index c334daebe3..8e25c9cb7f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -94,7 +94,6 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask, return aligned_features, dont_have_channels - def compute_template_from_sparse( peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None ): diff --git a/src/spikeinterface/sortingcomponents/tests/test_merge.py b/src/spikeinterface/sortingcomponents/tests/test_merge.py index b7a669a263..6b3ea2a901 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_merge.py +++ b/src/spikeinterface/sortingcomponents/tests/test_merge.py @@ -5,6 +5,7 @@ # no proper test at the moment this is used in tridesclous2 + def test_merge(): pass diff --git a/src/spikeinterface/sortingcomponents/tests/test_split.py b/src/spikeinterface/sortingcomponents/tests/test_split.py index ca5e5b57e7..5953f74e24 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_split.py +++ b/src/spikeinterface/sortingcomponents/tests/test_split.py @@ -5,6 +5,7 @@ # no proper test at the moment this is used in tridesclous2 + def test_split(): pass