From 0e470c2cfb2327fd3f62a65d33825423b63d4175 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 14 Sep 2023 21:02:08 +0200 Subject: [PATCH] wip merge agglomerate_pairs graph --- .../sortingcomponents/clustering/merge.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index d97fc84b5a..2e839ef0fc 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -158,6 +158,10 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" merges = [] graph = nx.from_numpy_matrix(pair_mask | pair_mask.T) + # put real nodes names for debugging + maps = dict(zip(np.arange(labels_set.size), labels_set)) + graph = nx.relabel_nodes(graph, maps) + groups = list(nx.connected_components(graph)) for group in groups: if len(group) == 1: @@ -167,7 +171,8 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" cliques = list(nx.find_cliques(sub_graph)) if len(cliques) == 1 and len(cliques[0]) == len(group): # the sub graph is full connected: no ambiguity - merges.append(labels_set[cliques[0]]) + # merges.append(labels_set[cliques[0]]) + merges.append(cliques[0]) elif len(cliques) > 1: # the subgraph is not fully connected if connection_mode == "full": @@ -175,7 +180,8 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" pass elif connection_mode == "partial": group = list(group) - merges.append(labels_set[group]) + # merges.append(labels_set[group]) + merges.append(group) elif connection_mode == "clique": raise NotImplementedError else: @@ -190,8 +196,8 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" nx.draw_networkx(sub_graph) plt.show() - # DEBUG = True - DEBUG = False + DEBUG = True + # DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -358,6 +364,13 @@ def merge( chans1 = np.unique(peaks["channel_index"][inds1]) target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) + if inds0.size <40 or inds1.size <40: + is_merge = False + merge_value = 0 + final_shift = 0 + return is_merge, label0, label1, final_shift, merge_value + + target_chans = np.intersect1d(target_chans0, target_chans1) inds = np.concatenate([inds0, inds1]) @@ -444,11 +457,11 @@ def merge( else: final_shift = 0 - # DEBUG = True - DEBUG = False + DEBUG = True + # DEBUG = False - # if DEBUG and is_merge: - if DEBUG: + if DEBUG and is_merge: + # if DEBUG: import matplotlib.pyplot as plt flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) @@ -460,8 +473,8 @@ def merge( ax.plot(flatten_wfs1.T, color="C1", alpha=0.01) m0 = np.mean(flatten_wfs0, axis=0) m1 = np.mean(flatten_wfs1, axis=0) - ax.plot(m0, color="C0", alpha=1, lw=4, label="label0") - ax.plot(m1, color="C1", alpha=1, lw=4, label="label1") + ax.plot(m0, color="C0", alpha=1, lw=4, label=f"{label0} {inds0.size}") + ax.plot(m1, color="C1", alpha=1, lw=4, label=f"{label1} {inds1.size}") ax.legend() @@ -474,7 +487,9 @@ def merge( ax.plot(bins[:-1], count0, color="C0") ax.plot(bins[:-1], count1, color="C1") - ax.set_title(f"{dipscore}") + ax.set_title(f"{dipscore:.4f} {is_merge}") + plt.show() + return is_merge, label0, label1, final_shift, merge_value