From a830445e0638bc919bd57dd9984ddf6e58749616 Mon Sep 17 00:00:00 2001 From: d-schindler <60650591+d-schindler@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:30:27 +0200 Subject: [PATCH] more improvements --- src/pygenstability/data_clustering.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/pygenstability/data_clustering.py b/src/pygenstability/data_clustering.py index 4c83e4d..c26df20 100644 --- a/src/pygenstability/data_clustering.py +++ b/src/pygenstability/data_clustering.py @@ -44,7 +44,7 @@ def __init__( # attributes self.adjacency_ = None - def fit(self, X): + def get_graph(self, X): """Construct graph from samples-by-features matrix.""" # if precomputed take X as adjacency matrix @@ -132,9 +132,10 @@ def labels_(self): return labels def fit(self, X): + """Construct graph and run PyGenStability for multiscale data clustering.""" # construct graph - self.adjacency_ = csr_matrix(super().fit(X)) + self.adjacency_ = csr_matrix(self.get_graph(X)) # run PyGenStability self.results_ = pgs_run(self.adjacency_, **self.pgs_kwargs) @@ -198,5 +199,9 @@ def plot_robust_partitions( ax.scatter(x_coord, y_coord, s=node_size, c=partition, zorder=10, cmap=cmap) # set labels - ax.set(xlabel="x", ylabel="y", title=f"Robust Partion {m+1}") + ax.set( + xlabel="x", + ylabel="y", + title=f"Robust Partion {m+1} (with {len(np.unique(partition))} clusters)", + ) plt.show()