Skip to content

Commit

Permalink
set geopandas backend to the plot_network function name. update the g…
Browse files Browse the repository at this point in the history
…eopandas backend to handle lists for attributes and add shapes for valves and pumps.
  • Loading branch information
kbonney committed Oct 21, 2024
1 parent 487511a commit 9bc07cf
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 90 deletions.
2 changes: 1 addition & 1 deletion wntr/graphics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
The wntr.graphics package contains graphic functions
"""
from wntr.graphics.network import plot_network, plot_network_gis, plot_interactive_network, plot_leaflet_network, network_animation
from wntr.graphics.network import plot_network, plot_network_nx, plot_interactive_network, plot_leaflet_network, network_animation
from wntr.graphics.layer import plot_valve_layer
from wntr.graphics.curve import plot_fragility_curve, plot_pump_curve, plot_tank_volume_curve
from wntr.graphics.color import custom_colormap, random_colormap
Expand Down
182 changes: 93 additions & 89 deletions wntr/graphics/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.path as mpath
from matplotlib import animation
import numpy as np

try:
import plotly
Expand All @@ -22,13 +24,56 @@

logger = logging.getLogger(__name__)


arrow_verts = [
(0.0, 0.0),
(0.5, 0.5),
(0.5, -0.5),
(0.0, 0.0),
]

arrow_marker = mpath.Path(arrow_verts)

def _get_angle(line, loc=0.5):
# calculate orientation angle
p1 = line.interpolate(loc-0.01, normalized=True)
p2 = line.interpolate(loc+0.01, normalized=True)
angle = math.atan2(p2.y-p1.y, p2.x - p1.x)
angle = math.atan2(p2.y-p1.y, p2.x - p1.x) # radians
angle = math.degrees(angle)
return angle


def _prepare_attribute(attribute, gdf):
kwds = {}
if attribute is not None:
# if dict convert to a series
if isinstance(attribute, dict):
attribute = pd.Series(attribute)
# if series add as a column to link gdf
if isinstance(attribute, pd.Series):
gdf["_attribute"] = attribute
kwds["column"] = "_attribute"
# if list, create new boolean column that captures which indices are in the list
# TODO need to check this with original behavior
elif isinstance(attribute, list):
gdf["_attribute"] = np.nan
gdf.loc[gdf.index.isin(attribute), "_attribute"] = 1
kwds["column"] = "_attribute"
# if str, assert that column name exists
elif isinstance(attribute, str):
if attribute not in gdf.columns:
raise KeyError(f"attribute {attribute} does not exist.")
kwds["column"] = attribute
else:
raise TypeError("attribute must be dict, Series, list, or str")
else:
kwds["color"] = "black"
return kwds





# def _create_oriented_arrow(line, length=0.01)

def _format_node_attribute(node_attribute, wn):
Expand All @@ -53,7 +98,7 @@ def _format_link_attribute(link_attribute, wn):

return link_attribute

def plot_network_gis(
def plot_network(
wn, node_attribute=None, link_attribute=None, title=None,
node_size=20, node_range=None, node_alpha=1, node_cmap=None, node_labels=False,
link_width=1, link_range=None, link_alpha=1, link_cmap=None, link_labels=False,
Expand Down Expand Up @@ -155,66 +200,37 @@ def plot_network_gis(
if title is not None:
ax.set_title(title)

# set aspect setting
aspect = None
# aspect = "auto"
# aspect = "equal"

# initialize gis objects
wn_gis = wn.to_gis()

link_gdf = pd.concat((wn_gis.pipes, wn_gis.pumps, wn_gis.valves))

node_gdf = pd.concat((wn_gis.junctions, wn_gis.tanks, wn_gis.reservoirs))

# missing keyword args
# these are used for elements that do not have a value for the link_attribute
# missing_kwds = {"color": "black"}

# set tank and reservoir marker
tank_marker = "P"
tank_marker = "D"
reservoir_marker = "s"

# colormap
if link_cmap is None:
link_cmap = plt.get_cmap('Spectral_r')
if node_cmap is None:
node_cmap = plt.get_cmap('Spectral_r')

# ranges
if link_range is None:
link_range = (None, None)
if node_range is None:
node_range = (None, None)

wn_gis = wn.to_gis()
link_gdf = pd.concat((wn_gis.pipes, wn_gis.pumps, wn_gis.valves))
node_gdf = pd.concat((wn_gis.junctions, wn_gis.tanks, wn_gis.reservoirs))

# process link attribute
link_kwds = _prepare_attribute(link_attribute, link_gdf)

# prepare pipe plotting keywords
link_kwds = {}
if link_attribute is not None:
# if dict convert to a series
if isinstance(link_attribute, dict):
link_attribute = pd.Series(link_attribute)
# if series add as a column to link gdf
if isinstance(link_attribute, pd.Series):
link_gdf["_link_attribute"] = link_attribute
link_kwds["column"] = "_link_attribute"
# if list, create new boolean column that captures which indices are in the list
# TODO need to check this with original behavior
elif isinstance(link_attribute, list):
link_gdf["_link_attribute"] = link_gdf.index.isin(link_attribute).astype(int)
link_kwds["column"] = "_link_attribute"
# if str, assert that column name exists
elif isinstance(link_attribute, str):
if link_attribute not in link_gdf.columns:
raise KeyError(f"link_attribute {link_attribute} does not exist.")
link_kwds["column"] = link_attribute
else:
raise TypeError("link_attribute must be dict, Series, list, or str")
if isinstance(link_attribute, list):
link_kwds["column"] = "_attribute"
link_kwds["cmap"] = custom_colormap(2,("red", "red"))
link_kwds["legend"] = False
elif isinstance(link_attribute, (dict, pd.Series, str)):
link_kwds["cmap"] = link_cmap
if add_colorbar:
link_kwds["legend"] = True
link_kwds["vmin"] = link_range[0]
link_kwds["vmax"] = link_range[1]
link_kwds["legend"] = add_colorbar
else:
link_kwds["color"] = "black"

Expand All @@ -230,36 +246,21 @@ def plot_network_gis(
link_cbar_kwds["shrink"] = 0.5
link_cbar_kwds["pad"] = 0.0
link_cbar_kwds["label"] = link_colorbar_label

# prepare junctin plotting keywords
node_kwds = {}
if node_attribute is not None:
# if dict convert to a series
if isinstance(node_attribute, dict):
node_attribute = pd.Series(node_attribute)
# if series add as a column to node gdf
if isinstance(node_attribute, pd.Series):
node_gdf["_node_attribute"] = node_attribute
node_kwds["column"] = "_node_attribute"
# if list, create new boolean column that captures which indices are in the list
# TODO need to check this with original behavior
elif isinstance(node_attribute, list):
node_gdf["_node_attribute"] = node_gdf.index.isin(node_attribute).astype(int)
node_kwds["column"] = "_node_attribute"
# if str, assert that column name exists
elif isinstance(node_attribute, str):
if node_attribute not in node_gdf.columns:
raise KeyError(f"node_attribute {node_attribute} does not exist.")
node_kwds["column"] = node_attribute
else:
raise TypeError("node_attribute must be dict, Series, list, or str")

# process node attribute
node_kwds = _prepare_attribute(node_attribute, node_gdf)

if isinstance(node_attribute, list):
node_kwds["cmap"] = custom_colormap(2,("red", "red"))
node_kwds["legend"] = False
elif isinstance(node_attribute, (dict, pd.Series, str)):
node_kwds["cmap"] = node_cmap
if add_colorbar:
node_kwds["legend"] = True
node_kwds["vmin"] = node_range[0]
node_kwds["vmax"] = node_range[1]
node_kwds["legend"] = add_colorbar
else:
node_kwds["color"] = "black"

node_kwds["alpha"] = node_alpha
node_kwds["markersize"] = node_size

Expand All @@ -268,21 +269,21 @@ def plot_network_gis(
node_cbar_kwds["pad"] = 0.0
node_cbar_kwds["label"] = node_colorbar_label

# plot nodes - each type is plotted separately to allow for different marker types
# plot junctions
# junction_mask
node_gdf[node_gdf.node_type == "Junction"].plot(
ax=ax, aspect=aspect, zorder=3, legend_kwds=node_cbar_kwds, **node_kwds)

# turn off legend for subsequent node plots
node_kwds["legend"] = False

# plot tanks
node_kwds["markersize"] = node_size * 1.5
node_kwds["markersize"] = node_size * 2.0
node_gdf[node_gdf.node_type == "Tank"].plot(
ax=ax, aspect=aspect, zorder=4, marker=tank_marker, **node_kwds)

# plot reservoirs
node_kwds["markersize"] = node_size * 2.0
node_kwds["markersize"] = node_size * 3.0
node_gdf[node_gdf.node_type == "Reservoir"].plot(
ax=ax, aspect=aspect, zorder=5, marker=reservoir_marker, **node_kwds)

Expand All @@ -295,24 +296,24 @@ def plot_network_gis(
ax=ax, aspect=aspect, zorder=2, legend_kwds=link_cbar_kwds, **link_kwds)

# plot pumps
# if len(wn_gis.pumps) >0:
# wn_gis.pumps.plot(ax=ax, color="purple", aspect=aspect)
# wn_gis.pumps["midpoint"] = wn_gis.pumps.geometry.interpolate(0.5, normalized=True)
# wn_gis.pumps["angle"] = wn_gis.pumps.apply(lambda row: _get_angle(row.geometry), axis=1)
# for idx , row in wn_gis.pumps.iterrows():
# x,y = row["midpoint"].x, row["midpoint"].y
# angle = row["angle"]
# ax.scatter(x,y, color="purple", s=100, marker=(3,0, angle-90))
if len(wn_gis.pumps) >0:
# wn_gis.pumps.plot(ax=ax, color="purple", aspect=aspect)
wn_gis.pumps["midpoint"] = wn_gis.pumps.geometry.interpolate(0.5, normalized=True)
wn_gis.pumps["angle"] = wn_gis.pumps.apply(lambda row: _get_angle(row.geometry), axis=1)
for idx , row in wn_gis.pumps.iterrows():
x,y = row["midpoint"].x, row["midpoint"].y
angle = row["angle"]
ax.scatter(x,y, color="black", s=100, marker=(3, 0, angle-90))
# ax.scatter(x,y, color="purple", s=100, marker=arrow_marker)

# plot valves
# if len(wn_gis.valves) >0:
# # wn_gis.valves.plot(ax=ax, color="green", aspect=aspect)
# wn_gis.valves["midpoint"] = wn_gis.valves.geometry.interpolate(0.5, normalized=True)
# wn_gis.valves["angle"] = wn_gis.valves.apply(lambda row: _get_angle(row.geometry), axis=1)
# for idx , row in wn_gis.valves.iterrows():
# x,y = row["midpoint"].x, row["midpoint"].y
# angle = row["angle"]
# ax.scatter(x,y, color="green", s=100, marker=(3,0, angle-90))
if len(wn_gis.valves) >0:
wn_gis.valves["midpoint"] = wn_gis.valves.geometry.interpolate(0.5, normalized=True)
wn_gis.valves["angle"] = wn_gis.valves.apply(lambda row: _get_angle(row.geometry), axis=1)
for idx , row in wn_gis.valves.iterrows():
x,y = row["midpoint"].x, row["midpoint"].y
angle = row["angle"]
ax.scatter(x,y, color="black", s=200, marker=(2,0, angle))

# annotation
if node_labels:
Expand All @@ -329,10 +330,13 @@ def plot_network_gis(
if filename:
plt.savefig(filename)

if show_plot is True:
plt.show(block=False)

return ax


def plot_network(wn, node_attribute=None, link_attribute=None, title=None,
def plot_network_nx(wn, node_attribute=None, link_attribute=None, title=None,
node_size=20, node_range=[None,None], node_alpha=1, node_cmap=None, node_labels=False,
link_width=1, link_range=[None,None], link_alpha=1, link_cmap=None, link_labels=False,
add_colorbar=True, node_colorbar_label='Node', link_colorbar_label='Link',
Expand Down

0 comments on commit 9bc07cf

Please sign in to comment.