Skip to content

Commit

Permalink
add plot_sankey method
Browse files Browse the repository at this point in the history
  • Loading branch information
d-schindler committed Apr 19, 2024
1 parent a830445 commit 8595f2e
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/pygenstability/data_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pygenstability.pygenstability import run as pgs_run
from pygenstability.plotting import plot_scan as pgs_plot_scan
from pygenstability.optimal_scales import identify_optimal_scales
from pygenstability.contrib.sankey import plot_sankey as pgs_plot_sankey


def compute_CkNN(D, k=5, delta=1):
Expand Down Expand Up @@ -124,7 +125,7 @@ def labels_(self):
# store labels of robust partitions
for i in self.results_["selected_partitions"]:

# only store non-trivial robust partitions
# only return non-trivial robust partitions
robust_partition = self.results_["community_id"][i]
if not np.allclose(robust_partition, np.zeros(self.adjacency_.shape[0])):
labels.append(robust_partition)
Expand Down Expand Up @@ -205,3 +206,29 @@ def plot_robust_partitions(
title=f"Robust Partion {m+1} (with {len(np.unique(partition))} clusters)",
)
plt.show()

def plot_sankey(
self,
optimal_scales=True,
live=False,
filename="communities_sankey.html",
scale_index=None,
):
"""Plot Sankey diagram."""

# plot non-trivial optimal scales only
if optimal_scales:
n_partitions = len(self.labels_)
# collect indices of non-trivial partitions
scale_index = self.results_["selected_partitions"][:n_partitions]

# plot Sankey diagram
fig = pgs_plot_sankey(
self.results_,
optimal_scales=False,
live=live,
filename=filename,
scale_index=scale_index,
)

return fig

0 comments on commit 8595f2e

Please sign in to comment.