diff --git a/wntr/graphics/__init__.py b/wntr/graphics/__init__.py index 5ea000ee9..7238bd7ea 100644 --- a/wntr/graphics/__init__.py +++ b/wntr/graphics/__init__.py @@ -1,7 +1,7 @@ """ The wntr.graphics package contains graphic functions """ -from wntr.graphics.network import plot_network, plot_network_gis, plot_interactive_network, plot_leaflet_network, network_animation +from wntr.graphics.network import plot_network, plot_network_nx, plot_interactive_network, plot_leaflet_network, network_animation from wntr.graphics.layer import plot_valve_layer from wntr.graphics.curve import plot_fragility_curve, plot_pump_curve, plot_tank_volume_curve from wntr.graphics.color import custom_colormap, random_colormap diff --git a/wntr/graphics/network.py b/wntr/graphics/network.py index 37529d3a7..c8d12ce78 100644 --- a/wntr/graphics/network.py +++ b/wntr/graphics/network.py @@ -7,7 +7,9 @@ import networkx as nx import pandas as pd import matplotlib.pyplot as plt +import matplotlib.path as mpath from matplotlib import animation +import numpy as np try: import plotly @@ -22,13 +24,56 @@ logger = logging.getLogger(__name__) + +arrow_verts = [ + (0.0, 0.0), + (0.5, 0.5), + (0.5, -0.5), + (0.0, 0.0), +] + +arrow_marker = mpath.Path(arrow_verts) + def _get_angle(line, loc=0.5): # calculate orientation angle p1 = line.interpolate(loc-0.01, normalized=True) p2 = line.interpolate(loc+0.01, normalized=True) - angle = math.atan2(p2.y-p1.y, p2.x - p1.x) + angle = math.atan2(p2.y-p1.y, p2.x - p1.x) # radians + angle = math.degrees(angle) return angle + +def _prepare_attribute(attribute, gdf): + kwds = {} + if attribute is not None: + # if dict convert to a series + if isinstance(attribute, dict): + attribute = pd.Series(attribute) + # if series add as a column to link gdf + if isinstance(attribute, pd.Series): + gdf["_attribute"] = attribute + kwds["column"] = "_attribute" + # if list, create new boolean column that captures which indices are in the list + # TODO need to check this with original behavior + elif isinstance(attribute, list): + gdf["_attribute"] = np.nan + gdf.loc[gdf.index.isin(attribute), "_attribute"] = 1 + kwds["column"] = "_attribute" + # if str, assert that column name exists + elif isinstance(attribute, str): + if attribute not in gdf.columns: + raise KeyError(f"attribute {attribute} does not exist.") + kwds["column"] = attribute + else: + raise TypeError("attribute must be dict, Series, list, or str") + else: + kwds["color"] = "black" + return kwds + + + + + # def _create_oriented_arrow(line, length=0.01) def _format_node_attribute(node_attribute, wn): @@ -53,7 +98,7 @@ def _format_link_attribute(link_attribute, wn): return link_attribute -def plot_network_gis( +def plot_network( wn, node_attribute=None, link_attribute=None, title=None, node_size=20, node_range=None, node_alpha=1, node_cmap=None, node_labels=False, link_width=1, link_range=None, link_alpha=1, link_cmap=None, link_labels=False, @@ -155,66 +200,37 @@ def plot_network_gis( if title is not None: ax.set_title(title) - # set aspect setting aspect = None - # aspect = "auto" - # aspect = "equal" - - # initialize gis objects - wn_gis = wn.to_gis() - - link_gdf = pd.concat((wn_gis.pipes, wn_gis.pumps, wn_gis.valves)) - node_gdf = pd.concat((wn_gis.junctions, wn_gis.tanks, wn_gis.reservoirs)) - - # missing keyword args - # these are used for elements that do not have a value for the link_attribute - # missing_kwds = {"color": "black"} - - # set tank and reservoir marker - tank_marker = "P" + tank_marker = "D" reservoir_marker = "s" - # colormap if link_cmap is None: link_cmap = plt.get_cmap('Spectral_r') if node_cmap is None: node_cmap = plt.get_cmap('Spectral_r') - # ranges if link_range is None: link_range = (None, None) if node_range is None: node_range = (None, None) + + wn_gis = wn.to_gis() + link_gdf = pd.concat((wn_gis.pipes, wn_gis.pumps, wn_gis.valves)) + node_gdf = pd.concat((wn_gis.junctions, wn_gis.tanks, wn_gis.reservoirs)) + # process link attribute + link_kwds = _prepare_attribute(link_attribute, link_gdf) - # prepare pipe plotting keywords - link_kwds = {} - if link_attribute is not None: - # if dict convert to a series - if isinstance(link_attribute, dict): - link_attribute = pd.Series(link_attribute) - # if series add as a column to link gdf - if isinstance(link_attribute, pd.Series): - link_gdf["_link_attribute"] = link_attribute - link_kwds["column"] = "_link_attribute" - # if list, create new boolean column that captures which indices are in the list - # TODO need to check this with original behavior - elif isinstance(link_attribute, list): - link_gdf["_link_attribute"] = link_gdf.index.isin(link_attribute).astype(int) - link_kwds["column"] = "_link_attribute" - # if str, assert that column name exists - elif isinstance(link_attribute, str): - if link_attribute not in link_gdf.columns: - raise KeyError(f"link_attribute {link_attribute} does not exist.") - link_kwds["column"] = link_attribute - else: - raise TypeError("link_attribute must be dict, Series, list, or str") + if isinstance(link_attribute, list): + link_kwds["column"] = "_attribute" + link_kwds["cmap"] = custom_colormap(2,("red", "red")) + link_kwds["legend"] = False + elif isinstance(link_attribute, (dict, pd.Series, str)): link_kwds["cmap"] = link_cmap - if add_colorbar: - link_kwds["legend"] = True link_kwds["vmin"] = link_range[0] link_kwds["vmax"] = link_range[1] + link_kwds["legend"] = add_colorbar else: link_kwds["color"] = "black" @@ -230,36 +246,21 @@ def plot_network_gis( link_cbar_kwds["shrink"] = 0.5 link_cbar_kwds["pad"] = 0.0 link_cbar_kwds["label"] = link_colorbar_label - - # prepare junctin plotting keywords - node_kwds = {} - if node_attribute is not None: - # if dict convert to a series - if isinstance(node_attribute, dict): - node_attribute = pd.Series(node_attribute) - # if series add as a column to node gdf - if isinstance(node_attribute, pd.Series): - node_gdf["_node_attribute"] = node_attribute - node_kwds["column"] = "_node_attribute" - # if list, create new boolean column that captures which indices are in the list - # TODO need to check this with original behavior - elif isinstance(node_attribute, list): - node_gdf["_node_attribute"] = node_gdf.index.isin(node_attribute).astype(int) - node_kwds["column"] = "_node_attribute" - # if str, assert that column name exists - elif isinstance(node_attribute, str): - if node_attribute not in node_gdf.columns: - raise KeyError(f"node_attribute {node_attribute} does not exist.") - node_kwds["column"] = node_attribute - else: - raise TypeError("node_attribute must be dict, Series, list, or str") + + # process node attribute + node_kwds = _prepare_attribute(node_attribute, node_gdf) + + if isinstance(node_attribute, list): + node_kwds["cmap"] = custom_colormap(2,("red", "red")) + node_kwds["legend"] = False + elif isinstance(node_attribute, (dict, pd.Series, str)): node_kwds["cmap"] = node_cmap - if add_colorbar: - node_kwds["legend"] = True node_kwds["vmin"] = node_range[0] node_kwds["vmax"] = node_range[1] + node_kwds["legend"] = add_colorbar else: node_kwds["color"] = "black" + node_kwds["alpha"] = node_alpha node_kwds["markersize"] = node_size @@ -268,8 +269,8 @@ def plot_network_gis( node_cbar_kwds["pad"] = 0.0 node_cbar_kwds["label"] = node_colorbar_label + # plot nodes - each type is plotted separately to allow for different marker types # plot junctions - # junction_mask node_gdf[node_gdf.node_type == "Junction"].plot( ax=ax, aspect=aspect, zorder=3, legend_kwds=node_cbar_kwds, **node_kwds) @@ -277,12 +278,12 @@ def plot_network_gis( node_kwds["legend"] = False # plot tanks - node_kwds["markersize"] = node_size * 1.5 + node_kwds["markersize"] = node_size * 2.0 node_gdf[node_gdf.node_type == "Tank"].plot( ax=ax, aspect=aspect, zorder=4, marker=tank_marker, **node_kwds) # plot reservoirs - node_kwds["markersize"] = node_size * 2.0 + node_kwds["markersize"] = node_size * 3.0 node_gdf[node_gdf.node_type == "Reservoir"].plot( ax=ax, aspect=aspect, zorder=5, marker=reservoir_marker, **node_kwds) @@ -295,24 +296,24 @@ def plot_network_gis( ax=ax, aspect=aspect, zorder=2, legend_kwds=link_cbar_kwds, **link_kwds) # plot pumps - # if len(wn_gis.pumps) >0: - # wn_gis.pumps.plot(ax=ax, color="purple", aspect=aspect) - # wn_gis.pumps["midpoint"] = wn_gis.pumps.geometry.interpolate(0.5, normalized=True) - # wn_gis.pumps["angle"] = wn_gis.pumps.apply(lambda row: _get_angle(row.geometry), axis=1) - # for idx , row in wn_gis.pumps.iterrows(): - # x,y = row["midpoint"].x, row["midpoint"].y - # angle = row["angle"] - # ax.scatter(x,y, color="purple", s=100, marker=(3,0, angle-90)) + if len(wn_gis.pumps) >0: + # wn_gis.pumps.plot(ax=ax, color="purple", aspect=aspect) + wn_gis.pumps["midpoint"] = wn_gis.pumps.geometry.interpolate(0.5, normalized=True) + wn_gis.pumps["angle"] = wn_gis.pumps.apply(lambda row: _get_angle(row.geometry), axis=1) + for idx , row in wn_gis.pumps.iterrows(): + x,y = row["midpoint"].x, row["midpoint"].y + angle = row["angle"] + ax.scatter(x,y, color="black", s=100, marker=(3, 0, angle-90)) + # ax.scatter(x,y, color="purple", s=100, marker=arrow_marker) # plot valves - # if len(wn_gis.valves) >0: - # # wn_gis.valves.plot(ax=ax, color="green", aspect=aspect) - # wn_gis.valves["midpoint"] = wn_gis.valves.geometry.interpolate(0.5, normalized=True) - # wn_gis.valves["angle"] = wn_gis.valves.apply(lambda row: _get_angle(row.geometry), axis=1) - # for idx , row in wn_gis.valves.iterrows(): - # x,y = row["midpoint"].x, row["midpoint"].y - # angle = row["angle"] - # ax.scatter(x,y, color="green", s=100, marker=(3,0, angle-90)) + if len(wn_gis.valves) >0: + wn_gis.valves["midpoint"] = wn_gis.valves.geometry.interpolate(0.5, normalized=True) + wn_gis.valves["angle"] = wn_gis.valves.apply(lambda row: _get_angle(row.geometry), axis=1) + for idx , row in wn_gis.valves.iterrows(): + x,y = row["midpoint"].x, row["midpoint"].y + angle = row["angle"] + ax.scatter(x,y, color="black", s=200, marker=(2,0, angle)) # annotation if node_labels: @@ -329,10 +330,13 @@ def plot_network_gis( if filename: plt.savefig(filename) + if show_plot is True: + plt.show(block=False) + return ax -def plot_network(wn, node_attribute=None, link_attribute=None, title=None, +def plot_network_nx(wn, node_attribute=None, link_attribute=None, title=None, node_size=20, node_range=[None,None], node_alpha=1, node_cmap=None, node_labels=False, link_width=1, link_range=[None,None], link_alpha=1, link_cmap=None, link_labels=False, add_colorbar=True, node_colorbar_label='Node', link_colorbar_label='Link',