Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon committed Apr 29, 2024
1 parent e09c18c commit ed8f0d4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/pygenstability/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Import main functions."""

from pygenstability.constructors import *
from pygenstability.data_clustering import DataClustering
from pygenstability.io import load_results
from pygenstability.io import save_results
from pygenstability.optimal_scales import identify_optimal_scales
from pygenstability.plotting import *
from pygenstability.pygenstability import evaluate_NVI
from pygenstability.pygenstability import run
from pygenstability.data_clustering import DataClustering
52 changes: 27 additions & 25 deletions src/pygenstability/data_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

import matplotlib.pyplot as plt
import numpy as np

from scipy.spatial.distance import pdist, squareform
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform
from sklearn.neighbors import kneighbors_graph

from pygenstability.pygenstability import run as pgs_run
from pygenstability.plotting import plot_scan as pgs_plot_scan
from pygenstability.optimal_scales import identify_optimal_scales
from pygenstability.contrib.sankey import plot_sankey as pgs_plot_sankey
from pygenstability.optimal_scales import identify_optimal_scales
from pygenstability.plotting import plot_scan as pgs_plot_scan
from pygenstability.pygenstability import run as pgs_run


def _compute_CkNN(D, k=5, delta=1):
Expand Down Expand Up @@ -49,9 +49,7 @@ def get_graph(self, X):
"""Construct graph from samples-by-features matrix."""
# if precomputed take X as adjacency matrix
if self.method == "precomputed":
assert (
X.shape[0] == X.shape[1]
), "Precomputed matrix should be a square matrix."
assert X.shape[0] == X.shape[1], "Precomputed matrix should be a square matrix."
self.adjacency_ = X
return self.adjacency_

Expand All @@ -67,9 +65,7 @@ def get_graph(self, X):
sparse = _compute_CkNN(D_norm, self.k, self.delta)

elif self.method == "knn-mst":
sparse = kneighbors_graph(
D_norm, n_neighbors=self.k, metric="precomputed"
).toarray()
sparse = kneighbors_graph(D_norm, n_neighbors=self.k, metric="precomputed").toarray()

# undirected distance backbone is given by sparse graph and MST
mst = minimum_spanning_tree(D_norm)
Expand Down Expand Up @@ -200,6 +196,7 @@ def fit(self, X):
-----------
X : {array-like, sparse matrix} of shape (n_samples,n_features) or \
(n_samples,n_samples) if graph_method='precomputed'
Data to fit
Returns:
-------
Expand All @@ -214,9 +211,7 @@ def fit(self, X):

return self

def scale_selection(
self, kernel_size=0.1, window_size=0.1, max_nvi=1, basin_radius=0.01
):
def scale_selection(self, kernel_size=0.1, window_size=0.1, max_nvi=1, basin_radius=0.01):
"""Identify optimal scales [3].
Parameters:
Expand Down Expand Up @@ -252,9 +247,7 @@ def scale_selection(
if window_size < 1:
window_size = int(window_size * self.results_["run_params"]["n_scale"])
if basin_radius < 1:
basin_radius = max(
1, int(basin_radius * self.results_["run_params"]["n_scale"])
)
basin_radius = max(1, int(basin_radius * self.results_["run_params"]["n_scale"]))

# apply scale selection algorithm
self.results_ = identify_optimal_scales(
Expand All @@ -267,15 +260,15 @@ def scale_selection(

return self.labels_

def plot_scan(self):
def plot_scan(self, *args, **kwargs):
"""Plot summary figure for PyGenStability scan."""
if self.results_ is None:
return

pgs_plot_scan(self.results_)
pgs_plot_scan(self.results_, *args, **kwargs)

def plot_robust_partitions(
self, x_coord, y_coord, edge_width=1.0, node_size=20.0, cmap="tab20"
self, x_coord, y_coord, edge_width=1.0, node_size=20.0, cmap="tab20", show=True
):
"""Plot robust partitions with graph layout.
Expand All @@ -293,13 +286,21 @@ def plot_robust_partitions(
node_size : float, default=20.0
Node size in graph. This parameter is expected to be positive.
cmap : str, default:'tab20'
cmap : str, default='tab20'
Color map for cluster colors.
show : book, default=True
Show the figures.
Returns:
--------
figs : All matplotlib figures
"""
figs = []
for m, partition in enumerate(self.labels_):

# plot
_, ax = plt.subplots(1, figsize=(10, 10))
fig, ax = plt.subplots(1, figsize=(10, 10))
figs.append(fig)

# plot edges
for i in range(self.adjacency_.shape[0]):
Expand All @@ -322,6 +323,7 @@ def plot_robust_partitions(
ylabel="y",
title=f"Robust Partion {m+1} (with {len(np.unique(partition))} clusters)",
)
if show:
plt.show()

def plot_sankey(
Expand Down
44 changes: 12 additions & 32 deletions src/pygenstability/pygenstability.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def _get_constructor_data(constructor, scales, pool, tqdm_disable=False):

def _check_method(method): # pragma: no cover
if _NO_LEIDEN and _NO_LOUVAIN:
raise Exception(
"Without Louvain or Leiden solver, we cannot run PyGenStability"
)
raise Exception("Without Louvain or Leiden solver, we cannot run PyGenStability")

if method == "louvain" and _NO_LOUVAIN:
print("Louvain is not available, we fallback to leiden.")
Expand Down Expand Up @@ -249,17 +247,13 @@ def run(
communities = _process_runs(t, results, all_results)

if with_NVI:
_compute_NVI(
communities, all_results, pool, n_partitions=min(n_NVI, n_tries)
)
_compute_NVI(communities, all_results, pool, n_partitions=min(n_NVI, n_tries))

save_results(all_results, filename=result_file)

if with_postprocessing:
L.info("Apply postprocessing...")
_apply_postprocessing(
all_results, pool, constructor_data, tqdm_disable, method=method
)
_apply_postprocessing(all_results, pool, constructor_data, tqdm_disable, method=method)

if with_ttprime or with_optimal_scales:
L.info("Compute ttprimes...")
Expand All @@ -273,9 +267,7 @@ def run(
"window_size": max(2, int(0.1 * n_scale)),
"basin_radius": max(1, int(0.01 * n_scale)),
}
all_results = identify_optimal_scales(
all_results, **optimal_scales_kwargs
)
all_results = identify_optimal_scales(all_results, **optimal_scales_kwargs)

save_results(all_results, filename=result_file)

Expand Down Expand Up @@ -304,9 +296,7 @@ def _compute_NVI(communities, all_results, pool, n_partitions=10):
worker = partial(evaluate_NVI, partitions=selected_partitions)
index_pairs = [[i, j] for i in range(n_partitions) for j in range(n_partitions)]
chunksize = _get_chunksize(len(index_pairs), pool)
all_results["NVI"].append(
np.mean(list(pool.imap(worker, index_pairs, chunksize=chunksize)))
)
all_results["NVI"].append(np.mean(list(pool.imap(worker, index_pairs, chunksize=chunksize))))


def evaluate_NVI(index_pair, partitions):
Expand Down Expand Up @@ -351,9 +341,7 @@ def _to_indices(matrix, directed=False):


@_timing
def _optimise(
_, quality_indices, quality_values, null_model, global_shift, method="louvain"
):
def _optimise(_, quality_indices, quality_values, null_model, global_shift, method="louvain"):
"""Worker for generalized Markov Stability optimisation runs."""
if method == "louvain":
stability, community_id = generalized_louvain.run_louvain(
Expand Down Expand Up @@ -393,9 +381,7 @@ def _optimise(
return stability + global_shift, community_id


def _evaluate_quality(
partition_id, qualities_index, null_model, global_shift, method="louvain"
):
def _evaluate_quality(partition_id, qualities_index, null_model, global_shift, method="louvain"):
"""Worker for generalized Markov Stability optimisation runs."""
if method == "louvain":
quality = generalized_louvain.evaluate_quality(
Expand Down Expand Up @@ -452,18 +438,14 @@ def _compute_ttprime(all_results, pool):
chunksize = _get_chunksize(len(index_pairs), pool)
ttprime_list = pool.map(worker, index_pairs, chunksize=chunksize)

all_results["ttprime"] = np.zeros(
[len(all_results["scales"]), len(all_results["scales"])]
)
all_results["ttprime"] = np.zeros([len(all_results["scales"]), len(all_results["scales"])])
for i, ttp in enumerate(ttprime_list):
all_results["ttprime"][index_pairs[i][0], index_pairs[i][1]] = ttp
all_results["ttprime"] += all_results["ttprime"].T


@_timing
def _apply_postprocessing(
all_results, pool, constructors, tqdm_disable=False, method="louvain"
):
def _apply_postprocessing(all_results, pool, constructors, tqdm_disable=False, method="louvain"):
"""Apply postprocessing."""
all_results_raw = all_results.copy()

Expand All @@ -485,10 +467,8 @@ def _apply_postprocessing(
)
)

all_results["community_id"][i] = all_results_raw["community_id"][
all_results["community_id"][i] = all_results_raw["community_id"][best_quality_id]
all_results["stability"][i] = all_results_raw["stability"][best_quality_id]
all_results["number_of_communities"][i] = all_results_raw["number_of_communities"][
best_quality_id
]
all_results["stability"][i] = all_results_raw["stability"][best_quality_id]
all_results["number_of_communities"][i] = all_results_raw[
"number_of_communities"
][best_quality_id]

0 comments on commit ed8f0d4

Please sign in to comment.