diff --git a/wntr/graphics/network.py b/wntr/graphics/network.py index 1b270bba4..7b49704fd 100644 --- a/wntr/graphics/network.py +++ b/wntr/graphics/network.py @@ -3,10 +3,14 @@ water network model. """ import logging +import math import networkx as nx import pandas as pd import matplotlib.pyplot as plt +import matplotlib.path as mpath from matplotlib import animation +import matplotlib as mpl +import numpy as np try: import plotly @@ -21,6 +25,24 @@ logger = logging.getLogger(__name__) + +arrow_verts = [ + (0.0, 0.0), + (0.5, 0.5), + (0.5, -0.5), + (0.0, 0.0), +] + +arrow_marker = mpath.Path(arrow_verts) + +def _get_angle(line, loc=0.5): + # calculate orientation angle + p1 = line.interpolate(loc-0.01, normalized=True) + p2 = line.interpolate(loc+0.01, normalized=True) + angle = math.atan2(p2.y-p1.y, p2.x - p1.x) # radians + angle = math.degrees(angle) + return angle + def _format_node_attribute(node_attribute, wn): if isinstance(node_attribute, str): @@ -42,12 +64,13 @@ def _format_link_attribute(link_attribute, wn): link_attribute = dict(link_attribute) return link_attribute - -def plot_network(wn, node_attribute=None, link_attribute=None, title=None, - node_size=20, node_range=[None,None], node_alpha=1, node_cmap=None, node_labels=False, - link_width=1, link_range=[None,None], link_alpha=1, link_cmap=None, link_labels=False, - add_colorbar=True, node_colorbar_label='Node', link_colorbar_label='Link', - directed=False, ax=None, show_plot=True, filename=None): + +def plot_network( + wn, node_attribute=None, link_attribute=None, title=None, + node_size=20, node_range=None, node_alpha=1, node_cmap=None, node_labels=False, + link_width=1, link_range=None, link_alpha=1, link_cmap=None, link_labels=False, + add_colorbar=True, node_colorbar_label=None, link_colorbar_label=None, + directed=False, legend=False, ax=None, show_plot=True, filename=None): """ Plot network graphic @@ -126,7 +149,7 @@ def plot_network(wn, node_attribute=None, link_attribute=None, title=None, ax: matplotlib axes object, optional Axes for plotting (None indicates that a new figure with a single axes will be used) - + show_plot: bool, optional If True, show plot with plt.show() @@ -137,113 +160,182 @@ def plot_network(wn, node_attribute=None, link_attribute=None, title=None, ------- ax : matplotlib axes object """ - if ax is None: # create a new figure plt.figure(facecolor='w', edgecolor='k') ax = plt.gca() - # Graph - G = wn.to_graph() - if not directed: - G = G.to_undirected() - - # Position - pos = nx.get_node_attributes(G,'pos') - if len(pos) == 0: - pos = None - - # Define node properties - add_node_colorbar = add_colorbar + if title is not None: + ax.set_title(title) + + aspect = "equal" + + tank_marker = "D" + reservoir_marker = "s" + + if link_cmap is None: + link_cmap = plt.get_cmap('Spectral_r') + if node_cmap is None: + node_cmap = plt.get_cmap('Spectral_r') + + if link_range is None: + link_range = (None, None) + if node_range is None: + node_range = (None, None) + + # use attribute name if no other label is provided + if node_colorbar_label is None and isinstance(node_attribute, str): + node_colorbar_label = node_attribute + if link_colorbar_label is None and isinstance(link_attribute, str): + link_colorbar_label = link_attribute + + wn_gis = wn.to_gis() + # add node_type so that node assets can be plotted separately + wn_gis.junctions["node_type"] = "Junction" + wn_gis.tanks["node_type"] = "Tank" + wn_gis.reservoirs["node_type"] = "Reservoir" + link_gdf = pd.concat((wn_gis.pipes, wn_gis.pumps, wn_gis.valves)) + node_gdf = pd.concat((wn_gis.junctions, wn_gis.tanks, wn_gis.reservoirs)) + + # Node attribute + node_kwds = {} + node_cbar = add_colorbar if node_attribute is not None: + node_gdf["_attribute"] = _format_node_attribute(node_attribute, wn) + node_kwds["column"] = "_attribute" + # handle cbar/cmap if isinstance(node_attribute, list): - if node_cmap is None: - node_cmap = ['red', 'red'] - add_node_colorbar = False - - if node_cmap is None: - node_cmap = plt.get_cmap('Spectral_r') - elif isinstance(node_cmap, list): - if len(node_cmap) == 1: - node_cmap = node_cmap*2 - node_cmap = custom_colormap(len(node_cmap), node_cmap) - - node_attribute = _format_node_attribute(node_attribute, wn) - nodelist,nodecolor = zip(*node_attribute.items()) - + node_kwds["cmap"] = custom_colormap(2,["red", "red"]) + node_cbar = False + elif isinstance(node_attribute, (dict, pd.Series, str)): + node_kwds["cmap"] = node_cmap + + # manually extract min/max if no range is given + node_attribute_values = node_gdf[node_kwds["column"]] + if node_range[0] is None: + node_kwds["vmin"] = np.nanmin(node_attribute_values) + else: + node_kwds["vmin"] = node_range[0] + if node_range[1] is None: + node_kwds["vmax"] = np.nanmax(node_attribute_values) + else: + node_kwds["vmax"] = node_range[1] + else: + raise TypeError("attribute must be dict, Series, list, or str") else: - nodelist = None - nodecolor = 'k' + node_kwds["color"] = "black" + node_cbar = False + + node_kwds["alpha"] = node_alpha + node_kwds["markersize"] = node_size - add_link_colorbar = add_colorbar + node_cbar_kwds = {} + node_cbar_kwds["shrink"] = 0.5 + node_cbar_kwds["pad"] = 0.0 + node_cbar_kwds["alpha"] = node_alpha + node_cbar_kwds["label"] = node_colorbar_label + + # Link attribute + link_kwds = {} + link_cbar = add_colorbar if link_attribute is not None: + link_gdf["_attribute"] = pd.Series(_format_link_attribute(link_attribute, wn)) + link_kwds["column"] = "_attribute" + # handle cbar/cmap if isinstance(link_attribute, list): - if link_cmap is None: - link_cmap = ['red', 'red'] - add_link_colorbar = False - - if link_cmap is None: - link_cmap = plt.get_cmap('Spectral_r') - elif isinstance(link_cmap, list): - if len(link_cmap) == 1: - link_cmap = link_cmap*2 - link_cmap = custom_colormap(len(link_cmap), link_cmap) + link_kwds["cmap"] = custom_colormap(2,["red", "red"]) + link_cbar = False + elif isinstance(link_attribute, (dict, pd.Series, str)): + link_kwds["cmap"] = link_cmap - link_attribute = _format_link_attribute(link_attribute, wn) - - # Replace link_attribute dictionary defined as - # {link_name: attr} with {(start_node, end_node, link_name): attr} - attr = {} - for link_name, value in link_attribute.items(): - link = wn.get_link(link_name) - attr[(link.start_node_name, link.end_node_name, link_name)] = value - link_attribute = attr - - linklist,linkcolor = zip(*link_attribute.items()) + # manually extract min/max if no range is given + link_attribute_values = link_gdf[link_kwds["column"]] + if link_range[0] is None: + link_kwds["vmin"] = np.nanmin(link_attribute_values) + else: + link_kwds["vmin"] = link_range[0] + if link_range[1] is None: + link_kwds["vmax"] = np.nanmax(link_attribute_values) + else: + link_kwds["vmax"] = link_range[1] + else: + raise TypeError("attribute must be dict, Series, list, or str") else: - linklist = None - linkcolor = 'k' + link_kwds["color"] = "black" + link_cbar = False - if title is not None: - ax.set_title(title) - - edge_background = nx.draw_networkx_edges(G, pos, edge_color='grey', - width=0.5, ax=ax) + link_kwds["linewidth"] = link_width + link_kwds["alpha"] = link_alpha + + background_link_kwds = {} + background_link_kwds["color"] = "grey" + background_link_kwds["linewidth"] = link_width / 2 + background_link_kwds["alpha"] = link_alpha + + link_cbar_kwds = {} + link_cbar_kwds["shrink"] = 0.5 + link_cbar_kwds["pad"] = 0.05 + link_cbar_kwds["label"] = link_colorbar_label + link_cbar_kwds["alpha"] = link_alpha + + missing_node_kwds={"color": "black"} + missing_link_kwds={"color": "black"} + + # plot nodes - each type is plotted separately to allow for different marker types + node_gdf[node_gdf.node_type == "Junction"].plot( + ax=ax, aspect=aspect, zorder=3, legend=False, label="Junction", missing_kwds=missing_node_kwds, **node_kwds) + + node_kwds["markersize"] = node_size * 2.0 + node_gdf[node_gdf.node_type == "Tank"].plot( + ax=ax, aspect=aspect, zorder=4, marker=tank_marker, legend=False, label="Tank", missing_kwds=missing_node_kwds, **node_kwds) - nodes = nx.draw_networkx_nodes(G, pos, - nodelist=nodelist, node_color=nodecolor, node_size=node_size, - alpha=node_alpha, cmap=node_cmap, vmin=node_range[0], vmax = node_range[1], - linewidths=0, ax=ax) - edges = nx.draw_networkx_edges(G, pos, edgelist=linklist, arrows=directed, - edge_color=linkcolor, width=link_width, alpha=link_alpha, edge_cmap=link_cmap, - edge_vmin=link_range[0], edge_vmax=link_range[1], ax=ax) + node_kwds["markersize"] = node_size * 3.0 + node_gdf[node_gdf.node_type == "Reservoir"].plot( + ax=ax, aspect=aspect, zorder=5, marker=reservoir_marker, legend=False, label="Reservoir", missing_kwds=missing_node_kwds,**node_kwds) + + if node_cbar: + sm = mpl.cm.ScalarMappable(cmap=node_kwds["cmap"]) + sm.set_clim(node_kwds["vmin"], node_kwds["vmax"]) + + node_cbar = ax.figure.colorbar(sm, ax=ax, **node_cbar_kwds) + + # plot links + # background + link_gdf.plot( + ax=ax, aspect=aspect, zorder=1, legend=False, **background_link_kwds) + + # main plot + link_gdf.plot( + ax=ax, aspect=aspect, zorder=2, legend=False, missing_kwds=missing_link_kwds, **link_kwds) + + if link_cbar: + sm = mpl.cm.ScalarMappable(cmap=link_kwds["cmap"]) + sm.set_clim(link_kwds["vmin"], link_kwds["vmax"]) + + link_cbar = ax.figure.colorbar(sm, ax=ax, **link_cbar_kwds) + if node_labels: - labels = dict(zip(wn.node_name_list, wn.node_name_list)) - nx.draw_networkx_labels(G, pos, labels, font_size=7, ax=ax) + for x, y, label in zip(node_gdf.geometry.x, node_gdf.geometry.y, node_gdf.index): + ax.annotate(label, xy=(x, y))#, xytext=(3, 3),)# textcoords="offset points") + if link_labels: - labels = {} - for link_name in wn.link_name_list: - link = wn.get_link(link_name) - labels[(link.start_node_name, link.end_node_name)] = link_name - nx.draw_networkx_edge_labels(G, pos, labels, font_size=7, ax=ax) - if add_node_colorbar and node_attribute: - clb = plt.colorbar(nodes, shrink=0.5, pad=0, ax=ax) - clb.ax.set_title(node_colorbar_label, fontsize=10) - if add_link_colorbar and link_attribute: - if link_range[0] is None: - vmin = min(link_attribute.values()) - else: - vmin = link_range[0] - if link_range[1] is None: - vmax = max(link_attribute.values()) - else: - vmax = link_range[1] - sm = plt.cm.ScalarMappable(cmap=link_cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax)) - sm.set_array([]) - clb = plt.colorbar(sm, shrink=0.5, pad=0.05, ax=ax) - clb.ax.set_title(link_colorbar_label, fontsize=10) - + midpoints = link_gdf.geometry.apply(lambda x: x.interpolate(0.5, normalized=True)) + for x, y, label in zip(midpoints.geometry.x, midpoints.geometry.y, link_gdf.index): + ax.annotate(label, xy=(x, y))#, xytext=(3, 3),)# textcoords="offset points") + + if directed: + link_gdf["_midpoint"] = link_gdf.geometry.interpolate(0.5, normalized=True) + link_gdf["_angle"] = link_gdf.apply(lambda row: _get_angle(row.geometry), axis=1) + for idx , row in link_gdf.iterrows(): + x,y = row["_midpoint"].x, row["_midpoint"].y + angle = row["_angle"] + ax.scatter(x,y, color="black", s=50, marker=(3,0, angle-90)) + + if legend: + handles, labels = ax.get_legend_handles_labels() + leg = ax.legend(handles, labels, loc='upper right', title="Legend") + ax.axis('off') if filename: diff --git a/wntr/tests/test_graphics.py b/wntr/tests/test_graphics.py index 0374ebef3..8caed7e45 100644 --- a/wntr/tests/test_graphics.py +++ b/wntr/tests/test_graphics.py @@ -5,7 +5,12 @@ import warnings from os.path import abspath, dirname, isfile, join +import networkx as nx import matplotlib.pylab as plt +import matplotlib +from wntr.graphics.color import custom_colormap +import pandas as pd +import numpy as np import wntr testdir = dirname(abspath(str(__file__))) @@ -22,7 +27,6 @@ def test_plot_network1(self): inp_file = join(ex_datadir, "Net6.inp") wn = wntr.network.WaterNetworkModel(inp_file) - plt.figure() wntr.graphics.plot_network(wn) plt.savefig(filename, format="png") plt.close() @@ -37,7 +41,7 @@ def test_plot_network2(self): filename = abspath(join(testdir, "plot_network2_undirected.png")) if isfile(filename): os.remove(filename) - plt.figure() + wntr.graphics.plot_network( wn, node_attribute="elevation", link_attribute="length" ) @@ -50,7 +54,7 @@ def test_plot_network2(self): filename = abspath(join(testdir, "plot_network2_directed.png")) if isfile(filename): os.remove(filename) - plt.figure() + wntr.graphics.plot_network( wn, node_attribute="elevation", link_attribute="length", directed=True ) @@ -67,7 +71,6 @@ def test_plot_network3(self): inp_file = join(ex_datadir, "Net1.inp") wn = wntr.network.WaterNetworkModel(inp_file) - plt.figure() wntr.graphics.plot_network( wn, node_attribute=["11", "21"], @@ -87,7 +90,6 @@ def test_plot_network4(self): inp_file = join(ex_datadir, "Net1.inp") wn = wntr.network.WaterNetworkModel(inp_file) - plt.figure() wntr.graphics.plot_network( wn, node_attribute={"11": 5, "21": 10}, @@ -108,7 +110,6 @@ def test_plot_network5(self): wn = wntr.network.WaterNetworkModel(inp_file) pop = wntr.metrics.population(wn) - plt.figure() wntr.graphics.plot_network( wn, node_attribute=pop, node_range=[0, 500], title="Population" ) @@ -117,6 +118,95 @@ def test_plot_network5(self): self.assertTrue(isfile(filename)) + def test_plot_network6(self): + # legend + filename = abspath(join(testdir, "plot_network6.png")) + if isfile(filename): + os.remove(filename) + + inp_file = join(ex_datadir, "Net6.inp") + wn = wntr.network.WaterNetworkModel(inp_file) + + wntr.graphics.plot_network( + wn, node_attribute="elevation", link_attribute="diameter", + add_colorbar=True, legend=True + ) + plt.savefig(filename, format="png") + plt.close() + + self.assertTrue(isfile(filename)) + + def test_plot_network_options(self): + # NOTE:to compare with the old plot_network set compare=True. + # this should be set to false for regular testing + compare = False + + cmap = matplotlib.colormaps['viridis'] + + inp_file = join(ex_datadir, "Net6.inp") + wn = wntr.network.WaterNetworkModel(inp_file) + + random_node_values = pd.Series( + np.random.rand(len(wn.node_name_list)), index=wn.node_name_list) + random_link_values = pd.Series( + np.random.rand(len(wn.link_name_list)), index=wn.link_name_list) + random_pipe_values = pd.Series( + np.random.rand(len(wn.pipe_name_list)), index=wn.pipe_name_list) + random_node_dict_subset = dict(random_node_values.iloc[:10]) + random_link_dict_subset = dict(random_link_values.iloc[:10]) + node_list = list(wn.node_name_list[:10]) + link_list = list(wn.link_name_list[:10]) + + kwarg_list = [ + {"node_attribute": "elevation", + "node_range": [0,20], + "node_alpha": 0.5, + "node_colorbar_label": "test_label"}, + {"link_attribute": "diameter", + "link_range": [0,None], + "link_alpha": 0.5, + "link_colorbar_label": "test_label"}, + {"link_attribute": "diameter", + "node_attribute": "elevation"}, + {"node_labels": True, + "link_labels": True}, + {"node_attribute": "elevation", + "add_colorbar": False}, + {"link_attribute": "diameter", + "add_colorbar": False}, + {"node_attribute": node_list}, + {"node_attribute": random_node_values}, + {"node_attribute": random_node_dict_subset}, + {"link_attribute": link_list}, + {"link_attribute": random_link_values}, + {"link_attribute": random_link_dict_subset}, + {"directed": True}, + {"link_attribute": random_pipe_values, + "node_size": 0, + "link_cmap": cmap, + "link_range": [0,1], + "link_width": 1.5}, + ] + + for kwargs in kwarg_list: + filename = abspath(join(testdir, "plot_network_options.png")) + if isfile(filename): + os.remove(filename) + if compare: + fig, ax = plt.subplots(1,2) + wntr.graphics.plot_network(wn, ax=ax[0], title="GIS plot_network", **kwargs) + plot_network_nx(wn, ax=ax[1], title="NX plot_network", **kwargs) + fig.savefig(filename, format="png") + plt.close(fig) + else: + wntr.graphics.plot_network(wn, **kwargs) + plt.savefig(filename, format="png") + plt.close() + + self.assertTrue(isfile(filename)) + os.remove(filename) + + def test_plot_interactive_network1(self): filename = abspath(join(testdir, "plot_interactive_network1.html")) @@ -235,6 +325,241 @@ def test_custom_colormap(self): ) self.assertEqual(cmp.N, 3) self.assertEqual(cmp.name, "custom") + + +# old plotting function using networkx backend to compare with geopandas +def plot_network_nx(wn, node_attribute=None, link_attribute=None, title=None, + node_size=20, node_range=[None,None], node_alpha=1, node_cmap=None, node_labels=False, + link_width=1, link_range=[None,None], link_alpha=1, link_cmap=None, link_labels=False, + add_colorbar=True, node_colorbar_label='Node', link_colorbar_label='Link', + directed=False, ax=None, show_plot=True, filename=None): + """ + Plot network graphic + + Parameters + ---------- + wn : wntr WaterNetworkModel + A WaterNetworkModel object + + node_attribute : None, str, list, pd.Series, or dict, optional + + - If node_attribute is a string, then a node attribute dictionary is + created using node_attribute = wn.query_node_attribute(str) + - If node_attribute is a list, then each node in the list is given a + value of 1. + - If node_attribute is a pd.Series, then it should be in the format + {nodeid: x} where nodeid is a string and x is a float. + - If node_attribute is a dict, then it should be in the format + {nodeid: x} where nodeid is a string and x is a float + + link_attribute : None, str, list, pd.Series, or dict, optional + + - If link_attribute is a string, then a link attribute dictionary is + created using edge_attribute = wn.query_link_attribute(str) + - If link_attribute is a list, then each link in the list is given a + value of 1. + - If link_attribute is a pd.Series, then it should be in the format + {linkid: x} where linkid is a string and x is a float. + - If link_attribute is a dict, then it should be in the format + {linkid: x} where linkid is a string and x is a float. + + title: str, optional + Plot title + + node_size: int, optional + Node size + + node_range: list, optional + Node color range ([None,None] indicates autoscale) + + node_alpha: int, optional + Node transparency + + node_cmap: matplotlib.pyplot.cm colormap or list of named colors, optional + Node colormap + + node_labels: bool, optional + If True, the graph will include each node labelled with its name. + + link_width: int, optional + Link width + + link_range : list, optional + Link color range ([None,None] indicates autoscale) + + link_alpha : int, optional + Link transparency + + link_cmap: matplotlib.pyplot.cm colormap or list of named colors, optional + Link colormap + + link_labels: bool, optional + If True, the graph will include each link labelled with its name. + + add_colorbar: bool, optional + Add colorbar + + node_colorbar_label: str, optional + Node colorbar label + + link_colorbar_label: str, optional + Link colorbar label + + directed: bool, optional + If True, plot the directed graph + + ax: matplotlib axes object, optional + Axes for plotting (None indicates that a new figure with a single + axes will be used) + + show_plot: bool, optional + If True, show plot with plt.show() + + filename : str, optional + Filename used to save the figure + + Returns + ------- + ax : matplotlib axes object + """ + + def _format_node_attribute(node_attribute, wn): + + if isinstance(node_attribute, str): + node_attribute = wn.query_node_attribute(node_attribute) + if isinstance(node_attribute, list): + node_attribute = dict(zip(node_attribute,[1]*len(node_attribute))) + if isinstance(node_attribute, pd.Series): + node_attribute = dict(node_attribute) + + return node_attribute + + def _format_link_attribute(link_attribute, wn): + + if isinstance(link_attribute, str): + link_attribute = wn.query_link_attribute(link_attribute) + if isinstance(link_attribute, list): + link_attribute = dict(zip(link_attribute,[1]*len(link_attribute))) + if isinstance(link_attribute, pd.Series): + link_attribute = dict(link_attribute) + + return link_attribute + + if ax is None: # create a new figure + plt.figure(facecolor='w', edgecolor='k') + ax = plt.gca() + + # Graph + G = wn.to_graph() + if not directed: + G = G.to_undirected() + + # Position + pos = nx.get_node_attributes(G,'pos') + if len(pos) == 0: + pos = None + + # Define node properties + add_node_colorbar = add_colorbar + if node_attribute is not None: + + if isinstance(node_attribute, list): + if node_cmap is None: + node_cmap = ['red', 'red'] + add_node_colorbar = False + + if node_cmap is None: + node_cmap = plt.get_cmap('Spectral_r') + elif isinstance(node_cmap, list): + if len(node_cmap) == 1: + node_cmap = node_cmap*2 + node_cmap = custom_colormap(len(node_cmap), node_cmap) + + node_attribute = _format_node_attribute(node_attribute, wn) + nodelist,nodecolor = zip(*node_attribute.items()) + + else: + nodelist = None + nodecolor = 'k' + + add_link_colorbar = add_colorbar + if link_attribute is not None: + + if isinstance(link_attribute, list): + if link_cmap is None: + link_cmap = ['red', 'red'] + add_link_colorbar = False + + if link_cmap is None: + link_cmap = plt.get_cmap('Spectral_r') + elif isinstance(link_cmap, list): + if len(link_cmap) == 1: + link_cmap = link_cmap*2 + link_cmap = custom_colormap(len(link_cmap), link_cmap) + + link_attribute = _format_link_attribute(link_attribute, wn) + + # Replace link_attribute dictionary defined as + # {link_name: attr} with {(start_node, end_node, link_name): attr} + attr = {} + for link_name, value in link_attribute.items(): + link = wn.get_link(link_name) + attr[(link.start_node_name, link.end_node_name, link_name)] = value + link_attribute = attr + + linklist,linkcolor = zip(*link_attribute.items()) + else: + linklist = None + linkcolor = 'k' + + if title is not None: + ax.set_title(title) + + edge_background = nx.draw_networkx_edges(G, pos, edge_color='grey', + width=0.5, ax=ax) + + nodes = nx.draw_networkx_nodes(G, pos, + nodelist=nodelist, node_color=nodecolor, node_size=node_size, + alpha=node_alpha, cmap=node_cmap, vmin=node_range[0], vmax = node_range[1], + linewidths=0, ax=ax) + edges = nx.draw_networkx_edges(G, pos, edgelist=linklist, arrows=directed, + edge_color=linkcolor, width=link_width, alpha=link_alpha, edge_cmap=link_cmap, + edge_vmin=link_range[0], edge_vmax=link_range[1], ax=ax) + if node_labels: + labels = dict(zip(wn.node_name_list, wn.node_name_list)) + nx.draw_networkx_labels(G, pos, labels, font_size=7, ax=ax) + if link_labels: + labels = {} + for link_name in wn.link_name_list: + link = wn.get_link(link_name) + labels[(link.start_node_name, link.end_node_name)] = link_name + nx.draw_networkx_edge_labels(G, pos, labels, font_size=7, ax=ax) + if add_node_colorbar and node_attribute: + clb = plt.colorbar(nodes, shrink=0.5, pad=0, ax=ax) + clb.ax.set_title(node_colorbar_label, fontsize=10) + if add_link_colorbar and link_attribute: + if link_range[0] is None: + vmin = min(link_attribute.values()) + else: + vmin = link_range[0] + if link_range[1] is None: + vmax = max(link_attribute.values()) + else: + vmax = link_range[1] + sm = plt.cm.ScalarMappable(cmap=link_cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax)) + sm.set_array([]) + clb = plt.colorbar(sm, shrink=0.5, pad=0.05, ax=ax) + clb.ax.set_title(link_colorbar_label, fontsize=10) + + ax.axis('off') + + if filename: + plt.savefig(filename) + + if show_plot is True: + plt.show(block=False) + + return ax if __name__ == "__main__":