Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 5, 2023
1 parent 48da4ea commit de2d642
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 51 deletions.
1 change: 1 addition & 0 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def function_wrapper(args):

# Here some utils copy/paste from DART (Charlie Windolf)


class MockFuture:
"""A non-concurrent class for mocking the concurrent.futures API."""

Expand Down
8 changes: 6 additions & 2 deletions src/spikeinterface/core/tests/test_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,18 @@ def test_global_tmp_folder():
def test_global_job_kwargs():
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
global_job_kwargs = get_global_job_kwargs()
assert global_job_kwargs == dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
assert global_job_kwargs == dict(
n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
)
set_global_job_kwargs(**job_kwargs)
assert get_global_job_kwargs() == job_kwargs
# test updating only one field
partial_job_kwargs = dict(n_jobs=2)
set_global_job_kwargs(**partial_job_kwargs)
global_job_kwargs = get_global_job_kwargs()
assert global_job_kwargs == dict(n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
assert global_job_kwargs == dict(
n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
)
# test that fix_job_kwargs grabs global kwargs
new_job_kwargs = dict(n_jobs=10)
job_kwargs_split = fix_job_kwargs(new_job_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,4 +558,4 @@ def test_non_json_object():
# test_portability()
test_recordingless()
# test_compute_sparsity()
# test_non_json_object()
# test_non_json_object()
77 changes: 39 additions & 38 deletions src/spikeinterface/sortingcomponents/clustering/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,45 +80,46 @@ def merge_clusters(
method_kwargs=method_kwargs,
**job_kwargs,
)



DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.matshow(pair_values)
pair_values[~pair_mask] = 20

pair_values[~pair_mask] = 20

import hdbscan

fig, ax = plt.subplots()
clusterer = hdbscan.HDBSCAN(metric='precomputed', min_cluster_size=2, allow_single_cluster=True)
clusterer = hdbscan.HDBSCAN(metric="precomputed", min_cluster_size=2, allow_single_cluster=True)
clusterer.fit(pair_values)
print(clusterer.labels_)
clusterer.single_linkage_tree_.plot(cmap='viridis', colorbar=True)
#~ fig, ax = plt.subplots()
#~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis',
#~ edge_alpha=0.6,
#~ node_size=80,
#~ edge_linewidth=2)
clusterer.single_linkage_tree_.plot(cmap="viridis", colorbar=True)
# ~ fig, ax = plt.subplots()
# ~ clusterer.minimum_spanning_tree_.plot(edge_cmap='viridis',
# ~ edge_alpha=0.6,
# ~ node_size=80,
# ~ edge_linewidth=2)

graph = clusterer.single_linkage_tree_.to_networkx()

import scipy.cluster

fig, ax = plt.subplots()
scipy.cluster.hierarchy.dendrogram(clusterer.single_linkage_tree_.to_numpy(), ax=ax)

import networkx as nx

fig = plt.figure()
nx.draw_networkx(graph)
plt.show()

plt.show()




merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="partial")
# merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full")
# merges = agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full")

group_shifts = resolve_final_shifts(labels_set, merges, pair_mask, pair_shift)

Expand Down Expand Up @@ -223,7 +224,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"
else:
raise ValueError

# DEBUG = True
# DEBUG = True
DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt
Expand All @@ -232,7 +233,7 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"
nx.draw_networkx(sub_graph)
plt.show()

# DEBUG = True
# DEBUG = True
DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -377,9 +378,10 @@ class ProjectDistribution:
The idea is :
* project the waveform (or features) samples on a 1d axis (using LDA for instance).
* check that it is the same or not distribution (diptest, distrib_overlap, ...)
"""

name = "project_distribution"

@staticmethod
Expand Down Expand Up @@ -412,13 +414,12 @@ def merge(
chans1 = np.unique(peaks["channel_index"][inds1])
target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0))

if inds0.size <40 or inds1.size <40:
if inds0.size < 40 or inds1.size < 40:
is_merge = False
merge_value = 0
final_shift = 0
return is_merge, label0, label1, final_shift, merge_value


target_chans = np.intersect1d(target_chans0, target_chans1)

inds = np.concatenate([inds0, inds1])
Expand Down Expand Up @@ -500,20 +501,19 @@ def merge(
elif criteria == "distrib_overlap":
lim0 = min(np.min(feat0), np.min(feat1))
lim1 = max(np.max(feat0), np.max(feat1))
bin_size = (lim1 - lim0) / 200.
bin_size = (lim1 - lim0) / 200.0
bins = np.arange(lim0, lim1, bin_size)

pdf0, _ = np.histogram(feat0, bins=bins, density=True)
pdf1, _ = np.histogram(feat1, bins=bins, density=True)
pdf0 *= bin_size
pdf1 *= bin_size
pdf1 *= bin_size
overlap = np.sum(np.minimum(pdf0, pdf1))

is_merge = overlap >= threshold_overlap

merge_value = 1 - overlap



else:
raise ValueError(f"bad criteria {criteria}")

Expand All @@ -522,13 +522,13 @@ def merge(
else:
final_shift = 0

# DEBUG = True
# DEBUG = True
DEBUG = False

if DEBUG and is_merge:
if DEBUG and not is_merge:
if DEBUG and (overlap > 0.05 and overlap <0.25):
if label0 == 49 and label1== 65:
# if DEBUG and not is_merge:
# if DEBUG and (overlap > 0.05 and overlap <0.25):
# if label0 == 49 and label1== 65:
import matplotlib.pyplot as plt

flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1)
Expand All @@ -551,7 +551,6 @@ def merge(
count1, _ = np.histogram(feat1, bins=bins, density=True)
pdf0 = count0 * bin_size
pdf1 = count1 * bin_size


ax = axs[1]
ax.plot(bins[:-1], pdf0, color="C0")
Expand All @@ -564,13 +563,15 @@ def merge(
ax.axvline(l0, color="C0")
ax.axvline(l1, color="C1")
elif criteria == "distrib_overlap":
print(lim0, lim1, )
print(
lim0,
lim1,
)
ax.set_title(f"{overlap:.4f} {is_merge}")
ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls='--', color='k')
ax.plot(bins[:-1], np.minimum(pdf0, pdf1), ls="--", color="k")

plt.show()


return is_merge, label0, label1, final_shift, merge_value


Expand Down
18 changes: 9 additions & 9 deletions src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,12 @@ def split_function_wrapper(peak_indices):
return is_split, local_labels, peak_indices



class LocalFeatureClustering:
"""
This method is a refactorized mix between:
* old tridesclous code
* "herding_split()" in DART/spikepsvae by Charlie Windolf
The idea simple :
* agregate features (svd or even waveforms) with sparse channel.
* run a local feature reduction (pca or svd)
Expand All @@ -183,7 +182,6 @@ def split(
min_samples=25,
n_pca_features=2,
minimum_common_channels=2,

):
local_labels = np.zeros(peak_indices.size, dtype=np.int64)

Expand Down Expand Up @@ -218,8 +216,12 @@ def split(
final_features = TruncatedSVD(n_pca_features).fit_transform(flatten_features)

if clusterer == "hdbscan":
clust = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples, allow_single_cluster=True,
cluster_selection_method="leaf")
clust = HDBSCAN(
min_cluster_size=min_cluster_size,
min_samples=min_samples,
allow_single_cluster=True,
cluster_selection_method="leaf",
)
clust.fit(final_features)
possible_labels = clust.labels_
is_split = np.setdiff1d(possible_labels, [-1]).size > 1
Expand All @@ -236,9 +238,7 @@ def split(
else:
raise ValueError(f"wrong clusterer {clusterer}")



# DEBUG = True
# DEBUG = True
DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt
Expand All @@ -259,7 +259,7 @@ def split(

ax = axs[1]
ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5)

axs[0].set_title(f"{clusterer} {is_split}")

plt.show()
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/sortingcomponents/clustering/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def aggregate_sparse_features(peaks, peak_indices, sparse_feature, sparse_mask,
return aligned_features, dont_have_channels



def compute_template_from_sparse(
peaks, labels, labels_set, sparse_waveforms, sparse_mask, total_channels, peak_shifts=None
):
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/sortingcomponents/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# no proper test at the moment this is used in tridesclous2


def test_merge():
pass

Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/sortingcomponents/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# no proper test at the moment this is used in tridesclous2


def test_split():
pass

Expand Down

0 comments on commit de2d642

Please sign in to comment.