Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update plot_network to use the geopandas plotting API in place of the networkx plotting API #451

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2fe12ed
first draft of plot_network function with a GIS backend
kbonney Jun 11, 2024
48ac43e
incorporating the rest of the keywords from plot_network in plot_netw…
kbonney Sep 30, 2024
b1bd4f6
Merge branch 'main' into gis_plotting
kbonney Sep 30, 2024
3baa93a
combining link-like and node-like gis files
kbonney Oct 1, 2024
487511a
adding tank/reservoir marker shapes and implementating other type opt…
kbonney Oct 1, 2024
9bc07cf
set geopandas backend to the plot_network function name. update the g…
kbonney Oct 21, 2024
3bb3c4b
archive networkx plotting function in test suite
kbonney Oct 21, 2024
1a7e0a5
clean up function and add directed functionality
kbonney Nov 5, 2024
08f9ee9
add comparison test for plotting
kbonney Nov 5, 2024
1eaaade
handle cbar manually to avoid error in earthquake demo
kbonney Nov 5, 2024
f0c32d4
Merge remote-tracking branch 'usepa/main' into gis_plotting
kbonney Nov 5, 2024
80eaf96
extend test cases
kbonney Nov 5, 2024
bacf024
Merge remote-tracking branch 'usepa/main' into gis_plotting
kbonney Nov 20, 2024
0757c2c
Merge remote-tracking branch 'usepa/main' into gis_plotting
kbonney Nov 20, 2024
d0d59f4
fix bug caused by node_type no longer provided by GIS geodataframes, …
kbonney Nov 20, 2024
6e3fcf6
add extra test cases, remove unneccesary calls to plt.figure
kbonney Nov 20, 2024
a3b2576
setting compare to false
kbonney Nov 20, 2024
c267ac1
add alpha to colorbars
kbonney Nov 26, 2024
664e4a5
fix color bug on mpl 3.8, change plot_valve/pumps kwarg name, set asp…
kbonney Nov 26, 2024
6401716
fix test case to adjust for new kwarg names
kbonney Nov 26, 2024
105d243
add legend and plot all nodes regardless of node_attribute
kbonney Dec 2, 2024
c8f9fd5
add additional tests
kbonney Dec 2, 2024
0558ff8
remove old plot network code, clean up plot network, add default colo…
kbonney Dec 3, 2024
4a8f5a4
remove pump/valve direction plotting tests
kbonney Dec 3, 2024
046006a
update plot_network to use _format functions instead of _prepare func…
kbonney Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 188 additions & 96 deletions wntr/graphics/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
water network model.
"""
import logging
import math
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.path as mpath
from matplotlib import animation
import matplotlib as mpl
import numpy as np

try:
import plotly
Expand All @@ -21,6 +25,24 @@

logger = logging.getLogger(__name__)


arrow_verts = [
(0.0, 0.0),
(0.5, 0.5),
(0.5, -0.5),
(0.0, 0.0),
]

arrow_marker = mpath.Path(arrow_verts)

def _get_angle(line, loc=0.5):
# calculate orientation angle
p1 = line.interpolate(loc-0.01, normalized=True)
p2 = line.interpolate(loc+0.01, normalized=True)
angle = math.atan2(p2.y-p1.y, p2.x - p1.x) # radians
angle = math.degrees(angle)
return angle

def _format_node_attribute(node_attribute, wn):

if isinstance(node_attribute, str):
Expand All @@ -42,12 +64,13 @@ def _format_link_attribute(link_attribute, wn):
link_attribute = dict(link_attribute)

return link_attribute

def plot_network(wn, node_attribute=None, link_attribute=None, title=None,
node_size=20, node_range=[None,None], node_alpha=1, node_cmap=None, node_labels=False,
link_width=1, link_range=[None,None], link_alpha=1, link_cmap=None, link_labels=False,
add_colorbar=True, node_colorbar_label='Node', link_colorbar_label='Link',
directed=False, ax=None, show_plot=True, filename=None):

def plot_network(
wn, node_attribute=None, link_attribute=None, title=None,
node_size=20, node_range=None, node_alpha=1, node_cmap=None, node_labels=False,
link_width=1, link_range=None, link_alpha=1, link_cmap=None, link_labels=False,
add_colorbar=True, node_colorbar_label=None, link_colorbar_label=None,
directed=False, legend=False, ax=None, show_plot=True, filename=None):
"""
Plot network graphic

Expand Down Expand Up @@ -126,7 +149,7 @@ def plot_network(wn, node_attribute=None, link_attribute=None, title=None,
ax: matplotlib axes object, optional
Axes for plotting (None indicates that a new figure with a single
axes will be used)

show_plot: bool, optional
If True, show plot with plt.show()

Expand All @@ -137,113 +160,182 @@ def plot_network(wn, node_attribute=None, link_attribute=None, title=None,
-------
ax : matplotlib axes object
"""

if ax is None: # create a new figure
plt.figure(facecolor='w', edgecolor='k')
ax = plt.gca()

# Graph
G = wn.to_graph()
if not directed:
G = G.to_undirected()

# Position
pos = nx.get_node_attributes(G,'pos')
if len(pos) == 0:
pos = None

# Define node properties
add_node_colorbar = add_colorbar
if title is not None:
ax.set_title(title)

aspect = "equal"

tank_marker = "D"
reservoir_marker = "s"

if link_cmap is None:
link_cmap = plt.get_cmap('Spectral_r')
if node_cmap is None:
node_cmap = plt.get_cmap('Spectral_r')

if link_range is None:
link_range = (None, None)
if node_range is None:
node_range = (None, None)

# use attribute name if no other label is provided
if node_colorbar_label is None and isinstance(node_attribute, str):
node_colorbar_label = node_attribute
if link_colorbar_label is None and isinstance(link_attribute, str):
link_colorbar_label = link_attribute

wn_gis = wn.to_gis()
# add node_type so that node assets can be plotted separately
wn_gis.junctions["node_type"] = "Junction"
wn_gis.tanks["node_type"] = "Tank"
wn_gis.reservoirs["node_type"] = "Reservoir"
link_gdf = pd.concat((wn_gis.pipes, wn_gis.pumps, wn_gis.valves))
node_gdf = pd.concat((wn_gis.junctions, wn_gis.tanks, wn_gis.reservoirs))

# Node attribute
node_kwds = {}
node_cbar = add_colorbar
if node_attribute is not None:
node_gdf["_attribute"] = _format_node_attribute(node_attribute, wn)
node_kwds["column"] = "_attribute"

# handle cbar/cmap
if isinstance(node_attribute, list):
if node_cmap is None:
node_cmap = ['red', 'red']
add_node_colorbar = False

if node_cmap is None:
node_cmap = plt.get_cmap('Spectral_r')
elif isinstance(node_cmap, list):
if len(node_cmap) == 1:
node_cmap = node_cmap*2
node_cmap = custom_colormap(len(node_cmap), node_cmap)

node_attribute = _format_node_attribute(node_attribute, wn)
nodelist,nodecolor = zip(*node_attribute.items())

node_kwds["cmap"] = custom_colormap(2,["red", "red"])
node_cbar = False
elif isinstance(node_attribute, (dict, pd.Series, str)):
node_kwds["cmap"] = node_cmap

# manually extract min/max if no range is given
node_attribute_values = node_gdf[node_kwds["column"]]
if node_range[0] is None:
node_kwds["vmin"] = np.nanmin(node_attribute_values)
else:
node_kwds["vmin"] = node_range[0]
if node_range[1] is None:
node_kwds["vmax"] = np.nanmax(node_attribute_values)
else:
node_kwds["vmax"] = node_range[1]
else:
raise TypeError("attribute must be dict, Series, list, or str")
else:
nodelist = None
nodecolor = 'k'
node_kwds["color"] = "black"
node_cbar = False

node_kwds["alpha"] = node_alpha
node_kwds["markersize"] = node_size

add_link_colorbar = add_colorbar
node_cbar_kwds = {}
node_cbar_kwds["shrink"] = 0.5
node_cbar_kwds["pad"] = 0.0
node_cbar_kwds["alpha"] = node_alpha
node_cbar_kwds["label"] = node_colorbar_label

# Link attribute
link_kwds = {}
link_cbar = add_colorbar
if link_attribute is not None:
link_gdf["_attribute"] = pd.Series(_format_link_attribute(link_attribute, wn))
link_kwds["column"] = "_attribute"

# handle cbar/cmap
if isinstance(link_attribute, list):
if link_cmap is None:
link_cmap = ['red', 'red']
add_link_colorbar = False

if link_cmap is None:
link_cmap = plt.get_cmap('Spectral_r')
elif isinstance(link_cmap, list):
if len(link_cmap) == 1:
link_cmap = link_cmap*2
link_cmap = custom_colormap(len(link_cmap), link_cmap)
link_kwds["cmap"] = custom_colormap(2,["red", "red"])
link_cbar = False
elif isinstance(link_attribute, (dict, pd.Series, str)):
link_kwds["cmap"] = link_cmap

link_attribute = _format_link_attribute(link_attribute, wn)

# Replace link_attribute dictionary defined as
# {link_name: attr} with {(start_node, end_node, link_name): attr}
attr = {}
for link_name, value in link_attribute.items():
link = wn.get_link(link_name)
attr[(link.start_node_name, link.end_node_name, link_name)] = value
link_attribute = attr

linklist,linkcolor = zip(*link_attribute.items())
# manually extract min/max if no range is given
link_attribute_values = link_gdf[link_kwds["column"]]
if link_range[0] is None:
link_kwds["vmin"] = np.nanmin(link_attribute_values)
else:
link_kwds["vmin"] = link_range[0]
if link_range[1] is None:
link_kwds["vmax"] = np.nanmax(link_attribute_values)
else:
link_kwds["vmax"] = link_range[1]
else:
raise TypeError("attribute must be dict, Series, list, or str")
else:
linklist = None
linkcolor = 'k'
link_kwds["color"] = "black"
link_cbar = False

if title is not None:
ax.set_title(title)

edge_background = nx.draw_networkx_edges(G, pos, edge_color='grey',
width=0.5, ax=ax)
link_kwds["linewidth"] = link_width
link_kwds["alpha"] = link_alpha

background_link_kwds = {}
background_link_kwds["color"] = "grey"
background_link_kwds["linewidth"] = link_width / 2
background_link_kwds["alpha"] = link_alpha

link_cbar_kwds = {}
link_cbar_kwds["shrink"] = 0.5
link_cbar_kwds["pad"] = 0.05
link_cbar_kwds["label"] = link_colorbar_label
link_cbar_kwds["alpha"] = link_alpha

missing_node_kwds={"color": "black"}
missing_link_kwds={"color": "black"}

# plot nodes - each type is plotted separately to allow for different marker types
node_gdf[node_gdf.node_type == "Junction"].plot(
ax=ax, aspect=aspect, zorder=3, legend=False, label="Junction", missing_kwds=missing_node_kwds, **node_kwds)

node_kwds["markersize"] = node_size * 2.0
node_gdf[node_gdf.node_type == "Tank"].plot(
ax=ax, aspect=aspect, zorder=4, marker=tank_marker, legend=False, label="Tank", missing_kwds=missing_node_kwds, **node_kwds)

nodes = nx.draw_networkx_nodes(G, pos,
nodelist=nodelist, node_color=nodecolor, node_size=node_size,
alpha=node_alpha, cmap=node_cmap, vmin=node_range[0], vmax = node_range[1],
linewidths=0, ax=ax)
edges = nx.draw_networkx_edges(G, pos, edgelist=linklist, arrows=directed,
edge_color=linkcolor, width=link_width, alpha=link_alpha, edge_cmap=link_cmap,
edge_vmin=link_range[0], edge_vmax=link_range[1], ax=ax)
node_kwds["markersize"] = node_size * 3.0
node_gdf[node_gdf.node_type == "Reservoir"].plot(
ax=ax, aspect=aspect, zorder=5, marker=reservoir_marker, legend=False, label="Reservoir", missing_kwds=missing_node_kwds,**node_kwds)

if node_cbar:
sm = mpl.cm.ScalarMappable(cmap=node_kwds["cmap"])
sm.set_clim(node_kwds["vmin"], node_kwds["vmax"])

node_cbar = ax.figure.colorbar(sm, ax=ax, **node_cbar_kwds)

# plot links
# background
link_gdf.plot(
ax=ax, aspect=aspect, zorder=1, legend=False, **background_link_kwds)

# main plot
link_gdf.plot(
ax=ax, aspect=aspect, zorder=2, legend=False, missing_kwds=missing_link_kwds, **link_kwds)

if link_cbar:
sm = mpl.cm.ScalarMappable(cmap=link_kwds["cmap"])
sm.set_clim(link_kwds["vmin"], link_kwds["vmax"])

link_cbar = ax.figure.colorbar(sm, ax=ax, **link_cbar_kwds)

if node_labels:
labels = dict(zip(wn.node_name_list, wn.node_name_list))
nx.draw_networkx_labels(G, pos, labels, font_size=7, ax=ax)
for x, y, label in zip(node_gdf.geometry.x, node_gdf.geometry.y, node_gdf.index):
ax.annotate(label, xy=(x, y))#, xytext=(3, 3),)# textcoords="offset points")

if link_labels:
labels = {}
for link_name in wn.link_name_list:
link = wn.get_link(link_name)
labels[(link.start_node_name, link.end_node_name)] = link_name
nx.draw_networkx_edge_labels(G, pos, labels, font_size=7, ax=ax)
if add_node_colorbar and node_attribute:
clb = plt.colorbar(nodes, shrink=0.5, pad=0, ax=ax)
clb.ax.set_title(node_colorbar_label, fontsize=10)
if add_link_colorbar and link_attribute:
if link_range[0] is None:
vmin = min(link_attribute.values())
else:
vmin = link_range[0]
if link_range[1] is None:
vmax = max(link_attribute.values())
else:
vmax = link_range[1]
sm = plt.cm.ScalarMappable(cmap=link_cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
sm.set_array([])
clb = plt.colorbar(sm, shrink=0.5, pad=0.05, ax=ax)
clb.ax.set_title(link_colorbar_label, fontsize=10)

midpoints = link_gdf.geometry.apply(lambda x: x.interpolate(0.5, normalized=True))
for x, y, label in zip(midpoints.geometry.x, midpoints.geometry.y, link_gdf.index):
ax.annotate(label, xy=(x, y))#, xytext=(3, 3),)# textcoords="offset points")

if directed:
link_gdf["_midpoint"] = link_gdf.geometry.interpolate(0.5, normalized=True)
link_gdf["_angle"] = link_gdf.apply(lambda row: _get_angle(row.geometry), axis=1)
for idx , row in link_gdf.iterrows():
x,y = row["_midpoint"].x, row["_midpoint"].y
angle = row["_angle"]
ax.scatter(x,y, color="black", s=50, marker=(3,0, angle-90))

if legend:
handles, labels = ax.get_legend_handles_labels()
leg = ax.legend(handles, labels, loc='upper right', title="Legend")

ax.axis('off')

if filename:
Expand Down
Loading
Loading