diff --git a/src/pygenstability/data_clustering.py b/src/pygenstability/data_clustering.py index c26df20..ec9e79b 100644 --- a/src/pygenstability/data_clustering.py +++ b/src/pygenstability/data_clustering.py @@ -11,6 +11,7 @@ from pygenstability.pygenstability import run as pgs_run from pygenstability.plotting import plot_scan as pgs_plot_scan from pygenstability.optimal_scales import identify_optimal_scales +from pygenstability.contrib.sankey import plot_sankey as pgs_plot_sankey def compute_CkNN(D, k=5, delta=1): @@ -124,7 +125,7 @@ def labels_(self): # store labels of robust partitions for i in self.results_["selected_partitions"]: - # only store non-trivial robust partitions + # only return non-trivial robust partitions robust_partition = self.results_["community_id"][i] if not np.allclose(robust_partition, np.zeros(self.adjacency_.shape[0])): labels.append(robust_partition) @@ -205,3 +206,29 @@ def plot_robust_partitions( title=f"Robust Partion {m+1} (with {len(np.unique(partition))} clusters)", ) plt.show() + + def plot_sankey( + self, + optimal_scales=True, + live=False, + filename="communities_sankey.html", + scale_index=None, + ): + """Plot Sankey diagram.""" + + # plot non-trivial optimal scales only + if optimal_scales: + n_partitions = len(self.labels_) + # collect indices of non-trivial partitions + scale_index = self.results_["selected_partitions"][:n_partitions] + + # plot Sankey diagram + fig = pgs_plot_sankey( + self.results_, + optimal_scales=False, + live=live, + filename=filename, + scale_index=scale_index, + ) + + return fig