Skip to content

Commit

Permalink
wip merge agglomerate_pairs graph
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Sep 14, 2023

Verified

This commit was signed with the committer’s verified signature. The key has expired.
mcansh Logan McAnsh
1 parent 8de9d53 commit 0e470c2
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions src/spikeinterface/sortingcomponents/clustering/merge.py
Original file line number Diff line number Diff line change
@@ -158,6 +158,10 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"
merges = []

graph = nx.from_numpy_matrix(pair_mask | pair_mask.T)
# put real nodes names for debugging
maps = dict(zip(np.arange(labels_set.size), labels_set))
graph = nx.relabel_nodes(graph, maps)

groups = list(nx.connected_components(graph))
for group in groups:
if len(group) == 1:
@@ -167,15 +171,17 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"
cliques = list(nx.find_cliques(sub_graph))
if len(cliques) == 1 and len(cliques[0]) == len(group):
# the sub graph is full connected: no ambiguity
merges.append(labels_set[cliques[0]])
# merges.append(labels_set[cliques[0]])
merges.append(cliques[0])
elif len(cliques) > 1:
# the subgraph is not fully connected
if connection_mode == "full":
# node merge
pass
elif connection_mode == "partial":
group = list(group)
merges.append(labels_set[group])
# merges.append(labels_set[group])
merges.append(group)
elif connection_mode == "clique":
raise NotImplementedError
else:
@@ -190,8 +196,8 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"
nx.draw_networkx(sub_graph)
plt.show()

# DEBUG = True
DEBUG = False
DEBUG = True
# DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

@@ -358,6 +364,13 @@ def merge(
chans1 = np.unique(peaks["channel_index"][inds1])
target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0))

if inds0.size <40 or inds1.size <40:
is_merge = False
merge_value = 0
final_shift = 0
return is_merge, label0, label1, final_shift, merge_value


target_chans = np.intersect1d(target_chans0, target_chans1)

inds = np.concatenate([inds0, inds1])
@@ -444,11 +457,11 @@ def merge(
else:
final_shift = 0

# DEBUG = True
DEBUG = False
DEBUG = True
# DEBUG = False

# if DEBUG and is_merge:
if DEBUG:
if DEBUG and is_merge:
# if DEBUG:
import matplotlib.pyplot as plt

flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1)
@@ -460,8 +473,8 @@ def merge(
ax.plot(flatten_wfs1.T, color="C1", alpha=0.01)
m0 = np.mean(flatten_wfs0, axis=0)
m1 = np.mean(flatten_wfs1, axis=0)
ax.plot(m0, color="C0", alpha=1, lw=4, label="label0")
ax.plot(m1, color="C1", alpha=1, lw=4, label="label1")
ax.plot(m0, color="C0", alpha=1, lw=4, label=f"{label0} {inds0.size}")
ax.plot(m1, color="C1", alpha=1, lw=4, label=f"{label1} {inds1.size}")

ax.legend()

@@ -474,7 +487,9 @@ def merge(
ax.plot(bins[:-1], count0, color="C0")
ax.plot(bins[:-1], count1, color="C1")

ax.set_title(f"{dipscore}")
ax.set_title(f"{dipscore:.4f} {is_merge}")
plt.show()


return is_merge, label0, label1, final_shift, merge_value

0 comments on commit 0e470c2

Please sign in to comment.