From 86e9c63d600d82f6ecb06432f067975be6d9353a Mon Sep 17 00:00:00 2001 From: alexis arnaudon Date: Sun, 14 Feb 2021 11:19:19 +0100 Subject: [PATCH 1/3] small corrections --- pygenstability/constructors.py | 6 +- pygenstability/plotting.py | 109 ++++++++++++------------------- pygenstability/pygenstability.py | 8 +-- 3 files changed, 51 insertions(+), 72 deletions(-) diff --git a/pygenstability/constructors.py b/pygenstability/constructors.py index 9384811..bb464f9 100644 --- a/pygenstability/constructors.py +++ b/pygenstability/constructors.py @@ -11,11 +11,13 @@ DTYPE = "float128" -def load_constructor(constructor): +def load_constructor(constructor, graph, **kwargs): """Load a constructor from its name, or as a custom Constructor class.""" if isinstance(constructor, str): + if graph is None: + raise Exception(f"No graph was provided with a generic constructor {constructor}") try: - return getattr(sys.modules[__name__], "constructor_%s" % constructor) + return getattr(sys.modules[__name__], "constructor_%s" % constructor(graph, **kwargs)) except AttributeError as exc: raise Exception("Could not load constructor %s" % constructor) from exc if not isinstance(constructor, Constructor): diff --git a/pygenstability/plotting.py b/pygenstability/plotting.py index 74e00c6..dd37a4b 100644 --- a/pygenstability/plotting.py +++ b/pygenstability/plotting.py @@ -10,10 +10,19 @@ from matplotlib import patches from tqdm import tqdm -L = logging.getLogger("pygenstability") +L = logging.getLogger(__name__) +# pylint: disable=import-outside-toplevel -def plot_scan(all_results, time_axis=True, figure_name="scan_results.pdf", use_plotly=True): + +def plot_scan( + all_results, + time_axis=True, + figure_name="scan_results.pdf", + use_plotly=True, + live=True, + plotly_filename="scan_results.html", +): """Plot results of pygenstability with matplotlib or plotly.""" if len(all_results["times"]) == 1: L.info("Cannot plot the results if only one time point, we display the result instead:") @@ -22,33 +31,31 @@ def plot_scan(all_results, time_axis=True, figure_name="scan_results.pdf", use_p if use_plotly: try: - plot_scan_plotly(all_results) + plot_scan_plotly(all_results, live=live, filename=plotly_filename) except ImportError: L.warning( "Plotly is not installed, please install package with \ pip install pygenstabiliy[plotly], using matplotlib instead." ) - - plot_scan_plt(all_results, time_axis=time_axis, figure_name=figure_name) + else: + plot_scan_plt(all_results, time_axis=time_axis, figure_name=figure_name) def plot_scan_plotly( # pylint: disable=too-many-branches,too-many-statements,too-many-locals all_results, + live=False, + filename="clusters.html", ): """Plot results of pygenstability with plotly.""" - # from plotly.subplots import make_subplots # pylint: disable=import-outside-toplevel - import plotly.graph_objects as go # pylint: disable=import-outside-toplevel + import plotly.graph_objects as go + from plotly.offline import plot as _plot if all_results["run_params"]["log_time"]: times = np.log10(all_results["times"]) else: times = all_results["times"] - hovertemplate = str( - "Time: %{x:.2f}" - + "
Number of communities: %{y}" - + "
%{text}" - ) + hovertemplate = str("Time: %{x:.2f},
%{text}") if "variation_information" in all_results: vi_data = all_results["variation_information"] @@ -62,8 +69,10 @@ def plot_scan_plotly( # pylint: disable=too-many-branches,too-many-statements,t vi_ticks = False text = [ - "Stability: {0:.3f},
Variation Information: {1:.3f},
Index: {2}".format(s, vi, i) - for s, vi, i in zip( + f"""Number of communities: {n},
Stability: {np.round(s, 3)}, +
Variation Information: {np.round(vi, 3)},
Index: {i}""" + for n, s, vi, i in zip( + all_results["number_of_communities"], all_results["stability"], vi_data, np.arange(0, len(times)), @@ -107,16 +116,16 @@ def plot_scan_plotly( # pylint: disable=too-many-branches,too-many-statements,t ), showscale=showscale, ) - - stab = go.Scatter( - x=times, - y=all_results["stability"], - mode="lines+markers", - hovertemplate=hovertemplate, - text=text, - name="Stability", - marker_color="blue", - ) + if "stability" in all_results: + stab = go.Scatter( + x=times, + y=all_results["stability"], + mode="lines+markers", + hovertemplate=hovertemplate, + text=text, + name="Stability", + marker_color="blue", + ) vi = go.Scatter( x=times, @@ -131,42 +140,18 @@ def plot_scan_plotly( # pylint: disable=too-many-branches,too-many-statements,t opacity=vi_opacity, ) - opt_criterion = go.Scatter( - x=times, - y=all_results["optimal_scale_criterion"], - mode="lines+markers", - hovertemplate=hovertemplate, - text=text, - name="Optimal Scale Criterion", - yaxis="y5", - xaxis="x3", - marker_color="orange", - ) - - opt_scale = go.Scatter( - x=times[all_results["selected_partitions"]], - y=np.zeros(len(all_results["selected_partitions"])), - mode="markers", - hovertemplate=hovertemplate, - text=text, - name="Optimal Scale", - yaxis="y5", - xaxis="x3", - marker_color="black", - ) - layout = go.Layout( yaxis=dict( title="Stability", titlefont=dict(color="blue"), tickfont=dict(color="blue"), - domain=[0.26, 0.49], + domain=[0.0, 0.28], ), yaxis2=dict( title=tprime_title, titlefont=dict(color="black"), tickfont=dict(color="black"), - domain=[0.51, 1], + domain=[0.32, 1], side="right", range=[times[0], times[-1]], ), @@ -184,22 +169,17 @@ def plot_scan_plotly( # pylint: disable=too-many-branches,too-many-statements,t tickfont=dict(color="red"), overlaying="y2", ), - yaxis5=dict( - title="Optimal Scale Criterion", - titlefont=dict(color="orange"), - tickfont=dict(color="orange"), - domain=[0, 0.24], - ), - xaxis=dict( - range=[times[0], times[-1]], - ), + xaxis=dict(range=[times[0], times[-1]]), xaxis2=dict(range=[times[0], times[-1]]), height=600, width=800, ) - fig = go.Figure(data=[stab, ncom, vi, ttprime, opt_criterion, opt_scale], layout=layout) - fig.show() + fig = go.Figure(data=[stab, ncom, vi, ttprime], layout=layout) + _plot(fig, filename=filename) + + if live: + fig.show() def plot_single_community( @@ -313,7 +293,6 @@ def plot_scan_plt(all_results, time_axis=True, figure_name="scan_results.svg"): ax1.set_xticks([]) plot_number_comm(all_results, ax=ax1, time_axis=time_axis) - if "ttprime" in all_results: ax1.yaxis.tick_right() ax1.yaxis.set_label_position("right") @@ -408,6 +387,7 @@ def plot_sankey(all_results, live=False, filename="communities_sankey.svg", time time_index (bool): plot time of indices """ import plotly.graph_objects as go + from plotly.offline import plot as _plot sources = [] targets = [] @@ -448,10 +428,7 @@ def plot_sankey(all_results, live=False, filename="communities_sankey.svg", time layout=layout, ) - try: - fig.write_image(filename) - except Exception: # pylint: disable=broad-except - print("Plotly figure cannot be saved, please install the relevant packages.") + _plot(fig, filename=filename) if live: fig.show() diff --git a/pygenstability/pygenstability.py b/pygenstability/pygenstability.py index fbf5f9a..d79c24f 100644 --- a/pygenstability/pygenstability.py +++ b/pygenstability/pygenstability.py @@ -61,7 +61,7 @@ def _get_params(all_locals): def run( - graph, + graph=None, constructor="linearized", min_time=-2.0, max_time=0.5, @@ -78,10 +78,10 @@ def run( n_workers=4, tqdm_disable=False, ): - """Main funtion to compute clustering at various time scales. + """Main function to compute clustering at various time scales. Args: - graph (scipy.csgraph): graph to cluster + graph (scipy.csgraph): graph to cluster, if None, the constructor cannot be a str constructor (str/function): name of the quality constructor, or custom constructor function. It must have two arguments, graph and time. min_time (float): minimum Markov time @@ -108,7 +108,7 @@ def run( log_time=log_time, times=times, ) - constructor = load_constructor(constructor)(graph, with_spectral_gap=with_spectral_gap) + constructor = load_constructor(constructor, graph, with_spectral_gap=with_spectral_gap) pool = multiprocessing.Pool(n_workers) L.info("Start loop over times...") From 5cc3c76e8de8fa768f3696a8bc182aa70c2664d2 Mon Sep 17 00:00:00 2001 From: alexis arnaudon Date: Sun, 14 Feb 2021 11:33:49 +0100 Subject: [PATCH 2/3] example cleanup --- examples/params.yaml | 23 ----------------------- examples/run_simple_example.sh | 7 ------- examples/simple_example.py | 5 +++-- pygenstability/constructors.py | 2 +- pygenstability/plotting.py | 2 +- 5 files changed, 5 insertions(+), 34 deletions(-) delete mode 100644 examples/params.yaml diff --git a/examples/params.yaml b/examples/params.yaml deleted file mode 100644 index 38a9f2c..0000000 --- a/examples/params.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# time parmaeters -min_time: -2.0 -max_time: 0.5 -n_time: 20 -log_time: True - -n_runs: 100 # number of Louvain runs -n_workers: 2 # number of workers for parallel computing - -# quality constructor -#constructor: 'continuous_combinatorial' -constructor: 'continuous_linearized' -#constructor: 'continuous_normalized' -#constructor: 'signed_modularity' - -# additional parameters -compute_mutual_information: True -n_partitions: 20 # number of louvian runs for MI computation -compute_ttprime: True -apply_postprocessing: False #True - - - diff --git a/examples/run_simple_example.sh b/examples/run_simple_example.sh index 61b4707..29ecd77 100755 --- a/examples/run_simple_example.sh +++ b/examples/run_simple_example.sh @@ -1,4 +1,3 @@ -<<<<<<< Updated upstream #!/bin/bash python create_graph.py @@ -15,12 +14,6 @@ pygenstability run \ # sbm_graph.pkl pygenstability plot_scan --help -======= -#!/bin/zsh - -python create_graph.py -pygenstability run --n-time 100 sbm_graph.pkl ->>>>>>> Stashed changes pygenstability plot_scan results.pkl pygenstability plot_communities --help diff --git a/examples/simple_example.py b/examples/simple_example.py index a6aa47d..9cd0329 100644 --- a/examples/simple_example.py +++ b/examples/simple_example.py @@ -12,12 +12,13 @@ def simple_test(): all_results = run(graph) - plotting.plot_scan(all_results, use_plotly=True) + plotting.plot_scan(all_results, use_plotly=False) + plotting.plot_scan(all_results, use_plotly=True, live=False) with open("sbm_graph.gpickle", "rb") as pickle_file: graph = pickle.load(pickle_file) plotting.plot_communities(graph, all_results) - plotting.plot_sankey(all_results) + plotting.plot_sankey(all_results, live=False) if __name__ == "__main__": diff --git a/pygenstability/constructors.py b/pygenstability/constructors.py index bb464f9..0775c64 100644 --- a/pygenstability/constructors.py +++ b/pygenstability/constructors.py @@ -17,7 +17,7 @@ def load_constructor(constructor, graph, **kwargs): if graph is None: raise Exception(f"No graph was provided with a generic constructor {constructor}") try: - return getattr(sys.modules[__name__], "constructor_%s" % constructor(graph, **kwargs)) + return getattr(sys.modules[__name__], "constructor_%s" % constructor)(graph, **kwargs) except AttributeError as exc: raise Exception("Could not load constructor %s" % constructor) from exc if not isinstance(constructor, Constructor): diff --git a/pygenstability/plotting.py b/pygenstability/plotting.py index dd37a4b..4d175bc 100644 --- a/pygenstability/plotting.py +++ b/pygenstability/plotting.py @@ -377,7 +377,7 @@ def plot_clustered_adjacency( plt.savefig(figure_name, bbox_inches="tight") -def plot_sankey(all_results, live=False, filename="communities_sankey.svg", time_index=None): +def plot_sankey(all_results, live=False, filename="communities_sankey.html", time_index=None): """Plot Sankey diagram of communities accros time. Args: From cebf105ce79ed6b753128ac3de8b24b39463ed9d Mon Sep 17 00:00:00 2001 From: alexis arnaudon Date: Sun, 14 Feb 2021 11:51:56 +0100 Subject: [PATCH 3/3] docstring update --- pygenstability/constructors.py | 15 +++++++- pygenstability/plotting.py | 63 +++++++++++++++++++++++--------- pygenstability/pygenstability.py | 4 +- 3 files changed, 61 insertions(+), 21 deletions(-) diff --git a/pygenstability/constructors.py b/pygenstability/constructors.py index 0775c64..cc39691 100644 --- a/pygenstability/constructors.py +++ b/pygenstability/constructors.py @@ -55,10 +55,21 @@ def get_spectral_gap(laplacian): class Constructor: - """Parent constructor class.""" + """Parent constructor class. + + This class encodes method specific construction of quality matrix and null models. + Use the method prepare to load and compute time independent quantities, and the method get_data + to return quality matrix, null model, and possible global shift (for linearised stability). + """ def __init__(self, graph, with_spectral_gap=False, **kwargs): - """Initialise constructor.""" + """The constructor calls te prepare method upon initialisation. + + Args: + graph (csgraph): graph for which to run clustering + with_spectral_gap (bool): set to True to use spectral gap time rescale if available + kwargs (dict): any other properties to pass to the constructor. + """ self.graph = graph self.with_spectral_gap = with_spectral_gap self.spectral_gap = None diff --git a/pygenstability/plotting.py b/pygenstability/plotting.py index 4d175bc..3c79432 100644 --- a/pygenstability/plotting.py +++ b/pygenstability/plotting.py @@ -19,11 +19,20 @@ def plot_scan( all_results, time_axis=True, figure_name="scan_results.pdf", - use_plotly=True, + use_plotly=False, live=True, plotly_filename="scan_results.html", ): - """Plot results of pygenstability with matplotlib or plotly.""" + """Plot results of pygenstability with matplotlib or plotly. + + Args: + all_results (dict): results of pygenstability scan + time_axis (bool): display time of time index on time axis + figure_name (str): name of matplotlib figure + use_plotly (bool): use matplotlib or plotly backend + live (bool): for plotly backend, open browser with pot + plotly_filename (str): filename of .html figure from plotly + """ if len(all_results["times"]) == 1: L.info("Cannot plot the results if only one time point, we display the result instead:") L.info(all_results) @@ -31,17 +40,17 @@ def plot_scan( if use_plotly: try: - plot_scan_plotly(all_results, live=live, filename=plotly_filename) + _plot_scan_plotly(all_results, live=live, filename=plotly_filename) except ImportError: L.warning( "Plotly is not installed, please install package with \ pip install pygenstabiliy[plotly], using matplotlib instead." ) else: - plot_scan_plt(all_results, time_axis=time_axis, figure_name=figure_name) + _plot_scan_plt(all_results, time_axis=time_axis, figure_name=figure_name) -def plot_scan_plotly( # pylint: disable=too-many-branches,too-many-statements,too-many-locals +def _plot_scan_plotly( # pylint: disable=too-many-branches,too-many-statements,too-many-locals all_results, live=False, filename="clusters.html", @@ -185,7 +194,18 @@ def plot_scan_plotly( # pylint: disable=too-many-branches,too-many-statements,t def plot_single_community( graph, all_results, time_id, edge_color="0.5", edge_width=0.5, node_size=100 ): - """Plot the community structures for a given time.""" + """Plot the community structures for a given time. + + Args: + graph (networkx.Graph): graph to plot + all_results (dict): results of pygenstability scan + time_id (int): index of time to plot + folder (str): folder to save figures + edge_color (str): color of edges + edge_width (float): width of edges + node_size (float): size of nodes + ext (str): extension of figures files + """ pos = {u: graph.nodes[u]["pos"] for u in graph} node_color = all_results["community_id"][time_id] @@ -212,7 +232,16 @@ def plot_single_community( def plot_communities( graph, all_results, folder="communities", edge_color="0.5", edge_width=0.5, ext=".pdf" ): - """Plot the community structures at each time in a folder.""" + """Plot the community structures at each time in a folder. + + Args: + graph (networkx.Graph): graph to plot + all_results (dict): results of pygenstability scan + folder (str): folder to save figures + edge_color (str): color of edges + edge_width (float): width of edgs + ext (str): extension of figures files + """ if not os.path.isdir(folder): os.mkdir(folder) @@ -237,7 +266,7 @@ def _get_times(all_results, time_axis=True): return all_results["times"] -def plot_number_comm(all_results, ax, time_axis=True): +def _plot_number_comm(all_results, ax, time_axis=True): """Plot number of communities.""" times = _get_times(all_results, time_axis) @@ -246,7 +275,7 @@ def plot_number_comm(all_results, ax, time_axis=True): ax.tick_params("y", colors="C3") -def plot_ttprime(all_results, ax, time_axis): +def _plot_ttprime(all_results, ax, time_axis): """Plot ttprime.""" times = _get_times(all_results, time_axis) @@ -257,7 +286,7 @@ def plot_ttprime(all_results, ax, time_axis): ax.axis([times[0], times[-1], times[0], times[-1]]) -def plot_variation_information(all_results, ax, time_axis=True): +def _plot_variation_information(all_results, ax, time_axis=True): """Plot variation information.""" times = _get_times(all_results, time_axis=time_axis) ax.plot(times, all_results["variation_information"], "-", lw=2.0, c="C2", label="VI") @@ -269,7 +298,7 @@ def plot_variation_information(all_results, ax, time_axis=True): ax.axis([times[0], times[-1], 0.0, np.max(all_results["variation_information"]) * 1.1]) -def plot_stability(all_results, ax, time_axis=True): +def _plot_stability(all_results, ax, time_axis=True): """Plot stability.""" times = _get_times(all_results, time_axis=time_axis) ax.plot(times, all_results["stability"], "-", label=r"$Q$", c="C0") @@ -279,20 +308,20 @@ def plot_stability(all_results, ax, time_axis=True): ax.set_xlabel(r"$log_{10}(t)$") -def plot_scan_plt(all_results, time_axis=True, figure_name="scan_results.svg"): +def _plot_scan_plt(all_results, time_axis=True, figure_name="scan_results.svg"): """Plot results of pygenstability with matplotlib.""" gs = gridspec.GridSpec(2, 1, height_ratios=[1.0, 0.5]) gs.update(hspace=0) if "ttprime" in all_results: ax0 = plt.subplot(gs[0, 0]) - plot_ttprime(all_results, ax=ax0, time_axis=time_axis) + _plot_ttprime(all_results, ax=ax0, time_axis=time_axis) ax1 = ax0.twinx() else: ax1 = plt.subplot(gs[0, 0]) ax1.set_xticks([]) - plot_number_comm(all_results, ax=ax1, time_axis=time_axis) + _plot_number_comm(all_results, ax=ax1, time_axis=time_axis) if "ttprime" in all_results: ax1.yaxis.tick_right() ax1.yaxis.set_label_position("right") @@ -300,11 +329,11 @@ def plot_scan_plt(all_results, time_axis=True, figure_name="scan_results.svg"): ax2 = plt.subplot(gs[1, 0]) if "stability" in all_results: - plot_stability(all_results, ax=ax2, time_axis=time_axis) + _plot_stability(all_results, ax=ax2, time_axis=time_axis) if "variation_information" in all_results: ax3 = ax2.twinx() - plot_variation_information(all_results, ax=ax3, time_axis=time_axis) + _plot_variation_information(all_results, ax=ax3, time_axis=time_axis) plt.savefig(figure_name, bbox_inches="tight") @@ -378,7 +407,7 @@ def plot_clustered_adjacency( def plot_sankey(all_results, live=False, filename="communities_sankey.html", time_index=None): - """Plot Sankey diagram of communities accros time. + """Plot Sankey diagram of communities accros time (plotly only). Args: all_results (dict): results from run function diff --git a/pygenstability/pygenstability.py b/pygenstability/pygenstability.py index d79c24f..397b33a 100644 --- a/pygenstability/pygenstability.py +++ b/pygenstability/pygenstability.py @@ -116,7 +116,7 @@ def run( all_results["run_params"] = run_params for time in tqdm(times, disable=tqdm_disable): quality_matrix, null_model, global_shift = constructor.get_data(time) - louvain_results = run_several_louvains( + louvain_results = _run_several_louvains( quality_matrix, null_model, global_shift, n_louvain, pool ) communities = _process_louvain_run(time, louvain_results, all_results) @@ -227,7 +227,7 @@ def _evaluate_quality(partition_id, qualities_index, null_model, global_shift): return quality -def run_several_louvains(quality_matrix, null_model, global_shift, n_runs, pool): +def _run_several_louvains(quality_matrix, null_model, global_shift, n_runs, pool): """Run several louvain on the current quality matrix.""" quality_indices, quality_values = _to_indices(quality_matrix) worker = partial(